summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore43
-rwxr-xr-xbin/ci24
-rwxr-xr-xbin/co24
-rwxr-xr-xbin/dd_extractbin0 -> 15976 bytes
-rwxr-xr-xbin/dd_listbin0 -> 16000 bytes
-rwxr-xr-xbin/ejectbin0 -> 50048 bytes
-rwxr-xr-xbin/identbin0 -> 113096 bytes
-rwxr-xr-xbin/lsscsibin0 -> 57832 bytes
-rwxr-xr-xbin/mergebin0 -> 113080 bytes
-rwxr-xr-xbin/rcsbin0 -> 179816 bytes
-rwxr-xr-xbin/rcsclean24
-rwxr-xr-xbin/rcsdiff24
-rwxr-xr-xbin/rcsfreeze132
-rwxr-xr-xbin/rcsmerge24
-rwxr-xr-xbin/rlog24
-rwxr-xr-xbin/sunhpc153
-rwxr-xr-xbin/syncDaemon167
-rw-r--r--data/.database1
-rw-r--r--data/sunhpc.dbbin0 -> 94208 bytes
-rw-r--r--etc/env.sunhpc35
-rw-r--r--etc/safeputrc4
-rw-r--r--lib/Crypto/Cipher/AES.py250
-rw-r--r--lib/Crypto/Cipher/AES.pyi47
-rw-r--r--lib/Crypto/Cipher/ARC2.py175
-rw-r--r--lib/Crypto/Cipher/ARC2.pyi35
-rw-r--r--lib/Crypto/Cipher/ARC4.py137
-rw-r--r--lib/Crypto/Cipher/ARC4.pyi16
-rw-r--r--lib/Crypto/Cipher/Blowfish.py159
-rw-r--r--lib/Crypto/Cipher/Blowfish.pyi35
-rw-r--r--lib/Crypto/Cipher/CAST.py159
-rw-r--r--lib/Crypto/Cipher/CAST.pyi35
-rw-r--r--lib/Crypto/Cipher/ChaCha20.py287
-rw-r--r--lib/Crypto/Cipher/ChaCha20.pyi25
-rw-r--r--lib/Crypto/Cipher/ChaCha20_Poly1305.py336
-rw-r--r--lib/Crypto/Cipher/ChaCha20_Poly1305.pyi28
-rw-r--r--lib/Crypto/Cipher/DES.py158
-rw-r--r--lib/Crypto/Cipher/DES.pyi35
-rw-r--r--lib/Crypto/Cipher/DES3.py187
-rw-r--r--lib/Crypto/Cipher/DES3.pyi37
-rw-r--r--lib/Crypto/Cipher/PKCS1_OAEP.py239
-rw-r--r--lib/Crypto/Cipher/PKCS1_OAEP.pyi35
-rw-r--r--lib/Crypto/Cipher/PKCS1_v1_5.py217
-rw-r--r--lib/Crypto/Cipher/PKCS1_v1_5.pyi20
-rw-r--r--lib/Crypto/Cipher/Salsa20.py167
-rw-r--r--lib/Crypto/Cipher/Salsa20.pyi27
-rwxr-xr-xlib/Crypto/Cipher/_ARC4.abi3.sobin0 -> 13768 bytes
-rw-r--r--lib/Crypto/Cipher/_EKSBlowfish.py131
-rw-r--r--lib/Crypto/Cipher/_EKSBlowfish.pyi15
-rwxr-xr-xlib/Crypto/Cipher/_Salsa20.abi3.sobin0 -> 26784 bytes
-rw-r--r--lib/Crypto/Cipher/__init__.py79
-rw-r--r--lib/Crypto/Cipher/__init__.pyi0
-rwxr-xr-xlib/Crypto/Cipher/_chacha20.abi3.sobin0 -> 28224 bytes
-rw-r--r--lib/Crypto/Cipher/_mode_cbc.py293
-rw-r--r--lib/Crypto/Cipher/_mode_cbc.pyi25
-rw-r--r--lib/Crypto/Cipher/_mode_ccm.py650
-rw-r--r--lib/Crypto/Cipher/_mode_ccm.pyi47
-rw-r--r--lib/Crypto/Cipher/_mode_cfb.py293
-rw-r--r--lib/Crypto/Cipher/_mode_cfb.pyi26
-rw-r--r--lib/Crypto/Cipher/_mode_ctr.py393
-rw-r--r--lib/Crypto/Cipher/_mode_ctr.pyi27
-rw-r--r--lib/Crypto/Cipher/_mode_eax.py408
-rw-r--r--lib/Crypto/Cipher/_mode_eax.pyi45
-rw-r--r--lib/Crypto/Cipher/_mode_ecb.py220
-rw-r--r--lib/Crypto/Cipher/_mode_ecb.pyi19
-rw-r--r--lib/Crypto/Cipher/_mode_gcm.py620
-rw-r--r--lib/Crypto/Cipher/_mode_gcm.pyi45
-rw-r--r--lib/Crypto/Cipher/_mode_ocb.py525
-rw-r--r--lib/Crypto/Cipher/_mode_ocb.pyi36
-rw-r--r--lib/Crypto/Cipher/_mode_ofb.py282
-rw-r--r--lib/Crypto/Cipher/_mode_ofb.pyi25
-rw-r--r--lib/Crypto/Cipher/_mode_openpgp.py206
-rw-r--r--lib/Crypto/Cipher/_mode_openpgp.pyi20
-rw-r--r--lib/Crypto/Cipher/_mode_siv.py392
-rw-r--r--lib/Crypto/Cipher/_mode_siv.pyi38
-rwxr-xr-xlib/Crypto/Cipher/_pkcs1_decode.abi3.sobin0 -> 28096 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_aes.abi3.sobin0 -> 66256 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_aesni.abi3.sobin0 -> 101136 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_arc2.abi3.sobin0 -> 43776 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_blowfish.abi3.sobin0 -> 69976 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_cast.abi3.sobin0 -> 42976 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_cbc.abi3.sobin0 -> 20736 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_cfb.abi3.sobin0 -> 25440 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_ctr.abi3.sobin0 -> 28600 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_des.abi3.sobin0 -> 75672 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_des3.abi3.sobin0 -> 76480 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_ecb.abi3.sobin0 -> 12440 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_eksblowfish.abi3.sobin0 -> 166264 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_ocb.abi3.sobin0 -> 37344 bytes
-rwxr-xr-xlib/Crypto/Cipher/_raw_ofb.abi3.sobin0 -> 15368 bytes
-rw-r--r--lib/Crypto/Hash/BLAKE2b.py247
-rw-r--r--lib/Crypto/Hash/BLAKE2b.pyi32
-rw-r--r--lib/Crypto/Hash/BLAKE2s.py247
-rw-r--r--lib/Crypto/Hash/BLAKE2s.pyi26
-rw-r--r--lib/Crypto/Hash/CMAC.py302
-rw-r--r--lib/Crypto/Hash/CMAC.pyi30
-rw-r--r--lib/Crypto/Hash/HMAC.py213
-rw-r--r--lib/Crypto/Hash/HMAC.pyi25
-rw-r--r--lib/Crypto/Hash/KMAC128.py179
-rw-r--r--lib/Crypto/Hash/KMAC128.pyi33
-rw-r--r--lib/Crypto/Hash/KMAC256.py74
-rw-r--r--lib/Crypto/Hash/KMAC256.pyi10
-rw-r--r--lib/Crypto/Hash/KangarooTwelve.py262
-rw-r--r--lib/Crypto/Hash/KangarooTwelve.pyi16
-rw-r--r--lib/Crypto/Hash/MD2.py166
-rw-r--r--lib/Crypto/Hash/MD2.pyi19
-rw-r--r--lib/Crypto/Hash/MD4.py185
-rw-r--r--lib/Crypto/Hash/MD4.pyi19
-rw-r--r--lib/Crypto/Hash/MD5.py184
-rw-r--r--lib/Crypto/Hash/MD5.pyi19
-rw-r--r--lib/Crypto/Hash/Poly1305.py217
-rw-r--r--lib/Crypto/Hash/Poly1305.pyi24
-rw-r--r--lib/Crypto/Hash/RIPEMD.py26
-rw-r--r--lib/Crypto/Hash/RIPEMD.pyi3
-rw-r--r--lib/Crypto/Hash/RIPEMD160.py169
-rw-r--r--lib/Crypto/Hash/RIPEMD160.pyi19
-rw-r--r--lib/Crypto/Hash/SHA.py24
-rw-r--r--lib/Crypto/Hash/SHA.pyi4
-rw-r--r--lib/Crypto/Hash/SHA1.py185
-rw-r--r--lib/Crypto/Hash/SHA1.pyi19
-rw-r--r--lib/Crypto/Hash/SHA224.py186
-rw-r--r--lib/Crypto/Hash/SHA224.pyi19
-rw-r--r--lib/Crypto/Hash/SHA256.py185
-rw-r--r--lib/Crypto/Hash/SHA256.pyi18
-rw-r--r--lib/Crypto/Hash/SHA384.py186
-rw-r--r--lib/Crypto/Hash/SHA384.pyi19
-rw-r--r--lib/Crypto/Hash/SHA3_224.py174
-rw-r--r--lib/Crypto/Hash/SHA3_224.pyi19
-rw-r--r--lib/Crypto/Hash/SHA3_256.py174
-rw-r--r--lib/Crypto/Hash/SHA3_256.pyi19
-rw-r--r--lib/Crypto/Hash/SHA3_384.py179
-rw-r--r--lib/Crypto/Hash/SHA3_384.pyi19
-rw-r--r--lib/Crypto/Hash/SHA3_512.py174
-rw-r--r--lib/Crypto/Hash/SHA3_512.pyi19
-rw-r--r--lib/Crypto/Hash/SHA512.py204
-rw-r--r--lib/Crypto/Hash/SHA512.pyi22
-rw-r--r--lib/Crypto/Hash/SHAKE128.py129
-rw-r--r--lib/Crypto/Hash/SHAKE128.pyi13
-rw-r--r--lib/Crypto/Hash/SHAKE256.py130
-rw-r--r--lib/Crypto/Hash/SHAKE256.pyi13
-rw-r--r--lib/Crypto/Hash/TupleHash128.py138
-rw-r--r--lib/Crypto/Hash/TupleHash128.pyi22
-rw-r--r--lib/Crypto/Hash/TupleHash256.py73
-rw-r--r--lib/Crypto/Hash/TupleHash256.pyi5
-rwxr-xr-xlib/Crypto/Hash/_BLAKE2b.abi3.sobin0 -> 21888 bytes
-rwxr-xr-xlib/Crypto/Hash/_BLAKE2s.abi3.sobin0 -> 21712 bytes
-rwxr-xr-xlib/Crypto/Hash/_MD2.abi3.sobin0 -> 20128 bytes
-rwxr-xr-xlib/Crypto/Hash/_MD4.abi3.sobin0 -> 25576 bytes
-rwxr-xr-xlib/Crypto/Hash/_MD5.abi3.sobin0 -> 31704 bytes
-rwxr-xr-xlib/Crypto/Hash/_RIPEMD160.abi3.sobin0 -> 55608 bytes
-rwxr-xr-xlib/Crypto/Hash/_SHA1.abi3.sobin0 -> 74416 bytes
-rwxr-xr-xlib/Crypto/Hash/_SHA224.abi3.sobin0 -> 43792 bytes
-rwxr-xr-xlib/Crypto/Hash/_SHA256.abi3.sobin0 -> 43872 bytes
-rwxr-xr-xlib/Crypto/Hash/_SHA384.abi3.sobin0 -> 50520 bytes
-rwxr-xr-xlib/Crypto/Hash/_SHA512.abi3.sobin0 -> 50624 bytes
-rw-r--r--lib/Crypto/Hash/__init__.py24
-rw-r--r--lib/Crypto/Hash/__init__.pyi0
-rwxr-xr-xlib/Crypto/Hash/_ghash_clmul.abi3.sobin0 -> 50160 bytes
-rwxr-xr-xlib/Crypto/Hash/_ghash_portable.abi3.sobin0 -> 17432 bytes
-rwxr-xr-xlib/Crypto/Hash/_keccak.abi3.sobin0 -> 35064 bytes
-rwxr-xr-xlib/Crypto/Hash/_poly1305.abi3.sobin0 -> 33360 bytes
-rw-r--r--lib/Crypto/Hash/cSHAKE128.py187
-rw-r--r--lib/Crypto/Hash/cSHAKE128.pyi14
-rw-r--r--lib/Crypto/Hash/cSHAKE256.py56
-rw-r--r--lib/Crypto/Hash/cSHAKE256.pyi8
-rw-r--r--lib/Crypto/Hash/keccak.py181
-rw-r--r--lib/Crypto/Hash/keccak.pyi23
-rw-r--r--lib/Crypto/IO/PEM.py189
-rw-r--r--lib/Crypto/IO/PEM.pyi10
-rw-r--r--lib/Crypto/IO/PKCS8.py239
-rw-r--r--lib/Crypto/IO/PKCS8.pyi14
-rw-r--r--lib/Crypto/IO/_PBES.py435
-rw-r--r--lib/Crypto/IO/_PBES.pyi19
-rw-r--r--lib/Crypto/IO/__init__.py31
-rw-r--r--lib/Crypto/Math/Numbers.py42
-rw-r--r--lib/Crypto/Math/Numbers.pyi4
-rw-r--r--lib/Crypto/Math/Primality.py369
-rw-r--r--lib/Crypto/Math/Primality.pyi18
-rw-r--r--lib/Crypto/Math/_IntegerBase.py392
-rw-r--r--lib/Crypto/Math/_IntegerBase.pyi61
-rw-r--r--lib/Crypto/Math/_IntegerCustom.py118
-rw-r--r--lib/Crypto/Math/_IntegerCustom.pyi8
-rw-r--r--lib/Crypto/Math/_IntegerGMP.py762
-rw-r--r--lib/Crypto/Math/_IntegerGMP.pyi3
-rw-r--r--lib/Crypto/Math/_IntegerNative.py395
-rw-r--r--lib/Crypto/Math/_IntegerNative.pyi3
-rw-r--r--lib/Crypto/Math/__init__.py0
-rwxr-xr-xlib/Crypto/Math/_modexp.abi3.sobin0 -> 294464 bytes
-rw-r--r--lib/Crypto/Protocol/KDF.py574
-rw-r--r--lib/Crypto/Protocol/KDF.pyi24
-rw-r--r--lib/Crypto/Protocol/SecretSharing.py278
-rw-r--r--lib/Crypto/Protocol/SecretSharing.pyi22
-rw-r--r--lib/Crypto/Protocol/__init__.py31
-rw-r--r--lib/Crypto/Protocol/__init__.pyi1
-rwxr-xr-xlib/Crypto/Protocol/_scrypt.abi3.sobin0 -> 25024 bytes
-rw-r--r--lib/Crypto/PublicKey/DSA.py682
-rw-r--r--lib/Crypto/PublicKey/DSA.pyi31
-rw-r--r--lib/Crypto/PublicKey/ECC.py1794
-rw-r--r--lib/Crypto/PublicKey/ECC.pyi66
-rw-r--r--lib/Crypto/PublicKey/ElGamal.py286
-rw-r--r--lib/Crypto/PublicKey/ElGamal.pyi18
-rw-r--r--lib/Crypto/PublicKey/RSA.py802
-rw-r--r--lib/Crypto/PublicKey/RSA.pyi51
-rw-r--r--lib/Crypto/PublicKey/__init__.py94
-rw-r--r--lib/Crypto/PublicKey/__init__.pyi0
-rwxr-xr-xlib/Crypto/PublicKey/_ec_ws.abi3.sobin0 -> 1068008 bytes
-rwxr-xr-xlib/Crypto/PublicKey/_ed25519.abi3.sobin0 -> 578280 bytes
-rwxr-xr-xlib/Crypto/PublicKey/_ed448.abi3.sobin0 -> 329424 bytes
-rw-r--r--lib/Crypto/PublicKey/_openssh.py135
-rw-r--r--lib/Crypto/PublicKey/_openssh.pyi7
-rwxr-xr-xlib/Crypto/PublicKey/_x25519.abi3.sobin0 -> 124632 bytes
-rw-r--r--lib/Crypto/Random/__init__.py57
-rw-r--r--lib/Crypto/Random/__init__.pyi19
-rw-r--r--lib/Crypto/Random/random.py138
-rw-r--r--lib/Crypto/Random/random.pyi20
-rw-r--r--lib/Crypto/SelfTest/Cipher/__init__.py60
-rw-r--r--lib/Crypto/SelfTest/Cipher/common.py510
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_AES.py1351
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_ARC2.py167
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_ARC4.py466
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_Blowfish.py160
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_CAST.py101
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_CBC.py556
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_CCM.py936
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_CFB.py411
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_CTR.py472
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_ChaCha20.py529
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_ChaCha20_Poly1305.py770
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_DES.py374
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_DES3.py195
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_EAX.py773
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_GCM.py951
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_OCB.py742
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_OFB.py238
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_OpenPGP.py218
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_SIV.py552
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_Salsa20.py367
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_pkcs1_15.py283
-rw-r--r--lib/Crypto/SelfTest/Cipher/test_pkcs1_oaep.py506
-rw-r--r--lib/Crypto/SelfTest/Hash/__init__.py61
-rw-r--r--lib/Crypto/SelfTest/Hash/common.py290
-rw-r--r--lib/Crypto/SelfTest/Hash/test_BLAKE2.py482
-rw-r--r--lib/Crypto/SelfTest/Hash/test_CMAC.py448
-rw-r--r--lib/Crypto/SelfTest/Hash/test_HMAC.py548
-rw-r--r--lib/Crypto/SelfTest/Hash/test_KMAC.py346
-rw-r--r--lib/Crypto/SelfTest/Hash/test_KangarooTwelve.py324
-rw-r--r--lib/Crypto/SelfTest/Hash/test_MD2.py62
-rw-r--r--lib/Crypto/SelfTest/Hash/test_MD4.py64
-rw-r--r--lib/Crypto/SelfTest/Hash/test_MD5.py94
-rw-r--r--lib/Crypto/SelfTest/Hash/test_Poly1305.py542
-rw-r--r--lib/Crypto/SelfTest/Hash/test_RIPEMD160.py71
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA1.py84
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA224.py63
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA256.py94
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA384.py61
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA3_224.py79
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA3_256.py80
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA3_384.py79
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA3_512.py79
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHA512.py140
-rw-r--r--lib/Crypto/SelfTest/Hash/test_SHAKE.py143
-rw-r--r--lib/Crypto/SelfTest/Hash/test_TupleHash.py286
-rw-r--r--lib/Crypto/SelfTest/Hash/test_cSHAKE.py178
-rw-r--r--lib/Crypto/SelfTest/Hash/test_keccak.py250
-rw-r--r--lib/Crypto/SelfTest/IO/__init__.py47
-rw-r--r--lib/Crypto/SelfTest/IO/test_PBES.py93
-rw-r--r--lib/Crypto/SelfTest/IO/test_PKCS8.py425
-rw-r--r--lib/Crypto/SelfTest/Math/__init__.py49
-rw-r--r--lib/Crypto/SelfTest/Math/test_Numbers.py797
-rw-r--r--lib/Crypto/SelfTest/Math/test_Primality.py118
-rw-r--r--lib/Crypto/SelfTest/Math/test_modexp.py201
-rw-r--r--lib/Crypto/SelfTest/Protocol/__init__.py44
-rw-r--r--lib/Crypto/SelfTest/Protocol/test_KDF.py732
-rw-r--r--lib/Crypto/SelfTest/Protocol/test_SecretSharing.py267
-rw-r--r--lib/Crypto/SelfTest/Protocol/test_rfc1751.py62
-rw-r--r--lib/Crypto/SelfTest/PublicKey/__init__.py53
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_DSA.py247
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_ECC_25519.py327
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_ECC_448.py327
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_ECC_NIST.py1389
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_ElGamal.py217
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_RSA.py317
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_import_DSA.py554
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_import_ECC.py2643
-rw-r--r--lib/Crypto/SelfTest/PublicKey/test_import_RSA.py590
-rw-r--r--lib/Crypto/SelfTest/Random/__init__.py39
-rw-r--r--lib/Crypto/SelfTest/Random/test_random.py167
-rw-r--r--lib/Crypto/SelfTest/Signature/__init__.py41
-rw-r--r--lib/Crypto/SelfTest/Signature/test_dss.py1369
-rw-r--r--lib/Crypto/SelfTest/Signature/test_eddsa.py578
-rw-r--r--lib/Crypto/SelfTest/Signature/test_pkcs1_15.py348
-rw-r--r--lib/Crypto/SelfTest/Signature/test_pss.py377
-rw-r--r--lib/Crypto/SelfTest/Util/__init__.py46
-rw-r--r--lib/Crypto/SelfTest/Util/test_Counter.py67
-rw-r--r--lib/Crypto/SelfTest/Util/test_Padding.py154
-rw-r--r--lib/Crypto/SelfTest/Util/test_asn1.py784
-rw-r--r--lib/Crypto/SelfTest/Util/test_number.py192
-rw-r--r--lib/Crypto/SelfTest/Util/test_rfc1751.py38
-rw-r--r--lib/Crypto/SelfTest/Util/test_strxor.py280
-rw-r--r--lib/Crypto/SelfTest/__init__.py97
-rw-r--r--lib/Crypto/SelfTest/__main__.py38
-rw-r--r--lib/Crypto/SelfTest/loader.py206
-rw-r--r--lib/Crypto/SelfTest/st_common.py55
-rw-r--r--lib/Crypto/Signature/DSS.py403
-rw-r--r--lib/Crypto/Signature/DSS.pyi27
-rw-r--r--lib/Crypto/Signature/PKCS1_PSS.py55
-rw-r--r--lib/Crypto/Signature/PKCS1_PSS.pyi7
-rw-r--r--lib/Crypto/Signature/PKCS1_v1_5.py53
-rw-r--r--lib/Crypto/Signature/PKCS1_v1_5.pyi6
-rw-r--r--lib/Crypto/Signature/__init__.py36
-rw-r--r--lib/Crypto/Signature/eddsa.py341
-rw-r--r--lib/Crypto/Signature/eddsa.pyi21
-rw-r--r--lib/Crypto/Signature/pkcs1_15.py222
-rw-r--r--lib/Crypto/Signature/pkcs1_15.pyi17
-rw-r--r--lib/Crypto/Signature/pss.py386
-rw-r--r--lib/Crypto/Signature/pss.pyi30
-rw-r--r--lib/Crypto/Util/Counter.py77
-rw-r--r--lib/Crypto/Util/Counter.pyi5
-rw-r--r--lib/Crypto/Util/Padding.py108
-rw-r--r--lib/Crypto/Util/Padding.pyi6
-rw-r--r--lib/Crypto/Util/RFC1751.py386
-rw-r--r--lib/Crypto/Util/RFC1751.pyi7
-rw-r--r--lib/Crypto/Util/__init__.py41
-rw-r--r--lib/Crypto/Util/_cpu_features.py46
-rw-r--r--lib/Crypto/Util/_cpu_features.pyi2
-rwxr-xr-xlib/Crypto/Util/_cpuid_c.abi3.sobin0 -> 12776 bytes
-rw-r--r--lib/Crypto/Util/_file_system.py54
-rw-r--r--lib/Crypto/Util/_file_system.pyi4
-rw-r--r--lib/Crypto/Util/_raw_api.py319
-rw-r--r--lib/Crypto/Util/_raw_api.pyi27
-rwxr-xr-xlib/Crypto/Util/_strxor.abi3.sobin0 -> 14960 bytes
-rw-r--r--lib/Crypto/Util/asn1.py939
-rw-r--r--lib/Crypto/Util/asn1.pyi74
-rw-r--r--lib/Crypto/Util/number.py1500
-rw-r--r--lib/Crypto/Util/number.pyi19
-rw-r--r--lib/Crypto/Util/py3compat.py174
-rw-r--r--lib/Crypto/Util/py3compat.pyi33
-rw-r--r--lib/Crypto/Util/strxor.py146
-rw-r--r--lib/Crypto/Util/strxor.pyi6
-rw-r--r--lib/Crypto/__init__.py6
-rw-r--r--lib/Crypto/__init__.pyi4
-rw-r--r--lib/Crypto/py.typed0
-rw-r--r--lib/SQLAlchemy-1.4.40.dist-info/INSTALLER1
-rw-r--r--lib/SQLAlchemy-1.4.40.dist-info/LICENSE19
-rw-r--r--lib/SQLAlchemy-1.4.40.dist-info/METADATA237
-rw-r--r--lib/SQLAlchemy-1.4.40.dist-info/RECORD486
-rw-r--r--lib/SQLAlchemy-1.4.40.dist-info/REQUESTED0
-rw-r--r--lib/SQLAlchemy-1.4.40.dist-info/WHEEL8
-rw-r--r--lib/SQLAlchemy-1.4.40.dist-info/top_level.txt1
-rwxr-xr-xlib/_cffi_backend.cpython-39-x86_64-linux-gnu.sobin0 -> 981144 bytes
-rwxr-xr-xlib/_dbus_bindings.cpython-310-x86_64-linux-gnu.sobin0 -> 565840 bytes
-rwxr-xr-xlib/_dbus_glib_bindings.cpython-310-x86_64-linux-gnu.sobin0 -> 63520 bytes
-rwxr-xr-xlib/_snack.sobin0 -> 45432 bytes
-rwxr-xr-xlib/bin/chardetect8
-rw-r--r--lib/cffi-1.15.1.dist-info/INSTALLER1
-rw-r--r--lib/cffi-1.15.1.dist-info/LICENSE26
-rw-r--r--lib/cffi-1.15.1.dist-info/METADATA34
-rw-r--r--lib/cffi-1.15.1.dist-info/RECORD45
-rw-r--r--lib/cffi-1.15.1.dist-info/REQUESTED0
-rw-r--r--lib/cffi-1.15.1.dist-info/WHEEL6
-rw-r--r--lib/cffi-1.15.1.dist-info/entry_points.txt2
-rw-r--r--lib/cffi-1.15.1.dist-info/top_level.txt2
-rw-r--r--lib/cffi/__init__.py14
-rw-r--r--lib/cffi/_cffi_errors.h149
-rw-r--r--lib/cffi/_cffi_include.h385
-rw-r--r--lib/cffi/_embedding.h528
-rw-r--r--lib/cffi/api.py965
-rw-r--r--lib/cffi/backend_ctypes.py1121
-rw-r--r--lib/cffi/cffi_opcode.py187
-rw-r--r--lib/cffi/commontypes.py80
-rw-r--r--lib/cffi/cparser.py1006
-rw-r--r--lib/cffi/error.py31
-rw-r--r--lib/cffi/ffiplatform.py127
-rw-r--r--lib/cffi/lock.py30
-rw-r--r--lib/cffi/model.py617
-rw-r--r--lib/cffi/parse_c_type.h181
-rw-r--r--lib/cffi/pkgconfig.py121
-rw-r--r--lib/cffi/recompiler.py1581
-rw-r--r--lib/cffi/setuptools_ext.py219
-rw-r--r--lib/cffi/vengine_cpy.py1076
-rw-r--r--lib/cffi/vengine_gen.py675
-rw-r--r--lib/cffi/verifier.py307
-rw-r--r--lib/chardet-5.0.0.dist-info/INSTALLER1
-rw-r--r--lib/chardet-5.0.0.dist-info/LICENSE502
-rw-r--r--lib/chardet-5.0.0.dist-info/METADATA100
-rw-r--r--lib/chardet-5.0.0.dist-info/RECORD99
-rw-r--r--lib/chardet-5.0.0.dist-info/REQUESTED0
-rw-r--r--lib/chardet-5.0.0.dist-info/WHEEL5
-rw-r--r--lib/chardet-5.0.0.dist-info/entry_points.txt3
-rw-r--r--lib/chardet-5.0.0.dist-info/top_level.txt1
-rw-r--r--lib/chardet/__init__.py93
-rw-r--r--lib/chardet/big5freq.py386
-rw-r--r--lib/chardet/big5prober.py47
-rw-r--r--lib/chardet/chardistribution.py259
-rw-r--r--lib/chardet/charsetgroupprober.py109
-rw-r--r--lib/chardet/charsetprober.py138
-rw-r--r--lib/chardet/cli/__init__.py0
-rw-r--r--lib/chardet/cli/chardetect.py86
-rw-r--r--lib/chardet/codingstatemachine.py88
-rw-r--r--lib/chardet/cp949prober.py49
-rw-r--r--lib/chardet/enums.py82
-rw-r--r--lib/chardet/escprober.py102
-rw-r--r--lib/chardet/escsm.py260
-rw-r--r--lib/chardet/eucjpprober.py95
-rw-r--r--lib/chardet/euckrfreq.py196
-rw-r--r--lib/chardet/euckrprober.py47
-rw-r--r--lib/chardet/euctwfreq.py388
-rw-r--r--lib/chardet/euctwprober.py47
-rw-r--r--lib/chardet/gb2312freq.py284
-rw-r--r--lib/chardet/gb2312prober.py47
-rw-r--r--lib/chardet/hebrewprober.py302
-rw-r--r--lib/chardet/jisfreq.py325
-rw-r--r--lib/chardet/johabfreq.py2382
-rw-r--r--lib/chardet/johabprober.py47
-rw-r--r--lib/chardet/jpcntx.py237
-rw-r--r--lib/chardet/langbulgarianmodel.py4649
-rw-r--r--lib/chardet/langgreekmodel.py4397
-rw-r--r--lib/chardet/langhebrewmodel.py4380
-rw-r--r--lib/chardet/langhungarianmodel.py4649
-rw-r--r--lib/chardet/langrussianmodel.py5725
-rw-r--r--lib/chardet/langthaimodel.py4380
-rw-r--r--lib/chardet/langturkishmodel.py4380
-rw-r--r--lib/chardet/latin1prober.py145
-rw-r--r--lib/chardet/mbcharsetprober.py95
-rw-r--r--lib/chardet/mbcsgroupprober.py56
-rw-r--r--lib/chardet/mbcssm.py660
-rw-r--r--lib/chardet/metadata/__init__.py0
-rw-r--r--lib/chardet/metadata/languages.py351
-rw-r--r--lib/chardet/sbcharsetprober.py160
-rw-r--r--lib/chardet/sbcsgroupprober.py88
-rw-r--r--lib/chardet/sjisprober.py98
-rw-r--r--lib/chardet/universaldetector.py328
-rw-r--r--lib/chardet/utf1632prober.py223
-rw-r--r--lib/chardet/utf8prober.py80
-rw-r--r--lib/chardet/version.py9
-rw-r--r--lib/dbus/__init__.py93
-rw-r--r--lib/dbus/_compat.py15
-rw-r--r--lib/dbus/_dbus.py229
-rw-r--r--lib/dbus/_expat_introspect_parser.py87
-rw-r--r--lib/dbus/bus.py434
-rw-r--r--lib/dbus/connection.py651
-rw-r--r--lib/dbus/decorators.py362
-rw-r--r--lib/dbus/exceptions.py133
-rw-r--r--lib/dbus/gi_service.py87
-rw-r--r--lib/dbus/glib.py53
-rw-r--r--lib/dbus/lowlevel.py38
-rw-r--r--lib/dbus/mainloop/__init__.py64
-rw-r--r--lib/dbus/mainloop/glib.py43
-rw-r--r--lib/dbus/proxies.py567
-rw-r--r--lib/dbus/server.py119
-rw-r--r--lib/dbus/service.py840
-rw-r--r--lib/dbus/types.py15
-rw-r--r--lib/greenlet-1.1.3.dist-info/AUTHORS51
-rw-r--r--lib/greenlet-1.1.3.dist-info/INSTALLER1
-rw-r--r--lib/greenlet-1.1.3.dist-info/LICENSE30
-rw-r--r--lib/greenlet-1.1.3.dist-info/LICENSE.PSF47
-rw-r--r--lib/greenlet-1.1.3.dist-info/METADATA103
-rw-r--r--lib/greenlet-1.1.3.dist-info/RECORD71
-rw-r--r--lib/greenlet-1.1.3.dist-info/WHEEL6
-rw-r--r--lib/greenlet-1.1.3.dist-info/top_level.txt1
-rw-r--r--lib/greenlet/__init__.py63
-rwxr-xr-xlib/greenlet/_greenlet.cpython-39-x86_64-linux-gnu.sobin0 -> 130456 bytes
-rw-r--r--lib/greenlet/greenlet.c2170
-rw-r--r--lib/greenlet/greenlet.h161
-rw-r--r--lib/greenlet/platform/setup_switch_x64_masm.cmd2
-rw-r--r--lib/greenlet/platform/switch_aarch64_gcc.h69
-rw-r--r--lib/greenlet/platform/switch_alpha_unix.h30
-rw-r--r--lib/greenlet/platform/switch_amd64_unix.h84
-rw-r--r--lib/greenlet/platform/switch_arm32_gcc.h79
-rw-r--r--lib/greenlet/platform/switch_arm32_ios.h67
-rw-r--r--lib/greenlet/platform/switch_csky_gcc.h48
-rw-r--r--lib/greenlet/platform/switch_m68k_gcc.h38
-rw-r--r--lib/greenlet/platform/switch_mips_unix.h64
-rw-r--r--lib/greenlet/platform/switch_ppc64_aix.h103
-rw-r--r--lib/greenlet/platform/switch_ppc64_linux.h105
-rw-r--r--lib/greenlet/platform/switch_ppc_aix.h87
-rw-r--r--lib/greenlet/platform/switch_ppc_linux.h84
-rw-r--r--lib/greenlet/platform/switch_ppc_macosx.h82
-rw-r--r--lib/greenlet/platform/switch_ppc_unix.h82
-rw-r--r--lib/greenlet/platform/switch_riscv_unix.h32
-rw-r--r--lib/greenlet/platform/switch_s390_unix.h87
-rw-r--r--lib/greenlet/platform/switch_sparc_sun_gcc.h92
-rw-r--r--lib/greenlet/platform/switch_x32_unix.h63
-rw-r--r--lib/greenlet/platform/switch_x64_masm.asm111
-rw-r--r--lib/greenlet/platform/switch_x64_masm.objbin0 -> 1078 bytes
-rw-r--r--lib/greenlet/platform/switch_x64_msvc.h60
-rw-r--r--lib/greenlet/platform/switch_x86_msvc.h88
-rw-r--r--lib/greenlet/platform/switch_x86_unix.h105
-rw-r--r--lib/greenlet/slp_platformselect.h58
-rw-r--r--lib/greenlet/tests/__init__.py0
-rw-r--r--lib/greenlet/tests/_test_extension.c216
-rwxr-xr-xlib/greenlet/tests/_test_extension.cpython-39-x86_64-linux-gnu.sobin0 -> 34632 bytes
-rw-r--r--lib/greenlet/tests/_test_extension_cpp.cpp121
-rwxr-xr-xlib/greenlet/tests/_test_extension_cpp.cpython-39-x86_64-linux-gnu.sobin0 -> 47368 bytes
-rw-r--r--lib/greenlet/tests/test_contextvars.py266
-rw-r--r--lib/greenlet/tests/test_cpp.py18
-rw-r--r--lib/greenlet/tests/test_extension_interface.py77
-rw-r--r--lib/greenlet/tests/test_gc.py77
-rw-r--r--lib/greenlet/tests/test_generator.py59
-rw-r--r--lib/greenlet/tests/test_generator_nested.py165
-rw-r--r--lib/greenlet/tests/test_greenlet.py728
-rw-r--r--lib/greenlet/tests/test_leaks.py178
-rw-r--r--lib/greenlet/tests/test_stack_saved.py19
-rw-r--r--lib/greenlet/tests/test_throw.py100
-rw-r--r--lib/greenlet/tests/test_tracing.py267
-rw-r--r--lib/greenlet/tests/test_version.py39
-rw-r--r--lib/greenlet/tests/test_weakref.py34
-rw-r--r--lib/include/python/greenlet/greenlet.h161
-rw-r--r--lib/pexpect-4.8.0.dist-info/INSTALLER1
-rw-r--r--lib/pexpect-4.8.0.dist-info/LICENSE20
-rw-r--r--lib/pexpect-4.8.0.dist-info/METADATA49
-rw-r--r--lib/pexpect-4.8.0.dist-info/RECORD38
-rw-r--r--lib/pexpect-4.8.0.dist-info/REQUESTED0
-rw-r--r--lib/pexpect-4.8.0.dist-info/WHEEL6
-rw-r--r--lib/pexpect-4.8.0.dist-info/top_level.txt1
-rw-r--r--lib/pexpect/ANSI.py351
-rw-r--r--lib/pexpect/FSM.py334
-rw-r--r--lib/pexpect/__init__.py85
-rw-r--r--lib/pexpect/_async.py103
-rw-r--r--lib/pexpect/bashrc.sh16
-rw-r--r--lib/pexpect/exceptions.py35
-rw-r--r--lib/pexpect/expect.py371
-rw-r--r--lib/pexpect/fdpexpect.py148
-rw-r--r--lib/pexpect/popen_spawn.py188
-rw-r--r--lib/pexpect/pty_spawn.py860
-rw-r--r--lib/pexpect/pxssh.py537
-rw-r--r--lib/pexpect/replwrap.py130
-rw-r--r--lib/pexpect/run.py157
-rw-r--r--lib/pexpect/screen.py431
-rw-r--r--lib/pexpect/spawnbase.py525
-rw-r--r--lib/pexpect/utils.py187
-rw-r--r--lib/prettytable-3.6.0.dist-info/INSTALLER1
-rw-r--r--lib/prettytable-3.6.0.dist-info/METADATA702
-rw-r--r--lib/prettytable-3.6.0.dist-info/RECORD13
-rw-r--r--lib/prettytable-3.6.0.dist-info/REQUESTED0
-rw-r--r--lib/prettytable-3.6.0.dist-info/WHEEL4
-rw-r--r--lib/prettytable-3.6.0.dist-info/licenses/COPYING30
-rw-r--r--lib/prettytable/__init__.py54
-rw-r--r--lib/prettytable/colortable.py97
-rw-r--r--lib/prettytable/prettytable.py2531
-rw-r--r--lib/prettytable/py.typed0
-rw-r--r--lib/psutil-5.9.4.dist-info/INSTALLER1
-rw-r--r--lib/psutil-5.9.4.dist-info/LICENSE29
-rw-r--r--lib/psutil-5.9.4.dist-info/METADATA526
-rw-r--r--lib/psutil-5.9.4.dist-info/RECORD65
-rw-r--r--lib/psutil-5.9.4.dist-info/REQUESTED0
-rw-r--r--lib/psutil-5.9.4.dist-info/WHEEL8
-rw-r--r--lib/psutil-5.9.4.dist-info/top_level.txt1
-rw-r--r--lib/psutil/__init__.py2421
-rw-r--r--lib/psutil/_common.py899
-rw-r--r--lib/psutil/_compat.py450
-rw-r--r--lib/psutil/_psaix.py555
-rw-r--r--lib/psutil/_psbsd.py927
-rw-r--r--lib/psutil/_pslinux.py2257
-rw-r--r--lib/psutil/_psosx.py543
-rw-r--r--lib/psutil/_psposix.py232
-rw-r--r--lib/psutil/_pssunos.py727
-rwxr-xr-xlib/psutil/_psutil_linux.abi3.sobin0 -> 107400 bytes
-rwxr-xr-xlib/psutil/_psutil_posix.abi3.sobin0 -> 71008 bytes
-rw-r--r--lib/psutil/_pswindows.py1120
-rw-r--r--lib/psutil/tests/__init__.py1820
-rw-r--r--lib/psutil/tests/__main__.py15
-rw-r--r--lib/psutil/tests/runner.py350
-rw-r--r--lib/psutil/tests/test_aix.py122
-rw-r--r--lib/psutil/tests/test_bsd.py568
-rw-r--r--lib/psutil/tests/test_connections.py554
-rw-r--r--lib/psutil/tests/test_contracts.py751
-rw-r--r--lib/psutil/tests/test_linux.py2286
-rw-r--r--lib/psutil/tests/test_memleaks.py492
-rw-r--r--lib/psutil/tests/test_misc.py852
-rw-r--r--lib/psutil/tests/test_osx.py241
-rw-r--r--lib/psutil/tests/test_posix.py432
-rw-r--r--lib/psutil/tests/test_process.py1591
-rw-r--r--lib/psutil/tests/test_sunos.py46
-rw-r--r--lib/psutil/tests/test_system.py892
-rw-r--r--lib/psutil/tests/test_testutils.py441
-rw-r--r--lib/psutil/tests/test_unicode.py355
-rw-r--r--lib/psutil/tests/test_windows.py898
-rw-r--r--lib/ptyprocess-0.7.0.dist-info/INSTALLER1
-rw-r--r--lib/ptyprocess-0.7.0.dist-info/LICENSE16
-rw-r--r--lib/ptyprocess-0.7.0.dist-info/METADATA37
-rw-r--r--lib/ptyprocess-0.7.0.dist-info/RECORD13
-rw-r--r--lib/ptyprocess-0.7.0.dist-info/WHEEL5
-rw-r--r--lib/ptyprocess/__init__.py4
-rw-r--r--lib/ptyprocess/_fork_pty.py78
-rw-r--r--lib/ptyprocess/ptyprocess.py842
-rw-r--r--lib/ptyprocess/util.py71
-rw-r--r--lib/pycparser-2.21.dist-info/INSTALLER1
-rw-r--r--lib/pycparser-2.21.dist-info/LICENSE27
-rw-r--r--lib/pycparser-2.21.dist-info/METADATA31
-rw-r--r--lib/pycparser-2.21.dist-info/RECORD41
-rw-r--r--lib/pycparser-2.21.dist-info/WHEEL6
-rw-r--r--lib/pycparser-2.21.dist-info/top_level.txt1
-rw-r--r--lib/pycparser/__init__.py90
-rw-r--r--lib/pycparser/_ast_gen.py336
-rw-r--r--lib/pycparser/_build_tables.py37
-rw-r--r--lib/pycparser/_c_ast.cfg195
-rw-r--r--lib/pycparser/ast_transforms.py164
-rw-r--r--lib/pycparser/c_ast.py1125
-rw-r--r--lib/pycparser/c_generator.py502
-rw-r--r--lib/pycparser/c_lexer.py554
-rw-r--r--lib/pycparser/c_parser.py1936
-rw-r--r--lib/pycparser/lextab.py10
-rw-r--r--lib/pycparser/ply/__init__.py5
-rw-r--r--lib/pycparser/ply/cpp.py905
-rw-r--r--lib/pycparser/ply/ctokens.py133
-rw-r--r--lib/pycparser/ply/lex.py1099
-rw-r--r--lib/pycparser/ply/yacc.py3494
-rw-r--r--lib/pycparser/ply/ygen.py74
-rw-r--r--lib/pycparser/plyparser.py133
-rw-r--r--lib/pycparser/yacctab.py366
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/AUTHORS.rst50
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/INSTALLER1
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/LICENSE.rst61
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/METADATA84
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/RECORD513
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/REQUESTED0
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/WHEEL5
-rw-r--r--lib/pycryptodome-3.15.0.dist-info/top_level.txt1
-rw-r--r--lib/snack.py998
-rw-r--r--lib/sqlalchemy/__init__.py158
-rwxr-xr-xlib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.sobin0 -> 53952 bytes
-rw-r--r--lib/sqlalchemy/connectors/__init__.py10
-rw-r--r--lib/sqlalchemy/connectors/mxodbc.py166
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py193
-rwxr-xr-xlib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.sobin0 -> 60640 bytes
-rwxr-xr-xlib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.sobin0 -> 92632 bytes
-rw-r--r--lib/sqlalchemy/databases/__init__.py38
-rw-r--r--lib/sqlalchemy/dialects/__init__.py72
-rw-r--r--lib/sqlalchemy/dialects/firebird/__init__.py41
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py989
-rw-r--r--lib/sqlalchemy/dialects/firebird/fdb.py112
-rw-r--r--lib/sqlalchemy/dialects/firebird/kinterbasdb.py202
-rw-r--r--lib/sqlalchemy/dialects/mssql/__init__.py85
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py3545
-rw-r--r--lib/sqlalchemy/dialects/mssql/information_schema.py232
-rw-r--r--lib/sqlalchemy/dialects/mssql/json.py125
-rw-r--r--lib/sqlalchemy/dialects/mssql/mxodbc.py150
-rw-r--r--lib/sqlalchemy/dialects/mssql/provision.py116
-rw-r--r--lib/sqlalchemy/dialects/mssql/pymssql.py138
-rw-r--r--lib/sqlalchemy/dialects/mssql/pyodbc.py673
-rw-r--r--lib/sqlalchemy/dialects/mysql/__init__.py103
-rw-r--r--lib/sqlalchemy/dialects/mysql/aiomysql.py317
-rw-r--r--lib/sqlalchemy/dialects/mysql/asyncmy.py328
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py3306
-rw-r--r--lib/sqlalchemy/dialects/mysql/cymysql.py82
-rw-r--r--lib/sqlalchemy/dialects/mysql/dml.py175
-rw-r--r--lib/sqlalchemy/dialects/mysql/enumerated.py263
-rw-r--r--lib/sqlalchemy/dialects/mysql/expression.py130
-rw-r--r--lib/sqlalchemy/dialects/mysql/json.py84
-rw-r--r--lib/sqlalchemy/dialects/mysql/mariadb.py25
-rw-r--r--lib/sqlalchemy/dialects/mysql/mariadbconnector.py240
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqlconnector.py240
-rw-r--r--lib/sqlalchemy/dialects/mysql/mysqldb.py331
-rw-r--r--lib/sqlalchemy/dialects/mysql/oursql.py273
-rw-r--r--lib/sqlalchemy/dialects/mysql/provision.py78
-rw-r--r--lib/sqlalchemy/dialects/mysql/pymysql.py98
-rw-r--r--lib/sqlalchemy/dialects/mysql/pyodbc.py136
-rw-r--r--lib/sqlalchemy/dialects/mysql/reflection.py558
-rw-r--r--lib/sqlalchemy/dialects/mysql/reserved_words.py564
-rw-r--r--lib/sqlalchemy/dialects/mysql/types.py773
-rw-r--r--lib/sqlalchemy/dialects/oracle/__init__.py58
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py2522
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py1424
-rw-r--r--lib/sqlalchemy/dialects/oracle/provision.py160
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py117
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py413
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py1112
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py4651
-rw-r--r--lib/sqlalchemy/dialects/postgresql/dml.py274
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py277
-rw-r--r--lib/sqlalchemy/dialects/postgresql/hstore.py455
-rw-r--r--lib/sqlalchemy/dialects/postgresql/json.py327
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py594
-rw-r--r--lib/sqlalchemy/dialects/postgresql/provision.py124
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py1088
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py60
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pygresql.py278
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pypostgresql.py126
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ranges.py138
-rw-r--r--lib/sqlalchemy/dialects/sqlite/__init__.py58
-rw-r--r--lib/sqlalchemy/dialects/sqlite/aiosqlite.py335
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py2556
-rw-r--r--lib/sqlalchemy/dialects/sqlite/dml.py200
-rw-r--r--lib/sqlalchemy/dialects/sqlite/json.py84
-rw-r--r--lib/sqlalchemy/dialects/sqlite/provision.py142
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlcipher.py164
-rw-r--r--lib/sqlalchemy/dialects/sqlite/pysqlite.py613
-rw-r--r--lib/sqlalchemy/dialects/sybase/__init__.py67
-rw-r--r--lib/sqlalchemy/dialects/sybase/base.py1100
-rw-r--r--lib/sqlalchemy/dialects/sybase/mxodbc.py34
-rw-r--r--lib/sqlalchemy/dialects/sybase/pyodbc.py89
-rw-r--r--lib/sqlalchemy/dialects/sybase/pysybase.py106
-rw-r--r--lib/sqlalchemy/engine/__init__.py62
-rw-r--r--lib/sqlalchemy/engine/base.py3450
-rw-r--r--lib/sqlalchemy/engine/characteristics.py56
-rw-r--r--lib/sqlalchemy/engine/create.py743
-rw-r--r--lib/sqlalchemy/engine/cursor.py1942
-rw-r--r--lib/sqlalchemy/engine/default.py1936
-rw-r--r--lib/sqlalchemy/engine/events.py835
-rw-r--r--lib/sqlalchemy/engine/interfaces.py1719
-rw-r--r--lib/sqlalchemy/engine/mock.py118
-rw-r--r--lib/sqlalchemy/engine/reflection.py1160
-rw-r--r--lib/sqlalchemy/engine/result.py1857
-rw-r--r--lib/sqlalchemy/engine/row.py621
-rw-r--r--lib/sqlalchemy/engine/strategies.py17
-rw-r--r--lib/sqlalchemy/engine/url.py806
-rw-r--r--lib/sqlalchemy/engine/util.py253
-rw-r--r--lib/sqlalchemy/event/__init__.py17
-rw-r--r--lib/sqlalchemy/event/api.py219
-rw-r--r--lib/sqlalchemy/event/attr.py468
-rw-r--r--lib/sqlalchemy/event/base.py345
-rw-r--r--lib/sqlalchemy/event/legacy.py185
-rw-r--r--lib/sqlalchemy/event/registry.py297
-rw-r--r--lib/sqlalchemy/events.py14
-rw-r--r--lib/sqlalchemy/exc.py733
-rw-r--r--lib/sqlalchemy/ext/__init__.py11
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py1627
-rw-r--r--lib/sqlalchemy/ext/asyncio/__init__.py22
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py89
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py828
-rw-r--r--lib/sqlalchemy/ext/asyncio/events.py44
-rw-r--r--lib/sqlalchemy/ext/asyncio/exc.py21
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py671
-rw-r--r--lib/sqlalchemy/ext/asyncio/scoping.py107
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py759
-rw-r--r--lib/sqlalchemy/ext/automap.py1234
-rw-r--r--lib/sqlalchemy/ext/baked.py648
-rw-r--r--lib/sqlalchemy/ext/compiler.py613
-rw-r--r--lib/sqlalchemy/ext/declarative/__init__.py64
-rw-r--r--lib/sqlalchemy/ext/declarative/extensions.py463
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py256
-rw-r--r--lib/sqlalchemy/ext/hybrid.py1206
-rw-r--r--lib/sqlalchemy/ext/indexable.py352
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py416
-rw-r--r--lib/sqlalchemy/ext/mutable.py958
-rw-r--r--lib/sqlalchemy/ext/mypy/__init__.py0
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py299
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py516
-rw-r--r--lib/sqlalchemy/ext/mypy/infer.py556
-rw-r--r--lib/sqlalchemy/ext/mypy/names.py253
-rw-r--r--lib/sqlalchemy/ext/mypy/plugin.py284
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py305
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py388
-rw-r--r--lib/sqlalchemy/ext/serializer.py177
-rw-r--r--lib/sqlalchemy/future/__init__.py18
-rw-r--r--lib/sqlalchemy/future/engine.py413
-rw-r--r--lib/sqlalchemy/future/orm/__init__.py10
-rw-r--r--lib/sqlalchemy/inspection.py93
-rw-r--r--lib/sqlalchemy/log.py241
-rw-r--r--lib/sqlalchemy/orm/__init__.py344
-rw-r--r--lib/sqlalchemy/orm/attributes.py2331
-rw-r--r--lib/sqlalchemy/orm/base.py572
-rw-r--r--lib/sqlalchemy/orm/clsregistry.py441
-rw-r--r--lib/sqlalchemy/orm/collections.py1706
-rw-r--r--lib/sqlalchemy/orm/context.py3136
-rw-r--r--lib/sqlalchemy/orm/decl_api.py1062
-rw-r--r--lib/sqlalchemy/orm/decl_base.py1210
-rw-r--r--lib/sqlalchemy/orm/dependency.py1290
-rw-r--r--lib/sqlalchemy/orm/descriptor_props.py745
-rw-r--r--lib/sqlalchemy/orm/dynamic.py491
-rw-r--r--lib/sqlalchemy/orm/evaluator.py241
-rw-r--r--lib/sqlalchemy/orm/events.py2876
-rw-r--r--lib/sqlalchemy/orm/exc.py204
-rw-r--r--lib/sqlalchemy/orm/identity.py254
-rw-r--r--lib/sqlalchemy/orm/instrumentation.py652
-rw-r--r--lib/sqlalchemy/orm/interfaces.py978
-rw-r--r--lib/sqlalchemy/orm/loading.py1465
-rw-r--r--lib/sqlalchemy/orm/mapper.py3658
-rw-r--r--lib/sqlalchemy/orm/path_registry.py519
-rw-r--r--lib/sqlalchemy/orm/persistence.py2517
-rw-r--r--lib/sqlalchemy/orm/properties.py430
-rw-r--r--lib/sqlalchemy/orm/query.py3508
-rw-r--r--lib/sqlalchemy/orm/relationships.py3684
-rw-r--r--lib/sqlalchemy/orm/scoping.py228
-rw-r--r--lib/sqlalchemy/orm/session.py4386
-rw-r--r--lib/sqlalchemy/orm/state.py1025
-rw-r--r--lib/sqlalchemy/orm/strategies.py3141
-rw-r--r--lib/sqlalchemy/orm/strategy_options.py2008
-rw-r--r--lib/sqlalchemy/orm/sync.py167
-rw-r--r--lib/sqlalchemy/orm/unitofwork.py784
-rw-r--r--lib/sqlalchemy/orm/util.py2149
-rw-r--r--lib/sqlalchemy/pool/__init__.py56
-rw-r--r--lib/sqlalchemy/pool/base.py1121
-rw-r--r--lib/sqlalchemy/pool/dbapi_proxy.py147
-rw-r--r--lib/sqlalchemy/pool/events.py284
-rw-r--r--lib/sqlalchemy/pool/impl.py514
-rw-r--r--lib/sqlalchemy/processors.py176
-rw-r--r--lib/sqlalchemy/schema.py59
-rw-r--r--lib/sqlalchemy/sql/__init__.py150
-rw-r--r--lib/sqlalchemy/sql/annotation.py364
-rw-r--r--lib/sqlalchemy/sql/base.py1702
-rw-r--r--lib/sqlalchemy/sql/coercions.py1096
-rw-r--r--lib/sqlalchemy/sql/compiler.py5525
-rw-r--r--lib/sqlalchemy/sql/crud.py1091
-rw-r--r--lib/sqlalchemy/sql/ddl.py1341
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py360
-rw-r--r--lib/sqlalchemy/sql/dml.py1514
-rw-r--r--lib/sqlalchemy/sql/elements.py5415
-rw-r--r--lib/sqlalchemy/sql/events.py331
-rw-r--r--lib/sqlalchemy/sql/expression.py278
-rw-r--r--lib/sqlalchemy/sql/functions.py1575
-rw-r--r--lib/sqlalchemy/sql/lambdas.py1314
-rw-r--r--lib/sqlalchemy/sql/naming.py210
-rw-r--r--lib/sqlalchemy/sql/operators.py1688
-rw-r--r--lib/sqlalchemy/sql/roles.py239
-rw-r--r--lib/sqlalchemy/sql/schema.py5268
-rw-r--r--lib/sqlalchemy/sql/selectable.py6946
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py3351
-rw-r--r--lib/sqlalchemy/sql/traversals.py1559
-rw-r--r--lib/sqlalchemy/sql/type_api.py1974
-rw-r--r--lib/sqlalchemy/sql/util.py1120
-rw-r--r--lib/sqlalchemy/sql/visitors.py852
-rw-r--r--lib/sqlalchemy/testing/__init__.py86
-rw-r--r--lib/sqlalchemy/testing/assertions.py845
-rw-r--r--lib/sqlalchemy/testing/assertsql.py457
-rw-r--r--lib/sqlalchemy/testing/asyncio.py128
-rw-r--r--lib/sqlalchemy/testing/config.py209
-rw-r--r--lib/sqlalchemy/testing/engines.py465
-rw-r--r--lib/sqlalchemy/testing/entities.py111
-rw-r--r--lib/sqlalchemy/testing/exclusions.py465
-rw-r--r--lib/sqlalchemy/testing/fixtures.py870
-rw-r--r--lib/sqlalchemy/testing/mock.py32
-rw-r--r--lib/sqlalchemy/testing/pickleable.py151
-rw-r--r--lib/sqlalchemy/testing/plugin/__init__.py0
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py54
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py789
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py820
-rw-r--r--lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py112
-rw-r--r--lib/sqlalchemy/testing/profiling.py335
-rw-r--r--lib/sqlalchemy/testing/provision.py416
-rw-r--r--lib/sqlalchemy/testing/requirements.py1518
-rw-r--r--lib/sqlalchemy/testing/schema.py218
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py13
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py204
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py381
-rw-r--r--lib/sqlalchemy/testing/suite/test_deprecations.py145
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py361
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py367
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py1738
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py426
-rw-r--r--lib/sqlalchemy/testing/suite/test_rowcount.py165
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py1783
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py282
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py1508
-rw-r--r--lib/sqlalchemy/testing/suite/test_unicode_ddl.py206
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py60
-rw-r--r--lib/sqlalchemy/testing/util.py458
-rw-r--r--lib/sqlalchemy/testing/warnings.py82
-rw-r--r--lib/sqlalchemy/types.py119
-rw-r--r--lib/sqlalchemy/util/__init__.py175
-rw-r--r--lib/sqlalchemy/util/_collections.py1089
-rw-r--r--lib/sqlalchemy/util/_compat_py3k.py67
-rw-r--r--lib/sqlalchemy/util/_concurrency_py3k.py194
-rw-r--r--lib/sqlalchemy/util/_preloaded.py68
-rw-r--r--lib/sqlalchemy/util/compat.py632
-rw-r--r--lib/sqlalchemy/util/concurrency.py73
-rw-r--r--lib/sqlalchemy/util/deprecations.py417
-rw-r--r--lib/sqlalchemy/util/langhelpers.py1945
-rw-r--r--lib/sqlalchemy/util/queue.py291
-rw-r--r--lib/sqlalchemy/util/topological.py100
-rw-r--r--lib/sunhpc/__init__.py5
-rw-r--r--lib/sunhpc/commands/__init__.py1631
-rw-r--r--lib/sunhpc/commands/add/__init__.py7
-rw-r--r--lib/sunhpc/commands/add/host/__init__.py124
-rw-r--r--lib/sunhpc/commands/add/host/interface/__init__.py109
-rw-r--r--lib/sunhpc/commands/add/host/security/__init__.py99
-rw-r--r--lib/sunhpc/commands/add/security/__init__.py68
-rw-r--r--lib/sunhpc/commands/build/__init__.py7
-rw-r--r--lib/sunhpc/commands/build/initializes/__init__.py626
-rw-r--r--lib/sunhpc/commands/check/__init__.py6
-rw-r--r--lib/sunhpc/commands/check/services/__init__.py52
-rw-r--r--lib/sunhpc/commands/create/__init__.py7
-rw-r--r--lib/sunhpc/commands/create/distro/__init__.py162
-rw-r--r--lib/sunhpc/commands/create/pxelinux/__init__.py7
-rw-r--r--lib/sunhpc/commands/create/pxelinux/client/__init__.py85
-rw-r--r--lib/sunhpc/commands/create/pxelinux/efiboot/__init__.py113
-rw-r--r--lib/sunhpc/commands/create/pxelinux/kickstart/__init__.py85
-rw-r--r--lib/sunhpc/commands/create/pxelinux/product/__init__.py80
-rw-r--r--lib/sunhpc/commands/create/pxelinux/squashfs/__init__.py113
-rw-r--r--lib/sunhpc/commands/create/pxelinux/updates/__init__.py80
-rw-r--r--lib/sunhpc/commands/create/pxelinux/vminitrd/__init__.py124
-rw-r--r--lib/sunhpc/commands/create/repos/__init__.py105
-rw-r--r--lib/sunhpc/commands/create/roll/__init__.py440
-rw-r--r--lib/sunhpc/commands/create/rpm/__init__.py476
-rw-r--r--lib/sunhpc/commands/create/security/__init__.py221
-rw-r--r--lib/sunhpc/commands/create/security/sshd/__init__.py60
-rw-r--r--lib/sunhpc/commands/create/security/users/__init__.py42
-rw-r--r--lib/sunhpc/commands/create/xml/__init__.py116
-rw-r--r--lib/sunhpc/commands/database/__init__.py6
-rw-r--r--lib/sunhpc/commands/database/init/__init__.py487
-rw-r--r--lib/sunhpc/commands/list/__init__.py7
-rw-r--r--lib/sunhpc/commands/list/help/__init__.py71
-rw-r--r--lib/sunhpc/commands/list/host/__init__.py38
-rw-r--r--lib/sunhpc/commands/list/host/interface/__init__.py41
-rw-r--r--lib/sunhpc/commands/list/license/__init__.py20
-rw-r--r--lib/sunhpc/commands/list/license/license.txt6
-rw-r--r--lib/sunhpc/commands/pxelinux/__init__.py7
-rw-r--r--lib/sunhpc/commands/pxelinux/build/__init__.py228
-rw-r--r--lib/sunhpc/commands/pxelinux/build/autofs/__init__.py124
-rw-r--r--lib/sunhpc/commands/pxelinux/build/dhcpd/__init__.py69
-rw-r--r--lib/sunhpc/commands/pxelinux/build/httpd/__init__.py195
-rw-r--r--lib/sunhpc/commands/pxelinux/build/nodes/__init__.py71
-rw-r--r--lib/sunhpc/commands/pxelinux/build/tftpd/__init__.py173
-rw-r--r--lib/sunhpc/commands/repair/__init__.py7
-rw-r--r--lib/sunhpc/commands/repair/permission/__init__.py54
-rw-r--r--lib/sunhpc/commands/repair/users/__init__.py7
-rw-r--r--lib/sunhpc/commands/repair/users/authorized/__init__.py82
-rw-r--r--lib/sunhpc/commands/report/__init__.py7
-rw-r--r--lib/sunhpc/commands/report/distro/__init__.py19
-rw-r--r--lib/sunhpc/commands/report/host/__init__.py148
-rw-r--r--lib/sunhpc/commands/report/host/attr/__init__.py72
-rw-r--r--lib/sunhpc/commands/report/host/dhcpd/__init__.py171
-rw-r--r--lib/sunhpc/commands/report/kickstart/__init__.py57
-rw-r--r--lib/sunhpc/commands/report/knownhosts/__init__.py102
-rw-r--r--lib/sunhpc/commands/report/nextip/__init__.py73
-rw-r--r--lib/sunhpc/commands/run/__init__.py7
-rw-r--r--lib/sunhpc/commands/run/host/__init__.py280
-rw-r--r--lib/sunhpc/commands/set/__init__.py7
-rw-r--r--lib/sunhpc/commands/set/host/__init__.py8
-rw-r--r--lib/sunhpc/commands/set/host/boot/__init__.py52
-rw-r--r--lib/sunhpc/commands/set/host/boot/plugin_00_ip2hex.py170
-rw-r--r--lib/sunhpc/commands/set/host/cpus/__init__.py45
-rw-r--r--lib/sunhpc/commands/set/host/interface/__init__.py3
-rw-r--r--lib/sunhpc/commands/set/host/interface/iface/__init__.py62
-rw-r--r--lib/sunhpc/commands/set/host/interface/ip/__init__.py87
-rw-r--r--lib/sunhpc/commands/set/host/interface/mac/__init__.py72
-rw-r--r--lib/sunhpc/commands/set/host/interface/name/__init__.py83
-rw-r--r--lib/sunhpc/commands/set/host/interface/subnet/__init__.py80
-rw-r--r--lib/sunhpc/commands/soft/__init__.py7
-rw-r--r--lib/sunhpc/commands/soft/autodock/__init__.py120
-rw-r--r--lib/sunhpc/commands/soft/gaussian/__init__.py152
-rw-r--r--lib/sunhpc/commands/sync/__init__.py7
-rw-r--r--lib/sunhpc/commands/sync/config/__init__.py29
-rw-r--r--lib/sunhpc/commands/sync/config/plugin_00_safe.py18
-rw-r--r--lib/sunhpc/commands/sync/config/plugin_05_sshd.py11
-rw-r--r--lib/sunhpc/commands/sync/users/__init__.py21
-rw-r--r--lib/sunhpc/commands/sync/users/plugin_00_fixmaster.py19
-rw-r--r--lib/sunhpc/commands/sync/users/plugin_05_share.py20
-rw-r--r--lib/sunhpc/commands/sync/users/plugin_10_fixusers.py121
-rw-r--r--lib/sunhpc/core/build.py585
-rw-r--r--lib/sunhpc/core/dist.py391
-rw-r--r--lib/sunhpc/core/files.py445
-rw-r--r--lib/sunhpc/core/firewalld.py6
-rw-r--r--lib/sunhpc/core/ip.py116
-rw-r--r--lib/sunhpc/core/partition.py902
-rw-r--r--lib/sunhpc/core/printer.py222
-rw-r--r--lib/sunhpc/core/security.py200
-rw-r--r--lib/sunhpc/core/sql.py110
-rw-r--r--lib/sunhpc/core/utils.py278
-rw-r--r--lib/sunhpc/db/__init__.py0
-rw-r--r--lib/sunhpc/db/alchemy-bak/database.py210
-rw-r--r--lib/sunhpc/db/alchemy-bak/helper.py369
-rw-r--r--lib/sunhpc/db/alchemy-bak/mappings/__init__.py0
-rw-r--r--lib/sunhpc/db/alchemy-bak/mappings/base.py219
-rw-r--r--lib/sunhpc/db/database.py204
-rw-r--r--lib/sunhpc/db/helper.py369
-rw-r--r--lib/sunhpc/db/mappings/__init__.py0
-rw-r--r--lib/sunhpc/db/mappings/base.py177
-rw-r--r--lib/sunhpc/db/sqlite-bak/database.py165
-rw-r--r--lib/sunhpc/db/sqlite-bak/helper.py367
-rw-r--r--lib/sunhpc/invoke.py10
-rw-r--r--lib/sunhpc/modules/__init__.py0
-rw-r--r--lib/sunhpc/modules/compute/00_selinux.py11
-rw-r--r--lib/sunhpc/modules/compute/__init__.py0
-rw-r--r--lib/sunhpc/modules/control/00-base.py191
-rw-r--r--lib/sunhpc/modules/control/05-securtiy.py98
-rw-r--r--lib/sunhpc/modules/control/50-packages.py94
-rw-r--r--lib/sunhpc/modules/control/__init__.py0
-rw-r--r--lib/sunhpc/modules/kickstart/00-base.py79
-rw-r--r--lib/sunhpc/modules/kickstart/05-partition.py42
-rw-r--r--lib/sunhpc/modules/kickstart/10-hostauth.py62
-rw-r--r--lib/sunhpc/modules/kickstart/12-security.py147
-rw-r--r--lib/sunhpc/modules/kickstart/15-network.py40
-rw-r--r--lib/sunhpc/modules/kickstart/20-scripts.py161
-rw-r--r--lib/sunhpc/modules/kickstart/30-services.py59
-rw-r--r--lib/sunhpc/modules/kickstart/50-packages.py98
-rw-r--r--lib/sunhpc/modules/kickstart/60-addons.py30
-rw-r--r--lib/sunhpc/modules/kickstart/62-anaconda.py32
-rw-r--r--lib/sunhpc/modules/kickstart/64-pxeboot.py39
-rw-r--r--lib/sunhpc/modules/kickstart/__init__.py0
-rwxr-xr-xsbin/build-sunhpc14
-rwxr-xr-xsbin/calcrollmd549
-rwxr-xr-xsbin/gen_root_pw10
-rwxr-xr-xsbin/insert-ethers742
-rwxr-xr-xsbin/kgen157
-rwxr-xr-xsbin/mksquashfsbin0 -> 211928 bytes
-rwxr-xr-xsbin/mom_gencfg559
-rwxr-xr-xsbin/restart-anaconda34
-rwxr-xr-xsbin/suncli38
-rwxr-xr-xsbin/sunyums178
-rwxr-xr-xsbin/unsquashfsbin0 -> 110544 bytes
991 files changed, 368151 insertions, 40 deletions
diff --git a/.gitignore b/.gitignore
index a81c8ee..f08d8af 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,28 +3,13 @@ __pycache__/
*.py[cod]
*$py.class
-# C extensions
-*.so
-
# Distribution / packaging
.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
*.egg-info/
-.installed.cfg
*.egg
-MANIFEST
+
+# Vim
+*.swp
# PyInstaller
# Usually these files are written by a python script from a template
@@ -49,7 +34,6 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
-cover/
# Translations
*.mo
@@ -73,7 +57,6 @@ docs/_build/
# PyBuilder
.pybuilder/
-target/
# Jupyter Notebook
.ipynb_checkpoints
@@ -104,15 +87,6 @@ celerybeat.pid
# SageMath parsed files
*.sage.py
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
# Spyder project settings
.spyderproject
.spyproject
@@ -120,19 +94,8 @@ venv.bak/
# Rope project settings
.ropeproject
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
-
-# Cython debug symbols
-cython_debug/
diff --git a/bin/ci b/bin/ci
new file mode 100755
index 0000000..f349367
--- /dev/null
+++ b/bin/ci
@@ -0,0 +1,24 @@
+#!/bin/sh
+# ci (GNU RCS) 5.9.0
+#
+# Copyright (C) 2013 Thien-Thi Nguyen
+#
+# This file is part of GNU RCS.
+#
+# GNU RCS is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# GNU RCS is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+d=`echo "$0" | sed 's|[^/]*$||'`
+exec "$d"rcs ci "$@"
+
+# ci ends here
diff --git a/bin/co b/bin/co
new file mode 100755
index 0000000..2e7d145
--- /dev/null
+++ b/bin/co
@@ -0,0 +1,24 @@
+#!/bin/sh
+# co (GNU RCS) 5.9.0
+#
+# Copyright (C) 2013 Thien-Thi Nguyen
+#
+# This file is part of GNU RCS.
+#
+# GNU RCS is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# GNU RCS is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+d=`echo "$0" | sed 's|[^/]*$||'`
+exec "$d"rcs co "$@"
+
+# co ends here
diff --git a/bin/dd_extract b/bin/dd_extract
new file mode 100755
index 0000000..ef5bcb5
--- /dev/null
+++ b/bin/dd_extract
Binary files differ
diff --git a/bin/dd_list b/bin/dd_list
new file mode 100755
index 0000000..83dc32a
--- /dev/null
+++ b/bin/dd_list
Binary files differ
diff --git a/bin/eject b/bin/eject
new file mode 100755
index 0000000..a750440
--- /dev/null
+++ b/bin/eject
Binary files differ
diff --git a/bin/ident b/bin/ident
new file mode 100755
index 0000000..85830ac
--- /dev/null
+++ b/bin/ident
Binary files differ
diff --git a/bin/lsscsi b/bin/lsscsi
new file mode 100755
index 0000000..1a0971d
--- /dev/null
+++ b/bin/lsscsi
Binary files differ
diff --git a/bin/merge b/bin/merge
new file mode 100755
index 0000000..cb97c61
--- /dev/null
+++ b/bin/merge
Binary files differ
diff --git a/bin/rcs b/bin/rcs
new file mode 100755
index 0000000..e0abbd0
--- /dev/null
+++ b/bin/rcs
Binary files differ
diff --git a/bin/rcsclean b/bin/rcsclean
new file mode 100755
index 0000000..548491d
--- /dev/null
+++ b/bin/rcsclean
@@ -0,0 +1,24 @@
+#!/bin/sh
+# rcsclean (GNU RCS) 5.9.0
+#
+# Copyright (C) 2013 Thien-Thi Nguyen
+#
+# This file is part of GNU RCS.
+#
+# GNU RCS is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# GNU RCS is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+d=`echo "$0" | sed 's|[^/]*$||'`
+exec "$d"rcs rcsclean "$@"
+
+# rcsclean ends here
diff --git a/bin/rcsdiff b/bin/rcsdiff
new file mode 100755
index 0000000..10e6170
--- /dev/null
+++ b/bin/rcsdiff
@@ -0,0 +1,24 @@
+#!/bin/sh
+# rcsdiff (GNU RCS) 5.9.0
+#
+# Copyright (C) 2013 Thien-Thi Nguyen
+#
+# This file is part of GNU RCS.
+#
+# GNU RCS is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# GNU RCS is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+d=`echo "$0" | sed 's|[^/]*$||'`
+exec "$d"rcs rcsdiff "$@"
+
+# rcsdiff ends here
diff --git a/bin/rcsfreeze b/bin/rcsfreeze
new file mode 100755
index 0000000..0a13a17
--- /dev/null
+++ b/bin/rcsfreeze
@@ -0,0 +1,132 @@
+#! /bin/sh
+# rcsfreeze - assign a symbolic revision number to a configuration of RCS files
+#
+# Copyright (C) 2010-2013 Thien-Thi Nguyen
+#
+# This file is part of GNU RCS.
+#
+# GNU RCS is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# GNU RCS is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+##
+# Usage: rcsfreeze [symbolic-name]
+#
+# The idea is to run rcsfreeze each time a new version is checked
+# in. A unique symbolic revision number (C_[number], where number
+# is increased each time rcsfreeze is run) is then assigned to the most
+# recent revision of each RCS file of the main trunk.
+#
+# If the command is invoked with an argument, then this
+# argument is used as the symbolic name to freeze a configuration.
+# The unique identifier is still generated
+# and is listed in the log file but it will not appear as
+# part of the symbolic revision name in the actual RCS file.
+#
+# A log message is requested from the user which is saved for future
+# references.
+#
+# The shell script works only on all RCS files at one time.
+# It is important that all changed files are checked in (there are
+# no precautions against any error in this respect).
+# file names:
+# {RCS/}.rcsfreeze.ver version number
+# {RCS/}.rscfreeze.log log messages, most recent first
+##
+version='rcsfreeze (GNU RCS) 5.9.0
+Copyright (C) 2010-2013 Thien-Thi Nguyen
+Copyright (C) 1990-1995 Paul Eggert
+License GPLv3+; GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>
+This is free software: you are free to change and redistribute it.
+There is NO WARRANTY, to the extent permitted by law.
+
+Written by Stephan v. Bechtolsheim.'
+
+usage ()
+{
+ sed '/^##/,/^##/!d;/^##/d;s/^# //g;s/^#$//g' $0
+}
+
+if [ x"$1" = x--help ] ; then usage ; exit 0 ; fi
+if [ x"$1" = x--version ] ; then echo "$version" ; exit 0 ; fi
+
+PATH=/usr/local/bin:/bin:/usr/bin:/usr/ucb:$PATH
+export PATH
+
+DATE=`date` || exit
+# Check whether we have an RCS subdirectory, so we can have the right
+# prefix for our paths.
+if test -d RCS
+then RCSDIR=RCS/ EXT=
+else RCSDIR= EXT=,v
+fi
+
+# Version number stuff, log message file
+VERSIONFILE=${RCSDIR}.rcsfreeze.ver
+LOGFILE=${RCSDIR}.rcsfreeze.log
+# Initialize, rcsfreeze never run before in the current directory
+test -r $VERSIONFILE || { echo 0 >$VERSIONFILE && >>$LOGFILE; } || exit
+
+# Get Version number, increase it, write back to file.
+VERSIONNUMBER=`cat $VERSIONFILE` &&
+VERSIONNUMBER=`expr $VERSIONNUMBER + 1` &&
+echo $VERSIONNUMBER >$VERSIONFILE || exit
+
+# Symbolic Revision Number
+SYMREV=C_$VERSIONNUMBER
+# Allow the user to give a meaningful symbolic name to the revision.
+SYMREVNAME=${1-$SYMREV}
+echo >&2 "rcsfreeze: symbolic revision number computed: \"${SYMREV}\"
+rcsfreeze: symbolic revision number used: \"${SYMREVNAME}\"
+rcsfreeze: the two differ only when rcsfreeze invoked with argument
+rcsfreeze: give log message, summarizing changes (end with EOF or single '.')" \
+ || exit
+
+# Stamp the logfile. Because we order the logfile the most recent
+# first we will have to save everything right now in a temporary file.
+TMPLOG=`mktemp -t` || exit
+trap 'rm -f $TMPLOG; exit 1' 1 2 13 15
+# Now ask for a log message, continously add to the log file
+(
+ echo "Version: $SYMREVNAME($SYMREV), Date: $DATE
+-----------" || exit
+ while read MESS
+ do
+ case $MESS in
+ .) break
+ esac
+ echo " $MESS" || exit
+ done
+ echo "-----------
+" &&
+ cat $LOGFILE
+) >$TMPLOG &&
+
+# combine old and new logfiles
+cp $TMPLOG $LOGFILE &&
+rm -f $TMPLOG &&
+
+# Now the real work begins by assigning a symbolic revision number
+# to each rcs file. Take the most recent version on the default branch.
+
+# If there are any .*,v files, throw them in too.
+# But ignore RCS/.* files that do not end in ,v.
+DOTFILES=
+for DOTFILE in ${RCSDIR}.*,v
+do
+ if test -f "$DOTFILE"
+ then
+ DOTFILES="${RCSDIR}.*,v"
+ break
+ fi
+done
+
+exec rcs -q -n$SYMREVNAME: ${RCSDIR}*$EXT $DOTFILES
diff --git a/bin/rcsmerge b/bin/rcsmerge
new file mode 100755
index 0000000..b62f940
--- /dev/null
+++ b/bin/rcsmerge
@@ -0,0 +1,24 @@
+#!/bin/sh
+# rcsmerge (GNU RCS) 5.9.0
+#
+# Copyright (C) 2013 Thien-Thi Nguyen
+#
+# This file is part of GNU RCS.
+#
+# GNU RCS is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# GNU RCS is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+d=`echo "$0" | sed 's|[^/]*$||'`
+exec "$d"rcs rcsmerge "$@"
+
+# rcsmerge ends here
diff --git a/bin/rlog b/bin/rlog
new file mode 100755
index 0000000..99e3e22
--- /dev/null
+++ b/bin/rlog
@@ -0,0 +1,24 @@
+#!/bin/sh
+# rlog (GNU RCS) 5.9.0
+#
+# Copyright (C) 2013 Thien-Thi Nguyen
+#
+# This file is part of GNU RCS.
+#
+# GNU RCS is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# GNU RCS is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty
+# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+# See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+d=`echo "$0" | sed 's|[^/]*$||'`
+exec "$d"rcs rlog "$@"
+
+# rlog ends here
diff --git a/bin/sunhpc b/bin/sunhpc
new file mode 100755
index 0000000..82f02f9
--- /dev/null
+++ b/bin/sunhpc
@@ -0,0 +1,153 @@
+#!/opt/sunpy3/bin/python3
+#coding:utf-8
+
+import os
+import sys
+import pwd
+import sunhpc
+import syslog
+import shutil
+import sunhpc.invoke
+
+if sys.getdefaultencoding() != 'utf-8':
+ reload(sys)
+ sys.setdefaultencoding('utf-8')
+
+def ttySize():
+ try:
+ (width, heigh) = shutil.get_terminal_size()
+ except:
+ width = 80
+ return width
+
+os.environ['COLUMNS'] = str(ttySize())
+syslog.openlog('SunhpcCMD', syslog.LOG_PID, syslog.LOG_LOCAL0)
+
+try:
+ import sqlalchemy
+ import sunhpc.db.helper
+
+ database = sunhpc.db.helper.DatabaseHelper()
+ database.connect()
+
+except ImportError:
+ raise
+except sqlalchemy.exc.OperationalError:
+ raise
+
+# 如果参数为空,应当主动赋值帮助,就当成sunhpc help命令.
+if len(sys.argv) == 1:
+ args = [ 'list', 'help']
+else:
+ args = sys.argv[1:]
+#
+# 如果这个sunhpc 没有参数, 则使用 sunhpc list help
+# 否则, 使用第二个参数, 例如 sunhpc list
+# 那么 args[0] = 永远指定第二个参数.
+# 检查命令是否被圈引
+# sunhpc aa bb -- 那么cmd=['aa']
+# sunhpc "aa bb" -- 那么cmd=['aa','bb']
+module = None
+cmd = args[0].split()
+
+# 一般不会执行这里.
+# 只有命令参数带入引号,才会执行此处.
+# 例如: sunhpc "aa bb" 而 sunhpc aa bb 则跳过此处.
+if len(cmd) > 1:
+ s = 'sunhpc.commands.%s' % '.'.join(cmd)
+ # 此时模块路径应该是 sunhpc.commands.aa.bb
+ try:
+ # 如果没有此模块,则触发ModuleNotFoundError异常.
+ __import__(s)
+
+ # 如果模块存在,但是报错,应该是头部没有进行import 相关模块.
+ # 触发 NameError 异常.
+ module = eval(s)
+
+ # 设定变量,供后面for递减模块中args列表从第二个参数取值使用.
+ i = 1
+ except:
+ module = None
+
+# 这里正式开始检查命令参数.
+if not module:
+ # 递减sunhpc 参数.
+ for i in range(len(args), 0, -1):
+ s = 'sunhpc.commands.%s' % '.'.join(args[:i])
+ # sunhpc.commands.aa.bb
+ # sunhpc.commands.aa
+ try:
+ # 倒序检查模块是否存在.
+ __import__(s)
+ module = eval(s)
+ if module:
+ # 如果找到模块,那么其余的参数,就会作为模块参数提供给此模块.
+ break
+ except ImportError:
+ continue
+
+# 如果没有找到提供的任何模块,则命令执行失败.
+if not module:
+ print ('error - invalid sunhpc command "%s"' % args[0])
+ sys.exit(-1)
+
+# 将模块路径重新转化成命令行形式格式.
+# sunhpc.commands.aa.bb
+# sunhpc commands aa bb 去除sunhpc commands.
+# name = aa bb
+name = ' '.join(s.split('.')[2:])
+
+# 此时module已经找到相应的模块了.
+# module = sunhpc.commands.aa.bb
+
+try:
+ # 检查module模块是否有Command类,如果没有则触发 AttributeError 异常.
+ command = getattr(module, 'Command')(database)
+except AttributeError:
+ # 没有Command类,则会执行到此处.到这里我们需要调用帮助信息.
+ #
+ # 此时的bb模块中没有Command类.
+ # 命令:sunhpc aa bb
+ # 已知module路径. sunhpc.commands.aa.bb
+ # 已经name变量, aa bb
+ #
+ # 导入sunhpc帮助模块
+ import sunhpc.commands.list.help
+ help = sunhpc.commands.list.help.Command(database)
+ # 将module路径转换成列表然后使用切片找出模块全路径和子路径.
+ fullmodpath = s.split('.')
+ submodpath = '/'.join(fullmodpath[2:])
+
+ # 执行Command类中run函数.第一个参数是字典,第二个是列表.
+ help.run({'subdir':submodpath}, [])
+ print (help.getText())
+ sys.exit(-1)
+
+# 如果有Command类, 则执行主模块sunhpc.commands.__init__中的runWapper函数
+# 两种用途
+# 1, 收集命令参数并且转换成固定格式.
+# 2, 开始执行命令入口函数.
+# 如果触发自定义命令异常,则输出使用帮助信息.
+if command.MustBeRoot and not (command.isRootUser() or command.isApacheUser()):
+ # 如果需要特权权限, 则需要提权操作.
+ os.system('sudo %s' % ' '.join(sys.argv))
+else:
+ # 收集参数等信息,然后运行模块run函数,错误则触发CommandError异常.
+ # runWrapper是主模块中的Command类中的入口函数.
+ # name : 当前模块路径.
+ # args : 当前有效模块后面全部视为参数.
+ try:
+ command.runWrapper(name, args[i:])
+ text = command.getText()
+ if len(text) > 0:
+ print (text.rstrip())
+ if text[len(text)-1] != '\n':
+ print ()
+ except sunhpc.core.utils.CommandError as e:
+ msg = ' '.join(str(e).split())
+ #print ("\033[91m[errs] %s\033[0m" % msg)
+ print (msg)
+ print (command.usage())
+ exit(1)
+
+syslog.closelog()
diff --git a/bin/syncDaemon b/bin/syncDaemon
new file mode 100755
index 0000000..30aa123
--- /dev/null
+++ b/bin/syncDaemon
@@ -0,0 +1,167 @@
+#!/opt/sunpy3/bin/python3
+#coding:utf-8
+import os, sys, time, queue
+import pickle, base64, shutil, tempfile
+import logging, random, argparse, configparser
+from multiprocessing.managers import BaseManager
+class QueueManager(BaseManager):
+ pass
+s_queue = queue.Queue()
+r_queue = queue.Queue()
+
+class Application(object):
+
+ def __init__(self):
+
+ self.conf = '/etc/sunhpc-plugin.conf'
+ self.addr = None
+ self.port = None
+ self.keys = None
+ self.conn = None
+ self.send = None
+ self.recv = None
+
+ self.modules = {}
+ self.logfile = '/opt/sunhpc/logs/syncdaemon.log'
+
+ def config(self):
+ config = configparser.ConfigParser()
+ if os.path.exists(self.conf):
+ try:
+ config.read(self.conf)
+ self.addr = config['plugins']['address']
+ self.port = int(config['plugins']['port'])
+ self.keys = config['plugins']['keys'].encode('utf-8')
+ except KeyError as e:
+ self.exelog('Read "/etc/sunhpc-plugin.conf error - %s' % repr(e))
+
+ def parseArgs(self):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--addr', metavar='addr', nargs='?', help='IP address')
+ parser.add_argument('--port', metavar='port', nargs='?', help='IP port')
+ parser.add_argument('--keys', metavar='keys', nargs='?', help='Secret key')
+ parser.add_argument('--conf', metavar='conf', nargs='?', help='Config file')
+
+ args = parser.parse_args()
+ if args.addr:
+ self.addr = args.addr
+ if args.port:
+ self.port = args.port
+ if args.keys:
+ self.keys = args.keys.encode('utf-8')
+ if args.conf:
+ self.conf = args.conf
+
+ def register(self):
+ QueueManager.register('send_queue', callable=lambda: s_queue)
+ QueueManager.register('recv_queue', callable=lambda: r_queue)
+ manager = QueueManager(address=(self.addr, self.port), authkey=self.keys)
+ return manager
+
+ def connect(self):
+ self.conn = self.register()
+ self.conn.start()
+ self.send = self.conn.send_queue()
+ self.recv = self.conn.recv_queue()
+
+ def load_plugins(self):
+ plugin_path = '/opt/sunhpc/var/plugins/syncdata'
+ sys.path.append(plugin_path)
+ tmpdirs = tempfile.mkdtemp()
+ self.modules['temp'] = tmpdirs
+ self.modules['path'] = plugin_path
+ self.modules['plugins'] = []
+
+ for plugin_file in os.listdir(plugin_path):
+ plugin_dict = {}
+ if not plugin_file.endswith('.py'):
+ continue
+
+ if plugin_file in ['plugins.py']:
+ fn = os.path.join(plugin_path, plugin_file)
+ with open(fn, 'rb') as f:
+ self.modules['init'] = base64.b64encode(f.read())
+ continue
+
+ p = plugin_file.split('.py')[0]
+ plugin = __import__(p).plugin()
+ plugin_dict['file'] = plugin_file
+ plugin_dict['modname'] = p
+ # 获取src开头函数名称作为模块字典中 key.
+ # src函数在服务器端执行后放入字典中 value.
+ for fname in plugin.get_srcfuncname(plugin):
+ plugin_n = getattr(plugin, fname)
+ plugin_dict[fname] = plugin_n()
+
+ # 只获取set开头函数并且放入模块字典中.
+ # set函数在客户端中执行.
+ for setname in plugin.get_setfuncname(plugin):
+ plugin_dict[setname] = None
+
+ filename = os.path.join(plugin_path, plugin_file)
+ with open(filename, 'rb') as f:
+ content = base64.b64encode(f.read())
+ plugin_dict['source'] = content
+
+ self.modules['plugins'].append(plugin_dict)
+
+ shutil.rmtree(tmpdirs)
+
+ def running(self):
+
+ print ('addr -- %s' % self.addr)
+ print ('port -- %s' % self.port)
+ self.outputlog('start the syncdaemon...')
+ self.connect()
+ if not self.modules:
+ self.load_plugins()
+
+ data = pickle.dumps(self.modules)
+ running = 1
+ while running:
+ time.sleep(1)
+ try:
+ self.send.put(data)
+ result = self.recv.get(timeout=3)
+ if result == 'exit':
+ running = 0
+
+ self.outputlog(result)
+ except queue.Empty:
+ pass
+
+ self.conn.shutdown()
+ self.outputlog('end the syncdaemon...')
+
+ def outputlog(self, s):
+ log_format = "%(asctime)s %(name)s %(levelname)s %(message)s"
+ dateformat = '%Y-%m-%d %H:%M:%S %a'
+ logging.basicConfig(
+ level = logging.DEBUG,
+ format = log_format,
+ datefmt = dateformat,
+ filename = self.logfile
+ )
+ if s:
+ logging.info(s)
+
+if __name__ == "__main__":
+
+ #
+ # 使用createKeys函数创建一个16位秘钥.
+ # 如果想创建一个新的秘钥,那么删除
+ # /opt/sunhpc/data/.daemon_keys 文件.
+ # 将会重新创建一个新的秘钥.
+ #
+ # syncDaemon --addr 127.0.0.1 --port 5000
+ #
+ app = Application()
+ app.config()
+ app.parseArgs()
+ app.running()
+
+
+
+
+
+
diff --git a/data/.database b/data/.database
new file mode 100644
index 0000000..903daea
--- /dev/null
+++ b/data/.database
@@ -0,0 +1 @@
+2023-04-16 08:28:49
diff --git a/data/sunhpc.db b/data/sunhpc.db
new file mode 100644
index 0000000..4206300
--- /dev/null
+++ b/data/sunhpc.db
Binary files differ
diff --git a/etc/env.sunhpc b/etc/env.sunhpc
new file mode 100644
index 0000000..6fc6588
--- /dev/null
+++ b/etc/env.sunhpc
@@ -0,0 +1,35 @@
+#!/bin/bash
+#
+#SunHPC Env Configure
+#
+
+command -v vim > /dev/null
+if [ $? -eq 0 ]
+then
+ alias vi=vim
+ VICONF="$HOME/.vimrc"
+ [ ! -f "$VICONF" ] && echo -e "set ts=4\nset expandtab" > "$VICONF"
+fi
+
+export SUNHPC_HOME=/opt/sunhpc
+export SUNHPC_PYTHON=/opt/sunpy3
+
+if [ `id -g` == 0 ];then
+ export PATH=$SUNHPC_HOME/bin:$SUNHPC_HOME/sbin:$SUNHPC_PYTHON/bin:$SUNHPC_PYTHON/sbin$PATH
+else
+ export PATH=$SUNHPC_HOME/bin:$PATH
+fi
+
+# python .pth config
+SUNHPC_PTH=/opt/sunpy3/lib/python3.10/site-packages/sunhpc.pth
+[ ! -e $SUNHPC_PTH ] && echo "$SUNHPC_HOME/lib" > $SUNHPC_PTH
+
+# pip
+[ ! -d ~/.pip ] && mkdir -p ~/.pip
+
+cat > ~/.pip/pip.conf << 'EOF'
+[global]
+index-url = https://pypi.tuna.tsinghua.edu.cn/simple
+[install]
+trusted-host = https://pypi.tuna.tsinghua.edu.cn/simple
+EOF
diff --git a/etc/safeputrc b/etc/safeputrc
new file mode 100644
index 0000000..bdf6869
--- /dev/null
+++ b/etc/safeputrc
@@ -0,0 +1,4 @@
+<?xml version="1.0" standalone="yes"?>
+<safeput>
+ <PrivateNetwork id="172.16.1.254" mask="24"/>
+</safeput>
diff --git a/lib/Crypto/Cipher/AES.py b/lib/Crypto/Cipher/AES.py
new file mode 100644
index 0000000..13bd7ea
--- /dev/null
+++ b/lib/Crypto/Cipher/AES.py
@@ -0,0 +1,250 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/AES.py : AES
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+"""
+Module's constants for the modes of operation supported with AES:
+
+:var MODE_ECB: :ref:`Electronic Code Book (ECB) <ecb_mode>`
+:var MODE_CBC: :ref:`Cipher-Block Chaining (CBC) <cbc_mode>`
+:var MODE_CFB: :ref:`Cipher FeedBack (CFB) <cfb_mode>`
+:var MODE_OFB: :ref:`Output FeedBack (OFB) <ofb_mode>`
+:var MODE_CTR: :ref:`CounTer Mode (CTR) <ctr_mode>`
+:var MODE_OPENPGP: :ref:`OpenPGP Mode <openpgp_mode>`
+:var MODE_CCM: :ref:`Counter with CBC-MAC (CCM) Mode <ccm_mode>`
+:var MODE_EAX: :ref:`EAX Mode <eax_mode>`
+:var MODE_GCM: :ref:`Galois Counter Mode (GCM) <gcm_mode>`
+:var MODE_SIV: :ref:`Syntethic Initialization Vector (SIV) <siv_mode>`
+:var MODE_OCB: :ref:`Offset Code Book (OCB) <ocb_mode>`
+"""
+
+import sys
+
+from Crypto.Cipher import _create_cipher
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ c_size_t, c_uint8_ptr)
+
+from Crypto.Util import _cpu_features
+from Crypto.Random import get_random_bytes
+
+
+_cproto = """
+ int AES_start_operation(const uint8_t key[],
+ size_t key_len,
+ void **pResult);
+ int AES_encrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int AES_decrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int AES_stop_operation(void *state);
+ """
+
+
+# Load portable AES
+_raw_aes_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_aes",
+ _cproto)
+
+# Try to load AES with AES NI instructions
+try:
+ _raw_aesni_lib = None
+ if _cpu_features.have_aes_ni():
+ _raw_aesni_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_aesni",
+ _cproto.replace("AES",
+ "AESNI"))
+# _raw_aesni may not have been compiled in
+except OSError:
+ pass
+
+
+def _create_base_cipher(dict_parameters):
+ """This method instantiates and returns a handle to a low-level
+ base cipher. It will absorb named parameters in the process."""
+
+ use_aesni = dict_parameters.pop("use_aesni", True)
+
+ try:
+ key = dict_parameters.pop("key")
+ except KeyError:
+ raise TypeError("Missing 'key' parameter")
+
+ if len(key) not in key_size:
+ raise ValueError("Incorrect AES key length (%d bytes)" % len(key))
+
+ if use_aesni and _raw_aesni_lib:
+ start_operation = _raw_aesni_lib.AESNI_start_operation
+ stop_operation = _raw_aesni_lib.AESNI_stop_operation
+ else:
+ start_operation = _raw_aes_lib.AES_start_operation
+ stop_operation = _raw_aes_lib.AES_stop_operation
+
+ cipher = VoidPointer()
+ result = start_operation(c_uint8_ptr(key),
+ c_size_t(len(key)),
+ cipher.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the AES cipher"
+ % result)
+ return SmartPointer(cipher.get(), stop_operation)
+
+
+def _derive_Poly1305_key_pair(key, nonce):
+ """Derive a tuple (r, s, nonce) for a Poly1305 MAC.
+
+ If nonce is ``None``, a new 16-byte nonce is generated.
+ """
+
+ if len(key) != 32:
+ raise ValueError("Poly1305 with AES requires a 32-byte key")
+
+ if nonce is None:
+ nonce = get_random_bytes(16)
+ elif len(nonce) != 16:
+ raise ValueError("Poly1305 with AES requires a 16-byte nonce")
+
+ s = new(key[:16], MODE_ECB).encrypt(nonce)
+ return key[16:], s, nonce
+
+
+def new(key, mode, *args, **kwargs):
+ """Create a new AES cipher.
+
+ :param key:
+ The secret key to use in the symmetric cipher.
+
+ It must be 16, 24 or 32 bytes long (respectively for *AES-128*,
+ *AES-192* or *AES-256*).
+
+ For ``MODE_SIV`` only, it doubles to 32, 48, or 64 bytes.
+ :type key: bytes/bytearray/memoryview
+
+ :param mode:
+ The chaining mode to use for encryption or decryption.
+ If in doubt, use ``MODE_EAX``.
+ :type mode: One of the supported ``MODE_*`` constants
+
+ :Keyword Arguments:
+ * **iv** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_CBC``, ``MODE_CFB``, ``MODE_OFB``,
+ and ``MODE_OPENPGP`` modes).
+
+ The initialization vector to use for encryption or decryption.
+
+ For ``MODE_CBC``, ``MODE_CFB``, and ``MODE_OFB`` it must be 16 bytes long.
+
+ For ``MODE_OPENPGP`` mode only,
+ it must be 16 bytes long for encryption
+ and 18 bytes for decryption (in the latter case, it is
+ actually the *encrypted* IV which was prefixed to the ciphertext).
+
+ If not provided, a random byte string is generated (you must then
+ read its value with the :attr:`iv` attribute).
+
+ * **nonce** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_CCM``, ``MODE_EAX``, ``MODE_GCM``,
+ ``MODE_SIV``, ``MODE_OCB``, and ``MODE_CTR``).
+
+ A value that must never be reused for any other encryption done
+ with this key (except possibly for ``MODE_SIV``, see below).
+
+ For ``MODE_EAX``, ``MODE_GCM`` and ``MODE_SIV`` there are no
+ restrictions on its length (recommended: **16** bytes).
+
+ For ``MODE_CCM``, its length must be in the range **[7..13]**.
+ Bear in mind that with CCM there is a trade-off between nonce
+ length and maximum message size. Recommendation: **11** bytes.
+
+ For ``MODE_OCB``, its length must be in the range **[1..15]**
+ (recommended: **15**).
+
+ For ``MODE_CTR``, its length must be in the range **[0..15]**
+ (recommended: **8**).
+
+ For ``MODE_SIV``, the nonce is optional, if it is not specified,
+ then no nonce is being used, which renders the encryption
+ deterministic.
+
+ If not provided, for modes other than ``MODE_SIV```, a random
+ byte string of the recommended length is used (you must then
+ read its value with the :attr:`nonce` attribute).
+
+ * **segment_size** (*integer*) --
+ (Only ``MODE_CFB``).The number of **bits** the plaintext and ciphertext
+ are segmented in. It must be a multiple of 8.
+ If not specified, it will be assumed to be 8.
+
+ * **mac_len** : (*integer*) --
+ (Only ``MODE_EAX``, ``MODE_GCM``, ``MODE_OCB``, ``MODE_CCM``)
+ Length of the authentication tag, in bytes.
+
+ It must be even and in the range **[4..16]**.
+ The recommended value (and the default, if not specified) is **16**.
+
+ * **msg_len** : (*integer*) --
+ (Only ``MODE_CCM``). Length of the message to (de)cipher.
+ If not specified, ``encrypt`` must be called with the entire message.
+ Similarly, ``decrypt`` can only be called once.
+
+ * **assoc_len** : (*integer*) --
+ (Only ``MODE_CCM``). Length of the associated data.
+ If not specified, all associated data is buffered internally,
+ which may represent a problem for very large messages.
+
+ * **initial_value** : (*integer* or *bytes/bytearray/memoryview*) --
+ (Only ``MODE_CTR``).
+ The initial value for the counter. If not present, the cipher will
+ start counting from 0. The value is incremented by one for each block.
+ The counter number is encoded in big endian mode.
+
+ * **counter** : (*object*) --
+ Instance of ``Crypto.Util.Counter``, which allows full customization
+ of the counter block. This parameter is incompatible to both ``nonce``
+ and ``initial_value``.
+
+ * **use_aesni** : (*boolean*) --
+ Use Intel AES-NI hardware extensions (default: use if available).
+
+ :Return: an AES object, of the applicable mode.
+ """
+
+ kwargs["add_aes_modes"] = True
+ return _create_cipher(sys.modules[__name__], key, mode, *args, **kwargs)
+
+
+MODE_ECB = 1
+MODE_CBC = 2
+MODE_CFB = 3
+MODE_OFB = 5
+MODE_CTR = 6
+MODE_OPENPGP = 7
+MODE_CCM = 8
+MODE_EAX = 9
+MODE_SIV = 10
+MODE_GCM = 11
+MODE_OCB = 12
+
+# Size of a data block (in bytes)
+block_size = 16
+# Size of a key (in bytes)
+key_size = (16, 24, 32)
diff --git a/lib/Crypto/Cipher/AES.pyi b/lib/Crypto/Cipher/AES.pyi
new file mode 100644
index 0000000..8f655cf
--- /dev/null
+++ b/lib/Crypto/Cipher/AES.pyi
@@ -0,0 +1,47 @@
+from typing import Union, Tuple, Optional, Dict
+
+from Crypto.Cipher._mode_ecb import EcbMode
+from Crypto.Cipher._mode_cbc import CbcMode
+from Crypto.Cipher._mode_cfb import CfbMode
+from Crypto.Cipher._mode_ofb import OfbMode
+from Crypto.Cipher._mode_ctr import CtrMode
+from Crypto.Cipher._mode_openpgp import OpenPgpMode
+from Crypto.Cipher._mode_ccm import CcmMode
+from Crypto.Cipher._mode_eax import EaxMode
+from Crypto.Cipher._mode_gcm import GcmMode
+from Crypto.Cipher._mode_siv import SivMode
+from Crypto.Cipher._mode_ocb import OcbMode
+
+AESMode = int
+
+MODE_ECB: AESMode
+MODE_CBC: AESMode
+MODE_CFB: AESMode
+MODE_OFB: AESMode
+MODE_CTR: AESMode
+MODE_OPENPGP: AESMode
+MODE_CCM: AESMode
+MODE_EAX: AESMode
+MODE_GCM: AESMode
+MODE_SIV: AESMode
+MODE_OCB: AESMode
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ mode: AESMode,
+ iv : Buffer = ...,
+ IV : Buffer = ...,
+ nonce : Buffer = ...,
+ segment_size : int = ...,
+ mac_len : int = ...,
+ assoc_len : int = ...,
+ initial_value : Union[int, Buffer] = ...,
+ counter : Dict = ...,
+ use_aesni : bool = ...) -> \
+ Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode,
+ OpenPgpMode, CcmMode, EaxMode, GcmMode,
+ SivMode, OcbMode]: ...
+
+block_size: int
+key_size: Tuple[int, int, int]
diff --git a/lib/Crypto/Cipher/ARC2.py b/lib/Crypto/Cipher/ARC2.py
new file mode 100644
index 0000000..0ba7e33
--- /dev/null
+++ b/lib/Crypto/Cipher/ARC2.py
@@ -0,0 +1,175 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/ARC2.py : ARC2.py
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+"""
+Module's constants for the modes of operation supported with ARC2:
+
+:var MODE_ECB: :ref:`Electronic Code Book (ECB) <ecb_mode>`
+:var MODE_CBC: :ref:`Cipher-Block Chaining (CBC) <cbc_mode>`
+:var MODE_CFB: :ref:`Cipher FeedBack (CFB) <cfb_mode>`
+:var MODE_OFB: :ref:`Output FeedBack (OFB) <ofb_mode>`
+:var MODE_CTR: :ref:`CounTer Mode (CTR) <ctr_mode>`
+:var MODE_OPENPGP: :ref:`OpenPGP Mode <openpgp_mode>`
+:var MODE_EAX: :ref:`EAX Mode <eax_mode>`
+"""
+
+import sys
+
+from Crypto.Cipher import _create_cipher
+from Crypto.Util.py3compat import byte_string
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ c_size_t, c_uint8_ptr)
+
+_raw_arc2_lib = load_pycryptodome_raw_lib(
+ "Crypto.Cipher._raw_arc2",
+ """
+ int ARC2_start_operation(const uint8_t key[],
+ size_t key_len,
+ size_t effective_key_len,
+ void **pResult);
+ int ARC2_encrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int ARC2_decrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int ARC2_stop_operation(void *state);
+ """
+ )
+
+
+def _create_base_cipher(dict_parameters):
+ """This method instantiates and returns a handle to a low-level
+ base cipher. It will absorb named parameters in the process."""
+
+ try:
+ key = dict_parameters.pop("key")
+ except KeyError:
+ raise TypeError("Missing 'key' parameter")
+
+ effective_keylen = dict_parameters.pop("effective_keylen", 1024)
+
+ if len(key) not in key_size:
+ raise ValueError("Incorrect ARC2 key length (%d bytes)" % len(key))
+
+ if not (40 <= effective_keylen <= 1024):
+ raise ValueError("'effective_key_len' must be at least 40 and no larger than 1024 "
+ "(not %d)" % effective_keylen)
+
+ start_operation = _raw_arc2_lib.ARC2_start_operation
+ stop_operation = _raw_arc2_lib.ARC2_stop_operation
+
+ cipher = VoidPointer()
+ result = start_operation(c_uint8_ptr(key),
+ c_size_t(len(key)),
+ c_size_t(effective_keylen),
+ cipher.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the ARC2 cipher"
+ % result)
+
+ return SmartPointer(cipher.get(), stop_operation)
+
+
+def new(key, mode, *args, **kwargs):
+ """Create a new RC2 cipher.
+
+ :param key:
+ The secret key to use in the symmetric cipher.
+ Its length can vary from 5 to 128 bytes; the actual search space
+ (and the cipher strength) can be reduced with the ``effective_keylen`` parameter.
+ :type key: bytes, bytearray, memoryview
+
+ :param mode:
+ The chaining mode to use for encryption or decryption.
+ :type mode: One of the supported ``MODE_*`` constants
+
+ :Keyword Arguments:
+ * **iv** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_CBC``, ``MODE_CFB``, ``MODE_OFB``,
+ and ``MODE_OPENPGP`` modes).
+
+ The initialization vector to use for encryption or decryption.
+
+ For ``MODE_CBC``, ``MODE_CFB``, and ``MODE_OFB`` it must be 8 bytes long.
+
+ For ``MODE_OPENPGP`` mode only,
+ it must be 8 bytes long for encryption
+ and 10 bytes for decryption (in the latter case, it is
+ actually the *encrypted* IV which was prefixed to the ciphertext).
+
+ If not provided, a random byte string is generated (you must then
+ read its value with the :attr:`iv` attribute).
+
+ * **nonce** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_EAX`` and ``MODE_CTR``).
+
+ A value that must never be reused for any other encryption done
+ with this key.
+
+ For ``MODE_EAX`` there are no
+ restrictions on its length (recommended: **16** bytes).
+
+ For ``MODE_CTR``, its length must be in the range **[0..7]**.
+
+ If not provided for ``MODE_EAX``, a random byte string is generated (you
+ can read it back via the ``nonce`` attribute).
+
+ * **effective_keylen** (*integer*) --
+ Optional. Maximum strength in bits of the actual key used by the ARC2 algorithm.
+ If the supplied ``key`` parameter is longer (in bits) of the value specified
+ here, it will be weakened to match it.
+ If not specified, no limitation is applied.
+
+ * **segment_size** (*integer*) --
+ (Only ``MODE_CFB``).The number of **bits** the plaintext and ciphertext
+ are segmented in. It must be a multiple of 8.
+ If not specified, it will be assumed to be 8.
+
+ * **mac_len** : (*integer*) --
+ (Only ``MODE_EAX``)
+ Length of the authentication tag, in bytes.
+ It must be no longer than 8 (default).
+
+ * **initial_value** : (*integer*) --
+ (Only ``MODE_CTR``). The initial value for the counter within
+ the counter block. By default it is **0**.
+
+ :Return: an ARC2 object, of the applicable mode.
+ """
+
+ return _create_cipher(sys.modules[__name__], key, mode, *args, **kwargs)
+
+MODE_ECB = 1
+MODE_CBC = 2
+MODE_CFB = 3
+MODE_OFB = 5
+MODE_CTR = 6
+MODE_OPENPGP = 7
+MODE_EAX = 9
+
+# Size of a data block (in bytes)
+block_size = 8
+# Size of a key (in bytes)
+key_size = range(5, 128 + 1)
diff --git a/lib/Crypto/Cipher/ARC2.pyi b/lib/Crypto/Cipher/ARC2.pyi
new file mode 100644
index 0000000..055c424
--- /dev/null
+++ b/lib/Crypto/Cipher/ARC2.pyi
@@ -0,0 +1,35 @@
+from typing import Union, Dict, Iterable
+
+from Crypto.Cipher._mode_ecb import EcbMode
+from Crypto.Cipher._mode_cbc import CbcMode
+from Crypto.Cipher._mode_cfb import CfbMode
+from Crypto.Cipher._mode_ofb import OfbMode
+from Crypto.Cipher._mode_ctr import CtrMode
+from Crypto.Cipher._mode_openpgp import OpenPgpMode
+from Crypto.Cipher._mode_eax import EaxMode
+
+ARC2Mode = int
+
+MODE_ECB: ARC2Mode
+MODE_CBC: ARC2Mode
+MODE_CFB: ARC2Mode
+MODE_OFB: ARC2Mode
+MODE_CTR: ARC2Mode
+MODE_OPENPGP: ARC2Mode
+MODE_EAX: ARC2Mode
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ mode: ARC2Mode,
+ iv : Buffer = ...,
+ IV : Buffer = ...,
+ nonce : Buffer = ...,
+ segment_size : int = ...,
+ mac_len : int = ...,
+ initial_value : Union[int, Buffer] = ...,
+ counter : Dict = ...) -> \
+ Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode]: ...
+
+block_size: int
+key_size: Iterable[int]
diff --git a/lib/Crypto/Cipher/ARC4.py b/lib/Crypto/Cipher/ARC4.py
new file mode 100644
index 0000000..7150ea6
--- /dev/null
+++ b/lib/Crypto/Cipher/ARC4.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/ARC4.py : ARC4
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import b
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ create_string_buffer, get_raw_buffer,
+ SmartPointer, c_size_t, c_uint8_ptr)
+
+
+_raw_arc4_lib = load_pycryptodome_raw_lib("Crypto.Cipher._ARC4", """
+ int ARC4_stream_encrypt(void *rc4State, const uint8_t in[],
+ uint8_t out[], size_t len);
+ int ARC4_stream_init(uint8_t *key, size_t keylen,
+ void **pRc4State);
+ int ARC4_stream_destroy(void *rc4State);
+ """)
+
+
+class ARC4Cipher:
+ """ARC4 cipher object. Do not create it directly. Use
+ :func:`Crypto.Cipher.ARC4.new` instead.
+ """
+
+ def __init__(self, key, *args, **kwargs):
+ """Initialize an ARC4 cipher object
+
+ See also `new()` at the module level."""
+
+ if len(args) > 0:
+ ndrop = args[0]
+ args = args[1:]
+ else:
+ ndrop = kwargs.pop('drop', 0)
+
+ if len(key) not in key_size:
+ raise ValueError("Incorrect ARC4 key length (%d bytes)" %
+ len(key))
+
+ self._state = VoidPointer()
+ result = _raw_arc4_lib.ARC4_stream_init(c_uint8_ptr(key),
+ c_size_t(len(key)),
+ self._state.address_of())
+ if result != 0:
+ raise ValueError("Error %d while creating the ARC4 cipher"
+ % result)
+ self._state = SmartPointer(self._state.get(),
+ _raw_arc4_lib.ARC4_stream_destroy)
+
+ if ndrop > 0:
+ # This is OK even if the cipher is used for decryption,
+ # since encrypt and decrypt are actually the same thing
+ # with ARC4.
+ self.encrypt(b'\x00' * ndrop)
+
+ self.block_size = 1
+ self.key_size = len(key)
+
+ def encrypt(self, plaintext):
+ """Encrypt a piece of data.
+
+ :param plaintext: The data to encrypt, of any size.
+ :type plaintext: bytes, bytearray, memoryview
+ :returns: the encrypted byte string, of equal length as the
+ plaintext.
+ """
+
+ ciphertext = create_string_buffer(len(plaintext))
+ result = _raw_arc4_lib.ARC4_stream_encrypt(self._state.get(),
+ c_uint8_ptr(plaintext),
+ ciphertext,
+ c_size_t(len(plaintext)))
+ if result:
+ raise ValueError("Error %d while encrypting with RC4" % result)
+ return get_raw_buffer(ciphertext)
+
+ def decrypt(self, ciphertext):
+ """Decrypt a piece of data.
+
+ :param ciphertext: The data to decrypt, of any size.
+ :type ciphertext: bytes, bytearray, memoryview
+ :returns: the decrypted byte string, of equal length as the
+ ciphertext.
+ """
+
+ try:
+ return self.encrypt(ciphertext)
+ except ValueError as e:
+ raise ValueError(str(e).replace("enc", "dec"))
+
+
+def new(key, *args, **kwargs):
+ """Create a new ARC4 cipher.
+
+ :param key:
+ The secret key to use in the symmetric cipher.
+ Its length must be in the range ``[5..256]``.
+ The recommended length is 16 bytes.
+ :type key: bytes, bytearray, memoryview
+
+ :Keyword Arguments:
+ * *drop* (``integer``) --
+ The amount of bytes to discard from the initial part of the keystream.
+ In fact, such part has been found to be distinguishable from random
+ data (while it shouldn't) and also correlated to key.
+
+ The recommended value is 3072_ bytes. The default value is 0.
+
+ :Return: an `ARC4Cipher` object
+
+ .. _3072: http://eprint.iacr.org/2002/067.pdf
+ """
+ return ARC4Cipher(key, *args, **kwargs)
+
+# Size of a data block (in bytes)
+block_size = 1
+# Size of a key (in bytes)
+key_size = range(5, 256+1)
diff --git a/lib/Crypto/Cipher/ARC4.pyi b/lib/Crypto/Cipher/ARC4.pyi
new file mode 100644
index 0000000..2e75d6f
--- /dev/null
+++ b/lib/Crypto/Cipher/ARC4.pyi
@@ -0,0 +1,16 @@
+from typing import Any, Union, Iterable
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class ARC4Cipher:
+ block_size: int
+ key_size: int
+
+ def __init__(self, key: Buffer, *args: Any, **kwargs: Any) -> None: ...
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ def decrypt(self, ciphertext: Buffer) -> bytes: ...
+
+def new(key: Buffer, drop : int = ...) -> ARC4Cipher: ...
+
+block_size: int
+key_size: Iterable[int]
diff --git a/lib/Crypto/Cipher/Blowfish.py b/lib/Crypto/Cipher/Blowfish.py
new file mode 100644
index 0000000..6005ffe
--- /dev/null
+++ b/lib/Crypto/Cipher/Blowfish.py
@@ -0,0 +1,159 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/Blowfish.py : Blowfish
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+"""
+Module's constants for the modes of operation supported with Blowfish:
+
+:var MODE_ECB: :ref:`Electronic Code Book (ECB) <ecb_mode>`
+:var MODE_CBC: :ref:`Cipher-Block Chaining (CBC) <cbc_mode>`
+:var MODE_CFB: :ref:`Cipher FeedBack (CFB) <cfb_mode>`
+:var MODE_OFB: :ref:`Output FeedBack (OFB) <ofb_mode>`
+:var MODE_CTR: :ref:`CounTer Mode (CTR) <ctr_mode>`
+:var MODE_OPENPGP: :ref:`OpenPGP Mode <openpgp_mode>`
+:var MODE_EAX: :ref:`EAX Mode <eax_mode>`
+"""
+
+import sys
+
+from Crypto.Cipher import _create_cipher
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer, c_size_t,
+ c_uint8_ptr)
+
+_raw_blowfish_lib = load_pycryptodome_raw_lib(
+ "Crypto.Cipher._raw_blowfish",
+ """
+ int Blowfish_start_operation(const uint8_t key[],
+ size_t key_len,
+ void **pResult);
+ int Blowfish_encrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int Blowfish_decrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int Blowfish_stop_operation(void *state);
+ """
+ )
+
+
+def _create_base_cipher(dict_parameters):
+ """This method instantiates and returns a smart pointer to
+ a low-level base cipher. It will absorb named parameters in
+ the process."""
+
+ try:
+ key = dict_parameters.pop("key")
+ except KeyError:
+ raise TypeError("Missing 'key' parameter")
+
+ if len(key) not in key_size:
+ raise ValueError("Incorrect Blowfish key length (%d bytes)" % len(key))
+
+ start_operation = _raw_blowfish_lib.Blowfish_start_operation
+ stop_operation = _raw_blowfish_lib.Blowfish_stop_operation
+
+ void_p = VoidPointer()
+ result = start_operation(c_uint8_ptr(key),
+ c_size_t(len(key)),
+ void_p.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the Blowfish cipher"
+ % result)
+ return SmartPointer(void_p.get(), stop_operation)
+
+
+def new(key, mode, *args, **kwargs):
+ """Create a new Blowfish cipher
+
+ :param key:
+ The secret key to use in the symmetric cipher.
+ Its length can vary from 5 to 56 bytes.
+ :type key: bytes, bytearray, memoryview
+
+ :param mode:
+ The chaining mode to use for encryption or decryption.
+ :type mode: One of the supported ``MODE_*`` constants
+
+ :Keyword Arguments:
+ * **iv** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_CBC``, ``MODE_CFB``, ``MODE_OFB``,
+ and ``MODE_OPENPGP`` modes).
+
+ The initialization vector to use for encryption or decryption.
+
+ For ``MODE_CBC``, ``MODE_CFB``, and ``MODE_OFB`` it must be 8 bytes long.
+
+ For ``MODE_OPENPGP`` mode only,
+ it must be 8 bytes long for encryption
+ and 10 bytes for decryption (in the latter case, it is
+ actually the *encrypted* IV which was prefixed to the ciphertext).
+
+ If not provided, a random byte string is generated (you must then
+ read its value with the :attr:`iv` attribute).
+
+ * **nonce** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_EAX`` and ``MODE_CTR``).
+
+ A value that must never be reused for any other encryption done
+ with this key.
+
+ For ``MODE_EAX`` there are no
+ restrictions on its length (recommended: **16** bytes).
+
+ For ``MODE_CTR``, its length must be in the range **[0..7]**.
+
+ If not provided for ``MODE_EAX``, a random byte string is generated (you
+ can read it back via the ``nonce`` attribute).
+
+ * **segment_size** (*integer*) --
+ (Only ``MODE_CFB``).The number of **bits** the plaintext and ciphertext
+ are segmented in. It must be a multiple of 8.
+ If not specified, it will be assumed to be 8.
+
+ * **mac_len** : (*integer*) --
+ (Only ``MODE_EAX``)
+ Length of the authentication tag, in bytes.
+ It must be no longer than 8 (default).
+
+ * **initial_value** : (*integer*) --
+ (Only ``MODE_CTR``). The initial value for the counter within
+ the counter block. By default it is **0**.
+
+ :Return: a Blowfish object, of the applicable mode.
+ """
+
+ return _create_cipher(sys.modules[__name__], key, mode, *args, **kwargs)
+
+MODE_ECB = 1
+MODE_CBC = 2
+MODE_CFB = 3
+MODE_OFB = 5
+MODE_CTR = 6
+MODE_OPENPGP = 7
+MODE_EAX = 9
+
+# Size of a data block (in bytes)
+block_size = 8
+# Size of a key (in bytes)
+key_size = range(4, 56 + 1)
diff --git a/lib/Crypto/Cipher/Blowfish.pyi b/lib/Crypto/Cipher/Blowfish.pyi
new file mode 100644
index 0000000..eff9da9
--- /dev/null
+++ b/lib/Crypto/Cipher/Blowfish.pyi
@@ -0,0 +1,35 @@
+from typing import Union, Dict, Iterable
+
+from Crypto.Cipher._mode_ecb import EcbMode
+from Crypto.Cipher._mode_cbc import CbcMode
+from Crypto.Cipher._mode_cfb import CfbMode
+from Crypto.Cipher._mode_ofb import OfbMode
+from Crypto.Cipher._mode_ctr import CtrMode
+from Crypto.Cipher._mode_openpgp import OpenPgpMode
+from Crypto.Cipher._mode_eax import EaxMode
+
+BlowfishMode = int
+
+MODE_ECB: BlowfishMode
+MODE_CBC: BlowfishMode
+MODE_CFB: BlowfishMode
+MODE_OFB: BlowfishMode
+MODE_CTR: BlowfishMode
+MODE_OPENPGP: BlowfishMode
+MODE_EAX: BlowfishMode
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ mode: BlowfishMode,
+ iv : Buffer = ...,
+ IV : Buffer = ...,
+ nonce : Buffer = ...,
+ segment_size : int = ...,
+ mac_len : int = ...,
+ initial_value : Union[int, Buffer] = ...,
+ counter : Dict = ...) -> \
+ Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode]: ...
+
+block_size: int
+key_size: Iterable[int]
diff --git a/lib/Crypto/Cipher/CAST.py b/lib/Crypto/Cipher/CAST.py
new file mode 100644
index 0000000..c7e82c1
--- /dev/null
+++ b/lib/Crypto/Cipher/CAST.py
@@ -0,0 +1,159 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/CAST.py : CAST
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+"""
+Module's constants for the modes of operation supported with CAST:
+
+:var MODE_ECB: :ref:`Electronic Code Book (ECB) <ecb_mode>`
+:var MODE_CBC: :ref:`Cipher-Block Chaining (CBC) <cbc_mode>`
+:var MODE_CFB: :ref:`Cipher FeedBack (CFB) <cfb_mode>`
+:var MODE_OFB: :ref:`Output FeedBack (OFB) <ofb_mode>`
+:var MODE_CTR: :ref:`CounTer Mode (CTR) <ctr_mode>`
+:var MODE_OPENPGP: :ref:`OpenPGP Mode <openpgp_mode>`
+:var MODE_EAX: :ref:`EAX Mode <eax_mode>`
+"""
+
+import sys
+
+from Crypto.Cipher import _create_cipher
+from Crypto.Util.py3compat import byte_string
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ c_size_t, c_uint8_ptr)
+
+_raw_cast_lib = load_pycryptodome_raw_lib(
+ "Crypto.Cipher._raw_cast",
+ """
+ int CAST_start_operation(const uint8_t key[],
+ size_t key_len,
+ void **pResult);
+ int CAST_encrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CAST_decrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CAST_stop_operation(void *state);
+ """)
+
+
+def _create_base_cipher(dict_parameters):
+ """This method instantiates and returns a handle to a low-level
+ base cipher. It will absorb named parameters in the process."""
+
+ try:
+ key = dict_parameters.pop("key")
+ except KeyError:
+ raise TypeError("Missing 'key' parameter")
+
+ if len(key) not in key_size:
+ raise ValueError("Incorrect CAST key length (%d bytes)" % len(key))
+
+ start_operation = _raw_cast_lib.CAST_start_operation
+ stop_operation = _raw_cast_lib.CAST_stop_operation
+
+ cipher = VoidPointer()
+ result = start_operation(c_uint8_ptr(key),
+ c_size_t(len(key)),
+ cipher.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the CAST cipher"
+ % result)
+
+ return SmartPointer(cipher.get(), stop_operation)
+
+
+def new(key, mode, *args, **kwargs):
+ """Create a new CAST cipher
+
+ :param key:
+ The secret key to use in the symmetric cipher.
+ Its length can vary from 5 to 16 bytes.
+ :type key: bytes, bytearray, memoryview
+
+ :param mode:
+ The chaining mode to use for encryption or decryption.
+ :type mode: One of the supported ``MODE_*`` constants
+
+ :Keyword Arguments:
+ * **iv** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_CBC``, ``MODE_CFB``, ``MODE_OFB``,
+ and ``MODE_OPENPGP`` modes).
+
+ The initialization vector to use for encryption or decryption.
+
+ For ``MODE_CBC``, ``MODE_CFB``, and ``MODE_OFB`` it must be 8 bytes long.
+
+ For ``MODE_OPENPGP`` mode only,
+ it must be 8 bytes long for encryption
+ and 10 bytes for decryption (in the latter case, it is
+ actually the *encrypted* IV which was prefixed to the ciphertext).
+
+ If not provided, a random byte string is generated (you must then
+ read its value with the :attr:`iv` attribute).
+
+ * **nonce** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_EAX`` and ``MODE_CTR``).
+
+ A value that must never be reused for any other encryption done
+ with this key.
+
+ For ``MODE_EAX`` there are no
+ restrictions on its length (recommended: **16** bytes).
+
+ For ``MODE_CTR``, its length must be in the range **[0..7]**.
+
+ If not provided for ``MODE_EAX``, a random byte string is generated (you
+ can read it back via the ``nonce`` attribute).
+
+ * **segment_size** (*integer*) --
+ (Only ``MODE_CFB``).The number of **bits** the plaintext and ciphertext
+ are segmented in. It must be a multiple of 8.
+ If not specified, it will be assumed to be 8.
+
+ * **mac_len** : (*integer*) --
+ (Only ``MODE_EAX``)
+ Length of the authentication tag, in bytes.
+ It must be no longer than 8 (default).
+
+ * **initial_value** : (*integer*) --
+ (Only ``MODE_CTR``). The initial value for the counter within
+ the counter block. By default it is **0**.
+
+ :Return: a CAST object, of the applicable mode.
+ """
+
+ return _create_cipher(sys.modules[__name__], key, mode, *args, **kwargs)
+
+MODE_ECB = 1
+MODE_CBC = 2
+MODE_CFB = 3
+MODE_OFB = 5
+MODE_CTR = 6
+MODE_OPENPGP = 7
+MODE_EAX = 9
+
+# Size of a data block (in bytes)
+block_size = 8
+# Size of a key (in bytes)
+key_size = range(5, 16 + 1)
diff --git a/lib/Crypto/Cipher/CAST.pyi b/lib/Crypto/Cipher/CAST.pyi
new file mode 100644
index 0000000..a0cb6af
--- /dev/null
+++ b/lib/Crypto/Cipher/CAST.pyi
@@ -0,0 +1,35 @@
+from typing import Union, Dict, Iterable
+
+from Crypto.Cipher._mode_ecb import EcbMode
+from Crypto.Cipher._mode_cbc import CbcMode
+from Crypto.Cipher._mode_cfb import CfbMode
+from Crypto.Cipher._mode_ofb import OfbMode
+from Crypto.Cipher._mode_ctr import CtrMode
+from Crypto.Cipher._mode_openpgp import OpenPgpMode
+from Crypto.Cipher._mode_eax import EaxMode
+
+CASTMode = int
+
+MODE_ECB: CASTMode
+MODE_CBC: CASTMode
+MODE_CFB: CASTMode
+MODE_OFB: CASTMode
+MODE_CTR: CASTMode
+MODE_OPENPGP: CASTMode
+MODE_EAX: CASTMode
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ mode: CASTMode,
+ iv : Buffer = ...,
+ IV : Buffer = ...,
+ nonce : Buffer = ...,
+ segment_size : int = ...,
+ mac_len : int = ...,
+ initial_value : Union[int, Buffer] = ...,
+ counter : Dict = ...) -> \
+ Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode]: ...
+
+block_size: int
+key_size : Iterable[int]
diff --git a/lib/Crypto/Cipher/ChaCha20.py b/lib/Crypto/Cipher/ChaCha20.py
new file mode 100644
index 0000000..9bd2252
--- /dev/null
+++ b/lib/Crypto/Cipher/ChaCha20.py
@@ -0,0 +1,287 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Random import get_random_bytes
+
+from Crypto.Util.py3compat import _copy_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ create_string_buffer,
+ get_raw_buffer, VoidPointer,
+ SmartPointer, c_size_t,
+ c_uint8_ptr, c_ulong,
+ is_writeable_buffer)
+
+_raw_chacha20_lib = load_pycryptodome_raw_lib("Crypto.Cipher._chacha20",
+ """
+ int chacha20_init(void **pState,
+ const uint8_t *key,
+ size_t keySize,
+ const uint8_t *nonce,
+ size_t nonceSize);
+
+ int chacha20_destroy(void *state);
+
+ int chacha20_encrypt(void *state,
+ const uint8_t in[],
+ uint8_t out[],
+ size_t len);
+
+ int chacha20_seek(void *state,
+ unsigned long block_high,
+ unsigned long block_low,
+ unsigned offset);
+ int hchacha20( const uint8_t key[32],
+ const uint8_t nonce16[16],
+ uint8_t subkey[32]);
+ """)
+
+
+def _HChaCha20(key, nonce):
+
+ assert(len(key) == 32)
+ assert(len(nonce) == 16)
+
+ subkey = bytearray(32)
+ result = _raw_chacha20_lib.hchacha20(
+ c_uint8_ptr(key),
+ c_uint8_ptr(nonce),
+ c_uint8_ptr(subkey))
+ if result:
+ raise ValueError("Error %d when deriving subkey with HChaCha20" % result)
+
+ return subkey
+
+
+class ChaCha20Cipher(object):
+ """ChaCha20 (or XChaCha20) cipher object.
+ Do not create it directly. Use :py:func:`new` instead.
+
+ :var nonce: The nonce with length 8, 12 or 24 bytes
+ :vartype nonce: bytes
+ """
+
+ block_size = 1
+
+ def __init__(self, key, nonce):
+ """Initialize a ChaCha20/XChaCha20 cipher object
+
+ See also `new()` at the module level."""
+
+ self.nonce = _copy_bytes(None, None, nonce)
+
+ # XChaCha20 requires a key derivation with HChaCha20
+ # See 2.3 in https://tools.ietf.org/html/draft-arciszewski-xchacha-03
+ if len(nonce) == 24:
+ key = _HChaCha20(key, nonce[:16])
+ nonce = b'\x00' * 4 + nonce[16:]
+ self._name = "XChaCha20"
+ else:
+ self._name = "ChaCha20"
+ nonce = self.nonce
+
+ self._next = ( self.encrypt, self.decrypt )
+
+ self._state = VoidPointer()
+ result = _raw_chacha20_lib.chacha20_init(
+ self._state.address_of(),
+ c_uint8_ptr(key),
+ c_size_t(len(key)),
+ nonce,
+ c_size_t(len(nonce)))
+ if result:
+ raise ValueError("Error %d instantiating a %s cipher" % (result,
+ self._name))
+ self._state = SmartPointer(self._state.get(),
+ _raw_chacha20_lib.chacha20_destroy)
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt a piece of data.
+
+ Args:
+ plaintext(bytes/bytearray/memoryview): The data to encrypt, of any size.
+ Keyword Args:
+ output(bytes/bytearray/memoryview): The location where the ciphertext
+ is written to. If ``None``, the ciphertext is returned.
+ Returns:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("Cipher object can only be used for decryption")
+ self._next = ( self.encrypt, )
+ return self._encrypt(plaintext, output)
+
+ def _encrypt(self, plaintext, output):
+ """Encrypt without FSM checks"""
+
+ if output is None:
+ ciphertext = create_string_buffer(len(plaintext))
+ else:
+ ciphertext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(plaintext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = _raw_chacha20_lib.chacha20_encrypt(
+ self._state.get(),
+ c_uint8_ptr(plaintext),
+ c_uint8_ptr(ciphertext),
+ c_size_t(len(plaintext)))
+ if result:
+ raise ValueError("Error %d while encrypting with %s" % (result, self._name))
+
+ if output is None:
+ return get_raw_buffer(ciphertext)
+ else:
+ return None
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt a piece of data.
+
+ Args:
+ ciphertext(bytes/bytearray/memoryview): The data to decrypt, of any size.
+ Keyword Args:
+ output(bytes/bytearray/memoryview): The location where the plaintext
+ is written to. If ``None``, the plaintext is returned.
+ Returns:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("Cipher object can only be used for encryption")
+ self._next = ( self.decrypt, )
+
+ try:
+ return self._encrypt(ciphertext, output)
+ except ValueError as e:
+ raise ValueError(str(e).replace("enc", "dec"))
+
+ def seek(self, position):
+ """Seek to a certain position in the key stream.
+
+ Args:
+ position (integer):
+ The absolute position within the key stream, in bytes.
+ """
+
+ position, offset = divmod(position, 64)
+ block_low = position & 0xFFFFFFFF
+ block_high = position >> 32
+
+ result = _raw_chacha20_lib.chacha20_seek(
+ self._state.get(),
+ c_ulong(block_high),
+ c_ulong(block_low),
+ offset
+ )
+ if result:
+ raise ValueError("Error %d while seeking with %s" % (result, self._name))
+
+
+def _derive_Poly1305_key_pair(key, nonce):
+ """Derive a tuple (r, s, nonce) for a Poly1305 MAC.
+
+ If nonce is ``None``, a new 12-byte nonce is generated.
+ """
+
+ if len(key) != 32:
+ raise ValueError("Poly1305 with ChaCha20 requires a 32-byte key")
+
+ if nonce is None:
+ padded_nonce = nonce = get_random_bytes(12)
+ elif len(nonce) == 8:
+ # See RFC7538, 2.6: [...] ChaCha20 as specified here requires a 96-bit
+ # nonce. So if the provided nonce is only 64-bit, then the first 32
+ # bits of the nonce will be set to a constant number.
+ # This will usually be zero, but for protocols with multiple senders it may be
+ # different for each sender, but should be the same for all
+ # invocations of the function with the same key by a particular
+ # sender.
+ padded_nonce = b'\x00\x00\x00\x00' + nonce
+ elif len(nonce) == 12:
+ padded_nonce = nonce
+ else:
+ raise ValueError("Poly1305 with ChaCha20 requires an 8- or 12-byte nonce")
+
+ rs = new(key=key, nonce=padded_nonce).encrypt(b'\x00' * 32)
+ return rs[:16], rs[16:], nonce
+
+
+def new(**kwargs):
+ """Create a new ChaCha20 or XChaCha20 cipher
+
+ Keyword Args:
+ key (bytes/bytearray/memoryview): The secret key to use.
+ It must be 32 bytes long.
+ nonce (bytes/bytearray/memoryview): A mandatory value that
+ must never be reused for any other encryption
+ done with this key.
+
+ For ChaCha20, it must be 8 or 12 bytes long.
+
+ For XChaCha20, it must be 24 bytes long.
+
+ If not provided, 8 bytes will be randomly generated
+ (you can find them back in the ``nonce`` attribute).
+
+ :Return: a :class:`Crypto.Cipher.ChaCha20.ChaCha20Cipher` object
+ """
+
+ try:
+ key = kwargs.pop("key")
+ except KeyError as e:
+ raise TypeError("Missing parameter %s" % e)
+
+ nonce = kwargs.pop("nonce", None)
+ if nonce is None:
+ nonce = get_random_bytes(8)
+
+ if len(key) != 32:
+ raise ValueError("ChaCha20/XChaCha20 key must be 32 bytes long")
+
+ if len(nonce) not in (8, 12, 24):
+ raise ValueError("Nonce must be 8/12 bytes(ChaCha20) or 24 bytes (XChaCha20)")
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return ChaCha20Cipher(key, nonce)
+
+# Size of a data block (in bytes)
+block_size = 1
+
+# Size of a key (in bytes)
+key_size = 32
diff --git a/lib/Crypto/Cipher/ChaCha20.pyi b/lib/Crypto/Cipher/ChaCha20.pyi
new file mode 100644
index 0000000..3d00a1d
--- /dev/null
+++ b/lib/Crypto/Cipher/ChaCha20.pyi
@@ -0,0 +1,25 @@
+from typing import Union, overload
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def _HChaCha20(key: Buffer, nonce: Buffer) -> bytearray: ...
+
+class ChaCha20Cipher:
+ block_size: int
+ nonce: bytes
+
+ def __init__(self, key: Buffer, nonce: Buffer) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ def seek(self, position: int) -> None: ...
+
+def new(key: Buffer, nonce: Buffer = ...) -> ChaCha20Cipher: ...
+
+block_size: int
+key_size: int
diff --git a/lib/Crypto/Cipher/ChaCha20_Poly1305.py b/lib/Crypto/Cipher/ChaCha20_Poly1305.py
new file mode 100644
index 0000000..21ddca3
--- /dev/null
+++ b/lib/Crypto/Cipher/ChaCha20_Poly1305.py
@@ -0,0 +1,336 @@
+# ===================================================================
+#
+# Copyright (c) 2018, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from binascii import unhexlify
+
+from Crypto.Cipher import ChaCha20
+from Crypto.Cipher.ChaCha20 import _HChaCha20
+from Crypto.Hash import Poly1305, BLAKE2s
+
+from Crypto.Random import get_random_bytes
+
+from Crypto.Util.number import long_to_bytes
+from Crypto.Util.py3compat import _copy_bytes, bord
+from Crypto.Util._raw_api import is_buffer
+
+
+def _enum(**enums):
+ return type('Enum', (), enums)
+
+
+_CipherStatus = _enum(PROCESSING_AUTH_DATA=1,
+ PROCESSING_CIPHERTEXT=2,
+ PROCESSING_DONE=3)
+
+
+class ChaCha20Poly1305Cipher(object):
+ """ChaCha20-Poly1305 and XChaCha20-Poly1305 cipher object.
+ Do not create it directly. Use :py:func:`new` instead.
+
+ :var nonce: The nonce with length 8, 12 or 24 bytes
+ :vartype nonce: byte string
+ """
+
+ def __init__(self, key, nonce):
+ """Initialize a ChaCha20-Poly1305 AEAD cipher object
+
+ See also `new()` at the module level."""
+
+ self.nonce = _copy_bytes(None, None, nonce)
+
+ self._next = (self.update, self.encrypt, self.decrypt, self.digest,
+ self.verify)
+
+ self._authenticator = Poly1305.new(key=key, nonce=nonce, cipher=ChaCha20)
+
+ self._cipher = ChaCha20.new(key=key, nonce=nonce)
+ self._cipher.seek(64) # Block counter starts at 1
+
+ self._len_aad = 0
+ self._len_ct = 0
+ self._mac_tag = None
+ self._status = _CipherStatus.PROCESSING_AUTH_DATA
+
+ def update(self, data):
+ """Protect the associated data.
+
+ Associated data (also known as *additional authenticated data* - AAD)
+ is the piece of the message that must stay in the clear, while
+ still allowing the receiver to verify its integrity.
+ An example is packet headers.
+
+ The associated data (possibly split into multiple segments) is
+ fed into :meth:`update` before any call to :meth:`decrypt` or :meth:`encrypt`.
+ If there is no associated data, :meth:`update` is not called.
+
+ :param bytes/bytearray/memoryview assoc_data:
+ A piece of associated data. There are no restrictions on its size.
+ """
+
+ if self.update not in self._next:
+ raise TypeError("update() method cannot be called")
+
+ self._len_aad += len(data)
+ self._authenticator.update(data)
+
+ def _pad_aad(self):
+
+ assert(self._status == _CipherStatus.PROCESSING_AUTH_DATA)
+ if self._len_aad & 0x0F:
+ self._authenticator.update(b'\x00' * (16 - (self._len_aad & 0x0F)))
+ self._status = _CipherStatus.PROCESSING_CIPHERTEXT
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt a piece of data.
+
+ Args:
+ plaintext(bytes/bytearray/memoryview): The data to encrypt, of any size.
+ Keyword Args:
+ output(bytes/bytearray/memoryview): The location where the ciphertext
+ is written to. If ``None``, the ciphertext is returned.
+ Returns:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() method cannot be called")
+
+ if self._status == _CipherStatus.PROCESSING_AUTH_DATA:
+ self._pad_aad()
+
+ self._next = (self.encrypt, self.digest)
+
+ result = self._cipher.encrypt(plaintext, output=output)
+ self._len_ct += len(plaintext)
+ if output is None:
+ self._authenticator.update(result)
+ else:
+ self._authenticator.update(output)
+ return result
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt a piece of data.
+
+ Args:
+ ciphertext(bytes/bytearray/memoryview): The data to decrypt, of any size.
+ Keyword Args:
+ output(bytes/bytearray/memoryview): The location where the plaintext
+ is written to. If ``None``, the plaintext is returned.
+ Returns:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() method cannot be called")
+
+ if self._status == _CipherStatus.PROCESSING_AUTH_DATA:
+ self._pad_aad()
+
+ self._next = (self.decrypt, self.verify)
+
+ self._len_ct += len(ciphertext)
+ self._authenticator.update(ciphertext)
+ return self._cipher.decrypt(ciphertext, output=output)
+
+ def _compute_mac(self):
+ """Finalize the cipher (if not done already) and return the MAC."""
+
+ if self._mac_tag:
+ assert(self._status == _CipherStatus.PROCESSING_DONE)
+ return self._mac_tag
+
+ assert(self._status != _CipherStatus.PROCESSING_DONE)
+
+ if self._status == _CipherStatus.PROCESSING_AUTH_DATA:
+ self._pad_aad()
+
+ if self._len_ct & 0x0F:
+ self._authenticator.update(b'\x00' * (16 - (self._len_ct & 0x0F)))
+
+ self._status = _CipherStatus.PROCESSING_DONE
+
+ self._authenticator.update(long_to_bytes(self._len_aad, 8)[::-1])
+ self._authenticator.update(long_to_bytes(self._len_ct, 8)[::-1])
+ self._mac_tag = self._authenticator.digest()
+ return self._mac_tag
+
+ def digest(self):
+ """Compute the *binary* authentication tag (MAC).
+
+ :Return: the MAC tag, as 16 ``bytes``.
+ """
+
+ if self.digest not in self._next:
+ raise TypeError("digest() method cannot be called")
+ self._next = (self.digest,)
+
+ return self._compute_mac()
+
+ def hexdigest(self):
+ """Compute the *printable* authentication tag (MAC).
+
+ This method is like :meth:`digest`.
+
+ :Return: the MAC tag, as a hexadecimal string.
+ """
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def verify(self, received_mac_tag):
+ """Validate the *binary* authentication tag (MAC).
+
+ The receiver invokes this method at the very end, to
+ check if the associated data (if any) and the decrypted
+ messages are valid.
+
+ :param bytes/bytearray/memoryview received_mac_tag:
+ This is the 16-byte *binary* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ if self.verify not in self._next:
+ raise TypeError("verify() cannot be called"
+ " when encrypting a message")
+ self._next = (self.verify,)
+
+ secret = get_random_bytes(16)
+
+ self._compute_mac()
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret,
+ data=self._mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret,
+ data=received_mac_tag)
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Validate the *printable* authentication tag (MAC).
+
+ This method is like :meth:`verify`.
+
+ :param string hex_mac_tag:
+ This is the *printable* MAC.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ self.verify(unhexlify(hex_mac_tag))
+
+ def encrypt_and_digest(self, plaintext):
+ """Perform :meth:`encrypt` and :meth:`digest` in one step.
+
+ :param plaintext: The data to encrypt, of any size.
+ :type plaintext: bytes/bytearray/memoryview
+ :return: a tuple with two ``bytes`` objects:
+
+ - the ciphertext, of equal length as the plaintext
+ - the 16-byte MAC tag
+ """
+
+ return self.encrypt(plaintext), self.digest()
+
+ def decrypt_and_verify(self, ciphertext, received_mac_tag):
+ """Perform :meth:`decrypt` and :meth:`verify` in one step.
+
+ :param ciphertext: The piece of data to decrypt.
+ :type ciphertext: bytes/bytearray/memoryview
+ :param bytes received_mac_tag:
+ This is the 16-byte *binary* MAC, as received from the sender.
+ :return: the decrypted data (as ``bytes``)
+ :raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ plaintext = self.decrypt(ciphertext)
+ self.verify(received_mac_tag)
+ return plaintext
+
+
+def new(**kwargs):
+ """Create a new ChaCha20-Poly1305 or XChaCha20-Poly1305 AEAD cipher.
+
+ :keyword key: The secret key to use. It must be 32 bytes long.
+ :type key: byte string
+
+ :keyword nonce:
+ A value that must never be reused for any other encryption
+ done with this key.
+
+ For ChaCha20-Poly1305, it must be 8 or 12 bytes long.
+
+ For XChaCha20-Poly1305, it must be 24 bytes long.
+
+ If not provided, 12 ``bytes`` will be generated randomly
+ (you can find them back in the ``nonce`` attribute).
+ :type nonce: bytes, bytearray, memoryview
+
+ :Return: a :class:`Crypto.Cipher.ChaCha20.ChaCha20Poly1305Cipher` object
+ """
+
+ try:
+ key = kwargs.pop("key")
+ except KeyError as e:
+ raise TypeError("Missing parameter %s" % e)
+
+ self._len_ct += len(plaintext)
+
+ if len(key) != 32:
+ raise ValueError("Key must be 32 bytes long")
+
+ nonce = kwargs.pop("nonce", None)
+ if nonce is None:
+ nonce = get_random_bytes(12)
+
+ if len(nonce) in (8, 12):
+ pass
+ elif len(nonce) == 24:
+ key = _HChaCha20(key, nonce[:16])
+ nonce = b'\x00\x00\x00\x00' + nonce[16:]
+ else:
+ raise ValueError("Nonce must be 8, 12 or 24 bytes long")
+
+ if not is_buffer(nonce):
+ raise TypeError("nonce must be bytes, bytearray or memoryview")
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return ChaCha20Poly1305Cipher(key, nonce)
+
+
+# Size of a key (in bytes)
+key_size = 32
diff --git a/lib/Crypto/Cipher/ChaCha20_Poly1305.pyi b/lib/Crypto/Cipher/ChaCha20_Poly1305.pyi
new file mode 100644
index 0000000..ef0450f
--- /dev/null
+++ b/lib/Crypto/Cipher/ChaCha20_Poly1305.pyi
@@ -0,0 +1,28 @@
+from typing import Union, Tuple, overload
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class ChaCha20Poly1305Cipher:
+ nonce: bytes
+
+ def __init__(self, key: Buffer, nonce: Buffer) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, received_mac_tag: Buffer) -> None: ...
+ def hexverify(self, received_mac_tag: str) -> None: ...
+ def encrypt_and_digest(self, plaintext: Buffer) -> Tuple[bytes, bytes]: ...
+ def decrypt_and_verify(self, ciphertext: Buffer, received_mac_tag: Buffer) -> bytes: ...
+
+def new(key: Buffer, nonce: Buffer = ...) -> ChaCha20Poly1305Cipher: ...
+
+block_size: int
+key_size: int
diff --git a/lib/Crypto/Cipher/DES.py b/lib/Crypto/Cipher/DES.py
new file mode 100644
index 0000000..5cc286a
--- /dev/null
+++ b/lib/Crypto/Cipher/DES.py
@@ -0,0 +1,158 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/DES.py : DES
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+"""
+Module's constants for the modes of operation supported with Single DES:
+
+:var MODE_ECB: :ref:`Electronic Code Book (ECB) <ecb_mode>`
+:var MODE_CBC: :ref:`Cipher-Block Chaining (CBC) <cbc_mode>`
+:var MODE_CFB: :ref:`Cipher FeedBack (CFB) <cfb_mode>`
+:var MODE_OFB: :ref:`Output FeedBack (OFB) <ofb_mode>`
+:var MODE_CTR: :ref:`CounTer Mode (CTR) <ctr_mode>`
+:var MODE_OPENPGP: :ref:`OpenPGP Mode <openpgp_mode>`
+:var MODE_EAX: :ref:`EAX Mode <eax_mode>`
+"""
+
+import sys
+
+from Crypto.Cipher import _create_cipher
+from Crypto.Util.py3compat import byte_string
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ c_size_t, c_uint8_ptr)
+
+_raw_des_lib = load_pycryptodome_raw_lib(
+ "Crypto.Cipher._raw_des",
+ """
+ int DES_start_operation(const uint8_t key[],
+ size_t key_len,
+ void **pResult);
+ int DES_encrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int DES_decrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int DES_stop_operation(void *state);
+ """)
+
+
+def _create_base_cipher(dict_parameters):
+ """This method instantiates and returns a handle to a low-level
+ base cipher. It will absorb named parameters in the process."""
+
+ try:
+ key = dict_parameters.pop("key")
+ except KeyError:
+ raise TypeError("Missing 'key' parameter")
+
+ if len(key) != key_size:
+ raise ValueError("Incorrect DES key length (%d bytes)" % len(key))
+
+ start_operation = _raw_des_lib.DES_start_operation
+ stop_operation = _raw_des_lib.DES_stop_operation
+
+ cipher = VoidPointer()
+ result = start_operation(c_uint8_ptr(key),
+ c_size_t(len(key)),
+ cipher.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the DES cipher"
+ % result)
+ return SmartPointer(cipher.get(), stop_operation)
+
+
+def new(key, mode, *args, **kwargs):
+ """Create a new DES cipher.
+
+ :param key:
+ The secret key to use in the symmetric cipher.
+ It must be 8 byte long. The parity bits will be ignored.
+ :type key: bytes/bytearray/memoryview
+
+ :param mode:
+ The chaining mode to use for encryption or decryption.
+ :type mode: One of the supported ``MODE_*`` constants
+
+ :Keyword Arguments:
+ * **iv** (*byte string*) --
+ (Only applicable for ``MODE_CBC``, ``MODE_CFB``, ``MODE_OFB``,
+ and ``MODE_OPENPGP`` modes).
+
+ The initialization vector to use for encryption or decryption.
+
+ For ``MODE_CBC``, ``MODE_CFB``, and ``MODE_OFB`` it must be 8 bytes long.
+
+ For ``MODE_OPENPGP`` mode only,
+ it must be 8 bytes long for encryption
+ and 10 bytes for decryption (in the latter case, it is
+ actually the *encrypted* IV which was prefixed to the ciphertext).
+
+ If not provided, a random byte string is generated (you must then
+ read its value with the :attr:`iv` attribute).
+
+ * **nonce** (*byte string*) --
+ (Only applicable for ``MODE_EAX`` and ``MODE_CTR``).
+
+ A value that must never be reused for any other encryption done
+ with this key.
+
+ For ``MODE_EAX`` there are no
+ restrictions on its length (recommended: **16** bytes).
+
+ For ``MODE_CTR``, its length must be in the range **[0..7]**.
+
+ If not provided for ``MODE_EAX``, a random byte string is generated (you
+ can read it back via the ``nonce`` attribute).
+
+ * **segment_size** (*integer*) --
+ (Only ``MODE_CFB``).The number of **bits** the plaintext and ciphertext
+ are segmented in. It must be a multiple of 8.
+ If not specified, it will be assumed to be 8.
+
+ * **mac_len** : (*integer*) --
+ (Only ``MODE_EAX``)
+ Length of the authentication tag, in bytes.
+ It must be no longer than 8 (default).
+
+ * **initial_value** : (*integer*) --
+ (Only ``MODE_CTR``). The initial value for the counter within
+ the counter block. By default it is **0**.
+
+ :Return: a DES object, of the applicable mode.
+ """
+
+ return _create_cipher(sys.modules[__name__], key, mode, *args, **kwargs)
+
+MODE_ECB = 1
+MODE_CBC = 2
+MODE_CFB = 3
+MODE_OFB = 5
+MODE_CTR = 6
+MODE_OPENPGP = 7
+MODE_EAX = 9
+
+# Size of a data block (in bytes)
+block_size = 8
+# Size of a key (in bytes)
+key_size = 8
diff --git a/lib/Crypto/Cipher/DES.pyi b/lib/Crypto/Cipher/DES.pyi
new file mode 100644
index 0000000..1047f13
--- /dev/null
+++ b/lib/Crypto/Cipher/DES.pyi
@@ -0,0 +1,35 @@
+from typing import Union, Dict, Iterable
+
+from Crypto.Cipher._mode_ecb import EcbMode
+from Crypto.Cipher._mode_cbc import CbcMode
+from Crypto.Cipher._mode_cfb import CfbMode
+from Crypto.Cipher._mode_ofb import OfbMode
+from Crypto.Cipher._mode_ctr import CtrMode
+from Crypto.Cipher._mode_openpgp import OpenPgpMode
+from Crypto.Cipher._mode_eax import EaxMode
+
+DESMode = int
+
+MODE_ECB: DESMode
+MODE_CBC: DESMode
+MODE_CFB: DESMode
+MODE_OFB: DESMode
+MODE_CTR: DESMode
+MODE_OPENPGP: DESMode
+MODE_EAX: DESMode
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ mode: DESMode,
+ iv : Buffer = ...,
+ IV : Buffer = ...,
+ nonce : Buffer = ...,
+ segment_size : int = ...,
+ mac_len : int = ...,
+ initial_value : Union[int, Buffer] = ...,
+ counter : Dict = ...) -> \
+ Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode]: ...
+
+block_size: int
+key_size: int
diff --git a/lib/Crypto/Cipher/DES3.py b/lib/Crypto/Cipher/DES3.py
new file mode 100644
index 0000000..c0d9367
--- /dev/null
+++ b/lib/Crypto/Cipher/DES3.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/DES3.py : DES3
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+"""
+Module's constants for the modes of operation supported with Triple DES:
+
+:var MODE_ECB: :ref:`Electronic Code Book (ECB) <ecb_mode>`
+:var MODE_CBC: :ref:`Cipher-Block Chaining (CBC) <cbc_mode>`
+:var MODE_CFB: :ref:`Cipher FeedBack (CFB) <cfb_mode>`
+:var MODE_OFB: :ref:`Output FeedBack (OFB) <ofb_mode>`
+:var MODE_CTR: :ref:`CounTer Mode (CTR) <ctr_mode>`
+:var MODE_OPENPGP: :ref:`OpenPGP Mode <openpgp_mode>`
+:var MODE_EAX: :ref:`EAX Mode <eax_mode>`
+"""
+
+import sys
+
+from Crypto.Cipher import _create_cipher
+from Crypto.Util.py3compat import byte_string, bchr, bord, bstr
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ c_size_t)
+
+_raw_des3_lib = load_pycryptodome_raw_lib(
+ "Crypto.Cipher._raw_des3",
+ """
+ int DES3_start_operation(const uint8_t key[],
+ size_t key_len,
+ void **pResult);
+ int DES3_encrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int DES3_decrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int DES3_stop_operation(void *state);
+ """)
+
+
+def adjust_key_parity(key_in):
+ """Set the parity bits in a TDES key.
+
+ :param key_in: the TDES key whose bits need to be adjusted
+ :type key_in: byte string
+
+ :returns: a copy of ``key_in``, with the parity bits correctly set
+ :rtype: byte string
+
+ :raises ValueError: if the TDES key is not 16 or 24 bytes long
+ :raises ValueError: if the TDES key degenerates into Single DES
+ """
+
+ def parity_byte(key_byte):
+ parity = 1
+ for i in range(1, 8):
+ parity ^= (key_byte >> i) & 1
+ return (key_byte & 0xFE) | parity
+
+ if len(key_in) not in key_size:
+ raise ValueError("Not a valid TDES key")
+
+ key_out = b"".join([ bchr(parity_byte(bord(x))) for x in key_in ])
+
+ if key_out[:8] == key_out[8:16] or key_out[-16:-8] == key_out[-8:]:
+ raise ValueError("Triple DES key degenerates to single DES")
+
+ return key_out
+
+
+def _create_base_cipher(dict_parameters):
+ """This method instantiates and returns a handle to a low-level base cipher.
+ It will absorb named parameters in the process."""
+
+ try:
+ key_in = dict_parameters.pop("key")
+ except KeyError:
+ raise TypeError("Missing 'key' parameter")
+
+ key = adjust_key_parity(bstr(key_in))
+
+ start_operation = _raw_des3_lib.DES3_start_operation
+ stop_operation = _raw_des3_lib.DES3_stop_operation
+
+ cipher = VoidPointer()
+ result = start_operation(key,
+ c_size_t(len(key)),
+ cipher.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the TDES cipher"
+ % result)
+ return SmartPointer(cipher.get(), stop_operation)
+
+
+def new(key, mode, *args, **kwargs):
+ """Create a new Triple DES cipher.
+
+ :param key:
+ The secret key to use in the symmetric cipher.
+ It must be 16 or 24 byte long. The parity bits will be ignored.
+ :type key: bytes/bytearray/memoryview
+
+ :param mode:
+ The chaining mode to use for encryption or decryption.
+ :type mode: One of the supported ``MODE_*`` constants
+
+ :Keyword Arguments:
+ * **iv** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_CBC``, ``MODE_CFB``, ``MODE_OFB``,
+ and ``MODE_OPENPGP`` modes).
+
+ The initialization vector to use for encryption or decryption.
+
+ For ``MODE_CBC``, ``MODE_CFB``, and ``MODE_OFB`` it must be 8 bytes long.
+
+ For ``MODE_OPENPGP`` mode only,
+ it must be 8 bytes long for encryption
+ and 10 bytes for decryption (in the latter case, it is
+ actually the *encrypted* IV which was prefixed to the ciphertext).
+
+ If not provided, a random byte string is generated (you must then
+ read its value with the :attr:`iv` attribute).
+
+ * **nonce** (*bytes*, *bytearray*, *memoryview*) --
+ (Only applicable for ``MODE_EAX`` and ``MODE_CTR``).
+
+ A value that must never be reused for any other encryption done
+ with this key.
+
+ For ``MODE_EAX`` there are no
+ restrictions on its length (recommended: **16** bytes).
+
+ For ``MODE_CTR``, its length must be in the range **[0..7]**.
+
+ If not provided for ``MODE_EAX``, a random byte string is generated (you
+ can read it back via the ``nonce`` attribute).
+
+ * **segment_size** (*integer*) --
+ (Only ``MODE_CFB``).The number of **bits** the plaintext and ciphertext
+ are segmented in. It must be a multiple of 8.
+ If not specified, it will be assumed to be 8.
+
+ * **mac_len** : (*integer*) --
+ (Only ``MODE_EAX``)
+ Length of the authentication tag, in bytes.
+ It must be no longer than 8 (default).
+
+ * **initial_value** : (*integer*) --
+ (Only ``MODE_CTR``). The initial value for the counter within
+ the counter block. By default it is **0**.
+
+ :Return: a Triple DES object, of the applicable mode.
+ """
+
+ return _create_cipher(sys.modules[__name__], key, mode, *args, **kwargs)
+
+MODE_ECB = 1
+MODE_CBC = 2
+MODE_CFB = 3
+MODE_OFB = 5
+MODE_CTR = 6
+MODE_OPENPGP = 7
+MODE_EAX = 9
+
+# Size of a data block (in bytes)
+block_size = 8
+# Size of a key (in bytes)
+key_size = (16, 24)
diff --git a/lib/Crypto/Cipher/DES3.pyi b/lib/Crypto/Cipher/DES3.pyi
new file mode 100644
index 0000000..a89db9c
--- /dev/null
+++ b/lib/Crypto/Cipher/DES3.pyi
@@ -0,0 +1,37 @@
+from typing import Union, Dict, Tuple
+
+from Crypto.Cipher._mode_ecb import EcbMode
+from Crypto.Cipher._mode_cbc import CbcMode
+from Crypto.Cipher._mode_cfb import CfbMode
+from Crypto.Cipher._mode_ofb import OfbMode
+from Crypto.Cipher._mode_ctr import CtrMode
+from Crypto.Cipher._mode_openpgp import OpenPgpMode
+from Crypto.Cipher._mode_eax import EaxMode
+
+def adjust_key_parity(key_in: bytes) -> bytes: ...
+
+DES3Mode = int
+
+MODE_ECB: DES3Mode
+MODE_CBC: DES3Mode
+MODE_CFB: DES3Mode
+MODE_OFB: DES3Mode
+MODE_CTR: DES3Mode
+MODE_OPENPGP: DES3Mode
+MODE_EAX: DES3Mode
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ mode: DES3Mode,
+ iv : Buffer = ...,
+ IV : Buffer = ...,
+ nonce : Buffer = ...,
+ segment_size : int = ...,
+ mac_len : int = ...,
+ initial_value : Union[int, Buffer] = ...,
+ counter : Dict = ...) -> \
+ Union[EcbMode, CbcMode, CfbMode, OfbMode, CtrMode, OpenPgpMode]: ...
+
+block_size: int
+key_size: Tuple[int, int]
diff --git a/lib/Crypto/Cipher/PKCS1_OAEP.py b/lib/Crypto/Cipher/PKCS1_OAEP.py
new file mode 100644
index 0000000..57a982b
--- /dev/null
+++ b/lib/Crypto/Cipher/PKCS1_OAEP.py
@@ -0,0 +1,239 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/PKCS1_OAEP.py : PKCS#1 OAEP
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Signature.pss import MGF1
+import Crypto.Hash.SHA1
+
+from Crypto.Util.py3compat import bord, _copy_bytes
+import Crypto.Util.number
+from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
+from Crypto.Util.strxor import strxor
+from Crypto import Random
+
+class PKCS1OAEP_Cipher:
+ """Cipher object for PKCS#1 v1.5 OAEP.
+ Do not create directly: use :func:`new` instead."""
+
+ def __init__(self, key, hashAlgo, mgfunc, label, randfunc):
+ """Initialize this PKCS#1 OAEP cipher object.
+
+ :Parameters:
+ key : an RSA key object
+ If a private half is given, both encryption and decryption are possible.
+ If a public half is given, only encryption is possible.
+ hashAlgo : hash object
+ The hash function to use. This can be a module under `Crypto.Hash`
+ or an existing hash object created from any of such modules. If not specified,
+ `Crypto.Hash.SHA1` is used.
+ mgfunc : callable
+ A mask generation function that accepts two parameters: a string to
+ use as seed, and the lenth of the mask to generate, in bytes.
+ If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
+ label : bytes/bytearray/memoryview
+ A label to apply to this particular encryption. If not specified,
+ an empty string is used. Specifying a label does not improve
+ security.
+ randfunc : callable
+ A function that returns random bytes.
+
+ :attention: Modify the mask generation function only if you know what you are doing.
+ Sender and receiver must use the same one.
+ """
+ self._key = key
+
+ if hashAlgo:
+ self._hashObj = hashAlgo
+ else:
+ self._hashObj = Crypto.Hash.SHA1
+
+ if mgfunc:
+ self._mgf = mgfunc
+ else:
+ self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
+
+ self._label = _copy_bytes(None, None, label)
+ self._randfunc = randfunc
+
+ def can_encrypt(self):
+ """Legacy function to check if you can call :meth:`encrypt`.
+
+ .. deprecated:: 3.0"""
+ return self._key.can_encrypt()
+
+ def can_decrypt(self):
+ """Legacy function to check if you can call :meth:`decrypt`.
+
+ .. deprecated:: 3.0"""
+ return self._key.can_decrypt()
+
+ def encrypt(self, message):
+ """Encrypt a message with PKCS#1 OAEP.
+
+ :param message:
+ The message to encrypt, also known as plaintext. It can be of
+ variable length, but not longer than the RSA modulus (in bytes)
+ minus 2, minus twice the hash output size.
+ For instance, if you use RSA 2048 and SHA-256, the longest message
+ you can encrypt is 190 byte long.
+ :type message: bytes/bytearray/memoryview
+
+ :returns: The ciphertext, as large as the RSA modulus.
+ :rtype: bytes
+
+ :raises ValueError:
+ if the message is too long.
+ """
+
+ # See 7.1.1 in RFC3447
+ modBits = Crypto.Util.number.size(self._key.n)
+ k = ceil_div(modBits, 8) # Convert from bits to bytes
+ hLen = self._hashObj.digest_size
+ mLen = len(message)
+
+ # Step 1b
+ ps_len = k - mLen - 2 * hLen - 2
+ if ps_len < 0:
+ raise ValueError("Plaintext is too long.")
+ # Step 2a
+ lHash = self._hashObj.new(self._label).digest()
+ # Step 2b
+ ps = b'\x00' * ps_len
+ # Step 2c
+ db = lHash + ps + b'\x01' + _copy_bytes(None, None, message)
+ # Step 2d
+ ros = self._randfunc(hLen)
+ # Step 2e
+ dbMask = self._mgf(ros, k-hLen-1)
+ # Step 2f
+ maskedDB = strxor(db, dbMask)
+ # Step 2g
+ seedMask = self._mgf(maskedDB, hLen)
+ # Step 2h
+ maskedSeed = strxor(ros, seedMask)
+ # Step 2i
+ em = b'\x00' + maskedSeed + maskedDB
+ # Step 3a (OS2IP)
+ em_int = bytes_to_long(em)
+ # Step 3b (RSAEP)
+ m_int = self._key._encrypt(em_int)
+ # Step 3c (I2OSP)
+ c = long_to_bytes(m_int, k)
+ return c
+
+ def decrypt(self, ciphertext):
+ """Decrypt a message with PKCS#1 OAEP.
+
+ :param ciphertext: The encrypted message.
+ :type ciphertext: bytes/bytearray/memoryview
+
+ :returns: The original message (plaintext).
+ :rtype: bytes
+
+ :raises ValueError:
+ if the ciphertext has the wrong length, or if decryption
+ fails the integrity check (in which case, the decryption
+ key is probably wrong).
+ :raises TypeError:
+ if the RSA key has no private half (i.e. you are trying
+ to decrypt using a public key).
+ """
+
+ # See 7.1.2 in RFC3447
+ modBits = Crypto.Util.number.size(self._key.n)
+ k = ceil_div(modBits,8) # Convert from bits to bytes
+ hLen = self._hashObj.digest_size
+
+ # Step 1b and 1c
+ if len(ciphertext) != k or k<hLen+2:
+ raise ValueError("Ciphertext with incorrect length.")
+ # Step 2a (O2SIP)
+ ct_int = bytes_to_long(ciphertext)
+ # Step 2b (RSADP)
+ m_int = self._key._decrypt(ct_int)
+ # Complete step 2c (I2OSP)
+ em = long_to_bytes(m_int, k)
+ # Step 3a
+ lHash = self._hashObj.new(self._label).digest()
+ # Step 3b
+ y = em[0]
+ # y must be 0, but we MUST NOT check it here in order not to
+ # allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
+ maskedSeed = em[1:hLen+1]
+ maskedDB = em[hLen+1:]
+ # Step 3c
+ seedMask = self._mgf(maskedDB, hLen)
+ # Step 3d
+ seed = strxor(maskedSeed, seedMask)
+ # Step 3e
+ dbMask = self._mgf(seed, k-hLen-1)
+ # Step 3f
+ db = strxor(maskedDB, dbMask)
+ # Step 3g
+ one_pos = hLen + db[hLen:].find(b'\x01')
+ lHash1 = db[:hLen]
+ invalid = bord(y) | int(one_pos < hLen)
+ hash_compare = strxor(lHash1, lHash)
+ for x in hash_compare:
+ invalid |= bord(x)
+ for x in db[hLen:one_pos]:
+ invalid |= bord(x)
+ if invalid != 0:
+ raise ValueError("Incorrect decryption.")
+ # Step 4
+ return db[one_pos + 1:]
+
+def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
+ """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
+
+ :param key:
+ The key object to use to encrypt or decrypt the message.
+ Decryption is only possible with a private RSA key.
+ :type key: RSA key object
+
+ :param hashAlgo:
+ The hash function to use. This can be a module under `Crypto.Hash`
+ or an existing hash object created from any of such modules.
+ If not specified, `Crypto.Hash.SHA1` is used.
+ :type hashAlgo: hash object
+
+ :param mgfunc:
+ A mask generation function that accepts two parameters: a string to
+ use as seed, and the lenth of the mask to generate, in bytes.
+ If not specified, the standard MGF1 consistent with ``hashAlgo`` is used (a safe choice).
+ :type mgfunc: callable
+
+ :param label:
+ A label to apply to this particular encryption. If not specified,
+ an empty string is used. Specifying a label does not improve
+ security.
+ :type label: bytes/bytearray/memoryview
+
+ :param randfunc:
+ A function that returns random bytes.
+ The default is `Random.get_random_bytes`.
+ :type randfunc: callable
+ """
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+ return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc)
+
diff --git a/lib/Crypto/Cipher/PKCS1_OAEP.pyi b/lib/Crypto/Cipher/PKCS1_OAEP.pyi
new file mode 100644
index 0000000..6cb80da
--- /dev/null
+++ b/lib/Crypto/Cipher/PKCS1_OAEP.pyi
@@ -0,0 +1,35 @@
+from typing import Optional, Union, Callable, Any, overload
+from typing_extensions import Protocol
+
+from Crypto.PublicKey.RSA import RsaKey
+
+class HashLikeClass(Protocol):
+ digest_size : int
+ def new(self, data: Optional[bytes] = ...) -> Any: ...
+
+class HashLikeModule(Protocol):
+ digest_size : int
+ @staticmethod
+ def new(data: Optional[bytes] = ...) -> Any: ...
+
+HashLike = Union[HashLikeClass, HashLikeModule]
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class PKCS1OAEP_Cipher:
+ def __init__(self,
+ key: RsaKey,
+ hashAlgo: HashLike,
+ mgfunc: Callable[[bytes, int], bytes],
+ label: Buffer,
+ randfunc: Callable[[int], bytes]) -> None: ...
+ def can_encrypt(self) -> bool: ...
+ def can_decrypt(self) -> bool: ...
+ def encrypt(self, message: Buffer) -> bytes: ...
+ def decrypt(self, ciphertext: Buffer) -> bytes: ...
+
+def new(key: RsaKey,
+ hashAlgo: Optional[HashLike] = ...,
+ mgfunc: Optional[Callable[[bytes, int], bytes]] = ...,
+ label: Optional[Buffer] = ...,
+ randfunc: Optional[Callable[[int], bytes]] = ...) -> PKCS1OAEP_Cipher: ...
diff --git a/lib/Crypto/Cipher/PKCS1_v1_5.py b/lib/Crypto/Cipher/PKCS1_v1_5.py
new file mode 100644
index 0000000..d0d474a
--- /dev/null
+++ b/lib/Crypto/Cipher/PKCS1_v1_5.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/PKCS1-v1_5.py : PKCS#1 v1.5
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+__all__ = ['new', 'PKCS115_Cipher']
+
+from Crypto import Random
+from Crypto.Util.number import bytes_to_long, long_to_bytes
+from Crypto.Util.py3compat import bord, is_bytes, _copy_bytes
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, c_size_t,
+ c_uint8_ptr)
+
+
+_raw_pkcs1_decode = load_pycryptodome_raw_lib("Crypto.Cipher._pkcs1_decode",
+ """
+ int pkcs1_decode(const uint8_t *em, size_t len_em,
+ const uint8_t *sentinel, size_t len_sentinel,
+ size_t expected_pt_len,
+ uint8_t *output);
+ """)
+
+
+def _pkcs1_decode(em, sentinel, expected_pt_len, output):
+ if len(em) != len(output):
+ raise ValueError("Incorrect output length")
+
+ ret = _raw_pkcs1_decode.pkcs1_decode(c_uint8_ptr(em),
+ c_size_t(len(em)),
+ c_uint8_ptr(sentinel),
+ c_size_t(len(sentinel)),
+ c_size_t(expected_pt_len),
+ c_uint8_ptr(output))
+ return ret
+
+
+class PKCS115_Cipher:
+ """This cipher can perform PKCS#1 v1.5 RSA encryption or decryption.
+ Do not instantiate directly. Use :func:`Crypto.Cipher.PKCS1_v1_5.new` instead."""
+
+ def __init__(self, key, randfunc):
+ """Initialize this PKCS#1 v1.5 cipher object.
+
+ :Parameters:
+ key : an RSA key object
+ If a private half is given, both encryption and decryption are possible.
+ If a public half is given, only encryption is possible.
+ randfunc : callable
+ Function that returns random bytes.
+ """
+
+ self._key = key
+ self._randfunc = randfunc
+
+ def can_encrypt(self):
+ """Return True if this cipher object can be used for encryption."""
+ return self._key.can_encrypt()
+
+ def can_decrypt(self):
+ """Return True if this cipher object can be used for decryption."""
+ return self._key.can_decrypt()
+
+ def encrypt(self, message):
+ """Produce the PKCS#1 v1.5 encryption of a message.
+
+ This function is named ``RSAES-PKCS1-V1_5-ENCRYPT``, and it is specified in
+ `section 7.2.1 of RFC8017
+ <https://tools.ietf.org/html/rfc8017#page-28>`_.
+
+ :param message:
+ The message to encrypt, also known as plaintext. It can be of
+ variable length, but not longer than the RSA modulus (in bytes) minus 11.
+ :type message: bytes/bytearray/memoryview
+
+ :Returns: A byte string, the ciphertext in which the message is encrypted.
+ It is as long as the RSA modulus (in bytes).
+
+ :Raises ValueError:
+ If the RSA key length is not sufficiently long to deal with the given
+ message.
+ """
+
+ # See 7.2.1 in RFC8017
+ k = self._key.size_in_bytes()
+ mLen = len(message)
+
+ # Step 1
+ if mLen > k - 11:
+ raise ValueError("Plaintext is too long.")
+ # Step 2a
+ ps = []
+ while len(ps) != k - mLen - 3:
+ new_byte = self._randfunc(1)
+ if bord(new_byte[0]) == 0x00:
+ continue
+ ps.append(new_byte)
+ ps = b"".join(ps)
+ assert(len(ps) == k - mLen - 3)
+ # Step 2b
+ em = b'\x00\x02' + ps + b'\x00' + _copy_bytes(None, None, message)
+ # Step 3a (OS2IP)
+ em_int = bytes_to_long(em)
+ # Step 3b (RSAEP)
+ m_int = self._key._encrypt(em_int)
+ # Step 3c (I2OSP)
+ c = long_to_bytes(m_int, k)
+ return c
+
+ def decrypt(self, ciphertext, sentinel, expected_pt_len=0):
+ r"""Decrypt a PKCS#1 v1.5 ciphertext.
+
+ This is the function ``RSAES-PKCS1-V1_5-DECRYPT`` specified in
+ `section 7.2.2 of RFC8017
+ <https://tools.ietf.org/html/rfc8017#page-29>`_.
+
+ Args:
+ ciphertext (bytes/bytearray/memoryview):
+ The ciphertext that contains the message to recover.
+ sentinel (any type):
+ The object to return whenever an error is detected.
+ expected_pt_len (integer):
+ The length the plaintext is known to have, or 0 if unknown.
+
+ Returns (byte string):
+ It is either the original message or the ``sentinel`` (in case of an error).
+
+ .. warning::
+ PKCS#1 v1.5 decryption is intrinsically vulnerable to timing
+ attacks (see `Bleichenbacher's`__ attack).
+ **Use PKCS#1 OAEP instead**.
+
+ This implementation attempts to mitigate the risk
+ with some constant-time constructs.
+ However, they are not sufficient by themselves: the type of protocol you
+ implement and the way you handle errors make a big difference.
+
+ Specifically, you should make it very hard for the (malicious)
+ party that submitted the ciphertext to quickly understand if decryption
+ succeeded or not.
+
+ To this end, it is recommended that your protocol only encrypts
+ plaintexts of fixed length (``expected_pt_len``),
+ that ``sentinel`` is a random byte string of the same length,
+ and that processing continues for as long
+ as possible even if ``sentinel`` is returned (i.e. in case of
+ incorrect decryption).
+
+ .. __: https://dx.doi.org/10.1007/BFb0055716
+ """
+
+ # See 7.2.2 in RFC8017
+ k = self._key.size_in_bytes()
+
+ # Step 1
+ if len(ciphertext) != k:
+ raise ValueError("Ciphertext with incorrect length (not %d bytes)" % k)
+
+ # Step 2a (O2SIP)
+ ct_int = bytes_to_long(ciphertext)
+
+ # Step 2b (RSADP)
+ m_int = self._key._decrypt(ct_int)
+
+ # Complete step 2c (I2OSP)
+ em = long_to_bytes(m_int, k)
+
+ # Step 3 (not constant time when the sentinel is not a byte string)
+ output = bytes(bytearray(k))
+ if not is_bytes(sentinel) or len(sentinel) > k:
+ size = _pkcs1_decode(em, b'', expected_pt_len, output)
+ if size < 0:
+ return sentinel
+ else:
+ return output[size:]
+
+ # Step 3 (somewhat constant time)
+ size = _pkcs1_decode(em, sentinel, expected_pt_len, output)
+ return output[size:]
+
+
+def new(key, randfunc=None):
+ """Create a cipher for performing PKCS#1 v1.5 encryption or decryption.
+
+ :param key:
+ The key to use to encrypt or decrypt the message. This is a `Crypto.PublicKey.RSA` object.
+ Decryption is only possible if *key* is a private RSA key.
+ :type key: RSA key object
+
+ :param randfunc:
+ Function that return random bytes.
+ The default is :func:`Crypto.Random.get_random_bytes`.
+ :type randfunc: callable
+
+ :returns: A cipher object `PKCS115_Cipher`.
+ """
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+ return PKCS115_Cipher(key, randfunc)
diff --git a/lib/Crypto/Cipher/PKCS1_v1_5.pyi b/lib/Crypto/Cipher/PKCS1_v1_5.pyi
new file mode 100644
index 0000000..1719f01
--- /dev/null
+++ b/lib/Crypto/Cipher/PKCS1_v1_5.pyi
@@ -0,0 +1,20 @@
+from typing import Callable, Union, Any, Optional, TypeVar
+
+from Crypto.PublicKey.RSA import RsaKey
+
+Buffer = Union[bytes, bytearray, memoryview]
+T = TypeVar('T')
+
+class PKCS115_Cipher:
+ def __init__(self,
+ key: RsaKey,
+ randfunc: Callable[[int], bytes]) -> None: ...
+ def can_encrypt(self) -> bool: ...
+ def can_decrypt(self) -> bool: ...
+ def encrypt(self, message: Buffer) -> bytes: ...
+ def decrypt(self, ciphertext: Buffer,
+ sentinel: T,
+ expected_pt_len: Optional[int] = ...) -> Union[bytes, T]: ...
+
+def new(key: RsaKey,
+ randfunc: Optional[Callable[[int], bytes]] = ...) -> PKCS115_Cipher: ...
diff --git a/lib/Crypto/Cipher/Salsa20.py b/lib/Crypto/Cipher/Salsa20.py
new file mode 100644
index 0000000..62d0b29
--- /dev/null
+++ b/lib/Crypto/Cipher/Salsa20.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/Salsa20.py : Salsa20 stream cipher (http://cr.yp.to/snuffle.html)
+#
+# Contributed by Fabrizio Tarizzo <fabrizio@fabriziotarizzo.org>.
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import _copy_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ create_string_buffer,
+ get_raw_buffer, VoidPointer,
+ SmartPointer, c_size_t,
+ c_uint8_ptr, is_writeable_buffer)
+
+from Crypto.Random import get_random_bytes
+
+_raw_salsa20_lib = load_pycryptodome_raw_lib("Crypto.Cipher._Salsa20",
+ """
+ int Salsa20_stream_init(uint8_t *key, size_t keylen,
+ uint8_t *nonce, size_t nonce_len,
+ void **pSalsaState);
+ int Salsa20_stream_destroy(void *salsaState);
+ int Salsa20_stream_encrypt(void *salsaState,
+ const uint8_t in[],
+ uint8_t out[], size_t len);
+ """)
+
+
+class Salsa20Cipher:
+ """Salsa20 cipher object. Do not create it directly. Use :py:func:`new`
+ instead.
+
+ :var nonce: The nonce with length 8
+ :vartype nonce: byte string
+ """
+
+ def __init__(self, key, nonce):
+ """Initialize a Salsa20 cipher object
+
+ See also `new()` at the module level."""
+
+ if len(key) not in key_size:
+ raise ValueError("Incorrect key length for Salsa20 (%d bytes)" % len(key))
+
+ if len(nonce) != 8:
+ raise ValueError("Incorrect nonce length for Salsa20 (%d bytes)" %
+ len(nonce))
+
+ self.nonce = _copy_bytes(None, None, nonce)
+
+ self._state = VoidPointer()
+ result = _raw_salsa20_lib.Salsa20_stream_init(
+ c_uint8_ptr(key),
+ c_size_t(len(key)),
+ c_uint8_ptr(nonce),
+ c_size_t(len(nonce)),
+ self._state.address_of())
+ if result:
+ raise ValueError("Error %d instantiating a Salsa20 cipher")
+ self._state = SmartPointer(self._state.get(),
+ _raw_salsa20_lib.Salsa20_stream_destroy)
+
+ self.block_size = 1
+ self.key_size = len(key)
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt a piece of data.
+
+ Args:
+ plaintext(bytes/bytearray/memoryview): The data to encrypt, of any size.
+ Keyword Args:
+ output(bytes/bytearray/memoryview): The location where the ciphertext
+ is written to. If ``None``, the ciphertext is returned.
+ Returns:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if output is None:
+ ciphertext = create_string_buffer(len(plaintext))
+ else:
+ ciphertext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(plaintext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = _raw_salsa20_lib.Salsa20_stream_encrypt(
+ self._state.get(),
+ c_uint8_ptr(plaintext),
+ c_uint8_ptr(ciphertext),
+ c_size_t(len(plaintext)))
+ if result:
+ raise ValueError("Error %d while encrypting with Salsa20" % result)
+
+ if output is None:
+ return get_raw_buffer(ciphertext)
+ else:
+ return None
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt a piece of data.
+
+ Args:
+ ciphertext(bytes/bytearray/memoryview): The data to decrypt, of any size.
+ Keyword Args:
+ output(bytes/bytearray/memoryview): The location where the plaintext
+ is written to. If ``None``, the plaintext is returned.
+ Returns:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ try:
+ return self.encrypt(ciphertext, output=output)
+ except ValueError as e:
+ raise ValueError(str(e).replace("enc", "dec"))
+
+
+def new(key, nonce=None):
+ """Create a new Salsa20 cipher
+
+ :keyword key: The secret key to use. It must be 16 or 32 bytes long.
+ :type key: bytes/bytearray/memoryview
+
+ :keyword nonce:
+ A value that must never be reused for any other encryption
+ done with this key. It must be 8 bytes long.
+
+ If not provided, a random byte string will be generated (you can read
+ it back via the ``nonce`` attribute of the returned object).
+ :type nonce: bytes/bytearray/memoryview
+
+ :Return: a :class:`Crypto.Cipher.Salsa20.Salsa20Cipher` object
+ """
+
+ if nonce is None:
+ nonce = get_random_bytes(8)
+
+ return Salsa20Cipher(key, nonce)
+
+# Size of a data block (in bytes)
+block_size = 1
+
+# Size of a key (in bytes)
+key_size = (16, 32)
+
diff --git a/lib/Crypto/Cipher/Salsa20.pyi b/lib/Crypto/Cipher/Salsa20.pyi
new file mode 100644
index 0000000..9178f0d
--- /dev/null
+++ b/lib/Crypto/Cipher/Salsa20.pyi
@@ -0,0 +1,27 @@
+from typing import Union, Tuple, Optional, overload
+
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class Salsa20Cipher:
+ nonce: bytes
+ block_size: int
+ key_size: int
+
+ def __init__(self,
+ key: Buffer,
+ nonce: Buffer) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
+def new(key: Buffer, nonce: Optional[Buffer] = ...) -> Salsa20Cipher: ...
+
+block_size: int
+key_size: Tuple[int, int]
+
diff --git a/lib/Crypto/Cipher/_ARC4.abi3.so b/lib/Crypto/Cipher/_ARC4.abi3.so
new file mode 100755
index 0000000..675eb60
--- /dev/null
+++ b/lib/Crypto/Cipher/_ARC4.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_EKSBlowfish.py b/lib/Crypto/Cipher/_EKSBlowfish.py
new file mode 100644
index 0000000..a844fae
--- /dev/null
+++ b/lib/Crypto/Cipher/_EKSBlowfish.py
@@ -0,0 +1,131 @@
+# ===================================================================
+#
+# Copyright (c) 2019, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import sys
+
+from Crypto.Cipher import _create_cipher
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer, c_size_t,
+ c_uint8_ptr, c_uint)
+
+_raw_blowfish_lib = load_pycryptodome_raw_lib(
+ "Crypto.Cipher._raw_eksblowfish",
+ """
+ int EKSBlowfish_start_operation(const uint8_t key[],
+ size_t key_len,
+ const uint8_t salt[16],
+ size_t salt_len,
+ unsigned cost,
+ unsigned invert,
+ void **pResult);
+ int EKSBlowfish_encrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int EKSBlowfish_decrypt(const void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int EKSBlowfish_stop_operation(void *state);
+ """
+ )
+
+
+def _create_base_cipher(dict_parameters):
+ """This method instantiates and returns a smart pointer to
+ a low-level base cipher. It will absorb named parameters in
+ the process."""
+
+ try:
+ key = dict_parameters.pop("key")
+ salt = dict_parameters.pop("salt")
+ cost = dict_parameters.pop("cost")
+ except KeyError as e:
+ raise TypeError("Missing EKSBlowfish parameter: " + str(e))
+ invert = dict_parameters.pop("invert", True)
+
+ if len(key) not in key_size:
+ raise ValueError("Incorrect EKSBlowfish key length (%d bytes)" % len(key))
+
+ start_operation = _raw_blowfish_lib.EKSBlowfish_start_operation
+ stop_operation = _raw_blowfish_lib.EKSBlowfish_stop_operation
+
+ void_p = VoidPointer()
+ result = start_operation(c_uint8_ptr(key),
+ c_size_t(len(key)),
+ c_uint8_ptr(salt),
+ c_size_t(len(salt)),
+ c_uint(cost),
+ c_uint(int(invert)),
+ void_p.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the EKSBlowfish cipher"
+ % result)
+ return SmartPointer(void_p.get(), stop_operation)
+
+
+def new(key, mode, salt, cost, invert):
+ """Create a new EKSBlowfish cipher
+
+ Args:
+
+ key (bytes, bytearray, memoryview):
+ The secret key to use in the symmetric cipher.
+ Its length can vary from 0 to 72 bytes.
+
+ mode (one of the supported ``MODE_*`` constants):
+ The chaining mode to use for encryption or decryption.
+
+ salt (bytes, bytearray, memoryview):
+ The salt that bcrypt uses to thwart rainbow table attacks
+
+ cost (integer):
+ The complexity factor in bcrypt
+
+ invert (bool):
+ If ``False``, in the inner loop use ``ExpandKey`` first over the salt
+ and then over the key, as defined in
+ the `original bcrypt specification <https://www.usenix.org/legacy/events/usenix99/provos/provos_html/node4.html>`_.
+ If ``True``, reverse the order, as in the first implementation of
+ `bcrypt` in OpenBSD.
+
+ :Return: an EKSBlowfish object
+ """
+
+ kwargs = { 'salt':salt, 'cost':cost, 'invert':invert }
+ return _create_cipher(sys.modules[__name__], key, mode, **kwargs)
+
+
+MODE_ECB = 1
+
+# Size of a data block (in bytes)
+block_size = 8
+# Size of a key (in bytes)
+key_size = range(0, 72 + 1)
diff --git a/lib/Crypto/Cipher/_EKSBlowfish.pyi b/lib/Crypto/Cipher/_EKSBlowfish.pyi
new file mode 100644
index 0000000..95db379
--- /dev/null
+++ b/lib/Crypto/Cipher/_EKSBlowfish.pyi
@@ -0,0 +1,15 @@
+from typing import Union, Iterable
+
+from Crypto.Cipher._mode_ecb import EcbMode
+
+MODE_ECB: int
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ mode: int,
+ salt: Buffer,
+ cost: int) -> EcbMode: ...
+
+block_size: int
+key_size: Iterable[int]
diff --git a/lib/Crypto/Cipher/_Salsa20.abi3.so b/lib/Crypto/Cipher/_Salsa20.abi3.so
new file mode 100755
index 0000000..2b97f5d
--- /dev/null
+++ b/lib/Crypto/Cipher/_Salsa20.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/__init__.py b/lib/Crypto/Cipher/__init__.py
new file mode 100644
index 0000000..ba2d485
--- /dev/null
+++ b/lib/Crypto/Cipher/__init__.py
@@ -0,0 +1,79 @@
+#
+# A block cipher is instantiated as a combination of:
+# 1. A base cipher (such as AES)
+# 2. A mode of operation (such as CBC)
+#
+# Both items are implemented as C modules.
+#
+# The API of #1 is (replace "AES" with the name of the actual cipher):
+# - AES_start_operaion(key) --> base_cipher_state
+# - AES_encrypt(base_cipher_state, in, out, length)
+# - AES_decrypt(base_cipher_state, in, out, length)
+# - AES_stop_operation(base_cipher_state)
+#
+# Where base_cipher_state is AES_State, a struct with BlockBase (set of
+# pointers to encrypt/decrypt/stop) followed by cipher-specific data.
+#
+# The API of #2 is (replace "CBC" with the name of the actual mode):
+# - CBC_start_operation(base_cipher_state) --> mode_state
+# - CBC_encrypt(mode_state, in, out, length)
+# - CBC_decrypt(mode_state, in, out, length)
+# - CBC_stop_operation(mode_state)
+#
+# where mode_state is a a pointer to base_cipher_state plus mode-specific data.
+
+import os
+
+from Crypto.Cipher._mode_ecb import _create_ecb_cipher
+from Crypto.Cipher._mode_cbc import _create_cbc_cipher
+from Crypto.Cipher._mode_cfb import _create_cfb_cipher
+from Crypto.Cipher._mode_ofb import _create_ofb_cipher
+from Crypto.Cipher._mode_ctr import _create_ctr_cipher
+from Crypto.Cipher._mode_openpgp import _create_openpgp_cipher
+from Crypto.Cipher._mode_ccm import _create_ccm_cipher
+from Crypto.Cipher._mode_eax import _create_eax_cipher
+from Crypto.Cipher._mode_siv import _create_siv_cipher
+from Crypto.Cipher._mode_gcm import _create_gcm_cipher
+from Crypto.Cipher._mode_ocb import _create_ocb_cipher
+
+_modes = { 1:_create_ecb_cipher,
+ 2:_create_cbc_cipher,
+ 3:_create_cfb_cipher,
+ 5:_create_ofb_cipher,
+ 6:_create_ctr_cipher,
+ 7:_create_openpgp_cipher,
+ 9:_create_eax_cipher
+ }
+
+_extra_modes = { 8:_create_ccm_cipher,
+ 10:_create_siv_cipher,
+ 11:_create_gcm_cipher,
+ 12:_create_ocb_cipher
+ }
+
+def _create_cipher(factory, key, mode, *args, **kwargs):
+
+ kwargs["key"] = key
+
+ modes = dict(_modes)
+ if kwargs.pop("add_aes_modes", False):
+ modes.update(_extra_modes)
+ if not mode in modes:
+ raise ValueError("Mode not supported")
+
+ if args:
+ if mode in (8, 9, 10, 11, 12):
+ if len(args) > 1:
+ raise TypeError("Too many arguments for this mode")
+ kwargs["nonce"] = args[0]
+ elif mode in (2, 3, 5, 7):
+ if len(args) > 1:
+ raise TypeError("Too many arguments for this mode")
+ kwargs["IV"] = args[0]
+ elif mode == 6:
+ if len(args) > 0:
+ raise TypeError("Too many arguments for this mode")
+ elif mode == 1:
+ raise TypeError("IV is not meaningful for the ECB mode")
+
+ return modes[mode](factory, **kwargs)
diff --git a/lib/Crypto/Cipher/__init__.pyi b/lib/Crypto/Cipher/__init__.pyi
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/Crypto/Cipher/__init__.pyi
diff --git a/lib/Crypto/Cipher/_chacha20.abi3.so b/lib/Crypto/Cipher/_chacha20.abi3.so
new file mode 100755
index 0000000..1549d63
--- /dev/null
+++ b/lib/Crypto/Cipher/_chacha20.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_mode_cbc.py b/lib/Crypto/Cipher/_mode_cbc.py
new file mode 100644
index 0000000..79c871a
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_cbc.py
@@ -0,0 +1,293 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+Ciphertext Block Chaining (CBC) mode.
+"""
+
+__all__ = ['CbcMode']
+
+from Crypto.Util.py3compat import _copy_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ create_string_buffer, get_raw_buffer,
+ SmartPointer, c_size_t, c_uint8_ptr,
+ is_writeable_buffer)
+
+from Crypto.Random import get_random_bytes
+
+raw_cbc_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_cbc", """
+ int CBC_start_operation(void *cipher,
+ const uint8_t iv[],
+ size_t iv_len,
+ void **pResult);
+ int CBC_encrypt(void *cbcState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CBC_decrypt(void *cbcState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CBC_stop_operation(void *state);
+ """
+ )
+
+
+class CbcMode(object):
+ """*Cipher-Block Chaining (CBC)*.
+
+ Each of the ciphertext blocks depends on the current
+ and all previous plaintext blocks.
+
+ An Initialization Vector (*IV*) is required.
+
+ See `NIST SP800-38A`_ , Section 6.2 .
+
+ .. _`NIST SP800-38A` : http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, block_cipher, iv):
+ """Create a new block cipher, configured in CBC mode.
+
+ :Parameters:
+ block_cipher : C pointer
+ A smart pointer to the low-level block cipher instance.
+
+ iv : bytes/bytearray/memoryview
+ The initialization vector to use for encryption or decryption.
+ It is as long as the cipher block.
+
+ **The IV must be unpredictable**. Ideally it is picked randomly.
+
+ Reusing the *IV* for encryptions performed with the same key
+ compromises confidentiality.
+ """
+
+ self._state = VoidPointer()
+ result = raw_cbc_lib.CBC_start_operation(block_cipher.get(),
+ c_uint8_ptr(iv),
+ c_size_t(len(iv)),
+ self._state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating the CBC mode"
+ % result)
+
+ # Ensure that object disposal of this Python object will (eventually)
+ # free the memory allocated by the raw library for the cipher mode
+ self._state = SmartPointer(self._state.get(),
+ raw_cbc_lib.CBC_stop_operation)
+
+ # Memory allocated for the underlying block cipher is now owed
+ # by the cipher mode
+ block_cipher.release()
+
+ self.block_size = len(iv)
+ """The block size of the underlying cipher, in bytes."""
+
+ self.iv = _copy_bytes(None, None, iv)
+ """The Initialization Vector originally used to create the object.
+ The value does not change."""
+
+ self.IV = self.iv
+ """Alias for `iv`"""
+
+ self._next = [ self.encrypt, self.decrypt ]
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ That also means that you cannot reuse an object for encrypting
+ or decrypting other data with the same key.
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ Its lenght must be multiple of the cipher block size.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() cannot be called after decrypt()")
+ self._next = [ self.encrypt ]
+
+ if output is None:
+ ciphertext = create_string_buffer(len(plaintext))
+ else:
+ ciphertext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(plaintext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_cbc_lib.CBC_encrypt(self._state.get(),
+ c_uint8_ptr(plaintext),
+ c_uint8_ptr(ciphertext),
+ c_size_t(len(plaintext)))
+ if result:
+ if result == 3:
+ raise ValueError("Data must be padded to %d byte boundary in CBC mode" % self.block_size)
+ raise ValueError("Error %d while encrypting in CBC mode" % result)
+
+ if output is None:
+ return get_raw_buffer(ciphertext)
+ else:
+ return None
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ Its length must be multiple of the cipher block size.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() cannot be called after encrypt()")
+ self._next = [ self.decrypt ]
+
+ if output is None:
+ plaintext = create_string_buffer(len(ciphertext))
+ else:
+ plaintext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(ciphertext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_cbc_lib.CBC_decrypt(self._state.get(),
+ c_uint8_ptr(ciphertext),
+ c_uint8_ptr(plaintext),
+ c_size_t(len(ciphertext)))
+ if result:
+ if result == 3:
+ raise ValueError("Data must be padded to %d byte boundary in CBC mode" % self.block_size)
+ raise ValueError("Error %d while decrypting in CBC mode" % result)
+
+ if output is None:
+ return get_raw_buffer(plaintext)
+ else:
+ return None
+
+
+def _create_cbc_cipher(factory, **kwargs):
+ """Instantiate a cipher object that performs CBC encryption/decryption.
+
+ :Parameters:
+ factory : module
+ The underlying block cipher, a module from ``Crypto.Cipher``.
+
+ :Keywords:
+ iv : bytes/bytearray/memoryview
+ The IV to use for CBC.
+
+ IV : bytes/bytearray/memoryview
+ Alias for ``iv``.
+
+ Any other keyword will be passed to the underlying block cipher.
+ See the relevant documentation for details (at least ``key`` will need
+ to be present).
+ """
+
+ cipher_state = factory._create_base_cipher(kwargs)
+ iv = kwargs.pop("IV", None)
+ IV = kwargs.pop("iv", None)
+
+ if (None, None) == (iv, IV):
+ iv = get_random_bytes(factory.block_size)
+ if iv is not None:
+ if IV is not None:
+ raise TypeError("You must either use 'iv' or 'IV', not both")
+ else:
+ iv = IV
+
+ if len(iv) != factory.block_size:
+ raise ValueError("Incorrect IV length (it must be %d bytes long)" %
+ factory.block_size)
+
+ if kwargs:
+ raise TypeError("Unknown parameters for CBC: %s" % str(kwargs))
+
+ return CbcMode(cipher_state, iv)
diff --git a/lib/Crypto/Cipher/_mode_cbc.pyi b/lib/Crypto/Cipher/_mode_cbc.pyi
new file mode 100644
index 0000000..8b9fb16
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_cbc.pyi
@@ -0,0 +1,25 @@
+from typing import Union, overload
+
+from Crypto.Util._raw_api import SmartPointer
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['CbcMode']
+
+class CbcMode(object):
+ block_size: int
+ iv: Buffer
+ IV: Buffer
+
+ def __init__(self,
+ block_cipher: SmartPointer,
+ iv: Buffer) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
diff --git a/lib/Crypto/Cipher/_mode_ccm.py b/lib/Crypto/Cipher/_mode_ccm.py
new file mode 100644
index 0000000..64077de
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ccm.py
@@ -0,0 +1,650 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+Counter with CBC-MAC (CCM) mode.
+"""
+
+__all__ = ['CcmMode']
+
+import struct
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import (byte_string, bord,
+ _copy_bytes)
+from Crypto.Util._raw_api import is_writeable_buffer
+
+from Crypto.Util.strxor import strxor
+from Crypto.Util.number import long_to_bytes
+
+from Crypto.Hash import BLAKE2s
+from Crypto.Random import get_random_bytes
+
+
+def enum(**enums):
+ return type('Enum', (), enums)
+
+MacStatus = enum(NOT_STARTED=0, PROCESSING_AUTH_DATA=1, PROCESSING_PLAINTEXT=2)
+
+
+class CcmMode(object):
+ """Counter with CBC-MAC (CCM).
+
+ This is an Authenticated Encryption with Associated Data (`AEAD`_) mode.
+ It provides both confidentiality and authenticity.
+
+ The header of the message may be left in the clear, if needed, and it will
+ still be subject to authentication. The decryption step tells the receiver
+ if the message comes from a source that really knowns the secret key.
+ Additionally, decryption detects if any part of the message - including the
+ header - has been modified or corrupted.
+
+ This mode requires a nonce. The nonce shall never repeat for two
+ different messages encrypted with the same key, but it does not need
+ to be random.
+ Note that there is a trade-off between the size of the nonce and the
+ maximum size of a single message you can encrypt.
+
+ It is important to use a large nonce if the key is reused across several
+ messages and the nonce is chosen randomly.
+
+ It is acceptable to us a short nonce if the key is only used a few times or
+ if the nonce is taken from a counter.
+
+ The following table shows the trade-off when the nonce is chosen at
+ random. The column on the left shows how many messages it takes
+ for the keystream to repeat **on average**. In practice, you will want to
+ stop using the key way before that.
+
+ +--------------------+---------------+-------------------+
+ | Avg. # of messages | nonce | Max. message |
+ | before keystream | size | size |
+ | repeats | (bytes) | (bytes) |
+ +====================+===============+===================+
+ | 2^52 | 13 | 64K |
+ +--------------------+---------------+-------------------+
+ | 2^48 | 12 | 16M |
+ +--------------------+---------------+-------------------+
+ | 2^44 | 11 | 4G |
+ +--------------------+---------------+-------------------+
+ | 2^40 | 10 | 1T |
+ +--------------------+---------------+-------------------+
+ | 2^36 | 9 | 64P |
+ +--------------------+---------------+-------------------+
+ | 2^32 | 8 | 16E |
+ +--------------------+---------------+-------------------+
+
+ This mode is only available for ciphers that operate on 128 bits blocks
+ (e.g. AES but not TDES).
+
+ See `NIST SP800-38C`_ or RFC3610_.
+
+ .. _`NIST SP800-38C`: http://csrc.nist.gov/publications/nistpubs/800-38C/SP800-38C.pdf
+ .. _RFC3610: https://tools.ietf.org/html/rfc3610
+ .. _AEAD: http://blog.cryptographyengineering.com/2012/05/how-to-choose-authenticated-encryption.html
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, factory, key, nonce, mac_len, msg_len, assoc_len,
+ cipher_params):
+
+ self.block_size = factory.block_size
+ """The block size of the underlying cipher, in bytes."""
+
+ self.nonce = _copy_bytes(None, None, nonce)
+ """The nonce used for this cipher instance"""
+
+ self._factory = factory
+ self._key = _copy_bytes(None, None, key)
+ self._mac_len = mac_len
+ self._msg_len = msg_len
+ self._assoc_len = assoc_len
+ self._cipher_params = cipher_params
+
+ self._mac_tag = None # Cache for MAC tag
+
+ if self.block_size != 16:
+ raise ValueError("CCM mode is only available for ciphers"
+ " that operate on 128 bits blocks")
+
+ # MAC tag length (Tlen)
+ if mac_len not in (4, 6, 8, 10, 12, 14, 16):
+ raise ValueError("Parameter 'mac_len' must be even"
+ " and in the range 4..16 (not %d)" % mac_len)
+
+ # Nonce value
+ if not (nonce and 7 <= len(nonce) <= 13):
+ raise ValueError("Length of parameter 'nonce' must be"
+ " in the range 7..13 bytes")
+
+ # Create MAC object (the tag will be the last block
+ # bytes worth of ciphertext)
+ self._mac = self._factory.new(key,
+ factory.MODE_CBC,
+ iv=b'\x00' * 16,
+ **cipher_params)
+ self._mac_status = MacStatus.NOT_STARTED
+ self._t = None
+
+ # Allowed transitions after initialization
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ # Cumulative lengths
+ self._cumul_assoc_len = 0
+ self._cumul_msg_len = 0
+
+ # Cache for unaligned associated data/plaintext.
+ # This is a list with byte strings, but when the MAC starts,
+ # it will become a binary string no longer than the block size.
+ self._cache = []
+
+ # Start CTR cipher, by formatting the counter (A.3)
+ q = 15 - len(nonce) # length of Q, the encoded message length
+ self._cipher = self._factory.new(key,
+ self._factory.MODE_CTR,
+ nonce=struct.pack("B", q - 1) + self.nonce,
+ **cipher_params)
+
+ # S_0, step 6 in 6.1 for j=0
+ self._s_0 = self._cipher.encrypt(b'\x00' * 16)
+
+ # Try to start the MAC
+ if None not in (assoc_len, msg_len):
+ self._start_mac()
+
+ def _start_mac(self):
+
+ assert(self._mac_status == MacStatus.NOT_STARTED)
+ assert(None not in (self._assoc_len, self._msg_len))
+ assert(isinstance(self._cache, list))
+
+ # Formatting control information and nonce (A.2.1)
+ q = 15 - len(self.nonce) # length of Q, the encoded message length
+ flags = (64 * (self._assoc_len > 0) + 8 * ((self._mac_len - 2) // 2) +
+ (q - 1))
+ b_0 = struct.pack("B", flags) + self.nonce + long_to_bytes(self._msg_len, q)
+
+ # Formatting associated data (A.2.2)
+ # Encoded 'a' is concatenated with the associated data 'A'
+ assoc_len_encoded = b''
+ if self._assoc_len > 0:
+ if self._assoc_len < (2 ** 16 - 2 ** 8):
+ enc_size = 2
+ elif self._assoc_len < (2 ** 32):
+ assoc_len_encoded = b'\xFF\xFE'
+ enc_size = 4
+ else:
+ assoc_len_encoded = b'\xFF\xFF'
+ enc_size = 8
+ assoc_len_encoded += long_to_bytes(self._assoc_len, enc_size)
+
+ # b_0 and assoc_len_encoded must be processed first
+ self._cache.insert(0, b_0)
+ self._cache.insert(1, assoc_len_encoded)
+
+ # Process all the data cached so far
+ first_data_to_mac = b"".join(self._cache)
+ self._cache = b""
+ self._mac_status = MacStatus.PROCESSING_AUTH_DATA
+ self._update(first_data_to_mac)
+
+ def _pad_cache_and_update(self):
+
+ assert(self._mac_status != MacStatus.NOT_STARTED)
+ assert(len(self._cache) < self.block_size)
+
+ # Associated data is concatenated with the least number
+ # of zero bytes (possibly none) to reach alignment to
+ # the 16 byte boundary (A.2.3)
+ len_cache = len(self._cache)
+ if len_cache > 0:
+ self._update(b'\x00' * (self.block_size - len_cache))
+
+ def update(self, assoc_data):
+ """Protect associated data
+
+ If there is any associated data, the caller has to invoke
+ this function one or more times, before using
+ ``decrypt`` or ``encrypt``.
+
+ By *associated data* it is meant any data (e.g. packet headers) that
+ will not be encrypted and will be transmitted in the clear.
+ However, the receiver is still able to detect any modification to it.
+ In CCM, the *associated data* is also called
+ *additional authenticated data* (AAD).
+
+ If there is no associated data, this method must not be called.
+
+ The caller may split associated data in segments of any size, and
+ invoke this method multiple times, each time with the next segment.
+
+ :Parameters:
+ assoc_data : bytes/bytearray/memoryview
+ A piece of associated data. There are no restrictions on its size.
+ """
+
+ if self.update not in self._next:
+ raise TypeError("update() can only be called"
+ " immediately after initialization")
+
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ self._cumul_assoc_len += len(assoc_data)
+ if self._assoc_len is not None and \
+ self._cumul_assoc_len > self._assoc_len:
+ raise ValueError("Associated data is too long")
+
+ self._update(assoc_data)
+ return self
+
+ def _update(self, assoc_data_pt=b""):
+ """Update the MAC with associated data or plaintext
+ (without FSM checks)"""
+
+ # If MAC has not started yet, we just park the data into a list.
+ # If the data is mutable, we create a copy and store that instead.
+ if self._mac_status == MacStatus.NOT_STARTED:
+ if is_writeable_buffer(assoc_data_pt):
+ assoc_data_pt = _copy_bytes(None, None, assoc_data_pt)
+ self._cache.append(assoc_data_pt)
+ return
+
+ assert(len(self._cache) < self.block_size)
+
+ if len(self._cache) > 0:
+ filler = min(self.block_size - len(self._cache),
+ len(assoc_data_pt))
+ self._cache += _copy_bytes(None, filler, assoc_data_pt)
+ assoc_data_pt = _copy_bytes(filler, None, assoc_data_pt)
+
+ if len(self._cache) < self.block_size:
+ return
+
+ # The cache is exactly one block
+ self._t = self._mac.encrypt(self._cache)
+ self._cache = b""
+
+ update_len = len(assoc_data_pt) // self.block_size * self.block_size
+ self._cache = _copy_bytes(update_len, None, assoc_data_pt)
+ if update_len > 0:
+ self._t = self._mac.encrypt(assoc_data_pt[:update_len])[-16:]
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ This method can be called only **once** if ``msg_len`` was
+ not passed at initialization.
+
+ If ``msg_len`` was given, the data to encrypt can be broken
+ up in two or more pieces and `encrypt` can be called
+ multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() can only be called after"
+ " initialization or an update()")
+ self._next = [self.encrypt, self.digest]
+
+ # No more associated data allowed from now
+ if self._assoc_len is None:
+ assert(isinstance(self._cache, list))
+ self._assoc_len = sum([len(x) for x in self._cache])
+ if self._msg_len is not None:
+ self._start_mac()
+ else:
+ if self._cumul_assoc_len < self._assoc_len:
+ raise ValueError("Associated data is too short")
+
+ # Only once piece of plaintext accepted if message length was
+ # not declared in advance
+ if self._msg_len is None:
+ self._msg_len = len(plaintext)
+ self._start_mac()
+ self._next = [self.digest]
+
+ self._cumul_msg_len += len(plaintext)
+ if self._cumul_msg_len > self._msg_len:
+ raise ValueError("Message is too long")
+
+ if self._mac_status == MacStatus.PROCESSING_AUTH_DATA:
+ # Associated data is concatenated with the least number
+ # of zero bytes (possibly none) to reach alignment to
+ # the 16 byte boundary (A.2.3)
+ self._pad_cache_and_update()
+ self._mac_status = MacStatus.PROCESSING_PLAINTEXT
+
+ self._update(plaintext)
+ return self._cipher.encrypt(plaintext, output=output)
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ This method can be called only **once** if ``msg_len`` was
+ not passed at initialization.
+
+ If ``msg_len`` was given, the data to decrypt can be
+ broken up in two or more pieces and `decrypt` can be
+ called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() can only be called"
+ " after initialization or an update()")
+ self._next = [self.decrypt, self.verify]
+
+ # No more associated data allowed from now
+ if self._assoc_len is None:
+ assert(isinstance(self._cache, list))
+ self._assoc_len = sum([len(x) for x in self._cache])
+ if self._msg_len is not None:
+ self._start_mac()
+ else:
+ if self._cumul_assoc_len < self._assoc_len:
+ raise ValueError("Associated data is too short")
+
+ # Only once piece of ciphertext accepted if message length was
+ # not declared in advance
+ if self._msg_len is None:
+ self._msg_len = len(ciphertext)
+ self._start_mac()
+ self._next = [self.verify]
+
+ self._cumul_msg_len += len(ciphertext)
+ if self._cumul_msg_len > self._msg_len:
+ raise ValueError("Message is too long")
+
+ if self._mac_status == MacStatus.PROCESSING_AUTH_DATA:
+ # Associated data is concatenated with the least number
+ # of zero bytes (possibly none) to reach alignment to
+ # the 16 byte boundary (A.2.3)
+ self._pad_cache_and_update()
+ self._mac_status = MacStatus.PROCESSING_PLAINTEXT
+
+ # Encrypt is equivalent to decrypt with the CTR mode
+ plaintext = self._cipher.encrypt(ciphertext, output=output)
+ if output is None:
+ self._update(plaintext)
+ else:
+ self._update(output)
+ return plaintext
+
+ def digest(self):
+ """Compute the *binary* MAC tag.
+
+ The caller invokes this function at the very end.
+
+ This method returns the MAC that shall be sent to the receiver,
+ together with the ciphertext.
+
+ :Return: the MAC, as a byte string.
+ """
+
+ if self.digest not in self._next:
+ raise TypeError("digest() cannot be called when decrypting"
+ " or validating a message")
+ self._next = [self.digest]
+ return self._digest()
+
+ def _digest(self):
+ if self._mac_tag:
+ return self._mac_tag
+
+ if self._assoc_len is None:
+ assert(isinstance(self._cache, list))
+ self._assoc_len = sum([len(x) for x in self._cache])
+ if self._msg_len is not None:
+ self._start_mac()
+ else:
+ if self._cumul_assoc_len < self._assoc_len:
+ raise ValueError("Associated data is too short")
+
+ if self._msg_len is None:
+ self._msg_len = 0
+ self._start_mac()
+
+ if self._cumul_msg_len != self._msg_len:
+ raise ValueError("Message is too short")
+
+ # Both associated data and payload are concatenated with the least
+ # number of zero bytes (possibly none) that align it to the
+ # 16 byte boundary (A.2.2 and A.2.3)
+ self._pad_cache_and_update()
+
+ # Step 8 in 6.1 (T xor MSB_Tlen(S_0))
+ self._mac_tag = strxor(self._t, self._s_0)[:self._mac_len]
+
+ return self._mac_tag
+
+ def hexdigest(self):
+ """Compute the *printable* MAC tag.
+
+ This method is like `digest`.
+
+ :Return: the MAC, as a hexadecimal string.
+ """
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def verify(self, received_mac_tag):
+ """Validate the *binary* MAC tag.
+
+ The caller invokes this function at the very end.
+
+ This method checks if the decrypted message is indeed valid
+ (that is, if the key is correct) and it has not been
+ tampered with while in transit.
+
+ :Parameters:
+ received_mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ if self.verify not in self._next:
+ raise TypeError("verify() cannot be called"
+ " when encrypting a message")
+ self._next = [self.verify]
+
+ self._digest()
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=self._mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=received_mac_tag)
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Validate the *printable* MAC tag.
+
+ This method is like `verify`.
+
+ :Parameters:
+ hex_mac_tag : string
+ This is the *printable* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ self.verify(unhexlify(hex_mac_tag))
+
+ def encrypt_and_digest(self, plaintext, output=None):
+ """Perform encrypt() and digest() in one step.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ a tuple with two items:
+
+ - the ciphertext, as ``bytes``
+ - the MAC tag, as ``bytes``
+
+ The first item becomes ``None`` when the ``output`` parameter
+ specified a location for the result.
+ """
+
+ return self.encrypt(plaintext, output=output), self.digest()
+
+ def decrypt_and_verify(self, ciphertext, received_mac_tag, output=None):
+ """Perform decrypt() and verify() in one step.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ received_mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return: the plaintext as ``bytes`` or ``None`` when the ``output``
+ parameter specified a location for the result.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ plaintext = self.decrypt(ciphertext, output=output)
+ self.verify(received_mac_tag)
+ return plaintext
+
+
+def _create_ccm_cipher(factory, **kwargs):
+ """Create a new block cipher, configured in CCM mode.
+
+ :Parameters:
+ factory : module
+ A symmetric cipher module from `Crypto.Cipher` (like
+ `Crypto.Cipher.AES`).
+
+ :Keywords:
+ key : bytes/bytearray/memoryview
+ The secret key to use in the symmetric cipher.
+
+ nonce : bytes/bytearray/memoryview
+ A value that must never be reused for any other encryption.
+
+ Its length must be in the range ``[7..13]``.
+ 11 or 12 bytes are reasonable values in general. Bear in
+ mind that with CCM there is a trade-off between nonce length and
+ maximum message size.
+
+ If not specified, a 11 byte long random string is used.
+
+ mac_len : integer
+ Length of the MAC, in bytes. It must be even and in
+ the range ``[4..16]``. The default is 16.
+
+ msg_len : integer
+ Length of the message to (de)cipher.
+ If not specified, ``encrypt`` or ``decrypt`` may only be called once.
+
+ assoc_len : integer
+ Length of the associated data.
+ If not specified, all data is internally buffered.
+ """
+
+ try:
+ key = key = kwargs.pop("key")
+ except KeyError as e:
+ raise TypeError("Missing parameter: " + str(e))
+
+ nonce = kwargs.pop("nonce", None) # N
+ if nonce is None:
+ nonce = get_random_bytes(11)
+ mac_len = kwargs.pop("mac_len", factory.block_size)
+ msg_len = kwargs.pop("msg_len", None) # p
+ assoc_len = kwargs.pop("assoc_len", None) # a
+ cipher_params = dict(kwargs)
+
+ return CcmMode(factory, key, nonce, mac_len, msg_len,
+ assoc_len, cipher_params)
diff --git a/lib/Crypto/Cipher/_mode_ccm.pyi b/lib/Crypto/Cipher/_mode_ccm.pyi
new file mode 100644
index 0000000..4b9f620
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ccm.pyi
@@ -0,0 +1,47 @@
+from types import ModuleType
+from typing import Union, overload, Dict, Tuple, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['CcmMode']
+
+class CcmMode(object):
+ block_size: int
+ nonce: bytes
+
+ def __init__(self,
+ factory: ModuleType,
+ key: Buffer,
+ nonce: Buffer,
+ mac_len: int,
+ msg_len: int,
+ assoc_len: int,
+ cipher_params: Dict) -> None: ...
+
+ def update(self, assoc_data: Buffer) -> CcmMode: ...
+
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, received_mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer) -> Tuple[bytes, bytes]: ...
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer,
+ output: Buffer) -> Tuple[None, bytes]: ...
+ def decrypt_and_verify(self,
+ ciphertext: Buffer,
+ received_mac_tag: Buffer,
+ output: Optional[Union[bytearray, memoryview]] = ...) -> bytes: ...
diff --git a/lib/Crypto/Cipher/_mode_cfb.py b/lib/Crypto/Cipher/_mode_cfb.py
new file mode 100644
index 0000000..b3ee1c7
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_cfb.py
@@ -0,0 +1,293 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/mode_cfb.py : CFB mode
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""
+Counter Feedback (CFB) mode.
+"""
+
+__all__ = ['CfbMode']
+
+from Crypto.Util.py3compat import _copy_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ create_string_buffer, get_raw_buffer,
+ SmartPointer, c_size_t, c_uint8_ptr,
+ is_writeable_buffer)
+
+from Crypto.Random import get_random_bytes
+
+raw_cfb_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_cfb","""
+ int CFB_start_operation(void *cipher,
+ const uint8_t iv[],
+ size_t iv_len,
+ size_t segment_len, /* In bytes */
+ void **pResult);
+ int CFB_encrypt(void *cfbState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CFB_decrypt(void *cfbState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CFB_stop_operation(void *state);"""
+ )
+
+
+class CfbMode(object):
+ """*Cipher FeedBack (CFB)*.
+
+ This mode is similar to CFB, but it transforms
+ the underlying block cipher into a stream cipher.
+
+ Plaintext and ciphertext are processed in *segments*
+ of **s** bits. The mode is therefore sometimes
+ labelled **s**-bit CFB.
+
+ An Initialization Vector (*IV*) is required.
+
+ See `NIST SP800-38A`_ , Section 6.3.
+
+ .. _`NIST SP800-38A` : http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, block_cipher, iv, segment_size):
+ """Create a new block cipher, configured in CFB mode.
+
+ :Parameters:
+ block_cipher : C pointer
+ A smart pointer to the low-level block cipher instance.
+
+ iv : bytes/bytearray/memoryview
+ The initialization vector to use for encryption or decryption.
+ It is as long as the cipher block.
+
+ **The IV must be unpredictable**. Ideally it is picked randomly.
+
+ Reusing the *IV* for encryptions performed with the same key
+ compromises confidentiality.
+
+ segment_size : integer
+ The number of bytes the plaintext and ciphertext are segmented in.
+ """
+
+ self._state = VoidPointer()
+ result = raw_cfb_lib.CFB_start_operation(block_cipher.get(),
+ c_uint8_ptr(iv),
+ c_size_t(len(iv)),
+ c_size_t(segment_size),
+ self._state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating the CFB mode" % result)
+
+ # Ensure that object disposal of this Python object will (eventually)
+ # free the memory allocated by the raw library for the cipher mode
+ self._state = SmartPointer(self._state.get(),
+ raw_cfb_lib.CFB_stop_operation)
+
+ # Memory allocated for the underlying block cipher is now owed
+ # by the cipher mode
+ block_cipher.release()
+
+ self.block_size = len(iv)
+ """The block size of the underlying cipher, in bytes."""
+
+ self.iv = _copy_bytes(None, None, iv)
+ """The Initialization Vector originally used to create the object.
+ The value does not change."""
+
+ self.IV = self.iv
+ """Alias for `iv`"""
+
+ self._next = [ self.encrypt, self.decrypt ]
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() cannot be called after decrypt()")
+ self._next = [ self.encrypt ]
+
+ if output is None:
+ ciphertext = create_string_buffer(len(plaintext))
+ else:
+ ciphertext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(plaintext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_cfb_lib.CFB_encrypt(self._state.get(),
+ c_uint8_ptr(plaintext),
+ c_uint8_ptr(ciphertext),
+ c_size_t(len(plaintext)))
+ if result:
+ raise ValueError("Error %d while encrypting in CFB mode" % result)
+
+ if output is None:
+ return get_raw_buffer(ciphertext)
+ else:
+ return None
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() cannot be called after encrypt()")
+ self._next = [ self.decrypt ]
+
+ if output is None:
+ plaintext = create_string_buffer(len(ciphertext))
+ else:
+ plaintext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(ciphertext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_cfb_lib.CFB_decrypt(self._state.get(),
+ c_uint8_ptr(ciphertext),
+ c_uint8_ptr(plaintext),
+ c_size_t(len(ciphertext)))
+ if result:
+ raise ValueError("Error %d while decrypting in CFB mode" % result)
+
+ if output is None:
+ return get_raw_buffer(plaintext)
+ else:
+ return None
+
+
+def _create_cfb_cipher(factory, **kwargs):
+ """Instantiate a cipher object that performs CFB encryption/decryption.
+
+ :Parameters:
+ factory : module
+ The underlying block cipher, a module from ``Crypto.Cipher``.
+
+ :Keywords:
+ iv : bytes/bytearray/memoryview
+ The IV to use for CFB.
+
+ IV : bytes/bytearray/memoryview
+ Alias for ``iv``.
+
+ segment_size : integer
+ The number of bit the plaintext and ciphertext are segmented in.
+ If not present, the default is 8.
+
+ Any other keyword will be passed to the underlying block cipher.
+ See the relevant documentation for details (at least ``key`` will need
+ to be present).
+ """
+
+ cipher_state = factory._create_base_cipher(kwargs)
+
+ iv = kwargs.pop("IV", None)
+ IV = kwargs.pop("iv", None)
+
+ if (None, None) == (iv, IV):
+ iv = get_random_bytes(factory.block_size)
+ if iv is not None:
+ if IV is not None:
+ raise TypeError("You must either use 'iv' or 'IV', not both")
+ else:
+ iv = IV
+
+ if len(iv) != factory.block_size:
+ raise ValueError("Incorrect IV length (it must be %d bytes long)" %
+ factory.block_size)
+
+ segment_size_bytes, rem = divmod(kwargs.pop("segment_size", 8), 8)
+ if segment_size_bytes == 0 or rem != 0:
+ raise ValueError("'segment_size' must be positive and multiple of 8 bits")
+
+ if kwargs:
+ raise TypeError("Unknown parameters for CFB: %s" % str(kwargs))
+ return CfbMode(cipher_state, iv, segment_size_bytes)
diff --git a/lib/Crypto/Cipher/_mode_cfb.pyi b/lib/Crypto/Cipher/_mode_cfb.pyi
new file mode 100644
index 0000000..e13a909
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_cfb.pyi
@@ -0,0 +1,26 @@
+from typing import Union, overload
+
+from Crypto.Util._raw_api import SmartPointer
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['CfbMode']
+
+
+class CfbMode(object):
+ block_size: int
+ iv: Buffer
+ IV: Buffer
+
+ def __init__(self,
+ block_cipher: SmartPointer,
+ iv: Buffer,
+ segment_size: int) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
diff --git a/lib/Crypto/Cipher/_mode_ctr.py b/lib/Crypto/Cipher/_mode_ctr.py
new file mode 100644
index 0000000..15c7e83
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ctr.py
@@ -0,0 +1,393 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/mode_ctr.py : CTR mode
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""
+Counter (CTR) mode.
+"""
+
+__all__ = ['CtrMode']
+
+import struct
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ create_string_buffer, get_raw_buffer,
+ SmartPointer, c_size_t, c_uint8_ptr,
+ is_writeable_buffer)
+
+from Crypto.Random import get_random_bytes
+from Crypto.Util.py3compat import _copy_bytes, is_native_int
+from Crypto.Util.number import long_to_bytes
+
+raw_ctr_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_ctr", """
+ int CTR_start_operation(void *cipher,
+ uint8_t initialCounterBlock[],
+ size_t initialCounterBlock_len,
+ size_t prefix_len,
+ unsigned counter_len,
+ unsigned littleEndian,
+ void **pResult);
+ int CTR_encrypt(void *ctrState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CTR_decrypt(void *ctrState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int CTR_stop_operation(void *ctrState);"""
+ )
+
+
+class CtrMode(object):
+ """*CounTeR (CTR)* mode.
+
+ This mode is very similar to ECB, in that
+ encryption of one block is done independently of all other blocks.
+
+ Unlike ECB, the block *position* contributes to the encryption
+ and no information leaks about symbol frequency.
+
+ Each message block is associated to a *counter* which
+ must be unique across all messages that get encrypted
+ with the same key (not just within the same message).
+ The counter is as big as the block size.
+
+ Counters can be generated in several ways. The most
+ straightword one is to choose an *initial counter block*
+ (which can be made public, similarly to the *IV* for the
+ other modes) and increment its lowest **m** bits by one
+ (modulo *2^m*) for each block. In most cases, **m** is
+ chosen to be half the block size.
+
+ See `NIST SP800-38A`_, Section 6.5 (for the mode) and
+ Appendix B (for how to manage the *initial counter block*).
+
+ .. _`NIST SP800-38A` : http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, block_cipher, initial_counter_block,
+ prefix_len, counter_len, little_endian):
+ """Create a new block cipher, configured in CTR mode.
+
+ :Parameters:
+ block_cipher : C pointer
+ A smart pointer to the low-level block cipher instance.
+
+ initial_counter_block : bytes/bytearray/memoryview
+ The initial plaintext to use to generate the key stream.
+
+ It is as large as the cipher block, and it embeds
+ the initial value of the counter.
+
+ This value must not be reused.
+ It shall contain a nonce or a random component.
+ Reusing the *initial counter block* for encryptions
+ performed with the same key compromises confidentiality.
+
+ prefix_len : integer
+ The amount of bytes at the beginning of the counter block
+ that never change.
+
+ counter_len : integer
+ The length in bytes of the counter embedded in the counter
+ block.
+
+ little_endian : boolean
+ True if the counter in the counter block is an integer encoded
+ in little endian mode. If False, it is big endian.
+ """
+
+ if len(initial_counter_block) == prefix_len + counter_len:
+ self.nonce = _copy_bytes(None, prefix_len, initial_counter_block)
+ """Nonce; not available if there is a fixed suffix"""
+
+ self._state = VoidPointer()
+ result = raw_ctr_lib.CTR_start_operation(block_cipher.get(),
+ c_uint8_ptr(initial_counter_block),
+ c_size_t(len(initial_counter_block)),
+ c_size_t(prefix_len),
+ counter_len,
+ little_endian,
+ self._state.address_of())
+ if result:
+ raise ValueError("Error %X while instantiating the CTR mode"
+ % result)
+
+ # Ensure that object disposal of this Python object will (eventually)
+ # free the memory allocated by the raw library for the cipher mode
+ self._state = SmartPointer(self._state.get(),
+ raw_ctr_lib.CTR_stop_operation)
+
+ # Memory allocated for the underlying block cipher is now owed
+ # by the cipher mode
+ block_cipher.release()
+
+ self.block_size = len(initial_counter_block)
+ """The block size of the underlying cipher, in bytes."""
+
+ self._next = [self.encrypt, self.decrypt]
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() cannot be called after decrypt()")
+ self._next = [self.encrypt]
+
+ if output is None:
+ ciphertext = create_string_buffer(len(plaintext))
+ else:
+ ciphertext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(plaintext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_ctr_lib.CTR_encrypt(self._state.get(),
+ c_uint8_ptr(plaintext),
+ c_uint8_ptr(ciphertext),
+ c_size_t(len(plaintext)))
+ if result:
+ if result == 0x60002:
+ raise OverflowError("The counter has wrapped around in"
+ " CTR mode")
+ raise ValueError("Error %X while encrypting in CTR mode" % result)
+
+ if output is None:
+ return get_raw_buffer(ciphertext)
+ else:
+ return None
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() cannot be called after encrypt()")
+ self._next = [self.decrypt]
+
+ if output is None:
+ plaintext = create_string_buffer(len(ciphertext))
+ else:
+ plaintext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(ciphertext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_ctr_lib.CTR_decrypt(self._state.get(),
+ c_uint8_ptr(ciphertext),
+ c_uint8_ptr(plaintext),
+ c_size_t(len(ciphertext)))
+ if result:
+ if result == 0x60002:
+ raise OverflowError("The counter has wrapped around in"
+ " CTR mode")
+ raise ValueError("Error %X while decrypting in CTR mode" % result)
+
+ if output is None:
+ return get_raw_buffer(plaintext)
+ else:
+ return None
+
+
+def _create_ctr_cipher(factory, **kwargs):
+ """Instantiate a cipher object that performs CTR encryption/decryption.
+
+ :Parameters:
+ factory : module
+ The underlying block cipher, a module from ``Crypto.Cipher``.
+
+ :Keywords:
+ nonce : bytes/bytearray/memoryview
+ The fixed part at the beginning of the counter block - the rest is
+ the counter number that gets increased when processing the next block.
+ The nonce must be such that no two messages are encrypted under the
+ same key and the same nonce.
+
+ The nonce must be shorter than the block size (it can have
+ zero length; the counter is then as long as the block).
+
+ If this parameter is not present, a random nonce will be created with
+ length equal to half the block size. No random nonce shorter than
+ 64 bits will be created though - you must really think through all
+ security consequences of using such a short block size.
+
+ initial_value : posive integer or bytes/bytearray/memoryview
+ The initial value for the counter. If not present, the cipher will
+ start counting from 0. The value is incremented by one for each block.
+ The counter number is encoded in big endian mode.
+
+ counter : object
+ Instance of ``Crypto.Util.Counter``, which allows full customization
+ of the counter block. This parameter is incompatible to both ``nonce``
+ and ``initial_value``.
+
+ Any other keyword will be passed to the underlying block cipher.
+ See the relevant documentation for details (at least ``key`` will need
+ to be present).
+ """
+
+ cipher_state = factory._create_base_cipher(kwargs)
+
+ counter = kwargs.pop("counter", None)
+ nonce = kwargs.pop("nonce", None)
+ initial_value = kwargs.pop("initial_value", None)
+ if kwargs:
+ raise TypeError("Invalid parameters for CTR mode: %s" % str(kwargs))
+
+ if counter is not None and (nonce, initial_value) != (None, None):
+ raise TypeError("'counter' and 'nonce'/'initial_value'"
+ " are mutually exclusive")
+
+ if counter is None:
+ # Crypto.Util.Counter is not used
+ if nonce is None:
+ if factory.block_size < 16:
+ raise TypeError("Impossible to create a safe nonce for short"
+ " block sizes")
+ nonce = get_random_bytes(factory.block_size // 2)
+ else:
+ if len(nonce) >= factory.block_size:
+ raise ValueError("Nonce is too long")
+
+ # What is not nonce is counter
+ counter_len = factory.block_size - len(nonce)
+
+ if initial_value is None:
+ initial_value = 0
+
+ if is_native_int(initial_value):
+ if (1 << (counter_len * 8)) - 1 < initial_value:
+ raise ValueError("Initial counter value is too large")
+ initial_counter_block = nonce + long_to_bytes(initial_value, counter_len)
+ else:
+ if len(initial_value) != counter_len:
+ raise ValueError("Incorrect length for counter byte string (%d bytes, expected %d)" %
+ (len(initial_value), counter_len))
+ initial_counter_block = nonce + initial_value
+
+ return CtrMode(cipher_state,
+ initial_counter_block,
+ len(nonce), # prefix
+ counter_len,
+ False) # little_endian
+
+ # Crypto.Util.Counter is used
+
+ # 'counter' used to be a callable object, but now it is
+ # just a dictionary for backward compatibility.
+ _counter = dict(counter)
+ try:
+ counter_len = _counter.pop("counter_len")
+ prefix = _counter.pop("prefix")
+ suffix = _counter.pop("suffix")
+ initial_value = _counter.pop("initial_value")
+ little_endian = _counter.pop("little_endian")
+ except KeyError:
+ raise TypeError("Incorrect counter object"
+ " (use Crypto.Util.Counter.new)")
+
+ # Compute initial counter block
+ words = []
+ while initial_value > 0:
+ words.append(struct.pack('B', initial_value & 255))
+ initial_value >>= 8
+ words += [b'\x00'] * max(0, counter_len - len(words))
+ if not little_endian:
+ words.reverse()
+ initial_counter_block = prefix + b"".join(words) + suffix
+
+ if len(initial_counter_block) != factory.block_size:
+ raise ValueError("Size of the counter block (%d bytes) must match"
+ " block size (%d)" % (len(initial_counter_block),
+ factory.block_size))
+
+ return CtrMode(cipher_state, initial_counter_block,
+ len(prefix), counter_len, little_endian)
diff --git a/lib/Crypto/Cipher/_mode_ctr.pyi b/lib/Crypto/Cipher/_mode_ctr.pyi
new file mode 100644
index 0000000..ce70855
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ctr.pyi
@@ -0,0 +1,27 @@
+from typing import Union, overload
+
+from Crypto.Util._raw_api import SmartPointer
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['CtrMode']
+
+class CtrMode(object):
+ block_size: int
+ nonce: bytes
+
+ def __init__(self,
+ block_cipher: SmartPointer,
+ initial_counter_block: Buffer,
+ prefix_len: int,
+ counter_len: int,
+ little_endian: bool) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
diff --git a/lib/Crypto/Cipher/_mode_eax.py b/lib/Crypto/Cipher/_mode_eax.py
new file mode 100644
index 0000000..d5fb135
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_eax.py
@@ -0,0 +1,408 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+EAX mode.
+"""
+
+__all__ = ['EaxMode']
+
+import struct
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import byte_string, bord, _copy_bytes
+
+from Crypto.Util._raw_api import is_buffer
+
+from Crypto.Util.strxor import strxor
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+
+from Crypto.Hash import CMAC, BLAKE2s
+from Crypto.Random import get_random_bytes
+
+
+class EaxMode(object):
+ """*EAX* mode.
+
+ This is an Authenticated Encryption with Associated Data
+ (`AEAD`_) mode. It provides both confidentiality and authenticity.
+
+ The header of the message may be left in the clear, if needed,
+ and it will still be subject to authentication.
+
+ The decryption step tells the receiver if the message comes
+ from a source that really knowns the secret key.
+ Additionally, decryption detects if any part of the message -
+ including the header - has been modified or corrupted.
+
+ This mode requires a *nonce*.
+
+ This mode is only available for ciphers that operate on 64 or
+ 128 bits blocks.
+
+ There are no official standards defining EAX.
+ The implementation is based on `a proposal`__ that
+ was presented to NIST.
+
+ .. _AEAD: http://blog.cryptographyengineering.com/2012/05/how-to-choose-authenticated-encryption.html
+ .. __: http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/eax/eax-spec.pdf
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, factory, key, nonce, mac_len, cipher_params):
+ """EAX cipher mode"""
+
+ self.block_size = factory.block_size
+ """The block size of the underlying cipher, in bytes."""
+
+ self.nonce = _copy_bytes(None, None, nonce)
+ """The nonce originally used to create the object."""
+
+ self._mac_len = mac_len
+ self._mac_tag = None # Cache for MAC tag
+
+ # Allowed transitions after initialization
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ # MAC tag length
+ if not (4 <= self._mac_len <= self.block_size):
+ raise ValueError("Parameter 'mac_len' must not be larger than %d"
+ % self.block_size)
+
+ # Nonce cannot be empty and must be a byte string
+ if len(self.nonce) == 0:
+ raise ValueError("Nonce cannot be empty in EAX mode")
+ if not is_buffer(nonce):
+ raise TypeError("nonce must be bytes, bytearray or memoryview")
+
+ self._omac = [
+ CMAC.new(key,
+ b'\x00' * (self.block_size - 1) + struct.pack('B', i),
+ ciphermod=factory,
+ cipher_params=cipher_params)
+ for i in range(0, 3)
+ ]
+
+ # Compute MAC of nonce
+ self._omac[0].update(self.nonce)
+ self._signer = self._omac[1]
+
+ # MAC of the nonce is also the initial counter for CTR encryption
+ counter_int = bytes_to_long(self._omac[0].digest())
+ self._cipher = factory.new(key,
+ factory.MODE_CTR,
+ initial_value=counter_int,
+ nonce=b"",
+ **cipher_params)
+
+ def update(self, assoc_data):
+ """Protect associated data
+
+ If there is any associated data, the caller has to invoke
+ this function one or more times, before using
+ ``decrypt`` or ``encrypt``.
+
+ By *associated data* it is meant any data (e.g. packet headers) that
+ will not be encrypted and will be transmitted in the clear.
+ However, the receiver is still able to detect any modification to it.
+
+ If there is no associated data, this method must not be called.
+
+ The caller may split associated data in segments of any size, and
+ invoke this method multiple times, each time with the next segment.
+
+ :Parameters:
+ assoc_data : bytes/bytearray/memoryview
+ A piece of associated data. There are no restrictions on its size.
+ """
+
+ if self.update not in self._next:
+ raise TypeError("update() can only be called"
+ " immediately after initialization")
+
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ self._signer.update(assoc_data)
+ return self
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() can only be called after"
+ " initialization or an update()")
+ self._next = [self.encrypt, self.digest]
+ ct = self._cipher.encrypt(plaintext, output=output)
+ if output is None:
+ self._omac[2].update(ct)
+ else:
+ self._omac[2].update(output)
+ return ct
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() can only be called"
+ " after initialization or an update()")
+ self._next = [self.decrypt, self.verify]
+ self._omac[2].update(ciphertext)
+ return self._cipher.decrypt(ciphertext, output=output)
+
+ def digest(self):
+ """Compute the *binary* MAC tag.
+
+ The caller invokes this function at the very end.
+
+ This method returns the MAC that shall be sent to the receiver,
+ together with the ciphertext.
+
+ :Return: the MAC, as a byte string.
+ """
+
+ if self.digest not in self._next:
+ raise TypeError("digest() cannot be called when decrypting"
+ " or validating a message")
+ self._next = [self.digest]
+
+ if not self._mac_tag:
+ tag = b'\x00' * self.block_size
+ for i in range(3):
+ tag = strxor(tag, self._omac[i].digest())
+ self._mac_tag = tag[:self._mac_len]
+
+ return self._mac_tag
+
+ def hexdigest(self):
+ """Compute the *printable* MAC tag.
+
+ This method is like `digest`.
+
+ :Return: the MAC, as a hexadecimal string.
+ """
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def verify(self, received_mac_tag):
+ """Validate the *binary* MAC tag.
+
+ The caller invokes this function at the very end.
+
+ This method checks if the decrypted message is indeed valid
+ (that is, if the key is correct) and it has not been
+ tampered with while in transit.
+
+ :Parameters:
+ received_mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Raises MacMismatchError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ if self.verify not in self._next:
+ raise TypeError("verify() cannot be called"
+ " when encrypting a message")
+ self._next = [self.verify]
+
+ if not self._mac_tag:
+ tag = b'\x00' * self.block_size
+ for i in range(3):
+ tag = strxor(tag, self._omac[i].digest())
+ self._mac_tag = tag[:self._mac_len]
+
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=self._mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=received_mac_tag)
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Validate the *printable* MAC tag.
+
+ This method is like `verify`.
+
+ :Parameters:
+ hex_mac_tag : string
+ This is the *printable* MAC, as received from the sender.
+ :Raises MacMismatchError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ self.verify(unhexlify(hex_mac_tag))
+
+ def encrypt_and_digest(self, plaintext, output=None):
+ """Perform encrypt() and digest() in one step.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ a tuple with two items:
+
+ - the ciphertext, as ``bytes``
+ - the MAC tag, as ``bytes``
+
+ The first item becomes ``None`` when the ``output`` parameter
+ specified a location for the result.
+ """
+
+ return self.encrypt(plaintext, output=output), self.digest()
+
+ def decrypt_and_verify(self, ciphertext, received_mac_tag, output=None):
+ """Perform decrypt() and verify() in one step.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ received_mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return: the plaintext as ``bytes`` or ``None`` when the ``output``
+ parameter specified a location for the result.
+ :Raises MacMismatchError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ pt = self.decrypt(ciphertext, output=output)
+ self.verify(received_mac_tag)
+ return pt
+
+
+def _create_eax_cipher(factory, **kwargs):
+ """Create a new block cipher, configured in EAX mode.
+
+ :Parameters:
+ factory : module
+ A symmetric cipher module from `Crypto.Cipher` (like
+ `Crypto.Cipher.AES`).
+
+ :Keywords:
+ key : bytes/bytearray/memoryview
+ The secret key to use in the symmetric cipher.
+
+ nonce : bytes/bytearray/memoryview
+ A value that must never be reused for any other encryption.
+ There are no restrictions on its length, but it is recommended to use
+ at least 16 bytes.
+
+ The nonce shall never repeat for two different messages encrypted with
+ the same key, but it does not need to be random.
+
+ If not specified, a 16 byte long random string is used.
+
+ mac_len : integer
+ Length of the MAC, in bytes. It must be no larger than the cipher
+ block bytes (which is the default).
+ """
+
+ try:
+ key = kwargs.pop("key")
+ nonce = kwargs.pop("nonce", None)
+ if nonce is None:
+ nonce = get_random_bytes(16)
+ mac_len = kwargs.pop("mac_len", factory.block_size)
+ except KeyError as e:
+ raise TypeError("Missing parameter: " + str(e))
+
+ return EaxMode(factory, key, nonce, mac_len, kwargs)
diff --git a/lib/Crypto/Cipher/_mode_eax.pyi b/lib/Crypto/Cipher/_mode_eax.pyi
new file mode 100644
index 0000000..cbfa467
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_eax.pyi
@@ -0,0 +1,45 @@
+from types import ModuleType
+from typing import Any, Union, Tuple, Dict, overload, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['EaxMode']
+
+class EaxMode(object):
+ block_size: int
+ nonce: bytes
+
+ def __init__(self,
+ factory: ModuleType,
+ key: Buffer,
+ nonce: Buffer,
+ mac_len: int,
+ cipher_params: Dict) -> None: ...
+
+ def update(self, assoc_data: Buffer) -> EaxMode: ...
+
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, received_mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer) -> Tuple[bytes, bytes]: ...
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer,
+ output: Buffer) -> Tuple[None, bytes]: ...
+ def decrypt_and_verify(self,
+ ciphertext: Buffer,
+ received_mac_tag: Buffer,
+ output: Optional[Union[bytearray, memoryview]] = ...) -> bytes: ...
diff --git a/lib/Crypto/Cipher/_mode_ecb.py b/lib/Crypto/Cipher/_mode_ecb.py
new file mode 100644
index 0000000..3783357
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ecb.py
@@ -0,0 +1,220 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/mode_ecb.py : ECB mode
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""
+Electronic Code Book (ECB) mode.
+"""
+
+__all__ = [ 'EcbMode' ]
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, create_string_buffer,
+ get_raw_buffer, SmartPointer,
+ c_size_t, c_uint8_ptr,
+ is_writeable_buffer)
+
+raw_ecb_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_ecb", """
+ int ECB_start_operation(void *cipher,
+ void **pResult);
+ int ECB_encrypt(void *ecbState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int ECB_decrypt(void *ecbState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int ECB_stop_operation(void *state);
+ """
+ )
+
+
+class EcbMode(object):
+ """*Electronic Code Book (ECB)*.
+
+ This is the simplest encryption mode. Each of the plaintext blocks
+ is directly encrypted into a ciphertext block, independently of
+ any other block.
+
+ This mode is dangerous because it exposes frequency of symbols
+ in your plaintext. Other modes (e.g. *CBC*) should be used instead.
+
+ See `NIST SP800-38A`_ , Section 6.1.
+
+ .. _`NIST SP800-38A` : http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, block_cipher):
+ """Create a new block cipher, configured in ECB mode.
+
+ :Parameters:
+ block_cipher : C pointer
+ A smart pointer to the low-level block cipher instance.
+ """
+ self.block_size = block_cipher.block_size
+
+ self._state = VoidPointer()
+ result = raw_ecb_lib.ECB_start_operation(block_cipher.get(),
+ self._state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating the ECB mode"
+ % result)
+
+ # Ensure that object disposal of this Python object will (eventually)
+ # free the memory allocated by the raw library for the cipher
+ # mode
+ self._state = SmartPointer(self._state.get(),
+ raw_ecb_lib.ECB_stop_operation)
+
+ # Memory allocated for the underlying block cipher is now owned
+ # by the cipher mode
+ block_cipher.release()
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key set at initialization.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ The length must be multiple of the cipher block length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if output is None:
+ ciphertext = create_string_buffer(len(plaintext))
+ else:
+ ciphertext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(plaintext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_ecb_lib.ECB_encrypt(self._state.get(),
+ c_uint8_ptr(plaintext),
+ c_uint8_ptr(ciphertext),
+ c_size_t(len(plaintext)))
+ if result:
+ if result == 3:
+ raise ValueError("Data must be aligned to block boundary in ECB mode")
+ raise ValueError("Error %d while encrypting in ECB mode" % result)
+
+ if output is None:
+ return get_raw_buffer(ciphertext)
+ else:
+ return None
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key set at initialization.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ The length must be multiple of the cipher block length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if output is None:
+ plaintext = create_string_buffer(len(ciphertext))
+ else:
+ plaintext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(ciphertext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_ecb_lib.ECB_decrypt(self._state.get(),
+ c_uint8_ptr(ciphertext),
+ c_uint8_ptr(plaintext),
+ c_size_t(len(ciphertext)))
+ if result:
+ if result == 3:
+ raise ValueError("Data must be aligned to block boundary in ECB mode")
+ raise ValueError("Error %d while decrypting in ECB mode" % result)
+
+ if output is None:
+ return get_raw_buffer(plaintext)
+ else:
+ return None
+
+
+def _create_ecb_cipher(factory, **kwargs):
+ """Instantiate a cipher object that performs ECB encryption/decryption.
+
+ :Parameters:
+ factory : module
+ The underlying block cipher, a module from ``Crypto.Cipher``.
+
+ All keywords are passed to the underlying block cipher.
+ See the relevant documentation for details (at least ``key`` will need
+ to be present"""
+
+ cipher_state = factory._create_base_cipher(kwargs)
+ cipher_state.block_size = factory.block_size
+ if kwargs:
+ raise TypeError("Unknown parameters for ECB: %s" % str(kwargs))
+ return EcbMode(cipher_state)
diff --git a/lib/Crypto/Cipher/_mode_ecb.pyi b/lib/Crypto/Cipher/_mode_ecb.pyi
new file mode 100644
index 0000000..1772b23
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ecb.pyi
@@ -0,0 +1,19 @@
+from typing import Union, overload
+
+from Crypto.Util._raw_api import SmartPointer
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = [ 'EcbMode' ]
+
+class EcbMode(object):
+ def __init__(self, block_cipher: SmartPointer) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
diff --git a/lib/Crypto/Cipher/_mode_gcm.py b/lib/Crypto/Cipher/_mode_gcm.py
new file mode 100644
index 0000000..da8e337
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_gcm.py
@@ -0,0 +1,620 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+Galois/Counter Mode (GCM).
+"""
+
+__all__ = ['GcmMode']
+
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import bord, _copy_bytes
+
+from Crypto.Util._raw_api import is_buffer
+
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+from Crypto.Hash import BLAKE2s
+from Crypto.Random import get_random_bytes
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ create_string_buffer, get_raw_buffer,
+ SmartPointer, c_size_t, c_uint8_ptr)
+
+from Crypto.Util import _cpu_features
+
+
+# C API by module implementing GHASH
+_ghash_api_template = """
+ int ghash_%imp%(uint8_t y_out[16],
+ const uint8_t block_data[],
+ size_t len,
+ const uint8_t y_in[16],
+ const void *exp_key);
+ int ghash_expand_%imp%(const uint8_t h[16],
+ void **ghash_tables);
+ int ghash_destroy_%imp%(void *ghash_tables);
+"""
+
+def _build_impl(lib, postfix):
+ from collections import namedtuple
+
+ funcs = ( "ghash", "ghash_expand", "ghash_destroy" )
+ GHASH_Imp = namedtuple('_GHash_Imp', funcs)
+ try:
+ imp_funcs = [ getattr(lib, x + "_" + postfix) for x in funcs ]
+ except AttributeError: # Make sphinx stop complaining with its mocklib
+ imp_funcs = [ None ] * 3
+ params = dict(zip(funcs, imp_funcs))
+ return GHASH_Imp(**params)
+
+
+def _get_ghash_portable():
+ api = _ghash_api_template.replace("%imp%", "portable")
+ lib = load_pycryptodome_raw_lib("Crypto.Hash._ghash_portable", api)
+ result = _build_impl(lib, "portable")
+ return result
+_ghash_portable = _get_ghash_portable()
+
+
+def _get_ghash_clmul():
+ """Return None if CLMUL implementation is not available"""
+
+ if not _cpu_features.have_clmul():
+ return None
+ try:
+ api = _ghash_api_template.replace("%imp%", "clmul")
+ lib = load_pycryptodome_raw_lib("Crypto.Hash._ghash_clmul", api)
+ result = _build_impl(lib, "clmul")
+ except OSError:
+ result = None
+ return result
+_ghash_clmul = _get_ghash_clmul()
+
+
+class _GHASH(object):
+ """GHASH function defined in NIST SP 800-38D, Algorithm 2.
+
+ If X_1, X_2, .. X_m are the blocks of input data, the function
+ computes:
+
+ X_1*H^{m} + X_2*H^{m-1} + ... + X_m*H
+
+ in the Galois field GF(2^256) using the reducing polynomial
+ (x^128 + x^7 + x^2 + x + 1).
+ """
+
+ def __init__(self, subkey, ghash_c):
+ assert len(subkey) == 16
+
+ self.ghash_c = ghash_c
+
+ self._exp_key = VoidPointer()
+ result = ghash_c.ghash_expand(c_uint8_ptr(subkey),
+ self._exp_key.address_of())
+ if result:
+ raise ValueError("Error %d while expanding the GHASH key" % result)
+
+ self._exp_key = SmartPointer(self._exp_key.get(),
+ ghash_c.ghash_destroy)
+
+ # create_string_buffer always returns a string of zeroes
+ self._last_y = create_string_buffer(16)
+
+ def update(self, block_data):
+ assert len(block_data) % 16 == 0
+
+ result = self.ghash_c.ghash(self._last_y,
+ c_uint8_ptr(block_data),
+ c_size_t(len(block_data)),
+ self._last_y,
+ self._exp_key.get())
+ if result:
+ raise ValueError("Error %d while updating GHASH" % result)
+
+ return self
+
+ def digest(self):
+ return get_raw_buffer(self._last_y)
+
+
+def enum(**enums):
+ return type('Enum', (), enums)
+
+
+MacStatus = enum(PROCESSING_AUTH_DATA=1, PROCESSING_CIPHERTEXT=2)
+
+
+class GcmMode(object):
+ """Galois Counter Mode (GCM).
+
+ This is an Authenticated Encryption with Associated Data (`AEAD`_) mode.
+ It provides both confidentiality and authenticity.
+
+ The header of the message may be left in the clear, if needed, and it will
+ still be subject to authentication. The decryption step tells the receiver
+ if the message comes from a source that really knowns the secret key.
+ Additionally, decryption detects if any part of the message - including the
+ header - has been modified or corrupted.
+
+ This mode requires a *nonce*.
+
+ This mode is only available for ciphers that operate on 128 bits blocks
+ (e.g. AES but not TDES).
+
+ See `NIST SP800-38D`_.
+
+ .. _`NIST SP800-38D`: http://csrc.nist.gov/publications/nistpubs/800-38D/SP-800-38D.pdf
+ .. _AEAD: http://blog.cryptographyengineering.com/2012/05/how-to-choose-authenticated-encryption.html
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, factory, key, nonce, mac_len, cipher_params, ghash_c):
+
+ self.block_size = factory.block_size
+ if self.block_size != 16:
+ raise ValueError("GCM mode is only available for ciphers"
+ " that operate on 128 bits blocks")
+
+ if len(nonce) == 0:
+ raise ValueError("Nonce cannot be empty")
+
+ if not is_buffer(nonce):
+ raise TypeError("Nonce must be bytes, bytearray or memoryview")
+
+ # See NIST SP 800 38D, 5.2.1.1
+ if len(nonce) > 2**64 - 1:
+ raise ValueError("Nonce exceeds maximum length")
+
+
+ self.nonce = _copy_bytes(None, None, nonce)
+ """Nonce"""
+
+ self._factory = factory
+ self._key = _copy_bytes(None, None, key)
+ self._tag = None # Cache for MAC tag
+
+ self._mac_len = mac_len
+ if not (4 <= mac_len <= 16):
+ raise ValueError("Parameter 'mac_len' must be in the range 4..16")
+
+ # Allowed transitions after initialization
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ self._no_more_assoc_data = False
+
+ # Length of associated data
+ self._auth_len = 0
+
+ # Length of the ciphertext or plaintext
+ self._msg_len = 0
+
+ # Step 1 in SP800-38D, Algorithm 4 (encryption) - Compute H
+ # See also Algorithm 5 (decryption)
+ hash_subkey = factory.new(key,
+ self._factory.MODE_ECB,
+ **cipher_params
+ ).encrypt(b'\x00' * 16)
+
+ # Step 2 - Compute J0
+ if len(self.nonce) == 12:
+ j0 = self.nonce + b"\x00\x00\x00\x01"
+ else:
+ fill = (16 - (len(nonce) % 16)) % 16 + 8
+ ghash_in = (self.nonce +
+ b'\x00' * fill +
+ long_to_bytes(8 * len(nonce), 8))
+ j0 = _GHASH(hash_subkey, ghash_c).update(ghash_in).digest()
+
+ # Step 3 - Prepare GCTR cipher for encryption/decryption
+ nonce_ctr = j0[:12]
+ iv_ctr = (bytes_to_long(j0) + 1) & 0xFFFFFFFF
+ self._cipher = factory.new(key,
+ self._factory.MODE_CTR,
+ initial_value=iv_ctr,
+ nonce=nonce_ctr,
+ **cipher_params)
+
+ # Step 5 - Bootstrat GHASH
+ self._signer = _GHASH(hash_subkey, ghash_c)
+
+ # Step 6 - Prepare GCTR cipher for GMAC
+ self._tag_cipher = factory.new(key,
+ self._factory.MODE_CTR,
+ initial_value=j0,
+ nonce=b"",
+ **cipher_params)
+
+ # Cache for data to authenticate
+ self._cache = b""
+
+ self._status = MacStatus.PROCESSING_AUTH_DATA
+
+ def update(self, assoc_data):
+ """Protect associated data
+
+ If there is any associated data, the caller has to invoke
+ this function one or more times, before using
+ ``decrypt`` or ``encrypt``.
+
+ By *associated data* it is meant any data (e.g. packet headers) that
+ will not be encrypted and will be transmitted in the clear.
+ However, the receiver is still able to detect any modification to it.
+ In GCM, the *associated data* is also called
+ *additional authenticated data* (AAD).
+
+ If there is no associated data, this method must not be called.
+
+ The caller may split associated data in segments of any size, and
+ invoke this method multiple times, each time with the next segment.
+
+ :Parameters:
+ assoc_data : bytes/bytearray/memoryview
+ A piece of associated data. There are no restrictions on its size.
+ """
+
+ if self.update not in self._next:
+ raise TypeError("update() can only be called"
+ " immediately after initialization")
+
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ self._update(assoc_data)
+ self._auth_len += len(assoc_data)
+
+ # See NIST SP 800 38D, 5.2.1.1
+ if self._auth_len > 2**64 - 1:
+ raise ValueError("Additional Authenticated Data exceeds maximum length")
+
+ return self
+
+ def _update(self, data):
+ assert(len(self._cache) < 16)
+
+ if len(self._cache) > 0:
+ filler = min(16 - len(self._cache), len(data))
+ self._cache += _copy_bytes(None, filler, data)
+ data = data[filler:]
+
+ if len(self._cache) < 16:
+ return
+
+ # The cache is exactly one block
+ self._signer.update(self._cache)
+ self._cache = b""
+
+ update_len = len(data) // 16 * 16
+ self._cache = _copy_bytes(update_len, None, data)
+ if update_len > 0:
+ self._signer.update(data[:update_len])
+
+ def _pad_cache_and_update(self):
+ assert(len(self._cache) < 16)
+
+ # The authenticated data A is concatenated to the minimum
+ # number of zero bytes (possibly none) such that the
+ # - ciphertext C is aligned to the 16 byte boundary.
+ # See step 5 in section 7.1
+ # - ciphertext C is aligned to the 16 byte boundary.
+ # See step 6 in section 7.2
+ len_cache = len(self._cache)
+ if len_cache > 0:
+ self._update(b'\x00' * (16 - len_cache))
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() can only be called after"
+ " initialization or an update()")
+ self._next = [self.encrypt, self.digest]
+
+ ciphertext = self._cipher.encrypt(plaintext, output=output)
+
+ if self._status == MacStatus.PROCESSING_AUTH_DATA:
+ self._pad_cache_and_update()
+ self._status = MacStatus.PROCESSING_CIPHERTEXT
+
+ self._update(ciphertext if output is None else output)
+ self._msg_len += len(plaintext)
+
+ # See NIST SP 800 38D, 5.2.1.1
+ if self._msg_len > 2**39 - 256:
+ raise ValueError("Plaintext exceeds maximum length")
+
+ return ciphertext
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() can only be called"
+ " after initialization or an update()")
+ self._next = [self.decrypt, self.verify]
+
+ if self._status == MacStatus.PROCESSING_AUTH_DATA:
+ self._pad_cache_and_update()
+ self._status = MacStatus.PROCESSING_CIPHERTEXT
+
+ self._update(ciphertext)
+ self._msg_len += len(ciphertext)
+
+ return self._cipher.decrypt(ciphertext, output=output)
+
+ def digest(self):
+ """Compute the *binary* MAC tag in an AEAD mode.
+
+ The caller invokes this function at the very end.
+
+ This method returns the MAC that shall be sent to the receiver,
+ together with the ciphertext.
+
+ :Return: the MAC, as a byte string.
+ """
+
+ if self.digest not in self._next:
+ raise TypeError("digest() cannot be called when decrypting"
+ " or validating a message")
+ self._next = [self.digest]
+
+ return self._compute_mac()
+
+ def _compute_mac(self):
+ """Compute MAC without any FSM checks."""
+
+ if self._tag:
+ return self._tag
+
+ # Step 5 in NIST SP 800-38D, Algorithm 4 - Compute S
+ self._pad_cache_and_update()
+ self._update(long_to_bytes(8 * self._auth_len, 8))
+ self._update(long_to_bytes(8 * self._msg_len, 8))
+ s_tag = self._signer.digest()
+
+ # Step 6 - Compute T
+ self._tag = self._tag_cipher.encrypt(s_tag)[:self._mac_len]
+
+ return self._tag
+
+ def hexdigest(self):
+ """Compute the *printable* MAC tag.
+
+ This method is like `digest`.
+
+ :Return: the MAC, as a hexadecimal string.
+ """
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def verify(self, received_mac_tag):
+ """Validate the *binary* MAC tag.
+
+ The caller invokes this function at the very end.
+
+ This method checks if the decrypted message is indeed valid
+ (that is, if the key is correct) and it has not been
+ tampered with while in transit.
+
+ :Parameters:
+ received_mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ if self.verify not in self._next:
+ raise TypeError("verify() cannot be called"
+ " when encrypting a message")
+ self._next = [self.verify]
+
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret,
+ data=self._compute_mac())
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret,
+ data=received_mac_tag)
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Validate the *printable* MAC tag.
+
+ This method is like `verify`.
+
+ :Parameters:
+ hex_mac_tag : string
+ This is the *printable* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ self.verify(unhexlify(hex_mac_tag))
+
+ def encrypt_and_digest(self, plaintext, output=None):
+ """Perform encrypt() and digest() in one step.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ a tuple with two items:
+
+ - the ciphertext, as ``bytes``
+ - the MAC tag, as ``bytes``
+
+ The first item becomes ``None`` when the ``output`` parameter
+ specified a location for the result.
+ """
+
+ return self.encrypt(plaintext, output=output), self.digest()
+
+ def decrypt_and_verify(self, ciphertext, received_mac_tag, output=None):
+ """Perform decrypt() and verify() in one step.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ received_mac_tag : byte string
+ This is the *binary* MAC, as received from the sender.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return: the plaintext as ``bytes`` or ``None`` when the ``output``
+ parameter specified a location for the result.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ plaintext = self.decrypt(ciphertext, output=output)
+ self.verify(received_mac_tag)
+ return plaintext
+
+
+def _create_gcm_cipher(factory, **kwargs):
+ """Create a new block cipher, configured in Galois Counter Mode (GCM).
+
+ :Parameters:
+ factory : module
+ A block cipher module, taken from `Crypto.Cipher`.
+ The cipher must have block length of 16 bytes.
+ GCM has been only defined for `Crypto.Cipher.AES`.
+
+ :Keywords:
+ key : bytes/bytearray/memoryview
+ The secret key to use in the symmetric cipher.
+ It must be 16 (e.g. *AES-128*), 24 (e.g. *AES-192*)
+ or 32 (e.g. *AES-256*) bytes long.
+
+ nonce : bytes/bytearray/memoryview
+ A value that must never be reused for any other encryption.
+
+ There are no restrictions on its length,
+ but it is recommended to use at least 16 bytes.
+
+ The nonce shall never repeat for two
+ different messages encrypted with the same key,
+ but it does not need to be random.
+
+ If not provided, a 16 byte nonce will be randomly created.
+
+ mac_len : integer
+ Length of the MAC, in bytes.
+ It must be no larger than 16 bytes (which is the default).
+ """
+
+ try:
+ key = kwargs.pop("key")
+ except KeyError as e:
+ raise TypeError("Missing parameter:" + str(e))
+
+ nonce = kwargs.pop("nonce", None)
+ if nonce is None:
+ nonce = get_random_bytes(16)
+ mac_len = kwargs.pop("mac_len", 16)
+
+ # Not documented - only used for testing
+ use_clmul = kwargs.pop("use_clmul", True)
+ if use_clmul and _ghash_clmul:
+ ghash_c = _ghash_clmul
+ else:
+ ghash_c = _ghash_portable
+
+ return GcmMode(factory, key, nonce, mac_len, kwargs, ghash_c)
diff --git a/lib/Crypto/Cipher/_mode_gcm.pyi b/lib/Crypto/Cipher/_mode_gcm.pyi
new file mode 100644
index 0000000..8912955
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_gcm.pyi
@@ -0,0 +1,45 @@
+from types import ModuleType
+from typing import Union, Tuple, Dict, overload, Optional
+
+__all__ = ['GcmMode']
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class GcmMode(object):
+ block_size: int
+ nonce: Buffer
+
+ def __init__(self,
+ factory: ModuleType,
+ key: Buffer,
+ nonce: Buffer,
+ mac_len: int,
+ cipher_params: Dict) -> None: ...
+
+ def update(self, assoc_data: Buffer) -> GcmMode: ...
+
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, received_mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer) -> Tuple[bytes, bytes]: ...
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer,
+ output: Buffer) -> Tuple[None, bytes]: ...
+ def decrypt_and_verify(self,
+ ciphertext: Buffer,
+ received_mac_tag: Buffer,
+ output: Optional[Union[bytearray, memoryview]] = ...) -> bytes: ...
diff --git a/lib/Crypto/Cipher/_mode_ocb.py b/lib/Crypto/Cipher/_mode_ocb.py
new file mode 100644
index 0000000..27758b1
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ocb.py
@@ -0,0 +1,525 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+Offset Codebook (OCB) mode.
+
+OCB is Authenticated Encryption with Associated Data (AEAD) cipher mode
+designed by Prof. Phillip Rogaway and specified in `RFC7253`_.
+
+The algorithm provides both authenticity and privacy, it is very efficient,
+it uses only one key and it can be used in online mode (so that encryption
+or decryption can start before the end of the message is available).
+
+This module implements the third and last variant of OCB (OCB3) and it only
+works in combination with a 128-bit block symmetric cipher, like AES.
+
+OCB is patented in US but `free licenses`_ exist for software implementations
+meant for non-military purposes.
+
+Example:
+ >>> from Crypto.Cipher import AES
+ >>> from Crypto.Random import get_random_bytes
+ >>>
+ >>> key = get_random_bytes(32)
+ >>> cipher = AES.new(key, AES.MODE_OCB)
+ >>> plaintext = b"Attack at dawn"
+ >>> ciphertext, mac = cipher.encrypt_and_digest(plaintext)
+ >>> # Deliver cipher.nonce, ciphertext and mac
+ ...
+ >>> cipher = AES.new(key, AES.MODE_OCB, nonce=nonce)
+ >>> try:
+ >>> plaintext = cipher.decrypt_and_verify(ciphertext, mac)
+ >>> except ValueError:
+ >>> print "Invalid message"
+ >>> else:
+ >>> print plaintext
+
+:undocumented: __package__
+
+.. _RFC7253: http://www.rfc-editor.org/info/rfc7253
+.. _free licenses: http://web.cs.ucdavis.edu/~rogaway/ocb/license.htm
+"""
+
+import struct
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import bord, _copy_bytes
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+from Crypto.Util.strxor import strxor
+
+from Crypto.Hash import BLAKE2s
+from Crypto.Random import get_random_bytes
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ create_string_buffer, get_raw_buffer,
+ SmartPointer, c_size_t, c_uint8_ptr,
+ is_buffer)
+
+_raw_ocb_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_ocb", """
+ int OCB_start_operation(void *cipher,
+ const uint8_t *offset_0,
+ size_t offset_0_len,
+ void **pState);
+ int OCB_encrypt(void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int OCB_decrypt(void *state,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int OCB_update(void *state,
+ const uint8_t *in,
+ size_t data_len);
+ int OCB_digest(void *state,
+ uint8_t *tag,
+ size_t tag_len);
+ int OCB_stop_operation(void *state);
+ """)
+
+
+class OcbMode(object):
+ """Offset Codebook (OCB) mode.
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, factory, nonce, mac_len, cipher_params):
+
+ if factory.block_size != 16:
+ raise ValueError("OCB mode is only available for ciphers"
+ " that operate on 128 bits blocks")
+
+ self.block_size = 16
+ """The block size of the underlying cipher, in bytes."""
+
+ self.nonce = _copy_bytes(None, None, nonce)
+ """Nonce used for this session."""
+ if len(nonce) not in range(1, 16):
+ raise ValueError("Nonce must be at most 15 bytes long")
+ if not is_buffer(nonce):
+ raise TypeError("Nonce must be bytes, bytearray or memoryview")
+
+ self._mac_len = mac_len
+ if not 8 <= mac_len <= 16:
+ raise ValueError("MAC tag must be between 8 and 16 bytes long")
+
+ # Cache for MAC tag
+ self._mac_tag = None
+
+ # Cache for unaligned associated data
+ self._cache_A = b""
+
+ # Cache for unaligned ciphertext/plaintext
+ self._cache_P = b""
+
+ # Allowed transitions after initialization
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ # Compute Offset_0
+ params_without_key = dict(cipher_params)
+ key = params_without_key.pop("key")
+ nonce = (struct.pack('B', self._mac_len << 4 & 0xFF) +
+ b'\x00' * (14 - len(nonce)) +
+ b'\x01' + self.nonce)
+
+ bottom_bits = bord(nonce[15]) & 0x3F # 6 bits, 0..63
+ top_bits = bord(nonce[15]) & 0xC0 # 2 bits
+
+ ktop_cipher = factory.new(key,
+ factory.MODE_ECB,
+ **params_without_key)
+ ktop = ktop_cipher.encrypt(struct.pack('15sB',
+ nonce[:15],
+ top_bits))
+
+ stretch = ktop + strxor(ktop[:8], ktop[1:9]) # 192 bits
+ offset_0 = long_to_bytes(bytes_to_long(stretch) >>
+ (64 - bottom_bits), 24)[8:]
+
+ # Create low-level cipher instance
+ raw_cipher = factory._create_base_cipher(cipher_params)
+ if cipher_params:
+ raise TypeError("Unknown keywords: " + str(cipher_params))
+
+ self._state = VoidPointer()
+ result = _raw_ocb_lib.OCB_start_operation(raw_cipher.get(),
+ offset_0,
+ c_size_t(len(offset_0)),
+ self._state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating the OCB mode"
+ % result)
+
+ # Ensure that object disposal of this Python object will (eventually)
+ # free the memory allocated by the raw library for the cipher mode
+ self._state = SmartPointer(self._state.get(),
+ _raw_ocb_lib.OCB_stop_operation)
+
+ # Memory allocated for the underlying block cipher is now owed
+ # by the cipher mode
+ raw_cipher.release()
+
+ def _update(self, assoc_data, assoc_data_len):
+ result = _raw_ocb_lib.OCB_update(self._state.get(),
+ c_uint8_ptr(assoc_data),
+ c_size_t(assoc_data_len))
+ if result:
+ raise ValueError("Error %d while computing MAC in OCB mode" % result)
+
+ def update(self, assoc_data):
+ """Process the associated data.
+
+ If there is any associated data, the caller has to invoke
+ this method one or more times, before using
+ ``decrypt`` or ``encrypt``.
+
+ By *associated data* it is meant any data (e.g. packet headers) that
+ will not be encrypted and will be transmitted in the clear.
+ However, the receiver shall still able to detect modifications.
+
+ If there is no associated data, this method must not be called.
+
+ The caller may split associated data in segments of any size, and
+ invoke this method multiple times, each time with the next segment.
+
+ :Parameters:
+ assoc_data : bytes/bytearray/memoryview
+ A piece of associated data.
+ """
+
+ if self.update not in self._next:
+ raise TypeError("update() can only be called"
+ " immediately after initialization")
+
+ self._next = [self.encrypt, self.decrypt, self.digest,
+ self.verify, self.update]
+
+ if len(self._cache_A) > 0:
+ filler = min(16 - len(self._cache_A), len(assoc_data))
+ self._cache_A += _copy_bytes(None, filler, assoc_data)
+ assoc_data = assoc_data[filler:]
+
+ if len(self._cache_A) < 16:
+ return self
+
+ # Clear the cache, and proceeding with any other aligned data
+ self._cache_A, seg = b"", self._cache_A
+ self.update(seg)
+
+ update_len = len(assoc_data) // 16 * 16
+ self._cache_A = _copy_bytes(update_len, None, assoc_data)
+ self._update(assoc_data, update_len)
+ return self
+
+ def _transcrypt_aligned(self, in_data, in_data_len,
+ trans_func, trans_desc):
+
+ out_data = create_string_buffer(in_data_len)
+ result = trans_func(self._state.get(),
+ in_data,
+ out_data,
+ c_size_t(in_data_len))
+ if result:
+ raise ValueError("Error %d while %sing in OCB mode"
+ % (result, trans_desc))
+ return get_raw_buffer(out_data)
+
+ def _transcrypt(self, in_data, trans_func, trans_desc):
+ # Last piece to encrypt/decrypt
+ if in_data is None:
+ out_data = self._transcrypt_aligned(self._cache_P,
+ len(self._cache_P),
+ trans_func,
+ trans_desc)
+ self._cache_P = b""
+ return out_data
+
+ # Try to fill up the cache, if it already contains something
+ prefix = b""
+ if len(self._cache_P) > 0:
+ filler = min(16 - len(self._cache_P), len(in_data))
+ self._cache_P += _copy_bytes(None, filler, in_data)
+ in_data = in_data[filler:]
+
+ if len(self._cache_P) < 16:
+ # We could not manage to fill the cache, so there is certainly
+ # no output yet.
+ return b""
+
+ # Clear the cache, and proceeding with any other aligned data
+ prefix = self._transcrypt_aligned(self._cache_P,
+ len(self._cache_P),
+ trans_func,
+ trans_desc)
+ self._cache_P = b""
+
+ # Process data in multiples of the block size
+ trans_len = len(in_data) // 16 * 16
+ result = self._transcrypt_aligned(c_uint8_ptr(in_data),
+ trans_len,
+ trans_func,
+ trans_desc)
+ if prefix:
+ result = prefix + result
+
+ # Left-over
+ self._cache_P = _copy_bytes(trans_len, None, in_data)
+
+ return result
+
+ def encrypt(self, plaintext=None):
+ """Encrypt the next piece of plaintext.
+
+ After the entire plaintext has been passed (but before `digest`),
+ you **must** call this method one last time with no arguments to collect
+ the final piece of ciphertext.
+
+ If possible, use the method `encrypt_and_digest` instead.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The next piece of data to encrypt or ``None`` to signify
+ that encryption has finished and that any remaining ciphertext
+ has to be produced.
+ :Return:
+ the ciphertext, as a byte string.
+ Its length may not match the length of the *plaintext*.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() can only be called after"
+ " initialization or an update()")
+
+ if plaintext is None:
+ self._next = [self.digest]
+ else:
+ self._next = [self.encrypt]
+ return self._transcrypt(plaintext, _raw_ocb_lib.OCB_encrypt, "encrypt")
+
+ def decrypt(self, ciphertext=None):
+ """Decrypt the next piece of ciphertext.
+
+ After the entire ciphertext has been passed (but before `verify`),
+ you **must** call this method one last time with no arguments to collect
+ the remaining piece of plaintext.
+
+ If possible, use the method `decrypt_and_verify` instead.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The next piece of data to decrypt or ``None`` to signify
+ that decryption has finished and that any remaining plaintext
+ has to be produced.
+ :Return:
+ the plaintext, as a byte string.
+ Its length may not match the length of the *ciphertext*.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() can only be called after"
+ " initialization or an update()")
+
+ if ciphertext is None:
+ self._next = [self.verify]
+ else:
+ self._next = [self.decrypt]
+ return self._transcrypt(ciphertext,
+ _raw_ocb_lib.OCB_decrypt,
+ "decrypt")
+
+ def _compute_mac_tag(self):
+
+ if self._mac_tag is not None:
+ return
+
+ if self._cache_A:
+ self._update(self._cache_A, len(self._cache_A))
+ self._cache_A = b""
+
+ mac_tag = create_string_buffer(16)
+ result = _raw_ocb_lib.OCB_digest(self._state.get(),
+ mac_tag,
+ c_size_t(len(mac_tag))
+ )
+ if result:
+ raise ValueError("Error %d while computing digest in OCB mode"
+ % result)
+ self._mac_tag = get_raw_buffer(mac_tag)[:self._mac_len]
+
+ def digest(self):
+ """Compute the *binary* MAC tag.
+
+ Call this method after the final `encrypt` (the one with no arguments)
+ to obtain the MAC tag.
+
+ The MAC tag is needed by the receiver to determine authenticity
+ of the message.
+
+ :Return: the MAC, as a byte string.
+ """
+
+ if self.digest not in self._next:
+ raise TypeError("digest() cannot be called now for this cipher")
+
+ assert(len(self._cache_P) == 0)
+
+ self._next = [self.digest]
+
+ if self._mac_tag is None:
+ self._compute_mac_tag()
+
+ return self._mac_tag
+
+ def hexdigest(self):
+ """Compute the *printable* MAC tag.
+
+ This method is like `digest`.
+
+ :Return: the MAC, as a hexadecimal string.
+ """
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def verify(self, received_mac_tag):
+ """Validate the *binary* MAC tag.
+
+ Call this method after the final `decrypt` (the one with no arguments)
+ to check if the message is authentic and valid.
+
+ :Parameters:
+ received_mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ if self.verify not in self._next:
+ raise TypeError("verify() cannot be called now for this cipher")
+
+ assert(len(self._cache_P) == 0)
+
+ self._next = [self.verify]
+
+ if self._mac_tag is None:
+ self._compute_mac_tag()
+
+ secret = get_random_bytes(16)
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=self._mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=received_mac_tag)
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Validate the *printable* MAC tag.
+
+ This method is like `verify`.
+
+ :Parameters:
+ hex_mac_tag : string
+ This is the *printable* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ self.verify(unhexlify(hex_mac_tag))
+
+ def encrypt_and_digest(self, plaintext):
+ """Encrypt the message and create the MAC tag in one step.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The entire message to encrypt.
+ :Return:
+ a tuple with two byte strings:
+
+ - the encrypted data
+ - the MAC
+ """
+
+ return self.encrypt(plaintext) + self.encrypt(), self.digest()
+
+ def decrypt_and_verify(self, ciphertext, received_mac_tag):
+ """Decrypted the message and verify its authenticity in one step.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The entire message to decrypt.
+ received_mac_tag : byte string
+ This is the *binary* MAC, as received from the sender.
+
+ :Return: the decrypted data (byte string).
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ plaintext = self.decrypt(ciphertext) + self.decrypt()
+ self.verify(received_mac_tag)
+ return plaintext
+
+
+def _create_ocb_cipher(factory, **kwargs):
+ """Create a new block cipher, configured in OCB mode.
+
+ :Parameters:
+ factory : module
+ A symmetric cipher module from `Crypto.Cipher`
+ (like `Crypto.Cipher.AES`).
+
+ :Keywords:
+ nonce : bytes/bytearray/memoryview
+ A value that must never be reused for any other encryption.
+ Its length can vary from 1 to 15 bytes.
+ If not specified, a random 15 bytes long nonce is generated.
+
+ mac_len : integer
+ Length of the MAC, in bytes.
+ It must be in the range ``[8..16]``.
+ The default is 16 (128 bits).
+
+ Any other keyword will be passed to the underlying block cipher.
+ See the relevant documentation for details (at least ``key`` will need
+ to be present).
+ """
+
+ try:
+ nonce = kwargs.pop("nonce", None)
+ if nonce is None:
+ nonce = get_random_bytes(15)
+ mac_len = kwargs.pop("mac_len", 16)
+ except KeyError as e:
+ raise TypeError("Keyword missing: " + str(e))
+
+ return OcbMode(factory, nonce, mac_len, kwargs)
diff --git a/lib/Crypto/Cipher/_mode_ocb.pyi b/lib/Crypto/Cipher/_mode_ocb.pyi
new file mode 100644
index 0000000..a1909fc
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ocb.pyi
@@ -0,0 +1,36 @@
+from types import ModuleType
+from typing import Union, Any, Optional, Tuple, Dict, overload
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class OcbMode(object):
+ block_size: int
+ nonce: Buffer
+
+ def __init__(self,
+ factory: ModuleType,
+ nonce: Buffer,
+ mac_len: int,
+ cipher_params: Dict) -> None: ...
+
+ def update(self, assoc_data: Buffer) -> OcbMode: ...
+
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, received_mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+ def encrypt_and_digest(self,
+ plaintext: Buffer) -> Tuple[bytes, bytes]: ...
+ def decrypt_and_verify(self,
+ ciphertext: Buffer,
+ received_mac_tag: Buffer) -> bytes: ...
diff --git a/lib/Crypto/Cipher/_mode_ofb.py b/lib/Crypto/Cipher/_mode_ofb.py
new file mode 100644
index 0000000..958f6d0
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ofb.py
@@ -0,0 +1,282 @@
+# -*- coding: utf-8 -*-
+#
+# Cipher/mode_ofb.py : OFB mode
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""
+Output Feedback (CFB) mode.
+"""
+
+__all__ = ['OfbMode']
+
+from Crypto.Util.py3compat import _copy_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ create_string_buffer, get_raw_buffer,
+ SmartPointer, c_size_t, c_uint8_ptr,
+ is_writeable_buffer)
+
+from Crypto.Random import get_random_bytes
+
+raw_ofb_lib = load_pycryptodome_raw_lib("Crypto.Cipher._raw_ofb", """
+ int OFB_start_operation(void *cipher,
+ const uint8_t iv[],
+ size_t iv_len,
+ void **pResult);
+ int OFB_encrypt(void *ofbState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int OFB_decrypt(void *ofbState,
+ const uint8_t *in,
+ uint8_t *out,
+ size_t data_len);
+ int OFB_stop_operation(void *state);
+ """
+ )
+
+
+class OfbMode(object):
+ """*Output FeedBack (OFB)*.
+
+ This mode is very similar to CBC, but it
+ transforms the underlying block cipher into a stream cipher.
+
+ The keystream is the iterated block encryption of the
+ previous ciphertext block.
+
+ An Initialization Vector (*IV*) is required.
+
+ See `NIST SP800-38A`_ , Section 6.4.
+
+ .. _`NIST SP800-38A` : http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, block_cipher, iv):
+ """Create a new block cipher, configured in OFB mode.
+
+ :Parameters:
+ block_cipher : C pointer
+ A smart pointer to the low-level block cipher instance.
+
+ iv : bytes/bytearray/memoryview
+ The initialization vector to use for encryption or decryption.
+ It is as long as the cipher block.
+
+ **The IV must be a nonce, to to be reused for any other
+ message**. It shall be a nonce or a random value.
+
+ Reusing the *IV* for encryptions performed with the same key
+ compromises confidentiality.
+ """
+
+ self._state = VoidPointer()
+ result = raw_ofb_lib.OFB_start_operation(block_cipher.get(),
+ c_uint8_ptr(iv),
+ c_size_t(len(iv)),
+ self._state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating the OFB mode"
+ % result)
+
+ # Ensure that object disposal of this Python object will (eventually)
+ # free the memory allocated by the raw library for the cipher mode
+ self._state = SmartPointer(self._state.get(),
+ raw_ofb_lib.OFB_stop_operation)
+
+ # Memory allocated for the underlying block cipher is now owed
+ # by the cipher mode
+ block_cipher.release()
+
+ self.block_size = len(iv)
+ """The block size of the underlying cipher, in bytes."""
+
+ self.iv = _copy_bytes(None, None, iv)
+ """The Initialization Vector originally used to create the object.
+ The value does not change."""
+
+ self.IV = self.iv
+ """Alias for `iv`"""
+
+ self._next = [ self.encrypt, self.decrypt ]
+
+ def encrypt(self, plaintext, output=None):
+ """Encrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ If ``output`` is ``None``, the ciphertext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() cannot be called after decrypt()")
+ self._next = [ self.encrypt ]
+
+ if output is None:
+ ciphertext = create_string_buffer(len(plaintext))
+ else:
+ ciphertext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(plaintext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_ofb_lib.OFB_encrypt(self._state.get(),
+ c_uint8_ptr(plaintext),
+ c_uint8_ptr(ciphertext),
+ c_size_t(len(plaintext)))
+ if result:
+ raise ValueError("Error %d while encrypting in OFB mode" % result)
+
+ if output is None:
+ return get_raw_buffer(ciphertext)
+ else:
+ return None
+
+ def decrypt(self, ciphertext, output=None):
+ """Decrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ It can be of any length.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext is written to.
+ If ``None``, the plaintext is returned.
+ :Return:
+ If ``output`` is ``None``, the plaintext is returned as ``bytes``.
+ Otherwise, ``None``.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() cannot be called after encrypt()")
+ self._next = [ self.decrypt ]
+
+ if output is None:
+ plaintext = create_string_buffer(len(ciphertext))
+ else:
+ plaintext = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(ciphertext) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(plaintext))
+
+ result = raw_ofb_lib.OFB_decrypt(self._state.get(),
+ c_uint8_ptr(ciphertext),
+ c_uint8_ptr(plaintext),
+ c_size_t(len(ciphertext)))
+ if result:
+ raise ValueError("Error %d while decrypting in OFB mode" % result)
+
+ if output is None:
+ return get_raw_buffer(plaintext)
+ else:
+ return None
+
+
+def _create_ofb_cipher(factory, **kwargs):
+ """Instantiate a cipher object that performs OFB encryption/decryption.
+
+ :Parameters:
+ factory : module
+ The underlying block cipher, a module from ``Crypto.Cipher``.
+
+ :Keywords:
+ iv : bytes/bytearray/memoryview
+ The IV to use for OFB.
+
+ IV : bytes/bytearray/memoryview
+ Alias for ``iv``.
+
+ Any other keyword will be passed to the underlying block cipher.
+ See the relevant documentation for details (at least ``key`` will need
+ to be present).
+ """
+
+ cipher_state = factory._create_base_cipher(kwargs)
+ iv = kwargs.pop("IV", None)
+ IV = kwargs.pop("iv", None)
+
+ if (None, None) == (iv, IV):
+ iv = get_random_bytes(factory.block_size)
+ if iv is not None:
+ if IV is not None:
+ raise TypeError("You must either use 'iv' or 'IV', not both")
+ else:
+ iv = IV
+
+ if len(iv) != factory.block_size:
+ raise ValueError("Incorrect IV length (it must be %d bytes long)" %
+ factory.block_size)
+
+ if kwargs:
+ raise TypeError("Unknown parameters for OFB: %s" % str(kwargs))
+
+ return OfbMode(cipher_state, iv)
diff --git a/lib/Crypto/Cipher/_mode_ofb.pyi b/lib/Crypto/Cipher/_mode_ofb.pyi
new file mode 100644
index 0000000..60f7f00
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_ofb.pyi
@@ -0,0 +1,25 @@
+from typing import Union, overload
+
+from Crypto.Util._raw_api import SmartPointer
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['OfbMode']
+
+class OfbMode(object):
+ block_size: int
+ iv: Buffer
+ IV: Buffer
+
+ def __init__(self,
+ block_cipher: SmartPointer,
+ iv: Buffer) -> None: ...
+ @overload
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def encrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+ @overload
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+ @overload
+ def decrypt(self, plaintext: Buffer, output: Union[bytearray, memoryview]) -> None: ...
+
diff --git a/lib/Crypto/Cipher/_mode_openpgp.py b/lib/Crypto/Cipher/_mode_openpgp.py
new file mode 100644
index 0000000..d079d59
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_openpgp.py
@@ -0,0 +1,206 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+OpenPGP mode.
+"""
+
+__all__ = ['OpenPgpMode']
+
+from Crypto.Util.py3compat import _copy_bytes
+from Crypto.Random import get_random_bytes
+
+class OpenPgpMode(object):
+ """OpenPGP mode.
+
+ This mode is a variant of CFB, and it is only used in PGP and
+ OpenPGP_ applications. If in doubt, use another mode.
+
+ An Initialization Vector (*IV*) is required.
+
+ Unlike CFB, the *encrypted* IV (not the IV itself) is
+ transmitted to the receiver.
+
+ The IV is a random data block. For legacy reasons, two of its bytes are
+ duplicated to act as a checksum for the correctness of the key, which is now
+ known to be insecure and is ignored. The encrypted IV is therefore 2 bytes
+ longer than the clean IV.
+
+ .. _OpenPGP: http://tools.ietf.org/html/rfc4880
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, factory, key, iv, cipher_params):
+
+ #: The block size of the underlying cipher, in bytes.
+ self.block_size = factory.block_size
+
+ self._done_first_block = False # True after the first encryption
+
+ # Instantiate a temporary cipher to process the IV
+ IV_cipher = factory.new(
+ key,
+ factory.MODE_CFB,
+ IV=b'\x00' * self.block_size,
+ segment_size=self.block_size * 8,
+ **cipher_params)
+
+ iv = _copy_bytes(None, None, iv)
+
+ # The cipher will be used for...
+ if len(iv) == self.block_size:
+ # ... encryption
+ self._encrypted_IV = IV_cipher.encrypt(iv + iv[-2:])
+ elif len(iv) == self.block_size + 2:
+ # ... decryption
+ self._encrypted_IV = iv
+ # Last two bytes are for a deprecated "quick check" feature that
+ # should not be used. (https://eprint.iacr.org/2005/033)
+ iv = IV_cipher.decrypt(iv)[:-2]
+ else:
+ raise ValueError("Length of IV must be %d or %d bytes"
+ " for MODE_OPENPGP"
+ % (self.block_size, self.block_size + 2))
+
+ self.iv = self.IV = iv
+
+ # Instantiate the cipher for the real PGP data
+ self._cipher = factory.new(
+ key,
+ factory.MODE_CFB,
+ IV=self._encrypted_IV[-self.block_size:],
+ segment_size=self.block_size * 8,
+ **cipher_params)
+
+ def encrypt(self, plaintext):
+ """Encrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have encrypted a message
+ you cannot encrypt (or decrypt) another message using the same
+ object.
+
+ The data to encrypt can be broken up in two or
+ more pieces and `encrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.encrypt(a) + c.encrypt(b)
+
+ is equivalent to:
+
+ >>> c.encrypt(a+b)
+
+ This function does not add any padding to the plaintext.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+
+ :Return:
+ the encrypted data, as a byte string.
+ It is as long as *plaintext* with one exception:
+ when encrypting the first message chunk,
+ the encypted IV is prepended to the returned ciphertext.
+ """
+
+ res = self._cipher.encrypt(plaintext)
+ if not self._done_first_block:
+ res = self._encrypted_IV + res
+ self._done_first_block = True
+ return res
+
+ def decrypt(self, ciphertext):
+ """Decrypt data with the key and the parameters set at initialization.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ The data to decrypt can be broken up in two or
+ more pieces and `decrypt` can be called multiple times.
+
+ That is, the statement:
+
+ >>> c.decrypt(a) + c.decrypt(b)
+
+ is equivalent to:
+
+ >>> c.decrypt(a+b)
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+
+ :Return: the decrypted data (byte string).
+ """
+
+ return self._cipher.decrypt(ciphertext)
+
+
+def _create_openpgp_cipher(factory, **kwargs):
+ """Create a new block cipher, configured in OpenPGP mode.
+
+ :Parameters:
+ factory : module
+ The module.
+
+ :Keywords:
+ key : bytes/bytearray/memoryview
+ The secret key to use in the symmetric cipher.
+
+ IV : bytes/bytearray/memoryview
+ The initialization vector to use for encryption or decryption.
+
+ For encryption, the IV must be as long as the cipher block size.
+
+ For decryption, it must be 2 bytes longer (it is actually the
+ *encrypted* IV which was prefixed to the ciphertext).
+ """
+
+ iv = kwargs.pop("IV", None)
+ IV = kwargs.pop("iv", None)
+
+ if (None, None) == (iv, IV):
+ iv = get_random_bytes(factory.block_size)
+ if iv is not None:
+ if IV is not None:
+ raise TypeError("You must either use 'iv' or 'IV', not both")
+ else:
+ iv = IV
+
+ try:
+ key = kwargs.pop("key")
+ except KeyError as e:
+ raise TypeError("Missing component: " + str(e))
+
+ return OpenPgpMode(factory, key, iv, kwargs)
diff --git a/lib/Crypto/Cipher/_mode_openpgp.pyi b/lib/Crypto/Cipher/_mode_openpgp.pyi
new file mode 100644
index 0000000..14b8105
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_openpgp.pyi
@@ -0,0 +1,20 @@
+from types import ModuleType
+from typing import Union, Dict
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['OpenPgpMode']
+
+class OpenPgpMode(object):
+ block_size: int
+ iv: Union[bytes, bytearray, memoryview]
+ IV: Union[bytes, bytearray, memoryview]
+
+ def __init__(self,
+ factory: ModuleType,
+ key: Buffer,
+ iv: Buffer,
+ cipher_params: Dict) -> None: ...
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+
diff --git a/lib/Crypto/Cipher/_mode_siv.py b/lib/Crypto/Cipher/_mode_siv.py
new file mode 100644
index 0000000..d1eca2a
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_siv.py
@@ -0,0 +1,392 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+Synthetic Initialization Vector (SIV) mode.
+"""
+
+__all__ = ['SivMode']
+
+from binascii import hexlify, unhexlify
+
+from Crypto.Util.py3compat import bord, _copy_bytes
+
+from Crypto.Util._raw_api import is_buffer
+
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+from Crypto.Protocol.KDF import _S2V
+from Crypto.Hash import BLAKE2s
+from Crypto.Random import get_random_bytes
+
+
+class SivMode(object):
+ """Synthetic Initialization Vector (SIV).
+
+ This is an Authenticated Encryption with Associated Data (`AEAD`_) mode.
+ It provides both confidentiality and authenticity.
+
+ The header of the message may be left in the clear, if needed, and it will
+ still be subject to authentication. The decryption step tells the receiver
+ if the message comes from a source that really knowns the secret key.
+ Additionally, decryption detects if any part of the message - including the
+ header - has been modified or corrupted.
+
+ Unlike other AEAD modes such as CCM, EAX or GCM, accidental reuse of a
+ nonce is not catastrophic for the confidentiality of the message. The only
+ effect is that an attacker can tell when the same plaintext (and same
+ associated data) is protected with the same key.
+
+ The length of the MAC is fixed to the block size of the underlying cipher.
+ The key size is twice the length of the key of the underlying cipher.
+
+ This mode is only available for AES ciphers.
+
+ +--------------------+---------------+-------------------+
+ | Cipher | SIV MAC size | SIV key length |
+ | | (bytes) | (bytes) |
+ +====================+===============+===================+
+ | AES-128 | 16 | 32 |
+ +--------------------+---------------+-------------------+
+ | AES-192 | 16 | 48 |
+ +--------------------+---------------+-------------------+
+ | AES-256 | 16 | 64 |
+ +--------------------+---------------+-------------------+
+
+ See `RFC5297`_ and the `original paper`__.
+
+ .. _RFC5297: https://tools.ietf.org/html/rfc5297
+ .. _AEAD: http://blog.cryptographyengineering.com/2012/05/how-to-choose-authenticated-encryption.html
+ .. __: http://www.cs.ucdavis.edu/~rogaway/papers/keywrap.pdf
+
+ :undocumented: __init__
+ """
+
+ def __init__(self, factory, key, nonce, kwargs):
+
+ self.block_size = factory.block_size
+ """The block size of the underlying cipher, in bytes."""
+
+ self._factory = factory
+
+ self._cipher_params = kwargs
+
+ if len(key) not in (32, 48, 64):
+ raise ValueError("Incorrect key length (%d bytes)" % len(key))
+
+ if nonce is not None:
+ if not is_buffer(nonce):
+ raise TypeError("When provided, the nonce must be bytes, bytearray or memoryview")
+
+ if len(nonce) == 0:
+ raise ValueError("When provided, the nonce must be non-empty")
+
+ self.nonce = _copy_bytes(None, None, nonce)
+ """Public attribute is only available in case of non-deterministic
+ encryption."""
+
+ subkey_size = len(key) // 2
+
+ self._mac_tag = None # Cache for MAC tag
+ self._kdf = _S2V(key[:subkey_size],
+ ciphermod=factory,
+ cipher_params=self._cipher_params)
+ self._subkey_cipher = key[subkey_size:]
+
+ # Purely for the purpose of verifying that cipher_params are OK
+ factory.new(key[:subkey_size], factory.MODE_ECB, **kwargs)
+
+ # Allowed transitions after initialization
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ def _create_ctr_cipher(self, v):
+ """Create a new CTR cipher from V in SIV mode"""
+
+ v_int = bytes_to_long(v)
+ q = v_int & 0xFFFFFFFFFFFFFFFF7FFFFFFF7FFFFFFF
+ return self._factory.new(
+ self._subkey_cipher,
+ self._factory.MODE_CTR,
+ initial_value=q,
+ nonce=b"",
+ **self._cipher_params)
+
+ def update(self, component):
+ """Protect one associated data component
+
+ For SIV, the associated data is a sequence (*vector*) of non-empty
+ byte strings (*components*).
+
+ This method consumes the next component. It must be called
+ once for each of the components that constitue the associated data.
+
+ Note that the components have clear boundaries, so that:
+
+ >>> cipher.update(b"builtin")
+ >>> cipher.update(b"securely")
+
+ is not equivalent to:
+
+ >>> cipher.update(b"built")
+ >>> cipher.update(b"insecurely")
+
+ If there is no associated data, this method must not be called.
+
+ :Parameters:
+ component : bytes/bytearray/memoryview
+ The next associated data component.
+ """
+
+ if self.update not in self._next:
+ raise TypeError("update() can only be called"
+ " immediately after initialization")
+
+ self._next = [self.update, self.encrypt, self.decrypt,
+ self.digest, self.verify]
+
+ return self._kdf.update(component)
+
+ def encrypt(self, plaintext):
+ """
+ For SIV, encryption and MAC authentication must take place at the same
+ point. This method shall not be used.
+
+ Use `encrypt_and_digest` instead.
+ """
+
+ raise TypeError("encrypt() not allowed for SIV mode."
+ " Use encrypt_and_digest() instead.")
+
+ def decrypt(self, ciphertext):
+ """
+ For SIV, decryption and verification must take place at the same
+ point. This method shall not be used.
+
+ Use `decrypt_and_verify` instead.
+ """
+
+ raise TypeError("decrypt() not allowed for SIV mode."
+ " Use decrypt_and_verify() instead.")
+
+ def digest(self):
+ """Compute the *binary* MAC tag.
+
+ The caller invokes this function at the very end.
+
+ This method returns the MAC that shall be sent to the receiver,
+ together with the ciphertext.
+
+ :Return: the MAC, as a byte string.
+ """
+
+ if self.digest not in self._next:
+ raise TypeError("digest() cannot be called when decrypting"
+ " or validating a message")
+ self._next = [self.digest]
+ if self._mac_tag is None:
+ self._mac_tag = self._kdf.derive()
+ return self._mac_tag
+
+ def hexdigest(self):
+ """Compute the *printable* MAC tag.
+
+ This method is like `digest`.
+
+ :Return: the MAC, as a hexadecimal string.
+ """
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def verify(self, received_mac_tag):
+ """Validate the *binary* MAC tag.
+
+ The caller invokes this function at the very end.
+
+ This method checks if the decrypted message is indeed valid
+ (that is, if the key is correct) and it has not been
+ tampered with while in transit.
+
+ :Parameters:
+ received_mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ if self.verify not in self._next:
+ raise TypeError("verify() cannot be called"
+ " when encrypting a message")
+ self._next = [self.verify]
+
+ if self._mac_tag is None:
+ self._mac_tag = self._kdf.derive()
+
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=self._mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=received_mac_tag)
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Validate the *printable* MAC tag.
+
+ This method is like `verify`.
+
+ :Parameters:
+ hex_mac_tag : string
+ This is the *printable* MAC, as received from the sender.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ self.verify(unhexlify(hex_mac_tag))
+
+ def encrypt_and_digest(self, plaintext, output=None):
+ """Perform encrypt() and digest() in one step.
+
+ :Parameters:
+ plaintext : bytes/bytearray/memoryview
+ The piece of data to encrypt.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the ciphertext must be written to.
+ If ``None``, the ciphertext is returned.
+ :Return:
+ a tuple with two items:
+
+ - the ciphertext, as ``bytes``
+ - the MAC tag, as ``bytes``
+
+ The first item becomes ``None`` when the ``output`` parameter
+ specified a location for the result.
+ """
+
+ if self.encrypt not in self._next:
+ raise TypeError("encrypt() can only be called after"
+ " initialization or an update()")
+
+ self._next = [ self.digest ]
+
+ # Compute V (MAC)
+ if hasattr(self, 'nonce'):
+ self._kdf.update(self.nonce)
+ self._kdf.update(plaintext)
+ self._mac_tag = self._kdf.derive()
+
+ cipher = self._create_ctr_cipher(self._mac_tag)
+
+ return cipher.encrypt(plaintext, output=output), self._mac_tag
+
+ def decrypt_and_verify(self, ciphertext, mac_tag, output=None):
+ """Perform decryption and verification in one step.
+
+ A cipher object is stateful: once you have decrypted a message
+ you cannot decrypt (or encrypt) another message with the same
+ object.
+
+ You cannot reuse an object for encrypting
+ or decrypting other data with the same key.
+
+ This function does not remove any padding from the plaintext.
+
+ :Parameters:
+ ciphertext : bytes/bytearray/memoryview
+ The piece of data to decrypt.
+ It can be of any length.
+ mac_tag : bytes/bytearray/memoryview
+ This is the *binary* MAC, as received from the sender.
+ :Keywords:
+ output : bytearray/memoryview
+ The location where the plaintext must be written to.
+ If ``None``, the plaintext is returned.
+ :Return: the plaintext as ``bytes`` or ``None`` when the ``output``
+ parameter specified a location for the result.
+ :Raises ValueError:
+ if the MAC does not match. The message has been tampered with
+ or the key is incorrect.
+ """
+
+ if self.decrypt not in self._next:
+ raise TypeError("decrypt() can only be called"
+ " after initialization or an update()")
+ self._next = [ self.verify ]
+
+ # Take the MAC and start the cipher for decryption
+ self._cipher = self._create_ctr_cipher(mac_tag)
+
+ plaintext = self._cipher.decrypt(ciphertext, output=output)
+
+ if hasattr(self, 'nonce'):
+ self._kdf.update(self.nonce)
+ self._kdf.update(plaintext if output is None else output)
+ self.verify(mac_tag)
+
+ return plaintext
+
+
+def _create_siv_cipher(factory, **kwargs):
+ """Create a new block cipher, configured in
+ Synthetic Initializaton Vector (SIV) mode.
+
+ :Parameters:
+
+ factory : object
+ A symmetric cipher module from `Crypto.Cipher`
+ (like `Crypto.Cipher.AES`).
+
+ :Keywords:
+
+ key : bytes/bytearray/memoryview
+ The secret key to use in the symmetric cipher.
+ It must be 32, 48 or 64 bytes long.
+ If AES is the chosen cipher, the variants *AES-128*,
+ *AES-192* and or *AES-256* will be used internally.
+
+ nonce : bytes/bytearray/memoryview
+ For deterministic encryption, it is not present.
+
+ Otherwise, it is a value that must never be reused
+ for encrypting message under this key.
+
+ There are no restrictions on its length,
+ but it is recommended to use at least 16 bytes.
+ """
+
+ try:
+ key = kwargs.pop("key")
+ except KeyError as e:
+ raise TypeError("Missing parameter: " + str(e))
+
+ nonce = kwargs.pop("nonce", None)
+
+ return SivMode(factory, key, nonce, kwargs)
diff --git a/lib/Crypto/Cipher/_mode_siv.pyi b/lib/Crypto/Cipher/_mode_siv.pyi
new file mode 100644
index 0000000..2934f23
--- /dev/null
+++ b/lib/Crypto/Cipher/_mode_siv.pyi
@@ -0,0 +1,38 @@
+from types import ModuleType
+from typing import Union, Tuple, Dict, Optional, overload
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+__all__ = ['SivMode']
+
+class SivMode(object):
+ block_size: int
+ nonce: bytes
+
+ def __init__(self,
+ factory: ModuleType,
+ key: Buffer,
+ nonce: Buffer,
+ kwargs: Dict) -> None: ...
+
+ def update(self, component: Buffer) -> SivMode: ...
+
+ def encrypt(self, plaintext: Buffer) -> bytes: ...
+ def decrypt(self, plaintext: Buffer) -> bytes: ...
+
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, received_mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer) -> Tuple[bytes, bytes]: ...
+ @overload
+ def encrypt_and_digest(self,
+ plaintext: Buffer,
+ output: Buffer) -> Tuple[None, bytes]: ...
+ def decrypt_and_verify(self,
+ ciphertext: Buffer,
+ received_mac_tag: Buffer,
+ output: Optional[Union[bytearray, memoryview]] = ...) -> bytes: ...
diff --git a/lib/Crypto/Cipher/_pkcs1_decode.abi3.so b/lib/Crypto/Cipher/_pkcs1_decode.abi3.so
new file mode 100755
index 0000000..551a4c2
--- /dev/null
+++ b/lib/Crypto/Cipher/_pkcs1_decode.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_aes.abi3.so b/lib/Crypto/Cipher/_raw_aes.abi3.so
new file mode 100755
index 0000000..42c7919
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_aes.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_aesni.abi3.so b/lib/Crypto/Cipher/_raw_aesni.abi3.so
new file mode 100755
index 0000000..35701e4
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_aesni.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_arc2.abi3.so b/lib/Crypto/Cipher/_raw_arc2.abi3.so
new file mode 100755
index 0000000..acca255
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_arc2.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_blowfish.abi3.so b/lib/Crypto/Cipher/_raw_blowfish.abi3.so
new file mode 100755
index 0000000..2f4ddcf
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_blowfish.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_cast.abi3.so b/lib/Crypto/Cipher/_raw_cast.abi3.so
new file mode 100755
index 0000000..e8e5722
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_cast.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_cbc.abi3.so b/lib/Crypto/Cipher/_raw_cbc.abi3.so
new file mode 100755
index 0000000..7ef5187
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_cbc.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_cfb.abi3.so b/lib/Crypto/Cipher/_raw_cfb.abi3.so
new file mode 100755
index 0000000..45e31e1
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_cfb.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_ctr.abi3.so b/lib/Crypto/Cipher/_raw_ctr.abi3.so
new file mode 100755
index 0000000..b12ffc4
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_ctr.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_des.abi3.so b/lib/Crypto/Cipher/_raw_des.abi3.so
new file mode 100755
index 0000000..84d47d1
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_des.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_des3.abi3.so b/lib/Crypto/Cipher/_raw_des3.abi3.so
new file mode 100755
index 0000000..c11ef96
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_des3.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_ecb.abi3.so b/lib/Crypto/Cipher/_raw_ecb.abi3.so
new file mode 100755
index 0000000..44f0b0d
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_ecb.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_eksblowfish.abi3.so b/lib/Crypto/Cipher/_raw_eksblowfish.abi3.so
new file mode 100755
index 0000000..acb074f
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_eksblowfish.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_ocb.abi3.so b/lib/Crypto/Cipher/_raw_ocb.abi3.so
new file mode 100755
index 0000000..0c88d18
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_ocb.abi3.so
Binary files differ
diff --git a/lib/Crypto/Cipher/_raw_ofb.abi3.so b/lib/Crypto/Cipher/_raw_ofb.abi3.so
new file mode 100755
index 0000000..eeb94e7
--- /dev/null
+++ b/lib/Crypto/Cipher/_raw_ofb.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/BLAKE2b.py b/lib/Crypto/Hash/BLAKE2b.py
new file mode 100644
index 0000000..a00e0b4
--- /dev/null
+++ b/lib/Crypto/Hash/BLAKE2b.py
@@ -0,0 +1,247 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import bord, tobytes
+
+from Crypto.Random import get_random_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_blake2b_lib = load_pycryptodome_raw_lib("Crypto.Hash._BLAKE2b",
+ """
+ int blake2b_init(void **state,
+ const uint8_t *key,
+ size_t key_size,
+ size_t digest_size);
+ int blake2b_destroy(void *state);
+ int blake2b_update(void *state,
+ const uint8_t *buf,
+ size_t len);
+ int blake2b_digest(const void *state,
+ uint8_t digest[64]);
+ int blake2b_copy(const void *src, void *dst);
+ """)
+
+
+class BLAKE2b_Hash(object):
+ """A BLAKE2b hash object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 64
+
+ def __init__(self, data, key, digest_bytes, update_after_digest):
+
+ # The size of the resulting hash in bytes.
+ self.digest_size = digest_bytes
+
+ self._update_after_digest = update_after_digest
+ self._digest_done = False
+
+ # See https://tools.ietf.org/html/rfc7693
+ if digest_bytes in (20, 32, 48, 64) and not key:
+ self.oid = "1.3.6.1.4.1.1722.12.2.1." + str(digest_bytes)
+
+ state = VoidPointer()
+ result = _raw_blake2b_lib.blake2b_init(state.address_of(),
+ c_uint8_ptr(key),
+ c_size_t(len(key)),
+ c_size_t(digest_bytes)
+ )
+ if result:
+ raise ValueError("Error %d while instantiating BLAKE2b" % result)
+ self._state = SmartPointer(state.get(),
+ _raw_blake2b_lib.blake2b_destroy)
+ if data:
+ self.update(data)
+
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (bytes/bytearray/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._digest_done and not self._update_after_digest:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_blake2b_lib.blake2b_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while hashing BLAKE2b data" % result)
+ return self
+
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(64)
+ result = _raw_blake2b_lib.blake2b_digest(self._state.get(),
+ bfr)
+ if result:
+ raise ValueError("Error %d while creating BLAKE2b digest" % result)
+
+ self._digest_done = True
+
+ return get_raw_buffer(bfr)[:self.digest_size]
+
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in tuple(self.digest())])
+
+
+ def verify(self, mac_tag):
+ """Verify that a given **binary** MAC (computed by another party)
+ is valid.
+
+ Args:
+ mac_tag (bytes/bytearray/memoryview): the expected MAC of the message.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ secret = get_random_bytes(16)
+
+ mac1 = new(digest_bits=160, key=secret, data=mac_tag)
+ mac2 = new(digest_bits=160, key=secret, data=self.digest())
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+
+ def hexverify(self, hex_mac_tag):
+ """Verify that a given **printable** MAC (computed by another party)
+ is valid.
+
+ Args:
+ hex_mac_tag (string): the expected MAC of the message, as a hexadecimal string.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ self.verify(unhexlify(tobytes(hex_mac_tag)))
+
+
+ def new(self, **kwargs):
+ """Return a new instance of a BLAKE2b hash object.
+ See :func:`new`.
+ """
+
+ if "digest_bytes" not in kwargs and "digest_bits" not in kwargs:
+ kwargs["digest_bytes"] = self.digest_size
+
+ return new(**kwargs)
+
+
+def new(**kwargs):
+ """Create a new hash object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`BLAKE2b_Hash.update`.
+ digest_bytes (integer):
+ Optional. The size of the digest, in bytes (1 to 64). Default is 64.
+ digest_bits (integer):
+ Optional and alternative to ``digest_bytes``.
+ The size of the digest, in bits (8 to 512, in steps of 8).
+ Default is 512.
+ key (bytes/bytearray/memoryview):
+ Optional. The key to use to compute the MAC (1 to 64 bytes).
+ If not specified, no key will be used.
+ update_after_digest (boolean):
+ Optional. By default, a hash object cannot be updated anymore after
+ the digest is computed. When this flag is ``True``, such check
+ is no longer enforced.
+
+ Returns:
+ A :class:`BLAKE2b_Hash` hash object
+ """
+
+ data = kwargs.pop("data", None)
+ update_after_digest = kwargs.pop("update_after_digest", False)
+
+ digest_bytes = kwargs.pop("digest_bytes", None)
+ digest_bits = kwargs.pop("digest_bits", None)
+ if None not in (digest_bytes, digest_bits):
+ raise TypeError("Only one digest parameter must be provided")
+ if (None, None) == (digest_bytes, digest_bits):
+ digest_bytes = 64
+ if digest_bytes is not None:
+ if not (1 <= digest_bytes <= 64):
+ raise ValueError("'digest_bytes' not in range 1..64")
+ else:
+ if not (8 <= digest_bits <= 512) or (digest_bits % 8):
+ raise ValueError("'digest_bytes' not in range 8..512, "
+ "with steps of 8")
+ digest_bytes = digest_bits // 8
+
+ key = kwargs.pop("key", b"")
+ if len(key) > 64:
+ raise ValueError("BLAKE2s key cannot exceed 64 bytes")
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return BLAKE2b_Hash(data, key, digest_bytes, update_after_digest)
diff --git a/lib/Crypto/Hash/BLAKE2b.pyi b/lib/Crypto/Hash/BLAKE2b.pyi
new file mode 100644
index 0000000..d37c374
--- /dev/null
+++ b/lib/Crypto/Hash/BLAKE2b.pyi
@@ -0,0 +1,32 @@
+from typing import Any, Union
+from types import ModuleType
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class BLAKE2b_Hash(object):
+ block_size: int
+ digest_size: int
+ oid: str
+
+ def __init__(self,
+ data: Buffer,
+ key: Buffer,
+ digest_bytes: bytes,
+ update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> BLAKE2b_Hash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+ def new(self,
+ data: Buffer = ...,
+ digest_bytes: int = ...,
+ digest_bits: int = ...,
+ key: Buffer = ...,
+ update_after_digest: bool = ...) -> BLAKE2b_Hash: ...
+
+def new(data: Buffer = ...,
+ digest_bytes: int = ...,
+ digest_bits: int = ...,
+ key: Buffer = ...,
+ update_after_digest: bool = ...) -> BLAKE2b_Hash: ...
diff --git a/lib/Crypto/Hash/BLAKE2s.py b/lib/Crypto/Hash/BLAKE2s.py
new file mode 100644
index 0000000..9b25c4a
--- /dev/null
+++ b/lib/Crypto/Hash/BLAKE2s.py
@@ -0,0 +1,247 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import bord, tobytes
+
+from Crypto.Random import get_random_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_blake2s_lib = load_pycryptodome_raw_lib("Crypto.Hash._BLAKE2s",
+ """
+ int blake2s_init(void **state,
+ const uint8_t *key,
+ size_t key_size,
+ size_t digest_size);
+ int blake2s_destroy(void *state);
+ int blake2s_update(void *state,
+ const uint8_t *buf,
+ size_t len);
+ int blake2s_digest(const void *state,
+ uint8_t digest[32]);
+ int blake2s_copy(const void *src, void *dst);
+ """)
+
+
+class BLAKE2s_Hash(object):
+ """A BLAKE2s hash object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 32
+
+ def __init__(self, data, key, digest_bytes, update_after_digest):
+
+ # The size of the resulting hash in bytes.
+ self.digest_size = digest_bytes
+
+ self._update_after_digest = update_after_digest
+ self._digest_done = False
+
+ # See https://tools.ietf.org/html/rfc7693
+ if digest_bytes in (16, 20, 28, 32) and not key:
+ self.oid = "1.3.6.1.4.1.1722.12.2.2." + str(digest_bytes)
+
+ state = VoidPointer()
+ result = _raw_blake2s_lib.blake2s_init(state.address_of(),
+ c_uint8_ptr(key),
+ c_size_t(len(key)),
+ c_size_t(digest_bytes)
+ )
+ if result:
+ raise ValueError("Error %d while instantiating BLAKE2s" % result)
+ self._state = SmartPointer(state.get(),
+ _raw_blake2s_lib.blake2s_destroy)
+ if data:
+ self.update(data)
+
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._digest_done and not self._update_after_digest:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_blake2s_lib.blake2s_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while hashing BLAKE2s data" % result)
+ return self
+
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(32)
+ result = _raw_blake2s_lib.blake2s_digest(self._state.get(),
+ bfr)
+ if result:
+ raise ValueError("Error %d while creating BLAKE2s digest" % result)
+
+ self._digest_done = True
+
+ return get_raw_buffer(bfr)[:self.digest_size]
+
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in tuple(self.digest())])
+
+
+ def verify(self, mac_tag):
+ """Verify that a given **binary** MAC (computed by another party)
+ is valid.
+
+ Args:
+ mac_tag (byte string/byte array/memoryview): the expected MAC of the message.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ secret = get_random_bytes(16)
+
+ mac1 = new(digest_bits=160, key=secret, data=mac_tag)
+ mac2 = new(digest_bits=160, key=secret, data=self.digest())
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+
+ def hexverify(self, hex_mac_tag):
+ """Verify that a given **printable** MAC (computed by another party)
+ is valid.
+
+ Args:
+ hex_mac_tag (string): the expected MAC of the message, as a hexadecimal string.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ self.verify(unhexlify(tobytes(hex_mac_tag)))
+
+
+ def new(self, **kwargs):
+ """Return a new instance of a BLAKE2s hash object.
+ See :func:`new`.
+ """
+
+ if "digest_bytes" not in kwargs and "digest_bits" not in kwargs:
+ kwargs["digest_bytes"] = self.digest_size
+
+ return new(**kwargs)
+
+
+def new(**kwargs):
+ """Create a new hash object.
+
+ Args:
+ data (byte string/byte array/memoryview):
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`BLAKE2s_Hash.update`.
+ digest_bytes (integer):
+ Optional. The size of the digest, in bytes (1 to 32). Default is 32.
+ digest_bits (integer):
+ Optional and alternative to ``digest_bytes``.
+ The size of the digest, in bits (8 to 256, in steps of 8).
+ Default is 256.
+ key (byte string):
+ Optional. The key to use to compute the MAC (1 to 64 bytes).
+ If not specified, no key will be used.
+ update_after_digest (boolean):
+ Optional. By default, a hash object cannot be updated anymore after
+ the digest is computed. When this flag is ``True``, such check
+ is no longer enforced.
+
+ Returns:
+ A :class:`BLAKE2s_Hash` hash object
+ """
+
+ data = kwargs.pop("data", None)
+ update_after_digest = kwargs.pop("update_after_digest", False)
+
+ digest_bytes = kwargs.pop("digest_bytes", None)
+ digest_bits = kwargs.pop("digest_bits", None)
+ if None not in (digest_bytes, digest_bits):
+ raise TypeError("Only one digest parameter must be provided")
+ if (None, None) == (digest_bytes, digest_bits):
+ digest_bytes = 32
+ if digest_bytes is not None:
+ if not (1 <= digest_bytes <= 32):
+ raise ValueError("'digest_bytes' not in range 1..32")
+ else:
+ if not (8 <= digest_bits <= 256) or (digest_bits % 8):
+ raise ValueError("'digest_bytes' not in range 8..256, "
+ "with steps of 8")
+ digest_bytes = digest_bits // 8
+
+ key = kwargs.pop("key", b"")
+ if len(key) > 32:
+ raise ValueError("BLAKE2s key cannot exceed 32 bytes")
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return BLAKE2s_Hash(data, key, digest_bytes, update_after_digest)
diff --git a/lib/Crypto/Hash/BLAKE2s.pyi b/lib/Crypto/Hash/BLAKE2s.pyi
new file mode 100644
index 0000000..374b3a4
--- /dev/null
+++ b/lib/Crypto/Hash/BLAKE2s.pyi
@@ -0,0 +1,26 @@
+from typing import Any, Union
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class BLAKE2s_Hash(object):
+ block_size: int
+ digest_size: int
+ oid: str
+
+ def __init__(self,
+ data: Buffer,
+ key: Buffer,
+ digest_bytes: bytes,
+ update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> BLAKE2s_Hash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+ def new(self, **kwargs: Any) -> BLAKE2s_Hash: ...
+
+def new(data: Buffer = ...,
+ digest_bytes: int = ...,
+ digest_bits: int = ...,
+ key: Buffer = ...,
+ update_after_digest: bool = ...) -> BLAKE2s_Hash: ...
diff --git a/lib/Crypto/Hash/CMAC.py b/lib/Crypto/Hash/CMAC.py
new file mode 100644
index 0000000..7585617
--- /dev/null
+++ b/lib/Crypto/Hash/CMAC.py
@@ -0,0 +1,302 @@
+# -*- coding: utf-8 -*-
+#
+# Hash/CMAC.py - Implements the CMAC algorithm
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from binascii import unhexlify
+
+from Crypto.Hash import BLAKE2s
+from Crypto.Util.strxor import strxor
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+from Crypto.Util.py3compat import bord, tobytes, _copy_bytes
+from Crypto.Random import get_random_bytes
+
+
+# The size of the authentication tag produced by the MAC.
+digest_size = None
+
+
+def _shift_bytes(bs, xor_lsb=0):
+ num = (bytes_to_long(bs) << 1) ^ xor_lsb
+ return long_to_bytes(num, len(bs))[-len(bs):]
+
+
+class CMAC(object):
+ """A CMAC hash object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar digest_size: the size in bytes of the resulting MAC tag
+ :vartype digest_size: integer
+ """
+
+ digest_size = None
+
+ def __init__(self, key, msg, ciphermod, cipher_params, mac_len,
+ update_after_digest):
+
+ self.digest_size = mac_len
+
+ self._key = _copy_bytes(None, None, key)
+ self._factory = ciphermod
+ self._cipher_params = cipher_params
+ self._block_size = bs = ciphermod.block_size
+ self._mac_tag = None
+ self._update_after_digest = update_after_digest
+
+ # Section 5.3 of NIST SP 800 38B and Appendix B
+ if bs == 8:
+ const_Rb = 0x1B
+ self._max_size = 8 * (2 ** 21)
+ elif bs == 16:
+ const_Rb = 0x87
+ self._max_size = 16 * (2 ** 48)
+ else:
+ raise TypeError("CMAC requires a cipher with a block size"
+ " of 8 or 16 bytes, not %d" % bs)
+
+ # Compute sub-keys
+ zero_block = b'\x00' * bs
+ self._ecb = ciphermod.new(key,
+ ciphermod.MODE_ECB,
+ **self._cipher_params)
+ L = self._ecb.encrypt(zero_block)
+ if bord(L[0]) & 0x80:
+ self._k1 = _shift_bytes(L, const_Rb)
+ else:
+ self._k1 = _shift_bytes(L)
+ if bord(self._k1[0]) & 0x80:
+ self._k2 = _shift_bytes(self._k1, const_Rb)
+ else:
+ self._k2 = _shift_bytes(self._k1)
+
+ # Initialize CBC cipher with zero IV
+ self._cbc = ciphermod.new(key,
+ ciphermod.MODE_CBC,
+ zero_block,
+ **self._cipher_params)
+
+ # Cache for outstanding data to authenticate
+ self._cache = bytearray(bs)
+ self._cache_n = 0
+
+ # Last piece of ciphertext produced
+ self._last_ct = zero_block
+
+ # Last block that was encrypted with AES
+ self._last_pt = None
+
+ # Counter for total message size
+ self._data_size = 0
+
+ if msg:
+ self.update(msg)
+
+ def update(self, msg):
+ """Authenticate the next chunk of message.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of data
+ """
+
+ if self._mac_tag is not None and not self._update_after_digest:
+ raise TypeError("update() cannot be called after digest() or verify()")
+
+ self._data_size += len(msg)
+ bs = self._block_size
+
+ if self._cache_n > 0:
+ filler = min(bs - self._cache_n, len(msg))
+ self._cache[self._cache_n:self._cache_n+filler] = msg[:filler]
+ self._cache_n += filler
+
+ if self._cache_n < bs:
+ return self
+
+ msg = memoryview(msg)[filler:]
+ self._update(self._cache)
+ self._cache_n = 0
+
+ remain = len(msg) % bs
+ if remain > 0:
+ self._update(msg[:-remain])
+ self._cache[:remain] = msg[-remain:]
+ else:
+ self._update(msg)
+ self._cache_n = remain
+ return self
+
+ def _update(self, data_block):
+ """Update a block aligned to the block boundary"""
+
+ bs = self._block_size
+ assert len(data_block) % bs == 0
+
+ if len(data_block) == 0:
+ return
+
+ ct = self._cbc.encrypt(data_block)
+ if len(data_block) == bs:
+ second_last = self._last_ct
+ else:
+ second_last = ct[-bs*2:-bs]
+ self._last_ct = ct[-bs:]
+ self._last_pt = strxor(second_last, data_block[-bs:])
+
+ def copy(self):
+ """Return a copy ("clone") of the CMAC object.
+
+ The copy will have the same internal state as the original CMAC
+ object.
+ This can be used to efficiently compute the MAC tag of byte
+ strings that share a common initial substring.
+
+ :return: An :class:`CMAC`
+ """
+
+ obj = self.__new__(CMAC)
+ obj.__dict__ = self.__dict__.copy()
+ obj._cbc = self._factory.new(self._key,
+ self._factory.MODE_CBC,
+ self._last_ct,
+ **self._cipher_params)
+ obj._cache = self._cache[:]
+ obj._last_ct = self._last_ct[:]
+ return obj
+
+ def digest(self):
+ """Return the **binary** (non-printable) MAC tag of the message
+ that has been authenticated so far.
+
+ :return: The MAC tag, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bs = self._block_size
+
+ if self._mac_tag is not None and not self._update_after_digest:
+ return self._mac_tag
+
+ if self._data_size > self._max_size:
+ raise ValueError("MAC is unsafe for this message")
+
+ if self._cache_n == 0 and self._data_size > 0:
+ # Last block was full
+ pt = strxor(self._last_pt, self._k1)
+ else:
+ # Last block is partial (or message length is zero)
+ partial = self._cache[:]
+ partial[self._cache_n:] = b'\x80' + b'\x00' * (bs - self._cache_n - 1)
+ pt = strxor(strxor(self._last_ct, partial), self._k2)
+
+ self._mac_tag = self._ecb.encrypt(pt)[:self.digest_size]
+
+ return self._mac_tag
+
+ def hexdigest(self):
+ """Return the **printable** MAC tag of the message authenticated so far.
+
+ :return: The MAC tag, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x)
+ for x in tuple(self.digest())])
+
+ def verify(self, mac_tag):
+ """Verify that a given **binary** MAC (computed by another party)
+ is valid.
+
+ Args:
+ mac_tag (byte string/byte array/memoryview): the expected MAC of the message.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=self.digest())
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Return the **printable** MAC tag of the message authenticated so far.
+
+ :return: The MAC tag, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ self.verify(unhexlify(tobytes(hex_mac_tag)))
+
+
+def new(key, msg=None, ciphermod=None, cipher_params=None, mac_len=None,
+ update_after_digest=False):
+ """Create a new MAC object.
+
+ Args:
+ key (byte string/byte array/memoryview):
+ key for the CMAC object.
+ The key must be valid for the underlying cipher algorithm.
+ For instance, it must be 16 bytes long for AES-128.
+ ciphermod (module):
+ A cipher module from :mod:`Crypto.Cipher`.
+ The cipher's block size has to be 128 bits,
+ like :mod:`Crypto.Cipher.AES`, to reduce the probability
+ of collisions.
+ msg (byte string/byte array/memoryview):
+ Optional. The very first chunk of the message to authenticate.
+ It is equivalent to an early call to `CMAC.update`. Optional.
+ cipher_params (dict):
+ Optional. A set of parameters to use when instantiating a cipher
+ object.
+ mac_len (integer):
+ Length of the MAC, in bytes.
+ It must be at least 4 bytes long.
+ The default (and recommended) length matches the size of a cipher block.
+ update_after_digest (boolean):
+ Optional. By default, a hash object cannot be updated anymore after
+ the digest is computed. When this flag is ``True``, such check
+ is no longer enforced.
+ Returns:
+ A :class:`CMAC` object
+ """
+
+ if ciphermod is None:
+ raise TypeError("ciphermod must be specified (try AES)")
+
+ cipher_params = {} if cipher_params is None else dict(cipher_params)
+
+ if mac_len is None:
+ mac_len = ciphermod.block_size
+
+ if mac_len < 4:
+ raise ValueError("MAC tag length must be at least 4 bytes long")
+
+ if mac_len > ciphermod.block_size:
+ raise ValueError("MAC tag length cannot be larger than a cipher block (%d) bytes" % ciphermod.block_size)
+
+ return CMAC(key, msg, ciphermod, cipher_params, mac_len,
+ update_after_digest)
diff --git a/lib/Crypto/Hash/CMAC.pyi b/lib/Crypto/Hash/CMAC.pyi
new file mode 100644
index 0000000..acdf055
--- /dev/null
+++ b/lib/Crypto/Hash/CMAC.pyi
@@ -0,0 +1,30 @@
+from types import ModuleType
+from typing import Union, Dict, Any
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+digest_size: int
+
+class CMAC(object):
+ digest_size: int
+
+ def __init__(self,
+ key: Buffer,
+ msg: Buffer,
+ ciphermod: ModuleType,
+ cipher_params: Dict[str, Any],
+ mac_len: int, update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> CMAC: ...
+ def copy(self) -> CMAC: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+
+def new(key: Buffer,
+ msg: Buffer = ...,
+ ciphermod: ModuleType = ...,
+ cipher_params: Dict[str, Any] = ...,
+ mac_len: int = ...,
+ update_after_digest: bool = ...) -> CMAC: ...
diff --git a/lib/Crypto/Hash/HMAC.py b/lib/Crypto/Hash/HMAC.py
new file mode 100644
index 0000000..e82bb9d
--- /dev/null
+++ b/lib/Crypto/Hash/HMAC.py
@@ -0,0 +1,213 @@
+#
+# HMAC.py - Implements the HMAC algorithm as described by RFC 2104.
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord, tobytes
+
+from binascii import unhexlify
+
+from Crypto.Hash import MD5
+from Crypto.Hash import BLAKE2s
+from Crypto.Util.strxor import strxor
+from Crypto.Random import get_random_bytes
+
+__all__ = ['new', 'HMAC']
+
+
+class HMAC(object):
+ """An HMAC hash object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar digest_size: the size in bytes of the resulting MAC tag
+ :vartype digest_size: integer
+ """
+
+ def __init__(self, key, msg=b"", digestmod=None):
+
+ if digestmod is None:
+ digestmod = MD5
+
+ if msg is None:
+ msg = b""
+
+ # Size of the MAC tag
+ self.digest_size = digestmod.digest_size
+
+ self._digestmod = digestmod
+
+ if isinstance(key, memoryview):
+ key = key.tobytes()
+
+ try:
+ if len(key) <= digestmod.block_size:
+ # Step 1 or 2
+ key_0 = key + b"\x00" * (digestmod.block_size - len(key))
+ else:
+ # Step 3
+ hash_k = digestmod.new(key).digest()
+ key_0 = hash_k + b"\x00" * (digestmod.block_size - len(hash_k))
+ except AttributeError:
+ # Not all hash types have "block_size"
+ raise ValueError("Hash type incompatible to HMAC")
+
+ # Step 4
+ key_0_ipad = strxor(key_0, b"\x36" * len(key_0))
+
+ # Start step 5 and 6
+ self._inner = digestmod.new(key_0_ipad)
+ self._inner.update(msg)
+
+ # Step 7
+ key_0_opad = strxor(key_0, b"\x5c" * len(key_0))
+
+ # Start step 8 and 9
+ self._outer = digestmod.new(key_0_opad)
+
+ def update(self, msg):
+ """Authenticate the next chunk of message.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of data
+ """
+
+ self._inner.update(msg)
+ return self
+
+ def _pbkdf2_hmac_assist(self, first_digest, iterations):
+ """Carry out the expensive inner loop for PBKDF2-HMAC"""
+
+ result = self._digestmod._pbkdf2_hmac_assist(
+ self._inner,
+ self._outer,
+ first_digest,
+ iterations)
+ return result
+
+ def copy(self):
+ """Return a copy ("clone") of the HMAC object.
+
+ The copy will have the same internal state as the original HMAC
+ object.
+ This can be used to efficiently compute the MAC tag of byte
+ strings that share a common initial substring.
+
+ :return: An :class:`HMAC`
+ """
+
+ new_hmac = HMAC(b"fake key", digestmod=self._digestmod)
+
+ # Syncronize the state
+ new_hmac._inner = self._inner.copy()
+ new_hmac._outer = self._outer.copy()
+
+ return new_hmac
+
+ def digest(self):
+ """Return the **binary** (non-printable) MAC tag of the message
+ authenticated so far.
+
+ :return: The MAC tag digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ frozen_outer_hash = self._outer.copy()
+ frozen_outer_hash.update(self._inner.digest())
+ return frozen_outer_hash.digest()
+
+ def verify(self, mac_tag):
+ """Verify that a given **binary** MAC (computed by another party)
+ is valid.
+
+ Args:
+ mac_tag (byte string/byte string/memoryview): the expected MAC of the message.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=self.digest())
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexdigest(self):
+ """Return the **printable** MAC tag of the message authenticated so far.
+
+ :return: The MAC tag, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x)
+ for x in tuple(self.digest())])
+
+ def hexverify(self, hex_mac_tag):
+ """Verify that a given **printable** MAC (computed by another party)
+ is valid.
+
+ Args:
+ hex_mac_tag (string): the expected MAC of the message,
+ as a hexadecimal string.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ self.verify(unhexlify(tobytes(hex_mac_tag)))
+
+
+def new(key, msg=b"", digestmod=None):
+ """Create a new MAC object.
+
+ Args:
+ key (bytes/bytearray/memoryview):
+ key for the MAC object.
+ It must be long enough to match the expected security level of the
+ MAC.
+ msg (bytes/bytearray/memoryview):
+ Optional. The very first chunk of the message to authenticate.
+ It is equivalent to an early call to :meth:`HMAC.update`.
+ digestmod (module):
+ The hash to use to implement the HMAC.
+ Default is :mod:`Crypto.Hash.MD5`.
+
+ Returns:
+ An :class:`HMAC` object
+ """
+
+ return HMAC(key, msg, digestmod)
diff --git a/lib/Crypto/Hash/HMAC.pyi b/lib/Crypto/Hash/HMAC.pyi
new file mode 100644
index 0000000..b577230
--- /dev/null
+++ b/lib/Crypto/Hash/HMAC.pyi
@@ -0,0 +1,25 @@
+from types import ModuleType
+from typing import Union, Dict
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+digest_size: int
+
+class HMAC(object):
+ digest_size: int
+
+ def __init__(self,
+ key: Buffer,
+ msg: Buffer,
+ digestmod: ModuleType) -> None: ...
+ def update(self, msg: Buffer) -> HMAC: ...
+ def copy(self) -> HMAC: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+
+def new(key: Buffer,
+ msg: Buffer = ...,
+ digestmod: ModuleType = ...) -> HMAC: ...
diff --git a/lib/Crypto/Hash/KMAC128.py b/lib/Crypto/Hash/KMAC128.py
new file mode 100644
index 0000000..05061fc
--- /dev/null
+++ b/lib/Crypto/Hash/KMAC128.py
@@ -0,0 +1,179 @@
+# ===================================================================
+#
+# Copyright (c) 2021, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import bord, tobytes, is_bytes
+from Crypto.Random import get_random_bytes
+
+from . import cSHAKE128, SHA3_256
+from .cSHAKE128 import _bytepad, _encode_str, _right_encode
+
+
+class KMAC_Hash(object):
+ """A KMAC hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+ """
+
+ def __init__(self, data, key, mac_len, custom,
+ oid_variant, cshake, rate):
+
+ # See https://tools.ietf.org/html/rfc8702
+ self.oid = "2.16.840.1.101.3.4.2." + oid_variant
+ self.digest_size = mac_len
+
+ self._mac = None
+
+ partial_newX = _bytepad(_encode_str(tobytes(key)), rate)
+ self._cshake = cshake._new(partial_newX, custom, b"KMAC")
+
+ if data:
+ self._cshake.update(data)
+
+ def update(self, data):
+ """Authenticate the next chunk of message.
+
+ Args:
+ data (bytes/bytearray/memoryview): The next chunk of the message to
+ authenticate.
+ """
+
+ if self._mac:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ self._cshake.update(data)
+ return self
+
+ def digest(self):
+ """Return the **binary** (non-printable) MAC tag of the message.
+
+ :return: The MAC tag. Binary form.
+ :rtype: byte string
+ """
+
+ if not self._mac:
+ self._cshake.update(_right_encode(self.digest_size * 8))
+ self._mac = self._cshake.read(self.digest_size)
+
+ return self._mac
+
+ def hexdigest(self):
+ """Return the **printable** MAC tag of the message.
+
+ :return: The MAC tag. Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in tuple(self.digest())])
+
+ def verify(self, mac_tag):
+ """Verify that a given **binary** MAC (computed by another party)
+ is valid.
+
+ Args:
+ mac_tag (bytes/bytearray/memoryview): the expected MAC of the message.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ secret = get_random_bytes(16)
+
+ mac1 = SHA3_256.new(secret + mac_tag)
+ mac2 = SHA3_256.new(secret + self.digest())
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Verify that a given **printable** MAC (computed by another party)
+ is valid.
+
+ Args:
+ hex_mac_tag (string): the expected MAC of the message, as a hexadecimal string.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ self.verify(unhexlify(tobytes(hex_mac_tag)))
+
+ def new(self, **kwargs):
+ """Return a new instance of a KMAC hash object.
+ See :func:`new`.
+ """
+
+ if "mac_len" not in kwargs:
+ kwargs["mac_len"] = self.digest_size
+
+ return new(**kwargs)
+
+
+def new(**kwargs):
+ """Create a new KMAC128 object.
+
+ Args:
+ key (bytes/bytearray/memoryview):
+ The key to use to compute the MAC.
+ It must be at least 128 bits long (16 bytes).
+ data (bytes/bytearray/memoryview):
+ Optional. The very first chunk of the message to authenticate.
+ It is equivalent to an early call to :meth:`KMAC_Hash.update`.
+ mac_len (integer):
+ Optional. The size of the authentication tag, in bytes.
+ Default is 64. Minimum is 8.
+ custom (bytes/bytearray/memoryview):
+ Optional. A customization byte string (``S`` in SP 800-185).
+
+ Returns:
+ A :class:`KMAC_Hash` hash object
+ """
+
+ key = kwargs.pop("key", None)
+ if not is_bytes(key):
+ raise TypeError("You must pass a key to KMAC128")
+ if len(key) < 16:
+ raise ValueError("The key must be at least 128 bits long (16 bytes)")
+
+ data = kwargs.pop("data", None)
+
+ mac_len = kwargs.pop("mac_len", 64)
+ if mac_len < 8:
+ raise ValueError("'mac_len' must be 8 bytes or more")
+
+ custom = kwargs.pop("custom", b"")
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return KMAC_Hash(data, key, mac_len, custom, "19", cSHAKE128, 168)
diff --git a/lib/Crypto/Hash/KMAC128.pyi b/lib/Crypto/Hash/KMAC128.pyi
new file mode 100644
index 0000000..8947dab
--- /dev/null
+++ b/lib/Crypto/Hash/KMAC128.pyi
@@ -0,0 +1,33 @@
+from typing import Union
+from types import ModuleType
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class KMAC_Hash(object):
+
+ def __init__(self,
+ data: Buffer,
+ key: Buffer,
+ mac_len: int,
+ custom: Buffer,
+ oid_variant: str,
+ cshake: ModuleType,
+ rate: int) -> None: ...
+
+ def update(self, data: Buffer) -> KMAC_Hash: ...
+
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+ def new(self,
+ data: Buffer = ...,
+ mac_len: int = ...,
+ key: Buffer = ...,
+ custom: Buffer = ...) -> KMAC_Hash: ...
+
+
+def new(key: Buffer,
+ data: Buffer = ...,
+ mac_len: int = ...,
+ custom: Buffer = ...) -> KMAC_Hash: ...
diff --git a/lib/Crypto/Hash/KMAC256.py b/lib/Crypto/Hash/KMAC256.py
new file mode 100644
index 0000000..2be8e2f
--- /dev/null
+++ b/lib/Crypto/Hash/KMAC256.py
@@ -0,0 +1,74 @@
+# ===================================================================
+#
+# Copyright (c) 2021, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import is_bytes
+
+from .KMAC128 import KMAC_Hash
+from . import cSHAKE256
+
+
+def new(**kwargs):
+ """Create a new KMAC256 object.
+
+ Args:
+ key (bytes/bytearray/memoryview):
+ The key to use to compute the MAC.
+ It must be at least 256 bits long (32 bytes).
+ data (bytes/bytearray/memoryview):
+ Optional. The very first chunk of the message to authenticate.
+ It is equivalent to an early call to :meth:`KMAC_Hash.update`.
+ mac_len (integer):
+ Optional. The size of the authentication tag, in bytes.
+ Default is 64. Minimum is 8.
+ custom (bytes/bytearray/memoryview):
+ Optional. A customization byte string (``S`` in SP 800-185).
+
+ Returns:
+ A :class:`KMAC_Hash` hash object
+ """
+
+ key = kwargs.pop("key", None)
+ if not is_bytes(key):
+ raise TypeError("You must pass a key to KMAC256")
+ if len(key) < 32:
+ raise ValueError("The key must be at least 256 bits long (32 bytes)")
+
+ data = kwargs.pop("data", None)
+
+ mac_len = kwargs.pop("mac_len", 64)
+ if mac_len < 8:
+ raise ValueError("'mac_len' must be 8 bytes or more")
+
+ custom = kwargs.pop("custom", b"")
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return KMAC_Hash(data, key, mac_len, custom, "20", cSHAKE256, 136)
diff --git a/lib/Crypto/Hash/KMAC256.pyi b/lib/Crypto/Hash/KMAC256.pyi
new file mode 100644
index 0000000..86cc500
--- /dev/null
+++ b/lib/Crypto/Hash/KMAC256.pyi
@@ -0,0 +1,10 @@
+from typing import Union
+
+from .KMAC128 import KMAC_Hash
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(key: Buffer,
+ data: Buffer = ...,
+ mac_len: int = ...,
+ custom: Buffer = ...) -> KMAC_Hash: ...
diff --git a/lib/Crypto/Hash/KangarooTwelve.py b/lib/Crypto/Hash/KangarooTwelve.py
new file mode 100644
index 0000000..f5358d4
--- /dev/null
+++ b/lib/Crypto/Hash/KangarooTwelve.py
@@ -0,0 +1,262 @@
+# ===================================================================
+#
+# Copyright (c) 2021, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util._raw_api import (VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Util.number import long_to_bytes
+from Crypto.Util.py3compat import bchr
+
+from .keccak import _raw_keccak_lib
+
+
+def _length_encode(x):
+ if x == 0:
+ return b'\x00'
+
+ S = long_to_bytes(x)
+ return S + bchr(len(S))
+
+
+# Possible states for a KangarooTwelve instance, which depend on the amount of data processed so far.
+SHORT_MSG = 1 # Still within the first 8192 bytes, but it is not certain we will exceed them.
+LONG_MSG_S0 = 2 # Still within the first 8192 bytes, and it is certain we will exceed them.
+LONG_MSG_SX = 3 # Beyond the first 8192 bytes.
+SQUEEZING = 4 # No more data to process.
+
+
+class K12_XOF(object):
+ """A KangarooTwelve hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+ """
+
+ def __init__(self, data, custom):
+
+ if custom == None:
+ custom = b''
+
+ self._custom = custom + _length_encode(len(custom))
+ self._state = SHORT_MSG
+ self._padding = None # Final padding is only decided in read()
+
+ # Internal hash that consumes FinalNode
+ self._hash1 = self._create_keccak()
+ self._length1 = 0
+
+ # Internal hash that produces CV_i (reset each time)
+ self._hash2 = None
+ self._length2 = 0
+
+ # Incremented by one for each 8192-byte block
+ self._ctr = 0
+
+ if data:
+ self.update(data)
+
+ def _create_keccak(self):
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(32), # 32 bytes of capacity (256 bits)
+ c_ubyte(12)) # Reduced number of rounds
+ if result:
+ raise ValueError("Error %d while instantiating KangarooTwelve"
+ % result)
+ return SmartPointer(state.get(), _raw_keccak_lib.keccak_destroy)
+
+ def _update(self, data, hash_obj):
+ result = _raw_keccak_lib.keccak_absorb(hash_obj.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while updating KangarooTwelve state"
+ % result)
+
+ def _squeeze(self, hash_obj, length, padding):
+ bfr = create_string_buffer(length)
+ result = _raw_keccak_lib.keccak_squeeze(hash_obj.get(),
+ bfr,
+ c_size_t(length),
+ c_ubyte(padding))
+ if result:
+ raise ValueError("Error %d while extracting from KangarooTwelve"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def _reset(self, hash_obj):
+ result = _raw_keccak_lib.keccak_reset(hash_obj.get())
+ if result:
+ raise ValueError("Error %d while resetting KangarooTwelve state"
+ % result)
+
+ def update(self, data):
+ """Hash the next piece of data.
+
+ .. note::
+ For better performance, submit chunks with a length multiple of 8192 bytes.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the
+ message to hash.
+ """
+
+ if self._state == SQUEEZING:
+ raise TypeError("You cannot call 'update' after the first 'read'")
+
+ if self._state == SHORT_MSG:
+ next_length = self._length1 + len(data)
+
+ if next_length + len(self._custom) <= 8192:
+ self._length1 = next_length
+ self._update(data, self._hash1)
+ return self
+
+ # Switch to tree hashing
+ self._state = LONG_MSG_S0
+
+ if self._state == LONG_MSG_S0:
+ data_mem = memoryview(data)
+ assert(self._length1 < 8192)
+ dtc = min(len(data), 8192 - self._length1)
+ self._update(data_mem[:dtc], self._hash1)
+ self._length1 += dtc
+
+ if self._length1 < 8192:
+ return self
+
+ # Finish hashing S_0 and start S_1
+ assert(self._length1 == 8192)
+
+ divider = b'\x03' + b'\x00' * 7
+ self._update(divider, self._hash1)
+ self._length1 += 8
+
+ self._hash2 = self._create_keccak()
+ self._length2 = 0
+ self._ctr = 1
+
+ self._state = LONG_MSG_SX
+ return self.update(data_mem[dtc:])
+
+ # LONG_MSG_SX
+ assert(self._state == LONG_MSG_SX)
+ index = 0
+ len_data = len(data)
+
+ # All iteractions could actually run in parallel
+ data_mem = memoryview(data)
+ while index < len_data:
+
+ new_index = min(index + 8192 - self._length2, len_data)
+ self._update(data_mem[index:new_index], self._hash2)
+ self._length2 += new_index - index
+ index = new_index
+
+ if self._length2 == 8192:
+ cv_i = self._squeeze(self._hash2, 32, 0x0B)
+ self._update(cv_i, self._hash1)
+ self._length1 += 32
+ self._reset(self._hash2)
+ self._length2 = 0
+ self._ctr += 1
+
+ return self
+
+ def read(self, length):
+ """
+ Produce more bytes of the digest.
+
+ .. note::
+ You cannot use :meth:`update` anymore after the first call to
+ :meth:`read`.
+
+ Args:
+ length (integer): the amount of bytes this method must return
+
+ :return: the next piece of XOF output (of the given length)
+ :rtype: byte string
+ """
+
+ custom_was_consumed = False
+
+ if self._state == SHORT_MSG:
+ self._update(self._custom, self._hash1)
+ self._padding = 0x07
+ self._state = SQUEEZING
+
+ if self._state == LONG_MSG_S0:
+ self.update(self._custom)
+ custom_was_consumed = True
+ assert(self._state == LONG_MSG_SX)
+
+ if self._state == LONG_MSG_SX:
+ if not custom_was_consumed:
+ self.update(self._custom)
+
+ # Is there still some leftover data in hash2?
+ if self._length2 > 0:
+ cv_i = self._squeeze(self._hash2, 32, 0x0B)
+ self._update(cv_i, self._hash1)
+ self._length1 += 32
+ self._reset(self._hash2)
+ self._length2 = 0
+ self._ctr += 1
+
+ trailer = _length_encode(self._ctr - 1) + b'\xFF\xFF'
+ self._update(trailer, self._hash1)
+
+ self._padding = 0x06
+ self._state = SQUEEZING
+
+ return self._squeeze(self._hash1, length, self._padding)
+
+ def new(self, data=None, custom=b''):
+ return type(self)(data, custom)
+
+
+def new(data=None, custom=None):
+ """Return a fresh instance of a KangarooTwelve object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ Optional.
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ custom (bytes):
+ Optional.
+ A customization byte string.
+
+ :Return: A :class:`K12_XOF` object
+ """
+
+ return K12_XOF(data, custom)
diff --git a/lib/Crypto/Hash/KangarooTwelve.pyi b/lib/Crypto/Hash/KangarooTwelve.pyi
new file mode 100644
index 0000000..8b3fd74
--- /dev/null
+++ b/lib/Crypto/Hash/KangarooTwelve.pyi
@@ -0,0 +1,16 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class K12_XOF(object):
+ def __init__(self,
+ data: Optional[Buffer] = ...,
+ custom: Optional[bytes] = ...) -> None: ...
+ def update(self, data: Buffer) -> K12_XOF: ...
+ def read(self, length: int) -> bytes: ...
+ def new(self,
+ data: Optional[Buffer] = ...,
+ custom: Optional[bytes] = ...) -> None: ...
+
+def new(data: Optional[Buffer] = ...,
+ custom: Optional[Buffer] = ...) -> K12_XOF: ...
diff --git a/lib/Crypto/Hash/MD2.py b/lib/Crypto/Hash/MD2.py
new file mode 100644
index 0000000..41decbb
--- /dev/null
+++ b/lib/Crypto/Hash/MD2.py
@@ -0,0 +1,166 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_md2_lib = load_pycryptodome_raw_lib(
+ "Crypto.Hash._MD2",
+ """
+ int md2_init(void **shaState);
+ int md2_destroy(void *shaState);
+ int md2_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int md2_digest(const void *shaState,
+ uint8_t digest[20]);
+ int md2_copy(const void *src, void *dst);
+ """)
+
+
+class MD2Hash(object):
+ """An MD2 hash object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 16
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 16
+ # ASN.1 Object ID
+ oid = "1.2.840.113549.2.2"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_md2_lib.md2_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating MD2"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_md2_lib.md2_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_md2_lib.md2_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while instantiating MD2"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_md2_lib.md2_digest(self._state.get(),
+ bfr)
+ if result:
+ raise ValueError("Error %d while instantiating MD2"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = MD2Hash()
+ result = _raw_md2_lib.md2_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying MD2" % result)
+ return clone
+
+ def new(self, data=None):
+ return MD2Hash(data)
+
+
+def new(data=None):
+ """Create a new hash object.
+
+ :parameter data:
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`MD2Hash.update`.
+ :type data: bytes/bytearray/memoryview
+
+ :Return: A :class:`MD2Hash` hash object
+ """
+
+ return MD2Hash().new(data)
+
+# The size of the resulting hash in bytes.
+digest_size = MD2Hash.digest_size
+
+# The internal block size of the hash algorithm in bytes.
+block_size = MD2Hash.block_size
diff --git a/lib/Crypto/Hash/MD2.pyi b/lib/Crypto/Hash/MD2.pyi
new file mode 100644
index 0000000..95a97a9
--- /dev/null
+++ b/lib/Crypto/Hash/MD2.pyi
@@ -0,0 +1,19 @@
+from typing import Union
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class MD4Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self, data: Buffer = ...) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> MD4Hash: ...
+ def new(self, data: Buffer = ...) -> MD4Hash: ...
+
+def new(data: Buffer = ...) -> MD4Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/MD4.py b/lib/Crypto/Hash/MD4.py
new file mode 100644
index 0000000..be12b19
--- /dev/null
+++ b/lib/Crypto/Hash/MD4.py
@@ -0,0 +1,185 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+MD4 is specified in RFC1320_ and produces the 128 bit digest of a message.
+
+ >>> from Crypto.Hash import MD4
+ >>>
+ >>> h = MD4.new()
+ >>> h.update(b'Hello')
+ >>> print h.hexdigest()
+
+MD4 stand for Message Digest version 4, and it was invented by Rivest in 1990.
+This algorithm is insecure. Do not use it for new designs.
+
+.. _RFC1320: http://tools.ietf.org/html/rfc1320
+"""
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_md4_lib = load_pycryptodome_raw_lib(
+ "Crypto.Hash._MD4",
+ """
+ int md4_init(void **shaState);
+ int md4_destroy(void *shaState);
+ int md4_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int md4_digest(const void *shaState,
+ uint8_t digest[20]);
+ int md4_copy(const void *src, void *dst);
+ """)
+
+
+class MD4Hash(object):
+ """Class that implements an MD4 hash
+ """
+
+ #: The size of the resulting hash in bytes.
+ digest_size = 16
+ #: The internal block size of the hash algorithm in bytes.
+ block_size = 64
+ #: ASN.1 Object ID
+ oid = "1.2.840.113549.2.4"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_md4_lib.md4_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating MD4"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_md4_lib.md4_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Repeated calls are equivalent to a single call with the concatenation
+ of all the arguments. In other words:
+
+ >>> m.update(a); m.update(b)
+
+ is equivalent to:
+
+ >>> m.update(a+b)
+
+ :Parameters:
+ data : byte string/byte array/memoryview
+ The next chunk of the message being hashed.
+ """
+
+ result = _raw_md4_lib.md4_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while instantiating MD4"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that
+ has been hashed so far.
+
+ This method does not change the state of the hash object.
+ You can continue updating the object after calling this function.
+
+ :Return: A byte string of `digest_size` bytes. It may contain non-ASCII
+ characters, including null bytes.
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_md4_lib.md4_digest(self._state.get(),
+ bfr)
+ if result:
+ raise ValueError("Error %d while instantiating MD4"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been
+ hashed so far.
+
+ This method does not change the state of the hash object.
+
+ :Return: A string of 2* `digest_size` characters. It contains only
+ hexadecimal ASCII digits.
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :Return: A hash object of the same type
+ """
+
+ clone = MD4Hash()
+ result = _raw_md4_lib.md4_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying MD4" % result)
+ return clone
+
+ def new(self, data=None):
+ return MD4Hash(data)
+
+
+def new(data=None):
+ """Return a fresh instance of the hash object.
+
+ :Parameters:
+ data : byte string/byte array/memoryview
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to `MD4Hash.update()`.
+ Optional.
+
+ :Return: A `MD4Hash` object
+ """
+ return MD4Hash().new(data)
+
+#: The size of the resulting hash in bytes.
+digest_size = MD4Hash.digest_size
+
+#: The internal block size of the hash algorithm in bytes.
+block_size = MD4Hash.block_size
diff --git a/lib/Crypto/Hash/MD4.pyi b/lib/Crypto/Hash/MD4.pyi
new file mode 100644
index 0000000..a9a7295
--- /dev/null
+++ b/lib/Crypto/Hash/MD4.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class MD4Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self, data: Optional[Buffer] = ...) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> MD4Hash: ...
+ def new(self, data: Optional[Buffer] = ...) -> MD4Hash: ...
+
+def new(data: Optional[Buffer] = ...) -> MD4Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/MD5.py b/lib/Crypto/Hash/MD5.py
new file mode 100644
index 0000000..554b777
--- /dev/null
+++ b/lib/Crypto/Hash/MD5.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import *
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_md5_lib = load_pycryptodome_raw_lib("Crypto.Hash._MD5",
+ """
+ #define MD5_DIGEST_SIZE 16
+
+ int MD5_init(void **shaState);
+ int MD5_destroy(void *shaState);
+ int MD5_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int MD5_digest(const void *shaState,
+ uint8_t digest[MD5_DIGEST_SIZE]);
+ int MD5_copy(const void *src, void *dst);
+
+ int MD5_pbkdf2_hmac_assist(const void *inner,
+ const void *outer,
+ const uint8_t first_digest[MD5_DIGEST_SIZE],
+ uint8_t final_digest[MD5_DIGEST_SIZE],
+ size_t iterations);
+ """)
+
+class MD5Hash(object):
+ """A MD5 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 16
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 64
+ # ASN.1 Object ID
+ oid = "1.2.840.113549.2.5"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_md5_lib.MD5_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating MD5"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_md5_lib.MD5_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_md5_lib.MD5_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while instantiating MD5"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_md5_lib.MD5_digest(self._state.get(),
+ bfr)
+ if result:
+ raise ValueError("Error %d while instantiating MD5"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = MD5Hash()
+ result = _raw_md5_lib.MD5_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying MD5" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA-1 hash object."""
+
+ return MD5Hash(data)
+
+
+def new(data=None):
+ """Create a new hash object.
+
+ :parameter data:
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`MD5Hash.update`.
+ :type data: byte string/byte array/memoryview
+
+ :Return: A :class:`MD5Hash` hash object
+ """
+ return MD5Hash().new(data)
+
+# The size of the resulting hash in bytes.
+digest_size = 16
+
+# The internal block size of the hash algorithm in bytes.
+block_size = 64
+
+
+def _pbkdf2_hmac_assist(inner, outer, first_digest, iterations):
+ """Compute the expensive inner loop in PBKDF-HMAC."""
+
+ assert len(first_digest) == digest_size
+ assert iterations > 0
+
+ bfr = create_string_buffer(digest_size);
+ result = _raw_md5_lib.MD5_pbkdf2_hmac_assist(
+ inner._state.get(),
+ outer._state.get(),
+ first_digest,
+ bfr,
+ c_size_t(iterations))
+
+ if result:
+ raise ValueError("Error %d with PBKDF2-HMAC assis for MD5" % result)
+
+ return get_raw_buffer(bfr)
diff --git a/lib/Crypto/Hash/MD5.pyi b/lib/Crypto/Hash/MD5.pyi
new file mode 100644
index 0000000..d819556
--- /dev/null
+++ b/lib/Crypto/Hash/MD5.pyi
@@ -0,0 +1,19 @@
+from typing import Union
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class MD5Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self, data: Buffer = ...) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> MD5Hash: ...
+ def new(self, data: Buffer = ...) -> MD5Hash: ...
+
+def new(data: Buffer = ...) -> MD5Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/Poly1305.py b/lib/Crypto/Hash/Poly1305.py
new file mode 100644
index 0000000..eb5e0da
--- /dev/null
+++ b/lib/Crypto/Hash/Poly1305.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+#
+# Hash/Poly1305.py - Implements the Poly1305 MAC
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import bord, tobytes, _copy_bytes
+
+from Crypto.Hash import BLAKE2s
+from Crypto.Random import get_random_bytes
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+
+_raw_poly1305 = load_pycryptodome_raw_lib("Crypto.Hash._poly1305",
+ """
+ int poly1305_init(void **state,
+ const uint8_t *r,
+ size_t r_len,
+ const uint8_t *s,
+ size_t s_len);
+ int poly1305_destroy(void *state);
+ int poly1305_update(void *state,
+ const uint8_t *in,
+ size_t len);
+ int poly1305_digest(const void *state,
+ uint8_t *digest,
+ size_t len);
+ """)
+
+
+class Poly1305_MAC(object):
+ """An Poly1305 MAC object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar digest_size: the size in bytes of the resulting MAC tag
+ :vartype digest_size: integer
+ """
+
+ digest_size = 16
+
+ def __init__(self, r, s, data):
+
+ if len(r) != 16:
+ raise ValueError("Parameter r is not 16 bytes long")
+ if len(s) != 16:
+ raise ValueError("Parameter s is not 16 bytes long")
+
+ self._mac_tag = None
+
+ state = VoidPointer()
+ result = _raw_poly1305.poly1305_init(state.address_of(),
+ c_uint8_ptr(r),
+ c_size_t(len(r)),
+ c_uint8_ptr(s),
+ c_size_t(len(s))
+ )
+ if result:
+ raise ValueError("Error %d while instantiating Poly1305" % result)
+ self._state = SmartPointer(state.get(),
+ _raw_poly1305.poly1305_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Authenticate the next chunk of message.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of data
+ """
+
+ if self._mac_tag:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_poly1305.poly1305_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while hashing Poly1305 data" % result)
+ return self
+
+ def copy(self):
+ raise NotImplementedError()
+
+ def digest(self):
+ """Return the **binary** (non-printable) MAC tag of the message
+ authenticated so far.
+
+ :return: The MAC tag digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ if self._mac_tag:
+ return self._mac_tag
+
+ bfr = create_string_buffer(16)
+ result = _raw_poly1305.poly1305_digest(self._state.get(),
+ bfr,
+ c_size_t(len(bfr)))
+ if result:
+ raise ValueError("Error %d while creating Poly1305 digest" % result)
+
+ self._mac_tag = get_raw_buffer(bfr)
+ return self._mac_tag
+
+ def hexdigest(self):
+ """Return the **printable** MAC tag of the message authenticated so far.
+
+ :return: The MAC tag, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x)
+ for x in tuple(self.digest())])
+
+ def verify(self, mac_tag):
+ """Verify that a given **binary** MAC (computed by another party)
+ is valid.
+
+ Args:
+ mac_tag (byte string/byte string/memoryview): the expected MAC of the message.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=mac_tag)
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=self.digest())
+
+ if mac1.digest() != mac2.digest():
+ raise ValueError("MAC check failed")
+
+ def hexverify(self, hex_mac_tag):
+ """Verify that a given **printable** MAC (computed by another party)
+ is valid.
+
+ Args:
+ hex_mac_tag (string): the expected MAC of the message,
+ as a hexadecimal string.
+
+ Raises:
+ ValueError: if the MAC does not match. It means that the message
+ has been tampered with or that the MAC key is incorrect.
+ """
+
+ self.verify(unhexlify(tobytes(hex_mac_tag)))
+
+
+
+def new(**kwargs):
+ """Create a new Poly1305 MAC object.
+
+ Args:
+ key (bytes/bytearray/memoryview):
+ The 32-byte key for the Poly1305 object.
+ cipher (module from ``Crypto.Cipher``):
+ The cipher algorithm to use for deriving the Poly1305
+ key pair *(r, s)*.
+ It can only be ``Crypto.Cipher.AES`` or ``Crypto.Cipher.ChaCha20``.
+ nonce (bytes/bytearray/memoryview):
+ Optional. The non-repeatable value to use for the MAC of this message.
+ It must be 16 bytes long for ``AES`` and 8 or 12 bytes for ``ChaCha20``.
+ If not passed, a random nonce is created; you will find it in the
+ ``nonce`` attribute of the new object.
+ data (bytes/bytearray/memoryview):
+ Optional. The very first chunk of the message to authenticate.
+ It is equivalent to an early call to ``update()``.
+
+ Returns:
+ A :class:`Poly1305_MAC` object
+ """
+
+ cipher = kwargs.pop("cipher", None)
+ if not hasattr(cipher, '_derive_Poly1305_key_pair'):
+ raise ValueError("Parameter 'cipher' must be AES or ChaCha20")
+
+ cipher_key = kwargs.pop("key", None)
+ if cipher_key is None:
+ raise TypeError("You must pass a parameter 'key'")
+
+ nonce = kwargs.pop("nonce", None)
+ data = kwargs.pop("data", None)
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ r, s, nonce = cipher._derive_Poly1305_key_pair(cipher_key, nonce)
+
+ new_mac = Poly1305_MAC(r, s, data)
+ new_mac.nonce = _copy_bytes(None, None, nonce) # nonce may still be just a memoryview
+ return new_mac
diff --git a/lib/Crypto/Hash/Poly1305.pyi b/lib/Crypto/Hash/Poly1305.pyi
new file mode 100644
index 0000000..f97a14a
--- /dev/null
+++ b/lib/Crypto/Hash/Poly1305.pyi
@@ -0,0 +1,24 @@
+from types import ModuleType
+from typing import Union
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class Poly1305_MAC(object):
+ block_size: int
+ digest_size: int
+ oid: str
+
+ def __init__(self,
+ r : int,
+ s : int,
+ data : Buffer) -> None: ...
+ def update(self, data: Buffer) -> Poly1305_MAC: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def verify(self, mac_tag: Buffer) -> None: ...
+ def hexverify(self, hex_mac_tag: str) -> None: ...
+
+def new(key: Buffer,
+ cipher: ModuleType,
+ nonce: Buffer = ...,
+ data: Buffer = ...) -> Poly1305_MAC: ...
diff --git a/lib/Crypto/Hash/RIPEMD.py b/lib/Crypto/Hash/RIPEMD.py
new file mode 100644
index 0000000..4e80235
--- /dev/null
+++ b/lib/Crypto/Hash/RIPEMD.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+# This file exists for backward compatibility with old code that refers to
+# Crypto.Hash.RIPEMD
+
+"""Deprecated alias for `Crypto.Hash.RIPEMD160`"""
+
+from Crypto.Hash.RIPEMD160 import new, block_size, digest_size
diff --git a/lib/Crypto/Hash/RIPEMD.pyi b/lib/Crypto/Hash/RIPEMD.pyi
new file mode 100644
index 0000000..e33eb2d
--- /dev/null
+++ b/lib/Crypto/Hash/RIPEMD.pyi
@@ -0,0 +1,3 @@
+# This file exists for backward compatibility with old code that refers to
+# Crypto.Hash.SHA
+
diff --git a/lib/Crypto/Hash/RIPEMD160.py b/lib/Crypto/Hash/RIPEMD160.py
new file mode 100644
index 0000000..820b57d
--- /dev/null
+++ b/lib/Crypto/Hash/RIPEMD160.py
@@ -0,0 +1,169 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_ripemd160_lib = load_pycryptodome_raw_lib(
+ "Crypto.Hash._RIPEMD160",
+ """
+ int ripemd160_init(void **shaState);
+ int ripemd160_destroy(void *shaState);
+ int ripemd160_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int ripemd160_digest(const void *shaState,
+ uint8_t digest[20]);
+ int ripemd160_copy(const void *src, void *dst);
+ """)
+
+
+class RIPEMD160Hash(object):
+ """A RIPEMD-160 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 20
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 64
+ # ASN.1 Object ID
+ oid = "1.3.36.3.2.1"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_ripemd160_lib.ripemd160_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating RIPEMD160"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_ripemd160_lib.ripemd160_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_ripemd160_lib.ripemd160_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while instantiating ripemd160"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_ripemd160_lib.ripemd160_digest(self._state.get(),
+ bfr)
+ if result:
+ raise ValueError("Error %d while instantiating ripemd160"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = RIPEMD160Hash()
+ result = _raw_ripemd160_lib.ripemd160_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying ripemd160" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh RIPEMD-160 hash object."""
+
+ return RIPEMD160Hash(data)
+
+
+def new(data=None):
+ """Create a new hash object.
+
+ :parameter data:
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`RIPEMD160Hash.update`.
+ :type data: byte string/byte array/memoryview
+
+ :Return: A :class:`RIPEMD160Hash` hash object
+ """
+
+ return RIPEMD160Hash().new(data)
+
+# The size of the resulting hash in bytes.
+digest_size = RIPEMD160Hash.digest_size
+
+# The internal block size of the hash algorithm in bytes.
+block_size = RIPEMD160Hash.block_size
diff --git a/lib/Crypto/Hash/RIPEMD160.pyi b/lib/Crypto/Hash/RIPEMD160.pyi
new file mode 100644
index 0000000..b619473
--- /dev/null
+++ b/lib/Crypto/Hash/RIPEMD160.pyi
@@ -0,0 +1,19 @@
+from typing import Union
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class RIPEMD160Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self, data: Buffer = ...) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> RIPEMD160Hash: ...
+ def new(self, data: Buffer = ...) -> RIPEMD160Hash: ...
+
+def new(data: Buffer = ...) -> RIPEMD160Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA.py b/lib/Crypto/Hash/SHA.py
new file mode 100644
index 0000000..0cc141c
--- /dev/null
+++ b/lib/Crypto/Hash/SHA.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+# This file exists for backward compatibility with old code that refers to
+# Crypto.Hash.SHA
+
+from Crypto.Hash.SHA1 import __doc__, new, block_size, digest_size
diff --git a/lib/Crypto/Hash/SHA.pyi b/lib/Crypto/Hash/SHA.pyi
new file mode 100644
index 0000000..4d7d57e
--- /dev/null
+++ b/lib/Crypto/Hash/SHA.pyi
@@ -0,0 +1,4 @@
+# This file exists for backward compatibility with old code that refers to
+# Crypto.Hash.SHA
+
+from Crypto.Hash.SHA1 import __doc__, new, block_size, digest_size
diff --git a/lib/Crypto/Hash/SHA1.py b/lib/Crypto/Hash/SHA1.py
new file mode 100644
index 0000000..f79d825
--- /dev/null
+++ b/lib/Crypto/Hash/SHA1.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import *
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_sha1_lib = load_pycryptodome_raw_lib("Crypto.Hash._SHA1",
+ """
+ #define SHA1_DIGEST_SIZE 20
+
+ int SHA1_init(void **shaState);
+ int SHA1_destroy(void *shaState);
+ int SHA1_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int SHA1_digest(const void *shaState,
+ uint8_t digest[SHA1_DIGEST_SIZE]);
+ int SHA1_copy(const void *src, void *dst);
+
+ int SHA1_pbkdf2_hmac_assist(const void *inner,
+ const void *outer,
+ const uint8_t first_digest[SHA1_DIGEST_SIZE],
+ uint8_t final_digest[SHA1_DIGEST_SIZE],
+ size_t iterations);
+ """)
+
+class SHA1Hash(object):
+ """A SHA-1 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 20
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 64
+ # ASN.1 Object ID
+ oid = "1.3.14.3.2.26"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_sha1_lib.SHA1_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating SHA1"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_sha1_lib.SHA1_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_sha1_lib.SHA1_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while instantiating SHA1"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_sha1_lib.SHA1_digest(self._state.get(),
+ bfr)
+ if result:
+ raise ValueError("Error %d while instantiating SHA1"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = SHA1Hash()
+ result = _raw_sha1_lib.SHA1_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA1" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA-1 hash object."""
+
+ return SHA1Hash(data)
+
+
+def new(data=None):
+ """Create a new hash object.
+
+ :parameter data:
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`SHA1Hash.update`.
+ :type data: byte string/byte array/memoryview
+
+ :Return: A :class:`SHA1Hash` hash object
+ """
+ return SHA1Hash().new(data)
+
+
+# The size of the resulting hash in bytes.
+digest_size = SHA1Hash.digest_size
+
+# The internal block size of the hash algorithm in bytes.
+block_size = SHA1Hash.block_size
+
+
+def _pbkdf2_hmac_assist(inner, outer, first_digest, iterations):
+ """Compute the expensive inner loop in PBKDF-HMAC."""
+
+ assert len(first_digest) == digest_size
+ assert iterations > 0
+
+ bfr = create_string_buffer(digest_size);
+ result = _raw_sha1_lib.SHA1_pbkdf2_hmac_assist(
+ inner._state.get(),
+ outer._state.get(),
+ first_digest,
+ bfr,
+ c_size_t(iterations))
+
+ if result:
+ raise ValueError("Error %d with PBKDF2-HMAC assis for SHA1" % result)
+
+ return get_raw_buffer(bfr)
diff --git a/lib/Crypto/Hash/SHA1.pyi b/lib/Crypto/Hash/SHA1.pyi
new file mode 100644
index 0000000..d6c8e25
--- /dev/null
+++ b/lib/Crypto/Hash/SHA1.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA1Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self, data: Optional[Buffer] = ...) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA1Hash: ...
+ def new(self, data: Optional[Buffer] = ...) -> SHA1Hash: ...
+
+def new(data: Optional[Buffer] = ...) -> SHA1Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA224.py b/lib/Crypto/Hash/SHA224.py
new file mode 100644
index 0000000..f788b06
--- /dev/null
+++ b/lib/Crypto/Hash/SHA224.py
@@ -0,0 +1,186 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_sha224_lib = load_pycryptodome_raw_lib("Crypto.Hash._SHA224",
+ """
+ int SHA224_init(void **shaState);
+ int SHA224_destroy(void *shaState);
+ int SHA224_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int SHA224_digest(const void *shaState,
+ uint8_t *digest,
+ size_t digest_size);
+ int SHA224_copy(const void *src, void *dst);
+
+ int SHA224_pbkdf2_hmac_assist(const void *inner,
+ const void *outer,
+ const uint8_t *first_digest,
+ uint8_t *final_digest,
+ size_t iterations,
+ size_t digest_size);
+ """)
+
+class SHA224Hash(object):
+ """A SHA-224 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 28
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 64
+ # ASN.1 Object ID
+ oid = '2.16.840.1.101.3.4.2.4'
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_sha224_lib.SHA224_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating SHA224"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_sha224_lib.SHA224_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_sha224_lib.SHA224_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while hashing data with SHA224"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_sha224_lib.SHA224_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size))
+ if result:
+ raise ValueError("Error %d while making SHA224 digest"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = SHA224Hash()
+ result = _raw_sha224_lib.SHA224_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA224" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA-224 hash object."""
+
+ return SHA224Hash(data)
+
+
+def new(data=None):
+ """Create a new hash object.
+
+ :parameter data:
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`SHA224Hash.update`.
+ :type data: byte string/byte array/memoryview
+
+ :Return: A :class:`SHA224Hash` hash object
+ """
+ return SHA224Hash().new(data)
+
+
+# The size of the resulting hash in bytes.
+digest_size = SHA224Hash.digest_size
+
+# The internal block size of the hash algorithm in bytes.
+block_size = SHA224Hash.block_size
+
+
+def _pbkdf2_hmac_assist(inner, outer, first_digest, iterations):
+ """Compute the expensive inner loop in PBKDF-HMAC."""
+
+ assert iterations > 0
+
+ bfr = create_string_buffer(len(first_digest));
+ result = _raw_sha224_lib.SHA224_pbkdf2_hmac_assist(
+ inner._state.get(),
+ outer._state.get(),
+ first_digest,
+ bfr,
+ c_size_t(iterations),
+ c_size_t(len(first_digest)))
+
+ if result:
+ raise ValueError("Error %d with PBKDF2-HMAC assist for SHA224" % result)
+
+ return get_raw_buffer(bfr)
diff --git a/lib/Crypto/Hash/SHA224.pyi b/lib/Crypto/Hash/SHA224.pyi
new file mode 100644
index 0000000..613a7f9
--- /dev/null
+++ b/lib/Crypto/Hash/SHA224.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA224Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self, data: Optional[Buffer] = ...) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA224Hash: ...
+ def new(self, data: Optional[Buffer] = ...) -> SHA224Hash: ...
+
+def new(data: Optional[Buffer] = ...) -> SHA224Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA256.py b/lib/Crypto/Hash/SHA256.py
new file mode 100644
index 0000000..957aa37
--- /dev/null
+++ b/lib/Crypto/Hash/SHA256.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_sha256_lib = load_pycryptodome_raw_lib("Crypto.Hash._SHA256",
+ """
+ int SHA256_init(void **shaState);
+ int SHA256_destroy(void *shaState);
+ int SHA256_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int SHA256_digest(const void *shaState,
+ uint8_t *digest,
+ size_t digest_size);
+ int SHA256_copy(const void *src, void *dst);
+
+ int SHA256_pbkdf2_hmac_assist(const void *inner,
+ const void *outer,
+ const uint8_t *first_digest,
+ uint8_t *final_digest,
+ size_t iterations,
+ size_t digest_size);
+ """)
+
+class SHA256Hash(object):
+ """A SHA-256 hash object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 32
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 64
+ # ASN.1 Object ID
+ oid = "2.16.840.1.101.3.4.2.1"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_sha256_lib.SHA256_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating SHA256"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_sha256_lib.SHA256_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_sha256_lib.SHA256_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while hashing data with SHA256"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_sha256_lib.SHA256_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size))
+ if result:
+ raise ValueError("Error %d while making SHA256 digest"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = SHA256Hash()
+ result = _raw_sha256_lib.SHA256_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA256" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA-256 hash object."""
+
+ return SHA256Hash(data)
+
+def new(data=None):
+ """Create a new hash object.
+
+ :parameter data:
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`SHA256Hash.update`.
+ :type data: byte string/byte array/memoryview
+
+ :Return: A :class:`SHA256Hash` hash object
+ """
+
+ return SHA256Hash().new(data)
+
+
+# The size of the resulting hash in bytes.
+digest_size = SHA256Hash.digest_size
+
+# The internal block size of the hash algorithm in bytes.
+block_size = SHA256Hash.block_size
+
+
+def _pbkdf2_hmac_assist(inner, outer, first_digest, iterations):
+ """Compute the expensive inner loop in PBKDF-HMAC."""
+
+ assert iterations > 0
+
+ bfr = create_string_buffer(len(first_digest));
+ result = _raw_sha256_lib.SHA256_pbkdf2_hmac_assist(
+ inner._state.get(),
+ outer._state.get(),
+ first_digest,
+ bfr,
+ c_size_t(iterations),
+ c_size_t(len(first_digest)))
+
+ if result:
+ raise ValueError("Error %d with PBKDF2-HMAC assist for SHA256" % result)
+
+ return get_raw_buffer(bfr)
diff --git a/lib/Crypto/Hash/SHA256.pyi b/lib/Crypto/Hash/SHA256.pyi
new file mode 100644
index 0000000..cbf21bf
--- /dev/null
+++ b/lib/Crypto/Hash/SHA256.pyi
@@ -0,0 +1,18 @@
+from typing import Union, Optional
+
+
+class SHA256Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+ def __init__(self, data: Optional[Union[bytes, bytearray, memoryview]]=None) -> None: ...
+ def update(self, data: Union[bytes, bytearray, memoryview]) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA256Hash: ...
+ def new(self, data: Optional[Union[bytes, bytearray, memoryview]]=None) -> SHA256Hash: ...
+
+def new(data: Optional[Union[bytes, bytearray, memoryview]]=None) -> SHA256Hash: ...
+
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA384.py b/lib/Crypto/Hash/SHA384.py
new file mode 100644
index 0000000..a98fa9a
--- /dev/null
+++ b/lib/Crypto/Hash/SHA384.py
@@ -0,0 +1,186 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_sha384_lib = load_pycryptodome_raw_lib("Crypto.Hash._SHA384",
+ """
+ int SHA384_init(void **shaState);
+ int SHA384_destroy(void *shaState);
+ int SHA384_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int SHA384_digest(const void *shaState,
+ uint8_t *digest,
+ size_t digest_size);
+ int SHA384_copy(const void *src, void *dst);
+
+ int SHA384_pbkdf2_hmac_assist(const void *inner,
+ const void *outer,
+ const uint8_t *first_digest,
+ uint8_t *final_digest,
+ size_t iterations,
+ size_t digest_size);
+ """)
+
+class SHA384Hash(object):
+ """A SHA-384 hash object.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 48
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 128
+ # ASN.1 Object ID
+ oid = '2.16.840.1.101.3.4.2.2'
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_sha384_lib.SHA384_init(state.address_of())
+ if result:
+ raise ValueError("Error %d while instantiating SHA384"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_sha384_lib.SHA384_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_sha384_lib.SHA384_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while hashing data with SHA384"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_sha384_lib.SHA384_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size))
+ if result:
+ raise ValueError("Error %d while making SHA384 digest"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = SHA384Hash()
+ result = _raw_sha384_lib.SHA384_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA384" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA-384 hash object."""
+
+ return SHA384Hash(data)
+
+
+def new(data=None):
+ """Create a new hash object.
+
+ :parameter data:
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`SHA384Hash.update`.
+ :type data: byte string/byte array/memoryview
+
+ :Return: A :class:`SHA384Hash` hash object
+ """
+
+ return SHA384Hash().new(data)
+
+
+# The size of the resulting hash in bytes.
+digest_size = SHA384Hash.digest_size
+
+# The internal block size of the hash algorithm in bytes.
+block_size = SHA384Hash.block_size
+
+
+def _pbkdf2_hmac_assist(inner, outer, first_digest, iterations):
+ """Compute the expensive inner loop in PBKDF-HMAC."""
+
+ assert iterations > 0
+
+ bfr = create_string_buffer(len(first_digest));
+ result = _raw_sha384_lib.SHA384_pbkdf2_hmac_assist(
+ inner._state.get(),
+ outer._state.get(),
+ first_digest,
+ bfr,
+ c_size_t(iterations),
+ c_size_t(len(first_digest)))
+
+ if result:
+ raise ValueError("Error %d with PBKDF2-HMAC assist for SHA384" % result)
+
+ return get_raw_buffer(bfr)
diff --git a/lib/Crypto/Hash/SHA384.pyi b/lib/Crypto/Hash/SHA384.pyi
new file mode 100644
index 0000000..c2aab9e
--- /dev/null
+++ b/lib/Crypto/Hash/SHA384.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA384Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self, data: Optional[Buffer] = ...) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA384Hash: ...
+ def new(self, data: Optional[Buffer] = ...) -> SHA384Hash: ...
+
+def new(data: Optional[Buffer] = ...) -> SHA384Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA3_224.py b/lib/Crypto/Hash/SHA3_224.py
new file mode 100644
index 0000000..54556d0
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_224.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Hash.keccak import _raw_keccak_lib
+
+class SHA3_224_Hash(object):
+ """A SHA3-224 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 28
+
+ # ASN.1 Object ID
+ oid = "2.16.840.1.101.3.4.2.7"
+
+ # Input block size for HMAC
+ block_size = 144
+
+ def __init__(self, data, update_after_digest):
+ self._update_after_digest = update_after_digest
+ self._digest_done = False
+ self._padding = 0x06
+
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(self.digest_size * 2),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/224"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._digest_done and not self._update_after_digest:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data))
+ )
+ if result:
+ raise ValueError("Error %d while updating SHA-3/224"
+ % result)
+ return self
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ self._digest_done = True
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_keccak_lib.keccak_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/224"
+ % result)
+
+ self._digest_value = get_raw_buffer(bfr)
+ return self._digest_value
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = self.new()
+ result = _raw_keccak_lib.keccak_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA3-224" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA3-224 hash object."""
+
+ return type(self)(data, self._update_after_digest)
+
+
+def new(*args, **kwargs):
+ """Create a new hash object.
+
+ Args:
+ data (byte string/byte array/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ update_after_digest (boolean):
+ Whether :meth:`digest` can be followed by another :meth:`update`
+ (default: ``False``).
+
+ :Return: A :class:`SHA3_224_Hash` hash object
+ """
+
+ data = kwargs.pop("data", None)
+ update_after_digest = kwargs.pop("update_after_digest", False)
+ if len(args) == 1:
+ if data:
+ raise ValueError("Initial data for hash specified twice")
+ data = args[0]
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return SHA3_224_Hash(data, update_after_digest)
+
+# The size of the resulting hash in bytes.
+digest_size = SHA3_224_Hash.digest_size
+
+# Input block size for HMAC
+block_size = 144
diff --git a/lib/Crypto/Hash/SHA3_224.pyi b/lib/Crypto/Hash/SHA3_224.pyi
new file mode 100644
index 0000000..2180821
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_224.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA3_224_Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+ def __init__(self, data: Optional[Buffer], update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> SHA3_224_Hash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA3_224_Hash: ...
+ def new(self, data: Optional[Buffer]) -> SHA3_224_Hash: ...
+
+def new(__data: Buffer = ..., update_after_digest: bool = ...) -> SHA3_224_Hash: ...
+
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA3_256.py b/lib/Crypto/Hash/SHA3_256.py
new file mode 100644
index 0000000..b4f11ee
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_256.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Hash.keccak import _raw_keccak_lib
+
+class SHA3_256_Hash(object):
+ """A SHA3-256 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 32
+
+ # ASN.1 Object ID
+ oid = "2.16.840.1.101.3.4.2.8"
+
+ # Input block size for HMAC
+ block_size = 136
+
+ def __init__(self, data, update_after_digest):
+ self._update_after_digest = update_after_digest
+ self._digest_done = False
+ self._padding = 0x06
+
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(self.digest_size * 2),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/256"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._digest_done and not self._update_after_digest:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data))
+ )
+ if result:
+ raise ValueError("Error %d while updating SHA-3/256"
+ % result)
+ return self
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ self._digest_done = True
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_keccak_lib.keccak_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/256"
+ % result)
+
+ self._digest_value = get_raw_buffer(bfr)
+ return self._digest_value
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = self.new()
+ result = _raw_keccak_lib.keccak_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA3-256" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA3-256 hash object."""
+
+ return type(self)(data, self._update_after_digest)
+
+
+def new(*args, **kwargs):
+ """Create a new hash object.
+
+ Args:
+ data (byte string/byte array/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ update_after_digest (boolean):
+ Whether :meth:`digest` can be followed by another :meth:`update`
+ (default: ``False``).
+
+ :Return: A :class:`SHA3_256_Hash` hash object
+ """
+
+ data = kwargs.pop("data", None)
+ update_after_digest = kwargs.pop("update_after_digest", False)
+ if len(args) == 1:
+ if data:
+ raise ValueError("Initial data for hash specified twice")
+ data = args[0]
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return SHA3_256_Hash(data, update_after_digest)
+
+# The size of the resulting hash in bytes.
+digest_size = SHA3_256_Hash.digest_size
+
+# Input block size for HMAC
+block_size = 136
diff --git a/lib/Crypto/Hash/SHA3_256.pyi b/lib/Crypto/Hash/SHA3_256.pyi
new file mode 100644
index 0000000..88436bd
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_256.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA3_256_Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+ def __init__(self, data: Optional[Buffer], update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> SHA3_256_Hash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA3_256_Hash: ...
+ def new(self, data: Optional[Buffer]) -> SHA3_256_Hash: ...
+
+def new(__data: Buffer = ..., update_after_digest: bool = ...) -> SHA3_256_Hash: ...
+
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA3_384.py b/lib/Crypto/Hash/SHA3_384.py
new file mode 100644
index 0000000..12f61ce
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_384.py
@@ -0,0 +1,179 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Hash.keccak import _raw_keccak_lib
+
+class SHA3_384_Hash(object):
+ """A SHA3-384 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 48
+
+ # ASN.1 Object ID
+ oid = "2.16.840.1.101.3.4.2.9"
+
+ # Input block size for HMAC
+ block_size = 104
+
+ def __init__(self, data, update_after_digest):
+ self._update_after_digest = update_after_digest
+ self._digest_done = False
+ self._padding = 0x06
+
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(self.digest_size * 2),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/384"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._digest_done and not self._update_after_digest:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while updating SHA-3/384"
+ % result)
+ return self
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ self._digest_done = True
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_keccak_lib.keccak_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/384"
+ % result)
+
+ self._digest_value = get_raw_buffer(bfr)
+ return self._digest_value
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = self.new()
+ result = _raw_keccak_lib.keccak_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA3-384" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA3-256 hash object."""
+
+ return type(self)(data, self._update_after_digest)
+
+
+ def new(self, data=None):
+ """Create a fresh SHA3-384 hash object."""
+
+ return type(self)(data, self._update_after_digest)
+
+
+def new(*args, **kwargs):
+ """Create a new hash object.
+
+ Args:
+ data (byte string/byte array/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ update_after_digest (boolean):
+ Whether :meth:`digest` can be followed by another :meth:`update`
+ (default: ``False``).
+
+ :Return: A :class:`SHA3_384_Hash` hash object
+ """
+
+ data = kwargs.pop("data", None)
+ update_after_digest = kwargs.pop("update_after_digest", False)
+ if len(args) == 1:
+ if data:
+ raise ValueError("Initial data for hash specified twice")
+ data = args[0]
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return SHA3_384_Hash(data, update_after_digest)
+
+# The size of the resulting hash in bytes.
+digest_size = SHA3_384_Hash.digest_size
+
+# Input block size for HMAC
+block_size = 104
diff --git a/lib/Crypto/Hash/SHA3_384.pyi b/lib/Crypto/Hash/SHA3_384.pyi
new file mode 100644
index 0000000..98d00c6
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_384.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA3_384_Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+ def __init__(self, data: Optional[Buffer], update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> SHA3_384_Hash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA3_384_Hash: ...
+ def new(self, data: Optional[Buffer]) -> SHA3_384_Hash: ...
+
+def new(__data: Buffer = ..., update_after_digest: bool = ...) -> SHA3_384_Hash: ...
+
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA3_512.py b/lib/Crypto/Hash/SHA3_512.py
new file mode 100644
index 0000000..de8880c
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_512.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Hash.keccak import _raw_keccak_lib
+
+class SHA3_512_Hash(object):
+ """A SHA3-512 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The size of the resulting hash in bytes.
+ digest_size = 64
+
+ # ASN.1 Object ID
+ oid = "2.16.840.1.101.3.4.2.10"
+
+ # Input block size for HMAC
+ block_size = 72
+
+ def __init__(self, data, update_after_digest):
+ self._update_after_digest = update_after_digest
+ self._digest_done = False
+ self._padding = 0x06
+
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(self.digest_size * 2),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/512"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._digest_done and not self._update_after_digest:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while updating SHA-3/512"
+ % result)
+ return self
+
+ def digest(self):
+
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ self._digest_done = True
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_keccak_lib.keccak_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-3/512"
+ % result)
+
+ self._digest_value = get_raw_buffer(bfr)
+ return self._digest_value
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = self.new()
+ result = _raw_keccak_lib.keccak_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA3-512" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA3-521 hash object."""
+
+ return type(self)(data, self._update_after_digest)
+
+
+def new(*args, **kwargs):
+ """Create a new hash object.
+
+ Args:
+ data (byte string/byte array/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ update_after_digest (boolean):
+ Whether :meth:`digest` can be followed by another :meth:`update`
+ (default: ``False``).
+
+ :Return: A :class:`SHA3_512_Hash` hash object
+ """
+
+ data = kwargs.pop("data", None)
+ update_after_digest = kwargs.pop("update_after_digest", False)
+ if len(args) == 1:
+ if data:
+ raise ValueError("Initial data for hash specified twice")
+ data = args[0]
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return SHA3_512_Hash(data, update_after_digest)
+
+# The size of the resulting hash in bytes.
+digest_size = SHA3_512_Hash.digest_size
+
+# Input block size for HMAC
+block_size = 72
diff --git a/lib/Crypto/Hash/SHA3_512.pyi b/lib/Crypto/Hash/SHA3_512.pyi
new file mode 100644
index 0000000..cdeec16
--- /dev/null
+++ b/lib/Crypto/Hash/SHA3_512.pyi
@@ -0,0 +1,19 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA3_512_Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+ def __init__(self, data: Optional[Buffer], update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> SHA3_512_Hash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA3_512_Hash: ...
+ def new(self, data: Optional[Buffer]) -> SHA3_512_Hash: ...
+
+def new(__data: Buffer = ..., update_after_digest: bool = ...) -> SHA3_512_Hash: ...
+
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHA512.py b/lib/Crypto/Hash/SHA512.py
new file mode 100644
index 0000000..403fe45
--- /dev/null
+++ b/lib/Crypto/Hash/SHA512.py
@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr)
+
+_raw_sha512_lib = load_pycryptodome_raw_lib("Crypto.Hash._SHA512",
+ """
+ int SHA512_init(void **shaState,
+ size_t digest_size);
+ int SHA512_destroy(void *shaState);
+ int SHA512_update(void *hs,
+ const uint8_t *buf,
+ size_t len);
+ int SHA512_digest(const void *shaState,
+ uint8_t *digest,
+ size_t digest_size);
+ int SHA512_copy(const void *src, void *dst);
+
+ int SHA512_pbkdf2_hmac_assist(const void *inner,
+ const void *outer,
+ const uint8_t *first_digest,
+ uint8_t *final_digest,
+ size_t iterations,
+ size_t digest_size);
+ """)
+
+class SHA512Hash(object):
+ """A SHA-512 hash object (possibly in its truncated version SHA-512/224 or
+ SHA-512/256.
+ Do not instantiate directly. Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+
+ :ivar block_size: the size in bytes of the internal message block,
+ input to the compression function
+ :vartype block_size: integer
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ # The internal block size of the hash algorithm in bytes.
+ block_size = 128
+
+ def __init__(self, data, truncate):
+ self._truncate = truncate
+
+ if truncate is None:
+ self.oid = "2.16.840.1.101.3.4.2.3"
+ self.digest_size = 64
+ elif truncate == "224":
+ self.oid = "2.16.840.1.101.3.4.2.5"
+ self.digest_size = 28
+ elif truncate == "256":
+ self.oid = "2.16.840.1.101.3.4.2.6"
+ self.digest_size = 32
+ else:
+ raise ValueError("Incorrect truncation length. It must be '224' or '256'.")
+
+ state = VoidPointer()
+ result = _raw_sha512_lib.SHA512_init(state.address_of(),
+ c_size_t(self.digest_size))
+ if result:
+ raise ValueError("Error %d while instantiating SHA-512"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_sha512_lib.SHA512_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ result = _raw_sha512_lib.SHA512_update(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while hashing data with SHA512"
+ % result)
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_sha512_lib.SHA512_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size))
+ if result:
+ raise ValueError("Error %d while making SHA512 digest"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def copy(self):
+ """Return a copy ("clone") of the hash object.
+
+ The copy will have the same internal state as the original hash
+ object.
+ This can be used to efficiently compute the digests of strings that
+ share a common initial substring.
+
+ :return: A hash object of the same type
+ """
+
+ clone = SHA512Hash(None, self._truncate)
+ result = _raw_sha512_lib.SHA512_copy(self._state.get(),
+ clone._state.get())
+ if result:
+ raise ValueError("Error %d while copying SHA512" % result)
+ return clone
+
+ def new(self, data=None):
+ """Create a fresh SHA-512 hash object."""
+
+ return SHA512Hash(data, self._truncate)
+
+
+def new(data=None, truncate=None):
+ """Create a new hash object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ Optional. The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`SHA512Hash.update`.
+ truncate (string):
+ Optional. The desired length of the digest. It can be either "224" or
+ "256". If not present, the digest is 512 bits long.
+ Passing this parameter is **not** equivalent to simply truncating
+ the output digest.
+
+ :Return: A :class:`SHA512Hash` hash object
+ """
+
+ return SHA512Hash(data, truncate)
+
+
+# The size of the full SHA-512 hash in bytes.
+digest_size = 64
+
+# The internal block size of the hash algorithm in bytes.
+block_size = 128
+
+
+def _pbkdf2_hmac_assist(inner, outer, first_digest, iterations):
+ """Compute the expensive inner loop in PBKDF-HMAC."""
+
+ assert iterations > 0
+
+ bfr = create_string_buffer(len(first_digest));
+ result = _raw_sha512_lib.SHA512_pbkdf2_hmac_assist(
+ inner._state.get(),
+ outer._state.get(),
+ first_digest,
+ bfr,
+ c_size_t(iterations),
+ c_size_t(len(first_digest)))
+
+ if result:
+ raise ValueError("Error %d with PBKDF2-HMAC assist for SHA512" % result)
+
+ return get_raw_buffer(bfr)
diff --git a/lib/Crypto/Hash/SHA512.pyi b/lib/Crypto/Hash/SHA512.pyi
new file mode 100644
index 0000000..f219ee9
--- /dev/null
+++ b/lib/Crypto/Hash/SHA512.pyi
@@ -0,0 +1,22 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHA512Hash(object):
+ digest_size: int
+ block_size: int
+ oid: str
+
+ def __init__(self,
+ data: Optional[Buffer],
+ truncate: Optional[str]) -> None: ...
+ def update(self, data: Buffer) -> None: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def copy(self) -> SHA512Hash: ...
+ def new(self, data: Optional[Buffer] = ...) -> SHA512Hash: ...
+
+def new(data: Optional[Buffer] = ...,
+ truncate: Optional[str] = ...) -> SHA512Hash: ...
+digest_size: int
+block_size: int
diff --git a/lib/Crypto/Hash/SHAKE128.py b/lib/Crypto/Hash/SHAKE128.py
new file mode 100644
index 0000000..894a41b
--- /dev/null
+++ b/lib/Crypto/Hash/SHAKE128.py
@@ -0,0 +1,129 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Hash.keccak import _raw_keccak_lib
+
+class SHAKE128_XOF(object):
+ """A SHAKE128 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+ """
+
+ # ASN.1 Object ID
+ oid = "2.16.840.1.101.3.4.2.11"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(32),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating SHAKE128"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ self._is_squeezing = False
+ self._padding = 0x1F
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._is_squeezing:
+ raise TypeError("You cannot call 'update' after the first 'read'")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while updating SHAKE128 state"
+ % result)
+ return self
+
+ def read(self, length):
+ """
+ Compute the next piece of XOF output.
+
+ .. note::
+ You cannot use :meth:`update` anymore after the first call to
+ :meth:`read`.
+
+ Args:
+ length (integer): the amount of bytes this method must return
+
+ :return: the next piece of XOF output (of the given length)
+ :rtype: byte string
+ """
+
+ self._is_squeezing = True
+ bfr = create_string_buffer(length)
+ result = _raw_keccak_lib.keccak_squeeze(self._state.get(),
+ bfr,
+ c_size_t(length),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while extracting from SHAKE128"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def new(self, data=None):
+ return type(self)(data=data)
+
+
+def new(data=None):
+ """Return a fresh instance of a SHAKE128 object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ Optional.
+
+ :Return: A :class:`SHAKE128_XOF` object
+ """
+
+ return SHAKE128_XOF(data=data)
diff --git a/lib/Crypto/Hash/SHAKE128.pyi b/lib/Crypto/Hash/SHAKE128.pyi
new file mode 100644
index 0000000..f618881
--- /dev/null
+++ b/lib/Crypto/Hash/SHAKE128.pyi
@@ -0,0 +1,13 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHAKE128_XOF(object):
+ oid: str
+ def __init__(self,
+ data: Optional[Buffer] = ...) -> None: ...
+ def update(self, data: Buffer) -> SHAKE128_XOF: ...
+ def read(self, length: int) -> bytes: ...
+ def new(self, data: Optional[Buffer] = ...) -> SHAKE128_XOF: ...
+
+def new(data: Optional[Buffer] = ...) -> SHAKE128_XOF: ...
diff --git a/lib/Crypto/Hash/SHAKE256.py b/lib/Crypto/Hash/SHAKE256.py
new file mode 100644
index 0000000..f75b822
--- /dev/null
+++ b/lib/Crypto/Hash/SHAKE256.py
@@ -0,0 +1,130 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Hash.keccak import _raw_keccak_lib
+
+class SHAKE256_XOF(object):
+ """A SHAKE256 hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar oid: ASN.1 Object ID
+ :vartype oid: string
+ """
+
+ # ASN.1 Object ID
+ oid = "2.16.840.1.101.3.4.2.12"
+
+ def __init__(self, data=None):
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(64),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating SHAKE256"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ self._is_squeezing = False
+ self._padding = 0x1F
+
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._is_squeezing:
+ raise TypeError("You cannot call 'update' after the first 'read'")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while updating SHAKE256 state"
+ % result)
+ return self
+
+ def read(self, length):
+ """
+ Compute the next piece of XOF output.
+
+ .. note::
+ You cannot use :meth:`update` anymore after the first call to
+ :meth:`read`.
+
+ Args:
+ length (integer): the amount of bytes this method must return
+
+ :return: the next piece of XOF output (of the given length)
+ :rtype: byte string
+ """
+
+ self._is_squeezing = True
+ bfr = create_string_buffer(length)
+ result = _raw_keccak_lib.keccak_squeeze(self._state.get(),
+ bfr,
+ c_size_t(length),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while extracting from SHAKE256"
+ % result)
+
+ return get_raw_buffer(bfr)
+
+ def new(self, data=None):
+ return type(self)(data=data)
+
+
+def new(data=None):
+ """Return a fresh instance of a SHAKE256 object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ Optional.
+
+ :Return: A :class:`SHAKE256_XOF` object
+ """
+
+ return SHAKE256_XOF(data=data)
diff --git a/lib/Crypto/Hash/SHAKE256.pyi b/lib/Crypto/Hash/SHAKE256.pyi
new file mode 100644
index 0000000..029347a
--- /dev/null
+++ b/lib/Crypto/Hash/SHAKE256.pyi
@@ -0,0 +1,13 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class SHAKE256_XOF(object):
+ oid: str
+ def __init__(self,
+ data: Optional[Buffer] = ...) -> None: ...
+ def update(self, data: Buffer) -> SHAKE256_XOF: ...
+ def read(self, length: int) -> bytes: ...
+ def new(self, data: Optional[Buffer] = ...) -> SHAKE256_XOF: ...
+
+def new(data: Optional[Buffer] = ...) -> SHAKE256_XOF: ...
diff --git a/lib/Crypto/Hash/TupleHash128.py b/lib/Crypto/Hash/TupleHash128.py
new file mode 100644
index 0000000..8fd3283
--- /dev/null
+++ b/lib/Crypto/Hash/TupleHash128.py
@@ -0,0 +1,138 @@
+# ===================================================================
+#
+# Copyright (c) 2021, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord, is_bytes, tobytes
+
+from . import cSHAKE128
+from .cSHAKE128 import _encode_str, _right_encode
+
+
+class TupleHash(object):
+ """A Tuple hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+ """
+
+ def __init__(self, custom, cshake, digest_size):
+
+ self.digest_size = digest_size
+
+ self._cshake = cshake._new(b'', custom, b'TupleHash')
+ self._digest = None
+
+ def update(self, data):
+ """Authenticate the next byte string in the tuple.
+
+ Args:
+ data (bytes/bytearray/memoryview): The next byte string.
+ """
+
+ if self._digest is not None:
+ raise TypeError("You cannot call 'update' after 'digest' or 'hexdigest'")
+
+ if not is_bytes(data):
+ raise TypeError("You can only call 'update' on bytes")
+
+ self._cshake.update(_encode_str(tobytes(data)))
+
+ return self
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the tuple of byte strings.
+
+ :return: The hash digest. Binary form.
+ :rtype: byte string
+ """
+
+ if self._digest is None:
+ self._cshake.update(_right_encode(self.digest_size * 8))
+ self._digest = self._cshake.read(self.digest_size)
+
+ return self._digest
+
+ def hexdigest(self):
+ """Return the **printable** digest of the tuple of byte strings.
+
+ :return: The hash digest. Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in tuple(self.digest())])
+
+ def new(self, **kwargs):
+ """Return a new instance of a TupleHash object.
+ See :func:`new`.
+ """
+
+ if "digest_bytes" not in kwargs and "digest_bits" not in kwargs:
+ kwargs["digest_bytes"] = self.digest_size
+
+ return new(**kwargs)
+
+
+def new(**kwargs):
+ """Create a new TupleHash128 object.
+
+ Args:
+ digest_bytes (integer):
+ Optional. The size of the digest, in bytes.
+ Default is 64. Minimum is 8.
+ digest_bits (integer):
+ Optional and alternative to ``digest_bytes``.
+ The size of the digest, in bits (and in steps of 8).
+ Default is 512. Minimum is 64.
+ custom (bytes):
+ Optional.
+ A customization bytestring (``S`` in SP 800-185).
+
+ :Return: A :class:`TupleHash` object
+ """
+
+ digest_bytes = kwargs.pop("digest_bytes", None)
+ digest_bits = kwargs.pop("digest_bits", None)
+ if None not in (digest_bytes, digest_bits):
+ raise TypeError("Only one digest parameter must be provided")
+ if (None, None) == (digest_bytes, digest_bits):
+ digest_bytes = 64
+ if digest_bytes is not None:
+ if digest_bytes < 8:
+ raise ValueError("'digest_bytes' must be at least 8")
+ else:
+ if digest_bits < 64 or digest_bits % 8:
+ raise ValueError("'digest_bytes' must be at least 64 "
+ "in steps of 8")
+ digest_bytes = digest_bits // 8
+
+ custom = kwargs.pop("custom", b'')
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return TupleHash(custom, cSHAKE128, digest_bytes)
diff --git a/lib/Crypto/Hash/TupleHash128.pyi b/lib/Crypto/Hash/TupleHash128.pyi
new file mode 100644
index 0000000..3b1e81e
--- /dev/null
+++ b/lib/Crypto/Hash/TupleHash128.pyi
@@ -0,0 +1,22 @@
+from typing import Any, Union
+from types import ModuleType
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class TupleHash(object):
+ digest_size: int
+ def __init__(self,
+ custom: bytes,
+ cshake: ModuleType,
+ digest_size: int) -> None: ...
+ def update(self, data: Buffer) -> TupleHash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def new(self,
+ digest_bytes: int = ...,
+ digest_bits: int = ...,
+ custom: int = ...) -> TupleHash: ...
+
+def new(digest_bytes: int = ...,
+ digest_bits: int = ...,
+ custom: int = ...) -> TupleHash: ...
diff --git a/lib/Crypto/Hash/TupleHash256.py b/lib/Crypto/Hash/TupleHash256.py
new file mode 100644
index 0000000..9b4fba0
--- /dev/null
+++ b/lib/Crypto/Hash/TupleHash256.py
@@ -0,0 +1,73 @@
+# ===================================================================
+#
+# Copyright (c) 2021, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from . import cSHAKE256
+from .TupleHash128 import TupleHash
+
+
+def new(**kwargs):
+ """Create a new TupleHash256 object.
+
+ Args:
+ digest_bytes (integer):
+ Optional. The size of the digest, in bytes.
+ Default is 64. Minimum is 8.
+ digest_bits (integer):
+ Optional and alternative to ``digest_bytes``.
+ The size of the digest, in bits (and in steps of 8).
+ Default is 512. Minimum is 64.
+ custom (bytes):
+ Optional.
+ A customization bytestring (``S`` in SP 800-185).
+
+ :Return: A :class:`TupleHash` object
+ """
+
+ digest_bytes = kwargs.pop("digest_bytes", None)
+ digest_bits = kwargs.pop("digest_bits", None)
+ if None not in (digest_bytes, digest_bits):
+ raise TypeError("Only one digest parameter must be provided")
+ if (None, None) == (digest_bytes, digest_bits):
+ digest_bytes = 64
+ if digest_bytes is not None:
+ if digest_bytes < 8:
+ raise ValueError("'digest_bytes' must be at least 8")
+ else:
+ if digest_bits < 64 or digest_bits % 8:
+ raise ValueError("'digest_bytes' must be at least 64 "
+ "in steps of 8")
+ digest_bytes = digest_bits // 8
+
+ custom = kwargs.pop("custom", b'')
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return TupleHash(custom, cSHAKE256, digest_bytes)
diff --git a/lib/Crypto/Hash/TupleHash256.pyi b/lib/Crypto/Hash/TupleHash256.pyi
new file mode 100644
index 0000000..82d943f
--- /dev/null
+++ b/lib/Crypto/Hash/TupleHash256.pyi
@@ -0,0 +1,5 @@
+from .TupleHash128 import TupleHash
+
+def new(digest_bytes: int = ...,
+ digest_bits: int = ...,
+ custom: int = ...) -> TupleHash: ...
diff --git a/lib/Crypto/Hash/_BLAKE2b.abi3.so b/lib/Crypto/Hash/_BLAKE2b.abi3.so
new file mode 100755
index 0000000..d7bc5bf
--- /dev/null
+++ b/lib/Crypto/Hash/_BLAKE2b.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_BLAKE2s.abi3.so b/lib/Crypto/Hash/_BLAKE2s.abi3.so
new file mode 100755
index 0000000..1e7ac6c
--- /dev/null
+++ b/lib/Crypto/Hash/_BLAKE2s.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_MD2.abi3.so b/lib/Crypto/Hash/_MD2.abi3.so
new file mode 100755
index 0000000..7671f60
--- /dev/null
+++ b/lib/Crypto/Hash/_MD2.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_MD4.abi3.so b/lib/Crypto/Hash/_MD4.abi3.so
new file mode 100755
index 0000000..ddd85d1
--- /dev/null
+++ b/lib/Crypto/Hash/_MD4.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_MD5.abi3.so b/lib/Crypto/Hash/_MD5.abi3.so
new file mode 100755
index 0000000..a8b6683
--- /dev/null
+++ b/lib/Crypto/Hash/_MD5.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_RIPEMD160.abi3.so b/lib/Crypto/Hash/_RIPEMD160.abi3.so
new file mode 100755
index 0000000..d37c515
--- /dev/null
+++ b/lib/Crypto/Hash/_RIPEMD160.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_SHA1.abi3.so b/lib/Crypto/Hash/_SHA1.abi3.so
new file mode 100755
index 0000000..7080984
--- /dev/null
+++ b/lib/Crypto/Hash/_SHA1.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_SHA224.abi3.so b/lib/Crypto/Hash/_SHA224.abi3.so
new file mode 100755
index 0000000..5ceeaa5
--- /dev/null
+++ b/lib/Crypto/Hash/_SHA224.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_SHA256.abi3.so b/lib/Crypto/Hash/_SHA256.abi3.so
new file mode 100755
index 0000000..0b4f41b
--- /dev/null
+++ b/lib/Crypto/Hash/_SHA256.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_SHA384.abi3.so b/lib/Crypto/Hash/_SHA384.abi3.so
new file mode 100755
index 0000000..8f87972
--- /dev/null
+++ b/lib/Crypto/Hash/_SHA384.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_SHA512.abi3.so b/lib/Crypto/Hash/_SHA512.abi3.so
new file mode 100755
index 0000000..804ee9e
--- /dev/null
+++ b/lib/Crypto/Hash/_SHA512.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/__init__.py b/lib/Crypto/Hash/__init__.py
new file mode 100644
index 0000000..4bda084
--- /dev/null
+++ b/lib/Crypto/Hash/__init__.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+__all__ = ['HMAC', 'MD2', 'MD4', 'MD5', 'RIPEMD160', 'SHA1',
+ 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'CMAC', 'Poly1305',
+ 'cSHAKE128', 'cSHAKE256', 'KMAC128', 'KMAC256',
+ 'TupleHash128', 'TupleHash256', 'KangarooTwelve']
diff --git a/lib/Crypto/Hash/__init__.pyi b/lib/Crypto/Hash/__init__.pyi
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/Crypto/Hash/__init__.pyi
diff --git a/lib/Crypto/Hash/_ghash_clmul.abi3.so b/lib/Crypto/Hash/_ghash_clmul.abi3.so
new file mode 100755
index 0000000..77163c1
--- /dev/null
+++ b/lib/Crypto/Hash/_ghash_clmul.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_ghash_portable.abi3.so b/lib/Crypto/Hash/_ghash_portable.abi3.so
new file mode 100755
index 0000000..702cd16
--- /dev/null
+++ b/lib/Crypto/Hash/_ghash_portable.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_keccak.abi3.so b/lib/Crypto/Hash/_keccak.abi3.so
new file mode 100755
index 0000000..aaa33d7
--- /dev/null
+++ b/lib/Crypto/Hash/_keccak.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/_poly1305.abi3.so b/lib/Crypto/Hash/_poly1305.abi3.so
new file mode 100755
index 0000000..a795027
--- /dev/null
+++ b/lib/Crypto/Hash/_poly1305.abi3.so
Binary files differ
diff --git a/lib/Crypto/Hash/cSHAKE128.py b/lib/Crypto/Hash/cSHAKE128.py
new file mode 100644
index 0000000..92a4e5c
--- /dev/null
+++ b/lib/Crypto/Hash/cSHAKE128.py
@@ -0,0 +1,187 @@
+# ===================================================================
+#
+# Copyright (c) 2021, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bchr
+
+from Crypto.Util._raw_api import (VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+from Crypto.Util.number import long_to_bytes
+
+from Crypto.Hash.keccak import _raw_keccak_lib
+
+
+def _left_encode(x):
+ """Left encode function as defined in NIST SP 800-185"""
+
+ assert (x < (1 << 2040) and x >= 0)
+
+ # Get number of bytes needed to represent this integer.
+ num = 1 if x == 0 else (x.bit_length() + 7) // 8
+
+ return bchr(num) + long_to_bytes(x)
+
+
+def _right_encode(x):
+ """Right encode function as defined in NIST SP 800-185"""
+
+ assert (x < (1 << 2040) and x >= 0)
+
+ # Get number of bytes needed to represent this integer.
+ num = 1 if x == 0 else (x.bit_length() + 7) // 8
+
+ return long_to_bytes(x) + bchr(num)
+
+
+def _encode_str(x):
+ """Encode string function as defined in NIST SP 800-185"""
+
+ bitlen = len(x) * 8
+ if bitlen >= (1 << 2040):
+ raise ValueError("String too large to encode in cSHAKE")
+
+ return _left_encode(bitlen) + x
+
+
+def _bytepad(x, length):
+ """Zero pad byte string as defined in NIST SP 800-185"""
+
+ to_pad = _left_encode(length) + x
+
+ # Note: this implementation works with byte aligned strings,
+ # hence no additional bit padding is needed at this point.
+ npad = (length - len(to_pad) % length) % length
+
+ return to_pad + b'\x00' * npad
+
+
+class cSHAKE_XOF(object):
+ """A cSHAKE hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+ """
+
+ def __init__(self, data, custom, capacity, function):
+ state = VoidPointer()
+
+ if custom or function:
+ prefix_unpad = _encode_str(function) + _encode_str(custom)
+ prefix = _bytepad(prefix_unpad, (1600 - capacity)//8)
+ self._padding = 0x04
+ else:
+ prefix = None
+ self._padding = 0x1F # for SHAKE
+
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(capacity//8),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating cSHAKE"
+ % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ self._is_squeezing = False
+
+ if prefix:
+ self.update(prefix)
+
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._is_squeezing:
+ raise TypeError("You cannot call 'update' after the first 'read'")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while updating %s state"
+ % (result, self.name))
+ return self
+
+ def read(self, length):
+ """
+ Compute the next piece of XOF output.
+
+ .. note::
+ You cannot use :meth:`update` anymore after the first call to
+ :meth:`read`.
+
+ Args:
+ length (integer): the amount of bytes this method must return
+
+ :return: the next piece of XOF output (of the given length)
+ :rtype: byte string
+ """
+
+ self._is_squeezing = True
+ bfr = create_string_buffer(length)
+ result = _raw_keccak_lib.keccak_squeeze(self._state.get(),
+ bfr,
+ c_size_t(length),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while extracting from %s"
+ % (result, self.name))
+
+ return get_raw_buffer(bfr)
+
+
+def _new(data, custom, function):
+ # Use Keccak[256]
+ return cSHAKE_XOF(data, custom, 256, function)
+
+
+def new(data=None, custom=None):
+ """Return a fresh instance of a cSHAKE128 object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ Optional.
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ custom (bytes):
+ Optional.
+ A customization bytestring (``S`` in SP 800-185).
+
+ :Return: A :class:`cSHAKE_XOF` object
+ """
+
+ # Use Keccak[256]
+ return cSHAKE_XOF(data, custom, 256, b'')
diff --git a/lib/Crypto/Hash/cSHAKE128.pyi b/lib/Crypto/Hash/cSHAKE128.pyi
new file mode 100644
index 0000000..1452fea
--- /dev/null
+++ b/lib/Crypto/Hash/cSHAKE128.pyi
@@ -0,0 +1,14 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class cSHAKE_XOF(object):
+ def __init__(self,
+ data: Optional[Buffer] = ...,
+ function: Optional[bytes] = ...,
+ custom: Optional[bytes] = ...) -> None: ...
+ def update(self, data: Buffer) -> cSHAKE_XOF: ...
+ def read(self, length: int) -> bytes: ...
+
+def new(data: Optional[Buffer] = ...,
+ custom: Optional[Buffer] = ...) -> cSHAKE_XOF: ...
diff --git a/lib/Crypto/Hash/cSHAKE256.py b/lib/Crypto/Hash/cSHAKE256.py
new file mode 100644
index 0000000..b3b31d6
--- /dev/null
+++ b/lib/Crypto/Hash/cSHAKE256.py
@@ -0,0 +1,56 @@
+# ===================================================================
+#
+# Copyright (c) 2021, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util._raw_api import c_size_t
+from Crypto.Hash.cSHAKE128 import cSHAKE_XOF
+
+
+def _new(data, custom, function):
+ # Use Keccak[512]
+ return cSHAKE_XOF(data, custom, 512, function)
+
+
+def new(data=None, custom=None):
+ """Return a fresh instance of a cSHAKE256 object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`update`.
+ Optional.
+ custom (bytes):
+ Optional.
+ A customization bytestring (``S`` in SP 800-185).
+
+ :Return: A :class:`cSHAKE_XOF` object
+ """
+
+ # Use Keccak[512]
+ return cSHAKE_XOF(data, custom, 512, b'')
diff --git a/lib/Crypto/Hash/cSHAKE256.pyi b/lib/Crypto/Hash/cSHAKE256.pyi
new file mode 100644
index 0000000..205b816
--- /dev/null
+++ b/lib/Crypto/Hash/cSHAKE256.pyi
@@ -0,0 +1,8 @@
+from typing import Union, Optional
+
+from Crypto.Hash.cSHAKE128 import cSHAKE_XOF
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def new(data: Optional[Buffer] = ...,
+ custom: Optional[Buffer] = ...) -> cSHAKE_XOF: ...
diff --git a/lib/Crypto/Hash/keccak.py b/lib/Crypto/Hash/keccak.py
new file mode 100644
index 0000000..f3f8bb5
--- /dev/null
+++ b/lib/Crypto/Hash/keccak.py
@@ -0,0 +1,181 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bord
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ VoidPointer, SmartPointer,
+ create_string_buffer,
+ get_raw_buffer, c_size_t,
+ c_uint8_ptr, c_ubyte)
+
+_raw_keccak_lib = load_pycryptodome_raw_lib("Crypto.Hash._keccak",
+ """
+ int keccak_init(void **state,
+ size_t capacity_bytes,
+ uint8_t rounds);
+ int keccak_destroy(void *state);
+ int keccak_absorb(void *state,
+ const uint8_t *in,
+ size_t len);
+ int keccak_squeeze(const void *state,
+ uint8_t *out,
+ size_t len,
+ uint8_t padding);
+ int keccak_digest(void *state,
+ uint8_t *digest,
+ size_t len,
+ uint8_t padding);
+ int keccak_copy(const void *src, void *dst);
+ int keccak_reset(void *state);
+ """)
+
+class Keccak_Hash(object):
+ """A Keccak hash object.
+ Do not instantiate directly.
+ Use the :func:`new` function.
+
+ :ivar digest_size: the size in bytes of the resulting hash
+ :vartype digest_size: integer
+ """
+
+ def __init__(self, data, digest_bytes, update_after_digest):
+ # The size of the resulting hash in bytes.
+ self.digest_size = digest_bytes
+
+ self._update_after_digest = update_after_digest
+ self._digest_done = False
+ self._padding = 0x01
+
+ state = VoidPointer()
+ result = _raw_keccak_lib.keccak_init(state.address_of(),
+ c_size_t(self.digest_size * 2),
+ c_ubyte(24))
+ if result:
+ raise ValueError("Error %d while instantiating keccak" % result)
+ self._state = SmartPointer(state.get(),
+ _raw_keccak_lib.keccak_destroy)
+ if data:
+ self.update(data)
+
+ def update(self, data):
+ """Continue hashing of a message by consuming the next chunk of data.
+
+ Args:
+ data (byte string/byte array/memoryview): The next chunk of the message being hashed.
+ """
+
+ if self._digest_done and not self._update_after_digest:
+ raise TypeError("You can only call 'digest' or 'hexdigest' on this object")
+
+ result = _raw_keccak_lib.keccak_absorb(self._state.get(),
+ c_uint8_ptr(data),
+ c_size_t(len(data)))
+ if result:
+ raise ValueError("Error %d while updating keccak" % result)
+ return self
+
+ def digest(self):
+ """Return the **binary** (non-printable) digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Binary form.
+ :rtype: byte string
+ """
+
+ self._digest_done = True
+ bfr = create_string_buffer(self.digest_size)
+ result = _raw_keccak_lib.keccak_digest(self._state.get(),
+ bfr,
+ c_size_t(self.digest_size),
+ c_ubyte(self._padding))
+ if result:
+ raise ValueError("Error %d while squeezing keccak" % result)
+
+ return get_raw_buffer(bfr)
+
+ def hexdigest(self):
+ """Return the **printable** digest of the message that has been hashed so far.
+
+ :return: The hash digest, computed over the data processed so far.
+ Hexadecimal encoded.
+ :rtype: string
+ """
+
+ return "".join(["%02x" % bord(x) for x in self.digest()])
+
+ def new(self, **kwargs):
+ """Create a fresh Keccak hash object."""
+
+ if "digest_bytes" not in kwargs and "digest_bits" not in kwargs:
+ kwargs["digest_bytes"] = self.digest_size
+
+ return new(**kwargs)
+
+
+def new(**kwargs):
+ """Create a new hash object.
+
+ Args:
+ data (bytes/bytearray/memoryview):
+ The very first chunk of the message to hash.
+ It is equivalent to an early call to :meth:`Keccak_Hash.update`.
+ digest_bytes (integer):
+ The size of the digest, in bytes (28, 32, 48, 64).
+ digest_bits (integer):
+ The size of the digest, in bits (224, 256, 384, 512).
+ update_after_digest (boolean):
+ Whether :meth:`Keccak.digest` can be followed by another
+ :meth:`Keccak.update` (default: ``False``).
+
+ :Return: A :class:`Keccak_Hash` hash object
+ """
+
+ data = kwargs.pop("data", None)
+ update_after_digest = kwargs.pop("update_after_digest", False)
+
+ digest_bytes = kwargs.pop("digest_bytes", None)
+ digest_bits = kwargs.pop("digest_bits", None)
+ if None not in (digest_bytes, digest_bits):
+ raise TypeError("Only one digest parameter must be provided")
+ if (None, None) == (digest_bytes, digest_bits):
+ raise TypeError("Digest size (bits, bytes) not provided")
+ if digest_bytes is not None:
+ if digest_bytes not in (28, 32, 48, 64):
+ raise ValueError("'digest_bytes' must be: 28, 32, 48 or 64")
+ else:
+ if digest_bits not in (224, 256, 384, 512):
+ raise ValueError("'digest_bytes' must be: 224, 256, 384 or 512")
+ digest_bytes = digest_bits // 8
+
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ return Keccak_Hash(data, digest_bytes, update_after_digest)
diff --git a/lib/Crypto/Hash/keccak.pyi b/lib/Crypto/Hash/keccak.pyi
new file mode 100644
index 0000000..844d256
--- /dev/null
+++ b/lib/Crypto/Hash/keccak.pyi
@@ -0,0 +1,23 @@
+from typing import Union, Any
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+class Keccak_Hash(object):
+ digest_size: int
+ def __init__(self,
+ data: Buffer,
+ digest_bytes: int,
+ update_after_digest: bool) -> None: ...
+ def update(self, data: Buffer) -> Keccak_Hash: ...
+ def digest(self) -> bytes: ...
+ def hexdigest(self) -> str: ...
+ def new(self,
+ data: Buffer = ...,
+ digest_bytes: int = ...,
+ digest_bits: int = ...,
+ update_after_digest: bool = ...) -> Keccak_Hash: ...
+
+def new(data: Buffer = ...,
+ digest_bytes: int = ...,
+ digest_bits: int = ...,
+ update_after_digest: bool = ...) -> Keccak_Hash: ...
diff --git a/lib/Crypto/IO/PEM.py b/lib/Crypto/IO/PEM.py
new file mode 100644
index 0000000..4c07b25
--- /dev/null
+++ b/lib/Crypto/IO/PEM.py
@@ -0,0 +1,189 @@
+#
+# Util/PEM.py : Privacy Enhanced Mail utilities
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+__all__ = ['encode', 'decode']
+
+import re
+from binascii import a2b_base64, b2a_base64, hexlify, unhexlify
+
+from Crypto.Hash import MD5
+from Crypto.Util.Padding import pad, unpad
+from Crypto.Cipher import DES, DES3, AES
+from Crypto.Protocol.KDF import PBKDF1
+from Crypto.Random import get_random_bytes
+from Crypto.Util.py3compat import tobytes, tostr
+
+
+def encode(data, marker, passphrase=None, randfunc=None):
+ """Encode a piece of binary data into PEM format.
+
+ Args:
+ data (byte string):
+ The piece of binary data to encode.
+ marker (string):
+ The marker for the PEM block (e.g. "PUBLIC KEY").
+ Note that there is no official master list for all allowed markers.
+ Still, you can refer to the OpenSSL_ source code.
+ passphrase (byte string):
+ If given, the PEM block will be encrypted. The key is derived from
+ the passphrase.
+ randfunc (callable):
+ Random number generation function; it accepts an integer N and returns
+ a byte string of random data, N bytes long. If not given, a new one is
+ instantiated.
+
+ Returns:
+ The PEM block, as a string.
+
+ .. _OpenSSL: https://github.com/openssl/openssl/blob/master/include/openssl/pem.h
+ """
+
+ if randfunc is None:
+ randfunc = get_random_bytes
+
+ out = "-----BEGIN %s-----\n" % marker
+ if passphrase:
+ # We only support 3DES for encryption
+ salt = randfunc(8)
+ key = PBKDF1(passphrase, salt, 16, 1, MD5)
+ key += PBKDF1(key + passphrase, salt, 8, 1, MD5)
+ objenc = DES3.new(key, DES3.MODE_CBC, salt)
+ out += "Proc-Type: 4,ENCRYPTED\nDEK-Info: DES-EDE3-CBC,%s\n\n" %\
+ tostr(hexlify(salt).upper())
+ # Encrypt with PKCS#7 padding
+ data = objenc.encrypt(pad(data, objenc.block_size))
+ elif passphrase is not None:
+ raise ValueError("Empty password")
+
+ # Each BASE64 line can take up to 64 characters (=48 bytes of data)
+ # b2a_base64 adds a new line character!
+ chunks = [tostr(b2a_base64(data[i:i + 48]))
+ for i in range(0, len(data), 48)]
+ out += "".join(chunks)
+ out += "-----END %s-----" % marker
+ return out
+
+
+def _EVP_BytesToKey(data, salt, key_len):
+ d = [ b'' ]
+ m = (key_len + 15 ) // 16
+ for _ in range(m):
+ nd = MD5.new(d[-1] + data + salt).digest()
+ d.append(nd)
+ return b"".join(d)[:key_len]
+
+
+def decode(pem_data, passphrase=None):
+ """Decode a PEM block into binary.
+
+ Args:
+ pem_data (string):
+ The PEM block.
+ passphrase (byte string):
+ If given and the PEM block is encrypted,
+ the key will be derived from the passphrase.
+
+ Returns:
+ A tuple with the binary data, the marker string, and a boolean to
+ indicate if decryption was performed.
+
+ Raises:
+ ValueError: if decoding fails, if the PEM file is encrypted and no passphrase has
+ been provided or if the passphrase is incorrect.
+ """
+
+ # Verify Pre-Encapsulation Boundary
+ r = re.compile(r"\s*-----BEGIN (.*)-----\s+")
+ m = r.match(pem_data)
+ if not m:
+ raise ValueError("Not a valid PEM pre boundary")
+ marker = m.group(1)
+
+ # Verify Post-Encapsulation Boundary
+ r = re.compile(r"-----END (.*)-----\s*$")
+ m = r.search(pem_data)
+ if not m or m.group(1) != marker:
+ raise ValueError("Not a valid PEM post boundary")
+
+ # Removes spaces and slit on lines
+ lines = pem_data.replace(" ", '').split()
+
+ # Decrypts, if necessary
+ if lines[1].startswith('Proc-Type:4,ENCRYPTED'):
+ if not passphrase:
+ raise ValueError("PEM is encrypted, but no passphrase available")
+ DEK = lines[2].split(':')
+ if len(DEK) != 2 or DEK[0] != 'DEK-Info':
+ raise ValueError("PEM encryption format not supported.")
+ algo, salt = DEK[1].split(',')
+ salt = unhexlify(tobytes(salt))
+
+ padding = True
+
+ if algo == "DES-CBC":
+ key = _EVP_BytesToKey(passphrase, salt, 8)
+ objdec = DES.new(key, DES.MODE_CBC, salt)
+ elif algo == "DES-EDE3-CBC":
+ key = _EVP_BytesToKey(passphrase, salt, 24)
+ objdec = DES3.new(key, DES3.MODE_CBC, salt)
+ elif algo == "AES-128-CBC":
+ key = _EVP_BytesToKey(passphrase, salt[:8], 16)
+ objdec = AES.new(key, AES.MODE_CBC, salt)
+ elif algo == "AES-192-CBC":
+ key = _EVP_BytesToKey(passphrase, salt[:8], 24)
+ objdec = AES.new(key, AES.MODE_CBC, salt)
+ elif algo == "AES-256-CBC":
+ key = _EVP_BytesToKey(passphrase, salt[:8], 32)
+ objdec = AES.new(key, AES.MODE_CBC, salt)
+ elif algo.lower() == "id-aes256-gcm":
+ key = _EVP_BytesToKey(passphrase, salt[:8], 32)
+ objdec = AES.new(key, AES.MODE_GCM, nonce=salt)
+ padding = False
+ else:
+ raise ValueError("Unsupport PEM encryption algorithm (%s)." % algo)
+ lines = lines[2:]
+ else:
+ objdec = None
+
+ # Decode body
+ data = a2b_base64(''.join(lines[1:-1]))
+ enc_flag = False
+ if objdec:
+ if padding:
+ data = unpad(objdec.decrypt(data), objdec.block_size)
+ else:
+ # There is no tag, so we don't use decrypt_and_verify
+ data = objdec.decrypt(data)
+ enc_flag = True
+
+ return (data, marker, enc_flag)
diff --git a/lib/Crypto/IO/PEM.pyi b/lib/Crypto/IO/PEM.pyi
new file mode 100644
index 0000000..2e324c4
--- /dev/null
+++ b/lib/Crypto/IO/PEM.pyi
@@ -0,0 +1,10 @@
+from typing import Tuple, Optional, Callable
+
+def encode(data: bytes,
+ marke: str,
+ passphrase: Optional[bytes] = ...,
+ randfunc: Optional[Callable[[int],bytes]] = ...) -> str: ...
+
+
+def decode(pem_data: str,
+ passphrase: Optional[bytes] = ...) -> Tuple[bytes, str, bool]: ...
diff --git a/lib/Crypto/IO/PKCS8.py b/lib/Crypto/IO/PKCS8.py
new file mode 100644
index 0000000..18dffae
--- /dev/null
+++ b/lib/Crypto/IO/PKCS8.py
@@ -0,0 +1,239 @@
+#
+# PublicKey/PKCS8.py : PKCS#8 functions
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+
+from Crypto.Util.py3compat import *
+
+from Crypto.Util.asn1 import (
+ DerNull,
+ DerSequence,
+ DerObjectId,
+ DerOctetString,
+ )
+
+from Crypto.IO._PBES import PBES1, PBES2, PbesError
+
+
+__all__ = ['wrap', 'unwrap']
+
+
+def wrap(private_key, key_oid, passphrase=None, protection=None,
+ prot_params=None, key_params=DerNull(), randfunc=None):
+ """Wrap a private key into a PKCS#8 blob (clear or encrypted).
+
+ Args:
+
+ private_key (byte string):
+ The private key encoded in binary form. The actual encoding is
+ algorithm specific. In most cases, it is DER.
+
+ key_oid (string):
+ The object identifier (OID) of the private key to wrap.
+ It is a dotted string, like ``1.2.840.113549.1.1.1`` (for RSA keys).
+
+ passphrase (bytes string or string):
+ The secret passphrase from which the wrapping key is derived.
+ Set it only if encryption is required.
+
+ protection (string):
+ The identifier of the algorithm to use for securely wrapping the key.
+ The default value is ``PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC``.
+
+ prot_params (dictionary):
+ Parameters for the protection algorithm.
+
+ +------------------+-----------------------------------------------+
+ | Key | Description |
+ +==================+===============================================+
+ | iteration_count | The KDF algorithm is repeated several times to|
+ | | slow down brute force attacks on passwords |
+ | | (called *N* or CPU/memory cost in scrypt). |
+ | | The default value for PBKDF2 is 1000. |
+ | | The default value for scrypt is 16384. |
+ +------------------+-----------------------------------------------+
+ | salt_size | Salt is used to thwart dictionary and rainbow |
+ | | attacks on passwords. The default value is 8 |
+ | | bytes. |
+ +------------------+-----------------------------------------------+
+ | block_size | *(scrypt only)* Memory-cost (r). The default |
+ | | value is 8. |
+ +------------------+-----------------------------------------------+
+ | parallelization | *(scrypt only)* CPU-cost (p). The default |
+ | | value is 1. |
+ +------------------+-----------------------------------------------+
+
+ key_params (DER object or None):
+ The ``parameters`` field to use in the ``AlgorithmIdentifier``
+ SEQUENCE. If ``None``, no ``parameters`` field will be added.
+ By default, the ASN.1 type ``NULL`` is used.
+
+ randfunc (callable):
+ Random number generation function; it should accept a single integer
+ N and return a string of random data, N bytes long.
+ If not specified, a new RNG will be instantiated
+ from :mod:`Crypto.Random`.
+
+ Return:
+ The PKCS#8-wrapped private key (possibly encrypted), as a byte string.
+ """
+
+ #
+ # PrivateKeyInfo ::= SEQUENCE {
+ # version Version,
+ # privateKeyAlgorithm PrivateKeyAlgorithmIdentifier,
+ # privateKey PrivateKey,
+ # attributes [0] IMPLICIT Attributes OPTIONAL
+ # }
+ #
+ if key_params is None:
+ algorithm = DerSequence([DerObjectId(key_oid)])
+ else:
+ algorithm = DerSequence([DerObjectId(key_oid), key_params])
+
+ pk_info = DerSequence([
+ 0,
+ algorithm,
+ DerOctetString(private_key)
+ ])
+ pk_info_der = pk_info.encode()
+
+ if passphrase is None:
+ return pk_info_der
+
+ if not passphrase:
+ raise ValueError("Empty passphrase")
+
+ # Encryption with PBES2
+ passphrase = tobytes(passphrase)
+ if protection is None:
+ protection = 'PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC'
+ return PBES2.encrypt(pk_info_der, passphrase,
+ protection, prot_params, randfunc)
+
+
+def unwrap(p8_private_key, passphrase=None):
+ """Unwrap a private key from a PKCS#8 blob (clear or encrypted).
+
+ Args:
+ p8_private_key (byte string):
+ The private key wrapped into a PKCS#8 blob, DER encoded.
+ passphrase (byte string or string):
+ The passphrase to use to decrypt the blob (if it is encrypted).
+
+ Return:
+ A tuple containing
+
+ #. the algorithm identifier of the wrapped key (OID, dotted string)
+ #. the private key (byte string, DER encoded)
+ #. the associated parameters (byte string, DER encoded) or ``None``
+
+ Raises:
+ ValueError : if decoding fails
+ """
+
+ if passphrase:
+ passphrase = tobytes(passphrase)
+
+ found = False
+ try:
+ p8_private_key = PBES1.decrypt(p8_private_key, passphrase)
+ found = True
+ except PbesError as e:
+ error_str = "PBES1[%s]" % str(e)
+ except ValueError:
+ error_str = "PBES1[Invalid]"
+
+ if not found:
+ try:
+ p8_private_key = PBES2.decrypt(p8_private_key, passphrase)
+ found = True
+ except PbesError as e:
+ error_str += ",PBES2[%s]" % str(e)
+ except ValueError:
+ error_str += ",PBES2[Invalid]"
+
+ if not found:
+ raise ValueError("Error decoding PKCS#8 (%s)" % error_str)
+
+ pk_info = DerSequence().decode(p8_private_key, nr_elements=(2, 3, 4, 5))
+ if len(pk_info) == 2 and not passphrase:
+ raise ValueError("Not a valid clear PKCS#8 structure "
+ "(maybe it is encrypted?)")
+
+ # RFC5208, PKCS#8, version is v1(0)
+ #
+ # PrivateKeyInfo ::= SEQUENCE {
+ # version Version,
+ # privateKeyAlgorithm PrivateKeyAlgorithmIdentifier,
+ # privateKey PrivateKey,
+ # attributes [0] IMPLICIT Attributes OPTIONAL
+ # }
+ #
+ # RFC5915, Asymmetric Key Package, version is v2(1)
+ #
+ # OneAsymmetricKey ::= SEQUENCE {
+ # version Version,
+ # privateKeyAlgorithm PrivateKeyAlgorithmIdentifier,
+ # privateKey PrivateKey,
+ # attributes [0] Attributes OPTIONAL,
+ # ...,
+ # [[2: publicKey [1] PublicKey OPTIONAL ]],
+ # ...
+ # }
+
+ if pk_info[0] == 0:
+ if len(pk_info) not in (3, 4):
+ raise ValueError("Not a valid PrivateKeyInfo SEQUENCE")
+ elif pk_info[0] == 1:
+ if len(pk_info) not in (3, 4, 5):
+ raise ValueError("Not a valid PrivateKeyInfo SEQUENCE")
+ else:
+ raise ValueError("Not a valid PrivateKeyInfo SEQUENCE")
+
+ algo = DerSequence().decode(pk_info[1], nr_elements=(1, 2))
+ algo_oid = DerObjectId().decode(algo[0]).value
+ if len(algo) == 1:
+ algo_params = None
+ else:
+ try:
+ DerNull().decode(algo[1])
+ algo_params = None
+ except:
+ algo_params = algo[1]
+
+ # PrivateKey ::= OCTET STRING
+ private_key = DerOctetString().decode(pk_info[2]).payload
+
+ # We ignore attributes and (for v2 only) publickey
+
+ return (algo_oid, private_key, algo_params)
diff --git a/lib/Crypto/IO/PKCS8.pyi b/lib/Crypto/IO/PKCS8.pyi
new file mode 100644
index 0000000..2fed1b7
--- /dev/null
+++ b/lib/Crypto/IO/PKCS8.pyi
@@ -0,0 +1,14 @@
+from typing import Dict, Tuple, Optional, Union, Callable
+
+from Crypto.Util.asn1 import DerObject
+
+def wrap(private_key: bytes,
+ key_oid: str,
+ passphrase: Union[bytes, str] = ...,
+ protection: str = ...,
+ prot_params: Dict = ...,
+ key_params: Optional[DerObject] = ...,
+ randfunc: Optional[Callable[[int],str]] = ...) -> bytes: ...
+
+
+def unwrap(p8_private_key: bytes, passphrase: Optional[Union[bytes, str]] = ...) -> Tuple[str, bytes, Optional[bytes]]: ...
diff --git a/lib/Crypto/IO/_PBES.py b/lib/Crypto/IO/_PBES.py
new file mode 100644
index 0000000..a47c775
--- /dev/null
+++ b/lib/Crypto/IO/_PBES.py
@@ -0,0 +1,435 @@
+#
+# PublicKey/_PBES.py : Password-Based Encryption functions
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto import Random
+from Crypto.Util.asn1 import (
+ DerSequence, DerOctetString,
+ DerObjectId, DerInteger,
+ )
+
+from Crypto.Util.Padding import pad, unpad
+from Crypto.Hash import MD5, SHA1, SHA224, SHA256, SHA384, SHA512
+from Crypto.Cipher import DES, ARC2, DES3, AES
+from Crypto.Protocol.KDF import PBKDF1, PBKDF2, scrypt
+
+_OID_PBE_WITH_MD5_AND_DES_CBC = "1.2.840.113549.1.5.3"
+_OID_PBE_WITH_MD5_AND_RC2_CBC = "1.2.840.113549.1.5.6"
+_OID_PBE_WITH_SHA1_AND_DES_CBC = "1.2.840.113549.1.5.10"
+_OID_PBE_WITH_SHA1_AND_RC2_CBC = "1.2.840.113549.1.5.11"
+
+_OID_PBES2 = "1.2.840.113549.1.5.13"
+
+_OID_PBKDF2 = "1.2.840.113549.1.5.12"
+_OID_SCRYPT = "1.3.6.1.4.1.11591.4.11"
+
+_OID_HMAC_SHA1 = "1.2.840.113549.2.7"
+_OID_HMAC_SHA224 = "1.2.840.113549.2.8"
+_OID_HMAC_SHA256 = "1.2.840.113549.2.9"
+_OID_HMAC_SHA384 = "1.2.840.113549.2.10"
+_OID_HMAC_SHA512 = "1.2.840.113549.2.11"
+
+_OID_DES_EDE3_CBC = "1.2.840.113549.3.7"
+_OID_AES128_CBC = "2.16.840.1.101.3.4.1.2"
+_OID_AES192_CBC = "2.16.840.1.101.3.4.1.22"
+_OID_AES256_CBC = "2.16.840.1.101.3.4.1.42"
+
+
+class PbesError(ValueError):
+ pass
+
+# These are the ASN.1 definitions used by the PBES1/2 logic:
+#
+# EncryptedPrivateKeyInfo ::= SEQUENCE {
+# encryptionAlgorithm EncryptionAlgorithmIdentifier,
+# encryptedData EncryptedData
+# }
+#
+# EncryptionAlgorithmIdentifier ::= AlgorithmIdentifier
+#
+# EncryptedData ::= OCTET STRING
+#
+# AlgorithmIdentifier ::= SEQUENCE {
+# algorithm OBJECT IDENTIFIER,
+# parameters ANY DEFINED BY algorithm OPTIONAL
+# }
+#
+# PBEParameter ::= SEQUENCE {
+# salt OCTET STRING (SIZE(8)),
+# iterationCount INTEGER
+# }
+#
+# PBES2-params ::= SEQUENCE {
+# keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}},
+# encryptionScheme AlgorithmIdentifier {{PBES2-Encs}}
+# }
+#
+# PBKDF2-params ::= SEQUENCE {
+# salt CHOICE {
+# specified OCTET STRING,
+# otherSource AlgorithmIdentifier {{PBKDF2-SaltSources}}
+# },
+# iterationCount INTEGER (1..MAX),
+# keyLength INTEGER (1..MAX) OPTIONAL,
+# prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1
+# }
+#
+# scrypt-params ::= SEQUENCE {
+# salt OCTET STRING,
+# costParameter INTEGER (1..MAX),
+# blockSize INTEGER (1..MAX),
+# parallelizationParameter INTEGER (1..MAX),
+# keyLength INTEGER (1..MAX) OPTIONAL
+# }
+
+class PBES1(object):
+ """Deprecated encryption scheme with password-based key derivation
+ (originally defined in PKCS#5 v1.5, but still present in `v2.0`__).
+
+ .. __: http://www.ietf.org/rfc/rfc2898.txt
+ """
+
+ @staticmethod
+ def decrypt(data, passphrase):
+ """Decrypt a piece of data using a passphrase and *PBES1*.
+
+ The algorithm to use is automatically detected.
+
+ :Parameters:
+ data : byte string
+ The piece of data to decrypt.
+ passphrase : byte string
+ The passphrase to use for decrypting the data.
+ :Returns:
+ The decrypted data, as a binary string.
+ """
+
+ enc_private_key_info = DerSequence().decode(data)
+ encrypted_algorithm = DerSequence().decode(enc_private_key_info[0])
+ encrypted_data = DerOctetString().decode(enc_private_key_info[1]).payload
+
+ pbe_oid = DerObjectId().decode(encrypted_algorithm[0]).value
+ cipher_params = {}
+ if pbe_oid == _OID_PBE_WITH_MD5_AND_DES_CBC:
+ # PBE_MD5_DES_CBC
+ hashmod = MD5
+ ciphermod = DES
+ elif pbe_oid == _OID_PBE_WITH_MD5_AND_RC2_CBC:
+ # PBE_MD5_RC2_CBC
+ hashmod = MD5
+ ciphermod = ARC2
+ cipher_params['effective_keylen'] = 64
+ elif pbe_oid == _OID_PBE_WITH_SHA1_AND_DES_CBC:
+ # PBE_SHA1_DES_CBC
+ hashmod = SHA1
+ ciphermod = DES
+ elif pbe_oid == _OID_PBE_WITH_SHA1_AND_RC2_CBC:
+ # PBE_SHA1_RC2_CBC
+ hashmod = SHA1
+ ciphermod = ARC2
+ cipher_params['effective_keylen'] = 64
+ else:
+ raise PbesError("Unknown OID for PBES1")
+
+ pbe_params = DerSequence().decode(encrypted_algorithm[1], nr_elements=2)
+ salt = DerOctetString().decode(pbe_params[0]).payload
+ iterations = pbe_params[1]
+
+ key_iv = PBKDF1(passphrase, salt, 16, iterations, hashmod)
+ key, iv = key_iv[:8], key_iv[8:]
+
+ cipher = ciphermod.new(key, ciphermod.MODE_CBC, iv, **cipher_params)
+ pt = cipher.decrypt(encrypted_data)
+ return unpad(pt, cipher.block_size)
+
+
+class PBES2(object):
+ """Encryption scheme with password-based key derivation
+ (defined in `PKCS#5 v2.0`__).
+
+ .. __: http://www.ietf.org/rfc/rfc2898.txt."""
+
+ @staticmethod
+ def encrypt(data, passphrase, protection, prot_params=None, randfunc=None):
+ """Encrypt a piece of data using a passphrase and *PBES2*.
+
+ :Parameters:
+ data : byte string
+ The piece of data to encrypt.
+ passphrase : byte string
+ The passphrase to use for encrypting the data.
+ protection : string
+ The identifier of the encryption algorithm to use.
+ The default value is '``PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC``'.
+ prot_params : dictionary
+ Parameters of the protection algorithm.
+
+ +------------------+-----------------------------------------------+
+ | Key | Description |
+ +==================+===============================================+
+ | iteration_count | The KDF algorithm is repeated several times to|
+ | | slow down brute force attacks on passwords |
+ | | (called *N* or CPU/memory cost in scrypt). |
+ | | |
+ | | The default value for PBKDF2 is 1 000. |
+ | | The default value for scrypt is 16 384. |
+ +------------------+-----------------------------------------------+
+ | salt_size | Salt is used to thwart dictionary and rainbow |
+ | | attacks on passwords. The default value is 8 |
+ | | bytes. |
+ +------------------+-----------------------------------------------+
+ | block_size | *(scrypt only)* Memory-cost (r). The default |
+ | | value is 8. |
+ +------------------+-----------------------------------------------+
+ | parallelization | *(scrypt only)* CPU-cost (p). The default |
+ | | value is 1. |
+ +------------------+-----------------------------------------------+
+
+
+ randfunc : callable
+ Random number generation function; it should accept
+ a single integer N and return a string of random data,
+ N bytes long. If not specified, a new RNG will be
+ instantiated from ``Crypto.Random``.
+
+ :Returns:
+ The encrypted data, as a binary string.
+ """
+
+ if prot_params is None:
+ prot_params = {}
+
+ if randfunc is None:
+ randfunc = Random.new().read
+
+ if protection == 'PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC':
+ key_size = 24
+ module = DES3
+ cipher_mode = DES3.MODE_CBC
+ enc_oid = _OID_DES_EDE3_CBC
+ elif protection in ('PBKDF2WithHMAC-SHA1AndAES128-CBC',
+ 'scryptAndAES128-CBC'):
+ key_size = 16
+ module = AES
+ cipher_mode = AES.MODE_CBC
+ enc_oid = _OID_AES128_CBC
+ elif protection in ('PBKDF2WithHMAC-SHA1AndAES192-CBC',
+ 'scryptAndAES192-CBC'):
+ key_size = 24
+ module = AES
+ cipher_mode = AES.MODE_CBC
+ enc_oid = _OID_AES192_CBC
+ elif protection in ('PBKDF2WithHMAC-SHA1AndAES256-CBC',
+ 'scryptAndAES256-CBC'):
+ key_size = 32
+ module = AES
+ cipher_mode = AES.MODE_CBC
+ enc_oid = _OID_AES256_CBC
+ else:
+ raise ValueError("Unknown PBES2 mode")
+
+ # Get random data
+ iv = randfunc(module.block_size)
+ salt = randfunc(prot_params.get("salt_size", 8))
+
+ # Derive key from password
+ if protection.startswith('PBKDF2'):
+ count = prot_params.get("iteration_count", 1000)
+ key = PBKDF2(passphrase, salt, key_size, count)
+ kdf_info = DerSequence([
+ DerObjectId(_OID_PBKDF2), # PBKDF2
+ DerSequence([
+ DerOctetString(salt),
+ DerInteger(count)
+ ])
+ ])
+ else:
+ # It must be scrypt
+ count = prot_params.get("iteration_count", 16384)
+ scrypt_r = prot_params.get('block_size', 8)
+ scrypt_p = prot_params.get('parallelization', 1)
+ key = scrypt(passphrase, salt, key_size,
+ count, scrypt_r, scrypt_p)
+ kdf_info = DerSequence([
+ DerObjectId(_OID_SCRYPT), # scrypt
+ DerSequence([
+ DerOctetString(salt),
+ DerInteger(count),
+ DerInteger(scrypt_r),
+ DerInteger(scrypt_p)
+ ])
+ ])
+
+ # Create cipher and use it
+ cipher = module.new(key, cipher_mode, iv)
+ encrypted_data = cipher.encrypt(pad(data, cipher.block_size))
+ enc_info = DerSequence([
+ DerObjectId(enc_oid),
+ DerOctetString(iv)
+ ])
+
+ # Result
+ enc_private_key_info = DerSequence([
+ # encryptionAlgorithm
+ DerSequence([
+ DerObjectId(_OID_PBES2),
+ DerSequence([
+ kdf_info,
+ enc_info
+ ]),
+ ]),
+ DerOctetString(encrypted_data)
+ ])
+ return enc_private_key_info.encode()
+
+ @staticmethod
+ def decrypt(data, passphrase):
+ """Decrypt a piece of data using a passphrase and *PBES2*.
+
+ The algorithm to use is automatically detected.
+
+ :Parameters:
+ data : byte string
+ The piece of data to decrypt.
+ passphrase : byte string
+ The passphrase to use for decrypting the data.
+ :Returns:
+ The decrypted data, as a binary string.
+ """
+
+ enc_private_key_info = DerSequence().decode(data, nr_elements=2)
+ enc_algo = DerSequence().decode(enc_private_key_info[0])
+ encrypted_data = DerOctetString().decode(enc_private_key_info[1]).payload
+
+ pbe_oid = DerObjectId().decode(enc_algo[0]).value
+ if pbe_oid != _OID_PBES2:
+ raise PbesError("Not a PBES2 object")
+
+ pbes2_params = DerSequence().decode(enc_algo[1], nr_elements=2)
+
+ ### Key Derivation Function selection
+ kdf_info = DerSequence().decode(pbes2_params[0], nr_elements=2)
+ kdf_oid = DerObjectId().decode(kdf_info[0]).value
+
+ kdf_key_length = None
+
+ # We only support PBKDF2 or scrypt
+ if kdf_oid == _OID_PBKDF2:
+
+ pbkdf2_params = DerSequence().decode(kdf_info[1], nr_elements=(2, 3, 4))
+ salt = DerOctetString().decode(pbkdf2_params[0]).payload
+ iteration_count = pbkdf2_params[1]
+
+ left = len(pbkdf2_params) - 2
+ idx = 2
+
+ if left > 0:
+ try:
+ kdf_key_length = pbkdf2_params[idx] - 0
+ left -= 1
+ idx += 1
+ except TypeError:
+ pass
+
+ # Default is HMAC-SHA1
+ pbkdf2_prf_oid = "1.2.840.113549.2.7"
+ if left > 0:
+ pbkdf2_prf_algo_id = DerSequence().decode(pbkdf2_params[idx])
+ pbkdf2_prf_oid = DerObjectId().decode(pbkdf2_prf_algo_id[0]).value
+
+ elif kdf_oid == _OID_SCRYPT:
+
+ scrypt_params = DerSequence().decode(kdf_info[1], nr_elements=(4, 5))
+ salt = DerOctetString().decode(scrypt_params[0]).payload
+ iteration_count, scrypt_r, scrypt_p = [scrypt_params[x]
+ for x in (1, 2, 3)]
+ if len(scrypt_params) > 4:
+ kdf_key_length = scrypt_params[4]
+ else:
+ kdf_key_length = None
+ else:
+ raise PbesError("Unsupported PBES2 KDF")
+
+ ### Cipher selection
+ enc_info = DerSequence().decode(pbes2_params[1])
+ enc_oid = DerObjectId().decode(enc_info[0]).value
+
+ if enc_oid == _OID_DES_EDE3_CBC:
+ # DES_EDE3_CBC
+ ciphermod = DES3
+ key_size = 24
+ elif enc_oid == _OID_AES128_CBC:
+ # AES128_CBC
+ ciphermod = AES
+ key_size = 16
+ elif enc_oid == _OID_AES192_CBC:
+ # AES192_CBC
+ ciphermod = AES
+ key_size = 24
+ elif enc_oid == _OID_AES256_CBC:
+ # AES256_CBC
+ ciphermod = AES
+ key_size = 32
+ else:
+ raise PbesError("Unsupported PBES2 cipher")
+
+ if kdf_key_length and kdf_key_length != key_size:
+ raise PbesError("Mismatch between PBES2 KDF parameters"
+ " and selected cipher")
+
+ IV = DerOctetString().decode(enc_info[1]).payload
+
+ # Create cipher
+ if kdf_oid == _OID_PBKDF2:
+ if pbkdf2_prf_oid == _OID_HMAC_SHA1:
+ hmac_hash_module = SHA1
+ elif pbkdf2_prf_oid == _OID_HMAC_SHA224:
+ hmac_hash_module = SHA224
+ elif pbkdf2_prf_oid == _OID_HMAC_SHA256:
+ hmac_hash_module = SHA256
+ elif pbkdf2_prf_oid == _OID_HMAC_SHA384:
+ hmac_hash_module = SHA384
+ elif pbkdf2_prf_oid == _OID_HMAC_SHA512:
+ hmac_hash_module = SHA512
+ else:
+ raise PbesError("Unsupported HMAC %s" % pbkdf2_prf_oid)
+
+ key = PBKDF2(passphrase, salt, key_size, iteration_count,
+ hmac_hash_module=hmac_hash_module)
+ else:
+ key = scrypt(passphrase, salt, key_size, iteration_count,
+ scrypt_r, scrypt_p)
+ cipher = ciphermod.new(key, ciphermod.MODE_CBC, IV)
+
+ # Decrypt data
+ pt = cipher.decrypt(encrypted_data)
+ return unpad(pt, cipher.block_size)
diff --git a/lib/Crypto/IO/_PBES.pyi b/lib/Crypto/IO/_PBES.pyi
new file mode 100644
index 0000000..a8a34ce
--- /dev/null
+++ b/lib/Crypto/IO/_PBES.pyi
@@ -0,0 +1,19 @@
+from typing import Dict, Optional, Callable
+
+class PbesError(ValueError):
+ ...
+
+class PBES1(object):
+ @staticmethod
+ def decrypt(data: bytes, passphrase: bytes) -> bytes: ...
+
+class PBES2(object):
+ @staticmethod
+ def encrypt(data: bytes,
+ passphrase: bytes,
+ protection: str,
+ prot_params: Optional[Dict] = ...,
+ randfunc: Optional[Callable[[int],bytes]] = ...) -> bytes: ...
+
+ @staticmethod
+ def decrypt(data:bytes, passphrase: bytes) -> bytes: ...
diff --git a/lib/Crypto/IO/__init__.py b/lib/Crypto/IO/__init__.py
new file mode 100644
index 0000000..85a0d0b
--- /dev/null
+++ b/lib/Crypto/IO/__init__.py
@@ -0,0 +1,31 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+__all__ = ['PEM', 'PKCS8']
diff --git a/lib/Crypto/Math/Numbers.py b/lib/Crypto/Math/Numbers.py
new file mode 100644
index 0000000..c2c4483
--- /dev/null
+++ b/lib/Crypto/Math/Numbers.py
@@ -0,0 +1,42 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+__all__ = ["Integer"]
+
+try:
+ from Crypto.Math._IntegerGMP import IntegerGMP as Integer
+ from Crypto.Math._IntegerGMP import implementation as _implementation
+except (ImportError, OSError, AttributeError):
+ try:
+ from Crypto.Math._IntegerCustom import IntegerCustom as Integer
+ from Crypto.Math._IntegerCustom import implementation as _implementation
+ except (ImportError, OSError):
+ from Crypto.Math._IntegerNative import IntegerNative as Integer
+ _implementation = {}
diff --git a/lib/Crypto/Math/Numbers.pyi b/lib/Crypto/Math/Numbers.pyi
new file mode 100644
index 0000000..126268c
--- /dev/null
+++ b/lib/Crypto/Math/Numbers.pyi
@@ -0,0 +1,4 @@
+from Crypto.Math._IntegerBase import IntegerBase
+
+class Integer(IntegerBase):
+ pass
diff --git a/lib/Crypto/Math/Primality.py b/lib/Crypto/Math/Primality.py
new file mode 100644
index 0000000..884c418
--- /dev/null
+++ b/lib/Crypto/Math/Primality.py
@@ -0,0 +1,369 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Functions to create and test prime numbers.
+
+:undocumented: __package__
+"""
+
+from Crypto import Random
+from Crypto.Math.Numbers import Integer
+
+from Crypto.Util.py3compat import iter_range
+
+COMPOSITE = 0
+PROBABLY_PRIME = 1
+
+
+def miller_rabin_test(candidate, iterations, randfunc=None):
+ """Perform a Miller-Rabin primality test on an integer.
+
+ The test is specified in Section C.3.1 of `FIPS PUB 186-4`__.
+
+ :Parameters:
+ candidate : integer
+ The number to test for primality.
+ iterations : integer
+ The maximum number of iterations to perform before
+ declaring a candidate a probable prime.
+ randfunc : callable
+ An RNG function where bases are taken from.
+
+ :Returns:
+ ``Primality.COMPOSITE`` or ``Primality.PROBABLY_PRIME``.
+
+ .. __: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
+ """
+
+ if not isinstance(candidate, Integer):
+ candidate = Integer(candidate)
+
+ if candidate in (1, 2, 3, 5):
+ return PROBABLY_PRIME
+
+ if candidate.is_even():
+ return COMPOSITE
+
+ one = Integer(1)
+ minus_one = Integer(candidate - 1)
+
+ if randfunc is None:
+ randfunc = Random.new().read
+
+ # Step 1 and 2
+ m = Integer(minus_one)
+ a = 0
+ while m.is_even():
+ m >>= 1
+ a += 1
+
+ # Skip step 3
+
+ # Step 4
+ for i in iter_range(iterations):
+
+ # Step 4.1-2
+ base = 1
+ while base in (one, minus_one):
+ base = Integer.random_range(min_inclusive=2,
+ max_inclusive=candidate - 2,
+ randfunc=randfunc)
+ assert(2 <= base <= candidate - 2)
+
+ # Step 4.3-4.4
+ z = pow(base, m, candidate)
+ if z in (one, minus_one):
+ continue
+
+ # Step 4.5
+ for j in iter_range(1, a):
+ z = pow(z, 2, candidate)
+ if z == minus_one:
+ break
+ if z == one:
+ return COMPOSITE
+ else:
+ return COMPOSITE
+
+ # Step 5
+ return PROBABLY_PRIME
+
+
+def lucas_test(candidate):
+ """Perform a Lucas primality test on an integer.
+
+ The test is specified in Section C.3.3 of `FIPS PUB 186-4`__.
+
+ :Parameters:
+ candidate : integer
+ The number to test for primality.
+
+ :Returns:
+ ``Primality.COMPOSITE`` or ``Primality.PROBABLY_PRIME``.
+
+ .. __: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
+ """
+
+ if not isinstance(candidate, Integer):
+ candidate = Integer(candidate)
+
+ # Step 1
+ if candidate in (1, 2, 3, 5):
+ return PROBABLY_PRIME
+ if candidate.is_even() or candidate.is_perfect_square():
+ return COMPOSITE
+
+ # Step 2
+ def alternate():
+ value = 5
+ while True:
+ yield value
+ if value > 0:
+ value += 2
+ else:
+ value -= 2
+ value = -value
+
+ for D in alternate():
+ if candidate in (D, -D):
+ continue
+ js = Integer.jacobi_symbol(D, candidate)
+ if js == 0:
+ return COMPOSITE
+ if js == -1:
+ break
+ # Found D. P=1 and Q=(1-D)/4 (note that Q is guaranteed to be an integer)
+
+ # Step 3
+ # This is \delta(n) = n - jacobi(D/n)
+ K = candidate + 1
+ # Step 4
+ r = K.size_in_bits() - 1
+ # Step 5
+ # U_1=1 and V_1=P
+ U_i = Integer(1)
+ V_i = Integer(1)
+ U_temp = Integer(0)
+ V_temp = Integer(0)
+ # Step 6
+ for i in iter_range(r - 1, -1, -1):
+ # Square
+ # U_temp = U_i * V_i % candidate
+ U_temp.set(U_i)
+ U_temp *= V_i
+ U_temp %= candidate
+ # V_temp = (((V_i ** 2 + (U_i ** 2 * D)) * K) >> 1) % candidate
+ V_temp.set(U_i)
+ V_temp *= U_i
+ V_temp *= D
+ V_temp.multiply_accumulate(V_i, V_i)
+ if V_temp.is_odd():
+ V_temp += candidate
+ V_temp >>= 1
+ V_temp %= candidate
+ # Multiply
+ if K.get_bit(i):
+ # U_i = (((U_temp + V_temp) * K) >> 1) % candidate
+ U_i.set(U_temp)
+ U_i += V_temp
+ if U_i.is_odd():
+ U_i += candidate
+ U_i >>= 1
+ U_i %= candidate
+ # V_i = (((V_temp + U_temp * D) * K) >> 1) % candidate
+ V_i.set(V_temp)
+ V_i.multiply_accumulate(U_temp, D)
+ if V_i.is_odd():
+ V_i += candidate
+ V_i >>= 1
+ V_i %= candidate
+ else:
+ U_i.set(U_temp)
+ V_i.set(V_temp)
+ # Step 7
+ if U_i == 0:
+ return PROBABLY_PRIME
+ return COMPOSITE
+
+
+from Crypto.Util.number import sieve_base as _sieve_base_large
+## The optimal number of small primes to use for the sieve
+## is probably dependent on the platform and the candidate size
+_sieve_base = set(_sieve_base_large[:100])
+
+
+def test_probable_prime(candidate, randfunc=None):
+ """Test if a number is prime.
+
+ A number is qualified as prime if it passes a certain
+ number of Miller-Rabin tests (dependent on the size
+ of the number, but such that probability of a false
+ positive is less than 10^-30) and a single Lucas test.
+
+ For instance, a 1024-bit candidate will need to pass
+ 4 Miller-Rabin tests.
+
+ :Parameters:
+ candidate : integer
+ The number to test for primality.
+ randfunc : callable
+ The routine to draw random bytes from to select Miller-Rabin bases.
+ :Returns:
+ ``PROBABLE_PRIME`` if the number if prime with very high probability.
+ ``COMPOSITE`` if the number is a composite.
+ For efficiency reasons, ``COMPOSITE`` is also returned for small primes.
+ """
+
+ if randfunc is None:
+ randfunc = Random.new().read
+
+ if not isinstance(candidate, Integer):
+ candidate = Integer(candidate)
+
+ # First, check trial division by the smallest primes
+ if int(candidate) in _sieve_base:
+ return PROBABLY_PRIME
+ try:
+ map(candidate.fail_if_divisible_by, _sieve_base)
+ except ValueError:
+ return COMPOSITE
+
+ # These are the number of Miller-Rabin iterations s.t. p(k, t) < 1E-30,
+ # with p(k, t) being the probability that a randomly chosen k-bit number
+ # is composite but still survives t MR iterations.
+ mr_ranges = ((220, 30), (280, 20), (390, 15), (512, 10),
+ (620, 7), (740, 6), (890, 5), (1200, 4),
+ (1700, 3), (3700, 2))
+
+ bit_size = candidate.size_in_bits()
+ try:
+ mr_iterations = list(filter(lambda x: bit_size < x[0],
+ mr_ranges))[0][1]
+ except IndexError:
+ mr_iterations = 1
+
+ if miller_rabin_test(candidate, mr_iterations,
+ randfunc=randfunc) == COMPOSITE:
+ return COMPOSITE
+ if lucas_test(candidate) == COMPOSITE:
+ return COMPOSITE
+ return PROBABLY_PRIME
+
+
+def generate_probable_prime(**kwargs):
+ """Generate a random probable prime.
+
+ The prime will not have any specific properties
+ (e.g. it will not be a *strong* prime).
+
+ Random numbers are evaluated for primality until one
+ passes all tests, consisting of a certain number of
+ Miller-Rabin tests with random bases followed by
+ a single Lucas test.
+
+ The number of Miller-Rabin iterations is chosen such that
+ the probability that the output number is a non-prime is
+ less than 1E-30 (roughly 2^{-100}).
+
+ This approach is compliant to `FIPS PUB 186-4`__.
+
+ :Keywords:
+ exact_bits : integer
+ The desired size in bits of the probable prime.
+ It must be at least 160.
+ randfunc : callable
+ An RNG function where candidate primes are taken from.
+ prime_filter : callable
+ A function that takes an Integer as parameter and returns
+ True if the number can be passed to further primality tests,
+ False if it should be immediately discarded.
+
+ :Return:
+ A probable prime in the range 2^exact_bits > p > 2^(exact_bits-1).
+
+ .. __: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
+ """
+
+ exact_bits = kwargs.pop("exact_bits", None)
+ randfunc = kwargs.pop("randfunc", None)
+ prime_filter = kwargs.pop("prime_filter", lambda x: True)
+ if kwargs:
+ raise ValueError("Unknown parameters: " + kwargs.keys())
+
+ if exact_bits is None:
+ raise ValueError("Missing exact_bits parameter")
+ if exact_bits < 160:
+ raise ValueError("Prime number is not big enough.")
+
+ if randfunc is None:
+ randfunc = Random.new().read
+
+ result = COMPOSITE
+ while result == COMPOSITE:
+ candidate = Integer.random(exact_bits=exact_bits,
+ randfunc=randfunc) | 1
+ if not prime_filter(candidate):
+ continue
+ result = test_probable_prime(candidate, randfunc)
+ return candidate
+
+
+def generate_probable_safe_prime(**kwargs):
+ """Generate a random, probable safe prime.
+
+ Note this operation is much slower than generating a simple prime.
+
+ :Keywords:
+ exact_bits : integer
+ The desired size in bits of the probable safe prime.
+ randfunc : callable
+ An RNG function where candidate primes are taken from.
+
+ :Return:
+ A probable safe prime in the range
+ 2^exact_bits > p > 2^(exact_bits-1).
+ """
+
+ exact_bits = kwargs.pop("exact_bits", None)
+ randfunc = kwargs.pop("randfunc", None)
+ if kwargs:
+ raise ValueError("Unknown parameters: " + kwargs.keys())
+
+ if randfunc is None:
+ randfunc = Random.new().read
+
+ result = COMPOSITE
+ while result == COMPOSITE:
+ q = generate_probable_prime(exact_bits=exact_bits - 1, randfunc=randfunc)
+ candidate = q * 2 + 1
+ if candidate.size_in_bits() != exact_bits:
+ continue
+ result = test_probable_prime(candidate, randfunc=randfunc)
+ return candidate
diff --git a/lib/Crypto/Math/Primality.pyi b/lib/Crypto/Math/Primality.pyi
new file mode 100644
index 0000000..7813483
--- /dev/null
+++ b/lib/Crypto/Math/Primality.pyi
@@ -0,0 +1,18 @@
+from typing import Callable, Optional, Union, Set
+
+PrimeResult = int
+
+COMPOSITE: PrimeResult
+PROBABLY_PRIME: PrimeResult
+
+def miller_rabin_test(candidate: int, iterations: int, randfunc: Optional[Callable[[int],bytes]]=None) -> PrimeResult: ...
+def lucas_test(candidate: int) -> PrimeResult: ...
+_sieve_base: Set[int]
+def test_probable_prime(candidate: int, randfunc: Optional[Callable[[int],bytes]]=None) -> PrimeResult: ...
+def generate_probable_prime(*,
+ exact_bits: int = ...,
+ randfunc: Callable[[int],bytes] = ...,
+ prime_filter: Callable[[int],bool] = ...) -> int: ...
+def generate_probable_safe_prime(*,
+ exact_bits: int = ...,
+ randfunc: Callable[[int],bytes] = ...) -> int: ...
diff --git a/lib/Crypto/Math/_IntegerBase.py b/lib/Crypto/Math/_IntegerBase.py
new file mode 100644
index 0000000..ec9cb47
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerBase.py
@@ -0,0 +1,392 @@
+# ===================================================================
+#
+# Copyright (c) 2018, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import abc
+
+from Crypto.Util.py3compat import iter_range, bord, bchr, ABC
+
+from Crypto import Random
+
+
+class IntegerBase(ABC):
+
+ # Conversions
+ @abc.abstractmethod
+ def __int__(self):
+ pass
+
+ @abc.abstractmethod
+ def __str__(self):
+ pass
+
+ @abc.abstractmethod
+ def __repr__(self):
+ pass
+
+ @abc.abstractmethod
+ def to_bytes(self, block_size=0, byteorder='big'):
+ pass
+
+ @staticmethod
+ @abc.abstractmethod
+ def from_bytes(byte_string, byteorder='big'):
+ pass
+
+ # Relations
+ @abc.abstractmethod
+ def __eq__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __ne__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __lt__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __le__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __gt__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __ge__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __nonzero__(self):
+ pass
+ __bool__ = __nonzero__
+
+ @abc.abstractmethod
+ def is_negative(self):
+ pass
+
+ # Arithmetic operations
+ @abc.abstractmethod
+ def __add__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __sub__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __mul__(self, factor):
+ pass
+
+ @abc.abstractmethod
+ def __floordiv__(self, divisor):
+ pass
+
+ @abc.abstractmethod
+ def __mod__(self, divisor):
+ pass
+
+ @abc.abstractmethod
+ def inplace_pow(self, exponent, modulus=None):
+ pass
+
+ @abc.abstractmethod
+ def __pow__(self, exponent, modulus=None):
+ pass
+
+ @abc.abstractmethod
+ def __abs__(self):
+ pass
+
+ @abc.abstractmethod
+ def sqrt(self, modulus=None):
+ pass
+
+ @abc.abstractmethod
+ def __iadd__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __isub__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __imul__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __imod__(self, term):
+ pass
+
+ # Boolean/bit operations
+ @abc.abstractmethod
+ def __and__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __or__(self, term):
+ pass
+
+ @abc.abstractmethod
+ def __rshift__(self, pos):
+ pass
+
+ @abc.abstractmethod
+ def __irshift__(self, pos):
+ pass
+
+ @abc.abstractmethod
+ def __lshift__(self, pos):
+ pass
+
+ @abc.abstractmethod
+ def __ilshift__(self, pos):
+ pass
+
+ @abc.abstractmethod
+ def get_bit(self, n):
+ pass
+
+ # Extra
+ @abc.abstractmethod
+ def is_odd(self):
+ pass
+
+ @abc.abstractmethod
+ def is_even(self):
+ pass
+
+ @abc.abstractmethod
+ def size_in_bits(self):
+ pass
+
+ @abc.abstractmethod
+ def size_in_bytes(self):
+ pass
+
+ @abc.abstractmethod
+ def is_perfect_square(self):
+ pass
+
+ @abc.abstractmethod
+ def fail_if_divisible_by(self, small_prime):
+ pass
+
+ @abc.abstractmethod
+ def multiply_accumulate(self, a, b):
+ pass
+
+ @abc.abstractmethod
+ def set(self, source):
+ pass
+
+ @abc.abstractmethod
+ def inplace_inverse(self, modulus):
+ pass
+
+ @abc.abstractmethod
+ def inverse(self, modulus):
+ pass
+
+ @abc.abstractmethod
+ def gcd(self, term):
+ pass
+
+ @abc.abstractmethod
+ def lcm(self, term):
+ pass
+
+ @staticmethod
+ @abc.abstractmethod
+ def jacobi_symbol(a, n):
+ pass
+
+ @staticmethod
+ def _tonelli_shanks(n, p):
+ """Tonelli-shanks algorithm for computing the square root
+ of n modulo a prime p.
+
+ n must be in the range [0..p-1].
+ p must be at least even.
+
+ The return value r is the square root of modulo p. If non-zero,
+ another solution will also exist (p-r).
+
+ Note we cannot assume that p is really a prime: if it's not,
+ we can either raise an exception or return the correct value.
+ """
+
+ # See https://rosettacode.org/wiki/Tonelli-Shanks_algorithm
+
+ if n in (0, 1):
+ return n
+
+ if p % 4 == 3:
+ root = pow(n, (p + 1) // 4, p)
+ if pow(root, 2, p) != n:
+ raise ValueError("Cannot compute square root")
+ return root
+
+ s = 1
+ q = (p - 1) // 2
+ while not (q & 1):
+ s += 1
+ q >>= 1
+
+ z = n.__class__(2)
+ while True:
+ euler = pow(z, (p - 1) // 2, p)
+ if euler == 1:
+ z += 1
+ continue
+ if euler == p - 1:
+ break
+ # Most probably p is not a prime
+ raise ValueError("Cannot compute square root")
+
+ m = s
+ c = pow(z, q, p)
+ t = pow(n, q, p)
+ r = pow(n, (q + 1) // 2, p)
+
+ while t != 1:
+ for i in iter_range(0, m):
+ if pow(t, 2**i, p) == 1:
+ break
+ if i == m:
+ raise ValueError("Cannot compute square root of %d mod %d" % (n, p))
+ b = pow(c, 2**(m - i - 1), p)
+ m = i
+ c = b**2 % p
+ t = (t * b**2) % p
+ r = (r * b) % p
+
+ if pow(r, 2, p) != n:
+ raise ValueError("Cannot compute square root")
+
+ return r
+
+ @classmethod
+ def random(cls, **kwargs):
+ """Generate a random natural integer of a certain size.
+
+ :Keywords:
+ exact_bits : positive integer
+ The length in bits of the resulting random Integer number.
+ The number is guaranteed to fulfil the relation:
+
+ 2^bits > result >= 2^(bits - 1)
+
+ max_bits : positive integer
+ The maximum length in bits of the resulting random Integer number.
+ The number is guaranteed to fulfil the relation:
+
+ 2^bits > result >=0
+
+ randfunc : callable
+ A function that returns a random byte string. The length of the
+ byte string is passed as parameter. Optional.
+ If not provided (or ``None``), randomness is read from the system RNG.
+
+ :Return: a Integer object
+ """
+
+ exact_bits = kwargs.pop("exact_bits", None)
+ max_bits = kwargs.pop("max_bits", None)
+ randfunc = kwargs.pop("randfunc", None)
+
+ if randfunc is None:
+ randfunc = Random.new().read
+
+ if exact_bits is None and max_bits is None:
+ raise ValueError("Either 'exact_bits' or 'max_bits' must be specified")
+
+ if exact_bits is not None and max_bits is not None:
+ raise ValueError("'exact_bits' and 'max_bits' are mutually exclusive")
+
+ bits = exact_bits or max_bits
+ bytes_needed = ((bits - 1) // 8) + 1
+ significant_bits_msb = 8 - (bytes_needed * 8 - bits)
+ msb = bord(randfunc(1)[0])
+ if exact_bits is not None:
+ msb |= 1 << (significant_bits_msb - 1)
+ msb &= (1 << significant_bits_msb) - 1
+
+ return cls.from_bytes(bchr(msb) + randfunc(bytes_needed - 1))
+
+ @classmethod
+ def random_range(cls, **kwargs):
+ """Generate a random integer within a given internal.
+
+ :Keywords:
+ min_inclusive : integer
+ The lower end of the interval (inclusive).
+ max_inclusive : integer
+ The higher end of the interval (inclusive).
+ max_exclusive : integer
+ The higher end of the interval (exclusive).
+ randfunc : callable
+ A function that returns a random byte string. The length of the
+ byte string is passed as parameter. Optional.
+ If not provided (or ``None``), randomness is read from the system RNG.
+ :Returns:
+ An Integer randomly taken in the given interval.
+ """
+
+ min_inclusive = kwargs.pop("min_inclusive", None)
+ max_inclusive = kwargs.pop("max_inclusive", None)
+ max_exclusive = kwargs.pop("max_exclusive", None)
+ randfunc = kwargs.pop("randfunc", None)
+
+ if kwargs:
+ raise ValueError("Unknown keywords: " + str(kwargs.keys))
+ if None not in (max_inclusive, max_exclusive):
+ raise ValueError("max_inclusive and max_exclusive cannot be both"
+ " specified")
+ if max_exclusive is not None:
+ max_inclusive = max_exclusive - 1
+ if None in (min_inclusive, max_inclusive):
+ raise ValueError("Missing keyword to identify the interval")
+
+ if randfunc is None:
+ randfunc = Random.new().read
+
+ norm_maximum = max_inclusive - min_inclusive
+ bits_needed = cls(norm_maximum).size_in_bits()
+
+ norm_candidate = -1
+ while not 0 <= norm_candidate <= norm_maximum:
+ norm_candidate = cls.random(
+ max_bits=bits_needed,
+ randfunc=randfunc
+ )
+ return norm_candidate + min_inclusive
+
diff --git a/lib/Crypto/Math/_IntegerBase.pyi b/lib/Crypto/Math/_IntegerBase.pyi
new file mode 100644
index 0000000..362c512
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerBase.pyi
@@ -0,0 +1,61 @@
+from typing import Optional, Union, Callable
+
+RandFunc = Callable[[int],int]
+
+class IntegerBase:
+
+ def __int__(self) -> int: ...
+ def __str__(self) -> str: ...
+ def __repr__(self) -> str: ...
+ def to_bytes(self, block_size: Optional[int]=0, byteorder: str= ...) -> bytes: ...
+ @staticmethod
+ def from_bytes(byte_string: bytes, byteorder: Optional[str] = ...) -> IntegerBase: ...
+ def __eq__(self, term: object) -> bool: ...
+ def __ne__(self, term: object) -> bool: ...
+ def __lt__(self, term: Union[IntegerBase, int]) -> bool: ...
+ def __le__(self, term: Union[IntegerBase, int]) -> bool: ...
+ def __gt__(self, term: Union[IntegerBase, int]) -> bool: ...
+ def __ge__(self, term: Union[IntegerBase, int]) -> bool: ...
+ def __nonzero__(self) -> bool: ...
+ def is_negative(self) -> bool: ...
+ def __add__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __sub__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __mul__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __floordiv__(self, divisor: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __mod__(self, divisor: Union[IntegerBase, int]) -> IntegerBase: ...
+ def inplace_pow(self, exponent: int, modulus: Optional[Union[IntegerBase, int]]=None) -> IntegerBase: ...
+ def __pow__(self, exponent: int, modulus: Optional[int]) -> IntegerBase: ...
+ def __abs__(self) -> IntegerBase: ...
+ def sqrt(self, modulus: Optional[int]) -> IntegerBase: ...
+ def __iadd__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __isub__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __imul__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __imod__(self, divisor: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __and__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __or__(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __rshift__(self, pos: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __irshift__(self, pos: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __lshift__(self, pos: Union[IntegerBase, int]) -> IntegerBase: ...
+ def __ilshift__(self, pos: Union[IntegerBase, int]) -> IntegerBase: ...
+ def get_bit(self, n: int) -> bool: ...
+ def is_odd(self) -> bool: ...
+ def is_even(self) -> bool: ...
+ def size_in_bits(self) -> int: ...
+ def size_in_bytes(self) -> int: ...
+ def is_perfect_square(self) -> bool: ...
+ def fail_if_divisible_by(self, small_prime: Union[IntegerBase, int]) -> None: ...
+ def multiply_accumulate(self, a: Union[IntegerBase, int], b: Union[IntegerBase, int]) -> IntegerBase: ...
+ def set(self, source: Union[IntegerBase, int]) -> IntegerBase: ...
+ def inplace_inverse(self, modulus: Union[IntegerBase, int]) -> IntegerBase: ...
+ def inverse(self, modulus: Union[IntegerBase, int]) -> IntegerBase: ...
+ def gcd(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ def lcm(self, term: Union[IntegerBase, int]) -> IntegerBase: ...
+ @staticmethod
+ def jacobi_symbol(a: Union[IntegerBase, int], n: Union[IntegerBase, int]) -> IntegerBase: ...
+ @staticmethod
+ def _tonelli_shanks(n: Union[IntegerBase, int], p: Union[IntegerBase, int]) -> IntegerBase : ...
+ @classmethod
+ def random(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
+ @classmethod
+ def random_range(cls, **kwargs: Union[int,RandFunc]) -> IntegerBase : ...
+
diff --git a/lib/Crypto/Math/_IntegerCustom.py b/lib/Crypto/Math/_IntegerCustom.py
new file mode 100644
index 0000000..d6f6f75
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerCustom.py
@@ -0,0 +1,118 @@
+# ===================================================================
+#
+# Copyright (c) 2018, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from ._IntegerNative import IntegerNative
+
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ create_string_buffer,
+ get_raw_buffer, backend,
+ c_size_t, c_ulonglong)
+
+
+from Crypto.Random.random import getrandbits
+
+c_defs = """
+int monty_pow(const uint8_t *base,
+ const uint8_t *exp,
+ const uint8_t *modulus,
+ uint8_t *out,
+ size_t len,
+ uint64_t seed);
+"""
+
+
+_raw_montgomery = load_pycryptodome_raw_lib("Crypto.Math._modexp", c_defs)
+implementation = {"library": "custom", "api": backend}
+
+
+class IntegerCustom(IntegerNative):
+
+ @staticmethod
+ def from_bytes(byte_string, byteorder='big'):
+ if byteorder == 'big':
+ pass
+ elif byteorder == 'little':
+ byte_string = bytearray(byte_string)
+ byte_string.reverse()
+ else:
+ raise ValueError("Incorrect byteorder")
+ return IntegerCustom(bytes_to_long(byte_string))
+
+ def inplace_pow(self, exponent, modulus=None):
+ exp_value = int(exponent)
+ if exp_value < 0:
+ raise ValueError("Exponent must not be negative")
+
+ # No modular reduction
+ if modulus is None:
+ self._value = pow(self._value, exp_value)
+ return self
+
+ # With modular reduction
+ mod_value = int(modulus)
+ if mod_value < 0:
+ raise ValueError("Modulus must be positive")
+ if mod_value == 0:
+ raise ZeroDivisionError("Modulus cannot be zero")
+
+ # C extension only works with odd moduli
+ if (mod_value & 1) == 0:
+ self._value = pow(self._value, exp_value, mod_value)
+ return self
+
+ # C extension only works with bases smaller than modulus
+ if self._value >= mod_value:
+ self._value %= mod_value
+
+ max_len = len(long_to_bytes(max(self._value, exp_value, mod_value)))
+
+ base_b = long_to_bytes(self._value, max_len)
+ exp_b = long_to_bytes(exp_value, max_len)
+ modulus_b = long_to_bytes(mod_value, max_len)
+
+ out = create_string_buffer(max_len)
+
+ error = _raw_montgomery.monty_pow(
+ out,
+ base_b,
+ exp_b,
+ modulus_b,
+ c_size_t(max_len),
+ c_ulonglong(getrandbits(64))
+ )
+
+ if error:
+ raise ValueError("monty_pow failed with error: %d" % error)
+
+ result = bytes_to_long(get_raw_buffer(out))
+ self._value = result
+ return self
diff --git a/lib/Crypto/Math/_IntegerCustom.pyi b/lib/Crypto/Math/_IntegerCustom.pyi
new file mode 100644
index 0000000..2dd75c7
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerCustom.pyi
@@ -0,0 +1,8 @@
+from typing import Any
+
+from ._IntegerNative import IntegerNative
+
+_raw_montgomery = Any
+
+class IntegerCustom(IntegerNative):
+ pass
diff --git a/lib/Crypto/Math/_IntegerGMP.py b/lib/Crypto/Math/_IntegerGMP.py
new file mode 100644
index 0000000..f552b71
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerGMP.py
@@ -0,0 +1,762 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import sys
+
+from Crypto.Util.py3compat import tobytes, is_native_int
+
+from Crypto.Util._raw_api import (backend, load_lib,
+ get_raw_buffer, get_c_string,
+ null_pointer, create_string_buffer,
+ c_ulong, c_size_t, c_uint8_ptr)
+
+from ._IntegerBase import IntegerBase
+
+gmp_defs = """typedef unsigned long UNIX_ULONG;
+ typedef struct { int a; int b; void *c; } MPZ;
+ typedef MPZ mpz_t[1];
+ typedef UNIX_ULONG mp_bitcnt_t;
+
+ void __gmpz_init (mpz_t x);
+ void __gmpz_init_set (mpz_t rop, const mpz_t op);
+ void __gmpz_init_set_ui (mpz_t rop, UNIX_ULONG op);
+
+ UNIX_ULONG __gmpz_get_ui (const mpz_t op);
+ void __gmpz_set (mpz_t rop, const mpz_t op);
+ void __gmpz_set_ui (mpz_t rop, UNIX_ULONG op);
+ void __gmpz_add (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ void __gmpz_add_ui (mpz_t rop, const mpz_t op1, UNIX_ULONG op2);
+ void __gmpz_sub_ui (mpz_t rop, const mpz_t op1, UNIX_ULONG op2);
+ void __gmpz_addmul (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ void __gmpz_addmul_ui (mpz_t rop, const mpz_t op1, UNIX_ULONG op2);
+ void __gmpz_submul_ui (mpz_t rop, const mpz_t op1, UNIX_ULONG op2);
+ void __gmpz_import (mpz_t rop, size_t count, int order, size_t size,
+ int endian, size_t nails, const void *op);
+ void * __gmpz_export (void *rop, size_t *countp, int order,
+ size_t size,
+ int endian, size_t nails, const mpz_t op);
+ size_t __gmpz_sizeinbase (const mpz_t op, int base);
+ void __gmpz_sub (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ void __gmpz_mul (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ void __gmpz_mul_ui (mpz_t rop, const mpz_t op1, UNIX_ULONG op2);
+ int __gmpz_cmp (const mpz_t op1, const mpz_t op2);
+ void __gmpz_powm (mpz_t rop, const mpz_t base, const mpz_t exp, const
+ mpz_t mod);
+ void __gmpz_powm_ui (mpz_t rop, const mpz_t base, UNIX_ULONG exp,
+ const mpz_t mod);
+ void __gmpz_pow_ui (mpz_t rop, const mpz_t base, UNIX_ULONG exp);
+ void __gmpz_sqrt(mpz_t rop, const mpz_t op);
+ void __gmpz_mod (mpz_t r, const mpz_t n, const mpz_t d);
+ void __gmpz_neg (mpz_t rop, const mpz_t op);
+ void __gmpz_abs (mpz_t rop, const mpz_t op);
+ void __gmpz_and (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ void __gmpz_ior (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ void __gmpz_clear (mpz_t x);
+ void __gmpz_tdiv_q_2exp (mpz_t q, const mpz_t n, mp_bitcnt_t b);
+ void __gmpz_fdiv_q (mpz_t q, const mpz_t n, const mpz_t d);
+ void __gmpz_mul_2exp (mpz_t rop, const mpz_t op1, mp_bitcnt_t op2);
+ int __gmpz_tstbit (const mpz_t op, mp_bitcnt_t bit_index);
+ int __gmpz_perfect_square_p (const mpz_t op);
+ int __gmpz_jacobi (const mpz_t a, const mpz_t b);
+ void __gmpz_gcd (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ UNIX_ULONG __gmpz_gcd_ui (mpz_t rop, const mpz_t op1,
+ UNIX_ULONG op2);
+ void __gmpz_lcm (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ int __gmpz_invert (mpz_t rop, const mpz_t op1, const mpz_t op2);
+ int __gmpz_divisible_p (const mpz_t n, const mpz_t d);
+ int __gmpz_divisible_ui_p (const mpz_t n, UNIX_ULONG d);
+ """
+
+if sys.platform == "win32":
+ raise ImportError("Not using GMP on Windows")
+
+lib = load_lib("gmp", gmp_defs)
+implementation = {"library": "gmp", "api": backend}
+
+if hasattr(lib, "__mpir_version"):
+ raise ImportError("MPIR library detected")
+
+# In order to create a function that returns a pointer to
+# a new MPZ structure, we need to break the abstraction
+# and know exactly what ffi backend we have
+if implementation["api"] == "ctypes":
+ from ctypes import Structure, c_int, c_void_p, byref
+
+ class _MPZ(Structure):
+ _fields_ = [('_mp_alloc', c_int),
+ ('_mp_size', c_int),
+ ('_mp_d', c_void_p)]
+
+ def new_mpz():
+ return byref(_MPZ())
+
+else:
+ # We are using CFFI
+ from Crypto.Util._raw_api import ffi
+
+ def new_mpz():
+ return ffi.new("MPZ*")
+
+
+# Lazy creation of GMP methods
+class _GMP(object):
+
+ def __getattr__(self, name):
+ if name.startswith("mpz_"):
+ func_name = "__gmpz_" + name[4:]
+ elif name.startswith("gmp_"):
+ func_name = "__gmp_" + name[4:]
+ else:
+ raise AttributeError("Attribute %s is invalid" % name)
+ func = getattr(lib, func_name)
+ setattr(self, name, func)
+ return func
+
+
+_gmp = _GMP()
+
+
+class IntegerGMP(IntegerBase):
+ """A fast, arbitrary precision integer"""
+
+ _zero_mpz_p = new_mpz()
+ _gmp.mpz_init_set_ui(_zero_mpz_p, c_ulong(0))
+
+ def __init__(self, value):
+ """Initialize the integer to the given value."""
+
+ self._mpz_p = new_mpz()
+ self._initialized = False
+
+ if isinstance(value, float):
+ raise ValueError("A floating point type is not a natural number")
+
+ if is_native_int(value):
+ _gmp.mpz_init(self._mpz_p)
+ self._initialized = True
+ if value == 0:
+ return
+
+ tmp = new_mpz()
+ _gmp.mpz_init(tmp)
+
+ try:
+ positive = value >= 0
+ reduce = abs(value)
+ slots = (reduce.bit_length() - 1) // 32 + 1
+
+ while slots > 0:
+ slots = slots - 1
+ _gmp.mpz_set_ui(tmp,
+ c_ulong(0xFFFFFFFF & (reduce >> (slots * 32))))
+ _gmp.mpz_mul_2exp(tmp, tmp, c_ulong(slots * 32))
+ _gmp.mpz_add(self._mpz_p, self._mpz_p, tmp)
+ finally:
+ _gmp.mpz_clear(tmp)
+
+ if not positive:
+ _gmp.mpz_neg(self._mpz_p, self._mpz_p)
+
+ elif isinstance(value, IntegerGMP):
+ _gmp.mpz_init_set(self._mpz_p, value._mpz_p)
+ self._initialized = True
+ else:
+ raise NotImplementedError
+
+
+ # Conversions
+ def __int__(self):
+ tmp = new_mpz()
+ _gmp.mpz_init_set(tmp, self._mpz_p)
+
+ try:
+ value = 0
+ slot = 0
+ while _gmp.mpz_cmp(tmp, self._zero_mpz_p) != 0:
+ lsb = _gmp.mpz_get_ui(tmp) & 0xFFFFFFFF
+ value |= lsb << (slot * 32)
+ _gmp.mpz_tdiv_q_2exp(tmp, tmp, c_ulong(32))
+ slot = slot + 1
+ finally:
+ _gmp.mpz_clear(tmp)
+
+ if self < 0:
+ value = -value
+ return int(value)
+
+ def __str__(self):
+ return str(int(self))
+
+ def __repr__(self):
+ return "Integer(%s)" % str(self)
+
+ # Only Python 2.x
+ def __hex__(self):
+ return hex(int(self))
+
+ # Only Python 3.x
+ def __index__(self):
+ return int(self)
+
+ def to_bytes(self, block_size=0, byteorder='big'):
+ """Convert the number into a byte string.
+
+ This method encodes the number in network order and prepends
+ as many zero bytes as required. It only works for non-negative
+ values.
+
+ :Parameters:
+ block_size : integer
+ The exact size the output byte string must have.
+ If zero, the string has the minimal length.
+ byteorder : string
+ 'big' for big-endian integers (default), 'little' for litte-endian.
+ :Returns:
+ A byte string.
+ :Raise ValueError:
+ If the value is negative or if ``block_size`` is
+ provided and the length of the byte string would exceed it.
+ """
+
+ if self < 0:
+ raise ValueError("Conversion only valid for non-negative numbers")
+
+ buf_len = (_gmp.mpz_sizeinbase(self._mpz_p, 2) + 7) // 8
+ if buf_len > block_size > 0:
+ raise ValueError("Number is too big to convert to byte string"
+ " of prescribed length")
+ buf = create_string_buffer(buf_len)
+
+
+ _gmp.mpz_export(
+ buf,
+ null_pointer, # Ignore countp
+ 1, # Big endian
+ c_size_t(1), # Each word is 1 byte long
+ 0, # Endianess within a word - not relevant
+ c_size_t(0), # No nails
+ self._mpz_p)
+
+ result = b'\x00' * max(0, block_size - buf_len) + get_raw_buffer(buf)
+ if byteorder == 'big':
+ pass
+ elif byteorder == 'little':
+ result = bytearray(result)
+ result.reverse()
+ result = bytes(result)
+ else:
+ raise ValueError("Incorrect byteorder")
+ return result
+
+ @staticmethod
+ def from_bytes(byte_string, byteorder='big'):
+ """Convert a byte string into a number.
+
+ :Parameters:
+ byte_string : byte string
+ The input number, encoded in network order.
+ It can only be non-negative.
+ byteorder : string
+ 'big' for big-endian integers (default), 'little' for litte-endian.
+
+ :Return:
+ The ``Integer`` object carrying the same value as the input.
+ """
+ result = IntegerGMP(0)
+ if byteorder == 'big':
+ pass
+ elif byteorder == 'little':
+ byte_string = bytearray(byte_string)
+ byte_string.reverse()
+ else:
+ raise ValueError("Incorrect byteorder")
+ _gmp.mpz_import(
+ result._mpz_p,
+ c_size_t(len(byte_string)), # Amount of words to read
+ 1, # Big endian
+ c_size_t(1), # Each word is 1 byte long
+ 0, # Endianess within a word - not relevant
+ c_size_t(0), # No nails
+ c_uint8_ptr(byte_string))
+ return result
+
+ # Relations
+ def _apply_and_return(self, func, term):
+ if not isinstance(term, IntegerGMP):
+ term = IntegerGMP(term)
+ return func(self._mpz_p, term._mpz_p)
+
+ def __eq__(self, term):
+ if not (isinstance(term, IntegerGMP) or is_native_int(term)):
+ return False
+ return self._apply_and_return(_gmp.mpz_cmp, term) == 0
+
+ def __ne__(self, term):
+ if not (isinstance(term, IntegerGMP) or is_native_int(term)):
+ return True
+ return self._apply_and_return(_gmp.mpz_cmp, term) != 0
+
+ def __lt__(self, term):
+ return self._apply_and_return(_gmp.mpz_cmp, term) < 0
+
+ def __le__(self, term):
+ return self._apply_and_return(_gmp.mpz_cmp, term) <= 0
+
+ def __gt__(self, term):
+ return self._apply_and_return(_gmp.mpz_cmp, term) > 0
+
+ def __ge__(self, term):
+ return self._apply_and_return(_gmp.mpz_cmp, term) >= 0
+
+ def __nonzero__(self):
+ return _gmp.mpz_cmp(self._mpz_p, self._zero_mpz_p) != 0
+ __bool__ = __nonzero__
+
+ def is_negative(self):
+ return _gmp.mpz_cmp(self._mpz_p, self._zero_mpz_p) < 0
+
+ # Arithmetic operations
+ def __add__(self, term):
+ result = IntegerGMP(0)
+ if not isinstance(term, IntegerGMP):
+ try:
+ term = IntegerGMP(term)
+ except NotImplementedError:
+ return NotImplemented
+ _gmp.mpz_add(result._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return result
+
+ def __sub__(self, term):
+ result = IntegerGMP(0)
+ if not isinstance(term, IntegerGMP):
+ try:
+ term = IntegerGMP(term)
+ except NotImplementedError:
+ return NotImplemented
+ _gmp.mpz_sub(result._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return result
+
+ def __mul__(self, term):
+ result = IntegerGMP(0)
+ if not isinstance(term, IntegerGMP):
+ try:
+ term = IntegerGMP(term)
+ except NotImplementedError:
+ return NotImplemented
+ _gmp.mpz_mul(result._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return result
+
+ def __floordiv__(self, divisor):
+ if not isinstance(divisor, IntegerGMP):
+ divisor = IntegerGMP(divisor)
+ if _gmp.mpz_cmp(divisor._mpz_p,
+ self._zero_mpz_p) == 0:
+ raise ZeroDivisionError("Division by zero")
+ result = IntegerGMP(0)
+ _gmp.mpz_fdiv_q(result._mpz_p,
+ self._mpz_p,
+ divisor._mpz_p)
+ return result
+
+ def __mod__(self, divisor):
+ if not isinstance(divisor, IntegerGMP):
+ divisor = IntegerGMP(divisor)
+ comp = _gmp.mpz_cmp(divisor._mpz_p,
+ self._zero_mpz_p)
+ if comp == 0:
+ raise ZeroDivisionError("Division by zero")
+ if comp < 0:
+ raise ValueError("Modulus must be positive")
+ result = IntegerGMP(0)
+ _gmp.mpz_mod(result._mpz_p,
+ self._mpz_p,
+ divisor._mpz_p)
+ return result
+
+ def inplace_pow(self, exponent, modulus=None):
+
+ if modulus is None:
+ if exponent < 0:
+ raise ValueError("Exponent must not be negative")
+
+ # Normal exponentiation
+ if exponent > 256:
+ raise ValueError("Exponent is too big")
+ _gmp.mpz_pow_ui(self._mpz_p,
+ self._mpz_p, # Base
+ c_ulong(int(exponent))
+ )
+ else:
+ # Modular exponentiation
+ if not isinstance(modulus, IntegerGMP):
+ modulus = IntegerGMP(modulus)
+ if not modulus:
+ raise ZeroDivisionError("Division by zero")
+ if modulus.is_negative():
+ raise ValueError("Modulus must be positive")
+ if is_native_int(exponent):
+ if exponent < 0:
+ raise ValueError("Exponent must not be negative")
+ if exponent < 65536:
+ _gmp.mpz_powm_ui(self._mpz_p,
+ self._mpz_p,
+ c_ulong(exponent),
+ modulus._mpz_p)
+ return self
+ exponent = IntegerGMP(exponent)
+ elif exponent.is_negative():
+ raise ValueError("Exponent must not be negative")
+ _gmp.mpz_powm(self._mpz_p,
+ self._mpz_p,
+ exponent._mpz_p,
+ modulus._mpz_p)
+ return self
+
+ def __pow__(self, exponent, modulus=None):
+ result = IntegerGMP(self)
+ return result.inplace_pow(exponent, modulus)
+
+ def __abs__(self):
+ result = IntegerGMP(0)
+ _gmp.mpz_abs(result._mpz_p, self._mpz_p)
+ return result
+
+ def sqrt(self, modulus=None):
+ """Return the largest Integer that does not
+ exceed the square root"""
+
+ if modulus is None:
+ if self < 0:
+ raise ValueError("Square root of negative value")
+ result = IntegerGMP(0)
+ _gmp.mpz_sqrt(result._mpz_p,
+ self._mpz_p)
+ else:
+ if modulus <= 0:
+ raise ValueError("Modulus must be positive")
+ modulus = int(modulus)
+ result = IntegerGMP(self._tonelli_shanks(int(self) % modulus, modulus))
+
+ return result
+
+ def __iadd__(self, term):
+ if is_native_int(term):
+ if 0 <= term < 65536:
+ _gmp.mpz_add_ui(self._mpz_p,
+ self._mpz_p,
+ c_ulong(term))
+ return self
+ if -65535 < term < 0:
+ _gmp.mpz_sub_ui(self._mpz_p,
+ self._mpz_p,
+ c_ulong(-term))
+ return self
+ term = IntegerGMP(term)
+ _gmp.mpz_add(self._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return self
+
+ def __isub__(self, term):
+ if is_native_int(term):
+ if 0 <= term < 65536:
+ _gmp.mpz_sub_ui(self._mpz_p,
+ self._mpz_p,
+ c_ulong(term))
+ return self
+ if -65535 < term < 0:
+ _gmp.mpz_add_ui(self._mpz_p,
+ self._mpz_p,
+ c_ulong(-term))
+ return self
+ term = IntegerGMP(term)
+ _gmp.mpz_sub(self._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return self
+
+ def __imul__(self, term):
+ if is_native_int(term):
+ if 0 <= term < 65536:
+ _gmp.mpz_mul_ui(self._mpz_p,
+ self._mpz_p,
+ c_ulong(term))
+ return self
+ if -65535 < term < 0:
+ _gmp.mpz_mul_ui(self._mpz_p,
+ self._mpz_p,
+ c_ulong(-term))
+ _gmp.mpz_neg(self._mpz_p, self._mpz_p)
+ return self
+ term = IntegerGMP(term)
+ _gmp.mpz_mul(self._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return self
+
+ def __imod__(self, divisor):
+ if not isinstance(divisor, IntegerGMP):
+ divisor = IntegerGMP(divisor)
+ comp = _gmp.mpz_cmp(divisor._mpz_p,
+ divisor._zero_mpz_p)
+ if comp == 0:
+ raise ZeroDivisionError("Division by zero")
+ if comp < 0:
+ raise ValueError("Modulus must be positive")
+ _gmp.mpz_mod(self._mpz_p,
+ self._mpz_p,
+ divisor._mpz_p)
+ return self
+
+ # Boolean/bit operations
+ def __and__(self, term):
+ result = IntegerGMP(0)
+ if not isinstance(term, IntegerGMP):
+ term = IntegerGMP(term)
+ _gmp.mpz_and(result._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return result
+
+ def __or__(self, term):
+ result = IntegerGMP(0)
+ if not isinstance(term, IntegerGMP):
+ term = IntegerGMP(term)
+ _gmp.mpz_ior(result._mpz_p,
+ self._mpz_p,
+ term._mpz_p)
+ return result
+
+ def __rshift__(self, pos):
+ result = IntegerGMP(0)
+ if pos < 0:
+ raise ValueError("negative shift count")
+ if pos > 65536:
+ if self < 0:
+ return -1
+ else:
+ return 0
+ _gmp.mpz_tdiv_q_2exp(result._mpz_p,
+ self._mpz_p,
+ c_ulong(int(pos)))
+ return result
+
+ def __irshift__(self, pos):
+ if pos < 0:
+ raise ValueError("negative shift count")
+ if pos > 65536:
+ if self < 0:
+ return -1
+ else:
+ return 0
+ _gmp.mpz_tdiv_q_2exp(self._mpz_p,
+ self._mpz_p,
+ c_ulong(int(pos)))
+ return self
+
+ def __lshift__(self, pos):
+ result = IntegerGMP(0)
+ if not 0 <= pos < 65536:
+ raise ValueError("Incorrect shift count")
+ _gmp.mpz_mul_2exp(result._mpz_p,
+ self._mpz_p,
+ c_ulong(int(pos)))
+ return result
+
+ def __ilshift__(self, pos):
+ if not 0 <= pos < 65536:
+ raise ValueError("Incorrect shift count")
+ _gmp.mpz_mul_2exp(self._mpz_p,
+ self._mpz_p,
+ c_ulong(int(pos)))
+ return self
+
+ def get_bit(self, n):
+ """Return True if the n-th bit is set to 1.
+ Bit 0 is the least significant."""
+
+ if self < 0:
+ raise ValueError("no bit representation for negative values")
+ if n < 0:
+ raise ValueError("negative bit count")
+ if n > 65536:
+ return 0
+ return bool(_gmp.mpz_tstbit(self._mpz_p,
+ c_ulong(int(n))))
+
+ # Extra
+ def is_odd(self):
+ return _gmp.mpz_tstbit(self._mpz_p, 0) == 1
+
+ def is_even(self):
+ return _gmp.mpz_tstbit(self._mpz_p, 0) == 0
+
+ def size_in_bits(self):
+ """Return the minimum number of bits that can encode the number."""
+
+ if self < 0:
+ raise ValueError("Conversion only valid for non-negative numbers")
+ return _gmp.mpz_sizeinbase(self._mpz_p, 2)
+
+ def size_in_bytes(self):
+ """Return the minimum number of bytes that can encode the number."""
+ return (self.size_in_bits() - 1) // 8 + 1
+
+ def is_perfect_square(self):
+ return _gmp.mpz_perfect_square_p(self._mpz_p) != 0
+
+ def fail_if_divisible_by(self, small_prime):
+ """Raise an exception if the small prime is a divisor."""
+
+ if is_native_int(small_prime):
+ if 0 < small_prime < 65536:
+ if _gmp.mpz_divisible_ui_p(self._mpz_p,
+ c_ulong(small_prime)):
+ raise ValueError("The value is composite")
+ return
+ small_prime = IntegerGMP(small_prime)
+ if _gmp.mpz_divisible_p(self._mpz_p,
+ small_prime._mpz_p):
+ raise ValueError("The value is composite")
+
+ def multiply_accumulate(self, a, b):
+ """Increment the number by the product of a and b."""
+
+ if not isinstance(a, IntegerGMP):
+ a = IntegerGMP(a)
+ if is_native_int(b):
+ if 0 < b < 65536:
+ _gmp.mpz_addmul_ui(self._mpz_p,
+ a._mpz_p,
+ c_ulong(b))
+ return self
+ if -65535 < b < 0:
+ _gmp.mpz_submul_ui(self._mpz_p,
+ a._mpz_p,
+ c_ulong(-b))
+ return self
+ b = IntegerGMP(b)
+ _gmp.mpz_addmul(self._mpz_p,
+ a._mpz_p,
+ b._mpz_p)
+ return self
+
+ def set(self, source):
+ """Set the Integer to have the given value"""
+
+ if not isinstance(source, IntegerGMP):
+ source = IntegerGMP(source)
+ _gmp.mpz_set(self._mpz_p,
+ source._mpz_p)
+ return self
+
+ def inplace_inverse(self, modulus):
+ """Compute the inverse of this number in the ring of
+ modulo integers.
+
+ Raise an exception if no inverse exists.
+ """
+
+ if not isinstance(modulus, IntegerGMP):
+ modulus = IntegerGMP(modulus)
+
+ comp = _gmp.mpz_cmp(modulus._mpz_p,
+ self._zero_mpz_p)
+ if comp == 0:
+ raise ZeroDivisionError("Modulus cannot be zero")
+ if comp < 0:
+ raise ValueError("Modulus must be positive")
+
+ result = _gmp.mpz_invert(self._mpz_p,
+ self._mpz_p,
+ modulus._mpz_p)
+ if not result:
+ raise ValueError("No inverse value can be computed")
+ return self
+
+ def inverse(self, modulus):
+ result = IntegerGMP(self)
+ result.inplace_inverse(modulus)
+ return result
+
+ def gcd(self, term):
+ """Compute the greatest common denominator between this
+ number and another term."""
+
+ result = IntegerGMP(0)
+ if is_native_int(term):
+ if 0 < term < 65535:
+ _gmp.mpz_gcd_ui(result._mpz_p,
+ self._mpz_p,
+ c_ulong(term))
+ return result
+ term = IntegerGMP(term)
+ _gmp.mpz_gcd(result._mpz_p, self._mpz_p, term._mpz_p)
+ return result
+
+ def lcm(self, term):
+ """Compute the least common multiplier between this
+ number and another term."""
+
+ result = IntegerGMP(0)
+ if not isinstance(term, IntegerGMP):
+ term = IntegerGMP(term)
+ _gmp.mpz_lcm(result._mpz_p, self._mpz_p, term._mpz_p)
+ return result
+
+ @staticmethod
+ def jacobi_symbol(a, n):
+ """Compute the Jacobi symbol"""
+
+ if not isinstance(a, IntegerGMP):
+ a = IntegerGMP(a)
+ if not isinstance(n, IntegerGMP):
+ n = IntegerGMP(n)
+ if n <= 0 or n.is_even():
+ raise ValueError("n must be positive odd for the Jacobi symbol")
+ return _gmp.mpz_jacobi(a._mpz_p, n._mpz_p)
+
+ # Clean-up
+ def __del__(self):
+
+ try:
+ if self._mpz_p is not None:
+ if self._initialized:
+ _gmp.mpz_clear(self._mpz_p)
+
+ self._mpz_p = None
+ except AttributeError:
+ pass
diff --git a/lib/Crypto/Math/_IntegerGMP.pyi b/lib/Crypto/Math/_IntegerGMP.pyi
new file mode 100644
index 0000000..2181b47
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerGMP.pyi
@@ -0,0 +1,3 @@
+from ._IntegerBase import IntegerBase
+class IntegerGMP(IntegerBase):
+ pass
diff --git a/lib/Crypto/Math/_IntegerNative.py b/lib/Crypto/Math/_IntegerNative.py
new file mode 100644
index 0000000..a8bcb3d
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerNative.py
@@ -0,0 +1,395 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from ._IntegerBase import IntegerBase
+
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+
+
+class IntegerNative(IntegerBase):
+ """A class to model a natural integer (including zero)"""
+
+ def __init__(self, value):
+ if isinstance(value, float):
+ raise ValueError("A floating point type is not a natural number")
+ try:
+ self._value = value._value
+ except AttributeError:
+ self._value = value
+
+ # Conversions
+ def __int__(self):
+ return self._value
+
+ def __str__(self):
+ return str(int(self))
+
+ def __repr__(self):
+ return "Integer(%s)" % str(self)
+
+ # Only Python 2.x
+ def __hex__(self):
+ return hex(self._value)
+
+ # Only Python 3.x
+ def __index__(self):
+ return int(self._value)
+
+ def to_bytes(self, block_size=0, byteorder='big'):
+ if self._value < 0:
+ raise ValueError("Conversion only valid for non-negative numbers")
+ result = long_to_bytes(self._value, block_size)
+ if len(result) > block_size > 0:
+ raise ValueError("Value too large to encode")
+ if byteorder == 'big':
+ pass
+ elif byteorder == 'little':
+ result = bytearray(result)
+ result.reverse()
+ result = bytes(result)
+ else:
+ raise ValueError("Incorrect byteorder")
+ return result
+
+ @classmethod
+ def from_bytes(cls, byte_string, byteorder='big'):
+ if byteorder == 'big':
+ pass
+ elif byteorder == 'little':
+ byte_string = bytearray(byte_string)
+ byte_string.reverse()
+ else:
+ raise ValueError("Incorrect byteorder")
+ return cls(bytes_to_long(byte_string))
+
+ # Relations
+ def __eq__(self, term):
+ if term is None:
+ return False
+ return self._value == int(term)
+
+ def __ne__(self, term):
+ return not self.__eq__(term)
+
+ def __lt__(self, term):
+ return self._value < int(term)
+
+ def __le__(self, term):
+ return self.__lt__(term) or self.__eq__(term)
+
+ def __gt__(self, term):
+ return not self.__le__(term)
+
+ def __ge__(self, term):
+ return not self.__lt__(term)
+
+ def __nonzero__(self):
+ return self._value != 0
+ __bool__ = __nonzero__
+
+ def is_negative(self):
+ return self._value < 0
+
+ # Arithmetic operations
+ def __add__(self, term):
+ try:
+ return self.__class__(self._value + int(term))
+ except (ValueError, AttributeError, TypeError):
+ return NotImplemented
+
+ def __sub__(self, term):
+ try:
+ return self.__class__(self._value - int(term))
+ except (ValueError, AttributeError, TypeError):
+ return NotImplemented
+
+ def __mul__(self, factor):
+ try:
+ return self.__class__(self._value * int(factor))
+ except (ValueError, AttributeError, TypeError):
+ return NotImplemented
+
+ def __floordiv__(self, divisor):
+ return self.__class__(self._value // int(divisor))
+
+ def __mod__(self, divisor):
+ divisor_value = int(divisor)
+ if divisor_value < 0:
+ raise ValueError("Modulus must be positive")
+ return self.__class__(self._value % divisor_value)
+
+ def inplace_pow(self, exponent, modulus=None):
+ exp_value = int(exponent)
+ if exp_value < 0:
+ raise ValueError("Exponent must not be negative")
+
+ if modulus is not None:
+ mod_value = int(modulus)
+ if mod_value < 0:
+ raise ValueError("Modulus must be positive")
+ if mod_value == 0:
+ raise ZeroDivisionError("Modulus cannot be zero")
+ else:
+ mod_value = None
+ self._value = pow(self._value, exp_value, mod_value)
+ return self
+
+ def __pow__(self, exponent, modulus=None):
+ result = self.__class__(self)
+ return result.inplace_pow(exponent, modulus)
+
+ def __abs__(self):
+ return abs(self._value)
+
+ def sqrt(self, modulus=None):
+
+ value = self._value
+ if modulus is None:
+ if value < 0:
+ raise ValueError("Square root of negative value")
+ # http://stackoverflow.com/questions/15390807/integer-square-root-in-python
+
+ x = value
+ y = (x + 1) // 2
+ while y < x:
+ x = y
+ y = (x + value // x) // 2
+ result = x
+ else:
+ if modulus <= 0:
+ raise ValueError("Modulus must be positive")
+ result = self._tonelli_shanks(self % modulus, modulus)
+
+ return self.__class__(result)
+
+ def __iadd__(self, term):
+ self._value += int(term)
+ return self
+
+ def __isub__(self, term):
+ self._value -= int(term)
+ return self
+
+ def __imul__(self, term):
+ self._value *= int(term)
+ return self
+
+ def __imod__(self, term):
+ modulus = int(term)
+ if modulus == 0:
+ raise ZeroDivisionError("Division by zero")
+ if modulus < 0:
+ raise ValueError("Modulus must be positive")
+ self._value %= modulus
+ return self
+
+ # Boolean/bit operations
+ def __and__(self, term):
+ return self.__class__(self._value & int(term))
+
+ def __or__(self, term):
+ return self.__class__(self._value | int(term))
+
+ def __rshift__(self, pos):
+ try:
+ return self.__class__(self._value >> int(pos))
+ except OverflowError:
+ if self._value >= 0:
+ return 0
+ else:
+ return -1
+
+ def __irshift__(self, pos):
+ try:
+ self._value >>= int(pos)
+ except OverflowError:
+ if self._value >= 0:
+ return 0
+ else:
+ return -1
+ return self
+
+ def __lshift__(self, pos):
+ try:
+ return self.__class__(self._value << int(pos))
+ except OverflowError:
+ raise ValueError("Incorrect shift count")
+
+ def __ilshift__(self, pos):
+ try:
+ self._value <<= int(pos)
+ except OverflowError:
+ raise ValueError("Incorrect shift count")
+ return self
+
+ def get_bit(self, n):
+ if self._value < 0:
+ raise ValueError("no bit representation for negative values")
+ try:
+ try:
+ result = (self._value >> n._value) & 1
+ if n._value < 0:
+ raise ValueError("negative bit count")
+ except AttributeError:
+ result = (self._value >> n) & 1
+ if n < 0:
+ raise ValueError("negative bit count")
+ except OverflowError:
+ result = 0
+ return result
+
+ # Extra
+ def is_odd(self):
+ return (self._value & 1) == 1
+
+ def is_even(self):
+ return (self._value & 1) == 0
+
+ def size_in_bits(self):
+
+ if self._value < 0:
+ raise ValueError("Conversion only valid for non-negative numbers")
+
+ if self._value == 0:
+ return 1
+
+ bit_size = 0
+ tmp = self._value
+ while tmp:
+ tmp >>= 1
+ bit_size += 1
+
+ return bit_size
+
+ def size_in_bytes(self):
+ return (self.size_in_bits() - 1) // 8 + 1
+
+ def is_perfect_square(self):
+ if self._value < 0:
+ return False
+ if self._value in (0, 1):
+ return True
+
+ x = self._value // 2
+ square_x = x ** 2
+
+ while square_x > self._value:
+ x = (square_x + self._value) // (2 * x)
+ square_x = x ** 2
+
+ return self._value == x ** 2
+
+ def fail_if_divisible_by(self, small_prime):
+ if (self._value % int(small_prime)) == 0:
+ raise ValueError("Value is composite")
+
+ def multiply_accumulate(self, a, b):
+ self._value += int(a) * int(b)
+ return self
+
+ def set(self, source):
+ self._value = int(source)
+
+ def inplace_inverse(self, modulus):
+ modulus = int(modulus)
+ if modulus == 0:
+ raise ZeroDivisionError("Modulus cannot be zero")
+ if modulus < 0:
+ raise ValueError("Modulus cannot be negative")
+ r_p, r_n = self._value, modulus
+ s_p, s_n = 1, 0
+ while r_n > 0:
+ q = r_p // r_n
+ r_p, r_n = r_n, r_p - q * r_n
+ s_p, s_n = s_n, s_p - q * s_n
+ if r_p != 1:
+ raise ValueError("No inverse value can be computed" + str(r_p))
+ while s_p < 0:
+ s_p += modulus
+ self._value = s_p
+ return self
+
+ def inverse(self, modulus):
+ result = self.__class__(self)
+ result.inplace_inverse(modulus)
+ return result
+
+ def gcd(self, term):
+ r_p, r_n = abs(self._value), abs(int(term))
+ while r_n > 0:
+ q = r_p // r_n
+ r_p, r_n = r_n, r_p - q * r_n
+ return self.__class__(r_p)
+
+ def lcm(self, term):
+ term = int(term)
+ if self._value == 0 or term == 0:
+ return self.__class__(0)
+ return self.__class__(abs((self._value * term) // self.gcd(term)._value))
+
+ @staticmethod
+ def jacobi_symbol(a, n):
+ a = int(a)
+ n = int(n)
+
+ if n <= 0:
+ raise ValueError("n must be a positive integer")
+
+ if (n & 1) == 0:
+ raise ValueError("n must be odd for the Jacobi symbol")
+
+ # Step 1
+ a = a % n
+ # Step 2
+ if a == 1 or n == 1:
+ return 1
+ # Step 3
+ if a == 0:
+ return 0
+ # Step 4
+ e = 0
+ a1 = a
+ while (a1 & 1) == 0:
+ a1 >>= 1
+ e += 1
+ # Step 5
+ if (e & 1) == 0:
+ s = 1
+ elif n % 8 in (1, 7):
+ s = 1
+ else:
+ s = -1
+ # Step 6
+ if n % 4 == 3 and a1 % 4 == 3:
+ s = -s
+ # Step 7
+ n1 = n % a1
+ # Step 8
+ return s * IntegerNative.jacobi_symbol(n1, a1)
diff --git a/lib/Crypto/Math/_IntegerNative.pyi b/lib/Crypto/Math/_IntegerNative.pyi
new file mode 100644
index 0000000..3f65a39
--- /dev/null
+++ b/lib/Crypto/Math/_IntegerNative.pyi
@@ -0,0 +1,3 @@
+from ._IntegerBase import IntegerBase
+class IntegerNative(IntegerBase):
+ pass
diff --git a/lib/Crypto/Math/__init__.py b/lib/Crypto/Math/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/Crypto/Math/__init__.py
diff --git a/lib/Crypto/Math/_modexp.abi3.so b/lib/Crypto/Math/_modexp.abi3.so
new file mode 100755
index 0000000..ffc9ffd
--- /dev/null
+++ b/lib/Crypto/Math/_modexp.abi3.so
Binary files differ
diff --git a/lib/Crypto/Protocol/KDF.py b/lib/Crypto/Protocol/KDF.py
new file mode 100644
index 0000000..1348265
--- /dev/null
+++ b/lib/Crypto/Protocol/KDF.py
@@ -0,0 +1,574 @@
+# coding=utf-8
+#
+# KDF.py : a collection of Key Derivation Functions
+#
+# Part of the Python Cryptography Toolkit
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+import re
+import struct
+from functools import reduce
+
+from Crypto.Util.py3compat import (tobytes, bord, _copy_bytes, iter_range,
+ tostr, bchr, bstr)
+
+from Crypto.Hash import SHA1, SHA256, HMAC, CMAC, BLAKE2s
+from Crypto.Util.strxor import strxor
+from Crypto.Random import get_random_bytes
+from Crypto.Util.number import size as bit_size, long_to_bytes, bytes_to_long
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ create_string_buffer,
+ get_raw_buffer, c_size_t)
+
+_raw_salsa20_lib = load_pycryptodome_raw_lib("Crypto.Cipher._Salsa20",
+ """
+ int Salsa20_8_core(const uint8_t *x, const uint8_t *y,
+ uint8_t *out);
+ """)
+
+_raw_scrypt_lib = load_pycryptodome_raw_lib("Crypto.Protocol._scrypt",
+ """
+ typedef int (core_t)(const uint8_t [64], const uint8_t [64], uint8_t [64]);
+ int scryptROMix(const uint8_t *data_in, uint8_t *data_out,
+ size_t data_len, unsigned N, core_t *core);
+ """)
+
+
+def PBKDF1(password, salt, dkLen, count=1000, hashAlgo=None):
+ """Derive one key from a password (or passphrase).
+
+ This function performs key derivation according to an old version of
+ the PKCS#5 standard (v1.5) or `RFC2898
+ <https://www.ietf.org/rfc/rfc2898.txt>`_.
+
+ Args:
+ password (string):
+ The secret password to generate the key from.
+ salt (byte string):
+ An 8 byte string to use for better protection from dictionary attacks.
+ This value does not need to be kept secret, but it should be randomly
+ chosen for each derivation.
+ dkLen (integer):
+ The length of the desired key. The default is 16 bytes, suitable for
+ instance for :mod:`Crypto.Cipher.AES`.
+ count (integer):
+ The number of iterations to carry out. The recommendation is 1000 or
+ more.
+ hashAlgo (module):
+ The hash algorithm to use, as a module or an object from the :mod:`Crypto.Hash` package.
+ The digest length must be no shorter than ``dkLen``.
+ The default algorithm is :mod:`Crypto.Hash.SHA1`.
+
+ Return:
+ A byte string of length ``dkLen`` that can be used as key.
+ """
+
+ if not hashAlgo:
+ hashAlgo = SHA1
+ password = tobytes(password)
+ pHash = hashAlgo.new(password+salt)
+ digest = pHash.digest_size
+ if dkLen > digest:
+ raise TypeError("Selected hash algorithm has a too short digest (%d bytes)." % digest)
+ if len(salt) != 8:
+ raise ValueError("Salt is not 8 bytes long (%d bytes instead)." % len(salt))
+ for i in iter_range(count-1):
+ pHash = pHash.new(pHash.digest())
+ return pHash.digest()[:dkLen]
+
+
+def PBKDF2(password, salt, dkLen=16, count=1000, prf=None, hmac_hash_module=None):
+ """Derive one or more keys from a password (or passphrase).
+
+ This function performs key derivation according to the PKCS#5 standard (v2.0).
+
+ Args:
+ password (string or byte string):
+ The secret password to generate the key from.
+ salt (string or byte string):
+ A (byte) string to use for better protection from dictionary attacks.
+ This value does not need to be kept secret, but it should be randomly
+ chosen for each derivation. It is recommended to use at least 16 bytes.
+ dkLen (integer):
+ The cumulative length of the keys to produce.
+
+ Due to a flaw in the PBKDF2 design, you should not request more bytes
+ than the ``prf`` can output. For instance, ``dkLen`` should not exceed
+ 20 bytes in combination with ``HMAC-SHA1``.
+ count (integer):
+ The number of iterations to carry out. The higher the value, the slower
+ and the more secure the function becomes.
+
+ You should find the maximum number of iterations that keeps the
+ key derivation still acceptable on the slowest hardware you must support.
+
+ Although the default value is 1000, **it is recommended to use at least
+ 1000000 (1 million) iterations**.
+ prf (callable):
+ A pseudorandom function. It must be a function that returns a
+ pseudorandom byte string from two parameters: a secret and a salt.
+ The slower the algorithm, the more secure the derivation function.
+ If not specified, **HMAC-SHA1** is used.
+ hmac_hash_module (module):
+ A module from ``Crypto.Hash`` implementing a Merkle-Damgard cryptographic
+ hash, which PBKDF2 must use in combination with HMAC.
+ This parameter is mutually exclusive with ``prf``.
+
+ Return:
+ A byte string of length ``dkLen`` that can be used as key material.
+ If you want multiple keys, just break up this string into segments of the desired length.
+ """
+
+ password = tobytes(password)
+ salt = tobytes(salt)
+
+ if prf and hmac_hash_module:
+ raise ValueError("'prf' and 'hmac_hash_module' are mutually exlusive")
+
+ if prf is None and hmac_hash_module is None:
+ hmac_hash_module = SHA1
+
+ if prf or not hasattr(hmac_hash_module, "_pbkdf2_hmac_assist"):
+ # Generic (and slow) implementation
+
+ if prf is None:
+ prf = lambda p,s: HMAC.new(p, s, hmac_hash_module).digest()
+
+ def link(s):
+ s[0], s[1] = s[1], prf(password, s[1])
+ return s[0]
+
+ key = b''
+ i = 1
+ while len(key) < dkLen:
+ s = [ prf(password, salt + struct.pack(">I", i)) ] * 2
+ key += reduce(strxor, (link(s) for j in range(count)) )
+ i += 1
+
+ else:
+ # Optimized implementation
+ key = b''
+ i = 1
+ while len(key)<dkLen:
+ base = HMAC.new(password, b"", hmac_hash_module)
+ first_digest = base.copy().update(salt + struct.pack(">I", i)).digest()
+ key += base._pbkdf2_hmac_assist(first_digest, count)
+ i += 1
+
+ return key[:dkLen]
+
+
+class _S2V(object):
+ """String-to-vector PRF as defined in `RFC5297`_.
+
+ This class implements a pseudorandom function family
+ based on CMAC that takes as input a vector of strings.
+
+ .. _RFC5297: http://tools.ietf.org/html/rfc5297
+ """
+
+ def __init__(self, key, ciphermod, cipher_params=None):
+ """Initialize the S2V PRF.
+
+ :Parameters:
+ key : byte string
+ A secret that can be used as key for CMACs
+ based on ciphers from ``ciphermod``.
+ ciphermod : module
+ A block cipher module from `Crypto.Cipher`.
+ cipher_params : dictionary
+ A set of extra parameters to use to create a cipher instance.
+ """
+
+ self._key = _copy_bytes(None, None, key)
+ self._ciphermod = ciphermod
+ self._last_string = self._cache = b'\x00' * ciphermod.block_size
+
+ # Max number of update() call we can process
+ self._n_updates = ciphermod.block_size * 8 - 1
+
+ if cipher_params is None:
+ self._cipher_params = {}
+ else:
+ self._cipher_params = dict(cipher_params)
+
+ @staticmethod
+ def new(key, ciphermod):
+ """Create a new S2V PRF.
+
+ :Parameters:
+ key : byte string
+ A secret that can be used as key for CMACs
+ based on ciphers from ``ciphermod``.
+ ciphermod : module
+ A block cipher module from `Crypto.Cipher`.
+ """
+ return _S2V(key, ciphermod)
+
+ def _double(self, bs):
+ doubled = bytes_to_long(bs)<<1
+ if bord(bs[0]) & 0x80:
+ doubled ^= 0x87
+ return long_to_bytes(doubled, len(bs))[-len(bs):]
+
+ def update(self, item):
+ """Pass the next component of the vector.
+
+ The maximum number of components you can pass is equal to the block
+ length of the cipher (in bits) minus 1.
+
+ :Parameters:
+ item : byte string
+ The next component of the vector.
+ :Raise TypeError: when the limit on the number of components has been reached.
+ """
+
+ if self._n_updates == 0:
+ raise TypeError("Too many components passed to S2V")
+ self._n_updates -= 1
+
+ mac = CMAC.new(self._key,
+ msg=self._last_string,
+ ciphermod=self._ciphermod,
+ cipher_params=self._cipher_params)
+ self._cache = strxor(self._double(self._cache), mac.digest())
+ self._last_string = _copy_bytes(None, None, item)
+
+ def derive(self):
+ """"Derive a secret from the vector of components.
+
+ :Return: a byte string, as long as the block length of the cipher.
+ """
+
+ if len(self._last_string) >= 16:
+ # xorend
+ final = self._last_string[:-16] + strxor(self._last_string[-16:], self._cache)
+ else:
+ # zero-pad & xor
+ padded = (self._last_string + b'\x80' + b'\x00' * 15)[:16]
+ final = strxor(padded, self._double(self._cache))
+ mac = CMAC.new(self._key,
+ msg=final,
+ ciphermod=self._ciphermod,
+ cipher_params=self._cipher_params)
+ return mac.digest()
+
+
+def HKDF(master, key_len, salt, hashmod, num_keys=1, context=None):
+ """Derive one or more keys from a master secret using
+ the HMAC-based KDF defined in RFC5869_.
+
+ Args:
+ master (byte string):
+ The unguessable value used by the KDF to generate the other keys.
+ It must be a high-entropy secret, though not necessarily uniform.
+ It must not be a password.
+ salt (byte string):
+ A non-secret, reusable value that strengthens the randomness
+ extraction step.
+ Ideally, it is as long as the digest size of the chosen hash.
+ If empty, a string of zeroes in used.
+ key_len (integer):
+ The length in bytes of every derived key.
+ hashmod (module):
+ A cryptographic hash algorithm from :mod:`Crypto.Hash`.
+ :mod:`Crypto.Hash.SHA512` is a good choice.
+ num_keys (integer):
+ The number of keys to derive. Every key is :data:`key_len` bytes long.
+ The maximum cumulative length of all keys is
+ 255 times the digest size.
+ context (byte string):
+ Optional identifier describing what the keys are used for.
+
+ Return:
+ A byte string or a tuple of byte strings.
+
+ .. _RFC5869: http://tools.ietf.org/html/rfc5869
+ """
+
+ output_len = key_len * num_keys
+ if output_len > (255 * hashmod.digest_size):
+ raise ValueError("Too much secret data to derive")
+ if not salt:
+ salt = b'\x00' * hashmod.digest_size
+ if context is None:
+ context = b""
+
+ # Step 1: extract
+ hmac = HMAC.new(salt, master, digestmod=hashmod)
+ prk = hmac.digest()
+
+ # Step 2: expand
+ t = [ b"" ]
+ n = 1
+ tlen = 0
+ while tlen < output_len:
+ hmac = HMAC.new(prk, t[-1] + context + struct.pack('B', n), digestmod=hashmod)
+ t.append(hmac.digest())
+ tlen += hashmod.digest_size
+ n += 1
+ derived_output = b"".join(t)
+ if num_keys == 1:
+ return derived_output[:key_len]
+ kol = [derived_output[idx:idx + key_len]
+ for idx in iter_range(0, output_len, key_len)]
+ return list(kol[:num_keys])
+
+
+
+def scrypt(password, salt, key_len, N, r, p, num_keys=1):
+ """Derive one or more keys from a passphrase.
+
+ Args:
+ password (string):
+ The secret pass phrase to generate the keys from.
+ salt (string):
+ A string to use for better protection from dictionary attacks.
+ This value does not need to be kept secret,
+ but it should be randomly chosen for each derivation.
+ It is recommended to be at least 16 bytes long.
+ key_len (integer):
+ The length in bytes of every derived key.
+ N (integer):
+ CPU/Memory cost parameter. It must be a power of 2 and less
+ than :math:`2^{32}`.
+ r (integer):
+ Block size parameter.
+ p (integer):
+ Parallelization parameter.
+ It must be no greater than :math:`(2^{32}-1)/(4r)`.
+ num_keys (integer):
+ The number of keys to derive. Every key is :data:`key_len` bytes long.
+ By default, only 1 key is generated.
+ The maximum cumulative length of all keys is :math:`(2^{32}-1)*32`
+ (that is, 128TB).
+
+ A good choice of parameters *(N, r , p)* was suggested
+ by Colin Percival in his `presentation in 2009`__:
+
+ - *( 2¹⁴, 8, 1 )* for interactive logins (≤100ms)
+ - *( 2²⁰, 8, 1 )* for file encryption (≤5s)
+
+ Return:
+ A byte string or a tuple of byte strings.
+
+ .. __: http://www.tarsnap.com/scrypt/scrypt-slides.pdf
+ """
+
+ if 2 ** (bit_size(N) - 1) != N:
+ raise ValueError("N must be a power of 2")
+ if N >= 2 ** 32:
+ raise ValueError("N is too big")
+ if p > ((2 ** 32 - 1) * 32) // (128 * r):
+ raise ValueError("p or r are too big")
+
+ prf_hmac_sha256 = lambda p, s: HMAC.new(p, s, SHA256).digest()
+
+ stage_1 = PBKDF2(password, salt, p * 128 * r, 1, prf=prf_hmac_sha256)
+
+ scryptROMix = _raw_scrypt_lib.scryptROMix
+ core = _raw_salsa20_lib.Salsa20_8_core
+
+ # Parallelize into p flows
+ data_out = []
+ for flow in iter_range(p):
+ idx = flow * 128 * r
+ buffer_out = create_string_buffer(128 * r)
+ result = scryptROMix(stage_1[idx : idx + 128 * r],
+ buffer_out,
+ c_size_t(128 * r),
+ N,
+ core)
+ if result:
+ raise ValueError("Error %X while running scrypt" % result)
+ data_out += [ get_raw_buffer(buffer_out) ]
+
+ dk = PBKDF2(password,
+ b"".join(data_out),
+ key_len * num_keys, 1,
+ prf=prf_hmac_sha256)
+
+ if num_keys == 1:
+ return dk
+
+ kol = [dk[idx:idx + key_len]
+ for idx in iter_range(0, key_len * num_keys, key_len)]
+ return kol
+
+
+def _bcrypt_encode(data):
+ s = "./ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
+
+ bits = []
+ for c in data:
+ bits_c = bin(bord(c))[2:].zfill(8)
+ bits.append(bstr(bits_c))
+ bits = b"".join(bits)
+
+ bits6 = [ bits[idx:idx+6] for idx in range(0, len(bits), 6) ]
+
+ result = []
+ for g in bits6[:-1]:
+ idx = int(g, 2)
+ result.append(s[idx])
+
+ g = bits6[-1]
+ idx = int(g, 2) << (6 - len(g))
+ result.append(s[idx])
+ result = "".join(result)
+
+ return tobytes(result)
+
+
+def _bcrypt_decode(data):
+ s = "./ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
+
+ bits = []
+ for c in tostr(data):
+ idx = s.find(c)
+ bits6 = bin(idx)[2:].zfill(6)
+ bits.append(bits6)
+ bits = "".join(bits)
+
+ modulo4 = len(data) % 4
+ if modulo4 == 1:
+ raise ValueError("Incorrect length")
+ elif modulo4 == 2:
+ bits = bits[:-4]
+ elif modulo4 == 3:
+ bits = bits[:-2]
+
+ bits8 = [ bits[idx:idx+8] for idx in range(0, len(bits), 8) ]
+
+ result = []
+ for g in bits8:
+ result.append(bchr(int(g, 2)))
+ result = b"".join(result)
+
+ return result
+
+
+def _bcrypt_hash(password, cost, salt, constant, invert):
+ from Crypto.Cipher import _EKSBlowfish
+
+ if len(password) > 72:
+ raise ValueError("The password is too long. It must be 72 bytes at most.")
+
+ if not (4 <= cost <= 31):
+ raise ValueError("bcrypt cost factor must be in the range 4..31")
+
+ cipher = _EKSBlowfish.new(password, _EKSBlowfish.MODE_ECB, salt, cost, invert)
+ ctext = constant
+ for _ in range(64):
+ ctext = cipher.encrypt(ctext)
+ return ctext
+
+
+def bcrypt(password, cost, salt=None):
+ """Hash a password into a key, using the OpenBSD bcrypt protocol.
+
+ Args:
+ password (byte string or string):
+ The secret password or pass phrase.
+ It must be at most 72 bytes long.
+ It must not contain the zero byte.
+ Unicode strings will be encoded as UTF-8.
+ cost (integer):
+ The exponential factor that makes it slower to compute the hash.
+ It must be in the range 4 to 31.
+ A value of at least 12 is recommended.
+ salt (byte string):
+ Optional. Random byte string to thwarts dictionary and rainbow table
+ attacks. It must be 16 bytes long.
+ If not passed, a random value is generated.
+
+ Return (byte string):
+ The bcrypt hash
+
+ Raises:
+ ValueError: if password is longer than 72 bytes or if it contains the zero byte
+
+ """
+
+ password = tobytes(password, "utf-8")
+
+ if password.find(bchr(0)[0]) != -1:
+ raise ValueError("The password contains the zero byte")
+
+ if len(password) < 72:
+ password += b"\x00"
+
+ if salt is None:
+ salt = get_random_bytes(16)
+ if len(salt) != 16:
+ raise ValueError("bcrypt salt must be 16 bytes long")
+
+ ctext = _bcrypt_hash(password, cost, salt, b"OrpheanBeholderScryDoubt", True)
+
+ cost_enc = b"$" + bstr(str(cost).zfill(2))
+ salt_enc = b"$" + _bcrypt_encode(salt)
+ hash_enc = _bcrypt_encode(ctext[:-1]) # only use 23 bytes, not 24
+ return b"$2a" + cost_enc + salt_enc + hash_enc
+
+
+def bcrypt_check(password, bcrypt_hash):
+ """Verify if the provided password matches the given bcrypt hash.
+
+ Args:
+ password (byte string or string):
+ The secret password or pass phrase to test.
+ It must be at most 72 bytes long.
+ It must not contain the zero byte.
+ Unicode strings will be encoded as UTF-8.
+ bcrypt_hash (byte string, bytearray):
+ The reference bcrypt hash the password needs to be checked against.
+
+ Raises:
+ ValueError: if the password does not match
+ """
+
+ bcrypt_hash = tobytes(bcrypt_hash)
+
+ if len(bcrypt_hash) != 60:
+ raise ValueError("Incorrect length of the bcrypt hash: %d bytes instead of 60" % len(bcrypt_hash))
+
+ if bcrypt_hash[:4] != b'$2a$':
+ raise ValueError("Unsupported prefix")
+
+ p = re.compile(br'\$2a\$([0-9][0-9])\$([A-Za-z0-9./]{22,22})([A-Za-z0-9./]{31,31})')
+ r = p.match(bcrypt_hash)
+ if not r:
+ raise ValueError("Incorrect bcrypt hash format")
+
+ cost = int(r.group(1))
+ if not (4 <= cost <= 31):
+ raise ValueError("Incorrect cost")
+
+ salt = _bcrypt_decode(r.group(2))
+
+ bcrypt_hash2 = bcrypt(password, cost, salt)
+
+ secret = get_random_bytes(16)
+
+ mac1 = BLAKE2s.new(digest_bits=160, key=secret, data=bcrypt_hash).digest()
+ mac2 = BLAKE2s.new(digest_bits=160, key=secret, data=bcrypt_hash2).digest()
+ if mac1 != mac2:
+ raise ValueError("Incorrect bcrypt hash")
diff --git a/lib/Crypto/Protocol/KDF.pyi b/lib/Crypto/Protocol/KDF.pyi
new file mode 100644
index 0000000..fb004bf
--- /dev/null
+++ b/lib/Crypto/Protocol/KDF.pyi
@@ -0,0 +1,24 @@
+from types import ModuleType
+from typing import Optional, Callable, Tuple, Union, Dict, Any
+
+RNG = Callable[[int], bytes]
+
+def PBKDF1(password: str, salt: bytes, dkLen: int, count: Optional[int]=1000, hashAlgo: Optional[ModuleType]=None) -> bytes: ...
+def PBKDF2(password: str, salt: bytes, dkLen: Optional[int]=16, count: Optional[int]=1000, prf: Optional[RNG]=None, hmac_hash_module: Optional[ModuleType]=None) -> bytes: ...
+
+class _S2V(object):
+ def __init__(self, key: bytes, ciphermod: ModuleType, cipher_params: Optional[Dict[Any, Any]]=None) -> None: ...
+
+ @staticmethod
+ def new(key: bytes, ciphermod: ModuleType) -> None: ...
+ def update(self, item: bytes) -> None: ...
+ def derive(self) -> bytes: ...
+
+def HKDF(master: bytes, key_len: int, salt: bytes, hashmod: ModuleType, num_keys: Optional[int]=1, context: Optional[bytes]=None) -> Union[bytes, Tuple[bytes, ...]]: ...
+
+def scrypt(password: str, salt: str, key_len: int, N: int, r: int, p: int, num_keys: Optional[int]=1) -> Union[bytes, Tuple[bytes, ...]]: ...
+
+def _bcrypt_decode(data: bytes) -> bytes: ...
+def _bcrypt_hash(password:bytes , cost: int, salt: bytes, constant:bytes, invert:bool) -> bytes: ...
+def bcrypt(password: Union[bytes, str], cost: int, salt: Optional[bytes]=None) -> bytes: ...
+def bcrypt_check(password: Union[bytes, str], bcrypt_hash: Union[bytes, bytearray, str]) -> None: ...
diff --git a/lib/Crypto/Protocol/SecretSharing.py b/lib/Crypto/Protocol/SecretSharing.py
new file mode 100644
index 0000000..a757e7c
--- /dev/null
+++ b/lib/Crypto/Protocol/SecretSharing.py
@@ -0,0 +1,278 @@
+#
+# SecretSharing.py : distribute a secret amongst a group of participants
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import is_native_int
+from Crypto.Util import number
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+from Crypto.Random import get_random_bytes as rng
+
+
+def _mult_gf2(f1, f2):
+ """Multiply two polynomials in GF(2)"""
+
+ # Ensure f2 is the smallest
+ if f2 > f1:
+ f1, f2 = f2, f1
+ z = 0
+ while f2:
+ if f2 & 1:
+ z ^= f1
+ f1 <<= 1
+ f2 >>= 1
+ return z
+
+
+def _div_gf2(a, b):
+ """
+ Compute division of polynomials over GF(2).
+ Given a and b, it finds two polynomials q and r such that:
+
+ a = b*q + r with deg(r)<deg(b)
+ """
+
+ if (a < b):
+ return 0, a
+
+ deg = number.size
+ q = 0
+ r = a
+ d = deg(b)
+ while deg(r) >= d:
+ s = 1 << (deg(r) - d)
+ q ^= s
+ r ^= _mult_gf2(b, s)
+ return (q, r)
+
+
+class _Element(object):
+ """Element of GF(2^128) field"""
+
+ # The irreducible polynomial defining this field is 1+x+x^2+x^7+x^128
+ irr_poly = 1 + 2 + 4 + 128 + 2 ** 128
+
+ def __init__(self, encoded_value):
+ """Initialize the element to a certain value.
+
+ The value passed as parameter is internally encoded as
+ a 128-bit integer, where each bit represents a polynomial
+ coefficient. The LSB is the constant coefficient.
+ """
+
+ if is_native_int(encoded_value):
+ self._value = encoded_value
+ elif len(encoded_value) == 16:
+ self._value = bytes_to_long(encoded_value)
+ else:
+ raise ValueError("The encoded value must be an integer or a 16 byte string")
+
+ def __eq__(self, other):
+ return self._value == other._value
+
+ def __int__(self):
+ """Return the field element, encoded as a 128-bit integer."""
+ return self._value
+
+ def encode(self):
+ """Return the field element, encoded as a 16 byte string."""
+ return long_to_bytes(self._value, 16)
+
+ def __mul__(self, factor):
+
+ f1 = self._value
+ f2 = factor._value
+
+ # Make sure that f2 is the smallest, to speed up the loop
+ if f2 > f1:
+ f1, f2 = f2, f1
+
+ if self.irr_poly in (f1, f2):
+ return _Element(0)
+
+ mask1 = 2 ** 128
+ v, z = f1, 0
+ while f2:
+ # if f2 ^ 1: z ^= v
+ mask2 = int(bin(f2 & 1)[2:] * 128, base=2)
+ z = (mask2 & (z ^ v)) | ((mask1 - mask2 - 1) & z)
+ v <<= 1
+ # if v & mask1: v ^= self.irr_poly
+ mask3 = int(bin((v >> 128) & 1)[2:] * 128, base=2)
+ v = (mask3 & (v ^ self.irr_poly)) | ((mask1 - mask3 - 1) & v)
+ f2 >>= 1
+ return _Element(z)
+
+ def __add__(self, term):
+ return _Element(self._value ^ term._value)
+
+ def inverse(self):
+ """Return the inverse of this element in GF(2^128)."""
+
+ # We use the Extended GCD algorithm
+ # http://en.wikipedia.org/wiki/Polynomial_greatest_common_divisor
+
+ if self._value == 0:
+ raise ValueError("Inversion of zero")
+
+ r0, r1 = self._value, self.irr_poly
+ s0, s1 = 1, 0
+ while r1 > 0:
+ q = _div_gf2(r0, r1)[0]
+ r0, r1 = r1, r0 ^ _mult_gf2(q, r1)
+ s0, s1 = s1, s0 ^ _mult_gf2(q, s1)
+ return _Element(s0)
+
+ def __pow__(self, exponent):
+ result = _Element(self._value)
+ for _ in range(exponent - 1):
+ result = result * self
+ return result
+
+
+class Shamir(object):
+ """Shamir's secret sharing scheme.
+
+ A secret is split into ``n`` shares, and it is sufficient to collect
+ ``k`` of them to reconstruct the secret.
+ """
+
+ @staticmethod
+ def split(k, n, secret, ssss=False):
+ """Split a secret into ``n`` shares.
+
+ The secret can be reconstructed later using just ``k`` shares
+ out of the original ``n``.
+ Each share must be kept confidential to the person it was
+ assigned to.
+
+ Each share is associated to an index (starting from 1).
+
+ Args:
+ k (integer):
+ The sufficient number of shares to reconstruct the secret (``k < n``).
+ n (integer):
+ The number of shares that this method will create.
+ secret (byte string):
+ A byte string of 16 bytes (e.g. the AES 128 key).
+ ssss (bool):
+ If ``True``, the shares can be used with the ``ssss`` utility.
+ Default: ``False``.
+
+ Return (tuples):
+ ``n`` tuples. A tuple is meant for each participant and it contains two items:
+
+ 1. the unique index (an integer)
+ 2. the share (a byte string, 16 bytes)
+ """
+
+ #
+ # We create a polynomial with random coefficients in GF(2^128):
+ #
+ # p(x) = \sum_{i=0}^{k-1} c_i * x^i
+ #
+ # c_0 is the encoded secret
+ #
+
+ coeffs = [_Element(rng(16)) for i in range(k - 1)]
+ coeffs.append(_Element(secret))
+
+ # Each share is y_i = p(x_i) where x_i is the public index
+ # associated to each of the n users.
+
+ def make_share(user, coeffs, ssss):
+ idx = _Element(user)
+ share = _Element(0)
+ for coeff in coeffs:
+ share = idx * share + coeff
+ if ssss:
+ share += _Element(user) ** len(coeffs)
+ return share.encode()
+
+ return [(i, make_share(i, coeffs, ssss)) for i in range(1, n + 1)]
+
+ @staticmethod
+ def combine(shares, ssss=False):
+ """Recombine a secret, if enough shares are presented.
+
+ Args:
+ shares (tuples):
+ The *k* tuples, each containin the index (an integer) and
+ the share (a byte string, 16 bytes long) that were assigned to
+ a participant.
+ ssss (bool):
+ If ``True``, the shares were produced by the ``ssss`` utility.
+ Default: ``False``.
+
+ Return:
+ The original secret, as a byte string (16 bytes long).
+ """
+
+ #
+ # Given k points (x,y), the interpolation polynomial of degree k-1 is:
+ #
+ # L(x) = \sum_{j=0}^{k-1} y_i * l_j(x)
+ #
+ # where:
+ #
+ # l_j(x) = \prod_{ \overset{0 \le m \le k-1}{m \ne j} }
+ # \frac{x - x_m}{x_j - x_m}
+ #
+ # However, in this case we are purely interested in the constant
+ # coefficient of L(x).
+ #
+
+ k = len(shares)
+
+ gf_shares = []
+ for x in shares:
+ idx = _Element(x[0])
+ value = _Element(x[1])
+ if any(y[0] == idx for y in gf_shares):
+ raise ValueError("Duplicate share")
+ if ssss:
+ value += idx ** k
+ gf_shares.append((idx, value))
+
+ result = _Element(0)
+ for j in range(k):
+ x_j, y_j = gf_shares[j]
+
+ numerator = _Element(1)
+ denominator = _Element(1)
+
+ for m in range(k):
+ x_m = gf_shares[m][0]
+ if m != j:
+ numerator *= x_m
+ denominator *= x_j + x_m
+ result += y_j * numerator * denominator.inverse()
+ return result.encode()
diff --git a/lib/Crypto/Protocol/SecretSharing.pyi b/lib/Crypto/Protocol/SecretSharing.pyi
new file mode 100644
index 0000000..5952c99
--- /dev/null
+++ b/lib/Crypto/Protocol/SecretSharing.pyi
@@ -0,0 +1,22 @@
+from typing import Union, List, Tuple, Optional
+
+def _mult_gf2(f1: int, f2: int) -> int : ...
+def _div_gf2(a: int, b: int) -> int : ...
+
+class _Element(object):
+ irr_poly: int
+ def __init__(self, encoded_value: Union[int, bytes]) -> None: ...
+ def __eq__(self, other) -> bool: ...
+ def __int__(self) -> int: ...
+ def encode(self) -> bytes: ...
+ def __mul__(self, factor: int) -> _Element: ...
+ def __add__(self, term: _Element) -> _Element: ...
+ def inverse(self) -> _Element: ...
+ def __pow__(self, exponent) -> _Element: ...
+
+class Shamir(object):
+ @staticmethod
+ def split(k: int, n: int, secret: bytes, ssss: Optional[bool]) -> List[Tuple[int, bytes]]: ...
+ @staticmethod
+ def combine(shares: List[Tuple[int, bytes]], ssss: Optional[bool]) -> bytes: ...
+
diff --git a/lib/Crypto/Protocol/__init__.py b/lib/Crypto/Protocol/__init__.py
new file mode 100644
index 0000000..efdf034
--- /dev/null
+++ b/lib/Crypto/Protocol/__init__.py
@@ -0,0 +1,31 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+__all__ = ['KDF', 'SecretSharing']
diff --git a/lib/Crypto/Protocol/__init__.pyi b/lib/Crypto/Protocol/__init__.pyi
new file mode 100644
index 0000000..377ed90
--- /dev/null
+++ b/lib/Crypto/Protocol/__init__.pyi
@@ -0,0 +1 @@
+__all__ = ['KDF.pyi', 'SecretSharing.pyi']
diff --git a/lib/Crypto/Protocol/_scrypt.abi3.so b/lib/Crypto/Protocol/_scrypt.abi3.so
new file mode 100755
index 0000000..3cf3eff
--- /dev/null
+++ b/lib/Crypto/Protocol/_scrypt.abi3.so
Binary files differ
diff --git a/lib/Crypto/PublicKey/DSA.py b/lib/Crypto/PublicKey/DSA.py
new file mode 100644
index 0000000..4c7f47b
--- /dev/null
+++ b/lib/Crypto/PublicKey/DSA.py
@@ -0,0 +1,682 @@
+# -*- coding: utf-8 -*-
+#
+# PublicKey/DSA.py : DSA signature primitive
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+__all__ = ['generate', 'construct', 'DsaKey', 'import_key' ]
+
+import binascii
+import struct
+import itertools
+
+from Crypto.Util.py3compat import bchr, bord, tobytes, tostr, iter_range
+
+from Crypto import Random
+from Crypto.IO import PKCS8, PEM
+from Crypto.Hash import SHA256
+from Crypto.Util.asn1 import (
+ DerObject, DerSequence,
+ DerInteger, DerObjectId,
+ DerBitString,
+ )
+
+from Crypto.Math.Numbers import Integer
+from Crypto.Math.Primality import (test_probable_prime, COMPOSITE,
+ PROBABLY_PRIME)
+
+from Crypto.PublicKey import (_expand_subject_public_key_info,
+ _create_subject_public_key_info,
+ _extract_subject_public_key_info)
+
+# ; The following ASN.1 types are relevant for DSA
+#
+# SubjectPublicKeyInfo ::= SEQUENCE {
+# algorithm AlgorithmIdentifier,
+# subjectPublicKey BIT STRING
+# }
+#
+# id-dsa ID ::= { iso(1) member-body(2) us(840) x9-57(10040) x9cm(4) 1 }
+#
+# ; See RFC3279
+# Dss-Parms ::= SEQUENCE {
+# p INTEGER,
+# q INTEGER,
+# g INTEGER
+# }
+#
+# DSAPublicKey ::= INTEGER
+#
+# DSSPrivatKey_OpenSSL ::= SEQUENCE
+# version INTEGER,
+# p INTEGER,
+# q INTEGER,
+# g INTEGER,
+# y INTEGER,
+# x INTEGER
+# }
+#
+
+class DsaKey(object):
+ r"""Class defining an actual DSA key.
+ Do not instantiate directly.
+ Use :func:`generate`, :func:`construct` or :func:`import_key` instead.
+
+ :ivar p: DSA modulus
+ :vartype p: integer
+
+ :ivar q: Order of the subgroup
+ :vartype q: integer
+
+ :ivar g: Generator
+ :vartype g: integer
+
+ :ivar y: Public key
+ :vartype y: integer
+
+ :ivar x: Private key
+ :vartype x: integer
+
+ :undocumented: exportKey, publickey
+ """
+
+ _keydata = ['y', 'g', 'p', 'q', 'x']
+
+ def __init__(self, key_dict):
+ input_set = set(key_dict.keys())
+ public_set = set(('y' , 'g', 'p', 'q'))
+ if not public_set.issubset(input_set):
+ raise ValueError("Some DSA components are missing = %s" %
+ str(public_set - input_set))
+ extra_set = input_set - public_set
+ if extra_set and extra_set != set(('x',)):
+ raise ValueError("Unknown DSA components = %s" %
+ str(extra_set - set(('x',))))
+ self._key = dict(key_dict)
+
+ def _sign(self, m, k):
+ if not self.has_private():
+ raise TypeError("DSA public key cannot be used for signing")
+ if not (1 < k < self.q):
+ raise ValueError("k is not between 2 and q-1")
+
+ x, q, p, g = [self._key[comp] for comp in ['x', 'q', 'p', 'g']]
+
+ blind_factor = Integer.random_range(min_inclusive=1,
+ max_exclusive=q)
+ inv_blind_k = (blind_factor * k).inverse(q)
+ blind_x = x * blind_factor
+
+ r = pow(g, k, p) % q # r = (g**k mod p) mod q
+ s = (inv_blind_k * (blind_factor * m + blind_x * r)) % q
+ return map(int, (r, s))
+
+ def _verify(self, m, sig):
+ r, s = sig
+ y, q, p, g = [self._key[comp] for comp in ['y', 'q', 'p', 'g']]
+ if not (0 < r < q) or not (0 < s < q):
+ return False
+ w = Integer(s).inverse(q)
+ u1 = (w * m) % q
+ u2 = (w * r) % q
+ v = (pow(g, u1, p) * pow(y, u2, p) % p) % q
+ return v == r
+
+ def has_private(self):
+ """Whether this is a DSA private key"""
+
+ return 'x' in self._key
+
+ def can_encrypt(self): # legacy
+ return False
+
+ def can_sign(self): # legacy
+ return True
+
+ def public_key(self):
+ """A matching DSA public key.
+
+ Returns:
+ a new :class:`DsaKey` object
+ """
+
+ public_components = dict((k, self._key[k]) for k in ('y', 'g', 'p', 'q'))
+ return DsaKey(public_components)
+
+ def __eq__(self, other):
+ if bool(self.has_private()) != bool(other.has_private()):
+ return False
+
+ result = True
+ for comp in self._keydata:
+ result = result and (getattr(self._key, comp, None) ==
+ getattr(other._key, comp, None))
+ return result
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getstate__(self):
+ # DSA key is not pickable
+ from pickle import PicklingError
+ raise PicklingError
+
+ def domain(self):
+ """The DSA domain parameters.
+
+ Returns
+ tuple : (p,q,g)
+ """
+
+ return [int(self._key[comp]) for comp in ('p', 'q', 'g')]
+
+ def __repr__(self):
+ attrs = []
+ for k in self._keydata:
+ if k == 'p':
+ bits = Integer(self.p).size_in_bits()
+ attrs.append("p(%d)" % (bits,))
+ elif hasattr(self, k):
+ attrs.append(k)
+ if self.has_private():
+ attrs.append("private")
+ # PY3K: This is meant to be text, do not change to bytes (data)
+ return "<%s @0x%x %s>" % (self.__class__.__name__, id(self), ",".join(attrs))
+
+ def __getattr__(self, item):
+ try:
+ return int(self._key[item])
+ except KeyError:
+ raise AttributeError(item)
+
+ def export_key(self, format='PEM', pkcs8=None, passphrase=None,
+ protection=None, randfunc=None):
+ """Export this DSA key.
+
+ Args:
+ format (string):
+ The encoding for the output:
+
+ - *'PEM'* (default). ASCII as per `RFC1421`_/ `RFC1423`_.
+ - *'DER'*. Binary ASN.1 encoding.
+ - *'OpenSSH'*. ASCII one-liner as per `RFC4253`_.
+ Only suitable for public keys, not for private keys.
+
+ passphrase (string):
+ *Private keys only*. The pass phrase to protect the output.
+
+ pkcs8 (boolean):
+ *Private keys only*. If ``True`` (default), the key is encoded
+ with `PKCS#8`_. If ``False``, it is encoded in the custom
+ OpenSSL/OpenSSH container.
+
+ protection (string):
+ *Only in combination with a pass phrase*.
+ The encryption scheme to use to protect the output.
+
+ If :data:`pkcs8` takes value ``True``, this is the PKCS#8
+ algorithm to use for deriving the secret and encrypting
+ the private DSA key.
+ For a complete list of algorithms, see :mod:`Crypto.IO.PKCS8`.
+ The default is *PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC*.
+
+ If :data:`pkcs8` is ``False``, the obsolete PEM encryption scheme is
+ used. It is based on MD5 for key derivation, and Triple DES for
+ encryption. Parameter :data:`protection` is then ignored.
+
+ The combination ``format='DER'`` and ``pkcs8=False`` is not allowed
+ if a passphrase is present.
+
+ randfunc (callable):
+ A function that returns random bytes.
+ By default it is :func:`Crypto.Random.get_random_bytes`.
+
+ Returns:
+ byte string : the encoded key
+
+ Raises:
+ ValueError : when the format is unknown or when you try to encrypt a private
+ key with *DER* format and OpenSSL/OpenSSH.
+
+ .. warning::
+ If you don't provide a pass phrase, the private key will be
+ exported in the clear!
+
+ .. _RFC1421: http://www.ietf.org/rfc/rfc1421.txt
+ .. _RFC1423: http://www.ietf.org/rfc/rfc1423.txt
+ .. _RFC4253: http://www.ietf.org/rfc/rfc4253.txt
+ .. _`PKCS#8`: http://www.ietf.org/rfc/rfc5208.txt
+ """
+
+ if passphrase is not None:
+ passphrase = tobytes(passphrase)
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ if format == 'OpenSSH':
+ tup1 = [self._key[x].to_bytes() for x in ('p', 'q', 'g', 'y')]
+
+ def func(x):
+ if (bord(x[0]) & 0x80):
+ return bchr(0) + x
+ else:
+ return x
+
+ tup2 = [func(x) for x in tup1]
+ keyparts = [b'ssh-dss'] + tup2
+ keystring = b''.join(
+ [struct.pack(">I", len(kp)) + kp for kp in keyparts]
+ )
+ return b'ssh-dss ' + binascii.b2a_base64(keystring)[:-1]
+
+ # DER format is always used, even in case of PEM, which simply
+ # encodes it into BASE64.
+ params = DerSequence([self.p, self.q, self.g])
+ if self.has_private():
+ if pkcs8 is None:
+ pkcs8 = True
+ if pkcs8:
+ if not protection:
+ protection = 'PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC'
+ private_key = DerInteger(self.x).encode()
+ binary_key = PKCS8.wrap(
+ private_key, oid, passphrase,
+ protection, key_params=params,
+ randfunc=randfunc
+ )
+ if passphrase:
+ key_type = 'ENCRYPTED PRIVATE'
+ else:
+ key_type = 'PRIVATE'
+ passphrase = None
+ else:
+ if format != 'PEM' and passphrase:
+ raise ValueError("DSA private key cannot be encrypted")
+ ints = [0, self.p, self.q, self.g, self.y, self.x]
+ binary_key = DerSequence(ints).encode()
+ key_type = "DSA PRIVATE"
+ else:
+ if pkcs8:
+ raise ValueError("PKCS#8 is only meaningful for private keys")
+
+ binary_key = _create_subject_public_key_info(oid,
+ DerInteger(self.y), params)
+ key_type = "PUBLIC"
+
+ if format == 'DER':
+ return binary_key
+ if format == 'PEM':
+ pem_str = PEM.encode(
+ binary_key, key_type + " KEY",
+ passphrase, randfunc
+ )
+ return tobytes(pem_str)
+ raise ValueError("Unknown key format '%s'. Cannot export the DSA key." % format)
+
+ # Backward-compatibility
+ exportKey = export_key
+ publickey = public_key
+
+ # Methods defined in PyCrypto that we don't support anymore
+
+ def sign(self, M, K):
+ raise NotImplementedError("Use module Crypto.Signature.DSS instead")
+
+ def verify(self, M, signature):
+ raise NotImplementedError("Use module Crypto.Signature.DSS instead")
+
+ def encrypt(self, plaintext, K):
+ raise NotImplementedError
+
+ def decrypt(self, ciphertext):
+ raise NotImplementedError
+
+ def blind(self, M, B):
+ raise NotImplementedError
+
+ def unblind(self, M, B):
+ raise NotImplementedError
+
+ def size(self):
+ raise NotImplementedError
+
+
+def _generate_domain(L, randfunc):
+ """Generate a new set of DSA domain parameters"""
+
+ N = { 1024:160, 2048:224, 3072:256 }.get(L)
+ if N is None:
+ raise ValueError("Invalid modulus length (%d)" % L)
+
+ outlen = SHA256.digest_size * 8
+ n = (L + outlen - 1) // outlen - 1 # ceil(L/outlen) -1
+ b_ = L - 1 - (n * outlen)
+
+ # Generate q (A.1.1.2)
+ q = Integer(4)
+ upper_bit = 1 << (N - 1)
+ while test_probable_prime(q, randfunc) != PROBABLY_PRIME:
+ seed = randfunc(64)
+ U = Integer.from_bytes(SHA256.new(seed).digest()) & (upper_bit - 1)
+ q = U | upper_bit | 1
+
+ assert(q.size_in_bits() == N)
+
+ # Generate p (A.1.1.2)
+ offset = 1
+ upper_bit = 1 << (L - 1)
+ while True:
+ V = [ SHA256.new(seed + Integer(offset + j).to_bytes()).digest()
+ for j in iter_range(n + 1) ]
+ V = [ Integer.from_bytes(v) for v in V ]
+ W = sum([V[i] * (1 << (i * outlen)) for i in iter_range(n)],
+ (V[n] & ((1 << b_) - 1)) * (1 << (n * outlen)))
+
+ X = Integer(W + upper_bit) # 2^{L-1} < X < 2^{L}
+ assert(X.size_in_bits() == L)
+
+ c = X % (q * 2)
+ p = X - (c - 1) # 2q divides (p-1)
+ if p.size_in_bits() == L and \
+ test_probable_prime(p, randfunc) == PROBABLY_PRIME:
+ break
+ offset += n + 1
+
+ # Generate g (A.2.3, index=1)
+ e = (p - 1) // q
+ for count in itertools.count(1):
+ U = seed + b"ggen" + bchr(1) + Integer(count).to_bytes()
+ W = Integer.from_bytes(SHA256.new(U).digest())
+ g = pow(W, e, p)
+ if g != 1:
+ break
+
+ return (p, q, g, seed)
+
+
+def generate(bits, randfunc=None, domain=None):
+ """Generate a new DSA key pair.
+
+ The algorithm follows Appendix A.1/A.2 and B.1 of `FIPS 186-4`_,
+ respectively for domain generation and key pair generation.
+
+ Args:
+ bits (integer):
+ Key length, or size (in bits) of the DSA modulus *p*.
+ It must be 1024, 2048 or 3072.
+
+ randfunc (callable):
+ Random number generation function; it accepts a single integer N
+ and return a string of random data N bytes long.
+ If not specified, :func:`Crypto.Random.get_random_bytes` is used.
+
+ domain (tuple):
+ The DSA domain parameters *p*, *q* and *g* as a list of 3
+ integers. Size of *p* and *q* must comply to `FIPS 186-4`_.
+ If not specified, the parameters are created anew.
+
+ Returns:
+ :class:`DsaKey` : a new DSA key object
+
+ Raises:
+ ValueError : when **bits** is too little, too big, or not a multiple of 64.
+
+ .. _FIPS 186-4: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
+ """
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ if domain:
+ p, q, g = map(Integer, domain)
+
+ ## Perform consistency check on domain parameters
+ # P and Q must be prime
+ fmt_error = test_probable_prime(p) == COMPOSITE
+ fmt_error |= test_probable_prime(q) == COMPOSITE
+ # Verify Lagrange's theorem for sub-group
+ fmt_error |= ((p - 1) % q) != 0
+ fmt_error |= g <= 1 or g >= p
+ fmt_error |= pow(g, q, p) != 1
+ if fmt_error:
+ raise ValueError("Invalid DSA domain parameters")
+ else:
+ p, q, g, _ = _generate_domain(bits, randfunc)
+
+ L = p.size_in_bits()
+ N = q.size_in_bits()
+
+ if L != bits:
+ raise ValueError("Mismatch between size of modulus (%d)"
+ " and 'bits' parameter (%d)" % (L, bits))
+
+ if (L, N) not in [(1024, 160), (2048, 224),
+ (2048, 256), (3072, 256)]:
+ raise ValueError("Lengths of p and q (%d, %d) are not compatible"
+ "to FIPS 186-3" % (L, N))
+
+ if not 1 < g < p:
+ raise ValueError("Incorrent DSA generator")
+
+ # B.1.1
+ c = Integer.random(exact_bits=N + 64, randfunc=randfunc)
+ x = c % (q - 1) + 1 # 1 <= x <= q-1
+ y = pow(g, x, p)
+
+ key_dict = { 'y':y, 'g':g, 'p':p, 'q':q, 'x':x }
+ return DsaKey(key_dict)
+
+
+def construct(tup, consistency_check=True):
+ """Construct a DSA key from a tuple of valid DSA components.
+
+ Args:
+ tup (tuple):
+ A tuple of long integers, with 4 or 5 items
+ in the following order:
+
+ 1. Public key (*y*).
+ 2. Sub-group generator (*g*).
+ 3. Modulus, finite field order (*p*).
+ 4. Sub-group order (*q*).
+ 5. Private key (*x*). Optional.
+
+ consistency_check (boolean):
+ If ``True``, the library will verify that the provided components
+ fulfil the main DSA properties.
+
+ Raises:
+ ValueError: when the key being imported fails the most basic DSA validity checks.
+
+ Returns:
+ :class:`DsaKey` : a DSA key object
+ """
+
+ key_dict = dict(zip(('y', 'g', 'p', 'q', 'x'), map(Integer, tup)))
+ key = DsaKey(key_dict)
+
+ fmt_error = False
+ if consistency_check:
+ # P and Q must be prime
+ fmt_error = test_probable_prime(key.p) == COMPOSITE
+ fmt_error |= test_probable_prime(key.q) == COMPOSITE
+ # Verify Lagrange's theorem for sub-group
+ fmt_error |= ((key.p - 1) % key.q) != 0
+ fmt_error |= key.g <= 1 or key.g >= key.p
+ fmt_error |= pow(key.g, key.q, key.p) != 1
+ # Public key
+ fmt_error |= key.y <= 0 or key.y >= key.p
+ if hasattr(key, 'x'):
+ fmt_error |= key.x <= 0 or key.x >= key.q
+ fmt_error |= pow(key.g, key.x, key.p) != key.y
+
+ if fmt_error:
+ raise ValueError("Invalid DSA key components")
+
+ return key
+
+
+# Dss-Parms ::= SEQUENCE {
+# p OCTET STRING,
+# q OCTET STRING,
+# g OCTET STRING
+# }
+# DSAPublicKey ::= INTEGER -- public key, y
+
+def _import_openssl_private(encoded, passphrase, params):
+ if params:
+ raise ValueError("DSA private key already comes with parameters")
+ der = DerSequence().decode(encoded, nr_elements=6, only_ints_expected=True)
+ if der[0] != 0:
+ raise ValueError("No version found")
+ tup = [der[comp] for comp in (4, 3, 1, 2, 5)]
+ return construct(tup)
+
+
+def _import_subjectPublicKeyInfo(encoded, passphrase, params):
+
+ algoid, encoded_key, emb_params = _expand_subject_public_key_info(encoded)
+ if algoid != oid:
+ raise ValueError("No DSA subjectPublicKeyInfo")
+ if params and emb_params:
+ raise ValueError("Too many DSA parameters")
+
+ y = DerInteger().decode(encoded_key).value
+ p, q, g = list(DerSequence().decode(params or emb_params))
+ tup = (y, g, p, q)
+ return construct(tup)
+
+
+def _import_x509_cert(encoded, passphrase, params):
+
+ sp_info = _extract_subject_public_key_info(encoded)
+ return _import_subjectPublicKeyInfo(sp_info, None, params)
+
+
+def _import_pkcs8(encoded, passphrase, params):
+ if params:
+ raise ValueError("PKCS#8 already includes parameters")
+ k = PKCS8.unwrap(encoded, passphrase)
+ if k[0] != oid:
+ raise ValueError("No PKCS#8 encoded DSA key")
+ x = DerInteger().decode(k[1]).value
+ p, q, g = list(DerSequence().decode(k[2]))
+ tup = (pow(g, x, p), g, p, q, x)
+ return construct(tup)
+
+
+def _import_key_der(key_data, passphrase, params):
+ """Import a DSA key (public or private half), encoded in DER form."""
+
+ decodings = (_import_openssl_private,
+ _import_subjectPublicKeyInfo,
+ _import_x509_cert,
+ _import_pkcs8)
+
+ for decoding in decodings:
+ try:
+ return decoding(key_data, passphrase, params)
+ except ValueError:
+ pass
+
+ raise ValueError("DSA key format is not supported")
+
+
+def import_key(extern_key, passphrase=None):
+ """Import a DSA key.
+
+ Args:
+ extern_key (string or byte string):
+ The DSA key to import.
+
+ The following formats are supported for a DSA **public** key:
+
+ - X.509 certificate (binary DER or PEM)
+ - X.509 ``subjectPublicKeyInfo`` (binary DER or PEM)
+ - OpenSSH (ASCII one-liner, see `RFC4253`_)
+
+ The following formats are supported for a DSA **private** key:
+
+ - `PKCS#8`_ ``PrivateKeyInfo`` or ``EncryptedPrivateKeyInfo``
+ DER SEQUENCE (binary or PEM)
+ - OpenSSL/OpenSSH custom format (binary or PEM)
+
+ For details about the PEM encoding, see `RFC1421`_/`RFC1423`_.
+
+ passphrase (string):
+ In case of an encrypted private key, this is the pass phrase
+ from which the decryption key is derived.
+
+ Encryption may be applied either at the `PKCS#8`_ or at the PEM level.
+
+ Returns:
+ :class:`DsaKey` : a DSA key object
+
+ Raises:
+ ValueError : when the given key cannot be parsed (possibly because
+ the pass phrase is wrong).
+
+ .. _RFC1421: http://www.ietf.org/rfc/rfc1421.txt
+ .. _RFC1423: http://www.ietf.org/rfc/rfc1423.txt
+ .. _RFC4253: http://www.ietf.org/rfc/rfc4253.txt
+ .. _PKCS#8: http://www.ietf.org/rfc/rfc5208.txt
+ """
+
+ extern_key = tobytes(extern_key)
+ if passphrase is not None:
+ passphrase = tobytes(passphrase)
+
+ if extern_key.startswith(b'-----'):
+ # This is probably a PEM encoded key
+ (der, marker, enc_flag) = PEM.decode(tostr(extern_key), passphrase)
+ if enc_flag:
+ passphrase = None
+ return _import_key_der(der, passphrase, None)
+
+ if extern_key.startswith(b'ssh-dss '):
+ # This is probably a public OpenSSH key
+ keystring = binascii.a2b_base64(extern_key.split(b' ')[1])
+ keyparts = []
+ while len(keystring) > 4:
+ length = struct.unpack(">I", keystring[:4])[0]
+ keyparts.append(keystring[4:4 + length])
+ keystring = keystring[4 + length:]
+ if keyparts[0] == b"ssh-dss":
+ tup = [Integer.from_bytes(keyparts[x]) for x in (4, 3, 1, 2)]
+ return construct(tup)
+
+ if len(extern_key) > 0 and bord(extern_key[0]) == 0x30:
+ # This is probably a DER encoded key
+ return _import_key_der(extern_key, passphrase, None)
+
+ raise ValueError("DSA key format is not supported")
+
+
+# Backward compatibility
+importKey = import_key
+
+#: `Object ID`_ for a DSA key.
+#:
+#: id-dsa ID ::= { iso(1) member-body(2) us(840) x9-57(10040) x9cm(4) 1 }
+#:
+#: .. _`Object ID`: http://www.alvestrand.no/objectid/1.2.840.10040.4.1.html
+oid = "1.2.840.10040.4.1"
diff --git a/lib/Crypto/PublicKey/DSA.pyi b/lib/Crypto/PublicKey/DSA.pyi
new file mode 100644
index 0000000..354ac1f
--- /dev/null
+++ b/lib/Crypto/PublicKey/DSA.pyi
@@ -0,0 +1,31 @@
+from typing import Dict, Tuple, Callable, Union, Optional
+
+__all__ = ['generate', 'construct', 'DsaKey', 'import_key' ]
+
+RNG = Callable[[int], bytes]
+
+class DsaKey(object):
+ def __init__(self, key_dict: Dict[str, int]) -> None: ...
+ def has_private(self) -> bool: ...
+ def can_encrypt(self) -> bool: ... # legacy
+ def can_sign(self) -> bool: ... # legacy
+ def public_key(self) -> DsaKey: ...
+ def __eq__(self, other: object) -> bool: ...
+ def __ne__(self, other: object) -> bool: ...
+ def __getstate__(self) -> None: ...
+ def domain(self) -> Tuple[int, int, int]: ...
+ def __repr__(self) -> str: ...
+ def __getattr__(self, item: str) -> int: ...
+ def export_key(self, format: Optional[str]="PEM", pkcs8: Optional[bool]=None, passphrase: Optional[str]=None,
+ protection: Optional[str]=None, randfunc: Optional[RNG]=None) -> bytes: ...
+ # Backward-compatibility
+ exportKey = export_key
+ publickey = public_key
+
+def generate(bits: int, randfunc: Optional[RNG]=None, domain: Optional[Tuple[int, int, int]]=None) -> DsaKey: ...
+def construct(tup: Union[Tuple[int, int, int, int], Tuple[int, int, int, int, int]], consistency_check: Optional[bool]=True) -> DsaKey: ...
+def import_key(extern_key: Union[str, bytes], passphrase: Optional[str]=None) -> DsaKey: ...
+# Backward compatibility
+importKey = import_key
+
+oid: str
diff --git a/lib/Crypto/PublicKey/ECC.py b/lib/Crypto/PublicKey/ECC.py
new file mode 100644
index 0000000..0b605c4
--- /dev/null
+++ b/lib/Crypto/PublicKey/ECC.py
@@ -0,0 +1,1794 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from __future__ import print_function
+
+import re
+import struct
+import binascii
+from collections import namedtuple
+
+from Crypto.Util.py3compat import bord, tobytes, tostr, bchr, is_string
+from Crypto.Util.number import bytes_to_long, long_to_bytes
+
+from Crypto.Math.Numbers import Integer
+from Crypto.Util.asn1 import (DerObjectId, DerOctetString, DerSequence,
+ DerBitString)
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, VoidPointer,
+ SmartPointer, c_size_t, c_uint8_ptr,
+ c_ulonglong, null_pointer)
+
+from Crypto.PublicKey import (_expand_subject_public_key_info,
+ _create_subject_public_key_info,
+ _extract_subject_public_key_info)
+
+from Crypto.Hash import SHA512, SHAKE256
+
+from Crypto.Random import get_random_bytes
+from Crypto.Random.random import getrandbits
+
+
+_ec_lib = load_pycryptodome_raw_lib("Crypto.PublicKey._ec_ws", """
+typedef void EcContext;
+typedef void EcPoint;
+int ec_ws_new_context(EcContext **pec_ctx,
+ const uint8_t *modulus,
+ const uint8_t *b,
+ const uint8_t *order,
+ size_t len,
+ uint64_t seed);
+void ec_free_context(EcContext *ec_ctx);
+int ec_ws_new_point(EcPoint **pecp,
+ const uint8_t *x,
+ const uint8_t *y,
+ size_t len,
+ const EcContext *ec_ctx);
+void ec_ws_free_point(EcPoint *ecp);
+int ec_ws_get_xy(uint8_t *x,
+ uint8_t *y,
+ size_t len,
+ const EcPoint *ecp);
+int ec_ws_double(EcPoint *p);
+int ec_ws_add(EcPoint *ecpa, EcPoint *ecpb);
+int ec_ws_scalar(EcPoint *ecp,
+ const uint8_t *k,
+ size_t len,
+ uint64_t seed);
+int ec_ws_clone(EcPoint **pecp2, const EcPoint *ecp);
+int ec_ws_cmp(const EcPoint *ecp1, const EcPoint *ecp2);
+int ec_ws_neg(EcPoint *p);
+""")
+
+_ed25519_lib = load_pycryptodome_raw_lib("Crypto.PublicKey._ed25519", """
+typedef void Point;
+int ed25519_new_point(Point **out,
+ const uint8_t x[32],
+ const uint8_t y[32],
+ size_t modsize,
+ const void *context);
+int ed25519_clone(Point **P, const Point *Q);
+void ed25519_free_point(Point *p);
+int ed25519_cmp(const Point *p1, const Point *p2);
+int ed25519_neg(Point *p);
+int ed25519_get_xy(uint8_t *xb, uint8_t *yb, size_t modsize, Point *p);
+int ed25519_double(Point *p);
+int ed25519_add(Point *P1, const Point *P2);
+int ed25519_scalar(Point *P, uint8_t *scalar, size_t scalar_len, uint64_t seed);
+""")
+
+_ed448_lib = load_pycryptodome_raw_lib("Crypto.PublicKey._ed448", """
+typedef void EcContext;
+typedef void PointEd448;
+int ed448_new_context(EcContext **pec_ctx);
+void ed448_context(EcContext *ec_ctx);
+void ed448_free_context(EcContext *ec_ctx);
+int ed448_new_point(PointEd448 **out,
+ const uint8_t x[56],
+ const uint8_t y[56],
+ size_t len,
+ const EcContext *context);
+int ed448_clone(PointEd448 **P, const PointEd448 *Q);
+void ed448_free_point(PointEd448 *p);
+int ed448_cmp(const PointEd448 *p1, const PointEd448 *p2);
+int ed448_neg(PointEd448 *p);
+int ed448_get_xy(uint8_t *xb, uint8_t *yb, size_t len, const PointEd448 *p);
+int ed448_double(PointEd448 *p);
+int ed448_add(PointEd448 *P1, const PointEd448 *P2);
+int ed448_scalar(PointEd448 *P, const uint8_t *scalar, size_t scalar_len, uint64_t seed);
+""")
+
+
+def lib_func(ecc_obj, func_name):
+ if ecc_obj._curve.desc == "Ed25519":
+ result = getattr(_ed25519_lib, "ed25519_" + func_name)
+ elif ecc_obj._curve.desc == "Ed448":
+ result = getattr(_ed448_lib, "ed448_" + func_name)
+ else:
+ result = getattr(_ec_lib, "ec_ws_" + func_name)
+ return result
+
+#
+# _curves is a database of curve parameters. Items are indexed by their
+# human-friendly name, suchas "P-256". Each item has the following fields:
+# - p: the prime number that defines the finite field for all modulo operations
+# - b: the constant in the Short Weierstrass curve equation
+# - order: the number of elements in the group with the generator below
+# - Gx the affine coordinate X of the generator point
+# - Gy the affine coordinate Y of the generator point
+# - G the generator, as an EccPoint object
+# - modulus_bits the minimum number of bits for encoding the modulus p
+# - oid an ASCII string with the registered ASN.1 Object ID
+# - context a raw pointer to memory holding a context for all curve operations (can be NULL)
+# - desc an ASCII string describing the curve
+# - openssh the ASCII string used in OpenSSH id files for public keys on this curve
+# - name the ASCII string which is also a valid key in _curves
+
+
+_Curve = namedtuple("_Curve", "p b order Gx Gy G modulus_bits oid context desc openssh name")
+_curves = {}
+
+
+p192_names = ["p192", "NIST P-192", "P-192", "prime192v1", "secp192r1",
+ "nistp192"]
+
+
+def init_p192():
+ p = 0xfffffffffffffffffffffffffffffffeffffffffffffffff
+ b = 0x64210519e59c80e70fa7e9ab72243049feb8deecc146b9b1
+ order = 0xffffffffffffffffffffffff99def836146bc9b1b4d22831
+ Gx = 0x188da80eb03090f67cbf20eb43a18800f4ff0afd82ff1012
+ Gy = 0x07192b95ffc8da78631011ed6b24cdd573f977a11e794811
+
+ p192_modulus = long_to_bytes(p, 24)
+ p192_b = long_to_bytes(b, 24)
+ p192_order = long_to_bytes(order, 24)
+
+ ec_p192_context = VoidPointer()
+ result = _ec_lib.ec_ws_new_context(ec_p192_context.address_of(),
+ c_uint8_ptr(p192_modulus),
+ c_uint8_ptr(p192_b),
+ c_uint8_ptr(p192_order),
+ c_size_t(len(p192_modulus)),
+ c_ulonglong(getrandbits(64))
+ )
+ if result:
+ raise ImportError("Error %d initializing P-192 context" % result)
+
+ context = SmartPointer(ec_p192_context.get(), _ec_lib.ec_free_context)
+ p192 = _Curve(Integer(p),
+ Integer(b),
+ Integer(order),
+ Integer(Gx),
+ Integer(Gy),
+ None,
+ 192,
+ "1.2.840.10045.3.1.1", # ANSI X9.62 / SEC2
+ context,
+ "NIST P-192",
+ "ecdsa-sha2-nistp192",
+ "p192")
+ global p192_names
+ _curves.update(dict.fromkeys(p192_names, p192))
+
+
+init_p192()
+del init_p192
+
+
+p224_names = ["p224", "NIST P-224", "P-224", "prime224v1", "secp224r1",
+ "nistp224"]
+
+
+def init_p224():
+ p = 0xffffffffffffffffffffffffffffffff000000000000000000000001
+ b = 0xb4050a850c04b3abf54132565044b0b7d7bfd8ba270b39432355ffb4
+ order = 0xffffffffffffffffffffffffffff16a2e0b8f03e13dd29455c5c2a3d
+ Gx = 0xb70e0cbd6bb4bf7f321390b94a03c1d356c21122343280d6115c1d21
+ Gy = 0xbd376388b5f723fb4c22dfe6cd4375a05a07476444d5819985007e34
+
+ p224_modulus = long_to_bytes(p, 28)
+ p224_b = long_to_bytes(b, 28)
+ p224_order = long_to_bytes(order, 28)
+
+ ec_p224_context = VoidPointer()
+ result = _ec_lib.ec_ws_new_context(ec_p224_context.address_of(),
+ c_uint8_ptr(p224_modulus),
+ c_uint8_ptr(p224_b),
+ c_uint8_ptr(p224_order),
+ c_size_t(len(p224_modulus)),
+ c_ulonglong(getrandbits(64))
+ )
+ if result:
+ raise ImportError("Error %d initializing P-224 context" % result)
+
+ context = SmartPointer(ec_p224_context.get(), _ec_lib.ec_free_context)
+ p224 = _Curve(Integer(p),
+ Integer(b),
+ Integer(order),
+ Integer(Gx),
+ Integer(Gy),
+ None,
+ 224,
+ "1.3.132.0.33", # SEC 2
+ context,
+ "NIST P-224",
+ "ecdsa-sha2-nistp224",
+ "p224")
+ global p224_names
+ _curves.update(dict.fromkeys(p224_names, p224))
+
+
+init_p224()
+del init_p224
+
+
+p256_names = ["p256", "NIST P-256", "P-256", "prime256v1", "secp256r1",
+ "nistp256"]
+
+
+def init_p256():
+ p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff
+ b = 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b
+ order = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551
+ Gx = 0x6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296
+ Gy = 0x4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5
+
+ p256_modulus = long_to_bytes(p, 32)
+ p256_b = long_to_bytes(b, 32)
+ p256_order = long_to_bytes(order, 32)
+
+ ec_p256_context = VoidPointer()
+ result = _ec_lib.ec_ws_new_context(ec_p256_context.address_of(),
+ c_uint8_ptr(p256_modulus),
+ c_uint8_ptr(p256_b),
+ c_uint8_ptr(p256_order),
+ c_size_t(len(p256_modulus)),
+ c_ulonglong(getrandbits(64))
+ )
+ if result:
+ raise ImportError("Error %d initializing P-256 context" % result)
+
+ context = SmartPointer(ec_p256_context.get(), _ec_lib.ec_free_context)
+ p256 = _Curve(Integer(p),
+ Integer(b),
+ Integer(order),
+ Integer(Gx),
+ Integer(Gy),
+ None,
+ 256,
+ "1.2.840.10045.3.1.7", # ANSI X9.62 / SEC2
+ context,
+ "NIST P-256",
+ "ecdsa-sha2-nistp256",
+ "p256")
+ global p256_names
+ _curves.update(dict.fromkeys(p256_names, p256))
+
+
+init_p256()
+del init_p256
+
+
+p384_names = ["p384", "NIST P-384", "P-384", "prime384v1", "secp384r1",
+ "nistp384"]
+
+
+def init_p384():
+ p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff
+ b = 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef
+ order = 0xffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52973
+ Gx = 0xaa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760aB7
+ Gy = 0x3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5F
+
+ p384_modulus = long_to_bytes(p, 48)
+ p384_b = long_to_bytes(b, 48)
+ p384_order = long_to_bytes(order, 48)
+
+ ec_p384_context = VoidPointer()
+ result = _ec_lib.ec_ws_new_context(ec_p384_context.address_of(),
+ c_uint8_ptr(p384_modulus),
+ c_uint8_ptr(p384_b),
+ c_uint8_ptr(p384_order),
+ c_size_t(len(p384_modulus)),
+ c_ulonglong(getrandbits(64))
+ )
+ if result:
+ raise ImportError("Error %d initializing P-384 context" % result)
+
+ context = SmartPointer(ec_p384_context.get(), _ec_lib.ec_free_context)
+ p384 = _Curve(Integer(p),
+ Integer(b),
+ Integer(order),
+ Integer(Gx),
+ Integer(Gy),
+ None,
+ 384,
+ "1.3.132.0.34", # SEC 2
+ context,
+ "NIST P-384",
+ "ecdsa-sha2-nistp384",
+ "p384")
+ global p384_names
+ _curves.update(dict.fromkeys(p384_names, p384))
+
+
+init_p384()
+del init_p384
+
+
+p521_names = ["p521", "NIST P-521", "P-521", "prime521v1", "secp521r1",
+ "nistp521"]
+
+
+def init_p521():
+ p = 0x000001ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
+ b = 0x00000051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00
+ order = 0x000001fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409
+ Gx = 0x000000c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66
+ Gy = 0x0000011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16650
+
+ p521_modulus = long_to_bytes(p, 66)
+ p521_b = long_to_bytes(b, 66)
+ p521_order = long_to_bytes(order, 66)
+
+ ec_p521_context = VoidPointer()
+ result = _ec_lib.ec_ws_new_context(ec_p521_context.address_of(),
+ c_uint8_ptr(p521_modulus),
+ c_uint8_ptr(p521_b),
+ c_uint8_ptr(p521_order),
+ c_size_t(len(p521_modulus)),
+ c_ulonglong(getrandbits(64))
+ )
+ if result:
+ raise ImportError("Error %d initializing P-521 context" % result)
+
+ context = SmartPointer(ec_p521_context.get(), _ec_lib.ec_free_context)
+ p521 = _Curve(Integer(p),
+ Integer(b),
+ Integer(order),
+ Integer(Gx),
+ Integer(Gy),
+ None,
+ 521,
+ "1.3.132.0.35", # SEC 2
+ context,
+ "NIST P-521",
+ "ecdsa-sha2-nistp521",
+ "p521")
+ global p521_names
+ _curves.update(dict.fromkeys(p521_names, p521))
+
+
+init_p521()
+del init_p521
+
+
+ed25519_names = ["ed25519", "Ed25519"]
+
+
+def init_ed25519():
+ p = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed # 2**255 - 19
+ order = 0x1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed
+ Gx = 0x216936d3cd6e53fec0a4e231fdd6dc5c692cc7609525a7b2c9562d608f25d51a
+ Gy = 0x6666666666666666666666666666666666666666666666666666666666666658
+
+ ed25519 = _Curve(Integer(p),
+ None,
+ Integer(order),
+ Integer(Gx),
+ Integer(Gy),
+ None,
+ 255,
+ "1.3.101.112", # RFC8410
+ None,
+ "Ed25519", # Used throughout; do not change
+ "ssh-ed25519",
+ "ed25519")
+ global ed25519_names
+ _curves.update(dict.fromkeys(ed25519_names, ed25519))
+
+
+init_ed25519()
+del init_ed25519
+
+
+ed448_names = ["ed448", "Ed448"]
+
+
+def init_ed448():
+ p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffffffffffffffffffffffffffffffffffffffffffffffffffff # 2**448 - 2**224 - 1
+ order = 0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c292ab5844f3
+ Gx = 0x4f1970c66bed0ded221d15a622bf36da9e146570470f1767ea6de324a3d3a46412ae1af72ab66511433b80e18b00938e2626a82bc70cc05e
+ Gy = 0x693f46716eb6bc248876203756c9c7624bea73736ca3984087789c1e05a0c2d73ad3ff1ce67c39c4fdbd132c4ed7c8ad9808795bf230fa14
+
+ ed448_context = VoidPointer()
+ result = _ed448_lib.ed448_new_context(ed448_context.address_of())
+ if result:
+ raise ImportError("Error %d initializing Ed448 context" % result)
+
+ context = SmartPointer(ed448_context.get(), _ed448_lib.ed448_free_context)
+
+ ed448 = _Curve(Integer(p),
+ None,
+ Integer(order),
+ Integer(Gx),
+ Integer(Gy),
+ None,
+ 448,
+ "1.3.101.113", # RFC8410
+ context,
+ "Ed448", # Used throughout; do not change
+ None,
+ "ed448")
+ global ed448_names
+ _curves.update(dict.fromkeys(ed448_names, ed448))
+
+
+init_ed448()
+del init_ed448
+
+
+class UnsupportedEccFeature(ValueError):
+ pass
+
+
+class EccPoint(object):
+ """A class to model a point on an Elliptic Curve.
+
+ The class supports operators for:
+
+ * Adding two points: ``R = S + T``
+ * In-place addition: ``S += T``
+ * Negating a point: ``R = -T``
+ * Comparing two points: ``if S == T: ...`` or ``if S != T: ...``
+ * Multiplying a point by a scalar: ``R = S*k``
+ * In-place multiplication by a scalar: ``T *= k``
+
+ :ivar x: The affine X-coordinate of the ECC point
+ :vartype x: integer
+
+ :ivar y: The affine Y-coordinate of the ECC point
+ :vartype y: integer
+
+ :ivar xy: The tuple with affine X- and Y- coordinates
+ """
+
+ def __init__(self, x, y, curve="p256"):
+
+ try:
+ self._curve = _curves[curve]
+ except KeyError:
+ raise ValueError("Unknown curve name %s" % str(curve))
+ self._curve_name = curve
+
+ modulus_bytes = self.size_in_bytes()
+
+ xb = long_to_bytes(x, modulus_bytes)
+ yb = long_to_bytes(y, modulus_bytes)
+ if len(xb) != modulus_bytes or len(yb) != modulus_bytes:
+ raise ValueError("Incorrect coordinate length")
+
+ new_point = lib_func(self, "new_point")
+ free_func = lib_func(self, "free_point")
+
+ self._point = VoidPointer()
+ try:
+ context = self._curve.context.get()
+ except AttributeError:
+ context = null_pointer
+ result = new_point(self._point.address_of(),
+ c_uint8_ptr(xb),
+ c_uint8_ptr(yb),
+ c_size_t(modulus_bytes),
+ context)
+
+ if result:
+ if result == 15:
+ raise ValueError("The EC point does not belong to the curve")
+ raise ValueError("Error %d while instantiating an EC point" % result)
+
+ # Ensure that object disposal of this Python object will (eventually)
+ # free the memory allocated by the raw library for the EC point
+ self._point = SmartPointer(self._point.get(), free_func)
+
+ def set(self, point):
+ clone = lib_func(self, "clone")
+ free_func = lib_func(self, "free_point")
+
+ self._point = VoidPointer()
+ result = clone(self._point.address_of(),
+ point._point.get())
+
+ if result:
+ raise ValueError("Error %d while cloning an EC point" % result)
+
+ self._point = SmartPointer(self._point.get(), free_func)
+ return self
+
+ def __eq__(self, point):
+ cmp_func = lib_func(self, "cmp")
+ return 0 == cmp_func(self._point.get(), point._point.get())
+
+ # Only needed for Python 2
+ def __ne__(self, point):
+ return not self == point
+
+ def __neg__(self):
+ neg_func = lib_func(self, "neg")
+ np = self.copy()
+ result = neg_func(np._point.get())
+ if result:
+ raise ValueError("Error %d while inverting an EC point" % result)
+ return np
+
+ def copy(self):
+ """Return a copy of this point."""
+ x, y = self.xy
+ np = EccPoint(x, y, self._curve_name)
+ return np
+
+ def _is_eddsa(self):
+ return self._curve.name in ("ed25519", "ed448")
+
+ def is_point_at_infinity(self):
+ """``True`` if this is the *point-at-infinity*."""
+
+ if self._is_eddsa():
+ return self.x == 0
+ else:
+ return self.xy == (0, 0)
+
+ def point_at_infinity(self):
+ """Return the *point-at-infinity* for the curve."""
+
+ if self._is_eddsa():
+ return EccPoint(0, 1, self._curve_name)
+ else:
+ return EccPoint(0, 0, self._curve_name)
+
+ @property
+ def x(self):
+ return self.xy[0]
+
+ @property
+ def y(self):
+ return self.xy[1]
+
+ @property
+ def xy(self):
+ modulus_bytes = self.size_in_bytes()
+ xb = bytearray(modulus_bytes)
+ yb = bytearray(modulus_bytes)
+ get_xy = lib_func(self, "get_xy")
+ result = get_xy(c_uint8_ptr(xb),
+ c_uint8_ptr(yb),
+ c_size_t(modulus_bytes),
+ self._point.get())
+ if result:
+ raise ValueError("Error %d while encoding an EC point" % result)
+
+ return (Integer(bytes_to_long(xb)), Integer(bytes_to_long(yb)))
+
+ def size_in_bytes(self):
+ """Size of each coordinate, in bytes."""
+ return (self.size_in_bits() + 7) // 8
+
+ def size_in_bits(self):
+ """Size of each coordinate, in bits."""
+ return self._curve.modulus_bits
+
+ def double(self):
+ """Double this point (in-place operation).
+
+ Returns:
+ This same object (to enable chaining).
+ """
+
+ double_func = lib_func(self, "double")
+ result = double_func(self._point.get())
+ if result:
+ raise ValueError("Error %d while doubling an EC point" % result)
+ return self
+
+ def __iadd__(self, point):
+ """Add a second point to this one"""
+
+ add_func = lib_func(self, "add")
+ result = add_func(self._point.get(), point._point.get())
+ if result:
+ if result == 16:
+ raise ValueError("EC points are not on the same curve")
+ raise ValueError("Error %d while adding two EC points" % result)
+ return self
+
+ def __add__(self, point):
+ """Return a new point, the addition of this one and another"""
+
+ np = self.copy()
+ np += point
+ return np
+
+ def __imul__(self, scalar):
+ """Multiply this point by a scalar"""
+
+ scalar_func = lib_func(self, "scalar")
+ if scalar < 0:
+ raise ValueError("Scalar multiplication is only defined for non-negative integers")
+ sb = long_to_bytes(scalar)
+ result = scalar_func(self._point.get(),
+ c_uint8_ptr(sb),
+ c_size_t(len(sb)),
+ c_ulonglong(getrandbits(64)))
+ if result:
+ raise ValueError("Error %d during scalar multiplication" % result)
+ return self
+
+ def __mul__(self, scalar):
+ """Return a new point, the scalar product of this one"""
+
+ np = self.copy()
+ np *= scalar
+ return np
+
+ def __rmul__(self, left_hand):
+ return self.__mul__(left_hand)
+
+
+# Last piece of initialization
+p192_G = EccPoint(_curves['p192'].Gx, _curves['p192'].Gy, "p192")
+p192 = _curves['p192']._replace(G=p192_G)
+_curves.update(dict.fromkeys(p192_names, p192))
+del p192_G, p192, p192_names
+
+p224_G = EccPoint(_curves['p224'].Gx, _curves['p224'].Gy, "p224")
+p224 = _curves['p224']._replace(G=p224_G)
+_curves.update(dict.fromkeys(p224_names, p224))
+del p224_G, p224, p224_names
+
+p256_G = EccPoint(_curves['p256'].Gx, _curves['p256'].Gy, "p256")
+p256 = _curves['p256']._replace(G=p256_G)
+_curves.update(dict.fromkeys(p256_names, p256))
+del p256_G, p256, p256_names
+
+p384_G = EccPoint(_curves['p384'].Gx, _curves['p384'].Gy, "p384")
+p384 = _curves['p384']._replace(G=p384_G)
+_curves.update(dict.fromkeys(p384_names, p384))
+del p384_G, p384, p384_names
+
+p521_G = EccPoint(_curves['p521'].Gx, _curves['p521'].Gy, "p521")
+p521 = _curves['p521']._replace(G=p521_G)
+_curves.update(dict.fromkeys(p521_names, p521))
+del p521_G, p521, p521_names
+
+ed25519_G = EccPoint(_curves['Ed25519'].Gx, _curves['Ed25519'].Gy, "Ed25519")
+ed25519 = _curves['Ed25519']._replace(G=ed25519_G)
+_curves.update(dict.fromkeys(ed25519_names, ed25519))
+del ed25519_G, ed25519, ed25519_names
+
+ed448_G = EccPoint(_curves['Ed448'].Gx, _curves['Ed448'].Gy, "Ed448")
+ed448 = _curves['Ed448']._replace(G=ed448_G)
+_curves.update(dict.fromkeys(ed448_names, ed448))
+del ed448_G, ed448, ed448_names
+
+
+class EccKey(object):
+ r"""Class defining an ECC key.
+ Do not instantiate directly.
+ Use :func:`generate`, :func:`construct` or :func:`import_key` instead.
+
+ :ivar curve: The name of the curve as defined in the `ECC table`_.
+ :vartype curve: string
+
+ :ivar pointQ: an ECC point representating the public component.
+ :vartype pointQ: :class:`EccPoint`
+
+ :ivar d: A scalar that represents the private component
+ in NIST P curves. It is smaller than the
+ order of the generator point.
+ :vartype d: integer
+
+ :ivar seed: A seed that representats the private component
+ in EdDSA curves
+ (Ed25519, 32 bytes; Ed448, 57 bytes).
+ :vartype seed: bytes
+ """
+
+ def __init__(self, **kwargs):
+ """Create a new ECC key
+
+ Keywords:
+ curve : string
+ The name of the curve.
+ d : integer
+ Mandatory for a private key one NIST P curves.
+ It must be in the range ``[1..order-1]``.
+ seed : bytes
+ Mandatory for a private key on the Ed25519 (32 bytes)
+ or Ed448 (57 bytes) curve.
+ point : EccPoint
+ Mandatory for a public key. If provided for a private key,
+ the implementation will NOT check whether it matches ``d``.
+
+ Only one parameter among ``d``, ``seed`` or ``point`` may be used.
+ """
+
+ kwargs_ = dict(kwargs)
+ curve_name = kwargs_.pop("curve", None)
+ self._d = kwargs_.pop("d", None)
+ self._seed = kwargs_.pop("seed", None)
+ self._point = kwargs_.pop("point", None)
+ if curve_name is None and self._point:
+ curve_name = self._point._curve_name
+ if kwargs_:
+ raise TypeError("Unknown parameters: " + str(kwargs_))
+
+ if curve_name not in _curves:
+ raise ValueError("Unsupported curve (%s)" % curve_name)
+ self._curve = _curves[curve_name]
+ self.curve = curve_name
+
+ count = int(self._d is not None) + int(self._seed is not None)
+
+ if count == 0:
+ if self._point is None:
+ raise ValueError("At lest one between parameters 'point', 'd' or 'seed' must be specified")
+ return
+
+ if count == 2:
+ raise ValueError("Parameters d and seed are mutually exclusive")
+
+ # NIST P curves work with d, EdDSA works with seed
+
+ if not self._is_eddsa():
+ if self._seed is not None:
+ raise ValueError("Parameter 'seed' can only be used with Ed25519 or Ed448")
+ self._d = Integer(self._d)
+ if not 1 <= self._d < self._curve.order:
+ raise ValueError("Parameter d must be an integer smaller than the curve order")
+ else:
+ if self._d is not None:
+ raise ValueError("Parameter d can only be used with NIST P curves")
+ # RFC 8032, 5.1.5
+ if self._curve.name == "ed25519":
+ if len(self._seed) != 32:
+ raise ValueError("Parameter seed must be 32 bytes long for Ed25519")
+ seed_hash = SHA512.new(self._seed).digest() # h
+ self._prefix = seed_hash[32:]
+ tmp = bytearray(seed_hash[:32])
+ tmp[0] &= 0xF8
+ tmp[31] = (tmp[31] & 0x7F) | 0x40
+ # RFC 8032, 5.2.5
+ elif self._curve.name == "ed448":
+ if len(self._seed) != 57:
+ raise ValueError("Parameter seed must be 57 bytes long for Ed448")
+ seed_hash = SHAKE256.new(self._seed).read(114) # h
+ self._prefix = seed_hash[57:]
+ tmp = bytearray(seed_hash[:57])
+ tmp[0] &= 0xFC
+ tmp[55] |= 0x80
+ tmp[56] = 0
+ self._d = Integer.from_bytes(tmp, byteorder='little')
+
+ def _is_eddsa(self):
+ return self._curve.desc in ("Ed25519", "Ed448")
+
+ def __eq__(self, other):
+ if other.has_private() != self.has_private():
+ return False
+
+ return other.pointQ == self.pointQ
+
+ def __repr__(self):
+ if self.has_private():
+ if self._is_eddsa():
+ extra = ", seed=%s" % self._seed.hex()
+ else:
+ extra = ", d=%d" % int(self._d)
+ else:
+ extra = ""
+ x, y = self.pointQ.xy
+ return "EccKey(curve='%s', point_x=%d, point_y=%d%s)" % (self._curve.desc, x, y, extra)
+
+ def has_private(self):
+ """``True`` if this key can be used for making signatures or decrypting data."""
+
+ return self._d is not None
+
+ # ECDSA
+ def _sign(self, z, k):
+ assert 0 < k < self._curve.order
+
+ order = self._curve.order
+ blind = Integer.random_range(min_inclusive=1,
+ max_exclusive=order)
+
+ blind_d = self._d * blind
+ inv_blind_k = (blind * k).inverse(order)
+
+ r = (self._curve.G * k).x % order
+ s = inv_blind_k * (blind * z + blind_d * r) % order
+ return (r, s)
+
+ # ECDSA
+ def _verify(self, z, rs):
+ order = self._curve.order
+ sinv = rs[1].inverse(order)
+ point1 = self._curve.G * ((sinv * z) % order)
+ point2 = self.pointQ * ((sinv * rs[0]) % order)
+ return (point1 + point2).x == rs[0]
+
+ @property
+ def d(self):
+ if not self.has_private():
+ raise ValueError("This is not a private ECC key")
+ return self._d
+
+ @property
+ def seed(self):
+ if not self.has_private():
+ raise ValueError("This is not a private ECC key")
+ return self._seed
+
+ @property
+ def pointQ(self):
+ if self._point is None:
+ self._point = self._curve.G * self._d
+ return self._point
+
+ def public_key(self):
+ """A matching ECC public key.
+
+ Returns:
+ a new :class:`EccKey` object
+ """
+
+ return EccKey(curve=self._curve.desc, point=self.pointQ)
+
+ def _export_SEC1(self, compress):
+ if self._is_eddsa():
+ raise ValueError("SEC1 format is unsupported for EdDSA curves")
+
+ # See 2.2 in RFC5480 and 2.3.3 in SEC1
+ #
+ # The first byte is:
+ # - 0x02: compressed, only X-coordinate, Y-coordinate is even
+ # - 0x03: compressed, only X-coordinate, Y-coordinate is odd
+ # - 0x04: uncompressed, X-coordinate is followed by Y-coordinate
+ #
+ # PAI is in theory encoded as 0x00.
+
+ modulus_bytes = self.pointQ.size_in_bytes()
+
+ if compress:
+ if self.pointQ.y.is_odd():
+ first_byte = b'\x03'
+ else:
+ first_byte = b'\x02'
+ public_key = (first_byte +
+ self.pointQ.x.to_bytes(modulus_bytes))
+ else:
+ public_key = (b'\x04' +
+ self.pointQ.x.to_bytes(modulus_bytes) +
+ self.pointQ.y.to_bytes(modulus_bytes))
+ return public_key
+
+ def _export_eddsa(self):
+ x, y = self.pointQ.xy
+ if self._curve.name == "ed25519":
+ result = bytearray(y.to_bytes(32, byteorder='little'))
+ result[31] = ((x & 1) << 7) | result[31]
+ elif self._curve.name == "ed448":
+ result = bytearray(y.to_bytes(57, byteorder='little'))
+ result[56] = (x & 1) << 7
+ else:
+ raise ValueError("Not an EdDSA key to export")
+ return bytes(result)
+
+ def _export_subjectPublicKeyInfo(self, compress):
+ if self._is_eddsa():
+ oid = self._curve.oid
+ public_key = self._export_eddsa()
+ params = None
+ else:
+ oid = "1.2.840.10045.2.1" # unrestricted
+ public_key = self._export_SEC1(compress)
+ params = DerObjectId(self._curve.oid)
+
+ return _create_subject_public_key_info(oid,
+ public_key,
+ params)
+
+ def _export_rfc5915_private_der(self, include_ec_params=True):
+
+ assert self.has_private()
+
+ # ECPrivateKey ::= SEQUENCE {
+ # version INTEGER { ecPrivkeyVer1(1) } (ecPrivkeyVer1),
+ # privateKey OCTET STRING,
+ # parameters [0] ECParameters {{ NamedCurve }} OPTIONAL,
+ # publicKey [1] BIT STRING OPTIONAL
+ # }
+
+ # Public key - uncompressed form
+ modulus_bytes = self.pointQ.size_in_bytes()
+ public_key = (b'\x04' +
+ self.pointQ.x.to_bytes(modulus_bytes) +
+ self.pointQ.y.to_bytes(modulus_bytes))
+
+ seq = [1,
+ DerOctetString(self.d.to_bytes(modulus_bytes)),
+ DerObjectId(self._curve.oid, explicit=0),
+ DerBitString(public_key, explicit=1)]
+
+ if not include_ec_params:
+ del seq[2]
+
+ return DerSequence(seq).encode()
+
+ def _export_pkcs8(self, **kwargs):
+ from Crypto.IO import PKCS8
+
+ if kwargs.get('passphrase', None) is not None and 'protection' not in kwargs:
+ raise ValueError("At least the 'protection' parameter should be present")
+
+ if self._is_eddsa():
+ oid = self._curve.oid
+ private_key = DerOctetString(self._seed).encode()
+ params = None
+ else:
+ oid = "1.2.840.10045.2.1" # unrestricted
+ private_key = self._export_rfc5915_private_der(include_ec_params=False)
+ params = DerObjectId(self._curve.oid)
+
+ result = PKCS8.wrap(private_key,
+ oid,
+ key_params=params,
+ **kwargs)
+ return result
+
+ def _export_public_pem(self, compress):
+ from Crypto.IO import PEM
+
+ encoded_der = self._export_subjectPublicKeyInfo(compress)
+ return PEM.encode(encoded_der, "PUBLIC KEY")
+
+ def _export_private_pem(self, passphrase, **kwargs):
+ from Crypto.IO import PEM
+
+ encoded_der = self._export_rfc5915_private_der()
+ return PEM.encode(encoded_der, "EC PRIVATE KEY", passphrase, **kwargs)
+
+ def _export_private_clear_pkcs8_in_clear_pem(self):
+ from Crypto.IO import PEM
+
+ encoded_der = self._export_pkcs8()
+ return PEM.encode(encoded_der, "PRIVATE KEY")
+
+ def _export_private_encrypted_pkcs8_in_clear_pem(self, passphrase, **kwargs):
+ from Crypto.IO import PEM
+
+ assert passphrase
+ if 'protection' not in kwargs:
+ raise ValueError("At least the 'protection' parameter should be present")
+ encoded_der = self._export_pkcs8(passphrase=passphrase, **kwargs)
+ return PEM.encode(encoded_der, "ENCRYPTED PRIVATE KEY")
+
+ def _export_openssh(self, compress):
+ if self.has_private():
+ raise ValueError("Cannot export OpenSSH private keys")
+
+ desc = self._curve.openssh
+
+ if desc is None:
+ raise ValueError("Cannot export %s keys as OpenSSH" % self._curve.name)
+ elif desc == "ssh-ed25519":
+ public_key = self._export_eddsa()
+ comps = (tobytes(desc), tobytes(public_key))
+ else:
+ modulus_bytes = self.pointQ.size_in_bytes()
+
+ if compress:
+ first_byte = 2 + self.pointQ.y.is_odd()
+ public_key = (bchr(first_byte) +
+ self.pointQ.x.to_bytes(modulus_bytes))
+ else:
+ public_key = (b'\x04' +
+ self.pointQ.x.to_bytes(modulus_bytes) +
+ self.pointQ.y.to_bytes(modulus_bytes))
+
+ middle = desc.split("-")[2]
+ comps = (tobytes(desc), tobytes(middle), public_key)
+
+ blob = b"".join([struct.pack(">I", len(x)) + x for x in comps])
+ return desc + " " + tostr(binascii.b2a_base64(blob))
+
+ def export_key(self, **kwargs):
+ """Export this ECC key.
+
+ Args:
+ format (string):
+ The format to use for encoding the key:
+
+ - ``'DER'``. The key will be encoded in ASN.1 DER format (binary).
+ For a public key, the ASN.1 ``subjectPublicKeyInfo`` structure
+ defined in `RFC5480`_ will be used.
+ For a private key, the ASN.1 ``ECPrivateKey`` structure defined
+ in `RFC5915`_ is used instead (possibly within a PKCS#8 envelope,
+ see the ``use_pkcs8`` flag below).
+ - ``'PEM'``. The key will be encoded in a PEM_ envelope (ASCII).
+ - ``'OpenSSH'``. The key will be encoded in the OpenSSH_ format
+ (ASCII, public keys only).
+ - ``'SEC1'``. The public key (i.e., the EC point) will be encoded
+ into ``bytes`` according to Section 2.3.3 of `SEC1`_
+ (which is a subset of the older X9.62 ITU standard).
+ Only for NIST P-curves.
+ - ``'raw'``. The public key will be encoded as ``bytes``,
+ without any metadata.
+
+ * For NIST P-curves: equivalent to ``'SEC1'``.
+ * For EdDSA curves: ``bytes`` in the format defined in `RFC8032`_.
+
+ passphrase (byte string or string):
+ The passphrase to use for protecting the private key.
+
+ use_pkcs8 (boolean):
+ Only relevant for private keys.
+
+ If ``True`` (default and recommended), the `PKCS#8`_ representation
+ will be used. It must be ``True`` for EdDSA curves.
+
+ protection (string):
+ When a private key is exported with password-protection
+ and PKCS#8 (both ``DER`` and ``PEM`` formats), this parameter MUST be
+ present and be a valid algorithm supported by :mod:`Crypto.IO.PKCS8`.
+ It is recommended to use ``PBKDF2WithHMAC-SHA1AndAES128-CBC``.
+
+ compress (boolean):
+ If ``True``, the method returns a more compact representation
+ of the public key, with the X-coordinate only.
+
+ If ``False`` (default), the method returns the full public key.
+
+ This parameter is ignored for EdDSA curves, as compression is
+ mandatory.
+
+ .. warning::
+ If you don't provide a passphrase, the private key will be
+ exported in the clear!
+
+ .. note::
+ When exporting a private key with password-protection and `PKCS#8`_
+ (both ``DER`` and ``PEM`` formats), any extra parameters
+ to ``export_key()`` will be passed to :mod:`Crypto.IO.PKCS8`.
+
+ .. _PEM: http://www.ietf.org/rfc/rfc1421.txt
+ .. _`PEM encryption`: http://www.ietf.org/rfc/rfc1423.txt
+ .. _OpenSSH: http://www.openssh.com/txt/rfc5656.txt
+ .. _RFC5480: https://tools.ietf.org/html/rfc5480
+ .. _SEC1: https://www.secg.org/sec1-v2.pdf
+
+ Returns:
+ A multi-line string (for ``'PEM'`` and ``'OpenSSH'``) or
+ ``bytes`` (for ``'DER'``, ``'SEC1'``, and ``'raw'``) with the encoded key.
+ """
+
+ args = kwargs.copy()
+ ext_format = args.pop("format")
+ if ext_format not in ("PEM", "DER", "OpenSSH", "SEC1", "raw"):
+ raise ValueError("Unknown format '%s'" % ext_format)
+
+ compress = args.pop("compress", False)
+
+ if self.has_private():
+ passphrase = args.pop("passphrase", None)
+ if is_string(passphrase):
+ passphrase = tobytes(passphrase)
+ if not passphrase:
+ raise ValueError("Empty passphrase")
+ use_pkcs8 = args.pop("use_pkcs8", True)
+
+ if not use_pkcs8 and self._is_eddsa():
+ raise ValueError("'pkcs8' must be True for EdDSA curves")
+
+ if ext_format == "PEM":
+ if use_pkcs8:
+ if passphrase:
+ return self._export_private_encrypted_pkcs8_in_clear_pem(passphrase, **args)
+ else:
+ return self._export_private_clear_pkcs8_in_clear_pem()
+ else:
+ return self._export_private_pem(passphrase, **args)
+ elif ext_format == "DER":
+ # DER
+ if passphrase and not use_pkcs8:
+ raise ValueError("Private keys can only be encrpyted with DER using PKCS#8")
+ if use_pkcs8:
+ return self._export_pkcs8(passphrase=passphrase, **args)
+ else:
+ return self._export_rfc5915_private_der()
+ else:
+ raise ValueError("Private keys cannot be exported "
+ "in the '%s' format" % ext_format)
+ else: # Public key
+ if args:
+ raise ValueError("Unexpected parameters: '%s'" % args)
+ if ext_format == "PEM":
+ return self._export_public_pem(compress)
+ elif ext_format == "DER":
+ return self._export_subjectPublicKeyInfo(compress)
+ elif ext_format == "SEC1":
+ return self._export_SEC1(compress)
+ elif ext_format == "raw":
+ if self._curve.name in ('ed25519', 'ed448'):
+ return self._export_eddsa()
+ else:
+ return self._export_SEC1(compress)
+ else:
+ return self._export_openssh(compress)
+
+
+def generate(**kwargs):
+ """Generate a new private key on the given curve.
+
+ Args:
+
+ curve (string):
+ Mandatory. It must be a curve name defined in the `ECC table`_.
+
+ randfunc (callable):
+ Optional. The RNG to read randomness from.
+ If ``None``, :func:`Crypto.Random.get_random_bytes` is used.
+ """
+
+ curve_name = kwargs.pop("curve")
+ curve = _curves[curve_name]
+ randfunc = kwargs.pop("randfunc", get_random_bytes)
+ if kwargs:
+ raise TypeError("Unknown parameters: " + str(kwargs))
+
+ if _curves[curve_name].name == "ed25519":
+ seed = randfunc(32)
+ new_key = EccKey(curve=curve_name, seed=seed)
+ elif _curves[curve_name].name == "ed448":
+ seed = randfunc(57)
+ new_key = EccKey(curve=curve_name, seed=seed)
+ else:
+ d = Integer.random_range(min_inclusive=1,
+ max_exclusive=curve.order,
+ randfunc=randfunc)
+ new_key = EccKey(curve=curve_name, d=d)
+
+ return new_key
+
+
+def construct(**kwargs):
+ """Build a new ECC key (private or public) starting
+ from some base components.
+
+ In most cases, you will already have an existing key
+ which you can read in with :func:`import_key` instead
+ of this function.
+
+ Args:
+ curve (string):
+ Mandatory. The name of the elliptic curve, as defined in the `ECC table`_.
+
+ d (integer):
+ Mandatory for a private key and a NIST P-curve (e.g., P-256):
+ the integer in the range ``[1..order-1]`` that represents the key.
+
+ seed (bytes):
+ Mandatory for a private key and an EdDSA curve.
+ It must be 32 bytes for Ed25519, and 57 bytes for Ed448.
+
+ point_x (integer):
+ Mandatory for a public key: the X coordinate (affine) of the ECC point.
+
+ point_y (integer):
+ Mandatory for a public key: the Y coordinate (affine) of the ECC point.
+
+ Returns:
+ :class:`EccKey` : a new ECC key object
+ """
+
+ curve_name = kwargs["curve"]
+ curve = _curves[curve_name]
+ point_x = kwargs.pop("point_x", None)
+ point_y = kwargs.pop("point_y", None)
+
+ if "point" in kwargs:
+ raise TypeError("Unknown keyword: point")
+
+ if None not in (point_x, point_y):
+ # ValueError is raised if the point is not on the curve
+ kwargs["point"] = EccPoint(point_x, point_y, curve_name)
+
+ new_key = EccKey(**kwargs)
+
+ # Validate that the private key matches the public one
+ # because EccKey will not do that automatically
+ if new_key.has_private() and 'point' in kwargs:
+ pub_key = curve.G * new_key.d
+ if pub_key.xy != (point_x, point_y):
+ raise ValueError("Private and public ECC keys do not match")
+
+ return new_key
+
+
+def _import_public_der(ec_point, curve_oid=None, curve_name=None):
+ """Convert an encoded EC point into an EccKey object
+
+ ec_point: byte string with the EC point (SEC1-encoded)
+ curve_oid: string with the name the curve
+ curve_name: string with the OID of the curve
+
+ Either curve_id or curve_name must be specified
+
+ """
+
+ for _curve_name, curve in _curves.items():
+ if curve_oid and curve.oid == curve_oid:
+ break
+ if curve_name == _curve_name:
+ break
+ else:
+ if curve_oid:
+ raise UnsupportedEccFeature("Unsupported ECC curve (OID: %s)" % curve_oid)
+ else:
+ raise UnsupportedEccFeature("Unsupported ECC curve (%s)" % curve_name)
+
+ # See 2.2 in RFC5480 and 2.3.3 in SEC1
+ # The first byte is:
+ # - 0x02: compressed, only X-coordinate, Y-coordinate is even
+ # - 0x03: compressed, only X-coordinate, Y-coordinate is odd
+ # - 0x04: uncompressed, X-coordinate is followed by Y-coordinate
+ #
+ # PAI is in theory encoded as 0x00.
+
+ modulus_bytes = curve.p.size_in_bytes()
+ point_type = bord(ec_point[0])
+
+ # Uncompressed point
+ if point_type == 0x04:
+ if len(ec_point) != (1 + 2 * modulus_bytes):
+ raise ValueError("Incorrect EC point length")
+ x = Integer.from_bytes(ec_point[1:modulus_bytes+1])
+ y = Integer.from_bytes(ec_point[modulus_bytes+1:])
+ # Compressed point
+ elif point_type in (0x02, 0x03):
+ if len(ec_point) != (1 + modulus_bytes):
+ raise ValueError("Incorrect EC point length")
+ x = Integer.from_bytes(ec_point[1:])
+ # Right now, we only support Short Weierstrass curves
+ y = (x**3 - x*3 + curve.b).sqrt(curve.p)
+ if point_type == 0x02 and y.is_odd():
+ y = curve.p - y
+ if point_type == 0x03 and y.is_even():
+ y = curve.p - y
+ else:
+ raise ValueError("Incorrect EC point encoding")
+
+ return construct(curve=_curve_name, point_x=x, point_y=y)
+
+
+def _import_subjectPublicKeyInfo(encoded, *kwargs):
+ """Convert a subjectPublicKeyInfo into an EccKey object"""
+
+ # See RFC5480
+
+ # Parse the generic subjectPublicKeyInfo structure
+ oid, ec_point, params = _expand_subject_public_key_info(encoded)
+
+ nist_p_oids = (
+ "1.2.840.10045.2.1", # id-ecPublicKey (unrestricted)
+ "1.3.132.1.12", # id-ecDH
+ "1.3.132.1.13" # id-ecMQV
+ )
+ eddsa_oids = {
+ "1.3.101.112": ("Ed25519", _import_ed25519_public_key), # id-Ed25519
+ "1.3.101.113": ("Ed448", _import_ed448_public_key) # id-Ed448
+ }
+
+ if oid in nist_p_oids:
+ # See RFC5480
+
+ # Parameters are mandatory and encoded as ECParameters
+ # ECParameters ::= CHOICE {
+ # namedCurve OBJECT IDENTIFIER
+ # -- implicitCurve NULL
+ # -- specifiedCurve SpecifiedECDomain
+ # }
+ # implicitCurve and specifiedCurve are not supported (as per RFC)
+ if not params:
+ raise ValueError("Missing ECC parameters for ECC OID %s" % oid)
+ try:
+ curve_oid = DerObjectId().decode(params).value
+ except ValueError:
+ raise ValueError("Error decoding namedCurve")
+
+ # ECPoint ::= OCTET STRING
+ return _import_public_der(ec_point, curve_oid=curve_oid)
+
+ elif oid in eddsa_oids:
+ # See RFC8410
+ curve_name, import_eddsa_public_key = eddsa_oids[oid]
+
+ # Parameters must be absent
+ if params:
+ raise ValueError("Unexpected ECC parameters for ECC OID %s" % oid)
+
+ x, y = import_eddsa_public_key(ec_point)
+ return construct(point_x=x, point_y=y, curve=curve_name)
+ else:
+ raise UnsupportedEccFeature("Unsupported ECC OID: %s" % oid)
+
+
+def _import_rfc5915_der(encoded, passphrase, curve_oid=None):
+
+ # See RFC5915 https://tools.ietf.org/html/rfc5915
+ #
+ # ECPrivateKey ::= SEQUENCE {
+ # version INTEGER { ecPrivkeyVer1(1) } (ecPrivkeyVer1),
+ # privateKey OCTET STRING,
+ # parameters [0] ECParameters {{ NamedCurve }} OPTIONAL,
+ # publicKey [1] BIT STRING OPTIONAL
+ # }
+
+ private_key = DerSequence().decode(encoded, nr_elements=(3, 4))
+ if private_key[0] != 1:
+ raise ValueError("Incorrect ECC private key version")
+
+ try:
+ parameters = DerObjectId(explicit=0).decode(private_key[2]).value
+ if curve_oid is not None and parameters != curve_oid:
+ raise ValueError("Curve mismatch")
+ curve_oid = parameters
+ except ValueError:
+ pass
+
+ if curve_oid is None:
+ raise ValueError("No curve found")
+
+ for curve_name, curve in _curves.items():
+ if curve.oid == curve_oid:
+ break
+ else:
+ raise UnsupportedEccFeature("Unsupported ECC curve (OID: %s)" % curve_oid)
+
+ scalar_bytes = DerOctetString().decode(private_key[1]).payload
+ modulus_bytes = curve.p.size_in_bytes()
+ if len(scalar_bytes) != modulus_bytes:
+ raise ValueError("Private key is too small")
+ d = Integer.from_bytes(scalar_bytes)
+
+ # Decode public key (if any)
+ if len(private_key) == 4:
+ public_key_enc = DerBitString(explicit=1).decode(private_key[3]).value
+ public_key = _import_public_der(public_key_enc, curve_oid=curve_oid)
+ point_x = public_key.pointQ.x
+ point_y = public_key.pointQ.y
+ else:
+ point_x = point_y = None
+
+ return construct(curve=curve_name, d=d, point_x=point_x, point_y=point_y)
+
+
+def _import_pkcs8(encoded, passphrase):
+ from Crypto.IO import PKCS8
+
+ algo_oid, private_key, params = PKCS8.unwrap(encoded, passphrase)
+
+ nist_p_oids = (
+ "1.2.840.10045.2.1", # id-ecPublicKey (unrestricted)
+ "1.3.132.1.12", # id-ecDH
+ "1.3.132.1.13" # id-ecMQV
+ )
+ eddsa_oids = {
+ "1.3.101.112": "Ed25519", # id-Ed25519
+ "1.3.101.113": "Ed448", # id-Ed448
+ }
+
+ if algo_oid in nist_p_oids:
+ curve_oid = DerObjectId().decode(params).value
+ return _import_rfc5915_der(private_key, passphrase, curve_oid)
+ elif algo_oid in eddsa_oids:
+ if params is not None:
+ raise ValueError("EdDSA ECC private key must not have parameters")
+ curve_oid = None
+ seed = DerOctetString().decode(private_key).payload
+ return construct(curve=eddsa_oids[algo_oid], seed=seed)
+ else:
+ raise UnsupportedEccFeature("Unsupported ECC purpose (OID: %s)" % algo_oid)
+
+
+def _import_x509_cert(encoded, *kwargs):
+
+ sp_info = _extract_subject_public_key_info(encoded)
+ return _import_subjectPublicKeyInfo(sp_info)
+
+
+def _import_der(encoded, passphrase):
+
+ try:
+ return _import_subjectPublicKeyInfo(encoded, passphrase)
+ except UnsupportedEccFeature as err:
+ raise err
+ except (ValueError, TypeError, IndexError):
+ pass
+
+ try:
+ return _import_x509_cert(encoded, passphrase)
+ except UnsupportedEccFeature as err:
+ raise err
+ except (ValueError, TypeError, IndexError):
+ pass
+
+ try:
+ return _import_rfc5915_der(encoded, passphrase)
+ except UnsupportedEccFeature as err:
+ raise err
+ except (ValueError, TypeError, IndexError):
+ pass
+
+ try:
+ return _import_pkcs8(encoded, passphrase)
+ except UnsupportedEccFeature as err:
+ raise err
+ except (ValueError, TypeError, IndexError):
+ pass
+
+ raise ValueError("Not an ECC DER key")
+
+
+def _import_openssh_public(encoded):
+ parts = encoded.split(b' ')
+ if len(parts) not in (2, 3):
+ raise ValueError("Not an openssh public key")
+
+ try:
+ keystring = binascii.a2b_base64(parts[1])
+
+ keyparts = []
+ while len(keystring) > 4:
+ lk = struct.unpack(">I", keystring[:4])[0]
+ keyparts.append(keystring[4:4 + lk])
+ keystring = keystring[4 + lk:]
+
+ if parts[0] != keyparts[0]:
+ raise ValueError("Mismatch in openssh public key")
+
+ # NIST P curves
+ if parts[0].startswith(b"ecdsa-sha2-"):
+
+ for curve_name, curve in _curves.items():
+ if curve.openssh is None:
+ continue
+ if not curve.openssh.startswith("ecdsa-sha2"):
+ continue
+ middle = tobytes(curve.openssh.split("-")[2])
+ if keyparts[1] == middle:
+ break
+ else:
+ raise ValueError("Unsupported ECC curve: " + middle)
+
+ ecc_key = _import_public_der(keyparts[2], curve_oid=curve.oid)
+
+ # EdDSA
+ elif parts[0] == b"ssh-ed25519":
+ x, y = _import_ed25519_public_key(keyparts[1])
+ ecc_key = construct(curve="Ed25519", point_x=x, point_y=y)
+ else:
+ raise ValueError("Unsupported SSH key type: " + parts[0])
+
+ except (IndexError, TypeError, binascii.Error):
+ raise ValueError("Error parsing SSH key type: " + parts[0])
+
+ return ecc_key
+
+
+def _import_openssh_private_ecc(data, password):
+
+ from ._openssh import (import_openssh_private_generic,
+ read_bytes, read_string, check_padding)
+
+ key_type, decrypted = import_openssh_private_generic(data, password)
+
+ eddsa_keys = {
+ "ssh-ed25519": ("Ed25519", _import_ed25519_public_key, 32),
+ }
+
+ # https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent-04
+ if key_type.startswith("ecdsa-sha2"):
+
+ ecdsa_curve_name, decrypted = read_string(decrypted)
+ if ecdsa_curve_name not in _curves:
+ raise UnsupportedEccFeature("Unsupported ECC curve %s" % ecdsa_curve_name)
+ curve = _curves[ecdsa_curve_name]
+ modulus_bytes = (curve.modulus_bits + 7) // 8
+
+ public_key, decrypted = read_bytes(decrypted)
+
+ if bord(public_key[0]) != 4:
+ raise ValueError("Only uncompressed OpenSSH EC keys are supported")
+ if len(public_key) != 2 * modulus_bytes + 1:
+ raise ValueError("Incorrect public key length")
+
+ point_x = Integer.from_bytes(public_key[1:1+modulus_bytes])
+ point_y = Integer.from_bytes(public_key[1+modulus_bytes:])
+
+ private_key, decrypted = read_bytes(decrypted)
+ d = Integer.from_bytes(private_key)
+
+ params = {'d': d, 'curve': ecdsa_curve_name}
+
+ elif key_type in eddsa_keys:
+
+ curve_name, import_eddsa_public_key, seed_len = eddsa_keys[key_type]
+
+ public_key, decrypted = read_bytes(decrypted)
+ point_x, point_y = import_eddsa_public_key(public_key)
+
+ private_public_key, decrypted = read_bytes(decrypted)
+ seed = private_public_key[:seed_len]
+
+ params = {'seed': seed, 'curve': curve_name}
+ else:
+ raise ValueError("Unsupport SSH agent key type:" + key_type)
+
+ _, padded = read_string(decrypted) # Comment
+ check_padding(padded)
+
+ return construct(point_x=point_x, point_y=point_y, **params)
+
+
+def _import_ed25519_public_key(encoded):
+ """Import an Ed25519 ECC public key, encoded as raw bytes as described
+ in RFC8032_.
+
+ Args:
+ encoded (bytes):
+ The Ed25519 public key to import. It must be 32 bytes long.
+
+ Returns:
+ :class:`EccKey` : a new ECC key object
+
+ Raises:
+ ValueError: when the given key cannot be parsed.
+
+ .. _RFC8032: https://datatracker.ietf.org/doc/html/rfc8032
+ """
+
+ if len(encoded) != 32:
+ raise ValueError("Incorrect length. Only Ed25519 public keys are supported.")
+
+ p = Integer(0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed) # 2**255 - 19
+ d = 37095705934669439343138083508754565189542113879843219016388785533085940283555
+
+ y = bytearray(encoded)
+ x_lsb = y[31] >> 7
+ y[31] &= 0x7F
+ point_y = Integer.from_bytes(y, byteorder='little')
+ if point_y >= p:
+ raise ValueError("Invalid Ed25519 key (y)")
+ if point_y == 1:
+ return 0, 1
+
+ u = (point_y**2 - 1) % p
+ v = ((point_y**2 % p) * d + 1) % p
+ try:
+ v_inv = v.inverse(p)
+ x2 = (u * v_inv) % p
+ point_x = Integer._tonelli_shanks(x2, p)
+ if (point_x & 1) != x_lsb:
+ point_x = p - point_x
+ except ValueError:
+ raise ValueError("Invalid Ed25519 public key")
+ return point_x, point_y
+
+
+def _import_ed448_public_key(encoded):
+ """Import an Ed448 ECC public key, encoded as raw bytes as described
+ in RFC8032_.
+
+ Args:
+ encoded (bytes):
+ The Ed448 public key to import. It must be 57 bytes long.
+
+ Returns:
+ :class:`EccKey` : a new ECC key object
+
+ Raises:
+ ValueError: when the given key cannot be parsed.
+
+ .. _RFC8032: https://datatracker.ietf.org/doc/html/rfc8032
+ """
+
+ if len(encoded) != 57:
+ raise ValueError("Incorrect length. Only Ed448 public keys are supported.")
+
+ p = Integer(0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffffffffffffffffffffffffffffffffffffffffffffffffffff) # 2**448 - 2**224 - 1
+ d = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffffffffffffffffffffffffffffffffffffffffffffffff6756
+
+ y = encoded[:56]
+ x_lsb = bord(encoded[56]) >> 7
+ point_y = Integer.from_bytes(y, byteorder='little')
+ if point_y >= p:
+ raise ValueError("Invalid Ed448 key (y)")
+ if point_y == 1:
+ return 0, 1
+
+ u = (point_y**2 - 1) % p
+ v = ((point_y**2 % p) * d - 1) % p
+ try:
+ v_inv = v.inverse(p)
+ x2 = (u * v_inv) % p
+ point_x = Integer._tonelli_shanks(x2, p)
+ if (point_x & 1) != x_lsb:
+ point_x = p - point_x
+ except ValueError:
+ raise ValueError("Invalid Ed448 public key")
+ return point_x, point_y
+
+
+def import_key(encoded, passphrase=None, curve_name=None):
+ """Import an ECC key (public or private).
+
+ Args:
+ encoded (bytes or multi-line string):
+ The ECC key to import.
+ The function will try to automatically detect the right format.
+
+ Supported formats for an ECC **public** key:
+
+ * X.509 certificate: binary (DER) or ASCII (PEM).
+ * X.509 ``subjectPublicKeyInfo``: binary (DER) or ASCII (PEM).
+ * SEC1_ (or X9.62), as ``bytes``. NIST P curves only.
+ You must also provide the ``curve_name`` (with a value from the `ECC table`_)
+ * OpenSSH line, defined in RFC5656_ and RFC8709_ (ASCII).
+ This is normally the content of files like ``~/.ssh/id_ecdsa.pub``.
+
+ Supported formats for an ECC **private** key:
+
+ * A binary ``ECPrivateKey`` structure, as defined in `RFC5915`_ (DER).
+ NIST P curves only.
+ * A `PKCS#8`_ structure (or the more recent Asymmetric Key Package, RFC5958_): binary (DER) or ASCII (PEM).
+ * `OpenSSH 6.5`_ and newer versions (ASCII).
+
+ Private keys can be in the clear or password-protected.
+
+ For details about the PEM encoding, see `RFC1421`_/`RFC1423`_.
+
+ passphrase (byte string):
+ The passphrase to use for decrypting a private key.
+ Encryption may be applied protected at the PEM level (not recommended)
+ or at the PKCS#8 level (recommended).
+ This parameter is ignored if the key in input is not encrypted.
+
+ curve_name (string):
+ For a SEC1 encoding only. This is the name of the curve,
+ as defined in the `ECC table`_.
+
+ .. note::
+
+ To import EdDSA private and public keys, when encoded as raw ``bytes``, use:
+
+ * :func:`Crypto.Signature.eddsa.import_public_key`, or
+ * :func:`Crypto.Signature.eddsa.import_private_key`.
+
+ Returns:
+ :class:`EccKey` : a new ECC key object
+
+ Raises:
+ ValueError: when the given key cannot be parsed (possibly because
+ the pass phrase is wrong).
+
+ .. _RFC1421: https://datatracker.ietf.org/doc/html/rfc1421
+ .. _RFC1423: https://datatracker.ietf.org/doc/html/rfc1423
+ .. _RFC5915: https://datatracker.ietf.org/doc/html/rfc5915
+ .. _RFC5656: https://datatracker.ietf.org/doc/html/rfc5656
+ .. _RFC8709: https://datatracker.ietf.org/doc/html/rfc8709
+ .. _RFC5958: https://datatracker.ietf.org/doc/html/rfc5958
+ .. _`PKCS#8`: https://datatracker.ietf.org/doc/html/rfc5208
+ .. _`OpenSSH 6.5`: https://flak.tedunangst.com/post/new-openssh-key-format-and-bcrypt-pbkdf
+ .. _SEC1: https://www.secg.org/sec1-v2.pdf
+ """
+
+ from Crypto.IO import PEM
+
+ encoded = tobytes(encoded)
+ if passphrase is not None:
+ passphrase = tobytes(passphrase)
+
+ # PEM
+ if encoded.startswith(b'-----BEGIN OPENSSH PRIVATE KEY'):
+ text_encoded = tostr(encoded)
+ openssh_encoded, marker, enc_flag = PEM.decode(text_encoded, passphrase)
+ result = _import_openssh_private_ecc(openssh_encoded, passphrase)
+ return result
+
+ elif encoded.startswith(b'-----'):
+
+ text_encoded = tostr(encoded)
+
+ # Remove any EC PARAMETERS section
+ # Ignore its content because the curve type must be already given in the key
+ ecparams_start = "-----BEGIN EC PARAMETERS-----"
+ ecparams_end = "-----END EC PARAMETERS-----"
+ text_encoded = re.sub(ecparams_start + ".*?" + ecparams_end, "",
+ text_encoded,
+ flags=re.DOTALL)
+
+ der_encoded, marker, enc_flag = PEM.decode(text_encoded, passphrase)
+ if enc_flag:
+ passphrase = None
+ try:
+ result = _import_der(der_encoded, passphrase)
+ except UnsupportedEccFeature as uef:
+ raise uef
+ except ValueError:
+ raise ValueError("Invalid DER encoding inside the PEM file")
+ return result
+
+ # OpenSSH
+ if encoded.startswith((b'ecdsa-sha2-', b'ssh-ed25519')):
+ return _import_openssh_public(encoded)
+
+ # DER
+ if len(encoded) > 0 and bord(encoded[0]) == 0x30:
+ return _import_der(encoded, passphrase)
+
+ # SEC1
+ if len(encoded) > 0 and bord(encoded[0]) in b'\x02\x03\x04':
+ if curve_name is None:
+ raise ValueError("No curve name was provided")
+ return _import_public_der(encoded, curve_name=curve_name)
+
+ raise ValueError("ECC key format is not supported")
+
+
+if __name__ == "__main__":
+
+ import time
+
+ d = 0xc51e4753afdec1e6b6c6a5b992f43f8dd0c7a8933072708b6522468b2ffb06fd
+
+ point = _curves['p256'].G.copy()
+ count = 3000
+
+ start = time.time()
+ for x in range(count):
+ pointX = point * d
+ print("(P-256 G)", (time.time() - start) / count * 1000, "ms")
+
+ start = time.time()
+ for x in range(count):
+ pointX = pointX * d
+ print("(P-256 arbitrary point)", (time.time() - start) / count * 1000, "ms")
diff --git a/lib/Crypto/PublicKey/ECC.pyi b/lib/Crypto/PublicKey/ECC.pyi
new file mode 100644
index 0000000..89f5a13
--- /dev/null
+++ b/lib/Crypto/PublicKey/ECC.pyi
@@ -0,0 +1,66 @@
+from typing import Union, Callable, Optional, NamedTuple, List, Tuple, Dict, NamedTuple, Any
+
+from Crypto.Math.Numbers import Integer
+
+RNG = Callable[[int], bytes]
+
+class UnsupportedEccFeature(ValueError): ...
+class EccPoint(object):
+ def __init__(self, x: Union[int, Integer], y: Union[int, Integer], curve: Optional[str] = ...) -> None: ...
+ def set(self, point: EccPoint) -> EccPoint: ...
+ def __eq__(self, point: object) -> bool: ...
+ def __neg__(self) -> EccPoint: ...
+ def copy(self) -> EccPoint: ...
+ def is_point_at_infinity(self) -> bool: ...
+ def point_at_infinity(self) -> EccPoint: ...
+ @property
+ def x(self) -> int: ...
+ @property
+ def y(self) -> int: ...
+ @property
+ def xy(self) -> Tuple[int, int]: ...
+ def size_in_bytes(self) -> int: ...
+ def size_in_bits(self) -> int: ...
+ def double(self) -> EccPoint: ...
+ def __iadd__(self, point: EccPoint) -> EccPoint: ...
+ def __add__(self, point: EccPoint) -> EccPoint: ...
+ def __imul__(self, scalar: int) -> EccPoint: ...
+ def __mul__(self, scalar: int) -> EccPoint: ...
+
+class EccKey(object):
+ curve: str
+ def __init__(self, *, curve: str = ..., d: int = ..., point: EccPoint = ...) -> None: ...
+ def __eq__(self, other: object) -> bool: ...
+ def __repr__(self) -> str: ...
+ def has_private(self) -> bool: ...
+ @property
+ def d(self) -> int: ...
+ @property
+ def pointQ(self) -> EccPoint: ...
+ def public_key(self) -> EccKey: ...
+ def export_key(self, **kwargs: Union[str, bytes, bool]) -> Union[str,bytes]: ...
+
+
+_Curve = NamedTuple("_Curve", [('p', Integer),
+ ('order', Integer),
+ ('b', Integer),
+ ('Gx', Integer),
+ ('Gy', Integer),
+ ('G', EccPoint),
+ ('modulus_bits', int),
+ ('oid', str),
+ ('context', Any),
+ ('desc', str),
+ ('openssh', Union[str, None]),
+ ])
+
+_curves : Dict[str, _Curve]
+
+
+def generate(**kwargs: Union[str, RNG]) -> EccKey: ...
+def construct(**kwargs: Union[str, int]) -> EccKey: ...
+def import_key(encoded: Union[bytes, str],
+ passphrase: Optional[str]=None,
+ curve_name:Optional[str]=None) -> EccKey: ...
+def _import_ed25519_public_key(encoded: bytes) -> EccKey: ...
+def _import_ed448_public_key(encoded: bytes) -> EccKey: ...
diff --git a/lib/Crypto/PublicKey/ElGamal.py b/lib/Crypto/PublicKey/ElGamal.py
new file mode 100644
index 0000000..3b10840
--- /dev/null
+++ b/lib/Crypto/PublicKey/ElGamal.py
@@ -0,0 +1,286 @@
+#
+# ElGamal.py : ElGamal encryption/decryption and signatures
+#
+# Part of the Python Cryptography Toolkit
+#
+# Originally written by: A.M. Kuchling
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+__all__ = ['generate', 'construct', 'ElGamalKey']
+
+from Crypto import Random
+from Crypto.Math.Primality import ( generate_probable_safe_prime,
+ test_probable_prime, COMPOSITE )
+from Crypto.Math.Numbers import Integer
+
+# Generate an ElGamal key with N bits
+def generate(bits, randfunc):
+ """Randomly generate a fresh, new ElGamal key.
+
+ The key will be safe for use for both encryption and signature
+ (although it should be used for **only one** purpose).
+
+ Args:
+ bits (int):
+ Key length, or size (in bits) of the modulus *p*.
+ The recommended value is 2048.
+ randfunc (callable):
+ Random number generation function; it should accept
+ a single integer *N* and return a string of random
+ *N* random bytes.
+
+ Return:
+ an :class:`ElGamalKey` object
+ """
+
+ obj=ElGamalKey()
+
+ # Generate a safe prime p
+ # See Algorithm 4.86 in Handbook of Applied Cryptography
+ obj.p = generate_probable_safe_prime(exact_bits=bits, randfunc=randfunc)
+ q = (obj.p - 1) >> 1
+
+ # Generate generator g
+ while 1:
+ # Choose a square residue; it will generate a cyclic group of order q.
+ obj.g = pow(Integer.random_range(min_inclusive=2,
+ max_exclusive=obj.p,
+ randfunc=randfunc), 2, obj.p)
+
+ # We must avoid g=2 because of Bleichenbacher's attack described
+ # in "Generating ElGamal signatures without knowning the secret key",
+ # 1996
+ if obj.g in (1, 2):
+ continue
+
+ # Discard g if it divides p-1 because of the attack described
+ # in Note 11.67 (iii) in HAC
+ if (obj.p - 1) % obj.g == 0:
+ continue
+
+ # g^{-1} must not divide p-1 because of Khadir's attack
+ # described in "Conditions of the generator for forging ElGamal
+ # signature", 2011
+ ginv = obj.g.inverse(obj.p)
+ if (obj.p - 1) % ginv == 0:
+ continue
+
+ # Found
+ break
+
+ # Generate private key x
+ obj.x = Integer.random_range(min_inclusive=2,
+ max_exclusive=obj.p-1,
+ randfunc=randfunc)
+ # Generate public key y
+ obj.y = pow(obj.g, obj.x, obj.p)
+ return obj
+
+def construct(tup):
+ r"""Construct an ElGamal key from a tuple of valid ElGamal components.
+
+ The modulus *p* must be a prime.
+ The following conditions must apply:
+
+ .. math::
+
+ \begin{align}
+ &1 < g < p-1 \\
+ &g^{p-1} = 1 \text{ mod } 1 \\
+ &1 < x < p-1 \\
+ &g^x = y \text{ mod } p
+ \end{align}
+
+ Args:
+ tup (tuple):
+ A tuple with either 3 or 4 integers,
+ in the following order:
+
+ 1. Modulus (*p*).
+ 2. Generator (*g*).
+ 3. Public key (*y*).
+ 4. Private key (*x*). Optional.
+
+ Raises:
+ ValueError: when the key being imported fails the most basic ElGamal validity checks.
+
+ Returns:
+ an :class:`ElGamalKey` object
+ """
+
+ obj=ElGamalKey()
+ if len(tup) not in [3,4]:
+ raise ValueError('argument for construct() wrong length')
+ for i in range(len(tup)):
+ field = obj._keydata[i]
+ setattr(obj, field, Integer(tup[i]))
+
+ fmt_error = test_probable_prime(obj.p) == COMPOSITE
+ fmt_error |= obj.g<=1 or obj.g>=obj.p
+ fmt_error |= pow(obj.g, obj.p-1, obj.p)!=1
+ fmt_error |= obj.y<1 or obj.y>=obj.p
+ if len(tup)==4:
+ fmt_error |= obj.x<=1 or obj.x>=obj.p
+ fmt_error |= pow(obj.g, obj.x, obj.p)!=obj.y
+
+ if fmt_error:
+ raise ValueError("Invalid ElGamal key components")
+
+ return obj
+
+class ElGamalKey(object):
+ r"""Class defining an ElGamal key.
+ Do not instantiate directly.
+ Use :func:`generate` or :func:`construct` instead.
+
+ :ivar p: Modulus
+ :vartype d: integer
+
+ :ivar g: Generator
+ :vartype e: integer
+
+ :ivar y: Public key component
+ :vartype y: integer
+
+ :ivar x: Private key component
+ :vartype x: integer
+ """
+
+ #: Dictionary of ElGamal parameters.
+ #:
+ #: A public key will only have the following entries:
+ #:
+ #: - **y**, the public key.
+ #: - **g**, the generator.
+ #: - **p**, the modulus.
+ #:
+ #: A private key will also have:
+ #:
+ #: - **x**, the private key.
+ _keydata=['p', 'g', 'y', 'x']
+
+ def __init__(self, randfunc=None):
+ if randfunc is None:
+ randfunc = Random.new().read
+ self._randfunc = randfunc
+
+ def _encrypt(self, M, K):
+ a=pow(self.g, K, self.p)
+ b=( pow(self.y, K, self.p)*M ) % self.p
+ return [int(a), int(b)]
+
+ def _decrypt(self, M):
+ if (not hasattr(self, 'x')):
+ raise TypeError('Private key not available in this object')
+ r = Integer.random_range(min_inclusive=2,
+ max_exclusive=self.p-1,
+ randfunc=self._randfunc)
+ a_blind = (pow(self.g, r, self.p) * M[0]) % self.p
+ ax=pow(a_blind, self.x, self.p)
+ plaintext_blind = (ax.inverse(self.p) * M[1] ) % self.p
+ plaintext = (plaintext_blind * pow(self.y, r, self.p)) % self.p
+ return int(plaintext)
+
+ def _sign(self, M, K):
+ if (not hasattr(self, 'x')):
+ raise TypeError('Private key not available in this object')
+ p1=self.p-1
+ K = Integer(K)
+ if (K.gcd(p1)!=1):
+ raise ValueError('Bad K value: GCD(K,p-1)!=1')
+ a=pow(self.g, K, self.p)
+ t=(Integer(M)-self.x*a) % p1
+ while t<0: t=t+p1
+ b=(t*K.inverse(p1)) % p1
+ return [int(a), int(b)]
+
+ def _verify(self, M, sig):
+ sig = [Integer(x) for x in sig]
+ if sig[0]<1 or sig[0]>self.p-1:
+ return 0
+ v1=pow(self.y, sig[0], self.p)
+ v1=(v1*pow(sig[0], sig[1], self.p)) % self.p
+ v2=pow(self.g, M, self.p)
+ if v1==v2:
+ return 1
+ return 0
+
+ def has_private(self):
+ """Whether this is an ElGamal private key"""
+
+ if hasattr(self, 'x'):
+ return 1
+ else:
+ return 0
+
+ def can_encrypt(self):
+ return True
+
+ def can_sign(self):
+ return True
+
+ def publickey(self):
+ """A matching ElGamal public key.
+
+ Returns:
+ a new :class:`ElGamalKey` object
+ """
+ return construct((self.p, self.g, self.y))
+
+ def __eq__(self, other):
+ if bool(self.has_private()) != bool(other.has_private()):
+ return False
+
+ result = True
+ for comp in self._keydata:
+ result = result and (getattr(self.key, comp, None) ==
+ getattr(other.key, comp, None))
+ return result
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getstate__(self):
+ # ElGamal key is not pickable
+ from pickle import PicklingError
+ raise PicklingError
+
+ # Methods defined in PyCrypto that we don't support anymore
+
+ def sign(self, M, K):
+ raise NotImplementedError
+
+ def verify(self, M, signature):
+ raise NotImplementedError
+
+ def encrypt(self, plaintext, K):
+ raise NotImplementedError
+
+ def decrypt(self, ciphertext):
+ raise NotImplementedError
+
+ def blind(self, M, B):
+ raise NotImplementedError
+
+ def unblind(self, M, B):
+ raise NotImplementedError
+
+ def size(self):
+ raise NotImplementedError
diff --git a/lib/Crypto/PublicKey/ElGamal.pyi b/lib/Crypto/PublicKey/ElGamal.pyi
new file mode 100644
index 0000000..9048531
--- /dev/null
+++ b/lib/Crypto/PublicKey/ElGamal.pyi
@@ -0,0 +1,18 @@
+from typing import Callable, Union, Tuple, Optional
+
+__all__ = ['generate', 'construct', 'ElGamalKey']
+
+RNG = Callable[[int], bytes]
+
+def generate(bits: int, randfunc: RNG) -> ElGamalKey: ...
+def construct(tup: Union[Tuple[int, int, int], Tuple[int, int, int, int]]) -> ElGamalKey: ...
+
+class ElGamalKey(object):
+ def __init__(self, randfunc: Optional[RNG]=None) -> None: ...
+ def has_private(self) -> bool: ...
+ def can_encrypt(self) -> bool: ...
+ def can_sign(self) -> bool: ...
+ def publickey(self) -> ElGamalKey: ...
+ def __eq__(self, other: object) -> bool: ...
+ def __ne__(self, other: object) -> bool: ...
+ def __getstate__(self) -> None: ...
diff --git a/lib/Crypto/PublicKey/RSA.py b/lib/Crypto/PublicKey/RSA.py
new file mode 100644
index 0000000..0f5e589
--- /dev/null
+++ b/lib/Crypto/PublicKey/RSA.py
@@ -0,0 +1,802 @@
+# -*- coding: utf-8 -*-
+# ===================================================================
+#
+# Copyright (c) 2016, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+__all__ = ['generate', 'construct', 'import_key',
+ 'RsaKey', 'oid']
+
+import binascii
+import struct
+
+from Crypto import Random
+from Crypto.Util.py3compat import tobytes, bord, tostr
+from Crypto.Util.asn1 import DerSequence, DerNull
+
+from Crypto.Math.Numbers import Integer
+from Crypto.Math.Primality import (test_probable_prime,
+ generate_probable_prime, COMPOSITE)
+
+from Crypto.PublicKey import (_expand_subject_public_key_info,
+ _create_subject_public_key_info,
+ _extract_subject_public_key_info)
+
+
+class RsaKey(object):
+ r"""Class defining an actual RSA key.
+ Do not instantiate directly.
+ Use :func:`generate`, :func:`construct` or :func:`import_key` instead.
+
+ :ivar n: RSA modulus
+ :vartype n: integer
+
+ :ivar e: RSA public exponent
+ :vartype e: integer
+
+ :ivar d: RSA private exponent
+ :vartype d: integer
+
+ :ivar p: First factor of the RSA modulus
+ :vartype p: integer
+
+ :ivar q: Second factor of the RSA modulus
+ :vartype q: integer
+
+ :ivar u: Chinese remainder component (:math:`p^{-1} \text{mod } q`)
+ :vartype u: integer
+
+ :undocumented: exportKey, publickey
+ """
+
+ def __init__(self, **kwargs):
+ """Build an RSA key.
+
+ :Keywords:
+ n : integer
+ The modulus.
+ e : integer
+ The public exponent.
+ d : integer
+ The private exponent. Only required for private keys.
+ p : integer
+ The first factor of the modulus. Only required for private keys.
+ q : integer
+ The second factor of the modulus. Only required for private keys.
+ u : integer
+ The CRT coefficient (inverse of p modulo q). Only required for
+ private keys.
+ """
+
+ input_set = set(kwargs.keys())
+ public_set = set(('n', 'e'))
+ private_set = public_set | set(('p', 'q', 'd', 'u'))
+ if input_set not in (private_set, public_set):
+ raise ValueError("Some RSA components are missing")
+ for component, value in kwargs.items():
+ setattr(self, "_" + component, value)
+ if input_set == private_set:
+ self._dp = self._d % (self._p - 1) # = (e⁻¹) mod (p-1)
+ self._dq = self._d % (self._q - 1) # = (e⁻¹) mod (q-1)
+
+ @property
+ def n(self):
+ return int(self._n)
+
+ @property
+ def e(self):
+ return int(self._e)
+
+ @property
+ def d(self):
+ if not self.has_private():
+ raise AttributeError("No private exponent available for public keys")
+ return int(self._d)
+
+ @property
+ def p(self):
+ if not self.has_private():
+ raise AttributeError("No CRT component 'p' available for public keys")
+ return int(self._p)
+
+ @property
+ def q(self):
+ if not self.has_private():
+ raise AttributeError("No CRT component 'q' available for public keys")
+ return int(self._q)
+
+ @property
+ def u(self):
+ if not self.has_private():
+ raise AttributeError("No CRT component 'u' available for public keys")
+ return int(self._u)
+
+ def size_in_bits(self):
+ """Size of the RSA modulus in bits"""
+ return self._n.size_in_bits()
+
+ def size_in_bytes(self):
+ """The minimal amount of bytes that can hold the RSA modulus"""
+ return (self._n.size_in_bits() - 1) // 8 + 1
+
+ def _encrypt(self, plaintext):
+ if not 0 <= plaintext < self._n:
+ raise ValueError("Plaintext too large")
+ return int(pow(Integer(plaintext), self._e, self._n))
+
+ def _decrypt(self, ciphertext):
+ if not 0 <= ciphertext < self._n:
+ raise ValueError("Ciphertext too large")
+ if not self.has_private():
+ raise TypeError("This is not a private key")
+
+ # Blinded RSA decryption (to prevent timing attacks):
+ # Step 1: Generate random secret blinding factor r,
+ # such that 0 < r < n-1
+ r = Integer.random_range(min_inclusive=1, max_exclusive=self._n)
+ # Step 2: Compute c' = c * r**e mod n
+ cp = Integer(ciphertext) * pow(r, self._e, self._n) % self._n
+ # Step 3: Compute m' = c'**d mod n (normal RSA decryption)
+ m1 = pow(cp, self._dp, self._p)
+ m2 = pow(cp, self._dq, self._q)
+ h = ((m2 - m1) * self._u) % self._q
+ mp = h * self._p + m1
+ # Step 4: Compute m = m' * (r**(-1)) mod n
+ result = (r.inverse(self._n) * mp) % self._n
+ # Verify no faults occurred
+ if ciphertext != pow(result, self._e, self._n):
+ raise ValueError("Fault detected in RSA decryption")
+ return result
+
+ def has_private(self):
+ """Whether this is an RSA private key"""
+
+ return hasattr(self, "_d")
+
+ def can_encrypt(self): # legacy
+ return True
+
+ def can_sign(self): # legacy
+ return True
+
+ def public_key(self):
+ """A matching RSA public key.
+
+ Returns:
+ a new :class:`RsaKey` object
+ """
+ return RsaKey(n=self._n, e=self._e)
+
+ def __eq__(self, other):
+ if self.has_private() != other.has_private():
+ return False
+ if self.n != other.n or self.e != other.e:
+ return False
+ if not self.has_private():
+ return True
+ return (self.d == other.d)
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ def __getstate__(self):
+ # RSA key is not pickable
+ from pickle import PicklingError
+ raise PicklingError
+
+ def __repr__(self):
+ if self.has_private():
+ extra = ", d=%d, p=%d, q=%d, u=%d" % (int(self._d), int(self._p),
+ int(self._q), int(self._u))
+ else:
+ extra = ""
+ return "RsaKey(n=%d, e=%d%s)" % (int(self._n), int(self._e), extra)
+
+ def __str__(self):
+ if self.has_private():
+ key_type = "Private"
+ else:
+ key_type = "Public"
+ return "%s RSA key at 0x%X" % (key_type, id(self))
+
+ def export_key(self, format='PEM', passphrase=None, pkcs=1,
+ protection=None, randfunc=None):
+ """Export this RSA key.
+
+ Args:
+ format (string):
+ The format to use for wrapping the key:
+
+ - *'PEM'*. (*Default*) Text encoding, done according to `RFC1421`_/`RFC1423`_.
+ - *'DER'*. Binary encoding.
+ - *'OpenSSH'*. Textual encoding, done according to OpenSSH specification.
+ Only suitable for public keys (not private keys).
+
+ passphrase (string):
+ (*For private keys only*) The pass phrase used for protecting the output.
+
+ pkcs (integer):
+ (*For private keys only*) The ASN.1 structure to use for
+ serializing the key. Note that even in case of PEM
+ encoding, there is an inner ASN.1 DER structure.
+
+ With ``pkcs=1`` (*default*), the private key is encoded in a
+ simple `PKCS#1`_ structure (``RSAPrivateKey``).
+
+ With ``pkcs=8``, the private key is encoded in a `PKCS#8`_ structure
+ (``PrivateKeyInfo``).
+
+ .. note::
+ This parameter is ignored for a public key.
+ For DER and PEM, an ASN.1 DER ``SubjectPublicKeyInfo``
+ structure is always used.
+
+ protection (string):
+ (*For private keys only*)
+ The encryption scheme to use for protecting the private key.
+
+ If ``None`` (default), the behavior depends on :attr:`format`:
+
+ - For *'DER'*, the *PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC*
+ scheme is used. The following operations are performed:
+
+ 1. A 16 byte Triple DES key is derived from the passphrase
+ using :func:`Crypto.Protocol.KDF.PBKDF2` with 8 bytes salt,
+ and 1 000 iterations of :mod:`Crypto.Hash.HMAC`.
+ 2. The private key is encrypted using CBC.
+ 3. The encrypted key is encoded according to PKCS#8.
+
+ - For *'PEM'*, the obsolete PEM encryption scheme is used.
+ It is based on MD5 for key derivation, and Triple DES for encryption.
+
+ Specifying a value for :attr:`protection` is only meaningful for PKCS#8
+ (that is, ``pkcs=8``) and only if a pass phrase is present too.
+
+ The supported schemes for PKCS#8 are listed in the
+ :mod:`Crypto.IO.PKCS8` module (see :attr:`wrap_algo` parameter).
+
+ randfunc (callable):
+ A function that provides random bytes. Only used for PEM encoding.
+ The default is :func:`Crypto.Random.get_random_bytes`.
+
+ Returns:
+ byte string: the encoded key
+
+ Raises:
+ ValueError:when the format is unknown or when you try to encrypt a private
+ key with *DER* format and PKCS#1.
+
+ .. warning::
+ If you don't provide a pass phrase, the private key will be
+ exported in the clear!
+
+ .. _RFC1421: http://www.ietf.org/rfc/rfc1421.txt
+ .. _RFC1423: http://www.ietf.org/rfc/rfc1423.txt
+ .. _`PKCS#1`: http://www.ietf.org/rfc/rfc3447.txt
+ .. _`PKCS#8`: http://www.ietf.org/rfc/rfc5208.txt
+ """
+
+ if passphrase is not None:
+ passphrase = tobytes(passphrase)
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ if format == 'OpenSSH':
+ e_bytes, n_bytes = [x.to_bytes() for x in (self._e, self._n)]
+ if bord(e_bytes[0]) & 0x80:
+ e_bytes = b'\x00' + e_bytes
+ if bord(n_bytes[0]) & 0x80:
+ n_bytes = b'\x00' + n_bytes
+ keyparts = [b'ssh-rsa', e_bytes, n_bytes]
+ keystring = b''.join([struct.pack(">I", len(kp)) + kp for kp in keyparts])
+ return b'ssh-rsa ' + binascii.b2a_base64(keystring)[:-1]
+
+ # DER format is always used, even in case of PEM, which simply
+ # encodes it into BASE64.
+ if self.has_private():
+ binary_key = DerSequence([0,
+ self.n,
+ self.e,
+ self.d,
+ self.p,
+ self.q,
+ self.d % (self.p-1),
+ self.d % (self.q-1),
+ Integer(self.q).inverse(self.p)
+ ]).encode()
+ if pkcs == 1:
+ key_type = 'RSA PRIVATE KEY'
+ if format == 'DER' and passphrase:
+ raise ValueError("PKCS#1 private key cannot be encrypted")
+ else: # PKCS#8
+ from Crypto.IO import PKCS8
+
+ if format == 'PEM' and protection is None:
+ key_type = 'PRIVATE KEY'
+ binary_key = PKCS8.wrap(binary_key, oid, None,
+ key_params=DerNull())
+ else:
+ key_type = 'ENCRYPTED PRIVATE KEY'
+ if not protection:
+ protection = 'PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC'
+ binary_key = PKCS8.wrap(binary_key, oid,
+ passphrase, protection,
+ key_params=DerNull())
+ passphrase = None
+ else:
+ key_type = "PUBLIC KEY"
+ binary_key = _create_subject_public_key_info(oid,
+ DerSequence([self.n,
+ self.e]),
+ DerNull()
+ )
+
+ if format == 'DER':
+ return binary_key
+ if format == 'PEM':
+ from Crypto.IO import PEM
+
+ pem_str = PEM.encode(binary_key, key_type, passphrase, randfunc)
+ return tobytes(pem_str)
+
+ raise ValueError("Unknown key format '%s'. Cannot export the RSA key." % format)
+
+ # Backward compatibility
+ exportKey = export_key
+ publickey = public_key
+
+ # Methods defined in PyCrypto that we don't support anymore
+ def sign(self, M, K):
+ raise NotImplementedError("Use module Crypto.Signature.pkcs1_15 instead")
+
+ def verify(self, M, signature):
+ raise NotImplementedError("Use module Crypto.Signature.pkcs1_15 instead")
+
+ def encrypt(self, plaintext, K):
+ raise NotImplementedError("Use module Crypto.Cipher.PKCS1_OAEP instead")
+
+ def decrypt(self, ciphertext):
+ raise NotImplementedError("Use module Crypto.Cipher.PKCS1_OAEP instead")
+
+ def blind(self, M, B):
+ raise NotImplementedError
+
+ def unblind(self, M, B):
+ raise NotImplementedError
+
+ def size(self):
+ raise NotImplementedError
+
+
+def generate(bits, randfunc=None, e=65537):
+ """Create a new RSA key pair.
+
+ The algorithm closely follows NIST `FIPS 186-4`_ in its
+ sections B.3.1 and B.3.3. The modulus is the product of
+ two non-strong probable primes.
+ Each prime passes a suitable number of Miller-Rabin tests
+ with random bases and a single Lucas test.
+
+ Args:
+ bits (integer):
+ Key length, or size (in bits) of the RSA modulus.
+ It must be at least 1024, but **2048 is recommended.**
+ The FIPS standard only defines 1024, 2048 and 3072.
+ randfunc (callable):
+ Function that returns random bytes.
+ The default is :func:`Crypto.Random.get_random_bytes`.
+ e (integer):
+ Public RSA exponent. It must be an odd positive integer.
+ It is typically a small number with very few ones in its
+ binary representation.
+ The FIPS standard requires the public exponent to be
+ at least 65537 (the default).
+
+ Returns: an RSA key object (:class:`RsaKey`, with private key).
+
+ .. _FIPS 186-4: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
+ """
+
+ if bits < 1024:
+ raise ValueError("RSA modulus length must be >= 1024")
+ if e % 2 == 0 or e < 3:
+ raise ValueError("RSA public exponent must be a positive, odd integer larger than 2.")
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ d = n = Integer(1)
+ e = Integer(e)
+
+ while n.size_in_bits() != bits and d < (1 << (bits // 2)):
+ # Generate the prime factors of n: p and q.
+ # By construciton, their product is always
+ # 2^{bits-1} < p*q < 2^bits.
+ size_q = bits // 2
+ size_p = bits - size_q
+
+ min_p = min_q = (Integer(1) << (2 * size_q - 1)).sqrt()
+ if size_q != size_p:
+ min_p = (Integer(1) << (2 * size_p - 1)).sqrt()
+
+ def filter_p(candidate):
+ return candidate > min_p and (candidate - 1).gcd(e) == 1
+
+ p = generate_probable_prime(exact_bits=size_p,
+ randfunc=randfunc,
+ prime_filter=filter_p)
+
+ min_distance = Integer(1) << (bits // 2 - 100)
+
+ def filter_q(candidate):
+ return (candidate > min_q and
+ (candidate - 1).gcd(e) == 1 and
+ abs(candidate - p) > min_distance)
+
+ q = generate_probable_prime(exact_bits=size_q,
+ randfunc=randfunc,
+ prime_filter=filter_q)
+
+ n = p * q
+ lcm = (p - 1).lcm(q - 1)
+ d = e.inverse(lcm)
+
+ if p > q:
+ p, q = q, p
+
+ u = p.inverse(q)
+
+ return RsaKey(n=n, e=e, d=d, p=p, q=q, u=u)
+
+
+def construct(rsa_components, consistency_check=True):
+ r"""Construct an RSA key from a tuple of valid RSA components.
+
+ The modulus **n** must be the product of two primes.
+ The public exponent **e** must be odd and larger than 1.
+
+ In case of a private key, the following equations must apply:
+
+ .. math::
+
+ \begin{align}
+ p*q &= n \\
+ e*d &\equiv 1 ( \text{mod lcm} [(p-1)(q-1)]) \\
+ p*u &\equiv 1 ( \text{mod } q)
+ \end{align}
+
+ Args:
+ rsa_components (tuple):
+ A tuple of integers, with at least 2 and no
+ more than 6 items. The items come in the following order:
+
+ 1. RSA modulus *n*.
+ 2. Public exponent *e*.
+ 3. Private exponent *d*.
+ Only required if the key is private.
+ 4. First factor of *n* (*p*).
+ Optional, but the other factor *q* must also be present.
+ 5. Second factor of *n* (*q*). Optional.
+ 6. CRT coefficient *q*, that is :math:`p^{-1} \text{mod }q`. Optional.
+
+ consistency_check (boolean):
+ If ``True``, the library will verify that the provided components
+ fulfil the main RSA properties.
+
+ Raises:
+ ValueError: when the key being imported fails the most basic RSA validity checks.
+
+ Returns: An RSA key object (:class:`RsaKey`).
+ """
+
+ class InputComps(object):
+ pass
+
+ input_comps = InputComps()
+ for (comp, value) in zip(('n', 'e', 'd', 'p', 'q', 'u'), rsa_components):
+ setattr(input_comps, comp, Integer(value))
+
+ n = input_comps.n
+ e = input_comps.e
+ if not hasattr(input_comps, 'd'):
+ key = RsaKey(n=n, e=e)
+ else:
+ d = input_comps.d
+ if hasattr(input_comps, 'q'):
+ p = input_comps.p
+ q = input_comps.q
+ else:
+ # Compute factors p and q from the private exponent d.
+ # We assume that n has no more than two factors.
+ # See 8.2.2(i) in Handbook of Applied Cryptography.
+ ktot = d * e - 1
+ # The quantity d*e-1 is a multiple of phi(n), even,
+ # and can be represented as t*2^s.
+ t = ktot
+ while t % 2 == 0:
+ t //= 2
+ # Cycle through all multiplicative inverses in Zn.
+ # The algorithm is non-deterministic, but there is a 50% chance
+ # any candidate a leads to successful factoring.
+ # See "Digitalized Signatures and Public Key Functions as Intractable
+ # as Factorization", M. Rabin, 1979
+ spotted = False
+ a = Integer(2)
+ while not spotted and a < 100:
+ k = Integer(t)
+ # Cycle through all values a^{t*2^i}=a^k
+ while k < ktot:
+ cand = pow(a, k, n)
+ # Check if a^k is a non-trivial root of unity (mod n)
+ if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
+ # We have found a number such that (cand-1)(cand+1)=0 (mod n).
+ # Either of the terms divides n.
+ p = Integer(n).gcd(cand + 1)
+ spotted = True
+ break
+ k *= 2
+ # This value was not any good... let's try another!
+ a += 2
+ if not spotted:
+ raise ValueError("Unable to compute factors p and q from exponent d.")
+ # Found !
+ assert ((n % p) == 0)
+ q = n // p
+
+ if hasattr(input_comps, 'u'):
+ u = input_comps.u
+ else:
+ u = p.inverse(q)
+
+ # Build key object
+ key = RsaKey(n=n, e=e, d=d, p=p, q=q, u=u)
+
+ # Verify consistency of the key
+ if consistency_check:
+
+ # Modulus and public exponent must be coprime
+ if e <= 1 or e >= n:
+ raise ValueError("Invalid RSA public exponent")
+ if Integer(n).gcd(e) != 1:
+ raise ValueError("RSA public exponent is not coprime to modulus")
+
+ # For RSA, modulus must be odd
+ if not n & 1:
+ raise ValueError("RSA modulus is not odd")
+
+ if key.has_private():
+ # Modulus and private exponent must be coprime
+ if d <= 1 or d >= n:
+ raise ValueError("Invalid RSA private exponent")
+ if Integer(n).gcd(d) != 1:
+ raise ValueError("RSA private exponent is not coprime to modulus")
+ # Modulus must be product of 2 primes
+ if p * q != n:
+ raise ValueError("RSA factors do not match modulus")
+ if test_probable_prime(p) == COMPOSITE:
+ raise ValueError("RSA factor p is composite")
+ if test_probable_prime(q) == COMPOSITE:
+ raise ValueError("RSA factor q is composite")
+ # See Carmichael theorem
+ phi = (p - 1) * (q - 1)
+ lcm = phi // (p - 1).gcd(q - 1)
+ if (e * d % int(lcm)) != 1:
+ raise ValueError("Invalid RSA condition")
+ if hasattr(key, 'u'):
+ # CRT coefficient
+ if u <= 1 or u >= q:
+ raise ValueError("Invalid RSA component u")
+ if (p * u % q) != 1:
+ raise ValueError("Invalid RSA component u with p")
+
+ return key
+
+
+def _import_pkcs1_private(encoded, *kwargs):
+ # RSAPrivateKey ::= SEQUENCE {
+ # version Version,
+ # modulus INTEGER, -- n
+ # publicExponent INTEGER, -- e
+ # privateExponent INTEGER, -- d
+ # prime1 INTEGER, -- p
+ # prime2 INTEGER, -- q
+ # exponent1 INTEGER, -- d mod (p-1)
+ # exponent2 INTEGER, -- d mod (q-1)
+ # coefficient INTEGER -- (inverse of q) mod p
+ # }
+ #
+ # Version ::= INTEGER
+ der = DerSequence().decode(encoded, nr_elements=9, only_ints_expected=True)
+ if der[0] != 0:
+ raise ValueError("No PKCS#1 encoding of an RSA private key")
+ return construct(der[1:6] + [Integer(der[4]).inverse(der[5])])
+
+
+def _import_pkcs1_public(encoded, *kwargs):
+ # RSAPublicKey ::= SEQUENCE {
+ # modulus INTEGER, -- n
+ # publicExponent INTEGER -- e
+ # }
+ der = DerSequence().decode(encoded, nr_elements=2, only_ints_expected=True)
+ return construct(der)
+
+
+def _import_subjectPublicKeyInfo(encoded, *kwargs):
+
+ algoid, encoded_key, params = _expand_subject_public_key_info(encoded)
+ if algoid != oid or params is not None:
+ raise ValueError("No RSA subjectPublicKeyInfo")
+ return _import_pkcs1_public(encoded_key)
+
+
+def _import_x509_cert(encoded, *kwargs):
+
+ sp_info = _extract_subject_public_key_info(encoded)
+ return _import_subjectPublicKeyInfo(sp_info)
+
+
+def _import_pkcs8(encoded, passphrase):
+ from Crypto.IO import PKCS8
+
+ k = PKCS8.unwrap(encoded, passphrase)
+ if k[0] != oid:
+ raise ValueError("No PKCS#8 encoded RSA key")
+ return _import_keyDER(k[1], passphrase)
+
+
+def _import_keyDER(extern_key, passphrase):
+ """Import an RSA key (public or private half), encoded in DER form."""
+
+ decodings = (_import_pkcs1_private,
+ _import_pkcs1_public,
+ _import_subjectPublicKeyInfo,
+ _import_x509_cert,
+ _import_pkcs8)
+
+ for decoding in decodings:
+ try:
+ return decoding(extern_key, passphrase)
+ except ValueError:
+ pass
+
+ raise ValueError("RSA key format is not supported")
+
+
+def _import_openssh_private_rsa(data, password):
+
+ from ._openssh import (import_openssh_private_generic,
+ read_bytes, read_string, check_padding)
+
+ ssh_name, decrypted = import_openssh_private_generic(data, password)
+
+ if ssh_name != "ssh-rsa":
+ raise ValueError("This SSH key is not RSA")
+
+ n, decrypted = read_bytes(decrypted)
+ e, decrypted = read_bytes(decrypted)
+ d, decrypted = read_bytes(decrypted)
+ iqmp, decrypted = read_bytes(decrypted)
+ p, decrypted = read_bytes(decrypted)
+ q, decrypted = read_bytes(decrypted)
+
+ _, padded = read_string(decrypted) # Comment
+ check_padding(padded)
+
+ build = [Integer.from_bytes(x) for x in (n, e, d, q, p, iqmp)]
+ return construct(build)
+
+
+def import_key(extern_key, passphrase=None):
+ """Import an RSA key (public or private).
+
+ Args:
+ extern_key (string or byte string):
+ The RSA key to import.
+
+ The following formats are supported for an RSA **public key**:
+
+ - X.509 certificate (binary or PEM format)
+ - X.509 ``subjectPublicKeyInfo`` DER SEQUENCE (binary or PEM
+ encoding)
+ - `PKCS#1`_ ``RSAPublicKey`` DER SEQUENCE (binary or PEM encoding)
+ - An OpenSSH line (e.g. the content of ``~/.ssh/id_ecdsa``, ASCII)
+
+ The following formats are supported for an RSA **private key**:
+
+ - PKCS#1 ``RSAPrivateKey`` DER SEQUENCE (binary or PEM encoding)
+ - `PKCS#8`_ ``PrivateKeyInfo`` or ``EncryptedPrivateKeyInfo``
+ DER SEQUENCE (binary or PEM encoding)
+ - OpenSSH (text format, introduced in `OpenSSH 6.5`_)
+
+ For details about the PEM encoding, see `RFC1421`_/`RFC1423`_.
+
+ passphrase (string or byte string):
+ For private keys only, the pass phrase that encrypts the key.
+
+ Returns: An RSA key object (:class:`RsaKey`).
+
+ Raises:
+ ValueError/IndexError/TypeError:
+ When the given key cannot be parsed (possibly because the pass
+ phrase is wrong).
+
+ .. _RFC1421: http://www.ietf.org/rfc/rfc1421.txt
+ .. _RFC1423: http://www.ietf.org/rfc/rfc1423.txt
+ .. _`PKCS#1`: http://www.ietf.org/rfc/rfc3447.txt
+ .. _`PKCS#8`: http://www.ietf.org/rfc/rfc5208.txt
+ .. _`OpenSSH 6.5`: https://flak.tedunangst.com/post/new-openssh-key-format-and-bcrypt-pbkdf
+ """
+
+ from Crypto.IO import PEM
+
+ extern_key = tobytes(extern_key)
+ if passphrase is not None:
+ passphrase = tobytes(passphrase)
+
+ if extern_key.startswith(b'-----BEGIN OPENSSH PRIVATE KEY'):
+ text_encoded = tostr(extern_key)
+ openssh_encoded, marker, enc_flag = PEM.decode(text_encoded, passphrase)
+ result = _import_openssh_private_rsa(openssh_encoded, passphrase)
+ return result
+
+ if extern_key.startswith(b'-----'):
+ # This is probably a PEM encoded key.
+ (der, marker, enc_flag) = PEM.decode(tostr(extern_key), passphrase)
+ if enc_flag:
+ passphrase = None
+ return _import_keyDER(der, passphrase)
+
+ if extern_key.startswith(b'ssh-rsa '):
+ # This is probably an OpenSSH key
+ keystring = binascii.a2b_base64(extern_key.split(b' ')[1])
+ keyparts = []
+ while len(keystring) > 4:
+ length = struct.unpack(">I", keystring[:4])[0]
+ keyparts.append(keystring[4:4 + length])
+ keystring = keystring[4 + length:]
+ e = Integer.from_bytes(keyparts[1])
+ n = Integer.from_bytes(keyparts[2])
+ return construct([n, e])
+
+ if len(extern_key) > 0 and bord(extern_key[0]) == 0x30:
+ # This is probably a DER encoded key
+ return _import_keyDER(extern_key, passphrase)
+
+ raise ValueError("RSA key format is not supported")
+
+
+# Backward compatibility
+importKey = import_key
+
+#: `Object ID`_ for the RSA encryption algorithm. This OID often indicates
+#: a generic RSA key, even when such key will be actually used for digital
+#: signatures.
+#:
+#: .. _`Object ID`: http://www.alvestrand.no/objectid/1.2.840.113549.1.1.1.html
+oid = "1.2.840.113549.1.1.1"
diff --git a/lib/Crypto/PublicKey/RSA.pyi b/lib/Crypto/PublicKey/RSA.pyi
new file mode 100644
index 0000000..d436acf
--- /dev/null
+++ b/lib/Crypto/PublicKey/RSA.pyi
@@ -0,0 +1,51 @@
+from typing import Callable, Union, Tuple, Optional
+
+__all__ = ['generate', 'construct', 'import_key',
+ 'RsaKey', 'oid']
+
+RNG = Callable[[int], bytes]
+
+class RsaKey(object):
+ def __init__(self, **kwargs: int) -> None: ...
+ @property
+ def n(self) -> int: ...
+ @property
+ def e(self) -> int: ...
+ @property
+ def d(self) -> int: ...
+ @property
+ def p(self) -> int: ...
+ @property
+ def q(self) -> int: ...
+ @property
+ def u(self) -> int: ...
+ def size_in_bits(self) -> int: ...
+ def size_in_bytes(self) -> int: ...
+ def has_private(self) -> bool: ...
+ def can_encrypt(self) -> bool: ... # legacy
+ def can_sign(self) -> bool:... # legacy
+ def public_key(self) -> RsaKey: ...
+ def __eq__(self, other: object) -> bool: ...
+ def __ne__(self, other: object) -> bool: ...
+ def __getstate__(self) -> None: ...
+ def __repr__(self) -> str: ...
+ def __str__(self) -> str: ...
+ def export_key(self, format: Optional[str]="PEM", passphrase: Optional[str]=None, pkcs: Optional[int]=1,
+ protection: Optional[str]=None, randfunc: Optional[RNG]=None) -> bytes: ...
+
+ # Backward compatibility
+ exportKey = export_key
+ publickey = public_key
+
+def generate(bits: int, randfunc: Optional[RNG]=None, e: Optional[int]=65537) -> RsaKey: ...
+def construct(rsa_components: Union[Tuple[int, int], # n, e
+ Tuple[int, int, int], # n, e, d
+ Tuple[int, int, int, int, int], # n, e, d, p, q
+ Tuple[int, int, int, int, int, int]], # n, e, d, p, q, crt_q
+ consistency_check: Optional[bool]=True) -> RsaKey: ...
+def import_key(extern_key: Union[str, bytes], passphrase: Optional[str]=None) -> RsaKey: ...
+
+# Backward compatibility
+importKey = import_key
+
+oid: str
diff --git a/lib/Crypto/PublicKey/__init__.py b/lib/Crypto/PublicKey/__init__.py
new file mode 100644
index 0000000..cf3a238
--- /dev/null
+++ b/lib/Crypto/PublicKey/__init__.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from Crypto.Util.asn1 import (DerSequence, DerInteger, DerBitString,
+ DerObjectId, DerNull)
+
+
+def _expand_subject_public_key_info(encoded):
+ """Parse a SubjectPublicKeyInfo structure.
+
+ It returns a triple with:
+ * OID (string)
+ * encoded public key (bytes)
+ * Algorithm parameters (bytes or None)
+ """
+
+ #
+ # SubjectPublicKeyInfo ::= SEQUENCE {
+ # algorithm AlgorithmIdentifier,
+ # subjectPublicKey BIT STRING
+ # }
+ #
+ # AlgorithmIdentifier ::= SEQUENCE {
+ # algorithm OBJECT IDENTIFIER,
+ # parameters ANY DEFINED BY algorithm OPTIONAL
+ # }
+ #
+
+ spki = DerSequence().decode(encoded, nr_elements=2)
+ algo = DerSequence().decode(spki[0], nr_elements=(1,2))
+ algo_oid = DerObjectId().decode(algo[0])
+ spk = DerBitString().decode(spki[1]).value
+
+ if len(algo) == 1:
+ algo_params = None
+ else:
+ try:
+ DerNull().decode(algo[1])
+ algo_params = None
+ except:
+ algo_params = algo[1]
+
+ return algo_oid.value, spk, algo_params
+
+
+def _create_subject_public_key_info(algo_oid, public_key, params):
+
+ if params is None:
+ algorithm = DerSequence([DerObjectId(algo_oid)])
+ else:
+ algorithm = DerSequence([DerObjectId(algo_oid), params])
+
+ spki = DerSequence([algorithm,
+ DerBitString(public_key)
+ ])
+ return spki.encode()
+
+
+def _extract_subject_public_key_info(x509_certificate):
+ """Extract subjectPublicKeyInfo from a DER X.509 certificate."""
+
+ certificate = DerSequence().decode(x509_certificate, nr_elements=3)
+ tbs_certificate = DerSequence().decode(certificate[0],
+ nr_elements=range(6, 11))
+
+ index = 5
+ try:
+ tbs_certificate[0] + 1
+ # Version not present
+ version = 1
+ except TypeError:
+ version = DerInteger(explicit=0).decode(tbs_certificate[0]).value
+ if version not in (2, 3):
+ raise ValueError("Incorrect X.509 certificate version")
+ index = 6
+
+ return tbs_certificate[index]
diff --git a/lib/Crypto/PublicKey/__init__.pyi b/lib/Crypto/PublicKey/__init__.pyi
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/Crypto/PublicKey/__init__.pyi
diff --git a/lib/Crypto/PublicKey/_ec_ws.abi3.so b/lib/Crypto/PublicKey/_ec_ws.abi3.so
new file mode 100755
index 0000000..46ec795
--- /dev/null
+++ b/lib/Crypto/PublicKey/_ec_ws.abi3.so
Binary files differ
diff --git a/lib/Crypto/PublicKey/_ed25519.abi3.so b/lib/Crypto/PublicKey/_ed25519.abi3.so
new file mode 100755
index 0000000..891ec44
--- /dev/null
+++ b/lib/Crypto/PublicKey/_ed25519.abi3.so
Binary files differ
diff --git a/lib/Crypto/PublicKey/_ed448.abi3.so b/lib/Crypto/PublicKey/_ed448.abi3.so
new file mode 100755
index 0000000..a3ddd4b
--- /dev/null
+++ b/lib/Crypto/PublicKey/_ed448.abi3.so
Binary files differ
diff --git a/lib/Crypto/PublicKey/_openssh.py b/lib/Crypto/PublicKey/_openssh.py
new file mode 100644
index 0000000..88dacfc
--- /dev/null
+++ b/lib/Crypto/PublicKey/_openssh.py
@@ -0,0 +1,135 @@
+# ===================================================================
+#
+# Copyright (c) 2019, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import struct
+
+from Crypto.Cipher import AES
+from Crypto.Hash import SHA512
+from Crypto.Protocol.KDF import _bcrypt_hash
+from Crypto.Util.strxor import strxor
+from Crypto.Util.py3compat import tostr, bchr, bord
+
+
+def read_int4(data):
+ if len(data) < 4:
+ raise ValueError("Insufficient data")
+ value = struct.unpack(">I", data[:4])[0]
+ return value, data[4:]
+
+
+def read_bytes(data):
+ size, data = read_int4(data)
+ if len(data) < size:
+ raise ValueError("Insufficient data (V)")
+ return data[:size], data[size:]
+
+
+def read_string(data):
+ s, d = read_bytes(data)
+ return tostr(s), d
+
+
+def check_padding(pad):
+ for v, x in enumerate(pad):
+ if bord(x) != ((v + 1) & 0xFF):
+ raise ValueError("Incorrect padding")
+
+
+def import_openssh_private_generic(data, password):
+ # https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.key?annotate=HEAD
+ # https://github.com/openssh/openssh-portable/blob/master/sshkey.c
+ # https://coolaj86.com/articles/the-openssh-private-key-format/
+ # https://coolaj86.com/articles/the-ssh-public-key-format/
+
+ if not data.startswith(b'openssh-key-v1\x00'):
+ raise ValueError("Incorrect magic value")
+ data = data[15:]
+
+ ciphername, data = read_string(data)
+ kdfname, data = read_string(data)
+ kdfoptions, data = read_bytes(data)
+ number_of_keys, data = read_int4(data)
+
+ if number_of_keys != 1:
+ raise ValueError("We only handle 1 key at a time")
+
+ _, data = read_string(data) # Public key
+ encrypted, data = read_bytes(data)
+ if data:
+ raise ValueError("Too much data")
+
+ if len(encrypted) % 8 != 0:
+ raise ValueError("Incorrect payload length")
+
+ # Decrypt if necessary
+ if ciphername == 'none':
+ decrypted = encrypted
+ else:
+ if (ciphername, kdfname) != ('aes256-ctr', 'bcrypt'):
+ raise ValueError("Unsupported encryption scheme %s/%s" % (ciphername, kdfname))
+
+ salt, kdfoptions = read_bytes(kdfoptions)
+ iterations, kdfoptions = read_int4(kdfoptions)
+
+ if len(salt) != 16:
+ raise ValueError("Incorrect salt length")
+ if kdfoptions:
+ raise ValueError("Too much data in kdfoptions")
+
+ pwd_sha512 = SHA512.new(password).digest()
+ # We need 32+16 = 48 bytes, therefore 2 bcrypt outputs are sufficient
+ stripes = []
+ constant = b"OxychromaticBlowfishSwatDynamite"
+ for count in range(1, 3):
+ salt_sha512 = SHA512.new(salt + struct.pack(">I", count)).digest()
+ out_le = _bcrypt_hash(pwd_sha512, 6, salt_sha512, constant, False)
+ out = struct.pack("<IIIIIIII", *struct.unpack(">IIIIIIII", out_le))
+ acc = bytearray(out)
+ for _ in range(1, iterations):
+ out_le = _bcrypt_hash(pwd_sha512, 6, SHA512.new(out).digest(), constant, False)
+ out = struct.pack("<IIIIIIII", *struct.unpack(">IIIIIIII", out_le))
+ strxor(acc, out, output=acc)
+ stripes.append(acc[:24])
+
+ result = b"".join([bchr(a)+bchr(b) for (a, b) in zip(*stripes)])
+
+ cipher = AES.new(result[:32],
+ AES.MODE_CTR,
+ nonce=b"",
+ initial_value=result[32:32+16])
+ decrypted = cipher.decrypt(encrypted)
+
+ checkint1, decrypted = read_int4(decrypted)
+ checkint2, decrypted = read_int4(decrypted)
+ if checkint1 != checkint2:
+ raise ValueError("Incorrect checksum")
+ ssh_name, decrypted = read_string(decrypted)
+
+ return ssh_name, decrypted
diff --git a/lib/Crypto/PublicKey/_openssh.pyi b/lib/Crypto/PublicKey/_openssh.pyi
new file mode 100644
index 0000000..15f3677
--- /dev/null
+++ b/lib/Crypto/PublicKey/_openssh.pyi
@@ -0,0 +1,7 @@
+from typing import Tuple
+
+def read_int4(data: bytes) -> Tuple[int, bytes]: ...
+def read_bytes(data: bytes) -> Tuple[bytes, bytes]: ...
+def read_string(data: bytes) -> Tuple[str, bytes]: ...
+def check_padding(pad: bytes) -> None: ...
+def import_openssh_private_generic(data: bytes, password: bytes) -> Tuple[str, bytes]: ...
diff --git a/lib/Crypto/PublicKey/_x25519.abi3.so b/lib/Crypto/PublicKey/_x25519.abi3.so
new file mode 100755
index 0000000..afd3ee4
--- /dev/null
+++ b/lib/Crypto/PublicKey/_x25519.abi3.so
Binary files differ
diff --git a/lib/Crypto/Random/__init__.py b/lib/Crypto/Random/__init__.py
new file mode 100644
index 0000000..0f83a07
--- /dev/null
+++ b/lib/Crypto/Random/__init__.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+#
+# Random/__init__.py : PyCrypto random number generation
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+__all__ = ['new', 'get_random_bytes']
+
+from os import urandom
+
+class _UrandomRNG(object):
+
+ def read(self, n):
+ """Return a random byte string of the desired size."""
+ return urandom(n)
+
+ def flush(self):
+ """Method provided for backward compatibility only."""
+ pass
+
+ def reinit(self):
+ """Method provided for backward compatibility only."""
+ pass
+
+ def close(self):
+ """Method provided for backward compatibility only."""
+ pass
+
+
+def new(*args, **kwargs):
+ """Return a file-like object that outputs cryptographically random bytes."""
+ return _UrandomRNG()
+
+
+def atfork():
+ pass
+
+
+#: Function that returns a random byte string of the desired size.
+get_random_bytes = urandom
+
diff --git a/lib/Crypto/Random/__init__.pyi b/lib/Crypto/Random/__init__.pyi
new file mode 100644
index 0000000..ddc5b9b
--- /dev/null
+++ b/lib/Crypto/Random/__init__.pyi
@@ -0,0 +1,19 @@
+from typing import Any
+
+__all__ = ['new', 'get_random_bytes']
+
+from os import urandom
+
+class _UrandomRNG(object):
+
+ def read(self, n: int) -> bytes:...
+ def flush(self) -> None: ...
+ def reinit(self) -> None: ...
+ def close(self) -> None: ...
+
+def new(*args: Any, **kwargs: Any) -> _UrandomRNG: ...
+
+def atfork() -> None: ...
+
+get_random_bytes = urandom
+
diff --git a/lib/Crypto/Random/random.py b/lib/Crypto/Random/random.py
new file mode 100644
index 0000000..5389b3b
--- /dev/null
+++ b/lib/Crypto/Random/random.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+#
+# Random/random.py : Strong alternative for the standard 'random' module
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+__all__ = ['StrongRandom', 'getrandbits', 'randrange', 'randint', 'choice', 'shuffle', 'sample']
+
+from Crypto import Random
+
+from Crypto.Util.py3compat import is_native_int
+
+class StrongRandom(object):
+ def __init__(self, rng=None, randfunc=None):
+ if randfunc is None and rng is None:
+ self._randfunc = None
+ elif randfunc is not None and rng is None:
+ self._randfunc = randfunc
+ elif randfunc is None and rng is not None:
+ self._randfunc = rng.read
+ else:
+ raise ValueError("Cannot specify both 'rng' and 'randfunc'")
+
+ def getrandbits(self, k):
+ """Return an integer with k random bits."""
+
+ if self._randfunc is None:
+ self._randfunc = Random.new().read
+ mask = (1 << k) - 1
+ return mask & bytes_to_long(self._randfunc(ceil_div(k, 8)))
+
+ def randrange(self, *args):
+ """randrange([start,] stop[, step]):
+ Return a randomly-selected element from range(start, stop, step)."""
+ if len(args) == 3:
+ (start, stop, step) = args
+ elif len(args) == 2:
+ (start, stop) = args
+ step = 1
+ elif len(args) == 1:
+ (stop,) = args
+ start = 0
+ step = 1
+ else:
+ raise TypeError("randrange expected at most 3 arguments, got %d" % (len(args),))
+ if (not is_native_int(start) or not is_native_int(stop) or not
+ is_native_int(step)):
+ raise TypeError("randrange requires integer arguments")
+ if step == 0:
+ raise ValueError("randrange step argument must not be zero")
+
+ num_choices = ceil_div(stop - start, step)
+ if num_choices < 0:
+ num_choices = 0
+ if num_choices < 1:
+ raise ValueError("empty range for randrange(%r, %r, %r)" % (start, stop, step))
+
+ # Pick a random number in the range of possible numbers
+ r = num_choices
+ while r >= num_choices:
+ r = self.getrandbits(size(num_choices))
+
+ return start + (step * r)
+
+ def randint(self, a, b):
+ """Return a random integer N such that a <= N <= b."""
+ if not is_native_int(a) or not is_native_int(b):
+ raise TypeError("randint requires integer arguments")
+ N = self.randrange(a, b+1)
+ assert a <= N <= b
+ return N
+
+ def choice(self, seq):
+ """Return a random element from a (non-empty) sequence.
+
+ If the seqence is empty, raises IndexError.
+ """
+ if len(seq) == 0:
+ raise IndexError("empty sequence")
+ return seq[self.randrange(len(seq))]
+
+ def shuffle(self, x):
+ """Shuffle the sequence in place."""
+ # Fisher-Yates shuffle. O(n)
+ # See http://en.wikipedia.org/wiki/Fisher-Yates_shuffle
+ # Working backwards from the end of the array, we choose a random item
+ # from the remaining items until all items have been chosen.
+ for i in range(len(x)-1, 0, -1): # iterate from len(x)-1 downto 1
+ j = self.randrange(0, i+1) # choose random j such that 0 <= j <= i
+ x[i], x[j] = x[j], x[i] # exchange x[i] and x[j]
+
+ def sample(self, population, k):
+ """Return a k-length list of unique elements chosen from the population sequence."""
+
+ num_choices = len(population)
+ if k > num_choices:
+ raise ValueError("sample larger than population")
+
+ retval = []
+ selected = {} # we emulate a set using a dict here
+ for i in range(k):
+ r = None
+ while r is None or r in selected:
+ r = self.randrange(num_choices)
+ retval.append(population[r])
+ selected[r] = 1
+ return retval
+
+_r = StrongRandom()
+getrandbits = _r.getrandbits
+randrange = _r.randrange
+randint = _r.randint
+choice = _r.choice
+shuffle = _r.shuffle
+sample = _r.sample
+
+# These are at the bottom to avoid problems with recursive imports
+from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes, size
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/Random/random.pyi b/lib/Crypto/Random/random.pyi
new file mode 100644
index 0000000..f873c4a
--- /dev/null
+++ b/lib/Crypto/Random/random.pyi
@@ -0,0 +1,20 @@
+from typing import Callable, Tuple, Union, Sequence, Any, Optional
+
+__all__ = ['StrongRandom', 'getrandbits', 'randrange', 'randint', 'choice', 'shuffle', 'sample']
+
+class StrongRandom(object):
+ def __init__(self, rng: Optional[Any]=None, randfunc: Optional[Callable]=None) -> None: ... # TODO What is rng?
+ def getrandbits(self, k: int) -> int: ...
+ def randrange(self, start: int, stop: int = ..., step: int = ...) -> int: ...
+ def randint(self, a: int, b: int) -> int: ...
+ def choice(self, seq: Sequence) -> object: ...
+ def shuffle(self, x: Sequence) -> None: ...
+ def sample(self, population: Sequence, k: int) -> list: ...
+
+_r = StrongRandom()
+getrandbits = _r.getrandbits
+randrange = _r.randrange
+randint = _r.randint
+choice = _r.choice
+shuffle = _r.shuffle
+sample = _r.sample
diff --git a/lib/Crypto/SelfTest/Cipher/__init__.py b/lib/Crypto/SelfTest/Cipher/__init__.py
new file mode 100644
index 0000000..05fc139
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/__init__.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/__init__.py: Self-test for cipher modules
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test for cipher modules"""
+
+__revision__ = "$Id$"
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest.Cipher import test_AES; tests += test_AES.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_ARC2; tests += test_ARC2.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_ARC4; tests += test_ARC4.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_Blowfish; tests += test_Blowfish.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_CAST; tests += test_CAST.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_DES3; tests += test_DES3.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_DES; tests += test_DES.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_Salsa20; tests += test_Salsa20.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_ChaCha20; tests += test_ChaCha20.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_ChaCha20_Poly1305; tests += test_ChaCha20_Poly1305.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_pkcs1_15; tests += test_pkcs1_15.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_pkcs1_oaep; tests += test_pkcs1_oaep.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_OCB; tests += test_OCB.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_CBC; tests += test_CBC.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_CFB; tests += test_CFB.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_OpenPGP; tests += test_OpenPGP.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_OFB; tests += test_OFB.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_CTR; tests += test_CTR.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_CCM; tests += test_CCM.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_EAX; tests += test_EAX.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_GCM; tests += test_GCM.get_tests(config=config)
+ from Crypto.SelfTest.Cipher import test_SIV; tests += test_SIV.get_tests(config=config)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/common.py b/lib/Crypto/SelfTest/Cipher/common.py
new file mode 100644
index 0000000..c5bc755
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/common.py
@@ -0,0 +1,510 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/common.py: Common code for Crypto.SelfTest.Hash
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-testing for PyCrypto hash modules"""
+
+import unittest
+from binascii import a2b_hex, b2a_hex, hexlify
+
+from Crypto.Util.py3compat import b
+from Crypto.Util.strxor import strxor_c
+
+class _NoDefault: pass # sentinel object
+def _extract(d, k, default=_NoDefault):
+ """Get an item from a dictionary, and remove it from the dictionary."""
+ try:
+ retval = d[k]
+ except KeyError:
+ if default is _NoDefault:
+ raise
+ return default
+ del d[k]
+ return retval
+
+# Generic cipher test case
+class CipherSelfTest(unittest.TestCase):
+
+ def __init__(self, module, params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+
+ # Extract the parameters
+ params = params.copy()
+ self.description = _extract(params, 'description')
+ self.key = b(_extract(params, 'key'))
+ self.plaintext = b(_extract(params, 'plaintext'))
+ self.ciphertext = b(_extract(params, 'ciphertext'))
+ self.module_name = _extract(params, 'module_name', None)
+ self.assoc_data = _extract(params, 'assoc_data', None)
+ self.mac = _extract(params, 'mac', None)
+ if self.assoc_data:
+ self.mac = b(self.mac)
+
+ mode = _extract(params, 'mode', None)
+ self.mode_name = str(mode)
+
+ if mode is not None:
+ # Block cipher
+ self.mode = getattr(self.module, "MODE_" + mode)
+
+ self.iv = _extract(params, 'iv', None)
+ if self.iv is None:
+ self.iv = _extract(params, 'nonce', None)
+ if self.iv is not None:
+ self.iv = b(self.iv)
+
+ else:
+ # Stream cipher
+ self.mode = None
+ self.iv = _extract(params, 'iv', None)
+ if self.iv is not None:
+ self.iv = b(self.iv)
+
+ self.extra_params = params
+
+ def shortDescription(self):
+ return self.description
+
+ def _new(self):
+ params = self.extra_params.copy()
+ key = a2b_hex(self.key)
+
+ old_style = []
+ if self.mode is not None:
+ old_style = [ self.mode ]
+ if self.iv is not None:
+ old_style += [ a2b_hex(self.iv) ]
+
+ return self.module.new(key, *old_style, **params)
+
+ def isMode(self, name):
+ if not hasattr(self.module, "MODE_"+name):
+ return False
+ return self.mode == getattr(self.module, "MODE_"+name)
+
+ def runTest(self):
+ plaintext = a2b_hex(self.plaintext)
+ ciphertext = a2b_hex(self.ciphertext)
+ assoc_data = []
+ if self.assoc_data:
+ assoc_data = [ a2b_hex(b(x)) for x in self.assoc_data]
+
+ ct = None
+ pt = None
+
+ #
+ # Repeat the same encryption or decryption twice and verify
+ # that the result is always the same
+ #
+ for i in range(2):
+ cipher = self._new()
+ decipher = self._new()
+
+ # Only AEAD modes
+ for comp in assoc_data:
+ cipher.update(comp)
+ decipher.update(comp)
+
+ ctX = b2a_hex(cipher.encrypt(plaintext))
+ ptX = b2a_hex(decipher.decrypt(ciphertext))
+
+ if ct:
+ self.assertEqual(ct, ctX)
+ self.assertEqual(pt, ptX)
+ ct, pt = ctX, ptX
+
+ self.assertEqual(self.ciphertext, ct) # encrypt
+ self.assertEqual(self.plaintext, pt) # decrypt
+
+ if self.mac:
+ mac = b2a_hex(cipher.digest())
+ self.assertEqual(self.mac, mac)
+ decipher.verify(a2b_hex(self.mac))
+
+class CipherStreamingSelfTest(CipherSelfTest):
+
+ def shortDescription(self):
+ desc = self.module_name
+ if self.mode is not None:
+ desc += " in %s mode" % (self.mode_name,)
+ return "%s should behave like a stream cipher" % (desc,)
+
+ def runTest(self):
+ plaintext = a2b_hex(self.plaintext)
+ ciphertext = a2b_hex(self.ciphertext)
+
+ # The cipher should work like a stream cipher
+
+ # Test counter mode encryption, 3 bytes at a time
+ ct3 = []
+ cipher = self._new()
+ for i in range(0, len(plaintext), 3):
+ ct3.append(cipher.encrypt(plaintext[i:i+3]))
+ ct3 = b2a_hex(b("").join(ct3))
+ self.assertEqual(self.ciphertext, ct3) # encryption (3 bytes at a time)
+
+ # Test counter mode decryption, 3 bytes at a time
+ pt3 = []
+ cipher = self._new()
+ for i in range(0, len(ciphertext), 3):
+ pt3.append(cipher.encrypt(ciphertext[i:i+3]))
+ # PY3K: This is meant to be text, do not change to bytes (data)
+ pt3 = b2a_hex(b("").join(pt3))
+ self.assertEqual(self.plaintext, pt3) # decryption (3 bytes at a time)
+
+
+class RoundtripTest(unittest.TestCase):
+ def __init__(self, module, params):
+ from Crypto import Random
+ unittest.TestCase.__init__(self)
+ self.module = module
+ self.iv = Random.get_random_bytes(module.block_size)
+ self.key = b(params['key'])
+ self.plaintext = 100 * b(params['plaintext'])
+ self.module_name = params.get('module_name', None)
+
+ def shortDescription(self):
+ return """%s .decrypt() output of .encrypt() should not be garbled""" % (self.module_name,)
+
+ def runTest(self):
+
+ ## ECB mode
+ mode = self.module.MODE_ECB
+ encryption_cipher = self.module.new(a2b_hex(self.key), mode)
+ ciphertext = encryption_cipher.encrypt(self.plaintext)
+ decryption_cipher = self.module.new(a2b_hex(self.key), mode)
+ decrypted_plaintext = decryption_cipher.decrypt(ciphertext)
+ self.assertEqual(self.plaintext, decrypted_plaintext)
+
+
+class IVLengthTest(unittest.TestCase):
+ def __init__(self, module, params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+ self.key = b(params['key'])
+
+ def shortDescription(self):
+ return "Check that all modes except MODE_ECB and MODE_CTR require an IV of the proper length"
+
+ def runTest(self):
+ self.assertRaises(TypeError, self.module.new, a2b_hex(self.key),
+ self.module.MODE_ECB, b(""))
+
+ def _dummy_counter(self):
+ return "\0" * self.module.block_size
+
+
+class NoDefaultECBTest(unittest.TestCase):
+ def __init__(self, module, params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+ self.key = b(params['key'])
+
+ def runTest(self):
+ self.assertRaises(TypeError, self.module.new, a2b_hex(self.key))
+
+
+class BlockSizeTest(unittest.TestCase):
+ def __init__(self, module, params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+ self.key = a2b_hex(b(params['key']))
+
+ def runTest(self):
+ cipher = self.module.new(self.key, self.module.MODE_ECB)
+ self.assertEqual(cipher.block_size, self.module.block_size)
+
+
+class ByteArrayTest(unittest.TestCase):
+ """Verify we can use bytearray's for encrypting and decrypting"""
+
+ def __init__(self, module, params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+
+ # Extract the parameters
+ params = params.copy()
+ self.description = _extract(params, 'description')
+ self.key = b(_extract(params, 'key'))
+ self.plaintext = b(_extract(params, 'plaintext'))
+ self.ciphertext = b(_extract(params, 'ciphertext'))
+ self.module_name = _extract(params, 'module_name', None)
+ self.assoc_data = _extract(params, 'assoc_data', None)
+ self.mac = _extract(params, 'mac', None)
+ if self.assoc_data:
+ self.mac = b(self.mac)
+
+ mode = _extract(params, 'mode', None)
+ self.mode_name = str(mode)
+
+ if mode is not None:
+ # Block cipher
+ self.mode = getattr(self.module, "MODE_" + mode)
+
+ self.iv = _extract(params, 'iv', None)
+ if self.iv is None:
+ self.iv = _extract(params, 'nonce', None)
+ if self.iv is not None:
+ self.iv = b(self.iv)
+ else:
+ # Stream cipher
+ self.mode = None
+ self.iv = _extract(params, 'iv', None)
+ if self.iv is not None:
+ self.iv = b(self.iv)
+
+ self.extra_params = params
+
+ def _new(self):
+ params = self.extra_params.copy()
+ key = a2b_hex(self.key)
+
+ old_style = []
+ if self.mode is not None:
+ old_style = [ self.mode ]
+ if self.iv is not None:
+ old_style += [ a2b_hex(self.iv) ]
+
+ return self.module.new(key, *old_style, **params)
+
+ def runTest(self):
+
+ plaintext = a2b_hex(self.plaintext)
+ ciphertext = a2b_hex(self.ciphertext)
+ assoc_data = []
+ if self.assoc_data:
+ assoc_data = [ bytearray(a2b_hex(b(x))) for x in self.assoc_data]
+
+ cipher = self._new()
+ decipher = self._new()
+
+ # Only AEAD modes
+ for comp in assoc_data:
+ cipher.update(comp)
+ decipher.update(comp)
+
+ ct = b2a_hex(cipher.encrypt(bytearray(plaintext)))
+ pt = b2a_hex(decipher.decrypt(bytearray(ciphertext)))
+
+ self.assertEqual(self.ciphertext, ct) # encrypt
+ self.assertEqual(self.plaintext, pt) # decrypt
+
+ if self.mac:
+ mac = b2a_hex(cipher.digest())
+ self.assertEqual(self.mac, mac)
+ decipher.verify(bytearray(a2b_hex(self.mac)))
+
+
+class MemoryviewTest(unittest.TestCase):
+ """Verify we can use memoryviews for encrypting and decrypting"""
+
+ def __init__(self, module, params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+
+ # Extract the parameters
+ params = params.copy()
+ self.description = _extract(params, 'description')
+ self.key = b(_extract(params, 'key'))
+ self.plaintext = b(_extract(params, 'plaintext'))
+ self.ciphertext = b(_extract(params, 'ciphertext'))
+ self.module_name = _extract(params, 'module_name', None)
+ self.assoc_data = _extract(params, 'assoc_data', None)
+ self.mac = _extract(params, 'mac', None)
+ if self.assoc_data:
+ self.mac = b(self.mac)
+
+ mode = _extract(params, 'mode', None)
+ self.mode_name = str(mode)
+
+ if mode is not None:
+ # Block cipher
+ self.mode = getattr(self.module, "MODE_" + mode)
+
+ self.iv = _extract(params, 'iv', None)
+ if self.iv is None:
+ self.iv = _extract(params, 'nonce', None)
+ if self.iv is not None:
+ self.iv = b(self.iv)
+ else:
+ # Stream cipher
+ self.mode = None
+ self.iv = _extract(params, 'iv', None)
+ if self.iv is not None:
+ self.iv = b(self.iv)
+
+ self.extra_params = params
+
+ def _new(self):
+ params = self.extra_params.copy()
+ key = a2b_hex(self.key)
+
+ old_style = []
+ if self.mode is not None:
+ old_style = [ self.mode ]
+ if self.iv is not None:
+ old_style += [ a2b_hex(self.iv) ]
+
+ return self.module.new(key, *old_style, **params)
+
+ def runTest(self):
+
+ plaintext = a2b_hex(self.plaintext)
+ ciphertext = a2b_hex(self.ciphertext)
+ assoc_data = []
+ if self.assoc_data:
+ assoc_data = [ memoryview(a2b_hex(b(x))) for x in self.assoc_data]
+
+ cipher = self._new()
+ decipher = self._new()
+
+ # Only AEAD modes
+ for comp in assoc_data:
+ cipher.update(comp)
+ decipher.update(comp)
+
+ ct = b2a_hex(cipher.encrypt(memoryview(plaintext)))
+ pt = b2a_hex(decipher.decrypt(memoryview(ciphertext)))
+
+ self.assertEqual(self.ciphertext, ct) # encrypt
+ self.assertEqual(self.plaintext, pt) # decrypt
+
+ if self.mac:
+ mac = b2a_hex(cipher.digest())
+ self.assertEqual(self.mac, mac)
+ decipher.verify(memoryview(a2b_hex(self.mac)))
+
+
+def make_block_tests(module, module_name, test_data, additional_params=dict()):
+ tests = []
+ extra_tests_added = False
+ for i in range(len(test_data)):
+ row = test_data[i]
+
+ # Build the "params" dictionary with
+ # - plaintext
+ # - ciphertext
+ # - key
+ # - mode (default is ECB)
+ # - (optionally) description
+ # - (optionally) any other parameter that this cipher mode requires
+ params = {}
+ if len(row) == 3:
+ (params['plaintext'], params['ciphertext'], params['key']) = row
+ elif len(row) == 4:
+ (params['plaintext'], params['ciphertext'], params['key'], params['description']) = row
+ elif len(row) == 5:
+ (params['plaintext'], params['ciphertext'], params['key'], params['description'], extra_params) = row
+ params.update(extra_params)
+ else:
+ raise AssertionError("Unsupported tuple size %d" % (len(row),))
+
+ if not "mode" in params:
+ params["mode"] = "ECB"
+
+ # Build the display-name for the test
+ p2 = params.copy()
+ p_key = _extract(p2, 'key')
+ p_plaintext = _extract(p2, 'plaintext')
+ p_ciphertext = _extract(p2, 'ciphertext')
+ p_mode = _extract(p2, 'mode')
+ p_description = _extract(p2, 'description', None)
+
+ if p_description is not None:
+ description = p_description
+ elif p_mode == 'ECB' and not p2:
+ description = "p=%s, k=%s" % (p_plaintext, p_key)
+ else:
+ description = "p=%s, k=%s, %r" % (p_plaintext, p_key, p2)
+ name = "%s #%d: %s" % (module_name, i+1, description)
+ params['description'] = name
+ params['module_name'] = module_name
+ params.update(additional_params)
+
+ # Add extra test(s) to the test suite before the current test
+ if not extra_tests_added:
+ tests += [
+ RoundtripTest(module, params),
+ IVLengthTest(module, params),
+ NoDefaultECBTest(module, params),
+ ByteArrayTest(module, params),
+ BlockSizeTest(module, params),
+ ]
+ extra_tests_added = True
+
+ # Add the current test to the test suite
+ tests.append(CipherSelfTest(module, params))
+
+ return tests
+
+def make_stream_tests(module, module_name, test_data):
+ tests = []
+ extra_tests_added = False
+ for i in range(len(test_data)):
+ row = test_data[i]
+
+ # Build the "params" dictionary
+ params = {}
+ if len(row) == 3:
+ (params['plaintext'], params['ciphertext'], params['key']) = row
+ elif len(row) == 4:
+ (params['plaintext'], params['ciphertext'], params['key'], params['description']) = row
+ elif len(row) == 5:
+ (params['plaintext'], params['ciphertext'], params['key'], params['description'], extra_params) = row
+ params.update(extra_params)
+ else:
+ raise AssertionError("Unsupported tuple size %d" % (len(row),))
+
+ # Build the display-name for the test
+ p2 = params.copy()
+ p_key = _extract(p2, 'key')
+ p_plaintext = _extract(p2, 'plaintext')
+ p_ciphertext = _extract(p2, 'ciphertext')
+ p_description = _extract(p2, 'description', None)
+
+ if p_description is not None:
+ description = p_description
+ elif not p2:
+ description = "p=%s, k=%s" % (p_plaintext, p_key)
+ else:
+ description = "p=%s, k=%s, %r" % (p_plaintext, p_key, p2)
+ name = "%s #%d: %s" % (module_name, i+1, description)
+ params['description'] = name
+ params['module_name'] = module_name
+
+ # Add extra test(s) to the test suite before the current test
+ if not extra_tests_added:
+ tests += [
+ ByteArrayTest(module, params),
+ ]
+
+ tests.append(MemoryviewTest(module, params))
+ extra_tests_added = True
+
+ # Add the test to the test suite
+ tests.append(CipherSelfTest(module, params))
+ tests.append(CipherStreamingSelfTest(module, params))
+ return tests
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_AES.py b/lib/Crypto/SelfTest/Cipher/test_AES.py
new file mode 100644
index 0000000..116deec
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_AES.py
@@ -0,0 +1,1351 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/AES.py: Self-test for the AES cipher
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.AES"""
+
+from __future__ import print_function
+
+import unittest
+from Crypto.Hash import SHA256
+from Crypto.Cipher import AES
+from Crypto.Util.py3compat import *
+from binascii import hexlify
+
+# This is a list of (plaintext, ciphertext, key[, description[, params]]) tuples.
+test_data = [
+ # FIPS PUB 197 test vectors
+ # http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf
+
+ ('00112233445566778899aabbccddeeff', '69c4e0d86a7b0430d8cdb78070b4c55a',
+ '000102030405060708090a0b0c0d0e0f', 'FIPS 197 C.1 (AES-128)'),
+
+ ('00112233445566778899aabbccddeeff', 'dda97ca4864cdfe06eaf70a0ec0d7191',
+ '000102030405060708090a0b0c0d0e0f1011121314151617',
+ 'FIPS 197 C.2 (AES-192)'),
+
+ ('00112233445566778899aabbccddeeff', '8ea2b7ca516745bfeafc49904b496089',
+ '000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ 'FIPS 197 C.3 (AES-256)'),
+
+ # Rijndael128 test vectors
+ # Downloaded 2008-09-13 from
+ # http://www.iaik.tugraz.at/Research/krypto/AES/old/~rijmen/rijndael/testvalues.tar.gz
+
+ # ecb_tbl.txt, KEYSIZE=128
+ ('506812a45f08c889b97f5980038b8359', 'd8f532538289ef7d06b506a4fd5be9c9',
+ '00010203050607080a0b0c0d0f101112',
+ 'ecb-tbl-128: I=1'),
+ ('5c6d71ca30de8b8b00549984d2ec7d4b', '59ab30f4d4ee6e4ff9907ef65b1fb68c',
+ '14151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-128: I=2'),
+ ('53f3f4c64f8616e4e7c56199f48f21f6', 'bf1ed2fcb2af3fd41443b56d85025cb1',
+ '28292a2b2d2e2f30323334353738393a',
+ 'ecb-tbl-128: I=3'),
+ ('a1eb65a3487165fb0f1c27ff9959f703', '7316632d5c32233edcb0780560eae8b2',
+ '3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-128: I=4'),
+ ('3553ecf0b1739558b08e350a98a39bfa', '408c073e3e2538072b72625e68b8364b',
+ '50515253555657585a5b5c5d5f606162',
+ 'ecb-tbl-128: I=5'),
+ ('67429969490b9711ae2b01dc497afde8', 'e1f94dfa776597beaca262f2f6366fea',
+ '64656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-128: I=6'),
+ ('93385c1f2aec8bed192f5a8e161dd508', 'f29e986c6a1c27d7b29ffd7ee92b75f1',
+ '78797a7b7d7e7f80828384858788898a',
+ 'ecb-tbl-128: I=7'),
+ ('b5bf946be19beb8db3983b5f4c6e8ddb', '131c886a57f8c2e713aba6955e2b55b5',
+ '8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-128: I=8'),
+ ('41321ee10e21bd907227c4450ff42324', 'd2ab7662df9b8c740210e5eeb61c199d',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2',
+ 'ecb-tbl-128: I=9'),
+ ('00a82f59c91c8486d12c0a80124f6089', '14c10554b2859c484cab5869bbe7c470',
+ 'b4b5b6b7b9babbbcbebfc0c1c3c4c5c6',
+ 'ecb-tbl-128: I=10'),
+ ('7ce0fd076754691b4bbd9faf8a1372fe', 'db4d498f0a49cf55445d502c1f9ab3b5',
+ 'c8c9cacbcdcecfd0d2d3d4d5d7d8d9da',
+ 'ecb-tbl-128: I=11'),
+ ('23605a8243d07764541bc5ad355b3129', '6d96fef7d66590a77a77bb2056667f7f',
+ 'dcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-128: I=12'),
+ ('12a8cfa23ea764fd876232b4e842bc44', '316fb68edba736c53e78477bf913725c',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe010002',
+ 'ecb-tbl-128: I=13'),
+ ('bcaf32415e8308b3723e5fdd853ccc80', '6936f2b93af8397fd3a771fc011c8c37',
+ '04050607090a0b0c0e0f101113141516',
+ 'ecb-tbl-128: I=14'),
+ ('89afae685d801ad747ace91fc49adde0', 'f3f92f7a9c59179c1fcc2c2ba0b082cd',
+ '2c2d2e2f31323334363738393b3c3d3e',
+ 'ecb-tbl-128: I=15'),
+ ('f521d07b484357c4a69e76124a634216', '6a95ea659ee3889158e7a9152ff04ebc',
+ '40414243454647484a4b4c4d4f505152',
+ 'ecb-tbl-128: I=16'),
+ ('3e23b3bc065bcc152407e23896d77783', '1959338344e945670678a5d432c90b93',
+ '54555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-128: I=17'),
+ ('79f0fba002be1744670e7e99290d8f52', 'e49bddd2369b83ee66e6c75a1161b394',
+ '68696a6b6d6e6f70727374757778797a',
+ 'ecb-tbl-128: I=18'),
+ ('da23fe9d5bd63e1d72e3dafbe21a6c2a', 'd3388f19057ff704b70784164a74867d',
+ '7c7d7e7f81828384868788898b8c8d8e',
+ 'ecb-tbl-128: I=19'),
+ ('e3f5698ba90b6a022efd7db2c7e6c823', '23aa03e2d5e4cd24f3217e596480d1e1',
+ 'a4a5a6a7a9aaabacaeafb0b1b3b4b5b6',
+ 'ecb-tbl-128: I=20'),
+ ('bdc2691d4f1b73d2700679c3bcbf9c6e', 'c84113d68b666ab2a50a8bdb222e91b9',
+ 'e0e1e2e3e5e6e7e8eaebecedeff0f1f2',
+ 'ecb-tbl-128: I=21'),
+ ('ba74e02093217ee1ba1b42bd5624349a', 'ac02403981cd4340b507963db65cb7b6',
+ '08090a0b0d0e0f10121314151718191a',
+ 'ecb-tbl-128: I=22'),
+ ('b5c593b5851c57fbf8b3f57715e8f680', '8d1299236223359474011f6bf5088414',
+ '6c6d6e6f71727374767778797b7c7d7e',
+ 'ecb-tbl-128: I=23'),
+ ('3da9bd9cec072381788f9387c3bbf4ee', '5a1d6ab8605505f7977e55b9a54d9b90',
+ '80818283858687888a8b8c8d8f909192',
+ 'ecb-tbl-128: I=24'),
+ ('4197f3051121702ab65d316b3c637374', '72e9c2d519cf555e4208805aabe3b258',
+ '94959697999a9b9c9e9fa0a1a3a4a5a6',
+ 'ecb-tbl-128: I=25'),
+ ('9f46c62ec4f6ee3f6e8c62554bc48ab7', 'a8f3e81c4a23a39ef4d745dffe026e80',
+ 'a8a9aaabadaeafb0b2b3b4b5b7b8b9ba',
+ 'ecb-tbl-128: I=26'),
+ ('0220673fe9e699a4ebc8e0dbeb6979c8', '546f646449d31458f9eb4ef5483aee6c',
+ 'bcbdbebfc1c2c3c4c6c7c8c9cbcccdce',
+ 'ecb-tbl-128: I=27'),
+ ('b2b99171337ded9bc8c2c23ff6f18867', '4dbe4bc84ac797c0ee4efb7f1a07401c',
+ 'd0d1d2d3d5d6d7d8dadbdcdddfe0e1e2',
+ 'ecb-tbl-128: I=28'),
+ ('a7facf4e301e984e5efeefd645b23505', '25e10bfb411bbd4d625ac8795c8ca3b3',
+ 'e4e5e6e7e9eaebeceeeff0f1f3f4f5f6',
+ 'ecb-tbl-128: I=29'),
+ ('f7c762e4a9819160fd7acfb6c4eedcdd', '315637405054ec803614e43def177579',
+ 'f8f9fafbfdfefe00020304050708090a',
+ 'ecb-tbl-128: I=30'),
+ ('9b64fc21ea08709f4915436faa70f1be', '60c5bc8a1410247295c6386c59e572a8',
+ '0c0d0e0f11121314161718191b1c1d1e',
+ 'ecb-tbl-128: I=31'),
+ ('52af2c3de07ee6777f55a4abfc100b3f', '01366fc8ca52dfe055d6a00a76471ba6',
+ '20212223252627282a2b2c2d2f303132',
+ 'ecb-tbl-128: I=32'),
+ ('2fca001224386c57aa3f968cbe2c816f', 'ecc46595516ec612449c3f581e7d42ff',
+ '34353637393a3b3c3e3f404143444546',
+ 'ecb-tbl-128: I=33'),
+ ('4149c73658a4a9c564342755ee2c132f', '6b7ffe4c602a154b06ee9c7dab5331c9',
+ '48494a4b4d4e4f50525354555758595a',
+ 'ecb-tbl-128: I=34'),
+ ('af60005a00a1772f7c07a48a923c23d2', '7da234c14039a240dd02dd0fbf84eb67',
+ '5c5d5e5f61626364666768696b6c6d6e',
+ 'ecb-tbl-128: I=35'),
+ ('6fccbc28363759914b6f0280afaf20c6', 'c7dc217d9e3604ffe7e91f080ecd5a3a',
+ '70717273757677787a7b7c7d7f808182',
+ 'ecb-tbl-128: I=36'),
+ ('7d82a43ddf4fefa2fc5947499884d386', '37785901863f5c81260ea41e7580cda5',
+ '84858687898a8b8c8e8f909193949596',
+ 'ecb-tbl-128: I=37'),
+ ('5d5a990eaab9093afe4ce254dfa49ef9', 'a07b9338e92ed105e6ad720fccce9fe4',
+ '98999a9b9d9e9fa0a2a3a4a5a7a8a9aa',
+ 'ecb-tbl-128: I=38'),
+ ('4cd1e2fd3f4434b553aae453f0ed1a02', 'ae0fb9722418cc21a7da816bbc61322c',
+ 'acadaeafb1b2b3b4b6b7b8b9bbbcbdbe',
+ 'ecb-tbl-128: I=39'),
+ ('5a2c9a9641d4299125fa1b9363104b5e', 'c826a193080ff91ffb21f71d3373c877',
+ 'c0c1c2c3c5c6c7c8cacbcccdcfd0d1d2',
+ 'ecb-tbl-128: I=40'),
+ ('b517fe34c0fa217d341740bfd4fe8dd4', '1181b11b0e494e8d8b0aa6b1d5ac2c48',
+ 'd4d5d6d7d9dadbdcdedfe0e1e3e4e5e6',
+ 'ecb-tbl-128: I=41'),
+ ('014baf2278a69d331d5180103643e99a', '6743c3d1519ab4f2cd9a78ab09a511bd',
+ 'e8e9eaebedeeeff0f2f3f4f5f7f8f9fa',
+ 'ecb-tbl-128: I=42'),
+ ('b529bd8164f20d0aa443d4932116841c', 'dc55c076d52bacdf2eefd952946a439d',
+ 'fcfdfeff01020304060708090b0c0d0e',
+ 'ecb-tbl-128: I=43'),
+ ('2e596dcbb2f33d4216a1176d5bd1e456', '711b17b590ffc72b5c8e342b601e8003',
+ '10111213151617181a1b1c1d1f202122',
+ 'ecb-tbl-128: I=44'),
+ ('7274a1ea2b7ee2424e9a0e4673689143', '19983bb0950783a537e1339f4aa21c75',
+ '24252627292a2b2c2e2f303133343536',
+ 'ecb-tbl-128: I=45'),
+ ('ae20020bd4f13e9d90140bee3b5d26af', '3ba7762e15554169c0f4fa39164c410c',
+ '38393a3b3d3e3f40424344454748494a',
+ 'ecb-tbl-128: I=46'),
+ ('baac065da7ac26e855e79c8849d75a02', 'a0564c41245afca7af8aa2e0e588ea89',
+ '4c4d4e4f51525354565758595b5c5d5e',
+ 'ecb-tbl-128: I=47'),
+ ('7c917d8d1d45fab9e2540e28832540cc', '5e36a42a2e099f54ae85ecd92e2381ed',
+ '60616263656667686a6b6c6d6f707172',
+ 'ecb-tbl-128: I=48'),
+ ('bde6f89e16daadb0e847a2a614566a91', '770036f878cd0f6ca2268172f106f2fe',
+ '74757677797a7b7c7e7f808183848586',
+ 'ecb-tbl-128: I=49'),
+ ('c9de163725f1f5be44ebb1db51d07fbc', '7e4e03908b716116443ccf7c94e7c259',
+ '88898a8b8d8e8f90929394959798999a',
+ 'ecb-tbl-128: I=50'),
+ ('3af57a58f0c07dffa669572b521e2b92', '482735a48c30613a242dd494c7f9185d',
+ '9c9d9e9fa1a2a3a4a6a7a8a9abacadae',
+ 'ecb-tbl-128: I=51'),
+ ('3d5ebac306dde4604f1b4fbbbfcdae55', 'b4c0f6c9d4d7079addf9369fc081061d',
+ 'b0b1b2b3b5b6b7b8babbbcbdbfc0c1c2',
+ 'ecb-tbl-128: I=52'),
+ ('c2dfa91bceb76a1183c995020ac0b556', 'd5810fe0509ac53edcd74f89962e6270',
+ 'c4c5c6c7c9cacbcccecfd0d1d3d4d5d6',
+ 'ecb-tbl-128: I=53'),
+ ('c70f54305885e9a0746d01ec56c8596b', '03f17a16b3f91848269ecdd38ebb2165',
+ 'd8d9dadbdddedfe0e2e3e4e5e7e8e9ea',
+ 'ecb-tbl-128: I=54'),
+ ('c4f81b610e98012ce000182050c0c2b2', 'da1248c3180348bad4a93b4d9856c9df',
+ 'ecedeeeff1f2f3f4f6f7f8f9fbfcfdfe',
+ 'ecb-tbl-128: I=55'),
+ ('eaab86b1d02a95d7404eff67489f97d4', '3d10d7b63f3452c06cdf6cce18be0c2c',
+ '00010203050607080a0b0c0d0f101112',
+ 'ecb-tbl-128: I=56'),
+ ('7c55bdb40b88870b52bec3738de82886', '4ab823e7477dfddc0e6789018fcb6258',
+ '14151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-128: I=57'),
+ ('ba6eaa88371ff0a3bd875e3f2a975ce0', 'e6478ba56a77e70cfdaa5c843abde30e',
+ '28292a2b2d2e2f30323334353738393a',
+ 'ecb-tbl-128: I=58'),
+ ('08059130c4c24bd30cf0575e4e0373dc', '1673064895fbeaf7f09c5429ff75772d',
+ '3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-128: I=59'),
+ ('9a8eab004ef53093dfcf96f57e7eda82', '4488033ae9f2efd0ca9383bfca1a94e9',
+ '50515253555657585a5b5c5d5f606162',
+ 'ecb-tbl-128: I=60'),
+ ('0745b589e2400c25f117b1d796c28129', '978f3b8c8f9d6f46626cac3c0bcb9217',
+ '64656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-128: I=61'),
+ ('2f1777781216cec3f044f134b1b92bbe', 'e08c8a7e582e15e5527f1d9e2eecb236',
+ '78797a7b7d7e7f80828384858788898a',
+ 'ecb-tbl-128: I=62'),
+ ('353a779ffc541b3a3805d90ce17580fc', 'cec155b76ac5ffda4cf4f9ca91e49a7a',
+ '8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-128: I=63'),
+ ('1a1eae4415cefcf08c4ac1c8f68bea8f', 'd5ac7165763225dd2a38cdc6862c29ad',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2',
+ 'ecb-tbl-128: I=64'),
+ ('e6e7e4e5b0b3b2b5d4d5aaab16111013', '03680fe19f7ce7275452020be70e8204',
+ 'b4b5b6b7b9babbbcbebfc0c1c3c4c5c6',
+ 'ecb-tbl-128: I=65'),
+ ('f8f9fafbfbf8f9e677767170efe0e1e2', '461df740c9781c388e94bb861ceb54f6',
+ 'c8c9cacbcdcecfd0d2d3d4d5d7d8d9da',
+ 'ecb-tbl-128: I=66'),
+ ('63626160a1a2a3a445444b4a75727370', '451bd60367f96483042742219786a074',
+ 'dcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-128: I=67'),
+ ('717073720605040b2d2c2b2a05fafbf9', 'e4dfa42671a02e57ef173b85c0ea9f2b',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe010002',
+ 'ecb-tbl-128: I=68'),
+ ('78797a7beae9e8ef3736292891969794', 'ed11b89e76274282227d854700a78b9e',
+ '04050607090a0b0c0e0f101113141516',
+ 'ecb-tbl-128: I=69'),
+ ('838281803231300fdddcdbdaa0afaead', '433946eaa51ea47af33895f2b90b3b75',
+ '18191a1b1d1e1f20222324252728292a',
+ 'ecb-tbl-128: I=70'),
+ ('18191a1bbfbcbdba75747b7a7f78797a', '6bc6d616a5d7d0284a5910ab35022528',
+ '2c2d2e2f31323334363738393b3c3d3e',
+ 'ecb-tbl-128: I=71'),
+ ('848586879b989996a3a2a5a4849b9a99', 'd2a920ecfe919d354b5f49eae9719c98',
+ '40414243454647484a4b4c4d4f505152',
+ 'ecb-tbl-128: I=72'),
+ ('0001020322212027cacbf4f551565754', '3a061b17f6a92885efbd0676985b373d',
+ '54555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-128: I=73'),
+ ('cecfcccdafacadb2515057564a454447', 'fadeec16e33ea2f4688499d157e20d8f',
+ '68696a6b6d6e6f70727374757778797a',
+ 'ecb-tbl-128: I=74'),
+ ('92939091cdcecfc813121d1c80878685', '5cdefede59601aa3c3cda36fa6b1fa13',
+ '7c7d7e7f81828384868788898b8c8d8e',
+ 'ecb-tbl-128: I=75'),
+ ('d2d3d0d16f6c6d6259585f5ed1eeefec', '9574b00039844d92ebba7ee8719265f8',
+ '90919293959697989a9b9c9d9fa0a1a2',
+ 'ecb-tbl-128: I=76'),
+ ('acadaeaf878485820f0e1110d5d2d3d0', '9a9cf33758671787e5006928188643fa',
+ 'a4a5a6a7a9aaabacaeafb0b1b3b4b5b6',
+ 'ecb-tbl-128: I=77'),
+ ('9091929364676619e6e7e0e1757a7b78', '2cddd634c846ba66bb46cbfea4a674f9',
+ 'b8b9babbbdbebfc0c2c3c4c5c7c8c9ca',
+ 'ecb-tbl-128: I=78'),
+ ('babbb8b98a89888f74757a7b92959497', 'd28bae029393c3e7e26e9fafbbb4b98f',
+ 'cccdcecfd1d2d3d4d6d7d8d9dbdcddde',
+ 'ecb-tbl-128: I=79'),
+ ('8d8c8f8e6e6d6c633b3a3d3ccad5d4d7', 'ec27529b1bee0a9ab6a0d73ebc82e9b7',
+ 'e0e1e2e3e5e6e7e8eaebecedeff0f1f2',
+ 'ecb-tbl-128: I=80'),
+ ('86878485010203040808f7f767606162', '3cb25c09472aff6ee7e2b47ccd7ccb17',
+ 'f4f5f6f7f9fafbfcfefe010103040506',
+ 'ecb-tbl-128: I=81'),
+ ('8e8f8c8d656667788a8b8c8d010e0f0c', 'dee33103a7283370d725e44ca38f8fe5',
+ '08090a0b0d0e0f10121314151718191a',
+ 'ecb-tbl-128: I=82'),
+ ('c8c9cacb858687807a7b7475e7e0e1e2', '27f9bcd1aac64bffc11e7815702c1a69',
+ '1c1d1e1f21222324262728292b2c2d2e',
+ 'ecb-tbl-128: I=83'),
+ ('6d6c6f6e5053525d8c8d8a8badd2d3d0', '5df534ffad4ed0749a9988e9849d0021',
+ '30313233353637383a3b3c3d3f404142',
+ 'ecb-tbl-128: I=84'),
+ ('28292a2b393a3b3c0607181903040506', 'a48bee75db04fb60ca2b80f752a8421b',
+ '44454647494a4b4c4e4f505153545556',
+ 'ecb-tbl-128: I=85'),
+ ('a5a4a7a6b0b3b28ddbdadddcbdb2b3b0', '024c8cf70bc86ee5ce03678cb7af45f9',
+ '58595a5b5d5e5f60626364656768696a',
+ 'ecb-tbl-128: I=86'),
+ ('323330316467666130313e3f2c2b2a29', '3c19ac0f8a3a3862ce577831301e166b',
+ '6c6d6e6f71727374767778797b7c7d7e',
+ 'ecb-tbl-128: I=87'),
+ ('27262524080b0a05171611100b141516', 'c5e355b796a57421d59ca6be82e73bca',
+ '80818283858687888a8b8c8d8f909192',
+ 'ecb-tbl-128: I=88'),
+ ('040506074142434435340b0aa3a4a5a6', 'd94033276417abfb05a69d15b6e386e2',
+ '94959697999a9b9c9e9fa0a1a3a4a5a6',
+ 'ecb-tbl-128: I=89'),
+ ('242526271112130c61606766bdb2b3b0', '24b36559ea3a9b9b958fe6da3e5b8d85',
+ 'a8a9aaabadaeafb0b2b3b4b5b7b8b9ba',
+ 'ecb-tbl-128: I=90'),
+ ('4b4a4948252627209e9f9091cec9c8cb', '20fd4feaa0e8bf0cce7861d74ef4cb72',
+ 'bcbdbebfc1c2c3c4c6c7c8c9cbcccdce',
+ 'ecb-tbl-128: I=91'),
+ ('68696a6b6665646b9f9e9998d9e6e7e4', '350e20d5174277b9ec314c501570a11d',
+ 'd0d1d2d3d5d6d7d8dadbdcdddfe0e1e2',
+ 'ecb-tbl-128: I=92'),
+ ('34353637c5c6c7c0f0f1eeef7c7b7a79', '87a29d61b7c604d238fe73045a7efd57',
+ 'e4e5e6e7e9eaebeceeeff0f1f3f4f5f6',
+ 'ecb-tbl-128: I=93'),
+ ('32333031c2c1c13f0d0c0b0a050a0b08', '2c3164c1cc7d0064816bdc0faa362c52',
+ 'f8f9fafbfdfefe00020304050708090a',
+ 'ecb-tbl-128: I=94'),
+ ('cdcccfcebebdbcbbabaaa5a4181f1e1d', '195fe5e8a05a2ed594f6e4400eee10b3',
+ '0c0d0e0f11121314161718191b1c1d1e',
+ 'ecb-tbl-128: I=95'),
+ ('212023223635343ba0a1a6a7445b5a59', 'e4663df19b9a21a5a284c2bd7f905025',
+ '20212223252627282a2b2c2d2f303132',
+ 'ecb-tbl-128: I=96'),
+ ('0e0f0c0da8abaaad2f2e515002050407', '21b88714cfb4e2a933bd281a2c4743fd',
+ '34353637393a3b3c3e3f404143444546',
+ 'ecb-tbl-128: I=97'),
+ ('070605042a2928378e8f8889bdb2b3b0', 'cbfc3980d704fd0fc54378ab84e17870',
+ '48494a4b4d4e4f50525354555758595a',
+ 'ecb-tbl-128: I=98'),
+ ('cbcac9c893909196a9a8a7a6a5a2a3a0', 'bc5144baa48bdeb8b63e22e03da418ef',
+ '5c5d5e5f61626364666768696b6c6d6e',
+ 'ecb-tbl-128: I=99'),
+ ('80818283c1c2c3cc9c9d9a9b0cf3f2f1', '5a1dbaef1ee2984b8395da3bdffa3ccc',
+ '70717273757677787a7b7c7d7f808182',
+ 'ecb-tbl-128: I=100'),
+ ('1213101125262720fafbe4e5b1b6b7b4', 'f0b11cd0729dfcc80cec903d97159574',
+ '84858687898a8b8c8e8f909193949596',
+ 'ecb-tbl-128: I=101'),
+ ('7f7e7d7c3033320d97969190222d2c2f', '9f95314acfddc6d1914b7f19a9cc8209',
+ '98999a9b9d9e9fa0a2a3a4a5a7a8a9aa',
+ 'ecb-tbl-128: I=102'),
+ ('4e4f4c4d484b4a4d81808f8e53545556', '595736f6f0f70914a94e9e007f022519',
+ 'acadaeafb1b2b3b4b6b7b8b9bbbcbdbe',
+ 'ecb-tbl-128: I=103'),
+ ('dcdddedfb0b3b2bd15141312a1bebfbc', '1f19f57892cae586fcdfb4c694deb183',
+ 'c0c1c2c3c5c6c7c8cacbcccdcfd0d1d2',
+ 'ecb-tbl-128: I=104'),
+ ('93929190282b2a2dc4c5fafb92959497', '540700ee1f6f3dab0b3eddf6caee1ef5',
+ 'd4d5d6d7d9dadbdcdedfe0e1e3e4e5e6',
+ 'ecb-tbl-128: I=105'),
+ ('f5f4f7f6c4c7c6d9373631307e717073', '14a342a91019a331687a2254e6626ca2',
+ 'e8e9eaebedeeeff0f2f3f4f5f7f8f9fa',
+ 'ecb-tbl-128: I=106'),
+ ('93929190b6b5b4b364656a6b05020300', '7b25f3c3b2eea18d743ef283140f29ff',
+ 'fcfdfeff01020304060708090b0c0d0e',
+ 'ecb-tbl-128: I=107'),
+ ('babbb8b90d0e0f00a4a5a2a3043b3a39', '46c2587d66e5e6fa7f7ca6411ad28047',
+ '10111213151617181a1b1c1d1f202122',
+ 'ecb-tbl-128: I=108'),
+ ('d8d9dadb7f7c7d7a10110e0f787f7e7d', '09470e72229d954ed5ee73886dfeeba9',
+ '24252627292a2b2c2e2f303133343536',
+ 'ecb-tbl-128: I=109'),
+ ('fefffcfdefeced923b3a3d3c6768696a', 'd77c03de92d4d0d79ef8d4824ef365eb',
+ '38393a3b3d3e3f40424344454748494a',
+ 'ecb-tbl-128: I=110'),
+ ('d6d7d4d58a89888f96979899a5a2a3a0', '1d190219f290e0f1715d152d41a23593',
+ '4c4d4e4f51525354565758595b5c5d5e',
+ 'ecb-tbl-128: I=111'),
+ ('18191a1ba8abaaa5303136379b848586', 'a2cd332ce3a0818769616292e87f757b',
+ '60616263656667686a6b6c6d6f707172',
+ 'ecb-tbl-128: I=112'),
+ ('6b6a6968a4a7a6a1d6d72829b0b7b6b5', 'd54afa6ce60fbf9341a3690e21385102',
+ '74757677797a7b7c7e7f808183848586',
+ 'ecb-tbl-128: I=113'),
+ ('000102038a89889755545352a6a9a8ab', '06e5c364ded628a3f5e05e613e356f46',
+ '88898a8b8d8e8f90929394959798999a',
+ 'ecb-tbl-128: I=114'),
+ ('2d2c2f2eb3b0b1b6b6b7b8b9f2f5f4f7', 'eae63c0e62556dac85d221099896355a',
+ '9c9d9e9fa1a2a3a4a6a7a8a9abacadae',
+ 'ecb-tbl-128: I=115'),
+ ('979695943536373856575051e09f9e9d', '1fed060e2c6fc93ee764403a889985a2',
+ 'b0b1b2b3b5b6b7b8babbbcbdbfc0c1c2',
+ 'ecb-tbl-128: I=116'),
+ ('a4a5a6a7989b9a9db1b0afae7a7d7c7f', 'c25235c1a30fdec1c7cb5c5737b2a588',
+ 'c4c5c6c7c9cacbcccecfd0d1d3d4d5d6',
+ 'ecb-tbl-128: I=117'),
+ ('c1c0c3c2686b6a55a8a9aeafeae5e4e7', '796dbef95147d4d30873ad8b7b92efc0',
+ 'd8d9dadbdddedfe0e2e3e4e5e7e8e9ea',
+ 'ecb-tbl-128: I=118'),
+ ('c1c0c3c2141716118c8d828364636261', 'cbcf0fb34d98d0bd5c22ce37211a46bf',
+ 'ecedeeeff1f2f3f4f6f7f8f9fbfcfdfe',
+ 'ecb-tbl-128: I=119'),
+ ('93929190cccfcec196979091e0fffefd', '94b44da6466126cafa7c7fd09063fc24',
+ '00010203050607080a0b0c0d0f101112',
+ 'ecb-tbl-128: I=120'),
+ ('b4b5b6b7f9fafbfc25241b1a6e69686b', 'd78c5b5ebf9b4dbda6ae506c5074c8fe',
+ '14151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-128: I=121'),
+ ('868784850704051ac7c6c1c08788898a', '6c27444c27204b043812cf8cf95f9769',
+ '28292a2b2d2e2f30323334353738393a',
+ 'ecb-tbl-128: I=122'),
+ ('f4f5f6f7aaa9a8affdfcf3f277707172', 'be94524ee5a2aa50bba8b75f4c0aebcf',
+ '3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-128: I=123'),
+ ('d3d2d1d00605040bc3c2c5c43e010003', 'a0aeaae91ba9f31f51aeb3588cf3a39e',
+ '50515253555657585a5b5c5d5f606162',
+ 'ecb-tbl-128: I=124'),
+ ('73727170424140476a6b74750d0a0b08', '275297779c28266ef9fe4c6a13c08488',
+ '64656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-128: I=125'),
+ ('c2c3c0c10a0908f754555253a1aeafac', '86523d92bb8672cb01cf4a77fd725882',
+ '78797a7b7d7e7f80828384858788898a',
+ 'ecb-tbl-128: I=126'),
+ ('6d6c6f6ef8fbfafd82838c8df8fffefd', '4b8327640e9f33322a04dd96fcbf9a36',
+ '8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-128: I=127'),
+ ('f5f4f7f684878689a6a7a0a1d2cdcccf', 'ce52af650d088ca559425223f4d32694',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2',
+ 'ecb-tbl-128: I=128'),
+
+ # ecb_tbl.txt, KEYSIZE=192
+ ('2d33eef2c0430a8a9ebf45e809c40bb6', 'dff4945e0336df4c1c56bc700eff837f',
+ '00010203050607080a0b0c0d0f10111214151617191a1b1c',
+ 'ecb-tbl-192: I=1'),
+ ('6aa375d1fa155a61fb72353e0a5a8756', 'b6fddef4752765e347d5d2dc196d1252',
+ '1e1f20212324252628292a2b2d2e2f30323334353738393a',
+ 'ecb-tbl-192: I=2'),
+ ('bc3736518b9490dcb8ed60eb26758ed4', 'd23684e3d963b3afcf1a114aca90cbd6',
+ '3c3d3e3f41424344464748494b4c4d4e5051525355565758',
+ 'ecb-tbl-192: I=3'),
+ ('aa214402b46cffb9f761ec11263a311e', '3a7ac027753e2a18c2ceab9e17c11fd0',
+ '5a5b5c5d5f60616264656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-192: I=4'),
+ ('02aea86e572eeab66b2c3af5e9a46fd6', '8f6786bd007528ba26603c1601cdd0d8',
+ '78797a7b7d7e7f80828384858788898a8c8d8e8f91929394',
+ 'ecb-tbl-192: I=5'),
+ ('e2aef6acc33b965c4fa1f91c75ff6f36', 'd17d073b01e71502e28b47ab551168b3',
+ '969798999b9c9d9ea0a1a2a3a5a6a7a8aaabacadafb0b1b2',
+ 'ecb-tbl-192: I=6'),
+ ('0659df46427162b9434865dd9499f91d', 'a469da517119fab95876f41d06d40ffa',
+ 'b4b5b6b7b9babbbcbebfc0c1c3c4c5c6c8c9cacbcdcecfd0',
+ 'ecb-tbl-192: I=7'),
+ ('49a44239c748feb456f59c276a5658df', '6091aa3b695c11f5c0b6ad26d3d862ff',
+ 'd2d3d4d5d7d8d9dadcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-192: I=8'),
+ ('66208f6e9d04525bdedb2733b6a6be37', '70f9e67f9f8df1294131662dc6e69364',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe01000204050607090a0b0c',
+ 'ecb-tbl-192: I=9'),
+ ('3393f8dfc729c97f5480b950bc9666b0', 'd154dcafad8b207fa5cbc95e9996b559',
+ '0e0f10111314151618191a1b1d1e1f20222324252728292a',
+ 'ecb-tbl-192: I=10'),
+ ('606834c8ce063f3234cf1145325dbd71', '4934d541e8b46fa339c805a7aeb9e5da',
+ '2c2d2e2f31323334363738393b3c3d3e4041424345464748',
+ 'ecb-tbl-192: I=11'),
+ ('fec1c04f529bbd17d8cecfcc4718b17f', '62564c738f3efe186e1a127a0c4d3c61',
+ '4a4b4c4d4f50515254555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-192: I=12'),
+ ('32df99b431ed5dc5acf8caf6dc6ce475', '07805aa043986eb23693e23bef8f3438',
+ '68696a6b6d6e6f70727374757778797a7c7d7e7f81828384',
+ 'ecb-tbl-192: I=13'),
+ ('7fdc2b746f3f665296943b83710d1f82', 'df0b4931038bade848dee3b4b85aa44b',
+ '868788898b8c8d8e90919293959697989a9b9c9d9fa0a1a2',
+ 'ecb-tbl-192: I=14'),
+ ('8fba1510a3c5b87e2eaa3f7a91455ca2', '592d5fded76582e4143c65099309477c',
+ 'a4a5a6a7a9aaabacaeafb0b1b3b4b5b6b8b9babbbdbebfc0',
+ 'ecb-tbl-192: I=15'),
+ ('2c9b468b1c2eed92578d41b0716b223b', 'c9b8d6545580d3dfbcdd09b954ed4e92',
+ 'c2c3c4c5c7c8c9cacccdcecfd1d2d3d4d6d7d8d9dbdcddde',
+ 'ecb-tbl-192: I=16'),
+ ('0a2bbf0efc6bc0034f8a03433fca1b1a', '5dccd5d6eb7c1b42acb008201df707a0',
+ 'e0e1e2e3e5e6e7e8eaebecedeff0f1f2f4f5f6f7f9fafbfc',
+ 'ecb-tbl-192: I=17'),
+ ('25260e1f31f4104d387222e70632504b', 'a2a91682ffeb6ed1d34340946829e6f9',
+ 'fefe01010304050608090a0b0d0e0f10121314151718191a',
+ 'ecb-tbl-192: I=18'),
+ ('c527d25a49f08a5228d338642ae65137', 'e45d185b797000348d9267960a68435d',
+ '1c1d1e1f21222324262728292b2c2d2e3031323335363738',
+ 'ecb-tbl-192: I=19'),
+ ('3b49fc081432f5890d0e3d87e884a69e', '45e060dae5901cda8089e10d4f4c246b',
+ '3a3b3c3d3f40414244454647494a4b4c4e4f505153545556',
+ 'ecb-tbl-192: I=20'),
+ ('d173f9ed1e57597e166931df2754a083', 'f6951afacc0079a369c71fdcff45df50',
+ '58595a5b5d5e5f60626364656768696a6c6d6e6f71727374',
+ 'ecb-tbl-192: I=21'),
+ ('8c2b7cafa5afe7f13562daeae1adede0', '9e95e00f351d5b3ac3d0e22e626ddad6',
+ '767778797b7c7d7e80818283858687888a8b8c8d8f909192',
+ 'ecb-tbl-192: I=22'),
+ ('aaf4ec8c1a815aeb826cab741339532c', '9cb566ff26d92dad083b51fdc18c173c',
+ '94959697999a9b9c9e9fa0a1a3a4a5a6a8a9aaabadaeafb0',
+ 'ecb-tbl-192: I=23'),
+ ('40be8c5d9108e663f38f1a2395279ecf', 'c9c82766176a9b228eb9a974a010b4fb',
+ 'd0d1d2d3d5d6d7d8dadbdcdddfe0e1e2e4e5e6e7e9eaebec',
+ 'ecb-tbl-192: I=24'),
+ ('0c8ad9bc32d43e04716753aa4cfbe351', 'd8e26aa02945881d5137f1c1e1386e88',
+ '2a2b2c2d2f30313234353637393a3b3c3e3f404143444546',
+ 'ecb-tbl-192: I=25'),
+ ('1407b1d5f87d63357c8dc7ebbaebbfee', 'c0e024ccd68ff5ffa4d139c355a77c55',
+ '48494a4b4d4e4f50525354555758595a5c5d5e5f61626364',
+ 'ecb-tbl-192: I=26'),
+ ('e62734d1ae3378c4549e939e6f123416', '0b18b3d16f491619da338640df391d43',
+ '84858687898a8b8c8e8f90919394959698999a9b9d9e9fa0',
+ 'ecb-tbl-192: I=27'),
+ ('5a752cff2a176db1a1de77f2d2cdee41', 'dbe09ac8f66027bf20cb6e434f252efc',
+ 'a2a3a4a5a7a8a9aaacadaeafb1b2b3b4b6b7b8b9bbbcbdbe',
+ 'ecb-tbl-192: I=28'),
+ ('a9c8c3a4eabedc80c64730ddd018cd88', '6d04e5e43c5b9cbe05feb9606b6480fe',
+ 'c0c1c2c3c5c6c7c8cacbcccdcfd0d1d2d4d5d6d7d9dadbdc',
+ 'ecb-tbl-192: I=29'),
+ ('ee9b3dbbdb86180072130834d305999a', 'dd1d6553b96be526d9fee0fbd7176866',
+ '1a1b1c1d1f20212224252627292a2b2c2e2f303133343536',
+ 'ecb-tbl-192: I=30'),
+ ('a7fa8c3586b8ebde7568ead6f634a879', '0260ca7e3f979fd015b0dd4690e16d2a',
+ '38393a3b3d3e3f40424344454748494a4c4d4e4f51525354',
+ 'ecb-tbl-192: I=31'),
+ ('37e0f4a87f127d45ac936fe7ad88c10a', '9893734de10edcc8a67c3b110b8b8cc6',
+ '929394959798999a9c9d9e9fa1a2a3a4a6a7a8a9abacadae',
+ 'ecb-tbl-192: I=32'),
+ ('3f77d8b5d92bac148e4e46f697a535c5', '93b30b750516b2d18808d710c2ee84ef',
+ '464748494b4c4d4e50515253555657585a5b5c5d5f606162',
+ 'ecb-tbl-192: I=33'),
+ ('d25ebb686c40f7e2c4da1014936571ca', '16f65fa47be3cb5e6dfe7c6c37016c0e',
+ '828384858788898a8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-192: I=34'),
+ ('4f1c769d1e5b0552c7eca84dea26a549', 'f3847210d5391e2360608e5acb560581',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2b4b5b6b7b9babbbc',
+ 'ecb-tbl-192: I=35'),
+ ('8548e2f882d7584d0fafc54372b6633a', '8754462cd223366d0753913e6af2643d',
+ 'bebfc0c1c3c4c5c6c8c9cacbcdcecfd0d2d3d4d5d7d8d9da',
+ 'ecb-tbl-192: I=36'),
+ ('87d7a336cb476f177cd2a51af2a62cdf', '1ea20617468d1b806a1fd58145462017',
+ 'dcdddedfe1e2e3e4e6e7e8e9ebecedeef0f1f2f3f5f6f7f8',
+ 'ecb-tbl-192: I=37'),
+ ('03b1feac668c4e485c1065dfc22b44ee', '3b155d927355d737c6be9dda60136e2e',
+ 'fafbfcfdfe01000204050607090a0b0c0e0f101113141516',
+ 'ecb-tbl-192: I=38'),
+ ('bda15e66819fa72d653a6866aa287962', '26144f7b66daa91b6333dbd3850502b3',
+ '18191a1b1d1e1f20222324252728292a2c2d2e2f31323334',
+ 'ecb-tbl-192: I=39'),
+ ('4d0c7a0d2505b80bf8b62ceb12467f0a', 'e4f9a4ab52ced8134c649bf319ebcc90',
+ '363738393b3c3d3e40414243454647484a4b4c4d4f505152',
+ 'ecb-tbl-192: I=40'),
+ ('626d34c9429b37211330986466b94e5f', 'b9ddd29ac6128a6cab121e34a4c62b36',
+ '54555657595a5b5c5e5f60616364656668696a6b6d6e6f70',
+ 'ecb-tbl-192: I=41'),
+ ('333c3e6bf00656b088a17e5ff0e7f60a', '6fcddad898f2ce4eff51294f5eaaf5c9',
+ '727374757778797a7c7d7e7f81828384868788898b8c8d8e',
+ 'ecb-tbl-192: I=42'),
+ ('687ed0cdc0d2a2bc8c466d05ef9d2891', 'c9a6fe2bf4028080bea6f7fc417bd7e3',
+ '90919293959697989a9b9c9d9fa0a1a2a4a5a6a7a9aaabac',
+ 'ecb-tbl-192: I=43'),
+ ('487830e78cc56c1693e64b2a6660c7b6', '6a2026846d8609d60f298a9c0673127f',
+ 'aeafb0b1b3b4b5b6b8b9babbbdbebfc0c2c3c4c5c7c8c9ca',
+ 'ecb-tbl-192: I=44'),
+ ('7a48d6b7b52b29392aa2072a32b66160', '2cb25c005e26efea44336c4c97a4240b',
+ 'cccdcecfd1d2d3d4d6d7d8d9dbdcdddee0e1e2e3e5e6e7e8',
+ 'ecb-tbl-192: I=45'),
+ ('907320e64c8c5314d10f8d7a11c8618d', '496967ab8680ddd73d09a0e4c7dcc8aa',
+ 'eaebecedeff0f1f2f4f5f6f7f9fafbfcfefe010103040506',
+ 'ecb-tbl-192: I=46'),
+ ('b561f2ca2d6e65a4a98341f3ed9ff533', 'd5af94de93487d1f3a8c577cb84a66a4',
+ '08090a0b0d0e0f10121314151718191a1c1d1e1f21222324',
+ 'ecb-tbl-192: I=47'),
+ ('df769380d212792d026f049e2e3e48ef', '84bdac569cae2828705f267cc8376e90',
+ '262728292b2c2d2e30313233353637383a3b3c3d3f404142',
+ 'ecb-tbl-192: I=48'),
+ ('79f374bc445bdabf8fccb8843d6054c6', 'f7401dda5ad5ab712b7eb5d10c6f99b6',
+ '44454647494a4b4c4e4f50515354555658595a5b5d5e5f60',
+ 'ecb-tbl-192: I=49'),
+ ('4e02f1242fa56b05c68dbae8fe44c9d6', '1c9d54318539ebd4c3b5b7e37bf119f0',
+ '626364656768696a6c6d6e6f71727374767778797b7c7d7e',
+ 'ecb-tbl-192: I=50'),
+ ('cf73c93cbff57ac635a6f4ad2a4a1545', 'aca572d65fb2764cffd4a6eca090ea0d',
+ '80818283858687888a8b8c8d8f90919294959697999a9b9c',
+ 'ecb-tbl-192: I=51'),
+ ('9923548e2875750725b886566784c625', '36d9c627b8c2a886a10ccb36eae3dfbb',
+ '9e9fa0a1a3a4a5a6a8a9aaabadaeafb0b2b3b4b5b7b8b9ba',
+ 'ecb-tbl-192: I=52'),
+ ('4888336b723a022c9545320f836a4207', '010edbf5981e143a81d646e597a4a568',
+ 'bcbdbebfc1c2c3c4c6c7c8c9cbcccdced0d1d2d3d5d6d7d8',
+ 'ecb-tbl-192: I=53'),
+ ('f84d9a5561b0608b1160dee000c41ba8', '8db44d538dc20cc2f40f3067fd298e60',
+ 'dadbdcdddfe0e1e2e4e5e6e7e9eaebeceeeff0f1f3f4f5f6',
+ 'ecb-tbl-192: I=54'),
+ ('c23192a0418e30a19b45ae3e3625bf22', '930eb53bc71e6ac4b82972bdcd5aafb3',
+ 'f8f9fafbfdfefe00020304050708090a0c0d0e0f11121314',
+ 'ecb-tbl-192: I=55'),
+ ('b84e0690b28b0025381ad82a15e501a7', '6c42a81edcbc9517ccd89c30c95597b4',
+ '161718191b1c1d1e20212223252627282a2b2c2d2f303132',
+ 'ecb-tbl-192: I=56'),
+ ('acef5e5c108876c4f06269f865b8f0b0', 'da389847ad06df19d76ee119c71e1dd3',
+ '34353637393a3b3c3e3f40414344454648494a4b4d4e4f50',
+ 'ecb-tbl-192: I=57'),
+ ('0f1b3603e0f5ddea4548246153a5e064', 'e018fdae13d3118f9a5d1a647a3f0462',
+ '525354555758595a5c5d5e5f61626364666768696b6c6d6e',
+ 'ecb-tbl-192: I=58'),
+ ('fbb63893450d42b58c6d88cd3c1809e3', '2aa65db36264239d3846180fabdfad20',
+ '70717273757677787a7b7c7d7f80818284858687898a8b8c',
+ 'ecb-tbl-192: I=59'),
+ ('4bef736df150259dae0c91354e8a5f92', '1472163e9a4f780f1ceb44b07ecf4fdb',
+ '8e8f90919394959698999a9b9d9e9fa0a2a3a4a5a7a8a9aa',
+ 'ecb-tbl-192: I=60'),
+ ('7d2d46242056ef13d3c3fc93c128f4c7', 'c8273fdc8f3a9f72e91097614b62397c',
+ 'acadaeafb1b2b3b4b6b7b8b9bbbcbdbec0c1c2c3c5c6c7c8',
+ 'ecb-tbl-192: I=61'),
+ ('e9c1ba2df415657a256edb33934680fd', '66c8427dcd733aaf7b3470cb7d976e3f',
+ 'cacbcccdcfd0d1d2d4d5d6d7d9dadbdcdedfe0e1e3e4e5e6',
+ 'ecb-tbl-192: I=62'),
+ ('e23ee277b0aa0a1dfb81f7527c3514f1', '146131cb17f1424d4f8da91e6f80c1d0',
+ 'e8e9eaebedeeeff0f2f3f4f5f7f8f9fafcfdfeff01020304',
+ 'ecb-tbl-192: I=63'),
+ ('3e7445b0b63caaf75e4a911e12106b4c', '2610d0ad83659081ae085266a88770dc',
+ '060708090b0c0d0e10111213151617181a1b1c1d1f202122',
+ 'ecb-tbl-192: I=64'),
+ ('767774752023222544455a5be6e1e0e3', '38a2b5a974b0575c5d733917fb0d4570',
+ '24252627292a2b2c2e2f30313334353638393a3b3d3e3f40',
+ 'ecb-tbl-192: I=65'),
+ ('72737475717e7f7ce9e8ebea696a6b6c', 'e21d401ebc60de20d6c486e4f39a588b',
+ '424344454748494a4c4d4e4f51525354565758595b5c5d5e',
+ 'ecb-tbl-192: I=66'),
+ ('dfdedddc25262728c9c8cfcef1eeefec', 'e51d5f88c670b079c0ca1f0c2c4405a2',
+ '60616263656667686a6b6c6d6f70717274757677797a7b7c',
+ 'ecb-tbl-192: I=67'),
+ ('fffe0100707776755f5e5d5c7675746b', '246a94788a642fb3d1b823c8762380c8',
+ '7e7f80818384858688898a8b8d8e8f90929394959798999a',
+ 'ecb-tbl-192: I=68'),
+ ('e0e1e2e3424140479f9e9190292e2f2c', 'b80c391c5c41a4c3b30c68e0e3d7550f',
+ '9c9d9e9fa1a2a3a4a6a7a8a9abacadaeb0b1b2b3b5b6b7b8',
+ 'ecb-tbl-192: I=69'),
+ ('2120272690efeeed3b3a39384e4d4c4b', 'b77c4754fc64eb9a1154a9af0bb1f21c',
+ 'babbbcbdbfc0c1c2c4c5c6c7c9cacbcccecfd0d1d3d4d5d6',
+ 'ecb-tbl-192: I=70'),
+ ('ecedeeef5350516ea1a0a7a6a3acadae', 'fb554de520d159a06bf219fc7f34a02f',
+ 'd8d9dadbdddedfe0e2e3e4e5e7e8e9eaecedeeeff1f2f3f4',
+ 'ecb-tbl-192: I=71'),
+ ('32333c3d25222320e9e8ebeacecdccc3', 'a89fba152d76b4927beed160ddb76c57',
+ 'f6f7f8f9fbfcfdfe00010203050607080a0b0c0d0f101112',
+ 'ecb-tbl-192: I=72'),
+ ('40414243626160678a8bb4b511161714', '5676eab4a98d2e8473b3f3d46424247c',
+ '14151617191a1b1c1e1f20212324252628292a2b2d2e2f30',
+ 'ecb-tbl-192: I=73'),
+ ('94959293f5fafbf81f1e1d1c7c7f7e79', '4e8f068bd7ede52a639036ec86c33568',
+ '323334353738393a3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-192: I=74'),
+ ('bebfbcbd191a1b14cfcec9c8546b6a69', 'f0193c4d7aff1791ee4c07eb4a1824fc',
+ '50515253555657585a5b5c5d5f60616264656667696a6b6c',
+ 'ecb-tbl-192: I=75'),
+ ('2c2d3233898e8f8cbbbab9b8333031ce', 'ac8686eeca9ba761afe82d67b928c33f',
+ '6e6f70717374757678797a7b7d7e7f80828384858788898a',
+ 'ecb-tbl-192: I=76'),
+ ('84858687bfbcbdba37363938fdfafbf8', '5faf8573e33b145b6a369cd3606ab2c9',
+ '8c8d8e8f91929394969798999b9c9d9ea0a1a2a3a5a6a7a8',
+ 'ecb-tbl-192: I=77'),
+ ('828384857669686b909192930b08090e', '31587e9944ab1c16b844ecad0df2e7da',
+ 'aaabacadafb0b1b2b4b5b6b7b9babbbcbebfc0c1c3c4c5c6',
+ 'ecb-tbl-192: I=78'),
+ ('bebfbcbd9695948b707176779e919093', 'd017fecd91148aba37f6f3068aa67d8a',
+ 'c8c9cacbcdcecfd0d2d3d4d5d7d8d9dadcdddedfe1e2e3e4',
+ 'ecb-tbl-192: I=79'),
+ ('8b8a85846067666521202322d0d3d2dd', '788ef2f021a73cba2794b616078a8500',
+ 'e6e7e8e9ebecedeef0f1f2f3f5f6f7f8fafbfcfdfe010002',
+ 'ecb-tbl-192: I=80'),
+ ('76777475f1f2f3f4f8f9e6e777707172', '5d1ef20dced6bcbc12131ac7c54788aa',
+ '04050607090a0b0c0e0f10111314151618191a1b1d1e1f20',
+ 'ecb-tbl-192: I=81'),
+ ('a4a5a2a34f404142b4b5b6b727242522', 'b3c8cf961faf9ea05fdde6d1e4d8f663',
+ '222324252728292a2c2d2e2f31323334363738393b3c3d3e',
+ 'ecb-tbl-192: I=82'),
+ ('94959697e1e2e3ec16171011839c9d9e', '143075c70605861c7fac6526199e459f',
+ '40414243454647484a4b4c4d4f50515254555657595a5b5c',
+ 'ecb-tbl-192: I=83'),
+ ('03023d3c06010003dedfdcddfffcfde2', 'a5ae12eade9a87268d898bfc8fc0252a',
+ '5e5f60616364656668696a6b6d6e6f70727374757778797a',
+ 'ecb-tbl-192: I=84'),
+ ('10111213f1f2f3f4cecfc0c1dbdcddde', '0924f7cf2e877a4819f5244a360dcea9',
+ '7c7d7e7f81828384868788898b8c8d8e9091929395969798',
+ 'ecb-tbl-192: I=85'),
+ ('67666160724d4c4f1d1c1f1e73707176', '3d9e9635afcc3e291cc7ab3f27d1c99a',
+ '9a9b9c9d9fa0a1a2a4a5a6a7a9aaabacaeafb0b1b3b4b5b6',
+ 'ecb-tbl-192: I=86'),
+ ('e6e7e4e5a8abaad584858283909f9e9d', '9d80feebf87510e2b8fb98bb54fd788c',
+ 'b8b9babbbdbebfc0c2c3c4c5c7c8c9cacccdcecfd1d2d3d4',
+ 'ecb-tbl-192: I=87'),
+ ('71707f7e565150537d7c7f7e6162636c', '5f9d1a082a1a37985f174002eca01309',
+ 'd6d7d8d9dbdcdddee0e1e2e3e5e6e7e8eaebecedeff0f1f2',
+ 'ecb-tbl-192: I=88'),
+ ('64656667212223245555aaaa03040506', 'a390ebb1d1403930184a44b4876646e4',
+ 'f4f5f6f7f9fafbfcfefe01010304050608090a0b0d0e0f10',
+ 'ecb-tbl-192: I=89'),
+ ('9e9f9899aba4a5a6cfcecdcc2b28292e', '700fe918981c3195bb6c4bcb46b74e29',
+ '121314151718191a1c1d1e1f21222324262728292b2c2d2e',
+ 'ecb-tbl-192: I=90'),
+ ('c7c6c5c4d1d2d3dc626364653a454447', '907984406f7bf2d17fb1eb15b673d747',
+ '30313233353637383a3b3c3d3f40414244454647494a4b4c',
+ 'ecb-tbl-192: I=91'),
+ ('f6f7e8e9e0e7e6e51d1c1f1e5b585966', 'c32a956dcfc875c2ac7c7cc8b8cc26e1',
+ '4e4f50515354555658595a5b5d5e5f60626364656768696a',
+ 'ecb-tbl-192: I=92'),
+ ('bcbdbebf5d5e5f5868696667f4f3f2f1', '02646e2ebfa9b820cf8424e9b9b6eb51',
+ '6c6d6e6f71727374767778797b7c7d7e8081828385868788',
+ 'ecb-tbl-192: I=93'),
+ ('40414647b0afaead9b9a99989b98999e', '621fda3a5bbd54c6d3c685816bd4ead8',
+ '8a8b8c8d8f90919294959697999a9b9c9e9fa0a1a3a4a5a6',
+ 'ecb-tbl-192: I=94'),
+ ('69686b6a0201001f0f0e0908b4bbbab9', 'd4e216040426dfaf18b152469bc5ac2f',
+ 'a8a9aaabadaeafb0b2b3b4b5b7b8b9babcbdbebfc1c2c3c4',
+ 'ecb-tbl-192: I=95'),
+ ('c7c6c9c8d8dfdedd5a5b5859bebdbcb3', '9d0635b9d33b6cdbd71f5d246ea17cc8',
+ 'c6c7c8c9cbcccdced0d1d2d3d5d6d7d8dadbdcdddfe0e1e2',
+ 'ecb-tbl-192: I=96'),
+ ('dedfdcdd787b7a7dfffee1e0b2b5b4b7', '10abad1bd9bae5448808765583a2cc1a',
+ 'e4e5e6e7e9eaebeceeeff0f1f3f4f5f6f8f9fafbfdfefe00',
+ 'ecb-tbl-192: I=97'),
+ ('4d4c4b4a606f6e6dd0d1d2d3fbf8f9fe', '6891889e16544e355ff65a793c39c9a8',
+ '020304050708090a0c0d0e0f11121314161718191b1c1d1e',
+ 'ecb-tbl-192: I=98'),
+ ('b7b6b5b4d7d4d5dae5e4e3e2e1fefffc', 'cc735582e68072c163cd9ddf46b91279',
+ '20212223252627282a2b2c2d2f30313234353637393a3b3c',
+ 'ecb-tbl-192: I=99'),
+ ('cecfb0b1f7f0f1f2aeafacad3e3d3c23', 'c5c68b9aeeb7f878df578efa562f9574',
+ '3e3f40414344454648494a4b4d4e4f50525354555758595a',
+ 'ecb-tbl-192: I=100'),
+ ('cacbc8c9cdcecfc812131c1d494e4f4c', '5f4764395a667a47d73452955d0d2ce8',
+ '5c5d5e5f61626364666768696b6c6d6e7071727375767778',
+ 'ecb-tbl-192: I=101'),
+ ('9d9c9b9ad22d2c2fb1b0b3b20c0f0e09', '701448331f66106cefddf1eb8267c357',
+ '7a7b7c7d7f80818284858687898a8b8c8e8f909193949596',
+ 'ecb-tbl-192: I=102'),
+ ('7a7b787964676659959493924f404142', 'cb3ee56d2e14b4e1941666f13379d657',
+ '98999a9b9d9e9fa0a2a3a4a5a7a8a9aaacadaeafb1b2b3b4',
+ 'ecb-tbl-192: I=103'),
+ ('aaaba4a5cec9c8cb1f1e1d1caba8a9a6', '9fe16efd18ab6e1981191851fedb0764',
+ 'b6b7b8b9bbbcbdbec0c1c2c3c5c6c7c8cacbcccdcfd0d1d2',
+ 'ecb-tbl-192: I=104'),
+ ('93929190282b2a2dc4c5fafb92959497', '3dc9ba24e1b223589b147adceb4c8e48',
+ 'd4d5d6d7d9dadbdcdedfe0e1e3e4e5e6e8e9eaebedeeeff0',
+ 'ecb-tbl-192: I=105'),
+ ('efeee9e8ded1d0d339383b3a888b8a8d', '1c333032682e7d4de5e5afc05c3e483c',
+ 'f2f3f4f5f7f8f9fafcfdfeff01020304060708090b0c0d0e',
+ 'ecb-tbl-192: I=106'),
+ ('7f7e7d7ca2a1a0af78797e7f112e2f2c', 'd593cc99a95afef7e92038e05a59d00a',
+ '10111213151617181a1b1c1d1f20212224252627292a2b2c',
+ 'ecb-tbl-192: I=107'),
+ ('84859a9b2b2c2d2e868784852625245b', '51e7f96f53b4353923452c222134e1ec',
+ '2e2f30313334353638393a3b3d3e3f40424344454748494a',
+ 'ecb-tbl-192: I=108'),
+ ('b0b1b2b3070405026869666710171615', '4075b357a1a2b473400c3b25f32f81a4',
+ '4c4d4e4f51525354565758595b5c5d5e6061626365666768',
+ 'ecb-tbl-192: I=109'),
+ ('acadaaabbda2a3a00d0c0f0e595a5b5c', '302e341a3ebcd74f0d55f61714570284',
+ '6a6b6c6d6f70717274757677797a7b7c7e7f808183848586',
+ 'ecb-tbl-192: I=110'),
+ ('121310115655544b5253545569666764', '57abdd8231280da01c5042b78cf76522',
+ '88898a8b8d8e8f90929394959798999a9c9d9e9fa1a2a3a4',
+ 'ecb-tbl-192: I=111'),
+ ('dedfd0d166616063eaebe8e94142434c', '17f9ea7eea17ac1adf0e190fef799e92',
+ 'a6a7a8a9abacadaeb0b1b2b3b5b6b7b8babbbcbdbfc0c1c2',
+ 'ecb-tbl-192: I=112'),
+ ('dbdad9d81417161166677879e0e7e6e5', '2e1bdd563dd87ee5c338dd6d098d0a7a',
+ 'c4c5c6c7c9cacbcccecfd0d1d3d4d5d6d8d9dadbdddedfe0',
+ 'ecb-tbl-192: I=113'),
+ ('6a6b6c6de0efeeed2b2a2928c0c3c2c5', 'eb869996e6f8bfb2bfdd9e0c4504dbb2',
+ 'e2e3e4e5e7e8e9eaecedeeeff1f2f3f4f6f7f8f9fbfcfdfe',
+ 'ecb-tbl-192: I=114'),
+ ('b1b0b3b21714151a1a1b1c1d5649484b', 'c2e01549e9decf317468b3e018c61ba8',
+ '00010203050607080a0b0c0d0f10111214151617191a1b1c',
+ 'ecb-tbl-192: I=115'),
+ ('39380706a3a4a5a6c4c5c6c77271706f', '8da875d033c01dd463b244a1770f4a22',
+ '1e1f20212324252628292a2b2d2e2f30323334353738393a',
+ 'ecb-tbl-192: I=116'),
+ ('5c5d5e5f1013121539383736e2e5e4e7', '8ba0dcf3a186844f026d022f8839d696',
+ '3c3d3e3f41424344464748494b4c4d4e5051525355565758',
+ 'ecb-tbl-192: I=117'),
+ ('43424544ead5d4d72e2f2c2d64676661', 'e9691ff9a6cc6970e51670a0fd5b88c1',
+ '5a5b5c5d5f60616264656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-192: I=118'),
+ ('55545756989b9a65f8f9feff18171615', 'f2baec06faeed30f88ee63ba081a6e5b',
+ '78797a7b7d7e7f80828384858788898a8c8d8e8f91929394',
+ 'ecb-tbl-192: I=119'),
+ ('05040b0a525554573c3d3e3f4a494847', '9c39d4c459ae5753394d6094adc21e78',
+ '969798999b9c9d9ea0a1a2a3a5a6a7a8aaabacadafb0b1b2',
+ 'ecb-tbl-192: I=120'),
+ ('14151617595a5b5c8584fbfa8e89888b', '6345b532a11904502ea43ba99c6bd2b2',
+ 'b4b5b6b7b9babbbcbebfc0c1c3c4c5c6c8c9cacbcdcecfd0',
+ 'ecb-tbl-192: I=121'),
+ ('7c7d7a7bfdf2f3f029282b2a51525354', '5ffae3061a95172e4070cedce1e428c8',
+ 'd2d3d4d5d7d8d9dadcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-192: I=122'),
+ ('38393a3b1e1d1c1341404746c23d3c3e', '0a4566be4cdf9adce5dec865b5ab34cd',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe01000204050607090a0b0c',
+ 'ecb-tbl-192: I=123'),
+ ('8d8c939240474645818083827c7f7e41', 'ca17fcce79b7404f2559b22928f126fb',
+ '0e0f10111314151618191a1b1d1e1f20222324252728292a',
+ 'ecb-tbl-192: I=124'),
+ ('3b3a39381a19181f32333c3d45424340', '97ca39b849ed73a6470a97c821d82f58',
+ '2c2d2e2f31323334363738393b3c3d3e4041424345464748',
+ 'ecb-tbl-192: I=125'),
+ ('f0f1f6f738272625828380817f7c7d7a', '8198cb06bc684c6d3e9b7989428dcf7a',
+ '4a4b4c4d4f50515254555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-192: I=126'),
+ ('89888b8a0407061966676061141b1a19', 'f53c464c705ee0f28d9a4c59374928bd',
+ '68696a6b6d6e6f70727374757778797a7c7d7e7f81828384',
+ 'ecb-tbl-192: I=127'),
+ ('d3d2dddcaaadacaf9c9d9e9fe8ebeae5', '9adb3d4cca559bb98c3e2ed73dbf1154',
+ '868788898b8c8d8e90919293959697989a9b9c9d9fa0a1a2',
+ 'ecb-tbl-192: I=128'),
+
+ # ecb_tbl.txt, KEYSIZE=256
+ ('834eadfccac7e1b30664b1aba44815ab', '1946dabf6a03a2a2c3d0b05080aed6fc',
+ '00010203050607080a0b0c0d0f10111214151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-256: I=1'),
+ ('d9dc4dba3021b05d67c0518f72b62bf1', '5ed301d747d3cc715445ebdec62f2fb4',
+ '28292a2b2d2e2f30323334353738393a3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-256: I=2'),
+ ('a291d86301a4a739f7392173aa3c604c', '6585c8f43d13a6beab6419fc5935b9d0',
+ '50515253555657585a5b5c5d5f60616264656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-256: I=3'),
+ ('4264b2696498de4df79788a9f83e9390', '2a5b56a596680fcc0e05f5e0f151ecae',
+ '78797a7b7d7e7f80828384858788898a8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-256: I=4'),
+ ('ee9932b3721804d5a83ef5949245b6f6', 'f5d6ff414fd2c6181494d20c37f2b8c4',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2b4b5b6b7b9babbbcbebfc0c1c3c4c5c6',
+ 'ecb-tbl-256: I=5'),
+ ('e6248f55c5fdcbca9cbbb01c88a2ea77', '85399c01f59fffb5204f19f8482f00b8',
+ 'c8c9cacbcdcecfd0d2d3d4d5d7d8d9dadcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-256: I=6'),
+ ('b8358e41b9dff65fd461d55a99266247', '92097b4c88a041ddf98144bc8d22e8e7',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe01000204050607090a0b0c0e0f101113141516',
+ 'ecb-tbl-256: I=7'),
+ ('f0e2d72260af58e21e015ab3a4c0d906', '89bd5b73b356ab412aef9f76cea2d65c',
+ '18191a1b1d1e1f20222324252728292a2c2d2e2f31323334363738393b3c3d3e',
+ 'ecb-tbl-256: I=8'),
+ ('475b8b823ce8893db3c44a9f2a379ff7', '2536969093c55ff9454692f2fac2f530',
+ '40414243454647484a4b4c4d4f50515254555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-256: I=9'),
+ ('688f5281945812862f5f3076cf80412f', '07fc76a872843f3f6e0081ee9396d637',
+ '68696a6b6d6e6f70727374757778797a7c7d7e7f81828384868788898b8c8d8e',
+ 'ecb-tbl-256: I=10'),
+ ('08d1d2bc750af553365d35e75afaceaa', 'e38ba8ec2aa741358dcc93e8f141c491',
+ '90919293959697989a9b9c9d9fa0a1a2a4a5a6a7a9aaabacaeafb0b1b3b4b5b6',
+ 'ecb-tbl-256: I=11'),
+ ('8707121f47cc3efceca5f9a8474950a1', 'd028ee23e4a89075d0b03e868d7d3a42',
+ 'b8b9babbbdbebfc0c2c3c4c5c7c8c9cacccdcecfd1d2d3d4d6d7d8d9dbdcddde',
+ 'ecb-tbl-256: I=12'),
+ ('e51aa0b135dba566939c3b6359a980c5', '8cd9423dfc459e547155c5d1d522e540',
+ 'e0e1e2e3e5e6e7e8eaebecedeff0f1f2f4f5f6f7f9fafbfcfefe010103040506',
+ 'ecb-tbl-256: I=13'),
+ ('069a007fc76a459f98baf917fedf9521', '080e9517eb1677719acf728086040ae3',
+ '08090a0b0d0e0f10121314151718191a1c1d1e1f21222324262728292b2c2d2e',
+ 'ecb-tbl-256: I=14'),
+ ('726165c1723fbcf6c026d7d00b091027', '7c1700211a3991fc0ecded0ab3e576b0',
+ '30313233353637383a3b3c3d3f40414244454647494a4b4c4e4f505153545556',
+ 'ecb-tbl-256: I=15'),
+ ('d7c544de91d55cfcde1f84ca382200ce', 'dabcbcc855839251db51e224fbe87435',
+ '58595a5b5d5e5f60626364656768696a6c6d6e6f71727374767778797b7c7d7e',
+ 'ecb-tbl-256: I=16'),
+ ('fed3c9a161b9b5b2bd611b41dc9da357', '68d56fad0406947a4dd27a7448c10f1d',
+ '80818283858687888a8b8c8d8f90919294959697999a9b9c9e9fa0a1a3a4a5a6',
+ 'ecb-tbl-256: I=17'),
+ ('4f634cdc6551043409f30b635832cf82', 'da9a11479844d1ffee24bbf3719a9925',
+ 'a8a9aaabadaeafb0b2b3b4b5b7b8b9babcbdbebfc1c2c3c4c6c7c8c9cbcccdce',
+ 'ecb-tbl-256: I=18'),
+ ('109ce98db0dfb36734d9f3394711b4e6', '5e4ba572f8d23e738da9b05ba24b8d81',
+ 'd0d1d2d3d5d6d7d8dadbdcdddfe0e1e2e4e5e6e7e9eaebeceeeff0f1f3f4f5f6',
+ 'ecb-tbl-256: I=19'),
+ ('4ea6dfaba2d8a02ffdffa89835987242', 'a115a2065d667e3f0b883837a6e903f8',
+ '70717273757677787a7b7c7d7f80818284858687898a8b8c8e8f909193949596',
+ 'ecb-tbl-256: I=20'),
+ ('5ae094f54af58e6e3cdbf976dac6d9ef', '3e9e90dc33eac2437d86ad30b137e66e',
+ '98999a9b9d9e9fa0a2a3a4a5a7a8a9aaacadaeafb1b2b3b4b6b7b8b9bbbcbdbe',
+ 'ecb-tbl-256: I=21'),
+ ('764d8e8e0f29926dbe5122e66354fdbe', '01ce82d8fbcdae824cb3c48e495c3692',
+ 'c0c1c2c3c5c6c7c8cacbcccdcfd0d1d2d4d5d6d7d9dadbdcdedfe0e1e3e4e5e6',
+ 'ecb-tbl-256: I=22'),
+ ('3f0418f888cdf29a982bf6b75410d6a9', '0c9cff163ce936faaf083cfd3dea3117',
+ 'e8e9eaebedeeeff0f2f3f4f5f7f8f9fafcfdfeff01020304060708090b0c0d0e',
+ 'ecb-tbl-256: I=23'),
+ ('e4a3e7cb12cdd56aa4a75197a9530220', '5131ba9bd48f2bba85560680df504b52',
+ '10111213151617181a1b1c1d1f20212224252627292a2b2c2e2f303133343536',
+ 'ecb-tbl-256: I=24'),
+ ('211677684aac1ec1a160f44c4ebf3f26', '9dc503bbf09823aec8a977a5ad26ccb2',
+ '38393a3b3d3e3f40424344454748494a4c4d4e4f51525354565758595b5c5d5e',
+ 'ecb-tbl-256: I=25'),
+ ('d21e439ff749ac8f18d6d4b105e03895', '9a6db0c0862e506a9e397225884041d7',
+ '60616263656667686a6b6c6d6f70717274757677797a7b7c7e7f808183848586',
+ 'ecb-tbl-256: I=26'),
+ ('d9f6ff44646c4725bd4c0103ff5552a7', '430bf9570804185e1ab6365fc6a6860c',
+ '88898a8b8d8e8f90929394959798999a9c9d9e9fa1a2a3a4a6a7a8a9abacadae',
+ 'ecb-tbl-256: I=27'),
+ ('0b1256c2a00b976250cfc5b0c37ed382', '3525ebc02f4886e6a5a3762813e8ce8a',
+ 'b0b1b2b3b5b6b7b8babbbcbdbfc0c1c2c4c5c6c7c9cacbcccecfd0d1d3d4d5d6',
+ 'ecb-tbl-256: I=28'),
+ ('b056447ffc6dc4523a36cc2e972a3a79', '07fa265c763779cce224c7bad671027b',
+ 'd8d9dadbdddedfe0e2e3e4e5e7e8e9eaecedeeeff1f2f3f4f6f7f8f9fbfcfdfe',
+ 'ecb-tbl-256: I=29'),
+ ('5e25ca78f0de55802524d38da3fe4456', 'e8b72b4e8be243438c9fff1f0e205872',
+ '00010203050607080a0b0c0d0f10111214151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-256: I=30'),
+ ('a5bcf4728fa5eaad8567c0dc24675f83', '109d4f999a0e11ace1f05e6b22cbcb50',
+ '28292a2b2d2e2f30323334353738393a3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-256: I=31'),
+ ('814e59f97ed84646b78b2ca022e9ca43', '45a5e8d4c3ed58403ff08d68a0cc4029',
+ '50515253555657585a5b5c5d5f60616264656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-256: I=32'),
+ ('15478beec58f4775c7a7f5d4395514d7', '196865964db3d417b6bd4d586bcb7634',
+ '78797a7b7d7e7f80828384858788898a8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-256: I=33'),
+ ('253548ffca461c67c8cbc78cd59f4756', '60436ad45ac7d30d99195f815d98d2ae',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2b4b5b6b7b9babbbcbebfc0c1c3c4c5c6',
+ 'ecb-tbl-256: I=34'),
+ ('fd7ad8d73b9b0f8cc41600640f503d65', 'bb07a23f0b61014b197620c185e2cd75',
+ 'c8c9cacbcdcecfd0d2d3d4d5d7d8d9dadcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-256: I=35'),
+ ('06199de52c6cbf8af954cd65830bcd56', '5bc0b2850129c854423aff0751fe343b',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe01000204050607090a0b0c0e0f101113141516',
+ 'ecb-tbl-256: I=36'),
+ ('f17c4ffe48e44c61bd891e257e725794', '7541a78f96738e6417d2a24bd2beca40',
+ '18191a1b1d1e1f20222324252728292a2c2d2e2f31323334363738393b3c3d3e',
+ 'ecb-tbl-256: I=37'),
+ ('9a5b4a402a3e8a59be6bf5cd8154f029', 'b0a303054412882e464591f1546c5b9e',
+ '40414243454647484a4b4c4d4f50515254555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-256: I=38'),
+ ('79bd40b91a7e07dc939d441782ae6b17', '778c06d8a355eeee214fcea14b4e0eef',
+ '68696a6b6d6e6f70727374757778797a7c7d7e7f81828384868788898b8c8d8e',
+ 'ecb-tbl-256: I=39'),
+ ('d8ceaaf8976e5fbe1012d8c84f323799', '09614206d15cbace63227d06db6beebb',
+ '90919293959697989a9b9c9d9fa0a1a2a4a5a6a7a9aaabacaeafb0b1b3b4b5b6',
+ 'ecb-tbl-256: I=40'),
+ ('3316e2751e2e388b083da23dd6ac3fbe', '41b97fb20e427a9fdbbb358d9262255d',
+ 'b8b9babbbdbebfc0c2c3c4c5c7c8c9cacccdcecfd1d2d3d4d6d7d8d9dbdcddde',
+ 'ecb-tbl-256: I=41'),
+ ('8b7cfbe37de7dca793521819242c5816', 'c1940f703d845f957652c2d64abd7adf',
+ 'e0e1e2e3e5e6e7e8eaebecedeff0f1f2f4f5f6f7f9fafbfcfefe010103040506',
+ 'ecb-tbl-256: I=42'),
+ ('f23f033c0eebf8ec55752662fd58ce68', 'd2d44fcdae5332343366db297efcf21b',
+ '08090a0b0d0e0f10121314151718191a1c1d1e1f21222324262728292b2c2d2e',
+ 'ecb-tbl-256: I=43'),
+ ('59eb34f6c8bdbacc5fc6ad73a59a1301', 'ea8196b79dbe167b6aa9896e287eed2b',
+ '30313233353637383a3b3c3d3f40414244454647494a4b4c4e4f505153545556',
+ 'ecb-tbl-256: I=44'),
+ ('dcde8b6bd5cf7cc22d9505e3ce81261a', 'd6b0b0c4ba6c7dbe5ed467a1e3f06c2d',
+ '58595a5b5d5e5f60626364656768696a6c6d6e6f71727374767778797b7c7d7e',
+ 'ecb-tbl-256: I=45'),
+ ('e33cf7e524fed781e7042ff9f4b35dc7', 'ec51eb295250c22c2fb01816fb72bcae',
+ '80818283858687888a8b8c8d8f90919294959697999a9b9c9e9fa0a1a3a4a5a6',
+ 'ecb-tbl-256: I=46'),
+ ('27963c8facdf73062867d164df6d064c', 'aded6630a07ce9c7408a155d3bd0d36f',
+ 'a8a9aaabadaeafb0b2b3b4b5b7b8b9babcbdbebfc1c2c3c4c6c7c8c9cbcccdce',
+ 'ecb-tbl-256: I=47'),
+ ('77b1ce386b551b995f2f2a1da994eef8', '697c9245b9937f32f5d1c82319f0363a',
+ 'd0d1d2d3d5d6d7d8dadbdcdddfe0e1e2e4e5e6e7e9eaebeceeeff0f1f3f4f5f6',
+ 'ecb-tbl-256: I=48'),
+ ('f083388b013679efcf0bb9b15d52ae5c', 'aad5ad50c6262aaec30541a1b7b5b19c',
+ 'f8f9fafbfdfefe00020304050708090a0c0d0e0f11121314161718191b1c1d1e',
+ 'ecb-tbl-256: I=49'),
+ ('c5009e0dab55db0abdb636f2600290c8', '7d34b893855341ec625bd6875ac18c0d',
+ '20212223252627282a2b2c2d2f30313234353637393a3b3c3e3f404143444546',
+ 'ecb-tbl-256: I=50'),
+ ('7804881e26cd532d8514d3683f00f1b9', '7ef05105440f83862f5d780e88f02b41',
+ '48494a4b4d4e4f50525354555758595a5c5d5e5f61626364666768696b6c6d6e',
+ 'ecb-tbl-256: I=51'),
+ ('46cddcd73d1eb53e675ca012870a92a3', 'c377c06403382061af2c9c93a8e70df6',
+ '70717273757677787a7b7c7d7f80818284858687898a8b8c8e8f909193949596',
+ 'ecb-tbl-256: I=52'),
+ ('a9fb44062bb07fe130a8e8299eacb1ab', '1dbdb3ffdc052dacc83318853abc6de5',
+ '98999a9b9d9e9fa0a2a3a4a5a7a8a9aaacadaeafb1b2b3b4b6b7b8b9bbbcbdbe',
+ 'ecb-tbl-256: I=53'),
+ ('2b6ff8d7a5cc3a28a22d5a6f221af26b', '69a6eab00432517d0bf483c91c0963c7',
+ 'c0c1c2c3c5c6c7c8cacbcccdcfd0d1d2d4d5d6d7d9dadbdcdedfe0e1e3e4e5e6',
+ 'ecb-tbl-256: I=54'),
+ ('1a9527c29b8add4b0e3e656dbb2af8b4', '0797f41dc217c80446e1d514bd6ab197',
+ 'e8e9eaebedeeeff0f2f3f4f5f7f8f9fafcfdfeff01020304060708090b0c0d0e',
+ 'ecb-tbl-256: I=55'),
+ ('7f99cf2c75244df015eb4b0c1050aeae', '9dfd76575902a637c01343c58e011a03',
+ '10111213151617181a1b1c1d1f20212224252627292a2b2c2e2f303133343536',
+ 'ecb-tbl-256: I=56'),
+ ('e84ff85b0d9454071909c1381646c4ed', 'acf4328ae78f34b9fa9b459747cc2658',
+ '38393a3b3d3e3f40424344454748494a4c4d4e4f51525354565758595b5c5d5e',
+ 'ecb-tbl-256: I=57'),
+ ('89afd40f99521280d5399b12404f6db4', 'b0479aea12bac4fe2384cf98995150c6',
+ '60616263656667686a6b6c6d6f70717274757677797a7b7c7e7f808183848586',
+ 'ecb-tbl-256: I=58'),
+ ('a09ef32dbc5119a35ab7fa38656f0329', '9dd52789efe3ffb99f33b3da5030109a',
+ '88898a8b8d8e8f90929394959798999a9c9d9e9fa1a2a3a4a6a7a8a9abacadae',
+ 'ecb-tbl-256: I=59'),
+ ('61773457f068c376c7829b93e696e716', 'abbb755e4621ef8f1214c19f649fb9fd',
+ 'b0b1b2b3b5b6b7b8babbbcbdbfc0c1c2c4c5c6c7c9cacbcccecfd0d1d3d4d5d6',
+ 'ecb-tbl-256: I=60'),
+ ('a34f0cae726cce41dd498747d891b967', 'da27fb8174357bce2bed0e7354f380f9',
+ 'd8d9dadbdddedfe0e2e3e4e5e7e8e9eaecedeeeff1f2f3f4f6f7f8f9fbfcfdfe',
+ 'ecb-tbl-256: I=61'),
+ ('856f59496c7388ee2d2b1a27b7697847', 'c59a0663f0993838f6e5856593bdc5ef',
+ '00010203050607080a0b0c0d0f10111214151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-256: I=62'),
+ ('cb090c593ef7720bd95908fb93b49df4', 'ed60b264b5213e831607a99c0ce5e57e',
+ '28292a2b2d2e2f30323334353738393a3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-256: I=63'),
+ ('a0ac75cd2f1923d460fc4d457ad95baf', 'e50548746846f3eb77b8c520640884ed',
+ '50515253555657585a5b5c5d5f60616264656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-256: I=64'),
+ ('2a2b282974777689e8e9eeef525d5c5f', '28282cc7d21d6a2923641e52d188ef0c',
+ '78797a7b7d7e7f80828384858788898a8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-256: I=65'),
+ ('909192939390919e0f0e09089788898a', '0dfa5b02abb18e5a815305216d6d4f8e',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2b4b5b6b7b9babbbcbebfc0c1c3c4c5c6',
+ 'ecb-tbl-256: I=66'),
+ ('777675748d8e8f907170777649464744', '7359635c0eecefe31d673395fb46fb99',
+ 'c8c9cacbcdcecfd0d2d3d4d5d7d8d9dadcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-256: I=67'),
+ ('717073720605040b2d2c2b2a05fafbf9', '73c679f7d5aef2745c9737bb4c47fb36',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe01000204050607090a0b0c0e0f101113141516',
+ 'ecb-tbl-256: I=68'),
+ ('64656667fefdfcc31b1a1d1ca5aaaba8', 'b192bd472a4d2eafb786e97458967626',
+ '18191a1b1d1e1f20222324252728292a2c2d2e2f31323334363738393b3c3d3e',
+ 'ecb-tbl-256: I=69'),
+ ('dbdad9d86a696867b5b4b3b2c8d7d6d5', '0ec327f6c8a2b147598ca3fde61dc6a4',
+ '40414243454647484a4b4c4d4f50515254555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-256: I=70'),
+ ('5c5d5e5fe3e0e1fe31303736333c3d3e', 'fc418eb3c41b859b38d4b6f646629729',
+ '68696a6b6d6e6f70727374757778797a7c7d7e7f81828384868788898b8c8d8e',
+ 'ecb-tbl-256: I=71'),
+ ('545556574b48494673727574546b6a69', '30249e5ac282b1c981ea64b609f3a154',
+ '90919293959697989a9b9c9d9fa0a1a2a4a5a6a7a9aaabacaeafb0b1b3b4b5b6',
+ 'ecb-tbl-256: I=72'),
+ ('ecedeeefc6c5c4bb56575051f5fafbf8', '5e6e08646d12150776bb43c2d78a9703',
+ 'b8b9babbbdbebfc0c2c3c4c5c7c8c9cacccdcecfd1d2d3d4d6d7d8d9dbdcddde',
+ 'ecb-tbl-256: I=73'),
+ ('464744452724252ac9c8cfced2cdcccf', 'faeb3d5de652cd3447dceb343f30394a',
+ 'e0e1e2e3e5e6e7e8eaebecedeff0f1f2f4f5f6f7f9fafbfcfefe010103040506',
+ 'ecb-tbl-256: I=74'),
+ ('e6e7e4e54142435c878681801c131211', 'a8e88706823f6993ef80d05c1c7b2cf0',
+ '08090a0b0d0e0f10121314151718191a1c1d1e1f21222324262728292b2c2d2e',
+ 'ecb-tbl-256: I=75'),
+ ('72737071cfcccdc2f9f8fffe710e0f0c', '8ced86677e6e00a1a1b15968f2d3cce6',
+ '30313233353637383a3b3c3d3f40414244454647494a4b4c4e4f505153545556',
+ 'ecb-tbl-256: I=76'),
+ ('505152537370714ec3c2c5c4010e0f0c', '9fc7c23858be03bdebb84e90db6786a9',
+ '58595a5b5d5e5f60626364656768696a6c6d6e6f71727374767778797b7c7d7e',
+ 'ecb-tbl-256: I=77'),
+ ('a8a9aaab5c5f5e51aeafa8a93d222320', 'b4fbd65b33f70d8cf7f1111ac4649c36',
+ '80818283858687888a8b8c8d8f90919294959697999a9b9c9e9fa0a1a3a4a5a6',
+ 'ecb-tbl-256: I=78'),
+ ('dedfdcddf6f5f4eb10111617fef1f0f3', 'c5c32d5ed03c4b53cc8c1bd0ef0dbbf6',
+ 'a8a9aaabadaeafb0b2b3b4b5b7b8b9babcbdbebfc1c2c3c4c6c7c8c9cbcccdce',
+ 'ecb-tbl-256: I=79'),
+ ('bdbcbfbe5e5d5c530b0a0d0cfac5c4c7', 'd1a7f03b773e5c212464b63709c6a891',
+ 'd0d1d2d3d5d6d7d8dadbdcdddfe0e1e2e4e5e6e7e9eaebeceeeff0f1f3f4f5f6',
+ 'ecb-tbl-256: I=80'),
+ ('8a8b8889050606f8f4f5f2f3636c6d6e', '6b7161d8745947ac6950438ea138d028',
+ 'f8f9fafbfdfefe00020304050708090a0c0d0e0f11121314161718191b1c1d1e',
+ 'ecb-tbl-256: I=81'),
+ ('a6a7a4a54d4e4f40b2b3b4b539262724', 'fd47a9f7e366ee7a09bc508b00460661',
+ '20212223252627282a2b2c2d2f30313234353637393a3b3c3e3f404143444546',
+ 'ecb-tbl-256: I=82'),
+ ('9c9d9e9fe9eaebf40e0f08099b949596', '00d40b003dc3a0d9310b659b98c7e416',
+ '48494a4b4d4e4f50525354555758595a5c5d5e5f61626364666768696b6c6d6e',
+ 'ecb-tbl-256: I=83'),
+ ('2d2c2f2e1013121dcccdcacbed121310', 'eea4c79dcc8e2bda691f20ac48be0717',
+ '70717273757677787a7b7c7d7f80818284858687898a8b8c8e8f909193949596',
+ 'ecb-tbl-256: I=84'),
+ ('f4f5f6f7edeeefd0eaebecedf7f8f9fa', 'e78f43b11c204403e5751f89d05a2509',
+ '98999a9b9d9e9fa0a2a3a4a5a7a8a9aaacadaeafb1b2b3b4b6b7b8b9bbbcbdbe',
+ 'ecb-tbl-256: I=85'),
+ ('3d3c3f3e282b2a2573727574150a0b08', 'd0f0e3d1f1244bb979931e38dd1786ef',
+ 'c0c1c2c3c5c6c7c8cacbcccdcfd0d1d2d4d5d6d7d9dadbdcdedfe0e1e3e4e5e6',
+ 'ecb-tbl-256: I=86'),
+ ('b6b7b4b5f8fbfae5b4b5b2b3a0afaead', '042e639dc4e1e4dde7b75b749ea6f765',
+ 'e8e9eaebedeeeff0f2f3f4f5f7f8f9fafcfdfeff01020304060708090b0c0d0e',
+ 'ecb-tbl-256: I=87'),
+ ('b7b6b5b4989b9a95878681809ba4a5a6', 'bc032fdd0efe29503a980a7d07ab46a8',
+ '10111213151617181a1b1c1d1f20212224252627292a2b2c2e2f303133343536',
+ 'ecb-tbl-256: I=88'),
+ ('a8a9aaabe5e6e798e9e8efee4748494a', '0c93ac949c0da6446effb86183b6c910',
+ '38393a3b3d3e3f40424344454748494a4c4d4e4f51525354565758595b5c5d5e',
+ 'ecb-tbl-256: I=89'),
+ ('ecedeeefd9dadbd4b9b8bfbe657a7b78', 'e0d343e14da75c917b4a5cec4810d7c2',
+ '60616263656667686a6b6c6d6f70717274757677797a7b7c7e7f808183848586',
+ 'ecb-tbl-256: I=90'),
+ ('7f7e7d7c696a6b74cacbcccd929d9c9f', '0eafb821748408279b937b626792e619',
+ '88898a8b8d8e8f90929394959798999a9c9d9e9fa1a2a3a4a6a7a8a9abacadae',
+ 'ecb-tbl-256: I=91'),
+ ('08090a0b0605040bfffef9f8b9c6c7c4', 'fa1ac6e02d23b106a1fef18b274a553f',
+ 'b0b1b2b3b5b6b7b8babbbcbdbfc0c1c2c4c5c6c7c9cacbcccecfd0d1d3d4d5d6',
+ 'ecb-tbl-256: I=92'),
+ ('08090a0bf1f2f3ccfcfdfafb68676665', '0dadfe019cd12368075507df33c1a1e9',
+ 'd8d9dadbdddedfe0e2e3e4e5e7e8e9eaecedeeeff1f2f3f4f6f7f8f9fbfcfdfe',
+ 'ecb-tbl-256: I=93'),
+ ('cacbc8c93a393837050403020d121310', '3a0879b414465d9ffbaf86b33a63a1b9',
+ '00010203050607080a0b0c0d0f10111214151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-256: I=94'),
+ ('e9e8ebea8281809f8f8e8988343b3a39', '62199fadc76d0be1805d3ba0b7d914bf',
+ '28292a2b2d2e2f30323334353738393a3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-256: I=95'),
+ ('515053524645444bd0d1d6d7340b0a09', '1b06d6c5d333e742730130cf78e719b4',
+ '50515253555657585a5b5c5d5f60616264656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-256: I=96'),
+ ('42434041ecefee1193929594c6c9c8cb', 'f1f848824c32e9dcdcbf21580f069329',
+ '78797a7b7d7e7f80828384858788898a8c8d8e8f91929394969798999b9c9d9e',
+ 'ecb-tbl-256: I=97'),
+ ('efeeedecc2c1c0cf76777071455a5b58', '1a09050cbd684f784d8e965e0782f28a',
+ 'a0a1a2a3a5a6a7a8aaabacadafb0b1b2b4b5b6b7b9babbbcbebfc0c1c3c4c5c6',
+ 'ecb-tbl-256: I=98'),
+ ('5f5e5d5c3f3c3d221d1c1b1a19161714', '79c2969e7ded2ba7d088f3f320692360',
+ 'c8c9cacbcdcecfd0d2d3d4d5d7d8d9dadcdddedfe1e2e3e4e6e7e8e9ebecedee',
+ 'ecb-tbl-256: I=99'),
+ ('000102034142434c1c1d1a1b8d727371', '091a658a2f7444c16accb669450c7b63',
+ 'f0f1f2f3f5f6f7f8fafbfcfdfe01000204050607090a0b0c0e0f101113141516',
+ 'ecb-tbl-256: I=100'),
+ ('8e8f8c8db1b2b38c56575051050a0b08', '97c1e3a72cca65fa977d5ed0e8a7bbfc',
+ '18191a1b1d1e1f20222324252728292a2c2d2e2f31323334363738393b3c3d3e',
+ 'ecb-tbl-256: I=101'),
+ ('a7a6a5a4e8ebeae57f7e7978cad5d4d7', '70c430c6db9a17828937305a2df91a2a',
+ '40414243454647484a4b4c4d4f50515254555657595a5b5c5e5f606163646566',
+ 'ecb-tbl-256: I=102'),
+ ('8a8b888994979689454443429f909192', '629553457fbe2479098571c7c903fde8',
+ '68696a6b6d6e6f70727374757778797a7c7d7e7f81828384868788898b8c8d8e',
+ 'ecb-tbl-256: I=103'),
+ ('8c8d8e8fe0e3e2ed45444342f1cecfcc', 'a25b25a61f612669e7d91265c7d476ba',
+ '90919293959697989a9b9c9d9fa0a1a2a4a5a6a7a9aaabacaeafb0b1b3b4b5b6',
+ 'ecb-tbl-256: I=104'),
+ ('fffefdfc4c4f4e31d8d9dedfb6b9b8bb', 'eb7e4e49b8ae0f024570dda293254fed',
+ 'b8b9babbbdbebfc0c2c3c4c5c7c8c9cacccdcecfd1d2d3d4d6d7d8d9dbdcddde',
+ 'ecb-tbl-256: I=105'),
+ ('fdfcfffecccfcec12f2e29286679787b', '38fe15d61cca84516e924adce5014f67',
+ 'e0e1e2e3e5e6e7e8eaebecedeff0f1f2f4f5f6f7f9fafbfcfefe010103040506',
+ 'ecb-tbl-256: I=106'),
+ ('67666564bab9b8a77071767719161714', '3ad208492249108c9f3ebeb167ad0583',
+ '08090a0b0d0e0f10121314151718191a1c1d1e1f21222324262728292b2c2d2e',
+ 'ecb-tbl-256: I=107'),
+ ('9a9b98992d2e2f2084858283245b5a59', '299ba9f9bf5ab05c3580fc26edd1ed12',
+ '30313233353637383a3b3c3d3f40414244454647494a4b4c4e4f505153545556',
+ 'ecb-tbl-256: I=108'),
+ ('a4a5a6a70b0809365c5d5a5b2c232221', '19dc705b857a60fb07717b2ea5717781',
+ '58595a5b5d5e5f60626364656768696a6c6d6e6f71727374767778797b7c7d7e',
+ 'ecb-tbl-256: I=109'),
+ ('464744455754555af3f2f5f4afb0b1b2', 'ffc8aeb885b5efcad06b6dbebf92e76b',
+ '80818283858687888a8b8c8d8f90919294959697999a9b9c9e9fa0a1a3a4a5a6',
+ 'ecb-tbl-256: I=110'),
+ ('323330317675746b7273747549464744', 'f58900c5e0b385253ff2546250a0142b',
+ 'a8a9aaabadaeafb0b2b3b4b5b7b8b9babcbdbebfc1c2c3c4c6c7c8c9cbcccdce',
+ 'ecb-tbl-256: I=111'),
+ ('a8a9aaab181b1a15808186872b141516', '2ee67b56280bc462429cee6e3370cbc1',
+ 'd0d1d2d3d5d6d7d8dadbdcdddfe0e1e2e4e5e6e7e9eaebeceeeff0f1f3f4f5f6',
+ 'ecb-tbl-256: I=112'),
+ ('e7e6e5e4202323ddaaabacad343b3a39', '20db650a9c8e9a84ab4d25f7edc8f03f',
+ 'f8f9fafbfdfefe00020304050708090a0c0d0e0f11121314161718191b1c1d1e',
+ 'ecb-tbl-256: I=113'),
+ ('a8a9aaab2221202fedecebea1e010003', '3c36da169525cf818843805f25b78ae5',
+ '20212223252627282a2b2c2d2f30313234353637393a3b3c3e3f404143444546',
+ 'ecb-tbl-256: I=114'),
+ ('f9f8fbfa5f5c5d42424344450e010003', '9a781d960db9e45e37779042fea51922',
+ '48494a4b4d4e4f50525354555758595a5c5d5e5f61626364666768696b6c6d6e',
+ 'ecb-tbl-256: I=115'),
+ ('57565554f5f6f7f89697909120dfdedd', '6560395ec269c672a3c288226efdba77',
+ '70717273757677787a7b7c7d7f80818284858687898a8b8c8e8f909193949596',
+ 'ecb-tbl-256: I=116'),
+ ('f8f9fafbcccfcef1dddcdbda0e010003', '8c772b7a189ac544453d5916ebb27b9a',
+ '98999a9b9d9e9fa0a2a3a4a5a7a8a9aaacadaeafb1b2b3b4b6b7b8b9bbbcbdbe',
+ 'ecb-tbl-256: I=117'),
+ ('d9d8dbda7073727d80818687c2dddcdf', '77ca5468cc48e843d05f78eed9d6578f',
+ 'c0c1c2c3c5c6c7c8cacbcccdcfd0d1d2d4d5d6d7d9dadbdcdedfe0e1e3e4e5e6',
+ 'ecb-tbl-256: I=118'),
+ ('c5c4c7c6080b0a1588898e8f68676665', '72cdcc71dc82c60d4429c9e2d8195baa',
+ 'e8e9eaebedeeeff0f2f3f4f5f7f8f9fafcfdfeff01020304060708090b0c0d0e',
+ 'ecb-tbl-256: I=119'),
+ ('83828180dcdfded186878081f0cfcecd', '8080d68ce60e94b40b5b8b69eeb35afa',
+ '10111213151617181a1b1c1d1f20212224252627292a2b2c2e2f303133343536',
+ 'ecb-tbl-256: I=120'),
+ ('98999a9bdddedfa079787f7e0a050407', '44222d3cde299c04369d58ac0eba1e8e',
+ '38393a3b3d3e3f40424344454748494a4c4d4e4f51525354565758595b5c5d5e',
+ 'ecb-tbl-256: I=121'),
+ ('cecfcccd4f4c4d429f9e9998dfc0c1c2', '9b8721b0a8dfc691c5bc5885dbfcb27a',
+ '60616263656667686a6b6c6d6f70717274757677797a7b7c7e7f808183848586',
+ 'ecb-tbl-256: I=122'),
+ ('404142436665647b29282f2eaba4a5a6', '0dc015ce9a3a3414b5e62ec643384183',
+ '88898a8b8d8e8f90929394959798999a9c9d9e9fa1a2a3a4a6a7a8a9abacadae',
+ 'ecb-tbl-256: I=123'),
+ ('33323130e6e5e4eb23222524dea1a0a3', '705715448a8da412025ce38345c2a148',
+ 'b0b1b2b3b5b6b7b8babbbcbdbfc0c1c2c4c5c6c7c9cacbcccecfd0d1d3d4d5d6',
+ 'ecb-tbl-256: I=124'),
+ ('cfcecdccf6f5f4cbe6e7e0e199969794', 'c32b5b0b6fbae165266c569f4b6ecf0b',
+ 'd8d9dadbdddedfe0e2e3e4e5e7e8e9eaecedeeeff1f2f3f4f6f7f8f9fbfcfdfe',
+ 'ecb-tbl-256: I=125'),
+ ('babbb8b97271707fdcdddadb29363734', '4dca6c75192a01ddca9476af2a521e87',
+ '00010203050607080a0b0c0d0f10111214151617191a1b1c1e1f202123242526',
+ 'ecb-tbl-256: I=126'),
+ ('c9c8cbca4447465926272021545b5a59', '058691e627ecbc36ac07b6db423bd698',
+ '28292a2b2d2e2f30323334353738393a3c3d3e3f41424344464748494b4c4d4e',
+ 'ecb-tbl-256: I=127'),
+ ('050407067477767956575051221d1c1f', '7444527095838fe080fc2bcdd30847eb',
+ '50515253555657585a5b5c5d5f60616264656667696a6b6c6e6f707173747576',
+ 'ecb-tbl-256: I=128'),
+
+ # FIPS PUB 800-38A test vectors, 2001 edition. Annex F.
+
+ ('6bc1bee22e409f96e93d7e117393172a'+'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411e5fbc1191a0a52ef'+'f69f2445df4f9b17ad2b417be66c3710',
+ '3ad77bb40d7a3660a89ecaf32466ef97'+'f5d3d58503b9699de785895a96fdbaaf'+
+ '43b1cd7f598ece23881b00e3ed030688'+'7b0c785e27e8ad3f8223207104725dd4',
+ '2b7e151628aed2a6abf7158809cf4f3c',
+ 'NIST 800-38A, F.1.1, ECB and AES-128'),
+
+ ('6bc1bee22e409f96e93d7e117393172a'+'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411e5fbc1191a0a52ef'+'f69f2445df4f9b17ad2b417be66c3710',
+ 'bd334f1d6e45f25ff712a214571fa5cc'+'974104846d0ad3ad7734ecb3ecee4eef'+
+ 'ef7afd2270e2e60adce0ba2face6444e'+'9a4b41ba738d6c72fb16691603c18e0e',
+ '8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b',
+ 'NIST 800-38A, F.1.3, ECB and AES-192'),
+
+ ('6bc1bee22e409f96e93d7e117393172a'+'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411e5fbc1191a0a52ef'+'f69f2445df4f9b17ad2b417be66c3710',
+ 'f3eed1bdb5d2a03c064b5a7e3db181f8'+'591ccb10d410ed26dc5ba74a31362870'+
+ 'b6ed21b99ca6f4f9f153e7b1beafed1d'+'23304b7a39f9f3ff067d8d8f9e24ecc7',
+ '603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4',
+ 'NIST 800-38A, F.1.3, ECB and AES-256'),
+
+]
+
+test_data_8_lanes = []
+for td in test_data:
+ test_data_8_lanes.append((td[0] * 8, td[1] * 8, td[2], td[3]))
+test_data += test_data_8_lanes
+
+class TestMultipleBlocks(unittest.TestCase):
+
+ def __init__(self, use_aesni):
+ unittest.TestCase.__init__(self)
+ self.use_aesni = use_aesni
+
+ def runTest(self):
+ # Encrypt data which is 8*2+4 bytes long, so as to trigger (for the
+ # AESNI variant) both the path that parallelizes 8 lanes and the one
+ # that processes data serially
+
+ tvs = [
+ (b'a' * 16, 'c0b27011eb15bf144d2fc9fae80ea16d4c231cb230416c5fac02e6835ad9d7d0'),
+ (b'a' * 24, 'df8435ce361a78c535b41dcb57da952abbf9ee5954dc6fbcd75fd00fa626915d'),
+ (b'a' * 32, '211402de6c80db1f92ba255881178e1f70783b8cfd3b37808205e48b80486cd8')
+ ]
+
+ for key, expected in tvs:
+
+ cipher = AES.new(key, AES.MODE_ECB, use_aesni=self.use_aesni)
+ h = SHA256.new()
+
+ pt = b"".join([ tobytes('{0:016x}'.format(x)) for x in range(20) ])
+ ct = cipher.encrypt(pt)
+ self.assertEqual(SHA256.new(ct).hexdigest(), expected)
+
+
+class TestIncompleteBlocks(unittest.TestCase):
+
+ def __init__(self, use_aesni):
+ unittest.TestCase.__init__(self)
+ self.use_aesni = use_aesni
+
+ def runTest(self):
+ # Encrypt data with length not multiple of 16 bytes
+
+ cipher = AES.new(b'4'*16, AES.MODE_ECB, use_aesni=self.use_aesni)
+
+ for msg_len in range(1, 16):
+ self.assertRaises(ValueError, cipher.encrypt, b'1' * msg_len)
+ self.assertRaises(ValueError, cipher.encrypt, b'1' * (msg_len+16))
+ self.assertRaises(ValueError, cipher.decrypt, b'1' * msg_len)
+ self.assertRaises(ValueError, cipher.decrypt, b'1' * (msg_len+16))
+
+ self.assertEqual(cipher.encrypt(b''), b'')
+ self.assertEqual(cipher.decrypt(b''), b'')
+
+
+class TestOutput(unittest.TestCase):
+
+ def __init__(self, use_aesni):
+ unittest.TestCase.__init__(self)
+ self.use_aesni = use_aesni
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ cipher = AES.new(b'4'*16, AES.MODE_ECB, use_aesni=self.use_aesni)
+
+ pt = b'5' * 16
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(16)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(16))
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
+
+ shorter_output = bytearray(15)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ from Crypto.Util import _cpu_features
+ from .common import make_block_tests
+
+ tests = make_block_tests(AES, "AES", test_data, {'use_aesni': False})
+ tests += [ TestMultipleBlocks(False) ]
+ tests += [ TestIncompleteBlocks(False) ]
+ if _cpu_features.have_aes_ni():
+ # Run tests with AES-NI instructions if they are available.
+ tests += make_block_tests(AES, "AESNI", test_data, {'use_aesni': True})
+ tests += [ TestMultipleBlocks(True) ]
+ tests += [ TestIncompleteBlocks(True) ]
+ tests += [ TestOutput(True) ]
+ else:
+ print("Skipping AESNI tests")
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_ARC2.py b/lib/Crypto/SelfTest/Cipher/test_ARC2.py
new file mode 100644
index 0000000..fd9448c
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_ARC2.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/ARC2.py: Self-test for the Alleged-RC2 cipher
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.ARC2"""
+
+import unittest
+
+from Crypto.Util.py3compat import b, bchr
+
+from Crypto.Cipher import ARC2
+
+# This is a list of (plaintext, ciphertext, key[, description[, extra_params]]) tuples.
+test_data = [
+ # Test vectors from RFC 2268
+
+ # 63-bit effective key length
+ ('0000000000000000', 'ebb773f993278eff', '0000000000000000',
+ 'RFC2268-1', dict(effective_keylen=63)),
+
+ # 64-bit effective key length
+ ('ffffffffffffffff', '278b27e42e2f0d49', 'ffffffffffffffff',
+ 'RFC2268-2', dict(effective_keylen=64)),
+ ('1000000000000001', '30649edf9be7d2c2', '3000000000000000',
+ 'RFC2268-3', dict(effective_keylen=64)),
+ #('0000000000000000', '61a8a244adacccf0', '88',
+ # 'RFC2268-4', dict(effective_keylen=64)),
+ ('0000000000000000', '6ccf4308974c267f', '88bca90e90875a',
+ 'RFC2268-5', dict(effective_keylen=64)),
+ ('0000000000000000', '1a807d272bbe5db1', '88bca90e90875a7f0f79c384627bafb2',
+ 'RFC2268-6', dict(effective_keylen=64)),
+
+ # 128-bit effective key length
+ ('0000000000000000', '2269552ab0f85ca6', '88bca90e90875a7f0f79c384627bafb2',
+ "RFC2268-7", dict(effective_keylen=128)),
+ ('0000000000000000', '5b78d3a43dfff1f1',
+ '88bca90e90875a7f0f79c384627bafb216f80a6f85920584c42fceb0be255daf1e',
+ "RFC2268-8", dict(effective_keylen=129)),
+
+ # Test vectors from PyCrypto 2.0.1's testdata.py
+ # 1024-bit effective key length
+ ('0000000000000000', '624fb3e887419e48', '5068696c6970476c617373',
+ 'PCTv201-0'),
+ ('ffffffffffffffff', '79cadef44c4a5a85', '5068696c6970476c617373',
+ 'PCTv201-1'),
+ ('0001020304050607', '90411525b34e4c2c', '5068696c6970476c617373',
+ 'PCTv201-2'),
+ ('0011223344556677', '078656aaba61cbfb', '5068696c6970476c617373',
+ 'PCTv201-3'),
+ ('0000000000000000', 'd7bcc5dbb4d6e56a', 'ffffffffffffffff',
+ 'PCTv201-4'),
+ ('ffffffffffffffff', '7259018ec557b357', 'ffffffffffffffff',
+ 'PCTv201-5'),
+ ('0001020304050607', '93d20a497f2ccb62', 'ffffffffffffffff',
+ 'PCTv201-6'),
+ ('0011223344556677', 'cb15a7f819c0014d', 'ffffffffffffffff',
+ 'PCTv201-7'),
+ ('0000000000000000', '63ac98cdf3843a7a', 'ffffffffffffffff5065746572477265656e6177617953e5ffe553',
+ 'PCTv201-8'),
+ ('ffffffffffffffff', '3fb49e2fa12371dd', 'ffffffffffffffff5065746572477265656e6177617953e5ffe553',
+ 'PCTv201-9'),
+ ('0001020304050607', '46414781ab387d5f', 'ffffffffffffffff5065746572477265656e6177617953e5ffe553',
+ 'PCTv201-10'),
+ ('0011223344556677', 'be09dc81feaca271', 'ffffffffffffffff5065746572477265656e6177617953e5ffe553',
+ 'PCTv201-11'),
+ ('0000000000000000', 'e64221e608be30ab', '53e5ffe553',
+ 'PCTv201-12'),
+ ('ffffffffffffffff', '862bc60fdcd4d9a9', '53e5ffe553',
+ 'PCTv201-13'),
+ ('0001020304050607', '6a34da50fa5e47de', '53e5ffe553',
+ 'PCTv201-14'),
+ ('0011223344556677', '584644c34503122c', '53e5ffe553',
+ 'PCTv201-15'),
+]
+
+class BufferOverflowTest(unittest.TestCase):
+ # Test a buffer overflow found in older versions of PyCrypto
+
+ def runTest(self):
+ """ARC2 with keylength > 128"""
+ key = b("x") * 16384
+ self.assertRaises(ValueError, ARC2.new, key, ARC2.MODE_ECB)
+
+class KeyLength(unittest.TestCase):
+
+ def runTest(self):
+ ARC2.new(b'\x00' * 16, ARC2.MODE_ECB, effective_keylen=40)
+ self.assertRaises(ValueError, ARC2.new, bchr(0) * 4, ARC2.MODE_ECB)
+ self.assertRaises(ValueError, ARC2.new, bchr(0) * 129, ARC2.MODE_ECB)
+
+ self.assertRaises(ValueError, ARC2.new, bchr(0) * 16, ARC2.MODE_ECB,
+ effective_keylen=39)
+ self.assertRaises(ValueError, ARC2.new, bchr(0) * 16, ARC2.MODE_ECB,
+ effective_keylen=1025)
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ cipher = ARC2.new(b'4'*16, ARC2.MODE_ECB)
+
+ pt = b'5' * 16
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(16)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(16))
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
+
+ shorter_output = bytearray(7)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ from Crypto.Cipher import ARC2
+ from .common import make_block_tests
+
+ tests = make_block_tests(ARC2, "ARC2", test_data)
+ tests.append(BufferOverflowTest())
+ tests.append(KeyLength())
+ tests += [TestOutput()]
+
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_ARC4.py b/lib/Crypto/SelfTest/Cipher/test_ARC4.py
new file mode 100644
index 0000000..1b9a7d8
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_ARC4.py
@@ -0,0 +1,466 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/ARC4.py: Self-test for the Alleged-RC4 cipher
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.ARC4"""
+
+__revision__ = "$Id$"
+
+from Crypto.Util.py3compat import *
+from Crypto.SelfTest.st_common import *
+from binascii import unhexlify
+
+from Crypto.Cipher import ARC4
+
+# This is a list of (plaintext, ciphertext, key[, description]) tuples.
+test_data = [
+ # Test vectors from Eric Rescorla's message with the subject
+ # "RC4 compatibility testing", sent to the cipherpunks mailing list on
+ # September 13, 1994.
+ # http://cypherpunks.venona.com/date/1994/09/msg00420.html
+
+ ('0123456789abcdef', '75b7878099e0c596', '0123456789abcdef',
+ 'Test vector 0'),
+
+ ('0000000000000000', '7494c2e7104b0879', '0123456789abcdef',
+ 'Test vector 1'),
+
+ ('0000000000000000', 'de188941a3375d3a', '0000000000000000',
+ 'Test vector 2'),
+
+ #('00000000000000000000', 'd6a141a7ec3c38dfbd61', 'ef012345',
+ # 'Test vector 3'),
+
+ ('01' * 512,
+ '7595c3e6114a09780c4ad452338e1ffd9a1be9498f813d76533449b6778dcad8'
+ + 'c78a8d2ba9ac66085d0e53d59c26c2d1c490c1ebbe0ce66d1b6b1b13b6b919b8'
+ + '47c25a91447a95e75e4ef16779cde8bf0a95850e32af9689444fd377108f98fd'
+ + 'cbd4e726567500990bcc7e0ca3c4aaa304a387d20f3b8fbbcd42a1bd311d7a43'
+ + '03dda5ab078896ae80c18b0af66dff319616eb784e495ad2ce90d7f772a81747'
+ + 'b65f62093b1e0db9e5ba532fafec47508323e671327df9444432cb7367cec82f'
+ + '5d44c0d00b67d650a075cd4b70dedd77eb9b10231b6b5b741347396d62897421'
+ + 'd43df9b42e446e358e9c11a9b2184ecbef0cd8e7a877ef968f1390ec9b3d35a5'
+ + '585cb009290e2fcde7b5ec66d9084be44055a619d9dd7fc3166f9487f7cb2729'
+ + '12426445998514c15d53a18c864ce3a2b7555793988126520eacf2e3066e230c'
+ + '91bee4dd5304f5fd0405b35bd99c73135d3d9bc335ee049ef69b3867bf2d7bd1'
+ + 'eaa595d8bfc0066ff8d31509eb0c6caa006c807a623ef84c3d33c195d23ee320'
+ + 'c40de0558157c822d4b8c569d849aed59d4e0fd7f379586b4b7ff684ed6a189f'
+ + '7486d49b9c4bad9ba24b96abf924372c8a8fffb10d55354900a77a3db5f205e1'
+ + 'b99fcd8660863a159ad4abe40fa48934163ddde542a6585540fd683cbfd8c00f'
+ + '12129a284deacc4cdefe58be7137541c047126c8d49e2755ab181ab7e940b0c0',
+ '0123456789abcdef',
+ "Test vector 4"),
+]
+
+class RFC6229_Tests(unittest.TestCase):
+ # Test vectors from RFC 6229. Each test vector is a tuple with two items:
+ # the ARC4 key and a dictionary. The dictionary has keystream offsets as keys
+ # and the 16-byte keystream starting at the relevant offset as value.
+ rfc6229_data = [
+ # Page 3
+ (
+ '0102030405',
+ {
+ 0: 'b2 39 63 05 f0 3d c0 27 cc c3 52 4a 0a 11 18 a8',
+ 16: '69 82 94 4f 18 fc 82 d5 89 c4 03 a4 7a 0d 09 19',
+ 240: '28 cb 11 32 c9 6c e2 86 42 1d ca ad b8 b6 9e ae',
+ 256: '1c fc f6 2b 03 ed db 64 1d 77 df cf 7f 8d 8c 93',
+ 496: '42 b7 d0 cd d9 18 a8 a3 3d d5 17 81 c8 1f 40 41',
+ 512: '64 59 84 44 32 a7 da 92 3c fb 3e b4 98 06 61 f6',
+ 752: 'ec 10 32 7b de 2b ee fd 18 f9 27 76 80 45 7e 22',
+ 768: 'eb 62 63 8d 4f 0b a1 fe 9f ca 20 e0 5b f8 ff 2b',
+ 1008:'45 12 90 48 e6 a0 ed 0b 56 b4 90 33 8f 07 8d a5',
+ 1024:'30 ab bc c7 c2 0b 01 60 9f 23 ee 2d 5f 6b b7 df',
+ 1520:'32 94 f7 44 d8 f9 79 05 07 e7 0f 62 e5 bb ce ea',
+ 1536:'d8 72 9d b4 18 82 25 9b ee 4f 82 53 25 f5 a1 30',
+ 2032:'1e b1 4a 0c 13 b3 bf 47 fa 2a 0b a9 3a d4 5b 8b',
+ 2048:'cc 58 2f 8b a9 f2 65 e2 b1 be 91 12 e9 75 d2 d7',
+ 3056:'f2 e3 0f 9b d1 02 ec bf 75 aa ad e9 bc 35 c4 3c',
+ 3072:'ec 0e 11 c4 79 dc 32 9d c8 da 79 68 fe 96 56 81',
+ 4080:'06 83 26 a2 11 84 16 d2 1f 9d 04 b2 cd 1c a0 50',
+ 4096:'ff 25 b5 89 95 99 67 07 e5 1f bd f0 8b 34 d8 75'
+ }
+ ),
+ # Page 4
+ (
+ '01020304050607',
+ {
+ 0: '29 3f 02 d4 7f 37 c9 b6 33 f2 af 52 85 fe b4 6b',
+ 16: 'e6 20 f1 39 0d 19 bd 84 e2 e0 fd 75 20 31 af c1',
+ 240: '91 4f 02 53 1c 92 18 81 0d f6 0f 67 e3 38 15 4c',
+ 256: 'd0 fd b5 83 07 3c e8 5a b8 39 17 74 0e c0 11 d5',
+ 496: '75 f8 14 11 e8 71 cf fa 70 b9 0c 74 c5 92 e4 54',
+ 512: '0b b8 72 02 93 8d ad 60 9e 87 a5 a1 b0 79 e5 e4',
+ 752: 'c2 91 12 46 b6 12 e7 e7 b9 03 df ed a1 da d8 66',
+ 768: '32 82 8f 91 50 2b 62 91 36 8d e8 08 1d e3 6f c2',
+ 1008:'f3 b9 a7 e3 b2 97 bf 9a d8 04 51 2f 90 63 ef f1',
+ 1024:'8e cb 67 a9 ba 1f 55 a5 a0 67 e2 b0 26 a3 67 6f',
+ 1520:'d2 aa 90 2b d4 2d 0d 7c fd 34 0c d4 58 10 52 9f',
+ 1536:'78 b2 72 c9 6e 42 ea b4 c6 0b d9 14 e3 9d 06 e3',
+ 2032:'f4 33 2f d3 1a 07 93 96 ee 3c ee 3f 2a 4f f0 49',
+ 2048:'05 45 97 81 d4 1f da 7f 30 c1 be 7e 12 46 c6 23',
+ 3056:'ad fd 38 68 b8 e5 14 85 d5 e6 10 01 7e 3d d6 09',
+ 3072:'ad 26 58 1c 0c 5b e4 5f 4c ea 01 db 2f 38 05 d5',
+ 4080:'f3 17 2c ef fc 3b 3d 99 7c 85 cc d5 af 1a 95 0c',
+ 4096:'e7 4b 0b 97 31 22 7f d3 7c 0e c0 8a 47 dd d8 b8'
+ }
+ ),
+ (
+ '0102030405060708',
+ {
+ 0: '97 ab 8a 1b f0 af b9 61 32 f2 f6 72 58 da 15 a8',
+ 16: '82 63 ef db 45 c4 a1 86 84 ef 87 e6 b1 9e 5b 09',
+ 240: '96 36 eb c9 84 19 26 f4 f7 d1 f3 62 bd df 6e 18',
+ 256: 'd0 a9 90 ff 2c 05 fe f5 b9 03 73 c9 ff 4b 87 0a',
+ 496: '73 23 9f 1d b7 f4 1d 80 b6 43 c0 c5 25 18 ec 63',
+ 512: '16 3b 31 99 23 a6 bd b4 52 7c 62 61 26 70 3c 0f',
+ 752: '49 d6 c8 af 0f 97 14 4a 87 df 21 d9 14 72 f9 66',
+ 768: '44 17 3a 10 3b 66 16 c5 d5 ad 1c ee 40 c8 63 d0',
+ 1008:'27 3c 9c 4b 27 f3 22 e4 e7 16 ef 53 a4 7d e7 a4',
+ 1024:'c6 d0 e7 b2 26 25 9f a9 02 34 90 b2 61 67 ad 1d',
+ 1520:'1f e8 98 67 13 f0 7c 3d 9a e1 c1 63 ff 8c f9 d3',
+ 1536:'83 69 e1 a9 65 61 0b e8 87 fb d0 c7 91 62 aa fb',
+ 2032:'0a 01 27 ab b4 44 84 b9 fb ef 5a bc ae 1b 57 9f',
+ 2048:'c2 cd ad c6 40 2e 8e e8 66 e1 f3 7b db 47 e4 2c',
+ 3056:'26 b5 1e a3 7d f8 e1 d6 f7 6f c3 b6 6a 74 29 b3',
+ 3072:'bc 76 83 20 5d 4f 44 3d c1 f2 9d da 33 15 c8 7b',
+ 4080:'d5 fa 5a 34 69 d2 9a aa f8 3d 23 58 9d b8 c8 5b',
+ 4096:'3f b4 6e 2c 8f 0f 06 8e dc e8 cd cd 7d fc 58 62'
+ }
+ ),
+ # Page 5
+ (
+ '0102030405060708090a',
+ {
+ 0: 'ed e3 b0 46 43 e5 86 cc 90 7d c2 18 51 70 99 02',
+ 16: '03 51 6b a7 8f 41 3b eb 22 3a a5 d4 d2 df 67 11',
+ 240: '3c fd 6c b5 8e e0 fd de 64 01 76 ad 00 00 04 4d',
+ 256: '48 53 2b 21 fb 60 79 c9 11 4c 0f fd 9c 04 a1 ad',
+ 496: '3e 8c ea 98 01 71 09 97 90 84 b1 ef 92 f9 9d 86',
+ 512: 'e2 0f b4 9b db 33 7e e4 8b 8d 8d c0 f4 af ef fe',
+ 752: '5c 25 21 ea cd 79 66 f1 5e 05 65 44 be a0 d3 15',
+ 768: 'e0 67 a7 03 19 31 a2 46 a6 c3 87 5d 2f 67 8a cb',
+ 1008:'a6 4f 70 af 88 ae 56 b6 f8 75 81 c0 e2 3e 6b 08',
+ 1024:'f4 49 03 1d e3 12 81 4e c6 f3 19 29 1f 4a 05 16',
+ 1520:'bd ae 85 92 4b 3c b1 d0 a2 e3 3a 30 c6 d7 95 99',
+ 1536:'8a 0f ed db ac 86 5a 09 bc d1 27 fb 56 2e d6 0a',
+ 2032:'b5 5a 0a 5b 51 a1 2a 8b e3 48 99 c3 e0 47 51 1a',
+ 2048:'d9 a0 9c ea 3c e7 5f e3 96 98 07 03 17 a7 13 39',
+ 3056:'55 22 25 ed 11 77 f4 45 84 ac 8c fa 6c 4e b5 fc',
+ 3072:'7e 82 cb ab fc 95 38 1b 08 09 98 44 21 29 c2 f8',
+ 4080:'1f 13 5e d1 4c e6 0a 91 36 9d 23 22 be f2 5e 3c',
+ 4096:'08 b6 be 45 12 4a 43 e2 eb 77 95 3f 84 dc 85 53'
+ }
+ ),
+ (
+ '0102030405060708090a0b0c0d0e0f10',
+ {
+ 0: '9a c7 cc 9a 60 9d 1e f7 b2 93 28 99 cd e4 1b 97',
+ 16: '52 48 c4 95 90 14 12 6a 6e 8a 84 f1 1d 1a 9e 1c',
+ 240: '06 59 02 e4 b6 20 f6 cc 36 c8 58 9f 66 43 2f 2b',
+ 256: 'd3 9d 56 6b c6 bc e3 01 07 68 15 15 49 f3 87 3f',
+ 496: 'b6 d1 e6 c4 a5 e4 77 1c ad 79 53 8d f2 95 fb 11',
+ 512: 'c6 8c 1d 5c 55 9a 97 41 23 df 1d bc 52 a4 3b 89',
+ 752: 'c5 ec f8 8d e8 97 fd 57 fe d3 01 70 1b 82 a2 59',
+ 768: 'ec cb e1 3d e1 fc c9 1c 11 a0 b2 6c 0b c8 fa 4d',
+ 1008:'e7 a7 25 74 f8 78 2a e2 6a ab cf 9e bc d6 60 65',
+ 1024:'bd f0 32 4e 60 83 dc c6 d3 ce dd 3c a8 c5 3c 16',
+ 1520:'b4 01 10 c4 19 0b 56 22 a9 61 16 b0 01 7e d2 97',
+ 1536:'ff a0 b5 14 64 7e c0 4f 63 06 b8 92 ae 66 11 81',
+ 2032:'d0 3d 1b c0 3c d3 3d 70 df f9 fa 5d 71 96 3e bd',
+ 2048:'8a 44 12 64 11 ea a7 8b d5 1e 8d 87 a8 87 9b f5',
+ 3056:'fa be b7 60 28 ad e2 d0 e4 87 22 e4 6c 46 15 a3',
+ 3072:'c0 5d 88 ab d5 03 57 f9 35 a6 3c 59 ee 53 76 23',
+ 4080:'ff 38 26 5c 16 42 c1 ab e8 d3 c2 fe 5e 57 2b f8',
+ 4096:'a3 6a 4c 30 1a e8 ac 13 61 0c cb c1 22 56 ca cc'
+ }
+ ),
+ # Page 6
+ (
+ '0102030405060708090a0b0c0d0e0f101112131415161718',
+ {
+ 0: '05 95 e5 7f e5 f0 bb 3c 70 6e da c8 a4 b2 db 11',
+ 16: 'df de 31 34 4a 1a f7 69 c7 4f 07 0a ee 9e 23 26',
+ 240: 'b0 6b 9b 1e 19 5d 13 d8 f4 a7 99 5c 45 53 ac 05',
+ 256: '6b d2 37 8e c3 41 c9 a4 2f 37 ba 79 f8 8a 32 ff',
+ 496: 'e7 0b ce 1d f7 64 5a db 5d 2c 41 30 21 5c 35 22',
+ 512: '9a 57 30 c7 fc b4 c9 af 51 ff da 89 c7 f1 ad 22',
+ 752: '04 85 05 5f d4 f6 f0 d9 63 ef 5a b9 a5 47 69 82',
+ 768: '59 1f c6 6b cd a1 0e 45 2b 03 d4 55 1f 6b 62 ac',
+ 1008:'27 53 cc 83 98 8a fa 3e 16 88 a1 d3 b4 2c 9a 02',
+ 1024:'93 61 0d 52 3d 1d 3f 00 62 b3 c2 a3 bb c7 c7 f0',
+ 1520:'96 c2 48 61 0a ad ed fe af 89 78 c0 3d e8 20 5a',
+ 1536:'0e 31 7b 3d 1c 73 b9 e9 a4 68 8f 29 6d 13 3a 19',
+ 2032:'bd f0 e6 c3 cc a5 b5 b9 d5 33 b6 9c 56 ad a1 20',
+ 2048:'88 a2 18 b6 e2 ec e1 e6 24 6d 44 c7 59 d1 9b 10',
+ 3056:'68 66 39 7e 95 c1 40 53 4f 94 26 34 21 00 6e 40',
+ 3072:'32 cb 0a 1e 95 42 c6 b3 b8 b3 98 ab c3 b0 f1 d5',
+ 4080:'29 a0 b8 ae d5 4a 13 23 24 c6 2e 42 3f 54 b4 c8',
+ 4096:'3c b0 f3 b5 02 0a 98 b8 2a f9 fe 15 44 84 a1 68'
+ }
+ ),
+ (
+ '0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20',
+ {
+ 0: 'ea a6 bd 25 88 0b f9 3d 3f 5d 1e 4c a2 61 1d 91',
+ 16: 'cf a4 5c 9f 7e 71 4b 54 bd fa 80 02 7c b1 43 80',
+ 240: '11 4a e3 44 de d7 1b 35 f2 e6 0f eb ad 72 7f d8',
+ 256: '02 e1 e7 05 6b 0f 62 39 00 49 64 22 94 3e 97 b6',
+ 496: '91 cb 93 c7 87 96 4e 10 d9 52 7d 99 9c 6f 93 6b',
+ 512: '49 b1 8b 42 f8 e8 36 7c be b5 ef 10 4b a1 c7 cd',
+ 752: '87 08 4b 3b a7 00 ba de 95 56 10 67 27 45 b3 74',
+ 768: 'e7 a7 b9 e9 ec 54 0d 5f f4 3b db 12 79 2d 1b 35',
+ 1008:'c7 99 b5 96 73 8f 6b 01 8c 76 c7 4b 17 59 bd 90',
+ 1024:'7f ec 5b fd 9f 9b 89 ce 65 48 30 90 92 d7 e9 58',
+ 1520:'40 f2 50 b2 6d 1f 09 6a 4a fd 4c 34 0a 58 88 15',
+ 1536:'3e 34 13 5c 79 db 01 02 00 76 76 51 cf 26 30 73',
+ 2032:'f6 56 ab cc f8 8d d8 27 02 7b 2c e9 17 d4 64 ec',
+ 2048:'18 b6 25 03 bf bc 07 7f ba bb 98 f2 0d 98 ab 34',
+ 3056:'8a ed 95 ee 5b 0d cb fb ef 4e b2 1d 3a 3f 52 f9',
+ 3072:'62 5a 1a b0 0e e3 9a 53 27 34 6b dd b0 1a 9c 18',
+ 4080:'a1 3a 7c 79 c7 e1 19 b5 ab 02 96 ab 28 c3 00 b9',
+ 4096:'f3 e4 c0 a2 e0 2d 1d 01 f7 f0 a7 46 18 af 2b 48'
+ }
+ ),
+ # Page 7
+ (
+ '833222772a',
+ {
+ 0: '80 ad 97 bd c9 73 df 8a 2e 87 9e 92 a4 97 ef da',
+ 16: '20 f0 60 c2 f2 e5 12 65 01 d3 d4 fe a1 0d 5f c0',
+ 240: 'fa a1 48 e9 90 46 18 1f ec 6b 20 85 f3 b2 0e d9',
+ 256: 'f0 da f5 ba b3 d5 96 83 98 57 84 6f 73 fb fe 5a',
+ 496: '1c 7e 2f c4 63 92 32 fe 29 75 84 b2 96 99 6b c8',
+ 512: '3d b9 b2 49 40 6c c8 ed ff ac 55 cc d3 22 ba 12',
+ 752: 'e4 f9 f7 e0 06 61 54 bb d1 25 b7 45 56 9b c8 97',
+ 768: '75 d5 ef 26 2b 44 c4 1a 9c f6 3a e1 45 68 e1 b9',
+ 1008:'6d a4 53 db f8 1e 82 33 4a 3d 88 66 cb 50 a1 e3',
+ 1024:'78 28 d0 74 11 9c ab 5c 22 b2 94 d7 a9 bf a0 bb',
+ 1520:'ad b8 9c ea 9a 15 fb e6 17 29 5b d0 4b 8c a0 5c',
+ 1536:'62 51 d8 7f d4 aa ae 9a 7e 4a d5 c2 17 d3 f3 00',
+ 2032:'e7 11 9b d6 dd 9b 22 af e8 f8 95 85 43 28 81 e2',
+ 2048:'78 5b 60 fd 7e c4 e9 fc b6 54 5f 35 0d 66 0f ab',
+ 3056:'af ec c0 37 fd b7 b0 83 8e b3 d7 0b cd 26 83 82',
+ 3072:'db c1 a7 b4 9d 57 35 8c c9 fa 6d 61 d7 3b 7c f0',
+ 4080:'63 49 d1 26 a3 7a fc ba 89 79 4f 98 04 91 4f dc',
+ 4096:'bf 42 c3 01 8c 2f 7c 66 bf de 52 49 75 76 81 15'
+ }
+ ),
+ (
+ '1910833222772a',
+ {
+ 0: 'bc 92 22 db d3 27 4d 8f c6 6d 14 cc bd a6 69 0b',
+ 16: '7a e6 27 41 0c 9a 2b e6 93 df 5b b7 48 5a 63 e3',
+ 240: '3f 09 31 aa 03 de fb 30 0f 06 01 03 82 6f 2a 64',
+ 256: 'be aa 9e c8 d5 9b b6 81 29 f3 02 7c 96 36 11 81',
+ 496: '74 e0 4d b4 6d 28 64 8d 7d ee 8a 00 64 b0 6c fe',
+ 512: '9b 5e 81 c6 2f e0 23 c5 5b e4 2f 87 bb f9 32 b8',
+ 752: 'ce 17 8f c1 82 6e fe cb c1 82 f5 79 99 a4 61 40',
+ 768: '8b df 55 cd 55 06 1c 06 db a6 be 11 de 4a 57 8a',
+ 1008:'62 6f 5f 4d ce 65 25 01 f3 08 7d 39 c9 2c c3 49',
+ 1024:'42 da ac 6a 8f 9a b9 a7 fd 13 7c 60 37 82 56 82',
+ 1520:'cc 03 fd b7 91 92 a2 07 31 2f 53 f5 d4 dc 33 d9',
+ 1536:'f7 0f 14 12 2a 1c 98 a3 15 5d 28 b8 a0 a8 a4 1d',
+ 2032:'2a 3a 30 7a b2 70 8a 9c 00 fe 0b 42 f9 c2 d6 a1',
+ 2048:'86 26 17 62 7d 22 61 ea b0 b1 24 65 97 ca 0a e9',
+ 3056:'55 f8 77 ce 4f 2e 1d db bf 8e 13 e2 cd e0 fd c8',
+ 3072:'1b 15 56 cb 93 5f 17 33 37 70 5f bb 5d 50 1f c1',
+ 4080:'ec d0 e9 66 02 be 7f 8d 50 92 81 6c cc f2 c2 e9',
+ 4096:'02 78 81 fa b4 99 3a 1c 26 20 24 a9 4f ff 3f 61'
+ }
+ ),
+ # Page 8
+ (
+ '641910833222772a',
+ {
+ 0: 'bb f6 09 de 94 13 17 2d 07 66 0c b6 80 71 69 26',
+ 16: '46 10 1a 6d ab 43 11 5d 6c 52 2b 4f e9 36 04 a9',
+ 240: 'cb e1 ff f2 1c 96 f3 ee f6 1e 8f e0 54 2c bd f0',
+ 256: '34 79 38 bf fa 40 09 c5 12 cf b4 03 4b 0d d1 a7',
+ 496: '78 67 a7 86 d0 0a 71 47 90 4d 76 dd f1 e5 20 e3',
+ 512: '8d 3e 9e 1c ae fc cc b3 fb f8 d1 8f 64 12 0b 32',
+ 752: '94 23 37 f8 fd 76 f0 fa e8 c5 2d 79 54 81 06 72',
+ 768: 'b8 54 8c 10 f5 16 67 f6 e6 0e 18 2f a1 9b 30 f7',
+ 1008:'02 11 c7 c6 19 0c 9e fd 12 37 c3 4c 8f 2e 06 c4',
+ 1024:'bd a6 4f 65 27 6d 2a ac b8 f9 02 12 20 3a 80 8e',
+ 1520:'bd 38 20 f7 32 ff b5 3e c1 93 e7 9d 33 e2 7c 73',
+ 1536:'d0 16 86 16 86 19 07 d4 82 e3 6c da c8 cf 57 49',
+ 2032:'97 b0 f0 f2 24 b2 d2 31 71 14 80 8f b0 3a f7 a0',
+ 2048:'e5 96 16 e4 69 78 79 39 a0 63 ce ea 9a f9 56 d1',
+ 3056:'c4 7e 0d c1 66 09 19 c1 11 01 20 8f 9e 69 aa 1f',
+ 3072:'5a e4 f1 28 96 b8 37 9a 2a ad 89 b5 b5 53 d6 b0',
+ 4080:'6b 6b 09 8d 0c 29 3b c2 99 3d 80 bf 05 18 b6 d9',
+ 4096:'81 70 cc 3c cd 92 a6 98 62 1b 93 9d d3 8f e7 b9'
+ }
+ ),
+ (
+ '8b37641910833222772a',
+ {
+ 0: 'ab 65 c2 6e dd b2 87 60 0d b2 fd a1 0d 1e 60 5c',
+ 16: 'bb 75 90 10 c2 96 58 f2 c7 2d 93 a2 d1 6d 29 30',
+ 240: 'b9 01 e8 03 6e d1 c3 83 cd 3c 4c 4d d0 a6 ab 05',
+ 256: '3d 25 ce 49 22 92 4c 55 f0 64 94 33 53 d7 8a 6c',
+ 496: '12 c1 aa 44 bb f8 7e 75 e6 11 f6 9b 2c 38 f4 9b',
+ 512: '28 f2 b3 43 4b 65 c0 98 77 47 00 44 c6 ea 17 0d',
+ 752: 'bd 9e f8 22 de 52 88 19 61 34 cf 8a f7 83 93 04',
+ 768: '67 55 9c 23 f0 52 15 84 70 a2 96 f7 25 73 5a 32',
+ 1008:'8b ab 26 fb c2 c1 2b 0f 13 e2 ab 18 5e ab f2 41',
+ 1024:'31 18 5a 6d 69 6f 0c fa 9b 42 80 8b 38 e1 32 a2',
+ 1520:'56 4d 3d ae 18 3c 52 34 c8 af 1e 51 06 1c 44 b5',
+ 1536:'3c 07 78 a7 b5 f7 2d 3c 23 a3 13 5c 7d 67 b9 f4',
+ 2032:'f3 43 69 89 0f cf 16 fb 51 7d ca ae 44 63 b2 dd',
+ 2048:'02 f3 1c 81 e8 20 07 31 b8 99 b0 28 e7 91 bf a7',
+ 3056:'72 da 64 62 83 22 8c 14 30 08 53 70 17 95 61 6f',
+ 3072:'4e 0a 8c 6f 79 34 a7 88 e2 26 5e 81 d6 d0 c8 f4',
+ 4080:'43 8d d5 ea fe a0 11 1b 6f 36 b4 b9 38 da 2a 68',
+ 4096:'5f 6b fc 73 81 58 74 d9 71 00 f0 86 97 93 57 d8'
+ }
+ ),
+ # Page 9
+ (
+ 'ebb46227c6cc8b37641910833222772a',
+ {
+ 0: '72 0c 94 b6 3e df 44 e1 31 d9 50 ca 21 1a 5a 30',
+ 16: 'c3 66 fd ea cf 9c a8 04 36 be 7c 35 84 24 d2 0b',
+ 240: 'b3 39 4a 40 aa bf 75 cb a4 22 82 ef 25 a0 05 9f',
+ 256: '48 47 d8 1d a4 94 2d bc 24 9d ef c4 8c 92 2b 9f',
+ 496: '08 12 8c 46 9f 27 53 42 ad da 20 2b 2b 58 da 95',
+ 512: '97 0d ac ef 40 ad 98 72 3b ac 5d 69 55 b8 17 61',
+ 752: '3c b8 99 93 b0 7b 0c ed 93 de 13 d2 a1 10 13 ac',
+ 768: 'ef 2d 67 6f 15 45 c2 c1 3d c6 80 a0 2f 4a db fe',
+ 1008:'b6 05 95 51 4f 24 bc 9f e5 22 a6 ca d7 39 36 44',
+ 1024:'b5 15 a8 c5 01 17 54 f5 90 03 05 8b db 81 51 4e',
+ 1520:'3c 70 04 7e 8c bc 03 8e 3b 98 20 db 60 1d a4 95',
+ 1536:'11 75 da 6e e7 56 de 46 a5 3e 2b 07 56 60 b7 70',
+ 2032:'00 a5 42 bb a0 21 11 cc 2c 65 b3 8e bd ba 58 7e',
+ 2048:'58 65 fd bb 5b 48 06 41 04 e8 30 b3 80 f2 ae de',
+ 3056:'34 b2 1a d2 ad 44 e9 99 db 2d 7f 08 63 f0 d9 b6',
+ 3072:'84 a9 21 8f c3 6e 8a 5f 2c cf be ae 53 a2 7d 25',
+ 4080:'a2 22 1a 11 b8 33 cc b4 98 a5 95 40 f0 54 5f 4a',
+ 4096:'5b be b4 78 7d 59 e5 37 3f db ea 6c 6f 75 c2 9b'
+ }
+ ),
+ (
+ 'c109163908ebe51debb46227c6cc8b37641910833222772a',
+ {
+ 0: '54 b6 4e 6b 5a 20 b5 e2 ec 84 59 3d c7 98 9d a7',
+ 16: 'c1 35 ee e2 37 a8 54 65 ff 97 dc 03 92 4f 45 ce',
+ 240: 'cf cc 92 2f b4 a1 4a b4 5d 61 75 aa bb f2 d2 01',
+ 256: '83 7b 87 e2 a4 46 ad 0e f7 98 ac d0 2b 94 12 4f',
+ 496: '17 a6 db d6 64 92 6a 06 36 b3 f4 c3 7a 4f 46 94',
+ 512: '4a 5f 9f 26 ae ee d4 d4 a2 5f 63 2d 30 52 33 d9',
+ 752: '80 a3 d0 1e f0 0c 8e 9a 42 09 c1 7f 4e eb 35 8c',
+ 768: 'd1 5e 7d 5f fa aa bc 02 07 bf 20 0a 11 77 93 a2',
+ 1008:'34 96 82 bf 58 8e aa 52 d0 aa 15 60 34 6a ea fa',
+ 1024:'f5 85 4c db 76 c8 89 e3 ad 63 35 4e 5f 72 75 e3',
+ 1520:'53 2c 7c ec cb 39 df 32 36 31 84 05 a4 b1 27 9c',
+ 1536:'ba ef e6 d9 ce b6 51 84 22 60 e0 d1 e0 5e 3b 90',
+ 2032:'e8 2d 8c 6d b5 4e 3c 63 3f 58 1c 95 2b a0 42 07',
+ 2048:'4b 16 e5 0a bd 38 1b d7 09 00 a9 cd 9a 62 cb 23',
+ 3056:'36 82 ee 33 bd 14 8b d9 f5 86 56 cd 8f 30 d9 fb',
+ 3072:'1e 5a 0b 84 75 04 5d 9b 20 b2 62 86 24 ed fd 9e',
+ 4080:'63 ed d6 84 fb 82 62 82 fe 52 8f 9c 0e 92 37 bc',
+ 4096:'e4 dd 2e 98 d6 96 0f ae 0b 43 54 54 56 74 33 91'
+ }
+ ),
+ # Page 10
+ (
+ '1ada31d5cf688221c109163908ebe51debb46227c6cc8b37641910833222772a',
+ {
+ 0: 'dd 5b cb 00 18 e9 22 d4 94 75 9d 7c 39 5d 02 d3',
+ 16: 'c8 44 6f 8f 77 ab f7 37 68 53 53 eb 89 a1 c9 eb',
+ 240: 'af 3e 30 f9 c0 95 04 59 38 15 15 75 c3 fb 90 98',
+ 256: 'f8 cb 62 74 db 99 b8 0b 1d 20 12 a9 8e d4 8f 0e',
+ 496: '25 c3 00 5a 1c b8 5d e0 76 25 98 39 ab 71 98 ab',
+ 512: '9d cb c1 83 e8 cb 99 4b 72 7b 75 be 31 80 76 9c',
+ 752: 'a1 d3 07 8d fa 91 69 50 3e d9 d4 49 1d ee 4e b2',
+ 768: '85 14 a5 49 58 58 09 6f 59 6e 4b cd 66 b1 06 65',
+ 1008:'5f 40 d5 9e c1 b0 3b 33 73 8e fa 60 b2 25 5d 31',
+ 1024:'34 77 c7 f7 64 a4 1b ac ef f9 0b f1 4f 92 b7 cc',
+ 1520:'ac 4e 95 36 8d 99 b9 eb 78 b8 da 8f 81 ff a7 95',
+ 1536:'8c 3c 13 f8 c2 38 8b b7 3f 38 57 6e 65 b7 c4 46',
+ 2032:'13 c4 b9 c1 df b6 65 79 ed dd 8a 28 0b 9f 73 16',
+ 2048:'dd d2 78 20 55 01 26 69 8e fa ad c6 4b 64 f6 6e',
+ 3056:'f0 8f 2e 66 d2 8e d1 43 f3 a2 37 cf 9d e7 35 59',
+ 3072:'9e a3 6c 52 55 31 b8 80 ba 12 43 34 f5 7b 0b 70',
+ 4080:'d5 a3 9e 3d fc c5 02 80 ba c4 a6 b5 aa 0d ca 7d',
+ 4096:'37 0b 1c 1f e6 55 91 6d 97 fd 0d 47 ca 1d 72 b8'
+ }
+ )
+ ]
+
+ def test_keystream(self):
+ for tv in self.rfc6229_data:
+ key = unhexlify(b((tv[0])))
+ cipher = ARC4.new(key)
+ count = 0
+ for offset in range(0,4096+1,16):
+ ct = cipher.encrypt(b('\x00')*16)
+ expected = tv[1].get(offset)
+ if expected:
+ expected = unhexlify(b(expected.replace(" ",'')))
+ self.assertEqual(ct, expected)
+ count += 1
+ self.assertEqual(count, len(tv[1]))
+
+class Drop_Tests(unittest.TestCase):
+ key = b('\xAA')*16
+ data = b('\x00')*5000
+
+ def setUp(self):
+ self.cipher = ARC4.new(self.key)
+
+ def test_drop256_encrypt(self):
+ cipher_drop = ARC4.new(self.key, 256)
+ ct_drop = cipher_drop.encrypt(self.data[:16])
+ ct = self.cipher.encrypt(self.data)[256:256+16]
+ self.assertEqual(ct_drop, ct)
+
+ def test_drop256_decrypt(self):
+ cipher_drop = ARC4.new(self.key, 256)
+ pt_drop = cipher_drop.decrypt(self.data[:16])
+ pt = self.cipher.decrypt(self.data)[256:256+16]
+ self.assertEqual(pt_drop, pt)
+
+
+class KeyLength(unittest.TestCase):
+
+ def runTest(self):
+ self.assertRaises(ValueError, ARC4.new, bchr(0) * 4)
+ self.assertRaises(ValueError, ARC4.new, bchr(0) * 257)
+
+
+def get_tests(config={}):
+ from .common import make_stream_tests
+ tests = make_stream_tests(ARC4, "ARC4", test_data)
+ tests += list_test_cases(RFC6229_Tests)
+ tests += list_test_cases(Drop_Tests)
+ tests.append(KeyLength())
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_Blowfish.py b/lib/Crypto/SelfTest/Cipher/test_Blowfish.py
new file mode 100644
index 0000000..4ce3a41
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_Blowfish.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/test_Blowfish.py: Self-test for the Blowfish cipher
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.Blowfish"""
+
+import unittest
+
+from Crypto.Util.py3compat import bchr
+
+from Crypto.Cipher import Blowfish
+
+# This is a list of (plaintext, ciphertext, key) tuples.
+test_data = [
+ # Test vectors from http://www.schneier.com/code/vectors.txt
+ ('0000000000000000', '4ef997456198dd78', '0000000000000000'),
+ ('ffffffffffffffff', '51866fd5b85ecb8a', 'ffffffffffffffff'),
+ ('1000000000000001', '7d856f9a613063f2', '3000000000000000'),
+ ('1111111111111111', '2466dd878b963c9d', '1111111111111111'),
+ ('1111111111111111', '61f9c3802281b096', '0123456789abcdef'),
+ ('0123456789abcdef', '7d0cc630afda1ec7', '1111111111111111'),
+ ('0000000000000000', '4ef997456198dd78', '0000000000000000'),
+ ('0123456789abcdef', '0aceab0fc6a0a28d', 'fedcba9876543210'),
+ ('01a1d6d039776742', '59c68245eb05282b', '7ca110454a1a6e57'),
+ ('5cd54ca83def57da', 'b1b8cc0b250f09a0', '0131d9619dc1376e'),
+ ('0248d43806f67172', '1730e5778bea1da4', '07a1133e4a0b2686'),
+ ('51454b582ddf440a', 'a25e7856cf2651eb', '3849674c2602319e'),
+ ('42fd443059577fa2', '353882b109ce8f1a', '04b915ba43feb5b6'),
+ ('059b5e0851cf143a', '48f4d0884c379918', '0113b970fd34f2ce'),
+ ('0756d8e0774761d2', '432193b78951fc98', '0170f175468fb5e6'),
+ ('762514b829bf486a', '13f04154d69d1ae5', '43297fad38e373fe'),
+ ('3bdd119049372802', '2eedda93ffd39c79', '07a7137045da2a16'),
+ ('26955f6835af609a', 'd887e0393c2da6e3', '04689104c2fd3b2f'),
+ ('164d5e404f275232', '5f99d04f5b163969', '37d06bb516cb7546'),
+ ('6b056e18759f5cca', '4a057a3b24d3977b', '1f08260d1ac2465e'),
+ ('004bd6ef09176062', '452031c1e4fada8e', '584023641aba6176'),
+ ('480d39006ee762f2', '7555ae39f59b87bd', '025816164629b007'),
+ ('437540c8698f3cfa', '53c55f9cb49fc019', '49793ebc79b3258f'),
+ ('072d43a077075292', '7a8e7bfa937e89a3', '4fb05e1515ab73a7'),
+ ('02fe55778117f12a', 'cf9c5d7a4986adb5', '49e95d6d4ca229bf'),
+ ('1d9d5c5018f728c2', 'd1abb290658bc778', '018310dc409b26d6'),
+ ('305532286d6f295a', '55cb3774d13ef201', '1c587f1c13924fef'),
+ ('0123456789abcdef', 'fa34ec4847b268b2', '0101010101010101'),
+ ('0123456789abcdef', 'a790795108ea3cae', '1f1f1f1f0e0e0e0e'),
+ ('0123456789abcdef', 'c39e072d9fac631d', 'e0fee0fef1fef1fe'),
+ ('ffffffffffffffff', '014933e0cdaff6e4', '0000000000000000'),
+ ('0000000000000000', 'f21e9a77b71c49bc', 'ffffffffffffffff'),
+ ('0000000000000000', '245946885754369a', '0123456789abcdef'),
+ ('ffffffffffffffff', '6b5c5a9c5d9e0a5a', 'fedcba9876543210'),
+ #('fedcba9876543210', 'f9ad597c49db005e', 'f0'),
+ #('fedcba9876543210', 'e91d21c1d961a6d6', 'f0e1'),
+ #('fedcba9876543210', 'e9c2b70a1bc65cf3', 'f0e1d2'),
+ ('fedcba9876543210', 'be1e639408640f05', 'f0e1d2c3'),
+ ('fedcba9876543210', 'b39e44481bdb1e6e', 'f0e1d2c3b4'),
+ ('fedcba9876543210', '9457aa83b1928c0d', 'f0e1d2c3b4a5'),
+ ('fedcba9876543210', '8bb77032f960629d', 'f0e1d2c3b4a596'),
+ ('fedcba9876543210', 'e87a244e2cc85e82', 'f0e1d2c3b4a59687'),
+ ('fedcba9876543210', '15750e7a4f4ec577', 'f0e1d2c3b4a5968778'),
+ ('fedcba9876543210', '122ba70b3ab64ae0', 'f0e1d2c3b4a596877869'),
+ ('fedcba9876543210', '3a833c9affc537f6', 'f0e1d2c3b4a5968778695a'),
+ ('fedcba9876543210', '9409da87a90f6bf2', 'f0e1d2c3b4a5968778695a4b'),
+ ('fedcba9876543210', '884f80625060b8b4', 'f0e1d2c3b4a5968778695a4b3c'),
+ ('fedcba9876543210', '1f85031c19e11968', 'f0e1d2c3b4a5968778695a4b3c2d'),
+ ('fedcba9876543210', '79d9373a714ca34f', 'f0e1d2c3b4a5968778695a4b3c2d1e'),
+ ('fedcba9876543210', '93142887ee3be15c',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f'),
+ ('fedcba9876543210', '03429e838ce2d14b',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f00'),
+ ('fedcba9876543210', 'a4299e27469ff67b',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f0011'),
+ ('fedcba9876543210', 'afd5aed1c1bc96a8',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f001122'),
+ ('fedcba9876543210', '10851c0e3858da9f',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f00112233'),
+ ('fedcba9876543210', 'e6f51ed79b9db21f',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f0011223344'),
+ ('fedcba9876543210', '64a6e14afd36b46f',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f001122334455'),
+ ('fedcba9876543210', '80c7d7d45a5479ad',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f00112233445566'),
+ ('fedcba9876543210', '05044b62fa52d080',
+ 'f0e1d2c3b4a5968778695a4b3c2d1e0f0011223344556677'),
+]
+
+
+class KeyLength(unittest.TestCase):
+
+ def runTest(self):
+ self.assertRaises(ValueError, Blowfish.new, bchr(0) * 3,
+ Blowfish.MODE_ECB)
+ self.assertRaises(ValueError, Blowfish.new, bchr(0) * 57,
+ Blowfish.MODE_ECB)
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ cipher = Blowfish.new(b'4'*16, Blowfish.MODE_ECB)
+
+ pt = b'5' * 16
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(16)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(16))
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
+
+ shorter_output = bytearray(7)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ from .common import make_block_tests
+ tests = make_block_tests(Blowfish, "Blowfish", test_data)
+ tests.append(KeyLength())
+ tests += [TestOutput()]
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_CAST.py b/lib/Crypto/SelfTest/Cipher/test_CAST.py
new file mode 100644
index 0000000..ff13bd4
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_CAST.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/CAST.py: Self-test for the CAST-128 (CAST5) cipher
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.CAST"""
+
+import unittest
+
+from Crypto.Util.py3compat import bchr
+
+from Crypto.Cipher import CAST
+
+# This is a list of (plaintext, ciphertext, key) tuples.
+test_data = [
+ # Test vectors from RFC 2144, B.1
+ ('0123456789abcdef', '238b4fe5847e44b2',
+ '0123456712345678234567893456789a',
+ '128-bit key'),
+
+ ('0123456789abcdef', 'eb6a711a2c02271b',
+ '01234567123456782345',
+ '80-bit key'),
+
+ ('0123456789abcdef', '7ac816d16e9b302e',
+ '0123456712',
+ '40-bit key'),
+]
+
+
+class KeyLength(unittest.TestCase):
+
+ def runTest(self):
+ self.assertRaises(ValueError, CAST.new, bchr(0) * 4, CAST.MODE_ECB)
+ self.assertRaises(ValueError, CAST.new, bchr(0) * 17, CAST.MODE_ECB)
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ cipher = CAST.new(b'4'*16, CAST.MODE_ECB)
+
+ pt = b'5' * 16
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(16)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(16))
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
+
+ shorter_output = bytearray(7)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ from .common import make_block_tests
+
+ tests = make_block_tests(CAST, "CAST", test_data)
+ tests.append(KeyLength())
+ tests.append(TestOutput())
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_CBC.py b/lib/Crypto/SelfTest/Cipher/test_CBC.py
new file mode 100644
index 0000000..374fb5a
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_CBC.py
@@ -0,0 +1,556 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.py3compat import tobytes, is_string
+from Crypto.Cipher import AES, DES3, DES
+from Crypto.Hash import SHAKE128
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+class BlockChainingTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ key_192 = get_tag_random("key_192", 24)
+ iv_128 = get_tag_random("iv_128", 16)
+ iv_64 = get_tag_random("iv_64", 8)
+ data_128 = get_tag_random("data_128", 16)
+
+ def test_loopback_128(self):
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_loopback_64(self):
+ cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
+ pt = get_tag_random("plaintext", 8 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_iv(self):
+ # If not passed, the iv is created randomly
+ cipher = AES.new(self.key_128, self.aes_mode)
+ iv1 = cipher.iv
+ cipher = AES.new(self.key_128, self.aes_mode)
+ iv2 = cipher.iv
+ self.assertNotEqual(iv1, iv2)
+ self.assertEqual(len(iv1), 16)
+
+ # IV can be passed in uppercase or lowercase
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ ct = cipher.encrypt(self.data_128)
+
+ cipher = AES.new(self.key_128, self.aes_mode, iv=self.iv_128)
+ self.assertEqual(ct, cipher.encrypt(self.data_128))
+
+ cipher = AES.new(self.key_128, self.aes_mode, IV=self.iv_128)
+ self.assertEqual(ct, cipher.encrypt(self.data_128))
+
+ def test_iv_must_be_bytes(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
+ iv = u'test1234567890-*')
+
+ def test_only_one_iv(self):
+ # Only one IV/iv keyword allowed
+ self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
+ iv=self.iv_128, IV=self.iv_128)
+
+ def test_iv_with_matching_length(self):
+ self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
+ b"")
+ self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
+ self.iv_128[:15])
+ self.assertRaises(ValueError, AES.new, self.key_128, self.aes_mode,
+ self.iv_128 + b"0")
+
+ def test_block_size_128(self):
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ self.assertEqual(cipher.block_size, AES.block_size)
+
+ def test_block_size_64(self):
+ cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
+ self.assertEqual(cipher.block_size, DES3.block_size)
+
+ def test_unaligned_data_128(self):
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ for wrong_length in range(1,16):
+ self.assertRaises(ValueError, cipher.encrypt, b"5" * wrong_length)
+
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ for wrong_length in range(1,16):
+ self.assertRaises(ValueError, cipher.decrypt, b"5" * wrong_length)
+
+ def test_unaligned_data_64(self):
+ cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
+ for wrong_length in range(1,8):
+ self.assertRaises(ValueError, cipher.encrypt, b"5" * wrong_length)
+
+ cipher = DES3.new(self.key_192, self.des3_mode, self.iv_64)
+ for wrong_length in range(1,8):
+ self.assertRaises(ValueError, cipher.decrypt, b"5" * wrong_length)
+
+ def test_IV_iv_attributes(self):
+ data = get_tag_random("data", 16 * 100)
+ for func in "encrypt", "decrypt":
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ getattr(cipher, func)(data)
+ self.assertEqual(cipher.iv, self.iv_128)
+ self.assertEqual(cipher.IV, self.iv_128)
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
+ self.iv_128, 7)
+ self.assertRaises(TypeError, AES.new, self.key_128, self.aes_mode,
+ iv=self.iv_128, unknown=7)
+ # But some are only known by the base cipher (e.g. use_aesni consumed by the AES module)
+ AES.new(self.key_128, self.aes_mode, iv=self.iv_128, use_aesni=False)
+
+ def test_null_encryption_decryption(self):
+ for func in "encrypt", "decrypt":
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ result = getattr(cipher, func)(b"")
+ self.assertEqual(result, b"")
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ cipher.encrypt(b"")
+ self.assertRaises(TypeError, cipher.decrypt, b"")
+
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ cipher.decrypt(b"")
+ self.assertRaises(TypeError, cipher.encrypt, b"")
+
+ def test_data_must_be_bytes(self):
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')
+
+ cipher = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
+
+ def test_bytearray(self):
+ data = b"1" * 128
+ data_ba = bytearray(data)
+
+ # Encrypt
+ key_ba = bytearray(self.key_128)
+ iv_ba = bytearray(self.iv_128)
+
+ cipher1 = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ ref1 = cipher1.encrypt(data)
+
+ cipher2 = AES.new(key_ba, self.aes_mode, iv_ba)
+ key_ba[:3] = b'\xFF\xFF\xFF'
+ iv_ba[:3] = b'\xFF\xFF\xFF'
+ ref2 = cipher2.encrypt(data_ba)
+
+ self.assertEqual(ref1, ref2)
+ self.assertEqual(cipher1.iv, cipher2.iv)
+
+ # Decrypt
+ key_ba = bytearray(self.key_128)
+ iv_ba = bytearray(self.iv_128)
+
+ cipher3 = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ ref3 = cipher3.decrypt(data)
+
+ cipher4 = AES.new(key_ba, self.aes_mode, iv_ba)
+ key_ba[:3] = b'\xFF\xFF\xFF'
+ iv_ba[:3] = b'\xFF\xFF\xFF'
+ ref4 = cipher4.decrypt(data_ba)
+
+ self.assertEqual(ref3, ref4)
+
+ def test_memoryview(self):
+ data = b"1" * 128
+ data_mv = memoryview(bytearray(data))
+
+ # Encrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ iv_mv = memoryview(bytearray(self.iv_128))
+
+ cipher1 = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ ref1 = cipher1.encrypt(data)
+
+ cipher2 = AES.new(key_mv, self.aes_mode, iv_mv)
+ key_mv[:3] = b'\xFF\xFF\xFF'
+ iv_mv[:3] = b'\xFF\xFF\xFF'
+ ref2 = cipher2.encrypt(data_mv)
+
+ self.assertEqual(ref1, ref2)
+ self.assertEqual(cipher1.iv, cipher2.iv)
+
+ # Decrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ iv_mv = memoryview(bytearray(self.iv_128))
+
+ cipher3 = AES.new(self.key_128, self.aes_mode, self.iv_128)
+ ref3 = cipher3.decrypt(data)
+
+ cipher4 = AES.new(key_mv, self.aes_mode, iv_mv)
+ key_mv[:3] = b'\xFF\xFF\xFF'
+ iv_mv[:3] = b'\xFF\xFF\xFF'
+ ref4 = cipher4.decrypt(data_mv)
+
+ self.assertEqual(ref3, ref4)
+
+ def test_output_param(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(128)
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+
+ def test_output_param_same_buffer(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ ct = cipher.encrypt(pt)
+
+ pt_ba = bytearray(pt)
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ res = cipher.encrypt(pt_ba, output=pt_ba)
+ self.assertEqual(ct, pt_ba)
+ self.assertEqual(res, None)
+
+ ct_ba = bytearray(ct)
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ res = cipher.decrypt(ct_ba, output=ct_ba)
+ self.assertEqual(pt, ct_ba)
+ self.assertEqual(res, None)
+
+
+ def test_output_param_memoryview(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ ct = cipher.encrypt(pt)
+
+ output = memoryview(bytearray(128))
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ def test_output_param_neg(self):
+ LEN_PT = 128
+
+ pt = b'5' * LEN_PT
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0' * LEN_PT)
+
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0' * LEN_PT)
+
+ shorter_output = bytearray(LEN_PT - 1)
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ cipher = AES.new(b'4'*16, self.aes_mode, iv=self.iv_128)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+class CbcTests(BlockChainingTests):
+ aes_mode = AES.MODE_CBC
+ des3_mode = DES3.MODE_CBC
+
+
+class NistBlockChainingVectors(unittest.TestCase):
+
+ def _do_kat_aes_test(self, file_name):
+
+ test_vectors = load_test_vectors(("Cipher", "AES"),
+ file_name,
+ "AES CBC KAT",
+ { "count" : lambda x: int(x) } )
+ if test_vectors is None:
+ return
+
+ direction = None
+ for tv in test_vectors:
+
+ # The test vector file contains some directive lines
+ if is_string(tv):
+ direction = tv
+ continue
+
+ self.description = tv.desc
+
+ cipher = AES.new(tv.key, self.aes_mode, tv.iv)
+ if direction == "[ENCRYPT]":
+ self.assertEqual(cipher.encrypt(tv.plaintext), tv.ciphertext)
+ elif direction == "[DECRYPT]":
+ self.assertEqual(cipher.decrypt(tv.ciphertext), tv.plaintext)
+ else:
+ assert False
+
+ # See Section 6.4.2 in AESAVS
+ def _do_mct_aes_test(self, file_name):
+
+ test_vectors = load_test_vectors(("Cipher", "AES"),
+ file_name,
+ "AES CBC Montecarlo",
+ { "count" : lambda x: int(x) } )
+ if test_vectors is None:
+ return
+
+ direction = None
+ for tv in test_vectors:
+
+ # The test vector file contains some directive lines
+ if is_string(tv):
+ direction = tv
+ continue
+
+ self.description = tv.desc
+ cipher = AES.new(tv.key, self.aes_mode, tv.iv)
+
+ if direction == '[ENCRYPT]':
+ cts = [ tv.iv ]
+ for count in range(1000):
+ cts.append(cipher.encrypt(tv.plaintext))
+ tv.plaintext = cts[-2]
+ self.assertEqual(cts[-1], tv.ciphertext)
+ elif direction == '[DECRYPT]':
+ pts = [ tv.iv]
+ for count in range(1000):
+ pts.append(cipher.decrypt(tv.ciphertext))
+ tv.ciphertext = pts[-2]
+ self.assertEqual(pts[-1], tv.plaintext)
+ else:
+ assert False
+
+ def _do_tdes_test(self, file_name):
+
+ test_vectors = load_test_vectors(("Cipher", "TDES"),
+ file_name,
+ "TDES CBC KAT",
+ { "count" : lambda x: int(x) } )
+ if test_vectors is None:
+ return
+
+ direction = None
+ for tv in test_vectors:
+
+ # The test vector file contains some directive lines
+ if is_string(tv):
+ direction = tv
+ continue
+
+ self.description = tv.desc
+ if hasattr(tv, "keys"):
+ cipher = DES.new(tv.keys, self.des_mode, tv.iv)
+ else:
+ if tv.key1 != tv.key3:
+ key = tv.key1 + tv.key2 + tv.key3 # Option 3
+ else:
+ key = tv.key1 + tv.key2 # Option 2
+ cipher = DES3.new(key, self.des3_mode, tv.iv)
+
+ if direction == "[ENCRYPT]":
+ self.assertEqual(cipher.encrypt(tv.plaintext), tv.ciphertext)
+ elif direction == "[DECRYPT]":
+ self.assertEqual(cipher.decrypt(tv.ciphertext), tv.plaintext)
+ else:
+ assert False
+
+
+class NistCbcVectors(NistBlockChainingVectors):
+ aes_mode = AES.MODE_CBC
+ des_mode = DES.MODE_CBC
+ des3_mode = DES3.MODE_CBC
+
+
+# Create one test method per file
+nist_aes_kat_mmt_files = (
+ # KAT
+ "CBCGFSbox128.rsp",
+ "CBCGFSbox192.rsp",
+ "CBCGFSbox256.rsp",
+ "CBCKeySbox128.rsp",
+ "CBCKeySbox192.rsp",
+ "CBCKeySbox256.rsp",
+ "CBCVarKey128.rsp",
+ "CBCVarKey192.rsp",
+ "CBCVarKey256.rsp",
+ "CBCVarTxt128.rsp",
+ "CBCVarTxt192.rsp",
+ "CBCVarTxt256.rsp",
+ # MMT
+ "CBCMMT128.rsp",
+ "CBCMMT192.rsp",
+ "CBCMMT256.rsp",
+ )
+nist_aes_mct_files = (
+ "CBCMCT128.rsp",
+ "CBCMCT192.rsp",
+ "CBCMCT256.rsp",
+ )
+
+for file_name in nist_aes_kat_mmt_files:
+ def new_func(self, file_name=file_name):
+ self._do_kat_aes_test(file_name)
+ setattr(NistCbcVectors, "test_AES_" + file_name, new_func)
+
+for file_name in nist_aes_mct_files:
+ def new_func(self, file_name=file_name):
+ self._do_mct_aes_test(file_name)
+ setattr(NistCbcVectors, "test_AES_" + file_name, new_func)
+del file_name, new_func
+
+nist_tdes_files = (
+ "TCBCMMT2.rsp", # 2TDES
+ "TCBCMMT3.rsp", # 3TDES
+ "TCBCinvperm.rsp", # Single DES
+ "TCBCpermop.rsp",
+ "TCBCsubtab.rsp",
+ "TCBCvarkey.rsp",
+ "TCBCvartext.rsp",
+ )
+
+for file_name in nist_tdes_files:
+ def new_func(self, file_name=file_name):
+ self._do_tdes_test(file_name)
+ setattr(NistCbcVectors, "test_TDES_" + file_name, new_func)
+
+# END OF NIST CBC TEST VECTORS
+
+
+class SP800TestVectors(unittest.TestCase):
+ """Class exercising the CBC test vectors found in Section F.2
+ of NIST SP 800-3A"""
+
+ def test_aes_128(self):
+ key = '2b7e151628aed2a6abf7158809cf4f3c'
+ iv = '000102030405060708090a0b0c0d0e0f'
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = '7649abac8119b246cee98e9b12e9197d' +\
+ '5086cb9b507219ee95db113a917678b2' +\
+ '73bed6b8e3c1743b7116e69e22229516' +\
+ '3ff1caa1681fac09120eca307586e1a7'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CBC, iv)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CBC, iv)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_192(self):
+ key = '8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b'
+ iv = '000102030405060708090a0b0c0d0e0f'
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = '4f021db243bc633d7178183a9fa071e8' +\
+ 'b4d9ada9ad7dedf4e5e738763f69145a' +\
+ '571b242012fb7ae07fa9baac3df102e0' +\
+ '08b0e27988598881d920a9e64f5615cd'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CBC, iv)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CBC, iv)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_256(self):
+ key = '603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4'
+ iv = '000102030405060708090a0b0c0d0e0f'
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = 'f58c4c04d6e5f1ba779eabfb5f7bfbd6' +\
+ '9cfc4e967edb808d679f777bc6702c7d' +\
+ '39f23369a9d9bacfa530e26304231461' +\
+ 'b2eb05e2c39be9fcda6c19078c6a9d1b'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CBC, iv)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CBC, iv)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(CbcTests)
+ if config.get('slow_tests'):
+ tests += list_test_cases(NistCbcVectors)
+ tests += list_test_cases(SP800TestVectors)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_CCM.py b/lib/Crypto/SelfTest/Cipher/test_CCM.py
new file mode 100644
index 0000000..e8ebc0b
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_CCM.py
@@ -0,0 +1,936 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+from Crypto.Util.py3compat import tobytes, bchr
+from Crypto.Cipher import AES
+from Crypto.Hash import SHAKE128
+
+from Crypto.Util.strxor import strxor
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class CcmTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data = get_tag_random("data", 128)
+
+ def test_loopback_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_nonce(self):
+ # If not passed, the nonce is created randomly
+ cipher = AES.new(self.key_128, AES.MODE_CCM)
+ nonce1 = cipher.nonce
+ cipher = AES.new(self.key_128, AES.MODE_CCM)
+ nonce2 = cipher.nonce
+ self.assertEqual(len(nonce1), 11)
+ self.assertNotEqual(nonce1, nonce2)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, self.nonce_96)
+ ct = cipher.encrypt(self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertEqual(ct, cipher.encrypt(self.data))
+
+ def test_nonce_must_be_bytes(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CCM,
+ nonce=u'test12345678')
+
+ def test_nonce_length(self):
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CCM,
+ nonce=b"")
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CCM,
+ nonce=bchr(1) * 6)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CCM,
+ nonce=bchr(1) * 14)
+ for x in range(7, 13 + 1):
+ AES.new(self.key_128, AES.MODE_CCM, nonce=bchr(1) * x)
+
+ def test_block_size(self):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertEqual(cipher.block_size, AES.block_size)
+
+ def test_nonce_attribute(self):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertEqual(cipher.nonce, self.nonce_96)
+
+ # By default, a 11 bytes long nonce is randomly generated
+ nonce1 = AES.new(self.key_128, AES.MODE_CCM).nonce
+ nonce2 = AES.new(self.key_128, AES.MODE_CCM).nonce
+ self.assertEqual(len(nonce1), 11)
+ self.assertNotEqual(nonce1, nonce2)
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CCM,
+ self.nonce_96, 7)
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96, unknown=7)
+
+ # But some are only known by the base cipher
+ # (e.g. use_aesni consumed by the AES module)
+ AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ use_aesni=False)
+
+ def test_null_encryption_decryption(self):
+ for func in "encrypt", "decrypt":
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ result = getattr(cipher, func)(b"")
+ self.assertEqual(result, b"")
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.encrypt(b"")
+ self.assertRaises(TypeError, cipher.decrypt, b"")
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.decrypt(b"")
+ self.assertRaises(TypeError, cipher.encrypt, b"")
+
+ def test_data_must_be_bytes(self):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
+
+ def test_mac_len(self):
+ # Invalid MAC length
+ for mac_len in range(3, 17 + 1, 2):
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96, mac_len=mac_len)
+
+ # Valid MAC length
+ for mac_len in range(4, 16 + 1, 2):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ mac_len=mac_len)
+ _, mac = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(len(mac), mac_len)
+
+ # Default MAC length
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ _, mac = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(len(mac), 16)
+
+ def test_invalid_mac(self):
+ from Crypto.Util.strxor import strxor_c
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ invalid_mac = strxor_c(mac, 0x01)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct,
+ invalid_mac)
+
+ def test_hex_mac(self):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ mac_hex = cipher.hexdigest()
+ self.assertEqual(cipher.digest(), unhexlify(mac_hex))
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.hexverify(mac_hex)
+
+ def test_longer_assoc_data_than_declared(self):
+ # More than zero
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ assoc_len=0)
+ self.assertRaises(ValueError, cipher.update, b"1")
+
+ # Too large
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ assoc_len=15)
+ self.assertRaises(ValueError, cipher.update, self.data)
+
+ def test_shorter_assoc_data_than_expected(self):
+ DATA_LEN = len(self.data)
+
+ # With plaintext
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ assoc_len=DATA_LEN + 1)
+ cipher.update(self.data)
+ self.assertRaises(ValueError, cipher.encrypt, self.data)
+
+ # With empty plaintext
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ assoc_len=DATA_LEN + 1)
+ cipher.update(self.data)
+ self.assertRaises(ValueError, cipher.digest)
+
+ # With ciphertext
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ assoc_len=DATA_LEN + 1)
+ cipher.update(self.data)
+ self.assertRaises(ValueError, cipher.decrypt, self.data)
+
+ # With empty ciphertext
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ mac = cipher.digest()
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ assoc_len=DATA_LEN + 1)
+ cipher.update(self.data)
+ self.assertRaises(ValueError, cipher.verify, mac)
+
+ def test_shorter_and_longer_plaintext_than_declared(self):
+ DATA_LEN = len(self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ msg_len=DATA_LEN + 1)
+ cipher.encrypt(self.data)
+ self.assertRaises(ValueError, cipher.digest)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ msg_len=DATA_LEN - 1)
+ self.assertRaises(ValueError, cipher.encrypt, self.data)
+
+ def test_shorter_ciphertext_than_declared(self):
+ DATA_LEN = len(self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ msg_len=DATA_LEN + 1)
+ cipher.decrypt(ct)
+ self.assertRaises(ValueError, cipher.verify, mac)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ msg_len=DATA_LEN - 1)
+ self.assertRaises(ValueError, cipher.decrypt, ct)
+
+ def test_message_chunks(self):
+ # Validate that both associated data and plaintext/ciphertext
+ # can be broken up in chunks of arbitrary length
+
+ auth_data = get_tag_random("authenticated data", 127)
+ plaintext = get_tag_random("plaintext", 127)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.update(auth_data)
+ ciphertext, ref_mac = cipher.encrypt_and_digest(plaintext)
+
+ def break_up(data, chunk_length):
+ return [data[i:i+chunk_length] for i in range(0, len(data),
+ chunk_length)]
+
+ # Encryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ msg_len=127, assoc_len=127)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ pt2 = b""
+ for chunk in break_up(ciphertext, chunk_length):
+ pt2 += cipher.decrypt(chunk)
+ self.assertEqual(plaintext, pt2)
+ cipher.verify(ref_mac)
+
+ # Decryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96,
+ msg_len=127, assoc_len=127)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ ct2 = b""
+ for chunk in break_up(plaintext, chunk_length):
+ ct2 += cipher.encrypt(chunk)
+ self.assertEqual(ciphertext, ct2)
+ self.assertEqual(cipher.digest(), ref_mac)
+
+ def test_bytearray(self):
+
+ # Encrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data)
+ data_ba = bytearray(self.data)
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_CCM,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct = cipher1.encrypt(self.data)
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_ba,
+ AES.MODE_CCM,
+ nonce=nonce_ba)
+ key_ba[:3] = b"\xFF\xFF\xFF"
+ nonce_ba[:3] = b"\xFF\xFF\xFF"
+ cipher2.update(header_ba)
+ header_ba[:3] = b"\xFF\xFF\xFF"
+ ct_test = cipher2.encrypt(data_ba)
+ data_ba[:3] = b"\xFF\xFF\xFF"
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data)
+ del data_ba
+
+ cipher4 = AES.new(key_ba,
+ AES.MODE_CCM,
+ nonce=nonce_ba)
+ key_ba[:3] = b"\xFF\xFF\xFF"
+ nonce_ba[:3] = b"\xFF\xFF\xFF"
+ cipher4.update(header_ba)
+ header_ba[:3] = b"\xFF\xFF\xFF"
+ pt_test = cipher4.decrypt_and_verify(bytearray(ct_test), bytearray(tag_test))
+
+ self.assertEqual(self.data, pt_test)
+
+ def test_memoryview(self):
+
+ # Encrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data))
+ data_mv = memoryview(bytearray(self.data))
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_CCM,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct = cipher1.encrypt(self.data)
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_mv,
+ AES.MODE_CCM,
+ nonce=nonce_mv)
+ key_mv[:3] = b"\xFF\xFF\xFF"
+ nonce_mv[:3] = b"\xFF\xFF\xFF"
+ cipher2.update(header_mv)
+ header_mv[:3] = b"\xFF\xFF\xFF"
+ ct_test = cipher2.encrypt(data_mv)
+ data_mv[:3] = b"\xFF\xFF\xFF"
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data))
+ del data_mv
+
+ cipher4 = AES.new(key_mv,
+ AES.MODE_CCM,
+ nonce=nonce_mv)
+ key_mv[:3] = b"\xFF\xFF\xFF"
+ nonce_mv[:3] = b"\xFF\xFF\xFF"
+ cipher4.update(header_mv)
+ header_mv[:3] = b"\xFF\xFF\xFF"
+ pt_test = cipher4.decrypt_and_verify(memoryview(ct_test), memoryview(tag_test))
+
+ self.assertEqual(self.data, pt_test)
+
+ def test_output_param(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+ tag = cipher.digest()
+
+ output = bytearray(128)
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ res, tag_out = cipher.encrypt_and_digest(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+ self.assertEqual(tag, tag_out)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ res = cipher.decrypt_and_verify(ct, tag, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ def test_output_param_memoryview(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+
+ output = memoryview(bytearray(128))
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ def test_output_param_neg(self):
+
+ pt = b'5' * 16
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
+
+ shorter_output = bytearray(15)
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+class CcmFSMTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data = get_tag_random("data", 16)
+
+ def test_valid_init_encrypt_decrypt_digest_verify(self):
+ # No authenticated data, fixed plaintext
+ for assoc_len in (None, 0):
+ for msg_len in (None, len(self.data)):
+ # Verify path INIT->ENCRYPT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ assoc_len=assoc_len,
+ msg_len=msg_len)
+ ct = cipher.encrypt(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->DECRYPT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ assoc_len=assoc_len,
+ msg_len=msg_len)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_update_digest_verify(self):
+ # No plaintext, fixed authenticated data
+ for assoc_len in (None, len(self.data)):
+ for msg_len in (None, 0):
+ # Verify path INIT->UPDATE->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ assoc_len=assoc_len,
+ msg_len=msg_len)
+ cipher.update(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ assoc_len=assoc_len,
+ msg_len=msg_len)
+ cipher.update(self.data)
+ cipher.verify(mac)
+
+ def test_valid_full_path(self):
+ # Fixed authenticated data, fixed plaintext
+ for assoc_len in (None, len(self.data)):
+ for msg_len in (None, len(self.data)):
+ # Verify path INIT->UPDATE->ENCRYPT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ assoc_len=assoc_len,
+ msg_len=msg_len)
+ cipher.update(self.data)
+ ct = cipher.encrypt(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->DECRYPT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ assoc_len=assoc_len,
+ msg_len=msg_len)
+ cipher.update(self.data)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_digest(self):
+ # Verify path INIT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.digest()
+
+ def test_valid_init_verify(self):
+ # Verify path INIT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ mac = cipher.digest()
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.verify(mac)
+
+ def test_valid_multiple_encrypt_or_decrypt(self):
+ # Only possible if msg_len is declared in advance
+ for method_name in "encrypt", "decrypt":
+ for auth_data in (None, b"333", self.data,
+ self.data + b"3"):
+ if auth_data is None:
+ assoc_len = None
+ else:
+ assoc_len = len(auth_data)
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ msg_len=64,
+ assoc_len=assoc_len)
+ if auth_data is not None:
+ cipher.update(auth_data)
+ method = getattr(cipher, method_name)
+ method(self.data)
+ method(self.data)
+ method(self.data)
+ method(self.data)
+
+ def test_valid_multiple_digest_or_verify(self):
+ # Multiple calls to digest
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ first_mac = cipher.digest()
+ for x in range(4):
+ self.assertEqual(first_mac, cipher.digest())
+
+ # Multiple calls to verify
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ for x in range(5):
+ cipher.verify(first_mac)
+
+ def test_valid_encrypt_and_digest_decrypt_and_verify(self):
+ # encrypt_and_digest
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ # decrypt_and_verify
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ pt = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(self.data, pt)
+
+ def test_invalid_multiple_encrypt_decrypt_without_msg_len(self):
+ # Once per method, with or without assoc. data
+ for method_name in "encrypt", "decrypt":
+ for assoc_data_present in (True, False):
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96)
+ if assoc_data_present:
+ cipher.update(self.data)
+ method = getattr(cipher, method_name)
+ method(self.data)
+ self.assertRaises(TypeError, method, self.data)
+
+ def test_invalid_mixing_encrypt_decrypt(self):
+ # Once per method, with or without assoc. data
+ for method1_name, method2_name in (("encrypt", "decrypt"),
+ ("decrypt", "encrypt")):
+ for assoc_data_present in (True, False):
+ cipher = AES.new(self.key_128, AES.MODE_CCM,
+ nonce=self.nonce_96,
+ msg_len=32)
+ if assoc_data_present:
+ cipher.update(self.data)
+ getattr(cipher, method1_name)(self.data)
+ self.assertRaises(TypeError, getattr(cipher, method2_name),
+ self.data)
+
+ def test_invalid_encrypt_or_update_after_digest(self):
+ for method_name in "encrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.encrypt(self.data)
+ cipher.digest()
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.encrypt_and_digest(self.data)
+
+ def test_invalid_decrypt_or_update_after_verify(self):
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data)
+ mac = cipher.digest()
+
+ for method_name in "decrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_CCM, nonce=self.nonce_96)
+ cipher.decrypt_and_verify(ct, mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+
+class TestVectors(unittest.TestCase):
+ """Class exercising the CCM test vectors found in Appendix C
+ of NIST SP 800-38C and in RFC 3610"""
+
+ # List of test vectors, each made up of:
+ # - authenticated data
+ # - plaintext
+ # - ciphertext
+ # - MAC
+ # - AES key
+ # - nonce
+ test_vectors_hex = [
+ # NIST SP 800 38C
+ ( '0001020304050607',
+ '20212223',
+ '7162015b',
+ '4dac255d',
+ '404142434445464748494a4b4c4d4e4f',
+ '10111213141516'),
+ ( '000102030405060708090a0b0c0d0e0f',
+ '202122232425262728292a2b2c2d2e2f',
+ 'd2a1f0e051ea5f62081a7792073d593d',
+ '1fc64fbfaccd',
+ '404142434445464748494a4b4c4d4e4f',
+ '1011121314151617'),
+ ( '000102030405060708090a0b0c0d0e0f10111213',
+ '202122232425262728292a2b2c2d2e2f3031323334353637',
+ 'e3b201a9f5b71a7a9b1ceaeccd97e70b6176aad9a4428aa5',
+ '484392fbc1b09951',
+ '404142434445464748494a4b4c4d4e4f',
+ '101112131415161718191a1b'),
+ ( (''.join(["%02X" % (x*16+y) for x in range(0,16) for y in range(0,16)]))*256,
+ '202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f',
+ '69915dad1e84c6376a68c2967e4dab615ae0fd1faec44cc484828529463ccf72',
+ 'b4ac6bec93e8598e7f0dadbcea5b',
+ '404142434445464748494a4b4c4d4e4f',
+ '101112131415161718191a1b1c'),
+ # RFC3610
+ ( '0001020304050607',
+ '08090a0b0c0d0e0f101112131415161718191a1b1c1d1e',
+ '588c979a61c663d2f066d0c2c0f989806d5f6b61dac384',
+ '17e8d12cfdf926e0',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '00000003020100a0a1a2a3a4a5'),
+ (
+ '0001020304050607',
+ '08090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ '72c91a36e135f8cf291ca894085c87e3cc15c439c9e43a3b',
+ 'a091d56e10400916',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '00000004030201a0a1a2a3a4a5'),
+ ( '0001020304050607',
+ '08090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20',
+ '51b1e5f44a197d1da46b0f8e2d282ae871e838bb64da859657',
+ '4adaa76fbd9fb0c5',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '00000005040302A0A1A2A3A4A5'),
+ ( '000102030405060708090a0b',
+ '0c0d0e0f101112131415161718191a1b1c1d1e',
+ 'a28c6865939a9a79faaa5c4c2a9d4a91cdac8c',
+ '96c861b9c9e61ef1',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '00000006050403a0a1a2a3a4a5'),
+ ( '000102030405060708090a0b',
+ '0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ 'dcf1fb7b5d9e23fb9d4e131253658ad86ebdca3e',
+ '51e83f077d9c2d93',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '00000007060504a0a1a2a3a4a5'),
+ ( '000102030405060708090a0b',
+ '0c0d0e0f101112131415161718191a1b1c1d1e1f20',
+ '6fc1b011f006568b5171a42d953d469b2570a4bd87',
+ '405a0443ac91cb94',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '00000008070605a0a1a2a3a4a5'),
+ ( '0001020304050607',
+ '08090a0b0c0d0e0f101112131415161718191a1b1c1d1e',
+ '0135d1b2c95f41d5d1d4fec185d166b8094e999dfed96c',
+ '048c56602c97acbb7490',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '00000009080706a0a1a2a3a4a5'),
+ ( '0001020304050607',
+ '08090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ '7b75399ac0831dd2f0bbd75879a2fd8f6cae6b6cd9b7db24',
+ 'c17b4433f434963f34b4',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '0000000a090807a0a1a2a3a4a5'),
+ ( '0001020304050607',
+ '08090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20',
+ '82531a60cc24945a4b8279181ab5c84df21ce7f9b73f42e197',
+ 'ea9c07e56b5eb17e5f4e',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '0000000b0a0908a0a1a2a3a4a5'),
+ ( '000102030405060708090a0b',
+ '0c0d0e0f101112131415161718191a1b1c1d1e',
+ '07342594157785152b074098330abb141b947b',
+ '566aa9406b4d999988dd',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '0000000c0b0a09a0a1a2a3a4a5'),
+ ( '000102030405060708090a0b',
+ '0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ '676bb20380b0e301e8ab79590a396da78b834934',
+ 'f53aa2e9107a8b6c022c',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '0000000d0c0b0aa0a1a2a3a4a5'),
+ ( '000102030405060708090a0b',
+ '0c0d0e0f101112131415161718191a1b1c1d1e1f20',
+ 'c0ffa0d6f05bdb67f24d43a4338d2aa4bed7b20e43',
+ 'cd1aa31662e7ad65d6db',
+ 'c0c1c2c3c4c5c6c7c8c9cacbcccdcecf',
+ '0000000e0d0c0ba0a1a2a3a4a5'),
+ ( '0be1a88bace018b1',
+ '08e8cf97d820ea258460e96ad9cf5289054d895ceac47c',
+ '4cb97f86a2a4689a877947ab8091ef5386a6ffbdd080f8',
+ 'e78cf7cb0cddd7b3',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '00412b4ea9cdbe3c9696766cfa'),
+ ( '63018f76dc8a1bcb',
+ '9020ea6f91bdd85afa0039ba4baff9bfb79c7028949cd0ec',
+ '4ccb1e7ca981befaa0726c55d378061298c85c92814abc33',
+ 'c52ee81d7d77c08a',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '0033568ef7b2633c9696766cfa'),
+ ( 'aa6cfa36cae86b40',
+ 'b916e0eacc1c00d7dcec68ec0b3bbb1a02de8a2d1aa346132e',
+ 'b1d23a2220ddc0ac900d9aa03c61fcf4a559a4417767089708',
+ 'a776796edb723506',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '00103fe41336713c9696766cfa'),
+ ( 'd0d0735c531e1becf049c244',
+ '12daac5630efa5396f770ce1a66b21f7b2101c',
+ '14d253c3967b70609b7cbb7c49916028324526',
+ '9a6f49975bcadeaf',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '00764c63b8058e3c9696766cfa'),
+ ( '77b60f011c03e1525899bcae',
+ 'e88b6a46c78d63e52eb8c546efb5de6f75e9cc0d',
+ '5545ff1a085ee2efbf52b2e04bee1e2336c73e3f',
+ '762c0c7744fe7e3c',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '00f8b678094e3b3c9696766cfa'),
+ ( 'cd9044d2b71fdb8120ea60c0',
+ '6435acbafb11a82e2f071d7ca4a5ebd93a803ba87f',
+ '009769ecabdf48625594c59251e6035722675e04c8',
+ '47099e5ae0704551',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '00d560912d3f703c9696766cfa'),
+ ( 'd85bc7e69f944fb8',
+ '8a19b950bcf71a018e5e6701c91787659809d67dbedd18',
+ 'bc218daa947427b6db386a99ac1aef23ade0b52939cb6a',
+ '637cf9bec2408897c6ba',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '0042fff8f1951c3c9696766cfa'),
+ ( '74a0ebc9069f5b37',
+ '1761433c37c5a35fc1f39f406302eb907c6163be38c98437',
+ '5810e6fd25874022e80361a478e3e9cf484ab04f447efff6',
+ 'f0a477cc2fc9bf548944',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '00920f40e56cdc3c9696766cfa'),
+ ( '44a3aa3aae6475ca',
+ 'a434a8e58500c6e41530538862d686ea9e81301b5ae4226bfa',
+ 'f2beed7bc5098e83feb5b31608f8e29c38819a89c8e776f154',
+ '4d4151a4ed3a8b87b9ce',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '0027ca0c7120bc3c9696766cfa'),
+ ( 'ec46bb63b02520c33c49fd70',
+ 'b96b49e21d621741632875db7f6c9243d2d7c2',
+ '31d750a09da3ed7fddd49a2032aabf17ec8ebf',
+ '7d22c8088c666be5c197',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '005b8ccbcd9af83c9696766cfa'),
+ ( '47a65ac78b3d594227e85e71',
+ 'e2fcfbb880442c731bf95167c8ffd7895e337076',
+ 'e882f1dbd38ce3eda7c23f04dd65071eb41342ac',
+ 'df7e00dccec7ae52987d',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '003ebe94044b9a3c9696766cfa'),
+ ( '6e37a6ef546d955d34ab6059',
+ 'abf21c0b02feb88f856df4a37381bce3cc128517d4',
+ 'f32905b88a641b04b9c9ffb58cc390900f3da12ab1',
+ '6dce9e82efa16da62059',
+ 'd7828d13b2b0bdc325a76236df93cc6b',
+ '008d493b30ae8b3c9696766cfa'),
+ ]
+
+ test_vectors = [[unhexlify(x) for x in tv] for tv in test_vectors_hex]
+
+ def runTest(self):
+ for assoc_data, pt, ct, mac, key, nonce in self.test_vectors:
+ # Encrypt
+ cipher = AES.new(key, AES.MODE_CCM, nonce, mac_len=len(mac))
+ cipher.update(assoc_data)
+ ct2, mac2 = cipher.encrypt_and_digest(pt)
+ self.assertEqual(ct, ct2)
+ self.assertEqual(mac, mac2)
+
+ # Decrypt
+ cipher = AES.new(key, AES.MODE_CCM, nonce, mac_len=len(mac))
+ cipher.update(assoc_data)
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(pt, pt2)
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings, **extra_params):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._extra_params = extra_params
+ self._id = "None"
+
+ def setUp(self):
+
+ def filter_tag(group):
+ return group['tagSize'] // 8
+
+ self.tv = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ "aes_ccm_test.json",
+ "Wycheproof AES CCM",
+ group_tag={'tag_size': filter_tag})
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_encrypt(self, tv):
+ self._id = "Wycheproof Encrypt CCM Test #" + str(tv.id)
+
+ try:
+ cipher = AES.new(tv.key, AES.MODE_CCM, tv.iv, mac_len=tv.tag_size,
+ **self._extra_params)
+ except ValueError as e:
+ if len(tv.iv) not in range(7, 13 + 1, 2) and "Length of parameter 'nonce'" in str(e):
+ assert not tv.valid
+ return
+ if tv.tag_size not in range(4, 16 + 1, 2) and "Parameter 'mac_len'" in str(e):
+ assert not tv.valid
+ return
+ raise e
+
+ cipher.update(tv.aad)
+ ct, tag = cipher.encrypt_and_digest(tv.msg)
+ if tv.valid:
+ self.assertEqual(ct, tv.ct)
+ self.assertEqual(tag, tv.tag)
+ self.warn(tv)
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt CCM Test #" + str(tv.id)
+
+ try:
+ cipher = AES.new(tv.key, AES.MODE_CCM, tv.iv, mac_len=tv.tag_size,
+ **self._extra_params)
+ except ValueError as e:
+ if len(tv.iv) not in range(7, 13 + 1, 2) and "Length of parameter 'nonce'" in str(e):
+ assert not tv.valid
+ return
+ if tv.tag_size not in range(4, 16 + 1, 2) and "Parameter 'mac_len'" in str(e):
+ assert not tv.valid
+ return
+ raise e
+
+ cipher.update(tv.aad)
+ try:
+ pt = cipher.decrypt_and_verify(tv.ct, tv.tag)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+ self.warn(tv)
+
+ def test_corrupt_decrypt(self, tv):
+ self._id = "Wycheproof Corrupt Decrypt CCM Test #" + str(tv.id)
+ if len(tv.iv) not in range(7, 13 + 1, 2) or len(tv.ct) == 0:
+ return
+ cipher = AES.new(tv.key, AES.MODE_CCM, tv.iv, mac_len=tv.tag_size,
+ **self._extra_params)
+ cipher.update(tv.aad)
+ ct_corrupt = strxor(tv.ct, b"\x00" * (len(tv.ct) - 1) + b"\x01")
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct_corrupt, tv.tag)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_encrypt(tv)
+ self.test_decrypt(tv)
+ self.test_corrupt_decrypt(tv)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(CcmTests)
+ tests += list_test_cases(CcmFSMTests)
+ tests += [TestVectors()]
+ tests += [TestVectorsWycheproof(wycheproof_warnings)]
+
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_CFB.py b/lib/Crypto/SelfTest/Cipher/test_CFB.py
new file mode 100644
index 0000000..cb0c352
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_CFB.py
@@ -0,0 +1,411 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.py3compat import tobytes, is_string
+from Crypto.Cipher import AES, DES3, DES
+from Crypto.Hash import SHAKE128
+
+from Crypto.SelfTest.Cipher.test_CBC import BlockChainingTests
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class CfbTests(BlockChainingTests):
+
+ aes_mode = AES.MODE_CFB
+ des3_mode = DES3.MODE_CFB
+
+ # Redefine test_unaligned_data_128/64
+
+ def test_unaligned_data_128(self):
+ plaintexts = [ b"7777777" ] * 100
+
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=8)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=8)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=128)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=128)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ def test_unaligned_data_64(self):
+ plaintexts = [ b"7777777" ] * 100
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=8)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=8)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=64)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=64)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ # Extra
+
+ def test_segment_size_128(self):
+ for bits in range(8, 129, 8):
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128,
+ segment_size=bits)
+
+ for bits in 0, 7, 9, 127, 129:
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CFB,
+ self.iv_128,
+ segment_size=bits)
+
+ def test_segment_size_64(self):
+ for bits in range(8, 65, 8):
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64,
+ segment_size=bits)
+
+ for bits in 0, 7, 9, 63, 65:
+ self.assertRaises(ValueError, DES3.new, self.key_192, AES.MODE_CFB,
+ self.iv_64,
+ segment_size=bits)
+
+
+class NistCfbVectors(unittest.TestCase):
+
+ def _do_kat_aes_test(self, file_name, segment_size):
+
+ test_vectors = load_test_vectors(("Cipher", "AES"),
+ file_name,
+ "AES CFB%d KAT" % segment_size,
+ { "count" : lambda x: int(x) } )
+ if test_vectors is None:
+ return
+
+ direction = None
+ for tv in test_vectors:
+
+ # The test vector file contains some directive lines
+ if is_string(tv):
+ direction = tv
+ continue
+
+ self.description = tv.desc
+ cipher = AES.new(tv.key, AES.MODE_CFB, tv.iv,
+ segment_size=segment_size)
+ if direction == "[ENCRYPT]":
+ self.assertEqual(cipher.encrypt(tv.plaintext), tv.ciphertext)
+ elif direction == "[DECRYPT]":
+ self.assertEqual(cipher.decrypt(tv.ciphertext), tv.plaintext)
+ else:
+ assert False
+
+ # See Section 6.4.5 in AESAVS
+ def _do_mct_aes_test(self, file_name, segment_size):
+
+ test_vectors = load_test_vectors(("Cipher", "AES"),
+ file_name,
+ "AES CFB%d Montecarlo" % segment_size,
+ { "count" : lambda x: int(x) } )
+ if test_vectors is None:
+ return
+
+ assert(segment_size in (8, 128))
+
+ direction = None
+ for tv in test_vectors:
+
+ # The test vector file contains some directive lines
+ if is_string(tv):
+ direction = tv
+ continue
+
+ self.description = tv.desc
+ cipher = AES.new(tv.key, AES.MODE_CFB, tv.iv,
+ segment_size=segment_size)
+
+ def get_input(input_text, output_seq, j):
+ # CFB128
+ if segment_size == 128:
+ if j >= 2:
+ return output_seq[-2]
+ return [input_text, tv.iv][j]
+ # CFB8
+ if j == 0:
+ return input_text
+ elif j <= 16:
+ return tv.iv[j - 1:j]
+ return output_seq[j - 17]
+
+ if direction == '[ENCRYPT]':
+ cts = []
+ for j in range(1000):
+ plaintext = get_input(tv.plaintext, cts, j)
+ cts.append(cipher.encrypt(plaintext))
+ self.assertEqual(cts[-1], tv.ciphertext)
+ elif direction == '[DECRYPT]':
+ pts = []
+ for j in range(1000):
+ ciphertext = get_input(tv.ciphertext, pts, j)
+ pts.append(cipher.decrypt(ciphertext))
+ self.assertEqual(pts[-1], tv.plaintext)
+ else:
+ assert False
+
+ def _do_tdes_test(self, file_name, segment_size):
+
+ test_vectors = load_test_vectors(("Cipher", "TDES"),
+ file_name,
+ "TDES CFB%d KAT" % segment_size,
+ { "count" : lambda x: int(x) } )
+ if test_vectors is None:
+ return
+
+ direction = None
+ for tv in test_vectors:
+
+ # The test vector file contains some directive lines
+ if is_string(tv):
+ direction = tv
+ continue
+
+ self.description = tv.desc
+ if hasattr(tv, "keys"):
+ cipher = DES.new(tv.keys, DES.MODE_CFB, tv.iv,
+ segment_size=segment_size)
+ else:
+ if tv.key1 != tv.key3:
+ key = tv.key1 + tv.key2 + tv.key3 # Option 3
+ else:
+ key = tv.key1 + tv.key2 # Option 2
+ cipher = DES3.new(key, DES3.MODE_CFB, tv.iv,
+ segment_size=segment_size)
+ if direction == "[ENCRYPT]":
+ self.assertEqual(cipher.encrypt(tv.plaintext), tv.ciphertext)
+ elif direction == "[DECRYPT]":
+ self.assertEqual(cipher.decrypt(tv.ciphertext), tv.plaintext)
+ else:
+ assert False
+
+
+# Create one test method per file
+nist_aes_kat_mmt_files = (
+ # KAT
+ "CFB?GFSbox128.rsp",
+ "CFB?GFSbox192.rsp",
+ "CFB?GFSbox256.rsp",
+ "CFB?KeySbox128.rsp",
+ "CFB?KeySbox192.rsp",
+ "CFB?KeySbox256.rsp",
+ "CFB?VarKey128.rsp",
+ "CFB?VarKey192.rsp",
+ "CFB?VarKey256.rsp",
+ "CFB?VarTxt128.rsp",
+ "CFB?VarTxt192.rsp",
+ "CFB?VarTxt256.rsp",
+ # MMT
+ "CFB?MMT128.rsp",
+ "CFB?MMT192.rsp",
+ "CFB?MMT256.rsp",
+ )
+nist_aes_mct_files = (
+ "CFB?MCT128.rsp",
+ "CFB?MCT192.rsp",
+ "CFB?MCT256.rsp",
+ )
+
+for file_gen_name in nist_aes_kat_mmt_files:
+ for bits in "8", "128":
+ file_name = file_gen_name.replace("?", bits)
+ def new_func(self, file_name=file_name, bits=bits):
+ self._do_kat_aes_test(file_name, int(bits))
+ setattr(NistCfbVectors, "test_AES_" + file_name, new_func)
+
+for file_gen_name in nist_aes_mct_files:
+ for bits in "8", "128":
+ file_name = file_gen_name.replace("?", bits)
+ def new_func(self, file_name=file_name, bits=bits):
+ self._do_mct_aes_test(file_name, int(bits))
+ setattr(NistCfbVectors, "test_AES_" + file_name, new_func)
+del file_name, new_func
+
+nist_tdes_files = (
+ "TCFB?MMT2.rsp", # 2TDES
+ "TCFB?MMT3.rsp", # 3TDES
+ "TCFB?invperm.rsp", # Single DES
+ "TCFB?permop.rsp",
+ "TCFB?subtab.rsp",
+ "TCFB?varkey.rsp",
+ "TCFB?vartext.rsp",
+ )
+
+for file_gen_name in nist_tdes_files:
+ for bits in "8", "64":
+ file_name = file_gen_name.replace("?", bits)
+ def new_func(self, file_name=file_name, bits=bits):
+ self._do_tdes_test(file_name, int(bits))
+ setattr(NistCfbVectors, "test_TDES_" + file_name, new_func)
+
+# END OF NIST CBC TEST VECTORS
+
+
+class SP800TestVectors(unittest.TestCase):
+ """Class exercising the CFB test vectors found in Section F.3
+ of NIST SP 800-3A"""
+
+ def test_aes_128_cfb8(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172aae2d'
+ ciphertext = '3b79424c9c0dd436bace9e0ed4586a4f32b9'
+ key = '2b7e151628aed2a6abf7158809cf4f3c'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=8)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=8)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_192_cfb8(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172aae2d'
+ ciphertext = 'cda2521ef0a905ca44cd057cbf0d47a0678a'
+ key = '8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=8)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=8)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_256_cfb8(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172aae2d'
+ ciphertext = 'dc1f1a8520a64db55fcc8ac554844e889700'
+ key = '603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=8)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=8)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_128_cfb128(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = '3b3fd92eb72dad20333449f8e83cfb4a' +\
+ 'c8a64537a0b3a93fcde3cdad9f1ce58b' +\
+ '26751f67a3cbb140b1808cf187a4f4df' +\
+ 'c04b05357c5d1c0eeac4c66f9ff7f2e6'
+ key = '2b7e151628aed2a6abf7158809cf4f3c'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=128)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=128)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_192_cfb128(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = 'cdc80d6fddf18cab34c25909c99a4174' +\
+ '67ce7f7f81173621961a2b70171d3d7a' +\
+ '2e1e8a1dd59b88b1c8e60fed1efac4c9' +\
+ 'c05f9f9ca9834fa042ae8fba584b09ff'
+ key = '8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=128)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=128)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_256_cfb128(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+
+ ciphertext = 'dc7e84bfda79164b7ecd8486985d3860' +\
+ '39ffed143b28b1c832113c6331e5407b' +\
+ 'df10132415e54b92a13ed0a8267ae2f9' +\
+ '75a385741ab9cef82031623d55b1e471'
+ key = '603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=128)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CFB, iv, segment_size=128)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(CfbTests)
+ if config.get('slow_tests'):
+ tests += list_test_cases(NistCfbVectors)
+ tests += list_test_cases(SP800TestVectors)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_CTR.py b/lib/Crypto/SelfTest/Cipher/test_CTR.py
new file mode 100644
index 0000000..6fc43ef
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_CTR.py
@@ -0,0 +1,472 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import hexlify, unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.py3compat import tobytes, bchr
+from Crypto.Cipher import AES, DES3
+from Crypto.Hash import SHAKE128, SHA256
+from Crypto.Util import Counter
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+class CtrTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ key_192 = get_tag_random("key_192", 24)
+ nonce_32 = get_tag_random("nonce_32", 4)
+ nonce_64 = get_tag_random("nonce_64", 8)
+ ctr_64 = Counter.new(32, prefix=nonce_32)
+ ctr_128 = Counter.new(64, prefix=nonce_64)
+
+ def test_loopback_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_loopback_64(self):
+ cipher = DES3.new(self.key_192, DES3.MODE_CTR, counter=self.ctr_64)
+ pt = get_tag_random("plaintext", 8 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = DES3.new(self.key_192, DES3.MODE_CTR, counter=self.ctr_64)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_invalid_counter_parameter(self):
+ # Counter object is required for ciphers with short block size
+ self.assertRaises(TypeError, DES3.new, self.key_192, AES.MODE_CTR)
+ # Positional arguments are not allowed (Counter must be passed as
+ # keyword)
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CTR, self.ctr_128)
+
+ def test_nonce_attribute(self):
+ # Nonce attribute is the prefix passed to Counter (DES3)
+ cipher = DES3.new(self.key_192, DES3.MODE_CTR, counter=self.ctr_64)
+ self.assertEqual(cipher.nonce, self.nonce_32)
+
+ # Nonce attribute is the prefix passed to Counter (AES)
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ self.assertEqual(cipher.nonce, self.nonce_64)
+
+ # Nonce attribute is not defined if suffix is used in Counter
+ counter = Counter.new(64, prefix=self.nonce_32, suffix=self.nonce_32)
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ self.assertFalse(hasattr(cipher, "nonce"))
+
+ def test_nonce_parameter(self):
+ # Nonce parameter becomes nonce attribute
+ cipher1 = AES.new(self.key_128, AES.MODE_CTR, nonce=self.nonce_64)
+ self.assertEqual(cipher1.nonce, self.nonce_64)
+
+ counter = Counter.new(64, prefix=self.nonce_64, initial_value=0)
+ cipher2 = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ pt = get_tag_random("plaintext", 65536)
+ self.assertEqual(cipher1.encrypt(pt), cipher2.encrypt(pt))
+
+ # Nonce is implicitly created (for AES) when no parameters are passed
+ nonce1 = AES.new(self.key_128, AES.MODE_CTR).nonce
+ nonce2 = AES.new(self.key_128, AES.MODE_CTR).nonce
+ self.assertNotEqual(nonce1, nonce2)
+ self.assertEqual(len(nonce1), 8)
+
+ # Nonce can be zero-length
+ cipher = AES.new(self.key_128, AES.MODE_CTR, nonce=b"")
+ self.assertEqual(b"", cipher.nonce)
+ cipher.encrypt(b'0'*300)
+
+ # Nonce and Counter are mutually exclusive
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CTR,
+ counter=self.ctr_128, nonce=self.nonce_64)
+
+ def test_initial_value_parameter(self):
+ # Test with nonce parameter
+ cipher1 = AES.new(self.key_128, AES.MODE_CTR,
+ nonce=self.nonce_64, initial_value=0xFFFF)
+ counter = Counter.new(64, prefix=self.nonce_64, initial_value=0xFFFF)
+ cipher2 = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ pt = get_tag_random("plaintext", 65536)
+ self.assertEqual(cipher1.encrypt(pt), cipher2.encrypt(pt))
+
+ # Test without nonce parameter
+ cipher1 = AES.new(self.key_128, AES.MODE_CTR,
+ initial_value=0xFFFF)
+ counter = Counter.new(64, prefix=cipher1.nonce, initial_value=0xFFFF)
+ cipher2 = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ pt = get_tag_random("plaintext", 65536)
+ self.assertEqual(cipher1.encrypt(pt), cipher2.encrypt(pt))
+
+ # Initial_value and Counter are mutually exclusive
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CTR,
+ counter=self.ctr_128, initial_value=0)
+
+ def test_initial_value_bytes_parameter(self):
+ # Same result as when passing an integer
+ cipher1 = AES.new(self.key_128, AES.MODE_CTR,
+ nonce=self.nonce_64,
+ initial_value=b"\x00"*6+b"\xFF\xFF")
+ cipher2 = AES.new(self.key_128, AES.MODE_CTR,
+ nonce=self.nonce_64, initial_value=0xFFFF)
+ pt = get_tag_random("plaintext", 65536)
+ self.assertEqual(cipher1.encrypt(pt), cipher2.encrypt(pt))
+
+ # Fail if the iv is too large
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CTR,
+ initial_value=b"5"*17)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CTR,
+ nonce=self.nonce_64, initial_value=b"5"*9)
+
+ # Fail if the iv is too short
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CTR,
+ initial_value=b"5"*15)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CTR,
+ nonce=self.nonce_64, initial_value=b"5"*7)
+
+ def test_iv_with_matching_length(self):
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CTR,
+ counter=Counter.new(120))
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_CTR,
+ counter=Counter.new(136))
+
+ def test_block_size_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ self.assertEqual(cipher.block_size, AES.block_size)
+
+ def test_block_size_64(self):
+ cipher = DES3.new(self.key_192, DES3.MODE_CTR, counter=self.ctr_64)
+ self.assertEqual(cipher.block_size, DES3.block_size)
+
+ def test_unaligned_data_128(self):
+ plaintexts = [ b"7777777" ] * 100
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ def test_unaligned_data_64(self):
+ plaintexts = [ b"7777777" ] * 100
+ cipher = DES3.new(self.key_192, AES.MODE_CTR, counter=self.ctr_64)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = DES3.new(self.key_192, AES.MODE_CTR, counter=self.ctr_64)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ cipher = DES3.new(self.key_192, AES.MODE_CTR, counter=self.ctr_64)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = DES3.new(self.key_192, AES.MODE_CTR, counter=self.ctr_64)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CTR,
+ 7, counter=self.ctr_128)
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_CTR,
+ counter=self.ctr_128, unknown=7)
+ # But some are only known by the base cipher (e.g. use_aesni consumed by the AES module)
+ AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128, use_aesni=False)
+
+ def test_null_encryption_decryption(self):
+ for func in "encrypt", "decrypt":
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ result = getattr(cipher, func)(b"")
+ self.assertEqual(result, b"")
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ cipher.encrypt(b"")
+ self.assertRaises(TypeError, cipher.decrypt, b"")
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=self.ctr_128)
+ cipher.decrypt(b"")
+ self.assertRaises(TypeError, cipher.encrypt, b"")
+
+ def test_wrap_around(self):
+ # Counter is only 8 bits, so we can only encrypt/decrypt 256 blocks (=4096 bytes)
+ counter = Counter.new(8, prefix=bchr(9) * 15)
+ max_bytes = 4096
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ cipher.encrypt(b'9' * max_bytes)
+ self.assertRaises(OverflowError, cipher.encrypt, b'9')
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ self.assertRaises(OverflowError, cipher.encrypt, b'9' * (max_bytes + 1))
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ cipher.decrypt(b'9' * max_bytes)
+ self.assertRaises(OverflowError, cipher.decrypt, b'9')
+
+ cipher = AES.new(self.key_128, AES.MODE_CTR, counter=counter)
+ self.assertRaises(OverflowError, cipher.decrypt, b'9' * (max_bytes + 1))
+
+ def test_bytearray(self):
+ data = b"1" * 16
+ iv = b"\x00" * 6 + b"\xFF\xFF"
+
+ # Encrypt
+ cipher1 = AES.new(self.key_128, AES.MODE_CTR,
+ nonce=self.nonce_64,
+ initial_value=iv)
+ ref1 = cipher1.encrypt(data)
+
+ cipher2 = AES.new(self.key_128, AES.MODE_CTR,
+ nonce=bytearray(self.nonce_64),
+ initial_value=bytearray(iv))
+ ref2 = cipher2.encrypt(bytearray(data))
+
+ self.assertEqual(ref1, ref2)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ cipher3 = AES.new(self.key_128, AES.MODE_CTR,
+ nonce=self.nonce_64,
+ initial_value=iv)
+ ref3 = cipher3.decrypt(data)
+
+ cipher4 = AES.new(self.key_128, AES.MODE_CTR,
+ nonce=bytearray(self.nonce_64),
+ initial_value=bytearray(iv))
+ ref4 = cipher4.decrypt(bytearray(data))
+
+ self.assertEqual(ref3, ref4)
+
+ def test_very_long_data(self):
+ cipher = AES.new(b'A' * 32, AES.MODE_CTR, nonce=b'')
+ ct = cipher.encrypt(b'B' * 1000000)
+ digest = SHA256.new(ct).hexdigest()
+ self.assertEqual(digest, "96204fc470476561a3a8f3b6fe6d24be85c87510b638142d1d0fb90989f8a6a6")
+
+ def test_output_param(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(128)
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ def test_output_param_memoryview(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ ct = cipher.encrypt(pt)
+
+ output = memoryview(bytearray(128))
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ def test_output_param_neg(self):
+ LEN_PT = 128
+
+ pt = b'5' * LEN_PT
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0' * LEN_PT)
+
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0' * LEN_PT)
+
+ shorter_output = bytearray(LEN_PT - 1)
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ cipher = AES.new(b'4'*16, AES.MODE_CTR, nonce=self.nonce_64)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+class SP800TestVectors(unittest.TestCase):
+ """Class exercising the CTR test vectors found in Section F.5
+ of NIST SP 800-38A"""
+
+ def test_aes_128(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = '874d6191b620e3261bef6864990db6ce' +\
+ '9806f66b7970fdff8617187bb9fffdff' +\
+ '5ae4df3edbd5d35e5b4f09020db03eab' +\
+ '1e031dda2fbe03d1792170a0f3009cee'
+ key = '2b7e151628aed2a6abf7158809cf4f3c'
+ counter = Counter.new(nbits=16,
+ prefix=unhexlify('f0f1f2f3f4f5f6f7f8f9fafbfcfd'),
+ initial_value=0xfeff)
+
+ key = unhexlify(key)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CTR, counter=counter)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CTR, counter=counter)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_192(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = '1abc932417521ca24f2b0459fe7e6e0b' +\
+ '090339ec0aa6faefd5ccc2c6f4ce8e94' +\
+ '1e36b26bd1ebc670d1bd1d665620abf7' +\
+ '4f78a7f6d29809585a97daec58c6b050'
+ key = '8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b'
+ counter = Counter.new(nbits=16,
+ prefix=unhexlify('f0f1f2f3f4f5f6f7f8f9fafbfcfd'),
+ initial_value=0xfeff)
+
+ key = unhexlify(key)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CTR, counter=counter)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CTR, counter=counter)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ def test_aes_256(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = '601ec313775789a5b7a7f504bbf3d228' +\
+ 'f443e3ca4d62b59aca84e990cacaf5c5' +\
+ '2b0930daa23de94ce87017ba2d84988d' +\
+ 'dfc9c58db67aada613c2dd08457941a6'
+ key = '603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4'
+ counter = Counter.new(nbits=16,
+ prefix=unhexlify('f0f1f2f3f4f5f6f7f8f9fafbfcfd'),
+ initial_value=0xfeff)
+ key = unhexlify(key)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_CTR, counter=counter)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_CTR, counter=counter)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+
+class RFC3686TestVectors(unittest.TestCase):
+
+ # Each item is a test vector with:
+ # - plaintext
+ # - ciphertext
+ # - key (AES 128, 192 or 256 bits)
+ # - counter prefix (4 byte nonce + 8 byte nonce)
+ data = (
+ ('53696e676c6520626c6f636b206d7367',
+ 'e4095d4fb7a7b3792d6175a3261311b8',
+ 'ae6852f8121067cc4bf7a5765577f39e',
+ '000000300000000000000000'),
+ ('000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ '5104a106168a72d9790d41ee8edad388eb2e1efc46da57c8fce630df9141be28',
+ '7e24067817fae0d743d6ce1f32539163',
+ '006cb6dbc0543b59da48d90b'),
+ ('000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223',
+ 'c1cf48a89f2ffdd9cf4652e9efdb72d74540a42bde6d7836d59a5ceaaef3105325b2072f',
+ '7691be035e5020a8ac6e618529f9a0dc',
+ '00e0017b27777f3f4a1786f0'),
+ ('53696e676c6520626c6f636b206d7367',
+ '4b55384fe259c9c84e7935a003cbe928',
+ '16af5b145fc9f579c175f93e3bfb0eed863d06ccfdb78515',
+ '0000004836733c147d6d93cb'),
+ ('000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ '453243fc609b23327edfaafa7131cd9f8490701c5ad4a79cfc1fe0ff42f4fb00',
+ '7c5cb2401b3dc33c19e7340819e0f69c678c3db8e6f6a91a',
+ '0096b03b020c6eadc2cb500d'),
+ ('000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223',
+ '96893fc55e5c722f540b7dd1ddf7e758d288bc95c69165884536c811662f2188abee0935',
+ '02bf391ee8ecb159b959617b0965279bf59b60a786d3e0fe',
+ '0007bdfd5cbd60278dcc0912'),
+ ('53696e676c6520626c6f636b206d7367',
+ '145ad01dbf824ec7560863dc71e3e0c0',
+ '776beff2851db06f4c8a0542c8696f6c6a81af1eec96b4d37fc1d689e6c1c104',
+ '00000060db5672c97aa8f0b2'),
+ ('000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f',
+ 'f05e231b3894612c49ee000b804eb2a9b8306b508f839d6a5530831d9344af1c',
+ 'f6d66d6bd52d59bb0796365879eff886c66dd51a5b6a99744b50590c87a23884',
+ '00faac24c1585ef15a43d875'),
+ ('000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223',
+ 'eb6c52821d0bbbf7ce7594462aca4faab407df866569fd07f48cc0b583d6071f1ec0e6b8',
+ 'ff7a617ce69148e4f1726e2f43581de2aa62d9f805532edff1eed687fb54153d',
+ '001cc5b751a51d70a1c11148')
+ )
+
+ bindata = []
+ for tv in data:
+ bindata.append([unhexlify(x) for x in tv])
+
+ def runTest(self):
+ for pt, ct, key, prefix in self.bindata:
+ counter = Counter.new(32, prefix=prefix)
+ cipher = AES.new(key, AES.MODE_CTR, counter=counter)
+ result = cipher.encrypt(pt)
+ self.assertEqual(hexlify(ct), hexlify(result))
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(CtrTests)
+ tests += list_test_cases(SP800TestVectors)
+ tests += [ RFC3686TestVectors() ]
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_ChaCha20.py b/lib/Crypto/SelfTest/Cipher/test_ChaCha20.py
new file mode 100644
index 0000000..4396ac2
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_ChaCha20.py
@@ -0,0 +1,529 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import os
+import re
+import unittest
+from binascii import hexlify, unhexlify
+
+from Crypto.Util.py3compat import b, tobytes, bchr
+from Crypto.Util.strxor import strxor_c
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Cipher import ChaCha20
+
+
+class ChaCha20Test(unittest.TestCase):
+
+ def test_new_positive(self):
+ cipher = ChaCha20.new(key=b("0")*32, nonce=b"0"*8)
+ self.assertEqual(cipher.nonce, b"0" * 8)
+ cipher = ChaCha20.new(key=b("0")*32, nonce=b"0"*12)
+ self.assertEqual(cipher.nonce, b"0" * 12)
+
+ def test_new_negative(self):
+ new = ChaCha20.new
+ self.assertRaises(TypeError, new)
+ self.assertRaises(TypeError, new, nonce=b("0"))
+ self.assertRaises(ValueError, new, nonce=b("0")*8, key=b("0"))
+ self.assertRaises(ValueError, new, nonce=b("0"), key=b("0")*32)
+
+ def test_default_nonce(self):
+ cipher1 = ChaCha20.new(key=bchr(1) * 32)
+ cipher2 = ChaCha20.new(key=bchr(1) * 32)
+ self.assertEqual(len(cipher1.nonce), 8)
+ self.assertNotEqual(cipher1.nonce, cipher2.nonce)
+
+ def test_nonce(self):
+ key = b'A' * 32
+
+ nonce1 = b'P' * 8
+ cipher1 = ChaCha20.new(key=key, nonce=nonce1)
+ self.assertEqual(nonce1, cipher1.nonce)
+
+ nonce2 = b'Q' * 12
+ cipher2 = ChaCha20.new(key=key, nonce=nonce2)
+ self.assertEqual(nonce2, cipher2.nonce)
+
+ def test_eiter_encrypt_or_decrypt(self):
+ """Verify that a cipher cannot be used for both decrypting and encrypting"""
+
+ c1 = ChaCha20.new(key=b("5") * 32, nonce=b("6") * 8)
+ c1.encrypt(b("8"))
+ self.assertRaises(TypeError, c1.decrypt, b("9"))
+
+ c2 = ChaCha20.new(key=b("5") * 32, nonce=b("6") * 8)
+ c2.decrypt(b("8"))
+ self.assertRaises(TypeError, c2.encrypt, b("9"))
+
+ def test_round_trip(self):
+ pt = b("A") * 1024
+ c1 = ChaCha20.new(key=b("5") * 32, nonce=b("6") * 8)
+ c2 = ChaCha20.new(key=b("5") * 32, nonce=b("6") * 8)
+ ct = c1.encrypt(pt)
+ self.assertEqual(c2.decrypt(ct), pt)
+
+ self.assertEqual(c1.encrypt(b("")), b(""))
+ self.assertEqual(c2.decrypt(b("")), b(""))
+
+ def test_streaming(self):
+ """Verify that an arbitrary number of bytes can be encrypted/decrypted"""
+ from Crypto.Hash import SHA1
+
+ segments = (1, 3, 5, 7, 11, 17, 23)
+ total = sum(segments)
+
+ pt = b("")
+ while len(pt) < total:
+ pt += SHA1.new(pt).digest()
+
+ cipher1 = ChaCha20.new(key=b("7") * 32, nonce=b("t") * 8)
+ ct = cipher1.encrypt(pt)
+
+ cipher2 = ChaCha20.new(key=b("7") * 32, nonce=b("t") * 8)
+ cipher3 = ChaCha20.new(key=b("7") * 32, nonce=b("t") * 8)
+ idx = 0
+ for segment in segments:
+ self.assertEqual(cipher2.decrypt(ct[idx:idx+segment]), pt[idx:idx+segment])
+ self.assertEqual(cipher3.encrypt(pt[idx:idx+segment]), ct[idx:idx+segment])
+ idx += segment
+
+ def test_seek(self):
+ cipher1 = ChaCha20.new(key=b("9") * 32, nonce=b("e") * 8)
+
+ offset = 64 * 900 + 7
+ pt = b("1") * 64
+
+ cipher1.encrypt(b("0") * offset)
+ ct1 = cipher1.encrypt(pt)
+
+ cipher2 = ChaCha20.new(key=b("9") * 32, nonce=b("e") * 8)
+ cipher2.seek(offset)
+ ct2 = cipher2.encrypt(pt)
+
+ self.assertEqual(ct1, ct2)
+
+ def test_seek_tv(self):
+ # Test Vector #4, A.1 from
+ # http://tools.ietf.org/html/draft-nir-cfrg-chacha20-poly1305-04
+ key = bchr(0) + bchr(255) + bchr(0) * 30
+ nonce = bchr(0) * 8
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ cipher.seek(64 * 2)
+ expected_key_stream = unhexlify(b(
+ "72d54dfbf12ec44b362692df94137f32"
+ "8fea8da73990265ec1bbbea1ae9af0ca"
+ "13b25aa26cb4a648cb9b9d1be65b2c09"
+ "24a66c54d545ec1b7374f4872e99f096"
+ ))
+ ct = cipher.encrypt(bchr(0) * len(expected_key_stream))
+ self.assertEqual(expected_key_stream, ct)
+
+ def test_rfc7539(self):
+ # from https://tools.ietf.org/html/rfc7539 Annex A.1
+ # Each item is: key, nonce, block #, plaintext, ciphertext
+ tvs = [
+ # Test Vector #1
+ (
+ "00"*32,
+ "00"*12,
+ 0,
+ "00"*16*4,
+ "76b8e0ada0f13d90405d6ae55386bd28"
+ "bdd219b8a08ded1aa836efcc8b770dc7"
+ "da41597c5157488d7724e03fb8d84a37"
+ "6a43b8f41518a11cc387b669b2ee6586"
+ ),
+ # Test Vector #2
+ (
+ "00"*31 + "01",
+ "00"*11 + "02",
+ 1,
+ "416e79207375626d697373696f6e2074"
+ "6f20746865204945544620696e74656e"
+ "6465642062792074686520436f6e7472"
+ "696275746f7220666f72207075626c69"
+ "636174696f6e20617320616c6c206f72"
+ "2070617274206f6620616e2049455446"
+ "20496e7465726e65742d447261667420"
+ "6f722052464320616e6420616e792073"
+ "746174656d656e74206d616465207769"
+ "7468696e2074686520636f6e74657874"
+ "206f6620616e20494554462061637469"
+ "7669747920697320636f6e7369646572"
+ "656420616e20224945544620436f6e74"
+ "7269627574696f6e222e205375636820"
+ "73746174656d656e747320696e636c75"
+ "6465206f72616c2073746174656d656e"
+ "747320696e2049455446207365737369"
+ "6f6e732c2061732077656c6c20617320"
+ "7772697474656e20616e6420656c6563"
+ "74726f6e696320636f6d6d756e696361"
+ "74696f6e73206d61646520617420616e"
+ "792074696d65206f7220706c6163652c"
+ "20776869636820617265206164647265"
+ "7373656420746f",
+ "a3fbf07df3fa2fde4f376ca23e827370"
+ "41605d9f4f4f57bd8cff2c1d4b7955ec"
+ "2a97948bd3722915c8f3d337f7d37005"
+ "0e9e96d647b7c39f56e031ca5eb6250d"
+ "4042e02785ececfa4b4bb5e8ead0440e"
+ "20b6e8db09d881a7c6132f420e527950"
+ "42bdfa7773d8a9051447b3291ce1411c"
+ "680465552aa6c405b7764d5e87bea85a"
+ "d00f8449ed8f72d0d662ab052691ca66"
+ "424bc86d2df80ea41f43abf937d3259d"
+ "c4b2d0dfb48a6c9139ddd7f76966e928"
+ "e635553ba76c5c879d7b35d49eb2e62b"
+ "0871cdac638939e25e8a1e0ef9d5280f"
+ "a8ca328b351c3c765989cbcf3daa8b6c"
+ "cc3aaf9f3979c92b3720fc88dc95ed84"
+ "a1be059c6499b9fda236e7e818b04b0b"
+ "c39c1e876b193bfe5569753f88128cc0"
+ "8aaa9b63d1a16f80ef2554d7189c411f"
+ "5869ca52c5b83fa36ff216b9c1d30062"
+ "bebcfd2dc5bce0911934fda79a86f6e6"
+ "98ced759c3ff9b6477338f3da4f9cd85"
+ "14ea9982ccafb341b2384dd902f3d1ab"
+ "7ac61dd29c6f21ba5b862f3730e37cfd"
+ "c4fd806c22f221"
+ ),
+ # Test Vector #3
+ (
+ "1c9240a5eb55d38af333888604f6b5f0"
+ "473917c1402b80099dca5cbc207075c0",
+ "00"*11 + "02",
+ 42,
+ "2754776173206272696c6c69672c2061"
+ "6e642074686520736c6974687920746f"
+ "7665730a446964206779726520616e64"
+ "2067696d626c6520696e207468652077"
+ "6162653a0a416c6c206d696d73792077"
+ "6572652074686520626f726f676f7665"
+ "732c0a416e6420746865206d6f6d6520"
+ "7261746873206f757467726162652e",
+ "62e6347f95ed87a45ffae7426f27a1df"
+ "5fb69110044c0d73118effa95b01e5cf"
+ "166d3df2d721caf9b21e5fb14c616871"
+ "fd84c54f9d65b283196c7fe4f60553eb"
+ "f39c6402c42234e32a356b3e764312a6"
+ "1a5532055716ead6962568f87d3f3f77"
+ "04c6a8d1bcd1bf4d50d6154b6da731b1"
+ "87b58dfd728afa36757a797ac188d1"
+ )
+ ]
+
+ for tv in tvs:
+ key = unhexlify(tv[0])
+ nonce = unhexlify(tv[1])
+ offset = tv[2] * 64
+ pt = unhexlify(tv[3])
+ ct_expect = unhexlify(tv[4])
+
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ if offset != 0:
+ cipher.seek(offset)
+ ct = cipher.encrypt(pt)
+ assert(ct == ct_expect)
+
+
+class XChaCha20Test(unittest.TestCase):
+
+ # From https://tools.ietf.org/html/draft-arciszewski-xchacha-03
+
+ def test_hchacha20(self):
+ # Section 2.2.1
+
+ from Crypto.Cipher.ChaCha20 import _HChaCha20
+
+ key = b"00:01:02:03:04:05:06:07:08:09:0a:0b:0c:0d:0e:0f:10:11:12:13:14:15:16:17:18:19:1a:1b:1c:1d:1e:1f"
+ key = unhexlify(key.replace(b":", b""))
+
+ nonce = b"00:00:00:09:00:00:00:4a:00:00:00:00:31:41:59:27"
+ nonce = unhexlify(nonce.replace(b":", b""))
+
+ subkey = _HChaCha20(key, nonce)
+
+ expected = b"82413b42 27b27bfe d30e4250 8a877d73 a0f9e4d5 8a74a853 c12ec413 26d3ecdc"
+ expected = unhexlify(expected.replace(b" ", b""))
+
+ self.assertEqual(subkey, expected)
+
+ def test_nonce(self):
+ key = b'A' * 32
+ nonce = b'P' * 24
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ self.assertEqual(nonce, cipher.nonce)
+
+ def test_encrypt(self):
+ # Section A.3.2
+
+ pt = b"""
+ 5468652064686f6c65202870726f6e6f756e6365642022646f6c652229206973
+ 20616c736f206b6e6f776e2061732074686520417369617469632077696c6420
+ 646f672c2072656420646f672c20616e642077686973746c696e6720646f672e
+ 2049742069732061626f7574207468652073697a65206f662061204765726d61
+ 6e20736865706865726420627574206c6f6f6b73206d6f7265206c696b652061
+ 206c6f6e672d6c656767656420666f782e205468697320686967686c7920656c
+ 757369766520616e6420736b696c6c6564206a756d70657220697320636c6173
+ 736966696564207769746820776f6c7665732c20636f796f7465732c206a6163
+ 6b616c732c20616e6420666f78657320696e20746865207461786f6e6f6d6963
+ 2066616d696c792043616e696461652e"""
+ pt = unhexlify(pt.replace(b"\n", b"").replace(b" ", b""))
+
+ key = unhexlify(b"808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f")
+ iv = unhexlify(b"404142434445464748494a4b4c4d4e4f5051525354555658")
+
+ ct = b"""
+ 7d0a2e6b7f7c65a236542630294e063b7ab9b555a5d5149aa21e4ae1e4fbce87
+ ecc8e08a8b5e350abe622b2ffa617b202cfad72032a3037e76ffdcdc4376ee05
+ 3a190d7e46ca1de04144850381b9cb29f051915386b8a710b8ac4d027b8b050f
+ 7cba5854e028d564e453b8a968824173fc16488b8970cac828f11ae53cabd201
+ 12f87107df24ee6183d2274fe4c8b1485534ef2c5fbc1ec24bfc3663efaa08bc
+ 047d29d25043532db8391a8a3d776bf4372a6955827ccb0cdd4af403a7ce4c63
+ d595c75a43e045f0cce1f29c8b93bd65afc5974922f214a40b7c402cdb91ae73
+ c0b63615cdad0480680f16515a7ace9d39236464328a37743ffc28f4ddb324f4
+ d0f5bbdc270c65b1749a6efff1fbaa09536175ccd29fb9e6057b307320d31683
+ 8a9c71f70b5b5907a66f7ea49aadc409"""
+ ct = unhexlify(ct.replace(b"\n", b"").replace(b" ", b""))
+
+ cipher = ChaCha20.new(key=key, nonce=iv)
+ cipher.seek(64) # Counter = 1
+ ct_test = cipher.encrypt(pt)
+ self.assertEqual(ct, ct_test)
+
+
+class ByteArrayTest(unittest.TestCase):
+ """Verify we can encrypt or decrypt bytearrays"""
+
+ def runTest(self):
+
+ data = b"0123"
+ key = b"9" * 32
+ nonce = b"t" * 8
+
+ # Encryption
+ data_ba = bytearray(data)
+ key_ba = bytearray(key)
+ nonce_ba = bytearray(nonce)
+
+ cipher1 = ChaCha20.new(key=key, nonce=nonce)
+ ct = cipher1.encrypt(data)
+
+ cipher2 = ChaCha20.new(key=key_ba, nonce=nonce_ba)
+ key_ba[:1] = b'\xFF'
+ nonce_ba[:1] = b'\xFF'
+ ct_test = cipher2.encrypt(data_ba)
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decryption
+ key_ba = bytearray(key)
+ nonce_ba = bytearray(nonce)
+ ct_ba = bytearray(ct)
+
+ cipher3 = ChaCha20.new(key=key_ba, nonce=nonce_ba)
+ key_ba[:1] = b'\xFF'
+ nonce_ba[:1] = b'\xFF'
+ pt_test = cipher3.decrypt(ct_ba)
+
+ self.assertEqual(data, pt_test)
+
+
+class MemoryviewTest(unittest.TestCase):
+ """Verify we can encrypt or decrypt bytearrays"""
+
+ def runTest(self):
+
+ data = b"0123"
+ key = b"9" * 32
+ nonce = b"t" * 8
+
+ # Encryption
+ data_mv = memoryview(bytearray(data))
+ key_mv = memoryview(bytearray(key))
+ nonce_mv = memoryview(bytearray(nonce))
+
+ cipher1 = ChaCha20.new(key=key, nonce=nonce)
+ ct = cipher1.encrypt(data)
+
+ cipher2 = ChaCha20.new(key=key_mv, nonce=nonce_mv)
+ key_mv[:1] = b'\xFF'
+ nonce_mv[:1] = b'\xFF'
+ ct_test = cipher2.encrypt(data_mv)
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decryption
+ key_mv = memoryview(bytearray(key))
+ nonce_mv = memoryview(bytearray(nonce))
+ ct_mv = memoryview(bytearray(ct))
+
+ cipher3 = ChaCha20.new(key=key_mv, nonce=nonce_mv)
+ key_mv[:1] = b'\xFF'
+ nonce_mv[:1] = b'\xFF'
+ pt_test = cipher3.decrypt(ct_mv)
+
+ self.assertEqual(data, pt_test)
+
+
+class ChaCha20_AGL_NIR(unittest.TestCase):
+
+ # From http://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-04
+ # and http://tools.ietf.org/html/draft-nir-cfrg-chacha20-poly1305-04
+ tv = [
+ ( "00" * 32,
+ "00" * 8,
+ "76b8e0ada0f13d90405d6ae55386bd28bdd219b8a08ded1aa836efcc"
+ "8b770dc7da41597c5157488d7724e03fb8d84a376a43b8f41518a11c"
+ "c387b669b2ee6586"
+ "9f07e7be5551387a98ba977c732d080d"
+ "cb0f29a048e3656912c6533e32ee7aed"
+ "29b721769ce64e43d57133b074d839d5"
+ "31ed1f28510afb45ace10a1f4b794d6f"
+ ),
+ ( "00" * 31 + "01",
+ "00" * 8,
+ "4540f05a9f1fb296d7736e7b208e3c96eb4fe1834688d2604f450952"
+ "ed432d41bbe2a0b6ea7566d2a5d1e7e20d42af2c53d792b1c43fea81"
+ "7e9ad275ae546963"
+ "3aeb5224ecf849929b9d828db1ced4dd"
+ "832025e8018b8160b82284f3c949aa5a"
+ "8eca00bbb4a73bdad192b5c42f73f2fd"
+ "4e273644c8b36125a64addeb006c13a0"
+ ),
+ ( "00" * 32,
+ "00" * 7 + "01",
+ "de9cba7bf3d69ef5e786dc63973f653a0b49e015adbff7134fcb7df1"
+ "37821031e85a050278a7084527214f73efc7fa5b5277062eb7a0433e"
+ "445f41e3"
+ ),
+ ( "00" * 32,
+ "01" + "00" * 7,
+ "ef3fdfd6c61578fbf5cf35bd3dd33b8009631634d21e42ac33960bd1"
+ "38e50d32111e4caf237ee53ca8ad6426194a88545ddc497a0b466e7d"
+ "6bbdb0041b2f586b"
+ ),
+ ( "000102030405060708090a0b0c0d0e0f101112131415161718191a1b"
+ "1c1d1e1f",
+ "0001020304050607",
+ "f798a189f195e66982105ffb640bb7757f579da31602fc93ec01ac56"
+ "f85ac3c134a4547b733b46413042c9440049176905d3be59ea1c53f1"
+ "5916155c2be8241a38008b9a26bc35941e2444177c8ade6689de9526"
+ "4986d95889fb60e84629c9bd9a5acb1cc118be563eb9b3a4a472f82e"
+ "09a7e778492b562ef7130e88dfe031c79db9d4f7c7a899151b9a4750"
+ "32b63fc385245fe054e3dd5a97a5f576fe064025d3ce042c566ab2c5"
+ "07b138db853e3d6959660996546cc9c4a6eafdc777c040d70eaf46f7"
+ "6dad3979e5c5360c3317166a1c894c94a371876a94df7628fe4eaaf2"
+ "ccb27d5aaae0ad7ad0f9d4b6ad3b54098746d4524d38407a6deb3ab7"
+ "8fab78c9"
+ ),
+ ( "00" * 32,
+ "00" * 7 + "02",
+ "c2c64d378cd536374ae204b9ef933fcd"
+ "1a8b2288b3dfa49672ab765b54ee27c7"
+ "8a970e0e955c14f3a88e741b97c286f7"
+ "5f8fc299e8148362fa198a39531bed6d"
+ ),
+ ]
+
+ def runTest(self):
+ for (key, nonce, stream) in self.tv:
+ c = ChaCha20.new(key=unhexlify(b(key)), nonce=unhexlify(b(nonce)))
+ ct = unhexlify(b(stream))
+ pt = b("\x00") * len(ct)
+ self.assertEqual(c.encrypt(pt), ct)
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ key = b'4' * 32
+ nonce = b'5' * 8
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+
+ pt = b'5' * 300
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(len(pt))
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(len(pt)))
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*len(pt))
+
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*len(pt))
+
+ shorter_output = bytearray(len(pt) - 1)
+
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+
+ cipher = ChaCha20.new(key=key, nonce=nonce)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(ChaCha20Test)
+ tests += list_test_cases(XChaCha20Test)
+ tests.append(ChaCha20_AGL_NIR())
+ tests.append(ByteArrayTest())
+ tests.append(MemoryviewTest())
+ tests.append(TestOutput())
+
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_ChaCha20_Poly1305.py b/lib/Crypto/SelfTest/Cipher/test_ChaCha20_Poly1305.py
new file mode 100644
index 0000000..67440d7
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_ChaCha20_Poly1305.py
@@ -0,0 +1,770 @@
+# ===================================================================
+#
+# Copyright (c) 2018, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+from Crypto.Util.py3compat import tobytes
+from Crypto.Cipher import ChaCha20_Poly1305
+from Crypto.Hash import SHAKE128
+
+from Crypto.Util._file_system import pycryptodome_filename
+from Crypto.Util.strxor import strxor
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class ChaCha20Poly1305Tests(unittest.TestCase):
+
+ key_256 = get_tag_random("key_256", 32)
+ nonce_96 = get_tag_random("nonce_96", 12)
+ data_128 = get_tag_random("data_128", 16)
+
+ def test_loopback(self):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_nonce(self):
+ # Nonce can only be 8 or 12 bytes
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=b'H' * 8)
+ self.assertEqual(len(cipher.nonce), 8)
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=b'H' * 12)
+ self.assertEqual(len(cipher.nonce), 12)
+
+ # If not passed, the nonce is created randomly
+ cipher = ChaCha20_Poly1305.new(key=self.key_256)
+ nonce1 = cipher.nonce
+ cipher = ChaCha20_Poly1305.new(key=self.key_256)
+ nonce2 = cipher.nonce
+ self.assertEqual(len(nonce1), 12)
+ self.assertNotEqual(nonce1, nonce2)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data_128)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ self.assertEqual(ct, cipher.encrypt(self.data_128))
+
+ def test_nonce_must_be_bytes(self):
+ self.assertRaises(TypeError,
+ ChaCha20_Poly1305.new,
+ key=self.key_256,
+ nonce=u'test12345678')
+
+ def test_nonce_length(self):
+ # nonce can only be 8 or 12 bytes long
+ self.assertRaises(ValueError,
+ ChaCha20_Poly1305.new,
+ key=self.key_256,
+ nonce=b'0' * 7)
+ self.assertRaises(ValueError,
+ ChaCha20_Poly1305.new,
+ key=self.key_256,
+ nonce=b'')
+
+ def test_block_size(self):
+ # Not based on block ciphers
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ self.assertFalse(hasattr(cipher, 'block_size'))
+
+ def test_nonce_attribute(self):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ self.assertEqual(cipher.nonce, self.nonce_96)
+
+ # By default, a 12 bytes long nonce is randomly generated
+ nonce1 = ChaCha20_Poly1305.new(key=self.key_256).nonce
+ nonce2 = ChaCha20_Poly1305.new(key=self.key_256).nonce
+ self.assertEqual(len(nonce1), 12)
+ self.assertNotEqual(nonce1, nonce2)
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError,
+ ChaCha20_Poly1305.new,
+ key=self.key_256,
+ param=9)
+
+ def test_null_encryption_decryption(self):
+ for func in "encrypt", "decrypt":
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ result = getattr(cipher, func)(b"")
+ self.assertEqual(result, b"")
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.encrypt(b"")
+ self.assertRaises(TypeError, cipher.decrypt, b"")
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.decrypt(b"")
+ self.assertRaises(TypeError, cipher.encrypt, b"")
+
+ def test_data_must_be_bytes(self):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
+
+ def test_mac_len(self):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ _, mac = cipher.encrypt_and_digest(self.data_128)
+ self.assertEqual(len(mac), 16)
+
+ def test_invalid_mac(self):
+ from Crypto.Util.strxor import strxor_c
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ ct, mac = cipher.encrypt_and_digest(self.data_128)
+
+ invalid_mac = strxor_c(mac, 0x01)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct,
+ invalid_mac)
+
+ def test_hex_mac(self):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ mac_hex = cipher.hexdigest()
+ self.assertEqual(cipher.digest(), unhexlify(mac_hex))
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.hexverify(mac_hex)
+
+ def test_message_chunks(self):
+ # Validate that both associated data and plaintext/ciphertext
+ # can be broken up in chunks of arbitrary length
+
+ auth_data = get_tag_random("authenticated data", 127)
+ plaintext = get_tag_random("plaintext", 127)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(auth_data)
+ ciphertext, ref_mac = cipher.encrypt_and_digest(plaintext)
+
+ def break_up(data, chunk_length):
+ return [data[i:i+chunk_length] for i in range(0, len(data),
+ chunk_length)]
+
+ # Encryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ pt2 = b""
+ for chunk in break_up(ciphertext, chunk_length):
+ pt2 += cipher.decrypt(chunk)
+ self.assertEqual(plaintext, pt2)
+ cipher.verify(ref_mac)
+
+ # Decryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ ct2 = b""
+ for chunk in break_up(plaintext, chunk_length):
+ ct2 += cipher.encrypt(chunk)
+ self.assertEqual(ciphertext, ct2)
+ self.assertEqual(cipher.digest(), ref_mac)
+
+ def test_bytearray(self):
+
+ # Encrypt
+ key_ba = bytearray(self.key_256)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data_128)
+ data_ba = bytearray(self.data_128)
+
+ cipher1 = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher1.update(self.data_128)
+ ct = cipher1.encrypt(self.data_128)
+ tag = cipher1.digest()
+
+ cipher2 = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ key_ba[:3] = b'\xFF\xFF\xFF'
+ nonce_ba[:3] = b'\xFF\xFF\xFF'
+ cipher2.update(header_ba)
+ header_ba[:3] = b'\xFF\xFF\xFF'
+ ct_test = cipher2.encrypt(data_ba)
+ data_ba[:3] = b'\x99\x99\x99'
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_ba = bytearray(self.key_256)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data_128)
+ ct_ba = bytearray(ct)
+ tag_ba = bytearray(tag)
+ del data_ba
+
+ cipher3 = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ key_ba[:3] = b'\xFF\xFF\xFF'
+ nonce_ba[:3] = b'\xFF\xFF\xFF'
+ cipher3.update(header_ba)
+ header_ba[:3] = b'\xFF\xFF\xFF'
+ pt_test = cipher3.decrypt(ct_ba)
+ ct_ba[:3] = b'\xFF\xFF\xFF'
+ cipher3.verify(tag_ba)
+
+ self.assertEqual(pt_test, self.data_128)
+
+ def test_memoryview(self):
+
+ # Encrypt
+ key_mv = memoryview(bytearray(self.key_256))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data_128))
+ data_mv = memoryview(bytearray(self.data_128))
+
+ cipher1 = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher1.update(self.data_128)
+ ct = cipher1.encrypt(self.data_128)
+ tag = cipher1.digest()
+
+ cipher2 = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ key_mv[:3] = b'\xFF\xFF\xFF'
+ nonce_mv[:3] = b'\xFF\xFF\xFF'
+ cipher2.update(header_mv)
+ header_mv[:3] = b'\xFF\xFF\xFF'
+ ct_test = cipher2.encrypt(data_mv)
+ data_mv[:3] = b'\x99\x99\x99'
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_mv = memoryview(bytearray(self.key_256))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data_128))
+ ct_mv = memoryview(bytearray(ct))
+ tag_mv = memoryview(bytearray(tag))
+ del data_mv
+
+ cipher3 = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ key_mv[:3] = b'\xFF\xFF\xFF'
+ nonce_mv[:3] = b'\xFF\xFF\xFF'
+ cipher3.update(header_mv)
+ header_mv[:3] = b'\xFF\xFF\xFF'
+ pt_test = cipher3.decrypt(ct_mv)
+ ct_mv[:3] = b'\x99\x99\x99'
+ cipher3.verify(tag_mv)
+
+ self.assertEqual(pt_test, self.data_128)
+
+
+class XChaCha20Poly1305Tests(unittest.TestCase):
+
+ def test_encrypt(self):
+ # From https://tools.ietf.org/html/draft-arciszewski-xchacha-03
+ # Section A.3.1
+
+ pt = b"""
+ 4c616469657320616e642047656e746c656d656e206f662074686520636c6173
+ 73206f66202739393a204966204920636f756c64206f6666657220796f75206f
+ 6e6c79206f6e652074697020666f7220746865206675747572652c2073756e73
+ 637265656e20776f756c642062652069742e"""
+ pt = unhexlify(pt.replace(b"\n", b"").replace(b" ", b""))
+
+ aad = unhexlify(b"50515253c0c1c2c3c4c5c6c7")
+ key = unhexlify(b"808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f")
+ iv = unhexlify(b"404142434445464748494a4b4c4d4e4f5051525354555657")
+
+ ct = b"""
+ bd6d179d3e83d43b9576579493c0e939572a1700252bfaccbed2902c21396cbb
+ 731c7f1b0b4aa6440bf3a82f4eda7e39ae64c6708c54c216cb96b72e1213b452
+ 2f8c9ba40db5d945b11b69b982c1bb9e3f3fac2bc369488f76b2383565d3fff9
+ 21f9664c97637da9768812f615c68b13b52e"""
+ ct = unhexlify(ct.replace(b"\n", b"").replace(b" ", b""))
+
+ tag = unhexlify(b"c0875924c1c7987947deafd8780acf49")
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=iv)
+ cipher.update(aad)
+ ct_test, tag_test = cipher.encrypt_and_digest(pt)
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=iv)
+ cipher.update(aad)
+ cipher.decrypt_and_verify(ct, tag)
+
+
+class ChaCha20Poly1305FSMTests(unittest.TestCase):
+
+ key_256 = get_tag_random("key_256", 32)
+ nonce_96 = get_tag_random("nonce_96", 12)
+ data_128 = get_tag_random("data_128", 16)
+
+ def test_valid_init_encrypt_decrypt_digest_verify(self):
+ # No authenticated data, fixed plaintext
+ # Verify path INIT->ENCRYPT->DIGEST
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data_128)
+ mac = cipher.digest()
+
+ # Verify path INIT->DECRYPT->VERIFY
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_update_digest_verify(self):
+ # No plaintext, fixed authenticated data
+ # Verify path INIT->UPDATE->DIGEST
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->VERIFY
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ cipher.verify(mac)
+
+ def test_valid_full_path(self):
+ # Fixed authenticated data, fixed plaintext
+ # Verify path INIT->UPDATE->ENCRYPT->DIGEST
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ ct = cipher.encrypt(self.data_128)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->DECRYPT->VERIFY
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_digest(self):
+ # Verify path INIT->DIGEST
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.digest()
+
+ def test_valid_init_verify(self):
+ # Verify path INIT->VERIFY
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ mac = cipher.digest()
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.verify(mac)
+
+ def test_valid_multiple_encrypt_or_decrypt(self):
+ for method_name in "encrypt", "decrypt":
+ for auth_data in (None, b"333", self.data_128,
+ self.data_128 + b"3"):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ if auth_data is not None:
+ cipher.update(auth_data)
+ method = getattr(cipher, method_name)
+ method(self.data_128)
+ method(self.data_128)
+ method(self.data_128)
+ method(self.data_128)
+
+ def test_valid_multiple_digest_or_verify(self):
+ # Multiple calls to digest
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ first_mac = cipher.digest()
+ for x in range(4):
+ self.assertEqual(first_mac, cipher.digest())
+
+ # Multiple calls to verify
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ for x in range(5):
+ cipher.verify(first_mac)
+
+ def test_valid_encrypt_and_digest_decrypt_and_verify(self):
+ # encrypt_and_digest
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ ct, mac = cipher.encrypt_and_digest(self.data_128)
+
+ # decrypt_and_verify
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ pt = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(self.data_128, pt)
+
+ def test_invalid_mixing_encrypt_decrypt(self):
+ # Once per method, with or without assoc. data
+ for method1_name, method2_name in (("encrypt", "decrypt"),
+ ("decrypt", "encrypt")):
+ for assoc_data_present in (True, False):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ if assoc_data_present:
+ cipher.update(self.data_128)
+ getattr(cipher, method1_name)(self.data_128)
+ self.assertRaises(TypeError, getattr(cipher, method2_name),
+ self.data_128)
+
+ def test_invalid_encrypt_or_update_after_digest(self):
+ for method_name in "encrypt", "update":
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.encrypt(self.data_128)
+ cipher.digest()
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data_128)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.encrypt_and_digest(self.data_128)
+
+ def test_invalid_decrypt_or_update_after_verify(self):
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data_128)
+ mac = cipher.digest()
+
+ for method_name in "decrypt", "update":
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data_128)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data_128)
+
+ cipher = ChaCha20_Poly1305.new(key=self.key_256,
+ nonce=self.nonce_96)
+ cipher.decrypt_and_verify(ct, mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data_128)
+
+
+def compact(x):
+ return unhexlify(x.replace(" ", "").replace(":", ""))
+
+
+class TestVectorsRFC(unittest.TestCase):
+ """Test cases from RFC7539"""
+
+ # AAD, PT, CT, MAC, KEY, NONCE
+ test_vectors_hex = [
+ ( '50 51 52 53 c0 c1 c2 c3 c4 c5 c6 c7',
+ '4c 61 64 69 65 73 20 61 6e 64 20 47 65 6e 74 6c'
+ '65 6d 65 6e 20 6f 66 20 74 68 65 20 63 6c 61 73'
+ '73 20 6f 66 20 27 39 39 3a 20 49 66 20 49 20 63'
+ '6f 75 6c 64 20 6f 66 66 65 72 20 79 6f 75 20 6f'
+ '6e 6c 79 20 6f 6e 65 20 74 69 70 20 66 6f 72 20'
+ '74 68 65 20 66 75 74 75 72 65 2c 20 73 75 6e 73'
+ '63 72 65 65 6e 20 77 6f 75 6c 64 20 62 65 20 69'
+ '74 2e',
+ 'd3 1a 8d 34 64 8e 60 db 7b 86 af bc 53 ef 7e c2'
+ 'a4 ad ed 51 29 6e 08 fe a9 e2 b5 a7 36 ee 62 d6'
+ '3d be a4 5e 8c a9 67 12 82 fa fb 69 da 92 72 8b'
+ '1a 71 de 0a 9e 06 0b 29 05 d6 a5 b6 7e cd 3b 36'
+ '92 dd bd 7f 2d 77 8b 8c 98 03 ae e3 28 09 1b 58'
+ 'fa b3 24 e4 fa d6 75 94 55 85 80 8b 48 31 d7 bc'
+ '3f f4 de f0 8e 4b 7a 9d e5 76 d2 65 86 ce c6 4b'
+ '61 16',
+ '1a:e1:0b:59:4f:09:e2:6a:7e:90:2e:cb:d0:60:06:91',
+ '80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f'
+ '90 91 92 93 94 95 96 97 98 99 9a 9b 9c 9d 9e 9f',
+ '07 00 00 00' + '40 41 42 43 44 45 46 47',
+ ),
+ ( 'f3 33 88 86 00 00 00 00 00 00 4e 91',
+ '49 6e 74 65 72 6e 65 74 2d 44 72 61 66 74 73 20'
+ '61 72 65 20 64 72 61 66 74 20 64 6f 63 75 6d 65'
+ '6e 74 73 20 76 61 6c 69 64 20 66 6f 72 20 61 20'
+ '6d 61 78 69 6d 75 6d 20 6f 66 20 73 69 78 20 6d'
+ '6f 6e 74 68 73 20 61 6e 64 20 6d 61 79 20 62 65'
+ '20 75 70 64 61 74 65 64 2c 20 72 65 70 6c 61 63'
+ '65 64 2c 20 6f 72 20 6f 62 73 6f 6c 65 74 65 64'
+ '20 62 79 20 6f 74 68 65 72 20 64 6f 63 75 6d 65'
+ '6e 74 73 20 61 74 20 61 6e 79 20 74 69 6d 65 2e'
+ '20 49 74 20 69 73 20 69 6e 61 70 70 72 6f 70 72'
+ '69 61 74 65 20 74 6f 20 75 73 65 20 49 6e 74 65'
+ '72 6e 65 74 2d 44 72 61 66 74 73 20 61 73 20 72'
+ '65 66 65 72 65 6e 63 65 20 6d 61 74 65 72 69 61'
+ '6c 20 6f 72 20 74 6f 20 63 69 74 65 20 74 68 65'
+ '6d 20 6f 74 68 65 72 20 74 68 61 6e 20 61 73 20'
+ '2f e2 80 9c 77 6f 72 6b 20 69 6e 20 70 72 6f 67'
+ '72 65 73 73 2e 2f e2 80 9d',
+ '64 a0 86 15 75 86 1a f4 60 f0 62 c7 9b e6 43 bd'
+ '5e 80 5c fd 34 5c f3 89 f1 08 67 0a c7 6c 8c b2'
+ '4c 6c fc 18 75 5d 43 ee a0 9e e9 4e 38 2d 26 b0'
+ 'bd b7 b7 3c 32 1b 01 00 d4 f0 3b 7f 35 58 94 cf'
+ '33 2f 83 0e 71 0b 97 ce 98 c8 a8 4a bd 0b 94 81'
+ '14 ad 17 6e 00 8d 33 bd 60 f9 82 b1 ff 37 c8 55'
+ '97 97 a0 6e f4 f0 ef 61 c1 86 32 4e 2b 35 06 38'
+ '36 06 90 7b 6a 7c 02 b0 f9 f6 15 7b 53 c8 67 e4'
+ 'b9 16 6c 76 7b 80 4d 46 a5 9b 52 16 cd e7 a4 e9'
+ '90 40 c5 a4 04 33 22 5e e2 82 a1 b0 a0 6c 52 3e'
+ 'af 45 34 d7 f8 3f a1 15 5b 00 47 71 8c bc 54 6a'
+ '0d 07 2b 04 b3 56 4e ea 1b 42 22 73 f5 48 27 1a'
+ '0b b2 31 60 53 fa 76 99 19 55 eb d6 31 59 43 4e'
+ 'ce bb 4e 46 6d ae 5a 10 73 a6 72 76 27 09 7a 10'
+ '49 e6 17 d9 1d 36 10 94 fa 68 f0 ff 77 98 71 30'
+ '30 5b ea ba 2e da 04 df 99 7b 71 4d 6c 6f 2c 29'
+ 'a6 ad 5c b4 02 2b 02 70 9b',
+ 'ee ad 9d 67 89 0c bb 22 39 23 36 fe a1 85 1f 38',
+ '1c 92 40 a5 eb 55 d3 8a f3 33 88 86 04 f6 b5 f0'
+ '47 39 17 c1 40 2b 80 09 9d ca 5c bc 20 70 75 c0',
+ '00 00 00 00 01 02 03 04 05 06 07 08',
+ )
+ ]
+
+ test_vectors = [[unhexlify(x.replace(" ","").replace(":","")) for x in tv] for tv in test_vectors_hex]
+
+ def runTest(self):
+ for assoc_data, pt, ct, mac, key, nonce in self.test_vectors:
+ # Encrypt
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ cipher.update(assoc_data)
+ ct2, mac2 = cipher.encrypt_and_digest(pt)
+ self.assertEqual(ct, ct2)
+ self.assertEqual(mac, mac2)
+
+ # Decrypt
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ cipher.update(assoc_data)
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(pt, pt2)
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._id = "None"
+
+ def load_tests(self, filename):
+
+ def filter_tag(group):
+ return group['tagSize'] // 8
+
+ def filter_algo(root):
+ return root['algorithm']
+
+ result = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ filename,
+ "Wycheproof ChaCha20-Poly1305",
+ root_tag={'algo': filter_algo},
+ group_tag={'tag_size': filter_tag})
+ return result
+
+ def setUp(self):
+ self.tv = []
+ self.tv.extend(self.load_tests("chacha20_poly1305_test.json"))
+ self.tv.extend(self.load_tests("xchacha20_poly1305_test.json"))
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_encrypt(self, tv):
+ self._id = "Wycheproof Encrypt %s Test #%s" % (tv.algo, tv.id)
+
+ try:
+ cipher = ChaCha20_Poly1305.new(key=tv.key, nonce=tv.iv)
+ except ValueError as e:
+ assert len(tv.iv) not in (8, 12) and "Nonce must be" in str(e)
+ return
+
+ cipher.update(tv.aad)
+ ct, tag = cipher.encrypt_and_digest(tv.msg)
+ if tv.valid:
+ self.assertEqual(ct, tv.ct)
+ self.assertEqual(tag, tv.tag)
+ self.warn(tv)
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt %s Test #%s" % (tv.algo, tv.id)
+
+ try:
+ cipher = ChaCha20_Poly1305.new(key=tv.key, nonce=tv.iv)
+ except ValueError as e:
+ assert len(tv.iv) not in (8, 12) and "Nonce must be" in str(e)
+ return
+
+ cipher.update(tv.aad)
+ try:
+ pt = cipher.decrypt_and_verify(tv.ct, tv.tag)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+ self.warn(tv)
+
+ def test_corrupt_decrypt(self, tv):
+ self._id = "Wycheproof Corrupt Decrypt ChaCha20-Poly1305 Test #" + str(tv.id)
+ if len(tv.iv) == 0 or len(tv.ct) < 1:
+ return
+ cipher = ChaCha20_Poly1305.new(key=tv.key, nonce=tv.iv)
+ cipher.update(tv.aad)
+ ct_corrupt = strxor(tv.ct, b"\x00" * (len(tv.ct) - 1) + b"\x01")
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct_corrupt, tv.tag)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_encrypt(tv)
+ self.test_decrypt(tv)
+ self.test_corrupt_decrypt(tv)
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ key = b'4' * 32
+ nonce = b'5' * 12
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+
+ pt = b'5' * 16
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(16)
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(16))
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
+
+ shorter_output = bytearray(7)
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+
+ cipher = ChaCha20_Poly1305.new(key=key, nonce=nonce)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(ChaCha20Poly1305Tests)
+ tests += list_test_cases(XChaCha20Poly1305Tests)
+ tests += list_test_cases(ChaCha20Poly1305FSMTests)
+ tests += [TestVectorsRFC()]
+ tests += [TestVectorsWycheproof(wycheproof_warnings)]
+ tests += [TestOutput()]
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_DES.py b/lib/Crypto/SelfTest/Cipher/test_DES.py
new file mode 100644
index 0000000..ee261bc
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_DES.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/DES.py: Self-test for the (Single) DES cipher
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.DES"""
+
+import unittest
+
+from Crypto.Cipher import DES
+
+# This is a list of (plaintext, ciphertext, key, description) tuples.
+SP800_17_B1_KEY = '01' * 8
+SP800_17_B2_PT = '00' * 8
+test_data = [
+ # Test vectors from Appendix A of NIST SP 800-17
+ # "Modes of Operation Validation System (MOVS): Requirements and Procedures"
+ # http://csrc.nist.gov/publications/nistpubs/800-17/800-17.pdf
+
+ # Appendix A - "Sample Round Outputs for the DES"
+ ('0000000000000000', '82dcbafbdeab6602', '10316e028c8f3b4a',
+ "NIST SP800-17 A"),
+
+ # Table B.1 - Variable Plaintext Known Answer Test
+ ('8000000000000000', '95f8a5e5dd31d900', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #0'),
+ ('4000000000000000', 'dd7f121ca5015619', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #1'),
+ ('2000000000000000', '2e8653104f3834ea', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #2'),
+ ('1000000000000000', '4bd388ff6cd81d4f', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #3'),
+ ('0800000000000000', '20b9e767b2fb1456', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #4'),
+ ('0400000000000000', '55579380d77138ef', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #5'),
+ ('0200000000000000', '6cc5defaaf04512f', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #6'),
+ ('0100000000000000', '0d9f279ba5d87260', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #7'),
+ ('0080000000000000', 'd9031b0271bd5a0a', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #8'),
+ ('0040000000000000', '424250b37c3dd951', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #9'),
+ ('0020000000000000', 'b8061b7ecd9a21e5', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #10'),
+ ('0010000000000000', 'f15d0f286b65bd28', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #11'),
+ ('0008000000000000', 'add0cc8d6e5deba1', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #12'),
+ ('0004000000000000', 'e6d5f82752ad63d1', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #13'),
+ ('0002000000000000', 'ecbfe3bd3f591a5e', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #14'),
+ ('0001000000000000', 'f356834379d165cd', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #15'),
+ ('0000800000000000', '2b9f982f20037fa9', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #16'),
+ ('0000400000000000', '889de068a16f0be6', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #17'),
+ ('0000200000000000', 'e19e275d846a1298', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #18'),
+ ('0000100000000000', '329a8ed523d71aec', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #19'),
+ ('0000080000000000', 'e7fce22557d23c97', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #20'),
+ ('0000040000000000', '12a9f5817ff2d65d', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #21'),
+ ('0000020000000000', 'a484c3ad38dc9c19', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #22'),
+ ('0000010000000000', 'fbe00a8a1ef8ad72', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #23'),
+ ('0000008000000000', '750d079407521363', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #24'),
+ ('0000004000000000', '64feed9c724c2faf', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #25'),
+ ('0000002000000000', 'f02b263b328e2b60', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #26'),
+ ('0000001000000000', '9d64555a9a10b852', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #27'),
+ ('0000000800000000', 'd106ff0bed5255d7', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #28'),
+ ('0000000400000000', 'e1652c6b138c64a5', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #29'),
+ ('0000000200000000', 'e428581186ec8f46', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #30'),
+ ('0000000100000000', 'aeb5f5ede22d1a36', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #31'),
+ ('0000000080000000', 'e943d7568aec0c5c', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #32'),
+ ('0000000040000000', 'df98c8276f54b04b', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #33'),
+ ('0000000020000000', 'b160e4680f6c696f', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #34'),
+ ('0000000010000000', 'fa0752b07d9c4ab8', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #35'),
+ ('0000000008000000', 'ca3a2b036dbc8502', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #36'),
+ ('0000000004000000', '5e0905517bb59bcf', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #37'),
+ ('0000000002000000', '814eeb3b91d90726', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #38'),
+ ('0000000001000000', '4d49db1532919c9f', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #39'),
+ ('0000000000800000', '25eb5fc3f8cf0621', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #40'),
+ ('0000000000400000', 'ab6a20c0620d1c6f', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #41'),
+ ('0000000000200000', '79e90dbc98f92cca', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #42'),
+ ('0000000000100000', '866ecedd8072bb0e', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #43'),
+ ('0000000000080000', '8b54536f2f3e64a8', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #44'),
+ ('0000000000040000', 'ea51d3975595b86b', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #45'),
+ ('0000000000020000', 'caffc6ac4542de31', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #46'),
+ ('0000000000010000', '8dd45a2ddf90796c', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #47'),
+ ('0000000000008000', '1029d55e880ec2d0', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #48'),
+ ('0000000000004000', '5d86cb23639dbea9', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #49'),
+ ('0000000000002000', '1d1ca853ae7c0c5f', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #50'),
+ ('0000000000001000', 'ce332329248f3228', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #51'),
+ ('0000000000000800', '8405d1abe24fb942', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #52'),
+ ('0000000000000400', 'e643d78090ca4207', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #53'),
+ ('0000000000000200', '48221b9937748a23', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #54'),
+ ('0000000000000100', 'dd7c0bbd61fafd54', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #55'),
+ ('0000000000000080', '2fbc291a570db5c4', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #56'),
+ ('0000000000000040', 'e07c30d7e4e26e12', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #57'),
+ ('0000000000000020', '0953e2258e8e90a1', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #58'),
+ ('0000000000000010', '5b711bc4ceebf2ee', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #59'),
+ ('0000000000000008', 'cc083f1e6d9e85f6', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #60'),
+ ('0000000000000004', 'd2fd8867d50d2dfe', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #61'),
+ ('0000000000000002', '06e7ea22ce92708f', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #62'),
+ ('0000000000000001', '166b40b44aba4bd6', SP800_17_B1_KEY,
+ 'NIST SP800-17 B.1 #63'),
+
+ # Table B.2 - Variable Key Known Answer Test
+ (SP800_17_B2_PT, '95a8d72813daa94d', '8001010101010101',
+ 'NIST SP800-17 B.2 #0'),
+ (SP800_17_B2_PT, '0eec1487dd8c26d5', '4001010101010101',
+ 'NIST SP800-17 B.2 #1'),
+ (SP800_17_B2_PT, '7ad16ffb79c45926', '2001010101010101',
+ 'NIST SP800-17 B.2 #2'),
+ (SP800_17_B2_PT, 'd3746294ca6a6cf3', '1001010101010101',
+ 'NIST SP800-17 B.2 #3'),
+ (SP800_17_B2_PT, '809f5f873c1fd761', '0801010101010101',
+ 'NIST SP800-17 B.2 #4'),
+ (SP800_17_B2_PT, 'c02faffec989d1fc', '0401010101010101',
+ 'NIST SP800-17 B.2 #5'),
+ (SP800_17_B2_PT, '4615aa1d33e72f10', '0201010101010101',
+ 'NIST SP800-17 B.2 #6'),
+ (SP800_17_B2_PT, '2055123350c00858', '0180010101010101',
+ 'NIST SP800-17 B.2 #7'),
+ (SP800_17_B2_PT, 'df3b99d6577397c8', '0140010101010101',
+ 'NIST SP800-17 B.2 #8'),
+ (SP800_17_B2_PT, '31fe17369b5288c9', '0120010101010101',
+ 'NIST SP800-17 B.2 #9'),
+ (SP800_17_B2_PT, 'dfdd3cc64dae1642', '0110010101010101',
+ 'NIST SP800-17 B.2 #10'),
+ (SP800_17_B2_PT, '178c83ce2b399d94', '0108010101010101',
+ 'NIST SP800-17 B.2 #11'),
+ (SP800_17_B2_PT, '50f636324a9b7f80', '0104010101010101',
+ 'NIST SP800-17 B.2 #12'),
+ (SP800_17_B2_PT, 'a8468ee3bc18f06d', '0102010101010101',
+ 'NIST SP800-17 B.2 #13'),
+ (SP800_17_B2_PT, 'a2dc9e92fd3cde92', '0101800101010101',
+ 'NIST SP800-17 B.2 #14'),
+ (SP800_17_B2_PT, 'cac09f797d031287', '0101400101010101',
+ 'NIST SP800-17 B.2 #15'),
+ (SP800_17_B2_PT, '90ba680b22aeb525', '0101200101010101',
+ 'NIST SP800-17 B.2 #16'),
+ (SP800_17_B2_PT, 'ce7a24f350e280b6', '0101100101010101',
+ 'NIST SP800-17 B.2 #17'),
+ (SP800_17_B2_PT, '882bff0aa01a0b87', '0101080101010101',
+ 'NIST SP800-17 B.2 #18'),
+ (SP800_17_B2_PT, '25610288924511c2', '0101040101010101',
+ 'NIST SP800-17 B.2 #19'),
+ (SP800_17_B2_PT, 'c71516c29c75d170', '0101020101010101',
+ 'NIST SP800-17 B.2 #20'),
+ (SP800_17_B2_PT, '5199c29a52c9f059', '0101018001010101',
+ 'NIST SP800-17 B.2 #21'),
+ (SP800_17_B2_PT, 'c22f0a294a71f29f', '0101014001010101',
+ 'NIST SP800-17 B.2 #22'),
+ (SP800_17_B2_PT, 'ee371483714c02ea', '0101012001010101',
+ 'NIST SP800-17 B.2 #23'),
+ (SP800_17_B2_PT, 'a81fbd448f9e522f', '0101011001010101',
+ 'NIST SP800-17 B.2 #24'),
+ (SP800_17_B2_PT, '4f644c92e192dfed', '0101010801010101',
+ 'NIST SP800-17 B.2 #25'),
+ (SP800_17_B2_PT, '1afa9a66a6df92ae', '0101010401010101',
+ 'NIST SP800-17 B.2 #26'),
+ (SP800_17_B2_PT, 'b3c1cc715cb879d8', '0101010201010101',
+ 'NIST SP800-17 B.2 #27'),
+ (SP800_17_B2_PT, '19d032e64ab0bd8b', '0101010180010101',
+ 'NIST SP800-17 B.2 #28'),
+ (SP800_17_B2_PT, '3cfaa7a7dc8720dc', '0101010140010101',
+ 'NIST SP800-17 B.2 #29'),
+ (SP800_17_B2_PT, 'b7265f7f447ac6f3', '0101010120010101',
+ 'NIST SP800-17 B.2 #30'),
+ (SP800_17_B2_PT, '9db73b3c0d163f54', '0101010110010101',
+ 'NIST SP800-17 B.2 #31'),
+ (SP800_17_B2_PT, '8181b65babf4a975', '0101010108010101',
+ 'NIST SP800-17 B.2 #32'),
+ (SP800_17_B2_PT, '93c9b64042eaa240', '0101010104010101',
+ 'NIST SP800-17 B.2 #33'),
+ (SP800_17_B2_PT, '5570530829705592', '0101010102010101',
+ 'NIST SP800-17 B.2 #34'),
+ (SP800_17_B2_PT, '8638809e878787a0', '0101010101800101',
+ 'NIST SP800-17 B.2 #35'),
+ (SP800_17_B2_PT, '41b9a79af79ac208', '0101010101400101',
+ 'NIST SP800-17 B.2 #36'),
+ (SP800_17_B2_PT, '7a9be42f2009a892', '0101010101200101',
+ 'NIST SP800-17 B.2 #37'),
+ (SP800_17_B2_PT, '29038d56ba6d2745', '0101010101100101',
+ 'NIST SP800-17 B.2 #38'),
+ (SP800_17_B2_PT, '5495c6abf1e5df51', '0101010101080101',
+ 'NIST SP800-17 B.2 #39'),
+ (SP800_17_B2_PT, 'ae13dbd561488933', '0101010101040101',
+ 'NIST SP800-17 B.2 #40'),
+ (SP800_17_B2_PT, '024d1ffa8904e389', '0101010101020101',
+ 'NIST SP800-17 B.2 #41'),
+ (SP800_17_B2_PT, 'd1399712f99bf02e', '0101010101018001',
+ 'NIST SP800-17 B.2 #42'),
+ (SP800_17_B2_PT, '14c1d7c1cffec79e', '0101010101014001',
+ 'NIST SP800-17 B.2 #43'),
+ (SP800_17_B2_PT, '1de5279dae3bed6f', '0101010101012001',
+ 'NIST SP800-17 B.2 #44'),
+ (SP800_17_B2_PT, 'e941a33f85501303', '0101010101011001',
+ 'NIST SP800-17 B.2 #45'),
+ (SP800_17_B2_PT, 'da99dbbc9a03f379', '0101010101010801',
+ 'NIST SP800-17 B.2 #46'),
+ (SP800_17_B2_PT, 'b7fc92f91d8e92e9', '0101010101010401',
+ 'NIST SP800-17 B.2 #47'),
+ (SP800_17_B2_PT, 'ae8e5caa3ca04e85', '0101010101010201',
+ 'NIST SP800-17 B.2 #48'),
+ (SP800_17_B2_PT, '9cc62df43b6eed74', '0101010101010180',
+ 'NIST SP800-17 B.2 #49'),
+ (SP800_17_B2_PT, 'd863dbb5c59a91a0', '0101010101010140',
+ 'NIST SP800-17 B.2 #50'),
+ (SP800_17_B2_PT, 'a1ab2190545b91d7', '0101010101010120',
+ 'NIST SP800-17 B.2 #51'),
+ (SP800_17_B2_PT, '0875041e64c570f7', '0101010101010110',
+ 'NIST SP800-17 B.2 #52'),
+ (SP800_17_B2_PT, '5a594528bebef1cc', '0101010101010108',
+ 'NIST SP800-17 B.2 #53'),
+ (SP800_17_B2_PT, 'fcdb3291de21f0c0', '0101010101010104',
+ 'NIST SP800-17 B.2 #54'),
+ (SP800_17_B2_PT, '869efd7f9f265a09', '0101010101010102',
+ 'NIST SP800-17 B.2 #55'),
+]
+
+class RonRivestTest(unittest.TestCase):
+ """ Ronald L. Rivest's DES test, see
+ http://people.csail.mit.edu/rivest/Destest.txt
+ ABSTRACT
+ --------
+
+ We present a simple way to test the correctness of a DES implementation:
+ Use the recurrence relation:
+
+ X0 = 9474B8E8C73BCA7D (hexadecimal)
+
+ X(i+1) = IF (i is even) THEN E(Xi,Xi) ELSE D(Xi,Xi)
+
+ to compute a sequence of 64-bit values: X0, X1, X2, ..., X16. Here
+ E(X,K) denotes the DES encryption of X using key K, and D(X,K) denotes
+ the DES decryption of X using key K. If you obtain
+
+ X16 = 1B1A2DDB4C642438
+
+ your implementation does not have any of the 36,568 possible single-fault
+ errors described herein.
+ """
+ def runTest(self):
+ from binascii import b2a_hex
+
+ X = []
+ X[0:] = [b'\x94\x74\xB8\xE8\xC7\x3B\xCA\x7D']
+
+ for i in range(16):
+ c = DES.new(X[i],DES.MODE_ECB)
+ if not (i&1): # (num&1) returns 1 for odd numbers
+ X[i+1:] = [c.encrypt(X[i])] # even
+ else:
+ X[i+1:] = [c.decrypt(X[i])] # odd
+
+ self.assertEqual(b2a_hex(X[16]),
+ b2a_hex(b'\x1B\x1A\x2D\xDB\x4C\x64\x24\x38'))
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ cipher = DES.new(b'4'*8, DES.MODE_ECB)
+
+ pt = b'5' * 8
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(8)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(8))
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*8)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*8)
+
+ shorter_output = bytearray(7)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ from .common import make_block_tests
+ tests = make_block_tests(DES, "DES", test_data)
+ tests += [RonRivestTest()]
+ tests += [TestOutput()]
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_DES3.py b/lib/Crypto/SelfTest/Cipher/test_DES3.py
new file mode 100644
index 0000000..8d6a648
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_DES3.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/DES3.py: Self-test for the Triple-DES cipher
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.DES3"""
+
+import unittest
+from binascii import hexlify, unhexlify
+
+from Crypto.Cipher import DES3
+
+from Crypto.Util.strxor import strxor_c
+from Crypto.Util.py3compat import bchr, tostr
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+
+# This is a list of (plaintext, ciphertext, key, description) tuples.
+test_data = [
+ # Test vector from Appendix B of NIST SP 800-67
+ # "Recommendation for the Triple Data Encryption Algorithm (TDEA) Block
+ # Cipher"
+ # http://csrc.nist.gov/publications/nistpubs/800-67/SP800-67.pdf
+ ('54686520717566636b2062726f776e20666f78206a756d70',
+ 'a826fd8ce53b855fcce21c8112256fe668d5c05dd9b6b900',
+ '0123456789abcdef23456789abcdef01456789abcdef0123',
+ 'NIST SP800-67 B.1'),
+
+ # This test is designed to test the DES3 API, not the correctness of the
+ # output.
+ ('21e81b7ade88a259', '5c577d4d9b20c0f8',
+ '9b397ebf81b1181e282f4bb8adbadc6b', 'Two-key 3DES'),
+]
+
+# NIST CAVP test vectors
+
+nist_tdes_mmt_files = ("TECBMMT2.rsp", "TECBMMT3.rsp")
+
+for tdes_file in nist_tdes_mmt_files:
+
+ test_vectors = load_test_vectors(
+ ("Cipher", "TDES"),
+ tdes_file,
+ "TDES ECB (%s)" % tdes_file,
+ {"count": lambda x: int(x)}) or []
+
+ for index, tv in enumerate(test_vectors):
+
+ # The test vector file contains some directive lines
+ if isinstance(tv, str):
+ continue
+
+ key = tv.key1 + tv.key2 + tv.key3
+ test_data_item = (tostr(hexlify(tv.plaintext)),
+ tostr(hexlify(tv.ciphertext)),
+ tostr(hexlify(key)),
+ "%s (%s)" % (tdes_file, index))
+ test_data.append(test_data_item)
+
+
+class CheckParity(unittest.TestCase):
+
+ def test_parity_option2(self):
+ before_2k = unhexlify("CABF326FA56734324FFCCABCDEFACABF")
+ after_2k = DES3.adjust_key_parity(before_2k)
+ self.assertEqual(after_2k,
+ unhexlify("CBBF326EA46734324FFDCBBCDFFBCBBF"))
+
+ def test_parity_option3(self):
+ before_3k = unhexlify("AAAAAAAAAAAAAAAABBBBBBBBBBBBBBBBCCCCCCCCCCCCCCCC")
+ after_3k = DES3.adjust_key_parity(before_3k)
+ self.assertEqual(after_3k,
+ unhexlify("ABABABABABABABABBABABABABABABABACDCDCDCDCDCDCDCD"))
+
+ def test_degradation(self):
+ sub_key1 = bchr(1) * 8
+ sub_key2 = bchr(255) * 8
+
+ # K1 == K2
+ self.assertRaises(ValueError, DES3.adjust_key_parity,
+ sub_key1 * 2 + sub_key2)
+
+ # K2 == K3
+ self.assertRaises(ValueError, DES3.adjust_key_parity,
+ sub_key1 + sub_key2 * 2)
+
+ # K1 == K2 == K3
+ self.assertRaises(ValueError, DES3.adjust_key_parity,
+ sub_key1 * 3)
+
+ # K1 == K2 (with different parity)
+ self.assertRaises(ValueError, DES3.adjust_key_parity,
+ sub_key1 + strxor_c(sub_key1, 1) + sub_key2)
+
+
+class DegenerateToDESTest(unittest.TestCase):
+
+ def runTest(self):
+ sub_key1 = bchr(1) * 8
+ sub_key2 = bchr(255) * 8
+
+ # K1 == K2
+ self.assertRaises(ValueError, DES3.new,
+ sub_key1 * 2 + sub_key2,
+ DES3.MODE_ECB)
+
+ # K2 == K3
+ self.assertRaises(ValueError, DES3.new,
+ sub_key1 + sub_key2 * 2,
+ DES3.MODE_ECB)
+
+ # K1 == K2 == K3
+ self.assertRaises(ValueError, DES3.new,
+ sub_key1 * 3,
+ DES3.MODE_ECB)
+
+ # K2 == K3 (parity is ignored)
+ self.assertRaises(ValueError, DES3.new,
+ sub_key1 + sub_key2 + strxor_c(sub_key2, 0x1),
+ DES3.MODE_ECB)
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ cipher = DES3.new(b'4'*8 + b'G'*8 + b'T'*8, DES3.MODE_ECB)
+
+ pt = b'5' * 16
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(16)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(16))
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*16)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*16)
+
+ shorter_output = bytearray(7)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ from .common import make_block_tests
+
+ tests = []
+ tests = make_block_tests(DES3, "DES3", test_data)
+ tests.append(DegenerateToDESTest())
+ tests += list_test_cases(CheckParity)
+ tests += [TestOutput()]
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+
+ def suite():
+ unittest.TestSuite(get_tests())
+
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_EAX.py b/lib/Crypto/SelfTest/Cipher/test_EAX.py
new file mode 100644
index 0000000..fe93d71
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_EAX.py
@@ -0,0 +1,773 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+from Crypto.Util.py3compat import tobytes, bchr
+from Crypto.Cipher import AES, DES3
+from Crypto.Hash import SHAKE128
+
+from Crypto.Util.strxor import strxor
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class EaxTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ key_192 = get_tag_random("key_192", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data_128 = get_tag_random("data_128", 16)
+
+ def test_loopback_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_loopback_64(self):
+ cipher = DES3.new(self.key_192, DES3.MODE_EAX, nonce=self.nonce_96)
+ pt = get_tag_random("plaintext", 8 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = DES3.new(self.key_192, DES3.MODE_EAX, nonce=self.nonce_96)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_nonce(self):
+ # If not passed, the nonce is created randomly
+ cipher = AES.new(self.key_128, AES.MODE_EAX)
+ nonce1 = cipher.nonce
+ cipher = AES.new(self.key_128, AES.MODE_EAX)
+ nonce2 = cipher.nonce
+ self.assertEqual(len(nonce1), 16)
+ self.assertNotEqual(nonce1, nonce2)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, self.nonce_96)
+ ct = cipher.encrypt(self.data_128)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertEqual(ct, cipher.encrypt(self.data_128))
+
+ def test_nonce_must_be_bytes(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_EAX,
+ nonce=u'test12345678')
+
+ def test_nonce_length(self):
+ # nonce can be of any length (but not empty)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_EAX,
+ nonce=b"")
+
+ for x in range(1, 128):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=bchr(1) * x)
+ cipher.encrypt(bchr(1))
+
+ def test_block_size_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertEqual(cipher.block_size, AES.block_size)
+
+ def test_block_size_64(self):
+ cipher = DES3.new(self.key_192, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertEqual(cipher.block_size, DES3.block_size)
+
+ def test_nonce_attribute(self):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertEqual(cipher.nonce, self.nonce_96)
+
+ # By default, a 16 bytes long nonce is randomly generated
+ nonce1 = AES.new(self.key_128, AES.MODE_EAX).nonce
+ nonce2 = AES.new(self.key_128, AES.MODE_EAX).nonce
+ self.assertEqual(len(nonce1), 16)
+ self.assertNotEqual(nonce1, nonce2)
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_EAX,
+ self.nonce_96, 7)
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96, unknown=7)
+
+ # But some are only known by the base cipher
+ # (e.g. use_aesni consumed by the AES module)
+ AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96,
+ use_aesni=False)
+
+ def test_null_encryption_decryption(self):
+ for func in "encrypt", "decrypt":
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ result = getattr(cipher, func)(b"")
+ self.assertEqual(result, b"")
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.encrypt(b"")
+ self.assertRaises(TypeError, cipher.decrypt, b"")
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.decrypt(b"")
+ self.assertRaises(TypeError, cipher.encrypt, b"")
+
+ def test_data_must_be_bytes(self):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
+
+ def test_mac_len(self):
+ # Invalid MAC length
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96, mac_len=3)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96, mac_len=16+1)
+
+ # Valid MAC length
+ for mac_len in range(5, 16 + 1):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96,
+ mac_len=mac_len)
+ _, mac = cipher.encrypt_and_digest(self.data_128)
+ self.assertEqual(len(mac), mac_len)
+
+ # Default MAC length
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ _, mac = cipher.encrypt_and_digest(self.data_128)
+ self.assertEqual(len(mac), 16)
+
+ def test_invalid_mac(self):
+ from Crypto.Util.strxor import strxor_c
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ ct, mac = cipher.encrypt_and_digest(self.data_128)
+
+ invalid_mac = strxor_c(mac, 0x01)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct,
+ invalid_mac)
+
+ def test_hex_mac(self):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ mac_hex = cipher.hexdigest()
+ self.assertEqual(cipher.digest(), unhexlify(mac_hex))
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.hexverify(mac_hex)
+
+ def test_message_chunks(self):
+ # Validate that both associated data and plaintext/ciphertext
+ # can be broken up in chunks of arbitrary length
+
+ auth_data = get_tag_random("authenticated data", 127)
+ plaintext = get_tag_random("plaintext", 127)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.update(auth_data)
+ ciphertext, ref_mac = cipher.encrypt_and_digest(plaintext)
+
+ def break_up(data, chunk_length):
+ return [data[i:i+chunk_length] for i in range(0, len(data),
+ chunk_length)]
+
+ # Encryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ pt2 = b""
+ for chunk in break_up(ciphertext, chunk_length):
+ pt2 += cipher.decrypt(chunk)
+ self.assertEqual(plaintext, pt2)
+ cipher.verify(ref_mac)
+
+ # Decryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ ct2 = b""
+ for chunk in break_up(plaintext, chunk_length):
+ ct2 += cipher.encrypt(chunk)
+ self.assertEqual(ciphertext, ct2)
+ self.assertEqual(cipher.digest(), ref_mac)
+
+ def test_bytearray(self):
+
+ # Encrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data_128)
+ data_ba = bytearray(self.data_128)
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_EAX,
+ nonce=self.nonce_96)
+ cipher1.update(self.data_128)
+ ct = cipher1.encrypt(self.data_128)
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_ba,
+ AES.MODE_EAX,
+ nonce=nonce_ba)
+ key_ba[:3] = b'\xFF\xFF\xFF'
+ nonce_ba[:3] = b'\xFF\xFF\xFF'
+ cipher2.update(header_ba)
+ header_ba[:3] = b'\xFF\xFF\xFF'
+ ct_test = cipher2.encrypt(data_ba)
+ data_ba[:3] = b'\x99\x99\x99'
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data_128)
+ ct_ba = bytearray(ct)
+ tag_ba = bytearray(tag)
+ del data_ba
+
+ cipher3 = AES.new(key_ba,
+ AES.MODE_EAX,
+ nonce=nonce_ba)
+ key_ba[:3] = b'\xFF\xFF\xFF'
+ nonce_ba[:3] = b'\xFF\xFF\xFF'
+ cipher3.update(header_ba)
+ header_ba[:3] = b'\xFF\xFF\xFF'
+ pt_test = cipher3.decrypt(ct_ba)
+ ct_ba[:3] = b'\xFF\xFF\xFF'
+ cipher3.verify(tag_ba)
+
+ self.assertEqual(pt_test, self.data_128)
+
+ def test_memoryview(self):
+
+ # Encrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data_128))
+ data_mv = memoryview(bytearray(self.data_128))
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_EAX,
+ nonce=self.nonce_96)
+ cipher1.update(self.data_128)
+ ct = cipher1.encrypt(self.data_128)
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_mv,
+ AES.MODE_EAX,
+ nonce=nonce_mv)
+ key_mv[:3] = b'\xFF\xFF\xFF'
+ nonce_mv[:3] = b'\xFF\xFF\xFF'
+ cipher2.update(header_mv)
+ header_mv[:3] = b'\xFF\xFF\xFF'
+ ct_test = cipher2.encrypt(data_mv)
+ data_mv[:3] = b'\x99\x99\x99'
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data_128))
+ ct_mv = memoryview(bytearray(ct))
+ tag_mv = memoryview(bytearray(tag))
+ del data_mv
+
+ cipher3 = AES.new(key_mv,
+ AES.MODE_EAX,
+ nonce=nonce_mv)
+ key_mv[:3] = b'\xFF\xFF\xFF'
+ nonce_mv[:3] = b'\xFF\xFF\xFF'
+ cipher3.update(header_mv)
+ header_mv[:3] = b'\xFF\xFF\xFF'
+ pt_test = cipher3.decrypt(ct_mv)
+ ct_mv[:3] = b'\x99\x99\x99'
+ cipher3.verify(tag_mv)
+
+ self.assertEqual(pt_test, self.data_128)
+
+ def test_output_param(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+ tag = cipher.digest()
+
+ output = bytearray(128)
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ res, tag_out = cipher.encrypt_and_digest(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+ self.assertEqual(tag, tag_out)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ res = cipher.decrypt_and_verify(ct, tag, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ def test_output_param_memoryview(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+
+ output = memoryview(bytearray(128))
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ def test_output_param_neg(self):
+ LEN_PT = 16
+
+ pt = b'5' * LEN_PT
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0' * LEN_PT)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0' * LEN_PT)
+
+ shorter_output = bytearray(LEN_PT - 1)
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+class EaxFSMTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data_128 = get_tag_random("data_128", 16)
+
+ def test_valid_init_encrypt_decrypt_digest_verify(self):
+ # No authenticated data, fixed plaintext
+ # Verify path INIT->ENCRYPT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data_128)
+ mac = cipher.digest()
+
+ # Verify path INIT->DECRYPT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_update_digest_verify(self):
+ # No plaintext, fixed authenticated data
+ # Verify path INIT->UPDATE->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ cipher.verify(mac)
+
+ def test_valid_full_path(self):
+ # Fixed authenticated data, fixed plaintext
+ # Verify path INIT->UPDATE->ENCRYPT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ ct = cipher.encrypt(self.data_128)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->DECRYPT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_digest(self):
+ # Verify path INIT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.digest()
+
+ def test_valid_init_verify(self):
+ # Verify path INIT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ mac = cipher.digest()
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.verify(mac)
+
+ def test_valid_multiple_encrypt_or_decrypt(self):
+ for method_name in "encrypt", "decrypt":
+ for auth_data in (None, b"333", self.data_128,
+ self.data_128 + b"3"):
+ if auth_data is None:
+ assoc_len = None
+ else:
+ assoc_len = len(auth_data)
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ if auth_data is not None:
+ cipher.update(auth_data)
+ method = getattr(cipher, method_name)
+ method(self.data_128)
+ method(self.data_128)
+ method(self.data_128)
+ method(self.data_128)
+
+ def test_valid_multiple_digest_or_verify(self):
+ # Multiple calls to digest
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ first_mac = cipher.digest()
+ for x in range(4):
+ self.assertEqual(first_mac, cipher.digest())
+
+ # Multiple calls to verify
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ for x in range(5):
+ cipher.verify(first_mac)
+
+ def test_valid_encrypt_and_digest_decrypt_and_verify(self):
+ # encrypt_and_digest
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ ct, mac = cipher.encrypt_and_digest(self.data_128)
+
+ # decrypt_and_verify
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.update(self.data_128)
+ pt = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(self.data_128, pt)
+
+ def test_invalid_mixing_encrypt_decrypt(self):
+ # Once per method, with or without assoc. data
+ for method1_name, method2_name in (("encrypt", "decrypt"),
+ ("decrypt", "encrypt")):
+ for assoc_data_present in (True, False):
+ cipher = AES.new(self.key_128, AES.MODE_EAX,
+ nonce=self.nonce_96)
+ if assoc_data_present:
+ cipher.update(self.data_128)
+ getattr(cipher, method1_name)(self.data_128)
+ self.assertRaises(TypeError, getattr(cipher, method2_name),
+ self.data_128)
+
+ def test_invalid_encrypt_or_update_after_digest(self):
+ for method_name in "encrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.encrypt(self.data_128)
+ cipher.digest()
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data_128)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.encrypt_and_digest(self.data_128)
+
+ def test_invalid_decrypt_or_update_after_verify(self):
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data_128)
+ mac = cipher.digest()
+
+ for method_name in "decrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data_128)
+
+ cipher = AES.new(self.key_128, AES.MODE_EAX, nonce=self.nonce_96)
+ cipher.decrypt_and_verify(ct, mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data_128)
+
+
+class TestVectorsPaper(unittest.TestCase):
+ """Class exercising the EAX test vectors found in
+ http://www.cs.ucdavis.edu/~rogaway/papers/eax.pdf"""
+
+ test_vectors_hex = [
+ ( '6bfb914fd07eae6b',
+ '',
+ '',
+ 'e037830e8389f27b025a2d6527e79d01',
+ '233952dee4d5ed5f9b9c6d6ff80ff478',
+ '62EC67F9C3A4A407FCB2A8C49031A8B3'
+ ),
+ (
+ 'fa3bfd4806eb53fa',
+ 'f7fb',
+ '19dd',
+ '5c4c9331049d0bdab0277408f67967e5',
+ '91945d3f4dcbee0bf45ef52255f095a4',
+ 'BECAF043B0A23D843194BA972C66DEBD'
+ ),
+ ( '234a3463c1264ac6',
+ '1a47cb4933',
+ 'd851d5bae0',
+ '3a59f238a23e39199dc9266626c40f80',
+ '01f74ad64077f2e704c0f60ada3dd523',
+ '70C3DB4F0D26368400A10ED05D2BFF5E'
+ ),
+ (
+ '33cce2eabff5a79d',
+ '481c9e39b1',
+ '632a9d131a',
+ 'd4c168a4225d8e1ff755939974a7bede',
+ 'd07cf6cbb7f313bdde66b727afd3c5e8',
+ '8408DFFF3C1A2B1292DC199E46B7D617'
+ ),
+ (
+ 'aeb96eaebe2970e9',
+ '40d0c07da5e4',
+ '071dfe16c675',
+ 'cb0677e536f73afe6a14b74ee49844dd',
+ '35b6d0580005bbc12b0587124557d2c2',
+ 'FDB6B06676EEDC5C61D74276E1F8E816'
+ ),
+ (
+ 'd4482d1ca78dce0f',
+ '4de3b35c3fc039245bd1fb7d',
+ '835bb4f15d743e350e728414',
+ 'abb8644fd6ccb86947c5e10590210a4f',
+ 'bd8e6e11475e60b268784c38c62feb22',
+ '6EAC5C93072D8E8513F750935E46DA1B'
+ ),
+ (
+ '65d2017990d62528',
+ '8b0a79306c9ce7ed99dae4f87f8dd61636',
+ '02083e3979da014812f59f11d52630da30',
+ '137327d10649b0aa6e1c181db617d7f2',
+ '7c77d6e813bed5ac98baa417477a2e7d',
+ '1A8C98DCD73D38393B2BF1569DEEFC19'
+ ),
+ (
+ '54b9f04e6a09189a',
+ '1bda122bce8a8dbaf1877d962b8592dd2d56',
+ '2ec47b2c4954a489afc7ba4897edcdae8cc3',
+ '3b60450599bd02c96382902aef7f832a',
+ '5fff20cafab119ca2fc73549e20f5b0d',
+ 'DDE59B97D722156D4D9AFF2BC7559826'
+ ),
+ (
+ '899a175897561d7e',
+ '6cf36720872b8513f6eab1a8a44438d5ef11',
+ '0de18fd0fdd91e7af19f1d8ee8733938b1e8',
+ 'e7f6d2231618102fdb7fe55ff1991700',
+ 'a4a4782bcffd3ec5e7ef6d8c34a56123',
+ 'B781FCF2F75FA5A8DE97A9CA48E522EC'
+ ),
+ (
+ '126735fcc320d25a',
+ 'ca40d7446e545ffaed3bd12a740a659ffbbb3ceab7',
+ 'cb8920f87a6c75cff39627b56e3ed197c552d295a7',
+ 'cfc46afc253b4652b1af3795b124ab6e',
+ '8395fcf1e95bebd697bd010bc766aac3',
+ '22E7ADD93CFC6393C57EC0B3C17D6B44'
+ ),
+ ]
+
+ test_vectors = [[unhexlify(x) for x in tv] for tv in test_vectors_hex]
+
+ def runTest(self):
+ for assoc_data, pt, ct, mac, key, nonce in self.test_vectors:
+ # Encrypt
+ cipher = AES.new(key, AES.MODE_EAX, nonce, mac_len=len(mac))
+ cipher.update(assoc_data)
+ ct2, mac2 = cipher.encrypt_and_digest(pt)
+ self.assertEqual(ct, ct2)
+ self.assertEqual(mac, mac2)
+
+ # Decrypt
+ cipher = AES.new(key, AES.MODE_EAX, nonce, mac_len=len(mac))
+ cipher.update(assoc_data)
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(pt, pt2)
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._id = "None"
+
+ def setUp(self):
+
+ def filter_tag(group):
+ return group['tagSize'] // 8
+
+ self.tv = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ "aes_eax_test.json",
+ "Wycheproof EAX",
+ group_tag={'tag_size': filter_tag})
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_encrypt(self, tv):
+ self._id = "Wycheproof Encrypt EAX Test #" + str(tv.id)
+
+ try:
+ cipher = AES.new(tv.key, AES.MODE_EAX, tv.iv, mac_len=tv.tag_size)
+ except ValueError as e:
+ assert len(tv.iv) == 0 and "Nonce cannot be empty" in str(e)
+ return
+
+ cipher.update(tv.aad)
+ ct, tag = cipher.encrypt_and_digest(tv.msg)
+ if tv.valid:
+ self.assertEqual(ct, tv.ct)
+ self.assertEqual(tag, tv.tag)
+ self.warn(tv)
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt EAX Test #" + str(tv.id)
+
+ try:
+ cipher = AES.new(tv.key, AES.MODE_EAX, tv.iv, mac_len=tv.tag_size)
+ except ValueError as e:
+ assert len(tv.iv) == 0 and "Nonce cannot be empty" in str(e)
+ return
+
+ cipher.update(tv.aad)
+ try:
+ pt = cipher.decrypt_and_verify(tv.ct, tv.tag)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+ self.warn(tv)
+
+ def test_corrupt_decrypt(self, tv):
+ self._id = "Wycheproof Corrupt Decrypt EAX Test #" + str(tv.id)
+ if len(tv.iv) == 0 or len(tv.ct) < 1:
+ return
+ cipher = AES.new(tv.key, AES.MODE_EAX, tv.iv, mac_len=tv.tag_size)
+ cipher.update(tv.aad)
+ ct_corrupt = strxor(tv.ct, b"\x00" * (len(tv.ct) - 1) + b"\x01")
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct_corrupt, tv.tag)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_encrypt(tv)
+ self.test_decrypt(tv)
+ self.test_corrupt_decrypt(tv)
+
+
+class TestOtherCiphers(unittest.TestCase):
+
+ @classmethod
+ def create_test(cls, name, factory, key_size):
+
+ def test_template(self, factory=factory, key_size=key_size):
+ cipher = factory.new(get_tag_random("cipher", key_size),
+ factory.MODE_EAX,
+ nonce=b"nonce")
+ ct, mac = cipher.encrypt_and_digest(b"plaintext")
+
+ cipher = factory.new(get_tag_random("cipher", key_size),
+ factory.MODE_EAX,
+ nonce=b"nonce")
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+
+ self.assertEqual(b"plaintext", pt2)
+
+ setattr(cls, "test_" + name, test_template)
+
+
+from Crypto.Cipher import DES, DES3, ARC2, CAST, Blowfish
+
+TestOtherCiphers.create_test("DES_" + str(DES.key_size), DES, DES.key_size)
+for ks in DES3.key_size:
+ TestOtherCiphers.create_test("DES3_" + str(ks), DES3, ks)
+for ks in ARC2.key_size:
+ TestOtherCiphers.create_test("ARC2_" + str(ks), ARC2, ks)
+for ks in CAST.key_size:
+ TestOtherCiphers.create_test("CAST_" + str(ks), CAST, ks)
+for ks in Blowfish.key_size:
+ TestOtherCiphers.create_test("Blowfish_" + str(ks), Blowfish, ks)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(EaxTests)
+ tests += list_test_cases(EaxFSMTests)
+ tests += [ TestVectorsPaper() ]
+ tests += [ TestVectorsWycheproof(wycheproof_warnings) ]
+ tests += list_test_cases(TestOtherCiphers)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_GCM.py b/lib/Crypto/SelfTest/Cipher/test_GCM.py
new file mode 100644
index 0000000..dd8da2f
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_GCM.py
@@ -0,0 +1,951 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from __future__ import print_function
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors, load_test_vectors_wycheproof
+
+from Crypto.Util.py3compat import tobytes, bchr
+from Crypto.Cipher import AES
+from Crypto.Hash import SHAKE128, SHA256
+
+from Crypto.Util.strxor import strxor
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class GcmTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data = get_tag_random("data", 128)
+
+ def test_loopback_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_nonce(self):
+ # Nonce is optional (a random one will be created)
+ AES.new(self.key_128, AES.MODE_GCM)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, self.nonce_96)
+ ct = cipher.encrypt(self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertEqual(ct, cipher.encrypt(self.data))
+
+ def test_nonce_must_be_bytes(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_GCM,
+ nonce=u'test12345678')
+
+ def test_nonce_length(self):
+ # nonce can be of any length (but not empty)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_GCM,
+ nonce=b"")
+
+ for x in range(1, 128):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=bchr(1) * x)
+ cipher.encrypt(bchr(1))
+
+ def test_block_size_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertEqual(cipher.block_size, AES.block_size)
+
+ def test_nonce_attribute(self):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertEqual(cipher.nonce, self.nonce_96)
+
+ # By default, a 15 bytes long nonce is randomly generated
+ nonce1 = AES.new(self.key_128, AES.MODE_GCM).nonce
+ nonce2 = AES.new(self.key_128, AES.MODE_GCM).nonce
+ self.assertEqual(len(nonce1), 16)
+ self.assertNotEqual(nonce1, nonce2)
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_GCM,
+ self.nonce_96, 7)
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96, unknown=7)
+
+ # But some are only known by the base cipher
+ # (e.g. use_aesni consumed by the AES module)
+ AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96,
+ use_aesni=False)
+
+ def test_null_encryption_decryption(self):
+ for func in "encrypt", "decrypt":
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ result = getattr(cipher, func)(b"")
+ self.assertEqual(result, b"")
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.encrypt(b"")
+ self.assertRaises(TypeError, cipher.decrypt, b"")
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.decrypt(b"")
+ self.assertRaises(TypeError, cipher.encrypt, b"")
+
+ def test_data_must_be_bytes(self):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
+
+ def test_mac_len(self):
+ # Invalid MAC length
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96, mac_len=3)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96, mac_len=16+1)
+
+ # Valid MAC length
+ for mac_len in range(5, 16 + 1):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96,
+ mac_len=mac_len)
+ _, mac = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(len(mac), mac_len)
+
+ # Default MAC length
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ _, mac = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(len(mac), 16)
+
+ def test_invalid_mac(self):
+ from Crypto.Util.strxor import strxor_c
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ invalid_mac = strxor_c(mac, 0x01)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct,
+ invalid_mac)
+
+ def test_hex_mac(self):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ mac_hex = cipher.hexdigest()
+ self.assertEqual(cipher.digest(), unhexlify(mac_hex))
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.hexverify(mac_hex)
+
+ def test_message_chunks(self):
+ # Validate that both associated data and plaintext/ciphertext
+ # can be broken up in chunks of arbitrary length
+
+ auth_data = get_tag_random("authenticated data", 127)
+ plaintext = get_tag_random("plaintext", 127)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.update(auth_data)
+ ciphertext, ref_mac = cipher.encrypt_and_digest(plaintext)
+
+ def break_up(data, chunk_length):
+ return [data[i:i+chunk_length] for i in range(0, len(data),
+ chunk_length)]
+
+ # Encryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ pt2 = b""
+ for chunk in break_up(ciphertext, chunk_length):
+ pt2 += cipher.decrypt(chunk)
+ self.assertEqual(plaintext, pt2)
+ cipher.verify(ref_mac)
+
+ # Decryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ ct2 = b""
+ for chunk in break_up(plaintext, chunk_length):
+ ct2 += cipher.encrypt(chunk)
+ self.assertEqual(ciphertext, ct2)
+ self.assertEqual(cipher.digest(), ref_mac)
+
+ def test_bytearray(self):
+
+ # Encrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data)
+ data_ba = bytearray(self.data)
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_GCM,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct = cipher1.encrypt(self.data)
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_ba,
+ AES.MODE_GCM,
+ nonce=nonce_ba)
+ key_ba[:3] = b"\xFF\xFF\xFF"
+ nonce_ba[:3] = b"\xFF\xFF\xFF"
+ cipher2.update(header_ba)
+ header_ba[:3] = b"\xFF\xFF\xFF"
+ ct_test = cipher2.encrypt(data_ba)
+ data_ba[:3] = b"\xFF\xFF\xFF"
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data)
+ del data_ba
+
+ cipher4 = AES.new(key_ba,
+ AES.MODE_GCM,
+ nonce=nonce_ba)
+ key_ba[:3] = b"\xFF\xFF\xFF"
+ nonce_ba[:3] = b"\xFF\xFF\xFF"
+ cipher4.update(header_ba)
+ header_ba[:3] = b"\xFF\xFF\xFF"
+ pt_test = cipher4.decrypt_and_verify(bytearray(ct_test), bytearray(tag_test))
+
+ self.assertEqual(self.data, pt_test)
+
+ def test_memoryview(self):
+
+ # Encrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data))
+ data_mv = memoryview(bytearray(self.data))
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_GCM,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct = cipher1.encrypt(self.data)
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_mv,
+ AES.MODE_GCM,
+ nonce=nonce_mv)
+ key_mv[:3] = b"\xFF\xFF\xFF"
+ nonce_mv[:3] = b"\xFF\xFF\xFF"
+ cipher2.update(header_mv)
+ header_mv[:3] = b"\xFF\xFF\xFF"
+ ct_test = cipher2.encrypt(data_mv)
+ data_mv[:3] = b"\xFF\xFF\xFF"
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data))
+ del data_mv
+
+ cipher4 = AES.new(key_mv,
+ AES.MODE_GCM,
+ nonce=nonce_mv)
+ key_mv[:3] = b"\xFF\xFF\xFF"
+ nonce_mv[:3] = b"\xFF\xFF\xFF"
+ cipher4.update(header_mv)
+ header_mv[:3] = b"\xFF\xFF\xFF"
+ pt_test = cipher4.decrypt_and_verify(memoryview(ct_test), memoryview(tag_test))
+
+ self.assertEqual(self.data, pt_test)
+
+ def test_output_param(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+ tag = cipher.digest()
+
+ output = bytearray(128)
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ res, tag_out = cipher.encrypt_and_digest(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+ self.assertEqual(tag, tag_out)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ res = cipher.decrypt_and_verify(ct, tag, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ def test_output_param_memoryview(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+
+ output = memoryview(bytearray(128))
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ def test_output_param_neg(self):
+ LEN_PT = 128
+
+ pt = b'5' * LEN_PT
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0' * LEN_PT)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0' * LEN_PT)
+
+ shorter_output = bytearray(LEN_PT - 1)
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+class GcmFSMTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data = get_tag_random("data", 128)
+
+ def test_valid_init_encrypt_decrypt_digest_verify(self):
+ # No authenticated data, fixed plaintext
+ # Verify path INIT->ENCRYPT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->DECRYPT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_update_digest_verify(self):
+ # No plaintext, fixed authenticated data
+ # Verify path INIT->UPDATE->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ cipher.verify(mac)
+
+ def test_valid_full_path(self):
+ # Fixed authenticated data, fixed plaintext
+ # Verify path INIT->UPDATE->ENCRYPT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ ct = cipher.encrypt(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->DECRYPT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+
+ def test_valid_init_digest(self):
+ # Verify path INIT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.digest()
+
+ def test_valid_init_verify(self):
+ # Verify path INIT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ mac = cipher.digest()
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.verify(mac)
+
+ def test_valid_multiple_encrypt_or_decrypt(self):
+ for method_name in "encrypt", "decrypt":
+ for auth_data in (None, b"333", self.data,
+ self.data + b"3"):
+ if auth_data is None:
+ assoc_len = None
+ else:
+ assoc_len = len(auth_data)
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ if auth_data is not None:
+ cipher.update(auth_data)
+ method = getattr(cipher, method_name)
+ method(self.data)
+ method(self.data)
+ method(self.data)
+ method(self.data)
+
+ def test_valid_multiple_digest_or_verify(self):
+ # Multiple calls to digest
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ first_mac = cipher.digest()
+ for x in range(4):
+ self.assertEqual(first_mac, cipher.digest())
+
+ # Multiple calls to verify
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ for x in range(5):
+ cipher.verify(first_mac)
+
+ def test_valid_encrypt_and_digest_decrypt_and_verify(self):
+ # encrypt_and_digest
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ # decrypt_and_verify
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.update(self.data)
+ pt = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(self.data, pt)
+
+ def test_invalid_mixing_encrypt_decrypt(self):
+ # Once per method, with or without assoc. data
+ for method1_name, method2_name in (("encrypt", "decrypt"),
+ ("decrypt", "encrypt")):
+ for assoc_data_present in (True, False):
+ cipher = AES.new(self.key_128, AES.MODE_GCM,
+ nonce=self.nonce_96)
+ if assoc_data_present:
+ cipher.update(self.data)
+ getattr(cipher, method1_name)(self.data)
+ self.assertRaises(TypeError, getattr(cipher, method2_name),
+ self.data)
+
+ def test_invalid_encrypt_or_update_after_digest(self):
+ for method_name in "encrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.encrypt(self.data)
+ cipher.digest()
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.encrypt_and_digest(self.data)
+
+ def test_invalid_decrypt_or_update_after_verify(self):
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data)
+ mac = cipher.digest()
+
+ for method_name in "decrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.verify(mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_GCM, nonce=self.nonce_96)
+ cipher.decrypt_and_verify(ct, mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+
+class TestVectors(unittest.TestCase):
+ """Class exercising the GCM test vectors found in
+ http://csrc.nist.gov/groups/ST/toolkit/BCM/documents/proposedmodes/gcm/gcm-revised-spec.pdf"""
+
+ # List of test vectors, each made up of:
+ # - authenticated data
+ # - plaintext
+ # - ciphertext
+ # - MAC
+ # - AES key
+ # - nonce
+ test_vectors_hex = [
+ (
+ '',
+ '',
+ '',
+ '58e2fccefa7e3061367f1d57a4e7455a',
+ '00000000000000000000000000000000',
+ '000000000000000000000000'
+ ),
+ (
+ '',
+ '00000000000000000000000000000000',
+ '0388dace60b6a392f328c2b971b2fe78',
+ 'ab6e47d42cec13bdf53a67b21257bddf',
+ '00000000000000000000000000000000',
+ '000000000000000000000000'
+ ),
+ (
+ '',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255',
+ '42831ec2217774244b7221b784d0d49ce3aa212f2c02a4e035c17e2329aca12e' +
+ '21d514b25466931c7d8f6a5aac84aa051ba30b396a0aac973d58e091473f5985',
+ '4d5c2af327cd64a62cf35abd2ba6fab4',
+ 'feffe9928665731c6d6a8f9467308308',
+ 'cafebabefacedbaddecaf888'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ '42831ec2217774244b7221b784d0d49ce3aa212f2c02a4e035c17e2329aca12e' +
+ '21d514b25466931c7d8f6a5aac84aa051ba30b396a0aac973d58e091',
+ '5bc94fbc3221a5db94fae95ae7121a47',
+ 'feffe9928665731c6d6a8f9467308308',
+ 'cafebabefacedbaddecaf888'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ '61353b4c2806934a777ff51fa22a4755699b2a714fcdc6f83766e5f97b6c7423' +
+ '73806900e49f24b22b097544d4896b424989b5e1ebac0f07c23f4598',
+ '3612d2e79e3b0785561be14aaca2fccb',
+ 'feffe9928665731c6d6a8f9467308308',
+ 'cafebabefacedbad'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ '8ce24998625615b603a033aca13fb894be9112a5c3a211a8ba262a3cca7e2ca7' +
+ '01e4a9a4fba43c90ccdcb281d48c7c6fd62875d2aca417034c34aee5',
+ '619cc5aefffe0bfa462af43c1699d050',
+ 'feffe9928665731c6d6a8f9467308308',
+ '9313225df88406e555909c5aff5269aa' +
+ '6a7a9538534f7da1e4c303d2a318a728c3c0c95156809539fcf0e2429a6b5254' +
+ '16aedbf5a0de6a57a637b39b'
+ ),
+ (
+ '',
+ '',
+ '',
+ 'cd33b28ac773f74ba00ed1f312572435',
+ '000000000000000000000000000000000000000000000000',
+ '000000000000000000000000'
+ ),
+ (
+ '',
+ '00000000000000000000000000000000',
+ '98e7247c07f0fe411c267e4384b0f600',
+ '2ff58d80033927ab8ef4d4587514f0fb',
+ '000000000000000000000000000000000000000000000000',
+ '000000000000000000000000'
+ ),
+ (
+ '',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255',
+ '3980ca0b3c00e841eb06fac4872a2757859e1ceaa6efd984628593b40ca1e19c' +
+ '7d773d00c144c525ac619d18c84a3f4718e2448b2fe324d9ccda2710acade256',
+ '9924a7c8587336bfb118024db8674a14',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c',
+ 'cafebabefacedbaddecaf888'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ '3980ca0b3c00e841eb06fac4872a2757859e1ceaa6efd984628593b40ca1e19c' +
+ '7d773d00c144c525ac619d18c84a3f4718e2448b2fe324d9ccda2710',
+ '2519498e80f1478f37ba55bd6d27618c',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c',
+ 'cafebabefacedbaddecaf888'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ '0f10f599ae14a154ed24b36e25324db8c566632ef2bbb34f8347280fc4507057' +
+ 'fddc29df9a471f75c66541d4d4dad1c9e93a19a58e8b473fa0f062f7',
+ '65dcc57fcf623a24094fcca40d3533f8',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c',
+ 'cafebabefacedbad'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ 'd27e88681ce3243c4830165a8fdcf9ff1de9a1d8e6b447ef6ef7b79828666e45' +
+ '81e79012af34ddd9e2f037589b292db3e67c036745fa22e7e9b7373b',
+ 'dcf566ff291c25bbb8568fc3d376a6d9',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c',
+ '9313225df88406e555909c5aff5269aa' +
+ '6a7a9538534f7da1e4c303d2a318a728c3c0c95156809539fcf0e2429a6b5254' +
+ '16aedbf5a0de6a57a637b39b'
+ ),
+ (
+ '',
+ '',
+ '',
+ '530f8afbc74536b9a963b4f1c4cb738b',
+ '0000000000000000000000000000000000000000000000000000000000000000',
+ '000000000000000000000000'
+ ),
+ (
+ '',
+ '00000000000000000000000000000000',
+ 'cea7403d4d606b6e074ec5d3baf39d18',
+ 'd0d1c8a799996bf0265b98b5d48ab919',
+ '0000000000000000000000000000000000000000000000000000000000000000',
+ '000000000000000000000000'
+ ),
+ ( '',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b391aafd255',
+ '522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa' +
+ '8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662898015ad',
+ 'b094dac5d93471bdec1a502270e3cc6c',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308',
+ 'cafebabefacedbaddecaf888'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ '522dc1f099567d07f47f37a32a84427d643a8cdcbfe5c0c97598a2bd2555d1aa' +
+ '8cb08e48590dbb3da7b08b1056828838c5f61e6393ba7a0abcc9f662',
+ '76fc6ece0f4e1768cddf8853bb2d551b',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308',
+ 'cafebabefacedbaddecaf888'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ 'c3762df1ca787d32ae47c13bf19844cbaf1ae14d0b976afac52ff7d79bba9de0' +
+ 'feb582d33934a4f0954cc2363bc73f7862ac430e64abe499f47c9b1f',
+ '3a337dbf46a792c45e454913fe2ea8f2',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308',
+ 'cafebabefacedbad'
+ ),
+ (
+ 'feedfacedeadbeeffeedfacedeadbeefabaddad2',
+ 'd9313225f88406e5a55909c5aff5269a86a7a9531534f7da2e4c303d8a318a72' +
+ '1c3c0c95956809532fcf0e2449a6b525b16aedf5aa0de657ba637b39',
+ '5a8def2f0c9e53f1f75d7853659e2a20eeb2b22aafde6419a058ab4f6f746bf4' +
+ '0fc0c3b780f244452da3ebf1c5d82cdea2418997200ef82e44ae7e3f',
+ 'a44a8266ee1c8eb0c8b5d4cf5ae9f19a',
+ 'feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308',
+ '9313225df88406e555909c5aff5269aa' +
+ '6a7a9538534f7da1e4c303d2a318a728c3c0c95156809539fcf0e2429a6b5254' +
+ '16aedbf5a0de6a57a637b39b'
+ )
+ ]
+
+ test_vectors = [[unhexlify(x) for x in tv] for tv in test_vectors_hex]
+
+ def runTest(self):
+ for assoc_data, pt, ct, mac, key, nonce in self.test_vectors:
+
+ # Encrypt
+ cipher = AES.new(key, AES.MODE_GCM, nonce, mac_len=len(mac))
+ cipher.update(assoc_data)
+ ct2, mac2 = cipher.encrypt_and_digest(pt)
+ self.assertEqual(ct, ct2)
+ self.assertEqual(mac, mac2)
+
+ # Decrypt
+ cipher = AES.new(key, AES.MODE_GCM, nonce, mac_len=len(mac))
+ cipher.update(assoc_data)
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(pt, pt2)
+
+
+class TestVectorsGueronKrasnov(unittest.TestCase):
+ """Class exercising the GCM test vectors found in
+ 'The fragility of AES-GCM authentication algorithm', Gueron, Krasnov
+ https://eprint.iacr.org/2013/157.pdf"""
+
+ def test_1(self):
+ key = unhexlify("3da6c536d6295579c0959a7043efb503")
+ iv = unhexlify("2b926197d34e091ef722db94")
+ aad = unhexlify("00000000000000000000000000000000" +
+ "000102030405060708090a0b0c0d0e0f" +
+ "101112131415161718191a1b1c1d1e1f" +
+ "202122232425262728292a2b2c2d2e2f" +
+ "303132333435363738393a3b3c3d3e3f")
+ digest = unhexlify("69dd586555ce3fcc89663801a71d957b")
+
+ cipher = AES.new(key, AES.MODE_GCM, iv).update(aad)
+ self.assertEqual(digest, cipher.digest())
+
+ def test_2(self):
+ key = unhexlify("843ffcf5d2b72694d19ed01d01249412")
+ iv = unhexlify("dbcca32ebf9b804617c3aa9e")
+ aad = unhexlify("00000000000000000000000000000000" +
+ "101112131415161718191a1b1c1d1e1f")
+ pt = unhexlify("000102030405060708090a0b0c0d0e0f" +
+ "101112131415161718191a1b1c1d1e1f" +
+ "202122232425262728292a2b2c2d2e2f" +
+ "303132333435363738393a3b3c3d3e3f" +
+ "404142434445464748494a4b4c4d4e4f")
+ ct = unhexlify("6268c6fa2a80b2d137467f092f657ac0" +
+ "4d89be2beaa623d61b5a868c8f03ff95" +
+ "d3dcee23ad2f1ab3a6c80eaf4b140eb0" +
+ "5de3457f0fbc111a6b43d0763aa422a3" +
+ "013cf1dc37fe417d1fbfc449b75d4cc5")
+ digest = unhexlify("3b629ccfbc1119b7319e1dce2cd6fd6d")
+
+ cipher = AES.new(key, AES.MODE_GCM, iv).update(aad)
+ ct2, digest2 = cipher.encrypt_and_digest(pt)
+
+ self.assertEqual(ct, ct2)
+ self.assertEqual(digest, digest2)
+
+
+class NISTTestVectorsGCM(unittest.TestCase):
+
+ def __init__(self, a):
+ self.use_clmul = True
+ unittest.TestCase.__init__(self, a)
+
+
+class NISTTestVectorsGCM_no_clmul(unittest.TestCase):
+
+ def __init__(self, a):
+ self.use_clmul = False
+ unittest.TestCase.__init__(self, a)
+
+
+test_vectors_nist = load_test_vectors(
+ ("Cipher", "AES"),
+ "gcmDecrypt128.rsp",
+ "GCM decrypt",
+ {"count": lambda x: int(x)}) or []
+
+test_vectors_nist += load_test_vectors(
+ ("Cipher", "AES"),
+ "gcmEncryptExtIV128.rsp",
+ "GCM encrypt",
+ {"count": lambda x: int(x)}) or []
+
+for idx, tv in enumerate(test_vectors_nist):
+
+ # The test vector file contains some directive lines
+ if isinstance(tv, str):
+ continue
+
+ def single_test(self, tv=tv):
+
+ self.description = tv.desc
+ cipher = AES.new(tv.key, AES.MODE_GCM, nonce=tv.iv,
+ mac_len=len(tv.tag), use_clmul=self.use_clmul)
+ cipher.update(tv.aad)
+ if "FAIL" in tv.others:
+ self.assertRaises(ValueError, cipher.decrypt_and_verify,
+ tv.ct, tv.tag)
+ else:
+ pt = cipher.decrypt_and_verify(tv.ct, tv.tag)
+ self.assertEqual(pt, tv.pt)
+
+ setattr(NISTTestVectorsGCM, "test_%d" % idx, single_test)
+ setattr(NISTTestVectorsGCM_no_clmul, "test_%d" % idx, single_test)
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings, **extra_params):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._extra_params = extra_params
+ self._id = "None"
+
+ def setUp(self):
+
+ def filter_tag(group):
+ return group['tagSize'] // 8
+
+ self.tv = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ "aes_gcm_test.json",
+ "Wycheproof GCM",
+ group_tag={'tag_size': filter_tag})
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_encrypt(self, tv):
+ self._id = "Wycheproof Encrypt GCM Test #" + str(tv.id)
+
+ try:
+ cipher = AES.new(tv.key, AES.MODE_GCM, tv.iv, mac_len=tv.tag_size,
+ **self._extra_params)
+ except ValueError as e:
+ if len(tv.iv) == 0 and "Nonce cannot be empty" in str(e):
+ return
+ raise e
+
+ cipher.update(tv.aad)
+ ct, tag = cipher.encrypt_and_digest(tv.msg)
+ if tv.valid:
+ self.assertEqual(ct, tv.ct)
+ self.assertEqual(tag, tv.tag)
+ self.warn(tv)
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt GCM Test #" + str(tv.id)
+
+ try:
+ cipher = AES.new(tv.key, AES.MODE_GCM, tv.iv, mac_len=tv.tag_size,
+ **self._extra_params)
+ except ValueError as e:
+ if len(tv.iv) == 0 and "Nonce cannot be empty" in str(e):
+ return
+ raise e
+
+ cipher.update(tv.aad)
+ try:
+ pt = cipher.decrypt_and_verify(tv.ct, tv.tag)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+ self.warn(tv)
+
+ def test_corrupt_decrypt(self, tv):
+ self._id = "Wycheproof Corrupt Decrypt GCM Test #" + str(tv.id)
+ if len(tv.iv) == 0 or len(tv.ct) < 1:
+ return
+ cipher = AES.new(tv.key, AES.MODE_GCM, tv.iv, mac_len=tv.tag_size,
+ **self._extra_params)
+ cipher.update(tv.aad)
+ ct_corrupt = strxor(tv.ct, b"\x00" * (len(tv.ct) - 1) + b"\x01")
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct_corrupt, tv.tag)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_encrypt(tv)
+ self.test_decrypt(tv)
+ self.test_corrupt_decrypt(tv)
+
+
+class TestVariableLength(unittest.TestCase):
+
+ def __init__(self, **extra_params):
+ unittest.TestCase.__init__(self)
+ self._extra_params = extra_params
+
+ def runTest(self):
+ key = b'0' * 16
+ h = SHA256.new()
+
+ for length in range(160):
+ nonce = '{0:04d}'.format(length).encode('utf-8')
+ data = bchr(length) * length
+ cipher = AES.new(key, AES.MODE_GCM, nonce=nonce, **self._extra_params)
+ ct, tag = cipher.encrypt_and_digest(data)
+ h.update(ct)
+ h.update(tag)
+
+ self.assertEqual(h.hexdigest(), "7b7eb1ffbe67a2e53a912067c0ec8e62ebc7ce4d83490ea7426941349811bdf4")
+
+
+def get_tests(config={}):
+ from Crypto.Util import _cpu_features
+
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(GcmTests)
+ tests += list_test_cases(GcmFSMTests)
+ tests += [TestVectors()]
+ tests += [TestVectorsWycheproof(wycheproof_warnings)]
+ tests += list_test_cases(TestVectorsGueronKrasnov)
+ tests += [TestVariableLength()]
+ if config.get('slow_tests'):
+ tests += list_test_cases(NISTTestVectorsGCM)
+
+ if _cpu_features.have_clmul():
+ tests += [TestVectorsWycheproof(wycheproof_warnings, use_clmul=False)]
+ tests += [TestVariableLength(use_clmul=False)]
+ if config.get('slow_tests'):
+ tests += list_test_cases(NISTTestVectorsGCM_no_clmul)
+ else:
+ print("Skipping test of PCLMULDQD in AES GCM")
+
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_OCB.py b/lib/Crypto/SelfTest/Cipher/test_OCB.py
new file mode 100644
index 0000000..3a89122
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_OCB.py
@@ -0,0 +1,742 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import os
+import re
+import unittest
+from binascii import hexlify, unhexlify
+
+from Crypto.Util.py3compat import b, tobytes, bchr
+from Crypto.Util.strxor import strxor_c
+from Crypto.Util.number import long_to_bytes
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Cipher import AES
+from Crypto.Hash import SHAKE128
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class OcbTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data = get_tag_random("data", 128)
+
+ def test_loopback_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct, mac = cipher.encrypt_and_digest(pt)
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(pt, pt2)
+
+ def test_nonce(self):
+ # Nonce is optional
+ AES.new(self.key_128, AES.MODE_OCB)
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, self.nonce_96)
+ ct = cipher.encrypt(self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ self.assertEqual(ct, cipher.encrypt(self.data))
+
+ def test_nonce_must_be_bytes(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_OCB,
+ nonce=u'test12345678')
+
+ def test_nonce_length(self):
+ # nonce cannot be empty
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_OCB,
+ nonce=b(""))
+
+ # nonce can be up to 15 bytes long
+ for length in range(1, 16):
+ AES.new(self.key_128, AES.MODE_OCB, nonce=self.data[:length])
+
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_OCB,
+ nonce=self.data)
+
+ def test_block_size_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ self.assertEqual(cipher.block_size, AES.block_size)
+
+ # By default, a 15 bytes long nonce is randomly generated
+ nonce1 = AES.new(self.key_128, AES.MODE_OCB).nonce
+ nonce2 = AES.new(self.key_128, AES.MODE_OCB).nonce
+ self.assertEqual(len(nonce1), 15)
+ self.assertNotEqual(nonce1, nonce2)
+
+ def test_nonce_attribute(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ self.assertEqual(cipher.nonce, self.nonce_96)
+
+ # By default, a 15 bytes long nonce is randomly generated
+ nonce1 = AES.new(self.key_128, AES.MODE_OCB).nonce
+ nonce2 = AES.new(self.key_128, AES.MODE_OCB).nonce
+ self.assertEqual(len(nonce1), 15)
+ self.assertNotEqual(nonce1, nonce2)
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_OCB,
+ self.nonce_96, 7)
+ self.assertRaises(TypeError, AES.new, self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96, unknown=7)
+
+ # But some are only known by the base cipher
+ # (e.g. use_aesni consumed by the AES module)
+ AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96,
+ use_aesni=False)
+
+ def test_null_encryption_decryption(self):
+ for func in "encrypt", "decrypt":
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ result = getattr(cipher, func)(b(""))
+ self.assertEqual(result, b(""))
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.encrypt(b("xyz"))
+ self.assertRaises(TypeError, cipher.decrypt, b("xyz"))
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.decrypt(b("xyz"))
+ self.assertRaises(TypeError, cipher.encrypt, b("xyz"))
+
+ def test_data_must_be_bytes(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, u'test1234567890-*')
+
+ def test_mac_len(self):
+ # Invalid MAC length
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96, mac_len=7)
+ self.assertRaises(ValueError, AES.new, self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96, mac_len=16+1)
+
+ # Valid MAC length
+ for mac_len in range(8, 16 + 1):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96,
+ mac_len=mac_len)
+ _, mac = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(len(mac), mac_len)
+
+ # Default MAC length
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ _, mac = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(len(mac), 16)
+
+ def test_invalid_mac(self):
+ from Crypto.Util.strxor import strxor_c
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ invalid_mac = strxor_c(mac, 0x01)
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct,
+ invalid_mac)
+
+ def test_hex_mac(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ mac_hex = cipher.hexdigest()
+ self.assertEqual(cipher.digest(), unhexlify(mac_hex))
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.hexverify(mac_hex)
+
+ def test_message_chunks(self):
+ # Validate that both associated data and plaintext/ciphertext
+ # can be broken up in chunks of arbitrary length
+
+ auth_data = get_tag_random("authenticated data", 127)
+ plaintext = get_tag_random("plaintext", 127)
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.update(auth_data)
+ ciphertext, ref_mac = cipher.encrypt_and_digest(plaintext)
+
+ def break_up(data, chunk_length):
+ return [data[i:i+chunk_length] for i in range(0, len(data),
+ chunk_length)]
+
+ # Encryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ pt2 = b("")
+ for chunk in break_up(ciphertext, chunk_length):
+ pt2 += cipher.decrypt(chunk)
+ pt2 += cipher.decrypt()
+ self.assertEqual(plaintext, pt2)
+ cipher.verify(ref_mac)
+
+ # Decryption
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+
+ for chunk in break_up(auth_data, chunk_length):
+ cipher.update(chunk)
+ ct2 = b("")
+ for chunk in break_up(plaintext, chunk_length):
+ ct2 += cipher.encrypt(chunk)
+ ct2 += cipher.encrypt()
+ self.assertEqual(ciphertext, ct2)
+ self.assertEqual(cipher.digest(), ref_mac)
+
+ def test_bytearray(self):
+
+ # Encrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data)
+ data_ba = bytearray(self.data)
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct = cipher1.encrypt(self.data) + cipher1.encrypt()
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_ba,
+ AES.MODE_OCB,
+ nonce=nonce_ba)
+ key_ba[:3] = b"\xFF\xFF\xFF"
+ nonce_ba[:3] = b"\xFF\xFF\xFF"
+ cipher2.update(header_ba)
+ header_ba[:3] = b"\xFF\xFF\xFF"
+ ct_test = cipher2.encrypt(data_ba) + cipher2.encrypt()
+ data_ba[:3] = b"\xFF\xFF\xFF"
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_ba = bytearray(self.key_128)
+ nonce_ba = bytearray(self.nonce_96)
+ header_ba = bytearray(self.data)
+ del data_ba
+
+ cipher4 = AES.new(key_ba,
+ AES.MODE_OCB,
+ nonce=nonce_ba)
+ key_ba[:3] = b"\xFF\xFF\xFF"
+ nonce_ba[:3] = b"\xFF\xFF\xFF"
+ cipher4.update(header_ba)
+ header_ba[:3] = b"\xFF\xFF\xFF"
+ pt_test = cipher4.decrypt_and_verify(bytearray(ct_test), bytearray(tag_test))
+
+ self.assertEqual(self.data, pt_test)
+
+ def test_memoryview(self):
+
+ # Encrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data))
+ data_mv = memoryview(bytearray(self.data))
+
+ cipher1 = AES.new(self.key_128,
+ AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct = cipher1.encrypt(self.data) + cipher1.encrypt()
+ tag = cipher1.digest()
+
+ cipher2 = AES.new(key_mv,
+ AES.MODE_OCB,
+ nonce=nonce_mv)
+ key_mv[:3] = b"\xFF\xFF\xFF"
+ nonce_mv[:3] = b"\xFF\xFF\xFF"
+ cipher2.update(header_mv)
+ header_mv[:3] = b"\xFF\xFF\xFF"
+ ct_test = cipher2.encrypt(data_mv) + cipher2.encrypt()
+ data_mv[:3] = b"\xFF\xFF\xFF"
+ tag_test = cipher2.digest()
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key_mv = memoryview(bytearray(self.key_128))
+ nonce_mv = memoryview(bytearray(self.nonce_96))
+ header_mv = memoryview(bytearray(self.data))
+ del data_mv
+
+ cipher4 = AES.new(key_mv,
+ AES.MODE_OCB,
+ nonce=nonce_mv)
+ key_mv[:3] = b"\xFF\xFF\xFF"
+ nonce_mv[:3] = b"\xFF\xFF\xFF"
+ cipher4.update(header_mv)
+ header_mv[:3] = b"\xFF\xFF\xFF"
+ pt_test = cipher4.decrypt_and_verify(memoryview(ct_test), memoryview(tag_test))
+
+ self.assertEqual(self.data, pt_test)
+
+
+class OcbFSMTests(unittest.TestCase):
+
+ key_128 = get_tag_random("key_128", 16)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data = get_tag_random("data", 128)
+
+ def test_valid_init_encrypt_decrypt_digest_verify(self):
+ # No authenticated data, fixed plaintext
+ # Verify path INIT->ENCRYPT->ENCRYPT(NONE)->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data)
+ ct += cipher.encrypt()
+ mac = cipher.digest()
+
+ # Verify path INIT->DECRYPT->DECRYPT(NONCE)->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.decrypt()
+ cipher.verify(mac)
+
+ def test_invalid_init_encrypt_decrypt_digest_verify(self):
+ # No authenticated data, fixed plaintext
+ # Verify path INIT->ENCRYPT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data)
+ self.assertRaises(TypeError, cipher.digest)
+
+ # Verify path INIT->DECRYPT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ self.assertRaises(TypeError, cipher.verify)
+
+ def test_valid_init_update_digest_verify(self):
+ # No plaintext, fixed authenticated data
+ # Verify path INIT->UPDATE->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ cipher.verify(mac)
+
+ def test_valid_full_path(self):
+ # Fixed authenticated data, fixed plaintext
+ # Verify path INIT->UPDATE->ENCRYPT->ENCRYPT(NONE)->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ ct = cipher.encrypt(self.data)
+ ct += cipher.encrypt()
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->DECRYPT->DECRYPT(NONE)->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ cipher.decrypt(ct)
+ cipher.decrypt()
+ cipher.verify(mac)
+
+ def test_invalid_encrypt_after_final(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ cipher.encrypt(self.data)
+ cipher.encrypt()
+ self.assertRaises(TypeError, cipher.encrypt, self.data)
+
+ def test_invalid_decrypt_after_final(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ cipher.decrypt(self.data)
+ cipher.decrypt()
+ self.assertRaises(TypeError, cipher.decrypt, self.data)
+
+ def test_valid_init_digest(self):
+ # Verify path INIT->DIGEST
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.digest()
+
+ def test_valid_init_verify(self):
+ # Verify path INIT->VERIFY
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ mac = cipher.digest()
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.verify(mac)
+
+ def test_valid_multiple_encrypt_or_decrypt(self):
+ for method_name in "encrypt", "decrypt":
+ for auth_data in (None, b("333"), self.data,
+ self.data + b("3")):
+ if auth_data is None:
+ assoc_len = None
+ else:
+ assoc_len = len(auth_data)
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ if auth_data is not None:
+ cipher.update(auth_data)
+ method = getattr(cipher, method_name)
+ method(self.data)
+ method(self.data)
+ method(self.data)
+ method(self.data)
+ method()
+
+ def test_valid_multiple_digest_or_verify(self):
+ # Multiple calls to digest
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.update(self.data)
+ first_mac = cipher.digest()
+ for x in range(4):
+ self.assertEqual(first_mac, cipher.digest())
+
+ # Multiple calls to verify
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.update(self.data)
+ for x in range(5):
+ cipher.verify(first_mac)
+
+ def test_valid_encrypt_and_digest_decrypt_and_verify(self):
+ # encrypt_and_digest
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.update(self.data)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ # decrypt_and_verify
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.update(self.data)
+ pt = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(self.data, pt)
+
+ def test_invalid_mixing_encrypt_decrypt(self):
+ # Once per method, with or without assoc. data
+ for method1_name, method2_name in (("encrypt", "decrypt"),
+ ("decrypt", "encrypt")):
+ for assoc_data_present in (True, False):
+ cipher = AES.new(self.key_128, AES.MODE_OCB,
+ nonce=self.nonce_96)
+ if assoc_data_present:
+ cipher.update(self.data)
+ getattr(cipher, method1_name)(self.data)
+ self.assertRaises(TypeError, getattr(cipher, method2_name),
+ self.data)
+
+ def test_invalid_encrypt_or_update_after_digest(self):
+ for method_name in "encrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.encrypt(self.data)
+ cipher.encrypt()
+ cipher.digest()
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.encrypt_and_digest(self.data)
+
+ def test_invalid_decrypt_or_update_after_verify(self):
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ ct = cipher.encrypt(self.data)
+ ct += cipher.encrypt()
+ mac = cipher.digest()
+
+ for method_name in "decrypt", "update":
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.decrypt(ct)
+ cipher.decrypt()
+ cipher.verify(mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+ cipher = AES.new(self.key_128, AES.MODE_OCB, nonce=self.nonce_96)
+ cipher.decrypt_and_verify(ct, mac)
+ self.assertRaises(TypeError, getattr(cipher, method_name),
+ self.data)
+
+
+class OcbRfc7253Test(unittest.TestCase):
+
+ # Tuple with
+ # - nonce
+ # - authenticated data
+ # - plaintext
+ # - ciphertext and 16 byte MAC tag
+ tv1_key = "000102030405060708090A0B0C0D0E0F"
+ tv1 = (
+ (
+ "BBAA99887766554433221100",
+ "",
+ "",
+ "785407BFFFC8AD9EDCC5520AC9111EE6"
+ ),
+ (
+ "BBAA99887766554433221101",
+ "0001020304050607",
+ "0001020304050607",
+ "6820B3657B6F615A5725BDA0D3B4EB3A257C9AF1F8F03009"
+ ),
+ (
+ "BBAA99887766554433221102",
+ "0001020304050607",
+ "",
+ "81017F8203F081277152FADE694A0A00"
+ ),
+ (
+ "BBAA99887766554433221103",
+ "",
+ "0001020304050607",
+ "45DD69F8F5AAE72414054CD1F35D82760B2CD00D2F99BFA9"
+ ),
+ (
+ "BBAA99887766554433221104",
+ "000102030405060708090A0B0C0D0E0F",
+ "000102030405060708090A0B0C0D0E0F",
+ "571D535B60B277188BE5147170A9A22C3AD7A4FF3835B8C5"
+ "701C1CCEC8FC3358"
+ ),
+ (
+ "BBAA99887766554433221105",
+ "000102030405060708090A0B0C0D0E0F",
+ "",
+ "8CF761B6902EF764462AD86498CA6B97"
+ ),
+ (
+ "BBAA99887766554433221106",
+ "",
+ "000102030405060708090A0B0C0D0E0F",
+ "5CE88EC2E0692706A915C00AEB8B2396F40E1C743F52436B"
+ "DF06D8FA1ECA343D"
+ ),
+ (
+ "BBAA99887766554433221107",
+ "000102030405060708090A0B0C0D0E0F1011121314151617",
+ "000102030405060708090A0B0C0D0E0F1011121314151617",
+ "1CA2207308C87C010756104D8840CE1952F09673A448A122"
+ "C92C62241051F57356D7F3C90BB0E07F"
+ ),
+ (
+ "BBAA99887766554433221108",
+ "000102030405060708090A0B0C0D0E0F1011121314151617",
+ "",
+ "6DC225A071FC1B9F7C69F93B0F1E10DE"
+ ),
+ (
+ "BBAA99887766554433221109",
+ "",
+ "000102030405060708090A0B0C0D0E0F1011121314151617",
+ "221BD0DE7FA6FE993ECCD769460A0AF2D6CDED0C395B1C3C"
+ "E725F32494B9F914D85C0B1EB38357FF"
+ ),
+ (
+ "BBAA9988776655443322110A",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F",
+ "BD6F6C496201C69296C11EFD138A467ABD3C707924B964DE"
+ "AFFC40319AF5A48540FBBA186C5553C68AD9F592A79A4240"
+ ),
+ (
+ "BBAA9988776655443322110B",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F",
+ "",
+ "FE80690BEE8A485D11F32965BC9D2A32"
+ ),
+ (
+ "BBAA9988776655443322110C",
+ "",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F",
+ "2942BFC773BDA23CABC6ACFD9BFD5835BD300F0973792EF4"
+ "6040C53F1432BCDFB5E1DDE3BC18A5F840B52E653444D5DF"
+ ),
+ (
+ "BBAA9988776655443322110D",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F2021222324252627",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F2021222324252627",
+ "D5CA91748410C1751FF8A2F618255B68A0A12E093FF45460"
+ "6E59F9C1D0DDC54B65E8628E568BAD7AED07BA06A4A69483"
+ "A7035490C5769E60"
+ ),
+ (
+ "BBAA9988776655443322110E",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F2021222324252627",
+ "",
+ "C5CD9D1850C141E358649994EE701B68"
+ ),
+ (
+ "BBAA9988776655443322110F",
+ "",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F2021222324252627",
+ "4412923493C57D5DE0D700F753CCE0D1D2D95060122E9F15"
+ "A5DDBFC5787E50B5CC55EE507BCB084E479AD363AC366B95"
+ "A98CA5F3000B1479"
+ )
+ )
+
+ # Tuple with
+ # - key
+ # - nonce
+ # - authenticated data
+ # - plaintext
+ # - ciphertext and 12 byte MAC tag
+ tv2 = (
+ "0F0E0D0C0B0A09080706050403020100",
+ "BBAA9988776655443322110D",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F2021222324252627",
+ "000102030405060708090A0B0C0D0E0F1011121314151617"
+ "18191A1B1C1D1E1F2021222324252627",
+ "1792A4E31E0755FB03E31B22116E6C2DDF9EFD6E33D536F1"
+ "A0124B0A55BAE884ED93481529C76B6AD0C515F4D1CDD4FD"
+ "AC4F02AA"
+ )
+
+ # Tuple with
+ # - key length
+ # - MAC tag length
+ # - Expected output
+ tv3 = (
+ (128, 128, "67E944D23256C5E0B6C61FA22FDF1EA2"),
+ (192, 128, "F673F2C3E7174AAE7BAE986CA9F29E17"),
+ (256, 128, "D90EB8E9C977C88B79DD793D7FFA161C"),
+ (128, 96, "77A3D8E73589158D25D01209"),
+ (192, 96, "05D56EAD2752C86BE6932C5E"),
+ (256, 96, "5458359AC23B0CBA9E6330DD"),
+ (128, 64, "192C9B7BD90BA06A"),
+ (192, 64, "0066BC6E0EF34E24"),
+ (256, 64, "7D4EA5D445501CBE"),
+ )
+
+ def test1(self):
+ key = unhexlify(b(self.tv1_key))
+ for tv in self.tv1:
+ nonce, aad, pt, ct = [ unhexlify(b(x)) for x in tv ]
+ ct, mac_tag = ct[:-16], ct[-16:]
+
+ cipher = AES.new(key, AES.MODE_OCB, nonce=nonce)
+ cipher.update(aad)
+ ct2 = cipher.encrypt(pt) + cipher.encrypt()
+ self.assertEqual(ct, ct2)
+ self.assertEqual(mac_tag, cipher.digest())
+
+ cipher = AES.new(key, AES.MODE_OCB, nonce=nonce)
+ cipher.update(aad)
+ pt2 = cipher.decrypt(ct) + cipher.decrypt()
+ self.assertEqual(pt, pt2)
+ cipher.verify(mac_tag)
+
+ def test2(self):
+
+ key, nonce, aad, pt, ct = [ unhexlify(b(x)) for x in self.tv2 ]
+ ct, mac_tag = ct[:-12], ct[-12:]
+
+ cipher = AES.new(key, AES.MODE_OCB, nonce=nonce, mac_len=12)
+ cipher.update(aad)
+ ct2 = cipher.encrypt(pt) + cipher.encrypt()
+ self.assertEqual(ct, ct2)
+ self.assertEqual(mac_tag, cipher.digest())
+
+ cipher = AES.new(key, AES.MODE_OCB, nonce=nonce, mac_len=12)
+ cipher.update(aad)
+ pt2 = cipher.decrypt(ct) + cipher.decrypt()
+ self.assertEqual(pt, pt2)
+ cipher.verify(mac_tag)
+
+ def test3(self):
+
+ for keylen, taglen, result in self.tv3:
+
+ key = bchr(0) * (keylen // 8 - 1) + bchr(taglen)
+ C = b("")
+
+ for i in range(128):
+ S = bchr(0) * i
+
+ N = long_to_bytes(3 * i + 1, 12)
+ cipher = AES.new(key, AES.MODE_OCB, nonce=N, mac_len=taglen // 8)
+ cipher.update(S)
+ C += cipher.encrypt(S) + cipher.encrypt() + cipher.digest()
+
+ N = long_to_bytes(3 * i + 2, 12)
+ cipher = AES.new(key, AES.MODE_OCB, nonce=N, mac_len=taglen // 8)
+ C += cipher.encrypt(S) + cipher.encrypt() + cipher.digest()
+
+ N = long_to_bytes(3 * i + 3, 12)
+ cipher = AES.new(key, AES.MODE_OCB, nonce=N, mac_len=taglen // 8)
+ cipher.update(S)
+ C += cipher.encrypt() + cipher.digest()
+
+ N = long_to_bytes(385, 12)
+ cipher = AES.new(key, AES.MODE_OCB, nonce=N, mac_len=taglen // 8)
+ cipher.update(C)
+ result2 = cipher.encrypt() + cipher.digest()
+ self.assertEqual(unhexlify(b(result)), result2)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(OcbTests)
+ tests += list_test_cases(OcbFSMTests)
+ tests += list_test_cases(OcbRfc7253Test)
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_OFB.py b/lib/Crypto/SelfTest/Cipher/test_OFB.py
new file mode 100644
index 0000000..ec145ad
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_OFB.py
@@ -0,0 +1,238 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.py3compat import tobytes
+from Crypto.Cipher import AES, DES3, DES
+from Crypto.Hash import SHAKE128
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+from Crypto.SelfTest.Cipher.test_CBC import BlockChainingTests
+
+class OfbTests(BlockChainingTests):
+
+ aes_mode = AES.MODE_OFB
+ des3_mode = DES3.MODE_OFB
+
+ # Redefine test_unaligned_data_128/64
+
+ def test_unaligned_data_128(self):
+ plaintexts = [ b"7777777" ] * 100
+
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=8)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=8)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=128)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = AES.new(self.key_128, AES.MODE_CFB, self.iv_128, segment_size=128)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ def test_unaligned_data_64(self):
+ plaintexts = [ b"7777777" ] * 100
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=8)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=8)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=64)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = DES3.new(self.key_192, DES3.MODE_CFB, self.iv_64, segment_size=64)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+
+from Crypto.SelfTest.Cipher.test_CBC import NistBlockChainingVectors
+
+class NistOfbVectors(NistBlockChainingVectors):
+ aes_mode = AES.MODE_OFB
+ des_mode = DES.MODE_OFB
+ des3_mode = DES3.MODE_OFB
+
+
+# Create one test method per file
+nist_aes_kat_mmt_files = (
+ # KAT
+ "OFBGFSbox128.rsp",
+ "OFBGFSbox192.rsp",
+ "OFBGFSbox256.rsp",
+ "OFBKeySbox128.rsp",
+ "OFBKeySbox192.rsp",
+ "OFBKeySbox256.rsp",
+ "OFBVarKey128.rsp",
+ "OFBVarKey192.rsp",
+ "OFBVarKey256.rsp",
+ "OFBVarTxt128.rsp",
+ "OFBVarTxt192.rsp",
+ "OFBVarTxt256.rsp",
+ # MMT
+ "OFBMMT128.rsp",
+ "OFBMMT192.rsp",
+ "OFBMMT256.rsp",
+ )
+nist_aes_mct_files = (
+ "OFBMCT128.rsp",
+ "OFBMCT192.rsp",
+ "OFBMCT256.rsp",
+ )
+
+for file_name in nist_aes_kat_mmt_files:
+ def new_func(self, file_name=file_name):
+ self._do_kat_aes_test(file_name)
+ setattr(NistOfbVectors, "test_AES_" + file_name, new_func)
+
+for file_name in nist_aes_mct_files:
+ def new_func(self, file_name=file_name):
+ self._do_mct_aes_test(file_name)
+ setattr(NistOfbVectors, "test_AES_" + file_name, new_func)
+del file_name, new_func
+
+nist_tdes_files = (
+ "TOFBMMT2.rsp", # 2TDES
+ "TOFBMMT3.rsp", # 3TDES
+ "TOFBinvperm.rsp", # Single DES
+ "TOFBpermop.rsp",
+ "TOFBsubtab.rsp",
+ "TOFBvarkey.rsp",
+ "TOFBvartext.rsp",
+ )
+
+for file_name in nist_tdes_files:
+ def new_func(self, file_name=file_name):
+ self._do_tdes_test(file_name)
+ setattr(NistOfbVectors, "test_TDES_" + file_name, new_func)
+
+# END OF NIST OFB TEST VECTORS
+
+
+class SP800TestVectors(unittest.TestCase):
+ """Class exercising the OFB test vectors found in Section F.4
+ of NIST SP 800-3A"""
+
+ def test_aes_128(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = '3b3fd92eb72dad20333449f8e83cfb4a' +\
+ '7789508d16918f03f53c52dac54ed825' +\
+ '9740051e9c5fecf64344f7a82260edcc' +\
+ '304c6528f659c77866a510d9c1d6ae5e'
+ key = '2b7e151628aed2a6abf7158809cf4f3c'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.encrypt(plaintext[:-8]), ciphertext[:-8])
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.decrypt(ciphertext[:-8]), plaintext[:-8])
+
+ def test_aes_192(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = 'cdc80d6fddf18cab34c25909c99a4174' +\
+ 'fcc28b8d4c63837c09e81700c1100401' +\
+ '8d9a9aeac0f6596f559c6d4daf59a5f2' +\
+ '6d9f200857ca6c3e9cac524bd9acc92a'
+ key = '8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.encrypt(plaintext[:-8]), ciphertext[:-8])
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.decrypt(ciphertext[:-8]), plaintext[:-8])
+
+ def test_aes_256(self):
+ plaintext = '6bc1bee22e409f96e93d7e117393172a' +\
+ 'ae2d8a571e03ac9c9eb76fac45af8e51' +\
+ '30c81c46a35ce411e5fbc1191a0a52ef' +\
+ 'f69f2445df4f9b17ad2b417be66c3710'
+ ciphertext = 'dc7e84bfda79164b7ecd8486985d3860' +\
+ '4febdc6740d20b3ac88f6ad82a4fb08d' +\
+ '71ab47a086e86eedf39d1c5bba97c408' +\
+ '0126141d67f37be8538f5a8be740e484'
+ key = '603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4'
+ iv = '000102030405060708090a0b0c0d0e0f'
+
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.encrypt(plaintext), ciphertext)
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.decrypt(ciphertext), plaintext)
+
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.encrypt(plaintext[:-8]), ciphertext[:-8])
+ cipher = AES.new(key, AES.MODE_OFB, iv)
+ self.assertEqual(cipher.decrypt(ciphertext[:-8]), plaintext[:-8])
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(OfbTests)
+ if config.get('slow_tests'):
+ tests += list_test_cases(NistOfbVectors)
+ tests += list_test_cases(SP800TestVectors)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_OpenPGP.py b/lib/Crypto/SelfTest/Cipher/test_OpenPGP.py
new file mode 100644
index 0000000..e6cae67
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_OpenPGP.py
@@ -0,0 +1,218 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.py3compat import tobytes
+from Crypto.Cipher import AES, DES3, DES
+from Crypto.Hash import SHAKE128
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+from Crypto.SelfTest.Cipher.test_CBC import BlockChainingTests
+
+class OpenPGPTests(BlockChainingTests):
+
+ aes_mode = AES.MODE_OPENPGP
+ des3_mode = DES3.MODE_OPENPGP
+
+ # Redefine test_unaligned_data_128/64
+
+ key_128 = get_tag_random("key_128", 16)
+ key_192 = get_tag_random("key_192", 24)
+ iv_128 = get_tag_random("iv_128", 16)
+ iv_64 = get_tag_random("iv_64", 8)
+ data_128 = get_tag_random("data_128", 16)
+
+ def test_loopback_128(self):
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, self.iv_128)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct = cipher.encrypt(pt)
+
+ eiv, ct = ct[:18], ct[18:]
+
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, eiv)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_loopback_64(self):
+ cipher = DES3.new(self.key_192, DES3.MODE_OPENPGP, self.iv_64)
+ pt = get_tag_random("plaintext", 8 * 100)
+ ct = cipher.encrypt(pt)
+
+ eiv, ct = ct[:10], ct[10:]
+
+ cipher = DES3.new(self.key_192, DES3.MODE_OPENPGP, eiv)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def test_IV_iv_attributes(self):
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, self.iv_128)
+ eiv = cipher.encrypt(b"")
+ self.assertEqual(cipher.iv, self.iv_128)
+
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, eiv)
+ self.assertEqual(cipher.iv, self.iv_128)
+
+ def test_null_encryption_decryption(self):
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, self.iv_128)
+ eiv = cipher.encrypt(b"")
+
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, eiv)
+ self.assertEqual(cipher.decrypt(b""), b"")
+
+ def test_either_encrypt_or_decrypt(self):
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, self.iv_128)
+ eiv = cipher.encrypt(b"")
+ self.assertRaises(TypeError, cipher.decrypt, b"")
+
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, eiv)
+ cipher.decrypt(b"")
+ self.assertRaises(TypeError, cipher.encrypt, b"")
+
+ def test_unaligned_data_128(self):
+ plaintexts = [ b"7777777" ] * 100
+
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, self.iv_128)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = AES.new(self.key_128, AES.MODE_OPENPGP, self.iv_128)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ def test_unaligned_data_64(self):
+ plaintexts = [ b"7777777" ] * 100
+
+ cipher = DES3.new(self.key_192, DES3.MODE_OPENPGP, self.iv_64)
+ ciphertexts = [ cipher.encrypt(x) for x in plaintexts ]
+ cipher = DES3.new(self.key_192, DES3.MODE_OPENPGP, self.iv_64)
+ self.assertEqual(b"".join(ciphertexts), cipher.encrypt(b"".join(plaintexts)))
+
+ def test_output_param(self):
+ pass
+
+ def test_output_param_same_buffer(self):
+ pass
+
+ def test_output_param_memoryview(self):
+ pass
+
+ def test_output_param_neg(self):
+ pass
+
+
+class TestVectors(unittest.TestCase):
+
+ def test_aes(self):
+ # The following test vectors have been generated with gpg v1.4.0.
+ # The command line used was:
+ #
+ # gpg -c -z 0 --cipher-algo AES --passphrase secret_passphrase \
+ # --disable-mdc --s2k-mode 0 --output ct pt
+ #
+ # As result, the content of the file 'pt' is encrypted with a key derived
+ # from 'secret_passphrase' and written to file 'ct'.
+ # Test vectors must be extracted from 'ct', which is a collection of
+ # TLVs (see RFC4880 for all details):
+ # - the encrypted data (with the encrypted IV as prefix) is the payload
+ # of the TLV with tag 9 (Symmetrical Encrypted Data Packet).
+ # This is the ciphertext in the test vector.
+ # - inside the encrypted part, there is a further layer of TLVs. One must
+ # look for tag 11 (Literal Data Packet); in its payload, after a short
+ # but time dependent header, there is the content of file 'pt'.
+ # In the test vector, the plaintext is the complete set of TLVs that gets
+ # encrypted. It is not just the content of 'pt'.
+ # - the key is the leftmost 16 bytes of the SHA1 digest of the password.
+ # The test vector contains such shortened digest.
+ #
+ # Note that encryption uses a clear IV, and decryption an encrypted IV
+
+ plaintext = 'ac18620270744fb4f647426c61636b4361745768697465436174'
+ ciphertext = 'dc6b9e1f095de609765c59983db5956ae4f63aea7405389d2ebb'
+ key = '5baa61e4c9b93f3f0682250b6cf8331b'
+ iv = '3d7d3e62282add7eb203eeba5c800733'
+ encrypted_iv='fd934601ef49cb58b6d9aebca6056bdb96ef'
+
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ encrypted_iv = unhexlify(encrypted_iv)
+
+ cipher = AES.new(key, AES.MODE_OPENPGP, iv)
+ ct = cipher.encrypt(plaintext)
+ self.assertEqual(ct[:18], encrypted_iv)
+ self.assertEqual(ct[18:], ciphertext)
+
+ cipher = AES.new(key, AES.MODE_OPENPGP, encrypted_iv)
+ pt = cipher.decrypt(ciphertext)
+ self.assertEqual(pt, plaintext)
+
+ def test_des3(self):
+ # The following test vectors have been generated with gpg v1.4.0.
+ # The command line used was:
+ # gpg -c -z 0 --cipher-algo 3DES --passphrase secret_passphrase \
+ # --disable-mdc --s2k-mode 0 --output ct pt
+ # For an explanation, see test_AES.py .
+
+ plaintext = 'ac1762037074324fb53ba3596f73656d69746556616c6c6579'
+ ciphertext = '9979238528357b90e2e0be549cb0b2d5999b9a4a447e5c5c7d'
+ key = '7ade65b460f5ea9be35f9e14aa883a2048e3824aa616c0b2'
+ iv='cd47e2afb8b7e4b0'
+ encrypted_iv='6a7eef0b58050e8b904a'
+
+ plaintext = unhexlify(plaintext)
+ ciphertext = unhexlify(ciphertext)
+ key = unhexlify(key)
+ iv = unhexlify(iv)
+ encrypted_iv = unhexlify(encrypted_iv)
+
+ cipher = DES3.new(key, DES3.MODE_OPENPGP, iv)
+ ct = cipher.encrypt(plaintext)
+ self.assertEqual(ct[:10], encrypted_iv)
+ self.assertEqual(ct[10:], ciphertext)
+
+ cipher = DES3.new(key, DES3.MODE_OPENPGP, encrypted_iv)
+ pt = cipher.decrypt(ciphertext)
+ self.assertEqual(pt, plaintext)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(OpenPGPTests)
+ tests += list_test_cases(TestVectors)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_SIV.py b/lib/Crypto/SelfTest/Cipher/test_SIV.py
new file mode 100644
index 0000000..a80ddc1
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_SIV.py
@@ -0,0 +1,552 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import json
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+
+from Crypto.Util.py3compat import tobytes, bchr
+from Crypto.Cipher import AES
+from Crypto.Hash import SHAKE128
+
+from Crypto.Util.strxor import strxor
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class SivTests(unittest.TestCase):
+
+ key_256 = get_tag_random("key_256", 32)
+ key_384 = get_tag_random("key_384", 48)
+ key_512 = get_tag_random("key_512", 64)
+ nonce_96 = get_tag_random("nonce_128", 12)
+ data = get_tag_random("data", 128)
+
+ def test_loopback_128(self):
+ for key in self.key_256, self.key_384, self.key_512:
+ cipher = AES.new(key, AES.MODE_SIV, nonce=self.nonce_96)
+ pt = get_tag_random("plaintext", 16 * 100)
+ ct, mac = cipher.encrypt_and_digest(pt)
+
+ cipher = AES.new(key, AES.MODE_SIV, nonce=self.nonce_96)
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(pt, pt2)
+
+ def test_nonce(self):
+ # Deterministic encryption
+ AES.new(self.key_256, AES.MODE_SIV)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, self.nonce_96)
+ ct1, tag1 = cipher.encrypt_and_digest(self.data)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ ct2, tag2 = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(ct1 + tag1, ct2 + tag2)
+
+ def test_nonce_must_be_bytes(self):
+ self.assertRaises(TypeError, AES.new, self.key_256, AES.MODE_SIV,
+ nonce=u'test12345678')
+
+ def test_nonce_length(self):
+ # nonce can be of any length (but not empty)
+ self.assertRaises(ValueError, AES.new, self.key_256, AES.MODE_SIV,
+ nonce=b"")
+
+ for x in range(1, 128):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=bchr(1) * x)
+ cipher.encrypt_and_digest(b'\x01')
+
+ def test_block_size_128(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertEqual(cipher.block_size, AES.block_size)
+
+ def test_nonce_attribute(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertEqual(cipher.nonce, self.nonce_96)
+
+ # By default, no nonce is randomly generated
+ self.assertFalse(hasattr(AES.new(self.key_256, AES.MODE_SIV), "nonce"))
+
+ def test_unknown_parameters(self):
+ self.assertRaises(TypeError, AES.new, self.key_256, AES.MODE_SIV,
+ self.nonce_96, 7)
+ self.assertRaises(TypeError, AES.new, self.key_256, AES.MODE_SIV,
+ nonce=self.nonce_96, unknown=7)
+
+ # But some are only known by the base cipher
+ # (e.g. use_aesni consumed by the AES module)
+ AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96,
+ use_aesni=False)
+
+ def test_encrypt_excludes_decrypt(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.encrypt_and_digest(self.data)
+ self.assertRaises(TypeError, cipher.decrypt, self.data)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.encrypt_and_digest(self.data)
+ self.assertRaises(TypeError, cipher.decrypt_and_verify,
+ self.data, self.data)
+
+ def test_data_must_be_bytes(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, u'test1234567890-*')
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt_and_verify,
+ u'test1234567890-*', b"xxxx")
+
+ def test_mac_len(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ _, mac = cipher.encrypt_and_digest(self.data)
+ self.assertEqual(len(mac), 16)
+
+ def test_invalid_mac(self):
+ from Crypto.Util.strxor import strxor_c
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ invalid_mac = strxor_c(mac, 0x01)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct,
+ invalid_mac)
+
+ def test_hex_mac(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ mac_hex = cipher.hexdigest()
+ self.assertEqual(cipher.digest(), unhexlify(mac_hex))
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.hexverify(mac_hex)
+
+ def test_bytearray(self):
+
+ # Encrypt
+ key = bytearray(self.key_256)
+ nonce = bytearray(self.nonce_96)
+ data = bytearray(self.data)
+ header = bytearray(self.data)
+
+ cipher1 = AES.new(self.key_256,
+ AES.MODE_SIV,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct, tag = cipher1.encrypt_and_digest(self.data)
+
+ cipher2 = AES.new(key,
+ AES.MODE_SIV,
+ nonce=nonce)
+ key[:3] = b'\xFF\xFF\xFF'
+ nonce[:3] = b'\xFF\xFF\xFF'
+ cipher2.update(header)
+ header[:3] = b'\xFF\xFF\xFF'
+ ct_test, tag_test = cipher2.encrypt_and_digest(data)
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key = bytearray(self.key_256)
+ nonce = bytearray(self.nonce_96)
+ header = bytearray(self.data)
+ ct_ba = bytearray(ct)
+ tag_ba = bytearray(tag)
+
+ cipher3 = AES.new(key,
+ AES.MODE_SIV,
+ nonce=nonce)
+ key[:3] = b'\xFF\xFF\xFF'
+ nonce[:3] = b'\xFF\xFF\xFF'
+ cipher3.update(header)
+ header[:3] = b'\xFF\xFF\xFF'
+ pt_test = cipher3.decrypt_and_verify(ct_ba, tag_ba)
+
+ self.assertEqual(self.data, pt_test)
+
+ def test_memoryview(self):
+
+ # Encrypt
+ key = memoryview(bytearray(self.key_256))
+ nonce = memoryview(bytearray(self.nonce_96))
+ data = memoryview(bytearray(self.data))
+ header = memoryview(bytearray(self.data))
+
+ cipher1 = AES.new(self.key_256,
+ AES.MODE_SIV,
+ nonce=self.nonce_96)
+ cipher1.update(self.data)
+ ct, tag = cipher1.encrypt_and_digest(self.data)
+
+ cipher2 = AES.new(key,
+ AES.MODE_SIV,
+ nonce=nonce)
+ key[:3] = b'\xFF\xFF\xFF'
+ nonce[:3] = b'\xFF\xFF\xFF'
+ cipher2.update(header)
+ header[:3] = b'\xFF\xFF\xFF'
+ ct_test, tag_test= cipher2.encrypt_and_digest(data)
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(tag, tag_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decrypt
+ key = memoryview(bytearray(self.key_256))
+ nonce = memoryview(bytearray(self.nonce_96))
+ header = memoryview(bytearray(self.data))
+ ct_ba = memoryview(bytearray(ct))
+ tag_ba = memoryview(bytearray(tag))
+
+ cipher3 = AES.new(key,
+ AES.MODE_SIV,
+ nonce=nonce)
+ key[:3] = b'\xFF\xFF\xFF'
+ nonce[:3] = b'\xFF\xFF\xFF'
+ cipher3.update(header)
+ header[:3] = b'\xFF\xFF\xFF'
+ pt_test = cipher3.decrypt_and_verify(ct_ba, tag_ba)
+
+ self.assertEqual(self.data, pt_test)
+
+ def test_output_param(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ ct, tag = cipher.encrypt_and_digest(pt)
+
+ output = bytearray(128)
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ res, tag_out = cipher.encrypt_and_digest(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+ self.assertEqual(tag, tag_out)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ res = cipher.decrypt_and_verify(ct, tag, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ def test_output_param_memoryview(self):
+
+ pt = b'5' * 128
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ ct, tag = cipher.encrypt_and_digest(pt)
+
+ output = memoryview(bytearray(128))
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.encrypt_and_digest(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.decrypt_and_verify(ct, tag, output=output)
+ self.assertEqual(pt, output)
+
+ def test_output_param_neg(self):
+ LEN_PT = 128
+
+ pt = b'5' * LEN_PT
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ ct, tag = cipher.encrypt_and_digest(pt)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt_and_digest, pt, output=b'0' * LEN_PT)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt_and_verify, ct, tag, output=b'0' * LEN_PT)
+
+ shorter_output = bytearray(LEN_PT - 1)
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.encrypt_and_digest, pt, output=shorter_output)
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ self.assertRaises(ValueError, cipher.decrypt_and_verify, ct, tag, output=shorter_output)
+
+
+class SivFSMTests(unittest.TestCase):
+
+ key_256 = get_tag_random("key_256", 32)
+ nonce_96 = get_tag_random("nonce_96", 12)
+ data = get_tag_random("data", 128)
+
+ def test_invalid_init_encrypt(self):
+ # Path INIT->ENCRYPT fails
+ cipher = AES.new(self.key_256, AES.MODE_SIV,
+ nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.encrypt, b"xxx")
+
+ def test_invalid_init_decrypt(self):
+ # Path INIT->DECRYPT fails
+ cipher = AES.new(self.key_256, AES.MODE_SIV,
+ nonce=self.nonce_96)
+ self.assertRaises(TypeError, cipher.decrypt, b"xxx")
+
+ def test_valid_init_update_digest_verify(self):
+ # No plaintext, fixed authenticated data
+ # Verify path INIT->UPDATE->DIGEST
+ cipher = AES.new(self.key_256, AES.MODE_SIV,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ mac = cipher.digest()
+
+ # Verify path INIT->UPDATE->VERIFY
+ cipher = AES.new(self.key_256, AES.MODE_SIV,
+ nonce=self.nonce_96)
+ cipher.update(self.data)
+ cipher.verify(mac)
+
+ def test_valid_init_digest(self):
+ # Verify path INIT->DIGEST
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.digest()
+
+ def test_valid_init_verify(self):
+ # Verify path INIT->VERIFY
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ mac = cipher.digest()
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.verify(mac)
+
+ def test_valid_multiple_digest_or_verify(self):
+ # Multiple calls to digest
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.update(self.data)
+ first_mac = cipher.digest()
+ for x in range(4):
+ self.assertEqual(first_mac, cipher.digest())
+
+ # Multiple calls to verify
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.update(self.data)
+ for x in range(5):
+ cipher.verify(first_mac)
+
+ def test_valid_encrypt_and_digest_decrypt_and_verify(self):
+ # encrypt_and_digest
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.update(self.data)
+ ct, mac = cipher.encrypt_and_digest(self.data)
+
+ # decrypt_and_verify
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.update(self.data)
+ pt = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(self.data, pt)
+
+ def test_invalid_multiple_encrypt_and_digest(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ ct, tag = cipher.encrypt_and_digest(self.data)
+ self.assertRaises(TypeError, cipher.encrypt_and_digest, b'')
+
+ def test_invalid_multiple_decrypt_and_verify(self):
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ ct, tag = cipher.encrypt_and_digest(self.data)
+
+ cipher = AES.new(self.key_256, AES.MODE_SIV, nonce=self.nonce_96)
+ cipher.decrypt_and_verify(ct, tag)
+ self.assertRaises(TypeError, cipher.decrypt_and_verify, ct, tag)
+
+
+def transform(tv):
+ new_tv = [[unhexlify(x) for x in tv[0].split("-")]]
+ new_tv += [ unhexlify(x) for x in tv[1:5]]
+ if tv[5]:
+ nonce = unhexlify(tv[5])
+ else:
+ nonce = None
+ new_tv += [ nonce ]
+ return new_tv
+
+
+class TestVectors(unittest.TestCase):
+ """Class exercising the SIV test vectors found in RFC5297"""
+
+ # This is a list of tuples with 5 items:
+ #
+ # 1. Header + '|' + plaintext
+ # 2. Header + '|' + ciphertext + '|' + MAC
+ # 3. AES-128 key
+ # 4. Description
+ # 5. Dictionary of parameters to be passed to AES.new().
+ # It must include the nonce.
+ #
+ # A "Header" is a dash ('-') separated sequece of components.
+ #
+ test_vectors_hex = [
+ (
+ '101112131415161718191a1b1c1d1e1f2021222324252627',
+ '112233445566778899aabbccddee',
+ '40c02b9690c4dc04daef7f6afe5c',
+ '85632d07c6e8f37f950acd320a2ecc93',
+ 'fffefdfcfbfaf9f8f7f6f5f4f3f2f1f0f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff',
+ None
+ ),
+ (
+ '00112233445566778899aabbccddeeffdeaddadadeaddadaffeeddccbbaa9988' +
+ '7766554433221100-102030405060708090a0',
+ '7468697320697320736f6d6520706c61696e7465787420746f20656e63727970' +
+ '74207573696e67205349562d414553',
+ 'cb900f2fddbe404326601965c889bf17dba77ceb094fa663b7a3f748ba8af829' +
+ 'ea64ad544a272e9c485b62a3fd5c0d',
+ '7bdb6e3b432667eb06f4d14bff2fbd0f',
+ '7f7e7d7c7b7a79787776757473727170404142434445464748494a4b4c4d4e4f',
+ '09f911029d74e35bd84156c5635688c0'
+ ),
+ ]
+
+ test_vectors = [ transform(tv) for tv in test_vectors_hex ]
+
+ def runTest(self):
+ for assoc_data, pt, ct, mac, key, nonce in self.test_vectors:
+
+ # Encrypt
+ cipher = AES.new(key, AES.MODE_SIV, nonce=nonce)
+ for x in assoc_data:
+ cipher.update(x)
+ ct2, mac2 = cipher.encrypt_and_digest(pt)
+ self.assertEqual(ct, ct2)
+ self.assertEqual(mac, mac2)
+
+ # Decrypt
+ cipher = AES.new(key, AES.MODE_SIV, nonce=nonce)
+ for x in assoc_data:
+ cipher.update(x)
+ pt2 = cipher.decrypt_and_verify(ct, mac)
+ self.assertEqual(pt, pt2)
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self):
+ unittest.TestCase.__init__(self)
+ self._id = "None"
+
+ def setUp(self):
+ self.tv = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ "aes_siv_cmac_test.json",
+ "Wycheproof AES SIV")
+
+ def shortDescription(self):
+ return self._id
+
+ def test_encrypt(self, tv):
+ self._id = "Wycheproof Encrypt AES-SIV Test #" + str(tv.id)
+
+ cipher = AES.new(tv.key, AES.MODE_SIV)
+ cipher.update(tv.aad)
+ ct, tag = cipher.encrypt_and_digest(tv.msg)
+ if tv.valid:
+ self.assertEqual(tag + ct, tv.ct)
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt AES_SIV Test #" + str(tv.id)
+
+ cipher = AES.new(tv.key, AES.MODE_SIV)
+ cipher.update(tv.aad)
+ try:
+ pt = cipher.decrypt_and_verify(tv.ct[16:], tv.ct[:16])
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_encrypt(tv)
+ self.test_decrypt(tv)
+
+
+class TestVectorsWycheproof2(unittest.TestCase):
+
+ def __init__(self):
+ unittest.TestCase.__init__(self)
+ self._id = "None"
+
+ def setUp(self):
+ self.tv = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ "aead_aes_siv_cmac_test.json",
+ "Wycheproof AEAD SIV")
+
+ def shortDescription(self):
+ return self._id
+
+ def test_encrypt(self, tv):
+ self._id = "Wycheproof Encrypt AEAD-AES-SIV Test #" + str(tv.id)
+
+ cipher = AES.new(tv.key, AES.MODE_SIV, nonce=tv.iv)
+ cipher.update(tv.aad)
+ ct, tag = cipher.encrypt_and_digest(tv.msg)
+ if tv.valid:
+ self.assertEqual(ct, tv.ct)
+ self.assertEqual(tag, tv.tag)
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt AEAD-AES-SIV Test #" + str(tv.id)
+
+ cipher = AES.new(tv.key, AES.MODE_SIV, nonce=tv.iv)
+ cipher.update(tv.aad)
+ try:
+ pt = cipher.decrypt_and_verify(tv.ct, tv.tag)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_encrypt(tv)
+ self.test_decrypt(tv)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(SivTests)
+ tests += list_test_cases(SivFSMTests)
+ tests += [ TestVectors() ]
+ tests += [ TestVectorsWycheproof() ]
+ tests += [ TestVectorsWycheproof2() ]
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Cipher/test_Salsa20.py b/lib/Crypto/SelfTest/Cipher/test_Salsa20.py
new file mode 100644
index 0000000..a710462
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_Salsa20.py
@@ -0,0 +1,367 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/Salsa20.py: Self-test for the Salsa20 stream cipher
+#
+# Written in 2013 by Fabrizio Tarizzo <fabrizio@fabriziotarizzo.org>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Cipher.Salsa20"""
+
+import unittest
+
+from Crypto.Util.py3compat import bchr
+
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Cipher import Salsa20
+
+from .common import make_stream_tests
+
+# This is a list of (plaintext, ciphertext, key[, description[, params]])
+# tuples.
+test_data = [
+ # Test vectors are taken from
+ # http://www.ecrypt.eu.org/stream/svn/viewcvs.cgi/ecrypt/trunk/submissions/salsa20/full/verified.test-vectors
+ ( '00' * 512,
+ '4dfa5e481da23ea09a31022050859936da52fcee218005164f267cb65f5cfd7f'
+ + '2b4f97e0ff16924a52df269515110a07f9e460bc65ef95da58f740b7d1dbb0aa'
+ + 'd64cec189c7eb8c6bbf3d7376c80a481d43e628701f6a27afb9fe23919f24114'
+ + '8db44f70d7063efcc3dd55a0893a613c3c6fe1c127bd6f59910589293bb6ef9e'
+ + 'e24819066dee1a64f49b0bbad5988635272b169af861f85df881939f29ada6fd'
+ + '0241410e8d332ae4798d929434a2630de451ec4e0169694cbaa7ebb121ea6a2b'
+ + 'da9c1581f429e0a00f7d67e23b730676783b262e8eb43a25f55fb90b3e753aef'
+ + '8c6713ec66c51881111593ccb3e8cb8f8de124080501eeeb389c4bcb6977cf95'
+ + '7d5789631eb4554400e1e025935dfa7b3e9039d61bdc58a8697d36815bf1985c'
+ + 'efdf7ae112e5bb81e37ecf0616ce7147fc08a93a367e08631f23c03b00a8da2f'
+ + 'aa5024e5c8d30aca43fc2d5082067b21b234bc741d68fb292c6012c3764ccee3'
+ + '1e364a5403e00cfee338a21a01e7d3cefd5a770ca0ab48c435ea6116435f7ad8'
+ + '30b217b49f978a68e207ed9f462af7fb195b2115fe8f24f152e4ddc32202d6f2'
+ + 'b52fafbcfbc202d8a259a611e901d3f62d065eb13f09bbc45cd45119b843efaa'
+ + 'b375703739daced4dd4059fd71c3c47fc2f9939670fad4a46066adcc6a564578'
+ + '3308b90ffb72be04a6b147cbe38cc0c3b9267c296a92a7c69873f9f263be9703',
+ '80000000000000000000000000000000',
+ '128 bits key, set 1, vector 0',
+ dict (iv='00'*8)),
+
+ ( '00' * 512,
+ 'e3be8fdd8beca2e3ea8ef9475b29a6e7003951e1097a5c38d23b7a5fad9f6844'
+ + 'b22c97559e2723c7cbbd3fe4fc8d9a0744652a83e72a9c461876af4d7ef1a117'
+ + '8da2b74eef1b6283e7e20166abcae538e9716e4669e2816b6b20c5c356802001'
+ + 'cc1403a9a117d12a2669f456366d6ebb0f1246f1265150f793cdb4b253e348ae'
+ + '203d89bc025e802a7e0e00621d70aa36b7e07cb1e7d5b38d5e222b8b0e4b8407'
+ + '0142b1e29504767d76824850320b5368129fdd74e861b498e3be8d16f2d7d169'
+ + '57be81f47b17d9ae7c4ff15429a73e10acf250ed3a90a93c711308a74c6216a9'
+ + 'ed84cd126da7f28e8abf8bb63517e1ca98e712f4fb2e1a6aed9fdc73291faa17'
+ + '958211c4ba2ebd5838c635edb81f513a91a294e194f1c039aeec657dce40aa7e'
+ + '7c0af57cacefa40c9f14b71a4b3456a63e162ec7d8d10b8ffb1810d71001b618'
+ + '2f9f73da53b85405c11f7b2d890fa8ae0c7f2e926d8a98c7ec4e91b65120e988'
+ + '349631a700c6facec3471cb0413656e75e309456584084d7e12c5b43a41c43ed'
+ + '9a048abd9b880da65f6a665a20fe7b77cd292fe62cae644b7f7df69f32bdb331'
+ + '903e6505ce44fdc293920c6a9ec7057e23df7dad298f82ddf4efb7fdc7bfc622'
+ + '696afcfd0cddcc83c7e77f11a649d79acdc3354e9635ff137e929933a0bd6f53'
+ + '77efa105a3a4266b7c0d089d08f1e855cc32b15b93784a36e56a76cc64bc8477',
+ '8000000000000000000000000000000000000000000000000000000000000000',
+ '256 bits key, set 1, vector 0',
+ dict (iv='00'*8)),
+
+ ( '00' * 512,
+ '169060ccb42bea7bee4d8012a02f3635eb7bca12859fa159cd559094b3507db8'
+ + '01735d1a1300102a9c9415546829cbd2021ba217b39b81d89c55b13d0c603359'
+ + '3f84159a3c84f4b4f4a0edcd9d38ff261a737909e0b66d68b5cac496f3a5be99'
+ + 'cb12c321ab711afaab36cc0947955e1a9bb952ed54425e7711279fbc81bb83f5'
+ + '6e55cea44e6daddb05858a153ea6213b3350c12aa1a83ef2726f09485fa71790'
+ + 'f9b9f922c7dda1113b1f9d56658ed3402803f511bc1f122601d5e7f0ff036e23'
+ + '23ef24bb24195b9fd574823cd8a40c29d86bd35c191e2038779ff696c712b6d8'
+ + '2e7014dbe1ac5d527af076c088c4a8d44317958189f6ef54933a7e0816b5b916'
+ + 'd8f12ed8afe9422b85e5cc9b8adec9d6cfabe8dbc1082bccc02f5a7266aa074c'
+ + 'a284e583a35837798cc0e69d4ce937653b8cdd65ce414b89138615ccb165ad19'
+ + '3c6b9c3d05eef4be921a10ea811fe61d11c6867600188e065daff90b509ec56b'
+ + 'd41e7e8968c478c78d590c2d2ee24ea009c8f49bc3d81672cfc47895a9e21c9a'
+ + '471ebf8e294bee5d2de436ac8d052bf31111b345f1da23c3a4d13b9fc5f0900a'
+ + 'a298f98f538973b8fad40d4d159777de2cfe2a3dead1645ddb49794827dba040'
+ + 'f70a0ff4ecd155e0f033604693a51e2363880e2ecf98699e7174af7c2c6b0fc6'
+ + '59ae329599a3949272a37b9b2183a0910922a3f325ae124dcbdd735364055ceb',
+ '09090909090909090909090909090909',
+ '128 bits key, set 2, vector 9',
+ dict (iv='00'*8)),
+
+ ( '00' * 512,
+ '7041e747ceb22ed7812985465f50333124f971da1c5d6efe5ca201b886f31046'
+ + 'e757e5c3ec914f60ed1f6bce2819b6810953f12b8ba1199bf82d746a8b8a88f1'
+ + '142002978ec4c35b95dc2c82990f9e847a0ab45f2ca72625f5190c820f29f3aa'
+ + 'f5f0b5572b06b70a144f2a240c3b3098d4831fa1ce1459f8d1df226a6a79b0ab'
+ + '41e91799ef31b5ff3d756c19126b19025858ee70fbd69f2be955cb011c005e31'
+ + '32b271b378f39b0cb594e95c99ce6ff17735a541891845bbf0450afcb4a850b9'
+ + '4ee90afb713ae7e01295c74381180a3816d7020d5a396c0d97aaa783eaabb6ec'
+ + '44d5111157f2212d1b1b8fca7893e8b520cd482418c272ab119b569a2b9598eb'
+ + '355624d12e79adab81153b58cd22eaf1b2a32395dedc4a1c66f4d274070b9800'
+ + 'ea95766f0245a8295f8aadb36ddbbdfa936417c8dbc6235d19494036964d3e70'
+ + 'b125b0f800c3d53881d9d11e7970f827c2f9556935cd29e927b0aceb8cae5fd4'
+ + '0fd88a8854010a33db94c96c98735858f1c5df6844f864feaca8f41539313e7f'
+ + '3c0610214912cd5e6362197646207e2d64cd5b26c9dfe0822629dcbeb16662e8'
+ + '9ff5bf5cf2e499138a5e27bd5027329d0e68ddf53103e9e409523662e27f61f6'
+ + '5cf38c1232023e6a6ef66c315bcb2a4328642faabb7ca1e889e039e7c444b34b'
+ + 'b3443f596ac730f3df3dfcdb343c307c80f76e43e8898c5e8f43dc3bb280add0',
+ '0909090909090909090909090909090909090909090909090909090909090909',
+ '256 bits key, set 2, vector 9',
+ dict (iv='00'*8)),
+
+ ( '00' * 1024,
+ '71daee5142d0728b41b6597933ebf467e43279e30978677078941602629cbf68'
+ + 'b73d6bd2c95f118d2b3e6ec955dabb6dc61c4143bc9a9b32b99dbe6866166dc0'
+ + '8631b7d6553050303d7252c264d3a90d26c853634813e09ad7545a6ce7e84a5d'
+ + 'fc75ec43431207d5319970b0faadb0e1510625bb54372c8515e28e2accf0a993'
+ + '0ad15f431874923d2a59e20d9f2a5367dba6051564f150287debb1db536ff9b0'
+ + '9ad981f25e5010d85d76ee0c305f755b25e6f09341e0812f95c94f42eead346e'
+ + '81f39c58c5faa2c88953dc0cac90469db2063cb5cdb22c9eae22afbf0506fca4'
+ + '1dc710b846fbdfe3c46883dd118f3a5e8b11b6afd9e71680d8666557301a2daa'
+ + 'fb9496c559784d35a035360885f9b17bd7191977deea932b981ebdb29057ae3c'
+ + '92cfeff5e6c5d0cb62f209ce342d4e35c69646ccd14e53350e488bb310a32f8b'
+ + '0248e70acc5b473df537ced3f81a014d4083932bedd62ed0e447b6766cd2604b'
+ + '706e9b346c4468beb46a34ecf1610ebd38331d52bf33346afec15eefb2a7699e'
+ + '8759db5a1f636a48a039688e39de34d995df9f27ed9edc8dd795e39e53d9d925'
+ + 'b278010565ff665269042f05096d94da3433d957ec13d2fd82a0066283d0d1ee'
+ + 'b81bf0ef133b7fd90248b8ffb499b2414cd4fa003093ff0864575a43749bf596'
+ + '02f26c717fa96b1d057697db08ebc3fa664a016a67dcef8807577cc3a09385d3'
+ + 'f4dc79b34364bb3b166ce65fe1dd28e3950fe6fa81063f7b16ce1c0e6daac1f8'
+ + '188455b77752045e863c9b256ad92bc6e2d08314c5bba191c274f42dfbb3d652'
+ + 'bb771956555e880f84cd8b827a4c5a52f3a099fa0259bd4aac3efd541f191170'
+ + '4412d6e85fbcc628b335875b9fef24807f6e1bc66c3186159e1e7f5a13913e02'
+ + 'd241ce2efdbcaa275039fb14eac5923d17ffbc7f1abd3b45e92127575bfbabf9'
+ + '3a257ebef0aa1437b326e41b585af572f7239c33b32981a1577a4f629b027e1e'
+ + 'b49d58cc497e944d79cef44357c2bf25442ab779651e991147bf79d6fd3a8868'
+ + '0cd3b1748e07fd10d78aceef6db8a5e563570d40127f754146c34a440f2a991a'
+ + '23fa39d365141f255041f2135c5cba4373452c114da1801bacca38610e3a6524'
+ + '2b822d32de4ab5a7d3cf9b61b37493c863bd12e2cae10530cddcda2cb7a5436b'
+ + 'ef8988d4d24e8cdc31b2d2a3586340bc5141f8f6632d0dd543bfed81eb471ba1'
+ + 'f3dc2225a15ffddcc03eb48f44e27e2aa390598adf83f15c6608a5f18d4dfcf0'
+ + 'f547d467a4d70b281c83a595d7660d0b62de78b9cca023cca89d7b1f83484638'
+ + '0e228c25f049184a612ef5bb3d37454e6cfa5b10dceda619d898a699b3c8981a'
+ + '173407844bb89b4287bf57dd6600c79e352c681d74b03fa7ea0d7bf6ad69f8a6'
+ + '8ecb001963bd2dd8a2baa0083ec09751cd9742402ad716be16d5c052304cfca1',
+ '0F62B5085BAE0154A7FA4DA0F34699EC',
+ '128 bits key, Set 6, vector# 3',
+ dict (iv='288FF65DC42B92F9')),
+
+ ( '00' * 1024,
+ '5e5e71f90199340304abb22a37b6625bf883fb89ce3b21f54a10b81066ef87da'
+ + '30b77699aa7379da595c77dd59542da208e5954f89e40eb7aa80a84a6176663f'
+ + 'd910cde567cf1ff60f7040548d8f376bfd1f44c4774aac37410ede7d5c3463fc'
+ + '4508a603201d8495ad257894e5eb1914b53e8da5e4bf2bc83ac87ce55cc67df7'
+ + '093d9853d2a83a9c8be969175df7c807a17156df768445dd0874a9271c6537f5'
+ + 'ce0466473582375f067fa4fcdaf65dbc0139cd75e8c21a482f28c0fb8c3d9f94'
+ + '22606cc8e88fe28fe73ec3cb10ff0e8cc5f2a49e540f007265c65b7130bfdb98'
+ + '795b1df9522da46e48b30e55d9f0d787955ece720205b29c85f3ad9be33b4459'
+ + '7d21b54d06c9a60b04b8e640c64e566e51566730e86cf128ab14174f91bd8981'
+ + 'a6fb00fe587bbd6c38b5a1dfdb04ea7e61536fd229f957aa9b070ca931358e85'
+ + '11b92c53c523cb54828fb1513c5636fa9a0645b4a3c922c0db94986d92f314ff'
+ + '7852c03b231e4dceea5dd8cced621869cff818daf3c270ff3c8be2e5c74be767'
+ + 'a4e1fdf3327a934fe31e46df5a74ae2021cee021d958c4f615263d99a5ddae7f'
+ + 'eab45e6eccbafefe4761c57750847b7e75ee2e2f14333c0779ce4678f47b1e1b'
+ + '760a03a5f17d6e91d4b42313b3f1077ee270e432fe04917ed1fc8babebf7c941'
+ + '42b80dfb44a28a2a3e59093027606f6860bfb8c2e5897078cfccda7314c70035'
+ + 'f137de6f05daa035891d5f6f76e1df0fce1112a2ff0ac2bd3534b5d1bf4c7165'
+ + 'fb40a1b6eacb7f295711c4907ae457514a7010f3a342b4427593d61ba993bc59'
+ + '8bd09c56b9ee53aac5dd861fa4b4bb53888952a4aa9d8ca8671582de716270e1'
+ + '97375b3ee49e51fa2bf4ef32015dd9a764d966aa2ae541592d0aa650849e99ca'
+ + '5c6c39beebf516457cc32fe4c105bff314a12f1ec94bdf4d626f5d9b1cbbde42'
+ + 'e5733f0885765ba29e2e82c829d312f5fc7e180679ac84826c08d0a644b326d0'
+ + '44da0fdcc75fa53cfe4ced0437fa4df5a7ecbca8b4cb7c4a9ecf9a60d00a56eb'
+ + '81da52adc21f508dbb60a9503a3cc94a896616d86020d5b0e5c637329b6d396a'
+ + '41a21ba2c4a9493cf33fa2d4f10f77d5b12fdad7e478ccfe79b74851fc96a7ca'
+ + '6320c5efd561a222c0ab0fb44bbda0e42149611d2262bb7d1719150fa798718a'
+ + '0eec63ee297cad459869c8b0f06c4e2b56cbac03cd2605b2a924efedf85ec8f1'
+ + '9b0b6c90e7cbd933223ffeb1b3a3f9677657905829294c4c70acdb8b0891b47d'
+ + '0875d0cd6c0f4efe2917fc44b581ef0d1e4280197065d07da34ab33283364552'
+ + 'efad0bd9257b059acdd0a6f246812feb69e7e76065f27dbc2eee94da9cc41835'
+ + 'bf826e36e5cebe5d4d6a37a6a666246290ce51a0c082718ab0ec855668db1add'
+ + 'a658e5f257e0db39384d02e6145c4c00eaa079098f6d820d872de711b6ed08cf',
+ '0F62B5085BAE0154A7FA4DA0F34699EC3F92E5388BDE3184D72A7DD02376C91C',
+ '256 bits key, Set 6, vector# 3',
+ dict (iv='288FF65DC42B92F9')),
+
+]
+
+
+class KeyLength(unittest.TestCase):
+
+ def runTest(self):
+
+ nonce = bchr(0) * 8
+ for key_length in (15, 30, 33):
+ key = bchr(1) * key_length
+ self.assertRaises(ValueError, Salsa20.new, key, nonce)
+
+
+class NonceTests(unittest.TestCase):
+
+ def test_invalid_nonce_length(self):
+ key = bchr(1) * 16
+ self.assertRaises(ValueError, Salsa20.new, key, bchr(0) * 7)
+ self.assertRaises(ValueError, Salsa20.new, key, bchr(0) * 9)
+
+ def test_default_nonce(self):
+
+ cipher1 = Salsa20.new(bchr(1) * 16)
+ cipher2 = Salsa20.new(bchr(1) * 16)
+ self.assertEqual(len(cipher1.nonce), 8)
+ self.assertNotEqual(cipher1.nonce, cipher2.nonce)
+
+
+class ByteArrayTest(unittest.TestCase):
+ """Verify we can encrypt or decrypt bytearrays"""
+
+ def runTest(self):
+
+ data = b"0123"
+ key = b"9" * 32
+ nonce = b"t" * 8
+
+ # Encryption
+ data_ba = bytearray(data)
+ key_ba = bytearray(key)
+ nonce_ba = bytearray(nonce)
+
+ cipher1 = Salsa20.new(key=key, nonce=nonce)
+ ct = cipher1.encrypt(data)
+
+ cipher2 = Salsa20.new(key=key_ba, nonce=nonce_ba)
+ key_ba[:1] = b'\xFF'
+ nonce_ba[:1] = b'\xFF'
+ ct_test = cipher2.encrypt(data_ba)
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decryption
+ key_ba = bytearray(key)
+ nonce_ba = bytearray(nonce)
+ ct_ba = bytearray(ct)
+
+ cipher3 = Salsa20.new(key=key_ba, nonce=nonce_ba)
+ key_ba[:1] = b'\xFF'
+ nonce_ba[:1] = b'\xFF'
+ pt_test = cipher3.decrypt(ct_ba)
+
+ self.assertEqual(data, pt_test)
+
+
+class MemoryviewTest(unittest.TestCase):
+ """Verify we can encrypt or decrypt bytearrays"""
+
+ def runTest(self):
+
+ data = b"0123"
+ key = b"9" * 32
+ nonce = b"t" * 8
+
+ # Encryption
+ data_mv = memoryview(bytearray(data))
+ key_mv = memoryview(bytearray(key))
+ nonce_mv = memoryview(bytearray(nonce))
+
+ cipher1 = Salsa20.new(key=key, nonce=nonce)
+ ct = cipher1.encrypt(data)
+
+ cipher2 = Salsa20.new(key=key_mv, nonce=nonce_mv)
+ key_mv[:1] = b'\xFF'
+ nonce_mv[:1] = b'\xFF'
+ ct_test = cipher2.encrypt(data_mv)
+
+ self.assertEqual(ct, ct_test)
+ self.assertEqual(cipher1.nonce, cipher2.nonce)
+
+ # Decryption
+ key_mv = memoryview(bytearray(key))
+ nonce_mv = memoryview(bytearray(nonce))
+ ct_mv = memoryview(bytearray(ct))
+
+ cipher3 = Salsa20.new(key=key_mv, nonce=nonce_mv)
+ key_mv[:1] = b'\xFF'
+ nonce_mv[:1] = b'\xFF'
+ pt_test = cipher3.decrypt(ct_mv)
+
+ self.assertEqual(data, pt_test)
+
+
+class TestOutput(unittest.TestCase):
+
+ def runTest(self):
+ # Encrypt/Decrypt data and test output parameter
+
+ key = b'4' * 32
+ nonce = b'5' * 8
+ cipher = Salsa20.new(key=key, nonce=nonce)
+
+ pt = b'5' * 300
+ ct = cipher.encrypt(pt)
+
+ output = bytearray(len(pt))
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ res = cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+ self.assertEqual(res, None)
+
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ res = cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+ self.assertEqual(res, None)
+
+ output = memoryview(bytearray(len(pt)))
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ cipher.encrypt(pt, output=output)
+ self.assertEqual(ct, output)
+
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ cipher.decrypt(ct, output=output)
+ self.assertEqual(pt, output)
+
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ self.assertRaises(TypeError, cipher.encrypt, pt, output=b'0'*len(pt))
+
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ self.assertRaises(TypeError, cipher.decrypt, ct, output=b'0'*len(ct))
+
+ shorter_output = bytearray(len(pt) - 1)
+
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ self.assertRaises(ValueError, cipher.encrypt, pt, output=shorter_output)
+
+ cipher = Salsa20.new(key=key, nonce=nonce)
+ self.assertRaises(ValueError, cipher.decrypt, ct, output=shorter_output)
+
+
+def get_tests(config={}):
+ tests = make_stream_tests(Salsa20, "Salsa20", test_data)
+ tests.append(KeyLength())
+ tests += list_test_cases(NonceTests)
+ tests.append(ByteArrayTest())
+ tests.append(MemoryviewTest())
+ tests.append(TestOutput())
+
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_pkcs1_15.py b/lib/Crypto/SelfTest/Cipher/test_pkcs1_15.py
new file mode 100644
index 0000000..e16a543
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_pkcs1_15.py
@@ -0,0 +1,283 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/test_pkcs1_15.py: Self-test for PKCS#1 v1.5 encryption
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from __future__ import print_function
+
+import unittest
+
+from Crypto.PublicKey import RSA
+from Crypto.SelfTest.st_common import list_test_cases, a2b_hex
+from Crypto import Random
+from Crypto.Cipher import PKCS1_v1_5 as PKCS
+from Crypto.Util.py3compat import b
+from Crypto.Util.number import bytes_to_long, long_to_bytes
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+
+
+def rws(t):
+ """Remove white spaces, tabs, and new lines from a string"""
+ for c in ['\n', '\t', ' ']:
+ t = t.replace(c, '')
+ return t
+
+
+def t2b(t):
+ """Convert a text string with bytes in hex form to a byte string"""
+ clean = b(rws(t))
+ if len(clean) % 2 == 1:
+ raise ValueError("Even number of characters expected")
+ return a2b_hex(clean)
+
+
+class PKCS1_15_Tests(unittest.TestCase):
+
+ def setUp(self):
+ self.rng = Random.new().read
+ self.key1024 = RSA.generate(1024, self.rng)
+
+ # List of tuples with test data for PKCS#1 v1.5.
+ # Each tuple is made up by:
+ # Item #0: dictionary with RSA key component, or key to import
+ # Item #1: plaintext
+ # Item #2: ciphertext
+ # Item #3: random data
+
+ _testData = (
+
+ #
+ # Generated with openssl 0.9.8o
+ #
+ (
+ # Private key
+ '''-----BEGIN RSA PRIVATE KEY-----
+MIICXAIBAAKBgQDAiAnvIAOvqVwJTaYzsKnefZftgtXGE2hPJppGsWl78yz9jeXY
+W/FxX/gTPURArNhdnhP6n3p2ZaDIBrO2zizbgIXs0IsljTTcr4vnI8fMXzyNUOjA
+zP3nzMqZDZK6757XQAobOssMkBFqRWwilT/3DsBhRpl3iMUhF+wvpTSHewIDAQAB
+AoGAC4HV/inOrpgTvSab8Wj0riyZgQOZ3U3ZpSlsfR8ra9Ib9Uee3jCYnKscu6Gk
+y6zI/cdt8EPJ4PuwAWSNJzbpbVaDvUq25OD+CX8/uRT08yBS4J8TzBitZJTD4lS7
+atdTnKT0Wmwk+u8tDbhvMKwnUHdJLcuIsycts9rwJVapUtkCQQDvDpx2JMun0YKG
+uUttjmL8oJ3U0m3ZvMdVwBecA0eebZb1l2J5PvI3EJD97eKe91Nsw8T3lwpoN40k
+IocSVDklAkEAzi1HLHE6EzVPOe5+Y0kGvrIYRRhncOb72vCvBZvD6wLZpQgqo6c4
+d3XHFBBQWA6xcvQb5w+VVEJZzw64y25sHwJBAMYReRl6SzL0qA0wIYrYWrOt8JeQ
+8mthulcWHXmqTgC6FEXP9Es5GD7/fuKl4wqLKZgIbH4nqvvGay7xXLCXD/ECQH9a
+1JYNMtRen5unSAbIOxRcKkWz92F0LKpm9ZW/S9vFHO+mBcClMGoKJHiuQxLBsLbT
+NtEZfSJZAeS2sUtn3/0CQDb2M2zNBTF8LlM0nxmh0k9VGm5TVIyBEMcipmvOgqIs
+HKukWBcq9f/UOmS0oEhai/6g+Uf7VHJdWaeO5LzuvwU=
+-----END RSA PRIVATE KEY-----''',
+ # Plaintext
+ '''THIS IS PLAINTEXT\x0A''',
+ # Ciphertext
+ '''3f dc fd 3c cd 5c 9b 12 af 65 32 e3 f7 d0 da 36
+ 8f 8f d9 e3 13 1c 7f c8 b3 f9 c1 08 e4 eb 79 9c
+ 91 89 1f 96 3b 94 77 61 99 a4 b1 ee 5d e6 17 c9
+ 5d 0a b5 63 52 0a eb 00 45 38 2a fb b0 71 3d 11
+ f7 a1 9e a7 69 b3 af 61 c0 bb 04 5b 5d 4b 27 44
+ 1f 5b 97 89 ba 6a 08 95 ee 4f a2 eb 56 64 e5 0f
+ da 7c f9 9a 61 61 06 62 ed a0 bc 5f aa 6c 31 78
+ 70 28 1a bb 98 3c e3 6a 60 3c d1 0b 0f 5a f4 75''',
+ # Random data
+ '''eb d7 7d 86 a4 35 23 a3 54 7e 02 0b 42 1d
+ 61 6c af 67 b8 4e 17 56 80 66 36 04 64 34 26 8a
+ 47 dd 44 b3 1a b2 17 60 f4 91 2e e2 b5 95 64 cc
+ f9 da c8 70 94 54 86 4c ef 5b 08 7d 18 c4 ab 8d
+ 04 06 33 8f ca 15 5f 52 60 8a a1 0c f5 08 b5 4c
+ bb 99 b8 94 25 04 9c e6 01 75 e6 f9 63 7a 65 61
+ 13 8a a7 47 77 81 ae 0d b8 2c 4d 50 a5'''
+ ),
+ )
+
+ def testEncrypt1(self):
+ for test in self._testData:
+ # Build the key
+ key = RSA.importKey(test[0])
+ # RNG that takes its random numbers from a pool given
+ # at initialization
+ class randGen:
+ def __init__(self, data):
+ self.data = data
+ self.idx = 0
+ def __call__(self, N):
+ r = self.data[self.idx:self.idx+N]
+ self.idx += N
+ return r
+ # The real test
+ cipher = PKCS.new(key, randfunc=randGen(t2b(test[3])))
+ ct = cipher.encrypt(b(test[1]))
+ self.assertEqual(ct, t2b(test[2]))
+
+ def testEncrypt2(self):
+ # Verify that encryption fail if plaintext is too long
+ pt = '\x00'*(128-11+1)
+ cipher = PKCS.new(self.key1024)
+ self.assertRaises(ValueError, cipher.encrypt, pt)
+
+ def testVerify1(self):
+ for test in self._testData:
+ key = RSA.importKey(test[0])
+ expected_pt = b(test[1])
+ ct = t2b(test[2])
+ cipher = PKCS.new(key)
+
+ # The real test
+ pt = cipher.decrypt(ct, None)
+ self.assertEqual(pt, expected_pt)
+
+ pt = cipher.decrypt(ct, b'\xFF' * len(expected_pt))
+ self.assertEqual(pt, expected_pt)
+
+ def testVerify2(self):
+ # Verify that decryption fails if ciphertext is not as long as
+ # RSA modulus
+ cipher = PKCS.new(self.key1024)
+ self.assertRaises(ValueError, cipher.decrypt, '\x00'*127, "---")
+ self.assertRaises(ValueError, cipher.decrypt, '\x00'*129, "---")
+
+ # Verify that decryption fails if there are less then 8 non-zero padding
+ # bytes
+ pt = b('\x00\x02' + '\xFF'*7 + '\x00' + '\x45'*118)
+ pt_int = bytes_to_long(pt)
+ ct_int = self.key1024._encrypt(pt_int)
+ ct = long_to_bytes(ct_int, 128)
+ self.assertEqual(b"---", cipher.decrypt(ct, b"---"))
+
+ def testEncryptVerify1(self):
+ # Encrypt/Verify messages of length [0..RSAlen-11]
+ # and therefore padding [8..117]
+ for pt_len in range(0, 128 - 11 + 1):
+ pt = self.rng(pt_len)
+ cipher = PKCS.new(self.key1024)
+ ct = cipher.encrypt(pt)
+ pt2 = cipher.decrypt(ct, b'\xAA' * pt_len)
+ self.assertEqual(pt, pt2)
+
+ def test_encrypt_verify_exp_pt_len(self):
+
+ cipher = PKCS.new(self.key1024)
+ pt = b'5' * 16
+ ct = cipher.encrypt(pt)
+ sentinel = b'\xAA' * 16
+
+ pt_A = cipher.decrypt(ct, sentinel, 16)
+ self.assertEqual(pt, pt_A)
+
+ pt_B = cipher.decrypt(ct, sentinel, 15)
+ self.assertEqual(sentinel, pt_B)
+
+ pt_C = cipher.decrypt(ct, sentinel, 17)
+ self.assertEqual(sentinel, pt_C)
+
+ def testByteArray(self):
+ pt = b"XER"
+ cipher = PKCS.new(self.key1024)
+ ct = cipher.encrypt(bytearray(pt))
+ pt2 = cipher.decrypt(bytearray(ct), '\xFF' * len(pt))
+ self.assertEqual(pt, pt2)
+
+ def testMemoryview(self):
+ pt = b"XER"
+ cipher = PKCS.new(self.key1024)
+ ct = cipher.encrypt(memoryview(bytearray(pt)))
+ pt2 = cipher.decrypt(memoryview(bytearray(ct)), b'\xFF' * len(pt))
+ self.assertEqual(pt, pt2)
+
+ def test_return_type(self):
+ pt = b"XYZ"
+ cipher = PKCS.new(self.key1024)
+ ct = cipher.encrypt(pt)
+ self.assertTrue(isinstance(ct, bytes))
+ pt2 = cipher.decrypt(ct, b'\xAA' * 3)
+ self.assertTrue(isinstance(pt2, bytes))
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings, skip_slow_tests):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._skip_slow_tests = skip_slow_tests
+ self._id = "None"
+
+ def load_tests(self, filename):
+
+ def filter_rsa(group):
+ return RSA.import_key(group['privateKeyPem'])
+
+ result = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ filename,
+ "Wycheproof PKCS#1v1.5 (%s)" % filename,
+ group_tag={'rsa_key': filter_rsa}
+ )
+ return result
+
+ def setUp(self):
+ self.tv = []
+ self.tv.extend(self.load_tests("rsa_pkcs1_2048_test.json"))
+ if not self._skip_slow_tests:
+ self.tv.extend(self.load_tests("rsa_pkcs1_3072_test.json"))
+ self.tv.extend(self.load_tests("rsa_pkcs1_4096_test.json"))
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt PKCS#1v1.5 Test #%s" % tv.id
+ sentinel = b'\xAA' * max(3, len(tv.msg))
+ cipher = PKCS.new(tv.rsa_key)
+ try:
+ pt = cipher.decrypt(tv.ct, sentinel=sentinel)
+ except ValueError:
+ assert not tv.valid
+ else:
+ if pt == sentinel:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+ self.warn(tv)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_decrypt(tv)
+
+
+def get_tests(config={}):
+ skip_slow_tests = not config.get('slow_tests')
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(PKCS1_15_Tests)
+ tests += [TestVectorsWycheproof(wycheproof_warnings, skip_slow_tests)]
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Cipher/test_pkcs1_oaep.py b/lib/Crypto/SelfTest/Cipher/test_pkcs1_oaep.py
new file mode 100644
index 0000000..1711581
--- /dev/null
+++ b/lib/Crypto/SelfTest/Cipher/test_pkcs1_oaep.py
@@ -0,0 +1,506 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Cipher/test_pkcs1_oaep.py: Self-test for PKCS#1 OAEP encryption
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+import unittest
+
+from Crypto.SelfTest.st_common import list_test_cases, a2b_hex
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+
+from Crypto.PublicKey import RSA
+from Crypto.Cipher import PKCS1_OAEP as PKCS
+from Crypto.Hash import MD2, MD5, SHA1, SHA256, RIPEMD160, SHA224, SHA384, SHA512
+from Crypto import Random
+from Crypto.Signature.pss import MGF1
+
+from Crypto.Util.py3compat import b, bchr
+
+
+def rws(t):
+ """Remove white spaces, tabs, and new lines from a string"""
+ for c in ['\n', '\t', ' ']:
+ t = t.replace(c, '')
+ return t
+
+
+def t2b(t):
+ """Convert a text string with bytes in hex form to a byte string"""
+ clean = rws(t)
+ if len(clean) % 2 == 1:
+ raise ValueError("Even number of characters expected")
+ return a2b_hex(clean)
+
+
+class PKCS1_OAEP_Tests(unittest.TestCase):
+
+ def setUp(self):
+ self.rng = Random.new().read
+ self.key1024 = RSA.generate(1024, self.rng)
+
+ # List of tuples with test data for PKCS#1 OAEP
+ # Each tuple is made up by:
+ # Item #0: dictionary with RSA key component
+ # Item #1: plaintext
+ # Item #2: ciphertext
+ # Item #3: random data (=seed)
+ # Item #4: hash object
+
+ _testData = (
+
+ #
+ # From in oaep-int.txt to be found in
+ # ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip
+ #
+ (
+ # Private key
+ {
+ 'n':'''bb f8 2f 09 06 82 ce 9c 23 38 ac 2b 9d a8 71 f7
+ 36 8d 07 ee d4 10 43 a4 40 d6 b6 f0 74 54 f5 1f
+ b8 df ba af 03 5c 02 ab 61 ea 48 ce eb 6f cd 48
+ 76 ed 52 0d 60 e1 ec 46 19 71 9d 8a 5b 8b 80 7f
+ af b8 e0 a3 df c7 37 72 3e e6 b4 b7 d9 3a 25 84
+ ee 6a 64 9d 06 09 53 74 88 34 b2 45 45 98 39 4e
+ e0 aa b1 2d 7b 61 a5 1f 52 7a 9a 41 f6 c1 68 7f
+ e2 53 72 98 ca 2a 8f 59 46 f8 e5 fd 09 1d bd cb''',
+ # Public key
+ 'e':'11',
+ # In the test vector, only p and q were given...
+ # d is computed offline as e^{-1} mod (p-1)(q-1)
+ 'd':'''a5dafc5341faf289c4b988db30c1cdf83f31251e0
+ 668b42784813801579641b29410b3c7998d6bc465745e5c3
+ 92669d6870da2c082a939e37fdcb82ec93edac97ff3ad595
+ 0accfbc111c76f1a9529444e56aaf68c56c092cd38dc3bef
+ 5d20a939926ed4f74a13eddfbe1a1cecc4894af9428c2b7b
+ 8883fe4463a4bc85b1cb3c1'''
+ }
+ ,
+ # Plaintext
+ '''d4 36 e9 95 69 fd 32 a7 c8 a0 5b bc 90 d3 2c 49''',
+ # Ciphertext
+ '''12 53 e0 4d c0 a5 39 7b b4 4a 7a b8 7e 9b f2 a0
+ 39 a3 3d 1e 99 6f c8 2a 94 cc d3 00 74 c9 5d f7
+ 63 72 20 17 06 9e 52 68 da 5d 1c 0b 4f 87 2c f6
+ 53 c1 1d f8 23 14 a6 79 68 df ea e2 8d ef 04 bb
+ 6d 84 b1 c3 1d 65 4a 19 70 e5 78 3b d6 eb 96 a0
+ 24 c2 ca 2f 4a 90 fe 9f 2e f5 c9 c1 40 e5 bb 48
+ da 95 36 ad 87 00 c8 4f c9 13 0a de a7 4e 55 8d
+ 51 a7 4d df 85 d8 b5 0d e9 68 38 d6 06 3e 09 55''',
+ # Random
+ '''aa fd 12 f6 59 ca e6 34 89 b4 79 e5 07 6d de c2
+ f0 6c b5 8f''',
+ # Hash
+ SHA1,
+ ),
+
+ #
+ # From in oaep-vect.txt to be found in Example 1.1
+ # ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip
+ #
+ (
+ # Private key
+ {
+ 'n':'''a8 b3 b2 84 af 8e b5 0b 38 70 34 a8 60 f1 46 c4
+ 91 9f 31 87 63 cd 6c 55 98 c8 ae 48 11 a1 e0 ab
+ c4 c7 e0 b0 82 d6 93 a5 e7 fc ed 67 5c f4 66 85
+ 12 77 2c 0c bc 64 a7 42 c6 c6 30 f5 33 c8 cc 72
+ f6 2a e8 33 c4 0b f2 58 42 e9 84 bb 78 bd bf 97
+ c0 10 7d 55 bd b6 62 f5 c4 e0 fa b9 84 5c b5 14
+ 8e f7 39 2d d3 aa ff 93 ae 1e 6b 66 7b b3 d4 24
+ 76 16 d4 f5 ba 10 d4 cf d2 26 de 88 d3 9f 16 fb''',
+ 'e':'''01 00 01''',
+ 'd':'''53 33 9c fd b7 9f c8 46 6a 65 5c 73 16 ac a8 5c
+ 55 fd 8f 6d d8 98 fd af 11 95 17 ef 4f 52 e8 fd
+ 8e 25 8d f9 3f ee 18 0f a0 e4 ab 29 69 3c d8 3b
+ 15 2a 55 3d 4a c4 d1 81 2b 8b 9f a5 af 0e 7f 55
+ fe 73 04 df 41 57 09 26 f3 31 1f 15 c4 d6 5a 73
+ 2c 48 31 16 ee 3d 3d 2d 0a f3 54 9a d9 bf 7c bf
+ b7 8a d8 84 f8 4d 5b eb 04 72 4d c7 36 9b 31 de
+ f3 7d 0c f5 39 e9 cf cd d3 de 65 37 29 ea d5 d1 '''
+ }
+ ,
+ # Plaintext
+ '''66 28 19 4e 12 07 3d b0 3b a9 4c da 9e f9 53 23
+ 97 d5 0d ba 79 b9 87 00 4a fe fe 34''',
+ # Ciphertext
+ '''35 4f e6 7b 4a 12 6d 5d 35 fe 36 c7 77 79 1a 3f
+ 7b a1 3d ef 48 4e 2d 39 08 af f7 22 fa d4 68 fb
+ 21 69 6d e9 5d 0b e9 11 c2 d3 17 4f 8a fc c2 01
+ 03 5f 7b 6d 8e 69 40 2d e5 45 16 18 c2 1a 53 5f
+ a9 d7 bf c5 b8 dd 9f c2 43 f8 cf 92 7d b3 13 22
+ d6 e8 81 ea a9 1a 99 61 70 e6 57 a0 5a 26 64 26
+ d9 8c 88 00 3f 84 77 c1 22 70 94 a0 d9 fa 1e 8c
+ 40 24 30 9c e1 ec cc b5 21 00 35 d4 7a c7 2e 8a''',
+ # Random
+ '''18 b7 76 ea 21 06 9d 69 77 6a 33 e9 6b ad 48 e1
+ dd a0 a5 ef''',
+ SHA1
+ ),
+
+ #
+ # From in oaep-vect.txt to be found in Example 2.1
+ # ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip
+ #
+ (
+ # Private key
+ {
+ 'n':'''01 94 7c 7f ce 90 42 5f 47 27 9e 70 85 1f 25 d5
+ e6 23 16 fe 8a 1d f1 93 71 e3 e6 28 e2 60 54 3e
+ 49 01 ef 60 81 f6 8c 0b 81 41 19 0d 2a e8 da ba
+ 7d 12 50 ec 6d b6 36 e9 44 ec 37 22 87 7c 7c 1d
+ 0a 67 f1 4b 16 94 c5 f0 37 94 51 a4 3e 49 a3 2d
+ de 83 67 0b 73 da 91 a1 c9 9b c2 3b 43 6a 60 05
+ 5c 61 0f 0b af 99 c1 a0 79 56 5b 95 a3 f1 52 66
+ 32 d1 d4 da 60 f2 0e da 25 e6 53 c4 f0 02 76 6f
+ 45''',
+ 'e':'''01 00 01''',
+ 'd':'''08 23 f2 0f ad b5 da 89 08 8a 9d 00 89 3e 21 fa
+ 4a 1b 11 fb c9 3c 64 a3 be 0b aa ea 97 fb 3b 93
+ c3 ff 71 37 04 c1 9c 96 3c 1d 10 7a ae 99 05 47
+ 39 f7 9e 02 e1 86 de 86 f8 7a 6d de fe a6 d8 cc
+ d1 d3 c8 1a 47 bf a7 25 5b e2 06 01 a4 a4 b2 f0
+ 8a 16 7b 5e 27 9d 71 5b 1b 45 5b dd 7e ab 24 59
+ 41 d9 76 8b 9a ce fb 3c cd a5 95 2d a3 ce e7 25
+ 25 b4 50 16 63 a8 ee 15 c9 e9 92 d9 24 62 fe 39'''
+ },
+ # Plaintext
+ '''8f f0 0c aa 60 5c 70 28 30 63 4d 9a 6c 3d 42 c6
+ 52 b5 8c f1 d9 2f ec 57 0b ee e7''',
+ # Ciphertext
+ '''01 81 af 89 22 b9 fc b4 d7 9d 92 eb e1 98 15 99
+ 2f c0 c1 43 9d 8b cd 49 13 98 a0 f4 ad 3a 32 9a
+ 5b d9 38 55 60 db 53 26 83 c8 b7 da 04 e4 b1 2a
+ ed 6a ac df 47 1c 34 c9 cd a8 91 ad dc c2 df 34
+ 56 65 3a a6 38 2e 9a e5 9b 54 45 52 57 eb 09 9d
+ 56 2b be 10 45 3f 2b 6d 13 c5 9c 02 e1 0f 1f 8a
+ bb 5d a0 d0 57 09 32 da cf 2d 09 01 db 72 9d 0f
+ ef cc 05 4e 70 96 8e a5 40 c8 1b 04 bc ae fe 72
+ 0e''',
+ # Random
+ '''8c 40 7b 5e c2 89 9e 50 99 c5 3e 8c e7 93 bf 94
+ e7 1b 17 82''',
+ SHA1
+ ),
+
+ #
+ # From in oaep-vect.txt to be found in Example 10.1
+ # ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip
+ #
+ (
+ # Private key
+ {
+ 'n':'''ae 45 ed 56 01 ce c6 b8 cc 05 f8 03 93 5c 67 4d
+ db e0 d7 5c 4c 09 fd 79 51 fc 6b 0c ae c3 13 a8
+ df 39 97 0c 51 8b ff ba 5e d6 8f 3f 0d 7f 22 a4
+ 02 9d 41 3f 1a e0 7e 4e be 9e 41 77 ce 23 e7 f5
+ 40 4b 56 9e 4e e1 bd cf 3c 1f b0 3e f1 13 80 2d
+ 4f 85 5e b9 b5 13 4b 5a 7c 80 85 ad ca e6 fa 2f
+ a1 41 7e c3 76 3b e1 71 b0 c6 2b 76 0e de 23 c1
+ 2a d9 2b 98 08 84 c6 41 f5 a8 fa c2 6b da d4 a0
+ 33 81 a2 2f e1 b7 54 88 50 94 c8 25 06 d4 01 9a
+ 53 5a 28 6a fe b2 71 bb 9b a5 92 de 18 dc f6 00
+ c2 ae ea e5 6e 02 f7 cf 79 fc 14 cf 3b dc 7c d8
+ 4f eb bb f9 50 ca 90 30 4b 22 19 a7 aa 06 3a ef
+ a2 c3 c1 98 0e 56 0c d6 4a fe 77 95 85 b6 10 76
+ 57 b9 57 85 7e fd e6 01 09 88 ab 7d e4 17 fc 88
+ d8 f3 84 c4 e6 e7 2c 3f 94 3e 0c 31 c0 c4 a5 cc
+ 36 f8 79 d8 a3 ac 9d 7d 59 86 0e aa da 6b 83 bb''',
+ 'e':'''01 00 01''',
+ 'd':'''05 6b 04 21 6f e5 f3 54 ac 77 25 0a 4b 6b 0c 85
+ 25 a8 5c 59 b0 bd 80 c5 64 50 a2 2d 5f 43 8e 59
+ 6a 33 3a a8 75 e2 91 dd 43 f4 8c b8 8b 9d 5f c0
+ d4 99 f9 fc d1 c3 97 f9 af c0 70 cd 9e 39 8c 8d
+ 19 e6 1d b7 c7 41 0a 6b 26 75 df bf 5d 34 5b 80
+ 4d 20 1a dd 50 2d 5c e2 df cb 09 1c e9 99 7b be
+ be 57 30 6f 38 3e 4d 58 81 03 f0 36 f7 e8 5d 19
+ 34 d1 52 a3 23 e4 a8 db 45 1d 6f 4a 5b 1b 0f 10
+ 2c c1 50 e0 2f ee e2 b8 8d ea 4a d4 c1 ba cc b2
+ 4d 84 07 2d 14 e1 d2 4a 67 71 f7 40 8e e3 05 64
+ fb 86 d4 39 3a 34 bc f0 b7 88 50 1d 19 33 03 f1
+ 3a 22 84 b0 01 f0 f6 49 ea f7 93 28 d4 ac 5c 43
+ 0a b4 41 49 20 a9 46 0e d1 b7 bc 40 ec 65 3e 87
+ 6d 09 ab c5 09 ae 45 b5 25 19 01 16 a0 c2 61 01
+ 84 82 98 50 9c 1c 3b f3 a4 83 e7 27 40 54 e1 5e
+ 97 07 50 36 e9 89 f6 09 32 80 7b 52 57 75 1e 79'''
+ },
+ # Plaintext
+ '''8b ba 6b f8 2a 6c 0f 86 d5 f1 75 6e 97 95 68 70
+ b0 89 53 b0 6b 4e b2 05 bc 16 94 ee''',
+ # Ciphertext
+ '''53 ea 5d c0 8c d2 60 fb 3b 85 85 67 28 7f a9 15
+ 52 c3 0b 2f eb fb a2 13 f0 ae 87 70 2d 06 8d 19
+ ba b0 7f e5 74 52 3d fb 42 13 9d 68 c3 c5 af ee
+ e0 bf e4 cb 79 69 cb f3 82 b8 04 d6 e6 13 96 14
+ 4e 2d 0e 60 74 1f 89 93 c3 01 4b 58 b9 b1 95 7a
+ 8b ab cd 23 af 85 4f 4c 35 6f b1 66 2a a7 2b fc
+ c7 e5 86 55 9d c4 28 0d 16 0c 12 67 85 a7 23 eb
+ ee be ff 71 f1 15 94 44 0a ae f8 7d 10 79 3a 87
+ 74 a2 39 d4 a0 4c 87 fe 14 67 b9 da f8 52 08 ec
+ 6c 72 55 79 4a 96 cc 29 14 2f 9a 8b d4 18 e3 c1
+ fd 67 34 4b 0c d0 82 9d f3 b2 be c6 02 53 19 62
+ 93 c6 b3 4d 3f 75 d3 2f 21 3d d4 5c 62 73 d5 05
+ ad f4 cc ed 10 57 cb 75 8f c2 6a ee fa 44 12 55
+ ed 4e 64 c1 99 ee 07 5e 7f 16 64 61 82 fd b4 64
+ 73 9b 68 ab 5d af f0 e6 3e 95 52 01 68 24 f0 54
+ bf 4d 3c 8c 90 a9 7b b6 b6 55 32 84 eb 42 9f cc''',
+ # Random
+ '''47 e1 ab 71 19 fe e5 6c 95 ee 5e aa d8 6f 40 d0
+ aa 63 bd 33''',
+ SHA1
+ ),
+ )
+
+ def testEncrypt1(self):
+ # Verify encryption using all test vectors
+ for test in self._testData:
+ # Build the key
+ comps = [int(rws(test[0][x]), 16) for x in ('n', 'e')]
+ key = RSA.construct(comps)
+
+ # RNG that takes its random numbers from a pool given
+ # at initialization
+ class randGen:
+
+ def __init__(self, data):
+ self.data = data
+ self.idx = 0
+
+ def __call__(self, N):
+ r = self.data[self.idx:N]
+ self.idx += N
+ return r
+
+ # The real test
+ cipher = PKCS.new(key, test[4], randfunc=randGen(t2b(test[3])))
+ ct = cipher.encrypt(t2b(test[1]))
+ self.assertEqual(ct, t2b(test[2]))
+
+ def testEncrypt2(self):
+ # Verify that encryption fails if plaintext is too long
+ pt = '\x00'*(128-2*20-2+1)
+ cipher = PKCS.new(self.key1024)
+ self.assertRaises(ValueError, cipher.encrypt, pt)
+
+ def testDecrypt1(self):
+ # Verify decryption using all test vectors
+ for test in self._testData:
+ # Build the key
+ comps = [int(rws(test[0][x]),16) for x in ('n', 'e', 'd')]
+ key = RSA.construct(comps)
+ # The real test
+ cipher = PKCS.new(key, test[4])
+ pt = cipher.decrypt(t2b(test[2]))
+ self.assertEqual(pt, t2b(test[1]))
+
+ def testDecrypt2(self):
+ # Simplest possible negative tests
+ for ct_size in (127, 128, 129):
+ cipher = PKCS.new(self.key1024)
+ self.assertRaises(ValueError, cipher.decrypt, bchr(0x00)*ct_size)
+
+ def testEncryptDecrypt1(self):
+ # Encrypt/Decrypt messages of length [0..128-2*20-2]
+ for pt_len in range(0, 128-2*20-2):
+ pt = self.rng(pt_len)
+ cipher = PKCS.new(self.key1024)
+ ct = cipher.encrypt(pt)
+ pt2 = cipher.decrypt(ct)
+ self.assertEqual(pt, pt2)
+
+ def testEncryptDecrypt2(self):
+ # Helper function to monitor what's requested from RNG
+ global asked
+
+ def localRng(N):
+ global asked
+ asked += N
+ return self.rng(N)
+
+ # Verify that OAEP is friendly to all hashes
+ for hashmod in (MD2, MD5, SHA1, SHA256, RIPEMD160):
+ # Verify that encrypt() asks for as many random bytes
+ # as the hash output size
+ asked = 0
+ pt = self.rng(40)
+ cipher = PKCS.new(self.key1024, hashmod, randfunc=localRng)
+ ct = cipher.encrypt(pt)
+ self.assertEqual(cipher.decrypt(ct), pt)
+ self.assertEqual(asked, hashmod.digest_size)
+
+ def testEncryptDecrypt3(self):
+ # Verify that OAEP supports labels
+ pt = self.rng(35)
+ xlabel = self.rng(22)
+ cipher = PKCS.new(self.key1024, label=xlabel)
+ ct = cipher.encrypt(pt)
+ self.assertEqual(cipher.decrypt(ct), pt)
+
+ def testEncryptDecrypt4(self):
+ # Verify that encrypt() uses the custom MGF
+ global mgfcalls
+ # Helper function to monitor what's requested from MGF
+
+ def newMGF(seed, maskLen):
+ global mgfcalls
+ mgfcalls += 1
+ return b'\x00' * maskLen
+
+ mgfcalls = 0
+ pt = self.rng(32)
+ cipher = PKCS.new(self.key1024, mgfunc=newMGF)
+ ct = cipher.encrypt(pt)
+ self.assertEqual(mgfcalls, 2)
+ self.assertEqual(cipher.decrypt(ct), pt)
+
+ def testByteArray(self):
+ pt = b("XER")
+ cipher = PKCS.new(self.key1024)
+ ct = cipher.encrypt(bytearray(pt))
+ pt2 = cipher.decrypt(bytearray(ct))
+ self.assertEqual(pt, pt2)
+
+ def testMemoryview(self):
+ pt = b("XER")
+ cipher = PKCS.new(self.key1024)
+ ct = cipher.encrypt(memoryview(bytearray(pt)))
+ pt2 = cipher.decrypt(memoryview(bytearray(ct)))
+ self.assertEqual(pt, pt2)
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings, skip_slow_tests):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._skip_slow_tests = skip_slow_tests
+ self._id = "None"
+
+ def load_tests(self, filename):
+
+ def filter_rsa(group):
+ return RSA.import_key(group['privateKeyPem'])
+
+ def filter_sha(group):
+ if group['sha'] == "SHA-1":
+ return SHA1
+ elif group['sha'] == "SHA-224":
+ return SHA224
+ elif group['sha'] == "SHA-256":
+ return SHA256
+ elif group['sha'] == "SHA-384":
+ return SHA384
+ elif group['sha'] == "SHA-512":
+ return SHA512
+ else:
+ raise ValueError("Unknown sha " + group['sha'])
+
+ def filter_mgf(group):
+ if group['mgfSha'] == "SHA-1":
+ return lambda x, y: MGF1(x, y, SHA1)
+ elif group['mgfSha'] == "SHA-224":
+ return lambda x, y: MGF1(x, y, SHA224)
+ elif group['mgfSha'] == "SHA-256":
+ return lambda x, y: MGF1(x, y, SHA256)
+ elif group['mgfSha'] == "SHA-384":
+ return lambda x, y: MGF1(x, y, SHA384)
+ elif group['mgfSha'] == "SHA-512":
+ return lambda x, y: MGF1(x, y, SHA512)
+ else:
+ raise ValueError("Unknown mgf/sha " + group['mgfSha'])
+
+ def filter_algo(group):
+ return "%s with MGF1/%s" % (group['sha'], group['mgfSha'])
+
+ result = load_test_vectors_wycheproof(("Cipher", "wycheproof"),
+ filename,
+ "Wycheproof PKCS#1 OAEP (%s)" % filename,
+ group_tag={'rsa_key': filter_rsa,
+ 'hash_mod': filter_sha,
+ 'mgf': filter_mgf,
+ 'algo': filter_algo}
+ )
+ return result
+
+ def setUp(self):
+ self.tv = []
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha1_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha224_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha224_mgf1sha224_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha256_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha256_mgf1sha256_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha384_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha384_mgf1sha384_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha512_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_2048_sha512_mgf1sha512_test.json"))
+ if not self._skip_slow_tests:
+ self.tv.extend(self.load_tests("rsa_oaep_3072_sha256_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_3072_sha256_mgf1sha256_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_3072_sha512_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_3072_sha512_mgf1sha512_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_4096_sha256_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_4096_sha256_mgf1sha256_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_4096_sha512_mgf1sha1_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_4096_sha512_mgf1sha512_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_4096_sha512_mgf1sha512_test.json"))
+ self.tv.extend(self.load_tests("rsa_oaep_misc_test.json"))
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_decrypt(self, tv):
+ self._id = "Wycheproof Decrypt %s Test #%s" % (tv.algo, tv.id)
+
+ cipher = PKCS.new(tv.rsa_key, hashAlgo=tv.hash_mod, mgfunc=tv.mgf, label=tv.label)
+ try:
+ pt = cipher.decrypt(tv.ct)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.assertEqual(pt, tv.msg)
+ self.warn(tv)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_decrypt(tv)
+
+
+def get_tests(config={}):
+ skip_slow_tests = not config.get('slow_tests')
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(PKCS1_OAEP_Tests)
+ tests += [TestVectorsWycheproof(wycheproof_warnings, skip_slow_tests)]
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/__init__.py b/lib/Crypto/SelfTest/Hash/__init__.py
new file mode 100644
index 0000000..008a810
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/__init__.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/__init__.py: Self-test for hash modules
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test for hash modules"""
+
+__revision__ = "$Id$"
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest.Hash import test_HMAC; tests += test_HMAC.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_CMAC; tests += test_CMAC.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_MD2; tests += test_MD2.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_MD4; tests += test_MD4.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_MD5; tests += test_MD5.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_RIPEMD160; tests += test_RIPEMD160.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA1; tests += test_SHA1.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA224; tests += test_SHA224.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA256; tests += test_SHA256.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA384; tests += test_SHA384.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA512; tests += test_SHA512.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA3_224; tests += test_SHA3_224.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA3_256; tests += test_SHA3_256.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA3_384; tests += test_SHA3_384.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHA3_512; tests += test_SHA3_512.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_keccak; tests += test_keccak.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_SHAKE; tests += test_SHAKE.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_BLAKE2; tests += test_BLAKE2.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_Poly1305; tests += test_Poly1305.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_cSHAKE; tests += test_cSHAKE.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_KMAC; tests += test_KMAC.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_TupleHash; tests += test_TupleHash.get_tests(config=config)
+ from Crypto.SelfTest.Hash import test_KangarooTwelve; tests += test_KangarooTwelve.get_tests(config=config)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/common.py b/lib/Crypto/SelfTest/Hash/common.py
new file mode 100644
index 0000000..1578667
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/common.py
@@ -0,0 +1,290 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/common.py: Common code for Crypto.SelfTest.Hash
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-testing for PyCrypto hash modules"""
+
+import re
+import sys
+import unittest
+import binascii
+import Crypto.Hash
+from binascii import hexlify, unhexlify
+from Crypto.Util.py3compat import b, tobytes
+from Crypto.Util.strxor import strxor_c
+
+def t2b(hex_string):
+ shorter = re.sub(br'\s+', b'', tobytes(hex_string))
+ return unhexlify(shorter)
+
+
+class HashDigestSizeSelfTest(unittest.TestCase):
+
+ def __init__(self, hashmod, description, expected, extra_params):
+ unittest.TestCase.__init__(self)
+ self.hashmod = hashmod
+ self.expected = expected
+ self.description = description
+ self.extra_params = extra_params
+
+ def shortDescription(self):
+ return self.description
+
+ def runTest(self):
+ if "truncate" not in self.extra_params:
+ self.assertTrue(hasattr(self.hashmod, "digest_size"))
+ self.assertEqual(self.hashmod.digest_size, self.expected)
+ h = self.hashmod.new(**self.extra_params)
+ self.assertTrue(hasattr(h, "digest_size"))
+ self.assertEqual(h.digest_size, self.expected)
+
+
+class HashSelfTest(unittest.TestCase):
+
+ def __init__(self, hashmod, description, expected, input, extra_params):
+ unittest.TestCase.__init__(self)
+ self.hashmod = hashmod
+ self.expected = expected.lower()
+ self.input = input
+ self.description = description
+ self.extra_params = extra_params
+
+ def shortDescription(self):
+ return self.description
+
+ def runTest(self):
+ h = self.hashmod.new(**self.extra_params)
+ h.update(self.input)
+
+ out1 = binascii.b2a_hex(h.digest())
+ out2 = h.hexdigest()
+
+ h = self.hashmod.new(self.input, **self.extra_params)
+
+ out3 = h.hexdigest()
+ out4 = binascii.b2a_hex(h.digest())
+
+ # PY3K: hexdigest() should return str(), and digest() bytes
+ self.assertEqual(self.expected, out1) # h = .new(); h.update(data); h.digest()
+ if sys.version_info[0] == 2:
+ self.assertEqual(self.expected, out2) # h = .new(); h.update(data); h.hexdigest()
+ self.assertEqual(self.expected, out3) # h = .new(data); h.hexdigest()
+ else:
+ self.assertEqual(self.expected.decode(), out2) # h = .new(); h.update(data); h.hexdigest()
+ self.assertEqual(self.expected.decode(), out3) # h = .new(data); h.hexdigest()
+ self.assertEqual(self.expected, out4) # h = .new(data); h.digest()
+
+ # Verify that the .new() method produces a fresh hash object, except
+ # for MD5 and SHA1, which are hashlib objects. (But test any .new()
+ # method that does exist.)
+ if self.hashmod.__name__ not in ('Crypto.Hash.MD5', 'Crypto.Hash.SHA1') or hasattr(h, 'new'):
+ h2 = h.new()
+ h2.update(self.input)
+ out5 = binascii.b2a_hex(h2.digest())
+ self.assertEqual(self.expected, out5)
+
+
+class HashTestOID(unittest.TestCase):
+ def __init__(self, hashmod, oid, extra_params):
+ unittest.TestCase.__init__(self)
+ self.hashmod = hashmod
+ self.oid = oid
+ self.extra_params = extra_params
+
+ def runTest(self):
+ h = self.hashmod.new(**self.extra_params)
+ self.assertEqual(h.oid, self.oid)
+
+
+class ByteArrayTest(unittest.TestCase):
+
+ def __init__(self, module, extra_params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+ self.extra_params = extra_params
+
+ def runTest(self):
+ data = b("\x00\x01\x02")
+
+ # Data can be a bytearray (during initialization)
+ ba = bytearray(data)
+
+ h1 = self.module.new(data, **self.extra_params)
+ h2 = self.module.new(ba, **self.extra_params)
+ ba[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a bytearray (during operation)
+ ba = bytearray(data)
+
+ h1 = self.module.new(**self.extra_params)
+ h2 = self.module.new(**self.extra_params)
+
+ h1.update(data)
+ h2.update(ba)
+
+ ba[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class MemoryViewTest(unittest.TestCase):
+
+ def __init__(self, module, extra_params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+ self.extra_params = extra_params
+
+ def runTest(self):
+
+ data = b"\x00\x01\x02"
+
+ def get_mv_ro(data):
+ return memoryview(data)
+
+ def get_mv_rw(data):
+ return memoryview(bytearray(data))
+
+ for get_mv in get_mv_ro, get_mv_rw:
+
+ # Data can be a memoryview (during initialization)
+ mv = get_mv(data)
+
+ h1 = self.module.new(data, **self.extra_params)
+ h2 = self.module.new(mv, **self.extra_params)
+ if not mv.readonly:
+ mv[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a memoryview (during operation)
+ mv = get_mv(data)
+
+ h1 = self.module.new(**self.extra_params)
+ h2 = self.module.new(**self.extra_params)
+ h1.update(data)
+ h2.update(mv)
+ if not mv.readonly:
+ mv[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class MACSelfTest(unittest.TestCase):
+
+ def __init__(self, module, description, result, data, key, params):
+ unittest.TestCase.__init__(self)
+ self.module = module
+ self.result = t2b(result)
+ self.data = t2b(data)
+ self.key = t2b(key)
+ self.params = params
+ self.description = description
+
+ def shortDescription(self):
+ return self.description
+
+ def runTest(self):
+
+ result_hex = hexlify(self.result)
+
+ # Verify result
+ h = self.module.new(self.key, **self.params)
+ h.update(self.data)
+ self.assertEqual(self.result, h.digest())
+ self.assertEqual(hexlify(self.result).decode('ascii'), h.hexdigest())
+
+ # Verify that correct MAC does not raise any exception
+ h.verify(self.result)
+ h.hexverify(result_hex)
+
+ # Verify that incorrect MAC does raise ValueError exception
+ wrong_mac = strxor_c(self.result, 255)
+ self.assertRaises(ValueError, h.verify, wrong_mac)
+ self.assertRaises(ValueError, h.hexverify, "4556")
+
+ # Verify again, with data passed to new()
+ h = self.module.new(self.key, self.data, **self.params)
+ self.assertEqual(self.result, h.digest())
+ self.assertEqual(hexlify(self.result).decode('ascii'), h.hexdigest())
+
+ # Test .copy()
+ try:
+ h = self.module.new(self.key, self.data, **self.params)
+ h2 = h.copy()
+ h3 = h.copy()
+
+ # Verify that changing the copy does not change the original
+ h2.update(b"bla")
+ self.assertEqual(h3.digest(), self.result)
+
+ # Verify that both can reach the same state
+ h.update(b"bla")
+ self.assertEqual(h.digest(), h2.digest())
+ except NotImplementedError:
+ pass
+
+ # PY3K: Check that hexdigest() returns str and digest() returns bytes
+ self.assertTrue(isinstance(h.digest(), type(b"")))
+ self.assertTrue(isinstance(h.hexdigest(), type("")))
+
+ # PY3K: Check that .hexverify() accepts bytes or str
+ h.hexverify(h.hexdigest())
+ h.hexverify(h.hexdigest().encode('ascii'))
+
+
+def make_hash_tests(module, module_name, test_data, digest_size, oid=None,
+ extra_params={}):
+ tests = []
+ for i in range(len(test_data)):
+ row = test_data[i]
+ (expected, input) = map(tobytes,row[0:2])
+ if len(row) < 3:
+ description = repr(input)
+ else:
+ description = row[2]
+ name = "%s #%d: %s" % (module_name, i+1, description)
+ tests.append(HashSelfTest(module, name, expected, input, extra_params))
+
+ name = "%s #%d: digest_size" % (module_name, len(test_data) + 1)
+ tests.append(HashDigestSizeSelfTest(module, name, digest_size, extra_params))
+
+ if oid is not None:
+ tests.append(HashTestOID(module, oid, extra_params))
+
+ tests.append(ByteArrayTest(module, extra_params))
+
+ tests.append(MemoryViewTest(module, extra_params))
+
+ return tests
+
+
+def make_mac_tests(module, module_name, test_data):
+ tests = []
+ for i, row in enumerate(test_data):
+ if len(row) == 4:
+ (key, data, results, description, params) = list(row) + [ {} ]
+ else:
+ (key, data, results, description, params) = row
+ name = "%s #%d: %s" % (module_name, i+1, description)
+ tests.append(MACSelfTest(module, name, results, data, key, params))
+ return tests
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_BLAKE2.py b/lib/Crypto/SelfTest/Hash/test_BLAKE2.py
new file mode 100644
index 0000000..f953eab
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_BLAKE2.py
@@ -0,0 +1,482 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import os
+import re
+import unittest
+import warnings
+from binascii import unhexlify, hexlify
+
+from Crypto.Util.py3compat import tobytes
+from Crypto.Util.strxor import strxor_c
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import BLAKE2b, BLAKE2s
+
+
+class Blake2Test(unittest.TestCase):
+
+ def test_new_positive(self):
+
+ h = self.BLAKE2.new(digest_bits=self.max_bits)
+ for new_func in self.BLAKE2.new, h.new:
+
+ for dbits in range(8, self.max_bits + 1, 8):
+ hobj = new_func(digest_bits=dbits)
+ self.assertEqual(hobj.digest_size, dbits // 8)
+
+ for dbytes in range(1, self.max_bytes + 1):
+ hobj = new_func(digest_bytes=dbytes)
+ self.assertEqual(hobj.digest_size, dbytes)
+
+ digest1 = new_func(data=b"\x90", digest_bytes=self.max_bytes).digest()
+ digest2 = new_func(digest_bytes=self.max_bytes).update(b"\x90").digest()
+ self.assertEqual(digest1, digest2)
+
+ new_func(data=b"A", key=b"5", digest_bytes=self.max_bytes)
+
+ hobj = h.new()
+ self.assertEqual(hobj.digest_size, self.max_bytes)
+
+ def test_new_negative(self):
+
+ h = self.BLAKE2.new(digest_bits=self.max_bits)
+ for new_func in self.BLAKE2.new, h.new:
+ self.assertRaises(TypeError, new_func,
+ digest_bytes=self.max_bytes,
+ digest_bits=self.max_bits)
+ self.assertRaises(ValueError, new_func, digest_bytes=0)
+ self.assertRaises(ValueError, new_func,
+ digest_bytes=self.max_bytes + 1)
+ self.assertRaises(ValueError, new_func, digest_bits=7)
+ self.assertRaises(ValueError, new_func, digest_bits=15)
+ self.assertRaises(ValueError, new_func,
+ digest_bits=self.max_bits + 1)
+ self.assertRaises(TypeError, new_func,
+ digest_bytes=self.max_bytes,
+ key=u"string")
+ self.assertRaises(TypeError, new_func,
+ digest_bytes=self.max_bytes,
+ data=u"string")
+
+ def test_default_digest_size(self):
+ digest = self.BLAKE2.new(data=b'abc').digest()
+ self.assertEqual(len(digest), self.max_bytes)
+
+ def test_update(self):
+ pieces = [b"\x0A" * 200, b"\x14" * 300]
+ h = self.BLAKE2.new(digest_bytes=self.max_bytes)
+ h.update(pieces[0]).update(pieces[1])
+ digest = h.digest()
+ h = self.BLAKE2.new(digest_bytes=self.max_bytes)
+ h.update(pieces[0] + pieces[1])
+ self.assertEqual(h.digest(), digest)
+
+ def test_update_negative(self):
+ h = self.BLAKE2.new(digest_bytes=self.max_bytes)
+ self.assertRaises(TypeError, h.update, u"string")
+
+ def test_digest(self):
+ h = self.BLAKE2.new(digest_bytes=self.max_bytes)
+ digest = h.digest()
+
+ # hexdigest does not change the state
+ self.assertEqual(h.digest(), digest)
+ # digest returns a byte string
+ self.assertTrue(isinstance(digest, type(b"digest")))
+
+ def test_update_after_digest(self):
+ msg = b"rrrrttt"
+
+ # Normally, update() cannot be done after digest()
+ h = self.BLAKE2.new(digest_bits=256, data=msg[:4])
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+ dig2 = self.BLAKE2.new(digest_bits=256, data=msg).digest()
+
+ # With the proper flag, it is allowed
+ h = self.BLAKE2.new(digest_bits=256, data=msg[:4], update_after_digest=True)
+ self.assertEqual(h.digest(), dig1)
+ # ... and the subsequent digest applies to the entire message
+ # up to that point
+ h.update(msg[4:])
+ self.assertEqual(h.digest(), dig2)
+
+ def test_hex_digest(self):
+ mac = self.BLAKE2.new(digest_bits=self.max_bits)
+ digest = mac.digest()
+ hexdigest = mac.hexdigest()
+
+ # hexdigest is equivalent to digest
+ self.assertEqual(hexlify(digest), tobytes(hexdigest))
+ # hexdigest does not change the state
+ self.assertEqual(mac.hexdigest(), hexdigest)
+ # hexdigest returns a string
+ self.assertTrue(isinstance(hexdigest, type("digest")))
+
+ def test_verify(self):
+ h = self.BLAKE2.new(digest_bytes=self.max_bytes, key=b"4")
+ mac = h.digest()
+ h.verify(mac)
+ wrong_mac = strxor_c(mac, 255)
+ self.assertRaises(ValueError, h.verify, wrong_mac)
+
+ def test_hexverify(self):
+ h = self.BLAKE2.new(digest_bytes=self.max_bytes, key=b"4")
+ mac = h.hexdigest()
+ h.hexverify(mac)
+ self.assertRaises(ValueError, h.hexverify, "4556")
+
+ def test_oid(self):
+
+ prefix = "1.3.6.1.4.1.1722.12.2." + self.oid_variant + "."
+
+ for digest_bits in self.digest_bits_oid:
+ h = self.BLAKE2.new(digest_bits=digest_bits)
+ self.assertEqual(h.oid, prefix + str(digest_bits // 8))
+
+ h = self.BLAKE2.new(digest_bits=digest_bits, key=b"secret")
+ self.assertRaises(AttributeError, lambda: h.oid)
+
+ for digest_bits in (8, self.max_bits):
+ if digest_bits in self.digest_bits_oid:
+ continue
+ self.assertRaises(AttributeError, lambda: h.oid)
+
+ def test_bytearray(self):
+
+ key = b'0' * 16
+ data = b"\x00\x01\x02"
+
+ # Data and key can be a bytearray (during initialization)
+ key_ba = bytearray(key)
+ data_ba = bytearray(data)
+
+ h1 = self.BLAKE2.new(data=data, key=key)
+ h2 = self.BLAKE2.new(data=data_ba, key=key_ba)
+ key_ba[:1] = b'\xFF'
+ data_ba[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a bytearray (during operation)
+ data_ba = bytearray(data)
+
+ h1 = self.BLAKE2.new()
+ h2 = self.BLAKE2.new()
+ h1.update(data)
+ h2.update(data_ba)
+ data_ba[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ def test_memoryview(self):
+
+ key = b'0' * 16
+ data = b"\x00\x01\x02"
+
+ def get_mv_ro(data):
+ return memoryview(data)
+
+ def get_mv_rw(data):
+ return memoryview(bytearray(data))
+
+ for get_mv in (get_mv_ro, get_mv_rw):
+
+ # Data and key can be a memoryview (during initialization)
+ key_mv = get_mv(key)
+ data_mv = get_mv(data)
+
+ h1 = self.BLAKE2.new(data=data, key=key)
+ h2 = self.BLAKE2.new(data=data_mv, key=key_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+ key_mv[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a memoryview (during operation)
+ data_mv = get_mv(data)
+
+ h1 = self.BLAKE2.new()
+ h2 = self.BLAKE2.new()
+ h1.update(data)
+ h2.update(data_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class Blake2bTest(Blake2Test):
+ #: Module
+ BLAKE2 = BLAKE2b
+ #: Max output size (in bits)
+ max_bits = 512
+ #: Max output size (in bytes)
+ max_bytes = 64
+ #: Bit size of the digests for which an ASN OID exists
+ digest_bits_oid = (160, 256, 384, 512)
+ # http://tools.ietf.org/html/draft-saarinen-blake2-02
+ oid_variant = "1"
+
+
+class Blake2sTest(Blake2Test):
+ #: Module
+ BLAKE2 = BLAKE2s
+ #: Max output size (in bits)
+ max_bits = 256
+ #: Max output size (in bytes)
+ max_bytes = 32
+ #: Bit size of the digests for which an ASN OID exists
+ digest_bits_oid = (128, 160, 224, 256)
+ # http://tools.ietf.org/html/draft-saarinen-blake2-02
+ oid_variant = "2"
+
+
+class Blake2OfficialTestVector(unittest.TestCase):
+
+ def _load_tests(self, test_vector_file):
+ expected = "in"
+ test_vectors = []
+ with open(test_vector_file, "rt") as test_vector_fd:
+ for line_number, line in enumerate(test_vector_fd):
+
+ if line.strip() == "" or line.startswith("#"):
+ continue
+
+ res = re.match("%s:\t([0-9A-Fa-f]*)" % expected, line)
+ if not res:
+ raise ValueError("Incorrect test vector format (line %d)"
+ % line_number)
+
+ if res.group(1):
+ bin_value = unhexlify(tobytes(res.group(1)))
+ else:
+ bin_value = b""
+ if expected == "in":
+ input_data = bin_value
+ expected = "key"
+ elif expected == "key":
+ key = bin_value
+ expected = "hash"
+ else:
+ result = bin_value
+ expected = "in"
+ test_vectors.append((input_data, key, result))
+ return test_vectors
+
+ def setUp(self):
+
+ dir_comps = ("Hash", self.name)
+ file_name = self.name.lower() + "-test.txt"
+ self.description = "%s tests" % self.name
+
+ try:
+ import pycryptodome_test_vectors # type: ignore
+ except ImportError:
+ warnings.warn("Warning: skipping extended tests for %s" % self.name,
+ UserWarning)
+ self.test_vectors = []
+ return
+
+ init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
+ full_file_name = os.path.join(os.path.join(init_dir, *dir_comps), file_name)
+ self.test_vectors = self._load_tests(full_file_name)
+
+ def runTest(self):
+ for (input_data, key, result) in self.test_vectors:
+ mac = self.BLAKE2.new(key=key, digest_bytes=self.max_bytes)
+ mac.update(input_data)
+ self.assertEqual(mac.digest(), result)
+
+
+class Blake2bOfficialTestVector(Blake2OfficialTestVector):
+ #: Module
+ BLAKE2 = BLAKE2b
+ #: Hash name
+ name = "BLAKE2b"
+ #: Max digest size
+ max_bytes = 64
+
+
+class Blake2sOfficialTestVector(Blake2OfficialTestVector):
+ #: Module
+ BLAKE2 = BLAKE2s
+ #: Hash name
+ name = "BLAKE2s"
+ #: Max digest size
+ max_bytes = 32
+
+
+class Blake2TestVector1(unittest.TestCase):
+
+ def _load_tests(self, test_vector_file):
+ test_vectors = []
+ with open(test_vector_file, "rt") as test_vector_fd:
+ for line_number, line in enumerate(test_vector_fd):
+ if line.strip() == "" or line.startswith("#"):
+ continue
+ res = re.match("digest: ([0-9A-Fa-f]*)", line)
+ if not res:
+ raise ValueError("Incorrect test vector format (line %d)"
+ % line_number)
+
+ test_vectors.append(unhexlify(tobytes(res.group(1))))
+ return test_vectors
+
+ def setUp(self):
+ dir_comps = ("Hash", self.name)
+ file_name = "tv1.txt"
+ self.description = "%s tests" % self.name
+
+ try:
+ import pycryptodome_test_vectors
+ except ImportError:
+ warnings.warn("Warning: skipping extended tests for %s" % self.name,
+ UserWarning)
+ self.test_vectors = []
+ return
+
+ init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
+ full_file_name = os.path.join(os.path.join(init_dir, *dir_comps), file_name)
+ self.test_vectors = self._load_tests(full_file_name)
+
+ def runTest(self):
+
+ for tv in self.test_vectors:
+ digest_bytes = len(tv)
+ next_data = b""
+ for _ in range(100):
+ h = self.BLAKE2.new(digest_bytes=digest_bytes)
+ h.update(next_data)
+ next_data = h.digest() + next_data
+ self.assertEqual(h.digest(), tv)
+
+
+class Blake2bTestVector1(Blake2TestVector1):
+ #: Module
+ BLAKE2 = BLAKE2b
+ #: Hash name
+ name = "BLAKE2b"
+
+
+class Blake2sTestVector1(Blake2TestVector1):
+ #: Module
+ BLAKE2 = BLAKE2s
+ #: Hash name
+ name = "BLAKE2s"
+
+
+class Blake2TestVector2(unittest.TestCase):
+
+ def _load_tests(self, test_vector_file):
+ test_vectors = []
+ with open(test_vector_file, "rt") as test_vector_fd:
+ for line_number, line in enumerate(test_vector_fd):
+ if line.strip() == "" or line.startswith("#"):
+ continue
+ res = re.match(r"digest\(([0-9]+)\): ([0-9A-Fa-f]*)", line)
+ if not res:
+ raise ValueError("Incorrect test vector format (line %d)"
+ % line_number)
+ key_size = int(res.group(1))
+ result = unhexlify(tobytes(res.group(2)))
+ test_vectors.append((key_size, result))
+ return test_vectors
+
+ def setUp(self):
+ dir_comps = ("Hash", self.name)
+ file_name = "tv2.txt"
+ self.description = "%s tests" % self.name
+
+ try:
+ import pycryptodome_test_vectors # type: ignore
+ except ImportError:
+ warnings.warn("Warning: skipping extended tests for %s" % self.name,
+ UserWarning)
+ self.test_vectors = []
+ return
+
+ init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
+ full_file_name = os.path.join(os.path.join(init_dir, *dir_comps), file_name)
+ self.test_vectors = self._load_tests(full_file_name)
+
+ def runTest(self):
+
+ for key_size, result in self.test_vectors:
+ next_data = b""
+ for _ in range(100):
+ h = self.BLAKE2.new(digest_bytes=self.max_bytes,
+ key=b"A" * key_size)
+ h.update(next_data)
+ next_data = h.digest() + next_data
+ self.assertEqual(h.digest(), result)
+
+
+class Blake2bTestVector2(Blake2TestVector1):
+ #: Module
+ BLAKE2 = BLAKE2b
+ #: Hash name
+ name = "BLAKE2b"
+ #: Max digest size in bytes
+ max_bytes = 64
+
+
+class Blake2sTestVector2(Blake2TestVector1):
+ #: Module
+ BLAKE2 = BLAKE2s
+ #: Hash name
+ name = "BLAKE2s"
+ #: Max digest size in bytes
+ max_bytes = 32
+
+
+def get_tests(config={}):
+ tests = []
+
+ tests += list_test_cases(Blake2bTest)
+ tests.append(Blake2bOfficialTestVector())
+ tests.append(Blake2bTestVector1())
+ tests.append(Blake2bTestVector2())
+
+ tests += list_test_cases(Blake2sTest)
+ tests.append(Blake2sOfficialTestVector())
+ tests.append(Blake2sTestVector1())
+ tests.append(Blake2sTestVector2())
+
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_CMAC.py b/lib/Crypto/SelfTest/Hash/test_CMAC.py
new file mode 100644
index 0000000..f4763f2
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_CMAC.py
@@ -0,0 +1,448 @@
+#
+# SelfTest/Hash/CMAC.py: Self-test for the CMAC module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.CMAC"""
+
+import json
+import unittest
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import tobytes
+
+from Crypto.Hash import CMAC
+from Crypto.Cipher import AES, DES3
+from Crypto.Hash import SHAKE128
+
+from Crypto.Util.strxor import strxor
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+
+# This is a list of (key, data, result, description, module) tuples.
+test_data = [
+
+ ## Test vectors from RFC 4493 ##
+ ## The are also in NIST SP 800 38B D.2 ##
+ ( '2b7e151628aed2a6abf7158809cf4f3c',
+ '',
+ 'bb1d6929e95937287fa37d129b756746',
+ 'RFC 4493 #1',
+ AES
+ ),
+
+ ( '2b7e151628aed2a6abf7158809cf4f3c',
+ '6bc1bee22e409f96e93d7e117393172a',
+ '070a16b46b4d4144f79bdd9dd04a287c',
+ 'RFC 4493 #2',
+ AES
+ ),
+
+ ( '2b7e151628aed2a6abf7158809cf4f3c',
+ '6bc1bee22e409f96e93d7e117393172a'+
+ 'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411',
+ 'dfa66747de9ae63030ca32611497c827',
+ 'RFC 4493 #3',
+ AES
+ ),
+
+ ( '2b7e151628aed2a6abf7158809cf4f3c',
+ '6bc1bee22e409f96e93d7e117393172a'+
+ 'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411e5fbc1191a0a52ef'+
+ 'f69f2445df4f9b17ad2b417be66c3710',
+ '51f0bebf7e3b9d92fc49741779363cfe',
+ 'RFC 4493 #4',
+ AES
+ ),
+
+ ## The rest of Appendix D of NIST SP 800 38B
+ ## was not totally correct.
+ ## Values in Examples 14, 15, 18, and 19 were wrong.
+ ## The updated test values are published in:
+ ## http://csrc.nist.gov/publications/nistpubs/800-38B/Updated_CMAC_Examples.pdf
+
+ ( '8e73b0f7da0e6452c810f32b809079e5'+
+ '62f8ead2522c6b7b',
+ '',
+ 'd17ddf46adaacde531cac483de7a9367',
+ 'NIST SP 800 38B D.2 Example 5',
+ AES
+ ),
+
+ ( '8e73b0f7da0e6452c810f32b809079e5'+
+ '62f8ead2522c6b7b',
+ '6bc1bee22e409f96e93d7e117393172a',
+ '9e99a7bf31e710900662f65e617c5184',
+ 'NIST SP 800 38B D.2 Example 6',
+ AES
+ ),
+
+ ( '8e73b0f7da0e6452c810f32b809079e5'+
+ '62f8ead2522c6b7b',
+ '6bc1bee22e409f96e93d7e117393172a'+
+ 'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411',
+ '8a1de5be2eb31aad089a82e6ee908b0e',
+ 'NIST SP 800 38B D.2 Example 7',
+ AES
+ ),
+
+ ( '8e73b0f7da0e6452c810f32b809079e5'+
+ '62f8ead2522c6b7b',
+ '6bc1bee22e409f96e93d7e117393172a'+
+ 'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411e5fbc1191a0a52ef'+
+ 'f69f2445df4f9b17ad2b417be66c3710',
+ 'a1d5df0eed790f794d77589659f39a11',
+ 'NIST SP 800 38B D.2 Example 8',
+ AES
+ ),
+
+ ( '603deb1015ca71be2b73aef0857d7781'+
+ '1f352c073b6108d72d9810a30914dff4',
+ '',
+ '028962f61b7bf89efc6b551f4667d983',
+ 'NIST SP 800 38B D.3 Example 9',
+ AES
+ ),
+
+ ( '603deb1015ca71be2b73aef0857d7781'+
+ '1f352c073b6108d72d9810a30914dff4',
+ '6bc1bee22e409f96e93d7e117393172a',
+ '28a7023f452e8f82bd4bf28d8c37c35c',
+ 'NIST SP 800 38B D.3 Example 10',
+ AES
+ ),
+
+ ( '603deb1015ca71be2b73aef0857d7781'+
+ '1f352c073b6108d72d9810a30914dff4',
+ '6bc1bee22e409f96e93d7e117393172a'+
+ 'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411',
+ 'aaf3d8f1de5640c232f5b169b9c911e6',
+ 'NIST SP 800 38B D.3 Example 11',
+ AES
+ ),
+
+ ( '603deb1015ca71be2b73aef0857d7781'+
+ '1f352c073b6108d72d9810a30914dff4',
+ '6bc1bee22e409f96e93d7e117393172a'+
+ 'ae2d8a571e03ac9c9eb76fac45af8e51'+
+ '30c81c46a35ce411e5fbc1191a0a52ef'+
+ 'f69f2445df4f9b17ad2b417be66c3710',
+ 'e1992190549f6ed5696a2c056c315410',
+ 'NIST SP 800 38B D.3 Example 12',
+ AES
+ ),
+
+ ( '8aa83bf8cbda1062'+
+ '0bc1bf19fbb6cd58'+
+ 'bc313d4a371ca8b5',
+ '',
+ 'b7a688e122ffaf95',
+ 'NIST SP 800 38B D.4 Example 13',
+ DES3
+ ),
+
+ ( '8aa83bf8cbda1062'+
+ '0bc1bf19fbb6cd58'+
+ 'bc313d4a371ca8b5',
+ '6bc1bee22e409f96',
+ '8e8f293136283797',
+ 'NIST SP 800 38B D.4 Example 14',
+ DES3
+ ),
+
+ ( '8aa83bf8cbda1062'+
+ '0bc1bf19fbb6cd58'+
+ 'bc313d4a371ca8b5',
+ '6bc1bee22e409f96'+
+ 'e93d7e117393172a'+
+ 'ae2d8a57',
+ '743ddbe0ce2dc2ed',
+ 'NIST SP 800 38B D.4 Example 15',
+ DES3
+ ),
+
+ ( '8aa83bf8cbda1062'+
+ '0bc1bf19fbb6cd58'+
+ 'bc313d4a371ca8b5',
+ '6bc1bee22e409f96'+
+ 'e93d7e117393172a'+
+ 'ae2d8a571e03ac9c'+
+ '9eb76fac45af8e51',
+ '33e6b1092400eae5',
+ 'NIST SP 800 38B D.4 Example 16',
+ DES3
+ ),
+
+ ( '4cf15134a2850dd5'+
+ '8a3d10ba80570d38',
+ '',
+ 'bd2ebf9a3ba00361',
+ 'NIST SP 800 38B D.7 Example 17',
+ DES3
+ ),
+
+ ( '4cf15134a2850dd5'+
+ '8a3d10ba80570d38',
+ '6bc1bee22e409f96',
+ '4ff2ab813c53ce83',
+ 'NIST SP 800 38B D.7 Example 18',
+ DES3
+ ),
+
+ ( '4cf15134a2850dd5'+
+ '8a3d10ba80570d38',
+ '6bc1bee22e409f96'+
+ 'e93d7e117393172a'+
+ 'ae2d8a57',
+ '62dd1b471902bd4e',
+ 'NIST SP 800 38B D.7 Example 19',
+ DES3
+ ),
+
+ ( '4cf15134a2850dd5'+
+ '8a3d10ba80570d38',
+ '6bc1bee22e409f96'+
+ 'e93d7e117393172a'+
+ 'ae2d8a571e03ac9c'+
+ '9eb76fac45af8e51',
+ '31b1e431dabc4eb8',
+ 'NIST SP 800 38B D.7 Example 20',
+ DES3
+ ),
+
+]
+
+
+def get_tag_random(tag, length):
+ return SHAKE128.new(data=tobytes(tag)).read(length)
+
+
+class TestCMAC(unittest.TestCase):
+
+ def test_internal_caching(self):
+ """Verify that internal caching is implemented correctly"""
+
+ data_to_mac = get_tag_random("data_to_mac", 128)
+ key = get_tag_random("key", 16)
+ ref_mac = CMAC.new(key, msg=data_to_mac, ciphermod=AES).digest()
+
+ # Break up in chunks of different length
+ # The result must always be the same
+ for chunk_length in 1, 2, 3, 7, 10, 13, 16, 40, 80, 128:
+
+ chunks = [data_to_mac[i:i+chunk_length] for i in
+ range(0, len(data_to_mac), chunk_length)]
+
+ mac = CMAC.new(key, ciphermod=AES)
+ for chunk in chunks:
+ mac.update(chunk)
+ self.assertEqual(ref_mac, mac.digest())
+
+ def test_update_after_digest(self):
+ msg = b"rrrrttt"
+ key = b"4" * 16
+
+ # Normally, update() cannot be done after digest()
+ h = CMAC.new(key, msg[:4], ciphermod=AES)
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+ dig2 = CMAC.new(key, msg, ciphermod=AES).digest()
+
+ # With the proper flag, it is allowed
+ h2 = CMAC.new(key, msg[:4], ciphermod=AES, update_after_digest=True)
+ self.assertEqual(h2.digest(), dig1)
+ # ... and the subsequent digest applies to the entire message
+ # up to that point
+ h2.update(msg[4:])
+ self.assertEqual(h2.digest(), dig2)
+
+
+class ByteArrayTests(unittest.TestCase):
+
+ def runTest(self):
+
+ key = b"0" * 16
+ data = b"\x00\x01\x02"
+
+ # Data and key can be a bytearray (during initialization)
+ key_ba = bytearray(key)
+ data_ba = bytearray(data)
+
+ h1 = CMAC.new(key, data, ciphermod=AES)
+ h2 = CMAC.new(key_ba, data_ba, ciphermod=AES)
+ key_ba[:1] = b'\xFF'
+ data_ba[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a bytearray (during operation)
+ key_ba = bytearray(key)
+ data_ba = bytearray(data)
+
+ h1 = CMAC.new(key, ciphermod=AES)
+ h2 = CMAC.new(key, ciphermod=AES)
+ h1.update(data)
+ h2.update(data_ba)
+ data_ba[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class MemoryViewTests(unittest.TestCase):
+
+ def runTest(self):
+
+ key = b"0" * 16
+ data = b"\x00\x01\x02"
+
+ def get_mv_ro(data):
+ return memoryview(data)
+
+ def get_mv_rw(data):
+ return memoryview(bytearray(data))
+
+ for get_mv in (get_mv_ro, get_mv_rw):
+
+ # Data and key can be a memoryview (during initialization)
+ key_mv = get_mv(key)
+ data_mv = get_mv(data)
+
+ h1 = CMAC.new(key, data, ciphermod=AES)
+ h2 = CMAC.new(key_mv, data_mv, ciphermod=AES)
+ if not data_mv.readonly:
+ key_mv[:1] = b'\xFF'
+ data_mv[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a memoryview (during operation)
+ data_mv = get_mv(data)
+
+ h1 = CMAC.new(key, ciphermod=AES)
+ h2 = CMAC.new(key, ciphermod=AES)
+ h1.update(data)
+ h2.update(data_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._id = "None"
+
+ def setUp(self):
+
+ def filter_tag(group):
+ return group['tagSize'] // 8
+
+ self.tv = load_test_vectors_wycheproof(("Hash", "wycheproof"),
+ "aes_cmac_test.json",
+ "Wycheproof CMAC",
+ group_tag={'tag_size': filter_tag})
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_create_mac(self, tv):
+ self._id = "Wycheproof MAC creation Test #" + str(tv.id)
+
+ try:
+ tag = CMAC.new(tv.key, tv.msg, ciphermod=AES, mac_len=tv.tag_size).digest()
+ except ValueError as e:
+ if len(tv.key) not in (16, 24, 32) and "key length" in str(e):
+ return
+ raise e
+ if tv.valid:
+ self.assertEqual(tag, tv.tag)
+ self.warn(tv)
+
+ def test_verify_mac(self, tv):
+ self._id = "Wycheproof MAC verification Test #" + str(tv.id)
+
+ try:
+ mac = CMAC.new(tv.key, tv.msg, ciphermod=AES, mac_len=tv.tag_size)
+ except ValueError as e:
+ if len(tv.key) not in (16, 24, 32) and "key length" in str(e):
+ return
+ raise e
+ try:
+ mac.verify(tv.tag)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.warn(tv)
+
+ def runTest(self):
+
+ for tv in self.tv:
+ self.test_create_mac(tv)
+ self.test_verify_mac(tv)
+
+
+def get_tests(config={}):
+ global test_data
+ import types
+ from .common import make_mac_tests
+
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ # Add new() parameters to the back of each test vector
+ params_test_data = []
+ for row in test_data:
+ t = list(row)
+ t[4] = dict(ciphermod=t[4])
+ params_test_data.append(t)
+
+ tests = make_mac_tests(CMAC, "CMAC", params_test_data)
+ tests.append(ByteArrayTests())
+ tests.append(list_test_cases(TestCMAC))
+ tests.append(MemoryViewTests())
+ tests += [ TestVectorsWycheproof(wycheproof_warnings) ]
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_HMAC.py b/lib/Crypto/SelfTest/Hash/test_HMAC.py
new file mode 100644
index 0000000..26b7b24
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_HMAC.py
@@ -0,0 +1,548 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/HMAC.py: Self-test for the HMAC module
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.HMAC"""
+
+import unittest
+from binascii import hexlify
+from Crypto.Util.py3compat import tostr, tobytes
+
+from Crypto.Hash import (HMAC, MD5, SHA1, SHA256,
+ SHA224, SHA384, SHA512,
+ RIPEMD160,
+ SHA3_224, SHA3_256, SHA3_384, SHA3_512)
+
+
+hash_modules = dict(MD5=MD5, SHA1=SHA1, SHA256=SHA256,
+ SHA224=SHA224, SHA384=SHA384, SHA512=SHA512,
+ RIPEMD160=RIPEMD160,
+ SHA3_224=SHA3_224, SHA3_256=SHA3_256,
+ SHA3_384=SHA3_384, SHA3_512=SHA3_512)
+
+default_hash = None
+
+def xl(text):
+ return tostr(hexlify(tobytes(text)))
+
+# This is a list of (key, data, results, description) tuples.
+test_data = [
+ ## Test vectors from RFC 2202 ##
+ # Test that the default hashmod is MD5
+ ('0b' * 16,
+ '4869205468657265',
+ dict(default_hash='9294727a3638bb1c13f48ef8158bfc9d'),
+ 'default-is-MD5'),
+
+ # Test case 1 (MD5)
+ ('0b' * 16,
+ '4869205468657265',
+ dict(MD5='9294727a3638bb1c13f48ef8158bfc9d'),
+ 'RFC 2202 #1-MD5 (HMAC-MD5)'),
+
+ # Test case 1 (SHA1)
+ ('0b' * 20,
+ '4869205468657265',
+ dict(SHA1='b617318655057264e28bc0b6fb378c8ef146be00'),
+ 'RFC 2202 #1-SHA1 (HMAC-SHA1)'),
+
+ # Test case 2
+ ('4a656665',
+ '7768617420646f2079612077616e7420666f72206e6f7468696e673f',
+ dict(MD5='750c783e6ab0b503eaa86e310a5db738',
+ SHA1='effcdf6ae5eb2fa2d27416d5f184df9c259a7c79'),
+ 'RFC 2202 #2 (HMAC-MD5/SHA1)'),
+
+ # Test case 3 (MD5)
+ ('aa' * 16,
+ 'dd' * 50,
+ dict(MD5='56be34521d144c88dbb8c733f0e8b3f6'),
+ 'RFC 2202 #3-MD5 (HMAC-MD5)'),
+
+ # Test case 3 (SHA1)
+ ('aa' * 20,
+ 'dd' * 50,
+ dict(SHA1='125d7342b9ac11cd91a39af48aa17b4f63f175d3'),
+ 'RFC 2202 #3-SHA1 (HMAC-SHA1)'),
+
+ # Test case 4
+ ('0102030405060708090a0b0c0d0e0f10111213141516171819',
+ 'cd' * 50,
+ dict(MD5='697eaf0aca3a3aea3a75164746ffaa79',
+ SHA1='4c9007f4026250c6bc8414f9bf50c86c2d7235da'),
+ 'RFC 2202 #4 (HMAC-MD5/SHA1)'),
+
+ # Test case 5 (MD5)
+ ('0c' * 16,
+ '546573742057697468205472756e636174696f6e',
+ dict(MD5='56461ef2342edc00f9bab995690efd4c'),
+ 'RFC 2202 #5-MD5 (HMAC-MD5)'),
+
+ # Test case 5 (SHA1)
+ # NB: We do not implement hash truncation, so we only test the full hash here.
+ ('0c' * 20,
+ '546573742057697468205472756e636174696f6e',
+ dict(SHA1='4c1a03424b55e07fe7f27be1d58bb9324a9a5a04'),
+ 'RFC 2202 #5-SHA1 (HMAC-SHA1)'),
+
+ # Test case 6
+ ('aa' * 80,
+ '54657374205573696e67204c6172676572205468616e20426c6f636b2d53697a'
+ + '65204b6579202d2048617368204b6579204669727374',
+ dict(MD5='6b1ab7fe4bd7bf8f0b62e6ce61b9d0cd',
+ SHA1='aa4ae5e15272d00e95705637ce8a3b55ed402112'),
+ 'RFC 2202 #6 (HMAC-MD5/SHA1)'),
+
+ # Test case 7
+ ('aa' * 80,
+ '54657374205573696e67204c6172676572205468616e20426c6f636b2d53697a'
+ + '65204b657920616e64204c6172676572205468616e204f6e6520426c6f636b2d'
+ + '53697a652044617461',
+ dict(MD5='6f630fad67cda0ee1fb1f562db3aa53e',
+ SHA1='e8e99d0f45237d786d6bbaa7965c7808bbff1a91'),
+ 'RFC 2202 #7 (HMAC-MD5/SHA1)'),
+
+ ## Test vectors from RFC 4231 ##
+ # 4.2. Test Case 1
+ ('0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b',
+ '4869205468657265',
+ dict(SHA256='''
+ b0344c61d8db38535ca8afceaf0bf12b
+ 881dc200c9833da726e9376c2e32cff7
+ '''),
+ 'RFC 4231 #1 (HMAC-SHA256)'),
+
+ # 4.3. Test Case 2 - Test with a key shorter than the length of the HMAC
+ # output.
+ ('4a656665',
+ '7768617420646f2079612077616e7420666f72206e6f7468696e673f',
+ dict(SHA256='''
+ 5bdcc146bf60754e6a042426089575c7
+ 5a003f089d2739839dec58b964ec3843
+ '''),
+ 'RFC 4231 #2 (HMAC-SHA256)'),
+
+ # 4.4. Test Case 3 - Test with a combined length of key and data that is
+ # larger than 64 bytes (= block-size of SHA-224 and SHA-256).
+ ('aa' * 20,
+ 'dd' * 50,
+ dict(SHA256='''
+ 773ea91e36800e46854db8ebd09181a7
+ 2959098b3ef8c122d9635514ced565fe
+ '''),
+ 'RFC 4231 #3 (HMAC-SHA256)'),
+
+ # 4.5. Test Case 4 - Test with a combined length of key and data that is
+ # larger than 64 bytes (= block-size of SHA-224 and SHA-256).
+ ('0102030405060708090a0b0c0d0e0f10111213141516171819',
+ 'cd' * 50,
+ dict(SHA256='''
+ 82558a389a443c0ea4cc819899f2083a
+ 85f0faa3e578f8077a2e3ff46729665b
+ '''),
+ 'RFC 4231 #4 (HMAC-SHA256)'),
+
+ # 4.6. Test Case 5 - Test with a truncation of output to 128 bits.
+ #
+ # Not included because we do not implement hash truncation.
+ #
+
+ # 4.7. Test Case 6 - Test with a key larger than 128 bytes (= block-size of
+ # SHA-384 and SHA-512).
+ ('aa' * 131,
+ '54657374205573696e67204c6172676572205468616e20426c6f636b2d53697a'
+ + '65204b6579202d2048617368204b6579204669727374',
+ dict(SHA256='''
+ 60e431591ee0b67f0d8a26aacbf5b77f
+ 8e0bc6213728c5140546040f0ee37f54
+ '''),
+ 'RFC 4231 #6 (HMAC-SHA256)'),
+
+ # 4.8. Test Case 7 - Test with a key and data that is larger than 128 bytes
+ # (= block-size of SHA-384 and SHA-512).
+ ('aa' * 131,
+ '5468697320697320612074657374207573696e672061206c6172676572207468'
+ + '616e20626c6f636b2d73697a65206b657920616e642061206c61726765722074'
+ + '68616e20626c6f636b2d73697a6520646174612e20546865206b6579206e6565'
+ + '647320746f20626520686173686564206265666f7265206265696e6720757365'
+ + '642062792074686520484d414320616c676f726974686d2e',
+ dict(SHA256='''
+ 9b09ffa71b942fcb27635fbcd5b0e944
+ bfdc63644f0713938a7f51535c3a35e2
+ '''),
+ 'RFC 4231 #7 (HMAC-SHA256)'),
+
+ # Test case 8 (SHA224)
+ ('4a656665',
+ '7768617420646f2079612077616e74'
+ + '20666f72206e6f7468696e673f',
+ dict(SHA224='a30e01098bc6dbbf45690f3a7e9e6d0f8bbea2a39e6148008fd05e44'),
+ 'RFC 4634 8.4 SHA224 (HMAC-SHA224)'),
+
+ # Test case 9 (SHA384)
+ ('4a656665',
+ '7768617420646f2079612077616e74'
+ + '20666f72206e6f7468696e673f',
+ dict(SHA384='af45d2e376484031617f78d2b58a6b1b9c7ef464f5a01b47e42ec3736322445e8e2240ca5e69e2c78b3239ecfab21649'),
+ 'RFC 4634 8.4 SHA384 (HMAC-SHA384)'),
+
+ # Test case 10 (SHA512)
+ ('4a656665',
+ '7768617420646f2079612077616e74'
+ + '20666f72206e6f7468696e673f',
+ dict(SHA512='164b7a7bfcf819e2e395fbe73b56e0a387bd64222e831fd610270cd7ea2505549758bf75c05a994a6d034f65f8f0e6fdcaeab1a34d4a6b4b636e070a38bce737'),
+ 'RFC 4634 8.4 SHA512 (HMAC-SHA512)'),
+
+ # Test case 11 (RIPEMD)
+ ('0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b',
+ xl("Hi There"),
+ dict(RIPEMD160='24cb4bd67d20fc1a5d2ed7732dcc39377f0a5668'),
+ 'RFC 2286 #1 (HMAC-RIPEMD)'),
+
+ # Test case 12 (RIPEMD)
+ (xl("Jefe"),
+ xl("what do ya want for nothing?"),
+ dict(RIPEMD160='dda6c0213a485a9e24f4742064a7f033b43c4069'),
+ 'RFC 2286 #2 (HMAC-RIPEMD)'),
+
+ # Test case 13 (RIPEMD)
+ ('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
+ 'dd' * 50,
+ dict(RIPEMD160='b0b105360de759960ab4f35298e116e295d8e7c1'),
+ 'RFC 2286 #3 (HMAC-RIPEMD)'),
+
+ # Test case 14 (RIPEMD)
+ ('0102030405060708090a0b0c0d0e0f10111213141516171819',
+ 'cd' * 50,
+ dict(RIPEMD160='d5ca862f4d21d5e610e18b4cf1beb97a4365ecf4'),
+ 'RFC 2286 #4 (HMAC-RIPEMD)'),
+
+ # Test case 15 (RIPEMD)
+ ('0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c',
+ xl("Test With Truncation"),
+ dict(RIPEMD160='7619693978f91d90539ae786500ff3d8e0518e39'),
+ 'RFC 2286 #5 (HMAC-RIPEMD)'),
+
+ # Test case 16 (RIPEMD)
+ ('aa' * 80,
+ xl("Test Using Larger Than Block-Size Key - Hash Key First"),
+ dict(RIPEMD160='6466ca07ac5eac29e1bd523e5ada7605b791fd8b'),
+ 'RFC 2286 #6 (HMAC-RIPEMD)'),
+
+ # Test case 17 (RIPEMD)
+ ('aa' * 80,
+ xl("Test Using Larger Than Block-Size Key and Larger Than One Block-Size Data"),
+ dict(RIPEMD160='69ea60798d71616cce5fd0871e23754cd75d5a0a'),
+ 'RFC 2286 #7 (HMAC-RIPEMD)'),
+
+ # From https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/HMAC_SHA3-224.pdf
+ (
+ '000102030405060708090a0b0c0d0e0f'
+ '101112131415161718191a1b',
+ xl('Sample message for keylen<blocklen'),
+ dict(SHA3_224='332cfd59347fdb8e576e77260be4aba2d6dc53117b3bfb52c6d18c04'),
+ 'NIST CSRC Sample #1 (SHA3-224)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '404142434445464748494a4b4c4d4e4f'\
+ '505152535455565758595a5b5c5d5e5f'\
+ '606162636465666768696a6b6c6d6e6f'\
+ '707172737475767778797a7b7c7d7e7f'\
+ '808182838485868788898a8b8c8d8e8f',
+ xl('Sample message for keylen=blocklen'),
+ dict(SHA3_224='d8b733bcf66c644a12323d564e24dcf3fc75f231f3b67968359100c7'),
+ 'NIST CSRC Sample #2 (SHA3-224)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '404142434445464748494a4b4c4d4e4f'\
+ '505152535455565758595a5b5c5d5e5f'\
+ '606162636465666768696a6b6c6d6e6f'\
+ '707172737475767778797a7b7c7d7e7f'\
+ '808182838485868788898a8b8c8d8e8f'\
+ '909192939495969798999a9b9c9d9e9f'\
+ 'a0a1a2a3a4a5a6a7a8a9aaab',
+ xl('Sample message for keylen>blocklen'),
+ dict(SHA3_224='078695eecc227c636ad31d063a15dd05a7e819a66ec6d8de1e193e59'),
+ 'NIST CSRC Sample #3 (SHA3-224)'
+ ),
+
+ # From https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/HMAC_SHA3-256.pdf
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f',
+ xl('Sample message for keylen<blocklen'),
+ dict(SHA3_256='4fe8e202c4f058e8dddc23d8c34e467343e23555e24fc2f025d598f558f67205'),
+ 'NIST CSRC Sample #1 (SHA3-256)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '404142434445464748494a4b4c4d4e4f'\
+ '505152535455565758595a5b5c5d5e5f'\
+ '606162636465666768696a6b6c6d6e6f'\
+ '707172737475767778797a7b7c7d7e7f'\
+ '8081828384858687',
+ xl('Sample message for keylen=blocklen'),
+ dict(SHA3_256='68b94e2e538a9be4103bebb5aa016d47961d4d1aa906061313b557f8af2c3faa'),
+ 'NIST CSRC Sample #2 (SHA3-256)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '404142434445464748494a4b4c4d4e4f'\
+ '505152535455565758595a5b5c5d5e5f'\
+ '606162636465666768696a6b6c6d6e6f'\
+ '707172737475767778797a7b7c7d7e7f'\
+ '808182838485868788898a8b8c8d8e8f'\
+ '909192939495969798999a9b9c9d9e9f'\
+ 'a0a1a2a3a4a5a6a7',
+ xl('Sample message for keylen>blocklen'),
+ dict(SHA3_256='9bcf2c238e235c3ce88404e813bd2f3a97185ac6f238c63d6229a00b07974258'),
+ 'NIST CSRC Sample #3 (SHA3-256)'
+ ),
+
+ # From https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/HMAC_SHA3-384.pdf
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'
+ '202122232425262728292a2b2c2d2e2f',
+ xl('Sample message for keylen<blocklen'),
+ dict(SHA3_384='d588a3c51f3f2d906e8298c1199aa8ff6296218127f6b38a90b6afe2c5617725bc99987f79b22a557b6520db710b7f42'),
+ 'NIST CSRC Sample #1 (SHA3-384)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '404142434445464748494a4b4c4d4e4f'\
+ '505152535455565758595a5b5c5d5e5f'\
+ '6061626364656667',
+ xl('Sample message for keylen=blocklen'),
+ dict(SHA3_384='a27d24b592e8c8cbf6d4ce6fc5bf62d8fc98bf2d486640d9eb8099e24047837f5f3bffbe92dcce90b4ed5b1e7e44fa90'),
+ 'NIST CSRC Sample #2 (SHA3-384)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '404142434445464748494a4b4c4d4e4f'\
+ '505152535455565758595a5b5c5d5e5f'\
+ '606162636465666768696a6b6c6d6e6f'\
+ '707172737475767778797a7b7c7d7e7f'\
+ '808182838485868788898a8b8c8d8e8f'\
+ '9091929394959697',
+ xl('Sample message for keylen>blocklen'),
+ dict(SHA3_384='e5ae4c739f455279368ebf36d4f5354c95aa184c899d3870e460ebc288ef1f9470053f73f7c6da2a71bcaec38ce7d6ac'),
+ 'NIST CSRC Sample #3 (SHA3-384)'
+ ),
+
+ # From https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/HMAC_SHA3-512.pdf
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f',
+ xl('Sample message for keylen<blocklen'),
+ dict(SHA3_512='4efd629d6c71bf86162658f29943b1c308ce27cdfa6db0d9c3ce81763f9cbce5f7ebe9868031db1a8f8eb7b6b95e5c5e3f657a8996c86a2f6527e307f0213196'),
+ 'NIST CSRC Sample #1 (SHA3-512)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '4041424344454647',
+ xl('Sample message for keylen=blocklen'),
+ dict(SHA3_512='544e257ea2a3e5ea19a590e6a24b724ce6327757723fe2751b75bf007d80f6b360744bf1b7a88ea585f9765b47911976d3191cf83c039f5ffab0d29cc9d9b6da'),
+ 'NIST CSRC Sample #2 (SHA3-512)'
+ ),
+ (
+ '000102030405060708090a0b0c0d0e0f'\
+ '101112131415161718191a1b1c1d1e1f'\
+ '202122232425262728292a2b2c2d2e2f'\
+ '303132333435363738393a3b3c3d3e3f'\
+ '404142434445464748494a4b4c4d4e4f'\
+ '505152535455565758595a5b5c5d5e5f'\
+ '606162636465666768696a6b6c6d6e6f'\
+ '707172737475767778797a7b7c7d7e7f'\
+ '8081828384858687',
+ xl('Sample message for keylen>blocklen'),
+ dict(SHA3_512='5f464f5e5b7848e3885e49b2c385f0694985d0e38966242dc4a5fe3fea4b37d46b65ceced5dcf59438dd840bab22269f0ba7febdb9fcf74602a35666b2a32915'),
+ 'NIST CSRC Sample #3 (SHA3-512)'
+ ),
+
+]
+
+
+class HMAC_Module_and_Instance_Test(unittest.TestCase):
+ """Test the HMAC construction and verify that it does not
+ matter if you initialize it with a hash module or
+ with an hash instance.
+
+ See https://bugs.launchpad.net/pycrypto/+bug/1209399
+ """
+
+ def __init__(self, hashmods):
+ """Initialize the test with a dictionary of hash modules
+ indexed by their names"""
+
+ unittest.TestCase.__init__(self)
+ self.hashmods = hashmods
+ self.description = ""
+
+ def shortDescription(self):
+ return self.description
+
+ def runTest(self):
+ key = b"\x90\x91\x92\x93" * 4
+ payload = b"\x00" * 100
+
+ for hashname, hashmod in self.hashmods.items():
+ if hashmod is None:
+ continue
+ self.description = "Test HMAC in combination with " + hashname
+ one = HMAC.new(key, payload, hashmod).digest()
+ two = HMAC.new(key, payload, hashmod.new()).digest()
+ self.assertEqual(one, two)
+
+
+class HMAC_None(unittest.TestCase):
+
+ def runTest(self):
+
+ key = b"\x04" * 20
+ one = HMAC.new(key, b"", SHA1).digest()
+ two = HMAC.new(key, None, SHA1).digest()
+ self.assertEqual(one, two)
+
+
+class ByteArrayTests(unittest.TestCase):
+
+ def runTest(self):
+
+ key = b"0" * 16
+ data = b"\x00\x01\x02"
+
+ # Data and key can be a bytearray (during initialization)
+ key_ba = bytearray(key)
+ data_ba = bytearray(data)
+
+ h1 = HMAC.new(key, data)
+ h2 = HMAC.new(key_ba, data_ba)
+ key_ba[:1] = b'\xFF'
+ data_ba[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a bytearray (during operation)
+ key_ba = bytearray(key)
+ data_ba = bytearray(data)
+
+ h1 = HMAC.new(key)
+ h2 = HMAC.new(key)
+ h1.update(data)
+ h2.update(data_ba)
+ data_ba[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class MemoryViewTests(unittest.TestCase):
+
+ def runTest(self):
+
+ key = b"0" * 16
+ data = b"\x00\x01\x02"
+
+ def get_mv_ro(data):
+ return memoryview(data)
+
+ def get_mv_rw(data):
+ return memoryview(bytearray(data))
+
+ for get_mv in (get_mv_ro, get_mv_rw):
+
+ # Data and key can be a memoryview (during initialization)
+ key_mv = get_mv(key)
+ data_mv = get_mv(data)
+
+ h1 = HMAC.new(key, data)
+ h2 = HMAC.new(key_mv, data_mv)
+ if not data_mv.readonly:
+ key_mv[:1] = b'\xFF'
+ data_mv[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a memoryview (during operation)
+ data_mv = get_mv(data)
+
+ h1 = HMAC.new(key)
+ h2 = HMAC.new(key)
+ h1.update(data)
+ h2.update(data_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+def get_tests(config={}):
+ global test_data
+ import types
+ from .common import make_mac_tests
+
+ # A test vector contains multiple results, each one for a
+ # different hash algorithm.
+ # Here we expand each test vector into multiple ones,
+ # and add the relevant parameters that will be passed to new()
+ exp_test_data = []
+ for row in test_data:
+ for modname in row[2].keys():
+ t = list(row)
+ t[2] = row[2][modname]
+ t.append(dict(digestmod=globals()[modname]))
+ exp_test_data.append(t)
+ tests = make_mac_tests(HMAC, "HMAC", exp_test_data)
+ tests.append(HMAC_Module_and_Instance_Test(hash_modules))
+ tests.append(HMAC_None())
+
+ tests.append(ByteArrayTests())
+ tests.append(MemoryViewTests())
+
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_KMAC.py b/lib/Crypto/SelfTest/Hash/test_KMAC.py
new file mode 100644
index 0000000..8e9bf70
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_KMAC.py
@@ -0,0 +1,346 @@
+import unittest
+from binascii import unhexlify, hexlify
+
+from Crypto.Util.py3compat import tobytes
+from Crypto.Util.strxor import strxor_c
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import KMAC128, KMAC256
+
+
+class KMACTest(unittest.TestCase):
+
+ def new(self, *args, **kwargs):
+ return self.KMAC.new(key=b'X' * (self.minimum_key_bits // 8), *args, **kwargs)
+
+ def test_new_positive(self):
+
+ key = b'X' * 32
+
+ h = self.new()
+ for new_func in self.KMAC.new, h.new:
+
+ for dbytes in range(self.minimum_bytes, 128 + 1):
+ hobj = new_func(key=key, mac_len=dbytes)
+ self.assertEqual(hobj.digest_size, dbytes)
+
+ digest1 = new_func(key=key, data=b"\x90").digest()
+ digest2 = new_func(key=key).update(b"\x90").digest()
+ self.assertEqual(digest1, digest2)
+
+ new_func(data=b"A", key=key, custom=b"g")
+
+ hobj = h.new(key=key)
+ self.assertEqual(hobj.digest_size, self.default_bytes)
+
+ def test_new_negative(self):
+
+ h = self.new()
+ for new_func in self.KMAC.new, h.new:
+ self.assertRaises(ValueError, new_func, key=b'X'*32,
+ mac_len=0)
+ self.assertRaises(ValueError, new_func, key=b'X'*32,
+ mac_len=self.minimum_bytes - 1)
+ self.assertRaises(TypeError, new_func,
+ key=u"string")
+ self.assertRaises(TypeError, new_func,
+ data=u"string")
+
+ def test_default_digest_size(self):
+ digest = self.new(data=b'abc').digest()
+ self.assertEqual(len(digest), self.default_bytes)
+
+ def test_update(self):
+ pieces = [b"\x0A" * 200, b"\x14" * 300]
+ h = self.new()
+ h.update(pieces[0]).update(pieces[1])
+ digest = h.digest()
+ h = self.new()
+ h.update(pieces[0] + pieces[1])
+ self.assertEqual(h.digest(), digest)
+
+ def test_update_negative(self):
+ h = self.new()
+ self.assertRaises(TypeError, h.update, u"string")
+
+ def test_digest(self):
+ h = self.new()
+ digest = h.digest()
+
+ # hexdigest does not change the state
+ self.assertEqual(h.digest(), digest)
+ # digest returns a byte string
+ self.assertTrue(isinstance(digest, type(b"digest")))
+
+ def test_update_after_digest(self):
+ msg = b"rrrrttt"
+
+ # Normally, update() cannot be done after digest()
+ h = self.new(mac_len=32, data=msg[:4])
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, dig1)
+
+ def test_hex_digest(self):
+ mac = self.new()
+ digest = mac.digest()
+ hexdigest = mac.hexdigest()
+
+ # hexdigest is equivalent to digest
+ self.assertEqual(hexlify(digest), tobytes(hexdigest))
+ # hexdigest does not change the state
+ self.assertEqual(mac.hexdigest(), hexdigest)
+ # hexdigest returns a string
+ self.assertTrue(isinstance(hexdigest, type("digest")))
+
+ def test_verify(self):
+ h = self.new()
+ mac = h.digest()
+ h.verify(mac)
+ wrong_mac = strxor_c(mac, 255)
+ self.assertRaises(ValueError, h.verify, wrong_mac)
+
+ def test_hexverify(self):
+ h = self.new()
+ mac = h.hexdigest()
+ h.hexverify(mac)
+ self.assertRaises(ValueError, h.hexverify, "4556")
+
+ def test_oid(self):
+
+ oid = "2.16.840.1.101.3.4.2." + self.oid_variant
+ h = self.new()
+ self.assertEqual(h.oid, oid)
+
+ def test_bytearray(self):
+
+ key = b'0' * 32
+ data = b"\x00\x01\x02"
+
+ # Data and key can be a bytearray (during initialization)
+ key_ba = bytearray(key)
+ data_ba = bytearray(data)
+
+ h1 = self.KMAC.new(data=data, key=key)
+ h2 = self.KMAC.new(data=data_ba, key=key_ba)
+ key_ba[:1] = b'\xFF'
+ data_ba[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a bytearray (during operation)
+ data_ba = bytearray(data)
+
+ h1 = self.new()
+ h2 = self.new()
+ h1.update(data)
+ h2.update(data_ba)
+ data_ba[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ def test_memoryview(self):
+
+ key = b'0' * 32
+ data = b"\x00\x01\x02"
+
+ def get_mv_ro(data):
+ return memoryview(data)
+
+ def get_mv_rw(data):
+ return memoryview(bytearray(data))
+
+ for get_mv in (get_mv_ro, get_mv_rw):
+
+ # Data and key can be a memoryview (during initialization)
+ key_mv = get_mv(key)
+ data_mv = get_mv(data)
+
+ h1 = self.KMAC.new(data=data, key=key)
+ h2 = self.KMAC.new(data=data_mv, key=key_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+ key_mv[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a memoryview (during operation)
+ data_mv = get_mv(data)
+
+ h1 = self.new()
+ h2 = self.new()
+ h1.update(data)
+ h2.update(data_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class KMAC128Test(KMACTest):
+
+ KMAC = KMAC128
+
+ minimum_key_bits = 128
+
+ minimum_bytes = 8
+ default_bytes = 64
+
+ oid_variant = "19"
+
+
+class KMAC256Test(KMACTest):
+
+ KMAC = KMAC256
+
+ minimum_key_bits = 256
+
+ minimum_bytes = 8
+ default_bytes = 64
+
+ oid_variant = "20"
+
+
+class NISTExampleTestVectors(unittest.TestCase):
+
+ # https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Standards-and-Guidelines/documents/examples/KMAC_samples.pdf
+ test_data = [
+ (
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F",
+ "00 01 02 03",
+ "",
+ "E5 78 0B 0D 3E A6 F7 D3 A4 29 C5 70 6A A4 3A 00"
+ "FA DB D7 D4 96 28 83 9E 31 87 24 3F 45 6E E1 4E",
+ "Sample #1 NIST",
+ KMAC128
+ ),
+ (
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F",
+ "00 01 02 03",
+ "My Tagged Application",
+ "3B 1F BA 96 3C D8 B0 B5 9E 8C 1A 6D 71 88 8B 71"
+ "43 65 1A F8 BA 0A 70 70 C0 97 9E 28 11 32 4A A5",
+ "Sample #2 NIST",
+ KMAC128
+ ),
+ (
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F",
+ "00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F"
+ "10 11 12 13 14 15 16 17 18 19 1A 1B 1C 1D 1E 1F"
+ "20 21 22 23 24 25 26 27 28 29 2A 2B 2C 2D 2E 2F"
+ "30 31 32 33 34 35 36 37 38 39 3A 3B 3C 3D 3E 3F"
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F"
+ "60 61 62 63 64 65 66 67 68 69 6A 6B 6C 6D 6E 6F"
+ "70 71 72 73 74 75 76 77 78 79 7A 7B 7C 7D 7E 7F"
+ "80 81 82 83 84 85 86 87 88 89 8A 8B 8C 8D 8E 8F"
+ "90 91 92 93 94 95 96 97 98 99 9A 9B 9C 9D 9E 9F"
+ "A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 AA AB AC AD AE AF"
+ "B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 BA BB BC BD BE BF"
+ "C0 C1 C2 C3 C4 C5 C6 C7",
+ "My Tagged Application",
+ "1F 5B 4E 6C CA 02 20 9E 0D CB 5C A6 35 B8 9A 15"
+ "E2 71 EC C7 60 07 1D FD 80 5F AA 38 F9 72 92 30",
+ "Sample #3 NIST",
+ KMAC128
+ ),
+ (
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F",
+ "00 01 02 03",
+ "My Tagged Application",
+ "20 C5 70 C3 13 46 F7 03 C9 AC 36 C6 1C 03 CB 64"
+ "C3 97 0D 0C FC 78 7E 9B 79 59 9D 27 3A 68 D2 F7"
+ "F6 9D 4C C3 DE 9D 10 4A 35 16 89 F2 7C F6 F5 95"
+ "1F 01 03 F3 3F 4F 24 87 10 24 D9 C2 77 73 A8 DD",
+ "Sample #4 NIST",
+ KMAC256
+ ),
+ (
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F",
+ "00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F"
+ "10 11 12 13 14 15 16 17 18 19 1A 1B 1C 1D 1E 1F"
+ "20 21 22 23 24 25 26 27 28 29 2A 2B 2C 2D 2E 2F"
+ "30 31 32 33 34 35 36 37 38 39 3A 3B 3C 3D 3E 3F"
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F"
+ "60 61 62 63 64 65 66 67 68 69 6A 6B 6C 6D 6E 6F"
+ "70 71 72 73 74 75 76 77 78 79 7A 7B 7C 7D 7E 7F"
+ "80 81 82 83 84 85 86 87 88 89 8A 8B 8C 8D 8E 8F"
+ "90 91 92 93 94 95 96 97 98 99 9A 9B 9C 9D 9E 9F"
+ "A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 AA AB AC AD AE AF"
+ "B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 BA BB BC BD BE BF"
+ "C0 C1 C2 C3 C4 C5 C6 C7",
+ "",
+ "75 35 8C F3 9E 41 49 4E 94 97 07 92 7C EE 0A F2"
+ "0A 3F F5 53 90 4C 86 B0 8F 21 CC 41 4B CF D6 91"
+ "58 9D 27 CF 5E 15 36 9C BB FF 8B 9A 4C 2E B1 78"
+ "00 85 5D 02 35 FF 63 5D A8 25 33 EC 6B 75 9B 69",
+ "Sample #5 NIST",
+ KMAC256
+ ),
+ (
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F",
+ "00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F"
+ "10 11 12 13 14 15 16 17 18 19 1A 1B 1C 1D 1E 1F"
+ "20 21 22 23 24 25 26 27 28 29 2A 2B 2C 2D 2E 2F"
+ "30 31 32 33 34 35 36 37 38 39 3A 3B 3C 3D 3E 3F"
+ "40 41 42 43 44 45 46 47 48 49 4A 4B 4C 4D 4E 4F"
+ "50 51 52 53 54 55 56 57 58 59 5A 5B 5C 5D 5E 5F"
+ "60 61 62 63 64 65 66 67 68 69 6A 6B 6C 6D 6E 6F"
+ "70 71 72 73 74 75 76 77 78 79 7A 7B 7C 7D 7E 7F"
+ "80 81 82 83 84 85 86 87 88 89 8A 8B 8C 8D 8E 8F"
+ "90 91 92 93 94 95 96 97 98 99 9A 9B 9C 9D 9E 9F"
+ "A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 AA AB AC AD AE AF"
+ "B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 BA BB BC BD BE BF"
+ "C0 C1 C2 C3 C4 C5 C6 C7",
+ "My Tagged Application",
+ "B5 86 18 F7 1F 92 E1 D5 6C 1B 8C 55 DD D7 CD 18"
+ "8B 97 B4 CA 4D 99 83 1E B2 69 9A 83 7D A2 E4 D9"
+ "70 FB AC FD E5 00 33 AE A5 85 F1 A2 70 85 10 C3"
+ "2D 07 88 08 01 BD 18 28 98 FE 47 68 76 FC 89 65",
+ "Sample #6 NIST",
+ KMAC256
+ ),
+ ]
+
+ def setUp(self):
+ td = []
+ for key, data, custom, mac, text, module in self.test_data:
+ ni = (
+ unhexlify(key.replace(" ", "")),
+ unhexlify(data.replace(" ", "")),
+ custom.encode(),
+ unhexlify(mac.replace(" ", "")),
+ text,
+ module
+ )
+ td.append(ni)
+ self.test_data = td
+
+ def runTest(self):
+
+ for key, data, custom, mac, text, module in self.test_data:
+ h = module.new(data=data, key=key, custom=custom, mac_len=len(mac))
+ mac_tag = h.digest()
+ self.assertEqual(mac_tag, mac, msg=text)
+
+
+def get_tests(config={}):
+ tests = []
+
+ tests += list_test_cases(KMAC128Test)
+ tests += list_test_cases(KMAC256Test)
+ tests.append(NISTExampleTestVectors())
+
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_KangarooTwelve.py b/lib/Crypto/SelfTest/Hash/test_KangarooTwelve.py
new file mode 100644
index 0000000..49aeaad
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_KangarooTwelve.py
@@ -0,0 +1,324 @@
+# ===================================================================
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.KangarooTwelve"""
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import KangarooTwelve as K12
+from Crypto.Util.py3compat import b, bchr
+
+
+class KangarooTwelveTest(unittest.TestCase):
+
+ def test_length_encode(self):
+ self.assertEqual(K12._length_encode(0), b'\x00')
+ self.assertEqual(K12._length_encode(12), b'\x0C\x01')
+ self.assertEqual(K12._length_encode(65538), b'\x01\x00\x02\x03')
+
+ def test_new_positive(self):
+
+ xof1 = K12.new()
+ xof2 = K12.new(data=b("90"))
+ xof3 = K12.new().update(b("90"))
+
+ self.assertNotEqual(xof1.read(10), xof2.read(10))
+ xof3.read(10)
+ self.assertEqual(xof2.read(10), xof3.read(10))
+
+ xof1 = K12.new()
+ ref = xof1.read(10)
+ xof2 = K12.new(custom=b(""))
+ xof3 = K12.new(custom=b("foo"))
+
+ self.assertEqual(ref, xof2.read(10))
+ self.assertNotEqual(ref, xof3.read(10))
+
+ xof1 = K12.new(custom=b("foo"))
+ xof2 = K12.new(custom=b("foo"), data=b("90"))
+ xof3 = K12.new(custom=b("foo")).update(b("90"))
+
+ self.assertNotEqual(xof1.read(10), xof2.read(10))
+ xof3.read(10)
+ self.assertEqual(xof2.read(10), xof3.read(10))
+
+ def test_update(self):
+ pieces = [bchr(10) * 200, bchr(20) * 300]
+ h = K12.new()
+ h.update(pieces[0]).update(pieces[1])
+ digest = h.read(10)
+ h = K12.new()
+ h.update(pieces[0] + pieces[1])
+ self.assertEqual(h.read(10), digest)
+
+ def test_update_negative(self):
+ h = K12.new()
+ self.assertRaises(TypeError, h.update, u"string")
+
+ def test_digest(self):
+ h = K12.new()
+ digest = h.read(90)
+
+ # read returns a byte string of the right length
+ self.assertTrue(isinstance(digest, type(b("digest"))))
+ self.assertEqual(len(digest), 90)
+
+ def test_update_after_read(self):
+ mac = K12.new()
+ mac.update(b("rrrr"))
+ mac.read(90)
+ self.assertRaises(TypeError, mac.update, b("ttt"))
+
+
+def txt2bin(txt):
+ clean = txt.replace(" ", "").replace("\n", "").replace("\r", "")
+ return unhexlify(clean)
+
+
+def ptn(n):
+ res = bytearray(n)
+ pattern = b"".join([bchr(x) for x in range(0, 0xFB)])
+ for base in range(0, n - 0xFB, 0xFB):
+ res[base:base + 0xFB] = pattern
+ remain = n % 0xFB
+ if remain:
+ base = (n // 0xFB) * 0xFB
+ res[base:] = pattern[:remain]
+ assert(len(res) == n)
+ return res
+
+
+def chunked(source, size):
+ for i in range(0, len(source), size):
+ yield source[i:i+size]
+
+
+# https://github.com/XKCP/XKCP/blob/master/tests/TestVectors/KangarooTwelve.txt
+class KangarooTwelveTV(unittest.TestCase):
+
+ def test_zero_1(self):
+ tv = """1A C2 D4 50 FC 3B 42 05 D1 9D A7 BF CA 1B 37 51
+ 3C 08 03 57 7A C7 16 7F 06 FE 2C E1 F0 EF 39 E5"""
+
+ btv = txt2bin(tv)
+ res = K12.new().read(32)
+ self.assertEqual(res, btv)
+
+ def test_zero_2(self):
+ tv = """1A C2 D4 50 FC 3B 42 05 D1 9D A7 BF CA 1B 37 51
+ 3C 08 03 57 7A C7 16 7F 06 FE 2C E1 F0 EF 39 E5
+ 42 69 C0 56 B8 C8 2E 48 27 60 38 B6 D2 92 96 6C
+ C0 7A 3D 46 45 27 2E 31 FF 38 50 81 39 EB 0A 71"""
+
+ btv = txt2bin(tv)
+ res = K12.new().read(64)
+ self.assertEqual(res, btv)
+
+ def test_zero_3(self):
+ tv = """E8 DC 56 36 42 F7 22 8C 84 68 4C 89 84 05 D3 A8
+ 34 79 91 58 C0 79 B1 28 80 27 7A 1D 28 E2 FF 6D"""
+
+ btv = txt2bin(tv)
+ res = K12.new().read(10032)
+ self.assertEqual(res[-32:], btv)
+
+ def test_ptn_1(self):
+ tv = """2B DA 92 45 0E 8B 14 7F 8A 7C B6 29 E7 84 A0 58
+ EF CA 7C F7 D8 21 8E 02 D3 45 DF AA 65 24 4A 1F"""
+
+ btv = txt2bin(tv)
+ res = K12.new(data=ptn(1)).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_17(self):
+ tv = """6B F7 5F A2 23 91 98 DB 47 72 E3 64 78 F8 E1 9B
+ 0F 37 12 05 F6 A9 A9 3A 27 3F 51 DF 37 12 28 88"""
+
+ btv = txt2bin(tv)
+ res = K12.new(data=ptn(17)).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_17_2(self):
+ tv = """0C 31 5E BC DE DB F6 14 26 DE 7D CF 8F B7 25 D1
+ E7 46 75 D7 F5 32 7A 50 67 F3 67 B1 08 EC B6 7C"""
+
+ btv = txt2bin(tv)
+ res = K12.new(data=ptn(17**2)).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_17_3(self):
+ tv = """CB 55 2E 2E C7 7D 99 10 70 1D 57 8B 45 7D DF 77
+ 2C 12 E3 22 E4 EE 7F E4 17 F9 2C 75 8F 0D 59 D0"""
+
+ btv = txt2bin(tv)
+ res = K12.new(data=ptn(17**3)).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_17_4(self):
+ tv = """87 01 04 5E 22 20 53 45 FF 4D DA 05 55 5C BB 5C
+ 3A F1 A7 71 C2 B8 9B AE F3 7D B4 3D 99 98 B9 FE"""
+
+ btv = txt2bin(tv)
+ data = ptn(17**4)
+
+ # All at once
+ res = K12.new(data=data).read(32)
+ self.assertEqual(res, btv)
+
+ # Byte by byte
+ k12 = K12.new()
+ for x in data:
+ k12.update(bchr(x))
+ res = k12.read(32)
+ self.assertEqual(res, btv)
+
+ # Chunks of various prime sizes
+ for chunk_size in (13, 17, 19, 23, 31):
+ k12 = K12.new()
+ for x in chunked(data, chunk_size):
+ k12.update(x)
+ res = k12.read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_17_5(self):
+ tv = """84 4D 61 09 33 B1 B9 96 3C BD EB 5A E3 B6 B0 5C
+ C7 CB D6 7C EE DF 88 3E B6 78 A0 A8 E0 37 16 82"""
+
+ btv = txt2bin(tv)
+ data = ptn(17**5)
+
+ # All at once
+ res = K12.new(data=data).read(32)
+ self.assertEqual(res, btv)
+
+ # Chunks
+ k12 = K12.new()
+ for chunk in chunked(data, 8192):
+ k12.update(chunk)
+ res = k12.read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_17_6(self):
+ tv = """3C 39 07 82 A8 A4 E8 9F A6 36 7F 72 FE AA F1 32
+ 55 C8 D9 58 78 48 1D 3C D8 CE 85 F5 8E 88 0A F8"""
+
+ btv = txt2bin(tv)
+ data = ptn(17**6)
+
+ # All at once
+ res = K12.new(data=data).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_c_1(self):
+ tv = """FA B6 58 DB 63 E9 4A 24 61 88 BF 7A F6 9A 13 30
+ 45 F4 6E E9 84 C5 6E 3C 33 28 CA AF 1A A1 A5 83"""
+
+ btv = txt2bin(tv)
+ custom = ptn(1)
+
+ # All at once
+ res = K12.new(custom=custom).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_c_41(self):
+ tv = """D8 48 C5 06 8C ED 73 6F 44 62 15 9B 98 67 FD 4C
+ 20 B8 08 AC C3 D5 BC 48 E0 B0 6B A0 A3 76 2E C4"""
+
+ btv = txt2bin(tv)
+ custom = ptn(41)
+
+ # All at once
+ res = K12.new(data=b'\xFF', custom=custom).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_c_41_2(self):
+ tv = """C3 89 E5 00 9A E5 71 20 85 4C 2E 8C 64 67 0A C0
+ 13 58 CF 4C 1B AF 89 44 7A 72 42 34 DC 7C ED 74"""
+
+ btv = txt2bin(tv)
+ custom = ptn(41**2)
+
+ # All at once
+ res = K12.new(data=b'\xFF' * 3, custom=custom).read(32)
+ self.assertEqual(res, btv)
+
+ def test_ptn_c_41_3(self):
+ tv = """75 D2 F8 6A 2E 64 45 66 72 6B 4F BC FC 56 57 B9
+ DB CF 07 0C 7B 0D CA 06 45 0A B2 91 D7 44 3B CF"""
+
+ btv = txt2bin(tv)
+ custom = ptn(41**3)
+
+ # All at once
+ res = K12.new(data=b'\xFF' * 7, custom=custom).read(32)
+ self.assertEqual(res, btv)
+
+ ###
+
+ def test_1(self):
+ tv = "fd608f91d81904a9916e78a18f65c157a78d63f93d8f6367db0524526a5ea2bb"
+
+ btv = txt2bin(tv)
+ res = K12.new(data=b'', custom=ptn(100)).read(32)
+ self.assertEqual(res, btv)
+
+ def test_2(self):
+ tv4 = "5a4ec9a649f81916d4ce1553492962f7868abf8dd1ceb2f0cb3682ea95cda6a6"
+ tv3 = "441688fe4fe4ae9425eb3105eb445eb2b3a6f67b66eff8e74ebfbc49371f6d4c"
+ tv2 = "17269a57759af0214c84a0fd9bc851f4d95f80554cfed4e7da8a6ee1ff080131"
+ tv1 = "33826990c09dc712ba7224f0d9be319e2720de95a4c1afbd2211507dae1c703a"
+ tv0 = "9f4d3aba908ddc096e4d3a71da954f917b9752f05052b9d26d916a6fbc75bf3e"
+
+ res = K12.new(data=b'A' * (8192 - 4), custom=b'B').read(32)
+ self.assertEqual(res, txt2bin(tv4))
+
+ res = K12.new(data=b'A' * (8192 - 3), custom=b'B').read(32)
+ self.assertEqual(res, txt2bin(tv3))
+
+ res = K12.new(data=b'A' * (8192 - 2), custom=b'B').read(32)
+ self.assertEqual(res, txt2bin(tv2))
+
+ res = K12.new(data=b'A' * (8192 - 1), custom=b'B').read(32)
+ self.assertEqual(res, txt2bin(tv1))
+
+ res = K12.new(data=b'A' * (8192 - 0), custom=b'B').read(32)
+ self.assertEqual(res, txt2bin(tv0))
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(KangarooTwelveTest)
+ tests += list_test_cases(KangarooTwelveTV)
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_MD2.py b/lib/Crypto/SelfTest/Hash/test_MD2.py
new file mode 100644
index 0000000..9375168
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_MD2.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/MD2.py: Self-test for the MD2 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.MD2"""
+
+from Crypto.Util.py3compat import *
+
+# This is a list of (expected_result, input[, description]) tuples.
+test_data = [
+ # Test vectors from RFC 1319
+ ('8350e5a3e24c153df2275c9f80692773', '', "'' (empty string)"),
+ ('32ec01ec4a6dac72c0ab96fb34c0b5d1', 'a'),
+ ('da853b0d3f88d99b30283a69e6ded6bb', 'abc'),
+ ('ab4f496bfb2a530b219ff33031fe06b0', 'message digest'),
+
+ ('4e8ddff3650292ab5a4108c3aa47940b', 'abcdefghijklmnopqrstuvwxyz',
+ 'a-z'),
+
+ ('da33def2a42df13975352846c30338cd',
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789',
+ 'A-Z, a-z, 0-9'),
+
+ ('d5976f79d83d3a0dc9806c3c66f3efd8',
+ '1234567890123456789012345678901234567890123456'
+ + '7890123456789012345678901234567890',
+ "'1234567890' * 8"),
+]
+
+def get_tests(config={}):
+ from Crypto.Hash import MD2
+ from .common import make_hash_tests
+ return make_hash_tests(MD2, "MD2", test_data,
+ digest_size=16,
+ oid="1.2.840.113549.2.2")
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_MD4.py b/lib/Crypto/SelfTest/Hash/test_MD4.py
new file mode 100644
index 0000000..17b48a7
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_MD4.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/MD4.py: Self-test for the MD4 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.MD4"""
+
+__revision__ = "$Id$"
+
+from Crypto.Util.py3compat import *
+
+# This is a list of (expected_result, input[, description]) tuples.
+test_data = [
+ # Test vectors from RFC 1320
+ ('31d6cfe0d16ae931b73c59d7e0c089c0', '', "'' (empty string)"),
+ ('bde52cb31de33e46245e05fbdbd6fb24', 'a'),
+ ('a448017aaf21d8525fc10ae87aa6729d', 'abc'),
+ ('d9130a8164549fe818874806e1c7014b', 'message digest'),
+
+ ('d79e1c308aa5bbcdeea8ed63df412da9', 'abcdefghijklmnopqrstuvwxyz',
+ 'a-z'),
+
+ ('043f8582f241db351ce627e153e7f0e4',
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789',
+ 'A-Z, a-z, 0-9'),
+
+ ('e33b4ddc9c38f2199c3e7b164fcc0536',
+ '1234567890123456789012345678901234567890123456'
+ + '7890123456789012345678901234567890',
+ "'1234567890' * 8"),
+]
+
+def get_tests(config={}):
+ from Crypto.Hash import MD4
+ from .common import make_hash_tests
+ return make_hash_tests(MD4, "MD4", test_data,
+ digest_size=16,
+ oid="1.2.840.113549.2.4")
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_MD5.py b/lib/Crypto/SelfTest/Hash/test_MD5.py
new file mode 100644
index 0000000..830ace7
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_MD5.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/MD5.py: Self-test for the MD5 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.MD5"""
+
+from Crypto.Util.py3compat import *
+from Crypto.Hash import MD5
+from binascii import unhexlify
+import unittest
+from Crypto.SelfTest.st_common import list_test_cases
+
+
+# This is a list of (expected_result, input[, description]) tuples.
+test_data = [
+ # Test vectors from RFC 1321
+ ('d41d8cd98f00b204e9800998ecf8427e', '', "'' (empty string)"),
+ ('0cc175b9c0f1b6a831c399e269772661', 'a'),
+ ('900150983cd24fb0d6963f7d28e17f72', 'abc'),
+ ('f96b697d7cb7938d525a2f31aaf161d0', 'message digest'),
+
+ ('c3fcd3d76192e4007dfb496cca67e13b', 'abcdefghijklmnopqrstuvwxyz',
+ 'a-z'),
+
+ ('d174ab98d277d9f5a5611c2c9f419d9f',
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789',
+ 'A-Z, a-z, 0-9'),
+
+ ('57edf4a22be3c955ac49da2e2107b67a',
+ '1234567890123456789012345678901234567890123456'
+ + '7890123456789012345678901234567890',
+ "'1234567890' * 8"),
+
+ # https://www.cosic.esat.kuleuven.be/nessie/testvectors/hash/md5/Md5-128.unverified.test-vectors
+ ('57EDF4A22BE3C955AC49DA2E2107B67A', '1234567890' * 8, 'Set 1, vector #7'),
+ ('7707D6AE4E027C70EEA2A935C2296F21', 'a'*1000000, 'Set 1, vector #8'),
+]
+
+
+class Md5IterTest(unittest.TestCase):
+
+ def runTest(self):
+ message = b("\x00") * 16
+ result1 = "4AE71336E44BF9BF79D2752E234818A5".lower()
+ result2 = "1A83F51285E4D89403D00C46EF8508FE".lower()
+
+ h = MD5.new(message)
+ message = h.digest()
+ self.assertEqual(h.hexdigest(), result1)
+
+ for _ in range(99999):
+ h = MD5.new(message)
+ message = h.digest()
+
+ self.assertEqual(h.hexdigest(), result2)
+
+
+def get_tests(config={}):
+ from .common import make_hash_tests
+
+ tests = make_hash_tests(MD5, "MD5", test_data,
+ digest_size=16,
+ oid="1.2.840.113549.2.5")
+ if config.get('slow_tests'):
+ tests += [ Md5IterTest() ]
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_Poly1305.py b/lib/Crypto/SelfTest/Hash/test_Poly1305.py
new file mode 100644
index 0000000..0612d4e
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_Poly1305.py
@@ -0,0 +1,542 @@
+#
+# SelfTest/Hash/test_Poly1305.py: Self-test for the Poly1305 module
+#
+# ===================================================================
+#
+# Copyright (c) 2018, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash._Poly1305"""
+
+import json
+import unittest
+from binascii import unhexlify, hexlify
+
+from .common import make_mac_tests
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import Poly1305
+from Crypto.Cipher import AES, ChaCha20
+
+from Crypto.Util.py3compat import tobytes
+from Crypto.Util.strxor import strxor_c
+
+# This is a list of (r+s keypair, data, result, description, keywords) tuples.
+test_data_basic = [
+ (
+ "85d6be7857556d337f4452fe42d506a80103808afb0db2fd4abff6af4149f51b",
+ hexlify(b"Cryptographic Forum Research Group").decode(),
+ "a8061dc1305136c6c22b8baf0c0127a9",
+ "RFC7539"
+ ),
+ (
+ "746869732069732033322d62797465206b657920666f7220506f6c7931333035",
+ "0000000000000000000000000000000000000000000000000000000000000000",
+ "49ec78090e481ec6c26b33b91ccc0307",
+ "https://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-00#section-7 A",
+ ),
+ (
+ "746869732069732033322d62797465206b657920666f7220506f6c7931333035",
+ "48656c6c6f20776f726c6421",
+ "a6f745008f81c916a20dcc74eef2b2f0",
+ "https://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-00#section-7 B",
+ ),
+ (
+ "746869732069732033322d62797465206b657920666f7220506f6c7931333035",
+ "",
+ "6b657920666f7220506f6c7931333035",
+ "Generated with pure Python",
+ ),
+ (
+ "746869732069732033322d62797465206b657920666f7220506f6c7931333035",
+ "FF",
+ "f7e4e0ef4c46d106219da3d1bdaeb3ff",
+ "Generated with pure Python",
+ ),
+ (
+ "746869732069732033322d62797465206b657920666f7220506f6c7931333035",
+ "FF00",
+ "7471eceeb22988fc936da1d6e838b70e",
+ "Generated with pure Python",
+ ),
+ (
+ "746869732069732033322d62797465206b657920666f7220506f6c7931333035",
+ "AA" * 17,
+ "32590bc07cb2afaccca3f67f122975fe",
+ "Generated with pure Python",
+ ),
+ (
+ "00" * 32,
+ "00" * 64,
+ "00" * 16,
+ "RFC7539 A.3 #1",
+ ),
+ (
+ "0000000000000000000000000000000036e5f6b5c5e06070f0efca96227a863e",
+ hexlify(
+ b"Any submission t"
+ b"o the IETF inten"
+ b"ded by the Contr"
+ b"ibutor for publi"
+ b"cation as all or"
+ b" part of an IETF"
+ b" Internet-Draft "
+ b"or RFC and any s"
+ b"tatement made wi"
+ b"thin the context"
+ b" of an IETF acti"
+ b"vity is consider"
+ b"ed an \"IETF Cont"
+ b"ribution\". Such "
+ b"statements inclu"
+ b"de oral statemen"
+ b"ts in IETF sessi"
+ b"ons, as well as "
+ b"written and elec"
+ b"tronic communica"
+ b"tions made at an"
+ b"y time or place,"
+ b" which are addre"
+ b"ssed to").decode(),
+ "36e5f6b5c5e06070f0efca96227a863e",
+ "RFC7539 A.3 #2",
+ ),
+ (
+ "36e5f6b5c5e06070f0efca96227a863e00000000000000000000000000000000",
+ hexlify(
+ b"Any submission t"
+ b"o the IETF inten"
+ b"ded by the Contr"
+ b"ibutor for publi"
+ b"cation as all or"
+ b" part of an IETF"
+ b" Internet-Draft "
+ b"or RFC and any s"
+ b"tatement made wi"
+ b"thin the context"
+ b" of an IETF acti"
+ b"vity is consider"
+ b"ed an \"IETF Cont"
+ b"ribution\". Such "
+ b"statements inclu"
+ b"de oral statemen"
+ b"ts in IETF sessi"
+ b"ons, as well as "
+ b"written and elec"
+ b"tronic communica"
+ b"tions made at an"
+ b"y time or place,"
+ b" which are addre"
+ b"ssed to").decode(),
+ "f3477e7cd95417af89a6b8794c310cf0",
+ "RFC7539 A.3 #3",
+ ),
+ (
+ "1c9240a5eb55d38af333888604f6b5f0473917c1402b80099dca5cbc207075c0",
+ "2754776173206272696c6c69672c2061"
+ "6e642074686520736c6974687920746f"
+ "7665730a446964206779726520616e64"
+ "2067696d626c6520696e207468652077"
+ "6162653a0a416c6c206d696d73792077"
+ "6572652074686520626f726f676f7665"
+ "732c0a416e6420746865206d6f6d6520"
+ "7261746873206f757467726162652e",
+ "4541669a7eaaee61e708dc7cbcc5eb62",
+ "RFC7539 A.3 #4",
+ ),
+ (
+ "02" + "00" * 31,
+ "FF" * 16,
+ "03" + "00" * 15,
+ "RFC7539 A.3 #5",
+ ),
+ (
+ "02" + "00" * 15 + "FF" * 16,
+ "02" + "00" * 15,
+ "03" + "00" * 15,
+ "RFC7539 A.3 #6",
+ ),
+ (
+ "01" + "00" * 31,
+ "FF" * 16 + "F0" + "FF" * 15 + "11" + "00" * 15,
+ "05" + "00" * 15,
+ "RFC7539 A.3 #7",
+ ),
+ (
+ "01" + "00" * 31,
+ "FF" * 16 + "FB" + "FE" * 15 + "01" * 16,
+ "00" * 16,
+ "RFC7539 A.3 #8",
+ ),
+ (
+ "02" + "00" * 31,
+ "FD" + "FF" * 15,
+ "FA" + "FF" * 15,
+ "RFC7539 A.3 #9",
+ ),
+ (
+ "01 00 00 00 00 00 00 00 04 00 00 00 00 00 00 00"
+ "00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00",
+ "E3 35 94 D7 50 5E 43 B9 00 00 00 00 00 00 00 00"
+ "33 94 D7 50 5E 43 79 CD 01 00 00 00 00 00 00 00"
+ "00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00"
+ "01 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00",
+ "14 00 00 00 00 00 00 00 55 00 00 00 00 00 00 00",
+ "RFC7539 A.3 #10",
+ ),
+ (
+ "01 00 00 00 00 00 00 00 04 00 00 00 00 00 00 00"
+ "00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00",
+ "E3 35 94 D7 50 5E 43 B9 00 00 00 00 00 00 00 00"
+ "33 94 D7 50 5E 43 79 CD 01 00 00 00 00 00 00 00"
+ "00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00",
+ "13" + "00" * 15,
+ "RFC7539 A.3 #11",
+ ),
+]
+
+# This is a list of (key(k+r), data, result, description, keywords) tuples.
+test_data_aes = [
+ (
+ "ec074c835580741701425b623235add6851fc40c3467ac0be05cc20404f3f700",
+ "f3f6",
+ "f4c633c3044fc145f84f335cb81953de",
+ "http://cr.yp.to/mac/poly1305-20050329.pdf",
+ { 'cipher':AES, 'nonce':unhexlify("fb447350c4e868c52ac3275cf9d4327e") }
+ ),
+ (
+ "75deaa25c09f208e1dc4ce6b5cad3fbfa0f3080000f46400d0c7e9076c834403",
+ "",
+ "dd3fab2251f11ac759f0887129cc2ee7",
+ "http://cr.yp.to/mac/poly1305-20050329.pdf",
+ { 'cipher':AES, 'nonce':unhexlify("61ee09218d29b0aaed7e154a2c5509cc") }
+ ),
+ (
+ "6acb5f61a7176dd320c5c1eb2edcdc7448443d0bb0d21109c89a100b5ce2c208",
+ "663cea190ffb83d89593f3f476b6bc24"
+ "d7e679107ea26adb8caf6652d0656136",
+ "0ee1c16bb73f0f4fd19881753c01cdbe",
+ "http://cr.yp.to/mac/poly1305-20050329.pdf",
+ { 'cipher':AES, 'nonce':unhexlify("ae212a55399729595dea458bc621ff0e") }
+ ),
+ (
+ "e1a5668a4d5b66a5f68cc5424ed5982d12976a08c4426d0ce8a82407c4f48207",
+ "ab0812724a7f1e342742cbed374d94d1"
+ "36c6b8795d45b3819830f2c04491faf0"
+ "990c62e48b8018b2c3e4a0fa3134cb67"
+ "fa83e158c994d961c4cb21095c1bf9",
+ "5154ad0d2cb26e01274fc51148491f1b",
+ "http://cr.yp.to/mac/poly1305-20050329.pdf",
+ { 'cipher':AES, 'nonce':unhexlify("9ae831e743978d3a23527c7128149e3a") }
+ ),
+]
+
+test_data_chacha20 = [
+ (
+ "00" * 32,
+ "FF" * 15,
+ "13cc5bbadc36b03a5163928f0bcb65aa",
+ "RFC7539 A.4 #1",
+ { 'cipher':ChaCha20, 'nonce':unhexlify("00" * 12) }
+ ),
+ (
+ "00" * 31 + "01",
+ "FF" * 15,
+ "0baf33c1d6df211bdd50a6767e98e00a",
+ "RFC7539 A.4 #2",
+ { 'cipher':ChaCha20, 'nonce':unhexlify("00" * 11 + "02") }
+ ),
+ (
+ "1c 92 40 a5 eb 55 d3 8a f3 33 88 86 04 f6 b5 f0"
+ "47 39 17 c1 40 2b 80 09 9d ca 5c bc 20 70 75 c0",
+ "FF" * 15,
+ "e8b4c6db226cd8939e65e02eebf834ce",
+ "RFC7539 A.4 #3",
+ { 'cipher':ChaCha20, 'nonce':unhexlify("00" * 11 + "02") }
+ ),
+ (
+ "1c 92 40 a5 eb 55 d3 8a f3 33 88 86 04 f6 b5 f0"
+ "47 39 17 c1 40 2b 80 09 9d ca 5c bc 20 70 75 c0",
+ "f3 33 88 86 00 00 00 00 00 00 4e 91 00 00 00 00"
+ "64 a0 86 15 75 86 1a f4 60 f0 62 c7 9b e6 43 bd"
+ "5e 80 5c fd 34 5c f3 89 f1 08 67 0a c7 6c 8c b2"
+ "4c 6c fc 18 75 5d 43 ee a0 9e e9 4e 38 2d 26 b0"
+ "bd b7 b7 3c 32 1b 01 00 d4 f0 3b 7f 35 58 94 cf"
+ "33 2f 83 0e 71 0b 97 ce 98 c8 a8 4a bd 0b 94 81"
+ "14 ad 17 6e 00 8d 33 bd 60 f9 82 b1 ff 37 c8 55"
+ "97 97 a0 6e f4 f0 ef 61 c1 86 32 4e 2b 35 06 38"
+ "36 06 90 7b 6a 7c 02 b0 f9 f6 15 7b 53 c8 67 e4"
+ "b9 16 6c 76 7b 80 4d 46 a5 9b 52 16 cd e7 a4 e9"
+ "90 40 c5 a4 04 33 22 5e e2 82 a1 b0 a0 6c 52 3e"
+ "af 45 34 d7 f8 3f a1 15 5b 00 47 71 8c bc 54 6a"
+ "0d 07 2b 04 b3 56 4e ea 1b 42 22 73 f5 48 27 1a"
+ "0b b2 31 60 53 fa 76 99 19 55 eb d6 31 59 43 4e"
+ "ce bb 4e 46 6d ae 5a 10 73 a6 72 76 27 09 7a 10"
+ "49 e6 17 d9 1d 36 10 94 fa 68 f0 ff 77 98 71 30"
+ "30 5b ea ba 2e da 04 df 99 7b 71 4d 6c 6f 2c 29"
+ "a6 ad 5c b4 02 2b 02 70 9b 00 00 00 00 00 00 00"
+ "0c 00 00 00 00 00 00 00 09 01 00 00 00 00 00 00",
+ "ee ad 9d 67 89 0c bb 22 39 23 36 fe a1 85 1f 38",
+ "RFC7539 A.5",
+ { 'cipher':ChaCha20, 'nonce':unhexlify("000000000102030405060708") }
+ ),
+]
+
+
+class Poly1305Test_AES(unittest.TestCase):
+
+ key = b'\x11' * 32
+
+ def test_new_positive(self):
+
+ data = b'r' * 100
+
+ h1 = Poly1305.new(key=self.key, cipher=AES)
+ self.assertEqual(h1.digest_size, 16)
+ self.assertEqual(len(h1.nonce), 16)
+ d1 = h1.update(data).digest()
+ self.assertEqual(len(d1), 16)
+
+ h2 = Poly1305.new(key=self.key, nonce=h1.nonce, data=data, cipher=AES)
+ d2 = h2.digest()
+ self.assertEqual(h1.nonce, h2.nonce)
+ self.assertEqual(d1, d2)
+
+ def test_new_negative(self):
+ from Crypto.Cipher import DES3
+
+ self.assertRaises(ValueError, Poly1305.new, key=self.key[:31], cipher=AES)
+ self.assertRaises(ValueError, Poly1305.new, key=self.key, cipher=DES3)
+ self.assertRaises(ValueError, Poly1305.new, key=self.key, nonce=b'1' * 15, cipher=AES)
+ self.assertRaises(TypeError, Poly1305.new, key=u"2" * 32, cipher=AES)
+ self.assertRaises(TypeError, Poly1305.new, key=self.key, data=u"2" * 100, cipher=AES)
+
+ def test_update(self):
+ pieces = [b"\x0A" * 200, b"\x14" * 300]
+ h1 = Poly1305.new(key=self.key, cipher=AES)
+ h1.update(pieces[0]).update(pieces[1])
+ d1 = h1.digest()
+
+ h2 = Poly1305.new(key=self.key, cipher=AES, nonce=h1.nonce)
+ h2.update(pieces[0] + pieces[1])
+ d2 = h2.digest()
+ self.assertEqual(d1, d2)
+
+ def test_update_negative(self):
+ h = Poly1305.new(key=self.key, cipher=AES)
+ self.assertRaises(TypeError, h.update, u"string")
+
+ def test_digest(self):
+ h = Poly1305.new(key=self.key, cipher=AES)
+ digest = h.digest()
+
+ # hexdigest does not change the state
+ self.assertEqual(h.digest(), digest)
+ # digest returns a byte string
+ self.assertTrue(isinstance(digest, type(b"digest")))
+
+ def test_update_after_digest(self):
+ msg=b"rrrrttt"
+
+ # Normally, update() cannot be done after digest()
+ h = Poly1305.new(key=self.key, data=msg[:4], cipher=AES)
+ h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+
+ def test_hex_digest(self):
+ mac = Poly1305.new(key=self.key, cipher=AES)
+ digest = mac.digest()
+ hexdigest = mac.hexdigest()
+
+ # hexdigest is equivalent to digest
+ self.assertEqual(hexlify(digest), tobytes(hexdigest))
+ # hexdigest does not change the state
+ self.assertEqual(mac.hexdigest(), hexdigest)
+ # hexdigest returns a string
+ self.assertTrue(isinstance(hexdigest, type("digest")))
+
+ def test_verify(self):
+ h = Poly1305.new(key=self.key, cipher=AES)
+ mac = h.digest()
+ h.verify(mac)
+ wrong_mac = strxor_c(mac, 255)
+ self.assertRaises(ValueError, h.verify, wrong_mac)
+
+ def test_hexverify(self):
+ h = Poly1305.new(key=self.key, cipher=AES)
+ mac = h.hexdigest()
+ h.hexverify(mac)
+ self.assertRaises(ValueError, h.hexverify, "4556")
+
+ def test_bytearray(self):
+
+ data = b"\x00\x01\x02"
+ h0 = Poly1305.new(key=self.key, data=data, cipher=AES)
+ d_ref = h0.digest()
+
+ # Data and key can be a bytearray (during initialization)
+ key_ba = bytearray(self.key)
+ data_ba = bytearray(data)
+
+ h1 = Poly1305.new(key=self.key, data=data, cipher=AES, nonce=h0.nonce)
+ h2 = Poly1305.new(key=key_ba, data=data_ba, cipher=AES, nonce=h0.nonce)
+ key_ba[:1] = b'\xFF'
+ data_ba[:1] = b'\xEE'
+
+ self.assertEqual(h1.digest(), d_ref)
+ self.assertEqual(h2.digest(), d_ref)
+
+ # Data can be a bytearray (during operation)
+ data_ba = bytearray(data)
+
+ h1 = Poly1305.new(key=self.key, cipher=AES)
+ h2 = Poly1305.new(key=self.key, cipher=AES, nonce=h1.nonce)
+ h1.update(data)
+ h2.update(data_ba)
+ data_ba[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ def test_memoryview(self):
+
+ data = b"\x00\x01\x02"
+
+ def get_mv_ro(data):
+ return memoryview(data)
+
+ def get_mv_rw(data):
+ return memoryview(bytearray(data))
+
+ for get_mv in (get_mv_ro, get_mv_rw):
+
+ # Data and key can be a memoryview (during initialization)
+ key_mv = get_mv(self.key)
+ data_mv = get_mv(data)
+
+ h1 = Poly1305.new(key=self.key, data=data, cipher=AES)
+ h2 = Poly1305.new(key=key_mv, data=data_mv, cipher=AES,
+ nonce=h1.nonce)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+ key_mv[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ # Data can be a memoryview (during operation)
+ data_mv = get_mv(data)
+
+ h1 = Poly1305.new(key=self.key, cipher=AES)
+ h2 = Poly1305.new(key=self.key, cipher=AES, nonce=h1.nonce)
+ h1.update(data)
+ h2.update(data_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class Poly1305Test_ChaCha20(unittest.TestCase):
+
+ key = b'\x11' * 32
+
+ def test_new_positive(self):
+ data = b'r' * 100
+
+ h1 = Poly1305.new(key=self.key, cipher=ChaCha20)
+ self.assertEqual(h1.digest_size, 16)
+ self.assertEqual(len(h1.nonce), 12)
+
+ h2 = Poly1305.new(key=self.key, cipher=ChaCha20, nonce = b'8' * 8)
+ self.assertEqual(len(h2.nonce), 8)
+ self.assertEqual(h2.nonce, b'8' * 8)
+
+ def test_new_negative(self):
+
+ self.assertRaises(ValueError, Poly1305.new, key=self.key, nonce=b'1' * 7, cipher=ChaCha20)
+
+
+#
+# make_mac_tests() expect a new() function with signature new(key, data,
+# **kwargs), and we need to adapt Poly1305's, as it only uses keywords
+#
+class Poly1305_New(object):
+
+ @staticmethod
+ def new(key, *data, **kwds):
+ _kwds = dict(kwds)
+ if len(data) == 1:
+ _kwds['data'] = data[0]
+ _kwds['key'] = key
+ return Poly1305.new(**_kwds)
+
+
+class Poly1305_Basic(object):
+
+ @staticmethod
+ def new(key, *data, **kwds):
+ from Crypto.Hash.Poly1305 import Poly1305_MAC
+
+ if len(data) == 1:
+ msg = data[0]
+ else:
+ msg = None
+
+ return Poly1305_MAC(key[:16], key[16:], msg)
+
+
+class Poly1305AES_MC(unittest.TestCase):
+
+ def runTest(self):
+ tag = unhexlify(b"fb447350c4e868c52ac3275cf9d4327e")
+
+ msg = b''
+ for msg_len in range(5000 + 1):
+ key = tag + strxor_c(tag, 0xFF)
+ nonce = tag[::-1]
+ if msg_len > 0:
+ msg = msg + tobytes(tag[0])
+ auth = Poly1305.new(key=key, nonce=nonce, cipher=AES, data=msg)
+ tag = auth.digest()
+
+ # Compare against output of original DJB's poly1305aes-20050218
+ self.assertEqual("CDFA436DDD629C7DC20E1128530BAED2", auth.hexdigest().upper())
+
+
+def get_tests(config={}):
+ tests = make_mac_tests(Poly1305_Basic, "Poly1305", test_data_basic)
+ tests += make_mac_tests(Poly1305_New, "Poly1305", test_data_aes)
+ tests += make_mac_tests(Poly1305_New, "Poly1305", test_data_chacha20)
+ tests += [ Poly1305AES_MC() ]
+ tests += list_test_cases(Poly1305Test_AES)
+ tests += list_test_cases(Poly1305Test_ChaCha20)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_RIPEMD160.py b/lib/Crypto/SelfTest/Hash/test_RIPEMD160.py
new file mode 100644
index 0000000..153c570
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_RIPEMD160.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_RIPEMD160.py: Self-test for the RIPEMD-160 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+#"""Self-test suite for Crypto.Hash.RIPEMD160"""
+
+from Crypto.Util.py3compat import *
+
+# This is a list of (expected_result, input[, description]) tuples.
+test_data = [
+ # Test vectors downloaded 2008-09-12 from
+ # http://homes.esat.kuleuven.be/~bosselae/ripemd160.html
+ ('9c1185a5c5e9fc54612808977ee8f548b2258d31', '', "'' (empty string)"),
+ ('0bdc9d2d256b3ee9daae347be6f4dc835a467ffe', 'a'),
+ ('8eb208f7e05d987a9b044a8e98c6b087f15a0bfc', 'abc'),
+ ('5d0689ef49d2fae572b881b123a85ffa21595f36', 'message digest'),
+
+ ('f71c27109c692c1b56bbdceb5b9d2865b3708dbc',
+ 'abcdefghijklmnopqrstuvwxyz',
+ 'a-z'),
+
+ ('12a053384a9c0c88e405a06c27dcf49ada62eb2b',
+ 'abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq',
+ 'abcdbcd...pnopq'),
+
+ ('b0e20b6e3116640286ed3a87a5713079b21f5189',
+ 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789',
+ 'A-Z, a-z, 0-9'),
+
+ ('9b752e45573d4b39f4dbd3323cab82bf63326bfb',
+ '1234567890' * 8,
+ "'1234567890' * 8"),
+
+ ('52783243c1697bdbe16d37f97f68f08325dc1528',
+ 'a' * 10**6,
+ '"a" * 10**6'),
+]
+
+def get_tests(config={}):
+ from Crypto.Hash import RIPEMD160
+ from .common import make_hash_tests
+ return make_hash_tests(RIPEMD160, "RIPEMD160", test_data,
+ digest_size=20,
+ oid="1.3.36.3.2.1")
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA1.py b/lib/Crypto/SelfTest/Hash/test_SHA1.py
new file mode 100644
index 0000000..a883a44
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA1.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/SHA1.py: Self-test for the SHA-1 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA"""
+
+from binascii import hexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+
+# Test vectors from various sources
+# This is a list of (expected_result, input[, description]) tuples.
+test_data_various = [
+ # FIPS PUB 180-2, A.1 - "One-Block Message"
+ ('a9993e364706816aba3e25717850c26c9cd0d89d', 'abc'),
+
+ # FIPS PUB 180-2, A.2 - "Multi-Block Message"
+ ('84983e441c3bd26ebaae4aa1f95129e5e54670f1',
+ 'abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq'),
+
+ # FIPS PUB 180-2, A.3 - "Long Message"
+# ('34aa973cd4c4daa4f61eeb2bdbad27316534016f',
+# 'a' * 10**6,
+# '"a" * 10**6'),
+
+ # RFC 3174: Section 7.3, "TEST4" (multiple of 512 bits)
+ ('dea356a2cddd90c7a7ecedc5ebb563934f460452',
+ '01234567' * 80,
+ '"01234567" * 80'),
+]
+
+def get_tests(config={}):
+ from Crypto.Hash import SHA1
+ from .common import make_hash_tests
+
+ tests = []
+
+ test_vectors = load_test_vectors(("Hash", "SHA1"),
+ "SHA1ShortMsg.rsp",
+ "KAT SHA-1",
+ { "len" : lambda x: int(x) } ) or []
+
+ test_data = test_data_various[:]
+ for tv in test_vectors:
+ try:
+ if tv.startswith('['):
+ continue
+ except AttributeError:
+ pass
+ if tv.len == 0:
+ tv.msg = b""
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+ tests = make_hash_tests(SHA1, "SHA1", test_data,
+ digest_size=20,
+ oid="1.3.14.3.2.26")
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA224.py b/lib/Crypto/SelfTest/Hash/test_SHA224.py
new file mode 100644
index 0000000..cf81ad9
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA224.py
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA224.py: Self-test for the SHA-224 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA224"""
+
+# Test vectors from various sources
+# This is a list of (expected_result, input[, description]) tuples.
+test_data = [
+
+ # RFC 3874: Section 3.1, "Test Vector #1
+ ('23097d223405d8228642a477bda255b32aadbce4bda0b3f7e36c9da7', 'abc'),
+
+ # RFC 3874: Section 3.2, "Test Vector #2
+ ('75388b16512776cc5dba5da1fd890150b0c6455cb4f58b1952522525', 'abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq'),
+
+ # RFC 3874: Section 3.3, "Test Vector #3
+ ('20794655980c91d8bbb4c1ea97618a4bf03f42581948b2ee4ee7ad67', 'a' * 10**6, "'a' * 10**6"),
+
+ # Examples from http://de.wikipedia.org/wiki/Secure_Hash_Algorithm
+ ('d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f', ''),
+
+ ('49b08defa65e644cbf8a2dd9270bdededabc741997d1dadd42026d7b',
+ 'Franz jagt im komplett verwahrlosten Taxi quer durch Bayern'),
+
+ ('58911e7fccf2971a7d07f93162d8bd13568e71aa8fc86fc1fe9043d1',
+ 'Frank jagt im komplett verwahrlosten Taxi quer durch Bayern'),
+
+]
+
+def get_tests(config={}):
+ from Crypto.Hash import SHA224
+ from .common import make_hash_tests
+ return make_hash_tests(SHA224, "SHA224", test_data,
+ digest_size=28,
+ oid='2.16.840.1.101.3.4.2.4')
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA256.py b/lib/Crypto/SelfTest/Hash/test_SHA256.py
new file mode 100644
index 0000000..bb99326
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA256.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA256.py: Self-test for the SHA-256 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA256"""
+
+import unittest
+from Crypto.Util.py3compat import *
+
+class LargeSHA256Test(unittest.TestCase):
+ def runTest(self):
+ """SHA256: 512/520 MiB test"""
+ from Crypto.Hash import SHA256
+ zeros = bchr(0x00) * (1024*1024)
+
+ h = SHA256.new(zeros)
+ for i in range(511):
+ h.update(zeros)
+
+ # This test vector is from PyCrypto's old testdata.py file.
+ self.assertEqual('9acca8e8c22201155389f65abbf6bc9723edc7384ead80503839f49dcc56d767', h.hexdigest()) # 512 MiB
+
+ for i in range(8):
+ h.update(zeros)
+
+ # This test vector is from PyCrypto's old testdata.py file.
+ self.assertEqual('abf51ad954b246009dfe5a50ecd582fd5b8f1b8b27f30393853c3ef721e7fa6e', h.hexdigest()) # 520 MiB
+
+def get_tests(config={}):
+ # Test vectors from FIPS PUB 180-2
+ # This is a list of (expected_result, input[, description]) tuples.
+ test_data = [
+ # FIPS PUB 180-2, B.1 - "One-Block Message"
+ ('ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad',
+ 'abc'),
+
+ # FIPS PUB 180-2, B.2 - "Multi-Block Message"
+ ('248d6a61d20638b8e5c026930c3e6039a33ce45964ff2167f6ecedd419db06c1',
+ 'abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq'),
+
+ # FIPS PUB 180-2, B.3 - "Long Message"
+ ('cdc76e5c9914fb9281a1c7e284d73e67f1809a48a497200e046d39ccc7112cd0',
+ 'a' * 10**6,
+ '"a" * 10**6'),
+
+ # Test for an old PyCrypto bug.
+ ('f7fd017a3c721ce7ff03f3552c0813adcc48b7f33f07e5e2ba71e23ea393d103',
+ 'This message is precisely 55 bytes long, to test a bug.',
+ 'Length = 55 (mod 64)'),
+
+ # Example from http://de.wikipedia.org/wiki/Secure_Hash_Algorithm
+ ('e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855', ''),
+
+ ('d32b568cd1b96d459e7291ebf4b25d007f275c9f13149beeb782fac0716613f8',
+ 'Franz jagt im komplett verwahrlosten Taxi quer durch Bayern'),
+ ]
+
+ from Crypto.Hash import SHA256
+ from .common import make_hash_tests
+ tests = make_hash_tests(SHA256, "SHA256", test_data,
+ digest_size=32,
+ oid="2.16.840.1.101.3.4.2.1")
+
+ if config.get('slow_tests'):
+ tests += [LargeSHA256Test()]
+
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA384.py b/lib/Crypto/SelfTest/Hash/test_SHA384.py
new file mode 100644
index 0000000..c682eb4
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA384.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA.py: Self-test for the SHA-384 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA384"""
+
+# Test vectors from various sources
+# This is a list of (expected_result, input[, description]) tuples.
+test_data = [
+
+ # RFC 4634: Section Page 8.4, "Test 1"
+ ('cb00753f45a35e8bb5a03d699ac65007272c32ab0eded1631a8b605a43ff5bed8086072ba1e7cc2358baeca134c825a7', 'abc'),
+
+ # RFC 4634: Section Page 8.4, "Test 2.2"
+ ('09330c33f71147e83d192fc782cd1b4753111b173b3b05d22fa08086e3b0f712fcc7c71a557e2db966c3e9fa91746039', 'abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu'),
+
+ # RFC 4634: Section Page 8.4, "Test 3"
+ ('9d0e1809716474cb086e834e310a4a1ced149e9c00f248527972cec5704c2a5b07b8b3dc38ecc4ebae97ddd87f3d8985', 'a' * 10**6, "'a' * 10**6"),
+
+ # Taken from http://de.wikipedia.org/wiki/Secure_Hash_Algorithm
+ ('38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b', ''),
+
+ # Example from http://de.wikipedia.org/wiki/Secure_Hash_Algorithm
+ ('71e8383a4cea32d6fd6877495db2ee353542f46fa44bc23100bca48f3366b84e809f0708e81041f427c6d5219a286677',
+ 'Franz jagt im komplett verwahrlosten Taxi quer durch Bayern'),
+
+]
+
+def get_tests(config={}):
+ from Crypto.Hash import SHA384
+ from .common import make_hash_tests
+ return make_hash_tests(SHA384, "SHA384", test_data,
+ digest_size=48,
+ oid='2.16.840.1.101.3.4.2.2')
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA3_224.py b/lib/Crypto/SelfTest/Hash/test_SHA3_224.py
new file mode 100644
index 0000000..f92147a
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA3_224.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA3_224.py: Self-test for the SHA-3/224 hash function
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA3_224"""
+
+import unittest
+from binascii import hexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Hash import SHA3_224 as SHA3
+from Crypto.Util.py3compat import b
+
+
+class APITest(unittest.TestCase):
+
+ def test_update_after_digest(self):
+ msg=b("rrrrttt")
+
+ # Normally, update() cannot be done after digest()
+ h = SHA3.new(data=msg[:4])
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+ dig2 = SHA3.new(data=msg).digest()
+
+ # With the proper flag, it is allowed
+ h = SHA3.new(data=msg[:4], update_after_digest=True)
+ self.assertEqual(h.digest(), dig1)
+ # ... and the subsequent digest applies to the entire message
+ # up to that point
+ h.update(msg[4:])
+ self.assertEqual(h.digest(), dig2)
+
+
+def get_tests(config={}):
+ from .common import make_hash_tests
+
+ tests = []
+
+ test_vectors = load_test_vectors(("Hash", "SHA3"),
+ "ShortMsgKAT_SHA3-224.txt",
+ "KAT SHA-3 224",
+ { "len" : lambda x: int(x) } ) or []
+
+ test_data = []
+ for tv in test_vectors:
+ if tv.len == 0:
+ tv.msg = b("")
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+ tests += make_hash_tests(SHA3, "SHA3_224", test_data,
+ digest_size=SHA3.digest_size,
+ oid="2.16.840.1.101.3.4.2.7")
+ tests += list_test_cases(APITest)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA3_256.py b/lib/Crypto/SelfTest/Hash/test_SHA3_256.py
new file mode 100644
index 0000000..432c932
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA3_256.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA3_256.py: Self-test for the SHA-3/256 hash function
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA3_256"""
+
+import unittest
+from binascii import hexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Hash import SHA3_256 as SHA3
+from Crypto.Util.py3compat import b
+
+
+class APITest(unittest.TestCase):
+
+ def test_update_after_digest(self):
+ msg=b("rrrrttt")
+
+ # Normally, update() cannot be done after digest()
+ h = SHA3.new(data=msg[:4])
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+ dig2 = SHA3.new(data=msg).digest()
+
+ # With the proper flag, it is allowed
+ h = SHA3.new(data=msg[:4], update_after_digest=True)
+ self.assertEqual(h.digest(), dig1)
+ # ... and the subsequent digest applies to the entire message
+ # up to that point
+ h.update(msg[4:])
+ self.assertEqual(h.digest(), dig2)
+
+
+def get_tests(config={}):
+ from .common import make_hash_tests
+
+ tests = []
+
+ test_vectors = load_test_vectors(("Hash", "SHA3"),
+ "ShortMsgKAT_SHA3-256.txt",
+ "KAT SHA-3 256",
+ { "len" : lambda x: int(x) } ) or []
+
+ test_data = []
+ for tv in test_vectors:
+ if tv.len == 0:
+ tv.msg = b("")
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+
+ tests += make_hash_tests(SHA3, "SHA3_256", test_data,
+ digest_size=SHA3.digest_size,
+ oid="2.16.840.1.101.3.4.2.8")
+ tests += list_test_cases(APITest)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA3_384.py b/lib/Crypto/SelfTest/Hash/test_SHA3_384.py
new file mode 100644
index 0000000..b0ba1bf
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA3_384.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA3_384.py: Self-test for the SHA-3/384 hash function
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA3_384"""
+
+import unittest
+from binascii import hexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Hash import SHA3_384 as SHA3
+from Crypto.Util.py3compat import b
+
+
+class APITest(unittest.TestCase):
+
+ def test_update_after_digest(self):
+ msg=b("rrrrttt")
+
+ # Normally, update() cannot be done after digest()
+ h = SHA3.new(data=msg[:4])
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+ dig2 = SHA3.new(data=msg).digest()
+
+ # With the proper flag, it is allowed
+ h = SHA3.new(data=msg[:4], update_after_digest=True)
+ self.assertEqual(h.digest(), dig1)
+ # ... and the subsequent digest applies to the entire message
+ # up to that point
+ h.update(msg[4:])
+ self.assertEqual(h.digest(), dig2)
+
+
+def get_tests(config={}):
+ from .common import make_hash_tests
+
+ tests = []
+
+ test_vectors = load_test_vectors(("Hash", "SHA3"),
+ "ShortMsgKAT_SHA3-384.txt",
+ "KAT SHA-3 384",
+ { "len" : lambda x: int(x) } ) or []
+
+ test_data = []
+ for tv in test_vectors:
+ if tv.len == 0:
+ tv.msg = b("")
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+ tests += make_hash_tests(SHA3, "SHA3_384", test_data,
+ digest_size=SHA3.digest_size,
+ oid="2.16.840.1.101.3.4.2.9")
+ tests += list_test_cases(APITest)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA3_512.py b/lib/Crypto/SelfTest/Hash/test_SHA3_512.py
new file mode 100644
index 0000000..7d1007a
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA3_512.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA3_512.py: Self-test for the SHA-3/512 hash function
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA3_512"""
+
+import unittest
+from binascii import hexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Hash import SHA3_512 as SHA3
+from Crypto.Util.py3compat import b
+
+
+class APITest(unittest.TestCase):
+
+ def test_update_after_digest(self):
+ msg=b("rrrrttt")
+
+ # Normally, update() cannot be done after digest()
+ h = SHA3.new(data=msg[:4])
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+ dig2 = SHA3.new(data=msg).digest()
+
+ # With the proper flag, it is allowed
+ h = SHA3.new(data=msg[:4], update_after_digest=True)
+ self.assertEqual(h.digest(), dig1)
+ # ... and the subsequent digest applies to the entire message
+ # up to that point
+ h.update(msg[4:])
+ self.assertEqual(h.digest(), dig2)
+
+
+def get_tests(config={}):
+ from .common import make_hash_tests
+
+ tests = []
+
+ test_vectors = load_test_vectors(("Hash", "SHA3"),
+ "ShortMsgKAT_SHA3-512.txt",
+ "KAT SHA-3 512",
+ { "len" : lambda x: int(x) } ) or []
+
+ test_data = []
+ for tv in test_vectors:
+ if tv.len == 0:
+ tv.msg = b("")
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+ tests += make_hash_tests(SHA3, "SHA3_512", test_data,
+ digest_size=SHA3.digest_size,
+ oid="2.16.840.1.101.3.4.2.10")
+ tests += list_test_cases(APITest)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_SHA512.py b/lib/Crypto/SelfTest/Hash/test_SHA512.py
new file mode 100644
index 0000000..20961ac
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHA512.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Hash/test_SHA512.py: Self-test for the SHA-512 hash function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHA512"""
+
+from binascii import hexlify
+
+from Crypto.Hash import SHA512
+from .common import make_hash_tests
+from Crypto.SelfTest.loader import load_test_vectors
+
+# Test vectors from various sources
+# This is a list of (expected_result, input[, description]) tuples.
+test_data_512_other = [
+
+ # RFC 4634: Section Page 8.4, "Test 1"
+ ('ddaf35a193617abacc417349ae20413112e6fa4e89a97ea20a9eeee64b55d39a2192992a274fc1a836ba3c23a3feebbd454d4423643ce80e2a9ac94fa54ca49f', 'abc'),
+
+ # RFC 4634: Section Page 8.4, "Test 2.1"
+ ('8e959b75dae313da8cf4f72814fc143f8f7779c6eb9f7fa17299aeadb6889018501d289e4900f7e4331b99dec4b5433ac7d329eeb6dd26545e96e55b874be909', 'abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu'),
+
+ # RFC 4634: Section Page 8.4, "Test 3"
+ ('e718483d0ce769644e2e42c7bc15b4638e1f98b13b2044285632a803afa973ebde0ff244877ea60a4cb0432ce577c31beb009c5c2c49aa2e4eadb217ad8cc09b', 'a' * 10**6, "'a' * 10**6"),
+
+ # Taken from http://de.wikipedia.org/wiki/Secure_Hash_Algorithm
+ ('cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e', ''),
+
+ ('af9ed2de700433b803240a552b41b5a472a6ef3fe1431a722b2063c75e9f07451f67a28e37d09cde769424c96aea6f8971389db9e1993d6c565c3c71b855723c', 'Franz jagt im komplett verwahrlosten Taxi quer durch Bayern'),
+]
+
+
+def get_tests_SHA512():
+
+ test_vectors = load_test_vectors(("Hash", "SHA2"),
+ "SHA512ShortMsg.rsp",
+ "KAT SHA-512",
+ {"len": lambda x: int(x)}) or []
+
+ test_data = test_data_512_other[:]
+ for tv in test_vectors:
+ try:
+ if tv.startswith('['):
+ continue
+ except AttributeError:
+ pass
+ if tv.len == 0:
+ tv.msg = b""
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+ tests = make_hash_tests(SHA512, "SHA512", test_data,
+ digest_size=64,
+ oid="2.16.840.1.101.3.4.2.3")
+ return tests
+
+
+def get_tests_SHA512_224():
+
+ test_vectors = load_test_vectors(("Hash", "SHA2"),
+ "SHA512_224ShortMsg.rsp",
+ "KAT SHA-512/224",
+ {"len": lambda x: int(x)}) or []
+
+ test_data = []
+ for tv in test_vectors:
+ try:
+ if tv.startswith('['):
+ continue
+ except AttributeError:
+ pass
+ if tv.len == 0:
+ tv.msg = b""
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+ tests = make_hash_tests(SHA512, "SHA512/224", test_data,
+ digest_size=28,
+ oid="2.16.840.1.101.3.4.2.5",
+ extra_params={ "truncate" : "224" })
+ return tests
+
+
+def get_tests_SHA512_256():
+
+ test_vectors = load_test_vectors(("Hash", "SHA2"),
+ "SHA512_256ShortMsg.rsp",
+ "KAT SHA-512/256",
+ {"len": lambda x: int(x)}) or []
+
+ test_data = []
+ for tv in test_vectors:
+ try:
+ if tv.startswith('['):
+ continue
+ except AttributeError:
+ pass
+ if tv.len == 0:
+ tv.msg = b""
+ test_data.append((hexlify(tv.md), tv.msg, tv.desc))
+
+ tests = make_hash_tests(SHA512, "SHA512/256", test_data,
+ digest_size=32,
+ oid="2.16.840.1.101.3.4.2.6",
+ extra_params={ "truncate" : "256" })
+ return tests
+
+
+def get_tests(config={}):
+
+ tests = []
+ tests += get_tests_SHA512()
+ tests += get_tests_SHA512_224()
+ tests += get_tests_SHA512_256()
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Hash/test_SHAKE.py b/lib/Crypto/SelfTest/Hash/test_SHAKE.py
new file mode 100644
index 0000000..29bd34e
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_SHAKE.py
@@ -0,0 +1,143 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.SHAKE128 and SHAKE256"""
+
+import unittest
+from binascii import hexlify, unhexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import SHAKE128, SHAKE256
+from Crypto.Util.py3compat import b, bchr, bord, tobytes
+
+class SHAKETest(unittest.TestCase):
+
+ def test_new_positive(self):
+
+ xof1 = self.shake.new()
+ xof2 = self.shake.new(data=b("90"))
+ xof3 = self.shake.new().update(b("90"))
+
+ self.assertNotEqual(xof1.read(10), xof2.read(10))
+ xof3.read(10)
+ self.assertEqual(xof2.read(10), xof3.read(10))
+
+ def test_update(self):
+ pieces = [bchr(10) * 200, bchr(20) * 300]
+ h = self.shake.new()
+ h.update(pieces[0]).update(pieces[1])
+ digest = h.read(10)
+ h = self.shake.new()
+ h.update(pieces[0] + pieces[1])
+ self.assertEqual(h.read(10), digest)
+
+ def test_update_negative(self):
+ h = self.shake.new()
+ self.assertRaises(TypeError, h.update, u"string")
+
+ def test_digest(self):
+ h = self.shake.new()
+ digest = h.read(90)
+
+ # read returns a byte string of the right length
+ self.assertTrue(isinstance(digest, type(b("digest"))))
+ self.assertEqual(len(digest), 90)
+
+ def test_update_after_read(self):
+ mac = self.shake.new()
+ mac.update(b("rrrr"))
+ mac.read(90)
+ self.assertRaises(TypeError, mac.update, b("ttt"))
+
+
+class SHAKE128Test(SHAKETest):
+ shake = SHAKE128
+
+
+class SHAKE256Test(SHAKETest):
+ shake = SHAKE256
+
+
+class SHAKEVectors(unittest.TestCase):
+ pass
+
+
+test_vectors_128 = load_test_vectors(("Hash", "SHA3"),
+ "ShortMsgKAT_SHAKE128.txt",
+ "Short Messages KAT SHAKE128",
+ { "len" : lambda x: int(x) } ) or []
+
+for idx, tv in enumerate(test_vectors_128):
+ if tv.len == 0:
+ data = b("")
+ else:
+ data = tobytes(tv.msg)
+
+ def new_test(self, data=data, result=tv.md):
+ hobj = SHAKE128.new(data=data)
+ digest = hobj.read(len(result))
+ self.assertEqual(digest, result)
+
+ setattr(SHAKEVectors, "test_128_%d" % idx, new_test)
+
+
+test_vectors_256 = load_test_vectors(("Hash", "SHA3"),
+ "ShortMsgKAT_SHAKE256.txt",
+ "Short Messages KAT SHAKE256",
+ { "len" : lambda x: int(x) } ) or []
+
+for idx, tv in enumerate(test_vectors_256):
+ if tv.len == 0:
+ data = b("")
+ else:
+ data = tobytes(tv.msg)
+
+ def new_test(self, data=data, result=tv.md):
+ hobj = SHAKE256.new(data=data)
+ digest = hobj.read(len(result))
+ self.assertEqual(digest, result)
+
+ setattr(SHAKEVectors, "test_256_%d" % idx, new_test)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(SHAKE128Test)
+ tests += list_test_cases(SHAKE256Test)
+ tests += list_test_cases(SHAKEVectors)
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_TupleHash.py b/lib/Crypto/SelfTest/Hash/test_TupleHash.py
new file mode 100644
index 0000000..803dc72
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_TupleHash.py
@@ -0,0 +1,286 @@
+import unittest
+from binascii import unhexlify, hexlify
+
+from Crypto.Util.py3compat import tobytes
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import TupleHash128, TupleHash256
+
+
+class TupleHashTest(unittest.TestCase):
+
+ def new(self, *args, **kwargs):
+ return self.TupleHash.new(*args, **kwargs)
+
+ def test_new_positive(self):
+
+ h = self.new()
+ for new_func in self.TupleHash.new, h.new:
+
+ for dbits in range(64, 1024 + 1, 8):
+ hobj = new_func(digest_bits=dbits)
+ self.assertEqual(hobj.digest_size * 8, dbits)
+
+ for dbytes in range(8, 128 + 1):
+ hobj = new_func(digest_bytes=dbytes)
+ self.assertEqual(hobj.digest_size, dbytes)
+
+ hobj = h.new()
+ self.assertEqual(hobj.digest_size, self.default_bytes)
+
+ def test_new_negative(self):
+
+ h = self.new()
+ for new_func in self.TupleHash.new, h.new:
+ self.assertRaises(TypeError, new_func,
+ digest_bytes=self.minimum_bytes,
+ digest_bits=self.minimum_bits)
+ self.assertRaises(ValueError, new_func, digest_bytes=0)
+ self.assertRaises(ValueError, new_func,
+ digest_bits=self.minimum_bits + 7)
+ self.assertRaises(ValueError, new_func,
+ digest_bits=self.minimum_bits - 8)
+ self.assertRaises(ValueError, new_func,
+ digest_bits=self.minimum_bytes - 1)
+
+ def test_default_digest_size(self):
+ digest = self.new().digest()
+ self.assertEqual(len(digest), self.default_bytes)
+
+ def test_update(self):
+ h = self.new()
+ h.update(b'')
+ h.digest()
+
+ h = self.new()
+ h.update(b'')
+ h.update(b'STRING1')
+ h.update(b'STRING2')
+ mac1 = h.digest()
+
+ h = self.new()
+ h.update(b'STRING1')
+ h.update(b'STRING2')
+ mac2 = h.digest()
+
+ self.assertNotEqual(mac1, mac2)
+
+ def test_update_negative(self):
+ h = self.new()
+ self.assertRaises(TypeError, h.update, u"string")
+ self.assertRaises(TypeError, h.update, None)
+
+ def test_digest(self):
+ h = self.new()
+ digest = h.digest()
+
+ # hexdigest does not change the state
+ self.assertEqual(h.digest(), digest)
+ # digest returns a byte string
+ self.assertTrue(isinstance(digest, type(b"digest")))
+
+ def test_update_after_digest(self):
+ msg = b"rrrrttt"
+
+ # Normally, update() cannot be done after digest()
+ h = self.new()
+ h.update(msg)
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, dig1)
+
+ def test_hex_digest(self):
+ mac = self.new()
+ digest = mac.digest()
+ hexdigest = mac.hexdigest()
+
+ # hexdigest is equivalent to digest
+ self.assertEqual(hexlify(digest), tobytes(hexdigest))
+ # hexdigest does not change the state
+ self.assertEqual(mac.hexdigest(), hexdigest)
+ # hexdigest returns a string
+ self.assertTrue(isinstance(hexdigest, type("digest")))
+
+ def test_bytearray(self):
+
+ data = b"\x00\x01\x02"
+
+ # Data can be a bytearray (during operation)
+ data_ba = bytearray(data)
+
+ h1 = self.new()
+ h2 = self.new()
+ h1.update(data)
+ h2.update(data_ba)
+ data_ba[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+ def test_memoryview(self):
+
+ data = b"\x00\x01\x02"
+
+ def get_mv_ro(data):
+ return memoryview(data)
+
+ def get_mv_rw(data):
+ return memoryview(bytearray(data))
+
+ for get_mv in (get_mv_ro, get_mv_rw):
+
+ # Data can be a memoryview (during operation)
+ data_mv = get_mv(data)
+
+ h1 = self.new()
+ h2 = self.new()
+ h1.update(data)
+ h2.update(data_mv)
+ if not data_mv.readonly:
+ data_mv[:1] = b'\xFF'
+
+ self.assertEqual(h1.digest(), h2.digest())
+
+
+class TupleHash128Test(TupleHashTest):
+
+ TupleHash = TupleHash128
+
+ minimum_bytes = 8
+ default_bytes = 64
+
+ minimum_bits = 64
+ default_bits = 512
+
+
+class TupleHash256Test(TupleHashTest):
+
+ TupleHash = TupleHash256
+
+ minimum_bytes = 8
+ default_bytes = 64
+
+ minimum_bits = 64
+ default_bits = 512
+
+
+class NISTExampleTestVectors(unittest.TestCase):
+
+ # http://csrc.nist.gov/groups/ST/toolkit/documents/Examples/TupleHash_samples.pdf
+ test_data = [
+ (
+ (
+ "00 01 02",
+ "10 11 12 13 14 15",
+ ),
+ "",
+ "C5 D8 78 6C 1A FB 9B 82 11 1A B3 4B 65 B2 C0 04"
+ "8F A6 4E 6D 48 E2 63 26 4C E1 70 7D 3F FC 8E D1",
+ "KMAC128 Sample #1 NIST",
+ TupleHash128
+ ),
+ (
+ (
+ "00 01 02",
+ "10 11 12 13 14 15",
+ ),
+ "My Tuple App",
+ "75 CD B2 0F F4 DB 11 54 E8 41 D7 58 E2 41 60 C5"
+ "4B AE 86 EB 8C 13 E7 F5 F4 0E B3 55 88 E9 6D FB",
+ "KMAC128 Sample #2 NIST",
+ TupleHash128
+ ),
+ (
+ (
+ "00 01 02",
+ "10 11 12 13 14 15",
+ "20 21 22 23 24 25 26 27 28",
+ ),
+ "My Tuple App",
+ "E6 0F 20 2C 89 A2 63 1E DA 8D 4C 58 8C A5 FD 07"
+ "F3 9E 51 51 99 8D EC CF 97 3A DB 38 04 BB 6E 84",
+ "KMAC128 Sample #3 NIST",
+ TupleHash128
+ ),
+ (
+ (
+ "00 01 02",
+ "10 11 12 13 14 15",
+ ),
+ "",
+ "CF B7 05 8C AC A5 E6 68 F8 1A 12 A2 0A 21 95 CE"
+ "97 A9 25 F1 DB A3 E7 44 9A 56 F8 22 01 EC 60 73"
+ "11 AC 26 96 B1 AB 5E A2 35 2D F1 42 3B DE 7B D4"
+ "BB 78 C9 AE D1 A8 53 C7 86 72 F9 EB 23 BB E1 94",
+ "KMAC256 Sample #4 NIST",
+ TupleHash256
+ ),
+ (
+ (
+ "00 01 02",
+ "10 11 12 13 14 15",
+ ),
+ "My Tuple App",
+ "14 7C 21 91 D5 ED 7E FD 98 DB D9 6D 7A B5 A1 16"
+ "92 57 6F 5F E2 A5 06 5F 3E 33 DE 6B BA 9F 3A A1"
+ "C4 E9 A0 68 A2 89 C6 1C 95 AA B3 0A EE 1E 41 0B"
+ "0B 60 7D E3 62 0E 24 A4 E3 BF 98 52 A1 D4 36 7E",
+ "KMAC256 Sample #5 NIST",
+ TupleHash256
+ ),
+ (
+ (
+ "00 01 02",
+ "10 11 12 13 14 15",
+ "20 21 22 23 24 25 26 27 28",
+ ),
+ "My Tuple App",
+ "45 00 0B E6 3F 9B 6B FD 89 F5 47 17 67 0F 69 A9"
+ "BC 76 35 91 A4 F0 5C 50 D6 88 91 A7 44 BC C6 E7"
+ "D6 D5 B5 E8 2C 01 8D A9 99 ED 35 B0 BB 49 C9 67"
+ "8E 52 6A BD 8E 85 C1 3E D2 54 02 1D B9 E7 90 CE",
+ "KMAC256 Sample #6 NIST",
+ TupleHash256
+ ),
+
+
+
+ ]
+
+ def setUp(self):
+ td = []
+ for tv_in in self.test_data:
+ tv_out = [None] * len(tv_in)
+
+ tv_out[0] = []
+ for string in tv_in[0]:
+ tv_out[0].append(unhexlify(string.replace(" ", "")))
+
+ tv_out[1] = tobytes(tv_in[1]) # Custom
+ tv_out[2] = unhexlify(tv_in[2].replace(" ", ""))
+ tv_out[3] = tv_in[3]
+ tv_out[4] = tv_in[4]
+ td.append(tv_out)
+ self.test_data = td
+
+ def runTest(self):
+
+ for data, custom, digest, text, module in self.test_data:
+ hd = module.new(custom=custom, digest_bytes=len(digest))
+ for string in data:
+ hd.update(string)
+ self.assertEqual(hd.digest(), digest, msg=text)
+
+
+def get_tests(config={}):
+ tests = []
+
+ tests += list_test_cases(TupleHash128Test)
+ tests += list_test_cases(TupleHash256Test)
+ tests.append(NISTExampleTestVectors())
+
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_cSHAKE.py b/lib/Crypto/SelfTest/Hash/test_cSHAKE.py
new file mode 100644
index 0000000..72ad341
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_cSHAKE.py
@@ -0,0 +1,178 @@
+# ===================================================================
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.cSHAKE128 and cSHAKE256"""
+
+import unittest
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import cSHAKE128, cSHAKE256, SHAKE128, SHAKE256
+from Crypto.Util.py3compat import b, bchr, tobytes
+
+
+class cSHAKETest(unittest.TestCase):
+
+ def test_left_encode(self):
+ from Crypto.Hash.cSHAKE128 import _left_encode
+ self.assertEqual(_left_encode(0), b'\x01\x00')
+ self.assertEqual(_left_encode(1), b'\x01\x01')
+ self.assertEqual(_left_encode(256), b'\x02\x01\x00')
+
+ def test_bytepad(self):
+ from Crypto.Hash.cSHAKE128 import _bytepad
+ self.assertEqual(_bytepad(b'', 4), b'\x01\x04\x00\x00')
+ self.assertEqual(_bytepad(b'A', 4), b'\x01\x04A\x00')
+ self.assertEqual(_bytepad(b'AA', 4), b'\x01\x04AA')
+ self.assertEqual(_bytepad(b'AAA', 4), b'\x01\x04AAA\x00\x00\x00')
+ self.assertEqual(_bytepad(b'AAAA', 4), b'\x01\x04AAAA\x00\x00')
+ self.assertEqual(_bytepad(b'AAAAA', 4), b'\x01\x04AAAAA\x00')
+ self.assertEqual(_bytepad(b'AAAAAA', 4), b'\x01\x04AAAAAA')
+ self.assertEqual(_bytepad(b'AAAAAAA', 4), b'\x01\x04AAAAAAA\x00\x00\x00')
+
+ def test_new_positive(self):
+
+ xof1 = self.cshake.new()
+ xof2 = self.cshake.new(data=b("90"))
+ xof3 = self.cshake.new().update(b("90"))
+
+ self.assertNotEqual(xof1.read(10), xof2.read(10))
+ xof3.read(10)
+ self.assertEqual(xof2.read(10), xof3.read(10))
+
+ xof1 = self.cshake.new()
+ ref = xof1.read(10)
+ xof2 = self.cshake.new(custom=b(""))
+ xof3 = self.cshake.new(custom=b("foo"))
+
+ self.assertEqual(ref, xof2.read(10))
+ self.assertNotEqual(ref, xof3.read(10))
+
+ xof1 = self.cshake.new(custom=b("foo"))
+ xof2 = self.cshake.new(custom=b("foo"), data=b("90"))
+ xof3 = self.cshake.new(custom=b("foo")).update(b("90"))
+
+ self.assertNotEqual(xof1.read(10), xof2.read(10))
+ xof3.read(10)
+ self.assertEqual(xof2.read(10), xof3.read(10))
+
+ def test_update(self):
+ pieces = [bchr(10) * 200, bchr(20) * 300]
+ h = self.cshake.new()
+ h.update(pieces[0]).update(pieces[1])
+ digest = h.read(10)
+ h = self.cshake.new()
+ h.update(pieces[0] + pieces[1])
+ self.assertEqual(h.read(10), digest)
+
+ def test_update_negative(self):
+ h = self.cshake.new()
+ self.assertRaises(TypeError, h.update, u"string")
+
+ def test_digest(self):
+ h = self.cshake.new()
+ digest = h.read(90)
+
+ # read returns a byte string of the right length
+ self.assertTrue(isinstance(digest, type(b("digest"))))
+ self.assertEqual(len(digest), 90)
+
+ def test_update_after_read(self):
+ mac = self.cshake.new()
+ mac.update(b("rrrr"))
+ mac.read(90)
+ self.assertRaises(TypeError, mac.update, b("ttt"))
+
+ def test_shake(self):
+ # When no customization string is passed, results must match SHAKE
+ for digest_len in range(64):
+ xof1 = self.cshake.new(b'TEST')
+ xof2 = self.shake.new(b'TEST')
+ self.assertEqual(xof1.read(digest_len), xof2.read(digest_len))
+
+
+class cSHAKE128Test(cSHAKETest):
+ cshake = cSHAKE128
+ shake = SHAKE128
+
+
+class cSHAKE256Test(cSHAKETest):
+ cshake = cSHAKE256
+ shake = SHAKE256
+
+
+class cSHAKEVectors(unittest.TestCase):
+ pass
+
+
+vector_files = [("ShortMsgSamples_cSHAKE128.txt", "Short Message Samples cSHAKE128", "128_cshake", cSHAKE128),
+ ("ShortMsgSamples_cSHAKE256.txt", "Short Message Samples cSHAKE256", "256_cshake", cSHAKE256),
+ ("CustomMsgSamples_cSHAKE128.txt", "Custom Message Samples cSHAKE128", "custom_128_cshake", cSHAKE128),
+ ("CustomMsgSamples_cSHAKE256.txt", "Custom Message Samples cSHAKE256", "custom_256_cshake", cSHAKE256),
+ ]
+
+for file, descr, tag, test_class in vector_files:
+
+ test_vectors = load_test_vectors(("Hash", "SHA3"), file, descr,
+ {"len": lambda x: int(x),
+ "nlen": lambda x: int(x),
+ "slen": lambda x: int(x)}) or []
+
+ for idx, tv in enumerate(test_vectors):
+ if getattr(tv, "len", 0) == 0:
+ data = b("")
+ else:
+ data = tobytes(tv.msg)
+ assert(tv.len == len(tv.msg)*8)
+ if getattr(tv, "nlen", 0) != 0:
+ raise ValueError("Unsupported cSHAKE test vector")
+ if getattr(tv, "slen", 0) == 0:
+ custom = b("")
+ else:
+ custom = tobytes(tv.s)
+ assert(tv.slen == len(tv.s)*8)
+
+ def new_test(self, data=data, result=tv.md, custom=custom, test_class=test_class):
+ hobj = test_class.new(data=data, custom=custom)
+ digest = hobj.read(len(result))
+ self.assertEqual(digest, result)
+
+ setattr(cSHAKEVectors, "test_%s_%d" % (tag, idx), new_test)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(cSHAKE128Test)
+ tests += list_test_cases(cSHAKE256Test)
+ tests += list_test_cases(cSHAKEVectors)
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Hash/test_keccak.py b/lib/Crypto/SelfTest/Hash/test_keccak.py
new file mode 100644
index 0000000..54cdf27
--- /dev/null
+++ b/lib/Crypto/SelfTest/Hash/test_keccak.py
@@ -0,0 +1,250 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Hash.keccak"""
+
+import unittest
+from binascii import hexlify, unhexlify
+
+from Crypto.SelfTest.loader import load_test_vectors
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Hash import keccak
+from Crypto.Util.py3compat import b, tobytes, bchr
+
+class KeccakTest(unittest.TestCase):
+
+ def test_new_positive(self):
+
+ for digest_bits in (224, 256, 384, 512):
+ hobj = keccak.new(digest_bits=digest_bits)
+ self.assertEqual(hobj.digest_size, digest_bits // 8)
+
+ hobj2 = hobj.new()
+ self.assertEqual(hobj2.digest_size, digest_bits // 8)
+
+ for digest_bytes in (28, 32, 48, 64):
+ hobj = keccak.new(digest_bytes=digest_bytes)
+ self.assertEqual(hobj.digest_size, digest_bytes)
+
+ hobj2 = hobj.new()
+ self.assertEqual(hobj2.digest_size, digest_bytes)
+
+ def test_new_positive2(self):
+
+ digest1 = keccak.new(data=b("\x90"), digest_bytes=64).digest()
+ digest2 = keccak.new(digest_bytes=64).update(b("\x90")).digest()
+ self.assertEqual(digest1, digest2)
+
+ def test_new_negative(self):
+
+ # keccak.new needs digest size
+ self.assertRaises(TypeError, keccak.new)
+
+ h = keccak.new(digest_bits=512)
+
+ # Either bits or bytes can be specified
+ self.assertRaises(TypeError, keccak.new,
+ digest_bytes=64,
+ digest_bits=512)
+
+ # Range
+ self.assertRaises(ValueError, keccak.new, digest_bytes=0)
+ self.assertRaises(ValueError, keccak.new, digest_bytes=1)
+ self.assertRaises(ValueError, keccak.new, digest_bytes=65)
+ self.assertRaises(ValueError, keccak.new, digest_bits=0)
+ self.assertRaises(ValueError, keccak.new, digest_bits=1)
+ self.assertRaises(ValueError, keccak.new, digest_bits=513)
+
+ def test_update(self):
+ pieces = [bchr(10) * 200, bchr(20) * 300]
+ h = keccak.new(digest_bytes=64)
+ h.update(pieces[0]).update(pieces[1])
+ digest = h.digest()
+ h = keccak.new(digest_bytes=64)
+ h.update(pieces[0] + pieces[1])
+ self.assertEqual(h.digest(), digest)
+
+ def test_update_negative(self):
+ h = keccak.new(digest_bytes=64)
+ self.assertRaises(TypeError, h.update, u"string")
+
+ def test_digest(self):
+ h = keccak.new(digest_bytes=64)
+ digest = h.digest()
+
+ # hexdigest does not change the state
+ self.assertEqual(h.digest(), digest)
+ # digest returns a byte string
+ self.assertTrue(isinstance(digest, type(b("digest"))))
+
+ def test_hex_digest(self):
+ mac = keccak.new(digest_bits=512)
+ digest = mac.digest()
+ hexdigest = mac.hexdigest()
+
+ # hexdigest is equivalent to digest
+ self.assertEqual(hexlify(digest), tobytes(hexdigest))
+ # hexdigest does not change the state
+ self.assertEqual(mac.hexdigest(), hexdigest)
+ # hexdigest returns a string
+ self.assertTrue(isinstance(hexdigest, type("digest")))
+
+ def test_update_after_digest(self):
+ msg=b("rrrrttt")
+
+ # Normally, update() cannot be done after digest()
+ h = keccak.new(digest_bits=512, data=msg[:4])
+ dig1 = h.digest()
+ self.assertRaises(TypeError, h.update, msg[4:])
+ dig2 = keccak.new(digest_bits=512, data=msg).digest()
+
+ # With the proper flag, it is allowed
+ h = keccak.new(digest_bits=512, data=msg[:4], update_after_digest=True)
+ self.assertEqual(h.digest(), dig1)
+ # ... and the subsequent digest applies to the entire message
+ # up to that point
+ h.update(msg[4:])
+ self.assertEqual(h.digest(), dig2)
+
+
+class KeccakVectors(unittest.TestCase):
+ pass
+
+ # TODO: add ExtremelyLong tests
+
+
+test_vectors_224 = load_test_vectors(("Hash", "keccak"),
+ "ShortMsgKAT_224.txt",
+ "Short Messages KAT 224",
+ {"len": lambda x: int(x)}) or []
+
+test_vectors_224 += load_test_vectors(("Hash", "keccak"),
+ "LongMsgKAT_224.txt",
+ "Long Messages KAT 224",
+ {"len": lambda x: int(x)}) or []
+
+for idx, tv in enumerate(test_vectors_224):
+ if tv.len == 0:
+ data = b("")
+ else:
+ data = tobytes(tv.msg)
+
+ def new_test(self, data=data, result=tv.md):
+ hobj = keccak.new(digest_bits=224, data=data)
+ self.assertEqual(hobj.digest(), result)
+
+ setattr(KeccakVectors, "test_224_%d" % idx, new_test)
+
+# ---
+
+test_vectors_256 = load_test_vectors(("Hash", "keccak"),
+ "ShortMsgKAT_256.txt",
+ "Short Messages KAT 256",
+ { "len" : lambda x: int(x) } ) or []
+
+test_vectors_256 += load_test_vectors(("Hash", "keccak"),
+ "LongMsgKAT_256.txt",
+ "Long Messages KAT 256",
+ { "len" : lambda x: int(x) } ) or []
+
+for idx, tv in enumerate(test_vectors_256):
+ if tv.len == 0:
+ data = b("")
+ else:
+ data = tobytes(tv.msg)
+
+ def new_test(self, data=data, result=tv.md):
+ hobj = keccak.new(digest_bits=256, data=data)
+ self.assertEqual(hobj.digest(), result)
+
+ setattr(KeccakVectors, "test_256_%d" % idx, new_test)
+
+
+# ---
+
+test_vectors_384 = load_test_vectors(("Hash", "keccak"),
+ "ShortMsgKAT_384.txt",
+ "Short Messages KAT 384",
+ {"len": lambda x: int(x)}) or []
+
+test_vectors_384 += load_test_vectors(("Hash", "keccak"),
+ "LongMsgKAT_384.txt",
+ "Long Messages KAT 384",
+ {"len": lambda x: int(x)}) or []
+
+for idx, tv in enumerate(test_vectors_384):
+ if tv.len == 0:
+ data = b("")
+ else:
+ data = tobytes(tv.msg)
+
+ def new_test(self, data=data, result=tv.md):
+ hobj = keccak.new(digest_bits=384, data=data)
+ self.assertEqual(hobj.digest(), result)
+
+ setattr(KeccakVectors, "test_384_%d" % idx, new_test)
+
+# ---
+
+test_vectors_512 = load_test_vectors(("Hash", "keccak"),
+ "ShortMsgKAT_512.txt",
+ "Short Messages KAT 512",
+ {"len": lambda x: int(x)}) or []
+
+test_vectors_512 += load_test_vectors(("Hash", "keccak"),
+ "LongMsgKAT_512.txt",
+ "Long Messages KAT 512",
+ {"len": lambda x: int(x)}) or []
+
+for idx, tv in enumerate(test_vectors_512):
+ if tv.len == 0:
+ data = b("")
+ else:
+ data = tobytes(tv.msg)
+
+ def new_test(self, data=data, result=tv.md):
+ hobj = keccak.new(digest_bits=512, data=data)
+ self.assertEqual(hobj.digest(), result)
+
+ setattr(KeccakVectors, "test_512_%d" % idx, new_test)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(KeccakTest)
+ tests += list_test_cases(KeccakVectors)
+ return tests
+
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/IO/__init__.py b/lib/Crypto/SelfTest/IO/__init__.py
new file mode 100644
index 0000000..c04a2a7
--- /dev/null
+++ b/lib/Crypto/SelfTest/IO/__init__.py
@@ -0,0 +1,47 @@
+#
+# SelfTest/IO/__init__.py: Self-test for input/output module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test for I/O"""
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest.IO import test_PKCS8; tests += test_PKCS8.get_tests(config=config)
+ from Crypto.SelfTest.IO import test_PBES; tests += test_PBES.get_tests(config=config)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+
diff --git a/lib/Crypto/SelfTest/IO/test_PBES.py b/lib/Crypto/SelfTest/IO/test_PBES.py
new file mode 100644
index 0000000..b2a4f94
--- /dev/null
+++ b/lib/Crypto/SelfTest/IO/test_PBES.py
@@ -0,0 +1,93 @@
+#
+# SelfTest/IO/test_PBES.py: Self-test for the _PBES module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-tests for Crypto.IO._PBES module"""
+
+import unittest
+from Crypto.Util.py3compat import *
+
+from Crypto.IO._PBES import PBES2
+
+
+class TestPBES2(unittest.TestCase):
+
+ def setUp(self):
+ self.ref = b("Test data")
+ self.passphrase = b("Passphrase")
+
+ def test1(self):
+ ct = PBES2.encrypt(self.ref, self.passphrase,
+ 'PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC')
+ pt = PBES2.decrypt(ct, self.passphrase)
+ self.assertEqual(self.ref, pt)
+
+ def test2(self):
+ ct = PBES2.encrypt(self.ref, self.passphrase,
+ 'PBKDF2WithHMAC-SHA1AndAES128-CBC')
+ pt = PBES2.decrypt(ct, self.passphrase)
+ self.assertEqual(self.ref, pt)
+
+ def test3(self):
+ ct = PBES2.encrypt(self.ref, self.passphrase,
+ 'PBKDF2WithHMAC-SHA1AndAES192-CBC')
+ pt = PBES2.decrypt(ct, self.passphrase)
+ self.assertEqual(self.ref, pt)
+
+ def test4(self):
+ ct = PBES2.encrypt(self.ref, self.passphrase,
+ 'scryptAndAES128-CBC')
+ pt = PBES2.decrypt(ct, self.passphrase)
+ self.assertEqual(self.ref, pt)
+
+ def test5(self):
+ ct = PBES2.encrypt(self.ref, self.passphrase,
+ 'scryptAndAES192-CBC')
+ pt = PBES2.decrypt(ct, self.passphrase)
+ self.assertEqual(self.ref, pt)
+
+ def test6(self):
+ ct = PBES2.encrypt(self.ref, self.passphrase,
+ 'scryptAndAES256-CBC')
+ pt = PBES2.decrypt(ct, self.passphrase)
+ self.assertEqual(self.ref, pt)
+
+
+def get_tests(config={}):
+ from Crypto.SelfTest.st_common import list_test_cases
+ listTests = []
+ listTests += list_test_cases(TestPBES2)
+ return listTests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/IO/test_PKCS8.py b/lib/Crypto/SelfTest/IO/test_PKCS8.py
new file mode 100644
index 0000000..cf91d69
--- /dev/null
+++ b/lib/Crypto/SelfTest/IO/test_PKCS8.py
@@ -0,0 +1,425 @@
+#
+# SelfTest/IO/test_PKCS8.py: Self-test for the PKCS8 module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-tests for Crypto.IO.PKCS8 module"""
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import *
+from Crypto.IO import PKCS8
+
+from Crypto.Util.asn1 import DerNull
+
+oid_key = '1.2.840.113549.1.1.1'
+
+# Original RSA key (in DER format)
+# hexdump -v -e '32/1 "%02x" "\n"' key.der
+clear_key="""
+308201ab020100025a00b94a7f7075ab9e79e8196f47be707781e80dd965cf16
+0c951a870b71783b6aaabbd550c0e65e5a3dfe15b8620009f6d7e5efec42a3f0
+6fe20faeebb0c356e79cdec6db4dd427e82d8ae4a5b90996227b8ba54ccfc4d2
+5c08050203010001025a00afa09c70d528299b7552fe766b5d20f9a221d66938
+c3b68371d48515359863ff96f0978d700e08cd6fd3d8a3f97066fc2e0d5f78eb
+3a50b8e17ba297b24d1b8e9cdfd18d608668198d724ad15863ef0329195dee89
+3f039395022d0ebe0518df702a8b25954301ec60a97efdcec8eaa4f2e76ca7e8
+8dfbc3f7e0bb83f9a0e8dc47c0f8c746e9df6b022d0c9195de13f09b7be1fdd7
+1f56ae7d973e08bd9fd2c3dfd8936bb05be9cc67bd32d663c7f00d70932a0be3
+c24f022d0ac334eb6cabf1933633db007b763227b0d9971a9ea36aca8b669ec9
+4fcf16352f6b3dcae28e4bd6137db4ddd3022d0400a09f15ee7b351a2481cb03
+09920905c236d09c87afd3022f3afc2a19e3b746672b635238956ee7e6dd62d5
+022d0cd88ed14fcfbda5bbf0257f700147137bbab9c797af7df866704b889aa3
+7e2e93df3ff1a0fd3490111dcdbc4c
+"""
+
+# Same key as above, wrapped in PKCS#8 but w/o password
+#
+# openssl pkcs8 -topk8 -inform DER -nocrypt -in key.der -outform DER -out keyp8.der
+# hexdump -v -e '32/1 "%02x" "\n"' keyp8.der
+wrapped_clear_key="""
+308201c5020100300d06092a864886f70d0101010500048201af308201ab0201
+00025a00b94a7f7075ab9e79e8196f47be707781e80dd965cf160c951a870b71
+783b6aaabbd550c0e65e5a3dfe15b8620009f6d7e5efec42a3f06fe20faeebb0
+c356e79cdec6db4dd427e82d8ae4a5b90996227b8ba54ccfc4d25c0805020301
+0001025a00afa09c70d528299b7552fe766b5d20f9a221d66938c3b68371d485
+15359863ff96f0978d700e08cd6fd3d8a3f97066fc2e0d5f78eb3a50b8e17ba2
+97b24d1b8e9cdfd18d608668198d724ad15863ef0329195dee893f039395022d
+0ebe0518df702a8b25954301ec60a97efdcec8eaa4f2e76ca7e88dfbc3f7e0bb
+83f9a0e8dc47c0f8c746e9df6b022d0c9195de13f09b7be1fdd71f56ae7d973e
+08bd9fd2c3dfd8936bb05be9cc67bd32d663c7f00d70932a0be3c24f022d0ac3
+34eb6cabf1933633db007b763227b0d9971a9ea36aca8b669ec94fcf16352f6b
+3dcae28e4bd6137db4ddd3022d0400a09f15ee7b351a2481cb0309920905c236
+d09c87afd3022f3afc2a19e3b746672b635238956ee7e6dd62d5022d0cd88ed1
+4fcfbda5bbf0257f700147137bbab9c797af7df866704b889aa37e2e93df3ff1
+a0fd3490111dcdbc4c
+"""
+
+###
+#
+# The key above will now be encrypted with different algorithms.
+# The password is always 'TestTest'.
+#
+# Each item in the wrapped_enc_keys list contains:
+# * wrap algorithm
+# * iteration count
+# * Salt
+# * IV
+# * Expected result
+###
+wrapped_enc_keys = []
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der -outform DER -out keyenc.der -v2 des3
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC',
+2048,
+"47EA7227D8B22E2F", # IV
+"E3F7A838AB911A4D", # Salt
+"""
+30820216304006092a864886f70d01050d3033301b06092a864886f70d01050c
+300e0408e3f7a838ab911a4d02020800301406082a864886f70d0307040847ea
+7227d8b22e2f048201d0ea388b374d2d0e4ceb7a5139f850fdff274884a6e6c0
+64326e09d00dbba9018834edb5a51a6ae3d1806e6e91eebf33788ce71fee0637
+a2ebf58859dd32afc644110c390274a6128b50c39b8d907823810ec471bada86
+6f5b75d8ea04ad310fad2e73621696db8e426cd511ee93ec1714a1a7db45e036
+4bf20d178d1f16bbb250b32c2d200093169d588de65f7d99aad9ddd0104b44f1
+326962e1520dfac3c2a800e8a14f678dff2b3d0bb23f69da635bf2a643ac934e
+219a447d2f4460b67149e860e54f365da130763deefa649c72b0dcd48966a2d3
+4a477444782e3e66df5a582b07bbb19778a79bd355074ce331f4a82eb966b0c4
+52a09eab6116f2722064d314ae433b3d6e81d2436e93fdf446112663cde93b87
+9c8be44beb45f18e2c78fee9b016033f01ecda51b9b142091fa69f65ab784d2c
+5ad8d34be6f7f1464adfc1e0ef3f7848f40d3bdea4412758f2fcb655c93d8f4d
+f6fa48fc5aa4b75dd1c017ab79ac9d737233a6d668f5364ccf47786debd37334
+9c10c9e6efbe78430a61f71c89948aa32cdc3cc7338cf994147819ce7ab23450
+c8f7d9b94c3bb377d17a3fa204b601526317824b142ff6bc843fa7815ece89c0
+839573f234dac8d80cc571a045353d61db904a4398d8ef3df5ac
+"""
+))
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der -outform DER -out keyenc.der
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'skip encryption', # pbeWithMD5AndDES-CBC, only decoding is supported
+-1,
+"",
+"",
+"""
+308201f1301b06092a864886f70d010503300e0408f9b990c89af1d41b020208
+00048201d0c6267fe8592903891933d559e71a7ca68b2e39150f19daca0f7921
+52f97e249d72f670d5140e9150433310ed7c7ee51927693fd39884cb9551cea5
+a7b746f7edf199f8787d4787a35dad930d7db057b2118851211b645ac8b90fa6
+b0e7d49ac8567cbd5fff226e87aa9129a0f52c45e9307752e8575c3b0ff756b7
+31fda6942d15ecb6b27ea19370ccc79773f47891e80d22b440d81259c4c28eac
+e0ca839524116bcf52d8c566e49a95ddb0e5493437279a770a39fd333f3fca91
+55884fad0ba5aaf273121f893059d37dd417da7dcfd0d6fa7494968f13b2cc95
+65633f2c891340193e5ec00e4ee0b0e90b3b93da362a4906360845771ade1754
+9df79140be5993f3424c012598eadd3e7c7c0b4db2c72cf103d7943a5cf61420
+93370b9702386c3dd4eb0a47f34b579624a46a108b2d13921fa1b367495fe345
+6aa128aa70f8ca80ae13eb301e96c380724ce67c54380bbea2316c1faf4d058e
+b4ca2e23442047606b9bc4b3bf65b432cb271bea4eb35dd3eb360d3be8612a87
+a50e96a2264490aeabdc07c6e78e5dbf4fe3388726d0e2a228346bf3c2907d68
+2a6276b22ae883fb30fa611f4e4193e7a08480fcd7db48308bacbd72bf4807aa
+11fd394859f97d22982f7fe890b2e2a0f7e7ffb693
+"""
+))
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der
+# -outform DER -out keyenc.der -v1 PBE-SHA1-RC2-64
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'skip encryption', # pbeWithSHA1AndRC2-CBC, only decoding is supported
+-1,
+"",
+"",
+"""
+308201f1301b06092a864886f70d01050b300e04083ee943bdae185008020208
+00048201d0e4614d9371d3ff10ceabc2f6a7a13a0f449f9a714144e46518ea55
+e3e6f0cde24031d01ef1f37ec40081449ef01914faf45983dde0d2bc496712de
+8dd15a5527dff4721d9016c13f34fb93e3ce68577e30146266d71b539f854e56
+753a192cf126ed4812734d86f81884374f1100772f78d0646e9946407637c565
+d070acab413c55952f7237437f2e48cae7fa0ff8d370de2bf446dd08049a3663
+d9c813ac197468c02e2b687e7ca994cf7f03f01b6eca87dbfed94502c2094157
+ea39f73fe4e591df1a68b04d19d9adab90bb9898467c1464ad20bf2b8fb9a5ff
+d3ec91847d1c67fd768a4b9cfb46572eccc83806601372b6fad0243f58f623b7
+1c5809dea0feb8278fe27e5560eed8448dc93f5612f546e5dd7c5f6404365eb2
+5bf3396814367ae8b15c5c432b57eaed1f882c05c7f6517ee9e42b87b7b8d071
+9d6125d1b52f7b2cca1f6bd5f584334bf90bce1a7d938274cafe27b68e629698
+b16e27ae528db28593af9adcfccbebb3b9e1f2af5cd5531b51968389caa6c091
+e7de1f1b96f0d258e54e540d961a7c0ef51fda45d6da5fddd33e9bbfd3a5f8d7
+d7ab2e971de495cddbc86d38444fee9f0ac097b00adaf7802dabe0cff5b43b45
+4f26b7b547016f89be52676866189911c53e2f2477"""
+))
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der
+# -outform DER -out keyenc.der -v1 PBE-MD5-RC2-64
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'skip encryption', # pbeWithMD5AndRC2-CBC, only decoding is supported
+-1,
+"",
+"",
+"""
+308201f1301b06092a864886f70d010506300e0408f5cd2fee56d9b4b8020208
+00048201d086454942d6166a19d6b108465bd111e7080911f573d54b1369c676
+df28600e84936bfec04f91023ff16499e2e07178c340904f12ffa6886ab66228
+32bf43c2bff5a0ed14e765918cf5fc543ad49566246f7eb3fc044fa5a9c25f40
+8fc8c8296b91658d3bb1067c0aba008c4fefd9e2bcdbbbd63fdc8085482bccf4
+f150cec9a084259ad441a017e5d81a1034ef2484696a7a50863836d0eeda45cd
+8cee8ecabfed703f8d9d4bbdf3a767d32a0ccdc38550ee2928d7fe3fa27eda5b
+5c7899e75ad55d076d2c2d3c37d6da3d95236081f9671dab9a99afdb1cbc890e
+332d1a91105d9a8ce08b6027aa07367bd1daec3059cb51f5d896124da16971e4
+0ca4bcadb06c854bdf39f42dd24174011414e51626d198775eff3449a982df7b
+ace874e77e045eb6d7c3faef0750792b29a068a6291f7275df1123fac5789c51
+27ace42836d81633faf9daf38f6787fff0394ea484bbcd465b57d4dbee3cf8df
+b77d1db287b3a6264c466805be5a4fe85cfbca180699859280f2dd8e2c2c10b5
+7a7d2ac670c6039d41952fbb0e4f99b560ebe1d020e1b96d02403283819c00cc
+529c51f0b0101555e4c58002ba3c6e3c12e3fde1aec94382792e96d9666a2b33
+3dc397b22ecab67ee38a552fec29a1d4ff8719c748"""
+))
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der
+# -outform DER -out keyenc.der -v1 PBE-SHA1-DES
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'skip encryption', # pbeWithSHA1AndDES-CBC, only decoding is supported
+-1,
+"",
+"",
+"""
+308201f1301b06092a864886f70d01050a300e04089bacc9cf1e8f734e020208
+00048201d03e502f3ceafe8fd19ab2939576bfdded26d719b2441db1459688f5
+9673218b41ec1f739edf1e460bd927bc28470c87b2d4fc8ea02ba17b47a63c49
+c5c1bee40529dadfd3ef8b4472c730bc136678c78abfb34670ec9d7dcd17ee3f
+892f93f2629e6e0f4b24ecb9f954069bf722f466dece3913bb6abbd2c471d9a5
+c5eea89b14aaccda43d30b0dd0f6eb6e9850d9747aa8aa8414c383ad01c374ee
+26d3552abec9ba22669cc9622ccf2921e3d0c8ecd1a70e861956de0bec6104b5
+b649ac994970c83f8a9e84b14a7dff7843d4ca3dd4af87cea43b5657e15ae0b5
+a940ce5047f006ab3596506600724764f23757205fe374fee04911336d655acc
+03e159ec27789191d1517c4f3f9122f5242d44d25eab8f0658cafb928566ca0e
+8f6589aa0c0ab13ca7a618008ae3eafd4671ee8fe0b562e70b3623b0e2a16eee
+97fd388087d2e03530c9fe7db6e52eccc7c48fd701ede35e08922861a9508d12
+bc8bbf24f0c6bee6e63dbcb489b603d4c4a78ce45bf2eab1d5d10456c42a65a8
+3a606f4e4b9b46eb13b57f2624b651859d3d2d5192b45dbd5a2ead14ff20ca76
+48f321309aa56d8c0c4a192b580821cc6c70c75e6f19d1c5414da898ec4dd39d
+b0eb93d6ba387a80702dfd2db610757ba340f63230
+"""
+))
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der
+# -outform DER -out keyenc.der -v2 aes128
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'PBKDF2WithHMAC-SHA1AndAES128-CBC',
+2048,
+"4F66EE5D3BCD531FE6EBF4B4E73016B8", # IV
+"479F25156176C53A", # Salt
+"""
+3082021f304906092a864886f70d01050d303c301b06092a864886f70d01050c
+300e0408479f25156176c53a02020800301d060960864801650304010204104f
+66ee5d3bcd531fe6ebf4b4e73016b8048201d0e33cfa560423f589d097d21533
+3b880a5ebac5b2ac58b4e73b0d787aee7764f034fe34ca1d1bd845c0a7c3316f
+afbfb2129e03dcaf5a5031394206492828dacef1e04639bee5935e0f46114202
+10bc6c37182f4889be11c5d0486c398f4be952e5740f65de9d8edeb275e2b406
+e19bc29ad5ebb97fa536344fc3d84c7e755696f12b810898de4e6f069b8a81c8
+0aab0d45d7d062303aaa4a10c2ce84fdb5a03114039cfe138e38bb15b2ced717
+93549cdad85e730b14d9e2198b663dfdc8d04a4349eb3de59b076ad40b116d4a
+25ed917c576bc7c883c95ef0f1180e28fc9981bea069594c309f1aa1b253ceab
+a2f0313bb1372bcb51a745056be93d77a1f235a762a45e8856512d436b2ca0f7
+dd60fbed394ba28978d2a2b984b028529d0a58d93aba46c6bbd4ac1e4013cbaa
+63b00988bc5f11ccc40141c346762d2b28f64435d4be98ec17c1884985e3807e
+e550db606600993efccf6de0dfc2d2d70b5336a3b018fa415d6bdd59f5777118
+16806b7bc17c4c7e20ad7176ebfa5a1aa3f6bc10f04b77afd443944642ac9cca
+d740e082b4a3bbb8bafdd34a0b3c5f2f3c2aceccccdccd092b78994b845bfa61
+706c3b9df5165ed1dbcbf1244fe41fc9bf993f52f7658e2f87e1baaeacb0f562
+9d905c
+"""
+))
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der
+# -outform DER -out keyenc.der -v2 aes192
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'PBKDF2WithHMAC-SHA1AndAES192-CBC',
+2048,
+"5CFC2A4FF7B63201A4A8A5B021148186", # IV
+"D718541C264944CE", # Salt
+"""
+3082021f304906092a864886f70d01050d303c301b06092a864886f70d01050c
+300e0408d718541c264944ce02020800301d060960864801650304011604105c
+fc2a4ff7b63201a4a8a5b021148186048201d08e74aaa21b8bcfb15b9790fe95
+b0e09ddb0f189b6fb1682fdb9f122b804650ddec3c67a1df093a828b3e5fbcc6
+286abbcc5354c482fd796d972e919ca8a5eba1eaa2293af1d648013ddad72106
+75622264dfba55dafdda39e338f058f1bdb9846041ffff803797d3fdf3693135
+8a192729ea8346a7e5e58e925a2e2e4af0818581859e8215d87370eb4194a5ff
+bae900857d4c591dbc651a241865a817eaede9987c9f9ae4f95c0bf930eea88c
+4d7596e535ffb7ca369988aba75027a96b9d0bc9c8b0b75f359067fd145a378b
+02aaa15e9db7a23176224da48a83249005460cc6e429168657f2efa8b1af7537
+d7d7042f2d683e8271b21d591090963eeb57aea6172f88da139e1614d6a7d1a2
+1002d5a7a93d6d21156e2b4777f6fc069287a85a1538c46b7722ccde591ab55c
+630e1ceeb1ac42d1b41f3f654e9da86b5efced43775ea68b2594e50e4005e052
+0fe753c0898120c2c07265367ff157f6538a1e4080d6f9d1ca9eb51939c9574e
+f2e4e1e87c1434affd5808563cddd376776dbbf790c6a40028f311a8b58dafa2
+0970ed34acd6e3e89d063987893b2b9570ddb8cc032b05a723bba9444933ebf3
+c624204be72f4190e0245197d0cb772bec933fd8442445f9a28bd042d5a3a1e9
+9a8a07
+"""
+))
+
+#
+# openssl pkcs8 -topk8 -passin pass:TestTest -inform DER -in key.der
+# -outform DER -out keyenc.der -v2 aes192
+# hexdump -v -e '32/1 "%02x" "\n"' keyenc.der
+#
+wrapped_enc_keys.append((
+'PBKDF2WithHMAC-SHA1AndAES256-CBC',
+2048,
+"323351F94462AC563E053A056252C2C4", # IV
+"02A6CD0D12E727B5", # Salt
+"""
+3082021f304906092a864886f70d01050d303c301b06092a864886f70d01050c
+300e040802a6cd0d12e727b502020800301d060960864801650304012a041032
+3351f94462ac563e053a056252c2c4048201d07f4ef1c7be21aae738a20c5632
+b8bdbbb9083b6e7f68822267b1f481fd27fdafd61a90660de6e4058790e4c912
+bf3f319a7c37e6eb3d956daaa143865020d554bf6215e8d7492359aaeef45d6e
+d85a686ed26c0bf7c18d071d827a86f0b73e1db0c0e7f3d42201544093302a90
+551ad530692468c47ac15c69500b8ca67d4a17b64d15cecc035ae50b768a36cf
+07c395afa091e9e6f86f665455fbdc1b21ad79c0908b73da5de75a9b43508d5d
+44dc97a870cd3cd9f01ca24452e9b11c1b4982946702cfcbfda5b2fcc0203fb5
+0b52a115760bd635c94d4c95ac2c640ee9a04ffaf6ccff5a8d953dd5d88ca478
+c377811c521f2191639c643d657a9e364af88bb7c14a356c2b0b4870a23c2f54
+d41f8157afff731471dccc6058b15e1151bcf84b39b5e622a3a1d65859c912a5
+591b85e034a1f6af664f030a6bfc8c3d20c70f32b54bcf4da9c2da83cef49cf8
+e9a74f0e5d358fe50b88acdce6a9db9a7ad61536212fc5f877ebfc7957b8bda4
+b1582a0f10d515a20ee06cf768db9c977aa6fbdca7540d611ff953012d009dac
+e8abd059f8e8ffea637c9c7721f817aaf0bb23403e26a0ef0ff0e2037da67d41
+af728481f53443551a9bff4cea023164e9622b5441a309e1f4bff98e5bf76677
+8d7cd9
+"""
+))
+
+def txt2bin(inputs):
+ s = b('').join([b(x) for x in inputs if not (x in '\n\r\t ')])
+ return unhexlify(s)
+
+class Rng:
+ def __init__(self, output):
+ self.output=output
+ self.idx=0
+ def __call__(self, n):
+ output = self.output[self.idx:self.idx+n]
+ self.idx += n
+ return output
+
+class PKCS8_Decrypt(unittest.TestCase):
+
+ def setUp(self):
+ self.oid_key = oid_key
+ self.clear_key = txt2bin(clear_key)
+ self.wrapped_clear_key = txt2bin(wrapped_clear_key)
+ self.wrapped_enc_keys = []
+ for t in wrapped_enc_keys:
+ self.wrapped_enc_keys.append((
+ t[0],
+ t[1],
+ txt2bin(t[2]),
+ txt2bin(t[3]),
+ txt2bin(t[4])
+ ))
+
+ ### NO ENCRYTION
+
+ def test1(self):
+ """Verify unwrapping w/o encryption"""
+ res1, res2, res3 = PKCS8.unwrap(self.wrapped_clear_key)
+ self.assertEqual(res1, self.oid_key)
+ self.assertEqual(res2, self.clear_key)
+
+ def test2(self):
+ """Verify wrapping w/o encryption"""
+ wrapped = PKCS8.wrap(self.clear_key, self.oid_key)
+ res1, res2, res3 = PKCS8.unwrap(wrapped)
+ self.assertEqual(res1, self.oid_key)
+ self.assertEqual(res2, self.clear_key)
+
+ ## ENCRYPTION
+
+ def test3(self):
+ """Verify unwrapping with encryption"""
+
+ for t in self.wrapped_enc_keys:
+ res1, res2, res3 = PKCS8.unwrap(t[4], b("TestTest"))
+ self.assertEqual(res1, self.oid_key)
+ self.assertEqual(res2, self.clear_key)
+
+ def test4(self):
+ """Verify wrapping with encryption"""
+
+ for t in self.wrapped_enc_keys:
+ if t[0] == 'skip encryption':
+ continue
+ rng = Rng(t[2]+t[3])
+ params = { 'iteration_count':t[1] }
+ wrapped = PKCS8.wrap(
+ self.clear_key,
+ self.oid_key,
+ b("TestTest"),
+ protection=t[0],
+ prot_params=params,
+ key_params=DerNull(),
+ randfunc=rng)
+ self.assertEqual(wrapped, t[4])
+
+def get_tests(config={}):
+ from Crypto.SelfTest.st_common import list_test_cases
+ listTests = []
+ listTests += list_test_cases(PKCS8_Decrypt)
+ return listTests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
diff --git a/lib/Crypto/SelfTest/Math/__init__.py b/lib/Crypto/SelfTest/Math/__init__.py
new file mode 100644
index 0000000..18e83d1
--- /dev/null
+++ b/lib/Crypto/SelfTest/Math/__init__.py
@@ -0,0 +1,49 @@
+#
+# SelfTest/Math/__init__.py: Self-test for math module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test for Math"""
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest.Math import test_Numbers
+ from Crypto.SelfTest.Math import test_Primality
+ from Crypto.SelfTest.Math import test_modexp
+ tests += test_Numbers.get_tests(config=config)
+ tests += test_Primality.get_tests(config=config)
+ tests += test_modexp.get_tests(config=config)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Math/test_Numbers.py b/lib/Crypto/SelfTest/Math/test_Numbers.py
new file mode 100644
index 0000000..924eca4
--- /dev/null
+++ b/lib/Crypto/SelfTest/Math/test_Numbers.py
@@ -0,0 +1,797 @@
+#
+# SelfTest/Math/test_Numbers.py: Self-test for Numbers module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test for Math.Numbers"""
+
+import sys
+import unittest
+
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Util.py3compat import *
+
+from Crypto.Math._IntegerNative import IntegerNative
+
+
+class TestIntegerBase(unittest.TestCase):
+
+ def setUp(self):
+ raise NotImplementedError("To be implemented")
+
+ def Integers(self, *arg):
+ return map(self.Integer, arg)
+
+ def test_init_and_equality(self):
+ Integer = self.Integer
+
+ v1 = Integer(23)
+ v2 = Integer(v1)
+ v3 = Integer(-9)
+ self.assertRaises(ValueError, Integer, 1.0)
+
+ v4 = Integer(10**10)
+ v5 = Integer(-10**10)
+
+ v6 = Integer(0xFFFF)
+ v7 = Integer(0xFFFFFFFF)
+ v8 = Integer(0xFFFFFFFFFFFFFFFF)
+
+ self.assertEqual(v1, v1)
+ self.assertEqual(v1, 23)
+ self.assertEqual(v1, v2)
+ self.assertEqual(v3, -9)
+ self.assertEqual(v4, 10 ** 10)
+ self.assertEqual(v5, -10 ** 10)
+ self.assertEqual(v6, 0xFFFF)
+ self.assertEqual(v7, 0xFFFFFFFF)
+ self.assertEqual(v8, 0xFFFFFFFFFFFFFFFF)
+
+ self.assertFalse(v1 == v4)
+
+ # Init and comparison between Integer's
+ v6 = Integer(v1)
+ self.assertEqual(v1, v6)
+
+ self.assertFalse(Integer(0) == None)
+
+ def test_conversion_to_int(self):
+ v1, v2 = self.Integers(-23, 2 ** 1000)
+ self.assertEqual(int(v1), -23)
+ self.assertEqual(int(v2), 2 ** 1000)
+
+ def test_equality_with_ints(self):
+ v1, v2, v3 = self.Integers(23, -89, 2 ** 1000)
+ self.assertTrue(v1 == 23)
+ self.assertTrue(v2 == -89)
+ self.assertFalse(v1 == 24)
+ self.assertTrue(v3 == 2 ** 1000)
+
+ def test_conversion_to_str(self):
+ v1, v2, v3, v4 = self.Integers(20, 0, -20, 2 ** 1000)
+ self.assertTrue(str(v1) == "20")
+ self.assertTrue(str(v2) == "0")
+ self.assertTrue(str(v3) == "-20")
+ self.assertTrue(str(v4) == "10715086071862673209484250490600018105614048117055336074437503883703510511249361224931983788156958581275946729175531468251871452856923140435984577574698574803934567774824230985421074605062371141877954182153046474983581941267398767559165543946077062914571196477686542167660429831652624386837205668069376")
+
+ def test_repr(self):
+ v1, v2 = self.Integers(-1, 2**80)
+ self.assertEqual(repr(v1), "Integer(-1)")
+ self.assertEqual(repr(v2), "Integer(1208925819614629174706176)")
+
+ def test_conversion_to_bytes(self):
+ Integer = self.Integer
+
+ v1 = Integer(0x17)
+ self.assertEqual(b("\x17"), v1.to_bytes())
+
+ v2 = Integer(0xFFFE)
+ self.assertEqual(b("\xFF\xFE"), v2.to_bytes())
+ self.assertEqual(b("\x00\xFF\xFE"), v2.to_bytes(3))
+ self.assertRaises(ValueError, v2.to_bytes, 1)
+
+ self.assertEqual(b("\xFE\xFF"), v2.to_bytes(byteorder='little'))
+ self.assertEqual(b("\xFE\xFF\x00"), v2.to_bytes(3, byteorder='little'))
+
+ v3 = Integer(-90)
+ self.assertRaises(ValueError, v3.to_bytes)
+ self.assertRaises(ValueError, v3.to_bytes, byteorder='bittle')
+
+ def test_conversion_from_bytes(self):
+ Integer = self.Integer
+
+ v1 = Integer.from_bytes(b"\x00")
+ self.assertTrue(isinstance(v1, Integer))
+ self.assertEqual(0, v1)
+
+ v2 = Integer.from_bytes(b"\x00\x01")
+ self.assertEqual(1, v2)
+
+ v3 = Integer.from_bytes(b"\xFF\xFF")
+ self.assertEqual(0xFFFF, v3)
+
+ v4 = Integer.from_bytes(b"\x00\x01", 'big')
+ self.assertEqual(1, v4)
+
+ v5 = Integer.from_bytes(b"\x00\x01", byteorder='big')
+ self.assertEqual(1, v5)
+
+ v6 = Integer.from_bytes(b"\x00\x01", byteorder='little')
+ self.assertEqual(0x0100, v6)
+
+ self.assertRaises(ValueError, Integer.from_bytes, b'\x09', 'bittle')
+
+ def test_inequality(self):
+ # Test Integer!=Integer and Integer!=int
+ v1, v2, v3, v4 = self.Integers(89, 89, 90, -8)
+ self.assertTrue(v1 != v3)
+ self.assertTrue(v1 != 90)
+ self.assertFalse(v1 != v2)
+ self.assertFalse(v1 != 89)
+ self.assertTrue(v1 != v4)
+ self.assertTrue(v4 != v1)
+ self.assertTrue(self.Integer(0) != None)
+
+ def test_less_than(self):
+ # Test Integer<Integer and Integer<int
+ v1, v2, v3, v4, v5 = self.Integers(13, 13, 14, -8, 2 ** 10)
+ self.assertTrue(v1 < v3)
+ self.assertTrue(v1 < 14)
+ self.assertFalse(v1 < v2)
+ self.assertFalse(v1 < 13)
+ self.assertTrue(v4 < v1)
+ self.assertFalse(v1 < v4)
+ self.assertTrue(v1 < v5)
+ self.assertFalse(v5 < v1)
+
+ def test_less_than_or_equal(self):
+ # Test Integer<=Integer and Integer<=int
+ v1, v2, v3, v4, v5 = self.Integers(13, 13, 14, -4, 2 ** 10)
+ self.assertTrue(v1 <= v1)
+ self.assertTrue(v1 <= 13)
+ self.assertTrue(v1 <= v2)
+ self.assertTrue(v1 <= 14)
+ self.assertTrue(v1 <= v3)
+ self.assertFalse(v1 <= v4)
+ self.assertTrue(v1 <= v5)
+ self.assertFalse(v5 <= v1)
+
+ def test_more_than(self):
+ # Test Integer>Integer and Integer>int
+ v1, v2, v3, v4, v5 = self.Integers(13, 13, 14, -8, 2 ** 10)
+ self.assertTrue(v3 > v1)
+ self.assertTrue(v3 > 13)
+ self.assertFalse(v1 > v1)
+ self.assertFalse(v1 > v2)
+ self.assertFalse(v1 > 13)
+ self.assertTrue(v1 > v4)
+ self.assertFalse(v4 > v1)
+ self.assertTrue(v5 > v1)
+ self.assertFalse(v1 > v5)
+
+ def test_more_than_or_equal(self):
+ # Test Integer>=Integer and Integer>=int
+ v1, v2, v3, v4 = self.Integers(13, 13, 14, -4)
+ self.assertTrue(v3 >= v1)
+ self.assertTrue(v3 >= 13)
+ self.assertTrue(v1 >= v2)
+ self.assertTrue(v1 >= v1)
+ self.assertTrue(v1 >= 13)
+ self.assertFalse(v4 >= v1)
+
+ def test_bool(self):
+ v1, v2, v3, v4 = self.Integers(0, 10, -9, 2 ** 10)
+ self.assertFalse(v1)
+ self.assertFalse(bool(v1))
+ self.assertTrue(v2)
+ self.assertTrue(bool(v2))
+ self.assertTrue(v3)
+ self.assertTrue(v4)
+
+ def test_is_negative(self):
+ v1, v2, v3, v4, v5 = self.Integers(-3 ** 100, -3, 0, 3, 3**100)
+ self.assertTrue(v1.is_negative())
+ self.assertTrue(v2.is_negative())
+ self.assertFalse(v4.is_negative())
+ self.assertFalse(v5.is_negative())
+
+ def test_addition(self):
+ # Test Integer+Integer and Integer+int
+ v1, v2, v3 = self.Integers(7, 90, -7)
+ self.assertTrue(isinstance(v1 + v2, self.Integer))
+ self.assertEqual(v1 + v2, 97)
+ self.assertEqual(v1 + 90, 97)
+ self.assertEqual(v1 + v3, 0)
+ self.assertEqual(v1 + (-7), 0)
+ self.assertEqual(v1 + 2 ** 10, 2 ** 10 + 7)
+
+ def test_subtraction(self):
+ # Test Integer-Integer and Integer-int
+ v1, v2, v3 = self.Integers(7, 90, -7)
+ self.assertTrue(isinstance(v1 - v2, self.Integer))
+ self.assertEqual(v2 - v1, 83)
+ self.assertEqual(v2 - 7, 83)
+ self.assertEqual(v2 - v3, 97)
+ self.assertEqual(v1 - (-7), 14)
+ self.assertEqual(v1 - 2 ** 10, 7 - 2 ** 10)
+
+ def test_multiplication(self):
+ # Test Integer-Integer and Integer-int
+ v1, v2, v3, v4 = self.Integers(4, 5, -2, 2 ** 10)
+ self.assertTrue(isinstance(v1 * v2, self.Integer))
+ self.assertEqual(v1 * v2, 20)
+ self.assertEqual(v1 * 5, 20)
+ self.assertEqual(v1 * -2, -8)
+ self.assertEqual(v1 * 2 ** 10, 4 * (2 ** 10))
+
+ def test_floor_div(self):
+ v1, v2, v3 = self.Integers(3, 8, 2 ** 80)
+ self.assertTrue(isinstance(v1 // v2, self.Integer))
+ self.assertEqual(v2 // v1, 2)
+ self.assertEqual(v2 // 3, 2)
+ self.assertEqual(v2 // -3, -3)
+ self.assertEqual(v3 // 2 ** 79, 2)
+ self.assertRaises(ZeroDivisionError, lambda: v1 // 0)
+
+ def test_remainder(self):
+ # Test Integer%Integer and Integer%int
+ v1, v2, v3 = self.Integers(23, 5, -4)
+ self.assertTrue(isinstance(v1 % v2, self.Integer))
+ self.assertEqual(v1 % v2, 3)
+ self.assertEqual(v1 % 5, 3)
+ self.assertEqual(v3 % 5, 1)
+ self.assertEqual(v1 % 2 ** 10, 23)
+ self.assertRaises(ZeroDivisionError, lambda: v1 % 0)
+ self.assertRaises(ValueError, lambda: v1 % -6)
+
+ def test_simple_exponentiation(self):
+ v1, v2, v3 = self.Integers(4, 3, -2)
+ self.assertTrue(isinstance(v1 ** v2, self.Integer))
+ self.assertEqual(v1 ** v2, 64)
+ self.assertEqual(pow(v1, v2), 64)
+ self.assertEqual(v1 ** 3, 64)
+ self.assertEqual(pow(v1, 3), 64)
+ self.assertEqual(v3 ** 2, 4)
+ self.assertEqual(v3 ** 3, -8)
+
+ self.assertRaises(ValueError, pow, v1, -3)
+
+ def test_modular_exponentiation(self):
+ v1, v2, v3 = self.Integers(23, 5, 17)
+
+ self.assertTrue(isinstance(pow(v1, v2, v3), self.Integer))
+ self.assertEqual(pow(v1, v2, v3), 7)
+ self.assertEqual(pow(v1, 5, v3), 7)
+ self.assertEqual(pow(v1, v2, 17), 7)
+ self.assertEqual(pow(v1, 5, 17), 7)
+ self.assertEqual(pow(v1, 0, 17), 1)
+ self.assertEqual(pow(v1, 1, 2 ** 80), 23)
+ self.assertEqual(pow(v1, 2 ** 80, 89298), 17689)
+
+ self.assertRaises(ZeroDivisionError, pow, v1, 5, 0)
+ self.assertRaises(ValueError, pow, v1, 5, -4)
+ self.assertRaises(ValueError, pow, v1, -3, 8)
+
+ def test_inplace_exponentiation(self):
+ v1 = self.Integer(4)
+ v1.inplace_pow(2)
+ self.assertEqual(v1, 16)
+
+ v1 = self.Integer(4)
+ v1.inplace_pow(2, 15)
+ self.assertEqual(v1, 1)
+
+ def test_abs(self):
+ v1, v2, v3, v4, v5 = self.Integers(-2 ** 100, -2, 0, 2, 2 ** 100)
+ self.assertEqual(abs(v1), 2 ** 100)
+ self.assertEqual(abs(v2), 2)
+ self.assertEqual(abs(v3), 0)
+ self.assertEqual(abs(v4), 2)
+ self.assertEqual(abs(v5), 2 ** 100)
+
+ def test_sqrt(self):
+ v1, v2, v3, v4 = self.Integers(-2, 0, 49, 10**100)
+
+ self.assertRaises(ValueError, v1.sqrt)
+ self.assertEqual(v2.sqrt(), 0)
+ self.assertEqual(v3.sqrt(), 7)
+ self.assertEqual(v4.sqrt(), 10**50)
+
+ def test_sqrt_module(self):
+
+ # Invalid modulus (non positive)
+ self.assertRaises(ValueError, self.Integer(5).sqrt, 0)
+ self.assertRaises(ValueError, self.Integer(5).sqrt, -1)
+
+ # Simple cases
+ assert self.Integer(0).sqrt(5) == 0
+ assert self.Integer(1).sqrt(5) in (1, 4)
+
+ # Test with all quadratic residues in several fields
+ for p in (11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53):
+ for i in range(0, p):
+ square = i**2 % p
+ res = self.Integer(square).sqrt(p)
+ assert res in (i, p - i)
+
+ # 2 is a non-quadratic reside in Z_11
+ self.assertRaises(ValueError, self.Integer(2).sqrt, 11)
+
+ # 10 is not a prime
+ self.assertRaises(ValueError, self.Integer(4).sqrt, 10)
+
+ # 5 is square residue of 4 and 7
+ assert self.Integer(5 - 11).sqrt(11) in (4, 7)
+ assert self.Integer(5 + 11).sqrt(11) in (4, 7)
+
+ def test_in_place_add(self):
+ v1, v2 = self.Integers(10, 20)
+
+ v1 += v2
+ self.assertEqual(v1, 30)
+ v1 += 10
+ self.assertEqual(v1, 40)
+ v1 += -1
+ self.assertEqual(v1, 39)
+ v1 += 2 ** 1000
+ self.assertEqual(v1, 39 + 2 ** 1000)
+
+ def test_in_place_sub(self):
+ v1, v2 = self.Integers(10, 20)
+
+ v1 -= v2
+ self.assertEqual(v1, -10)
+ v1 -= -100
+ self.assertEqual(v1, 90)
+ v1 -= 90000
+ self.assertEqual(v1, -89910)
+ v1 -= -100000
+ self.assertEqual(v1, 10090)
+
+ def test_in_place_mul(self):
+ v1, v2 = self.Integers(3, 5)
+
+ v1 *= v2
+ self.assertEqual(v1, 15)
+ v1 *= 2
+ self.assertEqual(v1, 30)
+ v1 *= -2
+ self.assertEqual(v1, -60)
+ v1 *= 2 ** 1000
+ self.assertEqual(v1, -60 * (2 ** 1000))
+
+ def test_in_place_modulus(self):
+ v1, v2 = self.Integers(20, 7)
+
+ v1 %= v2
+ self.assertEqual(v1, 6)
+ v1 %= 2 ** 1000
+ self.assertEqual(v1, 6)
+ v1 %= 2
+ self.assertEqual(v1, 0)
+ def t():
+ v3 = self.Integer(9)
+ v3 %= 0
+ self.assertRaises(ZeroDivisionError, t)
+
+ def test_and(self):
+ v1, v2, v3 = self.Integers(0xF4, 0x31, -0xF)
+ self.assertTrue(isinstance(v1 & v2, self.Integer))
+ self.assertEqual(v1 & v2, 0x30)
+ self.assertEqual(v1 & 0x31, 0x30)
+ self.assertEqual(v1 & v3, 0xF0)
+ self.assertEqual(v1 & -0xF, 0xF0)
+ self.assertEqual(v3 & -0xF, -0xF)
+ self.assertEqual(v2 & (2 ** 1000 + 0x31), 0x31)
+
+ def test_or(self):
+ v1, v2, v3 = self.Integers(0x40, 0x82, -0xF)
+ self.assertTrue(isinstance(v1 | v2, self.Integer))
+ self.assertEqual(v1 | v2, 0xC2)
+ self.assertEqual(v1 | 0x82, 0xC2)
+ self.assertEqual(v2 | v3, -0xD)
+ self.assertEqual(v2 | 2 ** 1000, 2 ** 1000 + 0x82)
+
+ def test_right_shift(self):
+ v1, v2, v3 = self.Integers(0x10, 1, -0x10)
+ self.assertEqual(v1 >> 0, v1)
+ self.assertTrue(isinstance(v1 >> v2, self.Integer))
+ self.assertEqual(v1 >> v2, 0x08)
+ self.assertEqual(v1 >> 1, 0x08)
+ self.assertRaises(ValueError, lambda: v1 >> -1)
+ self.assertEqual(v1 >> (2 ** 1000), 0)
+
+ self.assertEqual(v3 >> 1, -0x08)
+ self.assertEqual(v3 >> (2 ** 1000), -1)
+
+ def test_in_place_right_shift(self):
+ v1, v2, v3 = self.Integers(0x10, 1, -0x10)
+ v1 >>= 0
+ self.assertEqual(v1, 0x10)
+ v1 >>= 1
+ self.assertEqual(v1, 0x08)
+ v1 >>= v2
+ self.assertEqual(v1, 0x04)
+ v3 >>= 1
+ self.assertEqual(v3, -0x08)
+ def l():
+ v4 = self.Integer(0x90)
+ v4 >>= -1
+ self.assertRaises(ValueError, l)
+ def m1():
+ v4 = self.Integer(0x90)
+ v4 >>= 2 ** 1000
+ return v4
+ self.assertEqual(0, m1())
+ def m2():
+ v4 = self.Integer(-1)
+ v4 >>= 2 ** 1000
+ return v4
+ self.assertEqual(-1, m2())
+
+ def _test_left_shift(self):
+ v1, v2, v3 = self.Integers(0x10, 1, -0x10)
+ self.assertEqual(v1 << 0, v1)
+ self.assertTrue(isinstance(v1 << v2, self.Integer))
+ self.assertEqual(v1 << v2, 0x20)
+ self.assertEqual(v1 << 1, 0x20)
+ self.assertEqual(v3 << 1, -0x20)
+ self.assertRaises(ValueError, lambda: v1 << -1)
+ self.assertRaises(ValueError, lambda: v1 << (2 ** 1000))
+
+ def test_in_place_left_shift(self):
+ v1, v2, v3 = self.Integers(0x10, 1, -0x10)
+ v1 <<= 0
+ self.assertEqual(v1, 0x10)
+ v1 <<= 1
+ self.assertEqual(v1, 0x20)
+ v1 <<= v2
+ self.assertEqual(v1, 0x40)
+ v3 <<= 1
+ self.assertEqual(v3, -0x20)
+ def l():
+ v4 = self.Integer(0x90)
+ v4 <<= -1
+ self.assertRaises(ValueError, l)
+ def m():
+ v4 = self.Integer(0x90)
+ v4 <<= 2 ** 1000
+ self.assertRaises(ValueError, m)
+
+
+ def test_get_bit(self):
+ v1, v2, v3 = self.Integers(0x102, -3, 1)
+ self.assertEqual(v1.get_bit(0), 0)
+ self.assertEqual(v1.get_bit(1), 1)
+ self.assertEqual(v1.get_bit(v3), 1)
+ self.assertEqual(v1.get_bit(8), 1)
+ self.assertEqual(v1.get_bit(9), 0)
+
+ self.assertRaises(ValueError, v1.get_bit, -1)
+ self.assertEqual(v1.get_bit(2 ** 1000), 0)
+
+ self.assertRaises(ValueError, v2.get_bit, -1)
+ self.assertRaises(ValueError, v2.get_bit, 0)
+ self.assertRaises(ValueError, v2.get_bit, 1)
+ self.assertRaises(ValueError, v2.get_bit, 2 * 1000)
+
+ def test_odd_even(self):
+ v1, v2, v3, v4, v5 = self.Integers(0, 4, 17, -4, -17)
+
+ self.assertTrue(v1.is_even())
+ self.assertTrue(v2.is_even())
+ self.assertFalse(v3.is_even())
+ self.assertTrue(v4.is_even())
+ self.assertFalse(v5.is_even())
+
+ self.assertFalse(v1.is_odd())
+ self.assertFalse(v2.is_odd())
+ self.assertTrue(v3.is_odd())
+ self.assertFalse(v4.is_odd())
+ self.assertTrue(v5.is_odd())
+
+ def test_size_in_bits(self):
+ v1, v2, v3, v4 = self.Integers(0, 1, 0x100, -90)
+ self.assertEqual(v1.size_in_bits(), 1)
+ self.assertEqual(v2.size_in_bits(), 1)
+ self.assertEqual(v3.size_in_bits(), 9)
+ self.assertRaises(ValueError, v4.size_in_bits)
+
+ def test_size_in_bytes(self):
+ v1, v2, v3, v4, v5, v6 = self.Integers(0, 1, 0xFF, 0x1FF, 0x10000, -9)
+ self.assertEqual(v1.size_in_bytes(), 1)
+ self.assertEqual(v2.size_in_bytes(), 1)
+ self.assertEqual(v3.size_in_bytes(), 1)
+ self.assertEqual(v4.size_in_bytes(), 2)
+ self.assertEqual(v5.size_in_bytes(), 3)
+ self.assertRaises(ValueError, v6.size_in_bits)
+
+ def test_perfect_square(self):
+
+ self.assertFalse(self.Integer(-9).is_perfect_square())
+ self.assertTrue(self.Integer(0).is_perfect_square())
+ self.assertTrue(self.Integer(1).is_perfect_square())
+ self.assertFalse(self.Integer(2).is_perfect_square())
+ self.assertFalse(self.Integer(3).is_perfect_square())
+ self.assertTrue(self.Integer(4).is_perfect_square())
+ self.assertTrue(self.Integer(39*39).is_perfect_square())
+ self.assertFalse(self.Integer(39*39+1).is_perfect_square())
+
+ for x in range(100, 1000):
+ self.assertFalse(self.Integer(x**2+1).is_perfect_square())
+ self.assertTrue(self.Integer(x**2).is_perfect_square())
+
+ def test_fail_if_divisible_by(self):
+ v1, v2, v3 = self.Integers(12, -12, 4)
+
+ # No failure expected
+ v1.fail_if_divisible_by(7)
+ v2.fail_if_divisible_by(7)
+ v2.fail_if_divisible_by(2 ** 80)
+
+ # Failure expected
+ self.assertRaises(ValueError, v1.fail_if_divisible_by, 4)
+ self.assertRaises(ValueError, v1.fail_if_divisible_by, v3)
+
+ def test_multiply_accumulate(self):
+ v1, v2, v3 = self.Integers(4, 3, 2)
+ v1.multiply_accumulate(v2, v3)
+ self.assertEqual(v1, 10)
+ v1.multiply_accumulate(v2, 2)
+ self.assertEqual(v1, 16)
+ v1.multiply_accumulate(3, v3)
+ self.assertEqual(v1, 22)
+ v1.multiply_accumulate(1, -2)
+ self.assertEqual(v1, 20)
+ v1.multiply_accumulate(-2, 1)
+ self.assertEqual(v1, 18)
+ v1.multiply_accumulate(1, 2 ** 1000)
+ self.assertEqual(v1, 18 + 2 ** 1000)
+ v1.multiply_accumulate(2 ** 1000, 1)
+ self.assertEqual(v1, 18 + 2 ** 1001)
+
+ def test_set(self):
+ v1, v2 = self.Integers(3, 6)
+ v1.set(v2)
+ self.assertEqual(v1, 6)
+ v1.set(9)
+ self.assertEqual(v1, 9)
+ v1.set(-2)
+ self.assertEqual(v1, -2)
+ v1.set(2 ** 1000)
+ self.assertEqual(v1, 2 ** 1000)
+
+ def test_inverse(self):
+ v1, v2, v3, v4, v5, v6 = self.Integers(2, 5, -3, 0, 723872, 3433)
+
+ self.assertTrue(isinstance(v1.inverse(v2), self.Integer))
+ self.assertEqual(v1.inverse(v2), 3)
+ self.assertEqual(v1.inverse(5), 3)
+ self.assertEqual(v3.inverse(5), 3)
+ self.assertEqual(v5.inverse(92929921), 58610507)
+ self.assertEqual(v6.inverse(9912), 5353)
+
+ self.assertRaises(ValueError, v2.inverse, 10)
+ self.assertRaises(ValueError, v1.inverse, -3)
+ self.assertRaises(ValueError, v4.inverse, 10)
+ self.assertRaises(ZeroDivisionError, v2.inverse, 0)
+
+ def test_inplace_inverse(self):
+ v1, v2 = self.Integers(2, 5)
+
+ v1.inplace_inverse(v2)
+ self.assertEqual(v1, 3)
+
+ def test_gcd(self):
+ v1, v2, v3, v4 = self.Integers(6, 10, 17, -2)
+ self.assertTrue(isinstance(v1.gcd(v2), self.Integer))
+ self.assertEqual(v1.gcd(v2), 2)
+ self.assertEqual(v1.gcd(10), 2)
+ self.assertEqual(v1.gcd(v3), 1)
+ self.assertEqual(v1.gcd(-2), 2)
+ self.assertEqual(v4.gcd(6), 2)
+
+ def test_lcm(self):
+ v1, v2, v3, v4, v5 = self.Integers(6, 10, 17, -2, 0)
+ self.assertTrue(isinstance(v1.lcm(v2), self.Integer))
+ self.assertEqual(v1.lcm(v2), 30)
+ self.assertEqual(v1.lcm(10), 30)
+ self.assertEqual(v1.lcm(v3), 102)
+ self.assertEqual(v1.lcm(-2), 6)
+ self.assertEqual(v4.lcm(6), 6)
+ self.assertEqual(v1.lcm(0), 0)
+ self.assertEqual(v5.lcm(0), 0)
+
+ def test_jacobi_symbol(self):
+
+ data = (
+ (1001, 1, 1),
+ (19, 45, 1),
+ (8, 21, -1),
+ (5, 21, 1),
+ (610, 987, -1),
+ (1001, 9907, -1),
+ (5, 3439601197, -1)
+ )
+
+ js = self.Integer.jacobi_symbol
+
+ # Jacobi symbol is always 1 for k==1 or n==1
+ for k in range(1, 30):
+ self.assertEqual(js(k, 1), 1)
+ for n in range(1, 30, 2):
+ self.assertEqual(js(1, n), 1)
+
+ # Fail if n is not positive odd
+ self.assertRaises(ValueError, js, 6, -2)
+ self.assertRaises(ValueError, js, 6, -1)
+ self.assertRaises(ValueError, js, 6, 0)
+ self.assertRaises(ValueError, js, 0, 0)
+ self.assertRaises(ValueError, js, 6, 2)
+ self.assertRaises(ValueError, js, 6, 4)
+ self.assertRaises(ValueError, js, 6, 6)
+ self.assertRaises(ValueError, js, 6, 8)
+
+ for tv in data:
+ self.assertEqual(js(tv[0], tv[1]), tv[2])
+ self.assertEqual(js(self.Integer(tv[0]), tv[1]), tv[2])
+ self.assertEqual(js(tv[0], self.Integer(tv[1])), tv[2])
+
+ def test_jacobi_symbol_wikipedia(self):
+
+ # Test vectors from https://en.wikipedia.org/wiki/Jacobi_symbol
+ tv = [
+ (3, [(1, 1), (2, -1), (3, 0), (4, 1), (5, -1), (6, 0), (7, 1), (8, -1), (9, 0), (10, 1), (11, -1), (12, 0), (13, 1), (14, -1), (15, 0), (16, 1), (17, -1), (18, 0), (19, 1), (20, -1), (21, 0), (22, 1), (23, -1), (24, 0), (25, 1), (26, -1), (27, 0), (28, 1), (29, -1), (30, 0)]),
+ (5, [(1, 1), (2, -1), (3, -1), (4, 1), (5, 0), (6, 1), (7, -1), (8, -1), (9, 1), (10, 0), (11, 1), (12, -1), (13, -1), (14, 1), (15, 0), (16, 1), (17, -1), (18, -1), (19, 1), (20, 0), (21, 1), (22, -1), (23, -1), (24, 1), (25, 0), (26, 1), (27, -1), (28, -1), (29, 1), (30, 0)]),
+ (7, [(1, 1), (2, 1), (3, -1), (4, 1), (5, -1), (6, -1), (7, 0), (8, 1), (9, 1), (10, -1), (11, 1), (12, -1), (13, -1), (14, 0), (15, 1), (16, 1), (17, -1), (18, 1), (19, -1), (20, -1), (21, 0), (22, 1), (23, 1), (24, -1), (25, 1), (26, -1), (27, -1), (28, 0), (29, 1), (30, 1)]),
+ (9, [(1, 1), (2, 1), (3, 0), (4, 1), (5, 1), (6, 0), (7, 1), (8, 1), (9, 0), (10, 1), (11, 1), (12, 0), (13, 1), (14, 1), (15, 0), (16, 1), (17, 1), (18, 0), (19, 1), (20, 1), (21, 0), (22, 1), (23, 1), (24, 0), (25, 1), (26, 1), (27, 0), (28, 1), (29, 1), (30, 0)]),
+ (11, [(1, 1), (2, -1), (3, 1), (4, 1), (5, 1), (6, -1), (7, -1), (8, -1), (9, 1), (10, -1), (11, 0), (12, 1), (13, -1), (14, 1), (15, 1), (16, 1), (17, -1), (18, -1), (19, -1), (20, 1), (21, -1), (22, 0), (23, 1), (24, -1), (25, 1), (26, 1), (27, 1), (28, -1), (29, -1), (30, -1)]),
+ (13, [(1, 1), (2, -1), (3, 1), (4, 1), (5, -1), (6, -1), (7, -1), (8, -1), (9, 1), (10, 1), (11, -1), (12, 1), (13, 0), (14, 1), (15, -1), (16, 1), (17, 1), (18, -1), (19, -1), (20, -1), (21, -1), (22, 1), (23, 1), (24, -1), (25, 1), (26, 0), (27, 1), (28, -1), (29, 1), (30, 1)]),
+ (15, [(1, 1), (2, 1), (3, 0), (4, 1), (5, 0), (6, 0), (7, -1), (8, 1), (9, 0), (10, 0), (11, -1), (12, 0), (13, -1), (14, -1), (15, 0), (16, 1), (17, 1), (18, 0), (19, 1), (20, 0), (21, 0), (22, -1), (23, 1), (24, 0), (25, 0), (26, -1), (27, 0), (28, -1), (29, -1), (30, 0)]),
+ (17, [(1, 1), (2, 1), (3, -1), (4, 1), (5, -1), (6, -1), (7, -1), (8, 1), (9, 1), (10, -1), (11, -1), (12, -1), (13, 1), (14, -1), (15, 1), (16, 1), (17, 0), (18, 1), (19, 1), (20, -1), (21, 1), (22, -1), (23, -1), (24, -1), (25, 1), (26, 1), (27, -1), (28, -1), (29, -1), (30, 1)]),
+ (19, [(1, 1), (2, -1), (3, -1), (4, 1), (5, 1), (6, 1), (7, 1), (8, -1), (9, 1), (10, -1), (11, 1), (12, -1), (13, -1), (14, -1), (15, -1), (16, 1), (17, 1), (18, -1), (19, 0), (20, 1), (21, -1), (22, -1), (23, 1), (24, 1), (25, 1), (26, 1), (27, -1), (28, 1), (29, -1), (30, 1)]),
+ (21, [(1, 1), (2, -1), (3, 0), (4, 1), (5, 1), (6, 0), (7, 0), (8, -1), (9, 0), (10, -1), (11, -1), (12, 0), (13, -1), (14, 0), (15, 0), (16, 1), (17, 1), (18, 0), (19, -1), (20, 1), (21, 0), (22, 1), (23, -1), (24, 0), (25, 1), (26, 1), (27, 0), (28, 0), (29, -1), (30, 0)]),
+ (23, [(1, 1), (2, 1), (3, 1), (4, 1), (5, -1), (6, 1), (7, -1), (8, 1), (9, 1), (10, -1), (11, -1), (12, 1), (13, 1), (14, -1), (15, -1), (16, 1), (17, -1), (18, 1), (19, -1), (20, -1), (21, -1), (22, -1), (23, 0), (24, 1), (25, 1), (26, 1), (27, 1), (28, -1), (29, 1), (30, -1)]),
+ (25, [(1, 1), (2, 1), (3, 1), (4, 1), (5, 0), (6, 1), (7, 1), (8, 1), (9, 1), (10, 0), (11, 1), (12, 1), (13, 1), (14, 1), (15, 0), (16, 1), (17, 1), (18, 1), (19, 1), (20, 0), (21, 1), (22, 1), (23, 1), (24, 1), (25, 0), (26, 1), (27, 1), (28, 1), (29, 1), (30, 0)]),
+ (27, [(1, 1), (2, -1), (3, 0), (4, 1), (5, -1), (6, 0), (7, 1), (8, -1), (9, 0), (10, 1), (11, -1), (12, 0), (13, 1), (14, -1), (15, 0), (16, 1), (17, -1), (18, 0), (19, 1), (20, -1), (21, 0), (22, 1), (23, -1), (24, 0), (25, 1), (26, -1), (27, 0), (28, 1), (29, -1), (30, 0)]),
+ (29, [(1, 1), (2, -1), (3, -1), (4, 1), (5, 1), (6, 1), (7, 1), (8, -1), (9, 1), (10, -1), (11, -1), (12, -1), (13, 1), (14, -1), (15, -1), (16, 1), (17, -1), (18, -1), (19, -1), (20, 1), (21, -1), (22, 1), (23, 1), (24, 1), (25, 1), (26, -1), (27, -1), (28, 1), (29, 0), (30, 1)]),
+ ]
+
+ js = self.Integer.jacobi_symbol
+
+ for n, kj in tv:
+ for k, j in kj:
+ self.assertEqual(js(k, n), j)
+
+ def test_hex(self):
+ v1, = self.Integers(0x10)
+ self.assertEqual(hex(v1), "0x10")
+
+
+class TestIntegerInt(TestIntegerBase):
+
+ def setUp(self):
+ self.Integer = IntegerNative
+
+
+class testIntegerRandom(unittest.TestCase):
+
+ def test_random_exact_bits(self):
+
+ for _ in range(1000):
+ a = IntegerNative.random(exact_bits=8)
+ self.assertFalse(a < 128)
+ self.assertFalse(a >= 256)
+
+ for bits_value in range(1024, 1024 + 8):
+ a = IntegerNative.random(exact_bits=bits_value)
+ self.assertFalse(a < 2**(bits_value - 1))
+ self.assertFalse(a >= 2**bits_value)
+
+ def test_random_max_bits(self):
+
+ flag = False
+ for _ in range(1000):
+ a = IntegerNative.random(max_bits=8)
+ flag = flag or a < 128
+ self.assertFalse(a>=256)
+ self.assertTrue(flag)
+
+ for bits_value in range(1024, 1024 + 8):
+ a = IntegerNative.random(max_bits=bits_value)
+ self.assertFalse(a >= 2**bits_value)
+
+ def test_random_bits_custom_rng(self):
+
+ class CustomRNG(object):
+ def __init__(self):
+ self.counter = 0
+
+ def __call__(self, size):
+ self.counter += size
+ return bchr(0) * size
+
+ custom_rng = CustomRNG()
+ a = IntegerNative.random(exact_bits=32, randfunc=custom_rng)
+ self.assertEqual(custom_rng.counter, 4)
+
+ def test_random_range(self):
+
+ func = IntegerNative.random_range
+
+ for x in range(200):
+ a = func(min_inclusive=1, max_inclusive=15)
+ self.assertTrue(1 <= a <= 15)
+
+ for x in range(200):
+ a = func(min_inclusive=1, max_exclusive=15)
+ self.assertTrue(1 <= a < 15)
+
+ self.assertRaises(ValueError, func, min_inclusive=1, max_inclusive=2,
+ max_exclusive=3)
+ self.assertRaises(ValueError, func, max_inclusive=2, max_exclusive=3)
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(TestIntegerInt)
+
+ try:
+ from Crypto.Math._IntegerGMP import IntegerGMP
+
+ class TestIntegerGMP(TestIntegerBase):
+ def setUp(self):
+ self.Integer = IntegerGMP
+
+ tests += list_test_cases(TestIntegerGMP)
+ except (ImportError, OSError) as e:
+ if sys.platform == "win32":
+ sys.stdout.write("Skipping GMP tests on Windows\n")
+ else:
+ sys.stdout.write("Skipping GMP tests (%s)\n" % str(e) )
+
+ try:
+ from Crypto.Math._IntegerCustom import IntegerCustom
+
+ class TestIntegerCustomModexp(TestIntegerBase):
+ def setUp(self):
+ self.Integer = IntegerCustom
+
+ tests += list_test_cases(TestIntegerCustomModexp)
+ except (ImportError, OSError) as e:
+ sys.stdout.write("Skipping custom modexp tests (%s)\n" % str(e) )
+
+ tests += list_test_cases(testIntegerRandom)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Math/test_Primality.py b/lib/Crypto/SelfTest/Math/test_Primality.py
new file mode 100644
index 0000000..38344f3
--- /dev/null
+++ b/lib/Crypto/SelfTest/Math/test_Primality.py
@@ -0,0 +1,118 @@
+#
+# SelfTest/Math/test_Primality.py: Self-test for Primality module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test for Math.Numbers"""
+
+import unittest
+
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Util.py3compat import *
+
+from Crypto.Math.Numbers import Integer
+from Crypto.Math.Primality import (
+ PROBABLY_PRIME, COMPOSITE,
+ miller_rabin_test, lucas_test,
+ test_probable_prime,
+ generate_probable_prime,
+ generate_probable_safe_prime,
+ )
+
+
+class TestPrimality(unittest.TestCase):
+
+ primes = (1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 2**127-1, 175637383534939453397801320455508570374088202376942372758907369518414308188137781042871856139027160010343454418881888953150175357127346872102307696660678617989191485418582475696230580407111841072614783095326672517315988762029036079794994990250662362650625650262324085116467511357592728695033227611029693067539)
+ composites = (0, 4, 6, 8, 9, 10, 12, 14, 15, 16, 18, 20, 21, 7*23, (2**19-1)*(2**67-1), 9746347772161,)
+
+ def test_miller_rabin(self):
+ for prime in self.primes:
+ self.assertEqual(miller_rabin_test(prime, 3), PROBABLY_PRIME)
+ for composite in self.composites:
+ self.assertEqual(miller_rabin_test(composite, 3), COMPOSITE)
+ self.assertRaises(ValueError, miller_rabin_test, -1, 3)
+
+ def test_lucas(self):
+ for prime in self.primes:
+ res = lucas_test(prime)
+ self.assertEqual(res, PROBABLY_PRIME)
+ for composite in self.composites:
+ res = lucas_test(composite)
+ self.assertEqual(res, COMPOSITE)
+ self.assertRaises(ValueError, lucas_test, -1)
+
+ def test_is_prime(self):
+ primes = (170141183460469231731687303715884105727,
+ 19175002942688032928599,
+ 1363005552434666078217421284621279933627102780881053358473,
+ 2 ** 521 - 1)
+ for p in primes:
+ self.assertEqual(test_probable_prime(p), PROBABLY_PRIME)
+
+ not_primes = (
+ 4754868377601046732119933839981363081972014948522510826417784001,
+ 1334733877147062382486934807105197899496002201113849920496510541601,
+ 260849323075371835669784094383812120359260783810157225730623388382401,
+ )
+ for np in not_primes:
+ self.assertEqual(test_probable_prime(np), COMPOSITE)
+
+ from Crypto.Util.number import sieve_base
+ for p in sieve_base[:100]:
+ res = test_probable_prime(p)
+ self.assertEqual(res, PROBABLY_PRIME)
+
+ def test_generate_prime_bit_size(self):
+ p = generate_probable_prime(exact_bits=512)
+ self.assertEqual(p.size_in_bits(), 512)
+
+ def test_generate_prime_filter(self):
+ def ending_with_one(number):
+ return number % 10 == 1
+
+ for x in range(20):
+ q = generate_probable_prime(exact_bits=160,
+ prime_filter=ending_with_one)
+ self.assertEqual(q % 10, 1)
+
+ def test_generate_safe_prime(self):
+ p = generate_probable_safe_prime(exact_bits=161)
+ self.assertEqual(p.size_in_bits(), 161)
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(TestPrimality)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Math/test_modexp.py b/lib/Crypto/SelfTest/Math/test_modexp.py
new file mode 100644
index 0000000..b9eb869
--- /dev/null
+++ b/lib/Crypto/SelfTest/Math/test_modexp.py
@@ -0,0 +1,201 @@
+#
+# SelfTest/Math/test_modexp.py: Self-test for module exponentiation
+#
+# ===================================================================
+#
+# Copyright (c) 2017, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-test for the custom module exponentiation"""
+
+import unittest
+
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+
+from Crypto.Util.py3compat import *
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib,
+ create_string_buffer,
+ get_raw_buffer,
+ c_size_t,
+ c_ulonglong)
+
+from Crypto.Hash import SHAKE128
+from Crypto.Math.Numbers import Integer
+from Crypto.Math._IntegerCustom import _raw_montgomery
+
+from Crypto.Random.random import StrongRandom
+
+
+def create_rng(tag):
+ rng = StrongRandom(SHAKE128.new(data=tag))
+ return rng
+
+class ExceptionModulus(ValueError):
+ pass
+
+def monty_pow(base, exp, modulus):
+ max_len = len(long_to_bytes(max(base, exp, modulus)))
+
+ base_b, exp_b, modulus_b = [ long_to_bytes(x, max_len) for x in
+ (base, exp, modulus) ]
+
+ out = create_string_buffer(max_len)
+ error = _raw_montgomery.monty_pow(
+ out,
+ base_b,
+ exp_b,
+ modulus_b,
+ c_size_t(max_len),
+ c_ulonglong(32)
+ )
+
+ if error == 17:
+ raise ExceptionModulus()
+ if error:
+ raise ValueError("monty_pow failed with error: %d" % error)
+
+ result = bytes_to_long(get_raw_buffer(out))
+ return result
+
+exponent1 = 0x2ce0af628901460a419a08ef950d498b9fd6f271a1a52ac293b86fe5c60efe8e8ba93fa1ebe1eb3d614d2e7b328cb60a2591440e163441a190ecf101ceec245f600fffdcf3f5b3a17a7baeacb96a424db1d7ec985e8ec998bb479fecfffed6a75f9a90fc97062fd973303bce855ad7b8d8272a94025e8532be9aabd54a183f303538d2a7e621b4131d59e823a4625f39bd7d518d7784f7c3a8f19061da74974ff42fa1c063dec2db97d461e291a7d6e721708a5229de166c1246363372854e27f3f08ae274bc16bfd205b028a4d81386494433d516dfbb35f495acba5e4e1d1843cb3c3129b6642a85fc7244ce5845fac071c7f622e4ee12ac43fabeeaa0cd01
+modulus1 = 0xd66691b20071be4d66d4b71032b37fa007cfabf579fcb91e50bfc2753b3f0ce7be74e216aef7e26d4ae180bc20d7bd3ea88a6cbf6f87380e613c8979b5b043b200a8ff8856a3b12875e36e98a7569f3852d028e967551000b02c19e9fa52e83115b89309aabb1e1cf1e2cb6369d637d46775ce4523ea31f64ad2794cbc365dd8a35e007ed3b57695877fbf102dbeb8b3212491398e494314e93726926e1383f8abb5889bea954eb8c0ca1c62c8e9d83f41888095c5e645ed6d32515fe0c58c1368cad84694e18da43668c6f43e61d7c9bca633ddcda7aef5b79bc396d4a9f48e2a9abe0836cc455e435305357228e93d25aaed46b952defae0f57339bf26f5a9
+
+
+class TestModExp(unittest.TestCase):
+
+ def test_small(self):
+ self.assertEqual(1, monty_pow(11,12,19))
+
+ def test_large_1(self):
+ base = 0xfffffffffffffffffffffffffffffffffffffffffffffffffff
+ expected = pow(base, exponent1, modulus1)
+ result = monty_pow(base, exponent1, modulus1)
+ self.assertEqual(result, expected)
+
+ def test_zero_exp(self):
+ base = 0xfffffffffffffffffffffffffffffffffffffffffffffffffff
+ result = monty_pow(base, 0, modulus1)
+ self.assertEqual(result, 1)
+
+ def test_zero_base(self):
+ result = monty_pow(0, exponent1, modulus1)
+ self.assertEqual(result, 0)
+
+ def test_zero_modulus(self):
+ base = 0xfffffffffffffffffffffffffffffffffffffffffffffffff
+ self.assertRaises(ExceptionModulus, monty_pow, base, exponent1, 0)
+ self.assertRaises(ExceptionModulus, monty_pow, 0, 0, 0)
+
+ def test_larger_exponent(self):
+ base = modulus1 - 0xFFFFFFF
+ expected = pow(base, modulus1<<64, modulus1)
+ result = monty_pow(base, modulus1<<64, modulus1)
+ self.assertEqual(result, expected)
+
+ def test_even_modulus(self):
+ base = modulus1 >> 4
+ self.assertRaises(ExceptionModulus, monty_pow, base, exponent1, modulus1-1)
+
+ def test_several_lengths(self):
+ prng = SHAKE128.new().update(b('Test'))
+ for length in range(1, 100):
+ modulus2 = Integer.from_bytes(prng.read(length)) | 1
+ base = Integer.from_bytes(prng.read(length)) % modulus2
+ exponent2 = Integer.from_bytes(prng.read(length))
+
+ expected = pow(base, exponent2, modulus2)
+ result = monty_pow(base, exponent2, modulus2)
+ self.assertEqual(result, expected)
+
+ def test_variable_exponent(self):
+ prng = create_rng(b('Test variable exponent'))
+ for i in range(20):
+ for j in range(7):
+ modulus = prng.getrandbits(8*30) | 1
+ base = prng.getrandbits(8*30) % modulus
+ exponent = prng.getrandbits(i*8+j)
+
+ expected = pow(base, exponent, modulus)
+ result = monty_pow(base, exponent, modulus)
+ self.assertEqual(result, expected)
+
+ exponent ^= (1 << (i*8+j)) - 1
+
+ expected = pow(base, exponent, modulus)
+ result = monty_pow(base, exponent, modulus)
+ self.assertEqual(result, expected)
+
+ def test_stress_63(self):
+ prng = create_rng(b('Test 63'))
+ length = 63
+ for _ in range(2000):
+ modulus = prng.getrandbits(8*length) | 1
+ base = prng.getrandbits(8*length) % modulus
+ exponent = prng.getrandbits(8*length)
+
+ expected = pow(base, exponent, modulus)
+ result = monty_pow(base, exponent, modulus)
+ self.assertEqual(result, expected)
+
+ def test_stress_64(self):
+ prng = create_rng(b('Test 64'))
+ length = 64
+ for _ in range(2000):
+ modulus = prng.getrandbits(8*length) | 1
+ base = prng.getrandbits(8*length) % modulus
+ exponent = prng.getrandbits(8*length)
+
+ expected = pow(base, exponent, modulus)
+ result = monty_pow(base, exponent, modulus)
+ self.assertEqual(result, expected)
+
+ def test_stress_65(self):
+ prng = create_rng(b('Test 65'))
+ length = 65
+ for _ in range(2000):
+ modulus = prng.getrandbits(8*length) | 1
+ base = prng.getrandbits(8*length) % modulus
+ exponent = prng.getrandbits(8*length)
+
+ expected = pow(base, exponent, modulus)
+ result = monty_pow(base, exponent, modulus)
+ self.assertEqual(result, expected)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(TestModExp)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Protocol/__init__.py b/lib/Crypto/SelfTest/Protocol/__init__.py
new file mode 100644
index 0000000..1c1c095
--- /dev/null
+++ b/lib/Crypto/SelfTest/Protocol/__init__.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Protocol/__init__.py: Self-tests for Crypto.Protocol
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test for Crypto.Protocol"""
+
+__revision__ = "$Id$"
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest.Protocol import test_rfc1751; tests += test_rfc1751.get_tests(config=config)
+ from Crypto.SelfTest.Protocol import test_KDF; tests += test_KDF.get_tests(config=config)
+
+ from Crypto.SelfTest.Protocol import test_SecretSharing;
+ tests += test_SecretSharing.get_tests(config=config)
+
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Protocol/test_KDF.py b/lib/Crypto/SelfTest/Protocol/test_KDF.py
new file mode 100644
index 0000000..b2869f8
--- /dev/null
+++ b/lib/Crypto/SelfTest/Protocol/test_KDF.py
@@ -0,0 +1,732 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Protocol/test_KDF.py: Self-test for key derivation functions
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import b, bchr
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+from Crypto.Hash import SHA1, HMAC, SHA256, MD5, SHA224, SHA384, SHA512
+from Crypto.Cipher import AES, DES3
+
+from Crypto.Protocol.KDF import (PBKDF1, PBKDF2, _S2V, HKDF, scrypt,
+ bcrypt, bcrypt_check)
+
+from Crypto.Protocol.KDF import _bcrypt_decode
+
+
+def t2b(t):
+ if t is None:
+ return None
+ t2 = t.replace(" ", "").replace("\n", "")
+ return unhexlify(b(t2))
+
+
+class TestVector(object):
+ pass
+
+
+class PBKDF1_Tests(unittest.TestCase):
+
+ # List of tuples with test data.
+ # Each tuple is made up by:
+ # Item #0: a pass phrase
+ # Item #1: salt (8 bytes encoded in hex)
+ # Item #2: output key length
+ # Item #3: iterations to use
+ # Item #4: expected result (encoded in hex)
+ _testData = (
+ # From http://www.di-mgt.com.au/cryptoKDFs.html#examplespbkdf
+ ("password", "78578E5A5D63CB06", 16, 1000, "DC19847E05C64D2FAF10EBFB4A3D2A20"),
+ )
+
+ def test1(self):
+ v = self._testData[0]
+ res = PBKDF1(v[0], t2b(v[1]), v[2], v[3], SHA1)
+ self.assertEqual(res, t2b(v[4]))
+
+
+class PBKDF2_Tests(unittest.TestCase):
+
+ # List of tuples with test data.
+ # Each tuple is made up by:
+ # Item #0: a pass phrase
+ # Item #1: salt (encoded in hex)
+ # Item #2: output key length
+ # Item #3: iterations to use
+ # Item #4: hash module
+ # Item #5: expected result (encoded in hex)
+ _testData = (
+ # From http://www.di-mgt.com.au/cryptoKDFs.html#examplespbkdf
+ ("password","78578E5A5D63CB06",24,2048, SHA1, "BFDE6BE94DF7E11DD409BCE20A0255EC327CB936FFE93643"),
+ # From RFC 6050
+ ("password","73616c74", 20, 1, SHA1, "0c60c80f961f0e71f3a9b524af6012062fe037a6"),
+ ("password","73616c74", 20, 2, SHA1, "ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957"),
+ ("password","73616c74", 20, 4096, SHA1, "4b007901b765489abead49d926f721d065a429c1"),
+ ("passwordPASSWORDpassword","73616c7453414c5473616c7453414c5473616c7453414c5473616c7453414c5473616c74",
+ 25, 4096, SHA1, "3d2eec4fe41c849b80c8d83662c0e44a8b291a964cf2f07038"),
+ ( 'pass\x00word',"7361006c74",16,4096, SHA1, "56fa6aa75548099dcc37d7f03425e0c3"),
+ # From draft-josefsson-scrypt-kdf-01, Chapter 10
+ ( 'passwd', '73616c74', 64, 1, SHA256, "55ac046e56e3089fec1691c22544b605f94185216dde0465e68b9d57c20dacbc49ca9cccf179b645991664b39d77ef317c71b845b1e30bd509112041d3a19783"),
+ ( 'Password', '4e61436c', 64, 80000, SHA256, "4ddcd8f60b98be21830cee5ef22701f9641a4418d04c0414aeff08876b34ab56a1d425a1225833549adb841b51c9b3176a272bdebba1d078478f62b397f33c8d"),
+ )
+
+ def test1(self):
+ # Test only for HMAC-SHA1 as PRF
+
+ def prf_SHA1(p,s):
+ return HMAC.new(p,s,SHA1).digest()
+
+ def prf_SHA256(p,s):
+ return HMAC.new(p,s,SHA256).digest()
+
+ for i in range(len(self._testData)):
+ v = self._testData[i]
+ password = v[0]
+ salt = t2b(v[1])
+ out_len = v[2]
+ iters = v[3]
+ hash_mod = v[4]
+ expected = t2b(v[5])
+
+ if hash_mod is SHA1:
+ res = PBKDF2(password, salt, out_len, iters)
+ self.assertEqual(res, expected)
+
+ res = PBKDF2(password, salt, out_len, iters, prf_SHA1)
+ self.assertEqual(res, expected)
+ else:
+ res = PBKDF2(password, salt, out_len, iters, prf_SHA256)
+ self.assertEqual(res, expected)
+
+ def test2(self):
+ # Verify that prf and hmac_hash_module are mutual exclusive
+ def prf_SHA1(p,s):
+ return HMAC.new(p,s,SHA1).digest()
+
+ self.assertRaises(ValueError, PBKDF2, b("xxx"), b("yyy"), 16, 100,
+ prf=prf_SHA1, hmac_hash_module=SHA1)
+
+ def test3(self):
+ # Verify that hmac_hash_module works like prf
+
+ password = b("xxx")
+ salt = b("yyy")
+
+ for hashmod in (MD5, SHA1, SHA224, SHA256, SHA384, SHA512):
+
+ pr1 = PBKDF2(password, salt, 16, 100,
+ prf=lambda p, s: HMAC.new(p,s,hashmod).digest())
+ pr2 = PBKDF2(password, salt, 16, 100, hmac_hash_module=hashmod)
+
+ self.assertEqual(pr1, pr2)
+
+ def test4(self):
+ # Verify that PBKDF2 can take bytes or strings as password or salt
+ k1 = PBKDF2("xxx", b("yyy"), 16, 10)
+ k2 = PBKDF2(b("xxx"), b("yyy"), 16, 10)
+ self.assertEqual(k1, k2)
+
+ k1 = PBKDF2(b("xxx"), "yyy", 16, 10)
+ k2 = PBKDF2(b("xxx"), b("yyy"), 16, 10)
+ self.assertEqual(k1, k2)
+
+
+class S2V_Tests(unittest.TestCase):
+
+ # Sequence of test vectors.
+ # Each test vector is made up by:
+ # Item #0: a tuple of strings
+ # Item #1: an AES key
+ # Item #2: the result
+ # Item #3: the cipher module S2V is based on
+ # Everything is hex encoded
+ _testData = [
+
+ # RFC5297, A.1
+ (
+ ( '101112131415161718191a1b1c1d1e1f2021222324252627',
+ '112233445566778899aabbccddee' ),
+ 'fffefdfcfbfaf9f8f7f6f5f4f3f2f1f0',
+ '85632d07c6e8f37f950acd320a2ecc93',
+ AES
+ ),
+
+ # RFC5297, A.2
+ (
+ ( '00112233445566778899aabbccddeeffdeaddadadeaddadaffeeddcc'+
+ 'bbaa99887766554433221100',
+ '102030405060708090a0',
+ '09f911029d74e35bd84156c5635688c0',
+ '7468697320697320736f6d6520706c61'+
+ '696e7465787420746f20656e63727970'+
+ '74207573696e67205349562d414553'),
+ '7f7e7d7c7b7a79787776757473727170',
+ '7bdb6e3b432667eb06f4d14bff2fbd0f',
+ AES
+ ),
+
+ ]
+
+ def test1(self):
+ """Verify correctness of test vector"""
+ for tv in self._testData:
+ s2v = _S2V.new(t2b(tv[1]), tv[3])
+ for s in tv[0]:
+ s2v.update(t2b(s))
+ result = s2v.derive()
+ self.assertEqual(result, t2b(tv[2]))
+
+ def test2(self):
+ """Verify that no more than 127(AES) and 63(TDES)
+ components are accepted."""
+ key = bchr(0) * 8 + bchr(255) * 8
+ for module in (AES, DES3):
+ s2v = _S2V.new(key, module)
+ max_comps = module.block_size*8-1
+ for i in range(max_comps):
+ s2v.update(b("XX"))
+ self.assertRaises(TypeError, s2v.update, b("YY"))
+
+
+class HKDF_Tests(unittest.TestCase):
+
+ # Test vectors from RFC5869, Appendix A
+ # Each tuple is made up by:
+ # Item #0: hash module
+ # Item #1: secret
+ # Item #2: salt
+ # Item #3: context
+ # Item #4: expected result
+ _test_vector = (
+ (
+ SHA256,
+ "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
+ "000102030405060708090a0b0c",
+ "f0f1f2f3f4f5f6f7f8f9",
+ 42,
+ "3cb25f25faacd57a90434f64d0362f2a" +
+ "2d2d0a90cf1a5a4c5db02d56ecc4c5bf" +
+ "34007208d5b887185865"
+ ),
+ (
+ SHA256,
+ "000102030405060708090a0b0c0d0e0f" +
+ "101112131415161718191a1b1c1d1e1f" +
+ "202122232425262728292a2b2c2d2e2f" +
+ "303132333435363738393a3b3c3d3e3f" +
+ "404142434445464748494a4b4c4d4e4f",
+ "606162636465666768696a6b6c6d6e6f" +
+ "707172737475767778797a7b7c7d7e7f" +
+ "808182838485868788898a8b8c8d8e8f" +
+ "909192939495969798999a9b9c9d9e9f" +
+ "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
+ "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" +
+ "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" +
+ "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" +
+ "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" +
+ "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
+ 82,
+ "b11e398dc80327a1c8e7f78c596a4934" +
+ "4f012eda2d4efad8a050cc4c19afa97c" +
+ "59045a99cac7827271cb41c65e590e09" +
+ "da3275600c2f09b8367793a9aca3db71" +
+ "cc30c58179ec3e87c14c01d5c1f3434f" +
+ "1d87"
+ ),
+ (
+ SHA256,
+ "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
+ None,
+ None,
+ 42,
+ "8da4e775a563c18f715f802a063c5a31" +
+ "b8a11f5c5ee1879ec3454e5f3c738d2d" +
+ "9d201395faa4b61a96c8"
+ ),
+ (
+ SHA1,
+ "0b0b0b0b0b0b0b0b0b0b0b",
+ "000102030405060708090a0b0c",
+ "f0f1f2f3f4f5f6f7f8f9",
+ 42,
+ "085a01ea1b10f36933068b56efa5ad81" +
+ "a4f14b822f5b091568a9cdd4f155fda2" +
+ "c22e422478d305f3f896"
+ ),
+ (
+ SHA1,
+ "000102030405060708090a0b0c0d0e0f" +
+ "101112131415161718191a1b1c1d1e1f" +
+ "202122232425262728292a2b2c2d2e2f" +
+ "303132333435363738393a3b3c3d3e3f" +
+ "404142434445464748494a4b4c4d4e4f",
+ "606162636465666768696a6b6c6d6e6f" +
+ "707172737475767778797a7b7c7d7e7f" +
+ "808182838485868788898a8b8c8d8e8f" +
+ "909192939495969798999a9b9c9d9e9f" +
+ "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf",
+ "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" +
+ "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" +
+ "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" +
+ "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" +
+ "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
+ 82,
+ "0bd770a74d1160f7c9f12cd5912a06eb" +
+ "ff6adcae899d92191fe4305673ba2ffe" +
+ "8fa3f1a4e5ad79f3f334b3b202b2173c" +
+ "486ea37ce3d397ed034c7f9dfeb15c5e" +
+ "927336d0441f4c4300e2cff0d0900b52" +
+ "d3b4"
+ ),
+ (
+ SHA1,
+ "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
+ "",
+ "",
+ 42,
+ "0ac1af7002b3d761d1e55298da9d0506" +
+ "b9ae52057220a306e07b6b87e8df21d0" +
+ "ea00033de03984d34918"
+ ),
+ (
+ SHA1,
+ "0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c0c",
+ None,
+ "",
+ 42,
+ "2c91117204d745f3500d636a62f64f0a" +
+ "b3bae548aa53d423b0d1f27ebba6f5e5" +
+ "673a081d70cce7acfc48"
+ )
+ )
+
+ def test1(self):
+ for tv in self._test_vector:
+ secret, salt, info, exp = [ t2b(tv[x]) for x in (1,2,3,5) ]
+ key_len, hashmod = [ tv[x] for x in (4,0) ]
+
+ output = HKDF(secret, key_len, salt, hashmod, 1, info)
+ self.assertEqual(output, exp)
+
+ def test2(self):
+ ref = HKDF(b("XXXXXX"), 12, b("YYYY"), SHA1)
+
+ # Same output, but this time split over 2 keys
+ key1, key2 = HKDF(b("XXXXXX"), 6, b("YYYY"), SHA1, 2)
+ self.assertEqual((ref[:6], ref[6:]), (key1, key2))
+
+ # Same output, but this time split over 3 keys
+ key1, key2, key3 = HKDF(b("XXXXXX"), 4, b("YYYY"), SHA1, 3)
+ self.assertEqual((ref[:4], ref[4:8], ref[8:]), (key1, key2, key3))
+
+
+class scrypt_Tests(unittest.TestCase):
+
+ # Test vectors taken from
+ # https://tools.ietf.org/html/rfc7914
+ # - password
+ # - salt
+ # - N
+ # - r
+ # - p
+ data = (
+ (
+ "",
+ "",
+ 16, # 2K
+ 1,
+ 1,
+ """
+ 77 d6 57 62 38 65 7b 20 3b 19 ca 42 c1 8a 04 97
+ f1 6b 48 44 e3 07 4a e8 df df fa 3f ed e2 14 42
+ fc d0 06 9d ed 09 48 f8 32 6a 75 3a 0f c8 1f 17
+ e8 d3 e0 fb 2e 0d 36 28 cf 35 e2 0c 38 d1 89 06
+ """
+ ),
+ (
+ "password",
+ "NaCl",
+ 1024, # 1M
+ 8,
+ 16,
+ """
+ fd ba be 1c 9d 34 72 00 78 56 e7 19 0d 01 e9 fe
+ 7c 6a d7 cb c8 23 78 30 e7 73 76 63 4b 37 31 62
+ 2e af 30 d9 2e 22 a3 88 6f f1 09 27 9d 98 30 da
+ c7 27 af b9 4a 83 ee 6d 83 60 cb df a2 cc 06 40
+ """
+ ),
+ (
+ "pleaseletmein",
+ "SodiumChloride",
+ 16384, # 16M
+ 8,
+ 1,
+ """
+ 70 23 bd cb 3a fd 73 48 46 1c 06 cd 81 fd 38 eb
+ fd a8 fb ba 90 4f 8e 3e a9 b5 43 f6 54 5d a1 f2
+ d5 43 29 55 61 3f 0f cf 62 d4 97 05 24 2a 9a f9
+ e6 1e 85 dc 0d 65 1e 40 df cf 01 7b 45 57 58 87
+ """
+ ),
+ (
+ "pleaseletmein",
+ "SodiumChloride",
+ 1048576, # 1G
+ 8,
+ 1,
+ """
+ 21 01 cb 9b 6a 51 1a ae ad db be 09 cf 70 f8 81
+ ec 56 8d 57 4a 2f fd 4d ab e5 ee 98 20 ad aa 47
+ 8e 56 fd 8f 4b a5 d0 9f fa 1c 6d 92 7c 40 f4 c3
+ 37 30 40 49 e8 a9 52 fb cb f4 5c 6f a7 7a 41 a4
+ """
+ ),
+ )
+
+ def setUp(self):
+ new_test_vectors = []
+ for tv in self.data:
+ new_tv = TestVector()
+ new_tv.P = b(tv[0])
+ new_tv.S = b(tv[1])
+ new_tv.N = tv[2]
+ new_tv.r = tv[3]
+ new_tv.p = tv[4]
+ new_tv.output = t2b(tv[5])
+ new_tv.dkLen = len(new_tv.output)
+ new_test_vectors.append(new_tv)
+ self.data = new_test_vectors
+
+ def test2(self):
+
+ for tv in self.data:
+ try:
+ output = scrypt(tv.P, tv.S, tv.dkLen, tv.N, tv.r, tv.p)
+ except ValueError as e:
+ if " 2 " in str(e) and tv.N >= 1048576:
+ import warnings
+ warnings.warn("Not enough memory to unit test scrypt() with N=1048576", RuntimeWarning)
+ continue
+ else:
+ raise e
+ self.assertEqual(output, tv.output)
+
+ def test3(self):
+ ref = scrypt(b("password"), b("salt"), 12, 16, 1, 1)
+
+ # Same output, but this time split over 2 keys
+ key1, key2 = scrypt(b("password"), b("salt"), 6, 16, 1, 1, 2)
+ self.assertEqual((ref[:6], ref[6:]), (key1, key2))
+
+ # Same output, but this time split over 3 keys
+ key1, key2, key3 = scrypt(b("password"), b("salt"), 4, 16, 1, 1, 3)
+ self.assertEqual((ref[:4], ref[4:8], ref[8:]), (key1, key2, key3))
+
+
+class bcrypt_Tests(unittest.TestCase):
+
+ def test_negative_cases(self):
+ self.assertRaises(ValueError, bcrypt, b"1" * 73, 10)
+ self.assertRaises(ValueError, bcrypt, b"1" * 10, 3)
+ self.assertRaises(ValueError, bcrypt, b"1" * 10, 32)
+ self.assertRaises(ValueError, bcrypt, b"1" * 10, 4, salt=b"")
+ self.assertRaises(ValueError, bcrypt, b"1" * 10, 4, salt=b"1")
+ self.assertRaises(ValueError, bcrypt, b"1" * 10, 4, salt=b"1" * 17)
+ self.assertRaises(ValueError, bcrypt, b"1\x00" * 10, 4)
+
+ def test_bytearray_mismatch(self):
+ ref = bcrypt("pwd", 4)
+ bcrypt_check("pwd", ref)
+ bref = bytearray(ref)
+ bcrypt_check("pwd", bref)
+
+ wrong = ref[:-1] + bchr(bref[-1] ^ 0x01)
+ self.assertRaises(ValueError, bcrypt_check, "pwd", wrong)
+
+ wrong = b"x" + ref[1:]
+ self.assertRaises(ValueError, bcrypt_check, "pwd", wrong)
+
+ # https://github.com/patrickfav/bcrypt/wiki/Published-Test-Vectors
+
+ def test_empty_password(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ (b"", 4, b"zVHmKQtGGQob.b/Nc7l9NO", b"$2a$04$zVHmKQtGGQob.b/Nc7l9NO8UlrYcW05FiuCj/SxsFO/ZtiN9.mNzy"),
+ (b"", 5, b"zVHmKQtGGQob.b/Nc7l9NO", b"$2a$05$zVHmKQtGGQob.b/Nc7l9NOWES.1hkVBgy5IWImh9DOjKNU8atY4Iy"),
+ (b"", 6, b"zVHmKQtGGQob.b/Nc7l9NO", b"$2a$06$zVHmKQtGGQob.b/Nc7l9NOjOl7l4oz3WSh5fJ6414Uw8IXRAUoiaO"),
+ (b"", 7, b"zVHmKQtGGQob.b/Nc7l9NO", b"$2a$07$zVHmKQtGGQob.b/Nc7l9NOBsj1dQpBA1HYNGpIETIByoNX9jc.hOi"),
+ (b"", 8, b"zVHmKQtGGQob.b/Nc7l9NO", b"$2a$08$zVHmKQtGGQob.b/Nc7l9NOiLTUh/9MDpX86/DLyEzyiFjqjBFePgO"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_random_password_and_salt_short_pw(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ (b"<.S.2K(Zq'", 4, b"VYAclAMpaXY/oqAo9yUpku", b"$2a$04$VYAclAMpaXY/oqAo9yUpkuWmoYywaPzyhu56HxXpVltnBIfmO9tgu"),
+ (b"5.rApO%5jA", 5, b"kVNDrnYKvbNr5AIcxNzeIu", b"$2a$05$kVNDrnYKvbNr5AIcxNzeIuRcyIF5cZk6UrwHGxENbxP5dVv.WQM/G"),
+ (b"oW++kSrQW^", 6, b"QLKkRMH9Am6irtPeSKN5sO", b"$2a$06$QLKkRMH9Am6irtPeSKN5sObJGr3j47cO6Pdf5JZ0AsJXuze0IbsNm"),
+ (b"ggJ\\KbTnDG", 7, b"4H896R09bzjhapgCPS/LYu", b"$2a$07$4H896R09bzjhapgCPS/LYuMzAQluVgR5iu/ALF8L8Aln6lzzYXwbq"),
+ (b"49b0:;VkH/", 8, b"hfvO2retKrSrx5f2RXikWe", b"$2a$08$hfvO2retKrSrx5f2RXikWeFWdtSesPlbj08t/uXxCeZoHRWDz/xFe"),
+ (b">9N^5jc##'", 9, b"XZLvl7rMB3EvM0c1.JHivu", b"$2a$09$XZLvl7rMB3EvM0c1.JHivuIDPJWeNJPTVrpjZIEVRYYB/mF6cYgJK"),
+ (b"\\$ch)s4WXp", 10, b"aIjpMOLK5qiS9zjhcHR5TO", b"$2a$10$aIjpMOLK5qiS9zjhcHR5TOU7v2NFDmcsBmSFDt5EHOgp/jeTF3O/q"),
+ (b"RYoj\\_>2P7", 12, b"esIAHiQAJNNBrsr5V13l7.", b"$2a$12$esIAHiQAJNNBrsr5V13l7.RFWWJI2BZFtQlkFyiWXjou05GyuREZa"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_random_password_and_salt_long_pw(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ (b"^Q&\"]A`%/A(BVGt>QaX0M-#<Q148&f", 4, b"vrRP5vQxyD4LrqiLd/oWRO", b"$2a$04$vrRP5vQxyD4LrqiLd/oWROgrrGINsw3gb4Ga5x2sn01jNmiLVECl6"),
+ (b"nZa!rRf\\U;OL;R?>1ghq_+\":Y0CRmY", 5, b"YuQvhokOGVnevctykUYpKu", b"$2a$05$YuQvhokOGVnevctykUYpKutZD2pWeGGYn3auyLOasguMY3/0BbIyq"),
+ (b"F%uN/j>[GuB7-jB'_Yj!Tnb7Y!u^6)", 6, b"5L3vpQ0tG9O7k5gQ8nAHAe", b"$2a$06$5L3vpQ0tG9O7k5gQ8nAHAe9xxQiOcOLh8LGcI0PLWhIznsDt.S.C6"),
+ (b"Z>BobP32ub\"Cfe*Q<<WUq3rc=[GJr-", 7, b"hp8IdLueqE6qFh1zYycUZ.", b"$2a$07$hp8IdLueqE6qFh1zYycUZ.twmUH8eSTPQAEpdNXKMlwms9XfKqfea"),
+ (b"Ik&8N['7*[1aCc1lOm8\\jWeD*H$eZM", 8, b"2ANDTYCB9m7vf0Prh7rSru", b"$2a$08$2ANDTYCB9m7vf0Prh7rSrupqpO3jJOkIz2oW/QHB4lCmK7qMytGV6"),
+ (b"O)=%3[E$*q+>-q-=tRSjOBh8\\mLNW.", 9, b"nArqOfdCsD9kIbVnAixnwe", b"$2a$09$nArqOfdCsD9kIbVnAixnwe6s8QvyPYWtQBpEXKir2OJF9/oNBsEFe"),
+ (b"/MH51`!BP&0tj3%YCA;Xk%e3S`o\\EI", 10, b"ePiAc.s.yoBi3B6p1iQUCe", b"$2a$10$ePiAc.s.yoBi3B6p1iQUCezn3mraLwpVJ5XGelVyYFKyp5FZn/y.u"),
+ (b"ptAP\"mcg6oH.\";c0U2_oll.OKi<!ku", 12, b"aroG/pwwPj1tU5fl9a9pkO", b"$2a$12$aroG/pwwPj1tU5fl9a9pkO4rydAmkXRj/LqfHZOSnR6LGAZ.z.jwa"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_same_password_and_random_salt(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ (b"Q/A:k3DP;X@=<0\"hg&9c", 4, b"wbgDTvLMtyjQlNK7fjqwyO", b"$2a$04$wbgDTvLMtyjQlNK7fjqwyOakBoACQuYh11.VsKNarF4xUIOBWgD6S"),
+ (b"Q/A:k3DP;X@=<0\"hg&9c", 5, b"zbAaOmloOhxiKItjznRqru", b"$2a$05$zbAaOmloOhxiKItjznRqrunRqHlu3MAa7pMGv26Rr3WwyfGcwoRm6"),
+ (b"Q/A:k3DP;X@=<0\"hg&9c", 6, b"aOK0bWUvLI0qLkc3ti5jyu", b"$2a$06$aOK0bWUvLI0qLkc3ti5jyuAIQoqRzuqoK09kQqQ6Ou/YKDhW50/qa"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_same_password_and_salt_increasing_cost_factor(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ (b"o<&+X'F4AQ8H,LU,N`&r", 4, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$04$BK5u.QHk1Driey7bvnFTH.3smGwxd91PtoK2GxH5nZ7pcBsYX4lMq"),
+ (b"o<&+X'F4AQ8H,LU,N`&r", 5, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$05$BK5u.QHk1Driey7bvnFTH.t5P.jZvFBMzDB1IY4PwkkRPOyVbEtFG"),
+ (b"o<&+X'F4AQ8H,LU,N`&r", 6, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$06$BK5u.QHk1Driey7bvnFTH.6Ea1Z5db2p25CPXZbxb/3OyKQagg3pa"),
+ (b"o<&+X'F4AQ8H,LU,N`&r", 7, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$07$BK5u.QHk1Driey7bvnFTH.sruuQi8Lhv/0LWKDvNp3AGFk7ltdkm6"),
+ (b"o<&+X'F4AQ8H,LU,N`&r", 8, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$08$BK5u.QHk1Driey7bvnFTH.IE7KsaUzc4m7gzAMlyUPUeiYyACWe0q"),
+ (b"o<&+X'F4AQ8H,LU,N`&r", 9, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$09$BK5u.QHk1Driey7bvnFTH.1v4Xj1dwkp44QNg0cVAoQt4FQMMrvnS"),
+ (b"o<&+X'F4AQ8H,LU,N`&r", 10, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$10$BK5u.QHk1Driey7bvnFTH.ESINe9YntUMcVgFDfkC.Vbhc9vMhNX2"),
+ (b"o<&+X'F4AQ8H,LU,N`&r", 12, b"BK5u.QHk1Driey7bvnFTH.", b"$2a$12$BK5u.QHk1Driey7bvnFTH.QM1/nnGe/f5cTzb6XTTi/vMzcAnycqG"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_long_passwords(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ (b"g*3Q45=\"8NNgpT&mbMJ$Omfr.#ZeW?FP=CE$#roHd?97uL0F-]`?u73c\"\\[.\"*)qU34@VG",
+ 4, b"T2XJ5MOWvHQZRijl8LIKkO", b"$2a$04$T2XJ5MOWvHQZRijl8LIKkOQKIyX75KBfuLsuRYOJz5OjwBNF2lM8a"),
+ (b"\\M+*8;&QE=Ll[>5?Ui\"^ai#iQH7ZFtNMfs3AROnIncE9\"BNNoEgO[[*Yk8;RQ(#S,;I+aT",
+ 5, b"wgkOlGNXIVE2fWkT3gyRoO", b"$2a$05$wgkOlGNXIVE2fWkT3gyRoOqWi4gbi1Wv2Q2Jx3xVs3apl1w.Wtj8C"),
+ (b"M.E1=dt<.L0Q&p;94NfGm_Oo23+Kpl@M5?WIAL.[@/:'S)W96G8N^AWb7_smmC]>7#fGoB",
+ 6, b"W9zTCl35nEvUukhhFzkKMe", b"$2a$06$W9zTCl35nEvUukhhFzkKMekjT9/pj7M0lihRVEZrX3m8/SBNZRX7i"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_increasing_password_length(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ (b"a", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.l4WvgHIVg17ZawDIrDM2IjlE64GDNQS"),
+ (b"aa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.AyUxBk.ThHlsLvRTH7IqcG7yVHJ3SXq"),
+ (b"aaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.BxOVac5xPB6XFdRc/ZrzM9FgZkqmvbW"),
+ (b"aaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.Qbr209bpCtfl5hN7UQlG/L4xiD3AKau"),
+ (b"aaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.oWszihPjDZI0ypReKsaDOW1jBl7oOii"),
+ (b"aaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ./k.Xxn9YiqtV/sxh3EHbnOHd0Qsq27K"),
+ (b"aaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.PYJqRFQbgRbIjMd5VNKmdKS4sBVOyDe"),
+ (b"aaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ..VMYfzaw1wP/SGxowpLeGf13fxCCt.q"),
+ (b"aaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.5B0p054nO5WgAD1n04XslDY/bqY9RJi"),
+ (b"aaaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.INBTgqm7sdlBJDg.J5mLMSRK25ri04y"),
+ (b"aaaaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.s3y7CdFD0OR5p6rsZw/eZ.Dla40KLfm"),
+ (b"aaaaaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.Jx742Djra6Q7PqJWnTAS.85c28g.Siq"),
+ (b"aaaaaaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.oKMXW3EZcPHcUV0ib5vDBnh9HojXnLu"),
+ (b"aaaaaaaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.w6nIjWpDPNSH5pZUvLjC1q25ONEQpeS"),
+ (b"aaaaaaaaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.k1b2/r9A/hxdwKEKurg6OCn4MwMdiGq"),
+ (b"aaaaaaaaaaaaaaaa", 4, b"5DCebwootqWMCp59ISrMJ.", b"$2a$04$5DCebwootqWMCp59ISrMJ.3prCNHVX1Ws.7Hm2bJxFUnQOX9f7DFa"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_non_ascii_characters(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ ("àèìòùÀÈÌÒÙáéíóúýÁÉÍÓÚÝðÐ", 4, b"D3qS2aoTVyqM7z8v8crLm.", b"$2a$04$D3qS2aoTVyqM7z8v8crLm.3nKt4CzBZJbyFB.ZebmfCvRw7BGs.Xm"),
+ ("àèìòùÀÈÌÒÙáéíóúýÁÉÍÓÚÝðÐ", 5, b"VA1FujiOCMPkUHQ8kF7IaO", b"$2a$05$VA1FujiOCMPkUHQ8kF7IaOg7NGaNvpxwWzSluQutxEVmbZItRTsAa"),
+ ("àèìòùÀÈÌÒÙáéíóúýÁÉÍÓÚÝðÐ", 6, b"TXiaNrPeBSz5ugiQlehRt.", b"$2a$06$TXiaNrPeBSz5ugiQlehRt.gwpeDQnXWteQL4z2FulouBr6G7D9KUi"),
+ ("âêîôûÂÊÎÔÛãñõÃÑÕäëïöüÿ", 4, b"YTn1Qlvps8e1odqMn6G5x.", b"$2a$04$YTn1Qlvps8e1odqMn6G5x.85pqKql6w773EZJAExk7/BatYAI4tyO"),
+ ("âêîôûÂÊÎÔÛãñõÃÑÕäëïöüÿ", 5, b"C.8k5vJKD2NtfrRI9o17DO", b"$2a$05$C.8k5vJKD2NtfrRI9o17DOfIW0XnwItA529vJnh2jzYTb1QdoY0py"),
+ ("âêîôûÂÊÎÔÛãñõÃÑÕäëïöüÿ", 6, b"xqfRPj3RYAgwurrhcA6uRO", b"$2a$06$xqfRPj3RYAgwurrhcA6uROtGlXDp/U6/gkoDYHwlubtcVcNft5.vW"),
+ ("ÄËÏÖÜŸåÅæÆœŒßçÇøØ¢¿¡€", 4, b"y8vGgMmr9EdyxP9rmMKjH.", b"$2a$04$y8vGgMmr9EdyxP9rmMKjH.wv2y3r7yRD79gykQtmb3N3zrwjKsyay"),
+ ("ÄËÏÖÜŸåÅæÆœŒßçÇøØ¢¿¡€", 5, b"iYH4XIKAOOm/xPQs7xKP1u", b"$2a$05$iYH4XIKAOOm/xPQs7xKP1upD0cWyMn3Jf0ZWiizXbEkVpS41K1dcO"),
+ ("ÄËÏÖÜŸåÅæÆœŒßçÇøØ¢¿¡€", 6, b"wCOob.D0VV8twafNDB2ape", b"$2a$06$wCOob.D0VV8twafNDB2apegiGD5nqF6Y1e6K95q6Y.R8C4QGd265q"),
+ ("ΔημοσιεύθηκεστηνΕφημερίδατης", 4, b"E5SQtS6P4568MDXW7cyUp.", b"$2a$04$E5SQtS6P4568MDXW7cyUp.18wfDisKZBxifnPZjAI1d/KTYMfHPYO"),
+ ("АБбВвГгДдЕеЁёЖжЗзИиЙйКкЛлМмН", 4, b"03e26gQFHhQwRNf81/ww9.", b"$2a$04$03e26gQFHhQwRNf81/ww9.p1UbrNwxpzWjLuT.zpTLH4t/w5WhAhC"),
+ ("нОоПпРрСсТтУуФфХхЦцЧчШшЩщЪъЫыЬьЭэЮю", 4, b"PHNoJwpXCfe32nUtLv2Upu", b"$2a$04$PHNoJwpXCfe32nUtLv2UpuhJXOzd4k7IdFwnEpYwfJVCZ/f/.8Pje"),
+ ("電电電島岛島兔兔兎龜龟亀國国国區区区", 4, b"wU4/0i1TmNl2u.1jIwBX.u", b"$2a$04$wU4/0i1TmNl2u.1jIwBX.uZUaOL3Rc5ID7nlQRloQh6q5wwhV/zLW"),
+ ("诶比伊艾弗豆贝尔维吾艾尺开艾丝维贼德", 4, b"P4kreGLhCd26d4WIy7DJXu", b"$2a$04$P4kreGLhCd26d4WIy7DJXusPkhxLvBouzV6OXkL5EB0jux0osjsry"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+ def test_special_case_salt(self):
+ # password, cost, salt, bcrypt hash
+ tvs = [
+ ("-O_=*N!2JP", 4, b"......................", b"$2a$04$......................JjuKLOX9OOwo5PceZZXSkaLDvdmgb82"),
+ ("7B[$Q<4b>U", 5, b"......................", b"$2a$05$......................DRiedDQZRL3xq5A5FL8y7/6NM8a2Y5W"),
+ (">d5-I_8^.h", 6, b"......................", b"$2a$06$......................5Mq1Ng8jgDY.uHNU4h5p/x6BedzNH2W"),
+ (")V`/UM/]1t", 4, b".OC/.OC/.OC/.OC/.OC/.O", b"$2a$04$.OC/.OC/.OC/.OC/.OC/.OQIvKRDAam.Hm5/IaV/.hc7P8gwwIbmi"),
+ (":@t2.bWuH]", 5, b".OC/.OC/.OC/.OC/.OC/.O", b"$2a$05$.OC/.OC/.OC/.OC/.OC/.ONDbUvdOchUiKmQORX6BlkPofa/QxW9e"),
+ ("b(#KljF5s\"", 6, b".OC/.OC/.OC/.OC/.OC/.O", b"$2a$06$.OC/.OC/.OC/.OC/.OC/.OHfTd9e7svOu34vi1PCvOcAEq07ST7.K"),
+ ("@3YaJ^Xs]*", 4, b"eGA.eGA.eGA.eGA.eGA.e.", b"$2a$04$eGA.eGA.eGA.eGA.eGA.e.stcmvh.R70m.0jbfSFVxlONdj1iws0C"),
+ ("'\"5\\!k*C(p", 5, b"eGA.eGA.eGA.eGA.eGA.e.", b"$2a$05$eGA.eGA.eGA.eGA.eGA.e.vR37mVSbfdHwu.F0sNMvgn8oruQRghy"),
+ ("edEu7C?$'W", 6, b"eGA.eGA.eGA.eGA.eGA.e.", b"$2a$06$eGA.eGA.eGA.eGA.eGA.e.tSq0FN8MWHQXJXNFnHTPQKtA.n2a..G"),
+ ("N7dHmg\\PI^", 4, b"999999999999999999999u", b"$2a$04$999999999999999999999uCZfA/pLrlyngNDMq89r1uUk.bQ9icOu"),
+ ("\"eJuHh!)7*", 5, b"999999999999999999999u", b"$2a$05$999999999999999999999uj8Pfx.ufrJFAoWFLjapYBS5vVEQQ/hK"),
+ ("ZeDRJ:_tu:", 6, b"999999999999999999999u", b"$2a$06$999999999999999999999u6RB0P9UmbdbQgjoQFEJsrvrKe.BoU6q"),
+ ]
+
+ for (idx, (password, cost, salt64, result)) in enumerate(tvs):
+ x = bcrypt(password, cost, salt=_bcrypt_decode(salt64))
+ self.assertEqual(x, result)
+ bcrypt_check(password, result)
+
+
+class TestVectorsHKDFWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._id = "None"
+
+ def add_tests(self, filename):
+
+ def filter_algo(root):
+ algo_name = root['algorithm']
+ if algo_name == "HKDF-SHA-1":
+ return SHA1
+ elif algo_name == "HKDF-SHA-256":
+ return SHA256
+ elif algo_name == "HKDF-SHA-384":
+ return SHA384
+ elif algo_name == "HKDF-SHA-512":
+ return SHA512
+ else:
+ raise ValueError("Unknown algorithm " + algo_name)
+
+ def filter_size(unit):
+ return int(unit['size'])
+
+ result = load_test_vectors_wycheproof(("Protocol", "wycheproof"),
+ filename,
+ "Wycheproof HMAC (%s)" % filename,
+ root_tag={'hash_module': filter_algo},
+ unit_tag={'size': filter_size})
+ return result
+
+ def setUp(self):
+ self.tv = []
+ self.add_tests("hkdf_sha1_test.json")
+ self.add_tests("hkdf_sha256_test.json")
+ self.add_tests("hkdf_sha384_test.json")
+ self.add_tests("hkdf_sha512_test.json")
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_verify(self, tv):
+ self._id = "Wycheproof HKDF Test #%d (%s, %s)" % (tv.id, tv.comment, tv.filename)
+
+ try:
+ key = HKDF(tv.ikm, tv.size, tv.salt, tv.hash_module, 1, tv.info)
+ except ValueError:
+ assert not tv.valid
+ else:
+ if key != tv.okm:
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.warn(tv)
+
+ def runTest(self):
+ for tv in self.tv:
+ self.test_verify(tv)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ if not config.get('slow_tests'):
+ PBKDF2_Tests._testData = PBKDF2_Tests._testData[:3]
+ scrypt_Tests.data = scrypt_Tests.data[:3]
+
+ tests = []
+ tests += list_test_cases(PBKDF1_Tests)
+ tests += list_test_cases(PBKDF2_Tests)
+ tests += list_test_cases(S2V_Tests)
+ tests += list_test_cases(HKDF_Tests)
+ tests += [TestVectorsHKDFWycheproof(wycheproof_warnings)]
+ tests += list_test_cases(scrypt_Tests)
+ tests += list_test_cases(bcrypt_Tests)
+
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Protocol/test_SecretSharing.py b/lib/Crypto/SelfTest/Protocol/test_SecretSharing.py
new file mode 100644
index 0000000..0ea58a5
--- /dev/null
+++ b/lib/Crypto/SelfTest/Protocol/test_SecretSharing.py
@@ -0,0 +1,267 @@
+#
+# SelfTest/Protocol/test_secret_sharing.py: Self-test for secret sharing protocols
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from unittest import main, TestCase, TestSuite
+from binascii import unhexlify, hexlify
+
+from Crypto.Util.py3compat import *
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Protocol.SecretSharing import Shamir, _Element, \
+ _mult_gf2, _div_gf2
+
+class GF2_Tests(TestCase):
+
+ def test_mult_gf2(self):
+ # Prove mult by zero
+ x = _mult_gf2(0,0)
+ self.assertEqual(x, 0)
+
+ # Prove mult by unity
+ x = _mult_gf2(34, 1)
+ self.assertEqual(x, 34)
+
+ z = 3 # (x+1)
+ y = _mult_gf2(z, z)
+ self.assertEqual(y, 5) # (x+1)^2 = x^2 + 1
+ y = _mult_gf2(y, z)
+ self.assertEqual(y, 15) # (x+1)^3 = x^3 + x^2 + x + 1
+ y = _mult_gf2(y, z)
+ self.assertEqual(y, 17) # (x+1)^4 = x^4 + 1
+
+ # Prove linearity works
+ comps = [1, 4, 128, 2**34]
+ sum_comps = 1+4+128+2**34
+ y = 908
+ z = _mult_gf2(sum_comps, y)
+ w = 0
+ for x in comps:
+ w ^= _mult_gf2(x, y)
+ self.assertEqual(w, z)
+
+ def test_div_gf2(self):
+ from Crypto.Util.number import size as deg
+
+ x, y = _div_gf2(567, 7)
+ self.assertTrue(deg(y) < deg(7))
+
+ w = _mult_gf2(x, 7) ^ y
+ self.assertEqual(567, w)
+
+ x, y = _div_gf2(7, 567)
+ self.assertEqual(x, 0)
+ self.assertEqual(y, 7)
+
+class Element_Tests(TestCase):
+
+ def test1(self):
+ # Test encondings
+ e = _Element(256)
+ self.assertEqual(int(e), 256)
+ self.assertEqual(e.encode(), bchr(0)*14 + b("\x01\x00"))
+
+ e = _Element(bchr(0)*14 + b("\x01\x10"))
+ self.assertEqual(int(e), 0x110)
+ self.assertEqual(e.encode(), bchr(0)*14 + b("\x01\x10"))
+
+ # Only 16 byte string are a valid encoding
+ self.assertRaises(ValueError, _Element, bchr(0))
+
+ def test2(self):
+ # Test addition
+ e = _Element(0x10)
+ f = _Element(0x0A)
+ self.assertEqual(int(e+f), 0x1A)
+
+ def test3(self):
+ # Test multiplication
+ zero = _Element(0)
+ one = _Element(1)
+ two = _Element(2)
+
+ x = _Element(6) * zero
+ self.assertEqual(int(x), 0)
+
+ x = _Element(6) * one
+ self.assertEqual(int(x), 6)
+
+ x = _Element(2**127) * two
+ self.assertEqual(int(x), 1 + 2 + 4 + 128)
+
+ def test4(self):
+ # Test inversion
+ one = _Element(1)
+
+ x = one.inverse()
+ self.assertEqual(int(x), 1)
+
+ x = _Element(82323923)
+ y = x.inverse()
+ self.assertEqual(int(x * y), 1)
+
+class Shamir_Tests(TestCase):
+
+ def test1(self):
+ # Test splitting
+ shares = Shamir.split(2, 3, bchr(90)*16)
+ self.assertEqual(len(shares), 3)
+ for index in range(3):
+ self.assertEqual(shares[index][0], index+1)
+ self.assertEqual(len(shares[index][1]), 16)
+
+ def test2(self):
+ # Test recombine
+ from itertools import permutations
+
+ test_vectors = (
+ (2, "d9fe73909bae28b3757854c0af7ad405",
+ "1-594ae8964294174d95c33756d2504170",
+ "2-d897459d29da574eb40e93ec552ffe6e",
+ "3-5823de9bf0e068b054b5f07a28056b1b",
+ "4-db2c1f8bff46d748f795da995bd080cb"),
+ (2, "bf4f902d9a7efafd1f3ffd9291fd5de9",
+ "1-557bd3b0748064b533469722d1cc7935",
+ "2-6b2717164783c66d47cd28f2119f14d0",
+ "3-8113548ba97d58256bb4424251ae300c",
+ "4-179e9e5a218483ddaeda57539139cf04"),
+ (3, "ec96aa5c14c9faa699354cf1da74e904",
+ "1-64579fbf1908d66f7239bf6e2b4e41e1",
+ "2-6cd9428df8017b52322561e8c672ae3e",
+ "3-e418776ef5c0579bd9299277374806dd",
+ "4-ab3f77a0107398d23b323e581bb43f5d",
+ "5-23fe42431db2b41bd03ecdc7ea8e97ac"),
+ (3, "44cf249b68b80fcdc27b47be60c2c145",
+ "1-d6515a3905cd755119b86e311c801e31",
+ "2-16693d9ac9f10c254036ced5f8917fa3",
+ "3-84f74338a48476b99bf5e75a84d3a0d1",
+ "4-3fe8878dc4a5d35811cf3cbcd33dbe52",
+ "5-ad76f92fa9d0a9c4ca0c1533af7f6132"),
+ (5, "5398717c982db935d968eebe53a47f5a",
+ "1-be7be2dd4c068e7ef576aaa1b1c11b01",
+ "2-f821f5848441cb98b3eb467e2733ee21",
+ "3-25ee52f53e203f6e29a0297b5ab486b5",
+ "4-fc9fb58ef74dab947fbf9acd9d5d83cd",
+ "5-b1949cce46d81552e65f248d3f74cc5c",
+ "6-d64797f59977c4d4a7956ad916da7699",
+ "7-ab608a6546a8b9af8820ff832b1135c7"),
+ (5, "4a78db90fbf35da5545d2fb728e87596",
+ "1-08daf9a25d8aa184cfbf02b30a0ed6a0",
+ "2-dda28261e36f0b14168c2cf153fb734e",
+ "3-e9fdec5505d674a57f9836c417c1ecaa",
+ "4-4dce5636ae06dee42d2c82e65f06c735",
+ "5-3963dc118afc2ba798fa1d452b28ef00",
+ "6-6dfe6ff5b09e94d2f84c382b12f42424",
+ "7-6faea9d4d4a4e201bf6c90b9000630c3"),
+ (10, "eccbf6d66d680b49b073c4f1ddf804aa",
+ "01-7d8ac32fe4ae209ead1f3220fda34466",
+ "02-f9144e76988aad647d2e61353a6e96d5",
+ "03-b14c3b80179203363922d60760271c98",
+ "04-770bb2a8c28f6cee89e00f4d5cc7f861",
+ "05-6e3d7073ea368334ef67467871c66799",
+ "06-248792bc74a98ce024477c13c8fb5f8d",
+ "07-fcea4640d2db820c0604851e293d2487",
+ "08-2776c36fb714bb1f8525a0be36fc7dba",
+ "09-6ee7ac8be773e473a4bf75ee5f065762",
+ "10-33657fc073354cf91d4a68c735aacfc8",
+ "11-7645c65094a5868bf225c516fdee2d0c",
+ "12-840485aacb8226631ecd9c70e3018086"),
+ (10, "377e63bdbb5f7d4dc58a483d035212bb",
+ "01-32c53260103be431c843b1a633afe3bd",
+ "02-0107eb16cb8695084d452d2cc50bc7d6",
+ "03-df1e5c66cd755287fb0446faccd72a06",
+ "04-361bbcd5d40797f49dfa1898652da197",
+ "05-160d3ad1512f7dec7fd9344aed318591",
+ "06-659af6d95df4f25beca4fb9bfee3b7e8",
+ "07-37f3b208977bad50b3724566b72bfa9d",
+ "08-6c1de2dfc69c2986142c26a8248eb316",
+ "09-5e19220837a396bd4bc8cd685ff314c3",
+ "10-86e7b864fb0f3d628e46d50c1ba92f1c",
+ "11-065d0082c80b1aea18f4abe0c49df72e",
+ "12-84a09430c1d20ea9f388f3123c3733a3"),
+ )
+
+ def get_share(p):
+ pos = p.find('-')
+ return int(p[:pos]), unhexlify(p[pos + 1:])
+
+ for tv in test_vectors:
+ k = tv[0]
+ secret = unhexlify(tv[1])
+ max_perms = 10
+ for perm, shares_idx in enumerate(permutations(range(2, len(tv)), k)):
+ if perm > max_perms:
+ break
+ shares = [ get_share(tv[x]) for x in shares_idx ]
+ result = Shamir.combine(shares, True)
+ self.assertEqual(secret, result)
+
+ def test3(self):
+ # Loopback split/recombine
+ secret = unhexlify(b("000102030405060708090a0b0c0d0e0f"))
+
+ shares = Shamir.split(2, 3, secret)
+
+ secret2 = Shamir.combine(shares[:2])
+ self.assertEqual(secret, secret2)
+
+ secret3 = Shamir.combine([ shares[0], shares[2] ])
+ self.assertEqual(secret, secret3)
+
+ def test4(self):
+ # Loopback split/recombine (SSSS)
+ secret = unhexlify(b("000102030405060708090a0b0c0d0e0f"))
+
+ shares = Shamir.split(2, 3, secret, ssss=True)
+
+ secret2 = Shamir.combine(shares[:2], ssss=True)
+ self.assertEqual(secret, secret2)
+
+ def test5(self):
+ # Detect duplicate shares
+ secret = unhexlify(b("000102030405060708090a0b0c0d0e0f"))
+
+ shares = Shamir.split(2, 3, secret)
+ self.assertRaises(ValueError, Shamir.combine, (shares[0], shares[0]))
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(GF2_Tests)
+ tests += list_test_cases(Element_Tests)
+ tests += list_test_cases(Shamir_Tests)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: TestSuite(get_tests())
+ main(defaultTest='suite')
+
diff --git a/lib/Crypto/SelfTest/Protocol/test_rfc1751.py b/lib/Crypto/SelfTest/Protocol/test_rfc1751.py
new file mode 100644
index 0000000..0878cc5
--- /dev/null
+++ b/lib/Crypto/SelfTest/Protocol/test_rfc1751.py
@@ -0,0 +1,62 @@
+#
+# Test script for Crypto.Util.RFC1751.
+#
+# Part of the Python Cryptography Toolkit
+#
+# Written by Andrew Kuchling and others
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+__revision__ = "$Id$"
+
+import binascii
+import unittest
+from Crypto.Util import RFC1751
+from Crypto.Util.py3compat import *
+
+test_data = [('EB33F77EE73D4053', 'TIDE ITCH SLOW REIN RULE MOT'),
+ ('CCAC2AED591056BE4F90FD441C534766',
+ 'RASH BUSH MILK LOOK BAD BRIM AVID GAFF BAIT ROT POD LOVE'),
+ ('EFF81F9BFBC65350920CDD7416DE8009',
+ 'TROD MUTE TAIL WARM CHAR KONG HAAG CITY BORE O TEAL AWL')
+ ]
+
+class RFC1751Test_k2e (unittest.TestCase):
+
+ def runTest (self):
+ "Check converting keys to English"
+ for key, words in test_data:
+ key=binascii.a2b_hex(b(key))
+ self.assertEqual(RFC1751.key_to_english(key), words)
+
+class RFC1751Test_e2k (unittest.TestCase):
+
+ def runTest (self):
+ "Check converting English strings to keys"
+ for key, words in test_data:
+ key=binascii.a2b_hex(b(key))
+ self.assertEqual(RFC1751.english_to_key(words), key)
+
+# class RFC1751Test
+
+def get_tests(config={}):
+ return [RFC1751Test_k2e(), RFC1751Test_e2k()]
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/lib/Crypto/SelfTest/PublicKey/__init__.py b/lib/Crypto/SelfTest/PublicKey/__init__.py
new file mode 100644
index 0000000..437d3e4
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/__init__.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/PublicKey/__init__.py: Self-test for public key crypto
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test for public-key crypto"""
+
+import unittest
+from Crypto.SelfTest.PublicKey import (test_DSA, test_RSA,
+ test_ECC_NIST, test_ECC_25519, test_ECC_448,
+ test_import_DSA, test_import_RSA,
+ test_import_ECC, test_ElGamal)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += test_DSA.get_tests(config=config)
+ tests += test_RSA.get_tests(config=config)
+ tests += test_ECC_NIST.get_tests(config=config)
+ tests += test_ECC_25519.get_tests(config=config)
+ tests += test_ECC_448.get_tests(config=config)
+
+ tests += test_import_DSA.get_tests(config=config)
+ tests += test_import_RSA.get_tests(config=config)
+ tests += test_import_ECC.get_tests(config=config)
+
+ tests += test_ElGamal.get_tests(config=config)
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/PublicKey/test_DSA.py b/lib/Crypto/SelfTest/PublicKey/test_DSA.py
new file mode 100644
index 0000000..125cf6c
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_DSA.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/PublicKey/test_DSA.py: Self-test for the DSA primitive
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.PublicKey.DSA"""
+
+import os
+from Crypto.Util.py3compat import *
+
+import unittest
+from Crypto.SelfTest.st_common import list_test_cases, a2b_hex, b2a_hex
+
+def _sws(s):
+ """Remove whitespace from a text or byte string"""
+ if isinstance(s,str):
+ return "".join(s.split())
+ else:
+ return b("").join(s.split())
+
+class DSATest(unittest.TestCase):
+ # Test vector from "Appendix 5. Example of the DSA" of
+ # "Digital Signature Standard (DSS)",
+ # U.S. Department of Commerce/National Institute of Standards and Technology
+ # FIPS 186-2 (+Change Notice), 2000 January 27.
+ # http://csrc.nist.gov/publications/fips/fips186-2/fips186-2-change1.pdf
+
+ y = _sws("""19131871 d75b1612 a819f29d 78d1b0d7 346f7aa7 7bb62a85
+ 9bfd6c56 75da9d21 2d3a36ef 1672ef66 0b8c7c25 5cc0ec74
+ 858fba33 f44c0669 9630a76b 030ee333""")
+
+ g = _sws("""626d0278 39ea0a13 413163a5 5b4cb500 299d5522 956cefcb
+ 3bff10f3 99ce2c2e 71cb9de5 fa24babf 58e5b795 21925c9c
+ c42e9f6f 464b088c c572af53 e6d78802""")
+
+ p = _sws("""8df2a494 492276aa 3d25759b b06869cb eac0d83a fb8d0cf7
+ cbb8324f 0d7882e5 d0762fc5 b7210eaf c2e9adac 32ab7aac
+ 49693dfb f83724c2 ec0736ee 31c80291""")
+
+ q = _sws("""c773218c 737ec8ee 993b4f2d ed30f48e dace915f""")
+
+ x = _sws("""2070b322 3dba372f de1c0ffc 7b2e3b49 8b260614""")
+
+ k = _sws("""358dad57 1462710f 50e254cf 1a376b2b deaadfbf""")
+ k_inverse = _sws("""0d516729 8202e49b 4116ac10 4fc3f415 ae52f917""")
+ m = b2a_hex(b("abc"))
+ m_hash = _sws("""a9993e36 4706816a ba3e2571 7850c26c 9cd0d89d""")
+ r = _sws("""8bac1ab6 6410435c b7181f95 b16ab97c 92b341c0""")
+ s = _sws("""41e2345f 1f56df24 58f426d1 55b4ba2d b6dcd8c8""")
+
+ def setUp(self):
+ global DSA, Random, bytes_to_long, size
+ from Crypto.PublicKey import DSA
+ from Crypto import Random
+ from Crypto.Util.number import bytes_to_long, inverse, size
+
+ self.dsa = DSA
+
+ def test_generate_1arg(self):
+ """DSA (default implementation) generated key (1 argument)"""
+ dsaObj = self.dsa.generate(1024)
+ self._check_private_key(dsaObj)
+ pub = dsaObj.public_key()
+ self._check_public_key(pub)
+
+ def test_generate_2arg(self):
+ """DSA (default implementation) generated key (2 arguments)"""
+ dsaObj = self.dsa.generate(1024, Random.new().read)
+ self._check_private_key(dsaObj)
+ pub = dsaObj.public_key()
+ self._check_public_key(pub)
+
+ def test_construct_4tuple(self):
+ """DSA (default implementation) constructed key (4-tuple)"""
+ (y, g, p, q) = [bytes_to_long(a2b_hex(param)) for param in (self.y, self.g, self.p, self.q)]
+ dsaObj = self.dsa.construct((y, g, p, q))
+ self._test_verification(dsaObj)
+
+ def test_construct_5tuple(self):
+ """DSA (default implementation) constructed key (5-tuple)"""
+ (y, g, p, q, x) = [bytes_to_long(a2b_hex(param)) for param in (self.y, self.g, self.p, self.q, self.x)]
+ dsaObj = self.dsa.construct((y, g, p, q, x))
+ self._test_signing(dsaObj)
+ self._test_verification(dsaObj)
+
+ def test_construct_bad_key4(self):
+ (y, g, p, q) = [bytes_to_long(a2b_hex(param)) for param in (self.y, self.g, self.p, self.q)]
+ tup = (y, g, p+1, q)
+ self.assertRaises(ValueError, self.dsa.construct, tup)
+
+ tup = (y, g, p, q+1)
+ self.assertRaises(ValueError, self.dsa.construct, tup)
+
+ tup = (y, 1, p, q)
+ self.assertRaises(ValueError, self.dsa.construct, tup)
+
+ def test_construct_bad_key5(self):
+ (y, g, p, q, x) = [bytes_to_long(a2b_hex(param)) for param in (self.y, self.g, self.p, self.q, self.x)]
+ tup = (y, g, p, q, x+1)
+ self.assertRaises(ValueError, self.dsa.construct, tup)
+
+ tup = (y, g, p, q, q+10)
+ self.assertRaises(ValueError, self.dsa.construct, tup)
+
+ def _check_private_key(self, dsaObj):
+ # Check capabilities
+ self.assertEqual(1, dsaObj.has_private())
+ self.assertEqual(1, dsaObj.can_sign())
+ self.assertEqual(0, dsaObj.can_encrypt())
+
+ # Sanity check key data
+ self.assertEqual(1, dsaObj.p > dsaObj.q) # p > q
+ self.assertEqual(160, size(dsaObj.q)) # size(q) == 160 bits
+ self.assertEqual(0, (dsaObj.p - 1) % dsaObj.q) # q is a divisor of p-1
+ self.assertEqual(dsaObj.y, pow(dsaObj.g, dsaObj.x, dsaObj.p)) # y == g**x mod p
+ self.assertEqual(1, 0 < dsaObj.x < dsaObj.q) # 0 < x < q
+
+ def _check_public_key(self, dsaObj):
+ k = bytes_to_long(a2b_hex(self.k))
+ m_hash = bytes_to_long(a2b_hex(self.m_hash))
+
+ # Check capabilities
+ self.assertEqual(0, dsaObj.has_private())
+ self.assertEqual(1, dsaObj.can_sign())
+ self.assertEqual(0, dsaObj.can_encrypt())
+
+ # Check that private parameters are all missing
+ self.assertEqual(0, hasattr(dsaObj, 'x'))
+
+ # Sanity check key data
+ self.assertEqual(1, dsaObj.p > dsaObj.q) # p > q
+ self.assertEqual(160, size(dsaObj.q)) # size(q) == 160 bits
+ self.assertEqual(0, (dsaObj.p - 1) % dsaObj.q) # q is a divisor of p-1
+
+ # Public-only key objects should raise an error when .sign() is called
+ self.assertRaises(TypeError, dsaObj._sign, m_hash, k)
+
+ # Check __eq__ and __ne__
+ self.assertEqual(dsaObj.public_key() == dsaObj.public_key(),True) # assert_
+ self.assertEqual(dsaObj.public_key() != dsaObj.public_key(),False) # assertFalse
+
+ self.assertEqual(dsaObj.public_key(), dsaObj.publickey())
+
+ def _test_signing(self, dsaObj):
+ k = bytes_to_long(a2b_hex(self.k))
+ m_hash = bytes_to_long(a2b_hex(self.m_hash))
+ r = bytes_to_long(a2b_hex(self.r))
+ s = bytes_to_long(a2b_hex(self.s))
+ (r_out, s_out) = dsaObj._sign(m_hash, k)
+ self.assertEqual((r, s), (r_out, s_out))
+
+ def _test_verification(self, dsaObj):
+ m_hash = bytes_to_long(a2b_hex(self.m_hash))
+ r = bytes_to_long(a2b_hex(self.r))
+ s = bytes_to_long(a2b_hex(self.s))
+ self.assertTrue(dsaObj._verify(m_hash, (r, s)))
+ self.assertFalse(dsaObj._verify(m_hash + 1, (r, s)))
+
+ def test_repr(self):
+ (y, g, p, q) = [bytes_to_long(a2b_hex(param)) for param in (self.y, self.g, self.p, self.q)]
+ dsaObj = self.dsa.construct((y, g, p, q))
+ repr(dsaObj)
+
+
+class DSADomainTest(unittest.TestCase):
+
+ def test_domain1(self):
+ """Verify we can generate new keys in a given domain"""
+ dsa_key_1 = DSA.generate(1024)
+ domain_params = dsa_key_1.domain()
+
+ dsa_key_2 = DSA.generate(1024, domain=domain_params)
+ self.assertEqual(dsa_key_1.p, dsa_key_2.p)
+ self.assertEqual(dsa_key_1.q, dsa_key_2.q)
+ self.assertEqual(dsa_key_1.g, dsa_key_2.g)
+
+ self.assertEqual(dsa_key_1.domain(), dsa_key_2.domain())
+
+ def _get_weak_domain(self):
+
+ from Crypto.Math.Numbers import Integer
+ from Crypto.Math import Primality
+
+ p = Integer(4)
+ while p.size_in_bits() != 1024 or Primality.test_probable_prime(p) != Primality.PROBABLY_PRIME:
+ q1 = Integer.random(exact_bits=80)
+ q2 = Integer.random(exact_bits=80)
+ q = q1 * q2
+ z = Integer.random(exact_bits=1024-160)
+ p = z * q + 1
+
+ h = Integer(2)
+ g = 1
+ while g == 1:
+ g = pow(h, z, p)
+ h += 1
+
+ return (p, q, g)
+
+
+ def test_generate_error_weak_domain(self):
+ """Verify that domain parameters with composite q are rejected"""
+
+ domain_params = self._get_weak_domain()
+ self.assertRaises(ValueError, DSA.generate, 1024, domain=domain_params)
+
+
+ def test_construct_error_weak_domain(self):
+ """Verify that domain parameters with composite q are rejected"""
+
+ from Crypto.Math.Numbers import Integer
+
+ p, q, g = self._get_weak_domain()
+ y = pow(g, 89, p)
+ self.assertRaises(ValueError, DSA.construct, (y, g, p, q))
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(DSATest)
+ tests += list_test_cases(DSADomainTest)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/PublicKey/test_ECC_25519.py b/lib/Crypto/SelfTest/PublicKey/test_ECC_25519.py
new file mode 100644
index 0000000..305c077
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_ECC_25519.py
@@ -0,0 +1,327 @@
+# ===================================================================
+#
+# Copyright (c) 2022, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors
+
+from Crypto.PublicKey import ECC
+from Crypto.PublicKey.ECC import EccPoint, _curves, EccKey
+
+from Crypto.Math.Numbers import Integer
+
+from Crypto.Hash import SHAKE128
+
+
+class TestEccPoint_Ed25519(unittest.TestCase):
+
+ Gxy = {"x": 15112221349535400772501151409588531511454012693041857206046113283949847762202,
+ "y": 46316835694926478169428394003475163141307993866256225615783033603165251855960}
+
+ G2xy = {"x": 24727413235106541002554574571675588834622768167397638456726423682521233608206,
+ "y": 15549675580280190176352668710449542251549572066445060580507079593062643049417}
+
+ G3xy = {"x": 46896733464454938657123544595386787789046198280132665686241321779790909858396,
+ "y": 8324843778533443976490377120369201138301417226297555316741202210403726505172}
+
+ pointG = EccPoint(Gxy['x'], Gxy['y'], curve="Ed25519")
+ pointG2 = EccPoint(G2xy['x'], G2xy['y'], curve="Ed25519")
+ pointG3 = EccPoint(G3xy['x'], G3xy['y'], curve="Ed25519")
+
+ def test_init_xy(self):
+ EccPoint(self.Gxy['x'], self.Gxy['y'], curve="Ed25519")
+
+ # Neutral point
+ pai = EccPoint(0, 1, curve="Ed25519")
+ self.assertEqual(pai.x, 0)
+ self.assertEqual(pai.y, 1)
+ self.assertEqual(pai.xy, (0, 1))
+
+ # G
+ bp = self.pointG.copy()
+ self.assertEqual(bp.x, 15112221349535400772501151409588531511454012693041857206046113283949847762202)
+ self.assertEqual(bp.y, 46316835694926478169428394003475163141307993866256225615783033603165251855960)
+ self.assertEqual(bp.xy, (bp.x, bp.y))
+
+ # 2G
+ bp2 = self.pointG2.copy()
+ self.assertEqual(bp2.x, 24727413235106541002554574571675588834622768167397638456726423682521233608206)
+ self.assertEqual(bp2.y, 15549675580280190176352668710449542251549572066445060580507079593062643049417)
+ self.assertEqual(bp2.xy, (bp2.x, bp2.y))
+
+ # 5G
+ EccPoint(x=33467004535436536005251147249499675200073690106659565782908757308821616914995,
+ y=43097193783671926753355113395909008640284023746042808659097434958891230611693,
+ curve="Ed25519")
+
+ # Catch if point is not on the curve
+ self.assertRaises(ValueError, EccPoint, 34, 35, curve="Ed25519")
+
+ def test_set(self):
+ pointW = EccPoint(0, 1, curve="Ed25519")
+ pointW.set(self.pointG)
+ self.assertEqual(pointW.x, self.pointG.x)
+ self.assertEqual(pointW.y, self.pointG.y)
+
+ def test_copy(self):
+ pointW = self.pointG.copy()
+ self.assertEqual(pointW.x, self.pointG.x)
+ self.assertEqual(pointW.y, self.pointG.y)
+
+ def test_equal(self):
+ pointH = self.pointG.copy()
+ pointI = self.pointG2.copy()
+ self.assertEqual(self.pointG, pointH)
+ self.assertNotEqual(self.pointG, pointI)
+
+ def test_pai(self):
+ pai = EccPoint(0, 1, curve="Ed25519")
+ self.failUnless(pai.is_point_at_infinity())
+ self.assertEqual(pai, pai.point_at_infinity())
+
+ def test_negate(self):
+ negG = -self.pointG
+ sum = self.pointG + negG
+ self.failUnless(sum.is_point_at_infinity())
+
+ def test_addition(self):
+ self.assertEqual(self.pointG + self.pointG2, self.pointG3)
+ self.assertEqual(self.pointG2 + self.pointG, self.pointG3)
+ self.assertEqual(self.pointG2 + self.pointG.point_at_infinity(), self.pointG2)
+ self.assertEqual(self.pointG.point_at_infinity() + self.pointG2, self.pointG2)
+
+ G5 = self.pointG2 + self.pointG3
+ self.assertEqual(G5.x, 33467004535436536005251147249499675200073690106659565782908757308821616914995)
+ self.assertEqual(G5.y, 43097193783671926753355113395909008640284023746042808659097434958891230611693)
+
+ def test_inplace_addition(self):
+ pointH = self.pointG.copy()
+ pointH += self.pointG
+ self.assertEqual(pointH, self.pointG2)
+ pointH += self.pointG
+ self.assertEqual(pointH, self.pointG3)
+ pointH += self.pointG.point_at_infinity()
+ self.assertEqual(pointH, self.pointG3)
+
+ def test_doubling(self):
+ pointH = self.pointG.copy()
+ pointH.double()
+ self.assertEqual(pointH.x, self.pointG2.x)
+ self.assertEqual(pointH.y, self.pointG2.y)
+
+ # 2*0
+ pai = self.pointG.point_at_infinity()
+ pointR = pai.copy()
+ pointR.double()
+ self.assertEqual(pointR, pai)
+
+ def test_scalar_multiply(self):
+ d = 0
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0)
+ self.assertEqual(pointH.y, 1)
+
+ d = 1
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, self.pointG.x)
+ self.assertEqual(pointH.y, self.pointG.y)
+
+ d = 2
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, self.pointG2.x)
+ self.assertEqual(pointH.y, self.pointG2.y)
+
+ d = 3
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, self.pointG3.x)
+ self.assertEqual(pointH.y, self.pointG3.y)
+
+ d = 4
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 14582954232372986451776170844943001818709880559417862259286374126315108956272)
+ self.assertEqual(pointH.y, 32483318716863467900234833297694612235682047836132991208333042722294373421359)
+
+ d = 5
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 33467004535436536005251147249499675200073690106659565782908757308821616914995)
+ self.assertEqual(pointH.y, 43097193783671926753355113395909008640284023746042808659097434958891230611693)
+
+ d = 10
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 43500613248243327786121022071801015118933854441360174117148262713429272820047)
+ self.assertEqual(pointH.y, 45005105423099817237495816771148012388779685712352441364231470781391834741548)
+
+ d = 20
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 46694936775300686710656303283485882876784402425210400817529601134760286812591)
+ self.assertEqual(pointH.y, 8786390172762935853260670851718824721296437982862763585171334833968259029560)
+
+ d = 255
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 36843863416400016952258312492144504209624961884991522125275155377549541182230)
+ self.assertEqual(pointH.y, 22327030283879720808995671630924669697661065034121040761798775626517750047180)
+
+ d = 256
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 42740085206947573681423002599456489563927820004573071834350074001818321593686)
+ self.assertEqual(pointH.y, 6935684722522267618220753829624209639984359598320562595061366101608187623111)
+
+ def test_sizes(self):
+ self.assertEqual(self.pointG.size_in_bits(), 255)
+ self.assertEqual(self.pointG.size_in_bytes(), 32)
+
+
+class TestEccKey_Ed25519(unittest.TestCase):
+
+ def test_private_key(self):
+ seed = unhexlify("9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60")
+ Px = 38815646466658113194383306759739515082307681141926459231621296960732224964046
+ Py = 11903303657706407974989296177215005343713679411332034699907763981919547054807
+
+ key = EccKey(curve="Ed25519", seed=seed)
+ self.assertEqual(key.seed, seed)
+ self.assertEqual(key.d, 36144925721603087658594284515452164870581325872720374094707712194495455132720)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ.x, Px)
+ self.assertEqual(key.pointQ.y, Py)
+
+ point = EccPoint(Px, Py, "ed25519")
+ key = EccKey(curve="Ed25519", seed=seed, point=point)
+ self.assertEqual(key.d, 36144925721603087658594284515452164870581325872720374094707712194495455132720)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ # Other names
+ key = EccKey(curve="ed25519", seed=seed)
+
+ # Must not accept d parameter
+ self.assertRaises(ValueError, EccKey, curve="ed25519", d=1)
+
+ def test_public_key(self):
+ point = EccPoint(_curves['ed25519'].Gx, _curves['ed25519'].Gy, curve='ed25519')
+ key = EccKey(curve="ed25519", point=point)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ def test_public_key_derived(self):
+ priv_key = EccKey(curve="ed25519", seed=b'H'*32)
+ pub_key = priv_key.public_key()
+ self.assertFalse(pub_key.has_private())
+ self.assertEqual(priv_key.pointQ, pub_key.pointQ)
+
+ def test_invalid_seed(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="ed25519", seed=b'H' * 31))
+
+ def test_equality(self):
+ private_key = ECC.construct(seed=b'H'*32, curve="Ed25519")
+ private_key2 = ECC.construct(seed=b'H'*32, curve="ed25519")
+ private_key3 = ECC.construct(seed=b'C'*32, curve="Ed25519")
+
+ public_key = private_key.public_key()
+ public_key2 = private_key2.public_key()
+ public_key3 = private_key3.public_key()
+
+ self.assertEqual(private_key, private_key2)
+ self.assertNotEqual(private_key, private_key3)
+
+ self.assertEqual(public_key, public_key2)
+ self.assertNotEqual(public_key, public_key3)
+
+ self.assertNotEqual(public_key, private_key)
+
+
+class TestEccModule_Ed25519(unittest.TestCase):
+
+ def test_generate(self):
+ key = ECC.generate(curve="Ed25519")
+ self.assertTrue(key.has_private())
+ point = EccPoint(_curves['Ed25519'].Gx, _curves['Ed25519'].Gy, curve="Ed25519") * key.d
+ self.assertEqual(key.pointQ, point)
+
+ # Always random
+ key2 = ECC.generate(curve="Ed25519")
+ self.assertNotEqual(key, key2)
+
+ # Other names
+ ECC.generate(curve="Ed25519")
+
+ # Random source
+ key1 = ECC.generate(curve="Ed25519", randfunc=SHAKE128.new().read)
+ key2 = ECC.generate(curve="Ed25519", randfunc=SHAKE128.new().read)
+ self.assertEqual(key1, key2)
+
+ def test_construct(self):
+ seed = unhexlify("9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60")
+ Px = 38815646466658113194383306759739515082307681141926459231621296960732224964046
+ Py = 11903303657706407974989296177215005343713679411332034699907763981919547054807
+ d = 36144925721603087658594284515452164870581325872720374094707712194495455132720
+ point = EccPoint(Px, Py, curve="Ed25519")
+
+ # Private key only
+ key = ECC.construct(curve="Ed25519", seed=seed)
+ self.assertEqual(key.pointQ, point)
+ self.assertTrue(key.has_private())
+
+ # Public key only
+ key = ECC.construct(curve="Ed25519", point_x=Px, point_y=Py)
+ self.assertEqual(key.pointQ, point)
+ self.assertFalse(key.has_private())
+
+ # Private and public key
+ key = ECC.construct(curve="Ed25519", seed=seed, point_x=Px, point_y=Py)
+ self.assertEqual(key.pointQ, point)
+ self.assertTrue(key.has_private())
+
+ # Other names
+ key = ECC.construct(curve="ed25519", seed=seed)
+
+ def test_negative_construct(self):
+ coord = dict(point_x=10, point_y=4)
+ coordG = dict(point_x=_curves['ed25519'].Gx, point_y=_curves['ed25519'].Gy)
+
+ self.assertRaises(ValueError, ECC.construct, curve="Ed25519", **coord)
+ self.assertRaises(ValueError, ECC.construct, curve="Ed25519", d=2, **coordG)
+ self.assertRaises(ValueError, ECC.construct, curve="Ed25519", seed=b'H'*31)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(TestEccPoint_Ed25519)
+ tests += list_test_cases(TestEccKey_Ed25519)
+ tests += list_test_cases(TestEccModule_Ed25519)
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/PublicKey/test_ECC_448.py b/lib/Crypto/SelfTest/PublicKey/test_ECC_448.py
new file mode 100644
index 0000000..68bcaa3
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_ECC_448.py
@@ -0,0 +1,327 @@
+# ===================================================================
+#
+# Copyright (c) 2022, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors
+
+from Crypto.PublicKey import ECC
+from Crypto.PublicKey.ECC import EccPoint, _curves, EccKey
+
+from Crypto.Math.Numbers import Integer
+
+from Crypto.Hash import SHAKE128
+
+
+class TestEccPoint_Ed448(unittest.TestCase):
+
+ Gxy = {"x": 0x4f1970c66bed0ded221d15a622bf36da9e146570470f1767ea6de324a3d3a46412ae1af72ab66511433b80e18b00938e2626a82bc70cc05e,
+ "y": 0x693f46716eb6bc248876203756c9c7624bea73736ca3984087789c1e05a0c2d73ad3ff1ce67c39c4fdbd132c4ed7c8ad9808795bf230fa14}
+
+ G2xy = {"x": 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555,
+ "y": 0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed}
+
+ G3xy = {"x": 0x865886b9108af6455bd64316cb6943332241b8b8cda82c7e2ba077a4a3fcfe8daa9cbf7f6271fd6e862b769465da8575728173286ff2f8f,
+ "y": 0xe005a8dbd5125cf706cbda7ad43aa6449a4a8d952356c3b9fce43c82ec4e1d58bb3a331bdb6767f0bffa9a68fed02dafb822ac13588ed6fc}
+
+ pointG = EccPoint(Gxy['x'], Gxy['y'], curve="Ed448")
+ pointG2 = EccPoint(G2xy['x'], G2xy['y'], curve="Ed448")
+ pointG3 = EccPoint(G3xy['x'], G3xy['y'], curve="Ed448")
+
+ def test_init_xy(self):
+ EccPoint(self.Gxy['x'], self.Gxy['y'], curve="Ed448")
+
+ # Neutral point
+ pai = EccPoint(0, 1, curve="Ed448")
+ self.assertEqual(pai.x, 0)
+ self.assertEqual(pai.y, 1)
+ self.assertEqual(pai.xy, (0, 1))
+
+ # G
+ bp = self.pointG.copy()
+ self.assertEqual(bp.x, 0x4f1970c66bed0ded221d15a622bf36da9e146570470f1767ea6de324a3d3a46412ae1af72ab66511433b80e18b00938e2626a82bc70cc05e)
+ self.assertEqual(bp.y, 0x693f46716eb6bc248876203756c9c7624bea73736ca3984087789c1e05a0c2d73ad3ff1ce67c39c4fdbd132c4ed7c8ad9808795bf230fa14)
+ self.assertEqual(bp.xy, (bp.x, bp.y))
+
+ # 2G
+ bp2 = self.pointG2.copy()
+ self.assertEqual(bp2.x, 0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa955555555555555555555555555555555555555555555555555555555)
+ self.assertEqual(bp2.y, 0xae05e9634ad7048db359d6205086c2b0036ed7a035884dd7b7e36d728ad8c4b80d6565833a2a3098bbbcb2bed1cda06bdaeafbcdea9386ed)
+ self.assertEqual(bp2.xy, (bp2.x, bp2.y))
+
+ # 5G
+ EccPoint(x=0x7a9f9335a48dcb0e2ba7601eedb50def80cbcf728562ada756d761e8958812808bc0d57a920c3c96f07b2d8cefc6f950d0a99d1092030034,
+ y=0xadfd751a2517edd3b9109ce4fd580ade260ca1823ab18fced86551f7b698017127d7a4ee59d2b33c58405512881f225443b4731472f435eb,
+ curve="Ed448")
+
+ # Catch if point is not on the curve
+ self.assertRaises(ValueError, EccPoint, 34, 35, curve="Ed448")
+
+ def test_set(self):
+ pointW = EccPoint(0, 1, curve="Ed448")
+ pointW.set(self.pointG)
+ self.assertEqual(pointW.x, self.pointG.x)
+ self.assertEqual(pointW.y, self.pointG.y)
+
+ def test_copy(self):
+ pointW = self.pointG.copy()
+ self.assertEqual(pointW.x, self.pointG.x)
+ self.assertEqual(pointW.y, self.pointG.y)
+
+ def test_equal(self):
+ pointH = self.pointG.copy()
+ pointI = self.pointG2.copy()
+ self.assertEqual(self.pointG, pointH)
+ self.assertNotEqual(self.pointG, pointI)
+
+ def test_pai(self):
+ pai = EccPoint(0, 1, curve="Ed448")
+ self.failUnless(pai.is_point_at_infinity())
+ self.assertEqual(pai, pai.point_at_infinity())
+
+ def test_negate(self):
+ negG = -self.pointG
+ sum = self.pointG + negG
+ self.failUnless(sum.is_point_at_infinity())
+
+ def test_addition(self):
+ self.assertEqual(self.pointG + self.pointG2, self.pointG3)
+ self.assertEqual(self.pointG2 + self.pointG, self.pointG3)
+ self.assertEqual(self.pointG2 + self.pointG.point_at_infinity(), self.pointG2)
+ self.assertEqual(self.pointG.point_at_infinity() + self.pointG2, self.pointG2)
+
+ G5 = self.pointG2 + self.pointG3
+ self.assertEqual(G5.x, 0x7a9f9335a48dcb0e2ba7601eedb50def80cbcf728562ada756d761e8958812808bc0d57a920c3c96f07b2d8cefc6f950d0a99d1092030034)
+ self.assertEqual(G5.y, 0xadfd751a2517edd3b9109ce4fd580ade260ca1823ab18fced86551f7b698017127d7a4ee59d2b33c58405512881f225443b4731472f435eb)
+
+ def test_inplace_addition(self):
+ pointH = self.pointG.copy()
+ pointH += self.pointG
+ self.assertEqual(pointH, self.pointG2)
+ pointH += self.pointG
+ self.assertEqual(pointH, self.pointG3)
+ pointH += self.pointG.point_at_infinity()
+ self.assertEqual(pointH, self.pointG3)
+
+ def test_doubling(self):
+ pointH = self.pointG.copy()
+ pointH.double()
+ self.assertEqual(pointH.x, self.pointG2.x)
+ self.assertEqual(pointH.y, self.pointG2.y)
+
+ # 2*0
+ pai = self.pointG.point_at_infinity()
+ pointR = pai.copy()
+ pointR.double()
+ self.assertEqual(pointR, pai)
+
+ def test_scalar_multiply(self):
+ d = 0
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0)
+ self.assertEqual(pointH.y, 1)
+
+ d = 1
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, self.pointG.x)
+ self.assertEqual(pointH.y, self.pointG.y)
+
+ d = 2
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, self.pointG2.x)
+ self.assertEqual(pointH.y, self.pointG2.y)
+
+ d = 3
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, self.pointG3.x)
+ self.assertEqual(pointH.y, self.pointG3.y)
+
+ d = 4
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0x49dcbc5c6c0cce2c1419a17226f929ea255a09cf4e0891c693fda4be70c74cc301b7bdf1515dd8ba21aee1798949e120e2ce42ac48ba7f30)
+ self.assertEqual(pointH.y, 0xd49077e4accde527164b33a5de021b979cb7c02f0457d845c90dc3227b8a5bc1c0d8f97ea1ca9472b5d444285d0d4f5b32e236f86de51839)
+
+ d = 5
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0x7a9f9335a48dcb0e2ba7601eedb50def80cbcf728562ada756d761e8958812808bc0d57a920c3c96f07b2d8cefc6f950d0a99d1092030034)
+ self.assertEqual(pointH.y, 0xadfd751a2517edd3b9109ce4fd580ade260ca1823ab18fced86551f7b698017127d7a4ee59d2b33c58405512881f225443b4731472f435eb)
+
+ d = 10
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0x77486f9d19f6411cdd35d30d1c3235f71936452c787e5c034134d3e8172278aca61622bc805761ce3dab65118a0122d73b403165d0ed303d)
+ self.assertEqual(pointH.y, 0x4d2fea0b026be11024f1f0fe7e94e618e8ac17381ada1d1bf7ee293a68ff5d0bf93c1997dc1aabdc0c7e6381428d85b6b1954a89e4cddf67)
+
+ d = 20
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0x3c236422354600fe6763defcc1503737e4ed89e262d0de3ec1e552020f2a56fe3b9e1e012d021072598c3c2821e18268bb8fb8339c0d1216)
+ self.assertEqual(pointH.y, 0xb555b9721f630ccb05fc466de4c74d3d2781e69eca88e1b040844f04cab39fd946f91c688fa42402bb38fb9c3e61231017020b219b4396e1)
+
+ d = 255
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0xbeb7f8388b05cd9c1aa2e3c0dcf31e2b563659361826225390e7748654f627d5c36cbe627e9019936b56d15d4dad7c337c09bac64ff4197f)
+ self.assertEqual(pointH.y, 0x1e37312b2dd4e9440c43c6e7725fc4fa3d11e582d4863f1d018e28f50c0efdb1f53f9b01ada7c87fa162b1f0d72401015d57613d25f1ad53)
+
+ d = 256
+ pointH = d * self.pointG
+ self.assertEqual(pointH.x, 0xf19c34feb56730e3e2be761ac0a2a2b24853b281dda019fc35a5ab58e3696beb39609ae756b0d20fb7ccf0d79aaf5f3bca2e4fdb25bfac1c)
+ self.assertEqual(pointH.y, 0x3beb69cc9111bffcaddc61d363ce6fe5dd44da4aadce78f52e92e985d5442344ced72c4611ed0daac9f4f5661eab73d7a12d25ce8a30241e)
+
+ def test_sizes(self):
+ self.assertEqual(self.pointG.size_in_bits(), 448)
+ self.assertEqual(self.pointG.size_in_bytes(), 56)
+
+
+class TestEccKey_Ed448(unittest.TestCase):
+
+ def test_private_key(self):
+ seed = unhexlify("4adf5d37ac6785e83e99a924f92676d366a78690af59c92b6bdf14f9cdbcf26fdad478109607583d633b60078d61d51d81b7509c5433b0d4c9")
+ Px = 0x72a01eea003a35f9ac44231dc4aae2a382f351d80bf32508175b0855edcf389aa2bbf308dd961ce361a6e7c2091bc78957f6ebcf3002a617
+ Py = 0x9e0d08d84586e9aeefecacb41d049b831f1a3ee0c3eada63e34557b30702b50ab59fb372feff7c30b8cbb7dd51afbe88444ec56238722ec1
+
+ key = EccKey(curve="Ed448", seed=seed)
+ self.assertEqual(key.seed, seed)
+ self.assertEqual(key.d, 0xb07cf179604f83433186e5178760c759c15125ee54ff6f8dcde46e872b709ac82ed0bd0a4e036d774034dcb18a9fb11894657a1485895f80)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ.x, Px)
+ self.assertEqual(key.pointQ.y, Py)
+
+ point = EccPoint(Px, Py, "ed448")
+ key = EccKey(curve="Ed448", seed=seed, point=point)
+ self.assertEqual(key.d, 0xb07cf179604f83433186e5178760c759c15125ee54ff6f8dcde46e872b709ac82ed0bd0a4e036d774034dcb18a9fb11894657a1485895f80)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ # Other names
+ key = EccKey(curve="ed448", seed=seed)
+
+ # Must not accept d parameter
+ self.assertRaises(ValueError, EccKey, curve="ed448", d=1)
+
+ def test_public_key(self):
+ point = EccPoint(_curves['ed448'].Gx, _curves['ed448'].Gy, curve='ed448')
+ key = EccKey(curve="ed448", point=point)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ def test_public_key_derived(self):
+ priv_key = EccKey(curve="ed448", seed=b'H'*57)
+ pub_key = priv_key.public_key()
+ self.assertFalse(pub_key.has_private())
+ self.assertEqual(priv_key.pointQ, pub_key.pointQ)
+
+ def test_invalid_seed(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="ed448", seed=b'H' * 56))
+
+ def test_equality(self):
+ private_key = ECC.construct(seed=b'H'*57, curve="Ed448")
+ private_key2 = ECC.construct(seed=b'H'*57, curve="ed448")
+ private_key3 = ECC.construct(seed=b'C'*57, curve="Ed448")
+
+ public_key = private_key.public_key()
+ public_key2 = private_key2.public_key()
+ public_key3 = private_key3.public_key()
+
+ self.assertEqual(private_key, private_key2)
+ self.assertNotEqual(private_key, private_key3)
+
+ self.assertEqual(public_key, public_key2)
+ self.assertNotEqual(public_key, public_key3)
+
+ self.assertNotEqual(public_key, private_key)
+
+
+class TestEccModule_Ed448(unittest.TestCase):
+
+ def test_generate(self):
+ key = ECC.generate(curve="Ed448")
+ self.assertTrue(key.has_private())
+ point = EccPoint(_curves['Ed448'].Gx, _curves['Ed448'].Gy, curve="Ed448") * key.d
+ self.assertEqual(key.pointQ, point)
+
+ # Always random
+ key2 = ECC.generate(curve="Ed448")
+ self.assertNotEqual(key, key2)
+
+ # Other names
+ ECC.generate(curve="Ed448")
+
+ # Random source
+ key1 = ECC.generate(curve="Ed448", randfunc=SHAKE128.new().read)
+ key2 = ECC.generate(curve="Ed448", randfunc=SHAKE128.new().read)
+ self.assertEqual(key1, key2)
+
+ def test_construct(self):
+ seed = unhexlify("4adf5d37ac6785e83e99a924f92676d366a78690af59c92b6bdf14f9cdbcf26fdad478109607583d633b60078d61d51d81b7509c5433b0d4c9")
+ Px = 0x72a01eea003a35f9ac44231dc4aae2a382f351d80bf32508175b0855edcf389aa2bbf308dd961ce361a6e7c2091bc78957f6ebcf3002a617
+ Py = 0x9e0d08d84586e9aeefecacb41d049b831f1a3ee0c3eada63e34557b30702b50ab59fb372feff7c30b8cbb7dd51afbe88444ec56238722ec1
+ d = 0xb07cf179604f83433186e5178760c759c15125ee54ff6f8dcde46e872b709ac82ed0bd0a4e036d774034dcb18a9fb11894657a1485895f80
+ point = EccPoint(Px, Py, curve="Ed448")
+
+ # Private key only
+ key = ECC.construct(curve="Ed448", seed=seed)
+ self.assertEqual(key.pointQ, point)
+ self.assertTrue(key.has_private())
+
+ # Public key only
+ key = ECC.construct(curve="Ed448", point_x=Px, point_y=Py)
+ self.assertEqual(key.pointQ, point)
+ self.assertFalse(key.has_private())
+
+ # Private and public key
+ key = ECC.construct(curve="Ed448", seed=seed, point_x=Px, point_y=Py)
+ self.assertEqual(key.pointQ, point)
+ self.assertTrue(key.has_private())
+
+ # Other names
+ key = ECC.construct(curve="ed448", seed=seed)
+
+ def test_negative_construct(self):
+ coord = dict(point_x=10, point_y=4)
+ coordG = dict(point_x=_curves['ed448'].Gx, point_y=_curves['ed448'].Gy)
+
+ self.assertRaises(ValueError, ECC.construct, curve="Ed448", **coord)
+ self.assertRaises(ValueError, ECC.construct, curve="Ed448", d=2, **coordG)
+ self.assertRaises(ValueError, ECC.construct, curve="Ed448", seed=b'H'*58)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(TestEccPoint_Ed448)
+ tests += list_test_cases(TestEccKey_Ed448)
+ tests += list_test_cases(TestEccModule_Ed448)
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/PublicKey/test_ECC_NIST.py b/lib/Crypto/SelfTest/PublicKey/test_ECC_NIST.py
new file mode 100644
index 0000000..fc13f2d
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_ECC_NIST.py
@@ -0,0 +1,1389 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors
+
+from Crypto.PublicKey import ECC
+from Crypto.PublicKey.ECC import EccPoint, _curves, EccKey
+
+from Crypto.Math.Numbers import Integer
+
+
+class TestEccPoint(unittest.TestCase):
+
+ def test_mix(self):
+
+ p1 = ECC.generate(curve='P-256').pointQ
+ p2 = ECC.generate(curve='P-384').pointQ
+
+ try:
+ p1 + p2
+ assert(False)
+ except ValueError as e:
+ assert "not on the same curve" in str(e)
+
+ try:
+ p1 += p2
+ assert(False)
+ except ValueError as e:
+ assert "not on the same curve" in str(e)
+
+ def test_repr(self):
+ p1 = ECC.construct(curve='P-256',
+ d=75467964919405407085864614198393977741148485328036093939970922195112333446269,
+ point_x=20573031766139722500939782666697015100983491952082159880539639074939225934381,
+ point_y=108863130203210779921520632367477406025152638284581252625277850513266505911389)
+ self.assertEqual(repr(p1), "EccKey(curve='NIST P-256', point_x=20573031766139722500939782666697015100983491952082159880539639074939225934381, point_y=108863130203210779921520632367477406025152638284581252625277850513266505911389, d=75467964919405407085864614198393977741148485328036093939970922195112333446269)")
+
+
+class TestEccPoint_NIST_P192(unittest.TestCase):
+ """Tests defined in section 4.1 of https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.9073&rep=rep1&type=pdf"""
+
+ pointS = EccPoint(
+ 0xd458e7d127ae671b0c330266d246769353a012073e97acf8,
+ 0x325930500d851f336bddc050cf7fb11b5673a1645086df3b,
+ curve='p192')
+
+ pointT = EccPoint(
+ 0xf22c4395213e9ebe67ddecdd87fdbd01be16fb059b9753a4,
+ 0x264424096af2b3597796db48f8dfb41fa9cecc97691a9c79,
+ curve='p192')
+
+ def test_set(self):
+ pointW = EccPoint(0, 0)
+ pointW.set(self.pointS)
+ self.assertEqual(pointW, self.pointS)
+
+ def test_copy(self):
+ pointW = self.pointS.copy()
+ self.assertEqual(pointW, self.pointS)
+ pointW.set(self.pointT)
+ self.assertEqual(pointW, self.pointT)
+ self.assertNotEqual(self.pointS, self.pointT)
+
+ def test_negate(self):
+ negS = -self.pointS
+ sum = self.pointS + negS
+ self.assertEqual(sum, self.pointS.point_at_infinity())
+
+ def test_addition(self):
+ pointRx = 0x48e1e4096b9b8e5ca9d0f1f077b8abf58e843894de4d0290
+ pointRy = 0x408fa77c797cd7dbfb16aa48a3648d3d63c94117d7b6aa4b
+
+ pointR = self.pointS + self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS + pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai + self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai + pai
+ self.assertEqual(pointR, pai)
+
+ def test_inplace_addition(self):
+ pointRx = 0x48e1e4096b9b8e5ca9d0f1f077b8abf58e843894de4d0290
+ pointRy = 0x408fa77c797cd7dbfb16aa48a3648d3d63c94117d7b6aa4b
+
+ pointR = self.pointS.copy()
+ pointR += self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS.copy()
+ pointR += pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai.copy()
+ pointR += self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai.copy()
+ pointR += pai
+ self.assertEqual(pointR, pai)
+
+ def test_doubling(self):
+ pointRx = 0x30c5bc6b8c7da25354b373dc14dd8a0eba42d25a3f6e6962
+ pointRy = 0x0dde14bc4249a721c407aedbf011e2ddbbcb2968c9d889cf
+
+ pointR = self.pointS.copy()
+ pointR.double()
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 2*0
+ pai = self.pointS.point_at_infinity()
+ pointR = pai.copy()
+ pointR.double()
+ self.assertEqual(pointR, pai)
+
+ # S + S
+ pointR = self.pointS.copy()
+ pointR += pointR
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_scalar_multiply(self):
+ d = 0xa78a236d60baec0c5dd41b33a542463a8255391af64c74ee
+ pointRx = 0x1faee4205a4f669d2d0a8f25e3bcec9a62a6952965bf6d31
+ pointRy = 0x5ff2cdfa508a2581892367087c696f179e7a4d7e8260fb06
+
+ pointR = self.pointS * d
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 0*S
+ pai = self.pointS.point_at_infinity()
+ pointR = self.pointS * 0
+ self.assertEqual(pointR, pai)
+
+ # -1*S
+ self.assertRaises(ValueError, lambda: self.pointS * -1)
+
+ # Reverse order
+ pointR = d * self.pointS
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pointR = Integer(d) * self.pointS
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_joint_scalar_multiply(self):
+ d = 0xa78a236d60baec0c5dd41b33a542463a8255391af64c74ee
+ e = 0xc4be3d53ec3089e71e4de8ceab7cce889bc393cd85b972bc
+ pointRx = 0x019f64eed8fa9b72b7dfea82c17c9bfa60ecb9e1778b5bde
+ pointRy = 0x16590c5fcd8655fa4ced33fb800e2a7e3c61f35d83503644
+
+ pointR = self.pointS * d + self.pointT * e
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_sizes(self):
+ self.assertEqual(self.pointS.size_in_bits(), 192)
+ self.assertEqual(self.pointS.size_in_bytes(), 24)
+
+
+class TestEccPoint_NIST_P224(unittest.TestCase):
+ """Tests defined in section 4.2 of https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.9073&rep=rep1&type=pdf"""
+
+ pointS = EccPoint(
+ 0x6eca814ba59a930843dc814edd6c97da95518df3c6fdf16e9a10bb5b,
+ 0xef4b497f0963bc8b6aec0ca0f259b89cd80994147e05dc6b64d7bf22,
+ curve='p224')
+
+ pointT = EccPoint(
+ 0xb72b25aea5cb03fb88d7e842002969648e6ef23c5d39ac903826bd6d,
+ 0xc42a8a4d34984f0b71b5b4091af7dceb33ea729c1a2dc8b434f10c34,
+ curve='p224')
+
+ def test_set(self):
+ pointW = EccPoint(0, 0)
+ pointW.set(self.pointS)
+ self.assertEqual(pointW, self.pointS)
+
+ def test_copy(self):
+ pointW = self.pointS.copy()
+ self.assertEqual(pointW, self.pointS)
+ pointW.set(self.pointT)
+ self.assertEqual(pointW, self.pointT)
+ self.assertNotEqual(self.pointS, self.pointT)
+
+ def test_negate(self):
+ negS = -self.pointS
+ sum = self.pointS + negS
+ self.assertEqual(sum, self.pointS.point_at_infinity())
+
+ def test_addition(self):
+ pointRx = 0x236f26d9e84c2f7d776b107bd478ee0a6d2bcfcaa2162afae8d2fd15
+ pointRy = 0xe53cc0a7904ce6c3746f6a97471297a0b7d5cdf8d536ae25bb0fda70
+
+ pointR = self.pointS + self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS + pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai + self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai + pai
+ self.assertEqual(pointR, pai)
+
+ def test_inplace_addition(self):
+ pointRx = 0x236f26d9e84c2f7d776b107bd478ee0a6d2bcfcaa2162afae8d2fd15
+ pointRy = 0xe53cc0a7904ce6c3746f6a97471297a0b7d5cdf8d536ae25bb0fda70
+
+ pointR = self.pointS.copy()
+ pointR += self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS.copy()
+ pointR += pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai.copy()
+ pointR += self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai.copy()
+ pointR += pai
+ self.assertEqual(pointR, pai)
+
+ def test_doubling(self):
+ pointRx = 0xa9c96f2117dee0f27ca56850ebb46efad8ee26852f165e29cb5cdfc7
+ pointRy = 0xadf18c84cf77ced4d76d4930417d9579207840bf49bfbf5837dfdd7d
+
+ pointR = self.pointS.copy()
+ pointR.double()
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 2*0
+ pai = self.pointS.point_at_infinity()
+ pointR = pai.copy()
+ pointR.double()
+ self.assertEqual(pointR, pai)
+
+ # S + S
+ pointR = self.pointS.copy()
+ pointR += pointR
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_scalar_multiply(self):
+ d = 0xa78ccc30eaca0fcc8e36b2dd6fbb03df06d37f52711e6363aaf1d73b
+ pointRx = 0x96a7625e92a8d72bff1113abdb95777e736a14c6fdaacc392702bca4
+ pointRy = 0x0f8e5702942a3c5e13cd2fd5801915258b43dfadc70d15dbada3ed10
+
+ pointR = self.pointS * d
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 0*S
+ pai = self.pointS.point_at_infinity()
+ pointR = self.pointS * 0
+ self.assertEqual(pointR, pai)
+
+ # -1*S
+ self.assertRaises(ValueError, lambda: self.pointS * -1)
+
+ # Reverse order
+ pointR = d * self.pointS
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pointR = Integer(d) * self.pointS
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_joing_scalar_multiply(self):
+ d = 0xa78ccc30eaca0fcc8e36b2dd6fbb03df06d37f52711e6363aaf1d73b
+ e = 0x54d549ffc08c96592519d73e71e8e0703fc8177fa88aa77a6ed35736
+ pointRx = 0xdbfe2958c7b2cda1302a67ea3ffd94c918c5b350ab838d52e288c83e
+ pointRy = 0x2f521b83ac3b0549ff4895abcc7f0c5a861aacb87acbc5b8147bb18b
+
+ pointR = self.pointS * d + self.pointT * e
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_sizes(self):
+ self.assertEqual(self.pointS.size_in_bits(), 224)
+ self.assertEqual(self.pointS.size_in_bytes(), 28)
+
+
+class TestEccPoint_NIST_P256(unittest.TestCase):
+ """Tests defined in section 4.3 of https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.9073&rep=rep1&type=pdf"""
+
+ pointS = EccPoint(
+ 0xde2444bebc8d36e682edd27e0f271508617519b3221a8fa0b77cab3989da97c9,
+ 0xc093ae7ff36e5380fc01a5aad1e66659702de80f53cec576b6350b243042a256)
+
+ pointT = EccPoint(
+ 0x55a8b00f8da1d44e62f6b3b25316212e39540dc861c89575bb8cf92e35e0986b,
+ 0x5421c3209c2d6c704835d82ac4c3dd90f61a8a52598b9e7ab656e9d8c8b24316)
+
+ def test_set(self):
+ pointW = EccPoint(0, 0)
+ pointW.set(self.pointS)
+ self.assertEqual(pointW, self.pointS)
+
+ def test_copy(self):
+ pointW = self.pointS.copy()
+ self.assertEqual(pointW, self.pointS)
+ pointW.set(self.pointT)
+ self.assertEqual(pointW, self.pointT)
+ self.assertNotEqual(self.pointS, self.pointT)
+
+ def test_negate(self):
+ negS = -self.pointS
+ sum = self.pointS + negS
+ self.assertEqual(sum, self.pointS.point_at_infinity())
+
+ def test_addition(self):
+ pointRx = 0x72b13dd4354b6b81745195e98cc5ba6970349191ac476bd4553cf35a545a067e
+ pointRy = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264
+
+ pointR = self.pointS + self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS + pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai + self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai + pai
+ self.assertEqual(pointR, pai)
+
+ def test_inplace_addition(self):
+ pointRx = 0x72b13dd4354b6b81745195e98cc5ba6970349191ac476bd4553cf35a545a067e
+ pointRy = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264
+
+ pointR = self.pointS.copy()
+ pointR += self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS.copy()
+ pointR += pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai.copy()
+ pointR += self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai.copy()
+ pointR += pai
+ self.assertEqual(pointR, pai)
+
+ def test_doubling(self):
+ pointRx = 0x7669e6901606ee3ba1a8eef1e0024c33df6c22f3b17481b82a860ffcdb6127b0
+ pointRy = 0xfa878162187a54f6c39f6ee0072f33de389ef3eecd03023de10ca2c1db61d0c7
+
+ pointR = self.pointS.copy()
+ pointR.double()
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 2*0
+ pai = self.pointS.point_at_infinity()
+ pointR = pai.copy()
+ pointR.double()
+ self.assertEqual(pointR, pai)
+
+ # S + S
+ pointR = self.pointS.copy()
+ pointR += pointR
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_scalar_multiply(self):
+ d = 0xc51e4753afdec1e6b6c6a5b992f43f8dd0c7a8933072708b6522468b2ffb06fd
+ pointRx = 0x51d08d5f2d4278882946d88d83c97d11e62becc3cfc18bedacc89ba34eeca03f
+ pointRy = 0x75ee68eb8bf626aa5b673ab51f6e744e06f8fcf8a6c0cf3035beca956a7b41d5
+
+ pointR = self.pointS * d
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 0*S
+ pai = self.pointS.point_at_infinity()
+ pointR = self.pointS * 0
+ self.assertEqual(pointR, pai)
+
+ # -1*S
+ self.assertRaises(ValueError, lambda: self.pointS * -1)
+
+ # Reverse order
+ pointR = d * self.pointS
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pointR = Integer(d) * self.pointS
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_joing_scalar_multiply(self):
+ d = 0xc51e4753afdec1e6b6c6a5b992f43f8dd0c7a8933072708b6522468b2ffb06fd
+ e = 0xd37f628ece72a462f0145cbefe3f0b355ee8332d37acdd83a358016aea029db7
+ pointRx = 0xd867b4679221009234939221b8046245efcf58413daacbeff857b8588341f6b8
+ pointRy = 0xf2504055c03cede12d22720dad69c745106b6607ec7e50dd35d54bd80f615275
+
+ pointR = self.pointS * d + self.pointT * e
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_sizes(self):
+ self.assertEqual(self.pointS.size_in_bits(), 256)
+ self.assertEqual(self.pointS.size_in_bytes(), 32)
+
+
+class TestEccPoint_NIST_P384(unittest.TestCase):
+ """Tests defined in section 4.4 of https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.9073&rep=rep1&type=pdf"""
+
+ pointS = EccPoint(
+ 0xfba203b81bbd23f2b3be971cc23997e1ae4d89e69cb6f92385dda82768ada415ebab4167459da98e62b1332d1e73cb0e,
+ 0x5ffedbaefdeba603e7923e06cdb5d0c65b22301429293376d5c6944e3fa6259f162b4788de6987fd59aed5e4b5285e45,
+ "p384")
+
+ pointT = EccPoint(
+ 0xaacc05202e7fda6fc73d82f0a66220527da8117ee8f8330ead7d20ee6f255f582d8bd38c5a7f2b40bcdb68ba13d81051,
+ 0x84009a263fefba7c2c57cffa5db3634d286131afc0fca8d25afa22a7b5dce0d9470da89233cee178592f49b6fecb5092,
+ "p384")
+
+ def test_set(self):
+ pointW = EccPoint(0, 0, "p384")
+ pointW.set(self.pointS)
+ self.assertEqual(pointW, self.pointS)
+
+ def test_copy(self):
+ pointW = self.pointS.copy()
+ self.assertEqual(pointW, self.pointS)
+ pointW.set(self.pointT)
+ self.assertEqual(pointW, self.pointT)
+ self.assertNotEqual(self.pointS, self.pointT)
+
+ def test_negate(self):
+ negS = -self.pointS
+ sum = self.pointS + negS
+ self.assertEqual(sum, self.pointS.point_at_infinity())
+
+ def test_addition(self):
+ pointRx = 0x12dc5ce7acdfc5844d939f40b4df012e68f865b89c3213ba97090a247a2fc009075cf471cd2e85c489979b65ee0b5eed
+ pointRy = 0x167312e58fe0c0afa248f2854e3cddcb557f983b3189b67f21eee01341e7e9fe67f6ee81b36988efa406945c8804a4b0
+
+ pointR = self.pointS + self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS + pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai + self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai + pai
+ self.assertEqual(pointR, pai)
+
+ def _test_inplace_addition(self):
+ pointRx = 0x72b13dd4354b6b81745195e98cc5ba6970349191ac476bd4553cf35a545a067e
+ pointRy = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264
+
+ pointR = self.pointS.copy()
+ pointR += self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS.copy()
+ pointR += pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai.copy()
+ pointR += self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai.copy()
+ pointR += pai
+ self.assertEqual(pointR, pai)
+
+ def test_doubling(self):
+ pointRx = 0x2a2111b1e0aa8b2fc5a1975516bc4d58017ff96b25e1bdff3c229d5fac3bacc319dcbec29f9478f42dee597b4641504c
+ pointRy = 0xfa2e3d9dc84db8954ce8085ef28d7184fddfd1344b4d4797343af9b5f9d837520b450f726443e4114bd4e5bdb2f65ddd
+
+ pointR = self.pointS.copy()
+ pointR.double()
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 2*0
+ pai = self.pointS.point_at_infinity()
+ pointR = pai.copy()
+ pointR.double()
+ self.assertEqual(pointR, pai)
+
+ # S + S
+ pointR = self.pointS.copy()
+ pointR += pointR
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_scalar_multiply(self):
+ d = 0xa4ebcae5a665983493ab3e626085a24c104311a761b5a8fdac052ed1f111a5c44f76f45659d2d111a61b5fdd97583480
+ pointRx = 0xe4f77e7ffeb7f0958910e3a680d677a477191df166160ff7ef6bb5261f791aa7b45e3e653d151b95dad3d93ca0290ef2
+ pointRy = 0xac7dee41d8c5f4a7d5836960a773cfc1376289d3373f8cf7417b0c6207ac32e913856612fc9ff2e357eb2ee05cf9667f
+
+ pointR = self.pointS * d
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 0*S
+ pai = self.pointS.point_at_infinity()
+ pointR = self.pointS * 0
+ self.assertEqual(pointR, pai)
+
+ # -1*S
+ self.assertRaises(ValueError, lambda: self.pointS * -1)
+
+ def test_joing_scalar_multiply(self):
+ d = 0xa4ebcae5a665983493ab3e626085a24c104311a761b5a8fdac052ed1f111a5c44f76f45659d2d111a61b5fdd97583480
+ e = 0xafcf88119a3a76c87acbd6008e1349b29f4ba9aa0e12ce89bcfcae2180b38d81ab8cf15095301a182afbc6893e75385d
+ pointRx = 0x917ea28bcd641741ae5d18c2f1bd917ba68d34f0f0577387dc81260462aea60e2417b8bdc5d954fc729d211db23a02dc
+ pointRy = 0x1a29f7ce6d074654d77b40888c73e92546c8f16a5ff6bcbd307f758d4aee684beff26f6742f597e2585c86da908f7186
+
+ pointR = self.pointS * d + self.pointT * e
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_sizes(self):
+ self.assertEqual(self.pointS.size_in_bits(), 384)
+ self.assertEqual(self.pointS.size_in_bytes(), 48)
+
+
+class TestEccPoint_NIST_P521(unittest.TestCase):
+ """Tests defined in section 4.5 of https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.9073&rep=rep1&type=pdf"""
+
+ pointS = EccPoint(
+ 0x000001d5c693f66c08ed03ad0f031f937443458f601fd098d3d0227b4bf62873af50740b0bb84aa157fc847bcf8dc16a8b2b8bfd8e2d0a7d39af04b089930ef6dad5c1b4,
+ 0x00000144b7770963c63a39248865ff36b074151eac33549b224af5c8664c54012b818ed037b2b7c1a63ac89ebaa11e07db89fcee5b556e49764ee3fa66ea7ae61ac01823,
+ "p521")
+
+ pointT = EccPoint(
+ 0x000000f411f2ac2eb971a267b80297ba67c322dba4bb21cec8b70073bf88fc1ca5fde3ba09e5df6d39acb2c0762c03d7bc224a3e197feaf760d6324006fe3be9a548c7d5,
+ 0x000001fdf842769c707c93c630df6d02eff399a06f1b36fb9684f0b373ed064889629abb92b1ae328fdb45534268384943f0e9222afe03259b32274d35d1b9584c65e305,
+ "p521")
+
+ def test_set(self):
+ pointW = EccPoint(0, 0)
+ pointW.set(self.pointS)
+ self.assertEqual(pointW, self.pointS)
+
+ def test_copy(self):
+ pointW = self.pointS.copy()
+ self.assertEqual(pointW, self.pointS)
+ pointW.set(self.pointT)
+ self.assertEqual(pointW, self.pointT)
+ self.assertNotEqual(self.pointS, self.pointT)
+
+ def test_negate(self):
+ negS = -self.pointS
+ sum = self.pointS + negS
+ self.assertEqual(sum, self.pointS.point_at_infinity())
+
+ def test_addition(self):
+ pointRx = 0x000001264ae115ba9cbc2ee56e6f0059e24b52c8046321602c59a339cfb757c89a59c358a9a8e1f86d384b3f3b255ea3f73670c6dc9f45d46b6a196dc37bbe0f6b2dd9e9
+ pointRy = 0x00000062a9c72b8f9f88a271690bfa017a6466c31b9cadc2fc544744aeb817072349cfddc5ad0e81b03f1897bd9c8c6efbdf68237dc3bb00445979fb373b20c9a967ac55
+
+ pointR = self.pointS + self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS + pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai + self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai + pai
+ self.assertEqual(pointR, pai)
+
+ def test_inplace_addition(self):
+ pointRx = 0x000001264ae115ba9cbc2ee56e6f0059e24b52c8046321602c59a339cfb757c89a59c358a9a8e1f86d384b3f3b255ea3f73670c6dc9f45d46b6a196dc37bbe0f6b2dd9e9
+ pointRy = 0x00000062a9c72b8f9f88a271690bfa017a6466c31b9cadc2fc544744aeb817072349cfddc5ad0e81b03f1897bd9c8c6efbdf68237dc3bb00445979fb373b20c9a967ac55
+
+ pointR = self.pointS.copy()
+ pointR += self.pointT
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ pai = pointR.point_at_infinity()
+
+ # S + 0
+ pointR = self.pointS.copy()
+ pointR += pai
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + S
+ pointR = pai.copy()
+ pointR += self.pointS
+ self.assertEqual(pointR, self.pointS)
+
+ # 0 + 0
+ pointR = pai.copy()
+ pointR += pai
+ self.assertEqual(pointR, pai)
+
+ def test_doubling(self):
+ pointRx = 0x0000012879442f2450c119e7119a5f738be1f1eba9e9d7c6cf41b325d9ce6d643106e9d61124a91a96bcf201305a9dee55fa79136dc700831e54c3ca4ff2646bd3c36bc6
+ pointRy = 0x0000019864a8b8855c2479cbefe375ae553e2393271ed36fadfc4494fc0583f6bd03598896f39854abeae5f9a6515a021e2c0eef139e71de610143f53382f4104dccb543
+
+ pointR = self.pointS.copy()
+ pointR.double()
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 2*0
+ pai = self.pointS.point_at_infinity()
+ pointR = pai.copy()
+ pointR.double()
+ self.assertEqual(pointR, pai)
+
+ # S + S
+ pointR = self.pointS.copy()
+ pointR += pointR
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_scalar_multiply(self):
+ d = 0x000001eb7f81785c9629f136a7e8f8c674957109735554111a2a866fa5a166699419bfa9936c78b62653964df0d6da940a695c7294d41b2d6600de6dfcf0edcfc89fdcb1
+ pointRx = 0x00000091b15d09d0ca0353f8f96b93cdb13497b0a4bb582ae9ebefa35eee61bf7b7d041b8ec34c6c00c0c0671c4ae063318fb75be87af4fe859608c95f0ab4774f8c95bb
+ pointRy = 0x00000130f8f8b5e1abb4dd94f6baaf654a2d5810411e77b7423965e0c7fd79ec1ae563c207bd255ee9828eb7a03fed565240d2cc80ddd2cecbb2eb50f0951f75ad87977f
+
+ pointR = self.pointS * d
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ # 0*S
+ pai = self.pointS.point_at_infinity()
+ pointR = self.pointS * 0
+ self.assertEqual(pointR, pai)
+
+ # -1*S
+ self.assertRaises(ValueError, lambda: self.pointS * -1)
+
+ def test_joing_scalar_multiply(self):
+ d = 0x000001eb7f81785c9629f136a7e8f8c674957109735554111a2a866fa5a166699419bfa9936c78b62653964df0d6da940a695c7294d41b2d6600de6dfcf0edcfc89fdcb1
+ e = 0x00000137e6b73d38f153c3a7575615812608f2bab3229c92e21c0d1c83cfad9261dbb17bb77a63682000031b9122c2f0cdab2af72314be95254de4291a8f85f7c70412e3
+ pointRx = 0x0000009d3802642b3bea152beb9e05fba247790f7fc168072d363340133402f2585588dc1385d40ebcb8552f8db02b23d687cae46185b27528adb1bf9729716e4eba653d
+ pointRy = 0x0000000fe44344e79da6f49d87c1063744e5957d9ac0a505bafa8281c9ce9ff25ad53f8da084a2deb0923e46501de5797850c61b229023dd9cf7fc7f04cd35ebb026d89d
+
+ pointR = self.pointS * d
+ pointR += self.pointT * e
+ self.assertEqual(pointR.x, pointRx)
+ self.assertEqual(pointR.y, pointRy)
+
+ def test_sizes(self):
+ self.assertEqual(self.pointS.size_in_bits(), 521)
+ self.assertEqual(self.pointS.size_in_bytes(), 66)
+
+
+class TestEccPoint_PAI_P192(unittest.TestCase):
+ """Test vectors from http://point-at-infinity.org/ecc/nisttv"""
+
+ curve = _curves['p192']
+ pointG = EccPoint(curve.Gx, curve.Gy, "p192")
+
+
+tv_pai = load_test_vectors(("PublicKey", "ECC"),
+ "point-at-infinity.org-P192.txt",
+ "P-192 tests from point-at-infinity.org",
+ {"k": lambda k: int(k),
+ "x": lambda x: int(x, 16),
+ "y": lambda y: int(y, 16)}) or []
+for tv in tv_pai:
+ def new_test(self, scalar=tv.k, x=tv.x, y=tv.y):
+ result = self.pointG * scalar
+ self.assertEqual(result.x, x)
+ self.assertEqual(result.y, y)
+ setattr(TestEccPoint_PAI_P192, "test_%d" % tv.count, new_test)
+
+
+class TestEccPoint_PAI_P224(unittest.TestCase):
+ """Test vectors from http://point-at-infinity.org/ecc/nisttv"""
+
+ curve = _curves['p224']
+ pointG = EccPoint(curve.Gx, curve.Gy, "p224")
+
+
+tv_pai = load_test_vectors(("PublicKey", "ECC"),
+ "point-at-infinity.org-P224.txt",
+ "P-224 tests from point-at-infinity.org",
+ {"k": lambda k: int(k),
+ "x": lambda x: int(x, 16),
+ "y": lambda y: int(y, 16)}) or []
+for tv in tv_pai:
+ def new_test(self, scalar=tv.k, x=tv.x, y=tv.y):
+ result = self.pointG * scalar
+ self.assertEqual(result.x, x)
+ self.assertEqual(result.y, y)
+ setattr(TestEccPoint_PAI_P224, "test_%d" % tv.count, new_test)
+
+
+class TestEccPoint_PAI_P256(unittest.TestCase):
+ """Test vectors from http://point-at-infinity.org/ecc/nisttv"""
+
+ curve = _curves['p256']
+ pointG = EccPoint(curve.Gx, curve.Gy, "p256")
+
+
+tv_pai = load_test_vectors(("PublicKey", "ECC"),
+ "point-at-infinity.org-P256.txt",
+ "P-256 tests from point-at-infinity.org",
+ {"k": lambda k: int(k),
+ "x": lambda x: int(x, 16),
+ "y": lambda y: int(y, 16)}) or []
+for tv in tv_pai:
+ def new_test(self, scalar=tv.k, x=tv.x, y=tv.y):
+ result = self.pointG * scalar
+ self.assertEqual(result.x, x)
+ self.assertEqual(result.y, y)
+ setattr(TestEccPoint_PAI_P256, "test_%d" % tv.count, new_test)
+
+
+class TestEccPoint_PAI_P384(unittest.TestCase):
+ """Test vectors from http://point-at-infinity.org/ecc/nisttv"""
+
+ curve = _curves['p384']
+ pointG = EccPoint(curve.Gx, curve.Gy, "p384")
+
+
+tv_pai = load_test_vectors(("PublicKey", "ECC"),
+ "point-at-infinity.org-P384.txt",
+ "P-384 tests from point-at-infinity.org",
+ {"k": lambda k: int(k),
+ "x": lambda x: int(x, 16),
+ "y": lambda y: int(y, 16)}) or []
+for tv in tv_pai:
+ def new_test(self, scalar=tv.k, x=tv.x, y=tv.y):
+ result = self.pointG * scalar
+ self.assertEqual(result.x, x)
+ self.assertEqual(result.y, y)
+ setattr(TestEccPoint_PAI_P384, "test_%d" % tv.count, new_test)
+
+
+class TestEccPoint_PAI_P521(unittest.TestCase):
+ """Test vectors from http://point-at-infinity.org/ecc/nisttv"""
+
+ curve = _curves['p521']
+ pointG = EccPoint(curve.Gx, curve.Gy, "p521")
+
+
+tv_pai = load_test_vectors(("PublicKey", "ECC"),
+ "point-at-infinity.org-P521.txt",
+ "P-521 tests from point-at-infinity.org",
+ {"k": lambda k: int(k),
+ "x": lambda x: int(x, 16),
+ "y": lambda y: int(y, 16)}) or []
+for tv in tv_pai:
+ def new_test(self, scalar=tv.k, x=tv.x, y=tv.y):
+ result = self.pointG * scalar
+ self.assertEqual(result.x, x)
+ self.assertEqual(result.y, y)
+ setattr(TestEccPoint_PAI_P521, "test_%d" % tv.count, new_test)
+
+
+class TestEccKey_P192(unittest.TestCase):
+
+ def test_private_key(self):
+
+ key = EccKey(curve="P-192", d=1)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ.x, _curves['p192'].Gx)
+ self.assertEqual(key.pointQ.y, _curves['p192'].Gy)
+
+ point = EccPoint(_curves['p192'].Gx, _curves['p192'].Gy, curve='P-192')
+ key = EccKey(curve="P-192", d=1, point=point)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ # Other names
+ key = EccKey(curve="secp192r1", d=1)
+ key = EccKey(curve="prime192v1", d=1)
+
+ def test_public_key(self):
+
+ point = EccPoint(_curves['p192'].Gx, _curves['p192'].Gy, curve='P-192')
+ key = EccKey(curve="P-192", point=point)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ def test_public_key_derived(self):
+
+ priv_key = EccKey(curve="P-192", d=3)
+ pub_key = priv_key.public_key()
+ self.assertFalse(pub_key.has_private())
+ self.assertEqual(priv_key.pointQ, pub_key.pointQ)
+
+ def test_invalid_curve(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-193", d=1))
+
+ def test_invalid_d(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-192", d=0))
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-192",
+ d=_curves['p192'].order))
+
+ def test_equality(self):
+
+ private_key = ECC.construct(d=3, curve="P-192")
+ private_key2 = ECC.construct(d=3, curve="P-192")
+ private_key3 = ECC.construct(d=4, curve="P-192")
+
+ public_key = private_key.public_key()
+ public_key2 = private_key2.public_key()
+ public_key3 = private_key3.public_key()
+
+ self.assertEqual(private_key, private_key2)
+ self.assertNotEqual(private_key, private_key3)
+
+ self.assertEqual(public_key, public_key2)
+ self.assertNotEqual(public_key, public_key3)
+
+ self.assertNotEqual(public_key, private_key)
+
+
+class TestEccKey_P224(unittest.TestCase):
+
+ def test_private_key(self):
+
+ key = EccKey(curve="P-224", d=1)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ.x, _curves['p224'].Gx)
+ self.assertEqual(key.pointQ.y, _curves['p224'].Gy)
+
+ point = EccPoint(_curves['p224'].Gx, _curves['p224'].Gy, curve='P-224')
+ key = EccKey(curve="P-224", d=1, point=point)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ # Other names
+ key = EccKey(curve="secp224r1", d=1)
+ key = EccKey(curve="prime224v1", d=1)
+
+ def test_public_key(self):
+
+ point = EccPoint(_curves['p224'].Gx, _curves['p224'].Gy, curve='P-224')
+ key = EccKey(curve="P-224", point=point)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ def test_public_key_derived(self):
+
+ priv_key = EccKey(curve="P-224", d=3)
+ pub_key = priv_key.public_key()
+ self.assertFalse(pub_key.has_private())
+ self.assertEqual(priv_key.pointQ, pub_key.pointQ)
+
+ def test_invalid_curve(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-225", d=1))
+
+ def test_invalid_d(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-224", d=0))
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-224",
+ d=_curves['p224'].order))
+
+ def test_equality(self):
+
+ private_key = ECC.construct(d=3, curve="P-224")
+ private_key2 = ECC.construct(d=3, curve="P-224")
+ private_key3 = ECC.construct(d=4, curve="P-224")
+
+ public_key = private_key.public_key()
+ public_key2 = private_key2.public_key()
+ public_key3 = private_key3.public_key()
+
+ self.assertEqual(private_key, private_key2)
+ self.assertNotEqual(private_key, private_key3)
+
+ self.assertEqual(public_key, public_key2)
+ self.assertNotEqual(public_key, public_key3)
+
+ self.assertNotEqual(public_key, private_key)
+
+
+class TestEccKey_P256(unittest.TestCase):
+
+ def test_private_key(self):
+
+ key = EccKey(curve="P-256", d=1)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ.x, _curves['p256'].Gx)
+ self.assertEqual(key.pointQ.y, _curves['p256'].Gy)
+
+ point = EccPoint(_curves['p256'].Gx, _curves['p256'].Gy)
+ key = EccKey(curve="P-256", d=1, point=point)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ # Other names
+ key = EccKey(curve="secp256r1", d=1)
+ key = EccKey(curve="prime256v1", d=1)
+
+ # Must not accept d parameter
+ self.assertRaises(ValueError, EccKey, curve="p256", seed=b'H'*32)
+
+ def test_public_key(self):
+
+ point = EccPoint(_curves['p256'].Gx, _curves['p256'].Gy)
+ key = EccKey(curve="P-256", point=point)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ def test_public_key_derived(self):
+
+ priv_key = EccKey(curve="P-256", d=3)
+ pub_key = priv_key.public_key()
+ self.assertFalse(pub_key.has_private())
+ self.assertEqual(priv_key.pointQ, pub_key.pointQ)
+
+ def test_invalid_curve(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-257", d=1))
+
+ def test_invalid_d(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-256", d=0))
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-256", d=_curves['p256'].order))
+
+ def test_equality(self):
+
+ private_key = ECC.construct(d=3, curve="P-256")
+ private_key2 = ECC.construct(d=3, curve="P-256")
+ private_key3 = ECC.construct(d=4, curve="P-256")
+
+ public_key = private_key.public_key()
+ public_key2 = private_key2.public_key()
+ public_key3 = private_key3.public_key()
+
+ self.assertEqual(private_key, private_key2)
+ self.assertNotEqual(private_key, private_key3)
+
+ self.assertEqual(public_key, public_key2)
+ self.assertNotEqual(public_key, public_key3)
+
+ self.assertNotEqual(public_key, private_key)
+
+
+class TestEccKey_P384(unittest.TestCase):
+
+ def test_private_key(self):
+
+ p384 = _curves['p384']
+
+ key = EccKey(curve="P-384", d=1)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ.x, p384.Gx)
+ self.assertEqual(key.pointQ.y, p384.Gy)
+
+ point = EccPoint(p384.Gx, p384.Gy, "p384")
+ key = EccKey(curve="P-384", d=1, point=point)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ # Other names
+ key = EccKey(curve="p384", d=1)
+ key = EccKey(curve="secp384r1", d=1)
+ key = EccKey(curve="prime384v1", d=1)
+
+ def test_public_key(self):
+
+ p384 = _curves['p384']
+ point = EccPoint(p384.Gx, p384.Gy, 'p384')
+ key = EccKey(curve="P-384", point=point)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ def test_public_key_derived(self):
+
+ priv_key = EccKey(curve="P-384", d=3)
+ pub_key = priv_key.public_key()
+ self.assertFalse(pub_key.has_private())
+ self.assertEqual(priv_key.pointQ, pub_key.pointQ)
+
+ def test_invalid_curve(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-385", d=1))
+
+ def test_invalid_d(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-384", d=0))
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-384",
+ d=_curves['p384'].order))
+
+ def test_equality(self):
+
+ private_key = ECC.construct(d=3, curve="P-384")
+ private_key2 = ECC.construct(d=3, curve="P-384")
+ private_key3 = ECC.construct(d=4, curve="P-384")
+
+ public_key = private_key.public_key()
+ public_key2 = private_key2.public_key()
+ public_key3 = private_key3.public_key()
+
+ self.assertEqual(private_key, private_key2)
+ self.assertNotEqual(private_key, private_key3)
+
+ self.assertEqual(public_key, public_key2)
+ self.assertNotEqual(public_key, public_key3)
+
+ self.assertNotEqual(public_key, private_key)
+
+
+class TestEccKey_P521(unittest.TestCase):
+
+ def test_private_key(self):
+
+ p521 = _curves['p521']
+
+ key = EccKey(curve="P-521", d=1)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ.x, p521.Gx)
+ self.assertEqual(key.pointQ.y, p521.Gy)
+
+ point = EccPoint(p521.Gx, p521.Gy, "p521")
+ key = EccKey(curve="P-521", d=1, point=point)
+ self.assertEqual(key.d, 1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ # Other names
+ key = EccKey(curve="p521", d=1)
+ key = EccKey(curve="secp521r1", d=1)
+ key = EccKey(curve="prime521v1", d=1)
+
+ def test_public_key(self):
+
+ p521 = _curves['p521']
+ point = EccPoint(p521.Gx, p521.Gy, 'p521')
+ key = EccKey(curve="P-384", point=point)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, point)
+
+ def test_public_key_derived(self):
+
+ priv_key = EccKey(curve="P-521", d=3)
+ pub_key = priv_key.public_key()
+ self.assertFalse(pub_key.has_private())
+ self.assertEqual(priv_key.pointQ, pub_key.pointQ)
+
+ def test_invalid_curve(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-522", d=1))
+
+ def test_invalid_d(self):
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-521", d=0))
+ self.assertRaises(ValueError, lambda: EccKey(curve="P-521",
+ d=_curves['p521'].order))
+
+ def test_equality(self):
+
+ private_key = ECC.construct(d=3, curve="P-521")
+ private_key2 = ECC.construct(d=3, curve="P-521")
+ private_key3 = ECC.construct(d=4, curve="P-521")
+
+ public_key = private_key.public_key()
+ public_key2 = private_key2.public_key()
+ public_key3 = private_key3.public_key()
+
+ self.assertEqual(private_key, private_key2)
+ self.assertNotEqual(private_key, private_key3)
+
+ self.assertEqual(public_key, public_key2)
+ self.assertNotEqual(public_key, public_key3)
+
+ self.assertNotEqual(public_key, private_key)
+
+
+class TestEccModule_P192(unittest.TestCase):
+
+ def test_generate(self):
+
+ key = ECC.generate(curve="P-192")
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, EccPoint(_curves['p192'].Gx,
+ _curves['p192'].Gy,
+ "P-192") * key.d,
+ "p192")
+
+ # Other names
+ ECC.generate(curve="secp192r1")
+ ECC.generate(curve="prime192v1")
+
+ def test_construct(self):
+
+ key = ECC.construct(curve="P-192", d=1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p192'].G)
+
+ key = ECC.construct(curve="P-192", point_x=_curves['p192'].Gx,
+ point_y=_curves['p192'].Gy)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p192'].G)
+
+ # Other names
+ ECC.construct(curve="p192", d=1)
+ ECC.construct(curve="secp192r1", d=1)
+ ECC.construct(curve="prime192v1", d=1)
+
+ def test_negative_construct(self):
+ coord = dict(point_x=10, point_y=4)
+ coordG = dict(point_x=_curves['p192'].Gx, point_y=_curves['p192'].Gy)
+
+ self.assertRaises(ValueError, ECC.construct, curve="P-192", **coord)
+ self.assertRaises(ValueError, ECC.construct, curve="P-192", d=2, **coordG)
+
+
+class TestEccModule_P224(unittest.TestCase):
+
+ def test_generate(self):
+
+ key = ECC.generate(curve="P-224")
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, EccPoint(_curves['p224'].Gx,
+ _curves['p224'].Gy,
+ "P-224") * key.d,
+ "p224")
+
+ # Other names
+ ECC.generate(curve="secp224r1")
+ ECC.generate(curve="prime224v1")
+
+ def test_construct(self):
+
+ key = ECC.construct(curve="P-224", d=1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p224'].G)
+
+ key = ECC.construct(curve="P-224", point_x=_curves['p224'].Gx,
+ point_y=_curves['p224'].Gy)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p224'].G)
+
+ # Other names
+ ECC.construct(curve="p224", d=1)
+ ECC.construct(curve="secp224r1", d=1)
+ ECC.construct(curve="prime224v1", d=1)
+
+ def test_negative_construct(self):
+ coord = dict(point_x=10, point_y=4)
+ coordG = dict(point_x=_curves['p224'].Gx, point_y=_curves['p224'].Gy)
+
+ self.assertRaises(ValueError, ECC.construct, curve="P-224", **coord)
+ self.assertRaises(ValueError, ECC.construct, curve="P-224", d=2, **coordG)
+
+
+class TestEccModule_P256(unittest.TestCase):
+
+ def test_generate(self):
+
+ key = ECC.generate(curve="P-256")
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, EccPoint(_curves['p256'].Gx,
+ _curves['p256'].Gy) * key.d,
+ "p256")
+
+ # Other names
+ ECC.generate(curve="secp256r1")
+ ECC.generate(curve="prime256v1")
+
+ def test_construct(self):
+
+ key = ECC.construct(curve="P-256", d=1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p256'].G)
+
+ key = ECC.construct(curve="P-256", point_x=_curves['p256'].Gx,
+ point_y=_curves['p256'].Gy)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p256'].G)
+
+ # Other names
+ ECC.construct(curve="p256", d=1)
+ ECC.construct(curve="secp256r1", d=1)
+ ECC.construct(curve="prime256v1", d=1)
+
+ def test_negative_construct(self):
+ coord = dict(point_x=10, point_y=4)
+ coordG = dict(point_x=_curves['p256'].Gx, point_y=_curves['p256'].Gy)
+
+ self.assertRaises(ValueError, ECC.construct, curve="P-256", **coord)
+ self.assertRaises(ValueError, ECC.construct, curve="P-256", d=2, **coordG)
+
+
+class TestEccModule_P384(unittest.TestCase):
+
+ def test_generate(self):
+
+ curve = _curves['p384']
+ key = ECC.generate(curve="P-384")
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, EccPoint(curve.Gx, curve.Gy, "p384") * key.d)
+
+ # Other names
+ ECC.generate(curve="secp384r1")
+ ECC.generate(curve="prime384v1")
+
+ def test_construct(self):
+
+ curve = _curves['p384']
+ key = ECC.construct(curve="P-384", d=1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p384'].G)
+
+ key = ECC.construct(curve="P-384", point_x=curve.Gx, point_y=curve.Gy)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, curve.G)
+
+ # Other names
+ ECC.construct(curve="p384", d=1)
+ ECC.construct(curve="secp384r1", d=1)
+ ECC.construct(curve="prime384v1", d=1)
+
+ def test_negative_construct(self):
+ coord = dict(point_x=10, point_y=4)
+ coordG = dict(point_x=_curves['p384'].Gx, point_y=_curves['p384'].Gy)
+
+ self.assertRaises(ValueError, ECC.construct, curve="P-384", **coord)
+ self.assertRaises(ValueError, ECC.construct, curve="P-384", d=2, **coordG)
+
+
+class TestEccModule_P521(unittest.TestCase):
+
+ def test_generate(self):
+
+ curve = _curves['p521']
+ key = ECC.generate(curve="P-521")
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, EccPoint(curve.Gx, curve.Gy, "p521") * key.d)
+
+ # Other names
+ ECC.generate(curve="secp521r1")
+ ECC.generate(curve="prime521v1")
+
+ def test_construct(self):
+
+ curve = _curves['p521']
+ key = ECC.construct(curve="P-521", d=1)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.pointQ, _curves['p521'].G)
+
+ key = ECC.construct(curve="P-521", point_x=curve.Gx, point_y=curve.Gy)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.pointQ, curve.G)
+
+ # Other names
+ ECC.construct(curve="p521", d=1)
+ ECC.construct(curve="secp521r1", d=1)
+ ECC.construct(curve="prime521v1", d=1)
+
+ def test_negative_construct(self):
+ coord = dict(point_x=10, point_y=4)
+ coordG = dict(point_x=_curves['p521'].Gx, point_y=_curves['p521'].Gy)
+
+ self.assertRaises(ValueError, ECC.construct, curve="P-521", **coord)
+ self.assertRaises(ValueError, ECC.construct, curve="P-521", d=2, **coordG)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(TestEccPoint)
+ tests += list_test_cases(TestEccPoint_NIST_P192)
+ tests += list_test_cases(TestEccPoint_NIST_P224)
+ tests += list_test_cases(TestEccPoint_NIST_P256)
+ tests += list_test_cases(TestEccPoint_NIST_P384)
+ tests += list_test_cases(TestEccPoint_NIST_P521)
+ tests += list_test_cases(TestEccPoint_PAI_P192)
+ tests += list_test_cases(TestEccPoint_PAI_P224)
+ tests += list_test_cases(TestEccPoint_PAI_P256)
+ tests += list_test_cases(TestEccPoint_PAI_P384)
+ tests += list_test_cases(TestEccPoint_PAI_P521)
+ tests += list_test_cases(TestEccKey_P192)
+ tests += list_test_cases(TestEccKey_P224)
+ tests += list_test_cases(TestEccKey_P256)
+ tests += list_test_cases(TestEccKey_P384)
+ tests += list_test_cases(TestEccKey_P521)
+ tests += list_test_cases(TestEccModule_P192)
+ tests += list_test_cases(TestEccModule_P224)
+ tests += list_test_cases(TestEccModule_P256)
+ tests += list_test_cases(TestEccModule_P384)
+ tests += list_test_cases(TestEccModule_P521)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/PublicKey/test_ElGamal.py b/lib/Crypto/SelfTest/PublicKey/test_ElGamal.py
new file mode 100644
index 0000000..0b394ae
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_ElGamal.py
@@ -0,0 +1,217 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/PublicKey/test_ElGamal.py: Self-test for the ElGamal primitive
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.PublicKey.ElGamal"""
+
+__revision__ = "$Id$"
+
+import unittest
+from Crypto.SelfTest.st_common import list_test_cases, a2b_hex, b2a_hex
+from Crypto import Random
+from Crypto.PublicKey import ElGamal
+from Crypto.Util.number import bytes_to_long
+from Crypto.Util.py3compat import *
+
+class ElGamalTest(unittest.TestCase):
+
+ #
+ # Test vectors
+ #
+ # There seem to be no real ElGamal test vectors available in the
+ # public domain. The following test vectors have been generated
+ # with libgcrypt 1.5.0.
+ #
+ # Encryption
+ tve=[
+ {
+ # 256 bits
+ 'p' :'BA4CAEAAED8CBE952AFD2126C63EB3B345D65C2A0A73D2A3AD4138B6D09BD933',
+ 'g' :'05',
+ 'y' :'60D063600ECED7C7C55146020E7A31C4476E9793BEAED420FEC9E77604CAE4EF',
+ 'x' :'1D391BA2EE3C37FE1BA175A69B2C73A11238AD77675932',
+ 'k' :'F5893C5BAB4131264066F57AB3D8AD89E391A0B68A68A1',
+ 'pt' :'48656C6C6F207468657265',
+ 'ct1':'32BFD5F487966CEA9E9356715788C491EC515E4ED48B58F0F00971E93AAA5EC7',
+ 'ct2':'7BE8FBFF317C93E82FCEF9BD515284BA506603FEA25D01C0CB874A31F315EE68'
+ },
+
+ {
+ # 512 bits
+ 'p' :'F1B18AE9F7B4E08FDA9A04832F4E919D89462FD31BF12F92791A93519F75076D6CE3942689CDFF2F344CAFF0F82D01864F69F3AECF566C774CBACF728B81A227',
+ 'g' :'07',
+ 'y' :'688628C676E4F05D630E1BE39D0066178CA7AA83836B645DE5ADD359B4825A12B02EF4252E4E6FA9BEC1DB0BE90F6D7C8629CABB6E531F472B2664868156E20C',
+ 'x' :'14E60B1BDFD33436C0DA8A22FDC14A2CCDBBED0627CE68',
+ 'k' :'38DBF14E1F319BDA9BAB33EEEADCAF6B2EA5250577ACE7',
+ 'pt' :'48656C6C6F207468657265',
+ 'ct1':'290F8530C2CC312EC46178724F196F308AD4C523CEABB001FACB0506BFED676083FE0F27AC688B5C749AB3CB8A80CD6F7094DBA421FB19442F5A413E06A9772B',
+ 'ct2':'1D69AAAD1DC50493FB1B8E8721D621D683F3BF1321BE21BC4A43E11B40C9D4D9C80DE3AAC2AB60D31782B16B61112E68220889D53C4C3136EE6F6CE61F8A23A0'
+ }
+ ]
+
+ # Signature
+ tvs=[
+ {
+ # 256 bits
+ 'p' :'D2F3C41EA66530838A704A48FFAC9334F4701ECE3A97CEE4C69DD01AE7129DD7',
+ 'g' :'05',
+ 'y' :'C3F9417DC0DAFEA6A05C1D2333B7A95E63B3F4F28CC962254B3256984D1012E7',
+ 'x' :'165E4A39BE44D5A2D8B1332D416BC559616F536BC735BB',
+ 'k' :'C7F0C794A7EAD726E25A47FF8928013680E73C51DD3D7D99BFDA8F492585928F',
+ 'h' :'48656C6C6F207468657265',
+ 'sig1':'35CA98133779E2073EF31165AFCDEB764DD54E96ADE851715495F9C635E1E7C2',
+ 'sig2':'0135B88B1151279FE5D8078D4FC685EE81177EE9802AB123A73925FC1CB059A7',
+ },
+ {
+ # 512 bits
+ 'p' :'E24CF3A4B8A6AF749DCA6D714282FE4AABEEE44A53BB6ED15FBE32B5D3C3EF9CC4124A2ECA331F3C1C1B667ACA3766825217E7B5F9856648D95F05330C6A19CF',
+ 'g' :'0B',
+ 'y' :'2AD3A1049CA5D4ED207B2431C79A8719BB4073D4A94E450EA6CEE8A760EB07ADB67C0D52C275EE85D7B52789061EE45F2F37D9B2AE522A51C28329766BFE68AC',
+ 'x' :'16CBB4F46D9ECCF24FF9F7E63CAA3BD8936341555062AB',
+ 'k' :'8A3D89A4E429FD2476D7D717251FB79BF900FFE77444E6BB8299DC3F84D0DD57ABAB50732AE158EA52F5B9E7D8813E81FD9F79470AE22F8F1CF9AEC820A78C69',
+ 'h' :'48656C6C6F207468657265',
+ 'sig1':'BE001AABAFFF976EC9016198FBFEA14CBEF96B000CCC0063D3324016F9E91FE80D8F9325812ED24DDB2B4D4CF4430B169880B3CE88313B53255BD4EC0378586F',
+ 'sig2':'5E266F3F837BA204E3BBB6DBECC0611429D96F8C7CE8F4EFDF9D4CB681C2A954468A357BF4242CEC7418B51DFC081BCD21299EF5B5A0DDEF3A139A1817503DDE',
+ }
+ ]
+
+ def test_generate_180(self):
+ self._test_random_key(180)
+
+ def test_encryption(self):
+ for tv in self.tve:
+ d = self.convert_tv(tv, True)
+ key = ElGamal.construct(d['key'])
+ ct = key._encrypt(d['pt'], d['k'])
+ self.assertEqual(ct[0], d['ct1'])
+ self.assertEqual(ct[1], d['ct2'])
+
+ def test_decryption(self):
+ for tv in self.tve:
+ d = self.convert_tv(tv, True)
+ key = ElGamal.construct(d['key'])
+ pt = key._decrypt((d['ct1'], d['ct2']))
+ self.assertEqual(pt, d['pt'])
+
+ def test_signing(self):
+ for tv in self.tvs:
+ d = self.convert_tv(tv, True)
+ key = ElGamal.construct(d['key'])
+ sig1, sig2 = key._sign(d['h'], d['k'])
+ self.assertEqual(sig1, d['sig1'])
+ self.assertEqual(sig2, d['sig2'])
+
+ def test_verification(self):
+ for tv in self.tvs:
+ d = self.convert_tv(tv, True)
+ key = ElGamal.construct(d['key'])
+ # Positive test
+ res = key._verify( d['h'], (d['sig1'],d['sig2']) )
+ self.assertTrue(res)
+ # Negative test
+ res = key._verify( d['h'], (d['sig1']+1,d['sig2']) )
+ self.assertFalse(res)
+
+ def test_bad_key3(self):
+ tup = tup0 = list(self.convert_tv(self.tvs[0], 1)['key'])[:3]
+ tup[0] += 1 # p += 1 (not prime)
+ self.assertRaises(ValueError, ElGamal.construct, tup)
+
+ tup = tup0
+ tup[1] = 1 # g = 1
+ self.assertRaises(ValueError, ElGamal.construct, tup)
+
+ tup = tup0
+ tup[2] = tup[0]*2 # y = 2*p
+ self.assertRaises(ValueError, ElGamal.construct, tup)
+
+ def test_bad_key4(self):
+ tup = tup0 = list(self.convert_tv(self.tvs[0], 1)['key'])
+ tup[3] += 1 # x += 1
+ self.assertRaises(ValueError, ElGamal.construct, tup)
+
+ def convert_tv(self, tv, as_longs=0):
+ """Convert a test vector from textual form (hexadecimal ascii
+ to either integers or byte strings."""
+ key_comps = 'p','g','y','x'
+ tv2 = {}
+ for c in tv.keys():
+ tv2[c] = a2b_hex(tv[c])
+ if as_longs or c in key_comps or c in ('sig1','sig2'):
+ tv2[c] = bytes_to_long(tv2[c])
+ tv2['key']=[]
+ for c in key_comps:
+ tv2['key'] += [tv2[c]]
+ del tv2[c]
+ return tv2
+
+ def _test_random_key(self, bits):
+ elgObj = ElGamal.generate(bits, Random.new().read)
+ self._check_private_key(elgObj)
+ self._exercise_primitive(elgObj)
+ pub = elgObj.publickey()
+ self._check_public_key(pub)
+ self._exercise_public_primitive(elgObj)
+
+ def _check_private_key(self, elgObj):
+
+ # Check capabilities
+ self.assertTrue(elgObj.has_private())
+
+ # Sanity check key data
+ self.assertTrue(1<elgObj.g<(elgObj.p-1))
+ self.assertEqual(pow(elgObj.g, elgObj.p-1, elgObj.p), 1)
+ self.assertTrue(1<elgObj.x<(elgObj.p-1))
+ self.assertEqual(pow(elgObj.g, elgObj.x, elgObj.p), elgObj.y)
+
+ def _check_public_key(self, elgObj):
+
+ # Check capabilities
+ self.assertFalse(elgObj.has_private())
+
+ # Sanity check key data
+ self.assertTrue(1<elgObj.g<(elgObj.p-1))
+ self.assertEqual(pow(elgObj.g, elgObj.p-1, elgObj.p), 1)
+
+ def _exercise_primitive(self, elgObj):
+ # Test encryption/decryption
+ plaintext = 127218
+ ciphertext = elgObj._encrypt(plaintext, 123456789)
+ plaintextP = elgObj._decrypt(ciphertext)
+ self.assertEqual(plaintext, plaintextP)
+
+ # Test signature/verification
+ signature = elgObj._sign(plaintext, 987654321)
+ elgObj._verify(plaintext, signature)
+
+ def _exercise_public_primitive(self, elgObj):
+ plaintext = 92987276
+ ciphertext = elgObj._encrypt(plaintext, 123456789)
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(ElGamalTest)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
diff --git a/lib/Crypto/SelfTest/PublicKey/test_RSA.py b/lib/Crypto/SelfTest/PublicKey/test_RSA.py
new file mode 100644
index 0000000..e7b5b90
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_RSA.py
@@ -0,0 +1,317 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/PublicKey/test_RSA.py: Self-test for the RSA primitive
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.PublicKey.RSA"""
+
+__revision__ = "$Id$"
+
+import os
+import pickle
+from pickle import PicklingError
+from Crypto.Util.py3compat import *
+
+import unittest
+from Crypto.SelfTest.st_common import list_test_cases, a2b_hex, b2a_hex
+
+class RSATest(unittest.TestCase):
+ # Test vectors from "RSA-OAEP and RSA-PSS test vectors (.zip file)"
+ # ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1-vec.zip
+ # See RSADSI's PKCS#1 page at
+ # http://www.rsa.com/rsalabs/node.asp?id=2125
+
+ # from oaep-int.txt
+
+ # TODO: PyCrypto treats the message as starting *after* the leading "00"
+ # TODO: That behaviour should probably be changed in the future.
+ plaintext = """
+ eb 7a 19 ac e9 e3 00 63 50 e3 29 50 4b 45 e2
+ ca 82 31 0b 26 dc d8 7d 5c 68 f1 ee a8 f5 52 67
+ c3 1b 2e 8b b4 25 1f 84 d7 e0 b2 c0 46 26 f5 af
+ f9 3e dc fb 25 c9 c2 b3 ff 8a e1 0e 83 9a 2d db
+ 4c dc fe 4f f4 77 28 b4 a1 b7 c1 36 2b aa d2 9a
+ b4 8d 28 69 d5 02 41 21 43 58 11 59 1b e3 92 f9
+ 82 fb 3e 87 d0 95 ae b4 04 48 db 97 2f 3a c1 4f
+ 7b c2 75 19 52 81 ce 32 d2 f1 b7 6d 4d 35 3e 2d
+ """
+
+ ciphertext = """
+ 12 53 e0 4d c0 a5 39 7b b4 4a 7a b8 7e 9b f2 a0
+ 39 a3 3d 1e 99 6f c8 2a 94 cc d3 00 74 c9 5d f7
+ 63 72 20 17 06 9e 52 68 da 5d 1c 0b 4f 87 2c f6
+ 53 c1 1d f8 23 14 a6 79 68 df ea e2 8d ef 04 bb
+ 6d 84 b1 c3 1d 65 4a 19 70 e5 78 3b d6 eb 96 a0
+ 24 c2 ca 2f 4a 90 fe 9f 2e f5 c9 c1 40 e5 bb 48
+ da 95 36 ad 87 00 c8 4f c9 13 0a de a7 4e 55 8d
+ 51 a7 4d df 85 d8 b5 0d e9 68 38 d6 06 3e 09 55
+ """
+
+ modulus = """
+ bb f8 2f 09 06 82 ce 9c 23 38 ac 2b 9d a8 71 f7
+ 36 8d 07 ee d4 10 43 a4 40 d6 b6 f0 74 54 f5 1f
+ b8 df ba af 03 5c 02 ab 61 ea 48 ce eb 6f cd 48
+ 76 ed 52 0d 60 e1 ec 46 19 71 9d 8a 5b 8b 80 7f
+ af b8 e0 a3 df c7 37 72 3e e6 b4 b7 d9 3a 25 84
+ ee 6a 64 9d 06 09 53 74 88 34 b2 45 45 98 39 4e
+ e0 aa b1 2d 7b 61 a5 1f 52 7a 9a 41 f6 c1 68 7f
+ e2 53 72 98 ca 2a 8f 59 46 f8 e5 fd 09 1d bd cb
+ """
+
+ e = 0x11 # public exponent
+
+ prime_factor = """
+ c9 7f b1 f0 27 f4 53 f6 34 12 33 ea aa d1 d9 35
+ 3f 6c 42 d0 88 66 b1 d0 5a 0f 20 35 02 8b 9d 86
+ 98 40 b4 16 66 b4 2e 92 ea 0d a3 b4 32 04 b5 cf
+ ce 33 52 52 4d 04 16 a5 a4 41 e7 00 af 46 15 03
+ """
+
+ def setUp(self):
+ global RSA, Random, bytes_to_long
+ from Crypto.PublicKey import RSA
+ from Crypto import Random
+ from Crypto.Util.number import bytes_to_long, inverse
+ self.n = bytes_to_long(a2b_hex(self.modulus))
+ self.p = bytes_to_long(a2b_hex(self.prime_factor))
+
+ # Compute q, d, and u from n, e, and p
+ self.q = self.n // self.p
+ self.d = inverse(self.e, (self.p-1)*(self.q-1))
+ self.u = inverse(self.p, self.q) # u = e**-1 (mod q)
+
+ self.rsa = RSA
+
+ def test_generate_1arg(self):
+ """RSA (default implementation) generated key (1 argument)"""
+ rsaObj = self.rsa.generate(1024)
+ self._check_private_key(rsaObj)
+ self._exercise_primitive(rsaObj)
+ pub = rsaObj.public_key()
+ self._check_public_key(pub)
+ self._exercise_public_primitive(rsaObj)
+
+ def test_generate_2arg(self):
+ """RSA (default implementation) generated key (2 arguments)"""
+ rsaObj = self.rsa.generate(1024, Random.new().read)
+ self._check_private_key(rsaObj)
+ self._exercise_primitive(rsaObj)
+ pub = rsaObj.public_key()
+ self._check_public_key(pub)
+ self._exercise_public_primitive(rsaObj)
+
+ def test_generate_3args(self):
+ rsaObj = self.rsa.generate(1024, Random.new().read,e=65537)
+ self._check_private_key(rsaObj)
+ self._exercise_primitive(rsaObj)
+ pub = rsaObj.public_key()
+ self._check_public_key(pub)
+ self._exercise_public_primitive(rsaObj)
+ self.assertEqual(65537,rsaObj.e)
+
+ def test_construct_2tuple(self):
+ """RSA (default implementation) constructed key (2-tuple)"""
+ pub = self.rsa.construct((self.n, self.e))
+ self._check_public_key(pub)
+ self._check_encryption(pub)
+
+ def test_construct_3tuple(self):
+ """RSA (default implementation) constructed key (3-tuple)"""
+ rsaObj = self.rsa.construct((self.n, self.e, self.d))
+ self._check_encryption(rsaObj)
+ self._check_decryption(rsaObj)
+
+ def test_construct_4tuple(self):
+ """RSA (default implementation) constructed key (4-tuple)"""
+ rsaObj = self.rsa.construct((self.n, self.e, self.d, self.p))
+ self._check_encryption(rsaObj)
+ self._check_decryption(rsaObj)
+
+ def test_construct_5tuple(self):
+ """RSA (default implementation) constructed key (5-tuple)"""
+ rsaObj = self.rsa.construct((self.n, self.e, self.d, self.p, self.q))
+ self._check_private_key(rsaObj)
+ self._check_encryption(rsaObj)
+ self._check_decryption(rsaObj)
+
+ def test_construct_6tuple(self):
+ """RSA (default implementation) constructed key (6-tuple)"""
+ rsaObj = self.rsa.construct((self.n, self.e, self.d, self.p, self.q, self.u))
+ self._check_private_key(rsaObj)
+ self._check_encryption(rsaObj)
+ self._check_decryption(rsaObj)
+
+ def test_construct_bad_key2(self):
+ tup = (self.n, 1)
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ # An even modulus is wrong
+ tup = (self.n+1, self.e)
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ def test_construct_bad_key3(self):
+ tup = (self.n, self.e, self.d+1)
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ def test_construct_bad_key5(self):
+ tup = (self.n, self.e, self.d, self.p, self.p)
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ tup = (self.p*self.p, self.e, self.p, self.p)
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ tup = (self.p*self.p, 3, self.p, self.q)
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ def test_construct_bad_key6(self):
+ tup = (self.n, self.e, self.d, self.p, self.q, 10)
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ from Crypto.Util.number import inverse
+ tup = (self.n, self.e, self.d, self.p, self.q, inverse(self.q, self.p))
+ self.assertRaises(ValueError, self.rsa.construct, tup)
+
+ def test_factoring(self):
+ rsaObj = self.rsa.construct([self.n, self.e, self.d])
+ self.assertTrue(rsaObj.p==self.p or rsaObj.p==self.q)
+ self.assertTrue(rsaObj.q==self.p or rsaObj.q==self.q)
+ self.assertTrue(rsaObj.q*rsaObj.p == self.n)
+
+ self.assertRaises(ValueError, self.rsa.construct, [self.n, self.e, self.n-1])
+
+ def test_repr(self):
+ rsaObj = self.rsa.construct((self.n, self.e, self.d, self.p, self.q))
+ repr(rsaObj)
+
+ def test_serialization(self):
+ """RSA keys are unpickable"""
+
+ rsa_key = self.rsa.generate(1024)
+ self.assertRaises(PicklingError, pickle.dumps, rsa_key)
+
+ def test_raw_rsa_boundary(self):
+ # The argument of every RSA raw operation (encrypt/decrypt) must be
+ # non-negative and no larger than the modulus
+ rsa_obj = self.rsa.generate(1024)
+
+ self.assertRaises(ValueError, rsa_obj._decrypt, rsa_obj.n)
+ self.assertRaises(ValueError, rsa_obj._encrypt, rsa_obj.n)
+
+ self.assertRaises(ValueError, rsa_obj._decrypt, -1)
+ self.assertRaises(ValueError, rsa_obj._encrypt, -1)
+
+ def test_size(self):
+ pub = self.rsa.construct((self.n, self.e))
+ self.assertEqual(pub.size_in_bits(), 1024)
+ self.assertEqual(pub.size_in_bytes(), 128)
+
+ def _check_private_key(self, rsaObj):
+ from Crypto.Math.Numbers import Integer
+
+ # Check capabilities
+ self.assertEqual(1, rsaObj.has_private())
+
+ # Sanity check key data
+ self.assertEqual(rsaObj.n, rsaObj.p * rsaObj.q) # n = pq
+ lcm = int(Integer(rsaObj.p-1).lcm(rsaObj.q-1))
+ self.assertEqual(1, rsaObj.d * rsaObj.e % lcm) # ed = 1 (mod LCM(p-1, q-1))
+ self.assertEqual(1, rsaObj.p * rsaObj.u % rsaObj.q) # pu = 1 (mod q)
+ self.assertEqual(1, rsaObj.p > 1) # p > 1
+ self.assertEqual(1, rsaObj.q > 1) # q > 1
+ self.assertEqual(1, rsaObj.e > 1) # e > 1
+ self.assertEqual(1, rsaObj.d > 1) # d > 1
+
+ def _check_public_key(self, rsaObj):
+ ciphertext = a2b_hex(self.ciphertext)
+
+ # Check capabilities
+ self.assertEqual(0, rsaObj.has_private())
+
+ # Check rsaObj.[ne] -> rsaObj.[ne] mapping
+ self.assertEqual(rsaObj.n, rsaObj.n)
+ self.assertEqual(rsaObj.e, rsaObj.e)
+
+ # Check that private parameters are all missing
+ self.assertEqual(0, hasattr(rsaObj, 'd'))
+ self.assertEqual(0, hasattr(rsaObj, 'p'))
+ self.assertEqual(0, hasattr(rsaObj, 'q'))
+ self.assertEqual(0, hasattr(rsaObj, 'u'))
+
+ # Sanity check key data
+ self.assertEqual(1, rsaObj.e > 1) # e > 1
+
+ # Public keys should not be able to sign or decrypt
+ self.assertRaises(TypeError, rsaObj._decrypt,
+ bytes_to_long(ciphertext))
+
+ # Check __eq__ and __ne__
+ self.assertEqual(rsaObj.public_key() == rsaObj.public_key(),True) # assert_
+ self.assertEqual(rsaObj.public_key() != rsaObj.public_key(),False) # assertFalse
+
+ self.assertEqual(rsaObj.publickey(), rsaObj.public_key())
+
+ def _exercise_primitive(self, rsaObj):
+ # Since we're using a randomly-generated key, we can't check the test
+ # vector, but we can make sure encryption and decryption are inverse
+ # operations.
+ ciphertext = bytes_to_long(a2b_hex(self.ciphertext))
+
+ # Test decryption
+ plaintext = rsaObj._decrypt(ciphertext)
+
+ # Test encryption (2 arguments)
+ new_ciphertext2 = rsaObj._encrypt(plaintext)
+ self.assertEqual(ciphertext, new_ciphertext2)
+
+ def _exercise_public_primitive(self, rsaObj):
+ plaintext = a2b_hex(self.plaintext)
+
+ # Test encryption (2 arguments)
+ new_ciphertext2 = rsaObj._encrypt(bytes_to_long(plaintext))
+
+ def _check_encryption(self, rsaObj):
+ plaintext = a2b_hex(self.plaintext)
+ ciphertext = a2b_hex(self.ciphertext)
+
+ # Test encryption
+ new_ciphertext2 = rsaObj._encrypt(bytes_to_long(plaintext))
+ self.assertEqual(bytes_to_long(ciphertext), new_ciphertext2)
+
+ def _check_decryption(self, rsaObj):
+ plaintext = bytes_to_long(a2b_hex(self.plaintext))
+ ciphertext = bytes_to_long(a2b_hex(self.ciphertext))
+
+ # Test plain decryption
+ new_plaintext = rsaObj._decrypt(ciphertext)
+ self.assertEqual(plaintext, new_plaintext)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(RSATest)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/PublicKey/test_import_DSA.py b/lib/Crypto/SelfTest/PublicKey/test_import_DSA.py
new file mode 100644
index 0000000..266b46f
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_import_DSA.py
@@ -0,0 +1,554 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/PublicKey/test_import_DSA.py: Self-test for importing DSA keys
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+import unittest
+import re
+
+from Crypto.PublicKey import DSA
+from Crypto.SelfTest.st_common import *
+from Crypto.Util.py3compat import *
+
+from binascii import unhexlify
+
+class ImportKeyTests(unittest.TestCase):
+
+ y = 92137165128186062214622779787483327510946462589285775188003362705875131352591574106484271700740858696583623951844732128165434284507709057439633739849986759064015013893156866539696757799934634945787496920169462601722830899660681779448742875054459716726855443681559131362852474817534616736104831095601710736729
+ p = 162452170958135306109773853318304545923250830605675936228618290525164105310663722368377131295055868997377338797580997938253236213714988311430600065853662861806894003694743806769284131194035848116051021923956699231855223389086646903420682639786976554552864568460372266462812137447840653688476258666833303658691
+ q = 988791743931120302950649732173330531512663554851
+ g = 85583152299197514738065570254868711517748965097380456700369348466136657764813442044039878840094809620913085570225318356734366886985903212775602770761953571967834823306046501307810937486758039063386311593890777319935391363872375452381836756832784184928202587843258855704771836753434368484556809100537243908232
+ x = 540873410045082450874416847965843801027716145253
+
+ def setUp(self):
+
+ # It is easier to write test vectors in text form,
+ # and convert them to byte strigs dynamically here
+ for mname, mvalue in ImportKeyTests.__dict__.items():
+ if mname[:4] in ('der_', 'pem_', 'ssh_'):
+ if mname[:4] == 'der_':
+ mvalue = unhexlify(tobytes(mvalue))
+ mvalue = tobytes(mvalue)
+ setattr(self, mname, mvalue)
+
+ # 1. SubjectPublicKeyInfo
+ der_public=\
+ '308201b73082012b06072a8648ce3804013082011e02818100e756ee1717f4b6'+\
+ '794c7c214724a19763742c45572b4b3f8ff3b44f3be9f44ce039a2757695ec91'+\
+ '5697da74ef914fcd1b05660e2419c761d639f45d2d79b802dbd23e7ab8b81b47'+\
+ '9a380e1f30932584ba2a0b955032342ebc83cb5ca906e7b0d7cd6fe656cecb4c'+\
+ '8b5a77123a8c6750a481e3b06057aff6aa6eba620b832d60c3021500ad32f48c'+\
+ 'd3ae0c45a198a61fa4b5e20320763b2302818079dfdc3d614fe635fceb7eaeae'+\
+ '3718dc2efefb45282993ac6749dc83c223d8c1887296316b3b0b54466cf444f3'+\
+ '4b82e3554d0b90a778faaf1306f025dae6a3e36c7f93dd5bac4052b92370040a'+\
+ 'ca70b8d5820599711900efbc961812c355dd9beffe0981da85c5548074b41c56'+\
+ 'ae43fd300d89262e4efd89943f99a651b03888038185000281810083352a69a1'+\
+ '32f34843d2a0eb995bff4e2f083a73f0049d2c91ea2f0ce43d144abda48199e4'+\
+ 'b003c570a8af83303d45105f606c5c48d925a40ed9c2630c2fa4cdbf838539de'+\
+ 'b9a29f919085f2046369f627ca84b2cb1e2c7940564b670f963ab1164d4e2ca2'+\
+ 'bf6ffd39f12f548928bf4d2d1b5e6980b4f1be4c92a91986fba559'
+
+ def testImportKey1(self):
+ key_obj = DSA.importKey(self.der_public)
+ self.assertFalse(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+
+ def testExportKey1(self):
+ tup = (self.y, self.g, self.p, self.q)
+ key = DSA.construct(tup)
+ encoded = key.export_key('DER')
+ self.assertEqual(self.der_public, encoded)
+
+ # 2.
+ pem_public="""\
+-----BEGIN PUBLIC KEY-----
+MIIBtzCCASsGByqGSM44BAEwggEeAoGBAOdW7hcX9LZ5THwhRyShl2N0LEVXK0s/
+j/O0Tzvp9EzgOaJ1dpXskVaX2nTvkU/NGwVmDiQZx2HWOfRdLXm4AtvSPnq4uBtH
+mjgOHzCTJYS6KguVUDI0LryDy1ypBuew181v5lbOy0yLWncSOoxnUKSB47BgV6/2
+qm66YguDLWDDAhUArTL0jNOuDEWhmKYfpLXiAyB2OyMCgYB539w9YU/mNfzrfq6u
+NxjcLv77RSgpk6xnSdyDwiPYwYhyljFrOwtURmz0RPNLguNVTQuQp3j6rxMG8CXa
+5qPjbH+T3VusQFK5I3AECspwuNWCBZlxGQDvvJYYEsNV3Zvv/gmB2oXFVIB0tBxW
+rkP9MA2JJi5O/YmUP5mmUbA4iAOBhQACgYEAgzUqaaEy80hD0qDrmVv/Ti8IOnPw
+BJ0skeovDOQ9FEq9pIGZ5LADxXCor4MwPUUQX2BsXEjZJaQO2cJjDC+kzb+DhTne
+uaKfkZCF8gRjafYnyoSyyx4seUBWS2cPljqxFk1OLKK/b/058S9UiSi/TS0bXmmA
+tPG+TJKpGYb7pVk=
+-----END PUBLIC KEY-----"""
+
+ def testImportKey2(self):
+ for pem in (self.pem_public, tostr(self.pem_public)):
+ key_obj = DSA.importKey(pem)
+ self.assertFalse(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+
+ def testExportKey2(self):
+ tup = (self.y, self.g, self.p, self.q)
+ key = DSA.construct(tup)
+ encoded = key.export_key('PEM')
+ self.assertEqual(self.pem_public, encoded)
+
+ # 3. OpenSSL/OpenSSH format
+ der_private=\
+ '308201bb02010002818100e756ee1717f4b6794c7c214724a19763742c45572b'+\
+ '4b3f8ff3b44f3be9f44ce039a2757695ec915697da74ef914fcd1b05660e2419'+\
+ 'c761d639f45d2d79b802dbd23e7ab8b81b479a380e1f30932584ba2a0b955032'+\
+ '342ebc83cb5ca906e7b0d7cd6fe656cecb4c8b5a77123a8c6750a481e3b06057'+\
+ 'aff6aa6eba620b832d60c3021500ad32f48cd3ae0c45a198a61fa4b5e2032076'+\
+ '3b2302818079dfdc3d614fe635fceb7eaeae3718dc2efefb45282993ac6749dc'+\
+ '83c223d8c1887296316b3b0b54466cf444f34b82e3554d0b90a778faaf1306f0'+\
+ '25dae6a3e36c7f93dd5bac4052b92370040aca70b8d5820599711900efbc9618'+\
+ '12c355dd9beffe0981da85c5548074b41c56ae43fd300d89262e4efd89943f99'+\
+ 'a651b038880281810083352a69a132f34843d2a0eb995bff4e2f083a73f0049d'+\
+ '2c91ea2f0ce43d144abda48199e4b003c570a8af83303d45105f606c5c48d925'+\
+ 'a40ed9c2630c2fa4cdbf838539deb9a29f919085f2046369f627ca84b2cb1e2c'+\
+ '7940564b670f963ab1164d4e2ca2bf6ffd39f12f548928bf4d2d1b5e6980b4f1'+\
+ 'be4c92a91986fba55902145ebd9a3f0b82069d98420986b314215025756065'
+
+ def testImportKey3(self):
+ key_obj = DSA.importKey(self.der_private)
+ self.assertTrue(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+ self.assertEqual(self.x, key_obj.x)
+
+ def testExportKey3(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ encoded = key.export_key('DER', pkcs8=False)
+ self.assertEqual(self.der_private, encoded)
+
+ # 4.
+ pem_private="""\
+-----BEGIN DSA PRIVATE KEY-----
+MIIBuwIBAAKBgQDnVu4XF/S2eUx8IUckoZdjdCxFVytLP4/ztE876fRM4DmidXaV
+7JFWl9p075FPzRsFZg4kGcdh1jn0XS15uALb0j56uLgbR5o4Dh8wkyWEuioLlVAy
+NC68g8tcqQbnsNfNb+ZWzstMi1p3EjqMZ1CkgeOwYFev9qpuumILgy1gwwIVAK0y
+9IzTrgxFoZimH6S14gMgdjsjAoGAed/cPWFP5jX8636urjcY3C7++0UoKZOsZ0nc
+g8Ij2MGIcpYxazsLVEZs9ETzS4LjVU0LkKd4+q8TBvAl2uaj42x/k91brEBSuSNw
+BArKcLjVggWZcRkA77yWGBLDVd2b7/4JgdqFxVSAdLQcVq5D/TANiSYuTv2JlD+Z
+plGwOIgCgYEAgzUqaaEy80hD0qDrmVv/Ti8IOnPwBJ0skeovDOQ9FEq9pIGZ5LAD
+xXCor4MwPUUQX2BsXEjZJaQO2cJjDC+kzb+DhTneuaKfkZCF8gRjafYnyoSyyx4s
+eUBWS2cPljqxFk1OLKK/b/058S9UiSi/TS0bXmmAtPG+TJKpGYb7pVkCFF69mj8L
+ggadmEIJhrMUIVAldWBl
+-----END DSA PRIVATE KEY-----"""
+
+ def testImportKey4(self):
+ for pem in (self.pem_private, tostr(self.pem_private)):
+ key_obj = DSA.importKey(pem)
+ self.assertTrue(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+ self.assertEqual(self.x, key_obj.x)
+
+ def testExportKey4(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ encoded = key.export_key('PEM', pkcs8=False)
+ self.assertEqual(self.pem_private, encoded)
+
+ # 5. PKCS8 (unencrypted)
+ der_pkcs8=\
+ '3082014a0201003082012b06072a8648ce3804013082011e02818100e756ee17'+\
+ '17f4b6794c7c214724a19763742c45572b4b3f8ff3b44f3be9f44ce039a27576'+\
+ '95ec915697da74ef914fcd1b05660e2419c761d639f45d2d79b802dbd23e7ab8'+\
+ 'b81b479a380e1f30932584ba2a0b955032342ebc83cb5ca906e7b0d7cd6fe656'+\
+ 'cecb4c8b5a77123a8c6750a481e3b06057aff6aa6eba620b832d60c3021500ad'+\
+ '32f48cd3ae0c45a198a61fa4b5e20320763b2302818079dfdc3d614fe635fceb'+\
+ '7eaeae3718dc2efefb45282993ac6749dc83c223d8c1887296316b3b0b54466c'+\
+ 'f444f34b82e3554d0b90a778faaf1306f025dae6a3e36c7f93dd5bac4052b923'+\
+ '70040aca70b8d5820599711900efbc961812c355dd9beffe0981da85c5548074'+\
+ 'b41c56ae43fd300d89262e4efd89943f99a651b03888041602145ebd9a3f0b82'+\
+ '069d98420986b314215025756065'
+
+ def testImportKey5(self):
+ key_obj = DSA.importKey(self.der_pkcs8)
+ self.assertTrue(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+ self.assertEqual(self.x, key_obj.x)
+
+ def testExportKey5(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ encoded = key.export_key('DER')
+ self.assertEqual(self.der_pkcs8, encoded)
+ encoded = key.export_key('DER', pkcs8=True)
+ self.assertEqual(self.der_pkcs8, encoded)
+
+ # 6.
+ pem_pkcs8="""\
+-----BEGIN PRIVATE KEY-----
+MIIBSgIBADCCASsGByqGSM44BAEwggEeAoGBAOdW7hcX9LZ5THwhRyShl2N0LEVX
+K0s/j/O0Tzvp9EzgOaJ1dpXskVaX2nTvkU/NGwVmDiQZx2HWOfRdLXm4AtvSPnq4
+uBtHmjgOHzCTJYS6KguVUDI0LryDy1ypBuew181v5lbOy0yLWncSOoxnUKSB47Bg
+V6/2qm66YguDLWDDAhUArTL0jNOuDEWhmKYfpLXiAyB2OyMCgYB539w9YU/mNfzr
+fq6uNxjcLv77RSgpk6xnSdyDwiPYwYhyljFrOwtURmz0RPNLguNVTQuQp3j6rxMG
+8CXa5qPjbH+T3VusQFK5I3AECspwuNWCBZlxGQDvvJYYEsNV3Zvv/gmB2oXFVIB0
+tBxWrkP9MA2JJi5O/YmUP5mmUbA4iAQWAhRevZo/C4IGnZhCCYazFCFQJXVgZQ==
+-----END PRIVATE KEY-----"""
+
+ def testImportKey6(self):
+ for pem in (self.pem_pkcs8, tostr(self.pem_pkcs8)):
+ key_obj = DSA.importKey(pem)
+ self.assertTrue(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+ self.assertEqual(self.x, key_obj.x)
+
+ def testExportKey6(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ encoded = key.export_key('PEM')
+ self.assertEqual(self.pem_pkcs8, encoded)
+ encoded = key.export_key('PEM', pkcs8=True)
+ self.assertEqual(self.pem_pkcs8, encoded)
+
+ # 7. OpenSSH/RFC4253
+ ssh_pub="""ssh-dss AAAAB3NzaC1kc3MAAACBAOdW7hcX9LZ5THwhRyShl2N0LEVXK0s/j/O0Tzvp9EzgOaJ1dpXskVaX2nTvkU/NGwVmDiQZx2HWOfRdLXm4AtvSPnq4uBtHmjgOHzCTJYS6KguVUDI0LryDy1ypBuew181v5lbOy0yLWncSOoxnUKSB47BgV6/2qm66YguDLWDDAAAAFQCtMvSM064MRaGYph+kteIDIHY7IwAAAIB539w9YU/mNfzrfq6uNxjcLv77RSgpk6xnSdyDwiPYwYhyljFrOwtURmz0RPNLguNVTQuQp3j6rxMG8CXa5qPjbH+T3VusQFK5I3AECspwuNWCBZlxGQDvvJYYEsNV3Zvv/gmB2oXFVIB0tBxWrkP9MA2JJi5O/YmUP5mmUbA4iAAAAIEAgzUqaaEy80hD0qDrmVv/Ti8IOnPwBJ0skeovDOQ9FEq9pIGZ5LADxXCor4MwPUUQX2BsXEjZJaQO2cJjDC+kzb+DhTneuaKfkZCF8gRjafYnyoSyyx4seUBWS2cPljqxFk1OLKK/b/058S9UiSi/TS0bXmmAtPG+TJKpGYb7pVk="""
+
+ def testImportKey7(self):
+ for ssh in (self.ssh_pub, tostr(self.ssh_pub)):
+ key_obj = DSA.importKey(ssh)
+ self.assertFalse(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+
+ def testExportKey7(self):
+ tup = (self.y, self.g, self.p, self.q)
+ key = DSA.construct(tup)
+ encoded = key.export_key('OpenSSH')
+ self.assertEqual(self.ssh_pub, encoded)
+
+ # 8. Encrypted OpenSSL/OpenSSH
+ pem_private_encrypted="""\
+-----BEGIN DSA PRIVATE KEY-----
+Proc-Type: 4,ENCRYPTED
+DEK-Info: AES-128-CBC,70B6908939D65E9F2EB999E8729788CE
+
+4V6GHRDpCrdZ8MBjbyp5AlGUrjvr2Pn2e2zVxy5RBt4FBj9/pa0ae0nnyUPMLSUU
+kKyOR0topRYTVRLElm4qVrb5uNZ3hRwfbklr+pSrB7O9eHz9V5sfOQxyODS07JxK
+k1OdOs70/ouMXLF9EWfAZOmWUccZKHNblUwg1p1UrZIz5jXw4dUE/zqhvXh6d+iC
+ADsICaBCjCrRQJKDp50h3+ndQjkYBKVH+pj8TiQ79U7lAvdp3+iMghQN6YXs9mdI
+gFpWw/f97oWM4GHZFqHJ+VSMNFjBiFhAvYV587d7Lk4dhD8sCfbxj42PnfRgUItc
+nnPqHxmhMQozBWzYM4mQuo3XbF2WlsNFbOzFVyGhw1Bx1s91qvXBVWJh2ozrW0s6
+HYDV7ZkcTml/4kjA/d+mve6LZ8kuuR1qCiZx6rkffhh1gDN/1Xz3HVvIy/dQ+h9s
+5zp7PwUoWbhqp3WCOr156P6gR8qo7OlT6wMh33FSXK/mxikHK136fV2shwTKQVII
+rJBvXpj8nACUmi7scKuTWGeUoXa+dwTZVVe+b+L2U1ZM7+h/neTJiXn7u99PFUwu
+xVJtxaV37m3aXxtCsPnbBg==
+-----END DSA PRIVATE KEY-----"""
+
+ def testImportKey8(self):
+ for pem in (self.pem_private_encrypted, tostr(self.pem_private_encrypted)):
+ key_obj = DSA.importKey(pem, "PWDTEST")
+ self.assertTrue(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+ self.assertEqual(self.x, key_obj.x)
+
+ def testExportKey8(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ encoded = key.export_key('PEM', pkcs8=False, passphrase="PWDTEST")
+ key = DSA.importKey(encoded, "PWDTEST")
+ self.assertEqual(self.y, key.y)
+ self.assertEqual(self.p, key.p)
+ self.assertEqual(self.q, key.q)
+ self.assertEqual(self.g, key.g)
+ self.assertEqual(self.x, key.x)
+
+ # 9. Encrypted PKCS8
+ # pbeWithMD5AndDES-CBC
+ pem_pkcs8_encrypted="""\
+-----BEGIN ENCRYPTED PRIVATE KEY-----
+MIIBcTAbBgkqhkiG9w0BBQMwDgQI0GC3BJ/jSw8CAggABIIBUHc1cXZpExIE9tC7
+7ryiW+5ihtF2Ekurq3e408GYSAu5smJjN2bvQXmzRFBz8W38K8eMf1sbWroZ4+zn
+kZSbb9nSm5kAa8lR2+oF2k+WRswMR/PTC3f/D9STO2X0QxdrzKgIHEcSGSHp5jTx
+aVvbkCDHo9vhBTl6S3ogZ48As/MEro76+9igUwJ1jNhIQZPJ7e20QH5qDpQFFJN4
+CKl2ENSEuwGiqBszItFy4dqH0g63ZGZV/xt9wSO9Rd7SK/EbA/dklOxBa5Y/VItM
+gnIhs9XDMoGYyn6F023EicNJm6g/bVQk81BTTma4tm+12TKGdYm+QkeZvCOMZylr
+Wv67cKwO3cAXt5C3QXMDgYR64XvuaT5h7C0igMp2afSXJlnbHEbFxQVJlv83T4FM
+eZ4k+NQDbEL8GiHmFxzDWQAuPPZKJWEEEV2p/To+WOh+kSDHQw==
+-----END ENCRYPTED PRIVATE KEY-----"""
+
+ def testImportKey9(self):
+ for pem in (self.pem_pkcs8_encrypted, tostr(self.pem_pkcs8_encrypted)):
+ key_obj = DSA.importKey(pem, "PWDTEST")
+ self.assertTrue(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+ self.assertEqual(self.x, key_obj.x)
+
+ # 10. Encrypted PKCS8
+ # pkcs5PBES2 /
+ # pkcs5PBKDF2 (rounds=1000, salt=D725BF1B6B8239F4) /
+ # des-EDE3-CBC (iv=27A1C66C42AFEECE)
+ #
+ der_pkcs8_encrypted=\
+ '30820196304006092a864886f70d01050d3033301b06092a864886f70d01050c'+\
+ '300e0408d725bf1b6b8239f4020203e8301406082a864886f70d0307040827a1'+\
+ 'c66c42afeece048201505cacfde7bf8edabb3e0d387950dc872662ea7e9b1ed4'+\
+ '400d2e7e6186284b64668d8d0328c33a9d9397e6f03df7cb68268b0a06b4e22f'+\
+ '7d132821449ecf998a8b696dbc6dd2b19e66d7eb2edfeb4153c1771d49702395'+\
+ '4f36072868b5fcccf93413a5ac4b2eb47d4b3f681c6bd67ae363ed776f45ae47'+\
+ '174a00098a7c930a50f820b227ddf50f9742d8e950d02586ff2dac0e3c372248'+\
+ 'e5f9b6a7a02f4004f20c87913e0f7b52bccc209b95d478256a890b31d4c9adec'+\
+ '21a4d157a179a93a3dad06f94f3ce486b46dfa7fc15fd852dd7680bbb2f17478'+\
+ '7e71bd8dbaf81eca7518d76c1d26256e95424864ba45ca5d47d7c5a421be02fa'+\
+ 'b94ab01e18593f66cf9094eb5c94b9ecf3aa08b854a195cf87612fbe5e96c426'+\
+ '2b0d573e52dc71ba3f5e468c601e816c49b7d32c698b22175e89aaef0c443770'+\
+ '5ef2f88a116d99d8e2869a4fd09a771b84b49e4ccb79aadcb1c9'
+
+ def testImportKey10(self):
+ key_obj = DSA.importKey(self.der_pkcs8_encrypted, "PWDTEST")
+ self.assertTrue(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+ self.assertEqual(self.x, key_obj.x)
+
+ def testExportKey10(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ randfunc = BytesIO(unhexlify(b("27A1C66C42AFEECE") + b("D725BF1B6B8239F4"))).read
+ encoded = key.export_key('DER', pkcs8=True, passphrase="PWDTEST", randfunc=randfunc)
+ self.assertEqual(self.der_pkcs8_encrypted, encoded)
+
+ # ----
+
+ def testImportError1(self):
+ self.assertRaises(ValueError, DSA.importKey, self.der_pkcs8_encrypted, "wrongpwd")
+
+ def testExportError2(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ self.assertRaises(ValueError, key.export_key, 'DER', pkcs8=False, passphrase="PWDTEST")
+
+ def test_import_key(self):
+ """Verify importKey is an alias to import_key"""
+
+ key_obj = DSA.import_key(self.der_public)
+ self.assertFalse(key_obj.has_private())
+ self.assertEqual(self.y, key_obj.y)
+ self.assertEqual(self.p, key_obj.p)
+ self.assertEqual(self.q, key_obj.q)
+ self.assertEqual(self.g, key_obj.g)
+
+ def test_exportKey(self):
+ tup = (self.y, self.g, self.p, self.q, self.x)
+ key = DSA.construct(tup)
+ self.assertEqual(key.exportKey(), key.export_key())
+
+
+ def test_import_empty(self):
+ self.assertRaises(ValueError, DSA.import_key, b'')
+
+
+class ImportKeyFromX509Cert(unittest.TestCase):
+
+ def test_x509v1(self):
+
+ # Sample V1 certificate with a 1024 bit DSA key
+ x509_v1_cert = """
+-----BEGIN CERTIFICATE-----
+MIIDUjCCArsCAQIwDQYJKoZIhvcNAQEFBQAwfjENMAsGA1UEChMEQWNtZTELMAkG
+A1UECxMCUkQxHDAaBgkqhkiG9w0BCQEWDXNwYW1AYWNtZS5vcmcxEzARBgNVBAcT
+Ck1ldHJvcG9saXMxETAPBgNVBAgTCE5ldyBZb3JrMQswCQYDVQQGEwJVUzENMAsG
+A1UEAxMEdGVzdDAeFw0xNDA3MTEyMDM4NDNaFw0xNzA0MDYyMDM4NDNaME0xCzAJ
+BgNVBAYTAlVTMREwDwYDVQQIEwhOZXcgWW9yazENMAsGA1UEChMEQWNtZTELMAkG
+A1UECxMCUkQxDzANBgNVBAMTBnBvbGFuZDCCAbYwggErBgcqhkjOOAQBMIIBHgKB
+gQDOrN4Ox4+t3T6wKeHfhzArhcrNEFMQ4Ss+4PIKyimDy9Bn64WPkL1B/9dvYIga
+23GLu6tVJmXo6EdJnVOHEMhr99EeOwuDWWeP7Awq7RSlKEejokr4BEzMTW/tExSD
+cO6/GI7xzh0eTH+VTTPDfyrJMYCkh0rJAfCP+5xrmPNetwIVALtXYOV1yoRrzJ2Q
+M5uEjidH6GiZAoGAfUqA1SAm5g5U68SILMVX9l5rq0OpB0waBMpJQ31/R/yXNDqo
+c3gGWZTOJFU4IzwNpGhrGNADUByz/lc1SAOAdEJIr0JVrhbGewQjB4pWqoLGbBKz
+RoavTNDc/zD7SYa12evWDHADwvlXoeQg+lWop1zS8OqaDC7aLGKpWN3/m8kDgYQA
+AoGAKoirPAfcp1rbbl4y2FFAIktfW8f4+T7d2iKSg73aiVfujhNOt1Zz1lfC0NI2
+eonLWO3tAM4XGKf1TLjb5UXngGn40okPsaA81YE6ZIKm20ywjlOY3QkAEdMaLVY3
+9PJvM8RGB9m7pLKxyHfGMfF40MVN4222zKeGp7xhM0CNiCUwDQYJKoZIhvcNAQEF
+BQADgYEAfbNZfpYa2KlALEM1FZnwvQDvJHntHz8LdeJ4WM7CXDlKi67wY2HKM30w
+s2xej75imkVOFd1kF2d0A8sjfriXLVIt1Hwq9ANZomhu4Edx0xpH8tqdh/bDtnM2
+TmduZNY9OWkb07h0CtWD6Zt8fhRllVsSSrlWd/2or7FXNC5weFQ=
+-----END CERTIFICATE-----
+ """.strip()
+
+ # DSA public key as dumped by openssl
+ y_str = """
+2a:88:ab:3c:07:dc:a7:5a:db:6e:5e:32:d8:51:40:
+22:4b:5f:5b:c7:f8:f9:3e:dd:da:22:92:83:bd:da:
+89:57:ee:8e:13:4e:b7:56:73:d6:57:c2:d0:d2:36:
+7a:89:cb:58:ed:ed:00:ce:17:18:a7:f5:4c:b8:db:
+e5:45:e7:80:69:f8:d2:89:0f:b1:a0:3c:d5:81:3a:
+64:82:a6:db:4c:b0:8e:53:98:dd:09:00:11:d3:1a:
+2d:56:37:f4:f2:6f:33:c4:46:07:d9:bb:a4:b2:b1:
+c8:77:c6:31:f1:78:d0:c5:4d:e3:6d:b6:cc:a7:86:
+a7:bc:61:33:40:8d:88:25
+ """
+ p_str = """
+00:ce:ac:de:0e:c7:8f:ad:dd:3e:b0:29:e1:df:87:
+30:2b:85:ca:cd:10:53:10:e1:2b:3e:e0:f2:0a:ca:
+29:83:cb:d0:67:eb:85:8f:90:bd:41:ff:d7:6f:60:
+88:1a:db:71:8b:bb:ab:55:26:65:e8:e8:47:49:9d:
+53:87:10:c8:6b:f7:d1:1e:3b:0b:83:59:67:8f:ec:
+0c:2a:ed:14:a5:28:47:a3:a2:4a:f8:04:4c:cc:4d:
+6f:ed:13:14:83:70:ee:bf:18:8e:f1:ce:1d:1e:4c:
+7f:95:4d:33:c3:7f:2a:c9:31:80:a4:87:4a:c9:01:
+f0:8f:fb:9c:6b:98:f3:5e:b7
+ """
+ q_str = """
+00:bb:57:60:e5:75:ca:84:6b:cc:9d:90:33:9b:84:
+8e:27:47:e8:68:99
+ """
+ g_str = """
+7d:4a:80:d5:20:26:e6:0e:54:eb:c4:88:2c:c5:57:
+f6:5e:6b:ab:43:a9:07:4c:1a:04:ca:49:43:7d:7f:
+47:fc:97:34:3a:a8:73:78:06:59:94:ce:24:55:38:
+23:3c:0d:a4:68:6b:18:d0:03:50:1c:b3:fe:57:35:
+48:03:80:74:42:48:af:42:55:ae:16:c6:7b:04:23:
+07:8a:56:aa:82:c6:6c:12:b3:46:86:af:4c:d0:dc:
+ff:30:fb:49:86:b5:d9:eb:d6:0c:70:03:c2:f9:57:
+a1:e4:20:fa:55:a8:a7:5c:d2:f0:ea:9a:0c:2e:da:
+2c:62:a9:58:dd:ff:9b:c9
+ """
+
+ key = DSA.importKey(x509_v1_cert)
+ for comp_name in ('y', 'p', 'q', 'g'):
+ comp_str = locals()[comp_name + "_str"]
+ comp = int(re.sub("[^0-9a-f]", "", comp_str), 16)
+ self.assertEqual(getattr(key, comp_name), comp)
+ self.assertFalse(key.has_private())
+
+ def test_x509v3(self):
+
+ # Sample V3 certificate with a 1024 bit DSA key
+ x509_v3_cert = """
+-----BEGIN CERTIFICATE-----
+MIIFhjCCA26gAwIBAgIBAzANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJVUzEL
+MAkGA1UECAwCTUQxEjAQBgNVBAcMCUJhbHRpbW9yZTEQMA4GA1UEAwwHVGVzdCBD
+QTEfMB0GCSqGSIb3DQEJARYQdGVzdEBleGFtcGxlLmNvbTAeFw0xNDA3MTMyMDUz
+MjBaFw0xNzA0MDgyMDUzMjBaMEAxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJNRDES
+MBAGA1UEBwwJQmFsdGltb3JlMRAwDgYDVQQDDAdhdXN0cmlhMIIBtjCCASsGByqG
+SM44BAEwggEeAoGBALfd8gyEpVPA0ZI69Kp3nyJcu5N0ZZ3K1K9hleQLNqKEcZOh
+7a/C2J1TPdmHTLJ0rAwBZ1nWxnARSgRphziGDFspKCYQwYcSMz8KoFgvXbXpuchy
+oFACiQ2LqZnc5MakuLQtLcQciSYGYj3zmZdYMoa904F1aDWr+DxQI6DVC3/bAhUA
+hqXMCJ6fQK3G2O9S3/CC/yVZXCsCgYBRXROl3R2khX7l10LQjDEgo3B1IzjXU/jP
+McMBl6XO+nBJXxr/scbq8Ajiv7LTnGpSjgryHtvfj887kfvo8QbSS3kp3vq5uSqI
+ui7E7r3jguWaLj616AG1HWOctXJUjqsiabZwsp2h09gHTzmHEXBOmiARu8xFxKAH
+xsuo7onAbwOBhAACgYBylWjWSnKHE8mHx1A5m/0GQx6xnhWIe3+MJAnEhRGxA2J4
+SCsfWU0OwglIQToh1z5uUU9oDi9cYgNPBevOFRnDhc2yaJY6VAYnI+D+6J5IU6Yd
+0iaG/iSc4sV4bFr0axcPpse3SN0XaQxiKeSFBfFnoMqL+dd9Gb3QPZSllBcVD6OB
+1TCB0jAdBgNVHQ4EFgQUx5wN0Puotv388M9Tp/fsPbZpzAUwHwYDVR0jBBgwFoAU
+a0hkif3RMaraiWtsOOZZlLu9wJwwCQYDVR0TBAIwADALBgNVHQ8EBAMCBeAwSgYD
+VR0RBEMwQYILZXhhbXBsZS5jb22CD3d3dy5leGFtcGxlLmNvbYIQbWFpbC5leGFt
+cGxlLmNvbYIPZnRwLmV4YW1wbGUuY29tMCwGCWCGSAGG+EIBDQQfFh1PcGVuU1NM
+IEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTANBgkqhkiG9w0BAQsFAAOCAgEAyWf1TiJI
+aNEIA9o/PG8/JiGASTS2/HBVTJbkq03k6NkJVk/GxC1DPziTUJ+CdWlHWcAi1EOW
+Ach3QxNDRrVfCOfCMDgElIO1094/reJgdFYG00LRi8QkRJuxANV7YS4tLudhyHJC
+kR2lhdMNmEuzWK+s2y+5cLrdm7qdvdENQCcV67uvGPx4sc+EaE7x13SczKjWBtbo
+QCs6JTOW+EkPRl4Zo27K4OIZ43/J+GxvwU9QUVH3wPVdbbLNw+QeTFBYMTEcxyc4
+kv50HPBFaithziXBFyvdIs19FjkFzu0Uz/e0zb1+vMzQlJMD94HVOrMnIj5Sb2cL
+KKdYXS4uhxFJmdV091Xur5JkYYwEzuaGav7J3zOzYutrIGTgDluLCvA+VQkRcTsy
+jZ065SkY/v+38QHp+cmm8WRluupJTs8wYzVp6Fu0iFaaK7ztFmaZmHpiPIfDFjva
+aCIgzzT5NweJd/b71A2SyzHXJ14zBXsr1PMylMp2TpHIidhuuNuQL6I0HaollB4M
+Z3FsVBMhVDw4Z76qnFPr8mZE2tar33hSlJI/3pS/bBiukuBk8U7VB0X8OqaUnP3C
+7b2Z4G8GtqDVcKGMzkvMjT4n9rKd/Le+qHSsQOGO9W/0LB7UDAZSwUsfAPnoBgdS
+5t9tIomLCOstByXi+gGZue1TcdCa3Ph4kO0=
+-----END CERTIFICATE-----
+ """.strip()
+
+ # DSA public key as dumped by openssl
+ y_str = """
+72:95:68:d6:4a:72:87:13:c9:87:c7:50:39:9b:fd:
+06:43:1e:b1:9e:15:88:7b:7f:8c:24:09:c4:85:11:
+b1:03:62:78:48:2b:1f:59:4d:0e:c2:09:48:41:3a:
+21:d7:3e:6e:51:4f:68:0e:2f:5c:62:03:4f:05:eb:
+ce:15:19:c3:85:cd:b2:68:96:3a:54:06:27:23:e0:
+fe:e8:9e:48:53:a6:1d:d2:26:86:fe:24:9c:e2:c5:
+78:6c:5a:f4:6b:17:0f:a6:c7:b7:48:dd:17:69:0c:
+62:29:e4:85:05:f1:67:a0:ca:8b:f9:d7:7d:19:bd:
+d0:3d:94:a5:94:17:15:0f
+ """
+ p_str = """
+00:b7:dd:f2:0c:84:a5:53:c0:d1:92:3a:f4:aa:77:
+9f:22:5c:bb:93:74:65:9d:ca:d4:af:61:95:e4:0b:
+36:a2:84:71:93:a1:ed:af:c2:d8:9d:53:3d:d9:87:
+4c:b2:74:ac:0c:01:67:59:d6:c6:70:11:4a:04:69:
+87:38:86:0c:5b:29:28:26:10:c1:87:12:33:3f:0a:
+a0:58:2f:5d:b5:e9:b9:c8:72:a0:50:02:89:0d:8b:
+a9:99:dc:e4:c6:a4:b8:b4:2d:2d:c4:1c:89:26:06:
+62:3d:f3:99:97:58:32:86:bd:d3:81:75:68:35:ab:
+f8:3c:50:23:a0:d5:0b:7f:db
+ """
+ q_str = """
+00:86:a5:cc:08:9e:9f:40:ad:c6:d8:ef:52:df:f0:
+82:ff:25:59:5c:2b
+ """
+ g_str = """
+51:5d:13:a5:dd:1d:a4:85:7e:e5:d7:42:d0:8c:31:
+20:a3:70:75:23:38:d7:53:f8:cf:31:c3:01:97:a5:
+ce:fa:70:49:5f:1a:ff:b1:c6:ea:f0:08:e2:bf:b2:
+d3:9c:6a:52:8e:0a:f2:1e:db:df:8f:cf:3b:91:fb:
+e8:f1:06:d2:4b:79:29:de:fa:b9:b9:2a:88:ba:2e:
+c4:ee:bd:e3:82:e5:9a:2e:3e:b5:e8:01:b5:1d:63:
+9c:b5:72:54:8e:ab:22:69:b6:70:b2:9d:a1:d3:d8:
+07:4f:39:87:11:70:4e:9a:20:11:bb:cc:45:c4:a0:
+07:c6:cb:a8:ee:89:c0:6f
+ """
+
+ key = DSA.importKey(x509_v3_cert)
+ for comp_name in ('y', 'p', 'q', 'g'):
+ comp_str = locals()[comp_name + "_str"]
+ comp = int(re.sub("[^0-9a-f]", "", comp_str), 16)
+ self.assertEqual(getattr(key, comp_name), comp)
+ self.assertFalse(key.has_private())
+
+
+if __name__ == '__main__':
+ unittest.main()
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(ImportKeyTests)
+ tests += list_test_cases(ImportKeyFromX509Cert)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
diff --git a/lib/Crypto/SelfTest/PublicKey/test_import_ECC.py b/lib/Crypto/SelfTest/PublicKey/test_import_ECC.py
new file mode 100644
index 0000000..f9222c8
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_import_ECC.py
@@ -0,0 +1,2643 @@
+# ===================================================================
+#
+# Copyright (c) 2015, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import os
+import errno
+import warnings
+import unittest
+from binascii import unhexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.py3compat import bord, tostr, FileNotFoundError
+from Crypto.Util.asn1 import DerSequence, DerBitString
+from Crypto.Util.number import bytes_to_long
+from Crypto.Hash import SHAKE128
+
+from Crypto.PublicKey import ECC
+
+try:
+ import pycryptodome_test_vectors # type: ignore
+ test_vectors_available = True
+except ImportError:
+ test_vectors_available = False
+
+
+class MissingTestVectorException(ValueError):
+ pass
+
+
+def load_file(file_name, mode="rb"):
+ results = None
+
+ try:
+ if not test_vectors_available:
+ raise FileNotFoundError(errno.ENOENT,
+ os.strerror(errno.ENOENT),
+ file_name)
+
+ dir_comps = ("PublicKey", "ECC")
+ init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
+ full_file_name = os.path.join(os.path.join(init_dir, *dir_comps), file_name)
+ with open(full_file_name, mode) as file_in:
+ results = file_in.read()
+
+ except FileNotFoundError:
+ warnings.warn("Warning: skipping extended tests for ECC",
+ UserWarning,
+ stacklevel=2)
+
+ if results is None:
+ raise MissingTestVectorException("Missing %s" % file_name)
+
+ return results
+
+
+def compact(lines):
+ ext = b"".join(lines)
+ return unhexlify(tostr(ext).replace(" ", "").replace(":", ""))
+
+
+def create_ref_keys_p192():
+ key_len = 24
+ key_lines = load_file("ecc_p192.txt").splitlines()
+ private_key_d = bytes_to_long(compact(key_lines[2:4]))
+ public_key_xy = compact(key_lines[5:9])
+ assert bord(public_key_xy[0]) == 4 # Uncompressed
+ public_key_x = bytes_to_long(public_key_xy[1:key_len+1])
+ public_key_y = bytes_to_long(public_key_xy[key_len+1:])
+
+ return (ECC.construct(curve="P-192", d=private_key_d),
+ ECC.construct(curve="P-192", point_x=public_key_x, point_y=public_key_y))
+
+
+def create_ref_keys_p224():
+ key_len = 28
+ key_lines = load_file("ecc_p224.txt").splitlines()
+ private_key_d = bytes_to_long(compact(key_lines[2:4]))
+ public_key_xy = compact(key_lines[5:9])
+ assert bord(public_key_xy[0]) == 4 # Uncompressed
+ public_key_x = bytes_to_long(public_key_xy[1:key_len+1])
+ public_key_y = bytes_to_long(public_key_xy[key_len+1:])
+
+ return (ECC.construct(curve="P-224", d=private_key_d),
+ ECC.construct(curve="P-224", point_x=public_key_x, point_y=public_key_y))
+
+
+def create_ref_keys_p256():
+ key_len = 32
+ key_lines = load_file("ecc_p256.txt").splitlines()
+ private_key_d = bytes_to_long(compact(key_lines[2:5]))
+ public_key_xy = compact(key_lines[6:11])
+ assert bord(public_key_xy[0]) == 4 # Uncompressed
+ public_key_x = bytes_to_long(public_key_xy[1:key_len+1])
+ public_key_y = bytes_to_long(public_key_xy[key_len+1:])
+
+ return (ECC.construct(curve="P-256", d=private_key_d),
+ ECC.construct(curve="P-256", point_x=public_key_x, point_y=public_key_y))
+
+
+def create_ref_keys_p384():
+ key_len = 48
+ key_lines = load_file("ecc_p384.txt").splitlines()
+ private_key_d = bytes_to_long(compact(key_lines[2:6]))
+ public_key_xy = compact(key_lines[7:14])
+ assert bord(public_key_xy[0]) == 4 # Uncompressed
+ public_key_x = bytes_to_long(public_key_xy[1:key_len+1])
+ public_key_y = bytes_to_long(public_key_xy[key_len+1:])
+
+ return (ECC.construct(curve="P-384", d=private_key_d),
+ ECC.construct(curve="P-384", point_x=public_key_x, point_y=public_key_y))
+
+
+def create_ref_keys_p521():
+ key_len = 66
+ key_lines = load_file("ecc_p521.txt").splitlines()
+ private_key_d = bytes_to_long(compact(key_lines[2:7]))
+ public_key_xy = compact(key_lines[8:17])
+ assert bord(public_key_xy[0]) == 4 # Uncompressed
+ public_key_x = bytes_to_long(public_key_xy[1:key_len+1])
+ public_key_y = bytes_to_long(public_key_xy[key_len+1:])
+
+ return (ECC.construct(curve="P-521", d=private_key_d),
+ ECC.construct(curve="P-521", point_x=public_key_x, point_y=public_key_y))
+
+
+def create_ref_keys_ed25519():
+ key_lines = load_file("ecc_ed25519.txt").splitlines()
+ seed = compact(key_lines[5:8])
+ key = ECC.construct(curve="Ed25519", seed=seed)
+ return (key, key.public_key())
+
+
+def create_ref_keys_ed448():
+ key_lines = load_file("ecc_ed448.txt").splitlines()
+ seed = compact(key_lines[6:10])
+ key = ECC.construct(curve="Ed448", seed=seed)
+ return (key, key.public_key())
+
+
+# Create reference key pair
+# ref_private, ref_public = create_ref_keys_p521()
+
+def get_fixed_prng():
+ return SHAKE128.new().update(b"SEED").read
+
+
+def extract_bitstring_from_spki(data):
+ seq = DerSequence()
+ seq.decode(data)
+ bs = DerBitString()
+ bs.decode(seq[1])
+ return bs.value
+
+
+class TestImport(unittest.TestCase):
+
+ def test_empty(self):
+ self.assertRaises(ValueError, ECC.import_key, b"")
+
+
+class TestImport_P192(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestImport_P192, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p192()
+
+ def test_import_public_der(self):
+ key_file = load_file("ecc_p192_public.der")
+
+ key = ECC._import_subjectPublicKeyInfo(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_uncompressed(self):
+ key_file = load_file("ecc_p192_public.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P192')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_compressed(self):
+ key_file = load_file("ecc_p192_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P192')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_rfc5915_der(self):
+ key_file = load_file("ecc_p192_private.der")
+
+ key = ECC._import_rfc5915_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p192_private_p8_clear.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_in_pem_clear(self):
+ key_file = load_file("ecc_p192_private_p8_clear.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_1(self):
+ key_file = load_file("ecc_p192_private_p8.der")
+
+ key = ECC._import_der(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_2(self):
+ key_file = load_file("ecc_p192_private_p8.pem")
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_der(self):
+ key_file = load_file("ecc_p192_x509.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_public_pem(self):
+ key_file = load_file("ecc_p192_public.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_private_pem(self):
+ key_file = load_file("ecc_p192_private.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pem_encrypted(self):
+ for algo in "des3", "aes128", "aes192", "aes256", "aes256_gcm":
+ key_file = load_file("ecc_p192_private_enc_%s.pem" % algo)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(tostr(key_file), b"secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_pem(self):
+ key_file = load_file("ecc_p192_x509.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+
+class TestImport_P224(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestImport_P224, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p224()
+
+ def test_import_public_der(self):
+ key_file = load_file("ecc_p224_public.der")
+
+ key = ECC._import_subjectPublicKeyInfo(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_uncompressed(self):
+ key_file = load_file("ecc_p224_public.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P224')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_compressed(self):
+ key_file = load_file("ecc_p224_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P224')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_rfc5915_der(self):
+ key_file = load_file("ecc_p224_private.der")
+
+ key = ECC._import_rfc5915_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p224_private_p8_clear.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_in_pem_clear(self):
+ key_file = load_file("ecc_p224_private_p8_clear.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_1(self):
+ key_file = load_file("ecc_p224_private_p8.der")
+
+ key = ECC._import_der(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_2(self):
+ key_file = load_file("ecc_p224_private_p8.pem")
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_der(self):
+ key_file = load_file("ecc_p224_x509.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_public_pem(self):
+ key_file = load_file("ecc_p224_public.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_private_pem(self):
+ key_file = load_file("ecc_p224_private.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pem_encrypted(self):
+ for algo in "des3", "aes128", "aes192", "aes256", "aes256_gcm":
+ key_file = load_file("ecc_p224_private_enc_%s.pem" % algo)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(tostr(key_file), b"secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_pem(self):
+ key_file = load_file("ecc_p224_x509.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+
+class TestImport_P256(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestImport_P256, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p256()
+
+ def test_import_public_der(self):
+ key_file = load_file("ecc_p256_public.der")
+
+ key = ECC._import_subjectPublicKeyInfo(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_uncompressed(self):
+ key_file = load_file("ecc_p256_public.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P256')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_compressed(self):
+ key_file = load_file("ecc_p256_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P256')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_rfc5915_der(self):
+ key_file = load_file("ecc_p256_private.der")
+
+ key = ECC._import_rfc5915_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p256_private_p8_clear.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_in_pem_clear(self):
+ key_file = load_file("ecc_p256_private_p8_clear.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_1(self):
+ key_file = load_file("ecc_p256_private_p8.der")
+
+ key = ECC._import_der(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_2(self):
+ key_file = load_file("ecc_p256_private_p8.pem")
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_der(self):
+ key_file = load_file("ecc_p256_x509.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_public_pem(self):
+ key_file = load_file("ecc_p256_public.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_private_pem(self):
+ key_file = load_file("ecc_p256_private.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pem_with_ecparams(self):
+ key_file = load_file("ecc_p256_private_ecparams.pem")
+ key = ECC.import_key(key_file)
+ # We just check if the import succeeds
+
+ def test_import_private_pem_encrypted(self):
+ for algo in "des3", "aes128", "aes192", "aes256", "aes256_gcm":
+ key_file = load_file("ecc_p256_private_enc_%s.pem" % algo)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(tostr(key_file), b"secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_pem(self):
+ key_file = load_file("ecc_p256_x509.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_openssh_public(self):
+ key_file = load_file("ecc_p256_public_openssh.txt")
+
+ key = ECC._import_openssh_public(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_openssh_private_clear(self):
+ key_file = load_file("ecc_p256_private_openssh.pem")
+ key_file_old = load_file("ecc_p256_private_openssh_old.pem")
+
+ key = ECC.import_key(key_file)
+ key_old = ECC.import_key(key_file_old)
+ self.assertEqual(key, key_old)
+
+ def test_import_openssh_private_password(self):
+ key_file = load_file("ecc_p256_private_openssh_pwd.pem")
+ key_file_old = load_file("ecc_p256_private_openssh_pwd_old.pem")
+
+ key = ECC.import_key(key_file, b"password")
+ key_old = ECC.import_key(key_file_old)
+ self.assertEqual(key, key_old)
+
+
+class TestImport_P384(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestImport_P384, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p384()
+
+ def test_import_public_der(self):
+ key_file = load_file("ecc_p384_public.der")
+
+ key = ECC._import_subjectPublicKeyInfo(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_uncompressed(self):
+ key_file = load_file("ecc_p384_public.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P384')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_compressed(self):
+ key_file = load_file("ecc_p384_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P384')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_rfc5915_der(self):
+ key_file = load_file("ecc_p384_private.der")
+
+ key = ECC._import_rfc5915_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p384_private_p8_clear.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_in_pem_clear(self):
+ key_file = load_file("ecc_p384_private_p8_clear.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_1(self):
+ key_file = load_file("ecc_p384_private_p8.der")
+
+ key = ECC._import_der(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_2(self):
+ key_file = load_file("ecc_p384_private_p8.pem")
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_der(self):
+ key_file = load_file("ecc_p384_x509.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_public_pem(self):
+ key_file = load_file("ecc_p384_public.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_private_pem(self):
+ key_file = load_file("ecc_p384_private.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pem_encrypted(self):
+ for algo in "des3", "aes128", "aes192", "aes256", "aes256_gcm":
+ key_file = load_file("ecc_p384_private_enc_%s.pem" % algo)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(tostr(key_file), b"secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_pem(self):
+ key_file = load_file("ecc_p384_x509.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_openssh_public(self):
+ key_file = load_file("ecc_p384_public_openssh.txt")
+
+ key = ECC._import_openssh_public(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_openssh_private_clear(self):
+ key_file = load_file("ecc_p384_private_openssh.pem")
+ key_file_old = load_file("ecc_p384_private_openssh_old.pem")
+
+ key = ECC.import_key(key_file)
+ key_old = ECC.import_key(key_file_old)
+ self.assertEqual(key, key_old)
+
+ def test_import_openssh_private_password(self):
+ key_file = load_file("ecc_p384_private_openssh_pwd.pem")
+ key_file_old = load_file("ecc_p384_private_openssh_pwd_old.pem")
+
+ key = ECC.import_key(key_file, b"password")
+ key_old = ECC.import_key(key_file_old)
+ self.assertEqual(key, key_old)
+
+
+class TestImport_P521(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestImport_P521, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p521()
+
+ def test_import_public_der(self):
+ key_file = load_file("ecc_p521_public.der")
+
+ key = ECC._import_subjectPublicKeyInfo(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_uncompressed(self):
+ key_file = load_file("ecc_p521_public.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P521')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_sec1_compressed(self):
+ key_file = load_file("ecc_p521_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file)
+ key = ECC.import_key(key_file, curve_name='P521')
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_rfc5915_der(self):
+ key_file = load_file("ecc_p521_private.der")
+
+ key = ECC._import_rfc5915_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p521_private_p8_clear.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_in_pem_clear(self):
+ key_file = load_file("ecc_p521_private_p8_clear.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_1(self):
+ key_file = load_file("ecc_p521_private_p8.der")
+
+ key = ECC._import_der(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_2(self):
+ key_file = load_file("ecc_p521_private_p8.pem")
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_der(self):
+ key_file = load_file("ecc_p521_x509.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_public_pem(self):
+ key_file = load_file("ecc_p521_public.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_private_pem(self):
+ key_file = load_file("ecc_p521_private.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pem_encrypted(self):
+ for algo in "des3", "aes128", "aes192", "aes256", "aes256_gcm":
+ key_file = load_file("ecc_p521_private_enc_%s.pem" % algo)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(tostr(key_file), b"secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_pem(self):
+ key_file = load_file("ecc_p521_x509.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_openssh_public(self):
+ key_file = load_file("ecc_p521_public_openssh.txt")
+
+ key = ECC._import_openssh_public(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_openssh_private_clear(self):
+ key_file = load_file("ecc_p521_private_openssh.pem")
+ key_file_old = load_file("ecc_p521_private_openssh_old.pem")
+
+ key = ECC.import_key(key_file)
+ key_old = ECC.import_key(key_file_old)
+ self.assertEqual(key, key_old)
+
+ def test_import_openssh_private_password(self):
+ key_file = load_file("ecc_p521_private_openssh_pwd.pem")
+ key_file_old = load_file("ecc_p521_private_openssh_pwd_old.pem")
+
+ key = ECC.import_key(key_file, b"password")
+ key_old = ECC.import_key(key_file_old)
+ self.assertEqual(key, key_old)
+
+
+class TestExport_P192(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestExport_P192, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p192()
+
+ def test_export_public_der_uncompressed(self):
+ key_file = load_file("ecc_p192_public.der")
+
+ encoded = self.ref_public._export_subjectPublicKeyInfo(False)
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_der_compressed(self):
+ key_file = load_file("ecc_p192_public.der")
+ pub_key = ECC.import_key(key_file)
+ key_file_compressed = pub_key.export_key(format="DER", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p192_public_compressed.der")
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_public_sec1_uncompressed(self):
+ key_file = load_file("ecc_p192_public.der")
+ value = extract_bitstring_from_spki(key_file)
+
+ encoded = self.ref_public.export_key(format="SEC1")
+ self.assertEqual(value, encoded)
+
+ def test_export_public_sec1_compressed(self):
+ key_file = load_file("ecc_p192_public.der")
+ encoded = self.ref_public.export_key(format="SEC1", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p192_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file_compressed_ref)
+ self.assertEqual(value, encoded)
+
+ def test_export_rfc5915_private_der(self):
+ key_file = load_file("ecc_p192_private.der")
+
+ encoded = self.ref_private._export_rfc5915_private_der()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p192_private_p8_clear.der")
+
+ encoded = self.ref_private._export_pkcs8()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_encrypted(self):
+ encoded = self.ref_private._export_pkcs8(passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None)
+
+ decoded = ECC._import_pkcs8(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_public_pem_uncompressed(self):
+ key_file = load_file("ecc_p192_public.pem", "rt").strip()
+
+ encoded = self.ref_private._export_public_pem(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="PEM", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_pem_compressed(self):
+ key_file = load_file("ecc_p192_public.pem", "rt").strip()
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="PEM", compress=True)
+ key_file_compressed_ref = load_file("ecc_p192_public_compressed.pem", "rt").strip()
+
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_private_pem_clear(self):
+ key_file = load_file("ecc_p192_private.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_pem(None)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pem_encrypted(self):
+ encoded = self.ref_private._export_private_pem(passphrase=b"secret")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "EC PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ use_pkcs8=False)
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_private_pkcs8_and_pem_1(self):
+ # PKCS8 inside PEM with both unencrypted
+ key_file = load_file("ecc_p192_private_p8_clear.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_clear_pkcs8_in_clear_pem()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_and_pem_2(self):
+ # PKCS8 inside PEM with PKCS8 encryption
+ encoded = self.ref_private._export_private_encrypted_pkcs8_in_clear_pem("secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "ENCRYPTED PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_prng(self):
+ # Test that password-protected containers use the provided PRNG
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ # ---
+
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_byte_or_string_passphrase(self):
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase=b"secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_error_params1(self):
+ # Unknown format
+ self.assertRaises(ValueError, self.ref_private.export_key, format="XXX")
+
+ # Missing 'protection' parameter when PKCS#8 is used
+ self.ref_private.export_key(format="PEM", passphrase="secret",
+ use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="secret")
+
+ # DER format but no PKCS#8
+ self.assertRaises(ValueError, self.ref_private.export_key, format="DER",
+ passphrase="secret",
+ use_pkcs8=False,
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # Incorrect parameters for public keys
+ self.assertRaises(ValueError, self.ref_public.export_key, format="DER",
+ use_pkcs8=False)
+
+ # Empty password
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="", use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ def test_compressed_curve(self):
+
+ # Compressed P-192 curve (Y-point is even)
+ pem1 = """-----BEGIN EC PRIVATE KEY-----
+ MF8CAQEEGHvhXmIW95JxZYfd4AUPu9BwknjuvS36aqAKBggqhkjOPQMBAaE0AzIA
+ BLJZCyTu35DQIlqvMlBynn3k1Ig+dWfg/brRhHecxptrbloqFSP8ITw0CwbGF+2X
+ 5g==
+ -----END EC PRIVATE KEY-----"""
+
+ # Compressed P-192 curve (Y-point is odd)
+ pem2 = """-----BEGIN EC PRIVATE KEY-----
+ MF8CAQEEGA3rAotUaWl7d47eX6tz9JmLzOMJwl13XaAKBggqhkjOPQMBAaE0AzIA
+ BG4tHlTBBBGokcWmGm2xubVB0NvPC/Ou5AYwivs+3iCxmEjsymVAj6iiuX2Lxr6g
+ /Q==
+ -----END EC PRIVATE KEY-----"""
+
+ key1 = ECC.import_key(pem1)
+ low16 = int(key1.pointQ.y % 65536)
+ self.assertEqual(low16, 0x97E6)
+
+ key2 = ECC.import_key(pem2)
+ low16 = int(key2.pointQ.y % 65536)
+ self.assertEqual(low16, 0xA0FD)
+
+
+class TestExport_P224(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestExport_P224, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p224()
+
+ def test_export_public_der_uncompressed(self):
+ key_file = load_file("ecc_p224_public.der")
+
+ encoded = self.ref_public._export_subjectPublicKeyInfo(False)
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_der_compressed(self):
+ key_file = load_file("ecc_p224_public.der")
+ pub_key = ECC.import_key(key_file)
+ key_file_compressed = pub_key.export_key(format="DER", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p224_public_compressed.der")
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_public_sec1_uncompressed(self):
+ key_file = load_file("ecc_p224_public.der")
+ value = extract_bitstring_from_spki(key_file)
+
+ encoded = self.ref_public.export_key(format="SEC1")
+ self.assertEqual(value, encoded)
+
+ def test_export_public_sec1_compressed(self):
+ key_file = load_file("ecc_p224_public.der")
+ encoded = self.ref_public.export_key(format="SEC1", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p224_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file_compressed_ref)
+ self.assertEqual(value, encoded)
+
+ def test_export_rfc5915_private_der(self):
+ key_file = load_file("ecc_p224_private.der")
+
+ encoded = self.ref_private._export_rfc5915_private_der()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p224_private_p8_clear.der")
+
+ encoded = self.ref_private._export_pkcs8()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_encrypted(self):
+ encoded = self.ref_private._export_pkcs8(passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None)
+
+ decoded = ECC._import_pkcs8(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_public_pem_uncompressed(self):
+ key_file = load_file("ecc_p224_public.pem", "rt").strip()
+
+ encoded = self.ref_private._export_public_pem(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="PEM", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_pem_compressed(self):
+ key_file = load_file("ecc_p224_public.pem", "rt").strip()
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="PEM", compress=True)
+ key_file_compressed_ref = load_file("ecc_p224_public_compressed.pem", "rt").strip()
+
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_private_pem_clear(self):
+ key_file = load_file("ecc_p224_private.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_pem(None)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pem_encrypted(self):
+ encoded = self.ref_private._export_private_pem(passphrase=b"secret")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "EC PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ use_pkcs8=False)
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_private_pkcs8_and_pem_1(self):
+ # PKCS8 inside PEM with both unencrypted
+ key_file = load_file("ecc_p224_private_p8_clear.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_clear_pkcs8_in_clear_pem()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_and_pem_2(self):
+ # PKCS8 inside PEM with PKCS8 encryption
+ encoded = self.ref_private._export_private_encrypted_pkcs8_in_clear_pem("secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "ENCRYPTED PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_prng(self):
+ # Test that password-protected containers use the provided PRNG
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ # ---
+
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_byte_or_string_passphrase(self):
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase=b"secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_error_params1(self):
+ # Unknown format
+ self.assertRaises(ValueError, self.ref_private.export_key, format="XXX")
+
+ # Missing 'protection' parameter when PKCS#8 is used
+ self.ref_private.export_key(format="PEM", passphrase="secret",
+ use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="secret")
+
+ # DER format but no PKCS#8
+ self.assertRaises(ValueError, self.ref_private.export_key, format="DER",
+ passphrase="secret",
+ use_pkcs8=False,
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # Incorrect parameters for public keys
+ self.assertRaises(ValueError, self.ref_public.export_key, format="DER",
+ use_pkcs8=False)
+
+ # Empty password
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="", use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ def test_compressed_curve(self):
+
+ # Compressed P-224 curve (Y-point is even)
+ pem1 = """-----BEGIN EC PRIVATE KEY-----
+ MGgCAQEEHPYicBNI9nd6wDKAX2l+f3A0Q+KWUQeMqSt5GoOgBwYFK4EEACGhPAM6
+ AATCL6rUIDT14zXKoS5GQUMDP/tpc+1iI/FyEZikt2roKDkhU5q08srmqaysbfJN
+ eUr7Xf1lnCVGag==
+ -----END EC PRIVATE KEY-----"""
+
+ # Compressed P-224 curve (Y-point is odd)
+ pem2 = """-----BEGIN EC PRIVATE KEY-----
+ MGgCAQEEHEFjbaVPLJ3ngZyCibCvT0RLUqSlHjC5Z3e0FtugBwYFK4EEACGhPAM6
+ AAT5IvL2V6m48y1JLMGr6ZbnOqNKP9hMf9mxyVkk6/SaRoBoJVkXrNIpYL0P7DS7
+ QF8E/OGeZRwvow==
+ -----END EC PRIVATE KEY-----"""
+
+ key1 = ECC.import_key(pem1)
+ low16 = int(key1.pointQ.y % 65536)
+ self.assertEqual(low16, 0x466A)
+
+ key2 = ECC.import_key(pem2)
+ low16 = int(key2.pointQ.y % 65536)
+ self.assertEqual(low16, 0x2FA3)
+
+
+class TestExport_P256(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestExport_P256, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p256()
+
+ def test_export_public_der_uncompressed(self):
+ key_file = load_file("ecc_p256_public.der")
+
+ encoded = self.ref_public._export_subjectPublicKeyInfo(False)
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_der_compressed(self):
+ key_file = load_file("ecc_p256_public.der")
+ pub_key = ECC.import_key(key_file)
+ key_file_compressed = pub_key.export_key(format="DER", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p256_public_compressed.der")
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_public_sec1_uncompressed(self):
+ key_file = load_file("ecc_p256_public.der")
+ value = extract_bitstring_from_spki(key_file)
+
+ encoded = self.ref_public.export_key(format="SEC1")
+ self.assertEqual(value, encoded)
+
+ def test_export_public_sec1_compressed(self):
+ key_file = load_file("ecc_p256_public.der")
+ encoded = self.ref_public.export_key(format="SEC1", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p256_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file_compressed_ref)
+ self.assertEqual(value, encoded)
+
+ def test_export_rfc5915_private_der(self):
+ key_file = load_file("ecc_p256_private.der")
+
+ encoded = self.ref_private._export_rfc5915_private_der()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p256_private_p8_clear.der")
+
+ encoded = self.ref_private._export_pkcs8()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_encrypted(self):
+ encoded = self.ref_private._export_pkcs8(passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None)
+
+ decoded = ECC._import_pkcs8(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_public_pem_uncompressed(self):
+ key_file = load_file("ecc_p256_public.pem", "rt").strip()
+
+ encoded = self.ref_private._export_public_pem(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="PEM", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_pem_compressed(self):
+ key_file = load_file("ecc_p256_public.pem", "rt").strip()
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="PEM", compress=True)
+ key_file_compressed_ref = load_file("ecc_p256_public_compressed.pem", "rt").strip()
+
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_private_pem_clear(self):
+ key_file = load_file("ecc_p256_private.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_pem(None)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pem_encrypted(self):
+ encoded = self.ref_private._export_private_pem(passphrase=b"secret")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "EC PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ use_pkcs8=False)
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_private_pkcs8_and_pem_1(self):
+ # PKCS8 inside PEM with both unencrypted
+ key_file = load_file("ecc_p256_private_p8_clear.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_clear_pkcs8_in_clear_pem()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_and_pem_2(self):
+ # PKCS8 inside PEM with PKCS8 encryption
+ encoded = self.ref_private._export_private_encrypted_pkcs8_in_clear_pem("secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "ENCRYPTED PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_openssh_uncompressed(self):
+ key_file = load_file("ecc_p256_public_openssh.txt", "rt")
+
+ encoded = self.ref_public._export_openssh(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="OpenSSH")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="OpenSSH", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_openssh_compressed(self):
+ key_file = load_file("ecc_p256_public_openssh.txt", "rt")
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="OpenSSH", compress=True)
+ assert len(key_file) > len(key_file_compressed)
+ self.assertEqual(pub_key, ECC.import_key(key_file_compressed))
+
+ def test_prng(self):
+ # Test that password-protected containers use the provided PRNG
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ # ---
+
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_byte_or_string_passphrase(self):
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase=b"secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_error_params1(self):
+ # Unknown format
+ self.assertRaises(ValueError, self.ref_private.export_key, format="XXX")
+
+ # Missing 'protection' parameter when PKCS#8 is used
+ self.ref_private.export_key(format="PEM", passphrase="secret",
+ use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="secret")
+
+ # DER format but no PKCS#8
+ self.assertRaises(ValueError, self.ref_private.export_key, format="DER",
+ passphrase="secret",
+ use_pkcs8=False,
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # Incorrect parameters for public keys
+ self.assertRaises(ValueError, self.ref_public.export_key, format="DER",
+ use_pkcs8=False)
+
+ # Empty password
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="", use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # No private keys with OpenSSH
+ self.assertRaises(ValueError, self.ref_private.export_key, format="OpenSSH",
+ passphrase="secret")
+
+
+ def test_compressed_curve(self):
+
+ # Compressed P-256 curve (Y-point is even)
+ pem1 = """-----BEGIN EC PRIVATE KEY-----
+ MFcCAQEEIHTuc09jC51xXomV6MVCDN+DpAAvSmaJWZPTEHM6D5H1oAoGCCqGSM49
+ AwEHoSQDIgACWFuGbHe8yJ43rir7PMTE9w8vHz0BSpXHq90Xi7/s+a0=
+ -----END EC PRIVATE KEY-----"""
+
+ # Compressed P-256 curve (Y-point is odd)
+ pem2 = """-----BEGIN EC PRIVATE KEY-----
+ MFcCAQEEIFggiPN9SQP+FAPTCPp08fRUz7rHp2qNBRcBJ1DXhb3ZoAoGCCqGSM49
+ AwEHoSQDIgADLpph1trTIlVfa8NJvlMUPyWvL+wP+pW3BJITUL/wj9A=
+ -----END EC PRIVATE KEY-----"""
+
+ key1 = ECC.import_key(pem1)
+ low16 = int(key1.pointQ.y % 65536)
+ self.assertEqual(low16, 0xA6FC)
+
+ key2 = ECC.import_key(pem2)
+ low16 = int(key2.pointQ.y % 65536)
+ self.assertEqual(low16, 0x6E57)
+
+
+class TestExport_P384(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestExport_P384, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p384()
+
+ def test_export_public_der_uncompressed(self):
+ key_file = load_file("ecc_p384_public.der")
+
+ encoded = self.ref_public._export_subjectPublicKeyInfo(False)
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_der_compressed(self):
+ key_file = load_file("ecc_p384_public.der")
+ pub_key = ECC.import_key(key_file)
+ key_file_compressed = pub_key.export_key(format="DER", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p384_public_compressed.der")
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_public_sec1_uncompressed(self):
+ key_file = load_file("ecc_p384_public.der")
+ value = extract_bitstring_from_spki(key_file)
+
+ encoded = self.ref_public.export_key(format="SEC1")
+ self.assertEqual(value, encoded)
+
+ def test_export_public_sec1_compressed(self):
+ key_file = load_file("ecc_p384_public.der")
+ encoded = self.ref_public.export_key(format="SEC1", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p384_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file_compressed_ref)
+ self.assertEqual(value, encoded)
+
+ def test_export_rfc5915_private_der(self):
+ key_file = load_file("ecc_p384_private.der")
+
+ encoded = self.ref_private._export_rfc5915_private_der()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p384_private_p8_clear.der")
+
+ encoded = self.ref_private._export_pkcs8()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_encrypted(self):
+ encoded = self.ref_private._export_pkcs8(passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None)
+
+ decoded = ECC._import_pkcs8(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_public_pem_uncompressed(self):
+ key_file = load_file("ecc_p384_public.pem", "rt").strip()
+
+ encoded = self.ref_private._export_public_pem(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="PEM", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_pem_compressed(self):
+ key_file = load_file("ecc_p384_public.pem", "rt").strip()
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="PEM", compress=True)
+ key_file_compressed_ref = load_file("ecc_p384_public_compressed.pem", "rt").strip()
+
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_private_pem_clear(self):
+ key_file = load_file("ecc_p384_private.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_pem(None)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pem_encrypted(self):
+ encoded = self.ref_private._export_private_pem(passphrase=b"secret")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "EC PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ use_pkcs8=False)
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_private_pkcs8_and_pem_1(self):
+ # PKCS8 inside PEM with both unencrypted
+ key_file = load_file("ecc_p384_private_p8_clear.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_clear_pkcs8_in_clear_pem()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_and_pem_2(self):
+ # PKCS8 inside PEM with PKCS8 encryption
+ encoded = self.ref_private._export_private_encrypted_pkcs8_in_clear_pem("secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "ENCRYPTED PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_openssh_uncompressed(self):
+ key_file = load_file("ecc_p384_public_openssh.txt", "rt")
+
+ encoded = self.ref_public._export_openssh(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="OpenSSH")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="OpenSSH", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_openssh_compressed(self):
+ key_file = load_file("ecc_p384_public_openssh.txt", "rt")
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="OpenSSH", compress=True)
+ assert len(key_file) > len(key_file_compressed)
+ self.assertEqual(pub_key, ECC.import_key(key_file_compressed))
+
+ def test_prng(self):
+ # Test that password-protected containers use the provided PRNG
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ # ---
+
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_byte_or_string_passphrase(self):
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase=b"secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_error_params1(self):
+ # Unknown format
+ self.assertRaises(ValueError, self.ref_private.export_key, format="XXX")
+
+ # Missing 'protection' parameter when PKCS#8 is used
+ self.ref_private.export_key(format="PEM", passphrase="secret",
+ use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="secret")
+
+ # DER format but no PKCS#8
+ self.assertRaises(ValueError, self.ref_private.export_key, format="DER",
+ passphrase="secret",
+ use_pkcs8=False,
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # Incorrect parameters for public keys
+ self.assertRaises(ValueError, self.ref_public.export_key, format="DER",
+ use_pkcs8=False)
+
+ # Empty password
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="", use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # No private keys with OpenSSH
+ self.assertRaises(ValueError, self.ref_private.export_key, format="OpenSSH",
+ passphrase="secret")
+
+ def test_compressed_curve(self):
+
+ # Compressed P-384 curve (Y-point is even)
+ # openssl ecparam -name secp384p1 -genkey -noout -conv_form compressed -out /tmp/a.pem
+ # openssl ec -in /tmp/a.pem -text -noout
+ pem1 = """-----BEGIN EC PRIVATE KEY-----
+MIGkAgEBBDAM0lEIhvXuekK2SWtdbgOcZtBaxa9TxfpO/GcDFZLCJ3JVXaTgwken
+QT+C+XLtD6WgBwYFK4EEACKhZANiAATs0kZMhFDu8DoBC21jrSDPyAUn4aXZ/DM4
+ylhDfWmb4LEbeszXceIzfhIUaaGs5y1xXaqf5KXTiAAYx2pKUzAAM9lcGUHCGKJG
+k4AgUmVJON29XoUilcFrzjDmuye3B6Q=
+-----END EC PRIVATE KEY-----"""
+
+ # Compressed P-384 curve (Y-point is odd)
+ pem2 = """-----BEGIN EC PRIVATE KEY-----
+MIGkAgEBBDDHPFTslYLltE16fHdSDTtE/2HTmd3M8mqy5MttAm4wZ833KXiGS9oe
+kFdx9sNV0KygBwYFK4EEACKhZANiAASLIE5RqVMtNhtBH/u/p/ifqOAlKnK/+RrQ
+YC46ZRsnKNayw3wATdPjgja7L/DSII3nZK0G6KOOVwJBznT/e+zudUJYhZKaBLRx
+/bgXyxUtYClOXxb1Y/5N7txLstYRyP0=
+-----END EC PRIVATE KEY-----"""
+
+ key1 = ECC.import_key(pem1)
+ low16 = int(key1.pointQ.y % 65536)
+ self.assertEqual(low16, 0x07a4)
+
+ key2 = ECC.import_key(pem2)
+ low16 = int(key2.pointQ.y % 65536)
+ self.assertEqual(low16, 0xc8fd)
+
+
+class TestExport_P521(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestExport_P521, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_p521()
+
+ def test_export_public_der_uncompressed(self):
+ key_file = load_file("ecc_p521_public.der")
+
+ encoded = self.ref_public._export_subjectPublicKeyInfo(False)
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_der_compressed(self):
+ key_file = load_file("ecc_p521_public.der")
+ pub_key = ECC.import_key(key_file)
+ key_file_compressed = pub_key.export_key(format="DER", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p521_public_compressed.der")
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_public_sec1_uncompressed(self):
+ key_file = load_file("ecc_p521_public.der")
+ value = extract_bitstring_from_spki(key_file)
+
+ encoded = self.ref_public.export_key(format="SEC1")
+ self.assertEqual(value, encoded)
+
+ encoded = self.ref_public.export_key(format="raw")
+ self.assertEqual(value, encoded)
+
+ def test_export_public_sec1_compressed(self):
+ key_file = load_file("ecc_p521_public.der")
+ encoded = self.ref_public.export_key(format="SEC1", compress=True)
+
+ key_file_compressed_ref = load_file("ecc_p521_public_compressed.der")
+ value = extract_bitstring_from_spki(key_file_compressed_ref)
+ self.assertEqual(value, encoded)
+
+ encoded = self.ref_public.export_key(format="raw", compress=True)
+ self.assertEqual(value, encoded)
+
+ def test_export_rfc5915_private_der(self):
+ key_file = load_file("ecc_p521_private.der")
+
+ encoded = self.ref_private._export_rfc5915_private_der()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_clear(self):
+ key_file = load_file("ecc_p521_private_p8_clear.der")
+
+ encoded = self.ref_private._export_pkcs8()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_encrypted(self):
+ encoded = self.ref_private._export_pkcs8(passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None)
+
+ decoded = ECC._import_pkcs8(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_public_pem_uncompressed(self):
+ key_file = load_file("ecc_p521_public.pem", "rt").strip()
+
+ encoded = self.ref_private._export_public_pem(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="PEM", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_pem_compressed(self):
+ key_file = load_file("ecc_p521_public.pem", "rt").strip()
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="PEM", compress=True)
+ key_file_compressed_ref = load_file("ecc_p521_public_compressed.pem", "rt").strip()
+
+ self.assertEqual(key_file_compressed, key_file_compressed_ref)
+
+ def test_export_private_pem_clear(self):
+ key_file = load_file("ecc_p521_private.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_pem(None)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM", use_pkcs8=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pem_encrypted(self):
+ encoded = self.ref_private._export_private_pem(passphrase=b"secret")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "EC PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ use_pkcs8=False)
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_private_pkcs8_and_pem_1(self):
+ # PKCS8 inside PEM with both unencrypted
+ key_file = load_file("ecc_p521_private_p8_clear.pem", "rt").strip()
+
+ encoded = self.ref_private._export_private_clear_pkcs8_in_clear_pem()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM")
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pkcs8_and_pem_2(self):
+ # PKCS8 inside PEM with PKCS8 encryption
+ encoded = self.ref_private._export_private_encrypted_pkcs8_in_clear_pem("secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "ENCRYPTED PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_openssh_uncompressed(self):
+ key_file = load_file("ecc_p521_public_openssh.txt", "rt")
+
+ encoded = self.ref_public._export_openssh(False)
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_public.export_key(format="OpenSSH")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="OpenSSH", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_openssh_compressed(self):
+ key_file = load_file("ecc_p521_public_openssh.txt", "rt")
+ pub_key = ECC.import_key(key_file)
+
+ key_file_compressed = pub_key.export_key(format="OpenSSH", compress=True)
+ assert len(key_file) > len(key_file_compressed)
+ self.assertEqual(pub_key, ECC.import_key(key_file_compressed))
+
+ def test_prng(self):
+ # Test that password-protected containers use the provided PRNG
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ # ---
+
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_byte_or_string_passphrase(self):
+ encoded1 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase="secret",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ use_pkcs8=False,
+ passphrase=b"secret",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_error_params1(self):
+ # Unknown format
+ self.assertRaises(ValueError, self.ref_private.export_key, format="XXX")
+
+ # Missing 'protection' parameter when PKCS#8 is used
+ self.ref_private.export_key(format="PEM", passphrase="secret",
+ use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="secret")
+
+ # DER format but no PKCS#8
+ self.assertRaises(ValueError, self.ref_private.export_key, format="DER",
+ passphrase="secret",
+ use_pkcs8=False,
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # Incorrect parameters for public keys
+ self.assertRaises(ValueError, self.ref_public.export_key, format="DER",
+ use_pkcs8=False)
+
+ # Empty password
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="", use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # No private keys with OpenSSH
+ self.assertRaises(ValueError, self.ref_private.export_key, format="OpenSSH",
+ passphrase="secret")
+
+ def test_compressed_curve(self):
+
+ # Compressed P-521 curve (Y-point is even)
+ # openssl ecparam -name secp521r1 -genkey -noout -conv_form compressed -out /tmp/a.pem
+ # openssl ec -in /tmp/a.pem -text -noout
+ pem1 = """-----BEGIN EC PRIVATE KEY-----
+MIHcAgEBBEIAnm1CEjVjvNfXEN730p+D6su5l+mOztdc5XmTEoti+s2R4GQ4mAv3
+0zYLvyklvOHw0+yy8d0cyGEJGb8T3ZVKmg2gBwYFK4EEACOhgYkDgYYABAHzjTI1
+ckxQ3Togi0LAxiG0PucdBBBs5oIy3df95xv6SInp70z+4qQ2EltEmdNMssH8eOrl
+M5CYdZ6nbcHMVaJUvQEzTrYxvFjOgJiOd+E9eBWbLkbMNqsh1UKVO6HbMbW0ohCI
+uGxO8tM6r3w89/qzpG2SvFM/fvv3mIR30wSZDD84qA==
+-----END EC PRIVATE KEY-----"""
+
+ # Compressed P-521 curve (Y-point is odd)
+ pem2 = """-----BEGIN EC PRIVATE KEY-----
+MIHcAgEBBEIB84OfhJluLBRLn3+cC/RQ37C2SfQVP/t0gQK2tCsTf5avRcWYRrOJ
+PmX9lNnkC0Hobd75QFRmdxrB0Wd1/M4jZOWgBwYFK4EEACOhgYkDgYYABAAMZcdJ
+1YLCGHt3bHCEzdidVy6+brlJIbv1aQ9fPQLF7WKNv4c8w3H8d5a2+SDZilBOsk5c
+6cNJDMz2ExWQvxl4CwDJtJGt1+LHVKFGy73NANqVxMbRu+2F8lOxkNp/ziFTbVyV
+vv6oYkMIIi7r5oQWAiQDrR2mlrrFDL9V7GH/r8SWQw==
+-----END EC PRIVATE KEY-----"""
+
+ key1 = ECC.import_key(pem1)
+ low16 = int(key1.pointQ.y % 65536)
+ self.assertEqual(low16, 0x38a8)
+
+ key2 = ECC.import_key(pem2)
+ low16 = int(key2.pointQ.y % 65536)
+ self.assertEqual(low16, 0x9643)
+
+
+class TestImport_Ed25519(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestImport_Ed25519, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_ed25519()
+
+ def test_import_public_der(self):
+ key_file = load_file("ecc_ed25519_public.der")
+
+ key = ECC._import_subjectPublicKeyInfo(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_pkcs8_der(self):
+ key_file = load_file("ecc_ed25519_private.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_1(self):
+ key_file = load_file("ecc_ed25519_private_p8.der")
+
+ key = ECC._import_der(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_2(self):
+ key_file = load_file("ecc_ed25519_private_p8.pem")
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_der(self):
+ key_file = load_file("ecc_ed25519_x509.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_public_pem(self):
+ key_file = load_file("ecc_ed25519_public.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_private_pem(self):
+ key_file = load_file("ecc_ed25519_private.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pem_encrypted(self):
+ for algo in "des3", "aes128", "aes192", "aes256":
+ key_file = load_file("ecc_ed25519_private_enc_%s.pem" % algo)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(tostr(key_file), b"secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_pem(self):
+ key_file = load_file("ecc_ed25519_x509.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_openssh_public(self):
+ key_file = load_file("ecc_ed25519_public_openssh.txt")
+ key = ECC._import_openssh_public(key_file)
+ self.failIf(key.has_private())
+ key = ECC.import_key(key_file)
+ self.failIf(key.has_private())
+
+ def test_import_openssh_private_clear(self):
+ key_file = load_file("ecc_ed25519_private_openssh.pem")
+ key = ECC.import_key(key_file)
+
+ def test_import_openssh_private_password(self):
+ key_file = load_file("ecc_ed25519_private_openssh_pwd.pem")
+ key = ECC.import_key(key_file, b"password")
+
+
+class TestExport_Ed25519(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestExport_Ed25519, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_ed25519()
+
+ def test_export_public_der(self):
+ key_file = load_file("ecc_ed25519_public.der")
+
+ encoded = self.ref_public._export_subjectPublicKeyInfo(True)
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_sec1(self):
+ self.assertRaises(ValueError, self.ref_public.export_key, format="SEC1")
+
+ def test_export_private_pkcs8_clear(self):
+ key_file = load_file("ecc_ed25519_private.der")
+
+ encoded = self.ref_private._export_pkcs8()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ self.assertRaises(ValueError, self.ref_private.export_key,
+ format="DER", use_pkcs8=False)
+
+ def test_export_private_pkcs8_encrypted(self):
+ encoded = self.ref_private._export_pkcs8(passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None)
+
+ decoded = ECC._import_pkcs8(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_public_pem(self):
+ key_file_ref = load_file("ecc_ed25519_public.pem", "rt").strip()
+ key_file = self.ref_public.export_key(format="PEM").strip()
+ self.assertEqual(key_file_ref, key_file)
+
+ def test_export_private_pem_clear(self):
+ key_file = load_file("ecc_ed25519_private.pem", "rt").strip()
+ encoded = self.ref_private.export_key(format="PEM").strip()
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pem_encrypted(self):
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase=b"secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "ENCRYPTED PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_openssh(self):
+ key_file = load_file("ecc_ed25519_public_openssh.txt", "rt")
+ public_key = ECC.import_key(key_file)
+ key_file = " ".join(key_file.split(' ')[:2]) # remove comment
+
+ encoded = public_key._export_openssh(False)
+ self.assertEqual(key_file, encoded.strip())
+
+ encoded = public_key.export_key(format="OpenSSH")
+ self.assertEqual(key_file, encoded.strip())
+
+ def test_export_raw(self):
+ encoded = self.ref_public.export_key(format='raw')
+ self.assertEqual(encoded, unhexlify(b'bc85b8cf585d20a4de47e84d1cb6183f63d9ba96223fcbc886e363ffdea20cff'))
+
+ def test_prng(self):
+ # Test that password-protected containers use the provided PRNG
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_byte_or_string_passphrase(self):
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase=b"secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_error_params1(self):
+ # Unknown format
+ self.assertRaises(ValueError, self.ref_private.export_key, format="XXX")
+
+ # Missing 'protection' parameter when PKCS#8 is used
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="secret")
+
+ # Empty password
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="", use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # No private keys with OpenSSH
+ self.assertRaises(ValueError, self.ref_private.export_key, format="OpenSSH",
+ passphrase="secret")
+
+
+class TestImport_Ed448(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestImport_Ed448, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_ed448()
+
+ def test_import_public_der(self):
+ key_file = load_file("ecc_ed448_public.der")
+
+ key = ECC._import_subjectPublicKeyInfo(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_pkcs8_der(self):
+ key_file = load_file("ecc_ed448_private.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_1(self):
+ key_file = load_file("ecc_ed448_private_p8.der")
+
+ key = ECC._import_der(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pkcs8_encrypted_2(self):
+ key_file = load_file("ecc_ed448_private_p8.pem")
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_der(self):
+ key_file = load_file("ecc_ed448_x509.der")
+
+ key = ECC._import_der(key_file, None)
+ self.assertEqual(self.ref_public, key)
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_public_pem(self):
+ key_file = load_file("ecc_ed448_public.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+ def test_import_private_pem(self):
+ key_file = load_file("ecc_ed448_private.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_private_pem_encrypted(self):
+ for algo in "des3", "aes128", "aes192", "aes256":
+ key_file = load_file("ecc_ed448_private_enc_%s.pem" % algo)
+
+ key = ECC.import_key(key_file, "secret")
+ self.assertEqual(self.ref_private, key)
+
+ key = ECC.import_key(tostr(key_file), b"secret")
+ self.assertEqual(self.ref_private, key)
+
+ def test_import_x509_pem(self):
+ key_file = load_file("ecc_ed448_x509.pem")
+
+ key = ECC.import_key(key_file)
+ self.assertEqual(self.ref_public, key)
+
+
+class TestExport_Ed448(unittest.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(TestExport_Ed448, self).__init__(*args, **kwargs)
+ self.ref_private, self.ref_public = create_ref_keys_ed448()
+
+ def test_export_public_der(self):
+ key_file = load_file("ecc_ed448_public.der")
+
+ encoded = self.ref_public._export_subjectPublicKeyInfo(True)
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ encoded = self.ref_public.export_key(format="DER", compress=False)
+ self.assertEqual(key_file, encoded)
+
+ def test_export_public_sec1(self):
+ self.assertRaises(ValueError, self.ref_public.export_key, format="SEC1")
+
+ def test_export_private_pkcs8_clear(self):
+ key_file = load_file("ecc_ed448_private.der")
+
+ encoded = self.ref_private._export_pkcs8()
+ self.assertEqual(key_file, encoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER")
+ self.assertEqual(key_file, encoded)
+
+ self.assertRaises(ValueError, self.ref_private.export_key,
+ format="DER", use_pkcs8=False)
+
+ def test_export_private_pkcs8_encrypted(self):
+ encoded = self.ref_private._export_pkcs8(passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC._import_pkcs8, encoded, None)
+
+ decoded = ECC._import_pkcs8(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ # ---
+
+ encoded = self.ref_private.export_key(format="DER",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_public_pem(self):
+ key_file_ref = load_file("ecc_ed448_public.pem", "rt").strip()
+ key_file = self.ref_public.export_key(format="PEM").strip()
+ self.assertEqual(key_file_ref, key_file)
+
+ def test_export_private_pem_clear(self):
+ key_file = load_file("ecc_ed448_private.pem", "rt").strip()
+ encoded = self.ref_private.export_key(format="PEM").strip()
+ self.assertEqual(key_file, encoded)
+
+ def test_export_private_pem_encrypted(self):
+ encoded = self.ref_private.export_key(format="PEM",
+ passphrase=b"secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # This should prove that the output is password-protected
+ self.assertRaises(ValueError, ECC.import_key, encoded)
+
+ assert "ENCRYPTED PRIVATE KEY" in encoded
+
+ decoded = ECC.import_key(encoded, "secret")
+ self.assertEqual(self.ref_private, decoded)
+
+ def test_export_openssh(self):
+ # Not supported
+ self.assertRaises(ValueError, self.ref_public.export_key, format="OpenSSH")
+
+ def test_export_raw(self):
+ encoded = self.ref_public.export_key(format='raw')
+ self.assertEqual(encoded, unhexlify(b'899014ddc0a0e1260cfc1085afdf952019e9fd63372e3e366e26dad32b176624884330a14617237e3081febd9d1a15069e7499433d2f55dd80'))
+
+ def test_prng(self):
+ # Test that password-protected containers use the provided PRNG
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_byte_or_string_passphrase(self):
+ encoded1 = self.ref_private.export_key(format="PEM",
+ passphrase="secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ encoded2 = self.ref_private.export_key(format="PEM",
+ passphrase=b"secret",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC",
+ randfunc=get_fixed_prng())
+ self.assertEqual(encoded1, encoded2)
+
+ def test_error_params1(self):
+ # Unknown format
+ self.assertRaises(ValueError, self.ref_private.export_key, format="XXX")
+
+ # Missing 'protection' parameter when PKCS#8 is used
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="secret")
+
+ # Empty password
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="", use_pkcs8=False)
+ self.assertRaises(ValueError, self.ref_private.export_key, format="PEM",
+ passphrase="",
+ protection="PBKDF2WithHMAC-SHA1AndAES128-CBC")
+
+ # No private keys with OpenSSH
+ self.assertRaises(ValueError, self.ref_private.export_key, format="OpenSSH",
+ passphrase="secret")
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(TestImport)
+ try:
+ tests += list_test_cases(TestImport_P192)
+ tests += list_test_cases(TestImport_P224)
+ tests += list_test_cases(TestImport_P256)
+ tests += list_test_cases(TestImport_P384)
+ tests += list_test_cases(TestImport_P521)
+ tests += list_test_cases(TestImport_Ed25519)
+ tests += list_test_cases(TestImport_Ed448)
+
+ tests += list_test_cases(TestExport_P192)
+ tests += list_test_cases(TestExport_P224)
+ tests += list_test_cases(TestExport_P256)
+ tests += list_test_cases(TestExport_P384)
+ tests += list_test_cases(TestExport_P521)
+ tests += list_test_cases(TestExport_Ed25519)
+ tests += list_test_cases(TestExport_Ed448)
+
+ except MissingTestVectorException:
+ pass
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py b/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py
new file mode 100644
index 0000000..fa92fb0
--- /dev/null
+++ b/lib/Crypto/SelfTest/PublicKey/test_import_RSA.py
@@ -0,0 +1,590 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/PublicKey/test_importKey.py: Self-test for importing RSA keys
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+import os
+import re
+import errno
+import warnings
+import unittest
+
+from Crypto.PublicKey import RSA
+from Crypto.SelfTest.st_common import a2b_hex, list_test_cases
+from Crypto.Util.py3compat import b, tostr, FileNotFoundError
+from Crypto.Util.number import inverse
+from Crypto.Util import asn1
+
+try:
+ import pycryptodome_test_vectors # type: ignore
+ test_vectors_available = True
+except ImportError:
+ test_vectors_available = False
+
+
+def load_file(file_name, mode="rb"):
+ results = None
+
+ try:
+ if not test_vectors_available:
+ raise FileNotFoundError(errno.ENOENT,
+ os.strerror(errno.ENOENT),
+ file_name)
+
+ dir_comps = ("PublicKey", "RSA")
+ init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
+ full_file_name = os.path.join(os.path.join(init_dir, *dir_comps), file_name)
+ with open(full_file_name, mode) as file_in:
+ results = file_in.read()
+
+ except FileNotFoundError:
+ warnings.warn("Warning: skipping extended tests for RSA",
+ UserWarning,
+ stacklevel=2)
+
+ return results
+
+
+def der2pem(der, text='PUBLIC'):
+ import binascii
+ chunks = [binascii.b2a_base64(der[i:i+48]) for i in range(0, len(der), 48)]
+ pem = b('-----BEGIN %s KEY-----\n' % text)
+ pem += b('').join(chunks)
+ pem += b('-----END %s KEY-----' % text)
+ return pem
+
+
+class ImportKeyTests(unittest.TestCase):
+ # 512-bit RSA key generated with openssl
+ rsaKeyPEM = u'''-----BEGIN RSA PRIVATE KEY-----
+MIIBOwIBAAJBAL8eJ5AKoIsjURpcEoGubZMxLD7+kT+TLr7UkvEtFrRhDDKMtuII
+q19FrL4pUIMymPMSLBn3hJLe30Dw48GQM4UCAwEAAQJACUSDEp8RTe32ftq8IwG8
+Wojl5mAd1wFiIOrZ/Uv8b963WJOJiuQcVN29vxU5+My9GPZ7RA3hrDBEAoHUDPrI
+OQIhAPIPLz4dphiD9imAkivY31Rc5AfHJiQRA7XixTcjEkojAiEAyh/pJHks/Mlr
++rdPNEpotBjfV4M4BkgGAA/ipcmaAjcCIQCHvhwwKVBLzzTscT2HeUdEeBMoiXXK
+JACAr3sJQJGxIQIgarRp+m1WSKV1MciwMaTOnbU7wxFs9DP1pva76lYBzgUCIQC9
+n0CnZCJ6IZYqSt0H5N7+Q+2Ro64nuwV/OSQfM6sBwQ==
+-----END RSA PRIVATE KEY-----'''
+
+ # As above, but this is actually an unencrypted PKCS#8 key
+ rsaKeyPEM8 = u'''-----BEGIN PRIVATE KEY-----
+MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAvx4nkAqgiyNRGlwS
+ga5tkzEsPv6RP5MuvtSS8S0WtGEMMoy24girX0WsvilQgzKY8xIsGfeEkt7fQPDj
+wZAzhQIDAQABAkAJRIMSnxFN7fZ+2rwjAbxaiOXmYB3XAWIg6tn9S/xv3rdYk4mK
+5BxU3b2/FTn4zL0Y9ntEDeGsMEQCgdQM+sg5AiEA8g8vPh2mGIP2KYCSK9jfVFzk
+B8cmJBEDteLFNyMSSiMCIQDKH+kkeSz8yWv6t080Smi0GN9XgzgGSAYAD+KlyZoC
+NwIhAIe+HDApUEvPNOxxPYd5R0R4EyiJdcokAICvewlAkbEhAiBqtGn6bVZIpXUx
+yLAxpM6dtTvDEWz0M/Wm9rvqVgHOBQIhAL2fQKdkInohlipK3Qfk3v5D7ZGjrie7
+BX85JB8zqwHB
+-----END PRIVATE KEY-----'''
+
+ # The same RSA private key as in rsaKeyPEM, but now encrypted
+ rsaKeyEncryptedPEM = (
+
+ # PEM encryption
+ # With DES and passphrase 'test'
+ ('test', u'''-----BEGIN RSA PRIVATE KEY-----
+Proc-Type: 4,ENCRYPTED
+DEK-Info: DES-CBC,AF8F9A40BD2FA2FC
+
+Ckl9ex1kaVEWhYC2QBmfaF+YPiR4NFkRXA7nj3dcnuFEzBnY5XULupqQpQI3qbfA
+u8GYS7+b3toWWiHZivHbAAUBPDIZG9hKDyB9Sq2VMARGsX1yW1zhNvZLIiVJzUHs
+C6NxQ1IJWOXzTew/xM2I26kPwHIvadq+/VaT8gLQdjdH0jOiVNaevjWnLgrn1mLP
+BCNRMdcexozWtAFNNqSzfW58MJL2OdMi21ED184EFytIc1BlB+FZiGZduwKGuaKy
+9bMbdb/1PSvsSzPsqW7KSSrTw6MgJAFJg6lzIYvR5F4poTVBxwBX3+EyEmShiaNY
+IRX3TgQI0IjrVuLmvlZKbGWP18FXj7I7k9tSsNOOzllTTdq3ny5vgM3A+ynfAaxp
+dysKznQ6P+IoqML1WxAID4aGRMWka+uArOJ148Rbj9s=
+-----END RSA PRIVATE KEY-----'''),
+
+ # PKCS8 encryption
+ ('winter', u'''-----BEGIN ENCRYPTED PRIVATE KEY-----
+MIIBpjBABgkqhkiG9w0BBQ0wMzAbBgkqhkiG9w0BBQwwDgQIeZIsbW3O+JcCAggA
+MBQGCCqGSIb3DQMHBAgSM2p0D8FilgSCAWBhFyP2tiGKVpGj3mO8qIBzinU60ApR
+3unvP+N6j7LVgnV2lFGaXbJ6a1PbQXe+2D6DUyBLo8EMXrKKVLqOMGkFMHc0UaV6
+R6MmrsRDrbOqdpTuVRW+NVd5J9kQQh4xnfU/QrcPPt7vpJvSf4GzG0n666Ki50OV
+M/feuVlIiyGXY6UWdVDpcOV72cq02eNUs/1JWdh2uEBvA9fCL0c07RnMrdT+CbJQ
+NjJ7f8ULtp7xvR9O3Al/yJ4Wv3i4VxF1f3MCXzhlUD4I0ONlr0kJWgeQ80q/cWhw
+ntvgJwnCn2XR1h6LA8Wp+0ghDTsL2NhJpWd78zClGhyU4r3hqu1XDjoXa7YCXCix
+jCV15+ViDJzlNCwg+W6lRg18sSLkCT7alviIE0U5tHc6UPbbHwT5QqAxAABaP+nZ
+CGqJGyiwBzrKebjgSm/KRd4C91XqcsysyH2kKPfT51MLAoD4xelOURBP
+-----END ENCRYPTED PRIVATE KEY-----'''
+ ),
+ )
+
+ rsaPublicKeyPEM = u'''-----BEGIN PUBLIC KEY-----
+MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAL8eJ5AKoIsjURpcEoGubZMxLD7+kT+T
+Lr7UkvEtFrRhDDKMtuIIq19FrL4pUIMymPMSLBn3hJLe30Dw48GQM4UCAwEAAQ==
+-----END PUBLIC KEY-----'''
+
+ # Obtained using 'ssh-keygen -i -m PKCS8 -f rsaPublicKeyPEM'
+ rsaPublicKeyOpenSSH = b('''ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAQQC/HieQCqCLI1EaXBKBrm2TMSw+/pE/ky6+1JLxLRa0YQwyjLbiCKtfRay+KVCDMpjzEiwZ94SS3t9A8OPBkDOF comment\n''')
+
+ # The private key, in PKCS#1 format encoded with DER
+ rsaKeyDER = a2b_hex(
+ '''3082013b020100024100bf1e27900aa08b23511a5c1281ae6d93312c3efe
+ 913f932ebed492f12d16b4610c328cb6e208ab5f45acbe2950833298f312
+ 2c19f78492dedf40f0e3c190338502030100010240094483129f114dedf6
+ 7edabc2301bc5a88e5e6601dd7016220ead9fd4bfc6fdeb75893898ae41c
+ 54ddbdbf1539f8ccbd18f67b440de1ac30440281d40cfac839022100f20f
+ 2f3e1da61883f62980922bd8df545ce407c726241103b5e2c53723124a23
+ 022100ca1fe924792cfcc96bfab74f344a68b418df578338064806000fe2
+ a5c99a023702210087be1c3029504bcf34ec713d877947447813288975ca
+ 240080af7b094091b12102206ab469fa6d5648a57531c8b031a4ce9db53b
+ c3116cf433f5a6f6bbea5601ce05022100bd9f40a764227a21962a4add07
+ e4defe43ed91a3ae27bb057f39241f33ab01c1
+ '''.replace(" ",""))
+
+ # The private key, in unencrypted PKCS#8 format encoded with DER
+ rsaKeyDER8 = a2b_hex(
+ '''30820155020100300d06092a864886f70d01010105000482013f3082013
+ b020100024100bf1e27900aa08b23511a5c1281ae6d93312c3efe913f932
+ ebed492f12d16b4610c328cb6e208ab5f45acbe2950833298f3122c19f78
+ 492dedf40f0e3c190338502030100010240094483129f114dedf67edabc2
+ 301bc5a88e5e6601dd7016220ead9fd4bfc6fdeb75893898ae41c54ddbdb
+ f1539f8ccbd18f67b440de1ac30440281d40cfac839022100f20f2f3e1da
+ 61883f62980922bd8df545ce407c726241103b5e2c53723124a23022100c
+ a1fe924792cfcc96bfab74f344a68b418df578338064806000fe2a5c99a0
+ 23702210087be1c3029504bcf34ec713d877947447813288975ca240080a
+ f7b094091b12102206ab469fa6d5648a57531c8b031a4ce9db53bc3116cf
+ 433f5a6f6bbea5601ce05022100bd9f40a764227a21962a4add07e4defe4
+ 3ed91a3ae27bb057f39241f33ab01c1
+ '''.replace(" ",""))
+
+ rsaPublicKeyDER = a2b_hex(
+ '''305c300d06092a864886f70d0101010500034b003048024100bf1e27900a
+ a08b23511a5c1281ae6d93312c3efe913f932ebed492f12d16b4610c328c
+ b6e208ab5f45acbe2950833298f3122c19f78492dedf40f0e3c190338502
+ 03010001
+ '''.replace(" ",""))
+
+ n = int('BF 1E 27 90 0A A0 8B 23 51 1A 5C 12 81 AE 6D 93 31 2C 3E FE 91 3F 93 2E BE D4 92 F1 2D 16 B4 61 0C 32 8C B6 E2 08 AB 5F 45 AC BE 29 50 83 32 98 F3 12 2C 19 F7 84 92 DE DF 40 F0 E3 C1 90 33 85'.replace(" ",""),16)
+ e = 65537
+ d = int('09 44 83 12 9F 11 4D ED F6 7E DA BC 23 01 BC 5A 88 E5 E6 60 1D D7 01 62 20 EA D9 FD 4B FC 6F DE B7 58 93 89 8A E4 1C 54 DD BD BF 15 39 F8 CC BD 18 F6 7B 44 0D E1 AC 30 44 02 81 D4 0C FA C8 39'.replace(" ",""),16)
+ p = int('00 F2 0F 2F 3E 1D A6 18 83 F6 29 80 92 2B D8 DF 54 5C E4 07 C7 26 24 11 03 B5 E2 C5 37 23 12 4A 23'.replace(" ",""),16)
+ q = int('00 CA 1F E9 24 79 2C FC C9 6B FA B7 4F 34 4A 68 B4 18 DF 57 83 38 06 48 06 00 0F E2 A5 C9 9A 02 37'.replace(" ",""),16)
+
+ # This is q^{-1} mod p). fastmath and slowmath use pInv (p^{-1}
+ # mod q) instead!
+ qInv = int('00 BD 9F 40 A7 64 22 7A 21 96 2A 4A DD 07 E4 DE FE 43 ED 91 A3 AE 27 BB 05 7F 39 24 1F 33 AB 01 C1'.replace(" ",""),16)
+ pInv = inverse(p,q)
+
+ def testImportKey1(self):
+ """Verify import of RSAPrivateKey DER SEQUENCE"""
+ key = RSA.importKey(self.rsaKeyDER)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+ self.assertEqual(key.d, self.d)
+ self.assertEqual(key.p, self.p)
+ self.assertEqual(key.q, self.q)
+
+ def testImportKey2(self):
+ """Verify import of SubjectPublicKeyInfo DER SEQUENCE"""
+ key = RSA.importKey(self.rsaPublicKeyDER)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+
+ def testImportKey3unicode(self):
+ """Verify import of RSAPrivateKey DER SEQUENCE, encoded with PEM as unicode"""
+ key = RSA.importKey(self.rsaKeyPEM)
+ self.assertEqual(key.has_private(),True) # assert_
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+ self.assertEqual(key.d, self.d)
+ self.assertEqual(key.p, self.p)
+ self.assertEqual(key.q, self.q)
+
+ def testImportKey3bytes(self):
+ """Verify import of RSAPrivateKey DER SEQUENCE, encoded with PEM as byte string"""
+ key = RSA.importKey(b(self.rsaKeyPEM))
+ self.assertEqual(key.has_private(),True) # assert_
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+ self.assertEqual(key.d, self.d)
+ self.assertEqual(key.p, self.p)
+ self.assertEqual(key.q, self.q)
+
+ def testImportKey4unicode(self):
+ """Verify import of RSAPrivateKey DER SEQUENCE, encoded with PEM as unicode"""
+ key = RSA.importKey(self.rsaPublicKeyPEM)
+ self.assertEqual(key.has_private(),False) # assertFalse
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+
+ def testImportKey4bytes(self):
+ """Verify import of SubjectPublicKeyInfo DER SEQUENCE, encoded with PEM as byte string"""
+ key = RSA.importKey(b(self.rsaPublicKeyPEM))
+ self.assertEqual(key.has_private(),False) # assertFalse
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+
+ def testImportKey5(self):
+ """Verifies that the imported key is still a valid RSA pair"""
+ key = RSA.importKey(self.rsaKeyPEM)
+ idem = key._encrypt(key._decrypt(89))
+ self.assertEqual(idem, 89)
+
+ def testImportKey6(self):
+ """Verifies that the imported key is still a valid RSA pair"""
+ key = RSA.importKey(self.rsaKeyDER)
+ idem = key._encrypt(key._decrypt(65))
+ self.assertEqual(idem, 65)
+
+ def testImportKey7(self):
+ """Verify import of OpenSSH public key"""
+ key = RSA.importKey(self.rsaPublicKeyOpenSSH)
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+
+ def testImportKey8(self):
+ """Verify import of encrypted PrivateKeyInfo DER SEQUENCE"""
+ for t in self.rsaKeyEncryptedPEM:
+ key = RSA.importKey(t[1], t[0])
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+ self.assertEqual(key.d, self.d)
+ self.assertEqual(key.p, self.p)
+ self.assertEqual(key.q, self.q)
+
+ def testImportKey9(self):
+ """Verify import of unencrypted PrivateKeyInfo DER SEQUENCE"""
+ key = RSA.importKey(self.rsaKeyDER8)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+ self.assertEqual(key.d, self.d)
+ self.assertEqual(key.p, self.p)
+ self.assertEqual(key.q, self.q)
+
+ def testImportKey10(self):
+ """Verify import of unencrypted PrivateKeyInfo DER SEQUENCE, encoded with PEM"""
+ key = RSA.importKey(self.rsaKeyPEM8)
+ self.assertTrue(key.has_private())
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+ self.assertEqual(key.d, self.d)
+ self.assertEqual(key.p, self.p)
+ self.assertEqual(key.q, self.q)
+
+ def testImportKey11(self):
+ """Verify import of RSAPublicKey DER SEQUENCE"""
+ der = asn1.DerSequence([17, 3]).encode()
+ key = RSA.importKey(der)
+ self.assertEqual(key.n, 17)
+ self.assertEqual(key.e, 3)
+
+ def testImportKey12(self):
+ """Verify import of RSAPublicKey DER SEQUENCE, encoded with PEM"""
+ der = asn1.DerSequence([17, 3]).encode()
+ pem = der2pem(der)
+ key = RSA.importKey(pem)
+ self.assertEqual(key.n, 17)
+ self.assertEqual(key.e, 3)
+
+ def test_import_key_windows_cr_lf(self):
+ pem_cr_lf = "\r\n".join(self.rsaKeyPEM.splitlines())
+ key = RSA.importKey(pem_cr_lf)
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+ self.assertEqual(key.d, self.d)
+ self.assertEqual(key.p, self.p)
+ self.assertEqual(key.q, self.q)
+
+ def test_import_empty(self):
+ self.assertRaises(ValueError, RSA.import_key, b"")
+
+ ###
+ def testExportKey1(self):
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ derKey = key.export_key("DER")
+ self.assertEqual(derKey, self.rsaKeyDER)
+
+ def testExportKey2(self):
+ key = RSA.construct([self.n, self.e])
+ derKey = key.export_key("DER")
+ self.assertEqual(derKey, self.rsaPublicKeyDER)
+
+ def testExportKey3(self):
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ pemKey = key.export_key("PEM")
+ self.assertEqual(pemKey, b(self.rsaKeyPEM))
+
+ def testExportKey4(self):
+ key = RSA.construct([self.n, self.e])
+ pemKey = key.export_key("PEM")
+ self.assertEqual(pemKey, b(self.rsaPublicKeyPEM))
+
+ def testExportKey5(self):
+ key = RSA.construct([self.n, self.e])
+ openssh_1 = key.export_key("OpenSSH").split()
+ openssh_2 = self.rsaPublicKeyOpenSSH.split()
+ self.assertEqual(openssh_1[0], openssh_2[0])
+ self.assertEqual(openssh_1[1], openssh_2[1])
+
+ def testExportKey7(self):
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ derKey = key.export_key("DER", pkcs=8)
+ self.assertEqual(derKey, self.rsaKeyDER8)
+
+ def testExportKey8(self):
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ pemKey = key.export_key("PEM", pkcs=8)
+ self.assertEqual(pemKey, b(self.rsaKeyPEM8))
+
+ def testExportKey9(self):
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ self.assertRaises(ValueError, key.export_key, "invalid-format")
+
+ def testExportKey10(self):
+ # Export and re-import the encrypted key. It must match.
+ # PEM envelope, PKCS#1, old PEM encryption
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ outkey = key.export_key('PEM', 'test')
+ self.assertTrue(tostr(outkey).find('4,ENCRYPTED')!=-1)
+ self.assertTrue(tostr(outkey).find('BEGIN RSA PRIVATE KEY')!=-1)
+ inkey = RSA.importKey(outkey, 'test')
+ self.assertEqual(key.n, inkey.n)
+ self.assertEqual(key.e, inkey.e)
+ self.assertEqual(key.d, inkey.d)
+
+ def testExportKey11(self):
+ # Export and re-import the encrypted key. It must match.
+ # PEM envelope, PKCS#1, old PEM encryption
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ outkey = key.export_key('PEM', 'test', pkcs=1)
+ self.assertTrue(tostr(outkey).find('4,ENCRYPTED')!=-1)
+ self.assertTrue(tostr(outkey).find('BEGIN RSA PRIVATE KEY')!=-1)
+ inkey = RSA.importKey(outkey, 'test')
+ self.assertEqual(key.n, inkey.n)
+ self.assertEqual(key.e, inkey.e)
+ self.assertEqual(key.d, inkey.d)
+
+ def testExportKey12(self):
+ # Export and re-import the encrypted key. It must match.
+ # PEM envelope, PKCS#8, old PEM encryption
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ outkey = key.export_key('PEM', 'test', pkcs=8)
+ self.assertTrue(tostr(outkey).find('4,ENCRYPTED')!=-1)
+ self.assertTrue(tostr(outkey).find('BEGIN PRIVATE KEY')!=-1)
+ inkey = RSA.importKey(outkey, 'test')
+ self.assertEqual(key.n, inkey.n)
+ self.assertEqual(key.e, inkey.e)
+ self.assertEqual(key.d, inkey.d)
+
+ def testExportKey13(self):
+ # Export and re-import the encrypted key. It must match.
+ # PEM envelope, PKCS#8, PKCS#8 encryption
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ outkey = key.export_key('PEM', 'test', pkcs=8,
+ protection='PBKDF2WithHMAC-SHA1AndDES-EDE3-CBC')
+ self.assertTrue(tostr(outkey).find('4,ENCRYPTED')==-1)
+ self.assertTrue(tostr(outkey).find('BEGIN ENCRYPTED PRIVATE KEY')!=-1)
+ inkey = RSA.importKey(outkey, 'test')
+ self.assertEqual(key.n, inkey.n)
+ self.assertEqual(key.e, inkey.e)
+ self.assertEqual(key.d, inkey.d)
+
+ def testExportKey14(self):
+ # Export and re-import the encrypted key. It must match.
+ # DER envelope, PKCS#8, PKCS#8 encryption
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ outkey = key.export_key('DER', 'test', pkcs=8)
+ inkey = RSA.importKey(outkey, 'test')
+ self.assertEqual(key.n, inkey.n)
+ self.assertEqual(key.e, inkey.e)
+ self.assertEqual(key.d, inkey.d)
+
+ def testExportKey15(self):
+ # Verify that that error an condition is detected when trying to
+ # use a password with DER encoding and PKCS#1.
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ self.assertRaises(ValueError, key.export_key, 'DER', 'test', 1)
+
+ def test_import_key(self):
+ """Verify that import_key is an alias to importKey"""
+ key = RSA.import_key(self.rsaPublicKeyDER)
+ self.assertFalse(key.has_private())
+ self.assertEqual(key.n, self.n)
+ self.assertEqual(key.e, self.e)
+
+ def test_import_key_ba_mv(self):
+ """Verify that import_key can be used on bytearrays and memoryviews"""
+ key = RSA.import_key(bytearray(self.rsaPublicKeyDER))
+ key = RSA.import_key(memoryview(self.rsaPublicKeyDER))
+
+ def test_exportKey(self):
+ key = RSA.construct([self.n, self.e, self.d, self.p, self.q, self.pInv])
+ self.assertEqual(key.export_key(), key.exportKey())
+
+
+class ImportKeyFromX509Cert(unittest.TestCase):
+
+ def test_x509v1(self):
+
+ # Sample V1 certificate with a 1024 bit RSA key
+ x509_v1_cert = """
+-----BEGIN CERTIFICATE-----
+MIICOjCCAaMCAQEwDQYJKoZIhvcNAQEEBQAwfjENMAsGA1UEChMEQWNtZTELMAkG
+A1UECxMCUkQxHDAaBgkqhkiG9w0BCQEWDXNwYW1AYWNtZS5vcmcxEzARBgNVBAcT
+Ck1ldHJvcG9saXMxETAPBgNVBAgTCE5ldyBZb3JrMQswCQYDVQQGEwJVUzENMAsG
+A1UEAxMEdGVzdDAeFw0xNDA3MTExOTU3MjRaFw0xNzA0MDYxOTU3MjRaME0xCzAJ
+BgNVBAYTAlVTMREwDwYDVQQIEwhOZXcgWW9yazENMAsGA1UEChMEQWNtZTELMAkG
+A1UECxMCUkQxDzANBgNVBAMTBmxhdHZpYTCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
+gYkCgYEAyG+kytdRj3TFbRmHDYp3TXugVQ81chew0qeOxZWOz80IjtWpgdOaCvKW
+NCuc8wUR9BWrEQW+39SaRMLiQfQtyFSQZijc3nsEBu/Lo4uWZ0W/FHDRVSvkJA/V
+Ex5NL5ikI+wbUeCV5KajGNDalZ8F1pk32+CBs8h1xNx5DyxuEHUCAwEAATANBgkq
+hkiG9w0BAQQFAAOBgQCVQF9Y//Q4Psy+umEM38pIlbZ2hxC5xNz/MbVPwuCkNcGn
+KYNpQJP+JyVTsPpO8RLZsAQDzRueMI3S7fbbwTzAflN0z19wvblvu93xkaBytVok
+9VBAH28olVhy9b1MMeg2WOt5sUEQaFNPnwwsyiY9+HsRpvpRnPSQF+kyYVsshQ==
+-----END CERTIFICATE-----
+ """.strip()
+
+ # RSA public key as dumped by openssl
+ exponent = 65537
+ modulus_str = """
+00:c8:6f:a4:ca:d7:51:8f:74:c5:6d:19:87:0d:8a:
+77:4d:7b:a0:55:0f:35:72:17:b0:d2:a7:8e:c5:95:
+8e:cf:cd:08:8e:d5:a9:81:d3:9a:0a:f2:96:34:2b:
+9c:f3:05:11:f4:15:ab:11:05:be:df:d4:9a:44:c2:
+e2:41:f4:2d:c8:54:90:66:28:dc:de:7b:04:06:ef:
+cb:a3:8b:96:67:45:bf:14:70:d1:55:2b:e4:24:0f:
+d5:13:1e:4d:2f:98:a4:23:ec:1b:51:e0:95:e4:a6:
+a3:18:d0:da:95:9f:05:d6:99:37:db:e0:81:b3:c8:
+75:c4:dc:79:0f:2c:6e:10:75
+ """
+ modulus = int(re.sub("[^0-9a-f]","", modulus_str), 16)
+
+ key = RSA.importKey(x509_v1_cert)
+ self.assertEqual(key.e, exponent)
+ self.assertEqual(key.n, modulus)
+ self.assertFalse(key.has_private())
+
+ def test_x509v3(self):
+
+ # Sample V3 certificate with a 1024 bit RSA key
+ x509_v3_cert = """
+-----BEGIN CERTIFICATE-----
+MIIEcjCCAlqgAwIBAgIBATANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJVUzEL
+MAkGA1UECAwCTUQxEjAQBgNVBAcMCUJhbHRpbW9yZTEQMA4GA1UEAwwHVGVzdCBD
+QTEfMB0GCSqGSIb3DQEJARYQdGVzdEBleGFtcGxlLmNvbTAeFw0xNDA3MTIwOTM1
+MTJaFw0xNzA0MDcwOTM1MTJaMEQxCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJNRDES
+MBAGA1UEBwwJQmFsdGltb3JlMRQwEgYDVQQDDAtUZXN0IFNlcnZlcjCBnzANBgkq
+hkiG9w0BAQEFAAOBjQAwgYkCgYEA/S7GJV2OcFdyNMQ4K75KrYFtMEn3VnEFdPHa
+jyS37XlMxSh0oS4GeTGVUCJInl5Cpsv8WQdh03FfeOdvzp5IZ46OcjeOPiWnmjgl
+2G5j7e2bDH7RSchGV+OD6Fb1Agvuu2/9iy8fdf3rPQ/7eAddzKUrzwacVbnW+tg2
+QtSXKRcCAwEAAaOB1TCB0jAdBgNVHQ4EFgQU/WwCX7FfWMIPDFfJ+I8a2COG+l8w
+HwYDVR0jBBgwFoAUa0hkif3RMaraiWtsOOZZlLu9wJwwCQYDVR0TBAIwADALBgNV
+HQ8EBAMCBeAwSgYDVR0RBEMwQYILZXhhbXBsZS5jb22CD3d3dy5leGFtcGxlLmNv
+bYIQbWFpbC5leGFtcGxlLmNvbYIPZnRwLmV4YW1wbGUuY29tMCwGCWCGSAGG+EIB
+DQQfFh1PcGVuU1NMIEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTANBgkqhkiG9w0BAQsF
+AAOCAgEAvO6xfdsGbnoK4My3eJthodTAjMjPwFVY133LH04QLcCv54TxKhtUg1fi
+PgdjVe1HpTytPBfXy2bSZbXAN0abZCtw1rYrnn7o1g2pN8iypVq3zVn0iMTzQzxs
+zEPO3bpR/UhNSf90PmCsS5rqZpAAnXSaAy1ClwHWk/0eG2pYkhE1m1ABVMN2lsAW
+e9WxGk6IFqaI9O37NYQwmEypMs4DC+ECJEvbPFiqi3n0gbXCZJJ6omDA5xJldaYK
+Oa7KR3s/qjBsu9UAiWpLBuFoSTHIF2aeRKRFmUdmzwo43eVPep65pY6eQ4AdL2RF
+rqEuINbGlzI5oQyYhu71IwB+iPZXaZZPlwjLgOsuad/p2hOgDb5WxUi8FnDPursQ
+ujfpIpmrOP/zpvvQWnwePI3lI+5n41kTBSbefXEdv6rXpHk3QRzB90uPxnXPdxSC
+16ASA8bQT5an/1AgoE3k9CrcD2K0EmgaX0YI0HUhkyzbkg34EhpWJ6vvRUbRiNRo
+9cIbt/ya9Y9u0Ja8GLXv6dwX0l0IdJMkL8KifXUFAVCujp1FBrr/gdmwQn8itANy
++qbnWSxmOvtaY0zcaFAcONuHva0h51/WqXOMO1eb8PhR4HIIYU8p1oBwQp7dSni8
+THDi1F+GG5PsymMDj5cWK42f+QzjVw5PrVmFqqrrEoMlx8DWh5Y=
+-----END CERTIFICATE-----
+""".strip()
+
+ # RSA public key as dumped by openssl
+ exponent = 65537
+ modulus_str = """
+00:fd:2e:c6:25:5d:8e:70:57:72:34:c4:38:2b:be:
+4a:ad:81:6d:30:49:f7:56:71:05:74:f1:da:8f:24:
+b7:ed:79:4c:c5:28:74:a1:2e:06:79:31:95:50:22:
+48:9e:5e:42:a6:cb:fc:59:07:61:d3:71:5f:78:e7:
+6f:ce:9e:48:67:8e:8e:72:37:8e:3e:25:a7:9a:38:
+25:d8:6e:63:ed:ed:9b:0c:7e:d1:49:c8:46:57:e3:
+83:e8:56:f5:02:0b:ee:bb:6f:fd:8b:2f:1f:75:fd:
+eb:3d:0f:fb:78:07:5d:cc:a5:2b:cf:06:9c:55:b9:
+d6:fa:d8:36:42:d4:97:29:17
+ """
+ modulus = int(re.sub("[^0-9a-f]","", modulus_str), 16)
+
+ key = RSA.importKey(x509_v3_cert)
+ self.assertEqual(key.e, exponent)
+ self.assertEqual(key.n, modulus)
+ self.assertFalse(key.has_private())
+
+
+class TestImport_2048(unittest.TestCase):
+
+ def test_import_openssh_public(self):
+ key_file_ref = load_file("rsa2048_private.pem")
+ key_file = load_file("rsa2048_public_openssh.txt")
+
+ # Skip test if test vectors are not installed
+ if None in (key_file_ref, key_file):
+ return
+
+ key_ref = RSA.import_key(key_file_ref).public_key()
+ key = RSA.import_key(key_file)
+ self.assertEqual(key_ref, key)
+
+ def test_import_openssh_private_clear(self):
+ key_file = load_file("rsa2048_private_openssh.pem")
+ key_file_old = load_file("rsa2048_private_openssh_old.pem")
+
+ # Skip test if test vectors are not installed
+ if None in (key_file_old, key_file):
+ return
+
+ key = RSA.import_key(key_file)
+ key_old = RSA.import_key(key_file_old)
+
+ self.assertEqual(key, key_old)
+
+ def test_import_openssh_private_password(self):
+ key_file = load_file("rsa2048_private_openssh_pwd.pem")
+ key_file_old = load_file("rsa2048_private_openssh_pwd_old.pem")
+
+ # Skip test if test vectors are not installed
+ if None in (key_file_old, key_file):
+ return
+
+ key = RSA.import_key(key_file, b"password")
+ key_old = RSA.import_key(key_file_old)
+ self.assertEqual(key, key_old)
+
+
+if __name__ == '__main__':
+ unittest.main()
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(ImportKeyTests)
+ tests += list_test_cases(ImportKeyFromX509Cert)
+ tests += list_test_cases(TestImport_2048)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Random/__init__.py b/lib/Crypto/SelfTest/Random/__init__.py
new file mode 100644
index 0000000..53061cc
--- /dev/null
+++ b/lib/Crypto/SelfTest/Random/__init__.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Random/__init__.py: Self-test for random number generation modules
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test for random number generators"""
+
+__revision__ = "$Id$"
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest.Random import test_random; tests += test_random.get_tests(config=config)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Random/test_random.py b/lib/Crypto/SelfTest/Random/test_random.py
new file mode 100644
index 0000000..8fadc53
--- /dev/null
+++ b/lib/Crypto/SelfTest/Random/test_random.py
@@ -0,0 +1,167 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Util/test_generic.py: Self-test for the Crypto.Random.new() function
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test suite for Crypto.Random.new()"""
+
+import sys
+import unittest
+from Crypto.Util.py3compat import b
+
+class SimpleTest(unittest.TestCase):
+ def runTest(self):
+ """Crypto.Random.new()"""
+ # Import the Random module and try to use it
+ from Crypto import Random
+ randobj = Random.new()
+ x = randobj.read(16)
+ y = randobj.read(16)
+ self.assertNotEqual(x, y)
+ z = Random.get_random_bytes(16)
+ self.assertNotEqual(x, z)
+ self.assertNotEqual(y, z)
+ # Test the Random.random module, which
+ # implements a subset of Python's random API
+ # Not implemented:
+ # seed(), getstate(), setstate(), jumpahead()
+ # random(), uniform(), triangular(), betavariate()
+ # expovariate(), gammavariate(), gauss(),
+ # longnormvariate(), normalvariate(),
+ # vonmisesvariate(), paretovariate()
+ # weibullvariate()
+ # WichmannHill(), whseed(), SystemRandom()
+ from Crypto.Random import random
+ x = random.getrandbits(16*8)
+ y = random.getrandbits(16*8)
+ self.assertNotEqual(x, y)
+ # Test randrange
+ if x>y:
+ start = y
+ stop = x
+ else:
+ start = x
+ stop = y
+ for step in range(1,10):
+ x = random.randrange(start,stop,step)
+ y = random.randrange(start,stop,step)
+ self.assertNotEqual(x, y)
+ self.assertEqual(start <= x < stop, True)
+ self.assertEqual(start <= y < stop, True)
+ self.assertEqual((x - start) % step, 0)
+ self.assertEqual((y - start) % step, 0)
+ for i in range(10):
+ self.assertEqual(random.randrange(1,2), 1)
+ self.assertRaises(ValueError, random.randrange, start, start)
+ self.assertRaises(ValueError, random.randrange, stop, start, step)
+ self.assertRaises(TypeError, random.randrange, start, stop, step, step)
+ self.assertRaises(TypeError, random.randrange, start, stop, "1")
+ self.assertRaises(TypeError, random.randrange, "1", stop, step)
+ self.assertRaises(TypeError, random.randrange, 1, "2", step)
+ self.assertRaises(ValueError, random.randrange, start, stop, 0)
+ # Test randint
+ x = random.randint(start,stop)
+ y = random.randint(start,stop)
+ self.assertNotEqual(x, y)
+ self.assertEqual(start <= x <= stop, True)
+ self.assertEqual(start <= y <= stop, True)
+ for i in range(10):
+ self.assertEqual(random.randint(1,1), 1)
+ self.assertRaises(ValueError, random.randint, stop, start)
+ self.assertRaises(TypeError, random.randint, start, stop, step)
+ self.assertRaises(TypeError, random.randint, "1", stop)
+ self.assertRaises(TypeError, random.randint, 1, "2")
+ # Test choice
+ seq = range(10000)
+ x = random.choice(seq)
+ y = random.choice(seq)
+ self.assertNotEqual(x, y)
+ self.assertEqual(x in seq, True)
+ self.assertEqual(y in seq, True)
+ for i in range(10):
+ self.assertEqual(random.choice((1,2,3)) in (1,2,3), True)
+ self.assertEqual(random.choice([1,2,3]) in [1,2,3], True)
+ if sys.version_info[0] == 3:
+ self.assertEqual(random.choice(bytearray(b('123'))) in bytearray(b('123')), True)
+ self.assertEqual(1, random.choice([1]))
+ self.assertRaises(IndexError, random.choice, [])
+ self.assertRaises(TypeError, random.choice, 1)
+ # Test shuffle. Lacks random parameter to specify function.
+ # Make copies of seq
+ seq = range(500)
+ x = list(seq)
+ y = list(seq)
+ random.shuffle(x)
+ random.shuffle(y)
+ self.assertNotEqual(x, y)
+ self.assertEqual(len(seq), len(x))
+ self.assertEqual(len(seq), len(y))
+ for i in range(len(seq)):
+ self.assertEqual(x[i] in seq, True)
+ self.assertEqual(y[i] in seq, True)
+ self.assertEqual(seq[i] in x, True)
+ self.assertEqual(seq[i] in y, True)
+ z = [1]
+ random.shuffle(z)
+ self.assertEqual(z, [1])
+ if sys.version_info[0] == 3:
+ z = bytearray(b('12'))
+ random.shuffle(z)
+ self.assertEqual(b('1') in z, True)
+ self.assertRaises(TypeError, random.shuffle, b('12'))
+ self.assertRaises(TypeError, random.shuffle, 1)
+ self.assertRaises(TypeError, random.shuffle, "11")
+ self.assertRaises(TypeError, random.shuffle, (1,2))
+ # 2to3 wraps a list() around it, alas - but I want to shoot
+ # myself in the foot here! :D
+ # if sys.version_info[0] == 3:
+ # self.assertRaises(TypeError, random.shuffle, range(3))
+ # Test sample
+ x = random.sample(seq, 20)
+ y = random.sample(seq, 20)
+ self.assertNotEqual(x, y)
+ for i in range(20):
+ self.assertEqual(x[i] in seq, True)
+ self.assertEqual(y[i] in seq, True)
+ z = random.sample([1], 1)
+ self.assertEqual(z, [1])
+ z = random.sample((1,2,3), 1)
+ self.assertEqual(z[0] in (1,2,3), True)
+ z = random.sample("123", 1)
+ self.assertEqual(z[0] in "123", True)
+ z = random.sample(range(3), 1)
+ self.assertEqual(z[0] in range(3), True)
+ if sys.version_info[0] == 3:
+ z = random.sample(b("123"), 1)
+ self.assertEqual(z[0] in b("123"), True)
+ z = random.sample(bytearray(b("123")), 1)
+ self.assertEqual(z[0] in bytearray(b("123")), True)
+ self.assertRaises(TypeError, random.sample, 1)
+
+def get_tests(config={}):
+ return [SimpleTest()]
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Signature/__init__.py b/lib/Crypto/SelfTest/Signature/__init__.py
new file mode 100644
index 0000000..83cf0f3
--- /dev/null
+++ b/lib/Crypto/SelfTest/Signature/__init__.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Signature/__init__.py: Self-test for signature modules
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test for signature modules"""
+
+import unittest
+from . import test_pkcs1_15, test_pss, test_dss, test_eddsa
+
+
+def get_tests(config={}):
+ tests = []
+ tests += test_pkcs1_15.get_tests(config=config)
+ tests += test_pss.get_tests(config=config)
+ tests += test_dss.get_tests(config=config)
+ tests += test_eddsa.get_tests(config=config)
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Signature/test_dss.py b/lib/Crypto/SelfTest/Signature/test_dss.py
new file mode 100644
index 0000000..d3f8dfc
--- /dev/null
+++ b/lib/Crypto/SelfTest/Signature/test_dss.py
@@ -0,0 +1,1369 @@
+#
+# SelfTest/Signature/test_dss.py: Self-test for DSS signatures
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import re
+import unittest
+from binascii import hexlify, unhexlify
+
+from Crypto.Util.py3compat import tobytes, bord, bchr
+
+from Crypto.Hash import (SHA1, SHA224, SHA256, SHA384, SHA512,
+ SHA3_224, SHA3_256, SHA3_384, SHA3_512)
+from Crypto.Signature import DSS
+from Crypto.PublicKey import DSA, ECC
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors, load_test_vectors_wycheproof
+from Crypto.Util.number import bytes_to_long, long_to_bytes
+
+
+def t2b(hexstring):
+ ws = hexstring.replace(" ", "").replace("\n", "")
+ return unhexlify(tobytes(ws))
+
+
+def t2l(hexstring):
+ ws = hexstring.replace(" ", "").replace("\n", "")
+ return int(ws, 16)
+
+
+def load_hash_by_name(hash_name):
+ return __import__("Crypto.Hash." + hash_name, globals(), locals(), ["new"])
+
+
+class StrRNG:
+
+ def __init__(self, randomness):
+ length = len(randomness)
+ self._idx = 0
+ # Fix required to get the right K (see how randint() works!)
+ self._randomness = long_to_bytes(bytes_to_long(randomness) - 1, length)
+
+ def __call__(self, n):
+ out = self._randomness[self._idx:self._idx + n]
+ self._idx += n
+ return out
+
+
+class FIPS_DSA_Tests(unittest.TestCase):
+
+ # 1st 1024 bit key from SigGen.txt
+ P = 0xa8f9cd201e5e35d892f85f80e4db2599a5676a3b1d4f190330ed3256b26d0e80a0e49a8fffaaad2a24f472d2573241d4d6d6c7480c80b4c67bb4479c15ada7ea8424d2502fa01472e760241713dab025ae1b02e1703a1435f62ddf4ee4c1b664066eb22f2e3bf28bb70a2a76e4fd5ebe2d1229681b5b06439ac9c7e9d8bde283
+ Q = 0xf85f0f83ac4df7ea0cdf8f469bfeeaea14156495
+ G = 0x2b3152ff6c62f14622b8f48e59f8af46883b38e79b8c74deeae9df131f8b856e3ad6c8455dab87cc0da8ac973417ce4f7878557d6cdf40b35b4a0ca3eb310c6a95d68ce284ad4e25ea28591611ee08b8444bd64b25f3f7c572410ddfb39cc728b9c936f85f419129869929cdb909a6a3a99bbe089216368171bd0ba81de4fe33
+ X = 0xc53eae6d45323164c7d07af5715703744a63fc3a
+ Y = 0x313fd9ebca91574e1c2eebe1517c57e0c21b0209872140c5328761bbb2450b33f1b18b409ce9ab7c4cd8fda3391e8e34868357c199e16a6b2eba06d6749def791d79e95d3a4d09b24c392ad89dbf100995ae19c01062056bb14bce005e8731efde175f95b975089bdcdaea562b32786d96f5a31aedf75364008ad4fffebb970b
+
+ key_pub = DSA.construct((Y, G, P, Q))
+ key_priv = DSA.construct((Y, G, P, Q, X))
+
+ def shortDescription(self):
+ return "FIPS DSA Tests"
+
+ def test_loopback(self):
+ hashed_msg = SHA512.new(b"test")
+ signer = DSS.new(self.key_priv, 'fips-186-3')
+ signature = signer.sign(hashed_msg)
+
+ verifier = DSS.new(self.key_pub, 'fips-186-3')
+ verifier.verify(hashed_msg, signature)
+
+ def test_negative_unapproved_hashes(self):
+ """Verify that unapproved hashes are rejected"""
+
+ from Crypto.Hash import RIPEMD160
+
+ self.description = "Unapproved hash (RIPEMD160) test"
+ hash_obj = RIPEMD160.new()
+ signer = DSS.new(self.key_priv, 'fips-186-3')
+ self.assertRaises(ValueError, signer.sign, hash_obj)
+ self.assertRaises(ValueError, signer.verify, hash_obj, b"\x00" * 40)
+
+ def test_negative_unknown_modes_encodings(self):
+ """Verify that unknown modes/encodings are rejected"""
+
+ self.description = "Unknown mode test"
+ self.assertRaises(ValueError, DSS.new, self.key_priv, 'fips-186-0')
+
+ self.description = "Unknown encoding test"
+ self.assertRaises(ValueError, DSS.new, self.key_priv, 'fips-186-3', 'xml')
+
+ def test_asn1_encoding(self):
+ """Verify ASN.1 encoding"""
+
+ self.description = "ASN.1 encoding test"
+ hash_obj = SHA1.new()
+ signer = DSS.new(self.key_priv, 'fips-186-3', 'der')
+ signature = signer.sign(hash_obj)
+
+ # Verify that output looks like a DER SEQUENCE
+ self.assertEqual(bord(signature[0]), 48)
+ signer.verify(hash_obj, signature)
+
+ # Verify that ASN.1 parsing fails as expected
+ signature = bchr(7) + signature[1:]
+ self.assertRaises(ValueError, signer.verify, hash_obj, signature)
+
+ def test_sign_verify(self):
+ """Verify public/private method"""
+
+ self.description = "can_sign() test"
+ signer = DSS.new(self.key_priv, 'fips-186-3')
+ self.assertTrue(signer.can_sign())
+
+ signer = DSS.new(self.key_pub, 'fips-186-3')
+ self.assertFalse(signer.can_sign())
+
+ try:
+ signer.sign(SHA256.new(b'xyz'))
+ except TypeError as e:
+ msg = str(e)
+ else:
+ msg = ""
+ self.assertTrue("Private key is needed" in msg)
+
+
+class FIPS_DSA_Tests_KAT(unittest.TestCase):
+ pass
+
+
+test_vectors_verify = load_test_vectors(("Signature", "DSA"),
+ "FIPS_186_3_SigVer.rsp",
+ "Signature Verification 186-3",
+ {'result': lambda x: x}) or []
+
+for idx, tv in enumerate(test_vectors_verify):
+
+ if isinstance(tv, str):
+ res = re.match(r"\[mod = L=([0-9]+), N=([0-9]+), ([a-zA-Z0-9-]+)\]", tv)
+ assert(res)
+ hash_name = res.group(3).replace("-", "")
+ hash_module = load_hash_by_name(hash_name)
+ continue
+
+ if hasattr(tv, "p"):
+ modulus = tv.p
+ generator = tv.g
+ suborder = tv.q
+ continue
+
+ hash_obj = hash_module.new(tv.msg)
+
+ comps = [bytes_to_long(x) for x in (tv.y, generator, modulus, suborder)]
+ key = DSA.construct(comps, False) # type: ignore
+ verifier = DSS.new(key, 'fips-186-3')
+
+ def positive_test(self, verifier=verifier, hash_obj=hash_obj, signature=tv.r+tv.s):
+ verifier.verify(hash_obj, signature)
+
+ def negative_test(self, verifier=verifier, hash_obj=hash_obj, signature=tv.r+tv.s):
+ self.assertRaises(ValueError, verifier.verify, hash_obj, signature)
+
+ if tv.result == 'p':
+ setattr(FIPS_DSA_Tests_KAT, "test_verify_positive_%d" % idx, positive_test)
+ else:
+ setattr(FIPS_DSA_Tests_KAT, "test_verify_negative_%d" % idx, negative_test)
+
+
+test_vectors_sign = load_test_vectors(("Signature", "DSA"),
+ "FIPS_186_3_SigGen.txt",
+ "Signature Creation 186-3",
+ {}) or []
+
+for idx, tv in enumerate(test_vectors_sign):
+
+ if isinstance(tv, str):
+ res = re.match(r"\[mod = L=([0-9]+), N=([0-9]+), ([a-zA-Z0-9-]+)\]", tv)
+ assert(res)
+ hash_name = res.group(3).replace("-", "")
+ hash_module = load_hash_by_name(hash_name)
+ continue
+
+ if hasattr(tv, "p"):
+ modulus = tv.p
+ generator = tv.g
+ suborder = tv.q
+ continue
+
+ hash_obj = hash_module.new(tv.msg)
+ comps_dsa = [bytes_to_long(x) for x in (tv.y, generator, modulus, suborder, tv.x)]
+ key = DSA.construct(comps_dsa, False) # type: ignore
+ signer = DSS.new(key, 'fips-186-3', randfunc=StrRNG(tv.k))
+
+ def new_test(self, signer=signer, hash_obj=hash_obj, signature=tv.r+tv.s):
+ self.assertEqual(signer.sign(hash_obj), signature)
+ setattr(FIPS_DSA_Tests_KAT, "test_sign_%d" % idx, new_test)
+
+
+class FIPS_ECDSA_Tests(unittest.TestCase):
+
+ key_priv = ECC.generate(curve="P-256")
+ key_pub = key_priv.public_key()
+
+ def shortDescription(self):
+ return "FIPS ECDSA Tests"
+
+ def test_loopback(self):
+ hashed_msg = SHA512.new(b"test")
+ signer = DSS.new(self.key_priv, 'fips-186-3')
+ signature = signer.sign(hashed_msg)
+
+ verifier = DSS.new(self.key_pub, 'fips-186-3')
+ verifier.verify(hashed_msg, signature)
+
+ def test_negative_unapproved_hashes(self):
+ """Verify that unapproved hashes are rejected"""
+
+ from Crypto.Hash import SHA1
+
+ self.description = "Unapproved hash (SHA-1) test"
+ hash_obj = SHA1.new()
+ signer = DSS.new(self.key_priv, 'fips-186-3')
+ self.assertRaises(ValueError, signer.sign, hash_obj)
+ self.assertRaises(ValueError, signer.verify, hash_obj, b"\x00" * 40)
+
+ def test_negative_eddsa_key(self):
+ key = ECC.generate(curve="ed25519")
+ self.assertRaises(ValueError, DSS.new, key, 'fips-186-3')
+
+ def test_sign_verify(self):
+ """Verify public/private method"""
+
+ self.description = "can_sign() test"
+ signer = DSS.new(self.key_priv, 'fips-186-3')
+ self.assertTrue(signer.can_sign())
+
+ signer = DSS.new(self.key_pub, 'fips-186-3')
+ self.assertFalse(signer.can_sign())
+ self.assertRaises(TypeError, signer.sign, SHA256.new(b'xyz'))
+
+ try:
+ signer.sign(SHA256.new(b'xyz'))
+ except TypeError as e:
+ msg = str(e)
+ else:
+ msg = ""
+ self.assertTrue("Private key is needed" in msg)
+
+ def test_negative_unknown_modes_encodings(self):
+ """Verify that unknown modes/encodings are rejected"""
+
+ self.description = "Unknown mode test"
+ self.assertRaises(ValueError, DSS.new, self.key_priv, 'fips-186-0')
+
+ self.description = "Unknown encoding test"
+ self.assertRaises(ValueError, DSS.new, self.key_priv, 'fips-186-3', 'xml')
+
+ def test_asn1_encoding(self):
+ """Verify ASN.1 encoding"""
+
+ self.description = "ASN.1 encoding test"
+ hash_obj = SHA256.new()
+ signer = DSS.new(self.key_priv, 'fips-186-3', 'der')
+ signature = signer.sign(hash_obj)
+
+ # Verify that output looks like a DER SEQUENCE
+ self.assertEqual(bord(signature[0]), 48)
+ signer.verify(hash_obj, signature)
+
+ # Verify that ASN.1 parsing fails as expected
+ signature = bchr(7) + signature[1:]
+ self.assertRaises(ValueError, signer.verify, hash_obj, signature)
+
+
+class FIPS_ECDSA_Tests_KAT(unittest.TestCase):
+ pass
+
+
+test_vectors_verify = load_test_vectors(("Signature", "ECDSA"),
+ "SigVer.rsp",
+ "ECDSA Signature Verification 186-3",
+ {'result': lambda x: x,
+ 'qx': lambda x: int(x, 16),
+ 'qy': lambda x: int(x, 16),
+ }) or []
+test_vectors_verify += load_test_vectors(("Signature", "ECDSA"),
+ "SigVer_TruncatedSHAs.rsp",
+ "ECDSA Signature Verification 186-3",
+ {'result': lambda x: x,
+ 'qx': lambda x: int(x, 16),
+ 'qy': lambda x: int(x, 16),
+ }) or []
+
+
+for idx, tv in enumerate(test_vectors_verify):
+
+ if isinstance(tv, str):
+ res = re.match(r"\[(P-[0-9]+),(SHA-[0-9]+)\]", tv)
+ assert res
+ curve_name = res.group(1)
+ hash_name = res.group(2).replace("-", "")
+ if hash_name in ("SHA512224", "SHA512256"):
+ truncate = hash_name[-3:]
+ hash_name = hash_name[:-3]
+ else:
+ truncate = None
+ hash_module = load_hash_by_name(hash_name)
+ continue
+
+ if truncate is None:
+ hash_obj = hash_module.new(tv.msg)
+ else:
+ hash_obj = hash_module.new(tv.msg, truncate=truncate)
+ ecc_key = ECC.construct(curve=curve_name, point_x=tv.qx, point_y=tv.qy)
+ verifier = DSS.new(ecc_key, 'fips-186-3')
+
+ def positive_test(self, verifier=verifier, hash_obj=hash_obj, signature=tv.r+tv.s):
+ verifier.verify(hash_obj, signature)
+
+ def negative_test(self, verifier=verifier, hash_obj=hash_obj, signature=tv.r+tv.s):
+ self.assertRaises(ValueError, verifier.verify, hash_obj, signature)
+
+ if tv.result.startswith('p'):
+ setattr(FIPS_ECDSA_Tests_KAT, "test_verify_positive_%d" % idx, positive_test)
+ else:
+ setattr(FIPS_ECDSA_Tests_KAT, "test_verify_negative_%d" % idx, negative_test)
+
+
+test_vectors_sign = load_test_vectors(("Signature", "ECDSA"),
+ "SigGen.txt",
+ "ECDSA Signature Verification 186-3",
+ {'d': lambda x: int(x, 16)}) or []
+
+for idx, tv in enumerate(test_vectors_sign):
+
+ if isinstance(tv, str):
+ res = re.match(r"\[(P-[0-9]+),(SHA-[0-9]+)\]", tv)
+ assert res
+ curve_name = res.group(1)
+ hash_name = res.group(2).replace("-", "")
+ hash_module = load_hash_by_name(hash_name)
+ continue
+
+ hash_obj = hash_module.new(tv.msg)
+ ecc_key = ECC.construct(curve=curve_name, d=tv.d)
+ signer = DSS.new(ecc_key, 'fips-186-3', randfunc=StrRNG(tv.k))
+
+ def sign_test(self, signer=signer, hash_obj=hash_obj, signature=tv.r+tv.s):
+ self.assertEqual(signer.sign(hash_obj), signature)
+ setattr(FIPS_ECDSA_Tests_KAT, "test_sign_%d" % idx, sign_test)
+
+
+class Det_DSA_Tests(unittest.TestCase):
+ """Tests from rfc6979"""
+
+ # Each key is (p, q, g, x, y, desc)
+ keys = [
+ (
+ """
+ 86F5CA03DCFEB225063FF830A0C769B9DD9D6153AD91D7CE27F787C43278B447
+ E6533B86B18BED6E8A48B784A14C252C5BE0DBF60B86D6385BD2F12FB763ED88
+ 73ABFD3F5BA2E0A8C0A59082EAC056935E529DAF7C610467899C77ADEDFC846C
+ 881870B7B19B2B58F9BE0521A17002E3BDD6B86685EE90B3D9A1B02B782B1779""",
+ "996F967F6C8E388D9E28D01E205FBA957A5698B1",
+ """
+ 07B0F92546150B62514BB771E2A0C0CE387F03BDA6C56B505209FF25FD3C133D
+ 89BBCD97E904E09114D9A7DEFDEADFC9078EA544D2E401AEECC40BB9FBBF78FD
+ 87995A10A1C27CB7789B594BA7EFB5C4326A9FE59A070E136DB77175464ADCA4
+ 17BE5DCE2F40D10A46A3A3943F26AB7FD9C0398FF8C76EE0A56826A8A88F1DBD""",
+ "411602CB19A6CCC34494D79D98EF1E7ED5AF25F7",
+ """
+ 5DF5E01DED31D0297E274E1691C192FE5868FEF9E19A84776454B100CF16F653
+ 92195A38B90523E2542EE61871C0440CB87C322FC4B4D2EC5E1E7EC766E1BE8D
+ 4CE935437DC11C3C8FD426338933EBFE739CB3465F4D3668C5E473508253B1E6
+ 82F65CBDC4FAE93C2EA212390E54905A86E2223170B44EAA7DA5DD9FFCFB7F3B""",
+ "DSA1024"
+ ),
+ (
+ """
+ 9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642F0B5C48
+ C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757264E5A1A44F
+ FE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F9716BFE6117C6B5
+ B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091EB51743BF33050C38DE2
+ 35567E1B34C3D6A5C0CEAA1A0F368213C3D19843D0B4B09DCB9FC72D39C8DE41
+ F1BF14D4BB4563CA28371621CAD3324B6A2D392145BEBFAC748805236F5CA2FE
+ 92B871CD8F9C36D3292B5509CA8CAA77A2ADFC7BFD77DDA6F71125A7456FEA15
+ 3E433256A2261C6A06ED3693797E7995FAD5AABBCFBE3EDA2741E375404AE25B""",
+ "F2C3119374CE76C9356990B465374A17F23F9ED35089BD969F61C6DDE9998C1F",
+ """
+ 5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E24809670716C613
+ D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D1AA58C4328A06C4
+ 6A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A338661D10461C0D135472
+ 085057F3494309FFA73C611F78B32ADBB5740C361C9F35BE90997DB2014E2EF5
+ AA61782F52ABEB8BD6432C4DD097BC5423B285DAFB60DC364E8161F4A2A35ACA
+ 3A10B1C4D203CC76A470A33AFDCBDD92959859ABD8B56E1725252D78EAC66E71
+ BA9AE3F1DD2487199874393CD4D832186800654760E1E34C09E4D155179F9EC0
+ DC4473F996BDCE6EED1CABED8B6F116F7AD9CF505DF0F998E34AB27514B0FFE7""",
+ "69C7548C21D0DFEA6B9A51C9EAD4E27C33D3B3F180316E5BCAB92C933F0E4DBC",
+ """
+ 667098C654426C78D7F8201EAC6C203EF030D43605032C2F1FA937E5237DBD94
+ 9F34A0A2564FE126DC8B715C5141802CE0979C8246463C40E6B6BDAA2513FA61
+ 1728716C2E4FD53BC95B89E69949D96512E873B9C8F8DFD499CC312882561ADE
+ CB31F658E934C0C197F2C4D96B05CBAD67381E7B768891E4DA3843D24D94CDFB
+ 5126E9B8BF21E8358EE0E0A30EF13FD6A664C0DCE3731F7FB49A4845A4FD8254
+ 687972A2D382599C9BAC4E0ED7998193078913032558134976410B89D2C171D1
+ 23AC35FD977219597AA7D15C1A9A428E59194F75C721EBCBCFAE44696A499AFA
+ 74E04299F132026601638CB87AB79190D4A0986315DA8EEC6561C938996BEADF""",
+ "DSA2048"
+ ),
+ ]
+
+ # This is a sequence of items:
+ # message, k, r, s, hash module
+ signatures = [
+ (
+ "sample",
+ "7BDB6B0FF756E1BB5D53583EF979082F9AD5BD5B",
+ "2E1A0C2562B2912CAAF89186FB0F42001585DA55",
+ "29EFB6B0AFF2D7A68EB70CA313022253B9A88DF5",
+ SHA1,
+ 'DSA1024'
+ ),
+ (
+ "sample",
+ "562097C06782D60C3037BA7BE104774344687649",
+ "4BC3B686AEA70145856814A6F1BB53346F02101E",
+ "410697B92295D994D21EDD2F4ADA85566F6F94C1",
+ SHA224,
+ 'DSA1024'
+ ),
+ (
+ "sample",
+ "519BA0546D0C39202A7D34D7DFA5E760B318BCFB",
+ "81F2F5850BE5BC123C43F71A3033E9384611C545",
+ "4CDD914B65EB6C66A8AAAD27299BEE6B035F5E89",
+ SHA256,
+ 'DSA1024'
+ ),
+ (
+ "sample",
+ "95897CD7BBB944AA932DBC579C1C09EB6FCFC595",
+ "07F2108557EE0E3921BC1774F1CA9B410B4CE65A",
+ "54DF70456C86FAC10FAB47C1949AB83F2C6F7595",
+ SHA384,
+ 'DSA1024'
+ ),
+ (
+ "sample",
+ "09ECE7CA27D0F5A4DD4E556C9DF1D21D28104F8B",
+ "16C3491F9B8C3FBBDD5E7A7B667057F0D8EE8E1B",
+ "02C36A127A7B89EDBB72E4FFBC71DABC7D4FC69C",
+ SHA512,
+ 'DSA1024'
+ ),
+ (
+ "test",
+ "5C842DF4F9E344EE09F056838B42C7A17F4A6433",
+ "42AB2052FD43E123F0607F115052A67DCD9C5C77",
+ "183916B0230D45B9931491D4C6B0BD2FB4AAF088",
+ SHA1,
+ 'DSA1024'
+ ),
+ (
+ "test",
+ "4598B8EFC1A53BC8AECD58D1ABBB0C0C71E67297",
+ "6868E9964E36C1689F6037F91F28D5F2C30610F2",
+ "49CEC3ACDC83018C5BD2674ECAAD35B8CD22940F",
+ SHA224,
+ 'DSA1024'
+ ),
+ (
+ "test",
+ "5A67592E8128E03A417B0484410FB72C0B630E1A",
+ "22518C127299B0F6FDC9872B282B9E70D0790812",
+ "6837EC18F150D55DE95B5E29BE7AF5D01E4FE160",
+ SHA256,
+ 'DSA1024'
+ ),
+ (
+ "test",
+ "220156B761F6CA5E6C9F1B9CF9C24BE25F98CD89",
+ "854CF929B58D73C3CBFDC421E8D5430CD6DB5E66",
+ "91D0E0F53E22F898D158380676A871A157CDA622",
+ SHA384,
+ 'DSA1024'
+ ),
+ (
+ "test",
+ "65D2C2EEB175E370F28C75BFCDC028D22C7DBE9C",
+ "8EA47E475BA8AC6F2D821DA3BD212D11A3DEB9A0",
+ "7C670C7AD72B6C050C109E1790008097125433E8",
+ SHA512,
+ 'DSA1024'
+ ),
+ (
+ "sample",
+ "888FA6F7738A41BDC9846466ABDB8174C0338250AE50CE955CA16230F9CBD53E",
+ "3A1B2DBD7489D6ED7E608FD036C83AF396E290DBD602408E8677DAABD6E7445A",
+ "D26FCBA19FA3E3058FFC02CA1596CDBB6E0D20CB37B06054F7E36DED0CDBBCCF",
+ SHA1,
+ 'DSA2048'
+ ),
+ (
+ "sample",
+ "BC372967702082E1AA4FCE892209F71AE4AD25A6DFD869334E6F153BD0C4D806",
+ "DC9F4DEADA8D8FF588E98FED0AB690FFCE858DC8C79376450EB6B76C24537E2C",
+ "A65A9C3BC7BABE286B195D5DA68616DA8D47FA0097F36DD19F517327DC848CEC",
+ SHA224,
+ 'DSA2048'
+ ),
+ (
+ "sample",
+ "8926A27C40484216F052F4427CFD5647338B7B3939BC6573AF4333569D597C52",
+ "EACE8BDBBE353C432A795D9EC556C6D021F7A03F42C36E9BC87E4AC7932CC809",
+ "7081E175455F9247B812B74583E9E94F9EA79BD640DC962533B0680793A38D53",
+ SHA256,
+ 'DSA2048'
+ ),
+ (
+ "sample",
+ "C345D5AB3DA0A5BCB7EC8F8FB7A7E96069E03B206371EF7D83E39068EC564920",
+ "B2DA945E91858834FD9BF616EBAC151EDBC4B45D27D0DD4A7F6A22739F45C00B",
+ "19048B63D9FD6BCA1D9BAE3664E1BCB97F7276C306130969F63F38FA8319021B",
+ SHA384,
+ 'DSA2048'
+ ),
+ (
+ "sample",
+ "5A12994431785485B3F5F067221517791B85A597B7A9436995C89ED0374668FC",
+ "2016ED092DC5FB669B8EFB3D1F31A91EECB199879BE0CF78F02BA062CB4C942E",
+ "D0C76F84B5F091E141572A639A4FB8C230807EEA7D55C8A154A224400AFF2351",
+ SHA512,
+ 'DSA2048'
+ ),
+ (
+ "test",
+ "6EEA486F9D41A037B2C640BC5645694FF8FF4B98D066A25F76BE641CCB24BA4F",
+ "C18270A93CFC6063F57A4DFA86024F700D980E4CF4E2CB65A504397273D98EA0",
+ "414F22E5F31A8B6D33295C7539C1C1BA3A6160D7D68D50AC0D3A5BEAC2884FAA",
+ SHA1,
+ 'DSA2048'
+ ),
+ (
+ "test",
+ "06BD4C05ED74719106223BE33F2D95DA6B3B541DAD7BFBD7AC508213B6DA6670",
+ "272ABA31572F6CC55E30BF616B7A265312018DD325BE031BE0CC82AA17870EA3",
+ "E9CC286A52CCE201586722D36D1E917EB96A4EBDB47932F9576AC645B3A60806",
+ SHA224,
+ 'DSA2048'
+ ),
+ (
+ "test",
+ "1D6CE6DDA1C5D37307839CD03AB0A5CBB18E60D800937D67DFB4479AAC8DEAD7",
+ "8190012A1969F9957D56FCCAAD223186F423398D58EF5B3CEFD5A4146A4476F0",
+ "7452A53F7075D417B4B013B278D1BB8BBD21863F5E7B1CEE679CF2188E1AB19E",
+ SHA256,
+ 'DSA2048'
+ ),
+ (
+ "test",
+ "206E61F73DBE1B2DC8BE736B22B079E9DACD974DB00EEBBC5B64CAD39CF9F91C",
+ "239E66DDBE8F8C230A3D071D601B6FFBDFB5901F94D444C6AF56F732BEB954BE",
+ "6BD737513D5E72FE85D1C750E0F73921FE299B945AAD1C802F15C26A43D34961",
+ SHA384,
+ 'DSA2048'
+ ),
+ (
+ "test",
+ "AFF1651E4CD6036D57AA8B2A05CCF1A9D5A40166340ECBBDC55BE10B568AA0AA",
+ "89EC4BB1400ECCFF8E7D9AA515CD1DE7803F2DAFF09693EE7FD1353E90A68307",
+ "C9F0BDABCC0D880BB137A994CC7F3980CE91CC10FAF529FC46565B15CEA854E1",
+ SHA512,
+ 'DSA2048'
+ )
+ ]
+
+ def setUp(self):
+ # Convert DSA key components from hex strings to integers
+ # Each key is (p, q, g, x, y, desc)
+
+ from collections import namedtuple
+
+ TestKey = namedtuple('TestKey', 'p q g x y')
+ new_keys = {}
+ for k in self.keys:
+ tk = TestKey(*[t2l(y) for y in k[:-1]])
+ new_keys[k[-1]] = tk
+ self.keys = new_keys
+
+ # Convert signature encoding
+ TestSig = namedtuple('TestSig', 'message nonce result module test_key')
+ new_signatures = []
+ for message, nonce, r, s, module, test_key in self.signatures:
+ tsig = TestSig(
+ tobytes(message),
+ t2l(nonce),
+ t2b(r) + t2b(s),
+ module,
+ self.keys[test_key]
+ )
+ new_signatures.append(tsig)
+ self.signatures = new_signatures
+
+ def test1(self):
+ q = 0x4000000000000000000020108A2E0CC0D99F8A5EF
+ x = 0x09A4D6792295A7F730FC3F2B49CBC0F62E862272F
+ p = 2 * q + 1
+ y = pow(2, x, p)
+ key = DSA.construct([pow(y, 2, p), 2, p, q, x], False)
+ signer = DSS.new(key, 'deterministic-rfc6979')
+
+ # Test _int2octets
+ self.assertEqual(hexlify(signer._int2octets(x)),
+ b'009a4d6792295a7f730fc3f2b49cbc0f62e862272f')
+
+ # Test _bits2octets
+ h1 = SHA256.new(b"sample").digest()
+ self.assertEqual(hexlify(signer._bits2octets(h1)),
+ b'01795edf0d54db760f156d0dac04c0322b3a204224')
+
+ def test2(self):
+
+ for sig in self.signatures:
+ tk = sig.test_key
+ key = DSA.construct([tk.y, tk.g, tk.p, tk.q, tk.x], False)
+ signer = DSS.new(key, 'deterministic-rfc6979')
+
+ hash_obj = sig.module.new(sig.message)
+ result = signer.sign(hash_obj)
+ self.assertEqual(sig.result, result)
+
+
+class Det_ECDSA_Tests(unittest.TestCase):
+
+ key_priv_p192 = ECC.construct(curve="P-192", d=0x6FAB034934E4C0FC9AE67F5B5659A9D7D1FEFD187EE09FD4)
+ key_pub_p192 = key_priv_p192.public_key()
+
+ key_priv_p224 = ECC.construct(curve="P-224", d=0xF220266E1105BFE3083E03EC7A3A654651F45E37167E88600BF257C1)
+ key_pub_p224 = key_priv_p224.public_key()
+
+ key_priv_p256 = ECC.construct(curve="P-256", d=0xC9AFA9D845BA75166B5C215767B1D6934E50C3DB36E89B127B8A622B120F6721)
+ key_pub_p256 = key_priv_p256.public_key()
+
+ key_priv_p384 = ECC.construct(curve="P-384", d=0x6B9D3DAD2E1B8C1C05B19875B6659F4DE23C3B667BF297BA9AA47740787137D896D5724E4C70A825F872C9EA60D2EDF5)
+ key_pub_p384 = key_priv_p384.public_key()
+
+ key_priv_p521 = ECC.construct(curve="P-521", d=0x0FAD06DAA62BA3B25D2FB40133DA757205DE67F5BB0018FEE8C86E1B68C7E75CAA896EB32F1F47C70855836A6D16FCC1466F6D8FBEC67DB89EC0C08B0E996B83538)
+ key_pub_p521 = key_priv_p521.public_key()
+
+ # This is a sequence of items:
+ # message, k, r, s, hash module
+ # taken from RFC6979
+ signatures_p192_ = (
+ (
+ "sample",
+ "37D7CA00D2C7B0E5E412AC03BD44BA837FDD5B28CD3B0021",
+ "98C6BD12B23EAF5E2A2045132086BE3EB8EBD62ABF6698FF",
+ "57A22B07DEA9530F8DE9471B1DC6624472E8E2844BC25B64",
+ SHA1
+ ),
+ (
+ "sample",
+ "4381526B3FC1E7128F202E194505592F01D5FF4C5AF015D8",
+ "A1F00DAD97AEEC91C95585F36200C65F3C01812AA60378F5",
+ "E07EC1304C7C6C9DEBBE980B9692668F81D4DE7922A0F97A",
+ SHA224
+ ),
+ (
+ "sample",
+ "32B1B6D7D42A05CB449065727A84804FB1A3E34D8F261496",
+ "4B0B8CE98A92866A2820E20AA6B75B56382E0F9BFD5ECB55",
+ "CCDB006926EA9565CBADC840829D8C384E06DE1F1E381B85",
+ SHA256
+ ),
+ (
+ "sample",
+ "4730005C4FCB01834C063A7B6760096DBE284B8252EF4311",
+ "DA63BF0B9ABCF948FBB1E9167F136145F7A20426DCC287D5",
+ "C3AA2C960972BD7A2003A57E1C4C77F0578F8AE95E31EC5E",
+ SHA384
+ ),
+ (
+ "sample",
+ "A2AC7AB055E4F20692D49209544C203A7D1F2C0BFBC75DB1",
+ "4D60C5AB1996BD848343B31C00850205E2EA6922DAC2E4B8",
+ "3F6E837448F027A1BF4B34E796E32A811CBB4050908D8F67",
+ SHA512
+ ),
+ (
+ "test",
+ "D9CF9C3D3297D3260773A1DA7418DB5537AB8DD93DE7FA25",
+ "0F2141A0EBBC44D2E1AF90A50EBCFCE5E197B3B7D4DE036D",
+ "EB18BC9E1F3D7387500CB99CF5F7C157070A8961E38700B7",
+ SHA1
+ ),
+ (
+ "test",
+ "F5DC805F76EF851800700CCE82E7B98D8911B7D510059FBE",
+ "6945A1C1D1B2206B8145548F633BB61CEF04891BAF26ED34",
+ "B7FB7FDFC339C0B9BD61A9F5A8EAF9BE58FC5CBA2CB15293",
+ SHA224
+ ),
+ (
+ "test",
+ "5C4CE89CF56D9E7C77C8585339B006B97B5F0680B4306C6C",
+ "3A718BD8B4926C3B52EE6BBE67EF79B18CB6EB62B1AD97AE",
+ "5662E6848A4A19B1F1AE2F72ACD4B8BBE50F1EAC65D9124F",
+ SHA256
+ ),
+ (
+ "test",
+ "5AFEFB5D3393261B828DB6C91FBC68C230727B030C975693",
+ "B234B60B4DB75A733E19280A7A6034BD6B1EE88AF5332367",
+ "7994090B2D59BB782BE57E74A44C9A1C700413F8ABEFE77A",
+ SHA384
+ ),
+ (
+ "test",
+ "0758753A5254759C7CFBAD2E2D9B0792EEE44136C9480527",
+ "FE4F4AE86A58B6507946715934FE2D8FF9D95B6B098FE739",
+ "74CF5605C98FBA0E1EF34D4B5A1577A7DCF59457CAE52290",
+ SHA512
+ )
+ )
+
+ signatures_p224_ = (
+ (
+ "sample",
+ "7EEFADD91110D8DE6C2C470831387C50D3357F7F4D477054B8B426BC",
+ "22226F9D40A96E19C4A301CE5B74B115303C0F3A4FD30FC257FB57AC",
+ "66D1CDD83E3AF75605DD6E2FEFF196D30AA7ED7A2EDF7AF475403D69",
+ SHA1
+ ),
+ (
+ "sample",
+ "C1D1F2F10881088301880506805FEB4825FE09ACB6816C36991AA06D",
+ "1CDFE6662DDE1E4A1EC4CDEDF6A1F5A2FB7FBD9145C12113E6ABFD3E",
+ "A6694FD7718A21053F225D3F46197CA699D45006C06F871808F43EBC",
+ SHA224
+ ),
+ (
+ "sample",
+ "AD3029E0278F80643DE33917CE6908C70A8FF50A411F06E41DEDFCDC",
+ "61AA3DA010E8E8406C656BC477A7A7189895E7E840CDFE8FF42307BA",
+ "BC814050DAB5D23770879494F9E0A680DC1AF7161991BDE692B10101",
+ SHA256
+ ),
+ (
+ "sample",
+ "52B40F5A9D3D13040F494E83D3906C6079F29981035C7BD51E5CAC40",
+ "0B115E5E36F0F9EC81F1325A5952878D745E19D7BB3EABFABA77E953",
+ "830F34CCDFE826CCFDC81EB4129772E20E122348A2BBD889A1B1AF1D",
+ SHA384
+ ),
+ (
+ "sample",
+ "9DB103FFEDEDF9CFDBA05184F925400C1653B8501BAB89CEA0FBEC14",
+ "074BD1D979D5F32BF958DDC61E4FB4872ADCAFEB2256497CDAC30397",
+ "A4CECA196C3D5A1FF31027B33185DC8EE43F288B21AB342E5D8EB084",
+ SHA512
+ ),
+ (
+ "test",
+ "2519178F82C3F0E4F87ED5883A4E114E5B7A6E374043D8EFD329C253",
+ "DEAA646EC2AF2EA8AD53ED66B2E2DDAA49A12EFD8356561451F3E21C",
+ "95987796F6CF2062AB8135271DE56AE55366C045F6D9593F53787BD2",
+ SHA1
+ ),
+ (
+ "test",
+ "DF8B38D40DCA3E077D0AC520BF56B6D565134D9B5F2EAE0D34900524",
+ "C441CE8E261DED634E4CF84910E4C5D1D22C5CF3B732BB204DBEF019",
+ "902F42847A63BDC5F6046ADA114953120F99442D76510150F372A3F4",
+ SHA224
+ ),
+ (
+ "test",
+ "FF86F57924DA248D6E44E8154EB69F0AE2AEBAEE9931D0B5A969F904",
+ "AD04DDE87B84747A243A631EA47A1BA6D1FAA059149AD2440DE6FBA6",
+ "178D49B1AE90E3D8B629BE3DB5683915F4E8C99FDF6E666CF37ADCFD",
+ SHA256
+ ),
+ (
+ "test",
+ "7046742B839478C1B5BD31DB2E862AD868E1A45C863585B5F22BDC2D",
+ "389B92682E399B26518A95506B52C03BC9379A9DADF3391A21FB0EA4",
+ "414A718ED3249FF6DBC5B50C27F71F01F070944DA22AB1F78F559AAB",
+ SHA384
+ ),
+ (
+ "test",
+ "E39C2AA4EA6BE2306C72126D40ED77BF9739BB4D6EF2BBB1DCB6169D",
+ "049F050477C5ADD858CAC56208394B5A55BAEBBE887FDF765047C17C",
+ "077EB13E7005929CEFA3CD0403C7CDCC077ADF4E44F3C41B2F60ECFF",
+ SHA512
+ )
+ )
+
+ signatures_p256_ = (
+ (
+ "sample",
+ "882905F1227FD620FBF2ABF21244F0BA83D0DC3A9103DBBEE43A1FB858109DB4",
+ "61340C88C3AAEBEB4F6D667F672CA9759A6CCAA9FA8811313039EE4A35471D32",
+ "6D7F147DAC089441BB2E2FE8F7A3FA264B9C475098FDCF6E00D7C996E1B8B7EB",
+ SHA1
+ ),
+ (
+ "sample",
+ "103F90EE9DC52E5E7FB5132B7033C63066D194321491862059967C715985D473",
+ "53B2FFF5D1752B2C689DF257C04C40A587FABABB3F6FC2702F1343AF7CA9AA3F",
+ "B9AFB64FDC03DC1A131C7D2386D11E349F070AA432A4ACC918BEA988BF75C74C",
+ SHA224
+ ),
+ (
+ "sample",
+ "A6E3C57DD01ABE90086538398355DD4C3B17AA873382B0F24D6129493D8AAD60",
+ "EFD48B2AACB6A8FD1140DD9CD45E81D69D2C877B56AAF991C34D0EA84EAF3716",
+ "F7CB1C942D657C41D436C7A1B6E29F65F3E900DBB9AFF4064DC4AB2F843ACDA8",
+ SHA256
+ ),
+ (
+ "sample",
+ "09F634B188CEFD98E7EC88B1AA9852D734D0BC272F7D2A47DECC6EBEB375AAD4",
+ "0EAFEA039B20E9B42309FB1D89E213057CBF973DC0CFC8F129EDDDC800EF7719",
+ "4861F0491E6998B9455193E34E7B0D284DDD7149A74B95B9261F13ABDE940954",
+ SHA384
+ ),
+ (
+ "sample",
+ "5FA81C63109BADB88C1F367B47DA606DA28CAD69AA22C4FE6AD7DF73A7173AA5",
+ "8496A60B5E9B47C825488827E0495B0E3FA109EC4568FD3F8D1097678EB97F00",
+ "2362AB1ADBE2B8ADF9CB9EDAB740EA6049C028114F2460F96554F61FAE3302FE",
+ SHA512
+ ),
+ (
+ "test",
+ "8C9520267C55D6B980DF741E56B4ADEE114D84FBFA2E62137954164028632A2E",
+ "0CBCC86FD6ABD1D99E703E1EC50069EE5C0B4BA4B9AC60E409E8EC5910D81A89",
+ "01B9D7B73DFAA60D5651EC4591A0136F87653E0FD780C3B1BC872FFDEAE479B1",
+ SHA1
+ ),
+ (
+ "test",
+ "669F4426F2688B8BE0DB3A6BD1989BDAEFFF84B649EEB84F3DD26080F667FAA7",
+ "C37EDB6F0AE79D47C3C27E962FA269BB4F441770357E114EE511F662EC34A692",
+ "C820053A05791E521FCAAD6042D40AEA1D6B1A540138558F47D0719800E18F2D",
+ SHA224
+ ),
+ (
+ "test",
+ "D16B6AE827F17175E040871A1C7EC3500192C4C92677336EC2537ACAEE0008E0",
+ "F1ABB023518351CD71D881567B1EA663ED3EFCF6C5132B354F28D3B0B7D38367",
+ "019F4113742A2B14BD25926B49C649155F267E60D3814B4C0CC84250E46F0083",
+ SHA256
+ ),
+ (
+ "test",
+ "16AEFFA357260B04B1DD199693960740066C1A8F3E8EDD79070AA914D361B3B8",
+ "83910E8B48BB0C74244EBDF7F07A1C5413D61472BD941EF3920E623FBCCEBEB6",
+ "8DDBEC54CF8CD5874883841D712142A56A8D0F218F5003CB0296B6B509619F2C",
+ SHA384
+ ),
+ (
+ "test",
+ "6915D11632ACA3C40D5D51C08DAF9C555933819548784480E93499000D9F0B7F",
+ "461D93F31B6540894788FD206C07CFA0CC35F46FA3C91816FFF1040AD1581A04",
+ "39AF9F15DE0DB8D97E72719C74820D304CE5226E32DEDAE67519E840D1194E55",
+ SHA512
+ )
+ )
+
+ signatures_p384_ = (
+ (
+ "sample",
+ "4471EF7518BB2C7C20F62EAE1C387AD0C5E8E470995DB4ACF694466E6AB096630F29E5938D25106C3C340045A2DB01A7",
+ "EC748D839243D6FBEF4FC5C4859A7DFFD7F3ABDDF72014540C16D73309834FA37B9BA002899F6FDA3A4A9386790D4EB2",
+ "A3BCFA947BEEF4732BF247AC17F71676CB31A847B9FF0CBC9C9ED4C1A5B3FACF26F49CA031D4857570CCB5CA4424A443",
+ SHA1
+ ),
+ (
+ "sample",
+ "A4E4D2F0E729EB786B31FC20AD5D849E304450E0AE8E3E341134A5C1AFA03CAB8083EE4E3C45B06A5899EA56C51B5879",
+ "42356E76B55A6D9B4631C865445DBE54E056D3B3431766D0509244793C3F9366450F76EE3DE43F5A125333A6BE060122",
+ "9DA0C81787064021E78DF658F2FBB0B042BF304665DB721F077A4298B095E4834C082C03D83028EFBF93A3C23940CA8D",
+ SHA224
+ ),
+ (
+ "sample",
+ "180AE9F9AEC5438A44BC159A1FCB277C7BE54FA20E7CF404B490650A8ACC414E375572342863C899F9F2EDF9747A9B60",
+ "21B13D1E013C7FA1392D03C5F99AF8B30C570C6F98D4EA8E354B63A21D3DAA33BDE1E888E63355D92FA2B3C36D8FB2CD",
+ "F3AA443FB107745BF4BD77CB3891674632068A10CA67E3D45DB2266FA7D1FEEBEFDC63ECCD1AC42EC0CB8668A4FA0AB0",
+ SHA256
+ ),
+ (
+ "sample",
+ "94ED910D1A099DAD3254E9242AE85ABDE4BA15168EAF0CA87A555FD56D10FBCA2907E3E83BA95368623B8C4686915CF9",
+ "94EDBB92A5ECB8AAD4736E56C691916B3F88140666CE9FA73D64C4EA95AD133C81A648152E44ACF96E36DD1E80FABE46",
+ "99EF4AEB15F178CEA1FE40DB2603138F130E740A19624526203B6351D0A3A94FA329C145786E679E7B82C71A38628AC8",
+ SHA384
+ ),
+ (
+ "sample",
+ "92FC3C7183A883E24216D1141F1A8976C5B0DD797DFA597E3D7B32198BD35331A4E966532593A52980D0E3AAA5E10EC3",
+ "ED0959D5880AB2D869AE7F6C2915C6D60F96507F9CB3E047C0046861DA4A799CFE30F35CC900056D7C99CD7882433709",
+ "512C8CCEEE3890A84058CE1E22DBC2198F42323CE8ACA9135329F03C068E5112DC7CC3EF3446DEFCEB01A45C2667FDD5",
+ SHA512
+ ),
+ (
+ "test",
+ "66CC2C8F4D303FC962E5FF6A27BD79F84EC812DDAE58CF5243B64A4AD8094D47EC3727F3A3C186C15054492E30698497",
+ "4BC35D3A50EF4E30576F58CD96CE6BF638025EE624004A1F7789A8B8E43D0678ACD9D29876DAF46638645F7F404B11C7",
+ "D5A6326C494ED3FF614703878961C0FDE7B2C278F9A65FD8C4B7186201A2991695BA1C84541327E966FA7B50F7382282",
+ SHA1
+ ),
+ (
+ "test",
+ "18FA39DB95AA5F561F30FA3591DC59C0FA3653A80DAFFA0B48D1A4C6DFCBFF6E3D33BE4DC5EB8886A8ECD093F2935726",
+ "E8C9D0B6EA72A0E7837FEA1D14A1A9557F29FAA45D3E7EE888FC5BF954B5E62464A9A817C47FF78B8C11066B24080E72",
+ "07041D4A7A0379AC7232FF72E6F77B6DDB8F09B16CCE0EC3286B2BD43FA8C6141C53EA5ABEF0D8231077A04540A96B66",
+ SHA224
+ ),
+ (
+ "test",
+ "0CFAC37587532347DC3389FDC98286BBA8C73807285B184C83E62E26C401C0FAA48DD070BA79921A3457ABFF2D630AD7",
+ "6D6DEFAC9AB64DABAFE36C6BF510352A4CC27001263638E5B16D9BB51D451559F918EEDAF2293BE5B475CC8F0188636B",
+ "2D46F3BECBCC523D5F1A1256BF0C9B024D879BA9E838144C8BA6BAEB4B53B47D51AB373F9845C0514EEFB14024787265",
+ SHA256
+ ),
+ (
+ "test",
+ "015EE46A5BF88773ED9123A5AB0807962D193719503C527B031B4C2D225092ADA71F4A459BC0DA98ADB95837DB8312EA",
+ "8203B63D3C853E8D77227FB377BCF7B7B772E97892A80F36AB775D509D7A5FEB0542A7F0812998DA8F1DD3CA3CF023DB",
+ "DDD0760448D42D8A43AF45AF836FCE4DE8BE06B485E9B61B827C2F13173923E06A739F040649A667BF3B828246BAA5A5",
+ SHA384
+ ),
+ (
+ "test",
+ "3780C4F67CB15518B6ACAE34C9F83568D2E12E47DEAB6C50A4E4EE5319D1E8CE0E2CC8A136036DC4B9C00E6888F66B6C",
+ "A0D5D090C9980FAF3C2CE57B7AE951D31977DD11C775D314AF55F76C676447D06FB6495CD21B4B6E340FC236584FB277",
+ "976984E59B4C77B0E8E4460DCA3D9F20E07B9BB1F63BEEFAF576F6B2E8B224634A2092CD3792E0159AD9CEE37659C736",
+ SHA512
+ ),
+ )
+
+ signatures_p521_ = (
+ (
+ "sample",
+ "0089C071B419E1C2820962321787258469511958E80582E95D8378E0C2CCDB3CB42BEDE42F50E3FA3C71F5A76724281D31D9C89F0F91FC1BE4918DB1C03A5838D0F9",
+ "00343B6EC45728975EA5CBA6659BBB6062A5FF89EEA58BE3C80B619F322C87910FE092F7D45BB0F8EEE01ED3F20BABEC079D202AE677B243AB40B5431D497C55D75D",
+ "00E7B0E675A9B24413D448B8CC119D2BF7B2D2DF032741C096634D6D65D0DBE3D5694625FB9E8104D3B842C1B0E2D0B98BEA19341E8676AEF66AE4EBA3D5475D5D16",
+ SHA1
+ ),
+ (
+ "sample",
+ "0121415EC2CD7726330A61F7F3FA5DE14BE9436019C4DB8CB4041F3B54CF31BE0493EE3F427FB906393D895A19C9523F3A1D54BB8702BD4AA9C99DAB2597B92113F3",
+ "01776331CFCDF927D666E032E00CF776187BC9FDD8E69D0DABB4109FFE1B5E2A30715F4CC923A4A5E94D2503E9ACFED92857B7F31D7152E0F8C00C15FF3D87E2ED2E",
+ "0050CB5265417FE2320BBB5A122B8E1A32BD699089851128E360E620A30C7E17BA41A666AF126CE100E5799B153B60528D5300D08489CA9178FB610A2006C254B41F",
+ SHA224
+ ),
+ (
+ "sample",
+ "00EDF38AFCAAECAB4383358B34D67C9F2216C8382AAEA44A3DAD5FDC9C32575761793FEF24EB0FC276DFC4F6E3EC476752F043CF01415387470BCBD8678ED2C7E1A0",
+ "01511BB4D675114FE266FC4372B87682BAECC01D3CC62CF2303C92B3526012659D16876E25C7C1E57648F23B73564D67F61C6F14D527D54972810421E7D87589E1A7",
+ "004A171143A83163D6DF460AAF61522695F207A58B95C0644D87E52AA1A347916E4F7A72930B1BC06DBE22CE3F58264AFD23704CBB63B29B931F7DE6C9D949A7ECFC",
+ SHA256
+ ),
+ (
+ "sample",
+ "01546A108BC23A15D6F21872F7DED661FA8431DDBD922D0DCDB77CC878C8553FFAD064C95A920A750AC9137E527390D2D92F153E66196966EA554D9ADFCB109C4211",
+ "01EA842A0E17D2DE4F92C15315C63DDF72685C18195C2BB95E572B9C5136CA4B4B576AD712A52BE9730627D16054BA40CC0B8D3FF035B12AE75168397F5D50C67451",
+ "01F21A3CEE066E1961025FB048BD5FE2B7924D0CD797BABE0A83B66F1E35EEAF5FDE143FA85DC394A7DEE766523393784484BDF3E00114A1C857CDE1AA203DB65D61",
+ SHA384
+ ),
+ (
+ "sample",
+ "01DAE2EA071F8110DC26882D4D5EAE0621A3256FC8847FB9022E2B7D28E6F10198B1574FDD03A9053C08A1854A168AA5A57470EC97DD5CE090124EF52A2F7ECBFFD3",
+ "00C328FAFCBD79DD77850370C46325D987CB525569FB63C5D3BC53950E6D4C5F174E25A1EE9017B5D450606ADD152B534931D7D4E8455CC91F9B15BF05EC36E377FA",
+ "00617CCE7CF5064806C467F678D3B4080D6F1CC50AF26CA209417308281B68AF282623EAA63E5B5C0723D8B8C37FF0777B1A20F8CCB1DCCC43997F1EE0E44DA4A67A",
+ SHA512
+ ),
+ (
+ "test",
+ "00BB9F2BF4FE1038CCF4DABD7139A56F6FD8BB1386561BD3C6A4FC818B20DF5DDBA80795A947107A1AB9D12DAA615B1ADE4F7A9DC05E8E6311150F47F5C57CE8B222",
+ "013BAD9F29ABE20DE37EBEB823C252CA0F63361284015A3BF430A46AAA80B87B0693F0694BD88AFE4E661FC33B094CD3B7963BED5A727ED8BD6A3A202ABE009D0367",
+ "01E9BB81FF7944CA409AD138DBBEE228E1AFCC0C890FC78EC8604639CB0DBDC90F717A99EAD9D272855D00162EE9527567DD6A92CBD629805C0445282BBC916797FF",
+ SHA1
+ ),
+ (
+ "test",
+ "0040D09FCF3C8A5F62CF4FB223CBBB2B9937F6B0577C27020A99602C25A01136987E452988781484EDBBCF1C47E554E7FC901BC3085E5206D9F619CFF07E73D6F706",
+ "01C7ED902E123E6815546065A2C4AF977B22AA8EADDB68B2C1110E7EA44D42086BFE4A34B67DDC0E17E96536E358219B23A706C6A6E16BA77B65E1C595D43CAE17FB",
+ "0177336676304FCB343CE028B38E7B4FBA76C1C1B277DA18CAD2A8478B2A9A9F5BEC0F3BA04F35DB3E4263569EC6AADE8C92746E4C82F8299AE1B8F1739F8FD519A4",
+ SHA224
+ ),
+ (
+ "test",
+ "001DE74955EFAABC4C4F17F8E84D881D1310B5392D7700275F82F145C61E843841AF09035BF7A6210F5A431A6A9E81C9323354A9E69135D44EBD2FCAA7731B909258",
+ "000E871C4A14F993C6C7369501900C4BC1E9C7B0B4BA44E04868B30B41D8071042EB28C4C250411D0CE08CD197E4188EA4876F279F90B3D8D74A3C76E6F1E4656AA8",
+ "00CD52DBAA33B063C3A6CD8058A1FB0A46A4754B034FCC644766CA14DA8CA5CA9FDE00E88C1AD60CCBA759025299079D7A427EC3CC5B619BFBC828E7769BCD694E86",
+ SHA256
+ ),
+ (
+ "test",
+ "01F1FC4A349A7DA9A9E116BFDD055DC08E78252FF8E23AC276AC88B1770AE0B5DCEB1ED14A4916B769A523CE1E90BA22846AF11DF8B300C38818F713DADD85DE0C88",
+ "014BEE21A18B6D8B3C93FAB08D43E739707953244FDBE924FA926D76669E7AC8C89DF62ED8975C2D8397A65A49DCC09F6B0AC62272741924D479354D74FF6075578C",
+ "0133330865C067A0EAF72362A65E2D7BC4E461E8C8995C3B6226A21BD1AA78F0ED94FE536A0DCA35534F0CD1510C41525D163FE9D74D134881E35141ED5E8E95B979",
+ SHA384
+ ),
+ (
+ "test",
+ "016200813020EC986863BEDFC1B121F605C1215645018AEA1A7B215A564DE9EB1B38A67AA1128B80CE391C4FB71187654AAA3431027BFC7F395766CA988C964DC56D",
+ "013E99020ABF5CEE7525D16B69B229652AB6BDF2AFFCAEF38773B4B7D08725F10CDB93482FDCC54EDCEE91ECA4166B2A7C6265EF0CE2BD7051B7CEF945BABD47EE6D",
+ "01FBD0013C674AA79CB39849527916CE301C66EA7CE8B80682786AD60F98F7E78A19CA69EFF5C57400E3B3A0AD66CE0978214D13BAF4E9AC60752F7B155E2DE4DCE3",
+ SHA512
+ ),
+ )
+
+ signatures_p192 = []
+ for a, b, c, d, e in signatures_p192_:
+ new_tv = (tobytes(a), unhexlify(b), unhexlify(c), unhexlify(d), e)
+ signatures_p192.append(new_tv)
+
+ signatures_p224 = []
+ for a, b, c, d, e in signatures_p224_:
+ new_tv = (tobytes(a), unhexlify(b), unhexlify(c), unhexlify(d), e)
+ signatures_p224.append(new_tv)
+
+ signatures_p256 = []
+ for a, b, c, d, e in signatures_p256_:
+ new_tv = (tobytes(a), unhexlify(b), unhexlify(c), unhexlify(d), e)
+ signatures_p256.append(new_tv)
+
+ signatures_p384 = []
+ for a, b, c, d, e in signatures_p384_:
+ new_tv = (tobytes(a), unhexlify(b), unhexlify(c), unhexlify(d), e)
+ signatures_p384.append(new_tv)
+
+ signatures_p521 = []
+ for a, b, c, d, e in signatures_p521_:
+ new_tv = (tobytes(a), unhexlify(b), unhexlify(c), unhexlify(d), e)
+ signatures_p521.append(new_tv)
+
+ def shortDescription(self):
+ return "Deterministic ECDSA Tests"
+
+ def test_loopback_p192(self):
+ hashed_msg = SHA512.new(b"test")
+ signer = DSS.new(self.key_priv_p192, 'deterministic-rfc6979')
+ signature = signer.sign(hashed_msg)
+
+ verifier = DSS.new(self.key_pub_p192, 'deterministic-rfc6979')
+ verifier.verify(hashed_msg, signature)
+
+ def test_loopback_p224(self):
+ hashed_msg = SHA512.new(b"test")
+ signer = DSS.new(self.key_priv_p224, 'deterministic-rfc6979')
+ signature = signer.sign(hashed_msg)
+
+ verifier = DSS.new(self.key_pub_p224, 'deterministic-rfc6979')
+ verifier.verify(hashed_msg, signature)
+
+ def test_loopback_p256(self):
+ hashed_msg = SHA512.new(b"test")
+ signer = DSS.new(self.key_priv_p256, 'deterministic-rfc6979')
+ signature = signer.sign(hashed_msg)
+
+ verifier = DSS.new(self.key_pub_p256, 'deterministic-rfc6979')
+ verifier.verify(hashed_msg, signature)
+
+ def test_loopback_p384(self):
+ hashed_msg = SHA512.new(b"test")
+ signer = DSS.new(self.key_priv_p384, 'deterministic-rfc6979')
+ signature = signer.sign(hashed_msg)
+
+ verifier = DSS.new(self.key_pub_p384, 'deterministic-rfc6979')
+ verifier.verify(hashed_msg, signature)
+
+ def test_loopback_p521(self):
+ hashed_msg = SHA512.new(b"test")
+ signer = DSS.new(self.key_priv_p521, 'deterministic-rfc6979')
+ signature = signer.sign(hashed_msg)
+
+ verifier = DSS.new(self.key_pub_p521, 'deterministic-rfc6979')
+ verifier.verify(hashed_msg, signature)
+
+ def test_data_rfc6979_p192(self):
+ signer = DSS.new(self.key_priv_p192, 'deterministic-rfc6979')
+ for message, k, r, s, module in self.signatures_p192:
+ hash_obj = module.new(message)
+ result = signer.sign(hash_obj)
+ self.assertEqual(r + s, result)
+
+ def test_data_rfc6979_p224(self):
+ signer = DSS.new(self.key_priv_p224, 'deterministic-rfc6979')
+ for message, k, r, s, module in self.signatures_p224:
+ hash_obj = module.new(message)
+ result = signer.sign(hash_obj)
+ self.assertEqual(r + s, result)
+
+ def test_data_rfc6979_p256(self):
+ signer = DSS.new(self.key_priv_p256, 'deterministic-rfc6979')
+ for message, k, r, s, module in self.signatures_p256:
+ hash_obj = module.new(message)
+ result = signer.sign(hash_obj)
+ self.assertEqual(r + s, result)
+
+ def test_data_rfc6979_p384(self):
+ signer = DSS.new(self.key_priv_p384, 'deterministic-rfc6979')
+ for message, k, r, s, module in self.signatures_p384:
+ hash_obj = module.new(message)
+ result = signer.sign(hash_obj)
+ self.assertEqual(r + s, result)
+
+ def test_data_rfc6979_p521(self):
+ signer = DSS.new(self.key_priv_p521, 'deterministic-rfc6979')
+ for message, k, r, s, module in self.signatures_p521:
+ hash_obj = module.new(message)
+ result = signer.sign(hash_obj)
+ self.assertEqual(r + s, result)
+
+
+def get_hash_module(hash_name):
+ if hash_name == "SHA-512":
+ hash_module = SHA512
+ elif hash_name == "SHA-512/224":
+ hash_module = SHA512.new(truncate="224")
+ elif hash_name == "SHA-512/256":
+ hash_module = SHA512.new(truncate="256")
+ elif hash_name == "SHA-384":
+ hash_module = SHA384
+ elif hash_name == "SHA-256":
+ hash_module = SHA256
+ elif hash_name == "SHA-224":
+ hash_module = SHA224
+ elif hash_name == "SHA-1":
+ hash_module = SHA1
+ elif hash_name == "SHA3-224":
+ hash_module = SHA3_224
+ elif hash_name == "SHA3-256":
+ hash_module = SHA3_256
+ elif hash_name == "SHA3-384":
+ hash_module = SHA3_384
+ elif hash_name == "SHA3-512":
+ hash_module = SHA3_512
+ else:
+ raise ValueError("Unknown hash algorithm: " + hash_name)
+ return hash_module
+
+
+class TestVectorsDSAWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings, slow_tests):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._slow_tests = slow_tests
+ self._id = "None"
+ self.tv = []
+
+ def setUp(self):
+
+ def filter_dsa(group):
+ return DSA.import_key(group['keyPem'])
+
+ def filter_sha(group):
+ return get_hash_module(group['sha'])
+
+ def filter_type(group):
+ sig_type = group['type']
+ if sig_type != 'DsaVerify':
+ raise ValueError("Unknown signature type " + sig_type)
+ return sig_type
+
+ result = load_test_vectors_wycheproof(("Signature", "wycheproof"),
+ "dsa_test.json",
+ "Wycheproof DSA signature",
+ group_tag={'key': filter_dsa,
+ 'hash_module': filter_sha,
+ 'sig_type': filter_type})
+ self.tv += result
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_verify(self, tv):
+ self._id = "Wycheproof DSA Test #" + str(tv.id)
+
+ hashed_msg = tv.hash_module.new(tv.msg)
+ signer = DSS.new(tv.key, 'fips-186-3', encoding='der')
+ try:
+ signature = signer.verify(hashed_msg, tv.sig)
+ except ValueError as e:
+ if tv.warning:
+ return
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.warn(tv)
+
+ def runTest(self):
+ for tv in self.tv:
+ self.test_verify(tv)
+
+
+class TestVectorsECDSAWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings, slow_tests):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._slow_tests = slow_tests
+ self._id = "None"
+
+ def add_tests(self, filename):
+
+ def filter_ecc(group):
+ # These are the only curves we accept to skip
+ if group['key']['curve'] in ('secp224k1', 'secp256k1',
+ 'brainpoolP224r1', 'brainpoolP224t1',
+ 'brainpoolP256r1', 'brainpoolP256t1',
+ 'brainpoolP320r1', 'brainpoolP320t1',
+ 'brainpoolP384r1', 'brainpoolP384t1',
+ 'brainpoolP512r1', 'brainpoolP512t1',
+ ):
+ return None
+ return ECC.import_key(group['keyPem'])
+
+ def filter_sha(group):
+ return get_hash_module(group['sha'])
+
+ def filter_encoding(group):
+ encoding_name = group['type']
+ if encoding_name == "EcdsaVerify":
+ return "der"
+ elif encoding_name == "EcdsaP1363Verify":
+ return "binary"
+ else:
+ raise ValueError("Unknown signature type " + encoding_name)
+
+ result = load_test_vectors_wycheproof(("Signature", "wycheproof"),
+ filename,
+ "Wycheproof ECDSA signature (%s)" % filename,
+ group_tag={'key': filter_ecc,
+ 'hash_module': filter_sha,
+ 'encoding': filter_encoding,
+ })
+ self.tv += result
+
+ def setUp(self):
+ self.tv = []
+ self.add_tests("ecdsa_secp224r1_sha224_p1363_test.json")
+ self.add_tests("ecdsa_secp224r1_sha224_test.json")
+ if self._slow_tests:
+ self.add_tests("ecdsa_secp224r1_sha256_p1363_test.json")
+ self.add_tests("ecdsa_secp224r1_sha256_test.json")
+ self.add_tests("ecdsa_secp224r1_sha3_224_test.json")
+ self.add_tests("ecdsa_secp224r1_sha3_256_test.json")
+ self.add_tests("ecdsa_secp224r1_sha3_512_test.json")
+ self.add_tests("ecdsa_secp224r1_sha512_p1363_test.json")
+ self.add_tests("ecdsa_secp224r1_sha512_test.json")
+ self.add_tests("ecdsa_secp256r1_sha256_p1363_test.json")
+ self.add_tests("ecdsa_secp256r1_sha256_test.json")
+ self.add_tests("ecdsa_secp256r1_sha3_256_test.json")
+ self.add_tests("ecdsa_secp256r1_sha3_512_test.json")
+ self.add_tests("ecdsa_secp256r1_sha512_p1363_test.json")
+ self.add_tests("ecdsa_secp256r1_sha512_test.json")
+ if self._slow_tests:
+ self.add_tests("ecdsa_secp384r1_sha3_384_test.json")
+ self.add_tests("ecdsa_secp384r1_sha3_512_test.json")
+ self.add_tests("ecdsa_secp384r1_sha384_p1363_test.json")
+ self.add_tests("ecdsa_secp384r1_sha384_test.json")
+ self.add_tests("ecdsa_secp384r1_sha512_p1363_test.json")
+ self.add_tests("ecdsa_secp384r1_sha512_test.json")
+ if self._slow_tests:
+ self.add_tests("ecdsa_secp521r1_sha3_512_test.json")
+ self.add_tests("ecdsa_secp521r1_sha512_p1363_test.json")
+ self.add_tests("ecdsa_secp521r1_sha512_test.json")
+ self.add_tests("ecdsa_test.json")
+ self.add_tests("ecdsa_webcrypto_test.json")
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_verify(self, tv):
+ self._id = "Wycheproof ECDSA Test #%d (%s, %s)" % (tv.id, tv.comment, tv.filename)
+
+ # Skip tests with unsupported curves
+ if tv.key is None:
+ return
+
+ hashed_msg = tv.hash_module.new(tv.msg)
+ signer = DSS.new(tv.key, 'fips-186-3', encoding=tv.encoding)
+ try:
+ signature = signer.verify(hashed_msg, tv.sig)
+ except ValueError as e:
+ if tv.warning:
+ return
+ if tv.comment == "k*G has a large x-coordinate":
+ return
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.warn(tv)
+
+ def runTest(self):
+ for tv in self.tv:
+ self.test_verify(tv)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(FIPS_DSA_Tests)
+ tests += list_test_cases(FIPS_ECDSA_Tests)
+ tests += list_test_cases(Det_DSA_Tests)
+ tests += list_test_cases(Det_ECDSA_Tests)
+
+ slow_tests = config.get('slow_tests')
+ if slow_tests:
+ tests += list_test_cases(FIPS_DSA_Tests_KAT)
+ tests += list_test_cases(FIPS_ECDSA_Tests_KAT)
+
+ tests += [TestVectorsDSAWycheproof(wycheproof_warnings, slow_tests)]
+ tests += [TestVectorsECDSAWycheproof(wycheproof_warnings, slow_tests)]
+
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Signature/test_eddsa.py b/lib/Crypto/SelfTest/Signature/test_eddsa.py
new file mode 100644
index 0000000..6a9a9b0
--- /dev/null
+++ b/lib/Crypto/SelfTest/Signature/test_eddsa.py
@@ -0,0 +1,578 @@
+#
+# Copyright (c) 2022, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify
+
+from Crypto.PublicKey import ECC
+from Crypto.Signature import eddsa
+from Crypto.Hash import SHA512, SHAKE256
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors_wycheproof
+from Crypto.Util.number import bytes_to_long
+
+rfc8032_tv_str = (
+ # 7.1 Ed25519
+ (
+ "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60",
+ "d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a",
+ "",
+ None,
+ "",
+ "e5564300c360ac729086e2cc806e828a"
+ "84877f1eb8e5d974d873e06522490155"
+ "5fb8821590a33bacc61e39701cf9b46b"
+ "d25bf5f0595bbe24655141438e7a100b"
+ ),
+ (
+ "4ccd089b28ff96da9db6c346ec114e0f5b8a319f35aba624da8cf6ed4fb8a6fb",
+ "3d4017c3e843895a92b70aa74d1b7ebc9c982ccf2ec4968cc0cd55f12af4660c",
+ "72",
+ None,
+ "",
+ "92a009a9f0d4cab8720e820b5f642540"
+ "a2b27b5416503f8fb3762223ebdb69da"
+ "085ac1e43e15996e458f3613d0f11d8c"
+ "387b2eaeb4302aeeb00d291612bb0c00"
+ ),
+ (
+ "c5aa8df43f9f837bedb7442f31dcb7b166d38535076f094b85ce3a2e0b4458f7",
+ "fc51cd8e6218a1a38da47ed00230f0580816ed13ba3303ac5deb911548908025",
+ "af82",
+ None,
+ "",
+ "6291d657deec24024827e69c3abe01a3"
+ "0ce548a284743a445e3680d7db5ac3ac"
+ "18ff9b538d16f290ae67f760984dc659"
+ "4a7c15e9716ed28dc027beceea1ec40a"
+ ),
+ (
+ "f5e5767cf153319517630f226876b86c8160cc583bc013744c6bf255f5cc0ee5",
+ "278117fc144c72340f67d0f2316e8386ceffbf2b2428c9c51fef7c597f1d426e",
+ "08b8b2b733424243760fe426a4b54908"
+ "632110a66c2f6591eabd3345e3e4eb98"
+ "fa6e264bf09efe12ee50f8f54e9f77b1"
+ "e355f6c50544e23fb1433ddf73be84d8"
+ "79de7c0046dc4996d9e773f4bc9efe57"
+ "38829adb26c81b37c93a1b270b20329d"
+ "658675fc6ea534e0810a4432826bf58c"
+ "941efb65d57a338bbd2e26640f89ffbc"
+ "1a858efcb8550ee3a5e1998bd177e93a"
+ "7363c344fe6b199ee5d02e82d522c4fe"
+ "ba15452f80288a821a579116ec6dad2b"
+ "3b310da903401aa62100ab5d1a36553e"
+ "06203b33890cc9b832f79ef80560ccb9"
+ "a39ce767967ed628c6ad573cb116dbef"
+ "efd75499da96bd68a8a97b928a8bbc10"
+ "3b6621fcde2beca1231d206be6cd9ec7"
+ "aff6f6c94fcd7204ed3455c68c83f4a4"
+ "1da4af2b74ef5c53f1d8ac70bdcb7ed1"
+ "85ce81bd84359d44254d95629e9855a9"
+ "4a7c1958d1f8ada5d0532ed8a5aa3fb2"
+ "d17ba70eb6248e594e1a2297acbbb39d"
+ "502f1a8c6eb6f1ce22b3de1a1f40cc24"
+ "554119a831a9aad6079cad88425de6bd"
+ "e1a9187ebb6092cf67bf2b13fd65f270"
+ "88d78b7e883c8759d2c4f5c65adb7553"
+ "878ad575f9fad878e80a0c9ba63bcbcc"
+ "2732e69485bbc9c90bfbd62481d9089b"
+ "eccf80cfe2df16a2cf65bd92dd597b07"
+ "07e0917af48bbb75fed413d238f5555a"
+ "7a569d80c3414a8d0859dc65a46128ba"
+ "b27af87a71314f318c782b23ebfe808b"
+ "82b0ce26401d2e22f04d83d1255dc51a"
+ "ddd3b75a2b1ae0784504df543af8969b"
+ "e3ea7082ff7fc9888c144da2af58429e"
+ "c96031dbcad3dad9af0dcbaaaf268cb8"
+ "fcffead94f3c7ca495e056a9b47acdb7"
+ "51fb73e666c6c655ade8297297d07ad1"
+ "ba5e43f1bca32301651339e22904cc8c"
+ "42f58c30c04aafdb038dda0847dd988d"
+ "cda6f3bfd15c4b4c4525004aa06eeff8"
+ "ca61783aacec57fb3d1f92b0fe2fd1a8"
+ "5f6724517b65e614ad6808d6f6ee34df"
+ "f7310fdc82aebfd904b01e1dc54b2927"
+ "094b2db68d6f903b68401adebf5a7e08"
+ "d78ff4ef5d63653a65040cf9bfd4aca7"
+ "984a74d37145986780fc0b16ac451649"
+ "de6188a7dbdf191f64b5fc5e2ab47b57"
+ "f7f7276cd419c17a3ca8e1b939ae49e4"
+ "88acba6b965610b5480109c8b17b80e1"
+ "b7b750dfc7598d5d5011fd2dcc5600a3"
+ "2ef5b52a1ecc820e308aa342721aac09"
+ "43bf6686b64b2579376504ccc493d97e"
+ "6aed3fb0f9cd71a43dd497f01f17c0e2"
+ "cb3797aa2a2f256656168e6c496afc5f"
+ "b93246f6b1116398a346f1a641f3b041"
+ "e989f7914f90cc2c7fff357876e506b5"
+ "0d334ba77c225bc307ba537152f3f161"
+ "0e4eafe595f6d9d90d11faa933a15ef1"
+ "369546868a7f3a45a96768d40fd9d034"
+ "12c091c6315cf4fde7cb68606937380d"
+ "b2eaaa707b4c4185c32eddcdd306705e"
+ "4dc1ffc872eeee475a64dfac86aba41c"
+ "0618983f8741c5ef68d3a101e8a3b8ca"
+ "c60c905c15fc910840b94c00a0b9d0",
+ None,
+ "",
+ "0aab4c900501b3e24d7cdf4663326a3a"
+ "87df5e4843b2cbdb67cbf6e460fec350"
+ "aa5371b1508f9f4528ecea23c436d94b"
+ "5e8fcd4f681e30a6ac00a9704a188a03"
+ ),
+ # 7.2 Ed25519ctx
+ (
+ "0305334e381af78f141cb666f6199f57"
+ "bc3495335a256a95bd2a55bf546663f6",
+ "dfc9425e4f968f7f0c29f0259cf5f9ae"
+ "d6851c2bb4ad8bfb860cfee0ab248292",
+ "f726936d19c800494e3fdaff20b276a8",
+ None,
+ "666f6f",
+ "55a4cc2f70a54e04288c5f4cd1e45a7b"
+ "b520b36292911876cada7323198dd87a"
+ "8b36950b95130022907a7fb7c4e9b2d5"
+ "f6cca685a587b4b21f4b888e4e7edb0d"
+ ),
+ (
+ "0305334e381af78f141cb666f6199f57"
+ "bc3495335a256a95bd2a55bf546663f6",
+ "dfc9425e4f968f7f0c29f0259cf5f9ae"
+ "d6851c2bb4ad8bfb860cfee0ab248292",
+ "f726936d19c800494e3fdaff20b276a8",
+ None,
+ "626172",
+ "fc60d5872fc46b3aa69f8b5b4351d580"
+ "8f92bcc044606db097abab6dbcb1aee3"
+ "216c48e8b3b66431b5b186d1d28f8ee1"
+ "5a5ca2df6668346291c2043d4eb3e90d"
+ ),
+ (
+ "0305334e381af78f141cb666f6199f57"
+ "bc3495335a256a95bd2a55bf546663f6",
+ "dfc9425e4f968f7f0c29f0259cf5f9ae"
+ "d6851c2bb4ad8bfb860cfee0ab248292",
+ "508e9e6882b979fea900f62adceaca35",
+ None,
+ "666f6f",
+ "8b70c1cc8310e1de20ac53ce28ae6e72"
+ "07f33c3295e03bb5c0732a1d20dc6490"
+ "8922a8b052cf99b7c4fe107a5abb5b2c"
+ "4085ae75890d02df26269d8945f84b0b"
+ ),
+ (
+ "ab9c2853ce297ddab85c993b3ae14bca"
+ "d39b2c682beabc27d6d4eb20711d6560",
+ "0f1d1274943b91415889152e893d80e9"
+ "3275a1fc0b65fd71b4b0dda10ad7d772",
+ "f726936d19c800494e3fdaff20b276a8",
+ None,
+ "666f6f",
+ "21655b5f1aa965996b3f97b3c849eafb"
+ "a922a0a62992f73b3d1b73106a84ad85"
+ "e9b86a7b6005ea868337ff2d20a7f5fb"
+ "d4cd10b0be49a68da2b2e0dc0ad8960f"
+ ),
+ # 7.3 Ed25519ph
+ (
+ "833fe62409237b9d62ec77587520911e"
+ "9a759cec1d19755b7da901b96dca3d42",
+ "ec172b93ad5e563bf4932c70e1245034"
+ "c35467ef2efd4d64ebf819683467e2bf",
+ "616263",
+ SHA512,
+ "",
+ "98a70222f0b8121aa9d30f813d683f80"
+ "9e462b469c7ff87639499bb94e6dae41"
+ "31f85042463c2a355a2003d062adf5aa"
+ "a10b8c61e636062aaad11c2a26083406"
+ ),
+ # 7.4 Ed448
+ (
+ "6c82a562cb808d10d632be89c8513ebf6c929f34ddfa8c9f63c9960ef6e348a3"
+ "528c8a3fcc2f044e39a3fc5b94492f8f032e7549a20098f95b",
+ "5fd7449b59b461fd2ce787ec616ad46a1da1342485a70e1f8a0ea75d80e96778"
+ "edf124769b46c7061bd6783df1e50f6cd1fa1abeafe8256180",
+ "",
+ None,
+ "",
+ "533a37f6bbe457251f023c0d88f976ae2dfb504a843e34d2074fd823d41a591f"
+ "2b233f034f628281f2fd7a22ddd47d7828c59bd0a21bfd3980ff0d2028d4b18a"
+ "9df63e006c5d1c2d345b925d8dc00b4104852db99ac5c7cdda8530a113a0f4db"
+ "b61149f05a7363268c71d95808ff2e652600"
+ ),
+ (
+ "c4eab05d357007c632f3dbb48489924d552b08fe0c353a0d4a1f00acda2c463a"
+ "fbea67c5e8d2877c5e3bc397a659949ef8021e954e0a12274e",
+ "43ba28f430cdff456ae531545f7ecd0ac834a55d9358c0372bfa0c6c6798c086"
+ "6aea01eb00742802b8438ea4cb82169c235160627b4c3a9480",
+ "03",
+ None,
+ "",
+ "26b8f91727bd62897af15e41eb43c377efb9c610d48f2335cb0bd0087810f435"
+ "2541b143c4b981b7e18f62de8ccdf633fc1bf037ab7cd779805e0dbcc0aae1cb"
+ "cee1afb2e027df36bc04dcecbf154336c19f0af7e0a6472905e799f1953d2a0f"
+ "f3348ab21aa4adafd1d234441cf807c03a00",
+ ),
+ (
+ "c4eab05d357007c632f3dbb48489924d552b08fe0c353a0d4a1f00acda2c463a"
+ "fbea67c5e8d2877c5e3bc397a659949ef8021e954e0a12274e",
+ "43ba28f430cdff456ae531545f7ecd0ac834a55d9358c0372bfa0c6c6798c086"
+ "6aea01eb00742802b8438ea4cb82169c235160627b4c3a9480",
+ "03",
+ None,
+ "666f6f",
+ "d4f8f6131770dd46f40867d6fd5d5055de43541f8c5e35abbcd001b32a89f7d2"
+ "151f7647f11d8ca2ae279fb842d607217fce6e042f6815ea000c85741de5c8da"
+ "1144a6a1aba7f96de42505d7a7298524fda538fccbbb754f578c1cad10d54d0d"
+ "5428407e85dcbc98a49155c13764e66c3c00",
+ ),
+ (
+ "cd23d24f714274e744343237b93290f511f6425f98e64459ff203e8985083ffd"
+ "f60500553abc0e05cd02184bdb89c4ccd67e187951267eb328",
+ "dcea9e78f35a1bf3499a831b10b86c90aac01cd84b67a0109b55a36e9328b1e3"
+ "65fce161d71ce7131a543ea4cb5f7e9f1d8b00696447001400",
+ "0c3e544074ec63b0265e0c",
+ None,
+ "",
+ "1f0a8888ce25e8d458a21130879b840a9089d999aaba039eaf3e3afa090a09d3"
+ "89dba82c4ff2ae8ac5cdfb7c55e94d5d961a29fe0109941e00b8dbdeea6d3b05"
+ "1068df7254c0cdc129cbe62db2dc957dbb47b51fd3f213fb8698f064774250a5"
+ "028961c9bf8ffd973fe5d5c206492b140e00",
+ ),
+ (
+ "258cdd4ada32ed9c9ff54e63756ae582fb8fab2ac721f2c8e676a72768513d93"
+ "9f63dddb55609133f29adf86ec9929dccb52c1c5fd2ff7e21b",
+ "3ba16da0c6f2cc1f30187740756f5e798d6bc5fc015d7c63cc9510ee3fd44adc"
+ "24d8e968b6e46e6f94d19b945361726bd75e149ef09817f580",
+ "64a65f3cdedcdd66811e2915",
+ None,
+ "",
+ "7eeeab7c4e50fb799b418ee5e3197ff6bf15d43a14c34389b59dd1a7b1b85b4a"
+ "e90438aca634bea45e3a2695f1270f07fdcdf7c62b8efeaf00b45c2c96ba457e"
+ "b1a8bf075a3db28e5c24f6b923ed4ad747c3c9e03c7079efb87cb110d3a99861"
+ "e72003cbae6d6b8b827e4e6c143064ff3c00",
+ ),
+ (
+ "7ef4e84544236752fbb56b8f31a23a10e42814f5f55ca037cdcc11c64c9a3b29"
+ "49c1bb60700314611732a6c2fea98eebc0266a11a93970100e",
+ "b3da079b0aa493a5772029f0467baebee5a8112d9d3a22532361da294f7bb381"
+ "5c5dc59e176b4d9f381ca0938e13c6c07b174be65dfa578e80",
+ "64a65f3cdedcdd66811e2915e7",
+ None,
+ "",
+ "6a12066f55331b6c22acd5d5bfc5d71228fbda80ae8dec26bdd306743c5027cb"
+ "4890810c162c027468675ecf645a83176c0d7323a2ccde2d80efe5a1268e8aca"
+ "1d6fbc194d3f77c44986eb4ab4177919ad8bec33eb47bbb5fc6e28196fd1caf5"
+ "6b4e7e0ba5519234d047155ac727a1053100",
+ ),
+ (
+ "d65df341ad13e008567688baedda8e9dcdc17dc024974ea5b4227b6530e339bf"
+ "f21f99e68ca6968f3cca6dfe0fb9f4fab4fa135d5542ea3f01",
+ "df9705f58edbab802c7f8363cfe5560ab1c6132c20a9f1dd163483a26f8ac53a"
+ "39d6808bf4a1dfbd261b099bb03b3fb50906cb28bd8a081f00",
+ "bd0f6a3747cd561bdddf4640a332461a4a30a12a434cd0bf40d766d9c6d458e5"
+ "512204a30c17d1f50b5079631f64eb3112182da3005835461113718d1a5ef944",
+ None,
+ "",
+ "554bc2480860b49eab8532d2a533b7d578ef473eeb58c98bb2d0e1ce488a98b1"
+ "8dfde9b9b90775e67f47d4a1c3482058efc9f40d2ca033a0801b63d45b3b722e"
+ "f552bad3b4ccb667da350192b61c508cf7b6b5adadc2c8d9a446ef003fb05cba"
+ "5f30e88e36ec2703b349ca229c2670833900",
+ ),
+ (
+ "2ec5fe3c17045abdb136a5e6a913e32ab75ae68b53d2fc149b77e504132d3756"
+ "9b7e766ba74a19bd6162343a21c8590aa9cebca9014c636df5",
+ "79756f014dcfe2079f5dd9e718be4171e2ef2486a08f25186f6bff43a9936b9b"
+ "fe12402b08ae65798a3d81e22e9ec80e7690862ef3d4ed3a00",
+ "15777532b0bdd0d1389f636c5f6b9ba734c90af572877e2d272dd078aa1e567c"
+ "fa80e12928bb542330e8409f3174504107ecd5efac61ae7504dabe2a602ede89"
+ "e5cca6257a7c77e27a702b3ae39fc769fc54f2395ae6a1178cab4738e543072f"
+ "c1c177fe71e92e25bf03e4ecb72f47b64d0465aaea4c7fad372536c8ba516a60"
+ "39c3c2a39f0e4d832be432dfa9a706a6e5c7e19f397964ca4258002f7c0541b5"
+ "90316dbc5622b6b2a6fe7a4abffd96105eca76ea7b98816af0748c10df048ce0"
+ "12d901015a51f189f3888145c03650aa23ce894c3bd889e030d565071c59f409"
+ "a9981b51878fd6fc110624dcbcde0bf7a69ccce38fabdf86f3bef6044819de11",
+ None,
+ "",
+ "c650ddbb0601c19ca11439e1640dd931f43c518ea5bea70d3dcde5f4191fe53f"
+ "00cf966546b72bcc7d58be2b9badef28743954e3a44a23f880e8d4f1cfce2d7a"
+ "61452d26da05896f0a50da66a239a8a188b6d825b3305ad77b73fbac0836ecc6"
+ "0987fd08527c1a8e80d5823e65cafe2a3d00",
+ ),
+ (
+ "872d093780f5d3730df7c212664b37b8a0f24f56810daa8382cd4fa3f77634ec"
+ "44dc54f1c2ed9bea86fafb7632d8be199ea165f5ad55dd9ce8",
+ "a81b2e8a70a5ac94ffdbcc9badfc3feb0801f258578bb114ad44ece1ec0e799d"
+ "a08effb81c5d685c0c56f64eecaef8cdf11cc38737838cf400",
+ "6ddf802e1aae4986935f7f981ba3f0351d6273c0a0c22c9c0e8339168e675412"
+ "a3debfaf435ed651558007db4384b650fcc07e3b586a27a4f7a00ac8a6fec2cd"
+ "86ae4bf1570c41e6a40c931db27b2faa15a8cedd52cff7362c4e6e23daec0fbc"
+ "3a79b6806e316efcc7b68119bf46bc76a26067a53f296dafdbdc11c77f7777e9"
+ "72660cf4b6a9b369a6665f02e0cc9b6edfad136b4fabe723d2813db3136cfde9"
+ "b6d044322fee2947952e031b73ab5c603349b307bdc27bc6cb8b8bbd7bd32321"
+ "9b8033a581b59eadebb09b3c4f3d2277d4f0343624acc817804728b25ab79717"
+ "2b4c5c21a22f9c7839d64300232eb66e53f31c723fa37fe387c7d3e50bdf9813"
+ "a30e5bb12cf4cd930c40cfb4e1fc622592a49588794494d56d24ea4b40c89fc0"
+ "596cc9ebb961c8cb10adde976a5d602b1c3f85b9b9a001ed3c6a4d3b1437f520"
+ "96cd1956d042a597d561a596ecd3d1735a8d570ea0ec27225a2c4aaff26306d1"
+ "526c1af3ca6d9cf5a2c98f47e1c46db9a33234cfd4d81f2c98538a09ebe76998"
+ "d0d8fd25997c7d255c6d66ece6fa56f11144950f027795e653008f4bd7ca2dee"
+ "85d8e90f3dc315130ce2a00375a318c7c3d97be2c8ce5b6db41a6254ff264fa6"
+ "155baee3b0773c0f497c573f19bb4f4240281f0b1f4f7be857a4e59d416c06b4"
+ "c50fa09e1810ddc6b1467baeac5a3668d11b6ecaa901440016f389f80acc4db9"
+ "77025e7f5924388c7e340a732e554440e76570f8dd71b7d640b3450d1fd5f041"
+ "0a18f9a3494f707c717b79b4bf75c98400b096b21653b5d217cf3565c9597456"
+ "f70703497a078763829bc01bb1cbc8fa04eadc9a6e3f6699587a9e75c94e5bab"
+ "0036e0b2e711392cff0047d0d6b05bd2a588bc109718954259f1d86678a579a3"
+ "120f19cfb2963f177aeb70f2d4844826262e51b80271272068ef5b3856fa8535"
+ "aa2a88b2d41f2a0e2fda7624c2850272ac4a2f561f8f2f7a318bfd5caf969614"
+ "9e4ac824ad3460538fdc25421beec2cc6818162d06bbed0c40a387192349db67"
+ "a118bada6cd5ab0140ee273204f628aad1c135f770279a651e24d8c14d75a605"
+ "9d76b96a6fd857def5e0b354b27ab937a5815d16b5fae407ff18222c6d1ed263"
+ "be68c95f32d908bd895cd76207ae726487567f9a67dad79abec316f683b17f2d"
+ "02bf07e0ac8b5bc6162cf94697b3c27cd1fea49b27f23ba2901871962506520c"
+ "392da8b6ad0d99f7013fbc06c2c17a569500c8a7696481c1cd33e9b14e40b82e"
+ "79a5f5db82571ba97bae3ad3e0479515bb0e2b0f3bfcd1fd33034efc6245eddd"
+ "7ee2086ddae2600d8ca73e214e8c2b0bdb2b047c6a464a562ed77b73d2d841c4"
+ "b34973551257713b753632efba348169abc90a68f42611a40126d7cb21b58695"
+ "568186f7e569d2ff0f9e745d0487dd2eb997cafc5abf9dd102e62ff66cba87",
+ None,
+ "",
+ "e301345a41a39a4d72fff8df69c98075a0cc082b802fc9b2b6bc503f926b65bd"
+ "df7f4c8f1cb49f6396afc8a70abe6d8aef0db478d4c6b2970076c6a0484fe76d"
+ "76b3a97625d79f1ce240e7c576750d295528286f719b413de9ada3e8eb78ed57"
+ "3603ce30d8bb761785dc30dbc320869e1a00"
+ ),
+ # 7.5 Ed448ph
+ (
+ "833fe62409237b9d62ec77587520911e9a759cec1d19755b7da901b96dca3d42"
+ "ef7822e0d5104127dc05d6dbefde69e3ab2cec7c867c6e2c49",
+ "259b71c19f83ef77a7abd26524cbdb3161b590a48f7d17de3ee0ba9c52beb743"
+ "c09428a131d6b1b57303d90d8132c276d5ed3d5d01c0f53880",
+ "616263",
+ SHAKE256,
+ "",
+ "822f6901f7480f3d5f562c592994d9693602875614483256505600bbc281ae38"
+ "1f54d6bce2ea911574932f52a4e6cadd78769375ec3ffd1b801a0d9b3f4030cd"
+ "433964b6457ea39476511214f97469b57dd32dbc560a9a94d00bff07620464a3"
+ "ad203df7dc7ce360c3cd3696d9d9fab90f00"
+ ),
+ (
+ "833fe62409237b9d62ec77587520911e9a759cec1d19755b7da901b96dca3d42"
+ "ef7822e0d5104127dc05d6dbefde69e3ab2cec7c867c6e2c49",
+ "259b71c19f83ef77a7abd26524cbdb3161b590a48f7d17de3ee0ba9c52beb743"
+ "c09428a131d6b1b57303d90d8132c276d5ed3d5d01c0f53880",
+ "616263",
+ SHAKE256,
+ "666f6f",
+ "c32299d46ec8ff02b54540982814dce9a05812f81962b649d528095916a2aa48"
+ "1065b1580423ef927ecf0af5888f90da0f6a9a85ad5dc3f280d91224ba9911a3"
+ "653d00e484e2ce232521481c8658df304bb7745a73514cdb9bf3e15784ab7128"
+ "4f8d0704a608c54a6b62d97beb511d132100",
+ ),
+)
+
+
+rfc8032_tv_bytes = []
+for tv_str in rfc8032_tv_str:
+ rfc8032_tv_bytes.append([unhexlify(i) if isinstance(i, str) else i for i in tv_str])
+
+
+class TestEdDSA(unittest.TestCase):
+
+ def test_sign(self):
+ for sk, _, msg, hashmod, ctx, exp_signature in rfc8032_tv_bytes:
+ key = eddsa.import_private_key(sk)
+ signer = eddsa.new(key, 'rfc8032', context=ctx)
+ if hashmod is None:
+ # PureEdDSA
+ signature = signer.sign(msg)
+ else:
+ # HashEdDSA
+ hashobj = hashmod.new(msg)
+ signature = signer.sign(hashobj)
+ self.assertEqual(exp_signature, signature)
+
+ def test_verify(self):
+ for _, pk, msg, hashmod, ctx, exp_signature in rfc8032_tv_bytes:
+ key = eddsa.import_public_key(pk)
+ verifier = eddsa.new(key, 'rfc8032', context=ctx)
+ if hashmod is None:
+ # PureEdDSA
+ verifier.verify(msg, exp_signature)
+ else:
+ # HashEdDSA
+ hashobj = hashmod.new(msg)
+ verifier.verify(hashobj, exp_signature)
+
+ def test_negative(self):
+ key = ECC.generate(curve="ed25519")
+ self.assertRaises(ValueError, eddsa.new, key, 'rfc9999')
+
+ nist_key = ECC.generate(curve="p256")
+ self.assertRaises(ValueError, eddsa.new, nist_key, 'rfc8032')
+
+
+class TestExport_Ed25519(unittest.TestCase):
+
+ def test_raw(self):
+ key = ECC.generate(curve="Ed25519")
+ x, y = key.pointQ.xy
+ raw = bytearray(key._export_eddsa())
+ sign_x = raw[31] >> 7
+ raw[31] &= 0x7F
+ yt = bytes_to_long(raw[::-1])
+ self.assertEqual(y, yt)
+ self.assertEqual(x & 1, sign_x)
+
+ key = ECC.construct(point_x=0, point_y=1, curve="Ed25519")
+ out = key._export_eddsa()
+ self.assertEqual(b'\x01' + b'\x00' * 31, out)
+
+
+class TestExport_Ed448(unittest.TestCase):
+
+ def test_raw(self):
+ key = ECC.generate(curve="Ed448")
+ x, y = key.pointQ.xy
+ raw = bytearray(key._export_eddsa())
+ sign_x = raw[56] >> 7
+ raw[56] &= 0x7F
+ yt = bytes_to_long(raw[::-1])
+ self.assertEqual(y, yt)
+ self.assertEqual(x & 1, sign_x)
+
+ key = ECC.construct(point_x=0, point_y=1, curve="Ed448")
+ out = key._export_eddsa()
+ self.assertEqual(b'\x01' + b'\x00' * 56, out)
+
+
+class TestImport_Ed25519(unittest.TestCase):
+
+ def test_raw(self):
+ Px = 24407857220263921307776619664228778204996144802740950419837658238229122415920
+ Py = 56480760040633817885061096979765646085062883740629155052073094891081309750690
+ encoded = b'\xa2\x05\xd6\x00\xe1 \xe1\xc0\xff\x96\xee?V\x8e\xba/\xd3\x89\x06\xd7\xc4c\xe8$\xc2d\xd7a1\xfa\xde|'
+ key = eddsa.import_public_key(encoded)
+ self.assertEqual(Py, key.pointQ.y)
+ self.assertEqual(Px, key.pointQ.x)
+
+ encoded = b'\x01' + b'\x00' * 31
+ key = eddsa.import_public_key(encoded)
+ self.assertEqual(1, key.pointQ.y)
+ self.assertEqual(0, key.pointQ.x)
+
+
+class TestImport_Ed448(unittest.TestCase):
+
+ def test_raw(self):
+ Px = 0x153f42025aba3b0daecaa5cd79458b3146c7c9378c16c17b4a59bc3561113d90c169045bc12966c3f93e140c2ca0a3acc33d9205b9daf9b1
+ Py = 0x38f5c0015d3dedd576c232810dd90373b5b1d631a12894c043b7be529cbae03ede177d8fa490b56131dbcb2465d2aba777ef839fc1719b25
+ encoded = unhexlify("259b71c19f83ef77a7abd26524cbdb31"
+ "61b590a48f7d17de3ee0ba9c52beb743"
+ "c09428a131d6b1b57303d90d8132c276"
+ "d5ed3d5d01c0f53880")
+ key = eddsa.import_public_key(encoded)
+ self.assertEqual(Py, key.pointQ.y)
+ self.assertEqual(Px, key.pointQ.x)
+
+ encoded = b'\x01' + b'\x00' * 56
+ key = eddsa.import_public_key(encoded)
+ self.assertEqual(1, key.pointQ.y)
+ self.assertEqual(0, key.pointQ.x)
+
+
+class TestVectorsEdDSAWycheproof(unittest.TestCase):
+
+ def add_tests(self, filename):
+
+ def pk(group):
+ elem = group['key']['pk']
+ return unhexlify(elem)
+
+ def sk(group):
+ elem = group['key']['sk']
+ return unhexlify(elem)
+
+ result = load_test_vectors_wycheproof(("Signature", "wycheproof"),
+ filename,
+ "Wycheproof ECDSA signature (%s)"
+ % filename,
+ group_tag={'pk': pk, 'sk': sk})
+ self.tv += result
+
+ def setUp(self):
+ self.tv = []
+ self.add_tests("eddsa_test.json")
+ self.add_tests("ed448_test.json")
+
+ def test_sign(self, tv):
+ if not tv.valid:
+ return
+
+ self._id = "Wycheproof EdDSA Sign Test #%d (%s, %s)" % (tv.id, tv.comment, tv.filename)
+ key = eddsa.import_private_key(tv.sk)
+ signer = eddsa.new(key, 'rfc8032')
+ signature = signer.sign(tv.msg)
+ self.assertEqual(signature, tv.sig)
+
+ def test_verify(self, tv):
+ self._id = "Wycheproof EdDSA Verify Test #%d (%s, %s)" % (tv.id, tv.comment, tv.filename)
+ key = eddsa.import_public_key(tv.pk)
+ verifier = eddsa.new(key, 'rfc8032')
+ try:
+ verifier.verify(tv.msg, tv.sig)
+ except ValueError:
+ assert not tv.valid
+ else:
+ assert tv.valid
+
+ def runTest(self):
+ for tv in self.tv:
+ self.test_sign(tv)
+ self.test_verify(tv)
+
+
+def get_tests(config={}):
+
+ tests = []
+ tests += list_test_cases(TestExport_Ed25519)
+ tests += list_test_cases(TestExport_Ed448)
+ tests += list_test_cases(TestImport_Ed25519)
+ tests += list_test_cases(TestImport_Ed448)
+ tests += list_test_cases(TestEdDSA)
+ tests += [TestVectorsEdDSAWycheproof()]
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Signature/test_pkcs1_15.py b/lib/Crypto/SelfTest/Signature/test_pkcs1_15.py
new file mode 100644
index 0000000..8e2c6ee
--- /dev/null
+++ b/lib/Crypto/SelfTest/Signature/test_pkcs1_15.py
@@ -0,0 +1,348 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import json
+import unittest
+from binascii import unhexlify
+
+from Crypto.Util.py3compat import bchr
+from Crypto.Util.number import bytes_to_long
+from Crypto.Util.strxor import strxor
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors, load_test_vectors_wycheproof
+
+from Crypto.Hash import (SHA1, SHA224, SHA256, SHA384, SHA512, SHA3_384,
+ SHA3_224, SHA3_256, SHA3_512)
+from Crypto.PublicKey import RSA
+from Crypto.Signature import pkcs1_15
+from Crypto.Signature import PKCS1_v1_5
+
+from Crypto.Util._file_system import pycryptodome_filename
+from Crypto.Util.strxor import strxor
+
+
+def load_hash_by_name(hash_name):
+ return __import__("Crypto.Hash." + hash_name, globals(), locals(), ["new"])
+
+
+class FIPS_PKCS1_Verify_Tests(unittest.TestCase):
+
+ def shortDescription(self):
+ return "FIPS PKCS1 Tests (Verify)"
+
+ def test_can_sign(self):
+ test_public_key = RSA.generate(1024).public_key()
+ verifier = pkcs1_15.new(test_public_key)
+ self.assertEqual(verifier.can_sign(), False)
+
+
+class FIPS_PKCS1_Verify_Tests_KAT(unittest.TestCase):
+ pass
+
+
+test_vectors_verify = load_test_vectors(("Signature", "PKCS1-v1.5"),
+ "SigVer15_186-3.rsp",
+ "Signature Verification 186-3",
+ {'shaalg': lambda x: x,
+ 'd': lambda x: int(x),
+ 'result': lambda x: x}) or []
+
+
+for count, tv in enumerate(test_vectors_verify):
+ if isinstance(tv, str):
+ continue
+ if hasattr(tv, "n"):
+ modulus = tv.n
+ continue
+
+ hash_module = load_hash_by_name(tv.shaalg.upper())
+ hash_obj = hash_module.new(tv.msg)
+ public_key = RSA.construct([bytes_to_long(x) for x in (modulus, tv.e)]) # type: ignore
+ verifier = pkcs1_15.new(public_key)
+
+ def positive_test(self, hash_obj=hash_obj, verifier=verifier, signature=tv.s):
+ verifier.verify(hash_obj, signature)
+
+ def negative_test(self, hash_obj=hash_obj, verifier=verifier, signature=tv.s):
+ self.assertRaises(ValueError, verifier.verify, hash_obj, signature)
+
+ if tv.result == 'f':
+ setattr(FIPS_PKCS1_Verify_Tests_KAT, "test_negative_%d" % count, negative_test)
+ else:
+ setattr(FIPS_PKCS1_Verify_Tests_KAT, "test_positive_%d" % count, positive_test)
+
+
+class FIPS_PKCS1_Sign_Tests(unittest.TestCase):
+
+ def shortDescription(self):
+ return "FIPS PKCS1 Tests (Sign)"
+
+ def test_can_sign(self):
+ test_private_key = RSA.generate(1024)
+ signer = pkcs1_15.new(test_private_key)
+ self.assertEqual(signer.can_sign(), True)
+
+
+class FIPS_PKCS1_Sign_Tests_KAT(unittest.TestCase):
+ pass
+
+
+test_vectors_sign = load_test_vectors(("Signature", "PKCS1-v1.5"),
+ "SigGen15_186-2.txt",
+ "Signature Generation 186-2",
+ {'shaalg': lambda x: x}) or []
+
+test_vectors_sign += load_test_vectors(("Signature", "PKCS1-v1.5"),
+ "SigGen15_186-3.txt",
+ "Signature Generation 186-3",
+ {'shaalg': lambda x: x}) or []
+
+for count, tv in enumerate(test_vectors_sign):
+ if isinstance(tv, str):
+ continue
+ if hasattr(tv, "n"):
+ modulus = tv.n
+ continue
+ if hasattr(tv, "e"):
+ private_key = RSA.construct([bytes_to_long(x) for x in (modulus, tv.e, tv.d)]) # type: ignore
+ signer = pkcs1_15.new(private_key)
+ continue
+
+ hash_module = load_hash_by_name(tv.shaalg.upper())
+ hash_obj = hash_module.new(tv.msg)
+
+ def new_test(self, hash_obj=hash_obj, signer=signer, result=tv.s):
+ signature = signer.sign(hash_obj)
+ self.assertEqual(signature, result)
+
+ setattr(FIPS_PKCS1_Sign_Tests_KAT, "test_%d" % count, new_test)
+
+
+class PKCS1_15_NoParams(unittest.TestCase):
+ """Verify that PKCS#1 v1.5 signatures pass even without NULL parameters in
+ the algorithm identifier (PyCrypto/LP bug #1119552)."""
+
+ rsakey = """-----BEGIN RSA PRIVATE KEY-----
+ MIIBOwIBAAJBAL8eJ5AKoIsjURpcEoGubZMxLD7+kT+TLr7UkvEtFrRhDDKMtuII
+ q19FrL4pUIMymPMSLBn3hJLe30Dw48GQM4UCAwEAAQJACUSDEp8RTe32ftq8IwG8
+ Wojl5mAd1wFiIOrZ/Uv8b963WJOJiuQcVN29vxU5+My9GPZ7RA3hrDBEAoHUDPrI
+ OQIhAPIPLz4dphiD9imAkivY31Rc5AfHJiQRA7XixTcjEkojAiEAyh/pJHks/Mlr
+ +rdPNEpotBjfV4M4BkgGAA/ipcmaAjcCIQCHvhwwKVBLzzTscT2HeUdEeBMoiXXK
+ JACAr3sJQJGxIQIgarRp+m1WSKV1MciwMaTOnbU7wxFs9DP1pva76lYBzgUCIQC9
+ n0CnZCJ6IZYqSt0H5N7+Q+2Ro64nuwV/OSQfM6sBwQ==
+ -----END RSA PRIVATE KEY-----"""
+
+ msg = b"This is a test\x0a"
+
+ # PKCS1 v1.5 signature of the message computed using SHA-1.
+ # The digestAlgorithm SEQUENCE does NOT contain the NULL parameter.
+ sig_str = "a287a13517f716e72fb14eea8e33a8db4a4643314607e7ca3e3e28"\
+ "1893db74013dda8b855fd99f6fecedcb25fcb7a434f35cd0a101f8"\
+ "b19348e0bd7b6f152dfc"
+ signature = unhexlify(sig_str)
+
+ def runTest(self):
+ verifier = pkcs1_15.new(RSA.importKey(self.rsakey))
+ hashed = SHA1.new(self.msg)
+ verifier.verify(hashed, self.signature)
+
+
+class PKCS1_Legacy_Module_Tests(unittest.TestCase):
+ """Verify that the legacy module Crypto.Signature.PKCS1_v1_5
+ behaves as expected. The only difference is that the verify()
+ method returns True/False and does not raise exceptions."""
+
+ def shortDescription(self):
+ return "Test legacy Crypto.Signature.PKCS1_v1_5"
+
+ def runTest(self):
+ key = RSA.importKey(PKCS1_15_NoParams.rsakey)
+ hashed = SHA1.new(b"Test")
+ good_signature = PKCS1_v1_5.new(key).sign(hashed)
+ verifier = PKCS1_v1_5.new(key.public_key())
+
+ self.assertEqual(verifier.verify(hashed, good_signature), True)
+
+ # Flip a few bits in the signature
+ bad_signature = strxor(good_signature, bchr(1) * len(good_signature))
+ self.assertEqual(verifier.verify(hashed, bad_signature), False)
+
+
+class PKCS1_All_Hashes_Tests(unittest.TestCase):
+
+ def shortDescription(self):
+ return "Test PKCS#1v1.5 signature in combination with all hashes"
+
+ def runTest(self):
+
+ key = RSA.generate(1024)
+ signer = pkcs1_15.new(key)
+ hash_names = ("MD2", "MD4", "MD5", "RIPEMD160", "SHA1",
+ "SHA224", "SHA256", "SHA384", "SHA512",
+ "SHA3_224", "SHA3_256", "SHA3_384", "SHA3_512")
+
+ for name in hash_names:
+ hashed = load_hash_by_name(name).new(b"Test")
+ signer.sign(hashed)
+
+ from Crypto.Hash import BLAKE2b, BLAKE2s
+ for hash_size in (20, 32, 48, 64):
+ hashed_b = BLAKE2b.new(digest_bytes=hash_size, data=b"Test")
+ signer.sign(hashed_b)
+ for hash_size in (16, 20, 28, 32):
+ hashed_s = BLAKE2s.new(digest_bytes=hash_size, data=b"Test")
+ signer.sign(hashed_s)
+
+
+class TestVectorsWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._id = "None"
+
+ def setUp(self):
+ self.tv = []
+ self.add_tests("rsa_sig_gen_misc_test.json")
+ self.add_tests("rsa_signature_2048_sha224_test.json")
+ self.add_tests("rsa_signature_2048_sha256_test.json")
+ self.add_tests("rsa_signature_2048_sha384_test.json")
+ self.add_tests("rsa_signature_2048_sha3_224_test.json")
+ self.add_tests("rsa_signature_2048_sha3_256_test.json")
+ self.add_tests("rsa_signature_2048_sha3_384_test.json")
+ self.add_tests("rsa_signature_2048_sha3_512_test.json")
+ self.add_tests("rsa_signature_2048_sha512_test.json")
+ self.add_tests("rsa_signature_2048_sha512_224_test.json")
+ self.add_tests("rsa_signature_2048_sha512_256_test.json")
+ self.add_tests("rsa_signature_3072_sha256_test.json")
+ self.add_tests("rsa_signature_3072_sha384_test.json")
+ self.add_tests("rsa_signature_3072_sha3_256_test.json")
+ self.add_tests("rsa_signature_3072_sha3_384_test.json")
+ self.add_tests("rsa_signature_3072_sha3_512_test.json")
+ self.add_tests("rsa_signature_3072_sha512_test.json")
+ self.add_tests("rsa_signature_3072_sha512_256_test.json")
+ self.add_tests("rsa_signature_4096_sha384_test.json")
+ self.add_tests("rsa_signature_4096_sha512_test.json")
+ self.add_tests("rsa_signature_4096_sha512_256_test.json")
+ self.add_tests("rsa_signature_test.json")
+
+ def add_tests(self, filename):
+
+ def filter_rsa(group):
+ return RSA.import_key(group['keyPem'])
+
+ def filter_sha(group):
+ hash_name = group['sha']
+ if hash_name == "SHA-512":
+ return SHA512
+ elif hash_name == "SHA-512/224":
+ return SHA512.new(truncate="224")
+ elif hash_name == "SHA-512/256":
+ return SHA512.new(truncate="256")
+ elif hash_name == "SHA3-512":
+ return SHA3_512
+ elif hash_name == "SHA-384":
+ return SHA384
+ elif hash_name == "SHA3-384":
+ return SHA3_384
+ elif hash_name == "SHA-256":
+ return SHA256
+ elif hash_name == "SHA3-256":
+ return SHA3_256
+ elif hash_name == "SHA-224":
+ return SHA224
+ elif hash_name == "SHA3-224":
+ return SHA3_224
+ elif hash_name == "SHA-1":
+ return SHA1
+ else:
+ raise ValueError("Unknown hash algorithm: " + hash_name)
+
+ def filter_type(group):
+ type_name = group['type']
+ if type_name not in ("RsassaPkcs1Verify", "RsassaPkcs1Generate"):
+ raise ValueError("Unknown type name " + type_name)
+
+ result = load_test_vectors_wycheproof(("Signature", "wycheproof"),
+ filename,
+ "Wycheproof PKCS#1v1.5 signature (%s)" % filename,
+ group_tag={'rsa_key': filter_rsa,
+ 'hash_mod': filter_sha,
+ 'type': filter_type})
+ return result
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_verify(self, tv):
+ self._id = "Wycheproof RSA PKCS$#1 Test #" + str(tv.id)
+
+ hashed_msg = tv.hash_module.new(tv.msg)
+ signer = pkcs1_15.new(tv.key)
+ try:
+ signature = signer.verify(hashed_msg, tv.sig)
+ except ValueError as e:
+ if tv.warning:
+ return
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.warn(tv)
+
+ def runTest(self):
+ for tv in self.tv:
+ self.test_verify(tv)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(FIPS_PKCS1_Verify_Tests)
+ tests += list_test_cases(FIPS_PKCS1_Sign_Tests)
+ tests += list_test_cases(PKCS1_15_NoParams)
+ tests += list_test_cases(PKCS1_Legacy_Module_Tests)
+ tests += list_test_cases(PKCS1_All_Hashes_Tests)
+ tests += [ TestVectorsWycheproof(wycheproof_warnings) ]
+
+ if config.get('slow_tests'):
+ tests += list_test_cases(FIPS_PKCS1_Verify_Tests_KAT)
+ tests += list_test_cases(FIPS_PKCS1_Sign_Tests_KAT)
+
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Signature/test_pss.py b/lib/Crypto/SelfTest/Signature/test_pss.py
new file mode 100644
index 0000000..535474b
--- /dev/null
+++ b/lib/Crypto/SelfTest/Signature/test_pss.py
@@ -0,0 +1,377 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+
+from Crypto.Util.py3compat import b, bchr
+from Crypto.Util.number import bytes_to_long
+from Crypto.Util.strxor import strxor
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.SelfTest.loader import load_test_vectors, load_test_vectors_wycheproof
+
+from Crypto.Hash import SHA1, SHA224, SHA256, SHA384, SHA512
+from Crypto.PublicKey import RSA
+from Crypto.Signature import pss
+from Crypto.Signature import PKCS1_PSS
+
+from Crypto.Signature.pss import MGF1
+
+
+def load_hash_by_name(hash_name):
+ return __import__("Crypto.Hash." + hash_name, globals(), locals(), ["new"])
+
+
+class PRNG(object):
+
+ def __init__(self, stream):
+ self.stream = stream
+ self.idx = 0
+
+ def __call__(self, rnd_size):
+ result = self.stream[self.idx:self.idx + rnd_size]
+ self.idx += rnd_size
+ return result
+
+
+class PSS_Tests(unittest.TestCase):
+
+ rsa_key = b'-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQEAsvI34FgiTK8+txBvmooNGpNwk23YTU51dwNZi5yha3W4lA/Q\nvcZrDalkmD7ekWQwnduxVKa6pRSI13KBgeUOIqJoGXSWhntEtY3FEwvWOHW5AE7Q\njUzTzCiYT6TVaCcpa/7YLai+p6ai2g5f5Zfh4jSawa9uYeuggFygQq4IVW796MgV\nyqxYMM/arEj+/sKz3Viua9Rp9fFosertCYCX4DUTgW0mX9bwEnEOgjSI3pLOPXz1\n8vx+DRZS5wMCmwCUa0sKonLn3cAUPq+sGix7+eo7T0Z12MU8ud7IYVX/75r3cXiF\nPaYE2q8Le0kgOApIXbb+x74x0rNgyIh1yGygkwIDAQABAoIBABz4t1A0pLT6qHI2\nEIOaNz3mwhK0dZEqkz0GB1Dhtoax5ATgvKCFB98J3lYB08IBURe1snOsnMpOVUtg\naBRSM+QqnCUG6bnzKjAkuFP5liDE+oNQv1YpKp9CsUovuzdmI8Au3ewihl+ZTIN2\nUVNYMEOR1b5m+z2SSwWNOYsiJwpBrT7zkpdlDyjat7FiiPhMMIMXjhQFVxURMIcB\njUBtPzGvV/PG90cVDWi1wRGeeP1dDqti/jsnvykQ15KW1MqGrpeNKRmDdTy/Ucl1\nWIoYklKw3U456lgZ/rDTDB818+Tlnk35z4yF7d5ANPM8CKfqOPcnO1BCKVFzf4eq\n54wvUtkCgYEA1Zv2lp06l7rXMsvNtyYQjbFChezRDRnPwZmN4NCdRtTgGG1G0Ryd\nYz6WWoPGqZp0b4LAaaHd3W2GTcpXF8WXMKfMX1W+tMAxMozfsXRKMcHoypwuS5wT\nfJRXJCG4pvd57AB0iVUEJW2we+uGKU5Zxcx//id2nXGCpoRyViIplQsCgYEA1nVC\neHupHChht0Fh4N09cGqZHZzuwXjOUMzR3Vsfz+4WzVS3NvIgN4g5YgmQFOeKwo5y\niRq5yvubcNdFvf85eHWClg0zPAyxJCVUWigCrrOanGEhJo6re4idJvNVzu4Ucg0v\n6B3SJ1HsCda+ZSNz24bSyqRep8A+RoAaoVSFx5kCgYEAn3RvXPs9s+obnqWYiPF3\nRe5etE6Vt2vfNKwFxx6zaR6bsmBQjuUHcABWiHb6I71S0bMPI0tbrWGG8ibrYKl1\nNTLtUvVVCOS3VP7oNTWT9RTFTAnOXU7DFSo+6o/poWn3r36ff6zhDXeWWMr2OXtt\ndEQ1/2lCGEGVv+v61eVmmQUCgYABFHITPTwqwiFL1O5zPWnzyPWgaovhOYSAb6eW\n38CXQXGn8wdBJZL39J2lWrr4//l45VK6UgIhfYbY2JynSkO10ZGow8RARygVMILu\nOUlaK9lZdDvAf/NpGdUAvzTtZ9F+iYZ2OsA2JnlzyzsGM1l//3vMPWukmJk3ral0\nqoJJ8QKBgGRG3eVHnIegBbFVuMDp2NTcfuSuDVUQ1fGAwtPiFa8u81IodJnMk2pq\niXu2+0ytNA/M+SVrAnE2AgIzcaJbtr0p2srkuVM7KMWnG1vWFNjtXN8fAhf/joOv\nD+NmPL/N4uE57e40tbiU/H7KdyZaDt+5QiTmdhuyAe6CBjKsF2jy\n-----END RSA PRIVATE KEY-----'
+ msg = b'AAA'
+ tag = b'\x00[c5\xd8\xb0\x8b!D\x81\x83\x07\xc0\xdd\xb9\xb4\xb2`\x92\xe7\x02\xf1\xe1P\xea\xc3\xf0\xe3>\xddX5\xdd\x8e\xc5\x89\xef\xf3\xc2\xdc\xfeP\x02\x7f\x12+\xc9\xaf\xbb\xec\xfe\xb0\xa5\xb9\x08\x11P\x8fL\xee5\x9b\xb0k{=_\xd2\x14\xfb\x01R\xb7\xfe\x14}b\x03\x8d5Y\x89~}\xfc\xf2l\xd01-\xbd\xeb\x11\xcdV\x11\xe9l\x19k/o5\xa2\x0f\x15\xe7Q$\t=\xec\x1dAB\x19\xa5P\x9a\xaf\xa3G\x86"\xd6~\xf0<p5\x00\x86\xe0\xf3\x99\xc7+\xcfc,\\\x13)v\xcd\xff\x08o\x90\xc5\xd1\xca\x869\xf45\x1e\xfd\xa2\xf1n\xa3\xa6e\xc5\x11Q\xe4@\xbd\x17\x83x\xc9\x9b\xb5\xc7\xea\x03U\x9b\xa0\xccC\x17\xc9T\x86/\x05\x1c\xc7\x95hC\xf9b1\xbb\x05\xc3\xf0\x9a>j\xfcqkbs\x13\x84b\xe4\xbdm(\xed`\xa4F\xfb\x8f.\xe1\x8c)/_\x9eS\x98\xa4v\xb8\xdc\xfe\xf7/D\x18\x19\xb3T\x97:\xe2\x96s\xe8<\xa2\xb4\xb9\xf8/'
+
+ def test_positive_1(self):
+ key = RSA.import_key(self.rsa_key)
+ h = SHA256.new(self.msg)
+ verifier = pss.new(key)
+ verifier.verify(h, self.tag)
+
+ def test_negative_1(self):
+ key = RSA.import_key(self.rsa_key)
+ h = SHA256.new(self.msg + b'A')
+ verifier = pss.new(key)
+ tag = bytearray(self.tag)
+ self.assertRaises(ValueError, verifier.verify, h, tag)
+
+ def test_negative_2(self):
+ key = RSA.import_key(self.rsa_key)
+ h = SHA256.new(self.msg)
+ verifier = pss.new(key, salt_bytes=1000)
+ tag = bytearray(self.tag)
+ self.assertRaises(ValueError, verifier.verify, h, tag)
+
+
+class FIPS_PKCS1_Verify_Tests(unittest.TestCase):
+
+ def shortDescription(self):
+ return "FIPS PKCS1 Tests (Verify)"
+
+ def verify_positive(self, hashmod, message, public_key, salt, signature):
+ prng = PRNG(salt)
+ hashed = hashmod.new(message)
+ verifier = pss.new(public_key, salt_bytes=len(salt), rand_func=prng)
+ verifier.verify(hashed, signature)
+
+ def verify_negative(self, hashmod, message, public_key, salt, signature):
+ prng = PRNG(salt)
+ hashed = hashmod.new(message)
+ verifier = pss.new(public_key, salt_bytes=len(salt), rand_func=prng)
+ self.assertRaises(ValueError, verifier.verify, hashed, signature)
+
+ def test_can_sign(self):
+ test_public_key = RSA.generate(1024).public_key()
+ verifier = pss.new(test_public_key)
+ self.assertEqual(verifier.can_sign(), False)
+
+
+class FIPS_PKCS1_Verify_Tests_KAT(unittest.TestCase):
+ pass
+
+
+test_vectors_verify = load_test_vectors(("Signature", "PKCS1-PSS"),
+ "SigVerPSS_186-3.rsp",
+ "Signature Verification 186-3",
+ {'shaalg': lambda x: x,
+ 'result': lambda x: x}) or []
+
+
+for count, tv in enumerate(test_vectors_verify):
+ if isinstance(tv, str):
+ continue
+ if hasattr(tv, "n"):
+ modulus = tv.n
+ continue
+ if hasattr(tv, "p"):
+ continue
+
+ hash_module = load_hash_by_name(tv.shaalg.upper())
+ hash_obj = hash_module.new(tv.msg)
+ public_key = RSA.construct([bytes_to_long(x) for x in (modulus, tv.e)]) # type: ignore
+ if tv.saltval != b("\x00"):
+ prng = PRNG(tv.saltval)
+ verifier = pss.new(public_key, salt_bytes=len(tv.saltval), rand_func=prng)
+ else:
+ verifier = pss.new(public_key, salt_bytes=0)
+
+ def positive_test(self, hash_obj=hash_obj, verifier=verifier, signature=tv.s):
+ verifier.verify(hash_obj, signature)
+
+ def negative_test(self, hash_obj=hash_obj, verifier=verifier, signature=tv.s):
+ self.assertRaises(ValueError, verifier.verify, hash_obj, signature)
+
+ if tv.result == 'p':
+ setattr(FIPS_PKCS1_Verify_Tests_KAT, "test_positive_%d" % count, positive_test)
+ else:
+ setattr(FIPS_PKCS1_Verify_Tests_KAT, "test_negative_%d" % count, negative_test)
+
+
+class FIPS_PKCS1_Sign_Tests(unittest.TestCase):
+
+ def shortDescription(self):
+ return "FIPS PKCS1 Tests (Sign)"
+
+ def test_can_sign(self):
+ test_private_key = RSA.generate(1024)
+ signer = pss.new(test_private_key)
+ self.assertEqual(signer.can_sign(), True)
+
+
+class FIPS_PKCS1_Sign_Tests_KAT(unittest.TestCase):
+ pass
+
+
+test_vectors_sign = load_test_vectors(("Signature", "PKCS1-PSS"),
+ "SigGenPSS_186-2.txt",
+ "Signature Generation 186-2",
+ {'shaalg': lambda x: x}) or []
+
+test_vectors_sign += load_test_vectors(("Signature", "PKCS1-PSS"),
+ "SigGenPSS_186-3.txt",
+ "Signature Generation 186-3",
+ {'shaalg': lambda x: x}) or []
+
+for count, tv in enumerate(test_vectors_sign):
+ if isinstance(tv, str):
+ continue
+ if hasattr(tv, "n"):
+ modulus = tv.n
+ continue
+ if hasattr(tv, "e"):
+ private_key = RSA.construct([bytes_to_long(x) for x in (modulus, tv.e, tv.d)]) # type: ignore
+ continue
+
+ hash_module = load_hash_by_name(tv.shaalg.upper())
+ hash_obj = hash_module.new(tv.msg)
+ if tv.saltval != b("\x00"):
+ prng = PRNG(tv.saltval)
+ signer = pss.new(private_key, salt_bytes=len(tv.saltval), rand_func=prng)
+ else:
+ signer = pss.new(private_key, salt_bytes=0)
+
+ def new_test(self, hash_obj=hash_obj, signer=signer, result=tv.s):
+ signature = signer.sign(hash_obj)
+ self.assertEqual(signature, result)
+
+ setattr(FIPS_PKCS1_Sign_Tests_KAT, "test_%d" % count, new_test)
+
+
+class PKCS1_Legacy_Module_Tests(unittest.TestCase):
+ """Verify that the legacy module Crypto.Signature.PKCS1_PSS
+ behaves as expected. The only difference is that the verify()
+ method returns True/False and does not raise exceptions."""
+
+ def shortDescription(self):
+ return "Test legacy Crypto.Signature.PKCS1_PSS"
+
+ def runTest(self):
+ key = RSA.generate(1024)
+ hashed = SHA1.new(b("Test"))
+ good_signature = PKCS1_PSS.new(key).sign(hashed)
+ verifier = PKCS1_PSS.new(key.public_key())
+
+ self.assertEqual(verifier.verify(hashed, good_signature), True)
+
+ # Flip a few bits in the signature
+ bad_signature = strxor(good_signature, bchr(1) * len(good_signature))
+ self.assertEqual(verifier.verify(hashed, bad_signature), False)
+
+
+class PKCS1_All_Hashes_Tests(unittest.TestCase):
+
+ def shortDescription(self):
+ return "Test PKCS#1 PSS signature in combination with all hashes"
+
+ def runTest(self):
+
+ key = RSA.generate(1280)
+ signer = pss.new(key)
+ hash_names = ("MD2", "MD4", "MD5", "RIPEMD160", "SHA1",
+ "SHA224", "SHA256", "SHA384", "SHA512",
+ "SHA3_224", "SHA3_256", "SHA3_384", "SHA3_512")
+
+ for name in hash_names:
+ hashed = load_hash_by_name(name).new(b("Test"))
+ signer.sign(hashed)
+
+ from Crypto.Hash import BLAKE2b, BLAKE2s
+ for hash_size in (20, 32, 48, 64):
+ hashed_b = BLAKE2b.new(digest_bytes=hash_size, data=b("Test"))
+ signer.sign(hashed_b)
+ for hash_size in (16, 20, 28, 32):
+ hashed_s = BLAKE2s.new(digest_bytes=hash_size, data=b("Test"))
+ signer.sign(hashed_s)
+
+
+def get_hash_module(hash_name):
+ if hash_name == "SHA-512":
+ hash_module = SHA512
+ elif hash_name == "SHA-512/224":
+ hash_module = SHA512.new(truncate="224")
+ elif hash_name == "SHA-512/256":
+ hash_module = SHA512.new(truncate="256")
+ elif hash_name == "SHA-384":
+ hash_module = SHA384
+ elif hash_name == "SHA-256":
+ hash_module = SHA256
+ elif hash_name == "SHA-224":
+ hash_module = SHA224
+ elif hash_name == "SHA-1":
+ hash_module = SHA1
+ else:
+ raise ValueError("Unknown hash algorithm: " + hash_name)
+ return hash_module
+
+
+class TestVectorsPSSWycheproof(unittest.TestCase):
+
+ def __init__(self, wycheproof_warnings):
+ unittest.TestCase.__init__(self)
+ self._wycheproof_warnings = wycheproof_warnings
+ self._id = "None"
+
+ def add_tests(self, filename):
+
+ def filter_rsa(group):
+ return RSA.import_key(group['keyPem'])
+
+ def filter_sha(group):
+ return get_hash_module(group['sha'])
+
+ def filter_type(group):
+ type_name = group['type']
+ if type_name not in ("RsassaPssVerify", ):
+ raise ValueError("Unknown type name " + type_name)
+
+ def filter_slen(group):
+ return group['sLen']
+
+ def filter_mgf(group):
+ mgf = group['mgf']
+ if mgf not in ("MGF1", ):
+ raise ValueError("Unknown MGF " + mgf)
+ mgf1_hash = get_hash_module(group['mgfSha'])
+
+ def mgf(x, y, mh=mgf1_hash):
+ return MGF1(x, y, mh)
+
+ return mgf
+
+ result = load_test_vectors_wycheproof(("Signature", "wycheproof"),
+ filename,
+ "Wycheproof PSS signature (%s)" % filename,
+ group_tag={'key': filter_rsa,
+ 'hash_module': filter_sha,
+ 'sLen': filter_slen,
+ 'mgf': filter_mgf,
+ 'type': filter_type})
+ return result
+
+ def setUp(self):
+ self.tv = []
+ self.add_tests("rsa_pss_2048_sha1_mgf1_20_test.json")
+ self.add_tests("rsa_pss_2048_sha256_mgf1_0_test.json")
+ self.add_tests("rsa_pss_2048_sha256_mgf1_32_test.json")
+ self.add_tests("rsa_pss_2048_sha512_256_mgf1_28_test.json")
+ self.add_tests("rsa_pss_2048_sha512_256_mgf1_32_test.json")
+ self.add_tests("rsa_pss_3072_sha256_mgf1_32_test.json")
+ self.add_tests("rsa_pss_4096_sha256_mgf1_32_test.json")
+ self.add_tests("rsa_pss_4096_sha512_mgf1_32_test.json")
+ self.add_tests("rsa_pss_misc_test.json")
+
+ def shortDescription(self):
+ return self._id
+
+ def warn(self, tv):
+ if tv.warning and self._wycheproof_warnings:
+ import warnings
+ warnings.warn("Wycheproof warning: %s (%s)" % (self._id, tv.comment))
+
+ def test_verify(self, tv):
+ self._id = "Wycheproof RSA PSS Test #%d (%s)" % (tv.id, tv.comment)
+
+ hashed_msg = tv.hash_module.new(tv.msg)
+ signer = pss.new(tv.key, mask_func=tv.mgf, salt_bytes=tv.sLen)
+ try:
+ signature = signer.verify(hashed_msg, tv.sig)
+ except ValueError as e:
+ if tv.warning:
+ return
+ assert not tv.valid
+ else:
+ assert tv.valid
+ self.warn(tv)
+
+ def runTest(self):
+ for tv in self.tv:
+ self.test_verify(tv)
+
+
+def get_tests(config={}):
+ wycheproof_warnings = config.get('wycheproof_warnings')
+
+ tests = []
+ tests += list_test_cases(PSS_Tests)
+ tests += list_test_cases(FIPS_PKCS1_Verify_Tests)
+ tests += list_test_cases(FIPS_PKCS1_Sign_Tests)
+ tests += list_test_cases(PKCS1_Legacy_Module_Tests)
+ tests += list_test_cases(PKCS1_All_Hashes_Tests)
+
+ if config.get('slow_tests'):
+ tests += list_test_cases(FIPS_PKCS1_Verify_Tests_KAT)
+ tests += list_test_cases(FIPS_PKCS1_Sign_Tests_KAT)
+
+ tests += [TestVectorsPSSWycheproof(wycheproof_warnings)]
+
+ return tests
+
+
+if __name__ == '__main__':
+ def suite():
+ return unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Util/__init__.py b/lib/Crypto/SelfTest/Util/__init__.py
new file mode 100644
index 0000000..ee993db
--- /dev/null
+++ b/lib/Crypto/SelfTest/Util/__init__.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Util/__init__.py: Self-test for utility modules
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-test for utility modules"""
+
+__revision__ = "$Id$"
+
+import os
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest.Util import test_number; tests += test_number.get_tests(config=config)
+ from Crypto.SelfTest.Util import test_Counter; tests += test_Counter.get_tests(config=config)
+ from Crypto.SelfTest.Util import test_Padding; tests += test_Padding.get_tests(config=config)
+ from Crypto.SelfTest.Util import test_strxor; tests += test_strxor.get_tests(config=config)
+ from Crypto.SelfTest.Util import test_asn1; tests += test_asn1.get_tests(config=config)
+ from Crypto.SelfTest.Util import test_rfc1751; tests += test_rfc1751.get_tests(config=config)
+ return tests
+
+if __name__ == '__main__':
+ import unittest
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Util/test_Counter.py b/lib/Crypto/SelfTest/Util/test_Counter.py
new file mode 100644
index 0000000..8837a32
--- /dev/null
+++ b/lib/Crypto/SelfTest/Util/test_Counter.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Util/test_Counter: Self-test for the Crypto.Util.Counter module
+#
+# Written in 2009 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-tests for Crypto.Util.Counter"""
+
+from Crypto.Util.py3compat import *
+
+import unittest
+
+class CounterTests(unittest.TestCase):
+ def setUp(self):
+ global Counter
+ from Crypto.Util import Counter
+
+ def test_BE(self):
+ """Big endian"""
+ c = Counter.new(128)
+ c = Counter.new(128, little_endian=False)
+
+ def test_LE(self):
+ """Little endian"""
+ c = Counter.new(128, little_endian=True)
+
+ def test_nbits(self):
+ c = Counter.new(nbits=128)
+ self.assertRaises(ValueError, Counter.new, 129)
+
+ def test_prefix(self):
+ c = Counter.new(128, prefix=b("xx"))
+
+ def test_suffix(self):
+ c = Counter.new(128, suffix=b("xx"))
+
+ def test_iv(self):
+ c = Counter.new(128, initial_value=2)
+ self.assertRaises(ValueError, Counter.new, 16, initial_value=0x1FFFF)
+
+def get_tests(config={}):
+ from Crypto.SelfTest.st_common import list_test_cases
+ return list_test_cases(CounterTests)
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Util/test_Padding.py b/lib/Crypto/SelfTest/Util/test_Padding.py
new file mode 100644
index 0000000..12e2ec6
--- /dev/null
+++ b/lib/Crypto/SelfTest/Util/test_Padding.py
@@ -0,0 +1,154 @@
+#
+# SelfTest/Util/test_Padding.py: Self-test for padding functions
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify as uh
+
+from Crypto.Util.py3compat import *
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.Padding import pad, unpad
+
+class PKCS7_Tests(unittest.TestCase):
+
+ def test1(self):
+ padded = pad(b(""), 4)
+ self.assertTrue(padded == uh(b("04040404")))
+ padded = pad(b(""), 4, 'pkcs7')
+ self.assertTrue(padded == uh(b("04040404")))
+ back = unpad(padded, 4)
+ self.assertTrue(back == b(""))
+
+ def test2(self):
+ padded = pad(uh(b("12345678")), 4)
+ self.assertTrue(padded == uh(b("1234567804040404")))
+ back = unpad(padded, 4)
+ self.assertTrue(back == uh(b("12345678")))
+
+ def test3(self):
+ padded = pad(uh(b("123456")), 4)
+ self.assertTrue(padded == uh(b("12345601")))
+ back = unpad(padded, 4)
+ self.assertTrue(back == uh(b("123456")))
+
+ def test4(self):
+ padded = pad(uh(b("1234567890")), 4)
+ self.assertTrue(padded == uh(b("1234567890030303")))
+ back = unpad(padded, 4)
+ self.assertTrue(back == uh(b("1234567890")))
+
+ def testn1(self):
+ self.assertRaises(ValueError, pad, uh(b("12")), 4, 'pkcs8')
+
+ def testn2(self):
+ self.assertRaises(ValueError, unpad, b("\0\0\0"), 4)
+ self.assertRaises(ValueError, unpad, b(""), 4)
+
+ def testn3(self):
+ self.assertRaises(ValueError, unpad, b("123456\x02"), 4)
+ self.assertRaises(ValueError, unpad, b("123456\x00"), 4)
+ self.assertRaises(ValueError, unpad, b("123456\x05\x05\x05\x05\x05"), 4)
+
+class X923_Tests(unittest.TestCase):
+
+ def test1(self):
+ padded = pad(b(""), 4, 'x923')
+ self.assertTrue(padded == uh(b("00000004")))
+ back = unpad(padded, 4, 'x923')
+ self.assertTrue(back == b(""))
+
+ def test2(self):
+ padded = pad(uh(b("12345678")), 4, 'x923')
+ self.assertTrue(padded == uh(b("1234567800000004")))
+ back = unpad(padded, 4, 'x923')
+ self.assertTrue(back == uh(b("12345678")))
+
+ def test3(self):
+ padded = pad(uh(b("123456")), 4, 'x923')
+ self.assertTrue(padded == uh(b("12345601")))
+ back = unpad(padded, 4, 'x923')
+ self.assertTrue(back == uh(b("123456")))
+
+ def test4(self):
+ padded = pad(uh(b("1234567890")), 4, 'x923')
+ self.assertTrue(padded == uh(b("1234567890000003")))
+ back = unpad(padded, 4, 'x923')
+ self.assertTrue(back == uh(b("1234567890")))
+
+ def testn1(self):
+ self.assertRaises(ValueError, unpad, b("123456\x02"), 4, 'x923')
+ self.assertRaises(ValueError, unpad, b("123456\x00"), 4, 'x923')
+ self.assertRaises(ValueError, unpad, b("123456\x00\x00\x00\x00\x05"), 4, 'x923')
+ self.assertRaises(ValueError, unpad, b(""), 4, 'x923')
+
+class ISO7816_Tests(unittest.TestCase):
+
+ def test1(self):
+ padded = pad(b(""), 4, 'iso7816')
+ self.assertTrue(padded == uh(b("80000000")))
+ back = unpad(padded, 4, 'iso7816')
+ self.assertTrue(back == b(""))
+
+ def test2(self):
+ padded = pad(uh(b("12345678")), 4, 'iso7816')
+ self.assertTrue(padded == uh(b("1234567880000000")))
+ back = unpad(padded, 4, 'iso7816')
+ self.assertTrue(back == uh(b("12345678")))
+
+ def test3(self):
+ padded = pad(uh(b("123456")), 4, 'iso7816')
+ self.assertTrue(padded == uh(b("12345680")))
+ #import pdb; pdb.set_trace()
+ back = unpad(padded, 4, 'iso7816')
+ self.assertTrue(back == uh(b("123456")))
+
+ def test4(self):
+ padded = pad(uh(b("1234567890")), 4, 'iso7816')
+ self.assertTrue(padded == uh(b("1234567890800000")))
+ back = unpad(padded, 4, 'iso7816')
+ self.assertTrue(back == uh(b("1234567890")))
+
+ def testn1(self):
+ self.assertRaises(ValueError, unpad, b("123456\x81"), 4, 'iso7816')
+ self.assertRaises(ValueError, unpad, b(""), 4, 'iso7816')
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(PKCS7_Tests)
+ tests += list_test_cases(X923_Tests)
+ tests += list_test_cases(ISO7816_Tests)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
diff --git a/lib/Crypto/SelfTest/Util/test_asn1.py b/lib/Crypto/SelfTest/Util/test_asn1.py
new file mode 100644
index 0000000..68292f3
--- /dev/null
+++ b/lib/Crypto/SelfTest/Util/test_asn1.py
@@ -0,0 +1,784 @@
+#
+# SelfTest/Util/test_asn.py: Self-test for the Crypto.Util.asn1 module
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Self-tests for Crypto.Util.asn1"""
+
+import unittest
+
+from Crypto.Util.py3compat import *
+from Crypto.Util.asn1 import (DerObject, DerSetOf, DerInteger,
+ DerBitString,
+ DerObjectId, DerNull, DerOctetString,
+ DerSequence)
+
+class DerObjectTests(unittest.TestCase):
+
+ def testObjInit1(self):
+ # Fail with invalid tag format (must be 1 byte)
+ self.assertRaises(ValueError, DerObject, b('\x00\x99'))
+ # Fail with invalid implicit tag (must be <0x1F)
+ self.assertRaises(ValueError, DerObject, 0x1F)
+
+ # ------
+
+ def testObjEncode1(self):
+ # No payload
+ der = DerObject(b('\x02'))
+ self.assertEqual(der.encode(), b('\x02\x00'))
+ # Small payload (primitive)
+ der.payload = b('\x45')
+ self.assertEqual(der.encode(), b('\x02\x01\x45'))
+ # Invariant
+ self.assertEqual(der.encode(), b('\x02\x01\x45'))
+ # Initialize with numerical tag
+ der = DerObject(0x04)
+ der.payload = b('\x45')
+ self.assertEqual(der.encode(), b('\x04\x01\x45'))
+ # Initialize with constructed type
+ der = DerObject(b('\x10'), constructed=True)
+ self.assertEqual(der.encode(), b('\x30\x00'))
+
+ def testObjEncode2(self):
+ # Initialize with payload
+ der = DerObject(0x03, b('\x12\x12'))
+ self.assertEqual(der.encode(), b('\x03\x02\x12\x12'))
+
+ def testObjEncode3(self):
+ # Long payload
+ der = DerObject(b('\x10'))
+ der.payload = b("0")*128
+ self.assertEqual(der.encode(), b('\x10\x81\x80' + "0"*128))
+
+ def testObjEncode4(self):
+ # Implicit tags (constructed)
+ der = DerObject(0x10, implicit=1, constructed=True)
+ der.payload = b('ppll')
+ self.assertEqual(der.encode(), b('\xa1\x04ppll'))
+ # Implicit tags (primitive)
+ der = DerObject(0x02, implicit=0x1E, constructed=False)
+ der.payload = b('ppll')
+ self.assertEqual(der.encode(), b('\x9E\x04ppll'))
+
+ def testObjEncode5(self):
+ # Encode type with explicit tag
+ der = DerObject(0x10, explicit=5)
+ der.payload = b("xxll")
+ self.assertEqual(der.encode(), b("\xa5\x06\x10\x04xxll"))
+
+ # -----
+
+ def testObjDecode1(self):
+ # Decode short payload
+ der = DerObject(0x02)
+ der.decode(b('\x02\x02\x01\x02'))
+ self.assertEqual(der.payload, b("\x01\x02"))
+ self.assertEqual(der._tag_octet, 0x02)
+
+ def testObjDecode2(self):
+ # Decode long payload
+ der = DerObject(0x02)
+ der.decode(b('\x02\x81\x80' + "1"*128))
+ self.assertEqual(der.payload, b("1")*128)
+ self.assertEqual(der._tag_octet, 0x02)
+
+ def testObjDecode3(self):
+ # Decode payload with too much data gives error
+ der = DerObject(0x02)
+ self.assertRaises(ValueError, der.decode, b('\x02\x02\x01\x02\xFF'))
+ # Decode payload with too little data gives error
+ der = DerObject(0x02)
+ self.assertRaises(ValueError, der.decode, b('\x02\x02\x01'))
+
+ def testObjDecode4(self):
+ # Decode implicit tag (primitive)
+ der = DerObject(0x02, constructed=False, implicit=0xF)
+ self.assertRaises(ValueError, der.decode, b('\x02\x02\x01\x02'))
+ der.decode(b('\x8F\x01\x00'))
+ self.assertEqual(der.payload, b('\x00'))
+ # Decode implicit tag (constructed)
+ der = DerObject(0x02, constructed=True, implicit=0xF)
+ self.assertRaises(ValueError, der.decode, b('\x02\x02\x01\x02'))
+ der.decode(b('\xAF\x01\x00'))
+ self.assertEqual(der.payload, b('\x00'))
+
+ def testObjDecode5(self):
+ # Decode payload with unexpected tag gives error
+ der = DerObject(0x02)
+ self.assertRaises(ValueError, der.decode, b('\x03\x02\x01\x02'))
+
+ def testObjDecode6(self):
+ # Arbitrary DER object
+ der = DerObject()
+ der.decode(b('\x65\x01\x88'))
+ self.assertEqual(der._tag_octet, 0x65)
+ self.assertEqual(der.payload, b('\x88'))
+
+ def testObjDecode7(self):
+ # Decode explicit tag
+ der = DerObject(0x10, explicit=5)
+ der.decode(b("\xa5\x06\x10\x04xxll"))
+ self.assertEqual(der._inner_tag_octet, 0x10)
+ self.assertEqual(der.payload, b('xxll'))
+
+ # Explicit tag may be 0
+ der = DerObject(0x10, explicit=0)
+ der.decode(b("\xa0\x06\x10\x04xxll"))
+ self.assertEqual(der._inner_tag_octet, 0x10)
+ self.assertEqual(der.payload, b('xxll'))
+
+ def testObjDecode8(self):
+ # Verify that decode returns the object
+ der = DerObject(0x02)
+ self.assertEqual(der, der.decode(b('\x02\x02\x01\x02')))
+
+class DerIntegerTests(unittest.TestCase):
+
+ def testInit1(self):
+ der = DerInteger(1)
+ self.assertEqual(der.encode(), b('\x02\x01\x01'))
+
+ def testEncode1(self):
+ # Single-byte integers
+ # Value 0
+ der = DerInteger(0)
+ self.assertEqual(der.encode(), b('\x02\x01\x00'))
+ # Value 1
+ der = DerInteger(1)
+ self.assertEqual(der.encode(), b('\x02\x01\x01'))
+ # Value 127
+ der = DerInteger(127)
+ self.assertEqual(der.encode(), b('\x02\x01\x7F'))
+
+ def testEncode2(self):
+ # Multi-byte integers
+ # Value 128
+ der = DerInteger(128)
+ self.assertEqual(der.encode(), b('\x02\x02\x00\x80'))
+ # Value 0x180
+ der = DerInteger(0x180)
+ self.assertEqual(der.encode(), b('\x02\x02\x01\x80'))
+ # One very long integer
+ der = DerInteger(2**2048)
+ self.assertEqual(der.encode(),
+ b('\x02\x82\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00'))
+
+ def testEncode3(self):
+ # Negative integers
+ # Value -1
+ der = DerInteger(-1)
+ self.assertEqual(der.encode(), b('\x02\x01\xFF'))
+ # Value -128
+ der = DerInteger(-128)
+ self.assertEqual(der.encode(), b('\x02\x01\x80'))
+ # Value
+ der = DerInteger(-87873)
+ self.assertEqual(der.encode(), b('\x02\x03\xFE\xA8\xBF'))
+
+ def testEncode4(self):
+ # Explicit encoding
+ number = DerInteger(0x34, explicit=3)
+ self.assertEqual(number.encode(), b('\xa3\x03\x02\x01\x34'))
+
+ # -----
+
+ def testDecode1(self):
+ # Single-byte integer
+ der = DerInteger()
+ # Value 0
+ der.decode(b('\x02\x01\x00'))
+ self.assertEqual(der.value, 0)
+ # Value 1
+ der.decode(b('\x02\x01\x01'))
+ self.assertEqual(der.value, 1)
+ # Value 127
+ der.decode(b('\x02\x01\x7F'))
+ self.assertEqual(der.value, 127)
+
+ def testDecode2(self):
+ # Multi-byte integer
+ der = DerInteger()
+ # Value 0x180L
+ der.decode(b('\x02\x02\x01\x80'))
+ self.assertEqual(der.value,0x180)
+ # One very long integer
+ der.decode(
+ b('\x02\x82\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00'))
+ self.assertEqual(der.value,2**2048)
+
+ def testDecode3(self):
+ # Negative integer
+ der = DerInteger()
+ # Value -1
+ der.decode(b('\x02\x01\xFF'))
+ self.assertEqual(der.value, -1)
+ # Value -32768
+ der.decode(b('\x02\x02\x80\x00'))
+ self.assertEqual(der.value, -32768)
+
+ def testDecode5(self):
+ # We still accept BER integer format
+ der = DerInteger()
+ # Redundant leading zeroes
+ der.decode(b('\x02\x02\x00\x01'))
+ self.assertEqual(der.value, 1)
+ # Redundant leading 0xFF
+ der.decode(b('\x02\x02\xFF\xFF'))
+ self.assertEqual(der.value, -1)
+ # Empty payload
+ der.decode(b('\x02\x00'))
+ self.assertEqual(der.value, 0)
+
+ def testDecode6(self):
+ # Explicit encoding
+ number = DerInteger(explicit=3)
+ number.decode(b('\xa3\x03\x02\x01\x34'))
+ self.assertEqual(number.value, 0x34)
+
+ def testDecode7(self):
+ # Verify decode returns the DerInteger
+ der = DerInteger()
+ self.assertEqual(der, der.decode(b('\x02\x01\x7F')))
+
+ ###
+
+ def testStrict1(self):
+ number = DerInteger()
+
+ number.decode(b'\x02\x02\x00\x01')
+ number.decode(b'\x02\x02\x00\x7F')
+ self.assertRaises(ValueError, number.decode, b'\x02\x02\x00\x01', strict=True)
+ self.assertRaises(ValueError, number.decode, b'\x02\x02\x00\x7F', strict=True)
+
+ ###
+
+ def testErrDecode1(self):
+ # Wide length field
+ der = DerInteger()
+ self.assertRaises(ValueError, der.decode, b('\x02\x81\x01\x01'))
+
+
+class DerSequenceTests(unittest.TestCase):
+
+ def testInit1(self):
+ der = DerSequence([1, DerInteger(2), b('0\x00')])
+ self.assertEqual(der.encode(), b('0\x08\x02\x01\x01\x02\x01\x020\x00'))
+
+ def testEncode1(self):
+ # Empty sequence
+ der = DerSequence()
+ self.assertEqual(der.encode(), b('0\x00'))
+ self.assertFalse(der.hasOnlyInts())
+ # One single-byte integer (zero)
+ der.append(0)
+ self.assertEqual(der.encode(), b('0\x03\x02\x01\x00'))
+ self.assertEqual(der.hasInts(),1)
+ self.assertEqual(der.hasInts(False),1)
+ self.assertTrue(der.hasOnlyInts())
+ self.assertTrue(der.hasOnlyInts(False))
+ # Invariant
+ self.assertEqual(der.encode(), b('0\x03\x02\x01\x00'))
+
+ def testEncode2(self):
+ # Indexing
+ der = DerSequence()
+ der.append(0)
+ der[0] = 1
+ self.assertEqual(len(der),1)
+ self.assertEqual(der[0],1)
+ self.assertEqual(der[-1],1)
+ self.assertEqual(der.encode(), b('0\x03\x02\x01\x01'))
+ #
+ der[:] = [1]
+ self.assertEqual(len(der),1)
+ self.assertEqual(der[0],1)
+ self.assertEqual(der.encode(), b('0\x03\x02\x01\x01'))
+
+ def testEncode3(self):
+ # One multi-byte integer (non-zero)
+ der = DerSequence()
+ der.append(0x180)
+ self.assertEqual(der.encode(), b('0\x04\x02\x02\x01\x80'))
+
+ def testEncode4(self):
+ # One very long integer
+ der = DerSequence()
+ der.append(2**2048)
+ self.assertEqual(der.encode(), b('0\x82\x01\x05')+
+ b('\x02\x82\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00'))
+
+ def testEncode5(self):
+ der = DerSequence()
+ der += 1
+ der += b('\x30\x00')
+ self.assertEqual(der.encode(), b('\x30\x05\x02\x01\x01\x30\x00'))
+
+ def testEncode6(self):
+ # Two positive integers
+ der = DerSequence()
+ der.append(0x180)
+ der.append(0xFF)
+ self.assertEqual(der.encode(), b('0\x08\x02\x02\x01\x80\x02\x02\x00\xff'))
+ self.assertTrue(der.hasOnlyInts())
+ self.assertTrue(der.hasOnlyInts(False))
+ # Two mixed integers
+ der = DerSequence()
+ der.append(2)
+ der.append(-2)
+ self.assertEqual(der.encode(), b('0\x06\x02\x01\x02\x02\x01\xFE'))
+ self.assertEqual(der.hasInts(), 1)
+ self.assertEqual(der.hasInts(False), 2)
+ self.assertFalse(der.hasOnlyInts())
+ self.assertTrue(der.hasOnlyInts(False))
+ #
+ der.append(0x01)
+ der[1:] = [9,8]
+ self.assertEqual(len(der),3)
+ self.assertEqual(der[1:],[9,8])
+ self.assertEqual(der[1:-1],[9])
+ self.assertEqual(der.encode(), b('0\x09\x02\x01\x02\x02\x01\x09\x02\x01\x08'))
+
+ def testEncode7(self):
+ # One integer and another type (already encoded)
+ der = DerSequence()
+ der.append(0x180)
+ der.append(b('0\x03\x02\x01\x05'))
+ self.assertEqual(der.encode(), b('0\x09\x02\x02\x01\x800\x03\x02\x01\x05'))
+ self.assertFalse(der.hasOnlyInts())
+
+ def testEncode8(self):
+ # One integer and another type (yet to encode)
+ der = DerSequence()
+ der.append(0x180)
+ der.append(DerSequence([5]))
+ self.assertEqual(der.encode(), b('0\x09\x02\x02\x01\x800\x03\x02\x01\x05'))
+ self.assertFalse(der.hasOnlyInts())
+
+ ####
+
+ def testDecode1(self):
+ # Empty sequence
+ der = DerSequence()
+ der.decode(b('0\x00'))
+ self.assertEqual(len(der),0)
+ # One single-byte integer (zero)
+ der.decode(b('0\x03\x02\x01\x00'))
+ self.assertEqual(len(der),1)
+ self.assertEqual(der[0],0)
+ # Invariant
+ der.decode(b('0\x03\x02\x01\x00'))
+ self.assertEqual(len(der),1)
+ self.assertEqual(der[0],0)
+
+ def testDecode2(self):
+ # One single-byte integer (non-zero)
+ der = DerSequence()
+ der.decode(b('0\x03\x02\x01\x7f'))
+ self.assertEqual(len(der),1)
+ self.assertEqual(der[0],127)
+
+ def testDecode4(self):
+ # One very long integer
+ der = DerSequence()
+ der.decode(b('0\x82\x01\x05')+
+ b('\x02\x82\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')+
+ b('\x00\x00\x00\x00\x00\x00\x00\x00\x00'))
+ self.assertEqual(len(der),1)
+ self.assertEqual(der[0],2**2048)
+
+ def testDecode6(self):
+ # Two integers
+ der = DerSequence()
+ der.decode(b('0\x08\x02\x02\x01\x80\x02\x02\x00\xff'))
+ self.assertEqual(len(der),2)
+ self.assertEqual(der[0],0x180)
+ self.assertEqual(der[1],0xFF)
+
+ def testDecode7(self):
+ # One integer and 2 other types
+ der = DerSequence()
+ der.decode(b('0\x0A\x02\x02\x01\x80\x24\x02\xb6\x63\x12\x00'))
+ self.assertEqual(len(der),3)
+ self.assertEqual(der[0],0x180)
+ self.assertEqual(der[1],b('\x24\x02\xb6\x63'))
+ self.assertEqual(der[2],b('\x12\x00'))
+
+ def testDecode8(self):
+ # Only 2 other types
+ der = DerSequence()
+ der.decode(b('0\x06\x24\x02\xb6\x63\x12\x00'))
+ self.assertEqual(len(der),2)
+ self.assertEqual(der[0],b('\x24\x02\xb6\x63'))
+ self.assertEqual(der[1],b('\x12\x00'))
+ self.assertEqual(der.hasInts(), 0)
+ self.assertEqual(der.hasInts(False), 0)
+ self.assertFalse(der.hasOnlyInts())
+ self.assertFalse(der.hasOnlyInts(False))
+
+ def testDecode9(self):
+ # Verify that decode returns itself
+ der = DerSequence()
+ self.assertEqual(der, der.decode(b('0\x06\x24\x02\xb6\x63\x12\x00')))
+
+ ###
+
+ def testErrDecode1(self):
+ # Not a sequence
+ der = DerSequence()
+ self.assertRaises(ValueError, der.decode, b(''))
+ self.assertRaises(ValueError, der.decode, b('\x00'))
+ self.assertRaises(ValueError, der.decode, b('\x30'))
+
+ def testErrDecode2(self):
+ der = DerSequence()
+ # Too much data
+ self.assertRaises(ValueError, der.decode, b('\x30\x00\x00'))
+
+ def testErrDecode3(self):
+ # Wrong length format
+ der = DerSequence()
+ # Missing length in sub-item
+ self.assertRaises(ValueError, der.decode, b('\x30\x04\x02\x01\x01\x00'))
+ # Valid BER, but invalid DER length
+ self.assertRaises(ValueError, der.decode, b('\x30\x81\x03\x02\x01\x01'))
+ self.assertRaises(ValueError, der.decode, b('\x30\x04\x02\x81\x01\x01'))
+
+ def test_expected_nr_elements(self):
+ der_bin = DerSequence([1, 2, 3]).encode()
+
+ DerSequence().decode(der_bin, nr_elements=3)
+ DerSequence().decode(der_bin, nr_elements=(2,3))
+ self.assertRaises(ValueError, DerSequence().decode, der_bin, nr_elements=1)
+ self.assertRaises(ValueError, DerSequence().decode, der_bin, nr_elements=(4,5))
+
+ def test_expected_only_integers(self):
+
+ der_bin1 = DerSequence([1, 2, 3]).encode()
+ der_bin2 = DerSequence([1, 2, DerSequence([3, 4])]).encode()
+
+ DerSequence().decode(der_bin1, only_ints_expected=True)
+ DerSequence().decode(der_bin1, only_ints_expected=False)
+ DerSequence().decode(der_bin2, only_ints_expected=False)
+ self.assertRaises(ValueError, DerSequence().decode, der_bin2, only_ints_expected=True)
+
+
+class DerOctetStringTests(unittest.TestCase):
+
+ def testInit1(self):
+ der = DerOctetString(b('\xFF'))
+ self.assertEqual(der.encode(), b('\x04\x01\xFF'))
+
+ def testEncode1(self):
+ # Empty sequence
+ der = DerOctetString()
+ self.assertEqual(der.encode(), b('\x04\x00'))
+ # Small payload
+ der.payload = b('\x01\x02')
+ self.assertEqual(der.encode(), b('\x04\x02\x01\x02'))
+
+ ####
+
+ def testDecode1(self):
+ # Empty sequence
+ der = DerOctetString()
+ der.decode(b('\x04\x00'))
+ self.assertEqual(der.payload, b(''))
+ # Small payload
+ der.decode(b('\x04\x02\x01\x02'))
+ self.assertEqual(der.payload, b('\x01\x02'))
+
+ def testDecode2(self):
+ # Verify that decode returns the object
+ der = DerOctetString()
+ self.assertEqual(der, der.decode(b('\x04\x00')))
+
+ def testErrDecode1(self):
+ # No leftovers allowed
+ der = DerOctetString()
+ self.assertRaises(ValueError, der.decode, b('\x04\x01\x01\xff'))
+
+class DerNullTests(unittest.TestCase):
+
+ def testEncode1(self):
+ der = DerNull()
+ self.assertEqual(der.encode(), b('\x05\x00'))
+
+ ####
+
+ def testDecode1(self):
+ # Empty sequence
+ der = DerNull()
+ self.assertEqual(der, der.decode(b('\x05\x00')))
+
+class DerObjectIdTests(unittest.TestCase):
+
+ def testInit1(self):
+ der = DerObjectId("1.1")
+ self.assertEqual(der.encode(), b('\x06\x01)'))
+
+ def testEncode1(self):
+ der = DerObjectId('1.2.840.113549.1.1.1')
+ self.assertEqual(der.encode(), b('\x06\x09\x2A\x86\x48\x86\xF7\x0D\x01\x01\x01'))
+ #
+ der = DerObjectId()
+ der.value = '1.2.840.113549.1.1.1'
+ self.assertEqual(der.encode(), b('\x06\x09\x2A\x86\x48\x86\xF7\x0D\x01\x01\x01'))
+
+ ####
+
+ def testDecode1(self):
+ # Empty sequence
+ der = DerObjectId()
+ der.decode(b('\x06\x09\x2A\x86\x48\x86\xF7\x0D\x01\x01\x01'))
+ self.assertEqual(der.value, '1.2.840.113549.1.1.1')
+
+ def testDecode2(self):
+ # Verify that decode returns the object
+ der = DerObjectId()
+ self.assertEqual(der,
+ der.decode(b('\x06\x09\x2A\x86\x48\x86\xF7\x0D\x01\x01\x01')))
+
+ def testDecode3(self):
+ der = DerObjectId()
+ der.decode(b('\x06\x09\x2A\x86\x48\x86\xF7\x0D\x01\x00\x01'))
+ self.assertEqual(der.value, '1.2.840.113549.1.0.1')
+
+
+class DerBitStringTests(unittest.TestCase):
+
+ def testInit1(self):
+ der = DerBitString(b("\xFF"))
+ self.assertEqual(der.encode(), b('\x03\x02\x00\xFF'))
+
+ def testInit2(self):
+ der = DerBitString(DerInteger(1))
+ self.assertEqual(der.encode(), b('\x03\x04\x00\x02\x01\x01'))
+
+ def testEncode1(self):
+ # Empty sequence
+ der = DerBitString()
+ self.assertEqual(der.encode(), b('\x03\x01\x00'))
+ # Small payload
+ der = DerBitString(b('\x01\x02'))
+ self.assertEqual(der.encode(), b('\x03\x03\x00\x01\x02'))
+ # Small payload
+ der = DerBitString()
+ der.value = b('\x01\x02')
+ self.assertEqual(der.encode(), b('\x03\x03\x00\x01\x02'))
+
+ ####
+
+ def testDecode1(self):
+ # Empty sequence
+ der = DerBitString()
+ der.decode(b('\x03\x00'))
+ self.assertEqual(der.value, b(''))
+ # Small payload
+ der.decode(b('\x03\x03\x00\x01\x02'))
+ self.assertEqual(der.value, b('\x01\x02'))
+
+ def testDecode2(self):
+ # Verify that decode returns the object
+ der = DerBitString()
+ self.assertEqual(der, der.decode(b('\x03\x00')))
+
+
+class DerSetOfTests(unittest.TestCase):
+
+ def testInit1(self):
+ der = DerSetOf([DerInteger(1), DerInteger(2)])
+ self.assertEqual(der.encode(), b('1\x06\x02\x01\x01\x02\x01\x02'))
+
+ def testEncode1(self):
+ # Empty set
+ der = DerSetOf()
+ self.assertEqual(der.encode(), b('1\x00'))
+ # One single-byte integer (zero)
+ der.add(0)
+ self.assertEqual(der.encode(), b('1\x03\x02\x01\x00'))
+ # Invariant
+ self.assertEqual(der.encode(), b('1\x03\x02\x01\x00'))
+
+ def testEncode2(self):
+ # Two integers
+ der = DerSetOf()
+ der.add(0x180)
+ der.add(0xFF)
+ self.assertEqual(der.encode(), b('1\x08\x02\x02\x00\xff\x02\x02\x01\x80'))
+ # Initialize with integers
+ der = DerSetOf([0x180, 0xFF])
+ self.assertEqual(der.encode(), b('1\x08\x02\x02\x00\xff\x02\x02\x01\x80'))
+
+ def testEncode3(self):
+ # One integer and another type (no matter what it is)
+ der = DerSetOf()
+ der.add(0x180)
+ self.assertRaises(ValueError, der.add, b('\x00\x02\x00\x00'))
+
+ def testEncode4(self):
+ # Only non integers
+ der = DerSetOf()
+ der.add(b('\x01\x00'))
+ der.add(b('\x01\x01\x01'))
+ self.assertEqual(der.encode(), b('1\x05\x01\x00\x01\x01\x01'))
+
+ ####
+
+ def testDecode1(self):
+ # Empty sequence
+ der = DerSetOf()
+ der.decode(b('1\x00'))
+ self.assertEqual(len(der),0)
+ # One single-byte integer (zero)
+ der.decode(b('1\x03\x02\x01\x00'))
+ self.assertEqual(len(der),1)
+ self.assertEqual(list(der),[0])
+
+ def testDecode2(self):
+ # Two integers
+ der = DerSetOf()
+ der.decode(b('1\x08\x02\x02\x01\x80\x02\x02\x00\xff'))
+ self.assertEqual(len(der),2)
+ l = list(der)
+ self.assertTrue(0x180 in l)
+ self.assertTrue(0xFF in l)
+
+ def testDecode3(self):
+ # One integer and 2 other types
+ der = DerSetOf()
+ #import pdb; pdb.set_trace()
+ self.assertRaises(ValueError, der.decode,
+ b('0\x0A\x02\x02\x01\x80\x24\x02\xb6\x63\x12\x00'))
+
+ def testDecode4(self):
+ # Verify that decode returns the object
+ der = DerSetOf()
+ self.assertEqual(der,
+ der.decode(b('1\x08\x02\x02\x01\x80\x02\x02\x00\xff')))
+
+ ###
+
+ def testErrDecode1(self):
+ # No leftovers allowed
+ der = DerSetOf()
+ self.assertRaises(ValueError, der.decode,
+ b('1\x08\x02\x02\x01\x80\x02\x02\x00\xff\xAA'))
+
+def get_tests(config={}):
+ from Crypto.SelfTest.st_common import list_test_cases
+ listTests = []
+ listTests += list_test_cases(DerObjectTests)
+ listTests += list_test_cases(DerIntegerTests)
+ listTests += list_test_cases(DerSequenceTests)
+ listTests += list_test_cases(DerOctetStringTests)
+ listTests += list_test_cases(DerNullTests)
+ listTests += list_test_cases(DerObjectIdTests)
+ listTests += list_test_cases(DerBitStringTests)
+ listTests += list_test_cases(DerSetOfTests)
+ return listTests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Util/test_number.py b/lib/Crypto/SelfTest/Util/test_number.py
new file mode 100644
index 0000000..bb143f3
--- /dev/null
+++ b/lib/Crypto/SelfTest/Util/test_number.py
@@ -0,0 +1,192 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/Util/test_number.py: Self-test for parts of the Crypto.Util.number module
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self-tests for (some of) Crypto.Util.number"""
+
+import math
+import unittest
+
+from Crypto.Util.py3compat import *
+from Crypto.SelfTest.st_common import list_test_cases
+
+from Crypto.Util import number
+from Crypto.Util.number import long_to_bytes
+
+
+class MyError(Exception):
+ """Dummy exception used for tests"""
+
+# NB: In some places, we compare tuples instead of just output values so that
+# if any inputs cause a test failure, we'll be able to tell which ones.
+
+class MiscTests(unittest.TestCase):
+
+ def test_ceil_div(self):
+ """Util.number.ceil_div"""
+ self.assertRaises(TypeError, number.ceil_div, "1", 1)
+ self.assertRaises(ZeroDivisionError, number.ceil_div, 1, 0)
+ self.assertRaises(ZeroDivisionError, number.ceil_div, -1, 0)
+
+ # b = 1
+ self.assertEqual(0, number.ceil_div(0, 1))
+ self.assertEqual(1, number.ceil_div(1, 1))
+ self.assertEqual(2, number.ceil_div(2, 1))
+ self.assertEqual(3, number.ceil_div(3, 1))
+
+ # b = 2
+ self.assertEqual(0, number.ceil_div(0, 2))
+ self.assertEqual(1, number.ceil_div(1, 2))
+ self.assertEqual(1, number.ceil_div(2, 2))
+ self.assertEqual(2, number.ceil_div(3, 2))
+ self.assertEqual(2, number.ceil_div(4, 2))
+ self.assertEqual(3, number.ceil_div(5, 2))
+
+ # b = 3
+ self.assertEqual(0, number.ceil_div(0, 3))
+ self.assertEqual(1, number.ceil_div(1, 3))
+ self.assertEqual(1, number.ceil_div(2, 3))
+ self.assertEqual(1, number.ceil_div(3, 3))
+ self.assertEqual(2, number.ceil_div(4, 3))
+ self.assertEqual(2, number.ceil_div(5, 3))
+ self.assertEqual(2, number.ceil_div(6, 3))
+ self.assertEqual(3, number.ceil_div(7, 3))
+
+ # b = 4
+ self.assertEqual(0, number.ceil_div(0, 4))
+ self.assertEqual(1, number.ceil_div(1, 4))
+ self.assertEqual(1, number.ceil_div(2, 4))
+ self.assertEqual(1, number.ceil_div(3, 4))
+ self.assertEqual(1, number.ceil_div(4, 4))
+ self.assertEqual(2, number.ceil_div(5, 4))
+ self.assertEqual(2, number.ceil_div(6, 4))
+ self.assertEqual(2, number.ceil_div(7, 4))
+ self.assertEqual(2, number.ceil_div(8, 4))
+ self.assertEqual(3, number.ceil_div(9, 4))
+
+ def test_getPrime(self):
+ """Util.number.getPrime"""
+ self.assertRaises(ValueError, number.getPrime, -100)
+ self.assertRaises(ValueError, number.getPrime, 0)
+ self.assertRaises(ValueError, number.getPrime, 1)
+
+ bits = 4
+ for i in range(100):
+ x = number.getPrime(bits)
+ self.assertEqual(x >= (1 << bits - 1), 1)
+ self.assertEqual(x < (1 << bits), 1)
+
+ bits = 512
+ x = number.getPrime(bits)
+ self.assertNotEqual(x % 2, 0)
+ self.assertEqual(x >= (1 << bits - 1), 1)
+ self.assertEqual(x < (1 << bits), 1)
+
+ def test_getStrongPrime(self):
+ """Util.number.getStrongPrime"""
+ self.assertRaises(ValueError, number.getStrongPrime, 256)
+ self.assertRaises(ValueError, number.getStrongPrime, 513)
+ bits = 512
+ x = number.getStrongPrime(bits)
+ self.assertNotEqual(x % 2, 0)
+ self.assertEqual(x > (1 << bits-1)-1, 1)
+ self.assertEqual(x < (1 << bits), 1)
+ e = 2**16+1
+ x = number.getStrongPrime(bits, e)
+ self.assertEqual(number.GCD(x-1, e), 1)
+ self.assertNotEqual(x % 2, 0)
+ self.assertEqual(x > (1 << bits-1)-1, 1)
+ self.assertEqual(x < (1 << bits), 1)
+ e = 2**16+2
+ x = number.getStrongPrime(bits, e)
+ self.assertEqual(number.GCD((x-1)>>1, e), 1)
+ self.assertNotEqual(x % 2, 0)
+ self.assertEqual(x > (1 << bits-1)-1, 1)
+ self.assertEqual(x < (1 << bits), 1)
+
+ def test_isPrime(self):
+ """Util.number.isPrime"""
+ self.assertEqual(number.isPrime(-3), False) # Regression test: negative numbers should not be prime
+ self.assertEqual(number.isPrime(-2), False) # Regression test: negative numbers should not be prime
+ self.assertEqual(number.isPrime(1), False) # Regression test: isPrime(1) caused some versions of PyCrypto to crash.
+ self.assertEqual(number.isPrime(2), True)
+ self.assertEqual(number.isPrime(3), True)
+ self.assertEqual(number.isPrime(4), False)
+ self.assertEqual(number.isPrime(2**1279-1), True)
+ self.assertEqual(number.isPrime(-(2**1279-1)), False) # Regression test: negative numbers should not be prime
+ # test some known gmp pseudo-primes taken from
+ # http://www.trnicely.net/misc/mpzspsp.html
+ for composite in (43 * 127 * 211, 61 * 151 * 211, 15259 * 30517,
+ 346141 * 692281, 1007119 * 2014237, 3589477 * 7178953,
+ 4859419 * 9718837, 2730439 * 5460877,
+ 245127919 * 490255837, 963939391 * 1927878781,
+ 4186358431 * 8372716861, 1576820467 * 3153640933):
+ self.assertEqual(number.isPrime(int(composite)), False)
+
+ def test_size(self):
+ self.assertEqual(number.size(2),2)
+ self.assertEqual(number.size(3),2)
+ self.assertEqual(number.size(0xa2),8)
+ self.assertEqual(number.size(0xa2ba40),8*3)
+ self.assertEqual(number.size(0xa2ba40ee07e3b2bd2f02ce227f36a195024486e49c19cb41bbbdfbba98b22b0e577c2eeaffa20d883a76e65e394c69d4b3c05a1e8fadda27edb2a42bc000fe888b9b32c22d15add0cd76b3e7936e19955b220dd17d4ea904b1ec102b2e4de7751222aa99151024c7cb41cc5ea21d00eeb41f7c800834d2c6e06bce3bce7ea9a5), 1024)
+ self.assertRaises(ValueError, number.size, -1)
+
+
+class LongTests(unittest.TestCase):
+
+ def test1(self):
+ self.assertEqual(long_to_bytes(0), b'\x00')
+ self.assertEqual(long_to_bytes(1), b'\x01')
+ self.assertEqual(long_to_bytes(0x100), b'\x01\x00')
+ self.assertEqual(long_to_bytes(0xFF00000000), b'\xFF\x00\x00\x00\x00')
+ self.assertEqual(long_to_bytes(0xFF00000000), b'\xFF\x00\x00\x00\x00')
+ self.assertEqual(long_to_bytes(0x1122334455667788), b'\x11\x22\x33\x44\x55\x66\x77\x88')
+ self.assertEqual(long_to_bytes(0x112233445566778899), b'\x11\x22\x33\x44\x55\x66\x77\x88\x99')
+
+ def test2(self):
+ self.assertEqual(long_to_bytes(0, 1), b'\x00')
+ self.assertEqual(long_to_bytes(0, 2), b'\x00\x00')
+ self.assertEqual(long_to_bytes(1, 3), b'\x00\x00\x01')
+ self.assertEqual(long_to_bytes(65535, 2), b'\xFF\xFF')
+ self.assertEqual(long_to_bytes(65536, 2), b'\x00\x01\x00\x00')
+ self.assertEqual(long_to_bytes(0x100, 1), b'\x01\x00')
+ self.assertEqual(long_to_bytes(0xFF00000001, 6), b'\x00\xFF\x00\x00\x00\x01')
+ self.assertEqual(long_to_bytes(0xFF00000001, 8), b'\x00\x00\x00\xFF\x00\x00\x00\x01')
+ self.assertEqual(long_to_bytes(0xFF00000001, 10), b'\x00\x00\x00\x00\x00\xFF\x00\x00\x00\x01')
+ self.assertEqual(long_to_bytes(0xFF00000001, 11), b'\x00\x00\x00\x00\x00\x00\xFF\x00\x00\x00\x01')
+
+ def test_err1(self):
+ self.assertRaises(ValueError, long_to_bytes, -1)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(MiscTests)
+ tests += list_test_cases(LongTests)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/Util/test_rfc1751.py b/lib/Crypto/SelfTest/Util/test_rfc1751.py
new file mode 100644
index 0000000..af0aa2b
--- /dev/null
+++ b/lib/Crypto/SelfTest/Util/test_rfc1751.py
@@ -0,0 +1,38 @@
+import unittest
+
+import binascii
+from Crypto.Util.RFC1751 import key_to_english, english_to_key
+
+
+class RFC1751_Tests(unittest.TestCase):
+
+ def test1(self):
+ data = [
+ ('EB33F77EE73D4053', 'TIDE ITCH SLOW REIN RULE MOT'),
+ ('CCAC2AED591056BE4F90FD441C534766', 'RASH BUSH MILK LOOK BAD BRIM AVID GAFF BAIT ROT POD LOVE'),
+ ('EFF81F9BFBC65350920CDD7416DE8009', 'TROD MUTE TAIL WARM CHAR KONG HAAG CITY BORE O TEAL AWL')
+ ]
+
+ for key_hex, words in data:
+ key_bin = binascii.a2b_hex(key_hex)
+
+ w2 = key_to_english(key_bin)
+ self.assertEqual(w2, words)
+
+ k2 = english_to_key(words)
+ self.assertEqual(k2, key_bin)
+
+ def test_error_key_to_english(self):
+
+ self.assertRaises(ValueError, key_to_english, b'0' * 7)
+
+
+def get_tests(config={}):
+ from Crypto.SelfTest.st_common import list_test_cases
+ tests = list_test_cases(RFC1751_Tests)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/Util/test_strxor.py b/lib/Crypto/SelfTest/Util/test_strxor.py
new file mode 100644
index 0000000..c91d38f
--- /dev/null
+++ b/lib/Crypto/SelfTest/Util/test_strxor.py
@@ -0,0 +1,280 @@
+#
+# SelfTest/Util/test_strxor.py: Self-test for XORing
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import unittest
+from binascii import unhexlify, hexlify
+
+from Crypto.SelfTest.st_common import list_test_cases
+from Crypto.Util.strxor import strxor, strxor_c
+
+
+class StrxorTests(unittest.TestCase):
+
+ def test1(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term2 = unhexlify(b"383d4ba020573314395b")
+ result = unhexlify(b"c70ed123c59a7fcb6f12")
+ self.assertEqual(strxor(term1, term2), result)
+ self.assertEqual(strxor(term2, term1), result)
+
+ def test2(self):
+ es = b""
+ self.assertEqual(strxor(es, es), es)
+
+ def test3(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ all_zeros = b"\x00" * len(term1)
+ self.assertEqual(strxor(term1, term1), all_zeros)
+
+ def test_wrong_length(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term2 = unhexlify(b"ff339a83e5cd4cdf564990")
+ self.assertRaises(ValueError, strxor, term1, term2)
+
+ def test_bytearray(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term1_ba = bytearray(term1)
+ term2 = unhexlify(b"383d4ba020573314395b")
+ result = unhexlify(b"c70ed123c59a7fcb6f12")
+
+ self.assertEqual(strxor(term1_ba, term2), result)
+
+ def test_memoryview(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term1_mv = memoryview(term1)
+ term2 = unhexlify(b"383d4ba020573314395b")
+ result = unhexlify(b"c70ed123c59a7fcb6f12")
+
+ self.assertEqual(strxor(term1_mv, term2), result)
+
+ def test_output_bytearray(self):
+ """Verify result can be stored in pre-allocated memory"""
+
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term2 = unhexlify(b"383d4ba020573314395b")
+ original_term1 = term1[:]
+ original_term2 = term2[:]
+ expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
+ output = bytearray(len(term1))
+
+ result = strxor(term1, term2, output=output)
+
+ self.assertEqual(result, None)
+ self.assertEqual(output, expected_xor)
+ self.assertEqual(term1, original_term1)
+ self.assertEqual(term2, original_term2)
+
+ def test_output_memoryview(self):
+ """Verify result can be stored in pre-allocated memory"""
+
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term2 = unhexlify(b"383d4ba020573314395b")
+ original_term1 = term1[:]
+ original_term2 = term2[:]
+ expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
+ output = memoryview(bytearray(len(term1)))
+
+ result = strxor(term1, term2, output=output)
+
+ self.assertEqual(result, None)
+ self.assertEqual(output, expected_xor)
+ self.assertEqual(term1, original_term1)
+ self.assertEqual(term2, original_term2)
+
+ def test_output_overlapping_bytearray(self):
+ """Verify result can be stored in overlapping memory"""
+
+ term1 = bytearray(unhexlify(b"ff339a83e5cd4cdf5649"))
+ term2 = unhexlify(b"383d4ba020573314395b")
+ original_term2 = term2[:]
+ expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
+
+ result = strxor(term1, term2, output=term1)
+
+ self.assertEqual(result, None)
+ self.assertEqual(term1, expected_xor)
+ self.assertEqual(term2, original_term2)
+
+ def test_output_overlapping_memoryview(self):
+ """Verify result can be stored in overlapping memory"""
+
+ term1 = memoryview(bytearray(unhexlify(b"ff339a83e5cd4cdf5649")))
+ term2 = unhexlify(b"383d4ba020573314395b")
+ original_term2 = term2[:]
+ expected_xor = unhexlify(b"c70ed123c59a7fcb6f12")
+
+ result = strxor(term1, term2, output=term1)
+
+ self.assertEqual(result, None)
+ self.assertEqual(term1, expected_xor)
+ self.assertEqual(term2, original_term2)
+
+ def test_output_ro_bytes(self):
+ """Verify result cannot be stored in read-only memory"""
+
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term2 = unhexlify(b"383d4ba020573314395b")
+
+ self.assertRaises(TypeError, strxor, term1, term2, output=term1)
+
+ def test_output_ro_memoryview(self):
+ """Verify result cannot be stored in read-only memory"""
+
+ term1 = memoryview(unhexlify(b"ff339a83e5cd4cdf5649"))
+ term2 = unhexlify(b"383d4ba020573314395b")
+
+ self.assertRaises(TypeError, strxor, term1, term2, output=term1)
+
+ def test_output_incorrect_length(self):
+ """Verify result cannot be stored in memory of incorrect length"""
+
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term2 = unhexlify(b"383d4ba020573314395b")
+ output = bytearray(len(term1) - 1)
+
+ self.assertRaises(ValueError, strxor, term1, term2, output=output)
+
+
+class Strxor_cTests(unittest.TestCase):
+
+ def test1(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ result = unhexlify(b"be72dbc2a48c0d9e1708")
+ self.assertEqual(strxor_c(term1, 65), result)
+
+ def test2(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ self.assertEqual(strxor_c(term1, 0), term1)
+
+ def test3(self):
+ self.assertEqual(strxor_c(b"", 90), b"")
+
+ def test_wrong_range(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ self.assertRaises(ValueError, strxor_c, term1, -1)
+ self.assertRaises(ValueError, strxor_c, term1, 256)
+
+ def test_bytearray(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term1_ba = bytearray(term1)
+ result = unhexlify(b"be72dbc2a48c0d9e1708")
+
+ self.assertEqual(strxor_c(term1_ba, 65), result)
+
+ def test_memoryview(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ term1_mv = memoryview(term1)
+ result = unhexlify(b"be72dbc2a48c0d9e1708")
+
+ self.assertEqual(strxor_c(term1_mv, 65), result)
+
+ def test_output_bytearray(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ original_term1 = term1[:]
+ expected_result = unhexlify(b"be72dbc2a48c0d9e1708")
+ output = bytearray(len(term1))
+
+ result = strxor_c(term1, 65, output=output)
+
+ self.assertEqual(result, None)
+ self.assertEqual(output, expected_result)
+ self.assertEqual(term1, original_term1)
+
+ def test_output_memoryview(self):
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ original_term1 = term1[:]
+ expected_result = unhexlify(b"be72dbc2a48c0d9e1708")
+ output = memoryview(bytearray(len(term1)))
+
+ result = strxor_c(term1, 65, output=output)
+
+ self.assertEqual(result, None)
+ self.assertEqual(output, expected_result)
+ self.assertEqual(term1, original_term1)
+
+ def test_output_overlapping_bytearray(self):
+ """Verify result can be stored in overlapping memory"""
+
+ term1 = bytearray(unhexlify(b"ff339a83e5cd4cdf5649"))
+ expected_xor = unhexlify(b"be72dbc2a48c0d9e1708")
+
+ result = strxor_c(term1, 65, output=term1)
+
+ self.assertEqual(result, None)
+ self.assertEqual(term1, expected_xor)
+
+ def test_output_overlapping_memoryview(self):
+ """Verify result can be stored in overlapping memory"""
+
+ term1 = memoryview(bytearray(unhexlify(b"ff339a83e5cd4cdf5649")))
+ expected_xor = unhexlify(b"be72dbc2a48c0d9e1708")
+
+ result = strxor_c(term1, 65, output=term1)
+
+ self.assertEqual(result, None)
+ self.assertEqual(term1, expected_xor)
+
+ def test_output_ro_bytes(self):
+ """Verify result cannot be stored in read-only memory"""
+
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+
+ self.assertRaises(TypeError, strxor_c, term1, 65, output=term1)
+
+ def test_output_ro_memoryview(self):
+ """Verify result cannot be stored in read-only memory"""
+
+ term1 = memoryview(unhexlify(b"ff339a83e5cd4cdf5649"))
+ term2 = unhexlify(b"383d4ba020573314395b")
+
+ self.assertRaises(TypeError, strxor_c, term1, 65, output=term1)
+
+ def test_output_incorrect_length(self):
+ """Verify result cannot be stored in memory of incorrect length"""
+
+ term1 = unhexlify(b"ff339a83e5cd4cdf5649")
+ output = bytearray(len(term1) - 1)
+
+ self.assertRaises(ValueError, strxor_c, term1, 65, output=output)
+
+
+def get_tests(config={}):
+ tests = []
+ tests += list_test_cases(StrxorTests)
+ tests += list_test_cases(Strxor_cTests)
+ return tests
+
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
diff --git a/lib/Crypto/SelfTest/__init__.py b/lib/Crypto/SelfTest/__init__.py
new file mode 100644
index 0000000..12b7592
--- /dev/null
+++ b/lib/Crypto/SelfTest/__init__.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/__init__.py: Self-test for PyCrypto
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Self tests
+
+These tests should perform quickly and can ideally be used every time an
+application runs.
+"""
+
+__revision__ = "$Id$"
+
+import sys
+import unittest
+from Crypto.Util.py3compat import StringIO
+
+class SelfTestError(Exception):
+ def __init__(self, message, result):
+ Exception.__init__(self, message, result)
+ self.message = message
+ self.result = result
+
+def run(module=None, verbosity=0, stream=None, tests=None, config=None, **kwargs):
+ """Execute self-tests.
+
+ This raises SelfTestError if any test is unsuccessful.
+
+ You may optionally pass in a sub-module of SelfTest if you only want to
+ perform some of the tests. For example, the following would test only the
+ hash modules:
+
+ Crypto.SelfTest.run(Crypto.SelfTest.Hash)
+
+ """
+
+ if config is None:
+ config = {}
+ suite = unittest.TestSuite()
+ if module is None:
+ if tests is None:
+ tests = get_tests(config=config)
+ suite.addTests(tests)
+ else:
+ if tests is None:
+ suite.addTests(module.get_tests(config=config))
+ else:
+ raise ValueError("'module' and 'tests' arguments are mutually exclusive")
+ if stream is None:
+ kwargs['stream'] = StringIO()
+ else:
+ kwargs['stream'] = stream
+ runner = unittest.TextTestRunner(verbosity=verbosity, **kwargs)
+ result = runner.run(suite)
+ if not result.wasSuccessful():
+ if stream is None:
+ sys.stderr.write(kwargs['stream'].getvalue())
+ raise SelfTestError("Self-test failed", result)
+ return result
+
+def get_tests(config={}):
+ tests = []
+ from Crypto.SelfTest import Cipher; tests += Cipher.get_tests(config=config)
+ from Crypto.SelfTest import Hash; tests += Hash.get_tests(config=config)
+ from Crypto.SelfTest import Protocol; tests += Protocol.get_tests(config=config)
+ from Crypto.SelfTest import PublicKey; tests += PublicKey.get_tests(config=config)
+ from Crypto.SelfTest import Random; tests += Random.get_tests(config=config)
+ from Crypto.SelfTest import Util; tests += Util.get_tests(config=config)
+ from Crypto.SelfTest import Signature; tests += Signature.get_tests(config=config)
+ from Crypto.SelfTest import IO; tests += IO.get_tests(config=config)
+ from Crypto.SelfTest import Math; tests += Math.get_tests(config=config)
+ return tests
+
+if __name__ == '__main__':
+ suite = lambda: unittest.TestSuite(get_tests())
+ unittest.main(defaultTest='suite')
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/SelfTest/__main__.py b/lib/Crypto/SelfTest/__main__.py
new file mode 100644
index 0000000..9ab0912
--- /dev/null
+++ b/lib/Crypto/SelfTest/__main__.py
@@ -0,0 +1,38 @@
+#! /usr/bin/env python
+#
+# __main__.py : Stand-along loader for PyCryptodome test suite
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from __future__ import print_function
+
+import sys
+
+from Crypto import SelfTest
+
+slow_tests = not "--skip-slow-tests" in sys.argv
+if not slow_tests:
+ print("Skipping slow tests")
+
+wycheproof_warnings = "--wycheproof-warnings" in sys.argv
+if wycheproof_warnings:
+ print("Printing Wycheproof warnings")
+
+config = {'slow_tests' : slow_tests, 'wycheproof_warnings' : wycheproof_warnings }
+SelfTest.run(stream=sys.stdout, verbosity=1, config=config)
diff --git a/lib/Crypto/SelfTest/loader.py b/lib/Crypto/SelfTest/loader.py
new file mode 100644
index 0000000..18be270
--- /dev/null
+++ b/lib/Crypto/SelfTest/loader.py
@@ -0,0 +1,206 @@
+# ===================================================================
+#
+# Copyright (c) 2016, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import os
+import re
+import json
+import errno
+import binascii
+import warnings
+from binascii import unhexlify
+from Crypto.Util.py3compat import FileNotFoundError
+
+
+try:
+ import pycryptodome_test_vectors # type: ignore
+ test_vectors_available = True
+except ImportError:
+ test_vectors_available = False
+
+
+def _load_tests(dir_comps, file_in, description, conversions):
+ """Load and parse a test vector file
+
+ Return a list of objects, one per group of adjacent
+ KV lines or for a single line in the form "[.*]".
+
+ For a group of lines, the object has one attribute per line.
+ """
+
+ line_number = 0
+ results = []
+
+ class TestVector(object):
+ def __init__(self, description, count):
+ self.desc = description
+ self.count = count
+ self.others = []
+
+ test_vector = None
+ count = 0
+ new_group = True
+
+ while True:
+ line_number += 1
+ line = file_in.readline()
+ if not line:
+ if test_vector is not None:
+ results.append(test_vector)
+ break
+ line = line.strip()
+
+ # Skip comments and empty lines
+ if line.startswith('#') or not line:
+ new_group = True
+ continue
+
+ if line.startswith("["):
+ if test_vector is not None:
+ results.append(test_vector)
+ test_vector = None
+ results.append(line)
+ continue
+
+ if new_group:
+ count += 1
+ new_group = False
+ if test_vector is not None:
+ results.append(test_vector)
+ test_vector = TestVector("%s (#%d)" % (description, count), count)
+
+ res = re.match("([A-Za-z0-9]+) = ?(.*)", line)
+ if not res:
+ test_vector.others += [line]
+ else:
+ token = res.group(1).lower()
+ data = res.group(2).lower()
+
+ conversion = conversions.get(token, None)
+ if conversion is None:
+ if len(data) % 2 != 0:
+ data = "0" + data
+ setattr(test_vector, token, binascii.unhexlify(data))
+ else:
+ setattr(test_vector, token, conversion(data))
+
+ # This line is ignored
+ return results
+
+
+def load_test_vectors(dir_comps, file_name, description, conversions):
+ """Load and parse a test vector file
+
+ This function returns a list of objects, one per group of adjacent
+ KV lines or for a single line in the form "[.*]".
+
+ For a group of lines, the object has one attribute per line.
+ """
+
+ results = None
+
+ try:
+ if not test_vectors_available:
+ raise FileNotFoundError(errno.ENOENT,
+ os.strerror(errno.ENOENT),
+ file_name)
+
+ description = "%s test (%s)" % (description, file_name)
+
+ init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
+ full_file_name = os.path.join(os.path.join(init_dir, *dir_comps), file_name)
+ with open(full_file_name) as file_in:
+ results = _load_tests(dir_comps, file_in, description, conversions)
+
+ except FileNotFoundError:
+ warnings.warn("Warning: skipping extended tests for " + description,
+ UserWarning,
+ stacklevel=2)
+
+ return results
+
+
+def load_test_vectors_wycheproof(dir_comps, file_name, description,
+ root_tag={}, group_tag={}, unit_tag={}):
+
+ result = []
+ try:
+ if not test_vectors_available:
+ raise FileNotFoundError(errno.ENOENT,
+ os.strerror(errno.ENOENT),
+ file_name)
+
+ init_dir = os.path.dirname(pycryptodome_test_vectors.__file__)
+ full_file_name = os.path.join(os.path.join(init_dir, *dir_comps), file_name)
+ with open(full_file_name) as file_in:
+ tv_tree = json.load(file_in)
+
+ except FileNotFoundError:
+ warnings.warn("Warning: skipping extended tests for " + description,
+ UserWarning,
+ stacklevel=2)
+ return result
+
+ class TestVector(object):
+ pass
+
+ common_root = {}
+ for k, v in root_tag.items():
+ common_root[k] = v(tv_tree)
+
+ for group in tv_tree['testGroups']:
+
+ common_group = {}
+ for k, v in group_tag.items():
+ common_group[k] = v(group)
+
+ for test in group['tests']:
+ tv = TestVector()
+
+ for k, v in common_root.items():
+ setattr(tv, k, v)
+ for k, v in common_group.items():
+ setattr(tv, k, v)
+
+ tv.id = test['tcId']
+ tv.comment = test['comment']
+ for attr in 'key', 'iv', 'aad', 'msg', 'ct', 'tag', 'label', 'ikm', 'salt', 'info', 'okm', 'sig':
+ if attr in test:
+ setattr(tv, attr, unhexlify(test[attr]))
+ tv.filename = file_name
+
+ for k, v in unit_tag.items():
+ setattr(tv, k, v(test))
+
+ tv.valid = test['result'] != "invalid"
+ tv.warning = test['result'] == "acceptable"
+ result.append(tv)
+
+ return result
+
diff --git a/lib/Crypto/SelfTest/st_common.py b/lib/Crypto/SelfTest/st_common.py
new file mode 100644
index 0000000..e098d81
--- /dev/null
+++ b/lib/Crypto/SelfTest/st_common.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+#
+# SelfTest/st_common.py: Common functions for SelfTest modules
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Common functions for SelfTest modules"""
+
+import unittest
+import binascii
+from Crypto.Util.py3compat import b
+
+
+def list_test_cases(class_):
+ """Return a list of TestCase instances given a TestCase class
+
+ This is useful when you have defined test* methods on your TestCase class.
+ """
+ return unittest.TestLoader().loadTestsFromTestCase(class_)
+
+def strip_whitespace(s):
+ """Remove whitespace from a text or byte string"""
+ if isinstance(s,str):
+ return b("".join(s.split()))
+ else:
+ return b("").join(s.split())
+
+def a2b_hex(s):
+ """Convert hexadecimal to binary, ignoring whitespace"""
+ return binascii.a2b_hex(strip_whitespace(s))
+
+def b2a_hex(s):
+ """Convert binary to hexadecimal"""
+ # For completeness
+ return binascii.b2a_hex(s)
+
+# vim:set ts=4 sw=4 sts=4 expandtab:
diff --git a/lib/Crypto/Signature/DSS.py b/lib/Crypto/Signature/DSS.py
new file mode 100644
index 0000000..fa84817
--- /dev/null
+++ b/lib/Crypto/Signature/DSS.py
@@ -0,0 +1,403 @@
+#
+# Signature/DSS.py : DSS.py
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.asn1 import DerSequence
+from Crypto.Util.number import long_to_bytes
+from Crypto.Math.Numbers import Integer
+
+from Crypto.Hash import HMAC
+from Crypto.PublicKey.ECC import EccKey
+from Crypto.PublicKey.DSA import DsaKey
+
+__all__ = ['DssSigScheme', 'new']
+
+
+class DssSigScheme(object):
+ """A (EC)DSA signature object.
+ Do not instantiate directly.
+ Use :func:`Crypto.Signature.DSS.new`.
+ """
+
+ def __init__(self, key, encoding, order):
+ """Create a new Digital Signature Standard (DSS) object.
+
+ Do not instantiate this object directly,
+ use `Crypto.Signature.DSS.new` instead.
+ """
+
+ self._key = key
+ self._encoding = encoding
+ self._order = order
+
+ self._order_bits = self._order.size_in_bits()
+ self._order_bytes = (self._order_bits - 1) // 8 + 1
+
+ def can_sign(self):
+ """Return ``True`` if this signature object can be used
+ for signing messages."""
+
+ return self._key.has_private()
+
+ def _compute_nonce(self, msg_hash):
+ raise NotImplementedError("To be provided by subclasses")
+
+ def _valid_hash(self, msg_hash):
+ raise NotImplementedError("To be provided by subclasses")
+
+ def sign(self, msg_hash):
+ """Compute the DSA/ECDSA signature of a message.
+
+ Args:
+ msg_hash (hash object):
+ The hash that was carried out over the message.
+ The object belongs to the :mod:`Crypto.Hash` package.
+ Under mode ``'fips-186-3'``, the hash must be a FIPS
+ approved secure hash (SHA-2 or SHA-3).
+
+ :return: The signature as ``bytes``
+ :raise ValueError: if the hash algorithm is incompatible to the (EC)DSA key
+ :raise TypeError: if the (EC)DSA key has no private half
+ """
+
+ if not self._key.has_private():
+ raise TypeError("Private key is needed to sign")
+
+ if not self._valid_hash(msg_hash):
+ raise ValueError("Hash is not sufficiently strong")
+
+ # Generate the nonce k (critical!)
+ nonce = self._compute_nonce(msg_hash)
+
+ # Perform signature using the raw API
+ z = Integer.from_bytes(msg_hash.digest()[:self._order_bytes])
+ sig_pair = self._key._sign(z, nonce)
+
+ # Encode the signature into a single byte string
+ if self._encoding == 'binary':
+ output = b"".join([long_to_bytes(x, self._order_bytes)
+ for x in sig_pair])
+ else:
+ # Dss-sig ::= SEQUENCE {
+ # r INTEGER,
+ # s INTEGER
+ # }
+ # Ecdsa-Sig-Value ::= SEQUENCE {
+ # r INTEGER,
+ # s INTEGER
+ # }
+ output = DerSequence(sig_pair).encode()
+
+ return output
+
+ def verify(self, msg_hash, signature):
+ """Check if a certain (EC)DSA signature is authentic.
+
+ Args:
+ msg_hash (hash object):
+ The hash that was carried out over the message.
+ This is an object belonging to the :mod:`Crypto.Hash` module.
+ Under mode ``'fips-186-3'``, the hash must be a FIPS
+ approved secure hash (SHA-2 or SHA-3).
+
+ signature (``bytes``):
+ The signature that needs to be validated.
+
+ :raise ValueError: if the signature is not authentic
+ """
+
+ if not self._valid_hash(msg_hash):
+ raise ValueError("Hash is not sufficiently strong")
+
+ if self._encoding == 'binary':
+ if len(signature) != (2 * self._order_bytes):
+ raise ValueError("The signature is not authentic (length)")
+ r_prime, s_prime = [Integer.from_bytes(x)
+ for x in (signature[:self._order_bytes],
+ signature[self._order_bytes:])]
+ else:
+ try:
+ der_seq = DerSequence().decode(signature, strict=True)
+ except (ValueError, IndexError):
+ raise ValueError("The signature is not authentic (DER)")
+ if len(der_seq) != 2 or not der_seq.hasOnlyInts():
+ raise ValueError("The signature is not authentic (DER content)")
+ r_prime, s_prime = Integer(der_seq[0]), Integer(der_seq[1])
+
+ if not (0 < r_prime < self._order) or not (0 < s_prime < self._order):
+ raise ValueError("The signature is not authentic (d)")
+
+ z = Integer.from_bytes(msg_hash.digest()[:self._order_bytes])
+ result = self._key._verify(z, (r_prime, s_prime))
+ if not result:
+ raise ValueError("The signature is not authentic")
+ # Make PyCrypto code to fail
+ return False
+
+
+class DeterministicDsaSigScheme(DssSigScheme):
+ # Also applicable to ECDSA
+
+ def __init__(self, key, encoding, order, private_key):
+ super(DeterministicDsaSigScheme, self).__init__(key, encoding, order)
+ self._private_key = private_key
+
+ def _bits2int(self, bstr):
+ """See 2.3.2 in RFC6979"""
+
+ result = Integer.from_bytes(bstr)
+ q_len = self._order.size_in_bits()
+ b_len = len(bstr) * 8
+ if b_len > q_len:
+ # Only keep leftmost q_len bits
+ result >>= (b_len - q_len)
+ return result
+
+ def _int2octets(self, int_mod_q):
+ """See 2.3.3 in RFC6979"""
+
+ assert 0 < int_mod_q < self._order
+ return long_to_bytes(int_mod_q, self._order_bytes)
+
+ def _bits2octets(self, bstr):
+ """See 2.3.4 in RFC6979"""
+
+ z1 = self._bits2int(bstr)
+ if z1 < self._order:
+ z2 = z1
+ else:
+ z2 = z1 - self._order
+ return self._int2octets(z2)
+
+ def _compute_nonce(self, mhash):
+ """Generate k in a deterministic way"""
+
+ # See section 3.2 in RFC6979.txt
+ # Step a
+ h1 = mhash.digest()
+ # Step b
+ mask_v = b'\x01' * mhash.digest_size
+ # Step c
+ nonce_k = b'\x00' * mhash.digest_size
+
+ for int_oct in (b'\x00', b'\x01'):
+ # Step d/f
+ nonce_k = HMAC.new(nonce_k,
+ mask_v + int_oct +
+ self._int2octets(self._private_key) +
+ self._bits2octets(h1), mhash).digest()
+ # Step e/g
+ mask_v = HMAC.new(nonce_k, mask_v, mhash).digest()
+
+ nonce = -1
+ while not (0 < nonce < self._order):
+ # Step h.C (second part)
+ if nonce != -1:
+ nonce_k = HMAC.new(nonce_k, mask_v + b'\x00',
+ mhash).digest()
+ mask_v = HMAC.new(nonce_k, mask_v, mhash).digest()
+
+ # Step h.A
+ mask_t = b""
+
+ # Step h.B
+ while len(mask_t) < self._order_bytes:
+ mask_v = HMAC.new(nonce_k, mask_v, mhash).digest()
+ mask_t += mask_v
+
+ # Step h.C (first part)
+ nonce = self._bits2int(mask_t)
+ return nonce
+
+ def _valid_hash(self, msg_hash):
+ return True
+
+
+class FipsDsaSigScheme(DssSigScheme):
+
+ #: List of L (bit length of p) and N (bit length of q) combinations
+ #: that are allowed by FIPS 186-3. The security level is provided in
+ #: Table 2 of FIPS 800-57 (rev3).
+ _fips_186_3_L_N = (
+ (1024, 160), # 80 bits (SHA-1 or stronger)
+ (2048, 224), # 112 bits (SHA-224 or stronger)
+ (2048, 256), # 128 bits (SHA-256 or stronger)
+ (3072, 256) # 256 bits (SHA-512)
+ )
+
+ def __init__(self, key, encoding, order, randfunc):
+ super(FipsDsaSigScheme, self).__init__(key, encoding, order)
+ self._randfunc = randfunc
+
+ L = Integer(key.p).size_in_bits()
+ if (L, self._order_bits) not in self._fips_186_3_L_N:
+ error = ("L/N (%d, %d) is not compliant to FIPS 186-3"
+ % (L, self._order_bits))
+ raise ValueError(error)
+
+ def _compute_nonce(self, msg_hash):
+ # hash is not used
+ return Integer.random_range(min_inclusive=1,
+ max_exclusive=self._order,
+ randfunc=self._randfunc)
+
+ def _valid_hash(self, msg_hash):
+ """Verify that SHA-1, SHA-2 or SHA-3 are used"""
+ return (msg_hash.oid == "1.3.14.3.2.26" or
+ msg_hash.oid.startswith("2.16.840.1.101.3.4.2."))
+
+
+class FipsEcDsaSigScheme(DssSigScheme):
+
+ def __init__(self, key, encoding, order, randfunc):
+ super(FipsEcDsaSigScheme, self).__init__(key, encoding, order)
+ self._randfunc = randfunc
+
+ def _compute_nonce(self, msg_hash):
+ return Integer.random_range(min_inclusive=1,
+ max_exclusive=self._key._curve.order,
+ randfunc=self._randfunc)
+
+ def _valid_hash(self, msg_hash):
+ """Verify that the strength of the hash matches or exceeds
+ the strength of the EC. We fail if the hash is too weak."""
+
+ modulus_bits = self._key.pointQ.size_in_bits()
+
+ # SHS: SHA-2, SHA-3, truncated SHA-512
+ sha224 = ("2.16.840.1.101.3.4.2.4", "2.16.840.1.101.3.4.2.7", "2.16.840.1.101.3.4.2.5")
+ sha256 = ("2.16.840.1.101.3.4.2.1", "2.16.840.1.101.3.4.2.8", "2.16.840.1.101.3.4.2.6")
+ sha384 = ("2.16.840.1.101.3.4.2.2", "2.16.840.1.101.3.4.2.9")
+ sha512 = ("2.16.840.1.101.3.4.2.3", "2.16.840.1.101.3.4.2.10")
+ shs = sha224 + sha256 + sha384 + sha512
+
+ try:
+ result = msg_hash.oid in shs
+ except AttributeError:
+ result = False
+ return result
+
+
+def new(key, mode, encoding='binary', randfunc=None):
+ """Create a signature object :class:`DssSigScheme` that
+ can perform (EC)DSA signature or verification.
+
+ .. note::
+ Refer to `NIST SP 800 Part 1 Rev 4`_ (or newer release) for an
+ overview of the recommended key lengths.
+
+ Args:
+ key (:class:`Crypto.PublicKey.DSA` or :class:`Crypto.PublicKey.ECC`):
+ The key to use for computing the signature (*private* keys only)
+ or for verifying one.
+ For DSA keys, let ``L`` and ``N`` be the bit lengths of the modulus ``p``
+ and of ``q``: the pair ``(L,N)`` must appear in the following list,
+ in compliance to section 4.2 of `FIPS 186-4`_:
+
+ - (1024, 160) *legacy only; do not create new signatures with this*
+ - (2048, 224) *deprecated; do not create new signatures with this*
+ - (2048, 256)
+ - (3072, 256)
+
+ For ECC, only keys over P-224, P-256, P-384, and P-521 are accepted.
+
+ mode (string):
+ The parameter can take these values:
+
+ - ``'fips-186-3'``. The signature generation is randomized and carried out
+ according to `FIPS 186-3`_: the nonce ``k`` is taken from the RNG.
+ - ``'deterministic-rfc6979'``. The signature generation is not
+ randomized. See RFC6979_.
+
+ encoding (string):
+ How the signature is encoded. This value determines the output of
+ :meth:`sign` and the input to :meth:`verify`.
+
+ The following values are accepted:
+
+ - ``'binary'`` (default), the signature is the raw concatenation
+ of ``r`` and ``s``. It is defined in the IEEE P.1363 standard.
+ For DSA, the size in bytes of the signature is ``N/4`` bytes
+ (e.g. 64 for ``N=256``).
+ For ECDSA, the signature is always twice the length of a point
+ coordinate (e.g. 64 bytes for P-256).
+
+ - ``'der'``, the signature is a ASN.1 DER SEQUENCE
+ with two INTEGERs (``r`` and ``s``). It is defined in RFC3279_.
+ The size of the signature is variable.
+
+ randfunc (callable):
+ A function that returns random ``bytes``, of a given length.
+ If omitted, the internal RNG is used.
+ Only applicable for the *'fips-186-3'* mode.
+
+ .. _FIPS 186-3: http://csrc.nist.gov/publications/fips/fips186-3/fips_186-3.pdf
+ .. _FIPS 186-4: http://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf
+ .. _NIST SP 800 Part 1 Rev 4: http://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-57pt1r4.pdf
+ .. _RFC6979: http://tools.ietf.org/html/rfc6979
+ .. _RFC3279: https://tools.ietf.org/html/rfc3279#section-2.2.2
+ """
+
+ # The goal of the 'mode' parameter is to avoid to
+ # have the current version of the standard as default.
+ #
+ # Over time, such version will be superseded by (for instance)
+ # FIPS 186-4 and it will be odd to have -3 as default.
+
+ if encoding not in ('binary', 'der'):
+ raise ValueError("Unknown encoding '%s'" % encoding)
+
+ if isinstance(key, EccKey):
+ order = key._curve.order
+ private_key_attr = 'd'
+ if key._curve.name == "ed25519":
+ raise ValueError("ECC key is not on a NIST P curve")
+ elif isinstance(key, DsaKey):
+ order = Integer(key.q)
+ private_key_attr = 'x'
+ else:
+ raise ValueError("Unsupported key type " + str(type(key)))
+
+ if key.has_private():
+ private_key = getattr(key, private_key_attr)
+ else:
+ private_key = None
+
+ if mode == 'deterministic-rfc6979':
+ return DeterministicDsaSigScheme(key, encoding, order, private_key)
+ elif mode == 'fips-186-3':
+ if isinstance(key, EccKey):
+ return FipsEcDsaSigScheme(key, encoding, order, randfunc)
+ else:
+ return FipsDsaSigScheme(key, encoding, order, randfunc)
+ else:
+ raise ValueError("Unknown DSS mode '%s'" % mode)
diff --git a/lib/Crypto/Signature/DSS.pyi b/lib/Crypto/Signature/DSS.pyi
new file mode 100644
index 0000000..08cad81
--- /dev/null
+++ b/lib/Crypto/Signature/DSS.pyi
@@ -0,0 +1,27 @@
+from typing import Union, Optional, Callable
+from typing_extensions import Protocol
+
+from Crypto.PublicKey.DSA import DsaKey
+from Crypto.PublicKey.ECC import EccKey
+
+class Hash(Protocol):
+ def digest(self) -> bytes: ...
+
+__all__ = ['new']
+
+class DssSigScheme:
+ def __init__(self, key: Union[DsaKey, EccKey], encoding: str, order: int) -> None: ...
+ def can_sign(self) -> bool: ...
+ def sign(self, msg_hash: Hash) -> bytes: ...
+ def verify(self, msg_hash: Hash, signature: bytes) -> bool: ...
+
+class DeterministicDsaSigScheme(DssSigScheme):
+ def __init__(self, key, encoding, order, private_key) -> None: ...
+
+class FipsDsaSigScheme(DssSigScheme):
+ def __init__(self, key: DsaKey, encoding: str, order: int, randfunc: Callable) -> None: ...
+
+class FipsEcDsaSigScheme(DssSigScheme):
+ def __init__(self, key: EccKey, encoding: str, order: int, randfunc: Callable) -> None: ...
+
+def new(key: Union[DsaKey, EccKey], mode: str, encoding: Optional[str]='binary', randfunc: Optional[Callable]=None) -> Union[DeterministicDsaSigScheme, FipsDsaSigScheme, FipsEcDsaSigScheme]: ...
diff --git a/lib/Crypto/Signature/PKCS1_PSS.py b/lib/Crypto/Signature/PKCS1_PSS.py
new file mode 100644
index 0000000..c39d388
--- /dev/null
+++ b/lib/Crypto/Signature/PKCS1_PSS.py
@@ -0,0 +1,55 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+Legacy module for PKCS#1 PSS signatures.
+
+:undocumented: __package__
+"""
+
+import types
+
+from Crypto.Signature import pss
+
+
+def _pycrypto_verify(self, hash_object, signature):
+ try:
+ self._verify(hash_object, signature)
+ except (ValueError, TypeError):
+ return False
+ return True
+
+
+def new(rsa_key, mgfunc=None, saltLen=None, randfunc=None):
+ pkcs1 = pss.new(rsa_key, mask_func=mgfunc,
+ salt_bytes=saltLen, rand_func=randfunc)
+ pkcs1._verify = pkcs1.verify
+ pkcs1.verify = types.MethodType(_pycrypto_verify, pkcs1)
+ return pkcs1
diff --git a/lib/Crypto/Signature/PKCS1_PSS.pyi b/lib/Crypto/Signature/PKCS1_PSS.pyi
new file mode 100644
index 0000000..882cc8f
--- /dev/null
+++ b/lib/Crypto/Signature/PKCS1_PSS.pyi
@@ -0,0 +1,7 @@
+from typing import Optional, Callable
+
+from Crypto.PublicKey.RSA import RsaKey
+from Crypto.Signature.pss import PSS_SigScheme
+
+
+def new(rsa_key: RsaKey, mgfunc: Optional[Callable]=None, saltLen: Optional[int]=None, randfunc: Optional[Callable]=None) -> PSS_SigScheme: ...
diff --git a/lib/Crypto/Signature/PKCS1_v1_5.py b/lib/Crypto/Signature/PKCS1_v1_5.py
new file mode 100644
index 0000000..ac888ed
--- /dev/null
+++ b/lib/Crypto/Signature/PKCS1_v1_5.py
@@ -0,0 +1,53 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""
+Legacy module for PKCS#1 v1.5 signatures.
+
+:undocumented: __package__
+"""
+
+import types
+
+from Crypto.Signature import pkcs1_15
+
+def _pycrypto_verify(self, hash_object, signature):
+ try:
+ self._verify(hash_object, signature)
+ except (ValueError, TypeError):
+ return False
+ return True
+
+def new(rsa_key):
+ pkcs1 = pkcs1_15.new(rsa_key)
+ pkcs1._verify = pkcs1.verify
+ pkcs1.verify = types.MethodType(_pycrypto_verify, pkcs1)
+ return pkcs1
+
diff --git a/lib/Crypto/Signature/PKCS1_v1_5.pyi b/lib/Crypto/Signature/PKCS1_v1_5.pyi
new file mode 100644
index 0000000..55b9637
--- /dev/null
+++ b/lib/Crypto/Signature/PKCS1_v1_5.pyi
@@ -0,0 +1,6 @@
+from Crypto.PublicKey.RSA import RsaKey
+
+from Crypto.Signature.pkcs1_15 import PKCS115_SigScheme
+
+
+def new(rsa_key: RsaKey) -> PKCS115_SigScheme: ... \ No newline at end of file
diff --git a/lib/Crypto/Signature/__init__.py b/lib/Crypto/Signature/__init__.py
new file mode 100644
index 0000000..11ca64c
--- /dev/null
+++ b/lib/Crypto/Signature/__init__.py
@@ -0,0 +1,36 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+"""Digital signature protocols
+
+A collection of standardized protocols to carry out digital signatures.
+"""
+
+__all__ = ['PKCS1_v1_5', 'PKCS1_PSS', 'DSS', 'pkcs1_15', 'pss', 'eddsa']
diff --git a/lib/Crypto/Signature/eddsa.py b/lib/Crypto/Signature/eddsa.py
new file mode 100644
index 0000000..97145ca
--- /dev/null
+++ b/lib/Crypto/Signature/eddsa.py
@@ -0,0 +1,341 @@
+# ===================================================================
+#
+# Copyright (c) 2022, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Math.Numbers import Integer
+
+from Crypto.Hash import SHA512, SHAKE256
+from Crypto.Util.py3compat import bchr, is_bytes
+from Crypto.PublicKey.ECC import (EccKey,
+ construct,
+ _import_ed25519_public_key,
+ _import_ed448_public_key)
+
+
+def import_public_key(encoded):
+ """Import an EdDSA ECC public key, when encoded as raw ``bytes`` as described
+ in RFC8032.
+
+ Args:
+ encoded (bytes):
+ The EdDSA public key to import.
+ It must be 32 bytes for Ed25519, and 57 bytes for Ed448.
+
+ Returns:
+ :class:`Crypto.PublicKey.EccKey` : a new ECC key object.
+
+ Raises:
+ ValueError: when the given key cannot be parsed.
+ """
+
+ if len(encoded) == 32:
+ x, y = _import_ed25519_public_key(encoded)
+ curve_name = "Ed25519"
+ elif len(encoded) == 57:
+ x, y = _import_ed448_public_key(encoded)
+ curve_name = "Ed448"
+ else:
+ raise ValueError("Not an EdDSA key (%d bytes)" % len(encoded))
+ return construct(curve=curve_name, point_x=x, point_y=y)
+
+
+def import_private_key(encoded):
+ """Import an EdDSA ECC private key, when encoded as raw ``bytes`` as described
+ in RFC8032.
+
+ Args:
+ encoded (bytes):
+ The EdDSA private key to import.
+ It must be 32 bytes for Ed25519, and 57 bytes for Ed448.
+
+ Returns:
+ :class:`Crypto.PublicKey.EccKey` : a new ECC key object.
+
+ Raises:
+ ValueError: when the given key cannot be parsed.
+ """
+
+ if len(encoded) == 32:
+ curve_name = "ed25519"
+ elif len(encoded) == 57:
+ curve_name = "ed448"
+ else:
+ raise ValueError("Incorrect length. Only EdDSA private keys are supported.")
+
+ # Note that the private key is truly a sequence of random bytes,
+ # so we cannot check its correctness in any way.
+
+ return construct(seed=encoded, curve=curve_name)
+
+
+class EdDSASigScheme(object):
+ """An EdDSA signature object.
+ Do not instantiate directly.
+ Use :func:`Crypto.Signature.eddsa.new`.
+ """
+
+ def __init__(self, key, context):
+ """Create a new EdDSA object.
+
+ Do not instantiate this object directly,
+ use `Crypto.Signature.DSS.new` instead.
+ """
+
+ self._key = key
+ self._context = context
+ self._A = key._export_eddsa()
+ self._order = key._curve.order
+
+ def can_sign(self):
+ """Return ``True`` if this signature object can be used
+ for signing messages."""
+
+ return self._key.has_private()
+
+ def sign(self, msg_or_hash):
+ """Compute the EdDSA signature of a message.
+
+ Args:
+ msg_or_hash (bytes or a hash object):
+ The message to sign (``bytes``, in case of *PureEdDSA*) or
+ the hash that was carried out over the message (hash object, for *HashEdDSA*).
+
+ The hash object must be :class:`Crypto.Hash.SHA512` for Ed25519,
+ and :class:`Crypto.Hash.SHAKE256` object for Ed448.
+
+ :return: The signature as ``bytes``. It is always 64 bytes for Ed25519, and 114 bytes for Ed448.
+ :raise TypeError: if the EdDSA key has no private half
+ """
+
+ if not self._key.has_private():
+ raise TypeError("Private key is needed to sign")
+
+ if self._key._curve.name == "ed25519":
+ ph = isinstance(msg_or_hash, SHA512.SHA512Hash)
+ if not (ph or is_bytes(msg_or_hash)):
+ raise TypeError("'msg_or_hash' must be bytes of a SHA-512 hash")
+ eddsa_sign_method = self._sign_ed25519
+
+ elif self._key._curve.name == "ed448":
+ ph = isinstance(msg_or_hash, SHAKE256.SHAKE256_XOF)
+ if not (ph or is_bytes(msg_or_hash)):
+ raise TypeError("'msg_or_hash' must be bytes of a SHAKE256 hash")
+ eddsa_sign_method = self._sign_ed448
+
+ else:
+ raise ValueError("Incorrect curve for EdDSA")
+
+ return eddsa_sign_method(msg_or_hash, ph)
+
+ def _sign_ed25519(self, msg_or_hash, ph):
+
+ if self._context or ph:
+ flag = int(ph)
+ # dom2(flag, self._context)
+ dom2 = b'SigEd25519 no Ed25519 collisions' + bchr(flag) + \
+ bchr(len(self._context)) + self._context
+ else:
+ dom2 = b''
+
+ PHM = msg_or_hash.digest() if ph else msg_or_hash
+
+ # See RFC 8032, section 5.1.6
+
+ # Step 2
+ r_hash = SHA512.new(dom2 + self._key._prefix + PHM).digest()
+ r = Integer.from_bytes(r_hash, 'little') % self._order
+ # Step 3
+ R_pk = EccKey(point=r * self._key._curve.G)._export_eddsa()
+ # Step 4
+ k_hash = SHA512.new(dom2 + R_pk + self._A + PHM).digest()
+ k = Integer.from_bytes(k_hash, 'little') % self._order
+ # Step 5
+ s = (r + k * self._key.d) % self._order
+
+ return R_pk + s.to_bytes(32, 'little')
+
+ def _sign_ed448(self, msg_or_hash, ph):
+
+ flag = int(ph)
+ # dom4(flag, self._context)
+ dom4 = b'SigEd448' + bchr(flag) + \
+ bchr(len(self._context)) + self._context
+
+ PHM = msg_or_hash.read(64) if ph else msg_or_hash
+
+ # See RFC 8032, section 5.2.6
+
+ # Step 2
+ r_hash = SHAKE256.new(dom4 + self._key._prefix + PHM).read(114)
+ r = Integer.from_bytes(r_hash, 'little') % self._order
+ # Step 3
+ R_pk = EccKey(point=r * self._key._curve.G)._export_eddsa()
+ # Step 4
+ k_hash = SHAKE256.new(dom4 + R_pk + self._A + PHM).read(114)
+ k = Integer.from_bytes(k_hash, 'little') % self._order
+ # Step 5
+ s = (r + k * self._key.d) % self._order
+
+ return R_pk + s.to_bytes(57, 'little')
+
+ def verify(self, msg_or_hash, signature):
+ """Check if an EdDSA signature is authentic.
+
+ Args:
+ msg_or_hash (bytes or a hash object):
+ The message to verify (``bytes``, in case of *PureEdDSA*) or
+ the hash that was carried out over the message (hash object, for *HashEdDSA*).
+
+ The hash object must be :class:`Crypto.Hash.SHA512` object for Ed25519,
+ and :class:`Crypto.Hash.SHAKE256` for Ed448.
+
+ signature (``bytes``):
+ The signature that needs to be validated.
+ It must be 64 bytes for Ed25519, and 114 bytes for Ed448.
+
+ :raise ValueError: if the signature is not authentic
+ """
+
+ if self._key._curve.name == "ed25519":
+ ph = isinstance(msg_or_hash, SHA512.SHA512Hash)
+ if not (ph or is_bytes(msg_or_hash)):
+ raise TypeError("'msg_or_hash' must be bytes of a SHA-512 hash")
+ eddsa_verify_method = self._verify_ed25519
+
+ elif self._key._curve.name == "ed448":
+ ph = isinstance(msg_or_hash, SHAKE256.SHAKE256_XOF)
+ if not (ph or is_bytes(msg_or_hash)):
+ raise TypeError("'msg_or_hash' must be bytes of a SHAKE256 hash")
+ eddsa_verify_method = self._verify_ed448
+
+ else:
+ raise ValueError("Incorrect curve for EdDSA")
+
+ return eddsa_verify_method(msg_or_hash, signature, ph)
+
+ def _verify_ed25519(self, msg_or_hash, signature, ph):
+
+ if len(signature) != 64:
+ raise ValueError("The signature is not authentic (length)")
+
+ if self._context or ph:
+ flag = int(ph)
+ dom2 = b'SigEd25519 no Ed25519 collisions' + bchr(flag) + \
+ bchr(len(self._context)) + self._context
+ else:
+ dom2 = b''
+
+ PHM = msg_or_hash.digest() if ph else msg_or_hash
+
+ # Section 5.1.7
+
+ # Step 1
+ try:
+ R = import_public_key(signature[:32]).pointQ
+ except ValueError:
+ raise ValueError("The signature is not authentic (R)")
+ s = Integer.from_bytes(signature[32:], 'little')
+ if s > self._order:
+ raise ValueError("The signature is not authentic (S)")
+ # Step 2
+ k_hash = SHA512.new(dom2 + signature[:32] + self._A + PHM).digest()
+ k = Integer.from_bytes(k_hash, 'little') % self._order
+ # Step 3
+ point1 = s * 8 * self._key._curve.G
+ # OPTIMIZE: with double-scalar multiplication, with no SCA
+ # countermeasures because it is public values
+ point2 = 8 * R + k * 8 * self._key.pointQ
+ if point1 != point2:
+ raise ValueError("The signature is not authentic")
+
+ def _verify_ed448(self, msg_or_hash, signature, ph):
+
+ if len(signature) != 114:
+ raise ValueError("The signature is not authentic (length)")
+
+ flag = int(ph)
+ # dom4(flag, self._context)
+ dom4 = b'SigEd448' + bchr(flag) + \
+ bchr(len(self._context)) + self._context
+
+ PHM = msg_or_hash.read(64) if ph else msg_or_hash
+
+ # Section 5.2.7
+
+ # Step 1
+ try:
+ R = import_public_key(signature[:57]).pointQ
+ except ValueError:
+ raise ValueError("The signature is not authentic (R)")
+ s = Integer.from_bytes(signature[57:], 'little')
+ if s > self._order:
+ raise ValueError("The signature is not authentic (S)")
+ # Step 2
+ k_hash = SHAKE256.new(dom4 + signature[:57] + self._A + PHM).read(114)
+ k = Integer.from_bytes(k_hash, 'little') % self._order
+ # Step 3
+ point1 = s * 8 * self._key._curve.G
+ # OPTIMIZE: with double-scalar multiplication, with no SCA
+ # countermeasures because it is public values
+ point2 = 8 * R + k * 8 * self._key.pointQ
+ if point1 != point2:
+ raise ValueError("The signature is not authentic")
+
+
+def new(key, mode, context=None):
+ """Create a signature object :class:`EdDSASigScheme` that
+ can perform or verify an EdDSA signature.
+
+ Args:
+ key (:class:`Crypto.PublicKey.ECC` object:
+ The key to use for computing the signature (*private* keys only)
+ or for verifying one.
+ The key must be on the curve ``Ed25519`` or ``Ed448``.
+
+ mode (string):
+ This parameter must be ``'rfc8032'``.
+
+ context (bytes):
+ Up to 255 bytes of `context <https://datatracker.ietf.org/doc/html/rfc8032#page-41>`_,
+ which is a constant byte string to segregate different protocols or
+ different applications of the same key.
+ """
+
+ if not isinstance(key, EccKey) or not key._is_eddsa():
+ raise ValueError("EdDSA can only be used with EdDSA keys")
+
+ if mode != 'rfc8032':
+ raise ValueError("Mode must be 'rfc8032'")
+
+ if context is None:
+ context = b''
+ elif len(context) > 255:
+ raise ValueError("Context for EdDSA must not be longer than 255 bytes")
+
+ return EdDSASigScheme(key, context)
diff --git a/lib/Crypto/Signature/eddsa.pyi b/lib/Crypto/Signature/eddsa.pyi
new file mode 100644
index 0000000..3842ff7
--- /dev/null
+++ b/lib/Crypto/Signature/eddsa.pyi
@@ -0,0 +1,21 @@
+from typing import Union, Optional
+from typing_extensions import Protocol
+from Crypto.PublicKey.ECC import EccKey
+
+class Hash(Protocol):
+ def digest(self) -> bytes: ...
+
+class XOF(Protocol):
+ def read(self, len: int) -> bytes: ...
+
+def import_public_key(encoded: bytes) -> EccKey: ...
+def import_private_key(encoded: bytes) -> EccKey: ...
+
+class EdDSASigScheme(object):
+
+ def __init__(self, key: EccKey, context: bytes) -> None: ...
+ def can_sign(self) -> bool: ...
+ def sign(self, msg_or_hash: Union[bytes, Hash, XOF]) -> bytes: ...
+ def verify(self, msg_or_hash: Union[bytes, Hash, XOF], signature: bytes) -> None: ...
+
+def new(key: EccKey, mode: bytes, context: Optional[bytes]=None) -> EdDSASigScheme: ...
diff --git a/lib/Crypto/Signature/pkcs1_15.py b/lib/Crypto/Signature/pkcs1_15.py
new file mode 100644
index 0000000..54a4bf7
--- /dev/null
+++ b/lib/Crypto/Signature/pkcs1_15.py
@@ -0,0 +1,222 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import Crypto.Util.number
+from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
+from Crypto.Util.asn1 import DerSequence, DerNull, DerOctetString, DerObjectId
+
+class PKCS115_SigScheme:
+ """A signature object for ``RSASSA-PKCS1-v1_5``.
+ Do not instantiate directly.
+ Use :func:`Crypto.Signature.pkcs1_15.new`.
+ """
+
+ def __init__(self, rsa_key):
+ """Initialize this PKCS#1 v1.5 signature scheme object.
+
+ :Parameters:
+ rsa_key : an RSA key object
+ Creation of signatures is only possible if this is a *private*
+ RSA key. Verification of signatures is always possible.
+ """
+ self._key = rsa_key
+
+ def can_sign(self):
+ """Return ``True`` if this object can be used to sign messages."""
+ return self._key.has_private()
+
+ def sign(self, msg_hash):
+ """Create the PKCS#1 v1.5 signature of a message.
+
+ This function is also called ``RSASSA-PKCS1-V1_5-SIGN`` and
+ it is specified in
+ `section 8.2.1 of RFC8017 <https://tools.ietf.org/html/rfc8017#page-36>`_.
+
+ :parameter msg_hash:
+ This is an object from the :mod:`Crypto.Hash` package.
+ It has been used to digest the message to sign.
+ :type msg_hash: hash object
+
+ :return: the signature encoded as a *byte string*.
+ :raise ValueError: if the RSA key is not long enough for the given hash algorithm.
+ :raise TypeError: if the RSA key has no private half.
+ """
+
+ # See 8.2.1 in RFC3447
+ modBits = Crypto.Util.number.size(self._key.n)
+ k = ceil_div(modBits,8) # Convert from bits to bytes
+
+ # Step 1
+ em = _EMSA_PKCS1_V1_5_ENCODE(msg_hash, k)
+ # Step 2a (OS2IP)
+ em_int = bytes_to_long(em)
+ # Step 2b (RSASP1)
+ m_int = self._key._decrypt(em_int)
+ # Step 2c (I2OSP)
+ signature = long_to_bytes(m_int, k)
+ return signature
+
+ def verify(self, msg_hash, signature):
+ """Check if the PKCS#1 v1.5 signature over a message is valid.
+
+ This function is also called ``RSASSA-PKCS1-V1_5-VERIFY`` and
+ it is specified in
+ `section 8.2.2 of RFC8037 <https://tools.ietf.org/html/rfc8017#page-37>`_.
+
+ :parameter msg_hash:
+ The hash that was carried out over the message. This is an object
+ belonging to the :mod:`Crypto.Hash` module.
+ :type parameter: hash object
+
+ :parameter signature:
+ The signature that needs to be validated.
+ :type signature: byte string
+
+ :raise ValueError: if the signature is not valid.
+ """
+
+ # See 8.2.2 in RFC3447
+ modBits = Crypto.Util.number.size(self._key.n)
+ k = ceil_div(modBits, 8) # Convert from bits to bytes
+
+ # Step 1
+ if len(signature) != k:
+ raise ValueError("Invalid signature")
+ # Step 2a (O2SIP)
+ signature_int = bytes_to_long(signature)
+ # Step 2b (RSAVP1)
+ em_int = self._key._encrypt(signature_int)
+ # Step 2c (I2OSP)
+ em1 = long_to_bytes(em_int, k)
+ # Step 3
+ try:
+ possible_em1 = [ _EMSA_PKCS1_V1_5_ENCODE(msg_hash, k, True) ]
+ # MD2/4/5 hashes always require NULL params in AlgorithmIdentifier.
+ # For all others, it is optional.
+ try:
+ algorithm_is_md = msg_hash.oid.startswith('1.2.840.113549.2.')
+ except AttributeError:
+ algorithm_is_md = False
+ if not algorithm_is_md: # MD2/MD4/MD5
+ possible_em1.append(_EMSA_PKCS1_V1_5_ENCODE(msg_hash, k, False))
+ except ValueError:
+ raise ValueError("Invalid signature")
+ # Step 4
+ # By comparing the full encodings (as opposed to checking each
+ # of its components one at a time) we avoid attacks to the padding
+ # scheme like Bleichenbacher's (see http://www.mail-archive.com/cryptography@metzdowd.com/msg06537).
+ #
+ if em1 not in possible_em1:
+ raise ValueError("Invalid signature")
+ pass
+
+
+def _EMSA_PKCS1_V1_5_ENCODE(msg_hash, emLen, with_hash_parameters=True):
+ """
+ Implement the ``EMSA-PKCS1-V1_5-ENCODE`` function, as defined
+ in PKCS#1 v2.1 (RFC3447, 9.2).
+
+ ``_EMSA-PKCS1-V1_5-ENCODE`` actually accepts the message ``M`` as input,
+ and hash it internally. Here, we expect that the message has already
+ been hashed instead.
+
+ :Parameters:
+ msg_hash : hash object
+ The hash object that holds the digest of the message being signed.
+ emLen : int
+ The length the final encoding must have, in bytes.
+ with_hash_parameters : bool
+ If True (default), include NULL parameters for the hash
+ algorithm in the ``digestAlgorithm`` SEQUENCE.
+
+ :attention: the early standard (RFC2313) stated that ``DigestInfo``
+ had to be BER-encoded. This means that old signatures
+ might have length tags in indefinite form, which
+ is not supported in DER. Such encoding cannot be
+ reproduced by this function.
+
+ :Return: An ``emLen`` byte long string that encodes the hash.
+ """
+
+ # First, build the ASN.1 DER object DigestInfo:
+ #
+ # DigestInfo ::= SEQUENCE {
+ # digestAlgorithm AlgorithmIdentifier,
+ # digest OCTET STRING
+ # }
+ #
+ # where digestAlgorithm identifies the hash function and shall be an
+ # algorithm ID with an OID in the set PKCS1-v1-5DigestAlgorithms.
+ #
+ # PKCS1-v1-5DigestAlgorithms ALGORITHM-IDENTIFIER ::= {
+ # { OID id-md2 PARAMETERS NULL }|
+ # { OID id-md5 PARAMETERS NULL }|
+ # { OID id-sha1 PARAMETERS NULL }|
+ # { OID id-sha256 PARAMETERS NULL }|
+ # { OID id-sha384 PARAMETERS NULL }|
+ # { OID id-sha512 PARAMETERS NULL }
+ # }
+ #
+ # Appendix B.1 also says that for SHA-1/-2 algorithms, the parameters
+ # should be omitted. They may be present, but when they are, they shall
+ # have NULL value.
+
+ digestAlgo = DerSequence([ DerObjectId(msg_hash.oid).encode() ])
+
+ if with_hash_parameters:
+ digestAlgo.append(DerNull().encode())
+
+ digest = DerOctetString(msg_hash.digest())
+ digestInfo = DerSequence([
+ digestAlgo.encode(),
+ digest.encode()
+ ]).encode()
+
+ # We need at least 11 bytes for the remaining data: 3 fixed bytes and
+ # at least 8 bytes of padding).
+ if emLen<len(digestInfo)+11:
+ raise TypeError("Selected hash algorithm has a too long digest (%d bytes)." % len(digest))
+ PS = b'\xFF' * (emLen - len(digestInfo) - 3)
+ return b'\x00\x01' + PS + b'\x00' + digestInfo
+
+def new(rsa_key):
+ """Create a signature object for creating
+ or verifying PKCS#1 v1.5 signatures.
+
+ :parameter rsa_key:
+ The RSA key to use for signing or verifying the message.
+ This is a :class:`Crypto.PublicKey.RSA` object.
+ Signing is only possible when ``rsa_key`` is a **private** RSA key.
+ :type rsa_key: RSA object
+
+ :return: a :class:`PKCS115_SigScheme` signature object
+ """
+ return PKCS115_SigScheme(rsa_key)
+
diff --git a/lib/Crypto/Signature/pkcs1_15.pyi b/lib/Crypto/Signature/pkcs1_15.pyi
new file mode 100644
index 0000000..c4dc1ab
--- /dev/null
+++ b/lib/Crypto/Signature/pkcs1_15.pyi
@@ -0,0 +1,17 @@
+from typing import Optional
+from typing_extensions import Protocol
+
+from Crypto.PublicKey.RSA import RsaKey
+
+class Hash(Protocol):
+ def digest(self) -> bytes: ...
+
+class PKCS115_SigScheme:
+ def __init__(self, rsa_key: RsaKey) -> None: ...
+ def can_sign(self) -> bool: ...
+ def sign(self, msg_hash: Hash) -> bytes: ...
+ def verify(self, msg_hash: Hash, signature: bytes) -> None: ...
+
+def _EMSA_PKCS1_V1_5_ENCODE(msg_hash: Hash, emLen: int, with_hash_parameters: Optional[bool]=True) -> bytes: ...
+
+def new(rsa_key: RsaKey) -> PKCS115_SigScheme: ...
diff --git a/lib/Crypto/Signature/pss.py b/lib/Crypto/Signature/pss.py
new file mode 100644
index 0000000..5f34ace
--- /dev/null
+++ b/lib/Crypto/Signature/pss.py
@@ -0,0 +1,386 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util.py3compat import bchr, bord, iter_range
+import Crypto.Util.number
+from Crypto.Util.number import (ceil_div,
+ long_to_bytes,
+ bytes_to_long
+ )
+from Crypto.Util.strxor import strxor
+from Crypto import Random
+
+
+class PSS_SigScheme:
+ """A signature object for ``RSASSA-PSS``.
+ Do not instantiate directly.
+ Use :func:`Crypto.Signature.pss.new`.
+ """
+
+ def __init__(self, key, mgfunc, saltLen, randfunc):
+ """Initialize this PKCS#1 PSS signature scheme object.
+
+ :Parameters:
+ key : an RSA key object
+ If a private half is given, both signature and
+ verification are possible.
+ If a public half is given, only verification is possible.
+ mgfunc : callable
+ A mask generation function that accepts two parameters:
+ a string to use as seed, and the lenth of the mask to
+ generate, in bytes.
+ saltLen : integer
+ Length of the salt, in bytes.
+ randfunc : callable
+ A function that returns random bytes.
+ """
+
+ self._key = key
+ self._saltLen = saltLen
+ self._mgfunc = mgfunc
+ self._randfunc = randfunc
+
+ def can_sign(self):
+ """Return ``True`` if this object can be used to sign messages."""
+ return self._key.has_private()
+
+ def sign(self, msg_hash):
+ """Create the PKCS#1 PSS signature of a message.
+
+ This function is also called ``RSASSA-PSS-SIGN`` and
+ it is specified in
+ `section 8.1.1 of RFC8017 <https://tools.ietf.org/html/rfc8017#section-8.1.1>`_.
+
+ :parameter msg_hash:
+ This is an object from the :mod:`Crypto.Hash` package.
+ It has been used to digest the message to sign.
+ :type msg_hash: hash object
+
+ :return: the signature encoded as a *byte string*.
+ :raise ValueError: if the RSA key is not long enough for the given hash algorithm.
+ :raise TypeError: if the RSA key has no private half.
+ """
+
+ # Set defaults for salt length and mask generation function
+ if self._saltLen is None:
+ sLen = msg_hash.digest_size
+ else:
+ sLen = self._saltLen
+
+ if self._mgfunc is None:
+ mgf = lambda x, y: MGF1(x, y, msg_hash)
+ else:
+ mgf = self._mgfunc
+
+ modBits = Crypto.Util.number.size(self._key.n)
+
+ # See 8.1.1 in RFC3447
+ k = ceil_div(modBits, 8) # k is length in bytes of the modulus
+ # Step 1
+ em = _EMSA_PSS_ENCODE(msg_hash, modBits-1, self._randfunc, mgf, sLen)
+ # Step 2a (OS2IP)
+ em_int = bytes_to_long(em)
+ # Step 2b (RSASP1)
+ m_int = self._key._decrypt(em_int)
+ # Step 2c (I2OSP)
+ signature = long_to_bytes(m_int, k)
+ return signature
+
+ def verify(self, msg_hash, signature):
+ """Check if the PKCS#1 PSS signature over a message is valid.
+
+ This function is also called ``RSASSA-PSS-VERIFY`` and
+ it is specified in
+ `section 8.1.2 of RFC8037 <https://tools.ietf.org/html/rfc8017#section-8.1.2>`_.
+
+ :parameter msg_hash:
+ The hash that was carried out over the message. This is an object
+ belonging to the :mod:`Crypto.Hash` module.
+ :type parameter: hash object
+
+ :parameter signature:
+ The signature that needs to be validated.
+ :type signature: bytes
+
+ :raise ValueError: if the signature is not valid.
+ """
+
+ # Set defaults for salt length and mask generation function
+ if self._saltLen is None:
+ sLen = msg_hash.digest_size
+ else:
+ sLen = self._saltLen
+ if self._mgfunc:
+ mgf = self._mgfunc
+ else:
+ mgf = lambda x, y: MGF1(x, y, msg_hash)
+
+ modBits = Crypto.Util.number.size(self._key.n)
+
+ # See 8.1.2 in RFC3447
+ k = ceil_div(modBits, 8) # Convert from bits to bytes
+ # Step 1
+ if len(signature) != k:
+ raise ValueError("Incorrect signature")
+ # Step 2a (O2SIP)
+ signature_int = bytes_to_long(signature)
+ # Step 2b (RSAVP1)
+ em_int = self._key._encrypt(signature_int)
+ # Step 2c (I2OSP)
+ emLen = ceil_div(modBits - 1, 8)
+ em = long_to_bytes(em_int, emLen)
+ # Step 3/4
+ _EMSA_PSS_VERIFY(msg_hash, em, modBits-1, mgf, sLen)
+
+
+def MGF1(mgfSeed, maskLen, hash_gen):
+ """Mask Generation Function, described in `B.2.1 of RFC8017
+ <https://tools.ietf.org/html/rfc8017>`_.
+
+ :param mfgSeed:
+ seed from which the mask is generated
+ :type mfgSeed: byte string
+
+ :param maskLen:
+ intended length in bytes of the mask
+ :type maskLen: integer
+
+ :param hash_gen:
+ A module or a hash object from :mod:`Crypto.Hash`
+ :type hash_object:
+
+ :return: the mask, as a *byte string*
+ """
+
+ T = b""
+ for counter in iter_range(ceil_div(maskLen, hash_gen.digest_size)):
+ c = long_to_bytes(counter, 4)
+ hobj = hash_gen.new()
+ hobj.update(mgfSeed + c)
+ T = T + hobj.digest()
+ assert(len(T) >= maskLen)
+ return T[:maskLen]
+
+
+def _EMSA_PSS_ENCODE(mhash, emBits, randFunc, mgf, sLen):
+ r"""
+ Implement the ``EMSA-PSS-ENCODE`` function, as defined
+ in PKCS#1 v2.1 (RFC3447, 9.1.1).
+
+ The original ``EMSA-PSS-ENCODE`` actually accepts the message ``M``
+ as input, and hash it internally. Here, we expect that the message
+ has already been hashed instead.
+
+ :Parameters:
+ mhash : hash object
+ The hash object that holds the digest of the message being signed.
+ emBits : int
+ Maximum length of the final encoding, in bits.
+ randFunc : callable
+ An RNG function that accepts as only parameter an int, and returns
+ a string of random bytes, to be used as salt.
+ mgf : callable
+ A mask generation function that accepts two parameters: a string to
+ use as seed, and the lenth of the mask to generate, in bytes.
+ sLen : int
+ Length of the salt, in bytes.
+
+ :Return: An ``emLen`` byte long string that encodes the hash
+ (with ``emLen = \ceil(emBits/8)``).
+
+ :Raise ValueError:
+ When digest or salt length are too big.
+ """
+
+ emLen = ceil_div(emBits, 8)
+
+ # Bitmask of digits that fill up
+ lmask = 0
+ for i in iter_range(8*emLen-emBits):
+ lmask = lmask >> 1 | 0x80
+
+ # Step 1 and 2 have been already done
+ # Step 3
+ if emLen < mhash.digest_size+sLen+2:
+ raise ValueError("Digest or salt length are too long"
+ " for given key size.")
+ # Step 4
+ salt = randFunc(sLen)
+ # Step 5
+ m_prime = bchr(0)*8 + mhash.digest() + salt
+ # Step 6
+ h = mhash.new()
+ h.update(m_prime)
+ # Step 7
+ ps = bchr(0)*(emLen-sLen-mhash.digest_size-2)
+ # Step 8
+ db = ps + bchr(1) + salt
+ # Step 9
+ dbMask = mgf(h.digest(), emLen-mhash.digest_size-1)
+ # Step 10
+ maskedDB = strxor(db, dbMask)
+ # Step 11
+ maskedDB = bchr(bord(maskedDB[0]) & ~lmask) + maskedDB[1:]
+ # Step 12
+ em = maskedDB + h.digest() + bchr(0xBC)
+ return em
+
+
+def _EMSA_PSS_VERIFY(mhash, em, emBits, mgf, sLen):
+ """
+ Implement the ``EMSA-PSS-VERIFY`` function, as defined
+ in PKCS#1 v2.1 (RFC3447, 9.1.2).
+
+ ``EMSA-PSS-VERIFY`` actually accepts the message ``M`` as input,
+ and hash it internally. Here, we expect that the message has already
+ been hashed instead.
+
+ :Parameters:
+ mhash : hash object
+ The hash object that holds the digest of the message to be verified.
+ em : string
+ The signature to verify, therefore proving that the sender really
+ signed the message that was received.
+ emBits : int
+ Length of the final encoding (em), in bits.
+ mgf : callable
+ A mask generation function that accepts two parameters: a string to
+ use as seed, and the lenth of the mask to generate, in bytes.
+ sLen : int
+ Length of the salt, in bytes.
+
+ :Raise ValueError:
+ When the encoding is inconsistent, or the digest or salt lengths
+ are too big.
+ """
+
+ emLen = ceil_div(emBits, 8)
+
+ # Bitmask of digits that fill up
+ lmask = 0
+ for i in iter_range(8*emLen-emBits):
+ lmask = lmask >> 1 | 0x80
+
+ # Step 1 and 2 have been already done
+ # Step 3
+ if emLen < mhash.digest_size+sLen+2:
+ raise ValueError("Incorrect signature")
+ # Step 4
+ if ord(em[-1:]) != 0xBC:
+ raise ValueError("Incorrect signature")
+ # Step 5
+ maskedDB = em[:emLen-mhash.digest_size-1]
+ h = em[emLen-mhash.digest_size-1:-1]
+ # Step 6
+ if lmask & bord(em[0]):
+ raise ValueError("Incorrect signature")
+ # Step 7
+ dbMask = mgf(h, emLen-mhash.digest_size-1)
+ # Step 8
+ db = strxor(maskedDB, dbMask)
+ # Step 9
+ db = bchr(bord(db[0]) & ~lmask) + db[1:]
+ # Step 10
+ if not db.startswith(bchr(0)*(emLen-mhash.digest_size-sLen-2) + bchr(1)):
+ raise ValueError("Incorrect signature")
+ # Step 11
+ if sLen > 0:
+ salt = db[-sLen:]
+ else:
+ salt = b""
+ # Step 12
+ m_prime = bchr(0)*8 + mhash.digest() + salt
+ # Step 13
+ hobj = mhash.new()
+ hobj.update(m_prime)
+ hp = hobj.digest()
+ # Step 14
+ if h != hp:
+ raise ValueError("Incorrect signature")
+
+
+def new(rsa_key, **kwargs):
+ """Create an object for making or verifying PKCS#1 PSS signatures.
+
+ :parameter rsa_key:
+ The RSA key to use for signing or verifying the message.
+ This is a :class:`Crypto.PublicKey.RSA` object.
+ Signing is only possible when ``rsa_key`` is a **private** RSA key.
+ :type rsa_key: RSA object
+
+ :Keyword Arguments:
+
+ * *mask_func* (``callable``) --
+ A function that returns the mask (as `bytes`).
+ It must accept two parameters: a seed (as `bytes`)
+ and the length of the data to return.
+
+ If not specified, it will be the function :func:`MGF1` defined in
+ `RFC8017 <https://tools.ietf.org/html/rfc8017#page-67>`_ and
+ combined with the same hash algorithm applied to the
+ message to sign or verify.
+
+ If you want to use a different function, for instance still :func:`MGF1`
+ but together with another hash, you can do::
+
+ from Crypto.Hash import SHA256
+ from Crypto.Signature.pss import MGF1
+ mgf = lambda x, y: MGF1(x, y, SHA256)
+
+ * *salt_bytes* (``integer``) --
+ Length of the salt, in bytes.
+ It is a value between 0 and ``emLen - hLen - 2``, where ``emLen``
+ is the size of the RSA modulus and ``hLen`` is the size of the digest
+ applied to the message to sign or verify.
+
+ The salt is generated internally, you don't need to provide it.
+
+ If not specified, the salt length will be ``hLen``.
+ If it is zero, the signature scheme becomes deterministic.
+
+ Note that in some implementations such as OpenSSL the default
+ salt length is ``emLen - hLen - 2`` (even though it is not more
+ secure than ``hLen``).
+
+ * *rand_func* (``callable``) --
+ A function that returns random ``bytes``, of the desired length.
+ The default is :func:`Crypto.Random.get_random_bytes`.
+
+ :return: a :class:`PSS_SigScheme` signature object
+ """
+
+ mask_func = kwargs.pop("mask_func", None)
+ salt_len = kwargs.pop("salt_bytes", None)
+ rand_func = kwargs.pop("rand_func", None)
+ if rand_func is None:
+ rand_func = Random.get_random_bytes
+ if kwargs:
+ raise ValueError("Unknown keywords: " + str(kwargs.keys()))
+ return PSS_SigScheme(rsa_key, mask_func, salt_len, rand_func)
diff --git a/lib/Crypto/Signature/pss.pyi b/lib/Crypto/Signature/pss.pyi
new file mode 100644
index 0000000..4d216ca
--- /dev/null
+++ b/lib/Crypto/Signature/pss.pyi
@@ -0,0 +1,30 @@
+from typing import Union, Callable, Optional
+from typing_extensions import Protocol
+
+from Crypto.PublicKey.RSA import RsaKey
+
+
+class Hash(Protocol):
+ def digest(self) -> bytes: ...
+ def update(self, bytes) -> None: ...
+
+
+class HashModule(Protocol):
+ @staticmethod
+ def new(data: Optional[bytes]) -> Hash: ...
+
+
+MaskFunction = Callable[[bytes, int, Union[Hash, HashModule]], bytes]
+RndFunction = Callable[[int], bytes]
+
+class PSS_SigScheme:
+ def __init__(self, key: RsaKey, mgfunc: RndFunction, saltLen: int, randfunc: RndFunction) -> None: ...
+ def can_sign(self) -> bool: ...
+ def sign(self, msg_hash: Hash) -> bytes: ...
+ def verify(self, msg_hash: Hash, signature: bytes) -> None: ...
+
+
+MGF1 : MaskFunction
+def _EMSA_PSS_ENCODE(mhash: Hash, emBits: int, randFunc: RndFunction, mgf:MaskFunction, sLen: int) -> str: ...
+def _EMSA_PSS_VERIFY(mhash: Hash, em: str, emBits: int, mgf: MaskFunction, sLen: int) -> None: ...
+def new(rsa_key: RsaKey, **kwargs: Union[MaskFunction, RndFunction, int]) -> PSS_SigScheme: ...
diff --git a/lib/Crypto/Util/Counter.py b/lib/Crypto/Util/Counter.py
new file mode 100644
index 0000000..c67bc95
--- /dev/null
+++ b/lib/Crypto/Util/Counter.py
@@ -0,0 +1,77 @@
+# -*- coding: ascii -*-
+#
+# Util/Counter.py : Fast counter for use with CTR-mode ciphers
+#
+# Written in 2008 by Dwayne C. Litzenberger <dlitz@dlitz.net>
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+def new(nbits, prefix=b"", suffix=b"", initial_value=1, little_endian=False, allow_wraparound=False):
+ """Create a stateful counter block function suitable for CTR encryption modes.
+
+ Each call to the function returns the next counter block.
+ Each counter block is made up by three parts:
+
+ +------+--------------+-------+
+ |prefix| counter value|postfix|
+ +------+--------------+-------+
+
+ The counter value is incremented by 1 at each call.
+
+ Args:
+ nbits (integer):
+ Length of the desired counter value, in bits. It must be a multiple of 8.
+ prefix (byte string):
+ The constant prefix of the counter block. By default, no prefix is
+ used.
+ suffix (byte string):
+ The constant postfix of the counter block. By default, no suffix is
+ used.
+ initial_value (integer):
+ The initial value of the counter. Default value is 1.
+ Its length in bits must not exceed the argument ``nbits``.
+ little_endian (boolean):
+ If ``True``, the counter number will be encoded in little endian format.
+ If ``False`` (default), in big endian format.
+ allow_wraparound (boolean):
+ This parameter is ignored.
+ Returns:
+ An object that can be passed with the :data:`counter` parameter to a CTR mode
+ cipher.
+
+ It must hold that *len(prefix) + nbits//8 + len(suffix)* matches the
+ block size of the underlying block cipher.
+ """
+
+ if (nbits % 8) != 0:
+ raise ValueError("'nbits' must be a multiple of 8")
+
+ iv_bl = initial_value.bit_length()
+ if iv_bl > nbits:
+ raise ValueError("Initial value takes %d bits but it is longer than "
+ "the counter (%d bits)" %
+ (iv_bl, nbits))
+
+ # Ignore wraparound
+ return {"counter_len": nbits // 8,
+ "prefix": prefix,
+ "suffix": suffix,
+ "initial_value": initial_value,
+ "little_endian": little_endian
+ }
diff --git a/lib/Crypto/Util/Counter.pyi b/lib/Crypto/Util/Counter.pyi
new file mode 100644
index 0000000..fa2ffdd
--- /dev/null
+++ b/lib/Crypto/Util/Counter.pyi
@@ -0,0 +1,5 @@
+from typing import Optional, Union, Dict
+
+def new(nbits: int, prefix: Optional[bytes]=..., suffix: Optional[bytes]=..., initial_value: Optional[int]=1,
+ little_endian: Optional[bool]=False, allow_wraparound: Optional[bool]=False) -> \
+ Dict[str, Union[int, bytes, bool]]: ...
diff --git a/lib/Crypto/Util/Padding.py b/lib/Crypto/Util/Padding.py
new file mode 100644
index 0000000..da69e55
--- /dev/null
+++ b/lib/Crypto/Util/Padding.py
@@ -0,0 +1,108 @@
+#
+# Util/Padding.py : Functions to manage padding
+#
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+__all__ = [ 'pad', 'unpad' ]
+
+from Crypto.Util.py3compat import *
+
+
+def pad(data_to_pad, block_size, style='pkcs7'):
+ """Apply standard padding.
+
+ Args:
+ data_to_pad (byte string):
+ The data that needs to be padded.
+ block_size (integer):
+ The block boundary to use for padding. The output length is guaranteed
+ to be a multiple of :data:`block_size`.
+ style (string):
+ Padding algorithm. It can be *'pkcs7'* (default), *'iso7816'* or *'x923'*.
+
+ Return:
+ byte string : the original data with the appropriate padding added at the end.
+ """
+
+ padding_len = block_size-len(data_to_pad)%block_size
+ if style == 'pkcs7':
+ padding = bchr(padding_len)*padding_len
+ elif style == 'x923':
+ padding = bchr(0)*(padding_len-1) + bchr(padding_len)
+ elif style == 'iso7816':
+ padding = bchr(128) + bchr(0)*(padding_len-1)
+ else:
+ raise ValueError("Unknown padding style")
+ return data_to_pad + padding
+
+
+def unpad(padded_data, block_size, style='pkcs7'):
+ """Remove standard padding.
+
+ Args:
+ padded_data (byte string):
+ A piece of data with padding that needs to be stripped.
+ block_size (integer):
+ The block boundary to use for padding. The input length
+ must be a multiple of :data:`block_size`.
+ style (string):
+ Padding algorithm. It can be *'pkcs7'* (default), *'iso7816'* or *'x923'*.
+ Return:
+ byte string : data without padding.
+ Raises:
+ ValueError: if the padding is incorrect.
+ """
+
+ pdata_len = len(padded_data)
+ if pdata_len == 0:
+ raise ValueError("Zero-length input cannot be unpadded")
+ if pdata_len % block_size:
+ raise ValueError("Input data is not padded")
+ if style in ('pkcs7', 'x923'):
+ padding_len = bord(padded_data[-1])
+ if padding_len<1 or padding_len>min(block_size, pdata_len):
+ raise ValueError("Padding is incorrect.")
+ if style == 'pkcs7':
+ if padded_data[-padding_len:]!=bchr(padding_len)*padding_len:
+ raise ValueError("PKCS#7 padding is incorrect.")
+ else:
+ if padded_data[-padding_len:-1]!=bchr(0)*(padding_len-1):
+ raise ValueError("ANSI X.923 padding is incorrect.")
+ elif style == 'iso7816':
+ padding_len = pdata_len - padded_data.rfind(bchr(128))
+ if padding_len<1 or padding_len>min(block_size, pdata_len):
+ raise ValueError("Padding is incorrect.")
+ if padding_len>1 and padded_data[1-padding_len:]!=bchr(0)*(padding_len-1):
+ raise ValueError("ISO 7816-4 padding is incorrect.")
+ else:
+ raise ValueError("Unknown padding style")
+ return padded_data[:-padding_len]
+
diff --git a/lib/Crypto/Util/Padding.pyi b/lib/Crypto/Util/Padding.pyi
new file mode 100644
index 0000000..4d8d30d
--- /dev/null
+++ b/lib/Crypto/Util/Padding.pyi
@@ -0,0 +1,6 @@
+from typing import Optional
+
+__all__ = [ 'pad', 'unpad' ]
+
+def pad(data_to_pad: bytes, block_size: int, style: Optional[str]='pkcs7') -> bytes: ...
+def unpad(padded_data: bytes, block_size: int, style: Optional[str]='pkcs7') -> bytes: ... \ No newline at end of file
diff --git a/lib/Crypto/Util/RFC1751.py b/lib/Crypto/Util/RFC1751.py
new file mode 100644
index 0000000..9ed52d2
--- /dev/null
+++ b/lib/Crypto/Util/RFC1751.py
@@ -0,0 +1,386 @@
+# rfc1751.py : Converts between 128-bit strings and a human-readable
+# sequence of words, as defined in RFC1751: "A Convention for
+# Human-Readable 128-bit Keys", by Daniel L. McDonald.
+#
+# Part of the Python Cryptography Toolkit
+#
+# Written by Andrew M. Kuchling and others
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+from __future__ import print_function
+
+import binascii
+
+from Crypto.Util.py3compat import bord, bchr
+
+binary = {0: '0000', 1: '0001', 2: '0010', 3: '0011', 4: '0100', 5: '0101',
+ 6: '0110', 7: '0111', 8: '1000', 9: '1001', 10: '1010', 11: '1011',
+ 12: '1100', 13: '1101', 14: '1110', 15: '1111'}
+
+
+def _key2bin(s):
+ "Convert a key into a string of binary digits"
+ kl = map(lambda x: bord(x), s)
+ kl = map(lambda x: binary[x >> 4] + binary[x & 15], kl)
+ return ''.join(kl)
+
+
+def _extract(key, start, length):
+ """Extract a bitstring(2.x)/bytestring(2.x) from a string of binary digits, and return its
+ numeric value."""
+
+ result = 0
+ for y in key[start:start+length]:
+ result = result * 2 + ord(y) - 48
+ return result
+
+
+def key_to_english(key):
+ """Transform an arbitrary key into a string containing English words.
+
+ Example::
+
+ >>> from Crypto.Util.RFC1751 import key_to_english
+ >>> key_to_english(b'66666666')
+ 'RAM LOIS GOAD CREW CARE HIT'
+
+ Args:
+ key (byte string):
+ The key to convert. Its length must be a multiple of 8.
+ Return:
+ A string of English words.
+ """
+
+ if len(key) % 8 != 0:
+ raise ValueError('The length of the key must be a multiple of 8.')
+
+ english = ''
+ for index in range(0, len(key), 8): # Loop over 8-byte subkeys
+ subkey = key[index:index + 8]
+ # Compute the parity of the key
+ skbin = _key2bin(subkey)
+ p = 0
+ for i in range(0, 64, 2):
+ p = p + _extract(skbin, i, 2)
+ # Append parity bits to the subkey
+ skbin = _key2bin(subkey + bchr((p << 6) & 255))
+ for i in range(0, 64, 11):
+ english = english + wordlist[_extract(skbin, i, 11)] + ' '
+
+ return english.strip()
+
+
+def english_to_key(s):
+ """Transform a string into a corresponding key.
+
+ Example::
+
+ >>> from Crypto.Util.RFC1751 import english_to_key
+ >>> english_to_key('RAM LOIS GOAD CREW CARE HIT')
+ b'66666666'
+
+ Args:
+ s (string): the string with the words separated by whitespace;
+ the number of words must be a multiple of 6.
+ Return:
+ A byte string.
+ """
+
+ L = s.upper().split()
+ key = b''
+ for index in range(0, len(L), 6):
+ sublist = L[index:index + 6]
+ char = 9 * [0]
+ bits = 0
+ for i in sublist:
+ index = wordlist.index(i)
+ shift = (8 - (bits + 11) % 8) % 8
+ y = index << shift
+ cl, cc, cr = (y >> 16), (y >> 8) & 0xff, y & 0xff
+ if (shift > 5):
+ char[bits >> 3] = char[bits >> 3] | cl
+ char[(bits >> 3) + 1] = char[(bits >> 3) + 1] | cc
+ char[(bits >> 3) + 2] = char[(bits >> 3) + 2] | cr
+ elif shift > -3:
+ char[bits >> 3] = char[bits >> 3] | cc
+ char[(bits >> 3) + 1] = char[(bits >> 3) + 1] | cr
+ else:
+ char[bits >> 3] = char[bits >> 3] | cr
+ bits = bits + 11
+
+ subkey = b''
+ for y in char:
+ subkey = subkey + bchr(y)
+
+ # Check the parity of the resulting key
+ skbin = _key2bin(subkey)
+ p = 0
+ for i in range(0, 64, 2):
+ p = p + _extract(skbin, i, 2)
+ if (p & 3) != _extract(skbin, 64, 2):
+ raise ValueError("Parity error in resulting key")
+ key = key + subkey[0:8]
+ return key
+
+
+wordlist = [
+ "A", "ABE", "ACE", "ACT", "AD", "ADA", "ADD",
+ "AGO", "AID", "AIM", "AIR", "ALL", "ALP", "AM", "AMY", "AN", "ANA",
+ "AND", "ANN", "ANT", "ANY", "APE", "APS", "APT", "ARC", "ARE", "ARK",
+ "ARM", "ART", "AS", "ASH", "ASK", "AT", "ATE", "AUG", "AUK", "AVE",
+ "AWE", "AWK", "AWL", "AWN", "AX", "AYE", "BAD", "BAG", "BAH", "BAM",
+ "BAN", "BAR", "BAT", "BAY", "BE", "BED", "BEE", "BEG", "BEN", "BET",
+ "BEY", "BIB", "BID", "BIG", "BIN", "BIT", "BOB", "BOG", "BON", "BOO",
+ "BOP", "BOW", "BOY", "BUB", "BUD", "BUG", "BUM", "BUN", "BUS", "BUT",
+ "BUY", "BY", "BYE", "CAB", "CAL", "CAM", "CAN", "CAP", "CAR", "CAT",
+ "CAW", "COD", "COG", "COL", "CON", "COO", "COP", "COT", "COW", "COY",
+ "CRY", "CUB", "CUE", "CUP", "CUR", "CUT", "DAB", "DAD", "DAM", "DAN",
+ "DAR", "DAY", "DEE", "DEL", "DEN", "DES", "DEW", "DID", "DIE", "DIG",
+ "DIN", "DIP", "DO", "DOE", "DOG", "DON", "DOT", "DOW", "DRY", "DUB",
+ "DUD", "DUE", "DUG", "DUN", "EAR", "EAT", "ED", "EEL", "EGG", "EGO",
+ "ELI", "ELK", "ELM", "ELY", "EM", "END", "EST", "ETC", "EVA", "EVE",
+ "EWE", "EYE", "FAD", "FAN", "FAR", "FAT", "FAY", "FED", "FEE", "FEW",
+ "FIB", "FIG", "FIN", "FIR", "FIT", "FLO", "FLY", "FOE", "FOG", "FOR",
+ "FRY", "FUM", "FUN", "FUR", "GAB", "GAD", "GAG", "GAL", "GAM", "GAP",
+ "GAS", "GAY", "GEE", "GEL", "GEM", "GET", "GIG", "GIL", "GIN", "GO",
+ "GOT", "GUM", "GUN", "GUS", "GUT", "GUY", "GYM", "GYP", "HA", "HAD",
+ "HAL", "HAM", "HAN", "HAP", "HAS", "HAT", "HAW", "HAY", "HE", "HEM",
+ "HEN", "HER", "HEW", "HEY", "HI", "HID", "HIM", "HIP", "HIS", "HIT",
+ "HO", "HOB", "HOC", "HOE", "HOG", "HOP", "HOT", "HOW", "HUB", "HUE",
+ "HUG", "HUH", "HUM", "HUT", "I", "ICY", "IDA", "IF", "IKE", "ILL",
+ "INK", "INN", "IO", "ION", "IQ", "IRA", "IRE", "IRK", "IS", "IT",
+ "ITS", "IVY", "JAB", "JAG", "JAM", "JAN", "JAR", "JAW", "JAY", "JET",
+ "JIG", "JIM", "JO", "JOB", "JOE", "JOG", "JOT", "JOY", "JUG", "JUT",
+ "KAY", "KEG", "KEN", "KEY", "KID", "KIM", "KIN", "KIT", "LA", "LAB",
+ "LAC", "LAD", "LAG", "LAM", "LAP", "LAW", "LAY", "LEA", "LED", "LEE",
+ "LEG", "LEN", "LEO", "LET", "LEW", "LID", "LIE", "LIN", "LIP", "LIT",
+ "LO", "LOB", "LOG", "LOP", "LOS", "LOT", "LOU", "LOW", "LOY", "LUG",
+ "LYE", "MA", "MAC", "MAD", "MAE", "MAN", "MAO", "MAP", "MAT", "MAW",
+ "MAY", "ME", "MEG", "MEL", "MEN", "MET", "MEW", "MID", "MIN", "MIT",
+ "MOB", "MOD", "MOE", "MOO", "MOP", "MOS", "MOT", "MOW", "MUD", "MUG",
+ "MUM", "MY", "NAB", "NAG", "NAN", "NAP", "NAT", "NAY", "NE", "NED",
+ "NEE", "NET", "NEW", "NIB", "NIL", "NIP", "NIT", "NO", "NOB", "NOD",
+ "NON", "NOR", "NOT", "NOV", "NOW", "NU", "NUN", "NUT", "O", "OAF",
+ "OAK", "OAR", "OAT", "ODD", "ODE", "OF", "OFF", "OFT", "OH", "OIL",
+ "OK", "OLD", "ON", "ONE", "OR", "ORB", "ORE", "ORR", "OS", "OTT",
+ "OUR", "OUT", "OVA", "OW", "OWE", "OWL", "OWN", "OX", "PA", "PAD",
+ "PAL", "PAM", "PAN", "PAP", "PAR", "PAT", "PAW", "PAY", "PEA", "PEG",
+ "PEN", "PEP", "PER", "PET", "PEW", "PHI", "PI", "PIE", "PIN", "PIT",
+ "PLY", "PO", "POD", "POE", "POP", "POT", "POW", "PRO", "PRY", "PUB",
+ "PUG", "PUN", "PUP", "PUT", "QUO", "RAG", "RAM", "RAN", "RAP", "RAT",
+ "RAW", "RAY", "REB", "RED", "REP", "RET", "RIB", "RID", "RIG", "RIM",
+ "RIO", "RIP", "ROB", "ROD", "ROE", "RON", "ROT", "ROW", "ROY", "RUB",
+ "RUE", "RUG", "RUM", "RUN", "RYE", "SAC", "SAD", "SAG", "SAL", "SAM",
+ "SAN", "SAP", "SAT", "SAW", "SAY", "SEA", "SEC", "SEE", "SEN", "SET",
+ "SEW", "SHE", "SHY", "SIN", "SIP", "SIR", "SIS", "SIT", "SKI", "SKY",
+ "SLY", "SO", "SOB", "SOD", "SON", "SOP", "SOW", "SOY", "SPA", "SPY",
+ "SUB", "SUD", "SUE", "SUM", "SUN", "SUP", "TAB", "TAD", "TAG", "TAN",
+ "TAP", "TAR", "TEA", "TED", "TEE", "TEN", "THE", "THY", "TIC", "TIE",
+ "TIM", "TIN", "TIP", "TO", "TOE", "TOG", "TOM", "TON", "TOO", "TOP",
+ "TOW", "TOY", "TRY", "TUB", "TUG", "TUM", "TUN", "TWO", "UN", "UP",
+ "US", "USE", "VAN", "VAT", "VET", "VIE", "WAD", "WAG", "WAR", "WAS",
+ "WAY", "WE", "WEB", "WED", "WEE", "WET", "WHO", "WHY", "WIN", "WIT",
+ "WOK", "WON", "WOO", "WOW", "WRY", "WU", "YAM", "YAP", "YAW", "YE",
+ "YEA", "YES", "YET", "YOU", "ABED", "ABEL", "ABET", "ABLE", "ABUT",
+ "ACHE", "ACID", "ACME", "ACRE", "ACTA", "ACTS", "ADAM", "ADDS",
+ "ADEN", "AFAR", "AFRO", "AGEE", "AHEM", "AHOY", "AIDA", "AIDE",
+ "AIDS", "AIRY", "AJAR", "AKIN", "ALAN", "ALEC", "ALGA", "ALIA",
+ "ALLY", "ALMA", "ALOE", "ALSO", "ALTO", "ALUM", "ALVA", "AMEN",
+ "AMES", "AMID", "AMMO", "AMOK", "AMOS", "AMRA", "ANDY", "ANEW",
+ "ANNA", "ANNE", "ANTE", "ANTI", "AQUA", "ARAB", "ARCH", "AREA",
+ "ARGO", "ARID", "ARMY", "ARTS", "ARTY", "ASIA", "ASKS", "ATOM",
+ "AUNT", "AURA", "AUTO", "AVER", "AVID", "AVIS", "AVON", "AVOW",
+ "AWAY", "AWRY", "BABE", "BABY", "BACH", "BACK", "BADE", "BAIL",
+ "BAIT", "BAKE", "BALD", "BALE", "BALI", "BALK", "BALL", "BALM",
+ "BAND", "BANE", "BANG", "BANK", "BARB", "BARD", "BARE", "BARK",
+ "BARN", "BARR", "BASE", "BASH", "BASK", "BASS", "BATE", "BATH",
+ "BAWD", "BAWL", "BEAD", "BEAK", "BEAM", "BEAN", "BEAR", "BEAT",
+ "BEAU", "BECK", "BEEF", "BEEN", "BEER",
+ "BEET", "BELA", "BELL", "BELT", "BEND", "BENT", "BERG", "BERN",
+ "BERT", "BESS", "BEST", "BETA", "BETH", "BHOY", "BIAS", "BIDE",
+ "BIEN", "BILE", "BILK", "BILL", "BIND", "BING", "BIRD", "BITE",
+ "BITS", "BLAB", "BLAT", "BLED", "BLEW", "BLOB", "BLOC", "BLOT",
+ "BLOW", "BLUE", "BLUM", "BLUR", "BOAR", "BOAT", "BOCA", "BOCK",
+ "BODE", "BODY", "BOGY", "BOHR", "BOIL", "BOLD", "BOLO", "BOLT",
+ "BOMB", "BONA", "BOND", "BONE", "BONG", "BONN", "BONY", "BOOK",
+ "BOOM", "BOON", "BOOT", "BORE", "BORG", "BORN", "BOSE", "BOSS",
+ "BOTH", "BOUT", "BOWL", "BOYD", "BRAD", "BRAE", "BRAG", "BRAN",
+ "BRAY", "BRED", "BREW", "BRIG", "BRIM", "BROW", "BUCK", "BUDD",
+ "BUFF", "BULB", "BULK", "BULL", "BUNK", "BUNT", "BUOY", "BURG",
+ "BURL", "BURN", "BURR", "BURT", "BURY", "BUSH", "BUSS", "BUST",
+ "BUSY", "BYTE", "CADY", "CAFE", "CAGE", "CAIN", "CAKE", "CALF",
+ "CALL", "CALM", "CAME", "CANE", "CANT", "CARD", "CARE", "CARL",
+ "CARR", "CART", "CASE", "CASH", "CASK", "CAST", "CAVE", "CEIL",
+ "CELL", "CENT", "CERN", "CHAD", "CHAR", "CHAT", "CHAW", "CHEF",
+ "CHEN", "CHEW", "CHIC", "CHIN", "CHOU", "CHOW", "CHUB", "CHUG",
+ "CHUM", "CITE", "CITY", "CLAD", "CLAM", "CLAN", "CLAW", "CLAY",
+ "CLOD", "CLOG", "CLOT", "CLUB", "CLUE", "COAL", "COAT", "COCA",
+ "COCK", "COCO", "CODA", "CODE", "CODY", "COED", "COIL", "COIN",
+ "COKE", "COLA", "COLD", "COLT", "COMA", "COMB", "COME", "COOK",
+ "COOL", "COON", "COOT", "CORD", "CORE", "CORK", "CORN", "COST",
+ "COVE", "COWL", "CRAB", "CRAG", "CRAM", "CRAY", "CREW", "CRIB",
+ "CROW", "CRUD", "CUBA", "CUBE", "CUFF", "CULL", "CULT", "CUNY",
+ "CURB", "CURD", "CURE", "CURL", "CURT", "CUTS", "DADE", "DALE",
+ "DAME", "DANA", "DANE", "DANG", "DANK", "DARE", "DARK", "DARN",
+ "DART", "DASH", "DATA", "DATE", "DAVE", "DAVY", "DAWN", "DAYS",
+ "DEAD", "DEAF", "DEAL", "DEAN", "DEAR", "DEBT", "DECK", "DEED",
+ "DEEM", "DEER", "DEFT", "DEFY", "DELL", "DENT", "DENY", "DESK",
+ "DIAL", "DICE", "DIED", "DIET", "DIME", "DINE", "DING", "DINT",
+ "DIRE", "DIRT", "DISC", "DISH", "DISK", "DIVE", "DOCK", "DOES",
+ "DOLE", "DOLL", "DOLT", "DOME", "DONE", "DOOM", "DOOR", "DORA",
+ "DOSE", "DOTE", "DOUG", "DOUR", "DOVE", "DOWN", "DRAB", "DRAG",
+ "DRAM", "DRAW", "DREW", "DRUB", "DRUG", "DRUM", "DUAL", "DUCK",
+ "DUCT", "DUEL", "DUET", "DUKE", "DULL", "DUMB", "DUNE", "DUNK",
+ "DUSK", "DUST", "DUTY", "EACH", "EARL", "EARN", "EASE", "EAST",
+ "EASY", "EBEN", "ECHO", "EDDY", "EDEN", "EDGE", "EDGY", "EDIT",
+ "EDNA", "EGAN", "ELAN", "ELBA", "ELLA", "ELSE", "EMIL", "EMIT",
+ "EMMA", "ENDS", "ERIC", "EROS", "EVEN", "EVER", "EVIL", "EYED",
+ "FACE", "FACT", "FADE", "FAIL", "FAIN", "FAIR", "FAKE", "FALL",
+ "FAME", "FANG", "FARM", "FAST", "FATE", "FAWN", "FEAR", "FEAT",
+ "FEED", "FEEL", "FEET", "FELL", "FELT", "FEND", "FERN", "FEST",
+ "FEUD", "FIEF", "FIGS", "FILE", "FILL", "FILM", "FIND", "FINE",
+ "FINK", "FIRE", "FIRM", "FISH", "FISK", "FIST", "FITS", "FIVE",
+ "FLAG", "FLAK", "FLAM", "FLAT", "FLAW", "FLEA", "FLED", "FLEW",
+ "FLIT", "FLOC", "FLOG", "FLOW", "FLUB", "FLUE", "FOAL", "FOAM",
+ "FOGY", "FOIL", "FOLD", "FOLK", "FOND", "FONT", "FOOD", "FOOL",
+ "FOOT", "FORD", "FORE", "FORK", "FORM", "FORT", "FOSS", "FOUL",
+ "FOUR", "FOWL", "FRAU", "FRAY", "FRED", "FREE", "FRET", "FREY",
+ "FROG", "FROM", "FUEL", "FULL", "FUME", "FUND", "FUNK", "FURY",
+ "FUSE", "FUSS", "GAFF", "GAGE", "GAIL", "GAIN", "GAIT", "GALA",
+ "GALE", "GALL", "GALT", "GAME", "GANG", "GARB", "GARY", "GASH",
+ "GATE", "GAUL", "GAUR", "GAVE", "GAWK", "GEAR", "GELD", "GENE",
+ "GENT", "GERM", "GETS", "GIBE", "GIFT", "GILD", "GILL", "GILT",
+ "GINA", "GIRD", "GIRL", "GIST", "GIVE", "GLAD", "GLEE", "GLEN",
+ "GLIB", "GLOB", "GLOM", "GLOW", "GLUE", "GLUM", "GLUT", "GOAD",
+ "GOAL", "GOAT", "GOER", "GOES", "GOLD", "GOLF", "GONE", "GONG",
+ "GOOD", "GOOF", "GORE", "GORY", "GOSH", "GOUT", "GOWN", "GRAB",
+ "GRAD", "GRAY", "GREG", "GREW", "GREY", "GRID", "GRIM", "GRIN",
+ "GRIT", "GROW", "GRUB", "GULF", "GULL", "GUNK", "GURU", "GUSH",
+ "GUST", "GWEN", "GWYN", "HAAG", "HAAS", "HACK", "HAIL", "HAIR",
+ "HALE", "HALF", "HALL", "HALO", "HALT", "HAND", "HANG", "HANK",
+ "HANS", "HARD", "HARK", "HARM", "HART", "HASH", "HAST", "HATE",
+ "HATH", "HAUL", "HAVE", "HAWK", "HAYS", "HEAD", "HEAL", "HEAR",
+ "HEAT", "HEBE", "HECK", "HEED", "HEEL", "HEFT", "HELD", "HELL",
+ "HELM", "HERB", "HERD", "HERE", "HERO", "HERS", "HESS", "HEWN",
+ "HICK", "HIDE", "HIGH", "HIKE", "HILL", "HILT", "HIND", "HINT",
+ "HIRE", "HISS", "HIVE", "HOBO", "HOCK", "HOFF", "HOLD", "HOLE",
+ "HOLM", "HOLT", "HOME", "HONE", "HONK", "HOOD", "HOOF", "HOOK",
+ "HOOT", "HORN", "HOSE", "HOST", "HOUR", "HOVE", "HOWE", "HOWL",
+ "HOYT", "HUCK", "HUED", "HUFF", "HUGE", "HUGH", "HUGO", "HULK",
+ "HULL", "HUNK", "HUNT", "HURD", "HURL", "HURT", "HUSH", "HYDE",
+ "HYMN", "IBIS", "ICON", "IDEA", "IDLE", "IFFY", "INCA", "INCH",
+ "INTO", "IONS", "IOTA", "IOWA", "IRIS", "IRMA", "IRON", "ISLE",
+ "ITCH", "ITEM", "IVAN", "JACK", "JADE", "JAIL", "JAKE", "JANE",
+ "JAVA", "JEAN", "JEFF", "JERK", "JESS", "JEST", "JIBE", "JILL",
+ "JILT", "JIVE", "JOAN", "JOBS", "JOCK", "JOEL", "JOEY", "JOHN",
+ "JOIN", "JOKE", "JOLT", "JOVE", "JUDD", "JUDE", "JUDO", "JUDY",
+ "JUJU", "JUKE", "JULY", "JUNE", "JUNK", "JUNO", "JURY", "JUST",
+ "JUTE", "KAHN", "KALE", "KANE", "KANT", "KARL", "KATE", "KEEL",
+ "KEEN", "KENO", "KENT", "KERN", "KERR", "KEYS", "KICK", "KILL",
+ "KIND", "KING", "KIRK", "KISS", "KITE", "KLAN", "KNEE", "KNEW",
+ "KNIT", "KNOB", "KNOT", "KNOW", "KOCH", "KONG", "KUDO", "KURD",
+ "KURT", "KYLE", "LACE", "LACK", "LACY", "LADY", "LAID", "LAIN",
+ "LAIR", "LAKE", "LAMB", "LAME", "LAND", "LANE", "LANG", "LARD",
+ "LARK", "LASS", "LAST", "LATE", "LAUD", "LAVA", "LAWN", "LAWS",
+ "LAYS", "LEAD", "LEAF", "LEAK", "LEAN", "LEAR", "LEEK", "LEER",
+ "LEFT", "LEND", "LENS", "LENT", "LEON", "LESK", "LESS", "LEST",
+ "LETS", "LIAR", "LICE", "LICK", "LIED", "LIEN", "LIES", "LIEU",
+ "LIFE", "LIFT", "LIKE", "LILA", "LILT", "LILY", "LIMA", "LIMB",
+ "LIME", "LIND", "LINE", "LINK", "LINT", "LION", "LISA", "LIST",
+ "LIVE", "LOAD", "LOAF", "LOAM", "LOAN", "LOCK", "LOFT", "LOGE",
+ "LOIS", "LOLA", "LONE", "LONG", "LOOK", "LOON", "LOOT", "LORD",
+ "LORE", "LOSE", "LOSS", "LOST", "LOUD", "LOVE", "LOWE", "LUCK",
+ "LUCY", "LUGE", "LUKE", "LULU", "LUND", "LUNG", "LURA", "LURE",
+ "LURK", "LUSH", "LUST", "LYLE", "LYNN", "LYON", "LYRA", "MACE",
+ "MADE", "MAGI", "MAID", "MAIL", "MAIN", "MAKE", "MALE", "MALI",
+ "MALL", "MALT", "MANA", "MANN", "MANY", "MARC", "MARE", "MARK",
+ "MARS", "MART", "MARY", "MASH", "MASK", "MASS", "MAST", "MATE",
+ "MATH", "MAUL", "MAYO", "MEAD", "MEAL", "MEAN", "MEAT", "MEEK",
+ "MEET", "MELD", "MELT", "MEMO", "MEND", "MENU", "MERT", "MESH",
+ "MESS", "MICE", "MIKE", "MILD", "MILE", "MILK", "MILL", "MILT",
+ "MIMI", "MIND", "MINE", "MINI", "MINK", "MINT", "MIRE", "MISS",
+ "MIST", "MITE", "MITT", "MOAN", "MOAT", "MOCK", "MODE", "MOLD",
+ "MOLE", "MOLL", "MOLT", "MONA", "MONK", "MONT", "MOOD", "MOON",
+ "MOOR", "MOOT", "MORE", "MORN", "MORT", "MOSS", "MOST", "MOTH",
+ "MOVE", "MUCH", "MUCK", "MUDD", "MUFF", "MULE", "MULL", "MURK",
+ "MUSH", "MUST", "MUTE", "MUTT", "MYRA", "MYTH", "NAGY", "NAIL",
+ "NAIR", "NAME", "NARY", "NASH", "NAVE", "NAVY", "NEAL", "NEAR",
+ "NEAT", "NECK", "NEED", "NEIL", "NELL", "NEON", "NERO", "NESS",
+ "NEST", "NEWS", "NEWT", "NIBS", "NICE", "NICK", "NILE", "NINA",
+ "NINE", "NOAH", "NODE", "NOEL", "NOLL", "NONE", "NOOK", "NOON",
+ "NORM", "NOSE", "NOTE", "NOUN", "NOVA", "NUDE", "NULL", "NUMB",
+ "OATH", "OBEY", "OBOE", "ODIN", "OHIO", "OILY", "OINT", "OKAY",
+ "OLAF", "OLDY", "OLGA", "OLIN", "OMAN", "OMEN", "OMIT", "ONCE",
+ "ONES", "ONLY", "ONTO", "ONUS", "ORAL", "ORGY", "OSLO", "OTIS",
+ "OTTO", "OUCH", "OUST", "OUTS", "OVAL", "OVEN", "OVER", "OWLY",
+ "OWNS", "QUAD", "QUIT", "QUOD", "RACE", "RACK", "RACY", "RAFT",
+ "RAGE", "RAID", "RAIL", "RAIN", "RAKE", "RANK", "RANT", "RARE",
+ "RASH", "RATE", "RAVE", "RAYS", "READ", "REAL", "REAM", "REAR",
+ "RECK", "REED", "REEF", "REEK", "REEL", "REID", "REIN", "RENA",
+ "REND", "RENT", "REST", "RICE", "RICH", "RICK", "RIDE", "RIFT",
+ "RILL", "RIME", "RING", "RINK", "RISE", "RISK", "RITE", "ROAD",
+ "ROAM", "ROAR", "ROBE", "ROCK", "RODE", "ROIL", "ROLL", "ROME",
+ "ROOD", "ROOF", "ROOK", "ROOM", "ROOT", "ROSA", "ROSE", "ROSS",
+ "ROSY", "ROTH", "ROUT", "ROVE", "ROWE", "ROWS", "RUBE", "RUBY",
+ "RUDE", "RUDY", "RUIN", "RULE", "RUNG", "RUNS", "RUNT", "RUSE",
+ "RUSH", "RUSK", "RUSS", "RUST", "RUTH", "SACK", "SAFE", "SAGE",
+ "SAID", "SAIL", "SALE", "SALK", "SALT", "SAME", "SAND", "SANE",
+ "SANG", "SANK", "SARA", "SAUL", "SAVE", "SAYS", "SCAN", "SCAR",
+ "SCAT", "SCOT", "SEAL", "SEAM", "SEAR", "SEAT", "SEED", "SEEK",
+ "SEEM", "SEEN", "SEES", "SELF", "SELL", "SEND", "SENT", "SETS",
+ "SEWN", "SHAG", "SHAM", "SHAW", "SHAY", "SHED", "SHIM", "SHIN",
+ "SHOD", "SHOE", "SHOT", "SHOW", "SHUN", "SHUT", "SICK", "SIDE",
+ "SIFT", "SIGH", "SIGN", "SILK", "SILL", "SILO", "SILT", "SINE",
+ "SING", "SINK", "SIRE", "SITE", "SITS", "SITU", "SKAT", "SKEW",
+ "SKID", "SKIM", "SKIN", "SKIT", "SLAB", "SLAM", "SLAT", "SLAY",
+ "SLED", "SLEW", "SLID", "SLIM", "SLIT", "SLOB", "SLOG", "SLOT",
+ "SLOW", "SLUG", "SLUM", "SLUR", "SMOG", "SMUG", "SNAG", "SNOB",
+ "SNOW", "SNUB", "SNUG", "SOAK", "SOAR", "SOCK", "SODA", "SOFA",
+ "SOFT", "SOIL", "SOLD", "SOME", "SONG", "SOON", "SOOT", "SORE",
+ "SORT", "SOUL", "SOUR", "SOWN", "STAB", "STAG", "STAN", "STAR",
+ "STAY", "STEM", "STEW", "STIR", "STOW", "STUB", "STUN", "SUCH",
+ "SUDS", "SUIT", "SULK", "SUMS", "SUNG", "SUNK", "SURE", "SURF",
+ "SWAB", "SWAG", "SWAM", "SWAN", "SWAT", "SWAY", "SWIM", "SWUM",
+ "TACK", "TACT", "TAIL", "TAKE", "TALE", "TALK", "TALL", "TANK",
+ "TASK", "TATE", "TAUT", "TEAL", "TEAM", "TEAR", "TECH", "TEEM",
+ "TEEN", "TEET", "TELL", "TEND", "TENT", "TERM", "TERN", "TESS",
+ "TEST", "THAN", "THAT", "THEE", "THEM", "THEN", "THEY", "THIN",
+ "THIS", "THUD", "THUG", "TICK", "TIDE", "TIDY", "TIED", "TIER",
+ "TILE", "TILL", "TILT", "TIME", "TINA", "TINE", "TINT", "TINY",
+ "TIRE", "TOAD", "TOGO", "TOIL", "TOLD", "TOLL", "TONE", "TONG",
+ "TONY", "TOOK", "TOOL", "TOOT", "TORE", "TORN", "TOTE", "TOUR",
+ "TOUT", "TOWN", "TRAG", "TRAM", "TRAY", "TREE", "TREK", "TRIG",
+ "TRIM", "TRIO", "TROD", "TROT", "TROY", "TRUE", "TUBA", "TUBE",
+ "TUCK", "TUFT", "TUNA", "TUNE", "TUNG", "TURF", "TURN", "TUSK",
+ "TWIG", "TWIN", "TWIT", "ULAN", "UNIT", "URGE", "USED", "USER",
+ "USES", "UTAH", "VAIL", "VAIN", "VALE", "VARY", "VASE", "VAST",
+ "VEAL", "VEDA", "VEIL", "VEIN", "VEND", "VENT", "VERB", "VERY",
+ "VETO", "VICE", "VIEW", "VINE", "VISE", "VOID", "VOLT", "VOTE",
+ "WACK", "WADE", "WAGE", "WAIL", "WAIT", "WAKE", "WALE", "WALK",
+ "WALL", "WALT", "WAND", "WANE", "WANG", "WANT", "WARD", "WARM",
+ "WARN", "WART", "WASH", "WAST", "WATS", "WATT", "WAVE", "WAVY",
+ "WAYS", "WEAK", "WEAL", "WEAN", "WEAR", "WEED", "WEEK", "WEIR",
+ "WELD", "WELL", "WELT", "WENT", "WERE", "WERT", "WEST", "WHAM",
+ "WHAT", "WHEE", "WHEN", "WHET", "WHOA", "WHOM", "WICK", "WIFE",
+ "WILD", "WILL", "WIND", "WINE", "WING", "WINK", "WINO", "WIRE",
+ "WISE", "WISH", "WITH", "WOLF", "WONT", "WOOD", "WOOL", "WORD",
+ "WORE", "WORK", "WORM", "WORN", "WOVE", "WRIT", "WYNN", "YALE",
+ "YANG", "YANK", "YARD", "YARN", "YAWL", "YAWN", "YEAH", "YEAR",
+ "YELL", "YOGA", "YOKE" ]
diff --git a/lib/Crypto/Util/RFC1751.pyi b/lib/Crypto/Util/RFC1751.pyi
new file mode 100644
index 0000000..6ad07ff
--- /dev/null
+++ b/lib/Crypto/Util/RFC1751.pyi
@@ -0,0 +1,7 @@
+from typing import Dict, List
+
+binary: Dict[int, str]
+wordlist: List[str]
+
+def key_to_english(key: bytes) -> str: ...
+def english_to_key(s: str) -> bytes: ...
diff --git a/lib/Crypto/Util/__init__.py b/lib/Crypto/Util/__init__.py
new file mode 100644
index 0000000..f12214d
--- /dev/null
+++ b/lib/Crypto/Util/__init__.py
@@ -0,0 +1,41 @@
+# -*- coding: utf-8 -*-
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Miscellaneous modules
+
+Contains useful modules that don't belong into any of the
+other Crypto.* subpackages.
+
+======================== =============================================
+Module Description
+======================== =============================================
+`Crypto.Util.number` Number-theoretic functions (primality testing, etc.)
+`Crypto.Util.Counter` Fast counter functions for CTR cipher modes.
+`Crypto.Util.RFC1751` Converts between 128-bit keys and human-readable
+ strings of words.
+`Crypto.Util.asn1` Minimal support for ASN.1 DER encoding
+`Crypto.Util.Padding` Set of functions for adding and removing padding.
+======================== =============================================
+
+:undocumented: _galois, _number_new, cpuid, py3compat, _raw_api
+"""
+
+__all__ = ['RFC1751', 'number', 'strxor', 'asn1', 'Counter', 'Padding']
+
diff --git a/lib/Crypto/Util/_cpu_features.py b/lib/Crypto/Util/_cpu_features.py
new file mode 100644
index 0000000..b3039b5
--- /dev/null
+++ b/lib/Crypto/Util/_cpu_features.py
@@ -0,0 +1,46 @@
+# ===================================================================
+#
+# Copyright (c) 2018, Helder Eijs <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util._raw_api import load_pycryptodome_raw_lib
+
+
+_raw_cpuid_lib = load_pycryptodome_raw_lib("Crypto.Util._cpuid_c",
+ """
+ int have_aes_ni(void);
+ int have_clmul(void);
+ """)
+
+
+def have_aes_ni():
+ return _raw_cpuid_lib.have_aes_ni()
+
+
+def have_clmul():
+ return _raw_cpuid_lib.have_clmul()
diff --git a/lib/Crypto/Util/_cpu_features.pyi b/lib/Crypto/Util/_cpu_features.pyi
new file mode 100644
index 0000000..10e669e
--- /dev/null
+++ b/lib/Crypto/Util/_cpu_features.pyi
@@ -0,0 +1,2 @@
+def have_aes_ni() -> int: ...
+def have_clmul() -> int: ...
diff --git a/lib/Crypto/Util/_cpuid_c.abi3.so b/lib/Crypto/Util/_cpuid_c.abi3.so
new file mode 100755
index 0000000..6665a81
--- /dev/null
+++ b/lib/Crypto/Util/_cpuid_c.abi3.so
Binary files differ
diff --git a/lib/Crypto/Util/_file_system.py b/lib/Crypto/Util/_file_system.py
new file mode 100644
index 0000000..1cb0c4b
--- /dev/null
+++ b/lib/Crypto/Util/_file_system.py
@@ -0,0 +1,54 @@
+# ===================================================================
+#
+# Copyright (c) 2016, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import os
+
+
+def pycryptodome_filename(dir_comps, filename):
+ """Return the complete file name for the module
+
+ dir_comps : list of string
+ The list of directory names in the PyCryptodome package.
+ The first element must be "Crypto".
+
+ filename : string
+ The filename (inclusing extension) in the target directory.
+ """
+
+ if dir_comps[0] != "Crypto":
+ raise ValueError("Only available for modules under 'Crypto'")
+
+ dir_comps = list(dir_comps[1:]) + [filename]
+
+ util_lib, _ = os.path.split(os.path.abspath(__file__))
+ root_lib = os.path.join(util_lib, "..")
+
+ return os.path.join(root_lib, *dir_comps)
+
diff --git a/lib/Crypto/Util/_file_system.pyi b/lib/Crypto/Util/_file_system.pyi
new file mode 100644
index 0000000..d54a126
--- /dev/null
+++ b/lib/Crypto/Util/_file_system.pyi
@@ -0,0 +1,4 @@
+from typing import List
+
+
+def pycryptodome_filename(dir_comps: List[str], filename: str) -> str: ... \ No newline at end of file
diff --git a/lib/Crypto/Util/_raw_api.py b/lib/Crypto/Util/_raw_api.py
new file mode 100644
index 0000000..f026e8f
--- /dev/null
+++ b/lib/Crypto/Util/_raw_api.py
@@ -0,0 +1,319 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+import os
+import abc
+import sys
+from Crypto.Util.py3compat import byte_string
+from Crypto.Util._file_system import pycryptodome_filename
+
+#
+# List of file suffixes for Python extensions
+#
+if sys.version_info[0] < 3:
+
+ import imp
+ extension_suffixes = []
+ for ext, mod, typ in imp.get_suffixes():
+ if typ == imp.C_EXTENSION:
+ extension_suffixes.append(ext)
+
+else:
+
+ from importlib import machinery
+ extension_suffixes = machinery.EXTENSION_SUFFIXES
+
+# Which types with buffer interface we support (apart from byte strings)
+_buffer_type = (bytearray, memoryview)
+
+
+class _VoidPointer(object):
+ @abc.abstractmethod
+ def get(self):
+ """Return the memory location we point to"""
+ return
+
+ @abc.abstractmethod
+ def address_of(self):
+ """Return a raw pointer to this pointer"""
+ return
+
+
+try:
+ # Starting from v2.18, pycparser (used by cffi for in-line ABI mode)
+ # stops working correctly when PYOPTIMIZE==2 or the parameter -OO is
+ # passed. In that case, we fall back to ctypes.
+ # Note that PyPy ships with an old version of pycparser so we can keep
+ # using cffi there.
+ # See https://github.com/Legrandin/pycryptodome/issues/228
+ if '__pypy__' not in sys.builtin_module_names and sys.flags.optimize == 2:
+ raise ImportError("CFFI with optimize=2 fails due to pycparser bug.")
+
+ from cffi import FFI
+
+ ffi = FFI()
+ null_pointer = ffi.NULL
+ uint8_t_type = ffi.typeof(ffi.new("const uint8_t*"))
+
+ _Array = ffi.new("uint8_t[1]").__class__.__bases__
+
+ def load_lib(name, cdecl):
+ """Load a shared library and return a handle to it.
+
+ @name, either an absolute path or the name of a library
+ in the system search path.
+
+ @cdecl, the C function declarations.
+ """
+
+ if hasattr(ffi, "RTLD_DEEPBIND") and not os.getenv('PYCRYPTODOME_DISABLE_DEEPBIND'):
+ lib = ffi.dlopen(name, ffi.RTLD_DEEPBIND)
+ else:
+ lib = ffi.dlopen(name)
+ ffi.cdef(cdecl)
+ return lib
+
+ def c_ulong(x):
+ """Convert a Python integer to unsigned long"""
+ return x
+
+ c_ulonglong = c_ulong
+ c_uint = c_ulong
+ c_ubyte = c_ulong
+
+ def c_size_t(x):
+ """Convert a Python integer to size_t"""
+ return x
+
+ def create_string_buffer(init_or_size, size=None):
+ """Allocate the given amount of bytes (initially set to 0)"""
+
+ if isinstance(init_or_size, bytes):
+ size = max(len(init_or_size) + 1, size)
+ result = ffi.new("uint8_t[]", size)
+ result[:] = init_or_size
+ else:
+ if size:
+ raise ValueError("Size must be specified once only")
+ result = ffi.new("uint8_t[]", init_or_size)
+ return result
+
+ def get_c_string(c_string):
+ """Convert a C string into a Python byte sequence"""
+ return ffi.string(c_string)
+
+ def get_raw_buffer(buf):
+ """Convert a C buffer into a Python byte sequence"""
+ return ffi.buffer(buf)[:]
+
+ def c_uint8_ptr(data):
+ if isinstance(data, _buffer_type):
+ # This only works for cffi >= 1.7
+ return ffi.cast(uint8_t_type, ffi.from_buffer(data))
+ elif byte_string(data) or isinstance(data, _Array):
+ return data
+ else:
+ raise TypeError("Object type %s cannot be passed to C code" % type(data))
+
+ class VoidPointer_cffi(_VoidPointer):
+ """Model a newly allocated pointer to void"""
+
+ def __init__(self):
+ self._pp = ffi.new("void *[1]")
+
+ def get(self):
+ return self._pp[0]
+
+ def address_of(self):
+ return self._pp
+
+ def VoidPointer():
+ return VoidPointer_cffi()
+
+ backend = "cffi"
+
+except ImportError:
+
+ import ctypes
+ from ctypes import (CDLL, c_void_p, byref, c_ulong, c_ulonglong, c_size_t,
+ create_string_buffer, c_ubyte, c_uint)
+ from ctypes.util import find_library
+ from ctypes import Array as _Array
+
+ null_pointer = None
+ cached_architecture = []
+
+ def c_ubyte(c):
+ if not (0 <= c < 256):
+ raise OverflowError()
+ return ctypes.c_ubyte(c)
+
+ def load_lib(name, cdecl):
+ if not cached_architecture:
+ # platform.architecture() creates a subprocess, so caching the
+ # result makes successive imports faster.
+ import platform
+ cached_architecture[:] = platform.architecture()
+ bits, linkage = cached_architecture
+ if "." not in name and not linkage.startswith("Win"):
+ full_name = find_library(name)
+ if full_name is None:
+ raise OSError("Cannot load library '%s'" % name)
+ name = full_name
+ return CDLL(name)
+
+ def get_c_string(c_string):
+ return c_string.value
+
+ def get_raw_buffer(buf):
+ return buf.raw
+
+ # ---- Get raw pointer ---
+
+ _c_ssize_t = ctypes.c_ssize_t
+
+ _PyBUF_SIMPLE = 0
+ _PyObject_GetBuffer = ctypes.pythonapi.PyObject_GetBuffer
+ _PyBuffer_Release = ctypes.pythonapi.PyBuffer_Release
+ _py_object = ctypes.py_object
+ _c_ssize_p = ctypes.POINTER(_c_ssize_t)
+
+ # See Include/object.h for CPython
+ # and https://github.com/pallets/click/blob/master/src/click/_winconsole.py
+ class _Py_buffer(ctypes.Structure):
+ _fields_ = [
+ ('buf', c_void_p),
+ ('obj', ctypes.py_object),
+ ('len', _c_ssize_t),
+ ('itemsize', _c_ssize_t),
+ ('readonly', ctypes.c_int),
+ ('ndim', ctypes.c_int),
+ ('format', ctypes.c_char_p),
+ ('shape', _c_ssize_p),
+ ('strides', _c_ssize_p),
+ ('suboffsets', _c_ssize_p),
+ ('internal', c_void_p)
+ ]
+
+ # Extra field for CPython 2.6/2.7
+ if sys.version_info[0] == 2:
+ _fields_.insert(-1, ('smalltable', _c_ssize_t * 2))
+
+ def c_uint8_ptr(data):
+ if byte_string(data) or isinstance(data, _Array):
+ return data
+ elif isinstance(data, _buffer_type):
+ obj = _py_object(data)
+ buf = _Py_buffer()
+ _PyObject_GetBuffer(obj, byref(buf), _PyBUF_SIMPLE)
+ try:
+ buffer_type = ctypes.c_ubyte * buf.len
+ return buffer_type.from_address(buf.buf)
+ finally:
+ _PyBuffer_Release(byref(buf))
+ else:
+ raise TypeError("Object type %s cannot be passed to C code" % type(data))
+
+ # ---
+
+ class VoidPointer_ctypes(_VoidPointer):
+ """Model a newly allocated pointer to void"""
+
+ def __init__(self):
+ self._p = c_void_p()
+
+ def get(self):
+ return self._p
+
+ def address_of(self):
+ return byref(self._p)
+
+ def VoidPointer():
+ return VoidPointer_ctypes()
+
+ backend = "ctypes"
+
+
+class SmartPointer(object):
+ """Class to hold a non-managed piece of memory"""
+
+ def __init__(self, raw_pointer, destructor):
+ self._raw_pointer = raw_pointer
+ self._destructor = destructor
+
+ def get(self):
+ return self._raw_pointer
+
+ def release(self):
+ rp, self._raw_pointer = self._raw_pointer, None
+ return rp
+
+ def __del__(self):
+ try:
+ if self._raw_pointer is not None:
+ self._destructor(self._raw_pointer)
+ self._raw_pointer = None
+ except AttributeError:
+ pass
+
+
+def load_pycryptodome_raw_lib(name, cdecl):
+ """Load a shared library and return a handle to it.
+
+ @name, the name of the library expressed as a PyCryptodome module,
+ for instance Crypto.Cipher._raw_cbc.
+
+ @cdecl, the C function declarations.
+ """
+
+ split = name.split(".")
+ dir_comps, basename = split[:-1], split[-1]
+ attempts = []
+ for ext in extension_suffixes:
+ try:
+ filename = basename + ext
+ full_name = pycryptodome_filename(dir_comps, filename)
+ if not os.path.isfile(full_name):
+ attempts.append("Not found '%s'" % filename)
+ continue
+ return load_lib(full_name, cdecl)
+ except OSError as exp:
+ attempts.append("Cannot load '%s': %s" % (filename, str(exp)))
+ raise OSError("Cannot load native module '%s': %s" % (name, ", ".join(attempts)))
+
+
+def is_buffer(x):
+ """Return True if object x supports the buffer interface"""
+ return isinstance(x, (bytes, bytearray, memoryview))
+
+
+def is_writeable_buffer(x):
+ return (isinstance(x, bytearray) or
+ (isinstance(x, memoryview) and not x.readonly))
diff --git a/lib/Crypto/Util/_raw_api.pyi b/lib/Crypto/Util/_raw_api.pyi
new file mode 100644
index 0000000..2bc5301
--- /dev/null
+++ b/lib/Crypto/Util/_raw_api.pyi
@@ -0,0 +1,27 @@
+from typing import Any, Optional, Union
+
+def load_lib(name: str, cdecl: str) -> Any : ...
+def c_ulong(x: int ) -> Any : ...
+def c_ulonglong(x: int ) -> Any : ...
+def c_size_t(x: int) -> Any : ...
+def create_string_buffer(init_or_size: Union[bytes,int], size: Optional[int]) -> Any : ...
+def get_c_string(c_string: Any) -> bytes : ...
+def get_raw_buffer(buf: Any) -> bytes : ...
+def c_uint8_ptr(data: Union[bytes, memoryview, bytearray]) -> Any : ...
+
+class VoidPointer(object):
+ def get(self) -> Any : ...
+ def address_of(self) -> Any : ...
+
+class SmartPointer(object):
+ def __init__(self, raw_pointer: Any, destructor: Any) -> None : ...
+ def get(self) -> Any : ...
+ def release(self) -> Any : ...
+
+backend : str
+null_pointer : Any
+ffi: Any
+
+def load_pycryptodome_raw_lib(name: str, cdecl: str) -> Any : ...
+def is_buffer(x: Any) -> bool : ...
+def is_writeable_buffer(x: Any) -> bool : ...
diff --git a/lib/Crypto/Util/_strxor.abi3.so b/lib/Crypto/Util/_strxor.abi3.so
new file mode 100755
index 0000000..97e4245
--- /dev/null
+++ b/lib/Crypto/Util/_strxor.abi3.so
Binary files differ
diff --git a/lib/Crypto/Util/asn1.py b/lib/Crypto/Util/asn1.py
new file mode 100644
index 0000000..c4571d4
--- /dev/null
+++ b/lib/Crypto/Util/asn1.py
@@ -0,0 +1,939 @@
+# -*- coding: ascii -*-
+#
+# Util/asn1.py : Minimal support for ASN.1 DER binary encoding.
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+import struct
+
+from Crypto.Util.py3compat import byte_string, b, bchr, bord
+
+from Crypto.Util.number import long_to_bytes, bytes_to_long
+
+__all__ = ['DerObject', 'DerInteger', 'DerOctetString', 'DerNull',
+ 'DerSequence', 'DerObjectId', 'DerBitString', 'DerSetOf']
+
+
+def _is_number(x, only_non_negative=False):
+ test = 0
+ try:
+ test = x + test
+ except TypeError:
+ return False
+ return not only_non_negative or x >= 0
+
+
+class BytesIO_EOF(object):
+ """This class differs from BytesIO in that a ValueError exception is
+ raised whenever EOF is reached."""
+
+ def __init__(self, initial_bytes):
+ self._buffer = initial_bytes
+ self._index = 0
+ self._bookmark = None
+
+ def set_bookmark(self):
+ self._bookmark = self._index
+
+ def data_since_bookmark(self):
+ assert self._bookmark is not None
+ return self._buffer[self._bookmark:self._index]
+
+ def remaining_data(self):
+ return len(self._buffer) - self._index
+
+ def read(self, length):
+ new_index = self._index + length
+ if new_index > len(self._buffer):
+ raise ValueError("Not enough data for DER decoding: expected %d bytes and found %d" % (new_index, len(self._buffer)))
+
+ result = self._buffer[self._index:new_index]
+ self._index = new_index
+ return result
+
+ def read_byte(self):
+ return bord(self.read(1)[0])
+
+
+class DerObject(object):
+ """Base class for defining a single DER object.
+
+ This class should never be directly instantiated.
+ """
+
+ def __init__(self, asn1Id=None, payload=b'', implicit=None,
+ constructed=False, explicit=None):
+ """Initialize the DER object according to a specific ASN.1 type.
+
+ :Parameters:
+ asn1Id : integer
+ The universal DER tag number for this object
+ (e.g. 0x10 for a SEQUENCE).
+ If None, the tag is not known yet.
+
+ payload : byte string
+ The initial payload of the object (that it,
+ the content octets).
+ If not specified, the payload is empty.
+
+ implicit : integer
+ The IMPLICIT tag number to use for the encoded object.
+ It overrides the universal tag *asn1Id*.
+
+ constructed : bool
+ True when the ASN.1 type is *constructed*.
+ False when it is *primitive*.
+
+ explicit : integer
+ The EXPLICIT tag number to use for the encoded object.
+ """
+
+ if asn1Id is None:
+ # The tag octet will be read in with ``decode``
+ self._tag_octet = None
+ return
+ asn1Id = self._convertTag(asn1Id)
+
+ self.payload = payload
+
+ # In a BER/DER identifier octet:
+ # * bits 4-0 contain the tag value
+ # * bit 5 is set if the type is 'constructed'
+ # and unset if 'primitive'
+ # * bits 7-6 depend on the encoding class
+ #
+ # Class | Bit 7, Bit 6
+ # ----------------------------------
+ # universal | 0 0
+ # application | 0 1
+ # context-spec | 1 0 (default for IMPLICIT/EXPLICIT)
+ # private | 1 1
+ #
+ if None not in (explicit, implicit):
+ raise ValueError("Explicit and implicit tags are"
+ " mutually exclusive")
+
+ if implicit is not None:
+ self._tag_octet = 0x80 | 0x20 * constructed | self._convertTag(implicit)
+ return
+
+ if explicit is not None:
+ self._tag_octet = 0xA0 | self._convertTag(explicit)
+ self._inner_tag_octet = 0x20 * constructed | asn1Id
+ return
+
+ self._tag_octet = 0x20 * constructed | asn1Id
+
+ def _convertTag(self, tag):
+ """Check if *tag* is a real DER tag.
+ Convert it from a character to number if necessary.
+ """
+ if not _is_number(tag):
+ if len(tag) == 1:
+ tag = bord(tag[0])
+ # Ensure that tag is a low tag
+ if not (_is_number(tag) and 0 <= tag < 0x1F):
+ raise ValueError("Wrong DER tag")
+ return tag
+
+ @staticmethod
+ def _definite_form(length):
+ """Build length octets according to BER/DER
+ definite form.
+ """
+ if length > 127:
+ encoding = long_to_bytes(length)
+ return bchr(len(encoding) + 128) + encoding
+ return bchr(length)
+
+ def encode(self):
+ """Return this DER element, fully encoded as a binary byte string."""
+
+ # Concatenate identifier octets, length octets,
+ # and contents octets
+
+ output_payload = self.payload
+
+ # In case of an EXTERNAL tag, first encode the inner
+ # element.
+ if hasattr(self, "_inner_tag_octet"):
+ output_payload = (bchr(self._inner_tag_octet) +
+ self._definite_form(len(self.payload)) +
+ self.payload)
+
+ return (bchr(self._tag_octet) +
+ self._definite_form(len(output_payload)) +
+ output_payload)
+
+ def _decodeLen(self, s):
+ """Decode DER length octets from a file."""
+
+ length = s.read_byte()
+
+ if length > 127:
+ encoded_length = s.read(length & 0x7F)
+ if bord(encoded_length[0]) == 0:
+ raise ValueError("Invalid DER: length has leading zero")
+ length = bytes_to_long(encoded_length)
+ if length <= 127:
+ raise ValueError("Invalid DER: length in long form but smaller than 128")
+
+ return length
+
+ def decode(self, der_encoded, strict=False):
+ """Decode a complete DER element, and re-initializes this
+ object with it.
+
+ Args:
+ der_encoded (byte string): A complete DER element.
+
+ Raises:
+ ValueError: in case of parsing errors.
+ """
+
+ if not byte_string(der_encoded):
+ raise ValueError("Input is not a byte string")
+
+ s = BytesIO_EOF(der_encoded)
+ self._decodeFromStream(s, strict)
+
+ # There shouldn't be other bytes left
+ if s.remaining_data() > 0:
+ raise ValueError("Unexpected extra data after the DER structure")
+
+ return self
+
+ def _decodeFromStream(self, s, strict):
+ """Decode a complete DER element from a file."""
+
+ idOctet = s.read_byte()
+ if self._tag_octet is not None:
+ if idOctet != self._tag_octet:
+ raise ValueError("Unexpected DER tag")
+ else:
+ self._tag_octet = idOctet
+ length = self._decodeLen(s)
+ self.payload = s.read(length)
+
+ # In case of an EXTERNAL tag, further decode the inner
+ # element.
+ if hasattr(self, "_inner_tag_octet"):
+ p = BytesIO_EOF(self.payload)
+ inner_octet = p.read_byte()
+ if inner_octet != self._inner_tag_octet:
+ raise ValueError("Unexpected internal DER tag")
+ length = self._decodeLen(p)
+ self.payload = p.read(length)
+
+ # There shouldn't be other bytes left
+ if p.remaining_data() > 0:
+ raise ValueError("Unexpected extra data after the DER structure")
+
+
+class DerInteger(DerObject):
+ """Class to model a DER INTEGER.
+
+ An example of encoding is::
+
+ >>> from Crypto.Util.asn1 import DerInteger
+ >>> from binascii import hexlify, unhexlify
+ >>> int_der = DerInteger(9)
+ >>> print hexlify(int_der.encode())
+
+ which will show ``020109``, the DER encoding of 9.
+
+ And for decoding::
+
+ >>> s = unhexlify(b'020109')
+ >>> try:
+ >>> int_der = DerInteger()
+ >>> int_der.decode(s)
+ >>> print int_der.value
+ >>> except ValueError:
+ >>> print "Not a valid DER INTEGER"
+
+ the output will be ``9``.
+
+ :ivar value: The integer value
+ :vartype value: integer
+ """
+
+ def __init__(self, value=0, implicit=None, explicit=None):
+ """Initialize the DER object as an INTEGER.
+
+ :Parameters:
+ value : integer
+ The value of the integer.
+
+ implicit : integer
+ The IMPLICIT tag to use for the encoded object.
+ It overrides the universal tag for INTEGER (2).
+ """
+
+ DerObject.__init__(self, 0x02, b'', implicit,
+ False, explicit)
+ self.value = value # The integer value
+
+ def encode(self):
+ """Return the DER INTEGER, fully encoded as a
+ binary string."""
+
+ number = self.value
+ self.payload = b''
+ while True:
+ self.payload = bchr(int(number & 255)) + self.payload
+ if 128 <= number <= 255:
+ self.payload = bchr(0x00) + self.payload
+ if -128 <= number <= 255:
+ break
+ number >>= 8
+ return DerObject.encode(self)
+
+ def decode(self, der_encoded, strict=False):
+ """Decode a complete DER INTEGER DER, and re-initializes this
+ object with it.
+
+ Args:
+ der_encoded (byte string): A complete INTEGER DER element.
+
+ Raises:
+ ValueError: in case of parsing errors.
+ """
+
+ return DerObject.decode(self, der_encoded, strict=strict)
+
+ def _decodeFromStream(self, s, strict):
+ """Decode a complete DER INTEGER from a file."""
+
+ # Fill up self.payload
+ DerObject._decodeFromStream(self, s, strict)
+
+ if strict:
+ if len(self.payload) == 0:
+ raise ValueError("Invalid encoding for DER INTEGER: empty payload")
+ if len(self.payload) >= 2 and struct.unpack('>H', self.payload[:2])[0] < 0x80:
+ raise ValueError("Invalid encoding for DER INTEGER: leading zero")
+
+ # Derive self.value from self.payload
+ self.value = 0
+ bits = 1
+ for i in self.payload:
+ self.value *= 256
+ self.value += bord(i)
+ bits <<= 8
+ if self.payload and bord(self.payload[0]) & 0x80:
+ self.value -= bits
+
+
+class DerSequence(DerObject):
+ """Class to model a DER SEQUENCE.
+
+ This object behaves like a dynamic Python sequence.
+
+ Sub-elements that are INTEGERs behave like Python integers.
+
+ Any other sub-element is a binary string encoded as a complete DER
+ sub-element (TLV).
+
+ An example of encoding is:
+
+ >>> from Crypto.Util.asn1 import DerSequence, DerInteger
+ >>> from binascii import hexlify, unhexlify
+ >>> obj_der = unhexlify('070102')
+ >>> seq_der = DerSequence([4])
+ >>> seq_der.append(9)
+ >>> seq_der.append(obj_der.encode())
+ >>> print hexlify(seq_der.encode())
+
+ which will show ``3009020104020109070102``, the DER encoding of the
+ sequence containing ``4``, ``9``, and the object with payload ``02``.
+
+ For decoding:
+
+ >>> s = unhexlify(b'3009020104020109070102')
+ >>> try:
+ >>> seq_der = DerSequence()
+ >>> seq_der.decode(s)
+ >>> print len(seq_der)
+ >>> print seq_der[0]
+ >>> print seq_der[:]
+ >>> except ValueError:
+ >>> print "Not a valid DER SEQUENCE"
+
+ the output will be::
+
+ 3
+ 4
+ [4, 9, b'\x07\x01\x02']
+
+ """
+
+ def __init__(self, startSeq=None, implicit=None):
+ """Initialize the DER object as a SEQUENCE.
+
+ :Parameters:
+ startSeq : Python sequence
+ A sequence whose element are either integers or
+ other DER objects.
+
+ implicit : integer
+ The IMPLICIT tag to use for the encoded object.
+ It overrides the universal tag for SEQUENCE (16).
+ """
+
+ DerObject.__init__(self, 0x10, b'', implicit, True)
+ if startSeq is None:
+ self._seq = []
+ else:
+ self._seq = startSeq
+
+ # A few methods to make it behave like a python sequence
+
+ def __delitem__(self, n):
+ del self._seq[n]
+
+ def __getitem__(self, n):
+ return self._seq[n]
+
+ def __setitem__(self, key, value):
+ self._seq[key] = value
+
+ def __setslice__(self, i, j, sequence):
+ self._seq[i:j] = sequence
+
+ def __delslice__(self, i, j):
+ del self._seq[i:j]
+
+ def __getslice__(self, i, j):
+ return self._seq[max(0, i):max(0, j)]
+
+ def __len__(self):
+ return len(self._seq)
+
+ def __iadd__(self, item):
+ self._seq.append(item)
+ return self
+
+ def append(self, item):
+ self._seq.append(item)
+ return self
+
+ def hasInts(self, only_non_negative=True):
+ """Return the number of items in this sequence that are
+ integers.
+
+ Args:
+ only_non_negative (boolean):
+ If ``True``, negative integers are not counted in.
+ """
+
+ items = [x for x in self._seq if _is_number(x, only_non_negative)]
+ return len(items)
+
+ def hasOnlyInts(self, only_non_negative=True):
+ """Return ``True`` if all items in this sequence are integers
+ or non-negative integers.
+
+ This function returns False is the sequence is empty,
+ or at least one member is not an integer.
+
+ Args:
+ only_non_negative (boolean):
+ If ``True``, the presence of negative integers
+ causes the method to return ``False``."""
+ return self._seq and self.hasInts(only_non_negative) == len(self._seq)
+
+ def encode(self):
+ """Return this DER SEQUENCE, fully encoded as a
+ binary string.
+
+ Raises:
+ ValueError: if some elements in the sequence are neither integers
+ nor byte strings.
+ """
+ self.payload = b''
+ for item in self._seq:
+ if byte_string(item):
+ self.payload += item
+ elif _is_number(item):
+ self.payload += DerInteger(item).encode()
+ else:
+ self.payload += item.encode()
+ return DerObject.encode(self)
+
+ def decode(self, der_encoded, strict=False, nr_elements=None, only_ints_expected=False):
+ """Decode a complete DER SEQUENCE, and re-initializes this
+ object with it.
+
+ Args:
+ der_encoded (byte string):
+ A complete SEQUENCE DER element.
+ nr_elements (None or integer or list of integers):
+ The number of members the SEQUENCE can have
+ only_ints_expected (boolean):
+ Whether the SEQUENCE is expected to contain only integers.
+ strict (boolean):
+ Whether decoding must check for strict DER compliancy.
+
+ Raises:
+ ValueError: in case of parsing errors.
+
+ DER INTEGERs are decoded into Python integers. Any other DER
+ element is not decoded. Its validity is not checked.
+ """
+
+ self._nr_elements = nr_elements
+ result = DerObject.decode(self, der_encoded, strict=strict)
+
+ if only_ints_expected and not self.hasOnlyInts():
+ raise ValueError("Some members are not INTEGERs")
+
+ return result
+
+ def _decodeFromStream(self, s, strict):
+ """Decode a complete DER SEQUENCE from a file."""
+
+ self._seq = []
+
+ # Fill up self.payload
+ DerObject._decodeFromStream(self, s, strict)
+
+ # Add one item at a time to self.seq, by scanning self.payload
+ p = BytesIO_EOF(self.payload)
+ while p.remaining_data() > 0:
+ p.set_bookmark()
+
+ der = DerObject()
+ der._decodeFromStream(p, strict)
+
+ # Parse INTEGERs differently
+ if der._tag_octet != 0x02:
+ self._seq.append(p.data_since_bookmark())
+ else:
+ derInt = DerInteger()
+ #import pdb; pdb.set_trace()
+ data = p.data_since_bookmark()
+ derInt.decode(data, strict=strict)
+ self._seq.append(derInt.value)
+
+ ok = True
+ if self._nr_elements is not None:
+ try:
+ ok = len(self._seq) in self._nr_elements
+ except TypeError:
+ ok = len(self._seq) == self._nr_elements
+
+ if not ok:
+ raise ValueError("Unexpected number of members (%d)"
+ " in the sequence" % len(self._seq))
+
+
+class DerOctetString(DerObject):
+ """Class to model a DER OCTET STRING.
+
+ An example of encoding is:
+
+ >>> from Crypto.Util.asn1 import DerOctetString
+ >>> from binascii import hexlify, unhexlify
+ >>> os_der = DerOctetString(b'\\xaa')
+ >>> os_der.payload += b'\\xbb'
+ >>> print hexlify(os_der.encode())
+
+ which will show ``0402aabb``, the DER encoding for the byte string
+ ``b'\\xAA\\xBB'``.
+
+ For decoding:
+
+ >>> s = unhexlify(b'0402aabb')
+ >>> try:
+ >>> os_der = DerOctetString()
+ >>> os_der.decode(s)
+ >>> print hexlify(os_der.payload)
+ >>> except ValueError:
+ >>> print "Not a valid DER OCTET STRING"
+
+ the output will be ``aabb``.
+
+ :ivar payload: The content of the string
+ :vartype payload: byte string
+ """
+
+ def __init__(self, value=b'', implicit=None):
+ """Initialize the DER object as an OCTET STRING.
+
+ :Parameters:
+ value : byte string
+ The initial payload of the object.
+ If not specified, the payload is empty.
+
+ implicit : integer
+ The IMPLICIT tag to use for the encoded object.
+ It overrides the universal tag for OCTET STRING (4).
+ """
+ DerObject.__init__(self, 0x04, value, implicit, False)
+
+
+class DerNull(DerObject):
+ """Class to model a DER NULL element."""
+
+ def __init__(self):
+ """Initialize the DER object as a NULL."""
+
+ DerObject.__init__(self, 0x05, b'', None, False)
+
+
+class DerObjectId(DerObject):
+ """Class to model a DER OBJECT ID.
+
+ An example of encoding is:
+
+ >>> from Crypto.Util.asn1 import DerObjectId
+ >>> from binascii import hexlify, unhexlify
+ >>> oid_der = DerObjectId("1.2")
+ >>> oid_der.value += ".840.113549.1.1.1"
+ >>> print hexlify(oid_der.encode())
+
+ which will show ``06092a864886f70d010101``, the DER encoding for the
+ RSA Object Identifier ``1.2.840.113549.1.1.1``.
+
+ For decoding:
+
+ >>> s = unhexlify(b'06092a864886f70d010101')
+ >>> try:
+ >>> oid_der = DerObjectId()
+ >>> oid_der.decode(s)
+ >>> print oid_der.value
+ >>> except ValueError:
+ >>> print "Not a valid DER OBJECT ID"
+
+ the output will be ``1.2.840.113549.1.1.1``.
+
+ :ivar value: The Object ID (OID), a dot separated list of integers
+ :vartype value: string
+ """
+
+ def __init__(self, value='', implicit=None, explicit=None):
+ """Initialize the DER object as an OBJECT ID.
+
+ :Parameters:
+ value : string
+ The initial Object Identifier (e.g. "1.2.0.0.6.2").
+ implicit : integer
+ The IMPLICIT tag to use for the encoded object.
+ It overrides the universal tag for OBJECT ID (6).
+ explicit : integer
+ The EXPLICIT tag to use for the encoded object.
+ """
+ DerObject.__init__(self, 0x06, b'', implicit, False, explicit)
+ self.value = value
+
+ def encode(self):
+ """Return the DER OBJECT ID, fully encoded as a
+ binary string."""
+
+ comps = [int(x) for x in self.value.split(".")]
+ if len(comps) < 2:
+ raise ValueError("Not a valid Object Identifier string")
+ self.payload = bchr(40*comps[0]+comps[1])
+ for v in comps[2:]:
+ if v == 0:
+ enc = [0]
+ else:
+ enc = []
+ while v:
+ enc.insert(0, (v & 0x7F) | 0x80)
+ v >>= 7
+ enc[-1] &= 0x7F
+ self.payload += b''.join([bchr(x) for x in enc])
+ return DerObject.encode(self)
+
+ def decode(self, der_encoded, strict=False):
+ """Decode a complete DER OBJECT ID, and re-initializes this
+ object with it.
+
+ Args:
+ der_encoded (byte string):
+ A complete DER OBJECT ID.
+ strict (boolean):
+ Whether decoding must check for strict DER compliancy.
+
+ Raises:
+ ValueError: in case of parsing errors.
+ """
+
+ return DerObject.decode(self, der_encoded, strict)
+
+ def _decodeFromStream(self, s, strict):
+ """Decode a complete DER OBJECT ID from a file."""
+
+ # Fill up self.payload
+ DerObject._decodeFromStream(self, s, strict)
+
+ # Derive self.value from self.payload
+ p = BytesIO_EOF(self.payload)
+ comps = [str(x) for x in divmod(p.read_byte(), 40)]
+ v = 0
+ while p.remaining_data():
+ c = p.read_byte()
+ v = v*128 + (c & 0x7F)
+ if not (c & 0x80):
+ comps.append(str(v))
+ v = 0
+ self.value = '.'.join(comps)
+
+
+class DerBitString(DerObject):
+ """Class to model a DER BIT STRING.
+
+ An example of encoding is:
+
+ >>> from Crypto.Util.asn1 import DerBitString
+ >>> bs_der = DerBitString(b'\\xAA')
+ >>> bs_der.value += b'\\xBB'
+ >>> print(bs_der.encode().hex())
+
+ which will show ``030300aabb``, the DER encoding for the bit string
+ ``b'\\xAA\\xBB'``.
+
+ For decoding:
+
+ >>> s = bytes.fromhex('030300aabb')
+ >>> try:
+ >>> bs_der = DerBitString()
+ >>> bs_der.decode(s)
+ >>> print(bs_der.value.hex())
+ >>> except ValueError:
+ >>> print "Not a valid DER BIT STRING"
+
+ the output will be ``aabb``.
+
+ :ivar value: The content of the string
+ :vartype value: byte string
+ """
+
+ def __init__(self, value=b'', implicit=None, explicit=None):
+ """Initialize the DER object as a BIT STRING.
+
+ :Parameters:
+ value : byte string or DER object
+ The initial, packed bit string.
+ If not specified, the bit string is empty.
+ implicit : integer
+ The IMPLICIT tag to use for the encoded object.
+ It overrides the universal tag for OCTET STRING (3).
+ explicit : integer
+ The EXPLICIT tag to use for the encoded object.
+ """
+ DerObject.__init__(self, 0x03, b'', implicit, False, explicit)
+
+ # The bitstring value (packed)
+ if isinstance(value, DerObject):
+ self.value = value.encode()
+ else:
+ self.value = value
+
+ def encode(self):
+ """Return the DER BIT STRING, fully encoded as a
+ byte string."""
+
+ # Add padding count byte
+ self.payload = b'\x00' + self.value
+ return DerObject.encode(self)
+
+ def decode(self, der_encoded, strict=False):
+ """Decode a complete DER BIT STRING, and re-initializes this
+ object with it.
+
+ Args:
+ der_encoded (byte string): a complete DER BIT STRING.
+ strict (boolean):
+ Whether decoding must check for strict DER compliancy.
+
+ Raises:
+ ValueError: in case of parsing errors.
+ """
+
+ return DerObject.decode(self, der_encoded, strict)
+
+ def _decodeFromStream(self, s, strict):
+ """Decode a complete DER BIT STRING DER from a file."""
+
+ # Fill-up self.payload
+ DerObject._decodeFromStream(self, s, strict)
+
+ if self.payload and bord(self.payload[0]) != 0:
+ raise ValueError("Not a valid BIT STRING")
+
+ # Fill-up self.value
+ self.value = b''
+ # Remove padding count byte
+ if self.payload:
+ self.value = self.payload[1:]
+
+
+class DerSetOf(DerObject):
+ """Class to model a DER SET OF.
+
+ An example of encoding is:
+
+ >>> from Crypto.Util.asn1 import DerBitString
+ >>> from binascii import hexlify, unhexlify
+ >>> so_der = DerSetOf([4,5])
+ >>> so_der.add(6)
+ >>> print hexlify(so_der.encode())
+
+ which will show ``3109020104020105020106``, the DER encoding
+ of a SET OF with items 4,5, and 6.
+
+ For decoding:
+
+ >>> s = unhexlify(b'3109020104020105020106')
+ >>> try:
+ >>> so_der = DerSetOf()
+ >>> so_der.decode(s)
+ >>> print [x for x in so_der]
+ >>> except ValueError:
+ >>> print "Not a valid DER SET OF"
+
+ the output will be ``[4, 5, 6]``.
+ """
+
+ def __init__(self, startSet=None, implicit=None):
+ """Initialize the DER object as a SET OF.
+
+ :Parameters:
+ startSet : container
+ The initial set of integers or DER encoded objects.
+ implicit : integer
+ The IMPLICIT tag to use for the encoded object.
+ It overrides the universal tag for SET OF (17).
+ """
+ DerObject.__init__(self, 0x11, b'', implicit, True)
+ self._seq = []
+
+ # All elements must be of the same type (and therefore have the
+ # same leading octet)
+ self._elemOctet = None
+
+ if startSet:
+ for e in startSet:
+ self.add(e)
+
+ def __getitem__(self, n):
+ return self._seq[n]
+
+ def __iter__(self):
+ return iter(self._seq)
+
+ def __len__(self):
+ return len(self._seq)
+
+ def add(self, elem):
+ """Add an element to the set.
+
+ Args:
+ elem (byte string or integer):
+ An element of the same type of objects already in the set.
+ It can be an integer or a DER encoded object.
+ """
+
+ if _is_number(elem):
+ eo = 0x02
+ elif isinstance(elem, DerObject):
+ eo = self._tag_octet
+ else:
+ eo = bord(elem[0])
+
+ if self._elemOctet != eo:
+ if self._elemOctet is not None:
+ raise ValueError("New element does not belong to the set")
+ self._elemOctet = eo
+
+ if elem not in self._seq:
+ self._seq.append(elem)
+
+ def decode(self, der_encoded, strict=False):
+ """Decode a complete SET OF DER element, and re-initializes this
+ object with it.
+
+ DER INTEGERs are decoded into Python integers. Any other DER
+ element is left undecoded; its validity is not checked.
+
+ Args:
+ der_encoded (byte string): a complete DER BIT SET OF.
+ strict (boolean):
+ Whether decoding must check for strict DER compliancy.
+
+ Raises:
+ ValueError: in case of parsing errors.
+ """
+
+ return DerObject.decode(self, der_encoded, strict)
+
+ def _decodeFromStream(self, s, strict):
+ """Decode a complete DER SET OF from a file."""
+
+ self._seq = []
+
+ # Fill up self.payload
+ DerObject._decodeFromStream(self, s, strict)
+
+ # Add one item at a time to self.seq, by scanning self.payload
+ p = BytesIO_EOF(self.payload)
+ setIdOctet = -1
+ while p.remaining_data() > 0:
+ p.set_bookmark()
+
+ der = DerObject()
+ der._decodeFromStream(p, strict)
+
+ # Verify that all members are of the same type
+ if setIdOctet < 0:
+ setIdOctet = der._tag_octet
+ else:
+ if setIdOctet != der._tag_octet:
+ raise ValueError("Not all elements are of the same DER type")
+
+ # Parse INTEGERs differently
+ if setIdOctet != 0x02:
+ self._seq.append(p.data_since_bookmark())
+ else:
+ derInt = DerInteger()
+ derInt.decode(p.data_since_bookmark(), strict)
+ self._seq.append(derInt.value)
+ # end
+
+ def encode(self):
+ """Return this SET OF DER element, fully encoded as a
+ binary string.
+ """
+
+ # Elements in the set must be ordered in lexicographic order
+ ordered = []
+ for item in self._seq:
+ if _is_number(item):
+ bys = DerInteger(item).encode()
+ elif isinstance(item, DerObject):
+ bys = item.encode()
+ else:
+ bys = item
+ ordered.append(bys)
+ ordered.sort()
+ self.payload = b''.join(ordered)
+ return DerObject.encode(self)
diff --git a/lib/Crypto/Util/asn1.pyi b/lib/Crypto/Util/asn1.pyi
new file mode 100644
index 0000000..dac023b
--- /dev/null
+++ b/lib/Crypto/Util/asn1.pyi
@@ -0,0 +1,74 @@
+from typing import Optional, Sequence, Union, Set, Iterable
+
+__all__ = ['DerObject', 'DerInteger', 'DerOctetString', 'DerNull',
+ 'DerSequence', 'DerObjectId', 'DerBitString', 'DerSetOf']
+
+# TODO: Make the encoded DerObjects their own type, so that DerSequence and
+# DerSetOf can check their contents better
+
+class BytesIO_EOF:
+ def __init__(self, initial_bytes: bytes) -> None: ...
+ def set_bookmark(self) -> None: ...
+ def data_since_bookmark(self) -> bytes: ...
+ def remaining_data(self) -> int: ...
+ def read(self, length: int) -> bytes: ...
+ def read_byte(self) -> bytes: ...
+
+class DerObject:
+ payload: bytes
+ def __init__(self, asn1Id: Optional[int]=None, payload: Optional[bytes]=..., implicit: Optional[int]=None,
+ constructed: Optional[bool]=False, explicit: Optional[int]=None) -> None: ...
+ def encode(self) -> bytes: ...
+ def decode(self, der_encoded: bytes, strict: Optional[bool]=False) -> DerObject: ...
+
+class DerInteger(DerObject):
+ value: int
+ def __init__(self, value: Optional[int]= 0, implicit: Optional[int]=None, explicit: Optional[int]=None) -> None: ...
+ def encode(self) -> bytes: ...
+ def decode(self, der_encoded: bytes, strict: Optional[bool]=False) -> DerInteger: ...
+
+class DerSequence(DerObject):
+ def __init__(self, startSeq: Optional[Sequence[Union[int, DerInteger, DerObject]]]=None, implicit: Optional[int]=None) -> None: ...
+ def __delitem__(self, n: int) -> None: ...
+ def __getitem__(self, n: int) -> None: ...
+ def __setitem__(self, key: int, value: DerObject) -> None: ...
+ def __setslice__(self, i: int, j: int, sequence: Sequence) -> None: ...
+ def __delslice__(self, i: int, j: int) -> None: ...
+ def __getslice__(self, i: int, j: int) -> DerSequence: ...
+ def __len__(self) -> int: ...
+ def __iadd__(self, item: DerObject) -> DerSequence: ...
+ def append(self, item: DerObject) -> DerSequence: ...
+ def hasInts(self, only_non_negative: Optional[bool]=True) -> int: ...
+ def hasOnlyInts(self, only_non_negative: Optional[bool]=True) -> bool: ...
+ def encode(self) -> bytes: ...
+ def decode(self, der_encoded: bytes, strict: Optional[bool]=False, nr_elements: Optional[int]=None, only_ints_expected: Optional[bool]=False) -> DerSequence: ...
+
+class DerOctetString(DerObject):
+ payload: bytes
+ def __init__(self, value: Optional[bytes]=..., implicit: Optional[int]=None) -> None: ...
+
+class DerNull(DerObject):
+ def __init__(self) -> None: ...
+
+class DerObjectId(DerObject):
+ value: str
+ def __init__(self, value: Optional[str]=..., implicit: Optional[int]=None, explicit: Optional[int]=None) -> None: ...
+ def encode(self) -> bytes: ...
+ def decode(self, der_encoded: bytes, strict: Optional[bool]=False) -> DerObjectId: ...
+
+class DerBitString(DerObject):
+ value: bytes
+ def __init__(self, value: Optional[bytes]=..., implicit: Optional[int]=None, explicit: Optional[int]=None) -> None: ...
+ def encode(self) -> bytes: ...
+ def decode(self, der_encoded: bytes, strict: Optional[bool]=False) -> DerBitString: ...
+
+DerSetElement = Union[bytes, int]
+
+class DerSetOf(DerObject):
+ def __init__(self, startSet: Optional[Set[DerSetElement]]=None, implicit: Optional[int]=None) -> None: ...
+ def __getitem__(self, n: int) -> DerSetElement: ...
+ def __iter__(self) -> Iterable: ...
+ def __len__(self) -> int: ...
+ def add(self, elem: DerSetElement) -> None: ...
+ def decode(self, der_encoded: bytes, strict: Optional[bool]=False) -> DerObject: ...
+ def encode(self) -> bytes: ...
diff --git a/lib/Crypto/Util/number.py b/lib/Crypto/Util/number.py
new file mode 100644
index 0000000..279ffe0
--- /dev/null
+++ b/lib/Crypto/Util/number.py
@@ -0,0 +1,1500 @@
+#
+# number.py : Number-theoretic functions
+#
+# Part of the Python Cryptography Toolkit
+#
+# Written by Andrew M. Kuchling, Barry A. Warsaw, and others
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+#
+
+import math
+import sys
+import struct
+from Crypto import Random
+from Crypto.Util.py3compat import iter_range
+
+# Backward compatibility
+_fastmath = None
+
+
+def ceil_div(n, d):
+ """Return ceil(n/d), that is, the smallest integer r such that r*d >= n"""
+
+ if d == 0:
+ raise ZeroDivisionError()
+ if (n < 0) or (d < 0):
+ raise ValueError("Non positive values")
+ r, q = divmod(n, d)
+ if (n != 0) and (q != 0):
+ r += 1
+ return r
+
+
+def size (N):
+ """Returns the size of the number N in bits."""
+
+ if N < 0:
+ raise ValueError("Size in bits only avialable for non-negative numbers")
+
+ bits = 0
+ while N >> bits:
+ bits += 1
+ return bits
+
+
+def getRandomInteger(N, randfunc=None):
+ """Return a random number at most N bits long.
+
+ If :data:`randfunc` is omitted, then :meth:`Random.get_random_bytes` is used.
+
+ .. deprecated:: 3.0
+ This function is for internal use only and may be renamed or removed in
+ the future. Use :func:`Crypto.Random.random.getrandbits` instead.
+ """
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ S = randfunc(N>>3)
+ odd_bits = N % 8
+ if odd_bits != 0:
+ rand_bits = ord(randfunc(1)) >> (8-odd_bits)
+ S = struct.pack('B', rand_bits) + S
+ value = bytes_to_long(S)
+ return value
+
+def getRandomRange(a, b, randfunc=None):
+ """Return a random number *n* so that *a <= n < b*.
+
+ If :data:`randfunc` is omitted, then :meth:`Random.get_random_bytes` is used.
+
+ .. deprecated:: 3.0
+ This function is for internal use only and may be renamed or removed in
+ the future. Use :func:`Crypto.Random.random.randrange` instead.
+ """
+
+ range_ = b - a - 1
+ bits = size(range_)
+ value = getRandomInteger(bits, randfunc)
+ while value > range_:
+ value = getRandomInteger(bits, randfunc)
+ return a + value
+
+def getRandomNBitInteger(N, randfunc=None):
+ """Return a random number with exactly N-bits,
+ i.e. a random number between 2**(N-1) and (2**N)-1.
+
+ If :data:`randfunc` is omitted, then :meth:`Random.get_random_bytes` is used.
+
+ .. deprecated:: 3.0
+ This function is for internal use only and may be renamed or removed in
+ the future.
+ """
+
+ value = getRandomInteger (N-1, randfunc)
+ value |= 2 ** (N-1) # Ensure high bit is set
+ assert size(value) >= N
+ return value
+
+def GCD(x,y):
+ """Greatest Common Denominator of :data:`x` and :data:`y`.
+ """
+
+ x = abs(x) ; y = abs(y)
+ while x > 0:
+ x, y = y % x, x
+ return y
+
+def inverse(u, v):
+ """The inverse of :data:`u` *mod* :data:`v`."""
+
+ u3, v3 = u, v
+ u1, v1 = 1, 0
+ while v3 > 0:
+ q = u3 // v3
+ u1, v1 = v1, u1 - v1*q
+ u3, v3 = v3, u3 - v3*q
+ while u1<0:
+ u1 = u1 + v
+ return u1
+
+# Given a number of bits to generate and a random generation function,
+# find a prime number of the appropriate size.
+
+def getPrime(N, randfunc=None):
+ """Return a random N-bit prime number.
+
+ N must be an integer larger than 1.
+ If randfunc is omitted, then :meth:`Random.get_random_bytes` is used.
+ """
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ if N < 2:
+ raise ValueError("N must be larger than 1")
+
+ while True:
+ number = getRandomNBitInteger(N, randfunc) | 1
+ if isPrime(number, randfunc=randfunc):
+ break
+ return number
+
+
+def _rabinMillerTest(n, rounds, randfunc=None):
+ """_rabinMillerTest(n:long, rounds:int, randfunc:callable):int
+ Tests if n is prime.
+ Returns 0 when n is definitely composite.
+ Returns 1 when n is probably prime.
+ Returns 2 when n is definitely prime.
+
+ If randfunc is omitted, then Random.new().read is used.
+
+ This function is for internal use only and may be renamed or removed in
+ the future.
+ """
+ # check special cases (n==2, n even, n < 2)
+ if n < 3 or (n & 1) == 0:
+ return n == 2
+ # n might be very large so it might be beneficial to precalculate n-1
+ n_1 = n - 1
+ # determine m and b so that 2**b * m = n - 1 and b maximal
+ b = 0
+ m = n_1
+ while (m & 1) == 0:
+ b += 1
+ m >>= 1
+
+ tested = []
+ # we need to do at most n-2 rounds.
+ for i in iter_range (min (rounds, n-2)):
+ # randomly choose a < n and make sure it hasn't been tested yet
+ a = getRandomRange (2, n, randfunc)
+ while a in tested:
+ a = getRandomRange (2, n, randfunc)
+ tested.append (a)
+ # do the rabin-miller test
+ z = pow (a, m, n) # (a**m) % n
+ if z == 1 or z == n_1:
+ continue
+ composite = 1
+ for r in iter_range(b):
+ z = (z * z) % n
+ if z == 1:
+ return 0
+ elif z == n_1:
+ composite = 0
+ break
+ if composite:
+ return 0
+ return 1
+
+def getStrongPrime(N, e=0, false_positive_prob=1e-6, randfunc=None):
+ r"""
+ Return a random strong *N*-bit prime number.
+ In this context, *p* is a strong prime if *p-1* and *p+1* have at
+ least one large prime factor.
+
+ Args:
+ N (integer): the exact length of the strong prime.
+ It must be a multiple of 128 and > 512.
+ e (integer): if provided, the returned prime (minus 1)
+ will be coprime to *e* and thus suitable for RSA where
+ *e* is the public exponent.
+ false_positive_prob (float):
+ The statistical probability for the result not to be actually a
+ prime. It defaults to 10\ :sup:`-6`.
+ Note that the real probability of a false-positive is far less. This is
+ just the mathematically provable limit.
+ randfunc (callable):
+ A function that takes a parameter *N* and that returns
+ a random byte string of such length.
+ If omitted, :func:`Crypto.Random.get_random_bytes` is used.
+ Return:
+ The new strong prime.
+
+ .. deprecated:: 3.0
+ This function is for internal use only and may be renamed or removed in
+ the future.
+ """
+
+ # This function was implemented following the
+ # instructions found in the paper:
+ # "FAST GENERATION OF RANDOM, STRONG RSA PRIMES"
+ # by Robert D. Silverman
+ # RSA Laboratories
+ # May 17, 1997
+ # which by the time of writing could be freely downloaded here:
+ # http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.17.2713&rep=rep1&type=pdf
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ # Use the accelerator if available
+ if _fastmath is not None:
+ return _fastmath.getStrongPrime(long(N), long(e), false_positive_prob,
+ randfunc)
+
+ if (N < 512) or ((N % 128) != 0):
+ raise ValueError ("bits must be multiple of 128 and > 512")
+
+ rabin_miller_rounds = int(math.ceil(-math.log(false_positive_prob)/math.log(4)))
+
+ # calculate range for X
+ # lower_bound = sqrt(2) * 2^{511 + 128*x}
+ # upper_bound = 2^{512 + 128*x} - 1
+ x = (N - 512) >> 7;
+ # We need to approximate the sqrt(2) in the lower_bound by an integer
+ # expression because floating point math overflows with these numbers
+ lower_bound = (14142135623730950489 * (2 ** (511 + 128*x))) // 10000000000000000000
+ upper_bound = (1 << (512 + 128*x)) - 1
+ # Randomly choose X in calculated range
+ X = getRandomRange (lower_bound, upper_bound, randfunc)
+
+ # generate p1 and p2
+ p = [0, 0]
+ for i in (0, 1):
+ # randomly choose 101-bit y
+ y = getRandomNBitInteger (101, randfunc)
+ # initialize the field for sieving
+ field = [0] * 5 * len (sieve_base)
+ # sieve the field
+ for prime in sieve_base:
+ offset = y % prime
+ for j in iter_range((prime - offset) % prime, len (field), prime):
+ field[j] = 1
+
+ # look for suitable p[i] starting at y
+ result = 0
+ for j in range(len(field)):
+ composite = field[j]
+ # look for next canidate
+ if composite:
+ continue
+ tmp = y + j
+ result = _rabinMillerTest (tmp, rabin_miller_rounds)
+ if result > 0:
+ p[i] = tmp
+ break
+ if result == 0:
+ raise RuntimeError ("Couln't find prime in field. "
+ "Developer: Increase field_size")
+
+ # Calculate R
+ # R = (p2^{-1} mod p1) * p2 - (p1^{-1} mod p2) * p1
+ tmp1 = inverse (p[1], p[0]) * p[1] # (p2^-1 mod p1)*p2
+ tmp2 = inverse (p[0], p[1]) * p[0] # (p1^-1 mod p2)*p1
+ R = tmp1 - tmp2 # (p2^-1 mod p1)*p2 - (p1^-1 mod p2)*p1
+
+ # search for final prime number starting by Y0
+ # Y0 = X + (R - X mod p1p2)
+ increment = p[0] * p[1]
+ X = X + (R - (X % increment))
+ while 1:
+ is_possible_prime = 1
+ # first check candidate against sieve_base
+ for prime in sieve_base:
+ if (X % prime) == 0:
+ is_possible_prime = 0
+ break
+ # if e is given make sure that e and X-1 are coprime
+ # this is not necessarily a strong prime criterion but useful when
+ # creating them for RSA where the p-1 and q-1 should be coprime to
+ # the public exponent e
+ if e and is_possible_prime:
+ if e & 1:
+ if GCD(e, X-1) != 1:
+ is_possible_prime = 0
+ else:
+ if GCD(e, (X-1) // 2) != 1:
+ is_possible_prime = 0
+
+ # do some Rabin-Miller-Tests
+ if is_possible_prime:
+ result = _rabinMillerTest (X, rabin_miller_rounds)
+ if result > 0:
+ break
+ X += increment
+ # abort when X has more bits than requested
+ # TODO: maybe we shouldn't abort but rather start over.
+ if X >= 1 << N:
+ raise RuntimeError ("Couln't find prime in field. "
+ "Developer: Increase field_size")
+ return X
+
+def isPrime(N, false_positive_prob=1e-6, randfunc=None):
+ r"""Test if a number *N* is a prime.
+
+ Args:
+ false_positive_prob (float):
+ The statistical probability for the result not to be actually a
+ prime. It defaults to 10\ :sup:`-6`.
+ Note that the real probability of a false-positive is far less.
+ This is just the mathematically provable limit.
+ randfunc (callable):
+ A function that takes a parameter *N* and that returns
+ a random byte string of such length.
+ If omitted, :func:`Crypto.Random.get_random_bytes` is used.
+
+ Return:
+ `True` is the input is indeed prime.
+ """
+
+ if randfunc is None:
+ randfunc = Random.get_random_bytes
+
+ if _fastmath is not None:
+ return _fastmath.isPrime(long(N), false_positive_prob, randfunc)
+
+ if N < 3 or N & 1 == 0:
+ return N == 2
+ for p in sieve_base:
+ if N == p:
+ return 1
+ if N % p == 0:
+ return 0
+
+ rounds = int(math.ceil(-math.log(false_positive_prob)/math.log(4)))
+ return _rabinMillerTest(N, rounds, randfunc)
+
+
+# Improved conversion functions contributed by Barry Warsaw, after
+# careful benchmarking
+
+import struct
+
+def long_to_bytes(n, blocksize=0):
+ """Convert a positive integer to a byte string using big endian encoding.
+
+ If :data:`blocksize` is absent or zero, the byte string will
+ be of minimal length.
+
+ Otherwise, the length of the byte string is guaranteed to be a multiple
+ of :data:`blocksize`. If necessary, zeroes (``\\x00``) are added at the left.
+
+ .. note::
+ In Python 3, if you are sure that :data:`n` can fit into
+ :data:`blocksize` bytes, you can simply use the native method instead::
+
+ >>> n.to_bytes(blocksize, 'big')
+
+ For instance::
+
+ >>> n = 80
+ >>> n.to_bytes(2, 'big')
+ b'\\x00P'
+
+ However, and unlike this ``long_to_bytes()`` function,
+ an ``OverflowError`` exception is raised if :data:`n` does not fit.
+ """
+
+ if n < 0 or blocksize < 0:
+ raise ValueError("Values must be non-negative")
+
+ result = []
+ pack = struct.pack
+
+ # Fill the first block independently from the value of n
+ bsr = blocksize
+ while bsr >= 8:
+ result.insert(0, pack('>Q', n & 0xFFFFFFFFFFFFFFFF))
+ n = n >> 64
+ bsr -= 8
+
+ while bsr >= 4:
+ result.insert(0, pack('>I', n & 0xFFFFFFFF))
+ n = n >> 32
+ bsr -= 4
+
+ while bsr > 0:
+ result.insert(0, pack('>B', n & 0xFF))
+ n = n >> 8
+ bsr -= 1
+
+ if n == 0:
+ if len(result) == 0:
+ bresult = b'\x00'
+ else:
+ bresult = b''.join(result)
+ else:
+ # The encoded number exceeds the block size
+ while n > 0:
+ result.insert(0, pack('>Q', n & 0xFFFFFFFFFFFFFFFF))
+ n = n >> 64
+ result[0] = result[0].lstrip(b'\x00')
+ bresult = b''.join(result)
+ # bresult has minimum length here
+ if blocksize > 0:
+ target_len = ((len(bresult) - 1) // blocksize + 1) * blocksize
+ bresult = b'\x00' * (target_len - len(bresult)) + bresult
+
+ return bresult
+
+
+def bytes_to_long(s):
+ """Convert a byte string to a long integer (big endian).
+
+ In Python 3.2+, use the native method instead::
+
+ >>> int.from_bytes(s, 'big')
+
+ For instance::
+
+ >>> int.from_bytes(b'\x00P', 'big')
+ 80
+
+ This is (essentially) the inverse of :func:`long_to_bytes`.
+ """
+ acc = 0
+
+ unpack = struct.unpack
+
+ # Up to Python 2.7.4, struct.unpack can't work with bytearrays nor
+ # memoryviews
+ if sys.version_info[0:3] < (2, 7, 4):
+ if isinstance(s, bytearray):
+ s = bytes(s)
+ elif isinstance(s, memoryview):
+ s = s.tobytes()
+
+ length = len(s)
+ if length % 4:
+ extra = (4 - length % 4)
+ s = b'\x00' * extra + s
+ length = length + extra
+ for i in range(0, length, 4):
+ acc = (acc << 32) + unpack('>I', s[i:i+4])[0]
+ return acc
+
+
+# For backwards compatibility...
+import warnings
+def long2str(n, blocksize=0):
+ warnings.warn("long2str() has been replaced by long_to_bytes()")
+ return long_to_bytes(n, blocksize)
+def str2long(s):
+ warnings.warn("str2long() has been replaced by bytes_to_long()")
+ return bytes_to_long(s)
+
+
+# The first 10000 primes used for checking primality.
+# This should be enough to eliminate most of the odd
+# numbers before needing to do a Rabin-Miller test at all.
+sieve_base = (
+ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
+ 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
+ 73, 79, 83, 89, 97, 101, 103, 107, 109, 113,
+ 127, 131, 137, 139, 149, 151, 157, 163, 167, 173,
+ 179, 181, 191, 193, 197, 199, 211, 223, 227, 229,
+ 233, 239, 241, 251, 257, 263, 269, 271, 277, 281,
+ 283, 293, 307, 311, 313, 317, 331, 337, 347, 349,
+ 353, 359, 367, 373, 379, 383, 389, 397, 401, 409,
+ 419, 421, 431, 433, 439, 443, 449, 457, 461, 463,
+ 467, 479, 487, 491, 499, 503, 509, 521, 523, 541,
+ 547, 557, 563, 569, 571, 577, 587, 593, 599, 601,
+ 607, 613, 617, 619, 631, 641, 643, 647, 653, 659,
+ 661, 673, 677, 683, 691, 701, 709, 719, 727, 733,
+ 739, 743, 751, 757, 761, 769, 773, 787, 797, 809,
+ 811, 821, 823, 827, 829, 839, 853, 857, 859, 863,
+ 877, 881, 883, 887, 907, 911, 919, 929, 937, 941,
+ 947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013,
+ 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069,
+ 1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151,
+ 1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223,
+ 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289, 1291,
+ 1297, 1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373,
+ 1381, 1399, 1409, 1423, 1427, 1429, 1433, 1439, 1447, 1451,
+ 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, 1499, 1511,
+ 1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583,
+ 1597, 1601, 1607, 1609, 1613, 1619, 1621, 1627, 1637, 1657,
+ 1663, 1667, 1669, 1693, 1697, 1699, 1709, 1721, 1723, 1733,
+ 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1789, 1801, 1811,
+ 1823, 1831, 1847, 1861, 1867, 1871, 1873, 1877, 1879, 1889,
+ 1901, 1907, 1913, 1931, 1933, 1949, 1951, 1973, 1979, 1987,
+ 1993, 1997, 1999, 2003, 2011, 2017, 2027, 2029, 2039, 2053,
+ 2063, 2069, 2081, 2083, 2087, 2089, 2099, 2111, 2113, 2129,
+ 2131, 2137, 2141, 2143, 2153, 2161, 2179, 2203, 2207, 2213,
+ 2221, 2237, 2239, 2243, 2251, 2267, 2269, 2273, 2281, 2287,
+ 2293, 2297, 2309, 2311, 2333, 2339, 2341, 2347, 2351, 2357,
+ 2371, 2377, 2381, 2383, 2389, 2393, 2399, 2411, 2417, 2423,
+ 2437, 2441, 2447, 2459, 2467, 2473, 2477, 2503, 2521, 2531,
+ 2539, 2543, 2549, 2551, 2557, 2579, 2591, 2593, 2609, 2617,
+ 2621, 2633, 2647, 2657, 2659, 2663, 2671, 2677, 2683, 2687,
+ 2689, 2693, 2699, 2707, 2711, 2713, 2719, 2729, 2731, 2741,
+ 2749, 2753, 2767, 2777, 2789, 2791, 2797, 2801, 2803, 2819,
+ 2833, 2837, 2843, 2851, 2857, 2861, 2879, 2887, 2897, 2903,
+ 2909, 2917, 2927, 2939, 2953, 2957, 2963, 2969, 2971, 2999,
+ 3001, 3011, 3019, 3023, 3037, 3041, 3049, 3061, 3067, 3079,
+ 3083, 3089, 3109, 3119, 3121, 3137, 3163, 3167, 3169, 3181,
+ 3187, 3191, 3203, 3209, 3217, 3221, 3229, 3251, 3253, 3257,
+ 3259, 3271, 3299, 3301, 3307, 3313, 3319, 3323, 3329, 3331,
+ 3343, 3347, 3359, 3361, 3371, 3373, 3389, 3391, 3407, 3413,
+ 3433, 3449, 3457, 3461, 3463, 3467, 3469, 3491, 3499, 3511,
+ 3517, 3527, 3529, 3533, 3539, 3541, 3547, 3557, 3559, 3571,
+ 3581, 3583, 3593, 3607, 3613, 3617, 3623, 3631, 3637, 3643,
+ 3659, 3671, 3673, 3677, 3691, 3697, 3701, 3709, 3719, 3727,
+ 3733, 3739, 3761, 3767, 3769, 3779, 3793, 3797, 3803, 3821,
+ 3823, 3833, 3847, 3851, 3853, 3863, 3877, 3881, 3889, 3907,
+ 3911, 3917, 3919, 3923, 3929, 3931, 3943, 3947, 3967, 3989,
+ 4001, 4003, 4007, 4013, 4019, 4021, 4027, 4049, 4051, 4057,
+ 4073, 4079, 4091, 4093, 4099, 4111, 4127, 4129, 4133, 4139,
+ 4153, 4157, 4159, 4177, 4201, 4211, 4217, 4219, 4229, 4231,
+ 4241, 4243, 4253, 4259, 4261, 4271, 4273, 4283, 4289, 4297,
+ 4327, 4337, 4339, 4349, 4357, 4363, 4373, 4391, 4397, 4409,
+ 4421, 4423, 4441, 4447, 4451, 4457, 4463, 4481, 4483, 4493,
+ 4507, 4513, 4517, 4519, 4523, 4547, 4549, 4561, 4567, 4583,
+ 4591, 4597, 4603, 4621, 4637, 4639, 4643, 4649, 4651, 4657,
+ 4663, 4673, 4679, 4691, 4703, 4721, 4723, 4729, 4733, 4751,
+ 4759, 4783, 4787, 4789, 4793, 4799, 4801, 4813, 4817, 4831,
+ 4861, 4871, 4877, 4889, 4903, 4909, 4919, 4931, 4933, 4937,
+ 4943, 4951, 4957, 4967, 4969, 4973, 4987, 4993, 4999, 5003,
+ 5009, 5011, 5021, 5023, 5039, 5051, 5059, 5077, 5081, 5087,
+ 5099, 5101, 5107, 5113, 5119, 5147, 5153, 5167, 5171, 5179,
+ 5189, 5197, 5209, 5227, 5231, 5233, 5237, 5261, 5273, 5279,
+ 5281, 5297, 5303, 5309, 5323, 5333, 5347, 5351, 5381, 5387,
+ 5393, 5399, 5407, 5413, 5417, 5419, 5431, 5437, 5441, 5443,
+ 5449, 5471, 5477, 5479, 5483, 5501, 5503, 5507, 5519, 5521,
+ 5527, 5531, 5557, 5563, 5569, 5573, 5581, 5591, 5623, 5639,
+ 5641, 5647, 5651, 5653, 5657, 5659, 5669, 5683, 5689, 5693,
+ 5701, 5711, 5717, 5737, 5741, 5743, 5749, 5779, 5783, 5791,
+ 5801, 5807, 5813, 5821, 5827, 5839, 5843, 5849, 5851, 5857,
+ 5861, 5867, 5869, 5879, 5881, 5897, 5903, 5923, 5927, 5939,
+ 5953, 5981, 5987, 6007, 6011, 6029, 6037, 6043, 6047, 6053,
+ 6067, 6073, 6079, 6089, 6091, 6101, 6113, 6121, 6131, 6133,
+ 6143, 6151, 6163, 6173, 6197, 6199, 6203, 6211, 6217, 6221,
+ 6229, 6247, 6257, 6263, 6269, 6271, 6277, 6287, 6299, 6301,
+ 6311, 6317, 6323, 6329, 6337, 6343, 6353, 6359, 6361, 6367,
+ 6373, 6379, 6389, 6397, 6421, 6427, 6449, 6451, 6469, 6473,
+ 6481, 6491, 6521, 6529, 6547, 6551, 6553, 6563, 6569, 6571,
+ 6577, 6581, 6599, 6607, 6619, 6637, 6653, 6659, 6661, 6673,
+ 6679, 6689, 6691, 6701, 6703, 6709, 6719, 6733, 6737, 6761,
+ 6763, 6779, 6781, 6791, 6793, 6803, 6823, 6827, 6829, 6833,
+ 6841, 6857, 6863, 6869, 6871, 6883, 6899, 6907, 6911, 6917,
+ 6947, 6949, 6959, 6961, 6967, 6971, 6977, 6983, 6991, 6997,
+ 7001, 7013, 7019, 7027, 7039, 7043, 7057, 7069, 7079, 7103,
+ 7109, 7121, 7127, 7129, 7151, 7159, 7177, 7187, 7193, 7207,
+ 7211, 7213, 7219, 7229, 7237, 7243, 7247, 7253, 7283, 7297,
+ 7307, 7309, 7321, 7331, 7333, 7349, 7351, 7369, 7393, 7411,
+ 7417, 7433, 7451, 7457, 7459, 7477, 7481, 7487, 7489, 7499,
+ 7507, 7517, 7523, 7529, 7537, 7541, 7547, 7549, 7559, 7561,
+ 7573, 7577, 7583, 7589, 7591, 7603, 7607, 7621, 7639, 7643,
+ 7649, 7669, 7673, 7681, 7687, 7691, 7699, 7703, 7717, 7723,
+ 7727, 7741, 7753, 7757, 7759, 7789, 7793, 7817, 7823, 7829,
+ 7841, 7853, 7867, 7873, 7877, 7879, 7883, 7901, 7907, 7919,
+ 7927, 7933, 7937, 7949, 7951, 7963, 7993, 8009, 8011, 8017,
+ 8039, 8053, 8059, 8069, 8081, 8087, 8089, 8093, 8101, 8111,
+ 8117, 8123, 8147, 8161, 8167, 8171, 8179, 8191, 8209, 8219,
+ 8221, 8231, 8233, 8237, 8243, 8263, 8269, 8273, 8287, 8291,
+ 8293, 8297, 8311, 8317, 8329, 8353, 8363, 8369, 8377, 8387,
+ 8389, 8419, 8423, 8429, 8431, 8443, 8447, 8461, 8467, 8501,
+ 8513, 8521, 8527, 8537, 8539, 8543, 8563, 8573, 8581, 8597,
+ 8599, 8609, 8623, 8627, 8629, 8641, 8647, 8663, 8669, 8677,
+ 8681, 8689, 8693, 8699, 8707, 8713, 8719, 8731, 8737, 8741,
+ 8747, 8753, 8761, 8779, 8783, 8803, 8807, 8819, 8821, 8831,
+ 8837, 8839, 8849, 8861, 8863, 8867, 8887, 8893, 8923, 8929,
+ 8933, 8941, 8951, 8963, 8969, 8971, 8999, 9001, 9007, 9011,
+ 9013, 9029, 9041, 9043, 9049, 9059, 9067, 9091, 9103, 9109,
+ 9127, 9133, 9137, 9151, 9157, 9161, 9173, 9181, 9187, 9199,
+ 9203, 9209, 9221, 9227, 9239, 9241, 9257, 9277, 9281, 9283,
+ 9293, 9311, 9319, 9323, 9337, 9341, 9343, 9349, 9371, 9377,
+ 9391, 9397, 9403, 9413, 9419, 9421, 9431, 9433, 9437, 9439,
+ 9461, 9463, 9467, 9473, 9479, 9491, 9497, 9511, 9521, 9533,
+ 9539, 9547, 9551, 9587, 9601, 9613, 9619, 9623, 9629, 9631,
+ 9643, 9649, 9661, 9677, 9679, 9689, 9697, 9719, 9721, 9733,
+ 9739, 9743, 9749, 9767, 9769, 9781, 9787, 9791, 9803, 9811,
+ 9817, 9829, 9833, 9839, 9851, 9857, 9859, 9871, 9883, 9887,
+ 9901, 9907, 9923, 9929, 9931, 9941, 9949, 9967, 9973, 10007,
+ 10009, 10037, 10039, 10061, 10067, 10069, 10079, 10091, 10093, 10099,
+ 10103, 10111, 10133, 10139, 10141, 10151, 10159, 10163, 10169, 10177,
+ 10181, 10193, 10211, 10223, 10243, 10247, 10253, 10259, 10267, 10271,
+ 10273, 10289, 10301, 10303, 10313, 10321, 10331, 10333, 10337, 10343,
+ 10357, 10369, 10391, 10399, 10427, 10429, 10433, 10453, 10457, 10459,
+ 10463, 10477, 10487, 10499, 10501, 10513, 10529, 10531, 10559, 10567,
+ 10589, 10597, 10601, 10607, 10613, 10627, 10631, 10639, 10651, 10657,
+ 10663, 10667, 10687, 10691, 10709, 10711, 10723, 10729, 10733, 10739,
+ 10753, 10771, 10781, 10789, 10799, 10831, 10837, 10847, 10853, 10859,
+ 10861, 10867, 10883, 10889, 10891, 10903, 10909, 10937, 10939, 10949,
+ 10957, 10973, 10979, 10987, 10993, 11003, 11027, 11047, 11057, 11059,
+ 11069, 11071, 11083, 11087, 11093, 11113, 11117, 11119, 11131, 11149,
+ 11159, 11161, 11171, 11173, 11177, 11197, 11213, 11239, 11243, 11251,
+ 11257, 11261, 11273, 11279, 11287, 11299, 11311, 11317, 11321, 11329,
+ 11351, 11353, 11369, 11383, 11393, 11399, 11411, 11423, 11437, 11443,
+ 11447, 11467, 11471, 11483, 11489, 11491, 11497, 11503, 11519, 11527,
+ 11549, 11551, 11579, 11587, 11593, 11597, 11617, 11621, 11633, 11657,
+ 11677, 11681, 11689, 11699, 11701, 11717, 11719, 11731, 11743, 11777,
+ 11779, 11783, 11789, 11801, 11807, 11813, 11821, 11827, 11831, 11833,
+ 11839, 11863, 11867, 11887, 11897, 11903, 11909, 11923, 11927, 11933,
+ 11939, 11941, 11953, 11959, 11969, 11971, 11981, 11987, 12007, 12011,
+ 12037, 12041, 12043, 12049, 12071, 12073, 12097, 12101, 12107, 12109,
+ 12113, 12119, 12143, 12149, 12157, 12161, 12163, 12197, 12203, 12211,
+ 12227, 12239, 12241, 12251, 12253, 12263, 12269, 12277, 12281, 12289,
+ 12301, 12323, 12329, 12343, 12347, 12373, 12377, 12379, 12391, 12401,
+ 12409, 12413, 12421, 12433, 12437, 12451, 12457, 12473, 12479, 12487,
+ 12491, 12497, 12503, 12511, 12517, 12527, 12539, 12541, 12547, 12553,
+ 12569, 12577, 12583, 12589, 12601, 12611, 12613, 12619, 12637, 12641,
+ 12647, 12653, 12659, 12671, 12689, 12697, 12703, 12713, 12721, 12739,
+ 12743, 12757, 12763, 12781, 12791, 12799, 12809, 12821, 12823, 12829,
+ 12841, 12853, 12889, 12893, 12899, 12907, 12911, 12917, 12919, 12923,
+ 12941, 12953, 12959, 12967, 12973, 12979, 12983, 13001, 13003, 13007,
+ 13009, 13033, 13037, 13043, 13049, 13063, 13093, 13099, 13103, 13109,
+ 13121, 13127, 13147, 13151, 13159, 13163, 13171, 13177, 13183, 13187,
+ 13217, 13219, 13229, 13241, 13249, 13259, 13267, 13291, 13297, 13309,
+ 13313, 13327, 13331, 13337, 13339, 13367, 13381, 13397, 13399, 13411,
+ 13417, 13421, 13441, 13451, 13457, 13463, 13469, 13477, 13487, 13499,
+ 13513, 13523, 13537, 13553, 13567, 13577, 13591, 13597, 13613, 13619,
+ 13627, 13633, 13649, 13669, 13679, 13681, 13687, 13691, 13693, 13697,
+ 13709, 13711, 13721, 13723, 13729, 13751, 13757, 13759, 13763, 13781,
+ 13789, 13799, 13807, 13829, 13831, 13841, 13859, 13873, 13877, 13879,
+ 13883, 13901, 13903, 13907, 13913, 13921, 13931, 13933, 13963, 13967,
+ 13997, 13999, 14009, 14011, 14029, 14033, 14051, 14057, 14071, 14081,
+ 14083, 14087, 14107, 14143, 14149, 14153, 14159, 14173, 14177, 14197,
+ 14207, 14221, 14243, 14249, 14251, 14281, 14293, 14303, 14321, 14323,
+ 14327, 14341, 14347, 14369, 14387, 14389, 14401, 14407, 14411, 14419,
+ 14423, 14431, 14437, 14447, 14449, 14461, 14479, 14489, 14503, 14519,
+ 14533, 14537, 14543, 14549, 14551, 14557, 14561, 14563, 14591, 14593,
+ 14621, 14627, 14629, 14633, 14639, 14653, 14657, 14669, 14683, 14699,
+ 14713, 14717, 14723, 14731, 14737, 14741, 14747, 14753, 14759, 14767,
+ 14771, 14779, 14783, 14797, 14813, 14821, 14827, 14831, 14843, 14851,
+ 14867, 14869, 14879, 14887, 14891, 14897, 14923, 14929, 14939, 14947,
+ 14951, 14957, 14969, 14983, 15013, 15017, 15031, 15053, 15061, 15073,
+ 15077, 15083, 15091, 15101, 15107, 15121, 15131, 15137, 15139, 15149,
+ 15161, 15173, 15187, 15193, 15199, 15217, 15227, 15233, 15241, 15259,
+ 15263, 15269, 15271, 15277, 15287, 15289, 15299, 15307, 15313, 15319,
+ 15329, 15331, 15349, 15359, 15361, 15373, 15377, 15383, 15391, 15401,
+ 15413, 15427, 15439, 15443, 15451, 15461, 15467, 15473, 15493, 15497,
+ 15511, 15527, 15541, 15551, 15559, 15569, 15581, 15583, 15601, 15607,
+ 15619, 15629, 15641, 15643, 15647, 15649, 15661, 15667, 15671, 15679,
+ 15683, 15727, 15731, 15733, 15737, 15739, 15749, 15761, 15767, 15773,
+ 15787, 15791, 15797, 15803, 15809, 15817, 15823, 15859, 15877, 15881,
+ 15887, 15889, 15901, 15907, 15913, 15919, 15923, 15937, 15959, 15971,
+ 15973, 15991, 16001, 16007, 16033, 16057, 16061, 16063, 16067, 16069,
+ 16073, 16087, 16091, 16097, 16103, 16111, 16127, 16139, 16141, 16183,
+ 16187, 16189, 16193, 16217, 16223, 16229, 16231, 16249, 16253, 16267,
+ 16273, 16301, 16319, 16333, 16339, 16349, 16361, 16363, 16369, 16381,
+ 16411, 16417, 16421, 16427, 16433, 16447, 16451, 16453, 16477, 16481,
+ 16487, 16493, 16519, 16529, 16547, 16553, 16561, 16567, 16573, 16603,
+ 16607, 16619, 16631, 16633, 16649, 16651, 16657, 16661, 16673, 16691,
+ 16693, 16699, 16703, 16729, 16741, 16747, 16759, 16763, 16787, 16811,
+ 16823, 16829, 16831, 16843, 16871, 16879, 16883, 16889, 16901, 16903,
+ 16921, 16927, 16931, 16937, 16943, 16963, 16979, 16981, 16987, 16993,
+ 17011, 17021, 17027, 17029, 17033, 17041, 17047, 17053, 17077, 17093,
+ 17099, 17107, 17117, 17123, 17137, 17159, 17167, 17183, 17189, 17191,
+ 17203, 17207, 17209, 17231, 17239, 17257, 17291, 17293, 17299, 17317,
+ 17321, 17327, 17333, 17341, 17351, 17359, 17377, 17383, 17387, 17389,
+ 17393, 17401, 17417, 17419, 17431, 17443, 17449, 17467, 17471, 17477,
+ 17483, 17489, 17491, 17497, 17509, 17519, 17539, 17551, 17569, 17573,
+ 17579, 17581, 17597, 17599, 17609, 17623, 17627, 17657, 17659, 17669,
+ 17681, 17683, 17707, 17713, 17729, 17737, 17747, 17749, 17761, 17783,
+ 17789, 17791, 17807, 17827, 17837, 17839, 17851, 17863, 17881, 17891,
+ 17903, 17909, 17911, 17921, 17923, 17929, 17939, 17957, 17959, 17971,
+ 17977, 17981, 17987, 17989, 18013, 18041, 18043, 18047, 18049, 18059,
+ 18061, 18077, 18089, 18097, 18119, 18121, 18127, 18131, 18133, 18143,
+ 18149, 18169, 18181, 18191, 18199, 18211, 18217, 18223, 18229, 18233,
+ 18251, 18253, 18257, 18269, 18287, 18289, 18301, 18307, 18311, 18313,
+ 18329, 18341, 18353, 18367, 18371, 18379, 18397, 18401, 18413, 18427,
+ 18433, 18439, 18443, 18451, 18457, 18461, 18481, 18493, 18503, 18517,
+ 18521, 18523, 18539, 18541, 18553, 18583, 18587, 18593, 18617, 18637,
+ 18661, 18671, 18679, 18691, 18701, 18713, 18719, 18731, 18743, 18749,
+ 18757, 18773, 18787, 18793, 18797, 18803, 18839, 18859, 18869, 18899,
+ 18911, 18913, 18917, 18919, 18947, 18959, 18973, 18979, 19001, 19009,
+ 19013, 19031, 19037, 19051, 19069, 19073, 19079, 19081, 19087, 19121,
+ 19139, 19141, 19157, 19163, 19181, 19183, 19207, 19211, 19213, 19219,
+ 19231, 19237, 19249, 19259, 19267, 19273, 19289, 19301, 19309, 19319,
+ 19333, 19373, 19379, 19381, 19387, 19391, 19403, 19417, 19421, 19423,
+ 19427, 19429, 19433, 19441, 19447, 19457, 19463, 19469, 19471, 19477,
+ 19483, 19489, 19501, 19507, 19531, 19541, 19543, 19553, 19559, 19571,
+ 19577, 19583, 19597, 19603, 19609, 19661, 19681, 19687, 19697, 19699,
+ 19709, 19717, 19727, 19739, 19751, 19753, 19759, 19763, 19777, 19793,
+ 19801, 19813, 19819, 19841, 19843, 19853, 19861, 19867, 19889, 19891,
+ 19913, 19919, 19927, 19937, 19949, 19961, 19963, 19973, 19979, 19991,
+ 19993, 19997, 20011, 20021, 20023, 20029, 20047, 20051, 20063, 20071,
+ 20089, 20101, 20107, 20113, 20117, 20123, 20129, 20143, 20147, 20149,
+ 20161, 20173, 20177, 20183, 20201, 20219, 20231, 20233, 20249, 20261,
+ 20269, 20287, 20297, 20323, 20327, 20333, 20341, 20347, 20353, 20357,
+ 20359, 20369, 20389, 20393, 20399, 20407, 20411, 20431, 20441, 20443,
+ 20477, 20479, 20483, 20507, 20509, 20521, 20533, 20543, 20549, 20551,
+ 20563, 20593, 20599, 20611, 20627, 20639, 20641, 20663, 20681, 20693,
+ 20707, 20717, 20719, 20731, 20743, 20747, 20749, 20753, 20759, 20771,
+ 20773, 20789, 20807, 20809, 20849, 20857, 20873, 20879, 20887, 20897,
+ 20899, 20903, 20921, 20929, 20939, 20947, 20959, 20963, 20981, 20983,
+ 21001, 21011, 21013, 21017, 21019, 21023, 21031, 21059, 21061, 21067,
+ 21089, 21101, 21107, 21121, 21139, 21143, 21149, 21157, 21163, 21169,
+ 21179, 21187, 21191, 21193, 21211, 21221, 21227, 21247, 21269, 21277,
+ 21283, 21313, 21317, 21319, 21323, 21341, 21347, 21377, 21379, 21383,
+ 21391, 21397, 21401, 21407, 21419, 21433, 21467, 21481, 21487, 21491,
+ 21493, 21499, 21503, 21517, 21521, 21523, 21529, 21557, 21559, 21563,
+ 21569, 21577, 21587, 21589, 21599, 21601, 21611, 21613, 21617, 21647,
+ 21649, 21661, 21673, 21683, 21701, 21713, 21727, 21737, 21739, 21751,
+ 21757, 21767, 21773, 21787, 21799, 21803, 21817, 21821, 21839, 21841,
+ 21851, 21859, 21863, 21871, 21881, 21893, 21911, 21929, 21937, 21943,
+ 21961, 21977, 21991, 21997, 22003, 22013, 22027, 22031, 22037, 22039,
+ 22051, 22063, 22067, 22073, 22079, 22091, 22093, 22109, 22111, 22123,
+ 22129, 22133, 22147, 22153, 22157, 22159, 22171, 22189, 22193, 22229,
+ 22247, 22259, 22271, 22273, 22277, 22279, 22283, 22291, 22303, 22307,
+ 22343, 22349, 22367, 22369, 22381, 22391, 22397, 22409, 22433, 22441,
+ 22447, 22453, 22469, 22481, 22483, 22501, 22511, 22531, 22541, 22543,
+ 22549, 22567, 22571, 22573, 22613, 22619, 22621, 22637, 22639, 22643,
+ 22651, 22669, 22679, 22691, 22697, 22699, 22709, 22717, 22721, 22727,
+ 22739, 22741, 22751, 22769, 22777, 22783, 22787, 22807, 22811, 22817,
+ 22853, 22859, 22861, 22871, 22877, 22901, 22907, 22921, 22937, 22943,
+ 22961, 22963, 22973, 22993, 23003, 23011, 23017, 23021, 23027, 23029,
+ 23039, 23041, 23053, 23057, 23059, 23063, 23071, 23081, 23087, 23099,
+ 23117, 23131, 23143, 23159, 23167, 23173, 23189, 23197, 23201, 23203,
+ 23209, 23227, 23251, 23269, 23279, 23291, 23293, 23297, 23311, 23321,
+ 23327, 23333, 23339, 23357, 23369, 23371, 23399, 23417, 23431, 23447,
+ 23459, 23473, 23497, 23509, 23531, 23537, 23539, 23549, 23557, 23561,
+ 23563, 23567, 23581, 23593, 23599, 23603, 23609, 23623, 23627, 23629,
+ 23633, 23663, 23669, 23671, 23677, 23687, 23689, 23719, 23741, 23743,
+ 23747, 23753, 23761, 23767, 23773, 23789, 23801, 23813, 23819, 23827,
+ 23831, 23833, 23857, 23869, 23873, 23879, 23887, 23893, 23899, 23909,
+ 23911, 23917, 23929, 23957, 23971, 23977, 23981, 23993, 24001, 24007,
+ 24019, 24023, 24029, 24043, 24049, 24061, 24071, 24077, 24083, 24091,
+ 24097, 24103, 24107, 24109, 24113, 24121, 24133, 24137, 24151, 24169,
+ 24179, 24181, 24197, 24203, 24223, 24229, 24239, 24247, 24251, 24281,
+ 24317, 24329, 24337, 24359, 24371, 24373, 24379, 24391, 24407, 24413,
+ 24419, 24421, 24439, 24443, 24469, 24473, 24481, 24499, 24509, 24517,
+ 24527, 24533, 24547, 24551, 24571, 24593, 24611, 24623, 24631, 24659,
+ 24671, 24677, 24683, 24691, 24697, 24709, 24733, 24749, 24763, 24767,
+ 24781, 24793, 24799, 24809, 24821, 24841, 24847, 24851, 24859, 24877,
+ 24889, 24907, 24917, 24919, 24923, 24943, 24953, 24967, 24971, 24977,
+ 24979, 24989, 25013, 25031, 25033, 25037, 25057, 25073, 25087, 25097,
+ 25111, 25117, 25121, 25127, 25147, 25153, 25163, 25169, 25171, 25183,
+ 25189, 25219, 25229, 25237, 25243, 25247, 25253, 25261, 25301, 25303,
+ 25307, 25309, 25321, 25339, 25343, 25349, 25357, 25367, 25373, 25391,
+ 25409, 25411, 25423, 25439, 25447, 25453, 25457, 25463, 25469, 25471,
+ 25523, 25537, 25541, 25561, 25577, 25579, 25583, 25589, 25601, 25603,
+ 25609, 25621, 25633, 25639, 25643, 25657, 25667, 25673, 25679, 25693,
+ 25703, 25717, 25733, 25741, 25747, 25759, 25763, 25771, 25793, 25799,
+ 25801, 25819, 25841, 25847, 25849, 25867, 25873, 25889, 25903, 25913,
+ 25919, 25931, 25933, 25939, 25943, 25951, 25969, 25981, 25997, 25999,
+ 26003, 26017, 26021, 26029, 26041, 26053, 26083, 26099, 26107, 26111,
+ 26113, 26119, 26141, 26153, 26161, 26171, 26177, 26183, 26189, 26203,
+ 26209, 26227, 26237, 26249, 26251, 26261, 26263, 26267, 26293, 26297,
+ 26309, 26317, 26321, 26339, 26347, 26357, 26371, 26387, 26393, 26399,
+ 26407, 26417, 26423, 26431, 26437, 26449, 26459, 26479, 26489, 26497,
+ 26501, 26513, 26539, 26557, 26561, 26573, 26591, 26597, 26627, 26633,
+ 26641, 26647, 26669, 26681, 26683, 26687, 26693, 26699, 26701, 26711,
+ 26713, 26717, 26723, 26729, 26731, 26737, 26759, 26777, 26783, 26801,
+ 26813, 26821, 26833, 26839, 26849, 26861, 26863, 26879, 26881, 26891,
+ 26893, 26903, 26921, 26927, 26947, 26951, 26953, 26959, 26981, 26987,
+ 26993, 27011, 27017, 27031, 27043, 27059, 27061, 27067, 27073, 27077,
+ 27091, 27103, 27107, 27109, 27127, 27143, 27179, 27191, 27197, 27211,
+ 27239, 27241, 27253, 27259, 27271, 27277, 27281, 27283, 27299, 27329,
+ 27337, 27361, 27367, 27397, 27407, 27409, 27427, 27431, 27437, 27449,
+ 27457, 27479, 27481, 27487, 27509, 27527, 27529, 27539, 27541, 27551,
+ 27581, 27583, 27611, 27617, 27631, 27647, 27653, 27673, 27689, 27691,
+ 27697, 27701, 27733, 27737, 27739, 27743, 27749, 27751, 27763, 27767,
+ 27773, 27779, 27791, 27793, 27799, 27803, 27809, 27817, 27823, 27827,
+ 27847, 27851, 27883, 27893, 27901, 27917, 27919, 27941, 27943, 27947,
+ 27953, 27961, 27967, 27983, 27997, 28001, 28019, 28027, 28031, 28051,
+ 28057, 28069, 28081, 28087, 28097, 28099, 28109, 28111, 28123, 28151,
+ 28163, 28181, 28183, 28201, 28211, 28219, 28229, 28277, 28279, 28283,
+ 28289, 28297, 28307, 28309, 28319, 28349, 28351, 28387, 28393, 28403,
+ 28409, 28411, 28429, 28433, 28439, 28447, 28463, 28477, 28493, 28499,
+ 28513, 28517, 28537, 28541, 28547, 28549, 28559, 28571, 28573, 28579,
+ 28591, 28597, 28603, 28607, 28619, 28621, 28627, 28631, 28643, 28649,
+ 28657, 28661, 28663, 28669, 28687, 28697, 28703, 28711, 28723, 28729,
+ 28751, 28753, 28759, 28771, 28789, 28793, 28807, 28813, 28817, 28837,
+ 28843, 28859, 28867, 28871, 28879, 28901, 28909, 28921, 28927, 28933,
+ 28949, 28961, 28979, 29009, 29017, 29021, 29023, 29027, 29033, 29059,
+ 29063, 29077, 29101, 29123, 29129, 29131, 29137, 29147, 29153, 29167,
+ 29173, 29179, 29191, 29201, 29207, 29209, 29221, 29231, 29243, 29251,
+ 29269, 29287, 29297, 29303, 29311, 29327, 29333, 29339, 29347, 29363,
+ 29383, 29387, 29389, 29399, 29401, 29411, 29423, 29429, 29437, 29443,
+ 29453, 29473, 29483, 29501, 29527, 29531, 29537, 29567, 29569, 29573,
+ 29581, 29587, 29599, 29611, 29629, 29633, 29641, 29663, 29669, 29671,
+ 29683, 29717, 29723, 29741, 29753, 29759, 29761, 29789, 29803, 29819,
+ 29833, 29837, 29851, 29863, 29867, 29873, 29879, 29881, 29917, 29921,
+ 29927, 29947, 29959, 29983, 29989, 30011, 30013, 30029, 30047, 30059,
+ 30071, 30089, 30091, 30097, 30103, 30109, 30113, 30119, 30133, 30137,
+ 30139, 30161, 30169, 30181, 30187, 30197, 30203, 30211, 30223, 30241,
+ 30253, 30259, 30269, 30271, 30293, 30307, 30313, 30319, 30323, 30341,
+ 30347, 30367, 30389, 30391, 30403, 30427, 30431, 30449, 30467, 30469,
+ 30491, 30493, 30497, 30509, 30517, 30529, 30539, 30553, 30557, 30559,
+ 30577, 30593, 30631, 30637, 30643, 30649, 30661, 30671, 30677, 30689,
+ 30697, 30703, 30707, 30713, 30727, 30757, 30763, 30773, 30781, 30803,
+ 30809, 30817, 30829, 30839, 30841, 30851, 30853, 30859, 30869, 30871,
+ 30881, 30893, 30911, 30931, 30937, 30941, 30949, 30971, 30977, 30983,
+ 31013, 31019, 31033, 31039, 31051, 31063, 31069, 31079, 31081, 31091,
+ 31121, 31123, 31139, 31147, 31151, 31153, 31159, 31177, 31181, 31183,
+ 31189, 31193, 31219, 31223, 31231, 31237, 31247, 31249, 31253, 31259,
+ 31267, 31271, 31277, 31307, 31319, 31321, 31327, 31333, 31337, 31357,
+ 31379, 31387, 31391, 31393, 31397, 31469, 31477, 31481, 31489, 31511,
+ 31513, 31517, 31531, 31541, 31543, 31547, 31567, 31573, 31583, 31601,
+ 31607, 31627, 31643, 31649, 31657, 31663, 31667, 31687, 31699, 31721,
+ 31723, 31727, 31729, 31741, 31751, 31769, 31771, 31793, 31799, 31817,
+ 31847, 31849, 31859, 31873, 31883, 31891, 31907, 31957, 31963, 31973,
+ 31981, 31991, 32003, 32009, 32027, 32029, 32051, 32057, 32059, 32063,
+ 32069, 32077, 32083, 32089, 32099, 32117, 32119, 32141, 32143, 32159,
+ 32173, 32183, 32189, 32191, 32203, 32213, 32233, 32237, 32251, 32257,
+ 32261, 32297, 32299, 32303, 32309, 32321, 32323, 32327, 32341, 32353,
+ 32359, 32363, 32369, 32371, 32377, 32381, 32401, 32411, 32413, 32423,
+ 32429, 32441, 32443, 32467, 32479, 32491, 32497, 32503, 32507, 32531,
+ 32533, 32537, 32561, 32563, 32569, 32573, 32579, 32587, 32603, 32609,
+ 32611, 32621, 32633, 32647, 32653, 32687, 32693, 32707, 32713, 32717,
+ 32719, 32749, 32771, 32779, 32783, 32789, 32797, 32801, 32803, 32831,
+ 32833, 32839, 32843, 32869, 32887, 32909, 32911, 32917, 32933, 32939,
+ 32941, 32957, 32969, 32971, 32983, 32987, 32993, 32999, 33013, 33023,
+ 33029, 33037, 33049, 33053, 33071, 33073, 33083, 33091, 33107, 33113,
+ 33119, 33149, 33151, 33161, 33179, 33181, 33191, 33199, 33203, 33211,
+ 33223, 33247, 33287, 33289, 33301, 33311, 33317, 33329, 33331, 33343,
+ 33347, 33349, 33353, 33359, 33377, 33391, 33403, 33409, 33413, 33427,
+ 33457, 33461, 33469, 33479, 33487, 33493, 33503, 33521, 33529, 33533,
+ 33547, 33563, 33569, 33577, 33581, 33587, 33589, 33599, 33601, 33613,
+ 33617, 33619, 33623, 33629, 33637, 33641, 33647, 33679, 33703, 33713,
+ 33721, 33739, 33749, 33751, 33757, 33767, 33769, 33773, 33791, 33797,
+ 33809, 33811, 33827, 33829, 33851, 33857, 33863, 33871, 33889, 33893,
+ 33911, 33923, 33931, 33937, 33941, 33961, 33967, 33997, 34019, 34031,
+ 34033, 34039, 34057, 34061, 34123, 34127, 34129, 34141, 34147, 34157,
+ 34159, 34171, 34183, 34211, 34213, 34217, 34231, 34253, 34259, 34261,
+ 34267, 34273, 34283, 34297, 34301, 34303, 34313, 34319, 34327, 34337,
+ 34351, 34361, 34367, 34369, 34381, 34403, 34421, 34429, 34439, 34457,
+ 34469, 34471, 34483, 34487, 34499, 34501, 34511, 34513, 34519, 34537,
+ 34543, 34549, 34583, 34589, 34591, 34603, 34607, 34613, 34631, 34649,
+ 34651, 34667, 34673, 34679, 34687, 34693, 34703, 34721, 34729, 34739,
+ 34747, 34757, 34759, 34763, 34781, 34807, 34819, 34841, 34843, 34847,
+ 34849, 34871, 34877, 34883, 34897, 34913, 34919, 34939, 34949, 34961,
+ 34963, 34981, 35023, 35027, 35051, 35053, 35059, 35069, 35081, 35083,
+ 35089, 35099, 35107, 35111, 35117, 35129, 35141, 35149, 35153, 35159,
+ 35171, 35201, 35221, 35227, 35251, 35257, 35267, 35279, 35281, 35291,
+ 35311, 35317, 35323, 35327, 35339, 35353, 35363, 35381, 35393, 35401,
+ 35407, 35419, 35423, 35437, 35447, 35449, 35461, 35491, 35507, 35509,
+ 35521, 35527, 35531, 35533, 35537, 35543, 35569, 35573, 35591, 35593,
+ 35597, 35603, 35617, 35671, 35677, 35729, 35731, 35747, 35753, 35759,
+ 35771, 35797, 35801, 35803, 35809, 35831, 35837, 35839, 35851, 35863,
+ 35869, 35879, 35897, 35899, 35911, 35923, 35933, 35951, 35963, 35969,
+ 35977, 35983, 35993, 35999, 36007, 36011, 36013, 36017, 36037, 36061,
+ 36067, 36073, 36083, 36097, 36107, 36109, 36131, 36137, 36151, 36161,
+ 36187, 36191, 36209, 36217, 36229, 36241, 36251, 36263, 36269, 36277,
+ 36293, 36299, 36307, 36313, 36319, 36341, 36343, 36353, 36373, 36383,
+ 36389, 36433, 36451, 36457, 36467, 36469, 36473, 36479, 36493, 36497,
+ 36523, 36527, 36529, 36541, 36551, 36559, 36563, 36571, 36583, 36587,
+ 36599, 36607, 36629, 36637, 36643, 36653, 36671, 36677, 36683, 36691,
+ 36697, 36709, 36713, 36721, 36739, 36749, 36761, 36767, 36779, 36781,
+ 36787, 36791, 36793, 36809, 36821, 36833, 36847, 36857, 36871, 36877,
+ 36887, 36899, 36901, 36913, 36919, 36923, 36929, 36931, 36943, 36947,
+ 36973, 36979, 36997, 37003, 37013, 37019, 37021, 37039, 37049, 37057,
+ 37061, 37087, 37097, 37117, 37123, 37139, 37159, 37171, 37181, 37189,
+ 37199, 37201, 37217, 37223, 37243, 37253, 37273, 37277, 37307, 37309,
+ 37313, 37321, 37337, 37339, 37357, 37361, 37363, 37369, 37379, 37397,
+ 37409, 37423, 37441, 37447, 37463, 37483, 37489, 37493, 37501, 37507,
+ 37511, 37517, 37529, 37537, 37547, 37549, 37561, 37567, 37571, 37573,
+ 37579, 37589, 37591, 37607, 37619, 37633, 37643, 37649, 37657, 37663,
+ 37691, 37693, 37699, 37717, 37747, 37781, 37783, 37799, 37811, 37813,
+ 37831, 37847, 37853, 37861, 37871, 37879, 37889, 37897, 37907, 37951,
+ 37957, 37963, 37967, 37987, 37991, 37993, 37997, 38011, 38039, 38047,
+ 38053, 38069, 38083, 38113, 38119, 38149, 38153, 38167, 38177, 38183,
+ 38189, 38197, 38201, 38219, 38231, 38237, 38239, 38261, 38273, 38281,
+ 38287, 38299, 38303, 38317, 38321, 38327, 38329, 38333, 38351, 38371,
+ 38377, 38393, 38431, 38447, 38449, 38453, 38459, 38461, 38501, 38543,
+ 38557, 38561, 38567, 38569, 38593, 38603, 38609, 38611, 38629, 38639,
+ 38651, 38653, 38669, 38671, 38677, 38693, 38699, 38707, 38711, 38713,
+ 38723, 38729, 38737, 38747, 38749, 38767, 38783, 38791, 38803, 38821,
+ 38833, 38839, 38851, 38861, 38867, 38873, 38891, 38903, 38917, 38921,
+ 38923, 38933, 38953, 38959, 38971, 38977, 38993, 39019, 39023, 39041,
+ 39043, 39047, 39079, 39089, 39097, 39103, 39107, 39113, 39119, 39133,
+ 39139, 39157, 39161, 39163, 39181, 39191, 39199, 39209, 39217, 39227,
+ 39229, 39233, 39239, 39241, 39251, 39293, 39301, 39313, 39317, 39323,
+ 39341, 39343, 39359, 39367, 39371, 39373, 39383, 39397, 39409, 39419,
+ 39439, 39443, 39451, 39461, 39499, 39503, 39509, 39511, 39521, 39541,
+ 39551, 39563, 39569, 39581, 39607, 39619, 39623, 39631, 39659, 39667,
+ 39671, 39679, 39703, 39709, 39719, 39727, 39733, 39749, 39761, 39769,
+ 39779, 39791, 39799, 39821, 39827, 39829, 39839, 39841, 39847, 39857,
+ 39863, 39869, 39877, 39883, 39887, 39901, 39929, 39937, 39953, 39971,
+ 39979, 39983, 39989, 40009, 40013, 40031, 40037, 40039, 40063, 40087,
+ 40093, 40099, 40111, 40123, 40127, 40129, 40151, 40153, 40163, 40169,
+ 40177, 40189, 40193, 40213, 40231, 40237, 40241, 40253, 40277, 40283,
+ 40289, 40343, 40351, 40357, 40361, 40387, 40423, 40427, 40429, 40433,
+ 40459, 40471, 40483, 40487, 40493, 40499, 40507, 40519, 40529, 40531,
+ 40543, 40559, 40577, 40583, 40591, 40597, 40609, 40627, 40637, 40639,
+ 40693, 40697, 40699, 40709, 40739, 40751, 40759, 40763, 40771, 40787,
+ 40801, 40813, 40819, 40823, 40829, 40841, 40847, 40849, 40853, 40867,
+ 40879, 40883, 40897, 40903, 40927, 40933, 40939, 40949, 40961, 40973,
+ 40993, 41011, 41017, 41023, 41039, 41047, 41051, 41057, 41077, 41081,
+ 41113, 41117, 41131, 41141, 41143, 41149, 41161, 41177, 41179, 41183,
+ 41189, 41201, 41203, 41213, 41221, 41227, 41231, 41233, 41243, 41257,
+ 41263, 41269, 41281, 41299, 41333, 41341, 41351, 41357, 41381, 41387,
+ 41389, 41399, 41411, 41413, 41443, 41453, 41467, 41479, 41491, 41507,
+ 41513, 41519, 41521, 41539, 41543, 41549, 41579, 41593, 41597, 41603,
+ 41609, 41611, 41617, 41621, 41627, 41641, 41647, 41651, 41659, 41669,
+ 41681, 41687, 41719, 41729, 41737, 41759, 41761, 41771, 41777, 41801,
+ 41809, 41813, 41843, 41849, 41851, 41863, 41879, 41887, 41893, 41897,
+ 41903, 41911, 41927, 41941, 41947, 41953, 41957, 41959, 41969, 41981,
+ 41983, 41999, 42013, 42017, 42019, 42023, 42043, 42061, 42071, 42073,
+ 42083, 42089, 42101, 42131, 42139, 42157, 42169, 42179, 42181, 42187,
+ 42193, 42197, 42209, 42221, 42223, 42227, 42239, 42257, 42281, 42283,
+ 42293, 42299, 42307, 42323, 42331, 42337, 42349, 42359, 42373, 42379,
+ 42391, 42397, 42403, 42407, 42409, 42433, 42437, 42443, 42451, 42457,
+ 42461, 42463, 42467, 42473, 42487, 42491, 42499, 42509, 42533, 42557,
+ 42569, 42571, 42577, 42589, 42611, 42641, 42643, 42649, 42667, 42677,
+ 42683, 42689, 42697, 42701, 42703, 42709, 42719, 42727, 42737, 42743,
+ 42751, 42767, 42773, 42787, 42793, 42797, 42821, 42829, 42839, 42841,
+ 42853, 42859, 42863, 42899, 42901, 42923, 42929, 42937, 42943, 42953,
+ 42961, 42967, 42979, 42989, 43003, 43013, 43019, 43037, 43049, 43051,
+ 43063, 43067, 43093, 43103, 43117, 43133, 43151, 43159, 43177, 43189,
+ 43201, 43207, 43223, 43237, 43261, 43271, 43283, 43291, 43313, 43319,
+ 43321, 43331, 43391, 43397, 43399, 43403, 43411, 43427, 43441, 43451,
+ 43457, 43481, 43487, 43499, 43517, 43541, 43543, 43573, 43577, 43579,
+ 43591, 43597, 43607, 43609, 43613, 43627, 43633, 43649, 43651, 43661,
+ 43669, 43691, 43711, 43717, 43721, 43753, 43759, 43777, 43781, 43783,
+ 43787, 43789, 43793, 43801, 43853, 43867, 43889, 43891, 43913, 43933,
+ 43943, 43951, 43961, 43963, 43969, 43973, 43987, 43991, 43997, 44017,
+ 44021, 44027, 44029, 44041, 44053, 44059, 44071, 44087, 44089, 44101,
+ 44111, 44119, 44123, 44129, 44131, 44159, 44171, 44179, 44189, 44201,
+ 44203, 44207, 44221, 44249, 44257, 44263, 44267, 44269, 44273, 44279,
+ 44281, 44293, 44351, 44357, 44371, 44381, 44383, 44389, 44417, 44449,
+ 44453, 44483, 44491, 44497, 44501, 44507, 44519, 44531, 44533, 44537,
+ 44543, 44549, 44563, 44579, 44587, 44617, 44621, 44623, 44633, 44641,
+ 44647, 44651, 44657, 44683, 44687, 44699, 44701, 44711, 44729, 44741,
+ 44753, 44771, 44773, 44777, 44789, 44797, 44809, 44819, 44839, 44843,
+ 44851, 44867, 44879, 44887, 44893, 44909, 44917, 44927, 44939, 44953,
+ 44959, 44963, 44971, 44983, 44987, 45007, 45013, 45053, 45061, 45077,
+ 45083, 45119, 45121, 45127, 45131, 45137, 45139, 45161, 45179, 45181,
+ 45191, 45197, 45233, 45247, 45259, 45263, 45281, 45289, 45293, 45307,
+ 45317, 45319, 45329, 45337, 45341, 45343, 45361, 45377, 45389, 45403,
+ 45413, 45427, 45433, 45439, 45481, 45491, 45497, 45503, 45523, 45533,
+ 45541, 45553, 45557, 45569, 45587, 45589, 45599, 45613, 45631, 45641,
+ 45659, 45667, 45673, 45677, 45691, 45697, 45707, 45737, 45751, 45757,
+ 45763, 45767, 45779, 45817, 45821, 45823, 45827, 45833, 45841, 45853,
+ 45863, 45869, 45887, 45893, 45943, 45949, 45953, 45959, 45971, 45979,
+ 45989, 46021, 46027, 46049, 46051, 46061, 46073, 46091, 46093, 46099,
+ 46103, 46133, 46141, 46147, 46153, 46171, 46181, 46183, 46187, 46199,
+ 46219, 46229, 46237, 46261, 46271, 46273, 46279, 46301, 46307, 46309,
+ 46327, 46337, 46349, 46351, 46381, 46399, 46411, 46439, 46441, 46447,
+ 46451, 46457, 46471, 46477, 46489, 46499, 46507, 46511, 46523, 46549,
+ 46559, 46567, 46573, 46589, 46591, 46601, 46619, 46633, 46639, 46643,
+ 46649, 46663, 46679, 46681, 46687, 46691, 46703, 46723, 46727, 46747,
+ 46751, 46757, 46769, 46771, 46807, 46811, 46817, 46819, 46829, 46831,
+ 46853, 46861, 46867, 46877, 46889, 46901, 46919, 46933, 46957, 46993,
+ 46997, 47017, 47041, 47051, 47057, 47059, 47087, 47093, 47111, 47119,
+ 47123, 47129, 47137, 47143, 47147, 47149, 47161, 47189, 47207, 47221,
+ 47237, 47251, 47269, 47279, 47287, 47293, 47297, 47303, 47309, 47317,
+ 47339, 47351, 47353, 47363, 47381, 47387, 47389, 47407, 47417, 47419,
+ 47431, 47441, 47459, 47491, 47497, 47501, 47507, 47513, 47521, 47527,
+ 47533, 47543, 47563, 47569, 47581, 47591, 47599, 47609, 47623, 47629,
+ 47639, 47653, 47657, 47659, 47681, 47699, 47701, 47711, 47713, 47717,
+ 47737, 47741, 47743, 47777, 47779, 47791, 47797, 47807, 47809, 47819,
+ 47837, 47843, 47857, 47869, 47881, 47903, 47911, 47917, 47933, 47939,
+ 47947, 47951, 47963, 47969, 47977, 47981, 48017, 48023, 48029, 48049,
+ 48073, 48079, 48091, 48109, 48119, 48121, 48131, 48157, 48163, 48179,
+ 48187, 48193, 48197, 48221, 48239, 48247, 48259, 48271, 48281, 48299,
+ 48311, 48313, 48337, 48341, 48353, 48371, 48383, 48397, 48407, 48409,
+ 48413, 48437, 48449, 48463, 48473, 48479, 48481, 48487, 48491, 48497,
+ 48523, 48527, 48533, 48539, 48541, 48563, 48571, 48589, 48593, 48611,
+ 48619, 48623, 48647, 48649, 48661, 48673, 48677, 48679, 48731, 48733,
+ 48751, 48757, 48761, 48767, 48779, 48781, 48787, 48799, 48809, 48817,
+ 48821, 48823, 48847, 48857, 48859, 48869, 48871, 48883, 48889, 48907,
+ 48947, 48953, 48973, 48989, 48991, 49003, 49009, 49019, 49031, 49033,
+ 49037, 49043, 49057, 49069, 49081, 49103, 49109, 49117, 49121, 49123,
+ 49139, 49157, 49169, 49171, 49177, 49193, 49199, 49201, 49207, 49211,
+ 49223, 49253, 49261, 49277, 49279, 49297, 49307, 49331, 49333, 49339,
+ 49363, 49367, 49369, 49391, 49393, 49409, 49411, 49417, 49429, 49433,
+ 49451, 49459, 49463, 49477, 49481, 49499, 49523, 49529, 49531, 49537,
+ 49547, 49549, 49559, 49597, 49603, 49613, 49627, 49633, 49639, 49663,
+ 49667, 49669, 49681, 49697, 49711, 49727, 49739, 49741, 49747, 49757,
+ 49783, 49787, 49789, 49801, 49807, 49811, 49823, 49831, 49843, 49853,
+ 49871, 49877, 49891, 49919, 49921, 49927, 49937, 49939, 49943, 49957,
+ 49991, 49993, 49999, 50021, 50023, 50033, 50047, 50051, 50053, 50069,
+ 50077, 50087, 50093, 50101, 50111, 50119, 50123, 50129, 50131, 50147,
+ 50153, 50159, 50177, 50207, 50221, 50227, 50231, 50261, 50263, 50273,
+ 50287, 50291, 50311, 50321, 50329, 50333, 50341, 50359, 50363, 50377,
+ 50383, 50387, 50411, 50417, 50423, 50441, 50459, 50461, 50497, 50503,
+ 50513, 50527, 50539, 50543, 50549, 50551, 50581, 50587, 50591, 50593,
+ 50599, 50627, 50647, 50651, 50671, 50683, 50707, 50723, 50741, 50753,
+ 50767, 50773, 50777, 50789, 50821, 50833, 50839, 50849, 50857, 50867,
+ 50873, 50891, 50893, 50909, 50923, 50929, 50951, 50957, 50969, 50971,
+ 50989, 50993, 51001, 51031, 51043, 51047, 51059, 51061, 51071, 51109,
+ 51131, 51133, 51137, 51151, 51157, 51169, 51193, 51197, 51199, 51203,
+ 51217, 51229, 51239, 51241, 51257, 51263, 51283, 51287, 51307, 51329,
+ 51341, 51343, 51347, 51349, 51361, 51383, 51407, 51413, 51419, 51421,
+ 51427, 51431, 51437, 51439, 51449, 51461, 51473, 51479, 51481, 51487,
+ 51503, 51511, 51517, 51521, 51539, 51551, 51563, 51577, 51581, 51593,
+ 51599, 51607, 51613, 51631, 51637, 51647, 51659, 51673, 51679, 51683,
+ 51691, 51713, 51719, 51721, 51749, 51767, 51769, 51787, 51797, 51803,
+ 51817, 51827, 51829, 51839, 51853, 51859, 51869, 51871, 51893, 51899,
+ 51907, 51913, 51929, 51941, 51949, 51971, 51973, 51977, 51991, 52009,
+ 52021, 52027, 52051, 52057, 52067, 52069, 52081, 52103, 52121, 52127,
+ 52147, 52153, 52163, 52177, 52181, 52183, 52189, 52201, 52223, 52237,
+ 52249, 52253, 52259, 52267, 52289, 52291, 52301, 52313, 52321, 52361,
+ 52363, 52369, 52379, 52387, 52391, 52433, 52453, 52457, 52489, 52501,
+ 52511, 52517, 52529, 52541, 52543, 52553, 52561, 52567, 52571, 52579,
+ 52583, 52609, 52627, 52631, 52639, 52667, 52673, 52691, 52697, 52709,
+ 52711, 52721, 52727, 52733, 52747, 52757, 52769, 52783, 52807, 52813,
+ 52817, 52837, 52859, 52861, 52879, 52883, 52889, 52901, 52903, 52919,
+ 52937, 52951, 52957, 52963, 52967, 52973, 52981, 52999, 53003, 53017,
+ 53047, 53051, 53069, 53077, 53087, 53089, 53093, 53101, 53113, 53117,
+ 53129, 53147, 53149, 53161, 53171, 53173, 53189, 53197, 53201, 53231,
+ 53233, 53239, 53267, 53269, 53279, 53281, 53299, 53309, 53323, 53327,
+ 53353, 53359, 53377, 53381, 53401, 53407, 53411, 53419, 53437, 53441,
+ 53453, 53479, 53503, 53507, 53527, 53549, 53551, 53569, 53591, 53593,
+ 53597, 53609, 53611, 53617, 53623, 53629, 53633, 53639, 53653, 53657,
+ 53681, 53693, 53699, 53717, 53719, 53731, 53759, 53773, 53777, 53783,
+ 53791, 53813, 53819, 53831, 53849, 53857, 53861, 53881, 53887, 53891,
+ 53897, 53899, 53917, 53923, 53927, 53939, 53951, 53959, 53987, 53993,
+ 54001, 54011, 54013, 54037, 54049, 54059, 54083, 54091, 54101, 54121,
+ 54133, 54139, 54151, 54163, 54167, 54181, 54193, 54217, 54251, 54269,
+ 54277, 54287, 54293, 54311, 54319, 54323, 54331, 54347, 54361, 54367,
+ 54371, 54377, 54401, 54403, 54409, 54413, 54419, 54421, 54437, 54443,
+ 54449, 54469, 54493, 54497, 54499, 54503, 54517, 54521, 54539, 54541,
+ 54547, 54559, 54563, 54577, 54581, 54583, 54601, 54617, 54623, 54629,
+ 54631, 54647, 54667, 54673, 54679, 54709, 54713, 54721, 54727, 54751,
+ 54767, 54773, 54779, 54787, 54799, 54829, 54833, 54851, 54869, 54877,
+ 54881, 54907, 54917, 54919, 54941, 54949, 54959, 54973, 54979, 54983,
+ 55001, 55009, 55021, 55049, 55051, 55057, 55061, 55073, 55079, 55103,
+ 55109, 55117, 55127, 55147, 55163, 55171, 55201, 55207, 55213, 55217,
+ 55219, 55229, 55243, 55249, 55259, 55291, 55313, 55331, 55333, 55337,
+ 55339, 55343, 55351, 55373, 55381, 55399, 55411, 55439, 55441, 55457,
+ 55469, 55487, 55501, 55511, 55529, 55541, 55547, 55579, 55589, 55603,
+ 55609, 55619, 55621, 55631, 55633, 55639, 55661, 55663, 55667, 55673,
+ 55681, 55691, 55697, 55711, 55717, 55721, 55733, 55763, 55787, 55793,
+ 55799, 55807, 55813, 55817, 55819, 55823, 55829, 55837, 55843, 55849,
+ 55871, 55889, 55897, 55901, 55903, 55921, 55927, 55931, 55933, 55949,
+ 55967, 55987, 55997, 56003, 56009, 56039, 56041, 56053, 56081, 56087,
+ 56093, 56099, 56101, 56113, 56123, 56131, 56149, 56167, 56171, 56179,
+ 56197, 56207, 56209, 56237, 56239, 56249, 56263, 56267, 56269, 56299,
+ 56311, 56333, 56359, 56369, 56377, 56383, 56393, 56401, 56417, 56431,
+ 56437, 56443, 56453, 56467, 56473, 56477, 56479, 56489, 56501, 56503,
+ 56509, 56519, 56527, 56531, 56533, 56543, 56569, 56591, 56597, 56599,
+ 56611, 56629, 56633, 56659, 56663, 56671, 56681, 56687, 56701, 56711,
+ 56713, 56731, 56737, 56747, 56767, 56773, 56779, 56783, 56807, 56809,
+ 56813, 56821, 56827, 56843, 56857, 56873, 56891, 56893, 56897, 56909,
+ 56911, 56921, 56923, 56929, 56941, 56951, 56957, 56963, 56983, 56989,
+ 56993, 56999, 57037, 57041, 57047, 57059, 57073, 57077, 57089, 57097,
+ 57107, 57119, 57131, 57139, 57143, 57149, 57163, 57173, 57179, 57191,
+ 57193, 57203, 57221, 57223, 57241, 57251, 57259, 57269, 57271, 57283,
+ 57287, 57301, 57329, 57331, 57347, 57349, 57367, 57373, 57383, 57389,
+ 57397, 57413, 57427, 57457, 57467, 57487, 57493, 57503, 57527, 57529,
+ 57557, 57559, 57571, 57587, 57593, 57601, 57637, 57641, 57649, 57653,
+ 57667, 57679, 57689, 57697, 57709, 57713, 57719, 57727, 57731, 57737,
+ 57751, 57773, 57781, 57787, 57791, 57793, 57803, 57809, 57829, 57839,
+ 57847, 57853, 57859, 57881, 57899, 57901, 57917, 57923, 57943, 57947,
+ 57973, 57977, 57991, 58013, 58027, 58031, 58043, 58049, 58057, 58061,
+ 58067, 58073, 58099, 58109, 58111, 58129, 58147, 58151, 58153, 58169,
+ 58171, 58189, 58193, 58199, 58207, 58211, 58217, 58229, 58231, 58237,
+ 58243, 58271, 58309, 58313, 58321, 58337, 58363, 58367, 58369, 58379,
+ 58391, 58393, 58403, 58411, 58417, 58427, 58439, 58441, 58451, 58453,
+ 58477, 58481, 58511, 58537, 58543, 58549, 58567, 58573, 58579, 58601,
+ 58603, 58613, 58631, 58657, 58661, 58679, 58687, 58693, 58699, 58711,
+ 58727, 58733, 58741, 58757, 58763, 58771, 58787, 58789, 58831, 58889,
+ 58897, 58901, 58907, 58909, 58913, 58921, 58937, 58943, 58963, 58967,
+ 58979, 58991, 58997, 59009, 59011, 59021, 59023, 59029, 59051, 59053,
+ 59063, 59069, 59077, 59083, 59093, 59107, 59113, 59119, 59123, 59141,
+ 59149, 59159, 59167, 59183, 59197, 59207, 59209, 59219, 59221, 59233,
+ 59239, 59243, 59263, 59273, 59281, 59333, 59341, 59351, 59357, 59359,
+ 59369, 59377, 59387, 59393, 59399, 59407, 59417, 59419, 59441, 59443,
+ 59447, 59453, 59467, 59471, 59473, 59497, 59509, 59513, 59539, 59557,
+ 59561, 59567, 59581, 59611, 59617, 59621, 59627, 59629, 59651, 59659,
+ 59663, 59669, 59671, 59693, 59699, 59707, 59723, 59729, 59743, 59747,
+ 59753, 59771, 59779, 59791, 59797, 59809, 59833, 59863, 59879, 59887,
+ 59921, 59929, 59951, 59957, 59971, 59981, 59999, 60013, 60017, 60029,
+ 60037, 60041, 60077, 60083, 60089, 60091, 60101, 60103, 60107, 60127,
+ 60133, 60139, 60149, 60161, 60167, 60169, 60209, 60217, 60223, 60251,
+ 60257, 60259, 60271, 60289, 60293, 60317, 60331, 60337, 60343, 60353,
+ 60373, 60383, 60397, 60413, 60427, 60443, 60449, 60457, 60493, 60497,
+ 60509, 60521, 60527, 60539, 60589, 60601, 60607, 60611, 60617, 60623,
+ 60631, 60637, 60647, 60649, 60659, 60661, 60679, 60689, 60703, 60719,
+ 60727, 60733, 60737, 60757, 60761, 60763, 60773, 60779, 60793, 60811,
+ 60821, 60859, 60869, 60887, 60889, 60899, 60901, 60913, 60917, 60919,
+ 60923, 60937, 60943, 60953, 60961, 61001, 61007, 61027, 61031, 61043,
+ 61051, 61057, 61091, 61099, 61121, 61129, 61141, 61151, 61153, 61169,
+ 61211, 61223, 61231, 61253, 61261, 61283, 61291, 61297, 61331, 61333,
+ 61339, 61343, 61357, 61363, 61379, 61381, 61403, 61409, 61417, 61441,
+ 61463, 61469, 61471, 61483, 61487, 61493, 61507, 61511, 61519, 61543,
+ 61547, 61553, 61559, 61561, 61583, 61603, 61609, 61613, 61627, 61631,
+ 61637, 61643, 61651, 61657, 61667, 61673, 61681, 61687, 61703, 61717,
+ 61723, 61729, 61751, 61757, 61781, 61813, 61819, 61837, 61843, 61861,
+ 61871, 61879, 61909, 61927, 61933, 61949, 61961, 61967, 61979, 61981,
+ 61987, 61991, 62003, 62011, 62017, 62039, 62047, 62053, 62057, 62071,
+ 62081, 62099, 62119, 62129, 62131, 62137, 62141, 62143, 62171, 62189,
+ 62191, 62201, 62207, 62213, 62219, 62233, 62273, 62297, 62299, 62303,
+ 62311, 62323, 62327, 62347, 62351, 62383, 62401, 62417, 62423, 62459,
+ 62467, 62473, 62477, 62483, 62497, 62501, 62507, 62533, 62539, 62549,
+ 62563, 62581, 62591, 62597, 62603, 62617, 62627, 62633, 62639, 62653,
+ 62659, 62683, 62687, 62701, 62723, 62731, 62743, 62753, 62761, 62773,
+ 62791, 62801, 62819, 62827, 62851, 62861, 62869, 62873, 62897, 62903,
+ 62921, 62927, 62929, 62939, 62969, 62971, 62981, 62983, 62987, 62989,
+ 63029, 63031, 63059, 63067, 63073, 63079, 63097, 63103, 63113, 63127,
+ 63131, 63149, 63179, 63197, 63199, 63211, 63241, 63247, 63277, 63281,
+ 63299, 63311, 63313, 63317, 63331, 63337, 63347, 63353, 63361, 63367,
+ 63377, 63389, 63391, 63397, 63409, 63419, 63421, 63439, 63443, 63463,
+ 63467, 63473, 63487, 63493, 63499, 63521, 63527, 63533, 63541, 63559,
+ 63577, 63587, 63589, 63599, 63601, 63607, 63611, 63617, 63629, 63647,
+ 63649, 63659, 63667, 63671, 63689, 63691, 63697, 63703, 63709, 63719,
+ 63727, 63737, 63743, 63761, 63773, 63781, 63793, 63799, 63803, 63809,
+ 63823, 63839, 63841, 63853, 63857, 63863, 63901, 63907, 63913, 63929,
+ 63949, 63977, 63997, 64007, 64013, 64019, 64033, 64037, 64063, 64067,
+ 64081, 64091, 64109, 64123, 64151, 64153, 64157, 64171, 64187, 64189,
+ 64217, 64223, 64231, 64237, 64271, 64279, 64283, 64301, 64303, 64319,
+ 64327, 64333, 64373, 64381, 64399, 64403, 64433, 64439, 64451, 64453,
+ 64483, 64489, 64499, 64513, 64553, 64567, 64577, 64579, 64591, 64601,
+ 64609, 64613, 64621, 64627, 64633, 64661, 64663, 64667, 64679, 64693,
+ 64709, 64717, 64747, 64763, 64781, 64783, 64793, 64811, 64817, 64849,
+ 64853, 64871, 64877, 64879, 64891, 64901, 64919, 64921, 64927, 64937,
+ 64951, 64969, 64997, 65003, 65011, 65027, 65029, 65033, 65053, 65063,
+ 65071, 65089, 65099, 65101, 65111, 65119, 65123, 65129, 65141, 65147,
+ 65167, 65171, 65173, 65179, 65183, 65203, 65213, 65239, 65257, 65267,
+ 65269, 65287, 65293, 65309, 65323, 65327, 65353, 65357, 65371, 65381,
+ 65393, 65407, 65413, 65419, 65423, 65437, 65447, 65449, 65479, 65497,
+ 65519, 65521, 65537, 65539, 65543, 65551, 65557, 65563, 65579, 65581,
+ 65587, 65599, 65609, 65617, 65629, 65633, 65647, 65651, 65657, 65677,
+ 65687, 65699, 65701, 65707, 65713, 65717, 65719, 65729, 65731, 65761,
+ 65777, 65789, 65809, 65827, 65831, 65837, 65839, 65843, 65851, 65867,
+ 65881, 65899, 65921, 65927, 65929, 65951, 65957, 65963, 65981, 65983,
+ 65993, 66029, 66037, 66041, 66047, 66067, 66071, 66083, 66089, 66103,
+ 66107, 66109, 66137, 66161, 66169, 66173, 66179, 66191, 66221, 66239,
+ 66271, 66293, 66301, 66337, 66343, 66347, 66359, 66361, 66373, 66377,
+ 66383, 66403, 66413, 66431, 66449, 66457, 66463, 66467, 66491, 66499,
+ 66509, 66523, 66529, 66533, 66541, 66553, 66569, 66571, 66587, 66593,
+ 66601, 66617, 66629, 66643, 66653, 66683, 66697, 66701, 66713, 66721,
+ 66733, 66739, 66749, 66751, 66763, 66791, 66797, 66809, 66821, 66841,
+ 66851, 66853, 66863, 66877, 66883, 66889, 66919, 66923, 66931, 66943,
+ 66947, 66949, 66959, 66973, 66977, 67003, 67021, 67033, 67043, 67049,
+ 67057, 67061, 67073, 67079, 67103, 67121, 67129, 67139, 67141, 67153,
+ 67157, 67169, 67181, 67187, 67189, 67211, 67213, 67217, 67219, 67231,
+ 67247, 67261, 67271, 67273, 67289, 67307, 67339, 67343, 67349, 67369,
+ 67391, 67399, 67409, 67411, 67421, 67427, 67429, 67433, 67447, 67453,
+ 67477, 67481, 67489, 67493, 67499, 67511, 67523, 67531, 67537, 67547,
+ 67559, 67567, 67577, 67579, 67589, 67601, 67607, 67619, 67631, 67651,
+ 67679, 67699, 67709, 67723, 67733, 67741, 67751, 67757, 67759, 67763,
+ 67777, 67783, 67789, 67801, 67807, 67819, 67829, 67843, 67853, 67867,
+ 67883, 67891, 67901, 67927, 67931, 67933, 67939, 67943, 67957, 67961,
+ 67967, 67979, 67987, 67993, 68023, 68041, 68053, 68059, 68071, 68087,
+ 68099, 68111, 68113, 68141, 68147, 68161, 68171, 68207, 68209, 68213,
+ 68219, 68227, 68239, 68261, 68279, 68281, 68311, 68329, 68351, 68371,
+ 68389, 68399, 68437, 68443, 68447, 68449, 68473, 68477, 68483, 68489,
+ 68491, 68501, 68507, 68521, 68531, 68539, 68543, 68567, 68581, 68597,
+ 68611, 68633, 68639, 68659, 68669, 68683, 68687, 68699, 68711, 68713,
+ 68729, 68737, 68743, 68749, 68767, 68771, 68777, 68791, 68813, 68819,
+ 68821, 68863, 68879, 68881, 68891, 68897, 68899, 68903, 68909, 68917,
+ 68927, 68947, 68963, 68993, 69001, 69011, 69019, 69029, 69031, 69061,
+ 69067, 69073, 69109, 69119, 69127, 69143, 69149, 69151, 69163, 69191,
+ 69193, 69197, 69203, 69221, 69233, 69239, 69247, 69257, 69259, 69263,
+ 69313, 69317, 69337, 69341, 69371, 69379, 69383, 69389, 69401, 69403,
+ 69427, 69431, 69439, 69457, 69463, 69467, 69473, 69481, 69491, 69493,
+ 69497, 69499, 69539, 69557, 69593, 69623, 69653, 69661, 69677, 69691,
+ 69697, 69709, 69737, 69739, 69761, 69763, 69767, 69779, 69809, 69821,
+ 69827, 69829, 69833, 69847, 69857, 69859, 69877, 69899, 69911, 69929,
+ 69931, 69941, 69959, 69991, 69997, 70001, 70003, 70009, 70019, 70039,
+ 70051, 70061, 70067, 70079, 70099, 70111, 70117, 70121, 70123, 70139,
+ 70141, 70157, 70163, 70177, 70181, 70183, 70199, 70201, 70207, 70223,
+ 70229, 70237, 70241, 70249, 70271, 70289, 70297, 70309, 70313, 70321,
+ 70327, 70351, 70373, 70379, 70381, 70393, 70423, 70429, 70439, 70451,
+ 70457, 70459, 70481, 70487, 70489, 70501, 70507, 70529, 70537, 70549,
+ 70571, 70573, 70583, 70589, 70607, 70619, 70621, 70627, 70639, 70657,
+ 70663, 70667, 70687, 70709, 70717, 70729, 70753, 70769, 70783, 70793,
+ 70823, 70841, 70843, 70849, 70853, 70867, 70877, 70879, 70891, 70901,
+ 70913, 70919, 70921, 70937, 70949, 70951, 70957, 70969, 70979, 70981,
+ 70991, 70997, 70999, 71011, 71023, 71039, 71059, 71069, 71081, 71089,
+ 71119, 71129, 71143, 71147, 71153, 71161, 71167, 71171, 71191, 71209,
+ 71233, 71237, 71249, 71257, 71261, 71263, 71287, 71293, 71317, 71327,
+ 71329, 71333, 71339, 71341, 71347, 71353, 71359, 71363, 71387, 71389,
+ 71399, 71411, 71413, 71419, 71429, 71437, 71443, 71453, 71471, 71473,
+ 71479, 71483, 71503, 71527, 71537, 71549, 71551, 71563, 71569, 71593,
+ 71597, 71633, 71647, 71663, 71671, 71693, 71699, 71707, 71711, 71713,
+ 71719, 71741, 71761, 71777, 71789, 71807, 71809, 71821, 71837, 71843,
+ 71849, 71861, 71867, 71879, 71881, 71887, 71899, 71909, 71917, 71933,
+ 71941, 71947, 71963, 71971, 71983, 71987, 71993, 71999, 72019, 72031,
+ 72043, 72047, 72053, 72073, 72077, 72089, 72091, 72101, 72103, 72109,
+ 72139, 72161, 72167, 72169, 72173, 72211, 72221, 72223, 72227, 72229,
+ 72251, 72253, 72269, 72271, 72277, 72287, 72307, 72313, 72337, 72341,
+ 72353, 72367, 72379, 72383, 72421, 72431, 72461, 72467, 72469, 72481,
+ 72493, 72497, 72503, 72533, 72547, 72551, 72559, 72577, 72613, 72617,
+ 72623, 72643, 72647, 72649, 72661, 72671, 72673, 72679, 72689, 72701,
+ 72707, 72719, 72727, 72733, 72739, 72763, 72767, 72797, 72817, 72823,
+ 72859, 72869, 72871, 72883, 72889, 72893, 72901, 72907, 72911, 72923,
+ 72931, 72937, 72949, 72953, 72959, 72973, 72977, 72997, 73009, 73013,
+ 73019, 73037, 73039, 73043, 73061, 73063, 73079, 73091, 73121, 73127,
+ 73133, 73141, 73181, 73189, 73237, 73243, 73259, 73277, 73291, 73303,
+ 73309, 73327, 73331, 73351, 73361, 73363, 73369, 73379, 73387, 73417,
+ 73421, 73433, 73453, 73459, 73471, 73477, 73483, 73517, 73523, 73529,
+ 73547, 73553, 73561, 73571, 73583, 73589, 73597, 73607, 73609, 73613,
+ 73637, 73643, 73651, 73673, 73679, 73681, 73693, 73699, 73709, 73721,
+ 73727, 73751, 73757, 73771, 73783, 73819, 73823, 73847, 73849, 73859,
+ 73867, 73877, 73883, 73897, 73907, 73939, 73943, 73951, 73961, 73973,
+ 73999, 74017, 74021, 74027, 74047, 74051, 74071, 74077, 74093, 74099,
+ 74101, 74131, 74143, 74149, 74159, 74161, 74167, 74177, 74189, 74197,
+ 74201, 74203, 74209, 74219, 74231, 74257, 74279, 74287, 74293, 74297,
+ 74311, 74317, 74323, 74353, 74357, 74363, 74377, 74381, 74383, 74411,
+ 74413, 74419, 74441, 74449, 74453, 74471, 74489, 74507, 74509, 74521,
+ 74527, 74531, 74551, 74561, 74567, 74573, 74587, 74597, 74609, 74611,
+ 74623, 74653, 74687, 74699, 74707, 74713, 74717, 74719, 74729, 74731,
+ 74747, 74759, 74761, 74771, 74779, 74797, 74821, 74827, 74831, 74843,
+ 74857, 74861, 74869, 74873, 74887, 74891, 74897, 74903, 74923, 74929,
+ 74933, 74941, 74959, 75011, 75013, 75017, 75029, 75037, 75041, 75079,
+ 75083, 75109, 75133, 75149, 75161, 75167, 75169, 75181, 75193, 75209,
+ 75211, 75217, 75223, 75227, 75239, 75253, 75269, 75277, 75289, 75307,
+ 75323, 75329, 75337, 75347, 75353, 75367, 75377, 75389, 75391, 75401,
+ 75403, 75407, 75431, 75437, 75479, 75503, 75511, 75521, 75527, 75533,
+ 75539, 75541, 75553, 75557, 75571, 75577, 75583, 75611, 75617, 75619,
+ 75629, 75641, 75653, 75659, 75679, 75683, 75689, 75703, 75707, 75709,
+ 75721, 75731, 75743, 75767, 75773, 75781, 75787, 75793, 75797, 75821,
+ 75833, 75853, 75869, 75883, 75913, 75931, 75937, 75941, 75967, 75979,
+ 75983, 75989, 75991, 75997, 76001, 76003, 76031, 76039, 76079, 76081,
+ 76091, 76099, 76103, 76123, 76129, 76147, 76157, 76159, 76163, 76207,
+ 76213, 76231, 76243, 76249, 76253, 76259, 76261, 76283, 76289, 76303,
+ 76333, 76343, 76367, 76369, 76379, 76387, 76403, 76421, 76423, 76441,
+ 76463, 76471, 76481, 76487, 76493, 76507, 76511, 76519, 76537, 76541,
+ 76543, 76561, 76579, 76597, 76603, 76607, 76631, 76649, 76651, 76667,
+ 76673, 76679, 76697, 76717, 76733, 76753, 76757, 76771, 76777, 76781,
+ 76801, 76819, 76829, 76831, 76837, 76847, 76871, 76873, 76883, 76907,
+ 76913, 76919, 76943, 76949, 76961, 76963, 76991, 77003, 77017, 77023,
+ 77029, 77041, 77047, 77069, 77081, 77093, 77101, 77137, 77141, 77153,
+ 77167, 77171, 77191, 77201, 77213, 77237, 77239, 77243, 77249, 77261,
+ 77263, 77267, 77269, 77279, 77291, 77317, 77323, 77339, 77347, 77351,
+ 77359, 77369, 77377, 77383, 77417, 77419, 77431, 77447, 77471, 77477,
+ 77479, 77489, 77491, 77509, 77513, 77521, 77527, 77543, 77549, 77551,
+ 77557, 77563, 77569, 77573, 77587, 77591, 77611, 77617, 77621, 77641,
+ 77647, 77659, 77681, 77687, 77689, 77699, 77711, 77713, 77719, 77723,
+ 77731, 77743, 77747, 77761, 77773, 77783, 77797, 77801, 77813, 77839,
+ 77849, 77863, 77867, 77893, 77899, 77929, 77933, 77951, 77969, 77977,
+ 77983, 77999, 78007, 78017, 78031, 78041, 78049, 78059, 78079, 78101,
+ 78121, 78137, 78139, 78157, 78163, 78167, 78173, 78179, 78191, 78193,
+ 78203, 78229, 78233, 78241, 78259, 78277, 78283, 78301, 78307, 78311,
+ 78317, 78341, 78347, 78367, 78401, 78427, 78437, 78439, 78467, 78479,
+ 78487, 78497, 78509, 78511, 78517, 78539, 78541, 78553, 78569, 78571,
+ 78577, 78583, 78593, 78607, 78623, 78643, 78649, 78653, 78691, 78697,
+ 78707, 78713, 78721, 78737, 78779, 78781, 78787, 78791, 78797, 78803,
+ 78809, 78823, 78839, 78853, 78857, 78877, 78887, 78889, 78893, 78901,
+ 78919, 78929, 78941, 78977, 78979, 78989, 79031, 79039, 79043, 79063,
+ 79087, 79103, 79111, 79133, 79139, 79147, 79151, 79153, 79159, 79181,
+ 79187, 79193, 79201, 79229, 79231, 79241, 79259, 79273, 79279, 79283,
+ 79301, 79309, 79319, 79333, 79337, 79349, 79357, 79367, 79379, 79393,
+ 79397, 79399, 79411, 79423, 79427, 79433, 79451, 79481, 79493, 79531,
+ 79537, 79549, 79559, 79561, 79579, 79589, 79601, 79609, 79613, 79621,
+ 79627, 79631, 79633, 79657, 79669, 79687, 79691, 79693, 79697, 79699,
+ 79757, 79769, 79777, 79801, 79811, 79813, 79817, 79823, 79829, 79841,
+ 79843, 79847, 79861, 79867, 79873, 79889, 79901, 79903, 79907, 79939,
+ 79943, 79967, 79973, 79979, 79987, 79997, 79999, 80021, 80039, 80051,
+ 80071, 80077, 80107, 80111, 80141, 80147, 80149, 80153, 80167, 80173,
+ 80177, 80191, 80207, 80209, 80221, 80231, 80233, 80239, 80251, 80263,
+ 80273, 80279, 80287, 80309, 80317, 80329, 80341, 80347, 80363, 80369,
+ 80387, 80407, 80429, 80447, 80449, 80471, 80473, 80489, 80491, 80513,
+ 80527, 80537, 80557, 80567, 80599, 80603, 80611, 80621, 80627, 80629,
+ 80651, 80657, 80669, 80671, 80677, 80681, 80683, 80687, 80701, 80713,
+ 80737, 80747, 80749, 80761, 80777, 80779, 80783, 80789, 80803, 80809,
+ 80819, 80831, 80833, 80849, 80863, 80897, 80909, 80911, 80917, 80923,
+ 80929, 80933, 80953, 80963, 80989, 81001, 81013, 81017, 81019, 81023,
+ 81031, 81041, 81043, 81047, 81049, 81071, 81077, 81083, 81097, 81101,
+ 81119, 81131, 81157, 81163, 81173, 81181, 81197, 81199, 81203, 81223,
+ 81233, 81239, 81281, 81283, 81293, 81299, 81307, 81331, 81343, 81349,
+ 81353, 81359, 81371, 81373, 81401, 81409, 81421, 81439, 81457, 81463,
+ 81509, 81517, 81527, 81533, 81547, 81551, 81553, 81559, 81563, 81569,
+ 81611, 81619, 81629, 81637, 81647, 81649, 81667, 81671, 81677, 81689,
+ 81701, 81703, 81707, 81727, 81737, 81749, 81761, 81769, 81773, 81799,
+ 81817, 81839, 81847, 81853, 81869, 81883, 81899, 81901, 81919, 81929,
+ 81931, 81937, 81943, 81953, 81967, 81971, 81973, 82003, 82007, 82009,
+ 82013, 82021, 82031, 82037, 82039, 82051, 82067, 82073, 82129, 82139,
+ 82141, 82153, 82163, 82171, 82183, 82189, 82193, 82207, 82217, 82219,
+ 82223, 82231, 82237, 82241, 82261, 82267, 82279, 82301, 82307, 82339,
+ 82349, 82351, 82361, 82373, 82387, 82393, 82421, 82457, 82463, 82469,
+ 82471, 82483, 82487, 82493, 82499, 82507, 82529, 82531, 82549, 82559,
+ 82561, 82567, 82571, 82591, 82601, 82609, 82613, 82619, 82633, 82651,
+ 82657, 82699, 82721, 82723, 82727, 82729, 82757, 82759, 82763, 82781,
+ 82787, 82793, 82799, 82811, 82813, 82837, 82847, 82883, 82889, 82891,
+ 82903, 82913, 82939, 82963, 82981, 82997, 83003, 83009, 83023, 83047,
+ 83059, 83063, 83071, 83077, 83089, 83093, 83101, 83117, 83137, 83177,
+ 83203, 83207, 83219, 83221, 83227, 83231, 83233, 83243, 83257, 83267,
+ 83269, 83273, 83299, 83311, 83339, 83341, 83357, 83383, 83389, 83399,
+ 83401, 83407, 83417, 83423, 83431, 83437, 83443, 83449, 83459, 83471,
+ 83477, 83497, 83537, 83557, 83561, 83563, 83579, 83591, 83597, 83609,
+ 83617, 83621, 83639, 83641, 83653, 83663, 83689, 83701, 83717, 83719,
+ 83737, 83761, 83773, 83777, 83791, 83813, 83833, 83843, 83857, 83869,
+ 83873, 83891, 83903, 83911, 83921, 83933, 83939, 83969, 83983, 83987,
+ 84011, 84017, 84047, 84053, 84059, 84061, 84067, 84089, 84121, 84127,
+ 84131, 84137, 84143, 84163, 84179, 84181, 84191, 84199, 84211, 84221,
+ 84223, 84229, 84239, 84247, 84263, 84299, 84307, 84313, 84317, 84319,
+ 84347, 84349, 84377, 84389, 84391, 84401, 84407, 84421, 84431, 84437,
+ 84443, 84449, 84457, 84463, 84467, 84481, 84499, 84503, 84509, 84521,
+ 84523, 84533, 84551, 84559, 84589, 84629, 84631, 84649, 84653, 84659,
+ 84673, 84691, 84697, 84701, 84713, 84719, 84731, 84737, 84751, 84761,
+ 84787, 84793, 84809, 84811, 84827, 84857, 84859, 84869, 84871, 84913,
+ 84919, 84947, 84961, 84967, 84977, 84979, 84991, 85009, 85021, 85027,
+ 85037, 85049, 85061, 85081, 85087, 85091, 85093, 85103, 85109, 85121,
+ 85133, 85147, 85159, 85193, 85199, 85201, 85213, 85223, 85229, 85237,
+ 85243, 85247, 85259, 85297, 85303, 85313, 85331, 85333, 85361, 85363,
+ 85369, 85381, 85411, 85427, 85429, 85439, 85447, 85451, 85453, 85469,
+ 85487, 85513, 85517, 85523, 85531, 85549, 85571, 85577, 85597, 85601,
+ 85607, 85619, 85621, 85627, 85639, 85643, 85661, 85667, 85669, 85691,
+ 85703, 85711, 85717, 85733, 85751, 85781, 85793, 85817, 85819, 85829,
+ 85831, 85837, 85843, 85847, 85853, 85889, 85903, 85909, 85931, 85933,
+ 85991, 85999, 86011, 86017, 86027, 86029, 86069, 86077, 86083, 86111,
+ 86113, 86117, 86131, 86137, 86143, 86161, 86171, 86179, 86183, 86197,
+ 86201, 86209, 86239, 86243, 86249, 86257, 86263, 86269, 86287, 86291,
+ 86293, 86297, 86311, 86323, 86341, 86351, 86353, 86357, 86369, 86371,
+ 86381, 86389, 86399, 86413, 86423, 86441, 86453, 86461, 86467, 86477,
+ 86491, 86501, 86509, 86531, 86533, 86539, 86561, 86573, 86579, 86587,
+ 86599, 86627, 86629, 86677, 86689, 86693, 86711, 86719, 86729, 86743,
+ 86753, 86767, 86771, 86783, 86813, 86837, 86843, 86851, 86857, 86861,
+ 86869, 86923, 86927, 86929, 86939, 86951, 86959, 86969, 86981, 86993,
+ 87011, 87013, 87037, 87041, 87049, 87071, 87083, 87103, 87107, 87119,
+ 87121, 87133, 87149, 87151, 87179, 87181, 87187, 87211, 87221, 87223,
+ 87251, 87253, 87257, 87277, 87281, 87293, 87299, 87313, 87317, 87323,
+ 87337, 87359, 87383, 87403, 87407, 87421, 87427, 87433, 87443, 87473,
+ 87481, 87491, 87509, 87511, 87517, 87523, 87539, 87541, 87547, 87553,
+ 87557, 87559, 87583, 87587, 87589, 87613, 87623, 87629, 87631, 87641,
+ 87643, 87649, 87671, 87679, 87683, 87691, 87697, 87701, 87719, 87721,
+ 87739, 87743, 87751, 87767, 87793, 87797, 87803, 87811, 87833, 87853,
+ 87869, 87877, 87881, 87887, 87911, 87917, 87931, 87943, 87959, 87961,
+ 87973, 87977, 87991, 88001, 88003, 88007, 88019, 88037, 88069, 88079,
+ 88093, 88117, 88129, 88169, 88177, 88211, 88223, 88237, 88241, 88259,
+ 88261, 88289, 88301, 88321, 88327, 88337, 88339, 88379, 88397, 88411,
+ 88423, 88427, 88463, 88469, 88471, 88493, 88499, 88513, 88523, 88547,
+ 88589, 88591, 88607, 88609, 88643, 88651, 88657, 88661, 88663, 88667,
+ 88681, 88721, 88729, 88741, 88747, 88771, 88789, 88793, 88799, 88801,
+ 88807, 88811, 88813, 88817, 88819, 88843, 88853, 88861, 88867, 88873,
+ 88883, 88897, 88903, 88919, 88937, 88951, 88969, 88993, 88997, 89003,
+ 89009, 89017, 89021, 89041, 89051, 89057, 89069, 89071, 89083, 89087,
+ 89101, 89107, 89113, 89119, 89123, 89137, 89153, 89189, 89203, 89209,
+ 89213, 89227, 89231, 89237, 89261, 89269, 89273, 89293, 89303, 89317,
+ 89329, 89363, 89371, 89381, 89387, 89393, 89399, 89413, 89417, 89431,
+ 89443, 89449, 89459, 89477, 89491, 89501, 89513, 89519, 89521, 89527,
+ 89533, 89561, 89563, 89567, 89591, 89597, 89599, 89603, 89611, 89627,
+ 89633, 89653, 89657, 89659, 89669, 89671, 89681, 89689, 89753, 89759,
+ 89767, 89779, 89783, 89797, 89809, 89819, 89821, 89833, 89839, 89849,
+ 89867, 89891, 89897, 89899, 89909, 89917, 89923, 89939, 89959, 89963,
+ 89977, 89983, 89989, 90001, 90007, 90011, 90017, 90019, 90023, 90031,
+ 90053, 90059, 90067, 90071, 90073, 90089, 90107, 90121, 90127, 90149,
+ 90163, 90173, 90187, 90191, 90197, 90199, 90203, 90217, 90227, 90239,
+ 90247, 90263, 90271, 90281, 90289, 90313, 90353, 90359, 90371, 90373,
+ 90379, 90397, 90401, 90403, 90407, 90437, 90439, 90469, 90473, 90481,
+ 90499, 90511, 90523, 90527, 90529, 90533, 90547, 90583, 90599, 90617,
+ 90619, 90631, 90641, 90647, 90659, 90677, 90679, 90697, 90703, 90709,
+ 90731, 90749, 90787, 90793, 90803, 90821, 90823, 90833, 90841, 90847,
+ 90863, 90887, 90901, 90907, 90911, 90917, 90931, 90947, 90971, 90977,
+ 90989, 90997, 91009, 91019, 91033, 91079, 91081, 91097, 91099, 91121,
+ 91127, 91129, 91139, 91141, 91151, 91153, 91159, 91163, 91183, 91193,
+ 91199, 91229, 91237, 91243, 91249, 91253, 91283, 91291, 91297, 91303,
+ 91309, 91331, 91367, 91369, 91373, 91381, 91387, 91393, 91397, 91411,
+ 91423, 91433, 91453, 91457, 91459, 91463, 91493, 91499, 91513, 91529,
+ 91541, 91571, 91573, 91577, 91583, 91591, 91621, 91631, 91639, 91673,
+ 91691, 91703, 91711, 91733, 91753, 91757, 91771, 91781, 91801, 91807,
+ 91811, 91813, 91823, 91837, 91841, 91867, 91873, 91909, 91921, 91939,
+ 91943, 91951, 91957, 91961, 91967, 91969, 91997, 92003, 92009, 92033,
+ 92041, 92051, 92077, 92083, 92107, 92111, 92119, 92143, 92153, 92173,
+ 92177, 92179, 92189, 92203, 92219, 92221, 92227, 92233, 92237, 92243,
+ 92251, 92269, 92297, 92311, 92317, 92333, 92347, 92353, 92357, 92363,
+ 92369, 92377, 92381, 92383, 92387, 92399, 92401, 92413, 92419, 92431,
+ 92459, 92461, 92467, 92479, 92489, 92503, 92507, 92551, 92557, 92567,
+ 92569, 92581, 92593, 92623, 92627, 92639, 92641, 92647, 92657, 92669,
+ 92671, 92681, 92683, 92693, 92699, 92707, 92717, 92723, 92737, 92753,
+ 92761, 92767, 92779, 92789, 92791, 92801, 92809, 92821, 92831, 92849,
+ 92857, 92861, 92863, 92867, 92893, 92899, 92921, 92927, 92941, 92951,
+ 92957, 92959, 92987, 92993, 93001, 93047, 93053, 93059, 93077, 93083,
+ 93089, 93097, 93103, 93113, 93131, 93133, 93139, 93151, 93169, 93179,
+ 93187, 93199, 93229, 93239, 93241, 93251, 93253, 93257, 93263, 93281,
+ 93283, 93287, 93307, 93319, 93323, 93329, 93337, 93371, 93377, 93383,
+ 93407, 93419, 93427, 93463, 93479, 93481, 93487, 93491, 93493, 93497,
+ 93503, 93523, 93529, 93553, 93557, 93559, 93563, 93581, 93601, 93607,
+ 93629, 93637, 93683, 93701, 93703, 93719, 93739, 93761, 93763, 93787,
+ 93809, 93811, 93827, 93851, 93871, 93887, 93889, 93893, 93901, 93911,
+ 93913, 93923, 93937, 93941, 93949, 93967, 93971, 93979, 93983, 93997,
+ 94007, 94009, 94033, 94049, 94057, 94063, 94079, 94099, 94109, 94111,
+ 94117, 94121, 94151, 94153, 94169, 94201, 94207, 94219, 94229, 94253,
+ 94261, 94273, 94291, 94307, 94309, 94321, 94327, 94331, 94343, 94349,
+ 94351, 94379, 94397, 94399, 94421, 94427, 94433, 94439, 94441, 94447,
+ 94463, 94477, 94483, 94513, 94529, 94531, 94541, 94543, 94547, 94559,
+ 94561, 94573, 94583, 94597, 94603, 94613, 94621, 94649, 94651, 94687,
+ 94693, 94709, 94723, 94727, 94747, 94771, 94777, 94781, 94789, 94793,
+ 94811, 94819, 94823, 94837, 94841, 94847, 94849, 94873, 94889, 94903,
+ 94907, 94933, 94949, 94951, 94961, 94993, 94999, 95003, 95009, 95021,
+ 95027, 95063, 95071, 95083, 95087, 95089, 95093, 95101, 95107, 95111,
+ 95131, 95143, 95153, 95177, 95189, 95191, 95203, 95213, 95219, 95231,
+ 95233, 95239, 95257, 95261, 95267, 95273, 95279, 95287, 95311, 95317,
+ 95327, 95339, 95369, 95383, 95393, 95401, 95413, 95419, 95429, 95441,
+ 95443, 95461, 95467, 95471, 95479, 95483, 95507, 95527, 95531, 95539,
+ 95549, 95561, 95569, 95581, 95597, 95603, 95617, 95621, 95629, 95633,
+ 95651, 95701, 95707, 95713, 95717, 95723, 95731, 95737, 95747, 95773,
+ 95783, 95789, 95791, 95801, 95803, 95813, 95819, 95857, 95869, 95873,
+ 95881, 95891, 95911, 95917, 95923, 95929, 95947, 95957, 95959, 95971,
+ 95987, 95989, 96001, 96013, 96017, 96043, 96053, 96059, 96079, 96097,
+ 96137, 96149, 96157, 96167, 96179, 96181, 96199, 96211, 96221, 96223,
+ 96233, 96259, 96263, 96269, 96281, 96289, 96293, 96323, 96329, 96331,
+ 96337, 96353, 96377, 96401, 96419, 96431, 96443, 96451, 96457, 96461,
+ 96469, 96479, 96487, 96493, 96497, 96517, 96527, 96553, 96557, 96581,
+ 96587, 96589, 96601, 96643, 96661, 96667, 96671, 96697, 96703, 96731,
+ 96737, 96739, 96749, 96757, 96763, 96769, 96779, 96787, 96797, 96799,
+ 96821, 96823, 96827, 96847, 96851, 96857, 96893, 96907, 96911, 96931,
+ 96953, 96959, 96973, 96979, 96989, 96997, 97001, 97003, 97007, 97021,
+ 97039, 97073, 97081, 97103, 97117, 97127, 97151, 97157, 97159, 97169,
+ 97171, 97177, 97187, 97213, 97231, 97241, 97259, 97283, 97301, 97303,
+ 97327, 97367, 97369, 97373, 97379, 97381, 97387, 97397, 97423, 97429,
+ 97441, 97453, 97459, 97463, 97499, 97501, 97511, 97523, 97547, 97549,
+ 97553, 97561, 97571, 97577, 97579, 97583, 97607, 97609, 97613, 97649,
+ 97651, 97673, 97687, 97711, 97729, 97771, 97777, 97787, 97789, 97813,
+ 97829, 97841, 97843, 97847, 97849, 97859, 97861, 97871, 97879, 97883,
+ 97919, 97927, 97931, 97943, 97961, 97967, 97973, 97987, 98009, 98011,
+ 98017, 98041, 98047, 98057, 98081, 98101, 98123, 98129, 98143, 98179,
+ 98207, 98213, 98221, 98227, 98251, 98257, 98269, 98297, 98299, 98317,
+ 98321, 98323, 98327, 98347, 98369, 98377, 98387, 98389, 98407, 98411,
+ 98419, 98429, 98443, 98453, 98459, 98467, 98473, 98479, 98491, 98507,
+ 98519, 98533, 98543, 98561, 98563, 98573, 98597, 98621, 98627, 98639,
+ 98641, 98663, 98669, 98689, 98711, 98713, 98717, 98729, 98731, 98737,
+ 98773, 98779, 98801, 98807, 98809, 98837, 98849, 98867, 98869, 98873,
+ 98887, 98893, 98897, 98899, 98909, 98911, 98927, 98929, 98939, 98947,
+ 98953, 98963, 98981, 98993, 98999, 99013, 99017, 99023, 99041, 99053,
+ 99079, 99083, 99089, 99103, 99109, 99119, 99131, 99133, 99137, 99139,
+ 99149, 99173, 99181, 99191, 99223, 99233, 99241, 99251, 99257, 99259,
+ 99277, 99289, 99317, 99347, 99349, 99367, 99371, 99377, 99391, 99397,
+ 99401, 99409, 99431, 99439, 99469, 99487, 99497, 99523, 99527, 99529,
+ 99551, 99559, 99563, 99571, 99577, 99581, 99607, 99611, 99623, 99643,
+ 99661, 99667, 99679, 99689, 99707, 99709, 99713, 99719, 99721, 99733,
+ 99761, 99767, 99787, 99793, 99809, 99817, 99823, 99829, 99833, 99839,
+ 99859, 99871, 99877, 99881, 99901, 99907, 99923, 99929, 99961, 99971,
+ 99989, 99991, 100003, 100019, 100043, 100049, 100057, 100069, 100103, 100109,
+100129, 100151, 100153, 100169, 100183, 100189, 100193, 100207, 100213, 100237,
+100267, 100271, 100279, 100291, 100297, 100313, 100333, 100343, 100357, 100361,
+100363, 100379, 100391, 100393, 100403, 100411, 100417, 100447, 100459, 100469,
+100483, 100493, 100501, 100511, 100517, 100519, 100523, 100537, 100547, 100549,
+100559, 100591, 100609, 100613, 100621, 100649, 100669, 100673, 100693, 100699,
+100703, 100733, 100741, 100747, 100769, 100787, 100799, 100801, 100811, 100823,
+100829, 100847, 100853, 100907, 100913, 100927, 100931, 100937, 100943, 100957,
+100981, 100987, 100999, 101009, 101021, 101027, 101051, 101063, 101081, 101089,
+101107, 101111, 101113, 101117, 101119, 101141, 101149, 101159, 101161, 101173,
+101183, 101197, 101203, 101207, 101209, 101221, 101267, 101273, 101279, 101281,
+101287, 101293, 101323, 101333, 101341, 101347, 101359, 101363, 101377, 101383,
+101399, 101411, 101419, 101429, 101449, 101467, 101477, 101483, 101489, 101501,
+101503, 101513, 101527, 101531, 101533, 101537, 101561, 101573, 101581, 101599,
+101603, 101611, 101627, 101641, 101653, 101663, 101681, 101693, 101701, 101719,
+101723, 101737, 101741, 101747, 101749, 101771, 101789, 101797, 101807, 101833,
+101837, 101839, 101863, 101869, 101873, 101879, 101891, 101917, 101921, 101929,
+101939, 101957, 101963, 101977, 101987, 101999, 102001, 102013, 102019, 102023,
+102031, 102043, 102059, 102061, 102071, 102077, 102079, 102101, 102103, 102107,
+102121, 102139, 102149, 102161, 102181, 102191, 102197, 102199, 102203, 102217,
+102229, 102233, 102241, 102251, 102253, 102259, 102293, 102299, 102301, 102317,
+102329, 102337, 102359, 102367, 102397, 102407, 102409, 102433, 102437, 102451,
+102461, 102481, 102497, 102499, 102503, 102523, 102533, 102539, 102547, 102551,
+102559, 102563, 102587, 102593, 102607, 102611, 102643, 102647, 102653, 102667,
+102673, 102677, 102679, 102701, 102761, 102763, 102769, 102793, 102797, 102811,
+102829, 102841, 102859, 102871, 102877, 102881, 102911, 102913, 102929, 102931,
+102953, 102967, 102983, 103001, 103007, 103043, 103049, 103067, 103069, 103079,
+103087, 103091, 103093, 103099, 103123, 103141, 103171, 103177, 103183, 103217,
+103231, 103237, 103289, 103291, 103307, 103319, 103333, 103349, 103357, 103387,
+103391, 103393, 103399, 103409, 103421, 103423, 103451, 103457, 103471, 103483,
+103511, 103529, 103549, 103553, 103561, 103567, 103573, 103577, 103583, 103591,
+103613, 103619, 103643, 103651, 103657, 103669, 103681, 103687, 103699, 103703,
+103723, 103769, 103787, 103801, 103811, 103813, 103837, 103841, 103843, 103867,
+103889, 103903, 103913, 103919, 103951, 103963, 103967, 103969, 103979, 103981,
+103991, 103993, 103997, 104003, 104009, 104021, 104033, 104047, 104053, 104059,
+104087, 104089, 104107, 104113, 104119, 104123, 104147, 104149, 104161, 104173,
+104179, 104183, 104207, 104231, 104233, 104239, 104243, 104281, 104287, 104297,
+104309, 104311, 104323, 104327, 104347, 104369, 104381, 104383, 104393, 104399,
+104417, 104459, 104471, 104473, 104479, 104491, 104513, 104527, 104537, 104543,
+104549, 104551, 104561, 104579, 104593, 104597, 104623, 104639, 104651, 104659,
+104677, 104681, 104683, 104693, 104701, 104707, 104711, 104717, 104723, 104729,
+)
diff --git a/lib/Crypto/Util/number.pyi b/lib/Crypto/Util/number.pyi
new file mode 100644
index 0000000..f8680bf
--- /dev/null
+++ b/lib/Crypto/Util/number.pyi
@@ -0,0 +1,19 @@
+from typing import List, Optional, Callable
+
+
+def ceil_div(n: int, d: int) -> int: ...
+def size (N: int) -> int: ...
+def getRandomInteger(N: int, randfunc: Optional[Callable]=None) -> int: ...
+def getRandomRange(a: int, b: int, randfunc: Optional[Callable]=None) -> int: ...
+def getRandomNBitInteger(N: int, randfunc: Optional[Callable]=None) -> int: ...
+def GCD(x: int,y: int) -> int: ...
+def inverse(u: int, v: int) -> int: ...
+def getPrime(N: int, randfunc: Optional[Callable]=None) -> int: ...
+def getStrongPrime(N: int, e: Optional[int]=0, false_positive_prob: Optional[float]=1e-6, randfunc: Optional[Callable]=None) -> int: ...
+def isPrime(N: int, false_positive_prob: Optional[float]=1e-6, randfunc: Optional[Callable]=None) -> bool: ...
+def long_to_bytes(n: int, blocksize: Optional[int]=0) -> bytes: ...
+def bytes_to_long(s: bytes) -> int: ...
+def long2str(n: int, blocksize: Optional[int]=0) -> bytes: ...
+def str2long(s: bytes) -> int: ...
+
+sieve_base: List[int]
diff --git a/lib/Crypto/Util/py3compat.py b/lib/Crypto/Util/py3compat.py
new file mode 100644
index 0000000..3608357
--- /dev/null
+++ b/lib/Crypto/Util/py3compat.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+#
+# Util/py3compat.py : Compatibility code for handling Py3k / Python 2.x
+#
+# Written in 2010 by Thorsten Behrens
+#
+# ===================================================================
+# The contents of this file are dedicated to the public domain. To
+# the extent that dedication to the public domain is not available,
+# everyone is granted a worldwide, perpetual, royalty-free,
+# non-exclusive license to exercise all rights associated with the
+# contents of this file for any purpose whatsoever.
+# No rights are reserved.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+# ===================================================================
+
+"""Compatibility code for handling string/bytes changes from Python 2.x to Py3k
+
+In Python 2.x, strings (of type ''str'') contain binary data, including encoded
+Unicode text (e.g. UTF-8). The separate type ''unicode'' holds Unicode text.
+Unicode literals are specified via the u'...' prefix. Indexing or slicing
+either type always produces a string of the same type as the original.
+Data read from a file is always of '''str'' type.
+
+In Python 3.x, strings (type ''str'') may only contain Unicode text. The u'...'
+prefix and the ''unicode'' type are now redundant. A new type (called
+''bytes'') has to be used for binary data (including any particular
+''encoding'' of a string). The b'...' prefix allows one to specify a binary
+literal. Indexing or slicing a string produces another string. Slicing a byte
+string produces another byte string, but the indexing operation produces an
+integer. Data read from a file is of '''str'' type if the file was opened in
+text mode, or of ''bytes'' type otherwise.
+
+Since PyCrypto aims at supporting both Python 2.x and 3.x, the following helper
+functions are used to keep the rest of the library as independent as possible
+from the actual Python version.
+
+In general, the code should always deal with binary strings, and use integers
+instead of 1-byte character strings.
+
+b(s)
+ Take a text string literal (with no prefix or with u'...' prefix) and
+ make a byte string.
+bchr(c)
+ Take an integer and make a 1-character byte string.
+bord(c)
+ Take the result of indexing on a byte string and make an integer.
+tobytes(s)
+ Take a text string, a byte string, or a sequence of character taken from
+ a byte string and make a byte string.
+"""
+
+import sys
+import abc
+
+
+if sys.version_info[0] == 2:
+ def b(s):
+ return s
+ def bchr(s):
+ return chr(s)
+ def bstr(s):
+ return str(s)
+ def bord(s):
+ return ord(s)
+ def tobytes(s, encoding="latin-1"):
+ if isinstance(s, unicode):
+ return s.encode(encoding)
+ elif isinstance(s, str):
+ return s
+ elif isinstance(s, bytearray):
+ return bytes(s)
+ elif isinstance(s, memoryview):
+ return s.tobytes()
+ else:
+ return ''.join(s)
+ def tostr(bs):
+ return bs
+ def byte_string(s):
+ return isinstance(s, str)
+
+ from StringIO import StringIO
+ BytesIO = StringIO
+
+ from sys import maxint
+
+ iter_range = xrange
+
+ def is_native_int(x):
+ return isinstance(x, (int, long))
+
+ def is_string(x):
+ return isinstance(x, basestring)
+
+ def is_bytes(x):
+ return isinstance(x, str) or \
+ isinstance(x, bytearray) or \
+ isinstance(x, memoryview)
+
+ ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()})
+
+ FileNotFoundError = IOError
+
+else:
+ def b(s):
+ return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
+ def bchr(s):
+ return bytes([s])
+ def bstr(s):
+ if isinstance(s,str):
+ return bytes(s,"latin-1")
+ else:
+ return bytes(s)
+ def bord(s):
+ return s
+ def tobytes(s, encoding="latin-1"):
+ if isinstance(s, bytes):
+ return s
+ elif isinstance(s, bytearray):
+ return bytes(s)
+ elif isinstance(s,str):
+ return s.encode(encoding)
+ elif isinstance(s, memoryview):
+ return s.tobytes()
+ else:
+ return bytes([s])
+ def tostr(bs):
+ return bs.decode("latin-1")
+ def byte_string(s):
+ return isinstance(s, bytes)
+
+ from io import BytesIO
+ from io import StringIO
+ from sys import maxsize as maxint
+
+ iter_range = range
+
+ def is_native_int(x):
+ return isinstance(x, int)
+
+ def is_string(x):
+ return isinstance(x, str)
+
+ def is_bytes(x):
+ return isinstance(x, bytes) or \
+ isinstance(x, bytearray) or \
+ isinstance(x, memoryview)
+
+ from abc import ABC
+
+ FileNotFoundError = FileNotFoundError
+
+
+def _copy_bytes(start, end, seq):
+ """Return an immutable copy of a sequence (byte string, byte array, memoryview)
+ in a certain interval [start:seq]"""
+
+ if isinstance(seq, memoryview):
+ return seq[start:end].tobytes()
+ elif isinstance(seq, bytearray):
+ return bytes(seq[start:end])
+ else:
+ return seq[start:end]
+
+del sys
+del abc
diff --git a/lib/Crypto/Util/py3compat.pyi b/lib/Crypto/Util/py3compat.pyi
new file mode 100644
index 0000000..74e04a2
--- /dev/null
+++ b/lib/Crypto/Util/py3compat.pyi
@@ -0,0 +1,33 @@
+from typing import Union, Any, Optional, IO
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+import sys
+
+def b(s: str) -> bytes: ...
+def bchr(s: int) -> bytes: ...
+def bord(s: bytes) -> int: ...
+def tobytes(s: Union[bytes, str]) -> bytes: ...
+def tostr(b: bytes) -> str: ...
+def bytestring(x: Any) -> bool: ...
+
+def is_native_int(s: Any) -> bool: ...
+def is_string(x: Any) -> bool: ...
+def is_bytes(x: Any) -> bool: ...
+
+def BytesIO(b: bytes) -> IO[bytes]: ...
+def StringIO(s: str) -> IO[str]: ...
+
+if sys.version_info[0] == 2:
+ from sys import maxint
+ iter_range = xrange
+
+else:
+ from sys import maxsize as maxint
+ iter_range = range
+
+class FileNotFoundError:
+ def __init__(self, err: int, msg: str, filename: str) -> None:
+ pass
+
+def _copy_bytes(start: Optional[int], end: Optional[int], seq: Buffer) -> bytes: ...
diff --git a/lib/Crypto/Util/strxor.py b/lib/Crypto/Util/strxor.py
new file mode 100644
index 0000000..362db6e
--- /dev/null
+++ b/lib/Crypto/Util/strxor.py
@@ -0,0 +1,146 @@
+# ===================================================================
+#
+# Copyright (c) 2014, Legrandin <helderijs@gmail.com>
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# 2. Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in
+# the documentation and/or other materials provided with the
+# distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+# ===================================================================
+
+from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, c_size_t,
+ create_string_buffer, get_raw_buffer,
+ c_uint8_ptr, is_writeable_buffer)
+
+_raw_strxor = load_pycryptodome_raw_lib(
+ "Crypto.Util._strxor",
+ """
+ void strxor(const uint8_t *in1,
+ const uint8_t *in2,
+ uint8_t *out, size_t len);
+ void strxor_c(const uint8_t *in,
+ uint8_t c,
+ uint8_t *out,
+ size_t len);
+ """)
+
+
+def strxor(term1, term2, output=None):
+ """From two byte strings of equal length,
+ create a third one which is the byte-by-byte XOR of the two.
+
+ Args:
+ term1 (bytes/bytearray/memoryview):
+ The first byte string to XOR.
+ term2 (bytes/bytearray/memoryview):
+ The second byte string to XOR.
+ output (bytearray/memoryview):
+ The location where the result will be written to.
+ It must have the same length as ``term1`` and ``term2``.
+ If ``None``, the result is returned.
+ :Return:
+ If ``output`` is ``None``, a new byte string with the result.
+ Otherwise ``None``.
+
+ .. note::
+ ``term1`` and ``term2`` must have the same length.
+ """
+
+ if len(term1) != len(term2):
+ raise ValueError("Only byte strings of equal length can be xored")
+
+ if output is None:
+ result = create_string_buffer(len(term1))
+ else:
+ # Note: output may overlap with either input
+ result = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(term1) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(term1))
+
+ _raw_strxor.strxor(c_uint8_ptr(term1),
+ c_uint8_ptr(term2),
+ c_uint8_ptr(result),
+ c_size_t(len(term1)))
+
+ if output is None:
+ return get_raw_buffer(result)
+ else:
+ return None
+
+
+def strxor_c(term, c, output=None):
+ """From a byte string, create a second one of equal length
+ where each byte is XOR-red with the same value.
+
+ Args:
+ term(bytes/bytearray/memoryview):
+ The byte string to XOR.
+ c (int):
+ Every byte in the string will be XOR-ed with this value.
+ It must be between 0 and 255 (included).
+ output (None or bytearray/memoryview):
+ The location where the result will be written to.
+ It must have the same length as ``term``.
+ If ``None``, the result is returned.
+
+ Return:
+ If ``output`` is ``None``, a new ``bytes`` string with the result.
+ Otherwise ``None``.
+ """
+
+ if not 0 <= c < 256:
+ raise ValueError("c must be in range(256)")
+
+ if output is None:
+ result = create_string_buffer(len(term))
+ else:
+ # Note: output may overlap with either input
+ result = output
+
+ if not is_writeable_buffer(output):
+ raise TypeError("output must be a bytearray or a writeable memoryview")
+
+ if len(term) != len(output):
+ raise ValueError("output must have the same length as the input"
+ " (%d bytes)" % len(term))
+
+ _raw_strxor.strxor_c(c_uint8_ptr(term),
+ c,
+ c_uint8_ptr(result),
+ c_size_t(len(term))
+ )
+
+ if output is None:
+ return get_raw_buffer(result)
+ else:
+ return None
+
+
+def _strxor_direct(term1, term2, result):
+ """Very fast XOR - check conditions!"""
+ _raw_strxor.strxor(term1, term2, result, c_size_t(len(term1)))
diff --git a/lib/Crypto/Util/strxor.pyi b/lib/Crypto/Util/strxor.pyi
new file mode 100644
index 0000000..ca896f3
--- /dev/null
+++ b/lib/Crypto/Util/strxor.pyi
@@ -0,0 +1,6 @@
+from typing import Union, Optional
+
+Buffer = Union[bytes, bytearray, memoryview]
+
+def strxor(term1: bytes, term2: bytes, output: Optional[Buffer]=...) -> bytes: ...
+def strxor_c(term: bytes, c: int, output: Optional[Buffer]=...) -> bytes: ...
diff --git a/lib/Crypto/__init__.py b/lib/Crypto/__init__.py
new file mode 100644
index 0000000..9c2f83b
--- /dev/null
+++ b/lib/Crypto/__init__.py
@@ -0,0 +1,6 @@
+__all__ = ['Cipher', 'Hash', 'Protocol', 'PublicKey', 'Util', 'Signature',
+ 'IO', 'Math']
+
+version_info = (3, 15, '0')
+
+__version__ = ".".join([str(x) for x in version_info])
diff --git a/lib/Crypto/__init__.pyi b/lib/Crypto/__init__.pyi
new file mode 100644
index 0000000..bc73446
--- /dev/null
+++ b/lib/Crypto/__init__.pyi
@@ -0,0 +1,4 @@
+from typing import Tuple, Union
+
+version_info : Tuple[int, int, Union[int, str]]
+__version__ : str
diff --git a/lib/Crypto/py.typed b/lib/Crypto/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/Crypto/py.typed
diff --git a/lib/SQLAlchemy-1.4.40.dist-info/INSTALLER b/lib/SQLAlchemy-1.4.40.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/SQLAlchemy-1.4.40.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/SQLAlchemy-1.4.40.dist-info/LICENSE b/lib/SQLAlchemy-1.4.40.dist-info/LICENSE
new file mode 100644
index 0000000..c933e4b
--- /dev/null
+++ b/lib/SQLAlchemy-1.4.40.dist-info/LICENSE
@@ -0,0 +1,19 @@
+Copyright 2005-2022 SQLAlchemy authors and contributors <see AUTHORS file>.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/lib/SQLAlchemy-1.4.40.dist-info/METADATA b/lib/SQLAlchemy-1.4.40.dist-info/METADATA
new file mode 100644
index 0000000..825633f
--- /dev/null
+++ b/lib/SQLAlchemy-1.4.40.dist-info/METADATA
@@ -0,0 +1,237 @@
+Metadata-Version: 2.1
+Name: SQLAlchemy
+Version: 1.4.40
+Summary: Database Abstraction Library
+Home-page: https://www.sqlalchemy.org
+Author: Mike Bayer
+Author-email: mike_mp@zzzcomputing.com
+License: MIT
+Project-URL: Documentation, https://docs.sqlalchemy.org
+Project-URL: Issue Tracker, https://github.com/sqlalchemy/sqlalchemy/
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Intended Audience :: Developers
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Operating System :: OS Independent
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: Implementation :: CPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+Classifier: Topic :: Database :: Front-Ends
+Requires-Python: !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7
+Description-Content-Type: text/x-rst
+License-File: LICENSE
+Requires-Dist: importlib-metadata ; python_version < "3.8"
+Requires-Dist: greenlet (!=0.4.17) ; python_version >= "3" and (platform_machine == "aarch64" or (platform_machine == "ppc64le" or (platform_machine == "x86_64" or (platform_machine == "amd64" or (platform_machine == "AMD64" or (platform_machine == "win32" or platform_machine == "WIN32"))))))
+Provides-Extra: aiomysql
+Requires-Dist: greenlet (!=0.4.17) ; (python_version >= "3") and extra == 'aiomysql'
+Requires-Dist: aiomysql ; (python_version >= "3") and extra == 'aiomysql'
+Provides-Extra: aiosqlite
+Requires-Dist: typing-extensions (!=3.10.0.1) ; extra == 'aiosqlite'
+Requires-Dist: greenlet (!=0.4.17) ; (python_version >= "3") and extra == 'aiosqlite'
+Requires-Dist: aiosqlite ; (python_version >= "3") and extra == 'aiosqlite'
+Provides-Extra: asyncio
+Requires-Dist: greenlet (!=0.4.17) ; (python_version >= "3") and extra == 'asyncio'
+Provides-Extra: asyncmy
+Requires-Dist: greenlet (!=0.4.17) ; (python_version >= "3") and extra == 'asyncmy'
+Requires-Dist: asyncmy (!=0.2.4,>=0.2.3) ; (python_version >= "3") and extra == 'asyncmy'
+Provides-Extra: mariadb_connector
+Requires-Dist: mariadb (!=1.1.2,>=1.0.1) ; (python_version >= "3") and extra == 'mariadb_connector'
+Provides-Extra: mssql
+Requires-Dist: pyodbc ; extra == 'mssql'
+Provides-Extra: mssql_pymssql
+Requires-Dist: pymssql ; extra == 'mssql_pymssql'
+Provides-Extra: mssql_pyodbc
+Requires-Dist: pyodbc ; extra == 'mssql_pyodbc'
+Provides-Extra: mypy
+Requires-Dist: sqlalchemy2-stubs ; extra == 'mypy'
+Requires-Dist: mypy (>=0.910) ; (python_version >= "3") and extra == 'mypy'
+Provides-Extra: mysql
+Requires-Dist: mysqlclient (<2,>=1.4.0) ; (python_version < "3") and extra == 'mysql'
+Requires-Dist: mysqlclient (>=1.4.0) ; (python_version >= "3") and extra == 'mysql'
+Provides-Extra: mysql_connector
+Requires-Dist: mysql-connector-python ; extra == 'mysql_connector'
+Provides-Extra: oracle
+Requires-Dist: cx-oracle (<8,>=7) ; (python_version < "3") and extra == 'oracle'
+Requires-Dist: cx-oracle (>=7) ; (python_version >= "3") and extra == 'oracle'
+Provides-Extra: postgresql
+Requires-Dist: psycopg2 (>=2.7) ; extra == 'postgresql'
+Provides-Extra: postgresql_asyncpg
+Requires-Dist: greenlet (!=0.4.17) ; (python_version >= "3") and extra == 'postgresql_asyncpg'
+Requires-Dist: asyncpg ; (python_version >= "3") and extra == 'postgresql_asyncpg'
+Provides-Extra: postgresql_pg8000
+Requires-Dist: pg8000 (!=1.29.0,>=1.16.6) ; extra == 'postgresql_pg8000'
+Provides-Extra: postgresql_psycopg2binary
+Requires-Dist: psycopg2-binary ; extra == 'postgresql_psycopg2binary'
+Provides-Extra: postgresql_psycopg2cffi
+Requires-Dist: psycopg2cffi ; extra == 'postgresql_psycopg2cffi'
+Provides-Extra: pymysql
+Requires-Dist: pymysql (<1) ; (python_version < "3") and extra == 'pymysql'
+Requires-Dist: pymysql ; (python_version >= "3") and extra == 'pymysql'
+Provides-Extra: sqlcipher
+Requires-Dist: sqlcipher3-binary ; (python_version >= "3") and extra == 'sqlcipher'
+
+SQLAlchemy
+==========
+
+|PyPI| |Python| |Downloads|
+
+.. |PyPI| image:: https://img.shields.io/pypi/v/sqlalchemy
+ :target: https://pypi.org/project/sqlalchemy
+ :alt: PyPI
+
+.. |Python| image:: https://img.shields.io/pypi/pyversions/sqlalchemy
+ :target: https://pypi.org/project/sqlalchemy
+ :alt: PyPI - Python Version
+
+.. |Downloads| image:: https://img.shields.io/pypi/dm/sqlalchemy
+ :target: https://pypi.org/project/sqlalchemy
+ :alt: PyPI - Downloads
+
+
+The Python SQL Toolkit and Object Relational Mapper
+
+Introduction
+-------------
+
+SQLAlchemy is the Python SQL toolkit and Object Relational Mapper
+that gives application developers the full power and
+flexibility of SQL. SQLAlchemy provides a full suite
+of well known enterprise-level persistence patterns,
+designed for efficient and high-performing database
+access, adapted into a simple and Pythonic domain
+language.
+
+Major SQLAlchemy features include:
+
+* An industrial strength ORM, built
+ from the core on the identity map, unit of work,
+ and data mapper patterns. These patterns
+ allow transparent persistence of objects
+ using a declarative configuration system.
+ Domain models
+ can be constructed and manipulated naturally,
+ and changes are synchronized with the
+ current transaction automatically.
+* A relationally-oriented query system, exposing
+ the full range of SQL's capabilities
+ explicitly, including joins, subqueries,
+ correlation, and most everything else,
+ in terms of the object model.
+ Writing queries with the ORM uses the same
+ techniques of relational composition you use
+ when writing SQL. While you can drop into
+ literal SQL at any time, it's virtually never
+ needed.
+* A comprehensive and flexible system
+ of eager loading for related collections and objects.
+ Collections are cached within a session,
+ and can be loaded on individual access, all
+ at once using joins, or by query per collection
+ across the full result set.
+* A Core SQL construction system and DBAPI
+ interaction layer. The SQLAlchemy Core is
+ separate from the ORM and is a full database
+ abstraction layer in its own right, and includes
+ an extensible Python-based SQL expression
+ language, schema metadata, connection pooling,
+ type coercion, and custom types.
+* All primary and foreign key constraints are
+ assumed to be composite and natural. Surrogate
+ integer primary keys are of course still the
+ norm, but SQLAlchemy never assumes or hardcodes
+ to this model.
+* Database introspection and generation. Database
+ schemas can be "reflected" in one step into
+ Python structures representing database metadata;
+ those same structures can then generate
+ CREATE statements right back out - all within
+ the Core, independent of the ORM.
+
+SQLAlchemy's philosophy:
+
+* SQL databases behave less and less like object
+ collections the more size and performance start to
+ matter; object collections behave less and less like
+ tables and rows the more abstraction starts to matter.
+ SQLAlchemy aims to accommodate both of these
+ principles.
+* An ORM doesn't need to hide the "R". A relational
+ database provides rich, set-based functionality
+ that should be fully exposed. SQLAlchemy's
+ ORM provides an open-ended set of patterns
+ that allow a developer to construct a custom
+ mediation layer between a domain model and
+ a relational schema, turning the so-called
+ "object relational impedance" issue into
+ a distant memory.
+* The developer, in all cases, makes all decisions
+ regarding the design, structure, and naming conventions
+ of both the object model as well as the relational
+ schema. SQLAlchemy only provides the means
+ to automate the execution of these decisions.
+* With SQLAlchemy, there's no such thing as
+ "the ORM generated a bad query" - you
+ retain full control over the structure of
+ queries, including how joins are organized,
+ how subqueries and correlation is used, what
+ columns are requested. Everything SQLAlchemy
+ does is ultimately the result of a developer-
+ initiated decision.
+* Don't use an ORM if the problem doesn't need one.
+ SQLAlchemy consists of a Core and separate ORM
+ component. The Core offers a full SQL expression
+ language that allows Pythonic construction
+ of SQL constructs that render directly to SQL
+ strings for a target database, returning
+ result sets that are essentially enhanced DBAPI
+ cursors.
+* Transactions should be the norm. With SQLAlchemy's
+ ORM, nothing goes to permanent storage until
+ commit() is called. SQLAlchemy encourages applications
+ to create a consistent means of delineating
+ the start and end of a series of operations.
+* Never render a literal value in a SQL statement.
+ Bound parameters are used to the greatest degree
+ possible, allowing query optimizers to cache
+ query plans effectively and making SQL injection
+ attacks a non-issue.
+
+Documentation
+-------------
+
+Latest documentation is at:
+
+https://www.sqlalchemy.org/docs/
+
+Installation / Requirements
+---------------------------
+
+Full documentation for installation is at
+`Installation <https://www.sqlalchemy.org/docs/intro.html#installation>`_.
+
+Getting Help / Development / Bug reporting
+------------------------------------------
+
+Please refer to the `SQLAlchemy Community Guide <https://www.sqlalchemy.org/support.html>`_.
+
+Code of Conduct
+---------------
+
+Above all, SQLAlchemy places great emphasis on polite, thoughtful, and
+constructive communication between users and developers.
+Please see our current Code of Conduct at
+`Code of Conduct <https://www.sqlalchemy.org/codeofconduct.html>`_.
+
+License
+-------
+
+SQLAlchemy is distributed under the `MIT license
+<https://www.opensource.org/licenses/mit-license.php>`_.
+
diff --git a/lib/SQLAlchemy-1.4.40.dist-info/RECORD b/lib/SQLAlchemy-1.4.40.dist-info/RECORD
new file mode 100644
index 0000000..a2a1187
--- /dev/null
+++ b/lib/SQLAlchemy-1.4.40.dist-info/RECORD
@@ -0,0 +1,486 @@
+SQLAlchemy-1.4.40.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+SQLAlchemy-1.4.40.dist-info/LICENSE,sha256=hZ3tJdo0wetz5uc230xfjOPtLtUpBmMXbwbncg2cmiA,1100
+SQLAlchemy-1.4.40.dist-info/METADATA,sha256=YzemP5m4ZlRnJJTzRT1bGGHKByzIY4hF2pZZUgI-cOo,9972
+SQLAlchemy-1.4.40.dist-info/RECORD,,
+SQLAlchemy-1.4.40.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+SQLAlchemy-1.4.40.dist-info/WHEEL,sha256=RvDNC7WG_jtA8tJl5Arh36KKeuDAxaR7gtO4xTAfLVM,217
+SQLAlchemy-1.4.40.dist-info/top_level.txt,sha256=rp-ZgB7D8G11ivXON5VGPjupT1voYmWqkciDt5Uaw_Q,11
+sqlalchemy/__init__.py,sha256=210SNMBE7tXBO_CWMol8KnRpMie57CPWuE_rL0lmF3Y,4114
+sqlalchemy/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/__pycache__/events.cpython-39.pyc,,
+sqlalchemy/__pycache__/exc.cpython-39.pyc,,
+sqlalchemy/__pycache__/inspection.cpython-39.pyc,,
+sqlalchemy/__pycache__/log.cpython-39.pyc,,
+sqlalchemy/__pycache__/processors.cpython-39.pyc,,
+sqlalchemy/__pycache__/schema.cpython-39.pyc,,
+sqlalchemy/__pycache__/types.cpython-39.pyc,,
+sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.so,sha256=rupQuvoVfaUzECRjcEXMI2v-ogC0xKi68FUD1LGv8mg,53952
+sqlalchemy/connectors/__init__.py,sha256=2m_LPZFkNExkoaTw14fRActQCcyFl7W81WeYj2O10lM,279
+sqlalchemy/connectors/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/connectors/__pycache__/mxodbc.cpython-39.pyc,,
+sqlalchemy/connectors/__pycache__/pyodbc.cpython-39.pyc,,
+sqlalchemy/connectors/mxodbc.py,sha256=CApFVkPEL8amXL5HKcG83jU9RbbVg0EQSyxceLWh260,5784
+sqlalchemy/connectors/pyodbc.py,sha256=003bqMmK-Hpy-kZYa4vy2CNRz73Fvvj2zUsyhFQnkUc,6855
+sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.so,sha256=nae6_co12AzkqeYd1bf5yzB6Qyjs7DWx8em_EV_nWXI,60640
+sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.so,sha256=SHvjVzsNaFResbbfKoYz2fuxqUI7Om_O6AUY5-ib0Qo,92632
+sqlalchemy/databases/__init__.py,sha256=LAm4NHQgjg4sdCED02wUiZj9_0fKBEkStYtqvLWHArk,1010
+sqlalchemy/databases/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/__init__.py,sha256=52RcDU2JGS1nW2OHx2nIJ1B_IBI4puWFx09th8Hg-D0,2085
+sqlalchemy/dialects/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/firebird/__init__.py,sha256=iZH9WTMjUcsAf6Rl6-64CkcoLOixitP45TSZVSBQYL4,1153
+sqlalchemy/dialects/firebird/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/firebird/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/dialects/firebird/__pycache__/fdb.cpython-39.pyc,,
+sqlalchemy/dialects/firebird/__pycache__/kinterbasdb.cpython-39.pyc,,
+sqlalchemy/dialects/firebird/base.py,sha256=P0ycKcsMKJyglm6uikAVDSc_7UV0NPSIU7hL58HQaog,31171
+sqlalchemy/dialects/firebird/fdb.py,sha256=lQhO8S1P8PjUeEW3NXCC1vqNp1DGzBQIUN2eIi-fCC0,4116
+sqlalchemy/dialects/firebird/kinterbasdb.py,sha256=2_RZGXSw12FCEeZW0cXxbaR2Bl7GfMd7gGg5pgUiFzg,6479
+sqlalchemy/dialects/mssql/__init__.py,sha256=fvIR7jRTPH_4HellLg2kjwYIA3HM_jpNWSw9De0JciE,1788
+sqlalchemy/dialects/mssql/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/__pycache__/information_schema.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/__pycache__/json.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/__pycache__/mxodbc.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/__pycache__/provision.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/__pycache__/pymssql.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/__pycache__/pyodbc.cpython-39.pyc,,
+sqlalchemy/dialects/mssql/base.py,sha256=U0nbzzRNV-BH4HP7kAriTiZRiPN_O2LEcBhkdUhHuGE,115957
+sqlalchemy/dialects/mssql/information_schema.py,sha256=R0xpK7xppti2ToGahDksb9jHy9R9MyHTwCfgeNvw3BQ,7584
+sqlalchemy/dialects/mssql/json.py,sha256=K1RqVl5bslYyVMtk5CWGjRV_I4K1sszXjx2F_nbCVWI,4558
+sqlalchemy/dialects/mssql/mxodbc.py,sha256=HPIxqFtSUY9Ugz-ebNb2T_sLoLp4rQi7qrmezsIYIsM,4808
+sqlalchemy/dialects/mssql/provision.py,sha256=m7ofLZYZinDS91Vgs42fK7dhJNnH-J_Bw2x_tP59tCc,4255
+sqlalchemy/dialects/mssql/pymssql.py,sha256=smbS466-7-cr1o2VBRqkTHlfczy_UIycMgM4uerI5Xw,4843
+sqlalchemy/dialects/mssql/pyodbc.py,sha256=T__b7XXLrPAp0eo80ykgelUZQvncF9GcxccPDz_zOgw,24432
+sqlalchemy/dialects/mysql/__init__.py,sha256=4C8GY2nAGQOrdGj3CseZqF4NR-CkhVZ_CgXFoskGAJs,2190
+sqlalchemy/dialects/mysql/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/aiomysql.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/asyncmy.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/cymysql.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/dml.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/enumerated.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/expression.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/json.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/mariadb.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/mariadbconnector.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/mysqlconnector.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/mysqldb.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/oursql.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/provision.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/pymysql.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/pyodbc.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/reflection.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/reserved_words.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/__pycache__/types.cpython-39.pyc,,
+sqlalchemy/dialects/mysql/aiomysql.py,sha256=Xqfr0SjvUu-qQZgrDLBnxo4dRQF9ZrI6tpc4HgiXENE,9609
+sqlalchemy/dialects/mysql/asyncmy.py,sha256=D8slHiFP3hOvwxf8zMY_-72V1owEhnpO0LmQdkz4n4M,9885
+sqlalchemy/dialects/mysql/base.py,sha256=C8fmeH2RVPD9NloDL8ofFLmA-89j1Sd0CDkR0rv7duw,115228
+sqlalchemy/dialects/mysql/cymysql.py,sha256=zaVxpSLTg8rvIrI6BtlK0815BCLuLKp2ILHLs57thVA,2271
+sqlalchemy/dialects/mysql/dml.py,sha256=EXTHGjiXeNxGyt-jbRH5ZNIkRjTja25gQXAthTCCw8g,6226
+sqlalchemy/dialects/mysql/enumerated.py,sha256=Dv5BAF8DxCqfVXIkXt5kzGG-BxNygpdnXrZjyyzKyqM,9364
+sqlalchemy/dialects/mysql/expression.py,sha256=HJ4IO3LPJk4cQYIL-O-jN2vLWxVGCqem_K3h8kKNWzE,3741
+sqlalchemy/dialects/mysql/json.py,sha256=DMQnyo3PQ_XSPvDl8jt26Ya-fyMEaIJDXQBdLVmsdjE,2313
+sqlalchemy/dialects/mysql/mariadb.py,sha256=OBwN9RMQLP-xqLbNMAe5uoz7PEtqa68ln2HwwA6KUn8,585
+sqlalchemy/dialects/mysql/mariadbconnector.py,sha256=vLhoFmC9OFh30bHGRFBwWHv3ou3wTZ8WPZOamgmUuWs,7563
+sqlalchemy/dialects/mysql/mysqlconnector.py,sha256=CT4bFb2WaFHwBDfRSqK3ieltrkulTYwsX0kgbWPrRao,7690
+sqlalchemy/dialects/mysql/mysqldb.py,sha256=qvea9Iuf6SUqb4QSHeCEcbUf3c3FSckjT4jfQSTMlyw,10437
+sqlalchemy/dialects/mysql/oursql.py,sha256=fWWMyvhNZ6ywBGvvwJ8DqtBec8cUtziiIjYopBn2WVg,8523
+sqlalchemy/dialects/mysql/provision.py,sha256=P5ma4Xy5eSOFIcMjIe_zAwu_6ncSXSLVZYYSMS5Io9c,2649
+sqlalchemy/dialects/mysql/pymysql.py,sha256=D106c8jEME1O0wOMV7ZgSuwin7Pv61kKLWYFEEKPpUY,2770
+sqlalchemy/dialects/mysql/pyodbc.py,sha256=31587UnRrSQhep_NXt7ii0-3xkAVDJgCGQXSDCpDDuY,4290
+sqlalchemy/dialects/mysql/reflection.py,sha256=TcX8NovMj9BbEkGUALcom4JMAMPc_kL6nmQzSy5XXGs,18553
+sqlalchemy/dialects/mysql/reserved_words.py,sha256=vvAyUvobiAB46Lpd7DhyWPgp3cWdFaVu9_5P39TEXMM,9104
+sqlalchemy/dialects/mysql/types.py,sha256=MrMLGeFo-zJJfGMn39smAfxy5fPvQrgXv49cIrm6Img,24589
+sqlalchemy/dialects/oracle/__init__.py,sha256=POVn6bB3yD-b4ZT7CSYQlmNpxDRIRpfuJ8CTTYgphPM,1229
+sqlalchemy/dialects/oracle/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/oracle/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/dialects/oracle/__pycache__/cx_oracle.cpython-39.pyc,,
+sqlalchemy/dialects/oracle/__pycache__/provision.cpython-39.pyc,,
+sqlalchemy/dialects/oracle/base.py,sha256=8jixA3aDMW-cyclxBOFIGnpFCVJuixy1raBhmkoaau4,87563
+sqlalchemy/dialects/oracle/cx_oracle.py,sha256=78Igd2RmfFXNGSMllfhMPRu-AUbBVGKZ3_VI6a9ouh4,53202
+sqlalchemy/dialects/oracle/provision.py,sha256=GtHrw1rtW0bzPSa9dUE-IjDFGaElyJyw4rwHAK3QDVY,5806
+sqlalchemy/dialects/postgresql/__init__.py,sha256=thvDDu6Vp68lXdF78wagnnOTq7sFBCDwT5X9x8Mygn8,2509
+sqlalchemy/dialects/postgresql/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/array.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/asyncpg.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/dml.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/ext.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/hstore.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/json.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/pg8000.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/provision.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/psycopg2.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/psycopg2cffi.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/pygresql.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/pypostgresql.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/__pycache__/ranges.cpython-39.pyc,,
+sqlalchemy/dialects/postgresql/array.py,sha256=I-4mTmrRsJSr42EpoUy4OvMBys82PrDO3oMK8FusPbg,13131
+sqlalchemy/dialects/postgresql/asyncpg.py,sha256=FFVn3cctgxfTvRtVV1XUlzmXfIBFgLsqjDnNjj9K9GA,35265
+sqlalchemy/dialects/postgresql/base.py,sha256=MyxOUhYQFvOiKfX207ZlKR5ap5TbWUrwcGyE_IJF1T0,159101
+sqlalchemy/dialects/postgresql/dml.py,sha256=O7GBPR4liaOBBJWGlEU86vrfuzLMy3d3LIbeRZ-nSvc,9582
+sqlalchemy/dialects/postgresql/ext.py,sha256=oIjhNMC4OAYFOyUx21dX-8XIPRRsyqJxiG4IeBv0tVA,8439
+sqlalchemy/dialects/postgresql/hstore.py,sha256=UEjWkExqERMXkK-62Drv8ppJOyV64827xHxi3QpKV-I,12696
+sqlalchemy/dialects/postgresql/json.py,sha256=cIABYehcW9j7ctBCAYXhZFGFQeHgLkisVQB1k2ftnT4,10556
+sqlalchemy/dialects/postgresql/pg8000.py,sha256=_UztntjUclGLtty8nvVwlcNtCEFz_9lsQrf-HR7EpLE,17044
+sqlalchemy/dialects/postgresql/provision.py,sha256=ZDFEIOvtpBIgCnj1Q1R3-WDWx7lFnE6kdEGNTDFpzAw,4319
+sqlalchemy/dialects/postgresql/psycopg2.py,sha256=yUbR7QwBtu46n1TssONOtcF7ci6W2YERDZlyIRzVckI,40340
+sqlalchemy/dialects/postgresql/psycopg2cffi.py,sha256=pBRHxI6KgVePwPO_FAFaE7Nces43qPIviDwbtchi8f8,1691
+sqlalchemy/dialects/postgresql/pygresql.py,sha256=oZ847ZkhqqzPeo1BiQnIP7slX7SIbXdoo1OyC5ehChY,8585
+sqlalchemy/dialects/postgresql/pypostgresql.py,sha256=_Kw2eXUEAefflJVA1dZJ7aCGt2Lown3PW3i2ab2Eat0,3693
+sqlalchemy/dialects/postgresql/ranges.py,sha256=AP3ODSZoH9Yf9CeAPy_GpVVLMtK-4rdebmHWYjgKFug,4763
+sqlalchemy/dialects/sqlite/__init__.py,sha256=GwL23FcaoQOso1Sa1RlaF3i5SezqEVjfijvbp8hzRg0,1198
+sqlalchemy/dialects/sqlite/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/__pycache__/aiosqlite.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/__pycache__/dml.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/__pycache__/json.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/__pycache__/provision.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/__pycache__/pysqlcipher.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/__pycache__/pysqlite.cpython-39.pyc,,
+sqlalchemy/dialects/sqlite/aiosqlite.py,sha256=P4oLfXLU5lsjIjgwClPs-l25VRMLub7QtXWjTSQcbNs,9963
+sqlalchemy/dialects/sqlite/base.py,sha256=UZrriowzuSoAbQagvqKyC9HTCV0UjWuqIxB0SBmO07E,88435
+sqlalchemy/dialects/sqlite/dml.py,sha256=hFloxZoqsrew4tlzS0DSMyzdKJ9-HU0z-dLKWVgR5ns,6865
+sqlalchemy/dialects/sqlite/json.py,sha256=oFw4Rt8xw-tkD3IMlm3TDEGe1RqrTyvIuqjABsxn8EI,2518
+sqlalchemy/dialects/sqlite/provision.py,sha256=AQILXN5PBUSM05c-SFSFFhPdFqcQDwdoKtUnvLDac14,4676
+sqlalchemy/dialects/sqlite/pysqlcipher.py,sha256=1MmhAlAaUTnzm7guppjDzGXQ6_OxFtuGzchSiJ0PeRA,5605
+sqlalchemy/dialects/sqlite/pysqlite.py,sha256=_hIHqR-373bMLUr4fDaN3UXHtJzW-fbLatvXZNx2hWg,23441
+sqlalchemy/dialects/sybase/__init__.py,sha256=STn2xh97yskErTEYZAyrptb5vYOqPamvb9-QnYd3aG4,1364
+sqlalchemy/dialects/sybase/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/dialects/sybase/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/dialects/sybase/__pycache__/mxodbc.cpython-39.pyc,,
+sqlalchemy/dialects/sybase/__pycache__/pyodbc.cpython-39.pyc,,
+sqlalchemy/dialects/sybase/__pycache__/pysybase.cpython-39.pyc,,
+sqlalchemy/dialects/sybase/base.py,sha256=rOfZ2sN3BEtwIDo9nvIWe5VpgxVvjjLt4gSxFb9VyC0,32421
+sqlalchemy/dialects/sybase/mxodbc.py,sha256=7U4-Y4mf_o6qzFQraQ7XklDTB0PDddF8u6hFIpuAsCE,939
+sqlalchemy/dialects/sybase/pyodbc.py,sha256=bTbAjgvx2LRlhY94DYl_NXRkbVJAd71_LbIvRCtDPX0,2230
+sqlalchemy/dialects/sybase/pysybase.py,sha256=-i6vGx7UIVX2arQE9_9GM_YcqeiRCawqxcXnngjvRAY,3370
+sqlalchemy/engine/__init__.py,sha256=T44Oyjf2yPp77vDWs8g54h9XVt3FbGRZagKxGxu9XwU,2108
+sqlalchemy/engine/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/characteristics.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/create.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/cursor.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/default.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/events.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/interfaces.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/mock.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/reflection.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/result.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/row.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/strategies.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/url.cpython-39.pyc,,
+sqlalchemy/engine/__pycache__/util.cpython-39.pyc,,
+sqlalchemy/engine/base.py,sha256=Iv9_Fcju-spBWw_E-KAwaPzNXhFM5EE8XOnBUKLqHt4,124586
+sqlalchemy/engine/characteristics.py,sha256=qvd3T8HW470kIxN-x6OzycfjCFdnmbzcaFQeds7KHOw,1817
+sqlalchemy/engine/create.py,sha256=q47BzZWgZVxWAaex60SIbFxkfvDFHkDUH5RU0_WnwdA,30797
+sqlalchemy/engine/cursor.py,sha256=7M0w-Yc7llOGMbxIVYqqC36h5lHJLG66D_Ut6IO9PhQ,68126
+sqlalchemy/engine/default.py,sha256=rJCIDncGRqYn2yIGLunwoEkV9D8wAqFBfCUgPOYppsU,66872
+sqlalchemy/engine/events.py,sha256=_qeDo_mMNXXnpQBSAnaRkE1gg6c-r7P5IT78r0aBUuc,33422
+sqlalchemy/engine/interfaces.py,sha256=Os_4HO7Ebo1Hxl8Eym86oqB3h6E7K8T5LQrOTYtbSpY,58421
+sqlalchemy/engine/mock.py,sha256=wJIFZbvkHwAoi7jCupeyZzuE-J9lqyzhJ6VdrAyMNkw,3626
+sqlalchemy/engine/reflection.py,sha256=w0ix23go8S41ye3kM-UOLGVs-UiLUnS8oJqrWI-z9ow,38930
+sqlalchemy/engine/result.py,sha256=HwRxVtgpu62MdUxOdlv79HbZx4UKJJoN_uqoe1dQ2WA,58992
+sqlalchemy/engine/row.py,sha256=eFw7PtgqNkRSNwMTZPFxKNOBbwZ4V6_eOP8YpYAwRPE,18690
+sqlalchemy/engine/strategies.py,sha256=RzejZkLGzWq6QWWJ6a6fyYDdQac4VWCmORCTYEOZwCM,414
+sqlalchemy/engine/url.py,sha256=CZ1rnmmJJ1HhebnpHhej6xroYgyKk4-teJqWVQmvTYk,26473
+sqlalchemy/engine/util.py,sha256=drzyg95MX5NzC10bSQsqQ-dc3k4N4p009JhQuLUS8r0,8442
+sqlalchemy/event/__init__.py,sha256=I3Y3cjTy0wC_f-pJRX7B-9UizYje3nh3lIHOlL0Xf00,517
+sqlalchemy/event/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/event/__pycache__/api.cpython-39.pyc,,
+sqlalchemy/event/__pycache__/attr.cpython-39.pyc,,
+sqlalchemy/event/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/event/__pycache__/legacy.cpython-39.pyc,,
+sqlalchemy/event/__pycache__/registry.cpython-39.pyc,,
+sqlalchemy/event/api.py,sha256=yTMDO4cZp-CioTgeDfYGR0O4_zxfFZ-EFdNqM-dOw8E,8043
+sqlalchemy/event/attr.py,sha256=jWU2m7uuuq40HfwIlQK27eJ_3Gg92dtqI4kwu8vhmuk,14625
+sqlalchemy/event/base.py,sha256=FCifBVGLxkNkpr4mN608ZRcAraML8bcS5IU8_vAJjRQ,10936
+sqlalchemy/event/legacy.py,sha256=C09AtrcACXF2gL5c8adk2nLUo1oBfnhFHDkBpv3znUg,6270
+sqlalchemy/event/registry.py,sha256=5FuO494J1n2dUYImM9Yz1kl7C8NmO4c4GtKbk_l-S6k,8486
+sqlalchemy/events.py,sha256=VrZuUXHgwyx4kMKEielctzyTWqDlm2gvzMcc38jedoE,467
+sqlalchemy/exc.py,sha256=x9Z-nIkMQ1r3dqdNmVK5cHQq0zVFrdI6oKkXMw_QB3s,21116
+sqlalchemy/ext/__init__.py,sha256=4-X49d1TiOPC-T8JSpaFiMMVNP8JL9bDoBW19wBmXRY,322
+sqlalchemy/ext/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/associationproxy.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/automap.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/baked.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/compiler.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/horizontal_shard.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/hybrid.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/indexable.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/instrumentation.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/mutable.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/orderinglist.cpython-39.pyc,,
+sqlalchemy/ext/__pycache__/serializer.cpython-39.pyc,,
+sqlalchemy/ext/associationproxy.py,sha256=-687A1ZZMgToO6emMUy8kDOQb-GE8OqfM01xNkh3QtQ,51139
+sqlalchemy/ext/asyncio/__init__.py,sha256=XKCzBrSBP_LlqaCKpiMeSPUzwNdQFXUg9GL57EOM9-8,823
+sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/__pycache__/engine.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/__pycache__/events.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/__pycache__/exc.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/__pycache__/result.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/__pycache__/session.cpython-39.pyc,,
+sqlalchemy/ext/asyncio/base.py,sha256=VQmIq-CMEVQpZPMEa0K91tMxZMqKyCCAwJVuCLiG34w,2280
+sqlalchemy/ext/asyncio/engine.py,sha256=k_IBji10URA4P1xAcWG--WL4oHQz5Jo7dSJ6mWAcPKo,26535
+sqlalchemy/ext/asyncio/events.py,sha256=_rh2nSAD_6ZqoIRJihiCKUgzSMLBMxBuZ_gUWLpfbHg,1423
+sqlalchemy/ext/asyncio/exc.py,sha256=3tcIXQPCJROB3P_TkoHmkzy6o_dIIuMcnnu4tJB__ck,639
+sqlalchemy/ext/asyncio/result.py,sha256=OPsKEHnMNP80BJI8kLExY8OQovff_2Wj8Kvxd4t3Ht0,21238
+sqlalchemy/ext/asyncio/scoping.py,sha256=fckFlTcwgGjgurVnp69En-4IFwWRqgUV6ukGgPklDJ4,2960
+sqlalchemy/ext/asyncio/session.py,sha256=xnG1gDwtPK1rNHwPg8_4l4uMpQ0dcDIE4U1z46ugbAY,24025
+sqlalchemy/ext/automap.py,sha256=-x_Ls5a-opmgYwpjDGjmtrR1hqSy7AvKfUthK5UHD2A,45782
+sqlalchemy/ext/baked.py,sha256=DI4hcMk-poznDtAB6S38S7kvo5DXuvrt1CIAT8t5QPw,19969
+sqlalchemy/ext/compiler.py,sha256=Q3Dkj-viLi_1_OFL1EUKsz3RJ8aQk6bYwIflx6tbZR0,22629
+sqlalchemy/ext/declarative/__init__.py,sha256=NS6-oy4iI6AiMaGWGznzYSx4gnB1fOniOGtqPHxC0ms,1842
+sqlalchemy/ext/declarative/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/ext/declarative/__pycache__/extensions.cpython-39.pyc,,
+sqlalchemy/ext/declarative/extensions.py,sha256=QRFpuT6pGz11MzHNNh3L7EJmgkbmNmYinBV4aqPRaMY,16560
+sqlalchemy/ext/horizontal_shard.py,sha256=2NygP6u9SsOlOqCEqkzNbcSshdxtfxOI78XysnJw3S8,8922
+sqlalchemy/ext/hybrid.py,sha256=OSy2ZB-4i46Ai5NYncBQ4VAd19clflN6esAUGAgKxJE,41939
+sqlalchemy/ext/indexable.py,sha256=RZmG2074pMoM9-A3evs2ZKqMn3M9uTc3izAI1cN6HQc,11255
+sqlalchemy/ext/instrumentation.py,sha256=ReSLFxqbHgwAKNwoQQmKHoqYvWCob_WuXlPAEUJk4pk,14386
+sqlalchemy/ext/mutable.py,sha256=55s0_SIfISCtX4EMBlBQ9Xp-CZr5FIe7sb4WGBSx7R0,32328
+sqlalchemy/ext/mypy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+sqlalchemy/ext/mypy/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/ext/mypy/__pycache__/apply.cpython-39.pyc,,
+sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-39.pyc,,
+sqlalchemy/ext/mypy/__pycache__/infer.cpython-39.pyc,,
+sqlalchemy/ext/mypy/__pycache__/names.cpython-39.pyc,,
+sqlalchemy/ext/mypy/__pycache__/plugin.cpython-39.pyc,,
+sqlalchemy/ext/mypy/__pycache__/util.cpython-39.pyc,,
+sqlalchemy/ext/mypy/apply.py,sha256=9FIH7jxh6Rl1YDE_3tsacpfNb_8floNQkTuHaNgL7XU,9610
+sqlalchemy/ext/mypy/decl_class.py,sha256=buWnXWGOR71CADPZ0_51S49imTXDo-LjTjWsWhhgee0,17343
+sqlalchemy/ext/mypy/infer.py,sha256=otnyujWtI9x7IqsYMu-c21_AJigyAtsaHW6XmVXcaBk,18028
+sqlalchemy/ext/mypy/names.py,sha256=exMWKhQ7ouSFXojttr0ZadmigT5O_wFQ1rmZ4r7Ks4g,7930
+sqlalchemy/ext/mypy/plugin.py,sha256=6JnnsFCOJVwkF1o6FmXRhBYszq5gmli_lqLZJKMhALA,9245
+sqlalchemy/ext/mypy/util.py,sha256=NuIWpY4W5CXES-3q3lviisWuQhwtaQmkAejOspfrGls,8242
+sqlalchemy/ext/orderinglist.py,sha256=JtRiLDROBsDJnME4kZMDzr3FI6rheP-bd1M-C6zxDPU,13875
+sqlalchemy/ext/serializer.py,sha256=RC0aOS6nlFdA0Agkw_-3iiw7Ah2bZnY7sZVZFGj7vHI,5956
+sqlalchemy/future/__init__.py,sha256=tDG3ddqc3cRE61x7Q32ekTBQONsdy30drnW6KnIB92g,525
+sqlalchemy/future/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/future/__pycache__/engine.cpython-39.pyc,,
+sqlalchemy/future/engine.py,sha256=Ly-M3NGamVrpnA9XOG_nVLra5f7OlmTMmg7dMb2tn4s,16184
+sqlalchemy/future/orm/__init__.py,sha256=EKGpGVxFh3-ZA34c1Ujfy51Z_2oG05CFiSxk48pE1R8,289
+sqlalchemy/future/orm/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/inspection.py,sha256=Bcoh4cUJMKjZHcGQP-_Nz-swGXLVVWidj36W2F35Trg,3051
+sqlalchemy/log.py,sha256=0zxWZ9_FkRwYyjTvTaBGW9wMlRG0dSmbAb7SvW42EfY,7143
+sqlalchemy/orm/__init__.py,sha256=ECAf9d5L7wG58S3ijtNRJaQrdgX3WxDJxTlVVPk0hvk,10964
+sqlalchemy/orm/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/attributes.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/clsregistry.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/collections.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/context.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/decl_api.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/decl_base.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/dependency.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/descriptor_props.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/dynamic.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/evaluator.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/events.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/exc.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/identity.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/instrumentation.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/interfaces.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/loading.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/mapper.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/path_registry.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/persistence.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/properties.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/query.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/relationships.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/scoping.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/session.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/state.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/strategies.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/strategy_options.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/sync.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/unitofwork.cpython-39.pyc,,
+sqlalchemy/orm/__pycache__/util.cpython-39.pyc,,
+sqlalchemy/orm/attributes.py,sha256=u3tFz0hQdKyh_mCD53rWSKzaPnerquZiM9C71MHsOa4,77098
+sqlalchemy/orm/base.py,sha256=HZu51CAOyCjJqGGPJbFqOgqbbA_yQ06Lucxpf-J1B54,15068
+sqlalchemy/orm/clsregistry.py,sha256=i8-S8jCSsslTUlOXmfaxoDDkxy3nYGUiZVUeJlpDERA,13286
+sqlalchemy/orm/collections.py,sha256=YXLS4MyQIGWVAV5S3sXLvJKdfVCFAQsKFymOgxzkSuU,54723
+sqlalchemy/orm/context.py,sha256=S6T5BUNFbFqaSGOU5wJdv-k-m8lb0bvLZBjWV3kDwGs,111549
+sqlalchemy/orm/decl_api.py,sha256=rZSz1jys3n_V2woNUZuV8nciN0VgDZFMAiQdNbLkr10,35564
+sqlalchemy/orm/decl_base.py,sha256=unKLbWcQZ3At3nbqh6wbK8YtyGpywuJoBoCB00KJse8,44746
+sqlalchemy/orm/dependency.py,sha256=RsQ6UtF0Ryl-hgMqw9mm5tqNCZa5bbW56_X1prm6R-8,46987
+sqlalchemy/orm/descriptor_props.py,sha256=mdVGbdKc4N8gCxV2RXDGMFZB3V2aWZARUUH9VOe0K1s,25987
+sqlalchemy/orm/dynamic.py,sha256=heJsZBQSckDO1k2fYd1x1tap6qEDoS2yogx9VapzIY4,15957
+sqlalchemy/orm/evaluator.py,sha256=Cc4vdrYRq8acVhpVmGenA-xn_GRodgkhuplLlqUfrdo,6852
+sqlalchemy/orm/events.py,sha256=_9TO_KPRfTTdL9Lh53w2z8k4GPHOdeG_LGWHW6JwPBQ,111523
+sqlalchemy/orm/exc.py,sha256=dCW9lmc-DpwTJaHo-q8TJac5dK2jWFc4Fes6V8Z_gUo,6532
+sqlalchemy/orm/identity.py,sha256=_UnI-6Beolu3xWGGERUQfVg0dc9sb-3G22Xv8yzfKFg,7233
+sqlalchemy/orm/instrumentation.py,sha256=L4pmTKaGvmRjd8UTThGF-QqScFIWWs9hx6AZA0-rOn0,20378
+sqlalchemy/orm/interfaces.py,sha256=kAu29kzWsyA9j44tkT-Iqia_jqmj2WZpTgTGuJ90_9g,30249
+sqlalchemy/orm/loading.py,sha256=5rAng8kIp5UOLtZd5nUjduDLIhUQ80Sodc9W-jSMc1E,49317
+sqlalchemy/orm/mapper.py,sha256=NHtbt5VmUWWiQLuMkeOWAiYcbSk-FrZLyGXHv78RZug,140420
+sqlalchemy/orm/path_registry.py,sha256=0Akeeayg-OM-pPOAxVCyggGINInYX8kXrQkYWOtesd0,16411
+sqlalchemy/orm/persistence.py,sha256=KW7iYNJpEHjUMVFr_pQkkyvoSC1cfSpmzRvVv1H_sgs,84250
+sqlalchemy/orm/properties.py,sha256=XmmjsU1XBTyIe1mX8DZ2EdavRutLWxO7QN1k2cJVJ4w,14665
+sqlalchemy/orm/query.py,sha256=9aBTx4yskglMfirPKc9u_RwjmtXz2s3Be7dKHCmcEtY,125553
+sqlalchemy/orm/relationships.py,sha256=XWAG1IsWxKnnIZ1riUonvVlxlegzU0ZJyqFrHMN7eyU,143246
+sqlalchemy/orm/scoping.py,sha256=K4sY8l969uQigmm9VV1GL4XmIA505r_x_1yeDZSRWMQ,7257
+sqlalchemy/orm/session.py,sha256=cql6nG4OabB-Q6d9I2X46oWS9hmMuucY_TWFnLazZxQ,160925
+sqlalchemy/orm/state.py,sha256=1cqhgv40Z7zLD2iC331H4VCMavIICCmxFD6jF7-UYSo,33536
+sqlalchemy/orm/strategies.py,sha256=PXl1qSS4ABH19vOz0wVmnlzqfP02XaggC2da1G09tOs,108207
+sqlalchemy/orm/strategy_options.py,sha256=0S100xkX9Lwor6TeRjZ1aNJgEexeCrlKTq6xME85RcQ,67319
+sqlalchemy/orm/sync.py,sha256=KRyKql_Pgjm_y8clsUOLe8jo5JzM1t6II2vCorbtRow,5824
+sqlalchemy/orm/unitofwork.py,sha256=XEMx8PhX-KdP9tQpVgB0mcqnPlVbpSPG4bSKW6zIMRE,27090
+sqlalchemy/orm/util.py,sha256=eR6Q9HW87ycIsAsC8Xs9Z_09VGUQqu3-kptJTXKJcyw,74969
+sqlalchemy/pool/__init__.py,sha256=dTuz0I0lQ1aj_BHoMzoBk4FW1rI-4ssLZfXi7826ja8,1603
+sqlalchemy/pool/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/pool/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/pool/__pycache__/dbapi_proxy.cpython-39.pyc,,
+sqlalchemy/pool/__pycache__/events.cpython-39.pyc,,
+sqlalchemy/pool/__pycache__/impl.cpython-39.pyc,,
+sqlalchemy/pool/base.py,sha256=qffJ_mAqPfxcERTWZkKc0sLHVl-BFsFEVY7R7BmKtpI,38552
+sqlalchemy/pool/dbapi_proxy.py,sha256=ZDa32bJzGunYw8OyS5g0GfLoRo-Qwrf7jcsGsA9StSg,4229
+sqlalchemy/pool/events.py,sha256=nVQfjW55gD6-DEtTIDUCx-cNHZCKtt7C3gsdqf-PFWg,10299
+sqlalchemy/pool/impl.py,sha256=m8kUBUGN3ZikSndBO8mcu2ym8kd_o8vEtLsDSycZXAI,15783
+sqlalchemy/processors.py,sha256=LWwr9g-qDHiike9UKqD1yX8ghCxjpAWRdQk7Mh5NepA,5745
+sqlalchemy/schema.py,sha256=FLG1OeHCucohyiShM_jvw4OJivdrWSAsI7MxPIX7Q1M,2413
+sqlalchemy/sql/__init__.py,sha256=ojeq7QnyQrUcO1Ia7nogzumgOfTKXk6Oib7HuH_hz6Y,4661
+sqlalchemy/sql/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/annotation.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/base.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/coercions.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/compiler.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/crud.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/ddl.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/default_comparator.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/dml.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/elements.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/events.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/expression.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/functions.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/lambdas.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/naming.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/operators.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/roles.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/schema.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/selectable.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/sqltypes.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/traversals.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/type_api.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/util.cpython-39.pyc,,
+sqlalchemy/sql/__pycache__/visitors.cpython-39.pyc,,
+sqlalchemy/sql/annotation.py,sha256=xGpbeieggvywgRlqerZxz6lYnuSob7C86rJQ87k6Va0,11502
+sqlalchemy/sql/base.py,sha256=grJ02HrUj2yoDqlrhbNR_J4RHSahsyFilmvVgnCKb2g,55897
+sqlalchemy/sql/coercions.py,sha256=r5bczqjtsm67jl6RiPxyY-ictLPqtPQO0OnhhSN2zCI,34530
+sqlalchemy/sql/compiler.py,sha256=ouvZ78SDPQ8M6dYXtwSWhRgNUdzUDQSG3eIJEO4TjvE,187816
+sqlalchemy/sql/crud.py,sha256=yMGTebDMvF2Hpdto3YSwK6GiRLPpSbRVcZby1zU3n4w,35967
+sqlalchemy/sql/ddl.py,sha256=OV8dpPN3tW0nepwxitfz05W804mGJX6I3HHNJsI0mDo,44208
+sqlalchemy/sql/default_comparator.py,sha256=GR_hgIHtrZWq6j6yTWpiOWTUjIts5gn-UBcE37JVvfk,11178
+sqlalchemy/sql/dml.py,sha256=xAI5vzJFY_Y8_AEhJCo1Cxj-2M9tZzljVcpQ7-iUnpM,54663
+sqlalchemy/sql/elements.py,sha256=Y3CsWkDSEUOWOtV596KK5VlICbPcvTdpbJSHNsmLBig,181521
+sqlalchemy/sql/events.py,sha256=tusqYeUwf421_pG-T39wIHz7RtzixftIPhJG6CP_6Io,13369
+sqlalchemy/sql/expression.py,sha256=cyzp-pgHBfrgQ6_mRxo4T4zNSKIIzd40PlRLgwXI5aM,8828
+sqlalchemy/sql/functions.py,sha256=QYHZ7IX23HqbjMyKddxtndgPsg9cinjKEwP3NQvt16I,48481
+sqlalchemy/sql/lambdas.py,sha256=Jh4K1h_Vqp9bKlVGYrIFGfbFZ6WjhitVPyMtpEpeLZw,44913
+sqlalchemy/sql/naming.py,sha256=bmjEtvUW0Ccrc5tzH0_PcoPeA5jAtDLPJ4QxtKaAwe8,6786
+sqlalchemy/sql/operators.py,sha256=GyqVaHQC41uVX3vVirO03CXL5ZMDlmUATzrC6oT8RxI,48199
+sqlalchemy/sql/roles.py,sha256=ZTgs4PY4rneDh2suTVbmn25yGZyW34sztNWX8cOUf3M,5638
+sqlalchemy/sql/schema.py,sha256=pbLkR844wkM0uzIXTAyauACab3vor1IhmUhBreoqG94,195347
+sqlalchemy/sql/selectable.py,sha256=-n-OqE8dXqD9wNezC2kfLe0U41StaISevxhBsZ9DcRA,237341
+sqlalchemy/sql/sqltypes.py,sha256=HfmfIUOTD3WxxkuvA3_Qoh94RNLLXTPlWJz_SdE0PlU,113509
+sqlalchemy/sql/traversals.py,sha256=P0GP8F8RlM-lpL5jm3gWj7-NnE8klIXEcDmHk5Dmc-c,52719
+sqlalchemy/sql/type_api.py,sha256=IHOZMFl05LgcJ8FfqGGr703bEQEc8c56ru9vJdX-PEU,71036
+sqlalchemy/sql/util.py,sha256=JI2eMLpaDzZQjG3Cd4AopUmIMfzQXFIQVUJj8TG8gWw,35856
+sqlalchemy/sql/visitors.py,sha256=XLRAf08NKf5ndsNDIRY3wPJaaEBIIxl3DDI_dTKrh_s,27329
+sqlalchemy/testing/__init__.py,sha256=TKwXQsqFFV4gjeO48VGaLhCE99qhIVSQNxFrKdP6uNk,2850
+sqlalchemy/testing/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/assertions.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/assertsql.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/asyncio.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/config.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/engines.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/entities.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/exclusions.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/fixtures.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/mock.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/pickleable.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/profiling.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/provision.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/requirements.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/schema.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/util.cpython-39.pyc,,
+sqlalchemy/testing/__pycache__/warnings.cpython-39.pyc,,
+sqlalchemy/testing/assertions.py,sha256=fcCcIUk04m2XgpotqK2mRD5nKXsyOHXV8tchAAnfQyk,26502
+sqlalchemy/testing/assertsql.py,sha256=OIt0QyHKlFJ4zxu6WrX8_ufmBD9KrMgFrjsXTGkU3ys,14964
+sqlalchemy/testing/asyncio.py,sha256=B6ZqYcQpT6QtM8gR3o3AcZX32J6ZbWDqTTZGklVo5-I,3671
+sqlalchemy/testing/config.py,sha256=XhmzFNkEN_djORr4r6owvoIl3G5zA6Eo5neUiEJXy0E,6543
+sqlalchemy/testing/engines.py,sha256=s4h7bKB2Bqmu1rlquR2O88UktP03n6UVrrWkTNhqm3w,13392
+sqlalchemy/testing/entities.py,sha256=sOd9BlmZFPQFrBdCUlkOR8lxGEQNExkJmS_V2U5WIOk,3253
+sqlalchemy/testing/exclusions.py,sha256=zOthfVJs07z9wN2iAH0rGT39Q76Y_2cBuk5dPEW4wOA,13329
+sqlalchemy/testing/fixtures.py,sha256=xdWPaYEK6w1SZmVYdgyPorx2Nkmvfcykib0T0jw-P_Q,26829
+sqlalchemy/testing/mock.py,sha256=RUTHkpnxCQfsDlEZ_aQttL_3SXLATwxt4olgmSxAsJw,894
+sqlalchemy/testing/pickleable.py,sha256=QlwC2Cr7vKkHlj86t2Wlq9eGteZFXkvPpGlWAl9_g7Y,2886
+sqlalchemy/testing/plugin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+sqlalchemy/testing/plugin/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-39.pyc,,
+sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-39.pyc,,
+sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-39.pyc,,
+sqlalchemy/testing/plugin/__pycache__/reinvent_fixtures_py2k.cpython-39.pyc,,
+sqlalchemy/testing/plugin/bootstrap.py,sha256=038KOv89msOTFsWoDvCyPRb3ZTMv5eAOOKoGPHuZ7zs,1701
+sqlalchemy/testing/plugin/plugin_base.py,sha256=9Bg56KOsZSGW1jLHh_7fle85yFocyV8AGGVlswO9XAU,21540
+sqlalchemy/testing/plugin/pytestplugin.py,sha256=_NbB52E6sv6R9NJApMxMnwomH8y7iirfCYKnXvUH1g0,26133
+sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py,sha256=MdakbJzFh8N_7gUpX-nFbGPFs3AZRsmDAe-7zucf0ls,3288
+sqlalchemy/testing/profiling.py,sha256=ullStV2c-R4jTQJMK1tMKZE5qtSZ-PB1LzHod_hA230,10566
+sqlalchemy/testing/provision.py,sha256=IPpsZg4Pc42mXGScKdLri0SjeWJrURXbBF1S9m6ftY8,12070
+sqlalchemy/testing/requirements.py,sha256=G-l-20BjZ6eMA7TIy3FO4Ck_T6acLz9XwBheQI4Dql0,43499
+sqlalchemy/testing/schema.py,sha256=INOq15yhNyANmheylSQBUlm0IWRaAkEX22BpHSMqn08,6544
+sqlalchemy/testing/suite/__init__.py,sha256=_firVc2uS3TMZ3vH2baQzNb17ubM78RHtb9kniSybmk,476
+sqlalchemy/testing/suite/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_cte.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_insert.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_results.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_select.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_types.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-39.pyc,,
+sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-39.pyc,,
+sqlalchemy/testing/suite/test_cte.py,sha256=XuTuaWblSXyO1OOUTShBBmNch7fBdGnlMD84ooVTqFY,6183
+sqlalchemy/testing/suite/test_ddl.py,sha256=UwbfljXHdWUen3muIcgnOPi-A4AO6F1QzSOiHf9lU-A,11762
+sqlalchemy/testing/suite/test_deprecations.py,sha256=8oLDFUswey8KjPFKRUsqMyGT5sUMMoPQr7-XyIBMehw,5059
+sqlalchemy/testing/suite/test_dialect.py,sha256=eR1VVOb2fm955zavpWkmMjipCva3QvEE177U0OG-0LY,10895
+sqlalchemy/testing/suite/test_insert.py,sha256=oKtVjFuxqdSV5uKj5-OxdSABupLp0pECkWkSLd2U_QA,11134
+sqlalchemy/testing/suite/test_reflection.py,sha256=p-m2BjuWh7jW2vXvY_LxYsfjW47HqGs9O9PUpfm1HIs,58130
+sqlalchemy/testing/suite/test_results.py,sha256=xcoSl1ueaHo8LgKZp0Z1lJ44Mhjf2hxlWs_LjNLBNiE,13983
+sqlalchemy/testing/suite/test_rowcount.py,sha256=GQQRXIWbb6SfD5hwtBC8qvkGAgi1rI5Pv3c59eoumck,4877
+sqlalchemy/testing/suite/test_select.py,sha256=is3BbULeOWOJTRCoUwPnh6Crue15FXfkXKqAkxrFeGM,55464
+sqlalchemy/testing/suite/test_sequence.py,sha256=eCyOQlynF8T0cLrIMz0PO6WuW8ktpFVYq_fQp5CQ298,8431
+sqlalchemy/testing/suite/test_types.py,sha256=airX8OuJJdft4DU8okOLecJbcUhC15urr60Yu1U8Qe4,48044
+sqlalchemy/testing/suite/test_unicode_ddl.py,sha256=CndeAtV3DWJXxLbOoumqf4_mOOYcW_yNOrbKQ4cwFhw,6737
+sqlalchemy/testing/suite/test_update_delete.py,sha256=w9MMRqJCm7OW0Q5XaVjS6B8BGY_b_VvBeK3EWr7NKhU,1625
+sqlalchemy/testing/util.py,sha256=bvCWcESEPEO8QUTH0CcOa4Xg65EYK--V8Q_XeFcfGfE,12503
+sqlalchemy/testing/warnings.py,sha256=l9lI3heNOSbKreAhLcABpaA1e_6Ioi4l7q0mr5jY5OI,2270
+sqlalchemy/types.py,sha256=x8YDIEypMHOzWb7dzp67tW2WfDF7xtdh72HVDxm-aaY,2995
+sqlalchemy/util/__init__.py,sha256=75NADEtwE5GMCS27VcsEnTsTq1nSvXmJ2GY2aU3Q8hI,6373
+sqlalchemy/util/__pycache__/__init__.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/_collections.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/_compat_py3k.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/_concurrency_py3k.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/_preloaded.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/compat.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/concurrency.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/deprecations.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/langhelpers.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/queue.cpython-39.pyc,,
+sqlalchemy/util/__pycache__/topological.cpython-39.pyc,,
+sqlalchemy/util/_collections.py,sha256=Nulmym_NZYGN4OyE9cMtIVSoTwOzk3eJpSJ20l8j-lU,29139
+sqlalchemy/util/_compat_py3k.py,sha256=KibHVHAIlQfYdl8xs3ZhJQDlWEI6EhudTbOnMc2x9e4,2195
+sqlalchemy/util/_concurrency_py3k.py,sha256=5fTahmOgokaam-u-z7Xv0DYKR7YnK4TNjQqbVRYhoKQ,6598
+sqlalchemy/util/_preloaded.py,sha256=rx7QZ4T1zDZV5lktSvQlop3O0kdbCFVMmNDp5IOhpXQ,2396
+sqlalchemy/util/compat.py,sha256=cRcYIpcBc6aV_yboUTsKpmX1ssICP7kloCJRqEMsRBs,18281
+sqlalchemy/util/concurrency.py,sha256=LtozDo0PsiToyVmKzSDnu8qOMhRyGVjTNMsBiKro9d8,2278
+sqlalchemy/util/deprecations.py,sha256=RXg5M_MQhaopn00uTB0WEcz5yTTmPu2OCFPNklw5Uv4,11774
+sqlalchemy/util/langhelpers.py,sha256=RIlviqqBbBy1XhMnOwQHtmtAofNtMF79aCu3wa9Iycc,56288
+sqlalchemy/util/queue.py,sha256=FW6DSeO_GadaW0UA2EXjrBtFPRHO-dNGEoRwqHTfkMA,9293
+sqlalchemy/util/topological.py,sha256=MV1lkI2E0JdVIJVplggVo6iO_ZEVlUHRGvMW9AsXJRA,2859
diff --git a/lib/SQLAlchemy-1.4.40.dist-info/REQUESTED b/lib/SQLAlchemy-1.4.40.dist-info/REQUESTED
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/SQLAlchemy-1.4.40.dist-info/REQUESTED
diff --git a/lib/SQLAlchemy-1.4.40.dist-info/WHEEL b/lib/SQLAlchemy-1.4.40.dist-info/WHEEL
new file mode 100644
index 0000000..ba97021
--- /dev/null
+++ b/lib/SQLAlchemy-1.4.40.dist-info/WHEEL
@@ -0,0 +1,8 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.37.1)
+Root-Is-Purelib: false
+Tag: cp39-cp39-manylinux_2_5_x86_64
+Tag: cp39-cp39-manylinux1_x86_64
+Tag: cp39-cp39-manylinux_2_17_x86_64
+Tag: cp39-cp39-manylinux2014_x86_64
+
diff --git a/lib/SQLAlchemy-1.4.40.dist-info/top_level.txt b/lib/SQLAlchemy-1.4.40.dist-info/top_level.txt
new file mode 100644
index 0000000..39fb2be
--- /dev/null
+++ b/lib/SQLAlchemy-1.4.40.dist-info/top_level.txt
@@ -0,0 +1 @@
+sqlalchemy
diff --git a/lib/_cffi_backend.cpython-39-x86_64-linux-gnu.so b/lib/_cffi_backend.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..5f69b7e
--- /dev/null
+++ b/lib/_cffi_backend.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/_dbus_bindings.cpython-310-x86_64-linux-gnu.so b/lib/_dbus_bindings.cpython-310-x86_64-linux-gnu.so
new file mode 100755
index 0000000..e6b9dc2
--- /dev/null
+++ b/lib/_dbus_bindings.cpython-310-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/_dbus_glib_bindings.cpython-310-x86_64-linux-gnu.so b/lib/_dbus_glib_bindings.cpython-310-x86_64-linux-gnu.so
new file mode 100755
index 0000000..b204c39
--- /dev/null
+++ b/lib/_dbus_glib_bindings.cpython-310-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/_snack.so b/lib/_snack.so
new file mode 100755
index 0000000..7d3a27e
--- /dev/null
+++ b/lib/_snack.so
Binary files differ
diff --git a/lib/bin/chardetect b/lib/bin/chardetect
new file mode 100755
index 0000000..0c5ee27
--- /dev/null
+++ b/lib/bin/chardetect
@@ -0,0 +1,8 @@
+#!/opt/sunpy3/bin/python3.9
+# -*- coding: utf-8 -*-
+import re
+import sys
+from chardet.cli.chardetect import main
+if __name__ == '__main__':
+ sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
+ sys.exit(main())
diff --git a/lib/cffi-1.15.1.dist-info/INSTALLER b/lib/cffi-1.15.1.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/cffi-1.15.1.dist-info/LICENSE b/lib/cffi-1.15.1.dist-info/LICENSE
new file mode 100644
index 0000000..29225ee
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/LICENSE
@@ -0,0 +1,26 @@
+
+Except when otherwise stated (look for LICENSE files in directories or
+information at the beginning of each file) all software and
+documentation is licensed as follows:
+
+ The MIT License
+
+ Permission is hereby granted, free of charge, to any person
+ obtaining a copy of this software and associated documentation
+ files (the "Software"), to deal in the Software without
+ restriction, including without limitation the rights to use,
+ copy, modify, merge, publish, distribute, sublicense, and/or
+ sell copies of the Software, and to permit persons to whom the
+ Software is furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included
+ in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+ THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ DEALINGS IN THE SOFTWARE.
+
diff --git a/lib/cffi-1.15.1.dist-info/METADATA b/lib/cffi-1.15.1.dist-info/METADATA
new file mode 100644
index 0000000..538e679
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/METADATA
@@ -0,0 +1,34 @@
+Metadata-Version: 2.1
+Name: cffi
+Version: 1.15.1
+Summary: Foreign Function Interface for Python calling C code.
+Home-page: http://cffi.readthedocs.org
+Author: Armin Rigo, Maciej Fijalkowski
+Author-email: python-cffi@googlegroups.com
+License: MIT
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: Implementation :: CPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+Classifier: License :: OSI Approved :: MIT License
+License-File: LICENSE
+Requires-Dist: pycparser
+
+
+CFFI
+====
+
+Foreign Function Interface for Python calling C code.
+Please see the `Documentation <http://cffi.readthedocs.org/>`_.
+
+Contact
+-------
+
+`Mailing list <https://groups.google.com/forum/#!forum/python-cffi>`_
diff --git a/lib/cffi-1.15.1.dist-info/RECORD b/lib/cffi-1.15.1.dist-info/RECORD
new file mode 100644
index 0000000..95f86e9
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/RECORD
@@ -0,0 +1,45 @@
+_cffi_backend.cpython-39-x86_64-linux-gnu.so,sha256=qsQ76hjYdszjTdFV9NI0mnEkr6uMz56ZMzB7JbnKVTg,981144
+cffi-1.15.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+cffi-1.15.1.dist-info/LICENSE,sha256=BLgPWwd7vtaICM_rreteNSPyqMmpZJXFh72W3x6sKjM,1294
+cffi-1.15.1.dist-info/METADATA,sha256=KP4G3WmavRgDGwD2b8Y_eDsM1YeV6ckcG6Alz3-D8VY,1144
+cffi-1.15.1.dist-info/RECORD,,
+cffi-1.15.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+cffi-1.15.1.dist-info/WHEEL,sha256=FNUt4eBsrBVn_Yc5KG3aXtKE40X3uNZrbGLNCbVxyFw,148
+cffi-1.15.1.dist-info/entry_points.txt,sha256=y6jTxnyeuLnL-XJcDv8uML3n6wyYiGRg8MTp_QGJ9Ho,75
+cffi-1.15.1.dist-info/top_level.txt,sha256=rE7WR3rZfNKxWI9-jn6hsHCAl7MDkB-FmuQbxWjFehQ,19
+cffi/__init__.py,sha256=6xB_tafGvhhM5Xvj0Ova3oPC2SEhVlLTEObVLnazeiM,513
+cffi/__pycache__/__init__.cpython-39.pyc,,
+cffi/__pycache__/api.cpython-39.pyc,,
+cffi/__pycache__/backend_ctypes.cpython-39.pyc,,
+cffi/__pycache__/cffi_opcode.cpython-39.pyc,,
+cffi/__pycache__/commontypes.cpython-39.pyc,,
+cffi/__pycache__/cparser.cpython-39.pyc,,
+cffi/__pycache__/error.cpython-39.pyc,,
+cffi/__pycache__/ffiplatform.cpython-39.pyc,,
+cffi/__pycache__/lock.cpython-39.pyc,,
+cffi/__pycache__/model.cpython-39.pyc,,
+cffi/__pycache__/pkgconfig.cpython-39.pyc,,
+cffi/__pycache__/recompiler.cpython-39.pyc,,
+cffi/__pycache__/setuptools_ext.cpython-39.pyc,,
+cffi/__pycache__/vengine_cpy.cpython-39.pyc,,
+cffi/__pycache__/vengine_gen.cpython-39.pyc,,
+cffi/__pycache__/verifier.cpython-39.pyc,,
+cffi/_cffi_errors.h,sha256=zQXt7uR_m8gUW-fI2hJg0KoSkJFwXv8RGUkEDZ177dQ,3908
+cffi/_cffi_include.h,sha256=tKnA1rdSoPHp23FnDL1mDGwFo-Uj6fXfA6vA6kcoEUc,14800
+cffi/_embedding.h,sha256=9tnjF44QRobR8z0FGqAmAZY-wMSBOae1SUPqHccowqc,17680
+cffi/api.py,sha256=yxJalIePbr1mz_WxAHokSwyP5CVYde44m-nolHnbJNo,42064
+cffi/backend_ctypes.py,sha256=h5ZIzLc6BFVXnGyc9xPqZWUS7qGy7yFSDqXe68Sa8z4,42454
+cffi/cffi_opcode.py,sha256=v9RdD_ovA8rCtqsC95Ivki5V667rAOhGgs3fb2q9xpM,5724
+cffi/commontypes.py,sha256=QS4uxCDI7JhtTyjh1hlnCA-gynmaszWxJaRRLGkJa1A,2689
+cffi/cparser.py,sha256=rO_1pELRw1gI1DE1m4gi2ik5JMfpxouAACLXpRPlVEA,44231
+cffi/error.py,sha256=v6xTiS4U0kvDcy4h_BDRo5v39ZQuj-IMRYLv5ETddZs,877
+cffi/ffiplatform.py,sha256=HMXqR8ks2wtdsNxGaWpQ_PyqIvtiuos_vf1qKCy-cwg,4046
+cffi/lock.py,sha256=l9TTdwMIMpi6jDkJGnQgE9cvTIR7CAntIJr8EGHt3pY,747
+cffi/model.py,sha256=_GH_UF1Rn9vC4AvmgJm6qj7RUXXG3eqKPc8bPxxyBKE,21768
+cffi/parse_c_type.h,sha256=OdwQfwM9ktq6vlCB43exFQmxDBtj2MBNdK8LYl15tjw,5976
+cffi/pkgconfig.py,sha256=LP1w7vmWvmKwyqLaU1Z243FOWGNQMrgMUZrvgFuOlco,4374
+cffi/recompiler.py,sha256=YgVYTh2CrXIobo-vMk7_K9mwAXdd_LqB4-IbYABQ488,64598
+cffi/setuptools_ext.py,sha256=RUR17N5f8gpiQBBlXL34P9FtOu1mhHIaAf3WJlg5S4I,8931
+cffi/vengine_cpy.py,sha256=YglN8YS-UaHEv2k2cxgotNWE87dHX20-68EyKoiKUYA,43320
+cffi/vengine_gen.py,sha256=5dX7s1DU6pTBOMI6oTVn_8Bnmru_lj932B6b4v29Hlg,26684
+cffi/verifier.py,sha256=ESwuXWXtXrKEagCKveLRDjFzLNCyaKdqAgAlKREcyhY,11253
diff --git a/lib/cffi-1.15.1.dist-info/REQUESTED b/lib/cffi-1.15.1.dist-info/REQUESTED
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/REQUESTED
diff --git a/lib/cffi-1.15.1.dist-info/WHEEL b/lib/cffi-1.15.1.dist-info/WHEEL
new file mode 100644
index 0000000..271bfec
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/WHEEL
@@ -0,0 +1,6 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.37.1)
+Root-Is-Purelib: false
+Tag: cp39-cp39-manylinux_2_17_x86_64
+Tag: cp39-cp39-manylinux2014_x86_64
+
diff --git a/lib/cffi-1.15.1.dist-info/entry_points.txt b/lib/cffi-1.15.1.dist-info/entry_points.txt
new file mode 100644
index 0000000..4b0274f
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/entry_points.txt
@@ -0,0 +1,2 @@
+[distutils.setup_keywords]
+cffi_modules = cffi.setuptools_ext:cffi_modules
diff --git a/lib/cffi-1.15.1.dist-info/top_level.txt b/lib/cffi-1.15.1.dist-info/top_level.txt
new file mode 100644
index 0000000..f645779
--- /dev/null
+++ b/lib/cffi-1.15.1.dist-info/top_level.txt
@@ -0,0 +1,2 @@
+_cffi_backend
+cffi
diff --git a/lib/cffi/__init__.py b/lib/cffi/__init__.py
new file mode 100644
index 0000000..90e2e65
--- /dev/null
+++ b/lib/cffi/__init__.py
@@ -0,0 +1,14 @@
+__all__ = ['FFI', 'VerificationError', 'VerificationMissing', 'CDefError',
+ 'FFIError']
+
+from .api import FFI
+from .error import CDefError, FFIError, VerificationError, VerificationMissing
+from .error import PkgConfigError
+
+__version__ = "1.15.1"
+__version_info__ = (1, 15, 1)
+
+# The verifier module file names are based on the CRC32 of a string that
+# contains the following version number. It may be older than __version__
+# if nothing is clearly incompatible.
+__version_verifier_modules__ = "0.8.6"
diff --git a/lib/cffi/_cffi_errors.h b/lib/cffi/_cffi_errors.h
new file mode 100644
index 0000000..158e059
--- /dev/null
+++ b/lib/cffi/_cffi_errors.h
@@ -0,0 +1,149 @@
+#ifndef CFFI_MESSAGEBOX
+# ifdef _MSC_VER
+# define CFFI_MESSAGEBOX 1
+# else
+# define CFFI_MESSAGEBOX 0
+# endif
+#endif
+
+
+#if CFFI_MESSAGEBOX
+/* Windows only: logic to take the Python-CFFI embedding logic
+ initialization errors and display them in a background thread
+ with MessageBox. The idea is that if the whole program closes
+ as a result of this problem, then likely it is already a console
+ program and you can read the stderr output in the console too.
+ If it is not a console program, then it will likely show its own
+ dialog to complain, or generally not abruptly close, and for this
+ case the background thread should stay alive.
+*/
+static void *volatile _cffi_bootstrap_text;
+
+static PyObject *_cffi_start_error_capture(void)
+{
+ PyObject *result = NULL;
+ PyObject *x, *m, *bi;
+
+ if (InterlockedCompareExchangePointer(&_cffi_bootstrap_text,
+ (void *)1, NULL) != NULL)
+ return (PyObject *)1;
+
+ m = PyImport_AddModule("_cffi_error_capture");
+ if (m == NULL)
+ goto error;
+
+ result = PyModule_GetDict(m);
+ if (result == NULL)
+ goto error;
+
+#if PY_MAJOR_VERSION >= 3
+ bi = PyImport_ImportModule("builtins");
+#else
+ bi = PyImport_ImportModule("__builtin__");
+#endif
+ if (bi == NULL)
+ goto error;
+ PyDict_SetItemString(result, "__builtins__", bi);
+ Py_DECREF(bi);
+
+ x = PyRun_String(
+ "import sys\n"
+ "class FileLike:\n"
+ " def write(self, x):\n"
+ " try:\n"
+ " of.write(x)\n"
+ " except: pass\n"
+ " self.buf += x\n"
+ " def flush(self):\n"
+ " pass\n"
+ "fl = FileLike()\n"
+ "fl.buf = ''\n"
+ "of = sys.stderr\n"
+ "sys.stderr = fl\n"
+ "def done():\n"
+ " sys.stderr = of\n"
+ " return fl.buf\n", /* make sure the returned value stays alive */
+ Py_file_input,
+ result, result);
+ Py_XDECREF(x);
+
+ error:
+ if (PyErr_Occurred())
+ {
+ PyErr_WriteUnraisable(Py_None);
+ PyErr_Clear();
+ }
+ return result;
+}
+
+#pragma comment(lib, "user32.lib")
+
+static DWORD WINAPI _cffi_bootstrap_dialog(LPVOID ignored)
+{
+ Sleep(666); /* may be interrupted if the whole process is closing */
+#if PY_MAJOR_VERSION >= 3
+ MessageBoxW(NULL, (wchar_t *)_cffi_bootstrap_text,
+ L"Python-CFFI error",
+ MB_OK | MB_ICONERROR);
+#else
+ MessageBoxA(NULL, (char *)_cffi_bootstrap_text,
+ "Python-CFFI error",
+ MB_OK | MB_ICONERROR);
+#endif
+ _cffi_bootstrap_text = NULL;
+ return 0;
+}
+
+static void _cffi_stop_error_capture(PyObject *ecap)
+{
+ PyObject *s;
+ void *text;
+
+ if (ecap == (PyObject *)1)
+ return;
+
+ if (ecap == NULL)
+ goto error;
+
+ s = PyRun_String("done()", Py_eval_input, ecap, ecap);
+ if (s == NULL)
+ goto error;
+
+ /* Show a dialog box, but in a background thread, and
+ never show multiple dialog boxes at once. */
+#if PY_MAJOR_VERSION >= 3
+ text = PyUnicode_AsWideCharString(s, NULL);
+#else
+ text = PyString_AsString(s);
+#endif
+
+ _cffi_bootstrap_text = text;
+
+ if (text != NULL)
+ {
+ HANDLE h;
+ h = CreateThread(NULL, 0, _cffi_bootstrap_dialog,
+ NULL, 0, NULL);
+ if (h != NULL)
+ CloseHandle(h);
+ }
+ /* decref the string, but it should stay alive as 'fl.buf'
+ in the small module above. It will really be freed only if
+ we later get another similar error. So it's a leak of at
+ most one copy of the small module. That's fine for this
+ situation which is usually a "fatal error" anyway. */
+ Py_DECREF(s);
+ PyErr_Clear();
+ return;
+
+ error:
+ _cffi_bootstrap_text = NULL;
+ PyErr_Clear();
+}
+
+#else
+
+static PyObject *_cffi_start_error_capture(void) { return NULL; }
+static void _cffi_stop_error_capture(PyObject *ecap) { }
+
+#endif
diff --git a/lib/cffi/_cffi_include.h b/lib/cffi/_cffi_include.h
new file mode 100644
index 0000000..e4c0a67
--- /dev/null
+++ b/lib/cffi/_cffi_include.h
@@ -0,0 +1,385 @@
+#define _CFFI_
+
+/* We try to define Py_LIMITED_API before including Python.h.
+
+ Mess: we can only define it if Py_DEBUG, Py_TRACE_REFS and
+ Py_REF_DEBUG are not defined. This is a best-effort approximation:
+ we can learn about Py_DEBUG from pyconfig.h, but it is unclear if
+ the same works for the other two macros. Py_DEBUG implies them,
+ but not the other way around.
+
+ The implementation is messy (issue #350): on Windows, with _MSC_VER,
+ we have to define Py_LIMITED_API even before including pyconfig.h.
+ In that case, we guess what pyconfig.h will do to the macros above,
+ and check our guess after the #include.
+
+ Note that on Windows, with CPython 3.x, you need >= 3.5 and virtualenv
+ version >= 16.0.0. With older versions of either, you don't get a
+ copy of PYTHON3.DLL in the virtualenv. We can't check the version of
+ CPython *before* we even include pyconfig.h. ffi.set_source() puts
+ a ``#define _CFFI_NO_LIMITED_API'' at the start of this file if it is
+ running on Windows < 3.5, as an attempt at fixing it, but that's
+ arguably wrong because it may not be the target version of Python.
+ Still better than nothing I guess. As another workaround, you can
+ remove the definition of Py_LIMITED_API here.
+
+ See also 'py_limited_api' in cffi/setuptools_ext.py.
+*/
+#if !defined(_CFFI_USE_EMBEDDING) && !defined(Py_LIMITED_API)
+# ifdef _MSC_VER
+# if !defined(_DEBUG) && !defined(Py_DEBUG) && !defined(Py_TRACE_REFS) && !defined(Py_REF_DEBUG) && !defined(_CFFI_NO_LIMITED_API)
+# define Py_LIMITED_API
+# endif
+# include <pyconfig.h>
+ /* sanity-check: Py_LIMITED_API will cause crashes if any of these
+ are also defined. Normally, the Python file PC/pyconfig.h does not
+ cause any of these to be defined, with the exception that _DEBUG
+ causes Py_DEBUG. Double-check that. */
+# ifdef Py_LIMITED_API
+# if defined(Py_DEBUG)
+# error "pyconfig.h unexpectedly defines Py_DEBUG, but Py_LIMITED_API is set"
+# endif
+# if defined(Py_TRACE_REFS)
+# error "pyconfig.h unexpectedly defines Py_TRACE_REFS, but Py_LIMITED_API is set"
+# endif
+# if defined(Py_REF_DEBUG)
+# error "pyconfig.h unexpectedly defines Py_REF_DEBUG, but Py_LIMITED_API is set"
+# endif
+# endif
+# else
+# include <pyconfig.h>
+# if !defined(Py_DEBUG) && !defined(Py_TRACE_REFS) && !defined(Py_REF_DEBUG) && !defined(_CFFI_NO_LIMITED_API)
+# define Py_LIMITED_API
+# endif
+# endif
+#endif
+
+#include <Python.h>
+#ifdef __cplusplus
+extern "C" {
+#endif
+#include <stddef.h>
+#include "parse_c_type.h"
+
+/* this block of #ifs should be kept exactly identical between
+ c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py
+ and cffi/_cffi_include.h */
+#if defined(_MSC_VER)
+# include <malloc.h> /* for alloca() */
+# if _MSC_VER < 1600 /* MSVC < 2010 */
+ typedef __int8 int8_t;
+ typedef __int16 int16_t;
+ typedef __int32 int32_t;
+ typedef __int64 int64_t;
+ typedef unsigned __int8 uint8_t;
+ typedef unsigned __int16 uint16_t;
+ typedef unsigned __int32 uint32_t;
+ typedef unsigned __int64 uint64_t;
+ typedef __int8 int_least8_t;
+ typedef __int16 int_least16_t;
+ typedef __int32 int_least32_t;
+ typedef __int64 int_least64_t;
+ typedef unsigned __int8 uint_least8_t;
+ typedef unsigned __int16 uint_least16_t;
+ typedef unsigned __int32 uint_least32_t;
+ typedef unsigned __int64 uint_least64_t;
+ typedef __int8 int_fast8_t;
+ typedef __int16 int_fast16_t;
+ typedef __int32 int_fast32_t;
+ typedef __int64 int_fast64_t;
+ typedef unsigned __int8 uint_fast8_t;
+ typedef unsigned __int16 uint_fast16_t;
+ typedef unsigned __int32 uint_fast32_t;
+ typedef unsigned __int64 uint_fast64_t;
+ typedef __int64 intmax_t;
+ typedef unsigned __int64 uintmax_t;
+# else
+# include <stdint.h>
+# endif
+# if _MSC_VER < 1800 /* MSVC < 2013 */
+# ifndef __cplusplus
+ typedef unsigned char _Bool;
+# endif
+# endif
+#else
+# include <stdint.h>
+# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux)
+# include <alloca.h>
+# endif
+#endif
+
+#ifdef __GNUC__
+# define _CFFI_UNUSED_FN __attribute__((unused))
+#else
+# define _CFFI_UNUSED_FN /* nothing */
+#endif
+
+#ifdef __cplusplus
+# ifndef _Bool
+ typedef bool _Bool; /* semi-hackish: C++ has no _Bool; bool is builtin */
+# endif
+#endif
+
+/********** CPython-specific section **********/
+#ifndef PYPY_VERSION
+
+
+#if PY_MAJOR_VERSION >= 3
+# define PyInt_FromLong PyLong_FromLong
+#endif
+
+#define _cffi_from_c_double PyFloat_FromDouble
+#define _cffi_from_c_float PyFloat_FromDouble
+#define _cffi_from_c_long PyInt_FromLong
+#define _cffi_from_c_ulong PyLong_FromUnsignedLong
+#define _cffi_from_c_longlong PyLong_FromLongLong
+#define _cffi_from_c_ulonglong PyLong_FromUnsignedLongLong
+#define _cffi_from_c__Bool PyBool_FromLong
+
+#define _cffi_to_c_double PyFloat_AsDouble
+#define _cffi_to_c_float PyFloat_AsDouble
+
+#define _cffi_from_c_int(x, type) \
+ (((type)-1) > 0 ? /* unsigned */ \
+ (sizeof(type) < sizeof(long) ? \
+ PyInt_FromLong((long)x) : \
+ sizeof(type) == sizeof(long) ? \
+ PyLong_FromUnsignedLong((unsigned long)x) : \
+ PyLong_FromUnsignedLongLong((unsigned long long)x)) : \
+ (sizeof(type) <= sizeof(long) ? \
+ PyInt_FromLong((long)x) : \
+ PyLong_FromLongLong((long long)x)))
+
+#define _cffi_to_c_int(o, type) \
+ ((type)( \
+ sizeof(type) == 1 ? (((type)-1) > 0 ? (type)_cffi_to_c_u8(o) \
+ : (type)_cffi_to_c_i8(o)) : \
+ sizeof(type) == 2 ? (((type)-1) > 0 ? (type)_cffi_to_c_u16(o) \
+ : (type)_cffi_to_c_i16(o)) : \
+ sizeof(type) == 4 ? (((type)-1) > 0 ? (type)_cffi_to_c_u32(o) \
+ : (type)_cffi_to_c_i32(o)) : \
+ sizeof(type) == 8 ? (((type)-1) > 0 ? (type)_cffi_to_c_u64(o) \
+ : (type)_cffi_to_c_i64(o)) : \
+ (Py_FatalError("unsupported size for type " #type), (type)0)))
+
+#define _cffi_to_c_i8 \
+ ((int(*)(PyObject *))_cffi_exports[1])
+#define _cffi_to_c_u8 \
+ ((int(*)(PyObject *))_cffi_exports[2])
+#define _cffi_to_c_i16 \
+ ((int(*)(PyObject *))_cffi_exports[3])
+#define _cffi_to_c_u16 \
+ ((int(*)(PyObject *))_cffi_exports[4])
+#define _cffi_to_c_i32 \
+ ((int(*)(PyObject *))_cffi_exports[5])
+#define _cffi_to_c_u32 \
+ ((unsigned int(*)(PyObject *))_cffi_exports[6])
+#define _cffi_to_c_i64 \
+ ((long long(*)(PyObject *))_cffi_exports[7])
+#define _cffi_to_c_u64 \
+ ((unsigned long long(*)(PyObject *))_cffi_exports[8])
+#define _cffi_to_c_char \
+ ((int(*)(PyObject *))_cffi_exports[9])
+#define _cffi_from_c_pointer \
+ ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[10])
+#define _cffi_to_c_pointer \
+ ((char *(*)(PyObject *, struct _cffi_ctypedescr *))_cffi_exports[11])
+#define _cffi_get_struct_layout \
+ not used any more
+#define _cffi_restore_errno \
+ ((void(*)(void))_cffi_exports[13])
+#define _cffi_save_errno \
+ ((void(*)(void))_cffi_exports[14])
+#define _cffi_from_c_char \
+ ((PyObject *(*)(char))_cffi_exports[15])
+#define _cffi_from_c_deref \
+ ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[16])
+#define _cffi_to_c \
+ ((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[17])
+#define _cffi_from_c_struct \
+ ((PyObject *(*)(char *, struct _cffi_ctypedescr *))_cffi_exports[18])
+#define _cffi_to_c_wchar_t \
+ ((_cffi_wchar_t(*)(PyObject *))_cffi_exports[19])
+#define _cffi_from_c_wchar_t \
+ ((PyObject *(*)(_cffi_wchar_t))_cffi_exports[20])
+#define _cffi_to_c_long_double \
+ ((long double(*)(PyObject *))_cffi_exports[21])
+#define _cffi_to_c__Bool \
+ ((_Bool(*)(PyObject *))_cffi_exports[22])
+#define _cffi_prepare_pointer_call_argument \
+ ((Py_ssize_t(*)(struct _cffi_ctypedescr *, \
+ PyObject *, char **))_cffi_exports[23])
+#define _cffi_convert_array_from_object \
+ ((int(*)(char *, struct _cffi_ctypedescr *, PyObject *))_cffi_exports[24])
+#define _CFFI_CPIDX 25
+#define _cffi_call_python \
+ ((void(*)(struct _cffi_externpy_s *, char *))_cffi_exports[_CFFI_CPIDX])
+#define _cffi_to_c_wchar3216_t \
+ ((int(*)(PyObject *))_cffi_exports[26])
+#define _cffi_from_c_wchar3216_t \
+ ((PyObject *(*)(int))_cffi_exports[27])
+#define _CFFI_NUM_EXPORTS 28
+
+struct _cffi_ctypedescr;
+
+static void *_cffi_exports[_CFFI_NUM_EXPORTS];
+
+#define _cffi_type(index) ( \
+ assert((((uintptr_t)_cffi_types[index]) & 1) == 0), \
+ (struct _cffi_ctypedescr *)_cffi_types[index])
+
+static PyObject *_cffi_init(const char *module_name, Py_ssize_t version,
+ const struct _cffi_type_context_s *ctx)
+{
+ PyObject *module, *o_arg, *new_module;
+ void *raw[] = {
+ (void *)module_name,
+ (void *)version,
+ (void *)_cffi_exports,
+ (void *)ctx,
+ };
+
+ module = PyImport_ImportModule("_cffi_backend");
+ if (module == NULL)
+ goto failure;
+
+ o_arg = PyLong_FromVoidPtr((void *)raw);
+ if (o_arg == NULL)
+ goto failure;
+
+ new_module = PyObject_CallMethod(
+ module, (char *)"_init_cffi_1_0_external_module", (char *)"O", o_arg);
+
+ Py_DECREF(o_arg);
+ Py_DECREF(module);
+ return new_module;
+
+ failure:
+ Py_XDECREF(module);
+ return NULL;
+}
+
+
+#ifdef HAVE_WCHAR_H
+typedef wchar_t _cffi_wchar_t;
+#else
+typedef uint16_t _cffi_wchar_t; /* same random pick as _cffi_backend.c */
+#endif
+
+_CFFI_UNUSED_FN static uint16_t _cffi_to_c_char16_t(PyObject *o)
+{
+ if (sizeof(_cffi_wchar_t) == 2)
+ return (uint16_t)_cffi_to_c_wchar_t(o);
+ else
+ return (uint16_t)_cffi_to_c_wchar3216_t(o);
+}
+
+_CFFI_UNUSED_FN static PyObject *_cffi_from_c_char16_t(uint16_t x)
+{
+ if (sizeof(_cffi_wchar_t) == 2)
+ return _cffi_from_c_wchar_t((_cffi_wchar_t)x);
+ else
+ return _cffi_from_c_wchar3216_t((int)x);
+}
+
+_CFFI_UNUSED_FN static int _cffi_to_c_char32_t(PyObject *o)
+{
+ if (sizeof(_cffi_wchar_t) == 4)
+ return (int)_cffi_to_c_wchar_t(o);
+ else
+ return (int)_cffi_to_c_wchar3216_t(o);
+}
+
+_CFFI_UNUSED_FN static PyObject *_cffi_from_c_char32_t(unsigned int x)
+{
+ if (sizeof(_cffi_wchar_t) == 4)
+ return _cffi_from_c_wchar_t((_cffi_wchar_t)x);
+ else
+ return _cffi_from_c_wchar3216_t((int)x);
+}
+
+union _cffi_union_alignment_u {
+ unsigned char m_char;
+ unsigned short m_short;
+ unsigned int m_int;
+ unsigned long m_long;
+ unsigned long long m_longlong;
+ float m_float;
+ double m_double;
+ long double m_longdouble;
+};
+
+struct _cffi_freeme_s {
+ struct _cffi_freeme_s *next;
+ union _cffi_union_alignment_u alignment;
+};
+
+_CFFI_UNUSED_FN static int
+_cffi_convert_array_argument(struct _cffi_ctypedescr *ctptr, PyObject *arg,
+ char **output_data, Py_ssize_t datasize,
+ struct _cffi_freeme_s **freeme)
+{
+ char *p;
+ if (datasize < 0)
+ return -1;
+
+ p = *output_data;
+ if (p == NULL) {
+ struct _cffi_freeme_s *fp = (struct _cffi_freeme_s *)PyObject_Malloc(
+ offsetof(struct _cffi_freeme_s, alignment) + (size_t)datasize);
+ if (fp == NULL)
+ return -1;
+ fp->next = *freeme;
+ *freeme = fp;
+ p = *output_data = (char *)&fp->alignment;
+ }
+ memset((void *)p, 0, (size_t)datasize);
+ return _cffi_convert_array_from_object(p, ctptr, arg);
+}
+
+_CFFI_UNUSED_FN static void
+_cffi_free_array_arguments(struct _cffi_freeme_s *freeme)
+{
+ do {
+ void *p = (void *)freeme;
+ freeme = freeme->next;
+ PyObject_Free(p);
+ } while (freeme != NULL);
+}
+
+/********** end CPython-specific section **********/
+#else
+_CFFI_UNUSED_FN
+static void (*_cffi_call_python_org)(struct _cffi_externpy_s *, char *);
+# define _cffi_call_python _cffi_call_python_org
+#endif
+
+
+#define _cffi_array_len(array) (sizeof(array) / sizeof((array)[0]))
+
+#define _cffi_prim_int(size, sign) \
+ ((size) == 1 ? ((sign) ? _CFFI_PRIM_INT8 : _CFFI_PRIM_UINT8) : \
+ (size) == 2 ? ((sign) ? _CFFI_PRIM_INT16 : _CFFI_PRIM_UINT16) : \
+ (size) == 4 ? ((sign) ? _CFFI_PRIM_INT32 : _CFFI_PRIM_UINT32) : \
+ (size) == 8 ? ((sign) ? _CFFI_PRIM_INT64 : _CFFI_PRIM_UINT64) : \
+ _CFFI__UNKNOWN_PRIM)
+
+#define _cffi_prim_float(size) \
+ ((size) == sizeof(float) ? _CFFI_PRIM_FLOAT : \
+ (size) == sizeof(double) ? _CFFI_PRIM_DOUBLE : \
+ (size) == sizeof(long double) ? _CFFI__UNKNOWN_LONG_DOUBLE : \
+ _CFFI__UNKNOWN_FLOAT_PRIM)
+
+#define _cffi_check_int(got, got_nonpos, expected) \
+ ((got_nonpos) == (expected <= 0) && \
+ (got) == (unsigned long long)expected)
+
+#ifdef MS_WIN32
+# define _cffi_stdcall __stdcall
+#else
+# define _cffi_stdcall /* nothing */
+#endif
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/lib/cffi/_embedding.h b/lib/cffi/_embedding.h
new file mode 100644
index 0000000..8e8df88
--- /dev/null
+++ b/lib/cffi/_embedding.h
@@ -0,0 +1,528 @@
+
+/***** Support code for embedding *****/
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+
+#if defined(_WIN32)
+# define CFFI_DLLEXPORT __declspec(dllexport)
+#elif defined(__GNUC__)
+# define CFFI_DLLEXPORT __attribute__((visibility("default")))
+#else
+# define CFFI_DLLEXPORT /* nothing */
+#endif
+
+
+/* There are two global variables of type _cffi_call_python_fnptr:
+
+ * _cffi_call_python, which we declare just below, is the one called
+ by ``extern "Python"`` implementations.
+
+ * _cffi_call_python_org, which on CPython is actually part of the
+ _cffi_exports[] array, is the function pointer copied from
+ _cffi_backend. If _cffi_start_python() fails, then this is set
+ to NULL; otherwise, it should never be NULL.
+
+ After initialization is complete, both are equal. However, the
+ first one remains equal to &_cffi_start_and_call_python until the
+ very end of initialization, when we are (or should be) sure that
+ concurrent threads also see a completely initialized world, and
+ only then is it changed.
+*/
+#undef _cffi_call_python
+typedef void (*_cffi_call_python_fnptr)(struct _cffi_externpy_s *, char *);
+static void _cffi_start_and_call_python(struct _cffi_externpy_s *, char *);
+static _cffi_call_python_fnptr _cffi_call_python = &_cffi_start_and_call_python;
+
+
+#ifndef _MSC_VER
+ /* --- Assuming a GCC not infinitely old --- */
+# define cffi_compare_and_swap(l,o,n) __sync_bool_compare_and_swap(l,o,n)
+# define cffi_write_barrier() __sync_synchronize()
+# if !defined(__amd64__) && !defined(__x86_64__) && \
+ !defined(__i386__) && !defined(__i386)
+# define cffi_read_barrier() __sync_synchronize()
+# else
+# define cffi_read_barrier() (void)0
+# endif
+#else
+ /* --- Windows threads version --- */
+# include <Windows.h>
+# define cffi_compare_and_swap(l,o,n) \
+ (InterlockedCompareExchangePointer(l,n,o) == (o))
+# define cffi_write_barrier() InterlockedCompareExchange(&_cffi_dummy,0,0)
+# define cffi_read_barrier() (void)0
+static volatile LONG _cffi_dummy;
+#endif
+
+#ifdef WITH_THREAD
+# ifndef _MSC_VER
+# include <pthread.h>
+ static pthread_mutex_t _cffi_embed_startup_lock;
+# else
+ static CRITICAL_SECTION _cffi_embed_startup_lock;
+# endif
+ static char _cffi_embed_startup_lock_ready = 0;
+#endif
+
+static void _cffi_acquire_reentrant_mutex(void)
+{
+ static void *volatile lock = NULL;
+
+ while (!cffi_compare_and_swap(&lock, NULL, (void *)1)) {
+ /* should ideally do a spin loop instruction here, but
+ hard to do it portably and doesn't really matter I
+ think: pthread_mutex_init() should be very fast, and
+ this is only run at start-up anyway. */
+ }
+
+#ifdef WITH_THREAD
+ if (!_cffi_embed_startup_lock_ready) {
+# ifndef _MSC_VER
+ pthread_mutexattr_t attr;
+ pthread_mutexattr_init(&attr);
+ pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE);
+ pthread_mutex_init(&_cffi_embed_startup_lock, &attr);
+# else
+ InitializeCriticalSection(&_cffi_embed_startup_lock);
+# endif
+ _cffi_embed_startup_lock_ready = 1;
+ }
+#endif
+
+ while (!cffi_compare_and_swap(&lock, (void *)1, NULL))
+ ;
+
+#ifndef _MSC_VER
+ pthread_mutex_lock(&_cffi_embed_startup_lock);
+#else
+ EnterCriticalSection(&_cffi_embed_startup_lock);
+#endif
+}
+
+static void _cffi_release_reentrant_mutex(void)
+{
+#ifndef _MSC_VER
+ pthread_mutex_unlock(&_cffi_embed_startup_lock);
+#else
+ LeaveCriticalSection(&_cffi_embed_startup_lock);
+#endif
+}
+
+
+/********** CPython-specific section **********/
+#ifndef PYPY_VERSION
+
+#include "_cffi_errors.h"
+
+
+#define _cffi_call_python_org _cffi_exports[_CFFI_CPIDX]
+
+PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(void); /* forward */
+
+static void _cffi_py_initialize(void)
+{
+ /* XXX use initsigs=0, which "skips initialization registration of
+ signal handlers, which might be useful when Python is
+ embedded" according to the Python docs. But review and think
+ if it should be a user-controllable setting.
+
+ XXX we should also give a way to write errors to a buffer
+ instead of to stderr.
+
+ XXX if importing 'site' fails, CPython (any version) calls
+ exit(). Should we try to work around this behavior here?
+ */
+ Py_InitializeEx(0);
+}
+
+static int _cffi_initialize_python(void)
+{
+ /* This initializes Python, imports _cffi_backend, and then the
+ present .dll/.so is set up as a CPython C extension module.
+ */
+ int result;
+ PyGILState_STATE state;
+ PyObject *pycode=NULL, *global_dict=NULL, *x;
+ PyObject *builtins;
+
+ state = PyGILState_Ensure();
+
+ /* Call the initxxx() function from the present module. It will
+ create and initialize us as a CPython extension module, instead
+ of letting the startup Python code do it---it might reimport
+ the same .dll/.so and get maybe confused on some platforms.
+ It might also have troubles locating the .dll/.so again for all
+ I know.
+ */
+ (void)_CFFI_PYTHON_STARTUP_FUNC();
+ if (PyErr_Occurred())
+ goto error;
+
+ /* Now run the Python code provided to ffi.embedding_init_code().
+ */
+ pycode = Py_CompileString(_CFFI_PYTHON_STARTUP_CODE,
+ "<init code for '" _CFFI_MODULE_NAME "'>",
+ Py_file_input);
+ if (pycode == NULL)
+ goto error;
+ global_dict = PyDict_New();
+ if (global_dict == NULL)
+ goto error;
+ builtins = PyEval_GetBuiltins();
+ if (builtins == NULL)
+ goto error;
+ if (PyDict_SetItemString(global_dict, "__builtins__", builtins) < 0)
+ goto error;
+ x = PyEval_EvalCode(
+#if PY_MAJOR_VERSION < 3
+ (PyCodeObject *)
+#endif
+ pycode, global_dict, global_dict);
+ if (x == NULL)
+ goto error;
+ Py_DECREF(x);
+
+ /* Done! Now if we've been called from
+ _cffi_start_and_call_python() in an ``extern "Python"``, we can
+ only hope that the Python code did correctly set up the
+ corresponding @ffi.def_extern() function. Otherwise, the
+ general logic of ``extern "Python"`` functions (inside the
+ _cffi_backend module) will find that the reference is still
+ missing and print an error.
+ */
+ result = 0;
+ done:
+ Py_XDECREF(pycode);
+ Py_XDECREF(global_dict);
+ PyGILState_Release(state);
+ return result;
+
+ error:;
+ {
+ /* Print as much information as potentially useful.
+ Debugging load-time failures with embedding is not fun
+ */
+ PyObject *ecap;
+ PyObject *exception, *v, *tb, *f, *modules, *mod;
+ PyErr_Fetch(&exception, &v, &tb);
+ ecap = _cffi_start_error_capture();
+ f = PySys_GetObject((char *)"stderr");
+ if (f != NULL && f != Py_None) {
+ PyFile_WriteString(
+ "Failed to initialize the Python-CFFI embedding logic:\n\n", f);
+ }
+
+ if (exception != NULL) {
+ PyErr_NormalizeException(&exception, &v, &tb);
+ PyErr_Display(exception, v, tb);
+ }
+ Py_XDECREF(exception);
+ Py_XDECREF(v);
+ Py_XDECREF(tb);
+
+ if (f != NULL && f != Py_None) {
+ PyFile_WriteString("\nFrom: " _CFFI_MODULE_NAME
+ "\ncompiled with cffi version: 1.15.1"
+ "\n_cffi_backend module: ", f);
+ modules = PyImport_GetModuleDict();
+ mod = PyDict_GetItemString(modules, "_cffi_backend");
+ if (mod == NULL) {
+ PyFile_WriteString("not loaded", f);
+ }
+ else {
+ v = PyObject_GetAttrString(mod, "__file__");
+ PyFile_WriteObject(v, f, 0);
+ Py_XDECREF(v);
+ }
+ PyFile_WriteString("\nsys.path: ", f);
+ PyFile_WriteObject(PySys_GetObject((char *)"path"), f, 0);
+ PyFile_WriteString("\n\n", f);
+ }
+ _cffi_stop_error_capture(ecap);
+ }
+ result = -1;
+ goto done;
+}
+
+#if PY_VERSION_HEX < 0x03080000
+PyAPI_DATA(char *) _PyParser_TokenNames[]; /* from CPython */
+#endif
+
+static int _cffi_carefully_make_gil(void)
+{
+ /* This does the basic initialization of Python. It can be called
+ completely concurrently from unrelated threads. It assumes
+ that we don't hold the GIL before (if it exists), and we don't
+ hold it afterwards.
+
+ (What it really does used to be completely different in Python 2
+ and Python 3, with the Python 2 solution avoiding the spin-lock
+ around the Py_InitializeEx() call. However, after recent changes
+ to CPython 2.7 (issue #358) it no longer works. So we use the
+ Python 3 solution everywhere.)
+
+ This initializes Python by calling Py_InitializeEx().
+ Important: this must not be called concurrently at all.
+ So we use a global variable as a simple spin lock. This global
+ variable must be from 'libpythonX.Y.so', not from this
+ cffi-based extension module, because it must be shared from
+ different cffi-based extension modules.
+
+ In Python < 3.8, we choose
+ _PyParser_TokenNames[0] as a completely arbitrary pointer value
+ that is never written to. The default is to point to the
+ string "ENDMARKER". We change it temporarily to point to the
+ next character in that string. (Yes, I know it's REALLY
+ obscure.)
+
+ In Python >= 3.8, this string array is no longer writable, so
+ instead we pick PyCapsuleType.tp_version_tag. We can't change
+ Python < 3.8 because someone might use a mixture of cffi
+ embedded modules, some of which were compiled before this file
+ changed.
+ */
+
+#ifdef WITH_THREAD
+# if PY_VERSION_HEX < 0x03080000
+ char *volatile *lock = (char *volatile *)_PyParser_TokenNames;
+ char *old_value, *locked_value;
+
+ while (1) { /* spin loop */
+ old_value = *lock;
+ locked_value = old_value + 1;
+ if (old_value[0] == 'E') {
+ assert(old_value[1] == 'N');
+ if (cffi_compare_and_swap(lock, old_value, locked_value))
+ break;
+ }
+ else {
+ assert(old_value[0] == 'N');
+ /* should ideally do a spin loop instruction here, but
+ hard to do it portably and doesn't really matter I
+ think: PyEval_InitThreads() should be very fast, and
+ this is only run at start-up anyway. */
+ }
+ }
+# else
+ int volatile *lock = (int volatile *)&PyCapsule_Type.tp_version_tag;
+ int old_value, locked_value;
+ assert(!(PyCapsule_Type.tp_flags & Py_TPFLAGS_HAVE_VERSION_TAG));
+
+ while (1) { /* spin loop */
+ old_value = *lock;
+ locked_value = -42;
+ if (old_value == 0) {
+ if (cffi_compare_and_swap(lock, old_value, locked_value))
+ break;
+ }
+ else {
+ assert(old_value == locked_value);
+ /* should ideally do a spin loop instruction here, but
+ hard to do it portably and doesn't really matter I
+ think: PyEval_InitThreads() should be very fast, and
+ this is only run at start-up anyway. */
+ }
+ }
+# endif
+#endif
+
+ /* call Py_InitializeEx() */
+ if (!Py_IsInitialized()) {
+ _cffi_py_initialize();
+#if PY_VERSION_HEX < 0x03070000
+ PyEval_InitThreads();
+#endif
+ PyEval_SaveThread(); /* release the GIL */
+ /* the returned tstate must be the one that has been stored into the
+ autoTLSkey by _PyGILState_Init() called from Py_Initialize(). */
+ }
+ else {
+#if PY_VERSION_HEX < 0x03070000
+ /* PyEval_InitThreads() is always a no-op from CPython 3.7 */
+ PyGILState_STATE state = PyGILState_Ensure();
+ PyEval_InitThreads();
+ PyGILState_Release(state);
+#endif
+ }
+
+#ifdef WITH_THREAD
+ /* release the lock */
+ while (!cffi_compare_and_swap(lock, locked_value, old_value))
+ ;
+#endif
+
+ return 0;
+}
+
+/********** end CPython-specific section **********/
+
+
+#else
+
+
+/********** PyPy-specific section **********/
+
+PyMODINIT_FUNC _CFFI_PYTHON_STARTUP_FUNC(const void *[]); /* forward */
+
+static struct _cffi_pypy_init_s {
+ const char *name;
+ void *func; /* function pointer */
+ const char *code;
+} _cffi_pypy_init = {
+ _CFFI_MODULE_NAME,
+ _CFFI_PYTHON_STARTUP_FUNC,
+ _CFFI_PYTHON_STARTUP_CODE,
+};
+
+extern int pypy_carefully_make_gil(const char *);
+extern int pypy_init_embedded_cffi_module(int, struct _cffi_pypy_init_s *);
+
+static int _cffi_carefully_make_gil(void)
+{
+ return pypy_carefully_make_gil(_CFFI_MODULE_NAME);
+}
+
+static int _cffi_initialize_python(void)
+{
+ return pypy_init_embedded_cffi_module(0xB011, &_cffi_pypy_init);
+}
+
+/********** end PyPy-specific section **********/
+
+
+#endif
+
+
+#ifdef __GNUC__
+__attribute__((noinline))
+#endif
+static _cffi_call_python_fnptr _cffi_start_python(void)
+{
+ /* Delicate logic to initialize Python. This function can be
+ called multiple times concurrently, e.g. when the process calls
+ its first ``extern "Python"`` functions in multiple threads at
+ once. It can also be called recursively, in which case we must
+ ignore it. We also have to consider what occurs if several
+ different cffi-based extensions reach this code in parallel
+ threads---it is a different copy of the code, then, and we
+ can't have any shared global variable unless it comes from
+ 'libpythonX.Y.so'.
+
+ Idea:
+
+ * _cffi_carefully_make_gil(): "carefully" call
+ PyEval_InitThreads() (possibly with Py_InitializeEx() first).
+
+ * then we use a (local) custom lock to make sure that a call to this
+ cffi-based extension will wait if another call to the *same*
+ extension is running the initialization in another thread.
+ It is reentrant, so that a recursive call will not block, but
+ only one from a different thread.
+
+ * then we grab the GIL and (Python 2) we call Py_InitializeEx().
+ At this point, concurrent calls to Py_InitializeEx() are not
+ possible: we have the GIL.
+
+ * do the rest of the specific initialization, which may
+ temporarily release the GIL but not the custom lock.
+ Only release the custom lock when we are done.
+ */
+ static char called = 0;
+
+ if (_cffi_carefully_make_gil() != 0)
+ return NULL;
+
+ _cffi_acquire_reentrant_mutex();
+
+ /* Here the GIL exists, but we don't have it. We're only protected
+ from concurrency by the reentrant mutex. */
+
+ /* This file only initializes the embedded module once, the first
+ time this is called, even if there are subinterpreters. */
+ if (!called) {
+ called = 1; /* invoke _cffi_initialize_python() only once,
+ but don't set '_cffi_call_python' right now,
+ otherwise concurrent threads won't call
+ this function at all (we need them to wait) */
+ if (_cffi_initialize_python() == 0) {
+ /* now initialization is finished. Switch to the fast-path. */
+
+ /* We would like nobody to see the new value of
+ '_cffi_call_python' without also seeing the rest of the
+ data initialized. However, this is not possible. But
+ the new value of '_cffi_call_python' is the function
+ 'cffi_call_python()' from _cffi_backend. So: */
+ cffi_write_barrier();
+ /* ^^^ we put a write barrier here, and a corresponding
+ read barrier at the start of cffi_call_python(). This
+ ensures that after that read barrier, we see everything
+ done here before the write barrier.
+ */
+
+ assert(_cffi_call_python_org != NULL);
+ _cffi_call_python = (_cffi_call_python_fnptr)_cffi_call_python_org;
+ }
+ else {
+ /* initialization failed. Reset this to NULL, even if it was
+ already set to some other value. Future calls to
+ _cffi_start_python() are still forced to occur, and will
+ always return NULL from now on. */
+ _cffi_call_python_org = NULL;
+ }
+ }
+
+ _cffi_release_reentrant_mutex();
+
+ return (_cffi_call_python_fnptr)_cffi_call_python_org;
+}
+
+static
+void _cffi_start_and_call_python(struct _cffi_externpy_s *externpy, char *args)
+{
+ _cffi_call_python_fnptr fnptr;
+ int current_err = errno;
+#ifdef _MSC_VER
+ int current_lasterr = GetLastError();
+#endif
+ fnptr = _cffi_start_python();
+ if (fnptr == NULL) {
+ fprintf(stderr, "function %s() called, but initialization code "
+ "failed. Returning 0.\n", externpy->name);
+ memset(args, 0, externpy->size_of_result);
+ }
+#ifdef _MSC_VER
+ SetLastError(current_lasterr);
+#endif
+ errno = current_err;
+
+ if (fnptr != NULL)
+ fnptr(externpy, args);
+}
+
+
+/* The cffi_start_python() function makes sure Python is initialized
+ and our cffi module is set up. It can be called manually from the
+ user C code. The same effect is obtained automatically from any
+ dll-exported ``extern "Python"`` function. This function returns
+ -1 if initialization failed, 0 if all is OK. */
+_CFFI_UNUSED_FN
+static int cffi_start_python(void)
+{
+ if (_cffi_call_python == &_cffi_start_and_call_python) {
+ if (_cffi_start_python() == NULL)
+ return -1;
+ }
+ cffi_read_barrier();
+ return 0;
+}
+
+#undef cffi_compare_and_swap
+#undef cffi_write_barrier
+#undef cffi_read_barrier
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/lib/cffi/api.py b/lib/cffi/api.py
new file mode 100644
index 0000000..999a8ae
--- /dev/null
+++ b/lib/cffi/api.py
@@ -0,0 +1,965 @@
+import sys, types
+from .lock import allocate_lock
+from .error import CDefError
+from . import model
+
+try:
+ callable
+except NameError:
+ # Python 3.1
+ from collections import Callable
+ callable = lambda x: isinstance(x, Callable)
+
+try:
+ basestring
+except NameError:
+ # Python 3.x
+ basestring = str
+
+_unspecified = object()
+
+
+
+class FFI(object):
+ r'''
+ The main top-level class that you instantiate once, or once per module.
+
+ Example usage:
+
+ ffi = FFI()
+ ffi.cdef("""
+ int printf(const char *, ...);
+ """)
+
+ C = ffi.dlopen(None) # standard library
+ -or-
+ C = ffi.verify() # use a C compiler: verify the decl above is right
+
+ C.printf("hello, %s!\n", ffi.new("char[]", "world"))
+ '''
+
+ def __init__(self, backend=None):
+ """Create an FFI instance. The 'backend' argument is used to
+ select a non-default backend, mostly for tests.
+ """
+ if backend is None:
+ # You need PyPy (>= 2.0 beta), or a CPython (>= 2.6) with
+ # _cffi_backend.so compiled.
+ import _cffi_backend as backend
+ from . import __version__
+ if backend.__version__ != __version__:
+ # bad version! Try to be as explicit as possible.
+ if hasattr(backend, '__file__'):
+ # CPython
+ raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. When we import the top-level '_cffi_backend' extension module, we get version %s, located in %r. The two versions should be equal; check your installation." % (
+ __version__, __file__,
+ backend.__version__, backend.__file__))
+ else:
+ # PyPy
+ raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. This interpreter comes with a built-in '_cffi_backend' module, which is version %s. The two versions should be equal; check your installation." % (
+ __version__, __file__, backend.__version__))
+ # (If you insist you can also try to pass the option
+ # 'backend=backend_ctypes.CTypesBackend()', but don't
+ # rely on it! It's probably not going to work well.)
+
+ from . import cparser
+ self._backend = backend
+ self._lock = allocate_lock()
+ self._parser = cparser.Parser()
+ self._cached_btypes = {}
+ self._parsed_types = types.ModuleType('parsed_types').__dict__
+ self._new_types = types.ModuleType('new_types').__dict__
+ self._function_caches = []
+ self._libraries = []
+ self._cdefsources = []
+ self._included_ffis = []
+ self._windows_unicode = None
+ self._init_once_cache = {}
+ self._cdef_version = None
+ self._embedding = None
+ self._typecache = model.get_typecache(backend)
+ if hasattr(backend, 'set_ffi'):
+ backend.set_ffi(self)
+ for name in list(backend.__dict__):
+ if name.startswith('RTLD_'):
+ setattr(self, name, getattr(backend, name))
+ #
+ with self._lock:
+ self.BVoidP = self._get_cached_btype(model.voidp_type)
+ self.BCharA = self._get_cached_btype(model.char_array_type)
+ if isinstance(backend, types.ModuleType):
+ # _cffi_backend: attach these constants to the class
+ if not hasattr(FFI, 'NULL'):
+ FFI.NULL = self.cast(self.BVoidP, 0)
+ FFI.CData, FFI.CType = backend._get_types()
+ else:
+ # ctypes backend: attach these constants to the instance
+ self.NULL = self.cast(self.BVoidP, 0)
+ self.CData, self.CType = backend._get_types()
+ self.buffer = backend.buffer
+
+ def cdef(self, csource, override=False, packed=False, pack=None):
+ """Parse the given C source. This registers all declared functions,
+ types, and global variables. The functions and global variables can
+ then be accessed via either 'ffi.dlopen()' or 'ffi.verify()'.
+ The types can be used in 'ffi.new()' and other functions.
+ If 'packed' is specified as True, all structs declared inside this
+ cdef are packed, i.e. laid out without any field alignment at all.
+ Alternatively, 'pack' can be a small integer, and requests for
+ alignment greater than that are ignored (pack=1 is equivalent to
+ packed=True).
+ """
+ self._cdef(csource, override=override, packed=packed, pack=pack)
+
+ def embedding_api(self, csource, packed=False, pack=None):
+ self._cdef(csource, packed=packed, pack=pack, dllexport=True)
+ if self._embedding is None:
+ self._embedding = ''
+
+ def _cdef(self, csource, override=False, **options):
+ if not isinstance(csource, str): # unicode, on Python 2
+ if not isinstance(csource, basestring):
+ raise TypeError("cdef() argument must be a string")
+ csource = csource.encode('ascii')
+ with self._lock:
+ self._cdef_version = object()
+ self._parser.parse(csource, override=override, **options)
+ self._cdefsources.append(csource)
+ if override:
+ for cache in self._function_caches:
+ cache.clear()
+ finishlist = self._parser._recomplete
+ if finishlist:
+ self._parser._recomplete = []
+ for tp in finishlist:
+ tp.finish_backend_type(self, finishlist)
+
+ def dlopen(self, name, flags=0):
+ """Load and return a dynamic library identified by 'name'.
+ The standard C library can be loaded by passing None.
+ Note that functions and types declared by 'ffi.cdef()' are not
+ linked to a particular library, just like C headers; in the
+ library we only look for the actual (untyped) symbols.
+ """
+ if not (isinstance(name, basestring) or
+ name is None or
+ isinstance(name, self.CData)):
+ raise TypeError("dlopen(name): name must be a file name, None, "
+ "or an already-opened 'void *' handle")
+ with self._lock:
+ lib, function_cache = _make_ffi_library(self, name, flags)
+ self._function_caches.append(function_cache)
+ self._libraries.append(lib)
+ return lib
+
+ def dlclose(self, lib):
+ """Close a library obtained with ffi.dlopen(). After this call,
+ access to functions or variables from the library will fail
+ (possibly with a segmentation fault).
+ """
+ type(lib).__cffi_close__(lib)
+
+ def _typeof_locked(self, cdecl):
+ # call me with the lock!
+ key = cdecl
+ if key in self._parsed_types:
+ return self._parsed_types[key]
+ #
+ if not isinstance(cdecl, str): # unicode, on Python 2
+ cdecl = cdecl.encode('ascii')
+ #
+ type = self._parser.parse_type(cdecl)
+ really_a_function_type = type.is_raw_function
+ if really_a_function_type:
+ type = type.as_function_pointer()
+ btype = self._get_cached_btype(type)
+ result = btype, really_a_function_type
+ self._parsed_types[key] = result
+ return result
+
+ def _typeof(self, cdecl, consider_function_as_funcptr=False):
+ # string -> ctype object
+ try:
+ result = self._parsed_types[cdecl]
+ except KeyError:
+ with self._lock:
+ result = self._typeof_locked(cdecl)
+ #
+ btype, really_a_function_type = result
+ if really_a_function_type and not consider_function_as_funcptr:
+ raise CDefError("the type %r is a function type, not a "
+ "pointer-to-function type" % (cdecl,))
+ return btype
+
+ def typeof(self, cdecl):
+ """Parse the C type given as a string and return the
+ corresponding <ctype> object.
+ It can also be used on 'cdata' instance to get its C type.
+ """
+ if isinstance(cdecl, basestring):
+ return self._typeof(cdecl)
+ if isinstance(cdecl, self.CData):
+ return self._backend.typeof(cdecl)
+ if isinstance(cdecl, types.BuiltinFunctionType):
+ res = _builtin_function_type(cdecl)
+ if res is not None:
+ return res
+ if (isinstance(cdecl, types.FunctionType)
+ and hasattr(cdecl, '_cffi_base_type')):
+ with self._lock:
+ return self._get_cached_btype(cdecl._cffi_base_type)
+ raise TypeError(type(cdecl))
+
+ def sizeof(self, cdecl):
+ """Return the size in bytes of the argument. It can be a
+ string naming a C type, or a 'cdata' instance.
+ """
+ if isinstance(cdecl, basestring):
+ BType = self._typeof(cdecl)
+ return self._backend.sizeof(BType)
+ else:
+ return self._backend.sizeof(cdecl)
+
+ def alignof(self, cdecl):
+ """Return the natural alignment size in bytes of the C type
+ given as a string.
+ """
+ if isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl)
+ return self._backend.alignof(cdecl)
+
+ def offsetof(self, cdecl, *fields_or_indexes):
+ """Return the offset of the named field inside the given
+ structure or array, which must be given as a C type name.
+ You can give several field names in case of nested structures.
+ You can also give numeric values which correspond to array
+ items, in case of an array type.
+ """
+ if isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl)
+ return self._typeoffsetof(cdecl, *fields_or_indexes)[1]
+
+ def new(self, cdecl, init=None):
+ """Allocate an instance according to the specified C type and
+ return a pointer to it. The specified C type must be either a
+ pointer or an array: ``new('X *')`` allocates an X and returns
+ a pointer to it, whereas ``new('X[n]')`` allocates an array of
+ n X'es and returns an array referencing it (which works
+ mostly like a pointer, like in C). You can also use
+ ``new('X[]', n)`` to allocate an array of a non-constant
+ length n.
+
+ The memory is initialized following the rules of declaring a
+ global variable in C: by default it is zero-initialized, but
+ an explicit initializer can be given which can be used to
+ fill all or part of the memory.
+
+ When the returned <cdata> object goes out of scope, the memory
+ is freed. In other words the returned <cdata> object has
+ ownership of the value of type 'cdecl' that it points to. This
+ means that the raw data can be used as long as this object is
+ kept alive, but must not be used for a longer time. Be careful
+ about that when copying the pointer to the memory somewhere
+ else, e.g. into another structure.
+ """
+ if isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl)
+ return self._backend.newp(cdecl, init)
+
+ def new_allocator(self, alloc=None, free=None,
+ should_clear_after_alloc=True):
+ """Return a new allocator, i.e. a function that behaves like ffi.new()
+ but uses the provided low-level 'alloc' and 'free' functions.
+
+ 'alloc' is called with the size as argument. If it returns NULL, a
+ MemoryError is raised. 'free' is called with the result of 'alloc'
+ as argument. Both can be either Python function or directly C
+ functions. If 'free' is None, then no free function is called.
+ If both 'alloc' and 'free' are None, the default is used.
+
+ If 'should_clear_after_alloc' is set to False, then the memory
+ returned by 'alloc' is assumed to be already cleared (or you are
+ fine with garbage); otherwise CFFI will clear it.
+ """
+ compiled_ffi = self._backend.FFI()
+ allocator = compiled_ffi.new_allocator(alloc, free,
+ should_clear_after_alloc)
+ def allocate(cdecl, init=None):
+ if isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl)
+ return allocator(cdecl, init)
+ return allocate
+
+ def cast(self, cdecl, source):
+ """Similar to a C cast: returns an instance of the named C
+ type initialized with the given 'source'. The source is
+ casted between integers or pointers of any type.
+ """
+ if isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl)
+ return self._backend.cast(cdecl, source)
+
+ def string(self, cdata, maxlen=-1):
+ """Return a Python string (or unicode string) from the 'cdata'.
+ If 'cdata' is a pointer or array of characters or bytes, returns
+ the null-terminated string. The returned string extends until
+ the first null character, or at most 'maxlen' characters. If
+ 'cdata' is an array then 'maxlen' defaults to its length.
+
+ If 'cdata' is a pointer or array of wchar_t, returns a unicode
+ string following the same rules.
+
+ If 'cdata' is a single character or byte or a wchar_t, returns
+ it as a string or unicode string.
+
+ If 'cdata' is an enum, returns the value of the enumerator as a
+ string, or 'NUMBER' if the value is out of range.
+ """
+ return self._backend.string(cdata, maxlen)
+
+ def unpack(self, cdata, length):
+ """Unpack an array of C data of the given length,
+ returning a Python string/unicode/list.
+
+ If 'cdata' is a pointer to 'char', returns a byte string.
+ It does not stop at the first null. This is equivalent to:
+ ffi.buffer(cdata, length)[:]
+
+ If 'cdata' is a pointer to 'wchar_t', returns a unicode string.
+ 'length' is measured in wchar_t's; it is not the size in bytes.
+
+ If 'cdata' is a pointer to anything else, returns a list of
+ 'length' items. This is a faster equivalent to:
+ [cdata[i] for i in range(length)]
+ """
+ return self._backend.unpack(cdata, length)
+
+ #def buffer(self, cdata, size=-1):
+ # """Return a read-write buffer object that references the raw C data
+ # pointed to by the given 'cdata'. The 'cdata' must be a pointer or
+ # an array. Can be passed to functions expecting a buffer, or directly
+ # manipulated with:
+ #
+ # buf[:] get a copy of it in a regular string, or
+ # buf[idx] as a single character
+ # buf[:] = ...
+ # buf[idx] = ... change the content
+ # """
+ # note that 'buffer' is a type, set on this instance by __init__
+
+ def from_buffer(self, cdecl, python_buffer=_unspecified,
+ require_writable=False):
+ """Return a cdata of the given type pointing to the data of the
+ given Python object, which must support the buffer interface.
+ Note that this is not meant to be used on the built-in types
+ str or unicode (you can build 'char[]' arrays explicitly)
+ but only on objects containing large quantities of raw data
+ in some other format, like 'array.array' or numpy arrays.
+
+ The first argument is optional and default to 'char[]'.
+ """
+ if python_buffer is _unspecified:
+ cdecl, python_buffer = self.BCharA, cdecl
+ elif isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl)
+ return self._backend.from_buffer(cdecl, python_buffer,
+ require_writable)
+
+ def memmove(self, dest, src, n):
+ """ffi.memmove(dest, src, n) copies n bytes of memory from src to dest.
+
+ Like the C function memmove(), the memory areas may overlap;
+ apart from that it behaves like the C function memcpy().
+
+ 'src' can be any cdata ptr or array, or any Python buffer object.
+ 'dest' can be any cdata ptr or array, or a writable Python buffer
+ object. The size to copy, 'n', is always measured in bytes.
+
+ Unlike other methods, this one supports all Python buffer including
+ byte strings and bytearrays---but it still does not support
+ non-contiguous buffers.
+ """
+ return self._backend.memmove(dest, src, n)
+
+ def callback(self, cdecl, python_callable=None, error=None, onerror=None):
+ """Return a callback object or a decorator making such a
+ callback object. 'cdecl' must name a C function pointer type.
+ The callback invokes the specified 'python_callable' (which may
+ be provided either directly or via a decorator). Important: the
+ callback object must be manually kept alive for as long as the
+ callback may be invoked from the C level.
+ """
+ def callback_decorator_wrap(python_callable):
+ if not callable(python_callable):
+ raise TypeError("the 'python_callable' argument "
+ "is not callable")
+ return self._backend.callback(cdecl, python_callable,
+ error, onerror)
+ if isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl, consider_function_as_funcptr=True)
+ if python_callable is None:
+ return callback_decorator_wrap # decorator mode
+ else:
+ return callback_decorator_wrap(python_callable) # direct mode
+
+ def getctype(self, cdecl, replace_with=''):
+ """Return a string giving the C type 'cdecl', which may be itself
+ a string or a <ctype> object. If 'replace_with' is given, it gives
+ extra text to append (or insert for more complicated C types), like
+ a variable name, or '*' to get actually the C type 'pointer-to-cdecl'.
+ """
+ if isinstance(cdecl, basestring):
+ cdecl = self._typeof(cdecl)
+ replace_with = replace_with.strip()
+ if (replace_with.startswith('*')
+ and '&[' in self._backend.getcname(cdecl, '&')):
+ replace_with = '(%s)' % replace_with
+ elif replace_with and not replace_with[0] in '[(':
+ replace_with = ' ' + replace_with
+ return self._backend.getcname(cdecl, replace_with)
+
+ def gc(self, cdata, destructor, size=0):
+ """Return a new cdata object that points to the same
+ data. Later, when this new cdata object is garbage-collected,
+ 'destructor(old_cdata_object)' will be called.
+
+ The optional 'size' gives an estimate of the size, used to
+ trigger the garbage collection more eagerly. So far only used
+ on PyPy. It tells the GC that the returned object keeps alive
+ roughly 'size' bytes of external memory.
+ """
+ return self._backend.gcp(cdata, destructor, size)
+
+ def _get_cached_btype(self, type):
+ assert self._lock.acquire(False) is False
+ # call me with the lock!
+ try:
+ BType = self._cached_btypes[type]
+ except KeyError:
+ finishlist = []
+ BType = type.get_cached_btype(self, finishlist)
+ for type in finishlist:
+ type.finish_backend_type(self, finishlist)
+ return BType
+
+ def verify(self, source='', tmpdir=None, **kwargs):
+ """Verify that the current ffi signatures compile on this
+ machine, and return a dynamic library object. The dynamic
+ library can be used to call functions and access global
+ variables declared in this 'ffi'. The library is compiled
+ by the C compiler: it gives you C-level API compatibility
+ (including calling macros). This is unlike 'ffi.dlopen()',
+ which requires binary compatibility in the signatures.
+ """
+ from .verifier import Verifier, _caller_dir_pycache
+ #
+ # If set_unicode(True) was called, insert the UNICODE and
+ # _UNICODE macro declarations
+ if self._windows_unicode:
+ self._apply_windows_unicode(kwargs)
+ #
+ # Set the tmpdir here, and not in Verifier.__init__: it picks
+ # up the caller's directory, which we want to be the caller of
+ # ffi.verify(), as opposed to the caller of Veritier().
+ tmpdir = tmpdir or _caller_dir_pycache()
+ #
+ # Make a Verifier() and use it to load the library.
+ self.verifier = Verifier(self, source, tmpdir, **kwargs)
+ lib = self.verifier.load_library()
+ #
+ # Save the loaded library for keep-alive purposes, even
+ # if the caller doesn't keep it alive itself (it should).
+ self._libraries.append(lib)
+ return lib
+
+ def _get_errno(self):
+ return self._backend.get_errno()
+ def _set_errno(self, errno):
+ self._backend.set_errno(errno)
+ errno = property(_get_errno, _set_errno, None,
+ "the value of 'errno' from/to the C calls")
+
+ def getwinerror(self, code=-1):
+ return self._backend.getwinerror(code)
+
+ def _pointer_to(self, ctype):
+ with self._lock:
+ return model.pointer_cache(self, ctype)
+
+ def addressof(self, cdata, *fields_or_indexes):
+ """Return the address of a <cdata 'struct-or-union'>.
+ If 'fields_or_indexes' are given, returns the address of that
+ field or array item in the structure or array, recursively in
+ case of nested structures.
+ """
+ try:
+ ctype = self._backend.typeof(cdata)
+ except TypeError:
+ if '__addressof__' in type(cdata).__dict__:
+ return type(cdata).__addressof__(cdata, *fields_or_indexes)
+ raise
+ if fields_or_indexes:
+ ctype, offset = self._typeoffsetof(ctype, *fields_or_indexes)
+ else:
+ if ctype.kind == "pointer":
+ raise TypeError("addressof(pointer)")
+ offset = 0
+ ctypeptr = self._pointer_to(ctype)
+ return self._backend.rawaddressof(ctypeptr, cdata, offset)
+
+ def _typeoffsetof(self, ctype, field_or_index, *fields_or_indexes):
+ ctype, offset = self._backend.typeoffsetof(ctype, field_or_index)
+ for field1 in fields_or_indexes:
+ ctype, offset1 = self._backend.typeoffsetof(ctype, field1, 1)
+ offset += offset1
+ return ctype, offset
+
+ def include(self, ffi_to_include):
+ """Includes the typedefs, structs, unions and enums defined
+ in another FFI instance. Usage is similar to a #include in C,
+ where a part of the program might include types defined in
+ another part for its own usage. Note that the include()
+ method has no effect on functions, constants and global
+ variables, which must anyway be accessed directly from the
+ lib object returned by the original FFI instance.
+ """
+ if not isinstance(ffi_to_include, FFI):
+ raise TypeError("ffi.include() expects an argument that is also of"
+ " type cffi.FFI, not %r" % (
+ type(ffi_to_include).__name__,))
+ if ffi_to_include is self:
+ raise ValueError("self.include(self)")
+ with ffi_to_include._lock:
+ with self._lock:
+ self._parser.include(ffi_to_include._parser)
+ self._cdefsources.append('[')
+ self._cdefsources.extend(ffi_to_include._cdefsources)
+ self._cdefsources.append(']')
+ self._included_ffis.append(ffi_to_include)
+
+ def new_handle(self, x):
+ return self._backend.newp_handle(self.BVoidP, x)
+
+ def from_handle(self, x):
+ return self._backend.from_handle(x)
+
+ def release(self, x):
+ self._backend.release(x)
+
+ def set_unicode(self, enabled_flag):
+ """Windows: if 'enabled_flag' is True, enable the UNICODE and
+ _UNICODE defines in C, and declare the types like TCHAR and LPTCSTR
+ to be (pointers to) wchar_t. If 'enabled_flag' is False,
+ declare these types to be (pointers to) plain 8-bit characters.
+ This is mostly for backward compatibility; you usually want True.
+ """
+ if self._windows_unicode is not None:
+ raise ValueError("set_unicode() can only be called once")
+ enabled_flag = bool(enabled_flag)
+ if enabled_flag:
+ self.cdef("typedef wchar_t TBYTE;"
+ "typedef wchar_t TCHAR;"
+ "typedef const wchar_t *LPCTSTR;"
+ "typedef const wchar_t *PCTSTR;"
+ "typedef wchar_t *LPTSTR;"
+ "typedef wchar_t *PTSTR;"
+ "typedef TBYTE *PTBYTE;"
+ "typedef TCHAR *PTCHAR;")
+ else:
+ self.cdef("typedef char TBYTE;"
+ "typedef char TCHAR;"
+ "typedef const char *LPCTSTR;"
+ "typedef const char *PCTSTR;"
+ "typedef char *LPTSTR;"
+ "typedef char *PTSTR;"
+ "typedef TBYTE *PTBYTE;"
+ "typedef TCHAR *PTCHAR;")
+ self._windows_unicode = enabled_flag
+
+ def _apply_windows_unicode(self, kwds):
+ defmacros = kwds.get('define_macros', ())
+ if not isinstance(defmacros, (list, tuple)):
+ raise TypeError("'define_macros' must be a list or tuple")
+ defmacros = list(defmacros) + [('UNICODE', '1'),
+ ('_UNICODE', '1')]
+ kwds['define_macros'] = defmacros
+
+ def _apply_embedding_fix(self, kwds):
+ # must include an argument like "-lpython2.7" for the compiler
+ def ensure(key, value):
+ lst = kwds.setdefault(key, [])
+ if value not in lst:
+ lst.append(value)
+ #
+ if '__pypy__' in sys.builtin_module_names:
+ import os
+ if sys.platform == "win32":
+ # we need 'libpypy-c.lib'. Current distributions of
+ # pypy (>= 4.1) contain it as 'libs/python27.lib'.
+ pythonlib = "python{0[0]}{0[1]}".format(sys.version_info)
+ if hasattr(sys, 'prefix'):
+ ensure('library_dirs', os.path.join(sys.prefix, 'libs'))
+ else:
+ # we need 'libpypy-c.{so,dylib}', which should be by
+ # default located in 'sys.prefix/bin' for installed
+ # systems.
+ if sys.version_info < (3,):
+ pythonlib = "pypy-c"
+ else:
+ pythonlib = "pypy3-c"
+ if hasattr(sys, 'prefix'):
+ ensure('library_dirs', os.path.join(sys.prefix, 'bin'))
+ # On uninstalled pypy's, the libpypy-c is typically found in
+ # .../pypy/goal/.
+ if hasattr(sys, 'prefix'):
+ ensure('library_dirs', os.path.join(sys.prefix, 'pypy', 'goal'))
+ else:
+ if sys.platform == "win32":
+ template = "python%d%d"
+ if hasattr(sys, 'gettotalrefcount'):
+ template += '_d'
+ else:
+ try:
+ import sysconfig
+ except ImportError: # 2.6
+ from distutils import sysconfig
+ template = "python%d.%d"
+ if sysconfig.get_config_var('DEBUG_EXT'):
+ template += sysconfig.get_config_var('DEBUG_EXT')
+ pythonlib = (template %
+ (sys.hexversion >> 24, (sys.hexversion >> 16) & 0xff))
+ if hasattr(sys, 'abiflags'):
+ pythonlib += sys.abiflags
+ ensure('libraries', pythonlib)
+ if sys.platform == "win32":
+ ensure('extra_link_args', '/MANIFEST')
+
+ def set_source(self, module_name, source, source_extension='.c', **kwds):
+ import os
+ if hasattr(self, '_assigned_source'):
+ raise ValueError("set_source() cannot be called several times "
+ "per ffi object")
+ if not isinstance(module_name, basestring):
+ raise TypeError("'module_name' must be a string")
+ if os.sep in module_name or (os.altsep and os.altsep in module_name):
+ raise ValueError("'module_name' must not contain '/': use a dotted "
+ "name to make a 'package.module' location")
+ self._assigned_source = (str(module_name), source,
+ source_extension, kwds)
+
+ def set_source_pkgconfig(self, module_name, pkgconfig_libs, source,
+ source_extension='.c', **kwds):
+ from . import pkgconfig
+ if not isinstance(pkgconfig_libs, list):
+ raise TypeError("the pkgconfig_libs argument must be a list "
+ "of package names")
+ kwds2 = pkgconfig.flags_from_pkgconfig(pkgconfig_libs)
+ pkgconfig.merge_flags(kwds, kwds2)
+ self.set_source(module_name, source, source_extension, **kwds)
+
+ def distutils_extension(self, tmpdir='build', verbose=True):
+ from distutils.dir_util import mkpath
+ from .recompiler import recompile
+ #
+ if not hasattr(self, '_assigned_source'):
+ if hasattr(self, 'verifier'): # fallback, 'tmpdir' ignored
+ return self.verifier.get_extension()
+ raise ValueError("set_source() must be called before"
+ " distutils_extension()")
+ module_name, source, source_extension, kwds = self._assigned_source
+ if source is None:
+ raise TypeError("distutils_extension() is only for C extension "
+ "modules, not for dlopen()-style pure Python "
+ "modules")
+ mkpath(tmpdir)
+ ext, updated = recompile(self, module_name,
+ source, tmpdir=tmpdir, extradir=tmpdir,
+ source_extension=source_extension,
+ call_c_compiler=False, **kwds)
+ if verbose:
+ if updated:
+ sys.stderr.write("regenerated: %r\n" % (ext.sources[0],))
+ else:
+ sys.stderr.write("not modified: %r\n" % (ext.sources[0],))
+ return ext
+
+ def emit_c_code(self, filename):
+ from .recompiler import recompile
+ #
+ if not hasattr(self, '_assigned_source'):
+ raise ValueError("set_source() must be called before emit_c_code()")
+ module_name, source, source_extension, kwds = self._assigned_source
+ if source is None:
+ raise TypeError("emit_c_code() is only for C extension modules, "
+ "not for dlopen()-style pure Python modules")
+ recompile(self, module_name, source,
+ c_file=filename, call_c_compiler=False, **kwds)
+
+ def emit_python_code(self, filename):
+ from .recompiler import recompile
+ #
+ if not hasattr(self, '_assigned_source'):
+ raise ValueError("set_source() must be called before emit_c_code()")
+ module_name, source, source_extension, kwds = self._assigned_source
+ if source is not None:
+ raise TypeError("emit_python_code() is only for dlopen()-style "
+ "pure Python modules, not for C extension modules")
+ recompile(self, module_name, source,
+ c_file=filename, call_c_compiler=False, **kwds)
+
+ def compile(self, tmpdir='.', verbose=0, target=None, debug=None):
+ """The 'target' argument gives the final file name of the
+ compiled DLL. Use '*' to force distutils' choice, suitable for
+ regular CPython C API modules. Use a file name ending in '.*'
+ to ask for the system's default extension for dynamic libraries
+ (.so/.dll/.dylib).
+
+ The default is '*' when building a non-embedded C API extension,
+ and (module_name + '.*') when building an embedded library.
+ """
+ from .recompiler import recompile
+ #
+ if not hasattr(self, '_assigned_source'):
+ raise ValueError("set_source() must be called before compile()")
+ module_name, source, source_extension, kwds = self._assigned_source
+ return recompile(self, module_name, source, tmpdir=tmpdir,
+ target=target, source_extension=source_extension,
+ compiler_verbose=verbose, debug=debug, **kwds)
+
+ def init_once(self, func, tag):
+ # Read _init_once_cache[tag], which is either (False, lock) if
+ # we're calling the function now in some thread, or (True, result).
+ # Don't call setdefault() in most cases, to avoid allocating and
+ # immediately freeing a lock; but still use setdefaut() to avoid
+ # races.
+ try:
+ x = self._init_once_cache[tag]
+ except KeyError:
+ x = self._init_once_cache.setdefault(tag, (False, allocate_lock()))
+ # Common case: we got (True, result), so we return the result.
+ if x[0]:
+ return x[1]
+ # Else, it's a lock. Acquire it to serialize the following tests.
+ with x[1]:
+ # Read again from _init_once_cache the current status.
+ x = self._init_once_cache[tag]
+ if x[0]:
+ return x[1]
+ # Call the function and store the result back.
+ result = func()
+ self._init_once_cache[tag] = (True, result)
+ return result
+
+ def embedding_init_code(self, pysource):
+ if self._embedding:
+ raise ValueError("embedding_init_code() can only be called once")
+ # fix 'pysource' before it gets dumped into the C file:
+ # - remove empty lines at the beginning, so it starts at "line 1"
+ # - dedent, if all non-empty lines are indented
+ # - check for SyntaxErrors
+ import re
+ match = re.match(r'\s*\n', pysource)
+ if match:
+ pysource = pysource[match.end():]
+ lines = pysource.splitlines() or ['']
+ prefix = re.match(r'\s*', lines[0]).group()
+ for i in range(1, len(lines)):
+ line = lines[i]
+ if line.rstrip():
+ while not line.startswith(prefix):
+ prefix = prefix[:-1]
+ i = len(prefix)
+ lines = [line[i:]+'\n' for line in lines]
+ pysource = ''.join(lines)
+ #
+ compile(pysource, "cffi_init", "exec")
+ #
+ self._embedding = pysource
+
+ def def_extern(self, *args, **kwds):
+ raise ValueError("ffi.def_extern() is only available on API-mode FFI "
+ "objects")
+
+ def list_types(self):
+ """Returns the user type names known to this FFI instance.
+ This returns a tuple containing three lists of names:
+ (typedef_names, names_of_structs, names_of_unions)
+ """
+ typedefs = []
+ structs = []
+ unions = []
+ for key in self._parser._declarations:
+ if key.startswith('typedef '):
+ typedefs.append(key[8:])
+ elif key.startswith('struct '):
+ structs.append(key[7:])
+ elif key.startswith('union '):
+ unions.append(key[6:])
+ typedefs.sort()
+ structs.sort()
+ unions.sort()
+ return (typedefs, structs, unions)
+
+
+def _load_backend_lib(backend, name, flags):
+ import os
+ if not isinstance(name, basestring):
+ if sys.platform != "win32" or name is not None:
+ return backend.load_library(name, flags)
+ name = "c" # Windows: load_library(None) fails, but this works
+ # on Python 2 (backward compatibility hack only)
+ first_error = None
+ if '.' in name or '/' in name or os.sep in name:
+ try:
+ return backend.load_library(name, flags)
+ except OSError as e:
+ first_error = e
+ import ctypes.util
+ path = ctypes.util.find_library(name)
+ if path is None:
+ if name == "c" and sys.platform == "win32" and sys.version_info >= (3,):
+ raise OSError("dlopen(None) cannot work on Windows for Python 3 "
+ "(see http://bugs.python.org/issue23606)")
+ msg = ("ctypes.util.find_library() did not manage "
+ "to locate a library called %r" % (name,))
+ if first_error is not None:
+ msg = "%s. Additionally, %s" % (first_error, msg)
+ raise OSError(msg)
+ return backend.load_library(path, flags)
+
+def _make_ffi_library(ffi, libname, flags):
+ backend = ffi._backend
+ backendlib = _load_backend_lib(backend, libname, flags)
+ #
+ def accessor_function(name):
+ key = 'function ' + name
+ tp, _ = ffi._parser._declarations[key]
+ BType = ffi._get_cached_btype(tp)
+ value = backendlib.load_function(BType, name)
+ library.__dict__[name] = value
+ #
+ def accessor_variable(name):
+ key = 'variable ' + name
+ tp, _ = ffi._parser._declarations[key]
+ BType = ffi._get_cached_btype(tp)
+ read_variable = backendlib.read_variable
+ write_variable = backendlib.write_variable
+ setattr(FFILibrary, name, property(
+ lambda self: read_variable(BType, name),
+ lambda self, value: write_variable(BType, name, value)))
+ #
+ def addressof_var(name):
+ try:
+ return addr_variables[name]
+ except KeyError:
+ with ffi._lock:
+ if name not in addr_variables:
+ key = 'variable ' + name
+ tp, _ = ffi._parser._declarations[key]
+ BType = ffi._get_cached_btype(tp)
+ if BType.kind != 'array':
+ BType = model.pointer_cache(ffi, BType)
+ p = backendlib.load_function(BType, name)
+ addr_variables[name] = p
+ return addr_variables[name]
+ #
+ def accessor_constant(name):
+ raise NotImplementedError("non-integer constant '%s' cannot be "
+ "accessed from a dlopen() library" % (name,))
+ #
+ def accessor_int_constant(name):
+ library.__dict__[name] = ffi._parser._int_constants[name]
+ #
+ accessors = {}
+ accessors_version = [False]
+ addr_variables = {}
+ #
+ def update_accessors():
+ if accessors_version[0] is ffi._cdef_version:
+ return
+ #
+ for key, (tp, _) in ffi._parser._declarations.items():
+ if not isinstance(tp, model.EnumType):
+ tag, name = key.split(' ', 1)
+ if tag == 'function':
+ accessors[name] = accessor_function
+ elif tag == 'variable':
+ accessors[name] = accessor_variable
+ elif tag == 'constant':
+ accessors[name] = accessor_constant
+ else:
+ for i, enumname in enumerate(tp.enumerators):
+ def accessor_enum(name, tp=tp, i=i):
+ tp.check_not_partial()
+ library.__dict__[name] = tp.enumvalues[i]
+ accessors[enumname] = accessor_enum
+ for name in ffi._parser._int_constants:
+ accessors.setdefault(name, accessor_int_constant)
+ accessors_version[0] = ffi._cdef_version
+ #
+ def make_accessor(name):
+ with ffi._lock:
+ if name in library.__dict__ or name in FFILibrary.__dict__:
+ return # added by another thread while waiting for the lock
+ if name not in accessors:
+ update_accessors()
+ if name not in accessors:
+ raise AttributeError(name)
+ accessors[name](name)
+ #
+ class FFILibrary(object):
+ def __getattr__(self, name):
+ make_accessor(name)
+ return getattr(self, name)
+ def __setattr__(self, name, value):
+ try:
+ property = getattr(self.__class__, name)
+ except AttributeError:
+ make_accessor(name)
+ setattr(self, name, value)
+ else:
+ property.__set__(self, value)
+ def __dir__(self):
+ with ffi._lock:
+ update_accessors()
+ return accessors.keys()
+ def __addressof__(self, name):
+ if name in library.__dict__:
+ return library.__dict__[name]
+ if name in FFILibrary.__dict__:
+ return addressof_var(name)
+ make_accessor(name)
+ if name in library.__dict__:
+ return library.__dict__[name]
+ if name in FFILibrary.__dict__:
+ return addressof_var(name)
+ raise AttributeError("cffi library has no function or "
+ "global variable named '%s'" % (name,))
+ def __cffi_close__(self):
+ backendlib.close_lib()
+ self.__dict__.clear()
+ #
+ if isinstance(libname, basestring):
+ try:
+ if not isinstance(libname, str): # unicode, on Python 2
+ libname = libname.encode('utf-8')
+ FFILibrary.__name__ = 'FFILibrary_%s' % libname
+ except UnicodeError:
+ pass
+ library = FFILibrary()
+ return library, library.__dict__
+
+def _builtin_function_type(func):
+ # a hack to make at least ffi.typeof(builtin_function) work,
+ # if the builtin function was obtained by 'vengine_cpy'.
+ import sys
+ try:
+ module = sys.modules[func.__module__]
+ ffi = module._cffi_original_ffi
+ types_of_builtin_funcs = module._cffi_types_of_builtin_funcs
+ tp = types_of_builtin_funcs[func]
+ except (KeyError, AttributeError, TypeError):
+ return None
+ else:
+ with ffi._lock:
+ return ffi._get_cached_btype(tp)
diff --git a/lib/cffi/backend_ctypes.py b/lib/cffi/backend_ctypes.py
new file mode 100644
index 0000000..e7956a7
--- /dev/null
+++ b/lib/cffi/backend_ctypes.py
@@ -0,0 +1,1121 @@
+import ctypes, ctypes.util, operator, sys
+from . import model
+
+if sys.version_info < (3,):
+ bytechr = chr
+else:
+ unicode = str
+ long = int
+ xrange = range
+ bytechr = lambda num: bytes([num])
+
+class CTypesType(type):
+ pass
+
+class CTypesData(object):
+ __metaclass__ = CTypesType
+ __slots__ = ['__weakref__']
+ __name__ = '<cdata>'
+
+ def __init__(self, *args):
+ raise TypeError("cannot instantiate %r" % (self.__class__,))
+
+ @classmethod
+ def _newp(cls, init):
+ raise TypeError("expected a pointer or array ctype, got '%s'"
+ % (cls._get_c_name(),))
+
+ @staticmethod
+ def _to_ctypes(value):
+ raise TypeError
+
+ @classmethod
+ def _arg_to_ctypes(cls, *value):
+ try:
+ ctype = cls._ctype
+ except AttributeError:
+ raise TypeError("cannot create an instance of %r" % (cls,))
+ if value:
+ res = cls._to_ctypes(*value)
+ if not isinstance(res, ctype):
+ res = cls._ctype(res)
+ else:
+ res = cls._ctype()
+ return res
+
+ @classmethod
+ def _create_ctype_obj(cls, init):
+ if init is None:
+ return cls._arg_to_ctypes()
+ else:
+ return cls._arg_to_ctypes(init)
+
+ @staticmethod
+ def _from_ctypes(ctypes_value):
+ raise TypeError
+
+ @classmethod
+ def _get_c_name(cls, replace_with=''):
+ return cls._reftypename.replace(' &', replace_with)
+
+ @classmethod
+ def _fix_class(cls):
+ cls.__name__ = 'CData<%s>' % (cls._get_c_name(),)
+ cls.__qualname__ = 'CData<%s>' % (cls._get_c_name(),)
+ cls.__module__ = 'ffi'
+
+ def _get_own_repr(self):
+ raise NotImplementedError
+
+ def _addr_repr(self, address):
+ if address == 0:
+ return 'NULL'
+ else:
+ if address < 0:
+ address += 1 << (8*ctypes.sizeof(ctypes.c_void_p))
+ return '0x%x' % address
+
+ def __repr__(self, c_name=None):
+ own = self._get_own_repr()
+ return '<cdata %r %s>' % (c_name or self._get_c_name(), own)
+
+ def _convert_to_address(self, BClass):
+ if BClass is None:
+ raise TypeError("cannot convert %r to an address" % (
+ self._get_c_name(),))
+ else:
+ raise TypeError("cannot convert %r to %r" % (
+ self._get_c_name(), BClass._get_c_name()))
+
+ @classmethod
+ def _get_size(cls):
+ return ctypes.sizeof(cls._ctype)
+
+ def _get_size_of_instance(self):
+ return ctypes.sizeof(self._ctype)
+
+ @classmethod
+ def _cast_from(cls, source):
+ raise TypeError("cannot cast to %r" % (cls._get_c_name(),))
+
+ def _cast_to_integer(self):
+ return self._convert_to_address(None)
+
+ @classmethod
+ def _alignment(cls):
+ return ctypes.alignment(cls._ctype)
+
+ def __iter__(self):
+ raise TypeError("cdata %r does not support iteration" % (
+ self._get_c_name()),)
+
+ def _make_cmp(name):
+ cmpfunc = getattr(operator, name)
+ def cmp(self, other):
+ v_is_ptr = not isinstance(self, CTypesGenericPrimitive)
+ w_is_ptr = (isinstance(other, CTypesData) and
+ not isinstance(other, CTypesGenericPrimitive))
+ if v_is_ptr and w_is_ptr:
+ return cmpfunc(self._convert_to_address(None),
+ other._convert_to_address(None))
+ elif v_is_ptr or w_is_ptr:
+ return NotImplemented
+ else:
+ if isinstance(self, CTypesGenericPrimitive):
+ self = self._value
+ if isinstance(other, CTypesGenericPrimitive):
+ other = other._value
+ return cmpfunc(self, other)
+ cmp.func_name = name
+ return cmp
+
+ __eq__ = _make_cmp('__eq__')
+ __ne__ = _make_cmp('__ne__')
+ __lt__ = _make_cmp('__lt__')
+ __le__ = _make_cmp('__le__')
+ __gt__ = _make_cmp('__gt__')
+ __ge__ = _make_cmp('__ge__')
+
+ def __hash__(self):
+ return hash(self._convert_to_address(None))
+
+ def _to_string(self, maxlen):
+ raise TypeError("string(): %r" % (self,))
+
+
+class CTypesGenericPrimitive(CTypesData):
+ __slots__ = []
+
+ def __hash__(self):
+ return hash(self._value)
+
+ def _get_own_repr(self):
+ return repr(self._from_ctypes(self._value))
+
+
+class CTypesGenericArray(CTypesData):
+ __slots__ = []
+
+ @classmethod
+ def _newp(cls, init):
+ return cls(init)
+
+ def __iter__(self):
+ for i in xrange(len(self)):
+ yield self[i]
+
+ def _get_own_repr(self):
+ return self._addr_repr(ctypes.addressof(self._blob))
+
+
+class CTypesGenericPtr(CTypesData):
+ __slots__ = ['_address', '_as_ctype_ptr']
+ _automatic_casts = False
+ kind = "pointer"
+
+ @classmethod
+ def _newp(cls, init):
+ return cls(init)
+
+ @classmethod
+ def _cast_from(cls, source):
+ if source is None:
+ address = 0
+ elif isinstance(source, CTypesData):
+ address = source._cast_to_integer()
+ elif isinstance(source, (int, long)):
+ address = source
+ else:
+ raise TypeError("bad type for cast to %r: %r" %
+ (cls, type(source).__name__))
+ return cls._new_pointer_at(address)
+
+ @classmethod
+ def _new_pointer_at(cls, address):
+ self = cls.__new__(cls)
+ self._address = address
+ self._as_ctype_ptr = ctypes.cast(address, cls._ctype)
+ return self
+
+ def _get_own_repr(self):
+ try:
+ return self._addr_repr(self._address)
+ except AttributeError:
+ return '???'
+
+ def _cast_to_integer(self):
+ return self._address
+
+ def __nonzero__(self):
+ return bool(self._address)
+ __bool__ = __nonzero__
+
+ @classmethod
+ def _to_ctypes(cls, value):
+ if not isinstance(value, CTypesData):
+ raise TypeError("unexpected %s object" % type(value).__name__)
+ address = value._convert_to_address(cls)
+ return ctypes.cast(address, cls._ctype)
+
+ @classmethod
+ def _from_ctypes(cls, ctypes_ptr):
+ address = ctypes.cast(ctypes_ptr, ctypes.c_void_p).value or 0
+ return cls._new_pointer_at(address)
+
+ @classmethod
+ def _initialize(cls, ctypes_ptr, value):
+ if value:
+ ctypes_ptr.contents = cls._to_ctypes(value).contents
+
+ def _convert_to_address(self, BClass):
+ if (BClass in (self.__class__, None) or BClass._automatic_casts
+ or self._automatic_casts):
+ return self._address
+ else:
+ return CTypesData._convert_to_address(self, BClass)
+
+
+class CTypesBaseStructOrUnion(CTypesData):
+ __slots__ = ['_blob']
+
+ @classmethod
+ def _create_ctype_obj(cls, init):
+ # may be overridden
+ raise TypeError("cannot instantiate opaque type %s" % (cls,))
+
+ def _get_own_repr(self):
+ return self._addr_repr(ctypes.addressof(self._blob))
+
+ @classmethod
+ def _offsetof(cls, fieldname):
+ return getattr(cls._ctype, fieldname).offset
+
+ def _convert_to_address(self, BClass):
+ if getattr(BClass, '_BItem', None) is self.__class__:
+ return ctypes.addressof(self._blob)
+ else:
+ return CTypesData._convert_to_address(self, BClass)
+
+ @classmethod
+ def _from_ctypes(cls, ctypes_struct_or_union):
+ self = cls.__new__(cls)
+ self._blob = ctypes_struct_or_union
+ return self
+
+ @classmethod
+ def _to_ctypes(cls, value):
+ return value._blob
+
+ def __repr__(self, c_name=None):
+ return CTypesData.__repr__(self, c_name or self._get_c_name(' &'))
+
+
+class CTypesBackend(object):
+
+ PRIMITIVE_TYPES = {
+ 'char': ctypes.c_char,
+ 'short': ctypes.c_short,
+ 'int': ctypes.c_int,
+ 'long': ctypes.c_long,
+ 'long long': ctypes.c_longlong,
+ 'signed char': ctypes.c_byte,
+ 'unsigned char': ctypes.c_ubyte,
+ 'unsigned short': ctypes.c_ushort,
+ 'unsigned int': ctypes.c_uint,
+ 'unsigned long': ctypes.c_ulong,
+ 'unsigned long long': ctypes.c_ulonglong,
+ 'float': ctypes.c_float,
+ 'double': ctypes.c_double,
+ '_Bool': ctypes.c_bool,
+ }
+
+ for _name in ['unsigned long long', 'unsigned long',
+ 'unsigned int', 'unsigned short', 'unsigned char']:
+ _size = ctypes.sizeof(PRIMITIVE_TYPES[_name])
+ PRIMITIVE_TYPES['uint%d_t' % (8*_size)] = PRIMITIVE_TYPES[_name]
+ if _size == ctypes.sizeof(ctypes.c_void_p):
+ PRIMITIVE_TYPES['uintptr_t'] = PRIMITIVE_TYPES[_name]
+ if _size == ctypes.sizeof(ctypes.c_size_t):
+ PRIMITIVE_TYPES['size_t'] = PRIMITIVE_TYPES[_name]
+
+ for _name in ['long long', 'long', 'int', 'short', 'signed char']:
+ _size = ctypes.sizeof(PRIMITIVE_TYPES[_name])
+ PRIMITIVE_TYPES['int%d_t' % (8*_size)] = PRIMITIVE_TYPES[_name]
+ if _size == ctypes.sizeof(ctypes.c_void_p):
+ PRIMITIVE_TYPES['intptr_t'] = PRIMITIVE_TYPES[_name]
+ PRIMITIVE_TYPES['ptrdiff_t'] = PRIMITIVE_TYPES[_name]
+ if _size == ctypes.sizeof(ctypes.c_size_t):
+ PRIMITIVE_TYPES['ssize_t'] = PRIMITIVE_TYPES[_name]
+
+
+ def __init__(self):
+ self.RTLD_LAZY = 0 # not supported anyway by ctypes
+ self.RTLD_NOW = 0
+ self.RTLD_GLOBAL = ctypes.RTLD_GLOBAL
+ self.RTLD_LOCAL = ctypes.RTLD_LOCAL
+
+ def set_ffi(self, ffi):
+ self.ffi = ffi
+
+ def _get_types(self):
+ return CTypesData, CTypesType
+
+ def load_library(self, path, flags=0):
+ cdll = ctypes.CDLL(path, flags)
+ return CTypesLibrary(self, cdll)
+
+ def new_void_type(self):
+ class CTypesVoid(CTypesData):
+ __slots__ = []
+ _reftypename = 'void &'
+ @staticmethod
+ def _from_ctypes(novalue):
+ return None
+ @staticmethod
+ def _to_ctypes(novalue):
+ if novalue is not None:
+ raise TypeError("None expected, got %s object" %
+ (type(novalue).__name__,))
+ return None
+ CTypesVoid._fix_class()
+ return CTypesVoid
+
+ def new_primitive_type(self, name):
+ if name == 'wchar_t':
+ raise NotImplementedError(name)
+ ctype = self.PRIMITIVE_TYPES[name]
+ if name == 'char':
+ kind = 'char'
+ elif name in ('float', 'double'):
+ kind = 'float'
+ else:
+ if name in ('signed char', 'unsigned char'):
+ kind = 'byte'
+ elif name == '_Bool':
+ kind = 'bool'
+ else:
+ kind = 'int'
+ is_signed = (ctype(-1).value == -1)
+ #
+ def _cast_source_to_int(source):
+ if isinstance(source, (int, long, float)):
+ source = int(source)
+ elif isinstance(source, CTypesData):
+ source = source._cast_to_integer()
+ elif isinstance(source, bytes):
+ source = ord(source)
+ elif source is None:
+ source = 0
+ else:
+ raise TypeError("bad type for cast to %r: %r" %
+ (CTypesPrimitive, type(source).__name__))
+ return source
+ #
+ kind1 = kind
+ class CTypesPrimitive(CTypesGenericPrimitive):
+ __slots__ = ['_value']
+ _ctype = ctype
+ _reftypename = '%s &' % name
+ kind = kind1
+
+ def __init__(self, value):
+ self._value = value
+
+ @staticmethod
+ def _create_ctype_obj(init):
+ if init is None:
+ return ctype()
+ return ctype(CTypesPrimitive._to_ctypes(init))
+
+ if kind == 'int' or kind == 'byte':
+ @classmethod
+ def _cast_from(cls, source):
+ source = _cast_source_to_int(source)
+ source = ctype(source).value # cast within range
+ return cls(source)
+ def __int__(self):
+ return self._value
+
+ if kind == 'bool':
+ @classmethod
+ def _cast_from(cls, source):
+ if not isinstance(source, (int, long, float)):
+ source = _cast_source_to_int(source)
+ return cls(bool(source))
+ def __int__(self):
+ return int(self._value)
+
+ if kind == 'char':
+ @classmethod
+ def _cast_from(cls, source):
+ source = _cast_source_to_int(source)
+ source = bytechr(source & 0xFF)
+ return cls(source)
+ def __int__(self):
+ return ord(self._value)
+
+ if kind == 'float':
+ @classmethod
+ def _cast_from(cls, source):
+ if isinstance(source, float):
+ pass
+ elif isinstance(source, CTypesGenericPrimitive):
+ if hasattr(source, '__float__'):
+ source = float(source)
+ else:
+ source = int(source)
+ else:
+ source = _cast_source_to_int(source)
+ source = ctype(source).value # fix precision
+ return cls(source)
+ def __int__(self):
+ return int(self._value)
+ def __float__(self):
+ return self._value
+
+ _cast_to_integer = __int__
+
+ if kind == 'int' or kind == 'byte' or kind == 'bool':
+ @staticmethod
+ def _to_ctypes(x):
+ if not isinstance(x, (int, long)):
+ if isinstance(x, CTypesData):
+ x = int(x)
+ else:
+ raise TypeError("integer expected, got %s" %
+ type(x).__name__)
+ if ctype(x).value != x:
+ if not is_signed and x < 0:
+ raise OverflowError("%s: negative integer" % name)
+ else:
+ raise OverflowError("%s: integer out of bounds"
+ % name)
+ return x
+
+ if kind == 'char':
+ @staticmethod
+ def _to_ctypes(x):
+ if isinstance(x, bytes) and len(x) == 1:
+ return x
+ if isinstance(x, CTypesPrimitive): # <CData <char>>
+ return x._value
+ raise TypeError("character expected, got %s" %
+ type(x).__name__)
+ def __nonzero__(self):
+ return ord(self._value) != 0
+ else:
+ def __nonzero__(self):
+ return self._value != 0
+ __bool__ = __nonzero__
+
+ if kind == 'float':
+ @staticmethod
+ def _to_ctypes(x):
+ if not isinstance(x, (int, long, float, CTypesData)):
+ raise TypeError("float expected, got %s" %
+ type(x).__name__)
+ return ctype(x).value
+
+ @staticmethod
+ def _from_ctypes(value):
+ return getattr(value, 'value', value)
+
+ @staticmethod
+ def _initialize(blob, init):
+ blob.value = CTypesPrimitive._to_ctypes(init)
+
+ if kind == 'char':
+ def _to_string(self, maxlen):
+ return self._value
+ if kind == 'byte':
+ def _to_string(self, maxlen):
+ return chr(self._value & 0xff)
+ #
+ CTypesPrimitive._fix_class()
+ return CTypesPrimitive
+
+ def new_pointer_type(self, BItem):
+ getbtype = self.ffi._get_cached_btype
+ if BItem is getbtype(model.PrimitiveType('char')):
+ kind = 'charp'
+ elif BItem in (getbtype(model.PrimitiveType('signed char')),
+ getbtype(model.PrimitiveType('unsigned char'))):
+ kind = 'bytep'
+ elif BItem is getbtype(model.void_type):
+ kind = 'voidp'
+ else:
+ kind = 'generic'
+ #
+ class CTypesPtr(CTypesGenericPtr):
+ __slots__ = ['_own']
+ if kind == 'charp':
+ __slots__ += ['__as_strbuf']
+ _BItem = BItem
+ if hasattr(BItem, '_ctype'):
+ _ctype = ctypes.POINTER(BItem._ctype)
+ _bitem_size = ctypes.sizeof(BItem._ctype)
+ else:
+ _ctype = ctypes.c_void_p
+ if issubclass(BItem, CTypesGenericArray):
+ _reftypename = BItem._get_c_name('(* &)')
+ else:
+ _reftypename = BItem._get_c_name(' * &')
+
+ def __init__(self, init):
+ ctypeobj = BItem._create_ctype_obj(init)
+ if kind == 'charp':
+ self.__as_strbuf = ctypes.create_string_buffer(
+ ctypeobj.value + b'\x00')
+ self._as_ctype_ptr = ctypes.cast(
+ self.__as_strbuf, self._ctype)
+ else:
+ self._as_ctype_ptr = ctypes.pointer(ctypeobj)
+ self._address = ctypes.cast(self._as_ctype_ptr,
+ ctypes.c_void_p).value
+ self._own = True
+
+ def __add__(self, other):
+ if isinstance(other, (int, long)):
+ return self._new_pointer_at(self._address +
+ other * self._bitem_size)
+ else:
+ return NotImplemented
+
+ def __sub__(self, other):
+ if isinstance(other, (int, long)):
+ return self._new_pointer_at(self._address -
+ other * self._bitem_size)
+ elif type(self) is type(other):
+ return (self._address - other._address) // self._bitem_size
+ else:
+ return NotImplemented
+
+ def __getitem__(self, index):
+ if getattr(self, '_own', False) and index != 0:
+ raise IndexError
+ return BItem._from_ctypes(self._as_ctype_ptr[index])
+
+ def __setitem__(self, index, value):
+ self._as_ctype_ptr[index] = BItem._to_ctypes(value)
+
+ if kind == 'charp' or kind == 'voidp':
+ @classmethod
+ def _arg_to_ctypes(cls, *value):
+ if value and isinstance(value[0], bytes):
+ return ctypes.c_char_p(value[0])
+ else:
+ return super(CTypesPtr, cls)._arg_to_ctypes(*value)
+
+ if kind == 'charp' or kind == 'bytep':
+ def _to_string(self, maxlen):
+ if maxlen < 0:
+ maxlen = sys.maxsize
+ p = ctypes.cast(self._as_ctype_ptr,
+ ctypes.POINTER(ctypes.c_char))
+ n = 0
+ while n < maxlen and p[n] != b'\x00':
+ n += 1
+ return b''.join([p[i] for i in range(n)])
+
+ def _get_own_repr(self):
+ if getattr(self, '_own', False):
+ return 'owning %d bytes' % (
+ ctypes.sizeof(self._as_ctype_ptr.contents),)
+ return super(CTypesPtr, self)._get_own_repr()
+ #
+ if (BItem is self.ffi._get_cached_btype(model.void_type) or
+ BItem is self.ffi._get_cached_btype(model.PrimitiveType('char'))):
+ CTypesPtr._automatic_casts = True
+ #
+ CTypesPtr._fix_class()
+ return CTypesPtr
+
+ def new_array_type(self, CTypesPtr, length):
+ if length is None:
+ brackets = ' &[]'
+ else:
+ brackets = ' &[%d]' % length
+ BItem = CTypesPtr._BItem
+ getbtype = self.ffi._get_cached_btype
+ if BItem is getbtype(model.PrimitiveType('char')):
+ kind = 'char'
+ elif BItem in (getbtype(model.PrimitiveType('signed char')),
+ getbtype(model.PrimitiveType('unsigned char'))):
+ kind = 'byte'
+ else:
+ kind = 'generic'
+ #
+ class CTypesArray(CTypesGenericArray):
+ __slots__ = ['_blob', '_own']
+ if length is not None:
+ _ctype = BItem._ctype * length
+ else:
+ __slots__.append('_ctype')
+ _reftypename = BItem._get_c_name(brackets)
+ _declared_length = length
+ _CTPtr = CTypesPtr
+
+ def __init__(self, init):
+ if length is None:
+ if isinstance(init, (int, long)):
+ len1 = init
+ init = None
+ elif kind == 'char' and isinstance(init, bytes):
+ len1 = len(init) + 1 # extra null
+ else:
+ init = tuple(init)
+ len1 = len(init)
+ self._ctype = BItem._ctype * len1
+ self._blob = self._ctype()
+ self._own = True
+ if init is not None:
+ self._initialize(self._blob, init)
+
+ @staticmethod
+ def _initialize(blob, init):
+ if isinstance(init, bytes):
+ init = [init[i:i+1] for i in range(len(init))]
+ else:
+ if isinstance(init, CTypesGenericArray):
+ if (len(init) != len(blob) or
+ not isinstance(init, CTypesArray)):
+ raise TypeError("length/type mismatch: %s" % (init,))
+ init = tuple(init)
+ if len(init) > len(blob):
+ raise IndexError("too many initializers")
+ addr = ctypes.cast(blob, ctypes.c_void_p).value
+ PTR = ctypes.POINTER(BItem._ctype)
+ itemsize = ctypes.sizeof(BItem._ctype)
+ for i, value in enumerate(init):
+ p = ctypes.cast(addr + i * itemsize, PTR)
+ BItem._initialize(p.contents, value)
+
+ def __len__(self):
+ return len(self._blob)
+
+ def __getitem__(self, index):
+ if not (0 <= index < len(self._blob)):
+ raise IndexError
+ return BItem._from_ctypes(self._blob[index])
+
+ def __setitem__(self, index, value):
+ if not (0 <= index < len(self._blob)):
+ raise IndexError
+ self._blob[index] = BItem._to_ctypes(value)
+
+ if kind == 'char' or kind == 'byte':
+ def _to_string(self, maxlen):
+ if maxlen < 0:
+ maxlen = len(self._blob)
+ p = ctypes.cast(self._blob,
+ ctypes.POINTER(ctypes.c_char))
+ n = 0
+ while n < maxlen and p[n] != b'\x00':
+ n += 1
+ return b''.join([p[i] for i in range(n)])
+
+ def _get_own_repr(self):
+ if getattr(self, '_own', False):
+ return 'owning %d bytes' % (ctypes.sizeof(self._blob),)
+ return super(CTypesArray, self)._get_own_repr()
+
+ def _convert_to_address(self, BClass):
+ if BClass in (CTypesPtr, None) or BClass._automatic_casts:
+ return ctypes.addressof(self._blob)
+ else:
+ return CTypesData._convert_to_address(self, BClass)
+
+ @staticmethod
+ def _from_ctypes(ctypes_array):
+ self = CTypesArray.__new__(CTypesArray)
+ self._blob = ctypes_array
+ return self
+
+ @staticmethod
+ def _arg_to_ctypes(value):
+ return CTypesPtr._arg_to_ctypes(value)
+
+ def __add__(self, other):
+ if isinstance(other, (int, long)):
+ return CTypesPtr._new_pointer_at(
+ ctypes.addressof(self._blob) +
+ other * ctypes.sizeof(BItem._ctype))
+ else:
+ return NotImplemented
+
+ @classmethod
+ def _cast_from(cls, source):
+ raise NotImplementedError("casting to %r" % (
+ cls._get_c_name(),))
+ #
+ CTypesArray._fix_class()
+ return CTypesArray
+
+ def _new_struct_or_union(self, kind, name, base_ctypes_class):
+ #
+ class struct_or_union(base_ctypes_class):
+ pass
+ struct_or_union.__name__ = '%s_%s' % (kind, name)
+ kind1 = kind
+ #
+ class CTypesStructOrUnion(CTypesBaseStructOrUnion):
+ __slots__ = ['_blob']
+ _ctype = struct_or_union
+ _reftypename = '%s &' % (name,)
+ _kind = kind = kind1
+ #
+ CTypesStructOrUnion._fix_class()
+ return CTypesStructOrUnion
+
+ def new_struct_type(self, name):
+ return self._new_struct_or_union('struct', name, ctypes.Structure)
+
+ def new_union_type(self, name):
+ return self._new_struct_or_union('union', name, ctypes.Union)
+
+ def complete_struct_or_union(self, CTypesStructOrUnion, fields, tp,
+ totalsize=-1, totalalignment=-1, sflags=0,
+ pack=0):
+ if totalsize >= 0 or totalalignment >= 0:
+ raise NotImplementedError("the ctypes backend of CFFI does not support "
+ "structures completed by verify(); please "
+ "compile and install the _cffi_backend module.")
+ struct_or_union = CTypesStructOrUnion._ctype
+ fnames = [fname for (fname, BField, bitsize) in fields]
+ btypes = [BField for (fname, BField, bitsize) in fields]
+ bitfields = [bitsize for (fname, BField, bitsize) in fields]
+ #
+ bfield_types = {}
+ cfields = []
+ for (fname, BField, bitsize) in fields:
+ if bitsize < 0:
+ cfields.append((fname, BField._ctype))
+ bfield_types[fname] = BField
+ else:
+ cfields.append((fname, BField._ctype, bitsize))
+ bfield_types[fname] = Ellipsis
+ if sflags & 8:
+ struct_or_union._pack_ = 1
+ elif pack:
+ struct_or_union._pack_ = pack
+ struct_or_union._fields_ = cfields
+ CTypesStructOrUnion._bfield_types = bfield_types
+ #
+ @staticmethod
+ def _create_ctype_obj(init):
+ result = struct_or_union()
+ if init is not None:
+ initialize(result, init)
+ return result
+ CTypesStructOrUnion._create_ctype_obj = _create_ctype_obj
+ #
+ def initialize(blob, init):
+ if is_union:
+ if len(init) > 1:
+ raise ValueError("union initializer: %d items given, but "
+ "only one supported (use a dict if needed)"
+ % (len(init),))
+ if not isinstance(init, dict):
+ if isinstance(init, (bytes, unicode)):
+ raise TypeError("union initializer: got a str")
+ init = tuple(init)
+ if len(init) > len(fnames):
+ raise ValueError("too many values for %s initializer" %
+ CTypesStructOrUnion._get_c_name())
+ init = dict(zip(fnames, init))
+ addr = ctypes.addressof(blob)
+ for fname, value in init.items():
+ BField, bitsize = name2fieldtype[fname]
+ assert bitsize < 0, \
+ "not implemented: initializer with bit fields"
+ offset = CTypesStructOrUnion._offsetof(fname)
+ PTR = ctypes.POINTER(BField._ctype)
+ p = ctypes.cast(addr + offset, PTR)
+ BField._initialize(p.contents, value)
+ is_union = CTypesStructOrUnion._kind == 'union'
+ name2fieldtype = dict(zip(fnames, zip(btypes, bitfields)))
+ #
+ for fname, BField, bitsize in fields:
+ if fname == '':
+ raise NotImplementedError("nested anonymous structs/unions")
+ if hasattr(CTypesStructOrUnion, fname):
+ raise ValueError("the field name %r conflicts in "
+ "the ctypes backend" % fname)
+ if bitsize < 0:
+ def getter(self, fname=fname, BField=BField,
+ offset=CTypesStructOrUnion._offsetof(fname),
+ PTR=ctypes.POINTER(BField._ctype)):
+ addr = ctypes.addressof(self._blob)
+ p = ctypes.cast(addr + offset, PTR)
+ return BField._from_ctypes(p.contents)
+ def setter(self, value, fname=fname, BField=BField):
+ setattr(self._blob, fname, BField._to_ctypes(value))
+ #
+ if issubclass(BField, CTypesGenericArray):
+ setter = None
+ if BField._declared_length == 0:
+ def getter(self, fname=fname, BFieldPtr=BField._CTPtr,
+ offset=CTypesStructOrUnion._offsetof(fname),
+ PTR=ctypes.POINTER(BField._ctype)):
+ addr = ctypes.addressof(self._blob)
+ p = ctypes.cast(addr + offset, PTR)
+ return BFieldPtr._from_ctypes(p)
+ #
+ else:
+ def getter(self, fname=fname, BField=BField):
+ return BField._from_ctypes(getattr(self._blob, fname))
+ def setter(self, value, fname=fname, BField=BField):
+ # xxx obscure workaround
+ value = BField._to_ctypes(value)
+ oldvalue = getattr(self._blob, fname)
+ setattr(self._blob, fname, value)
+ if value != getattr(self._blob, fname):
+ setattr(self._blob, fname, oldvalue)
+ raise OverflowError("value too large for bitfield")
+ setattr(CTypesStructOrUnion, fname, property(getter, setter))
+ #
+ CTypesPtr = self.ffi._get_cached_btype(model.PointerType(tp))
+ for fname in fnames:
+ if hasattr(CTypesPtr, fname):
+ raise ValueError("the field name %r conflicts in "
+ "the ctypes backend" % fname)
+ def getter(self, fname=fname):
+ return getattr(self[0], fname)
+ def setter(self, value, fname=fname):
+ setattr(self[0], fname, value)
+ setattr(CTypesPtr, fname, property(getter, setter))
+
+ def new_function_type(self, BArgs, BResult, has_varargs):
+ nameargs = [BArg._get_c_name() for BArg in BArgs]
+ if has_varargs:
+ nameargs.append('...')
+ nameargs = ', '.join(nameargs)
+ #
+ class CTypesFunctionPtr(CTypesGenericPtr):
+ __slots__ = ['_own_callback', '_name']
+ _ctype = ctypes.CFUNCTYPE(getattr(BResult, '_ctype', None),
+ *[BArg._ctype for BArg in BArgs],
+ use_errno=True)
+ _reftypename = BResult._get_c_name('(* &)(%s)' % (nameargs,))
+
+ def __init__(self, init, error=None):
+ # create a callback to the Python callable init()
+ import traceback
+ assert not has_varargs, "varargs not supported for callbacks"
+ if getattr(BResult, '_ctype', None) is not None:
+ error = BResult._from_ctypes(
+ BResult._create_ctype_obj(error))
+ else:
+ error = None
+ def callback(*args):
+ args2 = []
+ for arg, BArg in zip(args, BArgs):
+ args2.append(BArg._from_ctypes(arg))
+ try:
+ res2 = init(*args2)
+ res2 = BResult._to_ctypes(res2)
+ except:
+ traceback.print_exc()
+ res2 = error
+ if issubclass(BResult, CTypesGenericPtr):
+ if res2:
+ res2 = ctypes.cast(res2, ctypes.c_void_p).value
+ # .value: http://bugs.python.org/issue1574593
+ else:
+ res2 = None
+ #print repr(res2)
+ return res2
+ if issubclass(BResult, CTypesGenericPtr):
+ # The only pointers callbacks can return are void*s:
+ # http://bugs.python.org/issue5710
+ callback_ctype = ctypes.CFUNCTYPE(
+ ctypes.c_void_p,
+ *[BArg._ctype for BArg in BArgs],
+ use_errno=True)
+ else:
+ callback_ctype = CTypesFunctionPtr._ctype
+ self._as_ctype_ptr = callback_ctype(callback)
+ self._address = ctypes.cast(self._as_ctype_ptr,
+ ctypes.c_void_p).value
+ self._own_callback = init
+
+ @staticmethod
+ def _initialize(ctypes_ptr, value):
+ if value:
+ raise NotImplementedError("ctypes backend: not supported: "
+ "initializers for function pointers")
+
+ def __repr__(self):
+ c_name = getattr(self, '_name', None)
+ if c_name:
+ i = self._reftypename.index('(* &)')
+ if self._reftypename[i-1] not in ' )*':
+ c_name = ' ' + c_name
+ c_name = self._reftypename.replace('(* &)', c_name)
+ return CTypesData.__repr__(self, c_name)
+
+ def _get_own_repr(self):
+ if getattr(self, '_own_callback', None) is not None:
+ return 'calling %r' % (self._own_callback,)
+ return super(CTypesFunctionPtr, self)._get_own_repr()
+
+ def __call__(self, *args):
+ if has_varargs:
+ assert len(args) >= len(BArgs)
+ extraargs = args[len(BArgs):]
+ args = args[:len(BArgs)]
+ else:
+ assert len(args) == len(BArgs)
+ ctypes_args = []
+ for arg, BArg in zip(args, BArgs):
+ ctypes_args.append(BArg._arg_to_ctypes(arg))
+ if has_varargs:
+ for i, arg in enumerate(extraargs):
+ if arg is None:
+ ctypes_args.append(ctypes.c_void_p(0)) # NULL
+ continue
+ if not isinstance(arg, CTypesData):
+ raise TypeError(
+ "argument %d passed in the variadic part "
+ "needs to be a cdata object (got %s)" %
+ (1 + len(BArgs) + i, type(arg).__name__))
+ ctypes_args.append(arg._arg_to_ctypes(arg))
+ result = self._as_ctype_ptr(*ctypes_args)
+ return BResult._from_ctypes(result)
+ #
+ CTypesFunctionPtr._fix_class()
+ return CTypesFunctionPtr
+
+ def new_enum_type(self, name, enumerators, enumvalues, CTypesInt):
+ assert isinstance(name, str)
+ reverse_mapping = dict(zip(reversed(enumvalues),
+ reversed(enumerators)))
+ #
+ class CTypesEnum(CTypesInt):
+ __slots__ = []
+ _reftypename = '%s &' % name
+
+ def _get_own_repr(self):
+ value = self._value
+ try:
+ return '%d: %s' % (value, reverse_mapping[value])
+ except KeyError:
+ return str(value)
+
+ def _to_string(self, maxlen):
+ value = self._value
+ try:
+ return reverse_mapping[value]
+ except KeyError:
+ return str(value)
+ #
+ CTypesEnum._fix_class()
+ return CTypesEnum
+
+ def get_errno(self):
+ return ctypes.get_errno()
+
+ def set_errno(self, value):
+ ctypes.set_errno(value)
+
+ def string(self, b, maxlen=-1):
+ return b._to_string(maxlen)
+
+ def buffer(self, bptr, size=-1):
+ raise NotImplementedError("buffer() with ctypes backend")
+
+ def sizeof(self, cdata_or_BType):
+ if isinstance(cdata_or_BType, CTypesData):
+ return cdata_or_BType._get_size_of_instance()
+ else:
+ assert issubclass(cdata_or_BType, CTypesData)
+ return cdata_or_BType._get_size()
+
+ def alignof(self, BType):
+ assert issubclass(BType, CTypesData)
+ return BType._alignment()
+
+ def newp(self, BType, source):
+ if not issubclass(BType, CTypesData):
+ raise TypeError
+ return BType._newp(source)
+
+ def cast(self, BType, source):
+ return BType._cast_from(source)
+
+ def callback(self, BType, source, error, onerror):
+ assert onerror is None # XXX not implemented
+ return BType(source, error)
+
+ _weakref_cache_ref = None
+
+ def gcp(self, cdata, destructor, size=0):
+ if self._weakref_cache_ref is None:
+ import weakref
+ class MyRef(weakref.ref):
+ def __eq__(self, other):
+ myref = self()
+ return self is other or (
+ myref is not None and myref is other())
+ def __ne__(self, other):
+ return not (self == other)
+ def __hash__(self):
+ try:
+ return self._hash
+ except AttributeError:
+ self._hash = hash(self())
+ return self._hash
+ self._weakref_cache_ref = {}, MyRef
+ weak_cache, MyRef = self._weakref_cache_ref
+
+ if destructor is None:
+ try:
+ del weak_cache[MyRef(cdata)]
+ except KeyError:
+ raise TypeError("Can remove destructor only on a object "
+ "previously returned by ffi.gc()")
+ return None
+
+ def remove(k):
+ cdata, destructor = weak_cache.pop(k, (None, None))
+ if destructor is not None:
+ destructor(cdata)
+
+ new_cdata = self.cast(self.typeof(cdata), cdata)
+ assert new_cdata is not cdata
+ weak_cache[MyRef(new_cdata, remove)] = (cdata, destructor)
+ return new_cdata
+
+ typeof = type
+
+ def getcname(self, BType, replace_with):
+ return BType._get_c_name(replace_with)
+
+ def typeoffsetof(self, BType, fieldname, num=0):
+ if isinstance(fieldname, str):
+ if num == 0 and issubclass(BType, CTypesGenericPtr):
+ BType = BType._BItem
+ if not issubclass(BType, CTypesBaseStructOrUnion):
+ raise TypeError("expected a struct or union ctype")
+ BField = BType._bfield_types[fieldname]
+ if BField is Ellipsis:
+ raise TypeError("not supported for bitfields")
+ return (BField, BType._offsetof(fieldname))
+ elif isinstance(fieldname, (int, long)):
+ if issubclass(BType, CTypesGenericArray):
+ BType = BType._CTPtr
+ if not issubclass(BType, CTypesGenericPtr):
+ raise TypeError("expected an array or ptr ctype")
+ BItem = BType._BItem
+ offset = BItem._get_size() * fieldname
+ if offset > sys.maxsize:
+ raise OverflowError
+ return (BItem, offset)
+ else:
+ raise TypeError(type(fieldname))
+
+ def rawaddressof(self, BTypePtr, cdata, offset=None):
+ if isinstance(cdata, CTypesBaseStructOrUnion):
+ ptr = ctypes.pointer(type(cdata)._to_ctypes(cdata))
+ elif isinstance(cdata, CTypesGenericPtr):
+ if offset is None or not issubclass(type(cdata)._BItem,
+ CTypesBaseStructOrUnion):
+ raise TypeError("unexpected cdata type")
+ ptr = type(cdata)._to_ctypes(cdata)
+ elif isinstance(cdata, CTypesGenericArray):
+ ptr = type(cdata)._to_ctypes(cdata)
+ else:
+ raise TypeError("expected a <cdata 'struct-or-union'>")
+ if offset:
+ ptr = ctypes.cast(
+ ctypes.c_void_p(
+ ctypes.cast(ptr, ctypes.c_void_p).value + offset),
+ type(ptr))
+ return BTypePtr._from_ctypes(ptr)
+
+
+class CTypesLibrary(object):
+
+ def __init__(self, backend, cdll):
+ self.backend = backend
+ self.cdll = cdll
+
+ def load_function(self, BType, name):
+ c_func = getattr(self.cdll, name)
+ funcobj = BType._from_ctypes(c_func)
+ funcobj._name = name
+ return funcobj
+
+ def read_variable(self, BType, name):
+ try:
+ ctypes_obj = BType._ctype.in_dll(self.cdll, name)
+ except AttributeError as e:
+ raise NotImplementedError(e)
+ return BType._from_ctypes(ctypes_obj)
+
+ def write_variable(self, BType, name, value):
+ new_ctypes_obj = BType._to_ctypes(value)
+ ctypes_obj = BType._ctype.in_dll(self.cdll, name)
+ ctypes.memmove(ctypes.addressof(ctypes_obj),
+ ctypes.addressof(new_ctypes_obj),
+ ctypes.sizeof(BType._ctype))
diff --git a/lib/cffi/cffi_opcode.py b/lib/cffi/cffi_opcode.py
new file mode 100644
index 0000000..a0df98d
--- /dev/null
+++ b/lib/cffi/cffi_opcode.py
@@ -0,0 +1,187 @@
+from .error import VerificationError
+
+class CffiOp(object):
+ def __init__(self, op, arg):
+ self.op = op
+ self.arg = arg
+
+ def as_c_expr(self):
+ if self.op is None:
+ assert isinstance(self.arg, str)
+ return '(_cffi_opcode_t)(%s)' % (self.arg,)
+ classname = CLASS_NAME[self.op]
+ return '_CFFI_OP(_CFFI_OP_%s, %s)' % (classname, self.arg)
+
+ def as_python_bytes(self):
+ if self.op is None and self.arg.isdigit():
+ value = int(self.arg) # non-negative: '-' not in self.arg
+ if value >= 2**31:
+ raise OverflowError("cannot emit %r: limited to 2**31-1"
+ % (self.arg,))
+ return format_four_bytes(value)
+ if isinstance(self.arg, str):
+ raise VerificationError("cannot emit to Python: %r" % (self.arg,))
+ return format_four_bytes((self.arg << 8) | self.op)
+
+ def __str__(self):
+ classname = CLASS_NAME.get(self.op, self.op)
+ return '(%s %s)' % (classname, self.arg)
+
+def format_four_bytes(num):
+ return '\\x%02X\\x%02X\\x%02X\\x%02X' % (
+ (num >> 24) & 0xFF,
+ (num >> 16) & 0xFF,
+ (num >> 8) & 0xFF,
+ (num ) & 0xFF)
+
+OP_PRIMITIVE = 1
+OP_POINTER = 3
+OP_ARRAY = 5
+OP_OPEN_ARRAY = 7
+OP_STRUCT_UNION = 9
+OP_ENUM = 11
+OP_FUNCTION = 13
+OP_FUNCTION_END = 15
+OP_NOOP = 17
+OP_BITFIELD = 19
+OP_TYPENAME = 21
+OP_CPYTHON_BLTN_V = 23 # varargs
+OP_CPYTHON_BLTN_N = 25 # noargs
+OP_CPYTHON_BLTN_O = 27 # O (i.e. a single arg)
+OP_CONSTANT = 29
+OP_CONSTANT_INT = 31
+OP_GLOBAL_VAR = 33
+OP_DLOPEN_FUNC = 35
+OP_DLOPEN_CONST = 37
+OP_GLOBAL_VAR_F = 39
+OP_EXTERN_PYTHON = 41
+
+PRIM_VOID = 0
+PRIM_BOOL = 1
+PRIM_CHAR = 2
+PRIM_SCHAR = 3
+PRIM_UCHAR = 4
+PRIM_SHORT = 5
+PRIM_USHORT = 6
+PRIM_INT = 7
+PRIM_UINT = 8
+PRIM_LONG = 9
+PRIM_ULONG = 10
+PRIM_LONGLONG = 11
+PRIM_ULONGLONG = 12
+PRIM_FLOAT = 13
+PRIM_DOUBLE = 14
+PRIM_LONGDOUBLE = 15
+
+PRIM_WCHAR = 16
+PRIM_INT8 = 17
+PRIM_UINT8 = 18
+PRIM_INT16 = 19
+PRIM_UINT16 = 20
+PRIM_INT32 = 21
+PRIM_UINT32 = 22
+PRIM_INT64 = 23
+PRIM_UINT64 = 24
+PRIM_INTPTR = 25
+PRIM_UINTPTR = 26
+PRIM_PTRDIFF = 27
+PRIM_SIZE = 28
+PRIM_SSIZE = 29
+PRIM_INT_LEAST8 = 30
+PRIM_UINT_LEAST8 = 31
+PRIM_INT_LEAST16 = 32
+PRIM_UINT_LEAST16 = 33
+PRIM_INT_LEAST32 = 34
+PRIM_UINT_LEAST32 = 35
+PRIM_INT_LEAST64 = 36
+PRIM_UINT_LEAST64 = 37
+PRIM_INT_FAST8 = 38
+PRIM_UINT_FAST8 = 39
+PRIM_INT_FAST16 = 40
+PRIM_UINT_FAST16 = 41
+PRIM_INT_FAST32 = 42
+PRIM_UINT_FAST32 = 43
+PRIM_INT_FAST64 = 44
+PRIM_UINT_FAST64 = 45
+PRIM_INTMAX = 46
+PRIM_UINTMAX = 47
+PRIM_FLOATCOMPLEX = 48
+PRIM_DOUBLECOMPLEX = 49
+PRIM_CHAR16 = 50
+PRIM_CHAR32 = 51
+
+_NUM_PRIM = 52
+_UNKNOWN_PRIM = -1
+_UNKNOWN_FLOAT_PRIM = -2
+_UNKNOWN_LONG_DOUBLE = -3
+
+_IO_FILE_STRUCT = -1
+
+PRIMITIVE_TO_INDEX = {
+ 'char': PRIM_CHAR,
+ 'short': PRIM_SHORT,
+ 'int': PRIM_INT,
+ 'long': PRIM_LONG,
+ 'long long': PRIM_LONGLONG,
+ 'signed char': PRIM_SCHAR,
+ 'unsigned char': PRIM_UCHAR,
+ 'unsigned short': PRIM_USHORT,
+ 'unsigned int': PRIM_UINT,
+ 'unsigned long': PRIM_ULONG,
+ 'unsigned long long': PRIM_ULONGLONG,
+ 'float': PRIM_FLOAT,
+ 'double': PRIM_DOUBLE,
+ 'long double': PRIM_LONGDOUBLE,
+ 'float _Complex': PRIM_FLOATCOMPLEX,
+ 'double _Complex': PRIM_DOUBLECOMPLEX,
+ '_Bool': PRIM_BOOL,
+ 'wchar_t': PRIM_WCHAR,
+ 'char16_t': PRIM_CHAR16,
+ 'char32_t': PRIM_CHAR32,
+ 'int8_t': PRIM_INT8,
+ 'uint8_t': PRIM_UINT8,
+ 'int16_t': PRIM_INT16,
+ 'uint16_t': PRIM_UINT16,
+ 'int32_t': PRIM_INT32,
+ 'uint32_t': PRIM_UINT32,
+ 'int64_t': PRIM_INT64,
+ 'uint64_t': PRIM_UINT64,
+ 'intptr_t': PRIM_INTPTR,
+ 'uintptr_t': PRIM_UINTPTR,
+ 'ptrdiff_t': PRIM_PTRDIFF,
+ 'size_t': PRIM_SIZE,
+ 'ssize_t': PRIM_SSIZE,
+ 'int_least8_t': PRIM_INT_LEAST8,
+ 'uint_least8_t': PRIM_UINT_LEAST8,
+ 'int_least16_t': PRIM_INT_LEAST16,
+ 'uint_least16_t': PRIM_UINT_LEAST16,
+ 'int_least32_t': PRIM_INT_LEAST32,
+ 'uint_least32_t': PRIM_UINT_LEAST32,
+ 'int_least64_t': PRIM_INT_LEAST64,
+ 'uint_least64_t': PRIM_UINT_LEAST64,
+ 'int_fast8_t': PRIM_INT_FAST8,
+ 'uint_fast8_t': PRIM_UINT_FAST8,
+ 'int_fast16_t': PRIM_INT_FAST16,
+ 'uint_fast16_t': PRIM_UINT_FAST16,
+ 'int_fast32_t': PRIM_INT_FAST32,
+ 'uint_fast32_t': PRIM_UINT_FAST32,
+ 'int_fast64_t': PRIM_INT_FAST64,
+ 'uint_fast64_t': PRIM_UINT_FAST64,
+ 'intmax_t': PRIM_INTMAX,
+ 'uintmax_t': PRIM_UINTMAX,
+ }
+
+F_UNION = 0x01
+F_CHECK_FIELDS = 0x02
+F_PACKED = 0x04
+F_EXTERNAL = 0x08
+F_OPAQUE = 0x10
+
+G_FLAGS = dict([('_CFFI_' + _key, globals()[_key])
+ for _key in ['F_UNION', 'F_CHECK_FIELDS', 'F_PACKED',
+ 'F_EXTERNAL', 'F_OPAQUE']])
+
+CLASS_NAME = {}
+for _name, _value in list(globals().items()):
+ if _name.startswith('OP_') and isinstance(_value, int):
+ CLASS_NAME[_value] = _name[3:]
diff --git a/lib/cffi/commontypes.py b/lib/cffi/commontypes.py
new file mode 100644
index 0000000..8ec97c7
--- /dev/null
+++ b/lib/cffi/commontypes.py
@@ -0,0 +1,80 @@
+import sys
+from . import model
+from .error import FFIError
+
+
+COMMON_TYPES = {}
+
+try:
+ # fetch "bool" and all simple Windows types
+ from _cffi_backend import _get_common_types
+ _get_common_types(COMMON_TYPES)
+except ImportError:
+ pass
+
+COMMON_TYPES['FILE'] = model.unknown_type('FILE', '_IO_FILE')
+COMMON_TYPES['bool'] = '_Bool' # in case we got ImportError above
+
+for _type in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
+ if _type.endswith('_t'):
+ COMMON_TYPES[_type] = _type
+del _type
+
+_CACHE = {}
+
+def resolve_common_type(parser, commontype):
+ try:
+ return _CACHE[commontype]
+ except KeyError:
+ cdecl = COMMON_TYPES.get(commontype, commontype)
+ if not isinstance(cdecl, str):
+ result, quals = cdecl, 0 # cdecl is already a BaseType
+ elif cdecl in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
+ result, quals = model.PrimitiveType(cdecl), 0
+ elif cdecl == 'set-unicode-needed':
+ raise FFIError("The Windows type %r is only available after "
+ "you call ffi.set_unicode()" % (commontype,))
+ else:
+ if commontype == cdecl:
+ raise FFIError(
+ "Unsupported type: %r. Please look at "
+ "http://cffi.readthedocs.io/en/latest/cdef.html#ffi-cdef-limitations "
+ "and file an issue if you think this type should really "
+ "be supported." % (commontype,))
+ result, quals = parser.parse_type_and_quals(cdecl) # recursive
+
+ assert isinstance(result, model.BaseTypeByIdentity)
+ _CACHE[commontype] = result, quals
+ return result, quals
+
+
+# ____________________________________________________________
+# extra types for Windows (most of them are in commontypes.c)
+
+
+def win_common_types():
+ return {
+ "UNICODE_STRING": model.StructType(
+ "_UNICODE_STRING",
+ ["Length",
+ "MaximumLength",
+ "Buffer"],
+ [model.PrimitiveType("unsigned short"),
+ model.PrimitiveType("unsigned short"),
+ model.PointerType(model.PrimitiveType("wchar_t"))],
+ [-1, -1, -1]),
+ "PUNICODE_STRING": "UNICODE_STRING *",
+ "PCUNICODE_STRING": "const UNICODE_STRING *",
+
+ "TBYTE": "set-unicode-needed",
+ "TCHAR": "set-unicode-needed",
+ "LPCTSTR": "set-unicode-needed",
+ "PCTSTR": "set-unicode-needed",
+ "LPTSTR": "set-unicode-needed",
+ "PTSTR": "set-unicode-needed",
+ "PTBYTE": "set-unicode-needed",
+ "PTCHAR": "set-unicode-needed",
+ }
+
+if sys.platform == 'win32':
+ COMMON_TYPES.update(win_common_types())
diff --git a/lib/cffi/cparser.py b/lib/cffi/cparser.py
new file mode 100644
index 0000000..74830e9
--- /dev/null
+++ b/lib/cffi/cparser.py
@@ -0,0 +1,1006 @@
+from . import model
+from .commontypes import COMMON_TYPES, resolve_common_type
+from .error import FFIError, CDefError
+try:
+ from . import _pycparser as pycparser
+except ImportError:
+ import pycparser
+import weakref, re, sys
+
+try:
+ if sys.version_info < (3,):
+ import thread as _thread
+ else:
+ import _thread
+ lock = _thread.allocate_lock()
+except ImportError:
+ lock = None
+
+def _workaround_for_static_import_finders():
+ # Issue #392: packaging tools like cx_Freeze can not find these
+ # because pycparser uses exec dynamic import. This is an obscure
+ # workaround. This function is never called.
+ import pycparser.yacctab
+ import pycparser.lextab
+
+CDEF_SOURCE_STRING = "<cdef source string>"
+_r_comment = re.compile(r"/\*.*?\*/|//([^\n\\]|\\.)*?$",
+ re.DOTALL | re.MULTILINE)
+_r_define = re.compile(r"^\s*#\s*define\s+([A-Za-z_][A-Za-z_0-9]*)"
+ r"\b((?:[^\n\\]|\\.)*?)$",
+ re.DOTALL | re.MULTILINE)
+_r_line_directive = re.compile(r"^[ \t]*#[ \t]*(?:line|\d+)\b.*$", re.MULTILINE)
+_r_partial_enum = re.compile(r"=\s*\.\.\.\s*[,}]|\.\.\.\s*\}")
+_r_enum_dotdotdot = re.compile(r"__dotdotdot\d+__$")
+_r_partial_array = re.compile(r"\[\s*\.\.\.\s*\]")
+_r_words = re.compile(r"\w+|\S")
+_parser_cache = None
+_r_int_literal = re.compile(r"-?0?x?[0-9a-f]+[lu]*$", re.IGNORECASE)
+_r_stdcall1 = re.compile(r"\b(__stdcall|WINAPI)\b")
+_r_stdcall2 = re.compile(r"[(]\s*(__stdcall|WINAPI)\b")
+_r_cdecl = re.compile(r"\b__cdecl\b")
+_r_extern_python = re.compile(r'\bextern\s*"'
+ r'(Python|Python\s*\+\s*C|C\s*\+\s*Python)"\s*.')
+_r_star_const_space = re.compile( # matches "* const "
+ r"[*]\s*((const|volatile|restrict)\b\s*)+")
+_r_int_dotdotdot = re.compile(r"(\b(int|long|short|signed|unsigned|char)\s*)+"
+ r"\.\.\.")
+_r_float_dotdotdot = re.compile(r"\b(double|float)\s*\.\.\.")
+
+def _get_parser():
+ global _parser_cache
+ if _parser_cache is None:
+ _parser_cache = pycparser.CParser()
+ return _parser_cache
+
+def _workaround_for_old_pycparser(csource):
+ # Workaround for a pycparser issue (fixed between pycparser 2.10 and
+ # 2.14): "char*const***" gives us a wrong syntax tree, the same as
+ # for "char***(*const)". This means we can't tell the difference
+ # afterwards. But "char(*const(***))" gives us the right syntax
+ # tree. The issue only occurs if there are several stars in
+ # sequence with no parenthesis inbetween, just possibly qualifiers.
+ # Attempt to fix it by adding some parentheses in the source: each
+ # time we see "* const" or "* const *", we add an opening
+ # parenthesis before each star---the hard part is figuring out where
+ # to close them.
+ parts = []
+ while True:
+ match = _r_star_const_space.search(csource)
+ if not match:
+ break
+ #print repr(''.join(parts)+csource), '=>',
+ parts.append(csource[:match.start()])
+ parts.append('('); closing = ')'
+ parts.append(match.group()) # e.g. "* const "
+ endpos = match.end()
+ if csource.startswith('*', endpos):
+ parts.append('('); closing += ')'
+ level = 0
+ i = endpos
+ while i < len(csource):
+ c = csource[i]
+ if c == '(':
+ level += 1
+ elif c == ')':
+ if level == 0:
+ break
+ level -= 1
+ elif c in ',;=':
+ if level == 0:
+ break
+ i += 1
+ csource = csource[endpos:i] + closing + csource[i:]
+ #print repr(''.join(parts)+csource)
+ parts.append(csource)
+ return ''.join(parts)
+
+def _preprocess_extern_python(csource):
+ # input: `extern "Python" int foo(int);` or
+ # `extern "Python" { int foo(int); }`
+ # output:
+ # void __cffi_extern_python_start;
+ # int foo(int);
+ # void __cffi_extern_python_stop;
+ #
+ # input: `extern "Python+C" int foo(int);`
+ # output:
+ # void __cffi_extern_python_plus_c_start;
+ # int foo(int);
+ # void __cffi_extern_python_stop;
+ parts = []
+ while True:
+ match = _r_extern_python.search(csource)
+ if not match:
+ break
+ endpos = match.end() - 1
+ #print
+ #print ''.join(parts)+csource
+ #print '=>'
+ parts.append(csource[:match.start()])
+ if 'C' in match.group(1):
+ parts.append('void __cffi_extern_python_plus_c_start; ')
+ else:
+ parts.append('void __cffi_extern_python_start; ')
+ if csource[endpos] == '{':
+ # grouping variant
+ closing = csource.find('}', endpos)
+ if closing < 0:
+ raise CDefError("'extern \"Python\" {': no '}' found")
+ if csource.find('{', endpos + 1, closing) >= 0:
+ raise NotImplementedError("cannot use { } inside a block "
+ "'extern \"Python\" { ... }'")
+ parts.append(csource[endpos+1:closing])
+ csource = csource[closing+1:]
+ else:
+ # non-grouping variant
+ semicolon = csource.find(';', endpos)
+ if semicolon < 0:
+ raise CDefError("'extern \"Python\": no ';' found")
+ parts.append(csource[endpos:semicolon+1])
+ csource = csource[semicolon+1:]
+ parts.append(' void __cffi_extern_python_stop;')
+ #print ''.join(parts)+csource
+ #print
+ parts.append(csource)
+ return ''.join(parts)
+
+def _warn_for_string_literal(csource):
+ if '"' not in csource:
+ return
+ for line in csource.splitlines():
+ if '"' in line and not line.lstrip().startswith('#'):
+ import warnings
+ warnings.warn("String literal found in cdef() or type source. "
+ "String literals are ignored here, but you should "
+ "remove them anyway because some character sequences "
+ "confuse pre-parsing.")
+ break
+
+def _warn_for_non_extern_non_static_global_variable(decl):
+ if not decl.storage:
+ import warnings
+ warnings.warn("Global variable '%s' in cdef(): for consistency "
+ "with C it should have a storage class specifier "
+ "(usually 'extern')" % (decl.name,))
+
+def _remove_line_directives(csource):
+ # _r_line_directive matches whole lines, without the final \n, if they
+ # start with '#line' with some spacing allowed, or '#NUMBER'. This
+ # function stores them away and replaces them with exactly the string
+ # '#line@N', where N is the index in the list 'line_directives'.
+ line_directives = []
+ def replace(m):
+ i = len(line_directives)
+ line_directives.append(m.group())
+ return '#line@%d' % i
+ csource = _r_line_directive.sub(replace, csource)
+ return csource, line_directives
+
+def _put_back_line_directives(csource, line_directives):
+ def replace(m):
+ s = m.group()
+ if not s.startswith('#line@'):
+ raise AssertionError("unexpected #line directive "
+ "(should have been processed and removed")
+ return line_directives[int(s[6:])]
+ return _r_line_directive.sub(replace, csource)
+
+def _preprocess(csource):
+ # First, remove the lines of the form '#line N "filename"' because
+ # the "filename" part could confuse the rest
+ csource, line_directives = _remove_line_directives(csource)
+ # Remove comments. NOTE: this only work because the cdef() section
+ # should not contain any string literals (except in line directives)!
+ def replace_keeping_newlines(m):
+ return ' ' + m.group().count('\n') * '\n'
+ csource = _r_comment.sub(replace_keeping_newlines, csource)
+ # Remove the "#define FOO x" lines
+ macros = {}
+ for match in _r_define.finditer(csource):
+ macroname, macrovalue = match.groups()
+ macrovalue = macrovalue.replace('\\\n', '').strip()
+ macros[macroname] = macrovalue
+ csource = _r_define.sub('', csource)
+ #
+ if pycparser.__version__ < '2.14':
+ csource = _workaround_for_old_pycparser(csource)
+ #
+ # BIG HACK: replace WINAPI or __stdcall with "volatile const".
+ # It doesn't make sense for the return type of a function to be
+ # "volatile volatile const", so we abuse it to detect __stdcall...
+ # Hack number 2 is that "int(volatile *fptr)();" is not valid C
+ # syntax, so we place the "volatile" before the opening parenthesis.
+ csource = _r_stdcall2.sub(' volatile volatile const(', csource)
+ csource = _r_stdcall1.sub(' volatile volatile const ', csource)
+ csource = _r_cdecl.sub(' ', csource)
+ #
+ # Replace `extern "Python"` with start/end markers
+ csource = _preprocess_extern_python(csource)
+ #
+ # Now there should not be any string literal left; warn if we get one
+ _warn_for_string_literal(csource)
+ #
+ # Replace "[...]" with "[__dotdotdotarray__]"
+ csource = _r_partial_array.sub('[__dotdotdotarray__]', csource)
+ #
+ # Replace "...}" with "__dotdotdotNUM__}". This construction should
+ # occur only at the end of enums; at the end of structs we have "...;}"
+ # and at the end of vararg functions "...);". Also replace "=...[,}]"
+ # with ",__dotdotdotNUM__[,}]": this occurs in the enums too, when
+ # giving an unknown value.
+ matches = list(_r_partial_enum.finditer(csource))
+ for number, match in enumerate(reversed(matches)):
+ p = match.start()
+ if csource[p] == '=':
+ p2 = csource.find('...', p, match.end())
+ assert p2 > p
+ csource = '%s,__dotdotdot%d__ %s' % (csource[:p], number,
+ csource[p2+3:])
+ else:
+ assert csource[p:p+3] == '...'
+ csource = '%s __dotdotdot%d__ %s' % (csource[:p], number,
+ csource[p+3:])
+ # Replace "int ..." or "unsigned long int..." with "__dotdotdotint__"
+ csource = _r_int_dotdotdot.sub(' __dotdotdotint__ ', csource)
+ # Replace "float ..." or "double..." with "__dotdotdotfloat__"
+ csource = _r_float_dotdotdot.sub(' __dotdotdotfloat__ ', csource)
+ # Replace all remaining "..." with the same name, "__dotdotdot__",
+ # which is declared with a typedef for the purpose of C parsing.
+ csource = csource.replace('...', ' __dotdotdot__ ')
+ # Finally, put back the line directives
+ csource = _put_back_line_directives(csource, line_directives)
+ return csource, macros
+
+def _common_type_names(csource):
+ # Look in the source for what looks like usages of types from the
+ # list of common types. A "usage" is approximated here as the
+ # appearance of the word, minus a "definition" of the type, which
+ # is the last word in a "typedef" statement. Approximative only
+ # but should be fine for all the common types.
+ look_for_words = set(COMMON_TYPES)
+ look_for_words.add(';')
+ look_for_words.add(',')
+ look_for_words.add('(')
+ look_for_words.add(')')
+ look_for_words.add('typedef')
+ words_used = set()
+ is_typedef = False
+ paren = 0
+ previous_word = ''
+ for word in _r_words.findall(csource):
+ if word in look_for_words:
+ if word == ';':
+ if is_typedef:
+ words_used.discard(previous_word)
+ look_for_words.discard(previous_word)
+ is_typedef = False
+ elif word == 'typedef':
+ is_typedef = True
+ paren = 0
+ elif word == '(':
+ paren += 1
+ elif word == ')':
+ paren -= 1
+ elif word == ',':
+ if is_typedef and paren == 0:
+ words_used.discard(previous_word)
+ look_for_words.discard(previous_word)
+ else: # word in COMMON_TYPES
+ words_used.add(word)
+ previous_word = word
+ return words_used
+
+
+class Parser(object):
+
+ def __init__(self):
+ self._declarations = {}
+ self._included_declarations = set()
+ self._anonymous_counter = 0
+ self._structnode2type = weakref.WeakKeyDictionary()
+ self._options = {}
+ self._int_constants = {}
+ self._recomplete = []
+ self._uses_new_feature = None
+
+ def _parse(self, csource):
+ csource, macros = _preprocess(csource)
+ # XXX: for more efficiency we would need to poke into the
+ # internals of CParser... the following registers the
+ # typedefs, because their presence or absence influences the
+ # parsing itself (but what they are typedef'ed to plays no role)
+ ctn = _common_type_names(csource)
+ typenames = []
+ for name in sorted(self._declarations):
+ if name.startswith('typedef '):
+ name = name[8:]
+ typenames.append(name)
+ ctn.discard(name)
+ typenames += sorted(ctn)
+ #
+ csourcelines = []
+ csourcelines.append('# 1 "<cdef automatic initialization code>"')
+ for typename in typenames:
+ csourcelines.append('typedef int %s;' % typename)
+ csourcelines.append('typedef int __dotdotdotint__, __dotdotdotfloat__,'
+ ' __dotdotdot__;')
+ # this forces pycparser to consider the following in the file
+ # called <cdef source string> from line 1
+ csourcelines.append('# 1 "%s"' % (CDEF_SOURCE_STRING,))
+ csourcelines.append(csource)
+ fullcsource = '\n'.join(csourcelines)
+ if lock is not None:
+ lock.acquire() # pycparser is not thread-safe...
+ try:
+ ast = _get_parser().parse(fullcsource)
+ except pycparser.c_parser.ParseError as e:
+ self.convert_pycparser_error(e, csource)
+ finally:
+ if lock is not None:
+ lock.release()
+ # csource will be used to find buggy source text
+ return ast, macros, csource
+
+ def _convert_pycparser_error(self, e, csource):
+ # xxx look for "<cdef source string>:NUM:" at the start of str(e)
+ # and interpret that as a line number. This will not work if
+ # the user gives explicit ``# NUM "FILE"`` directives.
+ line = None
+ msg = str(e)
+ match = re.match(r"%s:(\d+):" % (CDEF_SOURCE_STRING,), msg)
+ if match:
+ linenum = int(match.group(1), 10)
+ csourcelines = csource.splitlines()
+ if 1 <= linenum <= len(csourcelines):
+ line = csourcelines[linenum-1]
+ return line
+
+ def convert_pycparser_error(self, e, csource):
+ line = self._convert_pycparser_error(e, csource)
+
+ msg = str(e)
+ if line:
+ msg = 'cannot parse "%s"\n%s' % (line.strip(), msg)
+ else:
+ msg = 'parse error\n%s' % (msg,)
+ raise CDefError(msg)
+
+ def parse(self, csource, override=False, packed=False, pack=None,
+ dllexport=False):
+ if packed:
+ if packed != True:
+ raise ValueError("'packed' should be False or True; use "
+ "'pack' to give another value")
+ if pack:
+ raise ValueError("cannot give both 'pack' and 'packed'")
+ pack = 1
+ elif pack:
+ if pack & (pack - 1):
+ raise ValueError("'pack' must be a power of two, not %r" %
+ (pack,))
+ else:
+ pack = 0
+ prev_options = self._options
+ try:
+ self._options = {'override': override,
+ 'packed': pack,
+ 'dllexport': dllexport}
+ self._internal_parse(csource)
+ finally:
+ self._options = prev_options
+
+ def _internal_parse(self, csource):
+ ast, macros, csource = self._parse(csource)
+ # add the macros
+ self._process_macros(macros)
+ # find the first "__dotdotdot__" and use that as a separator
+ # between the repeated typedefs and the real csource
+ iterator = iter(ast.ext)
+ for decl in iterator:
+ if decl.name == '__dotdotdot__':
+ break
+ else:
+ assert 0
+ current_decl = None
+ #
+ try:
+ self._inside_extern_python = '__cffi_extern_python_stop'
+ for decl in iterator:
+ current_decl = decl
+ if isinstance(decl, pycparser.c_ast.Decl):
+ self._parse_decl(decl)
+ elif isinstance(decl, pycparser.c_ast.Typedef):
+ if not decl.name:
+ raise CDefError("typedef does not declare any name",
+ decl)
+ quals = 0
+ if (isinstance(decl.type.type, pycparser.c_ast.IdentifierType) and
+ decl.type.type.names[-1].startswith('__dotdotdot')):
+ realtype = self._get_unknown_type(decl)
+ elif (isinstance(decl.type, pycparser.c_ast.PtrDecl) and
+ isinstance(decl.type.type, pycparser.c_ast.TypeDecl) and
+ isinstance(decl.type.type.type,
+ pycparser.c_ast.IdentifierType) and
+ decl.type.type.type.names[-1].startswith('__dotdotdot')):
+ realtype = self._get_unknown_ptr_type(decl)
+ else:
+ realtype, quals = self._get_type_and_quals(
+ decl.type, name=decl.name, partial_length_ok=True,
+ typedef_example="*(%s *)0" % (decl.name,))
+ self._declare('typedef ' + decl.name, realtype, quals=quals)
+ elif decl.__class__.__name__ == 'Pragma':
+ pass # skip pragma, only in pycparser 2.15
+ else:
+ raise CDefError("unexpected <%s>: this construct is valid "
+ "C but not valid in cdef()" %
+ decl.__class__.__name__, decl)
+ except CDefError as e:
+ if len(e.args) == 1:
+ e.args = e.args + (current_decl,)
+ raise
+ except FFIError as e:
+ msg = self._convert_pycparser_error(e, csource)
+ if msg:
+ e.args = (e.args[0] + "\n *** Err: %s" % msg,)
+ raise
+
+ def _add_constants(self, key, val):
+ if key in self._int_constants:
+ if self._int_constants[key] == val:
+ return # ignore identical double declarations
+ raise FFIError(
+ "multiple declarations of constant: %s" % (key,))
+ self._int_constants[key] = val
+
+ def _add_integer_constant(self, name, int_str):
+ int_str = int_str.lower().rstrip("ul")
+ neg = int_str.startswith('-')
+ if neg:
+ int_str = int_str[1:]
+ # "010" is not valid oct in py3
+ if (int_str.startswith("0") and int_str != '0'
+ and not int_str.startswith("0x")):
+ int_str = "0o" + int_str[1:]
+ pyvalue = int(int_str, 0)
+ if neg:
+ pyvalue = -pyvalue
+ self._add_constants(name, pyvalue)
+ self._declare('macro ' + name, pyvalue)
+
+ def _process_macros(self, macros):
+ for key, value in macros.items():
+ value = value.strip()
+ if _r_int_literal.match(value):
+ self._add_integer_constant(key, value)
+ elif value == '...':
+ self._declare('macro ' + key, value)
+ else:
+ raise CDefError(
+ 'only supports one of the following syntax:\n'
+ ' #define %s ... (literally dot-dot-dot)\n'
+ ' #define %s NUMBER (with NUMBER an integer'
+ ' constant, decimal/hex/octal)\n'
+ 'got:\n'
+ ' #define %s %s'
+ % (key, key, key, value))
+
+ def _declare_function(self, tp, quals, decl):
+ tp = self._get_type_pointer(tp, quals)
+ if self._options.get('dllexport'):
+ tag = 'dllexport_python '
+ elif self._inside_extern_python == '__cffi_extern_python_start':
+ tag = 'extern_python '
+ elif self._inside_extern_python == '__cffi_extern_python_plus_c_start':
+ tag = 'extern_python_plus_c '
+ else:
+ tag = 'function '
+ self._declare(tag + decl.name, tp)
+
+ def _parse_decl(self, decl):
+ node = decl.type
+ if isinstance(node, pycparser.c_ast.FuncDecl):
+ tp, quals = self._get_type_and_quals(node, name=decl.name)
+ assert isinstance(tp, model.RawFunctionType)
+ self._declare_function(tp, quals, decl)
+ else:
+ if isinstance(node, pycparser.c_ast.Struct):
+ self._get_struct_union_enum_type('struct', node)
+ elif isinstance(node, pycparser.c_ast.Union):
+ self._get_struct_union_enum_type('union', node)
+ elif isinstance(node, pycparser.c_ast.Enum):
+ self._get_struct_union_enum_type('enum', node)
+ elif not decl.name:
+ raise CDefError("construct does not declare any variable",
+ decl)
+ #
+ if decl.name:
+ tp, quals = self._get_type_and_quals(node,
+ partial_length_ok=True)
+ if tp.is_raw_function:
+ self._declare_function(tp, quals, decl)
+ elif (tp.is_integer_type() and
+ hasattr(decl, 'init') and
+ hasattr(decl.init, 'value') and
+ _r_int_literal.match(decl.init.value)):
+ self._add_integer_constant(decl.name, decl.init.value)
+ elif (tp.is_integer_type() and
+ isinstance(decl.init, pycparser.c_ast.UnaryOp) and
+ decl.init.op == '-' and
+ hasattr(decl.init.expr, 'value') and
+ _r_int_literal.match(decl.init.expr.value)):
+ self._add_integer_constant(decl.name,
+ '-' + decl.init.expr.value)
+ elif (tp is model.void_type and
+ decl.name.startswith('__cffi_extern_python_')):
+ # hack: `extern "Python"` in the C source is replaced
+ # with "void __cffi_extern_python_start;" and
+ # "void __cffi_extern_python_stop;"
+ self._inside_extern_python = decl.name
+ else:
+ if self._inside_extern_python !='__cffi_extern_python_stop':
+ raise CDefError(
+ "cannot declare constants or "
+ "variables with 'extern \"Python\"'")
+ if (quals & model.Q_CONST) and not tp.is_array_type:
+ self._declare('constant ' + decl.name, tp, quals=quals)
+ else:
+ _warn_for_non_extern_non_static_global_variable(decl)
+ self._declare('variable ' + decl.name, tp, quals=quals)
+
+ def parse_type(self, cdecl):
+ return self.parse_type_and_quals(cdecl)[0]
+
+ def parse_type_and_quals(self, cdecl):
+ ast, macros = self._parse('void __dummy(\n%s\n);' % cdecl)[:2]
+ assert not macros
+ exprnode = ast.ext[-1].type.args.params[0]
+ if isinstance(exprnode, pycparser.c_ast.ID):
+ raise CDefError("unknown identifier '%s'" % (exprnode.name,))
+ return self._get_type_and_quals(exprnode.type)
+
+ def _declare(self, name, obj, included=False, quals=0):
+ if name in self._declarations:
+ prevobj, prevquals = self._declarations[name]
+ if prevobj is obj and prevquals == quals:
+ return
+ if not self._options.get('override'):
+ raise FFIError(
+ "multiple declarations of %s (for interactive usage, "
+ "try cdef(xx, override=True))" % (name,))
+ assert '__dotdotdot__' not in name.split()
+ self._declarations[name] = (obj, quals)
+ if included:
+ self._included_declarations.add(obj)
+
+ def _extract_quals(self, type):
+ quals = 0
+ if isinstance(type, (pycparser.c_ast.TypeDecl,
+ pycparser.c_ast.PtrDecl)):
+ if 'const' in type.quals:
+ quals |= model.Q_CONST
+ if 'volatile' in type.quals:
+ quals |= model.Q_VOLATILE
+ if 'restrict' in type.quals:
+ quals |= model.Q_RESTRICT
+ return quals
+
+ def _get_type_pointer(self, type, quals, declname=None):
+ if isinstance(type, model.RawFunctionType):
+ return type.as_function_pointer()
+ if (isinstance(type, model.StructOrUnionOrEnum) and
+ type.name.startswith('$') and type.name[1:].isdigit() and
+ type.forcename is None and declname is not None):
+ return model.NamedPointerType(type, declname, quals)
+ return model.PointerType(type, quals)
+
+ def _get_type_and_quals(self, typenode, name=None, partial_length_ok=False,
+ typedef_example=None):
+ # first, dereference typedefs, if we have it already parsed, we're good
+ if (isinstance(typenode, pycparser.c_ast.TypeDecl) and
+ isinstance(typenode.type, pycparser.c_ast.IdentifierType) and
+ len(typenode.type.names) == 1 and
+ ('typedef ' + typenode.type.names[0]) in self._declarations):
+ tp, quals = self._declarations['typedef ' + typenode.type.names[0]]
+ quals |= self._extract_quals(typenode)
+ return tp, quals
+ #
+ if isinstance(typenode, pycparser.c_ast.ArrayDecl):
+ # array type
+ if typenode.dim is None:
+ length = None
+ else:
+ length = self._parse_constant(
+ typenode.dim, partial_length_ok=partial_length_ok)
+ # a hack: in 'typedef int foo_t[...][...];', don't use '...' as
+ # the length but use directly the C expression that would be
+ # generated by recompiler.py. This lets the typedef be used in
+ # many more places within recompiler.py
+ if typedef_example is not None:
+ if length == '...':
+ length = '_cffi_array_len(%s)' % (typedef_example,)
+ typedef_example = "*" + typedef_example
+ #
+ tp, quals = self._get_type_and_quals(typenode.type,
+ partial_length_ok=partial_length_ok,
+ typedef_example=typedef_example)
+ return model.ArrayType(tp, length), quals
+ #
+ if isinstance(typenode, pycparser.c_ast.PtrDecl):
+ # pointer type
+ itemtype, itemquals = self._get_type_and_quals(typenode.type)
+ tp = self._get_type_pointer(itemtype, itemquals, declname=name)
+ quals = self._extract_quals(typenode)
+ return tp, quals
+ #
+ if isinstance(typenode, pycparser.c_ast.TypeDecl):
+ quals = self._extract_quals(typenode)
+ type = typenode.type
+ if isinstance(type, pycparser.c_ast.IdentifierType):
+ # assume a primitive type. get it from .names, but reduce
+ # synonyms to a single chosen combination
+ names = list(type.names)
+ if names != ['signed', 'char']: # keep this unmodified
+ prefixes = {}
+ while names:
+ name = names[0]
+ if name in ('short', 'long', 'signed', 'unsigned'):
+ prefixes[name] = prefixes.get(name, 0) + 1
+ del names[0]
+ else:
+ break
+ # ignore the 'signed' prefix below, and reorder the others
+ newnames = []
+ for prefix in ('unsigned', 'short', 'long'):
+ for i in range(prefixes.get(prefix, 0)):
+ newnames.append(prefix)
+ if not names:
+ names = ['int'] # implicitly
+ if names == ['int']: # but kill it if 'short' or 'long'
+ if 'short' in prefixes or 'long' in prefixes:
+ names = []
+ names = newnames + names
+ ident = ' '.join(names)
+ if ident == 'void':
+ return model.void_type, quals
+ if ident == '__dotdotdot__':
+ raise FFIError(':%d: bad usage of "..."' %
+ typenode.coord.line)
+ tp0, quals0 = resolve_common_type(self, ident)
+ return tp0, (quals | quals0)
+ #
+ if isinstance(type, pycparser.c_ast.Struct):
+ # 'struct foobar'
+ tp = self._get_struct_union_enum_type('struct', type, name)
+ return tp, quals
+ #
+ if isinstance(type, pycparser.c_ast.Union):
+ # 'union foobar'
+ tp = self._get_struct_union_enum_type('union', type, name)
+ return tp, quals
+ #
+ if isinstance(type, pycparser.c_ast.Enum):
+ # 'enum foobar'
+ tp = self._get_struct_union_enum_type('enum', type, name)
+ return tp, quals
+ #
+ if isinstance(typenode, pycparser.c_ast.FuncDecl):
+ # a function type
+ return self._parse_function_type(typenode, name), 0
+ #
+ # nested anonymous structs or unions end up here
+ if isinstance(typenode, pycparser.c_ast.Struct):
+ return self._get_struct_union_enum_type('struct', typenode, name,
+ nested=True), 0
+ if isinstance(typenode, pycparser.c_ast.Union):
+ return self._get_struct_union_enum_type('union', typenode, name,
+ nested=True), 0
+ #
+ raise FFIError(":%d: bad or unsupported type declaration" %
+ typenode.coord.line)
+
+ def _parse_function_type(self, typenode, funcname=None):
+ params = list(getattr(typenode.args, 'params', []))
+ for i, arg in enumerate(params):
+ if not hasattr(arg, 'type'):
+ raise CDefError("%s arg %d: unknown type '%s'"
+ " (if you meant to use the old C syntax of giving"
+ " untyped arguments, it is not supported)"
+ % (funcname or 'in expression', i + 1,
+ getattr(arg, 'name', '?')))
+ ellipsis = (
+ len(params) > 0 and
+ isinstance(params[-1].type, pycparser.c_ast.TypeDecl) and
+ isinstance(params[-1].type.type,
+ pycparser.c_ast.IdentifierType) and
+ params[-1].type.type.names == ['__dotdotdot__'])
+ if ellipsis:
+ params.pop()
+ if not params:
+ raise CDefError(
+ "%s: a function with only '(...)' as argument"
+ " is not correct C" % (funcname or 'in expression'))
+ args = [self._as_func_arg(*self._get_type_and_quals(argdeclnode.type))
+ for argdeclnode in params]
+ if not ellipsis and args == [model.void_type]:
+ args = []
+ result, quals = self._get_type_and_quals(typenode.type)
+ # the 'quals' on the result type are ignored. HACK: we absure them
+ # to detect __stdcall functions: we textually replace "__stdcall"
+ # with "volatile volatile const" above.
+ abi = None
+ if hasattr(typenode.type, 'quals'): # else, probable syntax error anyway
+ if typenode.type.quals[-3:] == ['volatile', 'volatile', 'const']:
+ abi = '__stdcall'
+ return model.RawFunctionType(tuple(args), result, ellipsis, abi)
+
+ def _as_func_arg(self, type, quals):
+ if isinstance(type, model.ArrayType):
+ return model.PointerType(type.item, quals)
+ elif isinstance(type, model.RawFunctionType):
+ return type.as_function_pointer()
+ else:
+ return type
+
+ def _get_struct_union_enum_type(self, kind, type, name=None, nested=False):
+ # First, a level of caching on the exact 'type' node of the AST.
+ # This is obscure, but needed because pycparser "unrolls" declarations
+ # such as "typedef struct { } foo_t, *foo_p" and we end up with
+ # an AST that is not a tree, but a DAG, with the "type" node of the
+ # two branches foo_t and foo_p of the trees being the same node.
+ # It's a bit silly but detecting "DAG-ness" in the AST tree seems
+ # to be the only way to distinguish this case from two independent
+ # structs. See test_struct_with_two_usages.
+ try:
+ return self._structnode2type[type]
+ except KeyError:
+ pass
+ #
+ # Note that this must handle parsing "struct foo" any number of
+ # times and always return the same StructType object. Additionally,
+ # one of these times (not necessarily the first), the fields of
+ # the struct can be specified with "struct foo { ...fields... }".
+ # If no name is given, then we have to create a new anonymous struct
+ # with no caching; in this case, the fields are either specified
+ # right now or never.
+ #
+ force_name = name
+ name = type.name
+ #
+ # get the type or create it if needed
+ if name is None:
+ # 'force_name' is used to guess a more readable name for
+ # anonymous structs, for the common case "typedef struct { } foo".
+ if force_name is not None:
+ explicit_name = '$%s' % force_name
+ else:
+ self._anonymous_counter += 1
+ explicit_name = '$%d' % self._anonymous_counter
+ tp = None
+ else:
+ explicit_name = name
+ key = '%s %s' % (kind, name)
+ tp, _ = self._declarations.get(key, (None, None))
+ #
+ if tp is None:
+ if kind == 'struct':
+ tp = model.StructType(explicit_name, None, None, None)
+ elif kind == 'union':
+ tp = model.UnionType(explicit_name, None, None, None)
+ elif kind == 'enum':
+ if explicit_name == '__dotdotdot__':
+ raise CDefError("Enums cannot be declared with ...")
+ tp = self._build_enum_type(explicit_name, type.values)
+ else:
+ raise AssertionError("kind = %r" % (kind,))
+ if name is not None:
+ self._declare(key, tp)
+ else:
+ if kind == 'enum' and type.values is not None:
+ raise NotImplementedError(
+ "enum %s: the '{}' declaration should appear on the first "
+ "time the enum is mentioned, not later" % explicit_name)
+ if not tp.forcename:
+ tp.force_the_name(force_name)
+ if tp.forcename and '$' in tp.name:
+ self._declare('anonymous %s' % tp.forcename, tp)
+ #
+ self._structnode2type[type] = tp
+ #
+ # enums: done here
+ if kind == 'enum':
+ return tp
+ #
+ # is there a 'type.decls'? If yes, then this is the place in the
+ # C sources that declare the fields. If no, then just return the
+ # existing type, possibly still incomplete.
+ if type.decls is None:
+ return tp
+ #
+ if tp.fldnames is not None:
+ raise CDefError("duplicate declaration of struct %s" % name)
+ fldnames = []
+ fldtypes = []
+ fldbitsize = []
+ fldquals = []
+ for decl in type.decls:
+ if (isinstance(decl.type, pycparser.c_ast.IdentifierType) and
+ ''.join(decl.type.names) == '__dotdotdot__'):
+ # XXX pycparser is inconsistent: 'names' should be a list
+ # of strings, but is sometimes just one string. Use
+ # str.join() as a way to cope with both.
+ self._make_partial(tp, nested)
+ continue
+ if decl.bitsize is None:
+ bitsize = -1
+ else:
+ bitsize = self._parse_constant(decl.bitsize)
+ self._partial_length = False
+ type, fqual = self._get_type_and_quals(decl.type,
+ partial_length_ok=True)
+ if self._partial_length:
+ self._make_partial(tp, nested)
+ if isinstance(type, model.StructType) and type.partial:
+ self._make_partial(tp, nested)
+ fldnames.append(decl.name or '')
+ fldtypes.append(type)
+ fldbitsize.append(bitsize)
+ fldquals.append(fqual)
+ tp.fldnames = tuple(fldnames)
+ tp.fldtypes = tuple(fldtypes)
+ tp.fldbitsize = tuple(fldbitsize)
+ tp.fldquals = tuple(fldquals)
+ if fldbitsize != [-1] * len(fldbitsize):
+ if isinstance(tp, model.StructType) and tp.partial:
+ raise NotImplementedError("%s: using both bitfields and '...;'"
+ % (tp,))
+ tp.packed = self._options.get('packed')
+ if tp.completed: # must be re-completed: it is not opaque any more
+ tp.completed = 0
+ self._recomplete.append(tp)
+ return tp
+
+ def _make_partial(self, tp, nested):
+ if not isinstance(tp, model.StructOrUnion):
+ raise CDefError("%s cannot be partial" % (tp,))
+ if not tp.has_c_name() and not nested:
+ raise NotImplementedError("%s is partial but has no C name" %(tp,))
+ tp.partial = True
+
+ def _parse_constant(self, exprnode, partial_length_ok=False):
+ # for now, limited to expressions that are an immediate number
+ # or positive/negative number
+ if isinstance(exprnode, pycparser.c_ast.Constant):
+ s = exprnode.value
+ if '0' <= s[0] <= '9':
+ s = s.rstrip('uUlL')
+ try:
+ if s.startswith('0'):
+ return int(s, 8)
+ else:
+ return int(s, 10)
+ except ValueError:
+ if len(s) > 1:
+ if s.lower()[0:2] == '0x':
+ return int(s, 16)
+ elif s.lower()[0:2] == '0b':
+ return int(s, 2)
+ raise CDefError("invalid constant %r" % (s,))
+ elif s[0] == "'" and s[-1] == "'" and (
+ len(s) == 3 or (len(s) == 4 and s[1] == "\\")):
+ return ord(s[-2])
+ else:
+ raise CDefError("invalid constant %r" % (s,))
+ #
+ if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and
+ exprnode.op == '+'):
+ return self._parse_constant(exprnode.expr)
+ #
+ if (isinstance(exprnode, pycparser.c_ast.UnaryOp) and
+ exprnode.op == '-'):
+ return -self._parse_constant(exprnode.expr)
+ # load previously defined int constant
+ if (isinstance(exprnode, pycparser.c_ast.ID) and
+ exprnode.name in self._int_constants):
+ return self._int_constants[exprnode.name]
+ #
+ if (isinstance(exprnode, pycparser.c_ast.ID) and
+ exprnode.name == '__dotdotdotarray__'):
+ if partial_length_ok:
+ self._partial_length = True
+ return '...'
+ raise FFIError(":%d: unsupported '[...]' here, cannot derive "
+ "the actual array length in this context"
+ % exprnode.coord.line)
+ #
+ if isinstance(exprnode, pycparser.c_ast.BinaryOp):
+ left = self._parse_constant(exprnode.left)
+ right = self._parse_constant(exprnode.right)
+ if exprnode.op == '+':
+ return left + right
+ elif exprnode.op == '-':
+ return left - right
+ elif exprnode.op == '*':
+ return left * right
+ elif exprnode.op == '/':
+ return self._c_div(left, right)
+ elif exprnode.op == '%':
+ return left - self._c_div(left, right) * right
+ elif exprnode.op == '<<':
+ return left << right
+ elif exprnode.op == '>>':
+ return left >> right
+ elif exprnode.op == '&':
+ return left & right
+ elif exprnode.op == '|':
+ return left | right
+ elif exprnode.op == '^':
+ return left ^ right
+ #
+ raise FFIError(":%d: unsupported expression: expected a "
+ "simple numeric constant" % exprnode.coord.line)
+
+ def _c_div(self, a, b):
+ result = a // b
+ if ((a < 0) ^ (b < 0)) and (a % b) != 0:
+ result += 1
+ return result
+
+ def _build_enum_type(self, explicit_name, decls):
+ if decls is not None:
+ partial = False
+ enumerators = []
+ enumvalues = []
+ nextenumvalue = 0
+ for enum in decls.enumerators:
+ if _r_enum_dotdotdot.match(enum.name):
+ partial = True
+ continue
+ if enum.value is not None:
+ nextenumvalue = self._parse_constant(enum.value)
+ enumerators.append(enum.name)
+ enumvalues.append(nextenumvalue)
+ self._add_constants(enum.name, nextenumvalue)
+ nextenumvalue += 1
+ enumerators = tuple(enumerators)
+ enumvalues = tuple(enumvalues)
+ tp = model.EnumType(explicit_name, enumerators, enumvalues)
+ tp.partial = partial
+ else: # opaque enum
+ tp = model.EnumType(explicit_name, (), ())
+ return tp
+
+ def include(self, other):
+ for name, (tp, quals) in other._declarations.items():
+ if name.startswith('anonymous $enum_$'):
+ continue # fix for test_anonymous_enum_include
+ kind = name.split(' ', 1)[0]
+ if kind in ('struct', 'union', 'enum', 'anonymous', 'typedef'):
+ self._declare(name, tp, included=True, quals=quals)
+ for k, v in other._int_constants.items():
+ self._add_constants(k, v)
+
+ def _get_unknown_type(self, decl):
+ typenames = decl.type.type.names
+ if typenames == ['__dotdotdot__']:
+ return model.unknown_type(decl.name)
+
+ if typenames == ['__dotdotdotint__']:
+ if self._uses_new_feature is None:
+ self._uses_new_feature = "'typedef int... %s'" % decl.name
+ return model.UnknownIntegerType(decl.name)
+
+ if typenames == ['__dotdotdotfloat__']:
+ # note: not for 'long double' so far
+ if self._uses_new_feature is None:
+ self._uses_new_feature = "'typedef float... %s'" % decl.name
+ return model.UnknownFloatType(decl.name)
+
+ raise FFIError(':%d: unsupported usage of "..." in typedef'
+ % decl.coord.line)
+
+ def _get_unknown_ptr_type(self, decl):
+ if decl.type.type.type.names == ['__dotdotdot__']:
+ return model.unknown_ptr_type(decl.name)
+ raise FFIError(':%d: unsupported usage of "..." in typedef'
+ % decl.coord.line)
diff --git a/lib/cffi/error.py b/lib/cffi/error.py
new file mode 100644
index 0000000..0a27247
--- /dev/null
+++ b/lib/cffi/error.py
@@ -0,0 +1,31 @@
+
+class FFIError(Exception):
+ __module__ = 'cffi'
+
+class CDefError(Exception):
+ __module__ = 'cffi'
+ def __str__(self):
+ try:
+ current_decl = self.args[1]
+ filename = current_decl.coord.file
+ linenum = current_decl.coord.line
+ prefix = '%s:%d: ' % (filename, linenum)
+ except (AttributeError, TypeError, IndexError):
+ prefix = ''
+ return '%s%s' % (prefix, self.args[0])
+
+class VerificationError(Exception):
+ """ An error raised when verification fails
+ """
+ __module__ = 'cffi'
+
+class VerificationMissing(Exception):
+ """ An error raised when incomplete structures are passed into
+ cdef, but no verification has been done
+ """
+ __module__ = 'cffi'
+
+class PkgConfigError(Exception):
+ """ An error raised for missing modules in pkg-config
+ """
+ __module__ = 'cffi'
diff --git a/lib/cffi/ffiplatform.py b/lib/cffi/ffiplatform.py
new file mode 100644
index 0000000..8531346
--- /dev/null
+++ b/lib/cffi/ffiplatform.py
@@ -0,0 +1,127 @@
+import sys, os
+from .error import VerificationError
+
+
+LIST_OF_FILE_NAMES = ['sources', 'include_dirs', 'library_dirs',
+ 'extra_objects', 'depends']
+
+def get_extension(srcfilename, modname, sources=(), **kwds):
+ _hack_at_distutils()
+ from distutils.core import Extension
+ allsources = [srcfilename]
+ for src in sources:
+ allsources.append(os.path.normpath(src))
+ return Extension(name=modname, sources=allsources, **kwds)
+
+def compile(tmpdir, ext, compiler_verbose=0, debug=None):
+ """Compile a C extension module using distutils."""
+
+ _hack_at_distutils()
+ saved_environ = os.environ.copy()
+ try:
+ outputfilename = _build(tmpdir, ext, compiler_verbose, debug)
+ outputfilename = os.path.abspath(outputfilename)
+ finally:
+ # workaround for a distutils bugs where some env vars can
+ # become longer and longer every time it is used
+ for key, value in saved_environ.items():
+ if os.environ.get(key) != value:
+ os.environ[key] = value
+ return outputfilename
+
+def _build(tmpdir, ext, compiler_verbose=0, debug=None):
+ # XXX compact but horrible :-(
+ from distutils.core import Distribution
+ import distutils.errors, distutils.log
+ #
+ dist = Distribution({'ext_modules': [ext]})
+ dist.parse_config_files()
+ options = dist.get_option_dict('build_ext')
+ if debug is None:
+ debug = sys.flags.debug
+ options['debug'] = ('ffiplatform', debug)
+ options['force'] = ('ffiplatform', True)
+ options['build_lib'] = ('ffiplatform', tmpdir)
+ options['build_temp'] = ('ffiplatform', tmpdir)
+ #
+ try:
+ old_level = distutils.log.set_threshold(0) or 0
+ try:
+ distutils.log.set_verbosity(compiler_verbose)
+ dist.run_command('build_ext')
+ cmd_obj = dist.get_command_obj('build_ext')
+ [soname] = cmd_obj.get_outputs()
+ finally:
+ distutils.log.set_threshold(old_level)
+ except (distutils.errors.CompileError,
+ distutils.errors.LinkError) as e:
+ raise VerificationError('%s: %s' % (e.__class__.__name__, e))
+ #
+ return soname
+
+try:
+ from os.path import samefile
+except ImportError:
+ def samefile(f1, f2):
+ return os.path.abspath(f1) == os.path.abspath(f2)
+
+def maybe_relative_path(path):
+ if not os.path.isabs(path):
+ return path # already relative
+ dir = path
+ names = []
+ while True:
+ prevdir = dir
+ dir, name = os.path.split(prevdir)
+ if dir == prevdir or not dir:
+ return path # failed to make it relative
+ names.append(name)
+ try:
+ if samefile(dir, os.curdir):
+ names.reverse()
+ return os.path.join(*names)
+ except OSError:
+ pass
+
+# ____________________________________________________________
+
+try:
+ int_or_long = (int, long)
+ import cStringIO
+except NameError:
+ int_or_long = int # Python 3
+ import io as cStringIO
+
+def _flatten(x, f):
+ if isinstance(x, str):
+ f.write('%ds%s' % (len(x), x))
+ elif isinstance(x, dict):
+ keys = sorted(x.keys())
+ f.write('%dd' % len(keys))
+ for key in keys:
+ _flatten(key, f)
+ _flatten(x[key], f)
+ elif isinstance(x, (list, tuple)):
+ f.write('%dl' % len(x))
+ for value in x:
+ _flatten(value, f)
+ elif isinstance(x, int_or_long):
+ f.write('%di' % (x,))
+ else:
+ raise TypeError(
+ "the keywords to verify() contains unsupported object %r" % (x,))
+
+def flatten(x):
+ f = cStringIO.StringIO()
+ _flatten(x, f)
+ return f.getvalue()
+
+def _hack_at_distutils():
+ # Windows-only workaround for some configurations: see
+ # https://bugs.python.org/issue23246 (Python 2.7 with
+ # a specific MS compiler suite download)
+ if sys.platform == "win32":
+ try:
+ import setuptools # for side-effects, patches distutils
+ except ImportError:
+ pass
diff --git a/lib/cffi/lock.py b/lib/cffi/lock.py
new file mode 100644
index 0000000..db91b71
--- /dev/null
+++ b/lib/cffi/lock.py
@@ -0,0 +1,30 @@
+import sys
+
+if sys.version_info < (3,):
+ try:
+ from thread import allocate_lock
+ except ImportError:
+ from dummy_thread import allocate_lock
+else:
+ try:
+ from _thread import allocate_lock
+ except ImportError:
+ from _dummy_thread import allocate_lock
+
+
+##import sys
+##l1 = allocate_lock
+
+##class allocate_lock(object):
+## def __init__(self):
+## self._real = l1()
+## def __enter__(self):
+## for i in range(4, 0, -1):
+## print sys._getframe(i).f_code
+## print
+## return self._real.__enter__()
+## def __exit__(self, *args):
+## return self._real.__exit__(*args)
+## def acquire(self, f):
+## assert f is False
+## return self._real.acquire(f)
diff --git a/lib/cffi/model.py b/lib/cffi/model.py
new file mode 100644
index 0000000..ad1c176
--- /dev/null
+++ b/lib/cffi/model.py
@@ -0,0 +1,617 @@
+import types
+import weakref
+
+from .lock import allocate_lock
+from .error import CDefError, VerificationError, VerificationMissing
+
+# type qualifiers
+Q_CONST = 0x01
+Q_RESTRICT = 0x02
+Q_VOLATILE = 0x04
+
+def qualify(quals, replace_with):
+ if quals & Q_CONST:
+ replace_with = ' const ' + replace_with.lstrip()
+ if quals & Q_VOLATILE:
+ replace_with = ' volatile ' + replace_with.lstrip()
+ if quals & Q_RESTRICT:
+ # It seems that __restrict is supported by gcc and msvc.
+ # If you hit some different compiler, add a #define in
+ # _cffi_include.h for it (and in its copies, documented there)
+ replace_with = ' __restrict ' + replace_with.lstrip()
+ return replace_with
+
+
+class BaseTypeByIdentity(object):
+ is_array_type = False
+ is_raw_function = False
+
+ def get_c_name(self, replace_with='', context='a C file', quals=0):
+ result = self.c_name_with_marker
+ assert result.count('&') == 1
+ # some logic duplication with ffi.getctype()... :-(
+ replace_with = replace_with.strip()
+ if replace_with:
+ if replace_with.startswith('*') and '&[' in result:
+ replace_with = '(%s)' % replace_with
+ elif not replace_with[0] in '[(':
+ replace_with = ' ' + replace_with
+ replace_with = qualify(quals, replace_with)
+ result = result.replace('&', replace_with)
+ if '$' in result:
+ raise VerificationError(
+ "cannot generate '%s' in %s: unknown type name"
+ % (self._get_c_name(), context))
+ return result
+
+ def _get_c_name(self):
+ return self.c_name_with_marker.replace('&', '')
+
+ def has_c_name(self):
+ return '$' not in self._get_c_name()
+
+ def is_integer_type(self):
+ return False
+
+ def get_cached_btype(self, ffi, finishlist, can_delay=False):
+ try:
+ BType = ffi._cached_btypes[self]
+ except KeyError:
+ BType = self.build_backend_type(ffi, finishlist)
+ BType2 = ffi._cached_btypes.setdefault(self, BType)
+ assert BType2 is BType
+ return BType
+
+ def __repr__(self):
+ return '<%s>' % (self._get_c_name(),)
+
+ def _get_items(self):
+ return [(name, getattr(self, name)) for name in self._attrs_]
+
+
+class BaseType(BaseTypeByIdentity):
+
+ def __eq__(self, other):
+ return (self.__class__ == other.__class__ and
+ self._get_items() == other._get_items())
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return hash((self.__class__, tuple(self._get_items())))
+
+
+class VoidType(BaseType):
+ _attrs_ = ()
+
+ def __init__(self):
+ self.c_name_with_marker = 'void&'
+
+ def build_backend_type(self, ffi, finishlist):
+ return global_cache(self, ffi, 'new_void_type')
+
+void_type = VoidType()
+
+
+class BasePrimitiveType(BaseType):
+ def is_complex_type(self):
+ return False
+
+
+class PrimitiveType(BasePrimitiveType):
+ _attrs_ = ('name',)
+
+ ALL_PRIMITIVE_TYPES = {
+ 'char': 'c',
+ 'short': 'i',
+ 'int': 'i',
+ 'long': 'i',
+ 'long long': 'i',
+ 'signed char': 'i',
+ 'unsigned char': 'i',
+ 'unsigned short': 'i',
+ 'unsigned int': 'i',
+ 'unsigned long': 'i',
+ 'unsigned long long': 'i',
+ 'float': 'f',
+ 'double': 'f',
+ 'long double': 'f',
+ 'float _Complex': 'j',
+ 'double _Complex': 'j',
+ '_Bool': 'i',
+ # the following types are not primitive in the C sense
+ 'wchar_t': 'c',
+ 'char16_t': 'c',
+ 'char32_t': 'c',
+ 'int8_t': 'i',
+ 'uint8_t': 'i',
+ 'int16_t': 'i',
+ 'uint16_t': 'i',
+ 'int32_t': 'i',
+ 'uint32_t': 'i',
+ 'int64_t': 'i',
+ 'uint64_t': 'i',
+ 'int_least8_t': 'i',
+ 'uint_least8_t': 'i',
+ 'int_least16_t': 'i',
+ 'uint_least16_t': 'i',
+ 'int_least32_t': 'i',
+ 'uint_least32_t': 'i',
+ 'int_least64_t': 'i',
+ 'uint_least64_t': 'i',
+ 'int_fast8_t': 'i',
+ 'uint_fast8_t': 'i',
+ 'int_fast16_t': 'i',
+ 'uint_fast16_t': 'i',
+ 'int_fast32_t': 'i',
+ 'uint_fast32_t': 'i',
+ 'int_fast64_t': 'i',
+ 'uint_fast64_t': 'i',
+ 'intptr_t': 'i',
+ 'uintptr_t': 'i',
+ 'intmax_t': 'i',
+ 'uintmax_t': 'i',
+ 'ptrdiff_t': 'i',
+ 'size_t': 'i',
+ 'ssize_t': 'i',
+ }
+
+ def __init__(self, name):
+ assert name in self.ALL_PRIMITIVE_TYPES
+ self.name = name
+ self.c_name_with_marker = name + '&'
+
+ def is_char_type(self):
+ return self.ALL_PRIMITIVE_TYPES[self.name] == 'c'
+ def is_integer_type(self):
+ return self.ALL_PRIMITIVE_TYPES[self.name] == 'i'
+ def is_float_type(self):
+ return self.ALL_PRIMITIVE_TYPES[self.name] == 'f'
+ def is_complex_type(self):
+ return self.ALL_PRIMITIVE_TYPES[self.name] == 'j'
+
+ def build_backend_type(self, ffi, finishlist):
+ return global_cache(self, ffi, 'new_primitive_type', self.name)
+
+
+class UnknownIntegerType(BasePrimitiveType):
+ _attrs_ = ('name',)
+
+ def __init__(self, name):
+ self.name = name
+ self.c_name_with_marker = name + '&'
+
+ def is_integer_type(self):
+ return True
+
+ def build_backend_type(self, ffi, finishlist):
+ raise NotImplementedError("integer type '%s' can only be used after "
+ "compilation" % self.name)
+
+class UnknownFloatType(BasePrimitiveType):
+ _attrs_ = ('name', )
+
+ def __init__(self, name):
+ self.name = name
+ self.c_name_with_marker = name + '&'
+
+ def build_backend_type(self, ffi, finishlist):
+ raise NotImplementedError("float type '%s' can only be used after "
+ "compilation" % self.name)
+
+
+class BaseFunctionType(BaseType):
+ _attrs_ = ('args', 'result', 'ellipsis', 'abi')
+
+ def __init__(self, args, result, ellipsis, abi=None):
+ self.args = args
+ self.result = result
+ self.ellipsis = ellipsis
+ self.abi = abi
+ #
+ reprargs = [arg._get_c_name() for arg in self.args]
+ if self.ellipsis:
+ reprargs.append('...')
+ reprargs = reprargs or ['void']
+ replace_with = self._base_pattern % (', '.join(reprargs),)
+ if abi is not None:
+ replace_with = replace_with[:1] + abi + ' ' + replace_with[1:]
+ self.c_name_with_marker = (
+ self.result.c_name_with_marker.replace('&', replace_with))
+
+
+class RawFunctionType(BaseFunctionType):
+ # Corresponds to a C type like 'int(int)', which is the C type of
+ # a function, but not a pointer-to-function. The backend has no
+ # notion of such a type; it's used temporarily by parsing.
+ _base_pattern = '(&)(%s)'
+ is_raw_function = True
+
+ def build_backend_type(self, ffi, finishlist):
+ raise CDefError("cannot render the type %r: it is a function "
+ "type, not a pointer-to-function type" % (self,))
+
+ def as_function_pointer(self):
+ return FunctionPtrType(self.args, self.result, self.ellipsis, self.abi)
+
+
+class FunctionPtrType(BaseFunctionType):
+ _base_pattern = '(*&)(%s)'
+
+ def build_backend_type(self, ffi, finishlist):
+ result = self.result.get_cached_btype(ffi, finishlist)
+ args = []
+ for tp in self.args:
+ args.append(tp.get_cached_btype(ffi, finishlist))
+ abi_args = ()
+ if self.abi == "__stdcall":
+ if not self.ellipsis: # __stdcall ignored for variadic funcs
+ try:
+ abi_args = (ffi._backend.FFI_STDCALL,)
+ except AttributeError:
+ pass
+ return global_cache(self, ffi, 'new_function_type',
+ tuple(args), result, self.ellipsis, *abi_args)
+
+ def as_raw_function(self):
+ return RawFunctionType(self.args, self.result, self.ellipsis, self.abi)
+
+
+class PointerType(BaseType):
+ _attrs_ = ('totype', 'quals')
+
+ def __init__(self, totype, quals=0):
+ self.totype = totype
+ self.quals = quals
+ extra = qualify(quals, " *&")
+ if totype.is_array_type:
+ extra = "(%s)" % (extra.lstrip(),)
+ self.c_name_with_marker = totype.c_name_with_marker.replace('&', extra)
+
+ def build_backend_type(self, ffi, finishlist):
+ BItem = self.totype.get_cached_btype(ffi, finishlist, can_delay=True)
+ return global_cache(self, ffi, 'new_pointer_type', BItem)
+
+voidp_type = PointerType(void_type)
+
+def ConstPointerType(totype):
+ return PointerType(totype, Q_CONST)
+
+const_voidp_type = ConstPointerType(void_type)
+
+
+class NamedPointerType(PointerType):
+ _attrs_ = ('totype', 'name')
+
+ def __init__(self, totype, name, quals=0):
+ PointerType.__init__(self, totype, quals)
+ self.name = name
+ self.c_name_with_marker = name + '&'
+
+
+class ArrayType(BaseType):
+ _attrs_ = ('item', 'length')
+ is_array_type = True
+
+ def __init__(self, item, length):
+ self.item = item
+ self.length = length
+ #
+ if length is None:
+ brackets = '&[]'
+ elif length == '...':
+ brackets = '&[/*...*/]'
+ else:
+ brackets = '&[%s]' % length
+ self.c_name_with_marker = (
+ self.item.c_name_with_marker.replace('&', brackets))
+
+ def length_is_unknown(self):
+ return isinstance(self.length, str)
+
+ def resolve_length(self, newlength):
+ return ArrayType(self.item, newlength)
+
+ def build_backend_type(self, ffi, finishlist):
+ if self.length_is_unknown():
+ raise CDefError("cannot render the type %r: unknown length" %
+ (self,))
+ self.item.get_cached_btype(ffi, finishlist) # force the item BType
+ BPtrItem = PointerType(self.item).get_cached_btype(ffi, finishlist)
+ return global_cache(self, ffi, 'new_array_type', BPtrItem, self.length)
+
+char_array_type = ArrayType(PrimitiveType('char'), None)
+
+
+class StructOrUnionOrEnum(BaseTypeByIdentity):
+ _attrs_ = ('name',)
+ forcename = None
+
+ def build_c_name_with_marker(self):
+ name = self.forcename or '%s %s' % (self.kind, self.name)
+ self.c_name_with_marker = name + '&'
+
+ def force_the_name(self, forcename):
+ self.forcename = forcename
+ self.build_c_name_with_marker()
+
+ def get_official_name(self):
+ assert self.c_name_with_marker.endswith('&')
+ return self.c_name_with_marker[:-1]
+
+
+class StructOrUnion(StructOrUnionOrEnum):
+ fixedlayout = None
+ completed = 0
+ partial = False
+ packed = 0
+
+ def __init__(self, name, fldnames, fldtypes, fldbitsize, fldquals=None):
+ self.name = name
+ self.fldnames = fldnames
+ self.fldtypes = fldtypes
+ self.fldbitsize = fldbitsize
+ self.fldquals = fldquals
+ self.build_c_name_with_marker()
+
+ def anonymous_struct_fields(self):
+ if self.fldtypes is not None:
+ for name, type in zip(self.fldnames, self.fldtypes):
+ if name == '' and isinstance(type, StructOrUnion):
+ yield type
+
+ def enumfields(self, expand_anonymous_struct_union=True):
+ fldquals = self.fldquals
+ if fldquals is None:
+ fldquals = (0,) * len(self.fldnames)
+ for name, type, bitsize, quals in zip(self.fldnames, self.fldtypes,
+ self.fldbitsize, fldquals):
+ if (name == '' and isinstance(type, StructOrUnion)
+ and expand_anonymous_struct_union):
+ # nested anonymous struct/union
+ for result in type.enumfields():
+ yield result
+ else:
+ yield (name, type, bitsize, quals)
+
+ def force_flatten(self):
+ # force the struct or union to have a declaration that lists
+ # directly all fields returned by enumfields(), flattening
+ # nested anonymous structs/unions.
+ names = []
+ types = []
+ bitsizes = []
+ fldquals = []
+ for name, type, bitsize, quals in self.enumfields():
+ names.append(name)
+ types.append(type)
+ bitsizes.append(bitsize)
+ fldquals.append(quals)
+ self.fldnames = tuple(names)
+ self.fldtypes = tuple(types)
+ self.fldbitsize = tuple(bitsizes)
+ self.fldquals = tuple(fldquals)
+
+ def get_cached_btype(self, ffi, finishlist, can_delay=False):
+ BType = StructOrUnionOrEnum.get_cached_btype(self, ffi, finishlist,
+ can_delay)
+ if not can_delay:
+ self.finish_backend_type(ffi, finishlist)
+ return BType
+
+ def finish_backend_type(self, ffi, finishlist):
+ if self.completed:
+ if self.completed != 2:
+ raise NotImplementedError("recursive structure declaration "
+ "for '%s'" % (self.name,))
+ return
+ BType = ffi._cached_btypes[self]
+ #
+ self.completed = 1
+ #
+ if self.fldtypes is None:
+ pass # not completing it: it's an opaque struct
+ #
+ elif self.fixedlayout is None:
+ fldtypes = [tp.get_cached_btype(ffi, finishlist)
+ for tp in self.fldtypes]
+ lst = list(zip(self.fldnames, fldtypes, self.fldbitsize))
+ extra_flags = ()
+ if self.packed:
+ if self.packed == 1:
+ extra_flags = (8,) # SF_PACKED
+ else:
+ extra_flags = (0, self.packed)
+ ffi._backend.complete_struct_or_union(BType, lst, self,
+ -1, -1, *extra_flags)
+ #
+ else:
+ fldtypes = []
+ fieldofs, fieldsize, totalsize, totalalignment = self.fixedlayout
+ for i in range(len(self.fldnames)):
+ fsize = fieldsize[i]
+ ftype = self.fldtypes[i]
+ #
+ if isinstance(ftype, ArrayType) and ftype.length_is_unknown():
+ # fix the length to match the total size
+ BItemType = ftype.item.get_cached_btype(ffi, finishlist)
+ nlen, nrest = divmod(fsize, ffi.sizeof(BItemType))
+ if nrest != 0:
+ self._verification_error(
+ "field '%s.%s' has a bogus size?" % (
+ self.name, self.fldnames[i] or '{}'))
+ ftype = ftype.resolve_length(nlen)
+ self.fldtypes = (self.fldtypes[:i] + (ftype,) +
+ self.fldtypes[i+1:])
+ #
+ BFieldType = ftype.get_cached_btype(ffi, finishlist)
+ if isinstance(ftype, ArrayType) and ftype.length is None:
+ assert fsize == 0
+ else:
+ bitemsize = ffi.sizeof(BFieldType)
+ if bitemsize != fsize:
+ self._verification_error(
+ "field '%s.%s' is declared as %d bytes, but is "
+ "really %d bytes" % (self.name,
+ self.fldnames[i] or '{}',
+ bitemsize, fsize))
+ fldtypes.append(BFieldType)
+ #
+ lst = list(zip(self.fldnames, fldtypes, self.fldbitsize, fieldofs))
+ ffi._backend.complete_struct_or_union(BType, lst, self,
+ totalsize, totalalignment)
+ self.completed = 2
+
+ def _verification_error(self, msg):
+ raise VerificationError(msg)
+
+ def check_not_partial(self):
+ if self.partial and self.fixedlayout is None:
+ raise VerificationMissing(self._get_c_name())
+
+ def build_backend_type(self, ffi, finishlist):
+ self.check_not_partial()
+ finishlist.append(self)
+ #
+ return global_cache(self, ffi, 'new_%s_type' % self.kind,
+ self.get_official_name(), key=self)
+
+
+class StructType(StructOrUnion):
+ kind = 'struct'
+
+
+class UnionType(StructOrUnion):
+ kind = 'union'
+
+
+class EnumType(StructOrUnionOrEnum):
+ kind = 'enum'
+ partial = False
+ partial_resolved = False
+
+ def __init__(self, name, enumerators, enumvalues, baseinttype=None):
+ self.name = name
+ self.enumerators = enumerators
+ self.enumvalues = enumvalues
+ self.baseinttype = baseinttype
+ self.build_c_name_with_marker()
+
+ def force_the_name(self, forcename):
+ StructOrUnionOrEnum.force_the_name(self, forcename)
+ if self.forcename is None:
+ name = self.get_official_name()
+ self.forcename = '$' + name.replace(' ', '_')
+
+ def check_not_partial(self):
+ if self.partial and not self.partial_resolved:
+ raise VerificationMissing(self._get_c_name())
+
+ def build_backend_type(self, ffi, finishlist):
+ self.check_not_partial()
+ base_btype = self.build_baseinttype(ffi, finishlist)
+ return global_cache(self, ffi, 'new_enum_type',
+ self.get_official_name(),
+ self.enumerators, self.enumvalues,
+ base_btype, key=self)
+
+ def build_baseinttype(self, ffi, finishlist):
+ if self.baseinttype is not None:
+ return self.baseinttype.get_cached_btype(ffi, finishlist)
+ #
+ if self.enumvalues:
+ smallest_value = min(self.enumvalues)
+ largest_value = max(self.enumvalues)
+ else:
+ import warnings
+ try:
+ # XXX! The goal is to ensure that the warnings.warn()
+ # will not suppress the warning. We want to get it
+ # several times if we reach this point several times.
+ __warningregistry__.clear()
+ except NameError:
+ pass
+ warnings.warn("%r has no values explicitly defined; "
+ "guessing that it is equivalent to 'unsigned int'"
+ % self._get_c_name())
+ smallest_value = largest_value = 0
+ if smallest_value < 0: # needs a signed type
+ sign = 1
+ candidate1 = PrimitiveType("int")
+ candidate2 = PrimitiveType("long")
+ else:
+ sign = 0
+ candidate1 = PrimitiveType("unsigned int")
+ candidate2 = PrimitiveType("unsigned long")
+ btype1 = candidate1.get_cached_btype(ffi, finishlist)
+ btype2 = candidate2.get_cached_btype(ffi, finishlist)
+ size1 = ffi.sizeof(btype1)
+ size2 = ffi.sizeof(btype2)
+ if (smallest_value >= ((-1) << (8*size1-1)) and
+ largest_value < (1 << (8*size1-sign))):
+ return btype1
+ if (smallest_value >= ((-1) << (8*size2-1)) and
+ largest_value < (1 << (8*size2-sign))):
+ return btype2
+ raise CDefError("%s values don't all fit into either 'long' "
+ "or 'unsigned long'" % self._get_c_name())
+
+def unknown_type(name, structname=None):
+ if structname is None:
+ structname = '$%s' % name
+ tp = StructType(structname, None, None, None)
+ tp.force_the_name(name)
+ tp.origin = "unknown_type"
+ return tp
+
+def unknown_ptr_type(name, structname=None):
+ if structname is None:
+ structname = '$$%s' % name
+ tp = StructType(structname, None, None, None)
+ return NamedPointerType(tp, name)
+
+
+global_lock = allocate_lock()
+_typecache_cffi_backend = weakref.WeakValueDictionary()
+
+def get_typecache(backend):
+ # returns _typecache_cffi_backend if backend is the _cffi_backend
+ # module, or type(backend).__typecache if backend is an instance of
+ # CTypesBackend (or some FakeBackend class during tests)
+ if isinstance(backend, types.ModuleType):
+ return _typecache_cffi_backend
+ with global_lock:
+ if not hasattr(type(backend), '__typecache'):
+ type(backend).__typecache = weakref.WeakValueDictionary()
+ return type(backend).__typecache
+
+def global_cache(srctype, ffi, funcname, *args, **kwds):
+ key = kwds.pop('key', (funcname, args))
+ assert not kwds
+ try:
+ return ffi._typecache[key]
+ except KeyError:
+ pass
+ try:
+ res = getattr(ffi._backend, funcname)(*args)
+ except NotImplementedError as e:
+ raise NotImplementedError("%s: %r: %s" % (funcname, srctype, e))
+ # note that setdefault() on WeakValueDictionary is not atomic
+ # and contains a rare bug (http://bugs.python.org/issue19542);
+ # we have to use a lock and do it ourselves
+ cache = ffi._typecache
+ with global_lock:
+ res1 = cache.get(key)
+ if res1 is None:
+ cache[key] = res
+ return res
+ else:
+ return res1
+
+def pointer_cache(ffi, BType):
+ return global_cache('?', ffi, 'new_pointer_type', BType)
+
+def attach_exception_info(e, name):
+ if e.args and type(e.args[0]) is str:
+ e.args = ('%s: %s' % (name, e.args[0]),) + e.args[1:]
diff --git a/lib/cffi/parse_c_type.h b/lib/cffi/parse_c_type.h
new file mode 100644
index 0000000..84e4ef8
--- /dev/null
+++ b/lib/cffi/parse_c_type.h
@@ -0,0 +1,181 @@
+
+/* This part is from file 'cffi/parse_c_type.h'. It is copied at the
+ beginning of C sources generated by CFFI's ffi.set_source(). */
+
+typedef void *_cffi_opcode_t;
+
+#define _CFFI_OP(opcode, arg) (_cffi_opcode_t)(opcode | (((uintptr_t)(arg)) << 8))
+#define _CFFI_GETOP(cffi_opcode) ((unsigned char)(uintptr_t)cffi_opcode)
+#define _CFFI_GETARG(cffi_opcode) (((intptr_t)cffi_opcode) >> 8)
+
+#define _CFFI_OP_PRIMITIVE 1
+#define _CFFI_OP_POINTER 3
+#define _CFFI_OP_ARRAY 5
+#define _CFFI_OP_OPEN_ARRAY 7
+#define _CFFI_OP_STRUCT_UNION 9
+#define _CFFI_OP_ENUM 11
+#define _CFFI_OP_FUNCTION 13
+#define _CFFI_OP_FUNCTION_END 15
+#define _CFFI_OP_NOOP 17
+#define _CFFI_OP_BITFIELD 19
+#define _CFFI_OP_TYPENAME 21
+#define _CFFI_OP_CPYTHON_BLTN_V 23 // varargs
+#define _CFFI_OP_CPYTHON_BLTN_N 25 // noargs
+#define _CFFI_OP_CPYTHON_BLTN_O 27 // O (i.e. a single arg)
+#define _CFFI_OP_CONSTANT 29
+#define _CFFI_OP_CONSTANT_INT 31
+#define _CFFI_OP_GLOBAL_VAR 33
+#define _CFFI_OP_DLOPEN_FUNC 35
+#define _CFFI_OP_DLOPEN_CONST 37
+#define _CFFI_OP_GLOBAL_VAR_F 39
+#define _CFFI_OP_EXTERN_PYTHON 41
+
+#define _CFFI_PRIM_VOID 0
+#define _CFFI_PRIM_BOOL 1
+#define _CFFI_PRIM_CHAR 2
+#define _CFFI_PRIM_SCHAR 3
+#define _CFFI_PRIM_UCHAR 4
+#define _CFFI_PRIM_SHORT 5
+#define _CFFI_PRIM_USHORT 6
+#define _CFFI_PRIM_INT 7
+#define _CFFI_PRIM_UINT 8
+#define _CFFI_PRIM_LONG 9
+#define _CFFI_PRIM_ULONG 10
+#define _CFFI_PRIM_LONGLONG 11
+#define _CFFI_PRIM_ULONGLONG 12
+#define _CFFI_PRIM_FLOAT 13
+#define _CFFI_PRIM_DOUBLE 14
+#define _CFFI_PRIM_LONGDOUBLE 15
+
+#define _CFFI_PRIM_WCHAR 16
+#define _CFFI_PRIM_INT8 17
+#define _CFFI_PRIM_UINT8 18
+#define _CFFI_PRIM_INT16 19
+#define _CFFI_PRIM_UINT16 20
+#define _CFFI_PRIM_INT32 21
+#define _CFFI_PRIM_UINT32 22
+#define _CFFI_PRIM_INT64 23
+#define _CFFI_PRIM_UINT64 24
+#define _CFFI_PRIM_INTPTR 25
+#define _CFFI_PRIM_UINTPTR 26
+#define _CFFI_PRIM_PTRDIFF 27
+#define _CFFI_PRIM_SIZE 28
+#define _CFFI_PRIM_SSIZE 29
+#define _CFFI_PRIM_INT_LEAST8 30
+#define _CFFI_PRIM_UINT_LEAST8 31
+#define _CFFI_PRIM_INT_LEAST16 32
+#define _CFFI_PRIM_UINT_LEAST16 33
+#define _CFFI_PRIM_INT_LEAST32 34
+#define _CFFI_PRIM_UINT_LEAST32 35
+#define _CFFI_PRIM_INT_LEAST64 36
+#define _CFFI_PRIM_UINT_LEAST64 37
+#define _CFFI_PRIM_INT_FAST8 38
+#define _CFFI_PRIM_UINT_FAST8 39
+#define _CFFI_PRIM_INT_FAST16 40
+#define _CFFI_PRIM_UINT_FAST16 41
+#define _CFFI_PRIM_INT_FAST32 42
+#define _CFFI_PRIM_UINT_FAST32 43
+#define _CFFI_PRIM_INT_FAST64 44
+#define _CFFI_PRIM_UINT_FAST64 45
+#define _CFFI_PRIM_INTMAX 46
+#define _CFFI_PRIM_UINTMAX 47
+#define _CFFI_PRIM_FLOATCOMPLEX 48
+#define _CFFI_PRIM_DOUBLECOMPLEX 49
+#define _CFFI_PRIM_CHAR16 50
+#define _CFFI_PRIM_CHAR32 51
+
+#define _CFFI__NUM_PRIM 52
+#define _CFFI__UNKNOWN_PRIM (-1)
+#define _CFFI__UNKNOWN_FLOAT_PRIM (-2)
+#define _CFFI__UNKNOWN_LONG_DOUBLE (-3)
+
+#define _CFFI__IO_FILE_STRUCT (-1)
+
+
+struct _cffi_global_s {
+ const char *name;
+ void *address;
+ _cffi_opcode_t type_op;
+ void *size_or_direct_fn; // OP_GLOBAL_VAR: size, or 0 if unknown
+ // OP_CPYTHON_BLTN_*: addr of direct function
+};
+
+struct _cffi_getconst_s {
+ unsigned long long value;
+ const struct _cffi_type_context_s *ctx;
+ int gindex;
+};
+
+struct _cffi_struct_union_s {
+ const char *name;
+ int type_index; // -> _cffi_types, on a OP_STRUCT_UNION
+ int flags; // _CFFI_F_* flags below
+ size_t size;
+ int alignment;
+ int first_field_index; // -> _cffi_fields array
+ int num_fields;
+};
+#define _CFFI_F_UNION 0x01 // is a union, not a struct
+#define _CFFI_F_CHECK_FIELDS 0x02 // complain if fields are not in the
+ // "standard layout" or if some are missing
+#define _CFFI_F_PACKED 0x04 // for CHECK_FIELDS, assume a packed struct
+#define _CFFI_F_EXTERNAL 0x08 // in some other ffi.include()
+#define _CFFI_F_OPAQUE 0x10 // opaque
+
+struct _cffi_field_s {
+ const char *name;
+ size_t field_offset;
+ size_t field_size;
+ _cffi_opcode_t field_type_op;
+};
+
+struct _cffi_enum_s {
+ const char *name;
+ int type_index; // -> _cffi_types, on a OP_ENUM
+ int type_prim; // _CFFI_PRIM_xxx
+ const char *enumerators; // comma-delimited string
+};
+
+struct _cffi_typename_s {
+ const char *name;
+ int type_index; /* if opaque, points to a possibly artificial
+ OP_STRUCT which is itself opaque */
+};
+
+struct _cffi_type_context_s {
+ _cffi_opcode_t *types;
+ const struct _cffi_global_s *globals;
+ const struct _cffi_field_s *fields;
+ const struct _cffi_struct_union_s *struct_unions;
+ const struct _cffi_enum_s *enums;
+ const struct _cffi_typename_s *typenames;
+ int num_globals;
+ int num_struct_unions;
+ int num_enums;
+ int num_typenames;
+ const char *const *includes;
+ int num_types;
+ int flags; /* future extension */
+};
+
+struct _cffi_parse_info_s {
+ const struct _cffi_type_context_s *ctx;
+ _cffi_opcode_t *output;
+ unsigned int output_size;
+ size_t error_location;
+ const char *error_message;
+};
+
+struct _cffi_externpy_s {
+ const char *name;
+ size_t size_of_result;
+ void *reserved1, *reserved2;
+};
+
+#ifdef _CFFI_INTERNAL
+static int parse_c_type(struct _cffi_parse_info_s *info, const char *input);
+static int search_in_globals(const struct _cffi_type_context_s *ctx,
+ const char *search, size_t search_len);
+static int search_in_struct_unions(const struct _cffi_type_context_s *ctx,
+ const char *search, size_t search_len);
+#endif
diff --git a/lib/cffi/pkgconfig.py b/lib/cffi/pkgconfig.py
new file mode 100644
index 0000000..5c93f15
--- /dev/null
+++ b/lib/cffi/pkgconfig.py
@@ -0,0 +1,121 @@
+# pkg-config, https://www.freedesktop.org/wiki/Software/pkg-config/ integration for cffi
+import sys, os, subprocess
+
+from .error import PkgConfigError
+
+
+def merge_flags(cfg1, cfg2):
+ """Merge values from cffi config flags cfg2 to cf1
+
+ Example:
+ merge_flags({"libraries": ["one"]}, {"libraries": ["two"]})
+ {"libraries": ["one", "two"]}
+ """
+ for key, value in cfg2.items():
+ if key not in cfg1:
+ cfg1[key] = value
+ else:
+ if not isinstance(cfg1[key], list):
+ raise TypeError("cfg1[%r] should be a list of strings" % (key,))
+ if not isinstance(value, list):
+ raise TypeError("cfg2[%r] should be a list of strings" % (key,))
+ cfg1[key].extend(value)
+ return cfg1
+
+
+def call(libname, flag, encoding=sys.getfilesystemencoding()):
+ """Calls pkg-config and returns the output if found
+ """
+ a = ["pkg-config", "--print-errors"]
+ a.append(flag)
+ a.append(libname)
+ try:
+ pc = subprocess.Popen(a, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ except EnvironmentError as e:
+ raise PkgConfigError("cannot run pkg-config: %s" % (str(e).strip(),))
+
+ bout, berr = pc.communicate()
+ if pc.returncode != 0:
+ try:
+ berr = berr.decode(encoding)
+ except Exception:
+ pass
+ raise PkgConfigError(berr.strip())
+
+ if sys.version_info >= (3,) and not isinstance(bout, str): # Python 3.x
+ try:
+ bout = bout.decode(encoding)
+ except UnicodeDecodeError:
+ raise PkgConfigError("pkg-config %s %s returned bytes that cannot "
+ "be decoded with encoding %r:\n%r" %
+ (flag, libname, encoding, bout))
+
+ if os.altsep != '\\' and '\\' in bout:
+ raise PkgConfigError("pkg-config %s %s returned an unsupported "
+ "backslash-escaped output:\n%r" %
+ (flag, libname, bout))
+ return bout
+
+
+def flags_from_pkgconfig(libs):
+ r"""Return compiler line flags for FFI.set_source based on pkg-config output
+
+ Usage
+ ...
+ ffibuilder.set_source("_foo", pkgconfig = ["libfoo", "libbar >= 1.8.3"])
+
+ If pkg-config is installed on build machine, then arguments include_dirs,
+ library_dirs, libraries, define_macros, extra_compile_args and
+ extra_link_args are extended with an output of pkg-config for libfoo and
+ libbar.
+
+ Raises PkgConfigError in case the pkg-config call fails.
+ """
+
+ def get_include_dirs(string):
+ return [x[2:] for x in string.split() if x.startswith("-I")]
+
+ def get_library_dirs(string):
+ return [x[2:] for x in string.split() if x.startswith("-L")]
+
+ def get_libraries(string):
+ return [x[2:] for x in string.split() if x.startswith("-l")]
+
+ # convert -Dfoo=bar to list of tuples [("foo", "bar")] expected by distutils
+ def get_macros(string):
+ def _macro(x):
+ x = x[2:] # drop "-D"
+ if '=' in x:
+ return tuple(x.split("=", 1)) # "-Dfoo=bar" => ("foo", "bar")
+ else:
+ return (x, None) # "-Dfoo" => ("foo", None)
+ return [_macro(x) for x in string.split() if x.startswith("-D")]
+
+ def get_other_cflags(string):
+ return [x for x in string.split() if not x.startswith("-I") and
+ not x.startswith("-D")]
+
+ def get_other_libs(string):
+ return [x for x in string.split() if not x.startswith("-L") and
+ not x.startswith("-l")]
+
+ # return kwargs for given libname
+ def kwargs(libname):
+ fse = sys.getfilesystemencoding()
+ all_cflags = call(libname, "--cflags")
+ all_libs = call(libname, "--libs")
+ return {
+ "include_dirs": get_include_dirs(all_cflags),
+ "library_dirs": get_library_dirs(all_libs),
+ "libraries": get_libraries(all_libs),
+ "define_macros": get_macros(all_cflags),
+ "extra_compile_args": get_other_cflags(all_cflags),
+ "extra_link_args": get_other_libs(all_libs),
+ }
+
+ # merge all arguments together
+ ret = {}
+ for libname in libs:
+ lib_flags = kwargs(libname)
+ merge_flags(ret, lib_flags)
+ return ret
diff --git a/lib/cffi/recompiler.py b/lib/cffi/recompiler.py
new file mode 100644
index 0000000..5d9d32d
--- /dev/null
+++ b/lib/cffi/recompiler.py
@@ -0,0 +1,1581 @@
+import os, sys, io
+from . import ffiplatform, model
+from .error import VerificationError
+from .cffi_opcode import *
+
+VERSION_BASE = 0x2601
+VERSION_EMBEDDED = 0x2701
+VERSION_CHAR16CHAR32 = 0x2801
+
+USE_LIMITED_API = (sys.platform != 'win32' or sys.version_info < (3, 0) or
+ sys.version_info >= (3, 5))
+
+
+class GlobalExpr:
+ def __init__(self, name, address, type_op, size=0, check_value=0):
+ self.name = name
+ self.address = address
+ self.type_op = type_op
+ self.size = size
+ self.check_value = check_value
+
+ def as_c_expr(self):
+ return ' { "%s", (void *)%s, %s, (void *)%s },' % (
+ self.name, self.address, self.type_op.as_c_expr(), self.size)
+
+ def as_python_expr(self):
+ return "b'%s%s',%d" % (self.type_op.as_python_bytes(), self.name,
+ self.check_value)
+
+class FieldExpr:
+ def __init__(self, name, field_offset, field_size, fbitsize, field_type_op):
+ self.name = name
+ self.field_offset = field_offset
+ self.field_size = field_size
+ self.fbitsize = fbitsize
+ self.field_type_op = field_type_op
+
+ def as_c_expr(self):
+ spaces = " " * len(self.name)
+ return (' { "%s", %s,\n' % (self.name, self.field_offset) +
+ ' %s %s,\n' % (spaces, self.field_size) +
+ ' %s %s },' % (spaces, self.field_type_op.as_c_expr()))
+
+ def as_python_expr(self):
+ raise NotImplementedError
+
+ def as_field_python_expr(self):
+ if self.field_type_op.op == OP_NOOP:
+ size_expr = ''
+ elif self.field_type_op.op == OP_BITFIELD:
+ size_expr = format_four_bytes(self.fbitsize)
+ else:
+ raise NotImplementedError
+ return "b'%s%s%s'" % (self.field_type_op.as_python_bytes(),
+ size_expr,
+ self.name)
+
+class StructUnionExpr:
+ def __init__(self, name, type_index, flags, size, alignment, comment,
+ first_field_index, c_fields):
+ self.name = name
+ self.type_index = type_index
+ self.flags = flags
+ self.size = size
+ self.alignment = alignment
+ self.comment = comment
+ self.first_field_index = first_field_index
+ self.c_fields = c_fields
+
+ def as_c_expr(self):
+ return (' { "%s", %d, %s,' % (self.name, self.type_index, self.flags)
+ + '\n %s, %s, ' % (self.size, self.alignment)
+ + '%d, %d ' % (self.first_field_index, len(self.c_fields))
+ + ('/* %s */ ' % self.comment if self.comment else '')
+ + '},')
+
+ def as_python_expr(self):
+ flags = eval(self.flags, G_FLAGS)
+ fields_expr = [c_field.as_field_python_expr()
+ for c_field in self.c_fields]
+ return "(b'%s%s%s',%s)" % (
+ format_four_bytes(self.type_index),
+ format_four_bytes(flags),
+ self.name,
+ ','.join(fields_expr))
+
+class EnumExpr:
+ def __init__(self, name, type_index, size, signed, allenums):
+ self.name = name
+ self.type_index = type_index
+ self.size = size
+ self.signed = signed
+ self.allenums = allenums
+
+ def as_c_expr(self):
+ return (' { "%s", %d, _cffi_prim_int(%s, %s),\n'
+ ' "%s" },' % (self.name, self.type_index,
+ self.size, self.signed, self.allenums))
+
+ def as_python_expr(self):
+ prim_index = {
+ (1, 0): PRIM_UINT8, (1, 1): PRIM_INT8,
+ (2, 0): PRIM_UINT16, (2, 1): PRIM_INT16,
+ (4, 0): PRIM_UINT32, (4, 1): PRIM_INT32,
+ (8, 0): PRIM_UINT64, (8, 1): PRIM_INT64,
+ }[self.size, self.signed]
+ return "b'%s%s%s\\x00%s'" % (format_four_bytes(self.type_index),
+ format_four_bytes(prim_index),
+ self.name, self.allenums)
+
+class TypenameExpr:
+ def __init__(self, name, type_index):
+ self.name = name
+ self.type_index = type_index
+
+ def as_c_expr(self):
+ return ' { "%s", %d },' % (self.name, self.type_index)
+
+ def as_python_expr(self):
+ return "b'%s%s'" % (format_four_bytes(self.type_index), self.name)
+
+
+# ____________________________________________________________
+
+
+class Recompiler:
+ _num_externpy = 0
+
+ def __init__(self, ffi, module_name, target_is_python=False):
+ self.ffi = ffi
+ self.module_name = module_name
+ self.target_is_python = target_is_python
+ self._version = VERSION_BASE
+
+ def needs_version(self, ver):
+ self._version = max(self._version, ver)
+
+ def collect_type_table(self):
+ self._typesdict = {}
+ self._generate("collecttype")
+ #
+ all_decls = sorted(self._typesdict, key=str)
+ #
+ # prepare all FUNCTION bytecode sequences first
+ self.cffi_types = []
+ for tp in all_decls:
+ if tp.is_raw_function:
+ assert self._typesdict[tp] is None
+ self._typesdict[tp] = len(self.cffi_types)
+ self.cffi_types.append(tp) # placeholder
+ for tp1 in tp.args:
+ assert isinstance(tp1, (model.VoidType,
+ model.BasePrimitiveType,
+ model.PointerType,
+ model.StructOrUnionOrEnum,
+ model.FunctionPtrType))
+ if self._typesdict[tp1] is None:
+ self._typesdict[tp1] = len(self.cffi_types)
+ self.cffi_types.append(tp1) # placeholder
+ self.cffi_types.append('END') # placeholder
+ #
+ # prepare all OTHER bytecode sequences
+ for tp in all_decls:
+ if not tp.is_raw_function and self._typesdict[tp] is None:
+ self._typesdict[tp] = len(self.cffi_types)
+ self.cffi_types.append(tp) # placeholder
+ if tp.is_array_type and tp.length is not None:
+ self.cffi_types.append('LEN') # placeholder
+ assert None not in self._typesdict.values()
+ #
+ # collect all structs and unions and enums
+ self._struct_unions = {}
+ self._enums = {}
+ for tp in all_decls:
+ if isinstance(tp, model.StructOrUnion):
+ self._struct_unions[tp] = None
+ elif isinstance(tp, model.EnumType):
+ self._enums[tp] = None
+ for i, tp in enumerate(sorted(self._struct_unions,
+ key=lambda tp: tp.name)):
+ self._struct_unions[tp] = i
+ for i, tp in enumerate(sorted(self._enums,
+ key=lambda tp: tp.name)):
+ self._enums[tp] = i
+ #
+ # emit all bytecode sequences now
+ for tp in all_decls:
+ method = getattr(self, '_emit_bytecode_' + tp.__class__.__name__)
+ method(tp, self._typesdict[tp])
+ #
+ # consistency check
+ for op in self.cffi_types:
+ assert isinstance(op, CffiOp)
+ self.cffi_types = tuple(self.cffi_types) # don't change any more
+
+ def _enum_fields(self, tp):
+ # When producing C, expand all anonymous struct/union fields.
+ # That's necessary to have C code checking the offsets of the
+ # individual fields contained in them. When producing Python,
+ # don't do it and instead write it like it is, with the
+ # corresponding fields having an empty name. Empty names are
+ # recognized at runtime when we import the generated Python
+ # file.
+ expand_anonymous_struct_union = not self.target_is_python
+ return tp.enumfields(expand_anonymous_struct_union)
+
+ def _do_collect_type(self, tp):
+ if not isinstance(tp, model.BaseTypeByIdentity):
+ if isinstance(tp, tuple):
+ for x in tp:
+ self._do_collect_type(x)
+ return
+ if tp not in self._typesdict:
+ self._typesdict[tp] = None
+ if isinstance(tp, model.FunctionPtrType):
+ self._do_collect_type(tp.as_raw_function())
+ elif isinstance(tp, model.StructOrUnion):
+ if tp.fldtypes is not None and (
+ tp not in self.ffi._parser._included_declarations):
+ for name1, tp1, _, _ in self._enum_fields(tp):
+ self._do_collect_type(self._field_type(tp, name1, tp1))
+ else:
+ for _, x in tp._get_items():
+ self._do_collect_type(x)
+
+ def _generate(self, step_name):
+ lst = self.ffi._parser._declarations.items()
+ for name, (tp, quals) in sorted(lst):
+ kind, realname = name.split(' ', 1)
+ try:
+ method = getattr(self, '_generate_cpy_%s_%s' % (kind,
+ step_name))
+ except AttributeError:
+ raise VerificationError(
+ "not implemented in recompile(): %r" % name)
+ try:
+ self._current_quals = quals
+ method(tp, realname)
+ except Exception as e:
+ model.attach_exception_info(e, name)
+ raise
+
+ # ----------
+
+ ALL_STEPS = ["global", "field", "struct_union", "enum", "typename"]
+
+ def collect_step_tables(self):
+ # collect the declarations for '_cffi_globals', '_cffi_typenames', etc.
+ self._lsts = {}
+ for step_name in self.ALL_STEPS:
+ self._lsts[step_name] = []
+ self._seen_struct_unions = set()
+ self._generate("ctx")
+ self._add_missing_struct_unions()
+ #
+ for step_name in self.ALL_STEPS:
+ lst = self._lsts[step_name]
+ if step_name != "field":
+ lst.sort(key=lambda entry: entry.name)
+ self._lsts[step_name] = tuple(lst) # don't change any more
+ #
+ # check for a possible internal inconsistency: _cffi_struct_unions
+ # should have been generated with exactly self._struct_unions
+ lst = self._lsts["struct_union"]
+ for tp, i in self._struct_unions.items():
+ assert i < len(lst)
+ assert lst[i].name == tp.name
+ assert len(lst) == len(self._struct_unions)
+ # same with enums
+ lst = self._lsts["enum"]
+ for tp, i in self._enums.items():
+ assert i < len(lst)
+ assert lst[i].name == tp.name
+ assert len(lst) == len(self._enums)
+
+ # ----------
+
+ def _prnt(self, what=''):
+ self._f.write(what + '\n')
+
+ def write_source_to_f(self, f, preamble):
+ if self.target_is_python:
+ assert preamble is None
+ self.write_py_source_to_f(f)
+ else:
+ assert preamble is not None
+ self.write_c_source_to_f(f, preamble)
+
+ def _rel_readlines(self, filename):
+ g = open(os.path.join(os.path.dirname(__file__), filename), 'r')
+ lines = g.readlines()
+ g.close()
+ return lines
+
+ def write_c_source_to_f(self, f, preamble):
+ self._f = f
+ prnt = self._prnt
+ if self.ffi._embedding is not None:
+ prnt('#define _CFFI_USE_EMBEDDING')
+ if not USE_LIMITED_API:
+ prnt('#define _CFFI_NO_LIMITED_API')
+ #
+ # first the '#include' (actually done by inlining the file's content)
+ lines = self._rel_readlines('_cffi_include.h')
+ i = lines.index('#include "parse_c_type.h"\n')
+ lines[i:i+1] = self._rel_readlines('parse_c_type.h')
+ prnt(''.join(lines))
+ #
+ # if we have ffi._embedding != None, we give it here as a macro
+ # and include an extra file
+ base_module_name = self.module_name.split('.')[-1]
+ if self.ffi._embedding is not None:
+ prnt('#define _CFFI_MODULE_NAME "%s"' % (self.module_name,))
+ prnt('static const char _CFFI_PYTHON_STARTUP_CODE[] = {')
+ self._print_string_literal_in_array(self.ffi._embedding)
+ prnt('0 };')
+ prnt('#ifdef PYPY_VERSION')
+ prnt('# define _CFFI_PYTHON_STARTUP_FUNC _cffi_pypyinit_%s' % (
+ base_module_name,))
+ prnt('#elif PY_MAJOR_VERSION >= 3')
+ prnt('# define _CFFI_PYTHON_STARTUP_FUNC PyInit_%s' % (
+ base_module_name,))
+ prnt('#else')
+ prnt('# define _CFFI_PYTHON_STARTUP_FUNC init%s' % (
+ base_module_name,))
+ prnt('#endif')
+ lines = self._rel_readlines('_embedding.h')
+ i = lines.index('#include "_cffi_errors.h"\n')
+ lines[i:i+1] = self._rel_readlines('_cffi_errors.h')
+ prnt(''.join(lines))
+ self.needs_version(VERSION_EMBEDDED)
+ #
+ # then paste the C source given by the user, verbatim.
+ prnt('/************************************************************/')
+ prnt()
+ prnt(preamble)
+ prnt()
+ prnt('/************************************************************/')
+ prnt()
+ #
+ # the declaration of '_cffi_types'
+ prnt('static void *_cffi_types[] = {')
+ typeindex2type = dict([(i, tp) for (tp, i) in self._typesdict.items()])
+ for i, op in enumerate(self.cffi_types):
+ comment = ''
+ if i in typeindex2type:
+ comment = ' // ' + typeindex2type[i]._get_c_name()
+ prnt('/* %2d */ %s,%s' % (i, op.as_c_expr(), comment))
+ if not self.cffi_types:
+ prnt(' 0')
+ prnt('};')
+ prnt()
+ #
+ # call generate_cpy_xxx_decl(), for every xxx found from
+ # ffi._parser._declarations. This generates all the functions.
+ self._seen_constants = set()
+ self._generate("decl")
+ #
+ # the declaration of '_cffi_globals' and '_cffi_typenames'
+ nums = {}
+ for step_name in self.ALL_STEPS:
+ lst = self._lsts[step_name]
+ nums[step_name] = len(lst)
+ if nums[step_name] > 0:
+ prnt('static const struct _cffi_%s_s _cffi_%ss[] = {' % (
+ step_name, step_name))
+ for entry in lst:
+ prnt(entry.as_c_expr())
+ prnt('};')
+ prnt()
+ #
+ # the declaration of '_cffi_includes'
+ if self.ffi._included_ffis:
+ prnt('static const char * const _cffi_includes[] = {')
+ for ffi_to_include in self.ffi._included_ffis:
+ try:
+ included_module_name, included_source = (
+ ffi_to_include._assigned_source[:2])
+ except AttributeError:
+ raise VerificationError(
+ "ffi object %r includes %r, but the latter has not "
+ "been prepared with set_source()" % (
+ self.ffi, ffi_to_include,))
+ if included_source is None:
+ raise VerificationError(
+ "not implemented yet: ffi.include() of a Python-based "
+ "ffi inside a C-based ffi")
+ prnt(' "%s",' % (included_module_name,))
+ prnt(' NULL')
+ prnt('};')
+ prnt()
+ #
+ # the declaration of '_cffi_type_context'
+ prnt('static const struct _cffi_type_context_s _cffi_type_context = {')
+ prnt(' _cffi_types,')
+ for step_name in self.ALL_STEPS:
+ if nums[step_name] > 0:
+ prnt(' _cffi_%ss,' % step_name)
+ else:
+ prnt(' NULL, /* no %ss */' % step_name)
+ for step_name in self.ALL_STEPS:
+ if step_name != "field":
+ prnt(' %d, /* num_%ss */' % (nums[step_name], step_name))
+ if self.ffi._included_ffis:
+ prnt(' _cffi_includes,')
+ else:
+ prnt(' NULL, /* no includes */')
+ prnt(' %d, /* num_types */' % (len(self.cffi_types),))
+ flags = 0
+ if self._num_externpy > 0 or self.ffi._embedding is not None:
+ flags |= 1 # set to mean that we use extern "Python"
+ prnt(' %d, /* flags */' % flags)
+ prnt('};')
+ prnt()
+ #
+ # the init function
+ prnt('#ifdef __GNUC__')
+ prnt('# pragma GCC visibility push(default) /* for -fvisibility= */')
+ prnt('#endif')
+ prnt()
+ prnt('#ifdef PYPY_VERSION')
+ prnt('PyMODINIT_FUNC')
+ prnt('_cffi_pypyinit_%s(const void *p[])' % (base_module_name,))
+ prnt('{')
+ if flags & 1:
+ prnt(' if (((intptr_t)p[0]) >= 0x0A03) {')
+ prnt(' _cffi_call_python_org = '
+ '(void(*)(struct _cffi_externpy_s *, char *))p[1];')
+ prnt(' }')
+ prnt(' p[0] = (const void *)0x%x;' % self._version)
+ prnt(' p[1] = &_cffi_type_context;')
+ prnt('#if PY_MAJOR_VERSION >= 3')
+ prnt(' return NULL;')
+ prnt('#endif')
+ prnt('}')
+ # on Windows, distutils insists on putting init_cffi_xyz in
+ # 'export_symbols', so instead of fighting it, just give up and
+ # give it one
+ prnt('# ifdef _MSC_VER')
+ prnt(' PyMODINIT_FUNC')
+ prnt('# if PY_MAJOR_VERSION >= 3')
+ prnt(' PyInit_%s(void) { return NULL; }' % (base_module_name,))
+ prnt('# else')
+ prnt(' init%s(void) { }' % (base_module_name,))
+ prnt('# endif')
+ prnt('# endif')
+ prnt('#elif PY_MAJOR_VERSION >= 3')
+ prnt('PyMODINIT_FUNC')
+ prnt('PyInit_%s(void)' % (base_module_name,))
+ prnt('{')
+ prnt(' return _cffi_init("%s", 0x%x, &_cffi_type_context);' % (
+ self.module_name, self._version))
+ prnt('}')
+ prnt('#else')
+ prnt('PyMODINIT_FUNC')
+ prnt('init%s(void)' % (base_module_name,))
+ prnt('{')
+ prnt(' _cffi_init("%s", 0x%x, &_cffi_type_context);' % (
+ self.module_name, self._version))
+ prnt('}')
+ prnt('#endif')
+ prnt()
+ prnt('#ifdef __GNUC__')
+ prnt('# pragma GCC visibility pop')
+ prnt('#endif')
+ self._version = None
+
+ def _to_py(self, x):
+ if isinstance(x, str):
+ return "b'%s'" % (x,)
+ if isinstance(x, (list, tuple)):
+ rep = [self._to_py(item) for item in x]
+ if len(rep) == 1:
+ rep.append('')
+ return "(%s)" % (','.join(rep),)
+ return x.as_python_expr() # Py2: unicode unexpected; Py3: bytes unexp.
+
+ def write_py_source_to_f(self, f):
+ self._f = f
+ prnt = self._prnt
+ #
+ # header
+ prnt("# auto-generated file")
+ prnt("import _cffi_backend")
+ #
+ # the 'import' of the included ffis
+ num_includes = len(self.ffi._included_ffis or ())
+ for i in range(num_includes):
+ ffi_to_include = self.ffi._included_ffis[i]
+ try:
+ included_module_name, included_source = (
+ ffi_to_include._assigned_source[:2])
+ except AttributeError:
+ raise VerificationError(
+ "ffi object %r includes %r, but the latter has not "
+ "been prepared with set_source()" % (
+ self.ffi, ffi_to_include,))
+ if included_source is not None:
+ raise VerificationError(
+ "not implemented yet: ffi.include() of a C-based "
+ "ffi inside a Python-based ffi")
+ prnt('from %s import ffi as _ffi%d' % (included_module_name, i))
+ prnt()
+ prnt("ffi = _cffi_backend.FFI('%s'," % (self.module_name,))
+ prnt(" _version = 0x%x," % (self._version,))
+ self._version = None
+ #
+ # the '_types' keyword argument
+ self.cffi_types = tuple(self.cffi_types) # don't change any more
+ types_lst = [op.as_python_bytes() for op in self.cffi_types]
+ prnt(' _types = %s,' % (self._to_py(''.join(types_lst)),))
+ typeindex2type = dict([(i, tp) for (tp, i) in self._typesdict.items()])
+ #
+ # the keyword arguments from ALL_STEPS
+ for step_name in self.ALL_STEPS:
+ lst = self._lsts[step_name]
+ if len(lst) > 0 and step_name != "field":
+ prnt(' _%ss = %s,' % (step_name, self._to_py(lst)))
+ #
+ # the '_includes' keyword argument
+ if num_includes > 0:
+ prnt(' _includes = (%s,),' % (
+ ', '.join(['_ffi%d' % i for i in range(num_includes)]),))
+ #
+ # the footer
+ prnt(')')
+
+ # ----------
+
+ def _gettypenum(self, type):
+ # a KeyError here is a bug. please report it! :-)
+ return self._typesdict[type]
+
+ def _convert_funcarg_to_c(self, tp, fromvar, tovar, errcode):
+ extraarg = ''
+ if isinstance(tp, model.BasePrimitiveType) and not tp.is_complex_type():
+ if tp.is_integer_type() and tp.name != '_Bool':
+ converter = '_cffi_to_c_int'
+ extraarg = ', %s' % tp.name
+ elif isinstance(tp, model.UnknownFloatType):
+ # don't check with is_float_type(): it may be a 'long
+ # double' here, and _cffi_to_c_double would loose precision
+ converter = '(%s)_cffi_to_c_double' % (tp.get_c_name(''),)
+ else:
+ cname = tp.get_c_name('')
+ converter = '(%s)_cffi_to_c_%s' % (cname,
+ tp.name.replace(' ', '_'))
+ if cname in ('char16_t', 'char32_t'):
+ self.needs_version(VERSION_CHAR16CHAR32)
+ errvalue = '-1'
+ #
+ elif isinstance(tp, model.PointerType):
+ self._convert_funcarg_to_c_ptr_or_array(tp, fromvar,
+ tovar, errcode)
+ return
+ #
+ elif (isinstance(tp, model.StructOrUnionOrEnum) or
+ isinstance(tp, model.BasePrimitiveType)):
+ # a struct (not a struct pointer) as a function argument;
+ # or, a complex (the same code works)
+ self._prnt(' if (_cffi_to_c((char *)&%s, _cffi_type(%d), %s) < 0)'
+ % (tovar, self._gettypenum(tp), fromvar))
+ self._prnt(' %s;' % errcode)
+ return
+ #
+ elif isinstance(tp, model.FunctionPtrType):
+ converter = '(%s)_cffi_to_c_pointer' % tp.get_c_name('')
+ extraarg = ', _cffi_type(%d)' % self._gettypenum(tp)
+ errvalue = 'NULL'
+ #
+ else:
+ raise NotImplementedError(tp)
+ #
+ self._prnt(' %s = %s(%s%s);' % (tovar, converter, fromvar, extraarg))
+ self._prnt(' if (%s == (%s)%s && PyErr_Occurred())' % (
+ tovar, tp.get_c_name(''), errvalue))
+ self._prnt(' %s;' % errcode)
+
+ def _extra_local_variables(self, tp, localvars, freelines):
+ if isinstance(tp, model.PointerType):
+ localvars.add('Py_ssize_t datasize')
+ localvars.add('struct _cffi_freeme_s *large_args_free = NULL')
+ freelines.add('if (large_args_free != NULL)'
+ ' _cffi_free_array_arguments(large_args_free);')
+
+ def _convert_funcarg_to_c_ptr_or_array(self, tp, fromvar, tovar, errcode):
+ self._prnt(' datasize = _cffi_prepare_pointer_call_argument(')
+ self._prnt(' _cffi_type(%d), %s, (char **)&%s);' % (
+ self._gettypenum(tp), fromvar, tovar))
+ self._prnt(' if (datasize != 0) {')
+ self._prnt(' %s = ((size_t)datasize) <= 640 ? '
+ '(%s)alloca((size_t)datasize) : NULL;' % (
+ tovar, tp.get_c_name('')))
+ self._prnt(' if (_cffi_convert_array_argument(_cffi_type(%d), %s, '
+ '(char **)&%s,' % (self._gettypenum(tp), fromvar, tovar))
+ self._prnt(' datasize, &large_args_free) < 0)')
+ self._prnt(' %s;' % errcode)
+ self._prnt(' }')
+
+ def _convert_expr_from_c(self, tp, var, context):
+ if isinstance(tp, model.BasePrimitiveType):
+ if tp.is_integer_type() and tp.name != '_Bool':
+ return '_cffi_from_c_int(%s, %s)' % (var, tp.name)
+ elif isinstance(tp, model.UnknownFloatType):
+ return '_cffi_from_c_double(%s)' % (var,)
+ elif tp.name != 'long double' and not tp.is_complex_type():
+ cname = tp.name.replace(' ', '_')
+ if cname in ('char16_t', 'char32_t'):
+ self.needs_version(VERSION_CHAR16CHAR32)
+ return '_cffi_from_c_%s(%s)' % (cname, var)
+ else:
+ return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ elif isinstance(tp, (model.PointerType, model.FunctionPtrType)):
+ return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ elif isinstance(tp, model.ArrayType):
+ return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % (
+ var, self._gettypenum(model.PointerType(tp.item)))
+ elif isinstance(tp, model.StructOrUnion):
+ if tp.fldnames is None:
+ raise TypeError("'%s' is used as %s, but is opaque" % (
+ tp._get_c_name(), context))
+ return '_cffi_from_c_struct((char *)&%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ elif isinstance(tp, model.EnumType):
+ return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ else:
+ raise NotImplementedError(tp)
+
+ # ----------
+ # typedefs
+
+ def _typedef_type(self, tp, name):
+ return self._global_type(tp, "(*(%s *)0)" % (name,))
+
+ def _generate_cpy_typedef_collecttype(self, tp, name):
+ self._do_collect_type(self._typedef_type(tp, name))
+
+ def _generate_cpy_typedef_decl(self, tp, name):
+ pass
+
+ def _typedef_ctx(self, tp, name):
+ type_index = self._typesdict[tp]
+ self._lsts["typename"].append(TypenameExpr(name, type_index))
+
+ def _generate_cpy_typedef_ctx(self, tp, name):
+ tp = self._typedef_type(tp, name)
+ self._typedef_ctx(tp, name)
+ if getattr(tp, "origin", None) == "unknown_type":
+ self._struct_ctx(tp, tp.name, approxname=None)
+ elif isinstance(tp, model.NamedPointerType):
+ self._struct_ctx(tp.totype, tp.totype.name, approxname=tp.name,
+ named_ptr=tp)
+
+ # ----------
+ # function declarations
+
+ def _generate_cpy_function_collecttype(self, tp, name):
+ self._do_collect_type(tp.as_raw_function())
+ if tp.ellipsis and not self.target_is_python:
+ self._do_collect_type(tp)
+
+ def _generate_cpy_function_decl(self, tp, name):
+ assert not self.target_is_python
+ assert isinstance(tp, model.FunctionPtrType)
+ if tp.ellipsis:
+ # cannot support vararg functions better than this: check for its
+ # exact type (including the fixed arguments), and build it as a
+ # constant function pointer (no CPython wrapper)
+ self._generate_cpy_constant_decl(tp, name)
+ return
+ prnt = self._prnt
+ numargs = len(tp.args)
+ if numargs == 0:
+ argname = 'noarg'
+ elif numargs == 1:
+ argname = 'arg0'
+ else:
+ argname = 'args'
+ #
+ # ------------------------------
+ # the 'd' version of the function, only for addressof(lib, 'func')
+ arguments = []
+ call_arguments = []
+ context = 'argument of %s' % name
+ for i, type in enumerate(tp.args):
+ arguments.append(type.get_c_name(' x%d' % i, context))
+ call_arguments.append('x%d' % i)
+ repr_arguments = ', '.join(arguments)
+ repr_arguments = repr_arguments or 'void'
+ if tp.abi:
+ abi = tp.abi + ' '
+ else:
+ abi = ''
+ name_and_arguments = '%s_cffi_d_%s(%s)' % (abi, name, repr_arguments)
+ prnt('static %s' % (tp.result.get_c_name(name_and_arguments),))
+ prnt('{')
+ call_arguments = ', '.join(call_arguments)
+ result_code = 'return '
+ if isinstance(tp.result, model.VoidType):
+ result_code = ''
+ prnt(' %s%s(%s);' % (result_code, name, call_arguments))
+ prnt('}')
+ #
+ prnt('#ifndef PYPY_VERSION') # ------------------------------
+ #
+ prnt('static PyObject *')
+ prnt('_cffi_f_%s(PyObject *self, PyObject *%s)' % (name, argname))
+ prnt('{')
+ #
+ context = 'argument of %s' % name
+ for i, type in enumerate(tp.args):
+ arg = type.get_c_name(' x%d' % i, context)
+ prnt(' %s;' % arg)
+ #
+ localvars = set()
+ freelines = set()
+ for type in tp.args:
+ self._extra_local_variables(type, localvars, freelines)
+ for decl in sorted(localvars):
+ prnt(' %s;' % (decl,))
+ #
+ if not isinstance(tp.result, model.VoidType):
+ result_code = 'result = '
+ context = 'result of %s' % name
+ result_decl = ' %s;' % tp.result.get_c_name(' result', context)
+ prnt(result_decl)
+ prnt(' PyObject *pyresult;')
+ else:
+ result_decl = None
+ result_code = ''
+ #
+ if len(tp.args) > 1:
+ rng = range(len(tp.args))
+ for i in rng:
+ prnt(' PyObject *arg%d;' % i)
+ prnt()
+ prnt(' if (!PyArg_UnpackTuple(args, "%s", %d, %d, %s))' % (
+ name, len(rng), len(rng),
+ ', '.join(['&arg%d' % i for i in rng])))
+ prnt(' return NULL;')
+ prnt()
+ #
+ for i, type in enumerate(tp.args):
+ self._convert_funcarg_to_c(type, 'arg%d' % i, 'x%d' % i,
+ 'return NULL')
+ prnt()
+ #
+ prnt(' Py_BEGIN_ALLOW_THREADS')
+ prnt(' _cffi_restore_errno();')
+ call_arguments = ['x%d' % i for i in range(len(tp.args))]
+ call_arguments = ', '.join(call_arguments)
+ prnt(' { %s%s(%s); }' % (result_code, name, call_arguments))
+ prnt(' _cffi_save_errno();')
+ prnt(' Py_END_ALLOW_THREADS')
+ prnt()
+ #
+ prnt(' (void)self; /* unused */')
+ if numargs == 0:
+ prnt(' (void)noarg; /* unused */')
+ if result_code:
+ prnt(' pyresult = %s;' %
+ self._convert_expr_from_c(tp.result, 'result', 'result type'))
+ for freeline in freelines:
+ prnt(' ' + freeline)
+ prnt(' return pyresult;')
+ else:
+ for freeline in freelines:
+ prnt(' ' + freeline)
+ prnt(' Py_INCREF(Py_None);')
+ prnt(' return Py_None;')
+ prnt('}')
+ #
+ prnt('#else') # ------------------------------
+ #
+ # the PyPy version: need to replace struct/union arguments with
+ # pointers, and if the result is a struct/union, insert a first
+ # arg that is a pointer to the result. We also do that for
+ # complex args and return type.
+ def need_indirection(type):
+ return (isinstance(type, model.StructOrUnion) or
+ (isinstance(type, model.PrimitiveType) and
+ type.is_complex_type()))
+ difference = False
+ arguments = []
+ call_arguments = []
+ context = 'argument of %s' % name
+ for i, type in enumerate(tp.args):
+ indirection = ''
+ if need_indirection(type):
+ indirection = '*'
+ difference = True
+ arg = type.get_c_name(' %sx%d' % (indirection, i), context)
+ arguments.append(arg)
+ call_arguments.append('%sx%d' % (indirection, i))
+ tp_result = tp.result
+ if need_indirection(tp_result):
+ context = 'result of %s' % name
+ arg = tp_result.get_c_name(' *result', context)
+ arguments.insert(0, arg)
+ tp_result = model.void_type
+ result_decl = None
+ result_code = '*result = '
+ difference = True
+ if difference:
+ repr_arguments = ', '.join(arguments)
+ repr_arguments = repr_arguments or 'void'
+ name_and_arguments = '%s_cffi_f_%s(%s)' % (abi, name,
+ repr_arguments)
+ prnt('static %s' % (tp_result.get_c_name(name_and_arguments),))
+ prnt('{')
+ if result_decl:
+ prnt(result_decl)
+ call_arguments = ', '.join(call_arguments)
+ prnt(' { %s%s(%s); }' % (result_code, name, call_arguments))
+ if result_decl:
+ prnt(' return result;')
+ prnt('}')
+ else:
+ prnt('# define _cffi_f_%s _cffi_d_%s' % (name, name))
+ #
+ prnt('#endif') # ------------------------------
+ prnt()
+
+ def _generate_cpy_function_ctx(self, tp, name):
+ if tp.ellipsis and not self.target_is_python:
+ self._generate_cpy_constant_ctx(tp, name)
+ return
+ type_index = self._typesdict[tp.as_raw_function()]
+ numargs = len(tp.args)
+ if self.target_is_python:
+ meth_kind = OP_DLOPEN_FUNC
+ elif numargs == 0:
+ meth_kind = OP_CPYTHON_BLTN_N # 'METH_NOARGS'
+ elif numargs == 1:
+ meth_kind = OP_CPYTHON_BLTN_O # 'METH_O'
+ else:
+ meth_kind = OP_CPYTHON_BLTN_V # 'METH_VARARGS'
+ self._lsts["global"].append(
+ GlobalExpr(name, '_cffi_f_%s' % name,
+ CffiOp(meth_kind, type_index),
+ size='_cffi_d_%s' % name))
+
+ # ----------
+ # named structs or unions
+
+ def _field_type(self, tp_struct, field_name, tp_field):
+ if isinstance(tp_field, model.ArrayType):
+ actual_length = tp_field.length
+ if actual_length == '...':
+ ptr_struct_name = tp_struct.get_c_name('*')
+ actual_length = '_cffi_array_len(((%s)0)->%s)' % (
+ ptr_struct_name, field_name)
+ tp_item = self._field_type(tp_struct, '%s[0]' % field_name,
+ tp_field.item)
+ tp_field = model.ArrayType(tp_item, actual_length)
+ return tp_field
+
+ def _struct_collecttype(self, tp):
+ self._do_collect_type(tp)
+ if self.target_is_python:
+ # also requires nested anon struct/unions in ABI mode, recursively
+ for fldtype in tp.anonymous_struct_fields():
+ self._struct_collecttype(fldtype)
+
+ def _struct_decl(self, tp, cname, approxname):
+ if tp.fldtypes is None:
+ return
+ prnt = self._prnt
+ checkfuncname = '_cffi_checkfld_%s' % (approxname,)
+ prnt('_CFFI_UNUSED_FN')
+ prnt('static void %s(%s *p)' % (checkfuncname, cname))
+ prnt('{')
+ prnt(' /* only to generate compile-time warnings or errors */')
+ prnt(' (void)p;')
+ for fname, ftype, fbitsize, fqual in self._enum_fields(tp):
+ try:
+ if ftype.is_integer_type() or fbitsize >= 0:
+ # accept all integers, but complain on float or double
+ if fname != '':
+ prnt(" (void)((p->%s) | 0); /* check that '%s.%s' is "
+ "an integer */" % (fname, cname, fname))
+ continue
+ # only accept exactly the type declared, except that '[]'
+ # is interpreted as a '*' and so will match any array length.
+ # (It would also match '*', but that's harder to detect...)
+ while (isinstance(ftype, model.ArrayType)
+ and (ftype.length is None or ftype.length == '...')):
+ ftype = ftype.item
+ fname = fname + '[0]'
+ prnt(' { %s = &p->%s; (void)tmp; }' % (
+ ftype.get_c_name('*tmp', 'field %r'%fname, quals=fqual),
+ fname))
+ except VerificationError as e:
+ prnt(' /* %s */' % str(e)) # cannot verify it, ignore
+ prnt('}')
+ prnt('struct _cffi_align_%s { char x; %s y; };' % (approxname, cname))
+ prnt()
+
+ def _struct_ctx(self, tp, cname, approxname, named_ptr=None):
+ type_index = self._typesdict[tp]
+ reason_for_not_expanding = None
+ flags = []
+ if isinstance(tp, model.UnionType):
+ flags.append("_CFFI_F_UNION")
+ if tp.fldtypes is None:
+ flags.append("_CFFI_F_OPAQUE")
+ reason_for_not_expanding = "opaque"
+ if (tp not in self.ffi._parser._included_declarations and
+ (named_ptr is None or
+ named_ptr not in self.ffi._parser._included_declarations)):
+ if tp.fldtypes is None:
+ pass # opaque
+ elif tp.partial or any(tp.anonymous_struct_fields()):
+ pass # field layout obtained silently from the C compiler
+ else:
+ flags.append("_CFFI_F_CHECK_FIELDS")
+ if tp.packed:
+ if tp.packed > 1:
+ raise NotImplementedError(
+ "%r is declared with 'pack=%r'; only 0 or 1 are "
+ "supported in API mode (try to use \"...;\", which "
+ "does not require a 'pack' declaration)" %
+ (tp, tp.packed))
+ flags.append("_CFFI_F_PACKED")
+ else:
+ flags.append("_CFFI_F_EXTERNAL")
+ reason_for_not_expanding = "external"
+ flags = '|'.join(flags) or '0'
+ c_fields = []
+ if reason_for_not_expanding is None:
+ enumfields = list(self._enum_fields(tp))
+ for fldname, fldtype, fbitsize, fqual in enumfields:
+ fldtype = self._field_type(tp, fldname, fldtype)
+ self._check_not_opaque(fldtype,
+ "field '%s.%s'" % (tp.name, fldname))
+ # cname is None for _add_missing_struct_unions() only
+ op = OP_NOOP
+ if fbitsize >= 0:
+ op = OP_BITFIELD
+ size = '%d /* bits */' % fbitsize
+ elif cname is None or (
+ isinstance(fldtype, model.ArrayType) and
+ fldtype.length is None):
+ size = '(size_t)-1'
+ else:
+ size = 'sizeof(((%s)0)->%s)' % (
+ tp.get_c_name('*') if named_ptr is None
+ else named_ptr.name,
+ fldname)
+ if cname is None or fbitsize >= 0:
+ offset = '(size_t)-1'
+ elif named_ptr is not None:
+ offset = '((char *)&((%s)0)->%s) - (char *)0' % (
+ named_ptr.name, fldname)
+ else:
+ offset = 'offsetof(%s, %s)' % (tp.get_c_name(''), fldname)
+ c_fields.append(
+ FieldExpr(fldname, offset, size, fbitsize,
+ CffiOp(op, self._typesdict[fldtype])))
+ first_field_index = len(self._lsts["field"])
+ self._lsts["field"].extend(c_fields)
+ #
+ if cname is None: # unknown name, for _add_missing_struct_unions
+ size = '(size_t)-2'
+ align = -2
+ comment = "unnamed"
+ else:
+ if named_ptr is not None:
+ size = 'sizeof(*(%s)0)' % (named_ptr.name,)
+ align = '-1 /* unknown alignment */'
+ else:
+ size = 'sizeof(%s)' % (cname,)
+ align = 'offsetof(struct _cffi_align_%s, y)' % (approxname,)
+ comment = None
+ else:
+ size = '(size_t)-1'
+ align = -1
+ first_field_index = -1
+ comment = reason_for_not_expanding
+ self._lsts["struct_union"].append(
+ StructUnionExpr(tp.name, type_index, flags, size, align, comment,
+ first_field_index, c_fields))
+ self._seen_struct_unions.add(tp)
+
+ def _check_not_opaque(self, tp, location):
+ while isinstance(tp, model.ArrayType):
+ tp = tp.item
+ if isinstance(tp, model.StructOrUnion) and tp.fldtypes is None:
+ raise TypeError(
+ "%s is of an opaque type (not declared in cdef())" % location)
+
+ def _add_missing_struct_unions(self):
+ # not very nice, but some struct declarations might be missing
+ # because they don't have any known C name. Check that they are
+ # not partial (we can't complete or verify them!) and emit them
+ # anonymously.
+ lst = list(self._struct_unions.items())
+ lst.sort(key=lambda tp_order: tp_order[1])
+ for tp, order in lst:
+ if tp not in self._seen_struct_unions:
+ if tp.partial:
+ raise NotImplementedError("internal inconsistency: %r is "
+ "partial but was not seen at "
+ "this point" % (tp,))
+ if tp.name.startswith('$') and tp.name[1:].isdigit():
+ approxname = tp.name[1:]
+ elif tp.name == '_IO_FILE' and tp.forcename == 'FILE':
+ approxname = 'FILE'
+ self._typedef_ctx(tp, 'FILE')
+ else:
+ raise NotImplementedError("internal inconsistency: %r" %
+ (tp,))
+ self._struct_ctx(tp, None, approxname)
+
+ def _generate_cpy_struct_collecttype(self, tp, name):
+ self._struct_collecttype(tp)
+ _generate_cpy_union_collecttype = _generate_cpy_struct_collecttype
+
+ def _struct_names(self, tp):
+ cname = tp.get_c_name('')
+ if ' ' in cname:
+ return cname, cname.replace(' ', '_')
+ else:
+ return cname, '_' + cname
+
+ def _generate_cpy_struct_decl(self, tp, name):
+ self._struct_decl(tp, *self._struct_names(tp))
+ _generate_cpy_union_decl = _generate_cpy_struct_decl
+
+ def _generate_cpy_struct_ctx(self, tp, name):
+ self._struct_ctx(tp, *self._struct_names(tp))
+ _generate_cpy_union_ctx = _generate_cpy_struct_ctx
+
+ # ----------
+ # 'anonymous' declarations. These are produced for anonymous structs
+ # or unions; the 'name' is obtained by a typedef.
+
+ def _generate_cpy_anonymous_collecttype(self, tp, name):
+ if isinstance(tp, model.EnumType):
+ self._generate_cpy_enum_collecttype(tp, name)
+ else:
+ self._struct_collecttype(tp)
+
+ def _generate_cpy_anonymous_decl(self, tp, name):
+ if isinstance(tp, model.EnumType):
+ self._generate_cpy_enum_decl(tp)
+ else:
+ self._struct_decl(tp, name, 'typedef_' + name)
+
+ def _generate_cpy_anonymous_ctx(self, tp, name):
+ if isinstance(tp, model.EnumType):
+ self._enum_ctx(tp, name)
+ else:
+ self._struct_ctx(tp, name, 'typedef_' + name)
+
+ # ----------
+ # constants, declared with "static const ..."
+
+ def _generate_cpy_const(self, is_int, name, tp=None, category='const',
+ check_value=None):
+ if (category, name) in self._seen_constants:
+ raise VerificationError(
+ "duplicate declaration of %s '%s'" % (category, name))
+ self._seen_constants.add((category, name))
+ #
+ prnt = self._prnt
+ funcname = '_cffi_%s_%s' % (category, name)
+ if is_int:
+ prnt('static int %s(unsigned long long *o)' % funcname)
+ prnt('{')
+ prnt(' int n = (%s) <= 0;' % (name,))
+ prnt(' *o = (unsigned long long)((%s) | 0);'
+ ' /* check that %s is an integer */' % (name, name))
+ if check_value is not None:
+ if check_value > 0:
+ check_value = '%dU' % (check_value,)
+ prnt(' if (!_cffi_check_int(*o, n, %s))' % (check_value,))
+ prnt(' n |= 2;')
+ prnt(' return n;')
+ prnt('}')
+ else:
+ assert check_value is None
+ prnt('static void %s(char *o)' % funcname)
+ prnt('{')
+ prnt(' *(%s)o = %s;' % (tp.get_c_name('*'), name))
+ prnt('}')
+ prnt()
+
+ def _generate_cpy_constant_collecttype(self, tp, name):
+ is_int = tp.is_integer_type()
+ if not is_int or self.target_is_python:
+ self._do_collect_type(tp)
+
+ def _generate_cpy_constant_decl(self, tp, name):
+ is_int = tp.is_integer_type()
+ self._generate_cpy_const(is_int, name, tp)
+
+ def _generate_cpy_constant_ctx(self, tp, name):
+ if not self.target_is_python and tp.is_integer_type():
+ type_op = CffiOp(OP_CONSTANT_INT, -1)
+ else:
+ if self.target_is_python:
+ const_kind = OP_DLOPEN_CONST
+ else:
+ const_kind = OP_CONSTANT
+ type_index = self._typesdict[tp]
+ type_op = CffiOp(const_kind, type_index)
+ self._lsts["global"].append(
+ GlobalExpr(name, '_cffi_const_%s' % name, type_op))
+
+ # ----------
+ # enums
+
+ def _generate_cpy_enum_collecttype(self, tp, name):
+ self._do_collect_type(tp)
+
+ def _generate_cpy_enum_decl(self, tp, name=None):
+ for enumerator in tp.enumerators:
+ self._generate_cpy_const(True, enumerator)
+
+ def _enum_ctx(self, tp, cname):
+ type_index = self._typesdict[tp]
+ type_op = CffiOp(OP_ENUM, -1)
+ if self.target_is_python:
+ tp.check_not_partial()
+ for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
+ self._lsts["global"].append(
+ GlobalExpr(enumerator, '_cffi_const_%s' % enumerator, type_op,
+ check_value=enumvalue))
+ #
+ if cname is not None and '$' not in cname and not self.target_is_python:
+ size = "sizeof(%s)" % cname
+ signed = "((%s)-1) <= 0" % cname
+ else:
+ basetp = tp.build_baseinttype(self.ffi, [])
+ size = self.ffi.sizeof(basetp)
+ signed = int(int(self.ffi.cast(basetp, -1)) < 0)
+ allenums = ",".join(tp.enumerators)
+ self._lsts["enum"].append(
+ EnumExpr(tp.name, type_index, size, signed, allenums))
+
+ def _generate_cpy_enum_ctx(self, tp, name):
+ self._enum_ctx(tp, tp._get_c_name())
+
+ # ----------
+ # macros: for now only for integers
+
+ def _generate_cpy_macro_collecttype(self, tp, name):
+ pass
+
+ def _generate_cpy_macro_decl(self, tp, name):
+ if tp == '...':
+ check_value = None
+ else:
+ check_value = tp # an integer
+ self._generate_cpy_const(True, name, check_value=check_value)
+
+ def _generate_cpy_macro_ctx(self, tp, name):
+ if tp == '...':
+ if self.target_is_python:
+ raise VerificationError(
+ "cannot use the syntax '...' in '#define %s ...' when "
+ "using the ABI mode" % (name,))
+ check_value = None
+ else:
+ check_value = tp # an integer
+ type_op = CffiOp(OP_CONSTANT_INT, -1)
+ self._lsts["global"].append(
+ GlobalExpr(name, '_cffi_const_%s' % name, type_op,
+ check_value=check_value))
+
+ # ----------
+ # global variables
+
+ def _global_type(self, tp, global_name):
+ if isinstance(tp, model.ArrayType):
+ actual_length = tp.length
+ if actual_length == '...':
+ actual_length = '_cffi_array_len(%s)' % (global_name,)
+ tp_item = self._global_type(tp.item, '%s[0]' % global_name)
+ tp = model.ArrayType(tp_item, actual_length)
+ return tp
+
+ def _generate_cpy_variable_collecttype(self, tp, name):
+ self._do_collect_type(self._global_type(tp, name))
+
+ def _generate_cpy_variable_decl(self, tp, name):
+ prnt = self._prnt
+ tp = self._global_type(tp, name)
+ if isinstance(tp, model.ArrayType) and tp.length is None:
+ tp = tp.item
+ ampersand = ''
+ else:
+ ampersand = '&'
+ # This code assumes that casts from "tp *" to "void *" is a
+ # no-op, i.e. a function that returns a "tp *" can be called
+ # as if it returned a "void *". This should be generally true
+ # on any modern machine. The only exception to that rule (on
+ # uncommon architectures, and as far as I can tell) might be
+ # if 'tp' were a function type, but that is not possible here.
+ # (If 'tp' is a function _pointer_ type, then casts from "fn_t
+ # **" to "void *" are again no-ops, as far as I can tell.)
+ decl = '*_cffi_var_%s(void)' % (name,)
+ prnt('static ' + tp.get_c_name(decl, quals=self._current_quals))
+ prnt('{')
+ prnt(' return %s(%s);' % (ampersand, name))
+ prnt('}')
+ prnt()
+
+ def _generate_cpy_variable_ctx(self, tp, name):
+ tp = self._global_type(tp, name)
+ type_index = self._typesdict[tp]
+ if self.target_is_python:
+ op = OP_GLOBAL_VAR
+ else:
+ op = OP_GLOBAL_VAR_F
+ self._lsts["global"].append(
+ GlobalExpr(name, '_cffi_var_%s' % name, CffiOp(op, type_index)))
+
+ # ----------
+ # extern "Python"
+
+ def _generate_cpy_extern_python_collecttype(self, tp, name):
+ assert isinstance(tp, model.FunctionPtrType)
+ self._do_collect_type(tp)
+ _generate_cpy_dllexport_python_collecttype = \
+ _generate_cpy_extern_python_plus_c_collecttype = \
+ _generate_cpy_extern_python_collecttype
+
+ def _extern_python_decl(self, tp, name, tag_and_space):
+ prnt = self._prnt
+ if isinstance(tp.result, model.VoidType):
+ size_of_result = '0'
+ else:
+ context = 'result of %s' % name
+ size_of_result = '(int)sizeof(%s)' % (
+ tp.result.get_c_name('', context),)
+ prnt('static struct _cffi_externpy_s _cffi_externpy__%s =' % name)
+ prnt(' { "%s.%s", %s, 0, 0 };' % (
+ self.module_name, name, size_of_result))
+ prnt()
+ #
+ arguments = []
+ context = 'argument of %s' % name
+ for i, type in enumerate(tp.args):
+ arg = type.get_c_name(' a%d' % i, context)
+ arguments.append(arg)
+ #
+ repr_arguments = ', '.join(arguments)
+ repr_arguments = repr_arguments or 'void'
+ name_and_arguments = '%s(%s)' % (name, repr_arguments)
+ if tp.abi == "__stdcall":
+ name_and_arguments = '_cffi_stdcall ' + name_and_arguments
+ #
+ def may_need_128_bits(tp):
+ return (isinstance(tp, model.PrimitiveType) and
+ tp.name == 'long double')
+ #
+ size_of_a = max(len(tp.args)*8, 8)
+ if may_need_128_bits(tp.result):
+ size_of_a = max(size_of_a, 16)
+ if isinstance(tp.result, model.StructOrUnion):
+ size_of_a = 'sizeof(%s) > %d ? sizeof(%s) : %d' % (
+ tp.result.get_c_name(''), size_of_a,
+ tp.result.get_c_name(''), size_of_a)
+ prnt('%s%s' % (tag_and_space, tp.result.get_c_name(name_and_arguments)))
+ prnt('{')
+ prnt(' char a[%s];' % size_of_a)
+ prnt(' char *p = a;')
+ for i, type in enumerate(tp.args):
+ arg = 'a%d' % i
+ if (isinstance(type, model.StructOrUnion) or
+ may_need_128_bits(type)):
+ arg = '&' + arg
+ type = model.PointerType(type)
+ prnt(' *(%s)(p + %d) = %s;' % (type.get_c_name('*'), i*8, arg))
+ prnt(' _cffi_call_python(&_cffi_externpy__%s, p);' % name)
+ if not isinstance(tp.result, model.VoidType):
+ prnt(' return *(%s)p;' % (tp.result.get_c_name('*'),))
+ prnt('}')
+ prnt()
+ self._num_externpy += 1
+
+ def _generate_cpy_extern_python_decl(self, tp, name):
+ self._extern_python_decl(tp, name, 'static ')
+
+ def _generate_cpy_dllexport_python_decl(self, tp, name):
+ self._extern_python_decl(tp, name, 'CFFI_DLLEXPORT ')
+
+ def _generate_cpy_extern_python_plus_c_decl(self, tp, name):
+ self._extern_python_decl(tp, name, '')
+
+ def _generate_cpy_extern_python_ctx(self, tp, name):
+ if self.target_is_python:
+ raise VerificationError(
+ "cannot use 'extern \"Python\"' in the ABI mode")
+ if tp.ellipsis:
+ raise NotImplementedError("a vararg function is extern \"Python\"")
+ type_index = self._typesdict[tp]
+ type_op = CffiOp(OP_EXTERN_PYTHON, type_index)
+ self._lsts["global"].append(
+ GlobalExpr(name, '&_cffi_externpy__%s' % name, type_op, name))
+
+ _generate_cpy_dllexport_python_ctx = \
+ _generate_cpy_extern_python_plus_c_ctx = \
+ _generate_cpy_extern_python_ctx
+
+ def _print_string_literal_in_array(self, s):
+ prnt = self._prnt
+ prnt('// # NB. this is not a string because of a size limit in MSVC')
+ if not isinstance(s, bytes): # unicode
+ s = s.encode('utf-8') # -> bytes
+ else:
+ s.decode('utf-8') # got bytes, check for valid utf-8
+ try:
+ s.decode('ascii')
+ except UnicodeDecodeError:
+ s = b'# -*- encoding: utf8 -*-\n' + s
+ for line in s.splitlines(True):
+ comment = line
+ if type('//') is bytes: # python2
+ line = map(ord, line) # make a list of integers
+ else: # python3
+ # type(line) is bytes, which enumerates like a list of integers
+ comment = ascii(comment)[1:-1]
+ prnt(('// ' + comment).rstrip())
+ printed_line = ''
+ for c in line:
+ if len(printed_line) >= 76:
+ prnt(printed_line)
+ printed_line = ''
+ printed_line += '%d,' % (c,)
+ prnt(printed_line)
+
+ # ----------
+ # emitting the opcodes for individual types
+
+ def _emit_bytecode_VoidType(self, tp, index):
+ self.cffi_types[index] = CffiOp(OP_PRIMITIVE, PRIM_VOID)
+
+ def _emit_bytecode_PrimitiveType(self, tp, index):
+ prim_index = PRIMITIVE_TO_INDEX[tp.name]
+ self.cffi_types[index] = CffiOp(OP_PRIMITIVE, prim_index)
+
+ def _emit_bytecode_UnknownIntegerType(self, tp, index):
+ s = ('_cffi_prim_int(sizeof(%s), (\n'
+ ' ((%s)-1) | 0 /* check that %s is an integer type */\n'
+ ' ) <= 0)' % (tp.name, tp.name, tp.name))
+ self.cffi_types[index] = CffiOp(OP_PRIMITIVE, s)
+
+ def _emit_bytecode_UnknownFloatType(self, tp, index):
+ s = ('_cffi_prim_float(sizeof(%s) *\n'
+ ' (((%s)1) / 2) * 2 /* integer => 0, float => 1 */\n'
+ ' )' % (tp.name, tp.name))
+ self.cffi_types[index] = CffiOp(OP_PRIMITIVE, s)
+
+ def _emit_bytecode_RawFunctionType(self, tp, index):
+ self.cffi_types[index] = CffiOp(OP_FUNCTION, self._typesdict[tp.result])
+ index += 1
+ for tp1 in tp.args:
+ realindex = self._typesdict[tp1]
+ if index != realindex:
+ if isinstance(tp1, model.PrimitiveType):
+ self._emit_bytecode_PrimitiveType(tp1, index)
+ else:
+ self.cffi_types[index] = CffiOp(OP_NOOP, realindex)
+ index += 1
+ flags = int(tp.ellipsis)
+ if tp.abi is not None:
+ if tp.abi == '__stdcall':
+ flags |= 2
+ else:
+ raise NotImplementedError("abi=%r" % (tp.abi,))
+ self.cffi_types[index] = CffiOp(OP_FUNCTION_END, flags)
+
+ def _emit_bytecode_PointerType(self, tp, index):
+ self.cffi_types[index] = CffiOp(OP_POINTER, self._typesdict[tp.totype])
+
+ _emit_bytecode_ConstPointerType = _emit_bytecode_PointerType
+ _emit_bytecode_NamedPointerType = _emit_bytecode_PointerType
+
+ def _emit_bytecode_FunctionPtrType(self, tp, index):
+ raw = tp.as_raw_function()
+ self.cffi_types[index] = CffiOp(OP_POINTER, self._typesdict[raw])
+
+ def _emit_bytecode_ArrayType(self, tp, index):
+ item_index = self._typesdict[tp.item]
+ if tp.length is None:
+ self.cffi_types[index] = CffiOp(OP_OPEN_ARRAY, item_index)
+ elif tp.length == '...':
+ raise VerificationError(
+ "type %s badly placed: the '...' array length can only be "
+ "used on global arrays or on fields of structures" % (
+ str(tp).replace('/*...*/', '...'),))
+ else:
+ assert self.cffi_types[index + 1] == 'LEN'
+ self.cffi_types[index] = CffiOp(OP_ARRAY, item_index)
+ self.cffi_types[index + 1] = CffiOp(None, str(tp.length))
+
+ def _emit_bytecode_StructType(self, tp, index):
+ struct_index = self._struct_unions[tp]
+ self.cffi_types[index] = CffiOp(OP_STRUCT_UNION, struct_index)
+ _emit_bytecode_UnionType = _emit_bytecode_StructType
+
+ def _emit_bytecode_EnumType(self, tp, index):
+ enum_index = self._enums[tp]
+ self.cffi_types[index] = CffiOp(OP_ENUM, enum_index)
+
+
+if sys.version_info >= (3,):
+ NativeIO = io.StringIO
+else:
+ class NativeIO(io.BytesIO):
+ def write(self, s):
+ if isinstance(s, unicode):
+ s = s.encode('ascii')
+ super(NativeIO, self).write(s)
+
+def _make_c_or_py_source(ffi, module_name, preamble, target_file, verbose):
+ if verbose:
+ print("generating %s" % (target_file,))
+ recompiler = Recompiler(ffi, module_name,
+ target_is_python=(preamble is None))
+ recompiler.collect_type_table()
+ recompiler.collect_step_tables()
+ f = NativeIO()
+ recompiler.write_source_to_f(f, preamble)
+ output = f.getvalue()
+ try:
+ with open(target_file, 'r') as f1:
+ if f1.read(len(output) + 1) != output:
+ raise IOError
+ if verbose:
+ print("(already up-to-date)")
+ return False # already up-to-date
+ except IOError:
+ tmp_file = '%s.~%d' % (target_file, os.getpid())
+ with open(tmp_file, 'w') as f1:
+ f1.write(output)
+ try:
+ os.rename(tmp_file, target_file)
+ except OSError:
+ os.unlink(target_file)
+ os.rename(tmp_file, target_file)
+ return True
+
+def make_c_source(ffi, module_name, preamble, target_c_file, verbose=False):
+ assert preamble is not None
+ return _make_c_or_py_source(ffi, module_name, preamble, target_c_file,
+ verbose)
+
+def make_py_source(ffi, module_name, target_py_file, verbose=False):
+ return _make_c_or_py_source(ffi, module_name, None, target_py_file,
+ verbose)
+
+def _modname_to_file(outputdir, modname, extension):
+ parts = modname.split('.')
+ try:
+ os.makedirs(os.path.join(outputdir, *parts[:-1]))
+ except OSError:
+ pass
+ parts[-1] += extension
+ return os.path.join(outputdir, *parts), parts
+
+
+# Aaargh. Distutils is not tested at all for the purpose of compiling
+# DLLs that are not extension modules. Here are some hacks to work
+# around that, in the _patch_for_*() functions...
+
+def _patch_meth(patchlist, cls, name, new_meth):
+ old = getattr(cls, name)
+ patchlist.append((cls, name, old))
+ setattr(cls, name, new_meth)
+ return old
+
+def _unpatch_meths(patchlist):
+ for cls, name, old_meth in reversed(patchlist):
+ setattr(cls, name, old_meth)
+
+def _patch_for_embedding(patchlist):
+ if sys.platform == 'win32':
+ # we must not remove the manifest when building for embedding!
+ from distutils.msvc9compiler import MSVCCompiler
+ _patch_meth(patchlist, MSVCCompiler, '_remove_visual_c_ref',
+ lambda self, manifest_file: manifest_file)
+
+ if sys.platform == 'darwin':
+ # we must not make a '-bundle', but a '-dynamiclib' instead
+ from distutils.ccompiler import CCompiler
+ def my_link_shared_object(self, *args, **kwds):
+ if '-bundle' in self.linker_so:
+ self.linker_so = list(self.linker_so)
+ i = self.linker_so.index('-bundle')
+ self.linker_so[i] = '-dynamiclib'
+ return old_link_shared_object(self, *args, **kwds)
+ old_link_shared_object = _patch_meth(patchlist, CCompiler,
+ 'link_shared_object',
+ my_link_shared_object)
+
+def _patch_for_target(patchlist, target):
+ from distutils.command.build_ext import build_ext
+ # if 'target' is different from '*', we need to patch some internal
+ # method to just return this 'target' value, instead of having it
+ # built from module_name
+ if target.endswith('.*'):
+ target = target[:-2]
+ if sys.platform == 'win32':
+ target += '.dll'
+ elif sys.platform == 'darwin':
+ target += '.dylib'
+ else:
+ target += '.so'
+ _patch_meth(patchlist, build_ext, 'get_ext_filename',
+ lambda self, ext_name: target)
+
+
+def recompile(ffi, module_name, preamble, tmpdir='.', call_c_compiler=True,
+ c_file=None, source_extension='.c', extradir=None,
+ compiler_verbose=1, target=None, debug=None, **kwds):
+ if not isinstance(module_name, str):
+ module_name = module_name.encode('ascii')
+ if ffi._windows_unicode:
+ ffi._apply_windows_unicode(kwds)
+ if preamble is not None:
+ embedding = (ffi._embedding is not None)
+ if embedding:
+ ffi._apply_embedding_fix(kwds)
+ if c_file is None:
+ c_file, parts = _modname_to_file(tmpdir, module_name,
+ source_extension)
+ if extradir:
+ parts = [extradir] + parts
+ ext_c_file = os.path.join(*parts)
+ else:
+ ext_c_file = c_file
+ #
+ if target is None:
+ if embedding:
+ target = '%s.*' % module_name
+ else:
+ target = '*'
+ #
+ ext = ffiplatform.get_extension(ext_c_file, module_name, **kwds)
+ updated = make_c_source(ffi, module_name, preamble, c_file,
+ verbose=compiler_verbose)
+ if call_c_compiler:
+ patchlist = []
+ cwd = os.getcwd()
+ try:
+ if embedding:
+ _patch_for_embedding(patchlist)
+ if target != '*':
+ _patch_for_target(patchlist, target)
+ if compiler_verbose:
+ if tmpdir == '.':
+ msg = 'the current directory is'
+ else:
+ msg = 'setting the current directory to'
+ print('%s %r' % (msg, os.path.abspath(tmpdir)))
+ os.chdir(tmpdir)
+ outputfilename = ffiplatform.compile('.', ext,
+ compiler_verbose, debug)
+ finally:
+ os.chdir(cwd)
+ _unpatch_meths(patchlist)
+ return outputfilename
+ else:
+ return ext, updated
+ else:
+ if c_file is None:
+ c_file, _ = _modname_to_file(tmpdir, module_name, '.py')
+ updated = make_py_source(ffi, module_name, c_file,
+ verbose=compiler_verbose)
+ if call_c_compiler:
+ return c_file
+ else:
+ return None, updated
+
diff --git a/lib/cffi/setuptools_ext.py b/lib/cffi/setuptools_ext.py
new file mode 100644
index 0000000..8fe3614
--- /dev/null
+++ b/lib/cffi/setuptools_ext.py
@@ -0,0 +1,219 @@
+import os
+import sys
+
+try:
+ basestring
+except NameError:
+ # Python 3.x
+ basestring = str
+
+def error(msg):
+ from distutils.errors import DistutilsSetupError
+ raise DistutilsSetupError(msg)
+
+
+def execfile(filename, glob):
+ # We use execfile() (here rewritten for Python 3) instead of
+ # __import__() to load the build script. The problem with
+ # a normal import is that in some packages, the intermediate
+ # __init__.py files may already try to import the file that
+ # we are generating.
+ with open(filename) as f:
+ src = f.read()
+ src += '\n' # Python 2.6 compatibility
+ code = compile(src, filename, 'exec')
+ exec(code, glob, glob)
+
+
+def add_cffi_module(dist, mod_spec):
+ from cffi.api import FFI
+
+ if not isinstance(mod_spec, basestring):
+ error("argument to 'cffi_modules=...' must be a str or a list of str,"
+ " not %r" % (type(mod_spec).__name__,))
+ mod_spec = str(mod_spec)
+ try:
+ build_file_name, ffi_var_name = mod_spec.split(':')
+ except ValueError:
+ error("%r must be of the form 'path/build.py:ffi_variable'" %
+ (mod_spec,))
+ if not os.path.exists(build_file_name):
+ ext = ''
+ rewritten = build_file_name.replace('.', '/') + '.py'
+ if os.path.exists(rewritten):
+ ext = ' (rewrite cffi_modules to [%r])' % (
+ rewritten + ':' + ffi_var_name,)
+ error("%r does not name an existing file%s" % (build_file_name, ext))
+
+ mod_vars = {'__name__': '__cffi__', '__file__': build_file_name}
+ execfile(build_file_name, mod_vars)
+
+ try:
+ ffi = mod_vars[ffi_var_name]
+ except KeyError:
+ error("%r: object %r not found in module" % (mod_spec,
+ ffi_var_name))
+ if not isinstance(ffi, FFI):
+ ffi = ffi() # maybe it's a function instead of directly an ffi
+ if not isinstance(ffi, FFI):
+ error("%r is not an FFI instance (got %r)" % (mod_spec,
+ type(ffi).__name__))
+ if not hasattr(ffi, '_assigned_source'):
+ error("%r: the set_source() method was not called" % (mod_spec,))
+ module_name, source, source_extension, kwds = ffi._assigned_source
+ if ffi._windows_unicode:
+ kwds = kwds.copy()
+ ffi._apply_windows_unicode(kwds)
+
+ if source is None:
+ _add_py_module(dist, ffi, module_name)
+ else:
+ _add_c_module(dist, ffi, module_name, source, source_extension, kwds)
+
+def _set_py_limited_api(Extension, kwds):
+ """
+ Add py_limited_api to kwds if setuptools >= 26 is in use.
+ Do not alter the setting if it already exists.
+ Setuptools takes care of ignoring the flag on Python 2 and PyPy.
+
+ CPython itself should ignore the flag in a debugging version
+ (by not listing .abi3.so in the extensions it supports), but
+ it doesn't so far, creating troubles. That's why we check
+ for "not hasattr(sys, 'gettotalrefcount')" (the 2.7 compatible equivalent
+ of 'd' not in sys.abiflags). (http://bugs.python.org/issue28401)
+
+ On Windows, with CPython <= 3.4, it's better not to use py_limited_api
+ because virtualenv *still* doesn't copy PYTHON3.DLL on these versions.
+ Recently (2020) we started shipping only >= 3.5 wheels, though. So
+ we'll give it another try and set py_limited_api on Windows >= 3.5.
+ """
+ from cffi import recompiler
+
+ if ('py_limited_api' not in kwds and not hasattr(sys, 'gettotalrefcount')
+ and recompiler.USE_LIMITED_API):
+ import setuptools
+ try:
+ setuptools_major_version = int(setuptools.__version__.partition('.')[0])
+ if setuptools_major_version >= 26:
+ kwds['py_limited_api'] = True
+ except ValueError: # certain development versions of setuptools
+ # If we don't know the version number of setuptools, we
+ # try to set 'py_limited_api' anyway. At worst, we get a
+ # warning.
+ kwds['py_limited_api'] = True
+ return kwds
+
+def _add_c_module(dist, ffi, module_name, source, source_extension, kwds):
+ from distutils.core import Extension
+ # We are a setuptools extension. Need this build_ext for py_limited_api.
+ from setuptools.command.build_ext import build_ext
+ from distutils.dir_util import mkpath
+ from distutils import log
+ from cffi import recompiler
+
+ allsources = ['$PLACEHOLDER']
+ allsources.extend(kwds.pop('sources', []))
+ kwds = _set_py_limited_api(Extension, kwds)
+ ext = Extension(name=module_name, sources=allsources, **kwds)
+
+ def make_mod(tmpdir, pre_run=None):
+ c_file = os.path.join(tmpdir, module_name + source_extension)
+ log.info("generating cffi module %r" % c_file)
+ mkpath(tmpdir)
+ # a setuptools-only, API-only hook: called with the "ext" and "ffi"
+ # arguments just before we turn the ffi into C code. To use it,
+ # subclass the 'distutils.command.build_ext.build_ext' class and
+ # add a method 'def pre_run(self, ext, ffi)'.
+ if pre_run is not None:
+ pre_run(ext, ffi)
+ updated = recompiler.make_c_source(ffi, module_name, source, c_file)
+ if not updated:
+ log.info("already up-to-date")
+ return c_file
+
+ if dist.ext_modules is None:
+ dist.ext_modules = []
+ dist.ext_modules.append(ext)
+
+ base_class = dist.cmdclass.get('build_ext', build_ext)
+ class build_ext_make_mod(base_class):
+ def run(self):
+ if ext.sources[0] == '$PLACEHOLDER':
+ pre_run = getattr(self, 'pre_run', None)
+ ext.sources[0] = make_mod(self.build_temp, pre_run)
+ base_class.run(self)
+ dist.cmdclass['build_ext'] = build_ext_make_mod
+ # NB. multiple runs here will create multiple 'build_ext_make_mod'
+ # classes. Even in this case the 'build_ext' command should be
+ # run once; but just in case, the logic above does nothing if
+ # called again.
+
+
+def _add_py_module(dist, ffi, module_name):
+ from distutils.dir_util import mkpath
+ from setuptools.command.build_py import build_py
+ from setuptools.command.build_ext import build_ext
+ from distutils import log
+ from cffi import recompiler
+
+ def generate_mod(py_file):
+ log.info("generating cffi module %r" % py_file)
+ mkpath(os.path.dirname(py_file))
+ updated = recompiler.make_py_source(ffi, module_name, py_file)
+ if not updated:
+ log.info("already up-to-date")
+
+ base_class = dist.cmdclass.get('build_py', build_py)
+ class build_py_make_mod(base_class):
+ def run(self):
+ base_class.run(self)
+ module_path = module_name.split('.')
+ module_path[-1] += '.py'
+ generate_mod(os.path.join(self.build_lib, *module_path))
+ def get_source_files(self):
+ # This is called from 'setup.py sdist' only. Exclude
+ # the generate .py module in this case.
+ saved_py_modules = self.py_modules
+ try:
+ if saved_py_modules:
+ self.py_modules = [m for m in saved_py_modules
+ if m != module_name]
+ return base_class.get_source_files(self)
+ finally:
+ self.py_modules = saved_py_modules
+ dist.cmdclass['build_py'] = build_py_make_mod
+
+ # distutils and setuptools have no notion I could find of a
+ # generated python module. If we don't add module_name to
+ # dist.py_modules, then things mostly work but there are some
+ # combination of options (--root and --record) that will miss
+ # the module. So we add it here, which gives a few apparently
+ # harmless warnings about not finding the file outside the
+ # build directory.
+ # Then we need to hack more in get_source_files(); see above.
+ if dist.py_modules is None:
+ dist.py_modules = []
+ dist.py_modules.append(module_name)
+
+ # the following is only for "build_ext -i"
+ base_class_2 = dist.cmdclass.get('build_ext', build_ext)
+ class build_ext_make_mod(base_class_2):
+ def run(self):
+ base_class_2.run(self)
+ if self.inplace:
+ # from get_ext_fullpath() in distutils/command/build_ext.py
+ module_path = module_name.split('.')
+ package = '.'.join(module_path[:-1])
+ build_py = self.get_finalized_command('build_py')
+ package_dir = build_py.get_package_dir(package)
+ file_name = module_path[-1] + '.py'
+ generate_mod(os.path.join(package_dir, file_name))
+ dist.cmdclass['build_ext'] = build_ext_make_mod
+
+def cffi_modules(dist, attr, value):
+ assert attr == 'cffi_modules'
+ if isinstance(value, basestring):
+ value = [value]
+
+ for cffi_module in value:
+ add_cffi_module(dist, cffi_module)
diff --git a/lib/cffi/vengine_cpy.py b/lib/cffi/vengine_cpy.py
new file mode 100644
index 0000000..6de0df0
--- /dev/null
+++ b/lib/cffi/vengine_cpy.py
@@ -0,0 +1,1076 @@
+#
+# DEPRECATED: implementation for ffi.verify()
+#
+import sys, imp
+from . import model
+from .error import VerificationError
+
+
+class VCPythonEngine(object):
+ _class_key = 'x'
+ _gen_python_module = True
+
+ def __init__(self, verifier):
+ self.verifier = verifier
+ self.ffi = verifier.ffi
+ self._struct_pending_verification = {}
+ self._types_of_builtin_functions = {}
+
+ def patch_extension_kwds(self, kwds):
+ pass
+
+ def find_module(self, module_name, path, so_suffixes):
+ try:
+ f, filename, descr = imp.find_module(module_name, path)
+ except ImportError:
+ return None
+ if f is not None:
+ f.close()
+ # Note that after a setuptools installation, there are both .py
+ # and .so files with the same basename. The code here relies on
+ # imp.find_module() locating the .so in priority.
+ if descr[0] not in so_suffixes:
+ return None
+ return filename
+
+ def collect_types(self):
+ self._typesdict = {}
+ self._generate("collecttype")
+
+ def _prnt(self, what=''):
+ self._f.write(what + '\n')
+
+ def _gettypenum(self, type):
+ # a KeyError here is a bug. please report it! :-)
+ return self._typesdict[type]
+
+ def _do_collect_type(self, tp):
+ if ((not isinstance(tp, model.PrimitiveType)
+ or tp.name == 'long double')
+ and tp not in self._typesdict):
+ num = len(self._typesdict)
+ self._typesdict[tp] = num
+
+ def write_source_to_f(self):
+ self.collect_types()
+ #
+ # The new module will have a _cffi_setup() function that receives
+ # objects from the ffi world, and that calls some setup code in
+ # the module. This setup code is split in several independent
+ # functions, e.g. one per constant. The functions are "chained"
+ # by ending in a tail call to each other.
+ #
+ # This is further split in two chained lists, depending on if we
+ # can do it at import-time or if we must wait for _cffi_setup() to
+ # provide us with the <ctype> objects. This is needed because we
+ # need the values of the enum constants in order to build the
+ # <ctype 'enum'> that we may have to pass to _cffi_setup().
+ #
+ # The following two 'chained_list_constants' items contains
+ # the head of these two chained lists, as a string that gives the
+ # call to do, if any.
+ self._chained_list_constants = ['((void)lib,0)', '((void)lib,0)']
+ #
+ prnt = self._prnt
+ # first paste some standard set of lines that are mostly '#define'
+ prnt(cffimod_header)
+ prnt()
+ # then paste the C source given by the user, verbatim.
+ prnt(self.verifier.preamble)
+ prnt()
+ #
+ # call generate_cpy_xxx_decl(), for every xxx found from
+ # ffi._parser._declarations. This generates all the functions.
+ self._generate("decl")
+ #
+ # implement the function _cffi_setup_custom() as calling the
+ # head of the chained list.
+ self._generate_setup_custom()
+ prnt()
+ #
+ # produce the method table, including the entries for the
+ # generated Python->C function wrappers, which are done
+ # by generate_cpy_function_method().
+ prnt('static PyMethodDef _cffi_methods[] = {')
+ self._generate("method")
+ prnt(' {"_cffi_setup", _cffi_setup, METH_VARARGS, NULL},')
+ prnt(' {NULL, NULL, 0, NULL} /* Sentinel */')
+ prnt('};')
+ prnt()
+ #
+ # standard init.
+ modname = self.verifier.get_module_name()
+ constants = self._chained_list_constants[False]
+ prnt('#if PY_MAJOR_VERSION >= 3')
+ prnt()
+ prnt('static struct PyModuleDef _cffi_module_def = {')
+ prnt(' PyModuleDef_HEAD_INIT,')
+ prnt(' "%s",' % modname)
+ prnt(' NULL,')
+ prnt(' -1,')
+ prnt(' _cffi_methods,')
+ prnt(' NULL, NULL, NULL, NULL')
+ prnt('};')
+ prnt()
+ prnt('PyMODINIT_FUNC')
+ prnt('PyInit_%s(void)' % modname)
+ prnt('{')
+ prnt(' PyObject *lib;')
+ prnt(' lib = PyModule_Create(&_cffi_module_def);')
+ prnt(' if (lib == NULL)')
+ prnt(' return NULL;')
+ prnt(' if (%s < 0 || _cffi_init() < 0) {' % (constants,))
+ prnt(' Py_DECREF(lib);')
+ prnt(' return NULL;')
+ prnt(' }')
+ prnt(' return lib;')
+ prnt('}')
+ prnt()
+ prnt('#else')
+ prnt()
+ prnt('PyMODINIT_FUNC')
+ prnt('init%s(void)' % modname)
+ prnt('{')
+ prnt(' PyObject *lib;')
+ prnt(' lib = Py_InitModule("%s", _cffi_methods);' % modname)
+ prnt(' if (lib == NULL)')
+ prnt(' return;')
+ prnt(' if (%s < 0 || _cffi_init() < 0)' % (constants,))
+ prnt(' return;')
+ prnt(' return;')
+ prnt('}')
+ prnt()
+ prnt('#endif')
+
+ def load_library(self, flags=None):
+ # XXX review all usages of 'self' here!
+ # import it as a new extension module
+ imp.acquire_lock()
+ try:
+ if hasattr(sys, "getdlopenflags"):
+ previous_flags = sys.getdlopenflags()
+ try:
+ if hasattr(sys, "setdlopenflags") and flags is not None:
+ sys.setdlopenflags(flags)
+ module = imp.load_dynamic(self.verifier.get_module_name(),
+ self.verifier.modulefilename)
+ except ImportError as e:
+ error = "importing %r: %s" % (self.verifier.modulefilename, e)
+ raise VerificationError(error)
+ finally:
+ if hasattr(sys, "setdlopenflags"):
+ sys.setdlopenflags(previous_flags)
+ finally:
+ imp.release_lock()
+ #
+ # call loading_cpy_struct() to get the struct layout inferred by
+ # the C compiler
+ self._load(module, 'loading')
+ #
+ # the C code will need the <ctype> objects. Collect them in
+ # order in a list.
+ revmapping = dict([(value, key)
+ for (key, value) in self._typesdict.items()])
+ lst = [revmapping[i] for i in range(len(revmapping))]
+ lst = list(map(self.ffi._get_cached_btype, lst))
+ #
+ # build the FFILibrary class and instance and call _cffi_setup().
+ # this will set up some fields like '_cffi_types', and only then
+ # it will invoke the chained list of functions that will really
+ # build (notably) the constant objects, as <cdata> if they are
+ # pointers, and store them as attributes on the 'library' object.
+ class FFILibrary(object):
+ _cffi_python_module = module
+ _cffi_ffi = self.ffi
+ _cffi_dir = []
+ def __dir__(self):
+ return FFILibrary._cffi_dir + list(self.__dict__)
+ library = FFILibrary()
+ if module._cffi_setup(lst, VerificationError, library):
+ import warnings
+ warnings.warn("reimporting %r might overwrite older definitions"
+ % (self.verifier.get_module_name()))
+ #
+ # finally, call the loaded_cpy_xxx() functions. This will perform
+ # the final adjustments, like copying the Python->C wrapper
+ # functions from the module to the 'library' object, and setting
+ # up the FFILibrary class with properties for the global C variables.
+ self._load(module, 'loaded', library=library)
+ module._cffi_original_ffi = self.ffi
+ module._cffi_types_of_builtin_funcs = self._types_of_builtin_functions
+ return library
+
+ def _get_declarations(self):
+ lst = [(key, tp) for (key, (tp, qual)) in
+ self.ffi._parser._declarations.items()]
+ lst.sort()
+ return lst
+
+ def _generate(self, step_name):
+ for name, tp in self._get_declarations():
+ kind, realname = name.split(' ', 1)
+ try:
+ method = getattr(self, '_generate_cpy_%s_%s' % (kind,
+ step_name))
+ except AttributeError:
+ raise VerificationError(
+ "not implemented in verify(): %r" % name)
+ try:
+ method(tp, realname)
+ except Exception as e:
+ model.attach_exception_info(e, name)
+ raise
+
+ def _load(self, module, step_name, **kwds):
+ for name, tp in self._get_declarations():
+ kind, realname = name.split(' ', 1)
+ method = getattr(self, '_%s_cpy_%s' % (step_name, kind))
+ try:
+ method(tp, realname, module, **kwds)
+ except Exception as e:
+ model.attach_exception_info(e, name)
+ raise
+
+ def _generate_nothing(self, tp, name):
+ pass
+
+ def _loaded_noop(self, tp, name, module, **kwds):
+ pass
+
+ # ----------
+
+ def _convert_funcarg_to_c(self, tp, fromvar, tovar, errcode):
+ extraarg = ''
+ if isinstance(tp, model.PrimitiveType):
+ if tp.is_integer_type() and tp.name != '_Bool':
+ converter = '_cffi_to_c_int'
+ extraarg = ', %s' % tp.name
+ else:
+ converter = '(%s)_cffi_to_c_%s' % (tp.get_c_name(''),
+ tp.name.replace(' ', '_'))
+ errvalue = '-1'
+ #
+ elif isinstance(tp, model.PointerType):
+ self._convert_funcarg_to_c_ptr_or_array(tp, fromvar,
+ tovar, errcode)
+ return
+ #
+ elif isinstance(tp, (model.StructOrUnion, model.EnumType)):
+ # a struct (not a struct pointer) as a function argument
+ self._prnt(' if (_cffi_to_c((char *)&%s, _cffi_type(%d), %s) < 0)'
+ % (tovar, self._gettypenum(tp), fromvar))
+ self._prnt(' %s;' % errcode)
+ return
+ #
+ elif isinstance(tp, model.FunctionPtrType):
+ converter = '(%s)_cffi_to_c_pointer' % tp.get_c_name('')
+ extraarg = ', _cffi_type(%d)' % self._gettypenum(tp)
+ errvalue = 'NULL'
+ #
+ else:
+ raise NotImplementedError(tp)
+ #
+ self._prnt(' %s = %s(%s%s);' % (tovar, converter, fromvar, extraarg))
+ self._prnt(' if (%s == (%s)%s && PyErr_Occurred())' % (
+ tovar, tp.get_c_name(''), errvalue))
+ self._prnt(' %s;' % errcode)
+
+ def _extra_local_variables(self, tp, localvars, freelines):
+ if isinstance(tp, model.PointerType):
+ localvars.add('Py_ssize_t datasize')
+ localvars.add('struct _cffi_freeme_s *large_args_free = NULL')
+ freelines.add('if (large_args_free != NULL)'
+ ' _cffi_free_array_arguments(large_args_free);')
+
+ def _convert_funcarg_to_c_ptr_or_array(self, tp, fromvar, tovar, errcode):
+ self._prnt(' datasize = _cffi_prepare_pointer_call_argument(')
+ self._prnt(' _cffi_type(%d), %s, (char **)&%s);' % (
+ self._gettypenum(tp), fromvar, tovar))
+ self._prnt(' if (datasize != 0) {')
+ self._prnt(' %s = ((size_t)datasize) <= 640 ? '
+ 'alloca((size_t)datasize) : NULL;' % (tovar,))
+ self._prnt(' if (_cffi_convert_array_argument(_cffi_type(%d), %s, '
+ '(char **)&%s,' % (self._gettypenum(tp), fromvar, tovar))
+ self._prnt(' datasize, &large_args_free) < 0)')
+ self._prnt(' %s;' % errcode)
+ self._prnt(' }')
+
+ def _convert_expr_from_c(self, tp, var, context):
+ if isinstance(tp, model.PrimitiveType):
+ if tp.is_integer_type() and tp.name != '_Bool':
+ return '_cffi_from_c_int(%s, %s)' % (var, tp.name)
+ elif tp.name != 'long double':
+ return '_cffi_from_c_%s(%s)' % (tp.name.replace(' ', '_'), var)
+ else:
+ return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ elif isinstance(tp, (model.PointerType, model.FunctionPtrType)):
+ return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ elif isinstance(tp, model.ArrayType):
+ return '_cffi_from_c_pointer((char *)%s, _cffi_type(%d))' % (
+ var, self._gettypenum(model.PointerType(tp.item)))
+ elif isinstance(tp, model.StructOrUnion):
+ if tp.fldnames is None:
+ raise TypeError("'%s' is used as %s, but is opaque" % (
+ tp._get_c_name(), context))
+ return '_cffi_from_c_struct((char *)&%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ elif isinstance(tp, model.EnumType):
+ return '_cffi_from_c_deref((char *)&%s, _cffi_type(%d))' % (
+ var, self._gettypenum(tp))
+ else:
+ raise NotImplementedError(tp)
+
+ # ----------
+ # typedefs: generates no code so far
+
+ _generate_cpy_typedef_collecttype = _generate_nothing
+ _generate_cpy_typedef_decl = _generate_nothing
+ _generate_cpy_typedef_method = _generate_nothing
+ _loading_cpy_typedef = _loaded_noop
+ _loaded_cpy_typedef = _loaded_noop
+
+ # ----------
+ # function declarations
+
+ def _generate_cpy_function_collecttype(self, tp, name):
+ assert isinstance(tp, model.FunctionPtrType)
+ if tp.ellipsis:
+ self._do_collect_type(tp)
+ else:
+ # don't call _do_collect_type(tp) in this common case,
+ # otherwise test_autofilled_struct_as_argument fails
+ for type in tp.args:
+ self._do_collect_type(type)
+ self._do_collect_type(tp.result)
+
+ def _generate_cpy_function_decl(self, tp, name):
+ assert isinstance(tp, model.FunctionPtrType)
+ if tp.ellipsis:
+ # cannot support vararg functions better than this: check for its
+ # exact type (including the fixed arguments), and build it as a
+ # constant function pointer (no CPython wrapper)
+ self._generate_cpy_const(False, name, tp)
+ return
+ prnt = self._prnt
+ numargs = len(tp.args)
+ if numargs == 0:
+ argname = 'noarg'
+ elif numargs == 1:
+ argname = 'arg0'
+ else:
+ argname = 'args'
+ prnt('static PyObject *')
+ prnt('_cffi_f_%s(PyObject *self, PyObject *%s)' % (name, argname))
+ prnt('{')
+ #
+ context = 'argument of %s' % name
+ for i, type in enumerate(tp.args):
+ prnt(' %s;' % type.get_c_name(' x%d' % i, context))
+ #
+ localvars = set()
+ freelines = set()
+ for type in tp.args:
+ self._extra_local_variables(type, localvars, freelines)
+ for decl in sorted(localvars):
+ prnt(' %s;' % (decl,))
+ #
+ if not isinstance(tp.result, model.VoidType):
+ result_code = 'result = '
+ context = 'result of %s' % name
+ prnt(' %s;' % tp.result.get_c_name(' result', context))
+ prnt(' PyObject *pyresult;')
+ else:
+ result_code = ''
+ #
+ if len(tp.args) > 1:
+ rng = range(len(tp.args))
+ for i in rng:
+ prnt(' PyObject *arg%d;' % i)
+ prnt()
+ prnt(' if (!PyArg_ParseTuple(args, "%s:%s", %s))' % (
+ 'O' * numargs, name, ', '.join(['&arg%d' % i for i in rng])))
+ prnt(' return NULL;')
+ prnt()
+ #
+ for i, type in enumerate(tp.args):
+ self._convert_funcarg_to_c(type, 'arg%d' % i, 'x%d' % i,
+ 'return NULL')
+ prnt()
+ #
+ prnt(' Py_BEGIN_ALLOW_THREADS')
+ prnt(' _cffi_restore_errno();')
+ prnt(' { %s%s(%s); }' % (
+ result_code, name,
+ ', '.join(['x%d' % i for i in range(len(tp.args))])))
+ prnt(' _cffi_save_errno();')
+ prnt(' Py_END_ALLOW_THREADS')
+ prnt()
+ #
+ prnt(' (void)self; /* unused */')
+ if numargs == 0:
+ prnt(' (void)noarg; /* unused */')
+ if result_code:
+ prnt(' pyresult = %s;' %
+ self._convert_expr_from_c(tp.result, 'result', 'result type'))
+ for freeline in freelines:
+ prnt(' ' + freeline)
+ prnt(' return pyresult;')
+ else:
+ for freeline in freelines:
+ prnt(' ' + freeline)
+ prnt(' Py_INCREF(Py_None);')
+ prnt(' return Py_None;')
+ prnt('}')
+ prnt()
+
+ def _generate_cpy_function_method(self, tp, name):
+ if tp.ellipsis:
+ return
+ numargs = len(tp.args)
+ if numargs == 0:
+ meth = 'METH_NOARGS'
+ elif numargs == 1:
+ meth = 'METH_O'
+ else:
+ meth = 'METH_VARARGS'
+ self._prnt(' {"%s", _cffi_f_%s, %s, NULL},' % (name, name, meth))
+
+ _loading_cpy_function = _loaded_noop
+
+ def _loaded_cpy_function(self, tp, name, module, library):
+ if tp.ellipsis:
+ return
+ func = getattr(module, name)
+ setattr(library, name, func)
+ self._types_of_builtin_functions[func] = tp
+
+ # ----------
+ # named structs
+
+ _generate_cpy_struct_collecttype = _generate_nothing
+ def _generate_cpy_struct_decl(self, tp, name):
+ assert name == tp.name
+ self._generate_struct_or_union_decl(tp, 'struct', name)
+ def _generate_cpy_struct_method(self, tp, name):
+ self._generate_struct_or_union_method(tp, 'struct', name)
+ def _loading_cpy_struct(self, tp, name, module):
+ self._loading_struct_or_union(tp, 'struct', name, module)
+ def _loaded_cpy_struct(self, tp, name, module, **kwds):
+ self._loaded_struct_or_union(tp)
+
+ _generate_cpy_union_collecttype = _generate_nothing
+ def _generate_cpy_union_decl(self, tp, name):
+ assert name == tp.name
+ self._generate_struct_or_union_decl(tp, 'union', name)
+ def _generate_cpy_union_method(self, tp, name):
+ self._generate_struct_or_union_method(tp, 'union', name)
+ def _loading_cpy_union(self, tp, name, module):
+ self._loading_struct_or_union(tp, 'union', name, module)
+ def _loaded_cpy_union(self, tp, name, module, **kwds):
+ self._loaded_struct_or_union(tp)
+
+ def _generate_struct_or_union_decl(self, tp, prefix, name):
+ if tp.fldnames is None:
+ return # nothing to do with opaque structs
+ checkfuncname = '_cffi_check_%s_%s' % (prefix, name)
+ layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name)
+ cname = ('%s %s' % (prefix, name)).strip()
+ #
+ prnt = self._prnt
+ prnt('static void %s(%s *p)' % (checkfuncname, cname))
+ prnt('{')
+ prnt(' /* only to generate compile-time warnings or errors */')
+ prnt(' (void)p;')
+ for fname, ftype, fbitsize, fqual in tp.enumfields():
+ if (isinstance(ftype, model.PrimitiveType)
+ and ftype.is_integer_type()) or fbitsize >= 0:
+ # accept all integers, but complain on float or double
+ prnt(' (void)((p->%s) << 1);' % fname)
+ else:
+ # only accept exactly the type declared.
+ try:
+ prnt(' { %s = &p->%s; (void)tmp; }' % (
+ ftype.get_c_name('*tmp', 'field %r'%fname, quals=fqual),
+ fname))
+ except VerificationError as e:
+ prnt(' /* %s */' % str(e)) # cannot verify it, ignore
+ prnt('}')
+ prnt('static PyObject *')
+ prnt('%s(PyObject *self, PyObject *noarg)' % (layoutfuncname,))
+ prnt('{')
+ prnt(' struct _cffi_aligncheck { char x; %s y; };' % cname)
+ prnt(' static Py_ssize_t nums[] = {')
+ prnt(' sizeof(%s),' % cname)
+ prnt(' offsetof(struct _cffi_aligncheck, y),')
+ for fname, ftype, fbitsize, fqual in tp.enumfields():
+ if fbitsize >= 0:
+ continue # xxx ignore fbitsize for now
+ prnt(' offsetof(%s, %s),' % (cname, fname))
+ if isinstance(ftype, model.ArrayType) and ftype.length is None:
+ prnt(' 0, /* %s */' % ftype._get_c_name())
+ else:
+ prnt(' sizeof(((%s *)0)->%s),' % (cname, fname))
+ prnt(' -1')
+ prnt(' };')
+ prnt(' (void)self; /* unused */')
+ prnt(' (void)noarg; /* unused */')
+ prnt(' return _cffi_get_struct_layout(nums);')
+ prnt(' /* the next line is not executed, but compiled */')
+ prnt(' %s(0);' % (checkfuncname,))
+ prnt('}')
+ prnt()
+
+ def _generate_struct_or_union_method(self, tp, prefix, name):
+ if tp.fldnames is None:
+ return # nothing to do with opaque structs
+ layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name)
+ self._prnt(' {"%s", %s, METH_NOARGS, NULL},' % (layoutfuncname,
+ layoutfuncname))
+
+ def _loading_struct_or_union(self, tp, prefix, name, module):
+ if tp.fldnames is None:
+ return # nothing to do with opaque structs
+ layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name)
+ #
+ function = getattr(module, layoutfuncname)
+ layout = function()
+ if isinstance(tp, model.StructOrUnion) and tp.partial:
+ # use the function()'s sizes and offsets to guide the
+ # layout of the struct
+ totalsize = layout[0]
+ totalalignment = layout[1]
+ fieldofs = layout[2::2]
+ fieldsize = layout[3::2]
+ tp.force_flatten()
+ assert len(fieldofs) == len(fieldsize) == len(tp.fldnames)
+ tp.fixedlayout = fieldofs, fieldsize, totalsize, totalalignment
+ else:
+ cname = ('%s %s' % (prefix, name)).strip()
+ self._struct_pending_verification[tp] = layout, cname
+
+ def _loaded_struct_or_union(self, tp):
+ if tp.fldnames is None:
+ return # nothing to do with opaque structs
+ self.ffi._get_cached_btype(tp) # force 'fixedlayout' to be considered
+
+ if tp in self._struct_pending_verification:
+ # check that the layout sizes and offsets match the real ones
+ def check(realvalue, expectedvalue, msg):
+ if realvalue != expectedvalue:
+ raise VerificationError(
+ "%s (we have %d, but C compiler says %d)"
+ % (msg, expectedvalue, realvalue))
+ ffi = self.ffi
+ BStruct = ffi._get_cached_btype(tp)
+ layout, cname = self._struct_pending_verification.pop(tp)
+ check(layout[0], ffi.sizeof(BStruct), "wrong total size")
+ check(layout[1], ffi.alignof(BStruct), "wrong total alignment")
+ i = 2
+ for fname, ftype, fbitsize, fqual in tp.enumfields():
+ if fbitsize >= 0:
+ continue # xxx ignore fbitsize for now
+ check(layout[i], ffi.offsetof(BStruct, fname),
+ "wrong offset for field %r" % (fname,))
+ if layout[i+1] != 0:
+ BField = ffi._get_cached_btype(ftype)
+ check(layout[i+1], ffi.sizeof(BField),
+ "wrong size for field %r" % (fname,))
+ i += 2
+ assert i == len(layout)
+
+ # ----------
+ # 'anonymous' declarations. These are produced for anonymous structs
+ # or unions; the 'name' is obtained by a typedef.
+
+ _generate_cpy_anonymous_collecttype = _generate_nothing
+
+ def _generate_cpy_anonymous_decl(self, tp, name):
+ if isinstance(tp, model.EnumType):
+ self._generate_cpy_enum_decl(tp, name, '')
+ else:
+ self._generate_struct_or_union_decl(tp, '', name)
+
+ def _generate_cpy_anonymous_method(self, tp, name):
+ if not isinstance(tp, model.EnumType):
+ self._generate_struct_or_union_method(tp, '', name)
+
+ def _loading_cpy_anonymous(self, tp, name, module):
+ if isinstance(tp, model.EnumType):
+ self._loading_cpy_enum(tp, name, module)
+ else:
+ self._loading_struct_or_union(tp, '', name, module)
+
+ def _loaded_cpy_anonymous(self, tp, name, module, **kwds):
+ if isinstance(tp, model.EnumType):
+ self._loaded_cpy_enum(tp, name, module, **kwds)
+ else:
+ self._loaded_struct_or_union(tp)
+
+ # ----------
+ # constants, likely declared with '#define'
+
+ def _generate_cpy_const(self, is_int, name, tp=None, category='const',
+ vartp=None, delayed=True, size_too=False,
+ check_value=None):
+ prnt = self._prnt
+ funcname = '_cffi_%s_%s' % (category, name)
+ prnt('static int %s(PyObject *lib)' % funcname)
+ prnt('{')
+ prnt(' PyObject *o;')
+ prnt(' int res;')
+ if not is_int:
+ prnt(' %s;' % (vartp or tp).get_c_name(' i', name))
+ else:
+ assert category == 'const'
+ #
+ if check_value is not None:
+ self._check_int_constant_value(name, check_value)
+ #
+ if not is_int:
+ if category == 'var':
+ realexpr = '&' + name
+ else:
+ realexpr = name
+ prnt(' i = (%s);' % (realexpr,))
+ prnt(' o = %s;' % (self._convert_expr_from_c(tp, 'i',
+ 'variable type'),))
+ assert delayed
+ else:
+ prnt(' o = _cffi_from_c_int_const(%s);' % name)
+ prnt(' if (o == NULL)')
+ prnt(' return -1;')
+ if size_too:
+ prnt(' {')
+ prnt(' PyObject *o1 = o;')
+ prnt(' o = Py_BuildValue("On", o1, (Py_ssize_t)sizeof(%s));'
+ % (name,))
+ prnt(' Py_DECREF(o1);')
+ prnt(' if (o == NULL)')
+ prnt(' return -1;')
+ prnt(' }')
+ prnt(' res = PyObject_SetAttrString(lib, "%s", o);' % name)
+ prnt(' Py_DECREF(o);')
+ prnt(' if (res < 0)')
+ prnt(' return -1;')
+ prnt(' return %s;' % self._chained_list_constants[delayed])
+ self._chained_list_constants[delayed] = funcname + '(lib)'
+ prnt('}')
+ prnt()
+
+ def _generate_cpy_constant_collecttype(self, tp, name):
+ is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type()
+ if not is_int:
+ self._do_collect_type(tp)
+
+ def _generate_cpy_constant_decl(self, tp, name):
+ is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type()
+ self._generate_cpy_const(is_int, name, tp)
+
+ _generate_cpy_constant_method = _generate_nothing
+ _loading_cpy_constant = _loaded_noop
+ _loaded_cpy_constant = _loaded_noop
+
+ # ----------
+ # enums
+
+ def _check_int_constant_value(self, name, value, err_prefix=''):
+ prnt = self._prnt
+ if value <= 0:
+ prnt(' if ((%s) > 0 || (long)(%s) != %dL) {' % (
+ name, name, value))
+ else:
+ prnt(' if ((%s) <= 0 || (unsigned long)(%s) != %dUL) {' % (
+ name, name, value))
+ prnt(' char buf[64];')
+ prnt(' if ((%s) <= 0)' % name)
+ prnt(' snprintf(buf, 63, "%%ld", (long)(%s));' % name)
+ prnt(' else')
+ prnt(' snprintf(buf, 63, "%%lu", (unsigned long)(%s));' %
+ name)
+ prnt(' PyErr_Format(_cffi_VerificationError,')
+ prnt(' "%s%s has the real value %s, not %s",')
+ prnt(' "%s", "%s", buf, "%d");' % (
+ err_prefix, name, value))
+ prnt(' return -1;')
+ prnt(' }')
+
+ def _enum_funcname(self, prefix, name):
+ # "$enum_$1" => "___D_enum____D_1"
+ name = name.replace('$', '___D_')
+ return '_cffi_e_%s_%s' % (prefix, name)
+
+ def _generate_cpy_enum_decl(self, tp, name, prefix='enum'):
+ if tp.partial:
+ for enumerator in tp.enumerators:
+ self._generate_cpy_const(True, enumerator, delayed=False)
+ return
+ #
+ funcname = self._enum_funcname(prefix, name)
+ prnt = self._prnt
+ prnt('static int %s(PyObject *lib)' % funcname)
+ prnt('{')
+ for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
+ self._check_int_constant_value(enumerator, enumvalue,
+ "enum %s: " % name)
+ prnt(' return %s;' % self._chained_list_constants[True])
+ self._chained_list_constants[True] = funcname + '(lib)'
+ prnt('}')
+ prnt()
+
+ _generate_cpy_enum_collecttype = _generate_nothing
+ _generate_cpy_enum_method = _generate_nothing
+
+ def _loading_cpy_enum(self, tp, name, module):
+ if tp.partial:
+ enumvalues = [getattr(module, enumerator)
+ for enumerator in tp.enumerators]
+ tp.enumvalues = tuple(enumvalues)
+ tp.partial_resolved = True
+
+ def _loaded_cpy_enum(self, tp, name, module, library):
+ for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
+ setattr(library, enumerator, enumvalue)
+
+ # ----------
+ # macros: for now only for integers
+
+ def _generate_cpy_macro_decl(self, tp, name):
+ if tp == '...':
+ check_value = None
+ else:
+ check_value = tp # an integer
+ self._generate_cpy_const(True, name, check_value=check_value)
+
+ _generate_cpy_macro_collecttype = _generate_nothing
+ _generate_cpy_macro_method = _generate_nothing
+ _loading_cpy_macro = _loaded_noop
+ _loaded_cpy_macro = _loaded_noop
+
+ # ----------
+ # global variables
+
+ def _generate_cpy_variable_collecttype(self, tp, name):
+ if isinstance(tp, model.ArrayType):
+ tp_ptr = model.PointerType(tp.item)
+ else:
+ tp_ptr = model.PointerType(tp)
+ self._do_collect_type(tp_ptr)
+
+ def _generate_cpy_variable_decl(self, tp, name):
+ if isinstance(tp, model.ArrayType):
+ tp_ptr = model.PointerType(tp.item)
+ self._generate_cpy_const(False, name, tp, vartp=tp_ptr,
+ size_too = tp.length_is_unknown())
+ else:
+ tp_ptr = model.PointerType(tp)
+ self._generate_cpy_const(False, name, tp_ptr, category='var')
+
+ _generate_cpy_variable_method = _generate_nothing
+ _loading_cpy_variable = _loaded_noop
+
+ def _loaded_cpy_variable(self, tp, name, module, library):
+ value = getattr(library, name)
+ if isinstance(tp, model.ArrayType): # int a[5] is "constant" in the
+ # sense that "a=..." is forbidden
+ if tp.length_is_unknown():
+ assert isinstance(value, tuple)
+ (value, size) = value
+ BItemType = self.ffi._get_cached_btype(tp.item)
+ length, rest = divmod(size, self.ffi.sizeof(BItemType))
+ if rest != 0:
+ raise VerificationError(
+ "bad size: %r does not seem to be an array of %s" %
+ (name, tp.item))
+ tp = tp.resolve_length(length)
+ # 'value' is a <cdata 'type *'> which we have to replace with
+ # a <cdata 'type[N]'> if the N is actually known
+ if tp.length is not None:
+ BArray = self.ffi._get_cached_btype(tp)
+ value = self.ffi.cast(BArray, value)
+ setattr(library, name, value)
+ return
+ # remove ptr=<cdata 'int *'> from the library instance, and replace
+ # it by a property on the class, which reads/writes into ptr[0].
+ ptr = value
+ delattr(library, name)
+ def getter(library):
+ return ptr[0]
+ def setter(library, value):
+ ptr[0] = value
+ setattr(type(library), name, property(getter, setter))
+ type(library)._cffi_dir.append(name)
+
+ # ----------
+
+ def _generate_setup_custom(self):
+ prnt = self._prnt
+ prnt('static int _cffi_setup_custom(PyObject *lib)')
+ prnt('{')
+ prnt(' return %s;' % self._chained_list_constants[True])
+ prnt('}')
+
+cffimod_header = r'''
+#include <Python.h>
+#include <stddef.h>
+
+/* this block of #ifs should be kept exactly identical between
+ c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py
+ and cffi/_cffi_include.h */
+#if defined(_MSC_VER)
+# include <malloc.h> /* for alloca() */
+# if _MSC_VER < 1600 /* MSVC < 2010 */
+ typedef __int8 int8_t;
+ typedef __int16 int16_t;
+ typedef __int32 int32_t;
+ typedef __int64 int64_t;
+ typedef unsigned __int8 uint8_t;
+ typedef unsigned __int16 uint16_t;
+ typedef unsigned __int32 uint32_t;
+ typedef unsigned __int64 uint64_t;
+ typedef __int8 int_least8_t;
+ typedef __int16 int_least16_t;
+ typedef __int32 int_least32_t;
+ typedef __int64 int_least64_t;
+ typedef unsigned __int8 uint_least8_t;
+ typedef unsigned __int16 uint_least16_t;
+ typedef unsigned __int32 uint_least32_t;
+ typedef unsigned __int64 uint_least64_t;
+ typedef __int8 int_fast8_t;
+ typedef __int16 int_fast16_t;
+ typedef __int32 int_fast32_t;
+ typedef __int64 int_fast64_t;
+ typedef unsigned __int8 uint_fast8_t;
+ typedef unsigned __int16 uint_fast16_t;
+ typedef unsigned __int32 uint_fast32_t;
+ typedef unsigned __int64 uint_fast64_t;
+ typedef __int64 intmax_t;
+ typedef unsigned __int64 uintmax_t;
+# else
+# include <stdint.h>
+# endif
+# if _MSC_VER < 1800 /* MSVC < 2013 */
+# ifndef __cplusplus
+ typedef unsigned char _Bool;
+# endif
+# endif
+#else
+# include <stdint.h>
+# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux)
+# include <alloca.h>
+# endif
+#endif
+
+#if PY_MAJOR_VERSION < 3
+# undef PyCapsule_CheckExact
+# undef PyCapsule_GetPointer
+# define PyCapsule_CheckExact(capsule) (PyCObject_Check(capsule))
+# define PyCapsule_GetPointer(capsule, name) \
+ (PyCObject_AsVoidPtr(capsule))
+#endif
+
+#if PY_MAJOR_VERSION >= 3
+# define PyInt_FromLong PyLong_FromLong
+#endif
+
+#define _cffi_from_c_double PyFloat_FromDouble
+#define _cffi_from_c_float PyFloat_FromDouble
+#define _cffi_from_c_long PyInt_FromLong
+#define _cffi_from_c_ulong PyLong_FromUnsignedLong
+#define _cffi_from_c_longlong PyLong_FromLongLong
+#define _cffi_from_c_ulonglong PyLong_FromUnsignedLongLong
+#define _cffi_from_c__Bool PyBool_FromLong
+
+#define _cffi_to_c_double PyFloat_AsDouble
+#define _cffi_to_c_float PyFloat_AsDouble
+
+#define _cffi_from_c_int_const(x) \
+ (((x) > 0) ? \
+ ((unsigned long long)(x) <= (unsigned long long)LONG_MAX) ? \
+ PyInt_FromLong((long)(x)) : \
+ PyLong_FromUnsignedLongLong((unsigned long long)(x)) : \
+ ((long long)(x) >= (long long)LONG_MIN) ? \
+ PyInt_FromLong((long)(x)) : \
+ PyLong_FromLongLong((long long)(x)))
+
+#define _cffi_from_c_int(x, type) \
+ (((type)-1) > 0 ? /* unsigned */ \
+ (sizeof(type) < sizeof(long) ? \
+ PyInt_FromLong((long)x) : \
+ sizeof(type) == sizeof(long) ? \
+ PyLong_FromUnsignedLong((unsigned long)x) : \
+ PyLong_FromUnsignedLongLong((unsigned long long)x)) : \
+ (sizeof(type) <= sizeof(long) ? \
+ PyInt_FromLong((long)x) : \
+ PyLong_FromLongLong((long long)x)))
+
+#define _cffi_to_c_int(o, type) \
+ ((type)( \
+ sizeof(type) == 1 ? (((type)-1) > 0 ? (type)_cffi_to_c_u8(o) \
+ : (type)_cffi_to_c_i8(o)) : \
+ sizeof(type) == 2 ? (((type)-1) > 0 ? (type)_cffi_to_c_u16(o) \
+ : (type)_cffi_to_c_i16(o)) : \
+ sizeof(type) == 4 ? (((type)-1) > 0 ? (type)_cffi_to_c_u32(o) \
+ : (type)_cffi_to_c_i32(o)) : \
+ sizeof(type) == 8 ? (((type)-1) > 0 ? (type)_cffi_to_c_u64(o) \
+ : (type)_cffi_to_c_i64(o)) : \
+ (Py_FatalError("unsupported size for type " #type), (type)0)))
+
+#define _cffi_to_c_i8 \
+ ((int(*)(PyObject *))_cffi_exports[1])
+#define _cffi_to_c_u8 \
+ ((int(*)(PyObject *))_cffi_exports[2])
+#define _cffi_to_c_i16 \
+ ((int(*)(PyObject *))_cffi_exports[3])
+#define _cffi_to_c_u16 \
+ ((int(*)(PyObject *))_cffi_exports[4])
+#define _cffi_to_c_i32 \
+ ((int(*)(PyObject *))_cffi_exports[5])
+#define _cffi_to_c_u32 \
+ ((unsigned int(*)(PyObject *))_cffi_exports[6])
+#define _cffi_to_c_i64 \
+ ((long long(*)(PyObject *))_cffi_exports[7])
+#define _cffi_to_c_u64 \
+ ((unsigned long long(*)(PyObject *))_cffi_exports[8])
+#define _cffi_to_c_char \
+ ((int(*)(PyObject *))_cffi_exports[9])
+#define _cffi_from_c_pointer \
+ ((PyObject *(*)(char *, CTypeDescrObject *))_cffi_exports[10])
+#define _cffi_to_c_pointer \
+ ((char *(*)(PyObject *, CTypeDescrObject *))_cffi_exports[11])
+#define _cffi_get_struct_layout \
+ ((PyObject *(*)(Py_ssize_t[]))_cffi_exports[12])
+#define _cffi_restore_errno \
+ ((void(*)(void))_cffi_exports[13])
+#define _cffi_save_errno \
+ ((void(*)(void))_cffi_exports[14])
+#define _cffi_from_c_char \
+ ((PyObject *(*)(char))_cffi_exports[15])
+#define _cffi_from_c_deref \
+ ((PyObject *(*)(char *, CTypeDescrObject *))_cffi_exports[16])
+#define _cffi_to_c \
+ ((int(*)(char *, CTypeDescrObject *, PyObject *))_cffi_exports[17])
+#define _cffi_from_c_struct \
+ ((PyObject *(*)(char *, CTypeDescrObject *))_cffi_exports[18])
+#define _cffi_to_c_wchar_t \
+ ((wchar_t(*)(PyObject *))_cffi_exports[19])
+#define _cffi_from_c_wchar_t \
+ ((PyObject *(*)(wchar_t))_cffi_exports[20])
+#define _cffi_to_c_long_double \
+ ((long double(*)(PyObject *))_cffi_exports[21])
+#define _cffi_to_c__Bool \
+ ((_Bool(*)(PyObject *))_cffi_exports[22])
+#define _cffi_prepare_pointer_call_argument \
+ ((Py_ssize_t(*)(CTypeDescrObject *, PyObject *, char **))_cffi_exports[23])
+#define _cffi_convert_array_from_object \
+ ((int(*)(char *, CTypeDescrObject *, PyObject *))_cffi_exports[24])
+#define _CFFI_NUM_EXPORTS 25
+
+typedef struct _ctypedescr CTypeDescrObject;
+
+static void *_cffi_exports[_CFFI_NUM_EXPORTS];
+static PyObject *_cffi_types, *_cffi_VerificationError;
+
+static int _cffi_setup_custom(PyObject *lib); /* forward */
+
+static PyObject *_cffi_setup(PyObject *self, PyObject *args)
+{
+ PyObject *library;
+ int was_alive = (_cffi_types != NULL);
+ (void)self; /* unused */
+ if (!PyArg_ParseTuple(args, "OOO", &_cffi_types, &_cffi_VerificationError,
+ &library))
+ return NULL;
+ Py_INCREF(_cffi_types);
+ Py_INCREF(_cffi_VerificationError);
+ if (_cffi_setup_custom(library) < 0)
+ return NULL;
+ return PyBool_FromLong(was_alive);
+}
+
+union _cffi_union_alignment_u {
+ unsigned char m_char;
+ unsigned short m_short;
+ unsigned int m_int;
+ unsigned long m_long;
+ unsigned long long m_longlong;
+ float m_float;
+ double m_double;
+ long double m_longdouble;
+};
+
+struct _cffi_freeme_s {
+ struct _cffi_freeme_s *next;
+ union _cffi_union_alignment_u alignment;
+};
+
+#ifdef __GNUC__
+ __attribute__((unused))
+#endif
+static int _cffi_convert_array_argument(CTypeDescrObject *ctptr, PyObject *arg,
+ char **output_data, Py_ssize_t datasize,
+ struct _cffi_freeme_s **freeme)
+{
+ char *p;
+ if (datasize < 0)
+ return -1;
+
+ p = *output_data;
+ if (p == NULL) {
+ struct _cffi_freeme_s *fp = (struct _cffi_freeme_s *)PyObject_Malloc(
+ offsetof(struct _cffi_freeme_s, alignment) + (size_t)datasize);
+ if (fp == NULL)
+ return -1;
+ fp->next = *freeme;
+ *freeme = fp;
+ p = *output_data = (char *)&fp->alignment;
+ }
+ memset((void *)p, 0, (size_t)datasize);
+ return _cffi_convert_array_from_object(p, ctptr, arg);
+}
+
+#ifdef __GNUC__
+ __attribute__((unused))
+#endif
+static void _cffi_free_array_arguments(struct _cffi_freeme_s *freeme)
+{
+ do {
+ void *p = (void *)freeme;
+ freeme = freeme->next;
+ PyObject_Free(p);
+ } while (freeme != NULL);
+}
+
+static int _cffi_init(void)
+{
+ PyObject *module, *c_api_object = NULL;
+
+ module = PyImport_ImportModule("_cffi_backend");
+ if (module == NULL)
+ goto failure;
+
+ c_api_object = PyObject_GetAttrString(module, "_C_API");
+ if (c_api_object == NULL)
+ goto failure;
+ if (!PyCapsule_CheckExact(c_api_object)) {
+ PyErr_SetNone(PyExc_ImportError);
+ goto failure;
+ }
+ memcpy(_cffi_exports, PyCapsule_GetPointer(c_api_object, "cffi"),
+ _CFFI_NUM_EXPORTS * sizeof(void *));
+
+ Py_DECREF(module);
+ Py_DECREF(c_api_object);
+ return 0;
+
+ failure:
+ Py_XDECREF(module);
+ Py_XDECREF(c_api_object);
+ return -1;
+}
+
+#define _cffi_type(num) ((CTypeDescrObject *)PyList_GET_ITEM(_cffi_types, num))
+
+/**********/
+'''
diff --git a/lib/cffi/vengine_gen.py b/lib/cffi/vengine_gen.py
new file mode 100644
index 0000000..2642152
--- /dev/null
+++ b/lib/cffi/vengine_gen.py
@@ -0,0 +1,675 @@
+#
+# DEPRECATED: implementation for ffi.verify()
+#
+import sys, os
+import types
+
+from . import model
+from .error import VerificationError
+
+
+class VGenericEngine(object):
+ _class_key = 'g'
+ _gen_python_module = False
+
+ def __init__(self, verifier):
+ self.verifier = verifier
+ self.ffi = verifier.ffi
+ self.export_symbols = []
+ self._struct_pending_verification = {}
+
+ def patch_extension_kwds(self, kwds):
+ # add 'export_symbols' to the dictionary. Note that we add the
+ # list before filling it. When we fill it, it will thus also show
+ # up in kwds['export_symbols'].
+ kwds.setdefault('export_symbols', self.export_symbols)
+
+ def find_module(self, module_name, path, so_suffixes):
+ for so_suffix in so_suffixes:
+ basename = module_name + so_suffix
+ if path is None:
+ path = sys.path
+ for dirname in path:
+ filename = os.path.join(dirname, basename)
+ if os.path.isfile(filename):
+ return filename
+
+ def collect_types(self):
+ pass # not needed in the generic engine
+
+ def _prnt(self, what=''):
+ self._f.write(what + '\n')
+
+ def write_source_to_f(self):
+ prnt = self._prnt
+ # first paste some standard set of lines that are mostly '#include'
+ prnt(cffimod_header)
+ # then paste the C source given by the user, verbatim.
+ prnt(self.verifier.preamble)
+ #
+ # call generate_gen_xxx_decl(), for every xxx found from
+ # ffi._parser._declarations. This generates all the functions.
+ self._generate('decl')
+ #
+ # on Windows, distutils insists on putting init_cffi_xyz in
+ # 'export_symbols', so instead of fighting it, just give up and
+ # give it one
+ if sys.platform == 'win32':
+ if sys.version_info >= (3,):
+ prefix = 'PyInit_'
+ else:
+ prefix = 'init'
+ modname = self.verifier.get_module_name()
+ prnt("void %s%s(void) { }\n" % (prefix, modname))
+
+ def load_library(self, flags=0):
+ # import it with the CFFI backend
+ backend = self.ffi._backend
+ # needs to make a path that contains '/', on Posix
+ filename = os.path.join(os.curdir, self.verifier.modulefilename)
+ module = backend.load_library(filename, flags)
+ #
+ # call loading_gen_struct() to get the struct layout inferred by
+ # the C compiler
+ self._load(module, 'loading')
+
+ # build the FFILibrary class and instance, this is a module subclass
+ # because modules are expected to have usually-constant-attributes and
+ # in PyPy this means the JIT is able to treat attributes as constant,
+ # which we want.
+ class FFILibrary(types.ModuleType):
+ _cffi_generic_module = module
+ _cffi_ffi = self.ffi
+ _cffi_dir = []
+ def __dir__(self):
+ return FFILibrary._cffi_dir
+ library = FFILibrary("")
+ #
+ # finally, call the loaded_gen_xxx() functions. This will set
+ # up the 'library' object.
+ self._load(module, 'loaded', library=library)
+ return library
+
+ def _get_declarations(self):
+ lst = [(key, tp) for (key, (tp, qual)) in
+ self.ffi._parser._declarations.items()]
+ lst.sort()
+ return lst
+
+ def _generate(self, step_name):
+ for name, tp in self._get_declarations():
+ kind, realname = name.split(' ', 1)
+ try:
+ method = getattr(self, '_generate_gen_%s_%s' % (kind,
+ step_name))
+ except AttributeError:
+ raise VerificationError(
+ "not implemented in verify(): %r" % name)
+ try:
+ method(tp, realname)
+ except Exception as e:
+ model.attach_exception_info(e, name)
+ raise
+
+ def _load(self, module, step_name, **kwds):
+ for name, tp in self._get_declarations():
+ kind, realname = name.split(' ', 1)
+ method = getattr(self, '_%s_gen_%s' % (step_name, kind))
+ try:
+ method(tp, realname, module, **kwds)
+ except Exception as e:
+ model.attach_exception_info(e, name)
+ raise
+
+ def _generate_nothing(self, tp, name):
+ pass
+
+ def _loaded_noop(self, tp, name, module, **kwds):
+ pass
+
+ # ----------
+ # typedefs: generates no code so far
+
+ _generate_gen_typedef_decl = _generate_nothing
+ _loading_gen_typedef = _loaded_noop
+ _loaded_gen_typedef = _loaded_noop
+
+ # ----------
+ # function declarations
+
+ def _generate_gen_function_decl(self, tp, name):
+ assert isinstance(tp, model.FunctionPtrType)
+ if tp.ellipsis:
+ # cannot support vararg functions better than this: check for its
+ # exact type (including the fixed arguments), and build it as a
+ # constant function pointer (no _cffi_f_%s wrapper)
+ self._generate_gen_const(False, name, tp)
+ return
+ prnt = self._prnt
+ numargs = len(tp.args)
+ argnames = []
+ for i, type in enumerate(tp.args):
+ indirection = ''
+ if isinstance(type, model.StructOrUnion):
+ indirection = '*'
+ argnames.append('%sx%d' % (indirection, i))
+ context = 'argument of %s' % name
+ arglist = [type.get_c_name(' %s' % arg, context)
+ for type, arg in zip(tp.args, argnames)]
+ tpresult = tp.result
+ if isinstance(tpresult, model.StructOrUnion):
+ arglist.insert(0, tpresult.get_c_name(' *r', context))
+ tpresult = model.void_type
+ arglist = ', '.join(arglist) or 'void'
+ wrappername = '_cffi_f_%s' % name
+ self.export_symbols.append(wrappername)
+ if tp.abi:
+ abi = tp.abi + ' '
+ else:
+ abi = ''
+ funcdecl = ' %s%s(%s)' % (abi, wrappername, arglist)
+ context = 'result of %s' % name
+ prnt(tpresult.get_c_name(funcdecl, context))
+ prnt('{')
+ #
+ if isinstance(tp.result, model.StructOrUnion):
+ result_code = '*r = '
+ elif not isinstance(tp.result, model.VoidType):
+ result_code = 'return '
+ else:
+ result_code = ''
+ prnt(' %s%s(%s);' % (result_code, name, ', '.join(argnames)))
+ prnt('}')
+ prnt()
+
+ _loading_gen_function = _loaded_noop
+
+ def _loaded_gen_function(self, tp, name, module, library):
+ assert isinstance(tp, model.FunctionPtrType)
+ if tp.ellipsis:
+ newfunction = self._load_constant(False, tp, name, module)
+ else:
+ indirections = []
+ base_tp = tp
+ if (any(isinstance(typ, model.StructOrUnion) for typ in tp.args)
+ or isinstance(tp.result, model.StructOrUnion)):
+ indirect_args = []
+ for i, typ in enumerate(tp.args):
+ if isinstance(typ, model.StructOrUnion):
+ typ = model.PointerType(typ)
+ indirections.append((i, typ))
+ indirect_args.append(typ)
+ indirect_result = tp.result
+ if isinstance(indirect_result, model.StructOrUnion):
+ if indirect_result.fldtypes is None:
+ raise TypeError("'%s' is used as result type, "
+ "but is opaque" % (
+ indirect_result._get_c_name(),))
+ indirect_result = model.PointerType(indirect_result)
+ indirect_args.insert(0, indirect_result)
+ indirections.insert(0, ("result", indirect_result))
+ indirect_result = model.void_type
+ tp = model.FunctionPtrType(tuple(indirect_args),
+ indirect_result, tp.ellipsis)
+ BFunc = self.ffi._get_cached_btype(tp)
+ wrappername = '_cffi_f_%s' % name
+ newfunction = module.load_function(BFunc, wrappername)
+ for i, typ in indirections:
+ newfunction = self._make_struct_wrapper(newfunction, i, typ,
+ base_tp)
+ setattr(library, name, newfunction)
+ type(library)._cffi_dir.append(name)
+
+ def _make_struct_wrapper(self, oldfunc, i, tp, base_tp):
+ backend = self.ffi._backend
+ BType = self.ffi._get_cached_btype(tp)
+ if i == "result":
+ ffi = self.ffi
+ def newfunc(*args):
+ res = ffi.new(BType)
+ oldfunc(res, *args)
+ return res[0]
+ else:
+ def newfunc(*args):
+ args = args[:i] + (backend.newp(BType, args[i]),) + args[i+1:]
+ return oldfunc(*args)
+ newfunc._cffi_base_type = base_tp
+ return newfunc
+
+ # ----------
+ # named structs
+
+ def _generate_gen_struct_decl(self, tp, name):
+ assert name == tp.name
+ self._generate_struct_or_union_decl(tp, 'struct', name)
+
+ def _loading_gen_struct(self, tp, name, module):
+ self._loading_struct_or_union(tp, 'struct', name, module)
+
+ def _loaded_gen_struct(self, tp, name, module, **kwds):
+ self._loaded_struct_or_union(tp)
+
+ def _generate_gen_union_decl(self, tp, name):
+ assert name == tp.name
+ self._generate_struct_or_union_decl(tp, 'union', name)
+
+ def _loading_gen_union(self, tp, name, module):
+ self._loading_struct_or_union(tp, 'union', name, module)
+
+ def _loaded_gen_union(self, tp, name, module, **kwds):
+ self._loaded_struct_or_union(tp)
+
+ def _generate_struct_or_union_decl(self, tp, prefix, name):
+ if tp.fldnames is None:
+ return # nothing to do with opaque structs
+ checkfuncname = '_cffi_check_%s_%s' % (prefix, name)
+ layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name)
+ cname = ('%s %s' % (prefix, name)).strip()
+ #
+ prnt = self._prnt
+ prnt('static void %s(%s *p)' % (checkfuncname, cname))
+ prnt('{')
+ prnt(' /* only to generate compile-time warnings or errors */')
+ prnt(' (void)p;')
+ for fname, ftype, fbitsize, fqual in tp.enumfields():
+ if (isinstance(ftype, model.PrimitiveType)
+ and ftype.is_integer_type()) or fbitsize >= 0:
+ # accept all integers, but complain on float or double
+ prnt(' (void)((p->%s) << 1);' % fname)
+ else:
+ # only accept exactly the type declared.
+ try:
+ prnt(' { %s = &p->%s; (void)tmp; }' % (
+ ftype.get_c_name('*tmp', 'field %r'%fname, quals=fqual),
+ fname))
+ except VerificationError as e:
+ prnt(' /* %s */' % str(e)) # cannot verify it, ignore
+ prnt('}')
+ self.export_symbols.append(layoutfuncname)
+ prnt('intptr_t %s(intptr_t i)' % (layoutfuncname,))
+ prnt('{')
+ prnt(' struct _cffi_aligncheck { char x; %s y; };' % cname)
+ prnt(' static intptr_t nums[] = {')
+ prnt(' sizeof(%s),' % cname)
+ prnt(' offsetof(struct _cffi_aligncheck, y),')
+ for fname, ftype, fbitsize, fqual in tp.enumfields():
+ if fbitsize >= 0:
+ continue # xxx ignore fbitsize for now
+ prnt(' offsetof(%s, %s),' % (cname, fname))
+ if isinstance(ftype, model.ArrayType) and ftype.length is None:
+ prnt(' 0, /* %s */' % ftype._get_c_name())
+ else:
+ prnt(' sizeof(((%s *)0)->%s),' % (cname, fname))
+ prnt(' -1')
+ prnt(' };')
+ prnt(' return nums[i];')
+ prnt(' /* the next line is not executed, but compiled */')
+ prnt(' %s(0);' % (checkfuncname,))
+ prnt('}')
+ prnt()
+
+ def _loading_struct_or_union(self, tp, prefix, name, module):
+ if tp.fldnames is None:
+ return # nothing to do with opaque structs
+ layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name)
+ #
+ BFunc = self.ffi._typeof_locked("intptr_t(*)(intptr_t)")[0]
+ function = module.load_function(BFunc, layoutfuncname)
+ layout = []
+ num = 0
+ while True:
+ x = function(num)
+ if x < 0: break
+ layout.append(x)
+ num += 1
+ if isinstance(tp, model.StructOrUnion) and tp.partial:
+ # use the function()'s sizes and offsets to guide the
+ # layout of the struct
+ totalsize = layout[0]
+ totalalignment = layout[1]
+ fieldofs = layout[2::2]
+ fieldsize = layout[3::2]
+ tp.force_flatten()
+ assert len(fieldofs) == len(fieldsize) == len(tp.fldnames)
+ tp.fixedlayout = fieldofs, fieldsize, totalsize, totalalignment
+ else:
+ cname = ('%s %s' % (prefix, name)).strip()
+ self._struct_pending_verification[tp] = layout, cname
+
+ def _loaded_struct_or_union(self, tp):
+ if tp.fldnames is None:
+ return # nothing to do with opaque structs
+ self.ffi._get_cached_btype(tp) # force 'fixedlayout' to be considered
+
+ if tp in self._struct_pending_verification:
+ # check that the layout sizes and offsets match the real ones
+ def check(realvalue, expectedvalue, msg):
+ if realvalue != expectedvalue:
+ raise VerificationError(
+ "%s (we have %d, but C compiler says %d)"
+ % (msg, expectedvalue, realvalue))
+ ffi = self.ffi
+ BStruct = ffi._get_cached_btype(tp)
+ layout, cname = self._struct_pending_verification.pop(tp)
+ check(layout[0], ffi.sizeof(BStruct), "wrong total size")
+ check(layout[1], ffi.alignof(BStruct), "wrong total alignment")
+ i = 2
+ for fname, ftype, fbitsize, fqual in tp.enumfields():
+ if fbitsize >= 0:
+ continue # xxx ignore fbitsize for now
+ check(layout[i], ffi.offsetof(BStruct, fname),
+ "wrong offset for field %r" % (fname,))
+ if layout[i+1] != 0:
+ BField = ffi._get_cached_btype(ftype)
+ check(layout[i+1], ffi.sizeof(BField),
+ "wrong size for field %r" % (fname,))
+ i += 2
+ assert i == len(layout)
+
+ # ----------
+ # 'anonymous' declarations. These are produced for anonymous structs
+ # or unions; the 'name' is obtained by a typedef.
+
+ def _generate_gen_anonymous_decl(self, tp, name):
+ if isinstance(tp, model.EnumType):
+ self._generate_gen_enum_decl(tp, name, '')
+ else:
+ self._generate_struct_or_union_decl(tp, '', name)
+
+ def _loading_gen_anonymous(self, tp, name, module):
+ if isinstance(tp, model.EnumType):
+ self._loading_gen_enum(tp, name, module, '')
+ else:
+ self._loading_struct_or_union(tp, '', name, module)
+
+ def _loaded_gen_anonymous(self, tp, name, module, **kwds):
+ if isinstance(tp, model.EnumType):
+ self._loaded_gen_enum(tp, name, module, **kwds)
+ else:
+ self._loaded_struct_or_union(tp)
+
+ # ----------
+ # constants, likely declared with '#define'
+
+ def _generate_gen_const(self, is_int, name, tp=None, category='const',
+ check_value=None):
+ prnt = self._prnt
+ funcname = '_cffi_%s_%s' % (category, name)
+ self.export_symbols.append(funcname)
+ if check_value is not None:
+ assert is_int
+ assert category == 'const'
+ prnt('int %s(char *out_error)' % funcname)
+ prnt('{')
+ self._check_int_constant_value(name, check_value)
+ prnt(' return 0;')
+ prnt('}')
+ elif is_int:
+ assert category == 'const'
+ prnt('int %s(long long *out_value)' % funcname)
+ prnt('{')
+ prnt(' *out_value = (long long)(%s);' % (name,))
+ prnt(' return (%s) <= 0;' % (name,))
+ prnt('}')
+ else:
+ assert tp is not None
+ assert check_value is None
+ if category == 'var':
+ ampersand = '&'
+ else:
+ ampersand = ''
+ extra = ''
+ if category == 'const' and isinstance(tp, model.StructOrUnion):
+ extra = 'const *'
+ ampersand = '&'
+ prnt(tp.get_c_name(' %s%s(void)' % (extra, funcname), name))
+ prnt('{')
+ prnt(' return (%s%s);' % (ampersand, name))
+ prnt('}')
+ prnt()
+
+ def _generate_gen_constant_decl(self, tp, name):
+ is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type()
+ self._generate_gen_const(is_int, name, tp)
+
+ _loading_gen_constant = _loaded_noop
+
+ def _load_constant(self, is_int, tp, name, module, check_value=None):
+ funcname = '_cffi_const_%s' % name
+ if check_value is not None:
+ assert is_int
+ self._load_known_int_constant(module, funcname)
+ value = check_value
+ elif is_int:
+ BType = self.ffi._typeof_locked("long long*")[0]
+ BFunc = self.ffi._typeof_locked("int(*)(long long*)")[0]
+ function = module.load_function(BFunc, funcname)
+ p = self.ffi.new(BType)
+ negative = function(p)
+ value = int(p[0])
+ if value < 0 and not negative:
+ BLongLong = self.ffi._typeof_locked("long long")[0]
+ value += (1 << (8*self.ffi.sizeof(BLongLong)))
+ else:
+ assert check_value is None
+ fntypeextra = '(*)(void)'
+ if isinstance(tp, model.StructOrUnion):
+ fntypeextra = '*' + fntypeextra
+ BFunc = self.ffi._typeof_locked(tp.get_c_name(fntypeextra, name))[0]
+ function = module.load_function(BFunc, funcname)
+ value = function()
+ if isinstance(tp, model.StructOrUnion):
+ value = value[0]
+ return value
+
+ def _loaded_gen_constant(self, tp, name, module, library):
+ is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type()
+ value = self._load_constant(is_int, tp, name, module)
+ setattr(library, name, value)
+ type(library)._cffi_dir.append(name)
+
+ # ----------
+ # enums
+
+ def _check_int_constant_value(self, name, value):
+ prnt = self._prnt
+ if value <= 0:
+ prnt(' if ((%s) > 0 || (long)(%s) != %dL) {' % (
+ name, name, value))
+ else:
+ prnt(' if ((%s) <= 0 || (unsigned long)(%s) != %dUL) {' % (
+ name, name, value))
+ prnt(' char buf[64];')
+ prnt(' if ((%s) <= 0)' % name)
+ prnt(' sprintf(buf, "%%ld", (long)(%s));' % name)
+ prnt(' else')
+ prnt(' sprintf(buf, "%%lu", (unsigned long)(%s));' %
+ name)
+ prnt(' sprintf(out_error, "%s has the real value %s, not %s",')
+ prnt(' "%s", buf, "%d");' % (name[:100], value))
+ prnt(' return -1;')
+ prnt(' }')
+
+ def _load_known_int_constant(self, module, funcname):
+ BType = self.ffi._typeof_locked("char[]")[0]
+ BFunc = self.ffi._typeof_locked("int(*)(char*)")[0]
+ function = module.load_function(BFunc, funcname)
+ p = self.ffi.new(BType, 256)
+ if function(p) < 0:
+ error = self.ffi.string(p)
+ if sys.version_info >= (3,):
+ error = str(error, 'utf-8')
+ raise VerificationError(error)
+
+ def _enum_funcname(self, prefix, name):
+ # "$enum_$1" => "___D_enum____D_1"
+ name = name.replace('$', '___D_')
+ return '_cffi_e_%s_%s' % (prefix, name)
+
+ def _generate_gen_enum_decl(self, tp, name, prefix='enum'):
+ if tp.partial:
+ for enumerator in tp.enumerators:
+ self._generate_gen_const(True, enumerator)
+ return
+ #
+ funcname = self._enum_funcname(prefix, name)
+ self.export_symbols.append(funcname)
+ prnt = self._prnt
+ prnt('int %s(char *out_error)' % funcname)
+ prnt('{')
+ for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
+ self._check_int_constant_value(enumerator, enumvalue)
+ prnt(' return 0;')
+ prnt('}')
+ prnt()
+
+ def _loading_gen_enum(self, tp, name, module, prefix='enum'):
+ if tp.partial:
+ enumvalues = [self._load_constant(True, tp, enumerator, module)
+ for enumerator in tp.enumerators]
+ tp.enumvalues = tuple(enumvalues)
+ tp.partial_resolved = True
+ else:
+ funcname = self._enum_funcname(prefix, name)
+ self._load_known_int_constant(module, funcname)
+
+ def _loaded_gen_enum(self, tp, name, module, library):
+ for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
+ setattr(library, enumerator, enumvalue)
+ type(library)._cffi_dir.append(enumerator)
+
+ # ----------
+ # macros: for now only for integers
+
+ def _generate_gen_macro_decl(self, tp, name):
+ if tp == '...':
+ check_value = None
+ else:
+ check_value = tp # an integer
+ self._generate_gen_const(True, name, check_value=check_value)
+
+ _loading_gen_macro = _loaded_noop
+
+ def _loaded_gen_macro(self, tp, name, module, library):
+ if tp == '...':
+ check_value = None
+ else:
+ check_value = tp # an integer
+ value = self._load_constant(True, tp, name, module,
+ check_value=check_value)
+ setattr(library, name, value)
+ type(library)._cffi_dir.append(name)
+
+ # ----------
+ # global variables
+
+ def _generate_gen_variable_decl(self, tp, name):
+ if isinstance(tp, model.ArrayType):
+ if tp.length_is_unknown():
+ prnt = self._prnt
+ funcname = '_cffi_sizeof_%s' % (name,)
+ self.export_symbols.append(funcname)
+ prnt("size_t %s(void)" % funcname)
+ prnt("{")
+ prnt(" return sizeof(%s);" % (name,))
+ prnt("}")
+ tp_ptr = model.PointerType(tp.item)
+ self._generate_gen_const(False, name, tp_ptr)
+ else:
+ tp_ptr = model.PointerType(tp)
+ self._generate_gen_const(False, name, tp_ptr, category='var')
+
+ _loading_gen_variable = _loaded_noop
+
+ def _loaded_gen_variable(self, tp, name, module, library):
+ if isinstance(tp, model.ArrayType): # int a[5] is "constant" in the
+ # sense that "a=..." is forbidden
+ if tp.length_is_unknown():
+ funcname = '_cffi_sizeof_%s' % (name,)
+ BFunc = self.ffi._typeof_locked('size_t(*)(void)')[0]
+ function = module.load_function(BFunc, funcname)
+ size = function()
+ BItemType = self.ffi._get_cached_btype(tp.item)
+ length, rest = divmod(size, self.ffi.sizeof(BItemType))
+ if rest != 0:
+ raise VerificationError(
+ "bad size: %r does not seem to be an array of %s" %
+ (name, tp.item))
+ tp = tp.resolve_length(length)
+ tp_ptr = model.PointerType(tp.item)
+ value = self._load_constant(False, tp_ptr, name, module)
+ # 'value' is a <cdata 'type *'> which we have to replace with
+ # a <cdata 'type[N]'> if the N is actually known
+ if tp.length is not None:
+ BArray = self.ffi._get_cached_btype(tp)
+ value = self.ffi.cast(BArray, value)
+ setattr(library, name, value)
+ type(library)._cffi_dir.append(name)
+ return
+ # remove ptr=<cdata 'int *'> from the library instance, and replace
+ # it by a property on the class, which reads/writes into ptr[0].
+ funcname = '_cffi_var_%s' % name
+ BFunc = self.ffi._typeof_locked(tp.get_c_name('*(*)(void)', name))[0]
+ function = module.load_function(BFunc, funcname)
+ ptr = function()
+ def getter(library):
+ return ptr[0]
+ def setter(library, value):
+ ptr[0] = value
+ setattr(type(library), name, property(getter, setter))
+ type(library)._cffi_dir.append(name)
+
+cffimod_header = r'''
+#include <stdio.h>
+#include <stddef.h>
+#include <stdarg.h>
+#include <errno.h>
+#include <sys/types.h> /* XXX for ssize_t on some platforms */
+
+/* this block of #ifs should be kept exactly identical between
+ c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py
+ and cffi/_cffi_include.h */
+#if defined(_MSC_VER)
+# include <malloc.h> /* for alloca() */
+# if _MSC_VER < 1600 /* MSVC < 2010 */
+ typedef __int8 int8_t;
+ typedef __int16 int16_t;
+ typedef __int32 int32_t;
+ typedef __int64 int64_t;
+ typedef unsigned __int8 uint8_t;
+ typedef unsigned __int16 uint16_t;
+ typedef unsigned __int32 uint32_t;
+ typedef unsigned __int64 uint64_t;
+ typedef __int8 int_least8_t;
+ typedef __int16 int_least16_t;
+ typedef __int32 int_least32_t;
+ typedef __int64 int_least64_t;
+ typedef unsigned __int8 uint_least8_t;
+ typedef unsigned __int16 uint_least16_t;
+ typedef unsigned __int32 uint_least32_t;
+ typedef unsigned __int64 uint_least64_t;
+ typedef __int8 int_fast8_t;
+ typedef __int16 int_fast16_t;
+ typedef __int32 int_fast32_t;
+ typedef __int64 int_fast64_t;
+ typedef unsigned __int8 uint_fast8_t;
+ typedef unsigned __int16 uint_fast16_t;
+ typedef unsigned __int32 uint_fast32_t;
+ typedef unsigned __int64 uint_fast64_t;
+ typedef __int64 intmax_t;
+ typedef unsigned __int64 uintmax_t;
+# else
+# include <stdint.h>
+# endif
+# if _MSC_VER < 1800 /* MSVC < 2013 */
+# ifndef __cplusplus
+ typedef unsigned char _Bool;
+# endif
+# endif
+#else
+# include <stdint.h>
+# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux)
+# include <alloca.h>
+# endif
+#endif
+'''
diff --git a/lib/cffi/verifier.py b/lib/cffi/verifier.py
new file mode 100644
index 0000000..a500c78
--- /dev/null
+++ b/lib/cffi/verifier.py
@@ -0,0 +1,307 @@
+#
+# DEPRECATED: implementation for ffi.verify()
+#
+import sys, os, binascii, shutil, io
+from . import __version_verifier_modules__
+from . import ffiplatform
+from .error import VerificationError
+
+if sys.version_info >= (3, 3):
+ import importlib.machinery
+ def _extension_suffixes():
+ return importlib.machinery.EXTENSION_SUFFIXES[:]
+else:
+ import imp
+ def _extension_suffixes():
+ return [suffix for suffix, _, type in imp.get_suffixes()
+ if type == imp.C_EXTENSION]
+
+
+if sys.version_info >= (3,):
+ NativeIO = io.StringIO
+else:
+ class NativeIO(io.BytesIO):
+ def write(self, s):
+ if isinstance(s, unicode):
+ s = s.encode('ascii')
+ super(NativeIO, self).write(s)
+
+
+class Verifier(object):
+
+ def __init__(self, ffi, preamble, tmpdir=None, modulename=None,
+ ext_package=None, tag='', force_generic_engine=False,
+ source_extension='.c', flags=None, relative_to=None, **kwds):
+ if ffi._parser._uses_new_feature:
+ raise VerificationError(
+ "feature not supported with ffi.verify(), but only "
+ "with ffi.set_source(): %s" % (ffi._parser._uses_new_feature,))
+ self.ffi = ffi
+ self.preamble = preamble
+ if not modulename:
+ flattened_kwds = ffiplatform.flatten(kwds)
+ vengine_class = _locate_engine_class(ffi, force_generic_engine)
+ self._vengine = vengine_class(self)
+ self._vengine.patch_extension_kwds(kwds)
+ self.flags = flags
+ self.kwds = self.make_relative_to(kwds, relative_to)
+ #
+ if modulename:
+ if tag:
+ raise TypeError("can't specify both 'modulename' and 'tag'")
+ else:
+ key = '\x00'.join(['%d.%d' % sys.version_info[:2],
+ __version_verifier_modules__,
+ preamble, flattened_kwds] +
+ ffi._cdefsources)
+ if sys.version_info >= (3,):
+ key = key.encode('utf-8')
+ k1 = hex(binascii.crc32(key[0::2]) & 0xffffffff)
+ k1 = k1.lstrip('0x').rstrip('L')
+ k2 = hex(binascii.crc32(key[1::2]) & 0xffffffff)
+ k2 = k2.lstrip('0').rstrip('L')
+ modulename = '_cffi_%s_%s%s%s' % (tag, self._vengine._class_key,
+ k1, k2)
+ suffix = _get_so_suffixes()[0]
+ self.tmpdir = tmpdir or _caller_dir_pycache()
+ self.sourcefilename = os.path.join(self.tmpdir, modulename + source_extension)
+ self.modulefilename = os.path.join(self.tmpdir, modulename + suffix)
+ self.ext_package = ext_package
+ self._has_source = False
+ self._has_module = False
+
+ def write_source(self, file=None):
+ """Write the C source code. It is produced in 'self.sourcefilename',
+ which can be tweaked beforehand."""
+ with self.ffi._lock:
+ if self._has_source and file is None:
+ raise VerificationError(
+ "source code already written")
+ self._write_source(file)
+
+ def compile_module(self):
+ """Write the C source code (if not done already) and compile it.
+ This produces a dynamic link library in 'self.modulefilename'."""
+ with self.ffi._lock:
+ if self._has_module:
+ raise VerificationError("module already compiled")
+ if not self._has_source:
+ self._write_source()
+ self._compile_module()
+
+ def load_library(self):
+ """Get a C module from this Verifier instance.
+ Returns an instance of a FFILibrary class that behaves like the
+ objects returned by ffi.dlopen(), but that delegates all
+ operations to the C module. If necessary, the C code is written
+ and compiled first.
+ """
+ with self.ffi._lock:
+ if not self._has_module:
+ self._locate_module()
+ if not self._has_module:
+ if not self._has_source:
+ self._write_source()
+ self._compile_module()
+ return self._load_library()
+
+ def get_module_name(self):
+ basename = os.path.basename(self.modulefilename)
+ # kill both the .so extension and the other .'s, as introduced
+ # by Python 3: 'basename.cpython-33m.so'
+ basename = basename.split('.', 1)[0]
+ # and the _d added in Python 2 debug builds --- but try to be
+ # conservative and not kill a legitimate _d
+ if basename.endswith('_d') and hasattr(sys, 'gettotalrefcount'):
+ basename = basename[:-2]
+ return basename
+
+ def get_extension(self):
+ ffiplatform._hack_at_distutils() # backward compatibility hack
+ if not self._has_source:
+ with self.ffi._lock:
+ if not self._has_source:
+ self._write_source()
+ sourcename = ffiplatform.maybe_relative_path(self.sourcefilename)
+ modname = self.get_module_name()
+ return ffiplatform.get_extension(sourcename, modname, **self.kwds)
+
+ def generates_python_module(self):
+ return self._vengine._gen_python_module
+
+ def make_relative_to(self, kwds, relative_to):
+ if relative_to and os.path.dirname(relative_to):
+ dirname = os.path.dirname(relative_to)
+ kwds = kwds.copy()
+ for key in ffiplatform.LIST_OF_FILE_NAMES:
+ if key in kwds:
+ lst = kwds[key]
+ if not isinstance(lst, (list, tuple)):
+ raise TypeError("keyword '%s' should be a list or tuple"
+ % (key,))
+ lst = [os.path.join(dirname, fn) for fn in lst]
+ kwds[key] = lst
+ return kwds
+
+ # ----------
+
+ def _locate_module(self):
+ if not os.path.isfile(self.modulefilename):
+ if self.ext_package:
+ try:
+ pkg = __import__(self.ext_package, None, None, ['__doc__'])
+ except ImportError:
+ return # cannot import the package itself, give up
+ # (e.g. it might be called differently before installation)
+ path = pkg.__path__
+ else:
+ path = None
+ filename = self._vengine.find_module(self.get_module_name(), path,
+ _get_so_suffixes())
+ if filename is None:
+ return
+ self.modulefilename = filename
+ self._vengine.collect_types()
+ self._has_module = True
+
+ def _write_source_to(self, file):
+ self._vengine._f = file
+ try:
+ self._vengine.write_source_to_f()
+ finally:
+ del self._vengine._f
+
+ def _write_source(self, file=None):
+ if file is not None:
+ self._write_source_to(file)
+ else:
+ # Write our source file to an in memory file.
+ f = NativeIO()
+ self._write_source_to(f)
+ source_data = f.getvalue()
+
+ # Determine if this matches the current file
+ if os.path.exists(self.sourcefilename):
+ with open(self.sourcefilename, "r") as fp:
+ needs_written = not (fp.read() == source_data)
+ else:
+ needs_written = True
+
+ # Actually write the file out if it doesn't match
+ if needs_written:
+ _ensure_dir(self.sourcefilename)
+ with open(self.sourcefilename, "w") as fp:
+ fp.write(source_data)
+
+ # Set this flag
+ self._has_source = True
+
+ def _compile_module(self):
+ # compile this C source
+ tmpdir = os.path.dirname(self.sourcefilename)
+ outputfilename = ffiplatform.compile(tmpdir, self.get_extension())
+ try:
+ same = ffiplatform.samefile(outputfilename, self.modulefilename)
+ except OSError:
+ same = False
+ if not same:
+ _ensure_dir(self.modulefilename)
+ shutil.move(outputfilename, self.modulefilename)
+ self._has_module = True
+
+ def _load_library(self):
+ assert self._has_module
+ if self.flags is not None:
+ return self._vengine.load_library(self.flags)
+ else:
+ return self._vengine.load_library()
+
+# ____________________________________________________________
+
+_FORCE_GENERIC_ENGINE = False # for tests
+
+def _locate_engine_class(ffi, force_generic_engine):
+ if _FORCE_GENERIC_ENGINE:
+ force_generic_engine = True
+ if not force_generic_engine:
+ if '__pypy__' in sys.builtin_module_names:
+ force_generic_engine = True
+ else:
+ try:
+ import _cffi_backend
+ except ImportError:
+ _cffi_backend = '?'
+ if ffi._backend is not _cffi_backend:
+ force_generic_engine = True
+ if force_generic_engine:
+ from . import vengine_gen
+ return vengine_gen.VGenericEngine
+ else:
+ from . import vengine_cpy
+ return vengine_cpy.VCPythonEngine
+
+# ____________________________________________________________
+
+_TMPDIR = None
+
+def _caller_dir_pycache():
+ if _TMPDIR:
+ return _TMPDIR
+ result = os.environ.get('CFFI_TMPDIR')
+ if result:
+ return result
+ filename = sys._getframe(2).f_code.co_filename
+ return os.path.abspath(os.path.join(os.path.dirname(filename),
+ '__pycache__'))
+
+def set_tmpdir(dirname):
+ """Set the temporary directory to use instead of __pycache__."""
+ global _TMPDIR
+ _TMPDIR = dirname
+
+def cleanup_tmpdir(tmpdir=None, keep_so=False):
+ """Clean up the temporary directory by removing all files in it
+ called `_cffi_*.{c,so}` as well as the `build` subdirectory."""
+ tmpdir = tmpdir or _caller_dir_pycache()
+ try:
+ filelist = os.listdir(tmpdir)
+ except OSError:
+ return
+ if keep_so:
+ suffix = '.c' # only remove .c files
+ else:
+ suffix = _get_so_suffixes()[0].lower()
+ for fn in filelist:
+ if fn.lower().startswith('_cffi_') and (
+ fn.lower().endswith(suffix) or fn.lower().endswith('.c')):
+ try:
+ os.unlink(os.path.join(tmpdir, fn))
+ except OSError:
+ pass
+ clean_dir = [os.path.join(tmpdir, 'build')]
+ for dir in clean_dir:
+ try:
+ for fn in os.listdir(dir):
+ fn = os.path.join(dir, fn)
+ if os.path.isdir(fn):
+ clean_dir.append(fn)
+ else:
+ os.unlink(fn)
+ except OSError:
+ pass
+
+def _get_so_suffixes():
+ suffixes = _extension_suffixes()
+ if not suffixes:
+ # bah, no C_EXTENSION available. Occurs on pypy without cpyext
+ if sys.platform == 'win32':
+ suffixes = [".pyd"]
+ else:
+ suffixes = [".so"]
+
+ return suffixes
+
+def _ensure_dir(filename):
+ dirname = os.path.dirname(filename)
+ if dirname and not os.path.isdir(dirname):
+ os.makedirs(dirname)
diff --git a/lib/chardet-5.0.0.dist-info/INSTALLER b/lib/chardet-5.0.0.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/chardet-5.0.0.dist-info/LICENSE b/lib/chardet-5.0.0.dist-info/LICENSE
new file mode 100644
index 0000000..4362b49
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/LICENSE
@@ -0,0 +1,502 @@
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 2.1, February 1999
+
+ Copyright (C) 1991, 1999 Free Software Foundation, Inc.
+ 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+[This is the first released version of the Lesser GPL. It also counts
+ as the successor of the GNU Library Public License, version 2, hence
+ the version number 2.1.]
+
+ Preamble
+
+ The licenses for most software are designed to take away your
+freedom to share and change it. By contrast, the GNU General Public
+Licenses are intended to guarantee your freedom to share and change
+free software--to make sure the software is free for all its users.
+
+ This license, the Lesser General Public License, applies to some
+specially designated software packages--typically libraries--of the
+Free Software Foundation and other authors who decide to use it. You
+can use it too, but we suggest you first think carefully about whether
+this license or the ordinary General Public License is the better
+strategy to use in any particular case, based on the explanations below.
+
+ When we speak of free software, we are referring to freedom of use,
+not price. Our General Public Licenses are designed to make sure that
+you have the freedom to distribute copies of free software (and charge
+for this service if you wish); that you receive source code or can get
+it if you want it; that you can change the software and use pieces of
+it in new free programs; and that you are informed that you can do
+these things.
+
+ To protect your rights, we need to make restrictions that forbid
+distributors to deny you these rights or to ask you to surrender these
+rights. These restrictions translate to certain responsibilities for
+you if you distribute copies of the library or if you modify it.
+
+ For example, if you distribute copies of the library, whether gratis
+or for a fee, you must give the recipients all the rights that we gave
+you. You must make sure that they, too, receive or can get the source
+code. If you link other code with the library, you must provide
+complete object files to the recipients, so that they can relink them
+with the library after making changes to the library and recompiling
+it. And you must show them these terms so they know their rights.
+
+ We protect your rights with a two-step method: (1) we copyright the
+library, and (2) we offer you this license, which gives you legal
+permission to copy, distribute and/or modify the library.
+
+ To protect each distributor, we want to make it very clear that
+there is no warranty for the free library. Also, if the library is
+modified by someone else and passed on, the recipients should know
+that what they have is not the original version, so that the original
+author's reputation will not be affected by problems that might be
+introduced by others.
+
+ Finally, software patents pose a constant threat to the existence of
+any free program. We wish to make sure that a company cannot
+effectively restrict the users of a free program by obtaining a
+restrictive license from a patent holder. Therefore, we insist that
+any patent license obtained for a version of the library must be
+consistent with the full freedom of use specified in this license.
+
+ Most GNU software, including some libraries, is covered by the
+ordinary GNU General Public License. This license, the GNU Lesser
+General Public License, applies to certain designated libraries, and
+is quite different from the ordinary General Public License. We use
+this license for certain libraries in order to permit linking those
+libraries into non-free programs.
+
+ When a program is linked with a library, whether statically or using
+a shared library, the combination of the two is legally speaking a
+combined work, a derivative of the original library. The ordinary
+General Public License therefore permits such linking only if the
+entire combination fits its criteria of freedom. The Lesser General
+Public License permits more lax criteria for linking other code with
+the library.
+
+ We call this license the "Lesser" General Public License because it
+does Less to protect the user's freedom than the ordinary General
+Public License. It also provides other free software developers Less
+of an advantage over competing non-free programs. These disadvantages
+are the reason we use the ordinary General Public License for many
+libraries. However, the Lesser license provides advantages in certain
+special circumstances.
+
+ For example, on rare occasions, there may be a special need to
+encourage the widest possible use of a certain library, so that it becomes
+a de-facto standard. To achieve this, non-free programs must be
+allowed to use the library. A more frequent case is that a free
+library does the same job as widely used non-free libraries. In this
+case, there is little to gain by limiting the free library to free
+software only, so we use the Lesser General Public License.
+
+ In other cases, permission to use a particular library in non-free
+programs enables a greater number of people to use a large body of
+free software. For example, permission to use the GNU C Library in
+non-free programs enables many more people to use the whole GNU
+operating system, as well as its variant, the GNU/Linux operating
+system.
+
+ Although the Lesser General Public License is Less protective of the
+users' freedom, it does ensure that the user of a program that is
+linked with the Library has the freedom and the wherewithal to run
+that program using a modified version of the Library.
+
+ The precise terms and conditions for copying, distribution and
+modification follow. Pay close attention to the difference between a
+"work based on the library" and a "work that uses the library". The
+former contains code derived from the library, whereas the latter must
+be combined with the library in order to run.
+
+ GNU LESSER GENERAL PUBLIC LICENSE
+ TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
+
+ 0. This License Agreement applies to any software library or other
+program which contains a notice placed by the copyright holder or
+other authorized party saying it may be distributed under the terms of
+this Lesser General Public License (also called "this License").
+Each licensee is addressed as "you".
+
+ A "library" means a collection of software functions and/or data
+prepared so as to be conveniently linked with application programs
+(which use some of those functions and data) to form executables.
+
+ The "Library", below, refers to any such software library or work
+which has been distributed under these terms. A "work based on the
+Library" means either the Library or any derivative work under
+copyright law: that is to say, a work containing the Library or a
+portion of it, either verbatim or with modifications and/or translated
+straightforwardly into another language. (Hereinafter, translation is
+included without limitation in the term "modification".)
+
+ "Source code" for a work means the preferred form of the work for
+making modifications to it. For a library, complete source code means
+all the source code for all modules it contains, plus any associated
+interface definition files, plus the scripts used to control compilation
+and installation of the library.
+
+ Activities other than copying, distribution and modification are not
+covered by this License; they are outside its scope. The act of
+running a program using the Library is not restricted, and output from
+such a program is covered only if its contents constitute a work based
+on the Library (independent of the use of the Library in a tool for
+writing it). Whether that is true depends on what the Library does
+and what the program that uses the Library does.
+
+ 1. You may copy and distribute verbatim copies of the Library's
+complete source code as you receive it, in any medium, provided that
+you conspicuously and appropriately publish on each copy an
+appropriate copyright notice and disclaimer of warranty; keep intact
+all the notices that refer to this License and to the absence of any
+warranty; and distribute a copy of this License along with the
+Library.
+
+ You may charge a fee for the physical act of transferring a copy,
+and you may at your option offer warranty protection in exchange for a
+fee.
+
+ 2. You may modify your copy or copies of the Library or any portion
+of it, thus forming a work based on the Library, and copy and
+distribute such modifications or work under the terms of Section 1
+above, provided that you also meet all of these conditions:
+
+ a) The modified work must itself be a software library.
+
+ b) You must cause the files modified to carry prominent notices
+ stating that you changed the files and the date of any change.
+
+ c) You must cause the whole of the work to be licensed at no
+ charge to all third parties under the terms of this License.
+
+ d) If a facility in the modified Library refers to a function or a
+ table of data to be supplied by an application program that uses
+ the facility, other than as an argument passed when the facility
+ is invoked, then you must make a good faith effort to ensure that,
+ in the event an application does not supply such function or
+ table, the facility still operates, and performs whatever part of
+ its purpose remains meaningful.
+
+ (For example, a function in a library to compute square roots has
+ a purpose that is entirely well-defined independent of the
+ application. Therefore, Subsection 2d requires that any
+ application-supplied function or table used by this function must
+ be optional: if the application does not supply it, the square
+ root function must still compute square roots.)
+
+These requirements apply to the modified work as a whole. If
+identifiable sections of that work are not derived from the Library,
+and can be reasonably considered independent and separate works in
+themselves, then this License, and its terms, do not apply to those
+sections when you distribute them as separate works. But when you
+distribute the same sections as part of a whole which is a work based
+on the Library, the distribution of the whole must be on the terms of
+this License, whose permissions for other licensees extend to the
+entire whole, and thus to each and every part regardless of who wrote
+it.
+
+Thus, it is not the intent of this section to claim rights or contest
+your rights to work written entirely by you; rather, the intent is to
+exercise the right to control the distribution of derivative or
+collective works based on the Library.
+
+In addition, mere aggregation of another work not based on the Library
+with the Library (or with a work based on the Library) on a volume of
+a storage or distribution medium does not bring the other work under
+the scope of this License.
+
+ 3. You may opt to apply the terms of the ordinary GNU General Public
+License instead of this License to a given copy of the Library. To do
+this, you must alter all the notices that refer to this License, so
+that they refer to the ordinary GNU General Public License, version 2,
+instead of to this License. (If a newer version than version 2 of the
+ordinary GNU General Public License has appeared, then you can specify
+that version instead if you wish.) Do not make any other change in
+these notices.
+
+ Once this change is made in a given copy, it is irreversible for
+that copy, so the ordinary GNU General Public License applies to all
+subsequent copies and derivative works made from that copy.
+
+ This option is useful when you wish to copy part of the code of
+the Library into a program that is not a library.
+
+ 4. You may copy and distribute the Library (or a portion or
+derivative of it, under Section 2) in object code or executable form
+under the terms of Sections 1 and 2 above provided that you accompany
+it with the complete corresponding machine-readable source code, which
+must be distributed under the terms of Sections 1 and 2 above on a
+medium customarily used for software interchange.
+
+ If distribution of object code is made by offering access to copy
+from a designated place, then offering equivalent access to copy the
+source code from the same place satisfies the requirement to
+distribute the source code, even though third parties are not
+compelled to copy the source along with the object code.
+
+ 5. A program that contains no derivative of any portion of the
+Library, but is designed to work with the Library by being compiled or
+linked with it, is called a "work that uses the Library". Such a
+work, in isolation, is not a derivative work of the Library, and
+therefore falls outside the scope of this License.
+
+ However, linking a "work that uses the Library" with the Library
+creates an executable that is a derivative of the Library (because it
+contains portions of the Library), rather than a "work that uses the
+library". The executable is therefore covered by this License.
+Section 6 states terms for distribution of such executables.
+
+ When a "work that uses the Library" uses material from a header file
+that is part of the Library, the object code for the work may be a
+derivative work of the Library even though the source code is not.
+Whether this is true is especially significant if the work can be
+linked without the Library, or if the work is itself a library. The
+threshold for this to be true is not precisely defined by law.
+
+ If such an object file uses only numerical parameters, data
+structure layouts and accessors, and small macros and small inline
+functions (ten lines or less in length), then the use of the object
+file is unrestricted, regardless of whether it is legally a derivative
+work. (Executables containing this object code plus portions of the
+Library will still fall under Section 6.)
+
+ Otherwise, if the work is a derivative of the Library, you may
+distribute the object code for the work under the terms of Section 6.
+Any executables containing that work also fall under Section 6,
+whether or not they are linked directly with the Library itself.
+
+ 6. As an exception to the Sections above, you may also combine or
+link a "work that uses the Library" with the Library to produce a
+work containing portions of the Library, and distribute that work
+under terms of your choice, provided that the terms permit
+modification of the work for the customer's own use and reverse
+engineering for debugging such modifications.
+
+ You must give prominent notice with each copy of the work that the
+Library is used in it and that the Library and its use are covered by
+this License. You must supply a copy of this License. If the work
+during execution displays copyright notices, you must include the
+copyright notice for the Library among them, as well as a reference
+directing the user to the copy of this License. Also, you must do one
+of these things:
+
+ a) Accompany the work with the complete corresponding
+ machine-readable source code for the Library including whatever
+ changes were used in the work (which must be distributed under
+ Sections 1 and 2 above); and, if the work is an executable linked
+ with the Library, with the complete machine-readable "work that
+ uses the Library", as object code and/or source code, so that the
+ user can modify the Library and then relink to produce a modified
+ executable containing the modified Library. (It is understood
+ that the user who changes the contents of definitions files in the
+ Library will not necessarily be able to recompile the application
+ to use the modified definitions.)
+
+ b) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (1) uses at run time a
+ copy of the library already present on the user's computer system,
+ rather than copying library functions into the executable, and (2)
+ will operate properly with a modified version of the library, if
+ the user installs one, as long as the modified version is
+ interface-compatible with the version that the work was made with.
+
+ c) Accompany the work with a written offer, valid for at
+ least three years, to give the same user the materials
+ specified in Subsection 6a, above, for a charge no more
+ than the cost of performing this distribution.
+
+ d) If distribution of the work is made by offering access to copy
+ from a designated place, offer equivalent access to copy the above
+ specified materials from the same place.
+
+ e) Verify that the user has already received a copy of these
+ materials or that you have already sent this user a copy.
+
+ For an executable, the required form of the "work that uses the
+Library" must include any data and utility programs needed for
+reproducing the executable from it. However, as a special exception,
+the materials to be distributed need not include anything that is
+normally distributed (in either source or binary form) with the major
+components (compiler, kernel, and so on) of the operating system on
+which the executable runs, unless that component itself accompanies
+the executable.
+
+ It may happen that this requirement contradicts the license
+restrictions of other proprietary libraries that do not normally
+accompany the operating system. Such a contradiction means you cannot
+use both them and the Library together in an executable that you
+distribute.
+
+ 7. You may place library facilities that are a work based on the
+Library side-by-side in a single library together with other library
+facilities not covered by this License, and distribute such a combined
+library, provided that the separate distribution of the work based on
+the Library and of the other library facilities is otherwise
+permitted, and provided that you do these two things:
+
+ a) Accompany the combined library with a copy of the same work
+ based on the Library, uncombined with any other library
+ facilities. This must be distributed under the terms of the
+ Sections above.
+
+ b) Give prominent notice with the combined library of the fact
+ that part of it is a work based on the Library, and explaining
+ where to find the accompanying uncombined form of the same work.
+
+ 8. You may not copy, modify, sublicense, link with, or distribute
+the Library except as expressly provided under this License. Any
+attempt otherwise to copy, modify, sublicense, link with, or
+distribute the Library is void, and will automatically terminate your
+rights under this License. However, parties who have received copies,
+or rights, from you under this License will not have their licenses
+terminated so long as such parties remain in full compliance.
+
+ 9. You are not required to accept this License, since you have not
+signed it. However, nothing else grants you permission to modify or
+distribute the Library or its derivative works. These actions are
+prohibited by law if you do not accept this License. Therefore, by
+modifying or distributing the Library (or any work based on the
+Library), you indicate your acceptance of this License to do so, and
+all its terms and conditions for copying, distributing or modifying
+the Library or works based on it.
+
+ 10. Each time you redistribute the Library (or any work based on the
+Library), the recipient automatically receives a license from the
+original licensor to copy, distribute, link with or modify the Library
+subject to these terms and conditions. You may not impose any further
+restrictions on the recipients' exercise of the rights granted herein.
+You are not responsible for enforcing compliance by third parties with
+this License.
+
+ 11. If, as a consequence of a court judgment or allegation of patent
+infringement or for any other reason (not limited to patent issues),
+conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot
+distribute so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you
+may not distribute the Library at all. For example, if a patent
+license would not permit royalty-free redistribution of the Library by
+all those who receive copies directly or indirectly through you, then
+the only way you could satisfy both it and this License would be to
+refrain entirely from distribution of the Library.
+
+If any portion of this section is held invalid or unenforceable under any
+particular circumstance, the balance of the section is intended to apply,
+and the section as a whole is intended to apply in other circumstances.
+
+It is not the purpose of this section to induce you to infringe any
+patents or other property right claims or to contest validity of any
+such claims; this section has the sole purpose of protecting the
+integrity of the free software distribution system which is
+implemented by public license practices. Many people have made
+generous contributions to the wide range of software distributed
+through that system in reliance on consistent application of that
+system; it is up to the author/donor to decide if he or she is willing
+to distribute software through any other system and a licensee cannot
+impose that choice.
+
+This section is intended to make thoroughly clear what is believed to
+be a consequence of the rest of this License.
+
+ 12. If the distribution and/or use of the Library is restricted in
+certain countries either by patents or by copyrighted interfaces, the
+original copyright holder who places the Library under this License may add
+an explicit geographical distribution limitation excluding those countries,
+so that distribution is permitted only in or among countries not thus
+excluded. In such case, this License incorporates the limitation as if
+written in the body of this License.
+
+ 13. The Free Software Foundation may publish revised and/or new
+versions of the Lesser General Public License from time to time.
+Such new versions will be similar in spirit to the present version,
+but may differ in detail to address new problems or concerns.
+
+Each version is given a distinguishing version number. If the Library
+specifies a version number of this License which applies to it and
+"any later version", you have the option of following the terms and
+conditions either of that version or of any later version published by
+the Free Software Foundation. If the Library does not specify a
+license version number, you may choose any version ever published by
+the Free Software Foundation.
+
+ 14. If you wish to incorporate parts of the Library into other free
+programs whose distribution conditions are incompatible with these,
+write to the author to ask for permission. For software which is
+copyrighted by the Free Software Foundation, write to the Free
+Software Foundation; we sometimes make exceptions for this. Our
+decision will be guided by the two goals of preserving the free status
+of all derivatives of our free software and of promoting the sharing
+and reuse of software generally.
+
+ NO WARRANTY
+
+ 15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO
+WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW.
+EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR
+OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY
+KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE
+LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME
+THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN
+WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY
+AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU
+FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR
+CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE
+LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING
+RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A
+FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF
+SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
+DAMAGES.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Libraries
+
+ If you develop a new library, and you want it to be of the greatest
+possible use to the public, we recommend making it free software that
+everyone can redistribute and change. You can do so by permitting
+redistribution under these terms (or, alternatively, under the terms of the
+ordinary General Public License).
+
+ To apply these terms, attach the following notices to the library. It is
+safest to attach them to the start of each source file to most effectively
+convey the exclusion of warranty; and each file should have at least the
+"copyright" line and a pointer to where the full notice is found.
+
+ <one line to give the library's name and a brief idea of what it does.>
+ Copyright (C) <year> <name of author>
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+Also add information on how to contact you by electronic and paper mail.
+
+You should also get your employer (if you work as a programmer) or your
+school, if any, to sign a "copyright disclaimer" for the library, if
+necessary. Here is a sample; alter the names:
+
+ Yoyodyne, Inc., hereby disclaims all copyright interest in the
+ library `Frob' (a library for tweaking knobs) written by James Random Hacker.
+
+ <signature of Ty Coon>, 1 April 1990
+ Ty Coon, President of Vice
+
+That's all there is to it!
diff --git a/lib/chardet-5.0.0.dist-info/METADATA b/lib/chardet-5.0.0.dist-info/METADATA
new file mode 100644
index 0000000..200e118
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/METADATA
@@ -0,0 +1,100 @@
+Metadata-Version: 2.1
+Name: chardet
+Version: 5.0.0
+Summary: Universal encoding detector for Python 3
+Home-page: https://github.com/chardet/chardet
+Author: Mark Pilgrim
+Author-email: mark@diveintomark.org
+Maintainer: Daniel Blanchard
+Maintainer-email: dan.blanchard@gmail.com
+License: LGPL
+Project-URL: Documentation, https://chardet.readthedocs.io/
+Project-URL: GitHub Project, https://github.com/chardet/chardet
+Project-URL: Issue Tracker, https://github.com/chardet/chardet/issues
+Keywords: encoding,i18n,xml
+Platform: UNKNOWN
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Intended Audience :: Developers
+Classifier: License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)
+Classifier: Operating System :: OS Independent
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: Implementation :: CPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+Classifier: Topic :: Software Development :: Libraries :: Python Modules
+Classifier: Topic :: Text Processing :: Linguistic
+Requires-Python: >=3.6
+License-File: LICENSE
+
+Chardet: The Universal Character Encoding Detector
+--------------------------------------------------
+
+.. image:: https://img.shields.io/travis/chardet/chardet/stable.svg
+ :alt: Build status
+ :target: https://travis-ci.org/chardet/chardet
+
+.. image:: https://img.shields.io/coveralls/chardet/chardet/stable.svg
+ :target: https://coveralls.io/r/chardet/chardet
+
+.. image:: https://img.shields.io/pypi/v/chardet.svg
+ :target: https://warehouse.python.org/project/chardet/
+ :alt: Latest version on PyPI
+
+.. image:: https://img.shields.io/pypi/l/chardet.svg
+ :alt: License
+
+
+Detects
+ - ASCII, UTF-8, UTF-16 (2 variants), UTF-32 (4 variants)
+ - Big5, GB2312, EUC-TW, HZ-GB-2312, ISO-2022-CN (Traditional and Simplified Chinese)
+ - EUC-JP, SHIFT_JIS, CP932, ISO-2022-JP (Japanese)
+ - EUC-KR, ISO-2022-KR, Johab (Korean)
+ - KOI8-R, MacCyrillic, IBM855, IBM866, ISO-8859-5, windows-1251 (Cyrillic)
+ - ISO-8859-5, windows-1251 (Bulgarian)
+ - ISO-8859-1, windows-1252 (Western European languages)
+ - ISO-8859-7, windows-1253 (Greek)
+ - ISO-8859-8, windows-1255 (Visual and Logical Hebrew)
+ - TIS-620 (Thai)
+
+.. note::
+ Our ISO-8859-2 and windows-1250 (Hungarian) probers have been temporarily
+ disabled until we can retrain the models.
+
+Requires Python 3.6+.
+
+Installation
+------------
+
+Install from `PyPI <https://pypi.org/project/chardet/>`_::
+
+ pip install chardet
+
+Documentation
+-------------
+
+For users, docs are now available at https://chardet.readthedocs.io/.
+
+Command-line Tool
+-----------------
+
+chardet comes with a command-line script which reports on the encodings of one
+or more files::
+
+ % chardetect somefile someotherfile
+ somefile: windows-1252 with confidence 0.5
+ someotherfile: ascii with confidence 1.0
+
+About
+-----
+
+This is a continuation of Mark Pilgrim's excellent original chardet port from C, and `Ian Cordasco <https://github.com/sigmavirus24>`_'s
+`charade <https://github.com/sigmavirus24/charade>`_ Python 3-compatible fork.
+
+:maintainer: Dan Blanchard
+
+
diff --git a/lib/chardet-5.0.0.dist-info/RECORD b/lib/chardet-5.0.0.dist-info/RECORD
new file mode 100644
index 0000000..80cab97
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/RECORD
@@ -0,0 +1,99 @@
+../../bin/chardetect,sha256=_q2tAuY1zbFrthRWpCvDlyurIxNTLWpbSrbxr42VZDU,230
+chardet-5.0.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+chardet-5.0.0.dist-info/LICENSE,sha256=3GJlINzVOiL3J68-5Cx3DlbJemT-OtsGN5nYqwMv5VE,26530
+chardet-5.0.0.dist-info/METADATA,sha256=m6a6fLYw7G1z4XT9R2IJ3W1rMmxXtnhgNtH3DzClnSs,3423
+chardet-5.0.0.dist-info/RECORD,,
+chardet-5.0.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+chardet-5.0.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
+chardet-5.0.0.dist-info/entry_points.txt,sha256=fAMmhu5eJ-zAJ-smfqQwRClQ3-nozOCmvJ6-E8lgGJo,60
+chardet-5.0.0.dist-info/top_level.txt,sha256=AowzBbZy4x8EirABDdJSLJZMkJ_53iIag8xfKR6D7kI,8
+chardet/__init__.py,sha256=9-r0i294avRciob2HKVcKf6GJmXPHpgMqIijVrqHBDU,3705
+chardet/__pycache__/__init__.cpython-39.pyc,,
+chardet/__pycache__/big5freq.cpython-39.pyc,,
+chardet/__pycache__/big5prober.cpython-39.pyc,,
+chardet/__pycache__/chardistribution.cpython-39.pyc,,
+chardet/__pycache__/charsetgroupprober.cpython-39.pyc,,
+chardet/__pycache__/charsetprober.cpython-39.pyc,,
+chardet/__pycache__/codingstatemachine.cpython-39.pyc,,
+chardet/__pycache__/cp949prober.cpython-39.pyc,,
+chardet/__pycache__/enums.cpython-39.pyc,,
+chardet/__pycache__/escprober.cpython-39.pyc,,
+chardet/__pycache__/escsm.cpython-39.pyc,,
+chardet/__pycache__/eucjpprober.cpython-39.pyc,,
+chardet/__pycache__/euckrfreq.cpython-39.pyc,,
+chardet/__pycache__/euckrprober.cpython-39.pyc,,
+chardet/__pycache__/euctwfreq.cpython-39.pyc,,
+chardet/__pycache__/euctwprober.cpython-39.pyc,,
+chardet/__pycache__/gb2312freq.cpython-39.pyc,,
+chardet/__pycache__/gb2312prober.cpython-39.pyc,,
+chardet/__pycache__/hebrewprober.cpython-39.pyc,,
+chardet/__pycache__/jisfreq.cpython-39.pyc,,
+chardet/__pycache__/johabfreq.cpython-39.pyc,,
+chardet/__pycache__/johabprober.cpython-39.pyc,,
+chardet/__pycache__/jpcntx.cpython-39.pyc,,
+chardet/__pycache__/langbulgarianmodel.cpython-39.pyc,,
+chardet/__pycache__/langgreekmodel.cpython-39.pyc,,
+chardet/__pycache__/langhebrewmodel.cpython-39.pyc,,
+chardet/__pycache__/langhungarianmodel.cpython-39.pyc,,
+chardet/__pycache__/langrussianmodel.cpython-39.pyc,,
+chardet/__pycache__/langthaimodel.cpython-39.pyc,,
+chardet/__pycache__/langturkishmodel.cpython-39.pyc,,
+chardet/__pycache__/latin1prober.cpython-39.pyc,,
+chardet/__pycache__/mbcharsetprober.cpython-39.pyc,,
+chardet/__pycache__/mbcsgroupprober.cpython-39.pyc,,
+chardet/__pycache__/mbcssm.cpython-39.pyc,,
+chardet/__pycache__/sbcharsetprober.cpython-39.pyc,,
+chardet/__pycache__/sbcsgroupprober.cpython-39.pyc,,
+chardet/__pycache__/sjisprober.cpython-39.pyc,,
+chardet/__pycache__/universaldetector.cpython-39.pyc,,
+chardet/__pycache__/utf1632prober.cpython-39.pyc,,
+chardet/__pycache__/utf8prober.cpython-39.pyc,,
+chardet/__pycache__/version.cpython-39.pyc,,
+chardet/big5freq.py,sha256=ltcfP-3PjlNHCoo5e4a7C4z-2DhBTXRfY6jbMbB7P30,31274
+chardet/big5prober.py,sha256=neUXIlq35507yibstiznZWFzyNcMn6EXrqJaUJVPWKg,1741
+chardet/chardistribution.py,sha256=M9NTKdM72KieFKy4TT5eml4PP0WaVcXuY5PpWSFD0FA,9608
+chardet/charsetgroupprober.py,sha256=CaIBAmNitEsYuSgMvgAsMREN4cLxMj5OYwMhVo6MAxk,3817
+chardet/charsetprober.py,sha256=Eo3w8sCmbvnVKOGNW1iy50KATVs8xV-gF7cQ0VG85dQ,4801
+chardet/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+chardet/cli/__pycache__/__init__.cpython-39.pyc,,
+chardet/cli/__pycache__/chardetect.cpython-39.pyc,,
+chardet/cli/chardetect.py,sha256=1qMxT3wrp5vP6ugSf1-Zz3BWwlbCWJ0jzeCuhgX85vw,2406
+chardet/codingstatemachine.py,sha256=BiGR9kgTYbS4gJI5qBmE52HMOBOR_roDvXf7aIehdEk,3559
+chardet/cp949prober.py,sha256=kCQEaOCzMntqv7pAyXEobWTRgIUxYfoiUr0btXO1nI8,1838
+chardet/enums.py,sha256=Rodw4p61Vg9U-oCo6eUuT7uDzKwIbCaA15HwbvCoCNk,1619
+chardet/escprober.py,sha256=girD61r3NsQLnMQXsWWBU4hHuRJzTH3V7-VfTUr-nQY,3864
+chardet/escsm.py,sha256=0Vs4iPPovberMoSxxnK5pI161Xf-mtKgOl14g5Xc7zg,12021
+chardet/eucjpprober.py,sha256=pGgs4lINwCEDV2bxqIZ6hXpaj2j4l2oLsMx6kuOK_zQ,3676
+chardet/euckrfreq.py,sha256=3mHuRvXfsq_QcQysDQFb8qSudvTiol71C6Ic2w57tKM,13566
+chardet/euckrprober.py,sha256=qBuSS2zXWaoUmGdzz3owAnD1GNhuKR_8bYzDC3yxe6I,1731
+chardet/euctwfreq.py,sha256=2alILE1Lh5eqiFJZjzRkMQXolNJRHY5oBQd-vmZYFFM,36913
+chardet/euctwprober.py,sha256=SLnCoJC94jZL8PJio60Q8PZACJA1rVPtUdWMa1W8Pwk,1731
+chardet/gb2312freq.py,sha256=49OrdXzD-HXqwavkqjo8Z7gvs58hONNzDhAyMENNkvY,20735
+chardet/gb2312prober.py,sha256=NS_i52jZE0TnWGkKqFduvu9fzW0nMcS2XbYJ8qSX8hY,1737
+chardet/hebrewprober.py,sha256=1l1hXF8-2IWDrPkf85UvAO1GVtMfY1r11kDgOqa-gU4,13919
+chardet/jisfreq.py,sha256=mm8tfrwqhpOd3wzZKS4NJqkYBQVcDfTM2JiQ5aW932E,25796
+chardet/johabfreq.py,sha256=dBpOYG34GRX6SL8k_LbS9rxZPMjLjoMlgZ03Pz5Hmqc,42498
+chardet/johabprober.py,sha256=C18osd4vMPfy9facw-Y1Lor_9UrW0PeV-zxM2fu441c,1730
+chardet/jpcntx.py,sha256=m1gDpPkRca4EDwym8XSL5YdoILFnFsDbNBYMQV7_-NE,26797
+chardet/langbulgarianmodel.py,sha256=bGoRpxBYtrbSHa6mX6PkEA26v30pWmhDjemhdxmkew8,104550
+chardet/langgreekmodel.py,sha256=3wMlEzQ8oU2MbrL2xN8lkuOB0dCMLBhW6heekxusoc0,98472
+chardet/langhebrewmodel.py,sha256=ZUTqusxMvR_earWPs5w-rH10xoe5sPjd9FLMu1DUIvE,98184
+chardet/langhungarianmodel.py,sha256=N-YtC2EiswyS7XsUicCPRycrIzRNj47Y048odp9qOoo,101351
+chardet/langrussianmodel.py,sha256=6v7RcZKGj0VH0864BHzizKNceAYbHvGts2p00ifC7w4,128023
+chardet/langthaimodel.py,sha256=Mr673U9U8rkQFfUDtLP01pp-0TOsl2o6sb75YEjvpcs,102762
+chardet/langturkishmodel.py,sha256=LkXCjWhGUEzqKXvfasHN0SFBigwKJ3xeWNVZ0EyI0kA,95360
+chardet/latin1prober.py,sha256=u_iGcQMUcZLXvj4B_WXx4caA0C5oaE2Qj1KTpz_RQ1I,5260
+chardet/mbcharsetprober.py,sha256=iKKuB6o_FF80NynRLBDT0UtwOnpLqmL_OspRPMib7CM,3367
+chardet/mbcsgroupprober.py,sha256=1D_kp9nv2_NQRddq9I2WDvB35OJh7Tfpo-OYTnL3B5o,2056
+chardet/mbcssm.py,sha256=EfORNu1WXgnFvpFarU8uJHS8KFif63xmgrHOB4DdDdY,30068
+chardet/metadata/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+chardet/metadata/__pycache__/__init__.cpython-39.pyc,,
+chardet/metadata/__pycache__/languages.cpython-39.pyc,,
+chardet/metadata/languages.py,sha256=HcaBygWtZq3gR8prIkJp_etvkhm2V4pUIToqjPZhgrc,13280
+chardet/sbcharsetprober.py,sha256=VvtWiNRLbHDZ5xgnofsmP1u8VQIkkaAuw3Ir9m1zDzQ,6199
+chardet/sbcsgroupprober.py,sha256=mekr4E3hgT4onmwi8oi1iEGW1CN-Z-BArG6kOtCunJw,4129
+chardet/sjisprober.py,sha256=sLfWS25PVFr5cDGhEf6h_s-RJsyeSteA-4ynsTl_UvA,3749
+chardet/universaldetector.py,sha256=BHeNWt1kn0yQgnR6xNtLAjiNmEQpSHYlKEvuZ9QyR1k,13288
+chardet/utf1632prober.py,sha256=N42YJEOkVDB67c38t5aJhXMG1QvnyWWDMNY5ERzniU0,8289
+chardet/utf8prober.py,sha256=mnLaSBV4gg-amt2WmxKFKWy4vVBedMNgjdbvgzBo0Dc,2709
+chardet/version.py,sha256=u_QYi-DXU1s7fyC_Rwa0I0-UcxMVmH7Co6c7QGKbe3g,242
diff --git a/lib/chardet-5.0.0.dist-info/REQUESTED b/lib/chardet-5.0.0.dist-info/REQUESTED
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/REQUESTED
diff --git a/lib/chardet-5.0.0.dist-info/WHEEL b/lib/chardet-5.0.0.dist-info/WHEEL
new file mode 100644
index 0000000..becc9a6
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/WHEEL
@@ -0,0 +1,5 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.37.1)
+Root-Is-Purelib: true
+Tag: py3-none-any
+
diff --git a/lib/chardet-5.0.0.dist-info/entry_points.txt b/lib/chardet-5.0.0.dist-info/entry_points.txt
new file mode 100644
index 0000000..a884269
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/entry_points.txt
@@ -0,0 +1,3 @@
+[console_scripts]
+chardetect = chardet.cli.chardetect:main
+
diff --git a/lib/chardet-5.0.0.dist-info/top_level.txt b/lib/chardet-5.0.0.dist-info/top_level.txt
new file mode 100644
index 0000000..79236f2
--- /dev/null
+++ b/lib/chardet-5.0.0.dist-info/top_level.txt
@@ -0,0 +1 @@
+chardet
diff --git a/lib/chardet/__init__.py b/lib/chardet/__init__.py
new file mode 100644
index 0000000..e91ad61
--- /dev/null
+++ b/lib/chardet/__init__.py
@@ -0,0 +1,93 @@
+######################## BEGIN LICENSE BLOCK ########################
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .enums import InputState
+from .universaldetector import UniversalDetector
+from .version import VERSION, __version__
+
+__all__ = ["UniversalDetector", "detect", "detect_all", "__version__", "VERSION"]
+
+
+def detect(byte_str):
+ """
+ Detect the encoding of the given byte string.
+
+ :param byte_str: The byte sequence to examine.
+ :type byte_str: ``bytes`` or ``bytearray``
+ """
+ if not isinstance(byte_str, bytearray):
+ if not isinstance(byte_str, bytes):
+ raise TypeError(
+ f"Expected object of type bytes or bytearray, got: {type(byte_str)}"
+ )
+ byte_str = bytearray(byte_str)
+ detector = UniversalDetector()
+ detector.feed(byte_str)
+ return detector.close()
+
+
+def detect_all(byte_str, ignore_threshold=False):
+ """
+ Detect all the possible encodings of the given byte string.
+
+ :param byte_str: The byte sequence to examine.
+ :type byte_str: ``bytes`` or ``bytearray``
+ :param ignore_threshold: Include encodings that are below
+ ``UniversalDetector.MINIMUM_THRESHOLD``
+ in results.
+ :type ignore_threshold: ``bool``
+ """
+ if not isinstance(byte_str, bytearray):
+ if not isinstance(byte_str, bytes):
+ raise TypeError(
+ f"Expected object of type bytes or bytearray, got: {type(byte_str)}"
+ )
+ byte_str = bytearray(byte_str)
+
+ detector = UniversalDetector()
+ detector.feed(byte_str)
+ detector.close()
+
+ if detector.input_state == InputState.HIGH_BYTE:
+ results = []
+ probers = []
+ for prober in detector.charset_probers:
+ if hasattr(prober, "probers"):
+ probers.extend(p for p in prober.probers)
+ else:
+ probers.append(prober)
+ for prober in probers:
+ if ignore_threshold or prober.get_confidence() > detector.MINIMUM_THRESHOLD:
+ charset_name = prober.charset_name or ""
+ lower_charset_name = charset_name.lower()
+ # Use Windows encoding name instead of ISO-8859 if we saw any
+ # extra Windows-specific bytes
+ if lower_charset_name.startswith("iso-8859") and detector.has_win_bytes:
+ charset_name = detector.ISO_WIN_MAP.get(
+ lower_charset_name, charset_name
+ )
+ results.append(
+ {
+ "encoding": charset_name,
+ "confidence": prober.get_confidence(),
+ "language": prober.language,
+ }
+ )
+ if len(results) > 0:
+ return sorted(results, key=lambda result: -result["confidence"])
+
+ return [detector.result]
diff --git a/lib/chardet/big5freq.py b/lib/chardet/big5freq.py
new file mode 100644
index 0000000..87d9f97
--- /dev/null
+++ b/lib/chardet/big5freq.py
@@ -0,0 +1,386 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+# Big5 frequency table
+# by Taiwan's Mandarin Promotion Council
+# <http://www.edu.tw:81/mandr/>
+#
+# 128 --> 0.42261
+# 256 --> 0.57851
+# 512 --> 0.74851
+# 1024 --> 0.89384
+# 2048 --> 0.97583
+#
+# Ideal Distribution Ratio = 0.74851/(1-0.74851) =2.98
+# Random Distribution Ration = 512/(5401-512)=0.105
+#
+# Typical Distribution Ratio about 25% of Ideal one, still much higher than RDR
+
+BIG5_TYPICAL_DISTRIBUTION_RATIO = 0.75
+
+# Char to FreqOrder table
+BIG5_TABLE_SIZE = 5376
+# fmt: off
+BIG5_CHAR_TO_FREQ_ORDER = (
+ 1,1801,1506, 255,1431, 198, 9, 82, 6,5008, 177, 202,3681,1256,2821, 110, # 16
+3814, 33,3274, 261, 76, 44,2114, 16,2946,2187,1176, 659,3971, 26,3451,2653, # 32
+1198,3972,3350,4202, 410,2215, 302, 590, 361,1964, 8, 204, 58,4510,5009,1932, # 48
+ 63,5010,5011, 317,1614, 75, 222, 159,4203,2417,1480,5012,3555,3091, 224,2822, # 64
+3682, 3, 10,3973,1471, 29,2787,1135,2866,1940, 873, 130,3275,1123, 312,5013, # 80
+4511,2052, 507, 252, 682,5014, 142,1915, 124, 206,2947, 34,3556,3204, 64, 604, # 96
+5015,2501,1977,1978, 155,1991, 645, 641,1606,5016,3452, 337, 72, 406,5017, 80, # 112
+ 630, 238,3205,1509, 263, 939,1092,2654, 756,1440,1094,3453, 449, 69,2987, 591, # 128
+ 179,2096, 471, 115,2035,1844, 60, 50,2988, 134, 806,1869, 734,2036,3454, 180, # 144
+ 995,1607, 156, 537,2907, 688,5018, 319,1305, 779,2145, 514,2379, 298,4512, 359, # 160
+2502, 90,2716,1338, 663, 11, 906,1099,2553, 20,2441, 182, 532,1716,5019, 732, # 176
+1376,4204,1311,1420,3206, 25,2317,1056, 113, 399, 382,1950, 242,3455,2474, 529, # 192
+3276, 475,1447,3683,5020, 117, 21, 656, 810,1297,2300,2334,3557,5021, 126,4205, # 208
+ 706, 456, 150, 613,4513, 71,1118,2037,4206, 145,3092, 85, 835, 486,2115,1246, # 224
+1426, 428, 727,1285,1015, 800, 106, 623, 303,1281,5022,2128,2359, 347,3815, 221, # 240
+3558,3135,5023,1956,1153,4207, 83, 296,1199,3093, 192, 624, 93,5024, 822,1898, # 256
+2823,3136, 795,2065, 991,1554,1542,1592, 27, 43,2867, 859, 139,1456, 860,4514, # 272
+ 437, 712,3974, 164,2397,3137, 695, 211,3037,2097, 195,3975,1608,3559,3560,3684, # 288
+3976, 234, 811,2989,2098,3977,2233,1441,3561,1615,2380, 668,2077,1638, 305, 228, # 304
+1664,4515, 467, 415,5025, 262,2099,1593, 239, 108, 300, 200,1033, 512,1247,2078, # 320
+5026,5027,2176,3207,3685,2682, 593, 845,1062,3277, 88,1723,2038,3978,1951, 212, # 336
+ 266, 152, 149, 468,1899,4208,4516, 77, 187,5028,3038, 37, 5,2990,5029,3979, # 352
+5030,5031, 39,2524,4517,2908,3208,2079, 55, 148, 74,4518, 545, 483,1474,1029, # 368
+1665, 217,1870,1531,3138,1104,2655,4209, 24, 172,3562, 900,3980,3563,3564,4519, # 384
+ 32,1408,2824,1312, 329, 487,2360,2251,2717, 784,2683, 4,3039,3351,1427,1789, # 400
+ 188, 109, 499,5032,3686,1717,1790, 888,1217,3040,4520,5033,3565,5034,3352,1520, # 416
+3687,3981, 196,1034, 775,5035,5036, 929,1816, 249, 439, 38,5037,1063,5038, 794, # 432
+3982,1435,2301, 46, 178,3278,2066,5039,2381,5040, 214,1709,4521, 804, 35, 707, # 448
+ 324,3688,1601,2554, 140, 459,4210,5041,5042,1365, 839, 272, 978,2262,2580,3456, # 464
+2129,1363,3689,1423, 697, 100,3094, 48, 70,1231, 495,3139,2196,5043,1294,5044, # 480
+2080, 462, 586,1042,3279, 853, 256, 988, 185,2382,3457,1698, 434,1084,5045,3458, # 496
+ 314,2625,2788,4522,2335,2336, 569,2285, 637,1817,2525, 757,1162,1879,1616,3459, # 512
+ 287,1577,2116, 768,4523,1671,2868,3566,2526,1321,3816, 909,2418,5046,4211, 933, # 528
+3817,4212,2053,2361,1222,4524, 765,2419,1322, 786,4525,5047,1920,1462,1677,2909, # 544
+1699,5048,4526,1424,2442,3140,3690,2600,3353,1775,1941,3460,3983,4213, 309,1369, # 560
+1130,2825, 364,2234,1653,1299,3984,3567,3985,3986,2656, 525,1085,3041, 902,2001, # 576
+1475, 964,4527, 421,1845,1415,1057,2286, 940,1364,3141, 376,4528,4529,1381, 7, # 592
+2527, 983,2383, 336,1710,2684,1846, 321,3461, 559,1131,3042,2752,1809,1132,1313, # 608
+ 265,1481,1858,5049, 352,1203,2826,3280, 167,1089, 420,2827, 776, 792,1724,3568, # 624
+4214,2443,3281,5050,4215,5051, 446, 229, 333,2753, 901,3818,1200,1557,4530,2657, # 640
+1921, 395,2754,2685,3819,4216,1836, 125, 916,3209,2626,4531,5052,5053,3820,5054, # 656
+5055,5056,4532,3142,3691,1133,2555,1757,3462,1510,2318,1409,3569,5057,2146, 438, # 672
+2601,2910,2384,3354,1068, 958,3043, 461, 311,2869,2686,4217,1916,3210,4218,1979, # 688
+ 383, 750,2755,2627,4219, 274, 539, 385,1278,1442,5058,1154,1965, 384, 561, 210, # 704
+ 98,1295,2556,3570,5059,1711,2420,1482,3463,3987,2911,1257, 129,5060,3821, 642, # 720
+ 523,2789,2790,2658,5061, 141,2235,1333, 68, 176, 441, 876, 907,4220, 603,2602, # 736
+ 710, 171,3464, 404, 549, 18,3143,2398,1410,3692,1666,5062,3571,4533,2912,4534, # 752
+5063,2991, 368,5064, 146, 366, 99, 871,3693,1543, 748, 807,1586,1185, 22,2263, # 768
+ 379,3822,3211,5065,3212, 505,1942,2628,1992,1382,2319,5066, 380,2362, 218, 702, # 784
+1818,1248,3465,3044,3572,3355,3282,5067,2992,3694, 930,3283,3823,5068, 59,5069, # 800
+ 585, 601,4221, 497,3466,1112,1314,4535,1802,5070,1223,1472,2177,5071, 749,1837, # 816
+ 690,1900,3824,1773,3988,1476, 429,1043,1791,2236,2117, 917,4222, 447,1086,1629, # 832
+5072, 556,5073,5074,2021,1654, 844,1090, 105, 550, 966,1758,2828,1008,1783, 686, # 848
+1095,5075,2287, 793,1602,5076,3573,2603,4536,4223,2948,2302,4537,3825, 980,2503, # 864
+ 544, 353, 527,4538, 908,2687,2913,5077, 381,2629,1943,1348,5078,1341,1252, 560, # 880
+3095,5079,3467,2870,5080,2054, 973, 886,2081, 143,4539,5081,5082, 157,3989, 496, # 896
+4224, 57, 840, 540,2039,4540,4541,3468,2118,1445, 970,2264,1748,1966,2082,4225, # 912
+3144,1234,1776,3284,2829,3695, 773,1206,2130,1066,2040,1326,3990,1738,1725,4226, # 928
+ 279,3145, 51,1544,2604, 423,1578,2131,2067, 173,4542,1880,5083,5084,1583, 264, # 944
+ 610,3696,4543,2444, 280, 154,5085,5086,5087,1739, 338,1282,3096, 693,2871,1411, # 960
+1074,3826,2445,5088,4544,5089,5090,1240, 952,2399,5091,2914,1538,2688, 685,1483, # 976
+4227,2475,1436, 953,4228,2055,4545, 671,2400, 79,4229,2446,3285, 608, 567,2689, # 992
+3469,4230,4231,1691, 393,1261,1792,2401,5092,4546,5093,5094,5095,5096,1383,1672, # 1008
+3827,3213,1464, 522,1119, 661,1150, 216, 675,4547,3991,1432,3574, 609,4548,2690, # 1024
+2402,5097,5098,5099,4232,3045, 0,5100,2476, 315, 231,2447, 301,3356,4549,2385, # 1040
+5101, 233,4233,3697,1819,4550,4551,5102, 96,1777,1315,2083,5103, 257,5104,1810, # 1056
+3698,2718,1139,1820,4234,2022,1124,2164,2791,1778,2659,5105,3097, 363,1655,3214, # 1072
+5106,2993,5107,5108,5109,3992,1567,3993, 718, 103,3215, 849,1443, 341,3357,2949, # 1088
+1484,5110,1712, 127, 67, 339,4235,2403, 679,1412, 821,5111,5112, 834, 738, 351, # 1104
+2994,2147, 846, 235,1497,1881, 418,1993,3828,2719, 186,1100,2148,2756,3575,1545, # 1120
+1355,2950,2872,1377, 583,3994,4236,2581,2995,5113,1298,3699,1078,2557,3700,2363, # 1136
+ 78,3829,3830, 267,1289,2100,2002,1594,4237, 348, 369,1274,2197,2178,1838,4552, # 1152
+1821,2830,3701,2757,2288,2003,4553,2951,2758, 144,3358, 882,4554,3995,2759,3470, # 1168
+4555,2915,5114,4238,1726, 320,5115,3996,3046, 788,2996,5116,2831,1774,1327,2873, # 1184
+3997,2832,5117,1306,4556,2004,1700,3831,3576,2364,2660, 787,2023, 506, 824,3702, # 1200
+ 534, 323,4557,1044,3359,2024,1901, 946,3471,5118,1779,1500,1678,5119,1882,4558, # 1216
+ 165, 243,4559,3703,2528, 123, 683,4239, 764,4560, 36,3998,1793, 589,2916, 816, # 1232
+ 626,1667,3047,2237,1639,1555,1622,3832,3999,5120,4000,2874,1370,1228,1933, 891, # 1248
+2084,2917, 304,4240,5121, 292,2997,2720,3577, 691,2101,4241,1115,4561, 118, 662, # 1264
+5122, 611,1156, 854,2386,1316,2875, 2, 386, 515,2918,5123,5124,3286, 868,2238, # 1280
+1486, 855,2661, 785,2216,3048,5125,1040,3216,3578,5126,3146, 448,5127,1525,5128, # 1296
+2165,4562,5129,3833,5130,4242,2833,3579,3147, 503, 818,4001,3148,1568, 814, 676, # 1312
+1444, 306,1749,5131,3834,1416,1030, 197,1428, 805,2834,1501,4563,5132,5133,5134, # 1328
+1994,5135,4564,5136,5137,2198, 13,2792,3704,2998,3149,1229,1917,5138,3835,2132, # 1344
+5139,4243,4565,2404,3580,5140,2217,1511,1727,1120,5141,5142, 646,3836,2448, 307, # 1360
+5143,5144,1595,3217,5145,5146,5147,3705,1113,1356,4002,1465,2529,2530,5148, 519, # 1376
+5149, 128,2133, 92,2289,1980,5150,4003,1512, 342,3150,2199,5151,2793,2218,1981, # 1392
+3360,4244, 290,1656,1317, 789, 827,2365,5152,3837,4566, 562, 581,4004,5153, 401, # 1408
+4567,2252, 94,4568,5154,1399,2794,5155,1463,2025,4569,3218,1944,5156, 828,1105, # 1424
+4245,1262,1394,5157,4246, 605,4570,5158,1784,2876,5159,2835, 819,2102, 578,2200, # 1440
+2952,5160,1502, 436,3287,4247,3288,2836,4005,2919,3472,3473,5161,2721,2320,5162, # 1456
+5163,2337,2068, 23,4571, 193, 826,3838,2103, 699,1630,4248,3098, 390,1794,1064, # 1472
+3581,5164,1579,3099,3100,1400,5165,4249,1839,1640,2877,5166,4572,4573, 137,4250, # 1488
+ 598,3101,1967, 780, 104, 974,2953,5167, 278, 899, 253, 402, 572, 504, 493,1339, # 1504
+5168,4006,1275,4574,2582,2558,5169,3706,3049,3102,2253, 565,1334,2722, 863, 41, # 1520
+5170,5171,4575,5172,1657,2338, 19, 463,2760,4251, 606,5173,2999,3289,1087,2085, # 1536
+1323,2662,3000,5174,1631,1623,1750,4252,2691,5175,2878, 791,2723,2663,2339, 232, # 1552
+2421,5176,3001,1498,5177,2664,2630, 755,1366,3707,3290,3151,2026,1609, 119,1918, # 1568
+3474, 862,1026,4253,5178,4007,3839,4576,4008,4577,2265,1952,2477,5179,1125, 817, # 1584
+4254,4255,4009,1513,1766,2041,1487,4256,3050,3291,2837,3840,3152,5180,5181,1507, # 1600
+5182,2692, 733, 40,1632,1106,2879, 345,4257, 841,2531, 230,4578,3002,1847,3292, # 1616
+3475,5183,1263, 986,3476,5184, 735, 879, 254,1137, 857, 622,1300,1180,1388,1562, # 1632
+4010,4011,2954, 967,2761,2665,1349, 592,2134,1692,3361,3003,1995,4258,1679,4012, # 1648
+1902,2188,5185, 739,3708,2724,1296,1290,5186,4259,2201,2202,1922,1563,2605,2559, # 1664
+1871,2762,3004,5187, 435,5188, 343,1108, 596, 17,1751,4579,2239,3477,3709,5189, # 1680
+4580, 294,3582,2955,1693, 477, 979, 281,2042,3583, 643,2043,3710,2631,2795,2266, # 1696
+1031,2340,2135,2303,3584,4581, 367,1249,2560,5190,3585,5191,4582,1283,3362,2005, # 1712
+ 240,1762,3363,4583,4584, 836,1069,3153, 474,5192,2149,2532, 268,3586,5193,3219, # 1728
+1521,1284,5194,1658,1546,4260,5195,3587,3588,5196,4261,3364,2693,1685,4262, 961, # 1744
+1673,2632, 190,2006,2203,3841,4585,4586,5197, 570,2504,3711,1490,5198,4587,2633, # 1760
+3293,1957,4588, 584,1514, 396,1045,1945,5199,4589,1968,2449,5200,5201,4590,4013, # 1776
+ 619,5202,3154,3294, 215,2007,2796,2561,3220,4591,3221,4592, 763,4263,3842,4593, # 1792
+5203,5204,1958,1767,2956,3365,3712,1174, 452,1477,4594,3366,3155,5205,2838,1253, # 1808
+2387,2189,1091,2290,4264, 492,5206, 638,1169,1825,2136,1752,4014, 648, 926,1021, # 1824
+1324,4595, 520,4596, 997, 847,1007, 892,4597,3843,2267,1872,3713,2405,1785,4598, # 1840
+1953,2957,3103,3222,1728,4265,2044,3714,4599,2008,1701,3156,1551, 30,2268,4266, # 1856
+5207,2027,4600,3589,5208, 501,5209,4267, 594,3478,2166,1822,3590,3479,3591,3223, # 1872
+ 829,2839,4268,5210,1680,3157,1225,4269,5211,3295,4601,4270,3158,2341,5212,4602, # 1888
+4271,5213,4015,4016,5214,1848,2388,2606,3367,5215,4603, 374,4017, 652,4272,4273, # 1904
+ 375,1140, 798,5216,5217,5218,2366,4604,2269, 546,1659, 138,3051,2450,4605,5219, # 1920
+2254, 612,1849, 910, 796,3844,1740,1371, 825,3845,3846,5220,2920,2562,5221, 692, # 1936
+ 444,3052,2634, 801,4606,4274,5222,1491, 244,1053,3053,4275,4276, 340,5223,4018, # 1952
+1041,3005, 293,1168, 87,1357,5224,1539, 959,5225,2240, 721, 694,4277,3847, 219, # 1968
+1478, 644,1417,3368,2666,1413,1401,1335,1389,4019,5226,5227,3006,2367,3159,1826, # 1984
+ 730,1515, 184,2840, 66,4607,5228,1660,2958, 246,3369, 378,1457, 226,3480, 975, # 2000
+4020,2959,1264,3592, 674, 696,5229, 163,5230,1141,2422,2167, 713,3593,3370,4608, # 2016
+4021,5231,5232,1186, 15,5233,1079,1070,5234,1522,3224,3594, 276,1050,2725, 758, # 2032
+1126, 653,2960,3296,5235,2342, 889,3595,4022,3104,3007, 903,1250,4609,4023,3481, # 2048
+3596,1342,1681,1718, 766,3297, 286, 89,2961,3715,5236,1713,5237,2607,3371,3008, # 2064
+5238,2962,2219,3225,2880,5239,4610,2505,2533, 181, 387,1075,4024, 731,2190,3372, # 2080
+5240,3298, 310, 313,3482,2304, 770,4278, 54,3054, 189,4611,3105,3848,4025,5241, # 2096
+1230,1617,1850, 355,3597,4279,4612,3373, 111,4280,3716,1350,3160,3483,3055,4281, # 2112
+2150,3299,3598,5242,2797,4026,4027,3009, 722,2009,5243,1071, 247,1207,2343,2478, # 2128
+1378,4613,2010, 864,1437,1214,4614, 373,3849,1142,2220, 667,4615, 442,2763,2563, # 2144
+3850,4028,1969,4282,3300,1840, 837, 170,1107, 934,1336,1883,5244,5245,2119,4283, # 2160
+2841, 743,1569,5246,4616,4284, 582,2389,1418,3484,5247,1803,5248, 357,1395,1729, # 2176
+3717,3301,2423,1564,2241,5249,3106,3851,1633,4617,1114,2086,4285,1532,5250, 482, # 2192
+2451,4618,5251,5252,1492, 833,1466,5253,2726,3599,1641,2842,5254,1526,1272,3718, # 2208
+4286,1686,1795, 416,2564,1903,1954,1804,5255,3852,2798,3853,1159,2321,5256,2881, # 2224
+4619,1610,1584,3056,2424,2764, 443,3302,1163,3161,5257,5258,4029,5259,4287,2506, # 2240
+3057,4620,4030,3162,2104,1647,3600,2011,1873,4288,5260,4289, 431,3485,5261, 250, # 2256
+ 97, 81,4290,5262,1648,1851,1558, 160, 848,5263, 866, 740,1694,5264,2204,2843, # 2272
+3226,4291,4621,3719,1687, 950,2479, 426, 469,3227,3720,3721,4031,5265,5266,1188, # 2288
+ 424,1996, 861,3601,4292,3854,2205,2694, 168,1235,3602,4293,5267,2087,1674,4622, # 2304
+3374,3303, 220,2565,1009,5268,3855, 670,3010, 332,1208, 717,5269,5270,3603,2452, # 2320
+4032,3375,5271, 513,5272,1209,2882,3376,3163,4623,1080,5273,5274,5275,5276,2534, # 2336
+3722,3604, 815,1587,4033,4034,5277,3605,3486,3856,1254,4624,1328,3058,1390,4035, # 2352
+1741,4036,3857,4037,5278, 236,3858,2453,3304,5279,5280,3723,3859,1273,3860,4625, # 2368
+5281, 308,5282,4626, 245,4627,1852,2480,1307,2583, 430, 715,2137,2454,5283, 270, # 2384
+ 199,2883,4038,5284,3606,2727,1753, 761,1754, 725,1661,1841,4628,3487,3724,5285, # 2400
+5286, 587, 14,3305, 227,2608, 326, 480,2270, 943,2765,3607, 291, 650,1884,5287, # 2416
+1702,1226, 102,1547, 62,3488, 904,4629,3489,1164,4294,5288,5289,1224,1548,2766, # 2432
+ 391, 498,1493,5290,1386,1419,5291,2056,1177,4630, 813, 880,1081,2368, 566,1145, # 2448
+4631,2291,1001,1035,2566,2609,2242, 394,1286,5292,5293,2069,5294, 86,1494,1730, # 2464
+4039, 491,1588, 745, 897,2963, 843,3377,4040,2767,2884,3306,1768, 998,2221,2070, # 2480
+ 397,1827,1195,1970,3725,3011,3378, 284,5295,3861,2507,2138,2120,1904,5296,4041, # 2496
+2151,4042,4295,1036,3490,1905, 114,2567,4296, 209,1527,5297,5298,2964,2844,2635, # 2512
+2390,2728,3164, 812,2568,5299,3307,5300,1559, 737,1885,3726,1210, 885, 28,2695, # 2528
+3608,3862,5301,4297,1004,1780,4632,5302, 346,1982,2222,2696,4633,3863,1742, 797, # 2544
+1642,4043,1934,1072,1384,2152, 896,4044,3308,3727,3228,2885,3609,5303,2569,1959, # 2560
+4634,2455,1786,5304,5305,5306,4045,4298,1005,1308,3728,4299,2729,4635,4636,1528, # 2576
+2610, 161,1178,4300,1983, 987,4637,1101,4301, 631,4046,1157,3229,2425,1343,1241, # 2592
+1016,2243,2570, 372, 877,2344,2508,1160, 555,1935, 911,4047,5307, 466,1170, 169, # 2608
+1051,2921,2697,3729,2481,3012,1182,2012,2571,1251,2636,5308, 992,2345,3491,1540, # 2624
+2730,1201,2071,2406,1997,2482,5309,4638, 528,1923,2191,1503,1874,1570,2369,3379, # 2640
+3309,5310, 557,1073,5311,1828,3492,2088,2271,3165,3059,3107, 767,3108,2799,4639, # 2656
+1006,4302,4640,2346,1267,2179,3730,3230, 778,4048,3231,2731,1597,2667,5312,4641, # 2672
+5313,3493,5314,5315,5316,3310,2698,1433,3311, 131, 95,1504,4049, 723,4303,3166, # 2688
+1842,3610,2768,2192,4050,2028,2105,3731,5317,3013,4051,1218,5318,3380,3232,4052, # 2704
+4304,2584, 248,1634,3864, 912,5319,2845,3732,3060,3865, 654, 53,5320,3014,5321, # 2720
+1688,4642, 777,3494,1032,4053,1425,5322, 191, 820,2121,2846, 971,4643, 931,3233, # 2736
+ 135, 664, 783,3866,1998, 772,2922,1936,4054,3867,4644,2923,3234, 282,2732, 640, # 2752
+1372,3495,1127, 922, 325,3381,5323,5324, 711,2045,5325,5326,4055,2223,2800,1937, # 2768
+4056,3382,2224,2255,3868,2305,5327,4645,3869,1258,3312,4057,3235,2139,2965,4058, # 2784
+4059,5328,2225, 258,3236,4646, 101,1227,5329,3313,1755,5330,1391,3314,5331,2924, # 2800
+2057, 893,5332,5333,5334,1402,4305,2347,5335,5336,3237,3611,5337,5338, 878,1325, # 2816
+1781,2801,4647, 259,1385,2585, 744,1183,2272,4648,5339,4060,2509,5340, 684,1024, # 2832
+4306,5341, 472,3612,3496,1165,3315,4061,4062, 322,2153, 881, 455,1695,1152,1340, # 2848
+ 660, 554,2154,4649,1058,4650,4307, 830,1065,3383,4063,4651,1924,5342,1703,1919, # 2864
+5343, 932,2273, 122,5344,4652, 947, 677,5345,3870,2637, 297,1906,1925,2274,4653, # 2880
+2322,3316,5346,5347,4308,5348,4309, 84,4310, 112, 989,5349, 547,1059,4064, 701, # 2896
+3613,1019,5350,4311,5351,3497, 942, 639, 457,2306,2456, 993,2966, 407, 851, 494, # 2912
+4654,3384, 927,5352,1237,5353,2426,3385, 573,4312, 680, 921,2925,1279,1875, 285, # 2928
+ 790,1448,1984, 719,2168,5354,5355,4655,4065,4066,1649,5356,1541, 563,5357,1077, # 2944
+5358,3386,3061,3498, 511,3015,4067,4068,3733,4069,1268,2572,3387,3238,4656,4657, # 2960
+5359, 535,1048,1276,1189,2926,2029,3167,1438,1373,2847,2967,1134,2013,5360,4313, # 2976
+1238,2586,3109,1259,5361, 700,5362,2968,3168,3734,4314,5363,4315,1146,1876,1907, # 2992
+4658,2611,4070, 781,2427, 132,1589, 203, 147, 273,2802,2407, 898,1787,2155,4071, # 3008
+4072,5364,3871,2803,5365,5366,4659,4660,5367,3239,5368,1635,3872, 965,5369,1805, # 3024
+2699,1516,3614,1121,1082,1329,3317,4073,1449,3873, 65,1128,2848,2927,2769,1590, # 3040
+3874,5370,5371, 12,2668, 45, 976,2587,3169,4661, 517,2535,1013,1037,3240,5372, # 3056
+3875,2849,5373,3876,5374,3499,5375,2612, 614,1999,2323,3877,3110,2733,2638,5376, # 3072
+2588,4316, 599,1269,5377,1811,3735,5378,2700,3111, 759,1060, 489,1806,3388,3318, # 3088
+1358,5379,5380,2391,1387,1215,2639,2256, 490,5381,5382,4317,1759,2392,2348,5383, # 3104
+4662,3878,1908,4074,2640,1807,3241,4663,3500,3319,2770,2349, 874,5384,5385,3501, # 3120
+3736,1859, 91,2928,3737,3062,3879,4664,5386,3170,4075,2669,5387,3502,1202,1403, # 3136
+3880,2969,2536,1517,2510,4665,3503,2511,5388,4666,5389,2701,1886,1495,1731,4076, # 3152
+2370,4667,5390,2030,5391,5392,4077,2702,1216, 237,2589,4318,2324,4078,3881,4668, # 3168
+4669,2703,3615,3504, 445,4670,5393,5394,5395,5396,2771, 61,4079,3738,1823,4080, # 3184
+5397, 687,2046, 935, 925, 405,2670, 703,1096,1860,2734,4671,4081,1877,1367,2704, # 3200
+3389, 918,2106,1782,2483, 334,3320,1611,1093,4672, 564,3171,3505,3739,3390, 945, # 3216
+2641,2058,4673,5398,1926, 872,4319,5399,3506,2705,3112, 349,4320,3740,4082,4674, # 3232
+3882,4321,3741,2156,4083,4675,4676,4322,4677,2408,2047, 782,4084, 400, 251,4323, # 3248
+1624,5400,5401, 277,3742, 299,1265, 476,1191,3883,2122,4324,4325,1109, 205,5402, # 3264
+2590,1000,2157,3616,1861,5403,5404,5405,4678,5406,4679,2573, 107,2484,2158,4085, # 3280
+3507,3172,5407,1533, 541,1301, 158, 753,4326,2886,3617,5408,1696, 370,1088,4327, # 3296
+4680,3618, 579, 327, 440, 162,2244, 269,1938,1374,3508, 968,3063, 56,1396,3113, # 3312
+2107,3321,3391,5409,1927,2159,4681,3016,5410,3619,5411,5412,3743,4682,2485,5413, # 3328
+2804,5414,1650,4683,5415,2613,5416,5417,4086,2671,3392,1149,3393,4087,3884,4088, # 3344
+5418,1076, 49,5419, 951,3242,3322,3323, 450,2850, 920,5420,1812,2805,2371,4328, # 3360
+1909,1138,2372,3885,3509,5421,3243,4684,1910,1147,1518,2428,4685,3886,5422,4686, # 3376
+2393,2614, 260,1796,3244,5423,5424,3887,3324, 708,5425,3620,1704,5426,3621,1351, # 3392
+1618,3394,3017,1887, 944,4329,3395,4330,3064,3396,4331,5427,3744, 422, 413,1714, # 3408
+3325, 500,2059,2350,4332,2486,5428,1344,1911, 954,5429,1668,5430,5431,4089,2409, # 3424
+4333,3622,3888,4334,5432,2307,1318,2512,3114, 133,3115,2887,4687, 629, 31,2851, # 3440
+2706,3889,4688, 850, 949,4689,4090,2970,1732,2089,4335,1496,1853,5433,4091, 620, # 3456
+3245, 981,1242,3745,3397,1619,3746,1643,3326,2140,2457,1971,1719,3510,2169,5434, # 3472
+3246,5435,5436,3398,1829,5437,1277,4690,1565,2048,5438,1636,3623,3116,5439, 869, # 3488
+2852, 655,3890,3891,3117,4092,3018,3892,1310,3624,4691,5440,5441,5442,1733, 558, # 3504
+4692,3747, 335,1549,3065,1756,4336,3748,1946,3511,1830,1291,1192, 470,2735,2108, # 3520
+2806, 913,1054,4093,5443,1027,5444,3066,4094,4693, 982,2672,3399,3173,3512,3247, # 3536
+3248,1947,2807,5445, 571,4694,5446,1831,5447,3625,2591,1523,2429,5448,2090, 984, # 3552
+4695,3749,1960,5449,3750, 852, 923,2808,3513,3751, 969,1519, 999,2049,2325,1705, # 3568
+5450,3118, 615,1662, 151, 597,4095,2410,2326,1049, 275,4696,3752,4337, 568,3753, # 3584
+3626,2487,4338,3754,5451,2430,2275, 409,3249,5452,1566,2888,3514,1002, 769,2853, # 3600
+ 194,2091,3174,3755,2226,3327,4339, 628,1505,5453,5454,1763,2180,3019,4096, 521, # 3616
+1161,2592,1788,2206,2411,4697,4097,1625,4340,4341, 412, 42,3119, 464,5455,2642, # 3632
+4698,3400,1760,1571,2889,3515,2537,1219,2207,3893,2643,2141,2373,4699,4700,3328, # 3648
+1651,3401,3627,5456,5457,3628,2488,3516,5458,3756,5459,5460,2276,2092, 460,5461, # 3664
+4701,5462,3020, 962, 588,3629, 289,3250,2644,1116, 52,5463,3067,1797,5464,5465, # 3680
+5466,1467,5467,1598,1143,3757,4342,1985,1734,1067,4702,1280,3402, 465,4703,1572, # 3696
+ 510,5468,1928,2245,1813,1644,3630,5469,4704,3758,5470,5471,2673,1573,1534,5472, # 3712
+5473, 536,1808,1761,3517,3894,3175,2645,5474,5475,5476,4705,3518,2929,1912,2809, # 3728
+5477,3329,1122, 377,3251,5478, 360,5479,5480,4343,1529, 551,5481,2060,3759,1769, # 3744
+2431,5482,2930,4344,3330,3120,2327,2109,2031,4706,1404, 136,1468,1479, 672,1171, # 3760
+3252,2308, 271,3176,5483,2772,5484,2050, 678,2736, 865,1948,4707,5485,2014,4098, # 3776
+2971,5486,2737,2227,1397,3068,3760,4708,4709,1735,2931,3403,3631,5487,3895, 509, # 3792
+2854,2458,2890,3896,5488,5489,3177,3178,4710,4345,2538,4711,2309,1166,1010, 552, # 3808
+ 681,1888,5490,5491,2972,2973,4099,1287,1596,1862,3179, 358, 453, 736, 175, 478, # 3824
+1117, 905,1167,1097,5492,1854,1530,5493,1706,5494,2181,3519,2292,3761,3520,3632, # 3840
+4346,2093,4347,5495,3404,1193,2489,4348,1458,2193,2208,1863,1889,1421,3331,2932, # 3856
+3069,2182,3521, 595,2123,5496,4100,5497,5498,4349,1707,2646, 223,3762,1359, 751, # 3872
+3121, 183,3522,5499,2810,3021, 419,2374, 633, 704,3897,2394, 241,5500,5501,5502, # 3888
+ 838,3022,3763,2277,2773,2459,3898,1939,2051,4101,1309,3122,2246,1181,5503,1136, # 3904
+2209,3899,2375,1446,4350,2310,4712,5504,5505,4351,1055,2615, 484,3764,5506,4102, # 3920
+ 625,4352,2278,3405,1499,4353,4103,5507,4104,4354,3253,2279,2280,3523,5508,5509, # 3936
+2774, 808,2616,3765,3406,4105,4355,3123,2539, 526,3407,3900,4356, 955,5510,1620, # 3952
+4357,2647,2432,5511,1429,3766,1669,1832, 994, 928,5512,3633,1260,5513,5514,5515, # 3968
+1949,2293, 741,2933,1626,4358,2738,2460, 867,1184, 362,3408,1392,5516,5517,4106, # 3984
+4359,1770,1736,3254,2934,4713,4714,1929,2707,1459,1158,5518,3070,3409,2891,1292, # 4000
+1930,2513,2855,3767,1986,1187,2072,2015,2617,4360,5519,2574,2514,2170,3768,2490, # 4016
+3332,5520,3769,4715,5521,5522, 666,1003,3023,1022,3634,4361,5523,4716,1814,2257, # 4032
+ 574,3901,1603, 295,1535, 705,3902,4362, 283, 858, 417,5524,5525,3255,4717,4718, # 4048
+3071,1220,1890,1046,2281,2461,4107,1393,1599, 689,2575, 388,4363,5526,2491, 802, # 4064
+5527,2811,3903,2061,1405,2258,5528,4719,3904,2110,1052,1345,3256,1585,5529, 809, # 4080
+5530,5531,5532, 575,2739,3524, 956,1552,1469,1144,2328,5533,2329,1560,2462,3635, # 4096
+3257,4108, 616,2210,4364,3180,2183,2294,5534,1833,5535,3525,4720,5536,1319,3770, # 4112
+3771,1211,3636,1023,3258,1293,2812,5537,5538,5539,3905, 607,2311,3906, 762,2892, # 4128
+1439,4365,1360,4721,1485,3072,5540,4722,1038,4366,1450,2062,2648,4367,1379,4723, # 4144
+2593,5541,5542,4368,1352,1414,2330,2935,1172,5543,5544,3907,3908,4724,1798,1451, # 4160
+5545,5546,5547,5548,2936,4109,4110,2492,2351, 411,4111,4112,3637,3333,3124,4725, # 4176
+1561,2674,1452,4113,1375,5549,5550, 47,2974, 316,5551,1406,1591,2937,3181,5552, # 4192
+1025,2142,3125,3182, 354,2740, 884,2228,4369,2412, 508,3772, 726,3638, 996,2433, # 4208
+3639, 729,5553, 392,2194,1453,4114,4726,3773,5554,5555,2463,3640,2618,1675,2813, # 4224
+ 919,2352,2975,2353,1270,4727,4115, 73,5556,5557, 647,5558,3259,2856,2259,1550, # 4240
+1346,3024,5559,1332, 883,3526,5560,5561,5562,5563,3334,2775,5564,1212, 831,1347, # 4256
+4370,4728,2331,3909,1864,3073, 720,3910,4729,4730,3911,5565,4371,5566,5567,4731, # 4272
+5568,5569,1799,4732,3774,2619,4733,3641,1645,2376,4734,5570,2938, 669,2211,2675, # 4288
+2434,5571,2893,5572,5573,1028,3260,5574,4372,2413,5575,2260,1353,5576,5577,4735, # 4304
+3183, 518,5578,4116,5579,4373,1961,5580,2143,4374,5581,5582,3025,2354,2355,3912, # 4320
+ 516,1834,1454,4117,2708,4375,4736,2229,2620,1972,1129,3642,5583,2776,5584,2976, # 4336
+1422, 577,1470,3026,1524,3410,5585,5586, 432,4376,3074,3527,5587,2594,1455,2515, # 4352
+2230,1973,1175,5588,1020,2741,4118,3528,4737,5589,2742,5590,1743,1361,3075,3529, # 4368
+2649,4119,4377,4738,2295, 895, 924,4378,2171, 331,2247,3076, 166,1627,3077,1098, # 4384
+5591,1232,2894,2231,3411,4739, 657, 403,1196,2377, 542,3775,3412,1600,4379,3530, # 4400
+5592,4740,2777,3261, 576, 530,1362,4741,4742,2540,2676,3776,4120,5593, 842,3913, # 4416
+5594,2814,2032,1014,4121, 213,2709,3413, 665, 621,4380,5595,3777,2939,2435,5596, # 4432
+2436,3335,3643,3414,4743,4381,2541,4382,4744,3644,1682,4383,3531,1380,5597, 724, # 4448
+2282, 600,1670,5598,1337,1233,4745,3126,2248,5599,1621,4746,5600, 651,4384,5601, # 4464
+1612,4385,2621,5602,2857,5603,2743,2312,3078,5604, 716,2464,3079, 174,1255,2710, # 4480
+4122,3645, 548,1320,1398, 728,4123,1574,5605,1891,1197,3080,4124,5606,3081,3082, # 4496
+3778,3646,3779, 747,5607, 635,4386,4747,5608,5609,5610,4387,5611,5612,4748,5613, # 4512
+3415,4749,2437, 451,5614,3780,2542,2073,4388,2744,4389,4125,5615,1764,4750,5616, # 4528
+4390, 350,4751,2283,2395,2493,5617,4391,4126,2249,1434,4127, 488,4752, 458,4392, # 4544
+4128,3781, 771,1330,2396,3914,2576,3184,2160,2414,1553,2677,3185,4393,5618,2494, # 4560
+2895,2622,1720,2711,4394,3416,4753,5619,2543,4395,5620,3262,4396,2778,5621,2016, # 4576
+2745,5622,1155,1017,3782,3915,5623,3336,2313, 201,1865,4397,1430,5624,4129,5625, # 4592
+5626,5627,5628,5629,4398,1604,5630, 414,1866, 371,2595,4754,4755,3532,2017,3127, # 4608
+4756,1708, 960,4399, 887, 389,2172,1536,1663,1721,5631,2232,4130,2356,2940,1580, # 4624
+5632,5633,1744,4757,2544,4758,4759,5634,4760,5635,2074,5636,4761,3647,3417,2896, # 4640
+4400,5637,4401,2650,3418,2815, 673,2712,2465, 709,3533,4131,3648,4402,5638,1148, # 4656
+ 502, 634,5639,5640,1204,4762,3649,1575,4763,2623,3783,5641,3784,3128, 948,3263, # 4672
+ 121,1745,3916,1110,5642,4403,3083,2516,3027,4132,3785,1151,1771,3917,1488,4133, # 4688
+1987,5643,2438,3534,5644,5645,2094,5646,4404,3918,1213,1407,2816, 531,2746,2545, # 4704
+3264,1011,1537,4764,2779,4405,3129,1061,5647,3786,3787,1867,2897,5648,2018, 120, # 4720
+4406,4407,2063,3650,3265,2314,3919,2678,3419,1955,4765,4134,5649,3535,1047,2713, # 4736
+1266,5650,1368,4766,2858, 649,3420,3920,2546,2747,1102,2859,2679,5651,5652,2000, # 4752
+5653,1111,3651,2977,5654,2495,3921,3652,2817,1855,3421,3788,5655,5656,3422,2415, # 4768
+2898,3337,3266,3653,5657,2577,5658,3654,2818,4135,1460, 856,5659,3655,5660,2899, # 4784
+2978,5661,2900,3922,5662,4408, 632,2517, 875,3923,1697,3924,2296,5663,5664,4767, # 4800
+3028,1239, 580,4768,4409,5665, 914, 936,2075,1190,4136,1039,2124,5666,5667,5668, # 4816
+5669,3423,1473,5670,1354,4410,3925,4769,2173,3084,4137, 915,3338,4411,4412,3339, # 4832
+1605,1835,5671,2748, 398,3656,4413,3926,4138, 328,1913,2860,4139,3927,1331,4414, # 4848
+3029, 937,4415,5672,3657,4140,4141,3424,2161,4770,3425, 524, 742, 538,3085,1012, # 4864
+5673,5674,3928,2466,5675, 658,1103, 225,3929,5676,5677,4771,5678,4772,5679,3267, # 4880
+1243,5680,4142, 963,2250,4773,5681,2714,3658,3186,5682,5683,2596,2332,5684,4774, # 4896
+5685,5686,5687,3536, 957,3426,2547,2033,1931,2941,2467, 870,2019,3659,1746,2780, # 4912
+2781,2439,2468,5688,3930,5689,3789,3130,3790,3537,3427,3791,5690,1179,3086,5691, # 4928
+3187,2378,4416,3792,2548,3188,3131,2749,4143,5692,3428,1556,2549,2297, 977,2901, # 4944
+2034,4144,1205,3429,5693,1765,3430,3189,2125,1271, 714,1689,4775,3538,5694,2333, # 4960
+3931, 533,4417,3660,2184, 617,5695,2469,3340,3539,2315,5696,5697,3190,5698,5699, # 4976
+3932,1988, 618, 427,2651,3540,3431,5700,5701,1244,1690,5702,2819,4418,4776,5703, # 4992
+3541,4777,5704,2284,1576, 473,3661,4419,3432, 972,5705,3662,5706,3087,5707,5708, # 5008
+4778,4779,5709,3793,4145,4146,5710, 153,4780, 356,5711,1892,2902,4420,2144, 408, # 5024
+ 803,2357,5712,3933,5713,4421,1646,2578,2518,4781,4782,3934,5714,3935,4422,5715, # 5040
+2416,3433, 752,5716,5717,1962,3341,2979,5718, 746,3030,2470,4783,4423,3794, 698, # 5056
+4784,1893,4424,3663,2550,4785,3664,3936,5719,3191,3434,5720,1824,1302,4147,2715, # 5072
+3937,1974,4425,5721,4426,3192, 823,1303,1288,1236,2861,3542,4148,3435, 774,3938, # 5088
+5722,1581,4786,1304,2862,3939,4787,5723,2440,2162,1083,3268,4427,4149,4428, 344, # 5104
+1173, 288,2316, 454,1683,5724,5725,1461,4788,4150,2597,5726,5727,4789, 985, 894, # 5120
+5728,3436,3193,5729,1914,2942,3795,1989,5730,2111,1975,5731,4151,5732,2579,1194, # 5136
+ 425,5733,4790,3194,1245,3796,4429,5734,5735,2863,5736, 636,4791,1856,3940, 760, # 5152
+1800,5737,4430,2212,1508,4792,4152,1894,1684,2298,5738,5739,4793,4431,4432,2213, # 5168
+ 479,5740,5741, 832,5742,4153,2496,5743,2980,2497,3797, 990,3132, 627,1815,2652, # 5184
+4433,1582,4434,2126,2112,3543,4794,5744, 799,4435,3195,5745,4795,2113,1737,3031, # 5200
+1018, 543, 754,4436,3342,1676,4796,4797,4154,4798,1489,5746,3544,5747,2624,2903, # 5216
+4155,5748,5749,2981,5750,5751,5752,5753,3196,4799,4800,2185,1722,5754,3269,3270, # 5232
+1843,3665,1715, 481, 365,1976,1857,5755,5756,1963,2498,4801,5757,2127,3666,3271, # 5248
+ 433,1895,2064,2076,5758, 602,2750,5759,5760,5761,5762,5763,3032,1628,3437,5764, # 5264
+3197,4802,4156,2904,4803,2519,5765,2551,2782,5766,5767,5768,3343,4804,2905,5769, # 5280
+4805,5770,2864,4806,4807,1221,2982,4157,2520,5771,5772,5773,1868,1990,5774,5775, # 5296
+5776,1896,5777,5778,4808,1897,4158, 318,5779,2095,4159,4437,5780,5781, 485,5782, # 5312
+ 938,3941, 553,2680, 116,5783,3942,3667,5784,3545,2681,2783,3438,3344,2820,5785, # 5328
+3668,2943,4160,1747,2944,2983,5786,5787, 207,5788,4809,5789,4810,2521,5790,3033, # 5344
+ 890,3669,3943,5791,1878,3798,3439,5792,2186,2358,3440,1652,5793,5794,5795, 941, # 5360
+2299, 208,3546,4161,2020, 330,4438,3944,2906,2499,3799,4439,4811,5796,5797,5798, # 5376
+)
+# fmt: on
diff --git a/lib/chardet/big5prober.py b/lib/chardet/big5prober.py
new file mode 100644
index 0000000..e4dfa7a
--- /dev/null
+++ b/lib/chardet/big5prober.py
@@ -0,0 +1,47 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import Big5DistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import BIG5_SM_MODEL
+
+
+class Big5Prober(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(BIG5_SM_MODEL)
+ self.distribution_analyzer = Big5DistributionAnalysis()
+ self.reset()
+
+ @property
+ def charset_name(self):
+ return "Big5"
+
+ @property
+ def language(self):
+ return "Chinese"
diff --git a/lib/chardet/chardistribution.py b/lib/chardet/chardistribution.py
new file mode 100644
index 0000000..27b4a29
--- /dev/null
+++ b/lib/chardet/chardistribution.py
@@ -0,0 +1,259 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .big5freq import (
+ BIG5_CHAR_TO_FREQ_ORDER,
+ BIG5_TABLE_SIZE,
+ BIG5_TYPICAL_DISTRIBUTION_RATIO,
+)
+from .euckrfreq import (
+ EUCKR_CHAR_TO_FREQ_ORDER,
+ EUCKR_TABLE_SIZE,
+ EUCKR_TYPICAL_DISTRIBUTION_RATIO,
+)
+from .euctwfreq import (
+ EUCTW_CHAR_TO_FREQ_ORDER,
+ EUCTW_TABLE_SIZE,
+ EUCTW_TYPICAL_DISTRIBUTION_RATIO,
+)
+from .gb2312freq import (
+ GB2312_CHAR_TO_FREQ_ORDER,
+ GB2312_TABLE_SIZE,
+ GB2312_TYPICAL_DISTRIBUTION_RATIO,
+)
+from .jisfreq import (
+ JIS_CHAR_TO_FREQ_ORDER,
+ JIS_TABLE_SIZE,
+ JIS_TYPICAL_DISTRIBUTION_RATIO,
+)
+from .johabfreq import JOHAB_TO_EUCKR_ORDER_TABLE
+
+
+class CharDistributionAnalysis:
+ ENOUGH_DATA_THRESHOLD = 1024
+ SURE_YES = 0.99
+ SURE_NO = 0.01
+ MINIMUM_DATA_THRESHOLD = 3
+
+ def __init__(self):
+ # Mapping table to get frequency order from char order (get from
+ # GetOrder())
+ self._char_to_freq_order = tuple()
+ self._table_size = None # Size of above table
+ # This is a constant value which varies from language to language,
+ # used in calculating confidence. See
+ # http://www.mozilla.org/projects/intl/UniversalCharsetDetection.html
+ # for further detail.
+ self.typical_distribution_ratio = None
+ self._done = None
+ self._total_chars = None
+ self._freq_chars = None
+ self.reset()
+
+ def reset(self):
+ """reset analyser, clear any state"""
+ # If this flag is set to True, detection is done and conclusion has
+ # been made
+ self._done = False
+ self._total_chars = 0 # Total characters encountered
+ # The number of characters whose frequency order is less than 512
+ self._freq_chars = 0
+
+ def feed(self, char, char_len):
+ """feed a character with known length"""
+ if char_len == 2:
+ # we only care about 2-bytes character in our distribution analysis
+ order = self.get_order(char)
+ else:
+ order = -1
+ if order >= 0:
+ self._total_chars += 1
+ # order is valid
+ if order < self._table_size:
+ if 512 > self._char_to_freq_order[order]:
+ self._freq_chars += 1
+
+ def get_confidence(self):
+ """return confidence based on existing data"""
+ # if we didn't receive any character in our consideration range,
+ # return negative answer
+ if self._total_chars <= 0 or self._freq_chars <= self.MINIMUM_DATA_THRESHOLD:
+ return self.SURE_NO
+
+ if self._total_chars != self._freq_chars:
+ r = self._freq_chars / (
+ (self._total_chars - self._freq_chars) * self.typical_distribution_ratio
+ )
+ if r < self.SURE_YES:
+ return r
+
+ # normalize confidence (we don't want to be 100% sure)
+ return self.SURE_YES
+
+ def got_enough_data(self):
+ # It is not necessary to receive all data to draw conclusion.
+ # For charset detection, certain amount of data is enough
+ return self._total_chars > self.ENOUGH_DATA_THRESHOLD
+
+ def get_order(self, _):
+ # We do not handle characters based on the original encoding string,
+ # but convert this encoding string to a number, here called order.
+ # This allows multiple encodings of a language to share one frequency
+ # table.
+ return -1
+
+
+class EUCTWDistributionAnalysis(CharDistributionAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._char_to_freq_order = EUCTW_CHAR_TO_FREQ_ORDER
+ self._table_size = EUCTW_TABLE_SIZE
+ self.typical_distribution_ratio = EUCTW_TYPICAL_DISTRIBUTION_RATIO
+
+ def get_order(self, byte_str):
+ # for euc-TW encoding, we are interested
+ # first byte range: 0xc4 -- 0xfe
+ # second byte range: 0xa1 -- 0xfe
+ # no validation needed here. State machine has done that
+ first_char = byte_str[0]
+ if first_char >= 0xC4:
+ return 94 * (first_char - 0xC4) + byte_str[1] - 0xA1
+ return -1
+
+
+class EUCKRDistributionAnalysis(CharDistributionAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._char_to_freq_order = EUCKR_CHAR_TO_FREQ_ORDER
+ self._table_size = EUCKR_TABLE_SIZE
+ self.typical_distribution_ratio = EUCKR_TYPICAL_DISTRIBUTION_RATIO
+
+ def get_order(self, byte_str):
+ # for euc-KR encoding, we are interested
+ # first byte range: 0xb0 -- 0xfe
+ # second byte range: 0xa1 -- 0xfe
+ # no validation needed here. State machine has done that
+ first_char = byte_str[0]
+ if first_char >= 0xB0:
+ return 94 * (first_char - 0xB0) + byte_str[1] - 0xA1
+ return -1
+
+
+class JOHABDistributionAnalysis(CharDistributionAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._char_to_freq_order = EUCKR_CHAR_TO_FREQ_ORDER
+ self._table_size = EUCKR_TABLE_SIZE
+ self.typical_distribution_ratio = EUCKR_TYPICAL_DISTRIBUTION_RATIO
+
+ def get_order(self, byte_str):
+ first_char = byte_str[0]
+ if 0x88 <= first_char < 0xD4:
+ code = first_char * 256 + byte_str[1]
+ return JOHAB_TO_EUCKR_ORDER_TABLE.get(code, -1)
+ return -1
+
+
+class GB2312DistributionAnalysis(CharDistributionAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._char_to_freq_order = GB2312_CHAR_TO_FREQ_ORDER
+ self._table_size = GB2312_TABLE_SIZE
+ self.typical_distribution_ratio = GB2312_TYPICAL_DISTRIBUTION_RATIO
+
+ def get_order(self, byte_str):
+ # for GB2312 encoding, we are interested
+ # first byte range: 0xb0 -- 0xfe
+ # second byte range: 0xa1 -- 0xfe
+ # no validation needed here. State machine has done that
+ first_char, second_char = byte_str[0], byte_str[1]
+ if (first_char >= 0xB0) and (second_char >= 0xA1):
+ return 94 * (first_char - 0xB0) + second_char - 0xA1
+ return -1
+
+
+class Big5DistributionAnalysis(CharDistributionAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._char_to_freq_order = BIG5_CHAR_TO_FREQ_ORDER
+ self._table_size = BIG5_TABLE_SIZE
+ self.typical_distribution_ratio = BIG5_TYPICAL_DISTRIBUTION_RATIO
+
+ def get_order(self, byte_str):
+ # for big5 encoding, we are interested
+ # first byte range: 0xa4 -- 0xfe
+ # second byte range: 0x40 -- 0x7e , 0xa1 -- 0xfe
+ # no validation needed here. State machine has done that
+ first_char, second_char = byte_str[0], byte_str[1]
+ if first_char >= 0xA4:
+ if second_char >= 0xA1:
+ return 157 * (first_char - 0xA4) + second_char - 0xA1 + 63
+ return 157 * (first_char - 0xA4) + second_char - 0x40
+ return -1
+
+
+class SJISDistributionAnalysis(CharDistributionAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._char_to_freq_order = JIS_CHAR_TO_FREQ_ORDER
+ self._table_size = JIS_TABLE_SIZE
+ self.typical_distribution_ratio = JIS_TYPICAL_DISTRIBUTION_RATIO
+
+ def get_order(self, byte_str):
+ # for sjis encoding, we are interested
+ # first byte range: 0x81 -- 0x9f , 0xe0 -- 0xfe
+ # second byte range: 0x40 -- 0x7e, 0x81 -- oxfe
+ # no validation needed here. State machine has done that
+ first_char, second_char = byte_str[0], byte_str[1]
+ if 0x81 <= first_char <= 0x9F:
+ order = 188 * (first_char - 0x81)
+ elif 0xE0 <= first_char <= 0xEF:
+ order = 188 * (first_char - 0xE0 + 31)
+ else:
+ return -1
+ order = order + second_char - 0x40
+ if second_char > 0x7F:
+ order = -1
+ return order
+
+
+class EUCJPDistributionAnalysis(CharDistributionAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._char_to_freq_order = JIS_CHAR_TO_FREQ_ORDER
+ self._table_size = JIS_TABLE_SIZE
+ self.typical_distribution_ratio = JIS_TYPICAL_DISTRIBUTION_RATIO
+
+ def get_order(self, byte_str):
+ # for euc-JP encoding, we are interested
+ # first byte range: 0xa0 -- 0xfe
+ # second byte range: 0xa1 -- 0xfe
+ # no validation needed here. State machine has done that
+ char = byte_str[0]
+ if char >= 0xA0:
+ return 94 * (char - 0xA1) + byte_str[1] - 0xA1
+ return -1
diff --git a/lib/chardet/charsetgroupprober.py b/lib/chardet/charsetgroupprober.py
new file mode 100644
index 0000000..778ff33
--- /dev/null
+++ b/lib/chardet/charsetgroupprober.py
@@ -0,0 +1,109 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .charsetprober import CharSetProber
+from .enums import ProbingState
+
+
+class CharSetGroupProber(CharSetProber):
+ def __init__(self, lang_filter=None):
+ super().__init__(lang_filter=lang_filter)
+ self._active_num = 0
+ self.probers = []
+ self._best_guess_prober = None
+
+ def reset(self):
+ super().reset()
+ self._active_num = 0
+ for prober in self.probers:
+ if prober:
+ prober.reset()
+ prober.active = True
+ self._active_num += 1
+ self._best_guess_prober = None
+
+ @property
+ def charset_name(self):
+ if not self._best_guess_prober:
+ self.get_confidence()
+ if not self._best_guess_prober:
+ return None
+ return self._best_guess_prober.charset_name
+
+ @property
+ def language(self):
+ if not self._best_guess_prober:
+ self.get_confidence()
+ if not self._best_guess_prober:
+ return None
+ return self._best_guess_prober.language
+
+ def feed(self, byte_str):
+ for prober in self.probers:
+ if not prober:
+ continue
+ if not prober.active:
+ continue
+ state = prober.feed(byte_str)
+ if not state:
+ continue
+ if state == ProbingState.FOUND_IT:
+ self._best_guess_prober = prober
+ self._state = ProbingState.FOUND_IT
+ return self.state
+ if state == ProbingState.NOT_ME:
+ prober.active = False
+ self._active_num -= 1
+ if self._active_num <= 0:
+ self._state = ProbingState.NOT_ME
+ return self.state
+ return self.state
+
+ def get_confidence(self):
+ state = self.state
+ if state == ProbingState.FOUND_IT:
+ return 0.99
+ if state == ProbingState.NOT_ME:
+ return 0.01
+ best_conf = 0.0
+ self._best_guess_prober = None
+ for prober in self.probers:
+ if not prober:
+ continue
+ if not prober.active:
+ self.logger.debug("%s not active", prober.charset_name)
+ continue
+ conf = prober.get_confidence()
+ self.logger.debug(
+ "%s %s confidence = %s", prober.charset_name, prober.language, conf
+ )
+ if best_conf < conf:
+ best_conf = conf
+ self._best_guess_prober = prober
+ if not self._best_guess_prober:
+ return 0.0
+ return best_conf
diff --git a/lib/chardet/charsetprober.py b/lib/chardet/charsetprober.py
new file mode 100644
index 0000000..9f1afd9
--- /dev/null
+++ b/lib/chardet/charsetprober.py
@@ -0,0 +1,138 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 2001
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+# Shy Shalom - original C code
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+import logging
+import re
+
+from .enums import ProbingState
+
+INTERNATIONAL_WORDS_PATTERN = re.compile(
+ b"[a-zA-Z]*[\x80-\xFF]+[a-zA-Z]*[^a-zA-Z\x80-\xFF]?"
+)
+
+
+class CharSetProber:
+
+ SHORTCUT_THRESHOLD = 0.95
+
+ def __init__(self, lang_filter=None):
+ self._state = None
+ self.lang_filter = lang_filter
+ self.logger = logging.getLogger(__name__)
+
+ def reset(self):
+ self._state = ProbingState.DETECTING
+
+ @property
+ def charset_name(self):
+ return None
+
+ def feed(self, byte_str):
+ raise NotImplementedError
+
+ @property
+ def state(self):
+ return self._state
+
+ def get_confidence(self):
+ return 0.0
+
+ @staticmethod
+ def filter_high_byte_only(buf):
+ buf = re.sub(b"([\x00-\x7F])+", b" ", buf)
+ return buf
+
+ @staticmethod
+ def filter_international_words(buf):
+ """
+ We define three types of bytes:
+ alphabet: english alphabets [a-zA-Z]
+ international: international characters [\x80-\xFF]
+ marker: everything else [^a-zA-Z\x80-\xFF]
+ The input buffer can be thought to contain a series of words delimited
+ by markers. This function works to filter all words that contain at
+ least one international character. All contiguous sequences of markers
+ are replaced by a single space ascii character.
+ This filter applies to all scripts which do not use English characters.
+ """
+ filtered = bytearray()
+
+ # This regex expression filters out only words that have at-least one
+ # international character. The word may include one marker character at
+ # the end.
+ words = INTERNATIONAL_WORDS_PATTERN.findall(buf)
+
+ for word in words:
+ filtered.extend(word[:-1])
+
+ # If the last character in the word is a marker, replace it with a
+ # space as markers shouldn't affect our analysis (they are used
+ # similarly across all languages and may thus have similar
+ # frequencies).
+ last_char = word[-1:]
+ if not last_char.isalpha() and last_char < b"\x80":
+ last_char = b" "
+ filtered.extend(last_char)
+
+ return filtered
+
+ @staticmethod
+ def remove_xml_tags(buf):
+ """
+ Returns a copy of ``buf`` that retains only the sequences of English
+ alphabet and high byte characters that are not between <> characters.
+ This filter can be applied to all scripts which contain both English
+ characters and extended ASCII characters, but is currently only used by
+ ``Latin1Prober``.
+ """
+ filtered = bytearray()
+ in_tag = False
+ prev = 0
+ buf = memoryview(buf).cast("c")
+
+ for curr, buf_char in enumerate(buf):
+ # Check if we're coming out of or entering an XML tag
+ if buf_char == b">":
+ prev = curr + 1
+ in_tag = False
+ elif buf_char == b"<":
+ if curr > prev and not in_tag:
+ # Keep everything after last non-extended-ASCII,
+ # non-alphabetic character
+ filtered.extend(buf[prev:curr])
+ # Output a space to delimit stretch we kept
+ filtered.extend(b" ")
+ in_tag = True
+
+ # If we're not in a tag...
+ if not in_tag:
+ # Keep everything after last non-extended-ASCII, non-alphabetic
+ # character
+ filtered.extend(buf[prev:])
+
+ return filtered
diff --git a/lib/chardet/cli/__init__.py b/lib/chardet/cli/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/chardet/cli/__init__.py
diff --git a/lib/chardet/cli/chardetect.py b/lib/chardet/cli/chardetect.py
new file mode 100644
index 0000000..7926fa3
--- /dev/null
+++ b/lib/chardet/cli/chardetect.py
@@ -0,0 +1,86 @@
+"""
+Script which takes one or more file paths and reports on their detected
+encodings
+
+Example::
+
+ % chardetect somefile someotherfile
+ somefile: windows-1252 with confidence 0.5
+ someotherfile: ascii with confidence 1.0
+
+If no paths are provided, it takes its input from stdin.
+
+"""
+
+
+import argparse
+import sys
+
+from .. import __version__
+from ..universaldetector import UniversalDetector
+
+
+def description_of(lines, name="stdin"):
+ """
+ Return a string describing the probable encoding of a file or
+ list of strings.
+
+ :param lines: The lines to get the encoding of.
+ :type lines: Iterable of bytes
+ :param name: Name of file or collection of lines
+ :type name: str
+ """
+ u = UniversalDetector()
+ for line in lines:
+ line = bytearray(line)
+ u.feed(line)
+ # shortcut out of the loop to save reading further - particularly useful if we read a BOM.
+ if u.done:
+ break
+ u.close()
+ result = u.result
+ if result["encoding"]:
+ return f'{name}: {result["encoding"]} with confidence {result["confidence"]}'
+ return f"{name}: no result"
+
+
+def main(argv=None):
+ """
+ Handles command line arguments and gets things started.
+
+ :param argv: List of arguments, as if specified on the command-line.
+ If None, ``sys.argv[1:]`` is used instead.
+ :type argv: list of str
+ """
+ # Get command line arguments
+ parser = argparse.ArgumentParser(
+ description="Takes one or more file paths and reports their detected \
+ encodings"
+ )
+ parser.add_argument(
+ "input",
+ help="File whose encoding we would like to determine. \
+ (default: stdin)",
+ type=argparse.FileType("rb"),
+ nargs="*",
+ default=[sys.stdin.buffer],
+ )
+ parser.add_argument(
+ "--version", action="version", version=f"%(prog)s {__version__}"
+ )
+ args = parser.parse_args(argv)
+
+ for f in args.input:
+ if f.isatty():
+ print(
+ "You are running chardetect interactively. Press "
+ "CTRL-D twice at the start of a blank line to signal the "
+ "end of your input. If you want help, run chardetect "
+ "--help\n",
+ file=sys.stderr,
+ )
+ print(description_of(f, f.name))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/lib/chardet/codingstatemachine.py b/lib/chardet/codingstatemachine.py
new file mode 100644
index 0000000..d3e3e82
--- /dev/null
+++ b/lib/chardet/codingstatemachine.py
@@ -0,0 +1,88 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+import logging
+
+from .enums import MachineState
+
+
+class CodingStateMachine:
+ """
+ A state machine to verify a byte sequence for a particular encoding. For
+ each byte the detector receives, it will feed that byte to every active
+ state machine available, one byte at a time. The state machine changes its
+ state based on its previous state and the byte it receives. There are 3
+ states in a state machine that are of interest to an auto-detector:
+
+ START state: This is the state to start with, or a legal byte sequence
+ (i.e. a valid code point) for character has been identified.
+
+ ME state: This indicates that the state machine identified a byte sequence
+ that is specific to the charset it is designed for and that
+ there is no other possible encoding which can contain this byte
+ sequence. This will to lead to an immediate positive answer for
+ the detector.
+
+ ERROR state: This indicates the state machine identified an illegal byte
+ sequence for that encoding. This will lead to an immediate
+ negative answer for this encoding. Detector will exclude this
+ encoding from consideration from here on.
+ """
+
+ def __init__(self, sm):
+ self._model = sm
+ self._curr_byte_pos = 0
+ self._curr_char_len = 0
+ self._curr_state = None
+ self.logger = logging.getLogger(__name__)
+ self.reset()
+
+ def reset(self):
+ self._curr_state = MachineState.START
+
+ def next_state(self, c):
+ # for each byte we get its class
+ # if it is first byte, we also get byte length
+ byte_class = self._model["class_table"][c]
+ if self._curr_state == MachineState.START:
+ self._curr_byte_pos = 0
+ self._curr_char_len = self._model["char_len_table"][byte_class]
+ # from byte's class and state_table, we get its next state
+ curr_state = self._curr_state * self._model["class_factor"] + byte_class
+ self._curr_state = self._model["state_table"][curr_state]
+ self._curr_byte_pos += 1
+ return self._curr_state
+
+ def get_current_charlen(self):
+ return self._curr_char_len
+
+ def get_coding_state_machine(self):
+ return self._model["name"]
+
+ @property
+ def language(self):
+ return self._model["language"]
diff --git a/lib/chardet/cp949prober.py b/lib/chardet/cp949prober.py
new file mode 100644
index 0000000..28a1f3d
--- /dev/null
+++ b/lib/chardet/cp949prober.py
@@ -0,0 +1,49 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import EUCKRDistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import CP949_SM_MODEL
+
+
+class CP949Prober(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(CP949_SM_MODEL)
+ # NOTE: CP949 is a superset of EUC-KR, so the distribution should be
+ # not different.
+ self.distribution_analyzer = EUCKRDistributionAnalysis()
+ self.reset()
+
+ @property
+ def charset_name(self):
+ return "CP949"
+
+ @property
+ def language(self):
+ return "Korean"
diff --git a/lib/chardet/enums.py b/lib/chardet/enums.py
new file mode 100644
index 0000000..32a77e7
--- /dev/null
+++ b/lib/chardet/enums.py
@@ -0,0 +1,82 @@
+"""
+All of the Enums that are used throughout the chardet package.
+
+:author: Dan Blanchard (dan.blanchard@gmail.com)
+"""
+
+
+class InputState:
+ """
+ This enum represents the different states a universal detector can be in.
+ """
+
+ PURE_ASCII = 0
+ ESC_ASCII = 1
+ HIGH_BYTE = 2
+
+
+class LanguageFilter:
+ """
+ This enum represents the different language filters we can apply to a
+ ``UniversalDetector``.
+ """
+
+ CHINESE_SIMPLIFIED = 0x01
+ CHINESE_TRADITIONAL = 0x02
+ JAPANESE = 0x04
+ KOREAN = 0x08
+ NON_CJK = 0x10
+ ALL = 0x1F
+ CHINESE = CHINESE_SIMPLIFIED | CHINESE_TRADITIONAL
+ CJK = CHINESE | JAPANESE | KOREAN
+
+
+class ProbingState:
+ """
+ This enum represents the different states a prober can be in.
+ """
+
+ DETECTING = 0
+ FOUND_IT = 1
+ NOT_ME = 2
+
+
+class MachineState:
+ """
+ This enum represents the different states a state machine can be in.
+ """
+
+ START = 0
+ ERROR = 1
+ ITS_ME = 2
+
+
+class SequenceLikelihood:
+ """
+ This enum represents the likelihood of a character following the previous one.
+ """
+
+ NEGATIVE = 0
+ UNLIKELY = 1
+ LIKELY = 2
+ POSITIVE = 3
+
+ @classmethod
+ def get_num_categories(cls):
+ """:returns: The number of likelihood categories in the enum."""
+ return 4
+
+
+class CharacterCategory:
+ """
+ This enum represents the different categories language models for
+ ``SingleByteCharsetProber`` put characters into.
+
+ Anything less than CONTROL is considered a letter.
+ """
+
+ UNDEFINED = 255
+ LINE_BREAK = 254
+ SYMBOL = 253
+ DIGIT = 252
+ CONTROL = 251
diff --git a/lib/chardet/escprober.py b/lib/chardet/escprober.py
new file mode 100644
index 0000000..d992611
--- /dev/null
+++ b/lib/chardet/escprober.py
@@ -0,0 +1,102 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .charsetprober import CharSetProber
+from .codingstatemachine import CodingStateMachine
+from .enums import LanguageFilter, MachineState, ProbingState
+from .escsm import (
+ HZ_SM_MODEL,
+ ISO2022CN_SM_MODEL,
+ ISO2022JP_SM_MODEL,
+ ISO2022KR_SM_MODEL,
+)
+
+
+class EscCharSetProber(CharSetProber):
+ """
+ This CharSetProber uses a "code scheme" approach for detecting encodings,
+ whereby easily recognizable escape or shift sequences are relied on to
+ identify these encodings.
+ """
+
+ def __init__(self, lang_filter=None):
+ super().__init__(lang_filter=lang_filter)
+ self.coding_sm = []
+ if self.lang_filter & LanguageFilter.CHINESE_SIMPLIFIED:
+ self.coding_sm.append(CodingStateMachine(HZ_SM_MODEL))
+ self.coding_sm.append(CodingStateMachine(ISO2022CN_SM_MODEL))
+ if self.lang_filter & LanguageFilter.JAPANESE:
+ self.coding_sm.append(CodingStateMachine(ISO2022JP_SM_MODEL))
+ if self.lang_filter & LanguageFilter.KOREAN:
+ self.coding_sm.append(CodingStateMachine(ISO2022KR_SM_MODEL))
+ self.active_sm_count = None
+ self._detected_charset = None
+ self._detected_language = None
+ self._state = None
+ self.reset()
+
+ def reset(self):
+ super().reset()
+ for coding_sm in self.coding_sm:
+ if not coding_sm:
+ continue
+ coding_sm.active = True
+ coding_sm.reset()
+ self.active_sm_count = len(self.coding_sm)
+ self._detected_charset = None
+ self._detected_language = None
+
+ @property
+ def charset_name(self):
+ return self._detected_charset
+
+ @property
+ def language(self):
+ return self._detected_language
+
+ def get_confidence(self):
+ return 0.99 if self._detected_charset else 0.00
+
+ def feed(self, byte_str):
+ for c in byte_str:
+ for coding_sm in self.coding_sm:
+ if not coding_sm or not coding_sm.active:
+ continue
+ coding_state = coding_sm.next_state(c)
+ if coding_state == MachineState.ERROR:
+ coding_sm.active = False
+ self.active_sm_count -= 1
+ if self.active_sm_count <= 0:
+ self._state = ProbingState.NOT_ME
+ return self.state
+ elif coding_state == MachineState.ITS_ME:
+ self._state = ProbingState.FOUND_IT
+ self._detected_charset = coding_sm.get_coding_state_machine()
+ self._detected_language = coding_sm.language
+ return self.state
+
+ return self.state
diff --git a/lib/chardet/escsm.py b/lib/chardet/escsm.py
new file mode 100644
index 0000000..3aa0f4d
--- /dev/null
+++ b/lib/chardet/escsm.py
@@ -0,0 +1,260 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .enums import MachineState
+
+# fmt: off
+HZ_CLS = (
+ 1, 0, 0, 0, 0, 0, 0, 0, # 00 - 07
+ 0, 0, 0, 0, 0, 0, 0, 0, # 08 - 0f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 10 - 17
+ 0, 0, 0, 1, 0, 0, 0, 0, # 18 - 1f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 20 - 27
+ 0, 0, 0, 0, 0, 0, 0, 0, # 28 - 2f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 30 - 37
+ 0, 0, 0, 0, 0, 0, 0, 0, # 38 - 3f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 40 - 47
+ 0, 0, 0, 0, 0, 0, 0, 0, # 48 - 4f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 50 - 57
+ 0, 0, 0, 0, 0, 0, 0, 0, # 58 - 5f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 60 - 67
+ 0, 0, 0, 0, 0, 0, 0, 0, # 68 - 6f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 70 - 77
+ 0, 0, 0, 4, 0, 5, 2, 0, # 78 - 7f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 80 - 87
+ 1, 1, 1, 1, 1, 1, 1, 1, # 88 - 8f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 90 - 97
+ 1, 1, 1, 1, 1, 1, 1, 1, # 98 - 9f
+ 1, 1, 1, 1, 1, 1, 1, 1, # a0 - a7
+ 1, 1, 1, 1, 1, 1, 1, 1, # a8 - af
+ 1, 1, 1, 1, 1, 1, 1, 1, # b0 - b7
+ 1, 1, 1, 1, 1, 1, 1, 1, # b8 - bf
+ 1, 1, 1, 1, 1, 1, 1, 1, # c0 - c7
+ 1, 1, 1, 1, 1, 1, 1, 1, # c8 - cf
+ 1, 1, 1, 1, 1, 1, 1, 1, # d0 - d7
+ 1, 1, 1, 1, 1, 1, 1, 1, # d8 - df
+ 1, 1, 1, 1, 1, 1, 1, 1, # e0 - e7
+ 1, 1, 1, 1, 1, 1, 1, 1, # e8 - ef
+ 1, 1, 1, 1, 1, 1, 1, 1, # f0 - f7
+ 1, 1, 1, 1, 1, 1, 1, 1, # f8 - ff
+)
+
+HZ_ST = (
+MachineState.START, MachineState.ERROR, 3, MachineState.START, MachineState.START, MachineState.START, MachineState.ERROR, MachineState.ERROR, # 00-07
+MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, # 08-0f
+MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ERROR, MachineState.ERROR, MachineState.START, MachineState.START, 4, MachineState.ERROR, # 10-17
+ 5, MachineState.ERROR, 6, MachineState.ERROR, 5, 5, 4, MachineState.ERROR, # 18-1f
+ 4, MachineState.ERROR, 4, 4, 4, MachineState.ERROR, 4, MachineState.ERROR, # 20-27
+ 4, MachineState.ITS_ME, MachineState.START, MachineState.START, MachineState.START, MachineState.START, MachineState.START, MachineState.START, # 28-2f
+)
+# fmt: on
+
+HZ_CHAR_LEN_TABLE = (0, 0, 0, 0, 0, 0)
+
+HZ_SM_MODEL = {
+ "class_table": HZ_CLS,
+ "class_factor": 6,
+ "state_table": HZ_ST,
+ "char_len_table": HZ_CHAR_LEN_TABLE,
+ "name": "HZ-GB-2312",
+ "language": "Chinese",
+}
+
+# fmt: off
+ISO2022CN_CLS = (
+ 2, 0, 0, 0, 0, 0, 0, 0, # 00 - 07
+ 0, 0, 0, 0, 0, 0, 0, 0, # 08 - 0f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 10 - 17
+ 0, 0, 0, 1, 0, 0, 0, 0, # 18 - 1f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 20 - 27
+ 0, 3, 0, 0, 0, 0, 0, 0, # 28 - 2f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 30 - 37
+ 0, 0, 0, 0, 0, 0, 0, 0, # 38 - 3f
+ 0, 0, 0, 4, 0, 0, 0, 0, # 40 - 47
+ 0, 0, 0, 0, 0, 0, 0, 0, # 48 - 4f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 50 - 57
+ 0, 0, 0, 0, 0, 0, 0, 0, # 58 - 5f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 60 - 67
+ 0, 0, 0, 0, 0, 0, 0, 0, # 68 - 6f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 70 - 77
+ 0, 0, 0, 0, 0, 0, 0, 0, # 78 - 7f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 80 - 87
+ 2, 2, 2, 2, 2, 2, 2, 2, # 88 - 8f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 90 - 97
+ 2, 2, 2, 2, 2, 2, 2, 2, # 98 - 9f
+ 2, 2, 2, 2, 2, 2, 2, 2, # a0 - a7
+ 2, 2, 2, 2, 2, 2, 2, 2, # a8 - af
+ 2, 2, 2, 2, 2, 2, 2, 2, # b0 - b7
+ 2, 2, 2, 2, 2, 2, 2, 2, # b8 - bf
+ 2, 2, 2, 2, 2, 2, 2, 2, # c0 - c7
+ 2, 2, 2, 2, 2, 2, 2, 2, # c8 - cf
+ 2, 2, 2, 2, 2, 2, 2, 2, # d0 - d7
+ 2, 2, 2, 2, 2, 2, 2, 2, # d8 - df
+ 2, 2, 2, 2, 2, 2, 2, 2, # e0 - e7
+ 2, 2, 2, 2, 2, 2, 2, 2, # e8 - ef
+ 2, 2, 2, 2, 2, 2, 2, 2, # f0 - f7
+ 2, 2, 2, 2, 2, 2, 2, 2, # f8 - ff
+)
+
+ISO2022CN_ST = (
+ MachineState.START, 3, MachineState.ERROR, MachineState.START, MachineState.START, MachineState.START, MachineState.START, MachineState.START, # 00-07
+ MachineState.START, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, # 08-0f
+ MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, # 10-17
+ MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, 4, MachineState.ERROR, # 18-1f
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, # 20-27
+ 5, 6, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, # 28-2f
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, # 30-37
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ERROR, MachineState.START, # 38-3f
+)
+# fmt: on
+
+ISO2022CN_CHAR_LEN_TABLE = (0, 0, 0, 0, 0, 0, 0, 0, 0)
+
+ISO2022CN_SM_MODEL = {
+ "class_table": ISO2022CN_CLS,
+ "class_factor": 9,
+ "state_table": ISO2022CN_ST,
+ "char_len_table": ISO2022CN_CHAR_LEN_TABLE,
+ "name": "ISO-2022-CN",
+ "language": "Chinese",
+}
+
+# fmt: off
+ISO2022JP_CLS = (
+ 2, 0, 0, 0, 0, 0, 0, 0, # 00 - 07
+ 0, 0, 0, 0, 0, 0, 2, 2, # 08 - 0f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 10 - 17
+ 0, 0, 0, 1, 0, 0, 0, 0, # 18 - 1f
+ 0, 0, 0, 0, 7, 0, 0, 0, # 20 - 27
+ 3, 0, 0, 0, 0, 0, 0, 0, # 28 - 2f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 30 - 37
+ 0, 0, 0, 0, 0, 0, 0, 0, # 38 - 3f
+ 6, 0, 4, 0, 8, 0, 0, 0, # 40 - 47
+ 0, 9, 5, 0, 0, 0, 0, 0, # 48 - 4f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 50 - 57
+ 0, 0, 0, 0, 0, 0, 0, 0, # 58 - 5f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 60 - 67
+ 0, 0, 0, 0, 0, 0, 0, 0, # 68 - 6f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 70 - 77
+ 0, 0, 0, 0, 0, 0, 0, 0, # 78 - 7f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 80 - 87
+ 2, 2, 2, 2, 2, 2, 2, 2, # 88 - 8f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 90 - 97
+ 2, 2, 2, 2, 2, 2, 2, 2, # 98 - 9f
+ 2, 2, 2, 2, 2, 2, 2, 2, # a0 - a7
+ 2, 2, 2, 2, 2, 2, 2, 2, # a8 - af
+ 2, 2, 2, 2, 2, 2, 2, 2, # b0 - b7
+ 2, 2, 2, 2, 2, 2, 2, 2, # b8 - bf
+ 2, 2, 2, 2, 2, 2, 2, 2, # c0 - c7
+ 2, 2, 2, 2, 2, 2, 2, 2, # c8 - cf
+ 2, 2, 2, 2, 2, 2, 2, 2, # d0 - d7
+ 2, 2, 2, 2, 2, 2, 2, 2, # d8 - df
+ 2, 2, 2, 2, 2, 2, 2, 2, # e0 - e7
+ 2, 2, 2, 2, 2, 2, 2, 2, # e8 - ef
+ 2, 2, 2, 2, 2, 2, 2, 2, # f0 - f7
+ 2, 2, 2, 2, 2, 2, 2, 2, # f8 - ff
+)
+
+ISO2022JP_ST = (
+ MachineState.START, 3, MachineState.ERROR, MachineState.START, MachineState.START, MachineState.START, MachineState.START, MachineState.START, # 00-07
+ MachineState.START, MachineState.START, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, # 08-0f
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, # 10-17
+ MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ERROR, MachineState.ERROR, # 18-1f
+ MachineState.ERROR, 5, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, 4, MachineState.ERROR, MachineState.ERROR, # 20-27
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, 6, MachineState.ITS_ME, MachineState.ERROR, MachineState.ITS_ME, MachineState.ERROR, # 28-2f
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ITS_ME, # 30-37
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, # 38-3f
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ERROR, MachineState.START, MachineState.START, # 40-47
+)
+# fmt: on
+
+ISO2022JP_CHAR_LEN_TABLE = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
+
+ISO2022JP_SM_MODEL = {
+ "class_table": ISO2022JP_CLS,
+ "class_factor": 10,
+ "state_table": ISO2022JP_ST,
+ "char_len_table": ISO2022JP_CHAR_LEN_TABLE,
+ "name": "ISO-2022-JP",
+ "language": "Japanese",
+}
+
+# fmt: off
+ISO2022KR_CLS = (
+ 2, 0, 0, 0, 0, 0, 0, 0, # 00 - 07
+ 0, 0, 0, 0, 0, 0, 0, 0, # 08 - 0f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 10 - 17
+ 0, 0, 0, 1, 0, 0, 0, 0, # 18 - 1f
+ 0, 0, 0, 0, 3, 0, 0, 0, # 20 - 27
+ 0, 4, 0, 0, 0, 0, 0, 0, # 28 - 2f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 30 - 37
+ 0, 0, 0, 0, 0, 0, 0, 0, # 38 - 3f
+ 0, 0, 0, 5, 0, 0, 0, 0, # 40 - 47
+ 0, 0, 0, 0, 0, 0, 0, 0, # 48 - 4f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 50 - 57
+ 0, 0, 0, 0, 0, 0, 0, 0, # 58 - 5f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 60 - 67
+ 0, 0, 0, 0, 0, 0, 0, 0, # 68 - 6f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 70 - 77
+ 0, 0, 0, 0, 0, 0, 0, 0, # 78 - 7f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 80 - 87
+ 2, 2, 2, 2, 2, 2, 2, 2, # 88 - 8f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 90 - 97
+ 2, 2, 2, 2, 2, 2, 2, 2, # 98 - 9f
+ 2, 2, 2, 2, 2, 2, 2, 2, # a0 - a7
+ 2, 2, 2, 2, 2, 2, 2, 2, # a8 - af
+ 2, 2, 2, 2, 2, 2, 2, 2, # b0 - b7
+ 2, 2, 2, 2, 2, 2, 2, 2, # b8 - bf
+ 2, 2, 2, 2, 2, 2, 2, 2, # c0 - c7
+ 2, 2, 2, 2, 2, 2, 2, 2, # c8 - cf
+ 2, 2, 2, 2, 2, 2, 2, 2, # d0 - d7
+ 2, 2, 2, 2, 2, 2, 2, 2, # d8 - df
+ 2, 2, 2, 2, 2, 2, 2, 2, # e0 - e7
+ 2, 2, 2, 2, 2, 2, 2, 2, # e8 - ef
+ 2, 2, 2, 2, 2, 2, 2, 2, # f0 - f7
+ 2, 2, 2, 2, 2, 2, 2, 2, # f8 - ff
+)
+
+ISO2022KR_ST = (
+ MachineState.START, 3, MachineState.ERROR, MachineState.START, MachineState.START, MachineState.START, MachineState.ERROR, MachineState.ERROR, # 00-07
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ITS_ME, # 08-0f
+ MachineState.ITS_ME, MachineState.ITS_ME, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, 4, MachineState.ERROR, MachineState.ERROR, # 10-17
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, 5, MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, # 18-1f
+ MachineState.ERROR, MachineState.ERROR, MachineState.ERROR, MachineState.ITS_ME, MachineState.START, MachineState.START, MachineState.START, MachineState.START, # 20-27
+)
+# fmt: on
+
+ISO2022KR_CHAR_LEN_TABLE = (0, 0, 0, 0, 0, 0)
+
+ISO2022KR_SM_MODEL = {
+ "class_table": ISO2022KR_CLS,
+ "class_factor": 6,
+ "state_table": ISO2022KR_ST,
+ "char_len_table": ISO2022KR_CHAR_LEN_TABLE,
+ "name": "ISO-2022-KR",
+ "language": "Korean",
+}
diff --git a/lib/chardet/eucjpprober.py b/lib/chardet/eucjpprober.py
new file mode 100644
index 0000000..abf2e66
--- /dev/null
+++ b/lib/chardet/eucjpprober.py
@@ -0,0 +1,95 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import EUCJPDistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .enums import MachineState, ProbingState
+from .jpcntx import EUCJPContextAnalysis
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import EUCJP_SM_MODEL
+
+
+class EUCJPProber(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(EUCJP_SM_MODEL)
+ self.distribution_analyzer = EUCJPDistributionAnalysis()
+ self.context_analyzer = EUCJPContextAnalysis()
+ self.reset()
+
+ def reset(self):
+ super().reset()
+ self.context_analyzer.reset()
+
+ @property
+ def charset_name(self):
+ return "EUC-JP"
+
+ @property
+ def language(self):
+ return "Japanese"
+
+ def feed(self, byte_str):
+ for i, byte in enumerate(byte_str):
+ # PY3K: byte_str is a byte array, so byte is an int, not a byte
+ coding_state = self.coding_sm.next_state(byte)
+ if coding_state == MachineState.ERROR:
+ self.logger.debug(
+ "%s %s prober hit error at byte %s",
+ self.charset_name,
+ self.language,
+ i,
+ )
+ self._state = ProbingState.NOT_ME
+ break
+ if coding_state == MachineState.ITS_ME:
+ self._state = ProbingState.FOUND_IT
+ break
+ if coding_state == MachineState.START:
+ char_len = self.coding_sm.get_current_charlen()
+ if i == 0:
+ self._last_char[1] = byte
+ self.context_analyzer.feed(self._last_char, char_len)
+ self.distribution_analyzer.feed(self._last_char, char_len)
+ else:
+ self.context_analyzer.feed(byte_str[i - 1 : i + 1], char_len)
+ self.distribution_analyzer.feed(byte_str[i - 1 : i + 1], char_len)
+
+ self._last_char[0] = byte_str[-1]
+
+ if self.state == ProbingState.DETECTING:
+ if self.context_analyzer.got_enough_data() and (
+ self.get_confidence() > self.SHORTCUT_THRESHOLD
+ ):
+ self._state = ProbingState.FOUND_IT
+
+ return self.state
+
+ def get_confidence(self):
+ context_conf = self.context_analyzer.get_confidence()
+ distrib_conf = self.distribution_analyzer.get_confidence()
+ return max(context_conf, distrib_conf)
diff --git a/lib/chardet/euckrfreq.py b/lib/chardet/euckrfreq.py
new file mode 100644
index 0000000..7dc3b10
--- /dev/null
+++ b/lib/chardet/euckrfreq.py
@@ -0,0 +1,196 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+# Sampling from about 20M text materials include literature and computer technology
+
+# 128 --> 0.79
+# 256 --> 0.92
+# 512 --> 0.986
+# 1024 --> 0.99944
+# 2048 --> 0.99999
+#
+# Idea Distribution Ratio = 0.98653 / (1-0.98653) = 73.24
+# Random Distribution Ration = 512 / (2350-512) = 0.279.
+#
+# Typical Distribution Ratio
+
+EUCKR_TYPICAL_DISTRIBUTION_RATIO = 6.0
+
+EUCKR_TABLE_SIZE = 2352
+
+# Char to FreqOrder table ,
+# fmt: off
+EUCKR_CHAR_TO_FREQ_ORDER = (
+ 13, 130, 120,1396, 481,1719,1720, 328, 609, 212,1721, 707, 400, 299,1722, 87,
+1397,1723, 104, 536,1117,1203,1724,1267, 685,1268, 508,1725,1726,1727,1728,1398,
+1399,1729,1730,1731, 141, 621, 326,1057, 368,1732, 267, 488, 20,1733,1269,1734,
+ 945,1400,1735, 47, 904,1270,1736,1737, 773, 248,1738, 409, 313, 786, 429,1739,
+ 116, 987, 813,1401, 683, 75,1204, 145,1740,1741,1742,1743, 16, 847, 667, 622,
+ 708,1744,1745,1746, 966, 787, 304, 129,1747, 60, 820, 123, 676,1748,1749,1750,
+1751, 617,1752, 626,1753,1754,1755,1756, 653,1757,1758,1759,1760,1761,1762, 856,
+ 344,1763,1764,1765,1766, 89, 401, 418, 806, 905, 848,1767,1768,1769, 946,1205,
+ 709,1770,1118,1771, 241,1772,1773,1774,1271,1775, 569,1776, 999,1777,1778,1779,
+1780, 337, 751,1058, 28, 628, 254,1781, 177, 906, 270, 349, 891,1079,1782, 19,
+1783, 379,1784, 315,1785, 629, 754,1402, 559,1786, 636, 203,1206,1787, 710, 567,
+1788, 935, 814,1789,1790,1207, 766, 528,1791,1792,1208,1793,1794,1795,1796,1797,
+1403,1798,1799, 533,1059,1404,1405,1156,1406, 936, 884,1080,1800, 351,1801,1802,
+1803,1804,1805, 801,1806,1807,1808,1119,1809,1157, 714, 474,1407,1810, 298, 899,
+ 885,1811,1120, 802,1158,1812, 892,1813,1814,1408, 659,1815,1816,1121,1817,1818,
+1819,1820,1821,1822, 319,1823, 594, 545,1824, 815, 937,1209,1825,1826, 573,1409,
+1022,1827,1210,1828,1829,1830,1831,1832,1833, 556, 722, 807,1122,1060,1834, 697,
+1835, 900, 557, 715,1836,1410, 540,1411, 752,1159, 294, 597,1211, 976, 803, 770,
+1412,1837,1838, 39, 794,1413, 358,1839, 371, 925,1840, 453, 661, 788, 531, 723,
+ 544,1023,1081, 869, 91,1841, 392, 430, 790, 602,1414, 677,1082, 457,1415,1416,
+1842,1843, 475, 327,1024,1417, 795, 121,1844, 733, 403,1418,1845,1846,1847, 300,
+ 119, 711,1212, 627,1848,1272, 207,1849,1850, 796,1213, 382,1851, 519,1852,1083,
+ 893,1853,1854,1855, 367, 809, 487, 671,1856, 663,1857,1858, 956, 471, 306, 857,
+1859,1860,1160,1084,1861,1862,1863,1864,1865,1061,1866,1867,1868,1869,1870,1871,
+ 282, 96, 574,1872, 502,1085,1873,1214,1874, 907,1875,1876, 827, 977,1419,1420,
+1421, 268,1877,1422,1878,1879,1880, 308,1881, 2, 537,1882,1883,1215,1884,1885,
+ 127, 791,1886,1273,1423,1887, 34, 336, 404, 643,1888, 571, 654, 894, 840,1889,
+ 0, 886,1274, 122, 575, 260, 908, 938,1890,1275, 410, 316,1891,1892, 100,1893,
+1894,1123, 48,1161,1124,1025,1895, 633, 901,1276,1896,1897, 115, 816,1898, 317,
+1899, 694,1900, 909, 734,1424, 572, 866,1425, 691, 85, 524,1010, 543, 394, 841,
+1901,1902,1903,1026,1904,1905,1906,1907,1908,1909, 30, 451, 651, 988, 310,1910,
+1911,1426, 810,1216, 93,1912,1913,1277,1217,1914, 858, 759, 45, 58, 181, 610,
+ 269,1915,1916, 131,1062, 551, 443,1000, 821,1427, 957, 895,1086,1917,1918, 375,
+1919, 359,1920, 687,1921, 822,1922, 293,1923,1924, 40, 662, 118, 692, 29, 939,
+ 887, 640, 482, 174,1925, 69,1162, 728,1428, 910,1926,1278,1218,1279, 386, 870,
+ 217, 854,1163, 823,1927,1928,1929,1930, 834,1931, 78,1932, 859,1933,1063,1934,
+1935,1936,1937, 438,1164, 208, 595,1938,1939,1940,1941,1219,1125,1942, 280, 888,
+1429,1430,1220,1431,1943,1944,1945,1946,1947,1280, 150, 510,1432,1948,1949,1950,
+1951,1952,1953,1954,1011,1087,1955,1433,1043,1956, 881,1957, 614, 958,1064,1065,
+1221,1958, 638,1001, 860, 967, 896,1434, 989, 492, 553,1281,1165,1959,1282,1002,
+1283,1222,1960,1961,1962,1963, 36, 383, 228, 753, 247, 454,1964, 876, 678,1965,
+1966,1284, 126, 464, 490, 835, 136, 672, 529, 940,1088,1435, 473,1967,1968, 467,
+ 50, 390, 227, 587, 279, 378, 598, 792, 968, 240, 151, 160, 849, 882,1126,1285,
+ 639,1044, 133, 140, 288, 360, 811, 563,1027, 561, 142, 523,1969,1970,1971, 7,
+ 103, 296, 439, 407, 506, 634, 990,1972,1973,1974,1975, 645,1976,1977,1978,1979,
+1980,1981, 236,1982,1436,1983,1984,1089, 192, 828, 618, 518,1166, 333,1127,1985,
+ 818,1223,1986,1987,1988,1989,1990,1991,1992,1993, 342,1128,1286, 746, 842,1994,
+1995, 560, 223,1287, 98, 8, 189, 650, 978,1288,1996,1437,1997, 17, 345, 250,
+ 423, 277, 234, 512, 226, 97, 289, 42, 167,1998, 201,1999,2000, 843, 836, 824,
+ 532, 338, 783,1090, 182, 576, 436,1438,1439, 527, 500,2001, 947, 889,2002,2003,
+2004,2005, 262, 600, 314, 447,2006, 547,2007, 693, 738,1129,2008, 71,1440, 745,
+ 619, 688,2009, 829,2010,2011, 147,2012, 33, 948,2013,2014, 74, 224,2015, 61,
+ 191, 918, 399, 637,2016,1028,1130, 257, 902,2017,2018,2019,2020,2021,2022,2023,
+2024,2025,2026, 837,2027,2028,2029,2030, 179, 874, 591, 52, 724, 246,2031,2032,
+2033,2034,1167, 969,2035,1289, 630, 605, 911,1091,1168,2036,2037,2038,1441, 912,
+2039, 623,2040,2041, 253,1169,1290,2042,1442, 146, 620, 611, 577, 433,2043,1224,
+ 719,1170, 959, 440, 437, 534, 84, 388, 480,1131, 159, 220, 198, 679,2044,1012,
+ 819,1066,1443, 113,1225, 194, 318,1003,1029,2045,2046,2047,2048,1067,2049,2050,
+2051,2052,2053, 59, 913, 112,2054, 632,2055, 455, 144, 739,1291,2056, 273, 681,
+ 499,2057, 448,2058,2059, 760,2060,2061, 970, 384, 169, 245,1132,2062,2063, 414,
+1444,2064,2065, 41, 235,2066, 157, 252, 877, 568, 919, 789, 580,2067, 725,2068,
+2069,1292,2070,2071,1445,2072,1446,2073,2074, 55, 588, 66,1447, 271,1092,2075,
+1226,2076, 960,1013, 372,2077,2078,2079,2080,2081,1293,2082,2083,2084,2085, 850,
+2086,2087,2088,2089,2090, 186,2091,1068, 180,2092,2093,2094, 109,1227, 522, 606,
+2095, 867,1448,1093, 991,1171, 926, 353,1133,2096, 581,2097,2098,2099,1294,1449,
+1450,2100, 596,1172,1014,1228,2101,1451,1295,1173,1229,2102,2103,1296,1134,1452,
+ 949,1135,2104,2105,1094,1453,1454,1455,2106,1095,2107,2108,2109,2110,2111,2112,
+2113,2114,2115,2116,2117, 804,2118,2119,1230,1231, 805,1456, 405,1136,2120,2121,
+2122,2123,2124, 720, 701,1297, 992,1457, 927,1004,2125,2126,2127,2128,2129,2130,
+ 22, 417,2131, 303,2132, 385,2133, 971, 520, 513,2134,1174, 73,1096, 231, 274,
+ 962,1458, 673,2135,1459,2136, 152,1137,2137,2138,2139,2140,1005,1138,1460,1139,
+2141,2142,2143,2144, 11, 374, 844,2145, 154,1232, 46,1461,2146, 838, 830, 721,
+1233, 106,2147, 90, 428, 462, 578, 566,1175, 352,2148,2149, 538,1234, 124,1298,
+2150,1462, 761, 565,2151, 686,2152, 649,2153, 72, 173,2154, 460, 415,2155,1463,
+2156,1235, 305,2157,2158,2159,2160,2161,2162, 579,2163,2164,2165,2166,2167, 747,
+2168,2169,2170,2171,1464, 669,2172,2173,2174,2175,2176,1465,2177, 23, 530, 285,
+2178, 335, 729,2179, 397,2180,2181,2182,1030,2183,2184, 698,2185,2186, 325,2187,
+2188, 369,2189, 799,1097,1015, 348,2190,1069, 680,2191, 851,1466,2192,2193, 10,
+2194, 613, 424,2195, 979, 108, 449, 589, 27, 172, 81,1031, 80, 774, 281, 350,
+1032, 525, 301, 582,1176,2196, 674,1045,2197,2198,1467, 730, 762,2199,2200,2201,
+2202,1468,2203, 993,2204,2205, 266,1070, 963,1140,2206,2207,2208, 664,1098, 972,
+2209,2210,2211,1177,1469,1470, 871,2212,2213,2214,2215,2216,1471,2217,2218,2219,
+2220,2221,2222,2223,2224,2225,2226,2227,1472,1236,2228,2229,2230,2231,2232,2233,
+2234,2235,1299,2236,2237, 200,2238, 477, 373,2239,2240, 731, 825, 777,2241,2242,
+2243, 521, 486, 548,2244,2245,2246,1473,1300, 53, 549, 137, 875, 76, 158,2247,
+1301,1474, 469, 396,1016, 278, 712,2248, 321, 442, 503, 767, 744, 941,1237,1178,
+1475,2249, 82, 178,1141,1179, 973,2250,1302,2251, 297,2252,2253, 570,2254,2255,
+2256, 18, 450, 206,2257, 290, 292,1142,2258, 511, 162, 99, 346, 164, 735,2259,
+1476,1477, 4, 554, 343, 798,1099,2260,1100,2261, 43, 171,1303, 139, 215,2262,
+2263, 717, 775,2264,1033, 322, 216,2265, 831,2266, 149,2267,1304,2268,2269, 702,
+1238, 135, 845, 347, 309,2270, 484,2271, 878, 655, 238,1006,1478,2272, 67,2273,
+ 295,2274,2275, 461,2276, 478, 942, 412,2277,1034,2278,2279,2280, 265,2281, 541,
+2282,2283,2284,2285,2286, 70, 852,1071,2287,2288,2289,2290, 21, 56, 509, 117,
+ 432,2291,2292, 331, 980, 552,1101, 148, 284, 105, 393,1180,1239, 755,2293, 187,
+2294,1046,1479,2295, 340,2296, 63,1047, 230,2297,2298,1305, 763,1306, 101, 800,
+ 808, 494,2299,2300,2301, 903,2302, 37,1072, 14, 5,2303, 79, 675,2304, 312,
+2305,2306,2307,2308,2309,1480, 6,1307,2310,2311,2312, 1, 470, 35, 24, 229,
+2313, 695, 210, 86, 778, 15, 784, 592, 779, 32, 77, 855, 964,2314, 259,2315,
+ 501, 380,2316,2317, 83, 981, 153, 689,1308,1481,1482,1483,2318,2319, 716,1484,
+2320,2321,2322,2323,2324,2325,1485,2326,2327, 128, 57, 68, 261,1048, 211, 170,
+1240, 31,2328, 51, 435, 742,2329,2330,2331, 635,2332, 264, 456,2333,2334,2335,
+ 425,2336,1486, 143, 507, 263, 943,2337, 363, 920,1487, 256,1488,1102, 243, 601,
+1489,2338,2339,2340,2341,2342,2343,2344, 861,2345,2346,2347,2348,2349,2350, 395,
+2351,1490,1491, 62, 535, 166, 225,2352,2353, 668, 419,1241, 138, 604, 928,2354,
+1181,2355,1492,1493,2356,2357,2358,1143,2359, 696,2360, 387, 307,1309, 682, 476,
+2361,2362, 332, 12, 222, 156,2363, 232,2364, 641, 276, 656, 517,1494,1495,1035,
+ 416, 736,1496,2365,1017, 586,2366,2367,2368,1497,2369, 242,2370,2371,2372,1498,
+2373, 965, 713,2374,2375,2376,2377, 740, 982,1499, 944,1500,1007,2378,2379,1310,
+1501,2380,2381,2382, 785, 329,2383,2384,1502,2385,2386,2387, 932,2388,1503,2389,
+2390,2391,2392,1242,2393,2394,2395,2396,2397, 994, 950,2398,2399,2400,2401,1504,
+1311,2402,2403,2404,2405,1049, 749,2406,2407, 853, 718,1144,1312,2408,1182,1505,
+2409,2410, 255, 516, 479, 564, 550, 214,1506,1507,1313, 413, 239, 444, 339,1145,
+1036,1508,1509,1314,1037,1510,1315,2411,1511,2412,2413,2414, 176, 703, 497, 624,
+ 593, 921, 302,2415, 341, 165,1103,1512,2416,1513,2417,2418,2419, 376,2420, 700,
+2421,2422,2423, 258, 768,1316,2424,1183,2425, 995, 608,2426,2427,2428,2429, 221,
+2430,2431,2432,2433,2434,2435,2436,2437, 195, 323, 726, 188, 897, 983,1317, 377,
+ 644,1050, 879,2438, 452,2439,2440,2441,2442,2443,2444, 914,2445,2446,2447,2448,
+ 915, 489,2449,1514,1184,2450,2451, 515, 64, 427, 495,2452, 583,2453, 483, 485,
+1038, 562, 213,1515, 748, 666,2454,2455,2456,2457, 334,2458, 780, 996,1008, 705,
+1243,2459,2460,2461,2462,2463, 114,2464, 493,1146, 366, 163,1516, 961,1104,2465,
+ 291,2466,1318,1105,2467,1517, 365,2468, 355, 951,1244,2469,1319,2470, 631,2471,
+2472, 218,1320, 364, 320, 756,1518,1519,1321,1520,1322,2473,2474,2475,2476, 997,
+2477,2478,2479,2480, 665,1185,2481, 916,1521,2482,2483,2484, 584, 684,2485,2486,
+ 797,2487,1051,1186,2488,2489,2490,1522,2491,2492, 370,2493,1039,1187, 65,2494,
+ 434, 205, 463,1188,2495, 125, 812, 391, 402, 826, 699, 286, 398, 155, 781, 771,
+ 585,2496, 590, 505,1073,2497, 599, 244, 219, 917,1018, 952, 646,1523,2498,1323,
+2499,2500, 49, 984, 354, 741,2501, 625,2502,1324,2503,1019, 190, 357, 757, 491,
+ 95, 782, 868,2504,2505,2506,2507,2508,2509, 134,1524,1074, 422,1525, 898,2510,
+ 161,2511,2512,2513,2514, 769,2515,1526,2516,2517, 411,1325,2518, 472,1527,2519,
+2520,2521,2522,2523,2524, 985,2525,2526,2527,2528,2529,2530, 764,2531,1245,2532,
+2533, 25, 204, 311,2534, 496,2535,1052,2536,2537,2538,2539,2540,2541,2542, 199,
+ 704, 504, 468, 758, 657,1528, 196, 44, 839,1246, 272, 750,2543, 765, 862,2544,
+2545,1326,2546, 132, 615, 933,2547, 732,2548,2549,2550,1189,1529,2551, 283,1247,
+1053, 607, 929,2552,2553,2554, 930, 183, 872, 616,1040,1147,2555,1148,1020, 441,
+ 249,1075,2556,2557,2558, 466, 743,2559,2560,2561, 92, 514, 426, 420, 526,2562,
+2563,2564,2565,2566,2567,2568, 185,2569,2570,2571,2572, 776,1530, 658,2573, 362,
+2574, 361, 922,1076, 793,2575,2576,2577,2578,2579,2580,1531, 251,2581,2582,2583,
+2584,1532, 54, 612, 237,1327,2585,2586, 275, 408, 647, 111,2587,1533,1106, 465,
+ 3, 458, 9, 38,2588, 107, 110, 890, 209, 26, 737, 498,2589,1534,2590, 431,
+ 202, 88,1535, 356, 287,1107, 660,1149,2591, 381,1536, 986,1150, 445,1248,1151,
+ 974,2592,2593, 846,2594, 446, 953, 184,1249,1250, 727,2595, 923, 193, 883,2596,
+2597,2598, 102, 324, 539, 817,2599, 421,1041,2600, 832,2601, 94, 175, 197, 406,
+2602, 459,2603,2604,2605,2606,2607, 330, 555,2608,2609,2610, 706,1108, 389,2611,
+2612,2613,2614, 233,2615, 833, 558, 931, 954,1251,2616,2617,1537, 546,2618,2619,
+1009,2620,2621,2622,1538, 690,1328,2623, 955,2624,1539,2625,2626, 772,2627,2628,
+2629,2630,2631, 924, 648, 863, 603,2632,2633, 934,1540, 864, 865,2634, 642,1042,
+ 670,1190,2635,2636,2637,2638, 168,2639, 652, 873, 542,1054,1541,2640,2641,2642, # 512, 256
+)
+# fmt: on
diff --git a/lib/chardet/euckrprober.py b/lib/chardet/euckrprober.py
new file mode 100644
index 0000000..154a6d2
--- /dev/null
+++ b/lib/chardet/euckrprober.py
@@ -0,0 +1,47 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import EUCKRDistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import EUCKR_SM_MODEL
+
+
+class EUCKRProber(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(EUCKR_SM_MODEL)
+ self.distribution_analyzer = EUCKRDistributionAnalysis()
+ self.reset()
+
+ @property
+ def charset_name(self):
+ return "EUC-KR"
+
+ @property
+ def language(self):
+ return "Korean"
diff --git a/lib/chardet/euctwfreq.py b/lib/chardet/euctwfreq.py
new file mode 100644
index 0000000..4900ccc
--- /dev/null
+++ b/lib/chardet/euctwfreq.py
@@ -0,0 +1,388 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+# EUCTW frequency table
+# Converted from big5 work
+# by Taiwan's Mandarin Promotion Council
+# <http:#www.edu.tw:81/mandr/>
+
+# 128 --> 0.42261
+# 256 --> 0.57851
+# 512 --> 0.74851
+# 1024 --> 0.89384
+# 2048 --> 0.97583
+#
+# Idea Distribution Ratio = 0.74851/(1-0.74851) =2.98
+# Random Distribution Ration = 512/(5401-512)=0.105
+#
+# Typical Distribution Ratio about 25% of Ideal one, still much higher than RDR
+
+EUCTW_TYPICAL_DISTRIBUTION_RATIO = 0.75
+
+# Char to FreqOrder table
+EUCTW_TABLE_SIZE = 5376
+
+# fmt: off
+EUCTW_CHAR_TO_FREQ_ORDER = (
+ 1, 1800, 1506, 255, 1431, 198, 9, 82, 6, 7310, 177, 202, 3615, 1256, 2808, 110, # 2742
+ 3735, 33, 3241, 261, 76, 44, 2113, 16, 2931, 2184, 1176, 659, 3868, 26, 3404, 2643, # 2758
+ 1198, 3869, 3313, 4060, 410, 2211, 302, 590, 361, 1963, 8, 204, 58, 4296, 7311, 1931, # 2774
+ 63, 7312, 7313, 317, 1614, 75, 222, 159, 4061, 2412, 1480, 7314, 3500, 3068, 224, 2809, # 2790
+ 3616, 3, 10, 3870, 1471, 29, 2774, 1135, 2852, 1939, 873, 130, 3242, 1123, 312, 7315, # 2806
+ 4297, 2051, 507, 252, 682, 7316, 142, 1914, 124, 206, 2932, 34, 3501, 3173, 64, 604, # 2822
+ 7317, 2494, 1976, 1977, 155, 1990, 645, 641, 1606, 7318, 3405, 337, 72, 406, 7319, 80, # 2838
+ 630, 238, 3174, 1509, 263, 939, 1092, 2644, 756, 1440, 1094, 3406, 449, 69, 2969, 591, # 2854
+ 179, 2095, 471, 115, 2034, 1843, 60, 50, 2970, 134, 806, 1868, 734, 2035, 3407, 180, # 2870
+ 995, 1607, 156, 537, 2893, 688, 7320, 319, 1305, 779, 2144, 514, 2374, 298, 4298, 359, # 2886
+ 2495, 90, 2707, 1338, 663, 11, 906, 1099, 2545, 20, 2436, 182, 532, 1716, 7321, 732, # 2902
+ 1376, 4062, 1311, 1420, 3175, 25, 2312, 1056, 113, 399, 382, 1949, 242, 3408, 2467, 529, # 2918
+ 3243, 475, 1447, 3617, 7322, 117, 21, 656, 810, 1297, 2295, 2329, 3502, 7323, 126, 4063, # 2934
+ 706, 456, 150, 613, 4299, 71, 1118, 2036, 4064, 145, 3069, 85, 835, 486, 2114, 1246, # 2950
+ 1426, 428, 727, 1285, 1015, 800, 106, 623, 303, 1281, 7324, 2127, 2354, 347, 3736, 221, # 2966
+ 3503, 3110, 7325, 1955, 1153, 4065, 83, 296, 1199, 3070, 192, 624, 93, 7326, 822, 1897, # 2982
+ 2810, 3111, 795, 2064, 991, 1554, 1542, 1592, 27, 43, 2853, 859, 139, 1456, 860, 4300, # 2998
+ 437, 712, 3871, 164, 2392, 3112, 695, 211, 3017, 2096, 195, 3872, 1608, 3504, 3505, 3618, # 3014
+ 3873, 234, 811, 2971, 2097, 3874, 2229, 1441, 3506, 1615, 2375, 668, 2076, 1638, 305, 228, # 3030
+ 1664, 4301, 467, 415, 7327, 262, 2098, 1593, 239, 108, 300, 200, 1033, 512, 1247, 2077, # 3046
+ 7328, 7329, 2173, 3176, 3619, 2673, 593, 845, 1062, 3244, 88, 1723, 2037, 3875, 1950, 212, # 3062
+ 266, 152, 149, 468, 1898, 4066, 4302, 77, 187, 7330, 3018, 37, 5, 2972, 7331, 3876, # 3078
+ 7332, 7333, 39, 2517, 4303, 2894, 3177, 2078, 55, 148, 74, 4304, 545, 483, 1474, 1029, # 3094
+ 1665, 217, 1869, 1531, 3113, 1104, 2645, 4067, 24, 172, 3507, 900, 3877, 3508, 3509, 4305, # 3110
+ 32, 1408, 2811, 1312, 329, 487, 2355, 2247, 2708, 784, 2674, 4, 3019, 3314, 1427, 1788, # 3126
+ 188, 109, 499, 7334, 3620, 1717, 1789, 888, 1217, 3020, 4306, 7335, 3510, 7336, 3315, 1520, # 3142
+ 3621, 3878, 196, 1034, 775, 7337, 7338, 929, 1815, 249, 439, 38, 7339, 1063, 7340, 794, # 3158
+ 3879, 1435, 2296, 46, 178, 3245, 2065, 7341, 2376, 7342, 214, 1709, 4307, 804, 35, 707, # 3174
+ 324, 3622, 1601, 2546, 140, 459, 4068, 7343, 7344, 1365, 839, 272, 978, 2257, 2572, 3409, # 3190
+ 2128, 1363, 3623, 1423, 697, 100, 3071, 48, 70, 1231, 495, 3114, 2193, 7345, 1294, 7346, # 3206
+ 2079, 462, 586, 1042, 3246, 853, 256, 988, 185, 2377, 3410, 1698, 434, 1084, 7347, 3411, # 3222
+ 314, 2615, 2775, 4308, 2330, 2331, 569, 2280, 637, 1816, 2518, 757, 1162, 1878, 1616, 3412, # 3238
+ 287, 1577, 2115, 768, 4309, 1671, 2854, 3511, 2519, 1321, 3737, 909, 2413, 7348, 4069, 933, # 3254
+ 3738, 7349, 2052, 2356, 1222, 4310, 765, 2414, 1322, 786, 4311, 7350, 1919, 1462, 1677, 2895, # 3270
+ 1699, 7351, 4312, 1424, 2437, 3115, 3624, 2590, 3316, 1774, 1940, 3413, 3880, 4070, 309, 1369, # 3286
+ 1130, 2812, 364, 2230, 1653, 1299, 3881, 3512, 3882, 3883, 2646, 525, 1085, 3021, 902, 2000, # 3302
+ 1475, 964, 4313, 421, 1844, 1415, 1057, 2281, 940, 1364, 3116, 376, 4314, 4315, 1381, 7, # 3318
+ 2520, 983, 2378, 336, 1710, 2675, 1845, 321, 3414, 559, 1131, 3022, 2742, 1808, 1132, 1313, # 3334
+ 265, 1481, 1857, 7352, 352, 1203, 2813, 3247, 167, 1089, 420, 2814, 776, 792, 1724, 3513, # 3350
+ 4071, 2438, 3248, 7353, 4072, 7354, 446, 229, 333, 2743, 901, 3739, 1200, 1557, 4316, 2647, # 3366
+ 1920, 395, 2744, 2676, 3740, 4073, 1835, 125, 916, 3178, 2616, 4317, 7355, 7356, 3741, 7357, # 3382
+ 7358, 7359, 4318, 3117, 3625, 1133, 2547, 1757, 3415, 1510, 2313, 1409, 3514, 7360, 2145, 438, # 3398
+ 2591, 2896, 2379, 3317, 1068, 958, 3023, 461, 311, 2855, 2677, 4074, 1915, 3179, 4075, 1978, # 3414
+ 383, 750, 2745, 2617, 4076, 274, 539, 385, 1278, 1442, 7361, 1154, 1964, 384, 561, 210, # 3430
+ 98, 1295, 2548, 3515, 7362, 1711, 2415, 1482, 3416, 3884, 2897, 1257, 129, 7363, 3742, 642, # 3446
+ 523, 2776, 2777, 2648, 7364, 141, 2231, 1333, 68, 176, 441, 876, 907, 4077, 603, 2592, # 3462
+ 710, 171, 3417, 404, 549, 18, 3118, 2393, 1410, 3626, 1666, 7365, 3516, 4319, 2898, 4320, # 3478
+ 7366, 2973, 368, 7367, 146, 366, 99, 871, 3627, 1543, 748, 807, 1586, 1185, 22, 2258, # 3494
+ 379, 3743, 3180, 7368, 3181, 505, 1941, 2618, 1991, 1382, 2314, 7369, 380, 2357, 218, 702, # 3510
+ 1817, 1248, 3418, 3024, 3517, 3318, 3249, 7370, 2974, 3628, 930, 3250, 3744, 7371, 59, 7372, # 3526
+ 585, 601, 4078, 497, 3419, 1112, 1314, 4321, 1801, 7373, 1223, 1472, 2174, 7374, 749, 1836, # 3542
+ 690, 1899, 3745, 1772, 3885, 1476, 429, 1043, 1790, 2232, 2116, 917, 4079, 447, 1086, 1629, # 3558
+ 7375, 556, 7376, 7377, 2020, 1654, 844, 1090, 105, 550, 966, 1758, 2815, 1008, 1782, 686, # 3574
+ 1095, 7378, 2282, 793, 1602, 7379, 3518, 2593, 4322, 4080, 2933, 2297, 4323, 3746, 980, 2496, # 3590
+ 544, 353, 527, 4324, 908, 2678, 2899, 7380, 381, 2619, 1942, 1348, 7381, 1341, 1252, 560, # 3606
+ 3072, 7382, 3420, 2856, 7383, 2053, 973, 886, 2080, 143, 4325, 7384, 7385, 157, 3886, 496, # 3622
+ 4081, 57, 840, 540, 2038, 4326, 4327, 3421, 2117, 1445, 970, 2259, 1748, 1965, 2081, 4082, # 3638
+ 3119, 1234, 1775, 3251, 2816, 3629, 773, 1206, 2129, 1066, 2039, 1326, 3887, 1738, 1725, 4083, # 3654
+ 279, 3120, 51, 1544, 2594, 423, 1578, 2130, 2066, 173, 4328, 1879, 7386, 7387, 1583, 264, # 3670
+ 610, 3630, 4329, 2439, 280, 154, 7388, 7389, 7390, 1739, 338, 1282, 3073, 693, 2857, 1411, # 3686
+ 1074, 3747, 2440, 7391, 4330, 7392, 7393, 1240, 952, 2394, 7394, 2900, 1538, 2679, 685, 1483, # 3702
+ 4084, 2468, 1436, 953, 4085, 2054, 4331, 671, 2395, 79, 4086, 2441, 3252, 608, 567, 2680, # 3718
+ 3422, 4087, 4088, 1691, 393, 1261, 1791, 2396, 7395, 4332, 7396, 7397, 7398, 7399, 1383, 1672, # 3734
+ 3748, 3182, 1464, 522, 1119, 661, 1150, 216, 675, 4333, 3888, 1432, 3519, 609, 4334, 2681, # 3750
+ 2397, 7400, 7401, 7402, 4089, 3025, 0, 7403, 2469, 315, 231, 2442, 301, 3319, 4335, 2380, # 3766
+ 7404, 233, 4090, 3631, 1818, 4336, 4337, 7405, 96, 1776, 1315, 2082, 7406, 257, 7407, 1809, # 3782
+ 3632, 2709, 1139, 1819, 4091, 2021, 1124, 2163, 2778, 1777, 2649, 7408, 3074, 363, 1655, 3183, # 3798
+ 7409, 2975, 7410, 7411, 7412, 3889, 1567, 3890, 718, 103, 3184, 849, 1443, 341, 3320, 2934, # 3814
+ 1484, 7413, 1712, 127, 67, 339, 4092, 2398, 679, 1412, 821, 7414, 7415, 834, 738, 351, # 3830
+ 2976, 2146, 846, 235, 1497, 1880, 418, 1992, 3749, 2710, 186, 1100, 2147, 2746, 3520, 1545, # 3846
+ 1355, 2935, 2858, 1377, 583, 3891, 4093, 2573, 2977, 7416, 1298, 3633, 1078, 2549, 3634, 2358, # 3862
+ 78, 3750, 3751, 267, 1289, 2099, 2001, 1594, 4094, 348, 369, 1274, 2194, 2175, 1837, 4338, # 3878
+ 1820, 2817, 3635, 2747, 2283, 2002, 4339, 2936, 2748, 144, 3321, 882, 4340, 3892, 2749, 3423, # 3894
+ 4341, 2901, 7417, 4095, 1726, 320, 7418, 3893, 3026, 788, 2978, 7419, 2818, 1773, 1327, 2859, # 3910
+ 3894, 2819, 7420, 1306, 4342, 2003, 1700, 3752, 3521, 2359, 2650, 787, 2022, 506, 824, 3636, # 3926
+ 534, 323, 4343, 1044, 3322, 2023, 1900, 946, 3424, 7421, 1778, 1500, 1678, 7422, 1881, 4344, # 3942
+ 165, 243, 4345, 3637, 2521, 123, 683, 4096, 764, 4346, 36, 3895, 1792, 589, 2902, 816, # 3958
+ 626, 1667, 3027, 2233, 1639, 1555, 1622, 3753, 3896, 7423, 3897, 2860, 1370, 1228, 1932, 891, # 3974
+ 2083, 2903, 304, 4097, 7424, 292, 2979, 2711, 3522, 691, 2100, 4098, 1115, 4347, 118, 662, # 3990
+ 7425, 611, 1156, 854, 2381, 1316, 2861, 2, 386, 515, 2904, 7426, 7427, 3253, 868, 2234, # 4006
+ 1486, 855, 2651, 785, 2212, 3028, 7428, 1040, 3185, 3523, 7429, 3121, 448, 7430, 1525, 7431, # 4022
+ 2164, 4348, 7432, 3754, 7433, 4099, 2820, 3524, 3122, 503, 818, 3898, 3123, 1568, 814, 676, # 4038
+ 1444, 306, 1749, 7434, 3755, 1416, 1030, 197, 1428, 805, 2821, 1501, 4349, 7435, 7436, 7437, # 4054
+ 1993, 7438, 4350, 7439, 7440, 2195, 13, 2779, 3638, 2980, 3124, 1229, 1916, 7441, 3756, 2131, # 4070
+ 7442, 4100, 4351, 2399, 3525, 7443, 2213, 1511, 1727, 1120, 7444, 7445, 646, 3757, 2443, 307, # 4086
+ 7446, 7447, 1595, 3186, 7448, 7449, 7450, 3639, 1113, 1356, 3899, 1465, 2522, 2523, 7451, 519, # 4102
+ 7452, 128, 2132, 92, 2284, 1979, 7453, 3900, 1512, 342, 3125, 2196, 7454, 2780, 2214, 1980, # 4118
+ 3323, 7455, 290, 1656, 1317, 789, 827, 2360, 7456, 3758, 4352, 562, 581, 3901, 7457, 401, # 4134
+ 4353, 2248, 94, 4354, 1399, 2781, 7458, 1463, 2024, 4355, 3187, 1943, 7459, 828, 1105, 4101, # 4150
+ 1262, 1394, 7460, 4102, 605, 4356, 7461, 1783, 2862, 7462, 2822, 819, 2101, 578, 2197, 2937, # 4166
+ 7463, 1502, 436, 3254, 4103, 3255, 2823, 3902, 2905, 3425, 3426, 7464, 2712, 2315, 7465, 7466, # 4182
+ 2332, 2067, 23, 4357, 193, 826, 3759, 2102, 699, 1630, 4104, 3075, 390, 1793, 1064, 3526, # 4198
+ 7467, 1579, 3076, 3077, 1400, 7468, 4105, 1838, 1640, 2863, 7469, 4358, 4359, 137, 4106, 598, # 4214
+ 3078, 1966, 780, 104, 974, 2938, 7470, 278, 899, 253, 402, 572, 504, 493, 1339, 7471, # 4230
+ 3903, 1275, 4360, 2574, 2550, 7472, 3640, 3029, 3079, 2249, 565, 1334, 2713, 863, 41, 7473, # 4246
+ 7474, 4361, 7475, 1657, 2333, 19, 463, 2750, 4107, 606, 7476, 2981, 3256, 1087, 2084, 1323, # 4262
+ 2652, 2982, 7477, 1631, 1623, 1750, 4108, 2682, 7478, 2864, 791, 2714, 2653, 2334, 232, 2416, # 4278
+ 7479, 2983, 1498, 7480, 2654, 2620, 755, 1366, 3641, 3257, 3126, 2025, 1609, 119, 1917, 3427, # 4294
+ 862, 1026, 4109, 7481, 3904, 3760, 4362, 3905, 4363, 2260, 1951, 2470, 7482, 1125, 817, 4110, # 4310
+ 4111, 3906, 1513, 1766, 2040, 1487, 4112, 3030, 3258, 2824, 3761, 3127, 7483, 7484, 1507, 7485, # 4326
+ 2683, 733, 40, 1632, 1106, 2865, 345, 4113, 841, 2524, 230, 4364, 2984, 1846, 3259, 3428, # 4342
+ 7486, 1263, 986, 3429, 7487, 735, 879, 254, 1137, 857, 622, 1300, 1180, 1388, 1562, 3907, # 4358
+ 3908, 2939, 967, 2751, 2655, 1349, 592, 2133, 1692, 3324, 2985, 1994, 4114, 1679, 3909, 1901, # 4374
+ 2185, 7488, 739, 3642, 2715, 1296, 1290, 7489, 4115, 2198, 2199, 1921, 1563, 2595, 2551, 1870, # 4390
+ 2752, 2986, 7490, 435, 7491, 343, 1108, 596, 17, 1751, 4365, 2235, 3430, 3643, 7492, 4366, # 4406
+ 294, 3527, 2940, 1693, 477, 979, 281, 2041, 3528, 643, 2042, 3644, 2621, 2782, 2261, 1031, # 4422
+ 2335, 2134, 2298, 3529, 4367, 367, 1249, 2552, 7493, 3530, 7494, 4368, 1283, 3325, 2004, 240, # 4438
+ 1762, 3326, 4369, 4370, 836, 1069, 3128, 474, 7495, 2148, 2525, 268, 3531, 7496, 3188, 1521, # 4454
+ 1284, 7497, 1658, 1546, 4116, 7498, 3532, 3533, 7499, 4117, 3327, 2684, 1685, 4118, 961, 1673, # 4470
+ 2622, 190, 2005, 2200, 3762, 4371, 4372, 7500, 570, 2497, 3645, 1490, 7501, 4373, 2623, 3260, # 4486
+ 1956, 4374, 584, 1514, 396, 1045, 1944, 7502, 4375, 1967, 2444, 7503, 7504, 4376, 3910, 619, # 4502
+ 7505, 3129, 3261, 215, 2006, 2783, 2553, 3189, 4377, 3190, 4378, 763, 4119, 3763, 4379, 7506, # 4518
+ 7507, 1957, 1767, 2941, 3328, 3646, 1174, 452, 1477, 4380, 3329, 3130, 7508, 2825, 1253, 2382, # 4534
+ 2186, 1091, 2285, 4120, 492, 7509, 638, 1169, 1824, 2135, 1752, 3911, 648, 926, 1021, 1324, # 4550
+ 4381, 520, 4382, 997, 847, 1007, 892, 4383, 3764, 2262, 1871, 3647, 7510, 2400, 1784, 4384, # 4566
+ 1952, 2942, 3080, 3191, 1728, 4121, 2043, 3648, 4385, 2007, 1701, 3131, 1551, 30, 2263, 4122, # 4582
+ 7511, 2026, 4386, 3534, 7512, 501, 7513, 4123, 594, 3431, 2165, 1821, 3535, 3432, 3536, 3192, # 4598
+ 829, 2826, 4124, 7514, 1680, 3132, 1225, 4125, 7515, 3262, 4387, 4126, 3133, 2336, 7516, 4388, # 4614
+ 4127, 7517, 3912, 3913, 7518, 1847, 2383, 2596, 3330, 7519, 4389, 374, 3914, 652, 4128, 4129, # 4630
+ 375, 1140, 798, 7520, 7521, 7522, 2361, 4390, 2264, 546, 1659, 138, 3031, 2445, 4391, 7523, # 4646
+ 2250, 612, 1848, 910, 796, 3765, 1740, 1371, 825, 3766, 3767, 7524, 2906, 2554, 7525, 692, # 4662
+ 444, 3032, 2624, 801, 4392, 4130, 7526, 1491, 244, 1053, 3033, 4131, 4132, 340, 7527, 3915, # 4678
+ 1041, 2987, 293, 1168, 87, 1357, 7528, 1539, 959, 7529, 2236, 721, 694, 4133, 3768, 219, # 4694
+ 1478, 644, 1417, 3331, 2656, 1413, 1401, 1335, 1389, 3916, 7530, 7531, 2988, 2362, 3134, 1825, # 4710
+ 730, 1515, 184, 2827, 66, 4393, 7532, 1660, 2943, 246, 3332, 378, 1457, 226, 3433, 975, # 4726
+ 3917, 2944, 1264, 3537, 674, 696, 7533, 163, 7534, 1141, 2417, 2166, 713, 3538, 3333, 4394, # 4742
+ 3918, 7535, 7536, 1186, 15, 7537, 1079, 1070, 7538, 1522, 3193, 3539, 276, 1050, 2716, 758, # 4758
+ 1126, 653, 2945, 3263, 7539, 2337, 889, 3540, 3919, 3081, 2989, 903, 1250, 4395, 3920, 3434, # 4774
+ 3541, 1342, 1681, 1718, 766, 3264, 286, 89, 2946, 3649, 7540, 1713, 7541, 2597, 3334, 2990, # 4790
+ 7542, 2947, 2215, 3194, 2866, 7543, 4396, 2498, 2526, 181, 387, 1075, 3921, 731, 2187, 3335, # 4806
+ 7544, 3265, 310, 313, 3435, 2299, 770, 4134, 54, 3034, 189, 4397, 3082, 3769, 3922, 7545, # 4822
+ 1230, 1617, 1849, 355, 3542, 4135, 4398, 3336, 111, 4136, 3650, 1350, 3135, 3436, 3035, 4137, # 4838
+ 2149, 3266, 3543, 7546, 2784, 3923, 3924, 2991, 722, 2008, 7547, 1071, 247, 1207, 2338, 2471, # 4854
+ 1378, 4399, 2009, 864, 1437, 1214, 4400, 373, 3770, 1142, 2216, 667, 4401, 442, 2753, 2555, # 4870
+ 3771, 3925, 1968, 4138, 3267, 1839, 837, 170, 1107, 934, 1336, 1882, 7548, 7549, 2118, 4139, # 4886
+ 2828, 743, 1569, 7550, 4402, 4140, 582, 2384, 1418, 3437, 7551, 1802, 7552, 357, 1395, 1729, # 4902
+ 3651, 3268, 2418, 1564, 2237, 7553, 3083, 3772, 1633, 4403, 1114, 2085, 4141, 1532, 7554, 482, # 4918
+ 2446, 4404, 7555, 7556, 1492, 833, 1466, 7557, 2717, 3544, 1641, 2829, 7558, 1526, 1272, 3652, # 4934
+ 4142, 1686, 1794, 416, 2556, 1902, 1953, 1803, 7559, 3773, 2785, 3774, 1159, 2316, 7560, 2867, # 4950
+ 4405, 1610, 1584, 3036, 2419, 2754, 443, 3269, 1163, 3136, 7561, 7562, 3926, 7563, 4143, 2499, # 4966
+ 3037, 4406, 3927, 3137, 2103, 1647, 3545, 2010, 1872, 4144, 7564, 4145, 431, 3438, 7565, 250, # 4982
+ 97, 81, 4146, 7566, 1648, 1850, 1558, 160, 848, 7567, 866, 740, 1694, 7568, 2201, 2830, # 4998
+ 3195, 4147, 4407, 3653, 1687, 950, 2472, 426, 469, 3196, 3654, 3655, 3928, 7569, 7570, 1188, # 5014
+ 424, 1995, 861, 3546, 4148, 3775, 2202, 2685, 168, 1235, 3547, 4149, 7571, 2086, 1674, 4408, # 5030
+ 3337, 3270, 220, 2557, 1009, 7572, 3776, 670, 2992, 332, 1208, 717, 7573, 7574, 3548, 2447, # 5046
+ 3929, 3338, 7575, 513, 7576, 1209, 2868, 3339, 3138, 4409, 1080, 7577, 7578, 7579, 7580, 2527, # 5062
+ 3656, 3549, 815, 1587, 3930, 3931, 7581, 3550, 3439, 3777, 1254, 4410, 1328, 3038, 1390, 3932, # 5078
+ 1741, 3933, 3778, 3934, 7582, 236, 3779, 2448, 3271, 7583, 7584, 3657, 3780, 1273, 3781, 4411, # 5094
+ 7585, 308, 7586, 4412, 245, 4413, 1851, 2473, 1307, 2575, 430, 715, 2136, 2449, 7587, 270, # 5110
+ 199, 2869, 3935, 7588, 3551, 2718, 1753, 761, 1754, 725, 1661, 1840, 4414, 3440, 3658, 7589, # 5126
+ 7590, 587, 14, 3272, 227, 2598, 326, 480, 2265, 943, 2755, 3552, 291, 650, 1883, 7591, # 5142
+ 1702, 1226, 102, 1547, 62, 3441, 904, 4415, 3442, 1164, 4150, 7592, 7593, 1224, 1548, 2756, # 5158
+ 391, 498, 1493, 7594, 1386, 1419, 7595, 2055, 1177, 4416, 813, 880, 1081, 2363, 566, 1145, # 5174
+ 4417, 2286, 1001, 1035, 2558, 2599, 2238, 394, 1286, 7596, 7597, 2068, 7598, 86, 1494, 1730, # 5190
+ 3936, 491, 1588, 745, 897, 2948, 843, 3340, 3937, 2757, 2870, 3273, 1768, 998, 2217, 2069, # 5206
+ 397, 1826, 1195, 1969, 3659, 2993, 3341, 284, 7599, 3782, 2500, 2137, 2119, 1903, 7600, 3938, # 5222
+ 2150, 3939, 4151, 1036, 3443, 1904, 114, 2559, 4152, 209, 1527, 7601, 7602, 2949, 2831, 2625, # 5238
+ 2385, 2719, 3139, 812, 2560, 7603, 3274, 7604, 1559, 737, 1884, 3660, 1210, 885, 28, 2686, # 5254
+ 3553, 3783, 7605, 4153, 1004, 1779, 4418, 7606, 346, 1981, 2218, 2687, 4419, 3784, 1742, 797, # 5270
+ 1642, 3940, 1933, 1072, 1384, 2151, 896, 3941, 3275, 3661, 3197, 2871, 3554, 7607, 2561, 1958, # 5286
+ 4420, 2450, 1785, 7608, 7609, 7610, 3942, 4154, 1005, 1308, 3662, 4155, 2720, 4421, 4422, 1528, # 5302
+ 2600, 161, 1178, 4156, 1982, 987, 4423, 1101, 4157, 631, 3943, 1157, 3198, 2420, 1343, 1241, # 5318
+ 1016, 2239, 2562, 372, 877, 2339, 2501, 1160, 555, 1934, 911, 3944, 7611, 466, 1170, 169, # 5334
+ 1051, 2907, 2688, 3663, 2474, 2994, 1182, 2011, 2563, 1251, 2626, 7612, 992, 2340, 3444, 1540, # 5350
+ 2721, 1201, 2070, 2401, 1996, 2475, 7613, 4424, 528, 1922, 2188, 1503, 1873, 1570, 2364, 3342, # 5366
+ 3276, 7614, 557, 1073, 7615, 1827, 3445, 2087, 2266, 3140, 3039, 3084, 767, 3085, 2786, 4425, # 5382
+ 1006, 4158, 4426, 2341, 1267, 2176, 3664, 3199, 778, 3945, 3200, 2722, 1597, 2657, 7616, 4427, # 5398
+ 7617, 3446, 7618, 7619, 7620, 3277, 2689, 1433, 3278, 131, 95, 1504, 3946, 723, 4159, 3141, # 5414
+ 1841, 3555, 2758, 2189, 3947, 2027, 2104, 3665, 7621, 2995, 3948, 1218, 7622, 3343, 3201, 3949, # 5430
+ 4160, 2576, 248, 1634, 3785, 912, 7623, 2832, 3666, 3040, 3786, 654, 53, 7624, 2996, 7625, # 5446
+ 1688, 4428, 777, 3447, 1032, 3950, 1425, 7626, 191, 820, 2120, 2833, 971, 4429, 931, 3202, # 5462
+ 135, 664, 783, 3787, 1997, 772, 2908, 1935, 3951, 3788, 4430, 2909, 3203, 282, 2723, 640, # 5478
+ 1372, 3448, 1127, 922, 325, 3344, 7627, 7628, 711, 2044, 7629, 7630, 3952, 2219, 2787, 1936, # 5494
+ 3953, 3345, 2220, 2251, 3789, 2300, 7631, 4431, 3790, 1258, 3279, 3954, 3204, 2138, 2950, 3955, # 5510
+ 3956, 7632, 2221, 258, 3205, 4432, 101, 1227, 7633, 3280, 1755, 7634, 1391, 3281, 7635, 2910, # 5526
+ 2056, 893, 7636, 7637, 7638, 1402, 4161, 2342, 7639, 7640, 3206, 3556, 7641, 7642, 878, 1325, # 5542
+ 1780, 2788, 4433, 259, 1385, 2577, 744, 1183, 2267, 4434, 7643, 3957, 2502, 7644, 684, 1024, # 5558
+ 4162, 7645, 472, 3557, 3449, 1165, 3282, 3958, 3959, 322, 2152, 881, 455, 1695, 1152, 1340, # 5574
+ 660, 554, 2153, 4435, 1058, 4436, 4163, 830, 1065, 3346, 3960, 4437, 1923, 7646, 1703, 1918, # 5590
+ 7647, 932, 2268, 122, 7648, 4438, 947, 677, 7649, 3791, 2627, 297, 1905, 1924, 2269, 4439, # 5606
+ 2317, 3283, 7650, 7651, 4164, 7652, 4165, 84, 4166, 112, 989, 7653, 547, 1059, 3961, 701, # 5622
+ 3558, 1019, 7654, 4167, 7655, 3450, 942, 639, 457, 2301, 2451, 993, 2951, 407, 851, 494, # 5638
+ 4440, 3347, 927, 7656, 1237, 7657, 2421, 3348, 573, 4168, 680, 921, 2911, 1279, 1874, 285, # 5654
+ 790, 1448, 1983, 719, 2167, 7658, 7659, 4441, 3962, 3963, 1649, 7660, 1541, 563, 7661, 1077, # 5670
+ 7662, 3349, 3041, 3451, 511, 2997, 3964, 3965, 3667, 3966, 1268, 2564, 3350, 3207, 4442, 4443, # 5686
+ 7663, 535, 1048, 1276, 1189, 2912, 2028, 3142, 1438, 1373, 2834, 2952, 1134, 2012, 7664, 4169, # 5702
+ 1238, 2578, 3086, 1259, 7665, 700, 7666, 2953, 3143, 3668, 4170, 7667, 4171, 1146, 1875, 1906, # 5718
+ 4444, 2601, 3967, 781, 2422, 132, 1589, 203, 147, 273, 2789, 2402, 898, 1786, 2154, 3968, # 5734
+ 3969, 7668, 3792, 2790, 7669, 7670, 4445, 4446, 7671, 3208, 7672, 1635, 3793, 965, 7673, 1804, # 5750
+ 2690, 1516, 3559, 1121, 1082, 1329, 3284, 3970, 1449, 3794, 65, 1128, 2835, 2913, 2759, 1590, # 5766
+ 3795, 7674, 7675, 12, 2658, 45, 976, 2579, 3144, 4447, 517, 2528, 1013, 1037, 3209, 7676, # 5782
+ 3796, 2836, 7677, 3797, 7678, 3452, 7679, 2602, 614, 1998, 2318, 3798, 3087, 2724, 2628, 7680, # 5798
+ 2580, 4172, 599, 1269, 7681, 1810, 3669, 7682, 2691, 3088, 759, 1060, 489, 1805, 3351, 3285, # 5814
+ 1358, 7683, 7684, 2386, 1387, 1215, 2629, 2252, 490, 7685, 7686, 4173, 1759, 2387, 2343, 7687, # 5830
+ 4448, 3799, 1907, 3971, 2630, 1806, 3210, 4449, 3453, 3286, 2760, 2344, 874, 7688, 7689, 3454, # 5846
+ 3670, 1858, 91, 2914, 3671, 3042, 3800, 4450, 7690, 3145, 3972, 2659, 7691, 3455, 1202, 1403, # 5862
+ 3801, 2954, 2529, 1517, 2503, 4451, 3456, 2504, 7692, 4452, 7693, 2692, 1885, 1495, 1731, 3973, # 5878
+ 2365, 4453, 7694, 2029, 7695, 7696, 3974, 2693, 1216, 237, 2581, 4174, 2319, 3975, 3802, 4454, # 5894
+ 4455, 2694, 3560, 3457, 445, 4456, 7697, 7698, 7699, 7700, 2761, 61, 3976, 3672, 1822, 3977, # 5910
+ 7701, 687, 2045, 935, 925, 405, 2660, 703, 1096, 1859, 2725, 4457, 3978, 1876, 1367, 2695, # 5926
+ 3352, 918, 2105, 1781, 2476, 334, 3287, 1611, 1093, 4458, 564, 3146, 3458, 3673, 3353, 945, # 5942
+ 2631, 2057, 4459, 7702, 1925, 872, 4175, 7703, 3459, 2696, 3089, 349, 4176, 3674, 3979, 4460, # 5958
+ 3803, 4177, 3675, 2155, 3980, 4461, 4462, 4178, 4463, 2403, 2046, 782, 3981, 400, 251, 4179, # 5974
+ 1624, 7704, 7705, 277, 3676, 299, 1265, 476, 1191, 3804, 2121, 4180, 4181, 1109, 205, 7706, # 5990
+ 2582, 1000, 2156, 3561, 1860, 7707, 7708, 7709, 4464, 7710, 4465, 2565, 107, 2477, 2157, 3982, # 6006
+ 3460, 3147, 7711, 1533, 541, 1301, 158, 753, 4182, 2872, 3562, 7712, 1696, 370, 1088, 4183, # 6022
+ 4466, 3563, 579, 327, 440, 162, 2240, 269, 1937, 1374, 3461, 968, 3043, 56, 1396, 3090, # 6038
+ 2106, 3288, 3354, 7713, 1926, 2158, 4467, 2998, 7714, 3564, 7715, 7716, 3677, 4468, 2478, 7717, # 6054
+ 2791, 7718, 1650, 4469, 7719, 2603, 7720, 7721, 3983, 2661, 3355, 1149, 3356, 3984, 3805, 3985, # 6070
+ 7722, 1076, 49, 7723, 951, 3211, 3289, 3290, 450, 2837, 920, 7724, 1811, 2792, 2366, 4184, # 6086
+ 1908, 1138, 2367, 3806, 3462, 7725, 3212, 4470, 1909, 1147, 1518, 2423, 4471, 3807, 7726, 4472, # 6102
+ 2388, 2604, 260, 1795, 3213, 7727, 7728, 3808, 3291, 708, 7729, 3565, 1704, 7730, 3566, 1351, # 6118
+ 1618, 3357, 2999, 1886, 944, 4185, 3358, 4186, 3044, 3359, 4187, 7731, 3678, 422, 413, 1714, # 6134
+ 3292, 500, 2058, 2345, 4188, 2479, 7732, 1344, 1910, 954, 7733, 1668, 7734, 7735, 3986, 2404, # 6150
+ 4189, 3567, 3809, 4190, 7736, 2302, 1318, 2505, 3091, 133, 3092, 2873, 4473, 629, 31, 2838, # 6166
+ 2697, 3810, 4474, 850, 949, 4475, 3987, 2955, 1732, 2088, 4191, 1496, 1852, 7737, 3988, 620, # 6182
+ 3214, 981, 1242, 3679, 3360, 1619, 3680, 1643, 3293, 2139, 2452, 1970, 1719, 3463, 2168, 7738, # 6198
+ 3215, 7739, 7740, 3361, 1828, 7741, 1277, 4476, 1565, 2047, 7742, 1636, 3568, 3093, 7743, 869, # 6214
+ 2839, 655, 3811, 3812, 3094, 3989, 3000, 3813, 1310, 3569, 4477, 7744, 7745, 7746, 1733, 558, # 6230
+ 4478, 3681, 335, 1549, 3045, 1756, 4192, 3682, 1945, 3464, 1829, 1291, 1192, 470, 2726, 2107, # 6246
+ 2793, 913, 1054, 3990, 7747, 1027, 7748, 3046, 3991, 4479, 982, 2662, 3362, 3148, 3465, 3216, # 6262
+ 3217, 1946, 2794, 7749, 571, 4480, 7750, 1830, 7751, 3570, 2583, 1523, 2424, 7752, 2089, 984, # 6278
+ 4481, 3683, 1959, 7753, 3684, 852, 923, 2795, 3466, 3685, 969, 1519, 999, 2048, 2320, 1705, # 6294
+ 7754, 3095, 615, 1662, 151, 597, 3992, 2405, 2321, 1049, 275, 4482, 3686, 4193, 568, 3687, # 6310
+ 3571, 2480, 4194, 3688, 7755, 2425, 2270, 409, 3218, 7756, 1566, 2874, 3467, 1002, 769, 2840, # 6326
+ 194, 2090, 3149, 3689, 2222, 3294, 4195, 628, 1505, 7757, 7758, 1763, 2177, 3001, 3993, 521, # 6342
+ 1161, 2584, 1787, 2203, 2406, 4483, 3994, 1625, 4196, 4197, 412, 42, 3096, 464, 7759, 2632, # 6358
+ 4484, 3363, 1760, 1571, 2875, 3468, 2530, 1219, 2204, 3814, 2633, 2140, 2368, 4485, 4486, 3295, # 6374
+ 1651, 3364, 3572, 7760, 7761, 3573, 2481, 3469, 7762, 3690, 7763, 7764, 2271, 2091, 460, 7765, # 6390
+ 4487, 7766, 3002, 962, 588, 3574, 289, 3219, 2634, 1116, 52, 7767, 3047, 1796, 7768, 7769, # 6406
+ 7770, 1467, 7771, 1598, 1143, 3691, 4198, 1984, 1734, 1067, 4488, 1280, 3365, 465, 4489, 1572, # 6422
+ 510, 7772, 1927, 2241, 1812, 1644, 3575, 7773, 4490, 3692, 7774, 7775, 2663, 1573, 1534, 7776, # 6438
+ 7777, 4199, 536, 1807, 1761, 3470, 3815, 3150, 2635, 7778, 7779, 7780, 4491, 3471, 2915, 1911, # 6454
+ 2796, 7781, 3296, 1122, 377, 3220, 7782, 360, 7783, 7784, 4200, 1529, 551, 7785, 2059, 3693, # 6470
+ 1769, 2426, 7786, 2916, 4201, 3297, 3097, 2322, 2108, 2030, 4492, 1404, 136, 1468, 1479, 672, # 6486
+ 1171, 3221, 2303, 271, 3151, 7787, 2762, 7788, 2049, 678, 2727, 865, 1947, 4493, 7789, 2013, # 6502
+ 3995, 2956, 7790, 2728, 2223, 1397, 3048, 3694, 4494, 4495, 1735, 2917, 3366, 3576, 7791, 3816, # 6518
+ 509, 2841, 2453, 2876, 3817, 7792, 7793, 3152, 3153, 4496, 4202, 2531, 4497, 2304, 1166, 1010, # 6534
+ 552, 681, 1887, 7794, 7795, 2957, 2958, 3996, 1287, 1596, 1861, 3154, 358, 453, 736, 175, # 6550
+ 478, 1117, 905, 1167, 1097, 7796, 1853, 1530, 7797, 1706, 7798, 2178, 3472, 2287, 3695, 3473, # 6566
+ 3577, 4203, 2092, 4204, 7799, 3367, 1193, 2482, 4205, 1458, 2190, 2205, 1862, 1888, 1421, 3298, # 6582
+ 2918, 3049, 2179, 3474, 595, 2122, 7800, 3997, 7801, 7802, 4206, 1707, 2636, 223, 3696, 1359, # 6598
+ 751, 3098, 183, 3475, 7803, 2797, 3003, 419, 2369, 633, 704, 3818, 2389, 241, 7804, 7805, # 6614
+ 7806, 838, 3004, 3697, 2272, 2763, 2454, 3819, 1938, 2050, 3998, 1309, 3099, 2242, 1181, 7807, # 6630
+ 1136, 2206, 3820, 2370, 1446, 4207, 2305, 4498, 7808, 7809, 4208, 1055, 2605, 484, 3698, 7810, # 6646
+ 3999, 625, 4209, 2273, 3368, 1499, 4210, 4000, 7811, 4001, 4211, 3222, 2274, 2275, 3476, 7812, # 6662
+ 7813, 2764, 808, 2606, 3699, 3369, 4002, 4212, 3100, 2532, 526, 3370, 3821, 4213, 955, 7814, # 6678
+ 1620, 4214, 2637, 2427, 7815, 1429, 3700, 1669, 1831, 994, 928, 7816, 3578, 1260, 7817, 7818, # 6694
+ 7819, 1948, 2288, 741, 2919, 1626, 4215, 2729, 2455, 867, 1184, 362, 3371, 1392, 7820, 7821, # 6710
+ 4003, 4216, 1770, 1736, 3223, 2920, 4499, 4500, 1928, 2698, 1459, 1158, 7822, 3050, 3372, 2877, # 6726
+ 1292, 1929, 2506, 2842, 3701, 1985, 1187, 2071, 2014, 2607, 4217, 7823, 2566, 2507, 2169, 3702, # 6742
+ 2483, 3299, 7824, 3703, 4501, 7825, 7826, 666, 1003, 3005, 1022, 3579, 4218, 7827, 4502, 1813, # 6758
+ 2253, 574, 3822, 1603, 295, 1535, 705, 3823, 4219, 283, 858, 417, 7828, 7829, 3224, 4503, # 6774
+ 4504, 3051, 1220, 1889, 1046, 2276, 2456, 4004, 1393, 1599, 689, 2567, 388, 4220, 7830, 2484, # 6790
+ 802, 7831, 2798, 3824, 2060, 1405, 2254, 7832, 4505, 3825, 2109, 1052, 1345, 3225, 1585, 7833, # 6806
+ 809, 7834, 7835, 7836, 575, 2730, 3477, 956, 1552, 1469, 1144, 2323, 7837, 2324, 1560, 2457, # 6822
+ 3580, 3226, 4005, 616, 2207, 3155, 2180, 2289, 7838, 1832, 7839, 3478, 4506, 7840, 1319, 3704, # 6838
+ 3705, 1211, 3581, 1023, 3227, 1293, 2799, 7841, 7842, 7843, 3826, 607, 2306, 3827, 762, 2878, # 6854
+ 1439, 4221, 1360, 7844, 1485, 3052, 7845, 4507, 1038, 4222, 1450, 2061, 2638, 4223, 1379, 4508, # 6870
+ 2585, 7846, 7847, 4224, 1352, 1414, 2325, 2921, 1172, 7848, 7849, 3828, 3829, 7850, 1797, 1451, # 6886
+ 7851, 7852, 7853, 7854, 2922, 4006, 4007, 2485, 2346, 411, 4008, 4009, 3582, 3300, 3101, 4509, # 6902
+ 1561, 2664, 1452, 4010, 1375, 7855, 7856, 47, 2959, 316, 7857, 1406, 1591, 2923, 3156, 7858, # 6918
+ 1025, 2141, 3102, 3157, 354, 2731, 884, 2224, 4225, 2407, 508, 3706, 726, 3583, 996, 2428, # 6934
+ 3584, 729, 7859, 392, 2191, 1453, 4011, 4510, 3707, 7860, 7861, 2458, 3585, 2608, 1675, 2800, # 6950
+ 919, 2347, 2960, 2348, 1270, 4511, 4012, 73, 7862, 7863, 647, 7864, 3228, 2843, 2255, 1550, # 6966
+ 1346, 3006, 7865, 1332, 883, 3479, 7866, 7867, 7868, 7869, 3301, 2765, 7870, 1212, 831, 1347, # 6982
+ 4226, 4512, 2326, 3830, 1863, 3053, 720, 3831, 4513, 4514, 3832, 7871, 4227, 7872, 7873, 4515, # 6998
+ 7874, 7875, 1798, 4516, 3708, 2609, 4517, 3586, 1645, 2371, 7876, 7877, 2924, 669, 2208, 2665, # 7014
+ 2429, 7878, 2879, 7879, 7880, 1028, 3229, 7881, 4228, 2408, 7882, 2256, 1353, 7883, 7884, 4518, # 7030
+ 3158, 518, 7885, 4013, 7886, 4229, 1960, 7887, 2142, 4230, 7888, 7889, 3007, 2349, 2350, 3833, # 7046
+ 516, 1833, 1454, 4014, 2699, 4231, 4519, 2225, 2610, 1971, 1129, 3587, 7890, 2766, 7891, 2961, # 7062
+ 1422, 577, 1470, 3008, 1524, 3373, 7892, 7893, 432, 4232, 3054, 3480, 7894, 2586, 1455, 2508, # 7078
+ 2226, 1972, 1175, 7895, 1020, 2732, 4015, 3481, 4520, 7896, 2733, 7897, 1743, 1361, 3055, 3482, # 7094
+ 2639, 4016, 4233, 4521, 2290, 895, 924, 4234, 2170, 331, 2243, 3056, 166, 1627, 3057, 1098, # 7110
+ 7898, 1232, 2880, 2227, 3374, 4522, 657, 403, 1196, 2372, 542, 3709, 3375, 1600, 4235, 3483, # 7126
+ 7899, 4523, 2767, 3230, 576, 530, 1362, 7900, 4524, 2533, 2666, 3710, 4017, 7901, 842, 3834, # 7142
+ 7902, 2801, 2031, 1014, 4018, 213, 2700, 3376, 665, 621, 4236, 7903, 3711, 2925, 2430, 7904, # 7158
+ 2431, 3302, 3588, 3377, 7905, 4237, 2534, 4238, 4525, 3589, 1682, 4239, 3484, 1380, 7906, 724, # 7174
+ 2277, 600, 1670, 7907, 1337, 1233, 4526, 3103, 2244, 7908, 1621, 4527, 7909, 651, 4240, 7910, # 7190
+ 1612, 4241, 2611, 7911, 2844, 7912, 2734, 2307, 3058, 7913, 716, 2459, 3059, 174, 1255, 2701, # 7206
+ 4019, 3590, 548, 1320, 1398, 728, 4020, 1574, 7914, 1890, 1197, 3060, 4021, 7915, 3061, 3062, # 7222
+ 3712, 3591, 3713, 747, 7916, 635, 4242, 4528, 7917, 7918, 7919, 4243, 7920, 7921, 4529, 7922, # 7238
+ 3378, 4530, 2432, 451, 7923, 3714, 2535, 2072, 4244, 2735, 4245, 4022, 7924, 1764, 4531, 7925, # 7254
+ 4246, 350, 7926, 2278, 2390, 2486, 7927, 4247, 4023, 2245, 1434, 4024, 488, 4532, 458, 4248, # 7270
+ 4025, 3715, 771, 1330, 2391, 3835, 2568, 3159, 2159, 2409, 1553, 2667, 3160, 4249, 7928, 2487, # 7286
+ 2881, 2612, 1720, 2702, 4250, 3379, 4533, 7929, 2536, 4251, 7930, 3231, 4252, 2768, 7931, 2015, # 7302
+ 2736, 7932, 1155, 1017, 3716, 3836, 7933, 3303, 2308, 201, 1864, 4253, 1430, 7934, 4026, 7935, # 7318
+ 7936, 7937, 7938, 7939, 4254, 1604, 7940, 414, 1865, 371, 2587, 4534, 4535, 3485, 2016, 3104, # 7334
+ 4536, 1708, 960, 4255, 887, 389, 2171, 1536, 1663, 1721, 7941, 2228, 4027, 2351, 2926, 1580, # 7350
+ 7942, 7943, 7944, 1744, 7945, 2537, 4537, 4538, 7946, 4539, 7947, 2073, 7948, 7949, 3592, 3380, # 7366
+ 2882, 4256, 7950, 4257, 2640, 3381, 2802, 673, 2703, 2460, 709, 3486, 4028, 3593, 4258, 7951, # 7382
+ 1148, 502, 634, 7952, 7953, 1204, 4540, 3594, 1575, 4541, 2613, 3717, 7954, 3718, 3105, 948, # 7398
+ 3232, 121, 1745, 3837, 1110, 7955, 4259, 3063, 2509, 3009, 4029, 3719, 1151, 1771, 3838, 1488, # 7414
+ 4030, 1986, 7956, 2433, 3487, 7957, 7958, 2093, 7959, 4260, 3839, 1213, 1407, 2803, 531, 2737, # 7430
+ 2538, 3233, 1011, 1537, 7960, 2769, 4261, 3106, 1061, 7961, 3720, 3721, 1866, 2883, 7962, 2017, # 7446
+ 120, 4262, 4263, 2062, 3595, 3234, 2309, 3840, 2668, 3382, 1954, 4542, 7963, 7964, 3488, 1047, # 7462
+ 2704, 1266, 7965, 1368, 4543, 2845, 649, 3383, 3841, 2539, 2738, 1102, 2846, 2669, 7966, 7967, # 7478
+ 1999, 7968, 1111, 3596, 2962, 7969, 2488, 3842, 3597, 2804, 1854, 3384, 3722, 7970, 7971, 3385, # 7494
+ 2410, 2884, 3304, 3235, 3598, 7972, 2569, 7973, 3599, 2805, 4031, 1460, 856, 7974, 3600, 7975, # 7510
+ 2885, 2963, 7976, 2886, 3843, 7977, 4264, 632, 2510, 875, 3844, 1697, 3845, 2291, 7978, 7979, # 7526
+ 4544, 3010, 1239, 580, 4545, 4265, 7980, 914, 936, 2074, 1190, 4032, 1039, 2123, 7981, 7982, # 7542
+ 7983, 3386, 1473, 7984, 1354, 4266, 3846, 7985, 2172, 3064, 4033, 915, 3305, 4267, 4268, 3306, # 7558
+ 1605, 1834, 7986, 2739, 398, 3601, 4269, 3847, 4034, 328, 1912, 2847, 4035, 3848, 1331, 4270, # 7574
+ 3011, 937, 4271, 7987, 3602, 4036, 4037, 3387, 2160, 4546, 3388, 524, 742, 538, 3065, 1012, # 7590
+ 7988, 7989, 3849, 2461, 7990, 658, 1103, 225, 3850, 7991, 7992, 4547, 7993, 4548, 7994, 3236, # 7606
+ 1243, 7995, 4038, 963, 2246, 4549, 7996, 2705, 3603, 3161, 7997, 7998, 2588, 2327, 7999, 4550, # 7622
+ 8000, 8001, 8002, 3489, 3307, 957, 3389, 2540, 2032, 1930, 2927, 2462, 870, 2018, 3604, 1746, # 7638
+ 2770, 2771, 2434, 2463, 8003, 3851, 8004, 3723, 3107, 3724, 3490, 3390, 3725, 8005, 1179, 3066, # 7654
+ 8006, 3162, 2373, 4272, 3726, 2541, 3163, 3108, 2740, 4039, 8007, 3391, 1556, 2542, 2292, 977, # 7670
+ 2887, 2033, 4040, 1205, 3392, 8008, 1765, 3393, 3164, 2124, 1271, 1689, 714, 4551, 3491, 8009, # 7686
+ 2328, 3852, 533, 4273, 3605, 2181, 617, 8010, 2464, 3308, 3492, 2310, 8011, 8012, 3165, 8013, # 7702
+ 8014, 3853, 1987, 618, 427, 2641, 3493, 3394, 8015, 8016, 1244, 1690, 8017, 2806, 4274, 4552, # 7718
+ 8018, 3494, 8019, 8020, 2279, 1576, 473, 3606, 4275, 3395, 972, 8021, 3607, 8022, 3067, 8023, # 7734
+ 8024, 4553, 4554, 8025, 3727, 4041, 4042, 8026, 153, 4555, 356, 8027, 1891, 2888, 4276, 2143, # 7750
+ 408, 803, 2352, 8028, 3854, 8029, 4277, 1646, 2570, 2511, 4556, 4557, 3855, 8030, 3856, 4278, # 7766
+ 8031, 2411, 3396, 752, 8032, 8033, 1961, 2964, 8034, 746, 3012, 2465, 8035, 4279, 3728, 698, # 7782
+ 4558, 1892, 4280, 3608, 2543, 4559, 3609, 3857, 8036, 3166, 3397, 8037, 1823, 1302, 4043, 2706, # 7798
+ 3858, 1973, 4281, 8038, 4282, 3167, 823, 1303, 1288, 1236, 2848, 3495, 4044, 3398, 774, 3859, # 7814
+ 8039, 1581, 4560, 1304, 2849, 3860, 4561, 8040, 2435, 2161, 1083, 3237, 4283, 4045, 4284, 344, # 7830
+ 1173, 288, 2311, 454, 1683, 8041, 8042, 1461, 4562, 4046, 2589, 8043, 8044, 4563, 985, 894, # 7846
+ 8045, 3399, 3168, 8046, 1913, 2928, 3729, 1988, 8047, 2110, 1974, 8048, 4047, 8049, 2571, 1194, # 7862
+ 425, 8050, 4564, 3169, 1245, 3730, 4285, 8051, 8052, 2850, 8053, 636, 4565, 1855, 3861, 760, # 7878
+ 1799, 8054, 4286, 2209, 1508, 4566, 4048, 1893, 1684, 2293, 8055, 8056, 8057, 4287, 4288, 2210, # 7894
+ 479, 8058, 8059, 832, 8060, 4049, 2489, 8061, 2965, 2490, 3731, 990, 3109, 627, 1814, 2642, # 7910
+ 4289, 1582, 4290, 2125, 2111, 3496, 4567, 8062, 799, 4291, 3170, 8063, 4568, 2112, 1737, 3013, # 7926
+ 1018, 543, 754, 4292, 3309, 1676, 4569, 4570, 4050, 8064, 1489, 8065, 3497, 8066, 2614, 2889, # 7942
+ 4051, 8067, 8068, 2966, 8069, 8070, 8071, 8072, 3171, 4571, 4572, 2182, 1722, 8073, 3238, 3239, # 7958
+ 1842, 3610, 1715, 481, 365, 1975, 1856, 8074, 8075, 1962, 2491, 4573, 8076, 2126, 3611, 3240, # 7974
+ 433, 1894, 2063, 2075, 8077, 602, 2741, 8078, 8079, 8080, 8081, 8082, 3014, 1628, 3400, 8083, # 7990
+ 3172, 4574, 4052, 2890, 4575, 2512, 8084, 2544, 2772, 8085, 8086, 8087, 3310, 4576, 2891, 8088, # 8006
+ 4577, 8089, 2851, 4578, 4579, 1221, 2967, 4053, 2513, 8090, 8091, 8092, 1867, 1989, 8093, 8094, # 8022
+ 8095, 1895, 8096, 8097, 4580, 1896, 4054, 318, 8098, 2094, 4055, 4293, 8099, 8100, 485, 8101, # 8038
+ 938, 3862, 553, 2670, 116, 8102, 3863, 3612, 8103, 3498, 2671, 2773, 3401, 3311, 2807, 8104, # 8054
+ 3613, 2929, 4056, 1747, 2930, 2968, 8105, 8106, 207, 8107, 8108, 2672, 4581, 2514, 8109, 3015, # 8070
+ 890, 3614, 3864, 8110, 1877, 3732, 3402, 8111, 2183, 2353, 3403, 1652, 8112, 8113, 8114, 941, # 8086
+ 2294, 208, 3499, 4057, 2019, 330, 4294, 3865, 2892, 2492, 3733, 4295, 8115, 8116, 8117, 8118, # 8102
+)
+# fmt: on
diff --git a/lib/chardet/euctwprober.py b/lib/chardet/euctwprober.py
new file mode 100644
index 0000000..ca10a23
--- /dev/null
+++ b/lib/chardet/euctwprober.py
@@ -0,0 +1,47 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import EUCTWDistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import EUCTW_SM_MODEL
+
+
+class EUCTWProber(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(EUCTW_SM_MODEL)
+ self.distribution_analyzer = EUCTWDistributionAnalysis()
+ self.reset()
+
+ @property
+ def charset_name(self):
+ return "EUC-TW"
+
+ @property
+ def language(self):
+ return "Taiwan"
diff --git a/lib/chardet/gb2312freq.py b/lib/chardet/gb2312freq.py
new file mode 100644
index 0000000..b32bfc7
--- /dev/null
+++ b/lib/chardet/gb2312freq.py
@@ -0,0 +1,284 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+# GB2312 most frequently used character table
+#
+# Char to FreqOrder table , from hz6763
+
+# 512 --> 0.79 -- 0.79
+# 1024 --> 0.92 -- 0.13
+# 2048 --> 0.98 -- 0.06
+# 6768 --> 1.00 -- 0.02
+#
+# Ideal Distribution Ratio = 0.79135/(1-0.79135) = 3.79
+# Random Distribution Ration = 512 / (3755 - 512) = 0.157
+#
+# Typical Distribution Ratio about 25% of Ideal one, still much higher that RDR
+
+GB2312_TYPICAL_DISTRIBUTION_RATIO = 0.9
+
+GB2312_TABLE_SIZE = 3760
+
+# fmt: off
+GB2312_CHAR_TO_FREQ_ORDER = (
+1671, 749,1443,2364,3924,3807,2330,3921,1704,3463,2691,1511,1515, 572,3191,2205,
+2361, 224,2558, 479,1711, 963,3162, 440,4060,1905,2966,2947,3580,2647,3961,3842,
+2204, 869,4207, 970,2678,5626,2944,2956,1479,4048, 514,3595, 588,1346,2820,3409,
+ 249,4088,1746,1873,2047,1774, 581,1813, 358,1174,3590,1014,1561,4844,2245, 670,
+1636,3112, 889,1286, 953, 556,2327,3060,1290,3141, 613, 185,3477,1367, 850,3820,
+1715,2428,2642,2303,2732,3041,2562,2648,3566,3946,1349, 388,3098,2091,1360,3585,
+ 152,1687,1539, 738,1559, 59,1232,2925,2267,1388,1249,1741,1679,2960, 151,1566,
+1125,1352,4271, 924,4296, 385,3166,4459, 310,1245,2850, 70,3285,2729,3534,3575,
+2398,3298,3466,1960,2265, 217,3647, 864,1909,2084,4401,2773,1010,3269,5152, 853,
+3051,3121,1244,4251,1895, 364,1499,1540,2313,1180,3655,2268, 562, 715,2417,3061,
+ 544, 336,3768,2380,1752,4075, 950, 280,2425,4382, 183,2759,3272, 333,4297,2155,
+1688,2356,1444,1039,4540, 736,1177,3349,2443,2368,2144,2225, 565, 196,1482,3406,
+ 927,1335,4147, 692, 878,1311,1653,3911,3622,1378,4200,1840,2969,3149,2126,1816,
+2534,1546,2393,2760, 737,2494, 13, 447, 245,2747, 38,2765,2129,2589,1079, 606,
+ 360, 471,3755,2890, 404, 848, 699,1785,1236, 370,2221,1023,3746,2074,2026,2023,
+2388,1581,2119, 812,1141,3091,2536,1519, 804,2053, 406,1596,1090, 784, 548,4414,
+1806,2264,2936,1100, 343,4114,5096, 622,3358, 743,3668,1510,1626,5020,3567,2513,
+3195,4115,5627,2489,2991, 24,2065,2697,1087,2719, 48,1634, 315, 68, 985,2052,
+ 198,2239,1347,1107,1439, 597,2366,2172, 871,3307, 919,2487,2790,1867, 236,2570,
+1413,3794, 906,3365,3381,1701,1982,1818,1524,2924,1205, 616,2586,2072,2004, 575,
+ 253,3099, 32,1365,1182, 197,1714,2454,1201, 554,3388,3224,2748, 756,2587, 250,
+2567,1507,1517,3529,1922,2761,2337,3416,1961,1677,2452,2238,3153, 615, 911,1506,
+1474,2495,1265,1906,2749,3756,3280,2161, 898,2714,1759,3450,2243,2444, 563, 26,
+3286,2266,3769,3344,2707,3677, 611,1402, 531,1028,2871,4548,1375, 261,2948, 835,
+1190,4134, 353, 840,2684,1900,3082,1435,2109,1207,1674, 329,1872,2781,4055,2686,
+2104, 608,3318,2423,2957,2768,1108,3739,3512,3271,3985,2203,1771,3520,1418,2054,
+1681,1153, 225,1627,2929, 162,2050,2511,3687,1954, 124,1859,2431,1684,3032,2894,
+ 585,4805,3969,2869,2704,2088,2032,2095,3656,2635,4362,2209, 256, 518,2042,2105,
+3777,3657, 643,2298,1148,1779, 190, 989,3544, 414, 11,2135,2063,2979,1471, 403,
+3678, 126, 770,1563, 671,2499,3216,2877, 600,1179, 307,2805,4937,1268,1297,2694,
+ 252,4032,1448,1494,1331,1394, 127,2256, 222,1647,1035,1481,3056,1915,1048, 873,
+3651, 210, 33,1608,2516, 200,1520, 415, 102, 0,3389,1287, 817, 91,3299,2940,
+ 836,1814, 549,2197,1396,1669,2987,3582,2297,2848,4528,1070, 687, 20,1819, 121,
+1552,1364,1461,1968,2617,3540,2824,2083, 177, 948,4938,2291, 110,4549,2066, 648,
+3359,1755,2110,2114,4642,4845,1693,3937,3308,1257,1869,2123, 208,1804,3159,2992,
+2531,2549,3361,2418,1350,2347,2800,2568,1291,2036,2680, 72, 842,1990, 212,1233,
+1154,1586, 75,2027,3410,4900,1823,1337,2710,2676, 728,2810,1522,3026,4995, 157,
+ 755,1050,4022, 710, 785,1936,2194,2085,1406,2777,2400, 150,1250,4049,1206, 807,
+1910, 534, 529,3309,1721,1660, 274, 39,2827, 661,2670,1578, 925,3248,3815,1094,
+4278,4901,4252, 41,1150,3747,2572,2227,4501,3658,4902,3813,3357,3617,2884,2258,
+ 887, 538,4187,3199,1294,2439,3042,2329,2343,2497,1255, 107, 543,1527, 521,3478,
+3568, 194,5062, 15, 961,3870,1241,1192,2664, 66,5215,3260,2111,1295,1127,2152,
+3805,4135, 901,1164,1976, 398,1278, 530,1460, 748, 904,1054,1966,1426, 53,2909,
+ 509, 523,2279,1534, 536,1019, 239,1685, 460,2353, 673,1065,2401,3600,4298,2272,
+1272,2363, 284,1753,3679,4064,1695, 81, 815,2677,2757,2731,1386, 859, 500,4221,
+2190,2566, 757,1006,2519,2068,1166,1455, 337,2654,3203,1863,1682,1914,3025,1252,
+1409,1366, 847, 714,2834,2038,3209, 964,2970,1901, 885,2553,1078,1756,3049, 301,
+1572,3326, 688,2130,1996,2429,1805,1648,2930,3421,2750,3652,3088, 262,1158,1254,
+ 389,1641,1812, 526,1719, 923,2073,1073,1902, 468, 489,4625,1140, 857,2375,3070,
+3319,2863, 380, 116,1328,2693,1161,2244, 273,1212,1884,2769,3011,1775,1142, 461,
+3066,1200,2147,2212, 790, 702,2695,4222,1601,1058, 434,2338,5153,3640, 67,2360,
+4099,2502, 618,3472,1329, 416,1132, 830,2782,1807,2653,3211,3510,1662, 192,2124,
+ 296,3979,1739,1611,3684, 23, 118, 324, 446,1239,1225, 293,2520,3814,3795,2535,
+3116, 17,1074, 467,2692,2201, 387,2922, 45,1326,3055,1645,3659,2817, 958, 243,
+1903,2320,1339,2825,1784,3289, 356, 576, 865,2315,2381,3377,3916,1088,3122,1713,
+1655, 935, 628,4689,1034,1327, 441, 800, 720, 894,1979,2183,1528,5289,2702,1071,
+4046,3572,2399,1571,3281, 79, 761,1103, 327, 134, 758,1899,1371,1615, 879, 442,
+ 215,2605,2579, 173,2048,2485,1057,2975,3317,1097,2253,3801,4263,1403,1650,2946,
+ 814,4968,3487,1548,2644,1567,1285, 2, 295,2636, 97, 946,3576, 832, 141,4257,
+3273, 760,3821,3521,3156,2607, 949,1024,1733,1516,1803,1920,2125,2283,2665,3180,
+1501,2064,3560,2171,1592, 803,3518,1416, 732,3897,4258,1363,1362,2458, 119,1427,
+ 602,1525,2608,1605,1639,3175, 694,3064, 10, 465, 76,2000,4846,4208, 444,3781,
+1619,3353,2206,1273,3796, 740,2483, 320,1723,2377,3660,2619,1359,1137,1762,1724,
+2345,2842,1850,1862, 912, 821,1866, 612,2625,1735,2573,3369,1093, 844, 89, 937,
+ 930,1424,3564,2413,2972,1004,3046,3019,2011, 711,3171,1452,4178, 428, 801,1943,
+ 432, 445,2811, 206,4136,1472, 730, 349, 73, 397,2802,2547, 998,1637,1167, 789,
+ 396,3217, 154,1218, 716,1120,1780,2819,4826,1931,3334,3762,2139,1215,2627, 552,
+3664,3628,3232,1405,2383,3111,1356,2652,3577,3320,3101,1703, 640,1045,1370,1246,
+4996, 371,1575,2436,1621,2210, 984,4033,1734,2638, 16,4529, 663,2755,3255,1451,
+3917,2257,1253,1955,2234,1263,2951, 214,1229, 617, 485, 359,1831,1969, 473,2310,
+ 750,2058, 165, 80,2864,2419, 361,4344,2416,2479,1134, 796,3726,1266,2943, 860,
+2715, 938, 390,2734,1313,1384, 248, 202, 877,1064,2854, 522,3907, 279,1602, 297,
+2357, 395,3740, 137,2075, 944,4089,2584,1267,3802, 62,1533,2285, 178, 176, 780,
+2440, 201,3707, 590, 478,1560,4354,2117,1075, 30, 74,4643,4004,1635,1441,2745,
+ 776,2596, 238,1077,1692,1912,2844, 605, 499,1742,3947, 241,3053, 980,1749, 936,
+2640,4511,2582, 515,1543,2162,5322,2892,2993, 890,2148,1924, 665,1827,3581,1032,
+ 968,3163, 339,1044,1896, 270, 583,1791,1720,4367,1194,3488,3669, 43,2523,1657,
+ 163,2167, 290,1209,1622,3378, 550, 634,2508,2510, 695,2634,2384,2512,1476,1414,
+ 220,1469,2341,2138,2852,3183,2900,4939,2865,3502,1211,3680, 854,3227,1299,2976,
+3172, 186,2998,1459, 443,1067,3251,1495, 321,1932,3054, 909, 753,1410,1828, 436,
+2441,1119,1587,3164,2186,1258, 227, 231,1425,1890,3200,3942, 247, 959, 725,5254,
+2741, 577,2158,2079, 929, 120, 174, 838,2813, 591,1115, 417,2024, 40,3240,1536,
+1037, 291,4151,2354, 632,1298,2406,2500,3535,1825,1846,3451, 205,1171, 345,4238,
+ 18,1163, 811, 685,2208,1217, 425,1312,1508,1175,4308,2552,1033, 587,1381,3059,
+2984,3482, 340,1316,4023,3972, 792,3176, 519, 777,4690, 918, 933,4130,2981,3741,
+ 90,3360,2911,2200,5184,4550, 609,3079,2030, 272,3379,2736, 363,3881,1130,1447,
+ 286, 779, 357,1169,3350,3137,1630,1220,2687,2391, 747,1277,3688,2618,2682,2601,
+1156,3196,5290,4034,3102,1689,3596,3128, 874, 219,2783, 798, 508,1843,2461, 269,
+1658,1776,1392,1913,2983,3287,2866,2159,2372, 829,4076, 46,4253,2873,1889,1894,
+ 915,1834,1631,2181,2318, 298, 664,2818,3555,2735, 954,3228,3117, 527,3511,2173,
+ 681,2712,3033,2247,2346,3467,1652, 155,2164,3382, 113,1994, 450, 899, 494, 994,
+1237,2958,1875,2336,1926,3727, 545,1577,1550, 633,3473, 204,1305,3072,2410,1956,
+2471, 707,2134, 841,2195,2196,2663,3843,1026,4940, 990,3252,4997, 368,1092, 437,
+3212,3258,1933,1829, 675,2977,2893, 412, 943,3723,4644,3294,3283,2230,2373,5154,
+2389,2241,2661,2323,1404,2524, 593, 787, 677,3008,1275,2059, 438,2709,2609,2240,
+2269,2246,1446, 36,1568,1373,3892,1574,2301,1456,3962, 693,2276,5216,2035,1143,
+2720,1919,1797,1811,2763,4137,2597,1830,1699,1488,1198,2090, 424,1694, 312,3634,
+3390,4179,3335,2252,1214, 561,1059,3243,2295,2561, 975,5155,2321,2751,3772, 472,
+1537,3282,3398,1047,2077,2348,2878,1323,3340,3076, 690,2906, 51, 369, 170,3541,
+1060,2187,2688,3670,2541,1083,1683, 928,3918, 459, 109,4427, 599,3744,4286, 143,
+2101,2730,2490, 82,1588,3036,2121, 281,1860, 477,4035,1238,2812,3020,2716,3312,
+1530,2188,2055,1317, 843, 636,1808,1173,3495, 649, 181,1002, 147,3641,1159,2414,
+3750,2289,2795, 813,3123,2610,1136,4368, 5,3391,4541,2174, 420, 429,1728, 754,
+1228,2115,2219, 347,2223,2733, 735,1518,3003,2355,3134,1764,3948,3329,1888,2424,
+1001,1234,1972,3321,3363,1672,1021,1450,1584, 226, 765, 655,2526,3404,3244,2302,
+3665, 731, 594,2184, 319,1576, 621, 658,2656,4299,2099,3864,1279,2071,2598,2739,
+ 795,3086,3699,3908,1707,2352,2402,1382,3136,2475,1465,4847,3496,3865,1085,3004,
+2591,1084, 213,2287,1963,3565,2250, 822, 793,4574,3187,1772,1789,3050, 595,1484,
+1959,2770,1080,2650, 456, 422,2996, 940,3322,4328,4345,3092,2742, 965,2784, 739,
+4124, 952,1358,2498,2949,2565, 332,2698,2378, 660,2260,2473,4194,3856,2919, 535,
+1260,2651,1208,1428,1300,1949,1303,2942, 433,2455,2450,1251,1946, 614,1269, 641,
+1306,1810,2737,3078,2912, 564,2365,1419,1415,1497,4460,2367,2185,1379,3005,1307,
+3218,2175,1897,3063, 682,1157,4040,4005,1712,1160,1941,1399, 394, 402,2952,1573,
+1151,2986,2404, 862, 299,2033,1489,3006, 346, 171,2886,3401,1726,2932, 168,2533,
+ 47,2507,1030,3735,1145,3370,1395,1318,1579,3609,4560,2857,4116,1457,2529,1965,
+ 504,1036,2690,2988,2405, 745,5871, 849,2397,2056,3081, 863,2359,3857,2096, 99,
+1397,1769,2300,4428,1643,3455,1978,1757,3718,1440, 35,4879,3742,1296,4228,2280,
+ 160,5063,1599,2013, 166, 520,3479,1646,3345,3012, 490,1937,1545,1264,2182,2505,
+1096,1188,1369,1436,2421,1667,2792,2460,1270,2122, 727,3167,2143, 806,1706,1012,
+1800,3037, 960,2218,1882, 805, 139,2456,1139,1521, 851,1052,3093,3089, 342,2039,
+ 744,5097,1468,1502,1585,2087, 223, 939, 326,2140,2577, 892,2481,1623,4077, 982,
+3708, 135,2131, 87,2503,3114,2326,1106, 876,1616, 547,2997,2831,2093,3441,4530,
+4314, 9,3256,4229,4148, 659,1462,1986,1710,2046,2913,2231,4090,4880,5255,3392,
+3274,1368,3689,4645,1477, 705,3384,3635,1068,1529,2941,1458,3782,1509, 100,1656,
+2548, 718,2339, 408,1590,2780,3548,1838,4117,3719,1345,3530, 717,3442,2778,3220,
+2898,1892,4590,3614,3371,2043,1998,1224,3483, 891, 635, 584,2559,3355, 733,1766,
+1729,1172,3789,1891,2307, 781,2982,2271,1957,1580,5773,2633,2005,4195,3097,1535,
+3213,1189,1934,5693,3262, 586,3118,1324,1598, 517,1564,2217,1868,1893,4445,3728,
+2703,3139,1526,1787,1992,3882,2875,1549,1199,1056,2224,1904,2711,5098,4287, 338,
+1993,3129,3489,2689,1809,2815,1997, 957,1855,3898,2550,3275,3057,1105,1319, 627,
+1505,1911,1883,3526, 698,3629,3456,1833,1431, 746, 77,1261,2017,2296,1977,1885,
+ 125,1334,1600, 525,1798,1109,2222,1470,1945, 559,2236,1186,3443,2476,1929,1411,
+2411,3135,1777,3372,2621,1841,1613,3229, 668,1430,1839,2643,2916, 195,1989,2671,
+2358,1387, 629,3205,2293,5256,4439, 123,1310, 888,1879,4300,3021,3605,1003,1162,
+3192,2910,2010, 140,2395,2859, 55,1082,2012,2901, 662, 419,2081,1438, 680,2774,
+4654,3912,1620,1731,1625,5035,4065,2328, 512,1344, 802,5443,2163,2311,2537, 524,
+3399, 98,1155,2103,1918,2606,3925,2816,1393,2465,1504,3773,2177,3963,1478,4346,
+ 180,1113,4655,3461,2028,1698, 833,2696,1235,1322,1594,4408,3623,3013,3225,2040,
+3022, 541,2881, 607,3632,2029,1665,1219, 639,1385,1686,1099,2803,3231,1938,3188,
+2858, 427, 676,2772,1168,2025, 454,3253,2486,3556, 230,1950, 580, 791,1991,1280,
+1086,1974,2034, 630, 257,3338,2788,4903,1017, 86,4790, 966,2789,1995,1696,1131,
+ 259,3095,4188,1308, 179,1463,5257, 289,4107,1248, 42,3413,1725,2288, 896,1947,
+ 774,4474,4254, 604,3430,4264, 392,2514,2588, 452, 237,1408,3018, 988,4531,1970,
+3034,3310, 540,2370,1562,1288,2990, 502,4765,1147, 4,1853,2708, 207, 294,2814,
+4078,2902,2509, 684, 34,3105,3532,2551, 644, 709,2801,2344, 573,1727,3573,3557,
+2021,1081,3100,4315,2100,3681, 199,2263,1837,2385, 146,3484,1195,2776,3949, 997,
+1939,3973,1008,1091,1202,1962,1847,1149,4209,5444,1076, 493, 117,5400,2521, 972,
+1490,2934,1796,4542,2374,1512,2933,2657, 413,2888,1135,2762,2314,2156,1355,2369,
+ 766,2007,2527,2170,3124,2491,2593,2632,4757,2437, 234,3125,3591,1898,1750,1376,
+1942,3468,3138, 570,2127,2145,3276,4131, 962, 132,1445,4196, 19, 941,3624,3480,
+3366,1973,1374,4461,3431,2629, 283,2415,2275, 808,2887,3620,2112,2563,1353,3610,
+ 955,1089,3103,1053, 96, 88,4097, 823,3808,1583, 399, 292,4091,3313, 421,1128,
+ 642,4006, 903,2539,1877,2082, 596, 29,4066,1790, 722,2157, 130, 995,1569, 769,
+1485, 464, 513,2213, 288,1923,1101,2453,4316, 133, 486,2445, 50, 625, 487,2207,
+ 57, 423, 481,2962, 159,3729,1558, 491, 303, 482, 501, 240,2837, 112,3648,2392,
+1783, 362, 8,3433,3422, 610,2793,3277,1390,1284,1654, 21,3823, 734, 367, 623,
+ 193, 287, 374,1009,1483, 816, 476, 313,2255,2340,1262,2150,2899,1146,2581, 782,
+2116,1659,2018,1880, 255,3586,3314,1110,2867,2137,2564, 986,2767,5185,2006, 650,
+ 158, 926, 762, 881,3157,2717,2362,3587, 306,3690,3245,1542,3077,2427,1691,2478,
+2118,2985,3490,2438, 539,2305, 983, 129,1754, 355,4201,2386, 827,2923, 104,1773,
+2838,2771, 411,2905,3919, 376, 767, 122,1114, 828,2422,1817,3506, 266,3460,1007,
+1609,4998, 945,2612,4429,2274, 726,1247,1964,2914,2199,2070,4002,4108, 657,3323,
+1422, 579, 455,2764,4737,1222,2895,1670, 824,1223,1487,2525, 558, 861,3080, 598,
+2659,2515,1967, 752,2583,2376,2214,4180, 977, 704,2464,4999,2622,4109,1210,2961,
+ 819,1541, 142,2284, 44, 418, 457,1126,3730,4347,4626,1644,1876,3671,1864, 302,
+1063,5694, 624, 723,1984,3745,1314,1676,2488,1610,1449,3558,3569,2166,2098, 409,
+1011,2325,3704,2306, 818,1732,1383,1824,1844,3757, 999,2705,3497,1216,1423,2683,
+2426,2954,2501,2726,2229,1475,2554,5064,1971,1794,1666,2014,1343, 783, 724, 191,
+2434,1354,2220,5065,1763,2752,2472,4152, 131, 175,2885,3434, 92,1466,4920,2616,
+3871,3872,3866, 128,1551,1632, 669,1854,3682,4691,4125,1230, 188,2973,3290,1302,
+1213, 560,3266, 917, 763,3909,3249,1760, 868,1958, 764,1782,2097, 145,2277,3774,
+4462, 64,1491,3062, 971,2132,3606,2442, 221,1226,1617, 218, 323,1185,3207,3147,
+ 571, 619,1473,1005,1744,2281, 449,1887,2396,3685, 275, 375,3816,1743,3844,3731,
+ 845,1983,2350,4210,1377, 773, 967,3499,3052,3743,2725,4007,1697,1022,3943,1464,
+3264,2855,2722,1952,1029,2839,2467, 84,4383,2215, 820,1391,2015,2448,3672, 377,
+1948,2168, 797,2545,3536,2578,2645, 94,2874,1678, 405,1259,3071, 771, 546,1315,
+ 470,1243,3083, 895,2468, 981, 969,2037, 846,4181, 653,1276,2928, 14,2594, 557,
+3007,2474, 156, 902,1338,1740,2574, 537,2518, 973,2282,2216,2433,1928, 138,2903,
+1293,2631,1612, 646,3457, 839,2935, 111, 496,2191,2847, 589,3186, 149,3994,2060,
+4031,2641,4067,3145,1870, 37,3597,2136,1025,2051,3009,3383,3549,1121,1016,3261,
+1301, 251,2446,2599,2153, 872,3246, 637, 334,3705, 831, 884, 921,3065,3140,4092,
+2198,1944, 246,2964, 108,2045,1152,1921,2308,1031, 203,3173,4170,1907,3890, 810,
+1401,2003,1690, 506, 647,1242,2828,1761,1649,3208,2249,1589,3709,2931,5156,1708,
+ 498, 666,2613, 834,3817,1231, 184,2851,1124, 883,3197,2261,3710,1765,1553,2658,
+1178,2639,2351, 93,1193, 942,2538,2141,4402, 235,1821, 870,1591,2192,1709,1871,
+3341,1618,4126,2595,2334, 603, 651, 69, 701, 268,2662,3411,2555,1380,1606, 503,
+ 448, 254,2371,2646, 574,1187,2309,1770, 322,2235,1292,1801, 305, 566,1133, 229,
+2067,2057, 706, 167, 483,2002,2672,3295,1820,3561,3067, 316, 378,2746,3452,1112,
+ 136,1981, 507,1651,2917,1117, 285,4591, 182,2580,3522,1304, 335,3303,1835,2504,
+1795,1792,2248, 674,1018,2106,2449,1857,2292,2845, 976,3047,1781,2600,2727,1389,
+1281, 52,3152, 153, 265,3950, 672,3485,3951,4463, 430,1183, 365, 278,2169, 27,
+1407,1336,2304, 209,1340,1730,2202,1852,2403,2883, 979,1737,1062, 631,2829,2542,
+3876,2592, 825,2086,2226,3048,3625, 352,1417,3724, 542, 991, 431,1351,3938,1861,
+2294, 826,1361,2927,3142,3503,1738, 463,2462,2723, 582,1916,1595,2808, 400,3845,
+3891,2868,3621,2254, 58,2492,1123, 910,2160,2614,1372,1603,1196,1072,3385,1700,
+3267,1980, 696, 480,2430, 920, 799,1570,2920,1951,2041,4047,2540,1321,4223,2469,
+3562,2228,1271,2602, 401,2833,3351,2575,5157, 907,2312,1256, 410, 263,3507,1582,
+ 996, 678,1849,2316,1480, 908,3545,2237, 703,2322, 667,1826,2849,1531,2604,2999,
+2407,3146,2151,2630,1786,3711, 469,3542, 497,3899,2409, 858, 837,4446,3393,1274,
+ 786, 620,1845,2001,3311, 484, 308,3367,1204,1815,3691,2332,1532,2557,1842,2020,
+2724,1927,2333,4440, 567, 22,1673,2728,4475,1987,1858,1144,1597, 101,1832,3601,
+ 12, 974,3783,4391, 951,1412, 1,3720, 453,4608,4041, 528,1041,1027,3230,2628,
+1129, 875,1051,3291,1203,2262,1069,2860,2799,2149,2615,3278, 144,1758,3040, 31,
+ 475,1680, 366,2685,3184, 311,1642,4008,2466,5036,1593,1493,2809, 216,1420,1668,
+ 233, 304,2128,3284, 232,1429,1768,1040,2008,3407,2740,2967,2543, 242,2133, 778,
+1565,2022,2620, 505,2189,2756,1098,2273, 372,1614, 708, 553,2846,2094,2278, 169,
+3626,2835,4161, 228,2674,3165, 809,1454,1309, 466,1705,1095, 900,3423, 880,2667,
+3751,5258,2317,3109,2571,4317,2766,1503,1342, 866,4447,1118, 63,2076, 314,1881,
+1348,1061, 172, 978,3515,1747, 532, 511,3970, 6, 601, 905,2699,3300,1751, 276,
+1467,3725,2668, 65,4239,2544,2779,2556,1604, 578,2451,1802, 992,2331,2624,1320,
+3446, 713,1513,1013, 103,2786,2447,1661, 886,1702, 916, 654,3574,2031,1556, 751,
+2178,2821,2179,1498,1538,2176, 271, 914,2251,2080,1325, 638,1953,2937,3877,2432,
+2754, 95,3265,1716, 260,1227,4083, 775, 106,1357,3254, 426,1607, 555,2480, 772,
+1985, 244,2546, 474, 495,1046,2611,1851,2061, 71,2089,1675,2590, 742,3758,2843,
+3222,1433, 267,2180,2576,2826,2233,2092,3913,2435, 956,1745,3075, 856,2113,1116,
+ 451, 3,1988,2896,1398, 993,2463,1878,2049,1341,2718,2721,2870,2108, 712,2904,
+4363,2753,2324, 277,2872,2349,2649, 384, 987, 435, 691,3000, 922, 164,3939, 652,
+1500,1184,4153,2482,3373,2165,4848,2335,3775,3508,3154,2806,2830,1554,2102,1664,
+2530,1434,2408, 893,1547,2623,3447,2832,2242,2532,3169,2856,3223,2078, 49,3770,
+3469, 462, 318, 656,2259,3250,3069, 679,1629,2758, 344,1138,1104,3120,1836,1283,
+3115,2154,1437,4448, 934, 759,1999, 794,2862,1038, 533,2560,1722,2342, 855,2626,
+1197,1663,4476,3127, 85,4240,2528, 25,1111,1181,3673, 407,3470,4561,2679,2713,
+ 768,1925,2841,3986,1544,1165, 932, 373,1240,2146,1930,2673, 721,4766, 354,4333,
+ 391,2963, 187, 61,3364,1442,1102, 330,1940,1767, 341,3809,4118, 393,2496,2062,
+2211, 105, 331, 300, 439, 913,1332, 626, 379,3304,1557, 328, 689,3952, 309,1555,
+ 931, 317,2517,3027, 325, 569, 686,2107,3084, 60,1042,1333,2794, 264,3177,4014,
+1628, 258,3712, 7,4464,1176,1043,1778, 683, 114,1975, 78,1492, 383,1886, 510,
+ 386, 645,5291,2891,2069,3305,4138,3867,2939,2603,2493,1935,1066,1848,3588,1015,
+1282,1289,4609, 697,1453,3044,2666,3611,1856,2412, 54, 719,1330, 568,3778,2459,
+1748, 788, 492, 551,1191,1000, 488,3394,3763, 282,1799, 348,2016,1523,3155,2390,
+1049, 382,2019,1788,1170, 729,2968,3523, 897,3926,2785,2938,3292, 350,2319,3238,
+1718,1717,2655,3453,3143,4465, 161,2889,2980,2009,1421, 56,1908,1640,2387,2232,
+1917,1874,2477,4921, 148, 83,3438, 592,4245,2882,1822,1055, 741, 115,1496,1624,
+ 381,1638,4592,1020, 516,3214, 458, 947,4575,1432, 211,1514,2926,1865,2142, 189,
+ 852,1221,1400,1486, 882,2299,4036, 351, 28,1122, 700,6479,6480,6481,6482,6483, #last 512
+)
+# fmt: on
diff --git a/lib/chardet/gb2312prober.py b/lib/chardet/gb2312prober.py
new file mode 100644
index 0000000..251c042
--- /dev/null
+++ b/lib/chardet/gb2312prober.py
@@ -0,0 +1,47 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import GB2312DistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import GB2312_SM_MODEL
+
+
+class GB2312Prober(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(GB2312_SM_MODEL)
+ self.distribution_analyzer = GB2312DistributionAnalysis()
+ self.reset()
+
+ @property
+ def charset_name(self):
+ return "GB2312"
+
+ @property
+ def language(self):
+ return "Chinese"
diff --git a/lib/chardet/hebrewprober.py b/lib/chardet/hebrewprober.py
new file mode 100644
index 0000000..3ca634b
--- /dev/null
+++ b/lib/chardet/hebrewprober.py
@@ -0,0 +1,302 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Shy Shalom
+# Portions created by the Initial Developer are Copyright (C) 2005
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .charsetprober import CharSetProber
+from .enums import ProbingState
+
+# This prober doesn't actually recognize a language or a charset.
+# It is a helper prober for the use of the Hebrew model probers
+
+### General ideas of the Hebrew charset recognition ###
+#
+# Four main charsets exist in Hebrew:
+# "ISO-8859-8" - Visual Hebrew
+# "windows-1255" - Logical Hebrew
+# "ISO-8859-8-I" - Logical Hebrew
+# "x-mac-hebrew" - ?? Logical Hebrew ??
+#
+# Both "ISO" charsets use a completely identical set of code points, whereas
+# "windows-1255" and "x-mac-hebrew" are two different proper supersets of
+# these code points. windows-1255 defines additional characters in the range
+# 0x80-0x9F as some misc punctuation marks as well as some Hebrew-specific
+# diacritics and additional 'Yiddish' ligature letters in the range 0xc0-0xd6.
+# x-mac-hebrew defines similar additional code points but with a different
+# mapping.
+#
+# As far as an average Hebrew text with no diacritics is concerned, all four
+# charsets are identical with respect to code points. Meaning that for the
+# main Hebrew alphabet, all four map the same values to all 27 Hebrew letters
+# (including final letters).
+#
+# The dominant difference between these charsets is their directionality.
+# "Visual" directionality means that the text is ordered as if the renderer is
+# not aware of a BIDI rendering algorithm. The renderer sees the text and
+# draws it from left to right. The text itself when ordered naturally is read
+# backwards. A buffer of Visual Hebrew generally looks like so:
+# "[last word of first line spelled backwards] [whole line ordered backwards
+# and spelled backwards] [first word of first line spelled backwards]
+# [end of line] [last word of second line] ... etc' "
+# adding punctuation marks, numbers and English text to visual text is
+# naturally also "visual" and from left to right.
+#
+# "Logical" directionality means the text is ordered "naturally" according to
+# the order it is read. It is the responsibility of the renderer to display
+# the text from right to left. A BIDI algorithm is used to place general
+# punctuation marks, numbers and English text in the text.
+#
+# Texts in x-mac-hebrew are almost impossible to find on the Internet. From
+# what little evidence I could find, it seems that its general directionality
+# is Logical.
+#
+# To sum up all of the above, the Hebrew probing mechanism knows about two
+# charsets:
+# Visual Hebrew - "ISO-8859-8" - backwards text - Words and sentences are
+# backwards while line order is natural. For charset recognition purposes
+# the line order is unimportant (In fact, for this implementation, even
+# word order is unimportant).
+# Logical Hebrew - "windows-1255" - normal, naturally ordered text.
+#
+# "ISO-8859-8-I" is a subset of windows-1255 and doesn't need to be
+# specifically identified.
+# "x-mac-hebrew" is also identified as windows-1255. A text in x-mac-hebrew
+# that contain special punctuation marks or diacritics is displayed with
+# some unconverted characters showing as question marks. This problem might
+# be corrected using another model prober for x-mac-hebrew. Due to the fact
+# that x-mac-hebrew texts are so rare, writing another model prober isn't
+# worth the effort and performance hit.
+#
+#### The Prober ####
+#
+# The prober is divided between two SBCharSetProbers and a HebrewProber,
+# all of which are managed, created, fed data, inquired and deleted by the
+# SBCSGroupProber. The two SBCharSetProbers identify that the text is in
+# fact some kind of Hebrew, Logical or Visual. The final decision about which
+# one is it is made by the HebrewProber by combining final-letter scores
+# with the scores of the two SBCharSetProbers to produce a final answer.
+#
+# The SBCSGroupProber is responsible for stripping the original text of HTML
+# tags, English characters, numbers, low-ASCII punctuation characters, spaces
+# and new lines. It reduces any sequence of such characters to a single space.
+# The buffer fed to each prober in the SBCS group prober is pure text in
+# high-ASCII.
+# The two SBCharSetProbers (model probers) share the same language model:
+# Win1255Model.
+# The first SBCharSetProber uses the model normally as any other
+# SBCharSetProber does, to recognize windows-1255, upon which this model was
+# built. The second SBCharSetProber is told to make the pair-of-letter
+# lookup in the language model backwards. This in practice exactly simulates
+# a visual Hebrew model using the windows-1255 logical Hebrew model.
+#
+# The HebrewProber is not using any language model. All it does is look for
+# final-letter evidence suggesting the text is either logical Hebrew or visual
+# Hebrew. Disjointed from the model probers, the results of the HebrewProber
+# alone are meaningless. HebrewProber always returns 0.00 as confidence
+# since it never identifies a charset by itself. Instead, the pointer to the
+# HebrewProber is passed to the model probers as a helper "Name Prober".
+# When the Group prober receives a positive identification from any prober,
+# it asks for the name of the charset identified. If the prober queried is a
+# Hebrew model prober, the model prober forwards the call to the
+# HebrewProber to make the final decision. In the HebrewProber, the
+# decision is made according to the final-letters scores maintained and Both
+# model probers scores. The answer is returned in the form of the name of the
+# charset identified, either "windows-1255" or "ISO-8859-8".
+
+
+class HebrewProber(CharSetProber):
+ # windows-1255 / ISO-8859-8 code points of interest
+ FINAL_KAF = 0xEA
+ NORMAL_KAF = 0xEB
+ FINAL_MEM = 0xED
+ NORMAL_MEM = 0xEE
+ FINAL_NUN = 0xEF
+ NORMAL_NUN = 0xF0
+ FINAL_PE = 0xF3
+ NORMAL_PE = 0xF4
+ FINAL_TSADI = 0xF5
+ NORMAL_TSADI = 0xF6
+
+ # Minimum Visual vs Logical final letter score difference.
+ # If the difference is below this, don't rely solely on the final letter score
+ # distance.
+ MIN_FINAL_CHAR_DISTANCE = 5
+
+ # Minimum Visual vs Logical model score difference.
+ # If the difference is below this, don't rely at all on the model score
+ # distance.
+ MIN_MODEL_DISTANCE = 0.01
+
+ VISUAL_HEBREW_NAME = "ISO-8859-8"
+ LOGICAL_HEBREW_NAME = "windows-1255"
+
+ def __init__(self):
+ super().__init__()
+ self._final_char_logical_score = None
+ self._final_char_visual_score = None
+ self._prev = None
+ self._before_prev = None
+ self._logical_prober = None
+ self._visual_prober = None
+ self.reset()
+
+ def reset(self):
+ self._final_char_logical_score = 0
+ self._final_char_visual_score = 0
+ # The two last characters seen in the previous buffer,
+ # mPrev and mBeforePrev are initialized to space in order to simulate
+ # a word delimiter at the beginning of the data
+ self._prev = " "
+ self._before_prev = " "
+ # These probers are owned by the group prober.
+
+ def set_model_probers(self, logical_prober, visual_prober):
+ self._logical_prober = logical_prober
+ self._visual_prober = visual_prober
+
+ def is_final(self, c):
+ return c in [
+ self.FINAL_KAF,
+ self.FINAL_MEM,
+ self.FINAL_NUN,
+ self.FINAL_PE,
+ self.FINAL_TSADI,
+ ]
+
+ def is_non_final(self, c):
+ # The normal Tsadi is not a good Non-Final letter due to words like
+ # 'lechotet' (to chat) containing an apostrophe after the tsadi. This
+ # apostrophe is converted to a space in FilterWithoutEnglishLetters
+ # causing the Non-Final tsadi to appear at an end of a word even
+ # though this is not the case in the original text.
+ # The letters Pe and Kaf rarely display a related behavior of not being
+ # a good Non-Final letter. Words like 'Pop', 'Winamp' and 'Mubarak'
+ # for example legally end with a Non-Final Pe or Kaf. However, the
+ # benefit of these letters as Non-Final letters outweighs the damage
+ # since these words are quite rare.
+ return c in [self.NORMAL_KAF, self.NORMAL_MEM, self.NORMAL_NUN, self.NORMAL_PE]
+
+ def feed(self, byte_str):
+ # Final letter analysis for logical-visual decision.
+ # Look for evidence that the received buffer is either logical Hebrew
+ # or visual Hebrew.
+ # The following cases are checked:
+ # 1) A word longer than 1 letter, ending with a final letter. This is
+ # an indication that the text is laid out "naturally" since the
+ # final letter really appears at the end. +1 for logical score.
+ # 2) A word longer than 1 letter, ending with a Non-Final letter. In
+ # normal Hebrew, words ending with Kaf, Mem, Nun, Pe or Tsadi,
+ # should not end with the Non-Final form of that letter. Exceptions
+ # to this rule are mentioned above in isNonFinal(). This is an
+ # indication that the text is laid out backwards. +1 for visual
+ # score
+ # 3) A word longer than 1 letter, starting with a final letter. Final
+ # letters should not appear at the beginning of a word. This is an
+ # indication that the text is laid out backwards. +1 for visual
+ # score.
+ #
+ # The visual score and logical score are accumulated throughout the
+ # text and are finally checked against each other in GetCharSetName().
+ # No checking for final letters in the middle of words is done since
+ # that case is not an indication for either Logical or Visual text.
+ #
+ # We automatically filter out all 7-bit characters (replace them with
+ # spaces) so the word boundary detection works properly. [MAP]
+
+ if self.state == ProbingState.NOT_ME:
+ # Both model probers say it's not them. No reason to continue.
+ return ProbingState.NOT_ME
+
+ byte_str = self.filter_high_byte_only(byte_str)
+
+ for cur in byte_str:
+ if cur == " ":
+ # We stand on a space - a word just ended
+ if self._before_prev != " ":
+ # next-to-last char was not a space so self._prev is not a
+ # 1 letter word
+ if self.is_final(self._prev):
+ # case (1) [-2:not space][-1:final letter][cur:space]
+ self._final_char_logical_score += 1
+ elif self.is_non_final(self._prev):
+ # case (2) [-2:not space][-1:Non-Final letter][
+ # cur:space]
+ self._final_char_visual_score += 1
+ else:
+ # Not standing on a space
+ if (
+ (self._before_prev == " ")
+ and (self.is_final(self._prev))
+ and (cur != " ")
+ ):
+ # case (3) [-2:space][-1:final letter][cur:not space]
+ self._final_char_visual_score += 1
+ self._before_prev = self._prev
+ self._prev = cur
+
+ # Forever detecting, till the end or until both model probers return
+ # ProbingState.NOT_ME (handled above)
+ return ProbingState.DETECTING
+
+ @property
+ def charset_name(self):
+ # Make the decision: is it Logical or Visual?
+ # If the final letter score distance is dominant enough, rely on it.
+ finalsub = self._final_char_logical_score - self._final_char_visual_score
+ if finalsub >= self.MIN_FINAL_CHAR_DISTANCE:
+ return self.LOGICAL_HEBREW_NAME
+ if finalsub <= -self.MIN_FINAL_CHAR_DISTANCE:
+ return self.VISUAL_HEBREW_NAME
+
+ # It's not dominant enough, try to rely on the model scores instead.
+ modelsub = (
+ self._logical_prober.get_confidence() - self._visual_prober.get_confidence()
+ )
+ if modelsub > self.MIN_MODEL_DISTANCE:
+ return self.LOGICAL_HEBREW_NAME
+ if modelsub < -self.MIN_MODEL_DISTANCE:
+ return self.VISUAL_HEBREW_NAME
+
+ # Still no good, back to final letter distance, maybe it'll save the
+ # day.
+ if finalsub < 0.0:
+ return self.VISUAL_HEBREW_NAME
+
+ # (finalsub > 0 - Logical) or (don't know what to do) default to
+ # Logical.
+ return self.LOGICAL_HEBREW_NAME
+
+ @property
+ def language(self):
+ return "Hebrew"
+
+ @property
+ def state(self):
+ # Remain active as long as any of the model probers are active.
+ if (self._logical_prober.state == ProbingState.NOT_ME) and (
+ self._visual_prober.state == ProbingState.NOT_ME
+ ):
+ return ProbingState.NOT_ME
+ return ProbingState.DETECTING
diff --git a/lib/chardet/jisfreq.py b/lib/chardet/jisfreq.py
new file mode 100644
index 0000000..3293576
--- /dev/null
+++ b/lib/chardet/jisfreq.py
@@ -0,0 +1,325 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+# Sampling from about 20M text materials include literature and computer technology
+#
+# Japanese frequency table, applied to both S-JIS and EUC-JP
+# They are sorted in order.
+
+# 128 --> 0.77094
+# 256 --> 0.85710
+# 512 --> 0.92635
+# 1024 --> 0.97130
+# 2048 --> 0.99431
+#
+# Ideal Distribution Ratio = 0.92635 / (1-0.92635) = 12.58
+# Random Distribution Ration = 512 / (2965+62+83+86-512) = 0.191
+#
+# Typical Distribution Ratio, 25% of IDR
+
+JIS_TYPICAL_DISTRIBUTION_RATIO = 3.0
+
+# Char to FreqOrder table ,
+JIS_TABLE_SIZE = 4368
+
+# fmt: off
+JIS_CHAR_TO_FREQ_ORDER = (
+ 40, 1, 6, 182, 152, 180, 295,2127, 285, 381,3295,4304,3068,4606,3165,3510, # 16
+3511,1822,2785,4607,1193,2226,5070,4608, 171,2996,1247, 18, 179,5071, 856,1661, # 32
+1262,5072, 619, 127,3431,3512,3230,1899,1700, 232, 228,1294,1298, 284, 283,2041, # 48
+2042,1061,1062, 48, 49, 44, 45, 433, 434,1040,1041, 996, 787,2997,1255,4305, # 64
+2108,4609,1684,1648,5073,5074,5075,5076,5077,5078,3687,5079,4610,5080,3927,3928, # 80
+5081,3296,3432, 290,2285,1471,2187,5082,2580,2825,1303,2140,1739,1445,2691,3375, # 96
+1691,3297,4306,4307,4611, 452,3376,1182,2713,3688,3069,4308,5083,5084,5085,5086, # 112
+5087,5088,5089,5090,5091,5092,5093,5094,5095,5096,5097,5098,5099,5100,5101,5102, # 128
+5103,5104,5105,5106,5107,5108,5109,5110,5111,5112,4097,5113,5114,5115,5116,5117, # 144
+5118,5119,5120,5121,5122,5123,5124,5125,5126,5127,5128,5129,5130,5131,5132,5133, # 160
+5134,5135,5136,5137,5138,5139,5140,5141,5142,5143,5144,5145,5146,5147,5148,5149, # 176
+5150,5151,5152,4612,5153,5154,5155,5156,5157,5158,5159,5160,5161,5162,5163,5164, # 192
+5165,5166,5167,5168,5169,5170,5171,5172,5173,5174,5175,1472, 598, 618, 820,1205, # 208
+1309,1412,1858,1307,1692,5176,5177,5178,5179,5180,5181,5182,1142,1452,1234,1172, # 224
+1875,2043,2149,1793,1382,2973, 925,2404,1067,1241, 960,1377,2935,1491, 919,1217, # 240
+1865,2030,1406,1499,2749,4098,5183,5184,5185,5186,5187,5188,2561,4099,3117,1804, # 256
+2049,3689,4309,3513,1663,5189,3166,3118,3298,1587,1561,3433,5190,3119,1625,2998, # 272
+3299,4613,1766,3690,2786,4614,5191,5192,5193,5194,2161, 26,3377, 2,3929, 20, # 288
+3691, 47,4100, 50, 17, 16, 35, 268, 27, 243, 42, 155, 24, 154, 29, 184, # 304
+ 4, 91, 14, 92, 53, 396, 33, 289, 9, 37, 64, 620, 21, 39, 321, 5, # 320
+ 12, 11, 52, 13, 3, 208, 138, 0, 7, 60, 526, 141, 151,1069, 181, 275, # 336
+1591, 83, 132,1475, 126, 331, 829, 15, 69, 160, 59, 22, 157, 55,1079, 312, # 352
+ 109, 38, 23, 25, 10, 19, 79,5195, 61, 382,1124, 8, 30,5196,5197,5198, # 368
+5199,5200,5201,5202,5203,5204,5205,5206, 89, 62, 74, 34,2416, 112, 139, 196, # 384
+ 271, 149, 84, 607, 131, 765, 46, 88, 153, 683, 76, 874, 101, 258, 57, 80, # 400
+ 32, 364, 121,1508, 169,1547, 68, 235, 145,2999, 41, 360,3027, 70, 63, 31, # 416
+ 43, 259, 262,1383, 99, 533, 194, 66, 93, 846, 217, 192, 56, 106, 58, 565, # 432
+ 280, 272, 311, 256, 146, 82, 308, 71, 100, 128, 214, 655, 110, 261, 104,1140, # 448
+ 54, 51, 36, 87, 67,3070, 185,2618,2936,2020, 28,1066,2390,2059,5207,5208, # 464
+5209,5210,5211,5212,5213,5214,5215,5216,4615,5217,5218,5219,5220,5221,5222,5223, # 480
+5224,5225,5226,5227,5228,5229,5230,5231,5232,5233,5234,5235,5236,3514,5237,5238, # 496
+5239,5240,5241,5242,5243,5244,2297,2031,4616,4310,3692,5245,3071,5246,3598,5247, # 512
+4617,3231,3515,5248,4101,4311,4618,3808,4312,4102,5249,4103,4104,3599,5250,5251, # 528
+5252,5253,5254,5255,5256,5257,5258,5259,5260,5261,5262,5263,5264,5265,5266,5267, # 544
+5268,5269,5270,5271,5272,5273,5274,5275,5276,5277,5278,5279,5280,5281,5282,5283, # 560
+5284,5285,5286,5287,5288,5289,5290,5291,5292,5293,5294,5295,5296,5297,5298,5299, # 576
+5300,5301,5302,5303,5304,5305,5306,5307,5308,5309,5310,5311,5312,5313,5314,5315, # 592
+5316,5317,5318,5319,5320,5321,5322,5323,5324,5325,5326,5327,5328,5329,5330,5331, # 608
+5332,5333,5334,5335,5336,5337,5338,5339,5340,5341,5342,5343,5344,5345,5346,5347, # 624
+5348,5349,5350,5351,5352,5353,5354,5355,5356,5357,5358,5359,5360,5361,5362,5363, # 640
+5364,5365,5366,5367,5368,5369,5370,5371,5372,5373,5374,5375,5376,5377,5378,5379, # 656
+5380,5381, 363, 642,2787,2878,2788,2789,2316,3232,2317,3434,2011, 165,1942,3930, # 672
+3931,3932,3933,5382,4619,5383,4620,5384,5385,5386,5387,5388,5389,5390,5391,5392, # 688
+5393,5394,5395,5396,5397,5398,5399,5400,5401,5402,5403,5404,5405,5406,5407,5408, # 704
+5409,5410,5411,5412,5413,5414,5415,5416,5417,5418,5419,5420,5421,5422,5423,5424, # 720
+5425,5426,5427,5428,5429,5430,5431,5432,5433,5434,5435,5436,5437,5438,5439,5440, # 736
+5441,5442,5443,5444,5445,5446,5447,5448,5449,5450,5451,5452,5453,5454,5455,5456, # 752
+5457,5458,5459,5460,5461,5462,5463,5464,5465,5466,5467,5468,5469,5470,5471,5472, # 768
+5473,5474,5475,5476,5477,5478,5479,5480,5481,5482,5483,5484,5485,5486,5487,5488, # 784
+5489,5490,5491,5492,5493,5494,5495,5496,5497,5498,5499,5500,5501,5502,5503,5504, # 800
+5505,5506,5507,5508,5509,5510,5511,5512,5513,5514,5515,5516,5517,5518,5519,5520, # 816
+5521,5522,5523,5524,5525,5526,5527,5528,5529,5530,5531,5532,5533,5534,5535,5536, # 832
+5537,5538,5539,5540,5541,5542,5543,5544,5545,5546,5547,5548,5549,5550,5551,5552, # 848
+5553,5554,5555,5556,5557,5558,5559,5560,5561,5562,5563,5564,5565,5566,5567,5568, # 864
+5569,5570,5571,5572,5573,5574,5575,5576,5577,5578,5579,5580,5581,5582,5583,5584, # 880
+5585,5586,5587,5588,5589,5590,5591,5592,5593,5594,5595,5596,5597,5598,5599,5600, # 896
+5601,5602,5603,5604,5605,5606,5607,5608,5609,5610,5611,5612,5613,5614,5615,5616, # 912
+5617,5618,5619,5620,5621,5622,5623,5624,5625,5626,5627,5628,5629,5630,5631,5632, # 928
+5633,5634,5635,5636,5637,5638,5639,5640,5641,5642,5643,5644,5645,5646,5647,5648, # 944
+5649,5650,5651,5652,5653,5654,5655,5656,5657,5658,5659,5660,5661,5662,5663,5664, # 960
+5665,5666,5667,5668,5669,5670,5671,5672,5673,5674,5675,5676,5677,5678,5679,5680, # 976
+5681,5682,5683,5684,5685,5686,5687,5688,5689,5690,5691,5692,5693,5694,5695,5696, # 992
+5697,5698,5699,5700,5701,5702,5703,5704,5705,5706,5707,5708,5709,5710,5711,5712, # 1008
+5713,5714,5715,5716,5717,5718,5719,5720,5721,5722,5723,5724,5725,5726,5727,5728, # 1024
+5729,5730,5731,5732,5733,5734,5735,5736,5737,5738,5739,5740,5741,5742,5743,5744, # 1040
+5745,5746,5747,5748,5749,5750,5751,5752,5753,5754,5755,5756,5757,5758,5759,5760, # 1056
+5761,5762,5763,5764,5765,5766,5767,5768,5769,5770,5771,5772,5773,5774,5775,5776, # 1072
+5777,5778,5779,5780,5781,5782,5783,5784,5785,5786,5787,5788,5789,5790,5791,5792, # 1088
+5793,5794,5795,5796,5797,5798,5799,5800,5801,5802,5803,5804,5805,5806,5807,5808, # 1104
+5809,5810,5811,5812,5813,5814,5815,5816,5817,5818,5819,5820,5821,5822,5823,5824, # 1120
+5825,5826,5827,5828,5829,5830,5831,5832,5833,5834,5835,5836,5837,5838,5839,5840, # 1136
+5841,5842,5843,5844,5845,5846,5847,5848,5849,5850,5851,5852,5853,5854,5855,5856, # 1152
+5857,5858,5859,5860,5861,5862,5863,5864,5865,5866,5867,5868,5869,5870,5871,5872, # 1168
+5873,5874,5875,5876,5877,5878,5879,5880,5881,5882,5883,5884,5885,5886,5887,5888, # 1184
+5889,5890,5891,5892,5893,5894,5895,5896,5897,5898,5899,5900,5901,5902,5903,5904, # 1200
+5905,5906,5907,5908,5909,5910,5911,5912,5913,5914,5915,5916,5917,5918,5919,5920, # 1216
+5921,5922,5923,5924,5925,5926,5927,5928,5929,5930,5931,5932,5933,5934,5935,5936, # 1232
+5937,5938,5939,5940,5941,5942,5943,5944,5945,5946,5947,5948,5949,5950,5951,5952, # 1248
+5953,5954,5955,5956,5957,5958,5959,5960,5961,5962,5963,5964,5965,5966,5967,5968, # 1264
+5969,5970,5971,5972,5973,5974,5975,5976,5977,5978,5979,5980,5981,5982,5983,5984, # 1280
+5985,5986,5987,5988,5989,5990,5991,5992,5993,5994,5995,5996,5997,5998,5999,6000, # 1296
+6001,6002,6003,6004,6005,6006,6007,6008,6009,6010,6011,6012,6013,6014,6015,6016, # 1312
+6017,6018,6019,6020,6021,6022,6023,6024,6025,6026,6027,6028,6029,6030,6031,6032, # 1328
+6033,6034,6035,6036,6037,6038,6039,6040,6041,6042,6043,6044,6045,6046,6047,6048, # 1344
+6049,6050,6051,6052,6053,6054,6055,6056,6057,6058,6059,6060,6061,6062,6063,6064, # 1360
+6065,6066,6067,6068,6069,6070,6071,6072,6073,6074,6075,6076,6077,6078,6079,6080, # 1376
+6081,6082,6083,6084,6085,6086,6087,6088,6089,6090,6091,6092,6093,6094,6095,6096, # 1392
+6097,6098,6099,6100,6101,6102,6103,6104,6105,6106,6107,6108,6109,6110,6111,6112, # 1408
+6113,6114,2044,2060,4621, 997,1235, 473,1186,4622, 920,3378,6115,6116, 379,1108, # 1424
+4313,2657,2735,3934,6117,3809, 636,3233, 573,1026,3693,3435,2974,3300,2298,4105, # 1440
+ 854,2937,2463, 393,2581,2417, 539, 752,1280,2750,2480, 140,1161, 440, 708,1569, # 1456
+ 665,2497,1746,1291,1523,3000, 164,1603, 847,1331, 537,1997, 486, 508,1693,2418, # 1472
+1970,2227, 878,1220, 299,1030, 969, 652,2751, 624,1137,3301,2619, 65,3302,2045, # 1488
+1761,1859,3120,1930,3694,3516, 663,1767, 852, 835,3695, 269, 767,2826,2339,1305, # 1504
+ 896,1150, 770,1616,6118, 506,1502,2075,1012,2519, 775,2520,2975,2340,2938,4314, # 1520
+3028,2086,1224,1943,2286,6119,3072,4315,2240,1273,1987,3935,1557, 175, 597, 985, # 1536
+3517,2419,2521,1416,3029, 585, 938,1931,1007,1052,1932,1685,6120,3379,4316,4623, # 1552
+ 804, 599,3121,1333,2128,2539,1159,1554,2032,3810, 687,2033,2904, 952, 675,1467, # 1568
+3436,6121,2241,1096,1786,2440,1543,1924, 980,1813,2228, 781,2692,1879, 728,1918, # 1584
+3696,4624, 548,1950,4625,1809,1088,1356,3303,2522,1944, 502, 972, 373, 513,2827, # 1600
+ 586,2377,2391,1003,1976,1631,6122,2464,1084, 648,1776,4626,2141, 324, 962,2012, # 1616
+2177,2076,1384, 742,2178,1448,1173,1810, 222, 102, 301, 445, 125,2420, 662,2498, # 1632
+ 277, 200,1476,1165,1068, 224,2562,1378,1446, 450,1880, 659, 791, 582,4627,2939, # 1648
+3936,1516,1274, 555,2099,3697,1020,1389,1526,3380,1762,1723,1787,2229, 412,2114, # 1664
+1900,2392,3518, 512,2597, 427,1925,2341,3122,1653,1686,2465,2499, 697, 330, 273, # 1680
+ 380,2162, 951, 832, 780, 991,1301,3073, 965,2270,3519, 668,2523,2636,1286, 535, # 1696
+1407, 518, 671, 957,2658,2378, 267, 611,2197,3030,6123, 248,2299, 967,1799,2356, # 1712
+ 850,1418,3437,1876,1256,1480,2828,1718,6124,6125,1755,1664,2405,6126,4628,2879, # 1728
+2829, 499,2179, 676,4629, 557,2329,2214,2090, 325,3234, 464, 811,3001, 992,2342, # 1744
+2481,1232,1469, 303,2242, 466,1070,2163, 603,1777,2091,4630,2752,4631,2714, 322, # 1760
+2659,1964,1768, 481,2188,1463,2330,2857,3600,2092,3031,2421,4632,2318,2070,1849, # 1776
+2598,4633,1302,2254,1668,1701,2422,3811,2905,3032,3123,2046,4106,1763,1694,4634, # 1792
+1604, 943,1724,1454, 917, 868,2215,1169,2940, 552,1145,1800,1228,1823,1955, 316, # 1808
+1080,2510, 361,1807,2830,4107,2660,3381,1346,1423,1134,4108,6127, 541,1263,1229, # 1824
+1148,2540, 545, 465,1833,2880,3438,1901,3074,2482, 816,3937, 713,1788,2500, 122, # 1840
+1575, 195,1451,2501,1111,6128, 859, 374,1225,2243,2483,4317, 390,1033,3439,3075, # 1856
+2524,1687, 266, 793,1440,2599, 946, 779, 802, 507, 897,1081, 528,2189,1292, 711, # 1872
+1866,1725,1167,1640, 753, 398,2661,1053, 246, 348,4318, 137,1024,3440,1600,2077, # 1888
+2129, 825,4319, 698, 238, 521, 187,2300,1157,2423,1641,1605,1464,1610,1097,2541, # 1904
+1260,1436, 759,2255,1814,2150, 705,3235, 409,2563,3304, 561,3033,2005,2564, 726, # 1920
+1956,2343,3698,4109, 949,3812,3813,3520,1669, 653,1379,2525, 881,2198, 632,2256, # 1936
+1027, 778,1074, 733,1957, 514,1481,2466, 554,2180, 702,3938,1606,1017,1398,6129, # 1952
+1380,3521, 921, 993,1313, 594, 449,1489,1617,1166, 768,1426,1360, 495,1794,3601, # 1968
+1177,3602,1170,4320,2344, 476, 425,3167,4635,3168,1424, 401,2662,1171,3382,1998, # 1984
+1089,4110, 477,3169, 474,6130,1909, 596,2831,1842, 494, 693,1051,1028,1207,3076, # 2000
+ 606,2115, 727,2790,1473,1115, 743,3522, 630, 805,1532,4321,2021, 366,1057, 838, # 2016
+ 684,1114,2142,4322,2050,1492,1892,1808,2271,3814,2424,1971,1447,1373,3305,1090, # 2032
+1536,3939,3523,3306,1455,2199, 336, 369,2331,1035, 584,2393, 902, 718,2600,6131, # 2048
+2753, 463,2151,1149,1611,2467, 715,1308,3124,1268, 343,1413,3236,1517,1347,2663, # 2064
+2093,3940,2022,1131,1553,2100,2941,1427,3441,2942,1323,2484,6132,1980, 872,2368, # 2080
+2441,2943, 320,2369,2116,1082, 679,1933,3941,2791,3815, 625,1143,2023, 422,2200, # 2096
+3816,6133, 730,1695, 356,2257,1626,2301,2858,2637,1627,1778, 937, 883,2906,2693, # 2112
+3002,1769,1086, 400,1063,1325,3307,2792,4111,3077, 456,2345,1046, 747,6134,1524, # 2128
+ 884,1094,3383,1474,2164,1059, 974,1688,2181,2258,1047, 345,1665,1187, 358, 875, # 2144
+3170, 305, 660,3524,2190,1334,1135,3171,1540,1649,2542,1527, 927, 968,2793, 885, # 2160
+1972,1850, 482, 500,2638,1218,1109,1085,2543,1654,2034, 876, 78,2287,1482,1277, # 2176
+ 861,1675,1083,1779, 724,2754, 454, 397,1132,1612,2332, 893, 672,1237, 257,2259, # 2192
+2370, 135,3384, 337,2244, 547, 352, 340, 709,2485,1400, 788,1138,2511, 540, 772, # 2208
+1682,2260,2272,2544,2013,1843,1902,4636,1999,1562,2288,4637,2201,1403,1533, 407, # 2224
+ 576,3308,1254,2071, 978,3385, 170, 136,1201,3125,2664,3172,2394, 213, 912, 873, # 2240
+3603,1713,2202, 699,3604,3699, 813,3442, 493, 531,1054, 468,2907,1483, 304, 281, # 2256
+4112,1726,1252,2094, 339,2319,2130,2639, 756,1563,2944, 748, 571,2976,1588,2425, # 2272
+2715,1851,1460,2426,1528,1392,1973,3237, 288,3309, 685,3386, 296, 892,2716,2216, # 2288
+1570,2245, 722,1747,2217, 905,3238,1103,6135,1893,1441,1965, 251,1805,2371,3700, # 2304
+2601,1919,1078, 75,2182,1509,1592,1270,2640,4638,2152,6136,3310,3817, 524, 706, # 2320
+1075, 292,3818,1756,2602, 317, 98,3173,3605,3525,1844,2218,3819,2502, 814, 567, # 2336
+ 385,2908,1534,6137, 534,1642,3239, 797,6138,1670,1529, 953,4323, 188,1071, 538, # 2352
+ 178, 729,3240,2109,1226,1374,2000,2357,2977, 731,2468,1116,2014,2051,6139,1261, # 2368
+1593, 803,2859,2736,3443, 556, 682, 823,1541,6140,1369,2289,1706,2794, 845, 462, # 2384
+2603,2665,1361, 387, 162,2358,1740, 739,1770,1720,1304,1401,3241,1049, 627,1571, # 2400
+2427,3526,1877,3942,1852,1500, 431,1910,1503, 677, 297,2795, 286,1433,1038,1198, # 2416
+2290,1133,1596,4113,4639,2469,1510,1484,3943,6141,2442, 108, 712,4640,2372, 866, # 2432
+3701,2755,3242,1348, 834,1945,1408,3527,2395,3243,1811, 824, 994,1179,2110,1548, # 2448
+1453, 790,3003, 690,4324,4325,2832,2909,3820,1860,3821, 225,1748, 310, 346,1780, # 2464
+2470, 821,1993,2717,2796, 828, 877,3528,2860,2471,1702,2165,2910,2486,1789, 453, # 2480
+ 359,2291,1676, 73,1164,1461,1127,3311, 421, 604, 314,1037, 589, 116,2487, 737, # 2496
+ 837,1180, 111, 244, 735,6142,2261,1861,1362, 986, 523, 418, 581,2666,3822, 103, # 2512
+ 855, 503,1414,1867,2488,1091, 657,1597, 979, 605,1316,4641,1021,2443,2078,2001, # 2528
+1209, 96, 587,2166,1032, 260,1072,2153, 173, 94, 226,3244, 819,2006,4642,4114, # 2544
+2203, 231,1744, 782, 97,2667, 786,3387, 887, 391, 442,2219,4326,1425,6143,2694, # 2560
+ 633,1544,1202, 483,2015, 592,2052,1958,2472,1655, 419, 129,4327,3444,3312,1714, # 2576
+1257,3078,4328,1518,1098, 865,1310,1019,1885,1512,1734, 469,2444, 148, 773, 436, # 2592
+1815,1868,1128,1055,4329,1245,2756,3445,2154,1934,1039,4643, 579,1238, 932,2320, # 2608
+ 353, 205, 801, 115,2428, 944,2321,1881, 399,2565,1211, 678, 766,3944, 335,2101, # 2624
+1459,1781,1402,3945,2737,2131,1010, 844, 981,1326,1013, 550,1816,1545,2620,1335, # 2640
+1008, 371,2881, 936,1419,1613,3529,1456,1395,2273,1834,2604,1317,2738,2503, 416, # 2656
+1643,4330, 806,1126, 229, 591,3946,1314,1981,1576,1837,1666, 347,1790, 977,3313, # 2672
+ 764,2861,1853, 688,2429,1920,1462, 77, 595, 415,2002,3034, 798,1192,4115,6144, # 2688
+2978,4331,3035,2695,2582,2072,2566, 430,2430,1727, 842,1396,3947,3702, 613, 377, # 2704
+ 278, 236,1417,3388,3314,3174, 757,1869, 107,3530,6145,1194, 623,2262, 207,1253, # 2720
+2167,3446,3948, 492,1117,1935, 536,1838,2757,1246,4332, 696,2095,2406,1393,1572, # 2736
+3175,1782, 583, 190, 253,1390,2230, 830,3126,3389, 934,3245,1703,1749,2979,1870, # 2752
+2545,1656,2204, 869,2346,4116,3176,1817, 496,1764,4644, 942,1504, 404,1903,1122, # 2768
+1580,3606,2945,1022, 515, 372,1735, 955,2431,3036,6146,2797,1110,2302,2798, 617, # 2784
+6147, 441, 762,1771,3447,3607,3608,1904, 840,3037, 86, 939,1385, 572,1370,2445, # 2800
+1336, 114,3703, 898, 294, 203,3315, 703,1583,2274, 429, 961,4333,1854,1951,3390, # 2816
+2373,3704,4334,1318,1381, 966,1911,2322,1006,1155, 309, 989, 458,2718,1795,1372, # 2832
+1203, 252,1689,1363,3177, 517,1936, 168,1490, 562, 193,3823,1042,4117,1835, 551, # 2848
+ 470,4645, 395, 489,3448,1871,1465,2583,2641, 417,1493, 279,1295, 511,1236,1119, # 2864
+ 72,1231,1982,1812,3004, 871,1564, 984,3449,1667,2696,2096,4646,2347,2833,1673, # 2880
+3609, 695,3246,2668, 807,1183,4647, 890, 388,2333,1801,1457,2911,1765,1477,1031, # 2896
+3316,3317,1278,3391,2799,2292,2526, 163,3450,4335,2669,1404,1802,6148,2323,2407, # 2912
+1584,1728,1494,1824,1269, 298, 909,3318,1034,1632, 375, 776,1683,2061, 291, 210, # 2928
+1123, 809,1249,1002,2642,3038, 206,1011,2132, 144, 975, 882,1565, 342, 667, 754, # 2944
+1442,2143,1299,2303,2062, 447, 626,2205,1221,2739,2912,1144,1214,2206,2584, 760, # 2960
+1715, 614, 950,1281,2670,2621, 810, 577,1287,2546,4648, 242,2168, 250,2643, 691, # 2976
+ 123,2644, 647, 313,1029, 689,1357,2946,1650, 216, 771,1339,1306, 808,2063, 549, # 2992
+ 913,1371,2913,2914,6149,1466,1092,1174,1196,1311,2605,2396,1783,1796,3079, 406, # 3008
+2671,2117,3949,4649, 487,1825,2220,6150,2915, 448,2348,1073,6151,2397,1707, 130, # 3024
+ 900,1598, 329, 176,1959,2527,1620,6152,2275,4336,3319,1983,2191,3705,3610,2155, # 3040
+3706,1912,1513,1614,6153,1988, 646, 392,2304,1589,3320,3039,1826,1239,1352,1340, # 3056
+2916, 505,2567,1709,1437,2408,2547, 906,6154,2672, 384,1458,1594,1100,1329, 710, # 3072
+ 423,3531,2064,2231,2622,1989,2673,1087,1882, 333, 841,3005,1296,2882,2379, 580, # 3088
+1937,1827,1293,2585, 601, 574, 249,1772,4118,2079,1120, 645, 901,1176,1690, 795, # 3104
+2207, 478,1434, 516,1190,1530, 761,2080, 930,1264, 355, 435,1552, 644,1791, 987, # 3120
+ 220,1364,1163,1121,1538, 306,2169,1327,1222, 546,2645, 218, 241, 610,1704,3321, # 3136
+1984,1839,1966,2528, 451,6155,2586,3707,2568, 907,3178, 254,2947, 186,1845,4650, # 3152
+ 745, 432,1757, 428,1633, 888,2246,2221,2489,3611,2118,1258,1265, 956,3127,1784, # 3168
+4337,2490, 319, 510, 119, 457,3612, 274,2035,2007,4651,1409,3128, 970,2758, 590, # 3184
+2800, 661,2247,4652,2008,3950,1420,1549,3080,3322,3951,1651,1375,2111, 485,2491, # 3200
+1429,1156,6156,2548,2183,1495, 831,1840,2529,2446, 501,1657, 307,1894,3247,1341, # 3216
+ 666, 899,2156,1539,2549,1559, 886, 349,2208,3081,2305,1736,3824,2170,2759,1014, # 3232
+1913,1386, 542,1397,2948, 490, 368, 716, 362, 159, 282,2569,1129,1658,1288,1750, # 3248
+2674, 276, 649,2016, 751,1496, 658,1818,1284,1862,2209,2087,2512,3451, 622,2834, # 3264
+ 376, 117,1060,2053,1208,1721,1101,1443, 247,1250,3179,1792,3952,2760,2398,3953, # 3280
+6157,2144,3708, 446,2432,1151,2570,3452,2447,2761,2835,1210,2448,3082, 424,2222, # 3296
+1251,2449,2119,2836, 504,1581,4338, 602, 817, 857,3825,2349,2306, 357,3826,1470, # 3312
+1883,2883, 255, 958, 929,2917,3248, 302,4653,1050,1271,1751,2307,1952,1430,2697, # 3328
+2719,2359, 354,3180, 777, 158,2036,4339,1659,4340,4654,2308,2949,2248,1146,2232, # 3344
+3532,2720,1696,2623,3827,6158,3129,1550,2698,1485,1297,1428, 637, 931,2721,2145, # 3360
+ 914,2550,2587, 81,2450, 612, 827,2646,1242,4655,1118,2884, 472,1855,3181,3533, # 3376
+3534, 569,1353,2699,1244,1758,2588,4119,2009,2762,2171,3709,1312,1531,6159,1152, # 3392
+1938, 134,1830, 471,3710,2276,1112,1535,3323,3453,3535, 982,1337,2950, 488, 826, # 3408
+ 674,1058,1628,4120,2017, 522,2399, 211, 568,1367,3454, 350, 293,1872,1139,3249, # 3424
+1399,1946,3006,1300,2360,3324, 588, 736,6160,2606, 744, 669,3536,3828,6161,1358, # 3440
+ 199, 723, 848, 933, 851,1939,1505,1514,1338,1618,1831,4656,1634,3613, 443,2740, # 3456
+3829, 717,1947, 491,1914,6162,2551,1542,4121,1025,6163,1099,1223, 198,3040,2722, # 3472
+ 370, 410,1905,2589, 998,1248,3182,2380, 519,1449,4122,1710, 947, 928,1153,4341, # 3488
+2277, 344,2624,1511, 615, 105, 161,1212,1076,1960,3130,2054,1926,1175,1906,2473, # 3504
+ 414,1873,2801,6164,2309, 315,1319,3325, 318,2018,2146,2157, 963, 631, 223,4342, # 3520
+4343,2675, 479,3711,1197,2625,3712,2676,2361,6165,4344,4123,6166,2451,3183,1886, # 3536
+2184,1674,1330,1711,1635,1506, 799, 219,3250,3083,3954,1677,3713,3326,2081,3614, # 3552
+1652,2073,4657,1147,3041,1752, 643,1961, 147,1974,3955,6167,1716,2037, 918,3007, # 3568
+1994, 120,1537, 118, 609,3184,4345, 740,3455,1219, 332,1615,3830,6168,1621,2980, # 3584
+1582, 783, 212, 553,2350,3714,1349,2433,2082,4124, 889,6169,2310,1275,1410, 973, # 3600
+ 166,1320,3456,1797,1215,3185,2885,1846,2590,2763,4658, 629, 822,3008, 763, 940, # 3616
+1990,2862, 439,2409,1566,1240,1622, 926,1282,1907,2764, 654,2210,1607, 327,1130, # 3632
+3956,1678,1623,6170,2434,2192, 686, 608,3831,3715, 903,3957,3042,6171,2741,1522, # 3648
+1915,1105,1555,2552,1359, 323,3251,4346,3457, 738,1354,2553,2311,2334,1828,2003, # 3664
+3832,1753,2351,1227,6172,1887,4125,1478,6173,2410,1874,1712,1847, 520,1204,2607, # 3680
+ 264,4659, 836,2677,2102, 600,4660,3833,2278,3084,6174,4347,3615,1342, 640, 532, # 3696
+ 543,2608,1888,2400,2591,1009,4348,1497, 341,1737,3616,2723,1394, 529,3252,1321, # 3712
+ 983,4661,1515,2120, 971,2592, 924, 287,1662,3186,4349,2700,4350,1519, 908,1948, # 3728
+2452, 156, 796,1629,1486,2223,2055, 694,4126,1259,1036,3392,1213,2249,2742,1889, # 3744
+1230,3958,1015, 910, 408, 559,3617,4662, 746, 725, 935,4663,3959,3009,1289, 563, # 3760
+ 867,4664,3960,1567,2981,2038,2626, 988,2263,2381,4351, 143,2374, 704,1895,6175, # 3776
+1188,3716,2088, 673,3085,2362,4352, 484,1608,1921,2765,2918, 215, 904,3618,3537, # 3792
+ 894, 509, 976,3043,2701,3961,4353,2837,2982, 498,6176,6177,1102,3538,1332,3393, # 3808
+1487,1636,1637, 233, 245,3962, 383, 650, 995,3044, 460,1520,1206,2352, 749,3327, # 3824
+ 530, 700, 389,1438,1560,1773,3963,2264, 719,2951,2724,3834, 870,1832,1644,1000, # 3840
+ 839,2474,3717, 197,1630,3394, 365,2886,3964,1285,2133, 734, 922, 818,1106, 732, # 3856
+ 480,2083,1774,3458, 923,2279,1350, 221,3086, 85,2233,2234,3835,1585,3010,2147, # 3872
+1387,1705,2382,1619,2475, 133, 239,2802,1991,1016,2084,2383, 411,2838,1113, 651, # 3888
+1985,1160,3328, 990,1863,3087,1048,1276,2647, 265,2627,1599,3253,2056, 150, 638, # 3904
+2019, 656, 853, 326,1479, 680,1439,4354,1001,1759, 413,3459,3395,2492,1431, 459, # 3920
+4355,1125,3329,2265,1953,1450,2065,2863, 849, 351,2678,3131,3254,3255,1104,1577, # 3936
+ 227,1351,1645,2453,2193,1421,2887, 812,2121, 634, 95,2435, 201,2312,4665,1646, # 3952
+1671,2743,1601,2554,2702,2648,2280,1315,1366,2089,3132,1573,3718,3965,1729,1189, # 3968
+ 328,2679,1077,1940,1136, 558,1283, 964,1195, 621,2074,1199,1743,3460,3619,1896, # 3984
+1916,1890,3836,2952,1154,2112,1064, 862, 378,3011,2066,2113,2803,1568,2839,6178, # 4000
+3088,2919,1941,1660,2004,1992,2194, 142, 707,1590,1708,1624,1922,1023,1836,1233, # 4016
+1004,2313, 789, 741,3620,6179,1609,2411,1200,4127,3719,3720,4666,2057,3721, 593, # 4032
+2840, 367,2920,1878,6180,3461,1521, 628,1168, 692,2211,2649, 300, 720,2067,2571, # 4048
+2953,3396, 959,2504,3966,3539,3462,1977, 701,6181, 954,1043, 800, 681, 183,3722, # 4064
+1803,1730,3540,4128,2103, 815,2314, 174, 467, 230,2454,1093,2134, 755,3541,3397, # 4080
+1141,1162,6182,1738,2039, 270,3256,2513,1005,1647,2185,3837, 858,1679,1897,1719, # 4096
+2954,2324,1806, 402, 670, 167,4129,1498,2158,2104, 750,6183, 915, 189,1680,1551, # 4112
+ 455,4356,1501,2455, 405,1095,2955, 338,1586,1266,1819, 570, 641,1324, 237,1556, # 4128
+2650,1388,3723,6184,1368,2384,1343,1978,3089,2436, 879,3724, 792,1191, 758,3012, # 4144
+1411,2135,1322,4357, 240,4667,1848,3725,1574,6185, 420,3045,1546,1391, 714,4358, # 4160
+1967, 941,1864, 863, 664, 426, 560,1731,2680,1785,2864,1949,2363, 403,3330,1415, # 4176
+1279,2136,1697,2335, 204, 721,2097,3838, 90,6186,2085,2505, 191,3967, 124,2148, # 4192
+1376,1798,1178,1107,1898,1405, 860,4359,1243,1272,2375,2983,1558,2456,1638, 113, # 4208
+3621, 578,1923,2609, 880, 386,4130, 784,2186,2266,1422,2956,2172,1722, 497, 263, # 4224
+2514,1267,2412,2610, 177,2703,3542, 774,1927,1344, 616,1432,1595,1018, 172,4360, # 4240
+2325, 911,4361, 438,1468,3622, 794,3968,2024,2173,1681,1829,2957, 945, 895,3090, # 4256
+ 575,2212,2476, 475,2401,2681, 785,2744,1745,2293,2555,1975,3133,2865, 394,4668, # 4272
+3839, 635,4131, 639, 202,1507,2195,2766,1345,1435,2572,3726,1908,1184,1181,2457, # 4288
+3727,3134,4362, 843,2611, 437, 916,4669, 234, 769,1884,3046,3047,3623, 833,6187, # 4304
+1639,2250,2402,1355,1185,2010,2047, 999, 525,1732,1290,1488,2612, 948,1578,3728, # 4320
+2413,2477,1216,2725,2159, 334,3840,1328,3624,2921,1525,4132, 564,1056, 891,4363, # 4336
+1444,1698,2385,2251,3729,1365,2281,2235,1717,6188, 864,3841,2515, 444, 527,2767, # 4352
+2922,3625, 544, 461,6189, 566, 209,2437,3398,2098,1065,2068,3331,3626,3257,2137, # 4368 #last 512
+)
+# fmt: on
diff --git a/lib/chardet/johabfreq.py b/lib/chardet/johabfreq.py
new file mode 100644
index 0000000..c129699
--- /dev/null
+++ b/lib/chardet/johabfreq.py
@@ -0,0 +1,2382 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+# The frequency data itself is the same as euc-kr.
+# This is just a mapping table to euc-kr.
+
+JOHAB_TO_EUCKR_ORDER_TABLE = {
+ 0x8861: 0,
+ 0x8862: 1,
+ 0x8865: 2,
+ 0x8868: 3,
+ 0x8869: 4,
+ 0x886A: 5,
+ 0x886B: 6,
+ 0x8871: 7,
+ 0x8873: 8,
+ 0x8874: 9,
+ 0x8875: 10,
+ 0x8876: 11,
+ 0x8877: 12,
+ 0x8878: 13,
+ 0x8879: 14,
+ 0x887B: 15,
+ 0x887C: 16,
+ 0x887D: 17,
+ 0x8881: 18,
+ 0x8882: 19,
+ 0x8885: 20,
+ 0x8889: 21,
+ 0x8891: 22,
+ 0x8893: 23,
+ 0x8895: 24,
+ 0x8896: 25,
+ 0x8897: 26,
+ 0x88A1: 27,
+ 0x88A2: 28,
+ 0x88A5: 29,
+ 0x88A9: 30,
+ 0x88B5: 31,
+ 0x88B7: 32,
+ 0x88C1: 33,
+ 0x88C5: 34,
+ 0x88C9: 35,
+ 0x88E1: 36,
+ 0x88E2: 37,
+ 0x88E5: 38,
+ 0x88E8: 39,
+ 0x88E9: 40,
+ 0x88EB: 41,
+ 0x88F1: 42,
+ 0x88F3: 43,
+ 0x88F5: 44,
+ 0x88F6: 45,
+ 0x88F7: 46,
+ 0x88F8: 47,
+ 0x88FB: 48,
+ 0x88FC: 49,
+ 0x88FD: 50,
+ 0x8941: 51,
+ 0x8945: 52,
+ 0x8949: 53,
+ 0x8951: 54,
+ 0x8953: 55,
+ 0x8955: 56,
+ 0x8956: 57,
+ 0x8957: 58,
+ 0x8961: 59,
+ 0x8962: 60,
+ 0x8963: 61,
+ 0x8965: 62,
+ 0x8968: 63,
+ 0x8969: 64,
+ 0x8971: 65,
+ 0x8973: 66,
+ 0x8975: 67,
+ 0x8976: 68,
+ 0x8977: 69,
+ 0x897B: 70,
+ 0x8981: 71,
+ 0x8985: 72,
+ 0x8989: 73,
+ 0x8993: 74,
+ 0x8995: 75,
+ 0x89A1: 76,
+ 0x89A2: 77,
+ 0x89A5: 78,
+ 0x89A8: 79,
+ 0x89A9: 80,
+ 0x89AB: 81,
+ 0x89AD: 82,
+ 0x89B0: 83,
+ 0x89B1: 84,
+ 0x89B3: 85,
+ 0x89B5: 86,
+ 0x89B7: 87,
+ 0x89B8: 88,
+ 0x89C1: 89,
+ 0x89C2: 90,
+ 0x89C5: 91,
+ 0x89C9: 92,
+ 0x89CB: 93,
+ 0x89D1: 94,
+ 0x89D3: 95,
+ 0x89D5: 96,
+ 0x89D7: 97,
+ 0x89E1: 98,
+ 0x89E5: 99,
+ 0x89E9: 100,
+ 0x89F3: 101,
+ 0x89F6: 102,
+ 0x89F7: 103,
+ 0x8A41: 104,
+ 0x8A42: 105,
+ 0x8A45: 106,
+ 0x8A49: 107,
+ 0x8A51: 108,
+ 0x8A53: 109,
+ 0x8A55: 110,
+ 0x8A57: 111,
+ 0x8A61: 112,
+ 0x8A65: 113,
+ 0x8A69: 114,
+ 0x8A73: 115,
+ 0x8A75: 116,
+ 0x8A81: 117,
+ 0x8A82: 118,
+ 0x8A85: 119,
+ 0x8A88: 120,
+ 0x8A89: 121,
+ 0x8A8A: 122,
+ 0x8A8B: 123,
+ 0x8A90: 124,
+ 0x8A91: 125,
+ 0x8A93: 126,
+ 0x8A95: 127,
+ 0x8A97: 128,
+ 0x8A98: 129,
+ 0x8AA1: 130,
+ 0x8AA2: 131,
+ 0x8AA5: 132,
+ 0x8AA9: 133,
+ 0x8AB6: 134,
+ 0x8AB7: 135,
+ 0x8AC1: 136,
+ 0x8AD5: 137,
+ 0x8AE1: 138,
+ 0x8AE2: 139,
+ 0x8AE5: 140,
+ 0x8AE9: 141,
+ 0x8AF1: 142,
+ 0x8AF3: 143,
+ 0x8AF5: 144,
+ 0x8B41: 145,
+ 0x8B45: 146,
+ 0x8B49: 147,
+ 0x8B61: 148,
+ 0x8B62: 149,
+ 0x8B65: 150,
+ 0x8B68: 151,
+ 0x8B69: 152,
+ 0x8B6A: 153,
+ 0x8B71: 154,
+ 0x8B73: 155,
+ 0x8B75: 156,
+ 0x8B77: 157,
+ 0x8B81: 158,
+ 0x8BA1: 159,
+ 0x8BA2: 160,
+ 0x8BA5: 161,
+ 0x8BA8: 162,
+ 0x8BA9: 163,
+ 0x8BAB: 164,
+ 0x8BB1: 165,
+ 0x8BB3: 166,
+ 0x8BB5: 167,
+ 0x8BB7: 168,
+ 0x8BB8: 169,
+ 0x8BBC: 170,
+ 0x8C61: 171,
+ 0x8C62: 172,
+ 0x8C63: 173,
+ 0x8C65: 174,
+ 0x8C69: 175,
+ 0x8C6B: 176,
+ 0x8C71: 177,
+ 0x8C73: 178,
+ 0x8C75: 179,
+ 0x8C76: 180,
+ 0x8C77: 181,
+ 0x8C7B: 182,
+ 0x8C81: 183,
+ 0x8C82: 184,
+ 0x8C85: 185,
+ 0x8C89: 186,
+ 0x8C91: 187,
+ 0x8C93: 188,
+ 0x8C95: 189,
+ 0x8C96: 190,
+ 0x8C97: 191,
+ 0x8CA1: 192,
+ 0x8CA2: 193,
+ 0x8CA9: 194,
+ 0x8CE1: 195,
+ 0x8CE2: 196,
+ 0x8CE3: 197,
+ 0x8CE5: 198,
+ 0x8CE9: 199,
+ 0x8CF1: 200,
+ 0x8CF3: 201,
+ 0x8CF5: 202,
+ 0x8CF6: 203,
+ 0x8CF7: 204,
+ 0x8D41: 205,
+ 0x8D42: 206,
+ 0x8D45: 207,
+ 0x8D51: 208,
+ 0x8D55: 209,
+ 0x8D57: 210,
+ 0x8D61: 211,
+ 0x8D65: 212,
+ 0x8D69: 213,
+ 0x8D75: 214,
+ 0x8D76: 215,
+ 0x8D7B: 216,
+ 0x8D81: 217,
+ 0x8DA1: 218,
+ 0x8DA2: 219,
+ 0x8DA5: 220,
+ 0x8DA7: 221,
+ 0x8DA9: 222,
+ 0x8DB1: 223,
+ 0x8DB3: 224,
+ 0x8DB5: 225,
+ 0x8DB7: 226,
+ 0x8DB8: 227,
+ 0x8DB9: 228,
+ 0x8DC1: 229,
+ 0x8DC2: 230,
+ 0x8DC9: 231,
+ 0x8DD6: 232,
+ 0x8DD7: 233,
+ 0x8DE1: 234,
+ 0x8DE2: 235,
+ 0x8DF7: 236,
+ 0x8E41: 237,
+ 0x8E45: 238,
+ 0x8E49: 239,
+ 0x8E51: 240,
+ 0x8E53: 241,
+ 0x8E57: 242,
+ 0x8E61: 243,
+ 0x8E81: 244,
+ 0x8E82: 245,
+ 0x8E85: 246,
+ 0x8E89: 247,
+ 0x8E90: 248,
+ 0x8E91: 249,
+ 0x8E93: 250,
+ 0x8E95: 251,
+ 0x8E97: 252,
+ 0x8E98: 253,
+ 0x8EA1: 254,
+ 0x8EA9: 255,
+ 0x8EB6: 256,
+ 0x8EB7: 257,
+ 0x8EC1: 258,
+ 0x8EC2: 259,
+ 0x8EC5: 260,
+ 0x8EC9: 261,
+ 0x8ED1: 262,
+ 0x8ED3: 263,
+ 0x8ED6: 264,
+ 0x8EE1: 265,
+ 0x8EE5: 266,
+ 0x8EE9: 267,
+ 0x8EF1: 268,
+ 0x8EF3: 269,
+ 0x8F41: 270,
+ 0x8F61: 271,
+ 0x8F62: 272,
+ 0x8F65: 273,
+ 0x8F67: 274,
+ 0x8F69: 275,
+ 0x8F6B: 276,
+ 0x8F70: 277,
+ 0x8F71: 278,
+ 0x8F73: 279,
+ 0x8F75: 280,
+ 0x8F77: 281,
+ 0x8F7B: 282,
+ 0x8FA1: 283,
+ 0x8FA2: 284,
+ 0x8FA5: 285,
+ 0x8FA9: 286,
+ 0x8FB1: 287,
+ 0x8FB3: 288,
+ 0x8FB5: 289,
+ 0x8FB7: 290,
+ 0x9061: 291,
+ 0x9062: 292,
+ 0x9063: 293,
+ 0x9065: 294,
+ 0x9068: 295,
+ 0x9069: 296,
+ 0x906A: 297,
+ 0x906B: 298,
+ 0x9071: 299,
+ 0x9073: 300,
+ 0x9075: 301,
+ 0x9076: 302,
+ 0x9077: 303,
+ 0x9078: 304,
+ 0x9079: 305,
+ 0x907B: 306,
+ 0x907D: 307,
+ 0x9081: 308,
+ 0x9082: 309,
+ 0x9085: 310,
+ 0x9089: 311,
+ 0x9091: 312,
+ 0x9093: 313,
+ 0x9095: 314,
+ 0x9096: 315,
+ 0x9097: 316,
+ 0x90A1: 317,
+ 0x90A2: 318,
+ 0x90A5: 319,
+ 0x90A9: 320,
+ 0x90B1: 321,
+ 0x90B7: 322,
+ 0x90E1: 323,
+ 0x90E2: 324,
+ 0x90E4: 325,
+ 0x90E5: 326,
+ 0x90E9: 327,
+ 0x90EB: 328,
+ 0x90EC: 329,
+ 0x90F1: 330,
+ 0x90F3: 331,
+ 0x90F5: 332,
+ 0x90F6: 333,
+ 0x90F7: 334,
+ 0x90FD: 335,
+ 0x9141: 336,
+ 0x9142: 337,
+ 0x9145: 338,
+ 0x9149: 339,
+ 0x9151: 340,
+ 0x9153: 341,
+ 0x9155: 342,
+ 0x9156: 343,
+ 0x9157: 344,
+ 0x9161: 345,
+ 0x9162: 346,
+ 0x9165: 347,
+ 0x9169: 348,
+ 0x9171: 349,
+ 0x9173: 350,
+ 0x9176: 351,
+ 0x9177: 352,
+ 0x917A: 353,
+ 0x9181: 354,
+ 0x9185: 355,
+ 0x91A1: 356,
+ 0x91A2: 357,
+ 0x91A5: 358,
+ 0x91A9: 359,
+ 0x91AB: 360,
+ 0x91B1: 361,
+ 0x91B3: 362,
+ 0x91B5: 363,
+ 0x91B7: 364,
+ 0x91BC: 365,
+ 0x91BD: 366,
+ 0x91C1: 367,
+ 0x91C5: 368,
+ 0x91C9: 369,
+ 0x91D6: 370,
+ 0x9241: 371,
+ 0x9245: 372,
+ 0x9249: 373,
+ 0x9251: 374,
+ 0x9253: 375,
+ 0x9255: 376,
+ 0x9261: 377,
+ 0x9262: 378,
+ 0x9265: 379,
+ 0x9269: 380,
+ 0x9273: 381,
+ 0x9275: 382,
+ 0x9277: 383,
+ 0x9281: 384,
+ 0x9282: 385,
+ 0x9285: 386,
+ 0x9288: 387,
+ 0x9289: 388,
+ 0x9291: 389,
+ 0x9293: 390,
+ 0x9295: 391,
+ 0x9297: 392,
+ 0x92A1: 393,
+ 0x92B6: 394,
+ 0x92C1: 395,
+ 0x92E1: 396,
+ 0x92E5: 397,
+ 0x92E9: 398,
+ 0x92F1: 399,
+ 0x92F3: 400,
+ 0x9341: 401,
+ 0x9342: 402,
+ 0x9349: 403,
+ 0x9351: 404,
+ 0x9353: 405,
+ 0x9357: 406,
+ 0x9361: 407,
+ 0x9362: 408,
+ 0x9365: 409,
+ 0x9369: 410,
+ 0x936A: 411,
+ 0x936B: 412,
+ 0x9371: 413,
+ 0x9373: 414,
+ 0x9375: 415,
+ 0x9377: 416,
+ 0x9378: 417,
+ 0x937C: 418,
+ 0x9381: 419,
+ 0x9385: 420,
+ 0x9389: 421,
+ 0x93A1: 422,
+ 0x93A2: 423,
+ 0x93A5: 424,
+ 0x93A9: 425,
+ 0x93AB: 426,
+ 0x93B1: 427,
+ 0x93B3: 428,
+ 0x93B5: 429,
+ 0x93B7: 430,
+ 0x93BC: 431,
+ 0x9461: 432,
+ 0x9462: 433,
+ 0x9463: 434,
+ 0x9465: 435,
+ 0x9468: 436,
+ 0x9469: 437,
+ 0x946A: 438,
+ 0x946B: 439,
+ 0x946C: 440,
+ 0x9470: 441,
+ 0x9471: 442,
+ 0x9473: 443,
+ 0x9475: 444,
+ 0x9476: 445,
+ 0x9477: 446,
+ 0x9478: 447,
+ 0x9479: 448,
+ 0x947D: 449,
+ 0x9481: 450,
+ 0x9482: 451,
+ 0x9485: 452,
+ 0x9489: 453,
+ 0x9491: 454,
+ 0x9493: 455,
+ 0x9495: 456,
+ 0x9496: 457,
+ 0x9497: 458,
+ 0x94A1: 459,
+ 0x94E1: 460,
+ 0x94E2: 461,
+ 0x94E3: 462,
+ 0x94E5: 463,
+ 0x94E8: 464,
+ 0x94E9: 465,
+ 0x94EB: 466,
+ 0x94EC: 467,
+ 0x94F1: 468,
+ 0x94F3: 469,
+ 0x94F5: 470,
+ 0x94F7: 471,
+ 0x94F9: 472,
+ 0x94FC: 473,
+ 0x9541: 474,
+ 0x9542: 475,
+ 0x9545: 476,
+ 0x9549: 477,
+ 0x9551: 478,
+ 0x9553: 479,
+ 0x9555: 480,
+ 0x9556: 481,
+ 0x9557: 482,
+ 0x9561: 483,
+ 0x9565: 484,
+ 0x9569: 485,
+ 0x9576: 486,
+ 0x9577: 487,
+ 0x9581: 488,
+ 0x9585: 489,
+ 0x95A1: 490,
+ 0x95A2: 491,
+ 0x95A5: 492,
+ 0x95A8: 493,
+ 0x95A9: 494,
+ 0x95AB: 495,
+ 0x95AD: 496,
+ 0x95B1: 497,
+ 0x95B3: 498,
+ 0x95B5: 499,
+ 0x95B7: 500,
+ 0x95B9: 501,
+ 0x95BB: 502,
+ 0x95C1: 503,
+ 0x95C5: 504,
+ 0x95C9: 505,
+ 0x95E1: 506,
+ 0x95F6: 507,
+ 0x9641: 508,
+ 0x9645: 509,
+ 0x9649: 510,
+ 0x9651: 511,
+ 0x9653: 512,
+ 0x9655: 513,
+ 0x9661: 514,
+ 0x9681: 515,
+ 0x9682: 516,
+ 0x9685: 517,
+ 0x9689: 518,
+ 0x9691: 519,
+ 0x9693: 520,
+ 0x9695: 521,
+ 0x9697: 522,
+ 0x96A1: 523,
+ 0x96B6: 524,
+ 0x96C1: 525,
+ 0x96D7: 526,
+ 0x96E1: 527,
+ 0x96E5: 528,
+ 0x96E9: 529,
+ 0x96F3: 530,
+ 0x96F5: 531,
+ 0x96F7: 532,
+ 0x9741: 533,
+ 0x9745: 534,
+ 0x9749: 535,
+ 0x9751: 536,
+ 0x9757: 537,
+ 0x9761: 538,
+ 0x9762: 539,
+ 0x9765: 540,
+ 0x9768: 541,
+ 0x9769: 542,
+ 0x976B: 543,
+ 0x9771: 544,
+ 0x9773: 545,
+ 0x9775: 546,
+ 0x9777: 547,
+ 0x9781: 548,
+ 0x97A1: 549,
+ 0x97A2: 550,
+ 0x97A5: 551,
+ 0x97A8: 552,
+ 0x97A9: 553,
+ 0x97B1: 554,
+ 0x97B3: 555,
+ 0x97B5: 556,
+ 0x97B6: 557,
+ 0x97B7: 558,
+ 0x97B8: 559,
+ 0x9861: 560,
+ 0x9862: 561,
+ 0x9865: 562,
+ 0x9869: 563,
+ 0x9871: 564,
+ 0x9873: 565,
+ 0x9875: 566,
+ 0x9876: 567,
+ 0x9877: 568,
+ 0x987D: 569,
+ 0x9881: 570,
+ 0x9882: 571,
+ 0x9885: 572,
+ 0x9889: 573,
+ 0x9891: 574,
+ 0x9893: 575,
+ 0x9895: 576,
+ 0x9896: 577,
+ 0x9897: 578,
+ 0x98E1: 579,
+ 0x98E2: 580,
+ 0x98E5: 581,
+ 0x98E9: 582,
+ 0x98EB: 583,
+ 0x98EC: 584,
+ 0x98F1: 585,
+ 0x98F3: 586,
+ 0x98F5: 587,
+ 0x98F6: 588,
+ 0x98F7: 589,
+ 0x98FD: 590,
+ 0x9941: 591,
+ 0x9942: 592,
+ 0x9945: 593,
+ 0x9949: 594,
+ 0x9951: 595,
+ 0x9953: 596,
+ 0x9955: 597,
+ 0x9956: 598,
+ 0x9957: 599,
+ 0x9961: 600,
+ 0x9976: 601,
+ 0x99A1: 602,
+ 0x99A2: 603,
+ 0x99A5: 604,
+ 0x99A9: 605,
+ 0x99B7: 606,
+ 0x99C1: 607,
+ 0x99C9: 608,
+ 0x99E1: 609,
+ 0x9A41: 610,
+ 0x9A45: 611,
+ 0x9A81: 612,
+ 0x9A82: 613,
+ 0x9A85: 614,
+ 0x9A89: 615,
+ 0x9A90: 616,
+ 0x9A91: 617,
+ 0x9A97: 618,
+ 0x9AC1: 619,
+ 0x9AE1: 620,
+ 0x9AE5: 621,
+ 0x9AE9: 622,
+ 0x9AF1: 623,
+ 0x9AF3: 624,
+ 0x9AF7: 625,
+ 0x9B61: 626,
+ 0x9B62: 627,
+ 0x9B65: 628,
+ 0x9B68: 629,
+ 0x9B69: 630,
+ 0x9B71: 631,
+ 0x9B73: 632,
+ 0x9B75: 633,
+ 0x9B81: 634,
+ 0x9B85: 635,
+ 0x9B89: 636,
+ 0x9B91: 637,
+ 0x9B93: 638,
+ 0x9BA1: 639,
+ 0x9BA5: 640,
+ 0x9BA9: 641,
+ 0x9BB1: 642,
+ 0x9BB3: 643,
+ 0x9BB5: 644,
+ 0x9BB7: 645,
+ 0x9C61: 646,
+ 0x9C62: 647,
+ 0x9C65: 648,
+ 0x9C69: 649,
+ 0x9C71: 650,
+ 0x9C73: 651,
+ 0x9C75: 652,
+ 0x9C76: 653,
+ 0x9C77: 654,
+ 0x9C78: 655,
+ 0x9C7C: 656,
+ 0x9C7D: 657,
+ 0x9C81: 658,
+ 0x9C82: 659,
+ 0x9C85: 660,
+ 0x9C89: 661,
+ 0x9C91: 662,
+ 0x9C93: 663,
+ 0x9C95: 664,
+ 0x9C96: 665,
+ 0x9C97: 666,
+ 0x9CA1: 667,
+ 0x9CA2: 668,
+ 0x9CA5: 669,
+ 0x9CB5: 670,
+ 0x9CB7: 671,
+ 0x9CE1: 672,
+ 0x9CE2: 673,
+ 0x9CE5: 674,
+ 0x9CE9: 675,
+ 0x9CF1: 676,
+ 0x9CF3: 677,
+ 0x9CF5: 678,
+ 0x9CF6: 679,
+ 0x9CF7: 680,
+ 0x9CFD: 681,
+ 0x9D41: 682,
+ 0x9D42: 683,
+ 0x9D45: 684,
+ 0x9D49: 685,
+ 0x9D51: 686,
+ 0x9D53: 687,
+ 0x9D55: 688,
+ 0x9D57: 689,
+ 0x9D61: 690,
+ 0x9D62: 691,
+ 0x9D65: 692,
+ 0x9D69: 693,
+ 0x9D71: 694,
+ 0x9D73: 695,
+ 0x9D75: 696,
+ 0x9D76: 697,
+ 0x9D77: 698,
+ 0x9D81: 699,
+ 0x9D85: 700,
+ 0x9D93: 701,
+ 0x9D95: 702,
+ 0x9DA1: 703,
+ 0x9DA2: 704,
+ 0x9DA5: 705,
+ 0x9DA9: 706,
+ 0x9DB1: 707,
+ 0x9DB3: 708,
+ 0x9DB5: 709,
+ 0x9DB7: 710,
+ 0x9DC1: 711,
+ 0x9DC5: 712,
+ 0x9DD7: 713,
+ 0x9DF6: 714,
+ 0x9E41: 715,
+ 0x9E45: 716,
+ 0x9E49: 717,
+ 0x9E51: 718,
+ 0x9E53: 719,
+ 0x9E55: 720,
+ 0x9E57: 721,
+ 0x9E61: 722,
+ 0x9E65: 723,
+ 0x9E69: 724,
+ 0x9E73: 725,
+ 0x9E75: 726,
+ 0x9E77: 727,
+ 0x9E81: 728,
+ 0x9E82: 729,
+ 0x9E85: 730,
+ 0x9E89: 731,
+ 0x9E91: 732,
+ 0x9E93: 733,
+ 0x9E95: 734,
+ 0x9E97: 735,
+ 0x9EA1: 736,
+ 0x9EB6: 737,
+ 0x9EC1: 738,
+ 0x9EE1: 739,
+ 0x9EE2: 740,
+ 0x9EE5: 741,
+ 0x9EE9: 742,
+ 0x9EF1: 743,
+ 0x9EF5: 744,
+ 0x9EF7: 745,
+ 0x9F41: 746,
+ 0x9F42: 747,
+ 0x9F45: 748,
+ 0x9F49: 749,
+ 0x9F51: 750,
+ 0x9F53: 751,
+ 0x9F55: 752,
+ 0x9F57: 753,
+ 0x9F61: 754,
+ 0x9F62: 755,
+ 0x9F65: 756,
+ 0x9F69: 757,
+ 0x9F71: 758,
+ 0x9F73: 759,
+ 0x9F75: 760,
+ 0x9F77: 761,
+ 0x9F78: 762,
+ 0x9F7B: 763,
+ 0x9F7C: 764,
+ 0x9FA1: 765,
+ 0x9FA2: 766,
+ 0x9FA5: 767,
+ 0x9FA9: 768,
+ 0x9FB1: 769,
+ 0x9FB3: 770,
+ 0x9FB5: 771,
+ 0x9FB7: 772,
+ 0xA061: 773,
+ 0xA062: 774,
+ 0xA065: 775,
+ 0xA067: 776,
+ 0xA068: 777,
+ 0xA069: 778,
+ 0xA06A: 779,
+ 0xA06B: 780,
+ 0xA071: 781,
+ 0xA073: 782,
+ 0xA075: 783,
+ 0xA077: 784,
+ 0xA078: 785,
+ 0xA07B: 786,
+ 0xA07D: 787,
+ 0xA081: 788,
+ 0xA082: 789,
+ 0xA085: 790,
+ 0xA089: 791,
+ 0xA091: 792,
+ 0xA093: 793,
+ 0xA095: 794,
+ 0xA096: 795,
+ 0xA097: 796,
+ 0xA098: 797,
+ 0xA0A1: 798,
+ 0xA0A2: 799,
+ 0xA0A9: 800,
+ 0xA0B7: 801,
+ 0xA0E1: 802,
+ 0xA0E2: 803,
+ 0xA0E5: 804,
+ 0xA0E9: 805,
+ 0xA0EB: 806,
+ 0xA0F1: 807,
+ 0xA0F3: 808,
+ 0xA0F5: 809,
+ 0xA0F7: 810,
+ 0xA0F8: 811,
+ 0xA0FD: 812,
+ 0xA141: 813,
+ 0xA142: 814,
+ 0xA145: 815,
+ 0xA149: 816,
+ 0xA151: 817,
+ 0xA153: 818,
+ 0xA155: 819,
+ 0xA156: 820,
+ 0xA157: 821,
+ 0xA161: 822,
+ 0xA162: 823,
+ 0xA165: 824,
+ 0xA169: 825,
+ 0xA175: 826,
+ 0xA176: 827,
+ 0xA177: 828,
+ 0xA179: 829,
+ 0xA181: 830,
+ 0xA1A1: 831,
+ 0xA1A2: 832,
+ 0xA1A4: 833,
+ 0xA1A5: 834,
+ 0xA1A9: 835,
+ 0xA1AB: 836,
+ 0xA1B1: 837,
+ 0xA1B3: 838,
+ 0xA1B5: 839,
+ 0xA1B7: 840,
+ 0xA1C1: 841,
+ 0xA1C5: 842,
+ 0xA1D6: 843,
+ 0xA1D7: 844,
+ 0xA241: 845,
+ 0xA245: 846,
+ 0xA249: 847,
+ 0xA253: 848,
+ 0xA255: 849,
+ 0xA257: 850,
+ 0xA261: 851,
+ 0xA265: 852,
+ 0xA269: 853,
+ 0xA273: 854,
+ 0xA275: 855,
+ 0xA281: 856,
+ 0xA282: 857,
+ 0xA283: 858,
+ 0xA285: 859,
+ 0xA288: 860,
+ 0xA289: 861,
+ 0xA28A: 862,
+ 0xA28B: 863,
+ 0xA291: 864,
+ 0xA293: 865,
+ 0xA295: 866,
+ 0xA297: 867,
+ 0xA29B: 868,
+ 0xA29D: 869,
+ 0xA2A1: 870,
+ 0xA2A5: 871,
+ 0xA2A9: 872,
+ 0xA2B3: 873,
+ 0xA2B5: 874,
+ 0xA2C1: 875,
+ 0xA2E1: 876,
+ 0xA2E5: 877,
+ 0xA2E9: 878,
+ 0xA341: 879,
+ 0xA345: 880,
+ 0xA349: 881,
+ 0xA351: 882,
+ 0xA355: 883,
+ 0xA361: 884,
+ 0xA365: 885,
+ 0xA369: 886,
+ 0xA371: 887,
+ 0xA375: 888,
+ 0xA3A1: 889,
+ 0xA3A2: 890,
+ 0xA3A5: 891,
+ 0xA3A8: 892,
+ 0xA3A9: 893,
+ 0xA3AB: 894,
+ 0xA3B1: 895,
+ 0xA3B3: 896,
+ 0xA3B5: 897,
+ 0xA3B6: 898,
+ 0xA3B7: 899,
+ 0xA3B9: 900,
+ 0xA3BB: 901,
+ 0xA461: 902,
+ 0xA462: 903,
+ 0xA463: 904,
+ 0xA464: 905,
+ 0xA465: 906,
+ 0xA468: 907,
+ 0xA469: 908,
+ 0xA46A: 909,
+ 0xA46B: 910,
+ 0xA46C: 911,
+ 0xA471: 912,
+ 0xA473: 913,
+ 0xA475: 914,
+ 0xA477: 915,
+ 0xA47B: 916,
+ 0xA481: 917,
+ 0xA482: 918,
+ 0xA485: 919,
+ 0xA489: 920,
+ 0xA491: 921,
+ 0xA493: 922,
+ 0xA495: 923,
+ 0xA496: 924,
+ 0xA497: 925,
+ 0xA49B: 926,
+ 0xA4A1: 927,
+ 0xA4A2: 928,
+ 0xA4A5: 929,
+ 0xA4B3: 930,
+ 0xA4E1: 931,
+ 0xA4E2: 932,
+ 0xA4E5: 933,
+ 0xA4E8: 934,
+ 0xA4E9: 935,
+ 0xA4EB: 936,
+ 0xA4F1: 937,
+ 0xA4F3: 938,
+ 0xA4F5: 939,
+ 0xA4F7: 940,
+ 0xA4F8: 941,
+ 0xA541: 942,
+ 0xA542: 943,
+ 0xA545: 944,
+ 0xA548: 945,
+ 0xA549: 946,
+ 0xA551: 947,
+ 0xA553: 948,
+ 0xA555: 949,
+ 0xA556: 950,
+ 0xA557: 951,
+ 0xA561: 952,
+ 0xA562: 953,
+ 0xA565: 954,
+ 0xA569: 955,
+ 0xA573: 956,
+ 0xA575: 957,
+ 0xA576: 958,
+ 0xA577: 959,
+ 0xA57B: 960,
+ 0xA581: 961,
+ 0xA585: 962,
+ 0xA5A1: 963,
+ 0xA5A2: 964,
+ 0xA5A3: 965,
+ 0xA5A5: 966,
+ 0xA5A9: 967,
+ 0xA5B1: 968,
+ 0xA5B3: 969,
+ 0xA5B5: 970,
+ 0xA5B7: 971,
+ 0xA5C1: 972,
+ 0xA5C5: 973,
+ 0xA5D6: 974,
+ 0xA5E1: 975,
+ 0xA5F6: 976,
+ 0xA641: 977,
+ 0xA642: 978,
+ 0xA645: 979,
+ 0xA649: 980,
+ 0xA651: 981,
+ 0xA653: 982,
+ 0xA661: 983,
+ 0xA665: 984,
+ 0xA681: 985,
+ 0xA682: 986,
+ 0xA685: 987,
+ 0xA688: 988,
+ 0xA689: 989,
+ 0xA68A: 990,
+ 0xA68B: 991,
+ 0xA691: 992,
+ 0xA693: 993,
+ 0xA695: 994,
+ 0xA697: 995,
+ 0xA69B: 996,
+ 0xA69C: 997,
+ 0xA6A1: 998,
+ 0xA6A9: 999,
+ 0xA6B6: 1000,
+ 0xA6C1: 1001,
+ 0xA6E1: 1002,
+ 0xA6E2: 1003,
+ 0xA6E5: 1004,
+ 0xA6E9: 1005,
+ 0xA6F7: 1006,
+ 0xA741: 1007,
+ 0xA745: 1008,
+ 0xA749: 1009,
+ 0xA751: 1010,
+ 0xA755: 1011,
+ 0xA757: 1012,
+ 0xA761: 1013,
+ 0xA762: 1014,
+ 0xA765: 1015,
+ 0xA769: 1016,
+ 0xA771: 1017,
+ 0xA773: 1018,
+ 0xA775: 1019,
+ 0xA7A1: 1020,
+ 0xA7A2: 1021,
+ 0xA7A5: 1022,
+ 0xA7A9: 1023,
+ 0xA7AB: 1024,
+ 0xA7B1: 1025,
+ 0xA7B3: 1026,
+ 0xA7B5: 1027,
+ 0xA7B7: 1028,
+ 0xA7B8: 1029,
+ 0xA7B9: 1030,
+ 0xA861: 1031,
+ 0xA862: 1032,
+ 0xA865: 1033,
+ 0xA869: 1034,
+ 0xA86B: 1035,
+ 0xA871: 1036,
+ 0xA873: 1037,
+ 0xA875: 1038,
+ 0xA876: 1039,
+ 0xA877: 1040,
+ 0xA87D: 1041,
+ 0xA881: 1042,
+ 0xA882: 1043,
+ 0xA885: 1044,
+ 0xA889: 1045,
+ 0xA891: 1046,
+ 0xA893: 1047,
+ 0xA895: 1048,
+ 0xA896: 1049,
+ 0xA897: 1050,
+ 0xA8A1: 1051,
+ 0xA8A2: 1052,
+ 0xA8B1: 1053,
+ 0xA8E1: 1054,
+ 0xA8E2: 1055,
+ 0xA8E5: 1056,
+ 0xA8E8: 1057,
+ 0xA8E9: 1058,
+ 0xA8F1: 1059,
+ 0xA8F5: 1060,
+ 0xA8F6: 1061,
+ 0xA8F7: 1062,
+ 0xA941: 1063,
+ 0xA957: 1064,
+ 0xA961: 1065,
+ 0xA962: 1066,
+ 0xA971: 1067,
+ 0xA973: 1068,
+ 0xA975: 1069,
+ 0xA976: 1070,
+ 0xA977: 1071,
+ 0xA9A1: 1072,
+ 0xA9A2: 1073,
+ 0xA9A5: 1074,
+ 0xA9A9: 1075,
+ 0xA9B1: 1076,
+ 0xA9B3: 1077,
+ 0xA9B7: 1078,
+ 0xAA41: 1079,
+ 0xAA61: 1080,
+ 0xAA77: 1081,
+ 0xAA81: 1082,
+ 0xAA82: 1083,
+ 0xAA85: 1084,
+ 0xAA89: 1085,
+ 0xAA91: 1086,
+ 0xAA95: 1087,
+ 0xAA97: 1088,
+ 0xAB41: 1089,
+ 0xAB57: 1090,
+ 0xAB61: 1091,
+ 0xAB65: 1092,
+ 0xAB69: 1093,
+ 0xAB71: 1094,
+ 0xAB73: 1095,
+ 0xABA1: 1096,
+ 0xABA2: 1097,
+ 0xABA5: 1098,
+ 0xABA9: 1099,
+ 0xABB1: 1100,
+ 0xABB3: 1101,
+ 0xABB5: 1102,
+ 0xABB7: 1103,
+ 0xAC61: 1104,
+ 0xAC62: 1105,
+ 0xAC64: 1106,
+ 0xAC65: 1107,
+ 0xAC68: 1108,
+ 0xAC69: 1109,
+ 0xAC6A: 1110,
+ 0xAC6B: 1111,
+ 0xAC71: 1112,
+ 0xAC73: 1113,
+ 0xAC75: 1114,
+ 0xAC76: 1115,
+ 0xAC77: 1116,
+ 0xAC7B: 1117,
+ 0xAC81: 1118,
+ 0xAC82: 1119,
+ 0xAC85: 1120,
+ 0xAC89: 1121,
+ 0xAC91: 1122,
+ 0xAC93: 1123,
+ 0xAC95: 1124,
+ 0xAC96: 1125,
+ 0xAC97: 1126,
+ 0xACA1: 1127,
+ 0xACA2: 1128,
+ 0xACA5: 1129,
+ 0xACA9: 1130,
+ 0xACB1: 1131,
+ 0xACB3: 1132,
+ 0xACB5: 1133,
+ 0xACB7: 1134,
+ 0xACC1: 1135,
+ 0xACC5: 1136,
+ 0xACC9: 1137,
+ 0xACD1: 1138,
+ 0xACD7: 1139,
+ 0xACE1: 1140,
+ 0xACE2: 1141,
+ 0xACE3: 1142,
+ 0xACE4: 1143,
+ 0xACE5: 1144,
+ 0xACE8: 1145,
+ 0xACE9: 1146,
+ 0xACEB: 1147,
+ 0xACEC: 1148,
+ 0xACF1: 1149,
+ 0xACF3: 1150,
+ 0xACF5: 1151,
+ 0xACF6: 1152,
+ 0xACF7: 1153,
+ 0xACFC: 1154,
+ 0xAD41: 1155,
+ 0xAD42: 1156,
+ 0xAD45: 1157,
+ 0xAD49: 1158,
+ 0xAD51: 1159,
+ 0xAD53: 1160,
+ 0xAD55: 1161,
+ 0xAD56: 1162,
+ 0xAD57: 1163,
+ 0xAD61: 1164,
+ 0xAD62: 1165,
+ 0xAD65: 1166,
+ 0xAD69: 1167,
+ 0xAD71: 1168,
+ 0xAD73: 1169,
+ 0xAD75: 1170,
+ 0xAD76: 1171,
+ 0xAD77: 1172,
+ 0xAD81: 1173,
+ 0xAD85: 1174,
+ 0xAD89: 1175,
+ 0xAD97: 1176,
+ 0xADA1: 1177,
+ 0xADA2: 1178,
+ 0xADA3: 1179,
+ 0xADA5: 1180,
+ 0xADA9: 1181,
+ 0xADAB: 1182,
+ 0xADB1: 1183,
+ 0xADB3: 1184,
+ 0xADB5: 1185,
+ 0xADB7: 1186,
+ 0xADBB: 1187,
+ 0xADC1: 1188,
+ 0xADC2: 1189,
+ 0xADC5: 1190,
+ 0xADC9: 1191,
+ 0xADD7: 1192,
+ 0xADE1: 1193,
+ 0xADE5: 1194,
+ 0xADE9: 1195,
+ 0xADF1: 1196,
+ 0xADF5: 1197,
+ 0xADF6: 1198,
+ 0xAE41: 1199,
+ 0xAE45: 1200,
+ 0xAE49: 1201,
+ 0xAE51: 1202,
+ 0xAE53: 1203,
+ 0xAE55: 1204,
+ 0xAE61: 1205,
+ 0xAE62: 1206,
+ 0xAE65: 1207,
+ 0xAE69: 1208,
+ 0xAE71: 1209,
+ 0xAE73: 1210,
+ 0xAE75: 1211,
+ 0xAE77: 1212,
+ 0xAE81: 1213,
+ 0xAE82: 1214,
+ 0xAE85: 1215,
+ 0xAE88: 1216,
+ 0xAE89: 1217,
+ 0xAE91: 1218,
+ 0xAE93: 1219,
+ 0xAE95: 1220,
+ 0xAE97: 1221,
+ 0xAE99: 1222,
+ 0xAE9B: 1223,
+ 0xAE9C: 1224,
+ 0xAEA1: 1225,
+ 0xAEB6: 1226,
+ 0xAEC1: 1227,
+ 0xAEC2: 1228,
+ 0xAEC5: 1229,
+ 0xAEC9: 1230,
+ 0xAED1: 1231,
+ 0xAED7: 1232,
+ 0xAEE1: 1233,
+ 0xAEE2: 1234,
+ 0xAEE5: 1235,
+ 0xAEE9: 1236,
+ 0xAEF1: 1237,
+ 0xAEF3: 1238,
+ 0xAEF5: 1239,
+ 0xAEF7: 1240,
+ 0xAF41: 1241,
+ 0xAF42: 1242,
+ 0xAF49: 1243,
+ 0xAF51: 1244,
+ 0xAF55: 1245,
+ 0xAF57: 1246,
+ 0xAF61: 1247,
+ 0xAF62: 1248,
+ 0xAF65: 1249,
+ 0xAF69: 1250,
+ 0xAF6A: 1251,
+ 0xAF71: 1252,
+ 0xAF73: 1253,
+ 0xAF75: 1254,
+ 0xAF77: 1255,
+ 0xAFA1: 1256,
+ 0xAFA2: 1257,
+ 0xAFA5: 1258,
+ 0xAFA8: 1259,
+ 0xAFA9: 1260,
+ 0xAFB0: 1261,
+ 0xAFB1: 1262,
+ 0xAFB3: 1263,
+ 0xAFB5: 1264,
+ 0xAFB7: 1265,
+ 0xAFBC: 1266,
+ 0xB061: 1267,
+ 0xB062: 1268,
+ 0xB064: 1269,
+ 0xB065: 1270,
+ 0xB069: 1271,
+ 0xB071: 1272,
+ 0xB073: 1273,
+ 0xB076: 1274,
+ 0xB077: 1275,
+ 0xB07D: 1276,
+ 0xB081: 1277,
+ 0xB082: 1278,
+ 0xB085: 1279,
+ 0xB089: 1280,
+ 0xB091: 1281,
+ 0xB093: 1282,
+ 0xB096: 1283,
+ 0xB097: 1284,
+ 0xB0B7: 1285,
+ 0xB0E1: 1286,
+ 0xB0E2: 1287,
+ 0xB0E5: 1288,
+ 0xB0E9: 1289,
+ 0xB0EB: 1290,
+ 0xB0F1: 1291,
+ 0xB0F3: 1292,
+ 0xB0F6: 1293,
+ 0xB0F7: 1294,
+ 0xB141: 1295,
+ 0xB145: 1296,
+ 0xB149: 1297,
+ 0xB185: 1298,
+ 0xB1A1: 1299,
+ 0xB1A2: 1300,
+ 0xB1A5: 1301,
+ 0xB1A8: 1302,
+ 0xB1A9: 1303,
+ 0xB1AB: 1304,
+ 0xB1B1: 1305,
+ 0xB1B3: 1306,
+ 0xB1B7: 1307,
+ 0xB1C1: 1308,
+ 0xB1C2: 1309,
+ 0xB1C5: 1310,
+ 0xB1D6: 1311,
+ 0xB1E1: 1312,
+ 0xB1F6: 1313,
+ 0xB241: 1314,
+ 0xB245: 1315,
+ 0xB249: 1316,
+ 0xB251: 1317,
+ 0xB253: 1318,
+ 0xB261: 1319,
+ 0xB281: 1320,
+ 0xB282: 1321,
+ 0xB285: 1322,
+ 0xB289: 1323,
+ 0xB291: 1324,
+ 0xB293: 1325,
+ 0xB297: 1326,
+ 0xB2A1: 1327,
+ 0xB2B6: 1328,
+ 0xB2C1: 1329,
+ 0xB2E1: 1330,
+ 0xB2E5: 1331,
+ 0xB357: 1332,
+ 0xB361: 1333,
+ 0xB362: 1334,
+ 0xB365: 1335,
+ 0xB369: 1336,
+ 0xB36B: 1337,
+ 0xB370: 1338,
+ 0xB371: 1339,
+ 0xB373: 1340,
+ 0xB381: 1341,
+ 0xB385: 1342,
+ 0xB389: 1343,
+ 0xB391: 1344,
+ 0xB3A1: 1345,
+ 0xB3A2: 1346,
+ 0xB3A5: 1347,
+ 0xB3A9: 1348,
+ 0xB3B1: 1349,
+ 0xB3B3: 1350,
+ 0xB3B5: 1351,
+ 0xB3B7: 1352,
+ 0xB461: 1353,
+ 0xB462: 1354,
+ 0xB465: 1355,
+ 0xB466: 1356,
+ 0xB467: 1357,
+ 0xB469: 1358,
+ 0xB46A: 1359,
+ 0xB46B: 1360,
+ 0xB470: 1361,
+ 0xB471: 1362,
+ 0xB473: 1363,
+ 0xB475: 1364,
+ 0xB476: 1365,
+ 0xB477: 1366,
+ 0xB47B: 1367,
+ 0xB47C: 1368,
+ 0xB481: 1369,
+ 0xB482: 1370,
+ 0xB485: 1371,
+ 0xB489: 1372,
+ 0xB491: 1373,
+ 0xB493: 1374,
+ 0xB495: 1375,
+ 0xB496: 1376,
+ 0xB497: 1377,
+ 0xB4A1: 1378,
+ 0xB4A2: 1379,
+ 0xB4A5: 1380,
+ 0xB4A9: 1381,
+ 0xB4AC: 1382,
+ 0xB4B1: 1383,
+ 0xB4B3: 1384,
+ 0xB4B5: 1385,
+ 0xB4B7: 1386,
+ 0xB4BB: 1387,
+ 0xB4BD: 1388,
+ 0xB4C1: 1389,
+ 0xB4C5: 1390,
+ 0xB4C9: 1391,
+ 0xB4D3: 1392,
+ 0xB4E1: 1393,
+ 0xB4E2: 1394,
+ 0xB4E5: 1395,
+ 0xB4E6: 1396,
+ 0xB4E8: 1397,
+ 0xB4E9: 1398,
+ 0xB4EA: 1399,
+ 0xB4EB: 1400,
+ 0xB4F1: 1401,
+ 0xB4F3: 1402,
+ 0xB4F4: 1403,
+ 0xB4F5: 1404,
+ 0xB4F6: 1405,
+ 0xB4F7: 1406,
+ 0xB4F8: 1407,
+ 0xB4FA: 1408,
+ 0xB4FC: 1409,
+ 0xB541: 1410,
+ 0xB542: 1411,
+ 0xB545: 1412,
+ 0xB549: 1413,
+ 0xB551: 1414,
+ 0xB553: 1415,
+ 0xB555: 1416,
+ 0xB557: 1417,
+ 0xB561: 1418,
+ 0xB562: 1419,
+ 0xB563: 1420,
+ 0xB565: 1421,
+ 0xB569: 1422,
+ 0xB56B: 1423,
+ 0xB56C: 1424,
+ 0xB571: 1425,
+ 0xB573: 1426,
+ 0xB574: 1427,
+ 0xB575: 1428,
+ 0xB576: 1429,
+ 0xB577: 1430,
+ 0xB57B: 1431,
+ 0xB57C: 1432,
+ 0xB57D: 1433,
+ 0xB581: 1434,
+ 0xB585: 1435,
+ 0xB589: 1436,
+ 0xB591: 1437,
+ 0xB593: 1438,
+ 0xB595: 1439,
+ 0xB596: 1440,
+ 0xB5A1: 1441,
+ 0xB5A2: 1442,
+ 0xB5A5: 1443,
+ 0xB5A9: 1444,
+ 0xB5AA: 1445,
+ 0xB5AB: 1446,
+ 0xB5AD: 1447,
+ 0xB5B0: 1448,
+ 0xB5B1: 1449,
+ 0xB5B3: 1450,
+ 0xB5B5: 1451,
+ 0xB5B7: 1452,
+ 0xB5B9: 1453,
+ 0xB5C1: 1454,
+ 0xB5C2: 1455,
+ 0xB5C5: 1456,
+ 0xB5C9: 1457,
+ 0xB5D1: 1458,
+ 0xB5D3: 1459,
+ 0xB5D5: 1460,
+ 0xB5D6: 1461,
+ 0xB5D7: 1462,
+ 0xB5E1: 1463,
+ 0xB5E2: 1464,
+ 0xB5E5: 1465,
+ 0xB5F1: 1466,
+ 0xB5F5: 1467,
+ 0xB5F7: 1468,
+ 0xB641: 1469,
+ 0xB642: 1470,
+ 0xB645: 1471,
+ 0xB649: 1472,
+ 0xB651: 1473,
+ 0xB653: 1474,
+ 0xB655: 1475,
+ 0xB657: 1476,
+ 0xB661: 1477,
+ 0xB662: 1478,
+ 0xB665: 1479,
+ 0xB669: 1480,
+ 0xB671: 1481,
+ 0xB673: 1482,
+ 0xB675: 1483,
+ 0xB677: 1484,
+ 0xB681: 1485,
+ 0xB682: 1486,
+ 0xB685: 1487,
+ 0xB689: 1488,
+ 0xB68A: 1489,
+ 0xB68B: 1490,
+ 0xB691: 1491,
+ 0xB693: 1492,
+ 0xB695: 1493,
+ 0xB697: 1494,
+ 0xB6A1: 1495,
+ 0xB6A2: 1496,
+ 0xB6A5: 1497,
+ 0xB6A9: 1498,
+ 0xB6B1: 1499,
+ 0xB6B3: 1500,
+ 0xB6B6: 1501,
+ 0xB6B7: 1502,
+ 0xB6C1: 1503,
+ 0xB6C2: 1504,
+ 0xB6C5: 1505,
+ 0xB6C9: 1506,
+ 0xB6D1: 1507,
+ 0xB6D3: 1508,
+ 0xB6D7: 1509,
+ 0xB6E1: 1510,
+ 0xB6E2: 1511,
+ 0xB6E5: 1512,
+ 0xB6E9: 1513,
+ 0xB6F1: 1514,
+ 0xB6F3: 1515,
+ 0xB6F5: 1516,
+ 0xB6F7: 1517,
+ 0xB741: 1518,
+ 0xB742: 1519,
+ 0xB745: 1520,
+ 0xB749: 1521,
+ 0xB751: 1522,
+ 0xB753: 1523,
+ 0xB755: 1524,
+ 0xB757: 1525,
+ 0xB759: 1526,
+ 0xB761: 1527,
+ 0xB762: 1528,
+ 0xB765: 1529,
+ 0xB769: 1530,
+ 0xB76F: 1531,
+ 0xB771: 1532,
+ 0xB773: 1533,
+ 0xB775: 1534,
+ 0xB777: 1535,
+ 0xB778: 1536,
+ 0xB779: 1537,
+ 0xB77A: 1538,
+ 0xB77B: 1539,
+ 0xB77C: 1540,
+ 0xB77D: 1541,
+ 0xB781: 1542,
+ 0xB785: 1543,
+ 0xB789: 1544,
+ 0xB791: 1545,
+ 0xB795: 1546,
+ 0xB7A1: 1547,
+ 0xB7A2: 1548,
+ 0xB7A5: 1549,
+ 0xB7A9: 1550,
+ 0xB7AA: 1551,
+ 0xB7AB: 1552,
+ 0xB7B0: 1553,
+ 0xB7B1: 1554,
+ 0xB7B3: 1555,
+ 0xB7B5: 1556,
+ 0xB7B6: 1557,
+ 0xB7B7: 1558,
+ 0xB7B8: 1559,
+ 0xB7BC: 1560,
+ 0xB861: 1561,
+ 0xB862: 1562,
+ 0xB865: 1563,
+ 0xB867: 1564,
+ 0xB868: 1565,
+ 0xB869: 1566,
+ 0xB86B: 1567,
+ 0xB871: 1568,
+ 0xB873: 1569,
+ 0xB875: 1570,
+ 0xB876: 1571,
+ 0xB877: 1572,
+ 0xB878: 1573,
+ 0xB881: 1574,
+ 0xB882: 1575,
+ 0xB885: 1576,
+ 0xB889: 1577,
+ 0xB891: 1578,
+ 0xB893: 1579,
+ 0xB895: 1580,
+ 0xB896: 1581,
+ 0xB897: 1582,
+ 0xB8A1: 1583,
+ 0xB8A2: 1584,
+ 0xB8A5: 1585,
+ 0xB8A7: 1586,
+ 0xB8A9: 1587,
+ 0xB8B1: 1588,
+ 0xB8B7: 1589,
+ 0xB8C1: 1590,
+ 0xB8C5: 1591,
+ 0xB8C9: 1592,
+ 0xB8E1: 1593,
+ 0xB8E2: 1594,
+ 0xB8E5: 1595,
+ 0xB8E9: 1596,
+ 0xB8EB: 1597,
+ 0xB8F1: 1598,
+ 0xB8F3: 1599,
+ 0xB8F5: 1600,
+ 0xB8F7: 1601,
+ 0xB8F8: 1602,
+ 0xB941: 1603,
+ 0xB942: 1604,
+ 0xB945: 1605,
+ 0xB949: 1606,
+ 0xB951: 1607,
+ 0xB953: 1608,
+ 0xB955: 1609,
+ 0xB957: 1610,
+ 0xB961: 1611,
+ 0xB965: 1612,
+ 0xB969: 1613,
+ 0xB971: 1614,
+ 0xB973: 1615,
+ 0xB976: 1616,
+ 0xB977: 1617,
+ 0xB981: 1618,
+ 0xB9A1: 1619,
+ 0xB9A2: 1620,
+ 0xB9A5: 1621,
+ 0xB9A9: 1622,
+ 0xB9AB: 1623,
+ 0xB9B1: 1624,
+ 0xB9B3: 1625,
+ 0xB9B5: 1626,
+ 0xB9B7: 1627,
+ 0xB9B8: 1628,
+ 0xB9B9: 1629,
+ 0xB9BD: 1630,
+ 0xB9C1: 1631,
+ 0xB9C2: 1632,
+ 0xB9C9: 1633,
+ 0xB9D3: 1634,
+ 0xB9D5: 1635,
+ 0xB9D7: 1636,
+ 0xB9E1: 1637,
+ 0xB9F6: 1638,
+ 0xB9F7: 1639,
+ 0xBA41: 1640,
+ 0xBA45: 1641,
+ 0xBA49: 1642,
+ 0xBA51: 1643,
+ 0xBA53: 1644,
+ 0xBA55: 1645,
+ 0xBA57: 1646,
+ 0xBA61: 1647,
+ 0xBA62: 1648,
+ 0xBA65: 1649,
+ 0xBA77: 1650,
+ 0xBA81: 1651,
+ 0xBA82: 1652,
+ 0xBA85: 1653,
+ 0xBA89: 1654,
+ 0xBA8A: 1655,
+ 0xBA8B: 1656,
+ 0xBA91: 1657,
+ 0xBA93: 1658,
+ 0xBA95: 1659,
+ 0xBA97: 1660,
+ 0xBAA1: 1661,
+ 0xBAB6: 1662,
+ 0xBAC1: 1663,
+ 0xBAE1: 1664,
+ 0xBAE2: 1665,
+ 0xBAE5: 1666,
+ 0xBAE9: 1667,
+ 0xBAF1: 1668,
+ 0xBAF3: 1669,
+ 0xBAF5: 1670,
+ 0xBB41: 1671,
+ 0xBB45: 1672,
+ 0xBB49: 1673,
+ 0xBB51: 1674,
+ 0xBB61: 1675,
+ 0xBB62: 1676,
+ 0xBB65: 1677,
+ 0xBB69: 1678,
+ 0xBB71: 1679,
+ 0xBB73: 1680,
+ 0xBB75: 1681,
+ 0xBB77: 1682,
+ 0xBBA1: 1683,
+ 0xBBA2: 1684,
+ 0xBBA5: 1685,
+ 0xBBA8: 1686,
+ 0xBBA9: 1687,
+ 0xBBAB: 1688,
+ 0xBBB1: 1689,
+ 0xBBB3: 1690,
+ 0xBBB5: 1691,
+ 0xBBB7: 1692,
+ 0xBBB8: 1693,
+ 0xBBBB: 1694,
+ 0xBBBC: 1695,
+ 0xBC61: 1696,
+ 0xBC62: 1697,
+ 0xBC65: 1698,
+ 0xBC67: 1699,
+ 0xBC69: 1700,
+ 0xBC6C: 1701,
+ 0xBC71: 1702,
+ 0xBC73: 1703,
+ 0xBC75: 1704,
+ 0xBC76: 1705,
+ 0xBC77: 1706,
+ 0xBC81: 1707,
+ 0xBC82: 1708,
+ 0xBC85: 1709,
+ 0xBC89: 1710,
+ 0xBC91: 1711,
+ 0xBC93: 1712,
+ 0xBC95: 1713,
+ 0xBC96: 1714,
+ 0xBC97: 1715,
+ 0xBCA1: 1716,
+ 0xBCA5: 1717,
+ 0xBCB7: 1718,
+ 0xBCE1: 1719,
+ 0xBCE2: 1720,
+ 0xBCE5: 1721,
+ 0xBCE9: 1722,
+ 0xBCF1: 1723,
+ 0xBCF3: 1724,
+ 0xBCF5: 1725,
+ 0xBCF6: 1726,
+ 0xBCF7: 1727,
+ 0xBD41: 1728,
+ 0xBD57: 1729,
+ 0xBD61: 1730,
+ 0xBD76: 1731,
+ 0xBDA1: 1732,
+ 0xBDA2: 1733,
+ 0xBDA5: 1734,
+ 0xBDA9: 1735,
+ 0xBDB1: 1736,
+ 0xBDB3: 1737,
+ 0xBDB5: 1738,
+ 0xBDB7: 1739,
+ 0xBDB9: 1740,
+ 0xBDC1: 1741,
+ 0xBDC2: 1742,
+ 0xBDC9: 1743,
+ 0xBDD6: 1744,
+ 0xBDE1: 1745,
+ 0xBDF6: 1746,
+ 0xBE41: 1747,
+ 0xBE45: 1748,
+ 0xBE49: 1749,
+ 0xBE51: 1750,
+ 0xBE53: 1751,
+ 0xBE77: 1752,
+ 0xBE81: 1753,
+ 0xBE82: 1754,
+ 0xBE85: 1755,
+ 0xBE89: 1756,
+ 0xBE91: 1757,
+ 0xBE93: 1758,
+ 0xBE97: 1759,
+ 0xBEA1: 1760,
+ 0xBEB6: 1761,
+ 0xBEB7: 1762,
+ 0xBEE1: 1763,
+ 0xBF41: 1764,
+ 0xBF61: 1765,
+ 0xBF71: 1766,
+ 0xBF75: 1767,
+ 0xBF77: 1768,
+ 0xBFA1: 1769,
+ 0xBFA2: 1770,
+ 0xBFA5: 1771,
+ 0xBFA9: 1772,
+ 0xBFB1: 1773,
+ 0xBFB3: 1774,
+ 0xBFB7: 1775,
+ 0xBFB8: 1776,
+ 0xBFBD: 1777,
+ 0xC061: 1778,
+ 0xC062: 1779,
+ 0xC065: 1780,
+ 0xC067: 1781,
+ 0xC069: 1782,
+ 0xC071: 1783,
+ 0xC073: 1784,
+ 0xC075: 1785,
+ 0xC076: 1786,
+ 0xC077: 1787,
+ 0xC078: 1788,
+ 0xC081: 1789,
+ 0xC082: 1790,
+ 0xC085: 1791,
+ 0xC089: 1792,
+ 0xC091: 1793,
+ 0xC093: 1794,
+ 0xC095: 1795,
+ 0xC096: 1796,
+ 0xC097: 1797,
+ 0xC0A1: 1798,
+ 0xC0A5: 1799,
+ 0xC0A7: 1800,
+ 0xC0A9: 1801,
+ 0xC0B1: 1802,
+ 0xC0B7: 1803,
+ 0xC0E1: 1804,
+ 0xC0E2: 1805,
+ 0xC0E5: 1806,
+ 0xC0E9: 1807,
+ 0xC0F1: 1808,
+ 0xC0F3: 1809,
+ 0xC0F5: 1810,
+ 0xC0F6: 1811,
+ 0xC0F7: 1812,
+ 0xC141: 1813,
+ 0xC142: 1814,
+ 0xC145: 1815,
+ 0xC149: 1816,
+ 0xC151: 1817,
+ 0xC153: 1818,
+ 0xC155: 1819,
+ 0xC157: 1820,
+ 0xC161: 1821,
+ 0xC165: 1822,
+ 0xC176: 1823,
+ 0xC181: 1824,
+ 0xC185: 1825,
+ 0xC197: 1826,
+ 0xC1A1: 1827,
+ 0xC1A2: 1828,
+ 0xC1A5: 1829,
+ 0xC1A9: 1830,
+ 0xC1B1: 1831,
+ 0xC1B3: 1832,
+ 0xC1B5: 1833,
+ 0xC1B7: 1834,
+ 0xC1C1: 1835,
+ 0xC1C5: 1836,
+ 0xC1C9: 1837,
+ 0xC1D7: 1838,
+ 0xC241: 1839,
+ 0xC245: 1840,
+ 0xC249: 1841,
+ 0xC251: 1842,
+ 0xC253: 1843,
+ 0xC255: 1844,
+ 0xC257: 1845,
+ 0xC261: 1846,
+ 0xC271: 1847,
+ 0xC281: 1848,
+ 0xC282: 1849,
+ 0xC285: 1850,
+ 0xC289: 1851,
+ 0xC291: 1852,
+ 0xC293: 1853,
+ 0xC295: 1854,
+ 0xC297: 1855,
+ 0xC2A1: 1856,
+ 0xC2B6: 1857,
+ 0xC2C1: 1858,
+ 0xC2C5: 1859,
+ 0xC2E1: 1860,
+ 0xC2E5: 1861,
+ 0xC2E9: 1862,
+ 0xC2F1: 1863,
+ 0xC2F3: 1864,
+ 0xC2F5: 1865,
+ 0xC2F7: 1866,
+ 0xC341: 1867,
+ 0xC345: 1868,
+ 0xC349: 1869,
+ 0xC351: 1870,
+ 0xC357: 1871,
+ 0xC361: 1872,
+ 0xC362: 1873,
+ 0xC365: 1874,
+ 0xC369: 1875,
+ 0xC371: 1876,
+ 0xC373: 1877,
+ 0xC375: 1878,
+ 0xC377: 1879,
+ 0xC3A1: 1880,
+ 0xC3A2: 1881,
+ 0xC3A5: 1882,
+ 0xC3A8: 1883,
+ 0xC3A9: 1884,
+ 0xC3AA: 1885,
+ 0xC3B1: 1886,
+ 0xC3B3: 1887,
+ 0xC3B5: 1888,
+ 0xC3B7: 1889,
+ 0xC461: 1890,
+ 0xC462: 1891,
+ 0xC465: 1892,
+ 0xC469: 1893,
+ 0xC471: 1894,
+ 0xC473: 1895,
+ 0xC475: 1896,
+ 0xC477: 1897,
+ 0xC481: 1898,
+ 0xC482: 1899,
+ 0xC485: 1900,
+ 0xC489: 1901,
+ 0xC491: 1902,
+ 0xC493: 1903,
+ 0xC495: 1904,
+ 0xC496: 1905,
+ 0xC497: 1906,
+ 0xC4A1: 1907,
+ 0xC4A2: 1908,
+ 0xC4B7: 1909,
+ 0xC4E1: 1910,
+ 0xC4E2: 1911,
+ 0xC4E5: 1912,
+ 0xC4E8: 1913,
+ 0xC4E9: 1914,
+ 0xC4F1: 1915,
+ 0xC4F3: 1916,
+ 0xC4F5: 1917,
+ 0xC4F6: 1918,
+ 0xC4F7: 1919,
+ 0xC541: 1920,
+ 0xC542: 1921,
+ 0xC545: 1922,
+ 0xC549: 1923,
+ 0xC551: 1924,
+ 0xC553: 1925,
+ 0xC555: 1926,
+ 0xC557: 1927,
+ 0xC561: 1928,
+ 0xC565: 1929,
+ 0xC569: 1930,
+ 0xC571: 1931,
+ 0xC573: 1932,
+ 0xC575: 1933,
+ 0xC576: 1934,
+ 0xC577: 1935,
+ 0xC581: 1936,
+ 0xC5A1: 1937,
+ 0xC5A2: 1938,
+ 0xC5A5: 1939,
+ 0xC5A9: 1940,
+ 0xC5B1: 1941,
+ 0xC5B3: 1942,
+ 0xC5B5: 1943,
+ 0xC5B7: 1944,
+ 0xC5C1: 1945,
+ 0xC5C2: 1946,
+ 0xC5C5: 1947,
+ 0xC5C9: 1948,
+ 0xC5D1: 1949,
+ 0xC5D7: 1950,
+ 0xC5E1: 1951,
+ 0xC5F7: 1952,
+ 0xC641: 1953,
+ 0xC649: 1954,
+ 0xC661: 1955,
+ 0xC681: 1956,
+ 0xC682: 1957,
+ 0xC685: 1958,
+ 0xC689: 1959,
+ 0xC691: 1960,
+ 0xC693: 1961,
+ 0xC695: 1962,
+ 0xC697: 1963,
+ 0xC6A1: 1964,
+ 0xC6A5: 1965,
+ 0xC6A9: 1966,
+ 0xC6B7: 1967,
+ 0xC6C1: 1968,
+ 0xC6D7: 1969,
+ 0xC6E1: 1970,
+ 0xC6E2: 1971,
+ 0xC6E5: 1972,
+ 0xC6E9: 1973,
+ 0xC6F1: 1974,
+ 0xC6F3: 1975,
+ 0xC6F5: 1976,
+ 0xC6F7: 1977,
+ 0xC741: 1978,
+ 0xC745: 1979,
+ 0xC749: 1980,
+ 0xC751: 1981,
+ 0xC761: 1982,
+ 0xC762: 1983,
+ 0xC765: 1984,
+ 0xC769: 1985,
+ 0xC771: 1986,
+ 0xC773: 1987,
+ 0xC777: 1988,
+ 0xC7A1: 1989,
+ 0xC7A2: 1990,
+ 0xC7A5: 1991,
+ 0xC7A9: 1992,
+ 0xC7B1: 1993,
+ 0xC7B3: 1994,
+ 0xC7B5: 1995,
+ 0xC7B7: 1996,
+ 0xC861: 1997,
+ 0xC862: 1998,
+ 0xC865: 1999,
+ 0xC869: 2000,
+ 0xC86A: 2001,
+ 0xC871: 2002,
+ 0xC873: 2003,
+ 0xC875: 2004,
+ 0xC876: 2005,
+ 0xC877: 2006,
+ 0xC881: 2007,
+ 0xC882: 2008,
+ 0xC885: 2009,
+ 0xC889: 2010,
+ 0xC891: 2011,
+ 0xC893: 2012,
+ 0xC895: 2013,
+ 0xC896: 2014,
+ 0xC897: 2015,
+ 0xC8A1: 2016,
+ 0xC8B7: 2017,
+ 0xC8E1: 2018,
+ 0xC8E2: 2019,
+ 0xC8E5: 2020,
+ 0xC8E9: 2021,
+ 0xC8EB: 2022,
+ 0xC8F1: 2023,
+ 0xC8F3: 2024,
+ 0xC8F5: 2025,
+ 0xC8F6: 2026,
+ 0xC8F7: 2027,
+ 0xC941: 2028,
+ 0xC942: 2029,
+ 0xC945: 2030,
+ 0xC949: 2031,
+ 0xC951: 2032,
+ 0xC953: 2033,
+ 0xC955: 2034,
+ 0xC957: 2035,
+ 0xC961: 2036,
+ 0xC965: 2037,
+ 0xC976: 2038,
+ 0xC981: 2039,
+ 0xC985: 2040,
+ 0xC9A1: 2041,
+ 0xC9A2: 2042,
+ 0xC9A5: 2043,
+ 0xC9A9: 2044,
+ 0xC9B1: 2045,
+ 0xC9B3: 2046,
+ 0xC9B5: 2047,
+ 0xC9B7: 2048,
+ 0xC9BC: 2049,
+ 0xC9C1: 2050,
+ 0xC9C5: 2051,
+ 0xC9E1: 2052,
+ 0xCA41: 2053,
+ 0xCA45: 2054,
+ 0xCA55: 2055,
+ 0xCA57: 2056,
+ 0xCA61: 2057,
+ 0xCA81: 2058,
+ 0xCA82: 2059,
+ 0xCA85: 2060,
+ 0xCA89: 2061,
+ 0xCA91: 2062,
+ 0xCA93: 2063,
+ 0xCA95: 2064,
+ 0xCA97: 2065,
+ 0xCAA1: 2066,
+ 0xCAB6: 2067,
+ 0xCAC1: 2068,
+ 0xCAE1: 2069,
+ 0xCAE2: 2070,
+ 0xCAE5: 2071,
+ 0xCAE9: 2072,
+ 0xCAF1: 2073,
+ 0xCAF3: 2074,
+ 0xCAF7: 2075,
+ 0xCB41: 2076,
+ 0xCB45: 2077,
+ 0xCB49: 2078,
+ 0xCB51: 2079,
+ 0xCB57: 2080,
+ 0xCB61: 2081,
+ 0xCB62: 2082,
+ 0xCB65: 2083,
+ 0xCB68: 2084,
+ 0xCB69: 2085,
+ 0xCB6B: 2086,
+ 0xCB71: 2087,
+ 0xCB73: 2088,
+ 0xCB75: 2089,
+ 0xCB81: 2090,
+ 0xCB85: 2091,
+ 0xCB89: 2092,
+ 0xCB91: 2093,
+ 0xCB93: 2094,
+ 0xCBA1: 2095,
+ 0xCBA2: 2096,
+ 0xCBA5: 2097,
+ 0xCBA9: 2098,
+ 0xCBB1: 2099,
+ 0xCBB3: 2100,
+ 0xCBB5: 2101,
+ 0xCBB7: 2102,
+ 0xCC61: 2103,
+ 0xCC62: 2104,
+ 0xCC63: 2105,
+ 0xCC65: 2106,
+ 0xCC69: 2107,
+ 0xCC6B: 2108,
+ 0xCC71: 2109,
+ 0xCC73: 2110,
+ 0xCC75: 2111,
+ 0xCC76: 2112,
+ 0xCC77: 2113,
+ 0xCC7B: 2114,
+ 0xCC81: 2115,
+ 0xCC82: 2116,
+ 0xCC85: 2117,
+ 0xCC89: 2118,
+ 0xCC91: 2119,
+ 0xCC93: 2120,
+ 0xCC95: 2121,
+ 0xCC96: 2122,
+ 0xCC97: 2123,
+ 0xCCA1: 2124,
+ 0xCCA2: 2125,
+ 0xCCE1: 2126,
+ 0xCCE2: 2127,
+ 0xCCE5: 2128,
+ 0xCCE9: 2129,
+ 0xCCF1: 2130,
+ 0xCCF3: 2131,
+ 0xCCF5: 2132,
+ 0xCCF6: 2133,
+ 0xCCF7: 2134,
+ 0xCD41: 2135,
+ 0xCD42: 2136,
+ 0xCD45: 2137,
+ 0xCD49: 2138,
+ 0xCD51: 2139,
+ 0xCD53: 2140,
+ 0xCD55: 2141,
+ 0xCD57: 2142,
+ 0xCD61: 2143,
+ 0xCD65: 2144,
+ 0xCD69: 2145,
+ 0xCD71: 2146,
+ 0xCD73: 2147,
+ 0xCD76: 2148,
+ 0xCD77: 2149,
+ 0xCD81: 2150,
+ 0xCD89: 2151,
+ 0xCD93: 2152,
+ 0xCD95: 2153,
+ 0xCDA1: 2154,
+ 0xCDA2: 2155,
+ 0xCDA5: 2156,
+ 0xCDA9: 2157,
+ 0xCDB1: 2158,
+ 0xCDB3: 2159,
+ 0xCDB5: 2160,
+ 0xCDB7: 2161,
+ 0xCDC1: 2162,
+ 0xCDD7: 2163,
+ 0xCE41: 2164,
+ 0xCE45: 2165,
+ 0xCE61: 2166,
+ 0xCE65: 2167,
+ 0xCE69: 2168,
+ 0xCE73: 2169,
+ 0xCE75: 2170,
+ 0xCE81: 2171,
+ 0xCE82: 2172,
+ 0xCE85: 2173,
+ 0xCE88: 2174,
+ 0xCE89: 2175,
+ 0xCE8B: 2176,
+ 0xCE91: 2177,
+ 0xCE93: 2178,
+ 0xCE95: 2179,
+ 0xCE97: 2180,
+ 0xCEA1: 2181,
+ 0xCEB7: 2182,
+ 0xCEE1: 2183,
+ 0xCEE5: 2184,
+ 0xCEE9: 2185,
+ 0xCEF1: 2186,
+ 0xCEF5: 2187,
+ 0xCF41: 2188,
+ 0xCF45: 2189,
+ 0xCF49: 2190,
+ 0xCF51: 2191,
+ 0xCF55: 2192,
+ 0xCF57: 2193,
+ 0xCF61: 2194,
+ 0xCF65: 2195,
+ 0xCF69: 2196,
+ 0xCF71: 2197,
+ 0xCF73: 2198,
+ 0xCF75: 2199,
+ 0xCFA1: 2200,
+ 0xCFA2: 2201,
+ 0xCFA5: 2202,
+ 0xCFA9: 2203,
+ 0xCFB1: 2204,
+ 0xCFB3: 2205,
+ 0xCFB5: 2206,
+ 0xCFB7: 2207,
+ 0xD061: 2208,
+ 0xD062: 2209,
+ 0xD065: 2210,
+ 0xD069: 2211,
+ 0xD06E: 2212,
+ 0xD071: 2213,
+ 0xD073: 2214,
+ 0xD075: 2215,
+ 0xD077: 2216,
+ 0xD081: 2217,
+ 0xD082: 2218,
+ 0xD085: 2219,
+ 0xD089: 2220,
+ 0xD091: 2221,
+ 0xD093: 2222,
+ 0xD095: 2223,
+ 0xD096: 2224,
+ 0xD097: 2225,
+ 0xD0A1: 2226,
+ 0xD0B7: 2227,
+ 0xD0E1: 2228,
+ 0xD0E2: 2229,
+ 0xD0E5: 2230,
+ 0xD0E9: 2231,
+ 0xD0EB: 2232,
+ 0xD0F1: 2233,
+ 0xD0F3: 2234,
+ 0xD0F5: 2235,
+ 0xD0F7: 2236,
+ 0xD141: 2237,
+ 0xD142: 2238,
+ 0xD145: 2239,
+ 0xD149: 2240,
+ 0xD151: 2241,
+ 0xD153: 2242,
+ 0xD155: 2243,
+ 0xD157: 2244,
+ 0xD161: 2245,
+ 0xD162: 2246,
+ 0xD165: 2247,
+ 0xD169: 2248,
+ 0xD171: 2249,
+ 0xD173: 2250,
+ 0xD175: 2251,
+ 0xD176: 2252,
+ 0xD177: 2253,
+ 0xD181: 2254,
+ 0xD185: 2255,
+ 0xD189: 2256,
+ 0xD193: 2257,
+ 0xD1A1: 2258,
+ 0xD1A2: 2259,
+ 0xD1A5: 2260,
+ 0xD1A9: 2261,
+ 0xD1AE: 2262,
+ 0xD1B1: 2263,
+ 0xD1B3: 2264,
+ 0xD1B5: 2265,
+ 0xD1B7: 2266,
+ 0xD1BB: 2267,
+ 0xD1C1: 2268,
+ 0xD1C2: 2269,
+ 0xD1C5: 2270,
+ 0xD1C9: 2271,
+ 0xD1D5: 2272,
+ 0xD1D7: 2273,
+ 0xD1E1: 2274,
+ 0xD1E2: 2275,
+ 0xD1E5: 2276,
+ 0xD1F5: 2277,
+ 0xD1F7: 2278,
+ 0xD241: 2279,
+ 0xD242: 2280,
+ 0xD245: 2281,
+ 0xD249: 2282,
+ 0xD253: 2283,
+ 0xD255: 2284,
+ 0xD257: 2285,
+ 0xD261: 2286,
+ 0xD265: 2287,
+ 0xD269: 2288,
+ 0xD273: 2289,
+ 0xD275: 2290,
+ 0xD281: 2291,
+ 0xD282: 2292,
+ 0xD285: 2293,
+ 0xD289: 2294,
+ 0xD28E: 2295,
+ 0xD291: 2296,
+ 0xD295: 2297,
+ 0xD297: 2298,
+ 0xD2A1: 2299,
+ 0xD2A5: 2300,
+ 0xD2A9: 2301,
+ 0xD2B1: 2302,
+ 0xD2B7: 2303,
+ 0xD2C1: 2304,
+ 0xD2C2: 2305,
+ 0xD2C5: 2306,
+ 0xD2C9: 2307,
+ 0xD2D7: 2308,
+ 0xD2E1: 2309,
+ 0xD2E2: 2310,
+ 0xD2E5: 2311,
+ 0xD2E9: 2312,
+ 0xD2F1: 2313,
+ 0xD2F3: 2314,
+ 0xD2F5: 2315,
+ 0xD2F7: 2316,
+ 0xD341: 2317,
+ 0xD342: 2318,
+ 0xD345: 2319,
+ 0xD349: 2320,
+ 0xD351: 2321,
+ 0xD355: 2322,
+ 0xD357: 2323,
+ 0xD361: 2324,
+ 0xD362: 2325,
+ 0xD365: 2326,
+ 0xD367: 2327,
+ 0xD368: 2328,
+ 0xD369: 2329,
+ 0xD36A: 2330,
+ 0xD371: 2331,
+ 0xD373: 2332,
+ 0xD375: 2333,
+ 0xD377: 2334,
+ 0xD37B: 2335,
+ 0xD381: 2336,
+ 0xD385: 2337,
+ 0xD389: 2338,
+ 0xD391: 2339,
+ 0xD393: 2340,
+ 0xD397: 2341,
+ 0xD3A1: 2342,
+ 0xD3A2: 2343,
+ 0xD3A5: 2344,
+ 0xD3A9: 2345,
+ 0xD3B1: 2346,
+ 0xD3B3: 2347,
+ 0xD3B5: 2348,
+ 0xD3B7: 2349,
+}
diff --git a/lib/chardet/johabprober.py b/lib/chardet/johabprober.py
new file mode 100644
index 0000000..6f359d1
--- /dev/null
+++ b/lib/chardet/johabprober.py
@@ -0,0 +1,47 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import JOHABDistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import JOHAB_SM_MODEL
+
+
+class JOHABProber(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(JOHAB_SM_MODEL)
+ self.distribution_analyzer = JOHABDistributionAnalysis()
+ self.reset()
+
+ @property
+ def charset_name(self):
+ return "Johab"
+
+ @property
+ def language(self):
+ return "Korean"
diff --git a/lib/chardet/jpcntx.py b/lib/chardet/jpcntx.py
new file mode 100644
index 0000000..7a8e5be
--- /dev/null
+++ b/lib/chardet/jpcntx.py
@@ -0,0 +1,237 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Communicator client code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+
+# This is hiragana 2-char sequence table, the number in each cell represents its frequency category
+# fmt: off
+jp2_char_context = (
+ (0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1),
+ (2, 4, 0, 4, 0, 3, 0, 4, 0, 3, 4, 4, 4, 2, 4, 3, 3, 4, 3, 2, 3, 3, 4, 2, 3, 3, 3, 2, 4, 1, 4, 3, 3, 1, 5, 4, 3, 4, 3, 4, 3, 5, 3, 0, 3, 5, 4, 2, 0, 3, 1, 0, 3, 3, 0, 3, 3, 0, 1, 1, 0, 4, 3, 0, 3, 3, 0, 4, 0, 2, 0, 3, 5, 5, 5, 5, 4, 0, 4, 1, 0, 3, 4),
+ (0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2),
+ (0, 4, 0, 5, 0, 5, 0, 4, 0, 4, 5, 4, 4, 3, 5, 3, 5, 1, 5, 3, 4, 3, 4, 4, 3, 4, 3, 3, 4, 3, 5, 4, 4, 3, 5, 5, 3, 5, 5, 5, 3, 5, 5, 3, 4, 5, 5, 3, 1, 3, 2, 0, 3, 4, 0, 4, 2, 0, 4, 2, 1, 5, 3, 2, 3, 5, 0, 4, 0, 2, 0, 5, 4, 4, 5, 4, 5, 0, 4, 0, 0, 4, 4),
+ (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 3, 0, 4, 0, 3, 0, 3, 0, 4, 5, 4, 3, 3, 3, 3, 4, 3, 5, 4, 4, 3, 5, 4, 4, 3, 4, 3, 4, 4, 4, 4, 5, 3, 4, 4, 3, 4, 5, 5, 4, 5, 5, 1, 4, 5, 4, 3, 0, 3, 3, 1, 3, 3, 0, 4, 4, 0, 3, 3, 1, 5, 3, 3, 3, 5, 0, 4, 0, 3, 0, 4, 4, 3, 4, 3, 3, 0, 4, 1, 1, 3, 4),
+ (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 4, 0, 3, 0, 3, 0, 4, 0, 3, 4, 4, 3, 2, 2, 1, 2, 1, 3, 1, 3, 3, 3, 3, 3, 4, 3, 1, 3, 3, 5, 3, 3, 0, 4, 3, 0, 5, 4, 3, 3, 5, 4, 4, 3, 4, 4, 5, 0, 1, 2, 0, 1, 2, 0, 2, 2, 0, 1, 0, 0, 5, 2, 2, 1, 4, 0, 3, 0, 1, 0, 4, 4, 3, 5, 4, 3, 0, 2, 1, 0, 4, 3),
+ (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 3, 0, 5, 0, 4, 0, 2, 1, 4, 4, 2, 4, 1, 4, 2, 4, 2, 4, 3, 3, 3, 4, 3, 3, 3, 3, 1, 4, 2, 3, 3, 3, 1, 4, 4, 1, 1, 1, 4, 3, 3, 2, 0, 2, 4, 3, 2, 0, 3, 3, 0, 3, 1, 1, 0, 0, 0, 3, 3, 0, 4, 2, 2, 3, 4, 0, 4, 0, 3, 0, 4, 4, 5, 3, 4, 4, 0, 3, 0, 0, 1, 4),
+ (1, 4, 0, 4, 0, 4, 0, 4, 0, 3, 5, 4, 4, 3, 4, 3, 5, 4, 3, 3, 4, 3, 5, 4, 4, 4, 4, 3, 4, 2, 4, 3, 3, 1, 5, 4, 3, 2, 4, 5, 4, 5, 5, 4, 4, 5, 4, 4, 0, 3, 2, 2, 3, 3, 0, 4, 3, 1, 3, 2, 1, 4, 3, 3, 4, 5, 0, 3, 0, 2, 0, 4, 5, 5, 4, 5, 4, 0, 4, 0, 0, 5, 4),
+ (0, 5, 0, 5, 0, 4, 0, 3, 0, 4, 4, 3, 4, 3, 3, 3, 4, 0, 4, 4, 4, 3, 4, 3, 4, 3, 3, 1, 4, 2, 4, 3, 4, 0, 5, 4, 1, 4, 5, 4, 4, 5, 3, 2, 4, 3, 4, 3, 2, 4, 1, 3, 3, 3, 2, 3, 2, 0, 4, 3, 3, 4, 3, 3, 3, 4, 0, 4, 0, 3, 0, 4, 5, 4, 4, 4, 3, 0, 4, 1, 0, 1, 3),
+ (0, 3, 1, 4, 0, 3, 0, 2, 0, 3, 4, 4, 3, 1, 4, 2, 3, 3, 4, 3, 4, 3, 4, 3, 4, 4, 3, 2, 3, 1, 5, 4, 4, 1, 4, 4, 3, 5, 4, 4, 3, 5, 5, 4, 3, 4, 4, 3, 1, 2, 3, 1, 2, 2, 0, 3, 2, 0, 3, 1, 0, 5, 3, 3, 3, 4, 3, 3, 3, 3, 4, 4, 4, 4, 5, 4, 2, 0, 3, 3, 2, 4, 3),
+ (0, 2, 0, 3, 0, 1, 0, 1, 0, 0, 3, 2, 0, 0, 2, 0, 1, 0, 2, 1, 3, 3, 3, 1, 2, 3, 1, 0, 1, 0, 4, 2, 1, 1, 3, 3, 0, 4, 3, 3, 1, 4, 3, 3, 0, 3, 3, 2, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 0, 0, 0, 4, 1, 0, 2, 3, 2, 2, 2, 1, 3, 3, 3, 4, 4, 3, 2, 0, 3, 1, 0, 3, 3),
+ (0, 4, 0, 4, 0, 3, 0, 3, 0, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 3, 4, 2, 4, 3, 4, 3, 3, 2, 4, 3, 4, 5, 4, 1, 4, 5, 3, 5, 4, 5, 3, 5, 4, 0, 3, 5, 5, 3, 1, 3, 3, 2, 2, 3, 0, 3, 4, 1, 3, 3, 2, 4, 3, 3, 3, 4, 0, 4, 0, 3, 0, 4, 5, 4, 4, 5, 3, 0, 4, 1, 0, 3, 4),
+ (0, 2, 0, 3, 0, 3, 0, 0, 0, 2, 2, 2, 1, 0, 1, 0, 0, 0, 3, 0, 3, 0, 3, 0, 1, 3, 1, 0, 3, 1, 3, 3, 3, 1, 3, 3, 3, 0, 1, 3, 1, 3, 4, 0, 0, 3, 1, 1, 0, 3, 2, 0, 0, 0, 0, 1, 3, 0, 1, 0, 0, 3, 3, 2, 0, 3, 0, 0, 0, 0, 0, 3, 4, 3, 4, 3, 3, 0, 3, 0, 0, 2, 3),
+ (2, 3, 0, 3, 0, 2, 0, 1, 0, 3, 3, 4, 3, 1, 3, 1, 1, 1, 3, 1, 4, 3, 4, 3, 3, 3, 0, 0, 3, 1, 5, 4, 3, 1, 4, 3, 2, 5, 5, 4, 4, 4, 4, 3, 3, 4, 4, 4, 0, 2, 1, 1, 3, 2, 0, 1, 2, 0, 0, 1, 0, 4, 1, 3, 3, 3, 0, 3, 0, 1, 0, 4, 4, 4, 5, 5, 3, 0, 2, 0, 0, 4, 4),
+ (0, 2, 0, 1, 0, 3, 1, 3, 0, 2, 3, 3, 3, 0, 3, 1, 0, 0, 3, 0, 3, 2, 3, 1, 3, 2, 1, 1, 0, 0, 4, 2, 1, 0, 2, 3, 1, 4, 3, 2, 0, 4, 4, 3, 1, 3, 1, 3, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 4, 1, 1, 1, 2, 0, 3, 0, 0, 0, 3, 4, 2, 4, 3, 2, 0, 1, 0, 0, 3, 3),
+ (0, 1, 0, 4, 0, 5, 0, 4, 0, 2, 4, 4, 2, 3, 3, 2, 3, 3, 5, 3, 3, 3, 4, 3, 4, 2, 3, 0, 4, 3, 3, 3, 4, 1, 4, 3, 2, 1, 5, 5, 3, 4, 5, 1, 3, 5, 4, 2, 0, 3, 3, 0, 1, 3, 0, 4, 2, 0, 1, 3, 1, 4, 3, 3, 3, 3, 0, 3, 0, 1, 0, 3, 4, 4, 4, 5, 5, 0, 3, 0, 1, 4, 5),
+ (0, 2, 0, 3, 0, 3, 0, 0, 0, 2, 3, 1, 3, 0, 4, 0, 1, 1, 3, 0, 3, 4, 3, 2, 3, 1, 0, 3, 3, 2, 3, 1, 3, 0, 2, 3, 0, 2, 1, 4, 1, 2, 2, 0, 0, 3, 3, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 2, 2, 0, 3, 2, 1, 3, 3, 0, 2, 0, 2, 0, 0, 3, 3, 1, 2, 4, 0, 3, 0, 2, 2, 3),
+ (2, 4, 0, 5, 0, 4, 0, 4, 0, 2, 4, 4, 4, 3, 4, 3, 3, 3, 1, 2, 4, 3, 4, 3, 4, 4, 5, 0, 3, 3, 3, 3, 2, 0, 4, 3, 1, 4, 3, 4, 1, 4, 4, 3, 3, 4, 4, 3, 1, 2, 3, 0, 4, 2, 0, 4, 1, 0, 3, 3, 0, 4, 3, 3, 3, 4, 0, 4, 0, 2, 0, 3, 5, 3, 4, 5, 2, 0, 3, 0, 0, 4, 5),
+ (0, 3, 0, 4, 0, 1, 0, 1, 0, 1, 3, 2, 2, 1, 3, 0, 3, 0, 2, 0, 2, 0, 3, 0, 2, 0, 0, 0, 1, 0, 1, 1, 0, 0, 3, 1, 0, 0, 0, 4, 0, 3, 1, 0, 2, 1, 3, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 2, 2, 3, 1, 0, 3, 0, 0, 0, 1, 4, 4, 4, 3, 0, 0, 4, 0, 0, 1, 4),
+ (1, 4, 1, 5, 0, 3, 0, 3, 0, 4, 5, 4, 4, 3, 5, 3, 3, 4, 4, 3, 4, 1, 3, 3, 3, 3, 2, 1, 4, 1, 5, 4, 3, 1, 4, 4, 3, 5, 4, 4, 3, 5, 4, 3, 3, 4, 4, 4, 0, 3, 3, 1, 2, 3, 0, 3, 1, 0, 3, 3, 0, 5, 4, 4, 4, 4, 4, 4, 3, 3, 5, 4, 4, 3, 3, 5, 4, 0, 3, 2, 0, 4, 4),
+ (0, 2, 0, 3, 0, 1, 0, 0, 0, 1, 3, 3, 3, 2, 4, 1, 3, 0, 3, 1, 3, 0, 2, 2, 1, 1, 0, 0, 2, 0, 4, 3, 1, 0, 4, 3, 0, 4, 4, 4, 1, 4, 3, 1, 1, 3, 3, 1, 0, 2, 0, 0, 1, 3, 0, 0, 0, 0, 2, 0, 0, 4, 3, 2, 4, 3, 5, 4, 3, 3, 3, 4, 3, 3, 4, 3, 3, 0, 2, 1, 0, 3, 3),
+ (0, 2, 0, 4, 0, 3, 0, 2, 0, 2, 5, 5, 3, 4, 4, 4, 4, 1, 4, 3, 3, 0, 4, 3, 4, 3, 1, 3, 3, 2, 4, 3, 0, 3, 4, 3, 0, 3, 4, 4, 2, 4, 4, 0, 4, 5, 3, 3, 2, 2, 1, 1, 1, 2, 0, 1, 5, 0, 3, 3, 2, 4, 3, 3, 3, 4, 0, 3, 0, 2, 0, 4, 4, 3, 5, 5, 0, 0, 3, 0, 2, 3, 3),
+ (0, 3, 0, 4, 0, 3, 0, 1, 0, 3, 4, 3, 3, 1, 3, 3, 3, 0, 3, 1, 3, 0, 4, 3, 3, 1, 1, 0, 3, 0, 3, 3, 0, 0, 4, 4, 0, 1, 5, 4, 3, 3, 5, 0, 3, 3, 4, 3, 0, 2, 0, 1, 1, 1, 0, 1, 3, 0, 1, 2, 1, 3, 3, 2, 3, 3, 0, 3, 0, 1, 0, 1, 3, 3, 4, 4, 1, 0, 1, 2, 2, 1, 3),
+ (0, 1, 0, 4, 0, 4, 0, 3, 0, 1, 3, 3, 3, 2, 3, 1, 1, 0, 3, 0, 3, 3, 4, 3, 2, 4, 2, 0, 1, 0, 4, 3, 2, 0, 4, 3, 0, 5, 3, 3, 2, 4, 4, 4, 3, 3, 3, 4, 0, 1, 3, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 4, 2, 3, 3, 3, 0, 3, 0, 0, 0, 4, 4, 4, 5, 3, 2, 0, 3, 3, 0, 3, 5),
+ (0, 2, 0, 3, 0, 0, 0, 3, 0, 1, 3, 0, 2, 0, 0, 0, 1, 0, 3, 1, 1, 3, 3, 0, 0, 3, 0, 0, 3, 0, 2, 3, 1, 0, 3, 1, 0, 3, 3, 2, 0, 4, 2, 2, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 1, 0, 0, 0, 1, 3, 1, 2, 0, 0, 0, 1, 0, 0, 1, 4),
+ (0, 3, 0, 3, 0, 5, 0, 1, 0, 2, 4, 3, 1, 3, 3, 2, 1, 1, 5, 2, 1, 0, 5, 1, 2, 0, 0, 0, 3, 3, 2, 2, 3, 2, 4, 3, 0, 0, 3, 3, 1, 3, 3, 0, 2, 5, 3, 4, 0, 3, 3, 0, 1, 2, 0, 2, 2, 0, 3, 2, 0, 2, 2, 3, 3, 3, 0, 2, 0, 1, 0, 3, 4, 4, 2, 5, 4, 0, 3, 0, 0, 3, 5),
+ (0, 3, 0, 3, 0, 3, 0, 1, 0, 3, 3, 3, 3, 0, 3, 0, 2, 0, 2, 1, 1, 0, 2, 0, 1, 0, 0, 0, 2, 1, 0, 0, 1, 0, 3, 2, 0, 0, 3, 3, 1, 2, 3, 1, 0, 3, 3, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 3, 1, 2, 3, 0, 3, 0, 1, 0, 3, 2, 1, 0, 4, 3, 0, 1, 1, 0, 3, 3),
+ (0, 4, 0, 5, 0, 3, 0, 3, 0, 4, 5, 5, 4, 3, 5, 3, 4, 3, 5, 3, 3, 2, 5, 3, 4, 4, 4, 3, 4, 3, 4, 5, 5, 3, 4, 4, 3, 4, 4, 5, 4, 4, 4, 3, 4, 5, 5, 4, 2, 3, 4, 2, 3, 4, 0, 3, 3, 1, 4, 3, 2, 4, 3, 3, 5, 5, 0, 3, 0, 3, 0, 5, 5, 5, 5, 4, 4, 0, 4, 0, 1, 4, 4),
+ (0, 4, 0, 4, 0, 3, 0, 3, 0, 3, 5, 4, 4, 2, 3, 2, 5, 1, 3, 2, 5, 1, 4, 2, 3, 2, 3, 3, 4, 3, 3, 3, 3, 2, 5, 4, 1, 3, 3, 5, 3, 4, 4, 0, 4, 4, 3, 1, 1, 3, 1, 0, 2, 3, 0, 2, 3, 0, 3, 0, 0, 4, 3, 1, 3, 4, 0, 3, 0, 2, 0, 4, 4, 4, 3, 4, 5, 0, 4, 0, 0, 3, 4),
+ (0, 3, 0, 3, 0, 3, 1, 2, 0, 3, 4, 4, 3, 3, 3, 0, 2, 2, 4, 3, 3, 1, 3, 3, 3, 1, 1, 0, 3, 1, 4, 3, 2, 3, 4, 4, 2, 4, 4, 4, 3, 4, 4, 3, 2, 4, 4, 3, 1, 3, 3, 1, 3, 3, 0, 4, 1, 0, 2, 2, 1, 4, 3, 2, 3, 3, 5, 4, 3, 3, 5, 4, 4, 3, 3, 0, 4, 0, 3, 2, 2, 4, 4),
+ (0, 2, 0, 1, 0, 0, 0, 0, 0, 1, 2, 1, 3, 0, 0, 0, 0, 0, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 1, 0, 1, 1, 3, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 0, 3, 4, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1),
+ (0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0, 4, 1, 4, 0, 3, 0, 4, 0, 3, 0, 4, 0, 3, 0, 3, 0, 4, 1, 5, 1, 4, 0, 0, 3, 0, 5, 0, 5, 2, 0, 1, 0, 0, 0, 2, 1, 4, 0, 1, 3, 0, 0, 3, 0, 0, 3, 1, 1, 4, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0),
+ (1, 4, 0, 5, 0, 3, 0, 2, 0, 3, 5, 4, 4, 3, 4, 3, 5, 3, 4, 3, 3, 0, 4, 3, 3, 3, 3, 3, 3, 2, 4, 4, 3, 1, 3, 4, 4, 5, 4, 4, 3, 4, 4, 1, 3, 5, 4, 3, 3, 3, 1, 2, 2, 3, 3, 1, 3, 1, 3, 3, 3, 5, 3, 3, 4, 5, 0, 3, 0, 3, 0, 3, 4, 3, 4, 4, 3, 0, 3, 0, 2, 4, 3),
+ (0, 1, 0, 4, 0, 0, 0, 0, 0, 1, 4, 0, 4, 1, 4, 2, 4, 0, 3, 0, 1, 0, 1, 0, 0, 0, 0, 0, 2, 0, 3, 1, 1, 1, 0, 3, 0, 0, 0, 1, 2, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 3, 0, 0, 0, 0, 3, 2, 0, 2, 2, 0, 1, 0, 0, 0, 2, 3, 2, 3, 3, 0, 0, 0, 0, 2, 1, 0),
+ (0, 5, 1, 5, 0, 3, 0, 3, 0, 5, 4, 4, 5, 1, 5, 3, 3, 0, 4, 3, 4, 3, 5, 3, 4, 3, 3, 2, 4, 3, 4, 3, 3, 0, 3, 3, 1, 4, 4, 3, 4, 4, 4, 3, 4, 5, 5, 3, 2, 3, 1, 1, 3, 3, 1, 3, 1, 1, 3, 3, 2, 4, 5, 3, 3, 5, 0, 4, 0, 3, 0, 4, 4, 3, 5, 3, 3, 0, 3, 4, 0, 4, 3),
+ (0, 5, 0, 5, 0, 3, 0, 2, 0, 4, 4, 3, 5, 2, 4, 3, 3, 3, 4, 4, 4, 3, 5, 3, 5, 3, 3, 1, 4, 0, 4, 3, 3, 0, 3, 3, 0, 4, 4, 4, 4, 5, 4, 3, 3, 5, 5, 3, 2, 3, 1, 2, 3, 2, 0, 1, 0, 0, 3, 2, 2, 4, 4, 3, 1, 5, 0, 4, 0, 3, 0, 4, 3, 1, 3, 2, 1, 0, 3, 3, 0, 3, 3),
+ (0, 4, 0, 5, 0, 5, 0, 4, 0, 4, 5, 5, 5, 3, 4, 3, 3, 2, 5, 4, 4, 3, 5, 3, 5, 3, 4, 0, 4, 3, 4, 4, 3, 2, 4, 4, 3, 4, 5, 4, 4, 5, 5, 0, 3, 5, 5, 4, 1, 3, 3, 2, 3, 3, 1, 3, 1, 0, 4, 3, 1, 4, 4, 3, 4, 5, 0, 4, 0, 2, 0, 4, 3, 4, 4, 3, 3, 0, 4, 0, 0, 5, 5),
+ (0, 4, 0, 4, 0, 5, 0, 1, 1, 3, 3, 4, 4, 3, 4, 1, 3, 0, 5, 1, 3, 0, 3, 1, 3, 1, 1, 0, 3, 0, 3, 3, 4, 0, 4, 3, 0, 4, 4, 4, 3, 4, 4, 0, 3, 5, 4, 1, 0, 3, 0, 0, 2, 3, 0, 3, 1, 0, 3, 1, 0, 3, 2, 1, 3, 5, 0, 3, 0, 1, 0, 3, 2, 3, 3, 4, 4, 0, 2, 2, 0, 4, 4),
+ (2, 4, 0, 5, 0, 4, 0, 3, 0, 4, 5, 5, 4, 3, 5, 3, 5, 3, 5, 3, 5, 2, 5, 3, 4, 3, 3, 4, 3, 4, 5, 3, 2, 1, 5, 4, 3, 2, 3, 4, 5, 3, 4, 1, 2, 5, 4, 3, 0, 3, 3, 0, 3, 2, 0, 2, 3, 0, 4, 1, 0, 3, 4, 3, 3, 5, 0, 3, 0, 1, 0, 4, 5, 5, 5, 4, 3, 0, 4, 2, 0, 3, 5),
+ (0, 5, 0, 4, 0, 4, 0, 2, 0, 5, 4, 3, 4, 3, 4, 3, 3, 3, 4, 3, 4, 2, 5, 3, 5, 3, 4, 1, 4, 3, 4, 4, 4, 0, 3, 5, 0, 4, 4, 4, 4, 5, 3, 1, 3, 4, 5, 3, 3, 3, 3, 3, 3, 3, 0, 2, 2, 0, 3, 3, 2, 4, 3, 3, 3, 5, 3, 4, 1, 3, 3, 5, 3, 2, 0, 0, 0, 0, 4, 3, 1, 3, 3),
+ (0, 1, 0, 3, 0, 3, 0, 1, 0, 1, 3, 3, 3, 2, 3, 3, 3, 0, 3, 0, 0, 0, 3, 1, 3, 0, 0, 0, 2, 2, 2, 3, 0, 0, 3, 2, 0, 1, 2, 4, 1, 3, 3, 0, 0, 3, 3, 3, 0, 1, 0, 0, 2, 1, 0, 0, 3, 0, 3, 1, 0, 3, 0, 0, 1, 3, 0, 2, 0, 1, 0, 3, 3, 1, 3, 3, 0, 0, 1, 1, 0, 3, 3),
+ (0, 2, 0, 3, 0, 2, 1, 4, 0, 2, 2, 3, 1, 1, 3, 1, 1, 0, 2, 0, 3, 1, 2, 3, 1, 3, 0, 0, 1, 0, 4, 3, 2, 3, 3, 3, 1, 4, 2, 3, 3, 3, 3, 1, 0, 3, 1, 4, 0, 1, 1, 0, 1, 2, 0, 1, 1, 0, 1, 1, 0, 3, 1, 3, 2, 2, 0, 1, 0, 0, 0, 2, 3, 3, 3, 1, 0, 0, 0, 0, 0, 2, 3),
+ (0, 5, 0, 4, 0, 5, 0, 2, 0, 4, 5, 5, 3, 3, 4, 3, 3, 1, 5, 4, 4, 2, 4, 4, 4, 3, 4, 2, 4, 3, 5, 5, 4, 3, 3, 4, 3, 3, 5, 5, 4, 5, 5, 1, 3, 4, 5, 3, 1, 4, 3, 1, 3, 3, 0, 3, 3, 1, 4, 3, 1, 4, 5, 3, 3, 5, 0, 4, 0, 3, 0, 5, 3, 3, 1, 4, 3, 0, 4, 0, 1, 5, 3),
+ (0, 5, 0, 5, 0, 4, 0, 2, 0, 4, 4, 3, 4, 3, 3, 3, 3, 3, 5, 4, 4, 4, 4, 4, 4, 5, 3, 3, 5, 2, 4, 4, 4, 3, 4, 4, 3, 3, 4, 4, 5, 5, 3, 3, 4, 3, 4, 3, 3, 4, 3, 3, 3, 3, 1, 2, 2, 1, 4, 3, 3, 5, 4, 4, 3, 4, 0, 4, 0, 3, 0, 4, 4, 4, 4, 4, 1, 0, 4, 2, 0, 2, 4),
+ (0, 4, 0, 4, 0, 3, 0, 1, 0, 3, 5, 2, 3, 0, 3, 0, 2, 1, 4, 2, 3, 3, 4, 1, 4, 3, 3, 2, 4, 1, 3, 3, 3, 0, 3, 3, 0, 0, 3, 3, 3, 5, 3, 3, 3, 3, 3, 2, 0, 2, 0, 0, 2, 0, 0, 2, 0, 0, 1, 0, 0, 3, 1, 2, 2, 3, 0, 3, 0, 2, 0, 4, 4, 3, 3, 4, 1, 0, 3, 0, 0, 2, 4),
+ (0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 0, 3, 1, 3, 0, 3, 2, 0, 0, 0, 1, 0, 3, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 2, 0, 0, 0, 0, 0, 0, 2),
+ (0, 2, 1, 3, 0, 2, 0, 2, 0, 3, 3, 3, 3, 1, 3, 1, 3, 3, 3, 3, 3, 3, 4, 2, 2, 1, 2, 1, 4, 0, 4, 3, 1, 3, 3, 3, 2, 4, 3, 5, 4, 3, 3, 3, 3, 3, 3, 3, 0, 1, 3, 0, 2, 0, 0, 1, 0, 0, 1, 0, 0, 4, 2, 0, 2, 3, 0, 3, 3, 0, 3, 3, 4, 2, 3, 1, 4, 0, 1, 2, 0, 2, 3),
+ (0, 3, 0, 3, 0, 1, 0, 3, 0, 2, 3, 3, 3, 0, 3, 1, 2, 0, 3, 3, 2, 3, 3, 2, 3, 2, 3, 1, 3, 0, 4, 3, 2, 0, 3, 3, 1, 4, 3, 3, 2, 3, 4, 3, 1, 3, 3, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 4, 1, 1, 0, 3, 0, 3, 1, 0, 2, 3, 3, 3, 3, 3, 1, 0, 0, 2, 0, 3, 3),
+ (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 3, 0, 3, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 2, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 3),
+ (0, 2, 0, 3, 1, 3, 0, 3, 0, 2, 3, 3, 3, 1, 3, 1, 3, 1, 3, 1, 3, 3, 3, 1, 3, 0, 2, 3, 1, 1, 4, 3, 3, 2, 3, 3, 1, 2, 2, 4, 1, 3, 3, 0, 1, 4, 2, 3, 0, 1, 3, 0, 3, 0, 0, 1, 3, 0, 2, 0, 0, 3, 3, 2, 1, 3, 0, 3, 0, 2, 0, 3, 4, 4, 4, 3, 1, 0, 3, 0, 0, 3, 3),
+ (0, 2, 0, 1, 0, 2, 0, 0, 0, 1, 3, 2, 2, 1, 3, 0, 1, 1, 3, 0, 3, 2, 3, 1, 2, 0, 2, 0, 1, 1, 3, 3, 3, 0, 3, 3, 1, 1, 2, 3, 2, 3, 3, 1, 2, 3, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 0, 2, 1, 2, 1, 3, 0, 3, 0, 0, 0, 3, 4, 4, 4, 3, 2, 0, 2, 0, 0, 2, 4),
+ (0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 3, 1, 0, 0, 0, 0, 0, 0, 0, 3),
+ (0, 3, 0, 3, 0, 2, 0, 3, 0, 3, 3, 3, 2, 3, 2, 2, 2, 0, 3, 1, 3, 3, 3, 2, 3, 3, 0, 0, 3, 0, 3, 2, 2, 0, 2, 3, 1, 4, 3, 4, 3, 3, 2, 3, 1, 5, 4, 4, 0, 3, 1, 2, 1, 3, 0, 3, 1, 1, 2, 0, 2, 3, 1, 3, 1, 3, 0, 3, 0, 1, 0, 3, 3, 4, 4, 2, 1, 0, 2, 1, 0, 2, 4),
+ (0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 4, 2, 5, 1, 4, 0, 2, 0, 2, 1, 3, 1, 4, 0, 2, 1, 0, 0, 2, 1, 4, 1, 1, 0, 3, 3, 0, 5, 1, 3, 2, 3, 3, 1, 0, 3, 2, 3, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0, 1, 0, 3, 0, 2, 0, 1, 0, 3, 3, 3, 4, 3, 3, 0, 0, 0, 0, 2, 3),
+ (0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 0, 0, 3),
+ (0, 1, 0, 3, 0, 4, 0, 3, 0, 2, 4, 3, 1, 0, 3, 2, 2, 1, 3, 1, 2, 2, 3, 1, 1, 1, 2, 1, 3, 0, 1, 2, 0, 1, 3, 2, 1, 3, 0, 5, 5, 1, 0, 0, 1, 3, 2, 1, 0, 3, 0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 1, 1, 1, 3, 2, 0, 2, 0, 1, 0, 2, 3, 3, 1, 2, 3, 0, 1, 0, 1, 0, 4),
+ (0, 0, 0, 1, 0, 3, 0, 3, 0, 2, 2, 1, 0, 0, 4, 0, 3, 0, 3, 1, 3, 0, 3, 0, 3, 0, 1, 0, 3, 0, 3, 1, 3, 0, 3, 3, 0, 0, 1, 2, 1, 1, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 1, 2, 0, 0, 2, 0, 0, 0, 0, 2, 3, 3, 3, 3, 0, 0, 0, 0, 1, 4),
+ (0, 0, 0, 3, 0, 3, 0, 0, 0, 0, 3, 1, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 2, 0, 2, 3, 0, 0, 2, 2, 3, 1, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 2, 0, 0, 0, 0, 2, 3),
+ (2, 4, 0, 5, 0, 5, 0, 4, 0, 3, 4, 3, 3, 3, 4, 3, 3, 3, 4, 3, 4, 4, 5, 4, 5, 5, 5, 2, 3, 0, 5, 5, 4, 1, 5, 4, 3, 1, 5, 4, 3, 4, 4, 3, 3, 4, 3, 3, 0, 3, 2, 0, 2, 3, 0, 3, 0, 0, 3, 3, 0, 5, 3, 2, 3, 3, 0, 3, 0, 3, 0, 3, 4, 5, 4, 5, 3, 0, 4, 3, 0, 3, 4),
+ (0, 3, 0, 3, 0, 3, 0, 3, 0, 3, 3, 4, 3, 2, 3, 2, 3, 0, 4, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 2, 4, 3, 3, 1, 3, 4, 3, 4, 4, 4, 3, 4, 4, 3, 2, 4, 4, 1, 0, 2, 0, 0, 1, 1, 0, 2, 0, 0, 3, 1, 0, 5, 3, 2, 1, 3, 0, 3, 0, 1, 2, 4, 3, 2, 4, 3, 3, 0, 3, 2, 0, 4, 4),
+ (0, 3, 0, 3, 0, 1, 0, 0, 0, 1, 4, 3, 3, 2, 3, 1, 3, 1, 4, 2, 3, 2, 4, 2, 3, 4, 3, 0, 2, 2, 3, 3, 3, 0, 3, 3, 3, 0, 3, 4, 1, 3, 3, 0, 3, 4, 3, 3, 0, 1, 1, 0, 1, 0, 0, 0, 4, 0, 3, 0, 0, 3, 1, 2, 1, 3, 0, 4, 0, 1, 0, 4, 3, 3, 4, 3, 3, 0, 2, 0, 0, 3, 3),
+ (0, 3, 0, 4, 0, 1, 0, 3, 0, 3, 4, 3, 3, 0, 3, 3, 3, 1, 3, 1, 3, 3, 4, 3, 3, 3, 0, 0, 3, 1, 5, 3, 3, 1, 3, 3, 2, 5, 4, 3, 3, 4, 5, 3, 2, 5, 3, 4, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1, 0, 4, 2, 2, 1, 3, 0, 3, 0, 2, 0, 4, 4, 3, 5, 3, 2, 0, 1, 1, 0, 3, 4),
+ (0, 5, 0, 4, 0, 5, 0, 2, 0, 4, 4, 3, 3, 2, 3, 3, 3, 1, 4, 3, 4, 1, 5, 3, 4, 3, 4, 0, 4, 2, 4, 3, 4, 1, 5, 4, 0, 4, 4, 4, 4, 5, 4, 1, 3, 5, 4, 2, 1, 4, 1, 1, 3, 2, 0, 3, 1, 0, 3, 2, 1, 4, 3, 3, 3, 4, 0, 4, 0, 3, 0, 4, 4, 4, 3, 3, 3, 0, 4, 2, 0, 3, 4),
+ (1, 4, 0, 4, 0, 3, 0, 1, 0, 3, 3, 3, 1, 1, 3, 3, 2, 2, 3, 3, 1, 0, 3, 2, 2, 1, 2, 0, 3, 1, 2, 1, 2, 0, 3, 2, 0, 2, 2, 3, 3, 4, 3, 0, 3, 3, 1, 2, 0, 1, 1, 3, 1, 2, 0, 0, 3, 0, 1, 1, 0, 3, 2, 2, 3, 3, 0, 3, 0, 0, 0, 2, 3, 3, 4, 3, 3, 0, 1, 0, 0, 1, 4),
+ (0, 4, 0, 4, 0, 4, 0, 0, 0, 3, 4, 4, 3, 1, 4, 2, 3, 2, 3, 3, 3, 1, 4, 3, 4, 0, 3, 0, 4, 2, 3, 3, 2, 2, 5, 4, 2, 1, 3, 4, 3, 4, 3, 1, 3, 3, 4, 2, 0, 2, 1, 0, 3, 3, 0, 0, 2, 0, 3, 1, 0, 4, 4, 3, 4, 3, 0, 4, 0, 1, 0, 2, 4, 4, 4, 4, 4, 0, 3, 2, 0, 3, 3),
+ (0, 0, 0, 1, 0, 4, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 2, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2),
+ (0, 2, 0, 3, 0, 4, 0, 4, 0, 1, 3, 3, 3, 0, 4, 0, 2, 1, 2, 1, 1, 1, 2, 0, 3, 1, 1, 0, 1, 0, 3, 1, 0, 0, 3, 3, 2, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 2, 0, 3, 1, 0, 0, 1, 0, 1, 1, 0, 1, 2, 0, 3, 0, 0, 0, 0, 1, 0, 0, 3, 3, 4, 3, 1, 0, 1, 0, 3, 0, 2),
+ (0, 0, 0, 3, 0, 5, 0, 0, 0, 0, 1, 0, 2, 0, 3, 1, 0, 1, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 4, 0, 0, 0, 2, 3, 0, 1, 4, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 3),
+ (0, 2, 0, 5, 0, 5, 0, 1, 0, 2, 4, 3, 3, 2, 5, 1, 3, 2, 3, 3, 3, 0, 4, 1, 2, 0, 3, 0, 4, 0, 2, 2, 1, 1, 5, 3, 0, 0, 1, 4, 2, 3, 2, 0, 3, 3, 3, 2, 0, 2, 4, 1, 1, 2, 0, 1, 1, 0, 3, 1, 0, 1, 3, 1, 2, 3, 0, 2, 0, 0, 0, 1, 3, 5, 4, 4, 4, 0, 3, 0, 0, 1, 3),
+ (0, 4, 0, 5, 0, 4, 0, 4, 0, 4, 5, 4, 3, 3, 4, 3, 3, 3, 4, 3, 4, 4, 5, 3, 4, 5, 4, 2, 4, 2, 3, 4, 3, 1, 4, 4, 1, 3, 5, 4, 4, 5, 5, 4, 4, 5, 5, 5, 2, 3, 3, 1, 4, 3, 1, 3, 3, 0, 3, 3, 1, 4, 3, 4, 4, 4, 0, 3, 0, 4, 0, 3, 3, 4, 4, 5, 0, 0, 4, 3, 0, 4, 5),
+ (0, 4, 0, 4, 0, 3, 0, 3, 0, 3, 4, 4, 4, 3, 3, 2, 4, 3, 4, 3, 4, 3, 5, 3, 4, 3, 2, 1, 4, 2, 4, 4, 3, 1, 3, 4, 2, 4, 5, 5, 3, 4, 5, 4, 1, 5, 4, 3, 0, 3, 2, 2, 3, 2, 1, 3, 1, 0, 3, 3, 3, 5, 3, 3, 3, 5, 4, 4, 2, 3, 3, 4, 3, 3, 3, 2, 1, 0, 3, 2, 1, 4, 3),
+ (0, 4, 0, 5, 0, 4, 0, 3, 0, 3, 5, 5, 3, 2, 4, 3, 4, 0, 5, 4, 4, 1, 4, 4, 4, 3, 3, 3, 4, 3, 5, 5, 2, 3, 3, 4, 1, 2, 5, 5, 3, 5, 5, 2, 3, 5, 5, 4, 0, 3, 2, 0, 3, 3, 1, 1, 5, 1, 4, 1, 0, 4, 3, 2, 3, 5, 0, 4, 0, 3, 0, 5, 4, 3, 4, 3, 0, 0, 4, 1, 0, 4, 4),
+ (1, 3, 0, 4, 0, 2, 0, 2, 0, 2, 5, 5, 3, 3, 3, 3, 3, 0, 4, 2, 3, 4, 4, 4, 3, 4, 0, 0, 3, 4, 5, 4, 3, 3, 3, 3, 2, 5, 5, 4, 5, 5, 5, 4, 3, 5, 5, 5, 1, 3, 1, 0, 1, 0, 0, 3, 2, 0, 4, 2, 0, 5, 2, 3, 2, 4, 1, 3, 0, 3, 0, 4, 5, 4, 5, 4, 3, 0, 4, 2, 0, 5, 4),
+ (0, 3, 0, 4, 0, 5, 0, 3, 0, 3, 4, 4, 3, 2, 3, 2, 3, 3, 3, 3, 3, 2, 4, 3, 3, 2, 2, 0, 3, 3, 3, 3, 3, 1, 3, 3, 3, 0, 4, 4, 3, 4, 4, 1, 1, 4, 4, 2, 0, 3, 1, 0, 1, 1, 0, 4, 1, 0, 2, 3, 1, 3, 3, 1, 3, 4, 0, 3, 0, 1, 0, 3, 1, 3, 0, 0, 1, 0, 2, 0, 0, 4, 4),
+ (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
+ (0, 3, 0, 3, 0, 2, 0, 3, 0, 1, 5, 4, 3, 3, 3, 1, 4, 2, 1, 2, 3, 4, 4, 2, 4, 4, 5, 0, 3, 1, 4, 3, 4, 0, 4, 3, 3, 3, 2, 3, 2, 5, 3, 4, 3, 2, 2, 3, 0, 0, 3, 0, 2, 1, 0, 1, 2, 0, 0, 0, 0, 2, 1, 1, 3, 1, 0, 2, 0, 4, 0, 3, 4, 4, 4, 5, 2, 0, 2, 0, 0, 1, 3),
+ (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 4, 2, 1, 1, 0, 1, 0, 3, 2, 0, 0, 3, 1, 1, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 0, 0, 2, 0, 0, 0, 1, 4, 0, 4, 2, 1, 0, 0, 0, 0, 0, 1),
+ (0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 3, 1, 0, 0, 0, 2, 0, 2, 1, 0, 0, 1, 2, 1, 0, 1, 1, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 1, 0, 0, 0, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2),
+ (0, 4, 0, 4, 0, 4, 0, 3, 0, 4, 4, 3, 4, 2, 4, 3, 2, 0, 4, 4, 4, 3, 5, 3, 5, 3, 3, 2, 4, 2, 4, 3, 4, 3, 1, 4, 0, 2, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 4, 1, 3, 4, 3, 2, 1, 2, 1, 3, 3, 3, 4, 4, 3, 3, 5, 0, 4, 0, 3, 0, 4, 3, 3, 3, 2, 1, 0, 3, 0, 0, 3, 3),
+ (0, 4, 0, 3, 0, 3, 0, 3, 0, 3, 5, 5, 3, 3, 3, 3, 4, 3, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 4, 3, 5, 3, 3, 1, 3, 2, 4, 5, 5, 5, 5, 4, 3, 4, 5, 5, 3, 2, 2, 3, 3, 3, 3, 2, 3, 3, 1, 2, 3, 2, 4, 3, 3, 3, 4, 0, 4, 0, 2, 0, 4, 3, 2, 2, 1, 2, 0, 3, 0, 0, 4, 1),
+)
+# fmt: on
+
+
+class JapaneseContextAnalysis:
+ NUM_OF_CATEGORY = 6
+ DONT_KNOW = -1
+ ENOUGH_REL_THRESHOLD = 100
+ MAX_REL_THRESHOLD = 1000
+ MINIMUM_DATA_THRESHOLD = 4
+
+ def __init__(self):
+ self._total_rel = None
+ self._rel_sample = None
+ self._need_to_skip_char_num = None
+ self._last_char_order = None
+ self._done = None
+ self.reset()
+
+ def reset(self):
+ self._total_rel = 0 # total sequence received
+ # category counters, each integer counts sequence in its category
+ self._rel_sample = [0] * self.NUM_OF_CATEGORY
+ # if last byte in current buffer is not the last byte of a character,
+ # we need to know how many bytes to skip in next buffer
+ self._need_to_skip_char_num = 0
+ self._last_char_order = -1 # The order of previous char
+ # If this flag is set to True, detection is done and conclusion has
+ # been made
+ self._done = False
+
+ def feed(self, byte_str, num_bytes):
+ if self._done:
+ return
+
+ # The buffer we got is byte oriented, and a character may span in more than one
+ # buffers. In case the last one or two byte in last buffer is not
+ # complete, we record how many byte needed to complete that character
+ # and skip these bytes here. We can choose to record those bytes as
+ # well and analyse the character once it is complete, but since a
+ # character will not make much difference, by simply skipping
+ # this character will simply our logic and improve performance.
+ i = self._need_to_skip_char_num
+ while i < num_bytes:
+ order, char_len = self.get_order(byte_str[i : i + 2])
+ i += char_len
+ if i > num_bytes:
+ self._need_to_skip_char_num = i - num_bytes
+ self._last_char_order = -1
+ else:
+ if (order != -1) and (self._last_char_order != -1):
+ self._total_rel += 1
+ if self._total_rel > self.MAX_REL_THRESHOLD:
+ self._done = True
+ break
+ self._rel_sample[
+ jp2_char_context[self._last_char_order][order]
+ ] += 1
+ self._last_char_order = order
+
+ def got_enough_data(self):
+ return self._total_rel > self.ENOUGH_REL_THRESHOLD
+
+ def get_confidence(self):
+ # This is just one way to calculate confidence. It works well for me.
+ if self._total_rel > self.MINIMUM_DATA_THRESHOLD:
+ return (self._total_rel - self._rel_sample[0]) / self._total_rel
+ return self.DONT_KNOW
+
+ def get_order(self, _):
+ return -1, 1
+
+
+class SJISContextAnalysis(JapaneseContextAnalysis):
+ def __init__(self):
+ super().__init__()
+ self._charset_name = "SHIFT_JIS"
+
+ @property
+ def charset_name(self):
+ return self._charset_name
+
+ def get_order(self, byte_str):
+ if not byte_str:
+ return -1, 1
+ # find out current char's byte length
+ first_char = byte_str[0]
+ if (0x81 <= first_char <= 0x9F) or (0xE0 <= first_char <= 0xFC):
+ char_len = 2
+ if (first_char == 0x87) or (0xFA <= first_char <= 0xFC):
+ self._charset_name = "CP932"
+ else:
+ char_len = 1
+
+ # return its order if it is hiragana
+ if len(byte_str) > 1:
+ second_char = byte_str[1]
+ if (first_char == 202) and (0x9F <= second_char <= 0xF1):
+ return second_char - 0x9F, char_len
+
+ return -1, char_len
+
+
+class EUCJPContextAnalysis(JapaneseContextAnalysis):
+ def get_order(self, byte_str):
+ if not byte_str:
+ return -1, 1
+ # find out current char's byte length
+ first_char = byte_str[0]
+ if (first_char == 0x8E) or (0xA1 <= first_char <= 0xFE):
+ char_len = 2
+ elif first_char == 0x8F:
+ char_len = 3
+ else:
+ char_len = 1
+
+ # return its order if it is hiragana
+ if len(byte_str) > 1:
+ second_char = byte_str[1]
+ if (first_char == 0xA4) and (0xA1 <= second_char <= 0xF3):
+ return second_char - 0xA1, char_len
+
+ return -1, char_len
diff --git a/lib/chardet/langbulgarianmodel.py b/lib/chardet/langbulgarianmodel.py
new file mode 100644
index 0000000..2f771bb
--- /dev/null
+++ b/lib/chardet/langbulgarianmodel.py
@@ -0,0 +1,4649 @@
+from chardet.sbcharsetprober import SingleByteCharSetModel
+
+# 3: Positive
+# 2: Likely
+# 1: Unlikely
+# 0: Negative
+
+BULGARIAN_LANG_MODEL = {
+ 63: { # 'e'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 1, # 'б'
+ 9: 1, # 'в'
+ 20: 1, # 'г'
+ 11: 1, # 'д'
+ 3: 1, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 0, # 'и'
+ 26: 1, # 'й'
+ 12: 1, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 1, # 'о'
+ 13: 1, # 'п'
+ 7: 1, # 'р'
+ 8: 1, # 'с'
+ 5: 1, # 'т'
+ 19: 0, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 45: { # '\xad'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 0, # 'Л'
+ 38: 1, # 'М'
+ 36: 0, # 'Н'
+ 41: 1, # 'О'
+ 30: 1, # 'П'
+ 39: 1, # 'Р'
+ 28: 1, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 0, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 0, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 0, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 31: { # 'А'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 1, # 'А'
+ 32: 1, # 'Б'
+ 35: 2, # 'В'
+ 43: 1, # 'Г'
+ 37: 2, # 'Д'
+ 44: 2, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 2, # 'З'
+ 40: 1, # 'И'
+ 59: 1, # 'Й'
+ 33: 1, # 'К'
+ 46: 2, # 'Л'
+ 38: 1, # 'М'
+ 36: 2, # 'Н'
+ 41: 1, # 'О'
+ 30: 2, # 'П'
+ 39: 2, # 'Р'
+ 28: 2, # 'С'
+ 34: 2, # 'Т'
+ 51: 1, # 'У'
+ 48: 2, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 2, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 1, # 'а'
+ 18: 2, # 'б'
+ 9: 2, # 'в'
+ 20: 2, # 'г'
+ 11: 2, # 'д'
+ 3: 1, # 'е'
+ 23: 1, # 'ж'
+ 15: 2, # 'з'
+ 2: 0, # 'и'
+ 26: 2, # 'й'
+ 12: 2, # 'к'
+ 10: 3, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 0, # 'о'
+ 13: 2, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 2, # 'т'
+ 19: 1, # 'у'
+ 29: 2, # 'ф'
+ 25: 1, # 'х'
+ 22: 1, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 32: { # 'Б'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 2, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 2, # 'Д'
+ 44: 1, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 2, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 2, # 'Н'
+ 41: 2, # 'О'
+ 30: 1, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 2, # 'Т'
+ 51: 1, # 'У'
+ 48: 2, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 1, # 'Щ'
+ 61: 2, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 2, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 2, # 'р'
+ 8: 1, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 2, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 35: { # 'В'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 2, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 1, # 'П'
+ 39: 2, # 'Р'
+ 28: 2, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 2, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 2, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 2, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 2, # 'н'
+ 4: 2, # 'о'
+ 13: 1, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 2, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 2, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 43: { # 'Г'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 2, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 0, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 0, # 'П'
+ 39: 1, # 'Р'
+ 28: 1, # 'С'
+ 34: 0, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 1, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 1, # 'б'
+ 9: 1, # 'в'
+ 20: 0, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 2, # 'о'
+ 13: 0, # 'п'
+ 7: 2, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 1, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 37: { # 'Д'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 2, # 'В'
+ 43: 1, # 'Г'
+ 37: 2, # 'Д'
+ 44: 2, # 'Е'
+ 55: 2, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 2, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 2, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 3, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 2, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 2, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 44: { # 'Е'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 1, # 'Б'
+ 35: 2, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 1, # 'З'
+ 40: 1, # 'И'
+ 59: 1, # 'Й'
+ 33: 2, # 'К'
+ 46: 2, # 'Л'
+ 38: 1, # 'М'
+ 36: 2, # 'Н'
+ 41: 2, # 'О'
+ 30: 1, # 'П'
+ 39: 2, # 'Р'
+ 28: 2, # 'С'
+ 34: 2, # 'Т'
+ 51: 1, # 'У'
+ 48: 2, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 2, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 1, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 0, # 'а'
+ 18: 1, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 2, # 'д'
+ 3: 0, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 0, # 'и'
+ 26: 1, # 'й'
+ 12: 2, # 'к'
+ 10: 2, # 'л'
+ 14: 2, # 'м'
+ 6: 2, # 'н'
+ 4: 0, # 'о'
+ 13: 1, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 1, # 'т'
+ 19: 1, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 55: { # 'Ж'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 0, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 1, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 1, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 2, # 'о'
+ 13: 1, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 47: { # 'З'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 2, # 'Н'
+ 41: 1, # 'О'
+ 30: 1, # 'П'
+ 39: 1, # 'Р'
+ 28: 1, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 0, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 2, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 1, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 1, # 'о'
+ 13: 0, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 40: { # 'И'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 1, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 2, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 2, # 'З'
+ 40: 1, # 'И'
+ 59: 1, # 'Й'
+ 33: 2, # 'К'
+ 46: 2, # 'Л'
+ 38: 2, # 'М'
+ 36: 2, # 'Н'
+ 41: 1, # 'О'
+ 30: 1, # 'П'
+ 39: 2, # 'Р'
+ 28: 2, # 'С'
+ 34: 2, # 'Т'
+ 51: 0, # 'У'
+ 48: 1, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 1, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 2, # 'Я'
+ 1: 1, # 'а'
+ 18: 1, # 'б'
+ 9: 3, # 'в'
+ 20: 2, # 'г'
+ 11: 1, # 'д'
+ 3: 1, # 'е'
+ 23: 0, # 'ж'
+ 15: 3, # 'з'
+ 2: 0, # 'и'
+ 26: 1, # 'й'
+ 12: 1, # 'к'
+ 10: 2, # 'л'
+ 14: 2, # 'м'
+ 6: 2, # 'н'
+ 4: 0, # 'о'
+ 13: 1, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 2, # 'т'
+ 19: 0, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 1, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 59: { # 'Й'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 1, # 'С'
+ 34: 1, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 0, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 1, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 0, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 2, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 33: { # 'К'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 0, # 'М'
+ 36: 2, # 'Н'
+ 41: 2, # 'О'
+ 30: 2, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 1, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 2, # 'е'
+ 23: 1, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 2, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 3, # 'р'
+ 8: 1, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 46: { # 'Л'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 2, # 'Г'
+ 37: 1, # 'Д'
+ 44: 2, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 0, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 1, # 'П'
+ 39: 0, # 'Р'
+ 28: 1, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 0, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 1, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 2, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 38: { # 'М'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 2, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 1, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 2, # 'л'
+ 14: 0, # 'м'
+ 6: 2, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 36: { # 'Н'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 2, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 2, # 'Д'
+ 44: 2, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 1, # 'Й'
+ 33: 2, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 1, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 2, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 1, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 0, # 'с'
+ 5: 1, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 2, # 'ю'
+ 16: 2, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 41: { # 'О'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 1, # 'Б'
+ 35: 2, # 'В'
+ 43: 1, # 'Г'
+ 37: 2, # 'Д'
+ 44: 1, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 1, # 'З'
+ 40: 1, # 'И'
+ 59: 1, # 'Й'
+ 33: 2, # 'К'
+ 46: 2, # 'Л'
+ 38: 2, # 'М'
+ 36: 2, # 'Н'
+ 41: 2, # 'О'
+ 30: 1, # 'П'
+ 39: 2, # 'Р'
+ 28: 2, # 'С'
+ 34: 2, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 1, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 1, # 'а'
+ 18: 2, # 'б'
+ 9: 2, # 'в'
+ 20: 2, # 'г'
+ 11: 1, # 'д'
+ 3: 1, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 0, # 'и'
+ 26: 1, # 'й'
+ 12: 2, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 0, # 'о'
+ 13: 2, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 3, # 'т'
+ 19: 1, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 1, # 'ц'
+ 21: 2, # 'ч'
+ 27: 0, # 'ш'
+ 24: 2, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 30: { # 'П'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 2, # 'П'
+ 39: 2, # 'Р'
+ 28: 2, # 'С'
+ 34: 1, # 'Т'
+ 51: 2, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 2, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 3, # 'л'
+ 14: 0, # 'м'
+ 6: 1, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 3, # 'р'
+ 8: 1, # 'с'
+ 5: 1, # 'т'
+ 19: 2, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 39: { # 'Р'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 2, # 'Г'
+ 37: 2, # 'Д'
+ 44: 2, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 0, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 2, # 'П'
+ 39: 1, # 'Р'
+ 28: 1, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 1, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 1, # 'с'
+ 5: 0, # 'т'
+ 19: 3, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 28: { # 'С'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 3, # 'А'
+ 32: 2, # 'Б'
+ 35: 2, # 'В'
+ 43: 1, # 'Г'
+ 37: 2, # 'Д'
+ 44: 2, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 1, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 2, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 2, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 2, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 1, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 2, # 'к'
+ 10: 3, # 'л'
+ 14: 2, # 'м'
+ 6: 1, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 2, # 'р'
+ 8: 0, # 'с'
+ 5: 3, # 'т'
+ 19: 2, # 'у'
+ 29: 2, # 'ф'
+ 25: 1, # 'х'
+ 22: 1, # 'ц'
+ 21: 1, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 34: { # 'Т'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 2, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 2, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 2, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 2, # 'О'
+ 30: 1, # 'П'
+ 39: 2, # 'Р'
+ 28: 2, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 1, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 1, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 1, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 1, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 3, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 2, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 51: { # 'У'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 1, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 2, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 1, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 0, # 'О'
+ 30: 1, # 'П'
+ 39: 1, # 'Р'
+ 28: 1, # 'С'
+ 34: 2, # 'Т'
+ 51: 0, # 'У'
+ 48: 1, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 1, # 'а'
+ 18: 1, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 1, # 'д'
+ 3: 2, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 2, # 'и'
+ 26: 1, # 'й'
+ 12: 2, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 2, # 'н'
+ 4: 2, # 'о'
+ 13: 1, # 'п'
+ 7: 1, # 'р'
+ 8: 2, # 'с'
+ 5: 1, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 2, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 48: { # 'Ф'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 0, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 2, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 1, # 'Т'
+ 51: 1, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 2, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 2, # 'о'
+ 13: 0, # 'п'
+ 7: 2, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 49: { # 'Х'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 0, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 1, # 'П'
+ 39: 1, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 1, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 0, # 'н'
+ 4: 2, # 'о'
+ 13: 0, # 'п'
+ 7: 2, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 53: { # 'Ц'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 0, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 2, # 'И'
+ 59: 0, # 'Й'
+ 33: 2, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 1, # 'Р'
+ 28: 2, # 'С'
+ 34: 0, # 'Т'
+ 51: 1, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 2, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 1, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 1, # 'о'
+ 13: 0, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 50: { # 'Ч'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 2, # 'А'
+ 32: 1, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 0, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 1, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 1, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 2, # 'о'
+ 13: 0, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 54: { # 'Ш'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 1, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 1, # 'Н'
+ 41: 1, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 1, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 2, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 2, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 2, # 'о'
+ 13: 1, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 57: { # 'Щ'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 1, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 1, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 1, # 'о'
+ 13: 0, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 1, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 61: { # 'Ъ'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 1, # 'Д'
+ 44: 0, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 1, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 2, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 0, # 'О'
+ 30: 1, # 'П'
+ 39: 2, # 'Р'
+ 28: 1, # 'С'
+ 34: 1, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 1, # 'Х'
+ 53: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 54: 1, # 'Ш'
+ 57: 1, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 0, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 0, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 1, # 'л'
+ 14: 0, # 'м'
+ 6: 1, # 'н'
+ 4: 0, # 'о'
+ 13: 0, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 60: { # 'Ю'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 1, # 'Б'
+ 35: 0, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 0, # 'Е'
+ 55: 1, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 0, # 'М'
+ 36: 1, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 1, # 'Р'
+ 28: 1, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 1, # 'б'
+ 9: 1, # 'в'
+ 20: 2, # 'г'
+ 11: 1, # 'д'
+ 3: 0, # 'е'
+ 23: 2, # 'ж'
+ 15: 1, # 'з'
+ 2: 1, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 0, # 'о'
+ 13: 1, # 'п'
+ 7: 1, # 'р'
+ 8: 1, # 'с'
+ 5: 1, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 56: { # 'Я'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 1, # 'Б'
+ 35: 1, # 'В'
+ 43: 1, # 'Г'
+ 37: 1, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 1, # 'Л'
+ 38: 1, # 'М'
+ 36: 1, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 1, # 'С'
+ 34: 2, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 1, # 'б'
+ 9: 1, # 'в'
+ 20: 1, # 'г'
+ 11: 1, # 'д'
+ 3: 0, # 'е'
+ 23: 0, # 'ж'
+ 15: 1, # 'з'
+ 2: 1, # 'и'
+ 26: 1, # 'й'
+ 12: 1, # 'к'
+ 10: 1, # 'л'
+ 14: 2, # 'м'
+ 6: 2, # 'н'
+ 4: 0, # 'о'
+ 13: 2, # 'п'
+ 7: 1, # 'р'
+ 8: 1, # 'с'
+ 5: 1, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 1: { # 'а'
+ 63: 1, # 'e'
+ 45: 1, # '\xad'
+ 31: 1, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 1, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 1, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 3, # 'ж'
+ 15: 3, # 'з'
+ 2: 3, # 'и'
+ 26: 3, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 2, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 3, # 'ф'
+ 25: 3, # 'х'
+ 22: 3, # 'ц'
+ 21: 3, # 'ч'
+ 27: 3, # 'ш'
+ 24: 3, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 18: { # 'б'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 3, # 'в'
+ 20: 1, # 'г'
+ 11: 2, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 3, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 1, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 0, # 'т'
+ 19: 3, # 'у'
+ 29: 0, # 'ф'
+ 25: 2, # 'х'
+ 22: 1, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 3, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 9: { # 'в'
+ 63: 1, # 'e'
+ 45: 1, # '\xad'
+ 31: 0, # 'А'
+ 32: 1, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 0, # 'в'
+ 20: 2, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 3, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 2, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 3, # 'ч'
+ 27: 2, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 20: { # 'г'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 2, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 3, # 'л'
+ 14: 1, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 1, # 'п'
+ 7: 3, # 'р'
+ 8: 2, # 'с'
+ 5: 2, # 'т'
+ 19: 3, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 11: { # 'д'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 2, # 'б'
+ 9: 3, # 'в'
+ 20: 2, # 'г'
+ 11: 2, # 'д'
+ 3: 3, # 'е'
+ 23: 3, # 'ж'
+ 15: 2, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 1, # 'т'
+ 19: 3, # 'у'
+ 29: 1, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 3: { # 'е'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 2, # 'е'
+ 23: 3, # 'ж'
+ 15: 3, # 'з'
+ 2: 2, # 'и'
+ 26: 3, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 2, # 'у'
+ 29: 3, # 'ф'
+ 25: 3, # 'х'
+ 22: 3, # 'ц'
+ 21: 3, # 'ч'
+ 27: 3, # 'ш'
+ 24: 3, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 23: { # 'ж'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 2, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 3, # 'н'
+ 4: 2, # 'о'
+ 13: 1, # 'п'
+ 7: 1, # 'р'
+ 8: 1, # 'с'
+ 5: 1, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 1, # 'ц'
+ 21: 1, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 15: { # 'з'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 1, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 2, # 'ш'
+ 24: 1, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 2, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 2: { # 'и'
+ 63: 1, # 'e'
+ 45: 1, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 1, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 1, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 1, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 1, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 3, # 'ж'
+ 15: 3, # 'з'
+ 2: 3, # 'и'
+ 26: 3, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 2, # 'у'
+ 29: 3, # 'ф'
+ 25: 3, # 'х'
+ 22: 3, # 'ц'
+ 21: 3, # 'ч'
+ 27: 3, # 'ш'
+ 24: 3, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 26: { # 'й'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 1, # 'а'
+ 18: 2, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 2, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 2, # 'з'
+ 2: 1, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 2, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 2, # 'о'
+ 13: 1, # 'п'
+ 7: 2, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 1, # 'у'
+ 29: 2, # 'ф'
+ 25: 1, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 12: { # 'к'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 1, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 1, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 3, # 'в'
+ 20: 2, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 2, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 3, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 1, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 3, # 'ц'
+ 21: 2, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 10: { # 'л'
+ 63: 1, # 'e'
+ 45: 1, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 1, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 2, # 'д'
+ 3: 3, # 'е'
+ 23: 3, # 'ж'
+ 15: 2, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 1, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 2, # 'п'
+ 7: 2, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 2, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 2, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 2, # 'ь'
+ 42: 3, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 14: { # 'м'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 1, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 1, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 2, # 'к'
+ 10: 3, # 'л'
+ 14: 1, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 1, # 'т'
+ 19: 3, # 'у'
+ 29: 2, # 'ф'
+ 25: 1, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 2, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 6: { # 'н'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 1, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 2, # 'б'
+ 9: 2, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 2, # 'ж'
+ 15: 2, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 1, # 'п'
+ 7: 2, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 3, # 'ф'
+ 25: 2, # 'х'
+ 22: 3, # 'ц'
+ 21: 3, # 'ч'
+ 27: 2, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 2, # 'ь'
+ 42: 2, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 4: { # 'о'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 2, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 3, # 'ж'
+ 15: 3, # 'з'
+ 2: 3, # 'и'
+ 26: 3, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 2, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 2, # 'у'
+ 29: 3, # 'ф'
+ 25: 3, # 'х'
+ 22: 3, # 'ц'
+ 21: 3, # 'ч'
+ 27: 3, # 'ш'
+ 24: 3, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 13: { # 'п'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 1, # 'й'
+ 12: 2, # 'к'
+ 10: 3, # 'л'
+ 14: 1, # 'м'
+ 6: 2, # 'н'
+ 4: 3, # 'о'
+ 13: 1, # 'п'
+ 7: 3, # 'р'
+ 8: 2, # 'с'
+ 5: 2, # 'т'
+ 19: 3, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 2, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 7: { # 'р'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 3, # 'е'
+ 23: 3, # 'ж'
+ 15: 2, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 2, # 'п'
+ 7: 1, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 2, # 'ф'
+ 25: 3, # 'х'
+ 22: 3, # 'ц'
+ 21: 2, # 'ч'
+ 27: 3, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 2, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 8: { # 'с'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 2, # 'б'
+ 9: 3, # 'в'
+ 20: 2, # 'г'
+ 11: 2, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 1, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 2, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 2, # 'ш'
+ 24: 0, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 2, # 'ь'
+ 42: 2, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 5: { # 'т'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 2, # 'г'
+ 11: 2, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 2, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 3, # 'у'
+ 29: 1, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 2, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 3, # 'ъ'
+ 52: 2, # 'ь'
+ 42: 2, # 'ю'
+ 16: 3, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 19: { # 'у'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 2, # 'е'
+ 23: 3, # 'ж'
+ 15: 3, # 'з'
+ 2: 2, # 'и'
+ 26: 2, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 2, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 1, # 'у'
+ 29: 2, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 3, # 'ч'
+ 27: 3, # 'ш'
+ 24: 2, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 29: { # 'ф'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 1, # 'в'
+ 20: 1, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 2, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 2, # 'т'
+ 19: 2, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 2, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 25: { # 'х'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 3, # 'в'
+ 20: 0, # 'г'
+ 11: 1, # 'д'
+ 3: 2, # 'е'
+ 23: 0, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 2, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 1, # 'п'
+ 7: 3, # 'р'
+ 8: 1, # 'с'
+ 5: 2, # 'т'
+ 19: 3, # 'у'
+ 29: 0, # 'ф'
+ 25: 1, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 22: { # 'ц'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 2, # 'в'
+ 20: 1, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 1, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 2, # 'к'
+ 10: 1, # 'л'
+ 14: 1, # 'м'
+ 6: 1, # 'н'
+ 4: 2, # 'о'
+ 13: 1, # 'п'
+ 7: 1, # 'р'
+ 8: 1, # 'с'
+ 5: 1, # 'т'
+ 19: 2, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 1, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 0, # 'ю'
+ 16: 2, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 21: { # 'ч'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 1, # 'б'
+ 9: 3, # 'в'
+ 20: 1, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 1, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 2, # 'л'
+ 14: 2, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 2, # 'р'
+ 8: 0, # 'с'
+ 5: 2, # 'т'
+ 19: 3, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 1, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 27: { # 'ш'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 2, # 'в'
+ 20: 0, # 'г'
+ 11: 1, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 3, # 'к'
+ 10: 2, # 'л'
+ 14: 1, # 'м'
+ 6: 3, # 'н'
+ 4: 2, # 'о'
+ 13: 2, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 1, # 'т'
+ 19: 2, # 'у'
+ 29: 1, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 1, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 2, # 'ъ'
+ 52: 1, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 24: { # 'щ'
+ 63: 1, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 3, # 'а'
+ 18: 0, # 'б'
+ 9: 1, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 3, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 3, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 2, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 1, # 'р'
+ 8: 0, # 'с'
+ 5: 2, # 'т'
+ 19: 3, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 1, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 2, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 17: { # 'ъ'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 1, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 3, # 'г'
+ 11: 3, # 'д'
+ 3: 2, # 'е'
+ 23: 3, # 'ж'
+ 15: 3, # 'з'
+ 2: 1, # 'и'
+ 26: 2, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 3, # 'о'
+ 13: 3, # 'п'
+ 7: 3, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 1, # 'у'
+ 29: 1, # 'ф'
+ 25: 2, # 'х'
+ 22: 2, # 'ц'
+ 21: 3, # 'ч'
+ 27: 2, # 'ш'
+ 24: 3, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 2, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 52: { # 'ь'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 1, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 0, # 'и'
+ 26: 0, # 'й'
+ 12: 1, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 1, # 'н'
+ 4: 3, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 0, # 'с'
+ 5: 1, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 1, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 1, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 42: { # 'ю'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 1, # 'а'
+ 18: 2, # 'б'
+ 9: 1, # 'в'
+ 20: 2, # 'г'
+ 11: 2, # 'д'
+ 3: 1, # 'е'
+ 23: 2, # 'ж'
+ 15: 2, # 'з'
+ 2: 1, # 'и'
+ 26: 1, # 'й'
+ 12: 2, # 'к'
+ 10: 2, # 'л'
+ 14: 2, # 'м'
+ 6: 2, # 'н'
+ 4: 1, # 'о'
+ 13: 1, # 'п'
+ 7: 2, # 'р'
+ 8: 2, # 'с'
+ 5: 2, # 'т'
+ 19: 1, # 'у'
+ 29: 1, # 'ф'
+ 25: 1, # 'х'
+ 22: 2, # 'ц'
+ 21: 3, # 'ч'
+ 27: 1, # 'ш'
+ 24: 1, # 'щ'
+ 17: 1, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 16: { # 'я'
+ 63: 0, # 'e'
+ 45: 1, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 3, # 'б'
+ 9: 3, # 'в'
+ 20: 2, # 'г'
+ 11: 3, # 'д'
+ 3: 2, # 'е'
+ 23: 1, # 'ж'
+ 15: 2, # 'з'
+ 2: 1, # 'и'
+ 26: 2, # 'й'
+ 12: 3, # 'к'
+ 10: 3, # 'л'
+ 14: 3, # 'м'
+ 6: 3, # 'н'
+ 4: 1, # 'о'
+ 13: 2, # 'п'
+ 7: 2, # 'р'
+ 8: 3, # 'с'
+ 5: 3, # 'т'
+ 19: 1, # 'у'
+ 29: 1, # 'ф'
+ 25: 3, # 'х'
+ 22: 2, # 'ц'
+ 21: 1, # 'ч'
+ 27: 1, # 'ш'
+ 24: 2, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 1, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 58: { # 'є'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 0, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 0, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 0, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+ 62: { # '№'
+ 63: 0, # 'e'
+ 45: 0, # '\xad'
+ 31: 0, # 'А'
+ 32: 0, # 'Б'
+ 35: 0, # 'В'
+ 43: 0, # 'Г'
+ 37: 0, # 'Д'
+ 44: 0, # 'Е'
+ 55: 0, # 'Ж'
+ 47: 0, # 'З'
+ 40: 0, # 'И'
+ 59: 0, # 'Й'
+ 33: 0, # 'К'
+ 46: 0, # 'Л'
+ 38: 0, # 'М'
+ 36: 0, # 'Н'
+ 41: 0, # 'О'
+ 30: 0, # 'П'
+ 39: 0, # 'Р'
+ 28: 0, # 'С'
+ 34: 0, # 'Т'
+ 51: 0, # 'У'
+ 48: 0, # 'Ф'
+ 49: 0, # 'Х'
+ 53: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 54: 0, # 'Ш'
+ 57: 0, # 'Щ'
+ 61: 0, # 'Ъ'
+ 60: 0, # 'Ю'
+ 56: 0, # 'Я'
+ 1: 0, # 'а'
+ 18: 0, # 'б'
+ 9: 0, # 'в'
+ 20: 0, # 'г'
+ 11: 0, # 'д'
+ 3: 0, # 'е'
+ 23: 0, # 'ж'
+ 15: 0, # 'з'
+ 2: 0, # 'и'
+ 26: 0, # 'й'
+ 12: 0, # 'к'
+ 10: 0, # 'л'
+ 14: 0, # 'м'
+ 6: 0, # 'н'
+ 4: 0, # 'о'
+ 13: 0, # 'п'
+ 7: 0, # 'р'
+ 8: 0, # 'с'
+ 5: 0, # 'т'
+ 19: 0, # 'у'
+ 29: 0, # 'ф'
+ 25: 0, # 'х'
+ 22: 0, # 'ц'
+ 21: 0, # 'ч'
+ 27: 0, # 'ш'
+ 24: 0, # 'щ'
+ 17: 0, # 'ъ'
+ 52: 0, # 'ь'
+ 42: 0, # 'ю'
+ 16: 0, # 'я'
+ 58: 0, # 'є'
+ 62: 0, # '№'
+ },
+}
+
+# 255: Undefined characters that did not exist in training text
+# 254: Carriage/Return
+# 253: symbol (punctuation) that does not belong to word
+# 252: 0 - 9
+# 251: Control characters
+
+# Character Mapping Table(s):
+ISO_8859_5_BULGARIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 77, # 'A'
+ 66: 90, # 'B'
+ 67: 99, # 'C'
+ 68: 100, # 'D'
+ 69: 72, # 'E'
+ 70: 109, # 'F'
+ 71: 107, # 'G'
+ 72: 101, # 'H'
+ 73: 79, # 'I'
+ 74: 185, # 'J'
+ 75: 81, # 'K'
+ 76: 102, # 'L'
+ 77: 76, # 'M'
+ 78: 94, # 'N'
+ 79: 82, # 'O'
+ 80: 110, # 'P'
+ 81: 186, # 'Q'
+ 82: 108, # 'R'
+ 83: 91, # 'S'
+ 84: 74, # 'T'
+ 85: 119, # 'U'
+ 86: 84, # 'V'
+ 87: 96, # 'W'
+ 88: 111, # 'X'
+ 89: 187, # 'Y'
+ 90: 115, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 65, # 'a'
+ 98: 69, # 'b'
+ 99: 70, # 'c'
+ 100: 66, # 'd'
+ 101: 63, # 'e'
+ 102: 68, # 'f'
+ 103: 112, # 'g'
+ 104: 103, # 'h'
+ 105: 92, # 'i'
+ 106: 194, # 'j'
+ 107: 104, # 'k'
+ 108: 95, # 'l'
+ 109: 86, # 'm'
+ 110: 87, # 'n'
+ 111: 71, # 'o'
+ 112: 116, # 'p'
+ 113: 195, # 'q'
+ 114: 85, # 'r'
+ 115: 93, # 's'
+ 116: 97, # 't'
+ 117: 113, # 'u'
+ 118: 196, # 'v'
+ 119: 197, # 'w'
+ 120: 198, # 'x'
+ 121: 199, # 'y'
+ 122: 200, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 194, # '\x80'
+ 129: 195, # '\x81'
+ 130: 196, # '\x82'
+ 131: 197, # '\x83'
+ 132: 198, # '\x84'
+ 133: 199, # '\x85'
+ 134: 200, # '\x86'
+ 135: 201, # '\x87'
+ 136: 202, # '\x88'
+ 137: 203, # '\x89'
+ 138: 204, # '\x8a'
+ 139: 205, # '\x8b'
+ 140: 206, # '\x8c'
+ 141: 207, # '\x8d'
+ 142: 208, # '\x8e'
+ 143: 209, # '\x8f'
+ 144: 210, # '\x90'
+ 145: 211, # '\x91'
+ 146: 212, # '\x92'
+ 147: 213, # '\x93'
+ 148: 214, # '\x94'
+ 149: 215, # '\x95'
+ 150: 216, # '\x96'
+ 151: 217, # '\x97'
+ 152: 218, # '\x98'
+ 153: 219, # '\x99'
+ 154: 220, # '\x9a'
+ 155: 221, # '\x9b'
+ 156: 222, # '\x9c'
+ 157: 223, # '\x9d'
+ 158: 224, # '\x9e'
+ 159: 225, # '\x9f'
+ 160: 81, # '\xa0'
+ 161: 226, # 'Ё'
+ 162: 227, # 'Ђ'
+ 163: 228, # 'Ѓ'
+ 164: 229, # 'Є'
+ 165: 230, # 'Ѕ'
+ 166: 105, # 'І'
+ 167: 231, # 'Ї'
+ 168: 232, # 'Ј'
+ 169: 233, # 'Љ'
+ 170: 234, # 'Њ'
+ 171: 235, # 'Ћ'
+ 172: 236, # 'Ќ'
+ 173: 45, # '\xad'
+ 174: 237, # 'Ў'
+ 175: 238, # 'Џ'
+ 176: 31, # 'А'
+ 177: 32, # 'Б'
+ 178: 35, # 'В'
+ 179: 43, # 'Г'
+ 180: 37, # 'Д'
+ 181: 44, # 'Е'
+ 182: 55, # 'Ж'
+ 183: 47, # 'З'
+ 184: 40, # 'И'
+ 185: 59, # 'Й'
+ 186: 33, # 'К'
+ 187: 46, # 'Л'
+ 188: 38, # 'М'
+ 189: 36, # 'Н'
+ 190: 41, # 'О'
+ 191: 30, # 'П'
+ 192: 39, # 'Р'
+ 193: 28, # 'С'
+ 194: 34, # 'Т'
+ 195: 51, # 'У'
+ 196: 48, # 'Ф'
+ 197: 49, # 'Х'
+ 198: 53, # 'Ц'
+ 199: 50, # 'Ч'
+ 200: 54, # 'Ш'
+ 201: 57, # 'Щ'
+ 202: 61, # 'Ъ'
+ 203: 239, # 'Ы'
+ 204: 67, # 'Ь'
+ 205: 240, # 'Э'
+ 206: 60, # 'Ю'
+ 207: 56, # 'Я'
+ 208: 1, # 'а'
+ 209: 18, # 'б'
+ 210: 9, # 'в'
+ 211: 20, # 'г'
+ 212: 11, # 'д'
+ 213: 3, # 'е'
+ 214: 23, # 'ж'
+ 215: 15, # 'з'
+ 216: 2, # 'и'
+ 217: 26, # 'й'
+ 218: 12, # 'к'
+ 219: 10, # 'л'
+ 220: 14, # 'м'
+ 221: 6, # 'н'
+ 222: 4, # 'о'
+ 223: 13, # 'п'
+ 224: 7, # 'р'
+ 225: 8, # 'с'
+ 226: 5, # 'т'
+ 227: 19, # 'у'
+ 228: 29, # 'ф'
+ 229: 25, # 'х'
+ 230: 22, # 'ц'
+ 231: 21, # 'ч'
+ 232: 27, # 'ш'
+ 233: 24, # 'щ'
+ 234: 17, # 'ъ'
+ 235: 75, # 'ы'
+ 236: 52, # 'ь'
+ 237: 241, # 'э'
+ 238: 42, # 'ю'
+ 239: 16, # 'я'
+ 240: 62, # '№'
+ 241: 242, # 'ё'
+ 242: 243, # 'ђ'
+ 243: 244, # 'ѓ'
+ 244: 58, # 'є'
+ 245: 245, # 'ѕ'
+ 246: 98, # 'і'
+ 247: 246, # 'ї'
+ 248: 247, # 'ј'
+ 249: 248, # 'љ'
+ 250: 249, # 'њ'
+ 251: 250, # 'ћ'
+ 252: 251, # 'ќ'
+ 253: 91, # '§'
+ 254: 252, # 'ў'
+ 255: 253, # 'џ'
+}
+
+ISO_8859_5_BULGARIAN_MODEL = SingleByteCharSetModel(
+ charset_name="ISO-8859-5",
+ language="Bulgarian",
+ char_to_order_map=ISO_8859_5_BULGARIAN_CHAR_TO_ORDER,
+ language_model=BULGARIAN_LANG_MODEL,
+ typical_positive_ratio=0.969392,
+ keep_ascii_letters=False,
+ alphabet="АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя",
+)
+
+WINDOWS_1251_BULGARIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 77, # 'A'
+ 66: 90, # 'B'
+ 67: 99, # 'C'
+ 68: 100, # 'D'
+ 69: 72, # 'E'
+ 70: 109, # 'F'
+ 71: 107, # 'G'
+ 72: 101, # 'H'
+ 73: 79, # 'I'
+ 74: 185, # 'J'
+ 75: 81, # 'K'
+ 76: 102, # 'L'
+ 77: 76, # 'M'
+ 78: 94, # 'N'
+ 79: 82, # 'O'
+ 80: 110, # 'P'
+ 81: 186, # 'Q'
+ 82: 108, # 'R'
+ 83: 91, # 'S'
+ 84: 74, # 'T'
+ 85: 119, # 'U'
+ 86: 84, # 'V'
+ 87: 96, # 'W'
+ 88: 111, # 'X'
+ 89: 187, # 'Y'
+ 90: 115, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 65, # 'a'
+ 98: 69, # 'b'
+ 99: 70, # 'c'
+ 100: 66, # 'd'
+ 101: 63, # 'e'
+ 102: 68, # 'f'
+ 103: 112, # 'g'
+ 104: 103, # 'h'
+ 105: 92, # 'i'
+ 106: 194, # 'j'
+ 107: 104, # 'k'
+ 108: 95, # 'l'
+ 109: 86, # 'm'
+ 110: 87, # 'n'
+ 111: 71, # 'o'
+ 112: 116, # 'p'
+ 113: 195, # 'q'
+ 114: 85, # 'r'
+ 115: 93, # 's'
+ 116: 97, # 't'
+ 117: 113, # 'u'
+ 118: 196, # 'v'
+ 119: 197, # 'w'
+ 120: 198, # 'x'
+ 121: 199, # 'y'
+ 122: 200, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 206, # 'Ђ'
+ 129: 207, # 'Ѓ'
+ 130: 208, # '‚'
+ 131: 209, # 'ѓ'
+ 132: 210, # '„'
+ 133: 211, # '…'
+ 134: 212, # '†'
+ 135: 213, # '‡'
+ 136: 120, # '€'
+ 137: 214, # '‰'
+ 138: 215, # 'Љ'
+ 139: 216, # '‹'
+ 140: 217, # 'Њ'
+ 141: 218, # 'Ќ'
+ 142: 219, # 'Ћ'
+ 143: 220, # 'Џ'
+ 144: 221, # 'ђ'
+ 145: 78, # '‘'
+ 146: 64, # '’'
+ 147: 83, # '“'
+ 148: 121, # '”'
+ 149: 98, # '•'
+ 150: 117, # '–'
+ 151: 105, # '—'
+ 152: 222, # None
+ 153: 223, # '™'
+ 154: 224, # 'љ'
+ 155: 225, # '›'
+ 156: 226, # 'њ'
+ 157: 227, # 'ќ'
+ 158: 228, # 'ћ'
+ 159: 229, # 'џ'
+ 160: 88, # '\xa0'
+ 161: 230, # 'Ў'
+ 162: 231, # 'ў'
+ 163: 232, # 'Ј'
+ 164: 233, # '¤'
+ 165: 122, # 'Ґ'
+ 166: 89, # '¦'
+ 167: 106, # '§'
+ 168: 234, # 'Ё'
+ 169: 235, # '©'
+ 170: 236, # 'Є'
+ 171: 237, # '«'
+ 172: 238, # '¬'
+ 173: 45, # '\xad'
+ 174: 239, # '®'
+ 175: 240, # 'Ї'
+ 176: 73, # '°'
+ 177: 80, # '±'
+ 178: 118, # 'І'
+ 179: 114, # 'і'
+ 180: 241, # 'ґ'
+ 181: 242, # 'µ'
+ 182: 243, # '¶'
+ 183: 244, # '·'
+ 184: 245, # 'ё'
+ 185: 62, # '№'
+ 186: 58, # 'є'
+ 187: 246, # '»'
+ 188: 247, # 'ј'
+ 189: 248, # 'Ѕ'
+ 190: 249, # 'ѕ'
+ 191: 250, # 'ї'
+ 192: 31, # 'А'
+ 193: 32, # 'Б'
+ 194: 35, # 'В'
+ 195: 43, # 'Г'
+ 196: 37, # 'Д'
+ 197: 44, # 'Е'
+ 198: 55, # 'Ж'
+ 199: 47, # 'З'
+ 200: 40, # 'И'
+ 201: 59, # 'Й'
+ 202: 33, # 'К'
+ 203: 46, # 'Л'
+ 204: 38, # 'М'
+ 205: 36, # 'Н'
+ 206: 41, # 'О'
+ 207: 30, # 'П'
+ 208: 39, # 'Р'
+ 209: 28, # 'С'
+ 210: 34, # 'Т'
+ 211: 51, # 'У'
+ 212: 48, # 'Ф'
+ 213: 49, # 'Х'
+ 214: 53, # 'Ц'
+ 215: 50, # 'Ч'
+ 216: 54, # 'Ш'
+ 217: 57, # 'Щ'
+ 218: 61, # 'Ъ'
+ 219: 251, # 'Ы'
+ 220: 67, # 'Ь'
+ 221: 252, # 'Э'
+ 222: 60, # 'Ю'
+ 223: 56, # 'Я'
+ 224: 1, # 'а'
+ 225: 18, # 'б'
+ 226: 9, # 'в'
+ 227: 20, # 'г'
+ 228: 11, # 'д'
+ 229: 3, # 'е'
+ 230: 23, # 'ж'
+ 231: 15, # 'з'
+ 232: 2, # 'и'
+ 233: 26, # 'й'
+ 234: 12, # 'к'
+ 235: 10, # 'л'
+ 236: 14, # 'м'
+ 237: 6, # 'н'
+ 238: 4, # 'о'
+ 239: 13, # 'п'
+ 240: 7, # 'р'
+ 241: 8, # 'с'
+ 242: 5, # 'т'
+ 243: 19, # 'у'
+ 244: 29, # 'ф'
+ 245: 25, # 'х'
+ 246: 22, # 'ц'
+ 247: 21, # 'ч'
+ 248: 27, # 'ш'
+ 249: 24, # 'щ'
+ 250: 17, # 'ъ'
+ 251: 75, # 'ы'
+ 252: 52, # 'ь'
+ 253: 253, # 'э'
+ 254: 42, # 'ю'
+ 255: 16, # 'я'
+}
+
+WINDOWS_1251_BULGARIAN_MODEL = SingleByteCharSetModel(
+ charset_name="windows-1251",
+ language="Bulgarian",
+ char_to_order_map=WINDOWS_1251_BULGARIAN_CHAR_TO_ORDER,
+ language_model=BULGARIAN_LANG_MODEL,
+ typical_positive_ratio=0.969392,
+ keep_ascii_letters=False,
+ alphabet="АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя",
+)
diff --git a/lib/chardet/langgreekmodel.py b/lib/chardet/langgreekmodel.py
new file mode 100644
index 0000000..0471d8b
--- /dev/null
+++ b/lib/chardet/langgreekmodel.py
@@ -0,0 +1,4397 @@
+from chardet.sbcharsetprober import SingleByteCharSetModel
+
+# 3: Positive
+# 2: Likely
+# 1: Unlikely
+# 0: Negative
+
+GREEK_LANG_MODEL = {
+ 60: { # 'e'
+ 60: 2, # 'e'
+ 55: 1, # 'o'
+ 58: 2, # 't'
+ 36: 1, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 1, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 55: { # 'o'
+ 60: 0, # 'e'
+ 55: 2, # 'o'
+ 58: 2, # 't'
+ 36: 1, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 1, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 1, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 58: { # 't'
+ 60: 2, # 'e'
+ 55: 1, # 'o'
+ 58: 1, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 1, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 36: { # '·'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 61: { # 'Ά'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 1, # 'γ'
+ 21: 2, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 1, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 46: { # 'Έ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 2, # 'β'
+ 20: 2, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 2, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 2, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 1, # 'σ'
+ 2: 2, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 3, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 54: { # 'Ό'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 2, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 2, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 2, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 2, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 31: { # 'Α'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 2, # 'Β'
+ 43: 2, # 'Γ'
+ 41: 1, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 2, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 2, # 'Κ'
+ 53: 2, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 1, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 2, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 2, # 'Υ'
+ 56: 2, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 2, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 1, # 'θ'
+ 5: 0, # 'ι'
+ 11: 2, # 'κ'
+ 16: 3, # 'λ'
+ 10: 2, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 2, # 'ς'
+ 7: 2, # 'σ'
+ 2: 0, # 'τ'
+ 12: 3, # 'υ'
+ 28: 2, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 51: { # 'Β'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 1, # 'Ε'
+ 40: 1, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 1, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 1, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 2, # 'ή'
+ 15: 0, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 43: { # 'Γ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 1, # 'Α'
+ 51: 0, # 'Β'
+ 43: 2, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 1, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 1, # 'Κ'
+ 53: 1, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 1, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 2, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 1, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 2, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 41: { # 'Δ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 2, # 'ή'
+ 15: 2, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 1, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 34: { # 'Ε'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 2, # 'Γ'
+ 41: 2, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 2, # 'Κ'
+ 53: 2, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 1, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 2, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 2, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 2, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 3, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 3, # 'γ'
+ 21: 2, # 'δ'
+ 3: 1, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 1, # 'θ'
+ 5: 2, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 2, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 3, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 2, # 'σ'
+ 2: 2, # 'τ'
+ 12: 2, # 'υ'
+ 28: 2, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 1, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 40: { # 'Η'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 1, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 2, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 2, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 2, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 1, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 1, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 1, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 52: { # 'Θ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 1, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 1, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 2, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 47: { # 'Ι'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 1, # 'Β'
+ 43: 1, # 'Γ'
+ 41: 2, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 2, # 'Κ'
+ 53: 2, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 2, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 2, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 1, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 2, # 'σ'
+ 2: 1, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 1, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 44: { # 'Κ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 1, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 1, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 1, # 'Τ'
+ 45: 2, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 1, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 2, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 53: { # 'Λ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 2, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 2, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 0, # 'ή'
+ 15: 2, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 1, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 2, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 38: { # 'Μ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 2, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 2, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 2, # 'ή'
+ 15: 2, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 3, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 2, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 49: { # 'Ν'
+ 60: 2, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 2, # 'έ'
+ 22: 0, # 'ή'
+ 15: 2, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 1, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 1, # 'ω'
+ 19: 2, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 59: { # 'Ξ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 1, # 'Ε'
+ 40: 1, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 1, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 2, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 39: { # 'Ο'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 1, # 'Β'
+ 43: 2, # 'Γ'
+ 41: 2, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 1, # 'Η'
+ 52: 2, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 2, # 'Κ'
+ 53: 2, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 2, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 2, # 'Υ'
+ 56: 2, # 'Φ'
+ 50: 2, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 2, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 2, # 'κ'
+ 16: 2, # 'λ'
+ 10: 2, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 2, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 2, # 'τ'
+ 12: 2, # 'υ'
+ 28: 1, # 'φ'
+ 23: 1, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 35: { # 'Π'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 2, # 'Λ'
+ 38: 1, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 1, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 1, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 1, # 'έ'
+ 22: 1, # 'ή'
+ 15: 2, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 2, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 2, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 48: { # 'Ρ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 1, # 'Γ'
+ 41: 1, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 1, # 'Τ'
+ 45: 1, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 1, # 'Χ'
+ 57: 1, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 2, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 1, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 3, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 0, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 37: { # 'Σ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 1, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 2, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 2, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 2, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 2, # 'ή'
+ 15: 2, # 'ί'
+ 1: 2, # 'α'
+ 29: 2, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 2, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 2, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 0, # 'φ'
+ 23: 2, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 0, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 33: { # 'Τ'
+ 60: 0, # 'e'
+ 55: 1, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 2, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 2, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 1, # 'Τ'
+ 45: 1, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 2, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 0, # 'ή'
+ 15: 2, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 2, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 2, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 2, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 45: { # 'Υ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 2, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 1, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 2, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 1, # 'Λ'
+ 38: 2, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 2, # 'Π'
+ 48: 1, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 1, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 3, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 56: { # 'Φ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 1, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 1, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 2, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 2, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 1, # 'ύ'
+ 27: 1, # 'ώ'
+ },
+ 50: { # 'Χ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 1, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 2, # 'Ε'
+ 40: 2, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 2, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 1, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 1, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 1, # 'Χ'
+ 57: 1, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 2, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 2, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 57: { # 'Ω'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 1, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 1, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 2, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 2, # 'Ρ'
+ 37: 2, # 'Σ'
+ 33: 2, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 2, # 'ρ'
+ 14: 2, # 'ς'
+ 7: 2, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 1, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 17: { # 'ά'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 3, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 3, # 'ε'
+ 32: 3, # 'ζ'
+ 13: 0, # 'η'
+ 25: 3, # 'θ'
+ 5: 2, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 3, # 'φ'
+ 23: 3, # 'χ'
+ 42: 3, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 18: { # 'έ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 3, # 'α'
+ 29: 2, # 'β'
+ 20: 3, # 'γ'
+ 21: 2, # 'δ'
+ 3: 3, # 'ε'
+ 32: 2, # 'ζ'
+ 13: 0, # 'η'
+ 25: 3, # 'θ'
+ 5: 0, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 3, # 'φ'
+ 23: 3, # 'χ'
+ 42: 3, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 22: { # 'ή'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 1, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 3, # 'θ'
+ 5: 0, # 'ι'
+ 11: 3, # 'κ'
+ 16: 2, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 15: { # 'ί'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 3, # 'α'
+ 29: 2, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 3, # 'ε'
+ 32: 3, # 'ζ'
+ 13: 3, # 'η'
+ 25: 3, # 'θ'
+ 5: 0, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 1, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 1: { # 'α'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 2, # 'έ'
+ 22: 0, # 'ή'
+ 15: 3, # 'ί'
+ 1: 0, # 'α'
+ 29: 3, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 2, # 'ε'
+ 32: 3, # 'ζ'
+ 13: 1, # 'η'
+ 25: 3, # 'θ'
+ 5: 3, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 3, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 2, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 29: { # 'β'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 2, # 'έ'
+ 22: 3, # 'ή'
+ 15: 2, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 2, # 'γ'
+ 21: 2, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 3, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 2, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 20: { # 'γ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 3, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 3, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 21: { # 'δ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 3, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 3: { # 'ε'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 3, # 'ί'
+ 1: 2, # 'α'
+ 29: 3, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 2, # 'ε'
+ 32: 2, # 'ζ'
+ 13: 0, # 'η'
+ 25: 3, # 'θ'
+ 5: 3, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 3, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 2, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 32: { # 'ζ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 2, # 'ή'
+ 15: 2, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 1, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 2, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 13: { # 'η'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 3, # 'γ'
+ 21: 2, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 3, # 'θ'
+ 5: 0, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 2, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 25: { # 'θ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 2, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 1, # 'λ'
+ 10: 3, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 3, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 5: { # 'ι'
+ 60: 0, # 'e'
+ 55: 1, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 1, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 0, # 'ί'
+ 1: 3, # 'α'
+ 29: 3, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 3, # 'ε'
+ 32: 2, # 'ζ'
+ 13: 3, # 'η'
+ 25: 3, # 'θ'
+ 5: 0, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 11: { # 'κ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 3, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 2, # 'θ'
+ 5: 3, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 2, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 2, # 'φ'
+ 23: 2, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 16: { # 'λ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 1, # 'β'
+ 20: 2, # 'γ'
+ 21: 1, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 2, # 'θ'
+ 5: 3, # 'ι'
+ 11: 2, # 'κ'
+ 16: 3, # 'λ'
+ 10: 2, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 2, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 10: { # 'μ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 1, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 3, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 2, # 'υ'
+ 28: 3, # 'φ'
+ 23: 0, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 6: { # 'ν'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 3, # 'δ'
+ 3: 3, # 'ε'
+ 32: 2, # 'ζ'
+ 13: 3, # 'η'
+ 25: 3, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 1, # 'λ'
+ 10: 0, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 30: { # 'ξ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 2, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 3, # 'τ'
+ 12: 2, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 2, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 1, # 'ώ'
+ },
+ 4: { # 'ο'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 2, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 2, # 'α'
+ 29: 3, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 3, # 'θ'
+ 5: 3, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 3, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 1, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 9: { # 'π'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 3, # 'λ'
+ 10: 0, # 'μ'
+ 6: 2, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 2, # 'ς'
+ 7: 0, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 0, # 'φ'
+ 23: 2, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 8: { # 'ρ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 2, # 'β'
+ 20: 3, # 'γ'
+ 21: 2, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 3, # 'θ'
+ 5: 3, # 'ι'
+ 11: 3, # 'κ'
+ 16: 1, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 2, # 'π'
+ 8: 2, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 2, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 3, # 'φ'
+ 23: 3, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 14: { # 'ς'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 2, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 0, # 'θ'
+ 5: 0, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 0, # 'τ'
+ 12: 0, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 7: { # 'σ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 3, # 'β'
+ 20: 0, # 'γ'
+ 21: 2, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 3, # 'θ'
+ 5: 3, # 'ι'
+ 11: 3, # 'κ'
+ 16: 2, # 'λ'
+ 10: 3, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 3, # 'φ'
+ 23: 3, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 2: { # 'τ'
+ 60: 0, # 'e'
+ 55: 2, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 2, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 3, # 'ι'
+ 11: 2, # 'κ'
+ 16: 2, # 'λ'
+ 10: 3, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 2, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 12: { # 'υ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 3, # 'ή'
+ 15: 2, # 'ί'
+ 1: 3, # 'α'
+ 29: 2, # 'β'
+ 20: 3, # 'γ'
+ 21: 2, # 'δ'
+ 3: 2, # 'ε'
+ 32: 2, # 'ζ'
+ 13: 2, # 'η'
+ 25: 3, # 'θ'
+ 5: 2, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 3, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 2, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 28: { # 'φ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 3, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 2, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 0, # 'μ'
+ 6: 1, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 1, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 2, # 'ύ'
+ 27: 2, # 'ώ'
+ },
+ 23: { # 'χ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 3, # 'ά'
+ 18: 2, # 'έ'
+ 22: 3, # 'ή'
+ 15: 3, # 'ί'
+ 1: 3, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 2, # 'θ'
+ 5: 3, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 2, # 'μ'
+ 6: 3, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 0, # 'π'
+ 8: 3, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 3, # 'τ'
+ 12: 3, # 'υ'
+ 28: 0, # 'φ'
+ 23: 2, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 3, # 'ω'
+ 19: 3, # 'ό'
+ 26: 3, # 'ύ'
+ 27: 3, # 'ώ'
+ },
+ 42: { # 'ψ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 2, # 'ά'
+ 18: 2, # 'έ'
+ 22: 1, # 'ή'
+ 15: 2, # 'ί'
+ 1: 2, # 'α'
+ 29: 0, # 'β'
+ 20: 0, # 'γ'
+ 21: 0, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 3, # 'η'
+ 25: 0, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 0, # 'λ'
+ 10: 0, # 'μ'
+ 6: 0, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 0, # 'π'
+ 8: 0, # 'ρ'
+ 14: 0, # 'ς'
+ 7: 0, # 'σ'
+ 2: 2, # 'τ'
+ 12: 1, # 'υ'
+ 28: 0, # 'φ'
+ 23: 0, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 24: { # 'ω'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 1, # 'ά'
+ 18: 0, # 'έ'
+ 22: 2, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 2, # 'β'
+ 20: 3, # 'γ'
+ 21: 2, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 0, # 'η'
+ 25: 3, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 0, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 2, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 19: { # 'ό'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 3, # 'β'
+ 20: 3, # 'γ'
+ 21: 3, # 'δ'
+ 3: 1, # 'ε'
+ 32: 2, # 'ζ'
+ 13: 2, # 'η'
+ 25: 2, # 'θ'
+ 5: 2, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 1, # 'ξ'
+ 4: 2, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 3, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 26: { # 'ύ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 2, # 'α'
+ 29: 2, # 'β'
+ 20: 2, # 'γ'
+ 21: 1, # 'δ'
+ 3: 3, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 2, # 'η'
+ 25: 3, # 'θ'
+ 5: 0, # 'ι'
+ 11: 3, # 'κ'
+ 16: 3, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 2, # 'ξ'
+ 4: 3, # 'ο'
+ 9: 3, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 2, # 'φ'
+ 23: 2, # 'χ'
+ 42: 2, # 'ψ'
+ 24: 2, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+ 27: { # 'ώ'
+ 60: 0, # 'e'
+ 55: 0, # 'o'
+ 58: 0, # 't'
+ 36: 0, # '·'
+ 61: 0, # 'Ά'
+ 46: 0, # 'Έ'
+ 54: 0, # 'Ό'
+ 31: 0, # 'Α'
+ 51: 0, # 'Β'
+ 43: 0, # 'Γ'
+ 41: 0, # 'Δ'
+ 34: 0, # 'Ε'
+ 40: 0, # 'Η'
+ 52: 0, # 'Θ'
+ 47: 0, # 'Ι'
+ 44: 0, # 'Κ'
+ 53: 0, # 'Λ'
+ 38: 0, # 'Μ'
+ 49: 0, # 'Ν'
+ 59: 0, # 'Ξ'
+ 39: 0, # 'Ο'
+ 35: 0, # 'Π'
+ 48: 0, # 'Ρ'
+ 37: 0, # 'Σ'
+ 33: 0, # 'Τ'
+ 45: 0, # 'Υ'
+ 56: 0, # 'Φ'
+ 50: 0, # 'Χ'
+ 57: 0, # 'Ω'
+ 17: 0, # 'ά'
+ 18: 0, # 'έ'
+ 22: 0, # 'ή'
+ 15: 0, # 'ί'
+ 1: 0, # 'α'
+ 29: 1, # 'β'
+ 20: 0, # 'γ'
+ 21: 3, # 'δ'
+ 3: 0, # 'ε'
+ 32: 0, # 'ζ'
+ 13: 1, # 'η'
+ 25: 2, # 'θ'
+ 5: 2, # 'ι'
+ 11: 0, # 'κ'
+ 16: 2, # 'λ'
+ 10: 3, # 'μ'
+ 6: 3, # 'ν'
+ 30: 1, # 'ξ'
+ 4: 0, # 'ο'
+ 9: 2, # 'π'
+ 8: 3, # 'ρ'
+ 14: 3, # 'ς'
+ 7: 3, # 'σ'
+ 2: 3, # 'τ'
+ 12: 0, # 'υ'
+ 28: 1, # 'φ'
+ 23: 1, # 'χ'
+ 42: 0, # 'ψ'
+ 24: 0, # 'ω'
+ 19: 0, # 'ό'
+ 26: 0, # 'ύ'
+ 27: 0, # 'ώ'
+ },
+}
+
+# 255: Undefined characters that did not exist in training text
+# 254: Carriage/Return
+# 253: symbol (punctuation) that does not belong to word
+# 252: 0 - 9
+# 251: Control characters
+
+# Character Mapping Table(s):
+WINDOWS_1253_GREEK_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 82, # 'A'
+ 66: 100, # 'B'
+ 67: 104, # 'C'
+ 68: 94, # 'D'
+ 69: 98, # 'E'
+ 70: 101, # 'F'
+ 71: 116, # 'G'
+ 72: 102, # 'H'
+ 73: 111, # 'I'
+ 74: 187, # 'J'
+ 75: 117, # 'K'
+ 76: 92, # 'L'
+ 77: 88, # 'M'
+ 78: 113, # 'N'
+ 79: 85, # 'O'
+ 80: 79, # 'P'
+ 81: 118, # 'Q'
+ 82: 105, # 'R'
+ 83: 83, # 'S'
+ 84: 67, # 'T'
+ 85: 114, # 'U'
+ 86: 119, # 'V'
+ 87: 95, # 'W'
+ 88: 99, # 'X'
+ 89: 109, # 'Y'
+ 90: 188, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 72, # 'a'
+ 98: 70, # 'b'
+ 99: 80, # 'c'
+ 100: 81, # 'd'
+ 101: 60, # 'e'
+ 102: 96, # 'f'
+ 103: 93, # 'g'
+ 104: 89, # 'h'
+ 105: 68, # 'i'
+ 106: 120, # 'j'
+ 107: 97, # 'k'
+ 108: 77, # 'l'
+ 109: 86, # 'm'
+ 110: 69, # 'n'
+ 111: 55, # 'o'
+ 112: 78, # 'p'
+ 113: 115, # 'q'
+ 114: 65, # 'r'
+ 115: 66, # 's'
+ 116: 58, # 't'
+ 117: 76, # 'u'
+ 118: 106, # 'v'
+ 119: 103, # 'w'
+ 120: 87, # 'x'
+ 121: 107, # 'y'
+ 122: 112, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 255, # '€'
+ 129: 255, # None
+ 130: 255, # '‚'
+ 131: 255, # 'ƒ'
+ 132: 255, # '„'
+ 133: 255, # '…'
+ 134: 255, # '†'
+ 135: 255, # '‡'
+ 136: 255, # None
+ 137: 255, # '‰'
+ 138: 255, # None
+ 139: 255, # '‹'
+ 140: 255, # None
+ 141: 255, # None
+ 142: 255, # None
+ 143: 255, # None
+ 144: 255, # None
+ 145: 255, # '‘'
+ 146: 255, # '’'
+ 147: 255, # '“'
+ 148: 255, # '”'
+ 149: 255, # '•'
+ 150: 255, # '–'
+ 151: 255, # '—'
+ 152: 255, # None
+ 153: 255, # '™'
+ 154: 255, # None
+ 155: 255, # '›'
+ 156: 255, # None
+ 157: 255, # None
+ 158: 255, # None
+ 159: 255, # None
+ 160: 253, # '\xa0'
+ 161: 233, # '΅'
+ 162: 61, # 'Ά'
+ 163: 253, # '£'
+ 164: 253, # '¤'
+ 165: 253, # '¥'
+ 166: 253, # '¦'
+ 167: 253, # '§'
+ 168: 253, # '¨'
+ 169: 253, # '©'
+ 170: 253, # None
+ 171: 253, # '«'
+ 172: 253, # '¬'
+ 173: 74, # '\xad'
+ 174: 253, # '®'
+ 175: 253, # '―'
+ 176: 253, # '°'
+ 177: 253, # '±'
+ 178: 253, # '²'
+ 179: 253, # '³'
+ 180: 247, # '΄'
+ 181: 253, # 'µ'
+ 182: 253, # '¶'
+ 183: 36, # '·'
+ 184: 46, # 'Έ'
+ 185: 71, # 'Ή'
+ 186: 73, # 'Ί'
+ 187: 253, # '»'
+ 188: 54, # 'Ό'
+ 189: 253, # '½'
+ 190: 108, # 'Ύ'
+ 191: 123, # 'Ώ'
+ 192: 110, # 'ΐ'
+ 193: 31, # 'Α'
+ 194: 51, # 'Β'
+ 195: 43, # 'Γ'
+ 196: 41, # 'Δ'
+ 197: 34, # 'Ε'
+ 198: 91, # 'Ζ'
+ 199: 40, # 'Η'
+ 200: 52, # 'Θ'
+ 201: 47, # 'Ι'
+ 202: 44, # 'Κ'
+ 203: 53, # 'Λ'
+ 204: 38, # 'Μ'
+ 205: 49, # 'Ν'
+ 206: 59, # 'Ξ'
+ 207: 39, # 'Ο'
+ 208: 35, # 'Π'
+ 209: 48, # 'Ρ'
+ 210: 250, # None
+ 211: 37, # 'Σ'
+ 212: 33, # 'Τ'
+ 213: 45, # 'Υ'
+ 214: 56, # 'Φ'
+ 215: 50, # 'Χ'
+ 216: 84, # 'Ψ'
+ 217: 57, # 'Ω'
+ 218: 120, # 'Ϊ'
+ 219: 121, # 'Ϋ'
+ 220: 17, # 'ά'
+ 221: 18, # 'έ'
+ 222: 22, # 'ή'
+ 223: 15, # 'ί'
+ 224: 124, # 'ΰ'
+ 225: 1, # 'α'
+ 226: 29, # 'β'
+ 227: 20, # 'γ'
+ 228: 21, # 'δ'
+ 229: 3, # 'ε'
+ 230: 32, # 'ζ'
+ 231: 13, # 'η'
+ 232: 25, # 'θ'
+ 233: 5, # 'ι'
+ 234: 11, # 'κ'
+ 235: 16, # 'λ'
+ 236: 10, # 'μ'
+ 237: 6, # 'ν'
+ 238: 30, # 'ξ'
+ 239: 4, # 'ο'
+ 240: 9, # 'π'
+ 241: 8, # 'ρ'
+ 242: 14, # 'ς'
+ 243: 7, # 'σ'
+ 244: 2, # 'τ'
+ 245: 12, # 'υ'
+ 246: 28, # 'φ'
+ 247: 23, # 'χ'
+ 248: 42, # 'ψ'
+ 249: 24, # 'ω'
+ 250: 64, # 'ϊ'
+ 251: 75, # 'ϋ'
+ 252: 19, # 'ό'
+ 253: 26, # 'ύ'
+ 254: 27, # 'ώ'
+ 255: 253, # None
+}
+
+WINDOWS_1253_GREEK_MODEL = SingleByteCharSetModel(
+ charset_name="windows-1253",
+ language="Greek",
+ char_to_order_map=WINDOWS_1253_GREEK_CHAR_TO_ORDER,
+ language_model=GREEK_LANG_MODEL,
+ typical_positive_ratio=0.982851,
+ keep_ascii_letters=False,
+ alphabet="ΆΈΉΊΌΎΏΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩάέήίαβγδεζηθικλμνξοπρςστυφχψωόύώ",
+)
+
+ISO_8859_7_GREEK_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 82, # 'A'
+ 66: 100, # 'B'
+ 67: 104, # 'C'
+ 68: 94, # 'D'
+ 69: 98, # 'E'
+ 70: 101, # 'F'
+ 71: 116, # 'G'
+ 72: 102, # 'H'
+ 73: 111, # 'I'
+ 74: 187, # 'J'
+ 75: 117, # 'K'
+ 76: 92, # 'L'
+ 77: 88, # 'M'
+ 78: 113, # 'N'
+ 79: 85, # 'O'
+ 80: 79, # 'P'
+ 81: 118, # 'Q'
+ 82: 105, # 'R'
+ 83: 83, # 'S'
+ 84: 67, # 'T'
+ 85: 114, # 'U'
+ 86: 119, # 'V'
+ 87: 95, # 'W'
+ 88: 99, # 'X'
+ 89: 109, # 'Y'
+ 90: 188, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 72, # 'a'
+ 98: 70, # 'b'
+ 99: 80, # 'c'
+ 100: 81, # 'd'
+ 101: 60, # 'e'
+ 102: 96, # 'f'
+ 103: 93, # 'g'
+ 104: 89, # 'h'
+ 105: 68, # 'i'
+ 106: 120, # 'j'
+ 107: 97, # 'k'
+ 108: 77, # 'l'
+ 109: 86, # 'm'
+ 110: 69, # 'n'
+ 111: 55, # 'o'
+ 112: 78, # 'p'
+ 113: 115, # 'q'
+ 114: 65, # 'r'
+ 115: 66, # 's'
+ 116: 58, # 't'
+ 117: 76, # 'u'
+ 118: 106, # 'v'
+ 119: 103, # 'w'
+ 120: 87, # 'x'
+ 121: 107, # 'y'
+ 122: 112, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 255, # '\x80'
+ 129: 255, # '\x81'
+ 130: 255, # '\x82'
+ 131: 255, # '\x83'
+ 132: 255, # '\x84'
+ 133: 255, # '\x85'
+ 134: 255, # '\x86'
+ 135: 255, # '\x87'
+ 136: 255, # '\x88'
+ 137: 255, # '\x89'
+ 138: 255, # '\x8a'
+ 139: 255, # '\x8b'
+ 140: 255, # '\x8c'
+ 141: 255, # '\x8d'
+ 142: 255, # '\x8e'
+ 143: 255, # '\x8f'
+ 144: 255, # '\x90'
+ 145: 255, # '\x91'
+ 146: 255, # '\x92'
+ 147: 255, # '\x93'
+ 148: 255, # '\x94'
+ 149: 255, # '\x95'
+ 150: 255, # '\x96'
+ 151: 255, # '\x97'
+ 152: 255, # '\x98'
+ 153: 255, # '\x99'
+ 154: 255, # '\x9a'
+ 155: 255, # '\x9b'
+ 156: 255, # '\x9c'
+ 157: 255, # '\x9d'
+ 158: 255, # '\x9e'
+ 159: 255, # '\x9f'
+ 160: 253, # '\xa0'
+ 161: 233, # '‘'
+ 162: 90, # '’'
+ 163: 253, # '£'
+ 164: 253, # '€'
+ 165: 253, # '₯'
+ 166: 253, # '¦'
+ 167: 253, # '§'
+ 168: 253, # '¨'
+ 169: 253, # '©'
+ 170: 253, # 'ͺ'
+ 171: 253, # '«'
+ 172: 253, # '¬'
+ 173: 74, # '\xad'
+ 174: 253, # None
+ 175: 253, # '―'
+ 176: 253, # '°'
+ 177: 253, # '±'
+ 178: 253, # '²'
+ 179: 253, # '³'
+ 180: 247, # '΄'
+ 181: 248, # '΅'
+ 182: 61, # 'Ά'
+ 183: 36, # '·'
+ 184: 46, # 'Έ'
+ 185: 71, # 'Ή'
+ 186: 73, # 'Ί'
+ 187: 253, # '»'
+ 188: 54, # 'Ό'
+ 189: 253, # '½'
+ 190: 108, # 'Ύ'
+ 191: 123, # 'Ώ'
+ 192: 110, # 'ΐ'
+ 193: 31, # 'Α'
+ 194: 51, # 'Β'
+ 195: 43, # 'Γ'
+ 196: 41, # 'Δ'
+ 197: 34, # 'Ε'
+ 198: 91, # 'Ζ'
+ 199: 40, # 'Η'
+ 200: 52, # 'Θ'
+ 201: 47, # 'Ι'
+ 202: 44, # 'Κ'
+ 203: 53, # 'Λ'
+ 204: 38, # 'Μ'
+ 205: 49, # 'Ν'
+ 206: 59, # 'Ξ'
+ 207: 39, # 'Ο'
+ 208: 35, # 'Π'
+ 209: 48, # 'Ρ'
+ 210: 250, # None
+ 211: 37, # 'Σ'
+ 212: 33, # 'Τ'
+ 213: 45, # 'Υ'
+ 214: 56, # 'Φ'
+ 215: 50, # 'Χ'
+ 216: 84, # 'Ψ'
+ 217: 57, # 'Ω'
+ 218: 120, # 'Ϊ'
+ 219: 121, # 'Ϋ'
+ 220: 17, # 'ά'
+ 221: 18, # 'έ'
+ 222: 22, # 'ή'
+ 223: 15, # 'ί'
+ 224: 124, # 'ΰ'
+ 225: 1, # 'α'
+ 226: 29, # 'β'
+ 227: 20, # 'γ'
+ 228: 21, # 'δ'
+ 229: 3, # 'ε'
+ 230: 32, # 'ζ'
+ 231: 13, # 'η'
+ 232: 25, # 'θ'
+ 233: 5, # 'ι'
+ 234: 11, # 'κ'
+ 235: 16, # 'λ'
+ 236: 10, # 'μ'
+ 237: 6, # 'ν'
+ 238: 30, # 'ξ'
+ 239: 4, # 'ο'
+ 240: 9, # 'π'
+ 241: 8, # 'ρ'
+ 242: 14, # 'ς'
+ 243: 7, # 'σ'
+ 244: 2, # 'τ'
+ 245: 12, # 'υ'
+ 246: 28, # 'φ'
+ 247: 23, # 'χ'
+ 248: 42, # 'ψ'
+ 249: 24, # 'ω'
+ 250: 64, # 'ϊ'
+ 251: 75, # 'ϋ'
+ 252: 19, # 'ό'
+ 253: 26, # 'ύ'
+ 254: 27, # 'ώ'
+ 255: 253, # None
+}
+
+ISO_8859_7_GREEK_MODEL = SingleByteCharSetModel(
+ charset_name="ISO-8859-7",
+ language="Greek",
+ char_to_order_map=ISO_8859_7_GREEK_CHAR_TO_ORDER,
+ language_model=GREEK_LANG_MODEL,
+ typical_positive_ratio=0.982851,
+ keep_ascii_letters=False,
+ alphabet="ΆΈΉΊΌΎΏΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩάέήίαβγδεζηθικλμνξοπρςστυφχψωόύώ",
+)
diff --git a/lib/chardet/langhebrewmodel.py b/lib/chardet/langhebrewmodel.py
new file mode 100644
index 0000000..86b3c5e
--- /dev/null
+++ b/lib/chardet/langhebrewmodel.py
@@ -0,0 +1,4380 @@
+from chardet.sbcharsetprober import SingleByteCharSetModel
+
+# 3: Positive
+# 2: Likely
+# 1: Unlikely
+# 0: Negative
+
+HEBREW_LANG_MODEL = {
+ 50: { # 'a'
+ 50: 0, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 2, # 'l'
+ 54: 2, # 'n'
+ 49: 0, # 'o'
+ 51: 2, # 'r'
+ 43: 1, # 's'
+ 44: 2, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 1, # 'ק'
+ 7: 0, # 'ר'
+ 10: 1, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 60: { # 'c'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 0, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 0, # 'n'
+ 49: 1, # 'o'
+ 51: 1, # 'r'
+ 43: 1, # 's'
+ 44: 2, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 61: { # 'd'
+ 50: 1, # 'a'
+ 60: 0, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 1, # 'n'
+ 49: 2, # 'o'
+ 51: 1, # 'r'
+ 43: 1, # 's'
+ 44: 0, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 1, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 42: { # 'e'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 2, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 2, # 'l'
+ 54: 2, # 'n'
+ 49: 1, # 'o'
+ 51: 2, # 'r'
+ 43: 2, # 's'
+ 44: 2, # 't'
+ 63: 1, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 1, # '–'
+ 52: 2, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 53: { # 'i'
+ 50: 1, # 'a'
+ 60: 2, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 0, # 'i'
+ 56: 1, # 'l'
+ 54: 2, # 'n'
+ 49: 2, # 'o'
+ 51: 1, # 'r'
+ 43: 2, # 's'
+ 44: 2, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 56: { # 'l'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 2, # 'e'
+ 53: 2, # 'i'
+ 56: 2, # 'l'
+ 54: 1, # 'n'
+ 49: 1, # 'o'
+ 51: 0, # 'r'
+ 43: 1, # 's'
+ 44: 1, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 54: { # 'n'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 1, # 'n'
+ 49: 1, # 'o'
+ 51: 0, # 'r'
+ 43: 1, # 's'
+ 44: 2, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 2, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 49: { # 'o'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 2, # 'n'
+ 49: 1, # 'o'
+ 51: 2, # 'r'
+ 43: 1, # 's'
+ 44: 1, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 51: { # 'r'
+ 50: 2, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 2, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 1, # 'n'
+ 49: 2, # 'o'
+ 51: 1, # 'r'
+ 43: 1, # 's'
+ 44: 1, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 2, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 43: { # 's'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 0, # 'd'
+ 42: 2, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 1, # 'n'
+ 49: 1, # 'o'
+ 51: 1, # 'r'
+ 43: 1, # 's'
+ 44: 2, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 2, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 44: { # 't'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 0, # 'd'
+ 42: 2, # 'e'
+ 53: 2, # 'i'
+ 56: 1, # 'l'
+ 54: 0, # 'n'
+ 49: 1, # 'o'
+ 51: 1, # 'r'
+ 43: 1, # 's'
+ 44: 1, # 't'
+ 63: 1, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 2, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 63: { # 'u'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 1, # 'n'
+ 49: 0, # 'o'
+ 51: 1, # 'r'
+ 43: 2, # 's'
+ 44: 1, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 34: { # '\xa0'
+ 50: 1, # 'a'
+ 60: 0, # 'c'
+ 61: 1, # 'd'
+ 42: 0, # 'e'
+ 53: 1, # 'i'
+ 56: 0, # 'l'
+ 54: 1, # 'n'
+ 49: 1, # 'o'
+ 51: 0, # 'r'
+ 43: 1, # 's'
+ 44: 1, # 't'
+ 63: 0, # 'u'
+ 34: 2, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 1, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 1, # 'ח'
+ 22: 1, # 'ט'
+ 1: 2, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 2, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 1, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 55: { # '´'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 1, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 2, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 1, # 'ן'
+ 12: 1, # 'נ'
+ 19: 1, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 48: { # '¼'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 1, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 39: { # '½'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 57: { # '¾'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 30: { # 'ְ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 2, # 'ב'
+ 20: 2, # 'ג'
+ 16: 2, # 'ד'
+ 3: 2, # 'ה'
+ 2: 2, # 'ו'
+ 24: 2, # 'ז'
+ 14: 2, # 'ח'
+ 22: 2, # 'ט'
+ 1: 2, # 'י'
+ 25: 2, # 'ך'
+ 15: 2, # 'כ'
+ 4: 2, # 'ל'
+ 11: 1, # 'ם'
+ 6: 2, # 'מ'
+ 23: 0, # 'ן'
+ 12: 2, # 'נ'
+ 19: 2, # 'ס'
+ 13: 2, # 'ע'
+ 26: 0, # 'ף'
+ 18: 2, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 2, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 59: { # 'ֱ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 1, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 1, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 2, # 'ל'
+ 11: 0, # 'ם'
+ 6: 2, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 41: { # 'ֲ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 2, # 'ב'
+ 20: 1, # 'ג'
+ 16: 2, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 1, # 'ח'
+ 22: 1, # 'ט'
+ 1: 1, # 'י'
+ 25: 1, # 'ך'
+ 15: 1, # 'כ'
+ 4: 2, # 'ל'
+ 11: 0, # 'ם'
+ 6: 2, # 'מ'
+ 23: 0, # 'ן'
+ 12: 2, # 'נ'
+ 19: 1, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 1, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 33: { # 'ִ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 1, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 1, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 1, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 1, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 2, # 'ב'
+ 20: 2, # 'ג'
+ 16: 2, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 2, # 'ז'
+ 14: 1, # 'ח'
+ 22: 1, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 2, # 'כ'
+ 4: 2, # 'ל'
+ 11: 2, # 'ם'
+ 6: 2, # 'מ'
+ 23: 2, # 'ן'
+ 12: 2, # 'נ'
+ 19: 2, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 2, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 2, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 37: { # 'ֵ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 1, # 'ַ'
+ 29: 1, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 2, # 'ב'
+ 20: 1, # 'ג'
+ 16: 2, # 'ד'
+ 3: 2, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 2, # 'ח'
+ 22: 1, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 1, # 'כ'
+ 4: 2, # 'ל'
+ 11: 2, # 'ם'
+ 6: 1, # 'מ'
+ 23: 2, # 'ן'
+ 12: 2, # 'נ'
+ 19: 1, # 'ס'
+ 13: 2, # 'ע'
+ 26: 1, # 'ף'
+ 18: 1, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 36: { # 'ֶ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 1, # 'ַ'
+ 29: 1, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 2, # 'ב'
+ 20: 1, # 'ג'
+ 16: 2, # 'ד'
+ 3: 2, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 2, # 'ח'
+ 22: 1, # 'ט'
+ 1: 2, # 'י'
+ 25: 2, # 'ך'
+ 15: 1, # 'כ'
+ 4: 2, # 'ל'
+ 11: 2, # 'ם'
+ 6: 2, # 'מ'
+ 23: 2, # 'ן'
+ 12: 2, # 'נ'
+ 19: 2, # 'ס'
+ 13: 1, # 'ע'
+ 26: 1, # 'ף'
+ 18: 1, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 31: { # 'ַ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 1, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 2, # 'ב'
+ 20: 2, # 'ג'
+ 16: 2, # 'ד'
+ 3: 2, # 'ה'
+ 2: 1, # 'ו'
+ 24: 2, # 'ז'
+ 14: 2, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 2, # 'כ'
+ 4: 2, # 'ל'
+ 11: 2, # 'ם'
+ 6: 2, # 'מ'
+ 23: 2, # 'ן'
+ 12: 2, # 'נ'
+ 19: 2, # 'ס'
+ 13: 2, # 'ע'
+ 26: 2, # 'ף'
+ 18: 2, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 2, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 29: { # 'ָ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 1, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 1, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 2, # 'ב'
+ 20: 2, # 'ג'
+ 16: 2, # 'ד'
+ 3: 3, # 'ה'
+ 2: 2, # 'ו'
+ 24: 2, # 'ז'
+ 14: 2, # 'ח'
+ 22: 1, # 'ט'
+ 1: 2, # 'י'
+ 25: 2, # 'ך'
+ 15: 2, # 'כ'
+ 4: 2, # 'ל'
+ 11: 2, # 'ם'
+ 6: 2, # 'מ'
+ 23: 2, # 'ן'
+ 12: 2, # 'נ'
+ 19: 1, # 'ס'
+ 13: 2, # 'ע'
+ 26: 1, # 'ף'
+ 18: 2, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 2, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 35: { # 'ֹ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 2, # 'ב'
+ 20: 1, # 'ג'
+ 16: 2, # 'ד'
+ 3: 2, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 1, # 'ח'
+ 22: 1, # 'ט'
+ 1: 1, # 'י'
+ 25: 1, # 'ך'
+ 15: 2, # 'כ'
+ 4: 2, # 'ל'
+ 11: 2, # 'ם'
+ 6: 2, # 'מ'
+ 23: 2, # 'ן'
+ 12: 2, # 'נ'
+ 19: 2, # 'ס'
+ 13: 2, # 'ע'
+ 26: 1, # 'ף'
+ 18: 2, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 2, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 62: { # 'ֻ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 1, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 1, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 2, # 'ל'
+ 11: 1, # 'ם'
+ 6: 1, # 'מ'
+ 23: 1, # 'ן'
+ 12: 1, # 'נ'
+ 19: 1, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 28: { # 'ּ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 3, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 1, # 'ֲ'
+ 33: 3, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 3, # 'ַ'
+ 29: 3, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 2, # 'ׁ'
+ 45: 1, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 2, # 'ב'
+ 20: 1, # 'ג'
+ 16: 2, # 'ד'
+ 3: 1, # 'ה'
+ 2: 2, # 'ו'
+ 24: 1, # 'ז'
+ 14: 1, # 'ח'
+ 22: 1, # 'ט'
+ 1: 2, # 'י'
+ 25: 2, # 'ך'
+ 15: 2, # 'כ'
+ 4: 2, # 'ל'
+ 11: 1, # 'ם'
+ 6: 2, # 'מ'
+ 23: 1, # 'ן'
+ 12: 2, # 'נ'
+ 19: 1, # 'ס'
+ 13: 2, # 'ע'
+ 26: 1, # 'ף'
+ 18: 1, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 2, # 'ר'
+ 10: 2, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 38: { # 'ׁ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 2, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 45: { # 'ׂ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 1, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 1, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 1, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 2, # 'ו'
+ 24: 0, # 'ז'
+ 14: 1, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 1, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 0, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 1, # 'ר'
+ 10: 0, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 9: { # 'א'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 1, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 2, # 'ֱ'
+ 41: 2, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 3, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 2, # 'ע'
+ 26: 3, # 'ף'
+ 18: 3, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 8: { # 'ב'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 1, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 3, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 2, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 1, # 'ף'
+ 18: 3, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 1, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 20: { # 'ג'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 2, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 1, # 'ִ'
+ 37: 1, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 3, # 'ב'
+ 20: 2, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 2, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 1, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 2, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 2, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 16: { # 'ד'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 1, # 'ז'
+ 14: 2, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 2, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 2, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 3: { # 'ה'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 1, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 1, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 1, # 'ְ'
+ 59: 1, # 'ֱ'
+ 41: 2, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 3, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 0, # 'ף'
+ 18: 3, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 1, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 2: { # 'ו'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 1, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 1, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 1, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 3, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 3, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 3, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 3, # 'ף'
+ 18: 3, # 'פ'
+ 27: 3, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 1, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 24: { # 'ז'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 1, # 'ֲ'
+ 33: 1, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 2, # 'ב'
+ 20: 2, # 'ג'
+ 16: 2, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 2, # 'ז'
+ 14: 2, # 'ח'
+ 22: 1, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 2, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 2, # 'נ'
+ 19: 1, # 'ס'
+ 13: 2, # 'ע'
+ 26: 1, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 1, # 'ש'
+ 5: 2, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 14: { # 'ח'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 1, # 'ֱ'
+ 41: 2, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 3, # 'ב'
+ 20: 2, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 2, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 2, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 1, # 'ע'
+ 26: 2, # 'ף'
+ 18: 2, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 22: { # 'ט'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 1, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 1, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 1, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 1, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 2, # 'ז'
+ 14: 3, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 2, # 'כ'
+ 4: 3, # 'ל'
+ 11: 2, # 'ם'
+ 6: 2, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 2, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 2, # 'ק'
+ 7: 3, # 'ר'
+ 10: 2, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 1: { # 'י'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 1, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 3, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 3, # 'ף'
+ 18: 3, # 'פ'
+ 27: 3, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 1, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 25: { # 'ך'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 1, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 1, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 1, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 15: { # 'כ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 3, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 2, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 3, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 2, # 'ע'
+ 26: 3, # 'ף'
+ 18: 3, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 2, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 4: { # 'ל'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 3, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 3, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 1, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 11: { # 'ם'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 1, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 1, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 0, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 6: { # 'מ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 0, # 'ף'
+ 18: 3, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 23: { # 'ן'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 1, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 1, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 0, # 'ז'
+ 14: 1, # 'ח'
+ 22: 1, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 1, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 1, # 'ס'
+ 13: 1, # 'ע'
+ 26: 1, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 1, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 1, # 'ת'
+ 32: 1, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 12: { # 'נ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 19: { # 'ס'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 1, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 1, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 2, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 1, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 2, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 2, # 'ס'
+ 13: 3, # 'ע'
+ 26: 3, # 'ף'
+ 18: 3, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 1, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 13: { # 'ע'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 1, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 1, # 'ְ'
+ 59: 1, # 'ֱ'
+ 41: 2, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 1, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 2, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 2, # 'ע'
+ 26: 1, # 'ף'
+ 18: 2, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 26: { # 'ף'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 1, # 'ו'
+ 24: 0, # 'ז'
+ 14: 1, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 1, # 'ס'
+ 13: 0, # 'ע'
+ 26: 1, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 1, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 18: { # 'פ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 1, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 1, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 2, # 'ב'
+ 20: 3, # 'ג'
+ 16: 2, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 2, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 2, # 'ם'
+ 6: 2, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 2, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 27: { # 'ץ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 1, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 1, # 'ר'
+ 10: 0, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 21: { # 'צ'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 2, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 1, # 'ז'
+ 14: 3, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 1, # 'כ'
+ 4: 3, # 'ל'
+ 11: 2, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 1, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 0, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 17: { # 'ק'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 1, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 2, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 2, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 1, # 'ך'
+ 15: 1, # 'כ'
+ 4: 3, # 'ל'
+ 11: 2, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 2, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 2, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 7: { # 'ר'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 2, # '´'
+ 48: 1, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 1, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 2, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 3, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 3, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 3, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 3, # 'ץ'
+ 21: 3, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 10: { # 'ש'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 1, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 1, # 'ִ'
+ 37: 1, # 'ֵ'
+ 36: 1, # 'ֶ'
+ 31: 1, # 'ַ'
+ 29: 1, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 3, # 'ׁ'
+ 45: 2, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 3, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 2, # 'ז'
+ 14: 3, # 'ח'
+ 22: 3, # 'ט'
+ 1: 3, # 'י'
+ 25: 3, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 2, # 'ן'
+ 12: 3, # 'נ'
+ 19: 2, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 1, # '…'
+ },
+ 5: { # 'ת'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 1, # '\xa0'
+ 55: 0, # '´'
+ 48: 1, # '¼'
+ 39: 1, # '½'
+ 57: 0, # '¾'
+ 30: 2, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 2, # 'ִ'
+ 37: 2, # 'ֵ'
+ 36: 2, # 'ֶ'
+ 31: 2, # 'ַ'
+ 29: 2, # 'ָ'
+ 35: 1, # 'ֹ'
+ 62: 1, # 'ֻ'
+ 28: 2, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 3, # 'א'
+ 8: 3, # 'ב'
+ 20: 3, # 'ג'
+ 16: 2, # 'ד'
+ 3: 3, # 'ה'
+ 2: 3, # 'ו'
+ 24: 2, # 'ז'
+ 14: 3, # 'ח'
+ 22: 2, # 'ט'
+ 1: 3, # 'י'
+ 25: 2, # 'ך'
+ 15: 3, # 'כ'
+ 4: 3, # 'ל'
+ 11: 3, # 'ם'
+ 6: 3, # 'מ'
+ 23: 3, # 'ן'
+ 12: 3, # 'נ'
+ 19: 2, # 'ס'
+ 13: 3, # 'ע'
+ 26: 2, # 'ף'
+ 18: 3, # 'פ'
+ 27: 1, # 'ץ'
+ 21: 2, # 'צ'
+ 17: 3, # 'ק'
+ 7: 3, # 'ר'
+ 10: 3, # 'ש'
+ 5: 3, # 'ת'
+ 32: 1, # '–'
+ 52: 1, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+ 32: { # '–'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 1, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 1, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 0, # 'ז'
+ 14: 1, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 1, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 0, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 52: { # '’'
+ 50: 1, # 'a'
+ 60: 0, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 1, # 'r'
+ 43: 2, # 's'
+ 44: 2, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 1, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 47: { # '“'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 1, # 'l'
+ 54: 1, # 'n'
+ 49: 1, # 'o'
+ 51: 1, # 'r'
+ 43: 1, # 's'
+ 44: 1, # 't'
+ 63: 1, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 2, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 1, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 1, # 'ח'
+ 22: 1, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 1, # 'ס'
+ 13: 1, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 1, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 46: { # '”'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 1, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 1, # 'ב'
+ 20: 1, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 1, # 'צ'
+ 17: 0, # 'ק'
+ 7: 1, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 0, # '†'
+ 40: 0, # '…'
+ },
+ 58: { # '†'
+ 50: 0, # 'a'
+ 60: 0, # 'c'
+ 61: 0, # 'd'
+ 42: 0, # 'e'
+ 53: 0, # 'i'
+ 56: 0, # 'l'
+ 54: 0, # 'n'
+ 49: 0, # 'o'
+ 51: 0, # 'r'
+ 43: 0, # 's'
+ 44: 0, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 0, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 0, # 'ה'
+ 2: 0, # 'ו'
+ 24: 0, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 0, # 'י'
+ 25: 0, # 'ך'
+ 15: 0, # 'כ'
+ 4: 0, # 'ל'
+ 11: 0, # 'ם'
+ 6: 0, # 'מ'
+ 23: 0, # 'ן'
+ 12: 0, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 0, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 0, # 'ר'
+ 10: 0, # 'ש'
+ 5: 0, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 0, # '”'
+ 58: 2, # '†'
+ 40: 0, # '…'
+ },
+ 40: { # '…'
+ 50: 1, # 'a'
+ 60: 1, # 'c'
+ 61: 1, # 'd'
+ 42: 1, # 'e'
+ 53: 1, # 'i'
+ 56: 0, # 'l'
+ 54: 1, # 'n'
+ 49: 0, # 'o'
+ 51: 1, # 'r'
+ 43: 1, # 's'
+ 44: 1, # 't'
+ 63: 0, # 'u'
+ 34: 0, # '\xa0'
+ 55: 0, # '´'
+ 48: 0, # '¼'
+ 39: 0, # '½'
+ 57: 0, # '¾'
+ 30: 0, # 'ְ'
+ 59: 0, # 'ֱ'
+ 41: 0, # 'ֲ'
+ 33: 0, # 'ִ'
+ 37: 0, # 'ֵ'
+ 36: 0, # 'ֶ'
+ 31: 0, # 'ַ'
+ 29: 0, # 'ָ'
+ 35: 0, # 'ֹ'
+ 62: 0, # 'ֻ'
+ 28: 0, # 'ּ'
+ 38: 0, # 'ׁ'
+ 45: 0, # 'ׂ'
+ 9: 1, # 'א'
+ 8: 0, # 'ב'
+ 20: 0, # 'ג'
+ 16: 0, # 'ד'
+ 3: 1, # 'ה'
+ 2: 1, # 'ו'
+ 24: 1, # 'ז'
+ 14: 0, # 'ח'
+ 22: 0, # 'ט'
+ 1: 1, # 'י'
+ 25: 0, # 'ך'
+ 15: 1, # 'כ'
+ 4: 1, # 'ל'
+ 11: 0, # 'ם'
+ 6: 1, # 'מ'
+ 23: 0, # 'ן'
+ 12: 1, # 'נ'
+ 19: 0, # 'ס'
+ 13: 0, # 'ע'
+ 26: 0, # 'ף'
+ 18: 1, # 'פ'
+ 27: 0, # 'ץ'
+ 21: 0, # 'צ'
+ 17: 0, # 'ק'
+ 7: 1, # 'ר'
+ 10: 1, # 'ש'
+ 5: 1, # 'ת'
+ 32: 0, # '–'
+ 52: 0, # '’'
+ 47: 0, # '“'
+ 46: 1, # '”'
+ 58: 0, # '†'
+ 40: 2, # '…'
+ },
+}
+
+# 255: Undefined characters that did not exist in training text
+# 254: Carriage/Return
+# 253: symbol (punctuation) that does not belong to word
+# 252: 0 - 9
+# 251: Control characters
+
+# Character Mapping Table(s):
+WINDOWS_1255_HEBREW_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 69, # 'A'
+ 66: 91, # 'B'
+ 67: 79, # 'C'
+ 68: 80, # 'D'
+ 69: 92, # 'E'
+ 70: 89, # 'F'
+ 71: 97, # 'G'
+ 72: 90, # 'H'
+ 73: 68, # 'I'
+ 74: 111, # 'J'
+ 75: 112, # 'K'
+ 76: 82, # 'L'
+ 77: 73, # 'M'
+ 78: 95, # 'N'
+ 79: 85, # 'O'
+ 80: 78, # 'P'
+ 81: 121, # 'Q'
+ 82: 86, # 'R'
+ 83: 71, # 'S'
+ 84: 67, # 'T'
+ 85: 102, # 'U'
+ 86: 107, # 'V'
+ 87: 84, # 'W'
+ 88: 114, # 'X'
+ 89: 103, # 'Y'
+ 90: 115, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 50, # 'a'
+ 98: 74, # 'b'
+ 99: 60, # 'c'
+ 100: 61, # 'd'
+ 101: 42, # 'e'
+ 102: 76, # 'f'
+ 103: 70, # 'g'
+ 104: 64, # 'h'
+ 105: 53, # 'i'
+ 106: 105, # 'j'
+ 107: 93, # 'k'
+ 108: 56, # 'l'
+ 109: 65, # 'm'
+ 110: 54, # 'n'
+ 111: 49, # 'o'
+ 112: 66, # 'p'
+ 113: 110, # 'q'
+ 114: 51, # 'r'
+ 115: 43, # 's'
+ 116: 44, # 't'
+ 117: 63, # 'u'
+ 118: 81, # 'v'
+ 119: 77, # 'w'
+ 120: 98, # 'x'
+ 121: 75, # 'y'
+ 122: 108, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 124, # '€'
+ 129: 202, # None
+ 130: 203, # '‚'
+ 131: 204, # 'ƒ'
+ 132: 205, # '„'
+ 133: 40, # '…'
+ 134: 58, # '†'
+ 135: 206, # '‡'
+ 136: 207, # 'ˆ'
+ 137: 208, # '‰'
+ 138: 209, # None
+ 139: 210, # '‹'
+ 140: 211, # None
+ 141: 212, # None
+ 142: 213, # None
+ 143: 214, # None
+ 144: 215, # None
+ 145: 83, # '‘'
+ 146: 52, # '’'
+ 147: 47, # '“'
+ 148: 46, # '”'
+ 149: 72, # '•'
+ 150: 32, # '–'
+ 151: 94, # '—'
+ 152: 216, # '˜'
+ 153: 113, # '™'
+ 154: 217, # None
+ 155: 109, # '›'
+ 156: 218, # None
+ 157: 219, # None
+ 158: 220, # None
+ 159: 221, # None
+ 160: 34, # '\xa0'
+ 161: 116, # '¡'
+ 162: 222, # '¢'
+ 163: 118, # '£'
+ 164: 100, # '₪'
+ 165: 223, # '¥'
+ 166: 224, # '¦'
+ 167: 117, # '§'
+ 168: 119, # '¨'
+ 169: 104, # '©'
+ 170: 125, # '×'
+ 171: 225, # '«'
+ 172: 226, # '¬'
+ 173: 87, # '\xad'
+ 174: 99, # '®'
+ 175: 227, # '¯'
+ 176: 106, # '°'
+ 177: 122, # '±'
+ 178: 123, # '²'
+ 179: 228, # '³'
+ 180: 55, # '´'
+ 181: 229, # 'µ'
+ 182: 230, # '¶'
+ 183: 101, # '·'
+ 184: 231, # '¸'
+ 185: 232, # '¹'
+ 186: 120, # '÷'
+ 187: 233, # '»'
+ 188: 48, # '¼'
+ 189: 39, # '½'
+ 190: 57, # '¾'
+ 191: 234, # '¿'
+ 192: 30, # 'ְ'
+ 193: 59, # 'ֱ'
+ 194: 41, # 'ֲ'
+ 195: 88, # 'ֳ'
+ 196: 33, # 'ִ'
+ 197: 37, # 'ֵ'
+ 198: 36, # 'ֶ'
+ 199: 31, # 'ַ'
+ 200: 29, # 'ָ'
+ 201: 35, # 'ֹ'
+ 202: 235, # None
+ 203: 62, # 'ֻ'
+ 204: 28, # 'ּ'
+ 205: 236, # 'ֽ'
+ 206: 126, # '־'
+ 207: 237, # 'ֿ'
+ 208: 238, # '׀'
+ 209: 38, # 'ׁ'
+ 210: 45, # 'ׂ'
+ 211: 239, # '׃'
+ 212: 240, # 'װ'
+ 213: 241, # 'ױ'
+ 214: 242, # 'ײ'
+ 215: 243, # '׳'
+ 216: 127, # '״'
+ 217: 244, # None
+ 218: 245, # None
+ 219: 246, # None
+ 220: 247, # None
+ 221: 248, # None
+ 222: 249, # None
+ 223: 250, # None
+ 224: 9, # 'א'
+ 225: 8, # 'ב'
+ 226: 20, # 'ג'
+ 227: 16, # 'ד'
+ 228: 3, # 'ה'
+ 229: 2, # 'ו'
+ 230: 24, # 'ז'
+ 231: 14, # 'ח'
+ 232: 22, # 'ט'
+ 233: 1, # 'י'
+ 234: 25, # 'ך'
+ 235: 15, # 'כ'
+ 236: 4, # 'ל'
+ 237: 11, # 'ם'
+ 238: 6, # 'מ'
+ 239: 23, # 'ן'
+ 240: 12, # 'נ'
+ 241: 19, # 'ס'
+ 242: 13, # 'ע'
+ 243: 26, # 'ף'
+ 244: 18, # 'פ'
+ 245: 27, # 'ץ'
+ 246: 21, # 'צ'
+ 247: 17, # 'ק'
+ 248: 7, # 'ר'
+ 249: 10, # 'ש'
+ 250: 5, # 'ת'
+ 251: 251, # None
+ 252: 252, # None
+ 253: 128, # '\u200e'
+ 254: 96, # '\u200f'
+ 255: 253, # None
+}
+
+WINDOWS_1255_HEBREW_MODEL = SingleByteCharSetModel(
+ charset_name="windows-1255",
+ language="Hebrew",
+ char_to_order_map=WINDOWS_1255_HEBREW_CHAR_TO_ORDER,
+ language_model=HEBREW_LANG_MODEL,
+ typical_positive_ratio=0.984004,
+ keep_ascii_letters=False,
+ alphabet="אבגדהוזחטיךכלםמןנסעףפץצקרשתװױײ",
+)
diff --git a/lib/chardet/langhungarianmodel.py b/lib/chardet/langhungarianmodel.py
new file mode 100644
index 0000000..bd6630a
--- /dev/null
+++ b/lib/chardet/langhungarianmodel.py
@@ -0,0 +1,4649 @@
+from chardet.sbcharsetprober import SingleByteCharSetModel
+
+# 3: Positive
+# 2: Likely
+# 1: Unlikely
+# 0: Negative
+
+HUNGARIAN_LANG_MODEL = {
+ 28: { # 'A'
+ 28: 0, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 2, # 'D'
+ 32: 1, # 'E'
+ 50: 1, # 'F'
+ 49: 2, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 2, # 'K'
+ 41: 2, # 'L'
+ 34: 1, # 'M'
+ 35: 2, # 'N'
+ 47: 1, # 'O'
+ 46: 2, # 'P'
+ 43: 2, # 'R'
+ 33: 2, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 2, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 2, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 1, # 'i'
+ 22: 1, # 'j'
+ 7: 2, # 'k'
+ 6: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 2, # 'n'
+ 8: 0, # 'o'
+ 23: 2, # 'p'
+ 10: 2, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 1, # 'u'
+ 19: 1, # 'v'
+ 62: 1, # 'x'
+ 16: 0, # 'y'
+ 11: 3, # 'z'
+ 51: 1, # 'Á'
+ 44: 0, # 'É'
+ 61: 1, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 40: { # 'B'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 0, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 3, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 2, # 'i'
+ 22: 1, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 3, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 54: { # 'C'
+ 28: 1, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 1, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 0, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 2, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 0, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 1, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 1, # 'h'
+ 9: 1, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 3, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 1, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 45: { # 'D'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 0, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 0, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 3, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 1, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 1, # 'o'
+ 23: 0, # 'p'
+ 10: 2, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 2, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 1, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 32: { # 'E'
+ 28: 1, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 1, # 'E'
+ 50: 1, # 'F'
+ 49: 2, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 2, # 'K'
+ 41: 2, # 'L'
+ 34: 2, # 'M'
+ 35: 2, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 2, # 'R'
+ 33: 2, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 1, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 2, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 3, # 'g'
+ 20: 1, # 'h'
+ 9: 1, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 2, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 2, # 's'
+ 3: 1, # 't'
+ 21: 2, # 'u'
+ 19: 1, # 'v'
+ 62: 1, # 'x'
+ 16: 0, # 'y'
+ 11: 3, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 0, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 1, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 50: { # 'F'
+ 28: 1, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 1, # 'E'
+ 50: 1, # 'F'
+ 49: 0, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 0, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 0, # 'V'
+ 55: 1, # 'Y'
+ 52: 0, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 1, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 2, # 'i'
+ 22: 1, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 2, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 0, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 2, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 49: { # 'G'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 2, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 1, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 2, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 2, # 'y'
+ 11: 0, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 0, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 38: { # 'H'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 0, # 'D'
+ 32: 1, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 1, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 1, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 1, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 0, # 'V'
+ 55: 1, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 2, # 'i'
+ 22: 1, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 0, # 'n'
+ 8: 3, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 2, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 2, # 'Á'
+ 44: 2, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 1, # 'é'
+ 30: 2, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 39: { # 'I'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 1, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 2, # 'K'
+ 41: 2, # 'L'
+ 34: 1, # 'M'
+ 35: 2, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 2, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 2, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 2, # 'd'
+ 1: 0, # 'e'
+ 27: 1, # 'f'
+ 12: 2, # 'g'
+ 20: 1, # 'h'
+ 9: 0, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 1, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 2, # 's'
+ 3: 2, # 't'
+ 21: 0, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 0, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 53: { # 'J'
+ 28: 2, # 'A'
+ 40: 0, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 1, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 1, # 'o'
+ 23: 0, # 'p'
+ 10: 0, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 2, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 0, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 1, # 'é'
+ 30: 0, # 'í'
+ 25: 2, # 'ó'
+ 24: 2, # 'ö'
+ 31: 1, # 'ú'
+ 29: 0, # 'ü'
+ 42: 1, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 36: { # 'K'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 0, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 0, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 1, # 'f'
+ 12: 0, # 'g'
+ 20: 1, # 'h'
+ 9: 3, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 2, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 2, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 2, # 'ö'
+ 31: 1, # 'ú'
+ 29: 2, # 'ü'
+ 42: 1, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 41: { # 'L'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 2, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 3, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 2, # 'i'
+ 22: 1, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 0, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 2, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 2, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 0, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 34: { # 'M'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 0, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 3, # 'a'
+ 18: 0, # 'b'
+ 26: 1, # 'c'
+ 17: 0, # 'd'
+ 1: 3, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 3, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 3, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 2, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 2, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 35: { # 'N'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 2, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 2, # 'Y'
+ 52: 1, # 'Z'
+ 2: 3, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 3, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 2, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 0, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 2, # 'y'
+ 11: 0, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 1, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 1, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 47: { # 'O'
+ 28: 1, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 1, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 2, # 'K'
+ 41: 2, # 'L'
+ 34: 2, # 'M'
+ 35: 2, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 2, # 'R'
+ 33: 2, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 1, # 'i'
+ 22: 1, # 'j'
+ 7: 2, # 'k'
+ 6: 2, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 1, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 1, # 's'
+ 3: 2, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 1, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 0, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 46: { # 'P'
+ 28: 1, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 1, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 0, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 2, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 1, # 'f'
+ 12: 0, # 'g'
+ 20: 1, # 'h'
+ 9: 2, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 2, # 'r'
+ 5: 1, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 2, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 3, # 'á'
+ 15: 2, # 'é'
+ 30: 0, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 0, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 43: { # 'R'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 2, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 1, # 'h'
+ 9: 2, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 0, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 2, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 2, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 2, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 33: { # 'S'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 2, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 3, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 1, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 1, # 'h'
+ 9: 2, # 'i'
+ 22: 0, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 1, # 'p'
+ 10: 0, # 'r'
+ 5: 0, # 's'
+ 3: 1, # 't'
+ 21: 1, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 3, # 'z'
+ 51: 2, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 37: { # 'T'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 1, # 'P'
+ 43: 2, # 'R'
+ 33: 1, # 'S'
+ 37: 2, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 2, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 1, # 'h'
+ 9: 2, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 0, # 't'
+ 21: 2, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 1, # 'z'
+ 51: 2, # 'Á'
+ 44: 2, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 2, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 57: { # 'U'
+ 28: 1, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 1, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 2, # 'S'
+ 37: 1, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 1, # 'e'
+ 27: 0, # 'f'
+ 12: 2, # 'g'
+ 20: 0, # 'h'
+ 9: 0, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 1, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 48: { # 'V'
+ 28: 2, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 0, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 2, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 2, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 2, # 'o'
+ 23: 0, # 'p'
+ 10: 0, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 2, # 'Á'
+ 44: 2, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 2, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 0, # 'ó'
+ 24: 1, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 55: { # 'Y'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 1, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 2, # 'Z'
+ 2: 1, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 1, # 'd'
+ 1: 1, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 0, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 8: 1, # 'o'
+ 23: 1, # 'p'
+ 10: 0, # 'r'
+ 5: 0, # 's'
+ 3: 0, # 't'
+ 21: 0, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 1, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 52: { # 'Z'
+ 28: 2, # 'A'
+ 40: 1, # 'B'
+ 54: 0, # 'C'
+ 45: 1, # 'D'
+ 32: 2, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 2, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 2, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 2, # 'S'
+ 37: 1, # 'T'
+ 57: 1, # 'U'
+ 48: 1, # 'V'
+ 55: 1, # 'Y'
+ 52: 1, # 'Z'
+ 2: 1, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 1, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 1, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 8: 1, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 2, # 's'
+ 3: 0, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 2, # 'Á'
+ 44: 1, # 'É'
+ 61: 1, # 'Í'
+ 58: 1, # 'Ó'
+ 59: 1, # 'Ö'
+ 60: 1, # 'Ú'
+ 63: 1, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 2: { # 'a'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 3, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 2, # 'e'
+ 27: 2, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 3, # 'i'
+ 22: 3, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 2, # 'o'
+ 23: 3, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 1, # 'x'
+ 16: 2, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 18: { # 'b'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 3, # 'i'
+ 22: 2, # 'j'
+ 7: 2, # 'k'
+ 6: 2, # 'l'
+ 13: 1, # 'm'
+ 4: 2, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 3, # 'r'
+ 5: 2, # 's'
+ 3: 1, # 't'
+ 21: 3, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 3, # 'ó'
+ 24: 2, # 'ö'
+ 31: 2, # 'ú'
+ 29: 2, # 'ü'
+ 42: 2, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 26: { # 'c'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 1, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 1, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 2, # 'a'
+ 18: 1, # 'b'
+ 26: 2, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 3, # 'h'
+ 9: 3, # 'i'
+ 22: 1, # 'j'
+ 7: 2, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 3, # 's'
+ 3: 2, # 't'
+ 21: 2, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 2, # 'á'
+ 15: 2, # 'é'
+ 30: 2, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 17: { # 'd'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 2, # 'b'
+ 26: 1, # 'c'
+ 17: 2, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 3, # 'j'
+ 7: 2, # 'k'
+ 6: 1, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 2, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 3, # 'í'
+ 25: 3, # 'ó'
+ 24: 3, # 'ö'
+ 31: 2, # 'ú'
+ 29: 2, # 'ü'
+ 42: 2, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 1: { # 'e'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 2, # 'a'
+ 18: 3, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 2, # 'e'
+ 27: 3, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 3, # 'i'
+ 22: 3, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 2, # 'o'
+ 23: 3, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 2, # 'u'
+ 19: 3, # 'v'
+ 62: 2, # 'x'
+ 16: 2, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 27: { # 'f'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 2, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 3, # 'i'
+ 22: 2, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 3, # 'o'
+ 23: 0, # 'p'
+ 10: 3, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 2, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 0, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 3, # 'ö'
+ 31: 1, # 'ú'
+ 29: 2, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 12: { # 'g'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 2, # 'c'
+ 17: 2, # 'd'
+ 1: 3, # 'e'
+ 27: 2, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 3, # 'i'
+ 22: 3, # 'j'
+ 7: 2, # 'k'
+ 6: 3, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 3, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 3, # 'ó'
+ 24: 2, # 'ö'
+ 31: 2, # 'ú'
+ 29: 2, # 'ü'
+ 42: 2, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 20: { # 'h'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 0, # 'd'
+ 1: 3, # 'e'
+ 27: 0, # 'f'
+ 12: 1, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 3, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 2, # 's'
+ 3: 1, # 't'
+ 21: 3, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 2, # 'y'
+ 11: 0, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 3, # 'í'
+ 25: 2, # 'ó'
+ 24: 2, # 'ö'
+ 31: 2, # 'ú'
+ 29: 1, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 9: { # 'i'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 3, # 'e'
+ 27: 3, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 2, # 'i'
+ 22: 2, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 2, # 'o'
+ 23: 2, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 1, # 'x'
+ 16: 1, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 3, # 'ó'
+ 24: 1, # 'ö'
+ 31: 2, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 22: { # 'j'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 2, # 'b'
+ 26: 1, # 'c'
+ 17: 3, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 2, # 'h'
+ 9: 1, # 'i'
+ 22: 2, # 'j'
+ 7: 2, # 'k'
+ 6: 2, # 'l'
+ 13: 1, # 'm'
+ 4: 2, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 2, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 1, # 'í'
+ 25: 3, # 'ó'
+ 24: 3, # 'ö'
+ 31: 3, # 'ú'
+ 29: 2, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 7: { # 'k'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 2, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 2, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 1, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 2, # 'v'
+ 62: 0, # 'x'
+ 16: 2, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 3, # 'í'
+ 25: 2, # 'ó'
+ 24: 3, # 'ö'
+ 31: 1, # 'ú'
+ 29: 3, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 6: { # 'l'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 1, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 1, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 2, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 3, # 'e'
+ 27: 3, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 3, # 'i'
+ 22: 3, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 2, # 'p'
+ 10: 2, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 3, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 3, # 'í'
+ 25: 3, # 'ó'
+ 24: 3, # 'ö'
+ 31: 2, # 'ú'
+ 29: 2, # 'ü'
+ 42: 3, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 13: { # 'm'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 2, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 2, # 'j'
+ 7: 1, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 8: 3, # 'o'
+ 23: 3, # 'p'
+ 10: 2, # 'r'
+ 5: 2, # 's'
+ 3: 2, # 't'
+ 21: 3, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 2, # 'ó'
+ 24: 2, # 'ö'
+ 31: 2, # 'ú'
+ 29: 2, # 'ü'
+ 42: 1, # 'ő'
+ 56: 2, # 'ű'
+ },
+ 4: { # 'n'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 3, # 'e'
+ 27: 2, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 3, # 'i'
+ 22: 2, # 'j'
+ 7: 3, # 'k'
+ 6: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 2, # 'p'
+ 10: 2, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 2, # 'v'
+ 62: 1, # 'x'
+ 16: 3, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 2, # 'ó'
+ 24: 3, # 'ö'
+ 31: 2, # 'ú'
+ 29: 3, # 'ü'
+ 42: 2, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 8: { # 'o'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 1, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 2, # 'a'
+ 18: 3, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 2, # 'e'
+ 27: 2, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 2, # 'i'
+ 22: 2, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 1, # 'o'
+ 23: 3, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 2, # 'u'
+ 19: 3, # 'v'
+ 62: 1, # 'x'
+ 16: 1, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 23: { # 'p'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 1, # 'b'
+ 26: 2, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 2, # 'j'
+ 7: 2, # 'k'
+ 6: 3, # 'l'
+ 13: 1, # 'm'
+ 4: 2, # 'n'
+ 8: 3, # 'o'
+ 23: 3, # 'p'
+ 10: 3, # 'r'
+ 5: 2, # 's'
+ 3: 2, # 't'
+ 21: 3, # 'u'
+ 19: 2, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 2, # 'ó'
+ 24: 2, # 'ö'
+ 31: 1, # 'ú'
+ 29: 2, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 10: { # 'r'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 3, # 'e'
+ 27: 2, # 'f'
+ 12: 3, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 3, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 2, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 1, # 'x'
+ 16: 2, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 3, # 'ó'
+ 24: 3, # 'ö'
+ 31: 3, # 'ú'
+ 29: 3, # 'ü'
+ 42: 2, # 'ő'
+ 56: 2, # 'ű'
+ },
+ 5: { # 's'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 2, # 'c'
+ 17: 2, # 'd'
+ 1: 3, # 'e'
+ 27: 2, # 'f'
+ 12: 2, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 1, # 'j'
+ 7: 3, # 'k'
+ 6: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 2, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 2, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 3, # 'í'
+ 25: 3, # 'ó'
+ 24: 3, # 'ö'
+ 31: 3, # 'ú'
+ 29: 3, # 'ü'
+ 42: 2, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 3: { # 't'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 3, # 'b'
+ 26: 2, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 2, # 'f'
+ 12: 1, # 'g'
+ 20: 3, # 'h'
+ 9: 3, # 'i'
+ 22: 3, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 3, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 3, # 'ó'
+ 24: 3, # 'ö'
+ 31: 3, # 'ú'
+ 29: 3, # 'ü'
+ 42: 3, # 'ő'
+ 56: 2, # 'ű'
+ },
+ 21: { # 'u'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 2, # 'b'
+ 26: 2, # 'c'
+ 17: 3, # 'd'
+ 1: 2, # 'e'
+ 27: 1, # 'f'
+ 12: 3, # 'g'
+ 20: 2, # 'h'
+ 9: 2, # 'i'
+ 22: 2, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 1, # 'o'
+ 23: 2, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 1, # 'u'
+ 19: 3, # 'v'
+ 62: 1, # 'x'
+ 16: 1, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 2, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 0, # 'ö'
+ 31: 1, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 19: { # 'v'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 2, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 3, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 1, # 'r'
+ 5: 2, # 's'
+ 3: 2, # 't'
+ 21: 2, # 'u'
+ 19: 2, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 2, # 'ó'
+ 24: 2, # 'ö'
+ 31: 1, # 'ú'
+ 29: 2, # 'ü'
+ 42: 1, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 62: { # 'x'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 0, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 1, # 'i'
+ 22: 0, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 1, # 'o'
+ 23: 1, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 1, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 1, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 16: { # 'y'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 2, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 3, # 'e'
+ 27: 2, # 'f'
+ 12: 2, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 2, # 'j'
+ 7: 2, # 'k'
+ 6: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 2, # 'p'
+ 10: 2, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 2, # 'í'
+ 25: 2, # 'ó'
+ 24: 3, # 'ö'
+ 31: 2, # 'ú'
+ 29: 2, # 'ü'
+ 42: 1, # 'ő'
+ 56: 2, # 'ű'
+ },
+ 11: { # 'z'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 3, # 'a'
+ 18: 2, # 'b'
+ 26: 1, # 'c'
+ 17: 3, # 'd'
+ 1: 3, # 'e'
+ 27: 1, # 'f'
+ 12: 2, # 'g'
+ 20: 2, # 'h'
+ 9: 3, # 'i'
+ 22: 1, # 'j'
+ 7: 3, # 'k'
+ 6: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 3, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 3, # 'u'
+ 19: 2, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 3, # 'á'
+ 15: 3, # 'é'
+ 30: 3, # 'í'
+ 25: 3, # 'ó'
+ 24: 3, # 'ö'
+ 31: 2, # 'ú'
+ 29: 3, # 'ü'
+ 42: 2, # 'ő'
+ 56: 1, # 'ű'
+ },
+ 51: { # 'Á'
+ 28: 0, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 0, # 'E'
+ 50: 1, # 'F'
+ 49: 2, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 2, # 'L'
+ 34: 1, # 'M'
+ 35: 2, # 'N'
+ 47: 0, # 'O'
+ 46: 1, # 'P'
+ 43: 2, # 'R'
+ 33: 2, # 'S'
+ 37: 1, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 0, # 'e'
+ 27: 0, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 0, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 1, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 44: { # 'É'
+ 28: 0, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 1, # 'E'
+ 50: 0, # 'F'
+ 49: 2, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 2, # 'L'
+ 34: 1, # 'M'
+ 35: 2, # 'N'
+ 47: 0, # 'O'
+ 46: 1, # 'P'
+ 43: 2, # 'R'
+ 33: 2, # 'S'
+ 37: 2, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 0, # 'e'
+ 27: 0, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 0, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 2, # 'l'
+ 13: 1, # 'm'
+ 4: 2, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 3, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 0, # 'Á'
+ 44: 1, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 61: { # 'Í'
+ 28: 0, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 0, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 1, # 'J'
+ 36: 0, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 0, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 0, # 'e'
+ 27: 0, # 'f'
+ 12: 2, # 'g'
+ 20: 0, # 'h'
+ 9: 0, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 0, # 'n'
+ 8: 0, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 0, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 58: { # 'Ó'
+ 28: 1, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 0, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 1, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 2, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 0, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 0, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 2, # 'h'
+ 9: 0, # 'i'
+ 22: 0, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 0, # 't'
+ 21: 0, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 1, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 59: { # 'Ö'
+ 28: 0, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 0, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 0, # 'O'
+ 46: 1, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 0, # 'b'
+ 26: 1, # 'c'
+ 17: 1, # 'd'
+ 1: 0, # 'e'
+ 27: 0, # 'f'
+ 12: 0, # 'g'
+ 20: 0, # 'h'
+ 9: 0, # 'i'
+ 22: 0, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 8: 0, # 'o'
+ 23: 0, # 'p'
+ 10: 2, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 60: { # 'Ú'
+ 28: 0, # 'A'
+ 40: 1, # 'B'
+ 54: 1, # 'C'
+ 45: 1, # 'D'
+ 32: 0, # 'E'
+ 50: 1, # 'F'
+ 49: 1, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 0, # 'b'
+ 26: 0, # 'c'
+ 17: 0, # 'd'
+ 1: 0, # 'e'
+ 27: 0, # 'f'
+ 12: 2, # 'g'
+ 20: 0, # 'h'
+ 9: 0, # 'i'
+ 22: 2, # 'j'
+ 7: 0, # 'k'
+ 6: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 8: 0, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 0, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 0, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 63: { # 'Ü'
+ 28: 0, # 'A'
+ 40: 1, # 'B'
+ 54: 0, # 'C'
+ 45: 1, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 1, # 'G'
+ 38: 1, # 'H'
+ 39: 0, # 'I'
+ 53: 1, # 'J'
+ 36: 1, # 'K'
+ 41: 1, # 'L'
+ 34: 1, # 'M'
+ 35: 1, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 1, # 'R'
+ 33: 1, # 'S'
+ 37: 1, # 'T'
+ 57: 0, # 'U'
+ 48: 1, # 'V'
+ 55: 0, # 'Y'
+ 52: 1, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 0, # 'c'
+ 17: 1, # 'd'
+ 1: 0, # 'e'
+ 27: 0, # 'f'
+ 12: 1, # 'g'
+ 20: 0, # 'h'
+ 9: 0, # 'i'
+ 22: 0, # 'j'
+ 7: 0, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 8: 0, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 1, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 14: { # 'á'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 3, # 'b'
+ 26: 3, # 'c'
+ 17: 3, # 'd'
+ 1: 1, # 'e'
+ 27: 2, # 'f'
+ 12: 3, # 'g'
+ 20: 2, # 'h'
+ 9: 2, # 'i'
+ 22: 3, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 1, # 'o'
+ 23: 2, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 2, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 1, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 2, # 'é'
+ 30: 1, # 'í'
+ 25: 0, # 'ó'
+ 24: 1, # 'ö'
+ 31: 0, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 15: { # 'é'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 3, # 'b'
+ 26: 2, # 'c'
+ 17: 3, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 3, # 'g'
+ 20: 3, # 'h'
+ 9: 2, # 'i'
+ 22: 2, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 1, # 'o'
+ 23: 3, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 0, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 30: { # 'í'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 0, # 'a'
+ 18: 1, # 'b'
+ 26: 2, # 'c'
+ 17: 1, # 'd'
+ 1: 0, # 'e'
+ 27: 1, # 'f'
+ 12: 3, # 'g'
+ 20: 0, # 'h'
+ 9: 0, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 3, # 'r'
+ 5: 2, # 's'
+ 3: 3, # 't'
+ 21: 0, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 25: { # 'ó'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 2, # 'a'
+ 18: 3, # 'b'
+ 26: 2, # 'c'
+ 17: 3, # 'd'
+ 1: 1, # 'e'
+ 27: 2, # 'f'
+ 12: 2, # 'g'
+ 20: 2, # 'h'
+ 9: 2, # 'i'
+ 22: 2, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 8: 1, # 'o'
+ 23: 2, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 1, # 'u'
+ 19: 2, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 0, # 'ó'
+ 24: 1, # 'ö'
+ 31: 1, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 24: { # 'ö'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 0, # 'a'
+ 18: 3, # 'b'
+ 26: 1, # 'c'
+ 17: 2, # 'd'
+ 1: 0, # 'e'
+ 27: 1, # 'f'
+ 12: 2, # 'g'
+ 20: 1, # 'h'
+ 9: 0, # 'i'
+ 22: 1, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 8: 0, # 'o'
+ 23: 2, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 3, # 't'
+ 21: 0, # 'u'
+ 19: 3, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 3, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 31: { # 'ú'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 1, # 'b'
+ 26: 2, # 'c'
+ 17: 1, # 'd'
+ 1: 1, # 'e'
+ 27: 2, # 'f'
+ 12: 3, # 'g'
+ 20: 1, # 'h'
+ 9: 1, # 'i'
+ 22: 3, # 'j'
+ 7: 1, # 'k'
+ 6: 3, # 'l'
+ 13: 1, # 'm'
+ 4: 2, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 3, # 'r'
+ 5: 3, # 's'
+ 3: 2, # 't'
+ 21: 1, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 1, # 'á'
+ 15: 1, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 29: { # 'ü'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 1, # 'b'
+ 26: 1, # 'c'
+ 17: 2, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 3, # 'g'
+ 20: 2, # 'h'
+ 9: 1, # 'i'
+ 22: 1, # 'j'
+ 7: 3, # 'k'
+ 6: 3, # 'l'
+ 13: 1, # 'm'
+ 4: 3, # 'n'
+ 8: 0, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 2, # 's'
+ 3: 2, # 't'
+ 21: 0, # 'u'
+ 19: 2, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 1, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 42: { # 'ő'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 2, # 'b'
+ 26: 1, # 'c'
+ 17: 2, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 1, # 'i'
+ 22: 1, # 'j'
+ 7: 2, # 'k'
+ 6: 3, # 'l'
+ 13: 1, # 'm'
+ 4: 2, # 'n'
+ 8: 1, # 'o'
+ 23: 1, # 'p'
+ 10: 2, # 'r'
+ 5: 2, # 's'
+ 3: 2, # 't'
+ 21: 1, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 1, # 'é'
+ 30: 1, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 1, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+ 56: { # 'ű'
+ 28: 0, # 'A'
+ 40: 0, # 'B'
+ 54: 0, # 'C'
+ 45: 0, # 'D'
+ 32: 0, # 'E'
+ 50: 0, # 'F'
+ 49: 0, # 'G'
+ 38: 0, # 'H'
+ 39: 0, # 'I'
+ 53: 0, # 'J'
+ 36: 0, # 'K'
+ 41: 0, # 'L'
+ 34: 0, # 'M'
+ 35: 0, # 'N'
+ 47: 0, # 'O'
+ 46: 0, # 'P'
+ 43: 0, # 'R'
+ 33: 0, # 'S'
+ 37: 0, # 'T'
+ 57: 0, # 'U'
+ 48: 0, # 'V'
+ 55: 0, # 'Y'
+ 52: 0, # 'Z'
+ 2: 1, # 'a'
+ 18: 1, # 'b'
+ 26: 0, # 'c'
+ 17: 1, # 'd'
+ 1: 1, # 'e'
+ 27: 1, # 'f'
+ 12: 1, # 'g'
+ 20: 1, # 'h'
+ 9: 1, # 'i'
+ 22: 1, # 'j'
+ 7: 1, # 'k'
+ 6: 1, # 'l'
+ 13: 0, # 'm'
+ 4: 2, # 'n'
+ 8: 0, # 'o'
+ 23: 0, # 'p'
+ 10: 1, # 'r'
+ 5: 1, # 's'
+ 3: 1, # 't'
+ 21: 0, # 'u'
+ 19: 1, # 'v'
+ 62: 0, # 'x'
+ 16: 0, # 'y'
+ 11: 2, # 'z'
+ 51: 0, # 'Á'
+ 44: 0, # 'É'
+ 61: 0, # 'Í'
+ 58: 0, # 'Ó'
+ 59: 0, # 'Ö'
+ 60: 0, # 'Ú'
+ 63: 0, # 'Ü'
+ 14: 0, # 'á'
+ 15: 0, # 'é'
+ 30: 0, # 'í'
+ 25: 0, # 'ó'
+ 24: 0, # 'ö'
+ 31: 0, # 'ú'
+ 29: 0, # 'ü'
+ 42: 0, # 'ő'
+ 56: 0, # 'ű'
+ },
+}
+
+# 255: Undefined characters that did not exist in training text
+# 254: Carriage/Return
+# 253: symbol (punctuation) that does not belong to word
+# 252: 0 - 9
+# 251: Control characters
+
+# Character Mapping Table(s):
+WINDOWS_1250_HUNGARIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 28, # 'A'
+ 66: 40, # 'B'
+ 67: 54, # 'C'
+ 68: 45, # 'D'
+ 69: 32, # 'E'
+ 70: 50, # 'F'
+ 71: 49, # 'G'
+ 72: 38, # 'H'
+ 73: 39, # 'I'
+ 74: 53, # 'J'
+ 75: 36, # 'K'
+ 76: 41, # 'L'
+ 77: 34, # 'M'
+ 78: 35, # 'N'
+ 79: 47, # 'O'
+ 80: 46, # 'P'
+ 81: 72, # 'Q'
+ 82: 43, # 'R'
+ 83: 33, # 'S'
+ 84: 37, # 'T'
+ 85: 57, # 'U'
+ 86: 48, # 'V'
+ 87: 64, # 'W'
+ 88: 68, # 'X'
+ 89: 55, # 'Y'
+ 90: 52, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 2, # 'a'
+ 98: 18, # 'b'
+ 99: 26, # 'c'
+ 100: 17, # 'd'
+ 101: 1, # 'e'
+ 102: 27, # 'f'
+ 103: 12, # 'g'
+ 104: 20, # 'h'
+ 105: 9, # 'i'
+ 106: 22, # 'j'
+ 107: 7, # 'k'
+ 108: 6, # 'l'
+ 109: 13, # 'm'
+ 110: 4, # 'n'
+ 111: 8, # 'o'
+ 112: 23, # 'p'
+ 113: 67, # 'q'
+ 114: 10, # 'r'
+ 115: 5, # 's'
+ 116: 3, # 't'
+ 117: 21, # 'u'
+ 118: 19, # 'v'
+ 119: 65, # 'w'
+ 120: 62, # 'x'
+ 121: 16, # 'y'
+ 122: 11, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 161, # '€'
+ 129: 162, # None
+ 130: 163, # '‚'
+ 131: 164, # None
+ 132: 165, # '„'
+ 133: 166, # '…'
+ 134: 167, # '†'
+ 135: 168, # '‡'
+ 136: 169, # None
+ 137: 170, # '‰'
+ 138: 171, # 'Š'
+ 139: 172, # '‹'
+ 140: 173, # 'Ś'
+ 141: 174, # 'Ť'
+ 142: 175, # 'Ž'
+ 143: 176, # 'Ź'
+ 144: 177, # None
+ 145: 178, # '‘'
+ 146: 179, # '’'
+ 147: 180, # '“'
+ 148: 78, # '”'
+ 149: 181, # '•'
+ 150: 69, # '–'
+ 151: 182, # '—'
+ 152: 183, # None
+ 153: 184, # '™'
+ 154: 185, # 'š'
+ 155: 186, # '›'
+ 156: 187, # 'ś'
+ 157: 188, # 'ť'
+ 158: 189, # 'ž'
+ 159: 190, # 'ź'
+ 160: 191, # '\xa0'
+ 161: 192, # 'ˇ'
+ 162: 193, # '˘'
+ 163: 194, # 'Ł'
+ 164: 195, # '¤'
+ 165: 196, # 'Ą'
+ 166: 197, # '¦'
+ 167: 76, # '§'
+ 168: 198, # '¨'
+ 169: 199, # '©'
+ 170: 200, # 'Ş'
+ 171: 201, # '«'
+ 172: 202, # '¬'
+ 173: 203, # '\xad'
+ 174: 204, # '®'
+ 175: 205, # 'Ż'
+ 176: 81, # '°'
+ 177: 206, # '±'
+ 178: 207, # '˛'
+ 179: 208, # 'ł'
+ 180: 209, # '´'
+ 181: 210, # 'µ'
+ 182: 211, # '¶'
+ 183: 212, # '·'
+ 184: 213, # '¸'
+ 185: 214, # 'ą'
+ 186: 215, # 'ş'
+ 187: 216, # '»'
+ 188: 217, # 'Ľ'
+ 189: 218, # '˝'
+ 190: 219, # 'ľ'
+ 191: 220, # 'ż'
+ 192: 221, # 'Ŕ'
+ 193: 51, # 'Á'
+ 194: 83, # 'Â'
+ 195: 222, # 'Ă'
+ 196: 80, # 'Ä'
+ 197: 223, # 'Ĺ'
+ 198: 224, # 'Ć'
+ 199: 225, # 'Ç'
+ 200: 226, # 'Č'
+ 201: 44, # 'É'
+ 202: 227, # 'Ę'
+ 203: 228, # 'Ë'
+ 204: 229, # 'Ě'
+ 205: 61, # 'Í'
+ 206: 230, # 'Î'
+ 207: 231, # 'Ď'
+ 208: 232, # 'Đ'
+ 209: 233, # 'Ń'
+ 210: 234, # 'Ň'
+ 211: 58, # 'Ó'
+ 212: 235, # 'Ô'
+ 213: 66, # 'Ő'
+ 214: 59, # 'Ö'
+ 215: 236, # '×'
+ 216: 237, # 'Ř'
+ 217: 238, # 'Ů'
+ 218: 60, # 'Ú'
+ 219: 70, # 'Ű'
+ 220: 63, # 'Ü'
+ 221: 239, # 'Ý'
+ 222: 240, # 'Ţ'
+ 223: 241, # 'ß'
+ 224: 84, # 'ŕ'
+ 225: 14, # 'á'
+ 226: 75, # 'â'
+ 227: 242, # 'ă'
+ 228: 71, # 'ä'
+ 229: 82, # 'ĺ'
+ 230: 243, # 'ć'
+ 231: 73, # 'ç'
+ 232: 244, # 'č'
+ 233: 15, # 'é'
+ 234: 85, # 'ę'
+ 235: 79, # 'ë'
+ 236: 86, # 'ě'
+ 237: 30, # 'í'
+ 238: 77, # 'î'
+ 239: 87, # 'ď'
+ 240: 245, # 'đ'
+ 241: 246, # 'ń'
+ 242: 247, # 'ň'
+ 243: 25, # 'ó'
+ 244: 74, # 'ô'
+ 245: 42, # 'ő'
+ 246: 24, # 'ö'
+ 247: 248, # '÷'
+ 248: 249, # 'ř'
+ 249: 250, # 'ů'
+ 250: 31, # 'ú'
+ 251: 56, # 'ű'
+ 252: 29, # 'ü'
+ 253: 251, # 'ý'
+ 254: 252, # 'ţ'
+ 255: 253, # '˙'
+}
+
+WINDOWS_1250_HUNGARIAN_MODEL = SingleByteCharSetModel(
+ charset_name="windows-1250",
+ language="Hungarian",
+ char_to_order_map=WINDOWS_1250_HUNGARIAN_CHAR_TO_ORDER,
+ language_model=HUNGARIAN_LANG_MODEL,
+ typical_positive_ratio=0.947368,
+ keep_ascii_letters=True,
+ alphabet="ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzÁÉÍÓÖÚÜáéíóöúüŐőŰű",
+)
+
+ISO_8859_2_HUNGARIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 28, # 'A'
+ 66: 40, # 'B'
+ 67: 54, # 'C'
+ 68: 45, # 'D'
+ 69: 32, # 'E'
+ 70: 50, # 'F'
+ 71: 49, # 'G'
+ 72: 38, # 'H'
+ 73: 39, # 'I'
+ 74: 53, # 'J'
+ 75: 36, # 'K'
+ 76: 41, # 'L'
+ 77: 34, # 'M'
+ 78: 35, # 'N'
+ 79: 47, # 'O'
+ 80: 46, # 'P'
+ 81: 71, # 'Q'
+ 82: 43, # 'R'
+ 83: 33, # 'S'
+ 84: 37, # 'T'
+ 85: 57, # 'U'
+ 86: 48, # 'V'
+ 87: 64, # 'W'
+ 88: 68, # 'X'
+ 89: 55, # 'Y'
+ 90: 52, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 2, # 'a'
+ 98: 18, # 'b'
+ 99: 26, # 'c'
+ 100: 17, # 'd'
+ 101: 1, # 'e'
+ 102: 27, # 'f'
+ 103: 12, # 'g'
+ 104: 20, # 'h'
+ 105: 9, # 'i'
+ 106: 22, # 'j'
+ 107: 7, # 'k'
+ 108: 6, # 'l'
+ 109: 13, # 'm'
+ 110: 4, # 'n'
+ 111: 8, # 'o'
+ 112: 23, # 'p'
+ 113: 67, # 'q'
+ 114: 10, # 'r'
+ 115: 5, # 's'
+ 116: 3, # 't'
+ 117: 21, # 'u'
+ 118: 19, # 'v'
+ 119: 65, # 'w'
+ 120: 62, # 'x'
+ 121: 16, # 'y'
+ 122: 11, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 159, # '\x80'
+ 129: 160, # '\x81'
+ 130: 161, # '\x82'
+ 131: 162, # '\x83'
+ 132: 163, # '\x84'
+ 133: 164, # '\x85'
+ 134: 165, # '\x86'
+ 135: 166, # '\x87'
+ 136: 167, # '\x88'
+ 137: 168, # '\x89'
+ 138: 169, # '\x8a'
+ 139: 170, # '\x8b'
+ 140: 171, # '\x8c'
+ 141: 172, # '\x8d'
+ 142: 173, # '\x8e'
+ 143: 174, # '\x8f'
+ 144: 175, # '\x90'
+ 145: 176, # '\x91'
+ 146: 177, # '\x92'
+ 147: 178, # '\x93'
+ 148: 179, # '\x94'
+ 149: 180, # '\x95'
+ 150: 181, # '\x96'
+ 151: 182, # '\x97'
+ 152: 183, # '\x98'
+ 153: 184, # '\x99'
+ 154: 185, # '\x9a'
+ 155: 186, # '\x9b'
+ 156: 187, # '\x9c'
+ 157: 188, # '\x9d'
+ 158: 189, # '\x9e'
+ 159: 190, # '\x9f'
+ 160: 191, # '\xa0'
+ 161: 192, # 'Ą'
+ 162: 193, # '˘'
+ 163: 194, # 'Ł'
+ 164: 195, # '¤'
+ 165: 196, # 'Ľ'
+ 166: 197, # 'Ś'
+ 167: 75, # '§'
+ 168: 198, # '¨'
+ 169: 199, # 'Š'
+ 170: 200, # 'Ş'
+ 171: 201, # 'Ť'
+ 172: 202, # 'Ź'
+ 173: 203, # '\xad'
+ 174: 204, # 'Ž'
+ 175: 205, # 'Ż'
+ 176: 79, # '°'
+ 177: 206, # 'ą'
+ 178: 207, # '˛'
+ 179: 208, # 'ł'
+ 180: 209, # '´'
+ 181: 210, # 'ľ'
+ 182: 211, # 'ś'
+ 183: 212, # 'ˇ'
+ 184: 213, # '¸'
+ 185: 214, # 'š'
+ 186: 215, # 'ş'
+ 187: 216, # 'ť'
+ 188: 217, # 'ź'
+ 189: 218, # '˝'
+ 190: 219, # 'ž'
+ 191: 220, # 'ż'
+ 192: 221, # 'Ŕ'
+ 193: 51, # 'Á'
+ 194: 81, # 'Â'
+ 195: 222, # 'Ă'
+ 196: 78, # 'Ä'
+ 197: 223, # 'Ĺ'
+ 198: 224, # 'Ć'
+ 199: 225, # 'Ç'
+ 200: 226, # 'Č'
+ 201: 44, # 'É'
+ 202: 227, # 'Ę'
+ 203: 228, # 'Ë'
+ 204: 229, # 'Ě'
+ 205: 61, # 'Í'
+ 206: 230, # 'Î'
+ 207: 231, # 'Ď'
+ 208: 232, # 'Đ'
+ 209: 233, # 'Ń'
+ 210: 234, # 'Ň'
+ 211: 58, # 'Ó'
+ 212: 235, # 'Ô'
+ 213: 66, # 'Ő'
+ 214: 59, # 'Ö'
+ 215: 236, # '×'
+ 216: 237, # 'Ř'
+ 217: 238, # 'Ů'
+ 218: 60, # 'Ú'
+ 219: 69, # 'Ű'
+ 220: 63, # 'Ü'
+ 221: 239, # 'Ý'
+ 222: 240, # 'Ţ'
+ 223: 241, # 'ß'
+ 224: 82, # 'ŕ'
+ 225: 14, # 'á'
+ 226: 74, # 'â'
+ 227: 242, # 'ă'
+ 228: 70, # 'ä'
+ 229: 80, # 'ĺ'
+ 230: 243, # 'ć'
+ 231: 72, # 'ç'
+ 232: 244, # 'č'
+ 233: 15, # 'é'
+ 234: 83, # 'ę'
+ 235: 77, # 'ë'
+ 236: 84, # 'ě'
+ 237: 30, # 'í'
+ 238: 76, # 'î'
+ 239: 85, # 'ď'
+ 240: 245, # 'đ'
+ 241: 246, # 'ń'
+ 242: 247, # 'ň'
+ 243: 25, # 'ó'
+ 244: 73, # 'ô'
+ 245: 42, # 'ő'
+ 246: 24, # 'ö'
+ 247: 248, # '÷'
+ 248: 249, # 'ř'
+ 249: 250, # 'ů'
+ 250: 31, # 'ú'
+ 251: 56, # 'ű'
+ 252: 29, # 'ü'
+ 253: 251, # 'ý'
+ 254: 252, # 'ţ'
+ 255: 253, # '˙'
+}
+
+ISO_8859_2_HUNGARIAN_MODEL = SingleByteCharSetModel(
+ charset_name="ISO-8859-2",
+ language="Hungarian",
+ char_to_order_map=ISO_8859_2_HUNGARIAN_CHAR_TO_ORDER,
+ language_model=HUNGARIAN_LANG_MODEL,
+ typical_positive_ratio=0.947368,
+ keep_ascii_letters=True,
+ alphabet="ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzÁÉÍÓÖÚÜáéíóöúüŐőŰű",
+)
diff --git a/lib/chardet/langrussianmodel.py b/lib/chardet/langrussianmodel.py
new file mode 100644
index 0000000..0d5b178
--- /dev/null
+++ b/lib/chardet/langrussianmodel.py
@@ -0,0 +1,5725 @@
+from chardet.sbcharsetprober import SingleByteCharSetModel
+
+# 3: Positive
+# 2: Likely
+# 1: Unlikely
+# 0: Negative
+
+RUSSIAN_LANG_MODEL = {
+ 37: { # 'А'
+ 37: 0, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 1, # 'З'
+ 42: 1, # 'И'
+ 60: 1, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 2, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 1, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 1, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 1, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 0, # 'е'
+ 24: 1, # 'ж'
+ 20: 1, # 'з'
+ 4: 0, # 'и'
+ 23: 1, # 'й'
+ 11: 2, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 2, # 'н'
+ 1: 0, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 2, # 'у'
+ 39: 2, # 'ф'
+ 26: 2, # 'х'
+ 28: 0, # 'ц'
+ 22: 1, # 'ч'
+ 25: 2, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 44: { # 'Б'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 1, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 2, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 2, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 33: { # 'В'
+ 37: 2, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 2, # 'а'
+ 21: 1, # 'б'
+ 10: 1, # 'в'
+ 19: 1, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 2, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 2, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 3, # 'с'
+ 6: 2, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 1, # 'ц'
+ 22: 2, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 1, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 0, # 'ю'
+ 16: 1, # 'я'
+ },
+ 46: { # 'Г'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 1, # 'в'
+ 19: 0, # 'г'
+ 13: 2, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 1, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 1, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 41: { # 'Д'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 2, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 3, # 'а'
+ 21: 0, # 'б'
+ 10: 2, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 3, # 'ж'
+ 20: 1, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 1, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 48: { # 'Е'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 1, # 'З'
+ 42: 1, # 'И'
+ 60: 1, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 2, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 2, # 'Р'
+ 32: 2, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 1, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 0, # 'а'
+ 21: 0, # 'б'
+ 10: 2, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 2, # 'е'
+ 24: 1, # 'ж'
+ 20: 1, # 'з'
+ 4: 0, # 'и'
+ 23: 2, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 1, # 'н'
+ 1: 0, # 'о'
+ 15: 1, # 'п'
+ 9: 1, # 'р'
+ 7: 3, # 'с'
+ 6: 0, # 'т'
+ 14: 0, # 'у'
+ 39: 1, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 2, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 56: { # 'Ж'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 1, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 1, # 'б'
+ 10: 0, # 'в'
+ 19: 1, # 'г'
+ 13: 1, # 'д'
+ 2: 2, # 'е'
+ 24: 1, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 1, # 'м'
+ 5: 0, # 'н'
+ 1: 2, # 'о'
+ 15: 0, # 'п'
+ 9: 1, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 2, # 'ю'
+ 16: 0, # 'я'
+ },
+ 51: { # 'З'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 2, # 'в'
+ 19: 0, # 'г'
+ 13: 2, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 1, # 'л'
+ 12: 1, # 'м'
+ 5: 2, # 'н'
+ 1: 2, # 'о'
+ 15: 0, # 'п'
+ 9: 1, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 1, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 1, # 'я'
+ },
+ 42: { # 'И'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 2, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 1, # 'З'
+ 42: 1, # 'И'
+ 60: 1, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 2, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 1, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 1, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 1, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 2, # 'з'
+ 4: 1, # 'и'
+ 23: 0, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 2, # 'н'
+ 1: 1, # 'о'
+ 15: 1, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 1, # 'у'
+ 39: 1, # 'ф'
+ 26: 2, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 60: { # 'Й'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 1, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 1, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 0, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 2, # 'о'
+ 15: 0, # 'п'
+ 9: 0, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 0, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 36: { # 'К'
+ 37: 2, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 1, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 1, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 2, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 0, # 'б'
+ 10: 1, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 2, # 'л'
+ 12: 0, # 'м'
+ 5: 1, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 49: { # 'Л'
+ 37: 2, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 1, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 0, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 0, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 1, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 1, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 1, # 'л'
+ 12: 0, # 'м'
+ 5: 1, # 'н'
+ 1: 2, # 'о'
+ 15: 0, # 'п'
+ 9: 0, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 2, # 'ю'
+ 16: 1, # 'я'
+ },
+ 38: { # 'М'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 1, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 1, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 3, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 1, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 1, # 'л'
+ 12: 1, # 'м'
+ 5: 2, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 1, # 'р'
+ 7: 1, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 31: { # 'Н'
+ 37: 2, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 1, # 'З'
+ 42: 2, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 1, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 1, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 3, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 1, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 3, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 2, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 34: { # 'О'
+ 37: 0, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 2, # 'Д'
+ 48: 1, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 1, # 'З'
+ 42: 1, # 'И'
+ 60: 1, # 'Й'
+ 36: 1, # 'К'
+ 49: 2, # 'Л'
+ 38: 1, # 'М'
+ 31: 2, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 2, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 1, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 1, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 1, # 'а'
+ 21: 2, # 'б'
+ 10: 1, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 0, # 'е'
+ 24: 1, # 'ж'
+ 20: 1, # 'з'
+ 4: 0, # 'и'
+ 23: 1, # 'й'
+ 11: 2, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 3, # 'н'
+ 1: 0, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 1, # 'у'
+ 39: 1, # 'ф'
+ 26: 2, # 'х'
+ 28: 1, # 'ц'
+ 22: 2, # 'ч'
+ 25: 2, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 35: { # 'П'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 1, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 2, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 2, # 'л'
+ 12: 0, # 'м'
+ 5: 1, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 3, # 'р'
+ 7: 1, # 'с'
+ 6: 1, # 'т'
+ 14: 2, # 'у'
+ 39: 1, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 2, # 'ь'
+ 30: 1, # 'э'
+ 27: 0, # 'ю'
+ 16: 2, # 'я'
+ },
+ 45: { # 'Р'
+ 37: 2, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 2, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 0, # 'З'
+ 42: 2, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 2, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 1, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 3, # 'а'
+ 21: 0, # 'б'
+ 10: 1, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 1, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 1, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 2, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 2, # 'я'
+ },
+ 32: { # 'С'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 2, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 1, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 2, # 'а'
+ 21: 1, # 'б'
+ 10: 2, # 'в'
+ 19: 1, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 1, # 'ж'
+ 20: 1, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 2, # 'н'
+ 1: 2, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 1, # 'с'
+ 6: 3, # 'т'
+ 14: 2, # 'у'
+ 39: 1, # 'ф'
+ 26: 1, # 'х'
+ 28: 1, # 'ц'
+ 22: 1, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 1, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 40: { # 'Т'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 2, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 1, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 2, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 1, # 'к'
+ 8: 1, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 1, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 52: { # 'У'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 1, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 1, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 1, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 1, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 1, # 'г'
+ 13: 2, # 'д'
+ 2: 1, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 2, # 'и'
+ 23: 1, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 1, # 'н'
+ 1: 2, # 'о'
+ 15: 1, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 0, # 'у'
+ 39: 1, # 'ф'
+ 26: 1, # 'х'
+ 28: 1, # 'ц'
+ 22: 2, # 'ч'
+ 25: 1, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 53: { # 'Ф'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 1, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 2, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 2, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 0, # 'с'
+ 6: 1, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 55: { # 'Х'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 2, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 0, # 'н'
+ 1: 2, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 1, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 1, # 'ь'
+ 30: 1, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 58: { # 'Ц'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 1, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 1, # 'а'
+ 21: 0, # 'б'
+ 10: 1, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 0, # 'о'
+ 15: 0, # 'п'
+ 9: 0, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 1, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 50: { # 'Ч'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 1, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 1, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 1, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 1, # 'о'
+ 15: 0, # 'п'
+ 9: 1, # 'р'
+ 7: 0, # 'с'
+ 6: 3, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 1, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 57: { # 'Ш'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 1, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 0, # 'б'
+ 10: 1, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 1, # 'и'
+ 23: 0, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 1, # 'н'
+ 1: 2, # 'о'
+ 15: 2, # 'п'
+ 9: 1, # 'р'
+ 7: 0, # 'с'
+ 6: 2, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 63: { # 'Щ'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 1, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 1, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 1, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 1, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 1, # 'о'
+ 15: 0, # 'п'
+ 9: 0, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 1, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 62: { # 'Ы'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 1, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 1, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 0, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 0, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 0, # 'о'
+ 15: 0, # 'п'
+ 9: 0, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 0, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 61: { # 'Ь'
+ 37: 0, # 'А'
+ 44: 1, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 1, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 0, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 1, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 1, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 1, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 1, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 0, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 0, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 0, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 0, # 'о'
+ 15: 0, # 'п'
+ 9: 0, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 0, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 47: { # 'Э'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 0, # 'Г'
+ 41: 1, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 1, # 'Й'
+ 36: 1, # 'К'
+ 49: 1, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 1, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 1, # 'а'
+ 21: 1, # 'б'
+ 10: 2, # 'в'
+ 19: 1, # 'г'
+ 13: 2, # 'д'
+ 2: 0, # 'е'
+ 24: 1, # 'ж'
+ 20: 0, # 'з'
+ 4: 0, # 'и'
+ 23: 2, # 'й'
+ 11: 2, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 2, # 'н'
+ 1: 0, # 'о'
+ 15: 1, # 'п'
+ 9: 2, # 'р'
+ 7: 1, # 'с'
+ 6: 3, # 'т'
+ 14: 1, # 'у'
+ 39: 1, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 59: { # 'Ю'
+ 37: 1, # 'А'
+ 44: 1, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 1, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 0, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 1, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 1, # 'б'
+ 10: 0, # 'в'
+ 19: 1, # 'г'
+ 13: 1, # 'д'
+ 2: 0, # 'е'
+ 24: 1, # 'ж'
+ 20: 0, # 'з'
+ 4: 0, # 'и'
+ 23: 0, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 2, # 'н'
+ 1: 0, # 'о'
+ 15: 1, # 'п'
+ 9: 1, # 'р'
+ 7: 1, # 'с'
+ 6: 0, # 'т'
+ 14: 0, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 43: { # 'Я'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 1, # 'В'
+ 46: 1, # 'Г'
+ 41: 0, # 'Д'
+ 48: 1, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 1, # 'С'
+ 40: 1, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 1, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 1, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 1, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 1, # 'Ю'
+ 43: 1, # 'Я'
+ 3: 0, # 'а'
+ 21: 1, # 'б'
+ 10: 1, # 'в'
+ 19: 1, # 'г'
+ 13: 1, # 'д'
+ 2: 0, # 'е'
+ 24: 0, # 'ж'
+ 20: 1, # 'з'
+ 4: 0, # 'и'
+ 23: 1, # 'й'
+ 11: 1, # 'к'
+ 8: 1, # 'л'
+ 12: 1, # 'м'
+ 5: 2, # 'н'
+ 1: 0, # 'о'
+ 15: 1, # 'п'
+ 9: 1, # 'р'
+ 7: 1, # 'с'
+ 6: 0, # 'т'
+ 14: 0, # 'у'
+ 39: 0, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 3: { # 'а'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 1, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 3, # 'б'
+ 10: 3, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 3, # 'з'
+ 4: 3, # 'и'
+ 23: 3, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 2, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 3, # 'х'
+ 28: 3, # 'ц'
+ 22: 3, # 'ч'
+ 25: 3, # 'ш'
+ 29: 3, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 2, # 'э'
+ 27: 3, # 'ю'
+ 16: 3, # 'я'
+ },
+ 21: { # 'б'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 1, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 1, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 1, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 1, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 0, # 'ф'
+ 26: 2, # 'х'
+ 28: 1, # 'ц'
+ 22: 1, # 'ч'
+ 25: 2, # 'ш'
+ 29: 3, # 'щ'
+ 54: 2, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 2, # 'ь'
+ 30: 1, # 'э'
+ 27: 2, # 'ю'
+ 16: 3, # 'я'
+ },
+ 10: { # 'в'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 2, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 1, # 'ж'
+ 20: 3, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 1, # 'ф'
+ 26: 2, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 3, # 'ш'
+ 29: 2, # 'щ'
+ 54: 2, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 3, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 3, # 'я'
+ },
+ 19: { # 'г'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 2, # 'в'
+ 19: 1, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 1, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 3, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 1, # 'ф'
+ 26: 1, # 'х'
+ 28: 1, # 'ц'
+ 22: 2, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 1, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 13: { # 'д'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 3, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 1, # 'ф'
+ 26: 2, # 'х'
+ 28: 3, # 'ц'
+ 22: 2, # 'ч'
+ 25: 2, # 'ш'
+ 29: 1, # 'щ'
+ 54: 2, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 3, # 'ь'
+ 30: 1, # 'э'
+ 27: 2, # 'ю'
+ 16: 3, # 'я'
+ },
+ 2: { # 'е'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 3, # 'б'
+ 10: 3, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 3, # 'з'
+ 4: 2, # 'и'
+ 23: 3, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 2, # 'у'
+ 39: 2, # 'ф'
+ 26: 3, # 'х'
+ 28: 3, # 'ц'
+ 22: 3, # 'ч'
+ 25: 3, # 'ш'
+ 29: 3, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 2, # 'ю'
+ 16: 3, # 'я'
+ },
+ 24: { # 'ж'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 1, # 'в'
+ 19: 2, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 1, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 3, # 'н'
+ 1: 2, # 'о'
+ 15: 1, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 1, # 'т'
+ 14: 3, # 'у'
+ 39: 1, # 'ф'
+ 26: 0, # 'х'
+ 28: 1, # 'ц'
+ 22: 2, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 2, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 20: { # 'з'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 3, # 'б'
+ 10: 3, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 3, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 1, # 'ц'
+ 22: 2, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 2, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 2, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 3, # 'я'
+ },
+ 4: { # 'и'
+ 37: 1, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 3, # 'б'
+ 10: 3, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 3, # 'з'
+ 4: 3, # 'и'
+ 23: 3, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 2, # 'у'
+ 39: 2, # 'ф'
+ 26: 3, # 'х'
+ 28: 3, # 'ц'
+ 22: 3, # 'ч'
+ 25: 3, # 'ш'
+ 29: 3, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 2, # 'э'
+ 27: 3, # 'ю'
+ 16: 3, # 'я'
+ },
+ 23: { # 'й'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 1, # 'а'
+ 21: 1, # 'б'
+ 10: 1, # 'в'
+ 19: 2, # 'г'
+ 13: 3, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 2, # 'з'
+ 4: 1, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 2, # 'о'
+ 15: 1, # 'п'
+ 9: 2, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 1, # 'у'
+ 39: 2, # 'ф'
+ 26: 1, # 'х'
+ 28: 2, # 'ц'
+ 22: 3, # 'ч'
+ 25: 2, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 2, # 'я'
+ },
+ 11: { # 'к'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 3, # 'в'
+ 19: 1, # 'г'
+ 13: 1, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 3, # 'л'
+ 12: 1, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 1, # 'ф'
+ 26: 2, # 'х'
+ 28: 2, # 'ц'
+ 22: 1, # 'ч'
+ 25: 2, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 1, # 'ы'
+ 17: 1, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 8: { # 'л'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 3, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 2, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 1, # 'р'
+ 7: 3, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 2, # 'х'
+ 28: 1, # 'ц'
+ 22: 3, # 'ч'
+ 25: 2, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 3, # 'ь'
+ 30: 1, # 'э'
+ 27: 3, # 'ю'
+ 16: 3, # 'я'
+ },
+ 12: { # 'м'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 2, # 'г'
+ 13: 1, # 'д'
+ 2: 3, # 'е'
+ 24: 1, # 'ж'
+ 20: 1, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 3, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 2, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 1, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 2, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 3, # 'я'
+ },
+ 5: { # 'н'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 1, # 'п'
+ 9: 2, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 2, # 'х'
+ 28: 3, # 'ц'
+ 22: 3, # 'ч'
+ 25: 2, # 'ш'
+ 29: 2, # 'щ'
+ 54: 1, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 3, # 'ь'
+ 30: 1, # 'э'
+ 27: 3, # 'ю'
+ 16: 3, # 'я'
+ },
+ 1: { # 'о'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 3, # 'б'
+ 10: 3, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 3, # 'з'
+ 4: 3, # 'и'
+ 23: 3, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 2, # 'у'
+ 39: 2, # 'ф'
+ 26: 3, # 'х'
+ 28: 2, # 'ц'
+ 22: 3, # 'ч'
+ 25: 3, # 'ш'
+ 29: 3, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 2, # 'э'
+ 27: 3, # 'ю'
+ 16: 3, # 'я'
+ },
+ 15: { # 'п'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 3, # 'л'
+ 12: 1, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 3, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 1, # 'ф'
+ 26: 0, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 1, # 'ш'
+ 29: 1, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 2, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 3, # 'я'
+ },
+ 9: { # 'р'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 3, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 2, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 2, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 3, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 3, # 'ш'
+ 29: 2, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 3, # 'ь'
+ 30: 2, # 'э'
+ 27: 2, # 'ю'
+ 16: 3, # 'я'
+ },
+ 7: { # 'с'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 1, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 3, # 'в'
+ 19: 2, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 3, # 'х'
+ 28: 2, # 'ц'
+ 22: 3, # 'ч'
+ 25: 2, # 'ш'
+ 29: 1, # 'щ'
+ 54: 2, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 3, # 'ь'
+ 30: 2, # 'э'
+ 27: 3, # 'ю'
+ 16: 3, # 'я'
+ },
+ 6: { # 'т'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 2, # 'б'
+ 10: 3, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 1, # 'ж'
+ 20: 1, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 2, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 2, # 'ш'
+ 29: 2, # 'щ'
+ 54: 2, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 3, # 'ь'
+ 30: 2, # 'э'
+ 27: 2, # 'ю'
+ 16: 3, # 'я'
+ },
+ 14: { # 'у'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 3, # 'б'
+ 10: 3, # 'в'
+ 19: 3, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 3, # 'з'
+ 4: 2, # 'и'
+ 23: 2, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 2, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 1, # 'у'
+ 39: 2, # 'ф'
+ 26: 3, # 'х'
+ 28: 2, # 'ц'
+ 22: 3, # 'ч'
+ 25: 3, # 'ш'
+ 29: 3, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 2, # 'э'
+ 27: 3, # 'ю'
+ 16: 2, # 'я'
+ },
+ 39: { # 'ф'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 0, # 'в'
+ 19: 1, # 'г'
+ 13: 0, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 1, # 'н'
+ 1: 3, # 'о'
+ 15: 1, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 2, # 'у'
+ 39: 2, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 1, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 2, # 'ы'
+ 17: 1, # 'ь'
+ 30: 2, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 26: { # 'х'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 0, # 'б'
+ 10: 3, # 'в'
+ 19: 1, # 'г'
+ 13: 1, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 1, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 1, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 1, # 'п'
+ 9: 3, # 'р'
+ 7: 2, # 'с'
+ 6: 2, # 'т'
+ 14: 2, # 'у'
+ 39: 1, # 'ф'
+ 26: 1, # 'х'
+ 28: 1, # 'ц'
+ 22: 1, # 'ч'
+ 25: 2, # 'ш'
+ 29: 0, # 'щ'
+ 54: 1, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 1, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 28: { # 'ц'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 2, # 'в'
+ 19: 1, # 'г'
+ 13: 1, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 1, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 2, # 'к'
+ 8: 1, # 'л'
+ 12: 1, # 'м'
+ 5: 1, # 'н'
+ 1: 3, # 'о'
+ 15: 0, # 'п'
+ 9: 1, # 'р'
+ 7: 0, # 'с'
+ 6: 1, # 'т'
+ 14: 3, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 1, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 3, # 'ы'
+ 17: 1, # 'ь'
+ 30: 0, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 22: { # 'ч'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 1, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 3, # 'е'
+ 24: 1, # 'ж'
+ 20: 0, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 2, # 'л'
+ 12: 1, # 'м'
+ 5: 3, # 'н'
+ 1: 2, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 1, # 'с'
+ 6: 3, # 'т'
+ 14: 3, # 'у'
+ 39: 1, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 1, # 'ч'
+ 25: 2, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 3, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 25: { # 'ш'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 1, # 'б'
+ 10: 2, # 'в'
+ 19: 1, # 'г'
+ 13: 0, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 2, # 'м'
+ 5: 3, # 'н'
+ 1: 3, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 1, # 'с'
+ 6: 2, # 'т'
+ 14: 3, # 'у'
+ 39: 2, # 'ф'
+ 26: 1, # 'х'
+ 28: 1, # 'ц'
+ 22: 1, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 3, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 0, # 'я'
+ },
+ 29: { # 'щ'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 3, # 'а'
+ 21: 0, # 'б'
+ 10: 1, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 3, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 3, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 1, # 'м'
+ 5: 2, # 'н'
+ 1: 1, # 'о'
+ 15: 0, # 'п'
+ 9: 2, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 2, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 2, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 0, # 'я'
+ },
+ 54: { # 'ъ'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 0, # 'б'
+ 10: 0, # 'в'
+ 19: 0, # 'г'
+ 13: 0, # 'д'
+ 2: 2, # 'е'
+ 24: 0, # 'ж'
+ 20: 0, # 'з'
+ 4: 0, # 'и'
+ 23: 0, # 'й'
+ 11: 0, # 'к'
+ 8: 0, # 'л'
+ 12: 0, # 'м'
+ 5: 0, # 'н'
+ 1: 0, # 'о'
+ 15: 0, # 'п'
+ 9: 0, # 'р'
+ 7: 0, # 'с'
+ 6: 0, # 'т'
+ 14: 0, # 'у'
+ 39: 0, # 'ф'
+ 26: 0, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 0, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 1, # 'ю'
+ 16: 2, # 'я'
+ },
+ 18: { # 'ы'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 3, # 'б'
+ 10: 3, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 2, # 'и'
+ 23: 3, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 1, # 'о'
+ 15: 3, # 'п'
+ 9: 3, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 1, # 'у'
+ 39: 0, # 'ф'
+ 26: 3, # 'х'
+ 28: 2, # 'ц'
+ 22: 3, # 'ч'
+ 25: 3, # 'ш'
+ 29: 2, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 0, # 'ю'
+ 16: 2, # 'я'
+ },
+ 17: { # 'ь'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 2, # 'б'
+ 10: 2, # 'в'
+ 19: 2, # 'г'
+ 13: 2, # 'д'
+ 2: 3, # 'е'
+ 24: 1, # 'ж'
+ 20: 3, # 'з'
+ 4: 2, # 'и'
+ 23: 0, # 'й'
+ 11: 3, # 'к'
+ 8: 0, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 2, # 'о'
+ 15: 2, # 'п'
+ 9: 1, # 'р'
+ 7: 3, # 'с'
+ 6: 2, # 'т'
+ 14: 0, # 'у'
+ 39: 2, # 'ф'
+ 26: 1, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 3, # 'ш'
+ 29: 2, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 3, # 'ю'
+ 16: 3, # 'я'
+ },
+ 30: { # 'э'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 1, # 'М'
+ 31: 1, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 1, # 'Р'
+ 32: 1, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 1, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 1, # 'б'
+ 10: 1, # 'в'
+ 19: 1, # 'г'
+ 13: 2, # 'д'
+ 2: 1, # 'е'
+ 24: 0, # 'ж'
+ 20: 1, # 'з'
+ 4: 0, # 'и'
+ 23: 2, # 'й'
+ 11: 2, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 2, # 'н'
+ 1: 0, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 2, # 'с'
+ 6: 3, # 'т'
+ 14: 1, # 'у'
+ 39: 2, # 'ф'
+ 26: 1, # 'х'
+ 28: 0, # 'ц'
+ 22: 0, # 'ч'
+ 25: 1, # 'ш'
+ 29: 0, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 1, # 'ю'
+ 16: 1, # 'я'
+ },
+ 27: { # 'ю'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 2, # 'а'
+ 21: 3, # 'б'
+ 10: 1, # 'в'
+ 19: 2, # 'г'
+ 13: 3, # 'д'
+ 2: 1, # 'е'
+ 24: 2, # 'ж'
+ 20: 2, # 'з'
+ 4: 1, # 'и'
+ 23: 1, # 'й'
+ 11: 2, # 'к'
+ 8: 2, # 'л'
+ 12: 2, # 'м'
+ 5: 2, # 'н'
+ 1: 1, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 0, # 'у'
+ 39: 1, # 'ф'
+ 26: 2, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 2, # 'ш'
+ 29: 3, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 1, # 'э'
+ 27: 2, # 'ю'
+ 16: 1, # 'я'
+ },
+ 16: { # 'я'
+ 37: 0, # 'А'
+ 44: 0, # 'Б'
+ 33: 0, # 'В'
+ 46: 0, # 'Г'
+ 41: 0, # 'Д'
+ 48: 0, # 'Е'
+ 56: 0, # 'Ж'
+ 51: 0, # 'З'
+ 42: 0, # 'И'
+ 60: 0, # 'Й'
+ 36: 0, # 'К'
+ 49: 0, # 'Л'
+ 38: 0, # 'М'
+ 31: 0, # 'Н'
+ 34: 0, # 'О'
+ 35: 0, # 'П'
+ 45: 0, # 'Р'
+ 32: 0, # 'С'
+ 40: 0, # 'Т'
+ 52: 0, # 'У'
+ 53: 0, # 'Ф'
+ 55: 0, # 'Х'
+ 58: 0, # 'Ц'
+ 50: 0, # 'Ч'
+ 57: 0, # 'Ш'
+ 63: 0, # 'Щ'
+ 62: 0, # 'Ы'
+ 61: 0, # 'Ь'
+ 47: 0, # 'Э'
+ 59: 0, # 'Ю'
+ 43: 0, # 'Я'
+ 3: 0, # 'а'
+ 21: 2, # 'б'
+ 10: 3, # 'в'
+ 19: 2, # 'г'
+ 13: 3, # 'д'
+ 2: 3, # 'е'
+ 24: 3, # 'ж'
+ 20: 3, # 'з'
+ 4: 2, # 'и'
+ 23: 2, # 'й'
+ 11: 3, # 'к'
+ 8: 3, # 'л'
+ 12: 3, # 'м'
+ 5: 3, # 'н'
+ 1: 0, # 'о'
+ 15: 2, # 'п'
+ 9: 2, # 'р'
+ 7: 3, # 'с'
+ 6: 3, # 'т'
+ 14: 1, # 'у'
+ 39: 1, # 'ф'
+ 26: 3, # 'х'
+ 28: 2, # 'ц'
+ 22: 2, # 'ч'
+ 25: 2, # 'ш'
+ 29: 3, # 'щ'
+ 54: 0, # 'ъ'
+ 18: 0, # 'ы'
+ 17: 0, # 'ь'
+ 30: 0, # 'э'
+ 27: 2, # 'ю'
+ 16: 2, # 'я'
+ },
+}
+
+# 255: Undefined characters that did not exist in training text
+# 254: Carriage/Return
+# 253: symbol (punctuation) that does not belong to word
+# 252: 0 - 9
+# 251: Control characters
+
+# Character Mapping Table(s):
+IBM866_RUSSIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 142, # 'A'
+ 66: 143, # 'B'
+ 67: 144, # 'C'
+ 68: 145, # 'D'
+ 69: 146, # 'E'
+ 70: 147, # 'F'
+ 71: 148, # 'G'
+ 72: 149, # 'H'
+ 73: 150, # 'I'
+ 74: 151, # 'J'
+ 75: 152, # 'K'
+ 76: 74, # 'L'
+ 77: 153, # 'M'
+ 78: 75, # 'N'
+ 79: 154, # 'O'
+ 80: 155, # 'P'
+ 81: 156, # 'Q'
+ 82: 157, # 'R'
+ 83: 158, # 'S'
+ 84: 159, # 'T'
+ 85: 160, # 'U'
+ 86: 161, # 'V'
+ 87: 162, # 'W'
+ 88: 163, # 'X'
+ 89: 164, # 'Y'
+ 90: 165, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 71, # 'a'
+ 98: 172, # 'b'
+ 99: 66, # 'c'
+ 100: 173, # 'd'
+ 101: 65, # 'e'
+ 102: 174, # 'f'
+ 103: 76, # 'g'
+ 104: 175, # 'h'
+ 105: 64, # 'i'
+ 106: 176, # 'j'
+ 107: 177, # 'k'
+ 108: 77, # 'l'
+ 109: 72, # 'm'
+ 110: 178, # 'n'
+ 111: 69, # 'o'
+ 112: 67, # 'p'
+ 113: 179, # 'q'
+ 114: 78, # 'r'
+ 115: 73, # 's'
+ 116: 180, # 't'
+ 117: 181, # 'u'
+ 118: 79, # 'v'
+ 119: 182, # 'w'
+ 120: 183, # 'x'
+ 121: 184, # 'y'
+ 122: 185, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 37, # 'А'
+ 129: 44, # 'Б'
+ 130: 33, # 'В'
+ 131: 46, # 'Г'
+ 132: 41, # 'Д'
+ 133: 48, # 'Е'
+ 134: 56, # 'Ж'
+ 135: 51, # 'З'
+ 136: 42, # 'И'
+ 137: 60, # 'Й'
+ 138: 36, # 'К'
+ 139: 49, # 'Л'
+ 140: 38, # 'М'
+ 141: 31, # 'Н'
+ 142: 34, # 'О'
+ 143: 35, # 'П'
+ 144: 45, # 'Р'
+ 145: 32, # 'С'
+ 146: 40, # 'Т'
+ 147: 52, # 'У'
+ 148: 53, # 'Ф'
+ 149: 55, # 'Х'
+ 150: 58, # 'Ц'
+ 151: 50, # 'Ч'
+ 152: 57, # 'Ш'
+ 153: 63, # 'Щ'
+ 154: 70, # 'Ъ'
+ 155: 62, # 'Ы'
+ 156: 61, # 'Ь'
+ 157: 47, # 'Э'
+ 158: 59, # 'Ю'
+ 159: 43, # 'Я'
+ 160: 3, # 'а'
+ 161: 21, # 'б'
+ 162: 10, # 'в'
+ 163: 19, # 'г'
+ 164: 13, # 'д'
+ 165: 2, # 'е'
+ 166: 24, # 'ж'
+ 167: 20, # 'з'
+ 168: 4, # 'и'
+ 169: 23, # 'й'
+ 170: 11, # 'к'
+ 171: 8, # 'л'
+ 172: 12, # 'м'
+ 173: 5, # 'н'
+ 174: 1, # 'о'
+ 175: 15, # 'п'
+ 176: 191, # '░'
+ 177: 192, # '▒'
+ 178: 193, # '▓'
+ 179: 194, # '│'
+ 180: 195, # '┤'
+ 181: 196, # '╡'
+ 182: 197, # '╢'
+ 183: 198, # '╖'
+ 184: 199, # '╕'
+ 185: 200, # '╣'
+ 186: 201, # '║'
+ 187: 202, # '╗'
+ 188: 203, # '╝'
+ 189: 204, # '╜'
+ 190: 205, # '╛'
+ 191: 206, # '┐'
+ 192: 207, # '└'
+ 193: 208, # '┴'
+ 194: 209, # '┬'
+ 195: 210, # '├'
+ 196: 211, # '─'
+ 197: 212, # '┼'
+ 198: 213, # '╞'
+ 199: 214, # '╟'
+ 200: 215, # '╚'
+ 201: 216, # '╔'
+ 202: 217, # '╩'
+ 203: 218, # '╦'
+ 204: 219, # '╠'
+ 205: 220, # '═'
+ 206: 221, # '╬'
+ 207: 222, # '╧'
+ 208: 223, # '╨'
+ 209: 224, # '╤'
+ 210: 225, # '╥'
+ 211: 226, # '╙'
+ 212: 227, # '╘'
+ 213: 228, # '╒'
+ 214: 229, # '╓'
+ 215: 230, # '╫'
+ 216: 231, # '╪'
+ 217: 232, # '┘'
+ 218: 233, # '┌'
+ 219: 234, # '█'
+ 220: 235, # '▄'
+ 221: 236, # '▌'
+ 222: 237, # '▐'
+ 223: 238, # '▀'
+ 224: 9, # 'р'
+ 225: 7, # 'с'
+ 226: 6, # 'т'
+ 227: 14, # 'у'
+ 228: 39, # 'ф'
+ 229: 26, # 'х'
+ 230: 28, # 'ц'
+ 231: 22, # 'ч'
+ 232: 25, # 'ш'
+ 233: 29, # 'щ'
+ 234: 54, # 'ъ'
+ 235: 18, # 'ы'
+ 236: 17, # 'ь'
+ 237: 30, # 'э'
+ 238: 27, # 'ю'
+ 239: 16, # 'я'
+ 240: 239, # 'Ё'
+ 241: 68, # 'ё'
+ 242: 240, # 'Є'
+ 243: 241, # 'є'
+ 244: 242, # 'Ї'
+ 245: 243, # 'ї'
+ 246: 244, # 'Ў'
+ 247: 245, # 'ў'
+ 248: 246, # '°'
+ 249: 247, # '∙'
+ 250: 248, # '·'
+ 251: 249, # '√'
+ 252: 250, # '№'
+ 253: 251, # '¤'
+ 254: 252, # '■'
+ 255: 255, # '\xa0'
+}
+
+IBM866_RUSSIAN_MODEL = SingleByteCharSetModel(
+ charset_name="IBM866",
+ language="Russian",
+ char_to_order_map=IBM866_RUSSIAN_CHAR_TO_ORDER,
+ language_model=RUSSIAN_LANG_MODEL,
+ typical_positive_ratio=0.976601,
+ keep_ascii_letters=False,
+ alphabet="ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
+)
+
+WINDOWS_1251_RUSSIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 142, # 'A'
+ 66: 143, # 'B'
+ 67: 144, # 'C'
+ 68: 145, # 'D'
+ 69: 146, # 'E'
+ 70: 147, # 'F'
+ 71: 148, # 'G'
+ 72: 149, # 'H'
+ 73: 150, # 'I'
+ 74: 151, # 'J'
+ 75: 152, # 'K'
+ 76: 74, # 'L'
+ 77: 153, # 'M'
+ 78: 75, # 'N'
+ 79: 154, # 'O'
+ 80: 155, # 'P'
+ 81: 156, # 'Q'
+ 82: 157, # 'R'
+ 83: 158, # 'S'
+ 84: 159, # 'T'
+ 85: 160, # 'U'
+ 86: 161, # 'V'
+ 87: 162, # 'W'
+ 88: 163, # 'X'
+ 89: 164, # 'Y'
+ 90: 165, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 71, # 'a'
+ 98: 172, # 'b'
+ 99: 66, # 'c'
+ 100: 173, # 'd'
+ 101: 65, # 'e'
+ 102: 174, # 'f'
+ 103: 76, # 'g'
+ 104: 175, # 'h'
+ 105: 64, # 'i'
+ 106: 176, # 'j'
+ 107: 177, # 'k'
+ 108: 77, # 'l'
+ 109: 72, # 'm'
+ 110: 178, # 'n'
+ 111: 69, # 'o'
+ 112: 67, # 'p'
+ 113: 179, # 'q'
+ 114: 78, # 'r'
+ 115: 73, # 's'
+ 116: 180, # 't'
+ 117: 181, # 'u'
+ 118: 79, # 'v'
+ 119: 182, # 'w'
+ 120: 183, # 'x'
+ 121: 184, # 'y'
+ 122: 185, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 191, # 'Ђ'
+ 129: 192, # 'Ѓ'
+ 130: 193, # '‚'
+ 131: 194, # 'ѓ'
+ 132: 195, # '„'
+ 133: 196, # '…'
+ 134: 197, # '†'
+ 135: 198, # '‡'
+ 136: 199, # '€'
+ 137: 200, # '‰'
+ 138: 201, # 'Љ'
+ 139: 202, # '‹'
+ 140: 203, # 'Њ'
+ 141: 204, # 'Ќ'
+ 142: 205, # 'Ћ'
+ 143: 206, # 'Џ'
+ 144: 207, # 'ђ'
+ 145: 208, # '‘'
+ 146: 209, # '’'
+ 147: 210, # '“'
+ 148: 211, # '”'
+ 149: 212, # '•'
+ 150: 213, # '–'
+ 151: 214, # '—'
+ 152: 215, # None
+ 153: 216, # '™'
+ 154: 217, # 'љ'
+ 155: 218, # '›'
+ 156: 219, # 'њ'
+ 157: 220, # 'ќ'
+ 158: 221, # 'ћ'
+ 159: 222, # 'џ'
+ 160: 223, # '\xa0'
+ 161: 224, # 'Ў'
+ 162: 225, # 'ў'
+ 163: 226, # 'Ј'
+ 164: 227, # '¤'
+ 165: 228, # 'Ґ'
+ 166: 229, # '¦'
+ 167: 230, # '§'
+ 168: 231, # 'Ё'
+ 169: 232, # '©'
+ 170: 233, # 'Є'
+ 171: 234, # '«'
+ 172: 235, # '¬'
+ 173: 236, # '\xad'
+ 174: 237, # '®'
+ 175: 238, # 'Ї'
+ 176: 239, # '°'
+ 177: 240, # '±'
+ 178: 241, # 'І'
+ 179: 242, # 'і'
+ 180: 243, # 'ґ'
+ 181: 244, # 'µ'
+ 182: 245, # '¶'
+ 183: 246, # '·'
+ 184: 68, # 'ё'
+ 185: 247, # '№'
+ 186: 248, # 'є'
+ 187: 249, # '»'
+ 188: 250, # 'ј'
+ 189: 251, # 'Ѕ'
+ 190: 252, # 'ѕ'
+ 191: 253, # 'ї'
+ 192: 37, # 'А'
+ 193: 44, # 'Б'
+ 194: 33, # 'В'
+ 195: 46, # 'Г'
+ 196: 41, # 'Д'
+ 197: 48, # 'Е'
+ 198: 56, # 'Ж'
+ 199: 51, # 'З'
+ 200: 42, # 'И'
+ 201: 60, # 'Й'
+ 202: 36, # 'К'
+ 203: 49, # 'Л'
+ 204: 38, # 'М'
+ 205: 31, # 'Н'
+ 206: 34, # 'О'
+ 207: 35, # 'П'
+ 208: 45, # 'Р'
+ 209: 32, # 'С'
+ 210: 40, # 'Т'
+ 211: 52, # 'У'
+ 212: 53, # 'Ф'
+ 213: 55, # 'Х'
+ 214: 58, # 'Ц'
+ 215: 50, # 'Ч'
+ 216: 57, # 'Ш'
+ 217: 63, # 'Щ'
+ 218: 70, # 'Ъ'
+ 219: 62, # 'Ы'
+ 220: 61, # 'Ь'
+ 221: 47, # 'Э'
+ 222: 59, # 'Ю'
+ 223: 43, # 'Я'
+ 224: 3, # 'а'
+ 225: 21, # 'б'
+ 226: 10, # 'в'
+ 227: 19, # 'г'
+ 228: 13, # 'д'
+ 229: 2, # 'е'
+ 230: 24, # 'ж'
+ 231: 20, # 'з'
+ 232: 4, # 'и'
+ 233: 23, # 'й'
+ 234: 11, # 'к'
+ 235: 8, # 'л'
+ 236: 12, # 'м'
+ 237: 5, # 'н'
+ 238: 1, # 'о'
+ 239: 15, # 'п'
+ 240: 9, # 'р'
+ 241: 7, # 'с'
+ 242: 6, # 'т'
+ 243: 14, # 'у'
+ 244: 39, # 'ф'
+ 245: 26, # 'х'
+ 246: 28, # 'ц'
+ 247: 22, # 'ч'
+ 248: 25, # 'ш'
+ 249: 29, # 'щ'
+ 250: 54, # 'ъ'
+ 251: 18, # 'ы'
+ 252: 17, # 'ь'
+ 253: 30, # 'э'
+ 254: 27, # 'ю'
+ 255: 16, # 'я'
+}
+
+WINDOWS_1251_RUSSIAN_MODEL = SingleByteCharSetModel(
+ charset_name="windows-1251",
+ language="Russian",
+ char_to_order_map=WINDOWS_1251_RUSSIAN_CHAR_TO_ORDER,
+ language_model=RUSSIAN_LANG_MODEL,
+ typical_positive_ratio=0.976601,
+ keep_ascii_letters=False,
+ alphabet="ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
+)
+
+IBM855_RUSSIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 142, # 'A'
+ 66: 143, # 'B'
+ 67: 144, # 'C'
+ 68: 145, # 'D'
+ 69: 146, # 'E'
+ 70: 147, # 'F'
+ 71: 148, # 'G'
+ 72: 149, # 'H'
+ 73: 150, # 'I'
+ 74: 151, # 'J'
+ 75: 152, # 'K'
+ 76: 74, # 'L'
+ 77: 153, # 'M'
+ 78: 75, # 'N'
+ 79: 154, # 'O'
+ 80: 155, # 'P'
+ 81: 156, # 'Q'
+ 82: 157, # 'R'
+ 83: 158, # 'S'
+ 84: 159, # 'T'
+ 85: 160, # 'U'
+ 86: 161, # 'V'
+ 87: 162, # 'W'
+ 88: 163, # 'X'
+ 89: 164, # 'Y'
+ 90: 165, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 71, # 'a'
+ 98: 172, # 'b'
+ 99: 66, # 'c'
+ 100: 173, # 'd'
+ 101: 65, # 'e'
+ 102: 174, # 'f'
+ 103: 76, # 'g'
+ 104: 175, # 'h'
+ 105: 64, # 'i'
+ 106: 176, # 'j'
+ 107: 177, # 'k'
+ 108: 77, # 'l'
+ 109: 72, # 'm'
+ 110: 178, # 'n'
+ 111: 69, # 'o'
+ 112: 67, # 'p'
+ 113: 179, # 'q'
+ 114: 78, # 'r'
+ 115: 73, # 's'
+ 116: 180, # 't'
+ 117: 181, # 'u'
+ 118: 79, # 'v'
+ 119: 182, # 'w'
+ 120: 183, # 'x'
+ 121: 184, # 'y'
+ 122: 185, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 191, # 'ђ'
+ 129: 192, # 'Ђ'
+ 130: 193, # 'ѓ'
+ 131: 194, # 'Ѓ'
+ 132: 68, # 'ё'
+ 133: 195, # 'Ё'
+ 134: 196, # 'є'
+ 135: 197, # 'Є'
+ 136: 198, # 'ѕ'
+ 137: 199, # 'Ѕ'
+ 138: 200, # 'і'
+ 139: 201, # 'І'
+ 140: 202, # 'ї'
+ 141: 203, # 'Ї'
+ 142: 204, # 'ј'
+ 143: 205, # 'Ј'
+ 144: 206, # 'љ'
+ 145: 207, # 'Љ'
+ 146: 208, # 'њ'
+ 147: 209, # 'Њ'
+ 148: 210, # 'ћ'
+ 149: 211, # 'Ћ'
+ 150: 212, # 'ќ'
+ 151: 213, # 'Ќ'
+ 152: 214, # 'ў'
+ 153: 215, # 'Ў'
+ 154: 216, # 'џ'
+ 155: 217, # 'Џ'
+ 156: 27, # 'ю'
+ 157: 59, # 'Ю'
+ 158: 54, # 'ъ'
+ 159: 70, # 'Ъ'
+ 160: 3, # 'а'
+ 161: 37, # 'А'
+ 162: 21, # 'б'
+ 163: 44, # 'Б'
+ 164: 28, # 'ц'
+ 165: 58, # 'Ц'
+ 166: 13, # 'д'
+ 167: 41, # 'Д'
+ 168: 2, # 'е'
+ 169: 48, # 'Е'
+ 170: 39, # 'ф'
+ 171: 53, # 'Ф'
+ 172: 19, # 'г'
+ 173: 46, # 'Г'
+ 174: 218, # '«'
+ 175: 219, # '»'
+ 176: 220, # '░'
+ 177: 221, # '▒'
+ 178: 222, # '▓'
+ 179: 223, # '│'
+ 180: 224, # '┤'
+ 181: 26, # 'х'
+ 182: 55, # 'Х'
+ 183: 4, # 'и'
+ 184: 42, # 'И'
+ 185: 225, # '╣'
+ 186: 226, # '║'
+ 187: 227, # '╗'
+ 188: 228, # '╝'
+ 189: 23, # 'й'
+ 190: 60, # 'Й'
+ 191: 229, # '┐'
+ 192: 230, # '└'
+ 193: 231, # '┴'
+ 194: 232, # '┬'
+ 195: 233, # '├'
+ 196: 234, # '─'
+ 197: 235, # '┼'
+ 198: 11, # 'к'
+ 199: 36, # 'К'
+ 200: 236, # '╚'
+ 201: 237, # '╔'
+ 202: 238, # '╩'
+ 203: 239, # '╦'
+ 204: 240, # '╠'
+ 205: 241, # '═'
+ 206: 242, # '╬'
+ 207: 243, # '¤'
+ 208: 8, # 'л'
+ 209: 49, # 'Л'
+ 210: 12, # 'м'
+ 211: 38, # 'М'
+ 212: 5, # 'н'
+ 213: 31, # 'Н'
+ 214: 1, # 'о'
+ 215: 34, # 'О'
+ 216: 15, # 'п'
+ 217: 244, # '┘'
+ 218: 245, # '┌'
+ 219: 246, # '█'
+ 220: 247, # '▄'
+ 221: 35, # 'П'
+ 222: 16, # 'я'
+ 223: 248, # '▀'
+ 224: 43, # 'Я'
+ 225: 9, # 'р'
+ 226: 45, # 'Р'
+ 227: 7, # 'с'
+ 228: 32, # 'С'
+ 229: 6, # 'т'
+ 230: 40, # 'Т'
+ 231: 14, # 'у'
+ 232: 52, # 'У'
+ 233: 24, # 'ж'
+ 234: 56, # 'Ж'
+ 235: 10, # 'в'
+ 236: 33, # 'В'
+ 237: 17, # 'ь'
+ 238: 61, # 'Ь'
+ 239: 249, # '№'
+ 240: 250, # '\xad'
+ 241: 18, # 'ы'
+ 242: 62, # 'Ы'
+ 243: 20, # 'з'
+ 244: 51, # 'З'
+ 245: 25, # 'ш'
+ 246: 57, # 'Ш'
+ 247: 30, # 'э'
+ 248: 47, # 'Э'
+ 249: 29, # 'щ'
+ 250: 63, # 'Щ'
+ 251: 22, # 'ч'
+ 252: 50, # 'Ч'
+ 253: 251, # '§'
+ 254: 252, # '■'
+ 255: 255, # '\xa0'
+}
+
+IBM855_RUSSIAN_MODEL = SingleByteCharSetModel(
+ charset_name="IBM855",
+ language="Russian",
+ char_to_order_map=IBM855_RUSSIAN_CHAR_TO_ORDER,
+ language_model=RUSSIAN_LANG_MODEL,
+ typical_positive_ratio=0.976601,
+ keep_ascii_letters=False,
+ alphabet="ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
+)
+
+KOI8_R_RUSSIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 142, # 'A'
+ 66: 143, # 'B'
+ 67: 144, # 'C'
+ 68: 145, # 'D'
+ 69: 146, # 'E'
+ 70: 147, # 'F'
+ 71: 148, # 'G'
+ 72: 149, # 'H'
+ 73: 150, # 'I'
+ 74: 151, # 'J'
+ 75: 152, # 'K'
+ 76: 74, # 'L'
+ 77: 153, # 'M'
+ 78: 75, # 'N'
+ 79: 154, # 'O'
+ 80: 155, # 'P'
+ 81: 156, # 'Q'
+ 82: 157, # 'R'
+ 83: 158, # 'S'
+ 84: 159, # 'T'
+ 85: 160, # 'U'
+ 86: 161, # 'V'
+ 87: 162, # 'W'
+ 88: 163, # 'X'
+ 89: 164, # 'Y'
+ 90: 165, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 71, # 'a'
+ 98: 172, # 'b'
+ 99: 66, # 'c'
+ 100: 173, # 'd'
+ 101: 65, # 'e'
+ 102: 174, # 'f'
+ 103: 76, # 'g'
+ 104: 175, # 'h'
+ 105: 64, # 'i'
+ 106: 176, # 'j'
+ 107: 177, # 'k'
+ 108: 77, # 'l'
+ 109: 72, # 'm'
+ 110: 178, # 'n'
+ 111: 69, # 'o'
+ 112: 67, # 'p'
+ 113: 179, # 'q'
+ 114: 78, # 'r'
+ 115: 73, # 's'
+ 116: 180, # 't'
+ 117: 181, # 'u'
+ 118: 79, # 'v'
+ 119: 182, # 'w'
+ 120: 183, # 'x'
+ 121: 184, # 'y'
+ 122: 185, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 191, # '─'
+ 129: 192, # '│'
+ 130: 193, # '┌'
+ 131: 194, # '┐'
+ 132: 195, # '└'
+ 133: 196, # '┘'
+ 134: 197, # '├'
+ 135: 198, # '┤'
+ 136: 199, # '┬'
+ 137: 200, # '┴'
+ 138: 201, # '┼'
+ 139: 202, # '▀'
+ 140: 203, # '▄'
+ 141: 204, # '█'
+ 142: 205, # '▌'
+ 143: 206, # '▐'
+ 144: 207, # '░'
+ 145: 208, # '▒'
+ 146: 209, # '▓'
+ 147: 210, # '⌠'
+ 148: 211, # '■'
+ 149: 212, # '∙'
+ 150: 213, # '√'
+ 151: 214, # '≈'
+ 152: 215, # '≤'
+ 153: 216, # '≥'
+ 154: 217, # '\xa0'
+ 155: 218, # '⌡'
+ 156: 219, # '°'
+ 157: 220, # '²'
+ 158: 221, # '·'
+ 159: 222, # '÷'
+ 160: 223, # '═'
+ 161: 224, # '║'
+ 162: 225, # '╒'
+ 163: 68, # 'ё'
+ 164: 226, # '╓'
+ 165: 227, # '╔'
+ 166: 228, # '╕'
+ 167: 229, # '╖'
+ 168: 230, # '╗'
+ 169: 231, # '╘'
+ 170: 232, # '╙'
+ 171: 233, # '╚'
+ 172: 234, # '╛'
+ 173: 235, # '╜'
+ 174: 236, # '╝'
+ 175: 237, # '╞'
+ 176: 238, # '╟'
+ 177: 239, # '╠'
+ 178: 240, # '╡'
+ 179: 241, # 'Ё'
+ 180: 242, # '╢'
+ 181: 243, # '╣'
+ 182: 244, # '╤'
+ 183: 245, # '╥'
+ 184: 246, # '╦'
+ 185: 247, # '╧'
+ 186: 248, # '╨'
+ 187: 249, # '╩'
+ 188: 250, # '╪'
+ 189: 251, # '╫'
+ 190: 252, # '╬'
+ 191: 253, # '©'
+ 192: 27, # 'ю'
+ 193: 3, # 'а'
+ 194: 21, # 'б'
+ 195: 28, # 'ц'
+ 196: 13, # 'д'
+ 197: 2, # 'е'
+ 198: 39, # 'ф'
+ 199: 19, # 'г'
+ 200: 26, # 'х'
+ 201: 4, # 'и'
+ 202: 23, # 'й'
+ 203: 11, # 'к'
+ 204: 8, # 'л'
+ 205: 12, # 'м'
+ 206: 5, # 'н'
+ 207: 1, # 'о'
+ 208: 15, # 'п'
+ 209: 16, # 'я'
+ 210: 9, # 'р'
+ 211: 7, # 'с'
+ 212: 6, # 'т'
+ 213: 14, # 'у'
+ 214: 24, # 'ж'
+ 215: 10, # 'в'
+ 216: 17, # 'ь'
+ 217: 18, # 'ы'
+ 218: 20, # 'з'
+ 219: 25, # 'ш'
+ 220: 30, # 'э'
+ 221: 29, # 'щ'
+ 222: 22, # 'ч'
+ 223: 54, # 'ъ'
+ 224: 59, # 'Ю'
+ 225: 37, # 'А'
+ 226: 44, # 'Б'
+ 227: 58, # 'Ц'
+ 228: 41, # 'Д'
+ 229: 48, # 'Е'
+ 230: 53, # 'Ф'
+ 231: 46, # 'Г'
+ 232: 55, # 'Х'
+ 233: 42, # 'И'
+ 234: 60, # 'Й'
+ 235: 36, # 'К'
+ 236: 49, # 'Л'
+ 237: 38, # 'М'
+ 238: 31, # 'Н'
+ 239: 34, # 'О'
+ 240: 35, # 'П'
+ 241: 43, # 'Я'
+ 242: 45, # 'Р'
+ 243: 32, # 'С'
+ 244: 40, # 'Т'
+ 245: 52, # 'У'
+ 246: 56, # 'Ж'
+ 247: 33, # 'В'
+ 248: 61, # 'Ь'
+ 249: 62, # 'Ы'
+ 250: 51, # 'З'
+ 251: 57, # 'Ш'
+ 252: 47, # 'Э'
+ 253: 63, # 'Щ'
+ 254: 50, # 'Ч'
+ 255: 70, # 'Ъ'
+}
+
+KOI8_R_RUSSIAN_MODEL = SingleByteCharSetModel(
+ charset_name="KOI8-R",
+ language="Russian",
+ char_to_order_map=KOI8_R_RUSSIAN_CHAR_TO_ORDER,
+ language_model=RUSSIAN_LANG_MODEL,
+ typical_positive_ratio=0.976601,
+ keep_ascii_letters=False,
+ alphabet="ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
+)
+
+MACCYRILLIC_RUSSIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 142, # 'A'
+ 66: 143, # 'B'
+ 67: 144, # 'C'
+ 68: 145, # 'D'
+ 69: 146, # 'E'
+ 70: 147, # 'F'
+ 71: 148, # 'G'
+ 72: 149, # 'H'
+ 73: 150, # 'I'
+ 74: 151, # 'J'
+ 75: 152, # 'K'
+ 76: 74, # 'L'
+ 77: 153, # 'M'
+ 78: 75, # 'N'
+ 79: 154, # 'O'
+ 80: 155, # 'P'
+ 81: 156, # 'Q'
+ 82: 157, # 'R'
+ 83: 158, # 'S'
+ 84: 159, # 'T'
+ 85: 160, # 'U'
+ 86: 161, # 'V'
+ 87: 162, # 'W'
+ 88: 163, # 'X'
+ 89: 164, # 'Y'
+ 90: 165, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 71, # 'a'
+ 98: 172, # 'b'
+ 99: 66, # 'c'
+ 100: 173, # 'd'
+ 101: 65, # 'e'
+ 102: 174, # 'f'
+ 103: 76, # 'g'
+ 104: 175, # 'h'
+ 105: 64, # 'i'
+ 106: 176, # 'j'
+ 107: 177, # 'k'
+ 108: 77, # 'l'
+ 109: 72, # 'm'
+ 110: 178, # 'n'
+ 111: 69, # 'o'
+ 112: 67, # 'p'
+ 113: 179, # 'q'
+ 114: 78, # 'r'
+ 115: 73, # 's'
+ 116: 180, # 't'
+ 117: 181, # 'u'
+ 118: 79, # 'v'
+ 119: 182, # 'w'
+ 120: 183, # 'x'
+ 121: 184, # 'y'
+ 122: 185, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 37, # 'А'
+ 129: 44, # 'Б'
+ 130: 33, # 'В'
+ 131: 46, # 'Г'
+ 132: 41, # 'Д'
+ 133: 48, # 'Е'
+ 134: 56, # 'Ж'
+ 135: 51, # 'З'
+ 136: 42, # 'И'
+ 137: 60, # 'Й'
+ 138: 36, # 'К'
+ 139: 49, # 'Л'
+ 140: 38, # 'М'
+ 141: 31, # 'Н'
+ 142: 34, # 'О'
+ 143: 35, # 'П'
+ 144: 45, # 'Р'
+ 145: 32, # 'С'
+ 146: 40, # 'Т'
+ 147: 52, # 'У'
+ 148: 53, # 'Ф'
+ 149: 55, # 'Х'
+ 150: 58, # 'Ц'
+ 151: 50, # 'Ч'
+ 152: 57, # 'Ш'
+ 153: 63, # 'Щ'
+ 154: 70, # 'Ъ'
+ 155: 62, # 'Ы'
+ 156: 61, # 'Ь'
+ 157: 47, # 'Э'
+ 158: 59, # 'Ю'
+ 159: 43, # 'Я'
+ 160: 191, # '†'
+ 161: 192, # '°'
+ 162: 193, # 'Ґ'
+ 163: 194, # '£'
+ 164: 195, # '§'
+ 165: 196, # '•'
+ 166: 197, # '¶'
+ 167: 198, # 'І'
+ 168: 199, # '®'
+ 169: 200, # '©'
+ 170: 201, # '™'
+ 171: 202, # 'Ђ'
+ 172: 203, # 'ђ'
+ 173: 204, # '≠'
+ 174: 205, # 'Ѓ'
+ 175: 206, # 'ѓ'
+ 176: 207, # '∞'
+ 177: 208, # '±'
+ 178: 209, # '≤'
+ 179: 210, # '≥'
+ 180: 211, # 'і'
+ 181: 212, # 'µ'
+ 182: 213, # 'ґ'
+ 183: 214, # 'Ј'
+ 184: 215, # 'Є'
+ 185: 216, # 'є'
+ 186: 217, # 'Ї'
+ 187: 218, # 'ї'
+ 188: 219, # 'Љ'
+ 189: 220, # 'љ'
+ 190: 221, # 'Њ'
+ 191: 222, # 'њ'
+ 192: 223, # 'ј'
+ 193: 224, # 'Ѕ'
+ 194: 225, # '¬'
+ 195: 226, # '√'
+ 196: 227, # 'ƒ'
+ 197: 228, # '≈'
+ 198: 229, # '∆'
+ 199: 230, # '«'
+ 200: 231, # '»'
+ 201: 232, # '…'
+ 202: 233, # '\xa0'
+ 203: 234, # 'Ћ'
+ 204: 235, # 'ћ'
+ 205: 236, # 'Ќ'
+ 206: 237, # 'ќ'
+ 207: 238, # 'ѕ'
+ 208: 239, # '–'
+ 209: 240, # '—'
+ 210: 241, # '“'
+ 211: 242, # '”'
+ 212: 243, # '‘'
+ 213: 244, # '’'
+ 214: 245, # '÷'
+ 215: 246, # '„'
+ 216: 247, # 'Ў'
+ 217: 248, # 'ў'
+ 218: 249, # 'Џ'
+ 219: 250, # 'џ'
+ 220: 251, # '№'
+ 221: 252, # 'Ё'
+ 222: 68, # 'ё'
+ 223: 16, # 'я'
+ 224: 3, # 'а'
+ 225: 21, # 'б'
+ 226: 10, # 'в'
+ 227: 19, # 'г'
+ 228: 13, # 'д'
+ 229: 2, # 'е'
+ 230: 24, # 'ж'
+ 231: 20, # 'з'
+ 232: 4, # 'и'
+ 233: 23, # 'й'
+ 234: 11, # 'к'
+ 235: 8, # 'л'
+ 236: 12, # 'м'
+ 237: 5, # 'н'
+ 238: 1, # 'о'
+ 239: 15, # 'п'
+ 240: 9, # 'р'
+ 241: 7, # 'с'
+ 242: 6, # 'т'
+ 243: 14, # 'у'
+ 244: 39, # 'ф'
+ 245: 26, # 'х'
+ 246: 28, # 'ц'
+ 247: 22, # 'ч'
+ 248: 25, # 'ш'
+ 249: 29, # 'щ'
+ 250: 54, # 'ъ'
+ 251: 18, # 'ы'
+ 252: 17, # 'ь'
+ 253: 30, # 'э'
+ 254: 27, # 'ю'
+ 255: 255, # '€'
+}
+
+MACCYRILLIC_RUSSIAN_MODEL = SingleByteCharSetModel(
+ charset_name="MacCyrillic",
+ language="Russian",
+ char_to_order_map=MACCYRILLIC_RUSSIAN_CHAR_TO_ORDER,
+ language_model=RUSSIAN_LANG_MODEL,
+ typical_positive_ratio=0.976601,
+ keep_ascii_letters=False,
+ alphabet="ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
+)
+
+ISO_8859_5_RUSSIAN_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 142, # 'A'
+ 66: 143, # 'B'
+ 67: 144, # 'C'
+ 68: 145, # 'D'
+ 69: 146, # 'E'
+ 70: 147, # 'F'
+ 71: 148, # 'G'
+ 72: 149, # 'H'
+ 73: 150, # 'I'
+ 74: 151, # 'J'
+ 75: 152, # 'K'
+ 76: 74, # 'L'
+ 77: 153, # 'M'
+ 78: 75, # 'N'
+ 79: 154, # 'O'
+ 80: 155, # 'P'
+ 81: 156, # 'Q'
+ 82: 157, # 'R'
+ 83: 158, # 'S'
+ 84: 159, # 'T'
+ 85: 160, # 'U'
+ 86: 161, # 'V'
+ 87: 162, # 'W'
+ 88: 163, # 'X'
+ 89: 164, # 'Y'
+ 90: 165, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 71, # 'a'
+ 98: 172, # 'b'
+ 99: 66, # 'c'
+ 100: 173, # 'd'
+ 101: 65, # 'e'
+ 102: 174, # 'f'
+ 103: 76, # 'g'
+ 104: 175, # 'h'
+ 105: 64, # 'i'
+ 106: 176, # 'j'
+ 107: 177, # 'k'
+ 108: 77, # 'l'
+ 109: 72, # 'm'
+ 110: 178, # 'n'
+ 111: 69, # 'o'
+ 112: 67, # 'p'
+ 113: 179, # 'q'
+ 114: 78, # 'r'
+ 115: 73, # 's'
+ 116: 180, # 't'
+ 117: 181, # 'u'
+ 118: 79, # 'v'
+ 119: 182, # 'w'
+ 120: 183, # 'x'
+ 121: 184, # 'y'
+ 122: 185, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 191, # '\x80'
+ 129: 192, # '\x81'
+ 130: 193, # '\x82'
+ 131: 194, # '\x83'
+ 132: 195, # '\x84'
+ 133: 196, # '\x85'
+ 134: 197, # '\x86'
+ 135: 198, # '\x87'
+ 136: 199, # '\x88'
+ 137: 200, # '\x89'
+ 138: 201, # '\x8a'
+ 139: 202, # '\x8b'
+ 140: 203, # '\x8c'
+ 141: 204, # '\x8d'
+ 142: 205, # '\x8e'
+ 143: 206, # '\x8f'
+ 144: 207, # '\x90'
+ 145: 208, # '\x91'
+ 146: 209, # '\x92'
+ 147: 210, # '\x93'
+ 148: 211, # '\x94'
+ 149: 212, # '\x95'
+ 150: 213, # '\x96'
+ 151: 214, # '\x97'
+ 152: 215, # '\x98'
+ 153: 216, # '\x99'
+ 154: 217, # '\x9a'
+ 155: 218, # '\x9b'
+ 156: 219, # '\x9c'
+ 157: 220, # '\x9d'
+ 158: 221, # '\x9e'
+ 159: 222, # '\x9f'
+ 160: 223, # '\xa0'
+ 161: 224, # 'Ё'
+ 162: 225, # 'Ђ'
+ 163: 226, # 'Ѓ'
+ 164: 227, # 'Є'
+ 165: 228, # 'Ѕ'
+ 166: 229, # 'І'
+ 167: 230, # 'Ї'
+ 168: 231, # 'Ј'
+ 169: 232, # 'Љ'
+ 170: 233, # 'Њ'
+ 171: 234, # 'Ћ'
+ 172: 235, # 'Ќ'
+ 173: 236, # '\xad'
+ 174: 237, # 'Ў'
+ 175: 238, # 'Џ'
+ 176: 37, # 'А'
+ 177: 44, # 'Б'
+ 178: 33, # 'В'
+ 179: 46, # 'Г'
+ 180: 41, # 'Д'
+ 181: 48, # 'Е'
+ 182: 56, # 'Ж'
+ 183: 51, # 'З'
+ 184: 42, # 'И'
+ 185: 60, # 'Й'
+ 186: 36, # 'К'
+ 187: 49, # 'Л'
+ 188: 38, # 'М'
+ 189: 31, # 'Н'
+ 190: 34, # 'О'
+ 191: 35, # 'П'
+ 192: 45, # 'Р'
+ 193: 32, # 'С'
+ 194: 40, # 'Т'
+ 195: 52, # 'У'
+ 196: 53, # 'Ф'
+ 197: 55, # 'Х'
+ 198: 58, # 'Ц'
+ 199: 50, # 'Ч'
+ 200: 57, # 'Ш'
+ 201: 63, # 'Щ'
+ 202: 70, # 'Ъ'
+ 203: 62, # 'Ы'
+ 204: 61, # 'Ь'
+ 205: 47, # 'Э'
+ 206: 59, # 'Ю'
+ 207: 43, # 'Я'
+ 208: 3, # 'а'
+ 209: 21, # 'б'
+ 210: 10, # 'в'
+ 211: 19, # 'г'
+ 212: 13, # 'д'
+ 213: 2, # 'е'
+ 214: 24, # 'ж'
+ 215: 20, # 'з'
+ 216: 4, # 'и'
+ 217: 23, # 'й'
+ 218: 11, # 'к'
+ 219: 8, # 'л'
+ 220: 12, # 'м'
+ 221: 5, # 'н'
+ 222: 1, # 'о'
+ 223: 15, # 'п'
+ 224: 9, # 'р'
+ 225: 7, # 'с'
+ 226: 6, # 'т'
+ 227: 14, # 'у'
+ 228: 39, # 'ф'
+ 229: 26, # 'х'
+ 230: 28, # 'ц'
+ 231: 22, # 'ч'
+ 232: 25, # 'ш'
+ 233: 29, # 'щ'
+ 234: 54, # 'ъ'
+ 235: 18, # 'ы'
+ 236: 17, # 'ь'
+ 237: 30, # 'э'
+ 238: 27, # 'ю'
+ 239: 16, # 'я'
+ 240: 239, # '№'
+ 241: 68, # 'ё'
+ 242: 240, # 'ђ'
+ 243: 241, # 'ѓ'
+ 244: 242, # 'є'
+ 245: 243, # 'ѕ'
+ 246: 244, # 'і'
+ 247: 245, # 'ї'
+ 248: 246, # 'ј'
+ 249: 247, # 'љ'
+ 250: 248, # 'њ'
+ 251: 249, # 'ћ'
+ 252: 250, # 'ќ'
+ 253: 251, # '§'
+ 254: 252, # 'ў'
+ 255: 255, # 'џ'
+}
+
+ISO_8859_5_RUSSIAN_MODEL = SingleByteCharSetModel(
+ charset_name="ISO-8859-5",
+ language="Russian",
+ char_to_order_map=ISO_8859_5_RUSSIAN_CHAR_TO_ORDER,
+ language_model=RUSSIAN_LANG_MODEL,
+ typical_positive_ratio=0.976601,
+ keep_ascii_letters=False,
+ alphabet="ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
+)
diff --git a/lib/chardet/langthaimodel.py b/lib/chardet/langthaimodel.py
new file mode 100644
index 0000000..883fdb1
--- /dev/null
+++ b/lib/chardet/langthaimodel.py
@@ -0,0 +1,4380 @@
+from chardet.sbcharsetprober import SingleByteCharSetModel
+
+# 3: Positive
+# 2: Likely
+# 1: Unlikely
+# 0: Negative
+
+THAI_LANG_MODEL = {
+ 5: { # 'ก'
+ 5: 2, # 'ก'
+ 30: 2, # 'ข'
+ 24: 2, # 'ค'
+ 8: 2, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 3, # 'ฎ'
+ 57: 2, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 2, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 3, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 1, # 'บ'
+ 25: 2, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 1, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 1, # 'ย'
+ 2: 3, # 'ร'
+ 61: 2, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 3, # 'ว'
+ 42: 2, # 'ศ'
+ 46: 3, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 3, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 3, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 2, # 'ื'
+ 32: 2, # 'ุ'
+ 35: 1, # 'ู'
+ 11: 2, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 3, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 30: { # 'ข'
+ 5: 1, # 'ก'
+ 30: 0, # 'ข'
+ 24: 1, # 'ค'
+ 8: 1, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 2, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 2, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 1, # 'บ'
+ 25: 1, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 2, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 1, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 3, # 'ึ'
+ 27: 1, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 1, # '็'
+ 6: 2, # '่'
+ 7: 3, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 24: { # 'ค'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 2, # 'ค'
+ 8: 2, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 2, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 2, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 0, # 'บ'
+ 25: 1, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 2, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 3, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 2, # 'า'
+ 36: 3, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 1, # 'เ'
+ 28: 0, # 'แ'
+ 41: 3, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 1, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 3, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 8: { # 'ง'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 3, # 'ค'
+ 8: 2, # 'ง'
+ 26: 2, # 'จ'
+ 52: 1, # 'ฉ'
+ 34: 2, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 1, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 1, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 2, # 'ว'
+ 42: 2, # 'ศ'
+ 46: 1, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 3, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 1, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 1, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 3, # 'ๆ'
+ 37: 0, # '็'
+ 6: 2, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 26: { # 'จ'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 0, # 'ค'
+ 8: 2, # 'ง'
+ 26: 3, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 1, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 1, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 1, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 1, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 3, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 3, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 3, # 'ึ'
+ 27: 1, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 1, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 2, # '่'
+ 7: 2, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 52: { # 'ฉ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 3, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 3, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 1, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 1, # 'ั'
+ 1: 1, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 34: { # 'ช'
+ 5: 1, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 1, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 1, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 1, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 2, # 'ั'
+ 1: 3, # 'า'
+ 36: 1, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 1, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 1, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 51: { # 'ซ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 1, # 'ั'
+ 1: 1, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 3, # 'ึ'
+ 27: 2, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 1, # 'ู'
+ 11: 1, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 1, # '็'
+ 6: 1, # '่'
+ 7: 2, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 47: { # 'ญ'
+ 5: 1, # 'ก'
+ 30: 1, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 3, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 1, # 'บ'
+ 25: 1, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 2, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 2, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 1, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 0, # '็'
+ 6: 2, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 58: { # 'ฎ'
+ 5: 2, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 1, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 57: { # 'ฏ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 49: { # 'ฐ'
+ 5: 1, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 2, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 53: { # 'ฑ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 3, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 55: { # 'ฒ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 43: { # 'ณ'
+ 5: 1, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 3, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 3, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 1, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 3, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 1, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 3, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 20: { # 'ด'
+ 5: 2, # 'ก'
+ 30: 2, # 'ข'
+ 24: 2, # 'ค'
+ 8: 3, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 1, # 'บ'
+ 25: 1, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 3, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 2, # 'า'
+ 36: 2, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 1, # 'ึ'
+ 27: 2, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 2, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 2, # 'ๆ'
+ 37: 2, # '็'
+ 6: 1, # '่'
+ 7: 3, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 19: { # 'ต'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 1, # 'ค'
+ 8: 0, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 1, # 'ต'
+ 44: 2, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 1, # 'บ'
+ 25: 1, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 2, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 1, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 0, # 'ห'
+ 4: 3, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 2, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 1, # 'ึ'
+ 27: 1, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 1, # 'เ'
+ 28: 1, # 'แ'
+ 41: 1, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 2, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 44: { # 'ถ'
+ 5: 1, # 'ก'
+ 30: 0, # 'ข'
+ 24: 1, # 'ค'
+ 8: 0, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 2, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 2, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 3, # 'ึ'
+ 27: 2, # 'ื'
+ 32: 2, # 'ุ'
+ 35: 3, # 'ู'
+ 11: 1, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 2, # '่'
+ 7: 3, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 14: { # 'ท'
+ 5: 1, # 'ก'
+ 30: 1, # 'ข'
+ 24: 3, # 'ค'
+ 8: 1, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 3, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 3, # 'ย'
+ 2: 3, # 'ร'
+ 61: 1, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 2, # 'ว'
+ 42: 3, # 'ศ'
+ 46: 1, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 3, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 2, # 'ึ'
+ 27: 1, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 1, # 'ู'
+ 11: 0, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 1, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 48: { # 'ธ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 1, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 2, # 'า'
+ 36: 0, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 2, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 3, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 3: { # 'น'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 3, # 'ค'
+ 8: 1, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 1, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 2, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 3, # 'ธ'
+ 3: 2, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 2, # 'ย'
+ 2: 2, # 'ร'
+ 61: 1, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 3, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 3, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 3, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 3, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 3, # 'เ'
+ 28: 2, # 'แ'
+ 41: 3, # 'โ'
+ 29: 3, # 'ใ'
+ 33: 3, # 'ไ'
+ 50: 2, # 'ๆ'
+ 37: 1, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 17: { # 'บ'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 2, # 'ค'
+ 8: 1, # 'ง'
+ 26: 1, # 'จ'
+ 52: 1, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 2, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 0, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 3, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 2, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 2, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 2, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 2, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 1, # '็'
+ 6: 2, # '่'
+ 7: 2, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 25: { # 'ป'
+ 5: 2, # 'ก'
+ 30: 0, # 'ข'
+ 24: 1, # 'ค'
+ 8: 0, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 1, # 'ฎ'
+ 57: 3, # 'ฏ'
+ 49: 1, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 1, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 0, # 'บ'
+ 25: 1, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 1, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 0, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 1, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 1, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 1, # 'า'
+ 36: 0, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 1, # 'เ'
+ 28: 2, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 3, # '็'
+ 6: 1, # '่'
+ 7: 2, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 39: { # 'ผ'
+ 5: 1, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 1, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 2, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 1, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 1, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 3, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 1, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 62: { # 'ฝ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 1, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 2, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 2, # '่'
+ 7: 1, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 31: { # 'พ'
+ 5: 1, # 'ก'
+ 30: 1, # 'ข'
+ 24: 1, # 'ค'
+ 8: 1, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 1, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 0, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 2, # 'ย'
+ 2: 3, # 'ร'
+ 61: 2, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 1, # 'ห'
+ 4: 2, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 1, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 1, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 1, # '็'
+ 6: 0, # '่'
+ 7: 1, # '้'
+ 38: 3, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 54: { # 'ฟ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 2, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 2, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 1, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 2, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 45: { # 'ภ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 1, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 9: { # 'ม'
+ 5: 2, # 'ก'
+ 30: 2, # 'ข'
+ 24: 2, # 'ค'
+ 8: 2, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 1, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 3, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 1, # 'ย'
+ 2: 2, # 'ร'
+ 61: 2, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 2, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 1, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 3, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 3, # 'ู'
+ 11: 2, # 'เ'
+ 28: 2, # 'แ'
+ 41: 2, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 1, # '็'
+ 6: 3, # '่'
+ 7: 2, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 16: { # 'ย'
+ 5: 3, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 3, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 2, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 2, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 1, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 0, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 3, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 1, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 1, # 'ึ'
+ 27: 2, # 'ื'
+ 32: 2, # 'ุ'
+ 35: 3, # 'ู'
+ 11: 2, # 'เ'
+ 28: 1, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 2, # 'ๆ'
+ 37: 1, # '็'
+ 6: 3, # '่'
+ 7: 2, # '้'
+ 38: 3, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 2: { # 'ร'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 2, # 'ค'
+ 8: 3, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 2, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 3, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 3, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 2, # 'ต'
+ 44: 3, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 2, # 'น'
+ 17: 2, # 'บ'
+ 25: 3, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 1, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 2, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 3, # 'ว'
+ 42: 2, # 'ศ'
+ 46: 2, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 3, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 3, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 2, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 3, # 'ู'
+ 11: 3, # 'เ'
+ 28: 3, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 3, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 3, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 61: { # 'ฤ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 2, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 2, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 15: { # 'ล'
+ 5: 2, # 'ก'
+ 30: 3, # 'ข'
+ 24: 1, # 'ค'
+ 8: 3, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 3, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 1, # 'ห'
+ 4: 3, # 'อ'
+ 63: 2, # 'ฯ'
+ 22: 3, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 2, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 2, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 2, # 'ุ'
+ 35: 3, # 'ู'
+ 11: 2, # 'เ'
+ 28: 1, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 2, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 12: { # 'ว'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 1, # 'ค'
+ 8: 3, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 1, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 1, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 1, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 3, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 2, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 2, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 42: { # 'ศ'
+ 5: 1, # 'ก'
+ 30: 0, # 'ข'
+ 24: 1, # 'ค'
+ 8: 0, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 1, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 2, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 2, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 2, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 3, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 2, # 'ู'
+ 11: 0, # 'เ'
+ 28: 1, # 'แ'
+ 41: 0, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 46: { # 'ษ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 2, # 'ฎ'
+ 57: 1, # 'ฏ'
+ 49: 2, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 3, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 2, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 2, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 1, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 18: { # 'ส'
+ 5: 2, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 2, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 3, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 1, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 2, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 1, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 2, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 3, # 'ำ'
+ 23: 3, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 2, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 3, # 'ู'
+ 11: 2, # 'เ'
+ 28: 0, # 'แ'
+ 41: 1, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 1, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 21: { # 'ห'
+ 5: 3, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 1, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 2, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 3, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 0, # 'บ'
+ 25: 1, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 2, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 0, # 'ำ'
+ 23: 1, # 'ิ'
+ 13: 1, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 1, # 'ุ'
+ 35: 1, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 3, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 4: { # 'อ'
+ 5: 3, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 3, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 1, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 3, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 2, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 2, # 'ะ'
+ 10: 3, # 'ั'
+ 1: 3, # 'า'
+ 36: 2, # 'ำ'
+ 23: 2, # 'ิ'
+ 13: 3, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 3, # 'ื'
+ 32: 3, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 1, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 1, # '็'
+ 6: 2, # '่'
+ 7: 2, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 63: { # 'ฯ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 22: { # 'ะ'
+ 5: 3, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 1, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 3, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 2, # 'น'
+ 17: 3, # 'บ'
+ 25: 2, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 2, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 3, # 'ห'
+ 4: 2, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 10: { # 'ั'
+ 5: 3, # 'ก'
+ 30: 0, # 'ข'
+ 24: 1, # 'ค'
+ 8: 3, # 'ง'
+ 26: 3, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 3, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 2, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 3, # 'ฒ'
+ 43: 3, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 1, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 3, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 3, # 'ว'
+ 42: 2, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 1: { # 'า'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 3, # 'ค'
+ 8: 3, # 'ง'
+ 26: 3, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 3, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 2, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 3, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 2, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 2, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 1, # 'ฝ'
+ 31: 3, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 3, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 3, # 'ว'
+ 42: 2, # 'ศ'
+ 46: 3, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 3, # 'ห'
+ 4: 2, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 3, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 36: { # 'ำ'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 3, # 'ค'
+ 8: 2, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 1, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 1, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 1, # 'บ'
+ 25: 1, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 0, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 3, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 23: { # 'ิ'
+ 5: 3, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 3, # 'ง'
+ 26: 3, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 3, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 2, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 3, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 2, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 3, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 2, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 2, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 3, # 'ว'
+ 42: 3, # 'ศ'
+ 46: 2, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 3, # 'ห'
+ 4: 1, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 1, # 'แ'
+ 41: 1, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 2, # '้'
+ 38: 2, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 13: { # 'ี'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 2, # 'ค'
+ 8: 0, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 1, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 3, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 2, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 1, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 2, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 40: { # 'ึ'
+ 5: 3, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 3, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 27: { # 'ื'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 3, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 32: { # 'ุ'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 3, # 'ค'
+ 8: 3, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 2, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 1, # 'ฒ'
+ 43: 3, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 2, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 1, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 1, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 2, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 1, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 1, # 'เ'
+ 28: 0, # 'แ'
+ 41: 1, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 2, # '้'
+ 38: 1, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 35: { # 'ู'
+ 5: 3, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 2, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 2, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 1, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 2, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 2, # 'น'
+ 17: 0, # 'บ'
+ 25: 3, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 0, # 'ย'
+ 2: 1, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 1, # 'เ'
+ 28: 1, # 'แ'
+ 41: 1, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 3, # '่'
+ 7: 3, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 11: { # 'เ'
+ 5: 3, # 'ก'
+ 30: 3, # 'ข'
+ 24: 3, # 'ค'
+ 8: 2, # 'ง'
+ 26: 3, # 'จ'
+ 52: 3, # 'ฉ'
+ 34: 3, # 'ช'
+ 51: 2, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 1, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 3, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 3, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 1, # 'ฝ'
+ 31: 3, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 3, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 2, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 3, # 'ว'
+ 42: 2, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 3, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 28: { # 'แ'
+ 5: 3, # 'ก'
+ 30: 2, # 'ข'
+ 24: 2, # 'ค'
+ 8: 1, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 3, # 'ต'
+ 44: 2, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 2, # 'ป'
+ 39: 3, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 2, # 'พ'
+ 54: 2, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 2, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 3, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 41: { # 'โ'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 0, # 'ง'
+ 26: 1, # 'จ'
+ 52: 1, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 2, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 1, # 'บ'
+ 25: 3, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 1, # 'ภ'
+ 9: 1, # 'ม'
+ 16: 2, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 3, # 'ล'
+ 12: 0, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 0, # 'ห'
+ 4: 2, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 29: { # 'ใ'
+ 5: 2, # 'ก'
+ 30: 0, # 'ข'
+ 24: 1, # 'ค'
+ 8: 0, # 'ง'
+ 26: 3, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 3, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 1, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 3, # 'ส'
+ 21: 3, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 33: { # 'ไ'
+ 5: 1, # 'ก'
+ 30: 2, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 3, # 'ด'
+ 19: 1, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 3, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 1, # 'บ'
+ 25: 3, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 2, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 0, # 'ย'
+ 2: 3, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 3, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 2, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 50: { # 'ๆ'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 37: { # '็'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 2, # 'ง'
+ 26: 3, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 1, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 2, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 3, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 1, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 2, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 0, # 'ห'
+ 4: 1, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 1, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 6: { # '่'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 3, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 1, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 1, # 'ธ'
+ 3: 3, # 'น'
+ 17: 1, # 'บ'
+ 25: 2, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 1, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 3, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 2, # 'ล'
+ 12: 3, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 1, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 1, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 3, # 'า'
+ 36: 2, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 3, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 1, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 7: { # '้'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 2, # 'ค'
+ 8: 3, # 'ง'
+ 26: 2, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 1, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 1, # 'ด'
+ 19: 2, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 2, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 3, # 'น'
+ 17: 2, # 'บ'
+ 25: 2, # 'ป'
+ 39: 2, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 3, # 'ม'
+ 16: 2, # 'ย'
+ 2: 2, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 3, # 'ว'
+ 42: 1, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 2, # 'ส'
+ 21: 2, # 'ห'
+ 4: 3, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 3, # 'า'
+ 36: 2, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 2, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 2, # 'ใ'
+ 33: 2, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 38: { # '์'
+ 5: 2, # 'ก'
+ 30: 1, # 'ข'
+ 24: 1, # 'ค'
+ 8: 0, # 'ง'
+ 26: 1, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 1, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 2, # 'ด'
+ 19: 1, # 'ต'
+ 44: 1, # 'ถ'
+ 14: 1, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 1, # 'น'
+ 17: 1, # 'บ'
+ 25: 1, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 1, # 'พ'
+ 54: 1, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 2, # 'ม'
+ 16: 0, # 'ย'
+ 2: 1, # 'ร'
+ 61: 1, # 'ฤ'
+ 15: 1, # 'ล'
+ 12: 1, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 1, # 'ส'
+ 21: 1, # 'ห'
+ 4: 2, # 'อ'
+ 63: 1, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 2, # 'เ'
+ 28: 2, # 'แ'
+ 41: 1, # 'โ'
+ 29: 1, # 'ใ'
+ 33: 1, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 0, # '๑'
+ 59: 0, # '๒'
+ 60: 0, # '๕'
+ },
+ 56: { # '๑'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 2, # '๑'
+ 59: 1, # '๒'
+ 60: 1, # '๕'
+ },
+ 59: { # '๒'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 1, # '๑'
+ 59: 1, # '๒'
+ 60: 3, # '๕'
+ },
+ 60: { # '๕'
+ 5: 0, # 'ก'
+ 30: 0, # 'ข'
+ 24: 0, # 'ค'
+ 8: 0, # 'ง'
+ 26: 0, # 'จ'
+ 52: 0, # 'ฉ'
+ 34: 0, # 'ช'
+ 51: 0, # 'ซ'
+ 47: 0, # 'ญ'
+ 58: 0, # 'ฎ'
+ 57: 0, # 'ฏ'
+ 49: 0, # 'ฐ'
+ 53: 0, # 'ฑ'
+ 55: 0, # 'ฒ'
+ 43: 0, # 'ณ'
+ 20: 0, # 'ด'
+ 19: 0, # 'ต'
+ 44: 0, # 'ถ'
+ 14: 0, # 'ท'
+ 48: 0, # 'ธ'
+ 3: 0, # 'น'
+ 17: 0, # 'บ'
+ 25: 0, # 'ป'
+ 39: 0, # 'ผ'
+ 62: 0, # 'ฝ'
+ 31: 0, # 'พ'
+ 54: 0, # 'ฟ'
+ 45: 0, # 'ภ'
+ 9: 0, # 'ม'
+ 16: 0, # 'ย'
+ 2: 0, # 'ร'
+ 61: 0, # 'ฤ'
+ 15: 0, # 'ล'
+ 12: 0, # 'ว'
+ 42: 0, # 'ศ'
+ 46: 0, # 'ษ'
+ 18: 0, # 'ส'
+ 21: 0, # 'ห'
+ 4: 0, # 'อ'
+ 63: 0, # 'ฯ'
+ 22: 0, # 'ะ'
+ 10: 0, # 'ั'
+ 1: 0, # 'า'
+ 36: 0, # 'ำ'
+ 23: 0, # 'ิ'
+ 13: 0, # 'ี'
+ 40: 0, # 'ึ'
+ 27: 0, # 'ื'
+ 32: 0, # 'ุ'
+ 35: 0, # 'ู'
+ 11: 0, # 'เ'
+ 28: 0, # 'แ'
+ 41: 0, # 'โ'
+ 29: 0, # 'ใ'
+ 33: 0, # 'ไ'
+ 50: 0, # 'ๆ'
+ 37: 0, # '็'
+ 6: 0, # '่'
+ 7: 0, # '้'
+ 38: 0, # '์'
+ 56: 2, # '๑'
+ 59: 1, # '๒'
+ 60: 0, # '๕'
+ },
+}
+
+# 255: Undefined characters that did not exist in training text
+# 254: Carriage/Return
+# 253: symbol (punctuation) that does not belong to word
+# 252: 0 - 9
+# 251: Control characters
+
+# Character Mapping Table(s):
+TIS_620_THAI_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 254, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 254, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 253, # ' '
+ 33: 253, # '!'
+ 34: 253, # '"'
+ 35: 253, # '#'
+ 36: 253, # '$'
+ 37: 253, # '%'
+ 38: 253, # '&'
+ 39: 253, # "'"
+ 40: 253, # '('
+ 41: 253, # ')'
+ 42: 253, # '*'
+ 43: 253, # '+'
+ 44: 253, # ','
+ 45: 253, # '-'
+ 46: 253, # '.'
+ 47: 253, # '/'
+ 48: 252, # '0'
+ 49: 252, # '1'
+ 50: 252, # '2'
+ 51: 252, # '3'
+ 52: 252, # '4'
+ 53: 252, # '5'
+ 54: 252, # '6'
+ 55: 252, # '7'
+ 56: 252, # '8'
+ 57: 252, # '9'
+ 58: 253, # ':'
+ 59: 253, # ';'
+ 60: 253, # '<'
+ 61: 253, # '='
+ 62: 253, # '>'
+ 63: 253, # '?'
+ 64: 253, # '@'
+ 65: 182, # 'A'
+ 66: 106, # 'B'
+ 67: 107, # 'C'
+ 68: 100, # 'D'
+ 69: 183, # 'E'
+ 70: 184, # 'F'
+ 71: 185, # 'G'
+ 72: 101, # 'H'
+ 73: 94, # 'I'
+ 74: 186, # 'J'
+ 75: 187, # 'K'
+ 76: 108, # 'L'
+ 77: 109, # 'M'
+ 78: 110, # 'N'
+ 79: 111, # 'O'
+ 80: 188, # 'P'
+ 81: 189, # 'Q'
+ 82: 190, # 'R'
+ 83: 89, # 'S'
+ 84: 95, # 'T'
+ 85: 112, # 'U'
+ 86: 113, # 'V'
+ 87: 191, # 'W'
+ 88: 192, # 'X'
+ 89: 193, # 'Y'
+ 90: 194, # 'Z'
+ 91: 253, # '['
+ 92: 253, # '\\'
+ 93: 253, # ']'
+ 94: 253, # '^'
+ 95: 253, # '_'
+ 96: 253, # '`'
+ 97: 64, # 'a'
+ 98: 72, # 'b'
+ 99: 73, # 'c'
+ 100: 114, # 'd'
+ 101: 74, # 'e'
+ 102: 115, # 'f'
+ 103: 116, # 'g'
+ 104: 102, # 'h'
+ 105: 81, # 'i'
+ 106: 201, # 'j'
+ 107: 117, # 'k'
+ 108: 90, # 'l'
+ 109: 103, # 'm'
+ 110: 78, # 'n'
+ 111: 82, # 'o'
+ 112: 96, # 'p'
+ 113: 202, # 'q'
+ 114: 91, # 'r'
+ 115: 79, # 's'
+ 116: 84, # 't'
+ 117: 104, # 'u'
+ 118: 105, # 'v'
+ 119: 97, # 'w'
+ 120: 98, # 'x'
+ 121: 92, # 'y'
+ 122: 203, # 'z'
+ 123: 253, # '{'
+ 124: 253, # '|'
+ 125: 253, # '}'
+ 126: 253, # '~'
+ 127: 253, # '\x7f'
+ 128: 209, # '\x80'
+ 129: 210, # '\x81'
+ 130: 211, # '\x82'
+ 131: 212, # '\x83'
+ 132: 213, # '\x84'
+ 133: 88, # '\x85'
+ 134: 214, # '\x86'
+ 135: 215, # '\x87'
+ 136: 216, # '\x88'
+ 137: 217, # '\x89'
+ 138: 218, # '\x8a'
+ 139: 219, # '\x8b'
+ 140: 220, # '\x8c'
+ 141: 118, # '\x8d'
+ 142: 221, # '\x8e'
+ 143: 222, # '\x8f'
+ 144: 223, # '\x90'
+ 145: 224, # '\x91'
+ 146: 99, # '\x92'
+ 147: 85, # '\x93'
+ 148: 83, # '\x94'
+ 149: 225, # '\x95'
+ 150: 226, # '\x96'
+ 151: 227, # '\x97'
+ 152: 228, # '\x98'
+ 153: 229, # '\x99'
+ 154: 230, # '\x9a'
+ 155: 231, # '\x9b'
+ 156: 232, # '\x9c'
+ 157: 233, # '\x9d'
+ 158: 234, # '\x9e'
+ 159: 235, # '\x9f'
+ 160: 236, # None
+ 161: 5, # 'ก'
+ 162: 30, # 'ข'
+ 163: 237, # 'ฃ'
+ 164: 24, # 'ค'
+ 165: 238, # 'ฅ'
+ 166: 75, # 'ฆ'
+ 167: 8, # 'ง'
+ 168: 26, # 'จ'
+ 169: 52, # 'ฉ'
+ 170: 34, # 'ช'
+ 171: 51, # 'ซ'
+ 172: 119, # 'ฌ'
+ 173: 47, # 'ญ'
+ 174: 58, # 'ฎ'
+ 175: 57, # 'ฏ'
+ 176: 49, # 'ฐ'
+ 177: 53, # 'ฑ'
+ 178: 55, # 'ฒ'
+ 179: 43, # 'ณ'
+ 180: 20, # 'ด'
+ 181: 19, # 'ต'
+ 182: 44, # 'ถ'
+ 183: 14, # 'ท'
+ 184: 48, # 'ธ'
+ 185: 3, # 'น'
+ 186: 17, # 'บ'
+ 187: 25, # 'ป'
+ 188: 39, # 'ผ'
+ 189: 62, # 'ฝ'
+ 190: 31, # 'พ'
+ 191: 54, # 'ฟ'
+ 192: 45, # 'ภ'
+ 193: 9, # 'ม'
+ 194: 16, # 'ย'
+ 195: 2, # 'ร'
+ 196: 61, # 'ฤ'
+ 197: 15, # 'ล'
+ 198: 239, # 'ฦ'
+ 199: 12, # 'ว'
+ 200: 42, # 'ศ'
+ 201: 46, # 'ษ'
+ 202: 18, # 'ส'
+ 203: 21, # 'ห'
+ 204: 76, # 'ฬ'
+ 205: 4, # 'อ'
+ 206: 66, # 'ฮ'
+ 207: 63, # 'ฯ'
+ 208: 22, # 'ะ'
+ 209: 10, # 'ั'
+ 210: 1, # 'า'
+ 211: 36, # 'ำ'
+ 212: 23, # 'ิ'
+ 213: 13, # 'ี'
+ 214: 40, # 'ึ'
+ 215: 27, # 'ื'
+ 216: 32, # 'ุ'
+ 217: 35, # 'ู'
+ 218: 86, # 'ฺ'
+ 219: 240, # None
+ 220: 241, # None
+ 221: 242, # None
+ 222: 243, # None
+ 223: 244, # '฿'
+ 224: 11, # 'เ'
+ 225: 28, # 'แ'
+ 226: 41, # 'โ'
+ 227: 29, # 'ใ'
+ 228: 33, # 'ไ'
+ 229: 245, # 'ๅ'
+ 230: 50, # 'ๆ'
+ 231: 37, # '็'
+ 232: 6, # '่'
+ 233: 7, # '้'
+ 234: 67, # '๊'
+ 235: 77, # '๋'
+ 236: 38, # '์'
+ 237: 93, # 'ํ'
+ 238: 246, # '๎'
+ 239: 247, # '๏'
+ 240: 68, # '๐'
+ 241: 56, # '๑'
+ 242: 59, # '๒'
+ 243: 65, # '๓'
+ 244: 69, # '๔'
+ 245: 60, # '๕'
+ 246: 70, # '๖'
+ 247: 80, # '๗'
+ 248: 71, # '๘'
+ 249: 87, # '๙'
+ 250: 248, # '๚'
+ 251: 249, # '๛'
+ 252: 250, # None
+ 253: 251, # None
+ 254: 252, # None
+ 255: 253, # None
+}
+
+TIS_620_THAI_MODEL = SingleByteCharSetModel(
+ charset_name="TIS-620",
+ language="Thai",
+ char_to_order_map=TIS_620_THAI_CHAR_TO_ORDER,
+ language_model=THAI_LANG_MODEL,
+ typical_positive_ratio=0.926386,
+ keep_ascii_letters=False,
+ alphabet="กขฃคฅฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลฦวศษสหฬอฮฯะัาำิีึืฺุู฿เแโใไๅๆ็่้๊๋์ํ๎๏๐๑๒๓๔๕๖๗๘๙๚๛",
+)
diff --git a/lib/chardet/langturkishmodel.py b/lib/chardet/langturkishmodel.py
new file mode 100644
index 0000000..64c9433
--- /dev/null
+++ b/lib/chardet/langturkishmodel.py
@@ -0,0 +1,4380 @@
+from chardet.sbcharsetprober import SingleByteCharSetModel
+
+# 3: Positive
+# 2: Likely
+# 1: Unlikely
+# 0: Negative
+
+TURKISH_LANG_MODEL = {
+ 23: { # 'A'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 1, # 'h'
+ 3: 1, # 'i'
+ 24: 0, # 'j'
+ 10: 2, # 'k'
+ 5: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 1, # 'r'
+ 8: 1, # 's'
+ 9: 1, # 't'
+ 14: 1, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 0, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 37: { # 'B'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 2, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 1, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 1, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 0, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 47: { # 'C'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 1, # 'L'
+ 20: 0, # 'M'
+ 46: 1, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 1, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 2, # 'j'
+ 10: 1, # 'k'
+ 5: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 2, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 2, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 39: { # 'D'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 1, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 1, # 'l'
+ 13: 3, # 'm'
+ 4: 0, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 1, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 1, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 29: { # 'E'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 1, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 0, # 'h'
+ 3: 1, # 'i'
+ 24: 1, # 'j'
+ 10: 0, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 1, # 's'
+ 9: 1, # 't'
+ 14: 1, # 'u'
+ 32: 1, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 52: { # 'F'
+ 23: 0, # 'A'
+ 37: 1, # 'B'
+ 47: 1, # 'C'
+ 39: 1, # 'D'
+ 29: 1, # 'E'
+ 52: 2, # 'F'
+ 36: 0, # 'G'
+ 45: 2, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 1, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 1, # 'b'
+ 28: 1, # 'c'
+ 12: 1, # 'd'
+ 2: 0, # 'e'
+ 18: 1, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 2, # 'i'
+ 24: 1, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 2, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 2, # 'r'
+ 8: 1, # 's'
+ 9: 1, # 't'
+ 14: 1, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 1, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 2, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 2, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 2, # 'ş'
+ },
+ 36: { # 'G'
+ 23: 1, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 2, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 2, # 'N'
+ 42: 1, # 'O'
+ 48: 1, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 1, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 1, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 0, # 'r'
+ 8: 1, # 's'
+ 9: 1, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 1, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 2, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 45: { # 'H'
+ 23: 0, # 'A'
+ 37: 1, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 2, # 'G'
+ 45: 1, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 1, # 'L'
+ 20: 0, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 2, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 2, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 15: 1, # 'o'
+ 26: 1, # 'p'
+ 7: 1, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 2, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 0, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 53: { # 'I'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 0, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 60: { # 'J'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 1, # 'd'
+ 2: 0, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 1, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 1, # 's'
+ 9: 0, # 't'
+ 14: 0, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 0, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 16: { # 'K'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 3, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 1, # 'e'
+ 18: 3, # 'f'
+ 27: 3, # 'g'
+ 25: 3, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 0, # 'u'
+ 32: 3, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 2, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 49: { # 'L'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 2, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 2, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 0, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 2, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 2, # 'n'
+ 15: 1, # 'o'
+ 26: 1, # 'p'
+ 7: 1, # 'r'
+ 8: 1, # 's'
+ 9: 1, # 't'
+ 14: 0, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 2, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 1, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 20: { # 'M'
+ 23: 1, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 1, # 'h'
+ 3: 2, # 'i'
+ 24: 2, # 'j'
+ 10: 2, # 'k'
+ 5: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 3, # 'r'
+ 8: 0, # 's'
+ 9: 2, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 3, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 46: { # 'N'
+ 23: 0, # 'A'
+ 37: 1, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 1, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 2, # 'j'
+ 10: 1, # 'k'
+ 5: 1, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 1, # 'o'
+ 26: 1, # 'p'
+ 7: 1, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 1, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 2, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 42: { # 'O'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 0, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 1, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 0, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 2, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 2, # 'İ'
+ 6: 1, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 48: { # 'P'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 2, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 1, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 15: 2, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 2, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 2, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 0, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 44: { # 'R'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 1, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 1, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 1, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 35: { # 'S'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 1, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 1, # 'l'
+ 13: 2, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 1, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 2, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 3, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 31: { # 'T'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 0, # 'c'
+ 12: 1, # 'd'
+ 2: 3, # 'e'
+ 18: 2, # 'f'
+ 27: 2, # 'g'
+ 25: 0, # 'h'
+ 3: 1, # 'i'
+ 24: 1, # 'j'
+ 10: 2, # 'k'
+ 5: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 2, # 'p'
+ 7: 2, # 'r'
+ 8: 0, # 's'
+ 9: 2, # 't'
+ 14: 2, # 'u'
+ 32: 1, # 'v'
+ 57: 1, # 'w'
+ 58: 1, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 51: { # 'U'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 1, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 1, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 1, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 1, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 2, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 1, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 38: { # 'V'
+ 23: 1, # 'A'
+ 37: 1, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 1, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 2, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 15: 2, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 1, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 1, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 1, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 3, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 62: { # 'W'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 0, # 'd'
+ 2: 0, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 0, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 0, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 0, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 43: { # 'Y'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 0, # 'G'
+ 45: 1, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 2, # 'N'
+ 42: 0, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 1, # 'j'
+ 10: 1, # 'k'
+ 5: 1, # 'l'
+ 13: 3, # 'm'
+ 4: 0, # 'n'
+ 15: 2, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 1, # 'Ü'
+ 59: 1, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 0, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 56: { # 'Z'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 2, # 'Z'
+ 1: 2, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 2, # 'i'
+ 24: 1, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 1, # 'r'
+ 8: 1, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 1, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 1: { # 'a'
+ 23: 3, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 3, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 1, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 3, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 2, # 'Z'
+ 1: 2, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 2, # 'e'
+ 18: 3, # 'f'
+ 27: 3, # 'g'
+ 25: 3, # 'h'
+ 3: 3, # 'i'
+ 24: 3, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 3, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 3, # 'v'
+ 57: 2, # 'w'
+ 58: 0, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 1, # 'î'
+ 34: 1, # 'ö'
+ 17: 3, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 21: { # 'b'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 3, # 'g'
+ 25: 1, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 3, # 'p'
+ 7: 1, # 'r'
+ 8: 2, # 's'
+ 9: 2, # 't'
+ 14: 2, # 'u'
+ 32: 1, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 28: { # 'c'
+ 23: 0, # 'A'
+ 37: 1, # 'B'
+ 47: 1, # 'C'
+ 39: 1, # 'D'
+ 29: 2, # 'E'
+ 52: 0, # 'F'
+ 36: 2, # 'G'
+ 45: 2, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 2, # 'T'
+ 51: 2, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 3, # 'Y'
+ 56: 0, # 'Z'
+ 1: 1, # 'a'
+ 21: 1, # 'b'
+ 28: 2, # 'c'
+ 12: 2, # 'd'
+ 2: 1, # 'e'
+ 18: 1, # 'f'
+ 27: 2, # 'g'
+ 25: 2, # 'h'
+ 3: 3, # 'i'
+ 24: 1, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 15: 2, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 1, # 'u'
+ 32: 0, # 'v'
+ 57: 1, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 1, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 1, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 1, # 'î'
+ 34: 2, # 'ö'
+ 17: 2, # 'ü'
+ 30: 2, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 2, # 'ş'
+ },
+ 12: { # 'd'
+ 23: 1, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 2, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 1, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 1, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 1, # 'f'
+ 27: 3, # 'g'
+ 25: 3, # 'h'
+ 3: 2, # 'i'
+ 24: 3, # 'j'
+ 10: 2, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 2, # 's'
+ 9: 2, # 't'
+ 14: 3, # 'u'
+ 32: 1, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 3, # 'y'
+ 22: 1, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 2: { # 'e'
+ 23: 2, # 'A'
+ 37: 0, # 'B'
+ 47: 2, # 'C'
+ 39: 0, # 'D'
+ 29: 3, # 'E'
+ 52: 1, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 1, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 1, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 1, # 'R'
+ 35: 0, # 'S'
+ 31: 3, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 2, # 'e'
+ 18: 3, # 'f'
+ 27: 3, # 'g'
+ 25: 3, # 'h'
+ 3: 3, # 'i'
+ 24: 3, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 3, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 3, # 'v'
+ 57: 2, # 'w'
+ 58: 0, # 'x'
+ 11: 3, # 'y'
+ 22: 1, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 3, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 18: { # 'f'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 2, # 'f'
+ 27: 1, # 'g'
+ 25: 1, # 'h'
+ 3: 1, # 'i'
+ 24: 1, # 'j'
+ 10: 1, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 2, # 'p'
+ 7: 1, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 1, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 1, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 1, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 27: { # 'g'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 1, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 1, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 2, # 'g'
+ 25: 1, # 'h'
+ 3: 2, # 'i'
+ 24: 3, # 'j'
+ 10: 2, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 2, # 'r'
+ 8: 2, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 1, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 1, # 'y'
+ 22: 0, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 25: { # 'h'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 2, # 'h'
+ 3: 2, # 'i'
+ 24: 3, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 1, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 2, # 't'
+ 14: 3, # 'u'
+ 32: 2, # 'v'
+ 57: 1, # 'w'
+ 58: 0, # 'x'
+ 11: 1, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 3: { # 'i'
+ 23: 2, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 0, # 'N'
+ 42: 1, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 1, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 2, # 'f'
+ 27: 3, # 'g'
+ 25: 1, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 3, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 2, # 'v'
+ 57: 1, # 'w'
+ 58: 1, # 'x'
+ 11: 3, # 'y'
+ 22: 1, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 1, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 3, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 24: { # 'j'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 2, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 1, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 2, # 'f'
+ 27: 1, # 'g'
+ 25: 1, # 'h'
+ 3: 2, # 'i'
+ 24: 1, # 'j'
+ 10: 2, # 'k'
+ 5: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 2, # 'r'
+ 8: 3, # 's'
+ 9: 2, # 't'
+ 14: 3, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 2, # 'x'
+ 11: 1, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 10: { # 'k'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 3, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 3, # 'e'
+ 18: 1, # 'f'
+ 27: 2, # 'g'
+ 25: 2, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 2, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 3, # 'p'
+ 7: 2, # 'r'
+ 8: 2, # 's'
+ 9: 2, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 3, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 3, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 5: { # 'l'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 3, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 1, # 'e'
+ 18: 3, # 'f'
+ 27: 3, # 'g'
+ 25: 2, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 2, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 2, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 13: { # 'm'
+ 23: 1, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 3, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 3, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 2, # 'e'
+ 18: 3, # 'f'
+ 27: 3, # 'g'
+ 25: 3, # 'h'
+ 3: 3, # 'i'
+ 24: 3, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 2, # 'u'
+ 32: 2, # 'v'
+ 57: 1, # 'w'
+ 58: 0, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 3, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 4: { # 'n'
+ 23: 1, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 2, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 1, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 1, # 'f'
+ 27: 2, # 'g'
+ 25: 3, # 'h'
+ 3: 2, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 3, # 'p'
+ 7: 2, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 2, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 2, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 1, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 15: { # 'o'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 2, # 'L'
+ 20: 0, # 'M'
+ 46: 2, # 'N'
+ 42: 1, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 1, # 'i'
+ 24: 2, # 'j'
+ 10: 1, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 2, # 'o'
+ 26: 0, # 'p'
+ 7: 1, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 2, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 3, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 2, # 'ğ'
+ 41: 2, # 'İ'
+ 6: 3, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 2, # 'ş'
+ },
+ 26: { # 'p'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 1, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 1, # 'h'
+ 3: 2, # 'i'
+ 24: 3, # 'j'
+ 10: 1, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 0, # 'o'
+ 26: 2, # 'p'
+ 7: 2, # 'r'
+ 8: 1, # 's'
+ 9: 1, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 1, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 3, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 7: { # 'r'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 1, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 2, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 1, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 2, # 'g'
+ 25: 3, # 'h'
+ 3: 2, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 3, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 8: { # 's'
+ 23: 1, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 1, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 2, # 'g'
+ 25: 2, # 'h'
+ 3: 2, # 'i'
+ 24: 3, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 3, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 2, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 2, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 9: { # 't'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 2, # 'f'
+ 27: 2, # 'g'
+ 25: 2, # 'h'
+ 3: 2, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 3, # 'v'
+ 57: 0, # 'w'
+ 58: 2, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 3, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 2, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 14: { # 'u'
+ 23: 3, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 3, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 2, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 3, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 2, # 'Z'
+ 1: 2, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 2, # 'e'
+ 18: 2, # 'f'
+ 27: 3, # 'g'
+ 25: 3, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 3, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 2, # 'v'
+ 57: 2, # 'w'
+ 58: 0, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 3, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 32: { # 'v'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 1, # 'j'
+ 10: 1, # 'k'
+ 5: 3, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 1, # 'r'
+ 8: 2, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 1, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 1, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 57: { # 'w'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 1, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 1, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 1, # 's'
+ 9: 0, # 't'
+ 14: 1, # 'u'
+ 32: 0, # 'v'
+ 57: 2, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 0, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 58: { # 'x'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 1, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 1, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 2, # 'i'
+ 24: 2, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 2, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 1, # 'r'
+ 8: 2, # 's'
+ 9: 1, # 't'
+ 14: 0, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 11: { # 'y'
+ 23: 1, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 2, # 'g'
+ 25: 2, # 'h'
+ 3: 2, # 'i'
+ 24: 1, # 'j'
+ 10: 2, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 2, # 'r'
+ 8: 1, # 's'
+ 9: 2, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 1, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 3, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 2, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 22: { # 'z'
+ 23: 2, # 'A'
+ 37: 2, # 'B'
+ 47: 1, # 'C'
+ 39: 2, # 'D'
+ 29: 3, # 'E'
+ 52: 1, # 'F'
+ 36: 2, # 'G'
+ 45: 2, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 2, # 'N'
+ 42: 2, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 3, # 'T'
+ 51: 2, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 1, # 'Z'
+ 1: 1, # 'a'
+ 21: 2, # 'b'
+ 28: 1, # 'c'
+ 12: 2, # 'd'
+ 2: 2, # 'e'
+ 18: 3, # 'f'
+ 27: 2, # 'g'
+ 25: 2, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 15: 2, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 0, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 3, # 'y'
+ 22: 2, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 2, # 'Ü'
+ 59: 1, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 2, # 'ö'
+ 17: 2, # 'ü'
+ 30: 2, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 3, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 2, # 'ş'
+ },
+ 63: { # '·'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 0, # 'd'
+ 2: 1, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 0, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 54: { # 'Ç'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 1, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 1, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 1, # 'b'
+ 28: 0, # 'c'
+ 12: 1, # 'd'
+ 2: 0, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 0, # 'h'
+ 3: 3, # 'i'
+ 24: 0, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 2, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 2, # 'r'
+ 8: 0, # 's'
+ 9: 1, # 't'
+ 14: 0, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 2, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 50: { # 'Ö'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 1, # 'D'
+ 29: 2, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 2, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 1, # 'N'
+ 42: 2, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 2, # 'b'
+ 28: 1, # 'c'
+ 12: 2, # 'd'
+ 2: 0, # 'e'
+ 18: 1, # 'f'
+ 27: 1, # 'g'
+ 25: 1, # 'h'
+ 3: 2, # 'i'
+ 24: 0, # 'j'
+ 10: 2, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 3, # 'n'
+ 15: 2, # 'o'
+ 26: 2, # 'p'
+ 7: 3, # 'r'
+ 8: 1, # 's'
+ 9: 2, # 't'
+ 14: 0, # 'u'
+ 32: 1, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 2, # 'ö'
+ 17: 2, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 55: { # 'Ü'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 1, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 1, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 1, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 1, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 1, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 1, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 0, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 59: { # 'â'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 1, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 2, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 0, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 2, # 'm'
+ 4: 0, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 2, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 1, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 33: { # 'ç'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 3, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 0, # 'Z'
+ 1: 0, # 'a'
+ 21: 3, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 0, # 'e'
+ 18: 2, # 'f'
+ 27: 1, # 'g'
+ 25: 3, # 'h'
+ 3: 3, # 'i'
+ 24: 0, # 'j'
+ 10: 3, # 'k'
+ 5: 0, # 'l'
+ 13: 0, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 3, # 'r'
+ 8: 2, # 's'
+ 9: 3, # 't'
+ 14: 0, # 'u'
+ 32: 2, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 1, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 61: { # 'î'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 0, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 0, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 2, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 1, # 'j'
+ 10: 0, # 'k'
+ 5: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 1, # 'n'
+ 15: 0, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 1, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 1, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 1, # 'î'
+ 34: 0, # 'ö'
+ 17: 0, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 1, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 34: { # 'ö'
+ 23: 0, # 'A'
+ 37: 1, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 1, # 'G'
+ 45: 1, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 1, # 'L'
+ 20: 0, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 2, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 1, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 2, # 'c'
+ 12: 1, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 2, # 'g'
+ 25: 2, # 'h'
+ 3: 1, # 'i'
+ 24: 2, # 'j'
+ 10: 1, # 'k'
+ 5: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 2, # 'n'
+ 15: 2, # 'o'
+ 26: 0, # 'p'
+ 7: 0, # 'r'
+ 8: 3, # 's'
+ 9: 1, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 1, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 2, # 'ö'
+ 17: 0, # 'ü'
+ 30: 2, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 1, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 17: { # 'ü'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 0, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 1, # 'J'
+ 16: 1, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 0, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 0, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 0, # 'c'
+ 12: 1, # 'd'
+ 2: 3, # 'e'
+ 18: 1, # 'f'
+ 27: 2, # 'g'
+ 25: 0, # 'h'
+ 3: 1, # 'i'
+ 24: 1, # 'j'
+ 10: 2, # 'k'
+ 5: 3, # 'l'
+ 13: 2, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 2, # 'p'
+ 7: 2, # 'r'
+ 8: 3, # 's'
+ 9: 2, # 't'
+ 14: 3, # 'u'
+ 32: 1, # 'v'
+ 57: 1, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 2, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 30: { # 'ğ'
+ 23: 0, # 'A'
+ 37: 2, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 1, # 'G'
+ 45: 0, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 1, # 'M'
+ 46: 2, # 'N'
+ 42: 2, # 'O'
+ 48: 1, # 'P'
+ 44: 1, # 'R'
+ 35: 0, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 2, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 0, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 2, # 'e'
+ 18: 0, # 'f'
+ 27: 0, # 'g'
+ 25: 0, # 'h'
+ 3: 0, # 'i'
+ 24: 3, # 'j'
+ 10: 1, # 'k'
+ 5: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 0, # 'n'
+ 15: 1, # 'o'
+ 26: 0, # 'p'
+ 7: 1, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 2, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 0, # 'î'
+ 34: 2, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 2, # 'İ'
+ 6: 2, # 'ı'
+ 40: 2, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 41: { # 'İ'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 1, # 'D'
+ 29: 1, # 'E'
+ 52: 0, # 'F'
+ 36: 2, # 'G'
+ 45: 2, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 2, # 'P'
+ 44: 0, # 'R'
+ 35: 1, # 'S'
+ 31: 1, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 0, # 'Z'
+ 1: 1, # 'a'
+ 21: 2, # 'b'
+ 28: 1, # 'c'
+ 12: 2, # 'd'
+ 2: 1, # 'e'
+ 18: 0, # 'f'
+ 27: 3, # 'g'
+ 25: 2, # 'h'
+ 3: 2, # 'i'
+ 24: 2, # 'j'
+ 10: 2, # 'k'
+ 5: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 3, # 'n'
+ 15: 1, # 'o'
+ 26: 1, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 2, # 't'
+ 14: 0, # 'u'
+ 32: 0, # 'v'
+ 57: 1, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 1, # 'Ü'
+ 59: 1, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 1, # 'ö'
+ 17: 1, # 'ü'
+ 30: 2, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 1, # 'ş'
+ },
+ 6: { # 'ı'
+ 23: 2, # 'A'
+ 37: 0, # 'B'
+ 47: 0, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 2, # 'J'
+ 16: 3, # 'K'
+ 49: 0, # 'L'
+ 20: 3, # 'M'
+ 46: 1, # 'N'
+ 42: 0, # 'O'
+ 48: 0, # 'P'
+ 44: 0, # 'R'
+ 35: 0, # 'S'
+ 31: 2, # 'T'
+ 51: 0, # 'U'
+ 38: 0, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 1, # 'Z'
+ 1: 3, # 'a'
+ 21: 2, # 'b'
+ 28: 1, # 'c'
+ 12: 3, # 'd'
+ 2: 3, # 'e'
+ 18: 3, # 'f'
+ 27: 3, # 'g'
+ 25: 2, # 'h'
+ 3: 3, # 'i'
+ 24: 3, # 'j'
+ 10: 3, # 'k'
+ 5: 3, # 'l'
+ 13: 3, # 'm'
+ 4: 3, # 'n'
+ 15: 0, # 'o'
+ 26: 3, # 'p'
+ 7: 3, # 'r'
+ 8: 3, # 's'
+ 9: 3, # 't'
+ 14: 3, # 'u'
+ 32: 3, # 'v'
+ 57: 1, # 'w'
+ 58: 1, # 'x'
+ 11: 3, # 'y'
+ 22: 0, # 'z'
+ 63: 1, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 2, # 'ç'
+ 61: 0, # 'î'
+ 34: 0, # 'ö'
+ 17: 3, # 'ü'
+ 30: 0, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 3, # 'ı'
+ 40: 0, # 'Ş'
+ 19: 0, # 'ş'
+ },
+ 40: { # 'Ş'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 1, # 'D'
+ 29: 1, # 'E'
+ 52: 0, # 'F'
+ 36: 1, # 'G'
+ 45: 2, # 'H'
+ 53: 1, # 'I'
+ 60: 0, # 'J'
+ 16: 0, # 'K'
+ 49: 0, # 'L'
+ 20: 2, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 2, # 'P'
+ 44: 2, # 'R'
+ 35: 1, # 'S'
+ 31: 1, # 'T'
+ 51: 0, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 2, # 'Y'
+ 56: 1, # 'Z'
+ 1: 0, # 'a'
+ 21: 2, # 'b'
+ 28: 0, # 'c'
+ 12: 2, # 'd'
+ 2: 0, # 'e'
+ 18: 3, # 'f'
+ 27: 0, # 'g'
+ 25: 2, # 'h'
+ 3: 3, # 'i'
+ 24: 2, # 'j'
+ 10: 1, # 'k'
+ 5: 0, # 'l'
+ 13: 1, # 'm'
+ 4: 3, # 'n'
+ 15: 2, # 'o'
+ 26: 0, # 'p'
+ 7: 3, # 'r'
+ 8: 2, # 's'
+ 9: 2, # 't'
+ 14: 1, # 'u'
+ 32: 3, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 2, # 'y'
+ 22: 0, # 'z'
+ 63: 0, # '·'
+ 54: 0, # 'Ç'
+ 50: 0, # 'Ö'
+ 55: 1, # 'Ü'
+ 59: 0, # 'â'
+ 33: 0, # 'ç'
+ 61: 0, # 'î'
+ 34: 2, # 'ö'
+ 17: 1, # 'ü'
+ 30: 2, # 'ğ'
+ 41: 0, # 'İ'
+ 6: 2, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 2, # 'ş'
+ },
+ 19: { # 'ş'
+ 23: 0, # 'A'
+ 37: 0, # 'B'
+ 47: 1, # 'C'
+ 39: 0, # 'D'
+ 29: 0, # 'E'
+ 52: 2, # 'F'
+ 36: 1, # 'G'
+ 45: 0, # 'H'
+ 53: 0, # 'I'
+ 60: 0, # 'J'
+ 16: 3, # 'K'
+ 49: 2, # 'L'
+ 20: 0, # 'M'
+ 46: 1, # 'N'
+ 42: 1, # 'O'
+ 48: 1, # 'P'
+ 44: 1, # 'R'
+ 35: 1, # 'S'
+ 31: 0, # 'T'
+ 51: 1, # 'U'
+ 38: 1, # 'V'
+ 62: 0, # 'W'
+ 43: 1, # 'Y'
+ 56: 0, # 'Z'
+ 1: 3, # 'a'
+ 21: 1, # 'b'
+ 28: 2, # 'c'
+ 12: 0, # 'd'
+ 2: 3, # 'e'
+ 18: 0, # 'f'
+ 27: 2, # 'g'
+ 25: 1, # 'h'
+ 3: 1, # 'i'
+ 24: 0, # 'j'
+ 10: 2, # 'k'
+ 5: 2, # 'l'
+ 13: 3, # 'm'
+ 4: 0, # 'n'
+ 15: 0, # 'o'
+ 26: 1, # 'p'
+ 7: 3, # 'r'
+ 8: 0, # 's'
+ 9: 0, # 't'
+ 14: 3, # 'u'
+ 32: 0, # 'v'
+ 57: 0, # 'w'
+ 58: 0, # 'x'
+ 11: 0, # 'y'
+ 22: 2, # 'z'
+ 63: 0, # '·'
+ 54: 1, # 'Ç'
+ 50: 2, # 'Ö'
+ 55: 0, # 'Ü'
+ 59: 0, # 'â'
+ 33: 1, # 'ç'
+ 61: 1, # 'î'
+ 34: 2, # 'ö'
+ 17: 0, # 'ü'
+ 30: 1, # 'ğ'
+ 41: 1, # 'İ'
+ 6: 1, # 'ı'
+ 40: 1, # 'Ş'
+ 19: 1, # 'ş'
+ },
+}
+
+# 255: Undefined characters that did not exist in training text
+# 254: Carriage/Return
+# 253: symbol (punctuation) that does not belong to word
+# 252: 0 - 9
+# 251: Control characters
+
+# Character Mapping Table(s):
+ISO_8859_9_TURKISH_CHAR_TO_ORDER = {
+ 0: 255, # '\x00'
+ 1: 255, # '\x01'
+ 2: 255, # '\x02'
+ 3: 255, # '\x03'
+ 4: 255, # '\x04'
+ 5: 255, # '\x05'
+ 6: 255, # '\x06'
+ 7: 255, # '\x07'
+ 8: 255, # '\x08'
+ 9: 255, # '\t'
+ 10: 255, # '\n'
+ 11: 255, # '\x0b'
+ 12: 255, # '\x0c'
+ 13: 255, # '\r'
+ 14: 255, # '\x0e'
+ 15: 255, # '\x0f'
+ 16: 255, # '\x10'
+ 17: 255, # '\x11'
+ 18: 255, # '\x12'
+ 19: 255, # '\x13'
+ 20: 255, # '\x14'
+ 21: 255, # '\x15'
+ 22: 255, # '\x16'
+ 23: 255, # '\x17'
+ 24: 255, # '\x18'
+ 25: 255, # '\x19'
+ 26: 255, # '\x1a'
+ 27: 255, # '\x1b'
+ 28: 255, # '\x1c'
+ 29: 255, # '\x1d'
+ 30: 255, # '\x1e'
+ 31: 255, # '\x1f'
+ 32: 255, # ' '
+ 33: 255, # '!'
+ 34: 255, # '"'
+ 35: 255, # '#'
+ 36: 255, # '$'
+ 37: 255, # '%'
+ 38: 255, # '&'
+ 39: 255, # "'"
+ 40: 255, # '('
+ 41: 255, # ')'
+ 42: 255, # '*'
+ 43: 255, # '+'
+ 44: 255, # ','
+ 45: 255, # '-'
+ 46: 255, # '.'
+ 47: 255, # '/'
+ 48: 255, # '0'
+ 49: 255, # '1'
+ 50: 255, # '2'
+ 51: 255, # '3'
+ 52: 255, # '4'
+ 53: 255, # '5'
+ 54: 255, # '6'
+ 55: 255, # '7'
+ 56: 255, # '8'
+ 57: 255, # '9'
+ 58: 255, # ':'
+ 59: 255, # ';'
+ 60: 255, # '<'
+ 61: 255, # '='
+ 62: 255, # '>'
+ 63: 255, # '?'
+ 64: 255, # '@'
+ 65: 23, # 'A'
+ 66: 37, # 'B'
+ 67: 47, # 'C'
+ 68: 39, # 'D'
+ 69: 29, # 'E'
+ 70: 52, # 'F'
+ 71: 36, # 'G'
+ 72: 45, # 'H'
+ 73: 53, # 'I'
+ 74: 60, # 'J'
+ 75: 16, # 'K'
+ 76: 49, # 'L'
+ 77: 20, # 'M'
+ 78: 46, # 'N'
+ 79: 42, # 'O'
+ 80: 48, # 'P'
+ 81: 69, # 'Q'
+ 82: 44, # 'R'
+ 83: 35, # 'S'
+ 84: 31, # 'T'
+ 85: 51, # 'U'
+ 86: 38, # 'V'
+ 87: 62, # 'W'
+ 88: 65, # 'X'
+ 89: 43, # 'Y'
+ 90: 56, # 'Z'
+ 91: 255, # '['
+ 92: 255, # '\\'
+ 93: 255, # ']'
+ 94: 255, # '^'
+ 95: 255, # '_'
+ 96: 255, # '`'
+ 97: 1, # 'a'
+ 98: 21, # 'b'
+ 99: 28, # 'c'
+ 100: 12, # 'd'
+ 101: 2, # 'e'
+ 102: 18, # 'f'
+ 103: 27, # 'g'
+ 104: 25, # 'h'
+ 105: 3, # 'i'
+ 106: 24, # 'j'
+ 107: 10, # 'k'
+ 108: 5, # 'l'
+ 109: 13, # 'm'
+ 110: 4, # 'n'
+ 111: 15, # 'o'
+ 112: 26, # 'p'
+ 113: 64, # 'q'
+ 114: 7, # 'r'
+ 115: 8, # 's'
+ 116: 9, # 't'
+ 117: 14, # 'u'
+ 118: 32, # 'v'
+ 119: 57, # 'w'
+ 120: 58, # 'x'
+ 121: 11, # 'y'
+ 122: 22, # 'z'
+ 123: 255, # '{'
+ 124: 255, # '|'
+ 125: 255, # '}'
+ 126: 255, # '~'
+ 127: 255, # '\x7f'
+ 128: 180, # '\x80'
+ 129: 179, # '\x81'
+ 130: 178, # '\x82'
+ 131: 177, # '\x83'
+ 132: 176, # '\x84'
+ 133: 175, # '\x85'
+ 134: 174, # '\x86'
+ 135: 173, # '\x87'
+ 136: 172, # '\x88'
+ 137: 171, # '\x89'
+ 138: 170, # '\x8a'
+ 139: 169, # '\x8b'
+ 140: 168, # '\x8c'
+ 141: 167, # '\x8d'
+ 142: 166, # '\x8e'
+ 143: 165, # '\x8f'
+ 144: 164, # '\x90'
+ 145: 163, # '\x91'
+ 146: 162, # '\x92'
+ 147: 161, # '\x93'
+ 148: 160, # '\x94'
+ 149: 159, # '\x95'
+ 150: 101, # '\x96'
+ 151: 158, # '\x97'
+ 152: 157, # '\x98'
+ 153: 156, # '\x99'
+ 154: 155, # '\x9a'
+ 155: 154, # '\x9b'
+ 156: 153, # '\x9c'
+ 157: 152, # '\x9d'
+ 158: 151, # '\x9e'
+ 159: 106, # '\x9f'
+ 160: 150, # '\xa0'
+ 161: 149, # '¡'
+ 162: 148, # '¢'
+ 163: 147, # '£'
+ 164: 146, # '¤'
+ 165: 145, # '¥'
+ 166: 144, # '¦'
+ 167: 100, # '§'
+ 168: 143, # '¨'
+ 169: 142, # '©'
+ 170: 141, # 'ª'
+ 171: 140, # '«'
+ 172: 139, # '¬'
+ 173: 138, # '\xad'
+ 174: 137, # '®'
+ 175: 136, # '¯'
+ 176: 94, # '°'
+ 177: 80, # '±'
+ 178: 93, # '²'
+ 179: 135, # '³'
+ 180: 105, # '´'
+ 181: 134, # 'µ'
+ 182: 133, # '¶'
+ 183: 63, # '·'
+ 184: 132, # '¸'
+ 185: 131, # '¹'
+ 186: 130, # 'º'
+ 187: 129, # '»'
+ 188: 128, # '¼'
+ 189: 127, # '½'
+ 190: 126, # '¾'
+ 191: 125, # '¿'
+ 192: 124, # 'À'
+ 193: 104, # 'Á'
+ 194: 73, # 'Â'
+ 195: 99, # 'Ã'
+ 196: 79, # 'Ä'
+ 197: 85, # 'Å'
+ 198: 123, # 'Æ'
+ 199: 54, # 'Ç'
+ 200: 122, # 'È'
+ 201: 98, # 'É'
+ 202: 92, # 'Ê'
+ 203: 121, # 'Ë'
+ 204: 120, # 'Ì'
+ 205: 91, # 'Í'
+ 206: 103, # 'Î'
+ 207: 119, # 'Ï'
+ 208: 68, # 'Ğ'
+ 209: 118, # 'Ñ'
+ 210: 117, # 'Ò'
+ 211: 97, # 'Ó'
+ 212: 116, # 'Ô'
+ 213: 115, # 'Õ'
+ 214: 50, # 'Ö'
+ 215: 90, # '×'
+ 216: 114, # 'Ø'
+ 217: 113, # 'Ù'
+ 218: 112, # 'Ú'
+ 219: 111, # 'Û'
+ 220: 55, # 'Ü'
+ 221: 41, # 'İ'
+ 222: 40, # 'Ş'
+ 223: 86, # 'ß'
+ 224: 89, # 'à'
+ 225: 70, # 'á'
+ 226: 59, # 'â'
+ 227: 78, # 'ã'
+ 228: 71, # 'ä'
+ 229: 82, # 'å'
+ 230: 88, # 'æ'
+ 231: 33, # 'ç'
+ 232: 77, # 'è'
+ 233: 66, # 'é'
+ 234: 84, # 'ê'
+ 235: 83, # 'ë'
+ 236: 110, # 'ì'
+ 237: 75, # 'í'
+ 238: 61, # 'î'
+ 239: 96, # 'ï'
+ 240: 30, # 'ğ'
+ 241: 67, # 'ñ'
+ 242: 109, # 'ò'
+ 243: 74, # 'ó'
+ 244: 87, # 'ô'
+ 245: 102, # 'õ'
+ 246: 34, # 'ö'
+ 247: 95, # '÷'
+ 248: 81, # 'ø'
+ 249: 108, # 'ù'
+ 250: 76, # 'ú'
+ 251: 72, # 'û'
+ 252: 17, # 'ü'
+ 253: 6, # 'ı'
+ 254: 19, # 'ş'
+ 255: 107, # 'ÿ'
+}
+
+ISO_8859_9_TURKISH_MODEL = SingleByteCharSetModel(
+ charset_name="ISO-8859-9",
+ language="Turkish",
+ char_to_order_map=ISO_8859_9_TURKISH_CHAR_TO_ORDER,
+ language_model=TURKISH_LANG_MODEL,
+ typical_positive_ratio=0.97029,
+ keep_ascii_letters=True,
+ alphabet="ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzÂÇÎÖÛÜâçîöûüĞğİıŞş",
+)
diff --git a/lib/chardet/latin1prober.py b/lib/chardet/latin1prober.py
new file mode 100644
index 0000000..241f14a
--- /dev/null
+++ b/lib/chardet/latin1prober.py
@@ -0,0 +1,145 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 2001
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+# Shy Shalom - original C code
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .charsetprober import CharSetProber
+from .enums import ProbingState
+
+FREQ_CAT_NUM = 4
+
+UDF = 0 # undefined
+OTH = 1 # other
+ASC = 2 # ascii capital letter
+ASS = 3 # ascii small letter
+ACV = 4 # accent capital vowel
+ACO = 5 # accent capital other
+ASV = 6 # accent small vowel
+ASO = 7 # accent small other
+CLASS_NUM = 8 # total classes
+
+# fmt: off
+Latin1_CharToClass = (
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 00 - 07
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 08 - 0F
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 10 - 17
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 18 - 1F
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 20 - 27
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 28 - 2F
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 30 - 37
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 38 - 3F
+ OTH, ASC, ASC, ASC, ASC, ASC, ASC, ASC, # 40 - 47
+ ASC, ASC, ASC, ASC, ASC, ASC, ASC, ASC, # 48 - 4F
+ ASC, ASC, ASC, ASC, ASC, ASC, ASC, ASC, # 50 - 57
+ ASC, ASC, ASC, OTH, OTH, OTH, OTH, OTH, # 58 - 5F
+ OTH, ASS, ASS, ASS, ASS, ASS, ASS, ASS, # 60 - 67
+ ASS, ASS, ASS, ASS, ASS, ASS, ASS, ASS, # 68 - 6F
+ ASS, ASS, ASS, ASS, ASS, ASS, ASS, ASS, # 70 - 77
+ ASS, ASS, ASS, OTH, OTH, OTH, OTH, OTH, # 78 - 7F
+ OTH, UDF, OTH, ASO, OTH, OTH, OTH, OTH, # 80 - 87
+ OTH, OTH, ACO, OTH, ACO, UDF, ACO, UDF, # 88 - 8F
+ UDF, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # 90 - 97
+ OTH, OTH, ASO, OTH, ASO, UDF, ASO, ACO, # 98 - 9F
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # A0 - A7
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # A8 - AF
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # B0 - B7
+ OTH, OTH, OTH, OTH, OTH, OTH, OTH, OTH, # B8 - BF
+ ACV, ACV, ACV, ACV, ACV, ACV, ACO, ACO, # C0 - C7
+ ACV, ACV, ACV, ACV, ACV, ACV, ACV, ACV, # C8 - CF
+ ACO, ACO, ACV, ACV, ACV, ACV, ACV, OTH, # D0 - D7
+ ACV, ACV, ACV, ACV, ACV, ACO, ACO, ACO, # D8 - DF
+ ASV, ASV, ASV, ASV, ASV, ASV, ASO, ASO, # E0 - E7
+ ASV, ASV, ASV, ASV, ASV, ASV, ASV, ASV, # E8 - EF
+ ASO, ASO, ASV, ASV, ASV, ASV, ASV, OTH, # F0 - F7
+ ASV, ASV, ASV, ASV, ASV, ASO, ASO, ASO, # F8 - FF
+)
+
+# 0 : illegal
+# 1 : very unlikely
+# 2 : normal
+# 3 : very likely
+Latin1ClassModel = (
+# UDF OTH ASC ASS ACV ACO ASV ASO
+ 0, 0, 0, 0, 0, 0, 0, 0, # UDF
+ 0, 3, 3, 3, 3, 3, 3, 3, # OTH
+ 0, 3, 3, 3, 3, 3, 3, 3, # ASC
+ 0, 3, 3, 3, 1, 1, 3, 3, # ASS
+ 0, 3, 3, 3, 1, 2, 1, 2, # ACV
+ 0, 3, 3, 3, 3, 3, 3, 3, # ACO
+ 0, 3, 1, 3, 1, 1, 1, 3, # ASV
+ 0, 3, 1, 3, 1, 1, 3, 3, # ASO
+)
+# fmt: on
+
+
+class Latin1Prober(CharSetProber):
+ def __init__(self):
+ super().__init__()
+ self._last_char_class = None
+ self._freq_counter = None
+ self.reset()
+
+ def reset(self):
+ self._last_char_class = OTH
+ self._freq_counter = [0] * FREQ_CAT_NUM
+ super().reset()
+
+ @property
+ def charset_name(self):
+ return "ISO-8859-1"
+
+ @property
+ def language(self):
+ return ""
+
+ def feed(self, byte_str):
+ byte_str = self.remove_xml_tags(byte_str)
+ for c in byte_str:
+ char_class = Latin1_CharToClass[c]
+ freq = Latin1ClassModel[(self._last_char_class * CLASS_NUM) + char_class]
+ if freq == 0:
+ self._state = ProbingState.NOT_ME
+ break
+ self._freq_counter[freq] += 1
+ self._last_char_class = char_class
+
+ return self.state
+
+ def get_confidence(self):
+ if self.state == ProbingState.NOT_ME:
+ return 0.01
+
+ total = sum(self._freq_counter)
+ confidence = (
+ 0.0
+ if total < 0.01
+ else (self._freq_counter[3] - self._freq_counter[1] * 20.0) / total
+ )
+ confidence = max(confidence, 0.0)
+ # lower the confidence of latin1 so that other more accurate
+ # detector can take priority.
+ confidence *= 0.73
+ return confidence
diff --git a/lib/chardet/mbcharsetprober.py b/lib/chardet/mbcharsetprober.py
new file mode 100644
index 0000000..bf96ad5
--- /dev/null
+++ b/lib/chardet/mbcharsetprober.py
@@ -0,0 +1,95 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 2001
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+# Shy Shalom - original C code
+# Proofpoint, Inc.
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .charsetprober import CharSetProber
+from .enums import MachineState, ProbingState
+
+
+class MultiByteCharSetProber(CharSetProber):
+ """
+ MultiByteCharSetProber
+ """
+
+ def __init__(self, lang_filter=None):
+ super().__init__(lang_filter=lang_filter)
+ self.distribution_analyzer = None
+ self.coding_sm = None
+ self._last_char = [0, 0]
+
+ def reset(self):
+ super().reset()
+ if self.coding_sm:
+ self.coding_sm.reset()
+ if self.distribution_analyzer:
+ self.distribution_analyzer.reset()
+ self._last_char = [0, 0]
+
+ @property
+ def charset_name(self):
+ raise NotImplementedError
+
+ @property
+ def language(self):
+ raise NotImplementedError
+
+ def feed(self, byte_str):
+ for i, byte in enumerate(byte_str):
+ coding_state = self.coding_sm.next_state(byte)
+ if coding_state == MachineState.ERROR:
+ self.logger.debug(
+ "%s %s prober hit error at byte %s",
+ self.charset_name,
+ self.language,
+ i,
+ )
+ self._state = ProbingState.NOT_ME
+ break
+ if coding_state == MachineState.ITS_ME:
+ self._state = ProbingState.FOUND_IT
+ break
+ if coding_state == MachineState.START:
+ char_len = self.coding_sm.get_current_charlen()
+ if i == 0:
+ self._last_char[1] = byte
+ self.distribution_analyzer.feed(self._last_char, char_len)
+ else:
+ self.distribution_analyzer.feed(byte_str[i - 1 : i + 1], char_len)
+
+ self._last_char[0] = byte_str[-1]
+
+ if self.state == ProbingState.DETECTING:
+ if self.distribution_analyzer.got_enough_data() and (
+ self.get_confidence() > self.SHORTCUT_THRESHOLD
+ ):
+ self._state = ProbingState.FOUND_IT
+
+ return self.state
+
+ def get_confidence(self):
+ return self.distribution_analyzer.get_confidence()
diff --git a/lib/chardet/mbcsgroupprober.py b/lib/chardet/mbcsgroupprober.py
new file mode 100644
index 0000000..9448836
--- /dev/null
+++ b/lib/chardet/mbcsgroupprober.py
@@ -0,0 +1,56 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 2001
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+# Shy Shalom - original C code
+# Proofpoint, Inc.
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .big5prober import Big5Prober
+from .charsetgroupprober import CharSetGroupProber
+from .cp949prober import CP949Prober
+from .eucjpprober import EUCJPProber
+from .euckrprober import EUCKRProber
+from .euctwprober import EUCTWProber
+from .gb2312prober import GB2312Prober
+from .johabprober import JOHABProber
+from .sjisprober import SJISProber
+from .utf8prober import UTF8Prober
+
+
+class MBCSGroupProber(CharSetGroupProber):
+ def __init__(self, lang_filter=None):
+ super().__init__(lang_filter=lang_filter)
+ self.probers = [
+ UTF8Prober(),
+ SJISProber(),
+ EUCJPProber(),
+ GB2312Prober(),
+ EUCKRProber(),
+ CP949Prober(),
+ Big5Prober(),
+ EUCTWProber(),
+ JOHABProber(),
+ ]
+ self.reset()
diff --git a/lib/chardet/mbcssm.py b/lib/chardet/mbcssm.py
new file mode 100644
index 0000000..d3b9c4b
--- /dev/null
+++ b/lib/chardet/mbcssm.py
@@ -0,0 +1,660 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .enums import MachineState
+
+# BIG5
+
+# fmt: off
+BIG5_CLS = (
+ 1, 1, 1, 1, 1, 1, 1, 1, # 00 - 07 #allow 0x00 as legal value
+ 1, 1, 1, 1, 1, 1, 0, 0, # 08 - 0f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 10 - 17
+ 1, 1, 1, 0, 1, 1, 1, 1, # 18 - 1f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 20 - 27
+ 1, 1, 1, 1, 1, 1, 1, 1, # 28 - 2f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 30 - 37
+ 1, 1, 1, 1, 1, 1, 1, 1, # 38 - 3f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 40 - 47
+ 2, 2, 2, 2, 2, 2, 2, 2, # 48 - 4f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 50 - 57
+ 2, 2, 2, 2, 2, 2, 2, 2, # 58 - 5f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 60 - 67
+ 2, 2, 2, 2, 2, 2, 2, 2, # 68 - 6f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 70 - 77
+ 2, 2, 2, 2, 2, 2, 2, 1, # 78 - 7f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 80 - 87
+ 4, 4, 4, 4, 4, 4, 4, 4, # 88 - 8f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 90 - 97
+ 4, 4, 4, 4, 4, 4, 4, 4, # 98 - 9f
+ 4, 3, 3, 3, 3, 3, 3, 3, # a0 - a7
+ 3, 3, 3, 3, 3, 3, 3, 3, # a8 - af
+ 3, 3, 3, 3, 3, 3, 3, 3, # b0 - b7
+ 3, 3, 3, 3, 3, 3, 3, 3, # b8 - bf
+ 3, 3, 3, 3, 3, 3, 3, 3, # c0 - c7
+ 3, 3, 3, 3, 3, 3, 3, 3, # c8 - cf
+ 3, 3, 3, 3, 3, 3, 3, 3, # d0 - d7
+ 3, 3, 3, 3, 3, 3, 3, 3, # d8 - df
+ 3, 3, 3, 3, 3, 3, 3, 3, # e0 - e7
+ 3, 3, 3, 3, 3, 3, 3, 3, # e8 - ef
+ 3, 3, 3, 3, 3, 3, 3, 3, # f0 - f7
+ 3, 3, 3, 3, 3, 3, 3, 0 # f8 - ff
+)
+
+BIG5_ST = (
+ MachineState.ERROR,MachineState.START,MachineState.START, 3,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#00-07
+ MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ERROR,#08-0f
+ MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START#10-17
+)
+# fmt: on
+
+BIG5_CHAR_LEN_TABLE = (0, 1, 1, 2, 0)
+
+BIG5_SM_MODEL = {
+ "class_table": BIG5_CLS,
+ "class_factor": 5,
+ "state_table": BIG5_ST,
+ "char_len_table": BIG5_CHAR_LEN_TABLE,
+ "name": "Big5",
+}
+
+# CP949
+# fmt: off
+CP949_CLS = (
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, # 00 - 0f
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, # 10 - 1f
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, # 20 - 2f
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, # 30 - 3f
+ 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, # 40 - 4f
+ 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 1, # 50 - 5f
+ 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, # 60 - 6f
+ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 1, # 70 - 7f
+ 0, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, # 80 - 8f
+ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, # 90 - 9f
+ 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, # a0 - af
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, # b0 - bf
+ 7, 7, 7, 7, 7, 7, 9, 2, 2, 3, 2, 2, 2, 2, 2, 2, # c0 - cf
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, # d0 - df
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, # e0 - ef
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, # f0 - ff
+)
+
+CP949_ST = (
+#cls= 0 1 2 3 4 5 6 7 8 9 # previous state =
+ MachineState.ERROR,MachineState.START, 3,MachineState.ERROR,MachineState.START,MachineState.START, 4, 5,MachineState.ERROR, 6, # MachineState.START
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, # MachineState.ERROR
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME, # MachineState.ITS_ME
+ MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START, # 3
+ MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START, # 4
+ MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START, # 5
+ MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START, # 6
+)
+# fmt: on
+
+CP949_CHAR_LEN_TABLE = (0, 1, 2, 0, 1, 1, 2, 2, 0, 2)
+
+CP949_SM_MODEL = {
+ "class_table": CP949_CLS,
+ "class_factor": 10,
+ "state_table": CP949_ST,
+ "char_len_table": CP949_CHAR_LEN_TABLE,
+ "name": "CP949",
+}
+
+# EUC-JP
+# fmt: off
+EUCJP_CLS = (
+ 4, 4, 4, 4, 4, 4, 4, 4, # 00 - 07
+ 4, 4, 4, 4, 4, 4, 5, 5, # 08 - 0f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 10 - 17
+ 4, 4, 4, 5, 4, 4, 4, 4, # 18 - 1f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 20 - 27
+ 4, 4, 4, 4, 4, 4, 4, 4, # 28 - 2f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 30 - 37
+ 4, 4, 4, 4, 4, 4, 4, 4, # 38 - 3f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 40 - 47
+ 4, 4, 4, 4, 4, 4, 4, 4, # 48 - 4f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 50 - 57
+ 4, 4, 4, 4, 4, 4, 4, 4, # 58 - 5f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 60 - 67
+ 4, 4, 4, 4, 4, 4, 4, 4, # 68 - 6f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 70 - 77
+ 4, 4, 4, 4, 4, 4, 4, 4, # 78 - 7f
+ 5, 5, 5, 5, 5, 5, 5, 5, # 80 - 87
+ 5, 5, 5, 5, 5, 5, 1, 3, # 88 - 8f
+ 5, 5, 5, 5, 5, 5, 5, 5, # 90 - 97
+ 5, 5, 5, 5, 5, 5, 5, 5, # 98 - 9f
+ 5, 2, 2, 2, 2, 2, 2, 2, # a0 - a7
+ 2, 2, 2, 2, 2, 2, 2, 2, # a8 - af
+ 2, 2, 2, 2, 2, 2, 2, 2, # b0 - b7
+ 2, 2, 2, 2, 2, 2, 2, 2, # b8 - bf
+ 2, 2, 2, 2, 2, 2, 2, 2, # c0 - c7
+ 2, 2, 2, 2, 2, 2, 2, 2, # c8 - cf
+ 2, 2, 2, 2, 2, 2, 2, 2, # d0 - d7
+ 2, 2, 2, 2, 2, 2, 2, 2, # d8 - df
+ 0, 0, 0, 0, 0, 0, 0, 0, # e0 - e7
+ 0, 0, 0, 0, 0, 0, 0, 0, # e8 - ef
+ 0, 0, 0, 0, 0, 0, 0, 0, # f0 - f7
+ 0, 0, 0, 0, 0, 0, 0, 5 # f8 - ff
+)
+
+EUCJP_ST = (
+ 3, 4, 3, 5,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#00-07
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,#08-0f
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.START,MachineState.ERROR,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#10-17
+ MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, 3,MachineState.ERROR,#18-1f
+ 3,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START#20-27
+)
+# fmt: on
+
+EUCJP_CHAR_LEN_TABLE = (2, 2, 2, 3, 1, 0)
+
+EUCJP_SM_MODEL = {
+ "class_table": EUCJP_CLS,
+ "class_factor": 6,
+ "state_table": EUCJP_ST,
+ "char_len_table": EUCJP_CHAR_LEN_TABLE,
+ "name": "EUC-JP",
+}
+
+# EUC-KR
+# fmt: off
+EUCKR_CLS = (
+ 1, 1, 1, 1, 1, 1, 1, 1, # 00 - 07
+ 1, 1, 1, 1, 1, 1, 0, 0, # 08 - 0f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 10 - 17
+ 1, 1, 1, 0, 1, 1, 1, 1, # 18 - 1f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 20 - 27
+ 1, 1, 1, 1, 1, 1, 1, 1, # 28 - 2f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 30 - 37
+ 1, 1, 1, 1, 1, 1, 1, 1, # 38 - 3f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 40 - 47
+ 1, 1, 1, 1, 1, 1, 1, 1, # 48 - 4f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 50 - 57
+ 1, 1, 1, 1, 1, 1, 1, 1, # 58 - 5f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 60 - 67
+ 1, 1, 1, 1, 1, 1, 1, 1, # 68 - 6f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 70 - 77
+ 1, 1, 1, 1, 1, 1, 1, 1, # 78 - 7f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 80 - 87
+ 0, 0, 0, 0, 0, 0, 0, 0, # 88 - 8f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 90 - 97
+ 0, 0, 0, 0, 0, 0, 0, 0, # 98 - 9f
+ 0, 2, 2, 2, 2, 2, 2, 2, # a0 - a7
+ 2, 2, 2, 2, 2, 3, 3, 3, # a8 - af
+ 2, 2, 2, 2, 2, 2, 2, 2, # b0 - b7
+ 2, 2, 2, 2, 2, 2, 2, 2, # b8 - bf
+ 2, 2, 2, 2, 2, 2, 2, 2, # c0 - c7
+ 2, 3, 2, 2, 2, 2, 2, 2, # c8 - cf
+ 2, 2, 2, 2, 2, 2, 2, 2, # d0 - d7
+ 2, 2, 2, 2, 2, 2, 2, 2, # d8 - df
+ 2, 2, 2, 2, 2, 2, 2, 2, # e0 - e7
+ 2, 2, 2, 2, 2, 2, 2, 2, # e8 - ef
+ 2, 2, 2, 2, 2, 2, 2, 2, # f0 - f7
+ 2, 2, 2, 2, 2, 2, 2, 0 # f8 - ff
+)
+
+EUCKR_ST = (
+ MachineState.ERROR,MachineState.START, 3,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#00-07
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START #08-0f
+)
+# fmt: on
+
+EUCKR_CHAR_LEN_TABLE = (0, 1, 2, 0)
+
+EUCKR_SM_MODEL = {
+ "class_table": EUCKR_CLS,
+ "class_factor": 4,
+ "state_table": EUCKR_ST,
+ "char_len_table": EUCKR_CHAR_LEN_TABLE,
+ "name": "EUC-KR",
+}
+
+# JOHAB
+# fmt: off
+JOHAB_CLS = (
+ 4,4,4,4,4,4,4,4, # 00 - 07
+ 4,4,4,4,4,4,0,0, # 08 - 0f
+ 4,4,4,4,4,4,4,4, # 10 - 17
+ 4,4,4,0,4,4,4,4, # 18 - 1f
+ 4,4,4,4,4,4,4,4, # 20 - 27
+ 4,4,4,4,4,4,4,4, # 28 - 2f
+ 4,3,3,3,3,3,3,3, # 30 - 37
+ 3,3,3,3,3,3,3,3, # 38 - 3f
+ 3,1,1,1,1,1,1,1, # 40 - 47
+ 1,1,1,1,1,1,1,1, # 48 - 4f
+ 1,1,1,1,1,1,1,1, # 50 - 57
+ 1,1,1,1,1,1,1,1, # 58 - 5f
+ 1,1,1,1,1,1,1,1, # 60 - 67
+ 1,1,1,1,1,1,1,1, # 68 - 6f
+ 1,1,1,1,1,1,1,1, # 70 - 77
+ 1,1,1,1,1,1,1,2, # 78 - 7f
+ 6,6,6,6,8,8,8,8, # 80 - 87
+ 8,8,8,8,8,8,8,8, # 88 - 8f
+ 8,7,7,7,7,7,7,7, # 90 - 97
+ 7,7,7,7,7,7,7,7, # 98 - 9f
+ 7,7,7,7,7,7,7,7, # a0 - a7
+ 7,7,7,7,7,7,7,7, # a8 - af
+ 7,7,7,7,7,7,7,7, # b0 - b7
+ 7,7,7,7,7,7,7,7, # b8 - bf
+ 7,7,7,7,7,7,7,7, # c0 - c7
+ 7,7,7,7,7,7,7,7, # c8 - cf
+ 7,7,7,7,5,5,5,5, # d0 - d7
+ 5,9,9,9,9,9,9,5, # d8 - df
+ 9,9,9,9,9,9,9,9, # e0 - e7
+ 9,9,9,9,9,9,9,9, # e8 - ef
+ 9,9,9,9,9,9,9,9, # f0 - f7
+ 9,9,5,5,5,5,5,0 # f8 - ff
+)
+
+JOHAB_ST = (
+# cls = 0 1 2 3 4 5 6 7 8 9
+ MachineState.ERROR ,MachineState.START ,MachineState.START ,MachineState.START ,MachineState.START ,MachineState.ERROR ,MachineState.ERROR ,3 ,3 ,4 , # MachineState.START
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME, # MachineState.ITS_ME
+ MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR ,MachineState.ERROR , # MachineState.ERROR
+ MachineState.ERROR ,MachineState.START ,MachineState.START ,MachineState.ERROR ,MachineState.ERROR ,MachineState.START ,MachineState.START ,MachineState.START ,MachineState.START ,MachineState.START , # 3
+ MachineState.ERROR ,MachineState.START ,MachineState.ERROR ,MachineState.START ,MachineState.ERROR ,MachineState.START ,MachineState.ERROR ,MachineState.START ,MachineState.ERROR ,MachineState.START , # 4
+)
+# fmt: on
+
+JOHAB_CHAR_LEN_TABLE = (0, 1, 1, 1, 1, 0, 0, 2, 2, 2)
+
+JOHAB_SM_MODEL = {
+ "class_table": JOHAB_CLS,
+ "class_factor": 10,
+ "state_table": JOHAB_ST,
+ "char_len_table": JOHAB_CHAR_LEN_TABLE,
+ "name": "Johab",
+}
+
+# EUC-TW
+# fmt: off
+EUCTW_CLS = (
+ 2, 2, 2, 2, 2, 2, 2, 2, # 00 - 07
+ 2, 2, 2, 2, 2, 2, 0, 0, # 08 - 0f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 10 - 17
+ 2, 2, 2, 0, 2, 2, 2, 2, # 18 - 1f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 20 - 27
+ 2, 2, 2, 2, 2, 2, 2, 2, # 28 - 2f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 30 - 37
+ 2, 2, 2, 2, 2, 2, 2, 2, # 38 - 3f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 40 - 47
+ 2, 2, 2, 2, 2, 2, 2, 2, # 48 - 4f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 50 - 57
+ 2, 2, 2, 2, 2, 2, 2, 2, # 58 - 5f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 60 - 67
+ 2, 2, 2, 2, 2, 2, 2, 2, # 68 - 6f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 70 - 77
+ 2, 2, 2, 2, 2, 2, 2, 2, # 78 - 7f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 80 - 87
+ 0, 0, 0, 0, 0, 0, 6, 0, # 88 - 8f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 90 - 97
+ 0, 0, 0, 0, 0, 0, 0, 0, # 98 - 9f
+ 0, 3, 4, 4, 4, 4, 4, 4, # a0 - a7
+ 5, 5, 1, 1, 1, 1, 1, 1, # a8 - af
+ 1, 1, 1, 1, 1, 1, 1, 1, # b0 - b7
+ 1, 1, 1, 1, 1, 1, 1, 1, # b8 - bf
+ 1, 1, 3, 1, 3, 3, 3, 3, # c0 - c7
+ 3, 3, 3, 3, 3, 3, 3, 3, # c8 - cf
+ 3, 3, 3, 3, 3, 3, 3, 3, # d0 - d7
+ 3, 3, 3, 3, 3, 3, 3, 3, # d8 - df
+ 3, 3, 3, 3, 3, 3, 3, 3, # e0 - e7
+ 3, 3, 3, 3, 3, 3, 3, 3, # e8 - ef
+ 3, 3, 3, 3, 3, 3, 3, 3, # f0 - f7
+ 3, 3, 3, 3, 3, 3, 3, 0 # f8 - ff
+)
+
+EUCTW_ST = (
+ MachineState.ERROR,MachineState.ERROR,MachineState.START, 3, 3, 3, 4,MachineState.ERROR,#00-07
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ITS_ME,#08-0f
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ERROR,MachineState.START,MachineState.ERROR,#10-17
+ MachineState.START,MachineState.START,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#18-1f
+ 5,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.ERROR,MachineState.START,MachineState.START,#20-27
+ MachineState.START,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START #28-2f
+)
+# fmt: on
+
+EUCTW_CHAR_LEN_TABLE = (0, 0, 1, 2, 2, 2, 3)
+
+EUCTW_SM_MODEL = {
+ "class_table": EUCTW_CLS,
+ "class_factor": 7,
+ "state_table": EUCTW_ST,
+ "char_len_table": EUCTW_CHAR_LEN_TABLE,
+ "name": "x-euc-tw",
+}
+
+# GB2312
+# fmt: off
+GB2312_CLS = (
+ 1, 1, 1, 1, 1, 1, 1, 1, # 00 - 07
+ 1, 1, 1, 1, 1, 1, 0, 0, # 08 - 0f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 10 - 17
+ 1, 1, 1, 0, 1, 1, 1, 1, # 18 - 1f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 20 - 27
+ 1, 1, 1, 1, 1, 1, 1, 1, # 28 - 2f
+ 3, 3, 3, 3, 3, 3, 3, 3, # 30 - 37
+ 3, 3, 1, 1, 1, 1, 1, 1, # 38 - 3f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 40 - 47
+ 2, 2, 2, 2, 2, 2, 2, 2, # 48 - 4f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 50 - 57
+ 2, 2, 2, 2, 2, 2, 2, 2, # 58 - 5f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 60 - 67
+ 2, 2, 2, 2, 2, 2, 2, 2, # 68 - 6f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 70 - 77
+ 2, 2, 2, 2, 2, 2, 2, 4, # 78 - 7f
+ 5, 6, 6, 6, 6, 6, 6, 6, # 80 - 87
+ 6, 6, 6, 6, 6, 6, 6, 6, # 88 - 8f
+ 6, 6, 6, 6, 6, 6, 6, 6, # 90 - 97
+ 6, 6, 6, 6, 6, 6, 6, 6, # 98 - 9f
+ 6, 6, 6, 6, 6, 6, 6, 6, # a0 - a7
+ 6, 6, 6, 6, 6, 6, 6, 6, # a8 - af
+ 6, 6, 6, 6, 6, 6, 6, 6, # b0 - b7
+ 6, 6, 6, 6, 6, 6, 6, 6, # b8 - bf
+ 6, 6, 6, 6, 6, 6, 6, 6, # c0 - c7
+ 6, 6, 6, 6, 6, 6, 6, 6, # c8 - cf
+ 6, 6, 6, 6, 6, 6, 6, 6, # d0 - d7
+ 6, 6, 6, 6, 6, 6, 6, 6, # d8 - df
+ 6, 6, 6, 6, 6, 6, 6, 6, # e0 - e7
+ 6, 6, 6, 6, 6, 6, 6, 6, # e8 - ef
+ 6, 6, 6, 6, 6, 6, 6, 6, # f0 - f7
+ 6, 6, 6, 6, 6, 6, 6, 0 # f8 - ff
+)
+
+GB2312_ST = (
+ MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START, 3,MachineState.ERROR,#00-07
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ITS_ME,#08-0f
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ERROR,MachineState.ERROR,MachineState.START,#10-17
+ 4,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#18-1f
+ MachineState.ERROR,MachineState.ERROR, 5,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ERROR,#20-27
+ MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.START #28-2f
+)
+# fmt: on
+
+# To be accurate, the length of class 6 can be either 2 or 4.
+# But it is not necessary to discriminate between the two since
+# it is used for frequency analysis only, and we are validating
+# each code range there as well. So it is safe to set it to be
+# 2 here.
+GB2312_CHAR_LEN_TABLE = (0, 1, 1, 1, 1, 1, 2)
+
+GB2312_SM_MODEL = {
+ "class_table": GB2312_CLS,
+ "class_factor": 7,
+ "state_table": GB2312_ST,
+ "char_len_table": GB2312_CHAR_LEN_TABLE,
+ "name": "GB2312",
+}
+
+# Shift_JIS
+# fmt: off
+SJIS_CLS = (
+ 1, 1, 1, 1, 1, 1, 1, 1, # 00 - 07
+ 1, 1, 1, 1, 1, 1, 0, 0, # 08 - 0f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 10 - 17
+ 1, 1, 1, 0, 1, 1, 1, 1, # 18 - 1f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 20 - 27
+ 1, 1, 1, 1, 1, 1, 1, 1, # 28 - 2f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 30 - 37
+ 1, 1, 1, 1, 1, 1, 1, 1, # 38 - 3f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 40 - 47
+ 2, 2, 2, 2, 2, 2, 2, 2, # 48 - 4f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 50 - 57
+ 2, 2, 2, 2, 2, 2, 2, 2, # 58 - 5f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 60 - 67
+ 2, 2, 2, 2, 2, 2, 2, 2, # 68 - 6f
+ 2, 2, 2, 2, 2, 2, 2, 2, # 70 - 77
+ 2, 2, 2, 2, 2, 2, 2, 1, # 78 - 7f
+ 3, 3, 3, 3, 3, 2, 2, 3, # 80 - 87
+ 3, 3, 3, 3, 3, 3, 3, 3, # 88 - 8f
+ 3, 3, 3, 3, 3, 3, 3, 3, # 90 - 97
+ 3, 3, 3, 3, 3, 3, 3, 3, # 98 - 9f
+ #0xa0 is illegal in sjis encoding, but some pages does
+ #contain such byte. We need to be more error forgiven.
+ 2, 2, 2, 2, 2, 2, 2, 2, # a0 - a7
+ 2, 2, 2, 2, 2, 2, 2, 2, # a8 - af
+ 2, 2, 2, 2, 2, 2, 2, 2, # b0 - b7
+ 2, 2, 2, 2, 2, 2, 2, 2, # b8 - bf
+ 2, 2, 2, 2, 2, 2, 2, 2, # c0 - c7
+ 2, 2, 2, 2, 2, 2, 2, 2, # c8 - cf
+ 2, 2, 2, 2, 2, 2, 2, 2, # d0 - d7
+ 2, 2, 2, 2, 2, 2, 2, 2, # d8 - df
+ 3, 3, 3, 3, 3, 3, 3, 3, # e0 - e7
+ 3, 3, 3, 3, 3, 4, 4, 4, # e8 - ef
+ 3, 3, 3, 3, 3, 3, 3, 3, # f0 - f7
+ 3, 3, 3, 3, 3, 0, 0, 0, # f8 - ff
+)
+
+SJIS_ST = (
+ MachineState.ERROR,MachineState.START,MachineState.START, 3,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#00-07
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,#08-0f
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START #10-17
+)
+# fmt: on
+
+SJIS_CHAR_LEN_TABLE = (0, 1, 1, 2, 0, 0)
+
+SJIS_SM_MODEL = {
+ "class_table": SJIS_CLS,
+ "class_factor": 6,
+ "state_table": SJIS_ST,
+ "char_len_table": SJIS_CHAR_LEN_TABLE,
+ "name": "Shift_JIS",
+}
+
+# UCS2-BE
+# fmt: off
+UCS2BE_CLS = (
+ 0, 0, 0, 0, 0, 0, 0, 0, # 00 - 07
+ 0, 0, 1, 0, 0, 2, 0, 0, # 08 - 0f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 10 - 17
+ 0, 0, 0, 3, 0, 0, 0, 0, # 18 - 1f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 20 - 27
+ 0, 3, 3, 3, 3, 3, 0, 0, # 28 - 2f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 30 - 37
+ 0, 0, 0, 0, 0, 0, 0, 0, # 38 - 3f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 40 - 47
+ 0, 0, 0, 0, 0, 0, 0, 0, # 48 - 4f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 50 - 57
+ 0, 0, 0, 0, 0, 0, 0, 0, # 58 - 5f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 60 - 67
+ 0, 0, 0, 0, 0, 0, 0, 0, # 68 - 6f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 70 - 77
+ 0, 0, 0, 0, 0, 0, 0, 0, # 78 - 7f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 80 - 87
+ 0, 0, 0, 0, 0, 0, 0, 0, # 88 - 8f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 90 - 97
+ 0, 0, 0, 0, 0, 0, 0, 0, # 98 - 9f
+ 0, 0, 0, 0, 0, 0, 0, 0, # a0 - a7
+ 0, 0, 0, 0, 0, 0, 0, 0, # a8 - af
+ 0, 0, 0, 0, 0, 0, 0, 0, # b0 - b7
+ 0, 0, 0, 0, 0, 0, 0, 0, # b8 - bf
+ 0, 0, 0, 0, 0, 0, 0, 0, # c0 - c7
+ 0, 0, 0, 0, 0, 0, 0, 0, # c8 - cf
+ 0, 0, 0, 0, 0, 0, 0, 0, # d0 - d7
+ 0, 0, 0, 0, 0, 0, 0, 0, # d8 - df
+ 0, 0, 0, 0, 0, 0, 0, 0, # e0 - e7
+ 0, 0, 0, 0, 0, 0, 0, 0, # e8 - ef
+ 0, 0, 0, 0, 0, 0, 0, 0, # f0 - f7
+ 0, 0, 0, 0, 0, 0, 4, 5 # f8 - ff
+)
+
+UCS2BE_ST = (
+ 5, 7, 7,MachineState.ERROR, 4, 3,MachineState.ERROR,MachineState.ERROR,#00-07
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,#08-0f
+ MachineState.ITS_ME,MachineState.ITS_ME, 6, 6, 6, 6,MachineState.ERROR,MachineState.ERROR,#10-17
+ 6, 6, 6, 6, 6,MachineState.ITS_ME, 6, 6,#18-1f
+ 6, 6, 6, 6, 5, 7, 7,MachineState.ERROR,#20-27
+ 5, 8, 6, 6,MachineState.ERROR, 6, 6, 6,#28-2f
+ 6, 6, 6, 6,MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START #30-37
+)
+# fmt: on
+
+UCS2BE_CHAR_LEN_TABLE = (2, 2, 2, 0, 2, 2)
+
+UCS2BE_SM_MODEL = {
+ "class_table": UCS2BE_CLS,
+ "class_factor": 6,
+ "state_table": UCS2BE_ST,
+ "char_len_table": UCS2BE_CHAR_LEN_TABLE,
+ "name": "UTF-16BE",
+}
+
+# UCS2-LE
+# fmt: off
+UCS2LE_CLS = (
+ 0, 0, 0, 0, 0, 0, 0, 0, # 00 - 07
+ 0, 0, 1, 0, 0, 2, 0, 0, # 08 - 0f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 10 - 17
+ 0, 0, 0, 3, 0, 0, 0, 0, # 18 - 1f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 20 - 27
+ 0, 3, 3, 3, 3, 3, 0, 0, # 28 - 2f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 30 - 37
+ 0, 0, 0, 0, 0, 0, 0, 0, # 38 - 3f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 40 - 47
+ 0, 0, 0, 0, 0, 0, 0, 0, # 48 - 4f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 50 - 57
+ 0, 0, 0, 0, 0, 0, 0, 0, # 58 - 5f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 60 - 67
+ 0, 0, 0, 0, 0, 0, 0, 0, # 68 - 6f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 70 - 77
+ 0, 0, 0, 0, 0, 0, 0, 0, # 78 - 7f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 80 - 87
+ 0, 0, 0, 0, 0, 0, 0, 0, # 88 - 8f
+ 0, 0, 0, 0, 0, 0, 0, 0, # 90 - 97
+ 0, 0, 0, 0, 0, 0, 0, 0, # 98 - 9f
+ 0, 0, 0, 0, 0, 0, 0, 0, # a0 - a7
+ 0, 0, 0, 0, 0, 0, 0, 0, # a8 - af
+ 0, 0, 0, 0, 0, 0, 0, 0, # b0 - b7
+ 0, 0, 0, 0, 0, 0, 0, 0, # b8 - bf
+ 0, 0, 0, 0, 0, 0, 0, 0, # c0 - c7
+ 0, 0, 0, 0, 0, 0, 0, 0, # c8 - cf
+ 0, 0, 0, 0, 0, 0, 0, 0, # d0 - d7
+ 0, 0, 0, 0, 0, 0, 0, 0, # d8 - df
+ 0, 0, 0, 0, 0, 0, 0, 0, # e0 - e7
+ 0, 0, 0, 0, 0, 0, 0, 0, # e8 - ef
+ 0, 0, 0, 0, 0, 0, 0, 0, # f0 - f7
+ 0, 0, 0, 0, 0, 0, 4, 5 # f8 - ff
+)
+
+UCS2LE_ST = (
+ 6, 6, 7, 6, 4, 3,MachineState.ERROR,MachineState.ERROR,#00-07
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,#08-0f
+ MachineState.ITS_ME,MachineState.ITS_ME, 5, 5, 5,MachineState.ERROR,MachineState.ITS_ME,MachineState.ERROR,#10-17
+ 5, 5, 5,MachineState.ERROR, 5,MachineState.ERROR, 6, 6,#18-1f
+ 7, 6, 8, 8, 5, 5, 5,MachineState.ERROR,#20-27
+ 5, 5, 5,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, 5, 5,#28-2f
+ 5, 5, 5,MachineState.ERROR, 5,MachineState.ERROR,MachineState.START,MachineState.START #30-37
+)
+# fmt: on
+
+UCS2LE_CHAR_LEN_TABLE = (2, 2, 2, 2, 2, 2)
+
+UCS2LE_SM_MODEL = {
+ "class_table": UCS2LE_CLS,
+ "class_factor": 6,
+ "state_table": UCS2LE_ST,
+ "char_len_table": UCS2LE_CHAR_LEN_TABLE,
+ "name": "UTF-16LE",
+}
+
+# UTF-8
+# fmt: off
+UTF8_CLS = (
+ 1, 1, 1, 1, 1, 1, 1, 1, # 00 - 07 #allow 0x00 as a legal value
+ 1, 1, 1, 1, 1, 1, 0, 0, # 08 - 0f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 10 - 17
+ 1, 1, 1, 0, 1, 1, 1, 1, # 18 - 1f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 20 - 27
+ 1, 1, 1, 1, 1, 1, 1, 1, # 28 - 2f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 30 - 37
+ 1, 1, 1, 1, 1, 1, 1, 1, # 38 - 3f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 40 - 47
+ 1, 1, 1, 1, 1, 1, 1, 1, # 48 - 4f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 50 - 57
+ 1, 1, 1, 1, 1, 1, 1, 1, # 58 - 5f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 60 - 67
+ 1, 1, 1, 1, 1, 1, 1, 1, # 68 - 6f
+ 1, 1, 1, 1, 1, 1, 1, 1, # 70 - 77
+ 1, 1, 1, 1, 1, 1, 1, 1, # 78 - 7f
+ 2, 2, 2, 2, 3, 3, 3, 3, # 80 - 87
+ 4, 4, 4, 4, 4, 4, 4, 4, # 88 - 8f
+ 4, 4, 4, 4, 4, 4, 4, 4, # 90 - 97
+ 4, 4, 4, 4, 4, 4, 4, 4, # 98 - 9f
+ 5, 5, 5, 5, 5, 5, 5, 5, # a0 - a7
+ 5, 5, 5, 5, 5, 5, 5, 5, # a8 - af
+ 5, 5, 5, 5, 5, 5, 5, 5, # b0 - b7
+ 5, 5, 5, 5, 5, 5, 5, 5, # b8 - bf
+ 0, 0, 6, 6, 6, 6, 6, 6, # c0 - c7
+ 6, 6, 6, 6, 6, 6, 6, 6, # c8 - cf
+ 6, 6, 6, 6, 6, 6, 6, 6, # d0 - d7
+ 6, 6, 6, 6, 6, 6, 6, 6, # d8 - df
+ 7, 8, 8, 8, 8, 8, 8, 8, # e0 - e7
+ 8, 8, 8, 8, 8, 9, 8, 8, # e8 - ef
+ 10, 11, 11, 11, 11, 11, 11, 11, # f0 - f7
+ 12, 13, 13, 13, 14, 15, 0, 0 # f8 - ff
+)
+
+UTF8_ST = (
+ MachineState.ERROR,MachineState.START,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, 12, 10,#00-07
+ 9, 11, 8, 7, 6, 5, 4, 3,#08-0f
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#10-17
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#18-1f
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,#20-27
+ MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,MachineState.ITS_ME,#28-2f
+ MachineState.ERROR,MachineState.ERROR, 5, 5, 5, 5,MachineState.ERROR,MachineState.ERROR,#30-37
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#38-3f
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, 5, 5, 5,MachineState.ERROR,MachineState.ERROR,#40-47
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#48-4f
+ MachineState.ERROR,MachineState.ERROR, 7, 7, 7, 7,MachineState.ERROR,MachineState.ERROR,#50-57
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#58-5f
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, 7, 7,MachineState.ERROR,MachineState.ERROR,#60-67
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#68-6f
+ MachineState.ERROR,MachineState.ERROR, 9, 9, 9, 9,MachineState.ERROR,MachineState.ERROR,#70-77
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#78-7f
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, 9,MachineState.ERROR,MachineState.ERROR,#80-87
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#88-8f
+ MachineState.ERROR,MachineState.ERROR, 12, 12, 12, 12,MachineState.ERROR,MachineState.ERROR,#90-97
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#98-9f
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR, 12,MachineState.ERROR,MachineState.ERROR,#a0-a7
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#a8-af
+ MachineState.ERROR,MachineState.ERROR, 12, 12, 12,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#b0-b7
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,#b8-bf
+ MachineState.ERROR,MachineState.ERROR,MachineState.START,MachineState.START,MachineState.START,MachineState.START,MachineState.ERROR,MachineState.ERROR,#c0-c7
+ MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR,MachineState.ERROR #c8-cf
+)
+# fmt: on
+
+UTF8_CHAR_LEN_TABLE = (0, 1, 0, 0, 0, 0, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6)
+
+UTF8_SM_MODEL = {
+ "class_table": UTF8_CLS,
+ "class_factor": 16,
+ "state_table": UTF8_ST,
+ "char_len_table": UTF8_CHAR_LEN_TABLE,
+ "name": "UTF-8",
+}
diff --git a/lib/chardet/metadata/__init__.py b/lib/chardet/metadata/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/chardet/metadata/__init__.py
diff --git a/lib/chardet/metadata/languages.py b/lib/chardet/metadata/languages.py
new file mode 100644
index 0000000..1d37884
--- /dev/null
+++ b/lib/chardet/metadata/languages.py
@@ -0,0 +1,351 @@
+"""
+Metadata about languages used by our model training code for our
+SingleByteCharSetProbers. Could be used for other things in the future.
+
+This code is based on the language metadata from the uchardet project.
+"""
+
+from string import ascii_letters
+
+# TODO: Add Ukrainian (KOI8-U)
+
+
+class Language:
+ """Metadata about a language useful for training models
+
+ :ivar name: The human name for the language, in English.
+ :type name: str
+ :ivar iso_code: 2-letter ISO 639-1 if possible, 3-letter ISO code otherwise,
+ or use another catalog as a last resort.
+ :type iso_code: str
+ :ivar use_ascii: Whether or not ASCII letters should be included in trained
+ models.
+ :type use_ascii: bool
+ :ivar charsets: The charsets we want to support and create data for.
+ :type charsets: list of str
+ :ivar alphabet: The characters in the language's alphabet. If `use_ascii` is
+ `True`, you only need to add those not in the ASCII set.
+ :type alphabet: str
+ :ivar wiki_start_pages: The Wikipedia pages to start from if we're crawling
+ Wikipedia for training data.
+ :type wiki_start_pages: list of str
+ """
+
+ def __init__(
+ self,
+ name=None,
+ iso_code=None,
+ use_ascii=True,
+ charsets=None,
+ alphabet=None,
+ wiki_start_pages=None,
+ ):
+ super().__init__()
+ self.name = name
+ self.iso_code = iso_code
+ self.use_ascii = use_ascii
+ self.charsets = charsets
+ if self.use_ascii:
+ if alphabet:
+ alphabet += ascii_letters
+ else:
+ alphabet = ascii_letters
+ elif not alphabet:
+ raise ValueError("Must supply alphabet if use_ascii is False")
+ self.alphabet = "".join(sorted(set(alphabet))) if alphabet else None
+ self.wiki_start_pages = wiki_start_pages
+
+ def __repr__(self):
+ param_str = ", ".join(
+ f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")
+ )
+ return f"{self.__class__.__name__}({param_str})"
+
+
+LANGUAGES = {
+ "Arabic": Language(
+ name="Arabic",
+ iso_code="ar",
+ use_ascii=False,
+ # We only support encodings that use isolated
+ # forms, because the current recommendation is
+ # that the rendering system handles presentation
+ # forms. This means we purposefully skip IBM864.
+ charsets=["ISO-8859-6", "WINDOWS-1256", "CP720", "CP864"],
+ alphabet="ءآأؤإئابةتثجحخدذرزسشصضطظعغػؼؽؾؿـفقكلمنهوىيًٌٍَُِّ",
+ wiki_start_pages=["الصفحة_الرئيسية"],
+ ),
+ "Belarusian": Language(
+ name="Belarusian",
+ iso_code="be",
+ use_ascii=False,
+ charsets=["ISO-8859-5", "WINDOWS-1251", "IBM866", "MacCyrillic"],
+ alphabet="АБВГДЕЁЖЗІЙКЛМНОПРСТУЎФХЦЧШЫЬЭЮЯабвгдеёжзійклмнопрстуўфхцчшыьэюяʼ",
+ wiki_start_pages=["Галоўная_старонка"],
+ ),
+ "Bulgarian": Language(
+ name="Bulgarian",
+ iso_code="bg",
+ use_ascii=False,
+ charsets=["ISO-8859-5", "WINDOWS-1251", "IBM855"],
+ alphabet="АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя",
+ wiki_start_pages=["Начална_страница"],
+ ),
+ "Czech": Language(
+ name="Czech",
+ iso_code="cz",
+ use_ascii=True,
+ charsets=["ISO-8859-2", "WINDOWS-1250"],
+ alphabet="áčďéěíňóřšťúůýžÁČĎÉĚÍŇÓŘŠŤÚŮÝŽ",
+ wiki_start_pages=["Hlavní_strana"],
+ ),
+ "Danish": Language(
+ name="Danish",
+ iso_code="da",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "ISO-8859-15", "WINDOWS-1252"],
+ alphabet="æøåÆØÅ",
+ wiki_start_pages=["Forside"],
+ ),
+ "German": Language(
+ name="German",
+ iso_code="de",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "WINDOWS-1252"],
+ alphabet="äöüßÄÖÜ",
+ wiki_start_pages=["Wikipedia:Hauptseite"],
+ ),
+ "Greek": Language(
+ name="Greek",
+ iso_code="el",
+ use_ascii=False,
+ charsets=["ISO-8859-7", "WINDOWS-1253"],
+ alphabet="αβγδεζηθικλμνξοπρσςτυφχψωάέήίόύώΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΣΤΥΦΧΨΩΆΈΉΊΌΎΏ",
+ wiki_start_pages=["Πύλη:Κύρια"],
+ ),
+ "English": Language(
+ name="English",
+ iso_code="en",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "WINDOWS-1252"],
+ wiki_start_pages=["Main_Page"],
+ ),
+ "Esperanto": Language(
+ name="Esperanto",
+ iso_code="eo",
+ # Q, W, X, and Y not used at all
+ use_ascii=False,
+ charsets=["ISO-8859-3"],
+ alphabet="abcĉdefgĝhĥijĵklmnoprsŝtuŭvzABCĈDEFGĜHĤIJĴKLMNOPRSŜTUŬVZ",
+ wiki_start_pages=["Vikipedio:Ĉefpaĝo"],
+ ),
+ "Spanish": Language(
+ name="Spanish",
+ iso_code="es",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "ISO-8859-15", "WINDOWS-1252"],
+ alphabet="ñáéíóúüÑÁÉÍÓÚÜ",
+ wiki_start_pages=["Wikipedia:Portada"],
+ ),
+ "Estonian": Language(
+ name="Estonian",
+ iso_code="et",
+ use_ascii=False,
+ charsets=["ISO-8859-4", "ISO-8859-13", "WINDOWS-1257"],
+ # C, F, Š, Q, W, X, Y, Z, Ž are only for
+ # loanwords
+ alphabet="ABDEGHIJKLMNOPRSTUVÕÄÖÜabdeghijklmnoprstuvõäöü",
+ wiki_start_pages=["Esileht"],
+ ),
+ "Finnish": Language(
+ name="Finnish",
+ iso_code="fi",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "ISO-8859-15", "WINDOWS-1252"],
+ alphabet="ÅÄÖŠŽåäöšž",
+ wiki_start_pages=["Wikipedia:Etusivu"],
+ ),
+ "French": Language(
+ name="French",
+ iso_code="fr",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "ISO-8859-15", "WINDOWS-1252"],
+ alphabet="œàâçèéîïùûêŒÀÂÇÈÉÎÏÙÛÊ",
+ wiki_start_pages=["Wikipédia:Accueil_principal", "Bœuf (animal)"],
+ ),
+ "Hebrew": Language(
+ name="Hebrew",
+ iso_code="he",
+ use_ascii=False,
+ charsets=["ISO-8859-8", "WINDOWS-1255"],
+ alphabet="אבגדהוזחטיךכלםמןנסעףפץצקרשתװױײ",
+ wiki_start_pages=["עמוד_ראשי"],
+ ),
+ "Croatian": Language(
+ name="Croatian",
+ iso_code="hr",
+ # Q, W, X, Y are only used for foreign words.
+ use_ascii=False,
+ charsets=["ISO-8859-2", "WINDOWS-1250"],
+ alphabet="abcčćdđefghijklmnoprsštuvzžABCČĆDĐEFGHIJKLMNOPRSŠTUVZŽ",
+ wiki_start_pages=["Glavna_stranica"],
+ ),
+ "Hungarian": Language(
+ name="Hungarian",
+ iso_code="hu",
+ # Q, W, X, Y are only used for foreign words.
+ use_ascii=False,
+ charsets=["ISO-8859-2", "WINDOWS-1250"],
+ alphabet="abcdefghijklmnoprstuvzáéíóöőúüűABCDEFGHIJKLMNOPRSTUVZÁÉÍÓÖŐÚÜŰ",
+ wiki_start_pages=["Kezdőlap"],
+ ),
+ "Italian": Language(
+ name="Italian",
+ iso_code="it",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "ISO-8859-15", "WINDOWS-1252"],
+ alphabet="ÀÈÉÌÒÓÙàèéìòóù",
+ wiki_start_pages=["Pagina_principale"],
+ ),
+ "Lithuanian": Language(
+ name="Lithuanian",
+ iso_code="lt",
+ use_ascii=False,
+ charsets=["ISO-8859-13", "WINDOWS-1257", "ISO-8859-4"],
+ # Q, W, and X not used at all
+ alphabet="AĄBCČDEĘĖFGHIĮYJKLMNOPRSŠTUŲŪVZŽaąbcčdeęėfghiįyjklmnoprsštuųūvzž",
+ wiki_start_pages=["Pagrindinis_puslapis"],
+ ),
+ "Latvian": Language(
+ name="Latvian",
+ iso_code="lv",
+ use_ascii=False,
+ charsets=["ISO-8859-13", "WINDOWS-1257", "ISO-8859-4"],
+ # Q, W, X, Y are only for loanwords
+ alphabet="AĀBCČDEĒFGĢHIĪJKĶLĻMNŅOPRSŠTUŪVZŽaābcčdeēfgģhiījkķlļmnņoprsštuūvzž",
+ wiki_start_pages=["Sākumlapa"],
+ ),
+ "Macedonian": Language(
+ name="Macedonian",
+ iso_code="mk",
+ use_ascii=False,
+ charsets=["ISO-8859-5", "WINDOWS-1251", "MacCyrillic", "IBM855"],
+ alphabet="АБВГДЃЕЖЗЅИЈКЛЉМНЊОПРСТЌУФХЦЧЏШабвгдѓежзѕијклљмнњопрстќуфхцчџш",
+ wiki_start_pages=["Главна_страница"],
+ ),
+ "Dutch": Language(
+ name="Dutch",
+ iso_code="nl",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "WINDOWS-1252"],
+ wiki_start_pages=["Hoofdpagina"],
+ ),
+ "Polish": Language(
+ name="Polish",
+ iso_code="pl",
+ # Q and X are only used for foreign words.
+ use_ascii=False,
+ charsets=["ISO-8859-2", "WINDOWS-1250"],
+ alphabet="AĄBCĆDEĘFGHIJKLŁMNŃOÓPRSŚTUWYZŹŻaąbcćdeęfghijklłmnńoóprsśtuwyzźż",
+ wiki_start_pages=["Wikipedia:Strona_główna"],
+ ),
+ "Portuguese": Language(
+ name="Portuguese",
+ iso_code="pt",
+ use_ascii=True,
+ charsets=["ISO-8859-1", "ISO-8859-15", "WINDOWS-1252"],
+ alphabet="ÁÂÃÀÇÉÊÍÓÔÕÚáâãàçéêíóôõú",
+ wiki_start_pages=["Wikipédia:Página_principal"],
+ ),
+ "Romanian": Language(
+ name="Romanian",
+ iso_code="ro",
+ use_ascii=True,
+ charsets=["ISO-8859-2", "WINDOWS-1250"],
+ alphabet="ăâîșțĂÂÎȘȚ",
+ wiki_start_pages=["Pagina_principală"],
+ ),
+ "Russian": Language(
+ name="Russian",
+ iso_code="ru",
+ use_ascii=False,
+ charsets=[
+ "ISO-8859-5",
+ "WINDOWS-1251",
+ "KOI8-R",
+ "MacCyrillic",
+ "IBM866",
+ "IBM855",
+ ],
+ alphabet="абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ",
+ wiki_start_pages=["Заглавная_страница"],
+ ),
+ "Slovak": Language(
+ name="Slovak",
+ iso_code="sk",
+ use_ascii=True,
+ charsets=["ISO-8859-2", "WINDOWS-1250"],
+ alphabet="áäčďéíĺľňóôŕšťúýžÁÄČĎÉÍĹĽŇÓÔŔŠŤÚÝŽ",
+ wiki_start_pages=["Hlavná_stránka"],
+ ),
+ "Slovene": Language(
+ name="Slovene",
+ iso_code="sl",
+ # Q, W, X, Y are only used for foreign words.
+ use_ascii=False,
+ charsets=["ISO-8859-2", "WINDOWS-1250"],
+ alphabet="abcčdefghijklmnoprsštuvzžABCČDEFGHIJKLMNOPRSŠTUVZŽ",
+ wiki_start_pages=["Glavna_stran"],
+ ),
+ # Serbian can be written in both Latin and Cyrillic, but there's no
+ # simple way to get the Latin alphabet pages from Wikipedia through
+ # the API, so for now we just support Cyrillic.
+ "Serbian": Language(
+ name="Serbian",
+ iso_code="sr",
+ alphabet="АБВГДЂЕЖЗИЈКЛЉМНЊОПРСТЋУФХЦЧЏШабвгдђежзијклљмнњопрстћуфхцчџш",
+ charsets=["ISO-8859-5", "WINDOWS-1251", "MacCyrillic", "IBM855"],
+ wiki_start_pages=["Главна_страна"],
+ ),
+ "Thai": Language(
+ name="Thai",
+ iso_code="th",
+ use_ascii=False,
+ charsets=["ISO-8859-11", "TIS-620", "CP874"],
+ alphabet="กขฃคฅฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลฦวศษสหฬอฮฯะัาำิีึืฺุู฿เแโใไๅๆ็่้๊๋์ํ๎๏๐๑๒๓๔๕๖๗๘๙๚๛",
+ wiki_start_pages=["หน้าหลัก"],
+ ),
+ "Turkish": Language(
+ name="Turkish",
+ iso_code="tr",
+ # Q, W, and X are not used by Turkish
+ use_ascii=False,
+ charsets=["ISO-8859-3", "ISO-8859-9", "WINDOWS-1254"],
+ alphabet="abcçdefgğhıijklmnoöprsştuüvyzâîûABCÇDEFGĞHIİJKLMNOÖPRSŞTUÜVYZÂÎÛ",
+ wiki_start_pages=["Ana_Sayfa"],
+ ),
+ "Vietnamese": Language(
+ name="Vietnamese",
+ iso_code="vi",
+ use_ascii=False,
+ # Windows-1258 is the only common 8-bit
+ # Vietnamese encoding supported by Python.
+ # From Wikipedia:
+ # For systems that lack support for Unicode,
+ # dozens of 8-bit Vietnamese code pages are
+ # available.[1] The most common are VISCII
+ # (TCVN 5712:1993), VPS, and Windows-1258.[3]
+ # Where ASCII is required, such as when
+ # ensuring readability in plain text e-mail,
+ # Vietnamese letters are often encoded
+ # according to Vietnamese Quoted-Readable
+ # (VIQR) or VSCII Mnemonic (VSCII-MNEM),[4]
+ # though usage of either variable-width
+ # scheme has declined dramatically following
+ # the adoption of Unicode on the World Wide
+ # Web.
+ charsets=["WINDOWS-1258"],
+ alphabet="aăâbcdđeêghiklmnoôơpqrstuưvxyAĂÂBCDĐEÊGHIKLMNOÔƠPQRSTUƯVXY",
+ wiki_start_pages=["Chữ_Quốc_ngữ"],
+ ),
+}
diff --git a/lib/chardet/sbcharsetprober.py b/lib/chardet/sbcharsetprober.py
new file mode 100644
index 0000000..31d70e1
--- /dev/null
+++ b/lib/chardet/sbcharsetprober.py
@@ -0,0 +1,160 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 2001
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+# Shy Shalom - original C code
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from collections import namedtuple
+
+from .charsetprober import CharSetProber
+from .enums import CharacterCategory, ProbingState, SequenceLikelihood
+
+SingleByteCharSetModel = namedtuple(
+ "SingleByteCharSetModel",
+ [
+ "charset_name",
+ "language",
+ "char_to_order_map",
+ "language_model",
+ "typical_positive_ratio",
+ "keep_ascii_letters",
+ "alphabet",
+ ],
+)
+
+
+class SingleByteCharSetProber(CharSetProber):
+ SAMPLE_SIZE = 64
+ SB_ENOUGH_REL_THRESHOLD = 1024 # 0.25 * SAMPLE_SIZE^2
+ POSITIVE_SHORTCUT_THRESHOLD = 0.95
+ NEGATIVE_SHORTCUT_THRESHOLD = 0.05
+
+ def __init__(self, model, is_reversed=False, name_prober=None):
+ super().__init__()
+ self._model = model
+ # TRUE if we need to reverse every pair in the model lookup
+ self._reversed = is_reversed
+ # Optional auxiliary prober for name decision
+ self._name_prober = name_prober
+ self._last_order = None
+ self._seq_counters = None
+ self._total_seqs = None
+ self._total_char = None
+ self._control_char = None
+ self._freq_char = None
+ self.reset()
+
+ def reset(self):
+ super().reset()
+ # char order of last character
+ self._last_order = 255
+ self._seq_counters = [0] * SequenceLikelihood.get_num_categories()
+ self._total_seqs = 0
+ self._total_char = 0
+ self._control_char = 0
+ # characters that fall in our sampling range
+ self._freq_char = 0
+
+ @property
+ def charset_name(self):
+ if self._name_prober:
+ return self._name_prober.charset_name
+ return self._model.charset_name
+
+ @property
+ def language(self):
+ if self._name_prober:
+ return self._name_prober.language
+ return self._model.language
+
+ def feed(self, byte_str):
+ # TODO: Make filter_international_words keep things in self.alphabet
+ if not self._model.keep_ascii_letters:
+ byte_str = self.filter_international_words(byte_str)
+ else:
+ byte_str = self.remove_xml_tags(byte_str)
+ if not byte_str:
+ return self.state
+ char_to_order_map = self._model.char_to_order_map
+ language_model = self._model.language_model
+ for char in byte_str:
+ order = char_to_order_map.get(char, CharacterCategory.UNDEFINED)
+ # XXX: This was SYMBOL_CAT_ORDER before, with a value of 250, but
+ # CharacterCategory.SYMBOL is actually 253, so we use CONTROL
+ # to make it closer to the original intent. The only difference
+ # is whether or not we count digits and control characters for
+ # _total_char purposes.
+ if order < CharacterCategory.CONTROL:
+ self._total_char += 1
+ if order < self.SAMPLE_SIZE:
+ self._freq_char += 1
+ if self._last_order < self.SAMPLE_SIZE:
+ self._total_seqs += 1
+ if not self._reversed:
+ lm_cat = language_model[self._last_order][order]
+ else:
+ lm_cat = language_model[order][self._last_order]
+ self._seq_counters[lm_cat] += 1
+ self._last_order = order
+
+ charset_name = self._model.charset_name
+ if self.state == ProbingState.DETECTING:
+ if self._total_seqs > self.SB_ENOUGH_REL_THRESHOLD:
+ confidence = self.get_confidence()
+ if confidence > self.POSITIVE_SHORTCUT_THRESHOLD:
+ self.logger.debug(
+ "%s confidence = %s, we have a winner", charset_name, confidence
+ )
+ self._state = ProbingState.FOUND_IT
+ elif confidence < self.NEGATIVE_SHORTCUT_THRESHOLD:
+ self.logger.debug(
+ "%s confidence = %s, below negative shortcut threshold %s",
+ charset_name,
+ confidence,
+ self.NEGATIVE_SHORTCUT_THRESHOLD,
+ )
+ self._state = ProbingState.NOT_ME
+
+ return self.state
+
+ def get_confidence(self):
+ r = 0.01
+ if self._total_seqs > 0:
+ r = (
+ (
+ self._seq_counters[SequenceLikelihood.POSITIVE]
+ + 0.25 * self._seq_counters[SequenceLikelihood.LIKELY]
+ )
+ / self._total_seqs
+ / self._model.typical_positive_ratio
+ )
+ # The more control characters (proportionnaly to the size
+ # of the text), the less confident we become in the current
+ # charset.
+ r = r * (self._total_char - self._control_char) / self._total_char
+ r = r * self._freq_char / self._total_char
+ if r >= 1.0:
+ r = 0.99
+ return r
diff --git a/lib/chardet/sbcsgroupprober.py b/lib/chardet/sbcsgroupprober.py
new file mode 100644
index 0000000..cad001c
--- /dev/null
+++ b/lib/chardet/sbcsgroupprober.py
@@ -0,0 +1,88 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 2001
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+# Shy Shalom - original C code
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .charsetgroupprober import CharSetGroupProber
+from .hebrewprober import HebrewProber
+from .langbulgarianmodel import ISO_8859_5_BULGARIAN_MODEL, WINDOWS_1251_BULGARIAN_MODEL
+from .langgreekmodel import ISO_8859_7_GREEK_MODEL, WINDOWS_1253_GREEK_MODEL
+from .langhebrewmodel import WINDOWS_1255_HEBREW_MODEL
+
+# from .langhungarianmodel import (ISO_8859_2_HUNGARIAN_MODEL,
+# WINDOWS_1250_HUNGARIAN_MODEL)
+from .langrussianmodel import (
+ IBM855_RUSSIAN_MODEL,
+ IBM866_RUSSIAN_MODEL,
+ ISO_8859_5_RUSSIAN_MODEL,
+ KOI8_R_RUSSIAN_MODEL,
+ MACCYRILLIC_RUSSIAN_MODEL,
+ WINDOWS_1251_RUSSIAN_MODEL,
+)
+from .langthaimodel import TIS_620_THAI_MODEL
+from .langturkishmodel import ISO_8859_9_TURKISH_MODEL
+from .sbcharsetprober import SingleByteCharSetProber
+
+
+class SBCSGroupProber(CharSetGroupProber):
+ def __init__(self):
+ super().__init__()
+ hebrew_prober = HebrewProber()
+ logical_hebrew_prober = SingleByteCharSetProber(
+ WINDOWS_1255_HEBREW_MODEL, is_reversed=False, name_prober=hebrew_prober
+ )
+ # TODO: See if using ISO-8859-8 Hebrew model works better here, since
+ # it's actually the visual one
+ visual_hebrew_prober = SingleByteCharSetProber(
+ WINDOWS_1255_HEBREW_MODEL, is_reversed=True, name_prober=hebrew_prober
+ )
+ hebrew_prober.set_model_probers(logical_hebrew_prober, visual_hebrew_prober)
+ # TODO: ORDER MATTERS HERE. I changed the order vs what was in master
+ # and several tests failed that did not before. Some thought
+ # should be put into the ordering, and we should consider making
+ # order not matter here, because that is very counter-intuitive.
+ self.probers = [
+ SingleByteCharSetProber(WINDOWS_1251_RUSSIAN_MODEL),
+ SingleByteCharSetProber(KOI8_R_RUSSIAN_MODEL),
+ SingleByteCharSetProber(ISO_8859_5_RUSSIAN_MODEL),
+ SingleByteCharSetProber(MACCYRILLIC_RUSSIAN_MODEL),
+ SingleByteCharSetProber(IBM866_RUSSIAN_MODEL),
+ SingleByteCharSetProber(IBM855_RUSSIAN_MODEL),
+ SingleByteCharSetProber(ISO_8859_7_GREEK_MODEL),
+ SingleByteCharSetProber(WINDOWS_1253_GREEK_MODEL),
+ SingleByteCharSetProber(ISO_8859_5_BULGARIAN_MODEL),
+ SingleByteCharSetProber(WINDOWS_1251_BULGARIAN_MODEL),
+ # TODO: Restore Hungarian encodings (iso-8859-2 and windows-1250)
+ # after we retrain model.
+ # SingleByteCharSetProber(ISO_8859_2_HUNGARIAN_MODEL),
+ # SingleByteCharSetProber(WINDOWS_1250_HUNGARIAN_MODEL),
+ SingleByteCharSetProber(TIS_620_THAI_MODEL),
+ SingleByteCharSetProber(ISO_8859_9_TURKISH_MODEL),
+ hebrew_prober,
+ logical_hebrew_prober,
+ visual_hebrew_prober,
+ ]
+ self.reset()
diff --git a/lib/chardet/sjisprober.py b/lib/chardet/sjisprober.py
new file mode 100644
index 0000000..3bcbdb7
--- /dev/null
+++ b/lib/chardet/sjisprober.py
@@ -0,0 +1,98 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .chardistribution import SJISDistributionAnalysis
+from .codingstatemachine import CodingStateMachine
+from .enums import MachineState, ProbingState
+from .jpcntx import SJISContextAnalysis
+from .mbcharsetprober import MultiByteCharSetProber
+from .mbcssm import SJIS_SM_MODEL
+
+
+class SJISProber(MultiByteCharSetProber):
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(SJIS_SM_MODEL)
+ self.distribution_analyzer = SJISDistributionAnalysis()
+ self.context_analyzer = SJISContextAnalysis()
+ self.reset()
+
+ def reset(self):
+ super().reset()
+ self.context_analyzer.reset()
+
+ @property
+ def charset_name(self):
+ return self.context_analyzer.charset_name
+
+ @property
+ def language(self):
+ return "Japanese"
+
+ def feed(self, byte_str):
+ for i, byte in enumerate(byte_str):
+ coding_state = self.coding_sm.next_state(byte)
+ if coding_state == MachineState.ERROR:
+ self.logger.debug(
+ "%s %s prober hit error at byte %s",
+ self.charset_name,
+ self.language,
+ i,
+ )
+ self._state = ProbingState.NOT_ME
+ break
+ if coding_state == MachineState.ITS_ME:
+ self._state = ProbingState.FOUND_IT
+ break
+ if coding_state == MachineState.START:
+ char_len = self.coding_sm.get_current_charlen()
+ if i == 0:
+ self._last_char[1] = byte
+ self.context_analyzer.feed(
+ self._last_char[2 - char_len :], char_len
+ )
+ self.distribution_analyzer.feed(self._last_char, char_len)
+ else:
+ self.context_analyzer.feed(
+ byte_str[i + 1 - char_len : i + 3 - char_len], char_len
+ )
+ self.distribution_analyzer.feed(byte_str[i - 1 : i + 1], char_len)
+
+ self._last_char[0] = byte_str[-1]
+
+ if self.state == ProbingState.DETECTING:
+ if self.context_analyzer.got_enough_data() and (
+ self.get_confidence() > self.SHORTCUT_THRESHOLD
+ ):
+ self._state = ProbingState.FOUND_IT
+
+ return self.state
+
+ def get_confidence(self):
+ context_conf = self.context_analyzer.get_confidence()
+ distrib_conf = self.distribution_analyzer.get_confidence()
+ return max(context_conf, distrib_conf)
diff --git a/lib/chardet/universaldetector.py b/lib/chardet/universaldetector.py
new file mode 100644
index 0000000..22fcf82
--- /dev/null
+++ b/lib/chardet/universaldetector.py
@@ -0,0 +1,328 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is Mozilla Universal charset detector code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 2001
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+# Shy Shalom - original C code
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+"""
+Module containing the UniversalDetector detector class, which is the primary
+class a user of ``chardet`` should use.
+
+:author: Mark Pilgrim (initial port to Python)
+:author: Shy Shalom (original C code)
+:author: Dan Blanchard (major refactoring for 3.0)
+:author: Ian Cordasco
+"""
+
+
+import codecs
+import logging
+import re
+
+from .charsetgroupprober import CharSetGroupProber
+from .enums import InputState, LanguageFilter, ProbingState
+from .escprober import EscCharSetProber
+from .latin1prober import Latin1Prober
+from .mbcsgroupprober import MBCSGroupProber
+from .sbcsgroupprober import SBCSGroupProber
+from .utf1632prober import UTF1632Prober
+
+
+class UniversalDetector:
+ """
+ The ``UniversalDetector`` class underlies the ``chardet.detect`` function
+ and coordinates all of the different charset probers.
+
+ To get a ``dict`` containing an encoding and its confidence, you can simply
+ run:
+
+ .. code::
+
+ u = UniversalDetector()
+ u.feed(some_bytes)
+ u.close()
+ detected = u.result
+
+ """
+
+ MINIMUM_THRESHOLD = 0.20
+ HIGH_BYTE_DETECTOR = re.compile(b"[\x80-\xFF]")
+ ESC_DETECTOR = re.compile(b"(\033|~{)")
+ WIN_BYTE_DETECTOR = re.compile(b"[\x80-\x9F]")
+ ISO_WIN_MAP = {
+ "iso-8859-1": "Windows-1252",
+ "iso-8859-2": "Windows-1250",
+ "iso-8859-5": "Windows-1251",
+ "iso-8859-6": "Windows-1256",
+ "iso-8859-7": "Windows-1253",
+ "iso-8859-8": "Windows-1255",
+ "iso-8859-9": "Windows-1254",
+ "iso-8859-13": "Windows-1257",
+ }
+
+ def __init__(self, lang_filter=LanguageFilter.ALL):
+ self._esc_charset_prober = None
+ self._utf1632_prober = None
+ self._charset_probers = []
+ self.result = None
+ self.done = None
+ self._got_data = None
+ self._input_state = None
+ self._last_char = None
+ self.lang_filter = lang_filter
+ self.logger = logging.getLogger(__name__)
+ self._has_win_bytes = None
+ self.reset()
+
+ @property
+ def input_state(self):
+ return self._input_state
+
+ @property
+ def has_win_bytes(self):
+ return self._has_win_bytes
+
+ @property
+ def charset_probers(self):
+ return self._charset_probers
+
+ def reset(self):
+ """
+ Reset the UniversalDetector and all of its probers back to their
+ initial states. This is called by ``__init__``, so you only need to
+ call this directly in between analyses of different documents.
+ """
+ self.result = {"encoding": None, "confidence": 0.0, "language": None}
+ self.done = False
+ self._got_data = False
+ self._has_win_bytes = False
+ self._input_state = InputState.PURE_ASCII
+ self._last_char = b""
+ if self._esc_charset_prober:
+ self._esc_charset_prober.reset()
+ if self._utf1632_prober:
+ self._utf1632_prober.reset()
+ for prober in self._charset_probers:
+ prober.reset()
+
+ def feed(self, byte_str):
+ """
+ Takes a chunk of a document and feeds it through all of the relevant
+ charset probers.
+
+ After calling ``feed``, you can check the value of the ``done``
+ attribute to see if you need to continue feeding the
+ ``UniversalDetector`` more data, or if it has made a prediction
+ (in the ``result`` attribute).
+
+ .. note::
+ You should always call ``close`` when you're done feeding in your
+ document if ``done`` is not already ``True``.
+ """
+ if self.done:
+ return
+
+ if not byte_str:
+ return
+
+ if not isinstance(byte_str, bytearray):
+ byte_str = bytearray(byte_str)
+
+ # First check for known BOMs, since these are guaranteed to be correct
+ if not self._got_data:
+ # If the data starts with BOM, we know it is UTF
+ if byte_str.startswith(codecs.BOM_UTF8):
+ # EF BB BF UTF-8 with BOM
+ self.result = {
+ "encoding": "UTF-8-SIG",
+ "confidence": 1.0,
+ "language": "",
+ }
+ elif byte_str.startswith((codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE)):
+ # FF FE 00 00 UTF-32, little-endian BOM
+ # 00 00 FE FF UTF-32, big-endian BOM
+ self.result = {"encoding": "UTF-32", "confidence": 1.0, "language": ""}
+ elif byte_str.startswith(b"\xFE\xFF\x00\x00"):
+ # FE FF 00 00 UCS-4, unusual octet order BOM (3412)
+ self.result = {
+ "encoding": "X-ISO-10646-UCS-4-3412",
+ "confidence": 1.0,
+ "language": "",
+ }
+ elif byte_str.startswith(b"\x00\x00\xFF\xFE"):
+ # 00 00 FF FE UCS-4, unusual octet order BOM (2143)
+ self.result = {
+ "encoding": "X-ISO-10646-UCS-4-2143",
+ "confidence": 1.0,
+ "language": "",
+ }
+ elif byte_str.startswith((codecs.BOM_LE, codecs.BOM_BE)):
+ # FF FE UTF-16, little endian BOM
+ # FE FF UTF-16, big endian BOM
+ self.result = {"encoding": "UTF-16", "confidence": 1.0, "language": ""}
+
+ self._got_data = True
+ if self.result["encoding"] is not None:
+ self.done = True
+ return
+
+ # If none of those matched and we've only see ASCII so far, check
+ # for high bytes and escape sequences
+ if self._input_state == InputState.PURE_ASCII:
+ if self.HIGH_BYTE_DETECTOR.search(byte_str):
+ self._input_state = InputState.HIGH_BYTE
+ elif (
+ self._input_state == InputState.PURE_ASCII
+ and self.ESC_DETECTOR.search(self._last_char + byte_str)
+ ):
+ self._input_state = InputState.ESC_ASCII
+
+ self._last_char = byte_str[-1:]
+
+ # next we will look to see if it is appears to be either a UTF-16 or
+ # UTF-32 encoding
+ if not self._utf1632_prober:
+ self._utf1632_prober = UTF1632Prober()
+
+ if self._utf1632_prober.state == ProbingState.DETECTING:
+ if self._utf1632_prober.feed(byte_str) == ProbingState.FOUND_IT:
+ self.result = {
+ "encoding": self._utf1632_prober.charset_name,
+ "confidence": self._utf1632_prober.get_confidence(),
+ "language": "",
+ }
+ self.done = True
+ return
+
+ # If we've seen escape sequences, use the EscCharSetProber, which
+ # uses a simple state machine to check for known escape sequences in
+ # HZ and ISO-2022 encodings, since those are the only encodings that
+ # use such sequences.
+ if self._input_state == InputState.ESC_ASCII:
+ if not self._esc_charset_prober:
+ self._esc_charset_prober = EscCharSetProber(self.lang_filter)
+ if self._esc_charset_prober.feed(byte_str) == ProbingState.FOUND_IT:
+ self.result = {
+ "encoding": self._esc_charset_prober.charset_name,
+ "confidence": self._esc_charset_prober.get_confidence(),
+ "language": self._esc_charset_prober.language,
+ }
+ self.done = True
+ # If we've seen high bytes (i.e., those with values greater than 127),
+ # we need to do more complicated checks using all our multi-byte and
+ # single-byte probers that are left. The single-byte probers
+ # use character bigram distributions to determine the encoding, whereas
+ # the multi-byte probers use a combination of character unigram and
+ # bigram distributions.
+ elif self._input_state == InputState.HIGH_BYTE:
+ if not self._charset_probers:
+ self._charset_probers = [MBCSGroupProber(self.lang_filter)]
+ # If we're checking non-CJK encodings, use single-byte prober
+ if self.lang_filter & LanguageFilter.NON_CJK:
+ self._charset_probers.append(SBCSGroupProber())
+ self._charset_probers.append(Latin1Prober())
+ for prober in self._charset_probers:
+ if prober.feed(byte_str) == ProbingState.FOUND_IT:
+ self.result = {
+ "encoding": prober.charset_name,
+ "confidence": prober.get_confidence(),
+ "language": prober.language,
+ }
+ self.done = True
+ break
+ if self.WIN_BYTE_DETECTOR.search(byte_str):
+ self._has_win_bytes = True
+
+ def close(self):
+ """
+ Stop analyzing the current document and come up with a final
+ prediction.
+
+ :returns: The ``result`` attribute, a ``dict`` with the keys
+ `encoding`, `confidence`, and `language`.
+ """
+ # Don't bother with checks if we're already done
+ if self.done:
+ return self.result
+ self.done = True
+
+ if not self._got_data:
+ self.logger.debug("no data received!")
+
+ # Default to ASCII if it is all we've seen so far
+ elif self._input_state == InputState.PURE_ASCII:
+ self.result = {"encoding": "ascii", "confidence": 1.0, "language": ""}
+
+ # If we have seen non-ASCII, return the best that met MINIMUM_THRESHOLD
+ elif self._input_state == InputState.HIGH_BYTE:
+ prober_confidence = None
+ max_prober_confidence = 0.0
+ max_prober = None
+ for prober in self._charset_probers:
+ if not prober:
+ continue
+ prober_confidence = prober.get_confidence()
+ if prober_confidence > max_prober_confidence:
+ max_prober_confidence = prober_confidence
+ max_prober = prober
+ if max_prober and (max_prober_confidence > self.MINIMUM_THRESHOLD):
+ charset_name = max_prober.charset_name
+ lower_charset_name = max_prober.charset_name.lower()
+ confidence = max_prober.get_confidence()
+ # Use Windows encoding name instead of ISO-8859 if we saw any
+ # extra Windows-specific bytes
+ if lower_charset_name.startswith("iso-8859"):
+ if self._has_win_bytes:
+ charset_name = self.ISO_WIN_MAP.get(
+ lower_charset_name, charset_name
+ )
+ self.result = {
+ "encoding": charset_name,
+ "confidence": confidence,
+ "language": max_prober.language,
+ }
+
+ # Log all prober confidences if none met MINIMUM_THRESHOLD
+ if self.logger.getEffectiveLevel() <= logging.DEBUG:
+ if self.result["encoding"] is None:
+ self.logger.debug("no probers hit minimum threshold")
+ for group_prober in self._charset_probers:
+ if not group_prober:
+ continue
+ if isinstance(group_prober, CharSetGroupProber):
+ for prober in group_prober.probers:
+ self.logger.debug(
+ "%s %s confidence = %s",
+ prober.charset_name,
+ prober.language,
+ prober.get_confidence(),
+ )
+ else:
+ self.logger.debug(
+ "%s %s confidence = %s",
+ group_prober.charset_name,
+ group_prober.language,
+ group_prober.get_confidence(),
+ )
+ return self.result
diff --git a/lib/chardet/utf1632prober.py b/lib/chardet/utf1632prober.py
new file mode 100644
index 0000000..9fd1580
--- /dev/null
+++ b/lib/chardet/utf1632prober.py
@@ -0,0 +1,223 @@
+######################## BEGIN LICENSE BLOCK ########################
+#
+# Contributor(s):
+# Jason Zavaglia
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+from .charsetprober import CharSetProber
+from .enums import ProbingState
+
+
+class UTF1632Prober(CharSetProber):
+ """
+ This class simply looks for occurrences of zero bytes, and infers
+ whether the file is UTF16 or UTF32 (low-endian or big-endian)
+ For instance, files looking like ( \0 \0 \0 [nonzero] )+
+ have a good probability to be UTF32BE. Files looking like ( \0 [nonzero] )+
+ may be guessed to be UTF16BE, and inversely for little-endian varieties.
+ """
+
+ # how many logical characters to scan before feeling confident of prediction
+ MIN_CHARS_FOR_DETECTION = 20
+ # a fixed constant ratio of expected zeros or non-zeros in modulo-position.
+ EXPECTED_RATIO = 0.94
+
+ def __init__(self):
+ super().__init__()
+ self.position = 0
+ self.zeros_at_mod = [0] * 4
+ self.nonzeros_at_mod = [0] * 4
+ self._state = ProbingState.DETECTING
+ self.quad = [0, 0, 0, 0]
+ self.invalid_utf16be = False
+ self.invalid_utf16le = False
+ self.invalid_utf32be = False
+ self.invalid_utf32le = False
+ self.first_half_surrogate_pair_detected_16be = False
+ self.first_half_surrogate_pair_detected_16le = False
+ self.reset()
+
+ def reset(self):
+ super().reset()
+ self.position = 0
+ self.zeros_at_mod = [0] * 4
+ self.nonzeros_at_mod = [0] * 4
+ self._state = ProbingState.DETECTING
+ self.invalid_utf16be = False
+ self.invalid_utf16le = False
+ self.invalid_utf32be = False
+ self.invalid_utf32le = False
+ self.first_half_surrogate_pair_detected_16be = False
+ self.first_half_surrogate_pair_detected_16le = False
+ self.quad = [0, 0, 0, 0]
+
+ @property
+ def charset_name(self):
+ if self.is_likely_utf32be():
+ return "utf-32be"
+ if self.is_likely_utf32le():
+ return "utf-32le"
+ if self.is_likely_utf16be():
+ return "utf-16be"
+ if self.is_likely_utf16le():
+ return "utf-16le"
+ # default to something valid
+ return "utf-16"
+
+ @property
+ def language(self):
+ return ""
+
+ def approx_32bit_chars(self):
+ return max(1.0, self.position / 4.0)
+
+ def approx_16bit_chars(self):
+ return max(1.0, self.position / 2.0)
+
+ def is_likely_utf32be(self):
+ approx_chars = self.approx_32bit_chars()
+ return approx_chars >= self.MIN_CHARS_FOR_DETECTION and (
+ self.zeros_at_mod[0] / approx_chars > self.EXPECTED_RATIO
+ and self.zeros_at_mod[1] / approx_chars > self.EXPECTED_RATIO
+ and self.zeros_at_mod[2] / approx_chars > self.EXPECTED_RATIO
+ and self.nonzeros_at_mod[3] / approx_chars > self.EXPECTED_RATIO
+ and not self.invalid_utf32be
+ )
+
+ def is_likely_utf32le(self):
+ approx_chars = self.approx_32bit_chars()
+ return approx_chars >= self.MIN_CHARS_FOR_DETECTION and (
+ self.nonzeros_at_mod[0] / approx_chars > self.EXPECTED_RATIO
+ and self.zeros_at_mod[1] / approx_chars > self.EXPECTED_RATIO
+ and self.zeros_at_mod[2] / approx_chars > self.EXPECTED_RATIO
+ and self.zeros_at_mod[3] / approx_chars > self.EXPECTED_RATIO
+ and not self.invalid_utf32le
+ )
+
+ def is_likely_utf16be(self):
+ approx_chars = self.approx_16bit_chars()
+ return approx_chars >= self.MIN_CHARS_FOR_DETECTION and (
+ (self.nonzeros_at_mod[1] + self.nonzeros_at_mod[3]) / approx_chars
+ > self.EXPECTED_RATIO
+ and (self.zeros_at_mod[0] + self.zeros_at_mod[2]) / approx_chars
+ > self.EXPECTED_RATIO
+ and not self.invalid_utf16be
+ )
+
+ def is_likely_utf16le(self):
+ approx_chars = self.approx_16bit_chars()
+ return approx_chars >= self.MIN_CHARS_FOR_DETECTION and (
+ (self.nonzeros_at_mod[0] + self.nonzeros_at_mod[2]) / approx_chars
+ > self.EXPECTED_RATIO
+ and (self.zeros_at_mod[1] + self.zeros_at_mod[3]) / approx_chars
+ > self.EXPECTED_RATIO
+ and not self.invalid_utf16le
+ )
+
+ def validate_utf32_characters(self, quad):
+ """
+ Validate if the quad of bytes is valid UTF-32.
+
+ UTF-32 is valid in the range 0x00000000 - 0x0010FFFF
+ excluding 0x0000D800 - 0x0000DFFF
+
+ https://en.wikipedia.org/wiki/UTF-32
+ """
+ if (
+ quad[0] != 0
+ or quad[1] > 0x10
+ or (quad[0] == 0 and quad[1] == 0 and 0xD8 <= quad[2] <= 0xDF)
+ ):
+ self.invalid_utf32be = True
+ if (
+ quad[3] != 0
+ or quad[2] > 0x10
+ or (quad[3] == 0 and quad[2] == 0 and 0xD8 <= quad[1] <= 0xDF)
+ ):
+ self.invalid_utf32le = True
+
+ def validate_utf16_characters(self, pair):
+ """
+ Validate if the pair of bytes is valid UTF-16.
+
+ UTF-16 is valid in the range 0x0000 - 0xFFFF excluding 0xD800 - 0xFFFF
+ with an exception for surrogate pairs, which must be in the range
+ 0xD800-0xDBFF followed by 0xDC00-0xDFFF
+
+ https://en.wikipedia.org/wiki/UTF-16
+ """
+ if not self.first_half_surrogate_pair_detected_16be:
+ if 0xD8 <= pair[0] <= 0xDB:
+ self.first_half_surrogate_pair_detected_16be = True
+ elif 0xDC <= pair[0] <= 0xDF:
+ self.invalid_utf16be = True
+ else:
+ if 0xDC <= pair[0] <= 0xDF:
+ self.first_half_surrogate_pair_detected_16be = False
+ else:
+ self.invalid_utf16be = True
+
+ if not self.first_half_surrogate_pair_detected_16le:
+ if 0xD8 <= pair[1] <= 0xDB:
+ self.first_half_surrogate_pair_detected_16le = True
+ elif 0xDC <= pair[1] <= 0xDF:
+ self.invalid_utf16le = True
+ else:
+ if 0xDC <= pair[1] <= 0xDF:
+ self.first_half_surrogate_pair_detected_16le = False
+ else:
+ self.invalid_utf16le = True
+
+ def feed(self, byte_str):
+ for c in byte_str:
+ mod4 = self.position % 4
+ self.quad[mod4] = c
+ if mod4 == 3:
+ self.validate_utf32_characters(self.quad)
+ self.validate_utf16_characters(self.quad[0:2])
+ self.validate_utf16_characters(self.quad[2:4])
+ if c == 0:
+ self.zeros_at_mod[mod4] += 1
+ else:
+ self.nonzeros_at_mod[mod4] += 1
+ self.position += 1
+ return self.state
+
+ @property
+ def state(self):
+ if self._state in {ProbingState.NOT_ME, ProbingState.FOUND_IT}:
+ # terminal, decided states
+ return self._state
+ if self.get_confidence() > 0.80:
+ self._state = ProbingState.FOUND_IT
+ elif self.position > 4 * 1024:
+ # if we get to 4kb into the file, and we can't conclude it's UTF,
+ # let's give up
+ self._state = ProbingState.NOT_ME
+ return self._state
+
+ def get_confidence(self):
+ return (
+ 0.85
+ if (
+ self.is_likely_utf16le()
+ or self.is_likely_utf16be()
+ or self.is_likely_utf32le()
+ or self.is_likely_utf32be()
+ )
+ else 0.00
+ )
diff --git a/lib/chardet/utf8prober.py b/lib/chardet/utf8prober.py
new file mode 100644
index 0000000..3aae09e
--- /dev/null
+++ b/lib/chardet/utf8prober.py
@@ -0,0 +1,80 @@
+######################## BEGIN LICENSE BLOCK ########################
+# The Original Code is mozilla.org code.
+#
+# The Initial Developer of the Original Code is
+# Netscape Communications Corporation.
+# Portions created by the Initial Developer are Copyright (C) 1998
+# the Initial Developer. All Rights Reserved.
+#
+# Contributor(s):
+# Mark Pilgrim - port to Python
+#
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+#
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
+# 02110-1301 USA
+######################### END LICENSE BLOCK #########################
+
+from .charsetprober import CharSetProber
+from .codingstatemachine import CodingStateMachine
+from .enums import MachineState, ProbingState
+from .mbcssm import UTF8_SM_MODEL
+
+
+class UTF8Prober(CharSetProber):
+ ONE_CHAR_PROB = 0.5
+
+ def __init__(self):
+ super().__init__()
+ self.coding_sm = CodingStateMachine(UTF8_SM_MODEL)
+ self._num_mb_chars = None
+ self.reset()
+
+ def reset(self):
+ super().reset()
+ self.coding_sm.reset()
+ self._num_mb_chars = 0
+
+ @property
+ def charset_name(self):
+ return "utf-8"
+
+ @property
+ def language(self):
+ return ""
+
+ def feed(self, byte_str):
+ for c in byte_str:
+ coding_state = self.coding_sm.next_state(c)
+ if coding_state == MachineState.ERROR:
+ self._state = ProbingState.NOT_ME
+ break
+ if coding_state == MachineState.ITS_ME:
+ self._state = ProbingState.FOUND_IT
+ break
+ if coding_state == MachineState.START:
+ if self.coding_sm.get_current_charlen() >= 2:
+ self._num_mb_chars += 1
+
+ if self.state == ProbingState.DETECTING:
+ if self.get_confidence() > self.SHORTCUT_THRESHOLD:
+ self._state = ProbingState.FOUND_IT
+
+ return self.state
+
+ def get_confidence(self):
+ unlike = 0.99
+ if self._num_mb_chars < 6:
+ unlike *= self.ONE_CHAR_PROB**self._num_mb_chars
+ return 1.0 - unlike
+ return unlike
diff --git a/lib/chardet/version.py b/lib/chardet/version.py
new file mode 100644
index 0000000..a08a06b
--- /dev/null
+++ b/lib/chardet/version.py
@@ -0,0 +1,9 @@
+"""
+This module exists only to simplify retrieving the version number of chardet
+from within setup.py and from chardet subpackages.
+
+:author: Dan Blanchard (dan.blanchard@gmail.com)
+"""
+
+__version__ = "5.0.0"
+VERSION = __version__.split(".")
diff --git a/lib/dbus/__init__.py b/lib/dbus/__init__.py
new file mode 100644
index 0000000..8cf3989
--- /dev/null
+++ b/lib/dbus/__init__.py
@@ -0,0 +1,93 @@
+"""\
+Implements the public API for a D-Bus client. See the dbus.service module
+to export objects or claim well-known names.
+"""
+
+# Copyright (C) 2003, 2004, 2005, 2006 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2003 David Zeuthen
+# Copyright (C) 2004 Rob Taylor
+# Copyright (C) 2005, 2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = [
+ # from _dbus
+ 'Bus', 'SystemBus', 'SessionBus', 'StarterBus',
+
+ # from proxies
+ 'Interface',
+
+ # from _dbus_bindings
+ 'get_default_main_loop', 'set_default_main_loop',
+
+ 'validate_interface_name', 'validate_member_name',
+ 'validate_bus_name', 'validate_object_path',
+ 'validate_error_name',
+
+ 'BUS_DAEMON_NAME', 'BUS_DAEMON_PATH', 'BUS_DAEMON_IFACE',
+ 'LOCAL_PATH', 'LOCAL_IFACE', 'PEER_IFACE',
+ 'INTROSPECTABLE_IFACE', 'PROPERTIES_IFACE',
+
+ 'ObjectPath', 'ByteArray', 'Signature', 'Byte', 'Boolean',
+ 'Int16', 'UInt16', 'Int32', 'UInt32', 'Int64', 'UInt64',
+ 'Double', 'String', 'Array', 'Struct', 'Dictionary',
+
+ # from exceptions
+ 'DBusException',
+ 'MissingErrorHandlerException', 'MissingReplyHandlerException',
+ 'ValidationException', 'IntrospectionParserException',
+ 'UnknownMethodException', 'NameExistsException',
+
+ # submodules
+ 'service', 'mainloop', 'lowlevel'
+ ]
+
+from dbus._compat import is_py2
+
+__docformat__ = 'restructuredtext'
+
+# OLPC Sugar compatibility
+import dbus.exceptions as exceptions
+import dbus.types as types
+
+from _dbus_bindings import __version__
+version = tuple(map(int, __version__.split('.')))
+
+from _dbus_bindings import (
+ get_default_main_loop, set_default_main_loop, validate_bus_name,
+ validate_error_name, validate_interface_name, validate_member_name,
+ validate_object_path)
+from _dbus_bindings import (
+ BUS_DAEMON_IFACE, BUS_DAEMON_NAME, BUS_DAEMON_PATH, INTROSPECTABLE_IFACE,
+ LOCAL_IFACE, LOCAL_PATH, PEER_IFACE, PROPERTIES_IFACE)
+
+from dbus.exceptions import (
+ DBusException, IntrospectionParserException, MissingErrorHandlerException,
+ MissingReplyHandlerException, NameExistsException, UnknownMethodException,
+ ValidationException)
+from _dbus_bindings import (
+ Array, Boolean, Byte, ByteArray, Dictionary, Double, Int16, Int32, Int64,
+ ObjectPath, Signature, String, Struct, UInt16, UInt32, UInt64)
+
+from dbus._dbus import Bus, SystemBus, SessionBus, StarterBus
+from dbus.proxies import Interface
diff --git a/lib/dbus/_compat.py b/lib/dbus/_compat.py
new file mode 100644
index 0000000..3e5f148
--- /dev/null
+++ b/lib/dbus/_compat.py
@@ -0,0 +1,15 @@
+# Python 2 / Python 3 compatibility helpers.
+# Copyright 2011 Barry Warsaw
+# Copyright 2021 Collabora Ltd.
+# SPDX-License-Identifier: MIT
+
+import sys
+
+is_py3 = True
+is_py2 = False
+
+if sys.version_info.major < 3:
+ raise AssertionError(
+ 'Python 2 has reached end-of-life, and dbus-python no longer '
+ 'supports it.'
+ )
diff --git a/lib/dbus/_dbus.py b/lib/dbus/_dbus.py
new file mode 100644
index 0000000..2891194
--- /dev/null
+++ b/lib/dbus/_dbus.py
@@ -0,0 +1,229 @@
+"""Implementation for dbus.Bus. Not to be imported directly."""
+
+# Copyright (C) 2003, 2004, 2005, 2006 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2003 David Zeuthen
+# Copyright (C) 2004 Rob Taylor
+# Copyright (C) 2005, 2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+from __future__ import generators
+
+__all__ = ('Bus', 'SystemBus', 'SessionBus', 'StarterBus')
+__docformat__ = 'reStructuredText'
+
+from dbus.exceptions import DBusException
+from _dbus_bindings import (
+ BUS_DAEMON_IFACE, BUS_DAEMON_NAME, BUS_DAEMON_PATH, BUS_SESSION,
+ BUS_STARTER, BUS_SYSTEM, DBUS_START_REPLY_ALREADY_RUNNING,
+ DBUS_START_REPLY_SUCCESS, validate_bus_name,
+ validate_interface_name, validate_member_name, validate_object_path)
+from dbus.bus import BusConnection
+from dbus.lowlevel import SignalMessage
+from dbus._compat import is_py2
+
+
+class Bus(BusConnection):
+ """A connection to one of three possible standard buses, the SESSION,
+ SYSTEM, or STARTER bus. This class manages shared connections to those
+ buses.
+
+ If you're trying to subclass `Bus`, you may be better off subclassing
+ `BusConnection`, which doesn't have all this magic.
+ """
+
+ _shared_instances = {}
+
+ def __new__(cls, bus_type=BusConnection.TYPE_SESSION, private=False,
+ mainloop=None):
+ """Constructor, returning an existing instance where appropriate.
+
+ The returned instance is actually always an instance of `SessionBus`,
+ `SystemBus` or `StarterBus`.
+
+ :Parameters:
+ `bus_type` : cls.TYPE_SESSION, cls.TYPE_SYSTEM or cls.TYPE_STARTER
+ Connect to the appropriate bus
+ `private` : bool
+ If true, never return an existing shared instance, but instead
+ return a private connection.
+
+ :Deprecated: since 0.82.3. Use dbus.bus.BusConnection for
+ private connections.
+
+ `mainloop` : dbus.mainloop.NativeMainLoop
+ The main loop to use. The default is to use the default
+ main loop if one has been set up, or raise an exception
+ if none has been.
+ :Changed: in dbus-python 0.80:
+ converted from a wrapper around a Connection to a Connection
+ subclass.
+ """
+ if (not private and bus_type in cls._shared_instances):
+ return cls._shared_instances[bus_type]
+
+ # this is a bit odd, but we create instances of the subtypes
+ # so we can return the shared instances if someone tries to
+ # construct one of them (otherwise we'd eg try and return an
+ # instance of Bus from __new__ in SessionBus). why are there
+ # three ways to construct this class? we just don't know.
+ if bus_type == BUS_SESSION:
+ subclass = SessionBus
+ elif bus_type == BUS_SYSTEM:
+ subclass = SystemBus
+ elif bus_type == BUS_STARTER:
+ subclass = StarterBus
+ else:
+ raise ValueError('invalid bus_type %s' % bus_type)
+
+ bus = BusConnection.__new__(subclass, bus_type, mainloop=mainloop)
+
+ bus._bus_type = bus_type
+
+ if not private:
+ cls._shared_instances[bus_type] = bus
+
+ return bus
+
+ def close(self):
+ t = self._bus_type
+ if self.__class__._shared_instances.get(t) is self:
+ del self.__class__._shared_instances[t]
+ super(Bus, self).close()
+
+ def get_connection(self):
+ """Return self, for backwards compatibility with earlier dbus-python
+ versions where Bus was not a subclass of Connection.
+
+ :Deprecated: since 0.80.0
+ """
+ return self
+ _connection = property(get_connection, None, None,
+ """self._connection == self, for backwards
+ compatibility with earlier dbus-python versions
+ where Bus was not a subclass of Connection.""")
+
+ def get_session(private=False):
+ """Static method that returns a connection to the session bus.
+
+ :Parameters:
+ `private` : bool
+ If true, do not return a shared connection.
+ """
+ return SessionBus(private=private)
+
+ get_session = staticmethod(get_session)
+
+ def get_system(private=False):
+ """Static method that returns a connection to the system bus.
+
+ :Parameters:
+ `private` : bool
+ If true, do not return a shared connection.
+ """
+ return SystemBus(private=private)
+
+ get_system = staticmethod(get_system)
+
+
+ def get_starter(private=False):
+ """Static method that returns a connection to the starter bus.
+
+ :Parameters:
+ `private` : bool
+ If true, do not return a shared connection.
+ """
+ return StarterBus(private=private)
+
+ get_starter = staticmethod(get_starter)
+
+ def __repr__(self):
+ if self._bus_type == BUS_SESSION:
+ name = 'session'
+ elif self._bus_type == BUS_SYSTEM:
+ name = 'system'
+ elif self._bus_type == BUS_STARTER:
+ name = 'starter'
+ else:
+ name = 'unknown bus type'
+
+ return '<%s.%s (%s) at %#x>' % (self.__class__.__module__,
+ self.__class__.__name__,
+ name, id(self))
+ __str__ = __repr__
+
+
+# FIXME: Drop the subclasses here? I can't think why we'd ever want
+# polymorphism
+class SystemBus(Bus):
+ """The system-wide message bus."""
+ def __new__(cls, private=False, mainloop=None):
+ """Return a connection to the system bus.
+
+ :Parameters:
+ `private` : bool
+ If true, never return an existing shared instance, but instead
+ return a private connection.
+ `mainloop` : dbus.mainloop.NativeMainLoop
+ The main loop to use. The default is to use the default
+ main loop if one has been set up, or raise an exception
+ if none has been.
+ """
+ return Bus.__new__(cls, Bus.TYPE_SYSTEM, mainloop=mainloop,
+ private=private)
+
+class SessionBus(Bus):
+ """The session (current login) message bus."""
+ def __new__(cls, private=False, mainloop=None):
+ """Return a connection to the session bus.
+
+ :Parameters:
+ `private` : bool
+ If true, never return an existing shared instance, but instead
+ return a private connection.
+ `mainloop` : dbus.mainloop.NativeMainLoop
+ The main loop to use. The default is to use the default
+ main loop if one has been set up, or raise an exception
+ if none has been.
+ """
+ return Bus.__new__(cls, Bus.TYPE_SESSION, private=private,
+ mainloop=mainloop)
+
+class StarterBus(Bus):
+ """The bus that activated this process (only valid if
+ this process was launched by DBus activation).
+ """
+ def __new__(cls, private=False, mainloop=None):
+ """Return a connection to the bus that activated this process.
+
+ :Parameters:
+ `private` : bool
+ If true, never return an existing shared instance, but instead
+ return a private connection.
+ `mainloop` : dbus.mainloop.NativeMainLoop
+ The main loop to use. The default is to use the default
+ main loop if one has been set up, or raise an exception
+ if none has been.
+ """
+ return Bus.__new__(cls, Bus.TYPE_STARTER, private=private,
+ mainloop=mainloop)
diff --git a/lib/dbus/_expat_introspect_parser.py b/lib/dbus/_expat_introspect_parser.py
new file mode 100644
index 0000000..2c6f341
--- /dev/null
+++ b/lib/dbus/_expat_introspect_parser.py
@@ -0,0 +1,87 @@
+# Copyright (C) 2003, 2004, 2005, 2006 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2003 David Zeuthen
+# Copyright (C) 2004 Rob Taylor
+# Copyright (C) 2005, 2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+from xml.parsers.expat import ParserCreate
+from dbus.exceptions import IntrospectionParserException
+
+class _Parser(object):
+ __slots__ = ('map', 'in_iface', 'in_method', 'sig')
+ def __init__(self):
+ self.map = {}
+ self.in_iface = ''
+ self.in_method = ''
+ self.sig = ''
+
+ def parse(self, data):
+ parser = ParserCreate('UTF-8', ' ')
+ parser.buffer_text = True
+ parser.StartElementHandler = self.StartElementHandler
+ parser.EndElementHandler = self.EndElementHandler
+ parser.Parse(data)
+ return self.map
+
+ def StartElementHandler(self, name, attributes):
+ if not self.in_iface:
+ if (not self.in_method and name == 'interface'):
+ self.in_iface = attributes['name']
+ else:
+ if (not self.in_method and name == 'method'):
+ self.in_method = attributes['name']
+ elif (self.in_method and name == 'arg'):
+ if attributes.get('direction', 'in') == 'in':
+ self.sig += attributes['type']
+
+ def EndElementHandler(self, name):
+ if self.in_iface:
+ if (not self.in_method and name == 'interface'):
+ self.in_iface = ''
+ elif (self.in_method and name == 'method'):
+ self.map[self.in_iface + '.' + self.in_method] = self.sig
+ self.in_method = ''
+ self.sig = ''
+
+def process_introspection_data(data):
+ """Return a dict mapping ``interface.method`` strings to the
+ concatenation of all their 'in' parameters, and mapping
+ ``interface.signal`` strings to the concatenation of all their
+ parameters.
+
+ Example output::
+
+ {
+ 'com.example.SignalEmitter.OneString': 's',
+ 'com.example.MethodImplementor.OneInt32Argument': 'i',
+ }
+
+ :Parameters:
+ `data` : str
+ The introspection XML. Must be an 8-bit string of UTF-8.
+ """
+ try:
+ return _Parser().parse(data)
+ except Exception as e:
+ raise IntrospectionParserException('%s: %s' % (e.__class__, e))
diff --git a/lib/dbus/bus.py b/lib/dbus/bus.py
new file mode 100644
index 0000000..fd5a281
--- /dev/null
+++ b/lib/dbus/bus.py
@@ -0,0 +1,434 @@
+# Copyright (C) 2007 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = ('BusConnection',)
+__docformat__ = 'reStructuredText'
+
+import logging
+import weakref
+
+from _dbus_bindings import (
+ BUS_DAEMON_IFACE, BUS_DAEMON_NAME, BUS_DAEMON_PATH, BUS_SESSION,
+ BUS_STARTER, BUS_SYSTEM, DBUS_START_REPLY_ALREADY_RUNNING,
+ DBUS_START_REPLY_SUCCESS, NAME_FLAG_ALLOW_REPLACEMENT,
+ NAME_FLAG_DO_NOT_QUEUE, NAME_FLAG_REPLACE_EXISTING,
+ RELEASE_NAME_REPLY_NON_EXISTENT, RELEASE_NAME_REPLY_NOT_OWNER,
+ RELEASE_NAME_REPLY_RELEASED, REQUEST_NAME_REPLY_ALREADY_OWNER,
+ REQUEST_NAME_REPLY_EXISTS, REQUEST_NAME_REPLY_IN_QUEUE,
+ REQUEST_NAME_REPLY_PRIMARY_OWNER, validate_bus_name, validate_error_name,
+ validate_interface_name, validate_member_name, validate_object_path)
+from dbus.connection import Connection
+from dbus.exceptions import DBusException
+from dbus.lowlevel import HANDLER_RESULT_NOT_YET_HANDLED
+from dbus._compat import is_py2
+
+
+_NAME_OWNER_CHANGE_MATCH = ("type='signal',sender='%s',"
+ "interface='%s',member='NameOwnerChanged',"
+ "path='%s',arg0='%%s'"
+ % (BUS_DAEMON_NAME, BUS_DAEMON_IFACE,
+ BUS_DAEMON_PATH))
+"""(_NAME_OWNER_CHANGE_MATCH % sender) matches relevant NameOwnerChange
+messages"""
+
+_NAME_HAS_NO_OWNER = 'org.freedesktop.DBus.Error.NameHasNoOwner'
+
+_logger = logging.getLogger('dbus.bus')
+
+
+class NameOwnerWatch(object):
+ __slots__ = ('_match', '_pending_call')
+
+ def __init__(self, bus_conn, bus_name, callback):
+ validate_bus_name(bus_name)
+
+ def signal_cb(owned, old_owner, new_owner):
+ callback(new_owner)
+
+ def error_cb(e):
+ if e.get_dbus_name() == _NAME_HAS_NO_OWNER:
+ callback('')
+ else:
+ logging.basicConfig()
+ _logger.debug('GetNameOwner(%s) failed:', bus_name,
+ exc_info=(e.__class__, e, None))
+
+ self._match = bus_conn.add_signal_receiver(signal_cb,
+ 'NameOwnerChanged',
+ BUS_DAEMON_IFACE,
+ BUS_DAEMON_NAME,
+ BUS_DAEMON_PATH,
+ arg0=bus_name)
+ self._pending_call = bus_conn.call_async(BUS_DAEMON_NAME,
+ BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE,
+ 'GetNameOwner',
+ 's', (bus_name,),
+ callback, error_cb)
+
+ def cancel(self):
+ if self._match is not None:
+ self._match.remove()
+ if self._pending_call is not None:
+ self._pending_call.cancel()
+ self._match = None
+ self._pending_call = None
+
+
+class BusConnection(Connection):
+ """A connection to a D-Bus daemon that implements the
+ ``org.freedesktop.DBus`` pseudo-service.
+
+ :Since: 0.81.0
+ """
+
+ TYPE_SESSION = BUS_SESSION
+ """Represents a session bus (same as the global dbus.BUS_SESSION)"""
+
+ TYPE_SYSTEM = BUS_SYSTEM
+ """Represents the system bus (same as the global dbus.BUS_SYSTEM)"""
+
+ TYPE_STARTER = BUS_STARTER
+ """Represents the bus that started this service by activation (same as
+ the global dbus.BUS_STARTER)"""
+
+ START_REPLY_SUCCESS = DBUS_START_REPLY_SUCCESS
+ START_REPLY_ALREADY_RUNNING = DBUS_START_REPLY_ALREADY_RUNNING
+
+ def __new__(cls, address_or_type=TYPE_SESSION, mainloop=None):
+ bus = cls._new_for_bus(address_or_type, mainloop=mainloop)
+
+ # _bus_names is used by dbus.service.BusName!
+ bus._bus_names = weakref.WeakValueDictionary()
+
+ bus._signal_sender_matches = {}
+ """Map from SignalMatch to NameOwnerWatch."""
+
+ return bus
+
+ def add_signal_receiver(self, handler_function, signal_name=None,
+ dbus_interface=None, bus_name=None,
+ path=None, **keywords):
+ named_service = keywords.pop('named_service', None)
+ if named_service is not None:
+ if bus_name is not None:
+ raise TypeError('bus_name and named_service cannot both be '
+ 'specified')
+ bus_name = named_service
+ from warnings import warn
+ warn('Passing the named_service parameter to add_signal_receiver '
+ 'by name is deprecated: please use positional parameters',
+ DeprecationWarning, stacklevel=2)
+
+ match = super(BusConnection, self).add_signal_receiver(
+ handler_function, signal_name, dbus_interface, bus_name,
+ path, **keywords)
+
+ if (bus_name is not None and bus_name != BUS_DAEMON_NAME):
+ if bus_name[:1] == ':':
+ def callback(new_owner):
+ if new_owner == '':
+ match.remove()
+ else:
+ callback = match.set_sender_name_owner
+ watch = self.watch_name_owner(bus_name, callback)
+ self._signal_sender_matches[match] = watch
+
+ self.add_match_string(str(match))
+
+ return match
+
+ def _clean_up_signal_match(self, match):
+ # The signals lock is no longer held here (it was in <= 0.81.0)
+ self.remove_match_string_non_blocking(str(match))
+ watch = self._signal_sender_matches.pop(match, None)
+ if watch is not None:
+ watch.cancel()
+
+ def activate_name_owner(self, bus_name):
+ if (bus_name is not None and bus_name[:1] != ':'
+ and bus_name != BUS_DAEMON_NAME):
+ try:
+ return self.get_name_owner(bus_name)
+ except DBusException as e:
+ if e.get_dbus_name() != _NAME_HAS_NO_OWNER:
+ raise
+ # else it doesn't exist: try to start it
+ self.start_service_by_name(bus_name)
+ return self.get_name_owner(bus_name)
+ else:
+ # already unique
+ return bus_name
+
+ def get_object(self, bus_name, object_path, introspect=True,
+ follow_name_owner_changes=False, **kwargs):
+ """Return a local proxy for the given remote object.
+
+ Method calls on the proxy are translated into method calls on the
+ remote object.
+
+ :Parameters:
+ `bus_name` : str
+ A bus name (either the unique name or a well-known name)
+ of the application owning the object. The keyword argument
+ named_service is a deprecated alias for this.
+ `object_path` : str
+ The object path of the desired object
+ `introspect` : bool
+ If true (default), attempt to introspect the remote
+ object to find out supported methods and their signatures
+ `follow_name_owner_changes` : bool
+ If the object path is a well-known name and this parameter
+ is false (default), resolve the well-known name to the unique
+ name of its current owner and bind to that instead; if the
+ ownership of the well-known name changes in future,
+ keep communicating with the original owner.
+ This is necessary if the D-Bus API used is stateful.
+
+ If the object path is a well-known name and this parameter
+ is true, whenever the well-known name changes ownership in
+ future, bind to the new owner, if any.
+
+ If the given object path is a unique name, this parameter
+ has no effect.
+
+ :Returns: a `dbus.proxies.ProxyObject`
+ :Raises `DBusException`: if resolving the well-known name to a
+ unique name fails
+ """
+ if follow_name_owner_changes:
+ self._require_main_loop() # we don't get the signals otherwise
+
+ named_service = kwargs.pop('named_service', None)
+ if named_service is not None:
+ if bus_name is not None:
+ raise TypeError('bus_name and named_service cannot both '
+ 'be specified')
+ from warnings import warn
+ warn('Passing the named_service parameter to get_object by name '
+ 'is deprecated: please use positional parameters',
+ DeprecationWarning, stacklevel=2)
+ bus_name = named_service
+ if kwargs:
+ raise TypeError('get_object does not take these keyword '
+ 'arguments: %s' % ', '.join(kwargs.keys()))
+
+ return self.ProxyObjectClass(self, bus_name, object_path,
+ introspect=introspect,
+ follow_name_owner_changes=follow_name_owner_changes)
+
+ def get_unix_user(self, bus_name):
+ """Get the numeric uid of the process owning the given bus name.
+
+ :Parameters:
+ `bus_name` : str
+ A bus name, either unique or well-known
+ :Returns: a `dbus.UInt32`
+ :Since: 0.80.0
+ """
+ validate_bus_name(bus_name)
+ return self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'GetConnectionUnixUser',
+ 's', (bus_name,))
+
+ def start_service_by_name(self, bus_name, flags=0):
+ """Start a service which will implement the given bus name on this Bus.
+
+ :Parameters:
+ `bus_name` : str
+ The well-known bus name to be activated.
+ `flags` : dbus.UInt32
+ Flags to pass to StartServiceByName (currently none are
+ defined)
+
+ :Returns: A tuple of 2 elements. The first is always True, the
+ second is either START_REPLY_SUCCESS or
+ START_REPLY_ALREADY_RUNNING.
+
+ :Raises `DBusException`: if the service could not be started.
+ :Since: 0.80.0
+ """
+ validate_bus_name(bus_name)
+ return (True, self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE,
+ 'StartServiceByName',
+ 'su', (bus_name, flags)))
+
+ # XXX: it might be nice to signal IN_QUEUE, EXISTS by exception,
+ # but this would not be backwards-compatible
+ def request_name(self, name, flags=0):
+ """Request a bus name.
+
+ :Parameters:
+ `name` : str
+ The well-known name to be requested
+ `flags` : dbus.UInt32
+ A bitwise-OR of 0 or more of the flags
+ `NAME_FLAG_ALLOW_REPLACEMENT`,
+ `NAME_FLAG_REPLACE_EXISTING`
+ and `NAME_FLAG_DO_NOT_QUEUE`
+ :Returns: `REQUEST_NAME_REPLY_PRIMARY_OWNER`,
+ `REQUEST_NAME_REPLY_IN_QUEUE`,
+ `REQUEST_NAME_REPLY_EXISTS` or
+ `REQUEST_NAME_REPLY_ALREADY_OWNER`
+ :Raises `DBusException`: if the bus daemon cannot be contacted or
+ returns an error.
+ """
+ validate_bus_name(name, allow_unique=False)
+ return self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'RequestName',
+ 'su', (name, flags))
+
+ def release_name(self, name):
+ """Release a bus name.
+
+ :Parameters:
+ `name` : str
+ The well-known name to be released
+ :Returns: `RELEASE_NAME_REPLY_RELEASED`,
+ `RELEASE_NAME_REPLY_NON_EXISTENT`
+ or `RELEASE_NAME_REPLY_NOT_OWNER`
+ :Raises `DBusException`: if the bus daemon cannot be contacted or
+ returns an error.
+ """
+ validate_bus_name(name, allow_unique=False)
+ return self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'ReleaseName',
+ 's', (name,))
+
+ def list_names(self):
+ """Return a list of all currently-owned names on the bus.
+
+ :Returns: a dbus.Array of dbus.UTF8String
+ :Since: 0.81.0
+ """
+ return self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'ListNames',
+ '', ())
+
+ def list_activatable_names(self):
+ """Return a list of all names that can be activated on the bus.
+
+ :Returns: a dbus.Array of dbus.UTF8String
+ :Since: 0.81.0
+ """
+ return self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'ListActivatableNames',
+ '', ())
+
+ def get_name_owner(self, bus_name):
+ """Return the unique connection name of the primary owner of the
+ given name.
+
+ :Raises `DBusException`: if the `bus_name` has no owner
+ :Since: 0.81.0
+ """
+ validate_bus_name(bus_name, allow_unique=False)
+ return self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'GetNameOwner',
+ 's', (bus_name,))
+
+ def watch_name_owner(self, bus_name, callback):
+ """Watch the unique connection name of the primary owner of the
+ given name.
+
+ `callback` will be called with one argument, which is either the
+ unique connection name, or the empty string (meaning the name is
+ not owned).
+
+ :Since: 0.81.0
+ """
+ return NameOwnerWatch(self, bus_name, callback)
+
+ def name_has_owner(self, bus_name):
+ """Return True iff the given bus name has an owner on this bus.
+
+ :Parameters:
+ `bus_name` : str
+ The bus name to look up
+ :Returns: a `bool`
+ """
+ return bool(self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'NameHasOwner',
+ 's', (bus_name,)))
+
+ def add_match_string(self, rule):
+ """Arrange for this application to receive messages on the bus that
+ match the given rule. This version will block.
+
+ :Parameters:
+ `rule` : str
+ The match rule
+ :Raises `DBusException`: on error.
+ :Since: 0.80.0
+ """
+ self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'AddMatch', 's', (rule,))
+
+ # FIXME: add an async success/error handler capability?
+ # (and the same for remove_...)
+ def add_match_string_non_blocking(self, rule):
+ """Arrange for this application to receive messages on the bus that
+ match the given rule. This version will not block, but any errors
+ will be ignored.
+
+
+ :Parameters:
+ `rule` : str
+ The match rule
+ :Raises `DBusException`: on error.
+ :Since: 0.80.0
+ """
+ self.call_async(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'AddMatch', 's', (rule,),
+ None, None)
+
+ def remove_match_string(self, rule):
+ """Arrange for this application to receive messages on the bus that
+ match the given rule. This version will block.
+
+ :Parameters:
+ `rule` : str
+ The match rule
+ :Raises `DBusException`: on error.
+ :Since: 0.80.0
+ """
+ self.call_blocking(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'RemoveMatch', 's', (rule,))
+
+ def remove_match_string_non_blocking(self, rule):
+ """Arrange for this application to receive messages on the bus that
+ match the given rule. This version will not block, but any errors
+ will be ignored.
+
+
+ :Parameters:
+ `rule` : str
+ The match rule
+ :Raises `DBusException`: on error.
+ :Since: 0.80.0
+ """
+ self.call_async(BUS_DAEMON_NAME, BUS_DAEMON_PATH,
+ BUS_DAEMON_IFACE, 'RemoveMatch', 's', (rule,),
+ None, None)
diff --git a/lib/dbus/connection.py b/lib/dbus/connection.py
new file mode 100644
index 0000000..fd20800
--- /dev/null
+++ b/lib/dbus/connection.py
@@ -0,0 +1,651 @@
+# Copyright (C) 2007 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = ('Connection', 'SignalMatch')
+__docformat__ = 'reStructuredText'
+
+import logging
+import threading
+import weakref
+
+from _dbus_bindings import (
+ Connection as _Connection, LOCAL_IFACE, LOCAL_PATH, validate_bus_name,
+ validate_interface_name, validate_member_name, validate_object_path)
+from dbus.exceptions import DBusException
+from dbus.lowlevel import (
+ ErrorMessage, HANDLER_RESULT_NOT_YET_HANDLED, MethodCallMessage,
+ MethodReturnMessage, SignalMessage)
+from dbus.proxies import ProxyObject
+from dbus._compat import is_py2, is_py3
+
+from _dbus_bindings import String
+
+
+_logger = logging.getLogger('dbus.connection')
+
+
+def _noop(*args, **kwargs):
+ pass
+
+
+class SignalMatch(object):
+ _slots = ['_sender_name_owner', '_member', '_interface', '_sender',
+ '_path', '_handler', '_args_match', '_rule',
+ '_byte_arrays', '_conn_weakref',
+ '_destination_keyword', '_interface_keyword',
+ '_message_keyword', '_member_keyword',
+ '_sender_keyword', '_path_keyword', '_int_args_match']
+
+ __slots__ = tuple(_slots)
+
+ def __init__(self, conn, sender, object_path, dbus_interface,
+ member, handler, byte_arrays=False,
+ sender_keyword=None, path_keyword=None,
+ interface_keyword=None, member_keyword=None,
+ message_keyword=None, destination_keyword=None,
+ **kwargs):
+ if member is not None:
+ validate_member_name(member)
+ if dbus_interface is not None:
+ validate_interface_name(dbus_interface)
+ if sender is not None:
+ validate_bus_name(sender)
+ if object_path is not None:
+ validate_object_path(object_path)
+
+ self._rule = None
+ self._conn_weakref = weakref.ref(conn)
+ self._sender = sender
+ self._interface = dbus_interface
+ self._member = member
+ self._path = object_path
+ self._handler = handler
+
+ # if the connection is actually a bus, it's responsible for changing
+ # this later
+ self._sender_name_owner = sender
+
+ if 'utf8_strings' in kwargs:
+ raise TypeError("unexpected keyword argument 'utf8_strings'")
+
+ self._byte_arrays = byte_arrays
+ self._sender_keyword = sender_keyword
+ self._path_keyword = path_keyword
+ self._member_keyword = member_keyword
+ self._interface_keyword = interface_keyword
+ self._message_keyword = message_keyword
+ self._destination_keyword = destination_keyword
+
+ self._args_match = kwargs
+ if not kwargs:
+ self._int_args_match = None
+ else:
+ self._int_args_match = {}
+ for kwarg in kwargs:
+ if not kwarg.startswith('arg'):
+ raise TypeError('SignalMatch: unknown keyword argument %s'
+ % kwarg)
+ try:
+ index = int(kwarg[3:])
+ except ValueError:
+ raise TypeError('SignalMatch: unknown keyword argument %s'
+ % kwarg)
+ if index < 0 or index > 63:
+ raise TypeError('SignalMatch: arg match index must be in '
+ 'range(64), not %d' % index)
+ self._int_args_match[index] = kwargs[kwarg]
+
+ def __hash__(self):
+ """SignalMatch objects are compared by identity."""
+ return hash(id(self))
+
+ def __eq__(self, other):
+ """SignalMatch objects are compared by identity."""
+ return self is other
+
+ def __ne__(self, other):
+ """SignalMatch objects are compared by identity."""
+ return self is not other
+
+ sender = property(lambda self: self._sender)
+
+ def __str__(self):
+ if self._rule is None:
+ rule = ["type='signal'"]
+ if self._sender is not None:
+ rule.append("sender='%s'" % self._sender)
+ if self._path is not None:
+ rule.append("path='%s'" % self._path)
+ if self._interface is not None:
+ rule.append("interface='%s'" % self._interface)
+ if self._member is not None:
+ rule.append("member='%s'" % self._member)
+ if self._int_args_match is not None:
+ for index, value in self._int_args_match.items():
+ rule.append("arg%d='%s'" % (index, value))
+
+ self._rule = ','.join(rule)
+
+ return self._rule
+
+ def __repr__(self):
+ return ('<%s at %x "%s" on conn %r>'
+ % (self.__class__, id(self), self._rule, self._conn_weakref()))
+
+ def set_sender_name_owner(self, new_name):
+ self._sender_name_owner = new_name
+
+ def matches_removal_spec(self, sender, object_path,
+ dbus_interface, member, handler, **kwargs):
+ if handler not in (None, self._handler):
+ return False
+ if sender != self._sender:
+ return False
+ if object_path != self._path:
+ return False
+ if dbus_interface != self._interface:
+ return False
+ if member != self._member:
+ return False
+ if kwargs != self._args_match:
+ return False
+ return True
+
+ def maybe_handle_message(self, message):
+ args = None
+
+ # these haven't been checked yet by the match tree
+ if self._sender_name_owner not in (None, message.get_sender()):
+ return False
+ if self._int_args_match is not None:
+ # extracting args with byte_arrays is less work
+ kwargs = dict(byte_arrays=True)
+ args = message.get_args_list(**kwargs)
+ for index, value in self._int_args_match.items():
+ if (index >= len(args)
+ or not isinstance(args[index], String)
+ or args[index] != value):
+ return False
+
+ # these have likely already been checked by the match tree
+ if self._member not in (None, message.get_member()):
+ return False
+ if self._interface not in (None, message.get_interface()):
+ return False
+ if self._path not in (None, message.get_path()):
+ return False
+
+ try:
+ # minor optimization: if we already extracted the args with the
+ # right calling convention to do the args match, don't bother
+ # doing so again
+ if args is None or not self._byte_arrays:
+ args = message.get_args_list(byte_arrays=self._byte_arrays)
+ kwargs = {}
+ if self._sender_keyword is not None:
+ kwargs[self._sender_keyword] = message.get_sender()
+ if self._destination_keyword is not None:
+ kwargs[self._destination_keyword] = message.get_destination()
+ if self._path_keyword is not None:
+ kwargs[self._path_keyword] = message.get_path()
+ if self._member_keyword is not None:
+ kwargs[self._member_keyword] = message.get_member()
+ if self._interface_keyword is not None:
+ kwargs[self._interface_keyword] = message.get_interface()
+ if self._message_keyword is not None:
+ kwargs[self._message_keyword] = message
+ self._handler(*args, **kwargs)
+ except:
+ # basicConfig is a no-op if logging is already configured
+ logging.basicConfig()
+ _logger.error('Exception in handler for D-Bus signal:', exc_info=1)
+
+ return True
+
+ def remove(self):
+ conn = self._conn_weakref()
+ # do nothing if the connection has already vanished
+ if conn is not None:
+ conn.remove_signal_receiver(self, self._member,
+ self._interface, self._sender,
+ self._path,
+ **self._args_match)
+
+
+class Connection(_Connection):
+ """A connection to another application. In this base class there is
+ assumed to be no bus daemon.
+
+ :Since: 0.81.0
+ """
+
+ ProxyObjectClass = ProxyObject
+
+ def __init__(self, *args, **kwargs):
+ super(Connection, self).__init__(*args, **kwargs)
+
+ # this if-block is needed because shared bus connections can be
+ # __init__'ed more than once
+ if not hasattr(self, '_dbus_Connection_initialized'):
+ self._dbus_Connection_initialized = 1
+
+ self.__call_on_disconnection = []
+
+ self._signal_recipients_by_object_path = {}
+ """Map from object path to dict mapping dbus_interface to dict
+ mapping member to list of SignalMatch objects."""
+
+ self._signals_lock = threading.Lock()
+ """Lock used to protect signal data structures"""
+
+ self.add_message_filter(self.__class__._signal_func)
+
+ def activate_name_owner(self, bus_name):
+ """Return the unique name for the given bus name, activating it
+ if necessary and possible.
+
+ If the name is already unique or this connection is not to a
+ bus daemon, just return it.
+
+ :Returns: a bus name. If the given `bus_name` exists, the returned
+ name identifies its current owner; otherwise the returned name
+ does not exist.
+ :Raises DBusException: if the implementation has failed
+ to activate the given bus name.
+ :Since: 0.81.0
+ """
+ return bus_name
+
+ def get_object(self, bus_name=None, object_path=None, introspect=True,
+ **kwargs):
+ """Return a local proxy for the given remote object.
+
+ Method calls on the proxy are translated into method calls on the
+ remote object.
+
+ :Parameters:
+ `bus_name` : str
+ A bus name (either the unique name or a well-known name)
+ of the application owning the object. The keyword argument
+ named_service is a deprecated alias for this.
+ `object_path` : str
+ The object path of the desired object
+ `introspect` : bool
+ If true (default), attempt to introspect the remote
+ object to find out supported methods and their signatures
+
+ :Returns: a `dbus.proxies.ProxyObject`
+ """
+ named_service = kwargs.pop('named_service', None)
+ if named_service is not None:
+ if bus_name is not None:
+ raise TypeError('bus_name and named_service cannot both '
+ 'be specified')
+ from warnings import warn
+ warn('Passing the named_service parameter to get_object by name '
+ 'is deprecated: please use positional parameters',
+ DeprecationWarning, stacklevel=2)
+ bus_name = named_service
+ if kwargs:
+ raise TypeError('get_object does not take these keyword '
+ 'arguments: %s' % ', '.join(kwargs.keys()))
+
+ return self.ProxyObjectClass(self, bus_name, object_path,
+ introspect=introspect)
+
+ def add_signal_receiver(self, handler_function,
+ signal_name=None,
+ dbus_interface=None,
+ bus_name=None,
+ path=None,
+ **keywords):
+ """Arrange for the given function to be called when a signal matching
+ the parameters is received.
+
+ :Parameters:
+ `handler_function` : callable
+ The function to be called. Its positional arguments will
+ be the arguments of the signal. By default it will receive
+ no keyword arguments, but see the description of
+ the optional keyword arguments below.
+ `signal_name` : str
+ The signal name; None (the default) matches all names
+ `dbus_interface` : str
+ The D-Bus interface name with which to qualify the signal;
+ None (the default) matches all interface names
+ `bus_name` : str
+ A bus name for the sender, which will be resolved to a
+ unique name if it is not already; None (the default) matches
+ any sender.
+ `path` : str
+ The object path of the object which must have emitted the
+ signal; None (the default) matches any object path
+ :Keywords:
+ `utf8_strings` : bool
+ If True, the handler function will receive any string
+ arguments as dbus.UTF8String objects (a subclass of str
+ guaranteed to be UTF-8). If False (default) it will receive
+ any string arguments as dbus.String objects (a subclass of
+ unicode).
+ `byte_arrays` : bool
+ If True, the handler function will receive any byte-array
+ arguments as dbus.ByteArray objects (a subclass of str).
+ If False (default) it will receive any byte-array
+ arguments as a dbus.Array of dbus.Byte (subclasses of:
+ a list of ints).
+ `sender_keyword` : str
+ If not None (the default), the handler function will receive
+ the unique name of the sending endpoint as a keyword
+ argument with this name.
+ `destination_keyword` : str
+ If not None (the default), the handler function will receive
+ the bus name of the destination (or None if the signal is a
+ broadcast, as is usual) as a keyword argument with this name.
+ `interface_keyword` : str
+ If not None (the default), the handler function will receive
+ the signal interface as a keyword argument with this name.
+ `member_keyword` : str
+ If not None (the default), the handler function will receive
+ the signal name as a keyword argument with this name.
+ `path_keyword` : str
+ If not None (the default), the handler function will receive
+ the object-path of the sending object as a keyword argument
+ with this name.
+ `message_keyword` : str
+ If not None (the default), the handler function will receive
+ the `dbus.lowlevel.SignalMessage` as a keyword argument with
+ this name.
+ `arg...` : unicode or UTF-8 str
+ If there are additional keyword parameters of the form
+ ``arg``\\ *n*, match only signals where the *n*\\ th argument
+ is the value given for that keyword parameter. As of this
+ time only string arguments can be matched (in particular,
+ object paths and signatures can't).
+ `named_service` : str
+ A deprecated alias for `bus_name`.
+ """
+ self._require_main_loop()
+
+ named_service = keywords.pop('named_service', None)
+ if named_service is not None:
+ if bus_name is not None:
+ raise TypeError('bus_name and named_service cannot both be '
+ 'specified')
+ bus_name = named_service
+ from warnings import warn
+ warn('Passing the named_service parameter to add_signal_receiver '
+ 'by name is deprecated: please use positional parameters',
+ DeprecationWarning, stacklevel=2)
+
+ match = SignalMatch(self, bus_name, path, dbus_interface,
+ signal_name, handler_function, **keywords)
+
+ self._signals_lock.acquire()
+ try:
+ by_interface = self._signal_recipients_by_object_path.setdefault(
+ path, {})
+ by_member = by_interface.setdefault(dbus_interface, {})
+ matches = by_member.setdefault(signal_name, [])
+
+ matches.append(match)
+ finally:
+ self._signals_lock.release()
+
+ return match
+
+ def _iter_easy_matches(self, path, dbus_interface, member):
+ if path is not None:
+ path_keys = (None, path)
+ else:
+ path_keys = (None,)
+ if dbus_interface is not None:
+ interface_keys = (None, dbus_interface)
+ else:
+ interface_keys = (None,)
+ if member is not None:
+ member_keys = (None, member)
+ else:
+ member_keys = (None,)
+
+ for path in path_keys:
+ by_interface = self._signal_recipients_by_object_path.get(path)
+ if by_interface is None:
+ continue
+ for dbus_interface in interface_keys:
+ by_member = by_interface.get(dbus_interface, None)
+ if by_member is None:
+ continue
+ for member in member_keys:
+ matches = by_member.get(member, None)
+ if matches is None:
+ continue
+ for m in matches:
+ yield m
+
+ def remove_signal_receiver(self, handler_or_match,
+ signal_name=None,
+ dbus_interface=None,
+ bus_name=None,
+ path=None,
+ **keywords):
+ named_service = keywords.pop('named_service', None)
+ if named_service is not None:
+ if bus_name is not None:
+ raise TypeError('bus_name and named_service cannot both be '
+ 'specified')
+ bus_name = named_service
+ from warnings import warn
+ warn('Passing the named_service parameter to '
+ 'remove_signal_receiver by name is deprecated: please use '
+ 'positional parameters',
+ DeprecationWarning, stacklevel=2)
+
+ new = []
+ deletions = []
+ self._signals_lock.acquire()
+ try:
+ by_interface = self._signal_recipients_by_object_path.get(path,
+ None)
+ if by_interface is None:
+ return
+ by_member = by_interface.get(dbus_interface, None)
+ if by_member is None:
+ return
+ matches = by_member.get(signal_name, None)
+ if matches is None:
+ return
+
+ for match in matches:
+ if (handler_or_match is match
+ or match.matches_removal_spec(bus_name,
+ path,
+ dbus_interface,
+ signal_name,
+ handler_or_match,
+ **keywords)):
+ deletions.append(match)
+ else:
+ new.append(match)
+
+ if new:
+ by_member[signal_name] = new
+ else:
+ del by_member[signal_name]
+ if not by_member:
+ del by_interface[dbus_interface]
+ if not by_interface:
+ del self._signal_recipients_by_object_path[path]
+ finally:
+ self._signals_lock.release()
+
+ for match in deletions:
+ self._clean_up_signal_match(match)
+
+ def _clean_up_signal_match(self, match):
+ # Now called without the signals lock held (it was held in <= 0.81.0)
+ pass
+
+ def _signal_func(self, message):
+ """D-Bus filter function. Handle signals by dispatching to Python
+ callbacks kept in the match-rule tree.
+ """
+
+ if not isinstance(message, SignalMessage):
+ return HANDLER_RESULT_NOT_YET_HANDLED
+
+ dbus_interface = message.get_interface()
+ path = message.get_path()
+ signal_name = message.get_member()
+
+ for match in self._iter_easy_matches(path, dbus_interface,
+ signal_name):
+ match.maybe_handle_message(message)
+
+ if (dbus_interface == LOCAL_IFACE and
+ path == LOCAL_PATH and
+ signal_name == 'Disconnected'):
+ for cb in self.__call_on_disconnection:
+ try:
+ cb(self)
+ except Exception:
+ # basicConfig is a no-op if logging is already configured
+ logging.basicConfig()
+ _logger.error('Exception in handler for Disconnected '
+ 'signal:', exc_info=1)
+
+ return HANDLER_RESULT_NOT_YET_HANDLED
+
+ def call_async(self, bus_name, object_path, dbus_interface, method,
+ signature, args, reply_handler, error_handler,
+ timeout=-1.0, byte_arrays=False,
+ require_main_loop=True, **kwargs):
+ """Call the given method, asynchronously.
+
+ If the reply_handler is None, successful replies will be ignored.
+ If the error_handler is None, failures will be ignored. If both
+ are None, the implementation may request that no reply is sent.
+
+ :Returns: The dbus.lowlevel.PendingCall.
+ :Since: 0.81.0
+ """
+ if object_path == LOCAL_PATH:
+ raise DBusException('Methods may not be called on the reserved '
+ 'path %s' % LOCAL_PATH)
+ if dbus_interface == LOCAL_IFACE:
+ raise DBusException('Methods may not be called on the reserved '
+ 'interface %s' % LOCAL_IFACE)
+ # no need to validate other args - MethodCallMessage ctor will do
+
+ get_args_opts = dict(byte_arrays=byte_arrays)
+ if 'utf8_strings' in kwargs:
+ raise TypeError("unexpected keyword argument 'utf8_strings'")
+
+ message = MethodCallMessage(destination=bus_name,
+ path=object_path,
+ interface=dbus_interface,
+ method=method)
+ # Add the arguments to the function
+ try:
+ message.append(signature=signature, *args)
+ except Exception as e:
+ logging.basicConfig()
+ _logger.error('Unable to set arguments %r according to '
+ 'signature %r: %s: %s',
+ args, signature, e.__class__, e)
+ raise
+
+ if reply_handler is None and error_handler is None:
+ # we don't care what happens, so just send it
+ self.send_message(message)
+ return
+
+ if reply_handler is None:
+ reply_handler = _noop
+ if error_handler is None:
+ error_handler = _noop
+
+ def msg_reply_handler(message):
+ if isinstance(message, MethodReturnMessage):
+ reply_handler(*message.get_args_list(**get_args_opts))
+ elif isinstance(message, ErrorMessage):
+ error_handler(DBusException(name=message.get_error_name(),
+ *message.get_args_list()))
+ else:
+ error_handler(TypeError('Unexpected type for reply '
+ 'message: %r' % message))
+ return self.send_message_with_reply(message, msg_reply_handler,
+ timeout,
+ require_main_loop=require_main_loop)
+
+ def call_blocking(self, bus_name, object_path, dbus_interface, method,
+ signature, args, timeout=-1.0,
+ byte_arrays=False, **kwargs):
+ """Call the given method, synchronously.
+ :Since: 0.81.0
+ """
+ if object_path == LOCAL_PATH:
+ raise DBusException('Methods may not be called on the reserved '
+ 'path %s' % LOCAL_PATH)
+ if dbus_interface == LOCAL_IFACE:
+ raise DBusException('Methods may not be called on the reserved '
+ 'interface %s' % LOCAL_IFACE)
+ # no need to validate other args - MethodCallMessage ctor will do
+
+ get_args_opts = dict(byte_arrays=byte_arrays)
+ if 'utf8_strings' in kwargs:
+ raise TypeError("unexpected keyword argument 'utf8_strings'")
+
+ message = MethodCallMessage(destination=bus_name,
+ path=object_path,
+ interface=dbus_interface,
+ method=method)
+ # Add the arguments to the function
+ try:
+ message.append(signature=signature, *args)
+ except Exception as e:
+ logging.basicConfig()
+ _logger.error('Unable to set arguments %r according to '
+ 'signature %r: %s: %s',
+ args, signature, e.__class__, e)
+ raise
+
+ # make a blocking call
+ reply_message = self.send_message_with_reply_and_block(
+ message, timeout)
+ args_list = reply_message.get_args_list(**get_args_opts)
+ if len(args_list) == 0:
+ return None
+ elif len(args_list) == 1:
+ return args_list[0]
+ else:
+ return tuple(args_list)
+
+ def call_on_disconnection(self, callable):
+ """Arrange for `callable` to be called with one argument (this
+ Connection object) when the Connection becomes
+ disconnected.
+
+ :Since: 0.83.0
+ """
+ self.__call_on_disconnection.append(callable)
diff --git a/lib/dbus/decorators.py b/lib/dbus/decorators.py
new file mode 100644
index 0000000..3dc04a6
--- /dev/null
+++ b/lib/dbus/decorators.py
@@ -0,0 +1,362 @@
+"""Service-side D-Bus decorators."""
+
+# Copyright (C) 2003, 2004, 2005, 2006 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2003 David Zeuthen
+# Copyright (C) 2004 Rob Taylor
+# Copyright (C) 2005, 2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = ('method', 'signal')
+__docformat__ = 'restructuredtext'
+
+import inspect
+
+from dbus import validate_interface_name, Signature, validate_member_name
+from dbus.lowlevel import SignalMessage
+from dbus.exceptions import DBusException
+from dbus._compat import is_py2
+
+
+def method(dbus_interface, in_signature=None, out_signature=None,
+ async_callbacks=None,
+ sender_keyword=None, path_keyword=None, destination_keyword=None,
+ message_keyword=None, connection_keyword=None,
+ byte_arrays=False,
+ rel_path_keyword=None, **kwargs):
+ """Factory for decorators used to mark methods of a `dbus.service.Object`
+ to be exported on the D-Bus.
+
+ The decorated method will be exported over D-Bus as the method of the
+ same name on the given D-Bus interface.
+
+ :Parameters:
+ `dbus_interface` : str
+ Name of a D-Bus interface
+ `in_signature` : str or None
+ If not None, the signature of the method parameters in the usual
+ D-Bus notation
+ `out_signature` : str or None
+ If not None, the signature of the return value in the usual
+ D-Bus notation
+ `async_callbacks` : tuple containing (str,str), or None
+ If None (default) the decorated method is expected to return
+ values matching the `out_signature` as usual, or raise
+ an exception on error. If not None, the following applies:
+
+ `async_callbacks` contains the names of two keyword arguments to
+ the decorated function, which will be used to provide a success
+ callback and an error callback (in that order).
+
+ When the decorated method is called via the D-Bus, its normal
+ return value will be ignored; instead, a pair of callbacks are
+ passed as keyword arguments, and the decorated method is
+ expected to arrange for one of them to be called.
+
+ On success the success callback must be called, passing the
+ results of this method as positional parameters in the format
+ given by the `out_signature`.
+
+ On error the decorated method may either raise an exception
+ before it returns, or arrange for the error callback to be
+ called with an Exception instance as parameter.
+
+ `sender_keyword` : str or None
+ If not None, contains the name of a keyword argument to the
+ decorated function, conventionally ``'sender'``. When the
+ method is called, the sender's unique name will be passed as
+ this keyword argument.
+
+ `path_keyword` : str or None
+ If not None (the default), the decorated method will receive
+ the destination object path as a keyword argument with this
+ name. Normally you already know the object path, but in the
+ case of "fallback paths" you'll usually want to use the object
+ path in the method's implementation.
+
+ For fallback objects, `rel_path_keyword` (new in 0.82.2) is
+ likely to be more useful.
+
+ :Since: 0.80.0?
+
+ `rel_path_keyword` : str or None
+ If not None (the default), the decorated method will receive
+ the destination object path, relative to the path at which the
+ object was exported, as a keyword argument with this
+ name. For non-fallback objects the relative path will always be
+ '/'.
+
+ :Since: 0.82.2
+
+ `destination_keyword` : str or None
+ If not None (the default), the decorated method will receive
+ the destination bus name as a keyword argument with this name.
+ Included for completeness - you shouldn't need this.
+
+ :Since: 0.80.0?
+
+ `message_keyword` : str or None
+ If not None (the default), the decorated method will receive
+ the `dbus.lowlevel.MethodCallMessage` as a keyword argument
+ with this name.
+
+ :Since: 0.80.0?
+
+ `connection_keyword` : str or None
+ If not None (the default), the decorated method will receive
+ the `dbus.connection.Connection` as a keyword argument
+ with this name. This is generally only useful for objects
+ that are available on more than one connection.
+
+ :Since: 0.82.0
+
+ `utf8_strings` : bool
+ If False (default), D-Bus strings are passed to the decorated
+ method as objects of class dbus.String, a unicode subclass.
+
+ If True, D-Bus strings are passed to the decorated method
+ as objects of class dbus.UTF8String, a str subclass guaranteed
+ to be encoded in UTF-8.
+
+ This option does not affect object-paths and signatures, which
+ are always 8-bit strings (str subclass) encoded in ASCII.
+
+ :Since: 0.80.0
+
+ `byte_arrays` : bool
+ If False (default), a byte array will be passed to the decorated
+ method as an `Array` (a list subclass) of `Byte` objects.
+
+ If True, a byte array will be passed to the decorated method as
+ a `ByteArray`, a str subclass. This is usually what you want,
+ but is switched off by default to keep dbus-python's API
+ consistent.
+
+ :Since: 0.80.0
+ """
+ validate_interface_name(dbus_interface)
+
+ def decorator(func):
+ if hasattr(inspect, 'Signature'):
+ args = []
+
+ for arg in inspect.signature(func).parameters.values():
+ if arg.kind in (inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD):
+ args.append(arg.name)
+ else:
+ args = inspect.getargspec(func)[0]
+
+ args.pop(0)
+
+ if async_callbacks:
+ if type(async_callbacks) != tuple:
+ raise TypeError('async_callbacks must be a tuple of (keyword for return callback, keyword for error callback)')
+ if len(async_callbacks) != 2:
+ raise ValueError('async_callbacks must be a tuple of (keyword for return callback, keyword for error callback)')
+ args.remove(async_callbacks[0])
+ args.remove(async_callbacks[1])
+
+ if sender_keyword:
+ args.remove(sender_keyword)
+ if rel_path_keyword:
+ args.remove(rel_path_keyword)
+ if path_keyword:
+ args.remove(path_keyword)
+ if destination_keyword:
+ args.remove(destination_keyword)
+ if message_keyword:
+ args.remove(message_keyword)
+ if connection_keyword:
+ args.remove(connection_keyword)
+
+ if in_signature:
+ in_sig = tuple(Signature(in_signature))
+
+ if len(in_sig) > len(args):
+ raise ValueError('input signature is longer than the number of arguments taken')
+ elif len(in_sig) < len(args):
+ raise ValueError('input signature is shorter than the number of arguments taken')
+
+ func._dbus_is_method = True
+ func._dbus_async_callbacks = async_callbacks
+ func._dbus_interface = dbus_interface
+ func._dbus_in_signature = in_signature
+ func._dbus_out_signature = out_signature
+ func._dbus_sender_keyword = sender_keyword
+ func._dbus_path_keyword = path_keyword
+ func._dbus_rel_path_keyword = rel_path_keyword
+ func._dbus_destination_keyword = destination_keyword
+ func._dbus_message_keyword = message_keyword
+ func._dbus_connection_keyword = connection_keyword
+ func._dbus_args = args
+ func._dbus_get_args_options = dict(byte_arrays=byte_arrays)
+ if 'utf8_strings' in kwargs:
+ raise TypeError("unexpected keyword argument 'utf8_strings'")
+ return func
+
+ return decorator
+
+
+def signal(dbus_interface, signature=None, path_keyword=None,
+ rel_path_keyword=None):
+ """Factory for decorators used to mark methods of a `dbus.service.Object`
+ to emit signals on the D-Bus.
+
+ Whenever the decorated method is called in Python, after the method
+ body is executed, a signal with the same name as the decorated method,
+ with the given D-Bus interface, will be emitted from this object.
+
+ :Parameters:
+ `dbus_interface` : str
+ The D-Bus interface whose signal is emitted
+ `signature` : str
+ The signature of the signal in the usual D-Bus notation
+
+ `path_keyword` : str or None
+ A keyword argument to the decorated method. If not None,
+ that argument will not be emitted as an argument of
+ the signal, and when the signal is emitted, it will appear
+ to come from the object path given by the keyword argument.
+
+ Note that when calling the decorated method, you must always
+ pass in the object path as a keyword argument, not as a
+ positional argument.
+
+ This keyword argument cannot be used on objects where
+ the class attribute ``SUPPORTS_MULTIPLE_OBJECT_PATHS`` is true.
+
+ :Deprecated: since 0.82.0. Use `rel_path_keyword` instead.
+
+ `rel_path_keyword` : str or None
+ A keyword argument to the decorated method. If not None,
+ that argument will not be emitted as an argument of
+ the signal.
+
+ When the signal is emitted, if the named keyword argument is given,
+ the signal will appear to come from the object path obtained by
+ appending the keyword argument to the object's object path.
+ This is useful to implement "fallback objects" (objects which
+ own an entire subtree of the object-path tree).
+
+ If the object is available at more than one object-path on the
+ same or different connections, the signal will be emitted at
+ an appropriate object-path on each connection - for instance,
+ if the object is exported at /abc on connection 1 and at
+ /def and /x/y/z on connection 2, and the keyword argument is
+ /foo, then signals will be emitted from /abc/foo and /def/foo
+ on connection 1, and /x/y/z/foo on connection 2.
+
+ :Since: 0.82.0
+ """
+ validate_interface_name(dbus_interface)
+
+ if path_keyword is not None:
+ from warnings import warn
+ warn(DeprecationWarning('dbus.service.signal::path_keyword has been '
+ 'deprecated since dbus-python 0.82.0, and '
+ 'will not work on objects that support '
+ 'multiple object paths'),
+ DeprecationWarning, stacklevel=2)
+ if rel_path_keyword is not None:
+ raise TypeError('dbus.service.signal::path_keyword and '
+ 'rel_path_keyword cannot both be used')
+
+ def decorator(func):
+ member_name = func.__name__
+ validate_member_name(member_name)
+
+ def emit_signal(self, *args, **keywords):
+ abs_path = None
+ if path_keyword is not None:
+ if self.SUPPORTS_MULTIPLE_OBJECT_PATHS:
+ raise TypeError('path_keyword cannot be used on the '
+ 'signals of an object that supports '
+ 'multiple object paths')
+ abs_path = keywords.pop(path_keyword, None)
+ if (abs_path != self.__dbus_object_path__ and
+ not self.__dbus_object_path__.startswith(abs_path + '/')):
+ raise ValueError('Path %r is not below %r', abs_path,
+ self.__dbus_object_path__)
+
+ rel_path = None
+ if rel_path_keyword is not None:
+ rel_path = keywords.pop(rel_path_keyword, None)
+
+ func(self, *args, **keywords)
+
+ for location in self.locations:
+ if abs_path is None:
+ # non-deprecated case
+ if rel_path is None or rel_path in ('/', ''):
+ object_path = location[1]
+ else:
+ # will be validated by SignalMessage ctor in a moment
+ object_path = location[1] + rel_path
+ else:
+ object_path = abs_path
+
+ message = SignalMessage(object_path,
+ dbus_interface,
+ member_name)
+ message.append(signature=signature, *args)
+
+ location[0].send_message(message)
+ # end emit_signal
+
+ if hasattr(inspect, 'Signature'):
+ args = []
+
+ for arg in inspect.signature(func).parameters.values():
+ if arg.kind in (inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD):
+ args.append(arg.name)
+ else:
+ args = inspect.getargspec(func)[0]
+
+ args.pop(0)
+
+ for keyword in rel_path_keyword, path_keyword:
+ if keyword is not None:
+ try:
+ args.remove(keyword)
+ except ValueError:
+ raise ValueError('function has no argument "%s"' % keyword)
+
+ if signature:
+ sig = tuple(Signature(signature))
+
+ if len(sig) > len(args):
+ raise ValueError('signal signature is longer than the number of arguments provided')
+ elif len(sig) < len(args):
+ raise ValueError('signal signature is shorter than the number of arguments provided')
+
+ emit_signal.__name__ = func.__name__
+ emit_signal.__doc__ = func.__doc__
+ emit_signal._dbus_is_signal = True
+ emit_signal._dbus_interface = dbus_interface
+ emit_signal._dbus_signature = signature
+ emit_signal._dbus_args = args
+ return emit_signal
+
+ return decorator
diff --git a/lib/dbus/exceptions.py b/lib/dbus/exceptions.py
new file mode 100644
index 0000000..870b731
--- /dev/null
+++ b/lib/dbus/exceptions.py
@@ -0,0 +1,133 @@
+"""D-Bus exceptions."""
+
+# Copyright (C) 2007 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = ('DBusException', 'MissingErrorHandlerException',
+ 'MissingReplyHandlerException', 'ValidationException',
+ 'IntrospectionParserException', 'UnknownMethodException',
+ 'NameExistsException')
+
+from dbus._compat import is_py3
+
+
+class DBusException(Exception):
+
+ include_traceback = False
+ """If True, tracebacks will be included in the exception message sent to
+ D-Bus clients.
+
+ Exceptions that are not DBusException subclasses always behave
+ as though this is True. Set this to True on DBusException subclasses
+ that represent a programming error, and leave it False on subclasses that
+ represent an expected failure condition (e.g. a network server not
+ responding)."""
+
+ def __init__(self, *args, **kwargs):
+ name = kwargs.pop('name', None)
+ if name is not None or getattr(self, '_dbus_error_name', None) is None:
+ self._dbus_error_name = name
+ if kwargs:
+ raise TypeError('DBusException does not take keyword arguments: %s'
+ % ', '.join(kwargs.keys()))
+ Exception.__init__(self, *args)
+
+ def __unicode__(self):
+ """Return a unicode error"""
+ # We can't just use Exception.__unicode__ because it chains up weirdly.
+ # https://code.launchpad.net/~mvo/ubuntu/quantal/dbus-python/lp846044/+merge/129214
+ if len(self.args) > 1:
+ s = unicode(self.args)
+ else:
+ s = ''.join(self.args)
+
+ if self._dbus_error_name is not None:
+ return '%s: %s' % (self._dbus_error_name, s)
+ else:
+ return s
+
+ def __str__(self):
+ """Return a str error"""
+ s = Exception.__str__(self)
+ if self._dbus_error_name is not None:
+ return '%s: %s' % (self._dbus_error_name, s)
+ else:
+ return s
+
+ def get_dbus_message(self):
+ if len(self.args) > 1:
+ s = str(self.args)
+ else:
+ s = ''.join(self.args)
+
+ if isinstance(s, bytes):
+ return s.decode('utf-8', 'replace')
+
+ return s
+
+ def get_dbus_name(self):
+ return self._dbus_error_name
+
+class MissingErrorHandlerException(DBusException):
+
+ include_traceback = True
+
+ def __init__(self):
+ DBusException.__init__(self, "error_handler not defined: if you define a reply_handler you must also define an error_handler")
+
+class MissingReplyHandlerException(DBusException):
+
+ include_traceback = True
+
+ def __init__(self):
+ DBusException.__init__(self, "reply_handler not defined: if you define an error_handler you must also define a reply_handler")
+
+class ValidationException(DBusException):
+
+ include_traceback = True
+
+ def __init__(self, msg=''):
+ DBusException.__init__(self, "Error validating string: %s"%msg)
+
+class IntrospectionParserException(DBusException):
+
+ include_traceback = True
+
+ def __init__(self, msg=''):
+ DBusException.__init__(self, "Error parsing introspect data: %s"%msg)
+
+class UnknownMethodException(DBusException):
+
+ include_traceback = True
+ _dbus_error_name = 'org.freedesktop.DBus.Error.UnknownMethod'
+
+ def __init__(self, method):
+ DBusException.__init__(self, "Unknown method: %s"%method)
+
+class NameExistsException(DBusException):
+
+ include_traceback = True
+
+ def __init__(self, name):
+ DBusException.__init__(self, "Bus name already exists: %s"%name)
diff --git a/lib/dbus/gi_service.py b/lib/dbus/gi_service.py
new file mode 100644
index 0000000..f68b088
--- /dev/null
+++ b/lib/dbus/gi_service.py
@@ -0,0 +1,87 @@
+"""Support code for implementing D-Bus services via PyGI."""
+
+# Copyright (C) 2007 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = ['ExportedGObject']
+
+from gi.repository import GObject
+import dbus.service
+
+# The odd syntax used here is required so that the code is compatible with
+# both Python 2 and Python 3. It essentially creates a new class called
+# ExportedGObject with a metaclass of ExportGObjectType and an __init__()
+# function.
+#
+# Because GObject and `dbus.service.Object` both have custom metaclasses, the
+# naive approach using simple multiple inheritance won't work. This class has
+# `ExportedGObjectType` as its metaclass, which is sufficient to make it work
+# correctly.
+
+class ExportedGObjectType(GObject.GObject.__class__, dbus.service.InterfaceType):
+ """A metaclass which inherits from both GObjectMeta and
+ `dbus.service.InterfaceType`. Used as the metaclass for `ExportedGObject`.
+ """
+ def __init__(cls, name, bases, dct):
+ GObject.GObject.__class__.__init__(cls, name, bases, dct)
+ dbus.service.InterfaceType.__init__(cls, name, bases, dct)
+
+
+def ExportedGObject__init__(self, conn=None, object_path=None, **kwargs):
+ """Initialize an exported GObject.
+
+ :Parameters:
+ `conn` : dbus.connection.Connection
+ The D-Bus connection or bus
+ `object_path` : str
+ The object path at which to register this object.
+ :Keywords:
+ `bus_name` : dbus.service.BusName
+ A bus name to be held on behalf of this object, or None.
+ `gobject_properties` : dict
+ GObject properties to be set on the constructed object.
+
+ Any unrecognised keyword arguments will also be interpreted
+ as GObject properties.
+ """
+ bus_name = kwargs.pop('bus_name', None)
+ gobject_properties = kwargs.pop('gobject_properties', None)
+
+ if gobject_properties is not None:
+ kwargs.update(gobject_properties)
+ GObject.GObject.__init__(self, **kwargs)
+ dbus.service.Object.__init__(self, conn=conn,
+ object_path=object_path,
+ bus_name=bus_name)
+
+ExportedGObject__doc__ = '''
+A GObject which is exported on D-Bus.
+'''
+
+ExportedGObject = ExportedGObjectType(
+ 'ExportedGObject',
+ (GObject.GObject, dbus.service.Object),
+ {'__init__': ExportedGObject__init__,
+ '__doc__': ExportedGObject__doc__,
+ })
diff --git a/lib/dbus/glib.py b/lib/dbus/glib.py
new file mode 100644
index 0000000..b521fcf
--- /dev/null
+++ b/lib/dbus/glib.py
@@ -0,0 +1,53 @@
+# Copyright (C) 2004 Anders Carlsson
+# Copyright (C) 2004, 2005, 2006 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2005, 2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+"""Deprecated module which sets the default GLib main context as the mainloop
+implementation within D-Bus, as a side-effect of being imported!
+
+This API is highly non-obvious, so instead of importing this module,
+new programs which don't need pre-0.80 compatibility should use this
+equivalent code::
+
+ from dbus.mainloop.glib import DBusGMainLoop
+ DBusGMainLoop(set_as_default=True)
+"""
+__docformat__ = 'restructuredtext'
+
+from dbus.mainloop.glib import DBusGMainLoop, threads_init
+from warnings import warn as _warn
+
+init_threads = threads_init
+
+DBusGMainLoop(set_as_default=True)
+
+_warn(DeprecationWarning("""\
+Importing dbus.glib to use the GLib main loop with dbus-python is deprecated.
+Instead, use this sequence:
+
+ from dbus.mainloop.glib import DBusGMainLoop
+
+ DBusGMainLoop(set_as_default=True)
+"""), DeprecationWarning, stacklevel=2)
diff --git a/lib/dbus/lowlevel.py b/lib/dbus/lowlevel.py
new file mode 100644
index 0000000..59bd8fe
--- /dev/null
+++ b/lib/dbus/lowlevel.py
@@ -0,0 +1,38 @@
+# Copyright (C) 2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+"""Low-level interface to D-Bus."""
+
+__all__ = ('PendingCall', 'Message', 'MethodCallMessage',
+ 'MethodReturnMessage', 'ErrorMessage', 'SignalMessage',
+ 'HANDLER_RESULT_HANDLED', 'HANDLER_RESULT_NOT_YET_HANDLED',
+ 'MESSAGE_TYPE_INVALID', 'MESSAGE_TYPE_METHOD_CALL',
+ 'MESSAGE_TYPE_METHOD_RETURN', 'MESSAGE_TYPE_ERROR',
+ 'MESSAGE_TYPE_SIGNAL')
+
+from _dbus_bindings import (
+ ErrorMessage, HANDLER_RESULT_HANDLED, HANDLER_RESULT_NOT_YET_HANDLED,
+ MESSAGE_TYPE_ERROR, MESSAGE_TYPE_INVALID, MESSAGE_TYPE_METHOD_CALL,
+ MESSAGE_TYPE_METHOD_RETURN, MESSAGE_TYPE_SIGNAL, Message,
+ MethodCallMessage, MethodReturnMessage, PendingCall, SignalMessage)
diff --git a/lib/dbus/mainloop/__init__.py b/lib/dbus/mainloop/__init__.py
new file mode 100644
index 0000000..b0d20a9
--- /dev/null
+++ b/lib/dbus/mainloop/__init__.py
@@ -0,0 +1,64 @@
+# Copyright (C) 2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+"""Base definitions, etc. for main loop integration.
+
+"""
+
+import _dbus_bindings
+
+NativeMainLoop = _dbus_bindings.NativeMainLoop
+
+NULL_MAIN_LOOP = _dbus_bindings.NULL_MAIN_LOOP
+"""A null mainloop which doesn't actually do anything.
+
+For advanced users who want to dispatch events by hand. This is almost
+certainly a bad idea - if in doubt, use the GLib main loop found in
+`dbus.mainloop.glib`.
+"""
+
+WATCH_READABLE = _dbus_bindings.WATCH_READABLE
+"""Represents a file descriptor becoming readable.
+Used to implement file descriptor watches."""
+
+WATCH_WRITABLE = _dbus_bindings.WATCH_WRITABLE
+"""Represents a file descriptor becoming readable.
+Used to implement file descriptor watches."""
+
+WATCH_HANGUP = _dbus_bindings.WATCH_HANGUP
+"""Represents a file descriptor reaching end-of-file.
+Used to implement file descriptor watches."""
+
+WATCH_ERROR = _dbus_bindings.WATCH_ERROR
+"""Represents an error condition on a file descriptor.
+Used to implement file descriptor watches."""
+
+__all__ = (
+ # Imported into this module
+ 'NativeMainLoop', 'WATCH_READABLE', 'WATCH_WRITABLE',
+ 'WATCH_HANGUP', 'WATCH_ERROR', 'NULL_MAIN_LOOP',
+
+ # Submodules
+ 'glib'
+ )
diff --git a/lib/dbus/mainloop/glib.py b/lib/dbus/mainloop/glib.py
new file mode 100644
index 0000000..5bb2f2e
--- /dev/null
+++ b/lib/dbus/mainloop/glib.py
@@ -0,0 +1,43 @@
+# Copyright (C) 2004 Anders Carlsson
+# Copyright (C) 2004-2006 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2005-2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+"""GLib main loop integration using libdbus-glib."""
+
+__all__ = ('DBusGMainLoop', 'threads_init')
+
+from _dbus_glib_bindings import DBusGMainLoop, gthreads_init
+
+_dbus_gthreads_initialized = False
+def threads_init():
+ """Initialize threads in dbus-glib, if this has not already been done.
+
+ This must be called before creating a second thread in a program that
+ uses this module.
+ """
+ global _dbus_gthreads_initialized
+ if not _dbus_gthreads_initialized:
+ gthreads_init()
+ _dbus_gthreads_initialized = True
diff --git a/lib/dbus/proxies.py b/lib/dbus/proxies.py
new file mode 100644
index 0000000..487976c
--- /dev/null
+++ b/lib/dbus/proxies.py
@@ -0,0 +1,567 @@
+# Copyright (C) 2003-2007 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2003 David Zeuthen
+# Copyright (C) 2004 Rob Taylor
+# Copyright (C) 2005-2007 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+import logging
+
+try:
+ from threading import RLock
+except ImportError:
+ from dummy_threading import RLock
+
+import _dbus_bindings
+from dbus._expat_introspect_parser import process_introspection_data
+from dbus.exceptions import (
+ DBusException, IntrospectionParserException, MissingErrorHandlerException,
+ MissingReplyHandlerException)
+
+__docformat__ = 'restructuredtext'
+
+
+_logger = logging.getLogger('dbus.proxies')
+
+from _dbus_bindings import (
+ BUS_DAEMON_IFACE, BUS_DAEMON_NAME, BUS_DAEMON_PATH, INTROSPECTABLE_IFACE,
+ LOCAL_PATH)
+from dbus._compat import is_py2
+
+
+class _DeferredMethod:
+ """A proxy method which will only get called once we have its
+ introspection reply.
+ """
+ def __init__(self, proxy_method, append, block):
+ self._proxy_method = proxy_method
+ # the test suite relies on the existence of this property
+ self._method_name = proxy_method._method_name
+ self._append = append
+ self._block = block
+
+ def __call__(self, *args, **keywords):
+ if ('reply_handler' in keywords or
+ keywords.get('ignore_reply', False)):
+ # defer the async call til introspection finishes
+ self._append(self._proxy_method, args, keywords)
+ return None
+ else:
+ # we're being synchronous, so block
+ self._block()
+ return self._proxy_method(*args, **keywords)
+
+ def call_async(self, *args, **keywords):
+ self._append(self._proxy_method, args, keywords)
+
+
+class _ProxyMethod:
+ """A proxy method.
+
+ Typically a member of a ProxyObject. Calls to the
+ method produce messages that travel over the Bus and are routed
+ to a specific named Service.
+ """
+ def __init__(self, proxy, connection, bus_name, object_path, method_name,
+ iface):
+ if object_path == LOCAL_PATH:
+ raise DBusException('Methods may not be called on the reserved '
+ 'path %s' % LOCAL_PATH)
+
+ # trust that the proxy, and the properties it had, are OK
+ self._proxy = proxy
+ self._connection = connection
+ self._named_service = bus_name
+ self._object_path = object_path
+ # fail early if the method name is bad
+ _dbus_bindings.validate_member_name(method_name)
+ # the test suite relies on the existence of this property
+ self._method_name = method_name
+ # fail early if the interface name is bad
+ if iface is not None:
+ _dbus_bindings.validate_interface_name(iface)
+ self._dbus_interface = iface
+
+ def __call__(self, *args, **keywords):
+ reply_handler = keywords.pop('reply_handler', None)
+ error_handler = keywords.pop('error_handler', None)
+ ignore_reply = keywords.pop('ignore_reply', False)
+ signature = keywords.pop('signature', None)
+
+ if reply_handler is not None or error_handler is not None:
+ if reply_handler is None:
+ raise MissingReplyHandlerException()
+ elif error_handler is None:
+ raise MissingErrorHandlerException()
+ elif ignore_reply:
+ raise TypeError('ignore_reply and reply_handler cannot be '
+ 'used together')
+
+ dbus_interface = keywords.pop('dbus_interface', self._dbus_interface)
+
+ if signature is None:
+ if dbus_interface is None:
+ key = self._method_name
+ else:
+ key = dbus_interface + '.' + self._method_name
+
+ signature = self._proxy._introspect_method_map.get(key, None)
+
+ if ignore_reply or reply_handler is not None:
+ self._connection.call_async(self._named_service,
+ self._object_path,
+ dbus_interface,
+ self._method_name,
+ signature,
+ args,
+ reply_handler,
+ error_handler,
+ **keywords)
+ else:
+ return self._connection.call_blocking(self._named_service,
+ self._object_path,
+ dbus_interface,
+ self._method_name,
+ signature,
+ args,
+ **keywords)
+
+ def call_async(self, *args, **keywords):
+ reply_handler = keywords.pop('reply_handler', None)
+ error_handler = keywords.pop('error_handler', None)
+ signature = keywords.pop('signature', None)
+
+ dbus_interface = keywords.pop('dbus_interface', self._dbus_interface)
+
+ if signature is None:
+ if dbus_interface:
+ key = dbus_interface + '.' + self._method_name
+ else:
+ key = self._method_name
+ signature = self._proxy._introspect_method_map.get(key, None)
+
+ self._connection.call_async(self._named_service,
+ self._object_path,
+ dbus_interface,
+ self._method_name,
+ signature,
+ args,
+ reply_handler,
+ error_handler,
+ **keywords)
+
+
+class ProxyObject(object):
+ """A proxy to the remote Object.
+
+ A ProxyObject is provided by the Bus. ProxyObjects
+ have member functions, and can be called like normal Python objects.
+ """
+ ProxyMethodClass = _ProxyMethod
+ DeferredMethodClass = _DeferredMethod
+
+ INTROSPECT_STATE_DONT_INTROSPECT = 0
+ INTROSPECT_STATE_INTROSPECT_IN_PROGRESS = 1
+ INTROSPECT_STATE_INTROSPECT_DONE = 2
+
+ def __init__(self, conn=None, bus_name=None, object_path=None,
+ introspect=True, follow_name_owner_changes=False, **kwargs):
+ """Initialize the proxy object.
+
+ :Parameters:
+ `conn` : `dbus.connection.Connection`
+ The bus or connection on which to find this object.
+ The keyword argument `bus` is a deprecated alias for this.
+ `bus_name` : str
+ A bus name for the application owning the object, to be used
+ as the destination for method calls and the sender for
+ signal matches. The keyword argument ``named_service`` is a
+ deprecated alias for this.
+ `object_path` : str
+ The object path at which the application exports the object
+ `introspect` : bool
+ If true (default), attempt to introspect the remote
+ object to find out supported methods and their signatures
+ `follow_name_owner_changes` : bool
+ If true (default is false) and the `bus_name` is a
+ well-known name, follow ownership changes for that name
+ """
+ bus = kwargs.pop('bus', None)
+ if bus is not None:
+ if conn is not None:
+ raise TypeError('conn and bus cannot both be specified')
+ conn = bus
+ from warnings import warn
+ warn('Passing the bus parameter to ProxyObject by name is '
+ 'deprecated: please use positional parameters',
+ DeprecationWarning, stacklevel=2)
+ named_service = kwargs.pop('named_service', None)
+ if named_service is not None:
+ if bus_name is not None:
+ raise TypeError('bus_name and named_service cannot both be '
+ 'specified')
+ bus_name = named_service
+ from warnings import warn
+ warn('Passing the named_service parameter to ProxyObject by name '
+ 'is deprecated: please use positional parameters',
+ DeprecationWarning, stacklevel=2)
+ if kwargs:
+ raise TypeError('ProxyObject.__init__ does not take these '
+ 'keyword arguments: %s'
+ % ', '.join(kwargs.keys()))
+
+ if follow_name_owner_changes:
+ # we don't get the signals unless the Bus has a main loop
+ # XXX: using Bus internals
+ conn._require_main_loop()
+
+ self._bus = conn
+
+ if bus_name is not None:
+ _dbus_bindings.validate_bus_name(bus_name)
+ # the attribute is still called _named_service for the moment,
+ # for the benefit of telepathy-python
+ self._named_service = self._requested_bus_name = bus_name
+
+ _dbus_bindings.validate_object_path(object_path)
+ self.__dbus_object_path__ = object_path
+
+ if not follow_name_owner_changes:
+ self._named_service = conn.activate_name_owner(bus_name)
+
+ #PendingCall object for Introspect call
+ self._pending_introspect = None
+ #queue of async calls waiting on the Introspect to return
+ self._pending_introspect_queue = []
+ #dictionary mapping method names to their input signatures
+ self._introspect_method_map = {}
+
+ # must be a recursive lock because block() is called while locked,
+ # and calls the callback which re-takes the lock
+ self._introspect_lock = RLock()
+
+ if not introspect or self.__dbus_object_path__ == LOCAL_PATH:
+ self._introspect_state = self.INTROSPECT_STATE_DONT_INTROSPECT
+ else:
+ self._introspect_state = self.INTROSPECT_STATE_INTROSPECT_IN_PROGRESS
+
+ self._pending_introspect = self._Introspect()
+
+ bus_name = property(lambda self: self._named_service, None, None,
+ """The bus name to which this proxy is bound. (Read-only,
+ may change.)
+
+ If the proxy was instantiated using a unique name, this property
+ is that unique name.
+
+ If the proxy was instantiated with a well-known name and with
+ ``follow_name_owner_changes`` set false (the default), this
+ property is the unique name of the connection that owned that
+ well-known name when the proxy was instantiated, which might
+ not actually own the requested well-known name any more.
+
+ If the proxy was instantiated with a well-known name and with
+ ``follow_name_owner_changes`` set true, this property is that
+ well-known name.
+ """)
+
+ requested_bus_name = property(lambda self: self._requested_bus_name,
+ None, None,
+ """The bus name which was requested when this proxy was
+ instantiated.
+ """)
+
+ object_path = property(lambda self: self.__dbus_object_path__,
+ None, None,
+ """The object-path of this proxy.""")
+
+ # XXX: We don't currently support this because it's the signal receiver
+ # that's responsible for tracking name owner changes, but it
+ # seems a natural thing to add in future.
+ #unique_bus_name = property(lambda self: something, None, None,
+ # """The unique name of the connection to which this proxy is
+ # currently bound. (Read-only, may change.)
+ # """)
+
+ def connect_to_signal(self, signal_name, handler_function, dbus_interface=None, **keywords):
+ """Arrange for the given function to be called when the given signal
+ is received.
+
+ :Parameters:
+ `signal_name` : str
+ The name of the signal
+ `handler_function` : callable
+ A function to be called when the signal is emitted by
+ the remote object. Its positional arguments will be the
+ arguments of the signal; optionally, it may be given
+ keyword arguments as described below.
+ `dbus_interface` : str
+ Optional interface with which to qualify the signal name.
+ If None (the default) the handler will be called whenever a
+ signal of the given member name is received, whatever
+ its interface.
+ :Keywords:
+ `utf8_strings` : bool
+ If True, the handler function will receive any string
+ arguments as dbus.UTF8String objects (a subclass of str
+ guaranteed to be UTF-8). If False (default) it will receive
+ any string arguments as dbus.String objects (a subclass of
+ unicode).
+ `byte_arrays` : bool
+ If True, the handler function will receive any byte-array
+ arguments as dbus.ByteArray objects (a subclass of str).
+ If False (default) it will receive any byte-array
+ arguments as a dbus.Array of dbus.Byte (subclasses of:
+ a list of ints).
+ `sender_keyword` : str
+ If not None (the default), the handler function will receive
+ the unique name of the sending endpoint as a keyword
+ argument with this name
+ `destination_keyword` : str
+ If not None (the default), the handler function will receive
+ the bus name of the destination (or None if the signal is a
+ broadcast, as is usual) as a keyword argument with this name.
+ `interface_keyword` : str
+ If not None (the default), the handler function will receive
+ the signal interface as a keyword argument with this name.
+ `member_keyword` : str
+ If not None (the default), the handler function will receive
+ the signal name as a keyword argument with this name.
+ `path_keyword` : str
+ If not None (the default), the handler function will receive
+ the object-path of the sending object as a keyword argument
+ with this name
+ `message_keyword` : str
+ If not None (the default), the handler function will receive
+ the `dbus.lowlevel.SignalMessage` as a keyword argument with
+ this name.
+ `arg...` : unicode or UTF-8 str
+ If there are additional keyword parameters of the form
+ ``arg``\\ *n*, match only signals where the *n*\\ th argument
+ is the value given for that keyword parameter. As of this time
+ only string arguments can be matched (in particular,
+ object paths and signatures can't).
+ """
+ return \
+ self._bus.add_signal_receiver(handler_function,
+ signal_name=signal_name,
+ dbus_interface=dbus_interface,
+ bus_name=self._named_service,
+ path=self.__dbus_object_path__,
+ **keywords)
+
+ def _Introspect(self):
+ kwargs = {}
+ return self._bus.call_async(self._named_service,
+ self.__dbus_object_path__,
+ INTROSPECTABLE_IFACE, 'Introspect', '', (),
+ self._introspect_reply_handler,
+ self._introspect_error_handler,
+ require_main_loop=False, **kwargs)
+
+ def _introspect_execute_queue(self):
+ # FIXME: potential to flood the bus
+ # We should make sure mainloops all have idle handlers
+ # and do one message per idle
+ for (proxy_method, args, keywords) in self._pending_introspect_queue:
+ proxy_method(*args, **keywords)
+ self._pending_introspect_queue = []
+
+ def _introspect_reply_handler(self, data):
+ self._introspect_lock.acquire()
+ try:
+ try:
+ self._introspect_method_map = process_introspection_data(data)
+ except IntrospectionParserException as e:
+ self._introspect_error_handler(e)
+ return
+
+ self._introspect_state = self.INTROSPECT_STATE_INTROSPECT_DONE
+ self._pending_introspect = None
+ self._introspect_execute_queue()
+ finally:
+ self._introspect_lock.release()
+
+ def _introspect_error_handler(self, error):
+ logging.basicConfig()
+ _logger.error("Introspect error on %s:%s: %s.%s: %s",
+ self._named_service, self.__dbus_object_path__,
+ error.__class__.__module__, error.__class__.__name__,
+ error)
+ self._introspect_lock.acquire()
+ try:
+ _logger.debug('Executing introspect queue due to error')
+ self._introspect_state = self.INTROSPECT_STATE_DONT_INTROSPECT
+ self._pending_introspect = None
+ self._introspect_execute_queue()
+ finally:
+ self._introspect_lock.release()
+
+ def _introspect_block(self):
+ self._introspect_lock.acquire()
+ try:
+ if self._pending_introspect is not None:
+ self._pending_introspect.block()
+ # else someone still has a _DeferredMethod from before we
+ # finished introspection: no need to do anything special any more
+ finally:
+ self._introspect_lock.release()
+
+ def _introspect_add_to_queue(self, callback, args, kwargs):
+ self._introspect_lock.acquire()
+ try:
+ if self._introspect_state == self.INTROSPECT_STATE_INTROSPECT_IN_PROGRESS:
+ self._pending_introspect_queue.append((callback, args, kwargs))
+ else:
+ # someone still has a _DeferredMethod from before we
+ # finished introspection
+ callback(*args, **kwargs)
+ finally:
+ self._introspect_lock.release()
+
+ def __getattr__(self, member):
+ if member.startswith('__') and member.endswith('__'):
+ raise AttributeError(member)
+ else:
+ return self.get_dbus_method(member)
+
+ def get_dbus_method(self, member, dbus_interface=None):
+ """Return a proxy method representing the given D-Bus method. The
+ returned proxy method can be called in the usual way. For instance, ::
+
+ proxy.get_dbus_method("Foo", dbus_interface='com.example.Bar')(123)
+
+ is equivalent to::
+
+ proxy.Foo(123, dbus_interface='com.example.Bar')
+
+ or even::
+
+ getattr(proxy, "Foo")(123, dbus_interface='com.example.Bar')
+
+ However, using `get_dbus_method` is the only way to call D-Bus
+ methods with certain awkward names - if the author of a service
+ implements a method called ``connect_to_signal`` or even
+ ``__getattr__``, you'll need to use `get_dbus_method` to call them.
+
+ For services which follow the D-Bus convention of CamelCaseMethodNames
+ this won't be a problem.
+ """
+
+ ret = self.ProxyMethodClass(self, self._bus,
+ self._named_service,
+ self.__dbus_object_path__, member,
+ dbus_interface)
+
+ # this can be done without taking the lock - the worst that can
+ # happen is that we accidentally return a _DeferredMethod just after
+ # finishing introspection, in which case _introspect_add_to_queue and
+ # _introspect_block will do the right thing anyway
+ if self._introspect_state == self.INTROSPECT_STATE_INTROSPECT_IN_PROGRESS:
+ ret = self.DeferredMethodClass(ret, self._introspect_add_to_queue,
+ self._introspect_block)
+
+ return ret
+
+ def __repr__(self):
+ return '<ProxyObject wrapping %s %s %s at %#x>'%(
+ self._bus, self._named_service, self.__dbus_object_path__, id(self))
+ __str__ = __repr__
+
+
+class Interface(object):
+ """An interface into a remote object.
+
+ An Interface can be used to wrap ProxyObjects
+ so that calls can be routed to their correct
+ D-Bus interface.
+ """
+
+ def __init__(self, object, dbus_interface):
+ """Construct a proxy for the given interface on the given object.
+
+ :Parameters:
+ `object` : `dbus.proxies.ProxyObject` or `dbus.Interface`
+ The remote object or another of its interfaces
+ `dbus_interface` : str
+ An interface the `object` implements
+ """
+ if isinstance(object, Interface):
+ self._obj = object.proxy_object
+ else:
+ self._obj = object
+ self._dbus_interface = dbus_interface
+
+ object_path = property (lambda self: self._obj.object_path, None, None,
+ "The D-Bus object path of the underlying object")
+ __dbus_object_path__ = object_path
+ bus_name = property (lambda self: self._obj.bus_name, None, None,
+ "The bus name to which the underlying proxy object "
+ "is bound")
+ requested_bus_name = property (lambda self: self._obj.requested_bus_name,
+ None, None,
+ "The bus name which was requested when the "
+ "underlying object was created")
+ proxy_object = property (lambda self: self._obj, None, None,
+ """The underlying proxy object""")
+ dbus_interface = property (lambda self: self._dbus_interface, None, None,
+ """The D-Bus interface represented""")
+
+ def connect_to_signal(self, signal_name, handler_function,
+ dbus_interface=None, **keywords):
+ """Arrange for a function to be called when the given signal is
+ emitted.
+
+ The parameters and keyword arguments are the same as for
+ `dbus.proxies.ProxyObject.connect_to_signal`, except that if
+ `dbus_interface` is None (the default), the D-Bus interface that
+ was passed to the `Interface` constructor is used.
+ """
+ if not dbus_interface:
+ dbus_interface = self._dbus_interface
+
+ return self._obj.connect_to_signal(signal_name, handler_function,
+ dbus_interface, **keywords)
+
+ def __getattr__(self, member):
+ if member.startswith('__') and member.endswith('__'):
+ raise AttributeError(member)
+ else:
+ return self._obj.get_dbus_method(member, self._dbus_interface)
+
+ def get_dbus_method(self, member, dbus_interface=None):
+ """Return a proxy method representing the given D-Bus method.
+
+ This is the same as `dbus.proxies.ProxyObject.get_dbus_method`
+ except that if `dbus_interface` is None (the default),
+ the D-Bus interface that was passed to the `Interface` constructor
+ is used.
+ """
+ if dbus_interface is None:
+ dbus_interface = self._dbus_interface
+ return self._obj.get_dbus_method(member, dbus_interface)
+
+ def __repr__(self):
+ return '<Interface %r implementing %r at %#x>'%(
+ self._obj, self._dbus_interface, id(self))
+ __str__ = __repr__
diff --git a/lib/dbus/server.py b/lib/dbus/server.py
new file mode 100644
index 0000000..40a7bb9
--- /dev/null
+++ b/lib/dbus/server.py
@@ -0,0 +1,119 @@
+# Copyright (C) 2008 Openismus GmbH <http://openismus.com/>
+# Copyright (C) 2008 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = ('Server', )
+__docformat__ = 'reStructuredText'
+
+from _dbus_bindings import _Server
+from dbus.connection import Connection
+
+class Server(_Server):
+ """An opaque object representing a server that listens for connections from
+ other applications.
+
+ This class is not useful to instantiate directly: you must subclass it and
+ either extend the method connection_added, or append to the
+ list on_connection_added.
+
+ :Since: 0.83
+ """
+
+ def __new__(cls, address, connection_class=Connection,
+ mainloop=None, auth_mechanisms=None):
+ """Construct a new Server.
+
+ :Parameters:
+ `address` : str
+ Listen on this address.
+ `connection_class` : type
+ When new connections come in, instantiate this subclass
+ of dbus.connection.Connection to represent them.
+ The default is Connection.
+ `mainloop` : dbus.mainloop.NativeMainLoop or None
+ The main loop with which to associate the new connections.
+ `auth_mechanisms` : sequence of str
+ Authentication mechanisms to allow. The default is to allow
+ any authentication mechanism supported by ``libdbus``.
+ """
+ return super(Server, cls).__new__(cls, address, connection_class,
+ mainloop, auth_mechanisms)
+
+ def __init__(self, *args, **kwargs):
+
+ self.__connections = {}
+
+ self.on_connection_added = []
+ """A list of callbacks to invoke when a connection is added.
+ They receive two arguments: this Server and the new Connection."""
+
+ self.on_connection_removed = []
+ """A list of callbacks to invoke when a connection becomes
+ disconnected. They receive two arguments: this Server and the removed
+ Connection."""
+
+ # This method name is hard-coded in _dbus_bindings._Server.
+ # This is not public API.
+ def _on_new_connection(self, conn):
+ conn.call_on_disconnection(self.connection_removed)
+ self.connection_added(conn)
+
+ def connection_added(self, conn):
+ """Respond to the creation of a new Connection.
+
+ This base-class implementation just invokes the callbacks in
+ the on_connection_added attribute.
+
+ :Parameters:
+ `conn` : dbus.connection.Connection
+ A D-Bus connection which has just been added.
+
+ The type of this parameter is whatever was passed
+ to the Server constructor as the ``connection_class``.
+ """
+ if self.on_connection_added:
+ for cb in self.on_connection_added:
+ cb(conn)
+
+ def connection_removed(self, conn):
+ """Respond to the disconnection of a Connection.
+
+ This base-class implementation just invokes the callbacks in
+ the on_connection_removed attribute.
+
+ :Parameters:
+ `conn` : dbus.connection.Connection
+ A D-Bus connection which has just become disconnected.
+
+ The type of this parameter is whatever was passed
+ to the Server constructor as the ``connection_class``.
+ """
+ if self.on_connection_removed:
+ for cb in self.on_connection_removed:
+ cb(conn)
+
+ address = property(_Server.get_address)
+ id = property(_Server.get_id)
+ is_connected = property(_Server.get_is_connected)
+
diff --git a/lib/dbus/service.py b/lib/dbus/service.py
new file mode 100644
index 0000000..2e13d3c
--- /dev/null
+++ b/lib/dbus/service.py
@@ -0,0 +1,840 @@
+# Copyright (C) 2003-2006 Red Hat Inc. <http://www.redhat.com/>
+# Copyright (C) 2003 David Zeuthen
+# Copyright (C) 2004 Rob Taylor
+# Copyright (C) 2005-2006 Collabora Ltd. <http://www.collabora.co.uk/>
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation
+# files (the "Software"), to deal in the Software without
+# restriction, including without limitation the rights to use, copy,
+# modify, merge, publish, distribute, sublicense, and/or sell copies
+# of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+__all__ = ('BusName', 'Object', 'FallbackObject', 'method', 'signal')
+__docformat__ = 'restructuredtext'
+
+import sys
+import logging
+import threading
+import traceback
+try:
+ from collections.abc import Sequence
+except ImportError:
+ # Python 2 (and 3.x < 3.3, but we don't support those)
+ from collections import Sequence
+
+import _dbus_bindings
+from dbus import (
+ INTROSPECTABLE_IFACE, ObjectPath, SessionBus, Signature, Struct,
+ validate_bus_name, validate_object_path)
+from dbus.decorators import method, signal
+from dbus.exceptions import (
+ DBusException, NameExistsException, UnknownMethodException)
+from dbus.lowlevel import ErrorMessage, MethodReturnMessage, MethodCallMessage
+from dbus.proxies import LOCAL_PATH
+from dbus._compat import is_py2
+
+
+_logger = logging.getLogger('dbus.service')
+
+
+class _VariantSignature(object):
+ """A fake method signature which, when iterated, yields an endless stream
+ of 'v' characters representing variants (handy with zip()).
+
+ It has no string representation.
+ """
+ def __iter__(self):
+ """Return self."""
+ return self
+
+ def __next__(self):
+ """Return 'v' whenever called."""
+ return 'v'
+
+
+class BusName(object):
+ """A base class for exporting your own Named Services across the Bus.
+
+ When instantiated, objects of this class attempt to claim the given
+ well-known name on the given bus for the current process. The name is
+ released when the BusName object becomes unreferenced.
+
+ If a well-known name is requested multiple times, multiple references
+ to the same BusName object will be returned.
+
+ :Caveats:
+
+ - Assumes that named services are only ever requested using this class -
+ if you request names from the bus directly, confusion may occur.
+ - Does not handle queueing.
+ """
+ def __new__(cls, name, bus=None, allow_replacement=False , replace_existing=False, do_not_queue=False):
+ """Constructor, which may either return an existing cached object
+ or a new object.
+
+ :Parameters:
+ `name` : str
+ The well-known name to be advertised
+ `bus` : dbus.Bus
+ A Bus on which this service will be advertised.
+
+ Omitting this parameter or setting it to None has been
+ deprecated since version 0.82.1. For backwards compatibility,
+ if this is done, the global shared connection to the session
+ bus will be used.
+
+ `allow_replacement` : bool
+ If True, other processes trying to claim the same well-known
+ name will take precedence over this one.
+ `replace_existing` : bool
+ If True, this process can take over the well-known name
+ from other processes already holding it.
+ `do_not_queue` : bool
+ If True, this service will not be placed in the queue of
+ services waiting for the requested name if another service
+ already holds it.
+ """
+ validate_bus_name(name, allow_well_known=True, allow_unique=False)
+
+ # if necessary, get default bus (deprecated)
+ if bus is None:
+ import warnings
+ warnings.warn('Omitting the "bus" parameter to '
+ 'dbus.service.BusName.__init__ is deprecated',
+ DeprecationWarning, stacklevel=2)
+ bus = SessionBus()
+
+ # see if this name is already defined, return it if so
+ # FIXME: accessing internals of Bus
+ if name in bus._bus_names:
+ return bus._bus_names[name]
+
+ # otherwise register the name
+ name_flags = (
+ (allow_replacement and _dbus_bindings.NAME_FLAG_ALLOW_REPLACEMENT or 0) |
+ (replace_existing and _dbus_bindings.NAME_FLAG_REPLACE_EXISTING or 0) |
+ (do_not_queue and _dbus_bindings.NAME_FLAG_DO_NOT_QUEUE or 0))
+
+ retval = bus.request_name(name, name_flags)
+
+ # TODO: more intelligent tracking of bus name states?
+ if retval == _dbus_bindings.REQUEST_NAME_REPLY_PRIMARY_OWNER:
+ pass
+ elif retval == _dbus_bindings.REQUEST_NAME_REPLY_IN_QUEUE:
+ # queueing can happen by default, maybe we should
+ # track this better or let the user know if they're
+ # queued or not?
+ pass
+ elif retval == _dbus_bindings.REQUEST_NAME_REPLY_EXISTS:
+ raise NameExistsException(name)
+ elif retval == _dbus_bindings.REQUEST_NAME_REPLY_ALREADY_OWNER:
+ # if this is a shared bus which is being used by someone
+ # else in this process, this can happen legitimately
+ pass
+ else:
+ raise RuntimeError('requesting bus name %s returned unexpected value %s' % (name, retval))
+
+ # and create the object
+ bus_name = object.__new__(cls)
+ bus_name._bus = bus
+ bus_name._name = name
+
+ # cache instance (weak ref only)
+ # FIXME: accessing Bus internals again
+ bus._bus_names[name] = bus_name
+
+ return bus_name
+
+ # do nothing because this is called whether or not the bus name
+ # object was retrieved from the cache or created new
+ def __init__(self, *args, **keywords):
+ pass
+
+ # we can delete the low-level name here because these objects
+ # are guaranteed to exist only once for each bus name
+ def __del__(self):
+ self._bus.release_name(self._name)
+ pass
+
+ def get_bus(self):
+ """Get the Bus this Service is on"""
+ return self._bus
+
+ def get_name(self):
+ """Get the name of this service"""
+ return self._name
+
+ def __repr__(self):
+ return '<dbus.service.BusName %s on %r at %#x>' % (self._name, self._bus, id(self))
+ __str__ = __repr__
+
+
+def _method_lookup(self, method_name, dbus_interface):
+ """Walks the Python MRO of the given class to find the method to invoke.
+
+ Returns two methods, the one to call, and the one it inherits from which
+ defines its D-Bus interface name, signature, and attributes.
+ """
+ parent_method = None
+ candidate_class = None
+ successful = False
+
+ # split up the cases when we do and don't have an interface because the
+ # latter is much simpler
+ if dbus_interface:
+ # search through the class hierarchy in python MRO order
+ for cls in self.__class__.__mro__:
+ # if we haven't got a candidate class yet, and we find a class with a
+ # suitably named member, save this as a candidate class
+ if (not candidate_class and method_name in cls.__dict__):
+ if ("_dbus_is_method" in cls.__dict__[method_name].__dict__
+ and "_dbus_interface" in cls.__dict__[method_name].__dict__):
+ # however if it is annotated for a different interface
+ # than we are looking for, it cannot be a candidate
+ if cls.__dict__[method_name]._dbus_interface == dbus_interface:
+ candidate_class = cls
+ parent_method = cls.__dict__[method_name]
+ successful = True
+ break
+ else:
+ pass
+ else:
+ candidate_class = cls
+
+ # if we have a candidate class, carry on checking this and all
+ # superclasses for a method annoated as a dbus method
+ # on the correct interface
+ if (candidate_class and method_name in cls.__dict__
+ and "_dbus_is_method" in cls.__dict__[method_name].__dict__
+ and "_dbus_interface" in cls.__dict__[method_name].__dict__
+ and cls.__dict__[method_name]._dbus_interface == dbus_interface):
+ # the candidate class has a dbus method on the correct interface,
+ # or overrides a method that is, success!
+ parent_method = cls.__dict__[method_name]
+ successful = True
+ break
+
+ else:
+ # simpler version of above
+ for cls in self.__class__.__mro__:
+ if (not candidate_class and method_name in cls.__dict__):
+ candidate_class = cls
+
+ if (candidate_class and method_name in cls.__dict__
+ and "_dbus_is_method" in cls.__dict__[method_name].__dict__):
+ parent_method = cls.__dict__[method_name]
+ successful = True
+ break
+
+ if successful:
+ return (candidate_class.__dict__[method_name], parent_method)
+ else:
+ if dbus_interface:
+ raise UnknownMethodException('%s is not a valid method of interface %s' % (method_name, dbus_interface))
+ else:
+ raise UnknownMethodException('%s is not a valid method' % method_name)
+
+
+def _method_reply_return(connection, message, method_name, signature, *retval):
+ reply = MethodReturnMessage(message)
+ try:
+ reply.append(signature=signature, *retval)
+ except Exception as e:
+ logging.basicConfig()
+ if signature is None:
+ try:
+ signature = reply.guess_signature(retval) + ' (guessed)'
+ except Exception as e:
+ _logger.error('Unable to guess signature for arguments %r: '
+ '%s: %s', retval, e.__class__, e)
+ raise
+ _logger.error('Unable to append %r to message with signature %s: '
+ '%s: %s', retval, signature, e.__class__, e)
+ raise
+
+ if not message.get_no_reply():
+ connection.send_message(reply)
+
+
+def _method_reply_error(connection, message, exception):
+ name = getattr(exception, '_dbus_error_name', None)
+
+ if name is not None:
+ pass
+ elif getattr(exception, '__module__', '') in ('', '__main__'):
+ name = 'org.freedesktop.DBus.Python.%s' % exception.__class__.__name__
+ else:
+ name = 'org.freedesktop.DBus.Python.%s.%s' % (exception.__module__, exception.__class__.__name__)
+
+ et, ev, etb = sys.exc_info()
+ if isinstance(exception, DBusException) and not exception.include_traceback:
+ # We don't actually want the traceback anyway
+ contents = exception.get_dbus_message()
+ elif ev is exception:
+ # The exception was actually thrown, so we can get a traceback
+ contents = ''.join(traceback.format_exception(et, ev, etb))
+ else:
+ # We don't have any traceback for it, e.g.
+ # async_err_cb(MyException('Failed to badger the mushroom'))
+ # see also https://bugs.freedesktop.org/show_bug.cgi?id=12403
+ contents = ''.join(traceback.format_exception_only(exception.__class__,
+ exception))
+ reply = ErrorMessage(message, name, contents)
+
+ if not message.get_no_reply():
+ connection.send_message(reply)
+
+
+class InterfaceType(type):
+ def __init__(cls, name, bases, dct):
+ # these attributes are shared between all instances of the Interface
+ # object, so this has to be a dictionary that maps class names to
+ # the per-class introspection/interface data
+ class_table = getattr(cls, '_dbus_class_table', {})
+ cls._dbus_class_table = class_table
+ interface_table = class_table[cls.__module__ + '.' + name] = {}
+
+ # merge all the name -> method tables for all the interfaces
+ # implemented by our base classes into our own
+ for b in bases:
+ base_name = b.__module__ + '.' + b.__name__
+ if getattr(b, '_dbus_class_table', False):
+ for (interface, method_table) in class_table[base_name].items():
+ our_method_table = interface_table.setdefault(interface, {})
+ our_method_table.update(method_table)
+
+ # add in all the name -> method entries for our own methods/signals
+ for func in dct.values():
+ if getattr(func, '_dbus_interface', False):
+ method_table = interface_table.setdefault(func._dbus_interface, {})
+ method_table[func.__name__] = func
+
+ super(InterfaceType, cls).__init__(name, bases, dct)
+
+ # methods are different to signals, so we have two functions... :)
+ def _reflect_on_method(cls, func):
+ args = func._dbus_args
+
+ if func._dbus_in_signature:
+ # convert signature into a tuple so length refers to number of
+ # types, not number of characters. the length is checked by
+ # the decorator to make sure it matches the length of args.
+ in_sig = tuple(Signature(func._dbus_in_signature))
+ else:
+ # magic iterator which returns as many v's as we need
+ in_sig = _VariantSignature()
+
+ if func._dbus_out_signature:
+ out_sig = Signature(func._dbus_out_signature)
+ else:
+ # its tempting to default to Signature('v'), but
+ # for methods that return nothing, providing incorrect
+ # introspection data is worse than providing none at all
+ out_sig = []
+
+ reflection_data = ' <method name="%s">\n' % (func.__name__)
+ for pair in zip(in_sig, args):
+ reflection_data += ' <arg direction="in" type="%s" name="%s" />\n' % pair
+ for type in out_sig:
+ reflection_data += ' <arg direction="out" type="%s" />\n' % type
+ reflection_data += ' </method>\n'
+
+ return reflection_data
+
+ def _reflect_on_signal(cls, func):
+ args = func._dbus_args
+
+ if func._dbus_signature:
+ # convert signature into a tuple so length refers to number of
+ # types, not number of characters
+ sig = tuple(Signature(func._dbus_signature))
+ else:
+ # magic iterator which returns as many v's as we need
+ sig = _VariantSignature()
+
+ reflection_data = ' <signal name="%s">\n' % (func.__name__)
+ for pair in zip(sig, args):
+ reflection_data = reflection_data + ' <arg type="%s" name="%s" />\n' % pair
+ reflection_data = reflection_data + ' </signal>\n'
+
+ return reflection_data
+
+
+# Define Interface as an instance of the metaclass InterfaceType, in a way
+# that is compatible across both Python 2 and Python 3.
+Interface = InterfaceType('Interface', (object,), {})
+
+
+#: A unique object used as the value of Object._object_path and
+#: Object._connection if it's actually in more than one place
+_MANY = object()
+
+class Object(Interface):
+ r"""A base class for exporting your own Objects across the Bus.
+
+ Just inherit from Object and mark exported methods with the
+ @\ `dbus.service.method` or @\ `dbus.service.signal` decorator.
+
+ Example::
+
+ class Example(dbus.service.object):
+ def __init__(self, object_path):
+ dbus.service.Object.__init__(self, dbus.SessionBus(), path)
+ self._last_input = None
+
+ @dbus.service.method(interface='com.example.Sample',
+ in_signature='v', out_signature='s')
+ def StringifyVariant(self, var):
+ self.LastInputChanged(var) # emits the signal
+ return str(var)
+
+ @dbus.service.signal(interface='com.example.Sample',
+ signature='v')
+ def LastInputChanged(self, var):
+ # run just before the signal is actually emitted
+ # just put "pass" if nothing should happen
+ self._last_input = var
+
+ @dbus.service.method(interface='com.example.Sample',
+ in_signature='', out_signature='v')
+ def GetLastInput(self):
+ return self._last_input
+ """
+
+ #: If True, this object can be made available at more than one object path.
+ #: If True but `SUPPORTS_MULTIPLE_CONNECTIONS` is False, the object may
+ #: handle more than one object path, but they must all be on the same
+ #: connection.
+ SUPPORTS_MULTIPLE_OBJECT_PATHS = False
+
+ #: If True, this object can be made available on more than one connection.
+ #: If True but `SUPPORTS_MULTIPLE_OBJECT_PATHS` is False, the object must
+ #: have the same object path on all its connections.
+ SUPPORTS_MULTIPLE_CONNECTIONS = False
+
+ def __init__(self, conn=None, object_path=None, bus_name=None):
+ """Constructor. Either conn or bus_name is required; object_path
+ is also required.
+
+ :Parameters:
+ `conn` : dbus.connection.Connection or None
+ The connection on which to export this object.
+
+ If None, use the Bus associated with the given ``bus_name``.
+ If there is no ``bus_name`` either, the object is not
+ initially available on any Connection.
+
+ For backwards compatibility, if an instance of
+ dbus.service.BusName is passed as the first parameter,
+ this is equivalent to passing its associated Bus as
+ ``conn``, and passing the BusName itself as ``bus_name``.
+
+ `object_path` : str or None
+ A D-Bus object path at which to make this Object available
+ immediately. If this is not None, a `conn` or `bus_name` must
+ also be provided.
+
+ `bus_name` : dbus.service.BusName or None
+ Represents a well-known name claimed by this process. A
+ reference to the BusName object will be held by this
+ Object, preventing the name from being released during this
+ Object's lifetime (unless it's released manually).
+ """
+ if object_path is not None:
+ validate_object_path(object_path)
+
+ if isinstance(conn, BusName):
+ # someone's using the old API; don't gratuitously break them
+ bus_name = conn
+ conn = bus_name.get_bus()
+ elif conn is None:
+ if bus_name is not None:
+ # someone's using the old API but naming arguments, probably
+ conn = bus_name.get_bus()
+
+ #: Either an object path, None or _MANY
+ self._object_path = None
+ #: Either a dbus.connection.Connection, None or _MANY
+ self._connection = None
+ #: A list of tuples (Connection, object path, False) where the False
+ #: is for future expansion (to support fallback paths)
+ self._locations = []
+ #: Lock protecting `_locations`, `_connection` and `_object_path`
+ self._locations_lock = threading.Lock()
+
+ #: True if this is a fallback object handling a whole subtree.
+ self._fallback = False
+
+ self._name = bus_name
+
+ if conn is None and object_path is not None:
+ raise TypeError('If object_path is given, either conn or bus_name '
+ 'is required')
+ if conn is not None and object_path is not None:
+ self.add_to_connection(conn, object_path)
+
+ @property
+ def __dbus_object_path__(self):
+ """The object-path at which this object is available.
+ Access raises AttributeError if there is no object path, or more than
+ one object path.
+
+ Changed in 0.82.0: AttributeError can be raised.
+ """
+ if self._object_path is _MANY:
+ raise AttributeError('Object %r has more than one object path: '
+ 'use Object.locations instead' % self)
+ elif self._object_path is None:
+ raise AttributeError('Object %r has no object path yet' % self)
+ else:
+ return self._object_path
+
+ @property
+ def connection(self):
+ """The Connection on which this object is available.
+ Access raises AttributeError if there is no Connection, or more than
+ one Connection.
+
+ Changed in 0.82.0: AttributeError can be raised.
+ """
+ if self._connection is _MANY:
+ raise AttributeError('Object %r is on more than one Connection: '
+ 'use Object.locations instead' % self)
+ elif self._connection is None:
+ raise AttributeError('Object %r has no Connection yet' % self)
+ else:
+ return self._connection
+
+ @property
+ def locations(self):
+ """An iterable over tuples representing locations at which this
+ object is available.
+
+ Each tuple has at least two items, but may have more in future
+ versions of dbus-python, so do not rely on their exact length.
+ The first two items are the dbus.connection.Connection and the object
+ path.
+
+ :Since: 0.82.0
+ """
+ return iter(self._locations)
+
+ def add_to_connection(self, connection, path):
+ """Make this object accessible via the given D-Bus connection and
+ object path.
+
+ :Parameters:
+ `connection` : dbus.connection.Connection
+ Export the object on this connection. If the class attribute
+ SUPPORTS_MULTIPLE_CONNECTIONS is False (default), this object
+ can only be made available on one connection; if the class
+ attribute is set True by a subclass, the object can be made
+ available on more than one connection.
+
+ `path` : dbus.ObjectPath or other str
+ Place the object at this object path. If the class attribute
+ SUPPORTS_MULTIPLE_OBJECT_PATHS is False (default), this object
+ can only be made available at one object path; if the class
+ attribute is set True by a subclass, the object can be made
+ available with more than one object path.
+
+ :Raises ValueError: if the object's class attributes do not allow the
+ object to be exported in the desired way.
+ :Since: 0.82.0
+ """
+ if path == LOCAL_PATH:
+ raise ValueError('Objects may not be exported on the reserved '
+ 'path %s' % LOCAL_PATH)
+
+ self._locations_lock.acquire()
+ try:
+ if (self._connection is not None and
+ self._connection is not connection and
+ not self.SUPPORTS_MULTIPLE_CONNECTIONS):
+ raise ValueError('%r is already exported on '
+ 'connection %r' % (self, self._connection))
+
+ if (self._object_path is not None and
+ not self.SUPPORTS_MULTIPLE_OBJECT_PATHS and
+ self._object_path != path):
+ raise ValueError('%r is already exported at object '
+ 'path %s' % (self, self._object_path))
+
+ connection._register_object_path(path, self._message_cb,
+ self._unregister_cb,
+ self._fallback)
+
+ if self._connection is None:
+ self._connection = connection
+ elif self._connection is not connection:
+ self._connection = _MANY
+
+ if self._object_path is None:
+ self._object_path = path
+ elif self._object_path != path:
+ self._object_path = _MANY
+
+ self._locations.append((connection, path, self._fallback))
+ finally:
+ self._locations_lock.release()
+
+ def remove_from_connection(self, connection=None, path=None):
+ """Make this object inaccessible via the given D-Bus connection
+ and object path. If no connection or path is specified,
+ the object ceases to be accessible via any connection or path.
+
+ :Parameters:
+ `connection` : dbus.connection.Connection or None
+ Only remove the object from this Connection. If None,
+ remove from all Connections on which it's exported.
+ `path` : dbus.ObjectPath or other str, or None
+ Only remove the object from this object path. If None,
+ remove from all object paths.
+ :Raises LookupError:
+ if the object was not exported on the requested connection
+ or path, or (if both are None) was not exported at all.
+ :Since: 0.81.1
+ """
+ self._locations_lock.acquire()
+ try:
+ if self._object_path is None or self._connection is None:
+ raise LookupError('%r is not exported' % self)
+
+ if connection is not None or path is not None:
+ dropped = []
+ for location in self._locations:
+ if ((connection is None or location[0] is connection) and
+ (path is None or location[1] == path)):
+ dropped.append(location)
+ else:
+ dropped = self._locations
+ self._locations = []
+
+ if not dropped:
+ raise LookupError('%r is not exported at a location matching '
+ '(%r,%r)' % (self, connection, path))
+
+ for location in dropped:
+ try:
+ location[0]._unregister_object_path(location[1])
+ except LookupError:
+ pass
+ if self._locations:
+ try:
+ self._locations.remove(location)
+ except ValueError:
+ pass
+ finally:
+ self._locations_lock.release()
+
+ def _unregister_cb(self, connection):
+ # there's not really enough information to do anything useful here
+ _logger.info('Unregistering exported object %r from some path '
+ 'on %r', self, connection)
+
+ def _message_cb(self, connection, message):
+ if not isinstance(message, MethodCallMessage):
+ return
+
+ try:
+ # lookup candidate method and parent method
+ method_name = message.get_member()
+ interface_name = message.get_interface()
+ (candidate_method, parent_method) = _method_lookup(self, method_name, interface_name)
+
+ # set up method call parameters
+ args = message.get_args_list(**parent_method._dbus_get_args_options)
+ keywords = {}
+
+ if parent_method._dbus_out_signature is not None:
+ signature = Signature(parent_method._dbus_out_signature)
+ else:
+ signature = None
+
+ # set up async callback functions
+ if parent_method._dbus_async_callbacks:
+ (return_callback, error_callback) = parent_method._dbus_async_callbacks
+ keywords[return_callback] = lambda *retval: _method_reply_return(connection, message, method_name, signature, *retval)
+ keywords[error_callback] = lambda exception: _method_reply_error(connection, message, exception)
+
+ # include the sender etc. if desired
+ if parent_method._dbus_sender_keyword:
+ keywords[parent_method._dbus_sender_keyword] = message.get_sender()
+ if parent_method._dbus_path_keyword:
+ keywords[parent_method._dbus_path_keyword] = message.get_path()
+ if parent_method._dbus_rel_path_keyword:
+ path = message.get_path()
+ rel_path = path
+ for exp in self._locations:
+ # pathological case: if we're exported in two places,
+ # one of which is a subtree of the other, then pick the
+ # subtree by preference (i.e. minimize the length of
+ # rel_path)
+ if exp[0] is connection:
+ if path == exp[1]:
+ rel_path = '/'
+ break
+ if exp[1] == '/':
+ # we already have rel_path == path at the beginning
+ continue
+ if path.startswith(exp[1] + '/'):
+ # yes we're in this exported subtree
+ suffix = path[len(exp[1]):]
+ if len(suffix) < len(rel_path):
+ rel_path = suffix
+ rel_path = ObjectPath(rel_path)
+ keywords[parent_method._dbus_rel_path_keyword] = rel_path
+
+ if parent_method._dbus_destination_keyword:
+ keywords[parent_method._dbus_destination_keyword] = message.get_destination()
+ if parent_method._dbus_message_keyword:
+ keywords[parent_method._dbus_message_keyword] = message
+ if parent_method._dbus_connection_keyword:
+ keywords[parent_method._dbus_connection_keyword] = connection
+
+ # call method
+ retval = candidate_method(self, *args, **keywords)
+
+ # we're done - the method has got callback functions to reply with
+ if parent_method._dbus_async_callbacks:
+ return
+
+ # otherwise we send the return values in a reply. if we have a
+ # signature, use it to turn the return value into a tuple as
+ # appropriate
+ if signature is not None:
+ signature_tuple = tuple(signature)
+ # if we have zero or one return values we want make a tuple
+ # for the _method_reply_return function, otherwise we need
+ # to check we're passing it a sequence
+ if len(signature_tuple) == 0:
+ if retval == None:
+ retval = ()
+ else:
+ raise TypeError('%s has an empty output signature but did not return None' %
+ method_name)
+ elif len(signature_tuple) == 1:
+ retval = (retval,)
+ else:
+ if isinstance(retval, Sequence):
+ # multi-value signature, multi-value return... proceed
+ # unchanged
+ pass
+ else:
+ raise TypeError('%s has multiple output values in signature %s but did not return a sequence' %
+ (method_name, signature))
+
+ # no signature, so just turn the return into a tuple and send it as normal
+ else:
+ if retval is None:
+ retval = ()
+ elif (isinstance(retval, tuple)
+ and not isinstance(retval, Struct)):
+ # If the return is a tuple that is not a Struct, we use it
+ # as-is on the assumption that there are multiple return
+ # values - this is the usual Python idiom. (fd.o #10174)
+ pass
+ else:
+ retval = (retval,)
+
+ _method_reply_return(connection, message, method_name, signature, *retval)
+ except Exception as exception:
+ # send error reply
+ _method_reply_error(connection, message, exception)
+
+ @method(INTROSPECTABLE_IFACE, in_signature='', out_signature='s',
+ path_keyword='object_path', connection_keyword='connection')
+ def Introspect(self, object_path, connection):
+ """Return a string of XML encoding this object's supported interfaces,
+ methods and signals.
+ """
+ reflection_data = _dbus_bindings.DBUS_INTROSPECT_1_0_XML_DOCTYPE_DECL_NODE
+ reflection_data += '<node name="%s">\n' % object_path
+
+ interfaces = self._dbus_class_table[self.__class__.__module__ + '.' + self.__class__.__name__]
+ for (name, funcs) in interfaces.items():
+ reflection_data += ' <interface name="%s">\n' % (name)
+
+ for func in funcs.values():
+ if getattr(func, '_dbus_is_method', False):
+ reflection_data += self.__class__._reflect_on_method(func)
+ elif getattr(func, '_dbus_is_signal', False):
+ reflection_data += self.__class__._reflect_on_signal(func)
+
+ reflection_data += ' </interface>\n'
+
+ for name in connection.list_exported_child_objects(object_path):
+ reflection_data += ' <node name="%s"/>\n' % name
+
+ reflection_data += '</node>\n'
+
+ return reflection_data
+
+ def __repr__(self):
+ where = ''
+ if (self._object_path is not _MANY
+ and self._object_path is not None):
+ where = ' at %s' % self._object_path
+ return '<%s.%s%s at %#x>' % (self.__class__.__module__,
+ self.__class__.__name__, where,
+ id(self))
+ __str__ = __repr__
+
+class FallbackObject(Object):
+ """An object that implements an entire subtree of the object-path
+ tree.
+
+ :Since: 0.82.0
+ """
+
+ SUPPORTS_MULTIPLE_OBJECT_PATHS = True
+
+ def __init__(self, conn=None, object_path=None):
+ """Constructor.
+
+ Note that the superclass' ``bus_name`` __init__ argument is not
+ supported here.
+
+ :Parameters:
+ `conn` : dbus.connection.Connection or None
+ The connection on which to export this object. If this is not
+ None, an `object_path` must also be provided.
+
+ If None, the object is not initially available on any
+ Connection.
+
+ `object_path` : str or None
+ A D-Bus object path at which to make this Object available
+ immediately. If this is not None, a `conn` must also be
+ provided.
+
+ This object will implements all object-paths in the subtree
+ starting at this object-path, except where a more specific
+ object has been added.
+ """
+ super(FallbackObject, self).__init__()
+ self._fallback = True
+
+ if conn is None:
+ if object_path is not None:
+ raise TypeError('If object_path is given, conn is required')
+ elif object_path is None:
+ raise TypeError('If conn is given, object_path is required')
+ else:
+ self.add_to_connection(conn, object_path)
diff --git a/lib/dbus/types.py b/lib/dbus/types.py
new file mode 100644
index 0000000..461639e
--- /dev/null
+++ b/lib/dbus/types.py
@@ -0,0 +1,15 @@
+# Copyright 2006-2021 Collabora Ltd.
+# Copyright 2011 Barry Warsaw
+# SPDX-License-Identifier: MIT
+
+__all__ = ['ObjectPath', 'ByteArray', 'Signature', 'Byte', 'Boolean',
+ 'Int16', 'UInt16', 'Int32', 'UInt32', 'Int64', 'UInt64',
+ 'Double', 'String', 'Array', 'Struct', 'Dictionary',
+ 'UnixFd']
+
+from _dbus_bindings import (
+ Array, Boolean, Byte, ByteArray, Dictionary, Double, Int16, Int32, Int64,
+ ObjectPath, Signature, String, Struct, UInt16, UInt32, UInt64,
+ UnixFd)
+
+from dbus._compat import is_py2
diff --git a/lib/greenlet-1.1.3.dist-info/AUTHORS b/lib/greenlet-1.1.3.dist-info/AUTHORS
new file mode 100644
index 0000000..42a5c22
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/AUTHORS
@@ -0,0 +1,51 @@
+Original Authors
+----------------
+* Armin Rigo
+* Christian Tismer
+
+Contributors
+------------
+* Al Stone
+* Alexander Schmidt
+* Alexey Borzenkov
+* Andreas Schwab
+* Armin Ronacher
+* Bin Wang <feisuzhu@163.com>
+* Bob Ippolito
+* ChangBo Guo
+* Christoph Gohlke
+* Denis Bilenko
+* Dirk Mueller
+* Donovan Preston
+* Fantix King
+* Floris Bruynooghe
+* Fredrik Fornwall
+* Gerd Woetzel
+* Giel van Schijndel
+* Gökhan Karabulut
+* Gustavo Niemeyer
+* Guy Rozendorn
+* Hye-Shik Chang
+* Jared Kuolt
+* Jason Madden
+* Josh Snyder
+* Kyle Ambroff
+* Laszlo Boszormenyi
+* Mao Han
+* Marc Abramowitz
+* Marc Schlaich
+* Marcin Bachry
+* Matt Madison
+* Matt Turner
+* Michael Ellerman
+* Michael Matz
+* Ralf Schmitt
+* Robie Basak
+* Ronny Pfannschmidt
+* Samual M. Rushing
+* Tony Bowles
+* Tony Breeds
+* Trevor Bowen
+* Tulio Magno Quites Machado Filho
+* Ulrich Weigand
+* Victor Stinner
diff --git a/lib/greenlet-1.1.3.dist-info/INSTALLER b/lib/greenlet-1.1.3.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/greenlet-1.1.3.dist-info/LICENSE b/lib/greenlet-1.1.3.dist-info/LICENSE
new file mode 100644
index 0000000..b73a4a1
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/LICENSE
@@ -0,0 +1,30 @@
+The following files are derived from Stackless Python and are subject to the
+same license as Stackless Python:
+
+ src/greenlet/slp_platformselect.h
+ files in src/greenlet/platform/ directory
+
+See LICENSE.PSF and http://www.stackless.com/ for details.
+
+Unless otherwise noted, the files in greenlet have been released under the
+following MIT license:
+
+Copyright (c) Armin Rigo, Christian Tismer and contributors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/lib/greenlet-1.1.3.dist-info/LICENSE.PSF b/lib/greenlet-1.1.3.dist-info/LICENSE.PSF
new file mode 100644
index 0000000..d3b509a
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/LICENSE.PSF
@@ -0,0 +1,47 @@
+PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
+--------------------------------------------
+
+1. This LICENSE AGREEMENT is between the Python Software Foundation
+("PSF"), and the Individual or Organization ("Licensee") accessing and
+otherwise using this software ("Python") in source or binary form and
+its associated documentation.
+
+2. Subject to the terms and conditions of this License Agreement, PSF hereby
+grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce,
+analyze, test, perform and/or display publicly, prepare derivative works,
+distribute, and otherwise use Python alone or in any derivative version,
+provided, however, that PSF's License Agreement and PSF's notice of copyright,
+i.e., "Copyright (c) 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010,
+2011 Python Software Foundation; All Rights Reserved" are retained in Python
+alone or in any derivative version prepared by Licensee.
+
+3. In the event Licensee prepares a derivative work that is based on
+or incorporates Python or any part thereof, and wants to make
+the derivative work available to others as provided herein, then
+Licensee hereby agrees to include in any such work a brief summary of
+the changes made to Python.
+
+4. PSF is making Python available to Licensee on an "AS IS"
+basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
+IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND
+DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
+FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT
+INFRINGE ANY THIRD PARTY RIGHTS.
+
+5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON
+FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS
+A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON,
+OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF.
+
+6. This License Agreement will automatically terminate upon a material
+breach of its terms and conditions.
+
+7. Nothing in this License Agreement shall be deemed to create any
+relationship of agency, partnership, or joint venture between PSF and
+Licensee. This License Agreement does not grant permission to use PSF
+trademarks or trade name in a trademark sense to endorse or promote
+products or services of Licensee, or any third party.
+
+8. By copying, installing or otherwise using Python, Licensee
+agrees to be bound by the terms and conditions of this License
+Agreement.
diff --git a/lib/greenlet-1.1.3.dist-info/METADATA b/lib/greenlet-1.1.3.dist-info/METADATA
new file mode 100644
index 0000000..23b5f72
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/METADATA
@@ -0,0 +1,103 @@
+Metadata-Version: 2.1
+Name: greenlet
+Version: 1.1.3
+Summary: Lightweight in-process concurrent programming
+Home-page: https://greenlet.readthedocs.io/
+Author: Alexey Borzenkov
+Author-email: snaury@gmail.com
+Maintainer: Jason Madden
+Maintainer-email: jason@nextthought.com
+License: MIT License
+Project-URL: Bug Tracker, https://github.com/python-greenlet/greenlet/issues
+Project-URL: Source Code, https://github.com/python-greenlet/greenlet/
+Project-URL: Documentation, https://greenlet.readthedocs.io/
+Keywords: greenlet coroutine concurrency threads cooperative
+Platform: any
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Intended Audience :: Developers
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Natural Language :: English
+Classifier: Programming Language :: C
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.5
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: 3.11
+Classifier: Operating System :: OS Independent
+Classifier: Topic :: Software Development :: Libraries :: Python Modules
+Requires-Python: >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*
+Description-Content-Type: text/x-rst
+License-File: LICENSE
+License-File: LICENSE.PSF
+License-File: AUTHORS
+Provides-Extra: docs
+Requires-Dist: Sphinx ; extra == 'docs'
+Provides-Extra: test
+
+.. This file is included into docs/history.rst
+
+.. image:: https://github.com/python-greenlet/greenlet/workflows/tests/badge.svg
+ :target: https://github.com/python-greenlet/greenlet/actions
+
+Greenlets are lightweight coroutines for in-process concurrent
+programming.
+
+The "greenlet" package is a spin-off of `Stackless`_, a version of
+CPython that supports micro-threads called "tasklets". Tasklets run
+pseudo-concurrently (typically in a single or a few OS-level threads)
+and are synchronized with data exchanges on "channels".
+
+A "greenlet", on the other hand, is a still more primitive notion of
+micro-thread with no implicit scheduling; coroutines, in other words.
+This is useful when you want to control exactly when your code runs.
+You can build custom scheduled micro-threads on top of greenlet;
+however, it seems that greenlets are useful on their own as a way to
+make advanced control flow structures. For example, we can recreate
+generators; the difference with Python's own generators is that our
+generators can call nested functions and the nested functions can
+yield values too. (Additionally, you don't need a "yield" keyword. See
+the example in `test_generator.py
+<https://github.com/python-greenlet/greenlet/blob/adca19bf1f287b3395896a8f41f3f4fd1797fdc7/src/greenlet/tests/test_generator.py#L1>`_).
+
+Greenlets are provided as a C extension module for the regular unmodified
+interpreter.
+
+.. _`Stackless`: http://www.stackless.com
+
+
+Who is using Greenlet?
+======================
+
+There are several libraries that use Greenlet as a more flexible
+alternative to Python's built in coroutine support:
+
+ - `Concurrence`_
+ - `Eventlet`_
+ - `Gevent`_
+
+.. _Concurrence: http://opensource.hyves.org/concurrence/
+.. _Eventlet: http://eventlet.net/
+.. _Gevent: http://www.gevent.org/
+
+Getting Greenlet
+================
+
+The easiest way to get Greenlet is to install it with pip::
+
+ pip install greenlet
+
+
+Source code archives and binary distributions are vailable on the
+python package index at https://pypi.org/project/greenlet
+
+The source code repository is hosted on github:
+https://github.com/python-greenlet/greenlet
+
+Documentation is available on readthedocs.org:
+https://greenlet.readthedocs.io
diff --git a/lib/greenlet-1.1.3.dist-info/RECORD b/lib/greenlet-1.1.3.dist-info/RECORD
new file mode 100644
index 0000000..b62ed99
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/RECORD
@@ -0,0 +1,71 @@
+../../include/python/greenlet/greenlet.h,sha256=muQGuDPNWzBVjWoObFXddpDP_DLeE2GtdnF41cyYgy0,4648
+greenlet-1.1.3.dist-info/AUTHORS,sha256=swW28t2knVRxRkaEQNZtO7MP9Sgnompb7B6cNgJM8Gk,849
+greenlet-1.1.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+greenlet-1.1.3.dist-info/LICENSE,sha256=dpgx1uXfrywggC-sz_H6-0wgJd2PYlPfpH_K1Z1NCXk,1434
+greenlet-1.1.3.dist-info/LICENSE.PSF,sha256=5f88I8EQ5JTNfXNsEP2W1GJFe6_soxCEDbZScpjH1Gs,2424
+greenlet-1.1.3.dist-info/METADATA,sha256=DQAWGnxur5YBtMAo1zxHoCE1xEBsLwFP-df-a0A7oWU,3930
+greenlet-1.1.3.dist-info/RECORD,,
+greenlet-1.1.3.dist-info/WHEEL,sha256=FNUt4eBsrBVn_Yc5KG3aXtKE40X3uNZrbGLNCbVxyFw,148
+greenlet-1.1.3.dist-info/top_level.txt,sha256=YSnRsCRoO61JGlP57o8iKL6rdLWDWuiyKD8ekpWUsDc,9
+greenlet/__init__.py,sha256=f2pBI8kauTC7tFFi8r-JUUPXuthYvspSRCNENiqAH8k,1323
+greenlet/__pycache__/__init__.cpython-39.pyc,,
+greenlet/_greenlet.cpython-39-x86_64-linux-gnu.so,sha256=hzBW5fYe7jLbF4ciFhCXGf7tPXNUcAQA3GJXKdAPJZM,130456
+greenlet/greenlet.c,sha256=tTKIwaPu9MhiGwhtlSkWWbSTXPbddgY-Xoq7xUfvfpA,67295
+greenlet/greenlet.h,sha256=muQGuDPNWzBVjWoObFXddpDP_DLeE2GtdnF41cyYgy0,4648
+greenlet/platform/setup_switch_x64_masm.cmd,sha256=ZpClUJeU0ujEPSTWNSepP0W2f9XiYQKA8QKSoVou8EU,143
+greenlet/platform/switch_aarch64_gcc.h,sha256=TRH22e9TNRA_mys8hhLbNwz3efZk7BtKZhyhK7ucgyM,2385
+greenlet/platform/switch_alpha_unix.h,sha256=T6kOBiHy3hLmy1vrmFrxbnOnRu0EJkoG_yuWy7fykZ4,689
+greenlet/platform/switch_amd64_unix.h,sha256=KWB4PB2wcAaWvWbMzcq8tYBe02vEGPBCRMnHnfeI7gE,2610
+greenlet/platform/switch_arm32_gcc.h,sha256=wflI2cGZBfLzM_GGgYx3OrFeoOq7OTsJP53dKLsrxS0,2488
+greenlet/platform/switch_arm32_ios.h,sha256=yQZXCa0AZbyAIS9tKceyTCrRYlihpFBKDbiPCn_3im0,1901
+greenlet/platform/switch_csky_gcc.h,sha256=GHlaVXrzQuSkrDqgL7-Ji9YwZnprpFhjPznNyp0NnvU,1340
+greenlet/platform/switch_m68k_gcc.h,sha256=VSa6NpZhvyyvF-Q58CTIWSpEDo4FKygOyTz00whctlw,928
+greenlet/platform/switch_mips_unix.h,sha256=9ptMGEBXafee15RxOm5NrxiC2bEnwM9AkxJ7ktVatU8,1444
+greenlet/platform/switch_ppc64_aix.h,sha256=ADpifLPlr6pTdT76bt6ozcqPjHrfPsJ93lQfc1VNaug,3878
+greenlet/platform/switch_ppc64_linux.h,sha256=jqPKpTg09FzmCn59Kt6OJi2-40aoazFVJcf1YETLlwA,3833
+greenlet/platform/switch_ppc_aix.h,sha256=nClVVlsRlFAI-I3fmivSJyJK7Xzx3_8l3Wf8QNJ9FMU,2959
+greenlet/platform/switch_ppc_linux.h,sha256=J4eKMA73WbPYSaq0yAedzHB6J6ZKE8tIIzkqYxlaA2c,2777
+greenlet/platform/switch_ppc_macosx.h,sha256=bnL2MqIUm9--NHizb5NYijvSrqutvuJx4auYCdqXllM,2642
+greenlet/platform/switch_ppc_unix.h,sha256=5UW9c71NGJh6xksEbAOButBFH168QRyZ5O53yXdXGxg,2670
+greenlet/platform/switch_riscv_unix.h,sha256=c3v3GRDMooslDKQLM75IqokWivtelbAj3-XZK31vWlE,758
+greenlet/platform/switch_s390_unix.h,sha256=9oJkYnyUovPvXOAsVLXoj-Unl_Rr_DidkXYMaRXLS0w,2781
+greenlet/platform/switch_sparc_sun_gcc.h,sha256=0vHXNNCdz-1ioQsw-OtK0ridnBVIzErYWiK7bBu6OgM,2815
+greenlet/platform/switch_x32_unix.h,sha256=ie7Nxo6Cf_x4UVOSA_a3bJYPlRKZ1BvLWsclyQle_SY,1527
+greenlet/platform/switch_x64_masm.asm,sha256=nu6n2sWyXuXfpPx40d9YmLfHXUc1sHgeTvX1kUzuvEM,1841
+greenlet/platform/switch_x64_masm.obj,sha256=GNtTNxYdo7idFUYsQv-mrXWgyT5EJ93-9q90lN6svtQ,1078
+greenlet/platform/switch_x64_msvc.h,sha256=LIeasyKo_vHzspdMzMHbosRhrBfKI4BkQOh4qcTHyJw,1805
+greenlet/platform/switch_x86_msvc.h,sha256=hi0dgp-k14IhMCxwtJtcI_ciPnMGd37uMnMaHaeQVWg,2481
+greenlet/platform/switch_x86_unix.h,sha256=WvY2sNMFIEfoFVNVakl-osygJui3pSnlVj5jBrdaU08,3068
+greenlet/slp_platformselect.h,sha256=-J5Px9Yk7Ths4hQTecC3iadxfte1CYaFoeqfg1lUl-A,3095
+greenlet/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+greenlet/tests/__pycache__/__init__.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_contextvars.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_cpp.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_extension_interface.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_gc.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_generator.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_generator_nested.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_greenlet.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_leaks.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_stack_saved.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_throw.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_tracing.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_version.cpython-39.pyc,,
+greenlet/tests/__pycache__/test_weakref.cpython-39.pyc,,
+greenlet/tests/_test_extension.c,sha256=Tceb6kMFPSvAPW2LJ_zUlj--Wz_DtLzIPmgZcqkqAEU,5402
+greenlet/tests/_test_extension.cpython-39-x86_64-linux-gnu.so,sha256=uxoAsgRTLnJ6R4EMtaqHtjMqsllAE8E5zuoXsSBg9I0,34632
+greenlet/tests/_test_extension_cpp.cpp,sha256=zKfz0FxBXicq-53rItZ_NP8M406OBtyQFdH5bv_pRmk,3212
+greenlet/tests/_test_extension_cpp.cpython-39-x86_64-linux-gnu.so,sha256=VZ-YrzJf0Isn4iPYb-ANhp014Zz4A7TlQqFUDY5LqSk,47368
+greenlet/tests/test_contextvars.py,sha256=d69XSuRrdU80xAPmzdObLjrjXnbTQChG0MgsvBF_nGM,9205
+greenlet/tests/test_cpp.py,sha256=SXMuqsHTYTxFPBrasdbx5Sgplc89wvYEuPZvwafD-3k,488
+greenlet/tests/test_extension_interface.py,sha256=1FhUkxL-NrxmQV_sxUdlt8tvIWpDcGi27JcdQ6VyvFc,2521
+greenlet/tests/test_gc.py,sha256=oATPCmEAagdf1dZBYfZ0aiDklovLo_pQt5HZNTygCzk,2892
+greenlet/tests/test_generator.py,sha256=_MLDA1kBtZQR-9a74AOZZQECQCIFljMa7vbucE0cOxw,1280
+greenlet/tests/test_generator_nested.py,sha256=pGYRpNn_WjdhY_5ZHHBuBw10wskG_7mjJjR8IqleY3M,3579
+greenlet/tests/test_greenlet.py,sha256=SVDi0e1RrJtJhiOFggmoWTZL1sFdxRpdALFRCie-n60,23427
+greenlet/tests/test_leaks.py,sha256=STvFoZsFsZ_E24kYFaIASGBx97TRgTIur6uJXnoevWc,6677
+greenlet/tests/test_stack_saved.py,sha256=SyIHZycTBfm1TxFsq1VLCAgVm02t5GSke8tT28qwi7c,450
+greenlet/tests/test_throw.py,sha256=OOWfgcEaymvGVJQ3d4xDGzC5IVH0rZAiazWuyZV9270,2755
+greenlet/tests/test_tracing.py,sha256=hZ6Cl5NMq9IaeH7NGqWYl8aQ0_5nFUSYuo6TeSXvrKw,7455
+greenlet/tests/test_version.py,sha256=lHDe3qcLvfsOHcFKFW8yrcl5wBvy6UIxaNkZZzNlpHE,1229
+greenlet/tests/test_weakref.py,sha256=gqAQunjVzbwF6qEUZijhv6UqhH4apWNIRHeoWLUo9tM,884
diff --git a/lib/greenlet-1.1.3.dist-info/WHEEL b/lib/greenlet-1.1.3.dist-info/WHEEL
new file mode 100644
index 0000000..271bfec
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/WHEEL
@@ -0,0 +1,6 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.37.1)
+Root-Is-Purelib: false
+Tag: cp39-cp39-manylinux_2_17_x86_64
+Tag: cp39-cp39-manylinux2014_x86_64
+
diff --git a/lib/greenlet-1.1.3.dist-info/top_level.txt b/lib/greenlet-1.1.3.dist-info/top_level.txt
new file mode 100644
index 0000000..46725be
--- /dev/null
+++ b/lib/greenlet-1.1.3.dist-info/top_level.txt
@@ -0,0 +1 @@
+greenlet
diff --git a/lib/greenlet/__init__.py b/lib/greenlet/__init__.py
new file mode 100644
index 0000000..22db798
--- /dev/null
+++ b/lib/greenlet/__init__.py
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+"""
+The root of the greenlet package.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+__all__ = [
+ '__version__',
+ '_C_API',
+
+ 'GreenletExit',
+ 'error',
+
+ 'getcurrent',
+ 'greenlet',
+
+ 'gettrace',
+ 'settrace',
+]
+
+# pylint:disable=no-name-in-module
+
+###
+# Metadata
+###
+__version__ = '1.1.3'
+from ._greenlet import _C_API # pylint:disable=no-name-in-module
+
+###
+# Exceptions
+###
+from ._greenlet import GreenletExit
+from ._greenlet import error
+
+###
+# greenlets
+###
+from ._greenlet import getcurrent
+from ._greenlet import greenlet
+
+###
+# tracing
+###
+try:
+ from ._greenlet import gettrace
+ from ._greenlet import settrace
+except ImportError:
+ # Tracing wasn't supported.
+ # XXX: The option to disable it was removed in 1.0,
+ # so this branch should be dead code.
+ pass
+
+###
+# Constants
+# These constants aren't documented and aren't recommended.
+# In 1.0, USE_GC and USE_TRACING are always true, and USE_CONTEXT_VARS
+# is the same as ``sys.version_info[:2] >= 3.7``
+###
+from ._greenlet import GREENLET_USE_CONTEXT_VARS # pylint:disable=unused-import
+from ._greenlet import GREENLET_USE_GC # pylint:disable=unused-import
+from ._greenlet import GREENLET_USE_TRACING # pylint:disable=unused-import
diff --git a/lib/greenlet/_greenlet.cpython-39-x86_64-linux-gnu.so b/lib/greenlet/_greenlet.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..4414bc0
--- /dev/null
+++ b/lib/greenlet/_greenlet.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/greenlet/greenlet.c b/lib/greenlet/greenlet.c
new file mode 100644
index 0000000..2f3ad6e
--- /dev/null
+++ b/lib/greenlet/greenlet.c
@@ -0,0 +1,2170 @@
+/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
+/* Format with:
+ * clang-format -i --style=file src/greenlet/greenlet.c
+ *
+ *
+ * Fix missing braces with:
+ * clang-tidy src/greenlet/greenlet.c -fix -checks="readability-braces-around-statements"
+*/
+#define GREENLET_MODULE
+
+#include "greenlet.h"
+
+#include "structmember.h"
+
+#ifdef __clang__
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wunused-parameter"
+# pragma clang diagnostic ignored "-Wmissing-field-initializers"
+#endif
+
+/***********************************************************
+
+A PyGreenlet is a range of C stack addresses that must be
+saved and restored in such a way that the full range of the
+stack contains valid data when we switch to it.
+
+Stack layout for a greenlet:
+
+ | ^^^ |
+ | older data |
+ | |
+ stack_stop . |_______________|
+ . | |
+ . | greenlet data |
+ . | in stack |
+ . * |_______________| . . _____________ stack_copy + stack_saved
+ . | | | |
+ . | data | |greenlet data|
+ . | unrelated | | saved |
+ . | to | | in heap |
+ stack_start . | this | . . |_____________| stack_copy
+ | greenlet |
+ | |
+ | newer data |
+ | vvv |
+
+
+Note that a greenlet's stack data is typically partly at its correct
+place in the stack, and partly saved away in the heap, but always in
+the above configuration: two blocks, the more recent one in the heap
+and the older one still in the stack (either block may be empty).
+
+Greenlets are chained: each points to the previous greenlet, which is
+the one that owns the data currently in the C stack above my
+stack_stop. The currently running greenlet is the first element of
+this chain. The main (initial) greenlet is the last one. Greenlets
+whose stack is entirely in the heap can be skipped from the chain.
+
+The chain is not related to execution order, but only to the order
+in which bits of C stack happen to belong to greenlets at a particular
+point in time.
+
+The main greenlet doesn't have a stack_stop: it is responsible for the
+complete rest of the C stack, and we don't know where it begins. We
+use (char*) -1, the largest possible address.
+
+States:
+ stack_stop == NULL && stack_start == NULL: did not start yet
+ stack_stop != NULL && stack_start == NULL: already finished
+ stack_stop != NULL && stack_start != NULL: active
+
+The running greenlet's stack_start is undefined but not NULL.
+
+ ***********************************************************/
+
+/*** global state ***/
+
+/* In the presence of multithreading, this is a bit tricky:
+
+ - ts_current always store a reference to a greenlet, but it is
+ not really the current greenlet after a thread switch occurred.
+
+ - each *running* greenlet uses its run_info field to know which
+ thread it is attached to. A greenlet can only run in the thread
+ where it was created. This run_info is a ref to tstate->dict.
+
+ - the thread state dict is used to save and restore ts_current,
+ using the dictionary key 'ts_curkey'.
+*/
+
+extern PyTypeObject PyGreenlet_Type;
+
+#if PY_VERSION_HEX >= 0x030700A3
+# define GREENLET_PY37 1
+#else
+# define GREENLET_PY37 0
+#endif
+
+#if PY_VERSION_HEX >= 0x30A00B1
+/*
+Python 3.10 beta 1 changed tstate->use_tracing to a nested cframe member.
+See https://github.com/python/cpython/pull/25276
+We have to save and restore this as well.
+*/
+#define TSTATE_USE_TRACING(tstate) (tstate->cframe->use_tracing)
+#define GREENLET_USE_CFRAME 1
+#else
+#define TSTATE_USE_TRACING(tstate) (tstate->use_tracing)
+#define GREENLET_USE_CFRAME 0
+#endif
+
+#ifndef Py_SET_REFCNT
+/* Py_REFCNT and Py_SIZE macros are converted to functions
+https://bugs.python.org/issue39573 */
+# define Py_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt)
+#endif
+
+#ifndef _Py_DEC_REFTOTAL
+/* _Py_DEC_REFTOTAL macro has been removed from Python 3.9 by:
+ https://github.com/python/cpython/commit/49932fec62c616ec88da52642339d83ae719e924
+*/
+# ifdef Py_REF_DEBUG
+# define _Py_DEC_REFTOTAL _Py_RefTotal--
+# else
+# define _Py_DEC_REFTOTAL
+# endif
+#endif
+
+/* Weak reference to the switching-to greenlet during the slp switch */
+static PyGreenlet* volatile ts_target = NULL;
+/* Strong reference to the switching from greenlet after the switch */
+static PyGreenlet* volatile ts_origin = NULL;
+/* Strong reference to the current greenlet in this thread state */
+static PyGreenlet* volatile ts_current = NULL;
+/* NULL if error, otherwise args tuple to pass around during slp switch */
+static PyObject* volatile ts_passaround_args = NULL;
+static PyObject* volatile ts_passaround_kwargs = NULL;
+
+/* Used internally in ``g_switchstack()`` */
+#if GREENLET_USE_CFRAME
+static int volatile ts__g_switchstack_use_tracing = 0;
+#endif
+
+/***********************************************************/
+/* Thread-aware routines, switching global variables when needed */
+
+#define STATE_OK \
+ (ts_current->run_info == PyThreadState_GET()->dict || \
+ !green_updatecurrent())
+
+static PyObject* ts_curkey;
+static PyObject* ts_delkey;
+static PyObject* ts_tracekey;
+static PyObject* ts_event_switch;
+static PyObject* ts_event_throw;
+static PyObject* PyExc_GreenletError;
+static PyObject* PyExc_GreenletExit;
+static PyObject* ts_empty_tuple;
+static PyObject* ts_empty_dict;
+
+#define GREENLET_GC_FLAGS Py_TPFLAGS_HAVE_GC
+#define GREENLET_tp_alloc PyType_GenericAlloc
+#define GREENLET_tp_free PyObject_GC_Del
+#define GREENLET_tp_traverse green_traverse
+#define GREENLET_tp_clear green_clear
+#define GREENLET_tp_is_gc green_is_gc
+
+static void
+green_clear_exc(PyGreenlet* g)
+{
+#if GREENLET_PY37
+ g->exc_info = NULL;
+ g->exc_state.exc_value = NULL;
+#if !GREENLET_PY311
+ g->exc_state.exc_type = NULL;
+ g->exc_state.exc_traceback = NULL;
+#endif
+ g->exc_state.previous_item = NULL;
+#else
+ g->exc_type = NULL;
+ g->exc_value = NULL;
+ g->exc_traceback = NULL;
+#endif
+}
+
+static PyGreenlet*
+green_create_main(void)
+{
+ PyGreenlet* gmain;
+ PyObject* dict = PyThreadState_GetDict();
+ if (dict == NULL) {
+ if (!PyErr_Occurred()) {
+ PyErr_NoMemory();
+ }
+ return NULL;
+ }
+
+ /* create the main greenlet for this thread */
+ gmain = (PyGreenlet*)PyType_GenericAlloc(&PyGreenlet_Type, 0);
+ if (gmain == NULL) {
+ return NULL;
+ }
+ gmain->stack_start = (char*)1;
+ gmain->stack_stop = (char*)-1;
+ /* GetDict() returns a borrowed reference. Make it strong. */
+ gmain->run_info = dict;
+ Py_INCREF(dict);
+ return gmain;
+}
+
+static int
+green_updatecurrent(void)
+{
+ PyObject *exc, *val, *tb;
+ PyThreadState* tstate;
+ PyGreenlet* current;
+ PyGreenlet* previous;
+ PyObject* deleteme;
+
+green_updatecurrent_restart:
+ /* save current exception */
+ PyErr_Fetch(&exc, &val, &tb);
+
+ /* get ts_current from the active tstate */
+ tstate = PyThreadState_GET();
+ if (tstate->dict &&
+ (current = (PyGreenlet*)PyDict_GetItem(tstate->dict, ts_curkey))) {
+ /* found -- remove it, to avoid keeping a ref */
+ Py_INCREF(current);
+ PyDict_DelItem(tstate->dict, ts_curkey);
+ }
+ else {
+ /* first time we see this tstate */
+ current = green_create_main();
+ if (current == NULL) {
+ Py_XDECREF(exc);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ return -1;
+ }
+ }
+ assert(current->run_info == tstate->dict);
+
+green_updatecurrent_retry:
+ /* update ts_current as soon as possible, in case of nested switches */
+ Py_INCREF(current);
+ previous = ts_current;
+ ts_current = current;
+
+ /* save ts_current as the current greenlet of its own thread */
+ if (PyDict_SetItem(previous->run_info, ts_curkey, (PyObject*)previous)) {
+ Py_DECREF(previous);
+ Py_DECREF(current);
+ Py_XDECREF(exc);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ return -1;
+ }
+ Py_DECREF(previous);
+
+ /* green_dealloc() cannot delete greenlets from other threads, so
+ it stores them in the thread dict; delete them now. */
+ deleteme = PyDict_GetItem(tstate->dict, ts_delkey);
+ if (deleteme != NULL) {
+ /* The only reference to these greenlets should be in this list, so
+ clearing the list should let them be deleted again, triggering
+ calls to green_dealloc() in the correct thread. This may run
+ arbitrary Python code?
+ */
+ PyList_SetSlice(deleteme, 0, INT_MAX, NULL);
+ }
+
+ if (ts_current != current) {
+ /* some Python code executed above and there was a thread switch,
+ * so ts_current points to some other thread again. We need to
+ * delete ts_curkey (it's likely there) and retry. */
+ PyDict_DelItem(tstate->dict, ts_curkey);
+ goto green_updatecurrent_retry;
+ }
+
+ /* release an extra reference */
+ Py_DECREF(current);
+ /* restore current exception */
+ PyErr_Restore(exc, val, tb);
+
+ /* thread switch could happen during PyErr_Restore, in that
+ case there's nothing to do except restart from scratch. */
+ if (ts_current->run_info != tstate->dict) {
+ goto green_updatecurrent_restart;
+ }
+ return 0;
+}
+
+static PyObject*
+green_statedict(PyGreenlet* g)
+{
+ while (!PyGreenlet_STARTED(g)) {
+ g = g->parent;
+ if (g == NULL) {
+ /* garbage collected greenlet in chain */
+ return NULL;
+ }
+ }
+ return g->run_info;
+}
+
+/***********************************************************/
+
+/* Some functions must not be inlined:
+ * slp_restore_state, when inlined into slp_switch might cause
+ it to restore stack over its own local variables
+ * slp_save_state, when inlined would add its own local
+ variables to the saved stack, wasting space
+ * slp_switch, cannot be inlined for obvious reasons
+ * g_initialstub, when inlined would receive a pointer into its
+ own stack frame, leading to incomplete stack save/restore
+*/
+
+#if defined(__GNUC__) && \
+ (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4))
+# define GREENLET_NOINLINE_SUPPORTED
+# define GREENLET_NOINLINE(name) __attribute__((noinline)) name
+#elif defined(_MSC_VER) && (_MSC_VER >= 1300)
+# define GREENLET_NOINLINE_SUPPORTED
+# define GREENLET_NOINLINE(name) __declspec(noinline) name
+#endif
+
+#ifdef GREENLET_NOINLINE_SUPPORTED
+/* add forward declarations */
+static void GREENLET_NOINLINE(slp_restore_state)(void);
+static int GREENLET_NOINLINE(slp_save_state)(char*);
+# if !(defined(MS_WIN64) && defined(_M_X64))
+static int GREENLET_NOINLINE(slp_switch)(void);
+# endif
+static int GREENLET_NOINLINE(g_initialstub)(void*);
+# define GREENLET_NOINLINE_INIT() \
+ do { \
+ } while (0)
+#else
+/* force compiler to call functions via pointers */
+static void (*slp_restore_state)(void);
+static int (*slp_save_state)(char*);
+static int (*slp_switch)(void);
+static int (*g_initialstub)(void*);
+# define GREENLET_NOINLINE(name) cannot_inline_##name
+# define GREENLET_NOINLINE_INIT() \
+ do { \
+ slp_restore_state = GREENLET_NOINLINE(slp_restore_state); \
+ slp_save_state = GREENLET_NOINLINE(slp_save_state); \
+ slp_switch = GREENLET_NOINLINE(slp_switch); \
+ g_initialstub = GREENLET_NOINLINE(g_initialstub); \
+ } while (0)
+#endif
+
+/*
+ * the following macros are spliced into the OS/compiler
+ * specific code, in order to simplify maintenance.
+ */
+
+#define SLP_SAVE_STATE(stackref, stsizediff) \
+ stackref += STACK_MAGIC; \
+ if (slp_save_state((char*)stackref)) \
+ return -1; \
+ if (!PyGreenlet_ACTIVE(ts_target)) \
+ return 1; \
+ stsizediff = ts_target->stack_start - (char*)stackref
+
+#define SLP_RESTORE_STATE() slp_restore_state()
+
+#define SLP_EVAL
+#define slp_switch GREENLET_NOINLINE(slp_switch)
+#include "slp_platformselect.h"
+#undef slp_switch
+
+#ifndef STACK_MAGIC
+# error \
+ "greenlet needs to be ported to this platform, or taught how to detect your compiler properly."
+#endif /* !STACK_MAGIC */
+
+#ifdef EXTERNAL_ASM
+/* CCP addition: Make these functions, to be called from assembler.
+ * The token include file for the given platform should enable the
+ * EXTERNAL_ASM define so that this is included.
+ */
+
+intptr_t
+slp_save_state_asm(intptr_t* ref)
+{
+ intptr_t diff;
+ SLP_SAVE_STATE(ref, diff);
+ return diff;
+}
+
+void
+slp_restore_state_asm(void)
+{
+ SLP_RESTORE_STATE();
+}
+
+extern int
+slp_switch(void);
+
+#endif
+
+/***********************************************************/
+
+static int
+g_save(PyGreenlet* g, char* stop)
+{
+ /* Save more of g's stack into the heap -- at least up to 'stop'
+
+ g->stack_stop |________|
+ | |
+ | __ stop . . . . .
+ | | ==> . .
+ |________| _______
+ | | | |
+ | | | |
+ g->stack_start | | |_______| g->stack_copy
+
+ */
+ intptr_t sz1 = g->stack_saved;
+ intptr_t sz2 = stop - g->stack_start;
+ assert(g->stack_start != NULL);
+ if (sz2 > sz1) {
+ char* c = (char*)PyMem_Realloc(g->stack_copy, sz2);
+ if (!c) {
+ PyErr_NoMemory();
+ return -1;
+ }
+ memcpy(c + sz1, g->stack_start + sz1, sz2 - sz1);
+ g->stack_copy = c;
+ g->stack_saved = sz2;
+ }
+ return 0;
+}
+
+static void GREENLET_NOINLINE(slp_restore_state)(void)
+{
+ PyGreenlet* g = ts_target;
+ PyGreenlet* owner = ts_current;
+
+#ifdef SLP_BEFORE_RESTORE_STATE
+ SLP_BEFORE_RESTORE_STATE();
+#endif
+
+ /* Restore the heap copy back into the C stack */
+ if (g->stack_saved != 0) {
+ memcpy(g->stack_start, g->stack_copy, g->stack_saved);
+ PyMem_Free(g->stack_copy);
+ g->stack_copy = NULL;
+ g->stack_saved = 0;
+ }
+ if (owner->stack_start == NULL) {
+ owner = owner->stack_prev; /* greenlet is dying, skip it */
+ }
+ while (owner && owner->stack_stop <= g->stack_stop) {
+ owner = owner->stack_prev; /* find greenlet with more stack */
+ }
+ g->stack_prev = owner;
+}
+
+static int GREENLET_NOINLINE(slp_save_state)(char* stackref)
+{
+ /* must free all the C stack up to target_stop */
+ char* target_stop = ts_target->stack_stop;
+ PyGreenlet* owner = ts_current;
+ assert(owner->stack_saved == 0);
+ if (owner->stack_start == NULL) {
+ owner = owner->stack_prev; /* not saved if dying */
+ }
+ else {
+ owner->stack_start = stackref;
+ }
+
+#ifdef SLP_BEFORE_SAVE_STATE
+ SLP_BEFORE_SAVE_STATE();
+#endif
+
+ while (owner->stack_stop < target_stop) {
+ /* ts_current is entierely within the area to free */
+ if (g_save(owner, owner->stack_stop)) {
+ return -1; /* XXX */
+ }
+ owner = owner->stack_prev;
+ }
+ if (owner != ts_target) {
+ if (g_save(owner, target_stop)) {
+ return -1; /* XXX */
+ }
+ }
+ return 0;
+}
+
+/**
+ Perform a stack switch according to some global variables
+ that must be set before calling this function. Those variables
+ are:
+
+ - ts_current: current greenlet (holds a reference)
+ - ts_target: greenlet to switch to (weak reference)
+ - ts_passaround_args: NULL if PyErr_Occurred(),
+ else a tuple of args sent to ts_target (holds a reference)
+ - ts_passaround_kwargs: switch kwargs (holds a reference)
+
+ Because the stack switch happens in this function, this function can't use
+ its own stack (local) variables, set before the switch, and then accessed after the
+ switch. Global variables beginning with ``ts__g_switchstack`` are used
+ internally instead.
+
+ On return results are passed via global variables as well:
+
+ - ts_origin: originating greenlet (holds a reference)
+ - ts_current: current greenlet (holds a reference)
+ - ts_passaround_args: NULL if PyErr_Occurred(),
+ else a tuple of args sent to ts_current (holds a reference)
+ - ts_passaround_kwargs: switch kwargs (holds a reference)
+
+ It is very important that stack switch is 'atomic', i.e. no
+ calls into other Python code allowed (except very few that
+ are safe), because global variables are very fragile.
+*/
+static int
+g_switchstack(void)
+{
+ int err;
+ { /* save state */
+ PyGreenlet* current = ts_current;
+ PyThreadState* tstate = PyThreadState_GET();
+#if GREENLET_PY311
+ current->recursion_depth = (tstate->recursion_limit
+ - tstate->recursion_remaining);
+#else
+ current->recursion_depth = tstate->recursion_depth;
+ current->top_frame = tstate->frame;
+#endif
+#if GREENLET_PY37
+ current->context = tstate->context;
+#endif
+#if GREENLET_PY37
+ current->exc_info = tstate->exc_info;
+ current->exc_state = tstate->exc_state;
+#else
+ current->exc_type = tstate->exc_type;
+ current->exc_value = tstate->exc_value;
+ current->exc_traceback = tstate->exc_traceback;
+#endif
+#if GREENLET_USE_CFRAME
+ /*
+ IMPORTANT: ``cframe`` is a pointer into the STACK.
+ Thus, because the call to ``slp_switch()``
+ changes the contents of the stack, you cannot read from
+ ``ts_current->cframe`` after that call and necessarily
+ get the same values you get from reading it here. Anything
+ you need to restore from now to then must be saved
+ in a global variable (because we can't use stack variables
+ here either).
+ */
+ current->cframe = tstate->cframe;
+ ts__g_switchstack_use_tracing = tstate->cframe->use_tracing;
+#if GREENLET_PY311
+ current->current_frame = tstate->cframe->current_frame;
+ current->datastack_chunk = tstate->datastack_chunk;
+ current->datastack_top = tstate->datastack_top;
+ current->datastack_limit = tstate->datastack_limit;
+ PyFrameObject *frame = PyThreadState_GetFrame(tstate);
+ Py_XDECREF(frame); /* PyThreadState_GetFrame gives us a new reference. */
+ current->top_frame = frame;
+#endif
+#endif
+ }
+
+ err = slp_switch();
+
+ if (err < 0) { /* error */
+ PyGreenlet* current = ts_current;
+ current->top_frame = NULL;
+#if GREENLET_PY37
+ green_clear_exc(current);
+#else
+ current->exc_type = NULL;
+ current->exc_value = NULL;
+ current->exc_traceback = NULL;
+#endif
+
+ assert(ts_origin == NULL);
+ ts_target = NULL;
+ }
+ else {
+ PyGreenlet* target = ts_target;
+ PyGreenlet* origin = ts_current;
+ PyThreadState* tstate = PyThreadState_GET();
+
+#if GREENLET_PY37
+ tstate->context = target->context;
+ target->context = NULL;
+ /* Incrementing this value invalidates the contextvars cache,
+ which would otherwise remain valid across switches */
+ tstate->context_ver++;
+#endif
+
+#if GREENLET_PY37
+ tstate->exc_state = target->exc_state;
+ tstate->exc_info =
+ target->exc_info ? target->exc_info : &tstate->exc_state;
+#else
+ tstate->exc_type = target->exc_type;
+ tstate->exc_value = target->exc_value;
+ tstate->exc_traceback = target->exc_traceback;
+#endif
+ green_clear_exc(target);
+
+#if GREENLET_USE_CFRAME
+ tstate->cframe = target->cframe;
+ /*
+ If we were tracing, we need to keep tracing.
+ There should never be the possibility of hitting the
+ root_cframe here. See note above about why we can't
+ just copy this from ``origin->cframe->use_tracing``.
+ */
+ tstate->cframe->use_tracing = ts__g_switchstack_use_tracing;
+#endif
+#if GREENLET_PY311
+ tstate->recursion_remaining = (tstate->recursion_limit
+ - target->recursion_depth);
+ tstate->cframe->current_frame = target->current_frame;
+ tstate->datastack_chunk = target->datastack_chunk;
+ tstate->datastack_top = target->datastack_top;
+ tstate->datastack_limit = target->datastack_limit;
+#else
+ tstate->recursion_depth = target->recursion_depth;
+ tstate->frame = target->top_frame;
+#endif
+ target->top_frame = NULL;
+ assert(ts_origin == NULL);
+ Py_INCREF(target);
+ ts_current = target;
+ ts_origin = origin;
+ ts_target = NULL;
+ }
+ return err;
+}
+
+static int
+g_calltrace(PyObject* tracefunc, PyObject* event, PyGreenlet* origin,
+ PyGreenlet* target)
+{
+ PyObject* retval;
+ PyObject *exc_type, *exc_val, *exc_tb;
+ PyThreadState* tstate;
+ PyErr_Fetch(&exc_type, &exc_val, &exc_tb);
+ tstate = PyThreadState_GET();
+ tstate->tracing++;
+ TSTATE_USE_TRACING(tstate) = 0;
+ retval = PyObject_CallFunction(tracefunc, "O(OO)", event, origin, target);
+ tstate->tracing--;
+ TSTATE_USE_TRACING(tstate) =
+ (tstate->tracing <= 0 &&
+ ((tstate->c_tracefunc != NULL) || (tstate->c_profilefunc != NULL)));
+ if (retval == NULL) {
+ /* In case of exceptions trace function is removed */
+ if (PyDict_GetItem(tstate->dict, ts_tracekey)) {
+ PyDict_DelItem(tstate->dict, ts_tracekey);
+ }
+ Py_XDECREF(exc_type);
+ Py_XDECREF(exc_val);
+ Py_XDECREF(exc_tb);
+ return -1;
+ }
+ else {
+ Py_DECREF(retval);
+ }
+ PyErr_Restore(exc_type, exc_val, exc_tb);
+ return 0;
+}
+
+static PyObject*
+g_switch(PyGreenlet* target, PyObject* args, PyObject* kwargs)
+{
+ /* _consumes_ a reference to the args tuple and kwargs dict,
+ and return a new tuple reference */
+ int err = 0;
+ PyObject* run_info;
+
+ /* check ts_current */
+ if (!STATE_OK) {
+ Py_XDECREF(args);
+ Py_XDECREF(kwargs);
+ return NULL;
+ }
+ run_info = green_statedict(target);
+ if (run_info == NULL || run_info != ts_current->run_info) {
+ Py_XDECREF(args);
+ Py_XDECREF(kwargs);
+ PyErr_SetString(PyExc_GreenletError,
+ run_info ?
+ "cannot switch to a different thread" :
+ "cannot switch to a garbage collected greenlet");
+ return NULL;
+ }
+
+ ts_passaround_args = args;
+ ts_passaround_kwargs = kwargs;
+
+ /* find the real target by ignoring dead greenlets,
+ and if necessary starting a greenlet. */
+ while (target) {
+ if (PyGreenlet_ACTIVE(target)) {
+ ts_target = target;
+ err = g_switchstack();
+ break;
+ }
+ if (!PyGreenlet_STARTED(target)) {
+ void* dummymarker;
+ ts_target = target;
+ err = g_initialstub(&dummymarker);
+ if (err == 1) {
+ continue; /* retry the switch */
+ }
+ break;
+ }
+ target = target->parent;
+ }
+
+ /* For a very short time, immediately after the 'atomic'
+ g_switchstack() call, global variables are in a known state.
+ We need to save everything we need, before it is destroyed
+ by calls into arbitrary Python code. */
+ args = ts_passaround_args;
+ ts_passaround_args = NULL;
+ kwargs = ts_passaround_kwargs;
+ ts_passaround_kwargs = NULL;
+ if (err < 0) {
+ /* Turn switch errors into switch throws */
+ assert(ts_origin == NULL);
+ Py_CLEAR(kwargs);
+ Py_CLEAR(args);
+ }
+ else {
+ PyGreenlet* origin;
+ PyGreenlet* current;
+ PyObject* tracefunc;
+ origin = ts_origin;
+ ts_origin = NULL;
+ current = ts_current;
+ if ((tracefunc = PyDict_GetItem(current->run_info, ts_tracekey)) != NULL) {
+ Py_INCREF(tracefunc);
+ if (g_calltrace(tracefunc,
+ args ? ts_event_switch : ts_event_throw,
+ origin,
+ current) < 0) {
+ /* Turn trace errors into switch throws */
+ Py_CLEAR(kwargs);
+ Py_CLEAR(args);
+ }
+ Py_DECREF(tracefunc);
+ }
+
+ Py_DECREF(origin);
+ }
+
+ /* We need to figure out what values to pass to the target greenlet
+ based on the arguments that have been passed to greenlet.switch(). If
+ switch() was just passed an arg tuple, then we'll just return that.
+ If only keyword arguments were passed, then we'll pass the keyword
+ argument dict. Otherwise, we'll create a tuple of (args, kwargs) and
+ return both. */
+ if (kwargs == NULL) {
+ return args;
+ }
+ else if (PyDict_Size(kwargs) == 0) {
+ Py_DECREF(kwargs);
+ return args;
+ }
+ else if (PySequence_Length(args) == 0) {
+ Py_DECREF(args);
+ return kwargs;
+ }
+ else {
+ PyObject* tuple = PyTuple_New(2);
+ if (tuple == NULL) {
+ Py_DECREF(args);
+ Py_DECREF(kwargs);
+ return NULL;
+ }
+ PyTuple_SET_ITEM(tuple, 0, args);
+ PyTuple_SET_ITEM(tuple, 1, kwargs);
+ return tuple;
+ }
+}
+
+static PyObject*
+g_handle_exit(PyObject* result)
+{
+ if (result == NULL && PyErr_ExceptionMatches(PyExc_GreenletExit)) {
+ /* catch and ignore GreenletExit */
+ PyObject *exc, *val, *tb;
+ PyErr_Fetch(&exc, &val, &tb);
+ if (val == NULL) {
+ Py_INCREF(Py_None);
+ val = Py_None;
+ }
+ result = val;
+ Py_DECREF(exc);
+ Py_XDECREF(tb);
+ }
+ if (result != NULL) {
+ /* package the result into a 1-tuple */
+ PyObject* r = result;
+ result = PyTuple_New(1);
+ if (result) {
+ PyTuple_SET_ITEM(result, 0, r);
+ }
+ else {
+ Py_DECREF(r);
+ }
+ }
+ return result;
+}
+
+static int GREENLET_NOINLINE(g_initialstub)(void* mark)
+{
+ int err;
+ PyObject *o, *run;
+ PyObject *exc, *val, *tb;
+ PyObject* run_info;
+ PyGreenlet* self = ts_target;
+ PyObject* args = ts_passaround_args;
+ PyObject* kwargs = ts_passaround_kwargs;
+#if GREENLET_USE_CFRAME
+ /*
+ See green_new(). This is a stack-allocated variable used
+ while *self* is in PyObject_Call().
+ We want to defer copying the state info until we're sure
+ we need it and are in a stable place to do so.
+ */
+ _PyCFrame trace_info;
+#endif
+ /* save exception in case getattr clears it */
+ PyErr_Fetch(&exc, &val, &tb);
+ /* self.run is the object to call in the new greenlet */
+ run = PyObject_GetAttrString((PyObject*)self, "run");
+ if (run == NULL) {
+ Py_XDECREF(exc);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ return -1;
+ }
+ /* restore saved exception */
+ PyErr_Restore(exc, val, tb);
+
+ /* recheck the state in case getattr caused thread switches */
+ if (!STATE_OK) {
+ Py_DECREF(run);
+ return -1;
+ }
+
+ /* recheck run_info in case greenlet reparented anywhere above */
+ run_info = green_statedict(self);
+ if (run_info == NULL || run_info != ts_current->run_info) {
+ Py_DECREF(run);
+ PyErr_SetString(PyExc_GreenletError,
+ run_info ?
+ "cannot switch to a different thread" :
+ "cannot switch to a garbage collected greenlet");
+ return -1;
+ }
+
+ /* by the time we got here another start could happen elsewhere,
+ * that means it should now be a regular switch
+ */
+ if (PyGreenlet_STARTED(self)) {
+ Py_DECREF(run);
+ ts_passaround_args = args;
+ ts_passaround_kwargs = kwargs;
+ return 1;
+ }
+
+#if GREENLET_USE_CFRAME
+ /* OK, we need it, we're about to switch greenlets, save the state. */
+ trace_info = *PyThreadState_GET()->cframe;
+ /* Make the target greenlet refer to the stack value. */
+ self->cframe = &trace_info;
+ /*
+ And restore the link to the previous frame so this one gets
+ unliked appropriately.
+ */
+ self->cframe->previous = &PyThreadState_GET()->root_cframe;
+#endif
+ /* start the greenlet */
+ self->stack_start = NULL;
+ self->stack_stop = (char*)mark;
+ if (ts_current->stack_start == NULL) {
+ /* ts_current is dying */
+ self->stack_prev = ts_current->stack_prev;
+ }
+ else {
+ self->stack_prev = ts_current;
+ }
+ self->top_frame = NULL;
+ green_clear_exc(self);
+#if GREENLET_PY311
+ self->recursion_depth = (PyThreadState_GET()->recursion_limit
+ - PyThreadState_GET()->recursion_remaining);
+#else
+ self->recursion_depth = PyThreadState_GET()->recursion_depth;
+#endif
+
+ /* restore arguments in case they are clobbered */
+ ts_target = self;
+ ts_passaround_args = args;
+ ts_passaround_kwargs = kwargs;
+
+ /* perform the initial switch */
+ err = g_switchstack();
+
+ /* returns twice!
+ The 1st time with ``err == 1``: we are in the new greenlet
+ The 2nd time with ``err <= 0``: back in the caller's greenlet
+ */
+ if (err == 1) {
+ /* in the new greenlet */
+ PyGreenlet* origin;
+ PyObject* tracefunc;
+ PyObject* result;
+ PyGreenlet* parent;
+ self->stack_start = (char*)1; /* running */
+
+ /* grab origin while we still can */
+ origin = ts_origin;
+ ts_origin = NULL;
+
+ /* now use run_info to store the statedict */
+ o = self->run_info;
+ self->run_info = green_statedict(self->parent);
+ Py_INCREF(self->run_info);
+ Py_XDECREF(o);
+
+ if ((tracefunc = PyDict_GetItem(self->run_info, ts_tracekey)) != NULL) {
+ Py_INCREF(tracefunc);
+ if (g_calltrace(tracefunc,
+ args ? ts_event_switch : ts_event_throw,
+ origin,
+ self) < 0) {
+ /* Turn trace errors into switch throws */
+ Py_CLEAR(kwargs);
+ Py_CLEAR(args);
+ }
+ Py_DECREF(tracefunc);
+ }
+
+ Py_DECREF(origin);
+
+ if (args == NULL) {
+ /* pending exception */
+ result = NULL;
+ }
+ else {
+ /* call g.run(*args, **kwargs) */
+ result = PyObject_Call(run, args, kwargs);
+ Py_DECREF(args);
+ Py_XDECREF(kwargs);
+ }
+ Py_DECREF(run);
+ result = g_handle_exit(result);
+
+ /* jump back to parent */
+ self->stack_start = NULL; /* dead */
+ for (parent = self->parent; parent != NULL; parent = parent->parent) {
+ result = g_switch(parent, result, NULL);
+ /* Return here means switch to parent failed,
+ * in which case we throw *current* exception
+ * to the next parent in chain.
+ */
+ assert(result == NULL);
+ }
+ /* We ran out of parents, cannot continue */
+ PyErr_WriteUnraisable((PyObject*)self);
+ Py_FatalError("greenlets cannot continue");
+ }
+ /* back in the parent */
+ if (err < 0) {
+ /* start failed badly, restore greenlet state */
+ self->stack_start = NULL;
+ self->stack_stop = NULL;
+ self->stack_prev = NULL;
+ }
+ return err;
+}
+
+/***********************************************************/
+
+static PyObject*
+green_new(PyTypeObject* type, PyObject* args, PyObject* kwds)
+{
+ PyObject* o =
+ PyBaseObject_Type.tp_new(type, ts_empty_tuple, ts_empty_dict);
+ if (o != NULL) {
+ if (!STATE_OK) {
+ Py_DECREF(o);
+ return NULL;
+ }
+ Py_INCREF(ts_current);
+ ((PyGreenlet*)o)->parent = ts_current;
+#if GREENLET_USE_CFRAME
+ /*
+ The PyThreadState->cframe pointer usually points to memory on the
+ stack, alloceted in a call into PyEval_EvalFrameDefault.
+
+ Initially, before any evaluation begins, it points to the initial
+ PyThreadState object's ``root_cframe`` object, which is statically
+ allocated for the lifetime of the thread.
+
+ A greenlet can last for longer than a call to
+ PyEval_EvalFrameDefault, so we can't set its ``cframe`` pointer to
+ be the current ``PyThreadState->cframe``; nor could we use one from
+ the greenlet parent for the same reason. Yet a further no: we can't
+ allocate one scoped to the greenlet and then destroy it when the
+ greenlet is deallocated, because inside the interpreter the CFrame
+ objects form a linked list, and that too can result in accessing
+ memory beyond its dynamic lifetime (if the greenlet doesn't actually
+ finish before it dies, its entry could still be in the list).
+
+ Using the ``root_cframe`` is problematic, though, because its
+ members are never modified by the interpreter and are set to 0,
+ meaning that its ``use_tracing`` flag is never updated. We don't
+ want to modify that value in the ``root_cframe`` ourself: it
+ *shouldn't* matter much because we should probably never get back to
+ the point where that's the only cframe on the stack; even if it did
+ matter, the major consequence of an incorrect value for
+ ``use_tracing`` is that if its true the interpreter does some extra
+ work --- however, it's just good code hygiene.
+
+ Our solution: before a greenlet runs, after its initial creation,
+ it uses the ``root_cframe`` just to have something to put there.
+ However, once the greenlet is actually switched to for the first
+ time, ``g_initialstub`` (which doesn't actually "return" while the
+ greenlet is running) stores a new _PyCFrame on its local stack, and
+ copies the appropriate values from the currently running CFrame;
+ this is then made the _PyCFrame for the newly-minted greenlet.
+ ``g_initialstub`` then proceeds to call ``glet.run()``, which
+ results in ``PyEval_...`` adding the _PyCFrame to the list. Switches
+ continue as normal. Finally, when the greenlet finishes, the call to
+ ``glet.run()`` returns and the _PyCFrame is taken out of the linked
+ list and the stack value is now unused and free to expire.
+ */
+ ((PyGreenlet*)o)->cframe = &PyThreadState_GET()->root_cframe;
+#endif
+ }
+ return o;
+}
+
+static int
+green_setrun(PyGreenlet* self, PyObject* nrun, void* c);
+static int
+green_setparent(PyGreenlet* self, PyObject* nparent, void* c);
+
+static int
+green_init(PyGreenlet* self, PyObject* args, PyObject* kwargs)
+{
+ PyObject* run = NULL;
+ PyObject* nparent = NULL;
+ static char* kwlist[] = {"run", "parent", 0};
+ if (!PyArg_ParseTupleAndKeywords(
+ args, kwargs, "|OO:green", kwlist, &run, &nparent)) {
+ return -1;
+ }
+
+ if (run != NULL) {
+ if (green_setrun(self, run, NULL)) {
+ return -1;
+ }
+ }
+ if (nparent != NULL && nparent != Py_None) {
+ return green_setparent(self, nparent, NULL);
+ }
+ return 0;
+}
+
+static int
+kill_greenlet(PyGreenlet* self)
+{
+ /* Cannot raise an exception to kill the greenlet if
+ it is not running in the same thread! */
+ if (self->run_info == PyThreadState_GET()->dict) {
+ /* The dying greenlet cannot be a parent of ts_current
+ because the 'parent' field chain would hold a
+ reference */
+ PyObject* result;
+ PyGreenlet* oldparent;
+ PyGreenlet* tmp;
+ if (!STATE_OK) {
+ return -1;
+ }
+ oldparent = self->parent;
+ self->parent = ts_current;
+ Py_INCREF(self->parent);
+ /* Send the greenlet a GreenletExit exception. */
+ PyErr_SetNone(PyExc_GreenletExit);
+ result = g_switch(self, NULL, NULL);
+ tmp = self->parent;
+ self->parent = oldparent;
+ Py_XDECREF(tmp);
+ if (result == NULL) {
+ return -1;
+ }
+ Py_DECREF(result);
+ return 0;
+ }
+ else {
+ /* Not the same thread! Temporarily save the greenlet
+ into its thread's ts_delkey list. */
+ PyObject* lst;
+ lst = PyDict_GetItem(self->run_info, ts_delkey);
+ if (lst == NULL) {
+ lst = PyList_New(0);
+ if (lst == NULL
+ || PyDict_SetItem(self->run_info, ts_delkey, lst) < 0) {
+ return -1;
+ }
+ /* PyDict_SetItem now holds a strong reference. PyList_New also
+ returned a fresh reference. We need to DECREF it now and let
+ the dictionary keep sole ownership. Frow now on, we're working
+ with a borrowed reference that will go away when the thread
+ dies. */
+ Py_DECREF(lst);
+ }
+ if (PyList_Append(lst, (PyObject*)self) < 0) {
+ return -1;
+ }
+ if (!STATE_OK) { /* to force ts_delkey to be reconsidered */
+ return -1;
+ }
+ return 0;
+ }
+}
+
+static int
+green_traverse(PyGreenlet* self, visitproc visit, void* arg)
+{
+ /* We must only visit referenced objects, i.e. only objects
+ Py_INCREF'ed by this greenlet (directly or indirectly):
+ - stack_prev is not visited: holds previous stack pointer, but it's not
+ referenced
+ - frames are not visited: alive greenlets are not garbage collected
+ anyway */
+ Py_VISIT((PyObject*)self->parent);
+ Py_VISIT(self->run_info);
+#if GREENLET_PY37
+ Py_VISIT(self->context);
+#endif
+#if GREENLET_PY37
+ Py_VISIT(self->exc_state.exc_value);
+#if !GREENLET_PY311
+ Py_VISIT(self->exc_state.exc_type);
+ Py_VISIT(self->exc_state.exc_traceback);
+#endif
+#else
+ Py_VISIT(self->exc_type);
+ Py_VISIT(self->exc_value);
+ Py_VISIT(self->exc_traceback);
+#endif
+ Py_VISIT(self->dict);
+ return 0;
+}
+
+static int
+green_is_gc(PyGreenlet* self)
+{
+ /* Main greenlet can be garbage collected since it can only
+ become unreachable if the underlying thread exited.
+ Active greenlet cannot be garbage collected, however. */
+ if (PyGreenlet_MAIN(self) || !PyGreenlet_ACTIVE(self)) {
+ return 1;
+ }
+ return 0;
+}
+
+static int
+green_clear(PyGreenlet* self)
+{
+ /* Greenlet is only cleared if it is about to be collected.
+ Since active greenlets are not garbage collectable, we can
+ be sure that, even if they are deallocated during clear,
+ nothing they reference is in unreachable or finalizers,
+ so even if it switches we are relatively safe. */
+ Py_CLEAR(self->parent);
+ Py_CLEAR(self->run_info);
+#if GREENLET_PY37
+ Py_CLEAR(self->context);
+#endif
+#if GREENLET_PY37
+ Py_CLEAR(self->exc_state.exc_value);
+#if !GREENLET_PY311
+ Py_CLEAR(self->exc_state.exc_type);
+ Py_CLEAR(self->exc_state.exc_traceback);
+#endif
+#else
+ Py_CLEAR(self->exc_type);
+ Py_CLEAR(self->exc_value);
+ Py_CLEAR(self->exc_traceback);
+#endif
+ Py_CLEAR(self->dict);
+ return 0;
+}
+
+static void
+green_dealloc(PyGreenlet* self)
+{
+ PyObject *error_type, *error_value, *error_traceback;
+ Py_ssize_t refcnt;
+
+ PyObject_GC_UnTrack(self);
+
+ if (PyGreenlet_ACTIVE(self) && self->run_info != NULL &&
+ !PyGreenlet_MAIN(self)) {
+ /* Hacks hacks hacks copied from instance_dealloc() */
+ /* Temporarily resurrect the greenlet. */
+ assert(Py_REFCNT(self) == 0);
+ Py_SET_REFCNT(self, 1);
+ /* Save the current exception, if any. */
+ PyErr_Fetch(&error_type, &error_value, &error_traceback);
+ if (kill_greenlet(self) < 0) {
+ PyErr_WriteUnraisable((PyObject*)self);
+ /* XXX what else should we do? */
+ }
+ /* Check for no resurrection must be done while we keep
+ * our internal reference, otherwise PyFile_WriteObject
+ * causes recursion if using Py_INCREF/Py_DECREF
+ */
+ if (Py_REFCNT(self) == 1 && PyGreenlet_ACTIVE(self)) {
+ /* Not resurrected, but still not dead!
+ XXX what else should we do? we complain. */
+ PyObject* f = PySys_GetObject("stderr");
+ Py_INCREF(self); /* leak! */
+ if (f != NULL) {
+ PyFile_WriteString("GreenletExit did not kill ", f);
+ PyFile_WriteObject((PyObject*)self, f, 0);
+ PyFile_WriteString("\n", f);
+ }
+ }
+ /* Restore the saved exception. */
+ PyErr_Restore(error_type, error_value, error_traceback);
+ /* Undo the temporary resurrection; can't use DECREF here,
+ * it would cause a recursive call.
+ */
+ assert(Py_REFCNT(self) > 0);
+
+ refcnt = Py_REFCNT(self) - 1;
+ Py_SET_REFCNT(self, refcnt);
+ if (refcnt != 0) {
+ /* Resurrected! */
+ _Py_NewReference((PyObject*)self);
+ Py_SET_REFCNT(self, refcnt);
+ /* Better to use tp_finalizer slot (PEP 442)
+ * and call ``PyObject_CallFinalizerFromDealloc``,
+ * but that's only supported in Python 3.4+; see
+ * Modules/_io/iobase.c for an example.
+ *
+ * The following approach is copied from iobase.c in CPython 2.7.
+ * (along with much of this function in general). Here's their
+ * comment:
+ *
+ * When called from a heap type's dealloc, the type will be
+ * decref'ed on return (see e.g. subtype_dealloc in typeobject.c). */
+ if (PyType_HasFeature(Py_TYPE(self), Py_TPFLAGS_HEAPTYPE)) {
+ Py_INCREF(Py_TYPE(self));
+ }
+
+ PyObject_GC_Track((PyObject*)self);
+
+ _Py_DEC_REFTOTAL;
+#ifdef COUNT_ALLOCS
+ --Py_TYPE(self)->tp_frees;
+ --Py_TYPE(self)->tp_allocs;
+#endif /* COUNT_ALLOCS */
+ return;
+ }
+ }
+ if (self->weakreflist != NULL) {
+ PyObject_ClearWeakRefs((PyObject*)self);
+ }
+ Py_CLEAR(self->parent);
+ Py_CLEAR(self->run_info);
+#if GREENLET_PY37
+ Py_CLEAR(self->context);
+#endif
+#if GREENLET_PY37
+ Py_CLEAR(self->exc_state.exc_value);
+#if !GREENLET_PY311
+ Py_CLEAR(self->exc_state.exc_type);
+ Py_CLEAR(self->exc_state.exc_traceback);
+#endif
+#else
+ Py_CLEAR(self->exc_type);
+ Py_CLEAR(self->exc_value);
+ Py_CLEAR(self->exc_traceback);
+#endif
+ Py_CLEAR(self->dict);
+ Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+static PyObject*
+single_result(PyObject* results)
+{
+ if (results != NULL && PyTuple_Check(results) &&
+ PyTuple_GET_SIZE(results) == 1) {
+ PyObject* result = PyTuple_GET_ITEM(results, 0);
+ Py_INCREF(result);
+ Py_DECREF(results);
+ return result;
+ }
+ else {
+ return results;
+ }
+}
+
+static PyObject*
+throw_greenlet(PyGreenlet* self, PyObject* typ, PyObject* val, PyObject* tb)
+{
+ /* Note: _consumes_ a reference to typ, val, tb */
+ PyObject* result = NULL;
+ PyErr_Restore(typ, val, tb);
+ if (PyGreenlet_STARTED(self) && !PyGreenlet_ACTIVE(self)) {
+ /* dead greenlet: turn GreenletExit into a regular return */
+ result = g_handle_exit(result);
+ }
+ return single_result(g_switch(self, result, NULL));
+}
+
+PyDoc_STRVAR(
+ green_switch_doc,
+ "switch(*args, **kwargs)\n"
+ "\n"
+ "Switch execution to this greenlet.\n"
+ "\n"
+ "If this greenlet has never been run, then this greenlet\n"
+ "will be switched to using the body of ``self.run(*args, **kwargs)``.\n"
+ "\n"
+ "If the greenlet is active (has been run, but was switch()'ed\n"
+ "out before leaving its run function), then this greenlet will\n"
+ "be resumed and the return value to its switch call will be\n"
+ "None if no arguments are given, the given argument if one\n"
+ "argument is given, or the args tuple and keyword args dict if\n"
+ "multiple arguments are given.\n"
+ "\n"
+ "If the greenlet is dead, or is the current greenlet then this\n"
+ "function will simply return the arguments using the same rules as\n"
+ "above.\n");
+
+static PyObject*
+green_switch(PyGreenlet* self, PyObject* args, PyObject* kwargs)
+{
+ Py_INCREF(args);
+ Py_XINCREF(kwargs);
+ return single_result(g_switch(self, args, kwargs));
+}
+
+PyDoc_STRVAR(
+ green_throw_doc,
+ "Switches execution to this greenlet, but immediately raises the\n"
+ "given exception in this greenlet. If no argument is provided, the "
+ "exception\n"
+ "defaults to `greenlet.GreenletExit`. The normal exception\n"
+ "propagation rules apply, as described for `switch`. Note that calling "
+ "this\n"
+ "method is almost equivalent to the following::\n"
+ "\n"
+ " def raiser():\n"
+ " raise typ, val, tb\n"
+ " g_raiser = greenlet(raiser, parent=g)\n"
+ " g_raiser.switch()\n"
+ "\n"
+ "except that this trick does not work for the\n"
+ "`greenlet.GreenletExit` exception, which would not propagate\n"
+ "from ``g_raiser`` to ``g``.\n");
+
+static PyObject*
+green_throw(PyGreenlet* self, PyObject* args)
+{
+ PyObject* typ = PyExc_GreenletExit;
+ PyObject* val = NULL;
+ PyObject* tb = NULL;
+
+ if (!PyArg_ParseTuple(args, "|OOO:throw", &typ, &val, &tb)) {
+ return NULL;
+ }
+
+ /* First, check the traceback argument, replacing None, with NULL */
+ if (tb == Py_None) {
+ tb = NULL;
+ }
+ else if (tb != NULL && !PyTraceBack_Check(tb)) {
+ PyErr_SetString(PyExc_TypeError,
+ "throw() third argument must be a traceback object");
+ return NULL;
+ }
+
+ Py_INCREF(typ);
+ Py_XINCREF(val);
+ Py_XINCREF(tb);
+
+ if (PyExceptionClass_Check(typ)) {
+ PyErr_NormalizeException(&typ, &val, &tb);
+ }
+ else if (PyExceptionInstance_Check(typ)) {
+ /* Raising an instance. The value should be a dummy. */
+ if (val && val != Py_None) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "instance exception may not have a separate value");
+ goto failed_throw;
+ }
+ else {
+ /* Normalize to raise <class>, <instance> */
+ Py_XDECREF(val);
+ val = typ;
+ typ = PyExceptionInstance_Class(typ);
+ Py_INCREF(typ);
+ }
+ }
+ else {
+ /* Not something you can raise. throw() fails. */
+ PyErr_Format(PyExc_TypeError,
+ "exceptions must be classes, or instances, not %s",
+ Py_TYPE(typ)->tp_name);
+ goto failed_throw;
+ }
+
+ return throw_greenlet(self, typ, val, tb);
+
+failed_throw:
+ /* Didn't use our arguments, so restore their original refcounts */
+ Py_DECREF(typ);
+ Py_XDECREF(val);
+ Py_XDECREF(tb);
+ return NULL;
+}
+
+static int
+green_bool(PyGreenlet* self)
+{
+ return PyGreenlet_ACTIVE(self);
+}
+
+static PyObject*
+green_getdict(PyGreenlet* self, void* c)
+{
+ if (self->dict == NULL) {
+ self->dict = PyDict_New();
+ if (self->dict == NULL) {
+ return NULL;
+ }
+ }
+ Py_INCREF(self->dict);
+ return self->dict;
+}
+
+static int
+green_setdict(PyGreenlet* self, PyObject* val, void* c)
+{
+ PyObject* tmp;
+
+ if (val == NULL) {
+ PyErr_SetString(PyExc_TypeError, "__dict__ may not be deleted");
+ return -1;
+ }
+ if (!PyDict_Check(val)) {
+ PyErr_SetString(PyExc_TypeError, "__dict__ must be a dictionary");
+ return -1;
+ }
+ tmp = self->dict;
+ Py_INCREF(val);
+ self->dict = val;
+ Py_XDECREF(tmp);
+ return 0;
+}
+
+static int
+_green_not_dead(PyGreenlet* self)
+{
+ return PyGreenlet_ACTIVE(self) || !PyGreenlet_STARTED(self);
+}
+
+
+static PyObject*
+green_getdead(PyGreenlet* self, void* c)
+{
+ if (_green_not_dead(self)) {
+ Py_RETURN_FALSE;
+ }
+ else {
+ Py_RETURN_TRUE;
+ }
+}
+
+static PyObject*
+green_get_stack_saved(PyGreenlet* self, void* c)
+{
+ return PyLong_FromSsize_t(self->stack_saved);
+}
+
+static PyObject*
+green_getrun(PyGreenlet* self, void* c)
+{
+ if (PyGreenlet_STARTED(self) || self->run_info == NULL) {
+ PyErr_SetString(PyExc_AttributeError, "run");
+ return NULL;
+ }
+ Py_INCREF(self->run_info);
+ return self->run_info;
+}
+
+static int
+green_setrun(PyGreenlet* self, PyObject* nrun, void* c)
+{
+ PyObject* o;
+ if (PyGreenlet_STARTED(self)) {
+ PyErr_SetString(PyExc_AttributeError,
+ "run cannot be set "
+ "after the start of the greenlet");
+ return -1;
+ }
+ o = self->run_info;
+ self->run_info = nrun;
+ Py_XINCREF(nrun);
+ Py_XDECREF(o);
+ return 0;
+}
+
+static PyObject*
+green_getparent(PyGreenlet* self, void* c)
+{
+ PyObject* result = self->parent ? (PyObject*)self->parent : Py_None;
+ Py_INCREF(result);
+ return result;
+}
+
+static int
+green_setparent(PyGreenlet* self, PyObject* nparent, void* c)
+{
+ PyGreenlet* p;
+ PyObject* run_info = NULL;
+ if (nparent == NULL) {
+ PyErr_SetString(PyExc_AttributeError, "can't delete attribute");
+ return -1;
+ }
+ if (!PyGreenlet_Check(nparent)) {
+ PyErr_SetString(PyExc_TypeError, "parent must be a greenlet");
+ return -1;
+ }
+ for (p = (PyGreenlet*)nparent; p; p = p->parent) {
+ if (p == self) {
+ PyErr_SetString(PyExc_ValueError, "cyclic parent chain");
+ return -1;
+ }
+ run_info = PyGreenlet_ACTIVE(p) ? p->run_info : NULL;
+ }
+ if (run_info == NULL) {
+ PyErr_SetString(PyExc_ValueError,
+ "parent must not be garbage collected");
+ return -1;
+ }
+ if (PyGreenlet_STARTED(self) && self->run_info != run_info) {
+ PyErr_SetString(PyExc_ValueError,
+ "parent cannot be on a different thread");
+ return -1;
+ }
+ p = self->parent;
+ self->parent = (PyGreenlet*)nparent;
+ Py_INCREF(nparent);
+ Py_XDECREF(p);
+ return 0;
+}
+
+#ifdef Py_CONTEXT_H
+# define GREENLET_NO_CONTEXTVARS_REASON "This build of greenlet"
+#else
+# define GREENLET_NO_CONTEXTVARS_REASON "This Python interpreter"
+#endif
+
+static PyObject*
+green_getcontext(PyGreenlet* self, void* c)
+{
+#if GREENLET_PY37
+ PyThreadState* tstate = PyThreadState_GET();
+ PyObject* result;
+
+ if (!STATE_OK) {
+ return NULL;
+ }
+ if (PyGreenlet_ACTIVE(self) && self->top_frame == NULL) {
+ /* Currently running greenlet: context is stored in the thread state,
+ not the greenlet object. */
+ if (self == ts_current) {
+ result = tstate->context;
+ }
+ else {
+ PyErr_SetString(PyExc_ValueError,
+ "cannot get context of a "
+ "greenlet that is running in a different thread");
+ return NULL;
+ }
+ }
+ else {
+ /* Greenlet is not running: just return context. */
+ result = self->context;
+ }
+ if (result == NULL) {
+ result = Py_None;
+ }
+ Py_INCREF(result);
+ return result;
+#else
+ PyErr_SetString(PyExc_AttributeError,
+ GREENLET_NO_CONTEXTVARS_REASON
+ " does not support context variables");
+ return NULL;
+#endif
+}
+
+static int
+green_setcontext(PyGreenlet* self, PyObject* nctx, void* c)
+{
+#if GREENLET_PY37
+ PyThreadState* tstate;
+ PyObject* octx = NULL;
+ if (!STATE_OK) {
+ return -1;
+ }
+ if (nctx == NULL) {
+ PyErr_SetString(PyExc_AttributeError, "can't delete attribute");
+ return -1;
+ }
+ if (nctx == Py_None) {
+ /* "Empty context" is stored as NULL, not None. */
+ nctx = NULL;
+ }
+ else if (!PyContext_CheckExact(nctx)) {
+ PyErr_SetString(PyExc_TypeError,
+ "greenlet context must be a "
+ "contextvars.Context or None");
+ return -1;
+ }
+ tstate = PyThreadState_GET();
+ if (PyGreenlet_ACTIVE(self) && self->top_frame == NULL) {
+ /* Currently running greenlet: context is stored in the thread state,
+ not the greenlet object. */
+ if (self == ts_current) {
+ octx = tstate->context;
+ tstate->context = nctx;
+ tstate->context_ver++;
+ Py_XINCREF(nctx);
+ }
+ else {
+ PyErr_SetString(PyExc_ValueError,
+ "cannot set context of a "
+ "greenlet that is running in a different thread");
+ return -1;
+ }
+ }
+ else {
+ /* Greenlet is not running: just set context. */
+ octx = self->context;
+ self->context = nctx;
+ Py_XINCREF(nctx);
+ }
+ Py_XDECREF(octx);
+ return 0;
+#else
+ PyErr_SetString(PyExc_AttributeError,
+ GREENLET_NO_CONTEXTVARS_REASON
+ " does not support context variables");
+ return -1;
+#endif
+}
+
+#undef GREENLET_NO_CONTEXTVARS_REASON
+
+static PyObject*
+green_getframe(PyGreenlet* self, void* c)
+{
+ PyObject* result = self->top_frame ? (PyObject*)self->top_frame : Py_None;
+ Py_INCREF(result);
+ return result;
+}
+
+static PyObject*
+green_getstate(PyGreenlet* self)
+{
+ PyErr_Format(PyExc_TypeError,
+ "cannot serialize '%s' object",
+ Py_TYPE(self)->tp_name);
+ return NULL;
+}
+
+static PyObject*
+green_repr(PyGreenlet* self)
+{
+ /*
+ Return a string like
+ <greenlet.greenlet at 0xdeadbeef [current][active started]|dead main>
+
+ The handling of greenlets across threads is not super good.
+ We mostly use the internal definitions of these terms, but they
+ generally should make sense to users as well.
+ */
+ PyObject* result;
+ int never_started = !PyGreenlet_STARTED(self) && !PyGreenlet_ACTIVE(self);
+
+ if (!STATE_OK) {
+ return NULL;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+# define GNative_FromFormat PyUnicode_FromFormat
+#else
+# define GNative_FromFormat PyString_FromFormat
+#endif
+
+ if (_green_not_dead(self)) {
+ /* XXX: The otid= is almost useless becasue you can't correlate it to
+ any thread identifier exposed to Python. We could use
+ PyThreadState_GET()->thread_id, but we'd need to save that in the
+ greenlet, or save the whole PyThreadState object itself.
+
+ As it stands, its only useful for identifying greenlets from the same thread.
+ */
+ result = GNative_FromFormat(
+ "<%s object at %p (otid=%p)%s%s%s%s>",
+ Py_TYPE(self)->tp_name,
+ self,
+ self->run_info,
+ ts_current == self
+ ? " current"
+ : (PyGreenlet_STARTED(self) ? " suspended" : ""),
+ PyGreenlet_ACTIVE(self) ? " active" : "",
+ never_started ? " pending" : " started",
+ PyGreenlet_MAIN(self) ? " main" : ""
+ );
+ }
+ else {
+ /* main greenlets never really appear dead. */
+ result = GNative_FromFormat(
+ "<%s object at %p (otid=%p) dead>",
+ Py_TYPE(self)->tp_name,
+ self,
+ self->run_info
+ );
+ }
+#undef GNative_FromFormat
+
+ return result;
+}
+
+/*****************************************************************************
+ * C interface
+ *
+ * These are exported using the CObject API
+ */
+
+static PyGreenlet*
+PyGreenlet_GetCurrent(void)
+{
+ if (!STATE_OK) {
+ return NULL;
+ }
+ Py_INCREF(ts_current);
+ return ts_current;
+}
+
+static int
+PyGreenlet_SetParent(PyGreenlet* g, PyGreenlet* nparent)
+{
+ if (!PyGreenlet_Check(g)) {
+ PyErr_SetString(PyExc_TypeError, "parent must be a greenlet");
+ return -1;
+ }
+
+ return green_setparent((PyGreenlet*)g, (PyObject*)nparent, NULL);
+}
+
+static PyGreenlet*
+PyGreenlet_New(PyObject* run, PyGreenlet* parent)
+{
+ /* XXX: Why doesn't this call green_new()? There's some duplicate
+ code. */
+ PyGreenlet* g = NULL;
+ g = (PyGreenlet*)PyType_GenericAlloc(&PyGreenlet_Type, 0);
+ if (g == NULL) {
+ return NULL;
+ }
+
+ if (run != NULL) {
+ Py_INCREF(run);
+ g->run_info = run;
+ }
+
+ if (parent != NULL) {
+ if (PyGreenlet_SetParent(g, parent)) {
+ Py_DECREF(g);
+ return NULL;
+ }
+ }
+ else {
+ if ((g->parent = PyGreenlet_GetCurrent()) == NULL) {
+ Py_DECREF(g);
+ return NULL;
+ }
+ }
+#if GREENLET_USE_CFRAME
+ g->cframe = &PyThreadState_GET()->root_cframe;
+#endif
+ return g;
+}
+
+static PyObject*
+PyGreenlet_Switch(PyGreenlet* g, PyObject* args, PyObject* kwargs)
+{
+ PyGreenlet* self = (PyGreenlet*)g;
+
+ if (!PyGreenlet_Check(self)) {
+ PyErr_BadArgument();
+ return NULL;
+ }
+
+ if (args == NULL) {
+ args = Py_BuildValue("()");
+ }
+ else {
+ Py_INCREF(args);
+ }
+
+ if (kwargs != NULL && PyDict_Check(kwargs)) {
+ Py_INCREF(kwargs);
+ }
+ else {
+ kwargs = NULL;
+ }
+
+ return single_result(g_switch(self, args, kwargs));
+}
+
+static PyObject*
+PyGreenlet_Throw(PyGreenlet* self, PyObject* typ, PyObject* val, PyObject* tb)
+{
+ if (!PyGreenlet_Check(self)) {
+ PyErr_BadArgument();
+ return NULL;
+ }
+ Py_INCREF(typ);
+ Py_XINCREF(val);
+ Py_XINCREF(tb);
+ return throw_greenlet(self, typ, val, tb);
+}
+
+/** End C API ****************************************************************/
+
+static PyMethodDef green_methods[] = {
+ {"switch",
+ (PyCFunction)green_switch,
+ METH_VARARGS | METH_KEYWORDS,
+ green_switch_doc},
+ {"throw", (PyCFunction)green_throw, METH_VARARGS, green_throw_doc},
+ {"__getstate__", (PyCFunction)green_getstate, METH_NOARGS, NULL},
+ {NULL, NULL} /* sentinel */
+};
+
+static PyGetSetDef green_getsets[] = {
+ {"__dict__", (getter)green_getdict, (setter)green_setdict, /*XXX*/ NULL},
+ {"run", (getter)green_getrun, (setter)green_setrun, /*XXX*/ NULL},
+ {"parent", (getter)green_getparent, (setter)green_setparent, /*XXX*/ NULL},
+ {"gr_frame", (getter)green_getframe, NULL, /*XXX*/ NULL},
+ {"gr_context",
+ (getter)green_getcontext,
+ (setter)green_setcontext,
+ /*XXX*/ NULL},
+ {"dead", (getter)green_getdead, NULL, /*XXX*/ NULL},
+ {"_stack_saved", (getter)green_get_stack_saved, NULL, /*XXX*/ NULL},
+ {NULL}};
+
+static PyNumberMethods green_as_number = {
+ NULL, /* nb_add */
+ NULL, /* nb_subtract */
+ NULL, /* nb_multiply */
+#if PY_MAJOR_VERSION < 3
+ NULL, /* nb_divide */
+#endif
+ NULL, /* nb_remainder */
+ NULL, /* nb_divmod */
+ NULL, /* nb_power */
+ NULL, /* nb_negative */
+ NULL, /* nb_positive */
+ NULL, /* nb_absolute */
+ (inquiry)green_bool, /* nb_bool */
+};
+
+PyTypeObject PyGreenlet_Type = {
+ PyVarObject_HEAD_INIT(NULL, 0)
+ "greenlet.greenlet", /* tp_name */
+ sizeof(PyGreenlet), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ /* methods */
+ (destructor)green_dealloc, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_compare */
+ (reprfunc)green_repr, /* tp_repr */
+ &green_as_number, /* tp_as _number*/
+ 0, /* tp_as _sequence*/
+ 0, /* tp_as _mapping*/
+ 0, /* tp_hash */
+ 0, /* tp_call */
+ 0, /* tp_str */
+ 0, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer*/
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
+ GREENLET_GC_FLAGS, /* tp_flags */
+ "greenlet(run=None, parent=None) -> greenlet\n\n"
+ "Creates a new greenlet object (without running it).\n\n"
+ " - *run* -- The callable to invoke.\n"
+ " - *parent* -- The parent greenlet. The default is the current "
+ "greenlet.", /* tp_doc */
+ (traverseproc)GREENLET_tp_traverse, /* tp_traverse */
+ (inquiry)GREENLET_tp_clear, /* tp_clear */
+ 0, /* tp_richcompare */
+ offsetof(PyGreenlet, weakreflist), /* tp_weaklistoffset */
+ 0, /* tp_iter */
+ 0, /* tp_iternext */
+ green_methods, /* tp_methods */
+ 0, /* tp_members */
+ green_getsets, /* tp_getset */
+ 0, /* tp_base */
+ 0, /* tp_dict */
+ 0, /* tp_descr_get */
+ 0, /* tp_descr_set */
+ offsetof(PyGreenlet, dict), /* tp_dictoffset */
+ (initproc)green_init, /* tp_init */
+ GREENLET_tp_alloc, /* tp_alloc */
+ green_new, /* tp_new */
+ GREENLET_tp_free, /* tp_free */
+ (inquiry)GREENLET_tp_is_gc, /* tp_is_gc */
+};
+
+PyDoc_STRVAR(mod_getcurrent_doc,
+ "getcurrent() -> greenlet\n"
+ "\n"
+ "Returns the current greenlet (i.e. the one which called this "
+ "function).\n");
+
+static PyObject*
+mod_getcurrent(PyObject* self)
+{
+ if (!STATE_OK) {
+ return NULL;
+ }
+ Py_INCREF(ts_current);
+ return (PyObject*)ts_current;
+}
+
+PyDoc_STRVAR(mod_settrace_doc,
+ "settrace(callback) -> object\n"
+ "\n"
+ "Sets a new tracing function and returns the previous one.\n");
+static PyObject*
+mod_settrace(PyObject* self, PyObject* args)
+{
+ int err;
+ PyObject* previous;
+ PyObject* tracefunc;
+ PyGreenlet* current;
+ if (!PyArg_ParseTuple(args, "O", &tracefunc)) {
+ return NULL;
+ }
+ if (!STATE_OK) {
+ return NULL;
+ }
+ current = ts_current;
+ previous = PyDict_GetItem(current->run_info, ts_tracekey);
+ if (previous == NULL) {
+ previous = Py_None;
+ }
+ Py_INCREF(previous);
+ if (tracefunc == Py_None) {
+ err = previous != Py_None ?
+ PyDict_DelItem(current->run_info, ts_tracekey) :
+ 0;
+ }
+ else {
+ err = PyDict_SetItem(current->run_info, ts_tracekey, tracefunc);
+ }
+ if (err < 0) {
+ Py_CLEAR(previous);
+ }
+ return previous;
+}
+
+PyDoc_STRVAR(mod_gettrace_doc,
+ "gettrace() -> object\n"
+ "\n"
+ "Returns the currently set tracing function, or None.\n");
+
+static PyObject*
+mod_gettrace(PyObject* self)
+{
+ PyObject* tracefunc;
+ if (!STATE_OK) {
+ return NULL;
+ }
+ tracefunc = PyDict_GetItem(ts_current->run_info, ts_tracekey);
+ if (tracefunc == NULL) {
+ tracefunc = Py_None;
+ }
+ Py_INCREF(tracefunc);
+ return tracefunc;
+}
+
+static PyMethodDef GreenMethods[] = {
+ {"getcurrent",
+ (PyCFunction)mod_getcurrent,
+ METH_NOARGS,
+ mod_getcurrent_doc},
+ {"settrace", (PyCFunction)mod_settrace, METH_VARARGS, mod_settrace_doc},
+ {"gettrace", (PyCFunction)mod_gettrace, METH_NOARGS, mod_gettrace_doc},
+ {NULL, NULL} /* Sentinel */
+};
+
+static char* copy_on_greentype[] = {
+ "getcurrent", "error", "GreenletExit", "settrace", "gettrace", NULL};
+
+#if PY_MAJOR_VERSION >= 3
+# define INITERROR return NULL
+
+static struct PyModuleDef greenlet_module_def = {
+ PyModuleDef_HEAD_INIT,
+ "greenlet._greenlet",
+ NULL,
+ -1,
+ GreenMethods,
+};
+
+PyMODINIT_FUNC
+PyInit__greenlet(void)
+#else
+# define INITERROR return
+
+PyMODINIT_FUNC
+init_greenlet(void)
+#endif
+{
+ PyObject* m = NULL;
+ char** p = NULL;
+ PyObject* c_api_object;
+ static void* _PyGreenlet_API[PyGreenlet_API_pointers];
+
+ GREENLET_NOINLINE_INIT();
+
+#if PY_MAJOR_VERSION >= 3
+ m = PyModule_Create(&greenlet_module_def);
+#else
+ m = Py_InitModule("greenlet._greenlet", GreenMethods);
+#endif
+ if (m == NULL) {
+ INITERROR;
+ }
+
+#if PY_MAJOR_VERSION >= 3
+# define Greenlet_Intern PyUnicode_InternFromString
+#else
+# define Greenlet_Intern PyString_InternFromString
+#endif
+ ts_curkey = Greenlet_Intern("__greenlet_ts_curkey");
+ ts_delkey = Greenlet_Intern("__greenlet_ts_delkey");
+ ts_tracekey = Greenlet_Intern("__greenlet_ts_tracekey");
+ ts_event_switch = Greenlet_Intern("switch");
+ ts_event_throw = Greenlet_Intern("throw");
+#undef Greenlet_Intern
+
+ if (ts_curkey == NULL || ts_delkey == NULL) {
+ INITERROR;
+ }
+ if (PyType_Ready(&PyGreenlet_Type) < 0) {
+ INITERROR;
+ }
+ PyExc_GreenletError = PyErr_NewException("greenlet.error", NULL, NULL);
+ if (PyExc_GreenletError == NULL) {
+ INITERROR;
+ }
+ PyExc_GreenletExit =
+ PyErr_NewException("greenlet.GreenletExit", PyExc_BaseException, NULL);
+ if (PyExc_GreenletExit == NULL) {
+ INITERROR;
+ }
+
+ ts_empty_tuple = PyTuple_New(0);
+ if (ts_empty_tuple == NULL) {
+ INITERROR;
+ }
+
+ ts_empty_dict = PyDict_New();
+ if (ts_empty_dict == NULL) {
+ INITERROR;
+ }
+
+ ts_current = green_create_main();
+ if (ts_current == NULL) {
+ INITERROR;
+ }
+
+ Py_INCREF(&PyGreenlet_Type);
+ PyModule_AddObject(m, "greenlet", (PyObject*)&PyGreenlet_Type);
+ Py_INCREF(PyExc_GreenletError);
+ PyModule_AddObject(m, "error", PyExc_GreenletError);
+ Py_INCREF(PyExc_GreenletExit);
+ PyModule_AddObject(m, "GreenletExit", PyExc_GreenletExit);
+
+ PyModule_AddObject(m, "GREENLET_USE_GC", PyBool_FromLong(1));
+ PyModule_AddObject(m, "GREENLET_USE_TRACING", PyBool_FromLong(1));
+ PyModule_AddObject(
+ m, "GREENLET_USE_CONTEXT_VARS", PyBool_FromLong(GREENLET_PY37));
+
+ /* also publish module-level data as attributes of the greentype. */
+ /* XXX: Why? */
+ for (p = copy_on_greentype; *p; p++) {
+ PyObject* o = PyObject_GetAttrString(m, *p);
+ if (!o) {
+ continue;
+ }
+ PyDict_SetItemString(PyGreenlet_Type.tp_dict, *p, o);
+ Py_DECREF(o);
+ }
+
+ /*
+ * Expose C API
+ */
+
+ /* types */
+ _PyGreenlet_API[PyGreenlet_Type_NUM] = (void*)&PyGreenlet_Type;
+
+ /* exceptions */
+ _PyGreenlet_API[PyExc_GreenletError_NUM] = (void*)PyExc_GreenletError;
+ _PyGreenlet_API[PyExc_GreenletExit_NUM] = (void*)PyExc_GreenletExit;
+
+ /* methods */
+ _PyGreenlet_API[PyGreenlet_New_NUM] = (void*)PyGreenlet_New;
+ _PyGreenlet_API[PyGreenlet_GetCurrent_NUM] = (void*)PyGreenlet_GetCurrent;
+ _PyGreenlet_API[PyGreenlet_Throw_NUM] = (void*)PyGreenlet_Throw;
+ _PyGreenlet_API[PyGreenlet_Switch_NUM] = (void*)PyGreenlet_Switch;
+ _PyGreenlet_API[PyGreenlet_SetParent_NUM] = (void*)PyGreenlet_SetParent;
+
+ /* XXX: Note that our module name is ``greenlet._greenlet``, but for
+ backwards compatibility with existing C code, we need the _C_API to
+ be directly in greenlet.
+ */
+ c_api_object =
+ PyCapsule_New((void*)_PyGreenlet_API, "greenlet._C_API", NULL);
+ if (c_api_object != NULL) {
+ PyModule_AddObject(m, "_C_API", c_api_object);
+ }
+
+#if PY_MAJOR_VERSION >= 3
+ return m;
+#endif
+}
+
+#ifdef __clang__
+# pragma clang diagnostic pop
+#endif
diff --git a/lib/greenlet/greenlet.h b/lib/greenlet/greenlet.h
new file mode 100644
index 0000000..c788b2f
--- /dev/null
+++ b/lib/greenlet/greenlet.h
@@ -0,0 +1,161 @@
+/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
+
+/* Greenlet object interface */
+
+#ifndef Py_GREENLETOBJECT_H
+#define Py_GREENLETOBJECT_H
+
+#include <Python.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* This is deprecated and undocumented. It does not change. */
+#define GREENLET_VERSION "1.0.0"
+
+#if PY_VERSION_HEX >= 0x30B00A6
+# define GREENLET_PY311 1
+ /* _PyInterpreterFrame moved to the internal C API in Python 3.11 */
+# include <internal/pycore_frame.h>
+#else
+# define GREENLET_PY311 0
+# define _PyCFrame CFrame
+#endif
+
+typedef struct _greenlet {
+ PyObject_HEAD
+ char* stack_start;
+ char* stack_stop;
+ char* stack_copy;
+ intptr_t stack_saved;
+ struct _greenlet* stack_prev;
+ struct _greenlet* parent;
+ PyObject* run_info;
+ struct _frame* top_frame;
+ int recursion_depth;
+#if GREENLET_PY311
+ _PyInterpreterFrame *current_frame;
+ _PyStackChunk *datastack_chunk;
+ PyObject **datastack_top;
+ PyObject **datastack_limit;
+#endif
+ PyObject* weakreflist;
+#if PY_VERSION_HEX >= 0x030700A3
+ _PyErr_StackItem* exc_info;
+ _PyErr_StackItem exc_state;
+#else
+ PyObject* exc_type;
+ PyObject* exc_value;
+ PyObject* exc_traceback;
+#endif
+ PyObject* dict;
+#if PY_VERSION_HEX >= 0x030700A3
+ PyObject* context;
+#endif
+#if PY_VERSION_HEX >= 0x30A00B1
+ _PyCFrame* cframe;
+#endif
+} PyGreenlet;
+
+#define PyGreenlet_Check(op) PyObject_TypeCheck(op, &PyGreenlet_Type)
+#define PyGreenlet_MAIN(op) (((PyGreenlet*)(op))->stack_stop == (char*)-1)
+#define PyGreenlet_STARTED(op) (((PyGreenlet*)(op))->stack_stop != NULL)
+#define PyGreenlet_ACTIVE(op) (((PyGreenlet*)(op))->stack_start != NULL)
+#define PyGreenlet_GET_PARENT(op) (((PyGreenlet*)(op))->parent)
+
+/* C API functions */
+
+/* Total number of symbols that are exported */
+#define PyGreenlet_API_pointers 8
+
+#define PyGreenlet_Type_NUM 0
+#define PyExc_GreenletError_NUM 1
+#define PyExc_GreenletExit_NUM 2
+
+#define PyGreenlet_New_NUM 3
+#define PyGreenlet_GetCurrent_NUM 4
+#define PyGreenlet_Throw_NUM 5
+#define PyGreenlet_Switch_NUM 6
+#define PyGreenlet_SetParent_NUM 7
+
+#ifndef GREENLET_MODULE
+/* This section is used by modules that uses the greenlet C API */
+static void** _PyGreenlet_API = NULL;
+
+# define PyGreenlet_Type \
+ (*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
+
+# define PyExc_GreenletError \
+ ((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
+
+# define PyExc_GreenletExit \
+ ((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
+
+/*
+ * PyGreenlet_New(PyObject *args)
+ *
+ * greenlet.greenlet(run, parent=None)
+ */
+# define PyGreenlet_New \
+ (*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
+ _PyGreenlet_API[PyGreenlet_New_NUM])
+
+/*
+ * PyGreenlet_GetCurrent(void)
+ *
+ * greenlet.getcurrent()
+ */
+# define PyGreenlet_GetCurrent \
+ (*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
+
+/*
+ * PyGreenlet_Throw(
+ * PyGreenlet *greenlet,
+ * PyObject *typ,
+ * PyObject *val,
+ * PyObject *tb)
+ *
+ * g.throw(...)
+ */
+# define PyGreenlet_Throw \
+ (*(PyObject * (*)(PyGreenlet * self, \
+ PyObject * typ, \
+ PyObject * val, \
+ PyObject * tb)) \
+ _PyGreenlet_API[PyGreenlet_Throw_NUM])
+
+/*
+ * PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
+ *
+ * g.switch(*args, **kwargs)
+ */
+# define PyGreenlet_Switch \
+ (*(PyObject * \
+ (*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
+ _PyGreenlet_API[PyGreenlet_Switch_NUM])
+
+/*
+ * PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
+ *
+ * g.parent = new_parent
+ */
+# define PyGreenlet_SetParent \
+ (*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
+ _PyGreenlet_API[PyGreenlet_SetParent_NUM])
+
+/* Macro that imports greenlet and initializes C API */
+/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
+ keep the older definition to be sure older code that might have a copy of
+ the header still works. */
+# define PyGreenlet_Import() \
+ { \
+ _PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
+ }
+
+#endif /* GREENLET_MODULE */
+
+#ifdef __cplusplus
+}
+#endif
+#endif /* !Py_GREENLETOBJECT_H */
diff --git a/lib/greenlet/platform/setup_switch_x64_masm.cmd b/lib/greenlet/platform/setup_switch_x64_masm.cmd
new file mode 100644
index 0000000..0928595
--- /dev/null
+++ b/lib/greenlet/platform/setup_switch_x64_masm.cmd
@@ -0,0 +1,2 @@
+call "C:\Program Files (x86)\Microsoft Visual Studio 9.0\VC\vcvarsall.bat" amd64
+ml64 /nologo /c /Fo switch_x64_masm.obj switch_x64_masm.asm
diff --git a/lib/greenlet/platform/switch_aarch64_gcc.h b/lib/greenlet/platform/switch_aarch64_gcc.h
new file mode 100644
index 0000000..0b9d556
--- /dev/null
+++ b/lib/greenlet/platform/switch_aarch64_gcc.h
@@ -0,0 +1,69 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 07-Sep-16 Add clang support using x register naming. Fredrik Fornwall
+ * 13-Apr-13 Add support for strange GCC caller-save decisions
+ * 08-Apr-13 File creation. Michael Matz
+ *
+ * NOTES
+ *
+ * Simply save all callee saved registers
+ *
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+#define STACK_MAGIC 0
+#define REGS_TO_SAVE "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", \
+ "x27", "x28", "x30" /* aka lr */, \
+ "v8", "v9", "v10", "v11", \
+ "v12", "v13", "v14", "v15"
+
+static int
+slp_switch(void)
+{
+ int err;
+ void *fp;
+ register long *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("str x29, %0" : "=m"(fp) : : );
+ __asm__ ("mov %0, sp" : "=r" (stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "add sp,sp,%0\n"
+ "add x29,x29,%0\n"
+ :
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ /* SLP_SAVE_STATE macro contains some return statements
+ (of -1 and 1). It falls through only when
+ the return value of slp_save_state() is zero, which
+ is placed in x0.
+ In that case we (slp_switch) also want to return zero
+ (also in x0 of course).
+ Now, some GCC versions (seen with 4.8) think it's a
+ good idea to save/restore x0 around the call to
+ slp_restore_state(), instead of simply zeroing it
+ at the return below. But slp_restore_state
+ writes random values to the stack slot used for this
+ save/restore (from when it once was saved above in
+ SLP_SAVE_STATE, when it was still uninitialized), so
+ "restoring" that precious zero actually makes us
+ return random values. There are some ways to make
+ GCC not use that zero value in the normal return path
+ (e.g. making err volatile, but that costs a little
+ stack space), and the simplest is to call a function
+ that returns an unknown value (which happens to be zero),
+ so the saved/restored value is unused. */
+ __asm__ volatile ("mov %0, #0" : "=r" (err));
+ }
+ __asm__ volatile ("ldr x29, %0" : : "m" (fp) :);
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ return err;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_alpha_unix.h b/lib/greenlet/platform/switch_alpha_unix.h
new file mode 100644
index 0000000..216619f
--- /dev/null
+++ b/lib/greenlet/platform/switch_alpha_unix.h
@@ -0,0 +1,30 @@
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+#define STACK_MAGIC 0
+
+#define REGS_TO_SAVE "$9", "$10", "$11", "$12", "$13", "$14", "$15", \
+ "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", "$f8", "$f9"
+
+static int
+slp_switch(void)
+{
+ register int ret;
+ register long *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("mov $30, %0" : "=r" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "addq $30, %0, $30\n\t"
+ : /* no outputs */
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("mov $31, %0" : "=r" (ret) : );
+ return ret;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_amd64_unix.h b/lib/greenlet/platform/switch_amd64_unix.h
new file mode 100644
index 0000000..16b99b7
--- /dev/null
+++ b/lib/greenlet/platform/switch_amd64_unix.h
@@ -0,0 +1,84 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 3-May-13 Ralf Schmitt <ralf@systemexit.de>
+ * Add support for strange GCC caller-save decisions
+ * (ported from switch_aarch64_gcc.h)
+ * 18-Aug-11 Alexey Borzenkov <snaury@gmail.com>
+ * Correctly save rbp, csr and cw
+ * 01-Apr-04 Hye-Shik Chang <perky@FreeBSD.org>
+ * Ported from i386 to amd64.
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for spark
+ * 31-Avr-02 Armin Rigo <arigo@ulb.ac.be>
+ * Added ebx, esi and edi register-saves.
+ * 01-Mar-02 Samual M. Rushing <rushing@ironport.com>
+ * Ported from i386.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+/* #define STACK_MAGIC 3 */
+/* the above works fine with gcc 2.96, but 2.95.3 wants this */
+#define STACK_MAGIC 0
+
+#define REGS_TO_SAVE "r12", "r13", "r14", "r15"
+
+static int
+slp_switch(void)
+{
+ int err;
+ void* rbp;
+ void* rbx;
+ unsigned int csr;
+ unsigned short cw;
+ register long *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("fstcw %0" : "=m" (cw));
+ __asm__ volatile ("stmxcsr %0" : "=m" (csr));
+ __asm__ volatile ("movq %%rbp, %0" : "=m" (rbp));
+ __asm__ volatile ("movq %%rbx, %0" : "=m" (rbx));
+ __asm__ ("movq %%rsp, %0" : "=g" (stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "addq %0, %%rsp\n"
+ "addq %0, %%rbp\n"
+ :
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ __asm__ volatile ("xorq %%rax, %%rax" : "=a" (err));
+ }
+ __asm__ volatile ("movq %0, %%rbx" : : "m" (rbx));
+ __asm__ volatile ("movq %0, %%rbp" : : "m" (rbp));
+ __asm__ volatile ("ldmxcsr %0" : : "m" (csr));
+ __asm__ volatile ("fldcw %0" : : "m" (cw));
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_arm32_gcc.h b/lib/greenlet/platform/switch_arm32_gcc.h
new file mode 100644
index 0000000..035d6b9
--- /dev/null
+++ b/lib/greenlet/platform/switch_arm32_gcc.h
@@ -0,0 +1,79 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 14-Aug-06 File creation. Ported from Arm Thumb. Sylvain Baro
+ * 3-Sep-06 Commented out saving of r1-r3 (r4 already commented out) as I
+ * read that these do not need to be saved. Also added notes and
+ * errors related to the frame pointer. Richard Tew.
+ *
+ * NOTES
+ *
+ * It is not possible to detect if fp is used or not, so the supplied
+ * switch function needs to support it, so that you can remove it if
+ * it does not apply to you.
+ *
+ * POSSIBLE ERRORS
+ *
+ * "fp cannot be used in asm here"
+ *
+ * - Try commenting out "fp" in REGS_TO_SAVE.
+ *
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+#define STACK_MAGIC 0
+#define REG_SP "sp"
+#define REG_SPSP "sp,sp"
+#ifdef __thumb__
+#define REG_FP "r7"
+#define REG_FPFP "r7,r7"
+#define REGS_TO_SAVE_GENERAL "r4", "r5", "r6", "r8", "r9", "r10", "r11", "lr"
+#else
+#define REG_FP "fp"
+#define REG_FPFP "fp,fp"
+#define REGS_TO_SAVE_GENERAL "r4", "r5", "r6", "r7", "r8", "r9", "r10", "lr"
+#endif
+#if defined(__SOFTFP__)
+#define REGS_TO_SAVE REGS_TO_SAVE_GENERAL
+#elif defined(__VFP_FP__)
+#define REGS_TO_SAVE REGS_TO_SAVE_GENERAL, "d8", "d9", "d10", "d11", \
+ "d12", "d13", "d14", "d15"
+#elif defined(__MAVERICK__)
+#define REGS_TO_SAVE REGS_TO_SAVE_GENERAL, "mvf4", "mvf5", "mvf6", "mvf7", \
+ "mvf8", "mvf9", "mvf10", "mvf11", \
+ "mvf12", "mvf13", "mvf14", "mvf15"
+#else
+#define REGS_TO_SAVE REGS_TO_SAVE_GENERAL, "f4", "f5", "f6", "f7"
+#endif
+
+static int
+#ifdef __GNUC__
+__attribute__((optimize("no-omit-frame-pointer")))
+#endif
+slp_switch(void)
+{
+ void *fp;
+ register int *stackref, stsizediff;
+ int result;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("mov r0," REG_FP "\n\tstr r0,%0" : "=m" (fp) : : "r0");
+ __asm__ ("mov %0," REG_SP : "=r" (stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "add " REG_SPSP ",%0\n"
+ "add " REG_FPFP ",%0\n"
+ :
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("ldr r0,%1\n\tmov " REG_FP ",r0\n\tmov %0, #0" : "=r" (result) : "m" (fp) : "r0");
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ return result;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_arm32_ios.h b/lib/greenlet/platform/switch_arm32_ios.h
new file mode 100644
index 0000000..e993707
--- /dev/null
+++ b/lib/greenlet/platform/switch_arm32_ios.h
@@ -0,0 +1,67 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 31-May-15 iOS support. Ported from arm32. Proton <feisuzhu@163.com>
+ *
+ * NOTES
+ *
+ * It is not possible to detect if fp is used or not, so the supplied
+ * switch function needs to support it, so that you can remove it if
+ * it does not apply to you.
+ *
+ * POSSIBLE ERRORS
+ *
+ * "fp cannot be used in asm here"
+ *
+ * - Try commenting out "fp" in REGS_TO_SAVE.
+ *
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 0
+#define REG_SP "sp"
+#define REG_SPSP "sp,sp"
+#define REG_FP "r7"
+#define REG_FPFP "r7,r7"
+#define REGS_TO_SAVE_GENERAL "r4", "r5", "r6", "r8", "r10", "r11", "lr"
+#define REGS_TO_SAVE REGS_TO_SAVE_GENERAL, "d8", "d9", "d10", "d11", \
+ "d12", "d13", "d14", "d15"
+
+static int
+#ifdef __GNUC__
+__attribute__((optimize("no-omit-frame-pointer")))
+#endif
+slp_switch(void)
+{
+ void *fp;
+ register int *stackref, stsizediff, result;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("str " REG_FP ",%0" : "=m" (fp));
+ __asm__ ("mov %0," REG_SP : "=r" (stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "add " REG_SPSP ",%0\n"
+ "add " REG_FPFP ",%0\n"
+ :
+ : "r" (stsizediff)
+ : REGS_TO_SAVE /* Clobber registers, force compiler to
+ * recalculate address of void *fp from REG_SP or REG_FP */
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile (
+ "ldr " REG_FP ", %1\n\t"
+ "mov %0, #0"
+ : "=r" (result)
+ : "m" (fp)
+ : REGS_TO_SAVE /* Force compiler to restore saved registers after this */
+ );
+ return result;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_csky_gcc.h b/lib/greenlet/platform/switch_csky_gcc.h
new file mode 100644
index 0000000..7486b94
--- /dev/null
+++ b/lib/greenlet/platform/switch_csky_gcc.h
@@ -0,0 +1,48 @@
+#ifdef SLP_EVAL
+#define STACK_MAGIC 0
+#define REG_FP "r8"
+#ifdef __CSKYABIV2__
+#define REGS_TO_SAVE_GENERAL "r4", "r5", "r6", "r7", "r9", "r10", "r11", "r15",\
+ "r16", "r17", "r18", "r19", "r20", "r21", "r22",\
+ "r23", "r24", "r25"
+
+#if defined (__CSKY_HARD_FLOAT__) || (__CSKY_VDSP__)
+#define REGS_TO_SAVE REGS_TO_SAVE_GENERAL, "vr8", "vr9", "vr10", "vr11", "vr12",\
+ "vr13", "vr14", "vr15"
+#else
+#define REGS_TO_SAVE REGS_TO_SAVE_GENERAL
+#endif
+#else
+#define REGS_TO_SAVE "r9", "r10", "r11", "r12", "r13", "r15"
+#endif
+
+
+static int
+#ifdef __GNUC__
+__attribute__((optimize("no-omit-frame-pointer")))
+#endif
+slp_switch(void)
+{
+ register int *stackref, stsizediff;
+ int result;
+
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ ("mov %0, sp" : "=r" (stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "addu sp,%0\n"
+ "addu "REG_FP",%0\n"
+ :
+ : "r" (stsizediff)
+ );
+
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("movi %0, 0" : "=r" (result));
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+
+ return result;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_m68k_gcc.h b/lib/greenlet/platform/switch_m68k_gcc.h
new file mode 100644
index 0000000..da761c2
--- /dev/null
+++ b/lib/greenlet/platform/switch_m68k_gcc.h
@@ -0,0 +1,38 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 2014-01-06 Andreas Schwab <schwab@linux-m68k.org>
+ * File created.
+ */
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 0
+
+#define REGS_TO_SAVE "%d2", "%d3", "%d4", "%d5", "%d6", "%d7", \
+ "%a2", "%a3", "%a4"
+
+static int
+slp_switch(void)
+{
+ int err;
+ int *stackref, stsizediff;
+ void *fp, *a5;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("move.l %%fp, %0" : "=m"(fp));
+ __asm__ volatile ("move.l %%a5, %0" : "=m"(a5));
+ __asm__ ("move.l %%sp, %0" : "=r"(stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile ("add.l %0, %%sp; add.l %0, %%fp" : : "r"(stsizediff));
+ SLP_RESTORE_STATE();
+ __asm__ volatile ("clr.l %0" : "=g" (err));
+ }
+ __asm__ volatile ("move.l %0, %%a5" : : "m"(a5));
+ __asm__ volatile ("move.l %0, %%fp" : : "m"(fp));
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ return err;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_mips_unix.h b/lib/greenlet/platform/switch_mips_unix.h
new file mode 100644
index 0000000..1916b26
--- /dev/null
+++ b/lib/greenlet/platform/switch_mips_unix.h
@@ -0,0 +1,64 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 20-Sep-14 Matt Madison <madison@bliss-m.org>
+ * Re-code the saving of the gp register for MIPS64.
+ * 05-Jan-08 Thiemo Seufer <ths@debian.org>
+ * Ported from ppc.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 0
+
+#define REGS_TO_SAVE "$16", "$17", "$18", "$19", "$20", "$21", "$22", \
+ "$23", "$30"
+static int
+slp_switch(void)
+{
+ register int err;
+ register int *stackref, stsizediff;
+#ifdef __mips64
+ uint64_t gpsave;
+#endif
+ __asm__ __volatile__ ("" : : : REGS_TO_SAVE);
+#ifdef __mips64
+ __asm__ __volatile__ ("sd $28,%0" : "=m" (gpsave) : : );
+#endif
+ __asm__ ("move %0, $29" : "=r" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ __volatile__ (
+#ifdef __mips64
+ "daddu $29, %0\n"
+#else
+ "addu $29, %0\n"
+#endif
+ : /* no outputs */
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ }
+#ifdef __mips64
+ __asm__ __volatile__ ("ld $28,%0" : : "m" (gpsave) : );
+#endif
+ __asm__ __volatile__ ("" : : : REGS_TO_SAVE);
+ __asm__ __volatile__ ("move %0, $0" : "=r" (err));
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_ppc64_aix.h b/lib/greenlet/platform/switch_ppc64_aix.h
new file mode 100644
index 0000000..e07b8de
--- /dev/null
+++ b/lib/greenlet/platform/switch_ppc64_aix.h
@@ -0,0 +1,103 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 16-Oct-20 Jesse Gorzinski <jgorzins@us.ibm.com>
+ * Copied from Linux PPC64 implementation
+ * 04-Sep-18 Alexey Borzenkov <snaury@gmail.com>
+ * Workaround a gcc bug using manual save/restore of r30
+ * 21-Mar-18 Tulio Magno Quites Machado Filho <tuliom@linux.vnet.ibm.com>
+ * Added r30 to the list of saved registers in order to fully comply with
+ * both ppc64 ELFv1 ABI and the ppc64le ELFv2 ABI, that classify this
+ * register as a nonvolatile register used for local variables.
+ * 21-Mar-18 Laszlo Boszormenyi <gcs@debian.org>
+ * Save r2 (TOC pointer) manually.
+ * 10-Dec-13 Ulrich Weigand <uweigand@de.ibm.com>
+ * Support ELFv2 ABI. Save float/vector registers.
+ * 09-Mar-12 Michael Ellerman <michael@ellerman.id.au>
+ * 64-bit implementation, copied from 32-bit.
+ * 07-Sep-05 (py-dev mailing list discussion)
+ * removed 'r31' from the register-saved. !!!! WARNING !!!!
+ * It means that this file can no longer be compiled statically!
+ * It is now only suitable as part of a dynamic library!
+ * 14-Jan-04 Bob Ippolito <bob@redivi.com>
+ * added cr2-cr4 to the registers to be saved.
+ * Open questions: Should we save FP registers?
+ * What about vector registers?
+ * Differences between darwin and unix?
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 04-Oct-02 Gustavo Niemeyer <niemeyer@conectiva.com>
+ * Ported from MacOS version.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 29-Jun-02 Christian Tismer <tismer@tismer.com>
+ * Added register 13-29, 31 saves. The same way as
+ * Armin Rigo did for the x86_unix version.
+ * This seems to be now fully functional!
+ * 04-Mar-02 Hye-Shik Chang <perky@fallin.lv>
+ * Ported from i386.
+ * 31-Jul-12 Trevor Bowen <trevorbowen@gmail.com>
+ * Changed memory constraints to register only.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 6
+
+#if defined(__ALTIVEC__)
+#define ALTIVEC_REGS \
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", \
+ "v28", "v29", "v30", "v31",
+#else
+#define ALTIVEC_REGS
+#endif
+
+#define REGS_TO_SAVE "r14", "r15", "r16", "r17", "r18", "r19", "r20", \
+ "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29", \
+ "r31", \
+ "fr14", "fr15", "fr16", "fr17", "fr18", "fr19", "fr20", "fr21", \
+ "fr22", "fr23", "fr24", "fr25", "fr26", "fr27", "fr28", "fr29", \
+ "fr30", "fr31", \
+ ALTIVEC_REGS \
+ "cr2", "cr3", "cr4"
+
+static int
+slp_switch(void)
+{
+ register int err;
+ register long *stackref, stsizediff;
+ void * toc;
+ void * r30;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("std 2, %0" : "=m" (toc));
+ __asm__ volatile ("std 30, %0" : "=m" (r30));
+ __asm__ ("mr %0, 1" : "=r" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "mr 11, %0\n"
+ "add 1, 1, 11\n"
+ : /* no outputs */
+ : "r" (stsizediff)
+ : "11"
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("ld 30, %0" : : "m" (r30));
+ __asm__ volatile ("ld 2, %0" : : "m" (toc));
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("li %0, 0" : "=r" (err));
+ return err;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_ppc64_linux.h b/lib/greenlet/platform/switch_ppc64_linux.h
new file mode 100644
index 0000000..88e6847
--- /dev/null
+++ b/lib/greenlet/platform/switch_ppc64_linux.h
@@ -0,0 +1,105 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 04-Sep-18 Alexey Borzenkov <snaury@gmail.com>
+ * Workaround a gcc bug using manual save/restore of r30
+ * 21-Mar-18 Tulio Magno Quites Machado Filho <tuliom@linux.vnet.ibm.com>
+ * Added r30 to the list of saved registers in order to fully comply with
+ * both ppc64 ELFv1 ABI and the ppc64le ELFv2 ABI, that classify this
+ * register as a nonvolatile register used for local variables.
+ * 21-Mar-18 Laszlo Boszormenyi <gcs@debian.org>
+ * Save r2 (TOC pointer) manually.
+ * 10-Dec-13 Ulrich Weigand <uweigand@de.ibm.com>
+ * Support ELFv2 ABI. Save float/vector registers.
+ * 09-Mar-12 Michael Ellerman <michael@ellerman.id.au>
+ * 64-bit implementation, copied from 32-bit.
+ * 07-Sep-05 (py-dev mailing list discussion)
+ * removed 'r31' from the register-saved. !!!! WARNING !!!!
+ * It means that this file can no longer be compiled statically!
+ * It is now only suitable as part of a dynamic library!
+ * 14-Jan-04 Bob Ippolito <bob@redivi.com>
+ * added cr2-cr4 to the registers to be saved.
+ * Open questions: Should we save FP registers?
+ * What about vector registers?
+ * Differences between darwin and unix?
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 04-Oct-02 Gustavo Niemeyer <niemeyer@conectiva.com>
+ * Ported from MacOS version.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 29-Jun-02 Christian Tismer <tismer@tismer.com>
+ * Added register 13-29, 31 saves. The same way as
+ * Armin Rigo did for the x86_unix version.
+ * This seems to be now fully functional!
+ * 04-Mar-02 Hye-Shik Chang <perky@fallin.lv>
+ * Ported from i386.
+ * 31-Jul-12 Trevor Bowen <trevorbowen@gmail.com>
+ * Changed memory constraints to register only.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#if _CALL_ELF == 2
+#define STACK_MAGIC 4
+#else
+#define STACK_MAGIC 6
+#endif
+
+#if defined(__ALTIVEC__)
+#define ALTIVEC_REGS \
+ "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", \
+ "v28", "v29", "v30", "v31",
+#else
+#define ALTIVEC_REGS
+#endif
+
+#define REGS_TO_SAVE "r14", "r15", "r16", "r17", "r18", "r19", "r20", \
+ "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29", \
+ "r31", \
+ "fr14", "fr15", "fr16", "fr17", "fr18", "fr19", "fr20", "fr21", \
+ "fr22", "fr23", "fr24", "fr25", "fr26", "fr27", "fr28", "fr29", \
+ "fr30", "fr31", \
+ ALTIVEC_REGS \
+ "cr2", "cr3", "cr4"
+
+static int
+slp_switch(void)
+{
+ register int err;
+ register long *stackref, stsizediff;
+ void * toc;
+ void * r30;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("std 2, %0" : "=m" (toc));
+ __asm__ volatile ("std 30, %0" : "=m" (r30));
+ __asm__ ("mr %0, 1" : "=r" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "mr 11, %0\n"
+ "add 1, 1, 11\n"
+ : /* no outputs */
+ : "r" (stsizediff)
+ : "11"
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("ld 30, %0" : : "m" (r30));
+ __asm__ volatile ("ld 2, %0" : : "m" (toc));
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("li %0, 0" : "=r" (err));
+ return err;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_ppc_aix.h b/lib/greenlet/platform/switch_ppc_aix.h
new file mode 100644
index 0000000..c7d476f
--- /dev/null
+++ b/lib/greenlet/platform/switch_ppc_aix.h
@@ -0,0 +1,87 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 07-Mar-11 Floris Bruynooghe <flub@devork.be>
+ * Do not add stsizediff to general purpose
+ * register (GPR) 30 as this is a non-volatile and
+ * unused by the PowerOpen Environment, therefore
+ * this was modifying a user register instead of the
+ * frame pointer (which does not seem to exist).
+ * 07-Sep-05 (py-dev mailing list discussion)
+ * removed 'r31' from the register-saved. !!!! WARNING !!!!
+ * It means that this file can no longer be compiled statically!
+ * It is now only suitable as part of a dynamic library!
+ * 14-Jan-04 Bob Ippolito <bob@redivi.com>
+ * added cr2-cr4 to the registers to be saved.
+ * Open questions: Should we save FP registers?
+ * What about vector registers?
+ * Differences between darwin and unix?
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 04-Oct-02 Gustavo Niemeyer <niemeyer@conectiva.com>
+ * Ported from MacOS version.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 29-Jun-02 Christian Tismer <tismer@tismer.com>
+ * Added register 13-29, 31 saves. The same way as
+ * Armin Rigo did for the x86_unix version.
+ * This seems to be now fully functional!
+ * 04-Mar-02 Hye-Shik Chang <perky@fallin.lv>
+ * Ported from i386.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 3
+
+/* !!!!WARNING!!!! need to add "r31" in the next line if this header file
+ * is meant to be compiled non-dynamically!
+ */
+#define REGS_TO_SAVE "r13", "r14", "r15", "r16", "r17", "r18", "r19", "r20", \
+ "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29", \
+ "cr2", "cr3", "cr4"
+static int
+slp_switch(void)
+{
+ register int err;
+ register int *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ ("mr %0, 1" : "=r" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "mr 11, %0\n"
+ "add 1, 1, 11\n"
+ : /* no outputs */
+ : "r" (stsizediff)
+ : "11"
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("li %0, 0" : "=r" (err));
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_ppc_linux.h b/lib/greenlet/platform/switch_ppc_linux.h
new file mode 100644
index 0000000..0a71255
--- /dev/null
+++ b/lib/greenlet/platform/switch_ppc_linux.h
@@ -0,0 +1,84 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 07-Sep-05 (py-dev mailing list discussion)
+ * removed 'r31' from the register-saved. !!!! WARNING !!!!
+ * It means that this file can no longer be compiled statically!
+ * It is now only suitable as part of a dynamic library!
+ * 14-Jan-04 Bob Ippolito <bob@redivi.com>
+ * added cr2-cr4 to the registers to be saved.
+ * Open questions: Should we save FP registers?
+ * What about vector registers?
+ * Differences between darwin and unix?
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 04-Oct-02 Gustavo Niemeyer <niemeyer@conectiva.com>
+ * Ported from MacOS version.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 29-Jun-02 Christian Tismer <tismer@tismer.com>
+ * Added register 13-29, 31 saves. The same way as
+ * Armin Rigo did for the x86_unix version.
+ * This seems to be now fully functional!
+ * 04-Mar-02 Hye-Shik Chang <perky@fallin.lv>
+ * Ported from i386.
+ * 31-Jul-12 Trevor Bowen <trevorbowen@gmail.com>
+ * Changed memory constraints to register only.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 3
+
+/* !!!!WARNING!!!! need to add "r31" in the next line if this header file
+ * is meant to be compiled non-dynamically!
+ */
+#define REGS_TO_SAVE "r13", "r14", "r15", "r16", "r17", "r18", "r19", "r20", \
+ "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29", \
+ "cr2", "cr3", "cr4"
+static int
+slp_switch(void)
+{
+ register int err;
+ register int *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ ("mr %0, 1" : "=r" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "mr 11, %0\n"
+ "add 1, 1, 11\n"
+ "add 30, 30, 11\n"
+ : /* no outputs */
+ : "r" (stsizediff)
+ : "11"
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("li %0, 0" : "=r" (err));
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_ppc_macosx.h b/lib/greenlet/platform/switch_ppc_macosx.h
new file mode 100644
index 0000000..56e573f
--- /dev/null
+++ b/lib/greenlet/platform/switch_ppc_macosx.h
@@ -0,0 +1,82 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 07-Sep-05 (py-dev mailing list discussion)
+ * removed 'r31' from the register-saved. !!!! WARNING !!!!
+ * It means that this file can no longer be compiled statically!
+ * It is now only suitable as part of a dynamic library!
+ * 14-Jan-04 Bob Ippolito <bob@redivi.com>
+ * added cr2-cr4 to the registers to be saved.
+ * Open questions: Should we save FP registers?
+ * What about vector registers?
+ * Differences between darwin and unix?
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 29-Jun-02 Christian Tismer <tismer@tismer.com>
+ * Added register 13-29, 31 saves. The same way as
+ * Armin Rigo did for the x86_unix version.
+ * This seems to be now fully functional!
+ * 04-Mar-02 Hye-Shik Chang <perky@fallin.lv>
+ * Ported from i386.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 3
+
+/* !!!!WARNING!!!! need to add "r31" in the next line if this header file
+ * is meant to be compiled non-dynamically!
+ */
+#define REGS_TO_SAVE "r13", "r14", "r15", "r16", "r17", "r18", "r19", "r20", \
+ "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29", \
+ "cr2", "cr3", "cr4"
+
+static int
+slp_switch(void)
+{
+ register int err;
+ register int *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ ("; asm block 2\n\tmr %0, r1" : "=g" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "; asm block 3\n"
+ "\tmr r11, %0\n"
+ "\tadd r1, r1, r11\n"
+ "\tadd r30, r30, r11\n"
+ : /* no outputs */
+ : "g" (stsizediff)
+ : "r11"
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("li %0, 0" : "=r" (err));
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_ppc_unix.h b/lib/greenlet/platform/switch_ppc_unix.h
new file mode 100644
index 0000000..2b3d307
--- /dev/null
+++ b/lib/greenlet/platform/switch_ppc_unix.h
@@ -0,0 +1,82 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 07-Sep-05 (py-dev mailing list discussion)
+ * removed 'r31' from the register-saved. !!!! WARNING !!!!
+ * It means that this file can no longer be compiled statically!
+ * It is now only suitable as part of a dynamic library!
+ * 14-Jan-04 Bob Ippolito <bob@redivi.com>
+ * added cr2-cr4 to the registers to be saved.
+ * Open questions: Should we save FP registers?
+ * What about vector registers?
+ * Differences between darwin and unix?
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 04-Oct-02 Gustavo Niemeyer <niemeyer@conectiva.com>
+ * Ported from MacOS version.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 29-Jun-02 Christian Tismer <tismer@tismer.com>
+ * Added register 13-29, 31 saves. The same way as
+ * Armin Rigo did for the x86_unix version.
+ * This seems to be now fully functional!
+ * 04-Mar-02 Hye-Shik Chang <perky@fallin.lv>
+ * Ported from i386.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 3
+
+/* !!!!WARNING!!!! need to add "r31" in the next line if this header file
+ * is meant to be compiled non-dynamically!
+ */
+#define REGS_TO_SAVE "r13", "r14", "r15", "r16", "r17", "r18", "r19", "r20", \
+ "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29", \
+ "cr2", "cr3", "cr4"
+static int
+slp_switch(void)
+{
+ register int err;
+ register int *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ ("mr %0, 1" : "=g" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "mr 11, %0\n"
+ "add 1, 1, 11\n"
+ "add 30, 30, 11\n"
+ : /* no outputs */
+ : "g" (stsizediff)
+ : "11"
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("li %0, 0" : "=r" (err));
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_riscv_unix.h b/lib/greenlet/platform/switch_riscv_unix.h
new file mode 100644
index 0000000..5b5ea98
--- /dev/null
+++ b/lib/greenlet/platform/switch_riscv_unix.h
@@ -0,0 +1,32 @@
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+#define STACK_MAGIC 0
+
+#define REGS_TO_SAVE "s0", "s1", "s2", "s3", "s4", "s5", \
+ "s6", "s7", "s8", "s9", "s10", "s11", "fs0", "fs1", \
+ "fs2", "fs3", "fs4", "fs5", "fs6", "fs7", "fs8", "fs9", \
+ "fs10", "fs11"
+
+static int
+slp_switch(void)
+{
+ register int ret;
+ register long *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("mv %0, sp" : "=r" (stackref) : );
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "add sp, sp, %0\n\t"
+ : /* no outputs */
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("mv %0, zero" : "=r" (ret) : );
+ return ret;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_s390_unix.h b/lib/greenlet/platform/switch_s390_unix.h
new file mode 100644
index 0000000..6641854
--- /dev/null
+++ b/lib/greenlet/platform/switch_s390_unix.h
@@ -0,0 +1,87 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 25-Jan-12 Alexey Borzenkov <snaury@gmail.com>
+ * Fixed Linux/S390 port to work correctly with
+ * different optimization options both on 31-bit
+ * and 64-bit. Thanks to Stefan Raabe for lots
+ * of testing.
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 06-Oct-02 Gustavo Niemeyer <niemeyer@conectiva.com>
+ * Ported to Linux/S390.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#ifdef __s390x__
+#define STACK_MAGIC 20 /* 20 * 8 = 160 bytes of function call area */
+#else
+#define STACK_MAGIC 24 /* 24 * 4 = 96 bytes of function call area */
+#endif
+
+/* Technically, r11-r13 also need saving, but function prolog starts
+ with stm(g) and since there are so many saved registers already
+ it won't be optimized, resulting in all r6-r15 being saved */
+#define REGS_TO_SAVE "r6", "r7", "r8", "r9", "r10", "r14", \
+ "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", \
+ "f8", "f9", "f10", "f11", "f12", "f13", "f14", "f15"
+
+static int
+slp_switch(void)
+{
+ register int ret;
+ register long *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+#ifdef __s390x__
+ __asm__ volatile ("lgr %0, 15" : "=r" (stackref) : );
+#else
+ __asm__ volatile ("lr %0, 15" : "=r" (stackref) : );
+#endif
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+/* N.B.
+ r11 may be used as the frame pointer, and in that case it cannot be
+ clobbered and needs offsetting just like the stack pointer (but in cases
+ where frame pointer isn't used we might clobber it accidentally). What's
+ scary is that r11 is 2nd (and even 1st when GOT is used) callee saved
+ register that gcc would chose for surviving function calls. However,
+ since r6-r10 are clobbered above, their cost for reuse is reduced, so
+ gcc IRA will chose them over r11 (not seeing r11 is implicitly saved),
+ making it relatively safe to offset in all cases. :) */
+ __asm__ volatile (
+#ifdef __s390x__
+ "agr 15, %0\n\t"
+ "agr 11, %0"
+#else
+ "ar 15, %0\n\t"
+ "ar 11, %0"
+#endif
+ : /* no outputs */
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("lhi %0, 0" : "=r" (ret) : );
+ return ret;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_sparc_sun_gcc.h b/lib/greenlet/platform/switch_sparc_sun_gcc.h
new file mode 100644
index 0000000..652b57f
--- /dev/null
+++ b/lib/greenlet/platform/switch_sparc_sun_gcc.h
@@ -0,0 +1,92 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 16-May-15 Alexey Borzenkov <snaury@gmail.com>
+ * Move stack spilling code inside save/restore functions
+ * 30-Aug-13 Floris Bruynooghe <flub@devork.be>
+ Clean the register windows again before returning.
+ This does not clobber the PIC register as it leaves
+ the current window intact and is required for multi-
+ threaded code to work correctly.
+ * 08-Mar-11 Floris Bruynooghe <flub@devork.be>
+ * No need to set return value register explicitly
+ * before the stack and framepointer are adjusted
+ * as none of the other registers are influenced by
+ * this. Also don't needlessly clean the windows
+ * ('ta %0" :: "i" (ST_CLEAN_WINDOWS)') as that
+ * clobbers the gcc PIC register (%l7).
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * added support for SunOS sparc with gcc
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+
+#define STACK_MAGIC 0
+
+
+#if defined(__sparcv9)
+#define SLP_FLUSHW __asm__ volatile ("flushw")
+#else
+#define SLP_FLUSHW __asm__ volatile ("ta 3") /* ST_FLUSH_WINDOWS */
+#endif
+
+/* On sparc we need to spill register windows inside save/restore functions */
+#define SLP_BEFORE_SAVE_STATE() SLP_FLUSHW
+#define SLP_BEFORE_RESTORE_STATE() SLP_FLUSHW
+
+
+static int
+slp_switch(void)
+{
+ register int err;
+ register int *stackref, stsizediff;
+
+ /* Put current stack pointer into stackref.
+ * Register spilling is done in save/restore.
+ */
+ __asm__ volatile ("mov %%sp, %0" : "=r" (stackref));
+
+ {
+ /* Thou shalt put SLP_SAVE_STATE into a local block */
+ /* Copy the current stack onto the heap */
+ SLP_SAVE_STATE(stackref, stsizediff);
+
+ /* Increment stack and frame pointer by stsizediff */
+ __asm__ volatile (
+ "add %0, %%sp, %%sp\n\t"
+ "add %0, %%fp, %%fp"
+ : : "r" (stsizediff));
+
+ /* Copy new stack from it's save store on the heap */
+ SLP_RESTORE_STATE();
+
+ __asm__ volatile ("mov %1, %0" : "=r" (err) : "i" (0));
+ return err;
+ }
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_x32_unix.h b/lib/greenlet/platform/switch_x32_unix.h
new file mode 100644
index 0000000..cb14ec1
--- /dev/null
+++ b/lib/greenlet/platform/switch_x32_unix.h
@@ -0,0 +1,63 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 17-Aug-12 Fantix King <fantix.king@gmail.com>
+ * Ported from amd64.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 0
+
+#define REGS_TO_SAVE "r12", "r13", "r14", "r15"
+
+
+static int
+slp_switch(void)
+{
+ void* ebp;
+ void* ebx;
+ unsigned int csr;
+ unsigned short cw;
+ register int err;
+ register int *stackref, stsizediff;
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("fstcw %0" : "=m" (cw));
+ __asm__ volatile ("stmxcsr %0" : "=m" (csr));
+ __asm__ volatile ("movl %%ebp, %0" : "=m" (ebp));
+ __asm__ volatile ("movl %%ebx, %0" : "=m" (ebx));
+ __asm__ ("movl %%esp, %0" : "=g" (stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "addl %0, %%esp\n"
+ "addl %0, %%ebp\n"
+ :
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ }
+ __asm__ volatile ("movl %0, %%ebx" : : "m" (ebx));
+ __asm__ volatile ("movl %0, %%ebp" : : "m" (ebp));
+ __asm__ volatile ("ldmxcsr %0" : : "m" (csr));
+ __asm__ volatile ("fldcw %0" : : "m" (cw));
+ __asm__ volatile ("" : : : REGS_TO_SAVE);
+ __asm__ volatile ("xorl %%eax, %%eax" : "=a" (err));
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/platform/switch_x64_masm.asm b/lib/greenlet/platform/switch_x64_masm.asm
new file mode 100644
index 0000000..f5c72a2
--- /dev/null
+++ b/lib/greenlet/platform/switch_x64_masm.asm
@@ -0,0 +1,111 @@
+;
+; stack switching code for MASM on x641
+; Kristjan Valur Jonsson, sept 2005
+;
+
+
+;prototypes for our calls
+slp_save_state_asm PROTO
+slp_restore_state_asm PROTO
+
+
+pushxmm MACRO reg
+ sub rsp, 16
+ .allocstack 16
+ movaps [rsp], reg ; faster than movups, but we must be aligned
+ ; .savexmm128 reg, offset (don't know what offset is, no documentation)
+ENDM
+popxmm MACRO reg
+ movaps reg, [rsp] ; faster than movups, but we must be aligned
+ add rsp, 16
+ENDM
+
+pushreg MACRO reg
+ push reg
+ .pushreg reg
+ENDM
+popreg MACRO reg
+ pop reg
+ENDM
+
+
+.code
+slp_switch PROC FRAME
+ ;realign stack to 16 bytes after return address push, makes the following faster
+ sub rsp,8
+ .allocstack 8
+
+ pushxmm xmm15
+ pushxmm xmm14
+ pushxmm xmm13
+ pushxmm xmm12
+ pushxmm xmm11
+ pushxmm xmm10
+ pushxmm xmm9
+ pushxmm xmm8
+ pushxmm xmm7
+ pushxmm xmm6
+
+ pushreg r15
+ pushreg r14
+ pushreg r13
+ pushreg r12
+
+ pushreg rbp
+ pushreg rbx
+ pushreg rdi
+ pushreg rsi
+
+ sub rsp, 10h ;allocate the singlefunction argument (must be multiple of 16)
+ .allocstack 10h
+.endprolog
+
+ lea rcx, [rsp+10h] ;load stack base that we are saving
+ call slp_save_state_asm ;pass stackpointer, return offset in eax
+ cmp rax, 1
+ je EXIT1
+ cmp rax, -1
+ je EXIT2
+ ;actual stack switch:
+ add rsp, rax
+ call slp_restore_state_asm
+ xor rax, rax ;return 0
+
+EXIT:
+
+ add rsp, 10h
+ popreg rsi
+ popreg rdi
+ popreg rbx
+ popreg rbp
+
+ popreg r12
+ popreg r13
+ popreg r14
+ popreg r15
+
+ popxmm xmm6
+ popxmm xmm7
+ popxmm xmm8
+ popxmm xmm9
+ popxmm xmm10
+ popxmm xmm11
+ popxmm xmm12
+ popxmm xmm13
+ popxmm xmm14
+ popxmm xmm15
+
+ add rsp, 8
+ ret
+
+EXIT1:
+ mov rax, 1
+ jmp EXIT
+
+EXIT2:
+ sar rax, 1
+ jmp EXIT
+
+slp_switch ENDP
+
+END \ No newline at end of file
diff --git a/lib/greenlet/platform/switch_x64_masm.obj b/lib/greenlet/platform/switch_x64_masm.obj
new file mode 100644
index 0000000..64e3e6b
--- /dev/null
+++ b/lib/greenlet/platform/switch_x64_masm.obj
Binary files differ
diff --git a/lib/greenlet/platform/switch_x64_msvc.h b/lib/greenlet/platform/switch_x64_msvc.h
new file mode 100644
index 0000000..601ea56
--- /dev/null
+++ b/lib/greenlet/platform/switch_x64_msvc.h
@@ -0,0 +1,60 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 26-Sep-02 Christian Tismer <tismer@tismer.com>
+ * again as a result of virtualized stack access,
+ * the compiler used less registers. Needed to
+ * explicit mention registers in order to get them saved.
+ * Thanks to Jeff Senn for pointing this out and help.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 01-Mar-02 Christian Tismer <tismer@tismer.com>
+ * Initial final version after lots of iterations for i386.
+ */
+
+/* Avoid alloca redefined warning on mingw64 */
+#ifndef alloca
+#define alloca _alloca
+#endif
+
+#define STACK_REFPLUS 1
+#define STACK_MAGIC 0
+
+/* Use the generic support for an external assembly language slp_switch function. */
+#define EXTERNAL_ASM
+
+#ifdef SLP_EVAL
+/* This always uses the external masm assembly file. */
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/* we have IsBadReadPtr available, so we can peek at objects */
+/*
+#define STACKLESS_SPY
+
+#ifdef IMPLEMENT_STACKLESSMODULE
+#include "Windows.h"
+#define CANNOT_READ_MEM(p, bytes) IsBadReadPtr(p, bytes)
+
+static int IS_ON_STACK(void*p)
+{
+ int stackref;
+ intptr_t stackbase = ((intptr_t)&stackref) & 0xfffff000;
+ return (intptr_t)p >= stackbase && (intptr_t)p < stackbase + 0x00100000;
+}
+
+#endif
+*/ \ No newline at end of file
diff --git a/lib/greenlet/platform/switch_x86_msvc.h b/lib/greenlet/platform/switch_x86_msvc.h
new file mode 100644
index 0000000..010a22c
--- /dev/null
+++ b/lib/greenlet/platform/switch_x86_msvc.h
@@ -0,0 +1,88 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 26-Sep-02 Christian Tismer <tismer@tismer.com>
+ * again as a result of virtualized stack access,
+ * the compiler used less registers. Needed to
+ * explicit mention registers in order to get them saved.
+ * Thanks to Jeff Senn for pointing this out and help.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for sparc
+ * 01-Mar-02 Christian Tismer <tismer@tismer.com>
+ * Initial final version after lots of iterations for i386.
+ */
+
+#define alloca _alloca
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+#define STACK_MAGIC 0
+
+/* Some magic to quell warnings and keep slp_switch() from crashing when built
+ with VC90. Disable global optimizations, and the warning: frame pointer
+ register 'ebp' modified by inline assembly code */
+#pragma optimize("g", off)
+#pragma warning(disable:4731)
+
+static int
+slp_switch(void)
+{
+ void* seh;
+ register int *stackref, stsizediff;
+ __asm mov eax, fs:[0]
+ __asm mov [seh], eax
+ __asm mov stackref, esp;
+ /* modify EBX, ESI and EDI in order to get them preserved */
+ __asm mov ebx, ebx;
+ __asm xchg esi, edi;
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm {
+ mov eax, stsizediff
+ add esp, eax
+ add ebp, eax
+ }
+ SLP_RESTORE_STATE();
+ }
+ __asm mov eax, [seh]
+ __asm mov fs:[0], eax
+ return 0;
+}
+
+/* re-enable ebp warning and global optimizations. */
+#pragma optimize("g", on)
+#pragma warning(default:4731)
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/* we have IsBadReadPtr available, so we can peek at objects */
+#define STACKLESS_SPY
+
+#ifdef IMPLEMENT_STACKLESSMODULE
+#include "Windows.h"
+#define CANNOT_READ_MEM(p, bytes) IsBadReadPtr(p, bytes)
+
+static int IS_ON_STACK(void*p)
+{
+ int stackref;
+ int stackbase = ((int)&stackref) & 0xfffff000;
+ return (int)p >= stackbase && (int)p < stackbase + 0x00100000;
+}
+
+#endif
diff --git a/lib/greenlet/platform/switch_x86_unix.h b/lib/greenlet/platform/switch_x86_unix.h
new file mode 100644
index 0000000..3a95186
--- /dev/null
+++ b/lib/greenlet/platform/switch_x86_unix.h
@@ -0,0 +1,105 @@
+/*
+ * this is the internal transfer function.
+ *
+ * HISTORY
+ * 3-May-13 Ralf Schmitt <ralf@systemexit.de>
+ * Add support for strange GCC caller-save decisions
+ * (ported from switch_aarch64_gcc.h)
+ * 19-Aug-11 Alexey Borzenkov <snaury@gmail.com>
+ * Correctly save ebp, ebx and cw
+ * 07-Sep-05 (py-dev mailing list discussion)
+ * removed 'ebx' from the register-saved. !!!! WARNING !!!!
+ * It means that this file can no longer be compiled statically!
+ * It is now only suitable as part of a dynamic library!
+ * 24-Nov-02 Christian Tismer <tismer@tismer.com>
+ * needed to add another magic constant to insure
+ * that f in slp_eval_frame(PyFrameObject *f)
+ * STACK_REFPLUS will probably be 1 in most cases.
+ * gets included into the saved stack area.
+ * 17-Sep-02 Christian Tismer <tismer@tismer.com>
+ * after virtualizing stack save/restore, the
+ * stack size shrunk a bit. Needed to introduce
+ * an adjustment STACK_MAGIC per platform.
+ * 15-Sep-02 Gerd Woetzel <gerd.woetzel@GMD.DE>
+ * slightly changed framework for spark
+ * 31-Avr-02 Armin Rigo <arigo@ulb.ac.be>
+ * Added ebx, esi and edi register-saves.
+ * 01-Mar-02 Samual M. Rushing <rushing@ironport.com>
+ * Ported from i386.
+ */
+
+#define STACK_REFPLUS 1
+
+#ifdef SLP_EVAL
+
+/* #define STACK_MAGIC 3 */
+/* the above works fine with gcc 2.96, but 2.95.3 wants this */
+#define STACK_MAGIC 0
+
+#if __GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 5)
+# define ATTR_NOCLONE __attribute__((noclone))
+#else
+# define ATTR_NOCLONE
+#endif
+
+static int
+slp_switch(void)
+{
+ int err;
+#ifdef _WIN32
+ void *seh;
+#endif
+ void *ebp, *ebx;
+ unsigned short cw;
+ register int *stackref, stsizediff;
+ __asm__ volatile ("" : : : "esi", "edi");
+ __asm__ volatile ("fstcw %0" : "=m" (cw));
+ __asm__ volatile ("movl %%ebp, %0" : "=m" (ebp));
+ __asm__ volatile ("movl %%ebx, %0" : "=m" (ebx));
+#ifdef _WIN32
+ __asm__ volatile (
+ "movl %%fs:0x0, %%eax\n"
+ "movl %%eax, %0\n"
+ : "=m" (seh)
+ :
+ : "eax");
+#endif
+ __asm__ ("movl %%esp, %0" : "=g" (stackref));
+ {
+ SLP_SAVE_STATE(stackref, stsizediff);
+ __asm__ volatile (
+ "addl %0, %%esp\n"
+ "addl %0, %%ebp\n"
+ :
+ : "r" (stsizediff)
+ );
+ SLP_RESTORE_STATE();
+ __asm__ volatile ("xorl %%eax, %%eax" : "=a" (err));
+ }
+#ifdef _WIN32
+ __asm__ volatile (
+ "movl %0, %%eax\n"
+ "movl %%eax, %%fs:0x0\n"
+ :
+ : "m" (seh)
+ : "eax");
+#endif
+ __asm__ volatile ("movl %0, %%ebx" : : "m" (ebx));
+ __asm__ volatile ("movl %0, %%ebp" : : "m" (ebp));
+ __asm__ volatile ("fldcw %0" : : "m" (cw));
+ __asm__ volatile ("" : : : "esi", "edi");
+ return err;
+}
+
+#endif
+
+/*
+ * further self-processing support
+ */
+
+/*
+ * if you want to add self-inspection tools, place them
+ * here. See the x86_msvc for the necessary defines.
+ * These features are highly experimental und not
+ * essential yet.
+ */
diff --git a/lib/greenlet/slp_platformselect.h b/lib/greenlet/slp_platformselect.h
new file mode 100644
index 0000000..b5e8eb6
--- /dev/null
+++ b/lib/greenlet/slp_platformselect.h
@@ -0,0 +1,58 @@
+/*
+ * Platform Selection for Stackless Python
+ */
+
+#if defined(MS_WIN32) && !defined(MS_WIN64) && defined(_M_IX86) && defined(_MSC_VER)
+#include "platform/switch_x86_msvc.h" /* MS Visual Studio on X86 */
+#elif defined(MS_WIN64) && defined(_M_X64) && defined(_MSC_VER) || defined(__MINGW64__)
+#include "platform/switch_x64_msvc.h" /* MS Visual Studio on X64 */
+#elif defined(__GNUC__) && defined(__amd64__) && defined(__ILP32__)
+#include "platform/switch_x32_unix.h" /* gcc on amd64 with x32 ABI */
+#elif defined(__GNUC__) && defined(__amd64__)
+#include "platform/switch_amd64_unix.h" /* gcc on amd64 */
+#elif defined(__GNUC__) && defined(__i386__)
+#include "platform/switch_x86_unix.h" /* gcc on X86 */
+#elif defined(__GNUC__) && defined(__powerpc64__) && (defined(__linux__) || defined(__FreeBSD__))
+#include "platform/switch_ppc64_linux.h" /* gcc on PowerPC 64-bit */
+#elif defined(__GNUC__) && defined(__PPC__) && (defined(__linux__) || defined(__FreeBSD__))
+#include "platform/switch_ppc_linux.h" /* gcc on PowerPC */
+#elif defined(__GNUC__) && defined(__ppc__) && defined(__APPLE__)
+#include "platform/switch_ppc_macosx.h" /* Apple MacOS X on PowerPC */
+#elif defined(__GNUC__) && defined(__powerpc64__) && defined(_AIX)
+#include "platform/switch_ppc64_aix.h" /* gcc on AIX/PowerPC 64-bit */
+#elif defined(__GNUC__) && defined(_ARCH_PPC) && defined(_AIX)
+#include "platform/switch_ppc_aix.h" /* gcc on AIX/PowerPC */
+#elif defined(__GNUC__) && defined(sparc)
+#include "platform/switch_sparc_sun_gcc.h" /* SunOS sparc with gcc */
+#elif defined(__SUNPRO_C) && defined(sparc) && defined(sun)
+#include "platform/switch_sparc_sun_gcc.h" /* SunStudio on amd64 */
+#elif defined(__SUNPRO_C) && defined(__amd64__) && defined(sun)
+#include "platform/switch_amd64_unix.h" /* SunStudio on amd64 */
+#elif defined(__SUNPRO_C) && defined(__i386__) && defined(sun)
+#include "platform/switch_x86_unix.h" /* SunStudio on x86 */
+#elif defined(__GNUC__) && defined(__s390__) && defined(__linux__)
+#include "platform/switch_s390_unix.h" /* Linux/S390 */
+#elif defined(__GNUC__) && defined(__s390x__) && defined(__linux__)
+#include "platform/switch_s390_unix.h" /* Linux/S390 zSeries (64-bit) */
+#elif defined(__GNUC__) && defined(__arm__)
+#ifdef __APPLE__
+#include <TargetConditionals.h>
+#endif
+#if TARGET_OS_IPHONE
+#include "platform/switch_arm32_ios.h" /* iPhone OS on arm32 */
+#else
+#include "platform/switch_arm32_gcc.h" /* gcc using arm32 */
+#endif
+#elif defined(__GNUC__) && defined(__mips__) && defined(__linux__)
+#include "platform/switch_mips_unix.h" /* Linux/MIPS */
+#elif defined(__GNUC__) && defined(__aarch64__)
+#include "platform/switch_aarch64_gcc.h" /* Aarch64 ABI */
+#elif defined(__GNUC__) && defined(__mc68000__)
+#include "platform/switch_m68k_gcc.h" /* gcc on m68k */
+#elif defined(__GNUC__) && defined(__csky__)
+#include "platform/switch_csky_gcc.h" /* gcc on csky */
+#elif defined(__GNUC__) && defined(__riscv)
+#include "platform/switch_riscv_unix.h" /* gcc on RISC-V */
+#elif defined(__GNUC__) && defined(__alpha__)
+#include "platform/switch_alpha_unix.h" /* gcc on DEC Alpha */
+#endif
diff --git a/lib/greenlet/tests/__init__.py b/lib/greenlet/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/greenlet/tests/__init__.py
diff --git a/lib/greenlet/tests/_test_extension.c b/lib/greenlet/tests/_test_extension.c
new file mode 100644
index 0000000..4fe087d
--- /dev/null
+++ b/lib/greenlet/tests/_test_extension.c
@@ -0,0 +1,216 @@
+/* This is a set of functions used by test_extension_interface.py to test the
+ * Greenlet C API.
+ */
+
+#include "../greenlet.h"
+
+#ifndef Py_RETURN_NONE
+# define Py_RETURN_NONE return Py_INCREF(Py_None), Py_None
+#endif
+
+#define TEST_MODULE_NAME "_test_extension"
+
+static PyObject*
+test_switch(PyObject* self, PyObject* greenlet)
+{
+ PyObject* result = NULL;
+
+ if (greenlet == NULL || !PyGreenlet_Check(greenlet)) {
+ PyErr_BadArgument();
+ return NULL;
+ }
+
+ result = PyGreenlet_Switch((PyGreenlet*)greenlet, NULL, NULL);
+ if (result == NULL) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_AssertionError,
+ "greenlet.switch() failed for some reason.");
+ }
+ return NULL;
+ }
+ Py_INCREF(result);
+ return result;
+}
+
+static PyObject*
+test_switch_kwargs(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+ PyGreenlet* g = NULL;
+ PyObject* result = NULL;
+
+ PyArg_ParseTuple(args, "O!", &PyGreenlet_Type, &g);
+
+ if (g == NULL || !PyGreenlet_Check(g)) {
+ PyErr_BadArgument();
+ return NULL;
+ }
+
+ result = PyGreenlet_Switch(g, NULL, kwargs);
+ if (result == NULL) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_AssertionError,
+ "greenlet.switch() failed for some reason.");
+ }
+ return NULL;
+ }
+ Py_XINCREF(result);
+ return result;
+}
+
+static PyObject*
+test_getcurrent(PyObject* self)
+{
+ PyGreenlet* g = PyGreenlet_GetCurrent();
+ if (g == NULL || !PyGreenlet_Check(g) || !PyGreenlet_ACTIVE(g)) {
+ PyErr_SetString(PyExc_AssertionError,
+ "getcurrent() returned an invalid greenlet");
+ Py_XDECREF(g);
+ return NULL;
+ }
+ Py_DECREF(g);
+ Py_RETURN_NONE;
+}
+
+static PyObject*
+test_setparent(PyObject* self, PyObject* arg)
+{
+ PyGreenlet* current;
+ PyGreenlet* greenlet = NULL;
+
+ if (arg == NULL || !PyGreenlet_Check(arg)) {
+ PyErr_BadArgument();
+ return NULL;
+ }
+ if ((current = PyGreenlet_GetCurrent()) == NULL) {
+ return NULL;
+ }
+ greenlet = (PyGreenlet*)arg;
+ if (PyGreenlet_SetParent(greenlet, current)) {
+ Py_DECREF(current);
+ return NULL;
+ }
+ Py_DECREF(current);
+ if (PyGreenlet_Switch(greenlet, NULL, NULL) == NULL) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+static PyObject*
+test_new_greenlet(PyObject* self, PyObject* callable)
+{
+ PyObject* result = NULL;
+ PyGreenlet* greenlet = PyGreenlet_New(callable, NULL);
+
+ if (!greenlet) {
+ return NULL;
+ }
+
+ result = PyGreenlet_Switch(greenlet, NULL, NULL);
+ if (result == NULL) {
+ return NULL;
+ }
+
+ Py_INCREF(result);
+ return result;
+}
+
+static PyObject*
+test_raise_dead_greenlet(PyObject* self)
+{
+ PyErr_SetString(PyExc_GreenletExit, "test GreenletExit exception.");
+ return NULL;
+}
+
+static PyObject*
+test_raise_greenlet_error(PyObject* self)
+{
+ PyErr_SetString(PyExc_GreenletError, "test greenlet.error exception");
+ return NULL;
+}
+
+static PyObject*
+test_throw(PyObject* self, PyGreenlet* g)
+{
+ const char msg[] = "take that sucka!";
+ PyObject* msg_obj = Py_BuildValue("s", msg);
+ PyGreenlet_Throw(g, PyExc_ValueError, msg_obj, NULL);
+ Py_DECREF(msg_obj);
+ Py_RETURN_NONE;
+}
+
+static PyMethodDef test_methods[] = {
+ {"test_switch",
+ (PyCFunction)test_switch,
+ METH_O,
+ "Switch to the provided greenlet sending provided arguments, and \n"
+ "return the results."},
+ {"test_switch_kwargs",
+ (PyCFunction)test_switch_kwargs,
+ METH_VARARGS | METH_KEYWORDS,
+ "Switch to the provided greenlet sending the provided keyword args."},
+ {"test_getcurrent",
+ (PyCFunction)test_getcurrent,
+ METH_NOARGS,
+ "Test PyGreenlet_GetCurrent()"},
+ {"test_setparent",
+ (PyCFunction)test_setparent,
+ METH_O,
+ "Se the parent of the provided greenlet and switch to it."},
+ {"test_new_greenlet",
+ (PyCFunction)test_new_greenlet,
+ METH_O,
+ "Test PyGreenlet_New()"},
+ {"test_raise_dead_greenlet",
+ (PyCFunction)test_raise_dead_greenlet,
+ METH_NOARGS,
+ "Just raise greenlet.GreenletExit"},
+ {"test_raise_greenlet_error",
+ (PyCFunction)test_raise_greenlet_error,
+ METH_NOARGS,
+ "Just raise greenlet.error"},
+ {"test_throw",
+ (PyCFunction)test_throw,
+ METH_O,
+ "Throw a ValueError at the provided greenlet"},
+ {NULL, NULL, 0, NULL}};
+
+#if PY_MAJOR_VERSION >= 3
+# define INITERROR return NULL
+
+static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT,
+ TEST_MODULE_NAME,
+ NULL,
+ 0,
+ test_methods,
+ NULL,
+ NULL,
+ NULL,
+ NULL};
+
+PyMODINIT_FUNC
+PyInit__test_extension(void)
+#else
+# define INITERROR return
+PyMODINIT_FUNC
+init_test_extension(void)
+#endif
+{
+ PyObject* module = NULL;
+
+#if PY_MAJOR_VERSION >= 3
+ module = PyModule_Create(&moduledef);
+#else
+ module = Py_InitModule(TEST_MODULE_NAME, test_methods);
+#endif
+
+ if (module == NULL) {
+ INITERROR;
+ }
+
+ PyGreenlet_Import();
+
+#if PY_MAJOR_VERSION >= 3
+ return module;
+#endif
+}
diff --git a/lib/greenlet/tests/_test_extension.cpython-39-x86_64-linux-gnu.so b/lib/greenlet/tests/_test_extension.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..ea567a8
--- /dev/null
+++ b/lib/greenlet/tests/_test_extension.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/greenlet/tests/_test_extension_cpp.cpp b/lib/greenlet/tests/_test_extension_cpp.cpp
new file mode 100644
index 0000000..72e3d81
--- /dev/null
+++ b/lib/greenlet/tests/_test_extension_cpp.cpp
@@ -0,0 +1,121 @@
+/* This is a set of functions used to test C++ exceptions are not
+ * broken during greenlet switches
+ */
+
+#include "../greenlet.h"
+
+struct exception_t {
+ int depth;
+ exception_t(int depth) : depth(depth) {}
+};
+
+/* Functions are called via pointers to prevent inlining */
+static void (*p_test_exception_throw)(int depth);
+static PyObject* (*p_test_exception_switch_recurse)(int depth, int left);
+
+static void
+test_exception_throw(int depth)
+{
+ throw exception_t(depth);
+}
+
+static PyObject*
+test_exception_switch_recurse(int depth, int left)
+{
+ if (left > 0) {
+ return p_test_exception_switch_recurse(depth, left - 1);
+ }
+
+ PyObject* result = NULL;
+ PyGreenlet* self = PyGreenlet_GetCurrent();
+ if (self == NULL)
+ return NULL;
+
+ try {
+ PyGreenlet_Switch(self->parent, NULL, NULL);
+ p_test_exception_throw(depth);
+ PyErr_SetString(PyExc_RuntimeError,
+ "throwing C++ exception didn't work");
+ }
+ catch (exception_t& e) {
+ if (e.depth != depth)
+ PyErr_SetString(PyExc_AssertionError, "depth mismatch");
+ else
+ result = PyLong_FromLong(depth);
+ }
+ catch (...) {
+ PyErr_SetString(PyExc_RuntimeError, "unexpected C++ exception");
+ }
+
+ Py_DECREF(self);
+ return result;
+}
+
+/* test_exception_switch(int depth)
+ * - recurses depth times
+ * - switches to parent inside try/catch block
+ * - throws an exception that (expected to be caught in the same function)
+ * - verifies depth matches (exceptions shouldn't be caught in other greenlets)
+ */
+static PyObject*
+test_exception_switch(PyObject* self, PyObject* args)
+{
+ int depth;
+ if (!PyArg_ParseTuple(args, "i", &depth))
+ return NULL;
+ return p_test_exception_switch_recurse(depth, depth);
+}
+
+static PyMethodDef test_methods[] = {
+ {"test_exception_switch",
+ (PyCFunction)&test_exception_switch,
+ METH_VARARGS,
+ "Switches to parent twice, to test exception handling and greenlet "
+ "switching."},
+ {NULL, NULL, 0, NULL}};
+
+#if PY_MAJOR_VERSION >= 3
+# define INITERROR return NULL
+
+static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT,
+ "greenlet.tests._test_extension_cpp",
+ NULL,
+ 0,
+ test_methods,
+ NULL,
+ NULL,
+ NULL,
+ NULL};
+
+PyMODINIT_FUNC
+PyInit__test_extension_cpp(void)
+#else
+# define INITERROR return
+PyMODINIT_FUNC
+init_test_extension_cpp(void)
+#endif
+{
+ PyObject* module = NULL;
+
+#if PY_MAJOR_VERSION >= 3
+ module = PyModule_Create(&moduledef);
+#else
+ module = Py_InitModule("greenlet.tests._test_extension_cpp", test_methods);
+#endif
+
+ if (module == NULL) {
+ INITERROR;
+ }
+
+ PyGreenlet_Import();
+ if (_PyGreenlet_API == NULL) {
+ INITERROR;
+ }
+
+ p_test_exception_throw = test_exception_throw;
+ p_test_exception_switch_recurse = test_exception_switch_recurse;
+
+#if PY_MAJOR_VERSION >= 3
+ return module;
+#endif
+}
diff --git a/lib/greenlet/tests/_test_extension_cpp.cpython-39-x86_64-linux-gnu.so b/lib/greenlet/tests/_test_extension_cpp.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..0e50cfe
--- /dev/null
+++ b/lib/greenlet/tests/_test_extension_cpp.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/greenlet/tests/test_contextvars.py b/lib/greenlet/tests/test_contextvars.py
new file mode 100644
index 0000000..49b7c0d
--- /dev/null
+++ b/lib/greenlet/tests/test_contextvars.py
@@ -0,0 +1,266 @@
+import unittest
+import gc
+import sys
+
+from functools import partial
+
+from greenlet import greenlet
+from greenlet import getcurrent
+
+
+try:
+ from contextvars import Context
+ from contextvars import ContextVar
+ from contextvars import copy_context
+except ImportError:
+ Context = ContextVar = copy_context = None
+
+# We don't support testing if greenlet's built-in context var support is disabled.
+@unittest.skipUnless(Context is not None, "ContextVar not supported")
+class ContextVarsTests(unittest.TestCase):
+ def _new_ctx_run(self, *args, **kwargs):
+ return copy_context().run(*args, **kwargs)
+
+ def _increment(self, greenlet_id, ctx_var, callback, counts, expect):
+ if expect is None:
+ self.assertIsNone(ctx_var.get())
+ else:
+ self.assertEqual(ctx_var.get(), expect)
+ ctx_var.set(greenlet_id)
+ for _ in range(2):
+ counts[ctx_var.get()] += 1
+ callback()
+
+ def _test_context(self, propagate_by):
+ id_var = ContextVar("id", default=None)
+ id_var.set(0)
+
+ callback = getcurrent().switch
+ counts = dict((i, 0) for i in range(5))
+
+ lets = [
+ greenlet(partial(
+ partial(
+ copy_context().run,
+ self._increment
+ ) if propagate_by == "run" else self._increment,
+ greenlet_id=i,
+ ctx_var=id_var,
+ callback=callback,
+ counts=counts,
+ expect=(
+ i - 1 if propagate_by == "share" else
+ 0 if propagate_by in ("set", "run") else None
+ )
+ ))
+ for i in range(1, 5)
+ ]
+
+ for let in lets:
+ if propagate_by == "set":
+ let.gr_context = copy_context()
+ elif propagate_by == "share":
+ let.gr_context = getcurrent().gr_context
+
+ for i in range(2):
+ counts[id_var.get()] += 1
+ for let in lets:
+ let.switch()
+
+ if propagate_by == "run":
+ # Must leave each context.run() in reverse order of entry
+ for let in reversed(lets):
+ let.switch()
+ else:
+ # No context.run(), so fine to exit in any order.
+ for let in lets:
+ let.switch()
+
+ for let in lets:
+ self.assertTrue(let.dead)
+ # When using run(), we leave the run() as the greenlet dies,
+ # and there's no context "underneath". When not using run(),
+ # gr_context still reflects the context the greenlet was
+ # running in.
+ self.assertEqual(let.gr_context is None, propagate_by == "run")
+
+ if propagate_by == "share":
+ self.assertEqual(counts, {0: 1, 1: 1, 2: 1, 3: 1, 4: 6})
+ else:
+ self.assertEqual(set(counts.values()), set([2]))
+
+ def test_context_propagated_by_context_run(self):
+ self._new_ctx_run(self._test_context, "run")
+
+ def test_context_propagated_by_setting_attribute(self):
+ self._new_ctx_run(self._test_context, "set")
+
+ def test_context_not_propagated(self):
+ self._new_ctx_run(self._test_context, None)
+
+ def test_context_shared(self):
+ self._new_ctx_run(self._test_context, "share")
+
+ def test_break_ctxvars(self):
+ let1 = greenlet(copy_context().run)
+ let2 = greenlet(copy_context().run)
+ let1.switch(getcurrent().switch)
+ let2.switch(getcurrent().switch)
+ # Since let2 entered the current context and let1 exits its own, the
+ # interpreter emits:
+ # RuntimeError: cannot exit context: thread state references a different context object
+ let1.switch()
+
+ def test_not_broken_if_using_attribute_instead_of_context_run(self):
+ let1 = greenlet(getcurrent().switch)
+ let2 = greenlet(getcurrent().switch)
+ let1.gr_context = copy_context()
+ let2.gr_context = copy_context()
+ let1.switch()
+ let2.switch()
+ let1.switch()
+ let2.switch()
+
+ def test_context_assignment_while_running(self):
+ id_var = ContextVar("id", default=None)
+
+ def target():
+ self.assertIsNone(id_var.get())
+ self.assertIsNone(gr.gr_context)
+
+ # Context is created on first use
+ id_var.set(1)
+ self.assertIsInstance(gr.gr_context, Context)
+ self.assertEqual(id_var.get(), 1)
+ self.assertEqual(gr.gr_context[id_var], 1)
+
+ # Clearing the context makes it get re-created as another
+ # empty context when next used
+ old_context = gr.gr_context
+ gr.gr_context = None # assign None while running
+ self.assertIsNone(id_var.get())
+ self.assertIsNone(gr.gr_context)
+ id_var.set(2)
+ self.assertIsInstance(gr.gr_context, Context)
+ self.assertEqual(id_var.get(), 2)
+ self.assertEqual(gr.gr_context[id_var], 2)
+
+ new_context = gr.gr_context
+ getcurrent().parent.switch((old_context, new_context))
+ # parent switches us back to old_context
+
+ self.assertEqual(id_var.get(), 1)
+ gr.gr_context = new_context # assign non-None while running
+ self.assertEqual(id_var.get(), 2)
+
+ getcurrent().parent.switch()
+ # parent switches us back to no context
+ self.assertIsNone(id_var.get())
+ self.assertIsNone(gr.gr_context)
+ gr.gr_context = old_context
+ self.assertEqual(id_var.get(), 1)
+
+ getcurrent().parent.switch()
+ # parent switches us back to no context
+ self.assertIsNone(id_var.get())
+ self.assertIsNone(gr.gr_context)
+
+ gr = greenlet(target)
+
+ with self.assertRaisesRegex(AttributeError, "can't delete attr"):
+ del gr.gr_context
+
+ self.assertIsNone(gr.gr_context)
+ old_context, new_context = gr.switch()
+ self.assertIs(new_context, gr.gr_context)
+ self.assertEqual(old_context[id_var], 1)
+ self.assertEqual(new_context[id_var], 2)
+ self.assertEqual(new_context.run(id_var.get), 2)
+ gr.gr_context = old_context # assign non-None while suspended
+ gr.switch()
+ self.assertIs(gr.gr_context, new_context)
+ gr.gr_context = None # assign None while suspended
+ gr.switch()
+ self.assertIs(gr.gr_context, old_context)
+ gr.gr_context = None
+ gr.switch()
+ self.assertIsNone(gr.gr_context)
+
+ # Make sure there are no reference leaks
+ gr = None
+ gc.collect()
+ self.assertEqual(sys.getrefcount(old_context), 2)
+ self.assertEqual(sys.getrefcount(new_context), 2)
+
+ def test_context_assignment_different_thread(self):
+ import threading
+
+ ctx = Context()
+ var = ContextVar("var", default=None)
+ is_running = threading.Event()
+ should_suspend = threading.Event()
+ did_suspend = threading.Event()
+ should_exit = threading.Event()
+ holder = []
+
+ def greenlet_in_thread_fn():
+ var.set(1)
+ is_running.set()
+ should_suspend.wait()
+ var.set(2)
+ getcurrent().parent.switch()
+ holder.append(var.get())
+
+ def thread_fn():
+ gr = greenlet(greenlet_in_thread_fn)
+ gr.gr_context = ctx
+ holder.append(gr)
+ gr.switch()
+ did_suspend.set()
+ should_exit.wait()
+ gr.switch()
+
+ thread = threading.Thread(target=thread_fn, daemon=True)
+ thread.start()
+ is_running.wait()
+ gr = holder[0]
+
+ # Can't access or modify context if the greenlet is running
+ # in a different thread
+ with self.assertRaisesRegex(ValueError, "running in a different"):
+ getattr(gr, 'gr_context')
+ with self.assertRaisesRegex(ValueError, "running in a different"):
+ gr.gr_context = None
+
+ should_suspend.set()
+ did_suspend.wait()
+
+ # OK to access and modify context if greenlet is suspended
+ self.assertIs(gr.gr_context, ctx)
+ self.assertEqual(gr.gr_context[var], 2)
+ gr.gr_context = None
+
+ should_exit.set()
+ thread.join()
+
+ self.assertEqual(holder, [gr, None])
+
+ # Context can still be accessed/modified when greenlet is dead:
+ self.assertIsNone(gr.gr_context)
+ gr.gr_context = ctx
+ self.assertIs(gr.gr_context, ctx)
+
+@unittest.skipIf(Context is not None, "ContextVar supported")
+class NoContextVarsTests(unittest.TestCase):
+ def test_contextvars_errors(self):
+ let1 = greenlet(getcurrent().switch)
+ self.assertFalse(hasattr(let1, 'gr_context'))
+ with self.assertRaises(AttributeError):
+ getattr(let1, 'gr_context')
+ with self.assertRaises(AttributeError):
+ let1.gr_context = None
+ let1.switch()
+ with self.assertRaises(AttributeError):
+ getattr(let1, 'gr_context')
+ with self.assertRaises(AttributeError):
+ let1.gr_context = None
diff --git a/lib/greenlet/tests/test_cpp.py b/lib/greenlet/tests/test_cpp.py
new file mode 100644
index 0000000..741ea10
--- /dev/null
+++ b/lib/greenlet/tests/test_cpp.py
@@ -0,0 +1,18 @@
+from __future__ import print_function
+from __future__ import absolute_import
+
+import unittest
+
+import greenlet
+from . import _test_extension_cpp
+
+
+class CPPTests(unittest.TestCase):
+ def test_exception_switch(self):
+ greenlets = []
+ for i in range(4):
+ g = greenlet.greenlet(_test_extension_cpp.test_exception_switch)
+ g.switch(i)
+ greenlets.append(g)
+ for i, g in enumerate(greenlets):
+ self.assertEqual(g.switch(), i)
diff --git a/lib/greenlet/tests/test_extension_interface.py b/lib/greenlet/tests/test_extension_interface.py
new file mode 100644
index 0000000..a92ea1f
--- /dev/null
+++ b/lib/greenlet/tests/test_extension_interface.py
@@ -0,0 +1,77 @@
+from __future__ import print_function
+from __future__ import absolute_import
+
+import sys
+import unittest
+
+import greenlet
+from . import _test_extension
+
+
+class CAPITests(unittest.TestCase):
+ def test_switch(self):
+ self.assertEqual(
+ 50, _test_extension.test_switch(greenlet.greenlet(lambda: 50)))
+
+ def test_switch_kwargs(self):
+ def foo(x, y):
+ return x * y
+ g = greenlet.greenlet(foo)
+ self.assertEqual(6, _test_extension.test_switch_kwargs(g, x=3, y=2))
+
+ def test_setparent(self):
+ def foo():
+ def bar():
+ greenlet.getcurrent().parent.switch()
+
+ # This final switch should go back to the main greenlet, since
+ # the test_setparent() function in the C extension should have
+ # reparented this greenlet.
+ greenlet.getcurrent().parent.switch()
+ raise AssertionError("Should never have reached this code")
+ child = greenlet.greenlet(bar)
+ child.switch()
+ greenlet.getcurrent().parent.switch(child)
+ greenlet.getcurrent().parent.throw(
+ AssertionError("Should never reach this code"))
+ foo_child = greenlet.greenlet(foo).switch()
+ self.assertEqual(None, _test_extension.test_setparent(foo_child))
+
+ def test_getcurrent(self):
+ _test_extension.test_getcurrent()
+
+ def test_new_greenlet(self):
+ self.assertEqual(-15, _test_extension.test_new_greenlet(lambda: -15))
+
+ def test_raise_greenlet_dead(self):
+ self.assertRaises(
+ greenlet.GreenletExit, _test_extension.test_raise_dead_greenlet)
+
+ def test_raise_greenlet_error(self):
+ self.assertRaises(
+ greenlet.error, _test_extension.test_raise_greenlet_error)
+
+ def test_throw(self):
+ seen = []
+
+ def foo():
+ try:
+ greenlet.getcurrent().parent.switch()
+ except ValueError:
+ seen.append(sys.exc_info()[1])
+ except greenlet.GreenletExit:
+ raise AssertionError
+ g = greenlet.greenlet(foo)
+ g.switch()
+ _test_extension.test_throw(g)
+ self.assertEqual(len(seen), 1)
+ self.assertTrue(
+ isinstance(seen[0], ValueError),
+ "ValueError was not raised in foo()")
+ self.assertEqual(
+ str(seen[0]),
+ 'take that sucka!',
+ "message doesn't match")
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/lib/greenlet/tests/test_gc.py b/lib/greenlet/tests/test_gc.py
new file mode 100644
index 0000000..a2a41ca
--- /dev/null
+++ b/lib/greenlet/tests/test_gc.py
@@ -0,0 +1,77 @@
+import gc
+import sys
+import unittest
+import weakref
+
+import greenlet
+
+
+class GCTests(unittest.TestCase):
+ def test_dead_circular_ref(self):
+ o = weakref.ref(greenlet.greenlet(greenlet.getcurrent).switch())
+ gc.collect()
+ self.assertTrue(o() is None)
+ self.assertFalse(gc.garbage, gc.garbage)
+
+ if greenlet.GREENLET_USE_GC:
+ # These only work with greenlet gc support
+
+ def test_circular_greenlet(self):
+ class circular_greenlet(greenlet.greenlet):
+ pass
+ o = circular_greenlet()
+ o.self = o
+ o = weakref.ref(o)
+ gc.collect()
+ self.assertTrue(o() is None)
+ self.assertFalse(gc.garbage, gc.garbage)
+
+ def test_inactive_ref(self):
+ class inactive_greenlet(greenlet.greenlet):
+ def __init__(self):
+ greenlet.greenlet.__init__(self, run=self.run)
+
+ def run(self):
+ pass
+ o = inactive_greenlet()
+ o = weakref.ref(o)
+ gc.collect()
+ self.assertTrue(o() is None)
+ self.assertFalse(gc.garbage, gc.garbage)
+
+ def test_finalizer_crash(self):
+ # This test is designed to crash when active greenlets
+ # are made garbage collectable, until the underlying
+ # problem is resolved. How does it work:
+ # - order of object creation is important
+ # - array is created first, so it is moved to unreachable first
+ # - we create a cycle between a greenlet and this array
+ # - we create an object that participates in gc, is only
+ # referenced by a greenlet, and would corrupt gc lists
+ # on destruction, the easiest is to use an object with
+ # a finalizer
+ # - because array is the first object in unreachable it is
+ # cleared first, which causes all references to greenlet
+ # to disappear and causes greenlet to be destroyed, but since
+ # it is still live it causes a switch during gc, which causes
+ # an object with finalizer to be destroyed, which causes stack
+ # corruption and then a crash
+ class object_with_finalizer(object):
+ def __del__(self):
+ pass
+ array = []
+ parent = greenlet.getcurrent()
+ def greenlet_body():
+ greenlet.getcurrent().object = object_with_finalizer()
+ try:
+ parent.switch()
+ finally:
+ del greenlet.getcurrent().object
+ g = greenlet.greenlet(greenlet_body)
+ g.array = array
+ array.append(g)
+ g.switch()
+ del array
+ del g
+ greenlet.getcurrent()
+ gc.collect()
diff --git a/lib/greenlet/tests/test_generator.py b/lib/greenlet/tests/test_generator.py
new file mode 100644
index 0000000..62f9f26
--- /dev/null
+++ b/lib/greenlet/tests/test_generator.py
@@ -0,0 +1,59 @@
+import unittest
+from greenlet import greenlet
+
+
+class genlet(greenlet):
+
+ def __init__(self, *args, **kwds):
+ self.args = args
+ self.kwds = kwds
+
+ def run(self):
+ fn, = self.fn
+ fn(*self.args, **self.kwds)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ self.parent = greenlet.getcurrent()
+ result = self.switch()
+ if self:
+ return result
+ else:
+ raise StopIteration
+
+ # Hack: Python < 2.6 compatibility
+ next = __next__
+
+
+def Yield(value):
+ g = greenlet.getcurrent()
+ while not isinstance(g, genlet):
+ if g is None:
+ raise RuntimeError('yield outside a genlet')
+ g = g.parent
+ g.parent.switch(value)
+
+
+def generator(func):
+ class generator(genlet):
+ fn = (func,)
+ return generator
+
+# ____________________________________________________________
+
+
+class GeneratorTests(unittest.TestCase):
+ def test_generator(self):
+ seen = []
+
+ def g(n):
+ for i in range(n):
+ seen.append(i)
+ Yield(i)
+ g = generator(g)
+ for k in range(3):
+ for j in g(5):
+ seen.append(j)
+ self.assertEqual(seen, 3 * [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
diff --git a/lib/greenlet/tests/test_generator_nested.py b/lib/greenlet/tests/test_generator_nested.py
new file mode 100644
index 0000000..6b4f023
--- /dev/null
+++ b/lib/greenlet/tests/test_generator_nested.py
@@ -0,0 +1,165 @@
+import unittest
+from greenlet import greenlet
+
+
+class genlet(greenlet):
+
+ def __init__(self, *args, **kwds):
+ self.args = args
+ self.kwds = kwds
+ self.child = None
+
+ def run(self):
+ fn, = self.fn
+ fn(*self.args, **self.kwds)
+
+ def __iter__(self):
+ return self
+
+ def set_child(self, child):
+ self.child = child
+
+ def __next__(self):
+ if self.child:
+ child = self.child
+ while child.child:
+ tmp = child
+ child = child.child
+ tmp.child = None
+
+ result = child.switch()
+ else:
+ self.parent = greenlet.getcurrent()
+ result = self.switch()
+
+ if self:
+ return result
+ else:
+ raise StopIteration
+
+ # Hack: Python < 2.6 compatibility
+ next = __next__
+
+
+def Yield(value, level=1):
+ g = greenlet.getcurrent()
+
+ while level != 0:
+ if not isinstance(g, genlet):
+ raise RuntimeError('yield outside a genlet')
+ if level > 1:
+ g.parent.set_child(g)
+ g = g.parent
+ level -= 1
+
+ g.switch(value)
+
+
+def Genlet(func):
+ class Genlet(genlet):
+ fn = (func,)
+ return Genlet
+
+# ____________________________________________________________
+
+
+def g1(n, seen):
+ for i in range(n):
+ seen.append(i + 1)
+ yield i
+
+
+def g2(n, seen):
+ for i in range(n):
+ seen.append(i + 1)
+ Yield(i)
+
+g2 = Genlet(g2)
+
+
+def nested(i):
+ Yield(i)
+
+
+def g3(n, seen):
+ for i in range(n):
+ seen.append(i + 1)
+ nested(i)
+g3 = Genlet(g3)
+
+
+def a(n):
+ if n == 0:
+ return
+ for ii in ax(n - 1):
+ Yield(ii)
+ Yield(n)
+ax = Genlet(a)
+
+
+def perms(l):
+ if len(l) > 1:
+ for e in l:
+ # No syntactical sugar for generator expressions
+ [Yield([e] + p) for p in perms([x for x in l if x != e])]
+ else:
+ Yield(l)
+perms = Genlet(perms)
+
+
+def gr1(n):
+ for ii in range(1, n):
+ Yield(ii)
+ Yield(ii * ii, 2)
+
+gr1 = Genlet(gr1)
+
+
+def gr2(n, seen):
+ for ii in gr1(n):
+ seen.append(ii)
+
+gr2 = Genlet(gr2)
+
+
+class NestedGeneratorTests(unittest.TestCase):
+ def test_layered_genlets(self):
+ seen = []
+ for ii in gr2(5, seen):
+ seen.append(ii)
+ self.assertEqual(seen, [1, 1, 2, 4, 3, 9, 4, 16])
+
+ def test_permutations(self):
+ gen_perms = perms(list(range(4)))
+ permutations = list(gen_perms)
+ self.assertEqual(len(permutations), 4 * 3 * 2 * 1)
+ self.assertTrue([0, 1, 2, 3] in permutations)
+ self.assertTrue([3, 2, 1, 0] in permutations)
+ res = []
+ for ii in zip(perms(list(range(4))), perms(list(range(3)))):
+ res.append(ii)
+ self.assertEqual(
+ res,
+ [([0, 1, 2, 3], [0, 1, 2]), ([0, 1, 3, 2], [0, 2, 1]),
+ ([0, 2, 1, 3], [1, 0, 2]), ([0, 2, 3, 1], [1, 2, 0]),
+ ([0, 3, 1, 2], [2, 0, 1]), ([0, 3, 2, 1], [2, 1, 0])])
+ # XXX Test to make sure we are working as a generator expression
+
+ def test_genlet_simple(self):
+ for g in [g1, g2, g3]:
+ seen = []
+ for k in range(3):
+ for j in g(5, seen):
+ seen.append(j)
+ self.assertEqual(seen, 3 * [1, 0, 2, 1, 3, 2, 4, 3, 5, 4])
+
+ def test_genlet_bad(self):
+ try:
+ Yield(10)
+ except RuntimeError:
+ pass
+
+ def test_nested_genlets(self):
+ seen = []
+ for ii in ax(5):
+ seen.append(ii)
diff --git a/lib/greenlet/tests/test_greenlet.py b/lib/greenlet/tests/test_greenlet.py
new file mode 100644
index 0000000..5509a8b
--- /dev/null
+++ b/lib/greenlet/tests/test_greenlet.py
@@ -0,0 +1,728 @@
+import gc
+import sys
+import time
+import threading
+import unittest
+from abc import ABCMeta, abstractmethod
+
+from greenlet import greenlet
+
+# We manually manage locks in many tests
+# pylint:disable=consider-using-with
+
+class SomeError(Exception):
+ pass
+
+
+def fmain(seen):
+ try:
+ greenlet.getcurrent().parent.switch()
+ except:
+ seen.append(sys.exc_info()[0])
+ raise
+ raise SomeError
+
+
+def send_exception(g, exc):
+ # note: send_exception(g, exc) can be now done with g.throw(exc).
+ # the purpose of this test is to explicitely check the propagation rules.
+ def crasher(exc):
+ raise exc
+ g1 = greenlet(crasher, parent=g)
+ g1.switch(exc)
+
+
+class TestGreenlet(unittest.TestCase):
+ def test_simple(self):
+ lst = []
+
+ def f():
+ lst.append(1)
+ greenlet.getcurrent().parent.switch()
+ lst.append(3)
+ g = greenlet(f)
+ lst.append(0)
+ g.switch()
+ lst.append(2)
+ g.switch()
+ lst.append(4)
+ self.assertEqual(lst, list(range(5)))
+
+ def test_parent_equals_None(self):
+ g = greenlet(parent=None)
+ self.assertIsNotNone(g)
+ self.assertIs(g.parent, greenlet.getcurrent())
+
+ def test_run_equals_None(self):
+ g = greenlet(run=None)
+ self.assertIsNotNone(g)
+ self.assertIsNone(g.run)
+
+ def test_two_children(self):
+ lst = []
+
+ def f():
+ lst.append(1)
+ greenlet.getcurrent().parent.switch()
+ lst.extend([1, 1])
+ g = greenlet(f)
+ h = greenlet(f)
+ g.switch()
+ self.assertEqual(len(lst), 1)
+ h.switch()
+ self.assertEqual(len(lst), 2)
+ h.switch()
+ self.assertEqual(len(lst), 4)
+ self.assertEqual(h.dead, True)
+ g.switch()
+ self.assertEqual(len(lst), 6)
+ self.assertEqual(g.dead, True)
+
+ def test_two_recursive_children(self):
+ lst = []
+
+ def f():
+ lst.append(1)
+ greenlet.getcurrent().parent.switch()
+
+ def g():
+ lst.append(1)
+ g = greenlet(f)
+ g.switch()
+ lst.append(1)
+ g = greenlet(g)
+ g.switch()
+ self.assertEqual(len(lst), 3)
+ self.assertEqual(sys.getrefcount(g), 2)
+
+ def test_threads(self):
+ success = []
+
+ def f():
+ self.test_simple()
+ success.append(True)
+ ths = [threading.Thread(target=f) for i in range(10)]
+ for th in ths:
+ th.start()
+ for th in ths:
+ th.join()
+ self.assertEqual(len(success), len(ths))
+
+ def test_exception(self):
+ seen = []
+ g1 = greenlet(fmain)
+ g2 = greenlet(fmain)
+ g1.switch(seen)
+ g2.switch(seen)
+ g2.parent = g1
+ self.assertEqual(seen, [])
+ self.assertRaises(SomeError, g2.switch)
+ self.assertEqual(seen, [SomeError])
+ g2.switch()
+ self.assertEqual(seen, [SomeError])
+
+ def test_send_exception(self):
+ seen = []
+ g1 = greenlet(fmain)
+ g1.switch(seen)
+ self.assertRaises(KeyError, send_exception, g1, KeyError)
+ self.assertEqual(seen, [KeyError])
+
+ def test_dealloc(self):
+ seen = []
+ g1 = greenlet(fmain)
+ g2 = greenlet(fmain)
+ g1.switch(seen)
+ g2.switch(seen)
+ self.assertEqual(seen, [])
+ del g1
+ gc.collect()
+ self.assertEqual(seen, [greenlet.GreenletExit])
+ del g2
+ gc.collect()
+ self.assertEqual(seen, [greenlet.GreenletExit, greenlet.GreenletExit])
+
+ def test_dealloc_other_thread(self):
+ seen = []
+ someref = []
+ lock = threading.Lock()
+ lock.acquire()
+ lock2 = threading.Lock()
+ lock2.acquire()
+
+ def f():
+ g1 = greenlet(fmain)
+ g1.switch(seen)
+ someref.append(g1)
+ del g1
+ gc.collect()
+ lock.release()
+ lock2.acquire()
+ greenlet() # trigger release
+ lock.release()
+ lock2.acquire()
+ t = threading.Thread(target=f)
+ t.start()
+ lock.acquire()
+ self.assertEqual(seen, [])
+ self.assertEqual(len(someref), 1)
+ del someref[:]
+ gc.collect()
+ # g1 is not released immediately because it's from another thread
+ self.assertEqual(seen, [])
+ lock2.release()
+ lock.acquire()
+ self.assertEqual(seen, [greenlet.GreenletExit])
+ lock2.release()
+ t.join()
+
+ def test_frame(self):
+ def f1():
+ f = sys._getframe(0) # pylint:disable=protected-access
+ self.assertEqual(f.f_back, None)
+ greenlet.getcurrent().parent.switch(f)
+ return "meaning of life"
+ g = greenlet(f1)
+ frame = g.switch()
+ self.assertTrue(frame is g.gr_frame)
+ self.assertTrue(g)
+
+ from_g = g.switch()
+ self.assertFalse(g)
+ self.assertEqual(from_g, 'meaning of life')
+ self.assertEqual(g.gr_frame, None)
+
+ def test_thread_bug(self):
+ def runner(x):
+ g = greenlet(lambda: time.sleep(x))
+ g.switch()
+ t1 = threading.Thread(target=runner, args=(0.2,))
+ t2 = threading.Thread(target=runner, args=(0.3,))
+ t1.start()
+ t2.start()
+ t1.join()
+ t2.join()
+
+ def test_switch_kwargs(self):
+ def run(a, b):
+ self.assertEqual(a, 4)
+ self.assertEqual(b, 2)
+ return 42
+ x = greenlet(run).switch(a=4, b=2)
+ self.assertEqual(x, 42)
+
+ def test_switch_kwargs_to_parent(self):
+ def run(x):
+ greenlet.getcurrent().parent.switch(x=x)
+ greenlet.getcurrent().parent.switch(2, x=3)
+ return x, x ** 2
+ g = greenlet(run)
+ self.assertEqual({'x': 3}, g.switch(3))
+ self.assertEqual(((2,), {'x': 3}), g.switch())
+ self.assertEqual((3, 9), g.switch())
+
+ def test_switch_to_another_thread(self):
+ data = {}
+ error = None
+ created_event = threading.Event()
+ done_event = threading.Event()
+
+ def run():
+ data['g'] = greenlet(lambda: None)
+ created_event.set()
+ done_event.wait()
+ thread = threading.Thread(target=run)
+ thread.start()
+ created_event.wait()
+ try:
+ data['g'].switch()
+ except greenlet.error:
+ error = sys.exc_info()[1]
+ self.assertIsNotNone(error, "greenlet.error was not raised!")
+ done_event.set()
+ thread.join()
+
+ def test_exc_state(self):
+ def f():
+ try:
+ raise ValueError('fun')
+ except: # pylint:disable=bare-except
+ exc_info = sys.exc_info()
+ greenlet(h).switch()
+ self.assertEqual(exc_info, sys.exc_info())
+
+ def h():
+ self.assertEqual(sys.exc_info(), (None, None, None))
+
+ greenlet(f).switch()
+
+ def test_instance_dict(self):
+ def f():
+ greenlet.getcurrent().test = 42
+ def deldict(g):
+ del g.__dict__
+ def setdict(g, value):
+ g.__dict__ = value
+ g = greenlet(f)
+ self.assertEqual(g.__dict__, {})
+ g.switch()
+ self.assertEqual(g.test, 42)
+ self.assertEqual(g.__dict__, {'test': 42})
+ g.__dict__ = g.__dict__
+ self.assertEqual(g.__dict__, {'test': 42})
+ self.assertRaises(TypeError, deldict, g)
+ self.assertRaises(TypeError, setdict, g, 42)
+
+ def test_threaded_reparent(self):
+ data = {}
+ created_event = threading.Event()
+ done_event = threading.Event()
+
+ def run():
+ data['g'] = greenlet(lambda: None)
+ created_event.set()
+ done_event.wait()
+
+ def blank():
+ greenlet.getcurrent().parent.switch()
+
+ def setparent(g, value):
+ g.parent = value
+
+ thread = threading.Thread(target=run)
+ thread.start()
+ created_event.wait()
+ g = greenlet(blank)
+ g.switch()
+ self.assertRaises(ValueError, setparent, g, data['g'])
+ done_event.set()
+ thread.join()
+
+ def test_deepcopy(self):
+ import copy
+ self.assertRaises(TypeError, copy.copy, greenlet())
+ self.assertRaises(TypeError, copy.deepcopy, greenlet())
+
+ def test_parent_restored_on_kill(self):
+ hub = greenlet(lambda: None)
+ main = greenlet.getcurrent()
+ result = []
+ def worker():
+ try:
+ # Wait to be killed
+ main.switch()
+ except greenlet.GreenletExit:
+ # Resurrect and switch to parent
+ result.append(greenlet.getcurrent().parent)
+ result.append(greenlet.getcurrent())
+ hub.switch()
+ g = greenlet(worker, parent=hub)
+ g.switch()
+ del g
+ self.assertTrue(result)
+ self.assertEqual(result[0], main)
+ self.assertEqual(result[1].parent, hub)
+
+ def test_parent_return_failure(self):
+ # No run causes AttributeError on switch
+ g1 = greenlet()
+ # Greenlet that implicitly switches to parent
+ g2 = greenlet(lambda: None, parent=g1)
+ # AttributeError should propagate to us, no fatal errors
+ self.assertRaises(AttributeError, g2.switch)
+
+ def test_throw_exception_not_lost(self):
+ class mygreenlet(greenlet):
+ def __getattribute__(self, name):
+ try:
+ raise Exception()
+ except: # pylint:disable=bare-except
+ pass
+ return greenlet.__getattribute__(self, name)
+ g = mygreenlet(lambda: None)
+ self.assertRaises(SomeError, g.throw, SomeError())
+
+ def test_throw_doesnt_crash(self):
+ result = []
+ def worker():
+ greenlet.getcurrent().parent.switch()
+ def creator():
+ g = greenlet(worker)
+ g.switch()
+ result.append(g)
+ t = threading.Thread(target=creator)
+ t.start()
+ t.join()
+ self.assertRaises(greenlet.error, result[0].throw, SomeError())
+
+ def test_recursive_startup(self):
+ class convoluted(greenlet):
+ def __init__(self):
+ greenlet.__init__(self)
+ self.count = 0
+ def __getattribute__(self, name):
+ if name == 'run' and self.count == 0:
+ self.count = 1
+ self.switch(43)
+ return greenlet.__getattribute__(self, name)
+ def run(self, value):
+ while True:
+ self.parent.switch(value)
+ g = convoluted()
+ self.assertEqual(g.switch(42), 43)
+
+ def test_unexpected_reparenting(self):
+ another = []
+ def worker():
+ g = greenlet(lambda: None)
+ another.append(g)
+ g.switch()
+ t = threading.Thread(target=worker)
+ t.start()
+ t.join()
+ class convoluted(greenlet):
+ def __getattribute__(self, name):
+ if name == 'run':
+ self.parent = another[0] # pylint:disable=attribute-defined-outside-init
+ return greenlet.__getattribute__(self, name)
+ g = convoluted(lambda: None)
+ self.assertRaises(greenlet.error, g.switch)
+
+ def test_threaded_updatecurrent(self):
+ # released when main thread should execute
+ lock1 = threading.Lock()
+ lock1.acquire()
+ # released when another thread should execute
+ lock2 = threading.Lock()
+ lock2.acquire()
+ class finalized(object):
+ def __del__(self):
+ # happens while in green_updatecurrent() in main greenlet
+ # should be very careful not to accidentally call it again
+ # at the same time we must make sure another thread executes
+ lock2.release()
+ lock1.acquire()
+ # now ts_current belongs to another thread
+ def deallocator():
+ greenlet.getcurrent().parent.switch()
+ def fthread():
+ lock2.acquire()
+ greenlet.getcurrent()
+ del g[0]
+ lock1.release()
+ lock2.acquire()
+ greenlet.getcurrent()
+ lock1.release()
+ main = greenlet.getcurrent()
+ g = [greenlet(deallocator)]
+ g[0].bomb = finalized()
+ g[0].switch()
+ t = threading.Thread(target=fthread)
+ t.start()
+ # let another thread grab ts_current and deallocate g[0]
+ lock2.release()
+ lock1.acquire()
+ # this is the corner stone
+ # getcurrent() will notice that ts_current belongs to another thread
+ # and start the update process, which would notice that g[0] should
+ # be deallocated, and that will execute an object's finalizer. Now,
+ # that object will let another thread run so it can grab ts_current
+ # again, which would likely crash the interpreter if there's no
+ # check for this case at the end of green_updatecurrent(). This test
+ # passes if getcurrent() returns correct result, but it's likely
+ # to randomly crash if it's not anyway.
+ self.assertEqual(greenlet.getcurrent(), main)
+ # wait for another thread to complete, just in case
+ t.join()
+
+ def test_dealloc_switch_args_not_lost(self):
+ seen = []
+ def worker():
+ # wait for the value
+ value = greenlet.getcurrent().parent.switch()
+ # delete all references to ourself
+ del worker[0]
+ initiator.parent = greenlet.getcurrent().parent
+ # switch to main with the value, but because
+ # ts_current is the last reference to us we
+ # return immediately
+ try:
+ greenlet.getcurrent().parent.switch(value)
+ finally:
+ seen.append(greenlet.getcurrent())
+ def initiator():
+ return 42 # implicitly falls thru to parent
+ worker = [greenlet(worker)]
+ worker[0].switch() # prime worker
+ initiator = greenlet(initiator, worker[0])
+ value = initiator.switch()
+ self.assertTrue(seen)
+ self.assertEqual(value, 42)
+
+
+
+ def test_tuple_subclass(self):
+ if sys.version_info[0] > 2:
+ # There's no apply in Python 3.x
+ def _apply(func, a, k):
+ func(*a, **k)
+ else:
+ _apply = apply # pylint:disable=undefined-variable
+
+ class mytuple(tuple):
+ def __len__(self):
+ greenlet.getcurrent().switch()
+ return tuple.__len__(self)
+ args = mytuple()
+ kwargs = dict(a=42)
+ def switchapply():
+ _apply(greenlet.getcurrent().parent.switch, args, kwargs)
+ g = greenlet(switchapply)
+ self.assertEqual(g.switch(), kwargs)
+
+ def test_abstract_subclasses(self):
+ AbstractSubclass = ABCMeta(
+ 'AbstractSubclass',
+ (greenlet,),
+ {'run': abstractmethod(lambda self: None)})
+
+ class BadSubclass(AbstractSubclass):
+ pass
+
+ class GoodSubclass(AbstractSubclass):
+ def run(self):
+ pass
+
+ GoodSubclass() # should not raise
+ self.assertRaises(TypeError, BadSubclass)
+
+ def test_implicit_parent_with_threads(self):
+ if not gc.isenabled():
+ return # cannot test with disabled gc
+ N = gc.get_threshold()[0]
+ if N < 50:
+ return # cannot test with such a small N
+ def attempt():
+ lock1 = threading.Lock()
+ lock1.acquire()
+ lock2 = threading.Lock()
+ lock2.acquire()
+ recycled = [False]
+ def another_thread():
+ lock1.acquire() # wait for gc
+ greenlet.getcurrent() # update ts_current
+ lock2.release() # release gc
+ t = threading.Thread(target=another_thread)
+ t.start()
+ class gc_callback(object):
+ def __del__(self):
+ lock1.release()
+ lock2.acquire()
+ recycled[0] = True
+ class garbage(object):
+ def __init__(self):
+ self.cycle = self
+ self.callback = gc_callback()
+ l = []
+ x = range(N*2)
+ current = greenlet.getcurrent()
+ g = garbage()
+ for _ in x:
+ g = None # lose reference to garbage
+ if recycled[0]:
+ # gc callback called prematurely
+ t.join()
+ return False
+ last = greenlet()
+ if recycled[0]:
+ break # yes! gc called in green_new
+ l.append(last) # increase allocation counter
+ else:
+ # gc callback not called when expected
+ gc.collect()
+ if recycled[0]:
+ t.join()
+ return False
+ self.assertEqual(last.parent, current)
+ for g in l:
+ self.assertEqual(g.parent, current)
+ return True
+ for _ in range(5):
+ if attempt():
+ break
+
+ def test_issue_245_reference_counting_subclass_no_threads(self):
+ # https://github.com/python-greenlet/greenlet/issues/245
+ # Before the fix, this crashed pretty reliably on
+ # Python 3.10, at least on macOS; but much less reliably on other
+ # interpreters (memory layout must have changed).
+ # The threaded test crashed more reliably on more interpreters.
+ from greenlet import getcurrent
+ from greenlet import GreenletExit
+
+ class Greenlet(greenlet):
+ pass
+
+ initial_refs = sys.getrefcount(Greenlet)
+ # This has to be an instance variable because
+ # Python 2 raises a SyntaxError if we delete a local
+ # variable referenced in an inner scope.
+ self.glets = [] # pylint:disable=attribute-defined-outside-init
+
+ def greenlet_main():
+ try:
+ getcurrent().parent.switch()
+ except GreenletExit:
+ self.glets.append(getcurrent())
+
+ # Before the
+ for _ in range(10):
+ Greenlet(greenlet_main).switch()
+
+ del self.glets
+ self.assertEqual(sys.getrefcount(Greenlet), initial_refs)
+
+ def test_issue_245_reference_counting_subclass_threads(self):
+ # https://github.com/python-greenlet/greenlet/issues/245
+ from threading import Thread
+ from threading import Event
+
+ from greenlet import getcurrent
+
+ class MyGreenlet(greenlet):
+ pass
+
+ glets = []
+ ref_cleared = Event()
+
+ def greenlet_main():
+ getcurrent().parent.switch()
+
+ def thread_main(greenlet_running_event):
+ mine = MyGreenlet(greenlet_main)
+ glets.append(mine)
+ # The greenlets being deleted must be active
+ mine.switch()
+ # Don't keep any reference to it in this thread
+ del mine
+ # Let main know we published our greenlet.
+ greenlet_running_event.set()
+ # Wait for main to let us know the references are
+ # gone and the greenlet objects no longer reachable
+ ref_cleared.wait()
+ # The creating thread must call getcurrent() (or a few other
+ # greenlet APIs) because that's when the thread-local list of dead
+ # greenlets gets cleared.
+ getcurrent()
+
+ # We start with 3 references to the subclass:
+ # - This module
+ # - Its __mro__
+ # - The __subclassess__ attribute of greenlet
+ # - (If we call gc.get_referents(), we find four entries, including
+ # some other tuple ``(greenlet)`` that I'm not sure about but must be part
+ # of the machinery.)
+ #
+ # On Python 3.10 it's often enough to just run 3 threads; on Python 2.7,
+ # more threads are needed, and the results are still
+ # non-deterministic. Presumably the memory layouts are different
+ initial_refs = sys.getrefcount(MyGreenlet)
+ thread_ready_events = []
+ for _ in range(
+ initial_refs + 45
+ ):
+ event = Event()
+ thread = Thread(target=thread_main, args=(event,))
+ thread_ready_events.append(event)
+ thread.start()
+
+
+ for done_event in thread_ready_events:
+ done_event.wait()
+
+
+ del glets[:]
+ ref_cleared.set()
+ # Let any other thread run; it will crash the interpreter
+ # if not fixed (or silently corrupt memory and we possibly crash
+ # later).
+ time.sleep(1)
+ self.assertEqual(sys.getrefcount(MyGreenlet), initial_refs)
+
+
+class TestRepr(unittest.TestCase):
+
+ def assertEndsWith(self, got, suffix):
+ self.assertTrue(got.endswith(suffix), (got, suffix))
+
+ def test_main_while_running(self):
+ r = repr(greenlet.getcurrent())
+ self.assertEndsWith(r, " current active started main>")
+
+ def test_main_in_background(self):
+ main = greenlet.getcurrent()
+ def run():
+ return repr(main)
+
+ g = greenlet(run)
+ r = g.switch()
+ self.assertEndsWith(r, ' suspended active started main>')
+
+ def test_initial(self):
+ r = repr(greenlet())
+ self.assertEndsWith(r, ' pending>')
+
+ def test_main_from_other_thread(self):
+ main = greenlet.getcurrent()
+
+ class T(threading.Thread):
+ original_main = thread_main = None
+ main_glet = None
+ def run(self):
+ self.original_main = repr(main)
+ self.main_glet = greenlet.getcurrent()
+ self.thread_main = repr(self.main_glet)
+
+ t = T()
+ t.start()
+ t.join(10)
+
+ self.assertEndsWith(t.original_main, ' suspended active started main>')
+ self.assertEndsWith(t.thread_main, ' current active started main>')
+
+ r = repr(t.main_glet)
+ # main greenlets, even from dead threads, never really appear dead
+ # TODO: Can we find a better way to differentiate that?
+ assert not t.main_glet.dead
+ self.assertEndsWith(r, ' suspended active started main>')
+
+ def test_dead(self):
+ g = greenlet(lambda: None)
+ g.switch()
+ self.assertEndsWith(repr(g), ' dead>')
+ self.assertNotIn('suspended', repr(g))
+ self.assertNotIn('started', repr(g))
+ self.assertNotIn('active', repr(g))
+
+ def test_formatting_produces_native_str(self):
+ # https://github.com/python-greenlet/greenlet/issues/218
+ # %s formatting on Python 2 was producing unicode, not str.
+
+ g_dead = greenlet(lambda: None)
+ g_not_started = greenlet(lambda: None)
+ g_cur = greenlet.getcurrent()
+
+ for g in g_dead, g_not_started, g_cur:
+
+ self.assertIsInstance(
+ '%s' % (g,),
+ str
+ )
+ self.assertIsInstance(
+ '%r' % (g,),
+ str,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/lib/greenlet/tests/test_leaks.py b/lib/greenlet/tests/test_leaks.py
new file mode 100644
index 0000000..2b02bfd
--- /dev/null
+++ b/lib/greenlet/tests/test_leaks.py
@@ -0,0 +1,178 @@
+import unittest
+import sys
+import gc
+
+import time
+import weakref
+import threading
+
+import greenlet
+
+class TestLeaks(unittest.TestCase):
+
+ def test_arg_refs(self):
+ args = ('a', 'b', 'c')
+ refcount_before = sys.getrefcount(args)
+ # pylint:disable=unnecessary-lambda
+ g = greenlet.greenlet(
+ lambda *args: greenlet.getcurrent().parent.switch(*args))
+ for _ in range(100):
+ g.switch(*args)
+ self.assertEqual(sys.getrefcount(args), refcount_before)
+
+ def test_kwarg_refs(self):
+ kwargs = {}
+ # pylint:disable=unnecessary-lambda
+ g = greenlet.greenlet(
+ lambda **kwargs: greenlet.getcurrent().parent.switch(**kwargs))
+ for _ in range(100):
+ g.switch(**kwargs)
+ self.assertEqual(sys.getrefcount(kwargs), 2)
+
+ assert greenlet.GREENLET_USE_GC # Option to disable this was removed in 1.0
+
+ def recycle_threads(self):
+ # By introducing a thread that does sleep we allow other threads,
+ # that have triggered their __block condition, but did not have a
+ # chance to deallocate their thread state yet, to finally do so.
+ # The way it works is by requiring a GIL switch (different thread),
+ # which does a GIL release (sleep), which might do a GIL switch
+ # to finished threads and allow them to clean up.
+ def worker():
+ time.sleep(0.001)
+ t = threading.Thread(target=worker)
+ t.start()
+ time.sleep(0.001)
+ t.join()
+
+ def test_threaded_leak(self):
+ gg = []
+ def worker():
+ # only main greenlet present
+ gg.append(weakref.ref(greenlet.getcurrent()))
+ for _ in range(2):
+ t = threading.Thread(target=worker)
+ t.start()
+ t.join()
+ del t
+ greenlet.getcurrent() # update ts_current
+ self.recycle_threads()
+ greenlet.getcurrent() # update ts_current
+ gc.collect()
+ greenlet.getcurrent() # update ts_current
+ for g in gg:
+ self.assertIsNone(g())
+
+ def test_threaded_adv_leak(self):
+ gg = []
+ def worker():
+ # main and additional *finished* greenlets
+ ll = greenlet.getcurrent().ll = []
+ def additional():
+ ll.append(greenlet.getcurrent())
+ for _ in range(2):
+ greenlet.greenlet(additional).switch()
+ gg.append(weakref.ref(greenlet.getcurrent()))
+ for _ in range(2):
+ t = threading.Thread(target=worker)
+ t.start()
+ t.join()
+ del t
+ greenlet.getcurrent() # update ts_current
+ self.recycle_threads()
+ greenlet.getcurrent() # update ts_current
+ gc.collect()
+ greenlet.getcurrent() # update ts_current
+ for g in gg:
+ self.assertIsNone(g())
+
+ def test_issue251_killing_cross_thread_leaks_list(self, manually_collect_background=True):
+ # See https://github.com/python-greenlet/greenlet/issues/251
+ # Killing a greenlet (probably not the main one)
+ # in one thread from another thread would
+ # result in leaking a list (the ts_delkey list).
+
+ # For the test to be valid, even empty lists have to be tracked by the
+ # GC
+ assert gc.is_tracked([])
+
+ def count_objects(kind=list):
+ # pylint:disable=unidiomatic-typecheck
+ # Collect the garbage.
+ for _ in range(3):
+ gc.collect()
+ gc.collect()
+ return sum(
+ 1
+ for x in gc.get_objects()
+ if type(x) is kind
+ )
+
+ # XXX: The main greenlet of a dead thread is only released
+ # when one of the proper greenlet APIs is used from a different
+ # running thread. See #252 (https://github.com/python-greenlet/greenlet/issues/252)
+ greenlet.getcurrent()
+ greenlets_before = count_objects(greenlet.greenlet)
+
+ background_glet_running = threading.Event()
+ background_glet_killed = threading.Event()
+ background_greenlets = []
+ def background_greenlet():
+ # Throw control back to the main greenlet.
+ greenlet.getcurrent().parent.switch()
+
+ def background_thread():
+ glet = greenlet.greenlet(background_greenlet)
+ background_greenlets.append(glet)
+ glet.switch() # Be sure it's active.
+ # Control is ours again.
+ del glet # Delete one reference from the thread it runs in.
+ background_glet_running.set()
+ background_glet_killed.wait()
+ # To trigger the background collection of the dead
+ # greenlet, thus clearing out the contents of the list, we
+ # need to run some APIs. See issue 252.
+ if manually_collect_background:
+ greenlet.getcurrent()
+
+
+ t = threading.Thread(target=background_thread)
+ t.start()
+ background_glet_running.wait()
+
+ lists_before = count_objects()
+
+ assert len(background_greenlets) == 1
+ self.assertFalse(background_greenlets[0].dead)
+ # Delete the last reference to the background greenlet
+ # from a different thread. This puts it in the background thread's
+ # ts_delkey list.
+ del background_greenlets[:]
+ background_glet_killed.set()
+
+ # Now wait for the background thread to die.
+ t.join(10)
+ del t
+
+ # Free the background main greenlet by forcing greenlet to notice a difference.
+ greenlet.getcurrent()
+ greenlets_after = count_objects(greenlet.greenlet)
+
+ lists_after = count_objects()
+ # On 2.7, we observe that lists_after is smaller than
+ # lists_before. No idea what lists got cleaned up. All the
+ # Python 3 versions match exactly.
+ self.assertLessEqual(lists_after, lists_before)
+
+ self.assertEqual(greenlets_before, greenlets_after)
+
+ @unittest.expectedFailure
+ def test_issue251_issue252_need_to_collect_in_background(self):
+ # This still fails because the leak of the list
+ # still exists when we don't call a greenlet API before exiting the
+ # thread. The proximate cause is that neither of the two greenlets
+ # from the background thread are actually being destroyed, even though
+ # the GC is in fact visiting both objects.
+ # It's not clear where that leak is? For some reason the thread-local dict
+ # holding it isn't being cleaned up.
+ self.test_issue251_killing_cross_thread_leaks_list(manually_collect_background=False)
diff --git a/lib/greenlet/tests/test_stack_saved.py b/lib/greenlet/tests/test_stack_saved.py
new file mode 100644
index 0000000..6c7353b
--- /dev/null
+++ b/lib/greenlet/tests/test_stack_saved.py
@@ -0,0 +1,19 @@
+import greenlet
+import unittest
+
+
+class Test(unittest.TestCase):
+
+ def test_stack_saved(self):
+ main = greenlet.getcurrent()
+ self.assertEqual(main._stack_saved, 0)
+
+ def func():
+ main.switch(main._stack_saved)
+
+ g = greenlet.greenlet(func)
+ x = g.switch()
+ assert x > 0, x
+ assert g._stack_saved > 0, g._stack_saved
+ g.switch()
+ assert g._stack_saved == 0, g._stack_saved
diff --git a/lib/greenlet/tests/test_throw.py b/lib/greenlet/tests/test_throw.py
new file mode 100644
index 0000000..a2014a9
--- /dev/null
+++ b/lib/greenlet/tests/test_throw.py
@@ -0,0 +1,100 @@
+import sys
+import unittest
+
+from greenlet import greenlet
+
+
+def switch(*args):
+ return greenlet.getcurrent().parent.switch(*args)
+
+
+class ThrowTests(unittest.TestCase):
+ def test_class(self):
+ def f():
+ try:
+ switch("ok")
+ except RuntimeError:
+ switch("ok")
+ return
+ switch("fail")
+ g = greenlet(f)
+ res = g.switch()
+ self.assertEqual(res, "ok")
+ res = g.throw(RuntimeError)
+ self.assertEqual(res, "ok")
+
+ def test_val(self):
+ def f():
+ try:
+ switch("ok")
+ except RuntimeError:
+ val = sys.exc_info()[1]
+ if str(val) == "ciao":
+ switch("ok")
+ return
+ switch("fail")
+
+ g = greenlet(f)
+ res = g.switch()
+ self.assertEqual(res, "ok")
+ res = g.throw(RuntimeError("ciao"))
+ self.assertEqual(res, "ok")
+
+ g = greenlet(f)
+ res = g.switch()
+ self.assertEqual(res, "ok")
+ res = g.throw(RuntimeError, "ciao")
+ self.assertEqual(res, "ok")
+
+ def test_kill(self):
+ def f():
+ switch("ok")
+ switch("fail")
+ g = greenlet(f)
+ res = g.switch()
+ self.assertEqual(res, "ok")
+ res = g.throw()
+ self.assertTrue(isinstance(res, greenlet.GreenletExit))
+ self.assertTrue(g.dead)
+ res = g.throw() # immediately eaten by the already-dead greenlet
+ self.assertTrue(isinstance(res, greenlet.GreenletExit))
+
+ def test_throw_goes_to_original_parent(self):
+ main = greenlet.getcurrent()
+
+ def f1():
+ try:
+ main.switch("f1 ready to catch")
+ except IndexError:
+ return "caught"
+ else:
+ return "normal exit"
+
+ def f2():
+ main.switch("from f2")
+
+ g1 = greenlet(f1)
+ g2 = greenlet(f2, parent=g1)
+ self.assertRaises(IndexError, g2.throw, IndexError)
+ self.assertTrue(g2.dead)
+ self.assertTrue(g1.dead)
+
+ g1 = greenlet(f1)
+ g2 = greenlet(f2, parent=g1)
+ res = g1.switch()
+ self.assertEqual(res, "f1 ready to catch")
+ res = g2.throw(IndexError)
+ self.assertEqual(res, "caught")
+ self.assertTrue(g2.dead)
+ self.assertTrue(g1.dead)
+
+ g1 = greenlet(f1)
+ g2 = greenlet(f2, parent=g1)
+ res = g1.switch()
+ self.assertEqual(res, "f1 ready to catch")
+ res = g2.switch()
+ self.assertEqual(res, "from f2")
+ res = g2.throw(IndexError)
+ self.assertEqual(res, "caught")
+ self.assertTrue(g2.dead)
+ self.assertTrue(g1.dead)
diff --git a/lib/greenlet/tests/test_tracing.py b/lib/greenlet/tests/test_tracing.py
new file mode 100644
index 0000000..2ab4d71
--- /dev/null
+++ b/lib/greenlet/tests/test_tracing.py
@@ -0,0 +1,267 @@
+import sys
+import unittest
+import greenlet
+
+class SomeError(Exception):
+ pass
+
+class GreenletTracer(object):
+ oldtrace = None
+
+ def __init__(self, error_on_trace=False):
+ self.actions = []
+ self.error_on_trace = error_on_trace
+
+ def __call__(self, *args):
+ self.actions.append(args)
+ if self.error_on_trace:
+ raise SomeError
+
+ def __enter__(self):
+ self.oldtrace = greenlet.settrace(self)
+ return self.actions
+
+ def __exit__(self, *args):
+ greenlet.settrace(self.oldtrace)
+
+
+class TestGreenletTracing(unittest.TestCase):
+ """
+ Tests of ``greenlet.settrace()``
+ """
+
+ def test_greenlet_tracing(self):
+ main = greenlet.getcurrent()
+ def dummy():
+ pass
+ def dummyexc():
+ raise SomeError()
+
+ with GreenletTracer() as actions:
+ g1 = greenlet.greenlet(dummy)
+ g1.switch()
+ g2 = greenlet.greenlet(dummyexc)
+ self.assertRaises(SomeError, g2.switch)
+
+ self.assertEqual(actions, [
+ ('switch', (main, g1)),
+ ('switch', (g1, main)),
+ ('switch', (main, g2)),
+ ('throw', (g2, main)),
+ ])
+
+ def test_exception_disables_tracing(self):
+ main = greenlet.getcurrent()
+ def dummy():
+ main.switch()
+ g = greenlet.greenlet(dummy)
+ g.switch()
+ with GreenletTracer(error_on_trace=True) as actions:
+ self.assertRaises(SomeError, g.switch)
+ self.assertEqual(greenlet.gettrace(), None)
+
+ self.assertEqual(actions, [
+ ('switch', (main, g)),
+ ])
+
+
+class PythonTracer(object):
+ oldtrace = None
+
+ def __init__(self):
+ self.actions = []
+
+ def __call__(self, frame, event, arg):
+ # Record the co_name so we have an idea what function we're in.
+ self.actions.append((event, frame.f_code.co_name))
+
+ def __enter__(self):
+ self.oldtrace = sys.setprofile(self)
+ return self.actions
+
+ def __exit__(self, *args):
+ sys.setprofile(self.oldtrace)
+
+def tpt_callback():
+ return 42
+
+class TestPythonTracing(unittest.TestCase):
+ """
+ Tests of the interaction of ``sys.settrace()``
+ with greenlet facilities.
+
+ NOTE: Most of this is probably CPython specific.
+ """
+
+ maxDiff = None
+
+ def test_trace_events_trivial(self):
+ with PythonTracer() as actions:
+ tpt_callback()
+ # If we use the sys.settrace instead of setprofile, we get
+ # this:
+
+ # self.assertEqual(actions, [
+ # ('call', 'tpt_callback'),
+ # ('call', '__exit__'),
+ # ])
+
+ self.assertEqual(actions, [
+ ('return', '__enter__'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('call', '__exit__'),
+ ('c_call', '__exit__'),
+ ])
+
+ def _trace_switch(self, glet):
+ with PythonTracer() as actions:
+ glet.switch()
+ return actions
+
+ def _check_trace_events_func_already_set(self, glet):
+ actions = self._trace_switch(glet)
+ self.assertEqual(actions, [
+ ('return', '__enter__'),
+ ('c_call', '_trace_switch'),
+ ('call', 'run'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('return', 'run'),
+ ('c_return', '_trace_switch'),
+ ('call', '__exit__'),
+ ('c_call', '__exit__'),
+ ])
+
+ def test_trace_events_into_greenlet_func_already_set(self):
+ def run():
+ return tpt_callback()
+
+ self._check_trace_events_func_already_set(greenlet.greenlet(run))
+
+ def test_trace_events_into_greenlet_subclass_already_set(self):
+ class X(greenlet.greenlet):
+ def run(self):
+ return tpt_callback()
+ self._check_trace_events_func_already_set(X())
+
+ def _check_trace_events_from_greenlet_sets_profiler(self, g, tracer):
+ g.switch()
+ tpt_callback()
+ tracer.__exit__()
+ self.assertEqual(tracer.actions, [
+ ('return', '__enter__'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('return', 'run'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('call', '__exit__'),
+ ('c_call', '__exit__'),
+ ])
+
+
+ def test_trace_events_from_greenlet_func_sets_profiler(self):
+ tracer = PythonTracer()
+ def run():
+ tracer.__enter__()
+ return tpt_callback()
+
+ self._check_trace_events_from_greenlet_sets_profiler(greenlet.greenlet(run),
+ tracer)
+
+ def test_trace_events_from_greenlet_subclass_sets_profiler(self):
+ tracer = PythonTracer()
+ class X(greenlet.greenlet):
+ def run(self):
+ tracer.__enter__()
+ return tpt_callback()
+
+ self._check_trace_events_from_greenlet_sets_profiler(X(), tracer)
+
+
+ def test_trace_events_multiple_greenlets_switching(self):
+ tracer = PythonTracer()
+
+ g1 = None
+ g2 = None
+
+ def g1_run():
+ tracer.__enter__()
+ tpt_callback()
+ g2.switch()
+ tpt_callback()
+ return 42
+
+ def g2_run():
+ tpt_callback()
+ tracer.__exit__()
+ tpt_callback()
+ g1.switch()
+
+ g1 = greenlet.greenlet(g1_run)
+ g2 = greenlet.greenlet(g2_run)
+
+ x = g1.switch()
+ self.assertEqual(x, 42)
+ tpt_callback() # ensure not in the trace
+ self.assertEqual(tracer.actions, [
+ ('return', '__enter__'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('c_call', 'g1_run'),
+ ('call', 'g2_run'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('call', '__exit__'),
+ ('c_call', '__exit__'),
+ ])
+
+ def test_trace_events_multiple_greenlets_switching_siblings(self):
+ # Like the first version, but get both greenlets running first
+ # as "siblings" and then establish the tracing.
+ tracer = PythonTracer()
+
+ g1 = None
+ g2 = None
+
+ def g1_run():
+ greenlet.getcurrent().parent.switch()
+ tracer.__enter__()
+ tpt_callback()
+ g2.switch()
+ tpt_callback()
+ return 42
+
+ def g2_run():
+ greenlet.getcurrent().parent.switch()
+
+ tpt_callback()
+ tracer.__exit__()
+ tpt_callback()
+ g1.switch()
+
+ g1 = greenlet.greenlet(g1_run)
+ g2 = greenlet.greenlet(g2_run)
+
+ # Start g1
+ g1.switch()
+ # And it immediately returns control to us.
+ # Start g2
+ g2.switch()
+ # Which also returns. Now kick of the real part of the
+ # test.
+ x = g1.switch()
+ self.assertEqual(x, 42)
+
+ tpt_callback() # ensure not in the trace
+ self.assertEqual(tracer.actions, [
+ ('return', '__enter__'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('c_call', 'g1_run'),
+ ('call', 'tpt_callback'),
+ ('return', 'tpt_callback'),
+ ('call', '__exit__'),
+ ('c_call', '__exit__'),
+ ])
diff --git a/lib/greenlet/tests/test_version.py b/lib/greenlet/tests/test_version.py
new file mode 100644
index 0000000..0c9a497
--- /dev/null
+++ b/lib/greenlet/tests/test_version.py
@@ -0,0 +1,39 @@
+#! /usr/bin/env python
+from __future__ import absolute_import
+from __future__ import print_function
+
+import sys
+import os
+import unittest
+
+import greenlet
+
+class VersionTests(unittest.TestCase):
+ def test_version(self):
+ def find_dominating_file(name):
+ if os.path.exists(name):
+ return name
+
+ tried = []
+ here = os.path.abspath(os.path.dirname(__file__))
+ for i in range(10):
+ up = ['..'] * i
+ path = [here] + up + [name]
+ fname = os.path.join(*path)
+ fname = os.path.abspath(fname)
+ tried.append(fname)
+ if os.path.exists(fname):
+ return fname
+ raise AssertionError("Could not find file " + name + "; checked " + str(tried))
+
+ try:
+ setup_py = find_dominating_file('setup.py')
+ except AssertionError as e:
+ raise unittest.SkipTest("Unable to find setup.py; must be out of tree. " + str(e))
+
+
+ invoke_setup = "%s %s --version" % (sys.executable, setup_py)
+ with os.popen(invoke_setup) as f:
+ sversion = f.read().strip()
+
+ self.assertEqual(sversion, greenlet.__version__)
diff --git a/lib/greenlet/tests/test_weakref.py b/lib/greenlet/tests/test_weakref.py
new file mode 100644
index 0000000..6a2ff06
--- /dev/null
+++ b/lib/greenlet/tests/test_weakref.py
@@ -0,0 +1,34 @@
+import gc
+import greenlet
+import weakref
+import unittest
+
+
+class WeakRefTests(unittest.TestCase):
+ def test_dead_weakref(self):
+ def _dead_greenlet():
+ g = greenlet.greenlet(lambda: None)
+ g.switch()
+ return g
+ o = weakref.ref(_dead_greenlet())
+ gc.collect()
+ self.assertEqual(o(), None)
+
+ def test_inactive_weakref(self):
+ o = weakref.ref(greenlet.greenlet())
+ gc.collect()
+ self.assertEqual(o(), None)
+
+ def test_dealloc_weakref(self):
+ seen = []
+ def worker():
+ try:
+ greenlet.getcurrent().parent.switch()
+ finally:
+ seen.append(g())
+ g = greenlet.greenlet(worker)
+ g.switch()
+ g2 = greenlet.greenlet(lambda: None, g)
+ g = weakref.ref(g2)
+ g2 = None
+ self.assertEqual(seen, [None])
diff --git a/lib/include/python/greenlet/greenlet.h b/lib/include/python/greenlet/greenlet.h
new file mode 100644
index 0000000..c788b2f
--- /dev/null
+++ b/lib/include/python/greenlet/greenlet.h
@@ -0,0 +1,161 @@
+/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
+
+/* Greenlet object interface */
+
+#ifndef Py_GREENLETOBJECT_H
+#define Py_GREENLETOBJECT_H
+
+#include <Python.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* This is deprecated and undocumented. It does not change. */
+#define GREENLET_VERSION "1.0.0"
+
+#if PY_VERSION_HEX >= 0x30B00A6
+# define GREENLET_PY311 1
+ /* _PyInterpreterFrame moved to the internal C API in Python 3.11 */
+# include <internal/pycore_frame.h>
+#else
+# define GREENLET_PY311 0
+# define _PyCFrame CFrame
+#endif
+
+typedef struct _greenlet {
+ PyObject_HEAD
+ char* stack_start;
+ char* stack_stop;
+ char* stack_copy;
+ intptr_t stack_saved;
+ struct _greenlet* stack_prev;
+ struct _greenlet* parent;
+ PyObject* run_info;
+ struct _frame* top_frame;
+ int recursion_depth;
+#if GREENLET_PY311
+ _PyInterpreterFrame *current_frame;
+ _PyStackChunk *datastack_chunk;
+ PyObject **datastack_top;
+ PyObject **datastack_limit;
+#endif
+ PyObject* weakreflist;
+#if PY_VERSION_HEX >= 0x030700A3
+ _PyErr_StackItem* exc_info;
+ _PyErr_StackItem exc_state;
+#else
+ PyObject* exc_type;
+ PyObject* exc_value;
+ PyObject* exc_traceback;
+#endif
+ PyObject* dict;
+#if PY_VERSION_HEX >= 0x030700A3
+ PyObject* context;
+#endif
+#if PY_VERSION_HEX >= 0x30A00B1
+ _PyCFrame* cframe;
+#endif
+} PyGreenlet;
+
+#define PyGreenlet_Check(op) PyObject_TypeCheck(op, &PyGreenlet_Type)
+#define PyGreenlet_MAIN(op) (((PyGreenlet*)(op))->stack_stop == (char*)-1)
+#define PyGreenlet_STARTED(op) (((PyGreenlet*)(op))->stack_stop != NULL)
+#define PyGreenlet_ACTIVE(op) (((PyGreenlet*)(op))->stack_start != NULL)
+#define PyGreenlet_GET_PARENT(op) (((PyGreenlet*)(op))->parent)
+
+/* C API functions */
+
+/* Total number of symbols that are exported */
+#define PyGreenlet_API_pointers 8
+
+#define PyGreenlet_Type_NUM 0
+#define PyExc_GreenletError_NUM 1
+#define PyExc_GreenletExit_NUM 2
+
+#define PyGreenlet_New_NUM 3
+#define PyGreenlet_GetCurrent_NUM 4
+#define PyGreenlet_Throw_NUM 5
+#define PyGreenlet_Switch_NUM 6
+#define PyGreenlet_SetParent_NUM 7
+
+#ifndef GREENLET_MODULE
+/* This section is used by modules that uses the greenlet C API */
+static void** _PyGreenlet_API = NULL;
+
+# define PyGreenlet_Type \
+ (*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
+
+# define PyExc_GreenletError \
+ ((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
+
+# define PyExc_GreenletExit \
+ ((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
+
+/*
+ * PyGreenlet_New(PyObject *args)
+ *
+ * greenlet.greenlet(run, parent=None)
+ */
+# define PyGreenlet_New \
+ (*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
+ _PyGreenlet_API[PyGreenlet_New_NUM])
+
+/*
+ * PyGreenlet_GetCurrent(void)
+ *
+ * greenlet.getcurrent()
+ */
+# define PyGreenlet_GetCurrent \
+ (*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
+
+/*
+ * PyGreenlet_Throw(
+ * PyGreenlet *greenlet,
+ * PyObject *typ,
+ * PyObject *val,
+ * PyObject *tb)
+ *
+ * g.throw(...)
+ */
+# define PyGreenlet_Throw \
+ (*(PyObject * (*)(PyGreenlet * self, \
+ PyObject * typ, \
+ PyObject * val, \
+ PyObject * tb)) \
+ _PyGreenlet_API[PyGreenlet_Throw_NUM])
+
+/*
+ * PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
+ *
+ * g.switch(*args, **kwargs)
+ */
+# define PyGreenlet_Switch \
+ (*(PyObject * \
+ (*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
+ _PyGreenlet_API[PyGreenlet_Switch_NUM])
+
+/*
+ * PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
+ *
+ * g.parent = new_parent
+ */
+# define PyGreenlet_SetParent \
+ (*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
+ _PyGreenlet_API[PyGreenlet_SetParent_NUM])
+
+/* Macro that imports greenlet and initializes C API */
+/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
+ keep the older definition to be sure older code that might have a copy of
+ the header still works. */
+# define PyGreenlet_Import() \
+ { \
+ _PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
+ }
+
+#endif /* GREENLET_MODULE */
+
+#ifdef __cplusplus
+}
+#endif
+#endif /* !Py_GREENLETOBJECT_H */
diff --git a/lib/pexpect-4.8.0.dist-info/INSTALLER b/lib/pexpect-4.8.0.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/pexpect-4.8.0.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/pexpect-4.8.0.dist-info/LICENSE b/lib/pexpect-4.8.0.dist-info/LICENSE
new file mode 100644
index 0000000..754db5a
--- /dev/null
+++ b/lib/pexpect-4.8.0.dist-info/LICENSE
@@ -0,0 +1,20 @@
+ISC LICENSE
+
+ This license is approved by the OSI and FSF as GPL-compatible.
+ http://opensource.org/licenses/isc-license.txt
+
+ Copyright (c) 2013-2014, Pexpect development team
+ Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+
+ Permission to use, copy, modify, and/or distribute this software for any
+ purpose with or without fee is hereby granted, provided that the above
+ copyright notice and this permission notice appear in all copies.
+
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
diff --git a/lib/pexpect-4.8.0.dist-info/METADATA b/lib/pexpect-4.8.0.dist-info/METADATA
new file mode 100644
index 0000000..4c41722
--- /dev/null
+++ b/lib/pexpect-4.8.0.dist-info/METADATA
@@ -0,0 +1,49 @@
+Metadata-Version: 2.1
+Name: pexpect
+Version: 4.8.0
+Summary: Pexpect allows easy control of interactive console applications.
+Home-page: https://pexpect.readthedocs.io/
+Author: Noah Spurrier; Thomas Kluyver; Jeff Quast
+Author-email: noah@noah.org, thomas@kluyver.me.uk, contact@jeffquast.com
+License: ISC license
+Platform: UNIX
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Environment :: Console
+Classifier: Intended Audience :: Developers
+Classifier: Intended Audience :: System Administrators
+Classifier: License :: OSI Approved :: ISC License (ISCL)
+Classifier: Operating System :: POSIX
+Classifier: Operating System :: MacOS :: MacOS X
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Topic :: Software Development
+Classifier: Topic :: Software Development :: Libraries :: Python Modules
+Classifier: Topic :: Software Development :: Quality Assurance
+Classifier: Topic :: Software Development :: Testing
+Classifier: Topic :: System
+Classifier: Topic :: System :: Archiving :: Packaging
+Classifier: Topic :: System :: Installation/Setup
+Classifier: Topic :: System :: Shells
+Classifier: Topic :: System :: Software Distribution
+Classifier: Topic :: Terminals
+Requires-Dist: ptyprocess (>=0.5)
+
+
+Pexpect is a pure Python module for spawning child applications; controlling
+them; and responding to expected patterns in their output. Pexpect works like
+Don Libes' Expect. Pexpect allows your script to spawn a child application and
+control it as if a human were typing commands.
+
+Pexpect can be used for automating interactive applications such as ssh, ftp,
+passwd, telnet, etc. It can be used to a automate setup scripts for duplicating
+software package installations on different servers. It can be used for
+automated software testing. Pexpect is in the spirit of Don Libes' Expect, but
+Pexpect is pure Python.
+
+The main features of Pexpect require the pty module in the Python standard
+library, which is only available on Unix-like systems. Some features—waiting
+for patterns from file descriptors or subprocesses—are also available on
+Windows.
+
+
diff --git a/lib/pexpect-4.8.0.dist-info/RECORD b/lib/pexpect-4.8.0.dist-info/RECORD
new file mode 100644
index 0000000..8cb0e93
--- /dev/null
+++ b/lib/pexpect-4.8.0.dist-info/RECORD
@@ -0,0 +1,38 @@
+pexpect-4.8.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+pexpect-4.8.0.dist-info/LICENSE,sha256=Skg64cTcc4psi3P-tJB04YNdoCq1qmhvJnUCmQb6Nk0,987
+pexpect-4.8.0.dist-info/METADATA,sha256=gD4QNR0Xubt0nLzHH_YhdcXrQS5MIHGuxE-kJs0RNJo,2180
+pexpect-4.8.0.dist-info/RECORD,,
+pexpect-4.8.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+pexpect-4.8.0.dist-info/WHEEL,sha256=8zNYZbwQSXoB9IfXOjPfeNwvAsALAjffgk27FqvCWbo,110
+pexpect-4.8.0.dist-info/top_level.txt,sha256=O-b3UY9VQZkW3yDAeFNatUOKO4GojVWO4TTHoI9-E7k,8
+pexpect/ANSI.py,sha256=aA-3tdXz_FZ4G7PAqFZi5g1KBGQ6PzJzS0gm3ALZKZw,12177
+pexpect/FSM.py,sha256=tluiyUGMyIH3q_wLG6Ak1NZVuXUAGNDjq6k6BK1q8RY,13419
+pexpect/__init__.py,sha256=xF4qylJdK-FRy40tmhAXu99fV9ewFasxTH3DSgoZjzQ,3902
+pexpect/__pycache__/ANSI.cpython-39.pyc,,
+pexpect/__pycache__/FSM.cpython-39.pyc,,
+pexpect/__pycache__/__init__.cpython-39.pyc,,
+pexpect/__pycache__/_async.cpython-39.pyc,,
+pexpect/__pycache__/exceptions.cpython-39.pyc,,
+pexpect/__pycache__/expect.cpython-39.pyc,,
+pexpect/__pycache__/fdpexpect.cpython-39.pyc,,
+pexpect/__pycache__/popen_spawn.cpython-39.pyc,,
+pexpect/__pycache__/pty_spawn.cpython-39.pyc,,
+pexpect/__pycache__/pxssh.cpython-39.pyc,,
+pexpect/__pycache__/replwrap.cpython-39.pyc,,
+pexpect/__pycache__/run.cpython-39.pyc,,
+pexpect/__pycache__/screen.cpython-39.pyc,,
+pexpect/__pycache__/spawnbase.cpython-39.pyc,,
+pexpect/__pycache__/utils.cpython-39.pyc,,
+pexpect/_async.py,sha256=UCUC9kbBZGjzG12YcR_M5yBjB4Dwc8nJOYNPklL-OdU,3304
+pexpect/bashrc.sh,sha256=CHK8qDg_HtDVdfyDULOV8MZDRDr4pOaIbo31XV58nQs,380
+pexpect/exceptions.py,sha256=A9C1PWbBc2j9AKvnv7UkPCawhFTEGYmeULW0vwbMvXQ,1068
+pexpect/expect.py,sha256=KKtBmx2MYa-yDE715XlHUcloKe5ndBD359a4OYVXD84,13827
+pexpect/fdpexpect.py,sha256=ugTrwveFi-zfl_nOPjbRyLUER1Wmhu8YxczCWtZgZWc,5828
+pexpect/popen_spawn.py,sha256=hVHOqr22jD2Pr-yVgsfwgqGAtULLi6kJLKQRrTBPvEg,6161
+pexpect/pty_spawn.py,sha256=ZygSYsdnVJ5acxiNM9gLvLrT2AVqgwJvbDcPaTxxv9E,37382
+pexpect/pxssh.py,sha256=bZHwFDOn1gC8U_Sl07eFFRlYfCjGCwEoC9WaZCHQo5Y,24279
+pexpect/replwrap.py,sha256=Raq9XgYfIlF-rH_CALgFbzK1H_A4o0NqmK9q45anmVA,5633
+pexpect/run.py,sha256=XK2GwW6_wbUZ6buIDbhouaOySVPnc5IahbgSjieks50,6628
+pexpect/screen.py,sha256=-twD4sIEp83nzuYH9lRDzwHfesoTgVGWglsBYWOK7Ks,13704
+pexpect/spawnbase.py,sha256=FoaNvkGYIXrD6xmBedrJmm_oVLMatqbGCXc027CugkQ,21247
+pexpect/utils.py,sha256=1jIhzU7eBvY3pbW3LZoJhCOU2KWqgty5HgQ6VBYIp5U,6019
diff --git a/lib/pexpect-4.8.0.dist-info/REQUESTED b/lib/pexpect-4.8.0.dist-info/REQUESTED
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/pexpect-4.8.0.dist-info/REQUESTED
diff --git a/lib/pexpect-4.8.0.dist-info/WHEEL b/lib/pexpect-4.8.0.dist-info/WHEEL
new file mode 100644
index 0000000..8b701e9
--- /dev/null
+++ b/lib/pexpect-4.8.0.dist-info/WHEEL
@@ -0,0 +1,6 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.33.6)
+Root-Is-Purelib: true
+Tag: py2-none-any
+Tag: py3-none-any
+
diff --git a/lib/pexpect-4.8.0.dist-info/top_level.txt b/lib/pexpect-4.8.0.dist-info/top_level.txt
new file mode 100644
index 0000000..808fb07
--- /dev/null
+++ b/lib/pexpect-4.8.0.dist-info/top_level.txt
@@ -0,0 +1 @@
+pexpect
diff --git a/lib/pexpect/ANSI.py b/lib/pexpect/ANSI.py
new file mode 100644
index 0000000..1cd2e90
--- /dev/null
+++ b/lib/pexpect/ANSI.py
@@ -0,0 +1,351 @@
+'''This implements an ANSI (VT100) terminal emulator as a subclass of screen.
+
+PEXPECT LICENSE
+
+ This license is approved by the OSI and FSF as GPL-compatible.
+ http://opensource.org/licenses/isc-license.txt
+
+ Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+ PERMISSION TO USE, COPY, MODIFY, AND/OR DISTRIBUTE THIS SOFTWARE FOR ANY
+ PURPOSE WITH OR WITHOUT FEE IS HEREBY GRANTED, PROVIDED THAT THE ABOVE
+ COPYRIGHT NOTICE AND THIS PERMISSION NOTICE APPEAR IN ALL COPIES.
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''
+
+# references:
+# http://en.wikipedia.org/wiki/ANSI_escape_code
+# http://www.retards.org/terminals/vt102.html
+# http://vt100.net/docs/vt102-ug/contents.html
+# http://vt100.net/docs/vt220-rm/
+# http://www.termsys.demon.co.uk/vtansi.htm
+
+from . import screen
+from . import FSM
+import string
+
+#
+# The 'Do.*' functions are helper functions for the ANSI class.
+#
+def DoEmit (fsm):
+
+ screen = fsm.memory[0]
+ screen.write_ch(fsm.input_symbol)
+
+def DoStartNumber (fsm):
+
+ fsm.memory.append (fsm.input_symbol)
+
+def DoBuildNumber (fsm):
+
+ ns = fsm.memory.pop()
+ ns = ns + fsm.input_symbol
+ fsm.memory.append (ns)
+
+def DoBackOne (fsm):
+
+ screen = fsm.memory[0]
+ screen.cursor_back ()
+
+def DoBack (fsm):
+
+ count = int(fsm.memory.pop())
+ screen = fsm.memory[0]
+ screen.cursor_back (count)
+
+def DoDownOne (fsm):
+
+ screen = fsm.memory[0]
+ screen.cursor_down ()
+
+def DoDown (fsm):
+
+ count = int(fsm.memory.pop())
+ screen = fsm.memory[0]
+ screen.cursor_down (count)
+
+def DoForwardOne (fsm):
+
+ screen = fsm.memory[0]
+ screen.cursor_forward ()
+
+def DoForward (fsm):
+
+ count = int(fsm.memory.pop())
+ screen = fsm.memory[0]
+ screen.cursor_forward (count)
+
+def DoUpReverse (fsm):
+
+ screen = fsm.memory[0]
+ screen.cursor_up_reverse()
+
+def DoUpOne (fsm):
+
+ screen = fsm.memory[0]
+ screen.cursor_up ()
+
+def DoUp (fsm):
+
+ count = int(fsm.memory.pop())
+ screen = fsm.memory[0]
+ screen.cursor_up (count)
+
+def DoHome (fsm):
+
+ c = int(fsm.memory.pop())
+ r = int(fsm.memory.pop())
+ screen = fsm.memory[0]
+ screen.cursor_home (r,c)
+
+def DoHomeOrigin (fsm):
+
+ c = 1
+ r = 1
+ screen = fsm.memory[0]
+ screen.cursor_home (r,c)
+
+def DoEraseDown (fsm):
+
+ screen = fsm.memory[0]
+ screen.erase_down()
+
+def DoErase (fsm):
+
+ arg = int(fsm.memory.pop())
+ screen = fsm.memory[0]
+ if arg == 0:
+ screen.erase_down()
+ elif arg == 1:
+ screen.erase_up()
+ elif arg == 2:
+ screen.erase_screen()
+
+def DoEraseEndOfLine (fsm):
+
+ screen = fsm.memory[0]
+ screen.erase_end_of_line()
+
+def DoEraseLine (fsm):
+
+ arg = int(fsm.memory.pop())
+ screen = fsm.memory[0]
+ if arg == 0:
+ screen.erase_end_of_line()
+ elif arg == 1:
+ screen.erase_start_of_line()
+ elif arg == 2:
+ screen.erase_line()
+
+def DoEnableScroll (fsm):
+
+ screen = fsm.memory[0]
+ screen.scroll_screen()
+
+def DoCursorSave (fsm):
+
+ screen = fsm.memory[0]
+ screen.cursor_save_attrs()
+
+def DoCursorRestore (fsm):
+
+ screen = fsm.memory[0]
+ screen.cursor_restore_attrs()
+
+def DoScrollRegion (fsm):
+
+ screen = fsm.memory[0]
+ r2 = int(fsm.memory.pop())
+ r1 = int(fsm.memory.pop())
+ screen.scroll_screen_rows (r1,r2)
+
+def DoMode (fsm):
+
+ screen = fsm.memory[0]
+ mode = fsm.memory.pop() # Should be 4
+ # screen.setReplaceMode ()
+
+def DoLog (fsm):
+
+ screen = fsm.memory[0]
+ fsm.memory = [screen]
+ fout = open ('log', 'a')
+ fout.write (fsm.input_symbol + ',' + fsm.current_state + '\n')
+ fout.close()
+
+class term (screen.screen):
+
+ '''This class is an abstract, generic terminal.
+ This does nothing. This is a placeholder that
+ provides a common base class for other terminals
+ such as an ANSI terminal. '''
+
+ def __init__ (self, r=24, c=80, *args, **kwargs):
+
+ screen.screen.__init__(self, r,c,*args,**kwargs)
+
+class ANSI (term):
+ '''This class implements an ANSI (VT100) terminal.
+ It is a stream filter that recognizes ANSI terminal
+ escape sequences and maintains the state of a screen object. '''
+
+ def __init__ (self, r=24,c=80,*args,**kwargs):
+
+ term.__init__(self,r,c,*args,**kwargs)
+
+ #self.screen = screen (24,80)
+ self.state = FSM.FSM ('INIT',[self])
+ self.state.set_default_transition (DoLog, 'INIT')
+ self.state.add_transition_any ('INIT', DoEmit, 'INIT')
+ self.state.add_transition ('\x1b', 'INIT', None, 'ESC')
+ self.state.add_transition_any ('ESC', DoLog, 'INIT')
+ self.state.add_transition ('(', 'ESC', None, 'G0SCS')
+ self.state.add_transition (')', 'ESC', None, 'G1SCS')
+ self.state.add_transition_list ('AB012', 'G0SCS', None, 'INIT')
+ self.state.add_transition_list ('AB012', 'G1SCS', None, 'INIT')
+ self.state.add_transition ('7', 'ESC', DoCursorSave, 'INIT')
+ self.state.add_transition ('8', 'ESC', DoCursorRestore, 'INIT')
+ self.state.add_transition ('M', 'ESC', DoUpReverse, 'INIT')
+ self.state.add_transition ('>', 'ESC', DoUpReverse, 'INIT')
+ self.state.add_transition ('<', 'ESC', DoUpReverse, 'INIT')
+ self.state.add_transition ('=', 'ESC', None, 'INIT') # Selects application keypad.
+ self.state.add_transition ('#', 'ESC', None, 'GRAPHICS_POUND')
+ self.state.add_transition_any ('GRAPHICS_POUND', None, 'INIT')
+ self.state.add_transition ('[', 'ESC', None, 'ELB')
+ # ELB means Escape Left Bracket. That is ^[[
+ self.state.add_transition ('H', 'ELB', DoHomeOrigin, 'INIT')
+ self.state.add_transition ('D', 'ELB', DoBackOne, 'INIT')
+ self.state.add_transition ('B', 'ELB', DoDownOne, 'INIT')
+ self.state.add_transition ('C', 'ELB', DoForwardOne, 'INIT')
+ self.state.add_transition ('A', 'ELB', DoUpOne, 'INIT')
+ self.state.add_transition ('J', 'ELB', DoEraseDown, 'INIT')
+ self.state.add_transition ('K', 'ELB', DoEraseEndOfLine, 'INIT')
+ self.state.add_transition ('r', 'ELB', DoEnableScroll, 'INIT')
+ self.state.add_transition ('m', 'ELB', self.do_sgr, 'INIT')
+ self.state.add_transition ('?', 'ELB', None, 'MODECRAP')
+ self.state.add_transition_list (string.digits, 'ELB', DoStartNumber, 'NUMBER_1')
+ self.state.add_transition_list (string.digits, 'NUMBER_1', DoBuildNumber, 'NUMBER_1')
+ self.state.add_transition ('D', 'NUMBER_1', DoBack, 'INIT')
+ self.state.add_transition ('B', 'NUMBER_1', DoDown, 'INIT')
+ self.state.add_transition ('C', 'NUMBER_1', DoForward, 'INIT')
+ self.state.add_transition ('A', 'NUMBER_1', DoUp, 'INIT')
+ self.state.add_transition ('J', 'NUMBER_1', DoErase, 'INIT')
+ self.state.add_transition ('K', 'NUMBER_1', DoEraseLine, 'INIT')
+ self.state.add_transition ('l', 'NUMBER_1', DoMode, 'INIT')
+ ### It gets worse... the 'm' code can have infinite number of
+ ### number;number;number before it. I've never seen more than two,
+ ### but the specs say it's allowed. crap!
+ self.state.add_transition ('m', 'NUMBER_1', self.do_sgr, 'INIT')
+ ### LED control. Same implementation problem as 'm' code.
+ self.state.add_transition ('q', 'NUMBER_1', self.do_decsca, 'INIT')
+
+ # \E[?47h switch to alternate screen
+ # \E[?47l restores to normal screen from alternate screen.
+ self.state.add_transition_list (string.digits, 'MODECRAP', DoStartNumber, 'MODECRAP_NUM')
+ self.state.add_transition_list (string.digits, 'MODECRAP_NUM', DoBuildNumber, 'MODECRAP_NUM')
+ self.state.add_transition ('l', 'MODECRAP_NUM', self.do_modecrap, 'INIT')
+ self.state.add_transition ('h', 'MODECRAP_NUM', self.do_modecrap, 'INIT')
+
+#RM Reset Mode Esc [ Ps l none
+ self.state.add_transition (';', 'NUMBER_1', None, 'SEMICOLON')
+ self.state.add_transition_any ('SEMICOLON', DoLog, 'INIT')
+ self.state.add_transition_list (string.digits, 'SEMICOLON', DoStartNumber, 'NUMBER_2')
+ self.state.add_transition_list (string.digits, 'NUMBER_2', DoBuildNumber, 'NUMBER_2')
+ self.state.add_transition_any ('NUMBER_2', DoLog, 'INIT')
+ self.state.add_transition ('H', 'NUMBER_2', DoHome, 'INIT')
+ self.state.add_transition ('f', 'NUMBER_2', DoHome, 'INIT')
+ self.state.add_transition ('r', 'NUMBER_2', DoScrollRegion, 'INIT')
+ ### It gets worse... the 'm' code can have infinite number of
+ ### number;number;number before it. I've never seen more than two,
+ ### but the specs say it's allowed. crap!
+ self.state.add_transition ('m', 'NUMBER_2', self.do_sgr, 'INIT')
+ ### LED control. Same problem as 'm' code.
+ self.state.add_transition ('q', 'NUMBER_2', self.do_decsca, 'INIT')
+ self.state.add_transition (';', 'NUMBER_2', None, 'SEMICOLON_X')
+
+ # Create a state for 'q' and 'm' which allows an infinite number of ignored numbers
+ self.state.add_transition_any ('SEMICOLON_X', DoLog, 'INIT')
+ self.state.add_transition_list (string.digits, 'SEMICOLON_X', DoStartNumber, 'NUMBER_X')
+ self.state.add_transition_list (string.digits, 'NUMBER_X', DoBuildNumber, 'NUMBER_X')
+ self.state.add_transition_any ('NUMBER_X', DoLog, 'INIT')
+ self.state.add_transition ('m', 'NUMBER_X', self.do_sgr, 'INIT')
+ self.state.add_transition ('q', 'NUMBER_X', self.do_decsca, 'INIT')
+ self.state.add_transition (';', 'NUMBER_X', None, 'SEMICOLON_X')
+
+ def process (self, c):
+ """Process a single character. Called by :meth:`write`."""
+ if isinstance(c, bytes):
+ c = self._decode(c)
+ self.state.process(c)
+
+ def process_list (self, l):
+
+ self.write(l)
+
+ def write (self, s):
+ """Process text, writing it to the virtual screen while handling
+ ANSI escape codes.
+ """
+ if isinstance(s, bytes):
+ s = self._decode(s)
+ for c in s:
+ self.process(c)
+
+ def flush (self):
+ pass
+
+ def write_ch (self, ch):
+ '''This puts a character at the current cursor position. The cursor
+ position is moved forward with wrap-around, but no scrolling is done if
+ the cursor hits the lower-right corner of the screen. '''
+
+ if isinstance(ch, bytes):
+ ch = self._decode(ch)
+
+ #\r and \n both produce a call to cr() and lf(), respectively.
+ ch = ch[0]
+
+ if ch == u'\r':
+ self.cr()
+ return
+ if ch == u'\n':
+ self.crlf()
+ return
+ if ch == chr(screen.BS):
+ self.cursor_back()
+ return
+ self.put_abs(self.cur_r, self.cur_c, ch)
+ old_r = self.cur_r
+ old_c = self.cur_c
+ self.cursor_forward()
+ if old_c == self.cur_c:
+ self.cursor_down()
+ if old_r != self.cur_r:
+ self.cursor_home (self.cur_r, 1)
+ else:
+ self.scroll_up ()
+ self.cursor_home (self.cur_r, 1)
+ self.erase_line()
+
+ def do_sgr (self, fsm):
+ '''Select Graphic Rendition, e.g. color. '''
+ screen = fsm.memory[0]
+ fsm.memory = [screen]
+
+ def do_decsca (self, fsm):
+ '''Select character protection attribute. '''
+ screen = fsm.memory[0]
+ fsm.memory = [screen]
+
+ def do_modecrap (self, fsm):
+ '''Handler for \x1b[?<number>h and \x1b[?<number>l. If anyone
+ wanted to actually use these, they'd need to add more states to the
+ FSM rather than just improve or override this method. '''
+ screen = fsm.memory[0]
+ fsm.memory = [screen]
diff --git a/lib/pexpect/FSM.py b/lib/pexpect/FSM.py
new file mode 100644
index 0000000..46b392e
--- /dev/null
+++ b/lib/pexpect/FSM.py
@@ -0,0 +1,334 @@
+#!/usr/bin/env python
+
+'''This module implements a Finite State Machine (FSM). In addition to state
+this FSM also maintains a user defined "memory". So this FSM can be used as a
+Push-down Automata (PDA) since a PDA is a FSM + memory.
+
+The following describes how the FSM works, but you will probably also need to
+see the example function to understand how the FSM is used in practice.
+
+You define an FSM by building tables of transitions. For a given input symbol
+the process() method uses these tables to decide what action to call and what
+the next state will be. The FSM has a table of transitions that associate:
+
+ (input_symbol, current_state) --> (action, next_state)
+
+Where "action" is a function you define. The symbols and states can be any
+objects. You use the add_transition() and add_transition_list() methods to add
+to the transition table. The FSM also has a table of transitions that
+associate:
+
+ (current_state) --> (action, next_state)
+
+You use the add_transition_any() method to add to this transition table. The
+FSM also has one default transition that is not associated with any specific
+input_symbol or state. You use the set_default_transition() method to set the
+default transition.
+
+When an action function is called it is passed a reference to the FSM. The
+action function may then access attributes of the FSM such as input_symbol,
+current_state, or "memory". The "memory" attribute can be any object that you
+want to pass along to the action functions. It is not used by the FSM itself.
+For parsing you would typically pass a list to be used as a stack.
+
+The processing sequence is as follows. The process() method is given an
+input_symbol to process. The FSM will search the table of transitions that
+associate:
+
+ (input_symbol, current_state) --> (action, next_state)
+
+If the pair (input_symbol, current_state) is found then process() will call the
+associated action function and then set the current state to the next_state.
+
+If the FSM cannot find a match for (input_symbol, current_state) it will then
+search the table of transitions that associate:
+
+ (current_state) --> (action, next_state)
+
+If the current_state is found then the process() method will call the
+associated action function and then set the current state to the next_state.
+Notice that this table lacks an input_symbol. It lets you define transitions
+for a current_state and ANY input_symbol. Hence, it is called the "any" table.
+Remember, it is always checked after first searching the table for a specific
+(input_symbol, current_state).
+
+For the case where the FSM did not match either of the previous two cases the
+FSM will try to use the default transition. If the default transition is
+defined then the process() method will call the associated action function and
+then set the current state to the next_state. This lets you define a default
+transition as a catch-all case. You can think of it as an exception handler.
+There can be only one default transition.
+
+Finally, if none of the previous cases are defined for an input_symbol and
+current_state then the FSM will raise an exception. This may be desirable, but
+you can always prevent this just by defining a default transition.
+
+Noah Spurrier 20020822
+
+PEXPECT LICENSE
+
+ This license is approved by the OSI and FSF as GPL-compatible.
+ http://opensource.org/licenses/isc-license.txt
+
+ Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+ PERMISSION TO USE, COPY, MODIFY, AND/OR DISTRIBUTE THIS SOFTWARE FOR ANY
+ PURPOSE WITH OR WITHOUT FEE IS HEREBY GRANTED, PROVIDED THAT THE ABOVE
+ COPYRIGHT NOTICE AND THIS PERMISSION NOTICE APPEAR IN ALL COPIES.
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''
+
+class ExceptionFSM(Exception):
+
+ '''This is the FSM Exception class.'''
+
+ def __init__(self, value):
+ self.value = value
+
+ def __str__(self):
+ return 'ExceptionFSM: ' + str(self.value)
+
+class FSM:
+
+ '''This is a Finite State Machine (FSM).
+ '''
+
+ def __init__(self, initial_state, memory=None):
+
+ '''This creates the FSM. You set the initial state here. The "memory"
+ attribute is any object that you want to pass along to the action
+ functions. It is not used by the FSM. For parsing you would typically
+ pass a list to be used as a stack. '''
+
+ # Map (input_symbol, current_state) --> (action, next_state).
+ self.state_transitions = {}
+ # Map (current_state) --> (action, next_state).
+ self.state_transitions_any = {}
+ self.default_transition = None
+
+ self.input_symbol = None
+ self.initial_state = initial_state
+ self.current_state = self.initial_state
+ self.next_state = None
+ self.action = None
+ self.memory = memory
+
+ def reset (self):
+
+ '''This sets the current_state to the initial_state and sets
+ input_symbol to None. The initial state was set by the constructor
+ __init__(). '''
+
+ self.current_state = self.initial_state
+ self.input_symbol = None
+
+ def add_transition (self, input_symbol, state, action=None, next_state=None):
+
+ '''This adds a transition that associates:
+
+ (input_symbol, current_state) --> (action, next_state)
+
+ The action may be set to None in which case the process() method will
+ ignore the action and only set the next_state. The next_state may be
+ set to None in which case the current state will be unchanged.
+
+ You can also set transitions for a list of symbols by using
+ add_transition_list(). '''
+
+ if next_state is None:
+ next_state = state
+ self.state_transitions[(input_symbol, state)] = (action, next_state)
+
+ def add_transition_list (self, list_input_symbols, state, action=None, next_state=None):
+
+ '''This adds the same transition for a list of input symbols.
+ You can pass a list or a string. Note that it is handy to use
+ string.digits, string.whitespace, string.letters, etc. to add
+ transitions that match character classes.
+
+ The action may be set to None in which case the process() method will
+ ignore the action and only set the next_state. The next_state may be
+ set to None in which case the current state will be unchanged. '''
+
+ if next_state is None:
+ next_state = state
+ for input_symbol in list_input_symbols:
+ self.add_transition (input_symbol, state, action, next_state)
+
+ def add_transition_any (self, state, action=None, next_state=None):
+
+ '''This adds a transition that associates:
+
+ (current_state) --> (action, next_state)
+
+ That is, any input symbol will match the current state.
+ The process() method checks the "any" state associations after it first
+ checks for an exact match of (input_symbol, current_state).
+
+ The action may be set to None in which case the process() method will
+ ignore the action and only set the next_state. The next_state may be
+ set to None in which case the current state will be unchanged. '''
+
+ if next_state is None:
+ next_state = state
+ self.state_transitions_any [state] = (action, next_state)
+
+ def set_default_transition (self, action, next_state):
+
+ '''This sets the default transition. This defines an action and
+ next_state if the FSM cannot find the input symbol and the current
+ state in the transition list and if the FSM cannot find the
+ current_state in the transition_any list. This is useful as a final
+ fall-through state for catching errors and undefined states.
+
+ The default transition can be removed by setting the attribute
+ default_transition to None. '''
+
+ self.default_transition = (action, next_state)
+
+ def get_transition (self, input_symbol, state):
+
+ '''This returns (action, next state) given an input_symbol and state.
+ This does not modify the FSM state, so calling this method has no side
+ effects. Normally you do not call this method directly. It is called by
+ process().
+
+ The sequence of steps to check for a defined transition goes from the
+ most specific to the least specific.
+
+ 1. Check state_transitions[] that match exactly the tuple,
+ (input_symbol, state)
+
+ 2. Check state_transitions_any[] that match (state)
+ In other words, match a specific state and ANY input_symbol.
+
+ 3. Check if the default_transition is defined.
+ This catches any input_symbol and any state.
+ This is a handler for errors, undefined states, or defaults.
+
+ 4. No transition was defined. If we get here then raise an exception.
+ '''
+
+ if (input_symbol, state) in self.state_transitions:
+ return self.state_transitions[(input_symbol, state)]
+ elif state in self.state_transitions_any:
+ return self.state_transitions_any[state]
+ elif self.default_transition is not None:
+ return self.default_transition
+ else:
+ raise ExceptionFSM ('Transition is undefined: (%s, %s).' %
+ (str(input_symbol), str(state)) )
+
+ def process (self, input_symbol):
+
+ '''This is the main method that you call to process input. This may
+ cause the FSM to change state and call an action. This method calls
+ get_transition() to find the action and next_state associated with the
+ input_symbol and current_state. If the action is None then the action
+ is not called and only the current state is changed. This method
+ processes one complete input symbol. You can process a list of symbols
+ (or a string) by calling process_list(). '''
+
+ self.input_symbol = input_symbol
+ (self.action, self.next_state) = self.get_transition (self.input_symbol, self.current_state)
+ if self.action is not None:
+ self.action (self)
+ self.current_state = self.next_state
+ self.next_state = None
+
+ def process_list (self, input_symbols):
+
+ '''This takes a list and sends each element to process(). The list may
+ be a string or any iterable object. '''
+
+ for s in input_symbols:
+ self.process (s)
+
+##############################################################################
+# The following is an example that demonstrates the use of the FSM class to
+# process an RPN expression. Run this module from the command line. You will
+# get a prompt > for input. Enter an RPN Expression. Numbers may be integers.
+# Operators are * / + - Use the = sign to evaluate and print the expression.
+# For example:
+#
+# 167 3 2 2 * * * 1 - =
+#
+# will print:
+#
+# 2003
+##############################################################################
+
+import sys
+import string
+
+PY3 = (sys.version_info[0] >= 3)
+
+#
+# These define the actions.
+# Note that "memory" is a list being used as a stack.
+#
+
+def BeginBuildNumber (fsm):
+ fsm.memory.append (fsm.input_symbol)
+
+def BuildNumber (fsm):
+ s = fsm.memory.pop ()
+ s = s + fsm.input_symbol
+ fsm.memory.append (s)
+
+def EndBuildNumber (fsm):
+ s = fsm.memory.pop ()
+ fsm.memory.append (int(s))
+
+def DoOperator (fsm):
+ ar = fsm.memory.pop()
+ al = fsm.memory.pop()
+ if fsm.input_symbol == '+':
+ fsm.memory.append (al + ar)
+ elif fsm.input_symbol == '-':
+ fsm.memory.append (al - ar)
+ elif fsm.input_symbol == '*':
+ fsm.memory.append (al * ar)
+ elif fsm.input_symbol == '/':
+ fsm.memory.append (al / ar)
+
+def DoEqual (fsm):
+ print(str(fsm.memory.pop()))
+
+def Error (fsm):
+ print('That does not compute.')
+ print(str(fsm.input_symbol))
+
+def main():
+
+ '''This is where the example starts and the FSM state transitions are
+ defined. Note that states are strings (such as 'INIT'). This is not
+ necessary, but it makes the example easier to read. '''
+
+ f = FSM ('INIT', [])
+ f.set_default_transition (Error, 'INIT')
+ f.add_transition_any ('INIT', None, 'INIT')
+ f.add_transition ('=', 'INIT', DoEqual, 'INIT')
+ f.add_transition_list (string.digits, 'INIT', BeginBuildNumber, 'BUILDING_NUMBER')
+ f.add_transition_list (string.digits, 'BUILDING_NUMBER', BuildNumber, 'BUILDING_NUMBER')
+ f.add_transition_list (string.whitespace, 'BUILDING_NUMBER', EndBuildNumber, 'INIT')
+ f.add_transition_list ('+-*/', 'INIT', DoOperator, 'INIT')
+
+ print()
+ print('Enter an RPN Expression.')
+ print('Numbers may be integers. Operators are * / + -')
+ print('Use the = sign to evaluate and print the expression.')
+ print('For example: ')
+ print(' 167 3 2 2 * * * 1 - =')
+ inputstr = (input if PY3 else raw_input)('> ') # analysis:ignore
+ f.process_list(inputstr)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/lib/pexpect/__init__.py b/lib/pexpect/__init__.py
new file mode 100644
index 0000000..7e30453
--- /dev/null
+++ b/lib/pexpect/__init__.py
@@ -0,0 +1,85 @@
+'''Pexpect is a Python module for spawning child applications and controlling
+them automatically. Pexpect can be used for automating interactive applications
+such as ssh, ftp, passwd, telnet, etc. It can be used to a automate setup
+scripts for duplicating software package installations on different servers. It
+can be used for automated software testing. Pexpect is in the spirit of Don
+Libes' Expect, but Pexpect is pure Python. Other Expect-like modules for Python
+require TCL and Expect or require C extensions to be compiled. Pexpect does not
+use C, Expect, or TCL extensions. It should work on any platform that supports
+the standard Python pty module. The Pexpect interface focuses on ease of use so
+that simple tasks are easy.
+
+There are two main interfaces to the Pexpect system; these are the function,
+run() and the class, spawn. The spawn class is more powerful. The run()
+function is simpler than spawn, and is good for quickly calling program. When
+you call the run() function it executes a given program and then returns the
+output. This is a handy replacement for os.system().
+
+For example::
+
+ pexpect.run('ls -la')
+
+The spawn class is the more powerful interface to the Pexpect system. You can
+use this to spawn a child program then interact with it by sending input and
+expecting responses (waiting for patterns in the child's output).
+
+For example::
+
+ child = pexpect.spawn('scp foo user@example.com:.')
+ child.expect('Password:')
+ child.sendline(mypassword)
+
+This works even for commands that ask for passwords or other input outside of
+the normal stdio streams. For example, ssh reads input directly from the TTY
+device which bypasses stdin.
+
+Credits: Noah Spurrier, Richard Holden, Marco Molteni, Kimberley Burchett,
+Robert Stone, Hartmut Goebel, Chad Schroeder, Erick Tryzelaar, Dave Kirby, Ids
+vander Molen, George Todd, Noel Taylor, Nicolas D. Cesar, Alexander Gattin,
+Jacques-Etienne Baudoux, Geoffrey Marshall, Francisco Lourenco, Glen Mabey,
+Karthik Gurusamy, Fernando Perez, Corey Minyard, Jon Cohen, Guillaume
+Chazarain, Andrew Ryan, Nick Craig-Wood, Andrew Stone, Jorgen Grahn, John
+Spiegel, Jan Grant, and Shane Kerr. Let me know if I forgot anyone.
+
+Pexpect is free, open source, and all that good stuff.
+http://pexpect.sourceforge.net/
+
+PEXPECT LICENSE
+
+ This license is approved by the OSI and FSF as GPL-compatible.
+ http://opensource.org/licenses/isc-license.txt
+
+ Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+ PERMISSION TO USE, COPY, MODIFY, AND/OR DISTRIBUTE THIS SOFTWARE FOR ANY
+ PURPOSE WITH OR WITHOUT FEE IS HEREBY GRANTED, PROVIDED THAT THE ABOVE
+ COPYRIGHT NOTICE AND THIS PERMISSION NOTICE APPEAR IN ALL COPIES.
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''
+
+import sys
+PY3 = (sys.version_info[0] >= 3)
+
+from .exceptions import ExceptionPexpect, EOF, TIMEOUT
+from .utils import split_command_line, which, is_executable_file
+from .expect import Expecter, searcher_re, searcher_string
+
+if sys.platform != 'win32':
+ # On Unix, these are available at the top level for backwards compatibility
+ from .pty_spawn import spawn, spawnu
+ from .run import run, runu
+
+__version__ = '4.8.0'
+__revision__ = ''
+__all__ = ['ExceptionPexpect', 'EOF', 'TIMEOUT', 'spawn', 'spawnu', 'run', 'runu',
+ 'which', 'split_command_line', '__version__', '__revision__']
+
+
+
+# vim: set shiftround expandtab tabstop=4 shiftwidth=4 ft=python autoindent :
diff --git a/lib/pexpect/_async.py b/lib/pexpect/_async.py
new file mode 100644
index 0000000..dfbfeef
--- /dev/null
+++ b/lib/pexpect/_async.py
@@ -0,0 +1,103 @@
+import asyncio
+import errno
+import signal
+
+from pexpect import EOF
+
+@asyncio.coroutine
+def expect_async(expecter, timeout=None):
+ # First process data that was previously read - if it maches, we don't need
+ # async stuff.
+ idx = expecter.existing_data()
+ if idx is not None:
+ return idx
+ if not expecter.spawn.async_pw_transport:
+ pw = PatternWaiter()
+ pw.set_expecter(expecter)
+ transport, pw = yield from asyncio.get_event_loop()\
+ .connect_read_pipe(lambda: pw, expecter.spawn)
+ expecter.spawn.async_pw_transport = pw, transport
+ else:
+ pw, transport = expecter.spawn.async_pw_transport
+ pw.set_expecter(expecter)
+ transport.resume_reading()
+ try:
+ return (yield from asyncio.wait_for(pw.fut, timeout))
+ except asyncio.TimeoutError as e:
+ transport.pause_reading()
+ return expecter.timeout(e)
+
+@asyncio.coroutine
+def repl_run_command_async(repl, cmdlines, timeout=-1):
+ res = []
+ repl.child.sendline(cmdlines[0])
+ for line in cmdlines[1:]:
+ yield from repl._expect_prompt(timeout=timeout, async_=True)
+ res.append(repl.child.before)
+ repl.child.sendline(line)
+
+ # Command was fully submitted, now wait for the next prompt
+ prompt_idx = yield from repl._expect_prompt(timeout=timeout, async_=True)
+ if prompt_idx == 1:
+ # We got the continuation prompt - command was incomplete
+ repl.child.kill(signal.SIGINT)
+ yield from repl._expect_prompt(timeout=1, async_=True)
+ raise ValueError("Continuation prompt found - input was incomplete:")
+ return u''.join(res + [repl.child.before])
+
+class PatternWaiter(asyncio.Protocol):
+ transport = None
+
+ def set_expecter(self, expecter):
+ self.expecter = expecter
+ self.fut = asyncio.Future()
+
+ def found(self, result):
+ if not self.fut.done():
+ self.fut.set_result(result)
+ self.transport.pause_reading()
+
+ def error(self, exc):
+ if not self.fut.done():
+ self.fut.set_exception(exc)
+ self.transport.pause_reading()
+
+ def connection_made(self, transport):
+ self.transport = transport
+
+ def data_received(self, data):
+ spawn = self.expecter.spawn
+ s = spawn._decoder.decode(data)
+ spawn._log(s, 'read')
+
+ if self.fut.done():
+ spawn._before.write(s)
+ spawn._buffer.write(s)
+ return
+
+ try:
+ index = self.expecter.new_data(s)
+ if index is not None:
+ # Found a match
+ self.found(index)
+ except Exception as e:
+ self.expecter.errored()
+ self.error(e)
+
+ def eof_received(self):
+ # N.B. If this gets called, async will close the pipe (the spawn object)
+ # for us
+ try:
+ self.expecter.spawn.flag_eof = True
+ index = self.expecter.eof()
+ except EOF as e:
+ self.error(e)
+ else:
+ self.found(index)
+
+ def connection_lost(self, exc):
+ if isinstance(exc, OSError) and exc.errno == errno.EIO:
+ # We may get here without eof_received being called, e.g on Linux
+ self.eof_received()
+ elif exc is not None:
+ self.error(exc)
diff --git a/lib/pexpect/bashrc.sh b/lib/pexpect/bashrc.sh
new file mode 100644
index 0000000..c734ac9
--- /dev/null
+++ b/lib/pexpect/bashrc.sh
@@ -0,0 +1,16 @@
+# Different platforms have different names for the systemwide bashrc
+if [[ -f /etc/bashrc ]]; then
+ source /etc/bashrc
+fi
+if [[ -f /etc/bash.bashrc ]]; then
+ source /etc/bash.bashrc
+fi
+if [[ -f ~/.bashrc ]]; then
+ source ~/.bashrc
+fi
+
+# Reset PS1 so pexpect can find it
+PS1="$"
+
+# Unset PROMPT_COMMAND, so that it can't change PS1 to something unexpected.
+unset PROMPT_COMMAND
diff --git a/lib/pexpect/exceptions.py b/lib/pexpect/exceptions.py
new file mode 100644
index 0000000..cb360f0
--- /dev/null
+++ b/lib/pexpect/exceptions.py
@@ -0,0 +1,35 @@
+"""Exception classes used by Pexpect"""
+
+import traceback
+import sys
+
+class ExceptionPexpect(Exception):
+ '''Base class for all exceptions raised by this module.
+ '''
+
+ def __init__(self, value):
+ super(ExceptionPexpect, self).__init__(value)
+ self.value = value
+
+ def __str__(self):
+ return str(self.value)
+
+ def get_trace(self):
+ '''This returns an abbreviated stack trace with lines that only concern
+ the caller. In other words, the stack trace inside the Pexpect module
+ is not included. '''
+
+ tblist = traceback.extract_tb(sys.exc_info()[2])
+ tblist = [item for item in tblist if ('pexpect/__init__' not in item[0])
+ and ('pexpect/expect' not in item[0])]
+ tblist = traceback.format_list(tblist)
+ return ''.join(tblist)
+
+
+class EOF(ExceptionPexpect):
+ '''Raised when EOF is read from a child.
+ This usually means the child has exited.'''
+
+
+class TIMEOUT(ExceptionPexpect):
+ '''Raised when a read time exceeds the timeout. '''
diff --git a/lib/pexpect/expect.py b/lib/pexpect/expect.py
new file mode 100644
index 0000000..d3409db
--- /dev/null
+++ b/lib/pexpect/expect.py
@@ -0,0 +1,371 @@
+import time
+
+from .exceptions import EOF, TIMEOUT
+
+class Expecter(object):
+ def __init__(self, spawn, searcher, searchwindowsize=-1):
+ self.spawn = spawn
+ self.searcher = searcher
+ # A value of -1 means to use the figure from spawn, which should
+ # be None or a positive number.
+ if searchwindowsize == -1:
+ searchwindowsize = spawn.searchwindowsize
+ self.searchwindowsize = searchwindowsize
+ self.lookback = None
+ if hasattr(searcher, 'longest_string'):
+ self.lookback = searcher.longest_string
+
+ def do_search(self, window, freshlen):
+ spawn = self.spawn
+ searcher = self.searcher
+ if freshlen > len(window):
+ freshlen = len(window)
+ index = searcher.search(window, freshlen, self.searchwindowsize)
+ if index >= 0:
+ spawn._buffer = spawn.buffer_type()
+ spawn._buffer.write(window[searcher.end:])
+ spawn.before = spawn._before.getvalue()[
+ 0:-(len(window) - searcher.start)]
+ spawn._before = spawn.buffer_type()
+ spawn._before.write(window[searcher.end:])
+ spawn.after = window[searcher.start:searcher.end]
+ spawn.match = searcher.match
+ spawn.match_index = index
+ # Found a match
+ return index
+ elif self.searchwindowsize or self.lookback:
+ maintain = self.searchwindowsize or self.lookback
+ if spawn._buffer.tell() > maintain:
+ spawn._buffer = spawn.buffer_type()
+ spawn._buffer.write(window[-maintain:])
+
+ def existing_data(self):
+ # First call from a new call to expect_loop or expect_async.
+ # self.searchwindowsize may have changed.
+ # Treat all data as fresh.
+ spawn = self.spawn
+ before_len = spawn._before.tell()
+ buf_len = spawn._buffer.tell()
+ freshlen = before_len
+ if before_len > buf_len:
+ if not self.searchwindowsize:
+ spawn._buffer = spawn.buffer_type()
+ window = spawn._before.getvalue()
+ spawn._buffer.write(window)
+ elif buf_len < self.searchwindowsize:
+ spawn._buffer = spawn.buffer_type()
+ spawn._before.seek(
+ max(0, before_len - self.searchwindowsize))
+ window = spawn._before.read()
+ spawn._buffer.write(window)
+ else:
+ spawn._buffer.seek(max(0, buf_len - self.searchwindowsize))
+ window = spawn._buffer.read()
+ else:
+ if self.searchwindowsize:
+ spawn._buffer.seek(max(0, buf_len - self.searchwindowsize))
+ window = spawn._buffer.read()
+ else:
+ window = spawn._buffer.getvalue()
+ return self.do_search(window, freshlen)
+
+ def new_data(self, data):
+ # A subsequent call, after a call to existing_data.
+ spawn = self.spawn
+ freshlen = len(data)
+ spawn._before.write(data)
+ if not self.searchwindowsize:
+ if self.lookback:
+ # search lookback + new data.
+ old_len = spawn._buffer.tell()
+ spawn._buffer.write(data)
+ spawn._buffer.seek(max(0, old_len - self.lookback))
+ window = spawn._buffer.read()
+ else:
+ # copy the whole buffer (really slow for large datasets).
+ spawn._buffer.write(data)
+ window = spawn.buffer
+ else:
+ if len(data) >= self.searchwindowsize or not spawn._buffer.tell():
+ window = data[-self.searchwindowsize:]
+ spawn._buffer = spawn.buffer_type()
+ spawn._buffer.write(window[-self.searchwindowsize:])
+ else:
+ spawn._buffer.write(data)
+ new_len = spawn._buffer.tell()
+ spawn._buffer.seek(max(0, new_len - self.searchwindowsize))
+ window = spawn._buffer.read()
+ return self.do_search(window, freshlen)
+
+ def eof(self, err=None):
+ spawn = self.spawn
+
+ spawn.before = spawn._before.getvalue()
+ spawn._buffer = spawn.buffer_type()
+ spawn._before = spawn.buffer_type()
+ spawn.after = EOF
+ index = self.searcher.eof_index
+ if index >= 0:
+ spawn.match = EOF
+ spawn.match_index = index
+ return index
+ else:
+ spawn.match = None
+ spawn.match_index = None
+ msg = str(spawn)
+ msg += '\nsearcher: %s' % self.searcher
+ if err is not None:
+ msg = str(err) + '\n' + msg
+
+ exc = EOF(msg)
+ exc.__cause__ = None # in Python 3.x we can use "raise exc from None"
+ raise exc
+
+ def timeout(self, err=None):
+ spawn = self.spawn
+
+ spawn.before = spawn._before.getvalue()
+ spawn.after = TIMEOUT
+ index = self.searcher.timeout_index
+ if index >= 0:
+ spawn.match = TIMEOUT
+ spawn.match_index = index
+ return index
+ else:
+ spawn.match = None
+ spawn.match_index = None
+ msg = str(spawn)
+ msg += '\nsearcher: %s' % self.searcher
+ if err is not None:
+ msg = str(err) + '\n' + msg
+
+ exc = TIMEOUT(msg)
+ exc.__cause__ = None # in Python 3.x we can use "raise exc from None"
+ raise exc
+
+ def errored(self):
+ spawn = self.spawn
+ spawn.before = spawn._before.getvalue()
+ spawn.after = None
+ spawn.match = None
+ spawn.match_index = None
+
+ def expect_loop(self, timeout=-1):
+ """Blocking expect"""
+ spawn = self.spawn
+
+ if timeout is not None:
+ end_time = time.time() + timeout
+
+ try:
+ idx = self.existing_data()
+ if idx is not None:
+ return idx
+ while True:
+ # No match at this point
+ if (timeout is not None) and (timeout < 0):
+ return self.timeout()
+ # Still have time left, so read more data
+ incoming = spawn.read_nonblocking(spawn.maxread, timeout)
+ if self.spawn.delayafterread is not None:
+ time.sleep(self.spawn.delayafterread)
+ idx = self.new_data(incoming)
+ # Keep reading until exception or return.
+ if idx is not None:
+ return idx
+ if timeout is not None:
+ timeout = end_time - time.time()
+ except EOF as e:
+ return self.eof(e)
+ except TIMEOUT as e:
+ return self.timeout(e)
+ except:
+ self.errored()
+ raise
+
+
+class searcher_string(object):
+ '''This is a plain string search helper for the spawn.expect_any() method.
+ This helper class is for speed. For more powerful regex patterns
+ see the helper class, searcher_re.
+
+ Attributes:
+
+ eof_index - index of EOF, or -1
+ timeout_index - index of TIMEOUT, or -1
+
+ After a successful match by the search() method the following attributes
+ are available:
+
+ start - index into the buffer, first byte of match
+ end - index into the buffer, first byte after match
+ match - the matching string itself
+
+ '''
+
+ def __init__(self, strings):
+ '''This creates an instance of searcher_string. This argument 'strings'
+ may be a list; a sequence of strings; or the EOF or TIMEOUT types. '''
+
+ self.eof_index = -1
+ self.timeout_index = -1
+ self._strings = []
+ self.longest_string = 0
+ for n, s in enumerate(strings):
+ if s is EOF:
+ self.eof_index = n
+ continue
+ if s is TIMEOUT:
+ self.timeout_index = n
+ continue
+ self._strings.append((n, s))
+ if len(s) > self.longest_string:
+ self.longest_string = len(s)
+
+ def __str__(self):
+ '''This returns a human-readable string that represents the state of
+ the object.'''
+
+ ss = [(ns[0], ' %d: %r' % ns) for ns in self._strings]
+ ss.append((-1, 'searcher_string:'))
+ if self.eof_index >= 0:
+ ss.append((self.eof_index, ' %d: EOF' % self.eof_index))
+ if self.timeout_index >= 0:
+ ss.append((self.timeout_index,
+ ' %d: TIMEOUT' % self.timeout_index))
+ ss.sort()
+ ss = list(zip(*ss))[1]
+ return '\n'.join(ss)
+
+ def search(self, buffer, freshlen, searchwindowsize=None):
+ '''This searches 'buffer' for the first occurrence of one of the search
+ strings. 'freshlen' must indicate the number of bytes at the end of
+ 'buffer' which have not been searched before. It helps to avoid
+ searching the same, possibly big, buffer over and over again.
+
+ See class spawn for the 'searchwindowsize' argument.
+
+ If there is a match this returns the index of that string, and sets
+ 'start', 'end' and 'match'. Otherwise, this returns -1. '''
+
+ first_match = None
+
+ # 'freshlen' helps a lot here. Further optimizations could
+ # possibly include:
+ #
+ # using something like the Boyer-Moore Fast String Searching
+ # Algorithm; pre-compiling the search through a list of
+ # strings into something that can scan the input once to
+ # search for all N strings; realize that if we search for
+ # ['bar', 'baz'] and the input is '...foo' we need not bother
+ # rescanning until we've read three more bytes.
+ #
+ # Sadly, I don't know enough about this interesting topic. /grahn
+
+ for index, s in self._strings:
+ if searchwindowsize is None:
+ # the match, if any, can only be in the fresh data,
+ # or at the very end of the old data
+ offset = -(freshlen + len(s))
+ else:
+ # better obey searchwindowsize
+ offset = -searchwindowsize
+ n = buffer.find(s, offset)
+ if n >= 0 and (first_match is None or n < first_match):
+ first_match = n
+ best_index, best_match = index, s
+ if first_match is None:
+ return -1
+ self.match = best_match
+ self.start = first_match
+ self.end = self.start + len(self.match)
+ return best_index
+
+
+class searcher_re(object):
+ '''This is regular expression string search helper for the
+ spawn.expect_any() method. This helper class is for powerful
+ pattern matching. For speed, see the helper class, searcher_string.
+
+ Attributes:
+
+ eof_index - index of EOF, or -1
+ timeout_index - index of TIMEOUT, or -1
+
+ After a successful match by the search() method the following attributes
+ are available:
+
+ start - index into the buffer, first byte of match
+ end - index into the buffer, first byte after match
+ match - the re.match object returned by a successful re.search
+
+ '''
+
+ def __init__(self, patterns):
+ '''This creates an instance that searches for 'patterns' Where
+ 'patterns' may be a list or other sequence of compiled regular
+ expressions, or the EOF or TIMEOUT types.'''
+
+ self.eof_index = -1
+ self.timeout_index = -1
+ self._searches = []
+ for n, s in enumerate(patterns):
+ if s is EOF:
+ self.eof_index = n
+ continue
+ if s is TIMEOUT:
+ self.timeout_index = n
+ continue
+ self._searches.append((n, s))
+
+ def __str__(self):
+ '''This returns a human-readable string that represents the state of
+ the object.'''
+
+ #ss = [(n, ' %d: re.compile("%s")' %
+ # (n, repr(s.pattern))) for n, s in self._searches]
+ ss = list()
+ for n, s in self._searches:
+ ss.append((n, ' %d: re.compile(%r)' % (n, s.pattern)))
+ ss.append((-1, 'searcher_re:'))
+ if self.eof_index >= 0:
+ ss.append((self.eof_index, ' %d: EOF' % self.eof_index))
+ if self.timeout_index >= 0:
+ ss.append((self.timeout_index, ' %d: TIMEOUT' %
+ self.timeout_index))
+ ss.sort()
+ ss = list(zip(*ss))[1]
+ return '\n'.join(ss)
+
+ def search(self, buffer, freshlen, searchwindowsize=None):
+ '''This searches 'buffer' for the first occurrence of one of the regular
+ expressions. 'freshlen' must indicate the number of bytes at the end of
+ 'buffer' which have not been searched before.
+
+ See class spawn for the 'searchwindowsize' argument.
+
+ If there is a match this returns the index of that string, and sets
+ 'start', 'end' and 'match'. Otherwise, returns -1.'''
+
+ first_match = None
+ # 'freshlen' doesn't help here -- we cannot predict the
+ # length of a match, and the re module provides no help.
+ if searchwindowsize is None:
+ searchstart = 0
+ else:
+ searchstart = max(0, len(buffer) - searchwindowsize)
+ for index, s in self._searches:
+ match = s.search(buffer, searchstart)
+ if match is None:
+ continue
+ n = match.start()
+ if first_match is None or n < first_match:
+ first_match = n
+ the_match = match
+ best_index = index
+ if first_match is None:
+ return -1
+ self.start = first_match
+ self.match = the_match
+ self.end = self.match.end()
+ return best_index
diff --git a/lib/pexpect/fdpexpect.py b/lib/pexpect/fdpexpect.py
new file mode 100644
index 0000000..cddd50e
--- /dev/null
+++ b/lib/pexpect/fdpexpect.py
@@ -0,0 +1,148 @@
+'''This is like pexpect, but it will work with any file descriptor that you
+pass it. You are responsible for opening and close the file descriptor.
+This allows you to use Pexpect with sockets and named pipes (FIFOs).
+
+PEXPECT LICENSE
+
+ This license is approved by the OSI and FSF as GPL-compatible.
+ http://opensource.org/licenses/isc-license.txt
+
+ Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+ PERMISSION TO USE, COPY, MODIFY, AND/OR DISTRIBUTE THIS SOFTWARE FOR ANY
+ PURPOSE WITH OR WITHOUT FEE IS HEREBY GRANTED, PROVIDED THAT THE ABOVE
+ COPYRIGHT NOTICE AND THIS PERMISSION NOTICE APPEAR IN ALL COPIES.
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''
+
+from .spawnbase import SpawnBase
+from .exceptions import ExceptionPexpect, TIMEOUT
+from .utils import select_ignore_interrupts, poll_ignore_interrupts
+import os
+
+__all__ = ['fdspawn']
+
+class fdspawn(SpawnBase):
+ '''This is like pexpect.spawn but allows you to supply your own open file
+ descriptor. For example, you could use it to read through a file looking
+ for patterns, or to control a modem or serial device. '''
+
+ def __init__ (self, fd, args=None, timeout=30, maxread=2000, searchwindowsize=None,
+ logfile=None, encoding=None, codec_errors='strict', use_poll=False):
+ '''This takes a file descriptor (an int) or an object that support the
+ fileno() method (returning an int). All Python file-like objects
+ support fileno(). '''
+
+ if type(fd) != type(0) and hasattr(fd, 'fileno'):
+ fd = fd.fileno()
+
+ if type(fd) != type(0):
+ raise ExceptionPexpect('The fd argument is not an int. If this is a command string then maybe you want to use pexpect.spawn.')
+
+ try: # make sure fd is a valid file descriptor
+ os.fstat(fd)
+ except OSError:
+ raise ExceptionPexpect('The fd argument is not a valid file descriptor.')
+
+ self.args = None
+ self.command = None
+ SpawnBase.__init__(self, timeout, maxread, searchwindowsize, logfile,
+ encoding=encoding, codec_errors=codec_errors)
+ self.child_fd = fd
+ self.own_fd = False
+ self.closed = False
+ self.name = '<file descriptor %d>' % fd
+ self.use_poll = use_poll
+
+ def close (self):
+ """Close the file descriptor.
+
+ Calling this method a second time does nothing, but if the file
+ descriptor was closed elsewhere, :class:`OSError` will be raised.
+ """
+ if self.child_fd == -1:
+ return
+
+ self.flush()
+ os.close(self.child_fd)
+ self.child_fd = -1
+ self.closed = True
+
+ def isalive (self):
+ '''This checks if the file descriptor is still valid. If :func:`os.fstat`
+ does not raise an exception then we assume it is alive. '''
+
+ if self.child_fd == -1:
+ return False
+ try:
+ os.fstat(self.child_fd)
+ return True
+ except:
+ return False
+
+ def terminate (self, force=False): # pragma: no cover
+ '''Deprecated and invalid. Just raises an exception.'''
+ raise ExceptionPexpect('This method is not valid for file descriptors.')
+
+ # These four methods are left around for backwards compatibility, but not
+ # documented as part of fdpexpect. You're encouraged to use os.write
+ # directly.
+ def send(self, s):
+ "Write to fd, return number of bytes written"
+ s = self._coerce_send_string(s)
+ self._log(s, 'send')
+
+ b = self._encoder.encode(s, final=False)
+ return os.write(self.child_fd, b)
+
+ def sendline(self, s):
+ "Write to fd with trailing newline, return number of bytes written"
+ s = self._coerce_send_string(s)
+ return self.send(s + self.linesep)
+
+ def write(self, s):
+ "Write to fd, return None"
+ self.send(s)
+
+ def writelines(self, sequence):
+ "Call self.write() for each item in sequence"
+ for s in sequence:
+ self.write(s)
+
+ def read_nonblocking(self, size=1, timeout=-1):
+ """
+ Read from the file descriptor and return the result as a string.
+
+ The read_nonblocking method of :class:`SpawnBase` assumes that a call
+ to os.read will not block (timeout parameter is ignored). This is not
+ the case for POSIX file-like objects such as sockets and serial ports.
+
+ Use :func:`select.select`, timeout is implemented conditionally for
+ POSIX systems.
+
+ :param int size: Read at most *size* bytes.
+ :param int timeout: Wait timeout seconds for file descriptor to be
+ ready to read. When -1 (default), use self.timeout. When 0, poll.
+ :return: String containing the bytes read
+ """
+ if os.name == 'posix':
+ if timeout == -1:
+ timeout = self.timeout
+ rlist = [self.child_fd]
+ wlist = []
+ xlist = []
+ if self.use_poll:
+ rlist = poll_ignore_interrupts(rlist, timeout)
+ else:
+ rlist, wlist, xlist = select_ignore_interrupts(
+ rlist, wlist, xlist, timeout
+ )
+ if self.child_fd not in rlist:
+ raise TIMEOUT('Timeout exceeded.')
+ return super(fdspawn, self).read_nonblocking(size)
diff --git a/lib/pexpect/popen_spawn.py b/lib/pexpect/popen_spawn.py
new file mode 100644
index 0000000..4bb58cf
--- /dev/null
+++ b/lib/pexpect/popen_spawn.py
@@ -0,0 +1,188 @@
+"""Provides an interface like pexpect.spawn interface using subprocess.Popen
+"""
+import os
+import threading
+import subprocess
+import sys
+import time
+import signal
+import shlex
+
+try:
+ from queue import Queue, Empty # Python 3
+except ImportError:
+ from Queue import Queue, Empty # Python 2
+
+from .spawnbase import SpawnBase, PY3
+from .exceptions import EOF
+from .utils import string_types
+
+class PopenSpawn(SpawnBase):
+ def __init__(self, cmd, timeout=30, maxread=2000, searchwindowsize=None,
+ logfile=None, cwd=None, env=None, encoding=None,
+ codec_errors='strict', preexec_fn=None):
+ super(PopenSpawn, self).__init__(timeout=timeout, maxread=maxread,
+ searchwindowsize=searchwindowsize, logfile=logfile,
+ encoding=encoding, codec_errors=codec_errors)
+
+ # Note that `SpawnBase` initializes `self.crlf` to `\r\n`
+ # because the default behaviour for a PTY is to convert
+ # incoming LF to `\r\n` (see the `onlcr` flag and
+ # https://stackoverflow.com/a/35887657/5397009). Here we set
+ # it to `os.linesep` because that is what the spawned
+ # application outputs by default and `popen` doesn't translate
+ # anything.
+ if encoding is None:
+ self.crlf = os.linesep.encode ("ascii")
+ else:
+ self.crlf = self.string_type (os.linesep)
+
+ kwargs = dict(bufsize=0, stdin=subprocess.PIPE,
+ stderr=subprocess.STDOUT, stdout=subprocess.PIPE,
+ cwd=cwd, preexec_fn=preexec_fn, env=env)
+
+ if sys.platform == 'win32':
+ startupinfo = subprocess.STARTUPINFO()
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
+ kwargs['startupinfo'] = startupinfo
+ kwargs['creationflags'] = subprocess.CREATE_NEW_PROCESS_GROUP
+
+ if isinstance(cmd, string_types) and sys.platform != 'win32':
+ cmd = shlex.split(cmd, posix=os.name == 'posix')
+
+ self.proc = subprocess.Popen(cmd, **kwargs)
+ self.pid = self.proc.pid
+ self.closed = False
+ self._buf = self.string_type()
+
+ self._read_queue = Queue()
+ self._read_thread = threading.Thread(target=self._read_incoming)
+ self._read_thread.setDaemon(True)
+ self._read_thread.start()
+
+ _read_reached_eof = False
+
+ def read_nonblocking(self, size, timeout):
+ buf = self._buf
+ if self._read_reached_eof:
+ # We have already finished reading. Use up any buffered data,
+ # then raise EOF
+ if buf:
+ self._buf = buf[size:]
+ return buf[:size]
+ else:
+ self.flag_eof = True
+ raise EOF('End Of File (EOF).')
+
+ if timeout == -1:
+ timeout = self.timeout
+ elif timeout is None:
+ timeout = 1e6
+
+ t0 = time.time()
+ while (time.time() - t0) < timeout and size and len(buf) < size:
+ try:
+ incoming = self._read_queue.get_nowait()
+ except Empty:
+ break
+ else:
+ if incoming is None:
+ self._read_reached_eof = True
+ break
+
+ buf += self._decoder.decode(incoming, final=False)
+
+ r, self._buf = buf[:size], buf[size:]
+
+ self._log(r, 'read')
+ return r
+
+ def _read_incoming(self):
+ """Run in a thread to move output from a pipe to a queue."""
+ fileno = self.proc.stdout.fileno()
+ while 1:
+ buf = b''
+ try:
+ buf = os.read(fileno, 1024)
+ except OSError as e:
+ self._log(e, 'read')
+
+ if not buf:
+ # This indicates we have reached EOF
+ self._read_queue.put(None)
+ return
+
+ self._read_queue.put(buf)
+
+ def write(self, s):
+ '''This is similar to send() except that there is no return value.
+ '''
+ self.send(s)
+
+ def writelines(self, sequence):
+ '''This calls write() for each element in the sequence.
+
+ The sequence can be any iterable object producing strings, typically a
+ list of strings. This does not add line separators. There is no return
+ value.
+ '''
+ for s in sequence:
+ self.send(s)
+
+ def send(self, s):
+ '''Send data to the subprocess' stdin.
+
+ Returns the number of bytes written.
+ '''
+ s = self._coerce_send_string(s)
+ self._log(s, 'send')
+
+ b = self._encoder.encode(s, final=False)
+ if PY3:
+ return self.proc.stdin.write(b)
+ else:
+ # On Python 2, .write() returns None, so we return the length of
+ # bytes written ourselves. This assumes they all got written.
+ self.proc.stdin.write(b)
+ return len(b)
+
+ def sendline(self, s=''):
+ '''Wraps send(), sending string ``s`` to child process, with os.linesep
+ automatically appended. Returns number of bytes written. '''
+
+ n = self.send(s)
+ return n + self.send(self.linesep)
+
+ def wait(self):
+ '''Wait for the subprocess to finish.
+
+ Returns the exit code.
+ '''
+ status = self.proc.wait()
+ if status >= 0:
+ self.exitstatus = status
+ self.signalstatus = None
+ else:
+ self.exitstatus = None
+ self.signalstatus = -status
+ self.terminated = True
+ return status
+
+ def kill(self, sig):
+ '''Sends a Unix signal to the subprocess.
+
+ Use constants from the :mod:`signal` module to specify which signal.
+ '''
+ if sys.platform == 'win32':
+ if sig in [signal.SIGINT, signal.CTRL_C_EVENT]:
+ sig = signal.CTRL_C_EVENT
+ elif sig in [signal.SIGBREAK, signal.CTRL_BREAK_EVENT]:
+ sig = signal.CTRL_BREAK_EVENT
+ else:
+ sig = signal.SIGTERM
+
+ os.kill(self.proc.pid, sig)
+
+ def sendeof(self):
+ '''Closes the stdin pipe from the writing end.'''
+ self.proc.stdin.close()
diff --git a/lib/pexpect/pty_spawn.py b/lib/pexpect/pty_spawn.py
new file mode 100644
index 0000000..8e28ca7
--- /dev/null
+++ b/lib/pexpect/pty_spawn.py
@@ -0,0 +1,860 @@
+import os
+import sys
+import time
+import pty
+import tty
+import errno
+import signal
+from contextlib import contextmanager
+
+import ptyprocess
+from ptyprocess.ptyprocess import use_native_pty_fork
+
+from .exceptions import ExceptionPexpect, EOF, TIMEOUT
+from .spawnbase import SpawnBase
+from .utils import (
+ which, split_command_line, select_ignore_interrupts, poll_ignore_interrupts
+)
+
+@contextmanager
+def _wrap_ptyprocess_err():
+ """Turn ptyprocess errors into our own ExceptionPexpect errors"""
+ try:
+ yield
+ except ptyprocess.PtyProcessError as e:
+ raise ExceptionPexpect(*e.args)
+
+PY3 = (sys.version_info[0] >= 3)
+
+class spawn(SpawnBase):
+ '''This is the main class interface for Pexpect. Use this class to start
+ and control child applications. '''
+
+ # This is purely informational now - changing it has no effect
+ use_native_pty_fork = use_native_pty_fork
+
+ def __init__(self, command, args=[], timeout=30, maxread=2000,
+ searchwindowsize=None, logfile=None, cwd=None, env=None,
+ ignore_sighup=False, echo=True, preexec_fn=None,
+ encoding=None, codec_errors='strict', dimensions=None,
+ use_poll=False):
+ '''This is the constructor. The command parameter may be a string that
+ includes a command and any arguments to the command. For example::
+
+ child = pexpect.spawn('/usr/bin/ftp')
+ child = pexpect.spawn('/usr/bin/ssh user@example.com')
+ child = pexpect.spawn('ls -latr /tmp')
+
+ You may also construct it with a list of arguments like so::
+
+ child = pexpect.spawn('/usr/bin/ftp', [])
+ child = pexpect.spawn('/usr/bin/ssh', ['user@example.com'])
+ child = pexpect.spawn('ls', ['-latr', '/tmp'])
+
+ After this the child application will be created and will be ready to
+ talk to. For normal use, see expect() and send() and sendline().
+
+ Remember that Pexpect does NOT interpret shell meta characters such as
+ redirect, pipe, or wild cards (``>``, ``|``, or ``*``). This is a
+ common mistake. If you want to run a command and pipe it through
+ another command then you must also start a shell. For example::
+
+ child = pexpect.spawn('/bin/bash -c "ls -l | grep LOG > logs.txt"')
+ child.expect(pexpect.EOF)
+
+ The second form of spawn (where you pass a list of arguments) is useful
+ in situations where you wish to spawn a command and pass it its own
+ argument list. This can make syntax more clear. For example, the
+ following is equivalent to the previous example::
+
+ shell_cmd = 'ls -l | grep LOG > logs.txt'
+ child = pexpect.spawn('/bin/bash', ['-c', shell_cmd])
+ child.expect(pexpect.EOF)
+
+ The maxread attribute sets the read buffer size. This is maximum number
+ of bytes that Pexpect will try to read from a TTY at one time. Setting
+ the maxread size to 1 will turn off buffering. Setting the maxread
+ value higher may help performance in cases where large amounts of
+ output are read back from the child. This feature is useful in
+ conjunction with searchwindowsize.
+
+ When the keyword argument *searchwindowsize* is None (default), the
+ full buffer is searched at each iteration of receiving incoming data.
+ The default number of bytes scanned at each iteration is very large
+ and may be reduced to collaterally reduce search cost. After
+ :meth:`~.expect` returns, the full buffer attribute remains up to
+ size *maxread* irrespective of *searchwindowsize* value.
+
+ When the keyword argument ``timeout`` is specified as a number,
+ (default: *30*), then :class:`TIMEOUT` will be raised after the value
+ specified has elapsed, in seconds, for any of the :meth:`~.expect`
+ family of method calls. When None, TIMEOUT will not be raised, and
+ :meth:`~.expect` may block indefinitely until match.
+
+
+ The logfile member turns on or off logging. All input and output will
+ be copied to the given file object. Set logfile to None to stop
+ logging. This is the default. Set logfile to sys.stdout to echo
+ everything to standard output. The logfile is flushed after each write.
+
+ Example log input and output to a file::
+
+ child = pexpect.spawn('some_command')
+ fout = open('mylog.txt','wb')
+ child.logfile = fout
+
+ Example log to stdout::
+
+ # In Python 2:
+ child = pexpect.spawn('some_command')
+ child.logfile = sys.stdout
+
+ # In Python 3, we'll use the ``encoding`` argument to decode data
+ # from the subprocess and handle it as unicode:
+ child = pexpect.spawn('some_command', encoding='utf-8')
+ child.logfile = sys.stdout
+
+ The logfile_read and logfile_send members can be used to separately log
+ the input from the child and output sent to the child. Sometimes you
+ don't want to see everything you write to the child. You only want to
+ log what the child sends back. For example::
+
+ child = pexpect.spawn('some_command')
+ child.logfile_read = sys.stdout
+
+ You will need to pass an encoding to spawn in the above code if you are
+ using Python 3.
+
+ To separately log output sent to the child use logfile_send::
+
+ child.logfile_send = fout
+
+ If ``ignore_sighup`` is True, the child process will ignore SIGHUP
+ signals. The default is False from Pexpect 4.0, meaning that SIGHUP
+ will be handled normally by the child.
+
+ The delaybeforesend helps overcome a weird behavior that many users
+ were experiencing. The typical problem was that a user would expect() a
+ "Password:" prompt and then immediately call sendline() to send the
+ password. The user would then see that their password was echoed back
+ to them. Passwords don't normally echo. The problem is caused by the
+ fact that most applications print out the "Password" prompt and then
+ turn off stdin echo, but if you send your password before the
+ application turned off echo, then you get your password echoed.
+ Normally this wouldn't be a problem when interacting with a human at a
+ real keyboard. If you introduce a slight delay just before writing then
+ this seems to clear up the problem. This was such a common problem for
+ many users that I decided that the default pexpect behavior should be
+ to sleep just before writing to the child application. 1/20th of a
+ second (50 ms) seems to be enough to clear up the problem. You can set
+ delaybeforesend to None to return to the old behavior.
+
+ Note that spawn is clever about finding commands on your path.
+ It uses the same logic that "which" uses to find executables.
+
+ If you wish to get the exit status of the child you must call the
+ close() method. The exit or signal status of the child will be stored
+ in self.exitstatus or self.signalstatus. If the child exited normally
+ then exitstatus will store the exit return code and signalstatus will
+ be None. If the child was terminated abnormally with a signal then
+ signalstatus will store the signal value and exitstatus will be None::
+
+ child = pexpect.spawn('some_command')
+ child.close()
+ print(child.exitstatus, child.signalstatus)
+
+ If you need more detail you can also read the self.status member which
+ stores the status returned by os.waitpid. You can interpret this using
+ os.WIFEXITED/os.WEXITSTATUS or os.WIFSIGNALED/os.TERMSIG.
+
+ The echo attribute may be set to False to disable echoing of input.
+ As a pseudo-terminal, all input echoed by the "keyboard" (send()
+ or sendline()) will be repeated to output. For many cases, it is
+ not desirable to have echo enabled, and it may be later disabled
+ using setecho(False) followed by waitnoecho(). However, for some
+ platforms such as Solaris, this is not possible, and should be
+ disabled immediately on spawn.
+
+ If preexec_fn is given, it will be called in the child process before
+ launching the given command. This is useful to e.g. reset inherited
+ signal handlers.
+
+ The dimensions attribute specifies the size of the pseudo-terminal as
+ seen by the subprocess, and is specified as a two-entry tuple (rows,
+ columns). If this is unspecified, the defaults in ptyprocess will apply.
+
+ The use_poll attribute enables using select.poll() over select.select()
+ for socket handling. This is handy if your system could have > 1024 fds
+ '''
+ super(spawn, self).__init__(timeout=timeout, maxread=maxread, searchwindowsize=searchwindowsize,
+ logfile=logfile, encoding=encoding, codec_errors=codec_errors)
+ self.STDIN_FILENO = pty.STDIN_FILENO
+ self.STDOUT_FILENO = pty.STDOUT_FILENO
+ self.STDERR_FILENO = pty.STDERR_FILENO
+ self.str_last_chars = 100
+ self.cwd = cwd
+ self.env = env
+ self.echo = echo
+ self.ignore_sighup = ignore_sighup
+ self.__irix_hack = sys.platform.lower().startswith('irix')
+ if command is None:
+ self.command = None
+ self.args = None
+ self.name = '<pexpect factory incomplete>'
+ else:
+ self._spawn(command, args, preexec_fn, dimensions)
+ self.use_poll = use_poll
+
+ def __str__(self):
+ '''This returns a human-readable string that represents the state of
+ the object. '''
+
+ s = []
+ s.append(repr(self))
+ s.append('command: ' + str(self.command))
+ s.append('args: %r' % (self.args,))
+ s.append('buffer (last %s chars): %r' % (self.str_last_chars,self.buffer[-self.str_last_chars:]))
+ s.append('before (last %s chars): %r' % (self.str_last_chars,self.before[-self.str_last_chars:] if self.before else ''))
+ s.append('after: %r' % (self.after,))
+ s.append('match: %r' % (self.match,))
+ s.append('match_index: ' + str(self.match_index))
+ s.append('exitstatus: ' + str(self.exitstatus))
+ if hasattr(self, 'ptyproc'):
+ s.append('flag_eof: ' + str(self.flag_eof))
+ s.append('pid: ' + str(self.pid))
+ s.append('child_fd: ' + str(self.child_fd))
+ s.append('closed: ' + str(self.closed))
+ s.append('timeout: ' + str(self.timeout))
+ s.append('delimiter: ' + str(self.delimiter))
+ s.append('logfile: ' + str(self.logfile))
+ s.append('logfile_read: ' + str(self.logfile_read))
+ s.append('logfile_send: ' + str(self.logfile_send))
+ s.append('maxread: ' + str(self.maxread))
+ s.append('ignorecase: ' + str(self.ignorecase))
+ s.append('searchwindowsize: ' + str(self.searchwindowsize))
+ s.append('delaybeforesend: ' + str(self.delaybeforesend))
+ s.append('delayafterclose: ' + str(self.delayafterclose))
+ s.append('delayafterterminate: ' + str(self.delayafterterminate))
+ return '\n'.join(s)
+
+ def _spawn(self, command, args=[], preexec_fn=None, dimensions=None):
+ '''This starts the given command in a child process. This does all the
+ fork/exec type of stuff for a pty. This is called by __init__. If args
+ is empty then command will be parsed (split on spaces) and args will be
+ set to parsed arguments. '''
+
+ # The pid and child_fd of this object get set by this method.
+ # Note that it is difficult for this method to fail.
+ # You cannot detect if the child process cannot start.
+ # So the only way you can tell if the child process started
+ # or not is to try to read from the file descriptor. If you get
+ # EOF immediately then it means that the child is already dead.
+ # That may not necessarily be bad because you may have spawned a child
+ # that performs some task; creates no stdout output; and then dies.
+
+ # If command is an int type then it may represent a file descriptor.
+ if isinstance(command, type(0)):
+ raise ExceptionPexpect('Command is an int type. ' +
+ 'If this is a file descriptor then maybe you want to ' +
+ 'use fdpexpect.fdspawn which takes an existing ' +
+ 'file descriptor instead of a command string.')
+
+ if not isinstance(args, type([])):
+ raise TypeError('The argument, args, must be a list.')
+
+ if args == []:
+ self.args = split_command_line(command)
+ self.command = self.args[0]
+ else:
+ # Make a shallow copy of the args list.
+ self.args = args[:]
+ self.args.insert(0, command)
+ self.command = command
+
+ command_with_path = which(self.command, env=self.env)
+ if command_with_path is None:
+ raise ExceptionPexpect('The command was not found or was not ' +
+ 'executable: %s.' % self.command)
+ self.command = command_with_path
+ self.args[0] = self.command
+
+ self.name = '<' + ' '.join(self.args) + '>'
+
+ assert self.pid is None, 'The pid member must be None.'
+ assert self.command is not None, 'The command member must not be None.'
+
+ kwargs = {'echo': self.echo, 'preexec_fn': preexec_fn}
+ if self.ignore_sighup:
+ def preexec_wrapper():
+ "Set SIGHUP to be ignored, then call the real preexec_fn"
+ signal.signal(signal.SIGHUP, signal.SIG_IGN)
+ if preexec_fn is not None:
+ preexec_fn()
+ kwargs['preexec_fn'] = preexec_wrapper
+
+ if dimensions is not None:
+ kwargs['dimensions'] = dimensions
+
+ if self.encoding is not None:
+ # Encode command line using the specified encoding
+ self.args = [a if isinstance(a, bytes) else a.encode(self.encoding)
+ for a in self.args]
+
+ self.ptyproc = self._spawnpty(self.args, env=self.env,
+ cwd=self.cwd, **kwargs)
+
+ self.pid = self.ptyproc.pid
+ self.child_fd = self.ptyproc.fd
+
+
+ self.terminated = False
+ self.closed = False
+
+ def _spawnpty(self, args, **kwargs):
+ '''Spawn a pty and return an instance of PtyProcess.'''
+ return ptyprocess.PtyProcess.spawn(args, **kwargs)
+
+ def close(self, force=True):
+ '''This closes the connection with the child application. Note that
+ calling close() more than once is valid. This emulates standard Python
+ behavior with files. Set force to True if you want to make sure that
+ the child is terminated (SIGKILL is sent if the child ignores SIGHUP
+ and SIGINT). '''
+
+ self.flush()
+ with _wrap_ptyprocess_err():
+ # PtyProcessError may be raised if it is not possible to terminate
+ # the child.
+ self.ptyproc.close(force=force)
+ self.isalive() # Update exit status from ptyproc
+ self.child_fd = -1
+ self.closed = True
+
+ def isatty(self):
+ '''This returns True if the file descriptor is open and connected to a
+ tty(-like) device, else False.
+
+ On SVR4-style platforms implementing streams, such as SunOS and HP-UX,
+ the child pty may not appear as a terminal device. This means
+ methods such as setecho(), setwinsize(), getwinsize() may raise an
+ IOError. '''
+
+ return os.isatty(self.child_fd)
+
+ def waitnoecho(self, timeout=-1):
+ '''This waits until the terminal ECHO flag is set False. This returns
+ True if the echo mode is off. This returns False if the ECHO flag was
+ not set False before the timeout. This can be used to detect when the
+ child is waiting for a password. Usually a child application will turn
+ off echo mode when it is waiting for the user to enter a password. For
+ example, instead of expecting the "password:" prompt you can wait for
+ the child to set ECHO off::
+
+ p = pexpect.spawn('ssh user@example.com')
+ p.waitnoecho()
+ p.sendline(mypassword)
+
+ If timeout==-1 then this method will use the value in self.timeout.
+ If timeout==None then this method to block until ECHO flag is False.
+ '''
+
+ if timeout == -1:
+ timeout = self.timeout
+ if timeout is not None:
+ end_time = time.time() + timeout
+ while True:
+ if not self.getecho():
+ return True
+ if timeout < 0 and timeout is not None:
+ return False
+ if timeout is not None:
+ timeout = end_time - time.time()
+ time.sleep(0.1)
+
+ def getecho(self):
+ '''This returns the terminal echo mode. This returns True if echo is
+ on or False if echo is off. Child applications that are expecting you
+ to enter a password often set ECHO False. See waitnoecho().
+
+ Not supported on platforms where ``isatty()`` returns False. '''
+ return self.ptyproc.getecho()
+
+ def setecho(self, state):
+ '''This sets the terminal echo mode on or off. Note that anything the
+ child sent before the echo will be lost, so you should be sure that
+ your input buffer is empty before you call setecho(). For example, the
+ following will work as expected::
+
+ p = pexpect.spawn('cat') # Echo is on by default.
+ p.sendline('1234') # We expect see this twice from the child...
+ p.expect(['1234']) # ... once from the tty echo...
+ p.expect(['1234']) # ... and again from cat itself.
+ p.setecho(False) # Turn off tty echo
+ p.sendline('abcd') # We will set this only once (echoed by cat).
+ p.sendline('wxyz') # We will set this only once (echoed by cat)
+ p.expect(['abcd'])
+ p.expect(['wxyz'])
+
+ The following WILL NOT WORK because the lines sent before the setecho
+ will be lost::
+
+ p = pexpect.spawn('cat')
+ p.sendline('1234')
+ p.setecho(False) # Turn off tty echo
+ p.sendline('abcd') # We will set this only once (echoed by cat).
+ p.sendline('wxyz') # We will set this only once (echoed by cat)
+ p.expect(['1234'])
+ p.expect(['1234'])
+ p.expect(['abcd'])
+ p.expect(['wxyz'])
+
+
+ Not supported on platforms where ``isatty()`` returns False.
+ '''
+ return self.ptyproc.setecho(state)
+
+ def read_nonblocking(self, size=1, timeout=-1):
+ '''This reads at most size characters from the child application. It
+ includes a timeout. If the read does not complete within the timeout
+ period then a TIMEOUT exception is raised. If the end of file is read
+ then an EOF exception will be raised. If a logfile is specified, a
+ copy is written to that log.
+
+ If timeout is None then the read may block indefinitely.
+ If timeout is -1 then the self.timeout value is used. If timeout is 0
+ then the child is polled and if there is no data immediately ready
+ then this will raise a TIMEOUT exception.
+
+ The timeout refers only to the amount of time to read at least one
+ character. This is not affected by the 'size' parameter, so if you call
+ read_nonblocking(size=100, timeout=30) and only one character is
+ available right away then one character will be returned immediately.
+ It will not wait for 30 seconds for another 99 characters to come in.
+
+ On the other hand, if there are bytes available to read immediately,
+ all those bytes will be read (up to the buffer size). So, if the
+ buffer size is 1 megabyte and there is 1 megabyte of data available
+ to read, the buffer will be filled, regardless of timeout.
+
+ This is a wrapper around os.read(). It uses select.select() or
+ select.poll() to implement the timeout. '''
+
+ if self.closed:
+ raise ValueError('I/O operation on closed file.')
+
+ if self.use_poll:
+ def select(timeout):
+ return poll_ignore_interrupts([self.child_fd], timeout)
+ else:
+ def select(timeout):
+ return select_ignore_interrupts([self.child_fd], [], [], timeout)[0]
+
+ # If there is data available to read right now, read as much as
+ # we can. We do this to increase performance if there are a lot
+ # of bytes to be read. This also avoids calling isalive() too
+ # often. See also:
+ # * https://github.com/pexpect/pexpect/pull/304
+ # * http://trac.sagemath.org/ticket/10295
+ if select(0):
+ try:
+ incoming = super(spawn, self).read_nonblocking(size)
+ except EOF:
+ # Maybe the child is dead: update some attributes in that case
+ self.isalive()
+ raise
+ while len(incoming) < size and select(0):
+ try:
+ incoming += super(spawn, self).read_nonblocking(size - len(incoming))
+ except EOF:
+ # Maybe the child is dead: update some attributes in that case
+ self.isalive()
+ # Don't raise EOF, just return what we read so far.
+ return incoming
+ return incoming
+
+ if timeout == -1:
+ timeout = self.timeout
+
+ if not self.isalive():
+ # The process is dead, but there may or may not be data
+ # available to read. Note that some systems such as Solaris
+ # do not give an EOF when the child dies. In fact, you can
+ # still try to read from the child_fd -- it will block
+ # forever or until TIMEOUT. For that reason, it's important
+ # to do this check before calling select() with timeout.
+ if select(0):
+ return super(spawn, self).read_nonblocking(size)
+ self.flag_eof = True
+ raise EOF('End Of File (EOF). Braindead platform.')
+ elif self.__irix_hack:
+ # Irix takes a long time before it realizes a child was terminated.
+ # Make sure that the timeout is at least 2 seconds.
+ # FIXME So does this mean Irix systems are forced to always have
+ # FIXME a 2 second delay when calling read_nonblocking? That sucks.
+ if timeout is not None and timeout < 2:
+ timeout = 2
+
+ # Because of the select(0) check above, we know that no data
+ # is available right now. But if a non-zero timeout is given
+ # (possibly timeout=None), we call select() with a timeout.
+ if (timeout != 0) and select(timeout):
+ return super(spawn, self).read_nonblocking(size)
+
+ if not self.isalive():
+ # Some platforms, such as Irix, will claim that their
+ # processes are alive; timeout on the select; and
+ # then finally admit that they are not alive.
+ self.flag_eof = True
+ raise EOF('End of File (EOF). Very slow platform.')
+ else:
+ raise TIMEOUT('Timeout exceeded.')
+
+ def write(self, s):
+ '''This is similar to send() except that there is no return value.
+ '''
+
+ self.send(s)
+
+ def writelines(self, sequence):
+ '''This calls write() for each element in the sequence. The sequence
+ can be any iterable object producing strings, typically a list of
+ strings. This does not add line separators. There is no return value.
+ '''
+
+ for s in sequence:
+ self.write(s)
+
+ def send(self, s):
+ '''Sends string ``s`` to the child process, returning the number of
+ bytes written. If a logfile is specified, a copy is written to that
+ log.
+
+ The default terminal input mode is canonical processing unless set
+ otherwise by the child process. This allows backspace and other line
+ processing to be performed prior to transmitting to the receiving
+ program. As this is buffered, there is a limited size of such buffer.
+
+ On Linux systems, this is 4096 (defined by N_TTY_BUF_SIZE). All
+ other systems honor the POSIX.1 definition PC_MAX_CANON -- 1024
+ on OSX, 256 on OpenSolaris, and 1920 on FreeBSD.
+
+ This value may be discovered using fpathconf(3)::
+
+ >>> from os import fpathconf
+ >>> print(fpathconf(0, 'PC_MAX_CANON'))
+ 256
+
+ On such a system, only 256 bytes may be received per line. Any
+ subsequent bytes received will be discarded. BEL (``'\a'``) is then
+ sent to output if IMAXBEL (termios.h) is set by the tty driver.
+ This is usually enabled by default. Linux does not honor this as
+ an option -- it behaves as though it is always set on.
+
+ Canonical input processing may be disabled altogether by executing
+ a shell, then stty(1), before executing the final program::
+
+ >>> bash = pexpect.spawn('/bin/bash', echo=False)
+ >>> bash.sendline('stty -icanon')
+ >>> bash.sendline('base64')
+ >>> bash.sendline('x' * 5000)
+ '''
+
+ if self.delaybeforesend is not None:
+ time.sleep(self.delaybeforesend)
+
+ s = self._coerce_send_string(s)
+ self._log(s, 'send')
+
+ b = self._encoder.encode(s, final=False)
+ return os.write(self.child_fd, b)
+
+ def sendline(self, s=''):
+ '''Wraps send(), sending string ``s`` to child process, with
+ ``os.linesep`` automatically appended. Returns number of bytes
+ written. Only a limited number of bytes may be sent for each
+ line in the default terminal mode, see docstring of :meth:`send`.
+ '''
+ s = self._coerce_send_string(s)
+ return self.send(s + self.linesep)
+
+ def _log_control(self, s):
+ """Write control characters to the appropriate log files"""
+ if self.encoding is not None:
+ s = s.decode(self.encoding, 'replace')
+ self._log(s, 'send')
+
+ def sendcontrol(self, char):
+ '''Helper method that wraps send() with mnemonic access for sending control
+ character to the child (such as Ctrl-C or Ctrl-D). For example, to send
+ Ctrl-G (ASCII 7, bell, '\a')::
+
+ child.sendcontrol('g')
+
+ See also, sendintr() and sendeof().
+ '''
+ n, byte = self.ptyproc.sendcontrol(char)
+ self._log_control(byte)
+ return n
+
+ def sendeof(self):
+ '''This sends an EOF to the child. This sends a character which causes
+ the pending parent output buffer to be sent to the waiting child
+ program without waiting for end-of-line. If it is the first character
+ of the line, the read() in the user program returns 0, which signifies
+ end-of-file. This means to work as expected a sendeof() has to be
+ called at the beginning of a line. This method does not send a newline.
+ It is the responsibility of the caller to ensure the eof is sent at the
+ beginning of a line. '''
+
+ n, byte = self.ptyproc.sendeof()
+ self._log_control(byte)
+
+ def sendintr(self):
+ '''This sends a SIGINT to the child. It does not require
+ the SIGINT to be the first character on a line. '''
+
+ n, byte = self.ptyproc.sendintr()
+ self._log_control(byte)
+
+ @property
+ def flag_eof(self):
+ return self.ptyproc.flag_eof
+
+ @flag_eof.setter
+ def flag_eof(self, value):
+ self.ptyproc.flag_eof = value
+
+ def eof(self):
+ '''This returns True if the EOF exception was ever raised.
+ '''
+ return self.flag_eof
+
+ def terminate(self, force=False):
+ '''This forces a child process to terminate. It starts nicely with
+ SIGHUP and SIGINT. If "force" is True then moves onto SIGKILL. This
+ returns True if the child was terminated. This returns False if the
+ child could not be terminated. '''
+
+ if not self.isalive():
+ return True
+ try:
+ self.kill(signal.SIGHUP)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ self.kill(signal.SIGCONT)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ self.kill(signal.SIGINT)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ if force:
+ self.kill(signal.SIGKILL)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ else:
+ return False
+ return False
+ except OSError:
+ # I think there are kernel timing issues that sometimes cause
+ # this to happen. I think isalive() reports True, but the
+ # process is dead to the kernel.
+ # Make one last attempt to see if the kernel is up to date.
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ else:
+ return False
+
+ def wait(self):
+ '''This waits until the child exits. This is a blocking call. This will
+ not read any data from the child, so this will block forever if the
+ child has unread output and has terminated. In other words, the child
+ may have printed output then called exit(), but, the child is
+ technically still alive until its output is read by the parent.
+
+ This method is non-blocking if :meth:`wait` has already been called
+ previously or :meth:`isalive` method returns False. It simply returns
+ the previously determined exit status.
+ '''
+
+ ptyproc = self.ptyproc
+ with _wrap_ptyprocess_err():
+ # exception may occur if "Is some other process attempting
+ # "job control with our child pid?"
+ exitstatus = ptyproc.wait()
+ self.status = ptyproc.status
+ self.exitstatus = ptyproc.exitstatus
+ self.signalstatus = ptyproc.signalstatus
+ self.terminated = True
+
+ return exitstatus
+
+ def isalive(self):
+ '''This tests if the child process is running or not. This is
+ non-blocking. If the child was terminated then this will read the
+ exitstatus or signalstatus of the child. This returns True if the child
+ process appears to be running or False if not. It can take literally
+ SECONDS for Solaris to return the right status. '''
+
+ ptyproc = self.ptyproc
+ with _wrap_ptyprocess_err():
+ alive = ptyproc.isalive()
+
+ if not alive:
+ self.status = ptyproc.status
+ self.exitstatus = ptyproc.exitstatus
+ self.signalstatus = ptyproc.signalstatus
+ self.terminated = True
+
+ return alive
+
+ def kill(self, sig):
+
+ '''This sends the given signal to the child application. In keeping
+ with UNIX tradition it has a misleading name. It does not necessarily
+ kill the child unless you send the right signal. '''
+
+ # Same as os.kill, but the pid is given for you.
+ if self.isalive():
+ os.kill(self.pid, sig)
+
+ def getwinsize(self):
+ '''This returns the terminal window size of the child tty. The return
+ value is a tuple of (rows, cols). '''
+ return self.ptyproc.getwinsize()
+
+ def setwinsize(self, rows, cols):
+ '''This sets the terminal window size of the child tty. This will cause
+ a SIGWINCH signal to be sent to the child. This does not change the
+ physical window size. It changes the size reported to TTY-aware
+ applications like vi or curses -- applications that respond to the
+ SIGWINCH signal. '''
+ return self.ptyproc.setwinsize(rows, cols)
+
+
+ def interact(self, escape_character=chr(29),
+ input_filter=None, output_filter=None):
+
+ '''This gives control of the child process to the interactive user (the
+ human at the keyboard). Keystrokes are sent to the child process, and
+ the stdout and stderr output of the child process is printed. This
+ simply echos the child stdout and child stderr to the real stdout and
+ it echos the real stdin to the child stdin. When the user types the
+ escape_character this method will return None. The escape_character
+ will not be transmitted. The default for escape_character is
+ entered as ``Ctrl - ]``, the very same as BSD telnet. To prevent
+ escaping, escape_character may be set to None.
+
+ If a logfile is specified, then the data sent and received from the
+ child process in interact mode is duplicated to the given log.
+
+ You may pass in optional input and output filter functions. These
+ functions should take bytes array and return bytes array too. Even
+ with ``encoding='utf-8'`` support, meth:`interact` will always pass
+ input_filter and output_filter bytes. You may need to wrap your
+ function to decode and encode back to UTF-8.
+
+ The output_filter will be passed all the output from the child process.
+ The input_filter will be passed all the keyboard input from the user.
+ The input_filter is run BEFORE the check for the escape_character.
+
+ Note that if you change the window size of the parent the SIGWINCH
+ signal will not be passed through to the child. If you want the child
+ window size to change when the parent's window size changes then do
+ something like the following example::
+
+ import pexpect, struct, fcntl, termios, signal, sys
+ def sigwinch_passthrough (sig, data):
+ s = struct.pack("HHHH", 0, 0, 0, 0)
+ a = struct.unpack('hhhh', fcntl.ioctl(sys.stdout.fileno(),
+ termios.TIOCGWINSZ , s))
+ if not p.closed:
+ p.setwinsize(a[0],a[1])
+
+ # Note this 'p' is global and used in sigwinch_passthrough.
+ p = pexpect.spawn('/bin/bash')
+ signal.signal(signal.SIGWINCH, sigwinch_passthrough)
+ p.interact()
+ '''
+
+ # Flush the buffer.
+ self.write_to_stdout(self.buffer)
+ self.stdout.flush()
+ self._buffer = self.buffer_type()
+ mode = tty.tcgetattr(self.STDIN_FILENO)
+ tty.setraw(self.STDIN_FILENO)
+ if escape_character is not None and PY3:
+ escape_character = escape_character.encode('latin-1')
+ try:
+ self.__interact_copy(escape_character, input_filter, output_filter)
+ finally:
+ tty.tcsetattr(self.STDIN_FILENO, tty.TCSAFLUSH, mode)
+
+ def __interact_writen(self, fd, data):
+ '''This is used by the interact() method.
+ '''
+
+ while data != b'' and self.isalive():
+ n = os.write(fd, data)
+ data = data[n:]
+
+ def __interact_read(self, fd):
+ '''This is used by the interact() method.
+ '''
+
+ return os.read(fd, 1000)
+
+ def __interact_copy(
+ self, escape_character=None, input_filter=None, output_filter=None
+ ):
+
+ '''This is used by the interact() method.
+ '''
+
+ while self.isalive():
+ if self.use_poll:
+ r = poll_ignore_interrupts([self.child_fd, self.STDIN_FILENO])
+ else:
+ r, w, e = select_ignore_interrupts(
+ [self.child_fd, self.STDIN_FILENO], [], []
+ )
+ if self.child_fd in r:
+ try:
+ data = self.__interact_read(self.child_fd)
+ except OSError as err:
+ if err.args[0] == errno.EIO:
+ # Linux-style EOF
+ break
+ raise
+ if data == b'':
+ # BSD-style EOF
+ break
+ if output_filter:
+ data = output_filter(data)
+ self._log(data, 'read')
+ os.write(self.STDOUT_FILENO, data)
+ if self.STDIN_FILENO in r:
+ data = self.__interact_read(self.STDIN_FILENO)
+ if input_filter:
+ data = input_filter(data)
+ i = -1
+ if escape_character is not None:
+ i = data.rfind(escape_character)
+ if i != -1:
+ data = data[:i]
+ if data:
+ self._log(data, 'send')
+ self.__interact_writen(self.child_fd, data)
+ break
+ self._log(data, 'send')
+ self.__interact_writen(self.child_fd, data)
+
+
+def spawnu(*args, **kwargs):
+ """Deprecated: pass encoding to spawn() instead."""
+ kwargs.setdefault('encoding', 'utf-8')
+ return spawn(*args, **kwargs)
diff --git a/lib/pexpect/pxssh.py b/lib/pexpect/pxssh.py
new file mode 100644
index 0000000..3d53bd9
--- /dev/null
+++ b/lib/pexpect/pxssh.py
@@ -0,0 +1,537 @@
+'''This class extends pexpect.spawn to specialize setting up SSH connections.
+This adds methods for login, logout, and expecting the shell prompt.
+
+PEXPECT LICENSE
+
+ This license is approved by the OSI and FSF as GPL-compatible.
+ http://opensource.org/licenses/isc-license.txt
+
+ Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+ PERMISSION TO USE, COPY, MODIFY, AND/OR DISTRIBUTE THIS SOFTWARE FOR ANY
+ PURPOSE WITH OR WITHOUT FEE IS HEREBY GRANTED, PROVIDED THAT THE ABOVE
+ COPYRIGHT NOTICE AND THIS PERMISSION NOTICE APPEAR IN ALL COPIES.
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''
+
+from pexpect import ExceptionPexpect, TIMEOUT, EOF, spawn
+import time
+import os
+import sys
+import re
+
+__all__ = ['ExceptionPxssh', 'pxssh']
+
+# Exception classes used by this module.
+class ExceptionPxssh(ExceptionPexpect):
+ '''Raised for pxssh exceptions.
+ '''
+
+if sys.version_info > (3, 0):
+ from shlex import quote
+else:
+ _find_unsafe = re.compile(r'[^\w@%+=:,./-]').search
+
+ def quote(s):
+ """Return a shell-escaped version of the string *s*."""
+ if not s:
+ return "''"
+ if _find_unsafe(s) is None:
+ return s
+
+ # use single quotes, and put single quotes into double quotes
+ # the string $'b is then quoted as '$'"'"'b'
+ return "'" + s.replace("'", "'\"'\"'") + "'"
+
+class pxssh (spawn):
+ '''This class extends pexpect.spawn to specialize setting up SSH
+ connections. This adds methods for login, logout, and expecting the shell
+ prompt. It does various tricky things to handle many situations in the SSH
+ login process. For example, if the session is your first login, then pxssh
+ automatically accepts the remote certificate; or if you have public key
+ authentication setup then pxssh won't wait for the password prompt.
+
+ pxssh uses the shell prompt to synchronize output from the remote host. In
+ order to make this more robust it sets the shell prompt to something more
+ unique than just $ or #. This should work on most Borne/Bash or Csh style
+ shells.
+
+ Example that runs a few commands on a remote server and prints the result::
+
+ from pexpect import pxssh
+ import getpass
+ try:
+ s = pxssh.pxssh()
+ hostname = raw_input('hostname: ')
+ username = raw_input('username: ')
+ password = getpass.getpass('password: ')
+ s.login(hostname, username, password)
+ s.sendline('uptime') # run a command
+ s.prompt() # match the prompt
+ print(s.before) # print everything before the prompt.
+ s.sendline('ls -l')
+ s.prompt()
+ print(s.before)
+ s.sendline('df')
+ s.prompt()
+ print(s.before)
+ s.logout()
+ except pxssh.ExceptionPxssh as e:
+ print("pxssh failed on login.")
+ print(e)
+
+ Example showing how to specify SSH options::
+
+ from pexpect import pxssh
+ s = pxssh.pxssh(options={
+ "StrictHostKeyChecking": "no",
+ "UserKnownHostsFile": "/dev/null"})
+ ...
+
+ Note that if you have ssh-agent running while doing development with pxssh
+ then this can lead to a lot of confusion. Many X display managers (xdm,
+ gdm, kdm, etc.) will automatically start a GUI agent. You may see a GUI
+ dialog box popup asking for a password during development. You should turn
+ off any key agents during testing. The 'force_password' attribute will turn
+ off public key authentication. This will only work if the remote SSH server
+ is configured to allow password logins. Example of using 'force_password'
+ attribute::
+
+ s = pxssh.pxssh()
+ s.force_password = True
+ hostname = raw_input('hostname: ')
+ username = raw_input('username: ')
+ password = getpass.getpass('password: ')
+ s.login (hostname, username, password)
+
+ `debug_command_string` is only for the test suite to confirm that the string
+ generated for SSH is correct, using this will not allow you to do
+ anything other than get a string back from `pxssh.pxssh.login()`.
+ '''
+
+ def __init__ (self, timeout=30, maxread=2000, searchwindowsize=None,
+ logfile=None, cwd=None, env=None, ignore_sighup=True, echo=True,
+ options={}, encoding=None, codec_errors='strict',
+ debug_command_string=False, use_poll=False):
+
+ spawn.__init__(self, None, timeout=timeout, maxread=maxread,
+ searchwindowsize=searchwindowsize, logfile=logfile,
+ cwd=cwd, env=env, ignore_sighup=ignore_sighup, echo=echo,
+ encoding=encoding, codec_errors=codec_errors, use_poll=use_poll)
+
+ self.name = '<pxssh>'
+
+ #SUBTLE HACK ALERT! Note that the command that SETS the prompt uses a
+ #slightly different string than the regular expression to match it. This
+ #is because when you set the prompt the command will echo back, but we
+ #don't want to match the echoed command. So if we make the set command
+ #slightly different than the regex we eliminate the problem. To make the
+ #set command different we add a backslash in front of $. The $ doesn't
+ #need to be escaped, but it doesn't hurt and serves to make the set
+ #prompt command different than the regex.
+
+ # used to match the command-line prompt
+ self.UNIQUE_PROMPT = r"\[PEXPECT\][\$\#] "
+ self.PROMPT = self.UNIQUE_PROMPT
+
+ # used to set shell command-line prompt to UNIQUE_PROMPT.
+ self.PROMPT_SET_SH = r"PS1='[PEXPECT]\$ '"
+ self.PROMPT_SET_CSH = r"set prompt='[PEXPECT]\$ '"
+ self.SSH_OPTS = ("-o'RSAAuthentication=no'"
+ + " -o 'PubkeyAuthentication=no'")
+# Disabling host key checking, makes you vulnerable to MITM attacks.
+# + " -o 'StrictHostKeyChecking=no'"
+# + " -o 'UserKnownHostsFile /dev/null' ")
+ # Disabling X11 forwarding gets rid of the annoying SSH_ASKPASS from
+ # displaying a GUI password dialog. I have not figured out how to
+ # disable only SSH_ASKPASS without also disabling X11 forwarding.
+ # Unsetting SSH_ASKPASS on the remote side doesn't disable it! Annoying!
+ #self.SSH_OPTS = "-x -o'RSAAuthentication=no' -o 'PubkeyAuthentication=no'"
+ self.force_password = False
+
+ self.debug_command_string = debug_command_string
+
+ # User defined SSH options, eg,
+ # ssh.otions = dict(StrictHostKeyChecking="no",UserKnownHostsFile="/dev/null")
+ self.options = options
+
+ def levenshtein_distance(self, a, b):
+ '''This calculates the Levenshtein distance between a and b.
+ '''
+
+ n, m = len(a), len(b)
+ if n > m:
+ a,b = b,a
+ n,m = m,n
+ current = range(n+1)
+ for i in range(1,m+1):
+ previous, current = current, [i]+[0]*n
+ for j in range(1,n+1):
+ add, delete = previous[j]+1, current[j-1]+1
+ change = previous[j-1]
+ if a[j-1] != b[i-1]:
+ change = change + 1
+ current[j] = min(add, delete, change)
+ return current[n]
+
+ def try_read_prompt(self, timeout_multiplier):
+ '''This facilitates using communication timeouts to perform
+ synchronization as quickly as possible, while supporting high latency
+ connections with a tunable worst case performance. Fast connections
+ should be read almost immediately. Worst case performance for this
+ method is timeout_multiplier * 3 seconds.
+ '''
+
+ # maximum time allowed to read the first response
+ first_char_timeout = timeout_multiplier * 0.5
+
+ # maximum time allowed between subsequent characters
+ inter_char_timeout = timeout_multiplier * 0.1
+
+ # maximum time for reading the entire prompt
+ total_timeout = timeout_multiplier * 3.0
+
+ prompt = self.string_type()
+ begin = time.time()
+ expired = 0.0
+ timeout = first_char_timeout
+
+ while expired < total_timeout:
+ try:
+ prompt += self.read_nonblocking(size=1, timeout=timeout)
+ expired = time.time() - begin # updated total time expired
+ timeout = inter_char_timeout
+ except TIMEOUT:
+ break
+
+ return prompt
+
+ def sync_original_prompt (self, sync_multiplier=1.0):
+ '''This attempts to find the prompt. Basically, press enter and record
+ the response; press enter again and record the response; if the two
+ responses are similar then assume we are at the original prompt.
+ This can be a slow function. Worst case with the default sync_multiplier
+ can take 12 seconds. Low latency connections are more likely to fail
+ with a low sync_multiplier. Best case sync time gets worse with a
+ high sync multiplier (500 ms with default). '''
+
+ # All of these timing pace values are magic.
+ # I came up with these based on what seemed reliable for
+ # connecting to a heavily loaded machine I have.
+ self.sendline()
+ time.sleep(0.1)
+
+ try:
+ # Clear the buffer before getting the prompt.
+ self.try_read_prompt(sync_multiplier)
+ except TIMEOUT:
+ pass
+
+ self.sendline()
+ x = self.try_read_prompt(sync_multiplier)
+
+ self.sendline()
+ a = self.try_read_prompt(sync_multiplier)
+
+ self.sendline()
+ b = self.try_read_prompt(sync_multiplier)
+
+ ld = self.levenshtein_distance(a,b)
+ len_a = len(a)
+ if len_a == 0:
+ return False
+ if float(ld)/len_a < 0.4:
+ return True
+ return False
+
+ ### TODO: This is getting messy and I'm pretty sure this isn't perfect.
+ ### TODO: I need to draw a flow chart for this.
+ ### TODO: Unit tests for SSH tunnels, remote SSH command exec, disabling original prompt sync
+ def login (self, server, username=None, password='', terminal_type='ansi',
+ original_prompt=r"[#$]", login_timeout=10, port=None,
+ auto_prompt_reset=True, ssh_key=None, quiet=True,
+ sync_multiplier=1, check_local_ip=True,
+ password_regex=r'(?i)(?:password:)|(?:passphrase for key)',
+ ssh_tunnels={}, spawn_local_ssh=True,
+ sync_original_prompt=True, ssh_config=None, cmd='ssh'):
+ '''This logs the user into the given server.
+
+ It uses 'original_prompt' to try to find the prompt right after login.
+ When it finds the prompt it immediately tries to reset the prompt to
+ something more easily matched. The default 'original_prompt' is very
+ optimistic and is easily fooled. It's more reliable to try to match the original
+ prompt as exactly as possible to prevent false matches by server
+ strings such as the "Message Of The Day". On many systems you can
+ disable the MOTD on the remote server by creating a zero-length file
+ called :file:`~/.hushlogin` on the remote server. If a prompt cannot be found
+ then this will not necessarily cause the login to fail. In the case of
+ a timeout when looking for the prompt we assume that the original
+ prompt was so weird that we could not match it, so we use a few tricks
+ to guess when we have reached the prompt. Then we hope for the best and
+ blindly try to reset the prompt to something more unique. If that fails
+ then login() raises an :class:`ExceptionPxssh` exception.
+
+ In some situations it is not possible or desirable to reset the
+ original prompt. In this case, pass ``auto_prompt_reset=False`` to
+ inhibit setting the prompt to the UNIQUE_PROMPT. Remember that pxssh
+ uses a unique prompt in the :meth:`prompt` method. If the original prompt is
+ not reset then this will disable the :meth:`prompt` method unless you
+ manually set the :attr:`PROMPT` attribute.
+
+ Set ``password_regex`` if there is a MOTD message with `password` in it.
+ Changing this is like playing in traffic, don't (p)expect it to match straight
+ away.
+
+ If you require to connect to another SSH server from the your original SSH
+ connection set ``spawn_local_ssh`` to `False` and this will use your current
+ session to do so. Setting this option to `False` and not having an active session
+ will trigger an error.
+
+ Set ``ssh_key`` to a file path to an SSH private key to use that SSH key
+ for the session authentication.
+ Set ``ssh_key`` to `True` to force passing the current SSH authentication socket
+ to the desired ``hostname``.
+
+ Set ``ssh_config`` to a file path string of an SSH client config file to pass that
+ file to the client to handle itself. You may set any options you wish in here, however
+ doing so will require you to post extra information that you may not want to if you
+ run into issues.
+
+ Alter the ``cmd`` to change the ssh client used, or to prepend it with network
+ namespaces. For example ```cmd="ip netns exec vlan2 ssh"``` to execute the ssh in
+ network namespace named ```vlan```.
+ '''
+
+ session_regex_array = ["(?i)are you sure you want to continue connecting", original_prompt, password_regex, "(?i)permission denied", "(?i)terminal type", TIMEOUT]
+ session_init_regex_array = []
+ session_init_regex_array.extend(session_regex_array)
+ session_init_regex_array.extend(["(?i)connection closed by remote host", EOF])
+
+ ssh_options = ''.join([" -o '%s=%s'" % (o, v) for (o, v) in self.options.items()])
+ if quiet:
+ ssh_options = ssh_options + ' -q'
+ if not check_local_ip:
+ ssh_options = ssh_options + " -o'NoHostAuthenticationForLocalhost=yes'"
+ if self.force_password:
+ ssh_options = ssh_options + ' ' + self.SSH_OPTS
+ if ssh_config is not None:
+ if spawn_local_ssh and not os.path.isfile(ssh_config):
+ raise ExceptionPxssh('SSH config does not exist or is not a file.')
+ ssh_options = ssh_options + ' -F ' + ssh_config
+ if port is not None:
+ ssh_options = ssh_options + ' -p %s'%(str(port))
+ if ssh_key is not None:
+ # Allow forwarding our SSH key to the current session
+ if ssh_key==True:
+ ssh_options = ssh_options + ' -A'
+ else:
+ if spawn_local_ssh and not os.path.isfile(ssh_key):
+ raise ExceptionPxssh('private ssh key does not exist or is not a file.')
+ ssh_options = ssh_options + ' -i %s' % (ssh_key)
+
+ # SSH tunnels, make sure you know what you're putting into the lists
+ # under each heading. Do not expect these to open 100% of the time,
+ # The port you're requesting might be bound.
+ #
+ # The structure should be like this:
+ # { 'local': ['2424:localhost:22'], # Local SSH tunnels
+ # 'remote': ['2525:localhost:22'], # Remote SSH tunnels
+ # 'dynamic': [8888] } # Dynamic/SOCKS tunnels
+ if ssh_tunnels!={} and isinstance({},type(ssh_tunnels)):
+ tunnel_types = {
+ 'local':'L',
+ 'remote':'R',
+ 'dynamic':'D'
+ }
+ for tunnel_type in tunnel_types:
+ cmd_type = tunnel_types[tunnel_type]
+ if tunnel_type in ssh_tunnels:
+ tunnels = ssh_tunnels[tunnel_type]
+ for tunnel in tunnels:
+ if spawn_local_ssh==False:
+ tunnel = quote(str(tunnel))
+ ssh_options = ssh_options + ' -' + cmd_type + ' ' + str(tunnel)
+
+ if username is not None:
+ ssh_options = ssh_options + ' -l ' + username
+ elif ssh_config is None:
+ raise TypeError('login() needs either a username or an ssh_config')
+ else: # make sure ssh_config has an entry for the server with a username
+ with open(ssh_config, 'rt') as f:
+ lines = [l.strip() for l in f.readlines()]
+
+ server_regex = r'^Host\s+%s\s*$' % server
+ user_regex = r'^User\s+\w+\s*$'
+ config_has_server = False
+ server_has_username = False
+ for line in lines:
+ if not config_has_server and re.match(server_regex, line, re.IGNORECASE):
+ config_has_server = True
+ elif config_has_server and 'hostname' in line.lower():
+ pass
+ elif config_has_server and 'host' in line.lower():
+ server_has_username = False # insurance
+ break # we have left the relevant section
+ elif config_has_server and re.match(user_regex, line, re.IGNORECASE):
+ server_has_username = True
+ break
+
+ if lines:
+ del line
+
+ del lines
+
+ if not config_has_server:
+ raise TypeError('login() ssh_config has no Host entry for %s' % server)
+ elif not server_has_username:
+ raise TypeError('login() ssh_config has no user entry for %s' % server)
+
+ cmd += " %s %s" % (ssh_options, server)
+ if self.debug_command_string:
+ return(cmd)
+
+ # Are we asking for a local ssh command or to spawn one in another session?
+ if spawn_local_ssh:
+ spawn._spawn(self, cmd)
+ else:
+ self.sendline(cmd)
+
+ # This does not distinguish between a remote server 'password' prompt
+ # and a local ssh 'passphrase' prompt (for unlocking a private key).
+ i = self.expect(session_init_regex_array, timeout=login_timeout)
+
+ # First phase
+ if i==0:
+ # New certificate -- always accept it.
+ # This is what you get if SSH does not have the remote host's
+ # public key stored in the 'known_hosts' cache.
+ self.sendline("yes")
+ i = self.expect(session_regex_array)
+ if i==2: # password or passphrase
+ self.sendline(password)
+ i = self.expect(session_regex_array)
+ if i==4:
+ self.sendline(terminal_type)
+ i = self.expect(session_regex_array)
+ if i==7:
+ self.close()
+ raise ExceptionPxssh('Could not establish connection to host')
+
+ # Second phase
+ if i==0:
+ # This is weird. This should not happen twice in a row.
+ self.close()
+ raise ExceptionPxssh('Weird error. Got "are you sure" prompt twice.')
+ elif i==1: # can occur if you have a public key pair set to authenticate.
+ ### TODO: May NOT be OK if expect() got tricked and matched a false prompt.
+ pass
+ elif i==2: # password prompt again
+ # For incorrect passwords, some ssh servers will
+ # ask for the password again, others return 'denied' right away.
+ # If we get the password prompt again then this means
+ # we didn't get the password right the first time.
+ self.close()
+ raise ExceptionPxssh('password refused')
+ elif i==3: # permission denied -- password was bad.
+ self.close()
+ raise ExceptionPxssh('permission denied')
+ elif i==4: # terminal type again? WTF?
+ self.close()
+ raise ExceptionPxssh('Weird error. Got "terminal type" prompt twice.')
+ elif i==5: # Timeout
+ #This is tricky... I presume that we are at the command-line prompt.
+ #It may be that the shell prompt was so weird that we couldn't match
+ #it. Or it may be that we couldn't log in for some other reason. I
+ #can't be sure, but it's safe to guess that we did login because if
+ #I presume wrong and we are not logged in then this should be caught
+ #later when I try to set the shell prompt.
+ pass
+ elif i==6: # Connection closed by remote host
+ self.close()
+ raise ExceptionPxssh('connection closed')
+ else: # Unexpected
+ self.close()
+ raise ExceptionPxssh('unexpected login response')
+ if sync_original_prompt:
+ if not self.sync_original_prompt(sync_multiplier):
+ self.close()
+ raise ExceptionPxssh('could not synchronize with original prompt')
+ # We appear to be in.
+ # set shell prompt to something unique.
+ if auto_prompt_reset:
+ if not self.set_unique_prompt():
+ self.close()
+ raise ExceptionPxssh('could not set shell prompt '
+ '(received: %r, expected: %r).' % (
+ self.before, self.PROMPT,))
+ return True
+
+ def logout (self):
+ '''Sends exit to the remote shell.
+
+ If there are stopped jobs then this automatically sends exit twice.
+ '''
+ self.sendline("exit")
+ index = self.expect([EOF, "(?i)there are stopped jobs"])
+ if index==1:
+ self.sendline("exit")
+ self.expect(EOF)
+ self.close()
+
+ def prompt(self, timeout=-1):
+ '''Match the next shell prompt.
+
+ This is little more than a short-cut to the :meth:`~pexpect.spawn.expect`
+ method. Note that if you called :meth:`login` with
+ ``auto_prompt_reset=False``, then before calling :meth:`prompt` you must
+ set the :attr:`PROMPT` attribute to a regex that it will use for
+ matching the prompt.
+
+ Calling :meth:`prompt` will erase the contents of the :attr:`before`
+ attribute even if no prompt is ever matched. If timeout is not given or
+ it is set to -1 then self.timeout is used.
+
+ :return: True if the shell prompt was matched, False if the timeout was
+ reached.
+ '''
+
+ if timeout == -1:
+ timeout = self.timeout
+ i = self.expect([self.PROMPT, TIMEOUT], timeout=timeout)
+ if i==1:
+ return False
+ return True
+
+ def set_unique_prompt(self):
+ '''This sets the remote prompt to something more unique than ``#`` or ``$``.
+ This makes it easier for the :meth:`prompt` method to match the shell prompt
+ unambiguously. This method is called automatically by the :meth:`login`
+ method, but you may want to call it manually if you somehow reset the
+ shell prompt. For example, if you 'su' to a different user then you
+ will need to manually reset the prompt. This sends shell commands to
+ the remote host to set the prompt, so this assumes the remote host is
+ ready to receive commands.
+
+ Alternatively, you may use your own prompt pattern. In this case you
+ should call :meth:`login` with ``auto_prompt_reset=False``; then set the
+ :attr:`PROMPT` attribute to a regular expression. After that, the
+ :meth:`prompt` method will try to match your prompt pattern.
+ '''
+
+ self.sendline("unset PROMPT_COMMAND")
+ self.sendline(self.PROMPT_SET_SH) # sh-style
+ i = self.expect ([TIMEOUT, self.PROMPT], timeout=10)
+ if i == 0: # csh-style
+ self.sendline(self.PROMPT_SET_CSH)
+ i = self.expect([TIMEOUT, self.PROMPT], timeout=10)
+ if i == 0:
+ return False
+ return True
+
+# vi:ts=4:sw=4:expandtab:ft=python:
diff --git a/lib/pexpect/replwrap.py b/lib/pexpect/replwrap.py
new file mode 100644
index 0000000..c930f1e
--- /dev/null
+++ b/lib/pexpect/replwrap.py
@@ -0,0 +1,130 @@
+"""Generic wrapper for read-eval-print-loops, a.k.a. interactive shells
+"""
+import os.path
+import signal
+import sys
+
+import pexpect
+
+PY3 = (sys.version_info[0] >= 3)
+
+if PY3:
+ basestring = str
+
+PEXPECT_PROMPT = u'[PEXPECT_PROMPT>'
+PEXPECT_CONTINUATION_PROMPT = u'[PEXPECT_PROMPT+'
+
+class REPLWrapper(object):
+ """Wrapper for a REPL.
+
+ :param cmd_or_spawn: This can either be an instance of :class:`pexpect.spawn`
+ in which a REPL has already been started, or a str command to start a new
+ REPL process.
+ :param str orig_prompt: The prompt to expect at first.
+ :param str prompt_change: A command to change the prompt to something more
+ unique. If this is ``None``, the prompt will not be changed. This will
+ be formatted with the new and continuation prompts as positional
+ parameters, so you can use ``{}`` style formatting to insert them into
+ the command.
+ :param str new_prompt: The more unique prompt to expect after the change.
+ :param str extra_init_cmd: Commands to do extra initialisation, such as
+ disabling pagers.
+ """
+ def __init__(self, cmd_or_spawn, orig_prompt, prompt_change,
+ new_prompt=PEXPECT_PROMPT,
+ continuation_prompt=PEXPECT_CONTINUATION_PROMPT,
+ extra_init_cmd=None):
+ if isinstance(cmd_or_spawn, basestring):
+ self.child = pexpect.spawn(cmd_or_spawn, echo=False, encoding='utf-8')
+ else:
+ self.child = cmd_or_spawn
+ if self.child.echo:
+ # Existing spawn instance has echo enabled, disable it
+ # to prevent our input from being repeated to output.
+ self.child.setecho(False)
+ self.child.waitnoecho()
+
+ if prompt_change is None:
+ self.prompt = orig_prompt
+ else:
+ self.set_prompt(orig_prompt,
+ prompt_change.format(new_prompt, continuation_prompt))
+ self.prompt = new_prompt
+ self.continuation_prompt = continuation_prompt
+
+ self._expect_prompt()
+
+ if extra_init_cmd is not None:
+ self.run_command(extra_init_cmd)
+
+ def set_prompt(self, orig_prompt, prompt_change):
+ self.child.expect(orig_prompt)
+ self.child.sendline(prompt_change)
+
+ def _expect_prompt(self, timeout=-1, async_=False):
+ return self.child.expect_exact([self.prompt, self.continuation_prompt],
+ timeout=timeout, async_=async_)
+
+ def run_command(self, command, timeout=-1, async_=False):
+ """Send a command to the REPL, wait for and return output.
+
+ :param str command: The command to send. Trailing newlines are not needed.
+ This should be a complete block of input that will trigger execution;
+ if a continuation prompt is found after sending input, :exc:`ValueError`
+ will be raised.
+ :param int timeout: How long to wait for the next prompt. -1 means the
+ default from the :class:`pexpect.spawn` object (default 30 seconds).
+ None means to wait indefinitely.
+ :param bool async_: On Python 3.4, or Python 3.3 with asyncio
+ installed, passing ``async_=True`` will make this return an
+ :mod:`asyncio` Future, which you can yield from to get the same
+ result that this method would normally give directly.
+ """
+ # Split up multiline commands and feed them in bit-by-bit
+ cmdlines = command.splitlines()
+ # splitlines ignores trailing newlines - add it back in manually
+ if command.endswith('\n'):
+ cmdlines.append('')
+ if not cmdlines:
+ raise ValueError("No command was given")
+
+ if async_:
+ from ._async import repl_run_command_async
+ return repl_run_command_async(self, cmdlines, timeout)
+
+ res = []
+ self.child.sendline(cmdlines[0])
+ for line in cmdlines[1:]:
+ self._expect_prompt(timeout=timeout)
+ res.append(self.child.before)
+ self.child.sendline(line)
+
+ # Command was fully submitted, now wait for the next prompt
+ if self._expect_prompt(timeout=timeout) == 1:
+ # We got the continuation prompt - command was incomplete
+ self.child.kill(signal.SIGINT)
+ self._expect_prompt(timeout=1)
+ raise ValueError("Continuation prompt found - input was incomplete:\n"
+ + command)
+ return u''.join(res + [self.child.before])
+
+def python(command="python"):
+ """Start a Python shell and return a :class:`REPLWrapper` object."""
+ return REPLWrapper(command, u">>> ", u"import sys; sys.ps1={0!r}; sys.ps2={1!r}")
+
+def bash(command="bash"):
+ """Start a bash shell and return a :class:`REPLWrapper` object."""
+ bashrc = os.path.join(os.path.dirname(__file__), 'bashrc.sh')
+ child = pexpect.spawn(command, ['--rcfile', bashrc], echo=False,
+ encoding='utf-8')
+
+ # If the user runs 'env', the value of PS1 will be in the output. To avoid
+ # replwrap seeing that as the next prompt, we'll embed the marker characters
+ # for invisible characters in the prompt; these show up when inspecting the
+ # environment variable, but not when bash displays the prompt.
+ ps1 = PEXPECT_PROMPT[:5] + u'\\[\\]' + PEXPECT_PROMPT[5:]
+ ps2 = PEXPECT_CONTINUATION_PROMPT[:5] + u'\\[\\]' + PEXPECT_CONTINUATION_PROMPT[5:]
+ prompt_change = u"PS1='{0}' PS2='{1}' PROMPT_COMMAND=''".format(ps1, ps2)
+
+ return REPLWrapper(child, u'\\$', prompt_change,
+ extra_init_cmd="export PAGER=cat")
diff --git a/lib/pexpect/run.py b/lib/pexpect/run.py
new file mode 100644
index 0000000..ff288a1
--- /dev/null
+++ b/lib/pexpect/run.py
@@ -0,0 +1,157 @@
+import sys
+import types
+
+from .exceptions import EOF, TIMEOUT
+from .pty_spawn import spawn
+
+def run(command, timeout=30, withexitstatus=False, events=None,
+ extra_args=None, logfile=None, cwd=None, env=None, **kwargs):
+
+ '''
+ This function runs the given command; waits for it to finish; then
+ returns all output as a string. STDERR is included in output. If the full
+ path to the command is not given then the path is searched.
+
+ Note that lines are terminated by CR/LF (\\r\\n) combination even on
+ UNIX-like systems because this is the standard for pseudottys. If you set
+ 'withexitstatus' to true, then run will return a tuple of (command_output,
+ exitstatus). If 'withexitstatus' is false then this returns just
+ command_output.
+
+ The run() function can often be used instead of creating a spawn instance.
+ For example, the following code uses spawn::
+
+ from pexpect import *
+ child = spawn('scp foo user@example.com:.')
+ child.expect('(?i)password')
+ child.sendline(mypassword)
+
+ The previous code can be replace with the following::
+
+ from pexpect import *
+ run('scp foo user@example.com:.', events={'(?i)password': mypassword})
+
+ **Examples**
+
+ Start the apache daemon on the local machine::
+
+ from pexpect import *
+ run("/usr/local/apache/bin/apachectl start")
+
+ Check in a file using SVN::
+
+ from pexpect import *
+ run("svn ci -m 'automatic commit' my_file.py")
+
+ Run a command and capture exit status::
+
+ from pexpect import *
+ (command_output, exitstatus) = run('ls -l /bin', withexitstatus=1)
+
+ The following will run SSH and execute 'ls -l' on the remote machine. The
+ password 'secret' will be sent if the '(?i)password' pattern is ever seen::
+
+ run("ssh username@machine.example.com 'ls -l'",
+ events={'(?i)password':'secret\\n'})
+
+ This will start mencoder to rip a video from DVD. This will also display
+ progress ticks every 5 seconds as it runs. For example::
+
+ from pexpect import *
+ def print_ticks(d):
+ print d['event_count'],
+ run("mencoder dvd://1 -o video.avi -oac copy -ovc copy",
+ events={TIMEOUT:print_ticks}, timeout=5)
+
+ The 'events' argument should be either a dictionary or a tuple list that
+ contains patterns and responses. Whenever one of the patterns is seen
+ in the command output, run() will send the associated response string.
+ So, run() in the above example can be also written as:
+
+ run("mencoder dvd://1 -o video.avi -oac copy -ovc copy",
+ events=[(TIMEOUT,print_ticks)], timeout=5)
+
+ Use a tuple list for events if the command output requires a delicate
+ control over what pattern should be matched, since the tuple list is passed
+ to pexpect() as its pattern list, with the order of patterns preserved.
+
+ Note that you should put newlines in your string if Enter is necessary.
+
+ Like the example above, the responses may also contain a callback, either
+ a function or method. It should accept a dictionary value as an argument.
+ The dictionary contains all the locals from the run() function, so you can
+ access the child spawn object or any other variable defined in run()
+ (event_count, child, and extra_args are the most useful). A callback may
+ return True to stop the current run process. Otherwise run() continues
+ until the next event. A callback may also return a string which will be
+ sent to the child. 'extra_args' is not used by directly run(). It provides
+ a way to pass data to a callback function through run() through the locals
+ dictionary passed to a callback.
+
+ Like :class:`spawn`, passing *encoding* will make it work with unicode
+ instead of bytes. You can pass *codec_errors* to control how errors in
+ encoding and decoding are handled.
+ '''
+ if timeout == -1:
+ child = spawn(command, maxread=2000, logfile=logfile, cwd=cwd, env=env,
+ **kwargs)
+ else:
+ child = spawn(command, timeout=timeout, maxread=2000, logfile=logfile,
+ cwd=cwd, env=env, **kwargs)
+ if isinstance(events, list):
+ patterns= [x for x,y in events]
+ responses = [y for x,y in events]
+ elif isinstance(events, dict):
+ patterns = list(events.keys())
+ responses = list(events.values())
+ else:
+ # This assumes EOF or TIMEOUT will eventually cause run to terminate.
+ patterns = None
+ responses = None
+ child_result_list = []
+ event_count = 0
+ while True:
+ try:
+ index = child.expect(patterns)
+ if isinstance(child.after, child.allowed_string_types):
+ child_result_list.append(child.before + child.after)
+ else:
+ # child.after may have been a TIMEOUT or EOF,
+ # which we don't want appended to the list.
+ child_result_list.append(child.before)
+ if isinstance(responses[index], child.allowed_string_types):
+ child.send(responses[index])
+ elif (isinstance(responses[index], types.FunctionType) or
+ isinstance(responses[index], types.MethodType)):
+ callback_result = responses[index](locals())
+ sys.stdout.flush()
+ if isinstance(callback_result, child.allowed_string_types):
+ child.send(callback_result)
+ elif callback_result:
+ break
+ else:
+ raise TypeError("parameter `event' at index {index} must be "
+ "a string, method, or function: {value!r}"
+ .format(index=index, value=responses[index]))
+ event_count = event_count + 1
+ except TIMEOUT:
+ child_result_list.append(child.before)
+ break
+ except EOF:
+ child_result_list.append(child.before)
+ break
+ child_result = child.string_type().join(child_result_list)
+ if withexitstatus:
+ child.close()
+ return (child_result, child.exitstatus)
+ else:
+ return child_result
+
+def runu(command, timeout=30, withexitstatus=False, events=None,
+ extra_args=None, logfile=None, cwd=None, env=None, **kwargs):
+ """Deprecated: pass encoding to run() instead.
+ """
+ kwargs.setdefault('encoding', 'utf-8')
+ return run(command, timeout=timeout, withexitstatus=withexitstatus,
+ events=events, extra_args=extra_args, logfile=logfile, cwd=cwd,
+ env=env, **kwargs)
diff --git a/lib/pexpect/screen.py b/lib/pexpect/screen.py
new file mode 100644
index 0000000..79f95c4
--- /dev/null
+++ b/lib/pexpect/screen.py
@@ -0,0 +1,431 @@
+'''This implements a virtual screen. This is used to support ANSI terminal
+emulation. The screen representation and state is implemented in this class.
+Most of the methods are inspired by ANSI screen control codes. The
+:class:`~pexpect.ANSI.ANSI` class extends this class to add parsing of ANSI
+escape codes.
+
+PEXPECT LICENSE
+
+ This license is approved by the OSI and FSF as GPL-compatible.
+ http://opensource.org/licenses/isc-license.txt
+
+ Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+ PERMISSION TO USE, COPY, MODIFY, AND/OR DISTRIBUTE THIS SOFTWARE FOR ANY
+ PURPOSE WITH OR WITHOUT FEE IS HEREBY GRANTED, PROVIDED THAT THE ABOVE
+ COPYRIGHT NOTICE AND THIS PERMISSION NOTICE APPEAR IN ALL COPIES.
+ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''
+
+import codecs
+import copy
+import sys
+
+import warnings
+
+warnings.warn(("pexpect.screen and pexpect.ANSI are deprecated. "
+ "We recommend using pyte to emulate a terminal screen: "
+ "https://pypi.python.org/pypi/pyte"),
+ stacklevel=2)
+
+NUL = 0 # Fill character; ignored on input.
+ENQ = 5 # Transmit answerback message.
+BEL = 7 # Ring the bell.
+BS = 8 # Move cursor left.
+HT = 9 # Move cursor to next tab stop.
+LF = 10 # Line feed.
+VT = 11 # Same as LF.
+FF = 12 # Same as LF.
+CR = 13 # Move cursor to left margin or newline.
+SO = 14 # Invoke G1 character set.
+SI = 15 # Invoke G0 character set.
+XON = 17 # Resume transmission.
+XOFF = 19 # Halt transmission.
+CAN = 24 # Cancel escape sequence.
+SUB = 26 # Same as CAN.
+ESC = 27 # Introduce a control sequence.
+DEL = 127 # Fill character; ignored on input.
+SPACE = u' ' # Space or blank character.
+
+PY3 = (sys.version_info[0] >= 3)
+if PY3:
+ unicode = str
+
+def constrain (n, min, max):
+
+ '''This returns a number, n constrained to the min and max bounds. '''
+
+ if n < min:
+ return min
+ if n > max:
+ return max
+ return n
+
+class screen:
+ '''This object maintains the state of a virtual text screen as a
+ rectangular array. This maintains a virtual cursor position and handles
+ scrolling as characters are added. This supports most of the methods needed
+ by an ANSI text screen. Row and column indexes are 1-based (not zero-based,
+ like arrays).
+
+ Characters are represented internally using unicode. Methods that accept
+ input characters, when passed 'bytes' (which in Python 2 is equivalent to
+ 'str'), convert them from the encoding specified in the 'encoding'
+ parameter to the constructor. Methods that return screen contents return
+ unicode strings, with the exception of __str__() under Python 2. Passing
+ ``encoding=None`` limits the API to only accept unicode input, so passing
+ bytes in will raise :exc:`TypeError`.
+ '''
+ def __init__(self, r=24, c=80, encoding='latin-1', encoding_errors='replace'):
+ '''This initializes a blank screen of the given dimensions.'''
+
+ self.rows = r
+ self.cols = c
+ self.encoding = encoding
+ self.encoding_errors = encoding_errors
+ if encoding is not None:
+ self.decoder = codecs.getincrementaldecoder(encoding)(encoding_errors)
+ else:
+ self.decoder = None
+ self.cur_r = 1
+ self.cur_c = 1
+ self.cur_saved_r = 1
+ self.cur_saved_c = 1
+ self.scroll_row_start = 1
+ self.scroll_row_end = self.rows
+ self.w = [ [SPACE] * self.cols for _ in range(self.rows)]
+
+ def _decode(self, s):
+ '''This converts from the external coding system (as passed to
+ the constructor) to the internal one (unicode). '''
+ if self.decoder is not None:
+ return self.decoder.decode(s)
+ else:
+ raise TypeError("This screen was constructed with encoding=None, "
+ "so it does not handle bytes.")
+
+ def _unicode(self):
+ '''This returns a printable representation of the screen as a unicode
+ string (which, under Python 3.x, is the same as 'str'). The end of each
+ screen line is terminated by a newline.'''
+
+ return u'\n'.join ([ u''.join(c) for c in self.w ])
+
+ if PY3:
+ __str__ = _unicode
+ else:
+ __unicode__ = _unicode
+
+ def __str__(self):
+ '''This returns a printable representation of the screen. The end of
+ each screen line is terminated by a newline. '''
+ encoding = self.encoding or 'ascii'
+ return self._unicode().encode(encoding, 'replace')
+
+ def dump (self):
+ '''This returns a copy of the screen as a unicode string. This is similar to
+ __str__/__unicode__ except that lines are not terminated with line
+ feeds.'''
+
+ return u''.join ([ u''.join(c) for c in self.w ])
+
+ def pretty (self):
+ '''This returns a copy of the screen as a unicode string with an ASCII
+ text box around the screen border. This is similar to
+ __str__/__unicode__ except that it adds a box.'''
+
+ top_bot = u'+' + u'-'*self.cols + u'+\n'
+ return top_bot + u'\n'.join([u'|'+line+u'|' for line in unicode(self).split(u'\n')]) + u'\n' + top_bot
+
+ def fill (self, ch=SPACE):
+
+ if isinstance(ch, bytes):
+ ch = self._decode(ch)
+
+ self.fill_region (1,1,self.rows,self.cols, ch)
+
+ def fill_region (self, rs,cs, re,ce, ch=SPACE):
+
+ if isinstance(ch, bytes):
+ ch = self._decode(ch)
+
+ rs = constrain (rs, 1, self.rows)
+ re = constrain (re, 1, self.rows)
+ cs = constrain (cs, 1, self.cols)
+ ce = constrain (ce, 1, self.cols)
+ if rs > re:
+ rs, re = re, rs
+ if cs > ce:
+ cs, ce = ce, cs
+ for r in range (rs, re+1):
+ for c in range (cs, ce + 1):
+ self.put_abs (r,c,ch)
+
+ def cr (self):
+ '''This moves the cursor to the beginning (col 1) of the current row.
+ '''
+
+ self.cursor_home (self.cur_r, 1)
+
+ def lf (self):
+ '''This moves the cursor down with scrolling.
+ '''
+
+ old_r = self.cur_r
+ self.cursor_down()
+ if old_r == self.cur_r:
+ self.scroll_up ()
+ self.erase_line()
+
+ def crlf (self):
+ '''This advances the cursor with CRLF properties.
+ The cursor will line wrap and the screen may scroll.
+ '''
+
+ self.cr ()
+ self.lf ()
+
+ def newline (self):
+ '''This is an alias for crlf().
+ '''
+
+ self.crlf()
+
+ def put_abs (self, r, c, ch):
+ '''Screen array starts at 1 index.'''
+
+ r = constrain (r, 1, self.rows)
+ c = constrain (c, 1, self.cols)
+ if isinstance(ch, bytes):
+ ch = self._decode(ch)[0]
+ else:
+ ch = ch[0]
+ self.w[r-1][c-1] = ch
+
+ def put (self, ch):
+ '''This puts a characters at the current cursor position.
+ '''
+
+ if isinstance(ch, bytes):
+ ch = self._decode(ch)
+
+ self.put_abs (self.cur_r, self.cur_c, ch)
+
+ def insert_abs (self, r, c, ch):
+ '''This inserts a character at (r,c). Everything under
+ and to the right is shifted right one character.
+ The last character of the line is lost.
+ '''
+
+ if isinstance(ch, bytes):
+ ch = self._decode(ch)
+
+ r = constrain (r, 1, self.rows)
+ c = constrain (c, 1, self.cols)
+ for ci in range (self.cols, c, -1):
+ self.put_abs (r,ci, self.get_abs(r,ci-1))
+ self.put_abs (r,c,ch)
+
+ def insert (self, ch):
+
+ if isinstance(ch, bytes):
+ ch = self._decode(ch)
+
+ self.insert_abs (self.cur_r, self.cur_c, ch)
+
+ def get_abs (self, r, c):
+
+ r = constrain (r, 1, self.rows)
+ c = constrain (c, 1, self.cols)
+ return self.w[r-1][c-1]
+
+ def get (self):
+
+ self.get_abs (self.cur_r, self.cur_c)
+
+ def get_region (self, rs,cs, re,ce):
+ '''This returns a list of lines representing the region.
+ '''
+
+ rs = constrain (rs, 1, self.rows)
+ re = constrain (re, 1, self.rows)
+ cs = constrain (cs, 1, self.cols)
+ ce = constrain (ce, 1, self.cols)
+ if rs > re:
+ rs, re = re, rs
+ if cs > ce:
+ cs, ce = ce, cs
+ sc = []
+ for r in range (rs, re+1):
+ line = u''
+ for c in range (cs, ce + 1):
+ ch = self.get_abs (r,c)
+ line = line + ch
+ sc.append (line)
+ return sc
+
+ def cursor_constrain (self):
+ '''This keeps the cursor within the screen area.
+ '''
+
+ self.cur_r = constrain (self.cur_r, 1, self.rows)
+ self.cur_c = constrain (self.cur_c, 1, self.cols)
+
+ def cursor_home (self, r=1, c=1): # <ESC>[{ROW};{COLUMN}H
+
+ self.cur_r = r
+ self.cur_c = c
+ self.cursor_constrain ()
+
+ def cursor_back (self,count=1): # <ESC>[{COUNT}D (not confused with down)
+
+ self.cur_c = self.cur_c - count
+ self.cursor_constrain ()
+
+ def cursor_down (self,count=1): # <ESC>[{COUNT}B (not confused with back)
+
+ self.cur_r = self.cur_r + count
+ self.cursor_constrain ()
+
+ def cursor_forward (self,count=1): # <ESC>[{COUNT}C
+
+ self.cur_c = self.cur_c + count
+ self.cursor_constrain ()
+
+ def cursor_up (self,count=1): # <ESC>[{COUNT}A
+
+ self.cur_r = self.cur_r - count
+ self.cursor_constrain ()
+
+ def cursor_up_reverse (self): # <ESC> M (called RI -- Reverse Index)
+
+ old_r = self.cur_r
+ self.cursor_up()
+ if old_r == self.cur_r:
+ self.scroll_up()
+
+ def cursor_force_position (self, r, c): # <ESC>[{ROW};{COLUMN}f
+ '''Identical to Cursor Home.'''
+
+ self.cursor_home (r, c)
+
+ def cursor_save (self): # <ESC>[s
+ '''Save current cursor position.'''
+
+ self.cursor_save_attrs()
+
+ def cursor_unsave (self): # <ESC>[u
+ '''Restores cursor position after a Save Cursor.'''
+
+ self.cursor_restore_attrs()
+
+ def cursor_save_attrs (self): # <ESC>7
+ '''Save current cursor position.'''
+
+ self.cur_saved_r = self.cur_r
+ self.cur_saved_c = self.cur_c
+
+ def cursor_restore_attrs (self): # <ESC>8
+ '''Restores cursor position after a Save Cursor.'''
+
+ self.cursor_home (self.cur_saved_r, self.cur_saved_c)
+
+ def scroll_constrain (self):
+ '''This keeps the scroll region within the screen region.'''
+
+ if self.scroll_row_start <= 0:
+ self.scroll_row_start = 1
+ if self.scroll_row_end > self.rows:
+ self.scroll_row_end = self.rows
+
+ def scroll_screen (self): # <ESC>[r
+ '''Enable scrolling for entire display.'''
+
+ self.scroll_row_start = 1
+ self.scroll_row_end = self.rows
+
+ def scroll_screen_rows (self, rs, re): # <ESC>[{start};{end}r
+ '''Enable scrolling from row {start} to row {end}.'''
+
+ self.scroll_row_start = rs
+ self.scroll_row_end = re
+ self.scroll_constrain()
+
+ def scroll_down (self): # <ESC>D
+ '''Scroll display down one line.'''
+
+ # Screen is indexed from 1, but arrays are indexed from 0.
+ s = self.scroll_row_start - 1
+ e = self.scroll_row_end - 1
+ self.w[s+1:e+1] = copy.deepcopy(self.w[s:e])
+
+ def scroll_up (self): # <ESC>M
+ '''Scroll display up one line.'''
+
+ # Screen is indexed from 1, but arrays are indexed from 0.
+ s = self.scroll_row_start - 1
+ e = self.scroll_row_end - 1
+ self.w[s:e] = copy.deepcopy(self.w[s+1:e+1])
+
+ def erase_end_of_line (self): # <ESC>[0K -or- <ESC>[K
+ '''Erases from the current cursor position to the end of the current
+ line.'''
+
+ self.fill_region (self.cur_r, self.cur_c, self.cur_r, self.cols)
+
+ def erase_start_of_line (self): # <ESC>[1K
+ '''Erases from the current cursor position to the start of the current
+ line.'''
+
+ self.fill_region (self.cur_r, 1, self.cur_r, self.cur_c)
+
+ def erase_line (self): # <ESC>[2K
+ '''Erases the entire current line.'''
+
+ self.fill_region (self.cur_r, 1, self.cur_r, self.cols)
+
+ def erase_down (self): # <ESC>[0J -or- <ESC>[J
+ '''Erases the screen from the current line down to the bottom of the
+ screen.'''
+
+ self.erase_end_of_line ()
+ self.fill_region (self.cur_r + 1, 1, self.rows, self.cols)
+
+ def erase_up (self): # <ESC>[1J
+ '''Erases the screen from the current line up to the top of the
+ screen.'''
+
+ self.erase_start_of_line ()
+ self.fill_region (self.cur_r-1, 1, 1, self.cols)
+
+ def erase_screen (self): # <ESC>[2J
+ '''Erases the screen with the background color.'''
+
+ self.fill ()
+
+ def set_tab (self): # <ESC>H
+ '''Sets a tab at the current position.'''
+
+ pass
+
+ def clear_tab (self): # <ESC>[g
+ '''Clears tab at the current position.'''
+
+ pass
+
+ def clear_all_tabs (self): # <ESC>[3g
+ '''Clears all tabs.'''
+
+ pass
+
+# Insert line Esc [ Pn L
+# Delete line Esc [ Pn M
+# Delete character Esc [ Pn P
+# Scrolling region Esc [ Pn(top);Pn(bot) r
+
diff --git a/lib/pexpect/spawnbase.py b/lib/pexpect/spawnbase.py
new file mode 100644
index 0000000..59e9057
--- /dev/null
+++ b/lib/pexpect/spawnbase.py
@@ -0,0 +1,525 @@
+from io import StringIO, BytesIO
+import codecs
+import os
+import sys
+import re
+import errno
+from .exceptions import ExceptionPexpect, EOF, TIMEOUT
+from .expect import Expecter, searcher_string, searcher_re
+
+PY3 = (sys.version_info[0] >= 3)
+text_type = str if PY3 else unicode
+
+class _NullCoder(object):
+ """Pass bytes through unchanged."""
+ @staticmethod
+ def encode(b, final=False):
+ return b
+
+ @staticmethod
+ def decode(b, final=False):
+ return b
+
+class SpawnBase(object):
+ """A base class providing the backwards-compatible spawn API for Pexpect.
+
+ This should not be instantiated directly: use :class:`pexpect.spawn` or
+ :class:`pexpect.fdpexpect.fdspawn`.
+ """
+ encoding = None
+ pid = None
+ flag_eof = False
+
+ def __init__(self, timeout=30, maxread=2000, searchwindowsize=None,
+ logfile=None, encoding=None, codec_errors='strict'):
+ self.stdin = sys.stdin
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ self.searcher = None
+ self.ignorecase = False
+ self.before = None
+ self.after = None
+ self.match = None
+ self.match_index = None
+ self.terminated = True
+ self.exitstatus = None
+ self.signalstatus = None
+ # status returned by os.waitpid
+ self.status = None
+ # the child file descriptor is initially closed
+ self.child_fd = -1
+ self.timeout = timeout
+ self.delimiter = EOF
+ self.logfile = logfile
+ # input from child (read_nonblocking)
+ self.logfile_read = None
+ # output to send (send, sendline)
+ self.logfile_send = None
+ # max bytes to read at one time into buffer
+ self.maxread = maxread
+ # Data before searchwindowsize point is preserved, but not searched.
+ self.searchwindowsize = searchwindowsize
+ # Delay used before sending data to child. Time in seconds.
+ # Set this to None to skip the time.sleep() call completely.
+ self.delaybeforesend = 0.05
+ # Used by close() to give kernel time to update process status.
+ # Time in seconds.
+ self.delayafterclose = 0.1
+ # Used by terminate() to give kernel time to update process status.
+ # Time in seconds.
+ self.delayafterterminate = 0.1
+ # Delay in seconds to sleep after each call to read_nonblocking().
+ # Set this to None to skip the time.sleep() call completely: that
+ # would restore the behavior from pexpect-2.0 (for performance
+ # reasons or because you don't want to release Python's global
+ # interpreter lock).
+ self.delayafterread = 0.0001
+ self.softspace = False
+ self.name = '<' + repr(self) + '>'
+ self.closed = True
+
+ # Unicode interface
+ self.encoding = encoding
+ self.codec_errors = codec_errors
+ if encoding is None:
+ # bytes mode (accepts some unicode for backwards compatibility)
+ self._encoder = self._decoder = _NullCoder()
+ self.string_type = bytes
+ self.buffer_type = BytesIO
+ self.crlf = b'\r\n'
+ if PY3:
+ self.allowed_string_types = (bytes, str)
+ self.linesep = os.linesep.encode('ascii')
+ def write_to_stdout(b):
+ try:
+ return sys.stdout.buffer.write(b)
+ except AttributeError:
+ # If stdout has been replaced, it may not have .buffer
+ return sys.stdout.write(b.decode('ascii', 'replace'))
+ self.write_to_stdout = write_to_stdout
+ else:
+ self.allowed_string_types = (basestring,) # analysis:ignore
+ self.linesep = os.linesep
+ self.write_to_stdout = sys.stdout.write
+ else:
+ # unicode mode
+ self._encoder = codecs.getincrementalencoder(encoding)(codec_errors)
+ self._decoder = codecs.getincrementaldecoder(encoding)(codec_errors)
+ self.string_type = text_type
+ self.buffer_type = StringIO
+ self.crlf = u'\r\n'
+ self.allowed_string_types = (text_type, )
+ if PY3:
+ self.linesep = os.linesep
+ else:
+ self.linesep = os.linesep.decode('ascii')
+ # This can handle unicode in both Python 2 and 3
+ self.write_to_stdout = sys.stdout.write
+ # storage for async transport
+ self.async_pw_transport = None
+ # This is the read buffer. See maxread.
+ self._buffer = self.buffer_type()
+ # The buffer may be trimmed for efficiency reasons. This is the
+ # untrimmed buffer, used to create the before attribute.
+ self._before = self.buffer_type()
+
+ def _log(self, s, direction):
+ if self.logfile is not None:
+ self.logfile.write(s)
+ self.logfile.flush()
+ second_log = self.logfile_send if (direction=='send') else self.logfile_read
+ if second_log is not None:
+ second_log.write(s)
+ second_log.flush()
+
+ # For backwards compatibility, in bytes mode (when encoding is None)
+ # unicode is accepted for send and expect. Unicode mode is strictly unicode
+ # only.
+ def _coerce_expect_string(self, s):
+ if self.encoding is None and not isinstance(s, bytes):
+ return s.encode('ascii')
+ return s
+
+ def _coerce_send_string(self, s):
+ if self.encoding is None and not isinstance(s, bytes):
+ return s.encode('utf-8')
+ return s
+
+ def _get_buffer(self):
+ return self._buffer.getvalue()
+
+ def _set_buffer(self, value):
+ self._buffer = self.buffer_type()
+ self._buffer.write(value)
+
+ # This property is provided for backwards compatability (self.buffer used
+ # to be a string/bytes object)
+ buffer = property(_get_buffer, _set_buffer)
+
+ def read_nonblocking(self, size=1, timeout=None):
+ """This reads data from the file descriptor.
+
+ This is a simple implementation suitable for a regular file. Subclasses using ptys or pipes should override it.
+
+ The timeout parameter is ignored.
+ """
+
+ try:
+ s = os.read(self.child_fd, size)
+ except OSError as err:
+ if err.args[0] == errno.EIO:
+ # Linux-style EOF
+ self.flag_eof = True
+ raise EOF('End Of File (EOF). Exception style platform.')
+ raise
+ if s == b'':
+ # BSD-style EOF
+ self.flag_eof = True
+ raise EOF('End Of File (EOF). Empty string style platform.')
+
+ s = self._decoder.decode(s, final=False)
+ self._log(s, 'read')
+ return s
+
+ def _pattern_type_err(self, pattern):
+ raise TypeError('got {badtype} ({badobj!r}) as pattern, must be one'
+ ' of: {goodtypes}, pexpect.EOF, pexpect.TIMEOUT'\
+ .format(badtype=type(pattern),
+ badobj=pattern,
+ goodtypes=', '.join([str(ast)\
+ for ast in self.allowed_string_types])
+ )
+ )
+
+ def compile_pattern_list(self, patterns):
+ '''This compiles a pattern-string or a list of pattern-strings.
+ Patterns must be a StringType, EOF, TIMEOUT, SRE_Pattern, or a list of
+ those. Patterns may also be None which results in an empty list (you
+ might do this if waiting for an EOF or TIMEOUT condition without
+ expecting any pattern).
+
+ This is used by expect() when calling expect_list(). Thus expect() is
+ nothing more than::
+
+ cpl = self.compile_pattern_list(pl)
+ return self.expect_list(cpl, timeout)
+
+ If you are using expect() within a loop it may be more
+ efficient to compile the patterns first and then call expect_list().
+ This avoid calls in a loop to compile_pattern_list()::
+
+ cpl = self.compile_pattern_list(my_pattern)
+ while some_condition:
+ ...
+ i = self.expect_list(cpl, timeout)
+ ...
+ '''
+
+ if patterns is None:
+ return []
+ if not isinstance(patterns, list):
+ patterns = [patterns]
+
+ # Allow dot to match \n
+ compile_flags = re.DOTALL
+ if self.ignorecase:
+ compile_flags = compile_flags | re.IGNORECASE
+ compiled_pattern_list = []
+ for idx, p in enumerate(patterns):
+ if isinstance(p, self.allowed_string_types):
+ p = self._coerce_expect_string(p)
+ compiled_pattern_list.append(re.compile(p, compile_flags))
+ elif p is EOF:
+ compiled_pattern_list.append(EOF)
+ elif p is TIMEOUT:
+ compiled_pattern_list.append(TIMEOUT)
+ elif isinstance(p, type(re.compile(''))):
+ compiled_pattern_list.append(p)
+ else:
+ self._pattern_type_err(p)
+ return compiled_pattern_list
+
+ def expect(self, pattern, timeout=-1, searchwindowsize=-1, async_=False, **kw):
+ '''This seeks through the stream until a pattern is matched. The
+ pattern is overloaded and may take several types. The pattern can be a
+ StringType, EOF, a compiled re, or a list of any of those types.
+ Strings will be compiled to re types. This returns the index into the
+ pattern list. If the pattern was not a list this returns index 0 on a
+ successful match. This may raise exceptions for EOF or TIMEOUT. To
+ avoid the EOF or TIMEOUT exceptions add EOF or TIMEOUT to the pattern
+ list. That will cause expect to match an EOF or TIMEOUT condition
+ instead of raising an exception.
+
+ If you pass a list of patterns and more than one matches, the first
+ match in the stream is chosen. If more than one pattern matches at that
+ point, the leftmost in the pattern list is chosen. For example::
+
+ # the input is 'foobar'
+ index = p.expect(['bar', 'foo', 'foobar'])
+ # returns 1('foo') even though 'foobar' is a "better" match
+
+ Please note, however, that buffering can affect this behavior, since
+ input arrives in unpredictable chunks. For example::
+
+ # the input is 'foobar'
+ index = p.expect(['foobar', 'foo'])
+ # returns 0('foobar') if all input is available at once,
+ # but returns 1('foo') if parts of the final 'bar' arrive late
+
+ When a match is found for the given pattern, the class instance
+ attribute *match* becomes an re.MatchObject result. Should an EOF
+ or TIMEOUT pattern match, then the match attribute will be an instance
+ of that exception class. The pairing before and after class
+ instance attributes are views of the data preceding and following
+ the matching pattern. On general exception, class attribute
+ *before* is all data received up to the exception, while *match* and
+ *after* attributes are value None.
+
+ When the keyword argument timeout is -1 (default), then TIMEOUT will
+ raise after the default value specified by the class timeout
+ attribute. When None, TIMEOUT will not be raised and may block
+ indefinitely until match.
+
+ When the keyword argument searchwindowsize is -1 (default), then the
+ value specified by the class maxread attribute is used.
+
+ A list entry may be EOF or TIMEOUT instead of a string. This will
+ catch these exceptions and return the index of the list entry instead
+ of raising the exception. The attribute 'after' will be set to the
+ exception type. The attribute 'match' will be None. This allows you to
+ write code like this::
+
+ index = p.expect(['good', 'bad', pexpect.EOF, pexpect.TIMEOUT])
+ if index == 0:
+ do_something()
+ elif index == 1:
+ do_something_else()
+ elif index == 2:
+ do_some_other_thing()
+ elif index == 3:
+ do_something_completely_different()
+
+ instead of code like this::
+
+ try:
+ index = p.expect(['good', 'bad'])
+ if index == 0:
+ do_something()
+ elif index == 1:
+ do_something_else()
+ except EOF:
+ do_some_other_thing()
+ except TIMEOUT:
+ do_something_completely_different()
+
+ These two forms are equivalent. It all depends on what you want. You
+ can also just expect the EOF if you are waiting for all output of a
+ child to finish. For example::
+
+ p = pexpect.spawn('/bin/ls')
+ p.expect(pexpect.EOF)
+ print p.before
+
+ If you are trying to optimize for speed then see expect_list().
+
+ On Python 3.4, or Python 3.3 with asyncio installed, passing
+ ``async_=True`` will make this return an :mod:`asyncio` coroutine,
+ which you can yield from to get the same result that this method would
+ normally give directly. So, inside a coroutine, you can replace this code::
+
+ index = p.expect(patterns)
+
+ With this non-blocking form::
+
+ index = yield from p.expect(patterns, async_=True)
+ '''
+ if 'async' in kw:
+ async_ = kw.pop('async')
+ if kw:
+ raise TypeError("Unknown keyword arguments: {}".format(kw))
+
+ compiled_pattern_list = self.compile_pattern_list(pattern)
+ return self.expect_list(compiled_pattern_list,
+ timeout, searchwindowsize, async_)
+
+ def expect_list(self, pattern_list, timeout=-1, searchwindowsize=-1,
+ async_=False, **kw):
+ '''This takes a list of compiled regular expressions and returns the
+ index into the pattern_list that matched the child output. The list may
+ also contain EOF or TIMEOUT(which are not compiled regular
+ expressions). This method is similar to the expect() method except that
+ expect_list() does not recompile the pattern list on every call. This
+ may help if you are trying to optimize for speed, otherwise just use
+ the expect() method. This is called by expect().
+
+
+ Like :meth:`expect`, passing ``async_=True`` will make this return an
+ asyncio coroutine.
+ '''
+ if timeout == -1:
+ timeout = self.timeout
+ if 'async' in kw:
+ async_ = kw.pop('async')
+ if kw:
+ raise TypeError("Unknown keyword arguments: {}".format(kw))
+
+ exp = Expecter(self, searcher_re(pattern_list), searchwindowsize)
+ if async_:
+ from ._async import expect_async
+ return expect_async(exp, timeout)
+ else:
+ return exp.expect_loop(timeout)
+
+ def expect_exact(self, pattern_list, timeout=-1, searchwindowsize=-1,
+ async_=False, **kw):
+
+ '''This is similar to expect(), but uses plain string matching instead
+ of compiled regular expressions in 'pattern_list'. The 'pattern_list'
+ may be a string; a list or other sequence of strings; or TIMEOUT and
+ EOF.
+
+ This call might be faster than expect() for two reasons: string
+ searching is faster than RE matching and it is possible to limit the
+ search to just the end of the input buffer.
+
+ This method is also useful when you don't want to have to worry about
+ escaping regular expression characters that you want to match.
+
+ Like :meth:`expect`, passing ``async_=True`` will make this return an
+ asyncio coroutine.
+ '''
+ if timeout == -1:
+ timeout = self.timeout
+ if 'async' in kw:
+ async_ = kw.pop('async')
+ if kw:
+ raise TypeError("Unknown keyword arguments: {}".format(kw))
+
+ if (isinstance(pattern_list, self.allowed_string_types) or
+ pattern_list in (TIMEOUT, EOF)):
+ pattern_list = [pattern_list]
+
+ def prepare_pattern(pattern):
+ if pattern in (TIMEOUT, EOF):
+ return pattern
+ if isinstance(pattern, self.allowed_string_types):
+ return self._coerce_expect_string(pattern)
+ self._pattern_type_err(pattern)
+
+ try:
+ pattern_list = iter(pattern_list)
+ except TypeError:
+ self._pattern_type_err(pattern_list)
+ pattern_list = [prepare_pattern(p) for p in pattern_list]
+
+ exp = Expecter(self, searcher_string(pattern_list), searchwindowsize)
+ if async_:
+ from ._async import expect_async
+ return expect_async(exp, timeout)
+ else:
+ return exp.expect_loop(timeout)
+
+ def expect_loop(self, searcher, timeout=-1, searchwindowsize=-1):
+ '''This is the common loop used inside expect. The 'searcher' should be
+ an instance of searcher_re or searcher_string, which describes how and
+ what to search for in the input.
+
+ See expect() for other arguments, return value and exceptions. '''
+
+ exp = Expecter(self, searcher, searchwindowsize)
+ return exp.expect_loop(timeout)
+
+ def read(self, size=-1):
+ '''This reads at most "size" bytes from the file (less if the read hits
+ EOF before obtaining size bytes). If the size argument is negative or
+ omitted, read all data until EOF is reached. The bytes are returned as
+ a string object. An empty string is returned when EOF is encountered
+ immediately. '''
+
+ if size == 0:
+ return self.string_type()
+ if size < 0:
+ # delimiter default is EOF
+ self.expect(self.delimiter)
+ return self.before
+
+ # I could have done this more directly by not using expect(), but
+ # I deliberately decided to couple read() to expect() so that
+ # I would catch any bugs early and ensure consistent behavior.
+ # It's a little less efficient, but there is less for me to
+ # worry about if I have to later modify read() or expect().
+ # Note, it's OK if size==-1 in the regex. That just means it
+ # will never match anything in which case we stop only on EOF.
+ cre = re.compile(self._coerce_expect_string('.{%d}' % size), re.DOTALL)
+ # delimiter default is EOF
+ index = self.expect([cre, self.delimiter])
+ if index == 0:
+ ### FIXME self.before should be ''. Should I assert this?
+ return self.after
+ return self.before
+
+ def readline(self, size=-1):
+ '''This reads and returns one entire line. The newline at the end of
+ line is returned as part of the string, unless the file ends without a
+ newline. An empty string is returned if EOF is encountered immediately.
+ This looks for a newline as a CR/LF pair (\\r\\n) even on UNIX because
+ this is what the pseudotty device returns. So contrary to what you may
+ expect you will receive newlines as \\r\\n.
+
+ If the size argument is 0 then an empty string is returned. In all
+ other cases the size argument is ignored, which is not standard
+ behavior for a file-like object. '''
+
+ if size == 0:
+ return self.string_type()
+ # delimiter default is EOF
+ index = self.expect([self.crlf, self.delimiter])
+ if index == 0:
+ return self.before + self.crlf
+ else:
+ return self.before
+
+ def __iter__(self):
+ '''This is to support iterators over a file-like object.
+ '''
+ return iter(self.readline, self.string_type())
+
+ def readlines(self, sizehint=-1):
+ '''This reads until EOF using readline() and returns a list containing
+ the lines thus read. The optional 'sizehint' argument is ignored.
+ Remember, because this reads until EOF that means the child
+ process should have closed its stdout. If you run this method on
+ a child that is still running with its stdout open then this
+ method will block until it timesout.'''
+
+ lines = []
+ while True:
+ line = self.readline()
+ if not line:
+ break
+ lines.append(line)
+ return lines
+
+ def fileno(self):
+ '''Expose file descriptor for a file-like interface
+ '''
+ return self.child_fd
+
+ def flush(self):
+ '''This does nothing. It is here to support the interface for a
+ File-like object. '''
+ pass
+
+ def isatty(self):
+ """Overridden in subclass using tty"""
+ return False
+
+ # For 'with spawn(...) as child:'
+ def __enter__(self):
+ return self
+
+ def __exit__(self, etype, evalue, tb):
+ # We rely on subclasses to implement close(). If they don't, it's not
+ # clear what a context manager should do.
+ self.close()
diff --git a/lib/pexpect/utils.py b/lib/pexpect/utils.py
new file mode 100644
index 0000000..f774519
--- /dev/null
+++ b/lib/pexpect/utils.py
@@ -0,0 +1,187 @@
+import os
+import sys
+import stat
+import select
+import time
+import errno
+
+try:
+ InterruptedError
+except NameError:
+ # Alias Python2 exception to Python3
+ InterruptedError = select.error
+
+if sys.version_info[0] >= 3:
+ string_types = (str,)
+else:
+ string_types = (unicode, str)
+
+
+def is_executable_file(path):
+ """Checks that path is an executable regular file, or a symlink towards one.
+
+ This is roughly ``os.path isfile(path) and os.access(path, os.X_OK)``.
+ """
+ # follow symlinks,
+ fpath = os.path.realpath(path)
+
+ if not os.path.isfile(fpath):
+ # non-files (directories, fifo, etc.)
+ return False
+
+ mode = os.stat(fpath).st_mode
+
+ if (sys.platform.startswith('sunos')
+ and os.getuid() == 0):
+ # When root on Solaris, os.X_OK is True for *all* files, irregardless
+ # of their executability -- instead, any permission bit of any user,
+ # group, or other is fine enough.
+ #
+ # (This may be true for other "Unix98" OS's such as HP-UX and AIX)
+ return bool(mode & (stat.S_IXUSR |
+ stat.S_IXGRP |
+ stat.S_IXOTH))
+
+ return os.access(fpath, os.X_OK)
+
+
+def which(filename, env=None):
+ '''This takes a given filename; tries to find it in the environment path;
+ then checks if it is executable. This returns the full path to the filename
+ if found and executable. Otherwise this returns None.'''
+
+ # Special case where filename contains an explicit path.
+ if os.path.dirname(filename) != '' and is_executable_file(filename):
+ return filename
+ if env is None:
+ env = os.environ
+ p = env.get('PATH')
+ if not p:
+ p = os.defpath
+ pathlist = p.split(os.pathsep)
+ for path in pathlist:
+ ff = os.path.join(path, filename)
+ if is_executable_file(ff):
+ return ff
+ return None
+
+
+def split_command_line(command_line):
+
+ '''This splits a command line into a list of arguments. It splits arguments
+ on spaces, but handles embedded quotes, doublequotes, and escaped
+ characters. It's impossible to do this with a regular expression, so I
+ wrote a little state machine to parse the command line. '''
+
+ arg_list = []
+ arg = ''
+
+ # Constants to name the states we can be in.
+ state_basic = 0
+ state_esc = 1
+ state_singlequote = 2
+ state_doublequote = 3
+ # The state when consuming whitespace between commands.
+ state_whitespace = 4
+ state = state_basic
+
+ for c in command_line:
+ if state == state_basic or state == state_whitespace:
+ if c == '\\':
+ # Escape the next character
+ state = state_esc
+ elif c == r"'":
+ # Handle single quote
+ state = state_singlequote
+ elif c == r'"':
+ # Handle double quote
+ state = state_doublequote
+ elif c.isspace():
+ # Add arg to arg_list if we aren't in the middle of whitespace.
+ if state == state_whitespace:
+ # Do nothing.
+ None
+ else:
+ arg_list.append(arg)
+ arg = ''
+ state = state_whitespace
+ else:
+ arg = arg + c
+ state = state_basic
+ elif state == state_esc:
+ arg = arg + c
+ state = state_basic
+ elif state == state_singlequote:
+ if c == r"'":
+ state = state_basic
+ else:
+ arg = arg + c
+ elif state == state_doublequote:
+ if c == r'"':
+ state = state_basic
+ else:
+ arg = arg + c
+
+ if arg != '':
+ arg_list.append(arg)
+ return arg_list
+
+
+def select_ignore_interrupts(iwtd, owtd, ewtd, timeout=None):
+
+ '''This is a wrapper around select.select() that ignores signals. If
+ select.select raises a select.error exception and errno is an EINTR
+ error then it is ignored. Mainly this is used to ignore sigwinch
+ (terminal resize). '''
+
+ # if select() is interrupted by a signal (errno==EINTR) then
+ # we loop back and enter the select() again.
+ if timeout is not None:
+ end_time = time.time() + timeout
+ while True:
+ try:
+ return select.select(iwtd, owtd, ewtd, timeout)
+ except InterruptedError:
+ err = sys.exc_info()[1]
+ if err.args[0] == errno.EINTR:
+ # if we loop back we have to subtract the
+ # amount of time we already waited.
+ if timeout is not None:
+ timeout = end_time - time.time()
+ if timeout < 0:
+ return([], [], [])
+ else:
+ # something else caused the select.error, so
+ # this actually is an exception.
+ raise
+
+
+def poll_ignore_interrupts(fds, timeout=None):
+ '''Simple wrapper around poll to register file descriptors and
+ ignore signals.'''
+
+ if timeout is not None:
+ end_time = time.time() + timeout
+
+ poller = select.poll()
+ for fd in fds:
+ poller.register(fd, select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR)
+
+ while True:
+ try:
+ timeout_ms = None if timeout is None else timeout * 1000
+ results = poller.poll(timeout_ms)
+ return [afd for afd, _ in results]
+ except InterruptedError:
+ err = sys.exc_info()[1]
+ if err.args[0] == errno.EINTR:
+ # if we loop back we have to subtract the
+ # amount of time we already waited.
+ if timeout is not None:
+ timeout = end_time - time.time()
+ if timeout < 0:
+ return []
+ else:
+ # something else caused the select.error, so
+ # this actually is an exception.
+ raise
diff --git a/lib/prettytable-3.6.0.dist-info/INSTALLER b/lib/prettytable-3.6.0.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/prettytable-3.6.0.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/prettytable-3.6.0.dist-info/METADATA b/lib/prettytable-3.6.0.dist-info/METADATA
new file mode 100644
index 0000000..5971cb6
--- /dev/null
+++ b/lib/prettytable-3.6.0.dist-info/METADATA
@@ -0,0 +1,702 @@
+Metadata-Version: 2.1
+Name: prettytable
+Version: 3.6.0
+Summary: A simple Python library for easily displaying tabular data in a visually appealing ASCII table format
+Project-URL: Homepage, https://github.com/jazzband/prettytable
+Project-URL: Source, https://github.com/jazzband/prettytable
+Author-email: Luke Maurits <luke@maurits.id.au>
+Maintainer: Jazzband
+License: BSD (3 clause)
+License-File: COPYING
+Classifier: License :: OSI Approved :: BSD License
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3 :: Only
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: 3.11
+Classifier: Programming Language :: Python :: Implementation :: CPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+Classifier: Topic :: Text Processing
+Classifier: Typing :: Typed
+Requires-Python: >=3.7
+Requires-Dist: importlib-metadata; python_version < '3.8'
+Requires-Dist: wcwidth
+Provides-Extra: tests
+Requires-Dist: pytest; extra == 'tests'
+Requires-Dist: pytest-cov; extra == 'tests'
+Requires-Dist: pytest-lazy-fixture; extra == 'tests'
+Description-Content-Type: text/markdown
+
+# PrettyTable
+
+[![Jazzband](https://jazzband.co/static/img/badge.svg)](https://jazzband.co/)
+[![PyPI version](https://img.shields.io/pypi/v/prettytable.svg?logo=pypi&logoColor=FFE873)](https://pypi.org/project/prettytable/)
+[![Supported Python versions](https://img.shields.io/pypi/pyversions/prettytable.svg?logo=python&logoColor=FFE873)](https://pypi.org/project/prettytable/)
+[![PyPI downloads](https://img.shields.io/pypi/dm/prettytable.svg)](https://pypistats.org/packages/prettytable)
+[![GitHub Actions status](https://github.com/jazzband/prettytable/workflows/Test/badge.svg)](https://github.com/jazzband/prettytable/actions)
+[![codecov](https://codecov.io/gh/jazzband/prettytable/branch/master/graph/badge.svg)](https://codecov.io/gh/jazzband/prettytable)
+[![Code style: Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
+
+PrettyTable lets you print tables in an attractive ASCII form:
+
+```
++-----------+------+------------+-----------------+
+| City name | Area | Population | Annual Rainfall |
++-----------+------+------------+-----------------+
+| Adelaide | 1295 | 1158259 | 600.5 |
+| Brisbane | 5905 | 1857594 | 1146.4 |
+| Darwin | 112 | 120900 | 1714.7 |
+| Hobart | 1357 | 205556 | 619.5 |
+| Melbourne | 1566 | 3806092 | 646.9 |
+| Perth | 5386 | 1554769 | 869.4 |
+| Sydney | 2058 | 4336374 | 1214.8 |
++-----------+------+------------+-----------------+
+```
+
+## Installation
+
+Install via pip:
+
+ python -m pip install -U prettytable
+
+Install latest development version:
+
+ python -m pip install -U git+https://github.com/jazzband/prettytable
+
+Or from `requirements.txt`:
+
+ -e git://github.com/jazzband/prettytable.git#egg=prettytable
+
+## Tutorial on how to use the PrettyTable API
+
+### Getting your data into (and out of) the table
+
+Let's suppose you have a shiny new PrettyTable:
+
+```python
+from prettytable import PrettyTable
+x = PrettyTable()
+```
+
+and you want to put some data into it. You have a few options.
+
+#### Row by row
+
+You can add data one row at a time. To do this you can set the field names first using
+the `field_names` attribute, and then add the rows one at a time using the `add_row`
+method:
+
+```python
+x.field_names = ["City name", "Area", "Population", "Annual Rainfall"]
+x.add_row(["Adelaide", 1295, 1158259, 600.5])
+x.add_row(["Brisbane", 5905, 1857594, 1146.4])
+x.add_row(["Darwin", 112, 120900, 1714.7])
+x.add_row(["Hobart", 1357, 205556, 619.5])
+x.add_row(["Sydney", 2058, 4336374, 1214.8])
+x.add_row(["Melbourne", 1566, 3806092, 646.9])
+x.add_row(["Perth", 5386, 1554769, 869.4])
+```
+
+#### All rows at once
+
+When you have a list of rows, you can add them in one go with `add_rows`:
+
+```python
+x.field_names = ["City name", "Area", "Population", "Annual Rainfall"]
+x.add_rows(
+ [
+ ["Adelaide", 1295, 1158259, 600.5],
+ ["Brisbane", 5905, 1857594, 1146.4],
+ ["Darwin", 112, 120900, 1714.7],
+ ["Hobart", 1357, 205556, 619.5],
+ ["Sydney", 2058, 4336374, 1214.8],
+ ["Melbourne", 1566, 3806092, 646.9],
+ ["Perth", 5386, 1554769, 869.4],
+ ]
+)
+```
+
+#### Column by column
+
+You can add data one column at a time as well. To do this you use the `add_column`
+method, which takes two arguments - a string which is the name for the field the column
+you are adding corresponds to, and a list or tuple which contains the column data:
+
+```python
+x.add_column("City name",
+["Adelaide","Brisbane","Darwin","Hobart","Sydney","Melbourne","Perth"])
+x.add_column("Area", [1295, 5905, 112, 1357, 2058, 1566, 5386])
+x.add_column("Population", [1158259, 1857594, 120900, 205556, 4336374, 3806092,
+1554769])
+x.add_column("Annual Rainfall",[600.5, 1146.4, 1714.7, 619.5, 1214.8, 646.9,
+869.4])
+```
+
+#### Mixing and matching
+
+If you really want to, you can even mix and match `add_row` and `add_column` and build
+some of your table in one way and some of it in the other. Tables built this way are
+kind of confusing for other people to read, though, so don't do this unless you have a
+good reason.
+
+#### Importing data from a CSV file
+
+If you have your table data in a comma-separated values file (.csv), you can read this
+data into a PrettyTable like this:
+
+```python
+from prettytable import from_csv
+with open("myfile.csv") as fp:
+ mytable = from_csv(fp)
+```
+
+#### Importing data from a database cursor
+
+If you have your table data in a database which you can access using a library which
+confirms to the Python DB-API (e.g. an SQLite database accessible using the `sqlite`
+module), then you can build a PrettyTable using a cursor object, like this:
+
+```python
+import sqlite3
+from prettytable import from_db_cursor
+
+connection = sqlite3.connect("mydb.db")
+cursor = connection.cursor()
+cursor.execute("SELECT field1, field2, field3 FROM my_table")
+mytable = from_db_cursor(cursor)
+```
+
+#### Getting data out
+
+There are three ways to get data out of a PrettyTable, in increasing order of
+completeness:
+
+- The `del_row` method takes an integer index of a single row to delete.
+- The `del_column` method takes a field name of a single column to delete.
+- The `clear_rows` method takes no arguments and deletes all the rows in the table - but
+ keeps the field names as they were so you that you can repopulate it with the same
+ kind of data.
+- The `clear` method takes no arguments and deletes all rows and all field names. It's
+ not quite the same as creating a fresh table instance, though - style related
+ settings, discussed later, are maintained.
+
+### Displaying your table in ASCII form
+
+PrettyTable's main goal is to let you print tables in an attractive ASCII form, like
+this:
+
+```
++-----------+------+------------+-----------------+
+| City name | Area | Population | Annual Rainfall |
++-----------+------+------------+-----------------+
+| Adelaide | 1295 | 1158259 | 600.5 |
+| Brisbane | 5905 | 1857594 | 1146.4 |
+| Darwin | 112 | 120900 | 1714.7 |
+| Hobart | 1357 | 205556 | 619.5 |
+| Melbourne | 1566 | 3806092 | 646.9 |
+| Perth | 5386 | 1554769 | 869.4 |
+| Sydney | 2058 | 4336374 | 1214.8 |
++-----------+------+------------+-----------------+
+```
+
+You can print tables like this to `stdout` or get string representations of them.
+
+#### Printing
+
+To print a table in ASCII form, you can just do this:
+
+```python
+print(x)
+```
+
+The old `x.printt()` method from versions 0.5 and earlier has been removed.
+
+To pass options changing the look of the table, use the `get_string()` method documented
+below:
+
+```python
+print(x.get_string())
+```
+
+#### Stringing
+
+If you don't want to actually print your table in ASCII form but just get a string
+containing what _would_ be printed if you use `print(x)`, you can use the `get_string`
+method:
+
+```python
+mystring = x.get_string()
+```
+
+This string is guaranteed to look exactly the same as what would be printed by doing
+`print(x)`. You can now do all the usual things you can do with a string, like write
+your table to a file or insert it into a GUI.
+
+#### Controlling which data gets displayed
+
+If you like, you can restrict the output of `print(x)` or `x.get_string` to only the
+fields or rows you like.
+
+The `fields` argument to these methods takes a list of field names to be printed:
+
+```python
+print(x.get_string(fields=["City name", "Population"]))
+```
+
+gives:
+
+```
++-----------+------------+
+| City name | Population |
++-----------+------------+
+| Adelaide | 1158259 |
+| Brisbane | 1857594 |
+| Darwin | 120900 |
+| Hobart | 205556 |
+| Melbourne | 3806092 |
+| Perth | 1554769 |
+| Sydney | 4336374 |
++-----------+------------+
+```
+
+The `start` and `end` arguments take the index of the first and last row to print
+respectively. Note that the indexing works like Python list slicing - to print the 2nd,
+3rd and 4th rows of the table, set `start` to 1 (the first row is row 0, so the second
+is row 1) and set `end` to 4 (the index of the 4th row, plus 1):
+
+```python
+print(x.get_string(start=1, end=4))
+```
+
+prints:
+
+```
++-----------+------+------------+-----------------+
+| City name | Area | Population | Annual Rainfall |
++-----------+------+------------+-----------------+
+| Brisbane | 5905 | 1857594 | 1146.4 |
+| Darwin | 112 | 120900 | 1714.7 |
+| Hobart | 1357 | 205556 | 619.5 |
++-----------+------+------------+-----------------+
+```
+
+#### Changing the alignment of columns
+
+By default, all columns in a table are centre aligned.
+
+##### All columns at once
+
+You can change the alignment of all the columns in a table at once by assigning a one
+character string to the `align` attribute. The allowed strings are `"l"`, `"r"` and
+`"c"` for left, right and centre alignment, respectively:
+
+```python
+x.align = "r"
+print(x)
+```
+
+gives:
+
+```
++-----------+------+------------+-----------------+
+| City name | Area | Population | Annual Rainfall |
++-----------+------+------------+-----------------+
+| Adelaide | 1295 | 1158259 | 600.5 |
+| Brisbane | 5905 | 1857594 | 1146.4 |
+| Darwin | 112 | 120900 | 1714.7 |
+| Hobart | 1357 | 205556 | 619.5 |
+| Melbourne | 1566 | 3806092 | 646.9 |
+| Perth | 5386 | 1554769 | 869.4 |
+| Sydney | 2058 | 4336374 | 1214.8 |
++-----------+------+------------+-----------------+
+```
+
+##### One column at a time
+
+You can also change the alignment of individual columns based on the corresponding field
+name by treating the `align` attribute as if it were a dictionary.
+
+```python
+x.align["City name"] = "l"
+x.align["Area"] = "c"
+x.align["Population"] = "r"
+x.align["Annual Rainfall"] = "c"
+print(x)
+```
+
+gives:
+
+```
++-----------+------+------------+-----------------+
+| City name | Area | Population | Annual Rainfall |
++-----------+------+------------+-----------------+
+| Adelaide | 1295 | 1158259 | 600.5 |
+| Brisbane | 5905 | 1857594 | 1146.4 |
+| Darwin | 112 | 120900 | 1714.7 |
+| Hobart | 1357 | 205556 | 619.5 |
+| Melbourne | 1566 | 3806092 | 646.9 |
+| Perth | 5386 | 1554769 | 869.4 |
+| Sydney | 2058 | 4336374 | 1214.8 |
++-----------+------+------------+-----------------+
+```
+
+##### Sorting your table by a field
+
+You can make sure that your ASCII tables are produced with the data sorted by one
+particular field by giving `get_string` a `sortby` keyword argument, which must be a
+string containing the name of one field.
+
+For example, to print the example table we built earlier of Australian capital city
+data, so that the most populated city comes last, we can do this:
+
+```python
+print(x.get_string(sortby="Population"))
+```
+
+to get:
+
+```
++-----------+------+------------+-----------------+
+| City name | Area | Population | Annual Rainfall |
++-----------+------+------------+-----------------+
+| Darwin | 112 | 120900 | 1714.7 |
+| Hobart | 1357 | 205556 | 619.5 |
+| Adelaide | 1295 | 1158259 | 600.5 |
+| Perth | 5386 | 1554769 | 869.4 |
+| Brisbane | 5905 | 1857594 | 1146.4 |
+| Melbourne | 1566 | 3806092 | 646.9 |
+| Sydney | 2058 | 4336374 | 1214.8 |
++-----------+------+------------+-----------------+
+```
+
+If we want the most populated city to come _first_, we can also give a
+`reversesort=True` argument.
+
+If you _always_ want your tables to be sorted in a certain way, you can make the setting
+long-term like this:
+
+```python
+x.sortby = "Population"
+print(x)
+print(x)
+print(x)
+```
+
+All three tables printed by this code will be sorted by population (you could do
+`x.reversesort = True` as well, if you wanted). The behaviour will persist until you
+turn it off:
+
+```python
+x.sortby = None
+```
+
+If you want to specify a custom sorting function, you can use the `sort_key` keyword
+argument. Pass this a function which accepts two lists of values and returns a negative
+or positive value depending on whether the first list should appear before or after the
+second one. If your table has n columns, each list will have n+1 elements. Each list
+corresponds to one row of the table. The first element will be whatever data is in the
+relevant row, in the column specified by the `sort_by` argument. The remaining n
+elements are the data in each of the table's columns, in order, including a repeated
+instance of the data in the `sort_by` column.
+
+### Changing the appearance of your table - the easy way
+
+By default, PrettyTable produces ASCII tables that look like the ones used in SQL
+database shells. But it can print them in a variety of other formats as well. If the
+format you want to use is common, PrettyTable makes this easy for you to do using the
+`set_style` method. If you want to produce an uncommon table, you'll have to do things
+slightly harder (see later).
+
+#### Setting a table style
+
+You can set the style for your table using the `set_style` method before any calls to
+`print` or `get_string`. Here's how to print a table in a format which works nicely with
+Microsoft Word's "Convert to table" feature:
+
+```python
+from prettytable import MSWORD_FRIENDLY
+x.set_style(MSWORD_FRIENDLY)
+print(x)
+```
+
+In addition to `MSWORD_FRIENDLY` you can use these in-built styles for your tables:
+
+- `DEFAULT` - The default look, used to undo any style changes you may have made
+- `PLAIN_COLUMNS` - A borderless style that works well with command line programs for
+ columnar data
+- `MARKDOWN` - A style that follows Markdown syntax
+- `ORGMODE` - A table style that fits [Org mode](https://orgmode.org/) syntax
+- `SINGLE_BORDER` and `DOUBLE_BORDER` - Styles that use continuous single/double border
+ lines with Box drawing characters for a fancier display on terminal
+
+Other styles are likely to appear in future releases.
+
+### Changing the appearance of your table - the hard way
+
+If you want to display your table in a style other than one of the in-built styles
+listed above, you'll have to set things up the hard way.
+
+Don't worry, it's not really that hard!
+
+#### Style options
+
+PrettyTable has a number of style options which control various aspects of how tables
+are displayed. You have the freedom to set each of these options individually to
+whatever you prefer. The `set_style` method just does this automatically for you.
+
+The options are these:
+
+- `border` - A boolean option (must be `True` or `False`). Controls whether a border is
+ drawn inside and around the table.
+- `preserve_internal_border` - A boolean option (must be `True` or `False`). Controls
+ whether borders are still drawn within the table even when `border=False`.
+- `header` - A boolean option (must be `True` or `False`). Controls whether the first
+ row of the table is a header showing the names of all the fields.
+- `hrules` - Controls printing of horizontal rules after rows. Allowed values: `FRAME`,
+ `HEADER`, `ALL`, `NONE` - note that these are variables defined inside the
+ `prettytable` module so make sure you import them or use `prettytable.FRAME` etc.
+- `vrules` - Controls printing of vertical rules between columns. Allowed values:
+ `FRAME`, `ALL`, `NONE`.
+- `int_format` - A string which controls the way integer data is printed. This works
+ like: `print("%<int_format>d" % data)`
+- `float_format` - A string which controls the way floating point data is printed. This
+ works like: `print("%<float_format>f" % data)`
+- `custom_format` - A Dictionary of field and callable. This allows you to set any
+ format you want `pf.custom_format["my_col_int"] = ()lambda f, v: f"{v:,}"`. The type
+ of the callable if `callable[[str, Any], str]`
+- `padding_width` - Number of spaces on either side of column data (only used if left
+ and right paddings are `None`).
+- `left_padding_width` - Number of spaces on left-hand side of column data.
+- `right_padding_width` - Number of spaces on right-hand side of column data.
+- `vertical_char` - Single character string used to draw vertical lines. Default is `|`.
+- `horizontal_char` - Single character string used to draw horizontal lines. Default is
+ `-`.
+- `_horizontal_align_char` - single character string used to indicate column alignment
+ in horizontal lines. Default is `:` for Markdown, otherwise `None`.
+- `junction_char` - Single character string used to draw line junctions. Default is `+`.
+- `top_junction_char` - single character string used to draw top line junctions. Default
+ is `junction_char`.
+- `bottom_junction_char` - single character string used to draw bottom line junctions.
+ Default is `junction_char`.
+- `right_junction_char` - single character string used to draw right line junctions.
+ Default is `junction_char`.
+- `left_junction_char` - single character string used to draw left line junctions.
+ Default is `junction_char`.
+- `top_right_junction_char` - single character string used to draw top-right line
+ junctions. Default is `junction_char`.
+- `top_left_junction_char` - single character string used to draw top-left line
+ junctions. Default is `junction_char`.
+- `bottom_right_junction_char` - single character string used to draw bottom-right line
+ junctions. Default is `junction_char`
+- `bottom_left_junction_char` - single character string used to draw bottom-left line
+ junctions. Default is `junction_char`.
+
+You can set the style options to your own settings in two ways:
+
+#### Setting style options for the long term
+
+If you want to print your table with a different style several times, you can set your
+option for the long term just by changing the appropriate attributes. If you never want
+your tables to have borders you can do this:
+
+```python
+x.border = False
+print(x)
+print(x)
+print(x)
+```
+
+Neither of the 3 tables printed by this will have borders, even if you do things like
+add extra rows in between them. The lack of borders will last until you do:
+
+```python
+x.border = True
+```
+
+to turn them on again. This sort of long-term setting is exactly how `set_style` works.
+`set_style` just sets a bunch of attributes to pre-set values for you.
+
+Note that if you know what style options you want at the moment you are creating your
+table, you can specify them using keyword arguments to the constructor. For example, the
+following two code blocks are equivalent:
+
+```python
+x = PrettyTable()
+x.border = False
+x.header = False
+x.padding_width = 5
+
+x = PrettyTable(border=False, header=False, padding_width=5)
+```
+
+#### Changing style options just once
+
+If you don't want to make long-term style changes by changing an attribute like in the
+previous section, you can make changes that last for just one `get_string` by giving
+those methods keyword arguments. To print two "normal" tables with one borderless table
+between them, you could do this:
+
+```python
+print(x)
+print(x.get_string(border=False))
+print(x)
+```
+
+### Changing the appearance of your table - with _colors_!
+
+PrettyTable has the functionality of printing your table with ANSI color codes. This
+includes support for most Windows versions through
+[Colorama](https://pypi.org/project/colorama/). To get started, import the `ColorTable`
+class instead of `PrettyTable`.
+
+```diff
+-from prettytable import PrettyTable
++from prettytable.colortable import ColorTable
+```
+
+The `ColorTable` class can be used the same as `PrettyTable`, but it adds an extra
+property. You can now specify a custom _theme_ that will format your table with colors.
+
+```python
+from prettytable.colortable import ColorTable, Themes
+
+x = ColorTable(theme=Themes.OCEAN)
+
+print(x)
+```
+
+#### Creating a custom theme
+
+The `Theme` class allows you to customize both the characters and colors used in your
+table.
+
+| Argument | Description |
+| ---------------------------------------------------------- | --------------------------------------------------------- |
+| `default_color` | The color to use as default |
+| `vertical_char`, `horizontal_char`, and `junction_char` | The characters used for creating the outline of the table |
+| `vertical_color`, `horizontal_color`, and `junction_color` | The colors used to style each character. |
+
+> **Note:** Colors are formatted with the `Theme.format_code(s: str)` function. It
+> accepts a string. If the string starts with an escape code (like `\x1b`) then it will
+> return the given string. If the string is just whitespace, it will return `""`. If the
+> string is a number (like `"34"`), it will automatically format it into an escape code.
+> I recommend you look into the source code for more information.
+
+### Displaying your table in JSON
+
+PrettyTable will also print your tables in JSON, as a list of fields and an array of
+rows. Just like in ASCII form, you can actually get a string representation - just use
+`get_json_string()`.
+
+### Displaying your table in HTML form
+
+PrettyTable will also print your tables in HTML form, as `<table>`s. Just like in ASCII
+form, you can actually get a string representation - just use `get_html_string()`. HTML
+printing supports the `fields`, `start`, `end`, `sortby` and `reversesort` arguments in
+exactly the same way as ASCII printing.
+
+#### Styling HTML tables
+
+By default, PrettyTable outputs HTML for "vanilla" tables. The HTML code is quite
+simple. It looks like this:
+
+```html
+<table>
+ <thead>
+ <tr>
+ <th>City name</th>
+ <th>Area</th>
+ <th>Population</th>
+ <th>Annual Rainfall</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td>Adelaide</td>
+ <td>1295</td>
+ <td>1158259</td>
+ <td>600.5</td>
+ </tr>
+ <tr>
+ <td>Brisbane</td>
+ <td>5905</td>
+ <td>1857594</td>
+ <td>1146.4</td>
+ ...
+ </tr>
+ </tbody>
+</table>
+```
+
+If you like, you can ask PrettyTable to do its best to mimic the style options that your
+table has set using inline CSS. This is done by giving a `format=True` keyword argument
+to `get_html_string` method. Note that if you _always_ want to print formatted HTML you
+can do:
+
+```python
+x.format = True
+```
+
+and the setting will persist until you turn it off.
+
+Just like with ASCII tables, if you want to change the table's style for just one
+`get_html_string` you can pass those methods' keyword arguments - exactly like `print`
+and `get_string`.
+
+#### Setting HTML attributes
+
+You can provide a dictionary of HTML attribute name/value pairs to the `get_html_string`
+method using the `attributes` keyword argument. This lets you specify common HTML
+attributes like `id` and `class` that can be used for linking to your tables or
+customising their appearance using CSS. For example:
+
+```python
+print(x.get_html_string(attributes={"id":"my_table", "class":"red_table"}))
+```
+
+will print:
+
+```html
+<table id="my_table" class="red_table">
+ <thead>
+ <tr>
+ <th>City name</th>
+ <th>Area</th>
+ <th>Population</th>
+ <th>Annual Rainfall</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ ... ... ...
+ </tr>
+ </tbody>
+</table>
+```
+
+### Miscellaneous things
+
+#### Copying a table
+
+You can call the `copy` method on a PrettyTable object without arguments to return an
+identical independent copy of the table.
+
+If you want a copy of a PrettyTable object with just a subset of the rows, you can use
+list slicing notation:
+
+```python
+new_table = old_table[0:5]
+```
+
+## Contributing
+
+After editing files, use the [Black](https://github.com/psf/black) linter to auto-format
+changed lines.
+
+```sh
+python -m pip install black
+black prettytable*.py
+```
diff --git a/lib/prettytable-3.6.0.dist-info/RECORD b/lib/prettytable-3.6.0.dist-info/RECORD
new file mode 100644
index 0000000..8702d1e
--- /dev/null
+++ b/lib/prettytable-3.6.0.dist-info/RECORD
@@ -0,0 +1,13 @@
+prettytable-3.6.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+prettytable-3.6.0.dist-info/METADATA,sha256=fU3BZXeOl90w8gPN-jvtUpMXwbMH54OnidclKA7hLbM,25058
+prettytable-3.6.0.dist-info/RECORD,,
+prettytable-3.6.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+prettytable-3.6.0.dist-info/WHEEL,sha256=hKi7AIIx6qfnsRbr087vpeJnrVUuDokDHZacPPMW7-Y,87
+prettytable-3.6.0.dist-info/licenses/COPYING,sha256=DIrcwgTIr2zf33iH39H5nCmQH2wdUFp4u6lu-k-yywU,1612
+prettytable/__init__.py,sha256=f4jlohGU7tzRTutkAieeDrxnxt02322kr7ayXZH7ixs,920
+prettytable/__pycache__/__init__.cpython-39.pyc,,
+prettytable/__pycache__/colortable.cpython-39.pyc,,
+prettytable/__pycache__/prettytable.cpython-39.pyc,,
+prettytable/colortable.py,sha256=DFN-wtny2c-jdBtLTgwTlmnPWLrEAp41QqrrFQ0Ve8E,2445
+prettytable/prettytable.py,sha256=XMVnRju2YGrKqkkVSm9ImOt42uvaXXqOlqrspG_mwM4,88442
+prettytable/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
diff --git a/lib/prettytable-3.6.0.dist-info/REQUESTED b/lib/prettytable-3.6.0.dist-info/REQUESTED
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/prettytable-3.6.0.dist-info/REQUESTED
diff --git a/lib/prettytable-3.6.0.dist-info/WHEEL b/lib/prettytable-3.6.0.dist-info/WHEEL
new file mode 100644
index 0000000..8d5c0ce
--- /dev/null
+++ b/lib/prettytable-3.6.0.dist-info/WHEEL
@@ -0,0 +1,4 @@
+Wheel-Version: 1.0
+Generator: hatchling 1.12.2
+Root-Is-Purelib: true
+Tag: py3-none-any
diff --git a/lib/prettytable-3.6.0.dist-info/licenses/COPYING b/lib/prettytable-3.6.0.dist-info/licenses/COPYING
new file mode 100644
index 0000000..cb6fed3
--- /dev/null
+++ b/lib/prettytable-3.6.0.dist-info/licenses/COPYING
@@ -0,0 +1,30 @@
+# Copyright (c) 2009-2014 Luke Maurits <luke@maurits.id.au>
+# All rights reserved.
+# With contributions from:
+# * Chris Clark
+# * Klein Stephane
+# * John Filleau
+# * Vladimir Vrzić
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+# * The name of the author may not be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
diff --git a/lib/prettytable/__init__.py b/lib/prettytable/__init__.py
new file mode 100644
index 0000000..d0f5adc
--- /dev/null
+++ b/lib/prettytable/__init__.py
@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+from .prettytable import (
+ ALL,
+ DEFAULT,
+ DOUBLE_BORDER,
+ FRAME,
+ HEADER,
+ MARKDOWN,
+ MSWORD_FRIENDLY,
+ NONE,
+ ORGMODE,
+ PLAIN_COLUMNS,
+ RANDOM,
+ SINGLE_BORDER,
+ PrettyTable,
+ TableHandler,
+ from_csv,
+ from_db_cursor,
+ from_html,
+ from_html_one,
+ from_json,
+)
+
+__all__ = [
+ "ALL",
+ "DEFAULT",
+ "DOUBLE_BORDER",
+ "SINGLE_BORDER",
+ "FRAME",
+ "HEADER",
+ "MARKDOWN",
+ "MSWORD_FRIENDLY",
+ "NONE",
+ "ORGMODE",
+ "PLAIN_COLUMNS",
+ "RANDOM",
+ "PrettyTable",
+ "TableHandler",
+ "from_csv",
+ "from_db_cursor",
+ "from_html",
+ "from_html_one",
+ "from_json",
+]
+
+try:
+ # Python 3.8+
+ import importlib.metadata as importlib_metadata
+except ImportError:
+ # <Python 3.7 and lower
+ import importlib_metadata # type: ignore
+
+__version__ = importlib_metadata.version(__name__)
diff --git a/lib/prettytable/colortable.py b/lib/prettytable/colortable.py
new file mode 100644
index 0000000..9a3a06b
--- /dev/null
+++ b/lib/prettytable/colortable.py
@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+from .prettytable import PrettyTable
+
+try:
+ from colorama import init
+
+ init()
+except ImportError:
+ pass
+
+
+RESET_CODE = "\x1b[0m"
+
+
+class Theme:
+ def __init__(
+ self,
+ default_color: str = "",
+ vertical_char: str = "|",
+ vertical_color: str = "",
+ horizontal_char: str = "-",
+ horizontal_color: str = "",
+ junction_char: str = "+",
+ junction_color: str = "",
+ ) -> None:
+ self.default_color = Theme.format_code(default_color)
+ self.vertical_char = vertical_char
+ self.vertical_color = Theme.format_code(vertical_color)
+ self.horizontal_char = horizontal_char
+ self.horizontal_color = Theme.format_code(horizontal_color)
+ self.junction_char = junction_char
+ self.junction_color = Theme.format_code(junction_color)
+
+ @staticmethod
+ def format_code(s: str) -> str:
+ """Takes string and intelligently puts it into an ANSI escape sequence"""
+ if s.strip() == "":
+ return ""
+ elif s.startswith("\x1b["):
+ return s
+ else:
+ return f"\x1b[{s}m"
+
+
+class Themes:
+ DEFAULT = Theme()
+ OCEAN = Theme(
+ default_color="96",
+ vertical_color="34",
+ horizontal_color="34",
+ junction_color="36",
+ )
+
+
+class ColorTable(PrettyTable):
+ def __init__(self, field_names=None, **kwargs) -> None:
+ super().__init__(field_names=field_names, **kwargs)
+ # TODO: Validate option
+
+ self.theme = kwargs.get("theme") or Themes.DEFAULT
+
+ @property
+ def theme(self) -> Theme:
+ return self._theme
+
+ @theme.setter
+ def theme(self, value: Theme):
+ self._theme = value
+ self.update_theme()
+
+ def update_theme(self) -> None:
+ theme = self._theme
+
+ self._vertical_char = (
+ theme.vertical_color
+ + theme.vertical_char
+ + RESET_CODE
+ + theme.default_color
+ )
+
+ self._horizontal_char = (
+ theme.horizontal_color
+ + theme.horizontal_char
+ + RESET_CODE
+ + theme.default_color
+ )
+
+ self._junction_char = (
+ theme.junction_color
+ + theme.junction_char
+ + RESET_CODE
+ + theme.default_color
+ )
+
+ def get_string(self, **kwargs) -> str:
+ return super().get_string(**kwargs) + RESET_CODE
diff --git a/lib/prettytable/prettytable.py b/lib/prettytable/prettytable.py
new file mode 100644
index 0000000..bb977e7
--- /dev/null
+++ b/lib/prettytable/prettytable.py
@@ -0,0 +1,2531 @@
+#!/usr/bin/env python
+#
+# Copyright (c) 2009-2014, Luke Maurits <luke@maurits.id.au>
+# All rights reserved.
+# With contributions from:
+# * Chris Clark
+# * Klein Stephane
+# * John Filleau
+# * Vladimir Vrzić
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+# * The name of the author may not be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+from __future__ import annotations
+
+import copy
+import csv
+import io
+import json
+import math
+import random
+import re
+import textwrap
+from html import escape
+from html.parser import HTMLParser
+from typing import Any
+
+import wcwidth # type: ignore
+
+# hrule styles
+FRAME = 0
+ALL = 1
+NONE = 2
+HEADER = 3
+
+# Table styles
+DEFAULT = 10
+MSWORD_FRIENDLY = 11
+PLAIN_COLUMNS = 12
+MARKDOWN = 13
+ORGMODE = 14
+DOUBLE_BORDER = 15
+SINGLE_BORDER = 16
+RANDOM = 20
+BASE_ALIGN_VALUE = "base_align_value"
+
+_re = re.compile(r"\033\[[0-9;]*m|\033\(B")
+
+
+def _get_size(text):
+ lines = text.split("\n")
+ height = len(lines)
+ width = max(_str_block_width(line) for line in lines)
+ return width, height
+
+
+class PrettyTable:
+ def __init__(self, field_names=None, **kwargs) -> None:
+ """Return a new PrettyTable instance
+
+ Arguments:
+
+ encoding - Unicode encoding scheme used to decode any encoded input
+ title - optional table title
+ field_names - list or tuple of field names
+ fields - list or tuple of field names to include in displays
+ start - index of first data row to include in output
+ end - index of last data row to include in output PLUS ONE (list slice style)
+ header - print a header showing field names (True or False)
+ header_style - stylisation to apply to field names in header
+ ("cap", "title", "upper", "lower" or None)
+ border - print a border around the table (True or False)
+ preserve_internal_border - print a border inside the table even if
+ border is disabled (True or False)
+ hrules - controls printing of horizontal rules after rows.
+ Allowed values: FRAME, HEADER, ALL, NONE
+ vrules - controls printing of vertical rules between columns.
+ Allowed values: FRAME, ALL, NONE
+ int_format - controls formatting of integer data
+ float_format - controls formatting of floating point data
+ custom_format - controls formatting of any column using callable
+ min_table_width - minimum desired table width, in characters
+ max_table_width - maximum desired table width, in characters
+ min_width - minimum desired field width, in characters
+ max_width - maximum desired field width, in characters
+ padding_width - number of spaces on either side of column data
+ (only used if left and right paddings are None)
+ left_padding_width - number of spaces on left hand side of column data
+ right_padding_width - number of spaces on right hand side of column data
+ vertical_char - single character string used to draw vertical lines
+ horizontal_char - single character string used to draw horizontal lines
+ horizontal_align_char - single character string used to indicate alignment
+ junction_char - single character string used to draw line junctions
+ top_junction_char - single character string used to draw top line junctions
+ bottom_junction_char -
+ single character string used to draw bottom line junctions
+ right_junction_char - single character string used to draw right line junctions
+ left_junction_char - single character string used to draw left line junctions
+ top_right_junction_char -
+ single character string used to draw top-right line junctions
+ top_left_junction_char -
+ single character string used to draw top-left line junctions
+ bottom_right_junction_char -
+ single character string used to draw bottom-right line junctions
+ bottom_left_junction_char -
+ single character string used to draw bottom-left line junctions
+ sortby - name of field to sort rows by
+ sort_key - sorting key function, applied to data points before sorting
+ align - default align for each column (None, "l", "c" or "r")
+ valign - default valign for each row (None, "t", "m" or "b")
+ reversesort - True or False to sort in descending or ascending order
+ oldsortslice - Slice rows before sorting in the "old style" """
+
+ self.encoding = kwargs.get("encoding", "UTF-8")
+
+ # Data
+ self._field_names: list[str] = []
+ self._rows: list[list] = []
+ self.align = {}
+ self.valign = {}
+ self.max_width = {}
+ self.min_width = {}
+ self.int_format = {}
+ self.float_format = {}
+ self.custom_format = {}
+
+ if field_names:
+ self.field_names = field_names
+ else:
+ self._widths: list[int] = []
+
+ # Options
+ self._options = [
+ "title",
+ "start",
+ "end",
+ "fields",
+ "header",
+ "border",
+ "preserve_internal_border",
+ "sortby",
+ "reversesort",
+ "sort_key",
+ "attributes",
+ "format",
+ "hrules",
+ "vrules",
+ "int_format",
+ "float_format",
+ "custom_format",
+ "min_table_width",
+ "max_table_width",
+ "padding_width",
+ "left_padding_width",
+ "right_padding_width",
+ "vertical_char",
+ "horizontal_char",
+ "horizontal_align_char",
+ "junction_char",
+ "header_style",
+ "valign",
+ "xhtml",
+ "print_empty",
+ "oldsortslice",
+ "top_junction_char",
+ "bottom_junction_char",
+ "right_junction_char",
+ "left_junction_char",
+ "top_right_junction_char",
+ "top_left_junction_char",
+ "bottom_right_junction_char",
+ "bottom_left_junction_char",
+ "align",
+ "valign",
+ "max_width",
+ "min_width",
+ "none_format",
+ ]
+ for option in self._options:
+ if option in kwargs:
+ self._validate_option(option, kwargs[option])
+ else:
+ kwargs[option] = None
+
+ self._title = kwargs["title"] or None
+ self._start = kwargs["start"] or 0
+ self._end = kwargs["end"] or None
+ self._fields = kwargs["fields"] or None
+ self._none_format: dict[None, None] = {}
+
+ if kwargs["header"] in (True, False):
+ self._header = kwargs["header"]
+ else:
+ self._header = True
+ self._header_style = kwargs["header_style"] or None
+ if kwargs["border"] in (True, False):
+ self._border = kwargs["border"]
+ else:
+ self._border = True
+ if kwargs["preserve_internal_border"] in (True, False):
+ self._preserve_internal_border = kwargs["preserve_internal_border"]
+ else:
+ self._preserve_internal_border = False
+ self._hrules = kwargs["hrules"] or FRAME
+ self._vrules = kwargs["vrules"] or ALL
+
+ self._sortby = kwargs["sortby"] or None
+ if kwargs["reversesort"] in (True, False):
+ self._reversesort = kwargs["reversesort"]
+ else:
+ self._reversesort = False
+ self._sort_key = kwargs["sort_key"] or (lambda x: x)
+
+ # Column specific arguments, use property.setters
+ self.align = kwargs["align"] or {}
+ self.valign = kwargs["valign"] or {}
+ self.max_width = kwargs["max_width"] or {}
+ self.min_width = kwargs["min_width"] or {}
+ self.int_format = kwargs["int_format"] or {}
+ self.float_format = kwargs["float_format"] or {}
+ self.custom_format = kwargs["custom_format"] or {}
+ self.none_format = kwargs["none_format"] or {}
+
+ self._min_table_width = kwargs["min_table_width"] or None
+ self._max_table_width = kwargs["max_table_width"] or None
+ if kwargs["padding_width"] is None:
+ self._padding_width = 1
+ else:
+ self._padding_width = kwargs["padding_width"]
+ self._left_padding_width = kwargs["left_padding_width"] or None
+ self._right_padding_width = kwargs["right_padding_width"] or None
+
+ self._vertical_char = kwargs["vertical_char"] or "|"
+ self._horizontal_char = kwargs["horizontal_char"] or "-"
+ self._horizontal_align_char = kwargs["horizontal_align_char"]
+ self._junction_char = kwargs["junction_char"] or "+"
+ self._top_junction_char = kwargs["top_junction_char"]
+ self._bottom_junction_char = kwargs["bottom_junction_char"]
+ self._right_junction_char = kwargs["right_junction_char"]
+ self._left_junction_char = kwargs["left_junction_char"]
+ self._top_right_junction_char = kwargs["top_right_junction_char"]
+ self._top_left_junction_char = kwargs["top_left_junction_char"]
+ self._bottom_right_junction_char = kwargs["bottom_right_junction_char"]
+ self._bottom_left_junction_char = kwargs["bottom_left_junction_char"]
+
+ if kwargs["print_empty"] in (True, False):
+ self._print_empty = kwargs["print_empty"]
+ else:
+ self._print_empty = True
+ if kwargs["oldsortslice"] in (True, False):
+ self._oldsortslice = kwargs["oldsortslice"]
+ else:
+ self._oldsortslice = False
+ self._format = kwargs["format"] or False
+ self._xhtml = kwargs["xhtml"] or False
+ self._attributes = kwargs["attributes"] or {}
+
+ def _justify(self, text, width, align):
+ excess = width - _str_block_width(text)
+ if align == "l":
+ return text + excess * " "
+ elif align == "r":
+ return excess * " " + text
+ else:
+ if excess % 2:
+ # Uneven padding
+ # Put more space on right if text is of odd length...
+ if _str_block_width(text) % 2:
+ return (excess // 2) * " " + text + (excess // 2 + 1) * " "
+ # and more space on left if text is of even length
+ else:
+ return (excess // 2 + 1) * " " + text + (excess // 2) * " "
+ # Why distribute extra space this way? To match the behaviour of
+ # the inbuilt str.center() method.
+ else:
+ # Equal padding on either side
+ return (excess // 2) * " " + text + (excess // 2) * " "
+
+ def __getattr__(self, name):
+
+ if name == "rowcount":
+ return len(self._rows)
+ elif name == "colcount":
+ if self._field_names:
+ return len(self._field_names)
+ elif self._rows:
+ return len(self._rows[0])
+ else:
+ return 0
+ else:
+ raise AttributeError(name)
+
+ def __getitem__(self, index):
+
+ new = PrettyTable()
+ new.field_names = self.field_names
+ for attr in self._options:
+ setattr(new, "_" + attr, getattr(self, "_" + attr))
+ setattr(new, "_align", getattr(self, "_align"))
+ if isinstance(index, slice):
+ for row in self._rows[index]:
+ new.add_row(row)
+ elif isinstance(index, int):
+ new.add_row(self._rows[index])
+ else:
+ raise IndexError(f"Index {index} is invalid, must be an integer or slice")
+ return new
+
+ def __str__(self):
+ return self.get_string()
+
+ def __repr__(self):
+ return self.get_string()
+
+ def _repr_html_(self):
+ """
+ Returns get_html_string value by default
+ as the repr call in Jupyter notebook environment
+ """
+ return self.get_html_string()
+
+ ##############################
+ # ATTRIBUTE VALIDATORS #
+ ##############################
+
+ # The method _validate_option is all that should be used elsewhere in the code base
+ # to validate options. It will call the appropriate validation method for that
+ # option. The individual validation methods should never need to be called directly
+ # (although nothing bad will happen if they *are*).
+ # Validation happens in TWO places.
+ # Firstly, in the property setters defined in the ATTRIBUTE MANAGEMENT section.
+ # Secondly, in the _get_options method, where keyword arguments are mixed with
+ # persistent settings
+
+ def _validate_option(self, option, val):
+ if option == "field_names":
+ self._validate_field_names(val)
+ elif option == "none_format":
+ self._validate_none_format(val)
+ elif option in (
+ "start",
+ "end",
+ "max_width",
+ "min_width",
+ "min_table_width",
+ "max_table_width",
+ "padding_width",
+ "left_padding_width",
+ "right_padding_width",
+ "format",
+ ):
+ self._validate_nonnegative_int(option, val)
+ elif option == "sortby":
+ self._validate_field_name(option, val)
+ elif option == "sort_key":
+ self._validate_function(option, val)
+ elif option == "hrules":
+ self._validate_hrules(option, val)
+ elif option == "vrules":
+ self._validate_vrules(option, val)
+ elif option == "fields":
+ self._validate_all_field_names(option, val)
+ elif option in (
+ "header",
+ "border",
+ "preserve_internal_border",
+ "reversesort",
+ "xhtml",
+ "print_empty",
+ "oldsortslice",
+ ):
+ self._validate_true_or_false(option, val)
+ elif option == "header_style":
+ self._validate_header_style(val)
+ elif option == "int_format":
+ self._validate_int_format(option, val)
+ elif option == "float_format":
+ self._validate_float_format(option, val)
+ elif option == "custom_format":
+ for k, formatter in val.items():
+ self._validate_function(f"{option}.{k}", formatter)
+ elif option in (
+ "vertical_char",
+ "horizontal_char",
+ "horizontal_align_char",
+ "junction_char",
+ "top_junction_char",
+ "bottom_junction_char",
+ "right_junction_char",
+ "left_junction_char",
+ "top_right_junction_char",
+ "top_left_junction_char",
+ "bottom_right_junction_char",
+ "bottom_left_junction_char",
+ ):
+ self._validate_single_char(option, val)
+ elif option == "attributes":
+ self._validate_attributes(option, val)
+
+ def _validate_field_names(self, val):
+ # Check for appropriate length
+ if self._field_names:
+ try:
+ assert len(val) == len(self._field_names)
+ except AssertionError:
+ raise ValueError(
+ "Field name list has incorrect number of values, "
+ f"(actual) {len(val)}!={len(self._field_names)} (expected)"
+ )
+ if self._rows:
+ try:
+ assert len(val) == len(self._rows[0])
+ except AssertionError:
+ raise ValueError(
+ "Field name list has incorrect number of values, "
+ f"(actual) {len(val)}!={len(self._rows[0])} (expected)"
+ )
+ # Check for uniqueness
+ try:
+ assert len(val) == len(set(val))
+ except AssertionError:
+ raise ValueError("Field names must be unique")
+
+ def _validate_none_format(self, val):
+ try:
+ if val is not None:
+ assert isinstance(val, str)
+ except AssertionError:
+ raise TypeError(
+ "Replacement for None value must be a string if being supplied."
+ )
+
+ def _validate_header_style(self, val):
+ try:
+ assert val in ("cap", "title", "upper", "lower", None)
+ except AssertionError:
+ raise ValueError(
+ "Invalid header style, use cap, title, upper, lower or None"
+ )
+
+ def _validate_align(self, val):
+ try:
+ assert val in ["l", "c", "r"]
+ except AssertionError:
+ raise ValueError(f"Alignment {val} is invalid, use l, c or r")
+
+ def _validate_valign(self, val):
+ try:
+ assert val in ["t", "m", "b", None]
+ except AssertionError:
+ raise ValueError(f"Alignment {val} is invalid, use t, m, b or None")
+
+ def _validate_nonnegative_int(self, name, val):
+ try:
+ assert int(val) >= 0
+ except AssertionError:
+ raise ValueError(f"Invalid value for {name}: {val}")
+
+ def _validate_true_or_false(self, name, val):
+ try:
+ assert val in (True, False)
+ except AssertionError:
+ raise ValueError(f"Invalid value for {name}. Must be True or False.")
+
+ def _validate_int_format(self, name, val):
+ if val == "":
+ return
+ try:
+ assert isinstance(val, str)
+ assert val.isdigit()
+ except AssertionError:
+ raise ValueError(
+ f"Invalid value for {name}. Must be an integer format string."
+ )
+
+ def _validate_float_format(self, name, val):
+ if val == "":
+ return
+ try:
+ assert isinstance(val, str)
+ assert "." in val
+ bits = val.split(".")
+ assert len(bits) <= 2
+ assert bits[0] == "" or bits[0].isdigit()
+ assert (
+ bits[1] == ""
+ or bits[1].isdigit()
+ or (bits[1][-1] == "f" and bits[1].rstrip("f").isdigit())
+ )
+ except AssertionError:
+ raise ValueError(
+ f"Invalid value for {name}. Must be a float format string."
+ )
+
+ def _validate_function(self, name, val):
+ try:
+ assert hasattr(val, "__call__")
+ except AssertionError:
+ raise ValueError(f"Invalid value for {name}. Must be a function.")
+
+ def _validate_hrules(self, name, val):
+ try:
+ assert val in (ALL, FRAME, HEADER, NONE)
+ except AssertionError:
+ raise ValueError(
+ f"Invalid value for {name}. Must be ALL, FRAME, HEADER or NONE."
+ )
+
+ def _validate_vrules(self, name, val):
+ try:
+ assert val in (ALL, FRAME, NONE)
+ except AssertionError:
+ raise ValueError(f"Invalid value for {name}. Must be ALL, FRAME, or NONE.")
+
+ def _validate_field_name(self, name, val):
+ try:
+ assert (val in self._field_names) or (val is None)
+ except AssertionError:
+ raise ValueError(f"Invalid field name: {val}")
+
+ def _validate_all_field_names(self, name, val):
+ try:
+ for x in val:
+ self._validate_field_name(name, x)
+ except AssertionError:
+ raise ValueError("Fields must be a sequence of field names")
+
+ def _validate_single_char(self, name, val):
+ try:
+ assert _str_block_width(val) == 1
+ except AssertionError:
+ raise ValueError(f"Invalid value for {name}. Must be a string of length 1.")
+
+ def _validate_attributes(self, name, val):
+ try:
+ assert isinstance(val, dict)
+ except AssertionError:
+ raise TypeError("Attributes must be a dictionary of name/value pairs")
+
+ ##############################
+ # ATTRIBUTE MANAGEMENT #
+ ##############################
+ @property
+ def rows(self) -> list[Any]:
+ return self._rows[:]
+
+ @property
+ def xhtml(self) -> bool:
+ """Print <br/> tags if True, <br> tags if False"""
+ return self._xhtml
+
+ @xhtml.setter
+ def xhtml(self, val):
+ self._validate_option("xhtml", val)
+ self._xhtml = val
+
+ @property
+ def none_format(self):
+ return self._none_format
+
+ @none_format.setter
+ def none_format(self, val):
+ if not self._field_names:
+ self._none_format = {}
+ elif val is None or (isinstance(val, dict) and len(val) == 0):
+ for field in self._field_names:
+ self._none_format[field] = None
+ else:
+ self._validate_none_format(val)
+ for field in self._field_names:
+ self._none_format[field] = val
+
+ @property
+ def field_names(self):
+ """List or tuple of field names
+
+ When setting field_names, if there are already field names the new list
+ of field names must be the same length. Columns are renamed and row data
+ remains unchanged."""
+ return self._field_names
+
+ @field_names.setter
+ def field_names(self, val):
+ val = [str(x) for x in val]
+ self._validate_option("field_names", val)
+ old_names = None
+ if self._field_names:
+ old_names = self._field_names[:]
+ self._field_names = val
+ if self._align and old_names:
+ for old_name, new_name in zip(old_names, val):
+ self._align[new_name] = self._align[old_name]
+ for old_name in old_names:
+ if old_name not in self._align:
+ self._align.pop(old_name)
+ elif self._align:
+ for field_name in self._field_names:
+ self._align[field_name] = self._align[BASE_ALIGN_VALUE]
+ else:
+ self.align = "c"
+ if self._valign and old_names:
+ for old_name, new_name in zip(old_names, val):
+ self._valign[new_name] = self._valign[old_name]
+ for old_name in old_names:
+ if old_name not in self._valign:
+ self._valign.pop(old_name)
+ else:
+ self.valign = "t"
+
+ @property
+ def align(self):
+ """Controls alignment of fields
+ Arguments:
+
+ align - alignment, one of "l", "c", or "r" """
+ return self._align
+
+ @align.setter
+ def align(self, val):
+ if val is None or (isinstance(val, dict) and len(val) == 0):
+ if not self._field_names:
+ self._align = {BASE_ALIGN_VALUE: "c"}
+ else:
+ for field in self._field_names:
+ self._align[field] = "c"
+ else:
+ self._validate_align(val)
+ if not self._field_names:
+ self._align = {BASE_ALIGN_VALUE: val}
+ else:
+ for field in self._field_names:
+ self._align[field] = val
+
+ @property
+ def valign(self):
+ """Controls vertical alignment of fields
+ Arguments:
+
+ valign - vertical alignment, one of "t", "m", or "b" """
+ return self._valign
+
+ @valign.setter
+ def valign(self, val):
+ if not self._field_names:
+ self._valign = {}
+ elif val is None or (isinstance(val, dict) and len(val) == 0):
+ for field in self._field_names:
+ self._valign[field] = "t"
+ else:
+ self._validate_valign(val)
+ for field in self._field_names:
+ self._valign[field] = val
+
+ @property
+ def max_width(self):
+ """Controls maximum width of fields
+ Arguments:
+
+ max_width - maximum width integer"""
+ return self._max_width
+
+ @max_width.setter
+ def max_width(self, val):
+ if val is None or (isinstance(val, dict) and len(val) == 0):
+ self._max_width = {}
+ else:
+ self._validate_option("max_width", val)
+ for field in self._field_names:
+ self._max_width[field] = val
+
+ @property
+ def min_width(self):
+ """Controls minimum width of fields
+ Arguments:
+
+ min_width - minimum width integer"""
+ return self._min_width
+
+ @min_width.setter
+ def min_width(self, val):
+ if val is None or (isinstance(val, dict) and len(val) == 0):
+ self._min_width = {}
+ else:
+ self._validate_option("min_width", val)
+ for field in self._field_names:
+ self._min_width[field] = val
+
+ @property
+ def min_table_width(self):
+ return self._min_table_width
+
+ @min_table_width.setter
+ def min_table_width(self, val):
+ self._validate_option("min_table_width", val)
+ self._min_table_width = val
+
+ @property
+ def max_table_width(self):
+ return self._max_table_width
+
+ @max_table_width.setter
+ def max_table_width(self, val):
+ self._validate_option("max_table_width", val)
+ self._max_table_width = val
+
+ @property
+ def fields(self):
+ """List or tuple of field names to include in displays"""
+ return self._fields
+
+ @fields.setter
+ def fields(self, val):
+ self._validate_option("fields", val)
+ self._fields = val
+
+ @property
+ def title(self):
+ """Optional table title
+
+ Arguments:
+
+ title - table title"""
+ return self._title
+
+ @title.setter
+ def title(self, val):
+ self._title = str(val)
+
+ @property
+ def start(self):
+ """Start index of the range of rows to print
+
+ Arguments:
+
+ start - index of first data row to include in output"""
+ return self._start
+
+ @start.setter
+ def start(self, val):
+ self._validate_option("start", val)
+ self._start = val
+
+ @property
+ def end(self):
+ """End index of the range of rows to print
+
+ Arguments:
+
+ end - index of last data row to include in output PLUS ONE (list slice style)"""
+ return self._end
+
+ @end.setter
+ def end(self, val):
+ self._validate_option("end", val)
+ self._end = val
+
+ @property
+ def sortby(self):
+ """Name of field by which to sort rows
+
+ Arguments:
+
+ sortby - field name to sort by"""
+ return self._sortby
+
+ @sortby.setter
+ def sortby(self, val):
+ self._validate_option("sortby", val)
+ self._sortby = val
+
+ @property
+ def reversesort(self):
+ """Controls direction of sorting (ascending vs descending)
+
+ Arguments:
+
+ reveresort - set to True to sort by descending order, or False to sort by
+ ascending order"""
+ return self._reversesort
+
+ @reversesort.setter
+ def reversesort(self, val):
+ self._validate_option("reversesort", val)
+ self._reversesort = val
+
+ @property
+ def sort_key(self):
+ """Sorting key function, applied to data points before sorting
+
+ Arguments:
+
+ sort_key - a function which takes one argument and returns something to be
+ sorted"""
+ return self._sort_key
+
+ @sort_key.setter
+ def sort_key(self, val):
+ self._validate_option("sort_key", val)
+ self._sort_key = val
+
+ @property
+ def header(self):
+ """Controls printing of table header with field names
+
+ Arguments:
+
+ header - print a header showing field names (True or False)"""
+ return self._header
+
+ @header.setter
+ def header(self, val):
+ self._validate_option("header", val)
+ self._header = val
+
+ @property
+ def header_style(self):
+ """Controls stylisation applied to field names in header
+
+ Arguments:
+
+ header_style - stylisation to apply to field names in header
+ ("cap", "title", "upper", "lower" or None)"""
+ return self._header_style
+
+ @header_style.setter
+ def header_style(self, val):
+ self._validate_header_style(val)
+ self._header_style = val
+
+ @property
+ def border(self):
+ """Controls printing of border around table
+
+ Arguments:
+
+ border - print a border around the table (True or False)"""
+ return self._border
+
+ @border.setter
+ def border(self, val):
+ self._validate_option("border", val)
+ self._border = val
+
+ @property
+ def preserve_internal_border(self):
+ """Controls printing of border inside table
+
+ Arguments:
+
+ preserve_internal_border - print a border inside the table even if
+ border is disabled (True or False)"""
+ return self._preserve_internal_border
+
+ @preserve_internal_border.setter
+ def preserve_internal_border(self, val):
+ self._validate_option("preserve_internal_border", val)
+ self._preserve_internal_border = val
+
+ @property
+ def hrules(self):
+ """Controls printing of horizontal rules after rows
+
+ Arguments:
+
+ hrules - horizontal rules style. Allowed values: FRAME, ALL, HEADER, NONE"""
+ return self._hrules
+
+ @hrules.setter
+ def hrules(self, val):
+ self._validate_option("hrules", val)
+ self._hrules = val
+
+ @property
+ def vrules(self):
+ """Controls printing of vertical rules between columns
+
+ Arguments:
+
+ vrules - vertical rules style. Allowed values: FRAME, ALL, NONE"""
+ return self._vrules
+
+ @vrules.setter
+ def vrules(self, val):
+ self._validate_option("vrules", val)
+ self._vrules = val
+
+ @property
+ def int_format(self):
+ """Controls formatting of integer data
+ Arguments:
+
+ int_format - integer format string"""
+ return self._int_format
+
+ @int_format.setter
+ def int_format(self, val):
+ if val is None or (isinstance(val, dict) and len(val) == 0):
+ self._int_format = {}
+ else:
+ self._validate_option("int_format", val)
+ for field in self._field_names:
+ self._int_format[field] = val
+
+ @property
+ def float_format(self):
+ """Controls formatting of floating point data
+ Arguments:
+
+ float_format - floating point format string"""
+ return self._float_format
+
+ @float_format.setter
+ def float_format(self, val):
+ if val is None or (isinstance(val, dict) and len(val) == 0):
+ self._float_format = {}
+ else:
+ self._validate_option("float_format", val)
+ for field in self._field_names:
+ self._float_format[field] = val
+
+ @property
+ def custom_format(self):
+ """Controls formatting of any column using callable
+ Arguments:
+
+ custom_format - Dictionary of field_name and callable"""
+ return self._custom_format
+
+ @custom_format.setter
+ def custom_format(self, val):
+ if val is None:
+ self._custom_format = {}
+ elif isinstance(val, dict):
+ for k, v in val.items():
+ self._validate_function(f"custom_value.{k}", v)
+ self._custom_format = val
+ elif hasattr(val, "__call__"):
+ self._validate_function("custom_value", val)
+ for field in self._field_names:
+ self._custom_format[field] = val
+ else:
+ raise TypeError(
+ "The custom_format property need to be a dictionary or callable"
+ )
+
+ @property
+ def padding_width(self):
+ """The number of empty spaces between a column's edge and its content
+
+ Arguments:
+
+ padding_width - number of spaces, must be a positive integer"""
+ return self._padding_width
+
+ @padding_width.setter
+ def padding_width(self, val):
+ self._validate_option("padding_width", val)
+ self._padding_width = val
+
+ @property
+ def left_padding_width(self):
+ """The number of empty spaces between a column's left edge and its content
+
+ Arguments:
+
+ left_padding - number of spaces, must be a positive integer"""
+ return self._left_padding_width
+
+ @left_padding_width.setter
+ def left_padding_width(self, val):
+ self._validate_option("left_padding_width", val)
+ self._left_padding_width = val
+
+ @property
+ def right_padding_width(self):
+ """The number of empty spaces between a column's right edge and its content
+
+ Arguments:
+
+ right_padding - number of spaces, must be a positive integer"""
+ return self._right_padding_width
+
+ @right_padding_width.setter
+ def right_padding_width(self, val):
+ self._validate_option("right_padding_width", val)
+ self._right_padding_width = val
+
+ @property
+ def vertical_char(self):
+ """The character used when printing table borders to draw vertical lines
+
+ Arguments:
+
+ vertical_char - single character string used to draw vertical lines"""
+ return self._vertical_char
+
+ @vertical_char.setter
+ def vertical_char(self, val):
+ val = str(val)
+ self._validate_option("vertical_char", val)
+ self._vertical_char = val
+
+ @property
+ def horizontal_char(self):
+ """The character used when printing table borders to draw horizontal lines
+
+ Arguments:
+
+ horizontal_char - single character string used to draw horizontal lines"""
+ return self._horizontal_char
+
+ @horizontal_char.setter
+ def horizontal_char(self, val):
+ val = str(val)
+ self._validate_option("horizontal_char", val)
+ self._horizontal_char = val
+
+ @property
+ def horizontal_align_char(self):
+ """The character used to indicate column alignment in horizontal lines
+
+ Arguments:
+
+ horizontal_align_char - single character string used to indicate alignment"""
+ return self._bottom_left_junction_char or self.junction_char
+
+ @horizontal_align_char.setter
+ def horizontal_align_char(self, val):
+ val = str(val)
+ self._validate_option("horizontal_align_char", val)
+ self._horizontal_align_char = val
+
+ @property
+ def junction_char(self):
+ """The character used when printing table borders to draw line junctions
+
+ Arguments:
+
+ junction_char - single character string used to draw line junctions"""
+ return self._junction_char
+
+ @junction_char.setter
+ def junction_char(self, val):
+ val = str(val)
+ self._validate_option("junction_char", val)
+ self._junction_char = val
+
+ @property
+ def top_junction_char(self):
+ """The character used when printing table borders to draw top line junctions
+
+ Arguments:
+
+ top_junction_char - single character string used to draw top line junctions"""
+ return self._top_junction_char or self.junction_char
+
+ @top_junction_char.setter
+ def top_junction_char(self, val):
+ val = str(val)
+ self._validate_option("top_junction_char", val)
+ self._top_junction_char = val
+
+ @property
+ def bottom_junction_char(self):
+ """The character used when printing table borders to draw bottom line junctions
+
+ Arguments:
+
+ bottom_junction_char -
+ single character string used to draw bottom line junctions"""
+ return self._bottom_junction_char or self.junction_char
+
+ @bottom_junction_char.setter
+ def bottom_junction_char(self, val):
+ val = str(val)
+ self._validate_option("bottom_junction_char", val)
+ self._bottom_junction_char = val
+
+ @property
+ def right_junction_char(self):
+ """The character used when printing table borders to draw right line junctions
+
+ Arguments:
+
+ right_junction_char -
+ single character string used to draw right line junctions"""
+ return self._right_junction_char or self.junction_char
+
+ @right_junction_char.setter
+ def right_junction_char(self, val):
+ val = str(val)
+ self._validate_option("right_junction_char", val)
+ self._right_junction_char = val
+
+ @property
+ def left_junction_char(self):
+ """The character used when printing table borders to draw left line junctions
+
+ Arguments:
+
+ left_junction_char - single character string used to draw left line junctions"""
+ return self._left_junction_char or self.junction_char
+
+ @left_junction_char.setter
+ def left_junction_char(self, val):
+ val = str(val)
+ self._validate_option("left_junction_char", val)
+ self._left_junction_char = val
+
+ @property
+ def top_right_junction_char(self):
+ """
+ The character used when printing table borders to draw top-right line junctions
+
+ Arguments:
+
+ top_right_junction_char -
+ single character string used to draw top-right line junctions"""
+ return self._top_right_junction_char or self.junction_char
+
+ @top_right_junction_char.setter
+ def top_right_junction_char(self, val):
+ val = str(val)
+ self._validate_option("top_right_junction_char", val)
+ self._top_right_junction_char = val
+
+ @property
+ def top_left_junction_char(self):
+ """
+ The character used when printing table borders to draw top-left line junctions
+
+ Arguments:
+
+ top_left_junction_char -
+ single character string used to draw top-left line junctions"""
+ return self._top_left_junction_char or self.junction_char
+
+ @top_left_junction_char.setter
+ def top_left_junction_char(self, val):
+ val = str(val)
+ self._validate_option("top_left_junction_char", val)
+ self._top_left_junction_char = val
+
+ @property
+ def bottom_right_junction_char(self):
+ """The character used when printing table borders
+ to draw bottom-right line junctions
+
+ Arguments:
+
+ bottom_right_junction_char -
+ single character string used to draw bottom-right line junctions"""
+ return self._bottom_right_junction_char or self.junction_char
+
+ @bottom_right_junction_char.setter
+ def bottom_right_junction_char(self, val):
+ val = str(val)
+ self._validate_option("bottom_right_junction_char", val)
+ self._bottom_right_junction_char = val
+
+ @property
+ def bottom_left_junction_char(self):
+ """The character used when printing table borders
+ to draw bottom-left line junctions
+
+ Arguments:
+
+ bottom_left_junction_char -
+ single character string used to draw bottom-left line junctions"""
+ return self._bottom_left_junction_char or self.junction_char
+
+ @bottom_left_junction_char.setter
+ def bottom_left_junction_char(self, val):
+ val = str(val)
+ self._validate_option("bottom_left_junction_char", val)
+ self._bottom_left_junction_char = val
+
+ @property
+ def format(self):
+ """Controls whether or not HTML tables are formatted to match styling options
+
+ Arguments:
+
+ format - True or False"""
+ return self._format
+
+ @format.setter
+ def format(self, val):
+ self._validate_option("format", val)
+ self._format = val
+
+ @property
+ def print_empty(self):
+ """Controls whether or not empty tables produce a header and frame or just an
+ empty string
+
+ Arguments:
+
+ print_empty - True or False"""
+ return self._print_empty
+
+ @print_empty.setter
+ def print_empty(self, val):
+ self._validate_option("print_empty", val)
+ self._print_empty = val
+
+ @property
+ def attributes(self):
+ """A dictionary of HTML attribute name/value pairs to be included in the
+ <table> tag when printing HTML
+
+ Arguments:
+
+ attributes - dictionary of attributes"""
+ return self._attributes
+
+ @attributes.setter
+ def attributes(self, val):
+ self._validate_option("attributes", val)
+ self._attributes = val
+
+ @property
+ def oldsortslice(self):
+ """oldsortslice - Slice rows before sorting in the "old style" """
+ return self._oldsortslice
+
+ @oldsortslice.setter
+ def oldsortslice(self, val):
+ self._validate_option("oldsortslice", val)
+ self._oldsortslice = val
+
+ ##############################
+ # OPTION MIXER #
+ ##############################
+
+ def _get_options(self, kwargs):
+
+ options = {}
+ for option in self._options:
+ if option in kwargs:
+ self._validate_option(option, kwargs[option])
+ options[option] = kwargs[option]
+ else:
+ options[option] = getattr(self, option)
+ return options
+
+ ##############################
+ # PRESET STYLE LOGIC #
+ ##############################
+
+ def set_style(self, style) -> None:
+
+ if style == DEFAULT:
+ self._set_default_style()
+ elif style == MSWORD_FRIENDLY:
+ self._set_msword_style()
+ elif style == PLAIN_COLUMNS:
+ self._set_columns_style()
+ elif style == MARKDOWN:
+ self._set_markdown_style()
+ elif style == ORGMODE:
+ self._set_orgmode_style()
+ elif style == DOUBLE_BORDER:
+ self._set_double_border_style()
+ elif style == SINGLE_BORDER:
+ self._set_single_border_style()
+ elif style == RANDOM:
+ self._set_random_style()
+ else:
+ raise ValueError("Invalid pre-set style")
+
+ def _set_orgmode_style(self):
+ self._set_default_style()
+ self.orgmode = True
+
+ def _set_markdown_style(self):
+ self.header = True
+ self.border = True
+ self._hrules = None
+ self.padding_width = 1
+ self.left_padding_width = 1
+ self.right_padding_width = 1
+ self.vertical_char = "|"
+ self.junction_char = "|"
+ self._horizontal_align_char = ":"
+
+ def _set_default_style(self):
+
+ self.header = True
+ self.border = True
+ self._hrules = FRAME
+ self._vrules = ALL
+ self.padding_width = 1
+ self.left_padding_width = 1
+ self.right_padding_width = 1
+ self.vertical_char = "|"
+ self.horizontal_char = "-"
+ self._horizontal_align_char = None
+ self.junction_char = "+"
+ self._top_junction_char = None
+ self._bottom_junction_char = None
+ self._right_junction_char = None
+ self._left_junction_char = None
+ self._top_right_junction_char = None
+ self._top_left_junction_char = None
+ self._bottom_right_junction_char = None
+ self._bottom_left_junction_char = None
+
+ def _set_msword_style(self):
+
+ self.header = True
+ self.border = True
+ self._hrules = NONE
+ self.padding_width = 1
+ self.left_padding_width = 1
+ self.right_padding_width = 1
+ self.vertical_char = "|"
+
+ def _set_columns_style(self):
+
+ self.header = True
+ self.border = False
+ self.padding_width = 1
+ self.left_padding_width = 0
+ self.right_padding_width = 8
+
+ def _set_double_border_style(self):
+ self.horizontal_char = "═"
+ self.vertical_char = "║"
+ self.junction_char = "╬"
+ self.top_junction_char = "╦"
+ self.bottom_junction_char = "╩"
+ self.right_junction_char = "╣"
+ self.left_junction_char = "╠"
+ self.top_right_junction_char = "╗"
+ self.top_left_junction_char = "╔"
+ self.bottom_right_junction_char = "╝"
+ self.bottom_left_junction_char = "╚"
+
+ def _set_single_border_style(self):
+ self.horizontal_char = "─"
+ self.vertical_char = "│"
+ self.junction_char = "┼"
+ self.top_junction_char = "┬"
+ self.bottom_junction_char = "┴"
+ self.right_junction_char = "┤"
+ self.left_junction_char = "├"
+ self.top_right_junction_char = "┐"
+ self.top_left_junction_char = "┌"
+ self.bottom_right_junction_char = "┘"
+ self.bottom_left_junction_char = "└"
+
+ def _set_random_style(self):
+
+ # Just for fun!
+ self.header = random.choice((True, False))
+ self.border = random.choice((True, False))
+ self._hrules = random.choice((ALL, FRAME, HEADER, NONE))
+ self._vrules = random.choice((ALL, FRAME, NONE))
+ self.left_padding_width = random.randint(0, 5)
+ self.right_padding_width = random.randint(0, 5)
+ self.vertical_char = random.choice(r"~!@#$%^&*()_+|-=\{}[];':\",./;<>?")
+ self.horizontal_char = random.choice(r"~!@#$%^&*()_+|-=\{}[];':\",./;<>?")
+ self.junction_char = random.choice(r"~!@#$%^&*()_+|-=\{}[];':\",./;<>?")
+ self.preserve_internal_border = random.choice((True, False))
+
+ ##############################
+ # DATA INPUT METHODS #
+ ##############################
+
+ def add_rows(self, rows) -> None:
+
+ """Add rows to the table
+
+ Arguments:
+
+ rows - rows of data, should be an iterable of lists, each list with as many
+ elements as the table has fields"""
+ for row in rows:
+ self.add_row(row)
+
+ def add_row(self, row) -> None:
+
+ """Add a row to the table
+
+ Arguments:
+
+ row - row of data, should be a list with as many elements as the table
+ has fields"""
+
+ if self._field_names and len(row) != len(self._field_names):
+ raise ValueError(
+ "Row has incorrect number of values, "
+ f"(actual) {len(row)}!={len(self._field_names)} (expected)"
+ )
+ if not self._field_names:
+ self.field_names = [f"Field {n + 1}" for n in range(0, len(row))]
+ self._rows.append(list(row))
+
+ def del_row(self, row_index) -> None:
+
+ """Delete a row from the table
+
+ Arguments:
+
+ row_index - The index of the row you want to delete. Indexing starts at 0."""
+
+ if row_index > len(self._rows) - 1:
+ raise IndexError(
+ f"Can't delete row at index {row_index}, "
+ f"table only has {len(self._rows)} rows"
+ )
+ del self._rows[row_index]
+
+ def add_column(
+ self, fieldname, column, align: str = "c", valign: str = "t"
+ ) -> None:
+
+ """Add a column to the table.
+
+ Arguments:
+
+ fieldname - name of the field to contain the new column of data
+ column - column of data, should be a list with as many elements as the
+ table has rows
+ align - desired alignment for this column - "l" for left, "c" for centre and
+ "r" for right
+ valign - desired vertical alignment for new columns - "t" for top,
+ "m" for middle and "b" for bottom"""
+
+ if len(self._rows) in (0, len(column)):
+ self._validate_align(align)
+ self._validate_valign(valign)
+ self._field_names.append(fieldname)
+ self._align[fieldname] = align
+ self._valign[fieldname] = valign
+ for i in range(0, len(column)):
+ if len(self._rows) < i + 1:
+ self._rows.append([])
+ self._rows[i].append(column[i])
+ else:
+ raise ValueError(
+ f"Column length {len(column)} does not match number of rows "
+ f"{len(self._rows)}"
+ )
+
+ def add_autoindex(self, fieldname: str = "Index"):
+ """Add an auto-incrementing index column to the table.
+ Arguments:
+ fieldname - name of the field to contain the new column of data"""
+ self._field_names.insert(0, fieldname)
+ self._align[fieldname] = self.align
+ self._valign[fieldname] = self.valign
+ for i, row in enumerate(self._rows):
+ row.insert(0, i + 1)
+
+ def del_column(self, fieldname) -> None:
+
+ """Delete a column from the table
+
+ Arguments:
+
+ fieldname - The field name of the column you want to delete."""
+
+ if fieldname not in self._field_names:
+ raise ValueError(
+ "Can't delete column %r which is not a field name of this table."
+ " Field names are: %s"
+ % (fieldname, ", ".join(map(repr, self._field_names)))
+ )
+
+ col_index = self._field_names.index(fieldname)
+ del self._field_names[col_index]
+ for row in self._rows:
+ del row[col_index]
+
+ def clear_rows(self) -> None:
+
+ """Delete all rows from the table but keep the current field names"""
+
+ self._rows = []
+
+ def clear(self) -> None:
+
+ """Delete all rows and field names from the table, maintaining nothing but
+ styling options"""
+
+ self._rows = []
+ self._field_names = []
+ self._widths = []
+
+ ##############################
+ # MISC PUBLIC METHODS #
+ ##############################
+
+ def copy(self):
+ return copy.deepcopy(self)
+
+ ##############################
+ # MISC PRIVATE METHODS #
+ ##############################
+
+ def _format_value(self, field, value):
+ if isinstance(value, int) and field in self._int_format:
+ return ("%%%sd" % self._int_format[field]) % value
+ elif isinstance(value, float) and field in self._float_format:
+ return ("%%%sf" % self._float_format[field]) % value
+
+ formatter = self._custom_format.get(field, (lambda f, v: str(v)))
+ return formatter(field, value)
+
+ def _compute_table_width(self, options):
+ table_width = 2 if options["vrules"] in (FRAME, ALL) else 0
+ per_col_padding = sum(self._get_padding_widths(options))
+ for index, fieldname in enumerate(self.field_names):
+ if not options["fields"] or (
+ options["fields"] and fieldname in options["fields"]
+ ):
+ table_width += self._widths[index] + per_col_padding
+ return table_width
+
+ def _compute_widths(self, rows, options):
+ if options["header"]:
+ widths = [_get_size(field)[0] for field in self._field_names]
+ else:
+ widths = len(self.field_names) * [0]
+
+ for row in rows:
+ for index, value in enumerate(row):
+ fieldname = self.field_names[index]
+ if self.none_format.get(fieldname) is not None:
+ if value == "None" or value is None:
+ value = self.none_format.get(fieldname)
+ if fieldname in self.max_width:
+ widths[index] = max(
+ widths[index],
+ min(_get_size(value)[0], self.max_width[fieldname]),
+ )
+ else:
+ widths[index] = max(widths[index], _get_size(value)[0])
+ if fieldname in self.min_width:
+ widths[index] = max(widths[index], self.min_width[fieldname])
+ self._widths = widths
+
+ # Are we exceeding max_table_width?
+ if self._max_table_width:
+ table_width = self._compute_table_width(options)
+ if table_width > self._max_table_width:
+ # Shrink widths in proportion
+ scale = 1.0 * self._max_table_width / table_width
+ widths = [int(math.floor(w * scale)) for w in widths]
+ self._widths = widths
+
+ # Are we under min_table_width or title width?
+ if self._min_table_width or options["title"]:
+ if options["title"]:
+ title_width = len(options["title"]) + sum(
+ self._get_padding_widths(options)
+ )
+ if options["vrules"] in (FRAME, ALL):
+ title_width += 2
+ else:
+ title_width = 0
+ min_table_width = self.min_table_width or 0
+ min_width = max(title_width, min_table_width)
+ if options["border"]:
+ borders = len(widths) + 1
+ elif options["preserve_internal_border"]:
+ borders = len(widths)
+ else:
+ borders = 0
+
+ # Subtract padding for each column and borders
+ min_width -= (
+ sum([sum(self._get_padding_widths(options)) for _ in widths]) + borders
+ )
+ # What is being scaled is content so we sum column widths
+ content_width = sum(widths) or 1
+
+ if content_width < min_width:
+ # Grow widths in proportion
+ scale = 1.0 * min_width / content_width
+ widths = [int(math.floor(w * scale)) for w in widths]
+ if sum(widths) < min_width:
+ widths[-1] += min_width - sum(widths)
+ self._widths = widths
+
+ def _get_padding_widths(self, options):
+
+ if options["left_padding_width"] is not None:
+ lpad = options["left_padding_width"]
+ else:
+ lpad = options["padding_width"]
+ if options["right_padding_width"] is not None:
+ rpad = options["right_padding_width"]
+ else:
+ rpad = options["padding_width"]
+ return lpad, rpad
+
+ def _get_rows(self, options):
+ """Return only those data rows that should be printed, based on slicing and
+ sorting.
+
+ Arguments:
+
+ options - dictionary of option settings."""
+
+ if options["oldsortslice"]:
+ rows = copy.deepcopy(self._rows[options["start"] : options["end"]])
+ else:
+ rows = copy.deepcopy(self._rows)
+
+ # Sort
+ if options["sortby"]:
+ sortindex = self._field_names.index(options["sortby"])
+ # Decorate
+ rows = [[row[sortindex]] + row for row in rows]
+ # Sort
+ rows.sort(reverse=options["reversesort"], key=options["sort_key"])
+ # Undecorate
+ rows = [row[1:] for row in rows]
+
+ # Slice if necessary
+ if not options["oldsortslice"]:
+ rows = rows[options["start"] : options["end"]]
+
+ return rows
+
+ def _format_row(self, row):
+ return [
+ self._format_value(field, value)
+ for (field, value) in zip(self._field_names, row)
+ ]
+
+ def _format_rows(self, rows):
+ return [self._format_row(row) for row in rows]
+
+ ##############################
+ # PLAIN TEXT STRING METHODS #
+ ##############################
+
+ def get_string(self, **kwargs) -> str:
+
+ """Return string representation of table in current state.
+
+ Arguments:
+
+ title - optional table title
+ start - index of first data row to include in output
+ end - index of last data row to include in output PLUS ONE (list slice style)
+ fields - names of fields (columns) to include
+ header - print a header showing field names (True or False)
+ border - print a border around the table (True or False)
+ preserve_internal_border - print a border inside the table even if
+ border is disabled (True or False)
+ hrules - controls printing of horizontal rules after rows.
+ Allowed values: ALL, FRAME, HEADER, NONE
+ vrules - controls printing of vertical rules between columns.
+ Allowed values: FRAME, ALL, NONE
+ int_format - controls formatting of integer data
+ float_format - controls formatting of floating point data
+ custom_format - controls formatting of any column using callable
+ padding_width - number of spaces on either side of column data (only used if
+ left and right paddings are None)
+ left_padding_width - number of spaces on left hand side of column data
+ right_padding_width - number of spaces on right hand side of column data
+ vertical_char - single character string used to draw vertical lines
+ horizontal_char - single character string used to draw horizontal lines
+ horizontal_align_char - single character string used to indicate alignment
+ junction_char - single character string used to draw line junctions
+ junction_char - single character string used to draw line junctions
+ top_junction_char - single character string used to draw top line junctions
+ bottom_junction_char -
+ single character string used to draw bottom line junctions
+ right_junction_char - single character string used to draw right line junctions
+ left_junction_char - single character string used to draw left line junctions
+ top_right_junction_char -
+ single character string used to draw top-right line junctions
+ top_left_junction_char -
+ single character string used to draw top-left line junctions
+ bottom_right_junction_char -
+ single character string used to draw bottom-right line junctions
+ bottom_left_junction_char -
+ single character string used to draw bottom-left line junctions
+ sortby - name of field to sort rows by
+ sort_key - sorting key function, applied to data points before sorting
+ reversesort - True or False to sort in descending or ascending order
+ print empty - if True, stringify just the header for an empty table,
+ if False return an empty string"""
+
+ options = self._get_options(kwargs)
+
+ lines = []
+
+ # Don't think too hard about an empty table
+ # Is this the desired behaviour? Maybe we should still print the header?
+ if self.rowcount == 0 and (not options["print_empty"] or not options["border"]):
+ return ""
+
+ # Get the rows we need to print, taking into account slicing, sorting, etc.
+ rows = self._get_rows(options)
+
+ # Turn all data in all rows into Unicode, formatted as desired
+ formatted_rows = self._format_rows(rows)
+
+ # Compute column widths
+ self._compute_widths(formatted_rows, options)
+ self._hrule = self._stringify_hrule(options)
+
+ # Add title
+ title = options["title"] or self._title
+ if title:
+ lines.append(self._stringify_title(title, options))
+
+ # Add header or top of border
+ if options["header"]:
+ lines.append(self._stringify_header(options))
+ elif options["border"] and options["hrules"] in (ALL, FRAME):
+ lines.append(self._stringify_hrule(options, where="top_"))
+ if title and options["vrules"] in (ALL, FRAME):
+ lines[-1] = (
+ self.left_junction_char + lines[-1][1:-1] + self.right_junction_char
+ )
+
+ # Add rows
+ for row in formatted_rows[:-1]:
+ lines.append(self._stringify_row(row, options, self._hrule))
+ if formatted_rows:
+ lines.append(
+ self._stringify_row(
+ formatted_rows[-1],
+ options,
+ self._stringify_hrule(options, where="bottom_"),
+ )
+ )
+
+ # Add bottom of border
+ if options["border"] and options["hrules"] == FRAME:
+ lines.append(self._stringify_hrule(options, where="bottom_"))
+
+ if "orgmode" in self.__dict__ and self.orgmode is True:
+ tmp = list()
+ for line in lines:
+ tmp.extend(line.split("\n"))
+ lines = ["|" + line[1:-1] + "|" for line in tmp]
+
+ return "\n".join(lines)
+
+ def _stringify_hrule(self, options, where=""):
+
+ if not options["border"] and not options["preserve_internal_border"]:
+ return ""
+ lpad, rpad = self._get_padding_widths(options)
+ if options["vrules"] in (ALL, FRAME):
+ bits = [options[where + "left_junction_char"]]
+ else:
+ bits = [options["horizontal_char"]]
+ # For tables with no data or fieldnames
+ if not self._field_names:
+ bits.append(options[where + "right_junction_char"])
+ return "".join(bits)
+ for field, width in zip(self._field_names, self._widths):
+ if options["fields"] and field not in options["fields"]:
+ continue
+
+ line = (width + lpad + rpad) * options["horizontal_char"]
+
+ # If necessary, add column alignment characters (e.g. ":" for Markdown)
+ if self._horizontal_align_char:
+ if self._align[field] in ("l", "c"):
+ line = self._horizontal_align_char + line[1:]
+ if self._align[field] in ("c", "r"):
+ line = line[:-1] + self._horizontal_align_char
+
+ bits.append(line)
+ if options["vrules"] == ALL:
+ bits.append(options[where + "junction_char"])
+ else:
+ bits.append(options["horizontal_char"])
+ if options["vrules"] in (ALL, FRAME):
+ bits.pop()
+ bits.append(options[where + "right_junction_char"])
+
+ if options["preserve_internal_border"] and not options["border"]:
+ bits = bits[1:-1]
+
+ return "".join(bits)
+
+ def _stringify_title(self, title, options):
+
+ lines = []
+ lpad, rpad = self._get_padding_widths(options)
+ if options["border"]:
+ if options["vrules"] == ALL:
+ options["vrules"] = FRAME
+ lines.append(self._stringify_hrule(options, "top_"))
+ options["vrules"] = ALL
+ elif options["vrules"] == FRAME:
+ lines.append(self._stringify_hrule(options, "top_"))
+ bits = []
+ endpoint = (
+ options["vertical_char"]
+ if options["vrules"] in (ALL, FRAME) and options["border"]
+ else " "
+ )
+ bits.append(endpoint)
+ title = " " * lpad + title + " " * rpad
+ bits.append(self._justify(title, len(self._hrule) - 2, "c"))
+ bits.append(endpoint)
+ lines.append("".join(bits))
+ return "\n".join(lines)
+
+ def _stringify_header(self, options):
+
+ bits = []
+ lpad, rpad = self._get_padding_widths(options)
+ if options["border"]:
+ if options["hrules"] in (ALL, FRAME):
+ bits.append(self._stringify_hrule(options, "top_"))
+ if options["title"] and options["vrules"] in (ALL, FRAME):
+ bits[-1] = (
+ self.left_junction_char
+ + bits[-1][1:-1]
+ + self.right_junction_char
+ )
+ bits.append("\n")
+ if options["vrules"] in (ALL, FRAME):
+ bits.append(options["vertical_char"])
+ else:
+ bits.append(" ")
+ # For tables with no data or field names
+ if not self._field_names:
+ if options["vrules"] in (ALL, FRAME):
+ bits.append(options["vertical_char"])
+ else:
+ bits.append(" ")
+ for (field, width) in zip(self._field_names, self._widths):
+ if options["fields"] and field not in options["fields"]:
+ continue
+ if self._header_style == "cap":
+ fieldname = field.capitalize()
+ elif self._header_style == "title":
+ fieldname = field.title()
+ elif self._header_style == "upper":
+ fieldname = field.upper()
+ elif self._header_style == "lower":
+ fieldname = field.lower()
+ else:
+ fieldname = field
+ if _str_block_width(fieldname) > width:
+ fieldname = fieldname[:width]
+ bits.append(
+ " " * lpad
+ + self._justify(fieldname, width, self._align[field])
+ + " " * rpad
+ )
+ if options["border"] or options["preserve_internal_border"]:
+ if options["vrules"] == ALL:
+ bits.append(options["vertical_char"])
+ else:
+ bits.append(" ")
+
+ # If only preserve_internal_border is true, then we just appended
+ # a vertical character at the end when we wanted a space
+ if not options["border"] and options["preserve_internal_border"]:
+ bits.pop()
+ bits.append(" ")
+ # If vrules is FRAME, then we just appended a space at the end
+ # of the last field, when we really want a vertical character
+ if options["border"] and options["vrules"] == FRAME:
+ bits.pop()
+ bits.append(options["vertical_char"])
+ if (options["border"] or options["preserve_internal_border"]) and options[
+ "hrules"
+ ] != NONE:
+ bits.append("\n")
+ bits.append(self._hrule)
+ return "".join(bits)
+
+ def _stringify_row(self, row, options, hrule):
+
+ for (index, field, value, width) in zip(
+ range(0, len(row)), self._field_names, row, self._widths
+ ):
+ # Enforce max widths
+ lines = value.split("\n")
+ new_lines = []
+ for line in lines:
+ if line == "None" and self.none_format.get(field) is not None:
+ line = self.none_format[field]
+ if _str_block_width(line) > width:
+ line = textwrap.fill(line, width)
+ new_lines.append(line)
+ lines = new_lines
+ value = "\n".join(lines)
+ row[index] = value
+
+ row_height = 0
+ for c in row:
+ h = _get_size(c)[1]
+ if h > row_height:
+ row_height = h
+
+ bits = []
+ lpad, rpad = self._get_padding_widths(options)
+ for y in range(0, row_height):
+ bits.append([])
+ if options["border"]:
+ if options["vrules"] in (ALL, FRAME):
+ bits[y].append(self.vertical_char)
+ else:
+ bits[y].append(" ")
+
+ for (field, value, width) in zip(self._field_names, row, self._widths):
+
+ valign = self._valign[field]
+ lines = value.split("\n")
+ d_height = row_height - len(lines)
+ if d_height:
+ if valign == "m":
+ lines = (
+ [""] * int(d_height / 2)
+ + lines
+ + [""] * (d_height - int(d_height / 2))
+ )
+ elif valign == "b":
+ lines = [""] * d_height + lines
+ else:
+ lines = lines + [""] * d_height
+
+ y = 0
+ for line in lines:
+ if options["fields"] and field not in options["fields"]:
+ continue
+
+ bits[y].append(
+ " " * lpad
+ + self._justify(line, width, self._align[field])
+ + " " * rpad
+ )
+ if options["border"] or options["preserve_internal_border"]:
+ if options["vrules"] == ALL:
+ bits[y].append(self.vertical_char)
+ else:
+ bits[y].append(" ")
+ y += 1
+
+ # If only preserve_internal_border is true, then we just appended
+ # a vertical character at the end when we wanted a space
+ if not options["border"] and options["preserve_internal_border"]:
+ bits[-1].pop()
+ bits[-1].append(" ")
+
+ # If vrules is FRAME, then we just appended a space at the end
+ # of the last field, when we really want a vertical character
+ for y in range(0, row_height):
+ if options["border"] and options["vrules"] == FRAME:
+ bits[y].pop()
+ bits[y].append(options["vertical_char"])
+
+ if options["border"] and options["hrules"] == ALL:
+ bits[row_height - 1].append("\n")
+ bits[row_height - 1].append(hrule)
+
+ for y in range(0, row_height):
+ bits[y] = "".join(bits[y])
+
+ return "\n".join(bits)
+
+ def paginate(self, page_length: int = 58, line_break: str = "\f", **kwargs):
+
+ pages = []
+ kwargs["start"] = kwargs.get("start", 0)
+ true_end = kwargs.get("end", self.rowcount)
+ while True:
+ kwargs["end"] = min(kwargs["start"] + page_length, true_end)
+ pages.append(self.get_string(**kwargs))
+ if kwargs["end"] == true_end:
+ break
+ kwargs["start"] += page_length
+ return line_break.join(pages)
+
+ ##############################
+ # CSV STRING METHODS #
+ ##############################
+ def get_csv_string(self, **kwargs) -> str:
+
+ """Return string representation of CSV formatted table in the current state
+
+ Keyword arguments are first interpreted as table formatting options, and
+ then any unused keyword arguments are passed to csv.writer(). For
+ example, get_csv_string(header=False, delimiter='\t') would use
+ header as a PrettyTable formatting option (skip the header row) and
+ delimiter as a csv.writer keyword argument.
+ """
+
+ options = self._get_options(kwargs)
+ csv_options = {
+ key: value for key, value in kwargs.items() if key not in options
+ }
+ csv_buffer = io.StringIO()
+ csv_writer = csv.writer(csv_buffer, **csv_options)
+
+ if options.get("header"):
+ csv_writer.writerow(self._field_names)
+ for row in self._get_rows(options):
+ csv_writer.writerow(row)
+
+ return csv_buffer.getvalue()
+
+ ##############################
+ # JSON STRING METHODS #
+ ##############################
+ def get_json_string(self, **kwargs) -> str:
+
+ """Return string representation of JSON formatted table in the current state
+
+ Keyword arguments are first interpreted as table formatting options, and
+ then any unused keyword arguments are passed to json.dumps(). For
+ example, get_json_string(header=False, indent=2) would use header as
+ a PrettyTable formatting option (skip the header row) and indent as a
+ json.dumps keyword argument.
+ """
+
+ options = self._get_options(kwargs)
+ json_options: Any = dict(indent=4, separators=(",", ": "), sort_keys=True)
+ json_options.update(
+ {key: value for key, value in kwargs.items() if key not in options}
+ )
+ objects = []
+
+ if options.get("header"):
+ objects.append(self.field_names)
+ for row in self._get_rows(options):
+ objects.append(dict(zip(self._field_names, row)))
+
+ return json.dumps(objects, **json_options)
+
+ ##############################
+ # HTML STRING METHODS #
+ ##############################
+
+ def get_html_string(self, **kwargs) -> str:
+ """Return string representation of HTML formatted version of table in current
+ state.
+
+ Arguments:
+
+ title - optional table title
+ start - index of first data row to include in output
+ end - index of last data row to include in output PLUS ONE (list slice style)
+ fields - names of fields (columns) to include
+ header - print a header showing field names (True or False)
+ border - print a border around the table (True or False)
+ preserve_internal_border - print a border inside the table even if
+ border is disabled (True or False)
+ hrules - controls printing of horizontal rules after rows.
+ Allowed values: ALL, FRAME, HEADER, NONE
+ vrules - controls printing of vertical rules between columns.
+ Allowed values: FRAME, ALL, NONE
+ int_format - controls formatting of integer data
+ float_format - controls formatting of floating point data
+ custom_format - controls formatting of any column using callable
+ padding_width - number of spaces on either side of column data (only used if
+ left and right paddings are None)
+ left_padding_width - number of spaces on left hand side of column data
+ right_padding_width - number of spaces on right hand side of column data
+ sortby - name of field to sort rows by
+ sort_key - sorting key function, applied to data points before sorting
+ attributes - dictionary of name/value pairs to include as HTML attributes in the
+ <table> tag
+ format - Controls whether or not HTML tables are formatted to match
+ styling options (True or False)
+ xhtml - print <br/> tags if True, <br> tags if False"""
+
+ options = self._get_options(kwargs)
+
+ if options["format"]:
+ string = self._get_formatted_html_string(options)
+ else:
+ string = self._get_simple_html_string(options)
+
+ return string
+
+ def _get_simple_html_string(self, options):
+
+ lines = []
+ if options["xhtml"]:
+ linebreak = "<br/>"
+ else:
+ linebreak = "<br>"
+
+ open_tag = ["<table"]
+ if options["attributes"]:
+ for attr_name in options["attributes"]:
+ open_tag.append(f' {attr_name}="{options["attributes"][attr_name]}"')
+ open_tag.append(">")
+ lines.append("".join(open_tag))
+
+ # Title
+ title = options["title"] or self._title
+ if title:
+ lines.append(f" <caption>{title}</caption>")
+
+ # Headers
+ if options["header"]:
+ lines.append(" <thead>")
+ lines.append(" <tr>")
+ for field in self._field_names:
+ if options["fields"] and field not in options["fields"]:
+ continue
+ lines.append(
+ " <th>%s</th>" % escape(field).replace("\n", linebreak)
+ )
+ lines.append(" </tr>")
+ lines.append(" </thead>")
+
+ # Data
+ lines.append(" <tbody>")
+ rows = self._get_rows(options)
+ formatted_rows = self._format_rows(rows)
+ for row in formatted_rows:
+ lines.append(" <tr>")
+ for field, datum in zip(self._field_names, row):
+ if options["fields"] and field not in options["fields"]:
+ continue
+ lines.append(
+ " <td>%s</td>" % escape(datum).replace("\n", linebreak)
+ )
+ lines.append(" </tr>")
+ lines.append(" </tbody>")
+ lines.append("</table>")
+
+ return "\n".join(lines)
+
+ def _get_formatted_html_string(self, options):
+
+ lines = []
+ lpad, rpad = self._get_padding_widths(options)
+ if options["xhtml"]:
+ linebreak = "<br/>"
+ else:
+ linebreak = "<br>"
+
+ open_tag = ["<table"]
+ if options["border"]:
+ if options["hrules"] == ALL and options["vrules"] == ALL:
+ open_tag.append(' frame="box" rules="all"')
+ elif options["hrules"] == FRAME and options["vrules"] == FRAME:
+ open_tag.append(' frame="box"')
+ elif options["hrules"] == FRAME and options["vrules"] == ALL:
+ open_tag.append(' frame="box" rules="cols"')
+ elif options["hrules"] == FRAME:
+ open_tag.append(' frame="hsides"')
+ elif options["hrules"] == ALL:
+ open_tag.append(' frame="hsides" rules="rows"')
+ elif options["vrules"] == FRAME:
+ open_tag.append(' frame="vsides"')
+ elif options["vrules"] == ALL:
+ open_tag.append(' frame="vsides" rules="cols"')
+ if not options["border"] and options["preserve_internal_border"]:
+ open_tag.append(' rules="cols"')
+ if options["attributes"]:
+ for attr_name in options["attributes"]:
+ open_tag.append(f' {attr_name}="{options["attributes"][attr_name]}"')
+ open_tag.append(">")
+ lines.append("".join(open_tag))
+
+ # Title
+ title = options["title"] or self._title
+ if title:
+ lines.append(f" <caption>{title}</caption>")
+
+ # Headers
+ if options["header"]:
+ lines.append(" <thead>")
+ lines.append(" <tr>")
+ for field in self._field_names:
+ if options["fields"] and field not in options["fields"]:
+ continue
+ lines.append(
+ ' <th style="padding-left: %dem; padding-right: %dem; text-align: center">%s</th>' # noqa: E501
+ % (lpad, rpad, escape(field).replace("\n", linebreak))
+ )
+ lines.append(" </tr>")
+ lines.append(" </thead>")
+
+ # Data
+ lines.append(" <tbody>")
+ rows = self._get_rows(options)
+ formatted_rows = self._format_rows(rows)
+ aligns = []
+ valigns = []
+ for field in self._field_names:
+ aligns.append(
+ {"l": "left", "r": "right", "c": "center"}[self._align[field]]
+ )
+ valigns.append(
+ {"t": "top", "m": "middle", "b": "bottom"}[self._valign[field]]
+ )
+ for row in formatted_rows:
+ lines.append(" <tr>")
+ for field, datum, align, valign in zip(
+ self._field_names, row, aligns, valigns
+ ):
+ if options["fields"] and field not in options["fields"]:
+ continue
+ lines.append(
+ ' <td style="padding-left: %dem; padding-right: %dem; text-align: %s; vertical-align: %s">%s</td>' # noqa: E501
+ % (
+ lpad,
+ rpad,
+ align,
+ valign,
+ escape(datum).replace("\n", linebreak),
+ )
+ )
+ lines.append(" </tr>")
+ lines.append(" </tbody>")
+ lines.append("</table>")
+
+ return "\n".join(lines)
+
+ ##############################
+ # LATEX STRING METHODS #
+ ##############################
+
+ def get_latex_string(self, **kwargs) -> str:
+ """Return string representation of LaTex formatted version of table in current
+ state.
+
+ Arguments:
+
+ start - index of first data row to include in output
+ end - index of last data row to include in output PLUS ONE (list slice style)
+ fields - names of fields (columns) to include
+ header - print a header showing field names (True or False)
+ border - print a border around the table (True or False)
+ preserve_internal_border - print a border inside the table even if
+ border is disabled (True or False)
+ hrules - controls printing of horizontal rules after rows.
+ Allowed values: ALL, FRAME, HEADER, NONE
+ vrules - controls printing of vertical rules between columns.
+ Allowed values: FRAME, ALL, NONE
+ int_format - controls formatting of integer data
+ float_format - controls formatting of floating point data
+ sortby - name of field to sort rows by
+ sort_key - sorting key function, applied to data points before sorting
+ format - Controls whether or not HTML tables are formatted to match
+ styling options (True or False)
+ """
+ options = self._get_options(kwargs)
+
+ if options["format"]:
+ string = self._get_formatted_latex_string(options)
+ else:
+ string = self._get_simple_latex_string(options)
+ return string
+
+ def _get_simple_latex_string(self, options):
+ lines = []
+
+ wanted_fields = []
+ if options["fields"]:
+ wanted_fields = [
+ field for field in self._field_names if field in options["fields"]
+ ]
+ else:
+ wanted_fields = self._field_names
+
+ alignments = "".join([self._align[field] for field in wanted_fields])
+
+ begin_cmd = "\\begin{tabular}{%s}" % alignments
+ lines.append(begin_cmd)
+
+ # Headers
+ if options["header"]:
+ lines.append(" & ".join(wanted_fields) + " \\\\")
+
+ # Data
+ rows = self._get_rows(options)
+ formatted_rows = self._format_rows(rows)
+ for row in formatted_rows:
+ wanted_data = [
+ d for f, d in zip(self._field_names, row) if f in wanted_fields
+ ]
+ lines.append(" & ".join(wanted_data) + " \\\\")
+
+ lines.append("\\end{tabular}")
+
+ return "\r\n".join(lines)
+
+ def _get_formatted_latex_string(self, options):
+ lines = []
+
+ wanted_fields = []
+ if options["fields"]:
+ wanted_fields = [
+ field for field in self._field_names if field in options["fields"]
+ ]
+ else:
+ wanted_fields = self._field_names
+
+ wanted_alignments = [self._align[field] for field in wanted_fields]
+ if options["border"] and options["vrules"] == ALL:
+ alignment_str = "|".join(wanted_alignments)
+ elif not options["border"] and options["preserve_internal_border"]:
+ alignment_str = "|".join(wanted_alignments)
+ else:
+ alignment_str = "".join(wanted_alignments)
+
+ if options["border"] and options["vrules"] in [ALL, FRAME]:
+ alignment_str = "|" + alignment_str + "|"
+
+ begin_cmd = "\\begin{tabular}{%s}" % alignment_str
+ lines.append(begin_cmd)
+ if options["border"] and options["hrules"] in [ALL, FRAME]:
+ lines.append("\\hline")
+
+ # Headers
+ if options["header"]:
+ lines.append(" & ".join(wanted_fields) + " \\\\")
+ if (options["border"] or options["preserve_internal_border"]) and options[
+ "hrules"
+ ] in [ALL, HEADER]:
+ lines.append("\\hline")
+
+ # Data
+ rows = self._get_rows(options)
+ formatted_rows = self._format_rows(rows)
+ rows = self._get_rows(options)
+ for row in formatted_rows:
+ wanted_data = [
+ d for f, d in zip(self._field_names, row) if f in wanted_fields
+ ]
+ lines.append(" & ".join(wanted_data) + " \\\\")
+ if options["border"] and options["hrules"] == ALL:
+ lines.append("\\hline")
+
+ if options["border"] and options["hrules"] == FRAME:
+ lines.append("\\hline")
+
+ lines.append("\\end{tabular}")
+
+ return "\r\n".join(lines)
+
+
+##############################
+# UNICODE WIDTH FUNCTION #
+##############################
+
+
+def _str_block_width(val):
+ return wcwidth.wcswidth(_re.sub("", val))
+
+
+##############################
+# TABLE FACTORIES #
+##############################
+
+
+def from_csv(fp, field_names: Any | None = None, **kwargs):
+ fmtparams = {}
+ for param in [
+ "delimiter",
+ "doublequote",
+ "escapechar",
+ "lineterminator",
+ "quotechar",
+ "quoting",
+ "skipinitialspace",
+ "strict",
+ ]:
+ if param in kwargs:
+ fmtparams[param] = kwargs.pop(param)
+ if fmtparams:
+ reader = csv.reader(fp, **fmtparams)
+ else:
+ dialect = csv.Sniffer().sniff(fp.read(1024))
+ fp.seek(0)
+ reader = csv.reader(fp, dialect)
+
+ table = PrettyTable(**kwargs)
+ if field_names:
+ table.field_names = field_names
+ else:
+ table.field_names = [x.strip() for x in next(reader)]
+
+ for row in reader:
+ table.add_row([x.strip() for x in row])
+
+ return table
+
+
+def from_db_cursor(cursor, **kwargs):
+ if cursor.description:
+ table = PrettyTable(**kwargs)
+ table.field_names = [col[0] for col in cursor.description]
+ for row in cursor.fetchall():
+ table.add_row(row)
+ return table
+
+
+def from_json(json_string, **kwargs):
+ table = PrettyTable(**kwargs)
+ objects = json.loads(json_string)
+ table.field_names = objects[0]
+ for obj in objects[1:]:
+ row = [obj[key] for key in table.field_names]
+ table.add_row(row)
+ return table
+
+
+class TableHandler(HTMLParser):
+ def __init__(self, **kwargs) -> None:
+ HTMLParser.__init__(self)
+ self.kwargs = kwargs
+ self.tables: list[list] = []
+ self.last_row: list[str] = []
+ self.rows: list[Any] = []
+ self.max_row_width = 0
+ self.active = None
+ self.last_content = ""
+ self.is_last_row_header = False
+ self.colspan = 0
+
+ def handle_starttag(self, tag, attrs) -> None:
+ self.active = tag
+ if tag == "th":
+ self.is_last_row_header = True
+ for (key, value) in attrs:
+ if key == "colspan":
+ self.colspan = int(value)
+
+ def handle_endtag(self, tag) -> None:
+ if tag in ["th", "td"]:
+ stripped_content = self.last_content.strip()
+ self.last_row.append(stripped_content)
+ if self.colspan:
+ for i in range(1, self.colspan):
+ self.last_row.append("")
+ self.colspan = 0
+
+ if tag == "tr":
+ self.rows.append((self.last_row, self.is_last_row_header))
+ self.max_row_width = max(self.max_row_width, len(self.last_row))
+ self.last_row = []
+ self.is_last_row_header = False
+ if tag == "table":
+ table = self.generate_table(self.rows)
+ self.tables.append(table)
+ self.rows = []
+ self.last_content = " "
+ self.active = None
+
+ def handle_data(self, data) -> None:
+ self.last_content += data
+
+ def generate_table(self, rows):
+ """
+ Generates from a list of rows a PrettyTable object.
+ """
+ table = PrettyTable(**self.kwargs)
+ for row in self.rows:
+ if len(row[0]) < self.max_row_width:
+ appends = self.max_row_width - len(row[0])
+ for i in range(1, appends):
+ row[0].append("-")
+
+ if row[1]:
+ self.make_fields_unique(row[0])
+ table.field_names = row[0]
+ else:
+ table.add_row(row[0])
+ return table
+
+ def make_fields_unique(self, fields) -> None:
+ """
+ iterates over the row and make each field unique
+ """
+ for i in range(0, len(fields)):
+ for j in range(i + 1, len(fields)):
+ if fields[i] == fields[j]:
+ fields[j] += "'"
+
+
+def from_html(html_code, **kwargs):
+ """
+ Generates a list of PrettyTables from a string of HTML code. Each <table> in
+ the HTML becomes one PrettyTable object.
+ """
+
+ parser = TableHandler(**kwargs)
+ parser.feed(html_code)
+ return parser.tables
+
+
+def from_html_one(html_code, **kwargs):
+ """
+ Generates a PrettyTables from a string of HTML code which contains only a
+ single <table>
+ """
+
+ tables = from_html(html_code, **kwargs)
+ try:
+ assert len(tables) == 1
+ except AssertionError:
+ raise ValueError(
+ "More than one <table> in provided HTML code. Use from_html instead."
+ )
+ return tables[0]
diff --git a/lib/prettytable/py.typed b/lib/prettytable/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/prettytable/py.typed
diff --git a/lib/psutil-5.9.4.dist-info/INSTALLER b/lib/psutil-5.9.4.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/psutil-5.9.4.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/psutil-5.9.4.dist-info/LICENSE b/lib/psutil-5.9.4.dist-info/LICENSE
new file mode 100644
index 0000000..0bf4a7f
--- /dev/null
+++ b/lib/psutil-5.9.4.dist-info/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2009, Jay Loden, Dave Daeschler, Giampaolo Rodola'
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name of the psutil authors nor the names of its contributors
+ may be used to endorse or promote products derived from this software without
+ specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/lib/psutil-5.9.4.dist-info/METADATA b/lib/psutil-5.9.4.dist-info/METADATA
new file mode 100644
index 0000000..528b4cd
--- /dev/null
+++ b/lib/psutil-5.9.4.dist-info/METADATA
@@ -0,0 +1,526 @@
+Metadata-Version: 2.1
+Name: psutil
+Version: 5.9.4
+Summary: Cross-platform lib for process and system monitoring in Python.
+Home-page: https://github.com/giampaolo/psutil
+Author: Giampaolo Rodola
+Author-email: g.rodola@gmail.com
+License: BSD-3-Clause
+Keywords: ps,top,kill,free,lsof,netstat,nice,tty,ionice,uptime,taskmgr,process,df,iotop,iostat,ifconfig,taskset,who,pidof,pmap,smem,pstree,monitoring,ulimit,prlimit,smem,performance,metrics,agent,observability
+Platform: Platform Independent
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Environment :: Console
+Classifier: Environment :: Win32 (MS Windows)
+Classifier: Intended Audience :: Developers
+Classifier: Intended Audience :: Information Technology
+Classifier: Intended Audience :: System Administrators
+Classifier: License :: OSI Approved :: BSD License
+Classifier: Operating System :: MacOS :: MacOS X
+Classifier: Operating System :: Microsoft :: Windows :: Windows 10
+Classifier: Operating System :: Microsoft :: Windows :: Windows 7
+Classifier: Operating System :: Microsoft :: Windows :: Windows 8
+Classifier: Operating System :: Microsoft :: Windows :: Windows 8.1
+Classifier: Operating System :: Microsoft :: Windows :: Windows Server 2003
+Classifier: Operating System :: Microsoft :: Windows :: Windows Server 2008
+Classifier: Operating System :: Microsoft :: Windows :: Windows Vista
+Classifier: Operating System :: Microsoft
+Classifier: Operating System :: OS Independent
+Classifier: Operating System :: POSIX :: AIX
+Classifier: Operating System :: POSIX :: BSD :: FreeBSD
+Classifier: Operating System :: POSIX :: BSD :: NetBSD
+Classifier: Operating System :: POSIX :: BSD :: OpenBSD
+Classifier: Operating System :: POSIX :: BSD
+Classifier: Operating System :: POSIX :: Linux
+Classifier: Operating System :: POSIX :: SunOS/Solaris
+Classifier: Operating System :: POSIX
+Classifier: Programming Language :: C
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: Implementation :: CPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+Classifier: Programming Language :: Python
+Classifier: Topic :: Software Development :: Libraries :: Python Modules
+Classifier: Topic :: Software Development :: Libraries
+Classifier: Topic :: System :: Benchmark
+Classifier: Topic :: System :: Hardware :: Hardware Drivers
+Classifier: Topic :: System :: Hardware
+Classifier: Topic :: System :: Monitoring
+Classifier: Topic :: System :: Networking :: Monitoring :: Hardware Watchdog
+Classifier: Topic :: System :: Networking :: Monitoring
+Classifier: Topic :: System :: Networking
+Classifier: Topic :: System :: Operating System
+Classifier: Topic :: System :: Systems Administration
+Classifier: Topic :: Utilities
+Requires-Python: >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*
+Description-Content-Type: text/x-rst
+License-File: LICENSE
+Provides-Extra: test
+Requires-Dist: ipaddress ; (python_version < "3.0") and extra == 'test'
+Requires-Dist: mock ; (python_version < "3.0") and extra == 'test'
+Requires-Dist: enum34 ; (python_version <= "3.4") and extra == 'test'
+Requires-Dist: pywin32 ; (sys_platform == "win32") and extra == 'test'
+Requires-Dist: wmi ; (sys_platform == "win32") and extra == 'test'
+
+| |downloads| |stars| |forks| |contributors| |coverage|
+| |version| |py-versions| |packages| |license|
+| |github-actions| |appveyor| |doc| |twitter| |tidelift|
+
+.. |downloads| image:: https://img.shields.io/pypi/dm/psutil.svg
+ :target: https://pepy.tech/project/psutil
+ :alt: Downloads
+
+.. |stars| image:: https://img.shields.io/github/stars/giampaolo/psutil.svg
+ :target: https://github.com/giampaolo/psutil/stargazers
+ :alt: Github stars
+
+.. |forks| image:: https://img.shields.io/github/forks/giampaolo/psutil.svg
+ :target: https://github.com/giampaolo/psutil/network/members
+ :alt: Github forks
+
+.. |contributors| image:: https://img.shields.io/github/contributors/giampaolo/psutil.svg
+ :target: https://github.com/giampaolo/psutil/graphs/contributors
+ :alt: Contributors
+
+.. |github-actions| image:: https://img.shields.io/github/workflow/status/giampaolo/psutil/CI?label=Linux%2C%20macOS%2C%20FreeBSD
+ :target: https://github.com/giampaolo/psutil/actions?query=workflow%3Abuild
+ :alt: Linux, macOS, Windows tests
+
+.. |appveyor| image:: https://img.shields.io/appveyor/ci/giampaolo/psutil/master.svg?maxAge=3600&label=Windows
+ :target: https://ci.appveyor.com/project/giampaolo/psutil
+ :alt: Windows tests (Appveyor)
+
+.. |coverage| image:: https://coveralls.io/repos/github/giampaolo/psutil/badge.svg?branch=master
+ :target: https://coveralls.io/github/giampaolo/psutil?branch=master
+ :alt: Test coverage (coverall.io)
+
+.. |doc| image:: https://readthedocs.org/projects/psutil/badge/?version=latest
+ :target: https://psutil.readthedocs.io/en/latest/
+ :alt: Documentation Status
+
+.. |version| image:: https://img.shields.io/pypi/v/psutil.svg?label=pypi
+ :target: https://pypi.org/project/psutil
+ :alt: Latest version
+
+.. |py-versions| image:: https://img.shields.io/pypi/pyversions/psutil.svg
+ :alt: Supported Python versions
+
+.. |packages| image:: https://repology.org/badge/tiny-repos/python:psutil.svg
+ :target: https://repology.org/metapackage/python:psutil/versions
+ :alt: Binary packages
+
+.. |license| image:: https://img.shields.io/pypi/l/psutil.svg
+ :target: https://github.com/giampaolo/psutil/blob/master/LICENSE
+ :alt: License
+
+.. |twitter| image:: https://img.shields.io/twitter/follow/grodola.svg?label=follow&style=flat&logo=twitter&logoColor=4FADFF
+ :target: https://twitter.com/grodola
+ :alt: Twitter Follow
+
+.. |tidelift| image:: https://tidelift.com/badges/github/giampaolo/psutil?style=flat
+ :target: https://tidelift.com/subscription/pkg/pypi-psutil?utm_source=pypi-psutil&utm_medium=referral&utm_campaign=readme
+ :alt: Tidelift
+
+-----
+
+Quick links
+===========
+
+- `Home page <https://github.com/giampaolo/psutil>`_
+- `Install <https://github.com/giampaolo/psutil/blob/master/INSTALL.rst>`_
+- `Documentation <http://psutil.readthedocs.io>`_
+- `Download <https://pypi.org/project/psutil/#files>`_
+- `Forum <http://groups.google.com/group/psutil/topics>`_
+- `StackOverflow <https://stackoverflow.com/questions/tagged/psutil>`_
+- `Blog <https://gmpy.dev/tags/psutil>`_
+- `What's new <https://github.com/giampaolo/psutil/blob/master/HISTORY.rst>`_
+
+
+Summary
+=======
+
+psutil (process and system utilities) is a cross-platform library for
+retrieving information on **running processes** and **system utilization**
+(CPU, memory, disks, network, sensors) in Python.
+It is useful mainly for **system monitoring**, **profiling and limiting process
+resources** and **management of running processes**.
+It implements many functionalities offered by classic UNIX command line tools
+such as *ps, top, iotop, lsof, netstat, ifconfig, free* and others.
+psutil currently supports the following platforms:
+
+- **Linux**
+- **Windows**
+- **macOS**
+- **FreeBSD, OpenBSD**, **NetBSD**
+- **Sun Solaris**
+- **AIX**
+
+Supported Python versions are **2.7**, **3.4+** and
+`PyPy <http://pypy.org/>`__.
+
+Funding
+=======
+
+While psutil is free software and will always be, the project would benefit
+immensely from some funding.
+Keeping up with bug reports and maintenance has become hardly sustainable for
+me alone in terms of time.
+If you're a company that's making significant use of psutil you can consider
+becoming a sponsor via `GitHub Sponsors <https://github.com/sponsors/giampaolo>`__,
+`Open Collective <https://opencollective.com/psutil>`__ or
+`PayPal <https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=A9ZS7PKKRM3S8>`__
+and have your logo displayed in here and psutil `doc <https://psutil.readthedocs.io>`__.
+
+Sponsors
+========
+
+.. image:: https://github.com/giampaolo/psutil/raw/master/docs/_static/tidelift-logo.png
+ :width: 200
+ :alt: Alternative text
+
+`Add your logo <https://github.com/sponsors/giampaolo>`__.
+
+Example usages
+==============
+
+This represents pretty much the whole psutil API.
+
+CPU
+---
+
+.. code-block:: python
+
+ >>> import psutil
+ >>>
+ >>> psutil.cpu_times()
+ scputimes(user=3961.46, nice=169.729, system=2150.659, idle=16900.540, iowait=629.59, irq=0.0, softirq=19.42, steal=0.0, guest=0, nice=0.0)
+ >>>
+ >>> for x in range(3):
+ ... psutil.cpu_percent(interval=1)
+ ...
+ 4.0
+ 5.9
+ 3.8
+ >>>
+ >>> for x in range(3):
+ ... psutil.cpu_percent(interval=1, percpu=True)
+ ...
+ [4.0, 6.9, 3.7, 9.2]
+ [7.0, 8.5, 2.4, 2.1]
+ [1.2, 9.0, 9.9, 7.2]
+ >>>
+ >>> for x in range(3):
+ ... psutil.cpu_times_percent(interval=1, percpu=False)
+ ...
+ scputimes(user=1.5, nice=0.0, system=0.5, idle=96.5, iowait=1.5, irq=0.0, softirq=0.0, steal=0.0, guest=0.0, guest_nice=0.0)
+ scputimes(user=1.0, nice=0.0, system=0.0, idle=99.0, iowait=0.0, irq=0.0, softirq=0.0, steal=0.0, guest=0.0, guest_nice=0.0)
+ scputimes(user=2.0, nice=0.0, system=0.0, idle=98.0, iowait=0.0, irq=0.0, softirq=0.0, steal=0.0, guest=0.0, guest_nice=0.0)
+ >>>
+ >>> psutil.cpu_count()
+ 4
+ >>> psutil.cpu_count(logical=False)
+ 2
+ >>>
+ >>> psutil.cpu_stats()
+ scpustats(ctx_switches=20455687, interrupts=6598984, soft_interrupts=2134212, syscalls=0)
+ >>>
+ >>> psutil.cpu_freq()
+ scpufreq(current=931.42925, min=800.0, max=3500.0)
+ >>>
+ >>> psutil.getloadavg() # also on Windows (emulated)
+ (3.14, 3.89, 4.67)
+
+Memory
+------
+
+.. code-block:: python
+
+ >>> psutil.virtual_memory()
+ svmem(total=10367352832, available=6472179712, percent=37.6, used=8186245120, free=2181107712, active=4748992512, inactive=2758115328, buffers=790724608, cached=3500347392, shared=787554304)
+ >>> psutil.swap_memory()
+ sswap(total=2097147904, used=296128512, free=1801019392, percent=14.1, sin=304193536, sout=677842944)
+ >>>
+
+Disks
+-----
+
+.. code-block:: python
+
+ >>> psutil.disk_partitions()
+ [sdiskpart(device='/dev/sda1', mountpoint='/', fstype='ext4', opts='rw,nosuid', maxfile=255, maxpath=4096),
+ sdiskpart(device='/dev/sda2', mountpoint='/home', fstype='ext', opts='rw', maxfile=255, maxpath=4096)]
+ >>>
+ >>> psutil.disk_usage('/')
+ sdiskusage(total=21378641920, used=4809781248, free=15482871808, percent=22.5)
+ >>>
+ >>> psutil.disk_io_counters(perdisk=False)
+ sdiskio(read_count=719566, write_count=1082197, read_bytes=18626220032, write_bytes=24081764352, read_time=5023392, write_time=63199568, read_merged_count=619166, write_merged_count=812396, busy_time=4523412)
+ >>>
+
+Network
+-------
+
+.. code-block:: python
+
+ >>> psutil.net_io_counters(pernic=True)
+ {'eth0': netio(bytes_sent=485291293, bytes_recv=6004858642, packets_sent=3251564, packets_recv=4787798, errin=0, errout=0, dropin=0, dropout=0),
+ 'lo': netio(bytes_sent=2838627, bytes_recv=2838627, packets_sent=30567, packets_recv=30567, errin=0, errout=0, dropin=0, dropout=0)}
+ >>>
+ >>> psutil.net_connections(kind='tcp')
+ [sconn(fd=115, family=<AddressFamily.AF_INET: 2>, type=<SocketType.SOCK_STREAM: 1>, laddr=addr(ip='10.0.0.1', port=48776), raddr=addr(ip='93.186.135.91', port=80), status='ESTABLISHED', pid=1254),
+ sconn(fd=117, family=<AddressFamily.AF_INET: 2>, type=<SocketType.SOCK_STREAM: 1>, laddr=addr(ip='10.0.0.1', port=43761), raddr=addr(ip='72.14.234.100', port=80), status='CLOSING', pid=2987),
+ ...]
+ >>>
+ >>> psutil.net_if_addrs()
+ {'lo': [snicaddr(family=<AddressFamily.AF_INET: 2>, address='127.0.0.1', netmask='255.0.0.0', broadcast='127.0.0.1', ptp=None),
+ snicaddr(family=<AddressFamily.AF_INET6: 10>, address='::1', netmask='ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff', broadcast=None, ptp=None),
+ snicaddr(family=<AddressFamily.AF_LINK: 17>, address='00:00:00:00:00:00', netmask=None, broadcast='00:00:00:00:00:00', ptp=None)],
+ 'wlan0': [snicaddr(family=<AddressFamily.AF_INET: 2>, address='192.168.1.3', netmask='255.255.255.0', broadcast='192.168.1.255', ptp=None),
+ snicaddr(family=<AddressFamily.AF_INET6: 10>, address='fe80::c685:8ff:fe45:641%wlan0', netmask='ffff:ffff:ffff:ffff::', broadcast=None, ptp=None),
+ snicaddr(family=<AddressFamily.AF_LINK: 17>, address='c4:85:08:45:06:41', netmask=None, broadcast='ff:ff:ff:ff:ff:ff', ptp=None)]}
+ >>>
+ >>> psutil.net_if_stats()
+ {'lo': snicstats(isup=True, duplex=<NicDuplex.NIC_DUPLEX_UNKNOWN: 0>, speed=0, mtu=65536, flags='up,loopback,running'),
+ 'wlan0': snicstats(isup=True, duplex=<NicDuplex.NIC_DUPLEX_FULL: 2>, speed=100, mtu=1500, flags='up,broadcast,running,multicast')}
+ >>>
+
+Sensors
+-------
+
+.. code-block:: python
+
+ >>> import psutil
+ >>> psutil.sensors_temperatures()
+ {'acpitz': [shwtemp(label='', current=47.0, high=103.0, critical=103.0)],
+ 'asus': [shwtemp(label='', current=47.0, high=None, critical=None)],
+ 'coretemp': [shwtemp(label='Physical id 0', current=52.0, high=100.0, critical=100.0),
+ shwtemp(label='Core 0', current=45.0, high=100.0, critical=100.0)]}
+ >>>
+ >>> psutil.sensors_fans()
+ {'asus': [sfan(label='cpu_fan', current=3200)]}
+ >>>
+ >>> psutil.sensors_battery()
+ sbattery(percent=93, secsleft=16628, power_plugged=False)
+ >>>
+
+Other system info
+-----------------
+
+.. code-block:: python
+
+ >>> import psutil
+ >>> psutil.users()
+ [suser(name='giampaolo', terminal='pts/2', host='localhost', started=1340737536.0, pid=1352),
+ suser(name='giampaolo', terminal='pts/3', host='localhost', started=1340737792.0, pid=1788)]
+ >>>
+ >>> psutil.boot_time()
+ 1365519115.0
+ >>>
+
+Process management
+------------------
+
+.. code-block:: python
+
+ >>> import psutil
+ >>> psutil.pids()
+ [1, 2, 3, 4, 5, 6, 7, 46, 48, 50, 51, 178, 182, 222, 223, 224, 268, 1215,
+ 1216, 1220, 1221, 1243, 1244, 1301, 1601, 2237, 2355, 2637, 2774, 3932,
+ 4176, 4177, 4185, 4187, 4189, 4225, 4243, 4245, 4263, 4282, 4306, 4311,
+ 4312, 4313, 4314, 4337, 4339, 4357, 4358, 4363, 4383, 4395, 4408, 4433,
+ 4443, 4445, 4446, 5167, 5234, 5235, 5252, 5318, 5424, 5644, 6987, 7054,
+ 7055, 7071]
+ >>>
+ >>> p = psutil.Process(7055)
+ >>> p
+ psutil.Process(pid=7055, name='python3', status='running', started='09:04:44')
+ >>> p.name()
+ 'python3'
+ >>> p.exe()
+ '/usr/bin/python3'
+ >>> p.cwd()
+ '/home/giampaolo'
+ >>> p.cmdline()
+ ['/usr/bin/python', 'main.py']
+ >>>
+ >>> p.pid
+ 7055
+ >>> p.ppid()
+ 7054
+ >>> p.children(recursive=True)
+ [psutil.Process(pid=29835, name='python3', status='sleeping', started='11:45:38'),
+ psutil.Process(pid=29836, name='python3', status='waking', started='11:43:39')]
+ >>>
+ >>> p.parent()
+ psutil.Process(pid=4699, name='bash', status='sleeping', started='09:06:44')
+ >>> p.parents()
+ [psutil.Process(pid=4699, name='bash', started='09:06:44'),
+ psutil.Process(pid=4689, name='gnome-terminal-server', status='sleeping', started='0:06:44'),
+ psutil.Process(pid=1, name='systemd', status='sleeping', started='05:56:55')]
+ >>>
+ >>> p.status()
+ 'running'
+ >>> p.username()
+ 'giampaolo'
+ >>> p.create_time()
+ 1267551141.5019531
+ >>> p.terminal()
+ '/dev/pts/0'
+ >>>
+ >>> p.uids()
+ puids(real=1000, effective=1000, saved=1000)
+ >>> p.gids()
+ pgids(real=1000, effective=1000, saved=1000)
+ >>>
+ >>> p.cpu_times()
+ pcputimes(user=1.02, system=0.31, children_user=0.32, children_system=0.1, iowait=0.0)
+ >>> p.cpu_percent(interval=1.0)
+ 12.1
+ >>> p.cpu_affinity()
+ [0, 1, 2, 3]
+ >>> p.cpu_affinity([0, 1]) # set
+ >>> p.cpu_num()
+ 1
+ >>>
+ >>> p.memory_info()
+ pmem(rss=10915840, vms=67608576, shared=3313664, text=2310144, lib=0, data=7262208, dirty=0)
+ >>> p.memory_full_info() # "real" USS memory usage (Linux, macOS, Win only)
+ pfullmem(rss=10199040, vms=52133888, shared=3887104, text=2867200, lib=0, data=5967872, dirty=0, uss=6545408, pss=6872064, swap=0)
+ >>> p.memory_percent()
+ 0.7823
+ >>> p.memory_maps()
+ [pmmap_grouped(path='/lib/x8664-linux-gnu/libutil-2.15.so', rss=32768, size=2125824, pss=32768, shared_clean=0, shared_dirty=0, private_clean=20480, private_dirty=12288, referenced=32768, anonymous=12288, swap=0),
+ pmmap_grouped(path='/lib/x8664-linux-gnu/libc-2.15.so', rss=3821568, size=3842048, pss=3821568, shared_clean=0, shared_dirty=0, private_clean=0, private_dirty=3821568, referenced=3575808, anonymous=3821568, swap=0),
+ pmmap_grouped(path='[heap]', rss=32768, size=139264, pss=32768, shared_clean=0, shared_dirty=0, private_clean=0, private_dirty=32768, referenced=32768, anonymous=32768, swap=0),
+ pmmap_grouped(path='[stack]', rss=2465792, size=2494464, pss=2465792, shared_clean=0, shared_dirty=0, private_clean=0, private_dirty=2465792, referenced=2277376, anonymous=2465792, swap=0),
+ ...]
+ >>>
+ >>> p.io_counters()
+ pio(read_count=478001, write_count=59371, read_bytes=700416, write_bytes=69632, read_chars=456232, write_chars=517543)
+ >>>
+ >>> p.open_files()
+ [popenfile(path='/home/giampaolo/monit.py', fd=3, position=0, mode='r', flags=32768),
+ popenfile(path='/var/log/monit.log', fd=4, position=235542, mode='a', flags=33793)]
+ >>>
+ >>> p.connections(kind='tcp')
+ [pconn(fd=115, family=<AddressFamily.AF_INET: 2>, type=<SocketType.SOCK_STREAM: 1>, laddr=addr(ip='10.0.0.1', port=48776), raddr=addr(ip='93.186.135.91', port=80), status='ESTABLISHED'),
+ pconn(fd=117, family=<AddressFamily.AF_INET: 2>, type=<SocketType.SOCK_STREAM: 1>, laddr=addr(ip='10.0.0.1', port=43761), raddr=addr(ip='72.14.234.100', port=80), status='CLOSING')]
+ >>>
+ >>> p.num_threads()
+ 4
+ >>> p.num_fds()
+ 8
+ >>> p.threads()
+ [pthread(id=5234, user_time=22.5, system_time=9.2891),
+ pthread(id=5237, user_time=0.0707, system_time=1.1)]
+ >>>
+ >>> p.num_ctx_switches()
+ pctxsw(voluntary=78, involuntary=19)
+ >>>
+ >>> p.nice()
+ 0
+ >>> p.nice(10) # set
+ >>>
+ >>> p.ionice(psutil.IOPRIO_CLASS_IDLE) # IO priority (Win and Linux only)
+ >>> p.ionice()
+ pionice(ioclass=<IOPriority.IOPRIO_CLASS_IDLE: 3>, value=0)
+ >>>
+ >>> p.rlimit(psutil.RLIMIT_NOFILE, (5, 5)) # set resource limits (Linux only)
+ >>> p.rlimit(psutil.RLIMIT_NOFILE)
+ (5, 5)
+ >>>
+ >>> p.environ()
+ {'LC_PAPER': 'it_IT.UTF-8', 'SHELL': '/bin/bash', 'GREP_OPTIONS': '--color=auto',
+ 'XDG_CONFIG_DIRS': '/etc/xdg/xdg-ubuntu:/usr/share/upstart/xdg:/etc/xdg',
+ ...}
+ >>>
+ >>> p.as_dict()
+ {'status': 'running', 'num_ctx_switches': pctxsw(voluntary=63, involuntary=1), 'pid': 5457, ...}
+ >>> p.is_running()
+ True
+ >>> p.suspend()
+ >>> p.resume()
+ >>>
+ >>> p.terminate()
+ >>> p.kill()
+ >>> p.wait(timeout=3)
+ <Exitcode.EX_OK: 0>
+ >>>
+ >>> psutil.test()
+ USER PID %CPU %MEM VSZ RSS TTY START TIME COMMAND
+ root 1 0.0 0.0 24584 2240 Jun17 00:00 init
+ root 2 0.0 0.0 0 0 Jun17 00:00 kthreadd
+ ...
+ giampaolo 31475 0.0 0.0 20760 3024 /dev/pts/0 Jun19 00:00 python2.4
+ giampaolo 31721 0.0 2.2 773060 181896 00:04 10:30 chrome
+ root 31763 0.0 0.0 0 0 00:05 00:00 kworker/0:1
+ >>>
+
+Further process APIs
+--------------------
+
+.. code-block:: python
+
+ >>> import psutil
+ >>> for proc in psutil.process_iter(['pid', 'name']):
+ ... print(proc.info)
+ ...
+ {'pid': 1, 'name': 'systemd'}
+ {'pid': 2, 'name': 'kthreadd'}
+ {'pid': 3, 'name': 'ksoftirqd/0'}
+ ...
+ >>>
+ >>> psutil.pid_exists(3)
+ True
+ >>>
+ >>> def on_terminate(proc):
+ ... print("process {} terminated".format(proc))
+ ...
+ >>> # waits for multiple processes to terminate
+ >>> gone, alive = psutil.wait_procs(procs_list, timeout=3, callback=on_terminate)
+ >>>
+
+Windows services
+----------------
+
+.. code-block:: python
+
+ >>> list(psutil.win_service_iter())
+ [<WindowsService(name='AeLookupSvc', display_name='Application Experience') at 38850096>,
+ <WindowsService(name='ALG', display_name='Application Layer Gateway Service') at 38850128>,
+ <WindowsService(name='APNMCP', display_name='Ask Update Service') at 38850160>,
+ <WindowsService(name='AppIDSvc', display_name='Application Identity') at 38850192>,
+ ...]
+ >>> s = psutil.win_service_get('alg')
+ >>> s.as_dict()
+ {'binpath': 'C:\\Windows\\System32\\alg.exe',
+ 'description': 'Provides support for 3rd party protocol plug-ins for Internet Connection Sharing',
+ 'display_name': 'Application Layer Gateway Service',
+ 'name': 'alg',
+ 'pid': None,
+ 'start_type': 'manual',
+ 'status': 'stopped',
+ 'username': 'NT AUTHORITY\\LocalService'}
+
+Projects using psutil
+=====================
+
+Here's some I find particularly interesting:
+
+- https://github.com/google/grr
+- https://github.com/facebook/osquery/
+- https://github.com/nicolargo/glances
+- https://github.com/Jahaja/psdash
+- https://github.com/ajenti/ajenti
+- https://github.com/home-assistant/home-assistant/
+
+Portings
+========
+
+- Go: https://github.com/shirou/gopsutil
+- C: https://github.com/hamon-in/cpslib
+- Rust: https://github.com/rust-psutil/rust-psutil
+- Nim: https://github.com/johnscillieri/psutil-nim
+
+
+
diff --git a/lib/psutil-5.9.4.dist-info/RECORD b/lib/psutil-5.9.4.dist-info/RECORD
new file mode 100644
index 0000000..b73248d
--- /dev/null
+++ b/lib/psutil-5.9.4.dist-info/RECORD
@@ -0,0 +1,65 @@
+psutil-5.9.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+psutil-5.9.4.dist-info/LICENSE,sha256=JMEphFAMqgf_3OGe68BjlsXm0kS1c7xsQ49KbvjlbBs,1549
+psutil-5.9.4.dist-info/METADATA,sha256=-nyI-WFnIZIg0TvMn9lK4lTlgq21R_27dW01aqmhhzk,21427
+psutil-5.9.4.dist-info/RECORD,,
+psutil-5.9.4.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+psutil-5.9.4.dist-info/WHEEL,sha256=rgpVBmjjvbINeGKCkWEGd3f40VHMTsDkQj1Lgil82zE,221
+psutil-5.9.4.dist-info/top_level.txt,sha256=gCNhn57wzksDjSAISmgMJ0aiXzQulk0GJhb2-BAyYgw,7
+psutil/__init__.py,sha256=3UMoBMGdPgG_lhpnjnvBGSeeyCpcoFQOfT5LvgsSu5k,87339
+psutil/__pycache__/__init__.cpython-310.pyc,,
+psutil/__pycache__/_common.cpython-310.pyc,,
+psutil/__pycache__/_compat.cpython-310.pyc,,
+psutil/__pycache__/_psaix.cpython-310.pyc,,
+psutil/__pycache__/_psbsd.cpython-310.pyc,,
+psutil/__pycache__/_pslinux.cpython-310.pyc,,
+psutil/__pycache__/_psosx.cpython-310.pyc,,
+psutil/__pycache__/_psposix.cpython-310.pyc,,
+psutil/__pycache__/_pssunos.cpython-310.pyc,,
+psutil/__pycache__/_pswindows.cpython-310.pyc,,
+psutil/_common.py,sha256=-rxGetvopggWEZ9g0PZ79RX6o6xXXQezZt_CICT4Tv0,28228
+psutil/_compat.py,sha256=bQ-sKCbGKTQajkxmnjltTbS5YlrDRDTSJB4vXHJmhh0,15018
+psutil/_psaix.py,sha256=MkYLYA8WTIf9L9PGkkA1ObIOd91KfEExVeB0NdzRx0A,18683
+psutil/_psbsd.py,sha256=6ejDCfMEZmXrbda_xWAQIM-HRvfRwt9tm5_kC8ugwl0,31408
+psutil/_pslinux.py,sha256=AgvDtI6GrsVpAUetLpoxc5FutZwIOjm0Z5_D1VzMd44,86380
+psutil/_psosx.py,sha256=OBj01V3f5yiNgvray-Mf9Q8A5MyqNSkysRi_AtM9hh8,16275
+psutil/_psposix.py,sha256=9_6tt24W5vZljaIZvIFC2LunfvS3nUQoEnT53rQneIU,8245
+psutil/_pssunos.py,sha256=CC4rVr5F3UnruSHF9_5oxFxOEWW14UrP-RzNZi2Aq-U,25493
+psutil/_psutil_linux.abi3.so,sha256=YA_8r5-HWjCHPFgmzttXlF3KD6niApjfYnmSlLb0Bb8,107400
+psutil/_psutil_posix.abi3.so,sha256=nPWIeVjhgE8K8RZhhY-gLRhlP9uDmh3-vMrYLOKenik,71008
+psutil/_pswindows.py,sha256=opgk8yU4Bz3TFfhRfnmeZkIkMOCFdtaDbC1SahYw0FU,37442
+psutil/tests/__init__.py,sha256=bZApp61qs0hKYS0qKTF4ncTQnj4HSks9HaTVeIe7pNc,58793
+psutil/tests/__main__.py,sha256=hhM384jjFQtDF9sTj_DXaBQCXCVLwdyjLil4UTXke8Q,293
+psutil/tests/__pycache__/__init__.cpython-310.pyc,,
+psutil/tests/__pycache__/__main__.cpython-310.pyc,,
+psutil/tests/__pycache__/runner.cpython-310.pyc,,
+psutil/tests/__pycache__/test_aix.cpython-310.pyc,,
+psutil/tests/__pycache__/test_bsd.cpython-310.pyc,,
+psutil/tests/__pycache__/test_connections.cpython-310.pyc,,
+psutil/tests/__pycache__/test_contracts.cpython-310.pyc,,
+psutil/tests/__pycache__/test_linux.cpython-310.pyc,,
+psutil/tests/__pycache__/test_memleaks.cpython-310.pyc,,
+psutil/tests/__pycache__/test_misc.cpython-310.pyc,,
+psutil/tests/__pycache__/test_osx.cpython-310.pyc,,
+psutil/tests/__pycache__/test_posix.cpython-310.pyc,,
+psutil/tests/__pycache__/test_process.cpython-310.pyc,,
+psutil/tests/__pycache__/test_sunos.cpython-310.pyc,,
+psutil/tests/__pycache__/test_system.cpython-310.pyc,,
+psutil/tests/__pycache__/test_testutils.cpython-310.pyc,,
+psutil/tests/__pycache__/test_unicode.cpython-310.pyc,,
+psutil/tests/__pycache__/test_windows.cpython-310.pyc,,
+psutil/tests/runner.py,sha256=ezm1dJbuimOLEYRk_8LrAS1RF-hGT1Kkha_hb8720tY,11204
+psutil/tests/test_aix.py,sha256=B5zO6M4JF5noyt0Tui_GzQTvBh-MjG7Rk5AFzkOmXLM,4508
+psutil/tests/test_bsd.py,sha256=euNO0G8ZnDTB1FgobRdDwP29qaSfRVb64isUIkri4CM,20688
+psutil/tests/test_connections.py,sha256=QWXNRiMSBdROkaPjKJz0fez_dqbHULGDcXFd-N9iKrM,21362
+psutil/tests/test_contracts.py,sha256=H3jCmSeawWgpsPzES27W3TdzCAF-3owyqiPOKUOlkxU,27661
+psutil/tests/test_linux.py,sha256=pAfZBBWAz6_3zjEcuUAEaaZvH5QypTfM0iboWQHpCLM,94469
+psutil/tests/test_memleaks.py,sha256=f650fy6Wmi_-LmZC9QSU_PGnVlFqZwndU3TcZavkbBk,15028
+psutil/tests/test_misc.py,sha256=PmMa-UxpvODNadU-PdeBXtIV9Dohb39WJeW_ZmhSkkM,31471
+psutil/tests/test_osx.py,sha256=0EMdYSzKkG-UPx6Cpb30-wxjO8FeB-qzyW9AQeMbRH0,7762
+psutil/tests/test_posix.py,sha256=zu7HXWA9KRHWKk_EfG9Jt7cqoQ8mVQ-HZ2EhEE8FWb4,15497
+psutil/tests/test_process.py,sha256=4cTQ_1g-t1pIdXt3-NzugjtAjx3G8laz5ah7MUDKBmc,61821
+psutil/tests/test_sunos.py,sha256=-gnzJy9mc6rwovtoXKYJw_h71FaCXgWLp85POlL1RtE,1333
+psutil/tests/test_system.py,sha256=RdLtO-7Q8hMuewmjYhfcbXuIB3zAShSOoVAeA70IAGQ,35924
+psutil/tests/test_testutils.py,sha256=vZ0UiZNOyQsydXXA38Atz0FS1qE7MlyCjHvetaPpSqM,14427
+psutil/tests/test_unicode.py,sha256=HgK3AzYGTRnqJKc6ytzuTUiglF5nZSBiVtkPHavoOfg,12441
+psutil/tests/test_windows.py,sha256=TR5U5rs5fp31Mj13nMcYf2by4aqCNLZi6eVv3F80TOs,35167
diff --git a/lib/psutil-5.9.4.dist-info/REQUESTED b/lib/psutil-5.9.4.dist-info/REQUESTED
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/psutil-5.9.4.dist-info/REQUESTED
diff --git a/lib/psutil-5.9.4.dist-info/WHEEL b/lib/psutil-5.9.4.dist-info/WHEEL
new file mode 100644
index 0000000..cd91456
--- /dev/null
+++ b/lib/psutil-5.9.4.dist-info/WHEEL
@@ -0,0 +1,8 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.37.1)
+Root-Is-Purelib: false
+Tag: cp36-abi3-manylinux_2_12_x86_64
+Tag: cp36-abi3-manylinux2010_x86_64
+Tag: cp36-abi3-manylinux_2_17_x86_64
+Tag: cp36-abi3-manylinux2014_x86_64
+
diff --git a/lib/psutil-5.9.4.dist-info/top_level.txt b/lib/psutil-5.9.4.dist-info/top_level.txt
new file mode 100644
index 0000000..a4d92cc
--- /dev/null
+++ b/lib/psutil-5.9.4.dist-info/top_level.txt
@@ -0,0 +1 @@
+psutil
diff --git a/lib/psutil/__init__.py b/lib/psutil/__init__.py
new file mode 100644
index 0000000..5674279
--- /dev/null
+++ b/lib/psutil/__init__.py
@@ -0,0 +1,2421 @@
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""psutil is a cross-platform library for retrieving information on
+running processes and system utilization (CPU, memory, disks, network,
+sensors) in Python. Supported platforms:
+
+ - Linux
+ - Windows
+ - macOS
+ - FreeBSD
+ - OpenBSD
+ - NetBSD
+ - Sun Solaris
+ - AIX
+
+Works with Python versions 2.7 and 3.4+.
+"""
+
+from __future__ import division
+
+import collections
+import contextlib
+import datetime
+import functools
+import os
+import signal
+import subprocess
+import sys
+import threading
+import time
+
+
+try:
+ import pwd
+except ImportError:
+ pwd = None
+
+from . import _common
+from ._common import AIX
+from ._common import BSD
+from ._common import CONN_CLOSE
+from ._common import CONN_CLOSE_WAIT
+from ._common import CONN_CLOSING
+from ._common import CONN_ESTABLISHED
+from ._common import CONN_FIN_WAIT1
+from ._common import CONN_FIN_WAIT2
+from ._common import CONN_LAST_ACK
+from ._common import CONN_LISTEN
+from ._common import CONN_NONE
+from ._common import CONN_SYN_RECV
+from ._common import CONN_SYN_SENT
+from ._common import CONN_TIME_WAIT
+from ._common import FREEBSD # NOQA
+from ._common import LINUX
+from ._common import MACOS
+from ._common import NETBSD # NOQA
+from ._common import NIC_DUPLEX_FULL
+from ._common import NIC_DUPLEX_HALF
+from ._common import NIC_DUPLEX_UNKNOWN
+from ._common import OPENBSD # NOQA
+from ._common import OSX # deprecated alias
+from ._common import POSIX # NOQA
+from ._common import POWER_TIME_UNKNOWN
+from ._common import POWER_TIME_UNLIMITED
+from ._common import STATUS_DEAD
+from ._common import STATUS_DISK_SLEEP
+from ._common import STATUS_IDLE
+from ._common import STATUS_LOCKED
+from ._common import STATUS_PARKED
+from ._common import STATUS_RUNNING
+from ._common import STATUS_SLEEPING
+from ._common import STATUS_STOPPED
+from ._common import STATUS_TRACING_STOP
+from ._common import STATUS_WAITING
+from ._common import STATUS_WAKING
+from ._common import STATUS_ZOMBIE
+from ._common import SUNOS
+from ._common import WINDOWS
+from ._common import AccessDenied
+from ._common import Error
+from ._common import NoSuchProcess
+from ._common import TimeoutExpired
+from ._common import ZombieProcess
+from ._common import memoize_when_activated
+from ._common import wrap_numbers as _wrap_numbers
+from ._compat import PY3 as _PY3
+from ._compat import PermissionError
+from ._compat import ProcessLookupError
+from ._compat import SubprocessTimeoutExpired as _SubprocessTimeoutExpired
+from ._compat import long
+
+
+if LINUX:
+ # This is public API and it will be retrieved from _pslinux.py
+ # via sys.modules.
+ PROCFS_PATH = "/proc"
+
+ from . import _pslinux as _psplatform
+ from ._pslinux import IOPRIO_CLASS_BE # NOQA
+ from ._pslinux import IOPRIO_CLASS_IDLE # NOQA
+ from ._pslinux import IOPRIO_CLASS_NONE # NOQA
+ from ._pslinux import IOPRIO_CLASS_RT # NOQA
+
+elif WINDOWS:
+ from . import _pswindows as _psplatform
+ from ._psutil_windows import ABOVE_NORMAL_PRIORITY_CLASS # NOQA
+ from ._psutil_windows import BELOW_NORMAL_PRIORITY_CLASS # NOQA
+ from ._psutil_windows import HIGH_PRIORITY_CLASS # NOQA
+ from ._psutil_windows import IDLE_PRIORITY_CLASS # NOQA
+ from ._psutil_windows import NORMAL_PRIORITY_CLASS # NOQA
+ from ._psutil_windows import REALTIME_PRIORITY_CLASS # NOQA
+ from ._pswindows import CONN_DELETE_TCB # NOQA
+ from ._pswindows import IOPRIO_HIGH # NOQA
+ from ._pswindows import IOPRIO_LOW # NOQA
+ from ._pswindows import IOPRIO_NORMAL # NOQA
+ from ._pswindows import IOPRIO_VERYLOW # NOQA
+
+elif MACOS:
+ from . import _psosx as _psplatform
+
+elif BSD:
+ from . import _psbsd as _psplatform
+
+elif SUNOS:
+ from . import _pssunos as _psplatform
+ from ._pssunos import CONN_BOUND # NOQA
+ from ._pssunos import CONN_IDLE # NOQA
+
+ # This is public writable API which is read from _pslinux.py and
+ # _pssunos.py via sys.modules.
+ PROCFS_PATH = "/proc"
+
+elif AIX:
+ from . import _psaix as _psplatform
+
+ # This is public API and it will be retrieved from _pslinux.py
+ # via sys.modules.
+ PROCFS_PATH = "/proc"
+
+else: # pragma: no cover
+ raise NotImplementedError('platform %s is not supported' % sys.platform)
+
+
+__all__ = [
+ # exceptions
+ "Error", "NoSuchProcess", "ZombieProcess", "AccessDenied",
+ "TimeoutExpired",
+
+ # constants
+ "version_info", "__version__",
+
+ "STATUS_RUNNING", "STATUS_IDLE", "STATUS_SLEEPING", "STATUS_DISK_SLEEP",
+ "STATUS_STOPPED", "STATUS_TRACING_STOP", "STATUS_ZOMBIE", "STATUS_DEAD",
+ "STATUS_WAKING", "STATUS_LOCKED", "STATUS_WAITING", "STATUS_LOCKED",
+ "STATUS_PARKED",
+
+ "CONN_ESTABLISHED", "CONN_SYN_SENT", "CONN_SYN_RECV", "CONN_FIN_WAIT1",
+ "CONN_FIN_WAIT2", "CONN_TIME_WAIT", "CONN_CLOSE", "CONN_CLOSE_WAIT",
+ "CONN_LAST_ACK", "CONN_LISTEN", "CONN_CLOSING", "CONN_NONE",
+ # "CONN_IDLE", "CONN_BOUND",
+
+ "AF_LINK",
+
+ "NIC_DUPLEX_FULL", "NIC_DUPLEX_HALF", "NIC_DUPLEX_UNKNOWN",
+
+ "POWER_TIME_UNKNOWN", "POWER_TIME_UNLIMITED",
+
+ "BSD", "FREEBSD", "LINUX", "NETBSD", "OPENBSD", "MACOS", "OSX", "POSIX",
+ "SUNOS", "WINDOWS", "AIX",
+
+ # "RLIM_INFINITY", "RLIMIT_AS", "RLIMIT_CORE", "RLIMIT_CPU", "RLIMIT_DATA",
+ # "RLIMIT_FSIZE", "RLIMIT_LOCKS", "RLIMIT_MEMLOCK", "RLIMIT_NOFILE",
+ # "RLIMIT_NPROC", "RLIMIT_RSS", "RLIMIT_STACK", "RLIMIT_MSGQUEUE",
+ # "RLIMIT_NICE", "RLIMIT_RTPRIO", "RLIMIT_RTTIME", "RLIMIT_SIGPENDING",
+
+ # classes
+ "Process", "Popen",
+
+ # functions
+ "pid_exists", "pids", "process_iter", "wait_procs", # proc
+ "virtual_memory", "swap_memory", # memory
+ "cpu_times", "cpu_percent", "cpu_times_percent", "cpu_count", # cpu
+ "cpu_stats", # "cpu_freq", "getloadavg"
+ "net_io_counters", "net_connections", "net_if_addrs", # network
+ "net_if_stats",
+ "disk_io_counters", "disk_partitions", "disk_usage", # disk
+ # "sensors_temperatures", "sensors_battery", "sensors_fans" # sensors
+ "users", "boot_time", # others
+]
+
+
+__all__.extend(_psplatform.__extra__all__)
+
+# Linux, FreeBSD
+if hasattr(_psplatform.Process, "rlimit"):
+ # Populate global namespace with RLIM* constants.
+ from . import _psutil_posix
+
+ _globals = globals()
+ _name = None
+ for _name in dir(_psutil_posix):
+ if _name.startswith('RLIM') and _name.isupper():
+ _globals[_name] = getattr(_psutil_posix, _name)
+ __all__.append(_name)
+ del _globals, _name
+
+AF_LINK = _psplatform.AF_LINK
+
+__author__ = "Giampaolo Rodola'"
+__version__ = "5.9.4"
+version_info = tuple([int(num) for num in __version__.split('.')])
+
+_timer = getattr(time, 'monotonic', time.time)
+_TOTAL_PHYMEM = None
+_LOWEST_PID = None
+_SENTINEL = object()
+
+# Sanity check in case the user messed up with psutil installation
+# or did something weird with sys.path. In this case we might end
+# up importing a python module using a C extension module which
+# was compiled for a different version of psutil.
+# We want to prevent that by failing sooner rather than later.
+# See: https://github.com/giampaolo/psutil/issues/564
+if (int(__version__.replace('.', '')) !=
+ getattr(_psplatform.cext, 'version', None)):
+ msg = "version conflict: %r C extension module was built for another " \
+ "version of psutil" % _psplatform.cext.__file__
+ if hasattr(_psplatform.cext, 'version'):
+ msg += " (%s instead of %s)" % (
+ '.'.join([x for x in str(_psplatform.cext.version)]), __version__)
+ else:
+ msg += " (different than %s)" % __version__
+ msg += "; you may try to 'pip uninstall psutil', manually remove %s" % (
+ getattr(_psplatform.cext, "__file__",
+ "the existing psutil install directory"))
+ msg += " or clean the virtual env somehow, then reinstall"
+ raise ImportError(msg)
+
+
+# =====================================================================
+# --- Utils
+# =====================================================================
+
+
+if hasattr(_psplatform, 'ppid_map'):
+ # Faster version (Windows and Linux).
+ _ppid_map = _psplatform.ppid_map
+else: # pragma: no cover
+ def _ppid_map():
+ """Return a {pid: ppid, ...} dict for all running processes in
+ one shot. Used to speed up Process.children().
+ """
+ ret = {}
+ for pid in pids():
+ try:
+ ret[pid] = _psplatform.Process(pid).ppid()
+ except (NoSuchProcess, ZombieProcess):
+ pass
+ return ret
+
+
+def _assert_pid_not_reused(fun):
+ """Decorator which raises NoSuchProcess in case a process is no
+ longer running or its PID has been reused.
+ """
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ if not self.is_running():
+ if self._pid_reused:
+ msg = "process no longer exists and its PID has been reused"
+ else:
+ msg = None
+ raise NoSuchProcess(self.pid, self._name, msg=msg)
+ return fun(self, *args, **kwargs)
+ return wrapper
+
+
+def _pprint_secs(secs):
+ """Format seconds in a human readable form."""
+ now = time.time()
+ secs_ago = int(now - secs)
+ if secs_ago < 60 * 60 * 24:
+ fmt = "%H:%M:%S"
+ else:
+ fmt = "%Y-%m-%d %H:%M:%S"
+ return datetime.datetime.fromtimestamp(secs).strftime(fmt)
+
+
+# =====================================================================
+# --- Process class
+# =====================================================================
+
+
+class Process(object):
+ """Represents an OS process with the given PID.
+ If PID is omitted current process PID (os.getpid()) is used.
+ Raise NoSuchProcess if PID does not exist.
+
+ Note that most of the methods of this class do not make sure
+ the PID of the process being queried has been reused over time.
+ That means you might end up retrieving an information referring
+ to another process in case the original one this instance
+ refers to is gone in the meantime.
+
+ The only exceptions for which process identity is pre-emptively
+ checked and guaranteed are:
+
+ - parent()
+ - children()
+ - nice() (set)
+ - ionice() (set)
+ - rlimit() (set)
+ - cpu_affinity (set)
+ - suspend()
+ - resume()
+ - send_signal()
+ - terminate()
+ - kill()
+
+ To prevent this problem for all other methods you can:
+ - use is_running() before querying the process
+ - if you're continuously iterating over a set of Process
+ instances use process_iter() which pre-emptively checks
+ process identity for every yielded instance
+ """
+
+ def __init__(self, pid=None):
+ self._init(pid)
+
+ def _init(self, pid, _ignore_nsp=False):
+ if pid is None:
+ pid = os.getpid()
+ else:
+ if not _PY3 and not isinstance(pid, (int, long)):
+ raise TypeError('pid must be an integer (got %r)' % pid)
+ if pid < 0:
+ raise ValueError('pid must be a positive integer (got %s)'
+ % pid)
+ self._pid = pid
+ self._name = None
+ self._exe = None
+ self._create_time = None
+ self._gone = False
+ self._pid_reused = False
+ self._hash = None
+ self._lock = threading.RLock()
+ # used for caching on Windows only (on POSIX ppid may change)
+ self._ppid = None
+ # platform-specific modules define an _psplatform.Process
+ # implementation class
+ self._proc = _psplatform.Process(pid)
+ self._last_sys_cpu_times = None
+ self._last_proc_cpu_times = None
+ self._exitcode = _SENTINEL
+ # cache creation time for later use in is_running() method
+ try:
+ self.create_time()
+ except AccessDenied:
+ # We should never get here as AFAIK we're able to get
+ # process creation time on all platforms even as a
+ # limited user.
+ pass
+ except ZombieProcess:
+ # Zombies can still be queried by this class (although
+ # not always) and pids() return them so just go on.
+ pass
+ except NoSuchProcess:
+ if not _ignore_nsp:
+ raise NoSuchProcess(pid, msg='process PID not found')
+ else:
+ self._gone = True
+ # This pair is supposed to identify a Process instance
+ # univocally over time (the PID alone is not enough as
+ # it might refer to a process whose PID has been reused).
+ # This will be used later in __eq__() and is_running().
+ self._ident = (self.pid, self._create_time)
+
+ def __str__(self):
+ info = collections.OrderedDict()
+ info["pid"] = self.pid
+ if self._name:
+ info['name'] = self._name
+ with self.oneshot():
+ try:
+ info["name"] = self.name()
+ info["status"] = self.status()
+ except ZombieProcess:
+ info["status"] = "zombie"
+ except NoSuchProcess:
+ info["status"] = "terminated"
+ except AccessDenied:
+ pass
+ if self._exitcode not in (_SENTINEL, None):
+ info["exitcode"] = self._exitcode
+ if self._create_time:
+ info['started'] = _pprint_secs(self._create_time)
+ return "%s.%s(%s)" % (
+ self.__class__.__module__,
+ self.__class__.__name__,
+ ", ".join(["%s=%r" % (k, v) for k, v in info.items()]))
+
+ __repr__ = __str__
+
+ def __eq__(self, other):
+ # Test for equality with another Process object based
+ # on PID and creation time.
+ if not isinstance(other, Process):
+ return NotImplemented
+ return self._ident == other._ident
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ if self._hash is None:
+ self._hash = hash(self._ident)
+ return self._hash
+
+ @property
+ def pid(self):
+ """The process PID."""
+ return self._pid
+
+ # --- utility methods
+
+ @contextlib.contextmanager
+ def oneshot(self):
+ """Utility context manager which considerably speeds up the
+ retrieval of multiple process information at the same time.
+
+ Internally different process info (e.g. name, ppid, uids,
+ gids, ...) may be fetched by using the same routine, but
+ only one information is returned and the others are discarded.
+ When using this context manager the internal routine is
+ executed once (in the example below on name()) and the
+ other info are cached.
+
+ The cache is cleared when exiting the context manager block.
+ The advice is to use this every time you retrieve more than
+ one information about the process. If you're lucky, you'll
+ get a hell of a speedup.
+
+ >>> import psutil
+ >>> p = psutil.Process()
+ >>> with p.oneshot():
+ ... p.name() # collect multiple info
+ ... p.cpu_times() # return cached value
+ ... p.cpu_percent() # return cached value
+ ... p.create_time() # return cached value
+ ...
+ >>>
+ """
+ with self._lock:
+ if hasattr(self, "_cache"):
+ # NOOP: this covers the use case where the user enters the
+ # context twice:
+ #
+ # >>> with p.oneshot():
+ # ... with p.oneshot():
+ # ...
+ #
+ # Also, since as_dict() internally uses oneshot()
+ # I expect that the code below will be a pretty common
+ # "mistake" that the user will make, so let's guard
+ # against that:
+ #
+ # >>> with p.oneshot():
+ # ... p.as_dict()
+ # ...
+ yield
+ else:
+ try:
+ # cached in case cpu_percent() is used
+ self.cpu_times.cache_activate(self)
+ # cached in case memory_percent() is used
+ self.memory_info.cache_activate(self)
+ # cached in case parent() is used
+ self.ppid.cache_activate(self)
+ # cached in case username() is used
+ if POSIX:
+ self.uids.cache_activate(self)
+ # specific implementation cache
+ self._proc.oneshot_enter()
+ yield
+ finally:
+ self.cpu_times.cache_deactivate(self)
+ self.memory_info.cache_deactivate(self)
+ self.ppid.cache_deactivate(self)
+ if POSIX:
+ self.uids.cache_deactivate(self)
+ self._proc.oneshot_exit()
+
+ def as_dict(self, attrs=None, ad_value=None):
+ """Utility method returning process information as a
+ hashable dictionary.
+ If *attrs* is specified it must be a list of strings
+ reflecting available Process class' attribute names
+ (e.g. ['cpu_times', 'name']) else all public (read
+ only) attributes are assumed.
+ *ad_value* is the value which gets assigned in case
+ AccessDenied or ZombieProcess exception is raised when
+ retrieving that particular process information.
+ """
+ valid_names = _as_dict_attrnames
+ if attrs is not None:
+ if not isinstance(attrs, (list, tuple, set, frozenset)):
+ raise TypeError("invalid attrs type %s" % type(attrs))
+ attrs = set(attrs)
+ invalid_names = attrs - valid_names
+ if invalid_names:
+ raise ValueError("invalid attr name%s %s" % (
+ "s" if len(invalid_names) > 1 else "",
+ ", ".join(map(repr, invalid_names))))
+
+ retdict = dict()
+ ls = attrs or valid_names
+ with self.oneshot():
+ for name in ls:
+ try:
+ if name == 'pid':
+ ret = self.pid
+ else:
+ meth = getattr(self, name)
+ ret = meth()
+ except (AccessDenied, ZombieProcess):
+ ret = ad_value
+ except NotImplementedError:
+ # in case of not implemented functionality (may happen
+ # on old or exotic systems) we want to crash only if
+ # the user explicitly asked for that particular attr
+ if attrs:
+ raise
+ continue
+ retdict[name] = ret
+ return retdict
+
+ def parent(self):
+ """Return the parent process as a Process object pre-emptively
+ checking whether PID has been reused.
+ If no parent is known return None.
+ """
+ lowest_pid = _LOWEST_PID if _LOWEST_PID is not None else pids()[0]
+ if self.pid == lowest_pid:
+ return None
+ ppid = self.ppid()
+ if ppid is not None:
+ ctime = self.create_time()
+ try:
+ parent = Process(ppid)
+ if parent.create_time() <= ctime:
+ return parent
+ # ...else ppid has been reused by another process
+ except NoSuchProcess:
+ pass
+
+ def parents(self):
+ """Return the parents of this process as a list of Process
+ instances. If no parents are known return an empty list.
+ """
+ parents = []
+ proc = self.parent()
+ while proc is not None:
+ parents.append(proc)
+ proc = proc.parent()
+ return parents
+
+ def is_running(self):
+ """Return whether this process is running.
+ It also checks if PID has been reused by another process in
+ which case return False.
+ """
+ if self._gone or self._pid_reused:
+ return False
+ try:
+ # Checking if PID is alive is not enough as the PID might
+ # have been reused by another process: we also want to
+ # verify process identity.
+ # Process identity / uniqueness over time is guaranteed by
+ # (PID + creation time) and that is verified in __eq__.
+ self._pid_reused = self != Process(self.pid)
+ return not self._pid_reused
+ except ZombieProcess:
+ # We should never get here as it's already handled in
+ # Process.__init__; here just for extra safety.
+ return True
+ except NoSuchProcess:
+ self._gone = True
+ return False
+
+ # --- actual API
+
+ @memoize_when_activated
+ def ppid(self):
+ """The process parent PID.
+ On Windows the return value is cached after first call.
+ """
+ # On POSIX we don't want to cache the ppid as it may unexpectedly
+ # change to 1 (init) in case this process turns into a zombie:
+ # https://github.com/giampaolo/psutil/issues/321
+ # http://stackoverflow.com/questions/356722/
+
+ # XXX should we check creation time here rather than in
+ # Process.parent()?
+ if POSIX:
+ return self._proc.ppid()
+ else: # pragma: no cover
+ self._ppid = self._ppid or self._proc.ppid()
+ return self._ppid
+
+ def name(self):
+ """The process name. The return value is cached after first call."""
+ # Process name is only cached on Windows as on POSIX it may
+ # change, see:
+ # https://github.com/giampaolo/psutil/issues/692
+ if WINDOWS and self._name is not None:
+ return self._name
+ name = self._proc.name()
+ if POSIX and len(name) >= 15:
+ # On UNIX the name gets truncated to the first 15 characters.
+ # If it matches the first part of the cmdline we return that
+ # one instead because it's usually more explicative.
+ # Examples are "gnome-keyring-d" vs. "gnome-keyring-daemon".
+ try:
+ cmdline = self.cmdline()
+ except AccessDenied:
+ pass
+ else:
+ if cmdline:
+ extended_name = os.path.basename(cmdline[0])
+ if extended_name.startswith(name):
+ name = extended_name
+ self._name = name
+ self._proc._name = name
+ return name
+
+ def exe(self):
+ """The process executable as an absolute path.
+ May also be an empty string.
+ The return value is cached after first call.
+ """
+ def guess_it(fallback):
+ # try to guess exe from cmdline[0] in absence of a native
+ # exe representation
+ cmdline = self.cmdline()
+ if cmdline and hasattr(os, 'access') and hasattr(os, 'X_OK'):
+ exe = cmdline[0] # the possible exe
+ # Attempt to guess only in case of an absolute path.
+ # It is not safe otherwise as the process might have
+ # changed cwd.
+ if (os.path.isabs(exe) and
+ os.path.isfile(exe) and
+ os.access(exe, os.X_OK)):
+ return exe
+ if isinstance(fallback, AccessDenied):
+ raise fallback
+ return fallback
+
+ if self._exe is None:
+ try:
+ exe = self._proc.exe()
+ except AccessDenied as err:
+ return guess_it(fallback=err)
+ else:
+ if not exe:
+ # underlying implementation can legitimately return an
+ # empty string; if that's the case we don't want to
+ # raise AD while guessing from the cmdline
+ try:
+ exe = guess_it(fallback=exe)
+ except AccessDenied:
+ pass
+ self._exe = exe
+ return self._exe
+
+ def cmdline(self):
+ """The command line this process has been called with."""
+ return self._proc.cmdline()
+
+ def status(self):
+ """The process current status as a STATUS_* constant."""
+ try:
+ return self._proc.status()
+ except ZombieProcess:
+ return STATUS_ZOMBIE
+
+ def username(self):
+ """The name of the user that owns the process.
+ On UNIX this is calculated by using *real* process uid.
+ """
+ if POSIX:
+ if pwd is None:
+ # might happen if python was installed from sources
+ raise ImportError(
+ "requires pwd module shipped with standard python")
+ real_uid = self.uids().real
+ try:
+ return pwd.getpwuid(real_uid).pw_name
+ except KeyError:
+ # the uid can't be resolved by the system
+ return str(real_uid)
+ else:
+ return self._proc.username()
+
+ def create_time(self):
+ """The process creation time as a floating point number
+ expressed in seconds since the epoch.
+ The return value is cached after first call.
+ """
+ if self._create_time is None:
+ self._create_time = self._proc.create_time()
+ return self._create_time
+
+ def cwd(self):
+ """Process current working directory as an absolute path."""
+ return self._proc.cwd()
+
+ def nice(self, value=None):
+ """Get or set process niceness (priority)."""
+ if value is None:
+ return self._proc.nice_get()
+ else:
+ if not self.is_running():
+ raise NoSuchProcess(self.pid, self._name)
+ self._proc.nice_set(value)
+
+ if POSIX:
+
+ @memoize_when_activated
+ def uids(self):
+ """Return process UIDs as a (real, effective, saved)
+ namedtuple.
+ """
+ return self._proc.uids()
+
+ def gids(self):
+ """Return process GIDs as a (real, effective, saved)
+ namedtuple.
+ """
+ return self._proc.gids()
+
+ def terminal(self):
+ """The terminal associated with this process, if any,
+ else None.
+ """
+ return self._proc.terminal()
+
+ def num_fds(self):
+ """Return the number of file descriptors opened by this
+ process (POSIX only).
+ """
+ return self._proc.num_fds()
+
+ # Linux, BSD, AIX and Windows only
+ if hasattr(_psplatform.Process, "io_counters"):
+
+ def io_counters(self):
+ """Return process I/O statistics as a
+ (read_count, write_count, read_bytes, write_bytes)
+ namedtuple.
+ Those are the number of read/write calls performed and the
+ amount of bytes read and written by the process.
+ """
+ return self._proc.io_counters()
+
+ # Linux and Windows
+ if hasattr(_psplatform.Process, "ionice_get"):
+
+ def ionice(self, ioclass=None, value=None):
+ """Get or set process I/O niceness (priority).
+
+ On Linux *ioclass* is one of the IOPRIO_CLASS_* constants.
+ *value* is a number which goes from 0 to 7. The higher the
+ value, the lower the I/O priority of the process.
+
+ On Windows only *ioclass* is used and it can be set to 2
+ (normal), 1 (low) or 0 (very low).
+
+ Available on Linux and Windows > Vista only.
+ """
+ if ioclass is None:
+ if value is not None:
+ raise ValueError("'ioclass' argument must be specified")
+ return self._proc.ionice_get()
+ else:
+ return self._proc.ionice_set(ioclass, value)
+
+ # Linux / FreeBSD only
+ if hasattr(_psplatform.Process, "rlimit"):
+
+ def rlimit(self, resource, limits=None):
+ """Get or set process resource limits as a (soft, hard)
+ tuple.
+
+ *resource* is one of the RLIMIT_* constants.
+ *limits* is supposed to be a (soft, hard) tuple.
+
+ See "man prlimit" for further info.
+ Available on Linux and FreeBSD only.
+ """
+ return self._proc.rlimit(resource, limits)
+
+ # Windows, Linux and FreeBSD only
+ if hasattr(_psplatform.Process, "cpu_affinity_get"):
+
+ def cpu_affinity(self, cpus=None):
+ """Get or set process CPU affinity.
+ If specified, *cpus* must be a list of CPUs for which you
+ want to set the affinity (e.g. [0, 1]).
+ If an empty list is passed, all egible CPUs are assumed
+ (and set).
+ (Windows, Linux and BSD only).
+ """
+ if cpus is None:
+ return sorted(set(self._proc.cpu_affinity_get()))
+ else:
+ if not cpus:
+ if hasattr(self._proc, "_get_eligible_cpus"):
+ cpus = self._proc._get_eligible_cpus()
+ else:
+ cpus = tuple(range(len(cpu_times(percpu=True))))
+ self._proc.cpu_affinity_set(list(set(cpus)))
+
+ # Linux, FreeBSD, SunOS
+ if hasattr(_psplatform.Process, "cpu_num"):
+
+ def cpu_num(self):
+ """Return what CPU this process is currently running on.
+ The returned number should be <= psutil.cpu_count()
+ and <= len(psutil.cpu_percent(percpu=True)).
+ It may be used in conjunction with
+ psutil.cpu_percent(percpu=True) to observe the system
+ workload distributed across CPUs.
+ """
+ return self._proc.cpu_num()
+
+ # All platforms has it, but maybe not in the future.
+ if hasattr(_psplatform.Process, "environ"):
+
+ def environ(self):
+ """The environment variables of the process as a dict. Note: this
+ might not reflect changes made after the process started. """
+ return self._proc.environ()
+
+ if WINDOWS:
+
+ def num_handles(self):
+ """Return the number of handles opened by this process
+ (Windows only).
+ """
+ return self._proc.num_handles()
+
+ def num_ctx_switches(self):
+ """Return the number of voluntary and involuntary context
+ switches performed by this process.
+ """
+ return self._proc.num_ctx_switches()
+
+ def num_threads(self):
+ """Return the number of threads used by this process."""
+ return self._proc.num_threads()
+
+ if hasattr(_psplatform.Process, "threads"):
+
+ def threads(self):
+ """Return threads opened by process as a list of
+ (id, user_time, system_time) namedtuples representing
+ thread id and thread CPU times (user/system).
+ On OpenBSD this method requires root access.
+ """
+ return self._proc.threads()
+
+ @_assert_pid_not_reused
+ def children(self, recursive=False):
+ """Return the children of this process as a list of Process
+ instances, pre-emptively checking whether PID has been reused.
+ If *recursive* is True return all the parent descendants.
+
+ Example (A == this process):
+
+ A ─┐
+ │
+ ├─ B (child) ─┐
+ │ └─ X (grandchild) ─┐
+ │ └─ Y (great grandchild)
+ ├─ C (child)
+ └─ D (child)
+
+ >>> import psutil
+ >>> p = psutil.Process()
+ >>> p.children()
+ B, C, D
+ >>> p.children(recursive=True)
+ B, X, Y, C, D
+
+ Note that in the example above if process X disappears
+ process Y won't be listed as the reference to process A
+ is lost.
+ """
+ ppid_map = _ppid_map()
+ ret = []
+ if not recursive:
+ for pid, ppid in ppid_map.items():
+ if ppid == self.pid:
+ try:
+ child = Process(pid)
+ # if child happens to be older than its parent
+ # (self) it means child's PID has been reused
+ if self.create_time() <= child.create_time():
+ ret.append(child)
+ except (NoSuchProcess, ZombieProcess):
+ pass
+ else:
+ # Construct a {pid: [child pids]} dict
+ reverse_ppid_map = collections.defaultdict(list)
+ for pid, ppid in ppid_map.items():
+ reverse_ppid_map[ppid].append(pid)
+ # Recursively traverse that dict, starting from self.pid,
+ # such that we only call Process() on actual children
+ seen = set()
+ stack = [self.pid]
+ while stack:
+ pid = stack.pop()
+ if pid in seen:
+ # Since pids can be reused while the ppid_map is
+ # constructed, there may be rare instances where
+ # there's a cycle in the recorded process "tree".
+ continue
+ seen.add(pid)
+ for child_pid in reverse_ppid_map[pid]:
+ try:
+ child = Process(child_pid)
+ # if child happens to be older than its parent
+ # (self) it means child's PID has been reused
+ intime = self.create_time() <= child.create_time()
+ if intime:
+ ret.append(child)
+ stack.append(child_pid)
+ except (NoSuchProcess, ZombieProcess):
+ pass
+ return ret
+
+ def cpu_percent(self, interval=None):
+ """Return a float representing the current process CPU
+ utilization as a percentage.
+
+ When *interval* is 0.0 or None (default) compares process times
+ to system CPU times elapsed since last call, returning
+ immediately (non-blocking). That means that the first time
+ this is called it will return a meaningful 0.0 value.
+
+ When *interval* is > 0.0 compares process times to system CPU
+ times elapsed before and after the interval (blocking).
+
+ In this case is recommended for accuracy that this function
+ be called with at least 0.1 seconds between calls.
+
+ A value > 100.0 can be returned in case of processes running
+ multiple threads on different CPU cores.
+
+ The returned value is explicitly NOT split evenly between
+ all available logical CPUs. This means that a busy loop process
+ running on a system with 2 logical CPUs will be reported as
+ having 100% CPU utilization instead of 50%.
+
+ Examples:
+
+ >>> import psutil
+ >>> p = psutil.Process(os.getpid())
+ >>> # blocking
+ >>> p.cpu_percent(interval=1)
+ 2.0
+ >>> # non-blocking (percentage since last call)
+ >>> p.cpu_percent(interval=None)
+ 2.9
+ >>>
+ """
+ blocking = interval is not None and interval > 0.0
+ if interval is not None and interval < 0:
+ raise ValueError("interval is not positive (got %r)" % interval)
+ num_cpus = cpu_count() or 1
+
+ def timer():
+ return _timer() * num_cpus
+
+ if blocking:
+ st1 = timer()
+ pt1 = self._proc.cpu_times()
+ time.sleep(interval)
+ st2 = timer()
+ pt2 = self._proc.cpu_times()
+ else:
+ st1 = self._last_sys_cpu_times
+ pt1 = self._last_proc_cpu_times
+ st2 = timer()
+ pt2 = self._proc.cpu_times()
+ if st1 is None or pt1 is None:
+ self._last_sys_cpu_times = st2
+ self._last_proc_cpu_times = pt2
+ return 0.0
+
+ delta_proc = (pt2.user - pt1.user) + (pt2.system - pt1.system)
+ delta_time = st2 - st1
+ # reset values for next call in case of interval == None
+ self._last_sys_cpu_times = st2
+ self._last_proc_cpu_times = pt2
+
+ try:
+ # This is the utilization split evenly between all CPUs.
+ # E.g. a busy loop process on a 2-CPU-cores system at this
+ # point is reported as 50% instead of 100%.
+ overall_cpus_percent = ((delta_proc / delta_time) * 100)
+ except ZeroDivisionError:
+ # interval was too low
+ return 0.0
+ else:
+ # Note 1:
+ # in order to emulate "top" we multiply the value for the num
+ # of CPU cores. This way the busy process will be reported as
+ # having 100% (or more) usage.
+ #
+ # Note 2:
+ # taskmgr.exe on Windows differs in that it will show 50%
+ # instead.
+ #
+ # Note 3:
+ # a percentage > 100 is legitimate as it can result from a
+ # process with multiple threads running on different CPU
+ # cores (top does the same), see:
+ # http://stackoverflow.com/questions/1032357
+ # https://github.com/giampaolo/psutil/issues/474
+ single_cpu_percent = overall_cpus_percent * num_cpus
+ return round(single_cpu_percent, 1)
+
+ @memoize_when_activated
+ def cpu_times(self):
+ """Return a (user, system, children_user, children_system)
+ namedtuple representing the accumulated process time, in
+ seconds.
+ This is similar to os.times() but per-process.
+ On macOS and Windows children_user and children_system are
+ always set to 0.
+ """
+ return self._proc.cpu_times()
+
+ @memoize_when_activated
+ def memory_info(self):
+ """Return a namedtuple with variable fields depending on the
+ platform, representing memory information about the process.
+
+ The "portable" fields available on all platforms are `rss` and `vms`.
+
+ All numbers are expressed in bytes.
+ """
+ return self._proc.memory_info()
+
+ @_common.deprecated_method(replacement="memory_info")
+ def memory_info_ex(self):
+ return self.memory_info()
+
+ def memory_full_info(self):
+ """This method returns the same information as memory_info(),
+ plus, on some platform (Linux, macOS, Windows), also provides
+ additional metrics (USS, PSS and swap).
+ The additional metrics provide a better representation of actual
+ process memory usage.
+
+ Namely USS is the memory which is unique to a process and which
+ would be freed if the process was terminated right now.
+
+ It does so by passing through the whole process address.
+ As such it usually requires higher user privileges than
+ memory_info() and is considerably slower.
+ """
+ return self._proc.memory_full_info()
+
+ def memory_percent(self, memtype="rss"):
+ """Compare process memory to total physical system memory and
+ calculate process memory utilization as a percentage.
+ *memtype* argument is a string that dictates what type of
+ process memory you want to compare against (defaults to "rss").
+ The list of available strings can be obtained like this:
+
+ >>> psutil.Process().memory_info()._fields
+ ('rss', 'vms', 'shared', 'text', 'lib', 'data', 'dirty', 'uss', 'pss')
+ """
+ valid_types = list(_psplatform.pfullmem._fields)
+ if memtype not in valid_types:
+ raise ValueError("invalid memtype %r; valid types are %r" % (
+ memtype, tuple(valid_types)))
+ fun = self.memory_info if memtype in _psplatform.pmem._fields else \
+ self.memory_full_info
+ metrics = fun()
+ value = getattr(metrics, memtype)
+
+ # use cached value if available
+ total_phymem = _TOTAL_PHYMEM or virtual_memory().total
+ if not total_phymem > 0:
+ # we should never get here
+ raise ValueError(
+ "can't calculate process memory percent because "
+ "total physical system memory is not positive (%r)"
+ % total_phymem)
+ return (value / float(total_phymem)) * 100
+
+ if hasattr(_psplatform.Process, "memory_maps"):
+ def memory_maps(self, grouped=True):
+ """Return process' mapped memory regions as a list of namedtuples
+ whose fields are variable depending on the platform.
+
+ If *grouped* is True the mapped regions with the same 'path'
+ are grouped together and the different memory fields are summed.
+
+ If *grouped* is False every mapped region is shown as a single
+ entity and the namedtuple will also include the mapped region's
+ address space ('addr') and permission set ('perms').
+ """
+ it = self._proc.memory_maps()
+ if grouped:
+ d = {}
+ for tupl in it:
+ path = tupl[2]
+ nums = tupl[3:]
+ try:
+ d[path] = map(lambda x, y: x + y, d[path], nums)
+ except KeyError:
+ d[path] = nums
+ nt = _psplatform.pmmap_grouped
+ return [nt(path, *d[path]) for path in d] # NOQA
+ else:
+ nt = _psplatform.pmmap_ext
+ return [nt(*x) for x in it]
+
+ def open_files(self):
+ """Return files opened by process as a list of
+ (path, fd) namedtuples including the absolute file name
+ and file descriptor number.
+ """
+ return self._proc.open_files()
+
+ def connections(self, kind='inet'):
+ """Return socket connections opened by process as a list of
+ (fd, family, type, laddr, raddr, status) namedtuples.
+ The *kind* parameter filters for connections that match the
+ following criteria:
+
+ +------------+----------------------------------------------------+
+ | Kind Value | Connections using |
+ +------------+----------------------------------------------------+
+ | inet | IPv4 and IPv6 |
+ | inet4 | IPv4 |
+ | inet6 | IPv6 |
+ | tcp | TCP |
+ | tcp4 | TCP over IPv4 |
+ | tcp6 | TCP over IPv6 |
+ | udp | UDP |
+ | udp4 | UDP over IPv4 |
+ | udp6 | UDP over IPv6 |
+ | unix | UNIX socket (both UDP and TCP protocols) |
+ | all | the sum of all the possible families and protocols |
+ +------------+----------------------------------------------------+
+ """
+ return self._proc.connections(kind)
+
+ # --- signals
+
+ if POSIX:
+ def _send_signal(self, sig):
+ assert not self.pid < 0, self.pid
+ if self.pid == 0:
+ # see "man 2 kill"
+ raise ValueError(
+ "preventing sending signal to process with PID 0 as it "
+ "would affect every process in the process group of the "
+ "calling process (os.getpid()) instead of PID 0")
+ try:
+ os.kill(self.pid, sig)
+ except ProcessLookupError:
+ if OPENBSD and pid_exists(self.pid):
+ # We do this because os.kill() lies in case of
+ # zombie processes.
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ else:
+ self._gone = True
+ raise NoSuchProcess(self.pid, self._name)
+ except PermissionError:
+ raise AccessDenied(self.pid, self._name)
+
+ @_assert_pid_not_reused
+ def send_signal(self, sig):
+ """Send a signal *sig* to process pre-emptively checking
+ whether PID has been reused (see signal module constants) .
+ On Windows only SIGTERM is valid and is treated as an alias
+ for kill().
+ """
+ if POSIX:
+ self._send_signal(sig)
+ else: # pragma: no cover
+ self._proc.send_signal(sig)
+
+ @_assert_pid_not_reused
+ def suspend(self):
+ """Suspend process execution with SIGSTOP pre-emptively checking
+ whether PID has been reused.
+ On Windows this has the effect of suspending all process threads.
+ """
+ if POSIX:
+ self._send_signal(signal.SIGSTOP)
+ else: # pragma: no cover
+ self._proc.suspend()
+
+ @_assert_pid_not_reused
+ def resume(self):
+ """Resume process execution with SIGCONT pre-emptively checking
+ whether PID has been reused.
+ On Windows this has the effect of resuming all process threads.
+ """
+ if POSIX:
+ self._send_signal(signal.SIGCONT)
+ else: # pragma: no cover
+ self._proc.resume()
+
+ @_assert_pid_not_reused
+ def terminate(self):
+ """Terminate the process with SIGTERM pre-emptively checking
+ whether PID has been reused.
+ On Windows this is an alias for kill().
+ """
+ if POSIX:
+ self._send_signal(signal.SIGTERM)
+ else: # pragma: no cover
+ self._proc.kill()
+
+ @_assert_pid_not_reused
+ def kill(self):
+ """Kill the current process with SIGKILL pre-emptively checking
+ whether PID has been reused.
+ """
+ if POSIX:
+ self._send_signal(signal.SIGKILL)
+ else: # pragma: no cover
+ self._proc.kill()
+
+ def wait(self, timeout=None):
+ """Wait for process to terminate and, if process is a children
+ of os.getpid(), also return its exit code, else None.
+ On Windows there's no such limitation (exit code is always
+ returned).
+
+ If the process is already terminated immediately return None
+ instead of raising NoSuchProcess.
+
+ If *timeout* (in seconds) is specified and process is still
+ alive raise TimeoutExpired.
+
+ To wait for multiple Process(es) use psutil.wait_procs().
+ """
+ if timeout is not None and not timeout >= 0:
+ raise ValueError("timeout must be a positive integer")
+ if self._exitcode is not _SENTINEL:
+ return self._exitcode
+ self._exitcode = self._proc.wait(timeout)
+ return self._exitcode
+
+
+# The valid attr names which can be processed by Process.as_dict().
+_as_dict_attrnames = set(
+ [x for x in dir(Process) if not x.startswith('_') and x not in
+ ['send_signal', 'suspend', 'resume', 'terminate', 'kill', 'wait',
+ 'is_running', 'as_dict', 'parent', 'parents', 'children', 'rlimit',
+ 'memory_info_ex', 'oneshot']])
+
+
+# =====================================================================
+# --- Popen class
+# =====================================================================
+
+
+class Popen(Process):
+ """Same as subprocess.Popen, but in addition it provides all
+ psutil.Process methods in a single class.
+ For the following methods which are common to both classes, psutil
+ implementation takes precedence:
+
+ * send_signal()
+ * terminate()
+ * kill()
+
+ This is done in order to avoid killing another process in case its
+ PID has been reused, fixing BPO-6973.
+
+ >>> import psutil
+ >>> from subprocess import PIPE
+ >>> p = psutil.Popen(["python", "-c", "print 'hi'"], stdout=PIPE)
+ >>> p.name()
+ 'python'
+ >>> p.uids()
+ user(real=1000, effective=1000, saved=1000)
+ >>> p.username()
+ 'giampaolo'
+ >>> p.communicate()
+ ('hi\n', None)
+ >>> p.terminate()
+ >>> p.wait(timeout=2)
+ 0
+ >>>
+ """
+
+ def __init__(self, *args, **kwargs):
+ # Explicitly avoid to raise NoSuchProcess in case the process
+ # spawned by subprocess.Popen terminates too quickly, see:
+ # https://github.com/giampaolo/psutil/issues/193
+ self.__subproc = subprocess.Popen(*args, **kwargs)
+ self._init(self.__subproc.pid, _ignore_nsp=True)
+
+ def __dir__(self):
+ return sorted(set(dir(Popen) + dir(subprocess.Popen)))
+
+ def __enter__(self):
+ if hasattr(self.__subproc, '__enter__'):
+ self.__subproc.__enter__()
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ if hasattr(self.__subproc, '__exit__'):
+ return self.__subproc.__exit__(*args, **kwargs)
+ else:
+ if self.stdout:
+ self.stdout.close()
+ if self.stderr:
+ self.stderr.close()
+ try:
+ # Flushing a BufferedWriter may raise an error.
+ if self.stdin:
+ self.stdin.close()
+ finally:
+ # Wait for the process to terminate, to avoid zombies.
+ self.wait()
+
+ def __getattribute__(self, name):
+ try:
+ return object.__getattribute__(self, name)
+ except AttributeError:
+ try:
+ return object.__getattribute__(self.__subproc, name)
+ except AttributeError:
+ raise AttributeError("%s instance has no attribute '%s'"
+ % (self.__class__.__name__, name))
+
+ def wait(self, timeout=None):
+ if self.__subproc.returncode is not None:
+ return self.__subproc.returncode
+ ret = super(Popen, self).wait(timeout)
+ self.__subproc.returncode = ret
+ return ret
+
+
+# =====================================================================
+# --- system processes related functions
+# =====================================================================
+
+
+def pids():
+ """Return a list of current running PIDs."""
+ global _LOWEST_PID
+ ret = sorted(_psplatform.pids())
+ _LOWEST_PID = ret[0]
+ return ret
+
+
+def pid_exists(pid):
+ """Return True if given PID exists in the current process list.
+ This is faster than doing "pid in psutil.pids()" and
+ should be preferred.
+ """
+ if pid < 0:
+ return False
+ elif pid == 0 and POSIX:
+ # On POSIX we use os.kill() to determine PID existence.
+ # According to "man 2 kill" PID 0 has a special meaning
+ # though: it refers to <<every process in the process
+ # group of the calling process>> and that is not we want
+ # to do here.
+ return pid in pids()
+ else:
+ return _psplatform.pid_exists(pid)
+
+
+_pmap = {}
+
+
+def process_iter(attrs=None, ad_value=None):
+ """Return a generator yielding a Process instance for all
+ running processes.
+
+ Every new Process instance is only created once and then cached
+ into an internal table which is updated every time this is used.
+
+ Cached Process instances are checked for identity so that you're
+ safe in case a PID has been reused by another process, in which
+ case the cached instance is updated.
+
+ The sorting order in which processes are yielded is based on
+ their PIDs.
+
+ *attrs* and *ad_value* have the same meaning as in
+ Process.as_dict(). If *attrs* is specified as_dict() is called
+ and the resulting dict is stored as a 'info' attribute attached
+ to returned Process instance.
+ If *attrs* is an empty list it will retrieve all process info
+ (slow).
+ """
+ global _pmap
+
+ def add(pid):
+ proc = Process(pid)
+ if attrs is not None:
+ proc.info = proc.as_dict(attrs=attrs, ad_value=ad_value)
+ pmap[proc.pid] = proc
+ return proc
+
+ def remove(pid):
+ pmap.pop(pid, None)
+
+ pmap = _pmap.copy()
+ a = set(pids())
+ b = set(pmap.keys())
+ new_pids = a - b
+ gone_pids = b - a
+ for pid in gone_pids:
+ remove(pid)
+ try:
+ ls = sorted(list(pmap.items()) + list(dict.fromkeys(new_pids).items()))
+ for pid, proc in ls:
+ try:
+ if proc is None: # new process
+ yield add(pid)
+ else:
+ # use is_running() to check whether PID has been
+ # reused by another process in which case yield a
+ # new Process instance
+ if proc.is_running():
+ if attrs is not None:
+ proc.info = proc.as_dict(
+ attrs=attrs, ad_value=ad_value)
+ yield proc
+ else:
+ yield add(pid)
+ except NoSuchProcess:
+ remove(pid)
+ except AccessDenied:
+ # Process creation time can't be determined hence there's
+ # no way to tell whether the pid of the cached process
+ # has been reused. Just return the cached version.
+ if proc is None and pid in pmap:
+ try:
+ yield pmap[pid]
+ except KeyError:
+ # If we get here it is likely that 2 threads were
+ # using process_iter().
+ pass
+ else:
+ raise
+ finally:
+ _pmap = pmap
+
+
+def wait_procs(procs, timeout=None, callback=None):
+ """Convenience function which waits for a list of processes to
+ terminate.
+
+ Return a (gone, alive) tuple indicating which processes
+ are gone and which ones are still alive.
+
+ The gone ones will have a new *returncode* attribute indicating
+ process exit status (may be None).
+
+ *callback* is a function which gets called every time a process
+ terminates (a Process instance is passed as callback argument).
+
+ Function will return as soon as all processes terminate or when
+ *timeout* occurs.
+ Differently from Process.wait() it will not raise TimeoutExpired if
+ *timeout* occurs.
+
+ Typical use case is:
+
+ - send SIGTERM to a list of processes
+ - give them some time to terminate
+ - send SIGKILL to those ones which are still alive
+
+ Example:
+
+ >>> def on_terminate(proc):
+ ... print("process {} terminated".format(proc))
+ ...
+ >>> for p in procs:
+ ... p.terminate()
+ ...
+ >>> gone, alive = wait_procs(procs, timeout=3, callback=on_terminate)
+ >>> for p in alive:
+ ... p.kill()
+ """
+ def check_gone(proc, timeout):
+ try:
+ returncode = proc.wait(timeout=timeout)
+ except TimeoutExpired:
+ pass
+ except _SubprocessTimeoutExpired:
+ pass
+ else:
+ if returncode is not None or not proc.is_running():
+ # Set new Process instance attribute.
+ proc.returncode = returncode
+ gone.add(proc)
+ if callback is not None:
+ callback(proc)
+
+ if timeout is not None and not timeout >= 0:
+ msg = "timeout must be a positive integer, got %s" % timeout
+ raise ValueError(msg)
+ gone = set()
+ alive = set(procs)
+ if callback is not None and not callable(callback):
+ raise TypeError("callback %r is not a callable" % callable)
+ if timeout is not None:
+ deadline = _timer() + timeout
+
+ while alive:
+ if timeout is not None and timeout <= 0:
+ break
+ for proc in alive:
+ # Make sure that every complete iteration (all processes)
+ # will last max 1 sec.
+ # We do this because we don't want to wait too long on a
+ # single process: in case it terminates too late other
+ # processes may disappear in the meantime and their PID
+ # reused.
+ max_timeout = 1.0 / len(alive)
+ if timeout is not None:
+ timeout = min((deadline - _timer()), max_timeout)
+ if timeout <= 0:
+ break
+ check_gone(proc, timeout)
+ else:
+ check_gone(proc, max_timeout)
+ alive = alive - gone
+
+ if alive:
+ # Last attempt over processes survived so far.
+ # timeout == 0 won't make this function wait any further.
+ for proc in alive:
+ check_gone(proc, 0)
+ alive = alive - gone
+
+ return (list(gone), list(alive))
+
+
+# =====================================================================
+# --- CPU related functions
+# =====================================================================
+
+
+def cpu_count(logical=True):
+ """Return the number of logical CPUs in the system (same as
+ os.cpu_count() in Python 3.4).
+
+ If *logical* is False return the number of physical cores only
+ (e.g. hyper thread CPUs are excluded).
+
+ Return None if undetermined.
+
+ The return value is cached after first call.
+ If desired cache can be cleared like this:
+
+ >>> psutil.cpu_count.cache_clear()
+ """
+ if logical:
+ ret = _psplatform.cpu_count_logical()
+ else:
+ ret = _psplatform.cpu_count_cores()
+ if ret is not None and ret < 1:
+ ret = None
+ return ret
+
+
+def cpu_times(percpu=False):
+ """Return system-wide CPU times as a namedtuple.
+ Every CPU time represents the seconds the CPU has spent in the
+ given mode. The namedtuple's fields availability varies depending on the
+ platform:
+
+ - user
+ - system
+ - idle
+ - nice (UNIX)
+ - iowait (Linux)
+ - irq (Linux, FreeBSD)
+ - softirq (Linux)
+ - steal (Linux >= 2.6.11)
+ - guest (Linux >= 2.6.24)
+ - guest_nice (Linux >= 3.2.0)
+
+ When *percpu* is True return a list of namedtuples for each CPU.
+ First element of the list refers to first CPU, second element
+ to second CPU and so on.
+ The order of the list is consistent across calls.
+ """
+ if not percpu:
+ return _psplatform.cpu_times()
+ else:
+ return _psplatform.per_cpu_times()
+
+
+try:
+ _last_cpu_times = cpu_times()
+except Exception:
+ # Don't want to crash at import time.
+ _last_cpu_times = None
+
+try:
+ _last_per_cpu_times = cpu_times(percpu=True)
+except Exception:
+ # Don't want to crash at import time.
+ _last_per_cpu_times = None
+
+
+def _cpu_tot_time(times):
+ """Given a cpu_time() ntuple calculates the total CPU time
+ (including idle time).
+ """
+ tot = sum(times)
+ if LINUX:
+ # On Linux guest times are already accounted in "user" or
+ # "nice" times, so we subtract them from total.
+ # Htop does the same. References:
+ # https://github.com/giampaolo/psutil/pull/940
+ # http://unix.stackexchange.com/questions/178045
+ # https://github.com/torvalds/linux/blob/
+ # 447976ef4fd09b1be88b316d1a81553f1aa7cd07/kernel/sched/
+ # cputime.c#L158
+ tot -= getattr(times, "guest", 0) # Linux 2.6.24+
+ tot -= getattr(times, "guest_nice", 0) # Linux 3.2.0+
+ return tot
+
+
+def _cpu_busy_time(times):
+ """Given a cpu_time() ntuple calculates the busy CPU time.
+ We do so by subtracting all idle CPU times.
+ """
+ busy = _cpu_tot_time(times)
+ busy -= times.idle
+ # Linux: "iowait" is time during which the CPU does not do anything
+ # (waits for IO to complete). On Linux IO wait is *not* accounted
+ # in "idle" time so we subtract it. Htop does the same.
+ # References:
+ # https://github.com/torvalds/linux/blob/
+ # 447976ef4fd09b1be88b316d1a81553f1aa7cd07/kernel/sched/cputime.c#L244
+ busy -= getattr(times, "iowait", 0)
+ return busy
+
+
+def _cpu_times_deltas(t1, t2):
+ assert t1._fields == t2._fields, (t1, t2)
+ field_deltas = []
+ for field in _psplatform.scputimes._fields:
+ field_delta = getattr(t2, field) - getattr(t1, field)
+ # CPU times are always supposed to increase over time
+ # or at least remain the same and that's because time
+ # cannot go backwards.
+ # Surprisingly sometimes this might not be the case (at
+ # least on Windows and Linux), see:
+ # https://github.com/giampaolo/psutil/issues/392
+ # https://github.com/giampaolo/psutil/issues/645
+ # https://github.com/giampaolo/psutil/issues/1210
+ # Trim negative deltas to zero to ignore decreasing fields.
+ # top does the same. Reference:
+ # https://gitlab.com/procps-ng/procps/blob/v3.3.12/top/top.c#L5063
+ field_delta = max(0, field_delta)
+ field_deltas.append(field_delta)
+ return _psplatform.scputimes(*field_deltas)
+
+
+def cpu_percent(interval=None, percpu=False):
+ """Return a float representing the current system-wide CPU
+ utilization as a percentage.
+
+ When *interval* is > 0.0 compares system CPU times elapsed before
+ and after the interval (blocking).
+
+ When *interval* is 0.0 or None compares system CPU times elapsed
+ since last call or module import, returning immediately (non
+ blocking). That means the first time this is called it will
+ return a meaningless 0.0 value which you should ignore.
+ In this case is recommended for accuracy that this function be
+ called with at least 0.1 seconds between calls.
+
+ When *percpu* is True returns a list of floats representing the
+ utilization as a percentage for each CPU.
+ First element of the list refers to first CPU, second element
+ to second CPU and so on.
+ The order of the list is consistent across calls.
+
+ Examples:
+
+ >>> # blocking, system-wide
+ >>> psutil.cpu_percent(interval=1)
+ 2.0
+ >>>
+ >>> # blocking, per-cpu
+ >>> psutil.cpu_percent(interval=1, percpu=True)
+ [2.0, 1.0]
+ >>>
+ >>> # non-blocking (percentage since last call)
+ >>> psutil.cpu_percent(interval=None)
+ 2.9
+ >>>
+ """
+ global _last_cpu_times
+ global _last_per_cpu_times
+ blocking = interval is not None and interval > 0.0
+ if interval is not None and interval < 0:
+ raise ValueError("interval is not positive (got %r)" % interval)
+
+ def calculate(t1, t2):
+ times_delta = _cpu_times_deltas(t1, t2)
+ all_delta = _cpu_tot_time(times_delta)
+ busy_delta = _cpu_busy_time(times_delta)
+
+ try:
+ busy_perc = (busy_delta / all_delta) * 100
+ except ZeroDivisionError:
+ return 0.0
+ else:
+ return round(busy_perc, 1)
+
+ # system-wide usage
+ if not percpu:
+ if blocking:
+ t1 = cpu_times()
+ time.sleep(interval)
+ else:
+ t1 = _last_cpu_times
+ if t1 is None:
+ # Something bad happened at import time. We'll
+ # get a meaningful result on the next call. See:
+ # https://github.com/giampaolo/psutil/pull/715
+ t1 = cpu_times()
+ _last_cpu_times = cpu_times()
+ return calculate(t1, _last_cpu_times)
+ # per-cpu usage
+ else:
+ ret = []
+ if blocking:
+ tot1 = cpu_times(percpu=True)
+ time.sleep(interval)
+ else:
+ tot1 = _last_per_cpu_times
+ if tot1 is None:
+ # Something bad happened at import time. We'll
+ # get a meaningful result on the next call. See:
+ # https://github.com/giampaolo/psutil/pull/715
+ tot1 = cpu_times(percpu=True)
+ _last_per_cpu_times = cpu_times(percpu=True)
+ for t1, t2 in zip(tot1, _last_per_cpu_times):
+ ret.append(calculate(t1, t2))
+ return ret
+
+
+# Use separate global vars for cpu_times_percent() so that it's
+# independent from cpu_percent() and they can both be used within
+# the same program.
+_last_cpu_times_2 = _last_cpu_times
+_last_per_cpu_times_2 = _last_per_cpu_times
+
+
+def cpu_times_percent(interval=None, percpu=False):
+ """Same as cpu_percent() but provides utilization percentages
+ for each specific CPU time as is returned by cpu_times().
+ For instance, on Linux we'll get:
+
+ >>> cpu_times_percent()
+ cpupercent(user=4.8, nice=0.0, system=4.8, idle=90.5, iowait=0.0,
+ irq=0.0, softirq=0.0, steal=0.0, guest=0.0, guest_nice=0.0)
+ >>>
+
+ *interval* and *percpu* arguments have the same meaning as in
+ cpu_percent().
+ """
+ global _last_cpu_times_2
+ global _last_per_cpu_times_2
+ blocking = interval is not None and interval > 0.0
+ if interval is not None and interval < 0:
+ raise ValueError("interval is not positive (got %r)" % interval)
+
+ def calculate(t1, t2):
+ nums = []
+ times_delta = _cpu_times_deltas(t1, t2)
+ all_delta = _cpu_tot_time(times_delta)
+ # "scale" is the value to multiply each delta with to get percentages.
+ # We use "max" to avoid division by zero (if all_delta is 0, then all
+ # fields are 0 so percentages will be 0 too. all_delta cannot be a
+ # fraction because cpu times are integers)
+ scale = 100.0 / max(1, all_delta)
+ for field_delta in times_delta:
+ field_perc = field_delta * scale
+ field_perc = round(field_perc, 1)
+ # make sure we don't return negative values or values over 100%
+ field_perc = min(max(0.0, field_perc), 100.0)
+ nums.append(field_perc)
+ return _psplatform.scputimes(*nums)
+
+ # system-wide usage
+ if not percpu:
+ if blocking:
+ t1 = cpu_times()
+ time.sleep(interval)
+ else:
+ t1 = _last_cpu_times_2
+ if t1 is None:
+ # Something bad happened at import time. We'll
+ # get a meaningful result on the next call. See:
+ # https://github.com/giampaolo/psutil/pull/715
+ t1 = cpu_times()
+ _last_cpu_times_2 = cpu_times()
+ return calculate(t1, _last_cpu_times_2)
+ # per-cpu usage
+ else:
+ ret = []
+ if blocking:
+ tot1 = cpu_times(percpu=True)
+ time.sleep(interval)
+ else:
+ tot1 = _last_per_cpu_times_2
+ if tot1 is None:
+ # Something bad happened at import time. We'll
+ # get a meaningful result on the next call. See:
+ # https://github.com/giampaolo/psutil/pull/715
+ tot1 = cpu_times(percpu=True)
+ _last_per_cpu_times_2 = cpu_times(percpu=True)
+ for t1, t2 in zip(tot1, _last_per_cpu_times_2):
+ ret.append(calculate(t1, t2))
+ return ret
+
+
+def cpu_stats():
+ """Return CPU statistics."""
+ return _psplatform.cpu_stats()
+
+
+if hasattr(_psplatform, "cpu_freq"):
+
+ def cpu_freq(percpu=False):
+ """Return CPU frequency as a namedtuple including current,
+ min and max frequency expressed in Mhz.
+
+ If *percpu* is True and the system supports per-cpu frequency
+ retrieval (Linux only) a list of frequencies is returned for
+ each CPU. If not a list with one element is returned.
+ """
+ ret = _psplatform.cpu_freq()
+ if percpu:
+ return ret
+ else:
+ num_cpus = float(len(ret))
+ if num_cpus == 0:
+ return None
+ elif num_cpus == 1:
+ return ret[0]
+ else:
+ currs, mins, maxs = 0.0, 0.0, 0.0
+ set_none = False
+ for cpu in ret:
+ currs += cpu.current
+ # On Linux if /proc/cpuinfo is used min/max are set
+ # to None.
+ if LINUX and cpu.min is None:
+ set_none = True
+ continue
+ mins += cpu.min
+ maxs += cpu.max
+
+ current = currs / num_cpus
+
+ if set_none:
+ min_ = max_ = None
+ else:
+ min_ = mins / num_cpus
+ max_ = maxs / num_cpus
+
+ return _common.scpufreq(current, min_, max_)
+
+ __all__.append("cpu_freq")
+
+
+if hasattr(os, "getloadavg") or hasattr(_psplatform, "getloadavg"):
+ # Perform this hasattr check once on import time to either use the
+ # platform based code or proxy straight from the os module.
+ if hasattr(os, "getloadavg"):
+ getloadavg = os.getloadavg
+ else:
+ getloadavg = _psplatform.getloadavg
+
+ __all__.append("getloadavg")
+
+
+# =====================================================================
+# --- system memory related functions
+# =====================================================================
+
+
+def virtual_memory():
+ """Return statistics about system memory usage as a namedtuple
+ including the following fields, expressed in bytes:
+
+ - total:
+ total physical memory available.
+
+ - available:
+ the memory that can be given instantly to processes without the
+ system going into swap.
+ This is calculated by summing different memory values depending
+ on the platform and it is supposed to be used to monitor actual
+ memory usage in a cross platform fashion.
+
+ - percent:
+ the percentage usage calculated as (total - available) / total * 100
+
+ - used:
+ memory used, calculated differently depending on the platform and
+ designed for informational purposes only:
+ macOS: active + wired
+ BSD: active + wired + cached
+ Linux: total - free
+
+ - free:
+ memory not being used at all (zeroed) that is readily available;
+ note that this doesn't reflect the actual memory available
+ (use 'available' instead)
+
+ Platform-specific fields:
+
+ - active (UNIX):
+ memory currently in use or very recently used, and so it is in RAM.
+
+ - inactive (UNIX):
+ memory that is marked as not used.
+
+ - buffers (BSD, Linux):
+ cache for things like file system metadata.
+
+ - cached (BSD, macOS):
+ cache for various things.
+
+ - wired (macOS, BSD):
+ memory that is marked to always stay in RAM. It is never moved to disk.
+
+ - shared (BSD):
+ memory that may be simultaneously accessed by multiple processes.
+
+ The sum of 'used' and 'available' does not necessarily equal total.
+ On Windows 'available' and 'free' are the same.
+ """
+ global _TOTAL_PHYMEM
+ ret = _psplatform.virtual_memory()
+ # cached for later use in Process.memory_percent()
+ _TOTAL_PHYMEM = ret.total
+ return ret
+
+
+def swap_memory():
+ """Return system swap memory statistics as a namedtuple including
+ the following fields:
+
+ - total: total swap memory in bytes
+ - used: used swap memory in bytes
+ - free: free swap memory in bytes
+ - percent: the percentage usage
+ - sin: no. of bytes the system has swapped in from disk (cumulative)
+ - sout: no. of bytes the system has swapped out from disk (cumulative)
+
+ 'sin' and 'sout' on Windows are meaningless and always set to 0.
+ """
+ return _psplatform.swap_memory()
+
+
+# =====================================================================
+# --- disks/paritions related functions
+# =====================================================================
+
+
+def disk_usage(path):
+ """Return disk usage statistics about the given *path* as a
+ namedtuple including total, used and free space expressed in bytes
+ plus the percentage usage.
+ """
+ return _psplatform.disk_usage(path)
+
+
+def disk_partitions(all=False):
+ """Return mounted partitions as a list of
+ (device, mountpoint, fstype, opts) namedtuple.
+ 'opts' field is a raw string separated by commas indicating mount
+ options which may vary depending on the platform.
+
+ If *all* parameter is False return physical devices only and ignore
+ all others.
+ """
+ def pathconf(path, name):
+ try:
+ return os.pathconf(path, name)
+ except (OSError, AttributeError):
+ pass
+
+ ret = _psplatform.disk_partitions(all)
+ if POSIX:
+ new = []
+ for item in ret:
+ nt = item._replace(
+ maxfile=pathconf(item.mountpoint, 'PC_NAME_MAX'),
+ maxpath=pathconf(item.mountpoint, 'PC_PATH_MAX'))
+ new.append(nt)
+ return new
+ else:
+ return ret
+
+
+def disk_io_counters(perdisk=False, nowrap=True):
+ """Return system disk I/O statistics as a namedtuple including
+ the following fields:
+
+ - read_count: number of reads
+ - write_count: number of writes
+ - read_bytes: number of bytes read
+ - write_bytes: number of bytes written
+ - read_time: time spent reading from disk (in ms)
+ - write_time: time spent writing to disk (in ms)
+
+ Platform specific:
+
+ - busy_time: (Linux, FreeBSD) time spent doing actual I/Os (in ms)
+ - read_merged_count (Linux): number of merged reads
+ - write_merged_count (Linux): number of merged writes
+
+ If *perdisk* is True return the same information for every
+ physical disk installed on the system as a dictionary
+ with partition names as the keys and the namedtuple
+ described above as the values.
+
+ If *nowrap* is True it detects and adjust the numbers which overflow
+ and wrap (restart from 0) and add "old value" to "new value" so that
+ the returned numbers will always be increasing or remain the same,
+ but never decrease.
+ "disk_io_counters.cache_clear()" can be used to invalidate the
+ cache.
+
+ On recent Windows versions 'diskperf -y' command may need to be
+ executed first otherwise this function won't find any disk.
+ """
+ kwargs = dict(perdisk=perdisk) if LINUX else {}
+ rawdict = _psplatform.disk_io_counters(**kwargs)
+ if not rawdict:
+ return {} if perdisk else None
+ if nowrap:
+ rawdict = _wrap_numbers(rawdict, 'psutil.disk_io_counters')
+ nt = getattr(_psplatform, "sdiskio", _common.sdiskio)
+ if perdisk:
+ for disk, fields in rawdict.items():
+ rawdict[disk] = nt(*fields)
+ return rawdict
+ else:
+ return nt(*(sum(x) for x in zip(*rawdict.values())))
+
+
+disk_io_counters.cache_clear = functools.partial(
+ _wrap_numbers.cache_clear, 'psutil.disk_io_counters')
+disk_io_counters.cache_clear.__doc__ = "Clears nowrap argument cache"
+
+
+# =====================================================================
+# --- network related functions
+# =====================================================================
+
+
+def net_io_counters(pernic=False, nowrap=True):
+ """Return network I/O statistics as a namedtuple including
+ the following fields:
+
+ - bytes_sent: number of bytes sent
+ - bytes_recv: number of bytes received
+ - packets_sent: number of packets sent
+ - packets_recv: number of packets received
+ - errin: total number of errors while receiving
+ - errout: total number of errors while sending
+ - dropin: total number of incoming packets which were dropped
+ - dropout: total number of outgoing packets which were dropped
+ (always 0 on macOS and BSD)
+
+ If *pernic* is True return the same information for every
+ network interface installed on the system as a dictionary
+ with network interface names as the keys and the namedtuple
+ described above as the values.
+
+ If *nowrap* is True it detects and adjust the numbers which overflow
+ and wrap (restart from 0) and add "old value" to "new value" so that
+ the returned numbers will always be increasing or remain the same,
+ but never decrease.
+ "disk_io_counters.cache_clear()" can be used to invalidate the
+ cache.
+ """
+ rawdict = _psplatform.net_io_counters()
+ if not rawdict:
+ return {} if pernic else None
+ if nowrap:
+ rawdict = _wrap_numbers(rawdict, 'psutil.net_io_counters')
+ if pernic:
+ for nic, fields in rawdict.items():
+ rawdict[nic] = _common.snetio(*fields)
+ return rawdict
+ else:
+ return _common.snetio(*[sum(x) for x in zip(*rawdict.values())])
+
+
+net_io_counters.cache_clear = functools.partial(
+ _wrap_numbers.cache_clear, 'psutil.net_io_counters')
+net_io_counters.cache_clear.__doc__ = "Clears nowrap argument cache"
+
+
+def net_connections(kind='inet'):
+ """Return system-wide socket connections as a list of
+ (fd, family, type, laddr, raddr, status, pid) namedtuples.
+ In case of limited privileges 'fd' and 'pid' may be set to -1
+ and None respectively.
+ The *kind* parameter filters for connections that fit the
+ following criteria:
+
+ +------------+----------------------------------------------------+
+ | Kind Value | Connections using |
+ +------------+----------------------------------------------------+
+ | inet | IPv4 and IPv6 |
+ | inet4 | IPv4 |
+ | inet6 | IPv6 |
+ | tcp | TCP |
+ | tcp4 | TCP over IPv4 |
+ | tcp6 | TCP over IPv6 |
+ | udp | UDP |
+ | udp4 | UDP over IPv4 |
+ | udp6 | UDP over IPv6 |
+ | unix | UNIX socket (both UDP and TCP protocols) |
+ | all | the sum of all the possible families and protocols |
+ +------------+----------------------------------------------------+
+
+ On macOS this function requires root privileges.
+ """
+ return _psplatform.net_connections(kind)
+
+
+def net_if_addrs():
+ """Return the addresses associated to each NIC (network interface
+ card) installed on the system as a dictionary whose keys are the
+ NIC names and value is a list of namedtuples for each address
+ assigned to the NIC. Each namedtuple includes 5 fields:
+
+ - family: can be either socket.AF_INET, socket.AF_INET6 or
+ psutil.AF_LINK, which refers to a MAC address.
+ - address: is the primary address and it is always set.
+ - netmask: and 'broadcast' and 'ptp' may be None.
+ - ptp: stands for "point to point" and references the
+ destination address on a point to point interface
+ (typically a VPN).
+ - broadcast: and *ptp* are mutually exclusive.
+
+ Note: you can have more than one address of the same family
+ associated with each interface.
+ """
+ has_enums = sys.version_info >= (3, 4)
+ if has_enums:
+ import socket
+ rawlist = _psplatform.net_if_addrs()
+ rawlist.sort(key=lambda x: x[1]) # sort by family
+ ret = collections.defaultdict(list)
+ for name, fam, addr, mask, broadcast, ptp in rawlist:
+ if has_enums:
+ try:
+ fam = socket.AddressFamily(fam)
+ except ValueError:
+ if WINDOWS and fam == -1:
+ fam = _psplatform.AF_LINK
+ elif (hasattr(_psplatform, "AF_LINK") and
+ _psplatform.AF_LINK == fam):
+ # Linux defines AF_LINK as an alias for AF_PACKET.
+ # We re-set the family here so that repr(family)
+ # will show AF_LINK rather than AF_PACKET
+ fam = _psplatform.AF_LINK
+ if fam == _psplatform.AF_LINK:
+ # The underlying C function may return an incomplete MAC
+ # address in which case we fill it with null bytes, see:
+ # https://github.com/giampaolo/psutil/issues/786
+ separator = ":" if POSIX else "-"
+ while addr.count(separator) < 5:
+ addr += "%s00" % separator
+ ret[name].append(_common.snicaddr(fam, addr, mask, broadcast, ptp))
+ return dict(ret)
+
+
+def net_if_stats():
+ """Return information about each NIC (network interface card)
+ installed on the system as a dictionary whose keys are the
+ NIC names and value is a namedtuple with the following fields:
+
+ - isup: whether the interface is up (bool)
+ - duplex: can be either NIC_DUPLEX_FULL, NIC_DUPLEX_HALF or
+ NIC_DUPLEX_UNKNOWN
+ - speed: the NIC speed expressed in mega bits (MB); if it can't
+ be determined (e.g. 'localhost') it will be set to 0.
+ - mtu: the maximum transmission unit expressed in bytes.
+ """
+ return _psplatform.net_if_stats()
+
+
+# =====================================================================
+# --- sensors
+# =====================================================================
+
+
+# Linux, macOS
+if hasattr(_psplatform, "sensors_temperatures"):
+
+ def sensors_temperatures(fahrenheit=False):
+ """Return hardware temperatures. Each entry is a namedtuple
+ representing a certain hardware sensor (it may be a CPU, an
+ hard disk or something else, depending on the OS and its
+ configuration).
+ All temperatures are expressed in celsius unless *fahrenheit*
+ is set to True.
+ """
+ def convert(n):
+ if n is not None:
+ return (float(n) * 9 / 5) + 32 if fahrenheit else n
+
+ ret = collections.defaultdict(list)
+ rawdict = _psplatform.sensors_temperatures()
+
+ for name, values in rawdict.items():
+ while values:
+ label, current, high, critical = values.pop(0)
+ current = convert(current)
+ high = convert(high)
+ critical = convert(critical)
+
+ if high and not critical:
+ critical = high
+ elif critical and not high:
+ high = critical
+
+ ret[name].append(
+ _common.shwtemp(label, current, high, critical))
+
+ return dict(ret)
+
+ __all__.append("sensors_temperatures")
+
+
+# Linux
+if hasattr(_psplatform, "sensors_fans"):
+
+ def sensors_fans():
+ """Return fans speed. Each entry is a namedtuple
+ representing a certain hardware sensor.
+ All speed are expressed in RPM (rounds per minute).
+ """
+ return _psplatform.sensors_fans()
+
+ __all__.append("sensors_fans")
+
+
+# Linux, Windows, FreeBSD, macOS
+if hasattr(_psplatform, "sensors_battery"):
+
+ def sensors_battery():
+ """Return battery information. If no battery is installed
+ returns None.
+
+ - percent: battery power left as a percentage.
+ - secsleft: a rough approximation of how many seconds are left
+ before the battery runs out of power. May be
+ POWER_TIME_UNLIMITED or POWER_TIME_UNLIMITED.
+ - power_plugged: True if the AC power cable is connected.
+ """
+ return _psplatform.sensors_battery()
+
+ __all__.append("sensors_battery")
+
+
+# =====================================================================
+# --- other system related functions
+# =====================================================================
+
+
+def boot_time():
+ """Return the system boot time expressed in seconds since the epoch."""
+ # Note: we are not caching this because it is subject to
+ # system clock updates.
+ return _psplatform.boot_time()
+
+
+def users():
+ """Return users currently connected on the system as a list of
+ namedtuples including the following fields.
+
+ - user: the name of the user
+ - terminal: the tty or pseudo-tty associated with the user, if any.
+ - host: the host name associated with the entry, if any.
+ - started: the creation time as a floating point number expressed in
+ seconds since the epoch.
+ """
+ return _psplatform.users()
+
+
+# =====================================================================
+# --- Windows services
+# =====================================================================
+
+
+if WINDOWS:
+
+ def win_service_iter():
+ """Return a generator yielding a WindowsService instance for all
+ Windows services installed.
+ """
+ return _psplatform.win_service_iter()
+
+ def win_service_get(name):
+ """Get a Windows service by *name*.
+ Raise NoSuchProcess if no service with such name exists.
+ """
+ return _psplatform.win_service_get(name)
+
+
+# =====================================================================
+
+
+def _set_debug(value):
+ """Enable or disable PSUTIL_DEBUG option, which prints debugging
+ messages to stderr.
+ """
+ import psutil._common
+ psutil._common.PSUTIL_DEBUG = bool(value)
+ _psplatform.cext.set_debug(bool(value))
+
+
+def test(): # pragma: no cover
+ from ._common import bytes2human
+ from ._compat import get_terminal_size
+
+ today_day = datetime.date.today()
+ templ = "%-10s %5s %5s %7s %7s %5s %6s %6s %6s %s"
+ attrs = ['pid', 'memory_percent', 'name', 'cmdline', 'cpu_times',
+ 'create_time', 'memory_info', 'status', 'nice', 'username']
+ print(templ % ("USER", "PID", "%MEM", "VSZ", "RSS", "NICE", # NOQA
+ "STATUS", "START", "TIME", "CMDLINE"))
+ for p in process_iter(attrs, ad_value=None):
+ if p.info['create_time']:
+ ctime = datetime.datetime.fromtimestamp(p.info['create_time'])
+ if ctime.date() == today_day:
+ ctime = ctime.strftime("%H:%M")
+ else:
+ ctime = ctime.strftime("%b%d")
+ else:
+ ctime = ''
+ if p.info['cpu_times']:
+ cputime = time.strftime("%M:%S",
+ time.localtime(sum(p.info['cpu_times'])))
+ else:
+ cputime = ''
+
+ user = p.info['username'] or ''
+ if not user and POSIX:
+ try:
+ user = p.uids()[0]
+ except Error:
+ pass
+ if user and WINDOWS and '\\' in user:
+ user = user.split('\\')[1]
+ user = user[:9]
+ vms = bytes2human(p.info['memory_info'].vms) if \
+ p.info['memory_info'] is not None else ''
+ rss = bytes2human(p.info['memory_info'].rss) if \
+ p.info['memory_info'] is not None else ''
+ memp = round(p.info['memory_percent'], 1) if \
+ p.info['memory_percent'] is not None else ''
+ nice = int(p.info['nice']) if p.info['nice'] else ''
+ if p.info['cmdline']:
+ cmdline = ' '.join(p.info['cmdline'])
+ else:
+ cmdline = p.info['name']
+ status = p.info['status'][:5] if p.info['status'] else ''
+
+ line = templ % (
+ user[:10],
+ p.info['pid'],
+ memp,
+ vms,
+ rss,
+ nice,
+ status,
+ ctime,
+ cputime,
+ cmdline)
+ print(line[:get_terminal_size()[0]]) # NOQA
+
+
+del memoize_when_activated, division
+if sys.version_info[0] < 3:
+ del num, x
+
+if __name__ == "__main__":
+ test()
diff --git a/lib/psutil/_common.py b/lib/psutil/_common.py
new file mode 100644
index 0000000..3414e8c
--- /dev/null
+++ b/lib/psutil/_common.py
@@ -0,0 +1,899 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Common objects shared by __init__.py and _ps*.py modules."""
+
+# Note: this module is imported by setup.py so it should not import
+# psutil or third-party modules.
+
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import contextlib
+import errno
+import functools
+import os
+import socket
+import stat
+import sys
+import threading
+import warnings
+from collections import namedtuple
+from socket import AF_INET
+from socket import SOCK_DGRAM
+from socket import SOCK_STREAM
+
+
+try:
+ from socket import AF_INET6
+except ImportError:
+ AF_INET6 = None
+try:
+ from socket import AF_UNIX
+except ImportError:
+ AF_UNIX = None
+
+if sys.version_info >= (3, 4):
+ import enum
+else:
+ enum = None
+
+
+# can't take it from _common.py as this script is imported by setup.py
+PY3 = sys.version_info[0] == 3
+PSUTIL_DEBUG = bool(os.getenv('PSUTIL_DEBUG', 0))
+_DEFAULT = object()
+
+__all__ = [
+ # OS constants
+ 'FREEBSD', 'BSD', 'LINUX', 'NETBSD', 'OPENBSD', 'MACOS', 'OSX', 'POSIX',
+ 'SUNOS', 'WINDOWS',
+ # connection constants
+ 'CONN_CLOSE', 'CONN_CLOSE_WAIT', 'CONN_CLOSING', 'CONN_ESTABLISHED',
+ 'CONN_FIN_WAIT1', 'CONN_FIN_WAIT2', 'CONN_LAST_ACK', 'CONN_LISTEN',
+ 'CONN_NONE', 'CONN_SYN_RECV', 'CONN_SYN_SENT', 'CONN_TIME_WAIT',
+ # net constants
+ 'NIC_DUPLEX_FULL', 'NIC_DUPLEX_HALF', 'NIC_DUPLEX_UNKNOWN',
+ # process status constants
+ 'STATUS_DEAD', 'STATUS_DISK_SLEEP', 'STATUS_IDLE', 'STATUS_LOCKED',
+ 'STATUS_RUNNING', 'STATUS_SLEEPING', 'STATUS_STOPPED', 'STATUS_SUSPENDED',
+ 'STATUS_TRACING_STOP', 'STATUS_WAITING', 'STATUS_WAKE_KILL',
+ 'STATUS_WAKING', 'STATUS_ZOMBIE', 'STATUS_PARKED',
+ # other constants
+ 'ENCODING', 'ENCODING_ERRS', 'AF_INET6',
+ # named tuples
+ 'pconn', 'pcputimes', 'pctxsw', 'pgids', 'pio', 'pionice', 'popenfile',
+ 'pthread', 'puids', 'sconn', 'scpustats', 'sdiskio', 'sdiskpart',
+ 'sdiskusage', 'snetio', 'snicaddr', 'snicstats', 'sswap', 'suser',
+ # utility functions
+ 'conn_tmap', 'deprecated_method', 'isfile_strict', 'memoize',
+ 'parse_environ_block', 'path_exists_strict', 'usage_percent',
+ 'supports_ipv6', 'sockfam_to_enum', 'socktype_to_enum', "wrap_numbers",
+ 'open_text', 'open_binary', 'cat', 'bcat',
+ 'bytes2human', 'conn_to_ntuple', 'debug',
+ # shell utils
+ 'hilite', 'term_supports_colors', 'print_color',
+]
+
+
+# ===================================================================
+# --- OS constants
+# ===================================================================
+
+
+POSIX = os.name == "posix"
+WINDOWS = os.name == "nt"
+LINUX = sys.platform.startswith("linux")
+MACOS = sys.platform.startswith("darwin")
+OSX = MACOS # deprecated alias
+FREEBSD = sys.platform.startswith(("freebsd", "midnightbsd"))
+OPENBSD = sys.platform.startswith("openbsd")
+NETBSD = sys.platform.startswith("netbsd")
+BSD = FREEBSD or OPENBSD or NETBSD
+SUNOS = sys.platform.startswith(("sunos", "solaris"))
+AIX = sys.platform.startswith("aix")
+
+
+# ===================================================================
+# --- API constants
+# ===================================================================
+
+
+# Process.status()
+STATUS_RUNNING = "running"
+STATUS_SLEEPING = "sleeping"
+STATUS_DISK_SLEEP = "disk-sleep"
+STATUS_STOPPED = "stopped"
+STATUS_TRACING_STOP = "tracing-stop"
+STATUS_ZOMBIE = "zombie"
+STATUS_DEAD = "dead"
+STATUS_WAKE_KILL = "wake-kill"
+STATUS_WAKING = "waking"
+STATUS_IDLE = "idle" # Linux, macOS, FreeBSD
+STATUS_LOCKED = "locked" # FreeBSD
+STATUS_WAITING = "waiting" # FreeBSD
+STATUS_SUSPENDED = "suspended" # NetBSD
+STATUS_PARKED = "parked" # Linux
+
+# Process.connections() and psutil.net_connections()
+CONN_ESTABLISHED = "ESTABLISHED"
+CONN_SYN_SENT = "SYN_SENT"
+CONN_SYN_RECV = "SYN_RECV"
+CONN_FIN_WAIT1 = "FIN_WAIT1"
+CONN_FIN_WAIT2 = "FIN_WAIT2"
+CONN_TIME_WAIT = "TIME_WAIT"
+CONN_CLOSE = "CLOSE"
+CONN_CLOSE_WAIT = "CLOSE_WAIT"
+CONN_LAST_ACK = "LAST_ACK"
+CONN_LISTEN = "LISTEN"
+CONN_CLOSING = "CLOSING"
+CONN_NONE = "NONE"
+
+# net_if_stats()
+if enum is None:
+ NIC_DUPLEX_FULL = 2
+ NIC_DUPLEX_HALF = 1
+ NIC_DUPLEX_UNKNOWN = 0
+else:
+ class NicDuplex(enum.IntEnum):
+ NIC_DUPLEX_FULL = 2
+ NIC_DUPLEX_HALF = 1
+ NIC_DUPLEX_UNKNOWN = 0
+
+ globals().update(NicDuplex.__members__)
+
+# sensors_battery()
+if enum is None:
+ POWER_TIME_UNKNOWN = -1
+ POWER_TIME_UNLIMITED = -2
+else:
+ class BatteryTime(enum.IntEnum):
+ POWER_TIME_UNKNOWN = -1
+ POWER_TIME_UNLIMITED = -2
+
+ globals().update(BatteryTime.__members__)
+
+# --- others
+
+ENCODING = sys.getfilesystemencoding()
+if not PY3:
+ ENCODING_ERRS = "replace"
+else:
+ try:
+ ENCODING_ERRS = sys.getfilesystemencodeerrors() # py 3.6
+ except AttributeError:
+ ENCODING_ERRS = "surrogateescape" if POSIX else "replace"
+
+
+# ===================================================================
+# --- namedtuples
+# ===================================================================
+
+# --- for system functions
+
+# psutil.swap_memory()
+sswap = namedtuple('sswap', ['total', 'used', 'free', 'percent', 'sin',
+ 'sout'])
+# psutil.disk_usage()
+sdiskusage = namedtuple('sdiskusage', ['total', 'used', 'free', 'percent'])
+# psutil.disk_io_counters()
+sdiskio = namedtuple('sdiskio', ['read_count', 'write_count',
+ 'read_bytes', 'write_bytes',
+ 'read_time', 'write_time'])
+# psutil.disk_partitions()
+sdiskpart = namedtuple('sdiskpart', ['device', 'mountpoint', 'fstype', 'opts',
+ 'maxfile', 'maxpath'])
+# psutil.net_io_counters()
+snetio = namedtuple('snetio', ['bytes_sent', 'bytes_recv',
+ 'packets_sent', 'packets_recv',
+ 'errin', 'errout',
+ 'dropin', 'dropout'])
+# psutil.users()
+suser = namedtuple('suser', ['name', 'terminal', 'host', 'started', 'pid'])
+# psutil.net_connections()
+sconn = namedtuple('sconn', ['fd', 'family', 'type', 'laddr', 'raddr',
+ 'status', 'pid'])
+# psutil.net_if_addrs()
+snicaddr = namedtuple('snicaddr',
+ ['family', 'address', 'netmask', 'broadcast', 'ptp'])
+# psutil.net_if_stats()
+snicstats = namedtuple('snicstats',
+ ['isup', 'duplex', 'speed', 'mtu', 'flags'])
+# psutil.cpu_stats()
+scpustats = namedtuple(
+ 'scpustats', ['ctx_switches', 'interrupts', 'soft_interrupts', 'syscalls'])
+# psutil.cpu_freq()
+scpufreq = namedtuple('scpufreq', ['current', 'min', 'max'])
+# psutil.sensors_temperatures()
+shwtemp = namedtuple(
+ 'shwtemp', ['label', 'current', 'high', 'critical'])
+# psutil.sensors_battery()
+sbattery = namedtuple('sbattery', ['percent', 'secsleft', 'power_plugged'])
+# psutil.sensors_fans()
+sfan = namedtuple('sfan', ['label', 'current'])
+
+# --- for Process methods
+
+# psutil.Process.cpu_times()
+pcputimes = namedtuple('pcputimes',
+ ['user', 'system', 'children_user', 'children_system'])
+# psutil.Process.open_files()
+popenfile = namedtuple('popenfile', ['path', 'fd'])
+# psutil.Process.threads()
+pthread = namedtuple('pthread', ['id', 'user_time', 'system_time'])
+# psutil.Process.uids()
+puids = namedtuple('puids', ['real', 'effective', 'saved'])
+# psutil.Process.gids()
+pgids = namedtuple('pgids', ['real', 'effective', 'saved'])
+# psutil.Process.io_counters()
+pio = namedtuple('pio', ['read_count', 'write_count',
+ 'read_bytes', 'write_bytes'])
+# psutil.Process.ionice()
+pionice = namedtuple('pionice', ['ioclass', 'value'])
+# psutil.Process.ctx_switches()
+pctxsw = namedtuple('pctxsw', ['voluntary', 'involuntary'])
+# psutil.Process.connections()
+pconn = namedtuple('pconn', ['fd', 'family', 'type', 'laddr', 'raddr',
+ 'status'])
+
+# psutil.connections() and psutil.Process.connections()
+addr = namedtuple('addr', ['ip', 'port'])
+
+
+# ===================================================================
+# --- Process.connections() 'kind' parameter mapping
+# ===================================================================
+
+
+conn_tmap = {
+ "all": ([AF_INET, AF_INET6, AF_UNIX], [SOCK_STREAM, SOCK_DGRAM]),
+ "tcp": ([AF_INET, AF_INET6], [SOCK_STREAM]),
+ "tcp4": ([AF_INET], [SOCK_STREAM]),
+ "udp": ([AF_INET, AF_INET6], [SOCK_DGRAM]),
+ "udp4": ([AF_INET], [SOCK_DGRAM]),
+ "inet": ([AF_INET, AF_INET6], [SOCK_STREAM, SOCK_DGRAM]),
+ "inet4": ([AF_INET], [SOCK_STREAM, SOCK_DGRAM]),
+ "inet6": ([AF_INET6], [SOCK_STREAM, SOCK_DGRAM]),
+}
+
+if AF_INET6 is not None:
+ conn_tmap.update({
+ "tcp6": ([AF_INET6], [SOCK_STREAM]),
+ "udp6": ([AF_INET6], [SOCK_DGRAM]),
+ })
+
+if AF_UNIX is not None:
+ conn_tmap.update({
+ "unix": ([AF_UNIX], [SOCK_STREAM, SOCK_DGRAM]),
+ })
+
+
+# =====================================================================
+# --- Exceptions
+# =====================================================================
+
+
+class Error(Exception):
+ """Base exception class. All other psutil exceptions inherit
+ from this one.
+ """
+ __module__ = 'psutil'
+
+ def _infodict(self, attrs):
+ info = collections.OrderedDict()
+ for name in attrs:
+ value = getattr(self, name, None)
+ if value:
+ info[name] = value
+ elif name == "pid" and value == 0:
+ info[name] = value
+ return info
+
+ def __str__(self):
+ # invoked on `raise Error`
+ info = self._infodict(("pid", "ppid", "name"))
+ if info:
+ details = "(%s)" % ", ".join(
+ ["%s=%r" % (k, v) for k, v in info.items()])
+ else:
+ details = None
+ return " ".join([x for x in (getattr(self, "msg", ""), details) if x])
+
+ def __repr__(self):
+ # invoked on `repr(Error)`
+ info = self._infodict(("pid", "ppid", "name", "seconds", "msg"))
+ details = ", ".join(["%s=%r" % (k, v) for k, v in info.items()])
+ return "psutil.%s(%s)" % (self.__class__.__name__, details)
+
+
+class NoSuchProcess(Error):
+ """Exception raised when a process with a certain PID doesn't
+ or no longer exists.
+ """
+ __module__ = 'psutil'
+
+ def __init__(self, pid, name=None, msg=None):
+ Error.__init__(self)
+ self.pid = pid
+ self.name = name
+ self.msg = msg or "process no longer exists"
+
+
+class ZombieProcess(NoSuchProcess):
+ """Exception raised when querying a zombie process. This is
+ raised on macOS, BSD and Solaris only, and not always: depending
+ on the query the OS may be able to succeed anyway.
+ On Linux all zombie processes are querable (hence this is never
+ raised). Windows doesn't have zombie processes.
+ """
+ __module__ = 'psutil'
+
+ def __init__(self, pid, name=None, ppid=None, msg=None):
+ NoSuchProcess.__init__(self, pid, name, msg)
+ self.ppid = ppid
+ self.msg = msg or "PID still exists but it's a zombie"
+
+
+class AccessDenied(Error):
+ """Exception raised when permission to perform an action is denied."""
+ __module__ = 'psutil'
+
+ def __init__(self, pid=None, name=None, msg=None):
+ Error.__init__(self)
+ self.pid = pid
+ self.name = name
+ self.msg = msg or ""
+
+
+class TimeoutExpired(Error):
+ """Raised on Process.wait(timeout) if timeout expires and process
+ is still alive.
+ """
+ __module__ = 'psutil'
+
+ def __init__(self, seconds, pid=None, name=None):
+ Error.__init__(self)
+ self.seconds = seconds
+ self.pid = pid
+ self.name = name
+ self.msg = "timeout after %s seconds" % seconds
+
+
+# ===================================================================
+# --- utils
+# ===================================================================
+
+
+def usage_percent(used, total, round_=None):
+ """Calculate percentage usage of 'used' against 'total'."""
+ try:
+ ret = (float(used) / total) * 100
+ except ZeroDivisionError:
+ return 0.0
+ else:
+ if round_ is not None:
+ ret = round(ret, round_)
+ return ret
+
+
+def memoize(fun):
+ """A simple memoize decorator for functions supporting (hashable)
+ positional arguments.
+ It also provides a cache_clear() function for clearing the cache:
+
+ >>> @memoize
+ ... def foo()
+ ... return 1
+ ...
+ >>> foo()
+ 1
+ >>> foo.cache_clear()
+ >>>
+ """
+ @functools.wraps(fun)
+ def wrapper(*args, **kwargs):
+ key = (args, frozenset(sorted(kwargs.items())))
+ try:
+ return cache[key]
+ except KeyError:
+ ret = cache[key] = fun(*args, **kwargs)
+ return ret
+
+ def cache_clear():
+ """Clear cache."""
+ cache.clear()
+
+ cache = {}
+ wrapper.cache_clear = cache_clear
+ return wrapper
+
+
+def memoize_when_activated(fun):
+ """A memoize decorator which is disabled by default. It can be
+ activated and deactivated on request.
+ For efficiency reasons it can be used only against class methods
+ accepting no arguments.
+
+ >>> class Foo:
+ ... @memoize
+ ... def foo()
+ ... print(1)
+ ...
+ >>> f = Foo()
+ >>> # deactivated (default)
+ >>> foo()
+ 1
+ >>> foo()
+ 1
+ >>>
+ >>> # activated
+ >>> foo.cache_activate(self)
+ >>> foo()
+ 1
+ >>> foo()
+ >>> foo()
+ >>>
+ """
+ @functools.wraps(fun)
+ def wrapper(self):
+ try:
+ # case 1: we previously entered oneshot() ctx
+ ret = self._cache[fun]
+ except AttributeError:
+ # case 2: we never entered oneshot() ctx
+ return fun(self)
+ except KeyError:
+ # case 3: we entered oneshot() ctx but there's no cache
+ # for this entry yet
+ ret = fun(self)
+ try:
+ self._cache[fun] = ret
+ except AttributeError:
+ # multi-threading race condition, see:
+ # https://github.com/giampaolo/psutil/issues/1948
+ pass
+ return ret
+
+ def cache_activate(proc):
+ """Activate cache. Expects a Process instance. Cache will be
+ stored as a "_cache" instance attribute."""
+ proc._cache = {}
+
+ def cache_deactivate(proc):
+ """Deactivate and clear cache."""
+ try:
+ del proc._cache
+ except AttributeError:
+ pass
+
+ wrapper.cache_activate = cache_activate
+ wrapper.cache_deactivate = cache_deactivate
+ return wrapper
+
+
+def isfile_strict(path):
+ """Same as os.path.isfile() but does not swallow EACCES / EPERM
+ exceptions, see:
+ http://mail.python.org/pipermail/python-dev/2012-June/120787.html
+ """
+ try:
+ st = os.stat(path)
+ except OSError as err:
+ if err.errno in (errno.EPERM, errno.EACCES):
+ raise
+ return False
+ else:
+ return stat.S_ISREG(st.st_mode)
+
+
+def path_exists_strict(path):
+ """Same as os.path.exists() but does not swallow EACCES / EPERM
+ exceptions, see:
+ http://mail.python.org/pipermail/python-dev/2012-June/120787.html
+ """
+ try:
+ os.stat(path)
+ except OSError as err:
+ if err.errno in (errno.EPERM, errno.EACCES):
+ raise
+ return False
+ else:
+ return True
+
+
+@memoize
+def supports_ipv6():
+ """Return True if IPv6 is supported on this platform."""
+ if not socket.has_ipv6 or AF_INET6 is None:
+ return False
+ try:
+ sock = socket.socket(AF_INET6, socket.SOCK_STREAM)
+ with contextlib.closing(sock):
+ sock.bind(("::1", 0))
+ return True
+ except socket.error:
+ return False
+
+
+def parse_environ_block(data):
+ """Parse a C environ block of environment variables into a dictionary."""
+ # The block is usually raw data from the target process. It might contain
+ # trailing garbage and lines that do not look like assignments.
+ ret = {}
+ pos = 0
+
+ # localize global variable to speed up access.
+ WINDOWS_ = WINDOWS
+ while True:
+ next_pos = data.find("\0", pos)
+ # nul byte at the beginning or double nul byte means finish
+ if next_pos <= pos:
+ break
+ # there might not be an equals sign
+ equal_pos = data.find("=", pos, next_pos)
+ if equal_pos > pos:
+ key = data[pos:equal_pos]
+ value = data[equal_pos + 1:next_pos]
+ # Windows expects environment variables to be uppercase only
+ if WINDOWS_:
+ key = key.upper()
+ ret[key] = value
+ pos = next_pos + 1
+
+ return ret
+
+
+def sockfam_to_enum(num):
+ """Convert a numeric socket family value to an IntEnum member.
+ If it's not a known member, return the numeric value itself.
+ """
+ if enum is None:
+ return num
+ else: # pragma: no cover
+ try:
+ return socket.AddressFamily(num)
+ except ValueError:
+ return num
+
+
+def socktype_to_enum(num):
+ """Convert a numeric socket type value to an IntEnum member.
+ If it's not a known member, return the numeric value itself.
+ """
+ if enum is None:
+ return num
+ else: # pragma: no cover
+ try:
+ return socket.SocketKind(num)
+ except ValueError:
+ return num
+
+
+def conn_to_ntuple(fd, fam, type_, laddr, raddr, status, status_map, pid=None):
+ """Convert a raw connection tuple to a proper ntuple."""
+ if fam in (socket.AF_INET, AF_INET6):
+ if laddr:
+ laddr = addr(*laddr)
+ if raddr:
+ raddr = addr(*raddr)
+ if type_ == socket.SOCK_STREAM and fam in (AF_INET, AF_INET6):
+ status = status_map.get(status, CONN_NONE)
+ else:
+ status = CONN_NONE # ignore whatever C returned to us
+ fam = sockfam_to_enum(fam)
+ type_ = socktype_to_enum(type_)
+ if pid is None:
+ return pconn(fd, fam, type_, laddr, raddr, status)
+ else:
+ return sconn(fd, fam, type_, laddr, raddr, status, pid)
+
+
+def deprecated_method(replacement):
+ """A decorator which can be used to mark a method as deprecated
+ 'replcement' is the method name which will be called instead.
+ """
+ def outer(fun):
+ msg = "%s() is deprecated and will be removed; use %s() instead" % (
+ fun.__name__, replacement)
+ if fun.__doc__ is None:
+ fun.__doc__ = msg
+
+ @functools.wraps(fun)
+ def inner(self, *args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
+ return getattr(self, replacement)(*args, **kwargs)
+ return inner
+ return outer
+
+
+class _WrapNumbers:
+ """Watches numbers so that they don't overflow and wrap
+ (reset to zero).
+ """
+
+ def __init__(self):
+ self.lock = threading.Lock()
+ self.cache = {}
+ self.reminders = {}
+ self.reminder_keys = {}
+
+ def _add_dict(self, input_dict, name):
+ assert name not in self.cache
+ assert name not in self.reminders
+ assert name not in self.reminder_keys
+ self.cache[name] = input_dict
+ self.reminders[name] = collections.defaultdict(int)
+ self.reminder_keys[name] = collections.defaultdict(set)
+
+ def _remove_dead_reminders(self, input_dict, name):
+ """In case the number of keys changed between calls (e.g. a
+ disk disappears) this removes the entry from self.reminders.
+ """
+ old_dict = self.cache[name]
+ gone_keys = set(old_dict.keys()) - set(input_dict.keys())
+ for gone_key in gone_keys:
+ for remkey in self.reminder_keys[name][gone_key]:
+ del self.reminders[name][remkey]
+ del self.reminder_keys[name][gone_key]
+
+ def run(self, input_dict, name):
+ """Cache dict and sum numbers which overflow and wrap.
+ Return an updated copy of `input_dict`
+ """
+ if name not in self.cache:
+ # This was the first call.
+ self._add_dict(input_dict, name)
+ return input_dict
+
+ self._remove_dead_reminders(input_dict, name)
+
+ old_dict = self.cache[name]
+ new_dict = {}
+ for key in input_dict.keys():
+ input_tuple = input_dict[key]
+ try:
+ old_tuple = old_dict[key]
+ except KeyError:
+ # The input dict has a new key (e.g. a new disk or NIC)
+ # which didn't exist in the previous call.
+ new_dict[key] = input_tuple
+ continue
+
+ bits = []
+ for i in range(len(input_tuple)):
+ input_value = input_tuple[i]
+ old_value = old_tuple[i]
+ remkey = (key, i)
+ if input_value < old_value:
+ # it wrapped!
+ self.reminders[name][remkey] += old_value
+ self.reminder_keys[name][key].add(remkey)
+ bits.append(input_value + self.reminders[name][remkey])
+
+ new_dict[key] = tuple(bits)
+
+ self.cache[name] = input_dict
+ return new_dict
+
+ def cache_clear(self, name=None):
+ """Clear the internal cache, optionally only for function 'name'."""
+ with self.lock:
+ if name is None:
+ self.cache.clear()
+ self.reminders.clear()
+ self.reminder_keys.clear()
+ else:
+ self.cache.pop(name, None)
+ self.reminders.pop(name, None)
+ self.reminder_keys.pop(name, None)
+
+ def cache_info(self):
+ """Return internal cache dicts as a tuple of 3 elements."""
+ with self.lock:
+ return (self.cache, self.reminders, self.reminder_keys)
+
+
+def wrap_numbers(input_dict, name):
+ """Given an `input_dict` and a function `name`, adjust the numbers
+ which "wrap" (restart from zero) across different calls by adding
+ "old value" to "new value" and return an updated dict.
+ """
+ with _wn.lock:
+ return _wn.run(input_dict, name)
+
+
+_wn = _WrapNumbers()
+wrap_numbers.cache_clear = _wn.cache_clear
+wrap_numbers.cache_info = _wn.cache_info
+
+
+# The read buffer size for open() builtin. This (also) dictates how
+# much data we read(2) when iterating over file lines as in:
+# >>> with open(file) as f:
+# ... for line in f:
+# ... ...
+# Default per-line buffer size for binary files is 1K. For text files
+# is 8K. We use a bigger buffer (32K) in order to have more consistent
+# results when reading /proc pseudo files on Linux, see:
+# https://github.com/giampaolo/psutil/issues/2050
+# On Python 2 this also speeds up the reading of big files:
+# (namely /proc/{pid}/smaps and /proc/net/*):
+# https://github.com/giampaolo/psutil/issues/708
+FILE_READ_BUFFER_SIZE = 32 * 1024
+
+
+def open_binary(fname):
+ return open(fname, "rb", buffering=FILE_READ_BUFFER_SIZE)
+
+
+def open_text(fname):
+ """On Python 3 opens a file in text mode by using fs encoding and
+ a proper en/decoding errors handler.
+ On Python 2 this is just an alias for open(name, 'rt').
+ """
+ if not PY3:
+ return open(fname, "rt", buffering=FILE_READ_BUFFER_SIZE)
+
+ # See:
+ # https://github.com/giampaolo/psutil/issues/675
+ # https://github.com/giampaolo/psutil/pull/733
+ fobj = open(fname, "rt", buffering=FILE_READ_BUFFER_SIZE,
+ encoding=ENCODING, errors=ENCODING_ERRS)
+ try:
+ # Dictates per-line read(2) buffer size. Defaults is 8k. See:
+ # https://github.com/giampaolo/psutil/issues/2050#issuecomment-1013387546
+ fobj._CHUNK_SIZE = FILE_READ_BUFFER_SIZE
+ except AttributeError:
+ pass
+ except Exception:
+ fobj.close()
+ raise
+
+ return fobj
+
+
+def cat(fname, fallback=_DEFAULT, _open=open_text):
+ """Read entire file content and return it as a string. File is
+ opened in text mode. If specified, `fallback` is the value
+ returned in case of error, either if the file does not exist or
+ it can't be read().
+ """
+ if fallback is _DEFAULT:
+ with _open(fname) as f:
+ return f.read()
+ else:
+ try:
+ with _open(fname) as f:
+ return f.read()
+ except (IOError, OSError):
+ return fallback
+
+
+def bcat(fname, fallback=_DEFAULT):
+ """Same as above but opens file in binary mode."""
+ return cat(fname, fallback=fallback, _open=open_binary)
+
+
+def bytes2human(n, format="%(value).1f%(symbol)s"):
+ """Used by various scripts. See:
+ http://goo.gl/zeJZl
+
+ >>> bytes2human(10000)
+ '9.8K'
+ >>> bytes2human(100001221)
+ '95.4M'
+ """
+ symbols = ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y')
+ prefix = {}
+ for i, s in enumerate(symbols[1:]):
+ prefix[s] = 1 << (i + 1) * 10
+ for symbol in reversed(symbols[1:]):
+ if n >= prefix[symbol]:
+ value = float(n) / prefix[symbol]
+ return format % locals()
+ return format % dict(symbol=symbols[0], value=n)
+
+
+def get_procfs_path():
+ """Return updated psutil.PROCFS_PATH constant."""
+ return sys.modules['psutil'].PROCFS_PATH
+
+
+if PY3:
+ def decode(s):
+ return s.decode(encoding=ENCODING, errors=ENCODING_ERRS)
+else:
+ def decode(s):
+ return s
+
+
+# =====================================================================
+# --- shell utils
+# =====================================================================
+
+
+@memoize
+def term_supports_colors(file=sys.stdout): # pragma: no cover
+ if os.name == 'nt':
+ return True
+ try:
+ import curses
+ assert file.isatty()
+ curses.setupterm()
+ assert curses.tigetnum("colors") > 0
+ except Exception:
+ return False
+ else:
+ return True
+
+
+def hilite(s, color=None, bold=False): # pragma: no cover
+ """Return an highlighted version of 'string'."""
+ if not term_supports_colors():
+ return s
+ attr = []
+ colors = dict(green='32', red='91', brown='33', yellow='93', blue='34',
+ violet='35', lightblue='36', grey='37', darkgrey='30')
+ colors[None] = '29'
+ try:
+ color = colors[color]
+ except KeyError:
+ raise ValueError("invalid color %r; choose between %s" % (
+ list(colors.keys())))
+ attr.append(color)
+ if bold:
+ attr.append('1')
+ return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), s)
+
+
+def print_color(
+ s, color=None, bold=False, file=sys.stdout): # pragma: no cover
+ """Print a colorized version of string."""
+ if not term_supports_colors():
+ print(s, file=file) # NOQA
+ elif POSIX:
+ print(hilite(s, color, bold), file=file) # NOQA
+ else:
+ import ctypes
+
+ DEFAULT_COLOR = 7
+ GetStdHandle = ctypes.windll.Kernel32.GetStdHandle
+ SetConsoleTextAttribute = \
+ ctypes.windll.Kernel32.SetConsoleTextAttribute
+
+ colors = dict(green=2, red=4, brown=6, yellow=6)
+ colors[None] = DEFAULT_COLOR
+ try:
+ color = colors[color]
+ except KeyError:
+ raise ValueError("invalid color %r; choose between %r" % (
+ color, list(colors.keys())))
+ if bold and color <= 7:
+ color += 8
+
+ handle_id = -12 if file is sys.stderr else -11
+ GetStdHandle.restype = ctypes.c_ulong
+ handle = GetStdHandle(handle_id)
+ SetConsoleTextAttribute(handle, color)
+ try:
+ print(s, file=file) # NOQA
+ finally:
+ SetConsoleTextAttribute(handle, DEFAULT_COLOR)
+
+
+def debug(msg):
+ """If PSUTIL_DEBUG env var is set, print a debug message to stderr."""
+ if PSUTIL_DEBUG:
+ import inspect
+ fname, lineno, func_name, lines, index = inspect.getframeinfo(
+ inspect.currentframe().f_back)
+ if isinstance(msg, Exception):
+ if isinstance(msg, (OSError, IOError, EnvironmentError)):
+ # ...because str(exc) may contain info about the file name
+ msg = "ignoring %s" % msg
+ else:
+ msg = "ignoring %r" % msg
+ print("psutil-debug [%s:%s]> %s" % (fname, lineno, msg), # NOQA
+ file=sys.stderr)
diff --git a/lib/psutil/_compat.py b/lib/psutil/_compat.py
new file mode 100644
index 0000000..52e762b
--- /dev/null
+++ b/lib/psutil/_compat.py
@@ -0,0 +1,450 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Module which provides compatibility with older Python versions.
+This is more future-compatible rather than the opposite (prefer latest
+Python 3 way of doing things).
+"""
+
+import collections
+import contextlib
+import errno
+import functools
+import os
+import sys
+import types
+
+
+__all__ = [
+ # constants
+ "PY3",
+ # builtins
+ "long", "range", "super", "unicode", "basestring",
+ # literals
+ "u", "b",
+ # collections module
+ "lru_cache",
+ # shutil module
+ "which", "get_terminal_size",
+ # contextlib module
+ "redirect_stderr",
+ # python 3 exceptions
+ "FileNotFoundError", "PermissionError", "ProcessLookupError",
+ "InterruptedError", "ChildProcessError", "FileExistsError"]
+
+
+PY3 = sys.version_info[0] == 3
+_SENTINEL = object()
+
+if PY3:
+ long = int
+ xrange = range
+ unicode = str
+ basestring = str
+ range = range
+
+ def u(s):
+ return s
+
+ def b(s):
+ return s.encode("latin-1")
+else:
+ long = long
+ range = xrange
+ unicode = unicode
+ basestring = basestring
+
+ def u(s):
+ return unicode(s, "unicode_escape")
+
+ def b(s):
+ return s
+
+
+# --- builtins
+
+
+# Python 3 super().
+# Taken from "future" package.
+# Credit: Ryan Kelly
+if PY3:
+ super = super
+else:
+ _builtin_super = super
+
+ def super(type_=_SENTINEL, type_or_obj=_SENTINEL, framedepth=1):
+ """Like Python 3 builtin super(). If called without any arguments
+ it attempts to infer them at runtime.
+ """
+ if type_ is _SENTINEL:
+ f = sys._getframe(framedepth)
+ try:
+ # Get the function's first positional argument.
+ type_or_obj = f.f_locals[f.f_code.co_varnames[0]]
+ except (IndexError, KeyError):
+ raise RuntimeError('super() used in a function with no args')
+ try:
+ # Get the MRO so we can crawl it.
+ mro = type_or_obj.__mro__
+ except (AttributeError, RuntimeError):
+ try:
+ mro = type_or_obj.__class__.__mro__
+ except AttributeError:
+ raise RuntimeError('super() used in a non-newstyle class')
+ for type_ in mro:
+ # Find the class that owns the currently-executing method.
+ for meth in type_.__dict__.values():
+ # Drill down through any wrappers to the underlying func.
+ # This handles e.g. classmethod() and staticmethod().
+ try:
+ while not isinstance(meth, types.FunctionType):
+ if isinstance(meth, property):
+ # Calling __get__ on the property will invoke
+ # user code which might throw exceptions or
+ # have side effects
+ meth = meth.fget
+ else:
+ try:
+ meth = meth.__func__
+ except AttributeError:
+ meth = meth.__get__(type_or_obj, type_)
+ except (AttributeError, TypeError):
+ continue
+ if meth.func_code is f.f_code:
+ break # found
+ else:
+ # Not found. Move onto the next class in MRO.
+ continue
+ break # found
+ else:
+ raise RuntimeError('super() called outside a method')
+
+ # Dispatch to builtin super().
+ if type_or_obj is not _SENTINEL:
+ return _builtin_super(type_, type_or_obj)
+ return _builtin_super(type_)
+
+
+# --- exceptions
+
+
+if PY3:
+ FileNotFoundError = FileNotFoundError # NOQA
+ PermissionError = PermissionError # NOQA
+ ProcessLookupError = ProcessLookupError # NOQA
+ InterruptedError = InterruptedError # NOQA
+ ChildProcessError = ChildProcessError # NOQA
+ FileExistsError = FileExistsError # NOQA
+else:
+ # https://github.com/PythonCharmers/python-future/blob/exceptions/
+ # src/future/types/exceptions/pep3151.py
+ import platform
+
+ def _instance_checking_exception(base_exception=Exception):
+ def wrapped(instance_checker):
+ class TemporaryClass(base_exception):
+
+ def __init__(self, *args, **kwargs):
+ if len(args) == 1 and isinstance(args[0], TemporaryClass):
+ unwrap_me = args[0]
+ for attr in dir(unwrap_me):
+ if not attr.startswith('__'):
+ setattr(self, attr, getattr(unwrap_me, attr))
+ else:
+ super(TemporaryClass, self).__init__(*args, **kwargs)
+
+ class __metaclass__(type):
+ def __instancecheck__(cls, inst):
+ return instance_checker(inst)
+
+ def __subclasscheck__(cls, classinfo):
+ value = sys.exc_info()[1]
+ return isinstance(value, cls)
+
+ TemporaryClass.__name__ = instance_checker.__name__
+ TemporaryClass.__doc__ = instance_checker.__doc__
+ return TemporaryClass
+
+ return wrapped
+
+ @_instance_checking_exception(EnvironmentError)
+ def FileNotFoundError(inst):
+ return getattr(inst, 'errno', _SENTINEL) == errno.ENOENT
+
+ @_instance_checking_exception(EnvironmentError)
+ def ProcessLookupError(inst):
+ return getattr(inst, 'errno', _SENTINEL) == errno.ESRCH
+
+ @_instance_checking_exception(EnvironmentError)
+ def PermissionError(inst):
+ return getattr(inst, 'errno', _SENTINEL) in (
+ errno.EACCES, errno.EPERM)
+
+ @_instance_checking_exception(EnvironmentError)
+ def InterruptedError(inst):
+ return getattr(inst, 'errno', _SENTINEL) == errno.EINTR
+
+ @_instance_checking_exception(EnvironmentError)
+ def ChildProcessError(inst):
+ return getattr(inst, 'errno', _SENTINEL) == errno.ECHILD
+
+ @_instance_checking_exception(EnvironmentError)
+ def FileExistsError(inst):
+ return getattr(inst, 'errno', _SENTINEL) == errno.EEXIST
+
+ if platform.python_implementation() != "CPython":
+ try:
+ raise OSError(errno.EEXIST, "perm")
+ except FileExistsError:
+ pass
+ except OSError:
+ raise RuntimeError(
+ "broken or incompatible Python implementation, see: "
+ "https://github.com/giampaolo/psutil/issues/1659")
+
+
+# --- stdlib additions
+
+
+# py 3.2 functools.lru_cache
+# Taken from: http://code.activestate.com/recipes/578078
+# Credit: Raymond Hettinger
+try:
+ from functools import lru_cache
+except ImportError:
+ try:
+ from threading import RLock
+ except ImportError:
+ from dummy_threading import RLock
+
+ _CacheInfo = collections.namedtuple(
+ "CacheInfo", ["hits", "misses", "maxsize", "currsize"])
+
+ class _HashedSeq(list):
+ __slots__ = 'hashvalue'
+
+ def __init__(self, tup, hash=hash):
+ self[:] = tup
+ self.hashvalue = hash(tup)
+
+ def __hash__(self):
+ return self.hashvalue
+
+ def _make_key(args, kwds, typed,
+ kwd_mark=(_SENTINEL, ),
+ fasttypes=set((int, str, frozenset, type(None))), # noqa
+ sorted=sorted, tuple=tuple, type=type, len=len):
+ key = args
+ if kwds:
+ sorted_items = sorted(kwds.items())
+ key += kwd_mark
+ for item in sorted_items:
+ key += item
+ if typed:
+ key += tuple(type(v) for v in args)
+ if kwds:
+ key += tuple(type(v) for k, v in sorted_items)
+ elif len(key) == 1 and type(key[0]) in fasttypes:
+ return key[0]
+ return _HashedSeq(key)
+
+ def lru_cache(maxsize=100, typed=False):
+ """Least-recently-used cache decorator, see:
+ http://docs.python.org/3/library/functools.html#functools.lru_cache
+ """
+ def decorating_function(user_function):
+ cache = dict()
+ stats = [0, 0]
+ HITS, MISSES = 0, 1
+ make_key = _make_key
+ cache_get = cache.get
+ _len = len
+ lock = RLock()
+ root = []
+ root[:] = [root, root, None, None]
+ nonlocal_root = [root]
+ PREV, NEXT, KEY, RESULT = 0, 1, 2, 3
+ if maxsize == 0:
+ def wrapper(*args, **kwds):
+ result = user_function(*args, **kwds)
+ stats[MISSES] += 1
+ return result
+ elif maxsize is None:
+ def wrapper(*args, **kwds):
+ key = make_key(args, kwds, typed)
+ result = cache_get(key, root)
+ if result is not root:
+ stats[HITS] += 1
+ return result
+ result = user_function(*args, **kwds)
+ cache[key] = result
+ stats[MISSES] += 1
+ return result
+ else:
+ def wrapper(*args, **kwds):
+ if kwds or typed:
+ key = make_key(args, kwds, typed)
+ else:
+ key = args
+ lock.acquire()
+ try:
+ link = cache_get(key)
+ if link is not None:
+ root, = nonlocal_root
+ link_prev, link_next, key, result = link
+ link_prev[NEXT] = link_next
+ link_next[PREV] = link_prev
+ last = root[PREV]
+ last[NEXT] = root[PREV] = link
+ link[PREV] = last
+ link[NEXT] = root
+ stats[HITS] += 1
+ return result
+ finally:
+ lock.release()
+ result = user_function(*args, **kwds)
+ lock.acquire()
+ try:
+ root, = nonlocal_root
+ if key in cache:
+ pass
+ elif _len(cache) >= maxsize:
+ oldroot = root
+ oldroot[KEY] = key
+ oldroot[RESULT] = result
+ root = nonlocal_root[0] = oldroot[NEXT]
+ oldkey = root[KEY]
+ root[KEY] = root[RESULT] = None
+ del cache[oldkey]
+ cache[key] = oldroot
+ else:
+ last = root[PREV]
+ link = [last, root, key, result]
+ last[NEXT] = root[PREV] = cache[key] = link
+ stats[MISSES] += 1
+ finally:
+ lock.release()
+ return result
+
+ def cache_info():
+ """Report cache statistics"""
+ lock.acquire()
+ try:
+ return _CacheInfo(stats[HITS], stats[MISSES], maxsize,
+ len(cache))
+ finally:
+ lock.release()
+
+ def cache_clear():
+ """Clear the cache and cache statistics"""
+ lock.acquire()
+ try:
+ cache.clear()
+ root = nonlocal_root[0]
+ root[:] = [root, root, None, None]
+ stats[:] = [0, 0]
+ finally:
+ lock.release()
+
+ wrapper.__wrapped__ = user_function
+ wrapper.cache_info = cache_info
+ wrapper.cache_clear = cache_clear
+ return functools.update_wrapper(wrapper, user_function)
+
+ return decorating_function
+
+
+# python 3.3
+try:
+ from shutil import which
+except ImportError:
+ def which(cmd, mode=os.F_OK | os.X_OK, path=None):
+ """Given a command, mode, and a PATH string, return the path which
+ conforms to the given mode on the PATH, or None if there is no such
+ file.
+
+ `mode` defaults to os.F_OK | os.X_OK. `path` defaults to the result
+ of os.environ.get("PATH"), or can be overridden with a custom search
+ path.
+ """
+ def _access_check(fn, mode):
+ return (os.path.exists(fn) and os.access(fn, mode) and
+ not os.path.isdir(fn))
+
+ if os.path.dirname(cmd):
+ if _access_check(cmd, mode):
+ return cmd
+ return None
+
+ if path is None:
+ path = os.environ.get("PATH", os.defpath)
+ if not path:
+ return None
+ path = path.split(os.pathsep)
+
+ if sys.platform == "win32":
+ if os.curdir not in path:
+ path.insert(0, os.curdir)
+
+ pathext = os.environ.get("PATHEXT", "").split(os.pathsep)
+ if any(cmd.lower().endswith(ext.lower()) for ext in pathext):
+ files = [cmd]
+ else:
+ files = [cmd + ext for ext in pathext]
+ else:
+ files = [cmd]
+
+ seen = set()
+ for dir in path:
+ normdir = os.path.normcase(dir)
+ if normdir not in seen:
+ seen.add(normdir)
+ for thefile in files:
+ name = os.path.join(dir, thefile)
+ if _access_check(name, mode):
+ return name
+ return None
+
+
+# python 3.3
+try:
+ from shutil import get_terminal_size
+except ImportError:
+ def get_terminal_size(fallback=(80, 24)):
+ try:
+ import fcntl
+ import struct
+ import termios
+ except ImportError:
+ return fallback
+ else:
+ try:
+ # This should work on Linux.
+ res = struct.unpack(
+ 'hh', fcntl.ioctl(1, termios.TIOCGWINSZ, '1234'))
+ return (res[1], res[0])
+ except Exception:
+ return fallback
+
+
+# python 3.3
+try:
+ from subprocess import TimeoutExpired as SubprocessTimeoutExpired
+except ImportError:
+ class SubprocessTimeoutExpired:
+ pass
+
+
+# python 3.5
+try:
+ from contextlib import redirect_stderr
+except ImportError:
+ @contextlib.contextmanager
+ def redirect_stderr(new_target):
+ original = sys.stderr
+ try:
+ sys.stderr = new_target
+ yield new_target
+ finally:
+ sys.stderr = original
diff --git a/lib/psutil/_psaix.py b/lib/psutil/_psaix.py
new file mode 100644
index 0000000..2391478
--- /dev/null
+++ b/lib/psutil/_psaix.py
@@ -0,0 +1,555 @@
+# Copyright (c) 2009, Giampaolo Rodola'
+# Copyright (c) 2017, Arnon Yaari
+# All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""AIX platform implementation."""
+
+import functools
+import glob
+import os
+import re
+import subprocess
+import sys
+from collections import namedtuple
+
+from . import _common
+from . import _psposix
+from . import _psutil_aix as cext
+from . import _psutil_posix as cext_posix
+from ._common import NIC_DUPLEX_FULL
+from ._common import NIC_DUPLEX_HALF
+from ._common import NIC_DUPLEX_UNKNOWN
+from ._common import AccessDenied
+from ._common import NoSuchProcess
+from ._common import ZombieProcess
+from ._common import conn_to_ntuple
+from ._common import get_procfs_path
+from ._common import memoize_when_activated
+from ._common import usage_percent
+from ._compat import PY3
+from ._compat import FileNotFoundError
+from ._compat import PermissionError
+from ._compat import ProcessLookupError
+
+
+__extra__all__ = ["PROCFS_PATH"]
+
+
+# =====================================================================
+# --- globals
+# =====================================================================
+
+
+HAS_THREADS = hasattr(cext, "proc_threads")
+HAS_NET_IO_COUNTERS = hasattr(cext, "net_io_counters")
+HAS_PROC_IO_COUNTERS = hasattr(cext, "proc_io_counters")
+
+PAGE_SIZE = cext_posix.getpagesize()
+AF_LINK = cext_posix.AF_LINK
+
+PROC_STATUSES = {
+ cext.SIDL: _common.STATUS_IDLE,
+ cext.SZOMB: _common.STATUS_ZOMBIE,
+ cext.SACTIVE: _common.STATUS_RUNNING,
+ cext.SSWAP: _common.STATUS_RUNNING, # TODO what status is this?
+ cext.SSTOP: _common.STATUS_STOPPED,
+}
+
+TCP_STATUSES = {
+ cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED,
+ cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT,
+ cext.TCPS_SYN_RCVD: _common.CONN_SYN_RECV,
+ cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1,
+ cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2,
+ cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT,
+ cext.TCPS_CLOSED: _common.CONN_CLOSE,
+ cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT,
+ cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK,
+ cext.TCPS_LISTEN: _common.CONN_LISTEN,
+ cext.TCPS_CLOSING: _common.CONN_CLOSING,
+ cext.PSUTIL_CONN_NONE: _common.CONN_NONE,
+}
+
+proc_info_map = dict(
+ ppid=0,
+ rss=1,
+ vms=2,
+ create_time=3,
+ nice=4,
+ num_threads=5,
+ status=6,
+ ttynr=7)
+
+
+# =====================================================================
+# --- named tuples
+# =====================================================================
+
+
+# psutil.Process.memory_info()
+pmem = namedtuple('pmem', ['rss', 'vms'])
+# psutil.Process.memory_full_info()
+pfullmem = pmem
+# psutil.Process.cpu_times()
+scputimes = namedtuple('scputimes', ['user', 'system', 'idle', 'iowait'])
+# psutil.virtual_memory()
+svmem = namedtuple('svmem', ['total', 'available', 'percent', 'used', 'free'])
+
+
+# =====================================================================
+# --- memory
+# =====================================================================
+
+
+def virtual_memory():
+ total, avail, free, pinned, inuse = cext.virtual_mem()
+ percent = usage_percent((total - avail), total, round_=1)
+ return svmem(total, avail, percent, inuse, free)
+
+
+def swap_memory():
+ """Swap system memory as a (total, used, free, sin, sout) tuple."""
+ total, free, sin, sout = cext.swap_mem()
+ used = total - free
+ percent = usage_percent(used, total, round_=1)
+ return _common.sswap(total, used, free, percent, sin, sout)
+
+
+# =====================================================================
+# --- CPU
+# =====================================================================
+
+
+def cpu_times():
+ """Return system-wide CPU times as a named tuple"""
+ ret = cext.per_cpu_times()
+ return scputimes(*[sum(x) for x in zip(*ret)])
+
+
+def per_cpu_times():
+ """Return system per-CPU times as a list of named tuples"""
+ ret = cext.per_cpu_times()
+ return [scputimes(*x) for x in ret]
+
+
+def cpu_count_logical():
+ """Return the number of logical CPUs in the system."""
+ try:
+ return os.sysconf("SC_NPROCESSORS_ONLN")
+ except ValueError:
+ # mimic os.cpu_count() behavior
+ return None
+
+
+def cpu_count_cores():
+ cmd = "lsdev -Cc processor"
+ p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ stdout, stderr = p.communicate()
+ if PY3:
+ stdout, stderr = [x.decode(sys.stdout.encoding)
+ for x in (stdout, stderr)]
+ if p.returncode != 0:
+ raise RuntimeError("%r command error\n%s" % (cmd, stderr))
+ processors = stdout.strip().splitlines()
+ return len(processors) or None
+
+
+def cpu_stats():
+ """Return various CPU stats as a named tuple."""
+ ctx_switches, interrupts, soft_interrupts, syscalls = cext.cpu_stats()
+ return _common.scpustats(
+ ctx_switches, interrupts, soft_interrupts, syscalls)
+
+
+# =====================================================================
+# --- disks
+# =====================================================================
+
+
+disk_io_counters = cext.disk_io_counters
+disk_usage = _psposix.disk_usage
+
+
+def disk_partitions(all=False):
+ """Return system disk partitions."""
+ # TODO - the filtering logic should be better checked so that
+ # it tries to reflect 'df' as much as possible
+ retlist = []
+ partitions = cext.disk_partitions()
+ for partition in partitions:
+ device, mountpoint, fstype, opts = partition
+ if device == 'none':
+ device = ''
+ if not all:
+ # Differently from, say, Linux, we don't have a list of
+ # common fs types so the best we can do, AFAIK, is to
+ # filter by filesystem having a total size > 0.
+ if not disk_usage(mountpoint).total:
+ continue
+ maxfile = maxpath = None # set later
+ ntuple = _common.sdiskpart(device, mountpoint, fstype, opts,
+ maxfile, maxpath)
+ retlist.append(ntuple)
+ return retlist
+
+
+# =====================================================================
+# --- network
+# =====================================================================
+
+
+net_if_addrs = cext_posix.net_if_addrs
+
+if HAS_NET_IO_COUNTERS:
+ net_io_counters = cext.net_io_counters
+
+
+def net_connections(kind, _pid=-1):
+ """Return socket connections. If pid == -1 return system-wide
+ connections (as opposed to connections opened by one process only).
+ """
+ cmap = _common.conn_tmap
+ if kind not in cmap:
+ raise ValueError("invalid %r kind argument; choose between %s"
+ % (kind, ', '.join([repr(x) for x in cmap])))
+ families, types = _common.conn_tmap[kind]
+ rawlist = cext.net_connections(_pid)
+ ret = []
+ for item in rawlist:
+ fd, fam, type_, laddr, raddr, status, pid = item
+ if fam not in families:
+ continue
+ if type_ not in types:
+ continue
+ nt = conn_to_ntuple(fd, fam, type_, laddr, raddr, status,
+ TCP_STATUSES, pid=pid if _pid == -1 else None)
+ ret.append(nt)
+ return ret
+
+
+def net_if_stats():
+ """Get NIC stats (isup, duplex, speed, mtu)."""
+ duplex_map = {"Full": NIC_DUPLEX_FULL,
+ "Half": NIC_DUPLEX_HALF}
+ names = set([x[0] for x in net_if_addrs()])
+ ret = {}
+ for name in names:
+ mtu = cext_posix.net_if_mtu(name)
+ flags = cext_posix.net_if_flags(name)
+
+ # try to get speed and duplex
+ # TODO: rewrite this in C (entstat forks, so use truss -f to follow.
+ # looks like it is using an undocumented ioctl?)
+ duplex = ""
+ speed = 0
+ p = subprocess.Popen(["/usr/bin/entstat", "-d", name],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = p.communicate()
+ if PY3:
+ stdout, stderr = [x.decode(sys.stdout.encoding)
+ for x in (stdout, stderr)]
+ if p.returncode == 0:
+ re_result = re.search(
+ r"Running: (\d+) Mbps.*?(\w+) Duplex", stdout)
+ if re_result is not None:
+ speed = int(re_result.group(1))
+ duplex = re_result.group(2)
+
+ output_flags = ','.join(flags)
+ isup = 'running' in flags
+ duplex = duplex_map.get(duplex, NIC_DUPLEX_UNKNOWN)
+ ret[name] = _common.snicstats(isup, duplex, speed, mtu, output_flags)
+ return ret
+
+
+# =====================================================================
+# --- other system functions
+# =====================================================================
+
+
+def boot_time():
+ """The system boot time expressed in seconds since the epoch."""
+ return cext.boot_time()
+
+
+def users():
+ """Return currently connected users as a list of namedtuples."""
+ retlist = []
+ rawlist = cext.users()
+ localhost = (':0.0', ':0')
+ for item in rawlist:
+ user, tty, hostname, tstamp, user_process, pid = item
+ # note: the underlying C function includes entries about
+ # system boot, run level and others. We might want
+ # to use them in the future.
+ if not user_process:
+ continue
+ if hostname in localhost:
+ hostname = 'localhost'
+ nt = _common.suser(user, tty, hostname, tstamp, pid)
+ retlist.append(nt)
+ return retlist
+
+
+# =====================================================================
+# --- processes
+# =====================================================================
+
+
+def pids():
+ """Returns a list of PIDs currently running on the system."""
+ return [int(x) for x in os.listdir(get_procfs_path()) if x.isdigit()]
+
+
+def pid_exists(pid):
+ """Check for the existence of a unix pid."""
+ return os.path.exists(os.path.join(get_procfs_path(), str(pid), "psinfo"))
+
+
+def wrap_exceptions(fun):
+ """Call callable into a try/except clause and translate ENOENT,
+ EACCES and EPERM in NoSuchProcess or AccessDenied exceptions.
+ """
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return fun(self, *args, **kwargs)
+ except (FileNotFoundError, ProcessLookupError):
+ # ENOENT (no such file or directory) gets raised on open().
+ # ESRCH (no such process) can get raised on read() if
+ # process is gone in meantime.
+ if not pid_exists(self.pid):
+ raise NoSuchProcess(self.pid, self._name)
+ else:
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ except PermissionError:
+ raise AccessDenied(self.pid, self._name)
+ return wrapper
+
+
+class Process(object):
+ """Wrapper class around underlying C implementation."""
+
+ __slots__ = ["pid", "_name", "_ppid", "_procfs_path", "_cache"]
+
+ def __init__(self, pid):
+ self.pid = pid
+ self._name = None
+ self._ppid = None
+ self._procfs_path = get_procfs_path()
+
+ def oneshot_enter(self):
+ self._proc_basic_info.cache_activate(self)
+ self._proc_cred.cache_activate(self)
+
+ def oneshot_exit(self):
+ self._proc_basic_info.cache_deactivate(self)
+ self._proc_cred.cache_deactivate(self)
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _proc_basic_info(self):
+ return cext.proc_basic_info(self.pid, self._procfs_path)
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _proc_cred(self):
+ return cext.proc_cred(self.pid, self._procfs_path)
+
+ @wrap_exceptions
+ def name(self):
+ if self.pid == 0:
+ return "swapper"
+ # note: max 16 characters
+ return cext.proc_name(self.pid, self._procfs_path).rstrip("\x00")
+
+ @wrap_exceptions
+ def exe(self):
+ # there is no way to get executable path in AIX other than to guess,
+ # and guessing is more complex than what's in the wrapping class
+ cmdline = self.cmdline()
+ if not cmdline:
+ return ''
+ exe = cmdline[0]
+ if os.path.sep in exe:
+ # relative or absolute path
+ if not os.path.isabs(exe):
+ # if cwd has changed, we're out of luck - this may be wrong!
+ exe = os.path.abspath(os.path.join(self.cwd(), exe))
+ if (os.path.isabs(exe) and
+ os.path.isfile(exe) and
+ os.access(exe, os.X_OK)):
+ return exe
+ # not found, move to search in PATH using basename only
+ exe = os.path.basename(exe)
+ # search for exe name PATH
+ for path in os.environ["PATH"].split(":"):
+ possible_exe = os.path.abspath(os.path.join(path, exe))
+ if (os.path.isfile(possible_exe) and
+ os.access(possible_exe, os.X_OK)):
+ return possible_exe
+ return ''
+
+ @wrap_exceptions
+ def cmdline(self):
+ return cext.proc_args(self.pid)
+
+ @wrap_exceptions
+ def environ(self):
+ return cext.proc_environ(self.pid)
+
+ @wrap_exceptions
+ def create_time(self):
+ return self._proc_basic_info()[proc_info_map['create_time']]
+
+ @wrap_exceptions
+ def num_threads(self):
+ return self._proc_basic_info()[proc_info_map['num_threads']]
+
+ if HAS_THREADS:
+ @wrap_exceptions
+ def threads(self):
+ rawlist = cext.proc_threads(self.pid)
+ retlist = []
+ for thread_id, utime, stime in rawlist:
+ ntuple = _common.pthread(thread_id, utime, stime)
+ retlist.append(ntuple)
+ # The underlying C implementation retrieves all OS threads
+ # and filters them by PID. At this point we can't tell whether
+ # an empty list means there were no connections for process or
+ # process is no longer active so we force NSP in case the PID
+ # is no longer there.
+ if not retlist:
+ # will raise NSP if process is gone
+ os.stat('%s/%s' % (self._procfs_path, self.pid))
+ return retlist
+
+ @wrap_exceptions
+ def connections(self, kind='inet'):
+ ret = net_connections(kind, _pid=self.pid)
+ # The underlying C implementation retrieves all OS connections
+ # and filters them by PID. At this point we can't tell whether
+ # an empty list means there were no connections for process or
+ # process is no longer active so we force NSP in case the PID
+ # is no longer there.
+ if not ret:
+ # will raise NSP if process is gone
+ os.stat('%s/%s' % (self._procfs_path, self.pid))
+ return ret
+
+ @wrap_exceptions
+ def nice_get(self):
+ return cext_posix.getpriority(self.pid)
+
+ @wrap_exceptions
+ def nice_set(self, value):
+ return cext_posix.setpriority(self.pid, value)
+
+ @wrap_exceptions
+ def ppid(self):
+ self._ppid = self._proc_basic_info()[proc_info_map['ppid']]
+ return self._ppid
+
+ @wrap_exceptions
+ def uids(self):
+ real, effective, saved, _, _, _ = self._proc_cred()
+ return _common.puids(real, effective, saved)
+
+ @wrap_exceptions
+ def gids(self):
+ _, _, _, real, effective, saved = self._proc_cred()
+ return _common.puids(real, effective, saved)
+
+ @wrap_exceptions
+ def cpu_times(self):
+ cpu_times = cext.proc_cpu_times(self.pid, self._procfs_path)
+ return _common.pcputimes(*cpu_times)
+
+ @wrap_exceptions
+ def terminal(self):
+ ttydev = self._proc_basic_info()[proc_info_map['ttynr']]
+ # convert from 64-bit dev_t to 32-bit dev_t and then map the device
+ ttydev = (((ttydev & 0x0000FFFF00000000) >> 16) | (ttydev & 0xFFFF))
+ # try to match rdev of /dev/pts/* files ttydev
+ for dev in glob.glob("/dev/**/*"):
+ if os.stat(dev).st_rdev == ttydev:
+ return dev
+ return None
+
+ @wrap_exceptions
+ def cwd(self):
+ procfs_path = self._procfs_path
+ try:
+ result = os.readlink("%s/%s/cwd" % (procfs_path, self.pid))
+ return result.rstrip('/')
+ except FileNotFoundError:
+ os.stat("%s/%s" % (procfs_path, self.pid)) # raise NSP or AD
+ return None
+
+ @wrap_exceptions
+ def memory_info(self):
+ ret = self._proc_basic_info()
+ rss = ret[proc_info_map['rss']] * 1024
+ vms = ret[proc_info_map['vms']] * 1024
+ return pmem(rss, vms)
+
+ memory_full_info = memory_info
+
+ @wrap_exceptions
+ def status(self):
+ code = self._proc_basic_info()[proc_info_map['status']]
+ # XXX is '?' legit? (we're not supposed to return it anyway)
+ return PROC_STATUSES.get(code, '?')
+
+ def open_files(self):
+ # TODO rewrite without using procfiles (stat /proc/pid/fd/* and then
+ # find matching name of the inode)
+ p = subprocess.Popen(["/usr/bin/procfiles", "-n", str(self.pid)],
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = p.communicate()
+ if PY3:
+ stdout, stderr = [x.decode(sys.stdout.encoding)
+ for x in (stdout, stderr)]
+ if "no such process" in stderr.lower():
+ raise NoSuchProcess(self.pid, self._name)
+ procfiles = re.findall(r"(\d+): S_IFREG.*\s*.*name:(.*)\n", stdout)
+ retlist = []
+ for fd, path in procfiles:
+ path = path.strip()
+ if path.startswith("//"):
+ path = path[1:]
+ if path.lower() == "cannot be retrieved":
+ continue
+ retlist.append(_common.popenfile(path, int(fd)))
+ return retlist
+
+ @wrap_exceptions
+ def num_fds(self):
+ if self.pid == 0: # no /proc/0/fd
+ return 0
+ return len(os.listdir("%s/%s/fd" % (self._procfs_path, self.pid)))
+
+ @wrap_exceptions
+ def num_ctx_switches(self):
+ return _common.pctxsw(
+ *cext.proc_num_ctx_switches(self.pid))
+
+ @wrap_exceptions
+ def wait(self, timeout=None):
+ return _psposix.wait_pid(self.pid, timeout, self._name)
+
+ if HAS_PROC_IO_COUNTERS:
+ @wrap_exceptions
+ def io_counters(self):
+ try:
+ rc, wc, rb, wb = cext.proc_io_counters(self.pid)
+ except OSError:
+ # if process is terminated, proc_io_counters returns OSError
+ # instead of NSP
+ if not pid_exists(self.pid):
+ raise NoSuchProcess(self.pid, self._name)
+ raise
+ return _common.pio(rc, wc, rb, wb)
diff --git a/lib/psutil/_psbsd.py b/lib/psutil/_psbsd.py
new file mode 100644
index 0000000..a25c96c
--- /dev/null
+++ b/lib/psutil/_psbsd.py
@@ -0,0 +1,927 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""FreeBSD, OpenBSD and NetBSD platforms implementation."""
+
+import contextlib
+import errno
+import functools
+import os
+import xml.etree.ElementTree as ET
+from collections import defaultdict
+from collections import namedtuple
+
+from . import _common
+from . import _psposix
+from . import _psutil_bsd as cext
+from . import _psutil_posix as cext_posix
+from ._common import FREEBSD
+from ._common import NETBSD
+from ._common import OPENBSD
+from ._common import AccessDenied
+from ._common import NoSuchProcess
+from ._common import ZombieProcess
+from ._common import conn_tmap
+from ._common import conn_to_ntuple
+from ._common import memoize
+from ._common import memoize_when_activated
+from ._common import usage_percent
+from ._compat import FileNotFoundError
+from ._compat import PermissionError
+from ._compat import ProcessLookupError
+from ._compat import which
+
+
+__extra__all__ = []
+
+
+# =====================================================================
+# --- globals
+# =====================================================================
+
+
+if FREEBSD:
+ PROC_STATUSES = {
+ cext.SIDL: _common.STATUS_IDLE,
+ cext.SRUN: _common.STATUS_RUNNING,
+ cext.SSLEEP: _common.STATUS_SLEEPING,
+ cext.SSTOP: _common.STATUS_STOPPED,
+ cext.SZOMB: _common.STATUS_ZOMBIE,
+ cext.SWAIT: _common.STATUS_WAITING,
+ cext.SLOCK: _common.STATUS_LOCKED,
+ }
+elif OPENBSD:
+ PROC_STATUSES = {
+ cext.SIDL: _common.STATUS_IDLE,
+ cext.SSLEEP: _common.STATUS_SLEEPING,
+ cext.SSTOP: _common.STATUS_STOPPED,
+ # According to /usr/include/sys/proc.h SZOMB is unused.
+ # test_zombie_process() shows that SDEAD is the right
+ # equivalent. Also it appears there's no equivalent of
+ # psutil.STATUS_DEAD. SDEAD really means STATUS_ZOMBIE.
+ # cext.SZOMB: _common.STATUS_ZOMBIE,
+ cext.SDEAD: _common.STATUS_ZOMBIE,
+ cext.SZOMB: _common.STATUS_ZOMBIE,
+ # From http://www.eecs.harvard.edu/~margo/cs161/videos/proc.h.txt
+ # OpenBSD has SRUN and SONPROC: SRUN indicates that a process
+ # is runnable but *not* yet running, i.e. is on a run queue.
+ # SONPROC indicates that the process is actually executing on
+ # a CPU, i.e. it is no longer on a run queue.
+ # As such we'll map SRUN to STATUS_WAKING and SONPROC to
+ # STATUS_RUNNING
+ cext.SRUN: _common.STATUS_WAKING,
+ cext.SONPROC: _common.STATUS_RUNNING,
+ }
+elif NETBSD:
+ PROC_STATUSES = {
+ cext.SIDL: _common.STATUS_IDLE,
+ cext.SSLEEP: _common.STATUS_SLEEPING,
+ cext.SSTOP: _common.STATUS_STOPPED,
+ cext.SZOMB: _common.STATUS_ZOMBIE,
+ cext.SRUN: _common.STATUS_WAKING,
+ cext.SONPROC: _common.STATUS_RUNNING,
+ }
+
+TCP_STATUSES = {
+ cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED,
+ cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT,
+ cext.TCPS_SYN_RECEIVED: _common.CONN_SYN_RECV,
+ cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1,
+ cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2,
+ cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT,
+ cext.TCPS_CLOSED: _common.CONN_CLOSE,
+ cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT,
+ cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK,
+ cext.TCPS_LISTEN: _common.CONN_LISTEN,
+ cext.TCPS_CLOSING: _common.CONN_CLOSING,
+ cext.PSUTIL_CONN_NONE: _common.CONN_NONE,
+}
+
+PAGESIZE = cext_posix.getpagesize()
+AF_LINK = cext_posix.AF_LINK
+
+HAS_PER_CPU_TIMES = hasattr(cext, "per_cpu_times")
+HAS_PROC_NUM_THREADS = hasattr(cext, "proc_num_threads")
+HAS_PROC_OPEN_FILES = hasattr(cext, 'proc_open_files')
+HAS_PROC_NUM_FDS = hasattr(cext, 'proc_num_fds')
+
+kinfo_proc_map = dict(
+ ppid=0,
+ status=1,
+ real_uid=2,
+ effective_uid=3,
+ saved_uid=4,
+ real_gid=5,
+ effective_gid=6,
+ saved_gid=7,
+ ttynr=8,
+ create_time=9,
+ ctx_switches_vol=10,
+ ctx_switches_unvol=11,
+ read_io_count=12,
+ write_io_count=13,
+ user_time=14,
+ sys_time=15,
+ ch_user_time=16,
+ ch_sys_time=17,
+ rss=18,
+ vms=19,
+ memtext=20,
+ memdata=21,
+ memstack=22,
+ cpunum=23,
+ name=24,
+)
+
+
+# =====================================================================
+# --- named tuples
+# =====================================================================
+
+
+# psutil.virtual_memory()
+svmem = namedtuple(
+ 'svmem', ['total', 'available', 'percent', 'used', 'free',
+ 'active', 'inactive', 'buffers', 'cached', 'shared', 'wired'])
+# psutil.cpu_times()
+scputimes = namedtuple(
+ 'scputimes', ['user', 'nice', 'system', 'idle', 'irq'])
+# psutil.Process.memory_info()
+pmem = namedtuple('pmem', ['rss', 'vms', 'text', 'data', 'stack'])
+# psutil.Process.memory_full_info()
+pfullmem = pmem
+# psutil.Process.cpu_times()
+pcputimes = namedtuple('pcputimes',
+ ['user', 'system', 'children_user', 'children_system'])
+# psutil.Process.memory_maps(grouped=True)
+pmmap_grouped = namedtuple(
+ 'pmmap_grouped', 'path rss, private, ref_count, shadow_count')
+# psutil.Process.memory_maps(grouped=False)
+pmmap_ext = namedtuple(
+ 'pmmap_ext', 'addr, perms path rss, private, ref_count, shadow_count')
+# psutil.disk_io_counters()
+if FREEBSD:
+ sdiskio = namedtuple('sdiskio', ['read_count', 'write_count',
+ 'read_bytes', 'write_bytes',
+ 'read_time', 'write_time',
+ 'busy_time'])
+else:
+ sdiskio = namedtuple('sdiskio', ['read_count', 'write_count',
+ 'read_bytes', 'write_bytes'])
+
+
+# =====================================================================
+# --- memory
+# =====================================================================
+
+
+def virtual_memory():
+ """System virtual memory as a namedtuple."""
+ mem = cext.virtual_mem()
+ total, free, active, inactive, wired, cached, buffers, shared = mem
+ if NETBSD:
+ # On NetBSD buffers and shared mem is determined via /proc.
+ # The C ext set them to 0.
+ with open('/proc/meminfo', 'rb') as f:
+ for line in f:
+ if line.startswith(b'Buffers:'):
+ buffers = int(line.split()[1]) * 1024
+ elif line.startswith(b'MemShared:'):
+ shared = int(line.split()[1]) * 1024
+ elif line.startswith(b'Cached:'):
+ cached = int(line.split()[1]) * 1024
+ avail = inactive + cached + free
+ used = active + wired + cached
+ percent = usage_percent((total - avail), total, round_=1)
+ return svmem(total, avail, percent, used, free,
+ active, inactive, buffers, cached, shared, wired)
+
+
+def swap_memory():
+ """System swap memory as (total, used, free, sin, sout) namedtuple."""
+ total, used, free, sin, sout = cext.swap_mem()
+ percent = usage_percent(used, total, round_=1)
+ return _common.sswap(total, used, free, percent, sin, sout)
+
+
+# =====================================================================
+# --- CPU
+# =====================================================================
+
+
+def cpu_times():
+ """Return system per-CPU times as a namedtuple"""
+ user, nice, system, idle, irq = cext.cpu_times()
+ return scputimes(user, nice, system, idle, irq)
+
+
+if HAS_PER_CPU_TIMES:
+ def per_cpu_times():
+ """Return system CPU times as a namedtuple"""
+ ret = []
+ for cpu_t in cext.per_cpu_times():
+ user, nice, system, idle, irq = cpu_t
+ item = scputimes(user, nice, system, idle, irq)
+ ret.append(item)
+ return ret
+else:
+ # XXX
+ # Ok, this is very dirty.
+ # On FreeBSD < 8 we cannot gather per-cpu information, see:
+ # https://github.com/giampaolo/psutil/issues/226
+ # If num cpus > 1, on first call we return single cpu times to avoid a
+ # crash at psutil import time.
+ # Next calls will fail with NotImplementedError
+ def per_cpu_times():
+ """Return system CPU times as a namedtuple"""
+ if cpu_count_logical() == 1:
+ return [cpu_times()]
+ if per_cpu_times.__called__:
+ raise NotImplementedError("supported only starting from FreeBSD 8")
+ per_cpu_times.__called__ = True
+ return [cpu_times()]
+
+ per_cpu_times.__called__ = False
+
+
+def cpu_count_logical():
+ """Return the number of logical CPUs in the system."""
+ return cext.cpu_count_logical()
+
+
+if OPENBSD or NETBSD:
+ def cpu_count_cores():
+ # OpenBSD and NetBSD do not implement this.
+ return 1 if cpu_count_logical() == 1 else None
+else:
+ def cpu_count_cores():
+ """Return the number of CPU cores in the system."""
+ # From the C module we'll get an XML string similar to this:
+ # http://manpages.ubuntu.com/manpages/precise/man4/smp.4freebsd.html
+ # We may get None in case "sysctl kern.sched.topology_spec"
+ # is not supported on this BSD version, in which case we'll mimic
+ # os.cpu_count() and return None.
+ ret = None
+ s = cext.cpu_topology()
+ if s is not None:
+ # get rid of padding chars appended at the end of the string
+ index = s.rfind("</groups>")
+ if index != -1:
+ s = s[:index + 9]
+ root = ET.fromstring(s)
+ try:
+ ret = len(root.findall('group/children/group/cpu')) or None
+ finally:
+ # needed otherwise it will memleak
+ root.clear()
+ if not ret:
+ # If logical CPUs == 1 it's obvious we' have only 1 core.
+ if cpu_count_logical() == 1:
+ return 1
+ return ret
+
+
+def cpu_stats():
+ """Return various CPU stats as a named tuple."""
+ if FREEBSD:
+ # Note: the C ext is returning some metrics we are not exposing:
+ # traps.
+ ctxsw, intrs, soft_intrs, syscalls, traps = cext.cpu_stats()
+ elif NETBSD:
+ # XXX
+ # Note about intrs: the C extension returns 0. intrs
+ # can be determined via /proc/stat; it has the same value as
+ # soft_intrs thought so the kernel is faking it (?).
+ #
+ # Note about syscalls: the C extension always sets it to 0 (?).
+ #
+ # Note: the C ext is returning some metrics we are not exposing:
+ # traps, faults and forks.
+ ctxsw, intrs, soft_intrs, syscalls, traps, faults, forks = \
+ cext.cpu_stats()
+ with open('/proc/stat', 'rb') as f:
+ for line in f:
+ if line.startswith(b'intr'):
+ intrs = int(line.split()[1])
+ elif OPENBSD:
+ # Note: the C ext is returning some metrics we are not exposing:
+ # traps, faults and forks.
+ ctxsw, intrs, soft_intrs, syscalls, traps, faults, forks = \
+ cext.cpu_stats()
+ return _common.scpustats(ctxsw, intrs, soft_intrs, syscalls)
+
+
+if FREEBSD:
+ def cpu_freq():
+ """Return frequency metrics for CPUs. As of Dec 2018 only
+ CPU 0 appears to be supported by FreeBSD and all other cores
+ match the frequency of CPU 0.
+ """
+ ret = []
+ num_cpus = cpu_count_logical()
+ for cpu in range(num_cpus):
+ try:
+ current, available_freq = cext.cpu_freq(cpu)
+ except NotImplementedError:
+ continue
+ if available_freq:
+ try:
+ min_freq = int(available_freq.split(" ")[-1].split("/")[0])
+ except (IndexError, ValueError):
+ min_freq = None
+ try:
+ max_freq = int(available_freq.split(" ")[0].split("/")[0])
+ except (IndexError, ValueError):
+ max_freq = None
+ ret.append(_common.scpufreq(current, min_freq, max_freq))
+ return ret
+elif OPENBSD:
+ def cpu_freq():
+ curr = float(cext.cpu_freq())
+ return [_common.scpufreq(curr, 0.0, 0.0)]
+
+
+# =====================================================================
+# --- disks
+# =====================================================================
+
+
+def disk_partitions(all=False):
+ """Return mounted disk partitions as a list of namedtuples.
+ 'all' argument is ignored, see:
+ https://github.com/giampaolo/psutil/issues/906
+ """
+ retlist = []
+ partitions = cext.disk_partitions()
+ for partition in partitions:
+ device, mountpoint, fstype, opts = partition
+ maxfile = maxpath = None # set later
+ ntuple = _common.sdiskpart(device, mountpoint, fstype, opts,
+ maxfile, maxpath)
+ retlist.append(ntuple)
+ return retlist
+
+
+disk_usage = _psposix.disk_usage
+disk_io_counters = cext.disk_io_counters
+
+
+# =====================================================================
+# --- network
+# =====================================================================
+
+
+net_io_counters = cext.net_io_counters
+net_if_addrs = cext_posix.net_if_addrs
+
+
+def net_if_stats():
+ """Get NIC stats (isup, duplex, speed, mtu)."""
+ names = net_io_counters().keys()
+ ret = {}
+ for name in names:
+ try:
+ mtu = cext_posix.net_if_mtu(name)
+ flags = cext_posix.net_if_flags(name)
+ duplex, speed = cext_posix.net_if_duplex_speed(name)
+ except OSError as err:
+ # https://github.com/giampaolo/psutil/issues/1279
+ if err.errno != errno.ENODEV:
+ raise
+ else:
+ if hasattr(_common, 'NicDuplex'):
+ duplex = _common.NicDuplex(duplex)
+ output_flags = ','.join(flags)
+ isup = 'running' in flags
+ ret[name] = _common.snicstats(isup, duplex, speed, mtu,
+ output_flags)
+ return ret
+
+
+def net_connections(kind):
+ """System-wide network connections."""
+ if OPENBSD:
+ ret = []
+ for pid in pids():
+ try:
+ cons = Process(pid).connections(kind)
+ except (NoSuchProcess, ZombieProcess):
+ continue
+ else:
+ for conn in cons:
+ conn = list(conn)
+ conn.append(pid)
+ ret.append(_common.sconn(*conn))
+ return ret
+
+ if kind not in _common.conn_tmap:
+ raise ValueError("invalid %r kind argument; choose between %s"
+ % (kind, ', '.join([repr(x) for x in conn_tmap])))
+ families, types = conn_tmap[kind]
+ ret = set()
+ if NETBSD:
+ rawlist = cext.net_connections(-1)
+ else:
+ rawlist = cext.net_connections()
+ for item in rawlist:
+ fd, fam, type, laddr, raddr, status, pid = item
+ # TODO: apply filter at C level
+ if fam in families and type in types:
+ nt = conn_to_ntuple(fd, fam, type, laddr, raddr, status,
+ TCP_STATUSES, pid)
+ ret.add(nt)
+ return list(ret)
+
+
+# =====================================================================
+# --- sensors
+# =====================================================================
+
+
+if FREEBSD:
+
+ def sensors_battery():
+ """Return battery info."""
+ try:
+ percent, minsleft, power_plugged = cext.sensors_battery()
+ except NotImplementedError:
+ # See: https://github.com/giampaolo/psutil/issues/1074
+ return None
+ power_plugged = power_plugged == 1
+ if power_plugged:
+ secsleft = _common.POWER_TIME_UNLIMITED
+ elif minsleft == -1:
+ secsleft = _common.POWER_TIME_UNKNOWN
+ else:
+ secsleft = minsleft * 60
+ return _common.sbattery(percent, secsleft, power_plugged)
+
+ def sensors_temperatures():
+ """Return CPU cores temperatures if available, else an empty dict."""
+ ret = defaultdict(list)
+ num_cpus = cpu_count_logical()
+ for cpu in range(num_cpus):
+ try:
+ current, high = cext.sensors_cpu_temperature(cpu)
+ if high <= 0:
+ high = None
+ name = "Core %s" % cpu
+ ret["coretemp"].append(
+ _common.shwtemp(name, current, high, high))
+ except NotImplementedError:
+ pass
+
+ return ret
+
+
+# =====================================================================
+# --- other system functions
+# =====================================================================
+
+
+def boot_time():
+ """The system boot time expressed in seconds since the epoch."""
+ return cext.boot_time()
+
+
+def users():
+ """Return currently connected users as a list of namedtuples."""
+ retlist = []
+ rawlist = cext.users()
+ for item in rawlist:
+ user, tty, hostname, tstamp, pid = item
+ if pid == -1:
+ assert OPENBSD
+ pid = None
+ if tty == '~':
+ continue # reboot or shutdown
+ nt = _common.suser(user, tty or None, hostname, tstamp, pid)
+ retlist.append(nt)
+ return retlist
+
+
+# =====================================================================
+# --- processes
+# =====================================================================
+
+
+@memoize
+def _pid_0_exists():
+ try:
+ Process(0).name()
+ except NoSuchProcess:
+ return False
+ except AccessDenied:
+ return True
+ else:
+ return True
+
+
+def pids():
+ """Returns a list of PIDs currently running on the system."""
+ ret = cext.pids()
+ if OPENBSD and (0 not in ret) and _pid_0_exists():
+ # On OpenBSD the kernel does not return PID 0 (neither does
+ # ps) but it's actually querable (Process(0) will succeed).
+ ret.insert(0, 0)
+ return ret
+
+
+if OPENBSD or NETBSD:
+ def pid_exists(pid):
+ """Return True if pid exists."""
+ exists = _psposix.pid_exists(pid)
+ if not exists:
+ # We do this because _psposix.pid_exists() lies in case of
+ # zombie processes.
+ return pid in pids()
+ else:
+ return True
+else:
+ pid_exists = _psposix.pid_exists
+
+
+def is_zombie(pid):
+ try:
+ st = cext.proc_oneshot_info(pid)[kinfo_proc_map['status']]
+ return st == cext.SZOMB
+ except Exception:
+ return False
+
+
+def wrap_exceptions(fun):
+ """Decorator which translates bare OSError exceptions into
+ NoSuchProcess and AccessDenied.
+ """
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return fun(self, *args, **kwargs)
+ except ProcessLookupError:
+ if is_zombie(self.pid):
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ else:
+ raise NoSuchProcess(self.pid, self._name)
+ except PermissionError:
+ raise AccessDenied(self.pid, self._name)
+ except OSError:
+ if self.pid == 0:
+ if 0 in pids():
+ raise AccessDenied(self.pid, self._name)
+ else:
+ raise
+ raise
+ return wrapper
+
+
+@contextlib.contextmanager
+def wrap_exceptions_procfs(inst):
+ """Same as above, for routines relying on reading /proc fs."""
+ try:
+ yield
+ except (ProcessLookupError, FileNotFoundError):
+ # ENOENT (no such file or directory) gets raised on open().
+ # ESRCH (no such process) can get raised on read() if
+ # process is gone in meantime.
+ if is_zombie(inst.pid):
+ raise ZombieProcess(inst.pid, inst._name, inst._ppid)
+ else:
+ raise NoSuchProcess(inst.pid, inst._name)
+ except PermissionError:
+ raise AccessDenied(inst.pid, inst._name)
+
+
+class Process(object):
+ """Wrapper class around underlying C implementation."""
+
+ __slots__ = ["pid", "_name", "_ppid", "_cache"]
+
+ def __init__(self, pid):
+ self.pid = pid
+ self._name = None
+ self._ppid = None
+
+ def _assert_alive(self):
+ """Raise NSP if the process disappeared on us."""
+ # For those C function who do not raise NSP, possibly returning
+ # incorrect or incomplete result.
+ cext.proc_name(self.pid)
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def oneshot(self):
+ """Retrieves multiple process info in one shot as a raw tuple."""
+ ret = cext.proc_oneshot_info(self.pid)
+ assert len(ret) == len(kinfo_proc_map)
+ return ret
+
+ def oneshot_enter(self):
+ self.oneshot.cache_activate(self)
+
+ def oneshot_exit(self):
+ self.oneshot.cache_deactivate(self)
+
+ @wrap_exceptions
+ def name(self):
+ name = self.oneshot()[kinfo_proc_map['name']]
+ return name if name is not None else cext.proc_name(self.pid)
+
+ @wrap_exceptions
+ def exe(self):
+ if FREEBSD:
+ if self.pid == 0:
+ return '' # else NSP
+ return cext.proc_exe(self.pid)
+ elif NETBSD:
+ if self.pid == 0:
+ # /proc/0 dir exists but /proc/0/exe doesn't
+ return ""
+ with wrap_exceptions_procfs(self):
+ return os.readlink("/proc/%s/exe" % self.pid)
+ else:
+ # OpenBSD: exe cannot be determined; references:
+ # https://chromium.googlesource.com/chromium/src/base/+/
+ # master/base_paths_posix.cc
+ # We try our best guess by using which against the first
+ # cmdline arg (may return None).
+ cmdline = self.cmdline()
+ if cmdline:
+ return which(cmdline[0]) or ""
+ else:
+ return ""
+
+ @wrap_exceptions
+ def cmdline(self):
+ if OPENBSD and self.pid == 0:
+ return [] # ...else it crashes
+ elif NETBSD:
+ # XXX - most of the times the underlying sysctl() call on Net
+ # and Open BSD returns a truncated string.
+ # Also /proc/pid/cmdline behaves the same so it looks
+ # like this is a kernel bug.
+ try:
+ return cext.proc_cmdline(self.pid)
+ except OSError as err:
+ if err.errno == errno.EINVAL:
+ if is_zombie(self.pid):
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ elif not pid_exists(self.pid):
+ raise NoSuchProcess(self.pid, self._name, self._ppid)
+ else:
+ # XXX: this happens with unicode tests. It means the C
+ # routine is unable to decode invalid unicode chars.
+ return []
+ else:
+ raise
+ else:
+ return cext.proc_cmdline(self.pid)
+
+ @wrap_exceptions
+ def environ(self):
+ return cext.proc_environ(self.pid)
+
+ @wrap_exceptions
+ def terminal(self):
+ tty_nr = self.oneshot()[kinfo_proc_map['ttynr']]
+ tmap = _psposix.get_terminal_map()
+ try:
+ return tmap[tty_nr]
+ except KeyError:
+ return None
+
+ @wrap_exceptions
+ def ppid(self):
+ self._ppid = self.oneshot()[kinfo_proc_map['ppid']]
+ return self._ppid
+
+ @wrap_exceptions
+ def uids(self):
+ rawtuple = self.oneshot()
+ return _common.puids(
+ rawtuple[kinfo_proc_map['real_uid']],
+ rawtuple[kinfo_proc_map['effective_uid']],
+ rawtuple[kinfo_proc_map['saved_uid']])
+
+ @wrap_exceptions
+ def gids(self):
+ rawtuple = self.oneshot()
+ return _common.pgids(
+ rawtuple[kinfo_proc_map['real_gid']],
+ rawtuple[kinfo_proc_map['effective_gid']],
+ rawtuple[kinfo_proc_map['saved_gid']])
+
+ @wrap_exceptions
+ def cpu_times(self):
+ rawtuple = self.oneshot()
+ return _common.pcputimes(
+ rawtuple[kinfo_proc_map['user_time']],
+ rawtuple[kinfo_proc_map['sys_time']],
+ rawtuple[kinfo_proc_map['ch_user_time']],
+ rawtuple[kinfo_proc_map['ch_sys_time']])
+
+ if FREEBSD:
+ @wrap_exceptions
+ def cpu_num(self):
+ return self.oneshot()[kinfo_proc_map['cpunum']]
+
+ @wrap_exceptions
+ def memory_info(self):
+ rawtuple = self.oneshot()
+ return pmem(
+ rawtuple[kinfo_proc_map['rss']],
+ rawtuple[kinfo_proc_map['vms']],
+ rawtuple[kinfo_proc_map['memtext']],
+ rawtuple[kinfo_proc_map['memdata']],
+ rawtuple[kinfo_proc_map['memstack']])
+
+ memory_full_info = memory_info
+
+ @wrap_exceptions
+ def create_time(self):
+ return self.oneshot()[kinfo_proc_map['create_time']]
+
+ @wrap_exceptions
+ def num_threads(self):
+ if HAS_PROC_NUM_THREADS:
+ # FreeBSD
+ return cext.proc_num_threads(self.pid)
+ else:
+ return len(self.threads())
+
+ @wrap_exceptions
+ def num_ctx_switches(self):
+ rawtuple = self.oneshot()
+ return _common.pctxsw(
+ rawtuple[kinfo_proc_map['ctx_switches_vol']],
+ rawtuple[kinfo_proc_map['ctx_switches_unvol']])
+
+ @wrap_exceptions
+ def threads(self):
+ # Note: on OpenSBD this (/dev/mem) requires root access.
+ rawlist = cext.proc_threads(self.pid)
+ retlist = []
+ for thread_id, utime, stime in rawlist:
+ ntuple = _common.pthread(thread_id, utime, stime)
+ retlist.append(ntuple)
+ if OPENBSD:
+ self._assert_alive()
+ return retlist
+
+ @wrap_exceptions
+ def connections(self, kind='inet'):
+ if kind not in conn_tmap:
+ raise ValueError("invalid %r kind argument; choose between %s"
+ % (kind, ', '.join([repr(x) for x in conn_tmap])))
+
+ if NETBSD:
+ families, types = conn_tmap[kind]
+ ret = []
+ rawlist = cext.net_connections(self.pid)
+ for item in rawlist:
+ fd, fam, type, laddr, raddr, status, pid = item
+ assert pid == self.pid
+ if fam in families and type in types:
+ nt = conn_to_ntuple(fd, fam, type, laddr, raddr, status,
+ TCP_STATUSES)
+ ret.append(nt)
+ self._assert_alive()
+ return list(ret)
+
+ families, types = conn_tmap[kind]
+ rawlist = cext.proc_connections(self.pid, families, types)
+ ret = []
+ for item in rawlist:
+ fd, fam, type, laddr, raddr, status = item
+ nt = conn_to_ntuple(fd, fam, type, laddr, raddr, status,
+ TCP_STATUSES)
+ ret.append(nt)
+
+ if OPENBSD:
+ self._assert_alive()
+
+ return ret
+
+ @wrap_exceptions
+ def wait(self, timeout=None):
+ return _psposix.wait_pid(self.pid, timeout, self._name)
+
+ @wrap_exceptions
+ def nice_get(self):
+ return cext_posix.getpriority(self.pid)
+
+ @wrap_exceptions
+ def nice_set(self, value):
+ return cext_posix.setpriority(self.pid, value)
+
+ @wrap_exceptions
+ def status(self):
+ code = self.oneshot()[kinfo_proc_map['status']]
+ # XXX is '?' legit? (we're not supposed to return it anyway)
+ return PROC_STATUSES.get(code, '?')
+
+ @wrap_exceptions
+ def io_counters(self):
+ rawtuple = self.oneshot()
+ return _common.pio(
+ rawtuple[kinfo_proc_map['read_io_count']],
+ rawtuple[kinfo_proc_map['write_io_count']],
+ -1,
+ -1)
+
+ @wrap_exceptions
+ def cwd(self):
+ """Return process current working directory."""
+ # sometimes we get an empty string, in which case we turn
+ # it into None
+ if OPENBSD and self.pid == 0:
+ return None # ...else it would raise EINVAL
+ elif NETBSD or HAS_PROC_OPEN_FILES:
+ # FreeBSD < 8 does not support functions based on
+ # kinfo_getfile() and kinfo_getvmmap()
+ return cext.proc_cwd(self.pid) or None
+ else:
+ raise NotImplementedError(
+ "supported only starting from FreeBSD 8" if
+ FREEBSD else "")
+
+ nt_mmap_grouped = namedtuple(
+ 'mmap', 'path rss, private, ref_count, shadow_count')
+ nt_mmap_ext = namedtuple(
+ 'mmap', 'addr, perms path rss, private, ref_count, shadow_count')
+
+ def _not_implemented(self):
+ raise NotImplementedError
+
+ # FreeBSD < 8 does not support functions based on kinfo_getfile()
+ # and kinfo_getvmmap()
+ if HAS_PROC_OPEN_FILES:
+ @wrap_exceptions
+ def open_files(self):
+ """Return files opened by process as a list of namedtuples."""
+ rawlist = cext.proc_open_files(self.pid)
+ return [_common.popenfile(path, fd) for path, fd in rawlist]
+ else:
+ open_files = _not_implemented
+
+ # FreeBSD < 8 does not support functions based on kinfo_getfile()
+ # and kinfo_getvmmap()
+ if HAS_PROC_NUM_FDS:
+ @wrap_exceptions
+ def num_fds(self):
+ """Return the number of file descriptors opened by this process."""
+ ret = cext.proc_num_fds(self.pid)
+ if NETBSD:
+ self._assert_alive()
+ return ret
+ else:
+ num_fds = _not_implemented
+
+ # --- FreeBSD only APIs
+
+ if FREEBSD:
+
+ @wrap_exceptions
+ def cpu_affinity_get(self):
+ return cext.proc_cpu_affinity_get(self.pid)
+
+ @wrap_exceptions
+ def cpu_affinity_set(self, cpus):
+ # Pre-emptively check if CPUs are valid because the C
+ # function has a weird behavior in case of invalid CPUs,
+ # see: https://github.com/giampaolo/psutil/issues/586
+ allcpus = tuple(range(len(per_cpu_times())))
+ for cpu in cpus:
+ if cpu not in allcpus:
+ raise ValueError("invalid CPU #%i (choose between %s)"
+ % (cpu, allcpus))
+ try:
+ cext.proc_cpu_affinity_set(self.pid, cpus)
+ except OSError as err:
+ # 'man cpuset_setaffinity' about EDEADLK:
+ # <<the call would leave a thread without a valid CPU to run
+ # on because the set does not overlap with the thread's
+ # anonymous mask>>
+ if err.errno in (errno.EINVAL, errno.EDEADLK):
+ for cpu in cpus:
+ if cpu not in allcpus:
+ raise ValueError(
+ "invalid CPU #%i (choose between %s)" % (
+ cpu, allcpus))
+ raise
+
+ @wrap_exceptions
+ def memory_maps(self):
+ return cext.proc_memory_maps(self.pid)
+
+ @wrap_exceptions
+ def rlimit(self, resource, limits=None):
+ if limits is None:
+ return cext.proc_getrlimit(self.pid, resource)
+ else:
+ if len(limits) != 2:
+ raise ValueError(
+ "second argument must be a (soft, hard) tuple, "
+ "got %s" % repr(limits))
+ soft, hard = limits
+ return cext.proc_setrlimit(self.pid, resource, soft, hard)
diff --git a/lib/psutil/_pslinux.py b/lib/psutil/_pslinux.py
new file mode 100644
index 0000000..9dc9643
--- /dev/null
+++ b/lib/psutil/_pslinux.py
@@ -0,0 +1,2257 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Linux platform implementation."""
+
+from __future__ import division
+
+import base64
+import collections
+import errno
+import functools
+import glob
+import os
+import re
+import socket
+import struct
+import sys
+import traceback
+import warnings
+from collections import defaultdict
+from collections import namedtuple
+
+from . import _common
+from . import _psposix
+from . import _psutil_linux as cext
+from . import _psutil_posix as cext_posix
+from ._common import NIC_DUPLEX_FULL
+from ._common import NIC_DUPLEX_HALF
+from ._common import NIC_DUPLEX_UNKNOWN
+from ._common import AccessDenied
+from ._common import NoSuchProcess
+from ._common import ZombieProcess
+from ._common import bcat
+from ._common import cat
+from ._common import debug
+from ._common import decode
+from ._common import get_procfs_path
+from ._common import isfile_strict
+from ._common import memoize
+from ._common import memoize_when_activated
+from ._common import open_binary
+from ._common import open_text
+from ._common import parse_environ_block
+from ._common import path_exists_strict
+from ._common import supports_ipv6
+from ._common import usage_percent
+from ._compat import PY3
+from ._compat import FileNotFoundError
+from ._compat import PermissionError
+from ._compat import ProcessLookupError
+from ._compat import b
+from ._compat import basestring
+
+
+if sys.version_info >= (3, 4):
+ import enum
+else:
+ enum = None
+
+
+__extra__all__ = [
+ #
+ 'PROCFS_PATH',
+ # io prio constants
+ "IOPRIO_CLASS_NONE", "IOPRIO_CLASS_RT", "IOPRIO_CLASS_BE",
+ "IOPRIO_CLASS_IDLE",
+ # connection status constants
+ "CONN_ESTABLISHED", "CONN_SYN_SENT", "CONN_SYN_RECV", "CONN_FIN_WAIT1",
+ "CONN_FIN_WAIT2", "CONN_TIME_WAIT", "CONN_CLOSE", "CONN_CLOSE_WAIT",
+ "CONN_LAST_ACK", "CONN_LISTEN", "CONN_CLOSING", ]
+
+
+# =====================================================================
+# --- globals
+# =====================================================================
+
+
+POWER_SUPPLY_PATH = "/sys/class/power_supply"
+HAS_PROC_SMAPS = os.path.exists('/proc/%s/smaps' % os.getpid())
+HAS_PROC_SMAPS_ROLLUP = os.path.exists('/proc/%s/smaps_rollup' % os.getpid())
+HAS_PROC_IO_PRIORITY = hasattr(cext, "proc_ioprio_get")
+HAS_CPU_AFFINITY = hasattr(cext, "proc_cpu_affinity_get")
+
+# Number of clock ticks per second
+CLOCK_TICKS = os.sysconf("SC_CLK_TCK")
+PAGESIZE = cext_posix.getpagesize()
+BOOT_TIME = None # set later
+LITTLE_ENDIAN = sys.byteorder == 'little'
+
+# "man iostat" states that sectors are equivalent with blocks and have
+# a size of 512 bytes. Despite this value can be queried at runtime
+# via /sys/block/{DISK}/queue/hw_sector_size and results may vary
+# between 1k, 2k, or 4k... 512 appears to be a magic constant used
+# throughout Linux source code:
+# * https://stackoverflow.com/a/38136179/376587
+# * https://lists.gt.net/linux/kernel/2241060
+# * https://github.com/giampaolo/psutil/issues/1305
+# * https://github.com/torvalds/linux/blob/
+# 4f671fe2f9523a1ea206f63fe60a7c7b3a56d5c7/include/linux/bio.h#L99
+# * https://lkml.org/lkml/2015/8/17/234
+DISK_SECTOR_SIZE = 512
+
+if enum is None:
+ AF_LINK = socket.AF_PACKET
+else:
+ AddressFamily = enum.IntEnum('AddressFamily',
+ {'AF_LINK': int(socket.AF_PACKET)})
+ AF_LINK = AddressFamily.AF_LINK
+
+# ioprio_* constants http://linux.die.net/man/2/ioprio_get
+if enum is None:
+ IOPRIO_CLASS_NONE = 0
+ IOPRIO_CLASS_RT = 1
+ IOPRIO_CLASS_BE = 2
+ IOPRIO_CLASS_IDLE = 3
+else:
+ class IOPriority(enum.IntEnum):
+ IOPRIO_CLASS_NONE = 0
+ IOPRIO_CLASS_RT = 1
+ IOPRIO_CLASS_BE = 2
+ IOPRIO_CLASS_IDLE = 3
+
+ globals().update(IOPriority.__members__)
+
+# See:
+# https://github.com/torvalds/linux/blame/master/fs/proc/array.c
+# ...and (TASK_* constants):
+# https://github.com/torvalds/linux/blob/master/include/linux/sched.h
+PROC_STATUSES = {
+ "R": _common.STATUS_RUNNING,
+ "S": _common.STATUS_SLEEPING,
+ "D": _common.STATUS_DISK_SLEEP,
+ "T": _common.STATUS_STOPPED,
+ "t": _common.STATUS_TRACING_STOP,
+ "Z": _common.STATUS_ZOMBIE,
+ "X": _common.STATUS_DEAD,
+ "x": _common.STATUS_DEAD,
+ "K": _common.STATUS_WAKE_KILL,
+ "W": _common.STATUS_WAKING,
+ "I": _common.STATUS_IDLE,
+ "P": _common.STATUS_PARKED,
+}
+
+# https://github.com/torvalds/linux/blob/master/include/net/tcp_states.h
+TCP_STATUSES = {
+ "01": _common.CONN_ESTABLISHED,
+ "02": _common.CONN_SYN_SENT,
+ "03": _common.CONN_SYN_RECV,
+ "04": _common.CONN_FIN_WAIT1,
+ "05": _common.CONN_FIN_WAIT2,
+ "06": _common.CONN_TIME_WAIT,
+ "07": _common.CONN_CLOSE,
+ "08": _common.CONN_CLOSE_WAIT,
+ "09": _common.CONN_LAST_ACK,
+ "0A": _common.CONN_LISTEN,
+ "0B": _common.CONN_CLOSING
+}
+
+
+# =====================================================================
+# --- named tuples
+# =====================================================================
+
+
+# psutil.virtual_memory()
+svmem = namedtuple(
+ 'svmem', ['total', 'available', 'percent', 'used', 'free',
+ 'active', 'inactive', 'buffers', 'cached', 'shared', 'slab'])
+# psutil.disk_io_counters()
+sdiskio = namedtuple(
+ 'sdiskio', ['read_count', 'write_count',
+ 'read_bytes', 'write_bytes',
+ 'read_time', 'write_time',
+ 'read_merged_count', 'write_merged_count',
+ 'busy_time'])
+# psutil.Process().open_files()
+popenfile = namedtuple(
+ 'popenfile', ['path', 'fd', 'position', 'mode', 'flags'])
+# psutil.Process().memory_info()
+pmem = namedtuple('pmem', 'rss vms shared text lib data dirty')
+# psutil.Process().memory_full_info()
+pfullmem = namedtuple('pfullmem', pmem._fields + ('uss', 'pss', 'swap'))
+# psutil.Process().memory_maps(grouped=True)
+pmmap_grouped = namedtuple(
+ 'pmmap_grouped',
+ ['path', 'rss', 'size', 'pss', 'shared_clean', 'shared_dirty',
+ 'private_clean', 'private_dirty', 'referenced', 'anonymous', 'swap'])
+# psutil.Process().memory_maps(grouped=False)
+pmmap_ext = namedtuple(
+ 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields))
+# psutil.Process.io_counters()
+pio = namedtuple('pio', ['read_count', 'write_count',
+ 'read_bytes', 'write_bytes',
+ 'read_chars', 'write_chars'])
+# psutil.Process.cpu_times()
+pcputimes = namedtuple('pcputimes',
+ ['user', 'system', 'children_user', 'children_system',
+ 'iowait'])
+
+
+# =====================================================================
+# --- utils
+# =====================================================================
+
+
+def readlink(path):
+ """Wrapper around os.readlink()."""
+ assert isinstance(path, basestring), path
+ path = os.readlink(path)
+ # readlink() might return paths containing null bytes ('\x00')
+ # resulting in "TypeError: must be encoded string without NULL
+ # bytes, not str" errors when the string is passed to other
+ # fs-related functions (os.*, open(), ...).
+ # Apparently everything after '\x00' is garbage (we can have
+ # ' (deleted)', 'new' and possibly others), see:
+ # https://github.com/giampaolo/psutil/issues/717
+ path = path.split('\x00')[0]
+ # Certain paths have ' (deleted)' appended. Usually this is
+ # bogus as the file actually exists. Even if it doesn't we
+ # don't care.
+ if path.endswith(' (deleted)') and not path_exists_strict(path):
+ path = path[:-10]
+ return path
+
+
+def file_flags_to_mode(flags):
+ """Convert file's open() flags into a readable string.
+ Used by Process.open_files().
+ """
+ modes_map = {os.O_RDONLY: 'r', os.O_WRONLY: 'w', os.O_RDWR: 'w+'}
+ mode = modes_map[flags & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR)]
+ if flags & os.O_APPEND:
+ mode = mode.replace('w', 'a', 1)
+ mode = mode.replace('w+', 'r+')
+ # possible values: r, w, a, r+, a+
+ return mode
+
+
+def is_storage_device(name):
+ """Return True if the given name refers to a root device (e.g.
+ "sda", "nvme0n1") as opposed to a logical partition (e.g. "sda1",
+ "nvme0n1p1"). If name is a virtual device (e.g. "loop1", "ram")
+ return True.
+ """
+ # Re-adapted from iostat source code, see:
+ # https://github.com/sysstat/sysstat/blob/
+ # 97912938cd476645b267280069e83b1c8dc0e1c7/common.c#L208
+ # Some devices may have a slash in their name (e.g. cciss/c0d0...).
+ name = name.replace('/', '!')
+ including_virtual = True
+ if including_virtual:
+ path = "/sys/block/%s" % name
+ else:
+ path = "/sys/block/%s/device" % name
+ return os.access(path, os.F_OK)
+
+
+@memoize
+def set_scputimes_ntuple(procfs_path):
+ """Set a namedtuple of variable fields depending on the CPU times
+ available on this Linux kernel version which may be:
+ (user, nice, system, idle, iowait, irq, softirq, [steal, [guest,
+ [guest_nice]]])
+ Used by cpu_times() function.
+ """
+ global scputimes
+ with open_binary('%s/stat' % procfs_path) as f:
+ values = f.readline().split()[1:]
+ fields = ['user', 'nice', 'system', 'idle', 'iowait', 'irq', 'softirq']
+ vlen = len(values)
+ if vlen >= 8:
+ # Linux >= 2.6.11
+ fields.append('steal')
+ if vlen >= 9:
+ # Linux >= 2.6.24
+ fields.append('guest')
+ if vlen >= 10:
+ # Linux >= 3.2.0
+ fields.append('guest_nice')
+ scputimes = namedtuple('scputimes', fields)
+
+
+try:
+ set_scputimes_ntuple("/proc")
+except Exception: # pragma: no cover
+ # Don't want to crash at import time.
+ traceback.print_exc()
+ scputimes = namedtuple('scputimes', 'user system idle')(0.0, 0.0, 0.0)
+
+
+# =====================================================================
+# --- prlimit
+# =====================================================================
+
+# Backport of resource.prlimit() for Python 2. Originally this was done
+# in C, but CentOS-6 which we use to create manylinux wheels is too old
+# and does not support prlimit() syscall. As such the resulting wheel
+# would not include prlimit(), even when installed on newer systems.
+# This is the only part of psutil using ctypes.
+
+prlimit = None
+try:
+ from resource import prlimit # python >= 3.4
+except ImportError:
+ import ctypes
+
+ libc = ctypes.CDLL(None, use_errno=True)
+
+ if hasattr(libc, "prlimit"):
+
+ def prlimit(pid, resource_, limits=None):
+ class StructRlimit(ctypes.Structure):
+ _fields_ = [('rlim_cur', ctypes.c_longlong),
+ ('rlim_max', ctypes.c_longlong)]
+
+ current = StructRlimit()
+ if limits is None:
+ # get
+ ret = libc.prlimit(pid, resource_, None, ctypes.byref(current))
+ else:
+ # set
+ new = StructRlimit()
+ new.rlim_cur = limits[0]
+ new.rlim_max = limits[1]
+ ret = libc.prlimit(
+ pid, resource_, ctypes.byref(new), ctypes.byref(current))
+
+ if ret != 0:
+ errno = ctypes.get_errno()
+ raise OSError(errno, os.strerror(errno))
+ return (current.rlim_cur, current.rlim_max)
+
+
+if prlimit is not None:
+ __extra__all__.extend(
+ [x for x in dir(cext) if x.startswith('RLIM') and x.isupper()])
+
+
+# =====================================================================
+# --- system memory
+# =====================================================================
+
+
+def calculate_avail_vmem(mems):
+ """Fallback for kernels < 3.14 where /proc/meminfo does not provide
+ "MemAvailable:" column, see:
+ https://blog.famzah.net/2014/09/24/
+ This code reimplements the algorithm outlined here:
+ https://git.kernel.org/cgit/linux/kernel/git/torvalds/linux.git/
+ commit/?id=34e431b0ae398fc54ea69ff85ec700722c9da773
+
+ XXX: on recent kernels this calculation differs by ~1.5% than
+ "MemAvailable:" as it's calculated slightly differently, see:
+ https://gitlab.com/procps-ng/procps/issues/42
+ https://github.com/famzah/linux-memavailable-procfs/issues/2
+ It is still way more realistic than doing (free + cached) though.
+ """
+ # Fallback for very old distros. According to
+ # https://git.kernel.org/cgit/linux/kernel/git/torvalds/linux.git/
+ # commit/?id=34e431b0ae398fc54ea69ff85ec700722c9da773
+ # ...long ago "avail" was calculated as (free + cached).
+ # We might fallback in such cases:
+ # "Active(file)" not available: 2.6.28 / Dec 2008
+ # "Inactive(file)" not available: 2.6.28 / Dec 2008
+ # "SReclaimable:" not available: 2.6.19 / Nov 2006
+ # /proc/zoneinfo not available: 2.6.13 / Aug 2005
+ free = mems[b'MemFree:']
+ fallback = free + mems.get(b"Cached:", 0)
+ try:
+ lru_active_file = mems[b'Active(file):']
+ lru_inactive_file = mems[b'Inactive(file):']
+ slab_reclaimable = mems[b'SReclaimable:']
+ except KeyError:
+ return fallback
+ try:
+ f = open_binary('%s/zoneinfo' % get_procfs_path())
+ except IOError:
+ return fallback # kernel 2.6.13
+
+ watermark_low = 0
+ with f:
+ for line in f:
+ line = line.strip()
+ if line.startswith(b'low'):
+ watermark_low += int(line.split()[1])
+ watermark_low *= PAGESIZE
+
+ avail = free - watermark_low
+ pagecache = lru_active_file + lru_inactive_file
+ pagecache -= min(pagecache / 2, watermark_low)
+ avail += pagecache
+ avail += slab_reclaimable - min(slab_reclaimable / 2.0, watermark_low)
+ return int(avail)
+
+
+def virtual_memory():
+ """Report virtual memory stats.
+ This implementation matches "free" and "vmstat -s" cmdline
+ utility values and procps-ng-3.3.12 source was used as a reference
+ (2016-09-18):
+ https://gitlab.com/procps-ng/procps/blob/
+ 24fd2605c51fccc375ab0287cec33aa767f06718/proc/sysinfo.c
+ For reference, procps-ng-3.3.10 is the version available on Ubuntu
+ 16.04.
+
+ Note about "available" memory: up until psutil 4.3 it was
+ calculated as "avail = (free + buffers + cached)". Now
+ "MemAvailable:" column (kernel 3.14) from /proc/meminfo is used as
+ it's more accurate.
+ That matches "available" column in newer versions of "free".
+ """
+ missing_fields = []
+ mems = {}
+ with open_binary('%s/meminfo' % get_procfs_path()) as f:
+ for line in f:
+ fields = line.split()
+ mems[fields[0]] = int(fields[1]) * 1024
+
+ # /proc doc states that the available fields in /proc/meminfo vary
+ # by architecture and compile options, but these 3 values are also
+ # returned by sysinfo(2); as such we assume they are always there.
+ total = mems[b'MemTotal:']
+ free = mems[b'MemFree:']
+ try:
+ buffers = mems[b'Buffers:']
+ except KeyError:
+ # https://github.com/giampaolo/psutil/issues/1010
+ buffers = 0
+ missing_fields.append('buffers')
+ try:
+ cached = mems[b"Cached:"]
+ except KeyError:
+ cached = 0
+ missing_fields.append('cached')
+ else:
+ # "free" cmdline utility sums reclaimable to cached.
+ # Older versions of procps used to add slab memory instead.
+ # This got changed in:
+ # https://gitlab.com/procps-ng/procps/commit/
+ # 05d751c4f076a2f0118b914c5e51cfbb4762ad8e
+ cached += mems.get(b"SReclaimable:", 0) # since kernel 2.6.19
+
+ try:
+ shared = mems[b'Shmem:'] # since kernel 2.6.32
+ except KeyError:
+ try:
+ shared = mems[b'MemShared:'] # kernels 2.4
+ except KeyError:
+ shared = 0
+ missing_fields.append('shared')
+
+ try:
+ active = mems[b"Active:"]
+ except KeyError:
+ active = 0
+ missing_fields.append('active')
+
+ try:
+ inactive = mems[b"Inactive:"]
+ except KeyError:
+ try:
+ inactive = \
+ mems[b"Inact_dirty:"] + \
+ mems[b"Inact_clean:"] + \
+ mems[b"Inact_laundry:"]
+ except KeyError:
+ inactive = 0
+ missing_fields.append('inactive')
+
+ try:
+ slab = mems[b"Slab:"]
+ except KeyError:
+ slab = 0
+
+ used = total - free - cached - buffers
+ if used < 0:
+ # May be symptomatic of running within a LCX container where such
+ # values will be dramatically distorted over those of the host.
+ used = total - free
+
+ # - starting from 4.4.0 we match free's "available" column.
+ # Before 4.4.0 we calculated it as (free + buffers + cached)
+ # which matched htop.
+ # - free and htop available memory differs as per:
+ # http://askubuntu.com/a/369589
+ # http://unix.stackexchange.com/a/65852/168884
+ # - MemAvailable has been introduced in kernel 3.14
+ try:
+ avail = mems[b'MemAvailable:']
+ except KeyError:
+ avail = calculate_avail_vmem(mems)
+
+ if avail < 0:
+ avail = 0
+ missing_fields.append('available')
+
+ # If avail is greater than total or our calculation overflows,
+ # that's symptomatic of running within a LCX container where such
+ # values will be dramatically distorted over those of the host.
+ # https://gitlab.com/procps-ng/procps/blob/
+ # 24fd2605c51fccc375ab0287cec33aa767f06718/proc/sysinfo.c#L764
+ if avail > total:
+ avail = free
+
+ percent = usage_percent((total - avail), total, round_=1)
+
+ # Warn about missing metrics which are set to 0.
+ if missing_fields:
+ msg = "%s memory stats couldn't be determined and %s set to 0" % (
+ ", ".join(missing_fields),
+ "was" if len(missing_fields) == 1 else "were")
+ warnings.warn(msg, RuntimeWarning)
+
+ return svmem(total, avail, percent, used, free,
+ active, inactive, buffers, cached, shared, slab)
+
+
+def swap_memory():
+ """Return swap memory metrics."""
+ mems = {}
+ with open_binary('%s/meminfo' % get_procfs_path()) as f:
+ for line in f:
+ fields = line.split()
+ mems[fields[0]] = int(fields[1]) * 1024
+ # We prefer /proc/meminfo over sysinfo() syscall so that
+ # psutil.PROCFS_PATH can be used in order to allow retrieval
+ # for linux containers, see:
+ # https://github.com/giampaolo/psutil/issues/1015
+ try:
+ total = mems[b'SwapTotal:']
+ free = mems[b'SwapFree:']
+ except KeyError:
+ _, _, _, _, total, free, unit_multiplier = cext.linux_sysinfo()
+ total *= unit_multiplier
+ free *= unit_multiplier
+
+ used = total - free
+ percent = usage_percent(used, total, round_=1)
+ # get pgin/pgouts
+ try:
+ f = open_binary("%s/vmstat" % get_procfs_path())
+ except IOError as err:
+ # see https://github.com/giampaolo/psutil/issues/722
+ msg = "'sin' and 'sout' swap memory stats couldn't " \
+ "be determined and were set to 0 (%s)" % str(err)
+ warnings.warn(msg, RuntimeWarning)
+ sin = sout = 0
+ else:
+ with f:
+ sin = sout = None
+ for line in f:
+ # values are expressed in 4 kilo bytes, we want
+ # bytes instead
+ if line.startswith(b'pswpin'):
+ sin = int(line.split(b' ')[1]) * 4 * 1024
+ elif line.startswith(b'pswpout'):
+ sout = int(line.split(b' ')[1]) * 4 * 1024
+ if sin is not None and sout is not None:
+ break
+ else:
+ # we might get here when dealing with exotic Linux
+ # flavors, see:
+ # https://github.com/giampaolo/psutil/issues/313
+ msg = "'sin' and 'sout' swap memory stats couldn't " \
+ "be determined and were set to 0"
+ warnings.warn(msg, RuntimeWarning)
+ sin = sout = 0
+ return _common.sswap(total, used, free, percent, sin, sout)
+
+
+# =====================================================================
+# --- CPU
+# =====================================================================
+
+
+def cpu_times():
+ """Return a named tuple representing the following system-wide
+ CPU times:
+ (user, nice, system, idle, iowait, irq, softirq [steal, [guest,
+ [guest_nice]]])
+ Last 3 fields may not be available on all Linux kernel versions.
+ """
+ procfs_path = get_procfs_path()
+ set_scputimes_ntuple(procfs_path)
+ with open_binary('%s/stat' % procfs_path) as f:
+ values = f.readline().split()
+ fields = values[1:len(scputimes._fields) + 1]
+ fields = [float(x) / CLOCK_TICKS for x in fields]
+ return scputimes(*fields)
+
+
+def per_cpu_times():
+ """Return a list of namedtuple representing the CPU times
+ for every CPU available on the system.
+ """
+ procfs_path = get_procfs_path()
+ set_scputimes_ntuple(procfs_path)
+ cpus = []
+ with open_binary('%s/stat' % procfs_path) as f:
+ # get rid of the first line which refers to system wide CPU stats
+ f.readline()
+ for line in f:
+ if line.startswith(b'cpu'):
+ values = line.split()
+ fields = values[1:len(scputimes._fields) + 1]
+ fields = [float(x) / CLOCK_TICKS for x in fields]
+ entry = scputimes(*fields)
+ cpus.append(entry)
+ return cpus
+
+
+def cpu_count_logical():
+ """Return the number of logical CPUs in the system."""
+ try:
+ return os.sysconf("SC_NPROCESSORS_ONLN")
+ except ValueError:
+ # as a second fallback we try to parse /proc/cpuinfo
+ num = 0
+ with open_binary('%s/cpuinfo' % get_procfs_path()) as f:
+ for line in f:
+ if line.lower().startswith(b'processor'):
+ num += 1
+
+ # unknown format (e.g. amrel/sparc architectures), see:
+ # https://github.com/giampaolo/psutil/issues/200
+ # try to parse /proc/stat as a last resort
+ if num == 0:
+ search = re.compile(r'cpu\d')
+ with open_text('%s/stat' % get_procfs_path()) as f:
+ for line in f:
+ line = line.split(' ')[0]
+ if search.match(line):
+ num += 1
+
+ if num == 0:
+ # mimic os.cpu_count()
+ return None
+ return num
+
+
+def cpu_count_cores():
+ """Return the number of CPU cores in the system."""
+ # Method #1
+ ls = set()
+ # These 2 files are the same but */core_cpus_list is newer while
+ # */thread_siblings_list is deprecated and may disappear in the future.
+ # https://www.kernel.org/doc/Documentation/admin-guide/cputopology.rst
+ # https://github.com/giampaolo/psutil/pull/1727#issuecomment-707624964
+ # https://lkml.org/lkml/2019/2/26/41
+ p1 = "/sys/devices/system/cpu/cpu[0-9]*/topology/core_cpus_list"
+ p2 = "/sys/devices/system/cpu/cpu[0-9]*/topology/thread_siblings_list"
+ for path in glob.glob(p1) or glob.glob(p2):
+ with open_binary(path) as f:
+ ls.add(f.read().strip())
+ result = len(ls)
+ if result != 0:
+ return result
+
+ # Method #2
+ mapping = {}
+ current_info = {}
+ with open_binary('%s/cpuinfo' % get_procfs_path()) as f:
+ for line in f:
+ line = line.strip().lower()
+ if not line:
+ # new section
+ try:
+ mapping[current_info[b'physical id']] = \
+ current_info[b'cpu cores']
+ except KeyError:
+ pass
+ current_info = {}
+ else:
+ # ongoing section
+ if line.startswith((b'physical id', b'cpu cores')):
+ key, value = line.split(b'\t:', 1)
+ current_info[key] = int(value)
+
+ result = sum(mapping.values())
+ return result or None # mimic os.cpu_count()
+
+
+def cpu_stats():
+ """Return various CPU stats as a named tuple."""
+ with open_binary('%s/stat' % get_procfs_path()) as f:
+ ctx_switches = None
+ interrupts = None
+ soft_interrupts = None
+ for line in f:
+ if line.startswith(b'ctxt'):
+ ctx_switches = int(line.split()[1])
+ elif line.startswith(b'intr'):
+ interrupts = int(line.split()[1])
+ elif line.startswith(b'softirq'):
+ soft_interrupts = int(line.split()[1])
+ if ctx_switches is not None and soft_interrupts is not None \
+ and interrupts is not None:
+ break
+ syscalls = 0
+ return _common.scpustats(
+ ctx_switches, interrupts, soft_interrupts, syscalls)
+
+
+def _cpu_get_cpuinfo_freq():
+ """Return current CPU frequency from cpuinfo if available.
+ """
+ ret = []
+ with open_binary('%s/cpuinfo' % get_procfs_path()) as f:
+ for line in f:
+ if line.lower().startswith(b'cpu mhz'):
+ ret.append(float(line.split(b':', 1)[1]))
+ return ret
+
+
+if os.path.exists("/sys/devices/system/cpu/cpufreq/policy0") or \
+ os.path.exists("/sys/devices/system/cpu/cpu0/cpufreq"):
+ def cpu_freq():
+ """Return frequency metrics for all CPUs.
+ Contrarily to other OSes, Linux updates these values in
+ real-time.
+ """
+ cpuinfo_freqs = _cpu_get_cpuinfo_freq()
+ paths = \
+ glob.glob("/sys/devices/system/cpu/cpufreq/policy[0-9]*") or \
+ glob.glob("/sys/devices/system/cpu/cpu[0-9]*/cpufreq")
+ paths.sort(key=lambda x: int(re.search(r"[0-9]+", x).group()))
+ ret = []
+ pjoin = os.path.join
+ for i, path in enumerate(paths):
+ if len(paths) == len(cpuinfo_freqs):
+ # take cached value from cpuinfo if available, see:
+ # https://github.com/giampaolo/psutil/issues/1851
+ curr = cpuinfo_freqs[i] * 1000
+ else:
+ curr = bcat(pjoin(path, "scaling_cur_freq"), fallback=None)
+ if curr is None:
+ # Likely an old RedHat, see:
+ # https://github.com/giampaolo/psutil/issues/1071
+ curr = bcat(pjoin(path, "cpuinfo_cur_freq"), fallback=None)
+ if curr is None:
+ raise NotImplementedError(
+ "can't find current frequency file")
+ curr = int(curr) / 1000
+ max_ = int(bcat(pjoin(path, "scaling_max_freq"))) / 1000
+ min_ = int(bcat(pjoin(path, "scaling_min_freq"))) / 1000
+ ret.append(_common.scpufreq(curr, min_, max_))
+ return ret
+
+else:
+ def cpu_freq():
+ """Alternate implementation using /proc/cpuinfo.
+ min and max frequencies are not available and are set to None.
+ """
+ return [_common.scpufreq(x, 0., 0.) for x in _cpu_get_cpuinfo_freq()]
+
+
+# =====================================================================
+# --- network
+# =====================================================================
+
+
+net_if_addrs = cext_posix.net_if_addrs
+
+
+class _Ipv6UnsupportedError(Exception):
+ pass
+
+
+class Connections:
+ """A wrapper on top of /proc/net/* files, retrieving per-process
+ and system-wide open connections (TCP, UDP, UNIX) similarly to
+ "netstat -an".
+
+ Note: in case of UNIX sockets we're only able to determine the
+ local endpoint/path, not the one it's connected to.
+ According to [1] it would be possible but not easily.
+
+ [1] http://serverfault.com/a/417946
+ """
+
+ def __init__(self):
+ # The string represents the basename of the corresponding
+ # /proc/net/{proto_name} file.
+ tcp4 = ("tcp", socket.AF_INET, socket.SOCK_STREAM)
+ tcp6 = ("tcp6", socket.AF_INET6, socket.SOCK_STREAM)
+ udp4 = ("udp", socket.AF_INET, socket.SOCK_DGRAM)
+ udp6 = ("udp6", socket.AF_INET6, socket.SOCK_DGRAM)
+ unix = ("unix", socket.AF_UNIX, None)
+ self.tmap = {
+ "all": (tcp4, tcp6, udp4, udp6, unix),
+ "tcp": (tcp4, tcp6),
+ "tcp4": (tcp4,),
+ "tcp6": (tcp6,),
+ "udp": (udp4, udp6),
+ "udp4": (udp4,),
+ "udp6": (udp6,),
+ "unix": (unix,),
+ "inet": (tcp4, tcp6, udp4, udp6),
+ "inet4": (tcp4, udp4),
+ "inet6": (tcp6, udp6),
+ }
+ self._procfs_path = None
+
+ def get_proc_inodes(self, pid):
+ inodes = defaultdict(list)
+ for fd in os.listdir("%s/%s/fd" % (self._procfs_path, pid)):
+ try:
+ inode = readlink("%s/%s/fd/%s" % (self._procfs_path, pid, fd))
+ except (FileNotFoundError, ProcessLookupError):
+ # ENOENT == file which is gone in the meantime;
+ # os.stat('/proc/%s' % self.pid) will be done later
+ # to force NSP (if it's the case)
+ continue
+ except OSError as err:
+ if err.errno == errno.EINVAL:
+ # not a link
+ continue
+ if err.errno == errno.ENAMETOOLONG:
+ # file name too long
+ debug(err)
+ continue
+ raise
+ else:
+ if inode.startswith('socket:['):
+ # the process is using a socket
+ inode = inode[8:][:-1]
+ inodes[inode].append((pid, int(fd)))
+ return inodes
+
+ def get_all_inodes(self):
+ inodes = {}
+ for pid in pids():
+ try:
+ inodes.update(self.get_proc_inodes(pid))
+ except (FileNotFoundError, ProcessLookupError, PermissionError):
+ # os.listdir() is gonna raise a lot of access denied
+ # exceptions in case of unprivileged user; that's fine
+ # as we'll just end up returning a connection with PID
+ # and fd set to None anyway.
+ # Both netstat -an and lsof does the same so it's
+ # unlikely we can do any better.
+ # ENOENT just means a PID disappeared on us.
+ continue
+ return inodes
+
+ @staticmethod
+ def decode_address(addr, family):
+ """Accept an "ip:port" address as displayed in /proc/net/*
+ and convert it into a human readable form, like:
+
+ "0500000A:0016" -> ("10.0.0.5", 22)
+ "0000000000000000FFFF00000100007F:9E49" -> ("::ffff:127.0.0.1", 40521)
+
+ The IP address portion is a little or big endian four-byte
+ hexadecimal number; that is, the least significant byte is listed
+ first, so we need to reverse the order of the bytes to convert it
+ to an IP address.
+ The port is represented as a two-byte hexadecimal number.
+
+ Reference:
+ http://linuxdevcenter.com/pub/a/linux/2000/11/16/LinuxAdmin.html
+ """
+ ip, port = addr.split(':')
+ port = int(port, 16)
+ # this usually refers to a local socket in listen mode with
+ # no end-points connected
+ if not port:
+ return ()
+ if PY3:
+ ip = ip.encode('ascii')
+ if family == socket.AF_INET:
+ # see: https://github.com/giampaolo/psutil/issues/201
+ if LITTLE_ENDIAN:
+ ip = socket.inet_ntop(family, base64.b16decode(ip)[::-1])
+ else:
+ ip = socket.inet_ntop(family, base64.b16decode(ip))
+ else: # IPv6
+ ip = base64.b16decode(ip)
+ try:
+ # see: https://github.com/giampaolo/psutil/issues/201
+ if LITTLE_ENDIAN:
+ ip = socket.inet_ntop(
+ socket.AF_INET6,
+ struct.pack('>4I', *struct.unpack('<4I', ip)))
+ else:
+ ip = socket.inet_ntop(
+ socket.AF_INET6,
+ struct.pack('<4I', *struct.unpack('<4I', ip)))
+ except ValueError:
+ # see: https://github.com/giampaolo/psutil/issues/623
+ if not supports_ipv6():
+ raise _Ipv6UnsupportedError
+ else:
+ raise
+ return _common.addr(ip, port)
+
+ @staticmethod
+ def process_inet(file, family, type_, inodes, filter_pid=None):
+ """Parse /proc/net/tcp* and /proc/net/udp* files."""
+ if file.endswith('6') and not os.path.exists(file):
+ # IPv6 not supported
+ return
+ with open_text(file) as f:
+ f.readline() # skip the first line
+ for lineno, line in enumerate(f, 1):
+ try:
+ _, laddr, raddr, status, _, _, _, _, _, inode = \
+ line.split()[:10]
+ except ValueError:
+ raise RuntimeError(
+ "error while parsing %s; malformed line %s %r" % (
+ file, lineno, line))
+ if inode in inodes:
+ # # We assume inet sockets are unique, so we error
+ # # out if there are multiple references to the
+ # # same inode. We won't do this for UNIX sockets.
+ # if len(inodes[inode]) > 1 and family != socket.AF_UNIX:
+ # raise ValueError("ambiguous inode with multiple "
+ # "PIDs references")
+ pid, fd = inodes[inode][0]
+ else:
+ pid, fd = None, -1
+ if filter_pid is not None and filter_pid != pid:
+ continue
+ else:
+ if type_ == socket.SOCK_STREAM:
+ status = TCP_STATUSES[status]
+ else:
+ status = _common.CONN_NONE
+ try:
+ laddr = Connections.decode_address(laddr, family)
+ raddr = Connections.decode_address(raddr, family)
+ except _Ipv6UnsupportedError:
+ continue
+ yield (fd, family, type_, laddr, raddr, status, pid)
+
+ @staticmethod
+ def process_unix(file, family, inodes, filter_pid=None):
+ """Parse /proc/net/unix files."""
+ with open_text(file) as f:
+ f.readline() # skip the first line
+ for line in f:
+ tokens = line.split()
+ try:
+ _, _, _, _, type_, _, inode = tokens[0:7]
+ except ValueError:
+ if ' ' not in line:
+ # see: https://github.com/giampaolo/psutil/issues/766
+ continue
+ raise RuntimeError(
+ "error while parsing %s; malformed line %r" % (
+ file, line))
+ if inode in inodes:
+ # With UNIX sockets we can have a single inode
+ # referencing many file descriptors.
+ pairs = inodes[inode]
+ else:
+ pairs = [(None, -1)]
+ for pid, fd in pairs:
+ if filter_pid is not None and filter_pid != pid:
+ continue
+ else:
+ if len(tokens) == 8:
+ path = tokens[-1]
+ else:
+ path = ""
+ type_ = _common.socktype_to_enum(int(type_))
+ # XXX: determining the remote endpoint of a
+ # UNIX socket on Linux is not possible, see:
+ # https://serverfault.com/questions/252723/
+ raddr = ""
+ status = _common.CONN_NONE
+ yield (fd, family, type_, path, raddr, status, pid)
+
+ def retrieve(self, kind, pid=None):
+ if kind not in self.tmap:
+ raise ValueError("invalid %r kind argument; choose between %s"
+ % (kind, ', '.join([repr(x) for x in self.tmap])))
+ self._procfs_path = get_procfs_path()
+ if pid is not None:
+ inodes = self.get_proc_inodes(pid)
+ if not inodes:
+ # no connections for this process
+ return []
+ else:
+ inodes = self.get_all_inodes()
+ ret = set()
+ for proto_name, family, type_ in self.tmap[kind]:
+ path = "%s/net/%s" % (self._procfs_path, proto_name)
+ if family in (socket.AF_INET, socket.AF_INET6):
+ ls = self.process_inet(
+ path, family, type_, inodes, filter_pid=pid)
+ else:
+ ls = self.process_unix(
+ path, family, inodes, filter_pid=pid)
+ for fd, family, type_, laddr, raddr, status, bound_pid in ls:
+ if pid:
+ conn = _common.pconn(fd, family, type_, laddr, raddr,
+ status)
+ else:
+ conn = _common.sconn(fd, family, type_, laddr, raddr,
+ status, bound_pid)
+ ret.add(conn)
+ return list(ret)
+
+
+_connections = Connections()
+
+
+def net_connections(kind='inet'):
+ """Return system-wide open connections."""
+ return _connections.retrieve(kind)
+
+
+def net_io_counters():
+ """Return network I/O statistics for every network interface
+ installed on the system as a dict of raw tuples.
+ """
+ with open_text("%s/net/dev" % get_procfs_path()) as f:
+ lines = f.readlines()
+ retdict = {}
+ for line in lines[2:]:
+ colon = line.rfind(':')
+ assert colon > 0, repr(line)
+ name = line[:colon].strip()
+ fields = line[colon + 1:].strip().split()
+
+ # in
+ (bytes_recv,
+ packets_recv,
+ errin,
+ dropin,
+ fifoin, # unused
+ framein, # unused
+ compressedin, # unused
+ multicastin, # unused
+ # out
+ bytes_sent,
+ packets_sent,
+ errout,
+ dropout,
+ fifoout, # unused
+ collisionsout, # unused
+ carrierout, # unused
+ compressedout) = map(int, fields)
+
+ retdict[name] = (bytes_sent, bytes_recv, packets_sent, packets_recv,
+ errin, errout, dropin, dropout)
+ return retdict
+
+
+def net_if_stats():
+ """Get NIC stats (isup, duplex, speed, mtu)."""
+ duplex_map = {cext.DUPLEX_FULL: NIC_DUPLEX_FULL,
+ cext.DUPLEX_HALF: NIC_DUPLEX_HALF,
+ cext.DUPLEX_UNKNOWN: NIC_DUPLEX_UNKNOWN}
+ names = net_io_counters().keys()
+ ret = {}
+ for name in names:
+ try:
+ mtu = cext_posix.net_if_mtu(name)
+ flags = cext_posix.net_if_flags(name)
+ duplex, speed = cext.net_if_duplex_speed(name)
+ except OSError as err:
+ # https://github.com/giampaolo/psutil/issues/1279
+ if err.errno != errno.ENODEV:
+ raise
+ else:
+ debug(err)
+ else:
+ output_flags = ','.join(flags)
+ isup = 'running' in flags
+ ret[name] = _common.snicstats(isup, duplex_map[duplex], speed, mtu,
+ output_flags)
+ return ret
+
+
+# =====================================================================
+# --- disks
+# =====================================================================
+
+
+disk_usage = _psposix.disk_usage
+
+
+def disk_io_counters(perdisk=False):
+ """Return disk I/O statistics for every disk installed on the
+ system as a dict of raw tuples.
+ """
+ def read_procfs():
+ # OK, this is a bit confusing. The format of /proc/diskstats can
+ # have 3 variations.
+ # On Linux 2.4 each line has always 15 fields, e.g.:
+ # "3 0 8 hda 8 8 8 8 8 8 8 8 8 8 8"
+ # On Linux 2.6+ each line *usually* has 14 fields, and the disk
+ # name is in another position, like this:
+ # "3 0 hda 8 8 8 8 8 8 8 8 8 8 8"
+ # ...unless (Linux 2.6) the line refers to a partition instead
+ # of a disk, in which case the line has less fields (7):
+ # "3 1 hda1 8 8 8 8"
+ # 4.18+ has 4 fields added:
+ # "3 0 hda 8 8 8 8 8 8 8 8 8 8 8 0 0 0 0"
+ # 5.5 has 2 more fields.
+ # See:
+ # https://www.kernel.org/doc/Documentation/iostats.txt
+ # https://www.kernel.org/doc/Documentation/ABI/testing/procfs-diskstats
+ with open_text("%s/diskstats" % get_procfs_path()) as f:
+ lines = f.readlines()
+ for line in lines:
+ fields = line.split()
+ flen = len(fields)
+ if flen == 15:
+ # Linux 2.4
+ name = fields[3]
+ reads = int(fields[2])
+ (reads_merged, rbytes, rtime, writes, writes_merged,
+ wbytes, wtime, _, busy_time, _) = map(int, fields[4:14])
+ elif flen == 14 or flen >= 18:
+ # Linux 2.6+, line referring to a disk
+ name = fields[2]
+ (reads, reads_merged, rbytes, rtime, writes, writes_merged,
+ wbytes, wtime, _, busy_time, _) = map(int, fields[3:14])
+ elif flen == 7:
+ # Linux 2.6+, line referring to a partition
+ name = fields[2]
+ reads, rbytes, writes, wbytes = map(int, fields[3:])
+ rtime = wtime = reads_merged = writes_merged = busy_time = 0
+ else:
+ raise ValueError("not sure how to interpret line %r" % line)
+ yield (name, reads, writes, rbytes, wbytes, rtime, wtime,
+ reads_merged, writes_merged, busy_time)
+
+ def read_sysfs():
+ for block in os.listdir('/sys/block'):
+ for root, _, files in os.walk(os.path.join('/sys/block', block)):
+ if 'stat' not in files:
+ continue
+ with open_text(os.path.join(root, 'stat')) as f:
+ fields = f.read().strip().split()
+ name = os.path.basename(root)
+ (reads, reads_merged, rbytes, rtime, writes, writes_merged,
+ wbytes, wtime, _, busy_time) = map(int, fields[:10])
+ yield (name, reads, writes, rbytes, wbytes, rtime,
+ wtime, reads_merged, writes_merged, busy_time)
+
+ if os.path.exists('%s/diskstats' % get_procfs_path()):
+ gen = read_procfs()
+ elif os.path.exists('/sys/block'):
+ gen = read_sysfs()
+ else:
+ raise NotImplementedError(
+ "%s/diskstats nor /sys/block filesystem are available on this "
+ "system" % get_procfs_path())
+
+ retdict = {}
+ for entry in gen:
+ (name, reads, writes, rbytes, wbytes, rtime, wtime, reads_merged,
+ writes_merged, busy_time) = entry
+ if not perdisk and not is_storage_device(name):
+ # perdisk=False means we want to calculate totals so we skip
+ # partitions (e.g. 'sda1', 'nvme0n1p1') and only include
+ # base disk devices (e.g. 'sda', 'nvme0n1'). Base disks
+ # include a total of all their partitions + some extra size
+ # of their own:
+ # $ cat /proc/diskstats
+ # 259 0 sda 10485760 ...
+ # 259 1 sda1 5186039 ...
+ # 259 1 sda2 5082039 ...
+ # See:
+ # https://github.com/giampaolo/psutil/pull/1313
+ continue
+
+ rbytes *= DISK_SECTOR_SIZE
+ wbytes *= DISK_SECTOR_SIZE
+ retdict[name] = (reads, writes, rbytes, wbytes, rtime, wtime,
+ reads_merged, writes_merged, busy_time)
+
+ return retdict
+
+
+class RootFsDeviceFinder:
+ """disk_partitions() may return partitions with device == "/dev/root"
+ or "rootfs". This container class uses different strategies to try to
+ obtain the real device path. Resources:
+ https://bootlin.com/blog/find-root-device/
+ https://www.systutorials.com/how-to-find-the-disk-where-root-is-on-in-bash-on-linux/
+ """
+ __slots__ = ['major', 'minor']
+
+ def __init__(self):
+ dev = os.stat("/").st_dev
+ self.major = os.major(dev)
+ self.minor = os.minor(dev)
+
+ def ask_proc_partitions(self):
+ with open_text("%s/partitions" % get_procfs_path()) as f:
+ for line in f.readlines()[2:]:
+ fields = line.split()
+ if len(fields) < 4: # just for extra safety
+ continue
+ major = int(fields[0]) if fields[0].isdigit() else None
+ minor = int(fields[1]) if fields[1].isdigit() else None
+ name = fields[3]
+ if major == self.major and minor == self.minor:
+ if name: # just for extra safety
+ return "/dev/%s" % name
+
+ def ask_sys_dev_block(self):
+ path = "/sys/dev/block/%s:%s/uevent" % (self.major, self.minor)
+ with open_text(path) as f:
+ for line in f:
+ if line.startswith("DEVNAME="):
+ name = line.strip().rpartition("DEVNAME=")[2]
+ if name: # just for extra safety
+ return "/dev/%s" % name
+
+ def ask_sys_class_block(self):
+ needle = "%s:%s" % (self.major, self.minor)
+ files = glob.iglob("/sys/class/block/*/dev")
+ for file in files:
+ try:
+ f = open_text(file)
+ except FileNotFoundError: # race condition
+ continue
+ else:
+ with f:
+ data = f.read().strip()
+ if data == needle:
+ name = os.path.basename(os.path.dirname(file))
+ return "/dev/%s" % name
+
+ def find(self):
+ path = None
+ if path is None:
+ try:
+ path = self.ask_proc_partitions()
+ except (IOError, OSError) as err:
+ debug(err)
+ if path is None:
+ try:
+ path = self.ask_sys_dev_block()
+ except (IOError, OSError) as err:
+ debug(err)
+ if path is None:
+ try:
+ path = self.ask_sys_class_block()
+ except (IOError, OSError) as err:
+ debug(err)
+ # We use exists() because the "/dev/*" part of the path is hard
+ # coded, so we want to be sure.
+ if path is not None and os.path.exists(path):
+ return path
+
+
+def disk_partitions(all=False):
+ """Return mounted disk partitions as a list of namedtuples."""
+ fstypes = set()
+ procfs_path = get_procfs_path()
+ with open_text("%s/filesystems" % procfs_path) as f:
+ for line in f:
+ line = line.strip()
+ if not line.startswith("nodev"):
+ fstypes.add(line.strip())
+ else:
+ # ignore all lines starting with "nodev" except "nodev zfs"
+ fstype = line.split("\t")[1]
+ if fstype == "zfs":
+ fstypes.add("zfs")
+
+ # See: https://github.com/giampaolo/psutil/issues/1307
+ if procfs_path == "/proc" and os.path.isfile('/etc/mtab'):
+ mounts_path = os.path.realpath("/etc/mtab")
+ else:
+ mounts_path = os.path.realpath("%s/self/mounts" % procfs_path)
+
+ retlist = []
+ partitions = cext.disk_partitions(mounts_path)
+ for partition in partitions:
+ device, mountpoint, fstype, opts = partition
+ if device == 'none':
+ device = ''
+ if device in ("/dev/root", "rootfs"):
+ device = RootFsDeviceFinder().find() or device
+ if not all:
+ if device == '' or fstype not in fstypes:
+ continue
+ maxfile = maxpath = None # set later
+ ntuple = _common.sdiskpart(device, mountpoint, fstype, opts,
+ maxfile, maxpath)
+ retlist.append(ntuple)
+
+ return retlist
+
+
+# =====================================================================
+# --- sensors
+# =====================================================================
+
+
+def sensors_temperatures():
+ """Return hardware (CPU and others) temperatures as a dict
+ including hardware name, label, current, max and critical
+ temperatures.
+
+ Implementation notes:
+ - /sys/class/hwmon looks like the most recent interface to
+ retrieve this info, and this implementation relies on it
+ only (old distros will probably use something else)
+ - lm-sensors on Ubuntu 16.04 relies on /sys/class/hwmon
+ - /sys/class/thermal/thermal_zone* is another one but it's more
+ difficult to parse
+ """
+ ret = collections.defaultdict(list)
+ basenames = glob.glob('/sys/class/hwmon/hwmon*/temp*_*')
+ # CentOS has an intermediate /device directory:
+ # https://github.com/giampaolo/psutil/issues/971
+ # https://github.com/nicolargo/glances/issues/1060
+ basenames.extend(glob.glob('/sys/class/hwmon/hwmon*/device/temp*_*'))
+ basenames = sorted(set([x.split('_')[0] for x in basenames]))
+
+ # Only add the coretemp hwmon entries if they're not already in
+ # /sys/class/hwmon/
+ # https://github.com/giampaolo/psutil/issues/1708
+ # https://github.com/giampaolo/psutil/pull/1648
+ basenames2 = glob.glob(
+ '/sys/devices/platform/coretemp.*/hwmon/hwmon*/temp*_*')
+ repl = re.compile('/sys/devices/platform/coretemp.*/hwmon/')
+ for name in basenames2:
+ altname = repl.sub('/sys/class/hwmon/', name)
+ if altname not in basenames:
+ basenames.append(name)
+
+ for base in basenames:
+ try:
+ path = base + '_input'
+ current = float(bcat(path)) / 1000.0
+ path = os.path.join(os.path.dirname(base), 'name')
+ unit_name = cat(path).strip()
+ except (IOError, OSError, ValueError):
+ # A lot of things can go wrong here, so let's just skip the
+ # whole entry. Sure thing is Linux's /sys/class/hwmon really
+ # is a stinky broken mess.
+ # https://github.com/giampaolo/psutil/issues/1009
+ # https://github.com/giampaolo/psutil/issues/1101
+ # https://github.com/giampaolo/psutil/issues/1129
+ # https://github.com/giampaolo/psutil/issues/1245
+ # https://github.com/giampaolo/psutil/issues/1323
+ continue
+
+ high = bcat(base + '_max', fallback=None)
+ critical = bcat(base + '_crit', fallback=None)
+ label = cat(base + '_label', fallback='').strip()
+
+ if high is not None:
+ try:
+ high = float(high) / 1000.0
+ except ValueError:
+ high = None
+ if critical is not None:
+ try:
+ critical = float(critical) / 1000.0
+ except ValueError:
+ critical = None
+
+ ret[unit_name].append((label, current, high, critical))
+
+ # Indication that no sensors were detected in /sys/class/hwmon/
+ if not basenames:
+ basenames = glob.glob('/sys/class/thermal/thermal_zone*')
+ basenames = sorted(set(basenames))
+
+ for base in basenames:
+ try:
+ path = os.path.join(base, 'temp')
+ current = float(bcat(path)) / 1000.0
+ path = os.path.join(base, 'type')
+ unit_name = cat(path).strip()
+ except (IOError, OSError, ValueError) as err:
+ debug(err)
+ continue
+
+ trip_paths = glob.glob(base + '/trip_point*')
+ trip_points = set(['_'.join(
+ os.path.basename(p).split('_')[0:3]) for p in trip_paths])
+ critical = None
+ high = None
+ for trip_point in trip_points:
+ path = os.path.join(base, trip_point + "_type")
+ trip_type = cat(path, fallback='').strip()
+ if trip_type == 'critical':
+ critical = bcat(os.path.join(base, trip_point + "_temp"),
+ fallback=None)
+ elif trip_type == 'high':
+ high = bcat(os.path.join(base, trip_point + "_temp"),
+ fallback=None)
+
+ if high is not None:
+ try:
+ high = float(high) / 1000.0
+ except ValueError:
+ high = None
+ if critical is not None:
+ try:
+ critical = float(critical) / 1000.0
+ except ValueError:
+ critical = None
+
+ ret[unit_name].append(('', current, high, critical))
+
+ return dict(ret)
+
+
+def sensors_fans():
+ """Return hardware fans info (for CPU and other peripherals) as a
+ dict including hardware label and current speed.
+
+ Implementation notes:
+ - /sys/class/hwmon looks like the most recent interface to
+ retrieve this info, and this implementation relies on it
+ only (old distros will probably use something else)
+ - lm-sensors on Ubuntu 16.04 relies on /sys/class/hwmon
+ """
+ ret = collections.defaultdict(list)
+ basenames = glob.glob('/sys/class/hwmon/hwmon*/fan*_*')
+ if not basenames:
+ # CentOS has an intermediate /device directory:
+ # https://github.com/giampaolo/psutil/issues/971
+ basenames = glob.glob('/sys/class/hwmon/hwmon*/device/fan*_*')
+
+ basenames = sorted(set([x.split('_')[0] for x in basenames]))
+ for base in basenames:
+ try:
+ current = int(bcat(base + '_input'))
+ except (IOError, OSError) as err:
+ debug(err)
+ continue
+ unit_name = cat(os.path.join(os.path.dirname(base), 'name')).strip()
+ label = cat(base + '_label', fallback='').strip()
+ ret[unit_name].append(_common.sfan(label, current))
+
+ return dict(ret)
+
+
+def sensors_battery():
+ """Return battery information.
+ Implementation note: it appears /sys/class/power_supply/BAT0/
+ directory structure may vary and provide files with the same
+ meaning but under different names, see:
+ https://github.com/giampaolo/psutil/issues/966
+ """
+ null = object()
+
+ def multi_bcat(*paths):
+ """Attempt to read the content of multiple files which may
+ not exist. If none of them exist return None.
+ """
+ for path in paths:
+ ret = bcat(path, fallback=null)
+ if ret != null:
+ try:
+ return int(ret)
+ except ValueError:
+ return ret.strip()
+ return None
+
+ bats = [x for x in os.listdir(POWER_SUPPLY_PATH) if x.startswith('BAT') or
+ 'battery' in x.lower()]
+ if not bats:
+ return None
+ # Get the first available battery. Usually this is "BAT0", except
+ # some rare exceptions:
+ # https://github.com/giampaolo/psutil/issues/1238
+ root = os.path.join(POWER_SUPPLY_PATH, sorted(bats)[0])
+
+ # Base metrics.
+ energy_now = multi_bcat(
+ root + "/energy_now",
+ root + "/charge_now")
+ power_now = multi_bcat(
+ root + "/power_now",
+ root + "/current_now")
+ energy_full = multi_bcat(
+ root + "/energy_full",
+ root + "/charge_full")
+ time_to_empty = multi_bcat(root + "/time_to_empty_now")
+
+ # Percent. If we have energy_full the percentage will be more
+ # accurate compared to reading /capacity file (float vs. int).
+ if energy_full is not None and energy_now is not None:
+ try:
+ percent = 100.0 * energy_now / energy_full
+ except ZeroDivisionError:
+ percent = 0.0
+ else:
+ percent = int(cat(root + "/capacity", fallback=-1))
+ if percent == -1:
+ return None
+
+ # Is AC power cable plugged in?
+ # Note: AC0 is not always available and sometimes (e.g. CentOS7)
+ # it's called "AC".
+ power_plugged = None
+ online = multi_bcat(
+ os.path.join(POWER_SUPPLY_PATH, "AC0/online"),
+ os.path.join(POWER_SUPPLY_PATH, "AC/online"))
+ if online is not None:
+ power_plugged = online == 1
+ else:
+ status = cat(root + "/status", fallback="").strip().lower()
+ if status == "discharging":
+ power_plugged = False
+ elif status in ("charging", "full"):
+ power_plugged = True
+
+ # Seconds left.
+ # Note to self: we may also calculate the charging ETA as per:
+ # https://github.com/thialfihar/dotfiles/blob/
+ # 013937745fd9050c30146290e8f963d65c0179e6/bin/battery.py#L55
+ if power_plugged:
+ secsleft = _common.POWER_TIME_UNLIMITED
+ elif energy_now is not None and power_now is not None:
+ try:
+ secsleft = int(energy_now / power_now * 3600)
+ except ZeroDivisionError:
+ secsleft = _common.POWER_TIME_UNKNOWN
+ elif time_to_empty is not None:
+ secsleft = int(time_to_empty * 60)
+ if secsleft < 0:
+ secsleft = _common.POWER_TIME_UNKNOWN
+ else:
+ secsleft = _common.POWER_TIME_UNKNOWN
+
+ return _common.sbattery(percent, secsleft, power_plugged)
+
+
+# =====================================================================
+# --- other system functions
+# =====================================================================
+
+
+def users():
+ """Return currently connected users as a list of namedtuples."""
+ retlist = []
+ rawlist = cext.users()
+ for item in rawlist:
+ user, tty, hostname, tstamp, user_process, pid = item
+ # note: the underlying C function includes entries about
+ # system boot, run level and others. We might want
+ # to use them in the future.
+ if not user_process:
+ continue
+ if hostname in (':0.0', ':0'):
+ hostname = 'localhost'
+ nt = _common.suser(user, tty or None, hostname, tstamp, pid)
+ retlist.append(nt)
+ return retlist
+
+
+def boot_time():
+ """Return the system boot time expressed in seconds since the epoch."""
+ global BOOT_TIME
+ path = '%s/stat' % get_procfs_path()
+ with open_binary(path) as f:
+ for line in f:
+ if line.startswith(b'btime'):
+ ret = float(line.strip().split()[1])
+ BOOT_TIME = ret
+ return ret
+ raise RuntimeError(
+ "line 'btime' not found in %s" % path)
+
+
+# =====================================================================
+# --- processes
+# =====================================================================
+
+
+def pids():
+ """Returns a list of PIDs currently running on the system."""
+ return [int(x) for x in os.listdir(b(get_procfs_path())) if x.isdigit()]
+
+
+def pid_exists(pid):
+ """Check for the existence of a unix PID. Linux TIDs are not
+ supported (always return False).
+ """
+ if not _psposix.pid_exists(pid):
+ return False
+ else:
+ # Linux's apparently does not distinguish between PIDs and TIDs
+ # (thread IDs).
+ # listdir("/proc") won't show any TID (only PIDs) but
+ # os.stat("/proc/{tid}") will succeed if {tid} exists.
+ # os.kill() can also be passed a TID. This is quite confusing.
+ # In here we want to enforce this distinction and support PIDs
+ # only, see:
+ # https://github.com/giampaolo/psutil/issues/687
+ try:
+ # Note: already checked that this is faster than using a
+ # regular expr. Also (a lot) faster than doing
+ # 'return pid in pids()'
+ path = "%s/%s/status" % (get_procfs_path(), pid)
+ with open_binary(path) as f:
+ for line in f:
+ if line.startswith(b"Tgid:"):
+ tgid = int(line.split()[1])
+ # If tgid and pid are the same then we're
+ # dealing with a process PID.
+ return tgid == pid
+ raise ValueError("'Tgid' line not found in %s" % path)
+ except (EnvironmentError, ValueError):
+ return pid in pids()
+
+
+def ppid_map():
+ """Obtain a {pid: ppid, ...} dict for all running processes in
+ one shot. Used to speed up Process.children().
+ """
+ ret = {}
+ procfs_path = get_procfs_path()
+ for pid in pids():
+ try:
+ with open_binary("%s/%s/stat" % (procfs_path, pid)) as f:
+ data = f.read()
+ except (FileNotFoundError, ProcessLookupError):
+ # Note: we should be able to access /stat for all processes
+ # aka it's unlikely we'll bump into EPERM, which is good.
+ pass
+ else:
+ rpar = data.rfind(b')')
+ dset = data[rpar + 2:].split()
+ ppid = int(dset[1])
+ ret[pid] = ppid
+ return ret
+
+
+def wrap_exceptions(fun):
+ """Decorator which translates bare OSError and IOError exceptions
+ into NoSuchProcess and AccessDenied.
+ """
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return fun(self, *args, **kwargs)
+ except PermissionError:
+ raise AccessDenied(self.pid, self._name)
+ except ProcessLookupError:
+ raise NoSuchProcess(self.pid, self._name)
+ except FileNotFoundError:
+ if not os.path.exists("%s/%s" % (self._procfs_path, self.pid)):
+ raise NoSuchProcess(self.pid, self._name)
+ # Note: zombies will keep existing under /proc until they're
+ # gone so there's no way to distinguish them in here.
+ raise
+ return wrapper
+
+
+class Process(object):
+ """Linux process implementation."""
+
+ __slots__ = ["pid", "_name", "_ppid", "_procfs_path", "_cache"]
+
+ def __init__(self, pid):
+ self.pid = pid
+ self._name = None
+ self._ppid = None
+ self._procfs_path = get_procfs_path()
+
+ def _assert_alive(self):
+ """Raise NSP if the process disappeared on us."""
+ # For those C function who do not raise NSP, possibly returning
+ # incorrect or incomplete result.
+ os.stat('%s/%s' % (self._procfs_path, self.pid))
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _parse_stat_file(self):
+ """Parse /proc/{pid}/stat file and return a dict with various
+ process info.
+ Using "man proc" as a reference: where "man proc" refers to
+ position N always subtract 3 (e.g ppid position 4 in
+ 'man proc' == position 1 in here).
+ The return value is cached in case oneshot() ctx manager is
+ in use.
+ """
+ data = bcat("%s/%s/stat" % (self._procfs_path, self.pid))
+ # Process name is between parentheses. It can contain spaces and
+ # other parentheses. This is taken into account by looking for
+ # the first occurrence of "(" and the last occurrence of ")".
+ rpar = data.rfind(b')')
+ name = data[data.find(b'(') + 1:rpar]
+ fields = data[rpar + 2:].split()
+
+ ret = {}
+ ret['name'] = name
+ ret['status'] = fields[0]
+ ret['ppid'] = fields[1]
+ ret['ttynr'] = fields[4]
+ ret['utime'] = fields[11]
+ ret['stime'] = fields[12]
+ ret['children_utime'] = fields[13]
+ ret['children_stime'] = fields[14]
+ ret['create_time'] = fields[19]
+ ret['cpu_num'] = fields[36]
+ ret['blkio_ticks'] = fields[39] # aka 'delayacct_blkio_ticks'
+
+ return ret
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _read_status_file(self):
+ """Read /proc/{pid}/stat file and return its content.
+ The return value is cached in case oneshot() ctx manager is
+ in use.
+ """
+ with open_binary("%s/%s/status" % (self._procfs_path, self.pid)) as f:
+ return f.read()
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _read_smaps_file(self):
+ with open_binary("%s/%s/smaps" % (self._procfs_path, self.pid)) as f:
+ return f.read().strip()
+
+ def oneshot_enter(self):
+ self._parse_stat_file.cache_activate(self)
+ self._read_status_file.cache_activate(self)
+ self._read_smaps_file.cache_activate(self)
+
+ def oneshot_exit(self):
+ self._parse_stat_file.cache_deactivate(self)
+ self._read_status_file.cache_deactivate(self)
+ self._read_smaps_file.cache_deactivate(self)
+
+ @wrap_exceptions
+ def name(self):
+ name = self._parse_stat_file()['name']
+ if PY3:
+ name = decode(name)
+ # XXX - gets changed later and probably needs refactoring
+ return name
+
+ def exe(self):
+ try:
+ return readlink("%s/%s/exe" % (self._procfs_path, self.pid))
+ except (FileNotFoundError, ProcessLookupError):
+ # no such file error; might be raised also if the
+ # path actually exists for system processes with
+ # low pids (about 0-20)
+ if os.path.lexists("%s/%s" % (self._procfs_path, self.pid)):
+ return ""
+ else:
+ if not pid_exists(self.pid):
+ raise NoSuchProcess(self.pid, self._name)
+ else:
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ except PermissionError:
+ raise AccessDenied(self.pid, self._name)
+
+ @wrap_exceptions
+ def cmdline(self):
+ with open_text("%s/%s/cmdline" % (self._procfs_path, self.pid)) as f:
+ data = f.read()
+ if not data:
+ # may happen in case of zombie process
+ return []
+ # 'man proc' states that args are separated by null bytes '\0'
+ # and last char is supposed to be a null byte. Nevertheless
+ # some processes may change their cmdline after being started
+ # (via setproctitle() or similar), they are usually not
+ # compliant with this rule and use spaces instead. Google
+ # Chrome process is an example. See:
+ # https://github.com/giampaolo/psutil/issues/1179
+ sep = '\x00' if data.endswith('\x00') else ' '
+ if data.endswith(sep):
+ data = data[:-1]
+ cmdline = data.split(sep)
+ # Sometimes last char is a null byte '\0' but the args are
+ # separated by spaces, see: https://github.com/giampaolo/psutil/
+ # issues/1179#issuecomment-552984549
+ if sep == '\x00' and len(cmdline) == 1 and ' ' in data:
+ cmdline = data.split(' ')
+ return cmdline
+
+ @wrap_exceptions
+ def environ(self):
+ with open_text("%s/%s/environ" % (self._procfs_path, self.pid)) as f:
+ data = f.read()
+ return parse_environ_block(data)
+
+ @wrap_exceptions
+ def terminal(self):
+ tty_nr = int(self._parse_stat_file()['ttynr'])
+ tmap = _psposix.get_terminal_map()
+ try:
+ return tmap[tty_nr]
+ except KeyError:
+ return None
+
+ # May not be available on old kernels.
+ if os.path.exists('/proc/%s/io' % os.getpid()):
+ @wrap_exceptions
+ def io_counters(self):
+ fname = "%s/%s/io" % (self._procfs_path, self.pid)
+ fields = {}
+ with open_binary(fname) as f:
+ for line in f:
+ # https://github.com/giampaolo/psutil/issues/1004
+ line = line.strip()
+ if line:
+ try:
+ name, value = line.split(b': ')
+ except ValueError:
+ # https://github.com/giampaolo/psutil/issues/1004
+ continue
+ else:
+ fields[name] = int(value)
+ if not fields:
+ raise RuntimeError("%s file was empty" % fname)
+ try:
+ return pio(
+ fields[b'syscr'], # read syscalls
+ fields[b'syscw'], # write syscalls
+ fields[b'read_bytes'], # read bytes
+ fields[b'write_bytes'], # write bytes
+ fields[b'rchar'], # read chars
+ fields[b'wchar'], # write chars
+ )
+ except KeyError as err:
+ raise ValueError("%r field was not found in %s; found fields "
+ "are %r" % (err[0], fname, fields))
+
+ @wrap_exceptions
+ def cpu_times(self):
+ values = self._parse_stat_file()
+ utime = float(values['utime']) / CLOCK_TICKS
+ stime = float(values['stime']) / CLOCK_TICKS
+ children_utime = float(values['children_utime']) / CLOCK_TICKS
+ children_stime = float(values['children_stime']) / CLOCK_TICKS
+ iowait = float(values['blkio_ticks']) / CLOCK_TICKS
+ return pcputimes(utime, stime, children_utime, children_stime, iowait)
+
+ @wrap_exceptions
+ def cpu_num(self):
+ """What CPU the process is on."""
+ return int(self._parse_stat_file()['cpu_num'])
+
+ @wrap_exceptions
+ def wait(self, timeout=None):
+ return _psposix.wait_pid(self.pid, timeout, self._name)
+
+ @wrap_exceptions
+ def create_time(self):
+ ctime = float(self._parse_stat_file()['create_time'])
+ # According to documentation, starttime is in field 21 and the
+ # unit is jiffies (clock ticks).
+ # We first divide it for clock ticks and then add uptime returning
+ # seconds since the epoch.
+ # Also use cached value if available.
+ bt = BOOT_TIME or boot_time()
+ return (ctime / CLOCK_TICKS) + bt
+
+ @wrap_exceptions
+ def memory_info(self):
+ # ============================================================
+ # | FIELD | DESCRIPTION | AKA | TOP |
+ # ============================================================
+ # | rss | resident set size | | RES |
+ # | vms | total program size | size | VIRT |
+ # | shared | shared pages (from shared mappings) | | SHR |
+ # | text | text ('code') | trs | CODE |
+ # | lib | library (unused in Linux 2.6) | lrs | |
+ # | data | data + stack | drs | DATA |
+ # | dirty | dirty pages (unused in Linux 2.6) | dt | |
+ # ============================================================
+ with open_binary("%s/%s/statm" % (self._procfs_path, self.pid)) as f:
+ vms, rss, shared, text, lib, data, dirty = \
+ [int(x) * PAGESIZE for x in f.readline().split()[:7]]
+ return pmem(rss, vms, shared, text, lib, data, dirty)
+
+ if HAS_PROC_SMAPS_ROLLUP or HAS_PROC_SMAPS:
+
+ @wrap_exceptions
+ def _parse_smaps_rollup(self):
+ # /proc/pid/smaps_rollup was added to Linux in 2017. Faster
+ # than /proc/pid/smaps. It reports higher PSS than */smaps
+ # (from 1k up to 200k higher; tested against all processes).
+ uss = pss = swap = 0
+ try:
+ with open_binary("{}/{}/smaps_rollup".format(
+ self._procfs_path, self.pid)) as f:
+ for line in f:
+ if line.startswith(b"Private_"):
+ # Private_Clean, Private_Dirty, Private_Hugetlb
+ uss += int(line.split()[1]) * 1024
+ elif line.startswith(b"Pss:"):
+ pss = int(line.split()[1]) * 1024
+ elif line.startswith(b"Swap:"):
+ swap = int(line.split()[1]) * 1024
+ except ProcessLookupError: # happens on readline()
+ if not pid_exists(self.pid):
+ raise NoSuchProcess(self.pid, self._name)
+ else:
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ return (uss, pss, swap)
+
+ @wrap_exceptions
+ def _parse_smaps(
+ self,
+ # Gets Private_Clean, Private_Dirty, Private_Hugetlb.
+ _private_re=re.compile(br"\nPrivate.*:\s+(\d+)"),
+ _pss_re=re.compile(br"\nPss\:\s+(\d+)"),
+ _swap_re=re.compile(br"\nSwap\:\s+(\d+)")):
+ # /proc/pid/smaps does not exist on kernels < 2.6.14 or if
+ # CONFIG_MMU kernel configuration option is not enabled.
+
+ # Note: using 3 regexes is faster than reading the file
+ # line by line.
+ # XXX: on Python 3 the 2 regexes are 30% slower than on
+ # Python 2 though. Figure out why.
+ #
+ # You might be tempted to calculate USS by subtracting
+ # the "shared" value from the "resident" value in
+ # /proc/<pid>/statm. But at least on Linux, statm's "shared"
+ # value actually counts pages backed by files, which has
+ # little to do with whether the pages are actually shared.
+ # /proc/self/smaps on the other hand appears to give us the
+ # correct information.
+ smaps_data = self._read_smaps_file()
+ # Note: smaps file can be empty for certain processes.
+ # The code below will not crash though and will result to 0.
+ uss = sum(map(int, _private_re.findall(smaps_data))) * 1024
+ pss = sum(map(int, _pss_re.findall(smaps_data))) * 1024
+ swap = sum(map(int, _swap_re.findall(smaps_data))) * 1024
+ return (uss, pss, swap)
+
+ def memory_full_info(self):
+ if HAS_PROC_SMAPS_ROLLUP: # faster
+ uss, pss, swap = self._parse_smaps_rollup()
+ else:
+ uss, pss, swap = self._parse_smaps()
+ basic_mem = self.memory_info()
+ return pfullmem(*basic_mem + (uss, pss, swap))
+
+ else:
+ memory_full_info = memory_info
+
+ if HAS_PROC_SMAPS:
+
+ @wrap_exceptions
+ def memory_maps(self):
+ """Return process's mapped memory regions as a list of named
+ tuples. Fields are explained in 'man proc'; here is an updated
+ (Apr 2012) version: http://goo.gl/fmebo
+
+ /proc/{PID}/smaps does not exist on kernels < 2.6.14 or if
+ CONFIG_MMU kernel configuration option is not enabled.
+ """
+ def get_blocks(lines, current_block):
+ data = {}
+ for line in lines:
+ fields = line.split(None, 5)
+ if not fields[0].endswith(b':'):
+ # new block section
+ yield (current_block.pop(), data)
+ current_block.append(line)
+ else:
+ try:
+ data[fields[0]] = int(fields[1]) * 1024
+ except ValueError:
+ if fields[0].startswith(b'VmFlags:'):
+ # see issue #369
+ continue
+ else:
+ raise ValueError("don't know how to inte"
+ "rpret line %r" % line)
+ yield (current_block.pop(), data)
+
+ data = self._read_smaps_file()
+ # Note: smaps file can be empty for certain processes.
+ if not data:
+ return []
+ lines = data.split(b'\n')
+ ls = []
+ first_line = lines.pop(0)
+ current_block = [first_line]
+ for header, data in get_blocks(lines, current_block):
+ hfields = header.split(None, 5)
+ try:
+ addr, perms, offset, dev, inode, path = hfields
+ except ValueError:
+ addr, perms, offset, dev, inode, path = \
+ hfields + ['']
+ if not path:
+ path = '[anon]'
+ else:
+ if PY3:
+ path = decode(path)
+ path = path.strip()
+ if (path.endswith(' (deleted)') and not
+ path_exists_strict(path)):
+ path = path[:-10]
+ ls.append((
+ decode(addr), decode(perms), path,
+ data.get(b'Rss:', 0),
+ data.get(b'Size:', 0),
+ data.get(b'Pss:', 0),
+ data.get(b'Shared_Clean:', 0),
+ data.get(b'Shared_Dirty:', 0),
+ data.get(b'Private_Clean:', 0),
+ data.get(b'Private_Dirty:', 0),
+ data.get(b'Referenced:', 0),
+ data.get(b'Anonymous:', 0),
+ data.get(b'Swap:', 0)
+ ))
+ return ls
+
+ @wrap_exceptions
+ def cwd(self):
+ try:
+ return readlink("%s/%s/cwd" % (self._procfs_path, self.pid))
+ except (FileNotFoundError, ProcessLookupError):
+ # https://github.com/giampaolo/psutil/issues/986
+ if not pid_exists(self.pid):
+ raise NoSuchProcess(self.pid, self._name)
+ else:
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+
+ @wrap_exceptions
+ def num_ctx_switches(self,
+ _ctxsw_re=re.compile(br'ctxt_switches:\t(\d+)')):
+ data = self._read_status_file()
+ ctxsw = _ctxsw_re.findall(data)
+ if not ctxsw:
+ raise NotImplementedError(
+ "'voluntary_ctxt_switches' and 'nonvoluntary_ctxt_switches'"
+ "lines were not found in %s/%s/status; the kernel is "
+ "probably older than 2.6.23" % (
+ self._procfs_path, self.pid))
+ else:
+ return _common.pctxsw(int(ctxsw[0]), int(ctxsw[1]))
+
+ @wrap_exceptions
+ def num_threads(self, _num_threads_re=re.compile(br'Threads:\t(\d+)')):
+ # Note: on Python 3 using a re is faster than iterating over file
+ # line by line. On Python 2 is the exact opposite, and iterating
+ # over a file on Python 3 is slower than on Python 2.
+ data = self._read_status_file()
+ return int(_num_threads_re.findall(data)[0])
+
+ @wrap_exceptions
+ def threads(self):
+ thread_ids = os.listdir("%s/%s/task" % (self._procfs_path, self.pid))
+ thread_ids.sort()
+ retlist = []
+ hit_enoent = False
+ for thread_id in thread_ids:
+ fname = "%s/%s/task/%s/stat" % (
+ self._procfs_path, self.pid, thread_id)
+ try:
+ with open_binary(fname) as f:
+ st = f.read().strip()
+ except (FileNotFoundError, ProcessLookupError):
+ # no such file or directory or no such process;
+ # it means thread disappeared on us
+ hit_enoent = True
+ continue
+ # ignore the first two values ("pid (exe)")
+ st = st[st.find(b')') + 2:]
+ values = st.split(b' ')
+ utime = float(values[11]) / CLOCK_TICKS
+ stime = float(values[12]) / CLOCK_TICKS
+ ntuple = _common.pthread(int(thread_id), utime, stime)
+ retlist.append(ntuple)
+ if hit_enoent:
+ self._assert_alive()
+ return retlist
+
+ @wrap_exceptions
+ def nice_get(self):
+ # with open_text('%s/%s/stat' % (self._procfs_path, self.pid)) as f:
+ # data = f.read()
+ # return int(data.split()[18])
+
+ # Use C implementation
+ return cext_posix.getpriority(self.pid)
+
+ @wrap_exceptions
+ def nice_set(self, value):
+ return cext_posix.setpriority(self.pid, value)
+
+ # starting from CentOS 6.
+ if HAS_CPU_AFFINITY:
+
+ @wrap_exceptions
+ def cpu_affinity_get(self):
+ return cext.proc_cpu_affinity_get(self.pid)
+
+ def _get_eligible_cpus(
+ self, _re=re.compile(br"Cpus_allowed_list:\t(\d+)-(\d+)")):
+ # See: https://github.com/giampaolo/psutil/issues/956
+ data = self._read_status_file()
+ match = _re.findall(data)
+ if match:
+ return list(range(int(match[0][0]), int(match[0][1]) + 1))
+ else:
+ return list(range(len(per_cpu_times())))
+
+ @wrap_exceptions
+ def cpu_affinity_set(self, cpus):
+ try:
+ cext.proc_cpu_affinity_set(self.pid, cpus)
+ except (OSError, ValueError) as err:
+ if isinstance(err, ValueError) or err.errno == errno.EINVAL:
+ eligible_cpus = self._get_eligible_cpus()
+ all_cpus = tuple(range(len(per_cpu_times())))
+ for cpu in cpus:
+ if cpu not in all_cpus:
+ raise ValueError(
+ "invalid CPU number %r; choose between %s" % (
+ cpu, eligible_cpus))
+ if cpu not in eligible_cpus:
+ raise ValueError(
+ "CPU number %r is not eligible; choose "
+ "between %s" % (cpu, eligible_cpus))
+ raise
+
+ # only starting from kernel 2.6.13
+ if HAS_PROC_IO_PRIORITY:
+
+ @wrap_exceptions
+ def ionice_get(self):
+ ioclass, value = cext.proc_ioprio_get(self.pid)
+ if enum is not None:
+ ioclass = IOPriority(ioclass)
+ return _common.pionice(ioclass, value)
+
+ @wrap_exceptions
+ def ionice_set(self, ioclass, value):
+ if value is None:
+ value = 0
+ if value and ioclass in (IOPRIO_CLASS_IDLE, IOPRIO_CLASS_NONE):
+ raise ValueError("%r ioclass accepts no value" % ioclass)
+ if value < 0 or value > 7:
+ raise ValueError("value not in 0-7 range")
+ return cext.proc_ioprio_set(self.pid, ioclass, value)
+
+ if prlimit is not None:
+
+ @wrap_exceptions
+ def rlimit(self, resource_, limits=None):
+ # If pid is 0 prlimit() applies to the calling process and
+ # we don't want that. We should never get here though as
+ # PID 0 is not supported on Linux.
+ if self.pid == 0:
+ raise ValueError("can't use prlimit() against PID 0 process")
+ try:
+ if limits is None:
+ # get
+ return prlimit(self.pid, resource_)
+ else:
+ # set
+ if len(limits) != 2:
+ raise ValueError(
+ "second argument must be a (soft, hard) tuple, "
+ "got %s" % repr(limits))
+ prlimit(self.pid, resource_, limits)
+ except OSError as err:
+ if err.errno == errno.ENOSYS and pid_exists(self.pid):
+ # I saw this happening on Travis:
+ # https://travis-ci.org/giampaolo/psutil/jobs/51368273
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ else:
+ raise
+
+ @wrap_exceptions
+ def status(self):
+ letter = self._parse_stat_file()['status']
+ if PY3:
+ letter = letter.decode()
+ # XXX is '?' legit? (we're not supposed to return it anyway)
+ return PROC_STATUSES.get(letter, '?')
+
+ @wrap_exceptions
+ def open_files(self):
+ retlist = []
+ files = os.listdir("%s/%s/fd" % (self._procfs_path, self.pid))
+ hit_enoent = False
+ for fd in files:
+ file = "%s/%s/fd/%s" % (self._procfs_path, self.pid, fd)
+ try:
+ path = readlink(file)
+ except (FileNotFoundError, ProcessLookupError):
+ # ENOENT == file which is gone in the meantime
+ hit_enoent = True
+ continue
+ except OSError as err:
+ if err.errno == errno.EINVAL:
+ # not a link
+ continue
+ if err.errno == errno.ENAMETOOLONG:
+ # file name too long
+ debug(err)
+ continue
+ raise
+ else:
+ # If path is not an absolute there's no way to tell
+ # whether it's a regular file or not, so we skip it.
+ # A regular file is always supposed to be have an
+ # absolute path though.
+ if path.startswith('/') and isfile_strict(path):
+ # Get file position and flags.
+ file = "%s/%s/fdinfo/%s" % (
+ self._procfs_path, self.pid, fd)
+ try:
+ with open_binary(file) as f:
+ pos = int(f.readline().split()[1])
+ flags = int(f.readline().split()[1], 8)
+ except (FileNotFoundError, ProcessLookupError):
+ # fd gone in the meantime; process may
+ # still be alive
+ hit_enoent = True
+ else:
+ mode = file_flags_to_mode(flags)
+ ntuple = popenfile(
+ path, int(fd), int(pos), mode, flags)
+ retlist.append(ntuple)
+ if hit_enoent:
+ self._assert_alive()
+ return retlist
+
+ @wrap_exceptions
+ def connections(self, kind='inet'):
+ ret = _connections.retrieve(kind, self.pid)
+ self._assert_alive()
+ return ret
+
+ @wrap_exceptions
+ def num_fds(self):
+ return len(os.listdir("%s/%s/fd" % (self._procfs_path, self.pid)))
+
+ @wrap_exceptions
+ def ppid(self):
+ return int(self._parse_stat_file()['ppid'])
+
+ @wrap_exceptions
+ def uids(self, _uids_re=re.compile(br'Uid:\t(\d+)\t(\d+)\t(\d+)')):
+ data = self._read_status_file()
+ real, effective, saved = _uids_re.findall(data)[0]
+ return _common.puids(int(real), int(effective), int(saved))
+
+ @wrap_exceptions
+ def gids(self, _gids_re=re.compile(br'Gid:\t(\d+)\t(\d+)\t(\d+)')):
+ data = self._read_status_file()
+ real, effective, saved = _gids_re.findall(data)[0]
+ return _common.pgids(int(real), int(effective), int(saved))
diff --git a/lib/psutil/_psosx.py b/lib/psutil/_psosx.py
new file mode 100644
index 0000000..58359bc
--- /dev/null
+++ b/lib/psutil/_psosx.py
@@ -0,0 +1,543 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""macOS platform implementation."""
+
+import errno
+import functools
+import os
+from collections import namedtuple
+
+from . import _common
+from . import _psposix
+from . import _psutil_osx as cext
+from . import _psutil_posix as cext_posix
+from ._common import AccessDenied
+from ._common import NoSuchProcess
+from ._common import ZombieProcess
+from ._common import conn_tmap
+from ._common import conn_to_ntuple
+from ._common import isfile_strict
+from ._common import memoize_when_activated
+from ._common import parse_environ_block
+from ._common import usage_percent
+from ._compat import PermissionError
+from ._compat import ProcessLookupError
+
+
+__extra__all__ = []
+
+
+# =====================================================================
+# --- globals
+# =====================================================================
+
+
+PAGESIZE = cext_posix.getpagesize()
+AF_LINK = cext_posix.AF_LINK
+
+TCP_STATUSES = {
+ cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED,
+ cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT,
+ cext.TCPS_SYN_RECEIVED: _common.CONN_SYN_RECV,
+ cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1,
+ cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2,
+ cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT,
+ cext.TCPS_CLOSED: _common.CONN_CLOSE,
+ cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT,
+ cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK,
+ cext.TCPS_LISTEN: _common.CONN_LISTEN,
+ cext.TCPS_CLOSING: _common.CONN_CLOSING,
+ cext.PSUTIL_CONN_NONE: _common.CONN_NONE,
+}
+
+PROC_STATUSES = {
+ cext.SIDL: _common.STATUS_IDLE,
+ cext.SRUN: _common.STATUS_RUNNING,
+ cext.SSLEEP: _common.STATUS_SLEEPING,
+ cext.SSTOP: _common.STATUS_STOPPED,
+ cext.SZOMB: _common.STATUS_ZOMBIE,
+}
+
+kinfo_proc_map = dict(
+ ppid=0,
+ ruid=1,
+ euid=2,
+ suid=3,
+ rgid=4,
+ egid=5,
+ sgid=6,
+ ttynr=7,
+ ctime=8,
+ status=9,
+ name=10,
+)
+
+pidtaskinfo_map = dict(
+ cpuutime=0,
+ cpustime=1,
+ rss=2,
+ vms=3,
+ pfaults=4,
+ pageins=5,
+ numthreads=6,
+ volctxsw=7,
+)
+
+
+# =====================================================================
+# --- named tuples
+# =====================================================================
+
+
+# psutil.cpu_times()
+scputimes = namedtuple('scputimes', ['user', 'nice', 'system', 'idle'])
+# psutil.virtual_memory()
+svmem = namedtuple(
+ 'svmem', ['total', 'available', 'percent', 'used', 'free',
+ 'active', 'inactive', 'wired'])
+# psutil.Process.memory_info()
+pmem = namedtuple('pmem', ['rss', 'vms', 'pfaults', 'pageins'])
+# psutil.Process.memory_full_info()
+pfullmem = namedtuple('pfullmem', pmem._fields + ('uss', ))
+
+
+# =====================================================================
+# --- memory
+# =====================================================================
+
+
+def virtual_memory():
+ """System virtual memory as a namedtuple."""
+ total, active, inactive, wired, free, speculative = cext.virtual_mem()
+ # This is how Zabbix calculate avail and used mem:
+ # https://github.com/zabbix/zabbix/blob/trunk/src/libs/zbxsysinfo/
+ # osx/memory.c
+ # Also see: https://github.com/giampaolo/psutil/issues/1277
+ avail = inactive + free
+ used = active + wired
+ # This is NOT how Zabbix calculates free mem but it matches "free"
+ # cmdline utility.
+ free -= speculative
+ percent = usage_percent((total - avail), total, round_=1)
+ return svmem(total, avail, percent, used, free,
+ active, inactive, wired)
+
+
+def swap_memory():
+ """Swap system memory as a (total, used, free, sin, sout) tuple."""
+ total, used, free, sin, sout = cext.swap_mem()
+ percent = usage_percent(used, total, round_=1)
+ return _common.sswap(total, used, free, percent, sin, sout)
+
+
+# =====================================================================
+# --- CPU
+# =====================================================================
+
+
+def cpu_times():
+ """Return system CPU times as a namedtuple."""
+ user, nice, system, idle = cext.cpu_times()
+ return scputimes(user, nice, system, idle)
+
+
+def per_cpu_times():
+ """Return system CPU times as a named tuple"""
+ ret = []
+ for cpu_t in cext.per_cpu_times():
+ user, nice, system, idle = cpu_t
+ item = scputimes(user, nice, system, idle)
+ ret.append(item)
+ return ret
+
+
+def cpu_count_logical():
+ """Return the number of logical CPUs in the system."""
+ return cext.cpu_count_logical()
+
+
+def cpu_count_cores():
+ """Return the number of CPU cores in the system."""
+ return cext.cpu_count_cores()
+
+
+def cpu_stats():
+ ctx_switches, interrupts, soft_interrupts, syscalls, traps = \
+ cext.cpu_stats()
+ return _common.scpustats(
+ ctx_switches, interrupts, soft_interrupts, syscalls)
+
+
+def cpu_freq():
+ """Return CPU frequency.
+ On macOS per-cpu frequency is not supported.
+ Also, the returned frequency never changes, see:
+ https://arstechnica.com/civis/viewtopic.php?f=19&t=465002
+ """
+ curr, min_, max_ = cext.cpu_freq()
+ return [_common.scpufreq(curr, min_, max_)]
+
+
+# =====================================================================
+# --- disks
+# =====================================================================
+
+
+disk_usage = _psposix.disk_usage
+disk_io_counters = cext.disk_io_counters
+
+
+def disk_partitions(all=False):
+ """Return mounted disk partitions as a list of namedtuples."""
+ retlist = []
+ partitions = cext.disk_partitions()
+ for partition in partitions:
+ device, mountpoint, fstype, opts = partition
+ if device == 'none':
+ device = ''
+ if not all:
+ if not os.path.isabs(device) or not os.path.exists(device):
+ continue
+ maxfile = maxpath = None # set later
+ ntuple = _common.sdiskpart(device, mountpoint, fstype, opts,
+ maxfile, maxpath)
+ retlist.append(ntuple)
+ return retlist
+
+
+# =====================================================================
+# --- sensors
+# =====================================================================
+
+
+def sensors_battery():
+ """Return battery information."""
+ try:
+ percent, minsleft, power_plugged = cext.sensors_battery()
+ except NotImplementedError:
+ # no power source - return None according to interface
+ return None
+ power_plugged = power_plugged == 1
+ if power_plugged:
+ secsleft = _common.POWER_TIME_UNLIMITED
+ elif minsleft == -1:
+ secsleft = _common.POWER_TIME_UNKNOWN
+ else:
+ secsleft = minsleft * 60
+ return _common.sbattery(percent, secsleft, power_plugged)
+
+
+# =====================================================================
+# --- network
+# =====================================================================
+
+
+net_io_counters = cext.net_io_counters
+net_if_addrs = cext_posix.net_if_addrs
+
+
+def net_connections(kind='inet'):
+ """System-wide network connections."""
+ # Note: on macOS this will fail with AccessDenied unless
+ # the process is owned by root.
+ ret = []
+ for pid in pids():
+ try:
+ cons = Process(pid).connections(kind)
+ except NoSuchProcess:
+ continue
+ else:
+ if cons:
+ for c in cons:
+ c = list(c) + [pid]
+ ret.append(_common.sconn(*c))
+ return ret
+
+
+def net_if_stats():
+ """Get NIC stats (isup, duplex, speed, mtu)."""
+ names = net_io_counters().keys()
+ ret = {}
+ for name in names:
+ try:
+ mtu = cext_posix.net_if_mtu(name)
+ flags = cext_posix.net_if_flags(name)
+ duplex, speed = cext_posix.net_if_duplex_speed(name)
+ except OSError as err:
+ # https://github.com/giampaolo/psutil/issues/1279
+ if err.errno != errno.ENODEV:
+ raise
+ else:
+ if hasattr(_common, 'NicDuplex'):
+ duplex = _common.NicDuplex(duplex)
+ output_flags = ','.join(flags)
+ isup = 'running' in flags
+ ret[name] = _common.snicstats(isup, duplex, speed, mtu,
+ output_flags)
+ return ret
+
+
+# =====================================================================
+# --- other system functions
+# =====================================================================
+
+
+def boot_time():
+ """The system boot time expressed in seconds since the epoch."""
+ return cext.boot_time()
+
+
+def users():
+ """Return currently connected users as a list of namedtuples."""
+ retlist = []
+ rawlist = cext.users()
+ for item in rawlist:
+ user, tty, hostname, tstamp, pid = item
+ if tty == '~':
+ continue # reboot or shutdown
+ if not tstamp:
+ continue
+ nt = _common.suser(user, tty or None, hostname or None, tstamp, pid)
+ retlist.append(nt)
+ return retlist
+
+
+# =====================================================================
+# --- processes
+# =====================================================================
+
+
+def pids():
+ ls = cext.pids()
+ if 0 not in ls:
+ # On certain macOS versions pids() C doesn't return PID 0 but
+ # "ps" does and the process is querable via sysctl():
+ # https://travis-ci.org/giampaolo/psutil/jobs/309619941
+ try:
+ Process(0).create_time()
+ ls.insert(0, 0)
+ except NoSuchProcess:
+ pass
+ except AccessDenied:
+ ls.insert(0, 0)
+ return ls
+
+
+pid_exists = _psposix.pid_exists
+
+
+def is_zombie(pid):
+ try:
+ st = cext.proc_kinfo_oneshot(pid)[kinfo_proc_map['status']]
+ return st == cext.SZOMB
+ except Exception:
+ return False
+
+
+def wrap_exceptions(fun):
+ """Decorator which translates bare OSError exceptions into
+ NoSuchProcess and AccessDenied.
+ """
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return fun(self, *args, **kwargs)
+ except ProcessLookupError:
+ if is_zombie(self.pid):
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ else:
+ raise NoSuchProcess(self.pid, self._name)
+ except PermissionError:
+ raise AccessDenied(self.pid, self._name)
+ except cext.ZombieProcessError:
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ return wrapper
+
+
+class Process(object):
+ """Wrapper class around underlying C implementation."""
+
+ __slots__ = ["pid", "_name", "_ppid", "_cache"]
+
+ def __init__(self, pid):
+ self.pid = pid
+ self._name = None
+ self._ppid = None
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _get_kinfo_proc(self):
+ # Note: should work with all PIDs without permission issues.
+ ret = cext.proc_kinfo_oneshot(self.pid)
+ assert len(ret) == len(kinfo_proc_map)
+ return ret
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _get_pidtaskinfo(self):
+ # Note: should work for PIDs owned by user only.
+ ret = cext.proc_pidtaskinfo_oneshot(self.pid)
+ assert len(ret) == len(pidtaskinfo_map)
+ return ret
+
+ def oneshot_enter(self):
+ self._get_kinfo_proc.cache_activate(self)
+ self._get_pidtaskinfo.cache_activate(self)
+
+ def oneshot_exit(self):
+ self._get_kinfo_proc.cache_deactivate(self)
+ self._get_pidtaskinfo.cache_deactivate(self)
+
+ @wrap_exceptions
+ def name(self):
+ name = self._get_kinfo_proc()[kinfo_proc_map['name']]
+ return name if name is not None else cext.proc_name(self.pid)
+
+ @wrap_exceptions
+ def exe(self):
+ return cext.proc_exe(self.pid)
+
+ @wrap_exceptions
+ def cmdline(self):
+ return cext.proc_cmdline(self.pid)
+
+ @wrap_exceptions
+ def environ(self):
+ return parse_environ_block(cext.proc_environ(self.pid))
+
+ @wrap_exceptions
+ def ppid(self):
+ self._ppid = self._get_kinfo_proc()[kinfo_proc_map['ppid']]
+ return self._ppid
+
+ @wrap_exceptions
+ def cwd(self):
+ return cext.proc_cwd(self.pid)
+
+ @wrap_exceptions
+ def uids(self):
+ rawtuple = self._get_kinfo_proc()
+ return _common.puids(
+ rawtuple[kinfo_proc_map['ruid']],
+ rawtuple[kinfo_proc_map['euid']],
+ rawtuple[kinfo_proc_map['suid']])
+
+ @wrap_exceptions
+ def gids(self):
+ rawtuple = self._get_kinfo_proc()
+ return _common.puids(
+ rawtuple[kinfo_proc_map['rgid']],
+ rawtuple[kinfo_proc_map['egid']],
+ rawtuple[kinfo_proc_map['sgid']])
+
+ @wrap_exceptions
+ def terminal(self):
+ tty_nr = self._get_kinfo_proc()[kinfo_proc_map['ttynr']]
+ tmap = _psposix.get_terminal_map()
+ try:
+ return tmap[tty_nr]
+ except KeyError:
+ return None
+
+ @wrap_exceptions
+ def memory_info(self):
+ rawtuple = self._get_pidtaskinfo()
+ return pmem(
+ rawtuple[pidtaskinfo_map['rss']],
+ rawtuple[pidtaskinfo_map['vms']],
+ rawtuple[pidtaskinfo_map['pfaults']],
+ rawtuple[pidtaskinfo_map['pageins']],
+ )
+
+ @wrap_exceptions
+ def memory_full_info(self):
+ basic_mem = self.memory_info()
+ uss = cext.proc_memory_uss(self.pid)
+ return pfullmem(*basic_mem + (uss, ))
+
+ @wrap_exceptions
+ def cpu_times(self):
+ rawtuple = self._get_pidtaskinfo()
+ return _common.pcputimes(
+ rawtuple[pidtaskinfo_map['cpuutime']],
+ rawtuple[pidtaskinfo_map['cpustime']],
+ # children user / system times are not retrievable (set to 0)
+ 0.0, 0.0)
+
+ @wrap_exceptions
+ def create_time(self):
+ return self._get_kinfo_proc()[kinfo_proc_map['ctime']]
+
+ @wrap_exceptions
+ def num_ctx_switches(self):
+ # Unvoluntary value seems not to be available;
+ # getrusage() numbers seems to confirm this theory.
+ # We set it to 0.
+ vol = self._get_pidtaskinfo()[pidtaskinfo_map['volctxsw']]
+ return _common.pctxsw(vol, 0)
+
+ @wrap_exceptions
+ def num_threads(self):
+ return self._get_pidtaskinfo()[pidtaskinfo_map['numthreads']]
+
+ @wrap_exceptions
+ def open_files(self):
+ if self.pid == 0:
+ return []
+ files = []
+ rawlist = cext.proc_open_files(self.pid)
+ for path, fd in rawlist:
+ if isfile_strict(path):
+ ntuple = _common.popenfile(path, fd)
+ files.append(ntuple)
+ return files
+
+ @wrap_exceptions
+ def connections(self, kind='inet'):
+ if kind not in conn_tmap:
+ raise ValueError("invalid %r kind argument; choose between %s"
+ % (kind, ', '.join([repr(x) for x in conn_tmap])))
+ families, types = conn_tmap[kind]
+ rawlist = cext.proc_connections(self.pid, families, types)
+ ret = []
+ for item in rawlist:
+ fd, fam, type, laddr, raddr, status = item
+ nt = conn_to_ntuple(fd, fam, type, laddr, raddr, status,
+ TCP_STATUSES)
+ ret.append(nt)
+ return ret
+
+ @wrap_exceptions
+ def num_fds(self):
+ if self.pid == 0:
+ return 0
+ return cext.proc_num_fds(self.pid)
+
+ @wrap_exceptions
+ def wait(self, timeout=None):
+ return _psposix.wait_pid(self.pid, timeout, self._name)
+
+ @wrap_exceptions
+ def nice_get(self):
+ return cext_posix.getpriority(self.pid)
+
+ @wrap_exceptions
+ def nice_set(self, value):
+ return cext_posix.setpriority(self.pid, value)
+
+ @wrap_exceptions
+ def status(self):
+ code = self._get_kinfo_proc()[kinfo_proc_map['status']]
+ # XXX is '?' legit? (we're not supposed to return it anyway)
+ return PROC_STATUSES.get(code, '?')
+
+ @wrap_exceptions
+ def threads(self):
+ rawlist = cext.proc_threads(self.pid)
+ retlist = []
+ for thread_id, utime, stime in rawlist:
+ ntuple = _common.pthread(thread_id, utime, stime)
+ retlist.append(ntuple)
+ return retlist
diff --git a/lib/psutil/_psposix.py b/lib/psutil/_psposix.py
new file mode 100644
index 0000000..1d250bf
--- /dev/null
+++ b/lib/psutil/_psposix.py
@@ -0,0 +1,232 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Routines common to all posix systems."""
+
+import glob
+import os
+import signal
+import sys
+import time
+
+from ._common import MACOS
+from ._common import TimeoutExpired
+from ._common import memoize
+from ._common import sdiskusage
+from ._common import usage_percent
+from ._compat import PY3
+from ._compat import ChildProcessError
+from ._compat import FileNotFoundError
+from ._compat import InterruptedError
+from ._compat import PermissionError
+from ._compat import ProcessLookupError
+from ._compat import unicode
+
+
+if MACOS:
+ from . import _psutil_osx
+
+
+if sys.version_info >= (3, 4):
+ import enum
+else:
+ enum = None
+
+
+__all__ = ['pid_exists', 'wait_pid', 'disk_usage', 'get_terminal_map']
+
+
+def pid_exists(pid):
+ """Check whether pid exists in the current process table."""
+ if pid == 0:
+ # According to "man 2 kill" PID 0 has a special meaning:
+ # it refers to <<every process in the process group of the
+ # calling process>> so we don't want to go any further.
+ # If we get here it means this UNIX platform *does* have
+ # a process with id 0.
+ return True
+ try:
+ os.kill(pid, 0)
+ except ProcessLookupError:
+ return False
+ except PermissionError:
+ # EPERM clearly means there's a process to deny access to
+ return True
+ # According to "man 2 kill" possible error values are
+ # (EINVAL, EPERM, ESRCH)
+ else:
+ return True
+
+
+# Python 3.5 signals enum (contributed by me ^^):
+# https://bugs.python.org/issue21076
+if enum is not None and hasattr(signal, "Signals"):
+ Negsignal = enum.IntEnum(
+ 'Negsignal', dict([(x.name, -x.value) for x in signal.Signals]))
+
+ def negsig_to_enum(num):
+ """Convert a negative signal value to an enum."""
+ try:
+ return Negsignal(num)
+ except ValueError:
+ return num
+else: # pragma: no cover
+ def negsig_to_enum(num):
+ return num
+
+
+def wait_pid(pid, timeout=None, proc_name=None,
+ _waitpid=os.waitpid,
+ _timer=getattr(time, 'monotonic', time.time),
+ _min=min,
+ _sleep=time.sleep,
+ _pid_exists=pid_exists):
+ """Wait for a process PID to terminate.
+
+ If the process terminated normally by calling exit(3) or _exit(2),
+ or by returning from main(), the return value is the positive integer
+ passed to *exit().
+
+ If it was terminated by a signal it returns the negated value of the
+ signal which caused the termination (e.g. -SIGTERM).
+
+ If PID is not a children of os.getpid() (current process) just
+ wait until the process disappears and return None.
+
+ If PID does not exist at all return None immediately.
+
+ If *timeout* != None and process is still alive raise TimeoutExpired.
+ timeout=0 is also possible (either return immediately or raise).
+ """
+ if pid <= 0:
+ raise ValueError("can't wait for PID 0") # see "man waitpid"
+ interval = 0.0001
+ flags = 0
+ if timeout is not None:
+ flags |= os.WNOHANG
+ stop_at = _timer() + timeout
+
+ def sleep(interval):
+ # Sleep for some time and return a new increased interval.
+ if timeout is not None:
+ if _timer() >= stop_at:
+ raise TimeoutExpired(timeout, pid=pid, name=proc_name)
+ _sleep(interval)
+ return _min(interval * 2, 0.04)
+
+ # See: https://linux.die.net/man/2/waitpid
+ while True:
+ try:
+ retpid, status = os.waitpid(pid, flags)
+ except InterruptedError:
+ interval = sleep(interval)
+ except ChildProcessError:
+ # This has two meanings:
+ # - PID is not a child of os.getpid() in which case
+ # we keep polling until it's gone
+ # - PID never existed in the first place
+ # In both cases we'll eventually return None as we
+ # can't determine its exit status code.
+ while _pid_exists(pid):
+ interval = sleep(interval)
+ return
+ else:
+ if retpid == 0:
+ # WNOHANG flag was used and PID is still running.
+ interval = sleep(interval)
+ continue
+ elif os.WIFEXITED(status):
+ # Process terminated normally by calling exit(3) or _exit(2),
+ # or by returning from main(). The return value is the
+ # positive integer passed to *exit().
+ return os.WEXITSTATUS(status)
+ elif os.WIFSIGNALED(status):
+ # Process exited due to a signal. Return the negative value
+ # of that signal.
+ return negsig_to_enum(-os.WTERMSIG(status))
+ # elif os.WIFSTOPPED(status):
+ # # Process was stopped via SIGSTOP or is being traced, and
+ # # waitpid() was called with WUNTRACED flag. PID is still
+ # # alive. From now on waitpid() will keep returning (0, 0)
+ # # until the process state doesn't change.
+ # # It may make sense to catch/enable this since stopped PIDs
+ # # ignore SIGTERM.
+ # interval = sleep(interval)
+ # continue
+ # elif os.WIFCONTINUED(status):
+ # # Process was resumed via SIGCONT and waitpid() was called
+ # # with WCONTINUED flag.
+ # interval = sleep(interval)
+ # continue
+ else:
+ # Should never happen.
+ raise ValueError("unknown process exit status %r" % status)
+
+
+def disk_usage(path):
+ """Return disk usage associated with path.
+ Note: UNIX usually reserves 5% disk space which is not accessible
+ by user. In this function "total" and "used" values reflect the
+ total and used disk space whereas "free" and "percent" represent
+ the "free" and "used percent" user disk space.
+ """
+ if PY3:
+ st = os.statvfs(path)
+ else: # pragma: no cover
+ # os.statvfs() does not support unicode on Python 2:
+ # - https://github.com/giampaolo/psutil/issues/416
+ # - http://bugs.python.org/issue18695
+ try:
+ st = os.statvfs(path)
+ except UnicodeEncodeError:
+ if isinstance(path, unicode):
+ try:
+ path = path.encode(sys.getfilesystemencoding())
+ except UnicodeEncodeError:
+ pass
+ st = os.statvfs(path)
+ else:
+ raise
+
+ # Total space which is only available to root (unless changed
+ # at system level).
+ total = (st.f_blocks * st.f_frsize)
+ # Remaining free space usable by root.
+ avail_to_root = (st.f_bfree * st.f_frsize)
+ # Remaining free space usable by user.
+ avail_to_user = (st.f_bavail * st.f_frsize)
+ # Total space being used in general.
+ used = (total - avail_to_root)
+ if MACOS:
+ # see: https://github.com/giampaolo/psutil/pull/2152
+ used = _psutil_osx.disk_usage_used(path, used)
+ # Total space which is available to user (same as 'total' but
+ # for the user).
+ total_user = used + avail_to_user
+ # User usage percent compared to the total amount of space
+ # the user can use. This number would be higher if compared
+ # to root's because the user has less space (usually -5%).
+ usage_percent_user = usage_percent(used, total_user, round_=1)
+
+ # NB: the percentage is -5% than what shown by df due to
+ # reserved blocks that we are currently not considering:
+ # https://github.com/giampaolo/psutil/issues/829#issuecomment-223750462
+ return sdiskusage(
+ total=total, used=used, free=avail_to_user, percent=usage_percent_user)
+
+
+@memoize
+def get_terminal_map():
+ """Get a map of device-id -> path as a dict.
+ Used by Process.terminal()
+ """
+ ret = {}
+ ls = glob.glob('/dev/tty*') + glob.glob('/dev/pts/*')
+ for name in ls:
+ assert name not in ret, name
+ try:
+ ret[os.stat(name).st_rdev] = name
+ except FileNotFoundError:
+ pass
+ return ret
diff --git a/lib/psutil/_pssunos.py b/lib/psutil/_pssunos.py
new file mode 100644
index 0000000..541c1aa
--- /dev/null
+++ b/lib/psutil/_pssunos.py
@@ -0,0 +1,727 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Sun OS Solaris platform implementation."""
+
+import errno
+import functools
+import os
+import socket
+import subprocess
+import sys
+from collections import namedtuple
+from socket import AF_INET
+
+from . import _common
+from . import _psposix
+from . import _psutil_posix as cext_posix
+from . import _psutil_sunos as cext
+from ._common import AF_INET6
+from ._common import AccessDenied
+from ._common import NoSuchProcess
+from ._common import ZombieProcess
+from ._common import debug
+from ._common import get_procfs_path
+from ._common import isfile_strict
+from ._common import memoize_when_activated
+from ._common import sockfam_to_enum
+from ._common import socktype_to_enum
+from ._common import usage_percent
+from ._compat import PY3
+from ._compat import FileNotFoundError
+from ._compat import PermissionError
+from ._compat import ProcessLookupError
+from ._compat import b
+
+
+__extra__all__ = ["CONN_IDLE", "CONN_BOUND", "PROCFS_PATH"]
+
+
+# =====================================================================
+# --- globals
+# =====================================================================
+
+
+PAGE_SIZE = cext_posix.getpagesize()
+AF_LINK = cext_posix.AF_LINK
+IS_64_BIT = sys.maxsize > 2**32
+
+CONN_IDLE = "IDLE"
+CONN_BOUND = "BOUND"
+
+PROC_STATUSES = {
+ cext.SSLEEP: _common.STATUS_SLEEPING,
+ cext.SRUN: _common.STATUS_RUNNING,
+ cext.SZOMB: _common.STATUS_ZOMBIE,
+ cext.SSTOP: _common.STATUS_STOPPED,
+ cext.SIDL: _common.STATUS_IDLE,
+ cext.SONPROC: _common.STATUS_RUNNING, # same as run
+ cext.SWAIT: _common.STATUS_WAITING,
+}
+
+TCP_STATUSES = {
+ cext.TCPS_ESTABLISHED: _common.CONN_ESTABLISHED,
+ cext.TCPS_SYN_SENT: _common.CONN_SYN_SENT,
+ cext.TCPS_SYN_RCVD: _common.CONN_SYN_RECV,
+ cext.TCPS_FIN_WAIT_1: _common.CONN_FIN_WAIT1,
+ cext.TCPS_FIN_WAIT_2: _common.CONN_FIN_WAIT2,
+ cext.TCPS_TIME_WAIT: _common.CONN_TIME_WAIT,
+ cext.TCPS_CLOSED: _common.CONN_CLOSE,
+ cext.TCPS_CLOSE_WAIT: _common.CONN_CLOSE_WAIT,
+ cext.TCPS_LAST_ACK: _common.CONN_LAST_ACK,
+ cext.TCPS_LISTEN: _common.CONN_LISTEN,
+ cext.TCPS_CLOSING: _common.CONN_CLOSING,
+ cext.PSUTIL_CONN_NONE: _common.CONN_NONE,
+ cext.TCPS_IDLE: CONN_IDLE, # sunos specific
+ cext.TCPS_BOUND: CONN_BOUND, # sunos specific
+}
+
+proc_info_map = dict(
+ ppid=0,
+ rss=1,
+ vms=2,
+ create_time=3,
+ nice=4,
+ num_threads=5,
+ status=6,
+ ttynr=7,
+ uid=8,
+ euid=9,
+ gid=10,
+ egid=11)
+
+
+# =====================================================================
+# --- named tuples
+# =====================================================================
+
+
+# psutil.cpu_times()
+scputimes = namedtuple('scputimes', ['user', 'system', 'idle', 'iowait'])
+# psutil.cpu_times(percpu=True)
+pcputimes = namedtuple('pcputimes',
+ ['user', 'system', 'children_user', 'children_system'])
+# psutil.virtual_memory()
+svmem = namedtuple('svmem', ['total', 'available', 'percent', 'used', 'free'])
+# psutil.Process.memory_info()
+pmem = namedtuple('pmem', ['rss', 'vms'])
+pfullmem = pmem
+# psutil.Process.memory_maps(grouped=True)
+pmmap_grouped = namedtuple('pmmap_grouped',
+ ['path', 'rss', 'anonymous', 'locked'])
+# psutil.Process.memory_maps(grouped=False)
+pmmap_ext = namedtuple(
+ 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields))
+
+
+# =====================================================================
+# --- memory
+# =====================================================================
+
+
+def virtual_memory():
+ """Report virtual memory metrics."""
+ # we could have done this with kstat, but IMHO this is good enough
+ total = os.sysconf('SC_PHYS_PAGES') * PAGE_SIZE
+ # note: there's no difference on Solaris
+ free = avail = os.sysconf('SC_AVPHYS_PAGES') * PAGE_SIZE
+ used = total - free
+ percent = usage_percent(used, total, round_=1)
+ return svmem(total, avail, percent, used, free)
+
+
+def swap_memory():
+ """Report swap memory metrics."""
+ sin, sout = cext.swap_mem()
+ # XXX
+ # we are supposed to get total/free by doing so:
+ # http://cvs.opensolaris.org/source/xref/onnv/onnv-gate/
+ # usr/src/cmd/swap/swap.c
+ # ...nevertheless I can't manage to obtain the same numbers as 'swap'
+ # cmdline utility, so let's parse its output (sigh!)
+ p = subprocess.Popen(['/usr/bin/env', 'PATH=/usr/sbin:/sbin:%s' %
+ os.environ['PATH'], 'swap', '-l'],
+ stdout=subprocess.PIPE)
+ stdout, stderr = p.communicate()
+ if PY3:
+ stdout = stdout.decode(sys.stdout.encoding)
+ if p.returncode != 0:
+ raise RuntimeError("'swap -l' failed (retcode=%s)" % p.returncode)
+
+ lines = stdout.strip().split('\n')[1:]
+ if not lines:
+ raise RuntimeError('no swap device(s) configured')
+ total = free = 0
+ for line in lines:
+ line = line.split()
+ t, f = line[3:5]
+ total += int(int(t) * 512)
+ free += int(int(f) * 512)
+ used = total - free
+ percent = usage_percent(used, total, round_=1)
+ return _common.sswap(total, used, free, percent,
+ sin * PAGE_SIZE, sout * PAGE_SIZE)
+
+
+# =====================================================================
+# --- CPU
+# =====================================================================
+
+
+def cpu_times():
+ """Return system-wide CPU times as a named tuple"""
+ ret = cext.per_cpu_times()
+ return scputimes(*[sum(x) for x in zip(*ret)])
+
+
+def per_cpu_times():
+ """Return system per-CPU times as a list of named tuples"""
+ ret = cext.per_cpu_times()
+ return [scputimes(*x) for x in ret]
+
+
+def cpu_count_logical():
+ """Return the number of logical CPUs in the system."""
+ try:
+ return os.sysconf("SC_NPROCESSORS_ONLN")
+ except ValueError:
+ # mimic os.cpu_count() behavior
+ return None
+
+
+def cpu_count_cores():
+ """Return the number of CPU cores in the system."""
+ return cext.cpu_count_cores()
+
+
+def cpu_stats():
+ """Return various CPU stats as a named tuple."""
+ ctx_switches, interrupts, syscalls, traps = cext.cpu_stats()
+ soft_interrupts = 0
+ return _common.scpustats(ctx_switches, interrupts, soft_interrupts,
+ syscalls)
+
+
+# =====================================================================
+# --- disks
+# =====================================================================
+
+
+disk_io_counters = cext.disk_io_counters
+disk_usage = _psposix.disk_usage
+
+
+def disk_partitions(all=False):
+ """Return system disk partitions."""
+ # TODO - the filtering logic should be better checked so that
+ # it tries to reflect 'df' as much as possible
+ retlist = []
+ partitions = cext.disk_partitions()
+ for partition in partitions:
+ device, mountpoint, fstype, opts = partition
+ if device == 'none':
+ device = ''
+ if not all:
+ # Differently from, say, Linux, we don't have a list of
+ # common fs types so the best we can do, AFAIK, is to
+ # filter by filesystem having a total size > 0.
+ try:
+ if not disk_usage(mountpoint).total:
+ continue
+ except OSError as err:
+ # https://github.com/giampaolo/psutil/issues/1674
+ debug("skipping %r: %s" % (mountpoint, err))
+ continue
+ maxfile = maxpath = None # set later
+ ntuple = _common.sdiskpart(device, mountpoint, fstype, opts,
+ maxfile, maxpath)
+ retlist.append(ntuple)
+ return retlist
+
+
+# =====================================================================
+# --- network
+# =====================================================================
+
+
+net_io_counters = cext.net_io_counters
+net_if_addrs = cext_posix.net_if_addrs
+
+
+def net_connections(kind, _pid=-1):
+ """Return socket connections. If pid == -1 return system-wide
+ connections (as opposed to connections opened by one process only).
+ Only INET sockets are returned (UNIX are not).
+ """
+ cmap = _common.conn_tmap.copy()
+ if _pid == -1:
+ cmap.pop('unix', 0)
+ if kind not in cmap:
+ raise ValueError("invalid %r kind argument; choose between %s"
+ % (kind, ', '.join([repr(x) for x in cmap])))
+ families, types = _common.conn_tmap[kind]
+ rawlist = cext.net_connections(_pid)
+ ret = set()
+ for item in rawlist:
+ fd, fam, type_, laddr, raddr, status, pid = item
+ if fam not in families:
+ continue
+ if type_ not in types:
+ continue
+ # TODO: refactor and use _common.conn_to_ntuple.
+ if fam in (AF_INET, AF_INET6):
+ if laddr:
+ laddr = _common.addr(*laddr)
+ if raddr:
+ raddr = _common.addr(*raddr)
+ status = TCP_STATUSES[status]
+ fam = sockfam_to_enum(fam)
+ type_ = socktype_to_enum(type_)
+ if _pid == -1:
+ nt = _common.sconn(fd, fam, type_, laddr, raddr, status, pid)
+ else:
+ nt = _common.pconn(fd, fam, type_, laddr, raddr, status)
+ ret.add(nt)
+ return list(ret)
+
+
+def net_if_stats():
+ """Get NIC stats (isup, duplex, speed, mtu)."""
+ ret = cext.net_if_stats()
+ for name, items in ret.items():
+ isup, duplex, speed, mtu = items
+ if hasattr(_common, 'NicDuplex'):
+ duplex = _common.NicDuplex(duplex)
+ ret[name] = _common.snicstats(isup, duplex, speed, mtu, '')
+ return ret
+
+
+# =====================================================================
+# --- other system functions
+# =====================================================================
+
+
+def boot_time():
+ """The system boot time expressed in seconds since the epoch."""
+ return cext.boot_time()
+
+
+def users():
+ """Return currently connected users as a list of namedtuples."""
+ retlist = []
+ rawlist = cext.users()
+ localhost = (':0.0', ':0')
+ for item in rawlist:
+ user, tty, hostname, tstamp, user_process, pid = item
+ # note: the underlying C function includes entries about
+ # system boot, run level and others. We might want
+ # to use them in the future.
+ if not user_process:
+ continue
+ if hostname in localhost:
+ hostname = 'localhost'
+ nt = _common.suser(user, tty, hostname, tstamp, pid)
+ retlist.append(nt)
+ return retlist
+
+
+# =====================================================================
+# --- processes
+# =====================================================================
+
+
+def pids():
+ """Returns a list of PIDs currently running on the system."""
+ return [int(x) for x in os.listdir(b(get_procfs_path())) if x.isdigit()]
+
+
+def pid_exists(pid):
+ """Check for the existence of a unix pid."""
+ return _psposix.pid_exists(pid)
+
+
+def wrap_exceptions(fun):
+ """Call callable into a try/except clause and translate ENOENT,
+ EACCES and EPERM in NoSuchProcess or AccessDenied exceptions.
+ """
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return fun(self, *args, **kwargs)
+ except (FileNotFoundError, ProcessLookupError):
+ # ENOENT (no such file or directory) gets raised on open().
+ # ESRCH (no such process) can get raised on read() if
+ # process is gone in meantime.
+ if not pid_exists(self.pid):
+ raise NoSuchProcess(self.pid, self._name)
+ else:
+ raise ZombieProcess(self.pid, self._name, self._ppid)
+ except PermissionError:
+ raise AccessDenied(self.pid, self._name)
+ except OSError:
+ if self.pid == 0:
+ if 0 in pids():
+ raise AccessDenied(self.pid, self._name)
+ else:
+ raise
+ raise
+ return wrapper
+
+
+class Process(object):
+ """Wrapper class around underlying C implementation."""
+
+ __slots__ = ["pid", "_name", "_ppid", "_procfs_path", "_cache"]
+
+ def __init__(self, pid):
+ self.pid = pid
+ self._name = None
+ self._ppid = None
+ self._procfs_path = get_procfs_path()
+
+ def _assert_alive(self):
+ """Raise NSP if the process disappeared on us."""
+ # For those C function who do not raise NSP, possibly returning
+ # incorrect or incomplete result.
+ os.stat('%s/%s' % (self._procfs_path, self.pid))
+
+ def oneshot_enter(self):
+ self._proc_name_and_args.cache_activate(self)
+ self._proc_basic_info.cache_activate(self)
+ self._proc_cred.cache_activate(self)
+
+ def oneshot_exit(self):
+ self._proc_name_and_args.cache_deactivate(self)
+ self._proc_basic_info.cache_deactivate(self)
+ self._proc_cred.cache_deactivate(self)
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _proc_name_and_args(self):
+ return cext.proc_name_and_args(self.pid, self._procfs_path)
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _proc_basic_info(self):
+ if self.pid == 0 and not \
+ os.path.exists('%s/%s/psinfo' % (self._procfs_path, self.pid)):
+ raise AccessDenied(self.pid)
+ ret = cext.proc_basic_info(self.pid, self._procfs_path)
+ assert len(ret) == len(proc_info_map)
+ return ret
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def _proc_cred(self):
+ return cext.proc_cred(self.pid, self._procfs_path)
+
+ @wrap_exceptions
+ def name(self):
+ # note: max len == 15
+ return self._proc_name_and_args()[0]
+
+ @wrap_exceptions
+ def exe(self):
+ try:
+ return os.readlink(
+ "%s/%s/path/a.out" % (self._procfs_path, self.pid))
+ except OSError:
+ pass # continue and guess the exe name from the cmdline
+ # Will be guessed later from cmdline but we want to explicitly
+ # invoke cmdline here in order to get an AccessDenied
+ # exception if the user has not enough privileges.
+ self.cmdline()
+ return ""
+
+ @wrap_exceptions
+ def cmdline(self):
+ return self._proc_name_and_args()[1].split(' ')
+
+ @wrap_exceptions
+ def environ(self):
+ return cext.proc_environ(self.pid, self._procfs_path)
+
+ @wrap_exceptions
+ def create_time(self):
+ return self._proc_basic_info()[proc_info_map['create_time']]
+
+ @wrap_exceptions
+ def num_threads(self):
+ return self._proc_basic_info()[proc_info_map['num_threads']]
+
+ @wrap_exceptions
+ def nice_get(self):
+ # Note #1: getpriority(3) doesn't work for realtime processes.
+ # Psinfo is what ps uses, see:
+ # https://github.com/giampaolo/psutil/issues/1194
+ return self._proc_basic_info()[proc_info_map['nice']]
+
+ @wrap_exceptions
+ def nice_set(self, value):
+ if self.pid in (2, 3):
+ # Special case PIDs: internally setpriority(3) return ESRCH
+ # (no such process), no matter what.
+ # The process actually exists though, as it has a name,
+ # creation time, etc.
+ raise AccessDenied(self.pid, self._name)
+ return cext_posix.setpriority(self.pid, value)
+
+ @wrap_exceptions
+ def ppid(self):
+ self._ppid = self._proc_basic_info()[proc_info_map['ppid']]
+ return self._ppid
+
+ @wrap_exceptions
+ def uids(self):
+ try:
+ real, effective, saved, _, _, _ = self._proc_cred()
+ except AccessDenied:
+ real = self._proc_basic_info()[proc_info_map['uid']]
+ effective = self._proc_basic_info()[proc_info_map['euid']]
+ saved = None
+ return _common.puids(real, effective, saved)
+
+ @wrap_exceptions
+ def gids(self):
+ try:
+ _, _, _, real, effective, saved = self._proc_cred()
+ except AccessDenied:
+ real = self._proc_basic_info()[proc_info_map['gid']]
+ effective = self._proc_basic_info()[proc_info_map['egid']]
+ saved = None
+ return _common.puids(real, effective, saved)
+
+ @wrap_exceptions
+ def cpu_times(self):
+ try:
+ times = cext.proc_cpu_times(self.pid, self._procfs_path)
+ except OSError as err:
+ if err.errno == errno.EOVERFLOW and not IS_64_BIT:
+ # We may get here if we attempt to query a 64bit process
+ # with a 32bit python.
+ # Error originates from read() and also tools like "cat"
+ # fail in the same way (!).
+ # Since there simply is no way to determine CPU times we
+ # return 0.0 as a fallback. See:
+ # https://github.com/giampaolo/psutil/issues/857
+ times = (0.0, 0.0, 0.0, 0.0)
+ else:
+ raise
+ return _common.pcputimes(*times)
+
+ @wrap_exceptions
+ def cpu_num(self):
+ return cext.proc_cpu_num(self.pid, self._procfs_path)
+
+ @wrap_exceptions
+ def terminal(self):
+ procfs_path = self._procfs_path
+ hit_enoent = False
+ tty = wrap_exceptions(
+ self._proc_basic_info()[proc_info_map['ttynr']])
+ if tty != cext.PRNODEV:
+ for x in (0, 1, 2, 255):
+ try:
+ return os.readlink(
+ '%s/%d/path/%d' % (procfs_path, self.pid, x))
+ except FileNotFoundError:
+ hit_enoent = True
+ continue
+ if hit_enoent:
+ self._assert_alive()
+
+ @wrap_exceptions
+ def cwd(self):
+ # /proc/PID/path/cwd may not be resolved by readlink() even if
+ # it exists (ls shows it). If that's the case and the process
+ # is still alive return None (we can return None also on BSD).
+ # Reference: http://goo.gl/55XgO
+ procfs_path = self._procfs_path
+ try:
+ return os.readlink("%s/%s/path/cwd" % (procfs_path, self.pid))
+ except FileNotFoundError:
+ os.stat("%s/%s" % (procfs_path, self.pid)) # raise NSP or AD
+ return None
+
+ @wrap_exceptions
+ def memory_info(self):
+ ret = self._proc_basic_info()
+ rss = ret[proc_info_map['rss']] * 1024
+ vms = ret[proc_info_map['vms']] * 1024
+ return pmem(rss, vms)
+
+ memory_full_info = memory_info
+
+ @wrap_exceptions
+ def status(self):
+ code = self._proc_basic_info()[proc_info_map['status']]
+ # XXX is '?' legit? (we're not supposed to return it anyway)
+ return PROC_STATUSES.get(code, '?')
+
+ @wrap_exceptions
+ def threads(self):
+ procfs_path = self._procfs_path
+ ret = []
+ tids = os.listdir('%s/%d/lwp' % (procfs_path, self.pid))
+ hit_enoent = False
+ for tid in tids:
+ tid = int(tid)
+ try:
+ utime, stime = cext.query_process_thread(
+ self.pid, tid, procfs_path)
+ except EnvironmentError as err:
+ if err.errno == errno.EOVERFLOW and not IS_64_BIT:
+ # We may get here if we attempt to query a 64bit process
+ # with a 32bit python.
+ # Error originates from read() and also tools like "cat"
+ # fail in the same way (!).
+ # Since there simply is no way to determine CPU times we
+ # return 0.0 as a fallback. See:
+ # https://github.com/giampaolo/psutil/issues/857
+ continue
+ # ENOENT == thread gone in meantime
+ if err.errno == errno.ENOENT:
+ hit_enoent = True
+ continue
+ raise
+ else:
+ nt = _common.pthread(tid, utime, stime)
+ ret.append(nt)
+ if hit_enoent:
+ self._assert_alive()
+ return ret
+
+ @wrap_exceptions
+ def open_files(self):
+ retlist = []
+ hit_enoent = False
+ procfs_path = self._procfs_path
+ pathdir = '%s/%d/path' % (procfs_path, self.pid)
+ for fd in os.listdir('%s/%d/fd' % (procfs_path, self.pid)):
+ path = os.path.join(pathdir, fd)
+ if os.path.islink(path):
+ try:
+ file = os.readlink(path)
+ except FileNotFoundError:
+ hit_enoent = True
+ continue
+ else:
+ if isfile_strict(file):
+ retlist.append(_common.popenfile(file, int(fd)))
+ if hit_enoent:
+ self._assert_alive()
+ return retlist
+
+ def _get_unix_sockets(self, pid):
+ """Get UNIX sockets used by process by parsing 'pfiles' output."""
+ # TODO: rewrite this in C (...but the damn netstat source code
+ # does not include this part! Argh!!)
+ cmd = "pfiles %s" % pid
+ p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ stdout, stderr = p.communicate()
+ if PY3:
+ stdout, stderr = [x.decode(sys.stdout.encoding)
+ for x in (stdout, stderr)]
+ if p.returncode != 0:
+ if 'permission denied' in stderr.lower():
+ raise AccessDenied(self.pid, self._name)
+ if 'no such process' in stderr.lower():
+ raise NoSuchProcess(self.pid, self._name)
+ raise RuntimeError("%r command error\n%s" % (cmd, stderr))
+
+ lines = stdout.split('\n')[2:]
+ for i, line in enumerate(lines):
+ line = line.lstrip()
+ if line.startswith('sockname: AF_UNIX'):
+ path = line.split(' ', 2)[2]
+ type = lines[i - 2].strip()
+ if type == 'SOCK_STREAM':
+ type = socket.SOCK_STREAM
+ elif type == 'SOCK_DGRAM':
+ type = socket.SOCK_DGRAM
+ else:
+ type = -1
+ yield (-1, socket.AF_UNIX, type, path, "", _common.CONN_NONE)
+
+ @wrap_exceptions
+ def connections(self, kind='inet'):
+ ret = net_connections(kind, _pid=self.pid)
+ # The underlying C implementation retrieves all OS connections
+ # and filters them by PID. At this point we can't tell whether
+ # an empty list means there were no connections for process or
+ # process is no longer active so we force NSP in case the PID
+ # is no longer there.
+ if not ret:
+ # will raise NSP if process is gone
+ os.stat('%s/%s' % (self._procfs_path, self.pid))
+
+ # UNIX sockets
+ if kind in ('all', 'unix'):
+ ret.extend([_common.pconn(*conn) for conn in
+ self._get_unix_sockets(self.pid)])
+ return ret
+
+ nt_mmap_grouped = namedtuple('mmap', 'path rss anon locked')
+ nt_mmap_ext = namedtuple('mmap', 'addr perms path rss anon locked')
+
+ @wrap_exceptions
+ def memory_maps(self):
+ def toaddr(start, end):
+ return '%s-%s' % (hex(start)[2:].strip('L'),
+ hex(end)[2:].strip('L'))
+
+ procfs_path = self._procfs_path
+ retlist = []
+ try:
+ rawlist = cext.proc_memory_maps(self.pid, procfs_path)
+ except OSError as err:
+ if err.errno == errno.EOVERFLOW and not IS_64_BIT:
+ # We may get here if we attempt to query a 64bit process
+ # with a 32bit python.
+ # Error originates from read() and also tools like "cat"
+ # fail in the same way (!).
+ # Since there simply is no way to determine CPU times we
+ # return 0.0 as a fallback. See:
+ # https://github.com/giampaolo/psutil/issues/857
+ return []
+ else:
+ raise
+ hit_enoent = False
+ for item in rawlist:
+ addr, addrsize, perm, name, rss, anon, locked = item
+ addr = toaddr(addr, addrsize)
+ if not name.startswith('['):
+ try:
+ name = os.readlink(
+ '%s/%s/path/%s' % (procfs_path, self.pid, name))
+ except OSError as err:
+ if err.errno == errno.ENOENT:
+ # sometimes the link may not be resolved by
+ # readlink() even if it exists (ls shows it).
+ # If that's the case we just return the
+ # unresolved link path.
+ # This seems an incosistency with /proc similar
+ # to: http://goo.gl/55XgO
+ name = '%s/%s/path/%s' % (procfs_path, self.pid, name)
+ hit_enoent = True
+ else:
+ raise
+ retlist.append((addr, perm, name, rss, anon, locked))
+ if hit_enoent:
+ self._assert_alive()
+ return retlist
+
+ @wrap_exceptions
+ def num_fds(self):
+ return len(os.listdir("%s/%s/fd" % (self._procfs_path, self.pid)))
+
+ @wrap_exceptions
+ def num_ctx_switches(self):
+ return _common.pctxsw(
+ *cext.proc_num_ctx_switches(self.pid, self._procfs_path))
+
+ @wrap_exceptions
+ def wait(self, timeout=None):
+ return _psposix.wait_pid(self.pid, timeout, self._name)
diff --git a/lib/psutil/_psutil_linux.abi3.so b/lib/psutil/_psutil_linux.abi3.so
new file mode 100755
index 0000000..d6aa4c1
--- /dev/null
+++ b/lib/psutil/_psutil_linux.abi3.so
Binary files differ
diff --git a/lib/psutil/_psutil_posix.abi3.so b/lib/psutil/_psutil_posix.abi3.so
new file mode 100755
index 0000000..0156ed1
--- /dev/null
+++ b/lib/psutil/_psutil_posix.abi3.so
Binary files differ
diff --git a/lib/psutil/_pswindows.py b/lib/psutil/_pswindows.py
new file mode 100644
index 0000000..49f8b05
--- /dev/null
+++ b/lib/psutil/_pswindows.py
@@ -0,0 +1,1120 @@
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Windows platform implementation."""
+
+import contextlib
+import errno
+import functools
+import os
+import signal
+import sys
+import time
+from collections import namedtuple
+
+from . import _common
+from ._common import ENCODING
+from ._common import ENCODING_ERRS
+from ._common import AccessDenied
+from ._common import NoSuchProcess
+from ._common import TimeoutExpired
+from ._common import conn_tmap
+from ._common import conn_to_ntuple
+from ._common import debug
+from ._common import isfile_strict
+from ._common import memoize
+from ._common import memoize_when_activated
+from ._common import parse_environ_block
+from ._common import usage_percent
+from ._compat import PY3
+from ._compat import long
+from ._compat import lru_cache
+from ._compat import range
+from ._compat import unicode
+from ._psutil_windows import ABOVE_NORMAL_PRIORITY_CLASS
+from ._psutil_windows import BELOW_NORMAL_PRIORITY_CLASS
+from ._psutil_windows import HIGH_PRIORITY_CLASS
+from ._psutil_windows import IDLE_PRIORITY_CLASS
+from ._psutil_windows import NORMAL_PRIORITY_CLASS
+from ._psutil_windows import REALTIME_PRIORITY_CLASS
+
+
+try:
+ from . import _psutil_windows as cext
+except ImportError as err:
+ if str(err).lower().startswith("dll load failed") and \
+ sys.getwindowsversion()[0] < 6:
+ # We may get here if:
+ # 1) we are on an old Windows version
+ # 2) psutil was installed via pip + wheel
+ # See: https://github.com/giampaolo/psutil/issues/811
+ msg = "this Windows version is too old (< Windows Vista); "
+ msg += "psutil 3.4.2 is the latest version which supports Windows "
+ msg += "2000, XP and 2003 server"
+ raise RuntimeError(msg)
+ else:
+ raise
+
+if sys.version_info >= (3, 4):
+ import enum
+else:
+ enum = None
+
+# process priority constants, import from __init__.py:
+# http://msdn.microsoft.com/en-us/library/ms686219(v=vs.85).aspx
+__extra__all__ = [
+ "win_service_iter", "win_service_get",
+ # Process priority
+ "ABOVE_NORMAL_PRIORITY_CLASS", "BELOW_NORMAL_PRIORITY_CLASS",
+ "HIGH_PRIORITY_CLASS", "IDLE_PRIORITY_CLASS", "NORMAL_PRIORITY_CLASS",
+ "REALTIME_PRIORITY_CLASS",
+ # IO priority
+ "IOPRIO_VERYLOW", "IOPRIO_LOW", "IOPRIO_NORMAL", "IOPRIO_HIGH",
+ # others
+ "CONN_DELETE_TCB", "AF_LINK",
+]
+
+
+# =====================================================================
+# --- globals
+# =====================================================================
+
+CONN_DELETE_TCB = "DELETE_TCB"
+ERROR_PARTIAL_COPY = 299
+PYPY = '__pypy__' in sys.builtin_module_names
+
+if enum is None:
+ AF_LINK = -1
+else:
+ AddressFamily = enum.IntEnum('AddressFamily', {'AF_LINK': -1})
+ AF_LINK = AddressFamily.AF_LINK
+
+TCP_STATUSES = {
+ cext.MIB_TCP_STATE_ESTAB: _common.CONN_ESTABLISHED,
+ cext.MIB_TCP_STATE_SYN_SENT: _common.CONN_SYN_SENT,
+ cext.MIB_TCP_STATE_SYN_RCVD: _common.CONN_SYN_RECV,
+ cext.MIB_TCP_STATE_FIN_WAIT1: _common.CONN_FIN_WAIT1,
+ cext.MIB_TCP_STATE_FIN_WAIT2: _common.CONN_FIN_WAIT2,
+ cext.MIB_TCP_STATE_TIME_WAIT: _common.CONN_TIME_WAIT,
+ cext.MIB_TCP_STATE_CLOSED: _common.CONN_CLOSE,
+ cext.MIB_TCP_STATE_CLOSE_WAIT: _common.CONN_CLOSE_WAIT,
+ cext.MIB_TCP_STATE_LAST_ACK: _common.CONN_LAST_ACK,
+ cext.MIB_TCP_STATE_LISTEN: _common.CONN_LISTEN,
+ cext.MIB_TCP_STATE_CLOSING: _common.CONN_CLOSING,
+ cext.MIB_TCP_STATE_DELETE_TCB: CONN_DELETE_TCB,
+ cext.PSUTIL_CONN_NONE: _common.CONN_NONE,
+}
+
+if enum is not None:
+ class Priority(enum.IntEnum):
+ ABOVE_NORMAL_PRIORITY_CLASS = ABOVE_NORMAL_PRIORITY_CLASS
+ BELOW_NORMAL_PRIORITY_CLASS = BELOW_NORMAL_PRIORITY_CLASS
+ HIGH_PRIORITY_CLASS = HIGH_PRIORITY_CLASS
+ IDLE_PRIORITY_CLASS = IDLE_PRIORITY_CLASS
+ NORMAL_PRIORITY_CLASS = NORMAL_PRIORITY_CLASS
+ REALTIME_PRIORITY_CLASS = REALTIME_PRIORITY_CLASS
+
+ globals().update(Priority.__members__)
+
+if enum is None:
+ IOPRIO_VERYLOW = 0
+ IOPRIO_LOW = 1
+ IOPRIO_NORMAL = 2
+ IOPRIO_HIGH = 3
+else:
+ class IOPriority(enum.IntEnum):
+ IOPRIO_VERYLOW = 0
+ IOPRIO_LOW = 1
+ IOPRIO_NORMAL = 2
+ IOPRIO_HIGH = 3
+ globals().update(IOPriority.__members__)
+
+pinfo_map = dict(
+ num_handles=0,
+ ctx_switches=1,
+ user_time=2,
+ kernel_time=3,
+ create_time=4,
+ num_threads=5,
+ io_rcount=6,
+ io_wcount=7,
+ io_rbytes=8,
+ io_wbytes=9,
+ io_count_others=10,
+ io_bytes_others=11,
+ num_page_faults=12,
+ peak_wset=13,
+ wset=14,
+ peak_paged_pool=15,
+ paged_pool=16,
+ peak_non_paged_pool=17,
+ non_paged_pool=18,
+ pagefile=19,
+ peak_pagefile=20,
+ mem_private=21,
+)
+
+
+# =====================================================================
+# --- named tuples
+# =====================================================================
+
+
+# psutil.cpu_times()
+scputimes = namedtuple('scputimes',
+ ['user', 'system', 'idle', 'interrupt', 'dpc'])
+# psutil.virtual_memory()
+svmem = namedtuple('svmem', ['total', 'available', 'percent', 'used', 'free'])
+# psutil.Process.memory_info()
+pmem = namedtuple(
+ 'pmem', ['rss', 'vms',
+ 'num_page_faults', 'peak_wset', 'wset', 'peak_paged_pool',
+ 'paged_pool', 'peak_nonpaged_pool', 'nonpaged_pool',
+ 'pagefile', 'peak_pagefile', 'private'])
+# psutil.Process.memory_full_info()
+pfullmem = namedtuple('pfullmem', pmem._fields + ('uss', ))
+# psutil.Process.memory_maps(grouped=True)
+pmmap_grouped = namedtuple('pmmap_grouped', ['path', 'rss'])
+# psutil.Process.memory_maps(grouped=False)
+pmmap_ext = namedtuple(
+ 'pmmap_ext', 'addr perms ' + ' '.join(pmmap_grouped._fields))
+# psutil.Process.io_counters()
+pio = namedtuple('pio', ['read_count', 'write_count',
+ 'read_bytes', 'write_bytes',
+ 'other_count', 'other_bytes'])
+
+
+# =====================================================================
+# --- utils
+# =====================================================================
+
+
+@lru_cache(maxsize=512)
+def convert_dos_path(s):
+ r"""Convert paths using native DOS format like:
+ "\Device\HarddiskVolume1\Windows\systemew\file.txt"
+ into:
+ "C:\Windows\systemew\file.txt"
+ """
+ rawdrive = '\\'.join(s.split('\\')[:3])
+ driveletter = cext.QueryDosDevice(rawdrive)
+ remainder = s[len(rawdrive):]
+ return os.path.join(driveletter, remainder)
+
+
+def py2_strencode(s):
+ """Encode a unicode string to a byte string by using the default fs
+ encoding + "replace" error handler.
+ """
+ if PY3:
+ return s
+ else:
+ if isinstance(s, str):
+ return s
+ else:
+ return s.encode(ENCODING, ENCODING_ERRS)
+
+
+@memoize
+def getpagesize():
+ return cext.getpagesize()
+
+
+# =====================================================================
+# --- memory
+# =====================================================================
+
+
+def virtual_memory():
+ """System virtual memory as a namedtuple."""
+ mem = cext.virtual_mem()
+ totphys, availphys, totsys, availsys = mem
+ #
+ total = totphys
+ avail = availphys
+ free = availphys
+ used = total - avail
+ percent = usage_percent((total - avail), total, round_=1)
+ return svmem(total, avail, percent, used, free)
+
+
+def swap_memory():
+ """Swap system memory as a (total, used, free, sin, sout) tuple."""
+ mem = cext.virtual_mem()
+
+ total_phys = mem[0]
+ total_system = mem[2]
+
+ # system memory (commit total/limit) is the sum of physical and swap
+ # thus physical memory values need to be substracted to get swap values
+ total = total_system - total_phys
+ # commit total is incremented immediately (decrementing free_system)
+ # while the corresponding free physical value is not decremented until
+ # pages are accessed, so we can't use free system memory for swap.
+ # instead, we calculate page file usage based on performance counter
+ if (total > 0):
+ percentswap = cext.getpercentswap()
+ used = int(0.01 * percentswap * total)
+ else:
+ used = 0
+ free = total - used
+ percent = usage_percent(used, total, round_=1)
+ return _common.sswap(total, used, free, percent, 0, 0)
+
+
+# =====================================================================
+# --- disk
+# =====================================================================
+
+
+disk_io_counters = cext.disk_io_counters
+
+
+def disk_usage(path):
+ """Return disk usage associated with path."""
+ if PY3 and isinstance(path, bytes):
+ # XXX: do we want to use "strict"? Probably yes, in order
+ # to fail immediately. After all we are accepting input here...
+ path = path.decode(ENCODING, errors="strict")
+ total, free = cext.disk_usage(path)
+ used = total - free
+ percent = usage_percent(used, total, round_=1)
+ return _common.sdiskusage(total, used, free, percent)
+
+
+def disk_partitions(all):
+ """Return disk partitions."""
+ rawlist = cext.disk_partitions(all)
+ return [_common.sdiskpart(*x) for x in rawlist]
+
+
+# =====================================================================
+# --- CPU
+# =====================================================================
+
+
+def cpu_times():
+ """Return system CPU times as a named tuple."""
+ user, system, idle = cext.cpu_times()
+ # Internally, GetSystemTimes() is used, and it doesn't return
+ # interrupt and dpc times. cext.per_cpu_times() does, so we
+ # rely on it to get those only.
+ percpu_summed = scputimes(*[sum(n) for n in zip(*cext.per_cpu_times())])
+ return scputimes(user, system, idle,
+ percpu_summed.interrupt, percpu_summed.dpc)
+
+
+def per_cpu_times():
+ """Return system per-CPU times as a list of named tuples."""
+ ret = []
+ for user, system, idle, interrupt, dpc in cext.per_cpu_times():
+ item = scputimes(user, system, idle, interrupt, dpc)
+ ret.append(item)
+ return ret
+
+
+def cpu_count_logical():
+ """Return the number of logical CPUs in the system."""
+ return cext.cpu_count_logical()
+
+
+def cpu_count_cores():
+ """Return the number of CPU cores in the system."""
+ return cext.cpu_count_cores()
+
+
+def cpu_stats():
+ """Return CPU statistics."""
+ ctx_switches, interrupts, dpcs, syscalls = cext.cpu_stats()
+ soft_interrupts = 0
+ return _common.scpustats(ctx_switches, interrupts, soft_interrupts,
+ syscalls)
+
+
+def cpu_freq():
+ """Return CPU frequency.
+ On Windows per-cpu frequency is not supported.
+ """
+ curr, max_ = cext.cpu_freq()
+ min_ = 0.0
+ return [_common.scpufreq(float(curr), min_, float(max_))]
+
+
+_loadavg_inititialized = False
+
+
+def getloadavg():
+ """Return the number of processes in the system run queue averaged
+ over the last 1, 5, and 15 minutes respectively as a tuple"""
+ global _loadavg_inititialized
+
+ if not _loadavg_inititialized:
+ cext.init_loadavg_counter()
+ _loadavg_inititialized = True
+
+ # Drop to 2 decimal points which is what Linux does
+ raw_loads = cext.getloadavg()
+ return tuple([round(load, 2) for load in raw_loads])
+
+
+# =====================================================================
+# --- network
+# =====================================================================
+
+
+def net_connections(kind, _pid=-1):
+ """Return socket connections. If pid == -1 return system-wide
+ connections (as opposed to connections opened by one process only).
+ """
+ if kind not in conn_tmap:
+ raise ValueError("invalid %r kind argument; choose between %s"
+ % (kind, ', '.join([repr(x) for x in conn_tmap])))
+ families, types = conn_tmap[kind]
+ rawlist = cext.net_connections(_pid, families, types)
+ ret = set()
+ for item in rawlist:
+ fd, fam, type, laddr, raddr, status, pid = item
+ nt = conn_to_ntuple(fd, fam, type, laddr, raddr, status, TCP_STATUSES,
+ pid=pid if _pid == -1 else None)
+ ret.add(nt)
+ return list(ret)
+
+
+def net_if_stats():
+ """Get NIC stats (isup, duplex, speed, mtu)."""
+ ret = {}
+ rawdict = cext.net_if_stats()
+ for name, items in rawdict.items():
+ if not PY3:
+ assert isinstance(name, unicode), type(name)
+ name = py2_strencode(name)
+ isup, duplex, speed, mtu = items
+ if hasattr(_common, 'NicDuplex'):
+ duplex = _common.NicDuplex(duplex)
+ ret[name] = _common.snicstats(isup, duplex, speed, mtu, '')
+ return ret
+
+
+def net_io_counters():
+ """Return network I/O statistics for every network interface
+ installed on the system as a dict of raw tuples.
+ """
+ ret = cext.net_io_counters()
+ return dict([(py2_strencode(k), v) for k, v in ret.items()])
+
+
+def net_if_addrs():
+ """Return the addresses associated to each NIC."""
+ ret = []
+ for items in cext.net_if_addrs():
+ items = list(items)
+ items[0] = py2_strencode(items[0])
+ ret.append(items)
+ return ret
+
+
+# =====================================================================
+# --- sensors
+# =====================================================================
+
+
+def sensors_battery():
+ """Return battery information."""
+ # For constants meaning see:
+ # https://msdn.microsoft.com/en-us/library/windows/desktop/
+ # aa373232(v=vs.85).aspx
+ acline_status, flags, percent, secsleft = cext.sensors_battery()
+ power_plugged = acline_status == 1
+ no_battery = bool(flags & 128)
+ charging = bool(flags & 8)
+
+ if no_battery:
+ return None
+ if power_plugged or charging:
+ secsleft = _common.POWER_TIME_UNLIMITED
+ elif secsleft == -1:
+ secsleft = _common.POWER_TIME_UNKNOWN
+
+ return _common.sbattery(percent, secsleft, power_plugged)
+
+
+# =====================================================================
+# --- other system functions
+# =====================================================================
+
+
+_last_btime = 0
+
+
+def boot_time():
+ """The system boot time expressed in seconds since the epoch."""
+ # This dirty hack is to adjust the precision of the returned
+ # value which may have a 1 second fluctuation, see:
+ # https://github.com/giampaolo/psutil/issues/1007
+ global _last_btime
+ ret = float(cext.boot_time())
+ if abs(ret - _last_btime) <= 1:
+ return _last_btime
+ else:
+ _last_btime = ret
+ return ret
+
+
+def users():
+ """Return currently connected users as a list of namedtuples."""
+ retlist = []
+ rawlist = cext.users()
+ for item in rawlist:
+ user, hostname, tstamp = item
+ user = py2_strencode(user)
+ nt = _common.suser(user, None, hostname, tstamp, None)
+ retlist.append(nt)
+ return retlist
+
+
+# =====================================================================
+# --- Windows services
+# =====================================================================
+
+
+def win_service_iter():
+ """Yields a list of WindowsService instances."""
+ for name, display_name in cext.winservice_enumerate():
+ yield WindowsService(py2_strencode(name), py2_strencode(display_name))
+
+
+def win_service_get(name):
+ """Open a Windows service and return it as a WindowsService instance."""
+ service = WindowsService(name, None)
+ service._display_name = service._query_config()['display_name']
+ return service
+
+
+class WindowsService(object):
+ """Represents an installed Windows service."""
+
+ def __init__(self, name, display_name):
+ self._name = name
+ self._display_name = display_name
+
+ def __str__(self):
+ details = "(name=%r, display_name=%r)" % (
+ self._name, self._display_name)
+ return "%s%s" % (self.__class__.__name__, details)
+
+ def __repr__(self):
+ return "<%s at %s>" % (self.__str__(), id(self))
+
+ def __eq__(self, other):
+ # Test for equality with another WindosService object based
+ # on name.
+ if not isinstance(other, WindowsService):
+ return NotImplemented
+ return self._name == other._name
+
+ def __ne__(self, other):
+ return not self == other
+
+ def _query_config(self):
+ with self._wrap_exceptions():
+ display_name, binpath, username, start_type = \
+ cext.winservice_query_config(self._name)
+ # XXX - update _self.display_name?
+ return dict(
+ display_name=py2_strencode(display_name),
+ binpath=py2_strencode(binpath),
+ username=py2_strencode(username),
+ start_type=py2_strencode(start_type))
+
+ def _query_status(self):
+ with self._wrap_exceptions():
+ status, pid = cext.winservice_query_status(self._name)
+ if pid == 0:
+ pid = None
+ return dict(status=status, pid=pid)
+
+ @contextlib.contextmanager
+ def _wrap_exceptions(self):
+ """Ctx manager which translates bare OSError and WindowsError
+ exceptions into NoSuchProcess and AccessDenied.
+ """
+ try:
+ yield
+ except OSError as err:
+ if is_permission_err(err):
+ raise AccessDenied(
+ pid=None, name=self._name,
+ msg="service %r is not querable (not enough privileges)" %
+ self._name)
+ elif err.winerror in (cext.ERROR_INVALID_NAME,
+ cext.ERROR_SERVICE_DOES_NOT_EXIST):
+ raise NoSuchProcess(
+ pid=None, name=self._name,
+ msg="service %r does not exist)" % self._name)
+ else:
+ raise
+
+ # config query
+
+ def name(self):
+ """The service name. This string is how a service is referenced
+ and can be passed to win_service_get() to get a new
+ WindowsService instance.
+ """
+ return self._name
+
+ def display_name(self):
+ """The service display name. The value is cached when this class
+ is instantiated.
+ """
+ return self._display_name
+
+ def binpath(self):
+ """The fully qualified path to the service binary/exe file as
+ a string, including command line arguments.
+ """
+ return self._query_config()['binpath']
+
+ def username(self):
+ """The name of the user that owns this service."""
+ return self._query_config()['username']
+
+ def start_type(self):
+ """A string which can either be "automatic", "manual" or
+ "disabled".
+ """
+ return self._query_config()['start_type']
+
+ # status query
+
+ def pid(self):
+ """The process PID, if any, else None. This can be passed
+ to Process class to control the service's process.
+ """
+ return self._query_status()['pid']
+
+ def status(self):
+ """Service status as a string."""
+ return self._query_status()['status']
+
+ def description(self):
+ """Service long description."""
+ return py2_strencode(cext.winservice_query_descr(self.name()))
+
+ # utils
+
+ def as_dict(self):
+ """Utility method retrieving all the information above as a
+ dictionary.
+ """
+ d = self._query_config()
+ d.update(self._query_status())
+ d['name'] = self.name()
+ d['display_name'] = self.display_name()
+ d['description'] = self.description()
+ return d
+
+ # actions
+ # XXX: the necessary C bindings for start() and stop() are
+ # implemented but for now I prefer not to expose them.
+ # I may change my mind in the future. Reasons:
+ # - they require Administrator privileges
+ # - can't implement a timeout for stop() (unless by using a thread,
+ # which sucks)
+ # - would require adding ServiceAlreadyStarted and
+ # ServiceAlreadyStopped exceptions, adding two new APIs.
+ # - we might also want to have modify(), which would basically mean
+ # rewriting win32serviceutil.ChangeServiceConfig, which involves a
+ # lot of stuff (and API constants which would pollute the API), see:
+ # http://pyxr.sourceforge.net/PyXR/c/python24/lib/site-packages/
+ # win32/lib/win32serviceutil.py.html#0175
+ # - psutil is typically about "read only" monitoring stuff;
+ # win_service_* APIs should only be used to retrieve a service and
+ # check whether it's running
+
+ # def start(self, timeout=None):
+ # with self._wrap_exceptions():
+ # cext.winservice_start(self.name())
+ # if timeout:
+ # giveup_at = time.time() + timeout
+ # while True:
+ # if self.status() == "running":
+ # return
+ # else:
+ # if time.time() > giveup_at:
+ # raise TimeoutExpired(timeout)
+ # else:
+ # time.sleep(.1)
+
+ # def stop(self):
+ # # Note: timeout is not implemented because it's just not
+ # # possible, see:
+ # # http://stackoverflow.com/questions/11973228/
+ # with self._wrap_exceptions():
+ # return cext.winservice_stop(self.name())
+
+
+# =====================================================================
+# --- processes
+# =====================================================================
+
+
+pids = cext.pids
+pid_exists = cext.pid_exists
+ppid_map = cext.ppid_map # used internally by Process.children()
+
+
+def is_permission_err(exc):
+ """Return True if this is a permission error."""
+ assert isinstance(exc, OSError), exc
+ # On Python 2 OSError doesn't always have 'winerror'. Sometimes
+ # it does, in which case the original exception was WindowsError
+ # (which is a subclass of OSError).
+ return exc.errno in (errno.EPERM, errno.EACCES) or \
+ getattr(exc, "winerror", -1) in (cext.ERROR_ACCESS_DENIED,
+ cext.ERROR_PRIVILEGE_NOT_HELD)
+
+
+def convert_oserror(exc, pid=None, name=None):
+ """Convert OSError into NoSuchProcess or AccessDenied."""
+ assert isinstance(exc, OSError), exc
+ if is_permission_err(exc):
+ return AccessDenied(pid=pid, name=name)
+ if exc.errno == errno.ESRCH:
+ return NoSuchProcess(pid=pid, name=name)
+ raise exc
+
+
+def wrap_exceptions(fun):
+ """Decorator which converts OSError into NoSuchProcess or AccessDenied."""
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return fun(self, *args, **kwargs)
+ except OSError as err:
+ raise convert_oserror(err, pid=self.pid, name=self._name)
+ return wrapper
+
+
+def retry_error_partial_copy(fun):
+ """Workaround for https://github.com/giampaolo/psutil/issues/875.
+ See: https://stackoverflow.com/questions/4457745#4457745
+ """
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ delay = 0.0001
+ times = 33
+ for x in range(times): # retries for roughly 1 second
+ try:
+ return fun(self, *args, **kwargs)
+ except WindowsError as _:
+ err = _
+ if err.winerror == ERROR_PARTIAL_COPY:
+ time.sleep(delay)
+ delay = min(delay * 2, 0.04)
+ continue
+ else:
+ raise
+ else:
+ msg = "%s retried %s times, converted to AccessDenied as it's " \
+ "still returning %r" % (fun, times, err)
+ raise AccessDenied(pid=self.pid, name=self._name, msg=msg)
+ return wrapper
+
+
+class Process(object):
+ """Wrapper class around underlying C implementation."""
+
+ __slots__ = ["pid", "_name", "_ppid", "_cache"]
+
+ def __init__(self, pid):
+ self.pid = pid
+ self._name = None
+ self._ppid = None
+
+ # --- oneshot() stuff
+
+ def oneshot_enter(self):
+ self._proc_info.cache_activate(self)
+ self.exe.cache_activate(self)
+
+ def oneshot_exit(self):
+ self._proc_info.cache_deactivate(self)
+ self.exe.cache_deactivate(self)
+
+ @memoize_when_activated
+ def _proc_info(self):
+ """Return multiple information about this process as a
+ raw tuple.
+ """
+ ret = cext.proc_info(self.pid)
+ assert len(ret) == len(pinfo_map)
+ return ret
+
+ def name(self):
+ """Return process name, which on Windows is always the final
+ part of the executable.
+ """
+ # This is how PIDs 0 and 4 are always represented in taskmgr
+ # and process-hacker.
+ if self.pid == 0:
+ return "System Idle Process"
+ if self.pid == 4:
+ return "System"
+ return os.path.basename(self.exe())
+
+ @wrap_exceptions
+ @memoize_when_activated
+ def exe(self):
+ if PYPY:
+ try:
+ exe = cext.proc_exe(self.pid)
+ except WindowsError as err:
+ # 24 = ERROR_TOO_MANY_OPEN_FILES. Not sure why this happens
+ # (perhaps PyPy's JIT delaying garbage collection of files?).
+ if err.errno == 24:
+ debug("%r translated into AccessDenied" % err)
+ raise AccessDenied(self.pid, self._name)
+ raise
+ else:
+ exe = cext.proc_exe(self.pid)
+ if not PY3:
+ exe = py2_strencode(exe)
+ if exe.startswith('\\'):
+ return convert_dos_path(exe)
+ return exe # May be "Registry", "MemCompression", ...
+
+ @wrap_exceptions
+ @retry_error_partial_copy
+ def cmdline(self):
+ if cext.WINVER >= cext.WINDOWS_8_1:
+ # PEB method detects cmdline changes but requires more
+ # privileges: https://github.com/giampaolo/psutil/pull/1398
+ try:
+ ret = cext.proc_cmdline(self.pid, use_peb=True)
+ except OSError as err:
+ if is_permission_err(err):
+ ret = cext.proc_cmdline(self.pid, use_peb=False)
+ else:
+ raise
+ else:
+ ret = cext.proc_cmdline(self.pid, use_peb=True)
+ if PY3:
+ return ret
+ else:
+ return [py2_strencode(s) for s in ret]
+
+ @wrap_exceptions
+ @retry_error_partial_copy
+ def environ(self):
+ ustr = cext.proc_environ(self.pid)
+ if ustr and not PY3:
+ assert isinstance(ustr, unicode), type(ustr)
+ return parse_environ_block(py2_strencode(ustr))
+
+ def ppid(self):
+ try:
+ return ppid_map()[self.pid]
+ except KeyError:
+ raise NoSuchProcess(self.pid, self._name)
+
+ def _get_raw_meminfo(self):
+ try:
+ return cext.proc_memory_info(self.pid)
+ except OSError as err:
+ if is_permission_err(err):
+ # TODO: the C ext can probably be refactored in order
+ # to get this from cext.proc_info()
+ info = self._proc_info()
+ return (
+ info[pinfo_map['num_page_faults']],
+ info[pinfo_map['peak_wset']],
+ info[pinfo_map['wset']],
+ info[pinfo_map['peak_paged_pool']],
+ info[pinfo_map['paged_pool']],
+ info[pinfo_map['peak_non_paged_pool']],
+ info[pinfo_map['non_paged_pool']],
+ info[pinfo_map['pagefile']],
+ info[pinfo_map['peak_pagefile']],
+ info[pinfo_map['mem_private']],
+ )
+ raise
+
+ @wrap_exceptions
+ def memory_info(self):
+ # on Windows RSS == WorkingSetSize and VSM == PagefileUsage.
+ # Underlying C function returns fields of PROCESS_MEMORY_COUNTERS
+ # struct.
+ t = self._get_raw_meminfo()
+ rss = t[2] # wset
+ vms = t[7] # pagefile
+ return pmem(*(rss, vms, ) + t)
+
+ @wrap_exceptions
+ def memory_full_info(self):
+ basic_mem = self.memory_info()
+ uss = cext.proc_memory_uss(self.pid)
+ uss *= getpagesize()
+ return pfullmem(*basic_mem + (uss, ))
+
+ def memory_maps(self):
+ try:
+ raw = cext.proc_memory_maps(self.pid)
+ except OSError as err:
+ # XXX - can't use wrap_exceptions decorator as we're
+ # returning a generator; probably needs refactoring.
+ raise convert_oserror(err, self.pid, self._name)
+ else:
+ for addr, perm, path, rss in raw:
+ path = convert_dos_path(path)
+ if not PY3:
+ path = py2_strencode(path)
+ addr = hex(addr)
+ yield (addr, perm, path, rss)
+
+ @wrap_exceptions
+ def kill(self):
+ return cext.proc_kill(self.pid)
+
+ @wrap_exceptions
+ def send_signal(self, sig):
+ if sig == signal.SIGTERM:
+ cext.proc_kill(self.pid)
+ # py >= 2.7
+ elif sig in (getattr(signal, "CTRL_C_EVENT", object()),
+ getattr(signal, "CTRL_BREAK_EVENT", object())):
+ os.kill(self.pid, sig)
+ else:
+ raise ValueError(
+ "only SIGTERM, CTRL_C_EVENT and CTRL_BREAK_EVENT signals "
+ "are supported on Windows")
+
+ @wrap_exceptions
+ def wait(self, timeout=None):
+ if timeout is None:
+ cext_timeout = cext.INFINITE
+ else:
+ # WaitForSingleObject() expects time in milliseconds.
+ cext_timeout = int(timeout * 1000)
+
+ timer = getattr(time, 'monotonic', time.time)
+ stop_at = timer() + timeout if timeout is not None else None
+
+ try:
+ # Exit code is supposed to come from GetExitCodeProcess().
+ # May also be None if OpenProcess() failed with
+ # ERROR_INVALID_PARAMETER, meaning PID is already gone.
+ exit_code = cext.proc_wait(self.pid, cext_timeout)
+ except cext.TimeoutExpired:
+ # WaitForSingleObject() returned WAIT_TIMEOUT. Just raise.
+ raise TimeoutExpired(timeout, self.pid, self._name)
+ except cext.TimeoutAbandoned:
+ # WaitForSingleObject() returned WAIT_ABANDONED, see:
+ # https://github.com/giampaolo/psutil/issues/1224
+ # We'll just rely on the internal polling and return None
+ # when the PID disappears. Subprocess module does the same
+ # (return None):
+ # https://github.com/python/cpython/blob/
+ # be50a7b627d0aa37e08fa8e2d5568891f19903ce/
+ # Lib/subprocess.py#L1193-L1194
+ exit_code = None
+
+ # At this point WaitForSingleObject() returned WAIT_OBJECT_0,
+ # meaning the process is gone. Stupidly there are cases where
+ # its PID may still stick around so we do a further internal
+ # polling.
+ delay = 0.0001
+ while True:
+ if not pid_exists(self.pid):
+ return exit_code
+ if stop_at and timer() >= stop_at:
+ raise TimeoutExpired(timeout, pid=self.pid, name=self._name)
+ time.sleep(delay)
+ delay = min(delay * 2, 0.04) # incremental delay
+
+ @wrap_exceptions
+ def username(self):
+ if self.pid in (0, 4):
+ return 'NT AUTHORITY\\SYSTEM'
+ domain, user = cext.proc_username(self.pid)
+ return py2_strencode(domain) + '\\' + py2_strencode(user)
+
+ @wrap_exceptions
+ def create_time(self):
+ # Note: proc_times() not put under oneshot() 'cause create_time()
+ # is already cached by the main Process class.
+ try:
+ user, system, created = cext.proc_times(self.pid)
+ return created
+ except OSError as err:
+ if is_permission_err(err):
+ return self._proc_info()[pinfo_map['create_time']]
+ raise
+
+ @wrap_exceptions
+ def num_threads(self):
+ return self._proc_info()[pinfo_map['num_threads']]
+
+ @wrap_exceptions
+ def threads(self):
+ rawlist = cext.proc_threads(self.pid)
+ retlist = []
+ for thread_id, utime, stime in rawlist:
+ ntuple = _common.pthread(thread_id, utime, stime)
+ retlist.append(ntuple)
+ return retlist
+
+ @wrap_exceptions
+ def cpu_times(self):
+ try:
+ user, system, created = cext.proc_times(self.pid)
+ except OSError as err:
+ if not is_permission_err(err):
+ raise
+ info = self._proc_info()
+ user = info[pinfo_map['user_time']]
+ system = info[pinfo_map['kernel_time']]
+ # Children user/system times are not retrievable (set to 0).
+ return _common.pcputimes(user, system, 0.0, 0.0)
+
+ @wrap_exceptions
+ def suspend(self):
+ cext.proc_suspend_or_resume(self.pid, True)
+
+ @wrap_exceptions
+ def resume(self):
+ cext.proc_suspend_or_resume(self.pid, False)
+
+ @wrap_exceptions
+ @retry_error_partial_copy
+ def cwd(self):
+ if self.pid in (0, 4):
+ raise AccessDenied(self.pid, self._name)
+ # return a normalized pathname since the native C function appends
+ # "\\" at the and of the path
+ path = cext.proc_cwd(self.pid)
+ return py2_strencode(os.path.normpath(path))
+
+ @wrap_exceptions
+ def open_files(self):
+ if self.pid in (0, 4):
+ return []
+ ret = set()
+ # Filenames come in in native format like:
+ # "\Device\HarddiskVolume1\Windows\systemew\file.txt"
+ # Convert the first part in the corresponding drive letter
+ # (e.g. "C:\") by using Windows's QueryDosDevice()
+ raw_file_names = cext.proc_open_files(self.pid)
+ for _file in raw_file_names:
+ _file = convert_dos_path(_file)
+ if isfile_strict(_file):
+ if not PY3:
+ _file = py2_strencode(_file)
+ ntuple = _common.popenfile(_file, -1)
+ ret.add(ntuple)
+ return list(ret)
+
+ @wrap_exceptions
+ def connections(self, kind='inet'):
+ return net_connections(kind, _pid=self.pid)
+
+ @wrap_exceptions
+ def nice_get(self):
+ value = cext.proc_priority_get(self.pid)
+ if enum is not None:
+ value = Priority(value)
+ return value
+
+ @wrap_exceptions
+ def nice_set(self, value):
+ return cext.proc_priority_set(self.pid, value)
+
+ @wrap_exceptions
+ def ionice_get(self):
+ ret = cext.proc_io_priority_get(self.pid)
+ if enum is not None:
+ ret = IOPriority(ret)
+ return ret
+
+ @wrap_exceptions
+ def ionice_set(self, ioclass, value):
+ if value:
+ raise TypeError("value argument not accepted on Windows")
+ if ioclass not in (IOPRIO_VERYLOW, IOPRIO_LOW, IOPRIO_NORMAL,
+ IOPRIO_HIGH):
+ raise ValueError("%s is not a valid priority" % ioclass)
+ cext.proc_io_priority_set(self.pid, ioclass)
+
+ @wrap_exceptions
+ def io_counters(self):
+ try:
+ ret = cext.proc_io_counters(self.pid)
+ except OSError as err:
+ if not is_permission_err(err):
+ raise
+ info = self._proc_info()
+ ret = (
+ info[pinfo_map['io_rcount']],
+ info[pinfo_map['io_wcount']],
+ info[pinfo_map['io_rbytes']],
+ info[pinfo_map['io_wbytes']],
+ info[pinfo_map['io_count_others']],
+ info[pinfo_map['io_bytes_others']],
+ )
+ return pio(*ret)
+
+ @wrap_exceptions
+ def status(self):
+ suspended = cext.proc_is_suspended(self.pid)
+ if suspended:
+ return _common.STATUS_STOPPED
+ else:
+ return _common.STATUS_RUNNING
+
+ @wrap_exceptions
+ def cpu_affinity_get(self):
+ def from_bitmask(x):
+ return [i for i in range(64) if (1 << i) & x]
+ bitmask = cext.proc_cpu_affinity_get(self.pid)
+ return from_bitmask(bitmask)
+
+ @wrap_exceptions
+ def cpu_affinity_set(self, value):
+ def to_bitmask(ls):
+ if not ls:
+ raise ValueError("invalid argument %r" % ls)
+ out = 0
+ for b in ls:
+ out |= 2 ** b
+ return out
+
+ # SetProcessAffinityMask() states that ERROR_INVALID_PARAMETER
+ # is returned for an invalid CPU but this seems not to be true,
+ # therefore we check CPUs validy beforehand.
+ allcpus = list(range(len(per_cpu_times())))
+ for cpu in value:
+ if cpu not in allcpus:
+ if not isinstance(cpu, (int, long)):
+ raise TypeError(
+ "invalid CPU %r; an integer is required" % cpu)
+ else:
+ raise ValueError("invalid CPU %r" % cpu)
+
+ bitmask = to_bitmask(value)
+ cext.proc_cpu_affinity_set(self.pid, bitmask)
+
+ @wrap_exceptions
+ def num_handles(self):
+ try:
+ return cext.proc_num_handles(self.pid)
+ except OSError as err:
+ if is_permission_err(err):
+ return self._proc_info()[pinfo_map['num_handles']]
+ raise
+
+ @wrap_exceptions
+ def num_ctx_switches(self):
+ ctx_switches = self._proc_info()[pinfo_map['ctx_switches']]
+ # only voluntary ctx switches are supported
+ return _common.pctxsw(ctx_switches, 0)
diff --git a/lib/psutil/tests/__init__.py b/lib/psutil/tests/__init__.py
new file mode 100644
index 0000000..ec9c748
--- /dev/null
+++ b/lib/psutil/tests/__init__.py
@@ -0,0 +1,1820 @@
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Test utilities.
+"""
+
+from __future__ import print_function
+
+import atexit
+import contextlib
+import ctypes
+import errno
+import functools
+import gc
+import inspect
+import os
+import platform
+import random
+import re
+import select
+import shlex
+import shutil
+import signal
+import socket
+import stat
+import subprocess
+import sys
+import tempfile
+import textwrap
+import threading
+import time
+import unittest
+import warnings
+from socket import AF_INET
+from socket import AF_INET6
+from socket import SOCK_STREAM
+
+import psutil
+from psutil import AIX
+from psutil import FREEBSD
+from psutil import LINUX
+from psutil import MACOS
+from psutil import POSIX
+from psutil import SUNOS
+from psutil import WINDOWS
+from psutil._common import bytes2human
+from psutil._common import memoize
+from psutil._common import print_color
+from psutil._common import supports_ipv6
+from psutil._compat import PY3
+from psutil._compat import FileExistsError
+from psutil._compat import FileNotFoundError
+from psutil._compat import range
+from psutil._compat import super
+from psutil._compat import u
+from psutil._compat import unicode
+from psutil._compat import which
+
+
+try:
+ from unittest import mock # py3
+except ImportError:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ import mock # NOQA - requires "pip install mock"
+
+if sys.version_info >= (3, 4):
+ import enum
+else:
+ enum = None
+
+if POSIX:
+ from psutil._psposix import wait_pid
+
+
+__all__ = [
+ # constants
+ 'APPVEYOR', 'DEVNULL', 'GLOBAL_TIMEOUT', 'TOLERANCE_SYS_MEM', 'NO_RETRIES',
+ 'PYPY', 'PYTHON_EXE', 'ROOT_DIR', 'SCRIPTS_DIR', 'TESTFN_PREFIX',
+ 'UNICODE_SUFFIX', 'INVALID_UNICODE_SUFFIX',
+ 'CI_TESTING', 'VALID_PROC_STATUSES', 'TOLERANCE_DISK_USAGE', 'IS_64BIT',
+ "HAS_CPU_AFFINITY", "HAS_CPU_FREQ", "HAS_ENVIRON", "HAS_PROC_IO_COUNTERS",
+ "HAS_IONICE", "HAS_MEMORY_MAPS", "HAS_PROC_CPU_NUM", "HAS_RLIMIT",
+ "HAS_SENSORS_BATTERY", "HAS_BATTERY", "HAS_SENSORS_FANS",
+ "HAS_SENSORS_TEMPERATURES", "HAS_MEMORY_FULL_INFO", "MACOS_11PLUS",
+ "MACOS_12PLUS",
+ # subprocesses
+ 'pyrun', 'terminate', 'reap_children', 'spawn_testproc', 'spawn_zombie',
+ 'spawn_children_pair',
+ # threads
+ 'ThreadTask'
+ # test utils
+ 'unittest', 'skip_on_access_denied', 'skip_on_not_implemented',
+ 'retry_on_failure', 'TestMemoryLeak', 'PsutilTestCase',
+ 'process_namespace', 'system_namespace', 'print_sysinfo',
+ # install utils
+ 'install_pip', 'install_test_deps',
+ # fs utils
+ 'chdir', 'safe_rmpath', 'create_exe', 'decode_path', 'encode_path',
+ 'get_testfn',
+ # os
+ 'get_winver', 'kernel_version',
+ # sync primitives
+ 'call_until', 'wait_for_pid', 'wait_for_file',
+ # network
+ 'check_net_address',
+ 'get_free_port', 'bind_socket', 'bind_unix_socket', 'tcp_socketpair',
+ 'unix_socketpair', 'create_sockets',
+ # compat
+ 'reload_module', 'import_module_by_path',
+ # others
+ 'warn', 'copyload_shared_lib', 'is_namedtuple',
+]
+
+
+# ===================================================================
+# --- constants
+# ===================================================================
+
+# --- platforms
+
+PYPY = '__pypy__' in sys.builtin_module_names
+# whether we're running this test suite on a Continuous Integration service
+APPVEYOR = 'APPVEYOR' in os.environ
+GITHUB_ACTIONS = 'GITHUB_ACTIONS' in os.environ or 'CIBUILDWHEEL' in os.environ
+CI_TESTING = APPVEYOR or GITHUB_ACTIONS
+# are we a 64 bit process?
+IS_64BIT = sys.maxsize > 2 ** 32
+
+
+@memoize
+def macos_version():
+ version_str = platform.mac_ver()[0]
+ version = tuple(map(int, version_str.split(".")[:2]))
+ if version == (10, 16):
+ # When built against an older macOS SDK, Python will report
+ # macOS 10.16 instead of the real version.
+ version_str = subprocess.check_output(
+ [
+ sys.executable,
+ "-sS",
+ "-c",
+ "import platform; print(platform.mac_ver()[0])",
+ ],
+ env={"SYSTEM_VERSION_COMPAT": "0"},
+ universal_newlines=True,
+ )
+ version = tuple(map(int, version_str.split(".")[:2]))
+ return version
+
+
+if MACOS:
+ MACOS_11PLUS = macos_version() > (10, 15)
+ MACOS_12PLUS = macos_version() >= (12, 0)
+else:
+ MACOS_11PLUS = False
+ MACOS_12PLUS = False
+
+
+# --- configurable defaults
+
+# how many times retry_on_failure() decorator will retry
+NO_RETRIES = 10
+# bytes tolerance for system-wide related tests
+TOLERANCE_SYS_MEM = 5 * 1024 * 1024 # 5MB
+TOLERANCE_DISK_USAGE = 10 * 1024 * 1024 # 10MB
+# the timeout used in functions which have to wait
+GLOBAL_TIMEOUT = 5
+# be more tolerant if we're on CI in order to avoid false positives
+if CI_TESTING:
+ NO_RETRIES *= 3
+ GLOBAL_TIMEOUT *= 3
+ TOLERANCE_SYS_MEM *= 4
+ TOLERANCE_DISK_USAGE *= 3
+
+# --- file names
+
+# Disambiguate TESTFN for parallel testing.
+if os.name == 'java':
+ # Jython disallows @ in module names
+ TESTFN_PREFIX = '$psutil-%s-' % os.getpid()
+else:
+ TESTFN_PREFIX = '@psutil-%s-' % os.getpid()
+UNICODE_SUFFIX = u("-ƒőő")
+# An invalid unicode string.
+if PY3:
+ INVALID_UNICODE_SUFFIX = b"f\xc0\x80".decode('utf8', 'surrogateescape')
+else:
+ INVALID_UNICODE_SUFFIX = "f\xc0\x80"
+ASCII_FS = sys.getfilesystemencoding().lower() in ('ascii', 'us-ascii')
+
+# --- paths
+
+ROOT_DIR = os.path.realpath(
+ os.path.join(os.path.dirname(__file__), '..', '..'))
+SCRIPTS_DIR = os.path.join(ROOT_DIR, 'scripts')
+HERE = os.path.realpath(os.path.dirname(__file__))
+
+# --- support
+
+HAS_CONNECTIONS_UNIX = POSIX and not SUNOS
+HAS_CPU_AFFINITY = hasattr(psutil.Process, "cpu_affinity")
+HAS_CPU_FREQ = hasattr(psutil, "cpu_freq")
+HAS_GETLOADAVG = hasattr(psutil, "getloadavg")
+HAS_ENVIRON = hasattr(psutil.Process, "environ")
+HAS_IONICE = hasattr(psutil.Process, "ionice")
+HAS_MEMORY_MAPS = hasattr(psutil.Process, "memory_maps")
+HAS_NET_IO_COUNTERS = hasattr(psutil, "net_io_counters")
+HAS_PROC_CPU_NUM = hasattr(psutil.Process, "cpu_num")
+HAS_PROC_IO_COUNTERS = hasattr(psutil.Process, "io_counters")
+HAS_RLIMIT = hasattr(psutil.Process, "rlimit")
+HAS_SENSORS_BATTERY = hasattr(psutil, "sensors_battery")
+try:
+ HAS_BATTERY = HAS_SENSORS_BATTERY and bool(psutil.sensors_battery())
+except Exception:
+ HAS_BATTERY = False
+HAS_SENSORS_FANS = hasattr(psutil, "sensors_fans")
+HAS_SENSORS_TEMPERATURES = hasattr(psutil, "sensors_temperatures")
+HAS_THREADS = hasattr(psutil.Process, "threads")
+SKIP_SYSCONS = (MACOS or AIX) and os.getuid() != 0
+
+# --- misc
+
+
+def _get_py_exe():
+ def attempt(exe):
+ try:
+ subprocess.check_call(
+ [exe, "-V"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ except Exception:
+ return None
+ else:
+ return exe
+
+ if GITHUB_ACTIONS:
+ if PYPY:
+ return which("pypy3") if PY3 else which("pypy")
+ elif FREEBSD:
+ return os.path.realpath(sys.executable)
+ else:
+ return which('python')
+ elif MACOS:
+ exe = \
+ attempt(sys.executable) or \
+ attempt(os.path.realpath(sys.executable)) or \
+ attempt(which("python%s.%s" % sys.version_info[:2])) or \
+ attempt(psutil.Process().exe())
+ if not exe:
+ raise ValueError("can't find python exe real abspath")
+ return exe
+ else:
+ exe = os.path.realpath(sys.executable)
+ assert os.path.exists(exe), exe
+ return exe
+
+
+PYTHON_EXE = _get_py_exe()
+DEVNULL = open(os.devnull, 'r+')
+atexit.register(DEVNULL.close)
+
+VALID_PROC_STATUSES = [getattr(psutil, x) for x in dir(psutil)
+ if x.startswith('STATUS_')]
+AF_UNIX = getattr(socket, "AF_UNIX", object())
+
+_subprocesses_started = set()
+_pids_started = set()
+
+
+# ===================================================================
+# --- threads
+# ===================================================================
+
+
+class ThreadTask(threading.Thread):
+ """A thread task which does nothing expect staying alive."""
+
+ def __init__(self):
+ super().__init__()
+ self._running = False
+ self._interval = 0.001
+ self._flag = threading.Event()
+
+ def __repr__(self):
+ name = self.__class__.__name__
+ return '<%s running=%s at %#x>' % (name, self._running, id(self))
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.stop()
+
+ def start(self):
+ """Start thread and keep it running until an explicit
+ stop() request. Polls for shutdown every 'timeout' seconds.
+ """
+ if self._running:
+ raise ValueError("already started")
+ threading.Thread.start(self)
+ self._flag.wait()
+
+ def run(self):
+ self._running = True
+ self._flag.set()
+ while self._running:
+ time.sleep(self._interval)
+
+ def stop(self):
+ """Stop thread execution and and waits until it is stopped."""
+ if not self._running:
+ raise ValueError("already stopped")
+ self._running = False
+ self.join()
+
+
+# ===================================================================
+# --- subprocesses
+# ===================================================================
+
+
+def _reap_children_on_err(fun):
+ @functools.wraps(fun)
+ def wrapper(*args, **kwargs):
+ try:
+ return fun(*args, **kwargs)
+ except Exception:
+ reap_children()
+ raise
+ return wrapper
+
+
+@_reap_children_on_err
+def spawn_testproc(cmd=None, **kwds):
+ """Creates a python subprocess which does nothing for 60 secs and
+ return it as a subprocess.Popen instance.
+ If "cmd" is specified that is used instead of python.
+ By default stdin and stdout are redirected to /dev/null.
+ It also attempts to make sure the process is in a reasonably
+ initialized state.
+ The process is registered for cleanup on reap_children().
+ """
+ kwds.setdefault("stdin", DEVNULL)
+ kwds.setdefault("stdout", DEVNULL)
+ kwds.setdefault("cwd", os.getcwd())
+ kwds.setdefault("env", os.environ)
+ if WINDOWS:
+ # Prevents the subprocess to open error dialogs. This will also
+ # cause stderr to be suppressed, which is suboptimal in order
+ # to debug broken tests.
+ CREATE_NO_WINDOW = 0x8000000
+ kwds.setdefault("creationflags", CREATE_NO_WINDOW)
+ if cmd is None:
+ testfn = get_testfn()
+ try:
+ safe_rmpath(testfn)
+ pyline = "from time import sleep;" \
+ "open(r'%s', 'w').close();" \
+ "sleep(60);" % testfn
+ cmd = [PYTHON_EXE, "-c", pyline]
+ sproc = subprocess.Popen(cmd, **kwds)
+ _subprocesses_started.add(sproc)
+ wait_for_file(testfn, delete=True, empty=True)
+ finally:
+ safe_rmpath(testfn)
+ else:
+ sproc = subprocess.Popen(cmd, **kwds)
+ _subprocesses_started.add(sproc)
+ wait_for_pid(sproc.pid)
+ return sproc
+
+
+@_reap_children_on_err
+def spawn_children_pair():
+ """Create a subprocess which creates another one as in:
+ A (us) -> B (child) -> C (grandchild).
+ Return a (child, grandchild) tuple.
+ The 2 processes are fully initialized and will live for 60 secs
+ and are registered for cleanup on reap_children().
+ """
+ tfile = None
+ testfn = get_testfn(dir=os.getcwd())
+ try:
+ s = textwrap.dedent("""\
+ import subprocess, os, sys, time
+ s = "import os, time;"
+ s += "f = open('%s', 'w');"
+ s += "f.write(str(os.getpid()));"
+ s += "f.close();"
+ s += "time.sleep(60);"
+ p = subprocess.Popen([r'%s', '-c', s])
+ p.wait()
+ """ % (os.path.basename(testfn), PYTHON_EXE))
+ # On Windows if we create a subprocess with CREATE_NO_WINDOW flag
+ # set (which is the default) a "conhost.exe" extra process will be
+ # spawned as a child. We don't want that.
+ if WINDOWS:
+ subp, tfile = pyrun(s, creationflags=0)
+ else:
+ subp, tfile = pyrun(s)
+ child = psutil.Process(subp.pid)
+ grandchild_pid = int(wait_for_file(testfn, delete=True, empty=False))
+ _pids_started.add(grandchild_pid)
+ grandchild = psutil.Process(grandchild_pid)
+ return (child, grandchild)
+ finally:
+ safe_rmpath(testfn)
+ if tfile is not None:
+ safe_rmpath(tfile)
+
+
+def spawn_zombie():
+ """Create a zombie process and return a (parent, zombie) process tuple.
+ In order to kill the zombie parent must be terminate()d first, then
+ zombie must be wait()ed on.
+ """
+ assert psutil.POSIX
+ unix_file = get_testfn()
+ src = textwrap.dedent("""\
+ import os, sys, time, socket, contextlib
+ child_pid = os.fork()
+ if child_pid > 0:
+ time.sleep(3000)
+ else:
+ # this is the zombie process
+ s = socket.socket(socket.AF_UNIX)
+ with contextlib.closing(s):
+ s.connect('%s')
+ if sys.version_info < (3, ):
+ pid = str(os.getpid())
+ else:
+ pid = bytes(str(os.getpid()), 'ascii')
+ s.sendall(pid)
+ """ % unix_file)
+ tfile = None
+ sock = bind_unix_socket(unix_file)
+ try:
+ sock.settimeout(GLOBAL_TIMEOUT)
+ parent, tfile = pyrun(src)
+ conn, _ = sock.accept()
+ try:
+ select.select([conn.fileno()], [], [], GLOBAL_TIMEOUT)
+ zpid = int(conn.recv(1024))
+ _pids_started.add(zpid)
+ zombie = psutil.Process(zpid)
+ call_until(lambda: zombie.status(), "ret == psutil.STATUS_ZOMBIE")
+ return (parent, zombie)
+ finally:
+ conn.close()
+ finally:
+ sock.close()
+ safe_rmpath(unix_file)
+ if tfile is not None:
+ safe_rmpath(tfile)
+
+
+@_reap_children_on_err
+def pyrun(src, **kwds):
+ """Run python 'src' code string in a separate interpreter.
+ Returns a subprocess.Popen instance and the test file where the source
+ code was written.
+ """
+ kwds.setdefault("stdout", None)
+ kwds.setdefault("stderr", None)
+ srcfile = get_testfn()
+ try:
+ with open(srcfile, 'wt') as f:
+ f.write(src)
+ subp = spawn_testproc([PYTHON_EXE, f.name], **kwds)
+ wait_for_pid(subp.pid)
+ return (subp, srcfile)
+ except Exception:
+ safe_rmpath(srcfile)
+ raise
+
+
+@_reap_children_on_err
+def sh(cmd, **kwds):
+ """run cmd in a subprocess and return its output.
+ raises RuntimeError on error.
+ """
+ # Prevents subprocess to open error dialogs in case of error.
+ flags = 0x8000000 if WINDOWS else 0
+ kwds.setdefault("stdout", subprocess.PIPE)
+ kwds.setdefault("stderr", subprocess.PIPE)
+ kwds.setdefault("universal_newlines", True)
+ kwds.setdefault("creationflags", flags)
+ if isinstance(cmd, str):
+ cmd = shlex.split(cmd)
+ p = subprocess.Popen(cmd, **kwds)
+ _subprocesses_started.add(p)
+ if PY3:
+ stdout, stderr = p.communicate(timeout=GLOBAL_TIMEOUT)
+ else:
+ stdout, stderr = p.communicate()
+ if p.returncode != 0:
+ raise RuntimeError(stderr)
+ if stderr:
+ warn(stderr)
+ if stdout.endswith('\n'):
+ stdout = stdout[:-1]
+ return stdout
+
+
+def terminate(proc_or_pid, sig=signal.SIGTERM, wait_timeout=GLOBAL_TIMEOUT):
+ """Terminate a process and wait() for it.
+ Process can be a PID or an instance of psutil.Process(),
+ subprocess.Popen() or psutil.Popen().
+ If it's a subprocess.Popen() or psutil.Popen() instance also closes
+ its stdin / stdout / stderr fds.
+ PID is wait()ed even if the process is already gone (kills zombies).
+ Does nothing if the process does not exist.
+ Return process exit status.
+ """
+ def wait(proc, timeout):
+ if isinstance(proc, subprocess.Popen) and not PY3:
+ proc.wait()
+ else:
+ proc.wait(timeout)
+ if WINDOWS and isinstance(proc, subprocess.Popen):
+ # Otherwise PID may still hang around.
+ try:
+ return psutil.Process(proc.pid).wait(timeout)
+ except psutil.NoSuchProcess:
+ pass
+
+ def sendsig(proc, sig):
+ # XXX: otherwise the build hangs for some reason.
+ if MACOS and GITHUB_ACTIONS:
+ sig = signal.SIGKILL
+ # If the process received SIGSTOP, SIGCONT is necessary first,
+ # otherwise SIGTERM won't work.
+ if POSIX and sig != signal.SIGKILL:
+ proc.send_signal(signal.SIGCONT)
+ proc.send_signal(sig)
+
+ def term_subprocess_proc(proc, timeout):
+ try:
+ sendsig(proc, sig)
+ except OSError as err:
+ if WINDOWS and err.winerror == 6: # "invalid handle"
+ pass
+ elif err.errno != errno.ESRCH:
+ raise
+ return wait(proc, timeout)
+
+ def term_psutil_proc(proc, timeout):
+ try:
+ sendsig(proc, sig)
+ except psutil.NoSuchProcess:
+ pass
+ return wait(proc, timeout)
+
+ def term_pid(pid, timeout):
+ try:
+ proc = psutil.Process(pid)
+ except psutil.NoSuchProcess:
+ # Needed to kill zombies.
+ if POSIX:
+ return wait_pid(pid, timeout)
+ else:
+ return term_psutil_proc(proc, timeout)
+
+ def flush_popen(proc):
+ if proc.stdout:
+ proc.stdout.close()
+ if proc.stderr:
+ proc.stderr.close()
+ # Flushing a BufferedWriter may raise an error.
+ if proc.stdin:
+ proc.stdin.close()
+
+ p = proc_or_pid
+ try:
+ if isinstance(p, int):
+ return term_pid(p, wait_timeout)
+ elif isinstance(p, (psutil.Process, psutil.Popen)):
+ return term_psutil_proc(p, wait_timeout)
+ elif isinstance(p, subprocess.Popen):
+ return term_subprocess_proc(p, wait_timeout)
+ else:
+ raise TypeError("wrong type %r" % p)
+ finally:
+ if isinstance(p, (subprocess.Popen, psutil.Popen)):
+ flush_popen(p)
+ pid = p if isinstance(p, int) else p.pid
+ assert not psutil.pid_exists(pid), pid
+
+
+def reap_children(recursive=False):
+ """Terminate and wait() any subprocess started by this test suite
+ and any children currently running, ensuring that no processes stick
+ around to hog resources.
+ If recursive is True it also tries to terminate and wait()
+ all grandchildren started by this process.
+ """
+ # Get the children here before terminating them, as in case of
+ # recursive=True we don't want to lose the intermediate reference
+ # pointing to the grandchildren.
+ children = psutil.Process().children(recursive=recursive)
+
+ # Terminate subprocess.Popen.
+ while _subprocesses_started:
+ subp = _subprocesses_started.pop()
+ terminate(subp)
+
+ # Collect started pids.
+ while _pids_started:
+ pid = _pids_started.pop()
+ terminate(pid)
+
+ # Terminate children.
+ if children:
+ for p in children:
+ terminate(p, wait_timeout=None)
+ gone, alive = psutil.wait_procs(children, timeout=GLOBAL_TIMEOUT)
+ for p in alive:
+ warn("couldn't terminate process %r; attempting kill()" % p)
+ terminate(p, sig=signal.SIGKILL)
+
+
+# ===================================================================
+# --- OS
+# ===================================================================
+
+
+def kernel_version():
+ """Return a tuple such as (2, 6, 36)."""
+ if not POSIX:
+ raise NotImplementedError("not POSIX")
+ s = ""
+ uname = os.uname()[2]
+ for c in uname:
+ if c.isdigit() or c == '.':
+ s += c
+ else:
+ break
+ if not s:
+ raise ValueError("can't parse %r" % uname)
+ minor = 0
+ micro = 0
+ nums = s.split('.')
+ major = int(nums[0])
+ if len(nums) >= 2:
+ minor = int(nums[1])
+ if len(nums) >= 3:
+ micro = int(nums[2])
+ return (major, minor, micro)
+
+
+def get_winver():
+ if not WINDOWS:
+ raise NotImplementedError("not WINDOWS")
+ wv = sys.getwindowsversion()
+ if hasattr(wv, 'service_pack_major'): # python >= 2.7
+ sp = wv.service_pack_major or 0
+ else:
+ r = re.search(r"\s\d$", wv[4])
+ if r:
+ sp = int(r.group(0))
+ else:
+ sp = 0
+ return (wv[0], wv[1], sp)
+
+
+# ===================================================================
+# --- sync primitives
+# ===================================================================
+
+
+class retry(object):
+ """A retry decorator."""
+
+ def __init__(self,
+ exception=Exception,
+ timeout=None,
+ retries=None,
+ interval=0.001,
+ logfun=None,
+ ):
+ if timeout and retries:
+ raise ValueError("timeout and retries args are mutually exclusive")
+ self.exception = exception
+ self.timeout = timeout
+ self.retries = retries
+ self.interval = interval
+ self.logfun = logfun
+
+ def __iter__(self):
+ if self.timeout:
+ stop_at = time.time() + self.timeout
+ while time.time() < stop_at:
+ yield
+ elif self.retries:
+ for _ in range(self.retries):
+ yield
+ else:
+ while True:
+ yield
+
+ def sleep(self):
+ if self.interval is not None:
+ time.sleep(self.interval)
+
+ def __call__(self, fun):
+ @functools.wraps(fun)
+ def wrapper(*args, **kwargs):
+ exc = None
+ for _ in self:
+ try:
+ return fun(*args, **kwargs)
+ except self.exception as _: # NOQA
+ exc = _
+ if self.logfun is not None:
+ self.logfun(exc)
+ self.sleep()
+ continue
+ if PY3:
+ raise exc
+ else:
+ raise
+
+ # This way the user of the decorated function can change config
+ # parameters.
+ wrapper.decorator = self
+ return wrapper
+
+
+@retry(exception=psutil.NoSuchProcess, logfun=None, timeout=GLOBAL_TIMEOUT,
+ interval=0.001)
+def wait_for_pid(pid):
+ """Wait for pid to show up in the process list then return.
+ Used in the test suite to give time the sub process to initialize.
+ """
+ psutil.Process(pid)
+ if WINDOWS:
+ # give it some more time to allow better initialization
+ time.sleep(0.01)
+
+
+@retry(exception=(FileNotFoundError, AssertionError), logfun=None,
+ timeout=GLOBAL_TIMEOUT, interval=0.001)
+def wait_for_file(fname, delete=True, empty=False):
+ """Wait for a file to be written on disk with some content."""
+ with open(fname, "rb") as f:
+ data = f.read()
+ if not empty:
+ assert data
+ if delete:
+ safe_rmpath(fname)
+ return data
+
+
+@retry(exception=AssertionError, logfun=None, timeout=GLOBAL_TIMEOUT,
+ interval=0.001)
+def call_until(fun, expr):
+ """Keep calling function for timeout secs and exit if eval()
+ expression is True.
+ """
+ ret = fun()
+ assert eval(expr)
+ return ret
+
+
+# ===================================================================
+# --- fs
+# ===================================================================
+
+
+def safe_rmpath(path):
+ """Convenience function for removing temporary test files or dirs."""
+ def retry_fun(fun):
+ # On Windows it could happen that the file or directory has
+ # open handles or references preventing the delete operation
+ # to succeed immediately, so we retry for a while. See:
+ # https://bugs.python.org/issue33240
+ stop_at = time.time() + GLOBAL_TIMEOUT
+ while time.time() < stop_at:
+ try:
+ return fun()
+ except FileNotFoundError:
+ pass
+ except WindowsError as _:
+ err = _
+ warn("ignoring %s" % (str(err)))
+ time.sleep(0.01)
+ raise err
+
+ try:
+ st = os.stat(path)
+ if stat.S_ISDIR(st.st_mode):
+ fun = functools.partial(shutil.rmtree, path)
+ else:
+ fun = functools.partial(os.remove, path)
+ if POSIX:
+ fun()
+ else:
+ retry_fun(fun)
+ except FileNotFoundError:
+ pass
+
+
+def safe_mkdir(dir):
+ """Convenience function for creating a directory."""
+ try:
+ os.mkdir(dir)
+ except FileExistsError:
+ pass
+
+
+@contextlib.contextmanager
+def chdir(dirname):
+ """Context manager which temporarily changes the current directory."""
+ curdir = os.getcwd()
+ try:
+ os.chdir(dirname)
+ yield
+ finally:
+ os.chdir(curdir)
+
+
+def create_exe(outpath, c_code=None):
+ """Creates an executable file in the given location."""
+ assert not os.path.exists(outpath), outpath
+ if c_code:
+ if not which("gcc"):
+ raise ValueError("gcc is not installed")
+ if isinstance(c_code, bool): # c_code is True
+ c_code = textwrap.dedent(
+ """
+ #include <unistd.h>
+ int main() {
+ pause();
+ return 1;
+ }
+ """)
+ assert isinstance(c_code, str), c_code
+ with open(get_testfn(suffix='.c'), 'wt') as f:
+ f.write(c_code)
+ try:
+ subprocess.check_call(["gcc", f.name, "-o", outpath])
+ finally:
+ safe_rmpath(f.name)
+ else:
+ # copy python executable
+ shutil.copyfile(PYTHON_EXE, outpath)
+ if POSIX:
+ st = os.stat(outpath)
+ os.chmod(outpath, st.st_mode | stat.S_IEXEC)
+
+
+def get_testfn(suffix="", dir=None):
+ """Return an absolute pathname of a file or dir that did not
+ exist at the time this call is made. Also schedule it for safe
+ deletion at interpreter exit. It's technically racy but probably
+ not really due to the time variant.
+ """
+ while True:
+ name = tempfile.mktemp(prefix=TESTFN_PREFIX, suffix=suffix, dir=dir)
+ if not os.path.exists(name): # also include dirs
+ return os.path.realpath(name) # needed for OSX
+
+
+# ===================================================================
+# --- testing
+# ===================================================================
+
+
+class TestCase(unittest.TestCase):
+
+ # Print a full path representation of the single unit tests
+ # being run.
+ def __str__(self):
+ fqmod = self.__class__.__module__
+ if not fqmod.startswith('psutil.'):
+ fqmod = 'psutil.tests.' + fqmod
+ return "%s.%s.%s" % (
+ fqmod, self.__class__.__name__, self._testMethodName)
+
+ # assertRaisesRegexp renamed to assertRaisesRegex in 3.3;
+ # add support for the new name.
+ if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
+ assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
+
+ # ...otherwise multiprocessing.Pool complains
+ if not PY3:
+ def runTest(self):
+ pass
+
+ @contextlib.contextmanager
+ def subTest(self, *args, **kw):
+ # fake it for python 2.7
+ yield
+
+
+# monkey patch default unittest.TestCase
+unittest.TestCase = TestCase
+
+
+class PsutilTestCase(TestCase):
+ """Test class providing auto-cleanup wrappers on top of process
+ test utilities.
+ """
+
+ def get_testfn(self, suffix="", dir=None):
+ fname = get_testfn(suffix=suffix, dir=dir)
+ self.addCleanup(safe_rmpath, fname)
+ return fname
+
+ def spawn_testproc(self, *args, **kwds):
+ sproc = spawn_testproc(*args, **kwds)
+ self.addCleanup(terminate, sproc)
+ return sproc
+
+ def spawn_children_pair(self):
+ child1, child2 = spawn_children_pair()
+ self.addCleanup(terminate, child2)
+ self.addCleanup(terminate, child1) # executed first
+ return (child1, child2)
+
+ def spawn_zombie(self):
+ parent, zombie = spawn_zombie()
+ self.addCleanup(terminate, zombie)
+ self.addCleanup(terminate, parent) # executed first
+ return (parent, zombie)
+
+ def pyrun(self, *args, **kwds):
+ sproc, srcfile = pyrun(*args, **kwds)
+ self.addCleanup(safe_rmpath, srcfile)
+ self.addCleanup(terminate, sproc) # executed first
+ return sproc
+
+ def assertProcessGone(self, proc):
+ self.assertRaises(psutil.NoSuchProcess, psutil.Process, proc.pid)
+ if isinstance(proc, (psutil.Process, psutil.Popen)):
+ assert not proc.is_running()
+ try:
+ status = proc.status()
+ except psutil.NoSuchProcess:
+ pass
+ else:
+ raise AssertionError("Process.status() didn't raise exception "
+ "(status=%s)" % status)
+ proc.wait(timeout=0) # assert not raise TimeoutExpired
+ assert not psutil.pid_exists(proc.pid), proc.pid
+ self.assertNotIn(proc.pid, psutil.pids())
+
+
+@unittest.skipIf(PYPY, "unreliable on PYPY")
+class TestMemoryLeak(PsutilTestCase):
+ """Test framework class for detecting function memory leaks,
+ typically functions implemented in C which forgot to free() memory
+ from the heap. It does so by checking whether the process memory
+ usage increased before and after calling the function many times.
+
+ Note that this is hard (probably impossible) to do reliably, due
+ to how the OS handles memory, the GC and so on (memory can even
+ decrease!). In order to avoid false positives, in case of failure
+ (mem > 0) we retry the test for up to 5 times, increasing call
+ repetitions each time. If the memory keeps increasing then it's a
+ failure.
+
+ If available (Linux, OSX, Windows), USS memory is used for comparison,
+ since it's supposed to be more precise, see:
+ https://gmpy.dev/blog/2016/real-process-memory-and-environ-in-python
+ If not, RSS memory is used. mallinfo() on Linux and _heapwalk() on
+ Windows may give even more precision, but at the moment are not
+ implemented.
+
+ PyPy appears to be completely unstable for this framework, probably
+ because of its JIT, so tests on PYPY are skipped.
+
+ Usage:
+
+ class TestLeaks(psutil.tests.TestMemoryLeak):
+
+ def test_fun(self):
+ self.execute(some_function)
+ """
+ # Configurable class attrs.
+ times = 200
+ warmup_times = 10
+ tolerance = 0 # memory
+ retries = 10 if CI_TESTING else 5
+ verbose = True
+ _thisproc = psutil.Process()
+ _psutil_debug_orig = bool(os.getenv('PSUTIL_DEBUG', 0))
+
+ @classmethod
+ def setUpClass(cls):
+ psutil._set_debug(False) # avoid spamming to stderr
+
+ @classmethod
+ def tearDownClass(cls):
+ psutil._set_debug(cls._psutil_debug_orig)
+
+ def _get_mem(self):
+ # USS is the closest thing we have to "real" memory usage and it
+ # should be less likely to produce false positives.
+ mem = self._thisproc.memory_full_info()
+ return getattr(mem, "uss", mem.rss)
+
+ def _get_num_fds(self):
+ if POSIX:
+ return self._thisproc.num_fds()
+ else:
+ return self._thisproc.num_handles()
+
+ def _log(self, msg):
+ if self.verbose:
+ print_color(msg, color="yellow", file=sys.stderr)
+
+ def _check_fds(self, fun):
+ """Makes sure num_fds() (POSIX) or num_handles() (Windows) does
+ not increase after calling a function. Used to discover forgotten
+ close(2) and CloseHandle syscalls.
+ """
+ before = self._get_num_fds()
+ self.call(fun)
+ after = self._get_num_fds()
+ diff = after - before
+ if diff < 0:
+ raise self.fail("negative diff %r (gc probably collected a "
+ "resource from a previous test)" % diff)
+ if diff > 0:
+ type_ = "fd" if POSIX else "handle"
+ if diff > 1:
+ type_ += "s"
+ msg = "%s unclosed %s after calling %r" % (diff, type_, fun)
+ raise self.fail(msg)
+
+ def _call_ntimes(self, fun, times):
+ """Get 2 distinct memory samples, before and after having
+ called fun repeatedly, and return the memory difference.
+ """
+ gc.collect(generation=1)
+ mem1 = self._get_mem()
+ for x in range(times):
+ ret = self.call(fun)
+ del x, ret
+ gc.collect(generation=1)
+ mem2 = self._get_mem()
+ self.assertEqual(gc.garbage, [])
+ diff = mem2 - mem1 # can also be negative
+ return diff
+
+ def _check_mem(self, fun, times, warmup_times, retries, tolerance):
+ messages = []
+ prev_mem = 0
+ increase = times
+ for idx in range(1, retries + 1):
+ mem = self._call_ntimes(fun, times)
+ msg = "Run #%s: extra-mem=%s, per-call=%s, calls=%s" % (
+ idx, bytes2human(mem), bytes2human(mem / times), times)
+ messages.append(msg)
+ success = mem <= tolerance or mem <= prev_mem
+ if success:
+ if idx > 1:
+ self._log(msg)
+ return
+ else:
+ if idx == 1:
+ print() # NOQA
+ self._log(msg)
+ times += increase
+ prev_mem = mem
+ raise self.fail(". ".join(messages))
+
+ # ---
+
+ def call(self, fun):
+ return fun()
+
+ def execute(self, fun, times=None, warmup_times=None, retries=None,
+ tolerance=None):
+ """Test a callable."""
+ times = times if times is not None else self.times
+ warmup_times = warmup_times if warmup_times is not None \
+ else self.warmup_times
+ retries = retries if retries is not None else self.retries
+ tolerance = tolerance if tolerance is not None else self.tolerance
+ try:
+ assert times >= 1, "times must be >= 1"
+ assert warmup_times >= 0, "warmup_times must be >= 0"
+ assert retries >= 0, "retries must be >= 0"
+ assert tolerance >= 0, "tolerance must be >= 0"
+ except AssertionError as err:
+ raise ValueError(str(err))
+
+ self._call_ntimes(fun, warmup_times) # warm up
+ self._check_fds(fun)
+ self._check_mem(fun, times=times, warmup_times=warmup_times,
+ retries=retries, tolerance=tolerance)
+
+ def execute_w_exc(self, exc, fun, **kwargs):
+ """Convenience method to test a callable while making sure it
+ raises an exception on every call.
+ """
+ def call():
+ self.assertRaises(exc, fun)
+
+ self.execute(call, **kwargs)
+
+
+def print_sysinfo():
+ import collections
+ import datetime
+ import getpass
+ import locale
+ import platform
+ import pprint
+ try:
+ import pip
+ except ImportError:
+ pip = None
+ try:
+ import wheel
+ except ImportError:
+ wheel = None
+
+ info = collections.OrderedDict()
+
+ # OS
+ if psutil.LINUX and which('lsb_release'):
+ info['OS'] = sh('lsb_release -d -s')
+ elif psutil.OSX:
+ info['OS'] = 'Darwin %s' % platform.mac_ver()[0]
+ elif psutil.WINDOWS:
+ info['OS'] = "Windows " + ' '.join(
+ map(str, platform.win32_ver()))
+ if hasattr(platform, 'win32_edition'):
+ info['OS'] += ", " + platform.win32_edition()
+ else:
+ info['OS'] = "%s %s" % (platform.system(), platform.version())
+ info['arch'] = ', '.join(
+ list(platform.architecture()) + [platform.machine()])
+ if psutil.POSIX:
+ info['kernel'] = platform.uname()[2]
+
+ # python
+ info['python'] = ', '.join([
+ platform.python_implementation(),
+ platform.python_version(),
+ platform.python_compiler()])
+ info['pip'] = getattr(pip, '__version__', 'not installed')
+ if wheel is not None:
+ info['pip'] += " (wheel=%s)" % wheel.__version__
+
+ # UNIX
+ if psutil.POSIX:
+ if which('gcc'):
+ out = sh(['gcc', '--version'])
+ info['gcc'] = str(out).split('\n')[0]
+ else:
+ info['gcc'] = 'not installed'
+ s = platform.libc_ver()[1]
+ if s:
+ info['glibc'] = s
+
+ # system
+ info['fs-encoding'] = sys.getfilesystemencoding()
+ lang = locale.getlocale()
+ info['lang'] = '%s, %s' % (lang[0], lang[1])
+ info['boot-time'] = datetime.datetime.fromtimestamp(
+ psutil.boot_time()).strftime("%Y-%m-%d %H:%M:%S")
+ info['time'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ info['user'] = getpass.getuser()
+ info['home'] = os.path.expanduser("~")
+ info['cwd'] = os.getcwd()
+ info['pyexe'] = PYTHON_EXE
+ info['hostname'] = platform.node()
+ info['PID'] = os.getpid()
+
+ # metrics
+ info['cpus'] = psutil.cpu_count()
+ info['loadavg'] = "%.1f%%, %.1f%%, %.1f%%" % (
+ tuple([x / psutil.cpu_count() * 100 for x in psutil.getloadavg()]))
+ mem = psutil.virtual_memory()
+ info['memory'] = "%s%%, used=%s, total=%s" % (
+ int(mem.percent), bytes2human(mem.used), bytes2human(mem.total))
+ swap = psutil.swap_memory()
+ info['swap'] = "%s%%, used=%s, total=%s" % (
+ int(swap.percent), bytes2human(swap.used), bytes2human(swap.total))
+ info['pids'] = len(psutil.pids())
+ pinfo = psutil.Process().as_dict()
+ pinfo.pop('memory_maps', None)
+ info['proc'] = pprint.pformat(pinfo)
+
+ print("=" * 70, file=sys.stderr) # NOQA
+ for k, v in info.items():
+ print("%-17s %s" % (k + ':', v), file=sys.stderr) # NOQA
+ print("=" * 70, file=sys.stderr) # NOQA
+ sys.stdout.flush()
+
+
+def _get_eligible_cpu():
+ p = psutil.Process()
+ if hasattr(p, "cpu_num"):
+ return p.cpu_num()
+ elif hasattr(p, "cpu_affinity"):
+ return random.choice(p.cpu_affinity())
+ return 0
+
+
+class process_namespace:
+ """A container that lists all Process class method names + some
+ reasonable parameters to be called with. Utility methods (parent(),
+ children(), ...) are excluded.
+
+ >>> ns = process_namespace(psutil.Process())
+ >>> for fun, name in ns.iter(ns.getters):
+ ... fun()
+ """
+ utils = [
+ ('cpu_percent', (), {}),
+ ('memory_percent', (), {}),
+ ]
+
+ ignored = [
+ ('as_dict', (), {}),
+ ('children', (), {'recursive': True}),
+ ('is_running', (), {}),
+ ('memory_info_ex', (), {}),
+ ('oneshot', (), {}),
+ ('parent', (), {}),
+ ('parents', (), {}),
+ ('pid', (), {}),
+ ('wait', (0, ), {}),
+ ]
+
+ getters = [
+ ('cmdline', (), {}),
+ ('connections', (), {'kind': 'all'}),
+ ('cpu_times', (), {}),
+ ('create_time', (), {}),
+ ('cwd', (), {}),
+ ('exe', (), {}),
+ ('memory_full_info', (), {}),
+ ('memory_info', (), {}),
+ ('name', (), {}),
+ ('nice', (), {}),
+ ('num_ctx_switches', (), {}),
+ ('num_threads', (), {}),
+ ('open_files', (), {}),
+ ('ppid', (), {}),
+ ('status', (), {}),
+ ('threads', (), {}),
+ ('username', (), {}),
+ ]
+ if POSIX:
+ getters += [('uids', (), {})]
+ getters += [('gids', (), {})]
+ getters += [('terminal', (), {})]
+ getters += [('num_fds', (), {})]
+ if HAS_PROC_IO_COUNTERS:
+ getters += [('io_counters', (), {})]
+ if HAS_IONICE:
+ getters += [('ionice', (), {})]
+ if HAS_RLIMIT:
+ getters += [('rlimit', (psutil.RLIMIT_NOFILE, ), {})]
+ if HAS_CPU_AFFINITY:
+ getters += [('cpu_affinity', (), {})]
+ if HAS_PROC_CPU_NUM:
+ getters += [('cpu_num', (), {})]
+ if HAS_ENVIRON:
+ getters += [('environ', (), {})]
+ if WINDOWS:
+ getters += [('num_handles', (), {})]
+ if HAS_MEMORY_MAPS:
+ getters += [('memory_maps', (), {'grouped': False})]
+
+ setters = []
+ if POSIX:
+ setters += [('nice', (0, ), {})]
+ else:
+ setters += [('nice', (psutil.NORMAL_PRIORITY_CLASS, ), {})]
+ if HAS_RLIMIT:
+ setters += [('rlimit', (psutil.RLIMIT_NOFILE, (1024, 4096)), {})]
+ if HAS_IONICE:
+ if LINUX:
+ setters += [('ionice', (psutil.IOPRIO_CLASS_NONE, 0), {})]
+ else:
+ setters += [('ionice', (psutil.IOPRIO_NORMAL, ), {})]
+ if HAS_CPU_AFFINITY:
+ setters += [('cpu_affinity', ([_get_eligible_cpu()], ), {})]
+
+ killers = [
+ ('send_signal', (signal.SIGTERM, ), {}),
+ ('suspend', (), {}),
+ ('resume', (), {}),
+ ('terminate', (), {}),
+ ('kill', (), {}),
+ ]
+ if WINDOWS:
+ killers += [('send_signal', (signal.CTRL_C_EVENT, ), {})]
+ killers += [('send_signal', (signal.CTRL_BREAK_EVENT, ), {})]
+
+ all = utils + getters + setters + killers
+
+ def __init__(self, proc):
+ self._proc = proc
+
+ def iter(self, ls, clear_cache=True):
+ """Given a list of tuples yields a set of (fun, fun_name) tuples
+ in random order.
+ """
+ ls = list(ls)
+ random.shuffle(ls)
+ for fun_name, args, kwds in ls:
+ if clear_cache:
+ self.clear_cache()
+ fun = getattr(self._proc, fun_name)
+ fun = functools.partial(fun, *args, **kwds)
+ yield (fun, fun_name)
+
+ def clear_cache(self):
+ """Clear the cache of a Process instance."""
+ self._proc._init(self._proc.pid, _ignore_nsp=True)
+
+ @classmethod
+ def test_class_coverage(cls, test_class, ls):
+ """Given a TestCase instance and a list of tuples checks that
+ the class defines the required test method names.
+ """
+ for fun_name, _, _ in ls:
+ meth_name = 'test_' + fun_name
+ if not hasattr(test_class, meth_name):
+ msg = "%r class should define a '%s' method" % (
+ test_class.__class__.__name__, meth_name)
+ raise AttributeError(msg)
+
+ @classmethod
+ def test(cls):
+ this = set([x[0] for x in cls.all])
+ ignored = set([x[0] for x in cls.ignored])
+ klass = set([x for x in dir(psutil.Process) if x[0] != '_'])
+ leftout = (this | ignored) ^ klass
+ if leftout:
+ raise ValueError("uncovered Process class names: %r" % leftout)
+
+
+class system_namespace:
+ """A container that lists all the module-level, system-related APIs.
+ Utilities such as cpu_percent() are excluded. Usage:
+
+ >>> ns = system_namespace
+ >>> for fun, name in ns.iter(ns.getters):
+ ... fun()
+ """
+ getters = [
+ ('boot_time', (), {}),
+ ('cpu_count', (), {'logical': False}),
+ ('cpu_count', (), {'logical': True}),
+ ('cpu_stats', (), {}),
+ ('cpu_times', (), {'percpu': False}),
+ ('cpu_times', (), {'percpu': True}),
+ ('disk_io_counters', (), {'perdisk': True}),
+ ('disk_partitions', (), {'all': True}),
+ ('disk_usage', (os.getcwd(), ), {}),
+ ('net_connections', (), {'kind': 'all'}),
+ ('net_if_addrs', (), {}),
+ ('net_if_stats', (), {}),
+ ('net_io_counters', (), {'pernic': True}),
+ ('pid_exists', (os.getpid(), ), {}),
+ ('pids', (), {}),
+ ('swap_memory', (), {}),
+ ('users', (), {}),
+ ('virtual_memory', (), {}),
+ ]
+ if HAS_CPU_FREQ:
+ getters += [('cpu_freq', (), {'percpu': True})]
+ if HAS_GETLOADAVG:
+ getters += [('getloadavg', (), {})]
+ if HAS_SENSORS_TEMPERATURES:
+ getters += [('sensors_temperatures', (), {})]
+ if HAS_SENSORS_FANS:
+ getters += [('sensors_fans', (), {})]
+ if HAS_SENSORS_BATTERY:
+ getters += [('sensors_battery', (), {})]
+ if WINDOWS:
+ getters += [('win_service_iter', (), {})]
+ getters += [('win_service_get', ('alg', ), {})]
+
+ ignored = [
+ ('process_iter', (), {}),
+ ('wait_procs', ([psutil.Process()], ), {}),
+ ('cpu_percent', (), {}),
+ ('cpu_times_percent', (), {}),
+ ]
+
+ all = getters
+
+ @staticmethod
+ def iter(ls):
+ """Given a list of tuples yields a set of (fun, fun_name) tuples
+ in random order.
+ """
+ ls = list(ls)
+ random.shuffle(ls)
+ for fun_name, args, kwds in ls:
+ fun = getattr(psutil, fun_name)
+ fun = functools.partial(fun, *args, **kwds)
+ yield (fun, fun_name)
+
+ test_class_coverage = process_namespace.test_class_coverage
+
+
+def serialrun(klass):
+ """A decorator to mark a TestCase class. When running parallel tests,
+ class' unit tests will be run serially (1 process).
+ """
+ # assert issubclass(klass, unittest.TestCase), klass
+ assert inspect.isclass(klass), klass
+ klass._serialrun = True
+ return klass
+
+
+def retry_on_failure(retries=NO_RETRIES):
+ """Decorator which runs a test function and retries N times before
+ actually failing.
+ """
+ def logfun(exc):
+ print("%r, retrying" % exc, file=sys.stderr) # NOQA
+
+ return retry(exception=AssertionError, timeout=None, retries=retries,
+ logfun=logfun)
+
+
+def skip_on_access_denied(only_if=None):
+ """Decorator to Ignore AccessDenied exceptions."""
+ def decorator(fun):
+ @functools.wraps(fun)
+ def wrapper(*args, **kwargs):
+ try:
+ return fun(*args, **kwargs)
+ except psutil.AccessDenied:
+ if only_if is not None:
+ if not only_if:
+ raise
+ raise unittest.SkipTest("raises AccessDenied")
+ return wrapper
+ return decorator
+
+
+def skip_on_not_implemented(only_if=None):
+ """Decorator to Ignore NotImplementedError exceptions."""
+ def decorator(fun):
+ @functools.wraps(fun)
+ def wrapper(*args, **kwargs):
+ try:
+ return fun(*args, **kwargs)
+ except NotImplementedError:
+ if only_if is not None:
+ if not only_if:
+ raise
+ msg = "%r was skipped because it raised NotImplementedError" \
+ % fun.__name__
+ raise unittest.SkipTest(msg)
+ return wrapper
+ return decorator
+
+
+# ===================================================================
+# --- network
+# ===================================================================
+
+
+# XXX: no longer used
+def get_free_port(host='127.0.0.1'):
+ """Return an unused TCP port. Subject to race conditions."""
+ with contextlib.closing(socket.socket()) as sock:
+ sock.bind((host, 0))
+ return sock.getsockname()[1]
+
+
+def bind_socket(family=AF_INET, type=SOCK_STREAM, addr=None):
+ """Binds a generic socket."""
+ if addr is None and family in (AF_INET, AF_INET6):
+ addr = ("", 0)
+ sock = socket.socket(family, type)
+ try:
+ if os.name not in ('nt', 'cygwin'):
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(addr)
+ if type == socket.SOCK_STREAM:
+ sock.listen(5)
+ return sock
+ except Exception:
+ sock.close()
+ raise
+
+
+def bind_unix_socket(name, type=socket.SOCK_STREAM):
+ """Bind a UNIX socket."""
+ assert psutil.POSIX
+ assert not os.path.exists(name), name
+ sock = socket.socket(socket.AF_UNIX, type)
+ try:
+ sock.bind(name)
+ if type == socket.SOCK_STREAM:
+ sock.listen(5)
+ except Exception:
+ sock.close()
+ raise
+ return sock
+
+
+def tcp_socketpair(family, addr=("", 0)):
+ """Build a pair of TCP sockets connected to each other.
+ Return a (server, client) tuple.
+ """
+ with contextlib.closing(socket.socket(family, SOCK_STREAM)) as ll:
+ ll.bind(addr)
+ ll.listen(5)
+ addr = ll.getsockname()
+ c = socket.socket(family, SOCK_STREAM)
+ try:
+ c.connect(addr)
+ caddr = c.getsockname()
+ while True:
+ a, addr = ll.accept()
+ # check that we've got the correct client
+ if addr == caddr:
+ return (a, c)
+ a.close()
+ except OSError:
+ c.close()
+ raise
+
+
+def unix_socketpair(name):
+ """Build a pair of UNIX sockets connected to each other through
+ the same UNIX file name.
+ Return a (server, client) tuple.
+ """
+ assert psutil.POSIX
+ server = client = None
+ try:
+ server = bind_unix_socket(name, type=socket.SOCK_STREAM)
+ server.setblocking(0)
+ client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ client.setblocking(0)
+ client.connect(name)
+ # new = server.accept()
+ except Exception:
+ if server is not None:
+ server.close()
+ if client is not None:
+ client.close()
+ raise
+ return (server, client)
+
+
+@contextlib.contextmanager
+def create_sockets():
+ """Open as many socket families / types as possible."""
+ socks = []
+ fname1 = fname2 = None
+ try:
+ socks.append(bind_socket(socket.AF_INET, socket.SOCK_STREAM))
+ socks.append(bind_socket(socket.AF_INET, socket.SOCK_DGRAM))
+ if supports_ipv6():
+ socks.append(bind_socket(socket.AF_INET6, socket.SOCK_STREAM))
+ socks.append(bind_socket(socket.AF_INET6, socket.SOCK_DGRAM))
+ if POSIX and HAS_CONNECTIONS_UNIX:
+ fname1 = get_testfn()
+ fname2 = get_testfn()
+ s1, s2 = unix_socketpair(fname1)
+ s3 = bind_unix_socket(fname2, type=socket.SOCK_DGRAM)
+ for s in (s1, s2, s3):
+ socks.append(s)
+ yield socks
+ finally:
+ for s in socks:
+ s.close()
+ for fname in (fname1, fname2):
+ if fname is not None:
+ safe_rmpath(fname)
+
+
+def check_net_address(addr, family):
+ """Check a net address validity. Supported families are IPv4,
+ IPv6 and MAC addresses.
+ """
+ import ipaddress # python >= 3.3 / requires "pip install ipaddress"
+ if enum and PY3 and not PYPY:
+ assert isinstance(family, enum.IntEnum), family
+ if family == socket.AF_INET:
+ octs = [int(x) for x in addr.split('.')]
+ assert len(octs) == 4, addr
+ for num in octs:
+ assert 0 <= num <= 255, addr
+ if not PY3:
+ addr = unicode(addr)
+ ipaddress.IPv4Address(addr)
+ elif family == socket.AF_INET6:
+ assert isinstance(addr, str), addr
+ if not PY3:
+ addr = unicode(addr)
+ ipaddress.IPv6Address(addr)
+ elif family == psutil.AF_LINK:
+ assert re.match(r'([a-fA-F0-9]{2}[:|\-]?){6}', addr) is not None, addr
+ else:
+ raise ValueError("unknown family %r", family)
+
+
+def check_connection_ntuple(conn):
+ """Check validity of a connection namedtuple."""
+ def check_ntuple(conn):
+ has_pid = len(conn) == 7
+ assert len(conn) in (6, 7), len(conn)
+ assert conn[0] == conn.fd, conn.fd
+ assert conn[1] == conn.family, conn.family
+ assert conn[2] == conn.type, conn.type
+ assert conn[3] == conn.laddr, conn.laddr
+ assert conn[4] == conn.raddr, conn.raddr
+ assert conn[5] == conn.status, conn.status
+ if has_pid:
+ assert conn[6] == conn.pid, conn.pid
+
+ def check_family(conn):
+ assert conn.family in (AF_INET, AF_INET6, AF_UNIX), conn.family
+ if enum is not None:
+ assert isinstance(conn.family, enum.IntEnum), conn
+ else:
+ assert isinstance(conn.family, int), conn
+ if conn.family == AF_INET:
+ # actually try to bind the local socket; ignore IPv6
+ # sockets as their address might be represented as
+ # an IPv4-mapped-address (e.g. "::127.0.0.1")
+ # and that's rejected by bind()
+ s = socket.socket(conn.family, conn.type)
+ with contextlib.closing(s):
+ try:
+ s.bind((conn.laddr[0], 0))
+ except socket.error as err:
+ if err.errno != errno.EADDRNOTAVAIL:
+ raise
+ elif conn.family == AF_UNIX:
+ assert conn.status == psutil.CONN_NONE, conn.status
+
+ def check_type(conn):
+ # SOCK_SEQPACKET may happen in case of AF_UNIX socks
+ SOCK_SEQPACKET = getattr(socket, "SOCK_SEQPACKET", object())
+ assert conn.type in (socket.SOCK_STREAM, socket.SOCK_DGRAM,
+ SOCK_SEQPACKET), conn.type
+ if enum is not None:
+ assert isinstance(conn.type, enum.IntEnum), conn
+ else:
+ assert isinstance(conn.type, int), conn
+ if conn.type == socket.SOCK_DGRAM:
+ assert conn.status == psutil.CONN_NONE, conn.status
+
+ def check_addrs(conn):
+ # check IP address and port sanity
+ for addr in (conn.laddr, conn.raddr):
+ if conn.family in (AF_INET, AF_INET6):
+ assert isinstance(addr, tuple), type(addr)
+ if not addr:
+ continue
+ assert isinstance(addr.port, int), type(addr.port)
+ assert 0 <= addr.port <= 65535, addr.port
+ check_net_address(addr.ip, conn.family)
+ elif conn.family == AF_UNIX:
+ assert isinstance(addr, str), type(addr)
+
+ def check_status(conn):
+ assert isinstance(conn.status, str), conn.status
+ valids = [getattr(psutil, x) for x in dir(psutil)
+ if x.startswith('CONN_')]
+ assert conn.status in valids, conn.status
+ if conn.family in (AF_INET, AF_INET6) and conn.type == SOCK_STREAM:
+ assert conn.status != psutil.CONN_NONE, conn.status
+ else:
+ assert conn.status == psutil.CONN_NONE, conn.status
+
+ check_ntuple(conn)
+ check_family(conn)
+ check_type(conn)
+ check_addrs(conn)
+ check_status(conn)
+
+
+# ===================================================================
+# --- compatibility
+# ===================================================================
+
+
+def reload_module(module):
+ """Backport of importlib.reload of Python 3.3+."""
+ try:
+ import importlib
+ if not hasattr(importlib, 'reload'): # python <=3.3
+ raise ImportError
+ except ImportError:
+ import imp
+ return imp.reload(module)
+ else:
+ return importlib.reload(module)
+
+
+def import_module_by_path(path):
+ name = os.path.splitext(os.path.basename(path))[0]
+ if sys.version_info[0] == 2:
+ import imp
+ return imp.load_source(name, path)
+ elif sys.version_info[:2] <= (3, 4):
+ from importlib.machinery import SourceFileLoader
+ return SourceFileLoader(name, path).load_module()
+ else:
+ import importlib.util
+ spec = importlib.util.spec_from_file_location(name, path)
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ return mod
+
+
+# ===================================================================
+# --- others
+# ===================================================================
+
+
+def warn(msg):
+ """Raise a warning msg."""
+ warnings.warn(msg, UserWarning)
+
+
+def is_namedtuple(x):
+ """Check if object is an instance of namedtuple."""
+ t = type(x)
+ b = t.__bases__
+ if len(b) != 1 or b[0] != tuple:
+ return False
+ f = getattr(t, '_fields', None)
+ if not isinstance(f, tuple):
+ return False
+ return all(type(n) == str for n in f)
+
+
+if POSIX:
+ @contextlib.contextmanager
+ def copyload_shared_lib(suffix=""):
+ """Ctx manager which picks up a random shared CO lib used
+ by this process, copies it in another location and loads it
+ in memory via ctypes. Return the new absolutized path.
+ """
+ exe = 'pypy' if PYPY else 'python'
+ ext = ".so"
+ dst = get_testfn(suffix=suffix + ext)
+ libs = [x.path for x in psutil.Process().memory_maps() if
+ os.path.splitext(x.path)[1] == ext and
+ exe in x.path.lower()]
+ src = random.choice(libs)
+ shutil.copyfile(src, dst)
+ try:
+ ctypes.CDLL(dst)
+ yield dst
+ finally:
+ safe_rmpath(dst)
+else:
+ @contextlib.contextmanager
+ def copyload_shared_lib(suffix=""):
+ """Ctx manager which picks up a random shared DLL lib used
+ by this process, copies it in another location and loads it
+ in memory via ctypes.
+ Return the new absolutized, normcased path.
+ """
+ from ctypes import WinError
+ from ctypes import wintypes
+ ext = ".dll"
+ dst = get_testfn(suffix=suffix + ext)
+ libs = [x.path for x in psutil.Process().memory_maps() if
+ x.path.lower().endswith(ext) and
+ 'python' in os.path.basename(x.path).lower() and
+ 'wow64' not in x.path.lower()]
+ if PYPY and not libs:
+ libs = [x.path for x in psutil.Process().memory_maps() if
+ 'pypy' in os.path.basename(x.path).lower()]
+ src = random.choice(libs)
+ shutil.copyfile(src, dst)
+ cfile = None
+ try:
+ cfile = ctypes.WinDLL(dst)
+ yield dst
+ finally:
+ # Work around OverflowError:
+ # - https://ci.appveyor.com/project/giampaolo/psutil/build/1207/
+ # job/o53330pbnri9bcw7
+ # - http://bugs.python.org/issue30286
+ # - http://stackoverflow.com/questions/23522055
+ if cfile is not None:
+ FreeLibrary = ctypes.windll.kernel32.FreeLibrary
+ FreeLibrary.argtypes = [wintypes.HMODULE]
+ ret = FreeLibrary(cfile._handle)
+ if ret == 0:
+ WinError()
+ safe_rmpath(dst)
+
+
+# ===================================================================
+# --- Exit funs (first is executed last)
+# ===================================================================
+
+
+# this is executed first
+@atexit.register
+def cleanup_test_procs():
+ reap_children(recursive=True)
+
+
+# atexit module does not execute exit functions in case of SIGTERM, which
+# gets sent to test subprocesses, which is a problem if they import this
+# module. With this it will. See:
+# https://gmpy.dev/blog/2016/how-to-always-execute-exit-functions-in-python
+if POSIX:
+ signal.signal(signal.SIGTERM, lambda sig, frame: sys.exit(sig))
diff --git a/lib/psutil/tests/__main__.py b/lib/psutil/tests/__main__.py
new file mode 100644
index 0000000..e677352
--- /dev/null
+++ b/lib/psutil/tests/__main__.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Run unit tests. This is invoked by:
+$ python -m psutil.tests
+"""
+
+from .runner import main
+
+
+main()
diff --git a/lib/psutil/tests/runner.py b/lib/psutil/tests/runner.py
new file mode 100644
index 0000000..2e6f83e
--- /dev/null
+++ b/lib/psutil/tests/runner.py
@@ -0,0 +1,350 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Unit test runner, providing new features on top of unittest module:
+- colourized output
+- parallel run (UNIX only)
+- print failures/tracebacks on CTRL+C
+- re-run failed tests only (make test-failed)
+
+Invocation examples:
+- make test
+- make test-failed
+
+Parallel:
+- make test-parallel
+- make test-process ARGS=--parallel
+"""
+
+from __future__ import print_function
+
+import atexit
+import optparse
+import os
+import sys
+import textwrap
+import time
+import unittest
+
+
+try:
+ import ctypes
+except ImportError:
+ ctypes = None
+
+try:
+ import concurrencytest # pip install concurrencytest
+except ImportError:
+ concurrencytest = None
+
+import psutil
+from psutil._common import hilite
+from psutil._common import print_color
+from psutil._common import term_supports_colors
+from psutil._compat import super
+from psutil.tests import CI_TESTING
+from psutil.tests import import_module_by_path
+from psutil.tests import print_sysinfo
+from psutil.tests import reap_children
+from psutil.tests import safe_rmpath
+
+
+VERBOSITY = 2
+FAILED_TESTS_FNAME = '.failed-tests.txt'
+NWORKERS = psutil.cpu_count() or 1
+USE_COLORS = not CI_TESTING and term_supports_colors()
+
+HERE = os.path.abspath(os.path.dirname(__file__))
+loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase
+
+
+def cprint(msg, color, bold=False, file=None):
+ if file is None:
+ file = sys.stderr if color == 'red' else sys.stdout
+ if USE_COLORS:
+ print_color(msg, color, bold=bold, file=file)
+ else:
+ print(msg, file=file)
+
+
+class TestLoader:
+
+ testdir = HERE
+ skip_files = ['test_memleaks.py']
+ if "WHEELHOUSE_UPLOADER_USERNAME" in os.environ:
+ skip_files.extend(['test_osx.py', 'test_linux.py', 'test_posix.py'])
+
+ def _get_testmods(self):
+ return [os.path.join(self.testdir, x)
+ for x in os.listdir(self.testdir)
+ if x.startswith('test_') and x.endswith('.py') and
+ x not in self.skip_files]
+
+ def _iter_testmod_classes(self):
+ """Iterate over all test files in this directory and return
+ all TestCase classes in them.
+ """
+ for path in self._get_testmods():
+ mod = import_module_by_path(path)
+ for name in dir(mod):
+ obj = getattr(mod, name)
+ if isinstance(obj, type) and \
+ issubclass(obj, unittest.TestCase):
+ yield obj
+
+ def all(self):
+ suite = unittest.TestSuite()
+ for obj in self._iter_testmod_classes():
+ test = loadTestsFromTestCase(obj)
+ suite.addTest(test)
+ return suite
+
+ def last_failed(self):
+ # ...from previously failed test run
+ suite = unittest.TestSuite()
+ if not os.path.isfile(FAILED_TESTS_FNAME):
+ return suite
+ with open(FAILED_TESTS_FNAME, 'rt') as f:
+ names = f.read().split()
+ for n in names:
+ test = unittest.defaultTestLoader.loadTestsFromName(n)
+ suite.addTest(test)
+ return suite
+
+ def from_name(self, name):
+ if name.endswith('.py'):
+ name = os.path.splitext(os.path.basename(name))[0]
+ return unittest.defaultTestLoader.loadTestsFromName(name)
+
+
+class ColouredResult(unittest.TextTestResult):
+
+ def addSuccess(self, test):
+ unittest.TestResult.addSuccess(self, test)
+ cprint("OK", "green")
+
+ def addError(self, test, err):
+ unittest.TestResult.addError(self, test, err)
+ cprint("ERROR", "red", bold=True)
+
+ def addFailure(self, test, err):
+ unittest.TestResult.addFailure(self, test, err)
+ cprint("FAIL", "red")
+
+ def addSkip(self, test, reason):
+ unittest.TestResult.addSkip(self, test, reason)
+ cprint("skipped: %s" % reason.strip(), "brown")
+
+ def printErrorList(self, flavour, errors):
+ flavour = hilite(flavour, "red", bold=flavour == 'ERROR')
+ super().printErrorList(flavour, errors)
+
+
+class ColouredTextRunner(unittest.TextTestRunner):
+ """
+ A coloured text runner which also prints failed tests on KeyboardInterrupt
+ and save failed tests in a file so that they can be re-run.
+ """
+ resultclass = ColouredResult if USE_COLORS else unittest.TextTestResult
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.failed_tnames = set()
+
+ def _makeResult(self):
+ # Store result instance so that it can be accessed on
+ # KeyboardInterrupt.
+ self.result = super()._makeResult()
+ return self.result
+
+ def _write_last_failed(self):
+ if self.failed_tnames:
+ with open(FAILED_TESTS_FNAME, 'wt') as f:
+ for tname in self.failed_tnames:
+ f.write(tname + '\n')
+
+ def _save_result(self, result):
+ if not result.wasSuccessful():
+ for t in result.errors + result.failures:
+ tname = t[0].id()
+ self.failed_tnames.add(tname)
+
+ def _run(self, suite):
+ try:
+ result = super().run(suite)
+ except (KeyboardInterrupt, SystemExit):
+ result = self.runner.result
+ result.printErrors()
+ raise sys.exit(1)
+ else:
+ self._save_result(result)
+ return result
+
+ def _exit(self, success):
+ if success:
+ cprint("SUCCESS", "green", bold=True)
+ safe_rmpath(FAILED_TESTS_FNAME)
+ sys.exit(0)
+ else:
+ cprint("FAILED", "red", bold=True)
+ self._write_last_failed()
+ sys.exit(1)
+
+ def run(self, suite):
+ result = self._run(suite)
+ self._exit(result.wasSuccessful())
+
+
+class ParallelRunner(ColouredTextRunner):
+
+ @staticmethod
+ def _parallelize(suite):
+ def fdopen(fd, mode, *kwds):
+ stream = orig_fdopen(fd, mode)
+ atexit.register(stream.close)
+ return stream
+
+ # Monkey patch concurrencytest lib bug (fdopen() stream not closed).
+ # https://github.com/cgoldberg/concurrencytest/issues/11
+ orig_fdopen = os.fdopen
+ concurrencytest.os.fdopen = fdopen
+ forker = concurrencytest.fork_for_tests(NWORKERS)
+ return concurrencytest.ConcurrentTestSuite(suite, forker)
+
+ @staticmethod
+ def _split_suite(suite):
+ serial = unittest.TestSuite()
+ parallel = unittest.TestSuite()
+ for test in suite:
+ if test.countTestCases() == 0:
+ continue
+ elif isinstance(test, unittest.TestSuite):
+ test_class = test._tests[0].__class__
+ elif isinstance(test, unittest.TestCase):
+ test_class = test
+ else:
+ raise TypeError("can't recognize type %r" % test)
+
+ if getattr(test_class, '_serialrun', False):
+ serial.addTest(test)
+ else:
+ parallel.addTest(test)
+ return (serial, parallel)
+
+ def run(self, suite):
+ ser_suite, par_suite = self._split_suite(suite)
+ par_suite = self._parallelize(par_suite)
+
+ # run parallel
+ cprint("starting parallel tests using %s workers" % NWORKERS,
+ "green", bold=True)
+ t = time.time()
+ par = self._run(par_suite)
+ par_elapsed = time.time() - t
+
+ # At this point we should have N zombies (the workers), which
+ # will disappear with wait().
+ orphans = psutil.Process().children()
+ gone, alive = psutil.wait_procs(orphans, timeout=1)
+ if alive:
+ cprint("alive processes %s" % alive, "red")
+ reap_children()
+
+ # run serial
+ t = time.time()
+ ser = self._run(ser_suite)
+ ser_elapsed = time.time() - t
+
+ # print
+ if not par.wasSuccessful() and ser_suite.countTestCases() > 0:
+ par.printErrors() # print them again at the bottom
+ par_fails, par_errs, par_skips = map(len, (par.failures,
+ par.errors,
+ par.skipped))
+ ser_fails, ser_errs, ser_skips = map(len, (ser.failures,
+ ser.errors,
+ ser.skipped))
+ print(textwrap.dedent("""
+ +----------+----------+----------+----------+----------+----------+
+ | | total | failures | errors | skipped | time |
+ +----------+----------+----------+----------+----------+----------+
+ | parallel | %3s | %3s | %3s | %3s | %.2fs |
+ +----------+----------+----------+----------+----------+----------+
+ | serial | %3s | %3s | %3s | %3s | %.2fs |
+ +----------+----------+----------+----------+----------+----------+
+ """ % (par.testsRun, par_fails, par_errs, par_skips, par_elapsed,
+ ser.testsRun, ser_fails, ser_errs, ser_skips, ser_elapsed)))
+ print("Ran %s tests in %.3fs using %s workers" % (
+ par.testsRun + ser.testsRun, par_elapsed + ser_elapsed, NWORKERS))
+ ok = par.wasSuccessful() and ser.wasSuccessful()
+ self._exit(ok)
+
+
+def get_runner(parallel=False):
+ def warn(msg):
+ cprint(msg + " Running serial tests instead.", "red")
+ if parallel:
+ if psutil.WINDOWS:
+ warn("Can't run parallel tests on Windows.")
+ elif concurrencytest is None:
+ warn("concurrencytest module is not installed.")
+ elif NWORKERS == 1:
+ warn("Only 1 CPU available.")
+ else:
+ return ParallelRunner(verbosity=VERBOSITY)
+ return ColouredTextRunner(verbosity=VERBOSITY)
+
+
+# Used by test_*,py modules.
+def run_from_name(name):
+ if CI_TESTING:
+ print_sysinfo()
+ suite = TestLoader().from_name(name)
+ runner = get_runner()
+ runner.run(suite)
+
+
+def setup():
+ psutil._set_debug(True)
+
+
+def main():
+ setup()
+ usage = "python3 -m psutil.tests [opts] [test-name]"
+ parser = optparse.OptionParser(usage=usage, description="run unit tests")
+ parser.add_option("--last-failed",
+ action="store_true", default=False,
+ help="only run last failed tests")
+ parser.add_option("--parallel",
+ action="store_true", default=False,
+ help="run tests in parallel")
+ opts, args = parser.parse_args()
+
+ if not opts.last_failed:
+ safe_rmpath(FAILED_TESTS_FNAME)
+
+ # loader
+ loader = TestLoader()
+ if args:
+ if len(args) > 1:
+ parser.print_usage()
+ return sys.exit(1)
+ else:
+ suite = loader.from_name(args[0])
+ elif opts.last_failed:
+ suite = loader.last_failed()
+ else:
+ suite = loader.all()
+
+ if CI_TESTING:
+ print_sysinfo()
+ runner = get_runner(opts.parallel)
+ runner.run(suite)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/lib/psutil/tests/test_aix.py b/lib/psutil/tests/test_aix.py
new file mode 100644
index 0000000..4a23b77
--- /dev/null
+++ b/lib/psutil/tests/test_aix.py
@@ -0,0 +1,122 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'
+# Copyright (c) 2017, Arnon Yaari
+# All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""AIX specific tests."""
+
+import re
+import unittest
+
+import psutil
+from psutil import AIX
+from psutil.tests import PsutilTestCase
+from psutil.tests import sh
+
+
+@unittest.skipIf(not AIX, "AIX only")
+class AIXSpecificTestCase(PsutilTestCase):
+
+ def test_virtual_memory(self):
+ out = sh('/usr/bin/svmon -O unit=KB')
+ re_pattern = r"memory\s*"
+ for field in ("size inuse free pin virtual available mmode").split():
+ re_pattern += r"(?P<%s>\S+)\s+" % (field,)
+ matchobj = re.search(re_pattern, out)
+
+ self.assertIsNotNone(
+ matchobj, "svmon command returned unexpected output")
+
+ KB = 1024
+ total = int(matchobj.group("size")) * KB
+ available = int(matchobj.group("available")) * KB
+ used = int(matchobj.group("inuse")) * KB
+ free = int(matchobj.group("free")) * KB
+
+ psutil_result = psutil.virtual_memory()
+
+ # TOLERANCE_SYS_MEM from psutil.tests is not enough. For some reason
+ # we're seeing differences of ~1.2 MB. 2 MB is still a good tolerance
+ # when compared to GBs.
+ TOLERANCE_SYS_MEM = 2 * KB * KB # 2 MB
+ self.assertEqual(psutil_result.total, total)
+ self.assertAlmostEqual(
+ psutil_result.used, used, delta=TOLERANCE_SYS_MEM)
+ self.assertAlmostEqual(
+ psutil_result.available, available, delta=TOLERANCE_SYS_MEM)
+ self.assertAlmostEqual(
+ psutil_result.free, free, delta=TOLERANCE_SYS_MEM)
+
+ def test_swap_memory(self):
+ out = sh('/usr/sbin/lsps -a')
+ # From the man page, "The size is given in megabytes" so we assume
+ # we'll always have 'MB' in the result
+ # TODO maybe try to use "swap -l" to check "used" too, but its units
+ # are not guaranteed to be "MB" so parsing may not be consistent
+ matchobj = re.search(r"(?P<space>\S+)\s+"
+ r"(?P<vol>\S+)\s+"
+ r"(?P<vg>\S+)\s+"
+ r"(?P<size>\d+)MB", out)
+
+ self.assertIsNotNone(
+ matchobj, "lsps command returned unexpected output")
+
+ total_mb = int(matchobj.group("size"))
+ MB = 1024 ** 2
+ psutil_result = psutil.swap_memory()
+ # we divide our result by MB instead of multiplying the lsps value by
+ # MB because lsps may round down, so we round down too
+ self.assertEqual(int(psutil_result.total / MB), total_mb)
+
+ def test_cpu_stats(self):
+ out = sh('/usr/bin/mpstat -a')
+
+ re_pattern = r"ALL\s*"
+ for field in ("min maj mpcs mpcr dev soft dec ph cs ics bound rq "
+ "push S3pull S3grd S0rd S1rd S2rd S3rd S4rd S5rd "
+ "sysc").split():
+ re_pattern += r"(?P<%s>\S+)\s+" % (field,)
+ matchobj = re.search(re_pattern, out)
+
+ self.assertIsNotNone(
+ matchobj, "mpstat command returned unexpected output")
+
+ # numbers are usually in the millions so 1000 is ok for tolerance
+ CPU_STATS_TOLERANCE = 1000
+ psutil_result = psutil.cpu_stats()
+ self.assertAlmostEqual(
+ psutil_result.ctx_switches,
+ int(matchobj.group("cs")),
+ delta=CPU_STATS_TOLERANCE)
+ self.assertAlmostEqual(
+ psutil_result.syscalls,
+ int(matchobj.group("sysc")),
+ delta=CPU_STATS_TOLERANCE)
+ self.assertAlmostEqual(
+ psutil_result.interrupts,
+ int(matchobj.group("dev")),
+ delta=CPU_STATS_TOLERANCE)
+ self.assertAlmostEqual(
+ psutil_result.soft_interrupts,
+ int(matchobj.group("soft")),
+ delta=CPU_STATS_TOLERANCE)
+
+ def test_cpu_count_logical(self):
+ out = sh('/usr/bin/mpstat -a')
+ mpstat_lcpu = int(re.search(r"lcpu=(\d+)", out).group(1))
+ psutil_lcpu = psutil.cpu_count(logical=True)
+ self.assertEqual(mpstat_lcpu, psutil_lcpu)
+
+ def test_net_if_addrs_names(self):
+ out = sh('/etc/ifconfig -l')
+ ifconfig_names = set(out.split())
+ psutil_names = set(psutil.net_if_addrs().keys())
+ self.assertSetEqual(ifconfig_names, psutil_names)
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_bsd.py b/lib/psutil/tests/test_bsd.py
new file mode 100644
index 0000000..e541547
--- /dev/null
+++ b/lib/psutil/tests/test_bsd.py
@@ -0,0 +1,568 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+# TODO: (FreeBSD) add test for comparing connections with 'sockstat' cmd.
+
+
+"""Tests specific to all BSD platforms."""
+
+
+import datetime
+import os
+import re
+import time
+import unittest
+
+import psutil
+from psutil import BSD
+from psutil import FREEBSD
+from psutil import NETBSD
+from psutil import OPENBSD
+from psutil.tests import HAS_BATTERY
+from psutil.tests import TOLERANCE_SYS_MEM
+from psutil.tests import PsutilTestCase
+from psutil.tests import retry_on_failure
+from psutil.tests import sh
+from psutil.tests import spawn_testproc
+from psutil.tests import terminate
+from psutil.tests import which
+
+
+if BSD:
+ from psutil._psutil_posix import getpagesize
+
+ PAGESIZE = getpagesize()
+ # muse requires root privileges
+ MUSE_AVAILABLE = True if os.getuid() == 0 and which('muse') else False
+else:
+ PAGESIZE = None
+ MUSE_AVAILABLE = False
+
+
+def sysctl(cmdline):
+ """Expects a sysctl command with an argument and parse the result
+ returning only the value of interest.
+ """
+ result = sh("sysctl " + cmdline)
+ if FREEBSD:
+ result = result[result.find(": ") + 2:]
+ elif OPENBSD or NETBSD:
+ result = result[result.find("=") + 1:]
+ try:
+ return int(result)
+ except ValueError:
+ return result
+
+
+def muse(field):
+ """Thin wrapper around 'muse' cmdline utility."""
+ out = sh('muse')
+ for line in out.split('\n'):
+ if line.startswith(field):
+ break
+ else:
+ raise ValueError("line not found")
+ return int(line.split()[1])
+
+
+# =====================================================================
+# --- All BSD*
+# =====================================================================
+
+
+@unittest.skipIf(not BSD, "BSD only")
+class BSDTestCase(PsutilTestCase):
+ """Generic tests common to all BSD variants."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.pid = spawn_testproc().pid
+
+ @classmethod
+ def tearDownClass(cls):
+ terminate(cls.pid)
+
+ @unittest.skipIf(NETBSD, "-o lstart doesn't work on NETBSD")
+ def test_process_create_time(self):
+ output = sh("ps -o lstart -p %s" % self.pid)
+ start_ps = output.replace('STARTED', '').strip()
+ start_psutil = psutil.Process(self.pid).create_time()
+ start_psutil = time.strftime("%a %b %e %H:%M:%S %Y",
+ time.localtime(start_psutil))
+ self.assertEqual(start_ps, start_psutil)
+
+ def test_disks(self):
+ # test psutil.disk_usage() and psutil.disk_partitions()
+ # against "df -a"
+ def df(path):
+ out = sh('df -k "%s"' % path).strip()
+ lines = out.split('\n')
+ lines.pop(0)
+ line = lines.pop(0)
+ dev, total, used, free = line.split()[:4]
+ if dev == 'none':
+ dev = ''
+ total = int(total) * 1024
+ used = int(used) * 1024
+ free = int(free) * 1024
+ return dev, total, used, free
+
+ for part in psutil.disk_partitions(all=False):
+ usage = psutil.disk_usage(part.mountpoint)
+ dev, total, used, free = df(part.mountpoint)
+ self.assertEqual(part.device, dev)
+ self.assertEqual(usage.total, total)
+ # 10 MB tolerance
+ if abs(usage.free - free) > 10 * 1024 * 1024:
+ raise self.fail("psutil=%s, df=%s" % (usage.free, free))
+ if abs(usage.used - used) > 10 * 1024 * 1024:
+ raise self.fail("psutil=%s, df=%s" % (usage.used, used))
+
+ @unittest.skipIf(not which('sysctl'), "sysctl cmd not available")
+ def test_cpu_count_logical(self):
+ syst = sysctl("hw.ncpu")
+ self.assertEqual(psutil.cpu_count(logical=True), syst)
+
+ @unittest.skipIf(not which('sysctl'), "sysctl cmd not available")
+ def test_virtual_memory_total(self):
+ num = sysctl('hw.physmem')
+ self.assertEqual(num, psutil.virtual_memory().total)
+
+ def test_net_if_stats(self):
+ for name, stats in psutil.net_if_stats().items():
+ try:
+ out = sh("ifconfig %s" % name)
+ except RuntimeError:
+ pass
+ else:
+ self.assertEqual(stats.isup, 'RUNNING' in out, msg=out)
+ if "mtu" in out:
+ self.assertEqual(stats.mtu,
+ int(re.findall(r'mtu (\d+)', out)[0]))
+
+
+# =====================================================================
+# --- FreeBSD
+# =====================================================================
+
+
+@unittest.skipIf(not FREEBSD, "FREEBSD only")
+class FreeBSDPsutilTestCase(PsutilTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.pid = spawn_testproc().pid
+
+ @classmethod
+ def tearDownClass(cls):
+ terminate(cls.pid)
+
+ @retry_on_failure()
+ def test_memory_maps(self):
+ out = sh('procstat -v %s' % self.pid)
+ maps = psutil.Process(self.pid).memory_maps(grouped=False)
+ lines = out.split('\n')[1:]
+ while lines:
+ line = lines.pop()
+ fields = line.split()
+ _, start, stop, perms, res = fields[:5]
+ map = maps.pop()
+ self.assertEqual("%s-%s" % (start, stop), map.addr)
+ self.assertEqual(int(res), map.rss)
+ if not map.path.startswith('['):
+ self.assertEqual(fields[10], map.path)
+
+ def test_exe(self):
+ out = sh('procstat -b %s' % self.pid)
+ self.assertEqual(psutil.Process(self.pid).exe(),
+ out.split('\n')[1].split()[-1])
+
+ def test_cmdline(self):
+ out = sh('procstat -c %s' % self.pid)
+ self.assertEqual(' '.join(psutil.Process(self.pid).cmdline()),
+ ' '.join(out.split('\n')[1].split()[2:]))
+
+ def test_uids_gids(self):
+ out = sh('procstat -s %s' % self.pid)
+ euid, ruid, suid, egid, rgid, sgid = out.split('\n')[1].split()[2:8]
+ p = psutil.Process(self.pid)
+ uids = p.uids()
+ gids = p.gids()
+ self.assertEqual(uids.real, int(ruid))
+ self.assertEqual(uids.effective, int(euid))
+ self.assertEqual(uids.saved, int(suid))
+ self.assertEqual(gids.real, int(rgid))
+ self.assertEqual(gids.effective, int(egid))
+ self.assertEqual(gids.saved, int(sgid))
+
+ @retry_on_failure()
+ def test_ctx_switches(self):
+ tested = []
+ out = sh('procstat -r %s' % self.pid)
+ p = psutil.Process(self.pid)
+ for line in out.split('\n'):
+ line = line.lower().strip()
+ if ' voluntary context' in line:
+ pstat_value = int(line.split()[-1])
+ psutil_value = p.num_ctx_switches().voluntary
+ self.assertEqual(pstat_value, psutil_value)
+ tested.append(None)
+ elif ' involuntary context' in line:
+ pstat_value = int(line.split()[-1])
+ psutil_value = p.num_ctx_switches().involuntary
+ self.assertEqual(pstat_value, psutil_value)
+ tested.append(None)
+ if len(tested) != 2:
+ raise RuntimeError("couldn't find lines match in procstat out")
+
+ @retry_on_failure()
+ def test_cpu_times(self):
+ tested = []
+ out = sh('procstat -r %s' % self.pid)
+ p = psutil.Process(self.pid)
+ for line in out.split('\n'):
+ line = line.lower().strip()
+ if 'user time' in line:
+ pstat_value = float('0.' + line.split()[-1].split('.')[-1])
+ psutil_value = p.cpu_times().user
+ self.assertEqual(pstat_value, psutil_value)
+ tested.append(None)
+ elif 'system time' in line:
+ pstat_value = float('0.' + line.split()[-1].split('.')[-1])
+ psutil_value = p.cpu_times().system
+ self.assertEqual(pstat_value, psutil_value)
+ tested.append(None)
+ if len(tested) != 2:
+ raise RuntimeError("couldn't find lines match in procstat out")
+
+
+@unittest.skipIf(not FREEBSD, "FREEBSD only")
+class FreeBSDSystemTestCase(PsutilTestCase):
+
+ @staticmethod
+ def parse_swapinfo():
+ # the last line is always the total
+ output = sh("swapinfo -k").splitlines()[-1]
+ parts = re.split(r'\s+', output)
+
+ if not parts:
+ raise ValueError("Can't parse swapinfo: %s" % output)
+
+ # the size is in 1k units, so multiply by 1024
+ total, used, free = (int(p) * 1024 for p in parts[1:4])
+ return total, used, free
+
+ def test_cpu_frequency_against_sysctl(self):
+ # Currently only cpu 0 is frequency is supported in FreeBSD
+ # All other cores use the same frequency.
+ sensor = "dev.cpu.0.freq"
+ try:
+ sysctl_result = int(sysctl(sensor))
+ except RuntimeError:
+ self.skipTest("frequencies not supported by kernel")
+ self.assertEqual(psutil.cpu_freq().current, sysctl_result)
+
+ sensor = "dev.cpu.0.freq_levels"
+ sysctl_result = sysctl(sensor)
+ # sysctl returns a string of the format:
+ # <freq_level_1>/<voltage_level_1> <freq_level_2>/<voltage_level_2>...
+ # Ordered highest available to lowest available.
+ max_freq = int(sysctl_result.split()[0].split("/")[0])
+ min_freq = int(sysctl_result.split()[-1].split("/")[0])
+ self.assertEqual(psutil.cpu_freq().max, max_freq)
+ self.assertEqual(psutil.cpu_freq().min, min_freq)
+
+ # --- virtual_memory(); tests against sysctl
+
+ @retry_on_failure()
+ def test_vmem_active(self):
+ syst = sysctl("vm.stats.vm.v_active_count") * PAGESIZE
+ self.assertAlmostEqual(psutil.virtual_memory().active, syst,
+ delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_inactive(self):
+ syst = sysctl("vm.stats.vm.v_inactive_count") * PAGESIZE
+ self.assertAlmostEqual(psutil.virtual_memory().inactive, syst,
+ delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_wired(self):
+ syst = sysctl("vm.stats.vm.v_wire_count") * PAGESIZE
+ self.assertAlmostEqual(psutil.virtual_memory().wired, syst,
+ delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_cached(self):
+ syst = sysctl("vm.stats.vm.v_cache_count") * PAGESIZE
+ self.assertAlmostEqual(psutil.virtual_memory().cached, syst,
+ delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_free(self):
+ syst = sysctl("vm.stats.vm.v_free_count") * PAGESIZE
+ self.assertAlmostEqual(psutil.virtual_memory().free, syst,
+ delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_buffers(self):
+ syst = sysctl("vfs.bufspace")
+ self.assertAlmostEqual(psutil.virtual_memory().buffers, syst,
+ delta=TOLERANCE_SYS_MEM)
+
+ # --- virtual_memory(); tests against muse
+
+ @unittest.skipIf(not MUSE_AVAILABLE, "muse not installed")
+ def test_muse_vmem_total(self):
+ num = muse('Total')
+ self.assertEqual(psutil.virtual_memory().total, num)
+
+ @unittest.skipIf(not MUSE_AVAILABLE, "muse not installed")
+ @retry_on_failure()
+ def test_muse_vmem_active(self):
+ num = muse('Active')
+ self.assertAlmostEqual(psutil.virtual_memory().active, num,
+ delta=TOLERANCE_SYS_MEM)
+
+ @unittest.skipIf(not MUSE_AVAILABLE, "muse not installed")
+ @retry_on_failure()
+ def test_muse_vmem_inactive(self):
+ num = muse('Inactive')
+ self.assertAlmostEqual(psutil.virtual_memory().inactive, num,
+ delta=TOLERANCE_SYS_MEM)
+
+ @unittest.skipIf(not MUSE_AVAILABLE, "muse not installed")
+ @retry_on_failure()
+ def test_muse_vmem_wired(self):
+ num = muse('Wired')
+ self.assertAlmostEqual(psutil.virtual_memory().wired, num,
+ delta=TOLERANCE_SYS_MEM)
+
+ @unittest.skipIf(not MUSE_AVAILABLE, "muse not installed")
+ @retry_on_failure()
+ def test_muse_vmem_cached(self):
+ num = muse('Cache')
+ self.assertAlmostEqual(psutil.virtual_memory().cached, num,
+ delta=TOLERANCE_SYS_MEM)
+
+ @unittest.skipIf(not MUSE_AVAILABLE, "muse not installed")
+ @retry_on_failure()
+ def test_muse_vmem_free(self):
+ num = muse('Free')
+ self.assertAlmostEqual(psutil.virtual_memory().free, num,
+ delta=TOLERANCE_SYS_MEM)
+
+ @unittest.skipIf(not MUSE_AVAILABLE, "muse not installed")
+ @retry_on_failure()
+ def test_muse_vmem_buffers(self):
+ num = muse('Buffer')
+ self.assertAlmostEqual(psutil.virtual_memory().buffers, num,
+ delta=TOLERANCE_SYS_MEM)
+
+ def test_cpu_stats_ctx_switches(self):
+ self.assertAlmostEqual(psutil.cpu_stats().ctx_switches,
+ sysctl('vm.stats.sys.v_swtch'), delta=1000)
+
+ def test_cpu_stats_interrupts(self):
+ self.assertAlmostEqual(psutil.cpu_stats().interrupts,
+ sysctl('vm.stats.sys.v_intr'), delta=1000)
+
+ def test_cpu_stats_soft_interrupts(self):
+ self.assertAlmostEqual(psutil.cpu_stats().soft_interrupts,
+ sysctl('vm.stats.sys.v_soft'), delta=1000)
+
+ @retry_on_failure()
+ def test_cpu_stats_syscalls(self):
+ # pretty high tolerance but it looks like it's OK.
+ self.assertAlmostEqual(psutil.cpu_stats().syscalls,
+ sysctl('vm.stats.sys.v_syscall'), delta=200000)
+
+ # def test_cpu_stats_traps(self):
+ # self.assertAlmostEqual(psutil.cpu_stats().traps,
+ # sysctl('vm.stats.sys.v_trap'), delta=1000)
+
+ # --- swap memory
+
+ def test_swapmem_free(self):
+ total, used, free = self.parse_swapinfo()
+ self.assertAlmostEqual(
+ psutil.swap_memory().free, free, delta=TOLERANCE_SYS_MEM)
+
+ def test_swapmem_used(self):
+ total, used, free = self.parse_swapinfo()
+ self.assertAlmostEqual(
+ psutil.swap_memory().used, used, delta=TOLERANCE_SYS_MEM)
+
+ def test_swapmem_total(self):
+ total, used, free = self.parse_swapinfo()
+ self.assertAlmostEqual(
+ psutil.swap_memory().total, total, delta=TOLERANCE_SYS_MEM)
+
+ # --- others
+
+ def test_boot_time(self):
+ s = sysctl('sysctl kern.boottime')
+ s = s[s.find(" sec = ") + 7:]
+ s = s[:s.find(',')]
+ btime = int(s)
+ self.assertEqual(btime, psutil.boot_time())
+
+ # --- sensors_battery
+
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_sensors_battery(self):
+ def secs2hours(secs):
+ m, s = divmod(secs, 60)
+ h, m = divmod(m, 60)
+ return "%d:%02d" % (h, m)
+
+ out = sh("acpiconf -i 0")
+ fields = dict([(x.split('\t')[0], x.split('\t')[-1])
+ for x in out.split("\n")])
+ metrics = psutil.sensors_battery()
+ percent = int(fields['Remaining capacity:'].replace('%', ''))
+ remaining_time = fields['Remaining time:']
+ self.assertEqual(metrics.percent, percent)
+ if remaining_time == 'unknown':
+ self.assertEqual(metrics.secsleft, psutil.POWER_TIME_UNLIMITED)
+ else:
+ self.assertEqual(secs2hours(metrics.secsleft), remaining_time)
+
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_sensors_battery_against_sysctl(self):
+ self.assertEqual(psutil.sensors_battery().percent,
+ sysctl("hw.acpi.battery.life"))
+ self.assertEqual(psutil.sensors_battery().power_plugged,
+ sysctl("hw.acpi.acline") == 1)
+ secsleft = psutil.sensors_battery().secsleft
+ if secsleft < 0:
+ self.assertEqual(sysctl("hw.acpi.battery.time"), -1)
+ else:
+ self.assertEqual(secsleft, sysctl("hw.acpi.battery.time") * 60)
+
+ @unittest.skipIf(HAS_BATTERY, "has battery")
+ def test_sensors_battery_no_battery(self):
+ # If no battery is present one of these calls is supposed
+ # to fail, see:
+ # https://github.com/giampaolo/psutil/issues/1074
+ with self.assertRaises(RuntimeError):
+ sysctl("hw.acpi.battery.life")
+ sysctl("hw.acpi.battery.time")
+ sysctl("hw.acpi.acline")
+ self.assertIsNone(psutil.sensors_battery())
+
+ # --- sensors_temperatures
+
+ def test_sensors_temperatures_against_sysctl(self):
+ num_cpus = psutil.cpu_count(True)
+ for cpu in range(num_cpus):
+ sensor = "dev.cpu.%s.temperature" % cpu
+ # sysctl returns a string in the format 46.0C
+ try:
+ sysctl_result = int(float(sysctl(sensor)[:-1]))
+ except RuntimeError:
+ self.skipTest("temperatures not supported by kernel")
+ self.assertAlmostEqual(
+ psutil.sensors_temperatures()["coretemp"][cpu].current,
+ sysctl_result, delta=10)
+
+ sensor = "dev.cpu.%s.coretemp.tjmax" % cpu
+ sysctl_result = int(float(sysctl(sensor)[:-1]))
+ self.assertEqual(
+ psutil.sensors_temperatures()["coretemp"][cpu].high,
+ sysctl_result)
+
+
+# =====================================================================
+# --- OpenBSD
+# =====================================================================
+
+
+@unittest.skipIf(not OPENBSD, "OPENBSD only")
+class OpenBSDTestCase(PsutilTestCase):
+
+ def test_boot_time(self):
+ s = sysctl('kern.boottime')
+ sys_bt = datetime.datetime.strptime(s, "%a %b %d %H:%M:%S %Y")
+ psutil_bt = datetime.datetime.fromtimestamp(psutil.boot_time())
+ self.assertEqual(sys_bt, psutil_bt)
+
+
+# =====================================================================
+# --- NetBSD
+# =====================================================================
+
+
+@unittest.skipIf(not NETBSD, "NETBSD only")
+class NetBSDTestCase(PsutilTestCase):
+
+ @staticmethod
+ def parse_meminfo(look_for):
+ with open('/proc/meminfo', 'rt') as f:
+ for line in f:
+ if line.startswith(look_for):
+ return int(line.split()[1]) * 1024
+ raise ValueError("can't find %s" % look_for)
+
+ def test_vmem_total(self):
+ self.assertEqual(
+ psutil.virtual_memory().total, self.parse_meminfo("MemTotal:"))
+
+ def test_vmem_free(self):
+ self.assertAlmostEqual(
+ psutil.virtual_memory().free, self.parse_meminfo("MemFree:"),
+ delta=TOLERANCE_SYS_MEM)
+
+ def test_vmem_buffers(self):
+ self.assertAlmostEqual(
+ psutil.virtual_memory().buffers, self.parse_meminfo("Buffers:"),
+ delta=TOLERANCE_SYS_MEM)
+
+ def test_vmem_shared(self):
+ self.assertAlmostEqual(
+ psutil.virtual_memory().shared, self.parse_meminfo("MemShared:"),
+ delta=TOLERANCE_SYS_MEM)
+
+ def test_swapmem_total(self):
+ self.assertAlmostEqual(
+ psutil.swap_memory().total, self.parse_meminfo("SwapTotal:"),
+ delta=TOLERANCE_SYS_MEM)
+
+ def test_swapmem_free(self):
+ self.assertAlmostEqual(
+ psutil.swap_memory().free, self.parse_meminfo("SwapFree:"),
+ delta=TOLERANCE_SYS_MEM)
+
+ def test_swapmem_used(self):
+ smem = psutil.swap_memory()
+ self.assertEqual(smem.used, smem.total - smem.free)
+
+ def test_cpu_stats_interrupts(self):
+ with open('/proc/stat', 'rb') as f:
+ for line in f:
+ if line.startswith(b'intr'):
+ interrupts = int(line.split()[1])
+ break
+ else:
+ raise ValueError("couldn't find line")
+ self.assertAlmostEqual(
+ psutil.cpu_stats().interrupts, interrupts, delta=1000)
+
+ def test_cpu_stats_ctx_switches(self):
+ with open('/proc/stat', 'rb') as f:
+ for line in f:
+ if line.startswith(b'ctxt'):
+ ctx_switches = int(line.split()[1])
+ break
+ else:
+ raise ValueError("couldn't find line")
+ self.assertAlmostEqual(
+ psutil.cpu_stats().ctx_switches, ctx_switches, delta=1000)
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_connections.py b/lib/psutil/tests/test_connections.py
new file mode 100644
index 0000000..f3b1f83
--- /dev/null
+++ b/lib/psutil/tests/test_connections.py
@@ -0,0 +1,554 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Tests for net_connections() and Process.connections() APIs."""
+
+import os
+import socket
+import textwrap
+import unittest
+from contextlib import closing
+from socket import AF_INET
+from socket import AF_INET6
+from socket import SOCK_DGRAM
+from socket import SOCK_STREAM
+
+import psutil
+from psutil import FREEBSD
+from psutil import LINUX
+from psutil import MACOS
+from psutil import NETBSD
+from psutil import OPENBSD
+from psutil import POSIX
+from psutil import SUNOS
+from psutil import WINDOWS
+from psutil._common import supports_ipv6
+from psutil._compat import PY3
+from psutil.tests import AF_UNIX
+from psutil.tests import HAS_CONNECTIONS_UNIX
+from psutil.tests import SKIP_SYSCONS
+from psutil.tests import PsutilTestCase
+from psutil.tests import bind_socket
+from psutil.tests import bind_unix_socket
+from psutil.tests import check_connection_ntuple
+from psutil.tests import create_sockets
+from psutil.tests import reap_children
+from psutil.tests import retry_on_failure
+from psutil.tests import serialrun
+from psutil.tests import skip_on_access_denied
+from psutil.tests import tcp_socketpair
+from psutil.tests import unix_socketpair
+from psutil.tests import wait_for_file
+
+
+thisproc = psutil.Process()
+SOCK_SEQPACKET = getattr(socket, "SOCK_SEQPACKET", object())
+
+
+@serialrun
+class ConnectionTestCase(PsutilTestCase):
+
+ def setUp(self):
+ if not (NETBSD or FREEBSD):
+ # process opens a UNIX socket to /var/log/run.
+ cons = thisproc.connections(kind='all')
+ assert not cons, cons
+
+ def tearDown(self):
+ if not (FREEBSD or NETBSD):
+ # Make sure we closed all resources.
+ # NetBSD opens a UNIX socket to /var/log/run.
+ cons = thisproc.connections(kind='all')
+ assert not cons, cons
+
+ def compare_procsys_connections(self, pid, proc_cons, kind='all'):
+ """Given a process PID and its list of connections compare
+ those against system-wide connections retrieved via
+ psutil.net_connections.
+ """
+ try:
+ sys_cons = psutil.net_connections(kind=kind)
+ except psutil.AccessDenied:
+ # On MACOS, system-wide connections are retrieved by iterating
+ # over all processes
+ if MACOS:
+ return
+ else:
+ raise
+ # Filter for this proc PID and exlucde PIDs from the tuple.
+ sys_cons = [c[:-1] for c in sys_cons if c.pid == pid]
+ sys_cons.sort()
+ proc_cons.sort()
+ self.assertEqual(proc_cons, sys_cons)
+
+
+class TestBasicOperations(ConnectionTestCase):
+
+ @unittest.skipIf(SKIP_SYSCONS, "requires root")
+ def test_system(self):
+ with create_sockets():
+ for conn in psutil.net_connections(kind='all'):
+ check_connection_ntuple(conn)
+
+ def test_process(self):
+ with create_sockets():
+ for conn in psutil.Process().connections(kind='all'):
+ check_connection_ntuple(conn)
+
+ def test_invalid_kind(self):
+ self.assertRaises(ValueError, thisproc.connections, kind='???')
+ self.assertRaises(ValueError, psutil.net_connections, kind='???')
+
+
+@serialrun
+class TestUnconnectedSockets(ConnectionTestCase):
+ """Tests sockets which are open but not connected to anything."""
+
+ def get_conn_from_sock(self, sock):
+ cons = thisproc.connections(kind='all')
+ smap = dict([(c.fd, c) for c in cons])
+ if NETBSD or FREEBSD:
+ # NetBSD opens a UNIX socket to /var/log/run
+ # so there may be more connections.
+ return smap[sock.fileno()]
+ else:
+ self.assertEqual(len(cons), 1)
+ if cons[0].fd != -1:
+ self.assertEqual(smap[sock.fileno()].fd, sock.fileno())
+ return cons[0]
+
+ def check_socket(self, sock):
+ """Given a socket, makes sure it matches the one obtained
+ via psutil. It assumes this process created one connection
+ only (the one supposed to be checked).
+ """
+ conn = self.get_conn_from_sock(sock)
+ check_connection_ntuple(conn)
+
+ # fd, family, type
+ if conn.fd != -1:
+ self.assertEqual(conn.fd, sock.fileno())
+ self.assertEqual(conn.family, sock.family)
+ # see: http://bugs.python.org/issue30204
+ self.assertEqual(
+ conn.type, sock.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE))
+
+ # local address
+ laddr = sock.getsockname()
+ if not laddr and PY3 and isinstance(laddr, bytes):
+ # See: http://bugs.python.org/issue30205
+ laddr = laddr.decode()
+ if sock.family == AF_INET6:
+ laddr = laddr[:2]
+ if sock.family == AF_UNIX and OPENBSD:
+ # No addresses are set for UNIX sockets on OpenBSD.
+ pass
+ else:
+ self.assertEqual(conn.laddr, laddr)
+
+ # XXX Solaris can't retrieve system-wide UNIX sockets
+ if sock.family == AF_UNIX and HAS_CONNECTIONS_UNIX:
+ cons = thisproc.connections(kind='all')
+ self.compare_procsys_connections(os.getpid(), cons, kind='all')
+ return conn
+
+ def test_tcp_v4(self):
+ addr = ("127.0.0.1", 0)
+ with closing(bind_socket(AF_INET, SOCK_STREAM, addr=addr)) as sock:
+ conn = self.check_socket(sock)
+ assert not conn.raddr
+ self.assertEqual(conn.status, psutil.CONN_LISTEN)
+
+ @unittest.skipIf(not supports_ipv6(), "IPv6 not supported")
+ def test_tcp_v6(self):
+ addr = ("::1", 0)
+ with closing(bind_socket(AF_INET6, SOCK_STREAM, addr=addr)) as sock:
+ conn = self.check_socket(sock)
+ assert not conn.raddr
+ self.assertEqual(conn.status, psutil.CONN_LISTEN)
+
+ def test_udp_v4(self):
+ addr = ("127.0.0.1", 0)
+ with closing(bind_socket(AF_INET, SOCK_DGRAM, addr=addr)) as sock:
+ conn = self.check_socket(sock)
+ assert not conn.raddr
+ self.assertEqual(conn.status, psutil.CONN_NONE)
+
+ @unittest.skipIf(not supports_ipv6(), "IPv6 not supported")
+ def test_udp_v6(self):
+ addr = ("::1", 0)
+ with closing(bind_socket(AF_INET6, SOCK_DGRAM, addr=addr)) as sock:
+ conn = self.check_socket(sock)
+ assert not conn.raddr
+ self.assertEqual(conn.status, psutil.CONN_NONE)
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_unix_tcp(self):
+ testfn = self.get_testfn()
+ with closing(bind_unix_socket(testfn, type=SOCK_STREAM)) as sock:
+ conn = self.check_socket(sock)
+ assert not conn.raddr
+ self.assertEqual(conn.status, psutil.CONN_NONE)
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_unix_udp(self):
+ testfn = self.get_testfn()
+ with closing(bind_unix_socket(testfn, type=SOCK_STREAM)) as sock:
+ conn = self.check_socket(sock)
+ assert not conn.raddr
+ self.assertEqual(conn.status, psutil.CONN_NONE)
+
+
+@serialrun
+class TestConnectedSocket(ConnectionTestCase):
+ """Test socket pairs which are are actually connected to
+ each other.
+ """
+
+ # On SunOS, even after we close() it, the server socket stays around
+ # in TIME_WAIT state.
+ @unittest.skipIf(SUNOS, "unreliable on SUONS")
+ def test_tcp(self):
+ addr = ("127.0.0.1", 0)
+ assert not thisproc.connections(kind='tcp4')
+ server, client = tcp_socketpair(AF_INET, addr=addr)
+ try:
+ cons = thisproc.connections(kind='tcp4')
+ self.assertEqual(len(cons), 2)
+ self.assertEqual(cons[0].status, psutil.CONN_ESTABLISHED)
+ self.assertEqual(cons[1].status, psutil.CONN_ESTABLISHED)
+ # May not be fast enough to change state so it stays
+ # commenteed.
+ # client.close()
+ # cons = thisproc.connections(kind='all')
+ # self.assertEqual(len(cons), 1)
+ # self.assertEqual(cons[0].status, psutil.CONN_CLOSE_WAIT)
+ finally:
+ server.close()
+ client.close()
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_unix(self):
+ testfn = self.get_testfn()
+ server, client = unix_socketpair(testfn)
+ try:
+ cons = thisproc.connections(kind='unix')
+ assert not (cons[0].laddr and cons[0].raddr)
+ assert not (cons[1].laddr and cons[1].raddr)
+ if NETBSD or FREEBSD:
+ # On NetBSD creating a UNIX socket will cause
+ # a UNIX connection to /var/run/log.
+ cons = [c for c in cons if c.raddr != '/var/run/log']
+ self.assertEqual(len(cons), 2, msg=cons)
+ if LINUX or FREEBSD or SUNOS:
+ # remote path is never set
+ self.assertEqual(cons[0].raddr, "")
+ self.assertEqual(cons[1].raddr, "")
+ # one local address should though
+ self.assertEqual(testfn, cons[0].laddr or cons[1].laddr)
+ elif OPENBSD:
+ # No addresses whatsoever here.
+ for addr in (cons[0].laddr, cons[0].raddr,
+ cons[1].laddr, cons[1].raddr):
+ self.assertEqual(addr, "")
+ else:
+ # On other systems either the laddr or raddr
+ # of both peers are set.
+ self.assertEqual(cons[0].laddr or cons[1].laddr, testfn)
+ self.assertEqual(cons[0].raddr or cons[1].raddr, testfn)
+ finally:
+ server.close()
+ client.close()
+
+
+class TestFilters(ConnectionTestCase):
+
+ def test_filters(self):
+ def check(kind, families, types):
+ for conn in thisproc.connections(kind=kind):
+ self.assertIn(conn.family, families)
+ self.assertIn(conn.type, types)
+ if not SKIP_SYSCONS:
+ for conn in psutil.net_connections(kind=kind):
+ self.assertIn(conn.family, families)
+ self.assertIn(conn.type, types)
+
+ with create_sockets():
+ check('all',
+ [AF_INET, AF_INET6, AF_UNIX],
+ [SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET])
+ check('inet',
+ [AF_INET, AF_INET6],
+ [SOCK_STREAM, SOCK_DGRAM])
+ check('inet4',
+ [AF_INET],
+ [SOCK_STREAM, SOCK_DGRAM])
+ check('tcp',
+ [AF_INET, AF_INET6],
+ [SOCK_STREAM])
+ check('tcp4',
+ [AF_INET],
+ [SOCK_STREAM])
+ check('tcp6',
+ [AF_INET6],
+ [SOCK_STREAM])
+ check('udp',
+ [AF_INET, AF_INET6],
+ [SOCK_DGRAM])
+ check('udp4',
+ [AF_INET],
+ [SOCK_DGRAM])
+ check('udp6',
+ [AF_INET6],
+ [SOCK_DGRAM])
+ if HAS_CONNECTIONS_UNIX:
+ check('unix',
+ [AF_UNIX],
+ [SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET])
+
+ @skip_on_access_denied(only_if=MACOS)
+ def test_combos(self):
+ reap_children()
+
+ def check_conn(proc, conn, family, type, laddr, raddr, status, kinds):
+ all_kinds = ("all", "inet", "inet4", "inet6", "tcp", "tcp4",
+ "tcp6", "udp", "udp4", "udp6")
+ check_connection_ntuple(conn)
+ self.assertEqual(conn.family, family)
+ self.assertEqual(conn.type, type)
+ self.assertEqual(conn.laddr, laddr)
+ self.assertEqual(conn.raddr, raddr)
+ self.assertEqual(conn.status, status)
+ for kind in all_kinds:
+ cons = proc.connections(kind=kind)
+ if kind in kinds:
+ assert cons
+ else:
+ assert not cons, cons
+ # compare against system-wide connections
+ # XXX Solaris can't retrieve system-wide UNIX
+ # sockets.
+ if HAS_CONNECTIONS_UNIX:
+ self.compare_procsys_connections(proc.pid, [conn])
+
+ tcp_template = textwrap.dedent("""
+ import socket, time
+ s = socket.socket({family}, socket.SOCK_STREAM)
+ s.bind(('{addr}', 0))
+ s.listen(5)
+ with open('{testfn}', 'w') as f:
+ f.write(str(s.getsockname()[:2]))
+ time.sleep(60)
+ """)
+
+ udp_template = textwrap.dedent("""
+ import socket, time
+ s = socket.socket({family}, socket.SOCK_DGRAM)
+ s.bind(('{addr}', 0))
+ with open('{testfn}', 'w') as f:
+ f.write(str(s.getsockname()[:2]))
+ time.sleep(60)
+ """)
+
+ # must be relative on Windows
+ testfile = os.path.basename(self.get_testfn(dir=os.getcwd()))
+ tcp4_template = tcp_template.format(
+ family=int(AF_INET), addr="127.0.0.1", testfn=testfile)
+ udp4_template = udp_template.format(
+ family=int(AF_INET), addr="127.0.0.1", testfn=testfile)
+ tcp6_template = tcp_template.format(
+ family=int(AF_INET6), addr="::1", testfn=testfile)
+ udp6_template = udp_template.format(
+ family=int(AF_INET6), addr="::1", testfn=testfile)
+
+ # launch various subprocess instantiating a socket of various
+ # families and types to enrich psutil results
+ tcp4_proc = self.pyrun(tcp4_template)
+ tcp4_addr = eval(wait_for_file(testfile, delete=True))
+ udp4_proc = self.pyrun(udp4_template)
+ udp4_addr = eval(wait_for_file(testfile, delete=True))
+ if supports_ipv6():
+ tcp6_proc = self.pyrun(tcp6_template)
+ tcp6_addr = eval(wait_for_file(testfile, delete=True))
+ udp6_proc = self.pyrun(udp6_template)
+ udp6_addr = eval(wait_for_file(testfile, delete=True))
+ else:
+ tcp6_proc = None
+ udp6_proc = None
+ tcp6_addr = None
+ udp6_addr = None
+
+ for p in thisproc.children():
+ cons = p.connections()
+ self.assertEqual(len(cons), 1)
+ for conn in cons:
+ # TCP v4
+ if p.pid == tcp4_proc.pid:
+ check_conn(p, conn, AF_INET, SOCK_STREAM, tcp4_addr, (),
+ psutil.CONN_LISTEN,
+ ("all", "inet", "inet4", "tcp", "tcp4"))
+ # UDP v4
+ elif p.pid == udp4_proc.pid:
+ check_conn(p, conn, AF_INET, SOCK_DGRAM, udp4_addr, (),
+ psutil.CONN_NONE,
+ ("all", "inet", "inet4", "udp", "udp4"))
+ # TCP v6
+ elif p.pid == getattr(tcp6_proc, "pid", None):
+ check_conn(p, conn, AF_INET6, SOCK_STREAM, tcp6_addr, (),
+ psutil.CONN_LISTEN,
+ ("all", "inet", "inet6", "tcp", "tcp6"))
+ # UDP v6
+ elif p.pid == getattr(udp6_proc, "pid", None):
+ check_conn(p, conn, AF_INET6, SOCK_DGRAM, udp6_addr, (),
+ psutil.CONN_NONE,
+ ("all", "inet", "inet6", "udp", "udp6"))
+
+ def test_count(self):
+ with create_sockets():
+ # tcp
+ cons = thisproc.connections(kind='tcp')
+ self.assertEqual(len(cons), 2 if supports_ipv6() else 1)
+ for conn in cons:
+ self.assertIn(conn.family, (AF_INET, AF_INET6))
+ self.assertEqual(conn.type, SOCK_STREAM)
+ # tcp4
+ cons = thisproc.connections(kind='tcp4')
+ self.assertEqual(len(cons), 1)
+ self.assertEqual(cons[0].family, AF_INET)
+ self.assertEqual(cons[0].type, SOCK_STREAM)
+ # tcp6
+ if supports_ipv6():
+ cons = thisproc.connections(kind='tcp6')
+ self.assertEqual(len(cons), 1)
+ self.assertEqual(cons[0].family, AF_INET6)
+ self.assertEqual(cons[0].type, SOCK_STREAM)
+ # udp
+ cons = thisproc.connections(kind='udp')
+ self.assertEqual(len(cons), 2 if supports_ipv6() else 1)
+ for conn in cons:
+ self.assertIn(conn.family, (AF_INET, AF_INET6))
+ self.assertEqual(conn.type, SOCK_DGRAM)
+ # udp4
+ cons = thisproc.connections(kind='udp4')
+ self.assertEqual(len(cons), 1)
+ self.assertEqual(cons[0].family, AF_INET)
+ self.assertEqual(cons[0].type, SOCK_DGRAM)
+ # udp6
+ if supports_ipv6():
+ cons = thisproc.connections(kind='udp6')
+ self.assertEqual(len(cons), 1)
+ self.assertEqual(cons[0].family, AF_INET6)
+ self.assertEqual(cons[0].type, SOCK_DGRAM)
+ # inet
+ cons = thisproc.connections(kind='inet')
+ self.assertEqual(len(cons), 4 if supports_ipv6() else 2)
+ for conn in cons:
+ self.assertIn(conn.family, (AF_INET, AF_INET6))
+ self.assertIn(conn.type, (SOCK_STREAM, SOCK_DGRAM))
+ # inet6
+ if supports_ipv6():
+ cons = thisproc.connections(kind='inet6')
+ self.assertEqual(len(cons), 2)
+ for conn in cons:
+ self.assertEqual(conn.family, AF_INET6)
+ self.assertIn(conn.type, (SOCK_STREAM, SOCK_DGRAM))
+ # Skipped on BSD becayse by default the Python process
+ # creates a UNIX socket to '/var/run/log'.
+ if HAS_CONNECTIONS_UNIX and not (FREEBSD or NETBSD):
+ cons = thisproc.connections(kind='unix')
+ self.assertEqual(len(cons), 3)
+ for conn in cons:
+ self.assertEqual(conn.family, AF_UNIX)
+ self.assertIn(conn.type, (SOCK_STREAM, SOCK_DGRAM))
+
+
+@unittest.skipIf(SKIP_SYSCONS, "requires root")
+class TestSystemWideConnections(ConnectionTestCase):
+ """Tests for net_connections()."""
+
+ def test_it(self):
+ def check(cons, families, types_):
+ for conn in cons:
+ self.assertIn(conn.family, families, msg=conn)
+ if conn.family != AF_UNIX:
+ self.assertIn(conn.type, types_, msg=conn)
+ check_connection_ntuple(conn)
+
+ with create_sockets():
+ from psutil._common import conn_tmap
+ for kind, groups in conn_tmap.items():
+ # XXX: SunOS does not retrieve UNIX sockets.
+ if kind == 'unix' and not HAS_CONNECTIONS_UNIX:
+ continue
+ families, types_ = groups
+ cons = psutil.net_connections(kind)
+ self.assertEqual(len(cons), len(set(cons)))
+ check(cons, families, types_)
+
+ @retry_on_failure()
+ def test_multi_sockets_procs(self):
+ # Creates multiple sub processes, each creating different
+ # sockets. For each process check that proc.connections()
+ # and net_connections() return the same results.
+ # This is done mainly to check whether net_connections()'s
+ # pid is properly set, see:
+ # https://github.com/giampaolo/psutil/issues/1013
+ with create_sockets() as socks:
+ expected = len(socks)
+ pids = []
+ times = 10
+ fnames = []
+ for i in range(times):
+ fname = self.get_testfn()
+ fnames.append(fname)
+ src = textwrap.dedent("""\
+ import time, os
+ from psutil.tests import create_sockets
+ with create_sockets():
+ with open(r'%s', 'w') as f:
+ f.write("hello")
+ time.sleep(60)
+ """ % fname)
+ sproc = self.pyrun(src)
+ pids.append(sproc.pid)
+
+ # sync
+ for fname in fnames:
+ wait_for_file(fname)
+
+ syscons = [x for x in psutil.net_connections(kind='all') if x.pid
+ in pids]
+ for pid in pids:
+ self.assertEqual(len([x for x in syscons if x.pid == pid]),
+ expected)
+ p = psutil.Process(pid)
+ self.assertEqual(len(p.connections('all')), expected)
+
+
+class TestMisc(PsutilTestCase):
+
+ def test_connection_constants(self):
+ ints = []
+ strs = []
+ for name in dir(psutil):
+ if name.startswith('CONN_'):
+ num = getattr(psutil, name)
+ str_ = str(num)
+ assert str_.isupper(), str_
+ self.assertNotIn(str, strs)
+ self.assertNotIn(num, ints)
+ ints.append(num)
+ strs.append(str_)
+ if SUNOS:
+ psutil.CONN_IDLE
+ psutil.CONN_BOUND
+ if WINDOWS:
+ psutil.CONN_DELETE_TCB
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_contracts.py b/lib/psutil/tests/test_contracts.py
new file mode 100644
index 0000000..3b806ee
--- /dev/null
+++ b/lib/psutil/tests/test_contracts.py
@@ -0,0 +1,751 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Contracts tests. These tests mainly check API sanity in terms of
+returned types and APIs availability.
+Some of these are duplicates of tests test_system.py and test_process.py
+"""
+
+import errno
+import multiprocessing
+import os
+import platform
+import signal
+import stat
+import sys
+import time
+import traceback
+import unittest
+
+import psutil
+from psutil import AIX
+from psutil import BSD
+from psutil import FREEBSD
+from psutil import LINUX
+from psutil import MACOS
+from psutil import NETBSD
+from psutil import OPENBSD
+from psutil import OSX
+from psutil import POSIX
+from psutil import SUNOS
+from psutil import WINDOWS
+from psutil._compat import FileNotFoundError
+from psutil._compat import long
+from psutil._compat import range
+from psutil._compat import unicode
+from psutil.tests import APPVEYOR
+from psutil.tests import CI_TESTING
+from psutil.tests import GITHUB_ACTIONS
+from psutil.tests import HAS_CPU_FREQ
+from psutil.tests import HAS_NET_IO_COUNTERS
+from psutil.tests import HAS_SENSORS_FANS
+from psutil.tests import HAS_SENSORS_TEMPERATURES
+from psutil.tests import PYPY
+from psutil.tests import SKIP_SYSCONS
+from psutil.tests import VALID_PROC_STATUSES
+from psutil.tests import PsutilTestCase
+from psutil.tests import check_connection_ntuple
+from psutil.tests import create_sockets
+from psutil.tests import enum
+from psutil.tests import is_namedtuple
+from psutil.tests import kernel_version
+from psutil.tests import process_namespace
+from psutil.tests import serialrun
+
+
+# ===================================================================
+# --- APIs availability
+# ===================================================================
+
+# Make sure code reflects what doc promises in terms of APIs
+# availability.
+
+class TestAvailConstantsAPIs(PsutilTestCase):
+
+ def test_PROCFS_PATH(self):
+ self.assertEqual(hasattr(psutil, "PROCFS_PATH"),
+ LINUX or SUNOS or AIX)
+
+ def test_win_priority(self):
+ ae = self.assertEqual
+ ae(hasattr(psutil, "ABOVE_NORMAL_PRIORITY_CLASS"), WINDOWS)
+ ae(hasattr(psutil, "BELOW_NORMAL_PRIORITY_CLASS"), WINDOWS)
+ ae(hasattr(psutil, "HIGH_PRIORITY_CLASS"), WINDOWS)
+ ae(hasattr(psutil, "IDLE_PRIORITY_CLASS"), WINDOWS)
+ ae(hasattr(psutil, "NORMAL_PRIORITY_CLASS"), WINDOWS)
+ ae(hasattr(psutil, "REALTIME_PRIORITY_CLASS"), WINDOWS)
+
+ def test_linux_ioprio_linux(self):
+ ae = self.assertEqual
+ ae(hasattr(psutil, "IOPRIO_CLASS_NONE"), LINUX)
+ ae(hasattr(psutil, "IOPRIO_CLASS_RT"), LINUX)
+ ae(hasattr(psutil, "IOPRIO_CLASS_BE"), LINUX)
+ ae(hasattr(psutil, "IOPRIO_CLASS_IDLE"), LINUX)
+
+ def test_linux_ioprio_windows(self):
+ ae = self.assertEqual
+ ae(hasattr(psutil, "IOPRIO_HIGH"), WINDOWS)
+ ae(hasattr(psutil, "IOPRIO_NORMAL"), WINDOWS)
+ ae(hasattr(psutil, "IOPRIO_LOW"), WINDOWS)
+ ae(hasattr(psutil, "IOPRIO_VERYLOW"), WINDOWS)
+
+ @unittest.skipIf(GITHUB_ACTIONS and LINUX,
+ "unsupported on GITHUB_ACTIONS + LINUX")
+ def test_rlimit(self):
+ ae = self.assertEqual
+ ae(hasattr(psutil, "RLIM_INFINITY"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_AS"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_CORE"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_CPU"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_DATA"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_FSIZE"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_MEMLOCK"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_NOFILE"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_NPROC"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_RSS"), LINUX or FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_STACK"), LINUX or FREEBSD)
+
+ ae(hasattr(psutil, "RLIMIT_LOCKS"), LINUX)
+ if POSIX:
+ if kernel_version() >= (2, 6, 8):
+ ae(hasattr(psutil, "RLIMIT_MSGQUEUE"), LINUX)
+ if kernel_version() >= (2, 6, 12):
+ ae(hasattr(psutil, "RLIMIT_NICE"), LINUX)
+ if kernel_version() >= (2, 6, 12):
+ ae(hasattr(psutil, "RLIMIT_RTPRIO"), LINUX)
+ if kernel_version() >= (2, 6, 25):
+ ae(hasattr(psutil, "RLIMIT_RTTIME"), LINUX)
+ if kernel_version() >= (2, 6, 8):
+ ae(hasattr(psutil, "RLIMIT_SIGPENDING"), LINUX)
+
+ ae(hasattr(psutil, "RLIMIT_SWAP"), FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_SBSIZE"), FREEBSD)
+ ae(hasattr(psutil, "RLIMIT_NPTS"), FREEBSD)
+
+
+class TestAvailSystemAPIs(PsutilTestCase):
+
+ def test_win_service_iter(self):
+ self.assertEqual(hasattr(psutil, "win_service_iter"), WINDOWS)
+
+ def test_win_service_get(self):
+ self.assertEqual(hasattr(psutil, "win_service_get"), WINDOWS)
+
+ def test_cpu_freq(self):
+ self.assertEqual(hasattr(psutil, "cpu_freq"),
+ LINUX or MACOS or WINDOWS or FREEBSD or OPENBSD)
+
+ def test_sensors_temperatures(self):
+ self.assertEqual(
+ hasattr(psutil, "sensors_temperatures"), LINUX or FREEBSD)
+
+ def test_sensors_fans(self):
+ self.assertEqual(hasattr(psutil, "sensors_fans"), LINUX)
+
+ def test_battery(self):
+ self.assertEqual(hasattr(psutil, "sensors_battery"),
+ LINUX or WINDOWS or FREEBSD or MACOS)
+
+
+class TestAvailProcessAPIs(PsutilTestCase):
+
+ def test_environ(self):
+ self.assertEqual(hasattr(psutil.Process, "environ"),
+ LINUX or MACOS or WINDOWS or AIX or SUNOS or
+ FREEBSD or OPENBSD or NETBSD)
+
+ def test_uids(self):
+ self.assertEqual(hasattr(psutil.Process, "uids"), POSIX)
+
+ def test_gids(self):
+ self.assertEqual(hasattr(psutil.Process, "uids"), POSIX)
+
+ def test_terminal(self):
+ self.assertEqual(hasattr(psutil.Process, "terminal"), POSIX)
+
+ def test_ionice(self):
+ self.assertEqual(hasattr(psutil.Process, "ionice"), LINUX or WINDOWS)
+
+ @unittest.skipIf(GITHUB_ACTIONS and LINUX,
+ "unsupported on GITHUB_ACTIONS + LINUX")
+ def test_rlimit(self):
+ self.assertEqual(hasattr(psutil.Process, "rlimit"), LINUX or FREEBSD)
+
+ def test_io_counters(self):
+ hasit = hasattr(psutil.Process, "io_counters")
+ self.assertEqual(hasit, False if MACOS or SUNOS else True)
+
+ def test_num_fds(self):
+ self.assertEqual(hasattr(psutil.Process, "num_fds"), POSIX)
+
+ def test_num_handles(self):
+ self.assertEqual(hasattr(psutil.Process, "num_handles"), WINDOWS)
+
+ def test_cpu_affinity(self):
+ self.assertEqual(hasattr(psutil.Process, "cpu_affinity"),
+ LINUX or WINDOWS or FREEBSD)
+
+ def test_cpu_num(self):
+ self.assertEqual(hasattr(psutil.Process, "cpu_num"),
+ LINUX or FREEBSD or SUNOS)
+
+ def test_memory_maps(self):
+ hasit = hasattr(psutil.Process, "memory_maps")
+ self.assertEqual(
+ hasit, False if OPENBSD or NETBSD or AIX or MACOS else True)
+
+
+# ===================================================================
+# --- API types
+# ===================================================================
+
+
+class TestSystemAPITypes(PsutilTestCase):
+ """Check the return types of system related APIs.
+ Mainly we want to test we never return unicode on Python 2, see:
+ https://github.com/giampaolo/psutil/issues/1039
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.proc = psutil.Process()
+
+ def assert_ntuple_of_nums(self, nt, type_=float, gezero=True):
+ assert is_namedtuple(nt)
+ for n in nt:
+ self.assertIsInstance(n, type_)
+ if gezero:
+ self.assertGreaterEqual(n, 0)
+
+ def test_cpu_times(self):
+ self.assert_ntuple_of_nums(psutil.cpu_times())
+ for nt in psutil.cpu_times(percpu=True):
+ self.assert_ntuple_of_nums(nt)
+
+ def test_cpu_percent(self):
+ self.assertIsInstance(psutil.cpu_percent(interval=None), float)
+ self.assertIsInstance(psutil.cpu_percent(interval=0.00001), float)
+
+ def test_cpu_times_percent(self):
+ self.assert_ntuple_of_nums(psutil.cpu_times_percent(interval=None))
+ self.assert_ntuple_of_nums(psutil.cpu_times_percent(interval=0.0001))
+
+ def test_cpu_count(self):
+ self.assertIsInstance(psutil.cpu_count(), int)
+
+ # TODO: remove this once 1892 is fixed
+ @unittest.skipIf(MACOS and platform.machine() == 'arm64',
+ "skipped due to #1892")
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_cpu_freq(self):
+ if psutil.cpu_freq() is None:
+ raise self.skipTest("cpu_freq() returns None")
+ self.assert_ntuple_of_nums(psutil.cpu_freq(), type_=(float, int, long))
+
+ def test_disk_io_counters(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for k, v in psutil.disk_io_counters(perdisk=True).items():
+ self.assertIsInstance(k, str)
+ self.assert_ntuple_of_nums(v, type_=(int, long))
+
+ def test_disk_partitions(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for disk in psutil.disk_partitions():
+ self.assertIsInstance(disk.device, str)
+ self.assertIsInstance(disk.mountpoint, str)
+ self.assertIsInstance(disk.fstype, str)
+ self.assertIsInstance(disk.opts, str)
+ self.assertIsInstance(disk.maxfile, int)
+ self.assertIsInstance(disk.maxpath, int)
+
+ @unittest.skipIf(SKIP_SYSCONS, "requires root")
+ def test_net_connections(self):
+ with create_sockets():
+ ret = psutil.net_connections('all')
+ self.assertEqual(len(ret), len(set(ret)))
+ for conn in ret:
+ assert is_namedtuple(conn)
+
+ def test_net_if_addrs(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for ifname, addrs in psutil.net_if_addrs().items():
+ self.assertIsInstance(ifname, str)
+ for addr in addrs:
+ if enum is not None and not PYPY:
+ self.assertIsInstance(addr.family, enum.IntEnum)
+ else:
+ self.assertIsInstance(addr.family, int)
+ self.assertIsInstance(addr.address, str)
+ self.assertIsInstance(addr.netmask, (str, type(None)))
+ self.assertIsInstance(addr.broadcast, (str, type(None)))
+
+ def test_net_if_stats(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for ifname, info in psutil.net_if_stats().items():
+ self.assertIsInstance(ifname, str)
+ self.assertIsInstance(info.isup, bool)
+ if enum is not None:
+ self.assertIsInstance(info.duplex, enum.IntEnum)
+ else:
+ self.assertIsInstance(info.duplex, int)
+ self.assertIsInstance(info.speed, int)
+ self.assertIsInstance(info.mtu, int)
+
+ @unittest.skipIf(not HAS_NET_IO_COUNTERS, 'not supported')
+ def test_net_io_counters(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for ifname, _ in psutil.net_io_counters(pernic=True).items():
+ self.assertIsInstance(ifname, str)
+
+ @unittest.skipIf(not HAS_SENSORS_FANS, "not supported")
+ def test_sensors_fans(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for name, units in psutil.sensors_fans().items():
+ self.assertIsInstance(name, str)
+ for unit in units:
+ self.assertIsInstance(unit.label, str)
+ self.assertIsInstance(unit.current, (float, int, type(None)))
+
+ @unittest.skipIf(not HAS_SENSORS_TEMPERATURES, "not supported")
+ def test_sensors_temperatures(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for name, units in psutil.sensors_temperatures().items():
+ self.assertIsInstance(name, str)
+ for unit in units:
+ self.assertIsInstance(unit.label, str)
+ self.assertIsInstance(unit.current, (float, int, type(None)))
+ self.assertIsInstance(unit.high, (float, int, type(None)))
+ self.assertIsInstance(unit.critical, (float, int, type(None)))
+
+ def test_boot_time(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ self.assertIsInstance(psutil.boot_time(), float)
+
+ def test_users(self):
+ # Duplicate of test_system.py. Keep it anyway.
+ for user in psutil.users():
+ self.assertIsInstance(user.name, str)
+ self.assertIsInstance(user.terminal, (str, type(None)))
+ self.assertIsInstance(user.host, (str, type(None)))
+ self.assertIsInstance(user.pid, (int, type(None)))
+
+
+class TestProcessWaitType(PsutilTestCase):
+
+ @unittest.skipIf(not POSIX, "not POSIX")
+ def test_negative_signal(self):
+ p = psutil.Process(self.spawn_testproc().pid)
+ p.terminate()
+ code = p.wait()
+ self.assertEqual(code, -signal.SIGTERM)
+ if enum is not None:
+ self.assertIsInstance(code, enum.IntEnum)
+ else:
+ self.assertIsInstance(code, int)
+
+
+# ===================================================================
+# --- Featch all processes test
+# ===================================================================
+
+
+def proc_info(pid):
+ tcase = PsutilTestCase()
+
+ def check_exception(exc, proc, name, ppid):
+ tcase.assertEqual(exc.pid, pid)
+ tcase.assertEqual(exc.name, name)
+ if isinstance(exc, psutil.ZombieProcess):
+ if exc.ppid is not None:
+ tcase.assertGreaterEqual(exc.ppid, 0)
+ tcase.assertEqual(exc.ppid, ppid)
+ elif isinstance(exc, psutil.NoSuchProcess):
+ tcase.assertProcessGone(proc)
+ str(exc)
+
+ def do_wait():
+ if pid != 0:
+ try:
+ proc.wait(0)
+ except psutil.Error as exc:
+ check_exception(exc, proc, name, ppid)
+
+ try:
+ proc = psutil.Process(pid)
+ d = proc.as_dict(['ppid', 'name'])
+ except psutil.NoSuchProcess:
+ return {}
+
+ name, ppid = d['name'], d['ppid']
+ info = {'pid': proc.pid}
+ ns = process_namespace(proc)
+ # We don't use oneshot() because in order not to fool
+ # check_exception() in case of NSP.
+ for fun, fun_name in ns.iter(ns.getters, clear_cache=False):
+ try:
+ info[fun_name] = fun()
+ except psutil.Error as exc:
+ check_exception(exc, proc, name, ppid)
+ continue
+ do_wait()
+ return info
+
+
+@serialrun
+class TestFetchAllProcesses(PsutilTestCase):
+ """Test which iterates over all running processes and performs
+ some sanity checks against Process API's returned values.
+ Uses a process pool to get info about all processes.
+ """
+
+ def setUp(self):
+ # Using a pool in a CI env may result in deadlock, see:
+ # https://github.com/giampaolo/psutil/issues/2104
+ if not CI_TESTING:
+ self.pool = multiprocessing.Pool()
+
+ def tearDown(self):
+ if not CI_TESTING:
+ self.pool.terminate()
+ self.pool.join()
+
+ def iter_proc_info(self):
+ # Fixes "can't pickle <function proc_info>: it's not the
+ # same object as test_contracts.proc_info".
+ from psutil.tests.test_contracts import proc_info
+
+ if not CI_TESTING:
+ return self.pool.imap_unordered(proc_info, psutil.pids())
+ else:
+ ls = []
+ for pid in psutil.pids():
+ ls.append(proc_info(pid))
+ return ls
+
+ def test_all(self):
+ failures = []
+ for info in self.iter_proc_info():
+ for name, value in info.items():
+ meth = getattr(self, name)
+ try:
+ meth(value, info)
+ except AssertionError:
+ s = '\n' + '=' * 70 + '\n'
+ s += "FAIL: test_%s pid=%s, ret=%s\n" % (
+ name, info['pid'], repr(value))
+ s += '-' * 70
+ s += "\n%s" % traceback.format_exc()
+ s = "\n".join((" " * 4) + i for i in s.splitlines())
+ s += '\n'
+ failures.append(s)
+ else:
+ if value not in (0, 0.0, [], None, '', {}):
+ assert value, value
+ if failures:
+ raise self.fail(''.join(failures))
+
+ def cmdline(self, ret, info):
+ self.assertIsInstance(ret, list)
+ for part in ret:
+ self.assertIsInstance(part, str)
+
+ def exe(self, ret, info):
+ self.assertIsInstance(ret, (str, unicode, type(None)))
+ if not ret:
+ self.assertEqual(ret, '')
+ else:
+ if WINDOWS and not ret.endswith('.exe'):
+ return # May be "Registry", "MemCompression", ...
+ assert os.path.isabs(ret), ret
+ # Note: os.stat() may return False even if the file is there
+ # hence we skip the test, see:
+ # http://stackoverflow.com/questions/3112546/os-path-exists-lies
+ if POSIX and os.path.isfile(ret):
+ if hasattr(os, 'access') and hasattr(os, "X_OK"):
+ # XXX: may fail on MACOS
+ try:
+ assert os.access(ret, os.X_OK)
+ except AssertionError:
+ if os.path.exists(ret) and not CI_TESTING:
+ raise
+
+ def pid(self, ret, info):
+ self.assertIsInstance(ret, int)
+ self.assertGreaterEqual(ret, 0)
+
+ def ppid(self, ret, info):
+ self.assertIsInstance(ret, (int, long))
+ self.assertGreaterEqual(ret, 0)
+
+ def name(self, ret, info):
+ self.assertIsInstance(ret, (str, unicode))
+ if APPVEYOR and not ret and info['status'] == 'stopped':
+ return
+ # on AIX, "<exiting>" processes don't have names
+ if not AIX:
+ assert ret
+
+ def create_time(self, ret, info):
+ self.assertIsInstance(ret, float)
+ try:
+ self.assertGreaterEqual(ret, 0)
+ except AssertionError:
+ # XXX
+ if OPENBSD and info['status'] == psutil.STATUS_ZOMBIE:
+ pass
+ else:
+ raise
+ # this can't be taken for granted on all platforms
+ # self.assertGreaterEqual(ret, psutil.boot_time())
+ # make sure returned value can be pretty printed
+ # with strftime
+ time.strftime("%Y %m %d %H:%M:%S", time.localtime(ret))
+
+ def uids(self, ret, info):
+ assert is_namedtuple(ret)
+ for uid in ret:
+ self.assertIsInstance(uid, int)
+ self.assertGreaterEqual(uid, 0)
+
+ def gids(self, ret, info):
+ assert is_namedtuple(ret)
+ # note: testing all gids as above seems not to be reliable for
+ # gid == 30 (nodoby); not sure why.
+ for gid in ret:
+ self.assertIsInstance(gid, int)
+ if not MACOS and not NETBSD:
+ self.assertGreaterEqual(gid, 0)
+
+ def username(self, ret, info):
+ self.assertIsInstance(ret, str)
+ assert ret
+
+ def status(self, ret, info):
+ self.assertIsInstance(ret, str)
+ assert ret
+ self.assertNotEqual(ret, '?') # XXX
+ self.assertIn(ret, VALID_PROC_STATUSES)
+
+ def io_counters(self, ret, info):
+ assert is_namedtuple(ret)
+ for field in ret:
+ self.assertIsInstance(field, (int, long))
+ if field != -1:
+ self.assertGreaterEqual(field, 0)
+
+ def ionice(self, ret, info):
+ if LINUX:
+ self.assertIsInstance(ret.ioclass, int)
+ self.assertIsInstance(ret.value, int)
+ self.assertGreaterEqual(ret.ioclass, 0)
+ self.assertGreaterEqual(ret.value, 0)
+ else: # Windows, Cygwin
+ choices = [
+ psutil.IOPRIO_VERYLOW,
+ psutil.IOPRIO_LOW,
+ psutil.IOPRIO_NORMAL,
+ psutil.IOPRIO_HIGH]
+ self.assertIsInstance(ret, int)
+ self.assertGreaterEqual(ret, 0)
+ self.assertIn(ret, choices)
+
+ def num_threads(self, ret, info):
+ self.assertIsInstance(ret, int)
+ if APPVEYOR and not ret and info['status'] == 'stopped':
+ return
+ self.assertGreaterEqual(ret, 1)
+
+ def threads(self, ret, info):
+ self.assertIsInstance(ret, list)
+ for t in ret:
+ assert is_namedtuple(t)
+ self.assertGreaterEqual(t.id, 0)
+ self.assertGreaterEqual(t.user_time, 0)
+ self.assertGreaterEqual(t.system_time, 0)
+ for field in t:
+ self.assertIsInstance(field, (int, float))
+
+ def cpu_times(self, ret, info):
+ assert is_namedtuple(ret)
+ for n in ret:
+ self.assertIsInstance(n, float)
+ self.assertGreaterEqual(n, 0)
+ # TODO: check ntuple fields
+
+ def cpu_percent(self, ret, info):
+ self.assertIsInstance(ret, float)
+ assert 0.0 <= ret <= 100.0, ret
+
+ def cpu_num(self, ret, info):
+ self.assertIsInstance(ret, int)
+ if FREEBSD and ret == -1:
+ return
+ self.assertGreaterEqual(ret, 0)
+ if psutil.cpu_count() == 1:
+ self.assertEqual(ret, 0)
+ self.assertIn(ret, list(range(psutil.cpu_count())))
+
+ def memory_info(self, ret, info):
+ assert is_namedtuple(ret)
+ for value in ret:
+ self.assertIsInstance(value, (int, long))
+ self.assertGreaterEqual(value, 0)
+ if WINDOWS:
+ self.assertGreaterEqual(ret.peak_wset, ret.wset)
+ self.assertGreaterEqual(ret.peak_paged_pool, ret.paged_pool)
+ self.assertGreaterEqual(ret.peak_nonpaged_pool, ret.nonpaged_pool)
+ self.assertGreaterEqual(ret.peak_pagefile, ret.pagefile)
+
+ def memory_full_info(self, ret, info):
+ assert is_namedtuple(ret)
+ total = psutil.virtual_memory().total
+ for name in ret._fields:
+ value = getattr(ret, name)
+ self.assertIsInstance(value, (int, long))
+ self.assertGreaterEqual(value, 0, msg=(name, value))
+ if LINUX or OSX and name in ('vms', 'data'):
+ # On Linux there are processes (e.g. 'goa-daemon') whose
+ # VMS is incredibly high for some reason.
+ continue
+ self.assertLessEqual(value, total, msg=(name, value, total))
+
+ if LINUX:
+ self.assertGreaterEqual(ret.pss, ret.uss)
+
+ def open_files(self, ret, info):
+ self.assertIsInstance(ret, list)
+ for f in ret:
+ self.assertIsInstance(f.fd, int)
+ self.assertIsInstance(f.path, str)
+ if WINDOWS:
+ self.assertEqual(f.fd, -1)
+ elif LINUX:
+ self.assertIsInstance(f.position, int)
+ self.assertIsInstance(f.mode, str)
+ self.assertIsInstance(f.flags, int)
+ self.assertGreaterEqual(f.position, 0)
+ self.assertIn(f.mode, ('r', 'w', 'a', 'r+', 'a+'))
+ self.assertGreater(f.flags, 0)
+ elif BSD and not f.path:
+ # XXX see: https://github.com/giampaolo/psutil/issues/595
+ continue
+ assert os.path.isabs(f.path), f
+ try:
+ st = os.stat(f.path)
+ except FileNotFoundError:
+ pass
+ else:
+ assert stat.S_ISREG(st.st_mode), f
+
+ def num_fds(self, ret, info):
+ self.assertIsInstance(ret, int)
+ self.assertGreaterEqual(ret, 0)
+
+ def connections(self, ret, info):
+ with create_sockets():
+ self.assertEqual(len(ret), len(set(ret)))
+ for conn in ret:
+ assert is_namedtuple(conn)
+ check_connection_ntuple(conn)
+
+ def cwd(self, ret, info):
+ if ret: # 'ret' can be None or empty
+ self.assertIsInstance(ret, str)
+ assert os.path.isabs(ret), ret
+ try:
+ st = os.stat(ret)
+ except OSError as err:
+ if WINDOWS and err.errno in \
+ psutil._psplatform.ACCESS_DENIED_SET:
+ pass
+ # directory has been removed in mean time
+ elif err.errno != errno.ENOENT:
+ raise
+ else:
+ assert stat.S_ISDIR(st.st_mode)
+
+ def memory_percent(self, ret, info):
+ self.assertIsInstance(ret, float)
+ assert 0 <= ret <= 100, ret
+
+ def is_running(self, ret, info):
+ self.assertIsInstance(ret, bool)
+
+ def cpu_affinity(self, ret, info):
+ self.assertIsInstance(ret, list)
+ assert ret != [], ret
+ cpus = list(range(psutil.cpu_count()))
+ for n in ret:
+ self.assertIsInstance(n, int)
+ self.assertIn(n, cpus)
+
+ def terminal(self, ret, info):
+ self.assertIsInstance(ret, (str, type(None)))
+ if ret is not None:
+ assert os.path.isabs(ret), ret
+ assert os.path.exists(ret), ret
+
+ def memory_maps(self, ret, info):
+ for nt in ret:
+ self.assertIsInstance(nt.addr, str)
+ self.assertIsInstance(nt.perms, str)
+ self.assertIsInstance(nt.path, str)
+ for fname in nt._fields:
+ value = getattr(nt, fname)
+ if fname == 'path':
+ if not value.startswith(("[", "anon_inode:")):
+ assert os.path.isabs(nt.path), nt.path
+ # commented as on Linux we might get
+ # '/foo/bar (deleted)'
+ # assert os.path.exists(nt.path), nt.path
+ elif fname == 'addr':
+ assert value, repr(value)
+ elif fname == 'perms':
+ if not WINDOWS:
+ assert value, repr(value)
+ else:
+ self.assertIsInstance(value, (int, long))
+ self.assertGreaterEqual(value, 0)
+
+ def num_handles(self, ret, info):
+ self.assertIsInstance(ret, int)
+ self.assertGreaterEqual(ret, 0)
+
+ def nice(self, ret, info):
+ self.assertIsInstance(ret, int)
+ if POSIX:
+ assert -20 <= ret <= 20, ret
+ else:
+ priorities = [getattr(psutil, x) for x in dir(psutil)
+ if x.endswith('_PRIORITY_CLASS')]
+ self.assertIn(ret, priorities)
+ if sys.version_info > (3, 4):
+ self.assertIsInstance(ret, enum.IntEnum)
+ else:
+ self.assertIsInstance(ret, int)
+
+ def num_ctx_switches(self, ret, info):
+ assert is_namedtuple(ret)
+ for value in ret:
+ self.assertIsInstance(value, (int, long))
+ self.assertGreaterEqual(value, 0)
+
+ def rlimit(self, ret, info):
+ self.assertIsInstance(ret, tuple)
+ self.assertEqual(len(ret), 2)
+ self.assertGreaterEqual(ret[0], -1)
+ self.assertGreaterEqual(ret[1], -1)
+
+ def environ(self, ret, info):
+ self.assertIsInstance(ret, dict)
+ for k, v in ret.items():
+ self.assertIsInstance(k, str)
+ self.assertIsInstance(v, str)
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_linux.py b/lib/psutil/tests/test_linux.py
new file mode 100644
index 0000000..3e1afc4
--- /dev/null
+++ b/lib/psutil/tests/test_linux.py
@@ -0,0 +1,2286 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Linux specific tests."""
+
+from __future__ import division
+
+import collections
+import contextlib
+import errno
+import glob
+import io
+import os
+import re
+import shutil
+import socket
+import struct
+import textwrap
+import time
+import unittest
+import warnings
+
+import psutil
+from psutil import LINUX
+from psutil._compat import PY3
+from psutil._compat import FileNotFoundError
+from psutil._compat import basestring
+from psutil._compat import u
+from psutil.tests import GITHUB_ACTIONS
+from psutil.tests import GLOBAL_TIMEOUT
+from psutil.tests import HAS_BATTERY
+from psutil.tests import HAS_CPU_FREQ
+from psutil.tests import HAS_GETLOADAVG
+from psutil.tests import HAS_RLIMIT
+from psutil.tests import PYPY
+from psutil.tests import TOLERANCE_DISK_USAGE
+from psutil.tests import TOLERANCE_SYS_MEM
+from psutil.tests import PsutilTestCase
+from psutil.tests import ThreadTask
+from psutil.tests import call_until
+from psutil.tests import mock
+from psutil.tests import reload_module
+from psutil.tests import retry_on_failure
+from psutil.tests import safe_rmpath
+from psutil.tests import sh
+from psutil.tests import skip_on_not_implemented
+from psutil.tests import which
+
+
+if LINUX:
+ from psutil._pslinux import CLOCK_TICKS
+ from psutil._pslinux import RootFsDeviceFinder
+ from psutil._pslinux import calculate_avail_vmem
+ from psutil._pslinux import open_binary
+
+
+HERE = os.path.abspath(os.path.dirname(__file__))
+SIOCGIFADDR = 0x8915
+SIOCGIFCONF = 0x8912
+SIOCGIFHWADDR = 0x8927
+SIOCGIFNETMASK = 0x891b
+SIOCGIFBRDADDR = 0x8919
+if LINUX:
+ SECTOR_SIZE = 512
+EMPTY_TEMPERATURES = not glob.glob('/sys/class/hwmon/hwmon*')
+
+
+# =====================================================================
+# --- utils
+# =====================================================================
+
+
+def get_ipv4_address(ifname):
+ import fcntl
+ ifname = ifname[:15]
+ if PY3:
+ ifname = bytes(ifname, 'ascii')
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ with contextlib.closing(s):
+ return socket.inet_ntoa(
+ fcntl.ioctl(s.fileno(),
+ SIOCGIFADDR,
+ struct.pack('256s', ifname))[20:24])
+
+
+def get_ipv4_netmask(ifname):
+ import fcntl
+ ifname = ifname[:15]
+ if PY3:
+ ifname = bytes(ifname, 'ascii')
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ with contextlib.closing(s):
+ return socket.inet_ntoa(
+ fcntl.ioctl(s.fileno(),
+ SIOCGIFNETMASK,
+ struct.pack('256s', ifname))[20:24])
+
+
+def get_ipv4_broadcast(ifname):
+ import fcntl
+ ifname = ifname[:15]
+ if PY3:
+ ifname = bytes(ifname, 'ascii')
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ with contextlib.closing(s):
+ return socket.inet_ntoa(
+ fcntl.ioctl(s.fileno(),
+ SIOCGIFBRDADDR,
+ struct.pack('256s', ifname))[20:24])
+
+
+def get_ipv6_addresses(ifname):
+ with open("/proc/net/if_inet6", 'rt') as f:
+ all_fields = []
+ for line in f.readlines():
+ fields = line.split()
+ if fields[-1] == ifname:
+ all_fields.append(fields)
+
+ if len(all_fields) == 0:
+ raise ValueError("could not find interface %r" % ifname)
+
+ for i in range(0, len(all_fields)):
+ unformatted = all_fields[i][0]
+ groups = []
+ for j in range(0, len(unformatted), 4):
+ groups.append(unformatted[j:j + 4])
+ formatted = ":".join(groups)
+ packed = socket.inet_pton(socket.AF_INET6, formatted)
+ all_fields[i] = socket.inet_ntop(socket.AF_INET6, packed)
+ return all_fields
+
+
+def get_mac_address(ifname):
+ import fcntl
+ ifname = ifname[:15]
+ if PY3:
+ ifname = bytes(ifname, 'ascii')
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ with contextlib.closing(s):
+ info = fcntl.ioctl(
+ s.fileno(), SIOCGIFHWADDR, struct.pack('256s', ifname))
+ if PY3:
+ def ord(x):
+ return x
+ else:
+ import __builtin__
+ ord = __builtin__.ord
+ return ''.join(['%02x:' % ord(char) for char in info[18:24]])[:-1]
+
+
+def free_swap():
+ """Parse 'free' cmd and return swap memory's s total, used and free
+ values.
+ """
+ out = sh(["free", "-b"], env={"LANG": "C.UTF-8"})
+ lines = out.split('\n')
+ for line in lines:
+ if line.startswith('Swap'):
+ _, total, used, free = line.split()
+ nt = collections.namedtuple('free', 'total used free')
+ return nt(int(total), int(used), int(free))
+ raise ValueError(
+ "can't find 'Swap' in 'free' output:\n%s" % '\n'.join(lines))
+
+
+def free_physmem():
+ """Parse 'free' cmd and return physical memory's total, used
+ and free values.
+ """
+ # Note: free can have 2 different formats, invalidating 'shared'
+ # and 'cached' memory which may have different positions so we
+ # do not return them.
+ # https://github.com/giampaolo/psutil/issues/538#issuecomment-57059946
+ out = sh(["free", "-b"], env={"LANG": "C.UTF-8"})
+ lines = out.split('\n')
+ for line in lines:
+ if line.startswith('Mem'):
+ total, used, free, shared = \
+ [int(x) for x in line.split()[1:5]]
+ nt = collections.namedtuple(
+ 'free', 'total used free shared output')
+ return nt(total, used, free, shared, out)
+ raise ValueError(
+ "can't find 'Mem' in 'free' output:\n%s" % '\n'.join(lines))
+
+
+def vmstat(stat):
+ out = sh(["vmstat", "-s"], env={"LANG": "C.UTF-8"})
+ for line in out.split("\n"):
+ line = line.strip()
+ if stat in line:
+ return int(line.split(' ')[0])
+ raise ValueError("can't find %r in 'vmstat' output" % stat)
+
+
+def get_free_version_info():
+ out = sh(["free", "-V"]).strip()
+ if 'UNKNOWN' in out:
+ raise unittest.SkipTest("can't determine free version")
+ return tuple(map(int, out.split()[-1].split('.')))
+
+
+@contextlib.contextmanager
+def mock_open_content(for_path, content):
+ """Mock open() builtin and forces it to return a certain `content`
+ on read() if the path being opened matches `for_path`.
+ """
+ def open_mock(name, *args, **kwargs):
+ if name == for_path:
+ if PY3:
+ if isinstance(content, basestring):
+ return io.StringIO(content)
+ else:
+ return io.BytesIO(content)
+ else:
+ return io.BytesIO(content)
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, create=True, side_effect=open_mock) as m:
+ yield m
+
+
+@contextlib.contextmanager
+def mock_open_exception(for_path, exc):
+ """Mock open() builtin and raises `exc` if the path being opened
+ matches `for_path`.
+ """
+ def open_mock(name, *args, **kwargs):
+ if name == for_path:
+ raise exc
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, create=True, side_effect=open_mock) as m:
+ yield m
+
+
+# =====================================================================
+# --- system virtual memory
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemVirtualMemory(PsutilTestCase):
+
+ def test_total(self):
+ # free_value = free_physmem().total
+ # psutil_value = psutil.virtual_memory().total
+ # self.assertEqual(free_value, psutil_value)
+ vmstat_value = vmstat('total memory') * 1024
+ psutil_value = psutil.virtual_memory().total
+ self.assertAlmostEqual(
+ vmstat_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_used(self):
+ # Older versions of procps used slab memory to calculate used memory.
+ # This got changed in:
+ # https://gitlab.com/procps-ng/procps/commit/
+ # 05d751c4f076a2f0118b914c5e51cfbb4762ad8e
+ if get_free_version_info() < (3, 3, 12):
+ raise self.skipTest("old free version")
+ free = free_physmem()
+ free_value = free.used
+ psutil_value = psutil.virtual_memory().used
+ self.assertAlmostEqual(
+ free_value, psutil_value, delta=TOLERANCE_SYS_MEM,
+ msg='%s %s \n%s' % (free_value, psutil_value, free.output))
+
+ @retry_on_failure()
+ def test_free(self):
+ vmstat_value = vmstat('free memory') * 1024
+ psutil_value = psutil.virtual_memory().free
+ self.assertAlmostEqual(
+ vmstat_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_buffers(self):
+ vmstat_value = vmstat('buffer memory') * 1024
+ psutil_value = psutil.virtual_memory().buffers
+ self.assertAlmostEqual(
+ vmstat_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_active(self):
+ vmstat_value = vmstat('active memory') * 1024
+ psutil_value = psutil.virtual_memory().active
+ self.assertAlmostEqual(
+ vmstat_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_inactive(self):
+ vmstat_value = vmstat('inactive memory') * 1024
+ psutil_value = psutil.virtual_memory().inactive
+ self.assertAlmostEqual(
+ vmstat_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_shared(self):
+ free = free_physmem()
+ free_value = free.shared
+ if free_value == 0:
+ raise unittest.SkipTest("free does not support 'shared' column")
+ psutil_value = psutil.virtual_memory().shared
+ self.assertAlmostEqual(
+ free_value, psutil_value, delta=TOLERANCE_SYS_MEM,
+ msg='%s %s \n%s' % (free_value, psutil_value, free.output))
+
+ @retry_on_failure()
+ def test_available(self):
+ # "free" output format has changed at some point:
+ # https://github.com/giampaolo/psutil/issues/538#issuecomment-147192098
+ out = sh(["free", "-b"])
+ lines = out.split('\n')
+ if 'available' not in lines[0]:
+ raise unittest.SkipTest("free does not support 'available' column")
+ else:
+ free_value = int(lines[1].split()[-1])
+ psutil_value = psutil.virtual_memory().available
+ self.assertAlmostEqual(
+ free_value, psutil_value, delta=TOLERANCE_SYS_MEM,
+ msg='%s %s \n%s' % (free_value, psutil_value, out))
+
+ def test_warnings_on_misses(self):
+ # Emulate a case where /proc/meminfo provides few info.
+ # psutil is supposed to set the missing fields to 0 and
+ # raise a warning.
+ with mock_open_content(
+ '/proc/meminfo',
+ textwrap.dedent("""\
+ Active(anon): 6145416 kB
+ Active(file): 2950064 kB
+ Inactive(anon): 574764 kB
+ Inactive(file): 1567648 kB
+ MemAvailable: -1 kB
+ MemFree: 2057400 kB
+ MemTotal: 16325648 kB
+ SReclaimable: 346648 kB
+ """).encode()) as m:
+ with warnings.catch_warnings(record=True) as ws:
+ warnings.simplefilter("always")
+ ret = psutil.virtual_memory()
+ assert m.called
+ self.assertEqual(len(ws), 1)
+ w = ws[0]
+ assert w.filename.endswith('psutil/_pslinux.py')
+ self.assertIn(
+ "memory stats couldn't be determined", str(w.message))
+ self.assertIn("cached", str(w.message))
+ self.assertIn("shared", str(w.message))
+ self.assertIn("active", str(w.message))
+ self.assertIn("inactive", str(w.message))
+ self.assertIn("buffers", str(w.message))
+ self.assertIn("available", str(w.message))
+ self.assertEqual(ret.cached, 0)
+ self.assertEqual(ret.active, 0)
+ self.assertEqual(ret.inactive, 0)
+ self.assertEqual(ret.shared, 0)
+ self.assertEqual(ret.buffers, 0)
+ self.assertEqual(ret.available, 0)
+ self.assertEqual(ret.slab, 0)
+
+ @retry_on_failure()
+ def test_avail_old_percent(self):
+ # Make sure that our calculation of avail mem for old kernels
+ # is off by max 15%.
+ mems = {}
+ with open_binary('/proc/meminfo') as f:
+ for line in f:
+ fields = line.split()
+ mems[fields[0]] = int(fields[1]) * 1024
+
+ a = calculate_avail_vmem(mems)
+ if b'MemAvailable:' in mems:
+ b = mems[b'MemAvailable:']
+ diff_percent = abs(a - b) / a * 100
+ self.assertLess(diff_percent, 15)
+
+ def test_avail_old_comes_from_kernel(self):
+ # Make sure "MemAvailable:" coluimn is used instead of relying
+ # on our internal algorithm to calculate avail mem.
+ with mock_open_content(
+ '/proc/meminfo',
+ textwrap.dedent("""\
+ Active: 9444728 kB
+ Active(anon): 6145416 kB
+ Active(file): 2950064 kB
+ Buffers: 287952 kB
+ Cached: 4818144 kB
+ Inactive(file): 1578132 kB
+ Inactive(anon): 574764 kB
+ Inactive(file): 1567648 kB
+ MemAvailable: 6574984 kB
+ MemFree: 2057400 kB
+ MemTotal: 16325648 kB
+ Shmem: 577588 kB
+ SReclaimable: 346648 kB
+ """).encode()) as m:
+ with warnings.catch_warnings(record=True) as ws:
+ ret = psutil.virtual_memory()
+ assert m.called
+ self.assertEqual(ret.available, 6574984 * 1024)
+ w = ws[0]
+ self.assertIn(
+ "inactive memory stats couldn't be determined", str(w.message))
+
+ def test_avail_old_missing_fields(self):
+ # Remove Active(file), Inactive(file) and SReclaimable
+ # from /proc/meminfo and make sure the fallback is used
+ # (free + cached),
+ with mock_open_content(
+ "/proc/meminfo",
+ textwrap.dedent("""\
+ Active: 9444728 kB
+ Active(anon): 6145416 kB
+ Buffers: 287952 kB
+ Cached: 4818144 kB
+ Inactive(file): 1578132 kB
+ Inactive(anon): 574764 kB
+ MemFree: 2057400 kB
+ MemTotal: 16325648 kB
+ Shmem: 577588 kB
+ """).encode()) as m:
+ with warnings.catch_warnings(record=True) as ws:
+ ret = psutil.virtual_memory()
+ assert m.called
+ self.assertEqual(ret.available, 2057400 * 1024 + 4818144 * 1024)
+ w = ws[0]
+ self.assertIn(
+ "inactive memory stats couldn't be determined", str(w.message))
+
+ def test_avail_old_missing_zoneinfo(self):
+ # Remove /proc/zoneinfo file. Make sure fallback is used
+ # (free + cached).
+ with mock_open_content(
+ "/proc/meminfo",
+ textwrap.dedent("""\
+ Active: 9444728 kB
+ Active(anon): 6145416 kB
+ Active(file): 2950064 kB
+ Buffers: 287952 kB
+ Cached: 4818144 kB
+ Inactive(file): 1578132 kB
+ Inactive(anon): 574764 kB
+ Inactive(file): 1567648 kB
+ MemFree: 2057400 kB
+ MemTotal: 16325648 kB
+ Shmem: 577588 kB
+ SReclaimable: 346648 kB
+ """).encode()):
+ with mock_open_exception(
+ "/proc/zoneinfo",
+ IOError(errno.ENOENT, 'no such file or directory')):
+ with warnings.catch_warnings(record=True) as ws:
+ ret = psutil.virtual_memory()
+ self.assertEqual(
+ ret.available, 2057400 * 1024 + 4818144 * 1024)
+ w = ws[0]
+ self.assertIn(
+ "inactive memory stats couldn't be determined",
+ str(w.message))
+
+ def test_virtual_memory_mocked(self):
+ # Emulate /proc/meminfo because neither vmstat nor free return slab.
+ def open_mock(name, *args, **kwargs):
+ if name == '/proc/meminfo':
+ return io.BytesIO(textwrap.dedent("""\
+ MemTotal: 100 kB
+ MemFree: 2 kB
+ MemAvailable: 3 kB
+ Buffers: 4 kB
+ Cached: 5 kB
+ SwapCached: 6 kB
+ Active: 7 kB
+ Inactive: 8 kB
+ Active(anon): 9 kB
+ Inactive(anon): 10 kB
+ Active(file): 11 kB
+ Inactive(file): 12 kB
+ Unevictable: 13 kB
+ Mlocked: 14 kB
+ SwapTotal: 15 kB
+ SwapFree: 16 kB
+ Dirty: 17 kB
+ Writeback: 18 kB
+ AnonPages: 19 kB
+ Mapped: 20 kB
+ Shmem: 21 kB
+ Slab: 22 kB
+ SReclaimable: 23 kB
+ SUnreclaim: 24 kB
+ KernelStack: 25 kB
+ PageTables: 26 kB
+ NFS_Unstable: 27 kB
+ Bounce: 28 kB
+ WritebackTmp: 29 kB
+ CommitLimit: 30 kB
+ Committed_AS: 31 kB
+ VmallocTotal: 32 kB
+ VmallocUsed: 33 kB
+ VmallocChunk: 34 kB
+ HardwareCorrupted: 35 kB
+ AnonHugePages: 36 kB
+ ShmemHugePages: 37 kB
+ ShmemPmdMapped: 38 kB
+ CmaTotal: 39 kB
+ CmaFree: 40 kB
+ HugePages_Total: 41 kB
+ HugePages_Free: 42 kB
+ HugePages_Rsvd: 43 kB
+ HugePages_Surp: 44 kB
+ Hugepagesize: 45 kB
+ DirectMap46k: 46 kB
+ DirectMap47M: 47 kB
+ DirectMap48G: 48 kB
+ """).encode())
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, create=True, side_effect=open_mock) as m:
+ mem = psutil.virtual_memory()
+ assert m.called
+ self.assertEqual(mem.total, 100 * 1024)
+ self.assertEqual(mem.free, 2 * 1024)
+ self.assertEqual(mem.buffers, 4 * 1024)
+ # cached mem also includes reclaimable memory
+ self.assertEqual(mem.cached, (5 + 23) * 1024)
+ self.assertEqual(mem.shared, 21 * 1024)
+ self.assertEqual(mem.active, 7 * 1024)
+ self.assertEqual(mem.inactive, 8 * 1024)
+ self.assertEqual(mem.slab, 22 * 1024)
+ self.assertEqual(mem.available, 3 * 1024)
+
+
+# =====================================================================
+# --- system swap memory
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemSwapMemory(PsutilTestCase):
+
+ @staticmethod
+ def meminfo_has_swap_info():
+ """Return True if /proc/meminfo provides swap metrics."""
+ with open("/proc/meminfo") as f:
+ data = f.read()
+ return 'SwapTotal:' in data and 'SwapFree:' in data
+
+ def test_total(self):
+ free_value = free_swap().total
+ psutil_value = psutil.swap_memory().total
+ return self.assertAlmostEqual(
+ free_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_used(self):
+ free_value = free_swap().used
+ psutil_value = psutil.swap_memory().used
+ return self.assertAlmostEqual(
+ free_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_free(self):
+ free_value = free_swap().free
+ psutil_value = psutil.swap_memory().free
+ return self.assertAlmostEqual(
+ free_value, psutil_value, delta=TOLERANCE_SYS_MEM)
+
+ def test_missing_sin_sout(self):
+ with mock.patch('psutil._common.open', create=True) as m:
+ with warnings.catch_warnings(record=True) as ws:
+ warnings.simplefilter("always")
+ ret = psutil.swap_memory()
+ assert m.called
+ self.assertEqual(len(ws), 1)
+ w = ws[0]
+ assert w.filename.endswith('psutil/_pslinux.py')
+ self.assertIn(
+ "'sin' and 'sout' swap memory stats couldn't "
+ "be determined", str(w.message))
+ self.assertEqual(ret.sin, 0)
+ self.assertEqual(ret.sout, 0)
+
+ def test_no_vmstat_mocked(self):
+ # see https://github.com/giampaolo/psutil/issues/722
+ with mock_open_exception(
+ "/proc/vmstat",
+ IOError(errno.ENOENT, 'no such file or directory')) as m:
+ with warnings.catch_warnings(record=True) as ws:
+ warnings.simplefilter("always")
+ ret = psutil.swap_memory()
+ assert m.called
+ self.assertEqual(len(ws), 1)
+ w = ws[0]
+ assert w.filename.endswith('psutil/_pslinux.py')
+ self.assertIn(
+ "'sin' and 'sout' swap memory stats couldn't "
+ "be determined and were set to 0",
+ str(w.message))
+ self.assertEqual(ret.sin, 0)
+ self.assertEqual(ret.sout, 0)
+
+ def test_meminfo_against_sysinfo(self):
+ # Make sure the content of /proc/meminfo about swap memory
+ # matches sysinfo() syscall, see:
+ # https://github.com/giampaolo/psutil/issues/1015
+ if not self.meminfo_has_swap_info():
+ return unittest.skip("/proc/meminfo has no swap metrics")
+ with mock.patch('psutil._pslinux.cext.linux_sysinfo') as m:
+ swap = psutil.swap_memory()
+ assert not m.called
+ import psutil._psutil_linux as cext
+ _, _, _, _, total, free, unit_multiplier = cext.linux_sysinfo()
+ total *= unit_multiplier
+ free *= unit_multiplier
+ self.assertEqual(swap.total, total)
+ self.assertAlmostEqual(swap.free, free, delta=TOLERANCE_SYS_MEM)
+
+ def test_emulate_meminfo_has_no_metrics(self):
+ # Emulate a case where /proc/meminfo provides no swap metrics
+ # in which case sysinfo() syscall is supposed to be used
+ # as a fallback.
+ with mock_open_content("/proc/meminfo", b"") as m:
+ psutil.swap_memory()
+ assert m.called
+
+
+# =====================================================================
+# --- system CPU
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemCPUTimes(PsutilTestCase):
+
+ def test_fields(self):
+ fields = psutil.cpu_times()._fields
+ kernel_ver = re.findall(r'\d+\.\d+\.\d+', os.uname()[2])[0]
+ kernel_ver_info = tuple(map(int, kernel_ver.split('.')))
+ if kernel_ver_info >= (2, 6, 11):
+ self.assertIn('steal', fields)
+ else:
+ self.assertNotIn('steal', fields)
+ if kernel_ver_info >= (2, 6, 24):
+ self.assertIn('guest', fields)
+ else:
+ self.assertNotIn('guest', fields)
+ if kernel_ver_info >= (3, 2, 0):
+ self.assertIn('guest_nice', fields)
+ else:
+ self.assertNotIn('guest_nice', fields)
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemCPUCountLogical(PsutilTestCase):
+
+ @unittest.skipIf(not os.path.exists("/sys/devices/system/cpu/online"),
+ "/sys/devices/system/cpu/online does not exist")
+ def test_against_sysdev_cpu_online(self):
+ with open("/sys/devices/system/cpu/online") as f:
+ value = f.read().strip()
+ if "-" in str(value):
+ value = int(value.split('-')[1]) + 1
+ self.assertEqual(psutil.cpu_count(), value)
+
+ @unittest.skipIf(not os.path.exists("/sys/devices/system/cpu"),
+ "/sys/devices/system/cpu does not exist")
+ def test_against_sysdev_cpu_num(self):
+ ls = os.listdir("/sys/devices/system/cpu")
+ count = len([x for x in ls if re.search(r"cpu\d+$", x) is not None])
+ self.assertEqual(psutil.cpu_count(), count)
+
+ @unittest.skipIf(not which("nproc"), "nproc utility not available")
+ def test_against_nproc(self):
+ num = int(sh("nproc --all"))
+ self.assertEqual(psutil.cpu_count(logical=True), num)
+
+ @unittest.skipIf(not which("lscpu"), "lscpu utility not available")
+ def test_against_lscpu(self):
+ out = sh("lscpu -p")
+ num = len([x for x in out.split('\n') if not x.startswith('#')])
+ self.assertEqual(psutil.cpu_count(logical=True), num)
+
+ def test_emulate_fallbacks(self):
+ import psutil._pslinux
+ original = psutil._pslinux.cpu_count_logical()
+ # Here we want to mock os.sysconf("SC_NPROCESSORS_ONLN") in
+ # order to cause the parsing of /proc/cpuinfo and /proc/stat.
+ with mock.patch(
+ 'psutil._pslinux.os.sysconf', side_effect=ValueError) as m:
+ self.assertEqual(psutil._pslinux.cpu_count_logical(), original)
+ assert m.called
+
+ # Let's have open() return empty data and make sure None is
+ # returned ('cause we mimic os.cpu_count()).
+ with mock.patch('psutil._common.open', create=True) as m:
+ self.assertIsNone(psutil._pslinux.cpu_count_logical())
+ self.assertEqual(m.call_count, 2)
+ # /proc/stat should be the last one
+ self.assertEqual(m.call_args[0][0], '/proc/stat')
+
+ # Let's push this a bit further and make sure /proc/cpuinfo
+ # parsing works as expected.
+ with open('/proc/cpuinfo', 'rb') as f:
+ cpuinfo_data = f.read()
+ fake_file = io.BytesIO(cpuinfo_data)
+ with mock.patch('psutil._common.open',
+ return_value=fake_file, create=True) as m:
+ self.assertEqual(psutil._pslinux.cpu_count_logical(), original)
+
+ # Finally, let's make /proc/cpuinfo return meaningless data;
+ # this way we'll fall back on relying on /proc/stat
+ with mock_open_content('/proc/cpuinfo', b"") as m:
+ self.assertEqual(psutil._pslinux.cpu_count_logical(), original)
+ m.called
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemCPUCountCores(PsutilTestCase):
+
+ @unittest.skipIf(not which("lscpu"), "lscpu utility not available")
+ def test_against_lscpu(self):
+ out = sh("lscpu -p")
+ core_ids = set()
+ for line in out.split('\n'):
+ if not line.startswith('#'):
+ fields = line.split(',')
+ core_ids.add(fields[1])
+ self.assertEqual(psutil.cpu_count(logical=False), len(core_ids))
+
+ def test_method_2(self):
+ meth_1 = psutil._pslinux.cpu_count_cores()
+ with mock.patch('glob.glob', return_value=[]) as m:
+ meth_2 = psutil._pslinux.cpu_count_cores()
+ assert m.called
+ if meth_1 is not None:
+ self.assertEqual(meth_1, meth_2)
+
+ def test_emulate_none(self):
+ with mock.patch('glob.glob', return_value=[]) as m1:
+ with mock.patch('psutil._common.open', create=True) as m2:
+ self.assertIsNone(psutil._pslinux.cpu_count_cores())
+ assert m1.called
+ assert m2.called
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemCPUFrequency(PsutilTestCase):
+
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_emulate_use_second_file(self):
+ # https://github.com/giampaolo/psutil/issues/981
+ def path_exists_mock(path):
+ if path.startswith("/sys/devices/system/cpu/cpufreq/policy"):
+ return False
+ else:
+ return orig_exists(path)
+
+ orig_exists = os.path.exists
+ with mock.patch("os.path.exists", side_effect=path_exists_mock,
+ create=True):
+ assert psutil.cpu_freq()
+
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_emulate_use_cpuinfo(self):
+ # Emulate a case where /sys/devices/system/cpu/cpufreq* does not
+ # exist and /proc/cpuinfo is used instead.
+ def path_exists_mock(path):
+ if path.startswith('/sys/devices/system/cpu/'):
+ return False
+ else:
+ return os_path_exists(path)
+
+ os_path_exists = os.path.exists
+ try:
+ with mock.patch("os.path.exists", side_effect=path_exists_mock):
+ reload_module(psutil._pslinux)
+ ret = psutil.cpu_freq()
+ assert ret
+ self.assertEqual(ret.max, 0.0)
+ self.assertEqual(ret.min, 0.0)
+ for freq in psutil.cpu_freq(percpu=True):
+ self.assertEqual(ret.max, 0.0)
+ self.assertEqual(ret.min, 0.0)
+ finally:
+ reload_module(psutil._pslinux)
+ reload_module(psutil)
+
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_emulate_data(self):
+ def open_mock(name, *args, **kwargs):
+ if (name.endswith('/scaling_cur_freq') and
+ name.startswith("/sys/devices/system/cpu/cpufreq/policy")):
+ return io.BytesIO(b"500000")
+ elif (name.endswith('/scaling_min_freq') and
+ name.startswith("/sys/devices/system/cpu/cpufreq/policy")):
+ return io.BytesIO(b"600000")
+ elif (name.endswith('/scaling_max_freq') and
+ name.startswith("/sys/devices/system/cpu/cpufreq/policy")):
+ return io.BytesIO(b"700000")
+ elif name == '/proc/cpuinfo':
+ return io.BytesIO(b"cpu MHz : 500")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock):
+ with mock.patch(
+ 'os.path.exists', return_value=True):
+ freq = psutil.cpu_freq()
+ self.assertEqual(freq.current, 500.0)
+ # when /proc/cpuinfo is used min and max frequencies are not
+ # available and are set to 0.
+ if freq.min != 0.0:
+ self.assertEqual(freq.min, 600.0)
+ if freq.max != 0.0:
+ self.assertEqual(freq.max, 700.0)
+
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_emulate_multi_cpu(self):
+ def open_mock(name, *args, **kwargs):
+ n = name
+ if (n.endswith('/scaling_cur_freq') and
+ n.startswith("/sys/devices/system/cpu/cpufreq/policy0")):
+ return io.BytesIO(b"100000")
+ elif (n.endswith('/scaling_min_freq') and
+ n.startswith("/sys/devices/system/cpu/cpufreq/policy0")):
+ return io.BytesIO(b"200000")
+ elif (n.endswith('/scaling_max_freq') and
+ n.startswith("/sys/devices/system/cpu/cpufreq/policy0")):
+ return io.BytesIO(b"300000")
+ elif (n.endswith('/scaling_cur_freq') and
+ n.startswith("/sys/devices/system/cpu/cpufreq/policy1")):
+ return io.BytesIO(b"400000")
+ elif (n.endswith('/scaling_min_freq') and
+ n.startswith("/sys/devices/system/cpu/cpufreq/policy1")):
+ return io.BytesIO(b"500000")
+ elif (n.endswith('/scaling_max_freq') and
+ n.startswith("/sys/devices/system/cpu/cpufreq/policy1")):
+ return io.BytesIO(b"600000")
+ elif name == '/proc/cpuinfo':
+ return io.BytesIO(b"cpu MHz : 100\n"
+ b"cpu MHz : 400")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock):
+ with mock.patch('os.path.exists', return_value=True):
+ with mock.patch('psutil._pslinux.cpu_count_logical',
+ return_value=2):
+ freq = psutil.cpu_freq(percpu=True)
+ self.assertEqual(freq[0].current, 100.0)
+ if freq[0].min != 0.0:
+ self.assertEqual(freq[0].min, 200.0)
+ if freq[0].max != 0.0:
+ self.assertEqual(freq[0].max, 300.0)
+ self.assertEqual(freq[1].current, 400.0)
+ if freq[1].min != 0.0:
+ self.assertEqual(freq[1].min, 500.0)
+ if freq[1].max != 0.0:
+ self.assertEqual(freq[1].max, 600.0)
+
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_emulate_no_scaling_cur_freq_file(self):
+ # See: https://github.com/giampaolo/psutil/issues/1071
+ def open_mock(name, *args, **kwargs):
+ if name.endswith('/scaling_cur_freq'):
+ raise IOError(errno.ENOENT, "")
+ elif name.endswith('/cpuinfo_cur_freq'):
+ return io.BytesIO(b"200000")
+ elif name == '/proc/cpuinfo':
+ return io.BytesIO(b"cpu MHz : 200")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock):
+ with mock.patch('os.path.exists', return_value=True):
+ with mock.patch('psutil._pslinux.cpu_count_logical',
+ return_value=1):
+ freq = psutil.cpu_freq()
+ self.assertEqual(freq.current, 200)
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemCPUStats(PsutilTestCase):
+
+ def test_ctx_switches(self):
+ vmstat_value = vmstat("context switches")
+ psutil_value = psutil.cpu_stats().ctx_switches
+ self.assertAlmostEqual(vmstat_value, psutil_value, delta=500)
+
+ def test_interrupts(self):
+ vmstat_value = vmstat("interrupts")
+ psutil_value = psutil.cpu_stats().interrupts
+ self.assertAlmostEqual(vmstat_value, psutil_value, delta=500)
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestLoadAvg(PsutilTestCase):
+
+ @unittest.skipIf(not HAS_GETLOADAVG, "not supported")
+ def test_getloadavg(self):
+ psutil_value = psutil.getloadavg()
+ with open("/proc/loadavg", "r") as f:
+ proc_value = f.read().split()
+
+ self.assertAlmostEqual(float(proc_value[0]), psutil_value[0], delta=1)
+ self.assertAlmostEqual(float(proc_value[1]), psutil_value[1], delta=1)
+ self.assertAlmostEqual(float(proc_value[2]), psutil_value[2], delta=1)
+
+
+# =====================================================================
+# --- system network
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemNetIfAddrs(PsutilTestCase):
+
+ def test_ips(self):
+ for name, addrs in psutil.net_if_addrs().items():
+ for addr in addrs:
+ if addr.family == psutil.AF_LINK:
+ self.assertEqual(addr.address, get_mac_address(name))
+ elif addr.family == socket.AF_INET:
+ self.assertEqual(addr.address, get_ipv4_address(name))
+ self.assertEqual(addr.netmask, get_ipv4_netmask(name))
+ if addr.broadcast is not None:
+ self.assertEqual(addr.broadcast,
+ get_ipv4_broadcast(name))
+ else:
+ self.assertEqual(get_ipv4_broadcast(name), '0.0.0.0')
+ elif addr.family == socket.AF_INET6:
+ # IPv6 addresses can have a percent symbol at the end.
+ # E.g. these 2 are equivalent:
+ # "fe80::1ff:fe23:4567:890a"
+ # "fe80::1ff:fe23:4567:890a%eth0"
+ # That is the "zone id" portion, which usually is the name
+ # of the network interface.
+ address = addr.address.split('%')[0]
+ self.assertIn(address, get_ipv6_addresses(name))
+
+ # XXX - not reliable when having virtual NICs installed by Docker.
+ # @unittest.skipIf(not which('ip'), "'ip' utility not available")
+ # def test_net_if_names(self):
+ # out = sh("ip addr").strip()
+ # nics = [x for x in psutil.net_if_addrs().keys() if ':' not in x]
+ # found = 0
+ # for line in out.split('\n'):
+ # line = line.strip()
+ # if re.search(r"^\d+:", line):
+ # found += 1
+ # name = line.split(':')[1].strip()
+ # self.assertIn(name, nics)
+ # self.assertEqual(len(nics), found, msg="%s\n---\n%s" % (
+ # pprint.pformat(nics), out))
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemNetIfStats(PsutilTestCase):
+
+ @unittest.skipIf(not which("ifconfig"), "ifconfig utility not available")
+ def test_against_ifconfig(self):
+ for name, stats in psutil.net_if_stats().items():
+ try:
+ out = sh("ifconfig %s" % name)
+ except RuntimeError:
+ pass
+ else:
+ self.assertEqual(stats.isup, 'RUNNING' in out, msg=out)
+ self.assertEqual(stats.mtu,
+ int(re.findall(r'(?i)MTU[: ](\d+)', out)[0]))
+
+ def test_mtu(self):
+ for name, stats in psutil.net_if_stats().items():
+ with open("/sys/class/net/%s/mtu" % name, "rt") as f:
+ self.assertEqual(stats.mtu, int(f.read().strip()))
+
+ @unittest.skipIf(not which("ifconfig"), "ifconfig utility not available")
+ def test_flags(self):
+ # first line looks like this:
+ # "eth0: flags=4163<UP,BROADCAST,RUNNING,MULTICAST> mtu 1500"
+ matches_found = 0
+ for name, stats in psutil.net_if_stats().items():
+ try:
+ out = sh("ifconfig %s" % name)
+ except RuntimeError:
+ pass
+ else:
+ match = re.search(r"flags=(\d+)?<(.*?)>", out)
+ if match and len(match.groups()) >= 2:
+ matches_found += 1
+ ifconfig_flags = set(match.group(2).lower().split(","))
+ psutil_flags = set(stats.flags.split(","))
+ self.assertEqual(ifconfig_flags, psutil_flags)
+ else:
+ # ifconfig has a different output on CentOS 6
+ # let's try that
+ match = re.search(r"(.*) MTU:(\d+) Metric:(\d+)", out)
+ if match and len(match.groups()) >= 3:
+ matches_found += 1
+ ifconfig_flags = set(match.group(1).lower().split())
+ psutil_flags = set(stats.flags.split(","))
+ self.assertEqual(ifconfig_flags, psutil_flags)
+
+ if not matches_found:
+ raise self.fail("no matches were found")
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemNetIOCounters(PsutilTestCase):
+
+ @unittest.skipIf(not which("ifconfig"), "ifconfig utility not available")
+ @retry_on_failure()
+ def test_against_ifconfig(self):
+ def ifconfig(nic):
+ ret = {}
+ out = sh("ifconfig %s" % name)
+ ret['packets_recv'] = int(
+ re.findall(r'RX packets[: ](\d+)', out)[0])
+ ret['packets_sent'] = int(
+ re.findall(r'TX packets[: ](\d+)', out)[0])
+ ret['errin'] = int(re.findall(r'errors[: ](\d+)', out)[0])
+ ret['errout'] = int(re.findall(r'errors[: ](\d+)', out)[1])
+ ret['dropin'] = int(re.findall(r'dropped[: ](\d+)', out)[0])
+ ret['dropout'] = int(re.findall(r'dropped[: ](\d+)', out)[1])
+ ret['bytes_recv'] = int(
+ re.findall(r'RX (?:packets \d+ +)?bytes[: ](\d+)', out)[0])
+ ret['bytes_sent'] = int(
+ re.findall(r'TX (?:packets \d+ +)?bytes[: ](\d+)', out)[0])
+ return ret
+
+ nio = psutil.net_io_counters(pernic=True, nowrap=False)
+ for name, stats in nio.items():
+ try:
+ ifconfig_ret = ifconfig(name)
+ except RuntimeError:
+ continue
+ self.assertAlmostEqual(
+ stats.bytes_recv, ifconfig_ret['bytes_recv'], delta=1024 * 5)
+ self.assertAlmostEqual(
+ stats.bytes_sent, ifconfig_ret['bytes_sent'], delta=1024 * 5)
+ self.assertAlmostEqual(
+ stats.packets_recv, ifconfig_ret['packets_recv'], delta=1024)
+ self.assertAlmostEqual(
+ stats.packets_sent, ifconfig_ret['packets_sent'], delta=1024)
+ self.assertAlmostEqual(
+ stats.errin, ifconfig_ret['errin'], delta=10)
+ self.assertAlmostEqual(
+ stats.errout, ifconfig_ret['errout'], delta=10)
+ self.assertAlmostEqual(
+ stats.dropin, ifconfig_ret['dropin'], delta=10)
+ self.assertAlmostEqual(
+ stats.dropout, ifconfig_ret['dropout'], delta=10)
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemNetConnections(PsutilTestCase):
+
+ @mock.patch('psutil._pslinux.socket.inet_ntop', side_effect=ValueError)
+ @mock.patch('psutil._pslinux.supports_ipv6', return_value=False)
+ def test_emulate_ipv6_unsupported(self, supports_ipv6, inet_ntop):
+ # see: https://github.com/giampaolo/psutil/issues/623
+ try:
+ s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ self.addCleanup(s.close)
+ s.bind(("::1", 0))
+ except socket.error:
+ pass
+ psutil.net_connections(kind='inet6')
+
+ def test_emulate_unix(self):
+ with mock_open_content(
+ '/proc/net/unix',
+ textwrap.dedent("""\
+ 0: 00000003 000 000 0001 03 462170 @/tmp/dbus-Qw2hMPIU3n
+ 0: 00000003 000 000 0001 03 35010 @/tmp/dbus-tB2X8h69BQ
+ 0: 00000003 000 000 0001 03 34424 @/tmp/dbus-cHy80Y8O
+ 000000000000000000000000000000000000000000000000000000
+ """)) as m:
+ psutil.net_connections(kind='unix')
+ assert m.called
+
+
+# =====================================================================
+# --- system disks
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemDiskPartitions(PsutilTestCase):
+
+ @unittest.skipIf(not hasattr(os, 'statvfs'), "os.statvfs() not available")
+ @skip_on_not_implemented()
+ def test_against_df(self):
+ # test psutil.disk_usage() and psutil.disk_partitions()
+ # against "df -a"
+ def df(path):
+ out = sh('df -P -B 1 "%s"' % path).strip()
+ lines = out.split('\n')
+ lines.pop(0)
+ line = lines.pop(0)
+ dev, total, used, free = line.split()[:4]
+ if dev == 'none':
+ dev = ''
+ total, used, free = int(total), int(used), int(free)
+ return dev, total, used, free
+
+ for part in psutil.disk_partitions(all=False):
+ usage = psutil.disk_usage(part.mountpoint)
+ dev, total, used, free = df(part.mountpoint)
+ self.assertEqual(usage.total, total)
+ self.assertAlmostEqual(usage.free, free,
+ delta=TOLERANCE_DISK_USAGE)
+ self.assertAlmostEqual(usage.used, used,
+ delta=TOLERANCE_DISK_USAGE)
+
+ def test_zfs_fs(self):
+ # Test that ZFS partitions are returned.
+ with open("/proc/filesystems", "r") as f:
+ data = f.read()
+ if 'zfs' in data:
+ for part in psutil.disk_partitions():
+ if part.fstype == 'zfs':
+ break
+ else:
+ raise self.fail("couldn't find any ZFS partition")
+ else:
+ # No ZFS partitions on this system. Let's fake one.
+ fake_file = io.StringIO(u("nodev\tzfs\n"))
+ with mock.patch('psutil._common.open',
+ return_value=fake_file, create=True) as m1:
+ with mock.patch(
+ 'psutil._pslinux.cext.disk_partitions',
+ return_value=[('/dev/sdb3', '/', 'zfs', 'rw')]) as m2:
+ ret = psutil.disk_partitions()
+ assert m1.called
+ assert m2.called
+ assert ret
+ self.assertEqual(ret[0].fstype, 'zfs')
+
+ def test_emulate_realpath_fail(self):
+ # See: https://github.com/giampaolo/psutil/issues/1307
+ try:
+ with mock.patch('os.path.realpath',
+ return_value='/non/existent') as m:
+ with self.assertRaises(FileNotFoundError):
+ psutil.disk_partitions()
+ assert m.called
+ finally:
+ psutil.PROCFS_PATH = "/proc"
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSystemDiskIoCounters(PsutilTestCase):
+
+ def test_emulate_kernel_2_4(self):
+ # Tests /proc/diskstats parsing format for 2.4 kernels, see:
+ # https://github.com/giampaolo/psutil/issues/767
+ with mock_open_content(
+ '/proc/diskstats',
+ " 3 0 1 hda 2 3 4 5 6 7 8 9 10 11 12"):
+ with mock.patch('psutil._pslinux.is_storage_device',
+ return_value=True):
+ ret = psutil.disk_io_counters(nowrap=False)
+ self.assertEqual(ret.read_count, 1)
+ self.assertEqual(ret.read_merged_count, 2)
+ self.assertEqual(ret.read_bytes, 3 * SECTOR_SIZE)
+ self.assertEqual(ret.read_time, 4)
+ self.assertEqual(ret.write_count, 5)
+ self.assertEqual(ret.write_merged_count, 6)
+ self.assertEqual(ret.write_bytes, 7 * SECTOR_SIZE)
+ self.assertEqual(ret.write_time, 8)
+ self.assertEqual(ret.busy_time, 10)
+
+ def test_emulate_kernel_2_6_full(self):
+ # Tests /proc/diskstats parsing format for 2.6 kernels,
+ # lines reporting all metrics:
+ # https://github.com/giampaolo/psutil/issues/767
+ with mock_open_content(
+ '/proc/diskstats',
+ " 3 0 hda 1 2 3 4 5 6 7 8 9 10 11"):
+ with mock.patch('psutil._pslinux.is_storage_device',
+ return_value=True):
+ ret = psutil.disk_io_counters(nowrap=False)
+ self.assertEqual(ret.read_count, 1)
+ self.assertEqual(ret.read_merged_count, 2)
+ self.assertEqual(ret.read_bytes, 3 * SECTOR_SIZE)
+ self.assertEqual(ret.read_time, 4)
+ self.assertEqual(ret.write_count, 5)
+ self.assertEqual(ret.write_merged_count, 6)
+ self.assertEqual(ret.write_bytes, 7 * SECTOR_SIZE)
+ self.assertEqual(ret.write_time, 8)
+ self.assertEqual(ret.busy_time, 10)
+
+ def test_emulate_kernel_2_6_limited(self):
+ # Tests /proc/diskstats parsing format for 2.6 kernels,
+ # where one line of /proc/partitions return a limited
+ # amount of metrics when it bumps into a partition
+ # (instead of a disk). See:
+ # https://github.com/giampaolo/psutil/issues/767
+ with mock_open_content(
+ '/proc/diskstats',
+ " 3 1 hda 1 2 3 4"):
+ with mock.patch('psutil._pslinux.is_storage_device',
+ return_value=True):
+ ret = psutil.disk_io_counters(nowrap=False)
+ self.assertEqual(ret.read_count, 1)
+ self.assertEqual(ret.read_bytes, 2 * SECTOR_SIZE)
+ self.assertEqual(ret.write_count, 3)
+ self.assertEqual(ret.write_bytes, 4 * SECTOR_SIZE)
+
+ self.assertEqual(ret.read_merged_count, 0)
+ self.assertEqual(ret.read_time, 0)
+ self.assertEqual(ret.write_merged_count, 0)
+ self.assertEqual(ret.write_time, 0)
+ self.assertEqual(ret.busy_time, 0)
+
+ def test_emulate_include_partitions(self):
+ # Make sure that when perdisk=True disk partitions are returned,
+ # see:
+ # https://github.com/giampaolo/psutil/pull/1313#issuecomment-408626842
+ with mock_open_content(
+ '/proc/diskstats',
+ textwrap.dedent("""\
+ 3 0 nvme0n1 1 2 3 4 5 6 7 8 9 10 11
+ 3 0 nvme0n1p1 1 2 3 4 5 6 7 8 9 10 11
+ """)):
+ with mock.patch('psutil._pslinux.is_storage_device',
+ return_value=False):
+ ret = psutil.disk_io_counters(perdisk=True, nowrap=False)
+ self.assertEqual(len(ret), 2)
+ self.assertEqual(ret['nvme0n1'].read_count, 1)
+ self.assertEqual(ret['nvme0n1p1'].read_count, 1)
+ self.assertEqual(ret['nvme0n1'].write_count, 5)
+ self.assertEqual(ret['nvme0n1p1'].write_count, 5)
+
+ def test_emulate_exclude_partitions(self):
+ # Make sure that when perdisk=False partitions (e.g. 'sda1',
+ # 'nvme0n1p1') are skipped and not included in the total count.
+ # https://github.com/giampaolo/psutil/pull/1313#issuecomment-408626842
+ with mock_open_content(
+ '/proc/diskstats',
+ textwrap.dedent("""\
+ 3 0 nvme0n1 1 2 3 4 5 6 7 8 9 10 11
+ 3 0 nvme0n1p1 1 2 3 4 5 6 7 8 9 10 11
+ """)):
+ with mock.patch('psutil._pslinux.is_storage_device',
+ return_value=False):
+ ret = psutil.disk_io_counters(perdisk=False, nowrap=False)
+ self.assertIsNone(ret)
+
+ #
+ def is_storage_device(name):
+ return name == 'nvme0n1'
+
+ with mock_open_content(
+ '/proc/diskstats',
+ textwrap.dedent("""\
+ 3 0 nvme0n1 1 2 3 4 5 6 7 8 9 10 11
+ 3 0 nvme0n1p1 1 2 3 4 5 6 7 8 9 10 11
+ """)):
+ with mock.patch('psutil._pslinux.is_storage_device',
+ create=True, side_effect=is_storage_device):
+ ret = psutil.disk_io_counters(perdisk=False, nowrap=False)
+ self.assertEqual(ret.read_count, 1)
+ self.assertEqual(ret.write_count, 5)
+
+ def test_emulate_use_sysfs(self):
+ def exists(path):
+ if path == '/proc/diskstats':
+ return False
+ return True
+
+ wprocfs = psutil.disk_io_counters(perdisk=True)
+ with mock.patch('psutil._pslinux.os.path.exists',
+ create=True, side_effect=exists):
+ wsysfs = psutil.disk_io_counters(perdisk=True)
+ self.assertEqual(len(wprocfs), len(wsysfs))
+
+ def test_emulate_not_impl(self):
+ def exists(path):
+ return False
+
+ with mock.patch('psutil._pslinux.os.path.exists',
+ create=True, side_effect=exists):
+ self.assertRaises(NotImplementedError, psutil.disk_io_counters)
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestRootFsDeviceFinder(PsutilTestCase):
+
+ def setUp(self):
+ dev = os.stat("/").st_dev
+ self.major = os.major(dev)
+ self.minor = os.minor(dev)
+
+ def test_call_methods(self):
+ finder = RootFsDeviceFinder()
+ if os.path.exists("/proc/partitions"):
+ finder.ask_proc_partitions()
+ else:
+ self.assertRaises(FileNotFoundError, finder.ask_proc_partitions)
+ if os.path.exists("/sys/dev/block/%s:%s/uevent" % (
+ self.major, self.minor)):
+ finder.ask_sys_dev_block()
+ else:
+ self.assertRaises(FileNotFoundError, finder.ask_sys_dev_block)
+ finder.ask_sys_class_block()
+
+ @unittest.skipIf(GITHUB_ACTIONS, "unsupported on GITHUB_ACTIONS")
+ def test_comparisons(self):
+ finder = RootFsDeviceFinder()
+ self.assertIsNotNone(finder.find())
+
+ a = b = c = None
+ if os.path.exists("/proc/partitions"):
+ a = finder.ask_proc_partitions()
+ if os.path.exists("/sys/dev/block/%s:%s/uevent" % (
+ self.major, self.minor)):
+ b = finder.ask_sys_class_block()
+ c = finder.ask_sys_dev_block()
+
+ base = a or b or c
+ if base and a:
+ self.assertEqual(base, a)
+ if base and b:
+ self.assertEqual(base, b)
+ if base and c:
+ self.assertEqual(base, c)
+
+ @unittest.skipIf(not which("findmnt"), "findmnt utility not available")
+ @unittest.skipIf(GITHUB_ACTIONS, "unsupported on GITHUB_ACTIONS")
+ def test_against_findmnt(self):
+ psutil_value = RootFsDeviceFinder().find()
+ findmnt_value = sh("findmnt -o SOURCE -rn /")
+ self.assertEqual(psutil_value, findmnt_value)
+
+ def test_disk_partitions_mocked(self):
+ with mock.patch(
+ 'psutil._pslinux.cext.disk_partitions',
+ return_value=[('/dev/root', '/', 'ext4', 'rw')]) as m:
+ part = psutil.disk_partitions()[0]
+ assert m.called
+ if not GITHUB_ACTIONS:
+ self.assertNotEqual(part.device, "/dev/root")
+ self.assertEqual(part.device, RootFsDeviceFinder().find())
+ else:
+ self.assertEqual(part.device, "/dev/root")
+
+
+# =====================================================================
+# --- misc
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestMisc(PsutilTestCase):
+
+ def test_boot_time(self):
+ vmstat_value = vmstat('boot time')
+ psutil_value = psutil.boot_time()
+ self.assertEqual(int(vmstat_value), int(psutil_value))
+
+ def test_no_procfs_on_import(self):
+ my_procfs = self.get_testfn()
+ os.mkdir(my_procfs)
+
+ with open(os.path.join(my_procfs, 'stat'), 'w') as f:
+ f.write('cpu 0 0 0 0 0 0 0 0 0 0\n')
+ f.write('cpu0 0 0 0 0 0 0 0 0 0 0\n')
+ f.write('cpu1 0 0 0 0 0 0 0 0 0 0\n')
+
+ try:
+ orig_open = open
+
+ def open_mock(name, *args, **kwargs):
+ if name.startswith('/proc'):
+ raise IOError(errno.ENOENT, 'rejecting access for test')
+ return orig_open(name, *args, **kwargs)
+
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock):
+ reload_module(psutil)
+
+ self.assertRaises(IOError, psutil.cpu_times)
+ self.assertRaises(IOError, psutil.cpu_times, percpu=True)
+ self.assertRaises(IOError, psutil.cpu_percent)
+ self.assertRaises(IOError, psutil.cpu_percent, percpu=True)
+ self.assertRaises(IOError, psutil.cpu_times_percent)
+ self.assertRaises(
+ IOError, psutil.cpu_times_percent, percpu=True)
+
+ psutil.PROCFS_PATH = my_procfs
+
+ self.assertEqual(psutil.cpu_percent(), 0)
+ self.assertEqual(sum(psutil.cpu_times_percent()), 0)
+
+ # since we don't know the number of CPUs at import time,
+ # we awkwardly say there are none until the second call
+ per_cpu_percent = psutil.cpu_percent(percpu=True)
+ self.assertEqual(sum(per_cpu_percent), 0)
+
+ # ditto awkward length
+ per_cpu_times_percent = psutil.cpu_times_percent(percpu=True)
+ self.assertEqual(sum(map(sum, per_cpu_times_percent)), 0)
+
+ # much user, very busy
+ with open(os.path.join(my_procfs, 'stat'), 'w') as f:
+ f.write('cpu 1 0 0 0 0 0 0 0 0 0\n')
+ f.write('cpu0 1 0 0 0 0 0 0 0 0 0\n')
+ f.write('cpu1 1 0 0 0 0 0 0 0 0 0\n')
+
+ self.assertNotEqual(psutil.cpu_percent(), 0)
+ self.assertNotEqual(
+ sum(psutil.cpu_percent(percpu=True)), 0)
+ self.assertNotEqual(sum(psutil.cpu_times_percent()), 0)
+ self.assertNotEqual(
+ sum(map(sum, psutil.cpu_times_percent(percpu=True))), 0)
+ finally:
+ shutil.rmtree(my_procfs)
+ reload_module(psutil)
+
+ self.assertEqual(psutil.PROCFS_PATH, '/proc')
+
+ def test_cpu_steal_decrease(self):
+ # Test cumulative cpu stats decrease. We should ignore this.
+ # See issue #1210.
+ with mock_open_content(
+ "/proc/stat",
+ textwrap.dedent("""\
+ cpu 0 0 0 0 0 0 0 1 0 0
+ cpu0 0 0 0 0 0 0 0 1 0 0
+ cpu1 0 0 0 0 0 0 0 1 0 0
+ """).encode()) as m:
+ # first call to "percent" functions should read the new stat file
+ # and compare to the "real" file read at import time - so the
+ # values are meaningless
+ psutil.cpu_percent()
+ assert m.called
+ psutil.cpu_percent(percpu=True)
+ psutil.cpu_times_percent()
+ psutil.cpu_times_percent(percpu=True)
+
+ with mock_open_content(
+ "/proc/stat",
+ textwrap.dedent("""\
+ cpu 1 0 0 0 0 0 0 0 0 0
+ cpu0 1 0 0 0 0 0 0 0 0 0
+ cpu1 1 0 0 0 0 0 0 0 0 0
+ """).encode()) as m:
+ # Increase "user" while steal goes "backwards" to zero.
+ cpu_percent = psutil.cpu_percent()
+ assert m.called
+ cpu_percent_percpu = psutil.cpu_percent(percpu=True)
+ cpu_times_percent = psutil.cpu_times_percent()
+ cpu_times_percent_percpu = psutil.cpu_times_percent(percpu=True)
+ self.assertNotEqual(cpu_percent, 0)
+ self.assertNotEqual(sum(cpu_percent_percpu), 0)
+ self.assertNotEqual(sum(cpu_times_percent), 0)
+ self.assertNotEqual(sum(cpu_times_percent), 100.0)
+ self.assertNotEqual(sum(map(sum, cpu_times_percent_percpu)), 0)
+ self.assertNotEqual(sum(map(sum, cpu_times_percent_percpu)), 100.0)
+ self.assertEqual(cpu_times_percent.steal, 0)
+ self.assertNotEqual(cpu_times_percent.user, 0)
+
+ def test_boot_time_mocked(self):
+ with mock.patch('psutil._common.open', create=True) as m:
+ self.assertRaises(
+ RuntimeError,
+ psutil._pslinux.boot_time)
+ assert m.called
+
+ def test_users_mocked(self):
+ # Make sure ':0' and ':0.0' (returned by C ext) are converted
+ # to 'localhost'.
+ with mock.patch('psutil._pslinux.cext.users',
+ return_value=[('giampaolo', 'pts/2', ':0',
+ 1436573184.0, True, 2)]) as m:
+ self.assertEqual(psutil.users()[0].host, 'localhost')
+ assert m.called
+ with mock.patch('psutil._pslinux.cext.users',
+ return_value=[('giampaolo', 'pts/2', ':0.0',
+ 1436573184.0, True, 2)]) as m:
+ self.assertEqual(psutil.users()[0].host, 'localhost')
+ assert m.called
+ # ...otherwise it should be returned as-is
+ with mock.patch('psutil._pslinux.cext.users',
+ return_value=[('giampaolo', 'pts/2', 'foo',
+ 1436573184.0, True, 2)]) as m:
+ self.assertEqual(psutil.users()[0].host, 'foo')
+ assert m.called
+
+ def test_procfs_path(self):
+ tdir = self.get_testfn()
+ os.mkdir(tdir)
+ try:
+ psutil.PROCFS_PATH = tdir
+ self.assertRaises(IOError, psutil.virtual_memory)
+ self.assertRaises(IOError, psutil.cpu_times)
+ self.assertRaises(IOError, psutil.cpu_times, percpu=True)
+ self.assertRaises(IOError, psutil.boot_time)
+ # self.assertRaises(IOError, psutil.pids)
+ self.assertRaises(IOError, psutil.net_connections)
+ self.assertRaises(IOError, psutil.net_io_counters)
+ self.assertRaises(IOError, psutil.net_if_stats)
+ # self.assertRaises(IOError, psutil.disk_io_counters)
+ self.assertRaises(IOError, psutil.disk_partitions)
+ self.assertRaises(psutil.NoSuchProcess, psutil.Process)
+ finally:
+ psutil.PROCFS_PATH = "/proc"
+
+ @retry_on_failure()
+ def test_issue_687(self):
+ # In case of thread ID:
+ # - pid_exists() is supposed to return False
+ # - Process(tid) is supposed to work
+ # - pids() should not return the TID
+ # See: https://github.com/giampaolo/psutil/issues/687
+ with ThreadTask():
+ p = psutil.Process()
+ threads = p.threads()
+ self.assertEqual(len(threads), 2)
+ tid = sorted(threads, key=lambda x: x.id)[1].id
+ self.assertNotEqual(p.pid, tid)
+ pt = psutil.Process(tid)
+ pt.as_dict()
+ self.assertNotIn(tid, psutil.pids())
+
+ def test_pid_exists_no_proc_status(self):
+ # Internally pid_exists relies on /proc/{pid}/status.
+ # Emulate a case where this file is empty in which case
+ # psutil is supposed to fall back on using pids().
+ with mock_open_content("/proc/%s/status", "") as m:
+ assert psutil.pid_exists(os.getpid())
+ assert m.called
+
+
+# =====================================================================
+# --- sensors
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+@unittest.skipIf(not HAS_BATTERY, "no battery")
+class TestSensorsBattery(PsutilTestCase):
+
+ @unittest.skipIf(not which("acpi"), "acpi utility not available")
+ def test_percent(self):
+ out = sh("acpi -b")
+ acpi_value = int(out.split(",")[1].strip().replace('%', ''))
+ psutil_value = psutil.sensors_battery().percent
+ self.assertAlmostEqual(acpi_value, psutil_value, delta=1)
+
+ def test_emulate_power_plugged(self):
+ # Pretend the AC power cable is connected.
+ def open_mock(name, *args, **kwargs):
+ if name.endswith("AC0/online") or name.endswith("AC/online"):
+ return io.BytesIO(b"1")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock) as m:
+ self.assertEqual(psutil.sensors_battery().power_plugged, True)
+ self.assertEqual(
+ psutil.sensors_battery().secsleft, psutil.POWER_TIME_UNLIMITED)
+ assert m.called
+
+ def test_emulate_power_plugged_2(self):
+ # Same as above but pretend /AC0/online does not exist in which
+ # case code relies on /status file.
+ def open_mock(name, *args, **kwargs):
+ if name.endswith("AC0/online") or name.endswith("AC/online"):
+ raise IOError(errno.ENOENT, "")
+ elif name.endswith("/status"):
+ return io.StringIO(u("charging"))
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock) as m:
+ self.assertEqual(psutil.sensors_battery().power_plugged, True)
+ assert m.called
+
+ def test_emulate_power_not_plugged(self):
+ # Pretend the AC power cable is not connected.
+ def open_mock(name, *args, **kwargs):
+ if name.endswith("AC0/online") or name.endswith("AC/online"):
+ return io.BytesIO(b"0")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock) as m:
+ self.assertEqual(psutil.sensors_battery().power_plugged, False)
+ assert m.called
+
+ def test_emulate_power_not_plugged_2(self):
+ # Same as above but pretend /AC0/online does not exist in which
+ # case code relies on /status file.
+ def open_mock(name, *args, **kwargs):
+ if name.endswith("AC0/online") or name.endswith("AC/online"):
+ raise IOError(errno.ENOENT, "")
+ elif name.endswith("/status"):
+ return io.StringIO(u("discharging"))
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock) as m:
+ self.assertEqual(psutil.sensors_battery().power_plugged, False)
+ assert m.called
+
+ def test_emulate_power_undetermined(self):
+ # Pretend we can't know whether the AC power cable not
+ # connected (assert fallback to False).
+ def open_mock(name, *args, **kwargs):
+ if name.startswith("/sys/class/power_supply/AC0/online") or \
+ name.startswith("/sys/class/power_supply/AC/online"):
+ raise IOError(errno.ENOENT, "")
+ elif name.startswith("/sys/class/power_supply/BAT0/status"):
+ return io.BytesIO(b"???")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock) as m:
+ self.assertIsNone(psutil.sensors_battery().power_plugged)
+ assert m.called
+
+ def test_emulate_energy_full_0(self):
+ # Emulate a case where energy_full files returns 0.
+ with mock_open_content(
+ "/sys/class/power_supply/BAT0/energy_full", b"0") as m:
+ self.assertEqual(psutil.sensors_battery().percent, 0)
+ assert m.called
+
+ def test_emulate_energy_full_not_avail(self):
+ # Emulate a case where energy_full file does not exist.
+ # Expected fallback on /capacity.
+ with mock_open_exception(
+ "/sys/class/power_supply/BAT0/energy_full",
+ IOError(errno.ENOENT, "")):
+ with mock_open_exception(
+ "/sys/class/power_supply/BAT0/charge_full",
+ IOError(errno.ENOENT, "")):
+ with mock_open_content(
+ "/sys/class/power_supply/BAT0/capacity", b"88"):
+ self.assertEqual(psutil.sensors_battery().percent, 88)
+
+ def test_emulate_no_power(self):
+ # Emulate a case where /AC0/online file nor /BAT0/status exist.
+ with mock_open_exception(
+ "/sys/class/power_supply/AC/online",
+ IOError(errno.ENOENT, "")):
+ with mock_open_exception(
+ "/sys/class/power_supply/AC0/online",
+ IOError(errno.ENOENT, "")):
+ with mock_open_exception(
+ "/sys/class/power_supply/BAT0/status",
+ IOError(errno.ENOENT, "")):
+ self.assertIsNone(psutil.sensors_battery().power_plugged)
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSensorsBatteryEmulated(PsutilTestCase):
+
+ def test_it(self):
+ def open_mock(name, *args, **kwargs):
+ if name.endswith("/energy_now"):
+ return io.StringIO(u("60000000"))
+ elif name.endswith("/power_now"):
+ return io.StringIO(u("0"))
+ elif name.endswith("/energy_full"):
+ return io.StringIO(u("60000001"))
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch('os.listdir', return_value=["BAT0"]) as mlistdir:
+ with mock.patch(patch_point, side_effect=open_mock) as mopen:
+ self.assertIsNotNone(psutil.sensors_battery())
+ assert mlistdir.called
+ assert mopen.called
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSensorsTemperatures(PsutilTestCase):
+
+ def test_emulate_class_hwmon(self):
+ def open_mock(name, *args, **kwargs):
+ if name.endswith('/name'):
+ return io.StringIO(u("name"))
+ elif name.endswith('/temp1_label'):
+ return io.StringIO(u("label"))
+ elif name.endswith('/temp1_input'):
+ return io.BytesIO(b"30000")
+ elif name.endswith('/temp1_max'):
+ return io.BytesIO(b"40000")
+ elif name.endswith('/temp1_crit'):
+ return io.BytesIO(b"50000")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock):
+ # Test case with /sys/class/hwmon
+ with mock.patch('glob.glob',
+ return_value=['/sys/class/hwmon/hwmon0/temp1']):
+ temp = psutil.sensors_temperatures()['name'][0]
+ self.assertEqual(temp.label, 'label')
+ self.assertEqual(temp.current, 30.0)
+ self.assertEqual(temp.high, 40.0)
+ self.assertEqual(temp.critical, 50.0)
+
+ def test_emulate_class_thermal(self):
+ def open_mock(name, *args, **kwargs):
+ if name.endswith('0_temp'):
+ return io.BytesIO(b"50000")
+ elif name.endswith('temp'):
+ return io.BytesIO(b"30000")
+ elif name.endswith('0_type'):
+ return io.StringIO(u("critical"))
+ elif name.endswith('type'):
+ return io.StringIO(u("name"))
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ def glob_mock(path):
+ if path == '/sys/class/hwmon/hwmon*/temp*_*':
+ return []
+ elif path == '/sys/class/hwmon/hwmon*/device/temp*_*':
+ return []
+ elif path == '/sys/class/thermal/thermal_zone*':
+ return ['/sys/class/thermal/thermal_zone0']
+ elif path == '/sys/class/thermal/thermal_zone0/trip_point*':
+ return ['/sys/class/thermal/thermal_zone1/trip_point_0_type',
+ '/sys/class/thermal/thermal_zone1/trip_point_0_temp']
+ return []
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock):
+ with mock.patch('glob.glob', create=True, side_effect=glob_mock):
+ temp = psutil.sensors_temperatures()['name'][0]
+ self.assertEqual(temp.label, '')
+ self.assertEqual(temp.current, 30.0)
+ self.assertEqual(temp.high, 50.0)
+ self.assertEqual(temp.critical, 50.0)
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestSensorsFans(PsutilTestCase):
+
+ def test_emulate_data(self):
+ def open_mock(name, *args, **kwargs):
+ if name.endswith('/name'):
+ return io.StringIO(u("name"))
+ elif name.endswith('/fan1_label'):
+ return io.StringIO(u("label"))
+ elif name.endswith('/fan1_input'):
+ return io.StringIO(u("2000"))
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock):
+ with mock.patch('glob.glob',
+ return_value=['/sys/class/hwmon/hwmon2/fan1']):
+ fan = psutil.sensors_fans()['name'][0]
+ self.assertEqual(fan.label, 'label')
+ self.assertEqual(fan.current, 2000)
+
+
+# =====================================================================
+# --- test process
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestProcess(PsutilTestCase):
+
+ @retry_on_failure()
+ def test_parse_smaps_vs_memory_maps(self):
+ sproc = self.spawn_testproc()
+ uss, pss, swap = psutil._pslinux.Process(sproc.pid)._parse_smaps()
+ maps = psutil.Process(sproc.pid).memory_maps(grouped=False)
+ self.assertAlmostEqual(
+ uss, sum([x.private_dirty + x.private_clean for x in maps]),
+ delta=4096)
+ self.assertAlmostEqual(
+ pss, sum([x.pss for x in maps]), delta=4096)
+ self.assertAlmostEqual(
+ swap, sum([x.swap for x in maps]), delta=4096)
+
+ def test_parse_smaps_mocked(self):
+ # See: https://github.com/giampaolo/psutil/issues/1222
+ with mock_open_content(
+ "/proc/%s/smaps" % os.getpid(),
+ textwrap.dedent("""\
+ fffff0 r-xp 00000000 00:00 0 [vsyscall]
+ Size: 1 kB
+ Rss: 2 kB
+ Pss: 3 kB
+ Shared_Clean: 4 kB
+ Shared_Dirty: 5 kB
+ Private_Clean: 6 kB
+ Private_Dirty: 7 kB
+ Referenced: 8 kB
+ Anonymous: 9 kB
+ LazyFree: 10 kB
+ AnonHugePages: 11 kB
+ ShmemPmdMapped: 12 kB
+ Shared_Hugetlb: 13 kB
+ Private_Hugetlb: 14 kB
+ Swap: 15 kB
+ SwapPss: 16 kB
+ KernelPageSize: 17 kB
+ MMUPageSize: 18 kB
+ Locked: 19 kB
+ VmFlags: rd ex
+ """).encode()) as m:
+ p = psutil._pslinux.Process(os.getpid())
+ uss, pss, swap = p._parse_smaps()
+ assert m.called
+ self.assertEqual(uss, (6 + 7 + 14) * 1024)
+ self.assertEqual(pss, 3 * 1024)
+ self.assertEqual(swap, 15 * 1024)
+
+ # On PYPY file descriptors are not closed fast enough.
+ @unittest.skipIf(PYPY, "unreliable on PYPY")
+ def test_open_files_mode(self):
+ def get_test_file(fname):
+ p = psutil.Process()
+ giveup_at = time.time() + GLOBAL_TIMEOUT
+ while True:
+ for file in p.open_files():
+ if file.path == os.path.abspath(fname):
+ return file
+ elif time.time() > giveup_at:
+ break
+ raise RuntimeError("timeout looking for test file")
+
+ #
+ testfn = self.get_testfn()
+ with open(testfn, "w"):
+ self.assertEqual(get_test_file(testfn).mode, "w")
+ with open(testfn, "r"):
+ self.assertEqual(get_test_file(testfn).mode, "r")
+ with open(testfn, "a"):
+ self.assertEqual(get_test_file(testfn).mode, "a")
+ #
+ with open(testfn, "r+"):
+ self.assertEqual(get_test_file(testfn).mode, "r+")
+ with open(testfn, "w+"):
+ self.assertEqual(get_test_file(testfn).mode, "r+")
+ with open(testfn, "a+"):
+ self.assertEqual(get_test_file(testfn).mode, "a+")
+ # note: "x" bit is not supported
+ if PY3:
+ safe_rmpath(testfn)
+ with open(testfn, "x"):
+ self.assertEqual(get_test_file(testfn).mode, "w")
+ safe_rmpath(testfn)
+ with open(testfn, "x+"):
+ self.assertEqual(get_test_file(testfn).mode, "r+")
+
+ def test_open_files_file_gone(self):
+ # simulates a file which gets deleted during open_files()
+ # execution
+ p = psutil.Process()
+ files = p.open_files()
+ with open(self.get_testfn(), 'w'):
+ # give the kernel some time to see the new file
+ call_until(p.open_files, "len(ret) != %i" % len(files))
+ with mock.patch('psutil._pslinux.os.readlink',
+ side_effect=OSError(errno.ENOENT, "")) as m:
+ files = p.open_files()
+ assert not files
+ assert m.called
+ # also simulate the case where os.readlink() returns EINVAL
+ # in which case psutil is supposed to 'continue'
+ with mock.patch('psutil._pslinux.os.readlink',
+ side_effect=OSError(errno.EINVAL, "")) as m:
+ self.assertEqual(p.open_files(), [])
+ assert m.called
+
+ def test_open_files_fd_gone(self):
+ # Simulate a case where /proc/{pid}/fdinfo/{fd} disappears
+ # while iterating through fds.
+ # https://travis-ci.org/giampaolo/psutil/jobs/225694530
+ p = psutil.Process()
+ files = p.open_files()
+ with open(self.get_testfn(), 'w'):
+ # give the kernel some time to see the new file
+ call_until(p.open_files, "len(ret) != %i" % len(files))
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point,
+ side_effect=IOError(errno.ENOENT, "")) as m:
+ files = p.open_files()
+ assert not files
+ assert m.called
+
+ def test_open_files_enametoolong(self):
+ # Simulate a case where /proc/{pid}/fd/{fd} symlink
+ # points to a file with full path longer than PATH_MAX, see:
+ # https://github.com/giampaolo/psutil/issues/1940
+ p = psutil.Process()
+ files = p.open_files()
+ with open(self.get_testfn(), 'w'):
+ # give the kernel some time to see the new file
+ call_until(p.open_files, "len(ret) != %i" % len(files))
+ patch_point = 'psutil._pslinux.os.readlink'
+ with mock.patch(patch_point,
+ side_effect=OSError(errno.ENAMETOOLONG, "")) as m:
+ with mock.patch("psutil._pslinux.debug"):
+ files = p.open_files()
+ assert not files
+ assert m.called
+
+ # --- mocked tests
+
+ def test_terminal_mocked(self):
+ with mock.patch('psutil._pslinux._psposix.get_terminal_map',
+ return_value={}) as m:
+ self.assertIsNone(psutil._pslinux.Process(os.getpid()).terminal())
+ assert m.called
+
+ # TODO: re-enable this test.
+ # def test_num_ctx_switches_mocked(self):
+ # with mock.patch('psutil._common.open', create=True) as m:
+ # self.assertRaises(
+ # NotImplementedError,
+ # psutil._pslinux.Process(os.getpid()).num_ctx_switches)
+ # assert m.called
+
+ def test_cmdline_mocked(self):
+ # see: https://github.com/giampaolo/psutil/issues/639
+ p = psutil.Process()
+ fake_file = io.StringIO(u('foo\x00bar\x00'))
+ with mock.patch('psutil._common.open',
+ return_value=fake_file, create=True) as m:
+ self.assertEqual(p.cmdline(), ['foo', 'bar'])
+ assert m.called
+ fake_file = io.StringIO(u('foo\x00bar\x00\x00'))
+ with mock.patch('psutil._common.open',
+ return_value=fake_file, create=True) as m:
+ self.assertEqual(p.cmdline(), ['foo', 'bar', ''])
+ assert m.called
+
+ def test_cmdline_spaces_mocked(self):
+ # see: https://github.com/giampaolo/psutil/issues/1179
+ p = psutil.Process()
+ fake_file = io.StringIO(u('foo bar '))
+ with mock.patch('psutil._common.open',
+ return_value=fake_file, create=True) as m:
+ self.assertEqual(p.cmdline(), ['foo', 'bar'])
+ assert m.called
+ fake_file = io.StringIO(u('foo bar '))
+ with mock.patch('psutil._common.open',
+ return_value=fake_file, create=True) as m:
+ self.assertEqual(p.cmdline(), ['foo', 'bar', ''])
+ assert m.called
+
+ def test_cmdline_mixed_separators(self):
+ # https://github.com/giampaolo/psutil/issues/
+ # 1179#issuecomment-552984549
+ p = psutil.Process()
+ fake_file = io.StringIO(u('foo\x20bar\x00'))
+ with mock.patch('psutil._common.open',
+ return_value=fake_file, create=True) as m:
+ self.assertEqual(p.cmdline(), ['foo', 'bar'])
+ assert m.called
+
+ def test_readlink_path_deleted_mocked(self):
+ with mock.patch('psutil._pslinux.os.readlink',
+ return_value='/home/foo (deleted)'):
+ self.assertEqual(psutil.Process().exe(), "/home/foo")
+ self.assertEqual(psutil.Process().cwd(), "/home/foo")
+
+ def test_threads_mocked(self):
+ # Test the case where os.listdir() returns a file (thread)
+ # which no longer exists by the time we open() it (race
+ # condition). threads() is supposed to ignore that instead
+ # of raising NSP.
+ def open_mock(name, *args, **kwargs):
+ if name.startswith('/proc/%s/task' % os.getpid()):
+ raise IOError(errno.ENOENT, "")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ orig_open = open
+ patch_point = 'builtins.open' if PY3 else '__builtin__.open'
+ with mock.patch(patch_point, side_effect=open_mock) as m:
+ ret = psutil.Process().threads()
+ assert m.called
+ self.assertEqual(ret, [])
+
+ # ...but if it bumps into something != ENOENT we want an
+ # exception.
+ def open_mock(name, *args, **kwargs):
+ if name.startswith('/proc/%s/task' % os.getpid()):
+ raise IOError(errno.EPERM, "")
+ else:
+ return orig_open(name, *args, **kwargs)
+
+ with mock.patch(patch_point, side_effect=open_mock):
+ self.assertRaises(psutil.AccessDenied, psutil.Process().threads)
+
+ def test_exe_mocked(self):
+ with mock.patch('psutil._pslinux.readlink',
+ side_effect=OSError(errno.ENOENT, "")) as m1:
+ with mock.patch('psutil.Process.cmdline',
+ side_effect=psutil.AccessDenied(0, "")) as m2:
+ # No such file error; might be raised also if /proc/pid/exe
+ # path actually exists for system processes with low pids
+ # (about 0-20). In this case psutil is supposed to return
+ # an empty string.
+ ret = psutil.Process().exe()
+ assert m1.called
+ assert m2.called
+ self.assertEqual(ret, "")
+
+ # ...but if /proc/pid no longer exist we're supposed to treat
+ # it as an alias for zombie process
+ with mock.patch('psutil._pslinux.os.path.lexists',
+ return_value=False):
+ self.assertRaises(
+ psutil.ZombieProcess, psutil.Process().exe)
+
+ def test_issue_1014(self):
+ # Emulates a case where smaps file does not exist. In this case
+ # wrap_exception decorator should not raise NoSuchProcess.
+ with mock_open_exception(
+ '/proc/%s/smaps' % os.getpid(),
+ IOError(errno.ENOENT, "")) as m:
+ p = psutil.Process()
+ with self.assertRaises(FileNotFoundError):
+ p.memory_maps()
+ assert m.called
+
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit_zombie(self):
+ # Emulate a case where rlimit() raises ENOSYS, which may
+ # happen in case of zombie process:
+ # https://travis-ci.org/giampaolo/psutil/jobs/51368273
+ with mock.patch("psutil._pslinux.prlimit",
+ side_effect=OSError(errno.ENOSYS, "")) as m:
+ p = psutil.Process()
+ p.name()
+ with self.assertRaises(psutil.ZombieProcess) as exc:
+ p.rlimit(psutil.RLIMIT_NOFILE)
+ assert m.called
+ self.assertEqual(exc.exception.pid, p.pid)
+ self.assertEqual(exc.exception.name, p.name())
+
+ def test_cwd_zombie(self):
+ with mock.patch("psutil._pslinux.os.readlink",
+ side_effect=OSError(errno.ENOENT, "")) as m:
+ p = psutil.Process()
+ p.name()
+ with self.assertRaises(psutil.ZombieProcess) as exc:
+ p.cwd()
+ assert m.called
+ self.assertEqual(exc.exception.pid, p.pid)
+ self.assertEqual(exc.exception.name, p.name())
+
+ def test_stat_file_parsing(self):
+ args = [
+ "0", # pid
+ "(cat)", # name
+ "Z", # status
+ "1", # ppid
+ "0", # pgrp
+ "0", # session
+ "0", # tty
+ "0", # tpgid
+ "0", # flags
+ "0", # minflt
+ "0", # cminflt
+ "0", # majflt
+ "0", # cmajflt
+ "2", # utime
+ "3", # stime
+ "4", # cutime
+ "5", # cstime
+ "0", # priority
+ "0", # nice
+ "0", # num_threads
+ "0", # itrealvalue
+ "6", # starttime
+ "0", # vsize
+ "0", # rss
+ "0", # rsslim
+ "0", # startcode
+ "0", # endcode
+ "0", # startstack
+ "0", # kstkesp
+ "0", # kstkeip
+ "0", # signal
+ "0", # blocked
+ "0", # sigignore
+ "0", # sigcatch
+ "0", # wchan
+ "0", # nswap
+ "0", # cnswap
+ "0", # exit_signal
+ "6", # processor
+ "0", # rt priority
+ "0", # policy
+ "7", # delayacct_blkio_ticks
+ ]
+ content = " ".join(args).encode()
+ with mock_open_content('/proc/%s/stat' % os.getpid(), content):
+ p = psutil.Process()
+ self.assertEqual(p.name(), 'cat')
+ self.assertEqual(p.status(), psutil.STATUS_ZOMBIE)
+ self.assertEqual(p.ppid(), 1)
+ self.assertEqual(
+ p.create_time(), 6 / CLOCK_TICKS + psutil.boot_time())
+ cpu = p.cpu_times()
+ self.assertEqual(cpu.user, 2 / CLOCK_TICKS)
+ self.assertEqual(cpu.system, 3 / CLOCK_TICKS)
+ self.assertEqual(cpu.children_user, 4 / CLOCK_TICKS)
+ self.assertEqual(cpu.children_system, 5 / CLOCK_TICKS)
+ self.assertEqual(cpu.iowait, 7 / CLOCK_TICKS)
+ self.assertEqual(p.cpu_num(), 6)
+
+ def test_status_file_parsing(self):
+ with mock_open_content(
+ '/proc/%s/status' % os.getpid(),
+ textwrap.dedent("""\
+ Uid:\t1000\t1001\t1002\t1003
+ Gid:\t1004\t1005\t1006\t1007
+ Threads:\t66
+ Cpus_allowed:\tf
+ Cpus_allowed_list:\t0-7
+ voluntary_ctxt_switches:\t12
+ nonvoluntary_ctxt_switches:\t13""").encode()):
+ p = psutil.Process()
+ self.assertEqual(p.num_ctx_switches().voluntary, 12)
+ self.assertEqual(p.num_ctx_switches().involuntary, 13)
+ self.assertEqual(p.num_threads(), 66)
+ uids = p.uids()
+ self.assertEqual(uids.real, 1000)
+ self.assertEqual(uids.effective, 1001)
+ self.assertEqual(uids.saved, 1002)
+ gids = p.gids()
+ self.assertEqual(gids.real, 1004)
+ self.assertEqual(gids.effective, 1005)
+ self.assertEqual(gids.saved, 1006)
+ self.assertEqual(p._proc._get_eligible_cpus(), list(range(0, 8)))
+
+ def test_connections_enametoolong(self):
+ # Simulate a case where /proc/{pid}/fd/{fd} symlink points to
+ # a file with full path longer than PATH_MAX, see:
+ # https://github.com/giampaolo/psutil/issues/1940
+ with mock.patch('psutil._pslinux.os.readlink',
+ side_effect=OSError(errno.ENAMETOOLONG, "")) as m:
+ p = psutil.Process()
+ with mock.patch("psutil._pslinux.debug"):
+ assert not p.connections()
+ assert m.called
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestProcessAgainstStatus(PsutilTestCase):
+ """/proc/pid/stat and /proc/pid/status have many values in common.
+ Whenever possible, psutil uses /proc/pid/stat (it's faster).
+ For all those cases we check that the value found in
+ /proc/pid/stat (by psutil) matches the one found in
+ /proc/pid/status.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.proc = psutil.Process()
+
+ def read_status_file(self, linestart):
+ with psutil._psplatform.open_text(
+ '/proc/%s/status' % self.proc.pid) as f:
+ for line in f:
+ line = line.strip()
+ if line.startswith(linestart):
+ value = line.partition('\t')[2]
+ try:
+ return int(value)
+ except ValueError:
+ return value
+ raise ValueError("can't find %r" % linestart)
+
+ def test_name(self):
+ value = self.read_status_file("Name:")
+ self.assertEqual(self.proc.name(), value)
+
+ def test_status(self):
+ value = self.read_status_file("State:")
+ value = value[value.find('(') + 1:value.rfind(')')]
+ value = value.replace(' ', '-')
+ self.assertEqual(self.proc.status(), value)
+
+ def test_ppid(self):
+ value = self.read_status_file("PPid:")
+ self.assertEqual(self.proc.ppid(), value)
+
+ def test_num_threads(self):
+ value = self.read_status_file("Threads:")
+ self.assertEqual(self.proc.num_threads(), value)
+
+ def test_uids(self):
+ value = self.read_status_file("Uid:")
+ value = tuple(map(int, value.split()[1:4]))
+ self.assertEqual(self.proc.uids(), value)
+
+ def test_gids(self):
+ value = self.read_status_file("Gid:")
+ value = tuple(map(int, value.split()[1:4]))
+ self.assertEqual(self.proc.gids(), value)
+
+ @retry_on_failure()
+ def test_num_ctx_switches(self):
+ value = self.read_status_file("voluntary_ctxt_switches:")
+ self.assertEqual(self.proc.num_ctx_switches().voluntary, value)
+ value = self.read_status_file("nonvoluntary_ctxt_switches:")
+ self.assertEqual(self.proc.num_ctx_switches().involuntary, value)
+
+ def test_cpu_affinity(self):
+ value = self.read_status_file("Cpus_allowed_list:")
+ if '-' in str(value):
+ min_, max_ = map(int, value.split('-'))
+ self.assertEqual(
+ self.proc.cpu_affinity(), list(range(min_, max_ + 1)))
+
+ def test_cpu_affinity_eligible_cpus(self):
+ value = self.read_status_file("Cpus_allowed_list:")
+ with mock.patch("psutil._pslinux.per_cpu_times") as m:
+ self.proc._proc._get_eligible_cpus()
+ if '-' in str(value):
+ assert not m.called
+ else:
+ assert m.called
+
+
+# =====================================================================
+# --- test utils
+# =====================================================================
+
+
+@unittest.skipIf(not LINUX, "LINUX only")
+class TestUtils(PsutilTestCase):
+
+ def test_readlink(self):
+ with mock.patch("os.readlink", return_value="foo (deleted)") as m:
+ self.assertEqual(psutil._psplatform.readlink("bar"), "foo")
+ assert m.called
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_memleaks.py b/lib/psutil/tests/test_memleaks.py
new file mode 100644
index 0000000..dbd1588
--- /dev/null
+++ b/lib/psutil/tests/test_memleaks.py
@@ -0,0 +1,492 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Tests for detecting function memory leaks (typically the ones
+implemented in C). It does so by calling a function many times and
+checking whether process memory usage keeps increasing between
+calls or over time.
+Note that this may produce false positives (especially on Windows
+for some reason).
+PyPy appears to be completely unstable for this framework, probably
+because of how its JIT handles memory, so tests are skipped.
+"""
+
+from __future__ import print_function
+
+import functools
+import os
+import platform
+import unittest
+
+import psutil
+import psutil._common
+from psutil import LINUX
+from psutil import MACOS
+from psutil import OPENBSD
+from psutil import POSIX
+from psutil import SUNOS
+from psutil import WINDOWS
+from psutil._compat import ProcessLookupError
+from psutil._compat import super
+from psutil.tests import HAS_CPU_AFFINITY
+from psutil.tests import HAS_CPU_FREQ
+from psutil.tests import HAS_ENVIRON
+from psutil.tests import HAS_IONICE
+from psutil.tests import HAS_MEMORY_MAPS
+from psutil.tests import HAS_NET_IO_COUNTERS
+from psutil.tests import HAS_PROC_CPU_NUM
+from psutil.tests import HAS_PROC_IO_COUNTERS
+from psutil.tests import HAS_RLIMIT
+from psutil.tests import HAS_SENSORS_BATTERY
+from psutil.tests import HAS_SENSORS_FANS
+from psutil.tests import HAS_SENSORS_TEMPERATURES
+from psutil.tests import TestMemoryLeak
+from psutil.tests import create_sockets
+from psutil.tests import get_testfn
+from psutil.tests import process_namespace
+from psutil.tests import skip_on_access_denied
+from psutil.tests import spawn_testproc
+from psutil.tests import system_namespace
+from psutil.tests import terminate
+
+
+cext = psutil._psplatform.cext
+thisproc = psutil.Process()
+FEW_TIMES = 5
+
+
+def fewtimes_if_linux():
+ """Decorator for those Linux functions which are implemented in pure
+ Python, and which we want to run faster.
+ """
+ def decorator(fun):
+ @functools.wraps(fun)
+ def wrapper(self, *args, **kwargs):
+ if LINUX:
+ before = self.__class__.times
+ try:
+ self.__class__.times = FEW_TIMES
+ return fun(self, *args, **kwargs)
+ finally:
+ self.__class__.times = before
+ else:
+ return fun(self, *args, **kwargs)
+ return wrapper
+ return decorator
+
+
+# ===================================================================
+# Process class
+# ===================================================================
+
+
+class TestProcessObjectLeaks(TestMemoryLeak):
+ """Test leaks of Process class methods."""
+
+ proc = thisproc
+
+ def test_coverage(self):
+ ns = process_namespace(None)
+ ns.test_class_coverage(self, ns.getters + ns.setters)
+
+ @fewtimes_if_linux()
+ def test_name(self):
+ self.execute(self.proc.name)
+
+ @fewtimes_if_linux()
+ def test_cmdline(self):
+ self.execute(self.proc.cmdline)
+
+ @fewtimes_if_linux()
+ def test_exe(self):
+ self.execute(self.proc.exe)
+
+ @fewtimes_if_linux()
+ def test_ppid(self):
+ self.execute(self.proc.ppid)
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ @fewtimes_if_linux()
+ def test_uids(self):
+ self.execute(self.proc.uids)
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ @fewtimes_if_linux()
+ def test_gids(self):
+ self.execute(self.proc.gids)
+
+ @fewtimes_if_linux()
+ def test_status(self):
+ self.execute(self.proc.status)
+
+ def test_nice(self):
+ self.execute(self.proc.nice)
+
+ def test_nice_set(self):
+ niceness = thisproc.nice()
+ self.execute(lambda: self.proc.nice(niceness))
+
+ @unittest.skipIf(not HAS_IONICE, "not supported")
+ def test_ionice(self):
+ self.execute(self.proc.ionice)
+
+ @unittest.skipIf(not HAS_IONICE, "not supported")
+ def test_ionice_set(self):
+ if WINDOWS:
+ value = thisproc.ionice()
+ self.execute(lambda: self.proc.ionice(value))
+ else:
+ self.execute(lambda: self.proc.ionice(psutil.IOPRIO_CLASS_NONE))
+ fun = functools.partial(cext.proc_ioprio_set, os.getpid(), -1, 0)
+ self.execute_w_exc(OSError, fun)
+
+ @unittest.skipIf(not HAS_PROC_IO_COUNTERS, "not supported")
+ @fewtimes_if_linux()
+ def test_io_counters(self):
+ self.execute(self.proc.io_counters)
+
+ @unittest.skipIf(POSIX, "worthless on POSIX")
+ def test_username(self):
+ # always open 1 handle on Windows (only once)
+ psutil.Process().username()
+ self.execute(self.proc.username)
+
+ @fewtimes_if_linux()
+ def test_create_time(self):
+ self.execute(self.proc.create_time)
+
+ @fewtimes_if_linux()
+ @skip_on_access_denied(only_if=OPENBSD)
+ def test_num_threads(self):
+ self.execute(self.proc.num_threads)
+
+ @unittest.skipIf(not WINDOWS, "WINDOWS only")
+ def test_num_handles(self):
+ self.execute(self.proc.num_handles)
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ @fewtimes_if_linux()
+ def test_num_fds(self):
+ self.execute(self.proc.num_fds)
+
+ @fewtimes_if_linux()
+ def test_num_ctx_switches(self):
+ self.execute(self.proc.num_ctx_switches)
+
+ @fewtimes_if_linux()
+ @skip_on_access_denied(only_if=OPENBSD)
+ def test_threads(self):
+ self.execute(self.proc.threads)
+
+ @fewtimes_if_linux()
+ def test_cpu_times(self):
+ self.execute(self.proc.cpu_times)
+
+ @fewtimes_if_linux()
+ @unittest.skipIf(not HAS_PROC_CPU_NUM, "not supported")
+ def test_cpu_num(self):
+ self.execute(self.proc.cpu_num)
+
+ @fewtimes_if_linux()
+ def test_memory_info(self):
+ self.execute(self.proc.memory_info)
+
+ @fewtimes_if_linux()
+ def test_memory_full_info(self):
+ self.execute(self.proc.memory_full_info)
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ @fewtimes_if_linux()
+ def test_terminal(self):
+ self.execute(self.proc.terminal)
+
+ def test_resume(self):
+ times = FEW_TIMES if POSIX else self.times
+ self.execute(self.proc.resume, times=times)
+
+ @fewtimes_if_linux()
+ def test_cwd(self):
+ self.execute(self.proc.cwd)
+
+ @unittest.skipIf(not HAS_CPU_AFFINITY, "not supported")
+ def test_cpu_affinity(self):
+ self.execute(self.proc.cpu_affinity)
+
+ @unittest.skipIf(not HAS_CPU_AFFINITY, "not supported")
+ def test_cpu_affinity_set(self):
+ affinity = thisproc.cpu_affinity()
+ self.execute(lambda: self.proc.cpu_affinity(affinity))
+ self.execute_w_exc(
+ ValueError, lambda: self.proc.cpu_affinity([-1]))
+
+ @fewtimes_if_linux()
+ def test_open_files(self):
+ with open(get_testfn(), 'w'):
+ self.execute(self.proc.open_files)
+
+ @unittest.skipIf(not HAS_MEMORY_MAPS, "not supported")
+ @fewtimes_if_linux()
+ def test_memory_maps(self):
+ self.execute(self.proc.memory_maps)
+
+ @unittest.skipIf(not LINUX, "LINUX only")
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit(self):
+ self.execute(lambda: self.proc.rlimit(psutil.RLIMIT_NOFILE))
+
+ @unittest.skipIf(not LINUX, "LINUX only")
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit_set(self):
+ limit = thisproc.rlimit(psutil.RLIMIT_NOFILE)
+ self.execute(lambda: self.proc.rlimit(psutil.RLIMIT_NOFILE, limit))
+ self.execute_w_exc((OSError, ValueError), lambda: self.proc.rlimit(-1))
+
+ @fewtimes_if_linux()
+ # Windows implementation is based on a single system-wide
+ # function (tested later).
+ @unittest.skipIf(WINDOWS, "worthless on WINDOWS")
+ def test_connections(self):
+ # TODO: UNIX sockets are temporarily implemented by parsing
+ # 'pfiles' cmd output; we don't want that part of the code to
+ # be executed.
+ with create_sockets():
+ kind = 'inet' if SUNOS else 'all'
+ self.execute(lambda: self.proc.connections(kind))
+
+ @unittest.skipIf(not HAS_ENVIRON, "not supported")
+ def test_environ(self):
+ self.execute(self.proc.environ)
+
+ @unittest.skipIf(not WINDOWS, "WINDOWS only")
+ def test_proc_info(self):
+ self.execute(lambda: cext.proc_info(os.getpid()))
+
+
+class TestTerminatedProcessLeaks(TestProcessObjectLeaks):
+ """Repeat the tests above looking for leaks occurring when dealing
+ with terminated processes raising NoSuchProcess exception.
+ The C functions are still invoked but will follow different code
+ paths. We'll check those code paths.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls.subp = spawn_testproc()
+ cls.proc = psutil.Process(cls.subp.pid)
+ cls.proc.kill()
+ cls.proc.wait()
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ terminate(cls.subp)
+
+ def call(self, fun):
+ try:
+ fun()
+ except psutil.NoSuchProcess:
+ pass
+
+ if WINDOWS:
+
+ def test_kill(self):
+ self.execute(self.proc.kill)
+
+ def test_terminate(self):
+ self.execute(self.proc.terminate)
+
+ def test_suspend(self):
+ self.execute(self.proc.suspend)
+
+ def test_resume(self):
+ self.execute(self.proc.resume)
+
+ def test_wait(self):
+ self.execute(self.proc.wait)
+
+ def test_proc_info(self):
+ # test dual implementation
+ def call():
+ try:
+ return cext.proc_info(self.proc.pid)
+ except ProcessLookupError:
+ pass
+
+ self.execute(call)
+
+
+@unittest.skipIf(not WINDOWS, "WINDOWS only")
+class TestProcessDualImplementation(TestMemoryLeak):
+
+ def test_cmdline_peb_true(self):
+ self.execute(lambda: cext.proc_cmdline(os.getpid(), use_peb=True))
+
+ def test_cmdline_peb_false(self):
+ self.execute(lambda: cext.proc_cmdline(os.getpid(), use_peb=False))
+
+
+# ===================================================================
+# system APIs
+# ===================================================================
+
+
+class TestModuleFunctionsLeaks(TestMemoryLeak):
+ """Test leaks of psutil module functions."""
+
+ def test_coverage(self):
+ ns = system_namespace()
+ ns.test_class_coverage(self, ns.all)
+
+ # --- cpu
+
+ @fewtimes_if_linux()
+ def test_cpu_count(self): # logical
+ self.execute(lambda: psutil.cpu_count(logical=True))
+
+ @fewtimes_if_linux()
+ def test_cpu_count_cores(self):
+ self.execute(lambda: psutil.cpu_count(logical=False))
+
+ @fewtimes_if_linux()
+ def test_cpu_times(self):
+ self.execute(psutil.cpu_times)
+
+ @fewtimes_if_linux()
+ def test_per_cpu_times(self):
+ self.execute(lambda: psutil.cpu_times(percpu=True))
+
+ @fewtimes_if_linux()
+ def test_cpu_stats(self):
+ self.execute(psutil.cpu_stats)
+
+ @fewtimes_if_linux()
+ # TODO: remove this once 1892 is fixed
+ @unittest.skipIf(MACOS and platform.machine() == 'arm64',
+ "skipped due to #1892")
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_cpu_freq(self):
+ self.execute(psutil.cpu_freq)
+
+ @unittest.skipIf(not WINDOWS, "WINDOWS only")
+ def test_getloadavg(self):
+ psutil.getloadavg()
+ self.execute(psutil.getloadavg)
+
+ # --- mem
+
+ def test_virtual_memory(self):
+ self.execute(psutil.virtual_memory)
+
+ # TODO: remove this skip when this gets fixed
+ @unittest.skipIf(SUNOS, "worthless on SUNOS (uses a subprocess)")
+ def test_swap_memory(self):
+ self.execute(psutil.swap_memory)
+
+ def test_pid_exists(self):
+ times = FEW_TIMES if POSIX else self.times
+ self.execute(lambda: psutil.pid_exists(os.getpid()), times=times)
+
+ # --- disk
+
+ def test_disk_usage(self):
+ times = FEW_TIMES if POSIX else self.times
+ self.execute(lambda: psutil.disk_usage('.'), times=times)
+
+ def test_disk_partitions(self):
+ self.execute(psutil.disk_partitions)
+
+ @unittest.skipIf(LINUX and not os.path.exists('/proc/diskstats'),
+ '/proc/diskstats not available on this Linux version')
+ @fewtimes_if_linux()
+ def test_disk_io_counters(self):
+ self.execute(lambda: psutil.disk_io_counters(nowrap=False))
+
+ # --- proc
+
+ @fewtimes_if_linux()
+ def test_pids(self):
+ self.execute(psutil.pids)
+
+ # --- net
+
+ @fewtimes_if_linux()
+ @unittest.skipIf(not HAS_NET_IO_COUNTERS, 'not supported')
+ def test_net_io_counters(self):
+ self.execute(lambda: psutil.net_io_counters(nowrap=False))
+
+ @fewtimes_if_linux()
+ @unittest.skipIf(MACOS and os.getuid() != 0, "need root access")
+ def test_net_connections(self):
+ # always opens and handle on Windows() (once)
+ psutil.net_connections(kind='all')
+ with create_sockets():
+ self.execute(lambda: psutil.net_connections(kind='all'))
+
+ def test_net_if_addrs(self):
+ # Note: verified that on Windows this was a false positive.
+ tolerance = 80 * 1024 if WINDOWS else self.tolerance
+ self.execute(psutil.net_if_addrs, tolerance=tolerance)
+
+ def test_net_if_stats(self):
+ self.execute(psutil.net_if_stats)
+
+ # --- sensors
+
+ @fewtimes_if_linux()
+ @unittest.skipIf(not HAS_SENSORS_BATTERY, "not supported")
+ def test_sensors_battery(self):
+ self.execute(psutil.sensors_battery)
+
+ @fewtimes_if_linux()
+ @unittest.skipIf(not HAS_SENSORS_TEMPERATURES, "not supported")
+ def test_sensors_temperatures(self):
+ self.execute(psutil.sensors_temperatures)
+
+ @fewtimes_if_linux()
+ @unittest.skipIf(not HAS_SENSORS_FANS, "not supported")
+ def test_sensors_fans(self):
+ self.execute(psutil.sensors_fans)
+
+ # --- others
+
+ @fewtimes_if_linux()
+ def test_boot_time(self):
+ self.execute(psutil.boot_time)
+
+ def test_users(self):
+ self.execute(psutil.users)
+
+ def test_set_debug(self):
+ self.execute(lambda: psutil._set_debug(False))
+
+ if WINDOWS:
+
+ # --- win services
+
+ def test_win_service_iter(self):
+ self.execute(cext.winservice_enumerate)
+
+ def test_win_service_get(self):
+ pass
+
+ def test_win_service_get_config(self):
+ name = next(psutil.win_service_iter()).name()
+ self.execute(lambda: cext.winservice_query_config(name))
+
+ def test_win_service_get_status(self):
+ name = next(psutil.win_service_iter()).name()
+ self.execute(lambda: cext.winservice_query_status(name))
+
+ def test_win_service_get_description(self):
+ name = next(psutil.win_service_iter()).name()
+ self.execute(lambda: cext.winservice_query_descr(name))
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_misc.py b/lib/psutil/tests/test_misc.py
new file mode 100644
index 0000000..e22789c
--- /dev/null
+++ b/lib/psutil/tests/test_misc.py
@@ -0,0 +1,852 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Miscellaneous tests.
+"""
+
+import ast
+import collections
+import errno
+import json
+import os
+import pickle
+import socket
+import stat
+import unittest
+
+import psutil
+import psutil.tests
+from psutil import LINUX
+from psutil import POSIX
+from psutil import WINDOWS
+from psutil._common import bcat
+from psutil._common import cat
+from psutil._common import debug
+from psutil._common import isfile_strict
+from psutil._common import memoize
+from psutil._common import memoize_when_activated
+from psutil._common import parse_environ_block
+from psutil._common import supports_ipv6
+from psutil._common import wrap_numbers
+from psutil._compat import PY3
+from psutil._compat import FileNotFoundError
+from psutil._compat import redirect_stderr
+from psutil.tests import APPVEYOR
+from psutil.tests import CI_TESTING
+from psutil.tests import HAS_BATTERY
+from psutil.tests import HAS_MEMORY_MAPS
+from psutil.tests import HAS_NET_IO_COUNTERS
+from psutil.tests import HAS_SENSORS_BATTERY
+from psutil.tests import HAS_SENSORS_FANS
+from psutil.tests import HAS_SENSORS_TEMPERATURES
+from psutil.tests import PYTHON_EXE
+from psutil.tests import ROOT_DIR
+from psutil.tests import SCRIPTS_DIR
+from psutil.tests import PsutilTestCase
+from psutil.tests import import_module_by_path
+from psutil.tests import mock
+from psutil.tests import reload_module
+from psutil.tests import sh
+
+
+# ===================================================================
+# --- Test classes' repr(), str(), ...
+# ===================================================================
+
+
+class TestSpecialMethods(PsutilTestCase):
+
+ def test_process__repr__(self, func=repr):
+ p = psutil.Process(self.spawn_testproc().pid)
+ r = func(p)
+ self.assertIn("psutil.Process", r)
+ self.assertIn("pid=%s" % p.pid, r)
+ self.assertIn("name='%s'" % str(p.name()),
+ r.replace("name=u'", "name='"))
+ self.assertIn("status=", r)
+ self.assertNotIn("exitcode=", r)
+ p.terminate()
+ p.wait()
+ r = func(p)
+ self.assertIn("status='terminated'", r)
+ self.assertIn("exitcode=", r)
+
+ with mock.patch.object(psutil.Process, "name",
+ side_effect=psutil.ZombieProcess(os.getpid())):
+ p = psutil.Process()
+ r = func(p)
+ self.assertIn("pid=%s" % p.pid, r)
+ self.assertIn("status='zombie'", r)
+ self.assertNotIn("name=", r)
+ with mock.patch.object(psutil.Process, "name",
+ side_effect=psutil.NoSuchProcess(os.getpid())):
+ p = psutil.Process()
+ r = func(p)
+ self.assertIn("pid=%s" % p.pid, r)
+ self.assertIn("terminated", r)
+ self.assertNotIn("name=", r)
+ with mock.patch.object(psutil.Process, "name",
+ side_effect=psutil.AccessDenied(os.getpid())):
+ p = psutil.Process()
+ r = func(p)
+ self.assertIn("pid=%s" % p.pid, r)
+ self.assertNotIn("name=", r)
+
+ def test_process__str__(self):
+ self.test_process__repr__(func=str)
+
+ def test_error__repr__(self):
+ self.assertEqual(repr(psutil.Error()), "psutil.Error()")
+
+ def test_error__str__(self):
+ self.assertEqual(str(psutil.Error()), "")
+
+ def test_no_such_process__repr__(self):
+ self.assertEqual(
+ repr(psutil.NoSuchProcess(321)),
+ "psutil.NoSuchProcess(pid=321, msg='process no longer exists')")
+ self.assertEqual(
+ repr(psutil.NoSuchProcess(321, name="name", msg="msg")),
+ "psutil.NoSuchProcess(pid=321, name='name', msg='msg')")
+
+ def test_no_such_process__str__(self):
+ self.assertEqual(
+ str(psutil.NoSuchProcess(321)),
+ "process no longer exists (pid=321)")
+ self.assertEqual(
+ str(psutil.NoSuchProcess(321, name="name", msg="msg")),
+ "msg (pid=321, name='name')")
+
+ def test_zombie_process__repr__(self):
+ self.assertEqual(
+ repr(psutil.ZombieProcess(321)),
+ 'psutil.ZombieProcess(pid=321, msg="PID still '
+ 'exists but it\'s a zombie")')
+ self.assertEqual(
+ repr(psutil.ZombieProcess(321, name="name", ppid=320, msg="foo")),
+ "psutil.ZombieProcess(pid=321, ppid=320, name='name', msg='foo')")
+
+ def test_zombie_process__str__(self):
+ self.assertEqual(
+ str(psutil.ZombieProcess(321)),
+ "PID still exists but it's a zombie (pid=321)")
+ self.assertEqual(
+ str(psutil.ZombieProcess(321, name="name", ppid=320, msg="foo")),
+ "foo (pid=321, ppid=320, name='name')")
+
+ def test_access_denied__repr__(self):
+ self.assertEqual(
+ repr(psutil.AccessDenied(321)),
+ "psutil.AccessDenied(pid=321)")
+ self.assertEqual(
+ repr(psutil.AccessDenied(321, name="name", msg="msg")),
+ "psutil.AccessDenied(pid=321, name='name', msg='msg')")
+
+ def test_access_denied__str__(self):
+ self.assertEqual(
+ str(psutil.AccessDenied(321)),
+ "(pid=321)")
+ self.assertEqual(
+ str(psutil.AccessDenied(321, name="name", msg="msg")),
+ "msg (pid=321, name='name')")
+
+ def test_timeout_expired__repr__(self):
+ self.assertEqual(
+ repr(psutil.TimeoutExpired(5)),
+ "psutil.TimeoutExpired(seconds=5, msg='timeout after 5 seconds')")
+ self.assertEqual(
+ repr(psutil.TimeoutExpired(5, pid=321, name="name")),
+ "psutil.TimeoutExpired(pid=321, name='name', seconds=5, "
+ "msg='timeout after 5 seconds')")
+
+ def test_timeout_expired__str__(self):
+ self.assertEqual(
+ str(psutil.TimeoutExpired(5)),
+ "timeout after 5 seconds")
+ self.assertEqual(
+ str(psutil.TimeoutExpired(5, pid=321, name="name")),
+ "timeout after 5 seconds (pid=321, name='name')")
+
+ def test_process__eq__(self):
+ p1 = psutil.Process()
+ p2 = psutil.Process()
+ self.assertEqual(p1, p2)
+ p2._ident = (0, 0)
+ self.assertNotEqual(p1, p2)
+ self.assertNotEqual(p1, 'foo')
+
+ def test_process__hash__(self):
+ s = set([psutil.Process(), psutil.Process()])
+ self.assertEqual(len(s), 1)
+
+
+# ===================================================================
+# --- Misc, generic, corner cases
+# ===================================================================
+
+
+class TestMisc(PsutilTestCase):
+
+ def test__all__(self):
+ dir_psutil = dir(psutil)
+ for name in dir_psutil:
+ if name in ('long', 'tests', 'test', 'PermissionError',
+ 'ProcessLookupError'):
+ continue
+ if not name.startswith('_'):
+ try:
+ __import__(name)
+ except ImportError:
+ if name not in psutil.__all__:
+ fun = getattr(psutil, name)
+ if fun is None:
+ continue
+ if (fun.__doc__ is not None and
+ 'deprecated' not in fun.__doc__.lower()):
+ raise self.fail('%r not in psutil.__all__' % name)
+
+ # Import 'star' will break if __all__ is inconsistent, see:
+ # https://github.com/giampaolo/psutil/issues/656
+ # Can't do `from psutil import *` as it won't work on python 3
+ # so we simply iterate over __all__.
+ for name in psutil.__all__:
+ self.assertIn(name, dir_psutil)
+
+ def test_version(self):
+ self.assertEqual('.'.join([str(x) for x in psutil.version_info]),
+ psutil.__version__)
+
+ def test_process_as_dict_no_new_names(self):
+ # See https://github.com/giampaolo/psutil/issues/813
+ p = psutil.Process()
+ p.foo = '1'
+ self.assertNotIn('foo', p.as_dict())
+
+ def test_serialization(self):
+ def check(ret):
+ if json is not None:
+ json.loads(json.dumps(ret))
+ a = pickle.dumps(ret)
+ b = pickle.loads(a)
+ self.assertEqual(ret, b)
+
+ check(psutil.Process().as_dict())
+ check(psutil.virtual_memory())
+ check(psutil.swap_memory())
+ check(psutil.cpu_times())
+ check(psutil.cpu_times_percent(interval=0))
+ check(psutil.net_io_counters())
+ if LINUX and not os.path.exists('/proc/diskstats'):
+ pass
+ else:
+ if not APPVEYOR:
+ check(psutil.disk_io_counters())
+ check(psutil.disk_partitions())
+ check(psutil.disk_usage(os.getcwd()))
+ check(psutil.users())
+
+ # XXX: https://github.com/pypa/setuptools/pull/2896
+ @unittest.skipIf(APPVEYOR, "temporarily disabled due to setuptools bug")
+ def test_setup_script(self):
+ setup_py = os.path.join(ROOT_DIR, 'setup.py')
+ if CI_TESTING and not os.path.exists(setup_py):
+ return self.skipTest("can't find setup.py")
+ module = import_module_by_path(setup_py)
+ self.assertRaises(SystemExit, module.setup)
+ self.assertEqual(module.get_version(), psutil.__version__)
+
+ def test_ad_on_process_creation(self):
+ # We are supposed to be able to instantiate Process also in case
+ # of zombie processes or access denied.
+ with mock.patch.object(psutil.Process, 'create_time',
+ side_effect=psutil.AccessDenied) as meth:
+ psutil.Process()
+ assert meth.called
+ with mock.patch.object(psutil.Process, 'create_time',
+ side_effect=psutil.ZombieProcess(1)) as meth:
+ psutil.Process()
+ assert meth.called
+ with mock.patch.object(psutil.Process, 'create_time',
+ side_effect=ValueError) as meth:
+ with self.assertRaises(ValueError):
+ psutil.Process()
+ assert meth.called
+
+ def test_sanity_version_check(self):
+ # see: https://github.com/giampaolo/psutil/issues/564
+ with mock.patch(
+ "psutil._psplatform.cext.version", return_value="0.0.0"):
+ with self.assertRaises(ImportError) as cm:
+ reload_module(psutil)
+ self.assertIn("version conflict", str(cm.exception).lower())
+
+
+# ===================================================================
+# --- psutil/_common.py utils
+# ===================================================================
+
+
+class TestCommonModule(PsutilTestCase):
+
+ def test_memoize(self):
+ @memoize
+ def foo(*args, **kwargs):
+ """foo docstring"""
+ calls.append(None)
+ return (args, kwargs)
+
+ calls = []
+ # no args
+ for x in range(2):
+ ret = foo()
+ expected = ((), {})
+ self.assertEqual(ret, expected)
+ self.assertEqual(len(calls), 1)
+ # with args
+ for x in range(2):
+ ret = foo(1)
+ expected = ((1, ), {})
+ self.assertEqual(ret, expected)
+ self.assertEqual(len(calls), 2)
+ # with args + kwargs
+ for x in range(2):
+ ret = foo(1, bar=2)
+ expected = ((1, ), {'bar': 2})
+ self.assertEqual(ret, expected)
+ self.assertEqual(len(calls), 3)
+ # clear cache
+ foo.cache_clear()
+ ret = foo()
+ expected = ((), {})
+ self.assertEqual(ret, expected)
+ self.assertEqual(len(calls), 4)
+ # docstring
+ self.assertEqual(foo.__doc__, "foo docstring")
+
+ def test_memoize_when_activated(self):
+ class Foo:
+
+ @memoize_when_activated
+ def foo(self):
+ calls.append(None)
+
+ f = Foo()
+ calls = []
+ f.foo()
+ f.foo()
+ self.assertEqual(len(calls), 2)
+
+ # activate
+ calls = []
+ f.foo.cache_activate(f)
+ f.foo()
+ f.foo()
+ self.assertEqual(len(calls), 1)
+
+ # deactivate
+ calls = []
+ f.foo.cache_deactivate(f)
+ f.foo()
+ f.foo()
+ self.assertEqual(len(calls), 2)
+
+ def test_parse_environ_block(self):
+ def k(s):
+ return s.upper() if WINDOWS else s
+
+ self.assertEqual(parse_environ_block("a=1\0"),
+ {k("a"): "1"})
+ self.assertEqual(parse_environ_block("a=1\0b=2\0\0"),
+ {k("a"): "1", k("b"): "2"})
+ self.assertEqual(parse_environ_block("a=1\0b=\0\0"),
+ {k("a"): "1", k("b"): ""})
+ # ignore everything after \0\0
+ self.assertEqual(parse_environ_block("a=1\0b=2\0\0c=3\0"),
+ {k("a"): "1", k("b"): "2"})
+ # ignore everything that is not an assignment
+ self.assertEqual(parse_environ_block("xxx\0a=1\0"), {k("a"): "1"})
+ self.assertEqual(parse_environ_block("a=1\0=b=2\0"), {k("a"): "1"})
+ # do not fail if the block is incomplete
+ self.assertEqual(parse_environ_block("a=1\0b=2"), {k("a"): "1"})
+
+ def test_supports_ipv6(self):
+ self.addCleanup(supports_ipv6.cache_clear)
+ if supports_ipv6():
+ with mock.patch('psutil._common.socket') as s:
+ s.has_ipv6 = False
+ supports_ipv6.cache_clear()
+ assert not supports_ipv6()
+
+ supports_ipv6.cache_clear()
+ with mock.patch('psutil._common.socket.socket',
+ side_effect=socket.error) as s:
+ assert not supports_ipv6()
+ assert s.called
+
+ supports_ipv6.cache_clear()
+ with mock.patch('psutil._common.socket.socket',
+ side_effect=socket.gaierror) as s:
+ assert not supports_ipv6()
+ supports_ipv6.cache_clear()
+ assert s.called
+
+ supports_ipv6.cache_clear()
+ with mock.patch('psutil._common.socket.socket.bind',
+ side_effect=socket.gaierror) as s:
+ assert not supports_ipv6()
+ supports_ipv6.cache_clear()
+ assert s.called
+ else:
+ with self.assertRaises(socket.error):
+ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ try:
+ sock.bind(("::1", 0))
+ finally:
+ sock.close()
+
+ def test_isfile_strict(self):
+ this_file = os.path.abspath(__file__)
+ assert isfile_strict(this_file)
+ assert not isfile_strict(os.path.dirname(this_file))
+ with mock.patch('psutil._common.os.stat',
+ side_effect=OSError(errno.EPERM, "foo")):
+ self.assertRaises(OSError, isfile_strict, this_file)
+ with mock.patch('psutil._common.os.stat',
+ side_effect=OSError(errno.EACCES, "foo")):
+ self.assertRaises(OSError, isfile_strict, this_file)
+ with mock.patch('psutil._common.os.stat',
+ side_effect=OSError(errno.ENOENT, "foo")):
+ assert not isfile_strict(this_file)
+ with mock.patch('psutil._common.stat.S_ISREG', return_value=False):
+ assert not isfile_strict(this_file)
+
+ def test_debug(self):
+ if PY3:
+ from io import StringIO
+ else:
+ from StringIO import StringIO
+
+ with redirect_stderr(StringIO()) as f:
+ debug("hello")
+ msg = f.getvalue()
+ assert msg.startswith("psutil-debug"), msg
+ self.assertIn("hello", msg)
+ self.assertIn(__file__.replace('.pyc', '.py'), msg)
+
+ # supposed to use repr(exc)
+ with redirect_stderr(StringIO()) as f:
+ debug(ValueError("this is an error"))
+ msg = f.getvalue()
+ self.assertIn("ignoring ValueError", msg)
+ self.assertIn("'this is an error'", msg)
+
+ # supposed to use str(exc), because of extra info about file name
+ with redirect_stderr(StringIO()) as f:
+ exc = OSError(2, "no such file")
+ exc.filename = "/foo"
+ debug(exc)
+ msg = f.getvalue()
+ self.assertIn("no such file", msg)
+ self.assertIn("/foo", msg)
+
+ def test_cat_bcat(self):
+ testfn = self.get_testfn()
+ with open(testfn, "wt") as f:
+ f.write("foo")
+ self.assertEqual(cat(testfn), "foo")
+ self.assertEqual(bcat(testfn), b"foo")
+ self.assertRaises(FileNotFoundError, cat, testfn + '-invalid')
+ self.assertRaises(FileNotFoundError, bcat, testfn + '-invalid')
+ self.assertEqual(cat(testfn + '-invalid', fallback="bar"), "bar")
+ self.assertEqual(bcat(testfn + '-invalid', fallback="bar"), "bar")
+
+
+# ===================================================================
+# --- Tests for wrap_numbers() function.
+# ===================================================================
+
+
+nt = collections.namedtuple('foo', 'a b c')
+
+
+class TestWrapNumbers(PsutilTestCase):
+
+ def setUp(self):
+ wrap_numbers.cache_clear()
+
+ tearDown = setUp
+
+ def test_first_call(self):
+ input = {'disk1': nt(5, 5, 5)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+
+ def test_input_hasnt_changed(self):
+ input = {'disk1': nt(5, 5, 5)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+
+ def test_increase_but_no_wrap(self):
+ input = {'disk1': nt(5, 5, 5)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ input = {'disk1': nt(10, 15, 20)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ input = {'disk1': nt(20, 25, 30)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ input = {'disk1': nt(20, 25, 30)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+
+ def test_wrap(self):
+ # let's say 100 is the threshold
+ input = {'disk1': nt(100, 100, 100)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ # first wrap restarts from 10
+ input = {'disk1': nt(100, 100, 10)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(100, 100, 110)})
+ # then it remains the same
+ input = {'disk1': nt(100, 100, 10)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(100, 100, 110)})
+ # then it goes up
+ input = {'disk1': nt(100, 100, 90)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(100, 100, 190)})
+ # then it wraps again
+ input = {'disk1': nt(100, 100, 20)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(100, 100, 210)})
+ # and remains the same
+ input = {'disk1': nt(100, 100, 20)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(100, 100, 210)})
+ # now wrap another num
+ input = {'disk1': nt(50, 100, 20)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(150, 100, 210)})
+ # and again
+ input = {'disk1': nt(40, 100, 20)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(190, 100, 210)})
+ # keep it the same
+ input = {'disk1': nt(40, 100, 20)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(190, 100, 210)})
+
+ def test_changing_keys(self):
+ # Emulate a case where the second call to disk_io()
+ # (or whatever) provides a new disk, then the new disk
+ # disappears on the third call.
+ input = {'disk1': nt(5, 5, 5)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ input = {'disk1': nt(5, 5, 5),
+ 'disk2': nt(7, 7, 7)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ input = {'disk1': nt(8, 8, 8)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+
+ def test_changing_keys_w_wrap(self):
+ input = {'disk1': nt(50, 50, 50),
+ 'disk2': nt(100, 100, 100)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ # disk 2 wraps
+ input = {'disk1': nt(50, 50, 50),
+ 'disk2': nt(100, 100, 10)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(50, 50, 50),
+ 'disk2': nt(100, 100, 110)})
+ # disk 2 disappears
+ input = {'disk1': nt(50, 50, 50)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+
+ # then it appears again; the old wrap is supposed to be
+ # gone.
+ input = {'disk1': nt(50, 50, 50),
+ 'disk2': nt(100, 100, 100)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ # remains the same
+ input = {'disk1': nt(50, 50, 50),
+ 'disk2': nt(100, 100, 100)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'), input)
+ # and then wraps again
+ input = {'disk1': nt(50, 50, 50),
+ 'disk2': nt(100, 100, 10)}
+ self.assertEqual(wrap_numbers(input, 'disk_io'),
+ {'disk1': nt(50, 50, 50),
+ 'disk2': nt(100, 100, 110)})
+
+ def test_real_data(self):
+ d = {'nvme0n1': (300, 508, 640, 1571, 5970, 1987, 2049, 451751, 47048),
+ 'nvme0n1p1': (1171, 2, 5600256, 1024, 516, 0, 0, 0, 8),
+ 'nvme0n1p2': (54, 54, 2396160, 5165056, 4, 24, 30, 1207, 28),
+ 'nvme0n1p3': (2389, 4539, 5154, 150, 4828, 1844, 2019, 398, 348)}
+ self.assertEqual(wrap_numbers(d, 'disk_io'), d)
+ self.assertEqual(wrap_numbers(d, 'disk_io'), d)
+ # decrease this ↓
+ d = {'nvme0n1': (100, 508, 640, 1571, 5970, 1987, 2049, 451751, 47048),
+ 'nvme0n1p1': (1171, 2, 5600256, 1024, 516, 0, 0, 0, 8),
+ 'nvme0n1p2': (54, 54, 2396160, 5165056, 4, 24, 30, 1207, 28),
+ 'nvme0n1p3': (2389, 4539, 5154, 150, 4828, 1844, 2019, 398, 348)}
+ out = wrap_numbers(d, 'disk_io')
+ self.assertEqual(out['nvme0n1'][0], 400)
+
+ # --- cache tests
+
+ def test_cache_first_call(self):
+ input = {'disk1': nt(5, 5, 5)}
+ wrap_numbers(input, 'disk_io')
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(cache[0], {'disk_io': input})
+ self.assertEqual(cache[1], {'disk_io': {}})
+ self.assertEqual(cache[2], {'disk_io': {}})
+
+ def test_cache_call_twice(self):
+ input = {'disk1': nt(5, 5, 5)}
+ wrap_numbers(input, 'disk_io')
+ input = {'disk1': nt(10, 10, 10)}
+ wrap_numbers(input, 'disk_io')
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(cache[0], {'disk_io': input})
+ self.assertEqual(
+ cache[1],
+ {'disk_io': {('disk1', 0): 0, ('disk1', 1): 0, ('disk1', 2): 0}})
+ self.assertEqual(cache[2], {'disk_io': {}})
+
+ def test_cache_wrap(self):
+ # let's say 100 is the threshold
+ input = {'disk1': nt(100, 100, 100)}
+ wrap_numbers(input, 'disk_io')
+
+ # first wrap restarts from 10
+ input = {'disk1': nt(100, 100, 10)}
+ wrap_numbers(input, 'disk_io')
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(cache[0], {'disk_io': input})
+ self.assertEqual(
+ cache[1],
+ {'disk_io': {('disk1', 0): 0, ('disk1', 1): 0, ('disk1', 2): 100}})
+ self.assertEqual(cache[2], {'disk_io': {'disk1': set([('disk1', 2)])}})
+
+ def assert_():
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(
+ cache[1],
+ {'disk_io': {('disk1', 0): 0, ('disk1', 1): 0,
+ ('disk1', 2): 100}})
+ self.assertEqual(cache[2],
+ {'disk_io': {'disk1': set([('disk1', 2)])}})
+
+ # then it remains the same
+ input = {'disk1': nt(100, 100, 10)}
+ wrap_numbers(input, 'disk_io')
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(cache[0], {'disk_io': input})
+ assert_()
+
+ # then it goes up
+ input = {'disk1': nt(100, 100, 90)}
+ wrap_numbers(input, 'disk_io')
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(cache[0], {'disk_io': input})
+ assert_()
+
+ # then it wraps again
+ input = {'disk1': nt(100, 100, 20)}
+ wrap_numbers(input, 'disk_io')
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(cache[0], {'disk_io': input})
+ self.assertEqual(
+ cache[1],
+ {'disk_io': {('disk1', 0): 0, ('disk1', 1): 0, ('disk1', 2): 190}})
+ self.assertEqual(cache[2], {'disk_io': {'disk1': set([('disk1', 2)])}})
+
+ def test_cache_changing_keys(self):
+ input = {'disk1': nt(5, 5, 5)}
+ wrap_numbers(input, 'disk_io')
+ input = {'disk1': nt(5, 5, 5),
+ 'disk2': nt(7, 7, 7)}
+ wrap_numbers(input, 'disk_io')
+ cache = wrap_numbers.cache_info()
+ self.assertEqual(cache[0], {'disk_io': input})
+ self.assertEqual(
+ cache[1],
+ {'disk_io': {('disk1', 0): 0, ('disk1', 1): 0, ('disk1', 2): 0}})
+ self.assertEqual(cache[2], {'disk_io': {}})
+
+ def test_cache_clear(self):
+ input = {'disk1': nt(5, 5, 5)}
+ wrap_numbers(input, 'disk_io')
+ wrap_numbers(input, 'disk_io')
+ wrap_numbers.cache_clear('disk_io')
+ self.assertEqual(wrap_numbers.cache_info(), ({}, {}, {}))
+ wrap_numbers.cache_clear('disk_io')
+ wrap_numbers.cache_clear('?!?')
+
+ @unittest.skipIf(not HAS_NET_IO_COUNTERS, 'not supported')
+ def test_cache_clear_public_apis(self):
+ if not psutil.disk_io_counters() or not psutil.net_io_counters():
+ return self.skipTest("no disks or NICs available")
+ psutil.disk_io_counters()
+ psutil.net_io_counters()
+ caches = wrap_numbers.cache_info()
+ for cache in caches:
+ self.assertIn('psutil.disk_io_counters', cache)
+ self.assertIn('psutil.net_io_counters', cache)
+
+ psutil.disk_io_counters.cache_clear()
+ caches = wrap_numbers.cache_info()
+ for cache in caches:
+ self.assertIn('psutil.net_io_counters', cache)
+ self.assertNotIn('psutil.disk_io_counters', cache)
+
+ psutil.net_io_counters.cache_clear()
+ caches = wrap_numbers.cache_info()
+ self.assertEqual(caches, ({}, {}, {}))
+
+
+# ===================================================================
+# --- Example script tests
+# ===================================================================
+
+
+@unittest.skipIf(not os.path.exists(SCRIPTS_DIR),
+ "can't locate scripts directory")
+class TestScripts(PsutilTestCase):
+ """Tests for scripts in the "scripts" directory."""
+
+ @staticmethod
+ def assert_stdout(exe, *args, **kwargs):
+ exe = '%s' % os.path.join(SCRIPTS_DIR, exe)
+ cmd = [PYTHON_EXE, exe]
+ for arg in args:
+ cmd.append(arg)
+ try:
+ out = sh(cmd, **kwargs).strip()
+ except RuntimeError as err:
+ if 'AccessDenied' in str(err):
+ return str(err)
+ else:
+ raise
+ assert out, out
+ return out
+
+ @staticmethod
+ def assert_syntax(exe, args=None):
+ exe = os.path.join(SCRIPTS_DIR, exe)
+ if PY3:
+ f = open(exe, 'rt', encoding='utf8')
+ else:
+ f = open(exe, 'rt')
+ with f:
+ src = f.read()
+ ast.parse(src)
+
+ def test_coverage(self):
+ # make sure all example scripts have a test method defined
+ meths = dir(self)
+ for name in os.listdir(SCRIPTS_DIR):
+ if name.endswith('.py'):
+ if 'test_' + os.path.splitext(name)[0] not in meths:
+ # self.assert_stdout(name)
+ raise self.fail('no test defined for %r script'
+ % os.path.join(SCRIPTS_DIR, name))
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ def test_executable(self):
+ for root, dirs, files in os.walk(SCRIPTS_DIR):
+ for file in files:
+ if file.endswith('.py'):
+ path = os.path.join(root, file)
+ if not stat.S_IXUSR & os.stat(path)[stat.ST_MODE]:
+ raise self.fail('%r is not executable' % path)
+
+ def test_disk_usage(self):
+ self.assert_stdout('disk_usage.py')
+
+ def test_free(self):
+ self.assert_stdout('free.py')
+
+ def test_meminfo(self):
+ self.assert_stdout('meminfo.py')
+
+ def test_procinfo(self):
+ self.assert_stdout('procinfo.py', str(os.getpid()))
+
+ @unittest.skipIf(CI_TESTING and not psutil.users(), "no users")
+ def test_who(self):
+ self.assert_stdout('who.py')
+
+ def test_ps(self):
+ self.assert_stdout('ps.py')
+
+ def test_pstree(self):
+ self.assert_stdout('pstree.py')
+
+ def test_netstat(self):
+ self.assert_stdout('netstat.py')
+
+ def test_ifconfig(self):
+ self.assert_stdout('ifconfig.py')
+
+ @unittest.skipIf(not HAS_MEMORY_MAPS, "not supported")
+ def test_pmap(self):
+ self.assert_stdout('pmap.py', str(os.getpid()))
+
+ def test_procsmem(self):
+ if 'uss' not in psutil.Process().memory_full_info()._fields:
+ raise self.skipTest("not supported")
+ self.assert_stdout('procsmem.py')
+
+ def test_killall(self):
+ self.assert_syntax('killall.py')
+
+ def test_nettop(self):
+ self.assert_syntax('nettop.py')
+
+ def test_top(self):
+ self.assert_syntax('top.py')
+
+ def test_iotop(self):
+ self.assert_syntax('iotop.py')
+
+ def test_pidof(self):
+ output = self.assert_stdout('pidof.py', psutil.Process().name())
+ self.assertIn(str(os.getpid()), output)
+
+ @unittest.skipIf(not WINDOWS, "WINDOWS only")
+ def test_winservices(self):
+ self.assert_stdout('winservices.py')
+
+ def test_cpu_distribution(self):
+ self.assert_syntax('cpu_distribution.py')
+
+ @unittest.skipIf(not HAS_SENSORS_TEMPERATURES, "not supported")
+ def test_temperatures(self):
+ if not psutil.sensors_temperatures():
+ self.skipTest("no temperatures")
+ self.assert_stdout('temperatures.py')
+
+ @unittest.skipIf(not HAS_SENSORS_FANS, "not supported")
+ def test_fans(self):
+ if not psutil.sensors_fans():
+ self.skipTest("no fans")
+ self.assert_stdout('fans.py')
+
+ @unittest.skipIf(not HAS_SENSORS_BATTERY, "not supported")
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_battery(self):
+ self.assert_stdout('battery.py')
+
+ @unittest.skipIf(not HAS_SENSORS_BATTERY, "not supported")
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_sensors(self):
+ self.assert_stdout('sensors.py')
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_osx.py b/lib/psutil/tests/test_osx.py
new file mode 100644
index 0000000..af12648
--- /dev/null
+++ b/lib/psutil/tests/test_osx.py
@@ -0,0 +1,241 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""macOS specific tests."""
+
+import platform
+import re
+import time
+import unittest
+
+import psutil
+from psutil import MACOS
+from psutil import POSIX
+from psutil.tests import HAS_BATTERY
+from psutil.tests import TOLERANCE_DISK_USAGE
+from psutil.tests import TOLERANCE_SYS_MEM
+from psutil.tests import PsutilTestCase
+from psutil.tests import retry_on_failure
+from psutil.tests import sh
+from psutil.tests import spawn_testproc
+from psutil.tests import terminate
+
+
+if POSIX:
+ from psutil._psutil_posix import getpagesize
+
+
+def sysctl(cmdline):
+ """Expects a sysctl command with an argument and parse the result
+ returning only the value of interest.
+ """
+ out = sh(cmdline)
+ result = out.split()[1]
+ try:
+ return int(result)
+ except ValueError:
+ return result
+
+
+def vm_stat(field):
+ """Wrapper around 'vm_stat' cmdline utility."""
+ out = sh('vm_stat')
+ for line in out.split('\n'):
+ if field in line:
+ break
+ else:
+ raise ValueError("line not found")
+ return int(re.search(r'\d+', line).group(0)) * getpagesize()
+
+
+# http://code.activestate.com/recipes/578019/
+def human2bytes(s):
+ SYMBOLS = {
+ 'customary': ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'),
+ }
+ init = s
+ num = ""
+ while s and s[0:1].isdigit() or s[0:1] == '.':
+ num += s[0]
+ s = s[1:]
+ num = float(num)
+ letter = s.strip()
+ for name, sset in SYMBOLS.items():
+ if letter in sset:
+ break
+ else:
+ if letter == 'k':
+ sset = SYMBOLS['customary']
+ letter = letter.upper()
+ else:
+ raise ValueError("can't interpret %r" % init)
+ prefix = {sset[0]: 1}
+ for i, s in enumerate(sset[1:]):
+ prefix[s] = 1 << (i + 1) * 10
+ return int(num * prefix[letter])
+
+
+@unittest.skipIf(not MACOS, "MACOS only")
+class TestProcess(PsutilTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.pid = spawn_testproc().pid
+
+ @classmethod
+ def tearDownClass(cls):
+ terminate(cls.pid)
+
+ def test_process_create_time(self):
+ output = sh("ps -o lstart -p %s" % self.pid)
+ start_ps = output.replace('STARTED', '').strip()
+ hhmmss = start_ps.split(' ')[-2]
+ year = start_ps.split(' ')[-1]
+ start_psutil = psutil.Process(self.pid).create_time()
+ self.assertEqual(
+ hhmmss,
+ time.strftime("%H:%M:%S", time.localtime(start_psutil)))
+ self.assertEqual(
+ year,
+ time.strftime("%Y", time.localtime(start_psutil)))
+
+
+@unittest.skipIf(not MACOS, "MACOS only")
+class TestSystemAPIs(PsutilTestCase):
+
+ # --- disk
+
+ @retry_on_failure()
+ def test_disks(self):
+ # test psutil.disk_usage() and psutil.disk_partitions()
+ # against "df -a"
+ def df(path):
+ out = sh('df -k "%s"' % path).strip()
+ lines = out.split('\n')
+ lines.pop(0)
+ line = lines.pop(0)
+ dev, total, used, free = line.split()[:4]
+ if dev == 'none':
+ dev = ''
+ total = int(total) * 1024
+ used = int(used) * 1024
+ free = int(free) * 1024
+ return dev, total, used, free
+
+ for part in psutil.disk_partitions(all=False):
+ usage = psutil.disk_usage(part.mountpoint)
+ dev, total, used, free = df(part.mountpoint)
+ self.assertEqual(part.device, dev)
+ self.assertEqual(usage.total, total)
+ self.assertAlmostEqual(usage.free, free,
+ delta=TOLERANCE_DISK_USAGE)
+ self.assertAlmostEqual(usage.used, used,
+ delta=TOLERANCE_DISK_USAGE)
+
+ # --- cpu
+
+ def test_cpu_count_logical(self):
+ num = sysctl("sysctl hw.logicalcpu")
+ self.assertEqual(num, psutil.cpu_count(logical=True))
+
+ def test_cpu_count_cores(self):
+ num = sysctl("sysctl hw.physicalcpu")
+ self.assertEqual(num, psutil.cpu_count(logical=False))
+
+ # TODO: remove this once 1892 is fixed
+ @unittest.skipIf(platform.machine() == 'arm64', "skipped due to #1892")
+ def test_cpu_freq(self):
+ freq = psutil.cpu_freq()
+ self.assertEqual(
+ freq.current * 1000 * 1000, sysctl("sysctl hw.cpufrequency"))
+ self.assertEqual(
+ freq.min * 1000 * 1000, sysctl("sysctl hw.cpufrequency_min"))
+ self.assertEqual(
+ freq.max * 1000 * 1000, sysctl("sysctl hw.cpufrequency_max"))
+
+ # --- virtual mem
+
+ def test_vmem_total(self):
+ sysctl_hwphymem = sysctl('sysctl hw.memsize')
+ self.assertEqual(sysctl_hwphymem, psutil.virtual_memory().total)
+
+ @retry_on_failure()
+ def test_vmem_free(self):
+ vmstat_val = vm_stat("free")
+ psutil_val = psutil.virtual_memory().free
+ self.assertAlmostEqual(psutil_val, vmstat_val, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_active(self):
+ vmstat_val = vm_stat("active")
+ psutil_val = psutil.virtual_memory().active
+ self.assertAlmostEqual(psutil_val, vmstat_val, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_inactive(self):
+ vmstat_val = vm_stat("inactive")
+ psutil_val = psutil.virtual_memory().inactive
+ self.assertAlmostEqual(psutil_val, vmstat_val, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_vmem_wired(self):
+ vmstat_val = vm_stat("wired")
+ psutil_val = psutil.virtual_memory().wired
+ self.assertAlmostEqual(psutil_val, vmstat_val, delta=TOLERANCE_SYS_MEM)
+
+ # --- swap mem
+
+ @retry_on_failure()
+ def test_swapmem_sin(self):
+ vmstat_val = vm_stat("Pageins")
+ psutil_val = psutil.swap_memory().sin
+ self.assertAlmostEqual(psutil_val, vmstat_val, delta=TOLERANCE_SYS_MEM)
+
+ @retry_on_failure()
+ def test_swapmem_sout(self):
+ vmstat_val = vm_stat("Pageout")
+ psutil_val = psutil.swap_memory().sout
+ self.assertAlmostEqual(psutil_val, vmstat_val, delta=TOLERANCE_SYS_MEM)
+
+ # Not very reliable.
+ # def test_swapmem_total(self):
+ # out = sh('sysctl vm.swapusage')
+ # out = out.replace('vm.swapusage: ', '')
+ # total, used, free = re.findall('\d+.\d+\w', out)
+ # psutil_smem = psutil.swap_memory()
+ # self.assertEqual(psutil_smem.total, human2bytes(total))
+ # self.assertEqual(psutil_smem.used, human2bytes(used))
+ # self.assertEqual(psutil_smem.free, human2bytes(free))
+
+ # --- network
+
+ def test_net_if_stats(self):
+ for name, stats in psutil.net_if_stats().items():
+ try:
+ out = sh("ifconfig %s" % name)
+ except RuntimeError:
+ pass
+ else:
+ self.assertEqual(stats.isup, 'RUNNING' in out, msg=out)
+ self.assertEqual(stats.mtu,
+ int(re.findall(r'mtu (\d+)', out)[0]))
+
+ # --- sensors_battery
+
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_sensors_battery(self):
+ out = sh("pmset -g batt")
+ percent = re.search(r"(\d+)%", out).group(1)
+ drawing_from = re.search("Now drawing from '([^']+)'", out).group(1)
+ power_plugged = drawing_from == "AC Power"
+ psutil_result = psutil.sensors_battery()
+ self.assertEqual(psutil_result.power_plugged, power_plugged)
+ self.assertEqual(psutil_result.percent, int(percent))
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_posix.py b/lib/psutil/tests/test_posix.py
new file mode 100644
index 0000000..d873223
--- /dev/null
+++ b/lib/psutil/tests/test_posix.py
@@ -0,0 +1,432 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""POSIX specific tests."""
+
+import datetime
+import errno
+import os
+import re
+import subprocess
+import time
+import unittest
+
+import psutil
+from psutil import AIX
+from psutil import BSD
+from psutil import LINUX
+from psutil import MACOS
+from psutil import OPENBSD
+from psutil import POSIX
+from psutil import SUNOS
+from psutil.tests import CI_TESTING
+from psutil.tests import HAS_NET_IO_COUNTERS
+from psutil.tests import PYTHON_EXE
+from psutil.tests import PsutilTestCase
+from psutil.tests import mock
+from psutil.tests import retry_on_failure
+from psutil.tests import sh
+from psutil.tests import skip_on_access_denied
+from psutil.tests import spawn_testproc
+from psutil.tests import terminate
+from psutil.tests import which
+
+
+if POSIX:
+ import mmap
+ import resource
+
+ from psutil._psutil_posix import getpagesize
+
+
+def ps(fmt, pid=None):
+ """
+ Wrapper for calling the ps command with a little bit of cross-platform
+ support for a narrow range of features.
+ """
+
+ cmd = ['ps']
+
+ if LINUX:
+ cmd.append('--no-headers')
+
+ if pid is not None:
+ cmd.extend(['-p', str(pid)])
+ else:
+ if SUNOS or AIX:
+ cmd.append('-A')
+ else:
+ cmd.append('ax')
+
+ if SUNOS:
+ fmt_map = set(('command', 'comm', 'start', 'stime'))
+ fmt = fmt_map.get(fmt, fmt)
+
+ cmd.extend(['-o', fmt])
+
+ output = sh(cmd)
+
+ if LINUX:
+ output = output.splitlines()
+ else:
+ output = output.splitlines()[1:]
+
+ all_output = []
+ for line in output:
+ line = line.strip()
+
+ try:
+ line = int(line)
+ except ValueError:
+ pass
+
+ all_output.append(line)
+
+ if pid is None:
+ return all_output
+ else:
+ return all_output[0]
+
+# ps "-o" field names differ wildly between platforms.
+# "comm" means "only executable name" but is not available on BSD platforms.
+# "args" means "command with all its arguments", and is also not available
+# on BSD platforms.
+# "command" is like "args" on most platforms, but like "comm" on AIX,
+# and not available on SUNOS.
+# so for the executable name we can use "comm" on Solaris and split "command"
+# on other platforms.
+# to get the cmdline (with args) we have to use "args" on AIX and
+# Solaris, and can use "command" on all others.
+
+
+def ps_name(pid):
+ field = "command"
+ if SUNOS:
+ field = "comm"
+ return ps(field, pid).split()[0]
+
+
+def ps_args(pid):
+ field = "command"
+ if AIX or SUNOS:
+ field = "args"
+ return ps(field, pid)
+
+
+def ps_rss(pid):
+ field = "rss"
+ if AIX:
+ field = "rssize"
+ return ps(field, pid)
+
+
+def ps_vsz(pid):
+ field = "vsz"
+ if AIX:
+ field = "vsize"
+ return ps(field, pid)
+
+
+@unittest.skipIf(not POSIX, "POSIX only")
+class TestProcess(PsutilTestCase):
+ """Compare psutil results against 'ps' command line utility (mainly)."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.pid = spawn_testproc([PYTHON_EXE, "-E", "-O"],
+ stdin=subprocess.PIPE).pid
+
+ @classmethod
+ def tearDownClass(cls):
+ terminate(cls.pid)
+
+ def test_ppid(self):
+ ppid_ps = ps('ppid', self.pid)
+ ppid_psutil = psutil.Process(self.pid).ppid()
+ self.assertEqual(ppid_ps, ppid_psutil)
+
+ def test_uid(self):
+ uid_ps = ps('uid', self.pid)
+ uid_psutil = psutil.Process(self.pid).uids().real
+ self.assertEqual(uid_ps, uid_psutil)
+
+ def test_gid(self):
+ gid_ps = ps('rgid', self.pid)
+ gid_psutil = psutil.Process(self.pid).gids().real
+ self.assertEqual(gid_ps, gid_psutil)
+
+ def test_username(self):
+ username_ps = ps('user', self.pid)
+ username_psutil = psutil.Process(self.pid).username()
+ self.assertEqual(username_ps, username_psutil)
+
+ def test_username_no_resolution(self):
+ # Emulate a case where the system can't resolve the uid to
+ # a username in which case psutil is supposed to return
+ # the stringified uid.
+ p = psutil.Process()
+ with mock.patch("psutil.pwd.getpwuid", side_effect=KeyError) as fun:
+ self.assertEqual(p.username(), str(p.uids().real))
+ assert fun.called
+
+ @skip_on_access_denied()
+ @retry_on_failure()
+ def test_rss_memory(self):
+ # give python interpreter some time to properly initialize
+ # so that the results are the same
+ time.sleep(0.1)
+ rss_ps = ps_rss(self.pid)
+ rss_psutil = psutil.Process(self.pid).memory_info()[0] / 1024
+ self.assertEqual(rss_ps, rss_psutil)
+
+ @skip_on_access_denied()
+ @retry_on_failure()
+ def test_vsz_memory(self):
+ # give python interpreter some time to properly initialize
+ # so that the results are the same
+ time.sleep(0.1)
+ vsz_ps = ps_vsz(self.pid)
+ vsz_psutil = psutil.Process(self.pid).memory_info()[1] / 1024
+ self.assertEqual(vsz_ps, vsz_psutil)
+
+ def test_name(self):
+ name_ps = ps_name(self.pid)
+ # remove path if there is any, from the command
+ name_ps = os.path.basename(name_ps).lower()
+ name_psutil = psutil.Process(self.pid).name().lower()
+ # ...because of how we calculate PYTHON_EXE; on MACOS this may
+ # be "pythonX.Y".
+ name_ps = re.sub(r"\d.\d", "", name_ps)
+ name_psutil = re.sub(r"\d.\d", "", name_psutil)
+ # ...may also be "python.X"
+ name_ps = re.sub(r"\d", "", name_ps)
+ name_psutil = re.sub(r"\d", "", name_psutil)
+ self.assertEqual(name_ps, name_psutil)
+
+ def test_name_long(self):
+ # On UNIX the kernel truncates the name to the first 15
+ # characters. In such a case psutil tries to determine the
+ # full name from the cmdline.
+ name = "long-program-name"
+ cmdline = ["long-program-name-extended", "foo", "bar"]
+ with mock.patch("psutil._psplatform.Process.name",
+ return_value=name):
+ with mock.patch("psutil._psplatform.Process.cmdline",
+ return_value=cmdline):
+ p = psutil.Process()
+ self.assertEqual(p.name(), "long-program-name-extended")
+
+ def test_name_long_cmdline_ad_exc(self):
+ # Same as above but emulates a case where cmdline() raises
+ # AccessDenied in which case psutil is supposed to return
+ # the truncated name instead of crashing.
+ name = "long-program-name"
+ with mock.patch("psutil._psplatform.Process.name",
+ return_value=name):
+ with mock.patch("psutil._psplatform.Process.cmdline",
+ side_effect=psutil.AccessDenied(0, "")):
+ p = psutil.Process()
+ self.assertEqual(p.name(), "long-program-name")
+
+ def test_name_long_cmdline_nsp_exc(self):
+ # Same as above but emulates a case where cmdline() raises NSP
+ # which is supposed to propagate.
+ name = "long-program-name"
+ with mock.patch("psutil._psplatform.Process.name",
+ return_value=name):
+ with mock.patch("psutil._psplatform.Process.cmdline",
+ side_effect=psutil.NoSuchProcess(0, "")):
+ p = psutil.Process()
+ self.assertRaises(psutil.NoSuchProcess, p.name)
+
+ @unittest.skipIf(MACOS or BSD, 'ps -o start not available')
+ def test_create_time(self):
+ time_ps = ps('start', self.pid)
+ time_psutil = psutil.Process(self.pid).create_time()
+ time_psutil_tstamp = datetime.datetime.fromtimestamp(
+ time_psutil).strftime("%H:%M:%S")
+ # sometimes ps shows the time rounded up instead of down, so we check
+ # for both possible values
+ round_time_psutil = round(time_psutil)
+ round_time_psutil_tstamp = datetime.datetime.fromtimestamp(
+ round_time_psutil).strftime("%H:%M:%S")
+ self.assertIn(time_ps, [time_psutil_tstamp, round_time_psutil_tstamp])
+
+ def test_exe(self):
+ ps_pathname = ps_name(self.pid)
+ psutil_pathname = psutil.Process(self.pid).exe()
+ try:
+ self.assertEqual(ps_pathname, psutil_pathname)
+ except AssertionError:
+ # certain platforms such as BSD are more accurate returning:
+ # "/usr/local/bin/python2.7"
+ # ...instead of:
+ # "/usr/local/bin/python"
+ # We do not want to consider this difference in accuracy
+ # an error.
+ adjusted_ps_pathname = ps_pathname[:len(ps_pathname)]
+ self.assertEqual(ps_pathname, adjusted_ps_pathname)
+
+ # On macOS the official python installer exposes a python wrapper that
+ # executes a python executable hidden inside an application bundle inside
+ # the Python framework.
+ # There's a race condition between the ps call & the psutil call below
+ # depending on the completion of the execve call so let's retry on failure
+ @retry_on_failure()
+ def test_cmdline(self):
+ ps_cmdline = ps_args(self.pid)
+ psutil_cmdline = " ".join(psutil.Process(self.pid).cmdline())
+ self.assertEqual(ps_cmdline, psutil_cmdline)
+
+ # On SUNOS "ps" reads niceness /proc/pid/psinfo which returns an
+ # incorrect value (20); the real deal is getpriority(2) which
+ # returns 0; psutil relies on it, see:
+ # https://github.com/giampaolo/psutil/issues/1082
+ # AIX has the same issue
+ @unittest.skipIf(SUNOS, "not reliable on SUNOS")
+ @unittest.skipIf(AIX, "not reliable on AIX")
+ def test_nice(self):
+ ps_nice = ps('nice', self.pid)
+ psutil_nice = psutil.Process().nice()
+ self.assertEqual(ps_nice, psutil_nice)
+
+
+@unittest.skipIf(not POSIX, "POSIX only")
+class TestSystemAPIs(PsutilTestCase):
+ """Test some system APIs."""
+
+ @retry_on_failure()
+ def test_pids(self):
+ # Note: this test might fail if the OS is starting/killing
+ # other processes in the meantime
+ pids_ps = sorted(ps("pid"))
+ pids_psutil = psutil.pids()
+
+ # on MACOS and OPENBSD ps doesn't show pid 0
+ if MACOS or OPENBSD and 0 not in pids_ps:
+ pids_ps.insert(0, 0)
+
+ # There will often be one more process in pids_ps for ps itself
+ if len(pids_ps) - len(pids_psutil) > 1:
+ difference = [x for x in pids_psutil if x not in pids_ps] + \
+ [x for x in pids_ps if x not in pids_psutil]
+ raise self.fail("difference: " + str(difference))
+
+ # for some reason ifconfig -a does not report all interfaces
+ # returned by psutil
+ @unittest.skipIf(SUNOS, "unreliable on SUNOS")
+ @unittest.skipIf(not which('ifconfig'), "no ifconfig cmd")
+ @unittest.skipIf(not HAS_NET_IO_COUNTERS, "not supported")
+ def test_nic_names(self):
+ output = sh("ifconfig -a")
+ for nic in psutil.net_io_counters(pernic=True).keys():
+ for line in output.split():
+ if line.startswith(nic):
+ break
+ else:
+ raise self.fail(
+ "couldn't find %s nic in 'ifconfig -a' output\n%s" % (
+ nic, output))
+
+ @unittest.skipIf(CI_TESTING and not psutil.users(), "unreliable on CI")
+ @retry_on_failure()
+ def test_users(self):
+ out = sh("who")
+ if not out.strip():
+ raise self.skipTest("no users on this system")
+ lines = out.split('\n')
+ users = [x.split()[0] for x in lines]
+ terminals = [x.split()[1] for x in lines]
+ self.assertEqual(len(users), len(psutil.users()))
+ for u in psutil.users():
+ self.assertIn(u.name, users)
+ self.assertIn(u.terminal, terminals)
+
+ def test_pid_exists_let_raise(self):
+ # According to "man 2 kill" possible error values for kill
+ # are (EINVAL, EPERM, ESRCH). Test that any other errno
+ # results in an exception.
+ with mock.patch("psutil._psposix.os.kill",
+ side_effect=OSError(errno.EBADF, "")) as m:
+ self.assertRaises(OSError, psutil._psposix.pid_exists, os.getpid())
+ assert m.called
+
+ def test_os_waitpid_let_raise(self):
+ # os.waitpid() is supposed to catch EINTR and ECHILD only.
+ # Test that any other errno results in an exception.
+ with mock.patch("psutil._psposix.os.waitpid",
+ side_effect=OSError(errno.EBADF, "")) as m:
+ self.assertRaises(OSError, psutil._psposix.wait_pid, os.getpid())
+ assert m.called
+
+ def test_os_waitpid_eintr(self):
+ # os.waitpid() is supposed to "retry" on EINTR.
+ with mock.patch("psutil._psposix.os.waitpid",
+ side_effect=OSError(errno.EINTR, "")) as m:
+ self.assertRaises(
+ psutil._psposix.TimeoutExpired,
+ psutil._psposix.wait_pid, os.getpid(), timeout=0.01)
+ assert m.called
+
+ def test_os_waitpid_bad_ret_status(self):
+ # Simulate os.waitpid() returning a bad status.
+ with mock.patch("psutil._psposix.os.waitpid",
+ return_value=(1, -1)) as m:
+ self.assertRaises(ValueError,
+ psutil._psposix.wait_pid, os.getpid())
+ assert m.called
+
+ # AIX can return '-' in df output instead of numbers, e.g. for /proc
+ @unittest.skipIf(AIX, "unreliable on AIX")
+ @retry_on_failure()
+ def test_disk_usage(self):
+ def df(device):
+ out = sh("df -k %s" % device).strip()
+ line = out.split('\n')[1]
+ fields = line.split()
+ total = int(fields[1]) * 1024
+ used = int(fields[2]) * 1024
+ free = int(fields[3]) * 1024
+ percent = float(fields[4].replace('%', ''))
+ return (total, used, free, percent)
+
+ tolerance = 4 * 1024 * 1024 # 4MB
+ for part in psutil.disk_partitions(all=False):
+ usage = psutil.disk_usage(part.mountpoint)
+ try:
+ total, used, free, percent = df(part.device)
+ except RuntimeError as err:
+ # see:
+ # https://travis-ci.org/giampaolo/psutil/jobs/138338464
+ # https://travis-ci.org/giampaolo/psutil/jobs/138343361
+ err = str(err).lower()
+ if "no such file or directory" in err or \
+ "raw devices not supported" in err or \
+ "permission denied" in err:
+ continue
+ else:
+ raise
+ else:
+ self.assertAlmostEqual(usage.total, total, delta=tolerance)
+ self.assertAlmostEqual(usage.used, used, delta=tolerance)
+ self.assertAlmostEqual(usage.free, free, delta=tolerance)
+ self.assertAlmostEqual(usage.percent, percent, delta=1)
+
+
+@unittest.skipIf(not POSIX, "POSIX only")
+class TestMisc(PsutilTestCase):
+
+ def test_getpagesize(self):
+ pagesize = getpagesize()
+ self.assertGreater(pagesize, 0)
+ self.assertEqual(pagesize, resource.getpagesize())
+ self.assertEqual(pagesize, mmap.PAGESIZE)
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_process.py b/lib/psutil/tests/test_process.py
new file mode 100644
index 0000000..26869e9
--- /dev/null
+++ b/lib/psutil/tests/test_process.py
@@ -0,0 +1,1591 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Tests for psutil.Process class."""
+
+import collections
+import errno
+import getpass
+import itertools
+import os
+import signal
+import socket
+import stat
+import subprocess
+import sys
+import textwrap
+import time
+import types
+import unittest
+
+import psutil
+from psutil import AIX
+from psutil import BSD
+from psutil import LINUX
+from psutil import MACOS
+from psutil import NETBSD
+from psutil import OPENBSD
+from psutil import OSX
+from psutil import POSIX
+from psutil import SUNOS
+from psutil import WINDOWS
+from psutil._common import open_text
+from psutil._compat import PY3
+from psutil._compat import FileNotFoundError
+from psutil._compat import long
+from psutil._compat import super
+from psutil.tests import APPVEYOR
+from psutil.tests import CI_TESTING
+from psutil.tests import GITHUB_ACTIONS
+from psutil.tests import GLOBAL_TIMEOUT
+from psutil.tests import HAS_CPU_AFFINITY
+from psutil.tests import HAS_ENVIRON
+from psutil.tests import HAS_IONICE
+from psutil.tests import HAS_MEMORY_MAPS
+from psutil.tests import HAS_PROC_CPU_NUM
+from psutil.tests import HAS_PROC_IO_COUNTERS
+from psutil.tests import HAS_RLIMIT
+from psutil.tests import HAS_THREADS
+from psutil.tests import MACOS_11PLUS
+from psutil.tests import PYPY
+from psutil.tests import PYTHON_EXE
+from psutil.tests import PsutilTestCase
+from psutil.tests import ThreadTask
+from psutil.tests import call_until
+from psutil.tests import copyload_shared_lib
+from psutil.tests import create_exe
+from psutil.tests import mock
+from psutil.tests import process_namespace
+from psutil.tests import reap_children
+from psutil.tests import retry_on_failure
+from psutil.tests import sh
+from psutil.tests import skip_on_access_denied
+from psutil.tests import skip_on_not_implemented
+from psutil.tests import wait_for_pid
+
+
+# ===================================================================
+# --- psutil.Process class tests
+# ===================================================================
+
+
+class TestProcess(PsutilTestCase):
+ """Tests for psutil.Process class."""
+
+ def spawn_psproc(self, *args, **kwargs):
+ sproc = self.spawn_testproc(*args, **kwargs)
+ return psutil.Process(sproc.pid)
+
+ # ---
+
+ def test_pid(self):
+ p = psutil.Process()
+ self.assertEqual(p.pid, os.getpid())
+ with self.assertRaises(AttributeError):
+ p.pid = 33
+
+ def test_kill(self):
+ p = self.spawn_psproc()
+ p.kill()
+ code = p.wait()
+ if WINDOWS:
+ self.assertEqual(code, signal.SIGTERM)
+ else:
+ self.assertEqual(code, -signal.SIGKILL)
+ self.assertProcessGone(p)
+
+ def test_terminate(self):
+ p = self.spawn_psproc()
+ p.terminate()
+ code = p.wait()
+ if WINDOWS:
+ self.assertEqual(code, signal.SIGTERM)
+ else:
+ self.assertEqual(code, -signal.SIGTERM)
+ self.assertProcessGone(p)
+
+ def test_send_signal(self):
+ sig = signal.SIGKILL if POSIX else signal.SIGTERM
+ p = self.spawn_psproc()
+ p.send_signal(sig)
+ code = p.wait()
+ if WINDOWS:
+ self.assertEqual(code, sig)
+ else:
+ self.assertEqual(code, -sig)
+ self.assertProcessGone(p)
+
+ @unittest.skipIf(not POSIX, "not POSIX")
+ def test_send_signal_mocked(self):
+ sig = signal.SIGTERM
+ p = self.spawn_psproc()
+ with mock.patch('psutil.os.kill',
+ side_effect=OSError(errno.ESRCH, "")):
+ self.assertRaises(psutil.NoSuchProcess, p.send_signal, sig)
+
+ p = self.spawn_psproc()
+ with mock.patch('psutil.os.kill',
+ side_effect=OSError(errno.EPERM, "")):
+ self.assertRaises(psutil.AccessDenied, p.send_signal, sig)
+
+ def test_wait_exited(self):
+ # Test waitpid() + WIFEXITED -> WEXITSTATUS.
+ # normal return, same as exit(0)
+ cmd = [PYTHON_EXE, "-c", "pass"]
+ p = self.spawn_psproc(cmd)
+ code = p.wait()
+ self.assertEqual(code, 0)
+ self.assertProcessGone(p)
+ # exit(1), implicit in case of error
+ cmd = [PYTHON_EXE, "-c", "1 / 0"]
+ p = self.spawn_psproc(cmd, stderr=subprocess.PIPE)
+ code = p.wait()
+ self.assertEqual(code, 1)
+ self.assertProcessGone(p)
+ # via sys.exit()
+ cmd = [PYTHON_EXE, "-c", "import sys; sys.exit(5);"]
+ p = self.spawn_psproc(cmd)
+ code = p.wait()
+ self.assertEqual(code, 5)
+ self.assertProcessGone(p)
+ # via os._exit()
+ cmd = [PYTHON_EXE, "-c", "import os; os._exit(5);"]
+ p = self.spawn_psproc(cmd)
+ code = p.wait()
+ self.assertEqual(code, 5)
+ self.assertProcessGone(p)
+
+ def test_wait_stopped(self):
+ p = self.spawn_psproc()
+ if POSIX:
+ # Test waitpid() + WIFSTOPPED and WIFCONTINUED.
+ # Note: if a process is stopped it ignores SIGTERM.
+ p.send_signal(signal.SIGSTOP)
+ self.assertRaises(psutil.TimeoutExpired, p.wait, timeout=0.001)
+ p.send_signal(signal.SIGCONT)
+ self.assertRaises(psutil.TimeoutExpired, p.wait, timeout=0.001)
+ p.send_signal(signal.SIGTERM)
+ self.assertEqual(p.wait(), -signal.SIGTERM)
+ self.assertEqual(p.wait(), -signal.SIGTERM)
+ else:
+ p.suspend()
+ self.assertRaises(psutil.TimeoutExpired, p.wait, timeout=0.001)
+ p.resume()
+ self.assertRaises(psutil.TimeoutExpired, p.wait, timeout=0.001)
+ p.terminate()
+ self.assertEqual(p.wait(), signal.SIGTERM)
+ self.assertEqual(p.wait(), signal.SIGTERM)
+
+ def test_wait_non_children(self):
+ # Test wait() against a process which is not our direct
+ # child.
+ child, grandchild = self.spawn_children_pair()
+ self.assertRaises(psutil.TimeoutExpired, child.wait, 0.01)
+ self.assertRaises(psutil.TimeoutExpired, grandchild.wait, 0.01)
+ # We also terminate the direct child otherwise the
+ # grandchild will hang until the parent is gone.
+ child.terminate()
+ grandchild.terminate()
+ child_ret = child.wait()
+ grandchild_ret = grandchild.wait()
+ if POSIX:
+ self.assertEqual(child_ret, -signal.SIGTERM)
+ # For processes which are not our children we're supposed
+ # to get None.
+ self.assertEqual(grandchild_ret, None)
+ else:
+ self.assertEqual(child_ret, signal.SIGTERM)
+ self.assertEqual(child_ret, signal.SIGTERM)
+
+ def test_wait_timeout(self):
+ p = self.spawn_psproc()
+ p.name()
+ self.assertRaises(psutil.TimeoutExpired, p.wait, 0.01)
+ self.assertRaises(psutil.TimeoutExpired, p.wait, 0)
+ self.assertRaises(ValueError, p.wait, -1)
+
+ def test_wait_timeout_nonblocking(self):
+ p = self.spawn_psproc()
+ self.assertRaises(psutil.TimeoutExpired, p.wait, 0)
+ p.kill()
+ stop_at = time.time() + GLOBAL_TIMEOUT
+ while time.time() < stop_at:
+ try:
+ code = p.wait(0)
+ break
+ except psutil.TimeoutExpired:
+ pass
+ else:
+ raise self.fail('timeout')
+ if POSIX:
+ self.assertEqual(code, -signal.SIGKILL)
+ else:
+ self.assertEqual(code, signal.SIGTERM)
+ self.assertProcessGone(p)
+
+ def test_cpu_percent(self):
+ p = psutil.Process()
+ p.cpu_percent(interval=0.001)
+ p.cpu_percent(interval=0.001)
+ for x in range(100):
+ percent = p.cpu_percent(interval=None)
+ self.assertIsInstance(percent, float)
+ self.assertGreaterEqual(percent, 0.0)
+ with self.assertRaises(ValueError):
+ p.cpu_percent(interval=-1)
+
+ def test_cpu_percent_numcpus_none(self):
+ # See: https://github.com/giampaolo/psutil/issues/1087
+ with mock.patch('psutil.cpu_count', return_value=None) as m:
+ psutil.Process().cpu_percent()
+ assert m.called
+
+ def test_cpu_times(self):
+ times = psutil.Process().cpu_times()
+ assert (times.user > 0.0) or (times.system > 0.0), times
+ assert (times.children_user >= 0.0), times
+ assert (times.children_system >= 0.0), times
+ if LINUX:
+ assert times.iowait >= 0.0, times
+ # make sure returned values can be pretty printed with strftime
+ for name in times._fields:
+ time.strftime("%H:%M:%S", time.localtime(getattr(times, name)))
+
+ def test_cpu_times_2(self):
+ user_time, kernel_time = psutil.Process().cpu_times()[:2]
+ utime, ktime = os.times()[:2]
+
+ # Use os.times()[:2] as base values to compare our results
+ # using a tolerance of +/- 0.1 seconds.
+ # It will fail if the difference between the values is > 0.1s.
+ if (max([user_time, utime]) - min([user_time, utime])) > 0.1:
+ raise self.fail("expected: %s, found: %s" % (utime, user_time))
+
+ if (max([kernel_time, ktime]) - min([kernel_time, ktime])) > 0.1:
+ raise self.fail("expected: %s, found: %s" % (ktime, kernel_time))
+
+ @unittest.skipIf(not HAS_PROC_CPU_NUM, "not supported")
+ def test_cpu_num(self):
+ p = psutil.Process()
+ num = p.cpu_num()
+ self.assertGreaterEqual(num, 0)
+ if psutil.cpu_count() == 1:
+ self.assertEqual(num, 0)
+ self.assertIn(p.cpu_num(), range(psutil.cpu_count()))
+
+ def test_create_time(self):
+ p = self.spawn_psproc()
+ now = time.time()
+ create_time = p.create_time()
+
+ # Use time.time() as base value to compare our result using a
+ # tolerance of +/- 1 second.
+ # It will fail if the difference between the values is > 2s.
+ difference = abs(create_time - now)
+ if difference > 2:
+ raise self.fail("expected: %s, found: %s, difference: %s"
+ % (now, create_time, difference))
+
+ # make sure returned value can be pretty printed with strftime
+ time.strftime("%Y %m %d %H:%M:%S", time.localtime(p.create_time()))
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_terminal(self):
+ terminal = psutil.Process().terminal()
+ if terminal is not None:
+ tty = os.path.realpath(sh('tty'))
+ self.assertEqual(terminal, tty)
+
+ @unittest.skipIf(not HAS_PROC_IO_COUNTERS, 'not supported')
+ @skip_on_not_implemented(only_if=LINUX)
+ def test_io_counters(self):
+ p = psutil.Process()
+ # test reads
+ io1 = p.io_counters()
+ with open(PYTHON_EXE, 'rb') as f:
+ f.read()
+ io2 = p.io_counters()
+ if not BSD and not AIX:
+ self.assertGreater(io2.read_count, io1.read_count)
+ self.assertEqual(io2.write_count, io1.write_count)
+ if LINUX:
+ self.assertGreater(io2.read_chars, io1.read_chars)
+ self.assertEqual(io2.write_chars, io1.write_chars)
+ else:
+ self.assertGreaterEqual(io2.read_bytes, io1.read_bytes)
+ self.assertGreaterEqual(io2.write_bytes, io1.write_bytes)
+
+ # test writes
+ io1 = p.io_counters()
+ with open(self.get_testfn(), 'wb') as f:
+ if PY3:
+ f.write(bytes("x" * 1000000, 'ascii'))
+ else:
+ f.write("x" * 1000000)
+ io2 = p.io_counters()
+ self.assertGreaterEqual(io2.write_count, io1.write_count)
+ self.assertGreaterEqual(io2.write_bytes, io1.write_bytes)
+ self.assertGreaterEqual(io2.read_count, io1.read_count)
+ self.assertGreaterEqual(io2.read_bytes, io1.read_bytes)
+ if LINUX:
+ self.assertGreater(io2.write_chars, io1.write_chars)
+ self.assertGreaterEqual(io2.read_chars, io1.read_chars)
+
+ # sanity check
+ for i in range(len(io2)):
+ if BSD and i >= 2:
+ # On BSD read_bytes and write_bytes are always set to -1.
+ continue
+ self.assertGreaterEqual(io2[i], 0)
+ self.assertGreaterEqual(io2[i], 0)
+
+ @unittest.skipIf(not HAS_IONICE, "not supported")
+ @unittest.skipIf(not LINUX, "linux only")
+ def test_ionice_linux(self):
+ p = psutil.Process()
+ if not CI_TESTING:
+ self.assertEqual(p.ionice()[0], psutil.IOPRIO_CLASS_NONE)
+ self.assertEqual(psutil.IOPRIO_CLASS_NONE, 0)
+ self.assertEqual(psutil.IOPRIO_CLASS_RT, 1) # high
+ self.assertEqual(psutil.IOPRIO_CLASS_BE, 2) # normal
+ self.assertEqual(psutil.IOPRIO_CLASS_IDLE, 3) # low
+ init = p.ionice()
+ try:
+ # low
+ p.ionice(psutil.IOPRIO_CLASS_IDLE)
+ self.assertEqual(tuple(p.ionice()), (psutil.IOPRIO_CLASS_IDLE, 0))
+ with self.assertRaises(ValueError): # accepts no value
+ p.ionice(psutil.IOPRIO_CLASS_IDLE, value=7)
+ # normal
+ p.ionice(psutil.IOPRIO_CLASS_BE)
+ self.assertEqual(tuple(p.ionice()), (psutil.IOPRIO_CLASS_BE, 0))
+ p.ionice(psutil.IOPRIO_CLASS_BE, value=7)
+ self.assertEqual(tuple(p.ionice()), (psutil.IOPRIO_CLASS_BE, 7))
+ with self.assertRaises(ValueError):
+ p.ionice(psutil.IOPRIO_CLASS_BE, value=8)
+ try:
+ p.ionice(psutil.IOPRIO_CLASS_RT, value=7)
+ except psutil.AccessDenied:
+ pass
+ # errs
+ self.assertRaisesRegex(
+ ValueError, "ioclass accepts no value",
+ p.ionice, psutil.IOPRIO_CLASS_NONE, 1)
+ self.assertRaisesRegex(
+ ValueError, "ioclass accepts no value",
+ p.ionice, psutil.IOPRIO_CLASS_IDLE, 1)
+ self.assertRaisesRegex(
+ ValueError, "'ioclass' argument must be specified",
+ p.ionice, value=1)
+ finally:
+ ioclass, value = init
+ if ioclass == psutil.IOPRIO_CLASS_NONE:
+ value = 0
+ p.ionice(ioclass, value)
+
+ @unittest.skipIf(not HAS_IONICE, "not supported")
+ @unittest.skipIf(not WINDOWS, 'not supported on this win version')
+ def test_ionice_win(self):
+ p = psutil.Process()
+ if not CI_TESTING:
+ self.assertEqual(p.ionice(), psutil.IOPRIO_NORMAL)
+ init = p.ionice()
+ try:
+ # base
+ p.ionice(psutil.IOPRIO_VERYLOW)
+ self.assertEqual(p.ionice(), psutil.IOPRIO_VERYLOW)
+ p.ionice(psutil.IOPRIO_LOW)
+ self.assertEqual(p.ionice(), psutil.IOPRIO_LOW)
+ try:
+ p.ionice(psutil.IOPRIO_HIGH)
+ except psutil.AccessDenied:
+ pass
+ else:
+ self.assertEqual(p.ionice(), psutil.IOPRIO_HIGH)
+ # errs
+ self.assertRaisesRegex(
+ TypeError, "value argument not accepted on Windows",
+ p.ionice, psutil.IOPRIO_NORMAL, value=1)
+ self.assertRaisesRegex(
+ ValueError, "is not a valid priority",
+ p.ionice, psutil.IOPRIO_HIGH + 1)
+ finally:
+ p.ionice(init)
+
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit_get(self):
+ import resource
+ p = psutil.Process(os.getpid())
+ names = [x for x in dir(psutil) if x.startswith('RLIMIT')]
+ assert names, names
+ for name in names:
+ value = getattr(psutil, name)
+ self.assertGreaterEqual(value, 0)
+ if name in dir(resource):
+ self.assertEqual(value, getattr(resource, name))
+ # XXX - On PyPy RLIMIT_INFINITY returned by
+ # resource.getrlimit() is reported as a very big long
+ # number instead of -1. It looks like a bug with PyPy.
+ if PYPY:
+ continue
+ self.assertEqual(p.rlimit(value), resource.getrlimit(value))
+ else:
+ ret = p.rlimit(value)
+ self.assertEqual(len(ret), 2)
+ self.assertGreaterEqual(ret[0], -1)
+ self.assertGreaterEqual(ret[1], -1)
+
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit_set(self):
+ p = self.spawn_psproc()
+ p.rlimit(psutil.RLIMIT_NOFILE, (5, 5))
+ self.assertEqual(p.rlimit(psutil.RLIMIT_NOFILE), (5, 5))
+ # If pid is 0 prlimit() applies to the calling process and
+ # we don't want that.
+ if LINUX:
+ with self.assertRaisesRegex(ValueError, "can't use prlimit"):
+ psutil._psplatform.Process(0).rlimit(0)
+ with self.assertRaises(ValueError):
+ p.rlimit(psutil.RLIMIT_NOFILE, (5, 5, 5))
+
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit(self):
+ p = psutil.Process()
+ testfn = self.get_testfn()
+ soft, hard = p.rlimit(psutil.RLIMIT_FSIZE)
+ try:
+ p.rlimit(psutil.RLIMIT_FSIZE, (1024, hard))
+ with open(testfn, "wb") as f:
+ f.write(b"X" * 1024)
+ # write() or flush() doesn't always cause the exception
+ # but close() will.
+ with self.assertRaises(IOError) as exc:
+ with open(testfn, "wb") as f:
+ f.write(b"X" * 1025)
+ self.assertEqual(exc.exception.errno if PY3 else exc.exception[0],
+ errno.EFBIG)
+ finally:
+ p.rlimit(psutil.RLIMIT_FSIZE, (soft, hard))
+ self.assertEqual(p.rlimit(psutil.RLIMIT_FSIZE), (soft, hard))
+
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit_infinity(self):
+ # First set a limit, then re-set it by specifying INFINITY
+ # and assume we overridden the previous limit.
+ p = psutil.Process()
+ soft, hard = p.rlimit(psutil.RLIMIT_FSIZE)
+ try:
+ p.rlimit(psutil.RLIMIT_FSIZE, (1024, hard))
+ p.rlimit(psutil.RLIMIT_FSIZE, (psutil.RLIM_INFINITY, hard))
+ with open(self.get_testfn(), "wb") as f:
+ f.write(b"X" * 2048)
+ finally:
+ p.rlimit(psutil.RLIMIT_FSIZE, (soft, hard))
+ self.assertEqual(p.rlimit(psutil.RLIMIT_FSIZE), (soft, hard))
+
+ @unittest.skipIf(not HAS_RLIMIT, "not supported")
+ def test_rlimit_infinity_value(self):
+ # RLIMIT_FSIZE should be RLIM_INFINITY, which will be a really
+ # big number on a platform with large file support. On these
+ # platforms we need to test that the get/setrlimit functions
+ # properly convert the number to a C long long and that the
+ # conversion doesn't raise an error.
+ p = psutil.Process()
+ soft, hard = p.rlimit(psutil.RLIMIT_FSIZE)
+ self.assertEqual(psutil.RLIM_INFINITY, hard)
+ p.rlimit(psutil.RLIMIT_FSIZE, (soft, hard))
+
+ def test_num_threads(self):
+ # on certain platforms such as Linux we might test for exact
+ # thread number, since we always have with 1 thread per process,
+ # but this does not apply across all platforms (MACOS, Windows)
+ p = psutil.Process()
+ if OPENBSD:
+ try:
+ step1 = p.num_threads()
+ except psutil.AccessDenied:
+ raise unittest.SkipTest("on OpenBSD this requires root access")
+ else:
+ step1 = p.num_threads()
+
+ with ThreadTask():
+ step2 = p.num_threads()
+ self.assertEqual(step2, step1 + 1)
+
+ @unittest.skipIf(not WINDOWS, 'WINDOWS only')
+ def test_num_handles(self):
+ # a better test is done later into test/_windows.py
+ p = psutil.Process()
+ self.assertGreater(p.num_handles(), 0)
+
+ @unittest.skipIf(not HAS_THREADS, 'not supported')
+ def test_threads(self):
+ p = psutil.Process()
+ if OPENBSD:
+ try:
+ step1 = p.threads()
+ except psutil.AccessDenied:
+ raise unittest.SkipTest("on OpenBSD this requires root access")
+ else:
+ step1 = p.threads()
+
+ with ThreadTask():
+ step2 = p.threads()
+ self.assertEqual(len(step2), len(step1) + 1)
+ athread = step2[0]
+ # test named tuple
+ self.assertEqual(athread.id, athread[0])
+ self.assertEqual(athread.user_time, athread[1])
+ self.assertEqual(athread.system_time, athread[2])
+
+ @retry_on_failure()
+ @skip_on_access_denied(only_if=MACOS)
+ @unittest.skipIf(not HAS_THREADS, 'not supported')
+ def test_threads_2(self):
+ p = self.spawn_psproc()
+ if OPENBSD:
+ try:
+ p.threads()
+ except psutil.AccessDenied:
+ raise unittest.SkipTest(
+ "on OpenBSD this requires root access")
+ self.assertAlmostEqual(
+ p.cpu_times().user,
+ sum([x.user_time for x in p.threads()]), delta=0.1)
+ self.assertAlmostEqual(
+ p.cpu_times().system,
+ sum([x.system_time for x in p.threads()]), delta=0.1)
+
+ @retry_on_failure()
+ def test_memory_info(self):
+ p = psutil.Process()
+
+ # step 1 - get a base value to compare our results
+ rss1, vms1 = p.memory_info()[:2]
+ percent1 = p.memory_percent()
+ self.assertGreater(rss1, 0)
+ self.assertGreater(vms1, 0)
+
+ # step 2 - allocate some memory
+ memarr = [None] * 1500000
+
+ rss2, vms2 = p.memory_info()[:2]
+ percent2 = p.memory_percent()
+
+ # step 3 - make sure that the memory usage bumped up
+ self.assertGreater(rss2, rss1)
+ self.assertGreaterEqual(vms2, vms1) # vms might be equal
+ self.assertGreater(percent2, percent1)
+ del memarr
+
+ if WINDOWS:
+ mem = p.memory_info()
+ self.assertEqual(mem.rss, mem.wset)
+ self.assertEqual(mem.vms, mem.pagefile)
+
+ mem = p.memory_info()
+ for name in mem._fields:
+ self.assertGreaterEqual(getattr(mem, name), 0)
+
+ def test_memory_full_info(self):
+ p = psutil.Process()
+ total = psutil.virtual_memory().total
+ mem = p.memory_full_info()
+ for name in mem._fields:
+ value = getattr(mem, name)
+ self.assertGreaterEqual(value, 0, msg=(name, value))
+ if name == 'vms' and OSX or LINUX:
+ continue
+ self.assertLessEqual(value, total, msg=(name, value, total))
+ if LINUX or WINDOWS or MACOS:
+ self.assertGreaterEqual(mem.uss, 0)
+ if LINUX:
+ self.assertGreaterEqual(mem.pss, 0)
+ self.assertGreaterEqual(mem.swap, 0)
+
+ @unittest.skipIf(not HAS_MEMORY_MAPS, "not supported")
+ def test_memory_maps(self):
+ p = psutil.Process()
+ maps = p.memory_maps()
+ paths = [x for x in maps]
+ self.assertEqual(len(paths), len(set(paths)))
+ ext_maps = p.memory_maps(grouped=False)
+
+ for nt in maps:
+ if not nt.path.startswith('['):
+ assert os.path.isabs(nt.path), nt.path
+ if POSIX:
+ try:
+ assert os.path.exists(nt.path) or \
+ os.path.islink(nt.path), nt.path
+ except AssertionError:
+ if not LINUX:
+ raise
+ else:
+ # https://github.com/giampaolo/psutil/issues/759
+ with open_text('/proc/self/smaps') as f:
+ data = f.read()
+ if "%s (deleted)" % nt.path not in data:
+ raise
+ else:
+ # XXX - On Windows we have this strange behavior with
+ # 64 bit dlls: they are visible via explorer but cannot
+ # be accessed via os.stat() (wtf?).
+ if '64' not in os.path.basename(nt.path):
+ try:
+ st = os.stat(nt.path)
+ except FileNotFoundError:
+ pass
+ else:
+ assert stat.S_ISREG(st.st_mode), nt.path
+ for nt in ext_maps:
+ for fname in nt._fields:
+ value = getattr(nt, fname)
+ if fname == 'path':
+ continue
+ elif fname in ('addr', 'perms'):
+ assert value, value
+ else:
+ self.assertIsInstance(value, (int, long))
+ assert value >= 0, value
+
+ @unittest.skipIf(not HAS_MEMORY_MAPS, "not supported")
+ def test_memory_maps_lists_lib(self):
+ # Make sure a newly loaded shared lib is listed.
+ p = psutil.Process()
+ with copyload_shared_lib() as path:
+ def normpath(p):
+ return os.path.realpath(os.path.normcase(p))
+ libpaths = [normpath(x.path)
+ for x in p.memory_maps()]
+ self.assertIn(normpath(path), libpaths)
+
+ def test_memory_percent(self):
+ p = psutil.Process()
+ p.memory_percent()
+ self.assertRaises(ValueError, p.memory_percent, memtype="?!?")
+ if LINUX or MACOS or WINDOWS:
+ p.memory_percent(memtype='uss')
+
+ def test_is_running(self):
+ p = self.spawn_psproc()
+ assert p.is_running()
+ assert p.is_running()
+ p.kill()
+ p.wait()
+ assert not p.is_running()
+ assert not p.is_running()
+
+ def test_exe(self):
+ p = self.spawn_psproc()
+ exe = p.exe()
+ try:
+ self.assertEqual(exe, PYTHON_EXE)
+ except AssertionError:
+ if WINDOWS and len(exe) == len(PYTHON_EXE):
+ # on Windows we don't care about case sensitivity
+ normcase = os.path.normcase
+ self.assertEqual(normcase(exe), normcase(PYTHON_EXE))
+ else:
+ # certain platforms such as BSD are more accurate returning:
+ # "/usr/local/bin/python2.7"
+ # ...instead of:
+ # "/usr/local/bin/python"
+ # We do not want to consider this difference in accuracy
+ # an error.
+ ver = "%s.%s" % (sys.version_info[0], sys.version_info[1])
+ try:
+ self.assertEqual(exe.replace(ver, ''),
+ PYTHON_EXE.replace(ver, ''))
+ except AssertionError:
+ # Typically MACOS. Really not sure what to do here.
+ pass
+
+ out = sh([exe, "-c", "import os; print('hey')"])
+ self.assertEqual(out, 'hey')
+
+ def test_cmdline(self):
+ cmdline = [PYTHON_EXE, "-c", "import time; time.sleep(60)"]
+ p = self.spawn_psproc(cmdline)
+ # XXX - most of the times the underlying sysctl() call on Net
+ # and Open BSD returns a truncated string.
+ # Also /proc/pid/cmdline behaves the same so it looks
+ # like this is a kernel bug.
+ # XXX - AIX truncates long arguments in /proc/pid/cmdline
+ if NETBSD or OPENBSD or AIX:
+ self.assertEqual(p.cmdline()[0], PYTHON_EXE)
+ else:
+ if MACOS and CI_TESTING:
+ pyexe = p.cmdline()[0]
+ if pyexe != PYTHON_EXE:
+ self.assertEqual(' '.join(p.cmdline()[1:]),
+ ' '.join(cmdline[1:]))
+ return
+ self.assertEqual(' '.join(p.cmdline()), ' '.join(cmdline))
+
+ @unittest.skipIf(PYPY, "broken on PYPY")
+ def test_long_cmdline(self):
+ testfn = self.get_testfn()
+ create_exe(testfn)
+ cmdline = [testfn] + (["0123456789"] * 20)
+ p = self.spawn_psproc(cmdline)
+ self.assertEqual(p.cmdline(), cmdline)
+
+ def test_name(self):
+ p = self.spawn_psproc(PYTHON_EXE)
+ name = p.name().lower()
+ pyexe = os.path.basename(os.path.realpath(sys.executable)).lower()
+ assert pyexe.startswith(name), (pyexe, name)
+
+ @unittest.skipIf(PYPY, "unreliable on PYPY")
+ def test_long_name(self):
+ testfn = self.get_testfn(suffix="0123456789" * 2)
+ create_exe(testfn)
+ p = self.spawn_psproc(testfn)
+ self.assertEqual(p.name(), os.path.basename(testfn))
+
+ # XXX
+ @unittest.skipIf(SUNOS, "broken on SUNOS")
+ @unittest.skipIf(AIX, "broken on AIX")
+ @unittest.skipIf(PYPY, "broken on PYPY")
+ def test_prog_w_funky_name(self):
+ # Test that name(), exe() and cmdline() correctly handle programs
+ # with funky chars such as spaces and ")", see:
+ # https://github.com/giampaolo/psutil/issues/628
+ funky_path = self.get_testfn(suffix='foo bar )')
+ create_exe(funky_path)
+ cmdline = [funky_path, "-c",
+ "import time; [time.sleep(0.01) for x in range(3000)];"
+ "arg1", "arg2", "", "arg3", ""]
+ p = self.spawn_psproc(cmdline)
+ self.assertEqual(p.cmdline(), cmdline)
+ self.assertEqual(p.name(), os.path.basename(funky_path))
+ self.assertEqual(os.path.normcase(p.exe()),
+ os.path.normcase(funky_path))
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_uids(self):
+ p = psutil.Process()
+ real, effective, saved = p.uids()
+ # os.getuid() refers to "real" uid
+ self.assertEqual(real, os.getuid())
+ # os.geteuid() refers to "effective" uid
+ self.assertEqual(effective, os.geteuid())
+ # No such thing as os.getsuid() ("saved" uid), but starting
+ # from python 2.7 we have os.getresuid() which returns all
+ # of them.
+ if hasattr(os, "getresuid"):
+ self.assertEqual(os.getresuid(), p.uids())
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_gids(self):
+ p = psutil.Process()
+ real, effective, saved = p.gids()
+ # os.getuid() refers to "real" uid
+ self.assertEqual(real, os.getgid())
+ # os.geteuid() refers to "effective" uid
+ self.assertEqual(effective, os.getegid())
+ # No such thing as os.getsgid() ("saved" gid), but starting
+ # from python 2.7 we have os.getresgid() which returns all
+ # of them.
+ if hasattr(os, "getresuid"):
+ self.assertEqual(os.getresgid(), p.gids())
+
+ def test_nice(self):
+ p = psutil.Process()
+ self.assertRaises(TypeError, p.nice, "str")
+ init = p.nice()
+ try:
+ if WINDOWS:
+ # A CI runner may limit our maximum priority, which will break
+ # this test. Instead, we test in order of increasing priority,
+ # and match either the expected value or the highest so far.
+ highest_prio = None
+ for prio in [psutil.IDLE_PRIORITY_CLASS,
+ psutil.BELOW_NORMAL_PRIORITY_CLASS,
+ psutil.NORMAL_PRIORITY_CLASS,
+ psutil.ABOVE_NORMAL_PRIORITY_CLASS,
+ psutil.HIGH_PRIORITY_CLASS,
+ psutil.REALTIME_PRIORITY_CLASS]:
+ with self.subTest(prio=prio):
+ try:
+ p.nice(prio)
+ except psutil.AccessDenied:
+ pass
+ else:
+ new_prio = p.nice()
+ if CI_TESTING:
+ if new_prio == prio or highest_prio is None:
+ highest_prio = prio
+ self.assertEqual(new_prio, highest_prio)
+ else:
+ self.assertEqual(new_prio, prio)
+ else:
+ try:
+ if hasattr(os, "getpriority"):
+ self.assertEqual(
+ os.getpriority(os.PRIO_PROCESS, os.getpid()),
+ p.nice())
+ p.nice(1)
+ self.assertEqual(p.nice(), 1)
+ if hasattr(os, "getpriority"):
+ self.assertEqual(
+ os.getpriority(os.PRIO_PROCESS, os.getpid()),
+ p.nice())
+ # XXX - going back to previous nice value raises
+ # AccessDenied on MACOS
+ if not MACOS:
+ p.nice(0)
+ self.assertEqual(p.nice(), 0)
+ except psutil.AccessDenied:
+ pass
+ finally:
+ try:
+ p.nice(init)
+ except psutil.AccessDenied:
+ pass
+
+ def test_status(self):
+ p = psutil.Process()
+ self.assertEqual(p.status(), psutil.STATUS_RUNNING)
+
+ def test_username(self):
+ p = self.spawn_psproc()
+ username = p.username()
+ if WINDOWS:
+ domain, username = username.split('\\')
+ getpass_user = getpass.getuser()
+ if getpass_user.endswith('$'):
+ # When running as a service account (most likely to be
+ # NetworkService), these user name calculations don't produce
+ # the same result, causing the test to fail.
+ raise unittest.SkipTest('running as service account')
+ self.assertEqual(username, getpass_user)
+ if 'USERDOMAIN' in os.environ:
+ self.assertEqual(domain, os.environ['USERDOMAIN'])
+ else:
+ self.assertEqual(username, getpass.getuser())
+
+ def test_cwd(self):
+ p = self.spawn_psproc()
+ self.assertEqual(p.cwd(), os.getcwd())
+
+ def test_cwd_2(self):
+ cmd = [PYTHON_EXE, "-c",
+ "import os, time; os.chdir('..'); time.sleep(60)"]
+ p = self.spawn_psproc(cmd)
+ call_until(p.cwd, "ret == os.path.dirname(os.getcwd())")
+
+ @unittest.skipIf(not HAS_CPU_AFFINITY, 'not supported')
+ def test_cpu_affinity(self):
+ p = psutil.Process()
+ initial = p.cpu_affinity()
+ assert initial, initial
+ self.addCleanup(p.cpu_affinity, initial)
+
+ if hasattr(os, "sched_getaffinity"):
+ self.assertEqual(initial, list(os.sched_getaffinity(p.pid)))
+ self.assertEqual(len(initial), len(set(initial)))
+
+ all_cpus = list(range(len(psutil.cpu_percent(percpu=True))))
+ for n in all_cpus:
+ p.cpu_affinity([n])
+ self.assertEqual(p.cpu_affinity(), [n])
+ if hasattr(os, "sched_getaffinity"):
+ self.assertEqual(p.cpu_affinity(),
+ list(os.sched_getaffinity(p.pid)))
+ # also test num_cpu()
+ if hasattr(p, "num_cpu"):
+ self.assertEqual(p.cpu_affinity()[0], p.num_cpu())
+
+ # [] is an alias for "all eligible CPUs"; on Linux this may
+ # not be equal to all available CPUs, see:
+ # https://github.com/giampaolo/psutil/issues/956
+ p.cpu_affinity([])
+ if LINUX:
+ self.assertEqual(p.cpu_affinity(), p._proc._get_eligible_cpus())
+ else:
+ self.assertEqual(p.cpu_affinity(), all_cpus)
+ if hasattr(os, "sched_getaffinity"):
+ self.assertEqual(p.cpu_affinity(),
+ list(os.sched_getaffinity(p.pid)))
+ #
+ self.assertRaises(TypeError, p.cpu_affinity, 1)
+ p.cpu_affinity(initial)
+ # it should work with all iterables, not only lists
+ p.cpu_affinity(set(all_cpus))
+ p.cpu_affinity(tuple(all_cpus))
+
+ @unittest.skipIf(not HAS_CPU_AFFINITY, 'not supported')
+ def test_cpu_affinity_errs(self):
+ p = self.spawn_psproc()
+ invalid_cpu = [len(psutil.cpu_times(percpu=True)) + 10]
+ self.assertRaises(ValueError, p.cpu_affinity, invalid_cpu)
+ self.assertRaises(ValueError, p.cpu_affinity, range(10000, 11000))
+ self.assertRaises(TypeError, p.cpu_affinity, [0, "1"])
+ self.assertRaises(ValueError, p.cpu_affinity, [0, -1])
+
+ @unittest.skipIf(not HAS_CPU_AFFINITY, 'not supported')
+ def test_cpu_affinity_all_combinations(self):
+ p = psutil.Process()
+ initial = p.cpu_affinity()
+ assert initial, initial
+ self.addCleanup(p.cpu_affinity, initial)
+
+ # All possible CPU set combinations.
+ if len(initial) > 12:
+ initial = initial[:12] # ...otherwise it will take forever
+ combos = []
+ for i in range(0, len(initial) + 1):
+ for subset in itertools.combinations(initial, i):
+ if subset:
+ combos.append(list(subset))
+
+ for combo in combos:
+ p.cpu_affinity(combo)
+ self.assertEqual(sorted(p.cpu_affinity()), sorted(combo))
+
+ # TODO: #595
+ @unittest.skipIf(BSD, "broken on BSD")
+ # can't find any process file on Appveyor
+ @unittest.skipIf(APPVEYOR, "unreliable on APPVEYOR")
+ def test_open_files(self):
+ p = psutil.Process()
+ testfn = self.get_testfn()
+ files = p.open_files()
+ self.assertNotIn(testfn, files)
+ with open(testfn, 'wb') as f:
+ f.write(b'x' * 1024)
+ f.flush()
+ # give the kernel some time to see the new file
+ files = call_until(p.open_files, "len(ret) != %i" % len(files))
+ filenames = [os.path.normcase(x.path) for x in files]
+ self.assertIn(os.path.normcase(testfn), filenames)
+ if LINUX:
+ for file in files:
+ if file.path == testfn:
+ self.assertEqual(file.position, 1024)
+ for file in files:
+ assert os.path.isfile(file.path), file
+
+ # another process
+ cmdline = "import time; f = open(r'%s', 'r'); time.sleep(60);" % testfn
+ p = self.spawn_psproc([PYTHON_EXE, "-c", cmdline])
+
+ for x in range(100):
+ filenames = [os.path.normcase(x.path) for x in p.open_files()]
+ if testfn in filenames:
+ break
+ time.sleep(.01)
+ else:
+ self.assertIn(os.path.normcase(testfn), filenames)
+ for file in filenames:
+ assert os.path.isfile(file), file
+
+ # TODO: #595
+ @unittest.skipIf(BSD, "broken on BSD")
+ # can't find any process file on Appveyor
+ @unittest.skipIf(APPVEYOR, "unreliable on APPVEYOR")
+ def test_open_files_2(self):
+ # test fd and path fields
+ p = psutil.Process()
+ normcase = os.path.normcase
+ testfn = self.get_testfn()
+ with open(testfn, 'w') as fileobj:
+ for file in p.open_files():
+ if normcase(file.path) == normcase(fileobj.name) or \
+ file.fd == fileobj.fileno():
+ break
+ else:
+ raise self.fail("no file found; files=%s" % (
+ repr(p.open_files())))
+ self.assertEqual(normcase(file.path), normcase(fileobj.name))
+ if WINDOWS:
+ self.assertEqual(file.fd, -1)
+ else:
+ self.assertEqual(file.fd, fileobj.fileno())
+ # test positions
+ ntuple = p.open_files()[0]
+ self.assertEqual(ntuple[0], ntuple.path)
+ self.assertEqual(ntuple[1], ntuple.fd)
+ # test file is gone
+ self.assertNotIn(fileobj.name, p.open_files())
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_num_fds(self):
+ p = psutil.Process()
+ testfn = self.get_testfn()
+ start = p.num_fds()
+ file = open(testfn, 'w')
+ self.addCleanup(file.close)
+ self.assertEqual(p.num_fds(), start + 1)
+ sock = socket.socket()
+ self.addCleanup(sock.close)
+ self.assertEqual(p.num_fds(), start + 2)
+ file.close()
+ sock.close()
+ self.assertEqual(p.num_fds(), start)
+
+ @skip_on_not_implemented(only_if=LINUX)
+ @unittest.skipIf(OPENBSD or NETBSD, "not reliable on OPENBSD & NETBSD")
+ def test_num_ctx_switches(self):
+ p = psutil.Process()
+ before = sum(p.num_ctx_switches())
+ for x in range(500000):
+ after = sum(p.num_ctx_switches())
+ if after > before:
+ return
+ raise self.fail(
+ "num ctx switches still the same after 50.000 iterations")
+
+ def test_ppid(self):
+ p = psutil.Process()
+ if hasattr(os, 'getppid'):
+ self.assertEqual(p.ppid(), os.getppid())
+ p = self.spawn_psproc()
+ self.assertEqual(p.ppid(), os.getpid())
+ if APPVEYOR:
+ # Occasional failures, see:
+ # https://ci.appveyor.com/project/giampaolo/psutil/build/
+ # job/0hs623nenj7w4m33
+ return
+
+ def test_parent(self):
+ p = self.spawn_psproc()
+ self.assertEqual(p.parent().pid, os.getpid())
+
+ lowest_pid = psutil.pids()[0]
+ self.assertIsNone(psutil.Process(lowest_pid).parent())
+
+ def test_parent_multi(self):
+ parent = psutil.Process()
+ child, grandchild = self.spawn_children_pair()
+ self.assertEqual(grandchild.parent(), child)
+ self.assertEqual(child.parent(), parent)
+
+ def test_parent_disappeared(self):
+ # Emulate a case where the parent process disappeared.
+ p = self.spawn_psproc()
+ with mock.patch("psutil.Process",
+ side_effect=psutil.NoSuchProcess(0, 'foo')):
+ self.assertIsNone(p.parent())
+
+ @retry_on_failure()
+ def test_parents(self):
+ parent = psutil.Process()
+ assert parent.parents()
+ child, grandchild = self.spawn_children_pair()
+ self.assertEqual(child.parents()[0], parent)
+ self.assertEqual(grandchild.parents()[0], child)
+ self.assertEqual(grandchild.parents()[1], parent)
+
+ def test_children(self):
+ parent = psutil.Process()
+ self.assertEqual(parent.children(), [])
+ self.assertEqual(parent.children(recursive=True), [])
+ # On Windows we set the flag to 0 in order to cancel out the
+ # CREATE_NO_WINDOW flag (enabled by default) which creates
+ # an extra "conhost.exe" child.
+ child = self.spawn_psproc(creationflags=0)
+ children1 = parent.children()
+ children2 = parent.children(recursive=True)
+ for children in (children1, children2):
+ self.assertEqual(len(children), 1)
+ self.assertEqual(children[0].pid, child.pid)
+ self.assertEqual(children[0].ppid(), parent.pid)
+
+ def test_children_recursive(self):
+ # Test children() against two sub processes, p1 and p2, where
+ # p1 (our child) spawned p2 (our grandchild).
+ parent = psutil.Process()
+ child, grandchild = self.spawn_children_pair()
+ self.assertEqual(parent.children(), [child])
+ self.assertEqual(parent.children(recursive=True), [child, grandchild])
+ # If the intermediate process is gone there's no way for
+ # children() to recursively find it.
+ child.terminate()
+ child.wait()
+ self.assertEqual(parent.children(recursive=True), [])
+
+ def test_children_duplicates(self):
+ # find the process which has the highest number of children
+ table = collections.defaultdict(int)
+ for p in psutil.process_iter():
+ try:
+ table[p.ppid()] += 1
+ except psutil.Error:
+ pass
+ # this is the one, now let's make sure there are no duplicates
+ pid = sorted(table.items(), key=lambda x: x[1])[-1][0]
+ if LINUX and pid == 0:
+ raise self.skipTest("PID 0")
+ p = psutil.Process(pid)
+ try:
+ c = p.children(recursive=True)
+ except psutil.AccessDenied: # windows
+ pass
+ else:
+ self.assertEqual(len(c), len(set(c)))
+
+ def test_parents_and_children(self):
+ parent = psutil.Process()
+ child, grandchild = self.spawn_children_pair()
+ # forward
+ children = parent.children(recursive=True)
+ self.assertEqual(len(children), 2)
+ self.assertEqual(children[0], child)
+ self.assertEqual(children[1], grandchild)
+ # backward
+ parents = grandchild.parents()
+ self.assertEqual(parents[0], child)
+ self.assertEqual(parents[1], parent)
+
+ def test_suspend_resume(self):
+ p = self.spawn_psproc()
+ p.suspend()
+ for x in range(100):
+ if p.status() == psutil.STATUS_STOPPED:
+ break
+ time.sleep(0.01)
+ p.resume()
+ self.assertNotEqual(p.status(), psutil.STATUS_STOPPED)
+
+ def test_invalid_pid(self):
+ self.assertRaises(TypeError, psutil.Process, "1")
+ self.assertRaises(ValueError, psutil.Process, -1)
+
+ def test_as_dict(self):
+ p = psutil.Process()
+ d = p.as_dict(attrs=['exe', 'name'])
+ self.assertEqual(sorted(d.keys()), ['exe', 'name'])
+
+ p = psutil.Process(min(psutil.pids()))
+ d = p.as_dict(attrs=['connections'], ad_value='foo')
+ if not isinstance(d['connections'], list):
+ self.assertEqual(d['connections'], 'foo')
+
+ # Test ad_value is set on AccessDenied.
+ with mock.patch('psutil.Process.nice', create=True,
+ side_effect=psutil.AccessDenied):
+ self.assertEqual(
+ p.as_dict(attrs=["nice"], ad_value=1), {"nice": 1})
+
+ # Test that NoSuchProcess bubbles up.
+ with mock.patch('psutil.Process.nice', create=True,
+ side_effect=psutil.NoSuchProcess(p.pid, "name")):
+ self.assertRaises(
+ psutil.NoSuchProcess, p.as_dict, attrs=["nice"])
+
+ # Test that ZombieProcess is swallowed.
+ with mock.patch('psutil.Process.nice', create=True,
+ side_effect=psutil.ZombieProcess(p.pid, "name")):
+ self.assertEqual(
+ p.as_dict(attrs=["nice"], ad_value="foo"), {"nice": "foo"})
+
+ # By default APIs raising NotImplementedError are
+ # supposed to be skipped.
+ with mock.patch('psutil.Process.nice', create=True,
+ side_effect=NotImplementedError):
+ d = p.as_dict()
+ self.assertNotIn('nice', list(d.keys()))
+ # ...unless the user explicitly asked for some attr.
+ with self.assertRaises(NotImplementedError):
+ p.as_dict(attrs=["nice"])
+
+ # errors
+ with self.assertRaises(TypeError):
+ p.as_dict('name')
+ with self.assertRaises(ValueError):
+ p.as_dict(['foo'])
+ with self.assertRaises(ValueError):
+ p.as_dict(['foo', 'bar'])
+
+ def test_oneshot(self):
+ p = psutil.Process()
+ with mock.patch("psutil._psplatform.Process.cpu_times") as m:
+ with p.oneshot():
+ p.cpu_times()
+ p.cpu_times()
+ self.assertEqual(m.call_count, 1)
+
+ with mock.patch("psutil._psplatform.Process.cpu_times") as m:
+ p.cpu_times()
+ p.cpu_times()
+ self.assertEqual(m.call_count, 2)
+
+ def test_oneshot_twice(self):
+ # Test the case where the ctx manager is __enter__ed twice.
+ # The second __enter__ is supposed to resut in a NOOP.
+ p = psutil.Process()
+ with mock.patch("psutil._psplatform.Process.cpu_times") as m1:
+ with mock.patch("psutil._psplatform.Process.oneshot_enter") as m2:
+ with p.oneshot():
+ p.cpu_times()
+ p.cpu_times()
+ with p.oneshot():
+ p.cpu_times()
+ p.cpu_times()
+ self.assertEqual(m1.call_count, 1)
+ self.assertEqual(m2.call_count, 1)
+
+ with mock.patch("psutil._psplatform.Process.cpu_times") as m:
+ p.cpu_times()
+ p.cpu_times()
+ self.assertEqual(m.call_count, 2)
+
+ def test_oneshot_cache(self):
+ # Make sure oneshot() cache is nonglobal. Instead it's
+ # supposed to be bound to the Process instance, see:
+ # https://github.com/giampaolo/psutil/issues/1373
+ p1, p2 = self.spawn_children_pair()
+ p1_ppid = p1.ppid()
+ p2_ppid = p2.ppid()
+ self.assertNotEqual(p1_ppid, p2_ppid)
+ with p1.oneshot():
+ self.assertEqual(p1.ppid(), p1_ppid)
+ self.assertEqual(p2.ppid(), p2_ppid)
+ with p2.oneshot():
+ self.assertEqual(p1.ppid(), p1_ppid)
+ self.assertEqual(p2.ppid(), p2_ppid)
+
+ def test_halfway_terminated_process(self):
+ # Test that NoSuchProcess exception gets raised in case the
+ # process dies after we create the Process object.
+ # Example:
+ # >>> proc = Process(1234)
+ # >>> time.sleep(2) # time-consuming task, process dies in meantime
+ # >>> proc.name()
+ # Refers to Issue #15
+ def assert_raises_nsp(fun, fun_name):
+ try:
+ ret = fun()
+ except psutil.ZombieProcess: # differentiate from NSP
+ raise
+ except psutil.NoSuchProcess:
+ pass
+ except psutil.AccessDenied:
+ if OPENBSD and fun_name in ('threads', 'num_threads'):
+ return
+ raise
+ else:
+ # NtQuerySystemInformation succeeds even if process is gone.
+ if WINDOWS and fun_name in ('exe', 'name'):
+ return
+ raise self.fail("%r didn't raise NSP and returned %r "
+ "instead" % (fun, ret))
+
+ p = self.spawn_psproc()
+ p.terminate()
+ p.wait()
+ if WINDOWS: # XXX
+ call_until(psutil.pids, "%s not in ret" % p.pid)
+ self.assertProcessGone(p)
+
+ ns = process_namespace(p)
+ for fun, name in ns.iter(ns.all):
+ assert_raises_nsp(fun, name)
+
+ # NtQuerySystemInformation succeeds even if process is gone.
+ if WINDOWS and not GITHUB_ACTIONS:
+ normcase = os.path.normcase
+ self.assertEqual(normcase(p.exe()), normcase(PYTHON_EXE))
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_zombie_process(self):
+ def succeed_or_zombie_p_exc(fun):
+ try:
+ return fun()
+ except (psutil.ZombieProcess, psutil.AccessDenied):
+ pass
+
+ parent, zombie = self.spawn_zombie()
+ # A zombie process should always be instantiable
+ zproc = psutil.Process(zombie.pid)
+ # ...and at least its status always be querable
+ self.assertEqual(zproc.status(), psutil.STATUS_ZOMBIE)
+ # ...and it should be considered 'running'
+ assert zproc.is_running()
+ # ...and as_dict() shouldn't crash
+ zproc.as_dict()
+ # ...its parent should 'see' it (edit: not true on BSD and MACOS
+ # descendants = [x.pid for x in psutil.Process().children(
+ # recursive=True)]
+ # self.assertIn(zpid, descendants)
+ # XXX should we also assume ppid be usable? Note: this
+ # would be an important use case as the only way to get
+ # rid of a zombie is to kill its parent.
+ # self.assertEqual(zpid.ppid(), os.getpid())
+ # ...and all other APIs should be able to deal with it
+
+ ns = process_namespace(zproc)
+ for fun, name in ns.iter(ns.all):
+ succeed_or_zombie_p_exc(fun)
+
+ assert psutil.pid_exists(zproc.pid)
+ self.assertIn(zproc.pid, psutil.pids())
+ self.assertIn(zproc.pid, [x.pid for x in psutil.process_iter()])
+ psutil._pmap = {}
+ self.assertIn(zproc.pid, [x.pid for x in psutil.process_iter()])
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_zombie_process_is_running_w_exc(self):
+ # Emulate a case where internally is_running() raises
+ # ZombieProcess.
+ p = psutil.Process()
+ with mock.patch("psutil.Process",
+ side_effect=psutil.ZombieProcess(0)) as m:
+ assert p.is_running()
+ assert m.called
+
+ @unittest.skipIf(not POSIX, 'POSIX only')
+ def test_zombie_process_status_w_exc(self):
+ # Emulate a case where internally status() raises
+ # ZombieProcess.
+ p = psutil.Process()
+ with mock.patch("psutil._psplatform.Process.status",
+ side_effect=psutil.ZombieProcess(0)) as m:
+ self.assertEqual(p.status(), psutil.STATUS_ZOMBIE)
+ assert m.called
+
+ def test_reused_pid(self):
+ # Emulate a case where PID has been reused by another process.
+ subp = self.spawn_testproc()
+ p = psutil.Process(subp.pid)
+ p._ident = (p.pid, p.create_time() + 100)
+ assert not p.is_running()
+ assert p != psutil.Process(subp.pid)
+ msg = "process no longer exists and its PID has been reused"
+ self.assertRaisesRegex(psutil.NoSuchProcess, msg, p.suspend)
+ self.assertRaisesRegex(psutil.NoSuchProcess, msg, p.resume)
+ self.assertRaisesRegex(psutil.NoSuchProcess, msg, p.terminate)
+ self.assertRaisesRegex(psutil.NoSuchProcess, msg, p.kill)
+ self.assertRaisesRegex(psutil.NoSuchProcess, msg, p.children)
+
+ def test_pid_0(self):
+ # Process(0) is supposed to work on all platforms except Linux
+ if 0 not in psutil.pids():
+ self.assertRaises(psutil.NoSuchProcess, psutil.Process, 0)
+ # These 2 are a contradiction, but "ps" says PID 1's parent
+ # is PID 0.
+ assert not psutil.pid_exists(0)
+ self.assertEqual(psutil.Process(1).ppid(), 0)
+ return
+
+ p = psutil.Process(0)
+ exc = psutil.AccessDenied if WINDOWS else ValueError
+ self.assertRaises(exc, p.wait)
+ self.assertRaises(exc, p.terminate)
+ self.assertRaises(exc, p.suspend)
+ self.assertRaises(exc, p.resume)
+ self.assertRaises(exc, p.kill)
+ self.assertRaises(exc, p.send_signal, signal.SIGTERM)
+
+ # test all methods
+ ns = process_namespace(p)
+ for fun, name in ns.iter(ns.getters + ns.setters):
+ try:
+ ret = fun()
+ except psutil.AccessDenied:
+ pass
+ else:
+ if name in ("uids", "gids"):
+ self.assertEqual(ret.real, 0)
+ elif name == "username":
+ user = 'NT AUTHORITY\\SYSTEM' if WINDOWS else 'root'
+ self.assertEqual(p.username(), user)
+ elif name == "name":
+ assert name, name
+
+ if not OPENBSD:
+ self.assertIn(0, psutil.pids())
+ assert psutil.pid_exists(0)
+
+ @unittest.skipIf(not HAS_ENVIRON, "not supported")
+ def test_environ(self):
+ def clean_dict(d):
+ # Most of these are problematic on Travis.
+ d.pop("PLAT", None)
+ d.pop("HOME", None)
+ if MACOS:
+ d.pop("__CF_USER_TEXT_ENCODING", None)
+ d.pop("VERSIONER_PYTHON_PREFER_32_BIT", None)
+ d.pop("VERSIONER_PYTHON_VERSION", None)
+ return dict(
+ [(k.replace("\r", "").replace("\n", ""),
+ v.replace("\r", "").replace("\n", ""))
+ for k, v in d.items()])
+
+ self.maxDiff = None
+ p = psutil.Process()
+ d1 = clean_dict(p.environ())
+ d2 = clean_dict(os.environ.copy())
+ if not OSX and GITHUB_ACTIONS:
+ self.assertEqual(d1, d2)
+
+ @unittest.skipIf(not HAS_ENVIRON, "not supported")
+ @unittest.skipIf(not POSIX, "POSIX only")
+ @unittest.skipIf(
+ MACOS_11PLUS,
+ "macOS 11+ can't get another process environment, issue #2084"
+ )
+ def test_weird_environ(self):
+ # environment variables can contain values without an equals sign
+ code = textwrap.dedent("""
+ #include <unistd.h>
+ #include <fcntl.h>
+
+ char * const argv[] = {"cat", 0};
+ char * const envp[] = {"A=1", "X", "C=3", 0};
+
+ int main(void) {
+ // Close stderr on exec so parent can wait for the
+ // execve to finish.
+ if (fcntl(2, F_SETFD, FD_CLOEXEC) != 0)
+ return 0;
+ return execve("/bin/cat", argv, envp);
+ }
+ """)
+ path = self.get_testfn()
+ create_exe(path, c_code=code)
+ sproc = self.spawn_testproc(
+ [path], stdin=subprocess.PIPE, stderr=subprocess.PIPE)
+ p = psutil.Process(sproc.pid)
+ wait_for_pid(p.pid)
+ assert p.is_running()
+ # Wait for process to exec or exit.
+ self.assertEqual(sproc.stderr.read(), b"")
+ if MACOS and CI_TESTING:
+ try:
+ env = p.environ()
+ except psutil.AccessDenied:
+ # XXX: fails sometimes with:
+ # PermissionError from 'sysctl(KERN_PROCARGS2) -> EIO'
+ return
+ else:
+ env = p.environ()
+ self.assertEqual(env, {"A": "1", "C": "3"})
+ sproc.communicate()
+ self.assertEqual(sproc.returncode, 0)
+
+
+# ===================================================================
+# --- Limited user tests
+# ===================================================================
+
+
+if POSIX and os.getuid() == 0:
+
+ class LimitedUserTestCase(TestProcess):
+ """Repeat the previous tests by using a limited user.
+ Executed only on UNIX and only if the user who run the test script
+ is root.
+ """
+ # the uid/gid the test suite runs under
+ if hasattr(os, 'getuid'):
+ PROCESS_UID = os.getuid()
+ PROCESS_GID = os.getgid()
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # re-define all existent test methods in order to
+ # ignore AccessDenied exceptions
+ for attr in [x for x in dir(self) if x.startswith('test')]:
+ meth = getattr(self, attr)
+
+ def test_(self):
+ try:
+ meth() # noqa
+ except psutil.AccessDenied:
+ pass
+ setattr(self, attr, types.MethodType(test_, self))
+
+ def setUp(self):
+ super().setUp()
+ os.setegid(1000)
+ os.seteuid(1000)
+
+ def tearDown(self):
+ os.setegid(self.PROCESS_UID)
+ os.seteuid(self.PROCESS_GID)
+ super().tearDown()
+
+ def test_nice(self):
+ try:
+ psutil.Process().nice(-1)
+ except psutil.AccessDenied:
+ pass
+ else:
+ raise self.fail("exception not raised")
+
+ @unittest.skipIf(1, "causes problem as root")
+ def test_zombie_process(self):
+ pass
+
+
+# ===================================================================
+# --- psutil.Popen tests
+# ===================================================================
+
+
+class TestPopen(PsutilTestCase):
+ """Tests for psutil.Popen class."""
+
+ @classmethod
+ def tearDownClass(cls):
+ reap_children()
+
+ def test_misc(self):
+ # XXX this test causes a ResourceWarning on Python 3 because
+ # psutil.__subproc instance doesn't get properly freed.
+ # Not sure what to do though.
+ cmd = [PYTHON_EXE, "-c", "import time; time.sleep(60);"]
+ with psutil.Popen(cmd, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE) as proc:
+ proc.name()
+ proc.cpu_times()
+ proc.stdin
+ self.assertTrue(dir(proc))
+ self.assertRaises(AttributeError, getattr, proc, 'foo')
+ proc.terminate()
+ if POSIX:
+ self.assertEqual(proc.wait(5), -signal.SIGTERM)
+ else:
+ self.assertEqual(proc.wait(5), signal.SIGTERM)
+
+ def test_ctx_manager(self):
+ with psutil.Popen([PYTHON_EXE, "-V"],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ stdin=subprocess.PIPE) as proc:
+ proc.communicate()
+ assert proc.stdout.closed
+ assert proc.stderr.closed
+ assert proc.stdin.closed
+ self.assertEqual(proc.returncode, 0)
+
+ def test_kill_terminate(self):
+ # subprocess.Popen()'s terminate(), kill() and send_signal() do
+ # not raise exception after the process is gone. psutil.Popen
+ # diverges from that.
+ cmd = [PYTHON_EXE, "-c", "import time; time.sleep(60);"]
+ with psutil.Popen(cmd, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE) as proc:
+ proc.terminate()
+ proc.wait()
+ self.assertRaises(psutil.NoSuchProcess, proc.terminate)
+ self.assertRaises(psutil.NoSuchProcess, proc.kill)
+ self.assertRaises(psutil.NoSuchProcess, proc.send_signal,
+ signal.SIGTERM)
+ if WINDOWS:
+ self.assertRaises(psutil.NoSuchProcess, proc.send_signal,
+ signal.CTRL_C_EVENT)
+ self.assertRaises(psutil.NoSuchProcess, proc.send_signal,
+ signal.CTRL_BREAK_EVENT)
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_sunos.py b/lib/psutil/tests/test_sunos.py
new file mode 100644
index 0000000..dd74a49
--- /dev/null
+++ b/lib/psutil/tests/test_sunos.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Sun OS specific tests."""
+
+import os
+import unittest
+
+import psutil
+from psutil import SUNOS
+from psutil.tests import PsutilTestCase
+from psutil.tests import sh
+
+
+@unittest.skipIf(not SUNOS, "SUNOS only")
+class SunOSSpecificTestCase(PsutilTestCase):
+
+ def test_swap_memory(self):
+ out = sh('env PATH=/usr/sbin:/sbin:%s swap -l' % os.environ['PATH'])
+ lines = out.strip().split('\n')[1:]
+ if not lines:
+ raise ValueError('no swap device(s) configured')
+ total = free = 0
+ for line in lines:
+ line = line.split()
+ t, f = line[-2:]
+ total += int(int(t) * 512)
+ free += int(int(f) * 512)
+ used = total - free
+
+ psutil_swap = psutil.swap_memory()
+ self.assertEqual(psutil_swap.total, total)
+ self.assertEqual(psutil_swap.used, used)
+ self.assertEqual(psutil_swap.free, free)
+
+ def test_cpu_count(self):
+ out = sh("/usr/sbin/psrinfo")
+ self.assertEqual(psutil.cpu_count(), len(out.split('\n')))
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_system.py b/lib/psutil/tests/test_system.py
new file mode 100644
index 0000000..1722b51
--- /dev/null
+++ b/lib/psutil/tests/test_system.py
@@ -0,0 +1,892 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Tests for system APIS."""
+
+import contextlib
+import datetime
+import errno
+import os
+import platform
+import pprint
+import shutil
+import signal
+import socket
+import sys
+import time
+import unittest
+
+import psutil
+from psutil import AIX
+from psutil import BSD
+from psutil import FREEBSD
+from psutil import LINUX
+from psutil import MACOS
+from psutil import NETBSD
+from psutil import OPENBSD
+from psutil import POSIX
+from psutil import SUNOS
+from psutil import WINDOWS
+from psutil._compat import FileNotFoundError
+from psutil._compat import long
+from psutil.tests import ASCII_FS
+from psutil.tests import CI_TESTING
+from psutil.tests import DEVNULL
+from psutil.tests import GITHUB_ACTIONS
+from psutil.tests import GLOBAL_TIMEOUT
+from psutil.tests import HAS_BATTERY
+from psutil.tests import HAS_CPU_FREQ
+from psutil.tests import HAS_GETLOADAVG
+from psutil.tests import HAS_NET_IO_COUNTERS
+from psutil.tests import HAS_SENSORS_BATTERY
+from psutil.tests import HAS_SENSORS_FANS
+from psutil.tests import HAS_SENSORS_TEMPERATURES
+from psutil.tests import IS_64BIT
+from psutil.tests import MACOS_12PLUS
+from psutil.tests import PYPY
+from psutil.tests import UNICODE_SUFFIX
+from psutil.tests import PsutilTestCase
+from psutil.tests import check_net_address
+from psutil.tests import enum
+from psutil.tests import mock
+from psutil.tests import retry_on_failure
+
+
+# ===================================================================
+# --- System-related API tests
+# ===================================================================
+
+
+class TestProcessAPIs(PsutilTestCase):
+
+ def test_process_iter(self):
+ self.assertIn(os.getpid(), [x.pid for x in psutil.process_iter()])
+ sproc = self.spawn_testproc()
+ self.assertIn(sproc.pid, [x.pid for x in psutil.process_iter()])
+ p = psutil.Process(sproc.pid)
+ p.kill()
+ p.wait()
+ self.assertNotIn(sproc.pid, [x.pid for x in psutil.process_iter()])
+
+ with mock.patch('psutil.Process',
+ side_effect=psutil.NoSuchProcess(os.getpid())):
+ self.assertEqual(list(psutil.process_iter()), [])
+ with mock.patch('psutil.Process',
+ side_effect=psutil.AccessDenied(os.getpid())):
+ with self.assertRaises(psutil.AccessDenied):
+ list(psutil.process_iter())
+
+ def test_prcess_iter_w_attrs(self):
+ for p in psutil.process_iter(attrs=['pid']):
+ self.assertEqual(list(p.info.keys()), ['pid'])
+ with self.assertRaises(ValueError):
+ list(psutil.process_iter(attrs=['foo']))
+ with mock.patch("psutil._psplatform.Process.cpu_times",
+ side_effect=psutil.AccessDenied(0, "")) as m:
+ for p in psutil.process_iter(attrs=["pid", "cpu_times"]):
+ self.assertIsNone(p.info['cpu_times'])
+ self.assertGreaterEqual(p.info['pid'], 0)
+ assert m.called
+ with mock.patch("psutil._psplatform.Process.cpu_times",
+ side_effect=psutil.AccessDenied(0, "")) as m:
+ flag = object()
+ for p in psutil.process_iter(
+ attrs=["pid", "cpu_times"], ad_value=flag):
+ self.assertIs(p.info['cpu_times'], flag)
+ self.assertGreaterEqual(p.info['pid'], 0)
+ assert m.called
+
+ @unittest.skipIf(PYPY and WINDOWS,
+ "spawn_testproc() unreliable on PYPY + WINDOWS")
+ def test_wait_procs(self):
+ def callback(p):
+ pids.append(p.pid)
+
+ pids = []
+ sproc1 = self.spawn_testproc()
+ sproc2 = self.spawn_testproc()
+ sproc3 = self.spawn_testproc()
+ procs = [psutil.Process(x.pid) for x in (sproc1, sproc2, sproc3)]
+ self.assertRaises(ValueError, psutil.wait_procs, procs, timeout=-1)
+ self.assertRaises(TypeError, psutil.wait_procs, procs, callback=1)
+ t = time.time()
+ gone, alive = psutil.wait_procs(procs, timeout=0.01, callback=callback)
+
+ self.assertLess(time.time() - t, 0.5)
+ self.assertEqual(gone, [])
+ self.assertEqual(len(alive), 3)
+ self.assertEqual(pids, [])
+ for p in alive:
+ self.assertFalse(hasattr(p, 'returncode'))
+
+ @retry_on_failure(30)
+ def test(procs, callback):
+ gone, alive = psutil.wait_procs(procs, timeout=0.03,
+ callback=callback)
+ self.assertEqual(len(gone), 1)
+ self.assertEqual(len(alive), 2)
+ return gone, alive
+
+ sproc3.terminate()
+ gone, alive = test(procs, callback)
+ self.assertIn(sproc3.pid, [x.pid for x in gone])
+ if POSIX:
+ self.assertEqual(gone.pop().returncode, -signal.SIGTERM)
+ else:
+ self.assertEqual(gone.pop().returncode, 1)
+ self.assertEqual(pids, [sproc3.pid])
+ for p in alive:
+ self.assertFalse(hasattr(p, 'returncode'))
+
+ @retry_on_failure(30)
+ def test(procs, callback):
+ gone, alive = psutil.wait_procs(procs, timeout=0.03,
+ callback=callback)
+ self.assertEqual(len(gone), 3)
+ self.assertEqual(len(alive), 0)
+ return gone, alive
+
+ sproc1.terminate()
+ sproc2.terminate()
+ gone, alive = test(procs, callback)
+ self.assertEqual(set(pids), set([sproc1.pid, sproc2.pid, sproc3.pid]))
+ for p in gone:
+ self.assertTrue(hasattr(p, 'returncode'))
+
+ @unittest.skipIf(PYPY and WINDOWS,
+ "spawn_testproc() unreliable on PYPY + WINDOWS")
+ def test_wait_procs_no_timeout(self):
+ sproc1 = self.spawn_testproc()
+ sproc2 = self.spawn_testproc()
+ sproc3 = self.spawn_testproc()
+ procs = [psutil.Process(x.pid) for x in (sproc1, sproc2, sproc3)]
+ for p in procs:
+ p.terminate()
+ gone, alive = psutil.wait_procs(procs)
+
+ def test_pid_exists(self):
+ sproc = self.spawn_testproc()
+ self.assertTrue(psutil.pid_exists(sproc.pid))
+ p = psutil.Process(sproc.pid)
+ p.kill()
+ p.wait()
+ self.assertFalse(psutil.pid_exists(sproc.pid))
+ self.assertFalse(psutil.pid_exists(-1))
+ self.assertEqual(psutil.pid_exists(0), 0 in psutil.pids())
+
+ def test_pid_exists_2(self):
+ pids = psutil.pids()
+ for pid in pids:
+ try:
+ assert psutil.pid_exists(pid)
+ except AssertionError:
+ # in case the process disappeared in meantime fail only
+ # if it is no longer in psutil.pids()
+ time.sleep(.1)
+ self.assertNotIn(pid, psutil.pids())
+ pids = range(max(pids) + 5000, max(pids) + 6000)
+ for pid in pids:
+ self.assertFalse(psutil.pid_exists(pid), msg=pid)
+
+
+class TestMiscAPIs(PsutilTestCase):
+
+ def test_boot_time(self):
+ bt = psutil.boot_time()
+ self.assertIsInstance(bt, float)
+ self.assertGreater(bt, 0)
+ self.assertLess(bt, time.time())
+
+ @unittest.skipIf(CI_TESTING and not psutil.users(), "unreliable on CI")
+ def test_users(self):
+ users = psutil.users()
+ self.assertNotEqual(users, [])
+ for user in users:
+ assert user.name, user
+ self.assertIsInstance(user.name, str)
+ self.assertIsInstance(user.terminal, (str, type(None)))
+ if user.host is not None:
+ self.assertIsInstance(user.host, (str, type(None)))
+ user.terminal
+ user.host
+ assert user.started > 0.0, user
+ datetime.datetime.fromtimestamp(user.started)
+ if WINDOWS or OPENBSD:
+ self.assertIsNone(user.pid)
+ else:
+ psutil.Process(user.pid)
+
+ def test_test(self):
+ # test for psutil.test() function
+ stdout = sys.stdout
+ sys.stdout = DEVNULL
+ try:
+ psutil.test()
+ finally:
+ sys.stdout = stdout
+
+ def test_os_constants(self):
+ names = ["POSIX", "WINDOWS", "LINUX", "MACOS", "FREEBSD", "OPENBSD",
+ "NETBSD", "BSD", "SUNOS"]
+ for name in names:
+ self.assertIsInstance(getattr(psutil, name), bool, msg=name)
+
+ if os.name == 'posix':
+ assert psutil.POSIX
+ assert not psutil.WINDOWS
+ names.remove("POSIX")
+ if "linux" in sys.platform.lower():
+ assert psutil.LINUX
+ names.remove("LINUX")
+ elif "bsd" in sys.platform.lower():
+ assert psutil.BSD
+ self.assertEqual([psutil.FREEBSD, psutil.OPENBSD,
+ psutil.NETBSD].count(True), 1)
+ names.remove("BSD")
+ names.remove("FREEBSD")
+ names.remove("OPENBSD")
+ names.remove("NETBSD")
+ elif "sunos" in sys.platform.lower() or \
+ "solaris" in sys.platform.lower():
+ assert psutil.SUNOS
+ names.remove("SUNOS")
+ elif "darwin" in sys.platform.lower():
+ assert psutil.MACOS
+ names.remove("MACOS")
+ else:
+ assert psutil.WINDOWS
+ assert not psutil.POSIX
+ names.remove("WINDOWS")
+
+ # assert all other constants are set to False
+ for name in names:
+ self.assertIs(getattr(psutil, name), False, msg=name)
+
+
+class TestMemoryAPIs(PsutilTestCase):
+
+ def test_virtual_memory(self):
+ mem = psutil.virtual_memory()
+ assert mem.total > 0, mem
+ assert mem.available > 0, mem
+ assert 0 <= mem.percent <= 100, mem
+ assert mem.used > 0, mem
+ assert mem.free >= 0, mem
+ for name in mem._fields:
+ value = getattr(mem, name)
+ if name != 'percent':
+ self.assertIsInstance(value, (int, long))
+ if name != 'total':
+ if not value >= 0:
+ raise self.fail("%r < 0 (%s)" % (name, value))
+ if value > mem.total:
+ raise self.fail("%r > total (total=%s, %s=%s)"
+ % (name, mem.total, name, value))
+
+ def test_swap_memory(self):
+ mem = psutil.swap_memory()
+ self.assertEqual(
+ mem._fields, ('total', 'used', 'free', 'percent', 'sin', 'sout'))
+
+ assert mem.total >= 0, mem
+ assert mem.used >= 0, mem
+ if mem.total > 0:
+ # likely a system with no swap partition
+ assert mem.free > 0, mem
+ else:
+ assert mem.free == 0, mem
+ assert 0 <= mem.percent <= 100, mem
+ assert mem.sin >= 0, mem
+ assert mem.sout >= 0, mem
+
+
+class TestCpuAPIs(PsutilTestCase):
+
+ def test_cpu_count_logical(self):
+ logical = psutil.cpu_count()
+ self.assertIsNotNone(logical)
+ self.assertEqual(logical, len(psutil.cpu_times(percpu=True)))
+ self.assertGreaterEqual(logical, 1)
+ #
+ if os.path.exists("/proc/cpuinfo"):
+ with open("/proc/cpuinfo") as fd:
+ cpuinfo_data = fd.read()
+ if "physical id" not in cpuinfo_data:
+ raise unittest.SkipTest("cpuinfo doesn't include physical id")
+
+ def test_cpu_count_cores(self):
+ logical = psutil.cpu_count()
+ cores = psutil.cpu_count(logical=False)
+ if cores is None:
+ raise self.skipTest("cpu_count_cores() is None")
+ if WINDOWS and sys.getwindowsversion()[:2] <= (6, 1): # <= Vista
+ self.assertIsNone(cores)
+ else:
+ self.assertGreaterEqual(cores, 1)
+ self.assertGreaterEqual(logical, cores)
+
+ def test_cpu_count_none(self):
+ # https://github.com/giampaolo/psutil/issues/1085
+ for val in (-1, 0, None):
+ with mock.patch('psutil._psplatform.cpu_count_logical',
+ return_value=val) as m:
+ self.assertIsNone(psutil.cpu_count())
+ assert m.called
+ with mock.patch('psutil._psplatform.cpu_count_cores',
+ return_value=val) as m:
+ self.assertIsNone(psutil.cpu_count(logical=False))
+ assert m.called
+
+ def test_cpu_times(self):
+ # Check type, value >= 0, str().
+ total = 0
+ times = psutil.cpu_times()
+ sum(times)
+ for cp_time in times:
+ self.assertIsInstance(cp_time, float)
+ self.assertGreaterEqual(cp_time, 0.0)
+ total += cp_time
+ self.assertEqual(total, sum(times))
+ str(times)
+ # CPU times are always supposed to increase over time
+ # or at least remain the same and that's because time
+ # cannot go backwards.
+ # Surprisingly sometimes this might not be the case (at
+ # least on Windows and Linux), see:
+ # https://github.com/giampaolo/psutil/issues/392
+ # https://github.com/giampaolo/psutil/issues/645
+ # if not WINDOWS:
+ # last = psutil.cpu_times()
+ # for x in range(100):
+ # new = psutil.cpu_times()
+ # for field in new._fields:
+ # new_t = getattr(new, field)
+ # last_t = getattr(last, field)
+ # self.assertGreaterEqual(new_t, last_t,
+ # msg="%s %s" % (new_t, last_t))
+ # last = new
+
+ def test_cpu_times_time_increases(self):
+ # Make sure time increases between calls.
+ t1 = sum(psutil.cpu_times())
+ stop_at = time.time() + GLOBAL_TIMEOUT
+ while time.time() < stop_at:
+ t2 = sum(psutil.cpu_times())
+ if t2 > t1:
+ return
+ raise self.fail("time remained the same")
+
+ def test_per_cpu_times(self):
+ # Check type, value >= 0, str().
+ for times in psutil.cpu_times(percpu=True):
+ total = 0
+ sum(times)
+ for cp_time in times:
+ self.assertIsInstance(cp_time, float)
+ self.assertGreaterEqual(cp_time, 0.0)
+ total += cp_time
+ self.assertEqual(total, sum(times))
+ str(times)
+ self.assertEqual(len(psutil.cpu_times(percpu=True)[0]),
+ len(psutil.cpu_times(percpu=False)))
+
+ # Note: in theory CPU times are always supposed to increase over
+ # time or remain the same but never go backwards. In practice
+ # sometimes this is not the case.
+ # This issue seemd to be afflict Windows:
+ # https://github.com/giampaolo/psutil/issues/392
+ # ...but it turns out also Linux (rarely) behaves the same.
+ # last = psutil.cpu_times(percpu=True)
+ # for x in range(100):
+ # new = psutil.cpu_times(percpu=True)
+ # for index in range(len(new)):
+ # newcpu = new[index]
+ # lastcpu = last[index]
+ # for field in newcpu._fields:
+ # new_t = getattr(newcpu, field)
+ # last_t = getattr(lastcpu, field)
+ # self.assertGreaterEqual(
+ # new_t, last_t, msg="%s %s" % (lastcpu, newcpu))
+ # last = new
+
+ def test_per_cpu_times_2(self):
+ # Simulate some work load then make sure time have increased
+ # between calls.
+ tot1 = psutil.cpu_times(percpu=True)
+ giveup_at = time.time() + GLOBAL_TIMEOUT
+ while True:
+ if time.time() >= giveup_at:
+ return self.fail("timeout")
+ tot2 = psutil.cpu_times(percpu=True)
+ for t1, t2 in zip(tot1, tot2):
+ t1, t2 = psutil._cpu_busy_time(t1), psutil._cpu_busy_time(t2)
+ difference = t2 - t1
+ if difference >= 0.05:
+ return
+
+ def test_cpu_times_comparison(self):
+ # Make sure the sum of all per cpu times is almost equal to
+ # base "one cpu" times.
+ base = psutil.cpu_times()
+ per_cpu = psutil.cpu_times(percpu=True)
+ summed_values = base._make([sum(num) for num in zip(*per_cpu)])
+ for field in base._fields:
+ self.assertAlmostEqual(
+ getattr(base, field), getattr(summed_values, field), delta=1)
+
+ def _test_cpu_percent(self, percent, last_ret, new_ret):
+ try:
+ self.assertIsInstance(percent, float)
+ self.assertGreaterEqual(percent, 0.0)
+ self.assertIsNot(percent, -0.0)
+ self.assertLessEqual(percent, 100.0 * psutil.cpu_count())
+ except AssertionError as err:
+ raise AssertionError("\n%s\nlast=%s\nnew=%s" % (
+ err, pprint.pformat(last_ret), pprint.pformat(new_ret)))
+
+ def test_cpu_percent(self):
+ last = psutil.cpu_percent(interval=0.001)
+ for x in range(100):
+ new = psutil.cpu_percent(interval=None)
+ self._test_cpu_percent(new, last, new)
+ last = new
+ with self.assertRaises(ValueError):
+ psutil.cpu_percent(interval=-1)
+
+ def test_per_cpu_percent(self):
+ last = psutil.cpu_percent(interval=0.001, percpu=True)
+ self.assertEqual(len(last), psutil.cpu_count())
+ for x in range(100):
+ new = psutil.cpu_percent(interval=None, percpu=True)
+ for percent in new:
+ self._test_cpu_percent(percent, last, new)
+ last = new
+ with self.assertRaises(ValueError):
+ psutil.cpu_percent(interval=-1, percpu=True)
+
+ def test_cpu_times_percent(self):
+ last = psutil.cpu_times_percent(interval=0.001)
+ for x in range(100):
+ new = psutil.cpu_times_percent(interval=None)
+ for percent in new:
+ self._test_cpu_percent(percent, last, new)
+ self._test_cpu_percent(sum(new), last, new)
+ last = new
+ with self.assertRaises(ValueError):
+ psutil.cpu_times_percent(interval=-1)
+
+ def test_per_cpu_times_percent(self):
+ last = psutil.cpu_times_percent(interval=0.001, percpu=True)
+ self.assertEqual(len(last), psutil.cpu_count())
+ for x in range(100):
+ new = psutil.cpu_times_percent(interval=None, percpu=True)
+ for cpu in new:
+ for percent in cpu:
+ self._test_cpu_percent(percent, last, new)
+ self._test_cpu_percent(sum(cpu), last, new)
+ last = new
+
+ def test_per_cpu_times_percent_negative(self):
+ # see: https://github.com/giampaolo/psutil/issues/645
+ psutil.cpu_times_percent(percpu=True)
+ zero_times = [x._make([0 for x in range(len(x._fields))])
+ for x in psutil.cpu_times(percpu=True)]
+ with mock.patch('psutil.cpu_times', return_value=zero_times):
+ for cpu in psutil.cpu_times_percent(percpu=True):
+ for percent in cpu:
+ self._test_cpu_percent(percent, None, None)
+
+ def test_cpu_stats(self):
+ # Tested more extensively in per-platform test modules.
+ infos = psutil.cpu_stats()
+ self.assertEqual(
+ infos._fields,
+ ('ctx_switches', 'interrupts', 'soft_interrupts', 'syscalls'))
+ for name in infos._fields:
+ value = getattr(infos, name)
+ self.assertGreaterEqual(value, 0)
+ # on AIX, ctx_switches is always 0
+ if not AIX and name in ('ctx_switches', 'interrupts'):
+ self.assertGreater(value, 0)
+
+ # TODO: remove this once 1892 is fixed
+ @unittest.skipIf(MACOS and platform.machine() == 'arm64',
+ "skipped due to #1892")
+ @unittest.skipIf(not HAS_CPU_FREQ, "not supported")
+ def test_cpu_freq(self):
+ def check_ls(ls):
+ for nt in ls:
+ self.assertEqual(nt._fields, ('current', 'min', 'max'))
+ if nt.max != 0.0:
+ self.assertLessEqual(nt.current, nt.max)
+ for name in nt._fields:
+ value = getattr(nt, name)
+ self.assertIsInstance(value, (int, long, float))
+ self.assertGreaterEqual(value, 0)
+
+ ls = psutil.cpu_freq(percpu=True)
+ if FREEBSD and not ls:
+ raise self.skipTest("returns empty list on FreeBSD")
+
+ assert ls, ls
+ check_ls([psutil.cpu_freq(percpu=False)])
+
+ if LINUX:
+ self.assertEqual(len(ls), psutil.cpu_count())
+
+ @unittest.skipIf(not HAS_GETLOADAVG, "not supported")
+ def test_getloadavg(self):
+ loadavg = psutil.getloadavg()
+ self.assertEqual(len(loadavg), 3)
+ for load in loadavg:
+ self.assertIsInstance(load, float)
+ self.assertGreaterEqual(load, 0.0)
+
+
+class TestDiskAPIs(PsutilTestCase):
+
+ @unittest.skipIf(PYPY and not IS_64BIT, "unreliable on PYPY32 + 32BIT")
+ def test_disk_usage(self):
+ usage = psutil.disk_usage(os.getcwd())
+ self.assertEqual(usage._fields, ('total', 'used', 'free', 'percent'))
+
+ assert usage.total > 0, usage
+ assert usage.used > 0, usage
+ assert usage.free > 0, usage
+ assert usage.total > usage.used, usage
+ assert usage.total > usage.free, usage
+ assert 0 <= usage.percent <= 100, usage.percent
+ if hasattr(shutil, 'disk_usage'):
+ # py >= 3.3, see: http://bugs.python.org/issue12442
+ shutil_usage = shutil.disk_usage(os.getcwd())
+ tolerance = 5 * 1024 * 1024 # 5MB
+ self.assertEqual(usage.total, shutil_usage.total)
+ self.assertAlmostEqual(usage.free, shutil_usage.free,
+ delta=tolerance)
+ if not MACOS_12PLUS:
+ # see https://github.com/giampaolo/psutil/issues/2147
+ self.assertAlmostEqual(usage.used, shutil_usage.used,
+ delta=tolerance)
+
+ # if path does not exist OSError ENOENT is expected across
+ # all platforms
+ fname = self.get_testfn()
+ with self.assertRaises(FileNotFoundError):
+ psutil.disk_usage(fname)
+
+ @unittest.skipIf(not ASCII_FS, "not an ASCII fs")
+ def test_disk_usage_unicode(self):
+ # See: https://github.com/giampaolo/psutil/issues/416
+ with self.assertRaises(UnicodeEncodeError):
+ psutil.disk_usage(UNICODE_SUFFIX)
+
+ def test_disk_usage_bytes(self):
+ psutil.disk_usage(b'.')
+
+ def test_disk_partitions(self):
+ def check_ntuple(nt):
+ self.assertIsInstance(nt.device, str)
+ self.assertIsInstance(nt.mountpoint, str)
+ self.assertIsInstance(nt.fstype, str)
+ self.assertIsInstance(nt.opts, str)
+ self.assertIsInstance(nt.maxfile, (int, type(None)))
+ self.assertIsInstance(nt.maxpath, (int, type(None)))
+ if nt.maxfile is not None and not GITHUB_ACTIONS:
+ self.assertGreater(nt.maxfile, 0)
+ if nt.maxpath is not None:
+ self.assertGreater(nt.maxpath, 0)
+
+ # all = False
+ ls = psutil.disk_partitions(all=False)
+ self.assertTrue(ls, msg=ls)
+ for disk in ls:
+ check_ntuple(disk)
+ if WINDOWS and 'cdrom' in disk.opts:
+ continue
+ if not POSIX:
+ assert os.path.exists(disk.device), disk
+ else:
+ # we cannot make any assumption about this, see:
+ # http://goo.gl/p9c43
+ disk.device
+ # on modern systems mount points can also be files
+ assert os.path.exists(disk.mountpoint), disk
+ assert disk.fstype, disk
+
+ # all = True
+ ls = psutil.disk_partitions(all=True)
+ self.assertTrue(ls, msg=ls)
+ for disk in psutil.disk_partitions(all=True):
+ check_ntuple(disk)
+ if not WINDOWS and disk.mountpoint:
+ try:
+ os.stat(disk.mountpoint)
+ except OSError as err:
+ if GITHUB_ACTIONS and MACOS and err.errno == errno.EIO:
+ continue
+ # http://mail.python.org/pipermail/python-dev/
+ # 2012-June/120787.html
+ if err.errno not in (errno.EPERM, errno.EACCES):
+ raise
+ else:
+ assert os.path.exists(disk.mountpoint), disk
+
+ # ---
+
+ def find_mount_point(path):
+ path = os.path.abspath(path)
+ while not os.path.ismount(path):
+ path = os.path.dirname(path)
+ return path.lower()
+
+ mount = find_mount_point(__file__)
+ mounts = [x.mountpoint.lower() for x in
+ psutil.disk_partitions(all=True) if x.mountpoint]
+ self.assertIn(mount, mounts)
+
+ @unittest.skipIf(LINUX and not os.path.exists('/proc/diskstats'),
+ '/proc/diskstats not available on this linux version')
+ @unittest.skipIf(CI_TESTING and not psutil.disk_io_counters(),
+ "unreliable on CI") # no visible disks
+ def test_disk_io_counters(self):
+ def check_ntuple(nt):
+ self.assertEqual(nt[0], nt.read_count)
+ self.assertEqual(nt[1], nt.write_count)
+ self.assertEqual(nt[2], nt.read_bytes)
+ self.assertEqual(nt[3], nt.write_bytes)
+ if not (OPENBSD or NETBSD):
+ self.assertEqual(nt[4], nt.read_time)
+ self.assertEqual(nt[5], nt.write_time)
+ if LINUX:
+ self.assertEqual(nt[6], nt.read_merged_count)
+ self.assertEqual(nt[7], nt.write_merged_count)
+ self.assertEqual(nt[8], nt.busy_time)
+ elif FREEBSD:
+ self.assertEqual(nt[6], nt.busy_time)
+ for name in nt._fields:
+ assert getattr(nt, name) >= 0, nt
+
+ ret = psutil.disk_io_counters(perdisk=False)
+ assert ret is not None, "no disks on this system?"
+ check_ntuple(ret)
+ ret = psutil.disk_io_counters(perdisk=True)
+ # make sure there are no duplicates
+ self.assertEqual(len(ret), len(set(ret)))
+ for key in ret:
+ assert key, key
+ check_ntuple(ret[key])
+
+ def test_disk_io_counters_no_disks(self):
+ # Emulate a case where no disks are installed, see:
+ # https://github.com/giampaolo/psutil/issues/1062
+ with mock.patch('psutil._psplatform.disk_io_counters',
+ return_value={}) as m:
+ self.assertIsNone(psutil.disk_io_counters(perdisk=False))
+ self.assertEqual(psutil.disk_io_counters(perdisk=True), {})
+ assert m.called
+
+
+class TestNetAPIs(PsutilTestCase):
+
+ @unittest.skipIf(not HAS_NET_IO_COUNTERS, 'not supported')
+ def test_net_io_counters(self):
+ def check_ntuple(nt):
+ self.assertEqual(nt[0], nt.bytes_sent)
+ self.assertEqual(nt[1], nt.bytes_recv)
+ self.assertEqual(nt[2], nt.packets_sent)
+ self.assertEqual(nt[3], nt.packets_recv)
+ self.assertEqual(nt[4], nt.errin)
+ self.assertEqual(nt[5], nt.errout)
+ self.assertEqual(nt[6], nt.dropin)
+ self.assertEqual(nt[7], nt.dropout)
+ assert nt.bytes_sent >= 0, nt
+ assert nt.bytes_recv >= 0, nt
+ assert nt.packets_sent >= 0, nt
+ assert nt.packets_recv >= 0, nt
+ assert nt.errin >= 0, nt
+ assert nt.errout >= 0, nt
+ assert nt.dropin >= 0, nt
+ assert nt.dropout >= 0, nt
+
+ ret = psutil.net_io_counters(pernic=False)
+ check_ntuple(ret)
+ ret = psutil.net_io_counters(pernic=True)
+ self.assertNotEqual(ret, [])
+ for key in ret:
+ self.assertTrue(key)
+ self.assertIsInstance(key, str)
+ check_ntuple(ret[key])
+
+ @unittest.skipIf(not HAS_NET_IO_COUNTERS, 'not supported')
+ def test_net_io_counters_no_nics(self):
+ # Emulate a case where no NICs are installed, see:
+ # https://github.com/giampaolo/psutil/issues/1062
+ with mock.patch('psutil._psplatform.net_io_counters',
+ return_value={}) as m:
+ self.assertIsNone(psutil.net_io_counters(pernic=False))
+ self.assertEqual(psutil.net_io_counters(pernic=True), {})
+ assert m.called
+
+ def test_net_if_addrs(self):
+ nics = psutil.net_if_addrs()
+ assert nics, nics
+
+ nic_stats = psutil.net_if_stats()
+
+ # Not reliable on all platforms (net_if_addrs() reports more
+ # interfaces).
+ # self.assertEqual(sorted(nics.keys()),
+ # sorted(psutil.net_io_counters(pernic=True).keys()))
+
+ families = set([socket.AF_INET, socket.AF_INET6, psutil.AF_LINK])
+ for nic, addrs in nics.items():
+ self.assertIsInstance(nic, str)
+ self.assertEqual(len(set(addrs)), len(addrs))
+ for addr in addrs:
+ self.assertIsInstance(addr.family, int)
+ self.assertIsInstance(addr.address, str)
+ self.assertIsInstance(addr.netmask, (str, type(None)))
+ self.assertIsInstance(addr.broadcast, (str, type(None)))
+ self.assertIn(addr.family, families)
+ if sys.version_info >= (3, 4) and not PYPY:
+ self.assertIsInstance(addr.family, enum.IntEnum)
+ if nic_stats[nic].isup:
+ # Do not test binding to addresses of interfaces
+ # that are down
+ if addr.family == socket.AF_INET:
+ s = socket.socket(addr.family)
+ with contextlib.closing(s):
+ s.bind((addr.address, 0))
+ elif addr.family == socket.AF_INET6:
+ info = socket.getaddrinfo(
+ addr.address, 0, socket.AF_INET6,
+ socket.SOCK_STREAM, 0, socket.AI_PASSIVE)[0]
+ af, socktype, proto, canonname, sa = info
+ s = socket.socket(af, socktype, proto)
+ with contextlib.closing(s):
+ s.bind(sa)
+ for ip in (addr.address, addr.netmask, addr.broadcast,
+ addr.ptp):
+ if ip is not None:
+ # TODO: skip AF_INET6 for now because I get:
+ # AddressValueError: Only hex digits permitted in
+ # u'c6f3%lxcbr0' in u'fe80::c8e0:fff:fe54:c6f3%lxcbr0'
+ if addr.family != socket.AF_INET6:
+ check_net_address(ip, addr.family)
+ # broadcast and ptp addresses are mutually exclusive
+ if addr.broadcast:
+ self.assertIsNone(addr.ptp)
+ elif addr.ptp:
+ self.assertIsNone(addr.broadcast)
+
+ if BSD or MACOS or SUNOS:
+ if hasattr(socket, "AF_LINK"):
+ self.assertEqual(psutil.AF_LINK, socket.AF_LINK)
+ elif LINUX:
+ self.assertEqual(psutil.AF_LINK, socket.AF_PACKET)
+ elif WINDOWS:
+ self.assertEqual(psutil.AF_LINK, -1)
+
+ def test_net_if_addrs_mac_null_bytes(self):
+ # Simulate that the underlying C function returns an incomplete
+ # MAC address. psutil is supposed to fill it with null bytes.
+ # https://github.com/giampaolo/psutil/issues/786
+ if POSIX:
+ ret = [('em1', psutil.AF_LINK, '06:3d:29', None, None, None)]
+ else:
+ ret = [('em1', -1, '06-3d-29', None, None, None)]
+ with mock.patch('psutil._psplatform.net_if_addrs',
+ return_value=ret) as m:
+ addr = psutil.net_if_addrs()['em1'][0]
+ assert m.called
+ if POSIX:
+ self.assertEqual(addr.address, '06:3d:29:00:00:00')
+ else:
+ self.assertEqual(addr.address, '06-3d-29-00-00-00')
+
+ def test_net_if_stats(self):
+ nics = psutil.net_if_stats()
+ assert nics, nics
+ all_duplexes = (psutil.NIC_DUPLEX_FULL,
+ psutil.NIC_DUPLEX_HALF,
+ psutil.NIC_DUPLEX_UNKNOWN)
+ for name, stats in nics.items():
+ self.assertIsInstance(name, str)
+ isup, duplex, speed, mtu, flags = stats
+ self.assertIsInstance(isup, bool)
+ self.assertIn(duplex, all_duplexes)
+ self.assertIn(duplex, all_duplexes)
+ self.assertGreaterEqual(speed, 0)
+ self.assertGreaterEqual(mtu, 0)
+ self.assertIsInstance(flags, str)
+
+ @unittest.skipIf(not (LINUX or BSD or MACOS),
+ "LINUX or BSD or MACOS specific")
+ def test_net_if_stats_enodev(self):
+ # See: https://github.com/giampaolo/psutil/issues/1279
+ with mock.patch('psutil._psutil_posix.net_if_mtu',
+ side_effect=OSError(errno.ENODEV, "")) as m:
+ ret = psutil.net_if_stats()
+ self.assertEqual(ret, {})
+ assert m.called
+
+
+class TestSensorsAPIs(PsutilTestCase):
+
+ @unittest.skipIf(not HAS_SENSORS_TEMPERATURES, "not supported")
+ def test_sensors_temperatures(self):
+ temps = psutil.sensors_temperatures()
+ for name, entries in temps.items():
+ self.assertIsInstance(name, str)
+ for entry in entries:
+ self.assertIsInstance(entry.label, str)
+ if entry.current is not None:
+ self.assertGreaterEqual(entry.current, 0)
+ if entry.high is not None:
+ self.assertGreaterEqual(entry.high, 0)
+ if entry.critical is not None:
+ self.assertGreaterEqual(entry.critical, 0)
+
+ @unittest.skipIf(not HAS_SENSORS_TEMPERATURES, "not supported")
+ def test_sensors_temperatures_fahreneit(self):
+ d = {'coretemp': [('label', 50.0, 60.0, 70.0)]}
+ with mock.patch("psutil._psplatform.sensors_temperatures",
+ return_value=d) as m:
+ temps = psutil.sensors_temperatures(
+ fahrenheit=True)['coretemp'][0]
+ assert m.called
+ self.assertEqual(temps.current, 122.0)
+ self.assertEqual(temps.high, 140.0)
+ self.assertEqual(temps.critical, 158.0)
+
+ @unittest.skipIf(not HAS_SENSORS_BATTERY, "not supported")
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_sensors_battery(self):
+ ret = psutil.sensors_battery()
+ self.assertGreaterEqual(ret.percent, 0)
+ self.assertLessEqual(ret.percent, 100)
+ if ret.secsleft not in (psutil.POWER_TIME_UNKNOWN,
+ psutil.POWER_TIME_UNLIMITED):
+ self.assertGreaterEqual(ret.secsleft, 0)
+ else:
+ if ret.secsleft == psutil.POWER_TIME_UNLIMITED:
+ self.assertTrue(ret.power_plugged)
+ self.assertIsInstance(ret.power_plugged, bool)
+
+ @unittest.skipIf(not HAS_SENSORS_FANS, "not supported")
+ def test_sensors_fans(self):
+ fans = psutil.sensors_fans()
+ for name, entries in fans.items():
+ self.assertIsInstance(name, str)
+ for entry in entries:
+ self.assertIsInstance(entry.label, str)
+ self.assertIsInstance(entry.current, (int, long))
+ self.assertGreaterEqual(entry.current, 0)
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_testutils.py b/lib/psutil/tests/test_testutils.py
new file mode 100644
index 0000000..dd98538
--- /dev/null
+++ b/lib/psutil/tests/test_testutils.py
@@ -0,0 +1,441 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Tests for testing utils (psutil.tests namespace).
+"""
+
+import collections
+import contextlib
+import errno
+import os
+import socket
+import stat
+import subprocess
+import unittest
+
+import psutil
+import psutil.tests
+from psutil import FREEBSD
+from psutil import NETBSD
+from psutil import POSIX
+from psutil._common import open_binary
+from psutil._common import open_text
+from psutil._common import supports_ipv6
+from psutil.tests import CI_TESTING
+from psutil.tests import HAS_CONNECTIONS_UNIX
+from psutil.tests import PYTHON_EXE
+from psutil.tests import PsutilTestCase
+from psutil.tests import TestMemoryLeak
+from psutil.tests import bind_socket
+from psutil.tests import bind_unix_socket
+from psutil.tests import call_until
+from psutil.tests import chdir
+from psutil.tests import create_sockets
+from psutil.tests import get_free_port
+from psutil.tests import is_namedtuple
+from psutil.tests import mock
+from psutil.tests import process_namespace
+from psutil.tests import reap_children
+from psutil.tests import retry
+from psutil.tests import retry_on_failure
+from psutil.tests import safe_mkdir
+from psutil.tests import safe_rmpath
+from psutil.tests import serialrun
+from psutil.tests import system_namespace
+from psutil.tests import tcp_socketpair
+from psutil.tests import terminate
+from psutil.tests import unix_socketpair
+from psutil.tests import wait_for_file
+from psutil.tests import wait_for_pid
+
+
+# ===================================================================
+# --- Unit tests for test utilities.
+# ===================================================================
+
+
+class TestRetryDecorator(PsutilTestCase):
+
+ @mock.patch('time.sleep')
+ def test_retry_success(self, sleep):
+ # Fail 3 times out of 5; make sure the decorated fun returns.
+
+ @retry(retries=5, interval=1, logfun=None)
+ def foo():
+ while queue:
+ queue.pop()
+ 1 / 0
+ return 1
+
+ queue = list(range(3))
+ self.assertEqual(foo(), 1)
+ self.assertEqual(sleep.call_count, 3)
+
+ @mock.patch('time.sleep')
+ def test_retry_failure(self, sleep):
+ # Fail 6 times out of 5; th function is supposed to raise exc.
+ @retry(retries=5, interval=1, logfun=None)
+ def foo():
+ while queue:
+ queue.pop()
+ 1 / 0
+ return 1
+
+ queue = list(range(6))
+ self.assertRaises(ZeroDivisionError, foo)
+ self.assertEqual(sleep.call_count, 5)
+
+ @mock.patch('time.sleep')
+ def test_exception_arg(self, sleep):
+ @retry(exception=ValueError, interval=1)
+ def foo():
+ raise TypeError
+
+ self.assertRaises(TypeError, foo)
+ self.assertEqual(sleep.call_count, 0)
+
+ @mock.patch('time.sleep')
+ def test_no_interval_arg(self, sleep):
+ # if interval is not specified sleep is not supposed to be called
+
+ @retry(retries=5, interval=None, logfun=None)
+ def foo():
+ 1 / 0
+
+ self.assertRaises(ZeroDivisionError, foo)
+ self.assertEqual(sleep.call_count, 0)
+
+ @mock.patch('time.sleep')
+ def test_retries_arg(self, sleep):
+
+ @retry(retries=5, interval=1, logfun=None)
+ def foo():
+ 1 / 0
+
+ self.assertRaises(ZeroDivisionError, foo)
+ self.assertEqual(sleep.call_count, 5)
+
+ @mock.patch('time.sleep')
+ def test_retries_and_timeout_args(self, sleep):
+ self.assertRaises(ValueError, retry, retries=5, timeout=1)
+
+
+class TestSyncTestUtils(PsutilTestCase):
+
+ def test_wait_for_pid(self):
+ wait_for_pid(os.getpid())
+ nopid = max(psutil.pids()) + 99999
+ with mock.patch('psutil.tests.retry.__iter__', return_value=iter([0])):
+ self.assertRaises(psutil.NoSuchProcess, wait_for_pid, nopid)
+
+ def test_wait_for_file(self):
+ testfn = self.get_testfn()
+ with open(testfn, 'w') as f:
+ f.write('foo')
+ wait_for_file(testfn)
+ assert not os.path.exists(testfn)
+
+ def test_wait_for_file_empty(self):
+ testfn = self.get_testfn()
+ with open(testfn, 'w'):
+ pass
+ wait_for_file(testfn, empty=True)
+ assert not os.path.exists(testfn)
+
+ def test_wait_for_file_no_file(self):
+ testfn = self.get_testfn()
+ with mock.patch('psutil.tests.retry.__iter__', return_value=iter([0])):
+ self.assertRaises(IOError, wait_for_file, testfn)
+
+ def test_wait_for_file_no_delete(self):
+ testfn = self.get_testfn()
+ with open(testfn, 'w') as f:
+ f.write('foo')
+ wait_for_file(testfn, delete=False)
+ assert os.path.exists(testfn)
+
+ def test_call_until(self):
+ ret = call_until(lambda: 1, "ret == 1")
+ self.assertEqual(ret, 1)
+
+
+class TestFSTestUtils(PsutilTestCase):
+
+ def test_open_text(self):
+ with open_text(__file__) as f:
+ self.assertEqual(f.mode, 'rt')
+
+ def test_open_binary(self):
+ with open_binary(__file__) as f:
+ self.assertEqual(f.mode, 'rb')
+
+ def test_safe_mkdir(self):
+ testfn = self.get_testfn()
+ safe_mkdir(testfn)
+ assert os.path.isdir(testfn)
+ safe_mkdir(testfn)
+ assert os.path.isdir(testfn)
+
+ def test_safe_rmpath(self):
+ # test file is removed
+ testfn = self.get_testfn()
+ open(testfn, 'w').close()
+ safe_rmpath(testfn)
+ assert not os.path.exists(testfn)
+ # test no exception if path does not exist
+ safe_rmpath(testfn)
+ # test dir is removed
+ os.mkdir(testfn)
+ safe_rmpath(testfn)
+ assert not os.path.exists(testfn)
+ # test other exceptions are raised
+ with mock.patch('psutil.tests.os.stat',
+ side_effect=OSError(errno.EINVAL, "")) as m:
+ with self.assertRaises(OSError):
+ safe_rmpath(testfn)
+ assert m.called
+
+ def test_chdir(self):
+ testfn = self.get_testfn()
+ base = os.getcwd()
+ os.mkdir(testfn)
+ with chdir(testfn):
+ self.assertEqual(os.getcwd(), os.path.join(base, testfn))
+ self.assertEqual(os.getcwd(), base)
+
+
+class TestProcessUtils(PsutilTestCase):
+
+ def test_reap_children(self):
+ subp = self.spawn_testproc()
+ p = psutil.Process(subp.pid)
+ assert p.is_running()
+ reap_children()
+ assert not p.is_running()
+ assert not psutil.tests._pids_started
+ assert not psutil.tests._subprocesses_started
+
+ def test_spawn_children_pair(self):
+ child, grandchild = self.spawn_children_pair()
+ self.assertNotEqual(child.pid, grandchild.pid)
+ assert child.is_running()
+ assert grandchild.is_running()
+ children = psutil.Process().children()
+ self.assertEqual(children, [child])
+ children = psutil.Process().children(recursive=True)
+ self.assertEqual(len(children), 2)
+ self.assertIn(child, children)
+ self.assertIn(grandchild, children)
+ self.assertEqual(child.ppid(), os.getpid())
+ self.assertEqual(grandchild.ppid(), child.pid)
+
+ terminate(child)
+ assert not child.is_running()
+ assert grandchild.is_running()
+
+ terminate(grandchild)
+ assert not grandchild.is_running()
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ def test_spawn_zombie(self):
+ parent, zombie = self.spawn_zombie()
+ self.assertEqual(zombie.status(), psutil.STATUS_ZOMBIE)
+
+ def test_terminate(self):
+ # by subprocess.Popen
+ p = self.spawn_testproc()
+ terminate(p)
+ self.assertProcessGone(p)
+ terminate(p)
+ # by psutil.Process
+ p = psutil.Process(self.spawn_testproc().pid)
+ terminate(p)
+ self.assertProcessGone(p)
+ terminate(p)
+ # by psutil.Popen
+ cmd = [PYTHON_EXE, "-c", "import time; time.sleep(60);"]
+ p = psutil.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ terminate(p)
+ self.assertProcessGone(p)
+ terminate(p)
+ # by PID
+ pid = self.spawn_testproc().pid
+ terminate(pid)
+ self.assertProcessGone(p)
+ terminate(pid)
+ # zombie
+ if POSIX:
+ parent, zombie = self.spawn_zombie()
+ terminate(parent)
+ terminate(zombie)
+ self.assertProcessGone(parent)
+ self.assertProcessGone(zombie)
+
+
+class TestNetUtils(PsutilTestCase):
+
+ def bind_socket(self):
+ port = get_free_port()
+ with contextlib.closing(bind_socket(addr=('', port))) as s:
+ self.assertEqual(s.getsockname()[1], port)
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ def test_bind_unix_socket(self):
+ name = self.get_testfn()
+ sock = bind_unix_socket(name)
+ with contextlib.closing(sock):
+ self.assertEqual(sock.family, socket.AF_UNIX)
+ self.assertEqual(sock.type, socket.SOCK_STREAM)
+ self.assertEqual(sock.getsockname(), name)
+ assert os.path.exists(name)
+ assert stat.S_ISSOCK(os.stat(name).st_mode)
+ # UDP
+ name = self.get_testfn()
+ sock = bind_unix_socket(name, type=socket.SOCK_DGRAM)
+ with contextlib.closing(sock):
+ self.assertEqual(sock.type, socket.SOCK_DGRAM)
+
+ def tcp_tcp_socketpair(self):
+ addr = ("127.0.0.1", get_free_port())
+ server, client = tcp_socketpair(socket.AF_INET, addr=addr)
+ with contextlib.closing(server):
+ with contextlib.closing(client):
+ # Ensure they are connected and the positions are
+ # correct.
+ self.assertEqual(server.getsockname(), addr)
+ self.assertEqual(client.getpeername(), addr)
+ self.assertNotEqual(client.getsockname(), addr)
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ @unittest.skipIf(NETBSD or FREEBSD,
+ "/var/run/log UNIX socket opened by default")
+ def test_unix_socketpair(self):
+ p = psutil.Process()
+ num_fds = p.num_fds()
+ assert not p.connections(kind='unix')
+ name = self.get_testfn()
+ server, client = unix_socketpair(name)
+ try:
+ assert os.path.exists(name)
+ assert stat.S_ISSOCK(os.stat(name).st_mode)
+ self.assertEqual(p.num_fds() - num_fds, 2)
+ self.assertEqual(len(p.connections(kind='unix')), 2)
+ self.assertEqual(server.getsockname(), name)
+ self.assertEqual(client.getpeername(), name)
+ finally:
+ client.close()
+ server.close()
+
+ def test_create_sockets(self):
+ with create_sockets() as socks:
+ fams = collections.defaultdict(int)
+ types = collections.defaultdict(int)
+ for s in socks:
+ fams[s.family] += 1
+ # work around http://bugs.python.org/issue30204
+ types[s.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE)] += 1
+ self.assertGreaterEqual(fams[socket.AF_INET], 2)
+ if supports_ipv6():
+ self.assertGreaterEqual(fams[socket.AF_INET6], 2)
+ if POSIX and HAS_CONNECTIONS_UNIX:
+ self.assertGreaterEqual(fams[socket.AF_UNIX], 2)
+ self.assertGreaterEqual(types[socket.SOCK_STREAM], 2)
+ self.assertGreaterEqual(types[socket.SOCK_DGRAM], 2)
+
+
+@serialrun
+class TestMemLeakClass(TestMemoryLeak):
+
+ @retry_on_failure()
+ def test_times(self):
+ def fun():
+ cnt['cnt'] += 1
+ cnt = {'cnt': 0}
+ self.execute(fun, times=10, warmup_times=15)
+ self.assertEqual(cnt['cnt'], 26)
+
+ def test_param_err(self):
+ self.assertRaises(ValueError, self.execute, lambda: 0, times=0)
+ self.assertRaises(ValueError, self.execute, lambda: 0, times=-1)
+ self.assertRaises(ValueError, self.execute, lambda: 0, warmup_times=-1)
+ self.assertRaises(ValueError, self.execute, lambda: 0, tolerance=-1)
+ self.assertRaises(ValueError, self.execute, lambda: 0, retries=-1)
+
+ @retry_on_failure()
+ @unittest.skipIf(CI_TESTING, "skipped on CI")
+ def test_leak_mem(self):
+ ls = []
+
+ def fun(ls=ls):
+ ls.append("x" * 24 * 1024)
+
+ try:
+ # will consume around 3M in total
+ self.assertRaisesRegex(AssertionError, "extra-mem",
+ self.execute, fun, times=50)
+ finally:
+ del ls
+
+ def test_unclosed_files(self):
+ def fun():
+ f = open(__file__)
+ self.addCleanup(f.close)
+ box.append(f)
+
+ box = []
+ kind = "fd" if POSIX else "handle"
+ self.assertRaisesRegex(AssertionError, "unclosed " + kind,
+ self.execute, fun)
+
+ def test_tolerance(self):
+ def fun():
+ ls.append("x" * 24 * 1024)
+ ls = []
+ times = 100
+ self.execute(fun, times=times, warmup_times=0,
+ tolerance=200 * 1024 * 1024)
+ self.assertEqual(len(ls), times + 1)
+
+ def test_execute_w_exc(self):
+ def fun():
+ 1 / 0
+ self.execute_w_exc(ZeroDivisionError, fun)
+ with self.assertRaises(ZeroDivisionError):
+ self.execute_w_exc(OSError, fun)
+
+ def fun():
+ pass
+ with self.assertRaises(AssertionError):
+ self.execute_w_exc(ZeroDivisionError, fun)
+
+
+class TestTestingUtils(PsutilTestCase):
+
+ def test_process_namespace(self):
+ p = psutil.Process()
+ ns = process_namespace(p)
+ ns.test()
+ fun = [x for x in ns.iter(ns.getters) if x[1] == 'ppid'][0][0]
+ self.assertEqual(fun(), p.ppid())
+
+ def test_system_namespace(self):
+ ns = system_namespace()
+ fun = [x for x in ns.iter(ns.getters) if x[1] == 'net_if_addrs'][0][0]
+ self.assertEqual(fun(), psutil.net_if_addrs())
+
+
+class TestOtherUtils(PsutilTestCase):
+
+ def test_is_namedtuple(self):
+ assert is_namedtuple(collections.namedtuple('foo', 'a b c')(1, 2, 3))
+ assert not is_namedtuple(tuple())
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_unicode.py b/lib/psutil/tests/test_unicode.py
new file mode 100644
index 0000000..3fa3f01
--- /dev/null
+++ b/lib/psutil/tests/test_unicode.py
@@ -0,0 +1,355 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""
+Notes about unicode handling in psutil
+======================================
+
+Starting from version 5.3.0 psutil adds unicode support, see:
+https://github.com/giampaolo/psutil/issues/1040
+The notes below apply to *any* API returning a string such as
+process exe(), cwd() or username():
+
+* all strings are encoded by using the OS filesystem encoding
+ (sys.getfilesystemencoding()) which varies depending on the platform
+ (e.g. "UTF-8" on macOS, "mbcs" on Win)
+* no API call is supposed to crash with UnicodeDecodeError
+* instead, in case of badly encoded data returned by the OS, the
+ following error handlers are used to replace the corrupted characters in
+ the string:
+ * Python 3: sys.getfilesystemencodeerrors() (PY 3.6+) or
+ "surrogatescape" on POSIX and "replace" on Windows
+ * Python 2: "replace"
+* on Python 2 all APIs return bytes (str type), never unicode
+* on Python 2, you can go back to unicode by doing:
+
+ >>> unicode(p.exe(), sys.getdefaultencoding(), errors="replace")
+
+For a detailed explanation of how psutil handles unicode see #1040.
+
+Tests
+=====
+
+List of APIs returning or dealing with a string:
+('not tested' means they are not tested to deal with non-ASCII strings):
+
+* Process.cmdline()
+* Process.connections('unix')
+* Process.cwd()
+* Process.environ()
+* Process.exe()
+* Process.memory_maps()
+* Process.name()
+* Process.open_files()
+* Process.username() (not tested)
+
+* disk_io_counters() (not tested)
+* disk_partitions() (not tested)
+* disk_usage(str)
+* net_connections('unix')
+* net_if_addrs() (not tested)
+* net_if_stats() (not tested)
+* net_io_counters() (not tested)
+* sensors_fans() (not tested)
+* sensors_temperatures() (not tested)
+* users() (not tested)
+
+* WindowsService.binpath() (not tested)
+* WindowsService.description() (not tested)
+* WindowsService.display_name() (not tested)
+* WindowsService.name() (not tested)
+* WindowsService.status() (not tested)
+* WindowsService.username() (not tested)
+
+In here we create a unicode path with a funky non-ASCII name and (where
+possible) make psutil return it back (e.g. on name(), exe(), open_files(),
+etc.) and make sure that:
+
+* psutil never crashes with UnicodeDecodeError
+* the returned path matches
+"""
+
+import os
+import shutil
+import traceback
+import unittest
+import warnings
+from contextlib import closing
+
+import psutil
+from psutil import BSD
+from psutil import OPENBSD
+from psutil import POSIX
+from psutil import WINDOWS
+from psutil._compat import PY3
+from psutil._compat import u
+from psutil.tests import APPVEYOR
+from psutil.tests import ASCII_FS
+from psutil.tests import CI_TESTING
+from psutil.tests import HAS_CONNECTIONS_UNIX
+from psutil.tests import HAS_ENVIRON
+from psutil.tests import HAS_MEMORY_MAPS
+from psutil.tests import INVALID_UNICODE_SUFFIX
+from psutil.tests import PYPY
+from psutil.tests import TESTFN_PREFIX
+from psutil.tests import UNICODE_SUFFIX
+from psutil.tests import PsutilTestCase
+from psutil.tests import bind_unix_socket
+from psutil.tests import chdir
+from psutil.tests import copyload_shared_lib
+from psutil.tests import create_exe
+from psutil.tests import get_testfn
+from psutil.tests import safe_mkdir
+from psutil.tests import safe_rmpath
+from psutil.tests import serialrun
+from psutil.tests import skip_on_access_denied
+from psutil.tests import spawn_testproc
+from psutil.tests import terminate
+
+
+if APPVEYOR:
+ def safe_rmpath(path): # NOQA
+ # TODO - this is quite random and I'm not sure why it happens,
+ # nor I can reproduce it locally:
+ # https://ci.appveyor.com/project/giampaolo/psutil/build/job/
+ # jiq2cgd6stsbtn60
+ # safe_rmpath() happens after reap_children() so this is weird
+ # Perhaps wait_procs() on Windows is broken? Maybe because
+ # of STILL_ACTIVE?
+ # https://github.com/giampaolo/psutil/blob/
+ # 68c7a70728a31d8b8b58f4be6c4c0baa2f449eda/psutil/arch/
+ # windows/process_info.c#L146
+ from psutil.tests import safe_rmpath as rm
+ try:
+ return rm(path)
+ except WindowsError:
+ traceback.print_exc()
+
+
+def try_unicode(suffix):
+ """Return True if both the fs and the subprocess module can
+ deal with a unicode file name.
+ """
+ sproc = None
+ testfn = get_testfn(suffix=suffix)
+ try:
+ safe_rmpath(testfn)
+ create_exe(testfn)
+ sproc = spawn_testproc(cmd=[testfn])
+ shutil.copyfile(testfn, testfn + '-2')
+ safe_rmpath(testfn + '-2')
+ except (UnicodeEncodeError, IOError):
+ return False
+ else:
+ return True
+ finally:
+ if sproc is not None:
+ terminate(sproc)
+ safe_rmpath(testfn)
+
+
+# ===================================================================
+# FS APIs
+# ===================================================================
+
+
+class BaseUnicodeTest(PsutilTestCase):
+ funky_suffix = None
+
+ def setUp(self):
+ if self.funky_suffix is not None:
+ if not try_unicode(self.funky_suffix):
+ raise self.skipTest("can't handle unicode str")
+
+
+@serialrun
+@unittest.skipIf(ASCII_FS, "ASCII fs")
+@unittest.skipIf(PYPY and not PY3, "too much trouble on PYPY2")
+class TestFSAPIs(BaseUnicodeTest):
+ """Test FS APIs with a funky, valid, UTF8 path name."""
+
+ funky_suffix = UNICODE_SUFFIX
+
+ @classmethod
+ def setUpClass(cls):
+ cls.funky_name = get_testfn(suffix=cls.funky_suffix)
+ create_exe(cls.funky_name)
+
+ @classmethod
+ def tearDownClass(cls):
+ safe_rmpath(cls.funky_name)
+
+ def expect_exact_path_match(self):
+ # Do not expect psutil to correctly handle unicode paths on
+ # Python 2 if os.listdir() is not able either.
+ here = '.' if isinstance(self.funky_name, str) else u('.')
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ return self.funky_name in os.listdir(here)
+
+ # ---
+
+ def test_proc_exe(self):
+ subp = self.spawn_testproc(cmd=[self.funky_name])
+ p = psutil.Process(subp.pid)
+ exe = p.exe()
+ self.assertIsInstance(exe, str)
+ if self.expect_exact_path_match():
+ self.assertEqual(os.path.normcase(exe),
+ os.path.normcase(self.funky_name))
+
+ def test_proc_name(self):
+ subp = self.spawn_testproc(cmd=[self.funky_name])
+ name = psutil.Process(subp.pid).name()
+ self.assertIsInstance(name, str)
+ if self.expect_exact_path_match():
+ self.assertEqual(name, os.path.basename(self.funky_name))
+
+ def test_proc_cmdline(self):
+ subp = self.spawn_testproc(cmd=[self.funky_name])
+ p = psutil.Process(subp.pid)
+ cmdline = p.cmdline()
+ for part in cmdline:
+ self.assertIsInstance(part, str)
+ if self.expect_exact_path_match():
+ self.assertEqual(cmdline, [self.funky_name])
+
+ def test_proc_cwd(self):
+ dname = self.funky_name + "2"
+ self.addCleanup(safe_rmpath, dname)
+ safe_mkdir(dname)
+ with chdir(dname):
+ p = psutil.Process()
+ cwd = p.cwd()
+ self.assertIsInstance(p.cwd(), str)
+ if self.expect_exact_path_match():
+ self.assertEqual(cwd, dname)
+
+ @unittest.skipIf(PYPY and WINDOWS, "fails on PYPY + WINDOWS")
+ def test_proc_open_files(self):
+ p = psutil.Process()
+ start = set(p.open_files())
+ with open(self.funky_name, 'rb'):
+ new = set(p.open_files())
+ path = (new - start).pop().path
+ self.assertIsInstance(path, str)
+ if BSD and not path:
+ # XXX - see https://github.com/giampaolo/psutil/issues/595
+ return self.skipTest("open_files on BSD is broken")
+ if self.expect_exact_path_match():
+ self.assertEqual(os.path.normcase(path),
+ os.path.normcase(self.funky_name))
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ def test_proc_connections(self):
+ name = self.get_testfn(suffix=self.funky_suffix)
+ try:
+ sock = bind_unix_socket(name)
+ except UnicodeEncodeError:
+ if PY3:
+ raise
+ else:
+ raise unittest.SkipTest("not supported")
+ with closing(sock):
+ conn = psutil.Process().connections('unix')[0]
+ self.assertIsInstance(conn.laddr, str)
+ # AF_UNIX addr not set on OpenBSD
+ if not OPENBSD: # XXX
+ self.assertEqual(conn.laddr, name)
+
+ @unittest.skipIf(not POSIX, "POSIX only")
+ @unittest.skipIf(not HAS_CONNECTIONS_UNIX, "can't list UNIX sockets")
+ @skip_on_access_denied()
+ def test_net_connections(self):
+ def find_sock(cons):
+ for conn in cons:
+ if os.path.basename(conn.laddr).startswith(TESTFN_PREFIX):
+ return conn
+ raise ValueError("connection not found")
+
+ name = self.get_testfn(suffix=self.funky_suffix)
+ try:
+ sock = bind_unix_socket(name)
+ except UnicodeEncodeError:
+ if PY3:
+ raise
+ else:
+ raise unittest.SkipTest("not supported")
+ with closing(sock):
+ cons = psutil.net_connections(kind='unix')
+ # AF_UNIX addr not set on OpenBSD
+ if not OPENBSD:
+ conn = find_sock(cons)
+ self.assertIsInstance(conn.laddr, str)
+ self.assertEqual(conn.laddr, name)
+
+ def test_disk_usage(self):
+ dname = self.funky_name + "2"
+ self.addCleanup(safe_rmpath, dname)
+ safe_mkdir(dname)
+ psutil.disk_usage(dname)
+
+ @unittest.skipIf(not HAS_MEMORY_MAPS, "not supported")
+ @unittest.skipIf(not PY3, "ctypes does not support unicode on PY2")
+ @unittest.skipIf(PYPY, "unstable on PYPY")
+ def test_memory_maps(self):
+ # XXX: on Python 2, using ctypes.CDLL with a unicode path
+ # opens a message box which blocks the test run.
+ with copyload_shared_lib(suffix=self.funky_suffix) as funky_path:
+ def normpath(p):
+ return os.path.realpath(os.path.normcase(p))
+ libpaths = [normpath(x.path)
+ for x in psutil.Process().memory_maps()]
+ # ...just to have a clearer msg in case of failure
+ libpaths = [x for x in libpaths if TESTFN_PREFIX in x]
+ self.assertIn(normpath(funky_path), libpaths)
+ for path in libpaths:
+ self.assertIsInstance(path, str)
+
+
+@unittest.skipIf(CI_TESTING, "unreliable on CI")
+class TestFSAPIsWithInvalidPath(TestFSAPIs):
+ """Test FS APIs with a funky, invalid path name."""
+ funky_suffix = INVALID_UNICODE_SUFFIX
+
+ @classmethod
+ def expect_exact_path_match(cls):
+ # Invalid unicode names are supposed to work on Python 2.
+ return True
+
+
+# ===================================================================
+# Non fs APIs
+# ===================================================================
+
+
+class TestNonFSAPIS(BaseUnicodeTest):
+ """Unicode tests for non fs-related APIs."""
+ funky_suffix = UNICODE_SUFFIX if PY3 else 'è'
+
+ @unittest.skipIf(not HAS_ENVIRON, "not supported")
+ @unittest.skipIf(PYPY and WINDOWS, "segfaults on PYPY + WINDOWS")
+ def test_proc_environ(self):
+ # Note: differently from others, this test does not deal
+ # with fs paths. On Python 2 subprocess module is broken as
+ # it's not able to handle with non-ASCII env vars, so
+ # we use "è", which is part of the extended ASCII table
+ # (unicode point <= 255).
+ env = os.environ.copy()
+ env['FUNNY_ARG'] = self.funky_suffix
+ sproc = self.spawn_testproc(env=env)
+ p = psutil.Process(sproc.pid)
+ env = p.environ()
+ for k, v in env.items():
+ self.assertIsInstance(k, str)
+ self.assertIsInstance(v, str)
+ self.assertEqual(env['FUNNY_ARG'], self.funky_suffix)
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/psutil/tests/test_windows.py b/lib/psutil/tests/test_windows.py
new file mode 100644
index 0000000..55e6731
--- /dev/null
+++ b/lib/psutil/tests/test_windows.py
@@ -0,0 +1,898 @@
+#!/usr/bin/env python3
+# -*- coding: UTF-8 -*
+
+# Copyright (c) 2009, Giampaolo Rodola'. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+"""Windows specific tests."""
+
+import datetime
+import errno
+import glob
+import os
+import platform
+import re
+import signal
+import subprocess
+import sys
+import time
+import unittest
+import warnings
+
+import psutil
+from psutil import WINDOWS
+from psutil._compat import FileNotFoundError
+from psutil._compat import which
+from psutil._compat import super
+from psutil.tests import APPVEYOR
+from psutil.tests import GITHUB_ACTIONS
+from psutil.tests import HAS_BATTERY
+from psutil.tests import IS_64BIT
+from psutil.tests import PY3
+from psutil.tests import PYPY
+from psutil.tests import TOLERANCE_DISK_USAGE
+from psutil.tests import TOLERANCE_SYS_MEM
+from psutil.tests import PsutilTestCase
+from psutil.tests import mock
+from psutil.tests import retry_on_failure
+from psutil.tests import sh
+from psutil.tests import spawn_testproc
+from psutil.tests import terminate
+
+
+if WINDOWS and not PYPY:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ import win32api # requires "pip install pywin32"
+ import win32con
+ import win32process
+ import wmi # requires "pip install wmi" / "make setup-dev-env"
+
+if WINDOWS:
+ from psutil._pswindows import convert_oserror
+
+
+cext = psutil._psplatform.cext
+
+
+@unittest.skipIf(not WINDOWS, "WINDOWS only")
+@unittest.skipIf(PYPY, "pywin32 not available on PYPY")
+# https://github.com/giampaolo/psutil/pull/1762#issuecomment-632892692
+@unittest.skipIf(GITHUB_ACTIONS and not PY3, "pywin32 broken on GITHUB + PY2")
+class WindowsTestCase(PsutilTestCase):
+ pass
+
+
+def powershell(cmd):
+ """Currently not used, but avalable just in case. Usage:
+
+ >>> powershell(
+ "Get-CIMInstance Win32_PageFileUsage | Select AllocatedBaseSize")
+ """
+ if not which("powershell.exe"):
+ raise unittest.SkipTest("powershell.exe not available")
+ cmdline = \
+ 'powershell.exe -ExecutionPolicy Bypass -NoLogo -NonInteractive ' + \
+ '-NoProfile -WindowStyle Hidden -Command "%s"' % cmd
+ return sh(cmdline)
+
+
+def wmic(path, what, converter=int):
+ """Currently not used, but avalable just in case. Usage:
+
+ >>> wmic("Win32_OperatingSystem", "FreePhysicalMemory")
+ 2134124534
+ """
+ out = sh("wmic path %s get %s" % (path, what)).strip()
+ data = "".join(out.splitlines()[1:]).strip() # get rid of the header
+ if converter is not None:
+ if "," in what:
+ return tuple([converter(x) for x in data.split()])
+ else:
+ return converter(data)
+ else:
+ return data
+
+
+# ===================================================================
+# System APIs
+# ===================================================================
+
+
+class TestCpuAPIs(WindowsTestCase):
+
+ @unittest.skipIf('NUMBER_OF_PROCESSORS' not in os.environ,
+ 'NUMBER_OF_PROCESSORS env var is not available')
+ def test_cpu_count_vs_NUMBER_OF_PROCESSORS(self):
+ # Will likely fail on many-cores systems:
+ # https://stackoverflow.com/questions/31209256
+ num_cpus = int(os.environ['NUMBER_OF_PROCESSORS'])
+ self.assertEqual(num_cpus, psutil.cpu_count())
+
+ def test_cpu_count_vs_GetSystemInfo(self):
+ # Will likely fail on many-cores systems:
+ # https://stackoverflow.com/questions/31209256
+ sys_value = win32api.GetSystemInfo()[5]
+ psutil_value = psutil.cpu_count()
+ self.assertEqual(sys_value, psutil_value)
+
+ def test_cpu_count_logical_vs_wmi(self):
+ w = wmi.WMI()
+ procs = sum(proc.NumberOfLogicalProcessors
+ for proc in w.Win32_Processor())
+ self.assertEqual(psutil.cpu_count(), procs)
+
+ def test_cpu_count_cores_vs_wmi(self):
+ w = wmi.WMI()
+ cores = sum(proc.NumberOfCores for proc in w.Win32_Processor())
+ self.assertEqual(psutil.cpu_count(logical=False), cores)
+
+ def test_cpu_count_vs_cpu_times(self):
+ self.assertEqual(psutil.cpu_count(),
+ len(psutil.cpu_times(percpu=True)))
+
+ def test_cpu_freq(self):
+ w = wmi.WMI()
+ proc = w.Win32_Processor()[0]
+ self.assertEqual(proc.CurrentClockSpeed, psutil.cpu_freq().current)
+ self.assertEqual(proc.MaxClockSpeed, psutil.cpu_freq().max)
+
+
+class TestSystemAPIs(WindowsTestCase):
+
+ def test_nic_names(self):
+ out = sh('ipconfig /all')
+ nics = psutil.net_io_counters(pernic=True).keys()
+ for nic in nics:
+ if "pseudo-interface" in nic.replace(' ', '-').lower():
+ continue
+ if nic not in out:
+ raise self.fail(
+ "%r nic wasn't found in 'ipconfig /all' output" % nic)
+
+ def test_total_phymem(self):
+ w = wmi.WMI().Win32_ComputerSystem()[0]
+ self.assertEqual(int(w.TotalPhysicalMemory),
+ psutil.virtual_memory().total)
+
+ def test_free_phymem(self):
+ w = wmi.WMI().Win32_PerfRawData_PerfOS_Memory()[0]
+ self.assertAlmostEqual(
+ int(w.AvailableBytes), psutil.virtual_memory().free,
+ delta=TOLERANCE_SYS_MEM)
+
+ def test_total_swapmem(self):
+ w = wmi.WMI().Win32_PerfRawData_PerfOS_Memory()[0]
+ self.assertEqual(int(w.CommitLimit) - psutil.virtual_memory().total,
+ psutil.swap_memory().total)
+ if (psutil.swap_memory().total == 0):
+ self.assertEqual(0, psutil.swap_memory().free)
+ self.assertEqual(0, psutil.swap_memory().used)
+
+ def test_percent_swapmem(self):
+ if (psutil.swap_memory().total > 0):
+ w = wmi.WMI().Win32_PerfRawData_PerfOS_PagingFile(
+ Name="_Total")[0]
+ # calculate swap usage to percent
+ percentSwap = int(w.PercentUsage) * 100 / int(w.PercentUsage_Base)
+ # exact percent may change but should be reasonable
+ # assert within +/- 5% and between 0 and 100%
+ self.assertGreaterEqual(psutil.swap_memory().percent, 0)
+ self.assertAlmostEqual(psutil.swap_memory().percent, percentSwap,
+ delta=5)
+ self.assertLessEqual(psutil.swap_memory().percent, 100)
+
+ # @unittest.skipIf(wmi is None, "wmi module is not installed")
+ # def test__UPTIME(self):
+ # # _UPTIME constant is not public but it is used internally
+ # # as value to return for pid 0 creation time.
+ # # WMI behaves the same.
+ # w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ # p = psutil.Process(0)
+ # wmic_create = str(w.CreationDate.split('.')[0])
+ # psutil_create = time.strftime("%Y%m%d%H%M%S",
+ # time.localtime(p.create_time()))
+
+ # Note: this test is not very reliable
+ @unittest.skipIf(APPVEYOR, "test not relieable on appveyor")
+ @retry_on_failure()
+ def test_pids(self):
+ # Note: this test might fail if the OS is starting/killing
+ # other processes in the meantime
+ w = wmi.WMI().Win32_Process()
+ wmi_pids = set([x.ProcessId for x in w])
+ psutil_pids = set(psutil.pids())
+ self.assertEqual(wmi_pids, psutil_pids)
+
+ @retry_on_failure()
+ def test_disks(self):
+ ps_parts = psutil.disk_partitions(all=True)
+ wmi_parts = wmi.WMI().Win32_LogicalDisk()
+ for ps_part in ps_parts:
+ for wmi_part in wmi_parts:
+ if ps_part.device.replace('\\', '') == wmi_part.DeviceID:
+ if not ps_part.mountpoint:
+ # this is usually a CD-ROM with no disk inserted
+ break
+ if 'cdrom' in ps_part.opts:
+ break
+ if ps_part.mountpoint.startswith('A:'):
+ break # floppy
+ try:
+ usage = psutil.disk_usage(ps_part.mountpoint)
+ except FileNotFoundError:
+ # usually this is the floppy
+ break
+ self.assertEqual(usage.total, int(wmi_part.Size))
+ wmi_free = int(wmi_part.FreeSpace)
+ self.assertEqual(usage.free, wmi_free)
+ # 10 MB tolerance
+ if abs(usage.free - wmi_free) > 10 * 1024 * 1024:
+ raise self.fail("psutil=%s, wmi=%s" % (
+ usage.free, wmi_free))
+ break
+ else:
+ raise self.fail("can't find partition %s" % repr(ps_part))
+
+ @retry_on_failure()
+ def test_disk_usage(self):
+ for disk in psutil.disk_partitions():
+ if 'cdrom' in disk.opts:
+ continue
+ sys_value = win32api.GetDiskFreeSpaceEx(disk.mountpoint)
+ psutil_value = psutil.disk_usage(disk.mountpoint)
+ self.assertAlmostEqual(sys_value[0], psutil_value.free,
+ delta=TOLERANCE_DISK_USAGE)
+ self.assertAlmostEqual(sys_value[1], psutil_value.total,
+ delta=TOLERANCE_DISK_USAGE)
+ self.assertEqual(psutil_value.used,
+ psutil_value.total - psutil_value.free)
+
+ def test_disk_partitions(self):
+ sys_value = [
+ x + '\\' for x in win32api.GetLogicalDriveStrings().split("\\\x00")
+ if x and not x.startswith('A:')]
+ psutil_value = [x.mountpoint for x in psutil.disk_partitions(all=True)
+ if not x.mountpoint.startswith('A:')]
+ self.assertEqual(sys_value, psutil_value)
+
+ def test_net_if_stats(self):
+ ps_names = set(cext.net_if_stats())
+ wmi_adapters = wmi.WMI().Win32_NetworkAdapter()
+ wmi_names = set()
+ for wmi_adapter in wmi_adapters:
+ wmi_names.add(wmi_adapter.Name)
+ wmi_names.add(wmi_adapter.NetConnectionID)
+ self.assertTrue(ps_names & wmi_names,
+ "no common entries in %s, %s" % (ps_names, wmi_names))
+
+ def test_boot_time(self):
+ wmi_os = wmi.WMI().Win32_OperatingSystem()
+ wmi_btime_str = wmi_os[0].LastBootUpTime.split('.')[0]
+ wmi_btime_dt = datetime.datetime.strptime(
+ wmi_btime_str, "%Y%m%d%H%M%S")
+ psutil_dt = datetime.datetime.fromtimestamp(psutil.boot_time())
+ diff = abs((wmi_btime_dt - psutil_dt).total_seconds())
+ self.assertLessEqual(diff, 5)
+
+ def test_boot_time_fluctuation(self):
+ # https://github.com/giampaolo/psutil/issues/1007
+ with mock.patch('psutil._pswindows.cext.boot_time', return_value=5):
+ self.assertEqual(psutil.boot_time(), 5)
+ with mock.patch('psutil._pswindows.cext.boot_time', return_value=4):
+ self.assertEqual(psutil.boot_time(), 5)
+ with mock.patch('psutil._pswindows.cext.boot_time', return_value=6):
+ self.assertEqual(psutil.boot_time(), 5)
+ with mock.patch('psutil._pswindows.cext.boot_time', return_value=333):
+ self.assertEqual(psutil.boot_time(), 333)
+
+
+# ===================================================================
+# sensors_battery()
+# ===================================================================
+
+
+class TestSensorsBattery(WindowsTestCase):
+
+ def test_has_battery(self):
+ if win32api.GetPwrCapabilities()['SystemBatteriesPresent']:
+ self.assertIsNotNone(psutil.sensors_battery())
+ else:
+ self.assertIsNone(psutil.sensors_battery())
+
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_percent(self):
+ w = wmi.WMI()
+ battery_wmi = w.query('select * from Win32_Battery')[0]
+ battery_psutil = psutil.sensors_battery()
+ self.assertAlmostEqual(
+ battery_psutil.percent, battery_wmi.EstimatedChargeRemaining,
+ delta=1)
+
+ @unittest.skipIf(not HAS_BATTERY, "no battery")
+ def test_power_plugged(self):
+ w = wmi.WMI()
+ battery_wmi = w.query('select * from Win32_Battery')[0]
+ battery_psutil = psutil.sensors_battery()
+ # Status codes:
+ # https://msdn.microsoft.com/en-us/library/aa394074(v=vs.85).aspx
+ self.assertEqual(battery_psutil.power_plugged,
+ battery_wmi.BatteryStatus == 2)
+
+ def test_emulate_no_battery(self):
+ with mock.patch("psutil._pswindows.cext.sensors_battery",
+ return_value=(0, 128, 0, 0)) as m:
+ self.assertIsNone(psutil.sensors_battery())
+ assert m.called
+
+ def test_emulate_power_connected(self):
+ with mock.patch("psutil._pswindows.cext.sensors_battery",
+ return_value=(1, 0, 0, 0)) as m:
+ self.assertEqual(psutil.sensors_battery().secsleft,
+ psutil.POWER_TIME_UNLIMITED)
+ assert m.called
+
+ def test_emulate_power_charging(self):
+ with mock.patch("psutil._pswindows.cext.sensors_battery",
+ return_value=(0, 8, 0, 0)) as m:
+ self.assertEqual(psutil.sensors_battery().secsleft,
+ psutil.POWER_TIME_UNLIMITED)
+ assert m.called
+
+ def test_emulate_secs_left_unknown(self):
+ with mock.patch("psutil._pswindows.cext.sensors_battery",
+ return_value=(0, 0, 0, -1)) as m:
+ self.assertEqual(psutil.sensors_battery().secsleft,
+ psutil.POWER_TIME_UNKNOWN)
+ assert m.called
+
+
+# ===================================================================
+# Process APIs
+# ===================================================================
+
+
+class TestProcess(WindowsTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.pid = spawn_testproc().pid
+
+ @classmethod
+ def tearDownClass(cls):
+ terminate(cls.pid)
+
+ def test_issue_24(self):
+ p = psutil.Process(0)
+ self.assertRaises(psutil.AccessDenied, p.kill)
+
+ def test_special_pid(self):
+ p = psutil.Process(4)
+ self.assertEqual(p.name(), 'System')
+ # use __str__ to access all common Process properties to check
+ # that nothing strange happens
+ str(p)
+ p.username()
+ self.assertTrue(p.create_time() >= 0.0)
+ try:
+ rss, vms = p.memory_info()[:2]
+ except psutil.AccessDenied:
+ # expected on Windows Vista and Windows 7
+ if not platform.uname()[1] in ('vista', 'win-7', 'win7'):
+ raise
+ else:
+ self.assertTrue(rss > 0)
+
+ def test_send_signal(self):
+ p = psutil.Process(self.pid)
+ self.assertRaises(ValueError, p.send_signal, signal.SIGINT)
+
+ def test_num_handles_increment(self):
+ p = psutil.Process(os.getpid())
+ before = p.num_handles()
+ handle = win32api.OpenProcess(win32con.PROCESS_QUERY_INFORMATION,
+ win32con.FALSE, os.getpid())
+ after = p.num_handles()
+ self.assertEqual(after, before + 1)
+ win32api.CloseHandle(handle)
+ self.assertEqual(p.num_handles(), before)
+
+ def test_ctrl_signals(self):
+ p = psutil.Process(self.spawn_testproc().pid)
+ p.send_signal(signal.CTRL_C_EVENT)
+ p.send_signal(signal.CTRL_BREAK_EVENT)
+ p.kill()
+ p.wait()
+ self.assertRaises(psutil.NoSuchProcess,
+ p.send_signal, signal.CTRL_C_EVENT)
+ self.assertRaises(psutil.NoSuchProcess,
+ p.send_signal, signal.CTRL_BREAK_EVENT)
+
+ def test_username(self):
+ name = win32api.GetUserNameEx(win32con.NameSamCompatible)
+ if name.endswith('$'):
+ # When running as a service account (most likely to be
+ # NetworkService), these user name calculations don't produce the
+ # same result, causing the test to fail.
+ raise unittest.SkipTest('running as service account')
+ self.assertEqual(psutil.Process().username(), name)
+
+ def test_cmdline(self):
+ sys_value = re.sub('[ ]+', ' ', win32api.GetCommandLine()).strip()
+ psutil_value = ' '.join(psutil.Process().cmdline())
+ if sys_value[0] == '"' != psutil_value[0]:
+ # The PyWin32 command line may retain quotes around argv[0] if they
+ # were used unnecessarily, while psutil will omit them. So remove
+ # the first 2 quotes from sys_value if not in psutil_value.
+ # A path to an executable will not contain quotes, so this is safe.
+ sys_value = sys_value.replace('"', '', 2)
+ self.assertEqual(sys_value, psutil_value)
+
+ # XXX - occasional failures
+
+ # def test_cpu_times(self):
+ # handle = win32api.OpenProcess(win32con.PROCESS_QUERY_INFORMATION,
+ # win32con.FALSE, os.getpid())
+ # self.addCleanup(win32api.CloseHandle, handle)
+ # sys_value = win32process.GetProcessTimes(handle)
+ # psutil_value = psutil.Process().cpu_times()
+ # self.assertAlmostEqual(
+ # psutil_value.user, sys_value['UserTime'] / 10000000.0,
+ # delta=0.2)
+ # self.assertAlmostEqual(
+ # psutil_value.user, sys_value['KernelTime'] / 10000000.0,
+ # delta=0.2)
+
+ def test_nice(self):
+ handle = win32api.OpenProcess(win32con.PROCESS_QUERY_INFORMATION,
+ win32con.FALSE, os.getpid())
+ self.addCleanup(win32api.CloseHandle, handle)
+ sys_value = win32process.GetPriorityClass(handle)
+ psutil_value = psutil.Process().nice()
+ self.assertEqual(psutil_value, sys_value)
+
+ def test_memory_info(self):
+ handle = win32api.OpenProcess(win32con.PROCESS_QUERY_INFORMATION,
+ win32con.FALSE, self.pid)
+ self.addCleanup(win32api.CloseHandle, handle)
+ sys_value = win32process.GetProcessMemoryInfo(handle)
+ psutil_value = psutil.Process(self.pid).memory_info()
+ self.assertEqual(
+ sys_value['PeakWorkingSetSize'], psutil_value.peak_wset)
+ self.assertEqual(
+ sys_value['WorkingSetSize'], psutil_value.wset)
+ self.assertEqual(
+ sys_value['QuotaPeakPagedPoolUsage'], psutil_value.peak_paged_pool)
+ self.assertEqual(
+ sys_value['QuotaPagedPoolUsage'], psutil_value.paged_pool)
+ self.assertEqual(
+ sys_value['QuotaPeakNonPagedPoolUsage'],
+ psutil_value.peak_nonpaged_pool)
+ self.assertEqual(
+ sys_value['QuotaNonPagedPoolUsage'], psutil_value.nonpaged_pool)
+ self.assertEqual(
+ sys_value['PagefileUsage'], psutil_value.pagefile)
+ self.assertEqual(
+ sys_value['PeakPagefileUsage'], psutil_value.peak_pagefile)
+
+ self.assertEqual(psutil_value.rss, psutil_value.wset)
+ self.assertEqual(psutil_value.vms, psutil_value.pagefile)
+
+ def test_wait(self):
+ handle = win32api.OpenProcess(win32con.PROCESS_QUERY_INFORMATION,
+ win32con.FALSE, self.pid)
+ self.addCleanup(win32api.CloseHandle, handle)
+ p = psutil.Process(self.pid)
+ p.terminate()
+ psutil_value = p.wait()
+ sys_value = win32process.GetExitCodeProcess(handle)
+ self.assertEqual(psutil_value, sys_value)
+
+ def test_cpu_affinity(self):
+ def from_bitmask(x):
+ return [i for i in range(64) if (1 << i) & x]
+
+ handle = win32api.OpenProcess(win32con.PROCESS_QUERY_INFORMATION,
+ win32con.FALSE, self.pid)
+ self.addCleanup(win32api.CloseHandle, handle)
+ sys_value = from_bitmask(
+ win32process.GetProcessAffinityMask(handle)[0])
+ psutil_value = psutil.Process(self.pid).cpu_affinity()
+ self.assertEqual(psutil_value, sys_value)
+
+ def test_io_counters(self):
+ handle = win32api.OpenProcess(win32con.PROCESS_QUERY_INFORMATION,
+ win32con.FALSE, os.getpid())
+ self.addCleanup(win32api.CloseHandle, handle)
+ sys_value = win32process.GetProcessIoCounters(handle)
+ psutil_value = psutil.Process().io_counters()
+ self.assertEqual(
+ psutil_value.read_count, sys_value['ReadOperationCount'])
+ self.assertEqual(
+ psutil_value.write_count, sys_value['WriteOperationCount'])
+ self.assertEqual(
+ psutil_value.read_bytes, sys_value['ReadTransferCount'])
+ self.assertEqual(
+ psutil_value.write_bytes, sys_value['WriteTransferCount'])
+ self.assertEqual(
+ psutil_value.other_count, sys_value['OtherOperationCount'])
+ self.assertEqual(
+ psutil_value.other_bytes, sys_value['OtherTransferCount'])
+
+ def test_num_handles(self):
+ import ctypes
+ import ctypes.wintypes
+ PROCESS_QUERY_INFORMATION = 0x400
+ handle = ctypes.windll.kernel32.OpenProcess(
+ PROCESS_QUERY_INFORMATION, 0, self.pid)
+ self.addCleanup(ctypes.windll.kernel32.CloseHandle, handle)
+
+ hndcnt = ctypes.wintypes.DWORD()
+ ctypes.windll.kernel32.GetProcessHandleCount(
+ handle, ctypes.byref(hndcnt))
+ sys_value = hndcnt.value
+ psutil_value = psutil.Process(self.pid).num_handles()
+ self.assertEqual(psutil_value, sys_value)
+
+ def test_error_partial_copy(self):
+ # https://github.com/giampaolo/psutil/issues/875
+ exc = WindowsError()
+ exc.winerror = 299
+ with mock.patch("psutil._psplatform.cext.proc_cwd", side_effect=exc):
+ with mock.patch("time.sleep") as m:
+ p = psutil.Process()
+ self.assertRaises(psutil.AccessDenied, p.cwd)
+ self.assertGreaterEqual(m.call_count, 5)
+
+ def test_exe(self):
+ # NtQuerySystemInformation succeeds if process is gone. Make sure
+ # it raises NSP for a non existent pid.
+ pid = psutil.pids()[-1] + 99999
+ proc = psutil._psplatform.Process(pid)
+ self.assertRaises(psutil.NoSuchProcess, proc.exe)
+
+
+class TestProcessWMI(WindowsTestCase):
+ """Compare Process API results with WMI."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.pid = spawn_testproc().pid
+
+ @classmethod
+ def tearDownClass(cls):
+ terminate(cls.pid)
+
+ def test_name(self):
+ w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ p = psutil.Process(self.pid)
+ self.assertEqual(p.name(), w.Caption)
+
+ # This fail on github because using virtualenv for test environment
+ @unittest.skipIf(GITHUB_ACTIONS, "unreliable path on GITHUB_ACTIONS")
+ def test_exe(self):
+ w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ p = psutil.Process(self.pid)
+ # Note: wmi reports the exe as a lower case string.
+ # Being Windows paths case-insensitive we ignore that.
+ self.assertEqual(p.exe().lower(), w.ExecutablePath.lower())
+
+ def test_cmdline(self):
+ w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ p = psutil.Process(self.pid)
+ self.assertEqual(' '.join(p.cmdline()),
+ w.CommandLine.replace('"', ''))
+
+ def test_username(self):
+ w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ p = psutil.Process(self.pid)
+ domain, _, username = w.GetOwner()
+ username = "%s\\%s" % (domain, username)
+ self.assertEqual(p.username(), username)
+
+ @retry_on_failure()
+ def test_memory_rss(self):
+ w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ p = psutil.Process(self.pid)
+ rss = p.memory_info().rss
+ self.assertEqual(rss, int(w.WorkingSetSize))
+
+ @retry_on_failure()
+ def test_memory_vms(self):
+ w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ p = psutil.Process(self.pid)
+ vms = p.memory_info().vms
+ # http://msdn.microsoft.com/en-us/library/aa394372(VS.85).aspx
+ # ...claims that PageFileUsage is represented in Kilo
+ # bytes but funnily enough on certain platforms bytes are
+ # returned instead.
+ wmi_usage = int(w.PageFileUsage)
+ if (vms != wmi_usage) and (vms != wmi_usage * 1024):
+ raise self.fail("wmi=%s, psutil=%s" % (wmi_usage, vms))
+
+ def test_create_time(self):
+ w = wmi.WMI().Win32_Process(ProcessId=self.pid)[0]
+ p = psutil.Process(self.pid)
+ wmic_create = str(w.CreationDate.split('.')[0])
+ psutil_create = time.strftime("%Y%m%d%H%M%S",
+ time.localtime(p.create_time()))
+ self.assertEqual(wmic_create, psutil_create)
+
+
+# ---
+
+
+@unittest.skipIf(not WINDOWS, "WINDOWS only")
+class TestDualProcessImplementation(PsutilTestCase):
+ """
+ Certain APIs on Windows have 2 internal implementations, one
+ based on documented Windows APIs, another one based
+ NtQuerySystemInformation() which gets called as fallback in
+ case the first fails because of limited permission error.
+ Here we test that the two methods return the exact same value,
+ see:
+ https://github.com/giampaolo/psutil/issues/304
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ cls.pid = spawn_testproc().pid
+
+ @classmethod
+ def tearDownClass(cls):
+ terminate(cls.pid)
+
+ def test_memory_info(self):
+ mem_1 = psutil.Process(self.pid).memory_info()
+ with mock.patch("psutil._psplatform.cext.proc_memory_info",
+ side_effect=OSError(errno.EPERM, "msg")) as fun:
+ mem_2 = psutil.Process(self.pid).memory_info()
+ self.assertEqual(len(mem_1), len(mem_2))
+ for i in range(len(mem_1)):
+ self.assertGreaterEqual(mem_1[i], 0)
+ self.assertGreaterEqual(mem_2[i], 0)
+ self.assertAlmostEqual(mem_1[i], mem_2[i], delta=512)
+ assert fun.called
+
+ def test_create_time(self):
+ ctime = psutil.Process(self.pid).create_time()
+ with mock.patch("psutil._psplatform.cext.proc_times",
+ side_effect=OSError(errno.EPERM, "msg")) as fun:
+ self.assertEqual(psutil.Process(self.pid).create_time(), ctime)
+ assert fun.called
+
+ def test_cpu_times(self):
+ cpu_times_1 = psutil.Process(self.pid).cpu_times()
+ with mock.patch("psutil._psplatform.cext.proc_times",
+ side_effect=OSError(errno.EPERM, "msg")) as fun:
+ cpu_times_2 = psutil.Process(self.pid).cpu_times()
+ assert fun.called
+ self.assertAlmostEqual(
+ cpu_times_1.user, cpu_times_2.user, delta=0.01)
+ self.assertAlmostEqual(
+ cpu_times_1.system, cpu_times_2.system, delta=0.01)
+
+ def test_io_counters(self):
+ io_counters_1 = psutil.Process(self.pid).io_counters()
+ with mock.patch("psutil._psplatform.cext.proc_io_counters",
+ side_effect=OSError(errno.EPERM, "msg")) as fun:
+ io_counters_2 = psutil.Process(self.pid).io_counters()
+ for i in range(len(io_counters_1)):
+ self.assertAlmostEqual(
+ io_counters_1[i], io_counters_2[i], delta=5)
+ assert fun.called
+
+ def test_num_handles(self):
+ num_handles = psutil.Process(self.pid).num_handles()
+ with mock.patch("psutil._psplatform.cext.proc_num_handles",
+ side_effect=OSError(errno.EPERM, "msg")) as fun:
+ self.assertEqual(psutil.Process(self.pid).num_handles(),
+ num_handles)
+ assert fun.called
+
+ def test_cmdline(self):
+ for pid in psutil.pids():
+ try:
+ a = cext.proc_cmdline(pid, use_peb=True)
+ b = cext.proc_cmdline(pid, use_peb=False)
+ except OSError as err:
+ err = convert_oserror(err)
+ if not isinstance(err, (psutil.AccessDenied,
+ psutil.NoSuchProcess)):
+ raise
+ else:
+ self.assertEqual(a, b)
+
+
+@unittest.skipIf(not WINDOWS, "WINDOWS only")
+class RemoteProcessTestCase(PsutilTestCase):
+ """Certain functions require calling ReadProcessMemory.
+ This trivially works when called on the current process.
+ Check that this works on other processes, especially when they
+ have a different bitness.
+ """
+
+ @staticmethod
+ def find_other_interpreter():
+ # find a python interpreter that is of the opposite bitness from us
+ code = "import sys; sys.stdout.write(str(sys.maxsize > 2**32))"
+
+ # XXX: a different and probably more stable approach might be to access
+ # the registry but accessing 64 bit paths from a 32 bit process
+ for filename in glob.glob(r"C:\Python*\python.exe"):
+ proc = subprocess.Popen(args=[filename, "-c", code],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+ output, _ = proc.communicate()
+ proc.wait()
+ if output == str(not IS_64BIT):
+ return filename
+
+ test_args = ["-c", "import sys; sys.stdin.read()"]
+
+ def setUp(self):
+ super().setUp()
+
+ other_python = self.find_other_interpreter()
+ if other_python is None:
+ raise unittest.SkipTest(
+ "could not find interpreter with opposite bitness")
+ if IS_64BIT:
+ self.python64 = sys.executable
+ self.python32 = other_python
+ else:
+ self.python64 = other_python
+ self.python32 = sys.executable
+
+ env = os.environ.copy()
+ env["THINK_OF_A_NUMBER"] = str(os.getpid())
+ self.proc32 = self.spawn_testproc(
+ [self.python32] + self.test_args,
+ env=env,
+ stdin=subprocess.PIPE)
+ self.proc64 = self.spawn_testproc(
+ [self.python64] + self.test_args,
+ env=env,
+ stdin=subprocess.PIPE)
+
+ def tearDown(self):
+ super().tearDown()
+ self.proc32.communicate()
+ self.proc64.communicate()
+
+ def test_cmdline_32(self):
+ p = psutil.Process(self.proc32.pid)
+ self.assertEqual(len(p.cmdline()), 3)
+ self.assertEqual(p.cmdline()[1:], self.test_args)
+
+ def test_cmdline_64(self):
+ p = psutil.Process(self.proc64.pid)
+ self.assertEqual(len(p.cmdline()), 3)
+ self.assertEqual(p.cmdline()[1:], self.test_args)
+
+ def test_cwd_32(self):
+ p = psutil.Process(self.proc32.pid)
+ self.assertEqual(p.cwd(), os.getcwd())
+
+ def test_cwd_64(self):
+ p = psutil.Process(self.proc64.pid)
+ self.assertEqual(p.cwd(), os.getcwd())
+
+ def test_environ_32(self):
+ p = psutil.Process(self.proc32.pid)
+ e = p.environ()
+ self.assertIn("THINK_OF_A_NUMBER", e)
+ self.assertEqual(e["THINK_OF_A_NUMBER"], str(os.getpid()))
+
+ def test_environ_64(self):
+ p = psutil.Process(self.proc64.pid)
+ try:
+ p.environ()
+ except psutil.AccessDenied:
+ pass
+
+
+# ===================================================================
+# Windows services
+# ===================================================================
+
+
+@unittest.skipIf(not WINDOWS, "WINDOWS only")
+class TestServices(PsutilTestCase):
+
+ def test_win_service_iter(self):
+ valid_statuses = set([
+ "running",
+ "paused",
+ "start",
+ "pause",
+ "continue",
+ "stop",
+ "stopped",
+ ])
+ valid_start_types = set([
+ "automatic",
+ "manual",
+ "disabled",
+ ])
+ valid_statuses = set([
+ "running",
+ "paused",
+ "start_pending",
+ "pause_pending",
+ "continue_pending",
+ "stop_pending",
+ "stopped"
+ ])
+ for serv in psutil.win_service_iter():
+ data = serv.as_dict()
+ self.assertIsInstance(data['name'], str)
+ self.assertNotEqual(data['name'].strip(), "")
+ self.assertIsInstance(data['display_name'], str)
+ self.assertIsInstance(data['username'], str)
+ self.assertIn(data['status'], valid_statuses)
+ if data['pid'] is not None:
+ psutil.Process(data['pid'])
+ self.assertIsInstance(data['binpath'], str)
+ self.assertIsInstance(data['username'], str)
+ self.assertIsInstance(data['start_type'], str)
+ self.assertIn(data['start_type'], valid_start_types)
+ self.assertIn(data['status'], valid_statuses)
+ self.assertIsInstance(data['description'], str)
+ pid = serv.pid()
+ if pid is not None:
+ p = psutil.Process(pid)
+ self.assertTrue(p.is_running())
+ # win_service_get
+ s = psutil.win_service_get(serv.name())
+ # test __eq__
+ self.assertEqual(serv, s)
+
+ def test_win_service_get(self):
+ ERROR_SERVICE_DOES_NOT_EXIST = \
+ psutil._psplatform.cext.ERROR_SERVICE_DOES_NOT_EXIST
+ ERROR_ACCESS_DENIED = psutil._psplatform.cext.ERROR_ACCESS_DENIED
+
+ name = next(psutil.win_service_iter()).name()
+ with self.assertRaises(psutil.NoSuchProcess) as cm:
+ psutil.win_service_get(name + '???')
+ self.assertEqual(cm.exception.name, name + '???')
+
+ # test NoSuchProcess
+ service = psutil.win_service_get(name)
+ if PY3:
+ args = (0, "msg", 0, ERROR_SERVICE_DOES_NOT_EXIST)
+ else:
+ args = (ERROR_SERVICE_DOES_NOT_EXIST, "msg")
+ exc = WindowsError(*args)
+ with mock.patch("psutil._psplatform.cext.winservice_query_status",
+ side_effect=exc):
+ self.assertRaises(psutil.NoSuchProcess, service.status)
+ with mock.patch("psutil._psplatform.cext.winservice_query_config",
+ side_effect=exc):
+ self.assertRaises(psutil.NoSuchProcess, service.username)
+
+ # test AccessDenied
+ if PY3:
+ args = (0, "msg", 0, ERROR_ACCESS_DENIED)
+ else:
+ args = (ERROR_ACCESS_DENIED, "msg")
+ exc = WindowsError(*args)
+ with mock.patch("psutil._psplatform.cext.winservice_query_status",
+ side_effect=exc):
+ self.assertRaises(psutil.AccessDenied, service.status)
+ with mock.patch("psutil._psplatform.cext.winservice_query_config",
+ side_effect=exc):
+ self.assertRaises(psutil.AccessDenied, service.username)
+
+ # test __str__ and __repr__
+ self.assertIn(service.name(), str(service))
+ self.assertIn(service.display_name(), str(service))
+ self.assertIn(service.name(), repr(service))
+ self.assertIn(service.display_name(), repr(service))
+
+
+if __name__ == '__main__':
+ from psutil.tests.runner import run_from_name
+ run_from_name(__file__)
diff --git a/lib/ptyprocess-0.7.0.dist-info/INSTALLER b/lib/ptyprocess-0.7.0.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/ptyprocess-0.7.0.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/ptyprocess-0.7.0.dist-info/LICENSE b/lib/ptyprocess-0.7.0.dist-info/LICENSE
new file mode 100644
index 0000000..9c77274
--- /dev/null
+++ b/lib/ptyprocess-0.7.0.dist-info/LICENSE
@@ -0,0 +1,16 @@
+Ptyprocess is under the ISC license, as code derived from Pexpect.
+ http://opensource.org/licenses/ISC
+
+Copyright (c) 2013-2014, Pexpect development team
+Copyright (c) 2012, Noah Spurrier <noah@noah.org>
+
+PERMISSION TO USE, COPY, MODIFY, AND/OR DISTRIBUTE THIS SOFTWARE FOR ANY PURPOSE
+WITH OR WITHOUT FEE IS HEREBY GRANTED, PROVIDED THAT THE ABOVE COPYRIGHT NOTICE
+AND THIS PERMISSION NOTICE APPEAR IN ALL COPIES. THE SOFTWARE IS PROVIDED
+"AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE
+INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT
+SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL
+DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
+WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
+OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
diff --git a/lib/ptyprocess-0.7.0.dist-info/METADATA b/lib/ptyprocess-0.7.0.dist-info/METADATA
new file mode 100644
index 0000000..ab1d4e0
--- /dev/null
+++ b/lib/ptyprocess-0.7.0.dist-info/METADATA
@@ -0,0 +1,37 @@
+Metadata-Version: 2.1
+Name: ptyprocess
+Version: 0.7.0
+Summary: Run a subprocess in a pseudo terminal
+Home-page: https://github.com/pexpect/ptyprocess
+License: UNKNOWN
+Author: Thomas Kluyver
+Author-email: thomas@kluyver.me.uk
+Description-Content-Type: text/x-rst
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Environment :: Console
+Classifier: Intended Audience :: Developers
+Classifier: Intended Audience :: System Administrators
+Classifier: License :: OSI Approved :: ISC License (ISCL)
+Classifier: Operating System :: POSIX
+Classifier: Operating System :: MacOS :: MacOS X
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Topic :: Terminals
+
+Launch a subprocess in a pseudo terminal (pty), and interact with both the
+process and its pty.
+
+Sometimes, piping stdin and stdout is not enough. There might be a password
+prompt that doesn't read from stdin, output that changes when it's going to a
+pipe rather than a terminal, or curses-style interfaces that rely on a terminal.
+If you need to automate these things, running the process in a pseudo terminal
+(pty) is the answer.
+
+Interface::
+
+ p = PtyProcessUnicode.spawn(['python'])
+ p.read(20)
+ p.write('6+6\n')
+ p.read(20)
+
diff --git a/lib/ptyprocess-0.7.0.dist-info/RECORD b/lib/ptyprocess-0.7.0.dist-info/RECORD
new file mode 100644
index 0000000..dd78069
--- /dev/null
+++ b/lib/ptyprocess-0.7.0.dist-info/RECORD
@@ -0,0 +1,13 @@
+ptyprocess-0.7.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+ptyprocess-0.7.0.dist-info/LICENSE,sha256=yCLThbGnMymEYkF5m-zxhpC11Edkwb7WkwC1NqQFAwo,905
+ptyprocess-0.7.0.dist-info/METADATA,sha256=w8K5a12aVdpZWMNNCMGKCEn1ZgkCbMRtXJW4t3_PPgw,1312
+ptyprocess-0.7.0.dist-info/RECORD,,
+ptyprocess-0.7.0.dist-info/WHEEL,sha256=NLqmsx-ZFZ6gDavYgh2oH0ZSN-KRmpcdEXIZDnYy9Pg,99
+ptyprocess/__init__.py,sha256=sn-W_1nNRTuIOi2aCEHVL06wCVJcR-LOZdgpXzwFuTU,138
+ptyprocess/__pycache__/__init__.cpython-39.pyc,,
+ptyprocess/__pycache__/_fork_pty.cpython-39.pyc,,
+ptyprocess/__pycache__/ptyprocess.cpython-39.pyc,,
+ptyprocess/__pycache__/util.cpython-39.pyc,,
+ptyprocess/_fork_pty.py,sha256=VVvMy8c4ZpjDMiIMSg8T1BQ1g3SBexDpey_cxi0n5aw,2362
+ptyprocess/ptyprocess.py,sha256=sk2sU2I22Yyl1gU3FjFmpWL3B43o0KqG3d3CI8r0Nq8,31686
+ptyprocess/util.py,sha256=rQAdDRZfoOiOn6vykWth0wI6FFKAp7aJtBSdt-KBWdU,2785
diff --git a/lib/ptyprocess-0.7.0.dist-info/WHEEL b/lib/ptyprocess-0.7.0.dist-info/WHEEL
new file mode 100644
index 0000000..3825653
--- /dev/null
+++ b/lib/ptyprocess-0.7.0.dist-info/WHEEL
@@ -0,0 +1,5 @@
+Wheel-Version: 1.0
+Generator: flit 3.0.0
+Root-Is-Purelib: true
+Tag: py2-none-any
+Tag: py3-none-any
diff --git a/lib/ptyprocess/__init__.py b/lib/ptyprocess/__init__.py
new file mode 100644
index 0000000..3a6268e
--- /dev/null
+++ b/lib/ptyprocess/__init__.py
@@ -0,0 +1,4 @@
+"""Run a subprocess in a pseudo terminal"""
+from .ptyprocess import PtyProcess, PtyProcessUnicode, PtyProcessError
+
+__version__ = '0.7.0'
diff --git a/lib/ptyprocess/_fork_pty.py b/lib/ptyprocess/_fork_pty.py
new file mode 100644
index 0000000..a8d05fe
--- /dev/null
+++ b/lib/ptyprocess/_fork_pty.py
@@ -0,0 +1,78 @@
+"""Substitute for the forkpty system call, to support Solaris.
+"""
+import os
+import errno
+
+from pty import (STDIN_FILENO, STDOUT_FILENO, STDERR_FILENO, CHILD)
+from .util import PtyProcessError
+
+def fork_pty():
+ '''This implements a substitute for the forkpty system call. This
+ should be more portable than the pty.fork() function. Specifically,
+ this should work on Solaris.
+
+ Modified 10.06.05 by Geoff Marshall: Implemented __fork_pty() method to
+ resolve the issue with Python's pty.fork() not supporting Solaris,
+ particularly ssh. Based on patch to posixmodule.c authored by Noah
+ Spurrier::
+
+ http://mail.python.org/pipermail/python-dev/2003-May/035281.html
+
+ '''
+
+ parent_fd, child_fd = os.openpty()
+ if parent_fd < 0 or child_fd < 0:
+ raise OSError("os.openpty() failed")
+
+ pid = os.fork()
+ if pid == CHILD:
+ # Child.
+ os.close(parent_fd)
+ pty_make_controlling_tty(child_fd)
+
+ os.dup2(child_fd, STDIN_FILENO)
+ os.dup2(child_fd, STDOUT_FILENO)
+ os.dup2(child_fd, STDERR_FILENO)
+
+ else:
+ # Parent.
+ os.close(child_fd)
+
+ return pid, parent_fd
+
+def pty_make_controlling_tty(tty_fd):
+ '''This makes the pseudo-terminal the controlling tty. This should be
+ more portable than the pty.fork() function. Specifically, this should
+ work on Solaris. '''
+
+ child_name = os.ttyname(tty_fd)
+
+ # Disconnect from controlling tty, if any. Raises OSError of ENXIO
+ # if there was no controlling tty to begin with, such as when
+ # executed by a cron(1) job.
+ try:
+ fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
+ os.close(fd)
+ except OSError as err:
+ if err.errno != errno.ENXIO:
+ raise
+
+ os.setsid()
+
+ # Verify we are disconnected from controlling tty by attempting to open
+ # it again. We expect that OSError of ENXIO should always be raised.
+ try:
+ fd = os.open("/dev/tty", os.O_RDWR | os.O_NOCTTY)
+ os.close(fd)
+ raise PtyProcessError("OSError of errno.ENXIO should be raised.")
+ except OSError as err:
+ if err.errno != errno.ENXIO:
+ raise
+
+ # Verify we can open child pty.
+ fd = os.open(child_name, os.O_RDWR)
+ os.close(fd)
+
+ # Verify we now have a controlling tty.
+ fd = os.open("/dev/tty", os.O_WRONLY)
+ os.close(fd)
diff --git a/lib/ptyprocess/ptyprocess.py b/lib/ptyprocess/ptyprocess.py
new file mode 100644
index 0000000..78d19fd
--- /dev/null
+++ b/lib/ptyprocess/ptyprocess.py
@@ -0,0 +1,842 @@
+import codecs
+import errno
+import fcntl
+import io
+import os
+import pty
+import resource
+import signal
+import struct
+import sys
+import termios
+import time
+
+try:
+ import builtins # Python 3
+except ImportError:
+ import __builtin__ as builtins # Python 2
+
+# Constants
+from pty import (STDIN_FILENO, CHILD)
+
+from .util import which, PtyProcessError
+
+_platform = sys.platform.lower()
+
+# Solaris uses internal __fork_pty(). All others use pty.fork().
+_is_solaris = (
+ _platform.startswith('solaris') or
+ _platform.startswith('sunos'))
+
+if _is_solaris:
+ use_native_pty_fork = False
+ from . import _fork_pty
+else:
+ use_native_pty_fork = True
+
+PY3 = sys.version_info[0] >= 3
+
+if PY3:
+ def _byte(i):
+ return bytes([i])
+else:
+ def _byte(i):
+ return chr(i)
+
+ class FileNotFoundError(OSError): pass
+ class TimeoutError(OSError): pass
+
+_EOF, _INTR = None, None
+
+def _make_eof_intr():
+ """Set constants _EOF and _INTR.
+
+ This avoids doing potentially costly operations on module load.
+ """
+ global _EOF, _INTR
+ if (_EOF is not None) and (_INTR is not None):
+ return
+
+ # inherit EOF and INTR definitions from controlling process.
+ try:
+ from termios import VEOF, VINTR
+ fd = None
+ for name in 'stdin', 'stdout':
+ stream = getattr(sys, '__%s__' % name, None)
+ if stream is None or not hasattr(stream, 'fileno'):
+ continue
+ try:
+ fd = stream.fileno()
+ except ValueError:
+ continue
+ if fd is None:
+ # no fd, raise ValueError to fallback on CEOF, CINTR
+ raise ValueError("No stream has a fileno")
+ intr = ord(termios.tcgetattr(fd)[6][VINTR])
+ eof = ord(termios.tcgetattr(fd)[6][VEOF])
+ except (ImportError, OSError, IOError, ValueError, termios.error):
+ # unless the controlling process is also not a terminal,
+ # such as cron(1), or when stdin and stdout are both closed.
+ # Fall-back to using CEOF and CINTR. There
+ try:
+ from termios import CEOF, CINTR
+ (intr, eof) = (CINTR, CEOF)
+ except ImportError:
+ # ^C, ^D
+ (intr, eof) = (3, 4)
+
+ _INTR = _byte(intr)
+ _EOF = _byte(eof)
+
+# setecho and setwinsize are pulled out here because on some platforms, we need
+# to do this from the child before we exec()
+
+def _setecho(fd, state):
+ errmsg = 'setecho() may not be called on this platform (it may still be possible to enable/disable echo when spawning the child process)'
+
+ try:
+ attr = termios.tcgetattr(fd)
+ except termios.error as err:
+ if err.args[0] == errno.EINVAL:
+ raise IOError(err.args[0], '%s: %s.' % (err.args[1], errmsg))
+ raise
+
+ if state:
+ attr[3] = attr[3] | termios.ECHO
+ else:
+ attr[3] = attr[3] & ~termios.ECHO
+
+ try:
+ # I tried TCSADRAIN and TCSAFLUSH, but these were inconsistent and
+ # blocked on some platforms. TCSADRAIN would probably be ideal.
+ termios.tcsetattr(fd, termios.TCSANOW, attr)
+ except IOError as err:
+ if err.args[0] == errno.EINVAL:
+ raise IOError(err.args[0], '%s: %s.' % (err.args[1], errmsg))
+ raise
+
+def _setwinsize(fd, rows, cols):
+ # Some very old platforms have a bug that causes the value for
+ # termios.TIOCSWINSZ to be truncated. There was a hack here to work
+ # around this, but it caused problems with newer platforms so has been
+ # removed. For details see https://github.com/pexpect/pexpect/issues/39
+ TIOCSWINSZ = getattr(termios, 'TIOCSWINSZ', -2146929561)
+ # Note, assume ws_xpixel and ws_ypixel are zero.
+ s = struct.pack('HHHH', rows, cols, 0, 0)
+ fcntl.ioctl(fd, TIOCSWINSZ, s)
+
+class PtyProcess(object):
+ '''This class represents a process running in a pseudoterminal.
+
+ The main constructor is the :meth:`spawn` classmethod.
+ '''
+ string_type = bytes
+ if PY3:
+ linesep = os.linesep.encode('ascii')
+ crlf = '\r\n'.encode('ascii')
+
+ @staticmethod
+ def write_to_stdout(b):
+ try:
+ return sys.stdout.buffer.write(b)
+ except AttributeError:
+ # If stdout has been replaced, it may not have .buffer
+ return sys.stdout.write(b.decode('ascii', 'replace'))
+ else:
+ linesep = os.linesep
+ crlf = '\r\n'
+ write_to_stdout = sys.stdout.write
+
+ encoding = None
+
+ argv = None
+ env = None
+ launch_dir = None
+
+ def __init__(self, pid, fd):
+ _make_eof_intr() # Ensure _EOF and _INTR are calculated
+ self.pid = pid
+ self.fd = fd
+ readf = io.open(fd, 'rb', buffering=0)
+ writef = io.open(fd, 'wb', buffering=0, closefd=False)
+ self.fileobj = io.BufferedRWPair(readf, writef)
+
+ self.terminated = False
+ self.closed = False
+ self.exitstatus = None
+ self.signalstatus = None
+ # status returned by os.waitpid
+ self.status = None
+ self.flag_eof = False
+ # Used by close() to give kernel time to update process status.
+ # Time in seconds.
+ self.delayafterclose = 0.1
+ # Used by terminate() to give kernel time to update process status.
+ # Time in seconds.
+ self.delayafterterminate = 0.1
+
+ @classmethod
+ def spawn(
+ cls, argv, cwd=None, env=None, echo=True, preexec_fn=None,
+ dimensions=(24, 80), pass_fds=()):
+ '''Start the given command in a child process in a pseudo terminal.
+
+ This does all the fork/exec type of stuff for a pty, and returns an
+ instance of PtyProcess.
+
+ If preexec_fn is supplied, it will be called with no arguments in the
+ child process before exec-ing the specified command.
+ It may, for instance, set signal handlers to SIG_DFL or SIG_IGN.
+
+ Dimensions of the psuedoterminal used for the subprocess can be
+ specified as a tuple (rows, cols), or the default (24, 80) will be used.
+
+ By default, all file descriptors except 0, 1 and 2 are closed. This
+ behavior can be overridden with pass_fds, a list of file descriptors to
+ keep open between the parent and the child.
+ '''
+ # Note that it is difficult for this method to fail.
+ # You cannot detect if the child process cannot start.
+ # So the only way you can tell if the child process started
+ # or not is to try to read from the file descriptor. If you get
+ # EOF immediately then it means that the child is already dead.
+ # That may not necessarily be bad because you may have spawned a child
+ # that performs some task; creates no stdout output; and then dies.
+
+ if not isinstance(argv, (list, tuple)):
+ raise TypeError("Expected a list or tuple for argv, got %r" % argv)
+
+ # Shallow copy of argv so we can modify it
+ argv = argv[:]
+ command = argv[0]
+
+ command_with_path = which(command)
+ if command_with_path is None:
+ raise FileNotFoundError('The command was not found or was not ' +
+ 'executable: %s.' % command)
+ command = command_with_path
+ argv[0] = command
+
+ # [issue #119] To prevent the case where exec fails and the user is
+ # stuck interacting with a python child process instead of whatever
+ # was expected, we implement the solution from
+ # http://stackoverflow.com/a/3703179 to pass the exception to the
+ # parent process
+
+ # [issue #119] 1. Before forking, open a pipe in the parent process.
+ exec_err_pipe_read, exec_err_pipe_write = os.pipe()
+
+ if use_native_pty_fork:
+ pid, fd = pty.fork()
+ else:
+ # Use internal fork_pty, for Solaris
+ pid, fd = _fork_pty.fork_pty()
+
+ # Some platforms must call setwinsize() and setecho() from the
+ # child process, and others from the master process. We do both,
+ # allowing IOError for either.
+
+ if pid == CHILD:
+ # set window size
+ try:
+ _setwinsize(STDIN_FILENO, *dimensions)
+ except IOError as err:
+ if err.args[0] not in (errno.EINVAL, errno.ENOTTY):
+ raise
+
+ # disable echo if spawn argument echo was unset
+ if not echo:
+ try:
+ _setecho(STDIN_FILENO, False)
+ except (IOError, termios.error) as err:
+ if err.args[0] not in (errno.EINVAL, errno.ENOTTY):
+ raise
+
+ # [issue #119] 3. The child closes the reading end and sets the
+ # close-on-exec flag for the writing end.
+ os.close(exec_err_pipe_read)
+ fcntl.fcntl(exec_err_pipe_write, fcntl.F_SETFD, fcntl.FD_CLOEXEC)
+
+ # Do not allow child to inherit open file descriptors from parent,
+ # with the exception of the exec_err_pipe_write of the pipe
+ # and pass_fds.
+ # Impose ceiling on max_fd: AIX bugfix for users with unlimited
+ # nofiles where resource.RLIMIT_NOFILE is 2^63-1 and os.closerange()
+ # occasionally raises out of range error
+ max_fd = min(1048576, resource.getrlimit(resource.RLIMIT_NOFILE)[0])
+ spass_fds = sorted(set(pass_fds) | {exec_err_pipe_write})
+ for pair in zip([2] + spass_fds, spass_fds + [max_fd]):
+ os.closerange(pair[0]+1, pair[1])
+
+ if cwd is not None:
+ os.chdir(cwd)
+
+ if preexec_fn is not None:
+ try:
+ preexec_fn()
+ except Exception as e:
+ ename = type(e).__name__
+ tosend = '{}:0:{}'.format(ename, str(e))
+ if PY3:
+ tosend = tosend.encode('utf-8')
+
+ os.write(exec_err_pipe_write, tosend)
+ os.close(exec_err_pipe_write)
+ os._exit(1)
+
+ try:
+ if env is None:
+ os.execv(command, argv)
+ else:
+ os.execvpe(command, argv, env)
+ except OSError as err:
+ # [issue #119] 5. If exec fails, the child writes the error
+ # code back to the parent using the pipe, then exits.
+ tosend = 'OSError:{}:{}'.format(err.errno, str(err))
+ if PY3:
+ tosend = tosend.encode('utf-8')
+ os.write(exec_err_pipe_write, tosend)
+ os.close(exec_err_pipe_write)
+ os._exit(os.EX_OSERR)
+
+ # Parent
+ inst = cls(pid, fd)
+
+ # Set some informational attributes
+ inst.argv = argv
+ if env is not None:
+ inst.env = env
+ if cwd is not None:
+ inst.launch_dir = cwd
+
+ # [issue #119] 2. After forking, the parent closes the writing end
+ # of the pipe and reads from the reading end.
+ os.close(exec_err_pipe_write)
+ exec_err_data = os.read(exec_err_pipe_read, 4096)
+ os.close(exec_err_pipe_read)
+
+ # [issue #119] 6. The parent reads eof (a zero-length read) if the
+ # child successfully performed exec, since close-on-exec made
+ # successful exec close the writing end of the pipe. Or, if exec
+ # failed, the parent reads the error code and can proceed
+ # accordingly. Either way, the parent blocks until the child calls
+ # exec.
+ if len(exec_err_data) != 0:
+ try:
+ errclass, errno_s, errmsg = exec_err_data.split(b':', 2)
+ exctype = getattr(builtins, errclass.decode('ascii'), Exception)
+
+ exception = exctype(errmsg.decode('utf-8', 'replace'))
+ if exctype is OSError:
+ exception.errno = int(errno_s)
+ except:
+ raise Exception('Subprocess failed, got bad error data: %r'
+ % exec_err_data)
+ else:
+ raise exception
+
+ try:
+ inst.setwinsize(*dimensions)
+ except IOError as err:
+ if err.args[0] not in (errno.EINVAL, errno.ENOTTY, errno.ENXIO):
+ raise
+
+ return inst
+
+ def __repr__(self):
+ clsname = type(self).__name__
+ if self.argv is not None:
+ args = [repr(self.argv)]
+ if self.env is not None:
+ args.append("env=%r" % self.env)
+ if self.launch_dir is not None:
+ args.append("cwd=%r" % self.launch_dir)
+
+ return "{}.spawn({})".format(clsname, ", ".join(args))
+
+ else:
+ return "{}(pid={}, fd={})".format(clsname, self.pid, self.fd)
+
+ @staticmethod
+ def _coerce_send_string(s):
+ if not isinstance(s, bytes):
+ return s.encode('utf-8')
+ return s
+
+ @staticmethod
+ def _coerce_read_string(s):
+ return s
+
+ def __del__(self):
+ '''This makes sure that no system resources are left open. Python only
+ garbage collects Python objects. OS file descriptors are not Python
+ objects, so they must be handled explicitly. If the child file
+ descriptor was opened outside of this class (passed to the constructor)
+ then this does not close it. '''
+
+ if not self.closed:
+ # It is possible for __del__ methods to execute during the
+ # teardown of the Python VM itself. Thus self.close() may
+ # trigger an exception because os.close may be None.
+ try:
+ self.close()
+ # which exception, shouldn't we catch explicitly .. ?
+ except:
+ pass
+
+
+ def fileno(self):
+ '''This returns the file descriptor of the pty for the child.
+ '''
+ return self.fd
+
+ def close(self, force=True):
+ '''This closes the connection with the child application. Note that
+ calling close() more than once is valid. This emulates standard Python
+ behavior with files. Set force to True if you want to make sure that
+ the child is terminated (SIGKILL is sent if the child ignores SIGHUP
+ and SIGINT). '''
+ if not self.closed:
+ self.flush()
+ self.fileobj.close() # Closes the file descriptor
+ # Give kernel time to update process status.
+ time.sleep(self.delayafterclose)
+ if self.isalive():
+ if not self.terminate(force):
+ raise PtyProcessError('Could not terminate the child.')
+ self.fd = -1
+ self.closed = True
+ #self.pid = None
+
+ def flush(self):
+ '''This does nothing. It is here to support the interface for a
+ File-like object. '''
+
+ pass
+
+ def isatty(self):
+ '''This returns True if the file descriptor is open and connected to a
+ tty(-like) device, else False.
+
+ On SVR4-style platforms implementing streams, such as SunOS and HP-UX,
+ the child pty may not appear as a terminal device. This means
+ methods such as setecho(), setwinsize(), getwinsize() may raise an
+ IOError. '''
+
+ return os.isatty(self.fd)
+
+ def waitnoecho(self, timeout=None):
+ '''This waits until the terminal ECHO flag is set False. This returns
+ True if the echo mode is off. This returns False if the ECHO flag was
+ not set False before the timeout. This can be used to detect when the
+ child is waiting for a password. Usually a child application will turn
+ off echo mode when it is waiting for the user to enter a password. For
+ example, instead of expecting the "password:" prompt you can wait for
+ the child to set ECHO off::
+
+ p = pexpect.spawn('ssh user@example.com')
+ p.waitnoecho()
+ p.sendline(mypassword)
+
+ If timeout==None then this method to block until ECHO flag is False.
+ '''
+
+ if timeout is not None:
+ end_time = time.time() + timeout
+ while True:
+ if not self.getecho():
+ return True
+ if timeout < 0 and timeout is not None:
+ return False
+ if timeout is not None:
+ timeout = end_time - time.time()
+ time.sleep(0.1)
+
+ def getecho(self):
+ '''This returns the terminal echo mode. This returns True if echo is
+ on or False if echo is off. Child applications that are expecting you
+ to enter a password often set ECHO False. See waitnoecho().
+
+ Not supported on platforms where ``isatty()`` returns False. '''
+
+ try:
+ attr = termios.tcgetattr(self.fd)
+ except termios.error as err:
+ errmsg = 'getecho() may not be called on this platform'
+ if err.args[0] == errno.EINVAL:
+ raise IOError(err.args[0], '%s: %s.' % (err.args[1], errmsg))
+ raise
+
+ self.echo = bool(attr[3] & termios.ECHO)
+ return self.echo
+
+ def setecho(self, state):
+ '''This sets the terminal echo mode on or off. Note that anything the
+ child sent before the echo will be lost, so you should be sure that
+ your input buffer is empty before you call setecho(). For example, the
+ following will work as expected::
+
+ p = pexpect.spawn('cat') # Echo is on by default.
+ p.sendline('1234') # We expect see this twice from the child...
+ p.expect(['1234']) # ... once from the tty echo...
+ p.expect(['1234']) # ... and again from cat itself.
+ p.setecho(False) # Turn off tty echo
+ p.sendline('abcd') # We will set this only once (echoed by cat).
+ p.sendline('wxyz') # We will set this only once (echoed by cat)
+ p.expect(['abcd'])
+ p.expect(['wxyz'])
+
+ The following WILL NOT WORK because the lines sent before the setecho
+ will be lost::
+
+ p = pexpect.spawn('cat')
+ p.sendline('1234')
+ p.setecho(False) # Turn off tty echo
+ p.sendline('abcd') # We will set this only once (echoed by cat).
+ p.sendline('wxyz') # We will set this only once (echoed by cat)
+ p.expect(['1234'])
+ p.expect(['1234'])
+ p.expect(['abcd'])
+ p.expect(['wxyz'])
+
+
+ Not supported on platforms where ``isatty()`` returns False.
+ '''
+ _setecho(self.fd, state)
+
+ self.echo = state
+
+ def read(self, size=1024):
+ """Read and return at most ``size`` bytes from the pty.
+
+ Can block if there is nothing to read. Raises :exc:`EOFError` if the
+ terminal was closed.
+
+ Unlike Pexpect's ``read_nonblocking`` method, this doesn't try to deal
+ with the vagaries of EOF on platforms that do strange things, like IRIX
+ or older Solaris systems. It handles the errno=EIO pattern used on
+ Linux, and the empty-string return used on BSD platforms and (seemingly)
+ on recent Solaris.
+ """
+ try:
+ s = self.fileobj.read1(size)
+ except (OSError, IOError) as err:
+ if err.args[0] == errno.EIO:
+ # Linux-style EOF
+ self.flag_eof = True
+ raise EOFError('End Of File (EOF). Exception style platform.')
+ raise
+ if s == b'':
+ # BSD-style EOF (also appears to work on recent Solaris (OpenIndiana))
+ self.flag_eof = True
+ raise EOFError('End Of File (EOF). Empty string style platform.')
+
+ return s
+
+ def readline(self):
+ """Read one line from the pseudoterminal, and return it as unicode.
+
+ Can block if there is nothing to read. Raises :exc:`EOFError` if the
+ terminal was closed.
+ """
+ try:
+ s = self.fileobj.readline()
+ except (OSError, IOError) as err:
+ if err.args[0] == errno.EIO:
+ # Linux-style EOF
+ self.flag_eof = True
+ raise EOFError('End Of File (EOF). Exception style platform.')
+ raise
+ if s == b'':
+ # BSD-style EOF (also appears to work on recent Solaris (OpenIndiana))
+ self.flag_eof = True
+ raise EOFError('End Of File (EOF). Empty string style platform.')
+
+ return s
+
+ def _writeb(self, b, flush=True):
+ n = self.fileobj.write(b)
+ if flush:
+ self.fileobj.flush()
+ return n
+
+ def write(self, s, flush=True):
+ """Write bytes to the pseudoterminal.
+
+ Returns the number of bytes written.
+ """
+ return self._writeb(s, flush=flush)
+
+ def sendcontrol(self, char):
+ '''Helper method that wraps send() with mnemonic access for sending control
+ character to the child (such as Ctrl-C or Ctrl-D). For example, to send
+ Ctrl-G (ASCII 7, bell, '\a')::
+
+ child.sendcontrol('g')
+
+ See also, sendintr() and sendeof().
+ '''
+ char = char.lower()
+ a = ord(char)
+ if 97 <= a <= 122:
+ a = a - ord('a') + 1
+ byte = _byte(a)
+ return self._writeb(byte), byte
+ d = {'@': 0, '`': 0,
+ '[': 27, '{': 27,
+ '\\': 28, '|': 28,
+ ']': 29, '}': 29,
+ '^': 30, '~': 30,
+ '_': 31,
+ '?': 127}
+ if char not in d:
+ return 0, b''
+
+ byte = _byte(d[char])
+ return self._writeb(byte), byte
+
+ def sendeof(self):
+ '''This sends an EOF to the child. This sends a character which causes
+ the pending parent output buffer to be sent to the waiting child
+ program without waiting for end-of-line. If it is the first character
+ of the line, the read() in the user program returns 0, which signifies
+ end-of-file. This means to work as expected a sendeof() has to be
+ called at the beginning of a line. This method does not send a newline.
+ It is the responsibility of the caller to ensure the eof is sent at the
+ beginning of a line. '''
+
+ return self._writeb(_EOF), _EOF
+
+ def sendintr(self):
+ '''This sends a SIGINT to the child. It does not require
+ the SIGINT to be the first character on a line. '''
+
+ return self._writeb(_INTR), _INTR
+
+ def eof(self):
+ '''This returns True if the EOF exception was ever raised.
+ '''
+
+ return self.flag_eof
+
+ def terminate(self, force=False):
+ '''This forces a child process to terminate. It starts nicely with
+ SIGHUP and SIGINT. If "force" is True then moves onto SIGKILL. This
+ returns True if the child was terminated. This returns False if the
+ child could not be terminated. '''
+
+ if not self.isalive():
+ return True
+ try:
+ self.kill(signal.SIGHUP)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ self.kill(signal.SIGCONT)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ self.kill(signal.SIGINT)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ if force:
+ self.kill(signal.SIGKILL)
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ else:
+ return False
+ return False
+ except OSError:
+ # I think there are kernel timing issues that sometimes cause
+ # this to happen. I think isalive() reports True, but the
+ # process is dead to the kernel.
+ # Make one last attempt to see if the kernel is up to date.
+ time.sleep(self.delayafterterminate)
+ if not self.isalive():
+ return True
+ else:
+ return False
+
+ def wait(self):
+ '''This waits until the child exits. This is a blocking call. This will
+ not read any data from the child, so this will block forever if the
+ child has unread output and has terminated. In other words, the child
+ may have printed output then called exit(), but, the child is
+ technically still alive until its output is read by the parent. '''
+
+ if self.isalive():
+ pid, status = os.waitpid(self.pid, 0)
+ else:
+ return self.exitstatus
+ self.exitstatus = os.WEXITSTATUS(status)
+ if os.WIFEXITED(status):
+ self.status = status
+ self.exitstatus = os.WEXITSTATUS(status)
+ self.signalstatus = None
+ self.terminated = True
+ elif os.WIFSIGNALED(status):
+ self.status = status
+ self.exitstatus = None
+ self.signalstatus = os.WTERMSIG(status)
+ self.terminated = True
+ elif os.WIFSTOPPED(status): # pragma: no cover
+ # You can't call wait() on a child process in the stopped state.
+ raise PtyProcessError('Called wait() on a stopped child ' +
+ 'process. This is not supported. Is some other ' +
+ 'process attempting job control with our child pid?')
+ return self.exitstatus
+
+ def isalive(self):
+ '''This tests if the child process is running or not. This is
+ non-blocking. If the child was terminated then this will read the
+ exitstatus or signalstatus of the child. This returns True if the child
+ process appears to be running or False if not. It can take literally
+ SECONDS for Solaris to return the right status. '''
+
+ if self.terminated:
+ return False
+
+ if self.flag_eof:
+ # This is for Linux, which requires the blocking form
+ # of waitpid to get the status of a defunct process.
+ # This is super-lame. The flag_eof would have been set
+ # in read_nonblocking(), so this should be safe.
+ waitpid_options = 0
+ else:
+ waitpid_options = os.WNOHANG
+
+ try:
+ pid, status = os.waitpid(self.pid, waitpid_options)
+ except OSError as e:
+ # No child processes
+ if e.errno == errno.ECHILD:
+ raise PtyProcessError('isalive() encountered condition ' +
+ 'where "terminated" is 0, but there was no child ' +
+ 'process. Did someone else call waitpid() ' +
+ 'on our process?')
+ else:
+ raise
+
+ # I have to do this twice for Solaris.
+ # I can't even believe that I figured this out...
+ # If waitpid() returns 0 it means that no child process
+ # wishes to report, and the value of status is undefined.
+ if pid == 0:
+ try:
+ ### os.WNOHANG) # Solaris!
+ pid, status = os.waitpid(self.pid, waitpid_options)
+ except OSError as e: # pragma: no cover
+ # This should never happen...
+ if e.errno == errno.ECHILD:
+ raise PtyProcessError('isalive() encountered condition ' +
+ 'that should never happen. There was no child ' +
+ 'process. Did someone else call waitpid() ' +
+ 'on our process?')
+ else:
+ raise
+
+ # If pid is still 0 after two calls to waitpid() then the process
+ # really is alive. This seems to work on all platforms, except for
+ # Irix which seems to require a blocking call on waitpid or select,
+ # so I let read_nonblocking take care of this situation
+ # (unfortunately, this requires waiting through the timeout).
+ if pid == 0:
+ return True
+
+ if pid == 0:
+ return True
+
+ if os.WIFEXITED(status):
+ self.status = status
+ self.exitstatus = os.WEXITSTATUS(status)
+ self.signalstatus = None
+ self.terminated = True
+ elif os.WIFSIGNALED(status):
+ self.status = status
+ self.exitstatus = None
+ self.signalstatus = os.WTERMSIG(status)
+ self.terminated = True
+ elif os.WIFSTOPPED(status):
+ raise PtyProcessError('isalive() encountered condition ' +
+ 'where child process is stopped. This is not ' +
+ 'supported. Is some other process attempting ' +
+ 'job control with our child pid?')
+ return False
+
+ def kill(self, sig):
+ """Send the given signal to the child application.
+
+ In keeping with UNIX tradition it has a misleading name. It does not
+ necessarily kill the child unless you send the right signal. See the
+ :mod:`signal` module for constants representing signal numbers.
+ """
+
+ # Same as os.kill, but the pid is given for you.
+ if self.isalive():
+ os.kill(self.pid, sig)
+
+ def getwinsize(self):
+ """Return the window size of the pseudoterminal as a tuple (rows, cols).
+ """
+ TIOCGWINSZ = getattr(termios, 'TIOCGWINSZ', 1074295912)
+ s = struct.pack('HHHH', 0, 0, 0, 0)
+ x = fcntl.ioctl(self.fd, TIOCGWINSZ, s)
+ return struct.unpack('HHHH', x)[0:2]
+
+ def setwinsize(self, rows, cols):
+ """Set the terminal window size of the child tty.
+
+ This will cause a SIGWINCH signal to be sent to the child. This does not
+ change the physical window size. It changes the size reported to
+ TTY-aware applications like vi or curses -- applications that respond to
+ the SIGWINCH signal.
+ """
+ return _setwinsize(self.fd, rows, cols)
+
+
+class PtyProcessUnicode(PtyProcess):
+ """Unicode wrapper around a process running in a pseudoterminal.
+
+ This class exposes a similar interface to :class:`PtyProcess`, but its read
+ methods return unicode, and its :meth:`write` accepts unicode.
+ """
+ if PY3:
+ string_type = str
+ else:
+ string_type = unicode # analysis:ignore
+
+ def __init__(self, pid, fd, encoding='utf-8', codec_errors='strict'):
+ super(PtyProcessUnicode, self).__init__(pid, fd)
+ self.encoding = encoding
+ self.codec_errors = codec_errors
+ self.decoder = codecs.getincrementaldecoder(encoding)(errors=codec_errors)
+
+ def read(self, size=1024):
+ """Read at most ``size`` bytes from the pty, return them as unicode.
+
+ Can block if there is nothing to read. Raises :exc:`EOFError` if the
+ terminal was closed.
+
+ The size argument still refers to bytes, not unicode code points.
+ """
+ b = super(PtyProcessUnicode, self).read(size)
+ return self.decoder.decode(b, final=False)
+
+ def readline(self):
+ """Read one line from the pseudoterminal, and return it as unicode.
+
+ Can block if there is nothing to read. Raises :exc:`EOFError` if the
+ terminal was closed.
+ """
+ b = super(PtyProcessUnicode, self).readline()
+ return self.decoder.decode(b, final=False)
+
+ def write(self, s):
+ """Write the unicode string ``s`` to the pseudoterminal.
+
+ Returns the number of bytes written.
+ """
+ b = s.encode(self.encoding)
+ return super(PtyProcessUnicode, self).write(b)
diff --git a/lib/ptyprocess/util.py b/lib/ptyprocess/util.py
new file mode 100644
index 0000000..aadbd62
--- /dev/null
+++ b/lib/ptyprocess/util.py
@@ -0,0 +1,71 @@
+try:
+ from shutil import which # Python >= 3.3
+except ImportError:
+ import os, sys
+
+ # This is copied from Python 3.4.1
+ def which(cmd, mode=os.F_OK | os.X_OK, path=None):
+ """Given a command, mode, and a PATH string, return the path which
+ conforms to the given mode on the PATH, or None if there is no such
+ file.
+
+ `mode` defaults to os.F_OK | os.X_OK. `path` defaults to the result
+ of os.environ.get("PATH"), or can be overridden with a custom search
+ path.
+
+ """
+ # Check that a given file can be accessed with the correct mode.
+ # Additionally check that `file` is not a directory, as on Windows
+ # directories pass the os.access check.
+ def _access_check(fn, mode):
+ return (os.path.exists(fn) and os.access(fn, mode)
+ and not os.path.isdir(fn))
+
+ # If we're given a path with a directory part, look it up directly rather
+ # than referring to PATH directories. This includes checking relative to the
+ # current directory, e.g. ./script
+ if os.path.dirname(cmd):
+ if _access_check(cmd, mode):
+ return cmd
+ return None
+
+ if path is None:
+ path = os.environ.get("PATH", os.defpath)
+ if not path:
+ return None
+ path = path.split(os.pathsep)
+
+ if sys.platform == "win32":
+ # The current directory takes precedence on Windows.
+ if not os.curdir in path:
+ path.insert(0, os.curdir)
+
+ # PATHEXT is necessary to check on Windows.
+ pathext = os.environ.get("PATHEXT", "").split(os.pathsep)
+ # See if the given file matches any of the expected path extensions.
+ # This will allow us to short circuit when given "python.exe".
+ # If it does match, only test that one, otherwise we have to try
+ # others.
+ if any(cmd.lower().endswith(ext.lower()) for ext in pathext):
+ files = [cmd]
+ else:
+ files = [cmd + ext for ext in pathext]
+ else:
+ # On other platforms you don't have things like PATHEXT to tell you
+ # what file suffixes are executable, so just pass on cmd as-is.
+ files = [cmd]
+
+ seen = set()
+ for dir in path:
+ normdir = os.path.normcase(dir)
+ if not normdir in seen:
+ seen.add(normdir)
+ for thefile in files:
+ name = os.path.join(dir, thefile)
+ if _access_check(name, mode):
+ return name
+ return None
+
+
+class PtyProcessError(Exception):
+ """Generic error class for this package."""
diff --git a/lib/pycparser-2.21.dist-info/INSTALLER b/lib/pycparser-2.21.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/pycparser-2.21.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/pycparser-2.21.dist-info/LICENSE b/lib/pycparser-2.21.dist-info/LICENSE
new file mode 100644
index 0000000..ea215f2
--- /dev/null
+++ b/lib/pycparser-2.21.dist-info/LICENSE
@@ -0,0 +1,27 @@
+pycparser -- A C parser in Python
+
+Copyright (c) 2008-2020, Eli Bendersky
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+* Neither the name of Eli Bendersky nor the names of its contributors may
+ be used to endorse or promote products derived from this software without
+ specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
+GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
+OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/lib/pycparser-2.21.dist-info/METADATA b/lib/pycparser-2.21.dist-info/METADATA
new file mode 100644
index 0000000..1d0fbd6
--- /dev/null
+++ b/lib/pycparser-2.21.dist-info/METADATA
@@ -0,0 +1,31 @@
+Metadata-Version: 2.1
+Name: pycparser
+Version: 2.21
+Summary: C parser in Python
+Home-page: https://github.com/eliben/pycparser
+Author: Eli Bendersky
+Author-email: eliben@gmail.com
+Maintainer: Eli Bendersky
+License: BSD
+Platform: Cross Platform
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: License :: OSI Approved :: BSD License
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.4
+Classifier: Programming Language :: Python :: 3.5
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Requires-Python: >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*
+
+
+pycparser is a complete parser of the C language, written in
+pure Python using the PLY parsing library.
+It parses C code into an AST and can serve as a front-end for
+C compilers or analysis tools.
+
+
diff --git a/lib/pycparser-2.21.dist-info/RECORD b/lib/pycparser-2.21.dist-info/RECORD
new file mode 100644
index 0000000..e230498
--- /dev/null
+++ b/lib/pycparser-2.21.dist-info/RECORD
@@ -0,0 +1,41 @@
+pycparser-2.21.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+pycparser-2.21.dist-info/LICENSE,sha256=Pn3yW437ZYyakVAZMNTZQ7BQh6g0fH4rQyVhavU1BHs,1536
+pycparser-2.21.dist-info/METADATA,sha256=GvTEQA9yKj0nvP4mknfoGpMvjaJXCQjQANcQHrRrAxc,1108
+pycparser-2.21.dist-info/RECORD,,
+pycparser-2.21.dist-info/WHEEL,sha256=kGT74LWyRUZrL4VgLh6_g12IeVl_9u9ZVhadrgXZUEY,110
+pycparser-2.21.dist-info/top_level.txt,sha256=c-lPcS74L_8KoH7IE6PQF5ofyirRQNV4VhkbSFIPeWM,10
+pycparser/__init__.py,sha256=WUEp5D0fuHBH9Q8c1fYvR2eKWfj-CNghLf2MMlQLI1I,2815
+pycparser/__pycache__/__init__.cpython-39.pyc,,
+pycparser/__pycache__/_ast_gen.cpython-39.pyc,,
+pycparser/__pycache__/_build_tables.cpython-39.pyc,,
+pycparser/__pycache__/ast_transforms.cpython-39.pyc,,
+pycparser/__pycache__/c_ast.cpython-39.pyc,,
+pycparser/__pycache__/c_generator.cpython-39.pyc,,
+pycparser/__pycache__/c_lexer.cpython-39.pyc,,
+pycparser/__pycache__/c_parser.cpython-39.pyc,,
+pycparser/__pycache__/lextab.cpython-39.pyc,,
+pycparser/__pycache__/plyparser.cpython-39.pyc,,
+pycparser/__pycache__/yacctab.cpython-39.pyc,,
+pycparser/_ast_gen.py,sha256=0JRVnDW-Jw-3IjVlo8je9rbAcp6Ko7toHAnB5zi7h0Q,10555
+pycparser/_build_tables.py,sha256=oZCd3Plhq-vkV-QuEsaahcf-jUI6-HgKsrAL9gvFzuU,1039
+pycparser/_c_ast.cfg,sha256=ld5ezE9yzIJFIVAUfw7ezJSlMi4nXKNCzfmqjOyQTNo,4255
+pycparser/ast_transforms.py,sha256=GTMYlUgWmXd5wJVyovXY1qzzAqjxzCpVVg0664dKGBs,5691
+pycparser/c_ast.py,sha256=HWeOrfYdCY0u5XaYhE1i60uVyE3yMWdcxzECUX-DqJw,31445
+pycparser/c_generator.py,sha256=yi6Mcqxv88J5ue8k5-mVGxh3iJ37iD4QyF-sWcGjC-8,17772
+pycparser/c_lexer.py,sha256=xCpjIb6vOUebBJpdifidb08y7XgAsO3T1gNGXJT93-w,17167
+pycparser/c_parser.py,sha256=_8y3i52bL6SUK21KmEEl0qzHxe-0eZRzjZGkWg8gQ4A,73680
+pycparser/lextab.py,sha256=fIxBAHYRC418oKF52M7xb8_KMj3K-tHx0TzZiKwxjPM,8504
+pycparser/ply/__init__.py,sha256=q4s86QwRsYRa20L9ueSxfh-hPihpftBjDOvYa2_SS2Y,102
+pycparser/ply/__pycache__/__init__.cpython-39.pyc,,
+pycparser/ply/__pycache__/cpp.cpython-39.pyc,,
+pycparser/ply/__pycache__/ctokens.cpython-39.pyc,,
+pycparser/ply/__pycache__/lex.cpython-39.pyc,,
+pycparser/ply/__pycache__/yacc.cpython-39.pyc,,
+pycparser/ply/__pycache__/ygen.cpython-39.pyc,,
+pycparser/ply/cpp.py,sha256=UtC3ylTWp5_1MKA-PLCuwKQR8zSOnlGuGGIdzj8xS98,33282
+pycparser/ply/ctokens.py,sha256=MKksnN40TehPhgVfxCJhjj_BjL943apreABKYz-bl0Y,3177
+pycparser/ply/lex.py,sha256=7Qol57x702HZwjA3ZLp-84CUEWq1EehW-N67Wzghi-M,42918
+pycparser/ply/yacc.py,sha256=eatSDkRLgRr6X3-hoDk_SQQv065R0BdL2K7fQ54CgVM,137323
+pycparser/ply/ygen.py,sha256=2JYNeYtrPz1JzLSLO3d4GsS8zJU8jY_I_CR1VI9gWrA,2251
+pycparser/plyparser.py,sha256=8tLOoEytcapvWrr1JfCf7Dog-wulBtS1YrDs8S7JfMo,4875
+pycparser/yacctab.py,sha256=j_fVNIyDWDRVk7eWMqQtlBw2AwUSV5JTrtT58l7zis0,205652
diff --git a/lib/pycparser-2.21.dist-info/WHEEL b/lib/pycparser-2.21.dist-info/WHEEL
new file mode 100644
index 0000000..ef99c6c
--- /dev/null
+++ b/lib/pycparser-2.21.dist-info/WHEEL
@@ -0,0 +1,6 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.34.2)
+Root-Is-Purelib: true
+Tag: py2-none-any
+Tag: py3-none-any
+
diff --git a/lib/pycparser-2.21.dist-info/top_level.txt b/lib/pycparser-2.21.dist-info/top_level.txt
new file mode 100644
index 0000000..dc1c9e1
--- /dev/null
+++ b/lib/pycparser-2.21.dist-info/top_level.txt
@@ -0,0 +1 @@
+pycparser
diff --git a/lib/pycparser/__init__.py b/lib/pycparser/__init__.py
new file mode 100644
index 0000000..d82eb2d
--- /dev/null
+++ b/lib/pycparser/__init__.py
@@ -0,0 +1,90 @@
+#-----------------------------------------------------------------
+# pycparser: __init__.py
+#
+# This package file exports some convenience functions for
+# interacting with pycparser
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#-----------------------------------------------------------------
+__all__ = ['c_lexer', 'c_parser', 'c_ast']
+__version__ = '2.21'
+
+import io
+from subprocess import check_output
+from .c_parser import CParser
+
+
+def preprocess_file(filename, cpp_path='cpp', cpp_args=''):
+ """ Preprocess a file using cpp.
+
+ filename:
+ Name of the file you want to preprocess.
+
+ cpp_path:
+ cpp_args:
+ Refer to the documentation of parse_file for the meaning of these
+ arguments.
+
+ When successful, returns the preprocessed file's contents.
+ Errors from cpp will be printed out.
+ """
+ path_list = [cpp_path]
+ if isinstance(cpp_args, list):
+ path_list += cpp_args
+ elif cpp_args != '':
+ path_list += [cpp_args]
+ path_list += [filename]
+
+ try:
+ # Note the use of universal_newlines to treat all newlines
+ # as \n for Python's purpose
+ text = check_output(path_list, universal_newlines=True)
+ except OSError as e:
+ raise RuntimeError("Unable to invoke 'cpp'. " +
+ 'Make sure its path was passed correctly\n' +
+ ('Original error: %s' % e))
+
+ return text
+
+
+def parse_file(filename, use_cpp=False, cpp_path='cpp', cpp_args='',
+ parser=None):
+ """ Parse a C file using pycparser.
+
+ filename:
+ Name of the file you want to parse.
+
+ use_cpp:
+ Set to True if you want to execute the C pre-processor
+ on the file prior to parsing it.
+
+ cpp_path:
+ If use_cpp is True, this is the path to 'cpp' on your
+ system. If no path is provided, it attempts to just
+ execute 'cpp', so it must be in your PATH.
+
+ cpp_args:
+ If use_cpp is True, set this to the command line arguments strings
+ to cpp. Be careful with quotes - it's best to pass a raw string
+ (r'') here. For example:
+ r'-I../utils/fake_libc_include'
+ If several arguments are required, pass a list of strings.
+
+ parser:
+ Optional parser object to be used instead of the default CParser
+
+ When successful, an AST is returned. ParseError can be
+ thrown if the file doesn't parse successfully.
+
+ Errors from cpp will be printed out.
+ """
+ if use_cpp:
+ text = preprocess_file(filename, cpp_path, cpp_args)
+ else:
+ with io.open(filename) as f:
+ text = f.read()
+
+ if parser is None:
+ parser = CParser()
+ return parser.parse(text, filename)
diff --git a/lib/pycparser/_ast_gen.py b/lib/pycparser/_ast_gen.py
new file mode 100644
index 0000000..0f7d330
--- /dev/null
+++ b/lib/pycparser/_ast_gen.py
@@ -0,0 +1,336 @@
+#-----------------------------------------------------------------
+# _ast_gen.py
+#
+# Generates the AST Node classes from a specification given in
+# a configuration file
+#
+# The design of this module was inspired by astgen.py from the
+# Python 2.5 code-base.
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#-----------------------------------------------------------------
+from string import Template
+
+
+class ASTCodeGenerator(object):
+ def __init__(self, cfg_filename='_c_ast.cfg'):
+ """ Initialize the code generator from a configuration
+ file.
+ """
+ self.cfg_filename = cfg_filename
+ self.node_cfg = [NodeCfg(name, contents)
+ for (name, contents) in self.parse_cfgfile(cfg_filename)]
+
+ def generate(self, file=None):
+ """ Generates the code into file, an open file buffer.
+ """
+ src = Template(_PROLOGUE_COMMENT).substitute(
+ cfg_filename=self.cfg_filename)
+
+ src += _PROLOGUE_CODE
+ for node_cfg in self.node_cfg:
+ src += node_cfg.generate_source() + '\n\n'
+
+ file.write(src)
+
+ def parse_cfgfile(self, filename):
+ """ Parse the configuration file and yield pairs of
+ (name, contents) for each node.
+ """
+ with open(filename, "r") as f:
+ for line in f:
+ line = line.strip()
+ if not line or line.startswith('#'):
+ continue
+ colon_i = line.find(':')
+ lbracket_i = line.find('[')
+ rbracket_i = line.find(']')
+ if colon_i < 1 or lbracket_i <= colon_i or rbracket_i <= lbracket_i:
+ raise RuntimeError("Invalid line in %s:\n%s\n" % (filename, line))
+
+ name = line[:colon_i]
+ val = line[lbracket_i + 1:rbracket_i]
+ vallist = [v.strip() for v in val.split(',')] if val else []
+ yield name, vallist
+
+
+class NodeCfg(object):
+ """ Node configuration.
+
+ name: node name
+ contents: a list of contents - attributes and child nodes
+ See comment at the top of the configuration file for details.
+ """
+
+ def __init__(self, name, contents):
+ self.name = name
+ self.all_entries = []
+ self.attr = []
+ self.child = []
+ self.seq_child = []
+
+ for entry in contents:
+ clean_entry = entry.rstrip('*')
+ self.all_entries.append(clean_entry)
+
+ if entry.endswith('**'):
+ self.seq_child.append(clean_entry)
+ elif entry.endswith('*'):
+ self.child.append(clean_entry)
+ else:
+ self.attr.append(entry)
+
+ def generate_source(self):
+ src = self._gen_init()
+ src += '\n' + self._gen_children()
+ src += '\n' + self._gen_iter()
+ src += '\n' + self._gen_attr_names()
+ return src
+
+ def _gen_init(self):
+ src = "class %s(Node):\n" % self.name
+
+ if self.all_entries:
+ args = ', '.join(self.all_entries)
+ slots = ', '.join("'{0}'".format(e) for e in self.all_entries)
+ slots += ", 'coord', '__weakref__'"
+ arglist = '(self, %s, coord=None)' % args
+ else:
+ slots = "'coord', '__weakref__'"
+ arglist = '(self, coord=None)'
+
+ src += " __slots__ = (%s)\n" % slots
+ src += " def __init__%s:\n" % arglist
+
+ for name in self.all_entries + ['coord']:
+ src += " self.%s = %s\n" % (name, name)
+
+ return src
+
+ def _gen_children(self):
+ src = ' def children(self):\n'
+
+ if self.all_entries:
+ src += ' nodelist = []\n'
+
+ for child in self.child:
+ src += (
+ ' if self.%(child)s is not None:' +
+ ' nodelist.append(("%(child)s", self.%(child)s))\n') % (
+ dict(child=child))
+
+ for seq_child in self.seq_child:
+ src += (
+ ' for i, child in enumerate(self.%(child)s or []):\n'
+ ' nodelist.append(("%(child)s[%%d]" %% i, child))\n') % (
+ dict(child=seq_child))
+
+ src += ' return tuple(nodelist)\n'
+ else:
+ src += ' return ()\n'
+
+ return src
+
+ def _gen_iter(self):
+ src = ' def __iter__(self):\n'
+
+ if self.all_entries:
+ for child in self.child:
+ src += (
+ ' if self.%(child)s is not None:\n' +
+ ' yield self.%(child)s\n') % (dict(child=child))
+
+ for seq_child in self.seq_child:
+ src += (
+ ' for child in (self.%(child)s or []):\n'
+ ' yield child\n') % (dict(child=seq_child))
+
+ if not (self.child or self.seq_child):
+ # Empty generator
+ src += (
+ ' return\n' +
+ ' yield\n')
+ else:
+ # Empty generator
+ src += (
+ ' return\n' +
+ ' yield\n')
+
+ return src
+
+ def _gen_attr_names(self):
+ src = " attr_names = (" + ''.join("%r, " % nm for nm in self.attr) + ')'
+ return src
+
+
+_PROLOGUE_COMMENT = \
+r'''#-----------------------------------------------------------------
+# ** ATTENTION **
+# This code was automatically generated from the file:
+# $cfg_filename
+#
+# Do not modify it directly. Modify the configuration file and
+# run the generator again.
+# ** ** *** ** **
+#
+# pycparser: c_ast.py
+#
+# AST Node classes.
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#-----------------------------------------------------------------
+
+'''
+
+_PROLOGUE_CODE = r'''
+import sys
+
+def _repr(obj):
+ """
+ Get the representation of an object, with dedicated pprint-like format for lists.
+ """
+ if isinstance(obj, list):
+ return '[' + (',\n '.join((_repr(e).replace('\n', '\n ') for e in obj))) + '\n]'
+ else:
+ return repr(obj)
+
+class Node(object):
+ __slots__ = ()
+ """ Abstract base class for AST nodes.
+ """
+ def __repr__(self):
+ """ Generates a python representation of the current node
+ """
+ result = self.__class__.__name__ + '('
+
+ indent = ''
+ separator = ''
+ for name in self.__slots__[:-2]:
+ result += separator
+ result += indent
+ result += name + '=' + (_repr(getattr(self, name)).replace('\n', '\n ' + (' ' * (len(name) + len(self.__class__.__name__)))))
+
+ separator = ','
+ indent = '\n ' + (' ' * len(self.__class__.__name__))
+
+ result += indent + ')'
+
+ return result
+
+ def children(self):
+ """ A sequence of all children that are Nodes
+ """
+ pass
+
+ def show(self, buf=sys.stdout, offset=0, attrnames=False, nodenames=False, showcoord=False, _my_node_name=None):
+ """ Pretty print the Node and all its attributes and
+ children (recursively) to a buffer.
+
+ buf:
+ Open IO buffer into which the Node is printed.
+
+ offset:
+ Initial offset (amount of leading spaces)
+
+ attrnames:
+ True if you want to see the attribute names in
+ name=value pairs. False to only see the values.
+
+ nodenames:
+ True if you want to see the actual node names
+ within their parents.
+
+ showcoord:
+ Do you want the coordinates of each Node to be
+ displayed.
+ """
+ lead = ' ' * offset
+ if nodenames and _my_node_name is not None:
+ buf.write(lead + self.__class__.__name__+ ' <' + _my_node_name + '>: ')
+ else:
+ buf.write(lead + self.__class__.__name__+ ': ')
+
+ if self.attr_names:
+ if attrnames:
+ nvlist = [(n, getattr(self,n)) for n in self.attr_names]
+ attrstr = ', '.join('%s=%s' % nv for nv in nvlist)
+ else:
+ vlist = [getattr(self, n) for n in self.attr_names]
+ attrstr = ', '.join('%s' % v for v in vlist)
+ buf.write(attrstr)
+
+ if showcoord:
+ buf.write(' (at %s)' % self.coord)
+ buf.write('\n')
+
+ for (child_name, child) in self.children():
+ child.show(
+ buf,
+ offset=offset + 2,
+ attrnames=attrnames,
+ nodenames=nodenames,
+ showcoord=showcoord,
+ _my_node_name=child_name)
+
+
+class NodeVisitor(object):
+ """ A base NodeVisitor class for visiting c_ast nodes.
+ Subclass it and define your own visit_XXX methods, where
+ XXX is the class name you want to visit with these
+ methods.
+
+ For example:
+
+ class ConstantVisitor(NodeVisitor):
+ def __init__(self):
+ self.values = []
+
+ def visit_Constant(self, node):
+ self.values.append(node.value)
+
+ Creates a list of values of all the constant nodes
+ encountered below the given node. To use it:
+
+ cv = ConstantVisitor()
+ cv.visit(node)
+
+ Notes:
+
+ * generic_visit() will be called for AST nodes for which
+ no visit_XXX method was defined.
+ * The children of nodes for which a visit_XXX was
+ defined will not be visited - if you need this, call
+ generic_visit() on the node.
+ You can use:
+ NodeVisitor.generic_visit(self, node)
+ * Modeled after Python's own AST visiting facilities
+ (the ast module of Python 3.0)
+ """
+
+ _method_cache = None
+
+ def visit(self, node):
+ """ Visit a node.
+ """
+
+ if self._method_cache is None:
+ self._method_cache = {}
+
+ visitor = self._method_cache.get(node.__class__.__name__, None)
+ if visitor is None:
+ method = 'visit_' + node.__class__.__name__
+ visitor = getattr(self, method, self.generic_visit)
+ self._method_cache[node.__class__.__name__] = visitor
+
+ return visitor(node)
+
+ def generic_visit(self, node):
+ """ Called if no explicit visitor function exists for a
+ node. Implements preorder visiting of the node.
+ """
+ for c in node:
+ self.visit(c)
+
+'''
diff --git a/lib/pycparser/_build_tables.py b/lib/pycparser/_build_tables.py
new file mode 100644
index 0000000..958381a
--- /dev/null
+++ b/lib/pycparser/_build_tables.py
@@ -0,0 +1,37 @@
+#-----------------------------------------------------------------
+# pycparser: _build_tables.py
+#
+# A dummy for generating the lexing/parsing tables and and
+# compiling them into .pyc for faster execution in optimized mode.
+# Also generates AST code from the configuration file.
+# Should be called from the pycparser directory.
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#-----------------------------------------------------------------
+
+# Insert '.' and '..' as first entries to the search path for modules.
+# Restricted environments like embeddable python do not include the
+# current working directory on startup.
+import sys
+sys.path[0:0] = ['.', '..']
+
+# Generate c_ast.py
+from _ast_gen import ASTCodeGenerator
+ast_gen = ASTCodeGenerator('_c_ast.cfg')
+ast_gen.generate(open('c_ast.py', 'w'))
+
+from pycparser import c_parser
+
+# Generates the tables
+#
+c_parser.CParser(
+ lex_optimize=True,
+ yacc_debug=False,
+ yacc_optimize=True)
+
+# Load to compile into .pyc
+#
+import lextab
+import yacctab
+import c_ast
diff --git a/lib/pycparser/_c_ast.cfg b/lib/pycparser/_c_ast.cfg
new file mode 100644
index 0000000..0626533
--- /dev/null
+++ b/lib/pycparser/_c_ast.cfg
@@ -0,0 +1,195 @@
+#-----------------------------------------------------------------
+# pycparser: _c_ast.cfg
+#
+# Defines the AST Node classes used in pycparser.
+#
+# Each entry is a Node sub-class name, listing the attributes
+# and child nodes of the class:
+# <name>* - a child node
+# <name>** - a sequence of child nodes
+# <name> - an attribute
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#-----------------------------------------------------------------
+
+# ArrayDecl is a nested declaration of an array with the given type.
+# dim: the dimension (for example, constant 42)
+# dim_quals: list of dimension qualifiers, to support C99's allowing 'const'
+# and 'static' within the array dimension in function declarations.
+ArrayDecl: [type*, dim*, dim_quals]
+
+ArrayRef: [name*, subscript*]
+
+# op: =, +=, /= etc.
+#
+Assignment: [op, lvalue*, rvalue*]
+
+Alignas: [alignment*]
+
+BinaryOp: [op, left*, right*]
+
+Break: []
+
+Case: [expr*, stmts**]
+
+Cast: [to_type*, expr*]
+
+# Compound statement in C99 is a list of block items (declarations or
+# statements).
+#
+Compound: [block_items**]
+
+# Compound literal (anonymous aggregate) for C99.
+# (type-name) {initializer_list}
+# type: the typename
+# init: InitList for the initializer list
+#
+CompoundLiteral: [type*, init*]
+
+# type: int, char, float, string, etc.
+#
+Constant: [type, value]
+
+Continue: []
+
+# name: the variable being declared
+# quals: list of qualifiers (const, volatile)
+# funcspec: list function specifiers (i.e. inline in C99)
+# storage: list of storage specifiers (extern, register, etc.)
+# type: declaration type (probably nested with all the modifiers)
+# init: initialization value, or None
+# bitsize: bit field size, or None
+#
+Decl: [name, quals, align, storage, funcspec, type*, init*, bitsize*]
+
+DeclList: [decls**]
+
+Default: [stmts**]
+
+DoWhile: [cond*, stmt*]
+
+# Represents the ellipsis (...) parameter in a function
+# declaration
+#
+EllipsisParam: []
+
+# An empty statement (a semicolon ';' on its own)
+#
+EmptyStatement: []
+
+# Enumeration type specifier
+# name: an optional ID
+# values: an EnumeratorList
+#
+Enum: [name, values*]
+
+# A name/value pair for enumeration values
+#
+Enumerator: [name, value*]
+
+# A list of enumerators
+#
+EnumeratorList: [enumerators**]
+
+# A list of expressions separated by the comma operator.
+#
+ExprList: [exprs**]
+
+# This is the top of the AST, representing a single C file (a
+# translation unit in K&R jargon). It contains a list of
+# "external-declaration"s, which is either declarations (Decl),
+# Typedef or function definitions (FuncDef).
+#
+FileAST: [ext**]
+
+# for (init; cond; next) stmt
+#
+For: [init*, cond*, next*, stmt*]
+
+# name: Id
+# args: ExprList
+#
+FuncCall: [name*, args*]
+
+# type <decl>(args)
+#
+FuncDecl: [args*, type*]
+
+# Function definition: a declarator for the function name and
+# a body, which is a compound statement.
+# There's an optional list of parameter declarations for old
+# K&R-style definitions
+#
+FuncDef: [decl*, param_decls**, body*]
+
+Goto: [name]
+
+ID: [name]
+
+# Holder for types that are a simple identifier (e.g. the built
+# ins void, char etc. and typedef-defined types)
+#
+IdentifierType: [names]
+
+If: [cond*, iftrue*, iffalse*]
+
+# An initialization list used for compound literals.
+#
+InitList: [exprs**]
+
+Label: [name, stmt*]
+
+# A named initializer for C99.
+# The name of a NamedInitializer is a sequence of Nodes, because
+# names can be hierarchical and contain constant expressions.
+#
+NamedInitializer: [name**, expr*]
+
+# a list of comma separated function parameter declarations
+#
+ParamList: [params**]
+
+PtrDecl: [quals, type*]
+
+Return: [expr*]
+
+StaticAssert: [cond*, message*]
+
+# name: struct tag name
+# decls: declaration of members
+#
+Struct: [name, decls**]
+
+# type: . or ->
+# name.field or name->field
+#
+StructRef: [name*, type, field*]
+
+Switch: [cond*, stmt*]
+
+# cond ? iftrue : iffalse
+#
+TernaryOp: [cond*, iftrue*, iffalse*]
+
+# A base type declaration
+#
+TypeDecl: [declname, quals, align, type*]
+
+# A typedef declaration.
+# Very similar to Decl, but without some attributes
+#
+Typedef: [name, quals, storage, type*]
+
+Typename: [name, quals, align, type*]
+
+UnaryOp: [op, expr*]
+
+# name: union tag name
+# decls: declaration of members
+#
+Union: [name, decls**]
+
+While: [cond*, stmt*]
+
+Pragma: [string]
diff --git a/lib/pycparser/ast_transforms.py b/lib/pycparser/ast_transforms.py
new file mode 100644
index 0000000..367dcf5
--- /dev/null
+++ b/lib/pycparser/ast_transforms.py
@@ -0,0 +1,164 @@
+#------------------------------------------------------------------------------
+# pycparser: ast_transforms.py
+#
+# Some utilities used by the parser to create a friendlier AST.
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#------------------------------------------------------------------------------
+
+from . import c_ast
+
+
+def fix_switch_cases(switch_node):
+ """ The 'case' statements in a 'switch' come out of parsing with one
+ child node, so subsequent statements are just tucked to the parent
+ Compound. Additionally, consecutive (fall-through) case statements
+ come out messy. This is a peculiarity of the C grammar. The following:
+
+ switch (myvar) {
+ case 10:
+ k = 10;
+ p = k + 1;
+ return 10;
+ case 20:
+ case 30:
+ return 20;
+ default:
+ break;
+ }
+
+ Creates this tree (pseudo-dump):
+
+ Switch
+ ID: myvar
+ Compound:
+ Case 10:
+ k = 10
+ p = k + 1
+ return 10
+ Case 20:
+ Case 30:
+ return 20
+ Default:
+ break
+
+ The goal of this transform is to fix this mess, turning it into the
+ following:
+
+ Switch
+ ID: myvar
+ Compound:
+ Case 10:
+ k = 10
+ p = k + 1
+ return 10
+ Case 20:
+ Case 30:
+ return 20
+ Default:
+ break
+
+ A fixed AST node is returned. The argument may be modified.
+ """
+ assert isinstance(switch_node, c_ast.Switch)
+ if not isinstance(switch_node.stmt, c_ast.Compound):
+ return switch_node
+
+ # The new Compound child for the Switch, which will collect children in the
+ # correct order
+ new_compound = c_ast.Compound([], switch_node.stmt.coord)
+
+ # The last Case/Default node
+ last_case = None
+
+ # Goes over the children of the Compound below the Switch, adding them
+ # either directly below new_compound or below the last Case as appropriate
+ # (for `switch(cond) {}`, block_items would have been None)
+ for child in (switch_node.stmt.block_items or []):
+ if isinstance(child, (c_ast.Case, c_ast.Default)):
+ # If it's a Case/Default:
+ # 1. Add it to the Compound and mark as "last case"
+ # 2. If its immediate child is also a Case or Default, promote it
+ # to a sibling.
+ new_compound.block_items.append(child)
+ _extract_nested_case(child, new_compound.block_items)
+ last_case = new_compound.block_items[-1]
+ else:
+ # Other statements are added as children to the last case, if it
+ # exists.
+ if last_case is None:
+ new_compound.block_items.append(child)
+ else:
+ last_case.stmts.append(child)
+
+ switch_node.stmt = new_compound
+ return switch_node
+
+
+def _extract_nested_case(case_node, stmts_list):
+ """ Recursively extract consecutive Case statements that are made nested
+ by the parser and add them to the stmts_list.
+ """
+ if isinstance(case_node.stmts[0], (c_ast.Case, c_ast.Default)):
+ stmts_list.append(case_node.stmts.pop())
+ _extract_nested_case(stmts_list[-1], stmts_list)
+
+
+def fix_atomic_specifiers(decl):
+ """ Atomic specifiers like _Atomic(type) are unusually structured,
+ conferring a qualifier upon the contained type.
+
+ This function fixes a decl with atomic specifiers to have a sane AST
+ structure, by removing spurious Typename->TypeDecl pairs and attaching
+ the _Atomic qualifier in the right place.
+ """
+ # There can be multiple levels of _Atomic in a decl; fix them until a
+ # fixed point is reached.
+ while True:
+ decl, found = _fix_atomic_specifiers_once(decl)
+ if not found:
+ break
+
+ # Make sure to add an _Atomic qual on the topmost decl if needed. Also
+ # restore the declname on the innermost TypeDecl (it gets placed in the
+ # wrong place during construction).
+ typ = decl
+ while not isinstance(typ, c_ast.TypeDecl):
+ try:
+ typ = typ.type
+ except AttributeError:
+ return decl
+ if '_Atomic' in typ.quals and '_Atomic' not in decl.quals:
+ decl.quals.append('_Atomic')
+ if typ.declname is None:
+ typ.declname = decl.name
+
+ return decl
+
+
+def _fix_atomic_specifiers_once(decl):
+ """ Performs one 'fix' round of atomic specifiers.
+ Returns (modified_decl, found) where found is True iff a fix was made.
+ """
+ parent = decl
+ grandparent = None
+ node = decl.type
+ while node is not None:
+ if isinstance(node, c_ast.Typename) and '_Atomic' in node.quals:
+ break
+ try:
+ grandparent = parent
+ parent = node
+ node = node.type
+ except AttributeError:
+ # If we've reached a node without a `type` field, it means we won't
+ # find what we're looking for at this point; give up the search
+ # and return the original decl unmodified.
+ return decl, False
+
+ assert isinstance(parent, c_ast.TypeDecl)
+ grandparent.type = node.type
+ if '_Atomic' not in node.type.quals:
+ node.type.quals.append('_Atomic')
+ return decl, True
diff --git a/lib/pycparser/c_ast.py b/lib/pycparser/c_ast.py
new file mode 100644
index 0000000..6575a2a
--- /dev/null
+++ b/lib/pycparser/c_ast.py
@@ -0,0 +1,1125 @@
+#-----------------------------------------------------------------
+# ** ATTENTION **
+# This code was automatically generated from the file:
+# _c_ast.cfg
+#
+# Do not modify it directly. Modify the configuration file and
+# run the generator again.
+# ** ** *** ** **
+#
+# pycparser: c_ast.py
+#
+# AST Node classes.
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#-----------------------------------------------------------------
+
+
+import sys
+
+def _repr(obj):
+ """
+ Get the representation of an object, with dedicated pprint-like format for lists.
+ """
+ if isinstance(obj, list):
+ return '[' + (',\n '.join((_repr(e).replace('\n', '\n ') for e in obj))) + '\n]'
+ else:
+ return repr(obj)
+
+class Node(object):
+ __slots__ = ()
+ """ Abstract base class for AST nodes.
+ """
+ def __repr__(self):
+ """ Generates a python representation of the current node
+ """
+ result = self.__class__.__name__ + '('
+
+ indent = ''
+ separator = ''
+ for name in self.__slots__[:-2]:
+ result += separator
+ result += indent
+ result += name + '=' + (_repr(getattr(self, name)).replace('\n', '\n ' + (' ' * (len(name) + len(self.__class__.__name__)))))
+
+ separator = ','
+ indent = '\n ' + (' ' * len(self.__class__.__name__))
+
+ result += indent + ')'
+
+ return result
+
+ def children(self):
+ """ A sequence of all children that are Nodes
+ """
+ pass
+
+ def show(self, buf=sys.stdout, offset=0, attrnames=False, nodenames=False, showcoord=False, _my_node_name=None):
+ """ Pretty print the Node and all its attributes and
+ children (recursively) to a buffer.
+
+ buf:
+ Open IO buffer into which the Node is printed.
+
+ offset:
+ Initial offset (amount of leading spaces)
+
+ attrnames:
+ True if you want to see the attribute names in
+ name=value pairs. False to only see the values.
+
+ nodenames:
+ True if you want to see the actual node names
+ within their parents.
+
+ showcoord:
+ Do you want the coordinates of each Node to be
+ displayed.
+ """
+ lead = ' ' * offset
+ if nodenames and _my_node_name is not None:
+ buf.write(lead + self.__class__.__name__+ ' <' + _my_node_name + '>: ')
+ else:
+ buf.write(lead + self.__class__.__name__+ ': ')
+
+ if self.attr_names:
+ if attrnames:
+ nvlist = [(n, getattr(self,n)) for n in self.attr_names]
+ attrstr = ', '.join('%s=%s' % nv for nv in nvlist)
+ else:
+ vlist = [getattr(self, n) for n in self.attr_names]
+ attrstr = ', '.join('%s' % v for v in vlist)
+ buf.write(attrstr)
+
+ if showcoord:
+ buf.write(' (at %s)' % self.coord)
+ buf.write('\n')
+
+ for (child_name, child) in self.children():
+ child.show(
+ buf,
+ offset=offset + 2,
+ attrnames=attrnames,
+ nodenames=nodenames,
+ showcoord=showcoord,
+ _my_node_name=child_name)
+
+
+class NodeVisitor(object):
+ """ A base NodeVisitor class for visiting c_ast nodes.
+ Subclass it and define your own visit_XXX methods, where
+ XXX is the class name you want to visit with these
+ methods.
+
+ For example:
+
+ class ConstantVisitor(NodeVisitor):
+ def __init__(self):
+ self.values = []
+
+ def visit_Constant(self, node):
+ self.values.append(node.value)
+
+ Creates a list of values of all the constant nodes
+ encountered below the given node. To use it:
+
+ cv = ConstantVisitor()
+ cv.visit(node)
+
+ Notes:
+
+ * generic_visit() will be called for AST nodes for which
+ no visit_XXX method was defined.
+ * The children of nodes for which a visit_XXX was
+ defined will not be visited - if you need this, call
+ generic_visit() on the node.
+ You can use:
+ NodeVisitor.generic_visit(self, node)
+ * Modeled after Python's own AST visiting facilities
+ (the ast module of Python 3.0)
+ """
+
+ _method_cache = None
+
+ def visit(self, node):
+ """ Visit a node.
+ """
+
+ if self._method_cache is None:
+ self._method_cache = {}
+
+ visitor = self._method_cache.get(node.__class__.__name__, None)
+ if visitor is None:
+ method = 'visit_' + node.__class__.__name__
+ visitor = getattr(self, method, self.generic_visit)
+ self._method_cache[node.__class__.__name__] = visitor
+
+ return visitor(node)
+
+ def generic_visit(self, node):
+ """ Called if no explicit visitor function exists for a
+ node. Implements preorder visiting of the node.
+ """
+ for c in node:
+ self.visit(c)
+
+class ArrayDecl(Node):
+ __slots__ = ('type', 'dim', 'dim_quals', 'coord', '__weakref__')
+ def __init__(self, type, dim, dim_quals, coord=None):
+ self.type = type
+ self.dim = dim
+ self.dim_quals = dim_quals
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.type is not None: nodelist.append(("type", self.type))
+ if self.dim is not None: nodelist.append(("dim", self.dim))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.type is not None:
+ yield self.type
+ if self.dim is not None:
+ yield self.dim
+
+ attr_names = ('dim_quals', )
+
+class ArrayRef(Node):
+ __slots__ = ('name', 'subscript', 'coord', '__weakref__')
+ def __init__(self, name, subscript, coord=None):
+ self.name = name
+ self.subscript = subscript
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.name is not None: nodelist.append(("name", self.name))
+ if self.subscript is not None: nodelist.append(("subscript", self.subscript))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.name is not None:
+ yield self.name
+ if self.subscript is not None:
+ yield self.subscript
+
+ attr_names = ()
+
+class Assignment(Node):
+ __slots__ = ('op', 'lvalue', 'rvalue', 'coord', '__weakref__')
+ def __init__(self, op, lvalue, rvalue, coord=None):
+ self.op = op
+ self.lvalue = lvalue
+ self.rvalue = rvalue
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.lvalue is not None: nodelist.append(("lvalue", self.lvalue))
+ if self.rvalue is not None: nodelist.append(("rvalue", self.rvalue))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.lvalue is not None:
+ yield self.lvalue
+ if self.rvalue is not None:
+ yield self.rvalue
+
+ attr_names = ('op', )
+
+class Alignas(Node):
+ __slots__ = ('alignment', 'coord', '__weakref__')
+ def __init__(self, alignment, coord=None):
+ self.alignment = alignment
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.alignment is not None: nodelist.append(("alignment", self.alignment))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.alignment is not None:
+ yield self.alignment
+
+ attr_names = ()
+
+class BinaryOp(Node):
+ __slots__ = ('op', 'left', 'right', 'coord', '__weakref__')
+ def __init__(self, op, left, right, coord=None):
+ self.op = op
+ self.left = left
+ self.right = right
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.left is not None: nodelist.append(("left", self.left))
+ if self.right is not None: nodelist.append(("right", self.right))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.left is not None:
+ yield self.left
+ if self.right is not None:
+ yield self.right
+
+ attr_names = ('op', )
+
+class Break(Node):
+ __slots__ = ('coord', '__weakref__')
+ def __init__(self, coord=None):
+ self.coord = coord
+
+ def children(self):
+ return ()
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ()
+
+class Case(Node):
+ __slots__ = ('expr', 'stmts', 'coord', '__weakref__')
+ def __init__(self, expr, stmts, coord=None):
+ self.expr = expr
+ self.stmts = stmts
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.expr is not None: nodelist.append(("expr", self.expr))
+ for i, child in enumerate(self.stmts or []):
+ nodelist.append(("stmts[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.expr is not None:
+ yield self.expr
+ for child in (self.stmts or []):
+ yield child
+
+ attr_names = ()
+
+class Cast(Node):
+ __slots__ = ('to_type', 'expr', 'coord', '__weakref__')
+ def __init__(self, to_type, expr, coord=None):
+ self.to_type = to_type
+ self.expr = expr
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.to_type is not None: nodelist.append(("to_type", self.to_type))
+ if self.expr is not None: nodelist.append(("expr", self.expr))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.to_type is not None:
+ yield self.to_type
+ if self.expr is not None:
+ yield self.expr
+
+ attr_names = ()
+
+class Compound(Node):
+ __slots__ = ('block_items', 'coord', '__weakref__')
+ def __init__(self, block_items, coord=None):
+ self.block_items = block_items
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.block_items or []):
+ nodelist.append(("block_items[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.block_items or []):
+ yield child
+
+ attr_names = ()
+
+class CompoundLiteral(Node):
+ __slots__ = ('type', 'init', 'coord', '__weakref__')
+ def __init__(self, type, init, coord=None):
+ self.type = type
+ self.init = init
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.type is not None: nodelist.append(("type", self.type))
+ if self.init is not None: nodelist.append(("init", self.init))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.type is not None:
+ yield self.type
+ if self.init is not None:
+ yield self.init
+
+ attr_names = ()
+
+class Constant(Node):
+ __slots__ = ('type', 'value', 'coord', '__weakref__')
+ def __init__(self, type, value, coord=None):
+ self.type = type
+ self.value = value
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ return tuple(nodelist)
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ('type', 'value', )
+
+class Continue(Node):
+ __slots__ = ('coord', '__weakref__')
+ def __init__(self, coord=None):
+ self.coord = coord
+
+ def children(self):
+ return ()
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ()
+
+class Decl(Node):
+ __slots__ = ('name', 'quals', 'align', 'storage', 'funcspec', 'type', 'init', 'bitsize', 'coord', '__weakref__')
+ def __init__(self, name, quals, align, storage, funcspec, type, init, bitsize, coord=None):
+ self.name = name
+ self.quals = quals
+ self.align = align
+ self.storage = storage
+ self.funcspec = funcspec
+ self.type = type
+ self.init = init
+ self.bitsize = bitsize
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.type is not None: nodelist.append(("type", self.type))
+ if self.init is not None: nodelist.append(("init", self.init))
+ if self.bitsize is not None: nodelist.append(("bitsize", self.bitsize))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.type is not None:
+ yield self.type
+ if self.init is not None:
+ yield self.init
+ if self.bitsize is not None:
+ yield self.bitsize
+
+ attr_names = ('name', 'quals', 'align', 'storage', 'funcspec', )
+
+class DeclList(Node):
+ __slots__ = ('decls', 'coord', '__weakref__')
+ def __init__(self, decls, coord=None):
+ self.decls = decls
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.decls or []):
+ nodelist.append(("decls[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.decls or []):
+ yield child
+
+ attr_names = ()
+
+class Default(Node):
+ __slots__ = ('stmts', 'coord', '__weakref__')
+ def __init__(self, stmts, coord=None):
+ self.stmts = stmts
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.stmts or []):
+ nodelist.append(("stmts[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.stmts or []):
+ yield child
+
+ attr_names = ()
+
+class DoWhile(Node):
+ __slots__ = ('cond', 'stmt', 'coord', '__weakref__')
+ def __init__(self, cond, stmt, coord=None):
+ self.cond = cond
+ self.stmt = stmt
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.cond is not None: nodelist.append(("cond", self.cond))
+ if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.cond is not None:
+ yield self.cond
+ if self.stmt is not None:
+ yield self.stmt
+
+ attr_names = ()
+
+class EllipsisParam(Node):
+ __slots__ = ('coord', '__weakref__')
+ def __init__(self, coord=None):
+ self.coord = coord
+
+ def children(self):
+ return ()
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ()
+
+class EmptyStatement(Node):
+ __slots__ = ('coord', '__weakref__')
+ def __init__(self, coord=None):
+ self.coord = coord
+
+ def children(self):
+ return ()
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ()
+
+class Enum(Node):
+ __slots__ = ('name', 'values', 'coord', '__weakref__')
+ def __init__(self, name, values, coord=None):
+ self.name = name
+ self.values = values
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.values is not None: nodelist.append(("values", self.values))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.values is not None:
+ yield self.values
+
+ attr_names = ('name', )
+
+class Enumerator(Node):
+ __slots__ = ('name', 'value', 'coord', '__weakref__')
+ def __init__(self, name, value, coord=None):
+ self.name = name
+ self.value = value
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.value is not None: nodelist.append(("value", self.value))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.value is not None:
+ yield self.value
+
+ attr_names = ('name', )
+
+class EnumeratorList(Node):
+ __slots__ = ('enumerators', 'coord', '__weakref__')
+ def __init__(self, enumerators, coord=None):
+ self.enumerators = enumerators
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.enumerators or []):
+ nodelist.append(("enumerators[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.enumerators or []):
+ yield child
+
+ attr_names = ()
+
+class ExprList(Node):
+ __slots__ = ('exprs', 'coord', '__weakref__')
+ def __init__(self, exprs, coord=None):
+ self.exprs = exprs
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.exprs or []):
+ nodelist.append(("exprs[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.exprs or []):
+ yield child
+
+ attr_names = ()
+
+class FileAST(Node):
+ __slots__ = ('ext', 'coord', '__weakref__')
+ def __init__(self, ext, coord=None):
+ self.ext = ext
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.ext or []):
+ nodelist.append(("ext[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.ext or []):
+ yield child
+
+ attr_names = ()
+
+class For(Node):
+ __slots__ = ('init', 'cond', 'next', 'stmt', 'coord', '__weakref__')
+ def __init__(self, init, cond, next, stmt, coord=None):
+ self.init = init
+ self.cond = cond
+ self.next = next
+ self.stmt = stmt
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.init is not None: nodelist.append(("init", self.init))
+ if self.cond is not None: nodelist.append(("cond", self.cond))
+ if self.next is not None: nodelist.append(("next", self.next))
+ if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.init is not None:
+ yield self.init
+ if self.cond is not None:
+ yield self.cond
+ if self.next is not None:
+ yield self.next
+ if self.stmt is not None:
+ yield self.stmt
+
+ attr_names = ()
+
+class FuncCall(Node):
+ __slots__ = ('name', 'args', 'coord', '__weakref__')
+ def __init__(self, name, args, coord=None):
+ self.name = name
+ self.args = args
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.name is not None: nodelist.append(("name", self.name))
+ if self.args is not None: nodelist.append(("args", self.args))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.name is not None:
+ yield self.name
+ if self.args is not None:
+ yield self.args
+
+ attr_names = ()
+
+class FuncDecl(Node):
+ __slots__ = ('args', 'type', 'coord', '__weakref__')
+ def __init__(self, args, type, coord=None):
+ self.args = args
+ self.type = type
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.args is not None: nodelist.append(("args", self.args))
+ if self.type is not None: nodelist.append(("type", self.type))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.args is not None:
+ yield self.args
+ if self.type is not None:
+ yield self.type
+
+ attr_names = ()
+
+class FuncDef(Node):
+ __slots__ = ('decl', 'param_decls', 'body', 'coord', '__weakref__')
+ def __init__(self, decl, param_decls, body, coord=None):
+ self.decl = decl
+ self.param_decls = param_decls
+ self.body = body
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.decl is not None: nodelist.append(("decl", self.decl))
+ if self.body is not None: nodelist.append(("body", self.body))
+ for i, child in enumerate(self.param_decls or []):
+ nodelist.append(("param_decls[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.decl is not None:
+ yield self.decl
+ if self.body is not None:
+ yield self.body
+ for child in (self.param_decls or []):
+ yield child
+
+ attr_names = ()
+
+class Goto(Node):
+ __slots__ = ('name', 'coord', '__weakref__')
+ def __init__(self, name, coord=None):
+ self.name = name
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ return tuple(nodelist)
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ('name', )
+
+class ID(Node):
+ __slots__ = ('name', 'coord', '__weakref__')
+ def __init__(self, name, coord=None):
+ self.name = name
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ return tuple(nodelist)
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ('name', )
+
+class IdentifierType(Node):
+ __slots__ = ('names', 'coord', '__weakref__')
+ def __init__(self, names, coord=None):
+ self.names = names
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ return tuple(nodelist)
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ('names', )
+
+class If(Node):
+ __slots__ = ('cond', 'iftrue', 'iffalse', 'coord', '__weakref__')
+ def __init__(self, cond, iftrue, iffalse, coord=None):
+ self.cond = cond
+ self.iftrue = iftrue
+ self.iffalse = iffalse
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.cond is not None: nodelist.append(("cond", self.cond))
+ if self.iftrue is not None: nodelist.append(("iftrue", self.iftrue))
+ if self.iffalse is not None: nodelist.append(("iffalse", self.iffalse))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.cond is not None:
+ yield self.cond
+ if self.iftrue is not None:
+ yield self.iftrue
+ if self.iffalse is not None:
+ yield self.iffalse
+
+ attr_names = ()
+
+class InitList(Node):
+ __slots__ = ('exprs', 'coord', '__weakref__')
+ def __init__(self, exprs, coord=None):
+ self.exprs = exprs
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.exprs or []):
+ nodelist.append(("exprs[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.exprs or []):
+ yield child
+
+ attr_names = ()
+
+class Label(Node):
+ __slots__ = ('name', 'stmt', 'coord', '__weakref__')
+ def __init__(self, name, stmt, coord=None):
+ self.name = name
+ self.stmt = stmt
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.stmt is not None:
+ yield self.stmt
+
+ attr_names = ('name', )
+
+class NamedInitializer(Node):
+ __slots__ = ('name', 'expr', 'coord', '__weakref__')
+ def __init__(self, name, expr, coord=None):
+ self.name = name
+ self.expr = expr
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.expr is not None: nodelist.append(("expr", self.expr))
+ for i, child in enumerate(self.name or []):
+ nodelist.append(("name[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.expr is not None:
+ yield self.expr
+ for child in (self.name or []):
+ yield child
+
+ attr_names = ()
+
+class ParamList(Node):
+ __slots__ = ('params', 'coord', '__weakref__')
+ def __init__(self, params, coord=None):
+ self.params = params
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.params or []):
+ nodelist.append(("params[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.params or []):
+ yield child
+
+ attr_names = ()
+
+class PtrDecl(Node):
+ __slots__ = ('quals', 'type', 'coord', '__weakref__')
+ def __init__(self, quals, type, coord=None):
+ self.quals = quals
+ self.type = type
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.type is not None: nodelist.append(("type", self.type))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.type is not None:
+ yield self.type
+
+ attr_names = ('quals', )
+
+class Return(Node):
+ __slots__ = ('expr', 'coord', '__weakref__')
+ def __init__(self, expr, coord=None):
+ self.expr = expr
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.expr is not None: nodelist.append(("expr", self.expr))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.expr is not None:
+ yield self.expr
+
+ attr_names = ()
+
+class StaticAssert(Node):
+ __slots__ = ('cond', 'message', 'coord', '__weakref__')
+ def __init__(self, cond, message, coord=None):
+ self.cond = cond
+ self.message = message
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.cond is not None: nodelist.append(("cond", self.cond))
+ if self.message is not None: nodelist.append(("message", self.message))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.cond is not None:
+ yield self.cond
+ if self.message is not None:
+ yield self.message
+
+ attr_names = ()
+
+class Struct(Node):
+ __slots__ = ('name', 'decls', 'coord', '__weakref__')
+ def __init__(self, name, decls, coord=None):
+ self.name = name
+ self.decls = decls
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.decls or []):
+ nodelist.append(("decls[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.decls or []):
+ yield child
+
+ attr_names = ('name', )
+
+class StructRef(Node):
+ __slots__ = ('name', 'type', 'field', 'coord', '__weakref__')
+ def __init__(self, name, type, field, coord=None):
+ self.name = name
+ self.type = type
+ self.field = field
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.name is not None: nodelist.append(("name", self.name))
+ if self.field is not None: nodelist.append(("field", self.field))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.name is not None:
+ yield self.name
+ if self.field is not None:
+ yield self.field
+
+ attr_names = ('type', )
+
+class Switch(Node):
+ __slots__ = ('cond', 'stmt', 'coord', '__weakref__')
+ def __init__(self, cond, stmt, coord=None):
+ self.cond = cond
+ self.stmt = stmt
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.cond is not None: nodelist.append(("cond", self.cond))
+ if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.cond is not None:
+ yield self.cond
+ if self.stmt is not None:
+ yield self.stmt
+
+ attr_names = ()
+
+class TernaryOp(Node):
+ __slots__ = ('cond', 'iftrue', 'iffalse', 'coord', '__weakref__')
+ def __init__(self, cond, iftrue, iffalse, coord=None):
+ self.cond = cond
+ self.iftrue = iftrue
+ self.iffalse = iffalse
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.cond is not None: nodelist.append(("cond", self.cond))
+ if self.iftrue is not None: nodelist.append(("iftrue", self.iftrue))
+ if self.iffalse is not None: nodelist.append(("iffalse", self.iffalse))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.cond is not None:
+ yield self.cond
+ if self.iftrue is not None:
+ yield self.iftrue
+ if self.iffalse is not None:
+ yield self.iffalse
+
+ attr_names = ()
+
+class TypeDecl(Node):
+ __slots__ = ('declname', 'quals', 'align', 'type', 'coord', '__weakref__')
+ def __init__(self, declname, quals, align, type, coord=None):
+ self.declname = declname
+ self.quals = quals
+ self.align = align
+ self.type = type
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.type is not None: nodelist.append(("type", self.type))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.type is not None:
+ yield self.type
+
+ attr_names = ('declname', 'quals', 'align', )
+
+class Typedef(Node):
+ __slots__ = ('name', 'quals', 'storage', 'type', 'coord', '__weakref__')
+ def __init__(self, name, quals, storage, type, coord=None):
+ self.name = name
+ self.quals = quals
+ self.storage = storage
+ self.type = type
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.type is not None: nodelist.append(("type", self.type))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.type is not None:
+ yield self.type
+
+ attr_names = ('name', 'quals', 'storage', )
+
+class Typename(Node):
+ __slots__ = ('name', 'quals', 'align', 'type', 'coord', '__weakref__')
+ def __init__(self, name, quals, align, type, coord=None):
+ self.name = name
+ self.quals = quals
+ self.align = align
+ self.type = type
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.type is not None: nodelist.append(("type", self.type))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.type is not None:
+ yield self.type
+
+ attr_names = ('name', 'quals', 'align', )
+
+class UnaryOp(Node):
+ __slots__ = ('op', 'expr', 'coord', '__weakref__')
+ def __init__(self, op, expr, coord=None):
+ self.op = op
+ self.expr = expr
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.expr is not None: nodelist.append(("expr", self.expr))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.expr is not None:
+ yield self.expr
+
+ attr_names = ('op', )
+
+class Union(Node):
+ __slots__ = ('name', 'decls', 'coord', '__weakref__')
+ def __init__(self, name, decls, coord=None):
+ self.name = name
+ self.decls = decls
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ for i, child in enumerate(self.decls or []):
+ nodelist.append(("decls[%d]" % i, child))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ for child in (self.decls or []):
+ yield child
+
+ attr_names = ('name', )
+
+class While(Node):
+ __slots__ = ('cond', 'stmt', 'coord', '__weakref__')
+ def __init__(self, cond, stmt, coord=None):
+ self.cond = cond
+ self.stmt = stmt
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ if self.cond is not None: nodelist.append(("cond", self.cond))
+ if self.stmt is not None: nodelist.append(("stmt", self.stmt))
+ return tuple(nodelist)
+
+ def __iter__(self):
+ if self.cond is not None:
+ yield self.cond
+ if self.stmt is not None:
+ yield self.stmt
+
+ attr_names = ()
+
+class Pragma(Node):
+ __slots__ = ('string', 'coord', '__weakref__')
+ def __init__(self, string, coord=None):
+ self.string = string
+ self.coord = coord
+
+ def children(self):
+ nodelist = []
+ return tuple(nodelist)
+
+ def __iter__(self):
+ return
+ yield
+
+ attr_names = ('string', )
+
diff --git a/lib/pycparser/c_generator.py b/lib/pycparser/c_generator.py
new file mode 100644
index 0000000..1057b2c
--- /dev/null
+++ b/lib/pycparser/c_generator.py
@@ -0,0 +1,502 @@
+#------------------------------------------------------------------------------
+# pycparser: c_generator.py
+#
+# C code generator from pycparser AST nodes.
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#------------------------------------------------------------------------------
+from . import c_ast
+
+
+class CGenerator(object):
+ """ Uses the same visitor pattern as c_ast.NodeVisitor, but modified to
+ return a value from each visit method, using string accumulation in
+ generic_visit.
+ """
+ def __init__(self, reduce_parentheses=False):
+ """ Constructs C-code generator
+
+ reduce_parentheses:
+ if True, eliminates needless parentheses on binary operators
+ """
+ # Statements start with indentation of self.indent_level spaces, using
+ # the _make_indent method.
+ self.indent_level = 0
+ self.reduce_parentheses = reduce_parentheses
+
+ def _make_indent(self):
+ return ' ' * self.indent_level
+
+ def visit(self, node):
+ method = 'visit_' + node.__class__.__name__
+ return getattr(self, method, self.generic_visit)(node)
+
+ def generic_visit(self, node):
+ if node is None:
+ return ''
+ else:
+ return ''.join(self.visit(c) for c_name, c in node.children())
+
+ def visit_Constant(self, n):
+ return n.value
+
+ def visit_ID(self, n):
+ return n.name
+
+ def visit_Pragma(self, n):
+ ret = '#pragma'
+ if n.string:
+ ret += ' ' + n.string
+ return ret
+
+ def visit_ArrayRef(self, n):
+ arrref = self._parenthesize_unless_simple(n.name)
+ return arrref + '[' + self.visit(n.subscript) + ']'
+
+ def visit_StructRef(self, n):
+ sref = self._parenthesize_unless_simple(n.name)
+ return sref + n.type + self.visit(n.field)
+
+ def visit_FuncCall(self, n):
+ fref = self._parenthesize_unless_simple(n.name)
+ return fref + '(' + self.visit(n.args) + ')'
+
+ def visit_UnaryOp(self, n):
+ if n.op == 'sizeof':
+ # Always parenthesize the argument of sizeof since it can be
+ # a name.
+ return 'sizeof(%s)' % self.visit(n.expr)
+ else:
+ operand = self._parenthesize_unless_simple(n.expr)
+ if n.op == 'p++':
+ return '%s++' % operand
+ elif n.op == 'p--':
+ return '%s--' % operand
+ else:
+ return '%s%s' % (n.op, operand)
+
+ # Precedence map of binary operators:
+ precedence_map = {
+ # Should be in sync with c_parser.CParser.precedence
+ # Higher numbers are stronger binding
+ '||': 0, # weakest binding
+ '&&': 1,
+ '|': 2,
+ '^': 3,
+ '&': 4,
+ '==': 5, '!=': 5,
+ '>': 6, '>=': 6, '<': 6, '<=': 6,
+ '>>': 7, '<<': 7,
+ '+': 8, '-': 8,
+ '*': 9, '/': 9, '%': 9 # strongest binding
+ }
+
+ def visit_BinaryOp(self, n):
+ # Note: all binary operators are left-to-right associative
+ #
+ # If `n.left.op` has a stronger or equally binding precedence in
+ # comparison to `n.op`, no parenthesis are needed for the left:
+ # e.g., `(a*b) + c` is equivalent to `a*b + c`, as well as
+ # `(a+b) - c` is equivalent to `a+b - c` (same precedence).
+ # If the left operator is weaker binding than the current, then
+ # parentheses are necessary:
+ # e.g., `(a+b) * c` is NOT equivalent to `a+b * c`.
+ lval_str = self._parenthesize_if(
+ n.left,
+ lambda d: not (self._is_simple_node(d) or
+ self.reduce_parentheses and isinstance(d, c_ast.BinaryOp) and
+ self.precedence_map[d.op] >= self.precedence_map[n.op]))
+ # If `n.right.op` has a stronger -but not equal- binding precedence,
+ # parenthesis can be omitted on the right:
+ # e.g., `a + (b*c)` is equivalent to `a + b*c`.
+ # If the right operator is weaker or equally binding, then parentheses
+ # are necessary:
+ # e.g., `a * (b+c)` is NOT equivalent to `a * b+c` and
+ # `a - (b+c)` is NOT equivalent to `a - b+c` (same precedence).
+ rval_str = self._parenthesize_if(
+ n.right,
+ lambda d: not (self._is_simple_node(d) or
+ self.reduce_parentheses and isinstance(d, c_ast.BinaryOp) and
+ self.precedence_map[d.op] > self.precedence_map[n.op]))
+ return '%s %s %s' % (lval_str, n.op, rval_str)
+
+ def visit_Assignment(self, n):
+ rval_str = self._parenthesize_if(
+ n.rvalue,
+ lambda n: isinstance(n, c_ast.Assignment))
+ return '%s %s %s' % (self.visit(n.lvalue), n.op, rval_str)
+
+ def visit_IdentifierType(self, n):
+ return ' '.join(n.names)
+
+ def _visit_expr(self, n):
+ if isinstance(n, c_ast.InitList):
+ return '{' + self.visit(n) + '}'
+ elif isinstance(n, c_ast.ExprList):
+ return '(' + self.visit(n) + ')'
+ else:
+ return self.visit(n)
+
+ def visit_Decl(self, n, no_type=False):
+ # no_type is used when a Decl is part of a DeclList, where the type is
+ # explicitly only for the first declaration in a list.
+ #
+ s = n.name if no_type else self._generate_decl(n)
+ if n.bitsize: s += ' : ' + self.visit(n.bitsize)
+ if n.init:
+ s += ' = ' + self._visit_expr(n.init)
+ return s
+
+ def visit_DeclList(self, n):
+ s = self.visit(n.decls[0])
+ if len(n.decls) > 1:
+ s += ', ' + ', '.join(self.visit_Decl(decl, no_type=True)
+ for decl in n.decls[1:])
+ return s
+
+ def visit_Typedef(self, n):
+ s = ''
+ if n.storage: s += ' '.join(n.storage) + ' '
+ s += self._generate_type(n.type)
+ return s
+
+ def visit_Cast(self, n):
+ s = '(' + self._generate_type(n.to_type, emit_declname=False) + ')'
+ return s + ' ' + self._parenthesize_unless_simple(n.expr)
+
+ def visit_ExprList(self, n):
+ visited_subexprs = []
+ for expr in n.exprs:
+ visited_subexprs.append(self._visit_expr(expr))
+ return ', '.join(visited_subexprs)
+
+ def visit_InitList(self, n):
+ visited_subexprs = []
+ for expr in n.exprs:
+ visited_subexprs.append(self._visit_expr(expr))
+ return ', '.join(visited_subexprs)
+
+ def visit_Enum(self, n):
+ return self._generate_struct_union_enum(n, name='enum')
+
+ def visit_Alignas(self, n):
+ return '_Alignas({})'.format(self.visit(n.alignment))
+
+ def visit_Enumerator(self, n):
+ if not n.value:
+ return '{indent}{name},\n'.format(
+ indent=self._make_indent(),
+ name=n.name,
+ )
+ else:
+ return '{indent}{name} = {value},\n'.format(
+ indent=self._make_indent(),
+ name=n.name,
+ value=self.visit(n.value),
+ )
+
+ def visit_FuncDef(self, n):
+ decl = self.visit(n.decl)
+ self.indent_level = 0
+ body = self.visit(n.body)
+ if n.param_decls:
+ knrdecls = ';\n'.join(self.visit(p) for p in n.param_decls)
+ return decl + '\n' + knrdecls + ';\n' + body + '\n'
+ else:
+ return decl + '\n' + body + '\n'
+
+ def visit_FileAST(self, n):
+ s = ''
+ for ext in n.ext:
+ if isinstance(ext, c_ast.FuncDef):
+ s += self.visit(ext)
+ elif isinstance(ext, c_ast.Pragma):
+ s += self.visit(ext) + '\n'
+ else:
+ s += self.visit(ext) + ';\n'
+ return s
+
+ def visit_Compound(self, n):
+ s = self._make_indent() + '{\n'
+ self.indent_level += 2
+ if n.block_items:
+ s += ''.join(self._generate_stmt(stmt) for stmt in n.block_items)
+ self.indent_level -= 2
+ s += self._make_indent() + '}\n'
+ return s
+
+ def visit_CompoundLiteral(self, n):
+ return '(' + self.visit(n.type) + '){' + self.visit(n.init) + '}'
+
+
+ def visit_EmptyStatement(self, n):
+ return ';'
+
+ def visit_ParamList(self, n):
+ return ', '.join(self.visit(param) for param in n.params)
+
+ def visit_Return(self, n):
+ s = 'return'
+ if n.expr: s += ' ' + self.visit(n.expr)
+ return s + ';'
+
+ def visit_Break(self, n):
+ return 'break;'
+
+ def visit_Continue(self, n):
+ return 'continue;'
+
+ def visit_TernaryOp(self, n):
+ s = '(' + self._visit_expr(n.cond) + ') ? '
+ s += '(' + self._visit_expr(n.iftrue) + ') : '
+ s += '(' + self._visit_expr(n.iffalse) + ')'
+ return s
+
+ def visit_If(self, n):
+ s = 'if ('
+ if n.cond: s += self.visit(n.cond)
+ s += ')\n'
+ s += self._generate_stmt(n.iftrue, add_indent=True)
+ if n.iffalse:
+ s += self._make_indent() + 'else\n'
+ s += self._generate_stmt(n.iffalse, add_indent=True)
+ return s
+
+ def visit_For(self, n):
+ s = 'for ('
+ if n.init: s += self.visit(n.init)
+ s += ';'
+ if n.cond: s += ' ' + self.visit(n.cond)
+ s += ';'
+ if n.next: s += ' ' + self.visit(n.next)
+ s += ')\n'
+ s += self._generate_stmt(n.stmt, add_indent=True)
+ return s
+
+ def visit_While(self, n):
+ s = 'while ('
+ if n.cond: s += self.visit(n.cond)
+ s += ')\n'
+ s += self._generate_stmt(n.stmt, add_indent=True)
+ return s
+
+ def visit_DoWhile(self, n):
+ s = 'do\n'
+ s += self._generate_stmt(n.stmt, add_indent=True)
+ s += self._make_indent() + 'while ('
+ if n.cond: s += self.visit(n.cond)
+ s += ');'
+ return s
+
+ def visit_StaticAssert(self, n):
+ s = '_Static_assert('
+ s += self.visit(n.cond)
+ if n.message:
+ s += ','
+ s += self.visit(n.message)
+ s += ')'
+ return s
+
+ def visit_Switch(self, n):
+ s = 'switch (' + self.visit(n.cond) + ')\n'
+ s += self._generate_stmt(n.stmt, add_indent=True)
+ return s
+
+ def visit_Case(self, n):
+ s = 'case ' + self.visit(n.expr) + ':\n'
+ for stmt in n.stmts:
+ s += self._generate_stmt(stmt, add_indent=True)
+ return s
+
+ def visit_Default(self, n):
+ s = 'default:\n'
+ for stmt in n.stmts:
+ s += self._generate_stmt(stmt, add_indent=True)
+ return s
+
+ def visit_Label(self, n):
+ return n.name + ':\n' + self._generate_stmt(n.stmt)
+
+ def visit_Goto(self, n):
+ return 'goto ' + n.name + ';'
+
+ def visit_EllipsisParam(self, n):
+ return '...'
+
+ def visit_Struct(self, n):
+ return self._generate_struct_union_enum(n, 'struct')
+
+ def visit_Typename(self, n):
+ return self._generate_type(n.type)
+
+ def visit_Union(self, n):
+ return self._generate_struct_union_enum(n, 'union')
+
+ def visit_NamedInitializer(self, n):
+ s = ''
+ for name in n.name:
+ if isinstance(name, c_ast.ID):
+ s += '.' + name.name
+ else:
+ s += '[' + self.visit(name) + ']'
+ s += ' = ' + self._visit_expr(n.expr)
+ return s
+
+ def visit_FuncDecl(self, n):
+ return self._generate_type(n)
+
+ def visit_ArrayDecl(self, n):
+ return self._generate_type(n, emit_declname=False)
+
+ def visit_TypeDecl(self, n):
+ return self._generate_type(n, emit_declname=False)
+
+ def visit_PtrDecl(self, n):
+ return self._generate_type(n, emit_declname=False)
+
+ def _generate_struct_union_enum(self, n, name):
+ """ Generates code for structs, unions, and enums. name should be
+ 'struct', 'union', or 'enum'.
+ """
+ if name in ('struct', 'union'):
+ members = n.decls
+ body_function = self._generate_struct_union_body
+ else:
+ assert name == 'enum'
+ members = None if n.values is None else n.values.enumerators
+ body_function = self._generate_enum_body
+ s = name + ' ' + (n.name or '')
+ if members is not None:
+ # None means no members
+ # Empty sequence means an empty list of members
+ s += '\n'
+ s += self._make_indent()
+ self.indent_level += 2
+ s += '{\n'
+ s += body_function(members)
+ self.indent_level -= 2
+ s += self._make_indent() + '}'
+ return s
+
+ def _generate_struct_union_body(self, members):
+ return ''.join(self._generate_stmt(decl) for decl in members)
+
+ def _generate_enum_body(self, members):
+ # `[:-2] + '\n'` removes the final `,` from the enumerator list
+ return ''.join(self.visit(value) for value in members)[:-2] + '\n'
+
+ def _generate_stmt(self, n, add_indent=False):
+ """ Generation from a statement node. This method exists as a wrapper
+ for individual visit_* methods to handle different treatment of
+ some statements in this context.
+ """
+ typ = type(n)
+ if add_indent: self.indent_level += 2
+ indent = self._make_indent()
+ if add_indent: self.indent_level -= 2
+
+ if typ in (
+ c_ast.Decl, c_ast.Assignment, c_ast.Cast, c_ast.UnaryOp,
+ c_ast.BinaryOp, c_ast.TernaryOp, c_ast.FuncCall, c_ast.ArrayRef,
+ c_ast.StructRef, c_ast.Constant, c_ast.ID, c_ast.Typedef,
+ c_ast.ExprList):
+ # These can also appear in an expression context so no semicolon
+ # is added to them automatically
+ #
+ return indent + self.visit(n) + ';\n'
+ elif typ in (c_ast.Compound,):
+ # No extra indentation required before the opening brace of a
+ # compound - because it consists of multiple lines it has to
+ # compute its own indentation.
+ #
+ return self.visit(n)
+ elif typ in (c_ast.If,):
+ return indent + self.visit(n)
+ else:
+ return indent + self.visit(n) + '\n'
+
+ def _generate_decl(self, n):
+ """ Generation from a Decl node.
+ """
+ s = ''
+ if n.funcspec: s = ' '.join(n.funcspec) + ' '
+ if n.storage: s += ' '.join(n.storage) + ' '
+ if n.align: s += self.visit(n.align[0]) + ' '
+ s += self._generate_type(n.type)
+ return s
+
+ def _generate_type(self, n, modifiers=[], emit_declname = True):
+ """ Recursive generation from a type node. n is the type node.
+ modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers
+ encountered on the way down to a TypeDecl, to allow proper
+ generation from it.
+ """
+ typ = type(n)
+ #~ print(n, modifiers)
+
+ if typ == c_ast.TypeDecl:
+ s = ''
+ if n.quals: s += ' '.join(n.quals) + ' '
+ s += self.visit(n.type)
+
+ nstr = n.declname if n.declname and emit_declname else ''
+ # Resolve modifiers.
+ # Wrap in parens to distinguish pointer to array and pointer to
+ # function syntax.
+ #
+ for i, modifier in enumerate(modifiers):
+ if isinstance(modifier, c_ast.ArrayDecl):
+ if (i != 0 and
+ isinstance(modifiers[i - 1], c_ast.PtrDecl)):
+ nstr = '(' + nstr + ')'
+ nstr += '['
+ if modifier.dim_quals:
+ nstr += ' '.join(modifier.dim_quals) + ' '
+ nstr += self.visit(modifier.dim) + ']'
+ elif isinstance(modifier, c_ast.FuncDecl):
+ if (i != 0 and
+ isinstance(modifiers[i - 1], c_ast.PtrDecl)):
+ nstr = '(' + nstr + ')'
+ nstr += '(' + self.visit(modifier.args) + ')'
+ elif isinstance(modifier, c_ast.PtrDecl):
+ if modifier.quals:
+ nstr = '* %s%s' % (' '.join(modifier.quals),
+ ' ' + nstr if nstr else '')
+ else:
+ nstr = '*' + nstr
+ if nstr: s += ' ' + nstr
+ return s
+ elif typ == c_ast.Decl:
+ return self._generate_decl(n.type)
+ elif typ == c_ast.Typename:
+ return self._generate_type(n.type, emit_declname = emit_declname)
+ elif typ == c_ast.IdentifierType:
+ return ' '.join(n.names) + ' '
+ elif typ in (c_ast.ArrayDecl, c_ast.PtrDecl, c_ast.FuncDecl):
+ return self._generate_type(n.type, modifiers + [n],
+ emit_declname = emit_declname)
+ else:
+ return self.visit(n)
+
+ def _parenthesize_if(self, n, condition):
+ """ Visits 'n' and returns its string representation, parenthesized
+ if the condition function applied to the node returns True.
+ """
+ s = self._visit_expr(n)
+ if condition(n):
+ return '(' + s + ')'
+ else:
+ return s
+
+ def _parenthesize_unless_simple(self, n):
+ """ Common use case for _parenthesize_if
+ """
+ return self._parenthesize_if(n, lambda d: not self._is_simple_node(d))
+
+ def _is_simple_node(self, n):
+ """ Returns True for nodes that are "simple" - i.e. nodes that always
+ have higher precedence than operators.
+ """
+ return isinstance(n, (c_ast.Constant, c_ast.ID, c_ast.ArrayRef,
+ c_ast.StructRef, c_ast.FuncCall))
diff --git a/lib/pycparser/c_lexer.py b/lib/pycparser/c_lexer.py
new file mode 100644
index 0000000..d68d8eb
--- /dev/null
+++ b/lib/pycparser/c_lexer.py
@@ -0,0 +1,554 @@
+#------------------------------------------------------------------------------
+# pycparser: c_lexer.py
+#
+# CLexer class: lexer for the C language
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#------------------------------------------------------------------------------
+import re
+
+from .ply import lex
+from .ply.lex import TOKEN
+
+
+class CLexer(object):
+ """ A lexer for the C language. After building it, set the
+ input text with input(), and call token() to get new
+ tokens.
+
+ The public attribute filename can be set to an initial
+ filename, but the lexer will update it upon #line
+ directives.
+ """
+ def __init__(self, error_func, on_lbrace_func, on_rbrace_func,
+ type_lookup_func):
+ """ Create a new Lexer.
+
+ error_func:
+ An error function. Will be called with an error
+ message, line and column as arguments, in case of
+ an error during lexing.
+
+ on_lbrace_func, on_rbrace_func:
+ Called when an LBRACE or RBRACE is encountered
+ (likely to push/pop type_lookup_func's scope)
+
+ type_lookup_func:
+ A type lookup function. Given a string, it must
+ return True IFF this string is a name of a type
+ that was defined with a typedef earlier.
+ """
+ self.error_func = error_func
+ self.on_lbrace_func = on_lbrace_func
+ self.on_rbrace_func = on_rbrace_func
+ self.type_lookup_func = type_lookup_func
+ self.filename = ''
+
+ # Keeps track of the last token returned from self.token()
+ self.last_token = None
+
+ # Allow either "# line" or "# <num>" to support GCC's
+ # cpp output
+ #
+ self.line_pattern = re.compile(r'([ \t]*line\W)|([ \t]*\d+)')
+ self.pragma_pattern = re.compile(r'[ \t]*pragma\W')
+
+ def build(self, **kwargs):
+ """ Builds the lexer from the specification. Must be
+ called after the lexer object is created.
+
+ This method exists separately, because the PLY
+ manual warns against calling lex.lex inside
+ __init__
+ """
+ self.lexer = lex.lex(object=self, **kwargs)
+
+ def reset_lineno(self):
+ """ Resets the internal line number counter of the lexer.
+ """
+ self.lexer.lineno = 1
+
+ def input(self, text):
+ self.lexer.input(text)
+
+ def token(self):
+ self.last_token = self.lexer.token()
+ return self.last_token
+
+ def find_tok_column(self, token):
+ """ Find the column of the token in its line.
+ """
+ last_cr = self.lexer.lexdata.rfind('\n', 0, token.lexpos)
+ return token.lexpos - last_cr
+
+ ######################-- PRIVATE --######################
+
+ ##
+ ## Internal auxiliary methods
+ ##
+ def _error(self, msg, token):
+ location = self._make_tok_location(token)
+ self.error_func(msg, location[0], location[1])
+ self.lexer.skip(1)
+
+ def _make_tok_location(self, token):
+ return (token.lineno, self.find_tok_column(token))
+
+ ##
+ ## Reserved keywords
+ ##
+ keywords = (
+ 'AUTO', 'BREAK', 'CASE', 'CHAR', 'CONST',
+ 'CONTINUE', 'DEFAULT', 'DO', 'DOUBLE', 'ELSE', 'ENUM', 'EXTERN',
+ 'FLOAT', 'FOR', 'GOTO', 'IF', 'INLINE', 'INT', 'LONG',
+ 'REGISTER', 'OFFSETOF',
+ 'RESTRICT', 'RETURN', 'SHORT', 'SIGNED', 'SIZEOF', 'STATIC', 'STRUCT',
+ 'SWITCH', 'TYPEDEF', 'UNION', 'UNSIGNED', 'VOID',
+ 'VOLATILE', 'WHILE', '__INT128',
+ )
+
+ keywords_new = (
+ '_BOOL', '_COMPLEX',
+ '_NORETURN', '_THREAD_LOCAL', '_STATIC_ASSERT',
+ '_ATOMIC', '_ALIGNOF', '_ALIGNAS',
+ )
+
+ keyword_map = {}
+
+ for keyword in keywords:
+ keyword_map[keyword.lower()] = keyword
+
+ for keyword in keywords_new:
+ keyword_map[keyword[:2].upper() + keyword[2:].lower()] = keyword
+
+ ##
+ ## All the tokens recognized by the lexer
+ ##
+ tokens = keywords + keywords_new + (
+ # Identifiers
+ 'ID',
+
+ # Type identifiers (identifiers previously defined as
+ # types with typedef)
+ 'TYPEID',
+
+ # constants
+ 'INT_CONST_DEC', 'INT_CONST_OCT', 'INT_CONST_HEX', 'INT_CONST_BIN', 'INT_CONST_CHAR',
+ 'FLOAT_CONST', 'HEX_FLOAT_CONST',
+ 'CHAR_CONST',
+ 'WCHAR_CONST',
+ 'U8CHAR_CONST',
+ 'U16CHAR_CONST',
+ 'U32CHAR_CONST',
+
+ # String literals
+ 'STRING_LITERAL',
+ 'WSTRING_LITERAL',
+ 'U8STRING_LITERAL',
+ 'U16STRING_LITERAL',
+ 'U32STRING_LITERAL',
+
+ # Operators
+ 'PLUS', 'MINUS', 'TIMES', 'DIVIDE', 'MOD',
+ 'OR', 'AND', 'NOT', 'XOR', 'LSHIFT', 'RSHIFT',
+ 'LOR', 'LAND', 'LNOT',
+ 'LT', 'LE', 'GT', 'GE', 'EQ', 'NE',
+
+ # Assignment
+ 'EQUALS', 'TIMESEQUAL', 'DIVEQUAL', 'MODEQUAL',
+ 'PLUSEQUAL', 'MINUSEQUAL',
+ 'LSHIFTEQUAL','RSHIFTEQUAL', 'ANDEQUAL', 'XOREQUAL',
+ 'OREQUAL',
+
+ # Increment/decrement
+ 'PLUSPLUS', 'MINUSMINUS',
+
+ # Structure dereference (->)
+ 'ARROW',
+
+ # Conditional operator (?)
+ 'CONDOP',
+
+ # Delimiters
+ 'LPAREN', 'RPAREN', # ( )
+ 'LBRACKET', 'RBRACKET', # [ ]
+ 'LBRACE', 'RBRACE', # { }
+ 'COMMA', 'PERIOD', # . ,
+ 'SEMI', 'COLON', # ; :
+
+ # Ellipsis (...)
+ 'ELLIPSIS',
+
+ # pre-processor
+ 'PPHASH', # '#'
+ 'PPPRAGMA', # 'pragma'
+ 'PPPRAGMASTR',
+ )
+
+ ##
+ ## Regexes for use in tokens
+ ##
+ ##
+
+ # valid C identifiers (K&R2: A.2.3), plus '$' (supported by some compilers)
+ identifier = r'[a-zA-Z_$][0-9a-zA-Z_$]*'
+
+ hex_prefix = '0[xX]'
+ hex_digits = '[0-9a-fA-F]+'
+ bin_prefix = '0[bB]'
+ bin_digits = '[01]+'
+
+ # integer constants (K&R2: A.2.5.1)
+ integer_suffix_opt = r'(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?'
+ decimal_constant = '(0'+integer_suffix_opt+')|([1-9][0-9]*'+integer_suffix_opt+')'
+ octal_constant = '0[0-7]*'+integer_suffix_opt
+ hex_constant = hex_prefix+hex_digits+integer_suffix_opt
+ bin_constant = bin_prefix+bin_digits+integer_suffix_opt
+
+ bad_octal_constant = '0[0-7]*[89]'
+
+ # character constants (K&R2: A.2.5.2)
+ # Note: a-zA-Z and '.-~^_!=&;,' are allowed as escape chars to support #line
+ # directives with Windows paths as filenames (..\..\dir\file)
+ # For the same reason, decimal_escape allows all digit sequences. We want to
+ # parse all correct code, even if it means to sometimes parse incorrect
+ # code.
+ #
+ # The original regexes were taken verbatim from the C syntax definition,
+ # and were later modified to avoid worst-case exponential running time.
+ #
+ # simple_escape = r"""([a-zA-Z._~!=&\^\-\\?'"])"""
+ # decimal_escape = r"""(\d+)"""
+ # hex_escape = r"""(x[0-9a-fA-F]+)"""
+ # bad_escape = r"""([\\][^a-zA-Z._~^!=&\^\-\\?'"x0-7])"""
+ #
+ # The following modifications were made to avoid the ambiguity that allowed backtracking:
+ # (https://github.com/eliben/pycparser/issues/61)
+ #
+ # - \x was removed from simple_escape, unless it was not followed by a hex digit, to avoid ambiguity with hex_escape.
+ # - hex_escape allows one or more hex characters, but requires that the next character(if any) is not hex
+ # - decimal_escape allows one or more decimal characters, but requires that the next character(if any) is not a decimal
+ # - bad_escape does not allow any decimals (8-9), to avoid conflicting with the permissive decimal_escape.
+ #
+ # Without this change, python's `re` module would recursively try parsing each ambiguous escape sequence in multiple ways.
+ # e.g. `\123` could be parsed as `\1`+`23`, `\12`+`3`, and `\123`.
+
+ simple_escape = r"""([a-wyzA-Z._~!=&\^\-\\?'"]|x(?![0-9a-fA-F]))"""
+ decimal_escape = r"""(\d+)(?!\d)"""
+ hex_escape = r"""(x[0-9a-fA-F]+)(?![0-9a-fA-F])"""
+ bad_escape = r"""([\\][^a-zA-Z._~^!=&\^\-\\?'"x0-9])"""
+
+ escape_sequence = r"""(\\("""+simple_escape+'|'+decimal_escape+'|'+hex_escape+'))'
+
+ # This complicated regex with lookahead might be slow for strings, so because all of the valid escapes (including \x) allowed
+ # 0 or more non-escaped characters after the first character, simple_escape+decimal_escape+hex_escape got simplified to
+
+ escape_sequence_start_in_string = r"""(\\[0-9a-zA-Z._~!=&\^\-\\?'"])"""
+
+ cconst_char = r"""([^'\\\n]|"""+escape_sequence+')'
+ char_const = "'"+cconst_char+"'"
+ wchar_const = 'L'+char_const
+ u8char_const = 'u8'+char_const
+ u16char_const = 'u'+char_const
+ u32char_const = 'U'+char_const
+ multicharacter_constant = "'"+cconst_char+"{2,4}'"
+ unmatched_quote = "('"+cconst_char+"*\\n)|('"+cconst_char+"*$)"
+ bad_char_const = r"""('"""+cconst_char+"""[^'\n]+')|('')|('"""+bad_escape+r"""[^'\n]*')"""
+
+ # string literals (K&R2: A.2.6)
+ string_char = r"""([^"\\\n]|"""+escape_sequence_start_in_string+')'
+ string_literal = '"'+string_char+'*"'
+ wstring_literal = 'L'+string_literal
+ u8string_literal = 'u8'+string_literal
+ u16string_literal = 'u'+string_literal
+ u32string_literal = 'U'+string_literal
+ bad_string_literal = '"'+string_char+'*'+bad_escape+string_char+'*"'
+
+ # floating constants (K&R2: A.2.5.3)
+ exponent_part = r"""([eE][-+]?[0-9]+)"""
+ fractional_constant = r"""([0-9]*\.[0-9]+)|([0-9]+\.)"""
+ floating_constant = '(((('+fractional_constant+')'+exponent_part+'?)|([0-9]+'+exponent_part+'))[FfLl]?)'
+ binary_exponent_part = r'''([pP][+-]?[0-9]+)'''
+ hex_fractional_constant = '((('+hex_digits+r""")?\."""+hex_digits+')|('+hex_digits+r"""\.))"""
+ hex_floating_constant = '('+hex_prefix+'('+hex_digits+'|'+hex_fractional_constant+')'+binary_exponent_part+'[FfLl]?)'
+
+ ##
+ ## Lexer states: used for preprocessor \n-terminated directives
+ ##
+ states = (
+ # ppline: preprocessor line directives
+ #
+ ('ppline', 'exclusive'),
+
+ # pppragma: pragma
+ #
+ ('pppragma', 'exclusive'),
+ )
+
+ def t_PPHASH(self, t):
+ r'[ \t]*\#'
+ if self.line_pattern.match(t.lexer.lexdata, pos=t.lexer.lexpos):
+ t.lexer.begin('ppline')
+ self.pp_line = self.pp_filename = None
+ elif self.pragma_pattern.match(t.lexer.lexdata, pos=t.lexer.lexpos):
+ t.lexer.begin('pppragma')
+ else:
+ t.type = 'PPHASH'
+ return t
+
+ ##
+ ## Rules for the ppline state
+ ##
+ @TOKEN(string_literal)
+ def t_ppline_FILENAME(self, t):
+ if self.pp_line is None:
+ self._error('filename before line number in #line', t)
+ else:
+ self.pp_filename = t.value.lstrip('"').rstrip('"')
+
+ @TOKEN(decimal_constant)
+ def t_ppline_LINE_NUMBER(self, t):
+ if self.pp_line is None:
+ self.pp_line = t.value
+ else:
+ # Ignore: GCC's cpp sometimes inserts a numeric flag
+ # after the file name
+ pass
+
+ def t_ppline_NEWLINE(self, t):
+ r'\n'
+ if self.pp_line is None:
+ self._error('line number missing in #line', t)
+ else:
+ self.lexer.lineno = int(self.pp_line)
+
+ if self.pp_filename is not None:
+ self.filename = self.pp_filename
+
+ t.lexer.begin('INITIAL')
+
+ def t_ppline_PPLINE(self, t):
+ r'line'
+ pass
+
+ t_ppline_ignore = ' \t'
+
+ def t_ppline_error(self, t):
+ self._error('invalid #line directive', t)
+
+ ##
+ ## Rules for the pppragma state
+ ##
+ def t_pppragma_NEWLINE(self, t):
+ r'\n'
+ t.lexer.lineno += 1
+ t.lexer.begin('INITIAL')
+
+ def t_pppragma_PPPRAGMA(self, t):
+ r'pragma'
+ return t
+
+ t_pppragma_ignore = ' \t'
+
+ def t_pppragma_STR(self, t):
+ '.+'
+ t.type = 'PPPRAGMASTR'
+ return t
+
+ def t_pppragma_error(self, t):
+ self._error('invalid #pragma directive', t)
+
+ ##
+ ## Rules for the normal state
+ ##
+ t_ignore = ' \t'
+
+ # Newlines
+ def t_NEWLINE(self, t):
+ r'\n+'
+ t.lexer.lineno += t.value.count("\n")
+
+ # Operators
+ t_PLUS = r'\+'
+ t_MINUS = r'-'
+ t_TIMES = r'\*'
+ t_DIVIDE = r'/'
+ t_MOD = r'%'
+ t_OR = r'\|'
+ t_AND = r'&'
+ t_NOT = r'~'
+ t_XOR = r'\^'
+ t_LSHIFT = r'<<'
+ t_RSHIFT = r'>>'
+ t_LOR = r'\|\|'
+ t_LAND = r'&&'
+ t_LNOT = r'!'
+ t_LT = r'<'
+ t_GT = r'>'
+ t_LE = r'<='
+ t_GE = r'>='
+ t_EQ = r'=='
+ t_NE = r'!='
+
+ # Assignment operators
+ t_EQUALS = r'='
+ t_TIMESEQUAL = r'\*='
+ t_DIVEQUAL = r'/='
+ t_MODEQUAL = r'%='
+ t_PLUSEQUAL = r'\+='
+ t_MINUSEQUAL = r'-='
+ t_LSHIFTEQUAL = r'<<='
+ t_RSHIFTEQUAL = r'>>='
+ t_ANDEQUAL = r'&='
+ t_OREQUAL = r'\|='
+ t_XOREQUAL = r'\^='
+
+ # Increment/decrement
+ t_PLUSPLUS = r'\+\+'
+ t_MINUSMINUS = r'--'
+
+ # ->
+ t_ARROW = r'->'
+
+ # ?
+ t_CONDOP = r'\?'
+
+ # Delimiters
+ t_LPAREN = r'\('
+ t_RPAREN = r'\)'
+ t_LBRACKET = r'\['
+ t_RBRACKET = r'\]'
+ t_COMMA = r','
+ t_PERIOD = r'\.'
+ t_SEMI = r';'
+ t_COLON = r':'
+ t_ELLIPSIS = r'\.\.\.'
+
+ # Scope delimiters
+ # To see why on_lbrace_func is needed, consider:
+ # typedef char TT;
+ # void foo(int TT) { TT = 10; }
+ # TT x = 5;
+ # Outside the function, TT is a typedef, but inside (starting and ending
+ # with the braces) it's a parameter. The trouble begins with yacc's
+ # lookahead token. If we open a new scope in brace_open, then TT has
+ # already been read and incorrectly interpreted as TYPEID. So, we need
+ # to open and close scopes from within the lexer.
+ # Similar for the TT immediately outside the end of the function.
+ #
+ @TOKEN(r'\{')
+ def t_LBRACE(self, t):
+ self.on_lbrace_func()
+ return t
+ @TOKEN(r'\}')
+ def t_RBRACE(self, t):
+ self.on_rbrace_func()
+ return t
+
+ t_STRING_LITERAL = string_literal
+
+ # The following floating and integer constants are defined as
+ # functions to impose a strict order (otherwise, decimal
+ # is placed before the others because its regex is longer,
+ # and this is bad)
+ #
+ @TOKEN(floating_constant)
+ def t_FLOAT_CONST(self, t):
+ return t
+
+ @TOKEN(hex_floating_constant)
+ def t_HEX_FLOAT_CONST(self, t):
+ return t
+
+ @TOKEN(hex_constant)
+ def t_INT_CONST_HEX(self, t):
+ return t
+
+ @TOKEN(bin_constant)
+ def t_INT_CONST_BIN(self, t):
+ return t
+
+ @TOKEN(bad_octal_constant)
+ def t_BAD_CONST_OCT(self, t):
+ msg = "Invalid octal constant"
+ self._error(msg, t)
+
+ @TOKEN(octal_constant)
+ def t_INT_CONST_OCT(self, t):
+ return t
+
+ @TOKEN(decimal_constant)
+ def t_INT_CONST_DEC(self, t):
+ return t
+
+ # Must come before bad_char_const, to prevent it from
+ # catching valid char constants as invalid
+ #
+ @TOKEN(multicharacter_constant)
+ def t_INT_CONST_CHAR(self, t):
+ return t
+
+ @TOKEN(char_const)
+ def t_CHAR_CONST(self, t):
+ return t
+
+ @TOKEN(wchar_const)
+ def t_WCHAR_CONST(self, t):
+ return t
+
+ @TOKEN(u8char_const)
+ def t_U8CHAR_CONST(self, t):
+ return t
+
+ @TOKEN(u16char_const)
+ def t_U16CHAR_CONST(self, t):
+ return t
+
+ @TOKEN(u32char_const)
+ def t_U32CHAR_CONST(self, t):
+ return t
+
+ @TOKEN(unmatched_quote)
+ def t_UNMATCHED_QUOTE(self, t):
+ msg = "Unmatched '"
+ self._error(msg, t)
+
+ @TOKEN(bad_char_const)
+ def t_BAD_CHAR_CONST(self, t):
+ msg = "Invalid char constant %s" % t.value
+ self._error(msg, t)
+
+ @TOKEN(wstring_literal)
+ def t_WSTRING_LITERAL(self, t):
+ return t
+
+ @TOKEN(u8string_literal)
+ def t_U8STRING_LITERAL(self, t):
+ return t
+
+ @TOKEN(u16string_literal)
+ def t_U16STRING_LITERAL(self, t):
+ return t
+
+ @TOKEN(u32string_literal)
+ def t_U32STRING_LITERAL(self, t):
+ return t
+
+ # unmatched string literals are caught by the preprocessor
+
+ @TOKEN(bad_string_literal)
+ def t_BAD_STRING_LITERAL(self, t):
+ msg = "String contains invalid escape code"
+ self._error(msg, t)
+
+ @TOKEN(identifier)
+ def t_ID(self, t):
+ t.type = self.keyword_map.get(t.value, "ID")
+ if t.type == 'ID' and self.type_lookup_func(t.value):
+ t.type = "TYPEID"
+ return t
+
+ def t_error(self, t):
+ msg = 'Illegal character %s' % repr(t.value[0])
+ self._error(msg, t)
diff --git a/lib/pycparser/c_parser.py b/lib/pycparser/c_parser.py
new file mode 100644
index 0000000..640a759
--- /dev/null
+++ b/lib/pycparser/c_parser.py
@@ -0,0 +1,1936 @@
+#------------------------------------------------------------------------------
+# pycparser: c_parser.py
+#
+# CParser class: Parser and AST builder for the C language
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#------------------------------------------------------------------------------
+from .ply import yacc
+
+from . import c_ast
+from .c_lexer import CLexer
+from .plyparser import PLYParser, ParseError, parameterized, template
+from .ast_transforms import fix_switch_cases, fix_atomic_specifiers
+
+
+@template
+class CParser(PLYParser):
+ def __init__(
+ self,
+ lex_optimize=True,
+ lexer=CLexer,
+ lextab='pycparser.lextab',
+ yacc_optimize=True,
+ yacctab='pycparser.yacctab',
+ yacc_debug=False,
+ taboutputdir=''):
+ """ Create a new CParser.
+
+ Some arguments for controlling the debug/optimization
+ level of the parser are provided. The defaults are
+ tuned for release/performance mode.
+ The simple rules for using them are:
+ *) When tweaking CParser/CLexer, set these to False
+ *) When releasing a stable parser, set to True
+
+ lex_optimize:
+ Set to False when you're modifying the lexer.
+ Otherwise, changes in the lexer won't be used, if
+ some lextab.py file exists.
+ When releasing with a stable lexer, set to True
+ to save the re-generation of the lexer table on
+ each run.
+
+ lexer:
+ Set this parameter to define the lexer to use if
+ you're not using the default CLexer.
+
+ lextab:
+ Points to the lex table that's used for optimized
+ mode. Only if you're modifying the lexer and want
+ some tests to avoid re-generating the table, make
+ this point to a local lex table file (that's been
+ earlier generated with lex_optimize=True)
+
+ yacc_optimize:
+ Set to False when you're modifying the parser.
+ Otherwise, changes in the parser won't be used, if
+ some parsetab.py file exists.
+ When releasing with a stable parser, set to True
+ to save the re-generation of the parser table on
+ each run.
+
+ yacctab:
+ Points to the yacc table that's used for optimized
+ mode. Only if you're modifying the parser, make
+ this point to a local yacc table file
+
+ yacc_debug:
+ Generate a parser.out file that explains how yacc
+ built the parsing table from the grammar.
+
+ taboutputdir:
+ Set this parameter to control the location of generated
+ lextab and yacctab files.
+ """
+ self.clex = lexer(
+ error_func=self._lex_error_func,
+ on_lbrace_func=self._lex_on_lbrace_func,
+ on_rbrace_func=self._lex_on_rbrace_func,
+ type_lookup_func=self._lex_type_lookup_func)
+
+ self.clex.build(
+ optimize=lex_optimize,
+ lextab=lextab,
+ outputdir=taboutputdir)
+ self.tokens = self.clex.tokens
+
+ rules_with_opt = [
+ 'abstract_declarator',
+ 'assignment_expression',
+ 'declaration_list',
+ 'declaration_specifiers_no_type',
+ 'designation',
+ 'expression',
+ 'identifier_list',
+ 'init_declarator_list',
+ 'id_init_declarator_list',
+ 'initializer_list',
+ 'parameter_type_list',
+ 'block_item_list',
+ 'type_qualifier_list',
+ 'struct_declarator_list'
+ ]
+
+ for rule in rules_with_opt:
+ self._create_opt_rule(rule)
+
+ self.cparser = yacc.yacc(
+ module=self,
+ start='translation_unit_or_empty',
+ debug=yacc_debug,
+ optimize=yacc_optimize,
+ tabmodule=yacctab,
+ outputdir=taboutputdir)
+
+ # Stack of scopes for keeping track of symbols. _scope_stack[-1] is
+ # the current (topmost) scope. Each scope is a dictionary that
+ # specifies whether a name is a type. If _scope_stack[n][name] is
+ # True, 'name' is currently a type in the scope. If it's False,
+ # 'name' is used in the scope but not as a type (for instance, if we
+ # saw: int name;
+ # If 'name' is not a key in _scope_stack[n] then 'name' was not defined
+ # in this scope at all.
+ self._scope_stack = [dict()]
+
+ # Keeps track of the last token given to yacc (the lookahead token)
+ self._last_yielded_token = None
+
+ def parse(self, text, filename='', debug=False):
+ """ Parses C code and returns an AST.
+
+ text:
+ A string containing the C source code
+
+ filename:
+ Name of the file being parsed (for meaningful
+ error messages)
+
+ debug:
+ Debug flag to YACC
+ """
+ self.clex.filename = filename
+ self.clex.reset_lineno()
+ self._scope_stack = [dict()]
+ self._last_yielded_token = None
+ return self.cparser.parse(
+ input=text,
+ lexer=self.clex,
+ debug=debug)
+
+ ######################-- PRIVATE --######################
+
+ def _push_scope(self):
+ self._scope_stack.append(dict())
+
+ def _pop_scope(self):
+ assert len(self._scope_stack) > 1
+ self._scope_stack.pop()
+
+ def _add_typedef_name(self, name, coord):
+ """ Add a new typedef name (ie a TYPEID) to the current scope
+ """
+ if not self._scope_stack[-1].get(name, True):
+ self._parse_error(
+ "Typedef %r previously declared as non-typedef "
+ "in this scope" % name, coord)
+ self._scope_stack[-1][name] = True
+
+ def _add_identifier(self, name, coord):
+ """ Add a new object, function, or enum member name (ie an ID) to the
+ current scope
+ """
+ if self._scope_stack[-1].get(name, False):
+ self._parse_error(
+ "Non-typedef %r previously declared as typedef "
+ "in this scope" % name, coord)
+ self._scope_stack[-1][name] = False
+
+ def _is_type_in_scope(self, name):
+ """ Is *name* a typedef-name in the current scope?
+ """
+ for scope in reversed(self._scope_stack):
+ # If name is an identifier in this scope it shadows typedefs in
+ # higher scopes.
+ in_scope = scope.get(name)
+ if in_scope is not None: return in_scope
+ return False
+
+ def _lex_error_func(self, msg, line, column):
+ self._parse_error(msg, self._coord(line, column))
+
+ def _lex_on_lbrace_func(self):
+ self._push_scope()
+
+ def _lex_on_rbrace_func(self):
+ self._pop_scope()
+
+ def _lex_type_lookup_func(self, name):
+ """ Looks up types that were previously defined with
+ typedef.
+ Passed to the lexer for recognizing identifiers that
+ are types.
+ """
+ is_type = self._is_type_in_scope(name)
+ return is_type
+
+ def _get_yacc_lookahead_token(self):
+ """ We need access to yacc's lookahead token in certain cases.
+ This is the last token yacc requested from the lexer, so we
+ ask the lexer.
+ """
+ return self.clex.last_token
+
+ # To understand what's going on here, read sections A.8.5 and
+ # A.8.6 of K&R2 very carefully.
+ #
+ # A C type consists of a basic type declaration, with a list
+ # of modifiers. For example:
+ #
+ # int *c[5];
+ #
+ # The basic declaration here is 'int c', and the pointer and
+ # the array are the modifiers.
+ #
+ # Basic declarations are represented by TypeDecl (from module c_ast) and the
+ # modifiers are FuncDecl, PtrDecl and ArrayDecl.
+ #
+ # The standard states that whenever a new modifier is parsed, it should be
+ # added to the end of the list of modifiers. For example:
+ #
+ # K&R2 A.8.6.2: Array Declarators
+ #
+ # In a declaration T D where D has the form
+ # D1 [constant-expression-opt]
+ # and the type of the identifier in the declaration T D1 is
+ # "type-modifier T", the type of the
+ # identifier of D is "type-modifier array of T"
+ #
+ # This is what this method does. The declarator it receives
+ # can be a list of declarators ending with TypeDecl. It
+ # tacks the modifier to the end of this list, just before
+ # the TypeDecl.
+ #
+ # Additionally, the modifier may be a list itself. This is
+ # useful for pointers, that can come as a chain from the rule
+ # p_pointer. In this case, the whole modifier list is spliced
+ # into the new location.
+ def _type_modify_decl(self, decl, modifier):
+ """ Tacks a type modifier on a declarator, and returns
+ the modified declarator.
+
+ Note: the declarator and modifier may be modified
+ """
+ #~ print '****'
+ #~ decl.show(offset=3)
+ #~ modifier.show(offset=3)
+ #~ print '****'
+
+ modifier_head = modifier
+ modifier_tail = modifier
+
+ # The modifier may be a nested list. Reach its tail.
+ while modifier_tail.type:
+ modifier_tail = modifier_tail.type
+
+ # If the decl is a basic type, just tack the modifier onto it.
+ if isinstance(decl, c_ast.TypeDecl):
+ modifier_tail.type = decl
+ return modifier
+ else:
+ # Otherwise, the decl is a list of modifiers. Reach
+ # its tail and splice the modifier onto the tail,
+ # pointing to the underlying basic type.
+ decl_tail = decl
+
+ while not isinstance(decl_tail.type, c_ast.TypeDecl):
+ decl_tail = decl_tail.type
+
+ modifier_tail.type = decl_tail.type
+ decl_tail.type = modifier_head
+ return decl
+
+ # Due to the order in which declarators are constructed,
+ # they have to be fixed in order to look like a normal AST.
+ #
+ # When a declaration arrives from syntax construction, it has
+ # these problems:
+ # * The innermost TypeDecl has no type (because the basic
+ # type is only known at the uppermost declaration level)
+ # * The declaration has no variable name, since that is saved
+ # in the innermost TypeDecl
+ # * The typename of the declaration is a list of type
+ # specifiers, and not a node. Here, basic identifier types
+ # should be separated from more complex types like enums
+ # and structs.
+ #
+ # This method fixes these problems.
+ def _fix_decl_name_type(self, decl, typename):
+ """ Fixes a declaration. Modifies decl.
+ """
+ # Reach the underlying basic type
+ #
+ type = decl
+ while not isinstance(type, c_ast.TypeDecl):
+ type = type.type
+
+ decl.name = type.declname
+ type.quals = decl.quals[:]
+
+ # The typename is a list of types. If any type in this
+ # list isn't an IdentifierType, it must be the only
+ # type in the list (it's illegal to declare "int enum ..")
+ # If all the types are basic, they're collected in the
+ # IdentifierType holder.
+ for tn in typename:
+ if not isinstance(tn, c_ast.IdentifierType):
+ if len(typename) > 1:
+ self._parse_error(
+ "Invalid multiple types specified", tn.coord)
+ else:
+ type.type = tn
+ return decl
+
+ if not typename:
+ # Functions default to returning int
+ #
+ if not isinstance(decl.type, c_ast.FuncDecl):
+ self._parse_error(
+ "Missing type in declaration", decl.coord)
+ type.type = c_ast.IdentifierType(
+ ['int'],
+ coord=decl.coord)
+ else:
+ # At this point, we know that typename is a list of IdentifierType
+ # nodes. Concatenate all the names into a single list.
+ #
+ type.type = c_ast.IdentifierType(
+ [name for id in typename for name in id.names],
+ coord=typename[0].coord)
+ return decl
+
+ def _add_declaration_specifier(self, declspec, newspec, kind, append=False):
+ """ Declaration specifiers are represented by a dictionary
+ with the entries:
+ * qual: a list of type qualifiers
+ * storage: a list of storage type qualifiers
+ * type: a list of type specifiers
+ * function: a list of function specifiers
+ * alignment: a list of alignment specifiers
+
+ This method is given a declaration specifier, and a
+ new specifier of a given kind.
+ If `append` is True, the new specifier is added to the end of
+ the specifiers list, otherwise it's added at the beginning.
+ Returns the declaration specifier, with the new
+ specifier incorporated.
+ """
+ spec = declspec or dict(qual=[], storage=[], type=[], function=[], alignment=[])
+
+ if append:
+ spec[kind].append(newspec)
+ else:
+ spec[kind].insert(0, newspec)
+
+ return spec
+
+ def _build_declarations(self, spec, decls, typedef_namespace=False):
+ """ Builds a list of declarations all sharing the given specifiers.
+ If typedef_namespace is true, each declared name is added
+ to the "typedef namespace", which also includes objects,
+ functions, and enum constants.
+ """
+ is_typedef = 'typedef' in spec['storage']
+ declarations = []
+
+ # Bit-fields are allowed to be unnamed.
+ if decls[0].get('bitsize') is not None:
+ pass
+
+ # When redeclaring typedef names as identifiers in inner scopes, a
+ # problem can occur where the identifier gets grouped into
+ # spec['type'], leaving decl as None. This can only occur for the
+ # first declarator.
+ elif decls[0]['decl'] is None:
+ if len(spec['type']) < 2 or len(spec['type'][-1].names) != 1 or \
+ not self._is_type_in_scope(spec['type'][-1].names[0]):
+ coord = '?'
+ for t in spec['type']:
+ if hasattr(t, 'coord'):
+ coord = t.coord
+ break
+ self._parse_error('Invalid declaration', coord)
+
+ # Make this look as if it came from "direct_declarator:ID"
+ decls[0]['decl'] = c_ast.TypeDecl(
+ declname=spec['type'][-1].names[0],
+ type=None,
+ quals=None,
+ align=spec['alignment'],
+ coord=spec['type'][-1].coord)
+ # Remove the "new" type's name from the end of spec['type']
+ del spec['type'][-1]
+
+ # A similar problem can occur where the declaration ends up looking
+ # like an abstract declarator. Give it a name if this is the case.
+ elif not isinstance(decls[0]['decl'], (
+ c_ast.Enum, c_ast.Struct, c_ast.Union, c_ast.IdentifierType)):
+ decls_0_tail = decls[0]['decl']
+ while not isinstance(decls_0_tail, c_ast.TypeDecl):
+ decls_0_tail = decls_0_tail.type
+ if decls_0_tail.declname is None:
+ decls_0_tail.declname = spec['type'][-1].names[0]
+ del spec['type'][-1]
+
+ for decl in decls:
+ assert decl['decl'] is not None
+ if is_typedef:
+ declaration = c_ast.Typedef(
+ name=None,
+ quals=spec['qual'],
+ storage=spec['storage'],
+ type=decl['decl'],
+ coord=decl['decl'].coord)
+ else:
+ declaration = c_ast.Decl(
+ name=None,
+ quals=spec['qual'],
+ align=spec['alignment'],
+ storage=spec['storage'],
+ funcspec=spec['function'],
+ type=decl['decl'],
+ init=decl.get('init'),
+ bitsize=decl.get('bitsize'),
+ coord=decl['decl'].coord)
+
+ if isinstance(declaration.type, (
+ c_ast.Enum, c_ast.Struct, c_ast.Union,
+ c_ast.IdentifierType)):
+ fixed_decl = declaration
+ else:
+ fixed_decl = self._fix_decl_name_type(declaration, spec['type'])
+
+ # Add the type name defined by typedef to a
+ # symbol table (for usage in the lexer)
+ if typedef_namespace:
+ if is_typedef:
+ self._add_typedef_name(fixed_decl.name, fixed_decl.coord)
+ else:
+ self._add_identifier(fixed_decl.name, fixed_decl.coord)
+
+ fixed_decl = fix_atomic_specifiers(fixed_decl)
+ declarations.append(fixed_decl)
+
+ return declarations
+
+ def _build_function_definition(self, spec, decl, param_decls, body):
+ """ Builds a function definition.
+ """
+ if 'typedef' in spec['storage']:
+ self._parse_error("Invalid typedef", decl.coord)
+
+ declaration = self._build_declarations(
+ spec=spec,
+ decls=[dict(decl=decl, init=None)],
+ typedef_namespace=True)[0]
+
+ return c_ast.FuncDef(
+ decl=declaration,
+ param_decls=param_decls,
+ body=body,
+ coord=decl.coord)
+
+ def _select_struct_union_class(self, token):
+ """ Given a token (either STRUCT or UNION), selects the
+ appropriate AST class.
+ """
+ if token == 'struct':
+ return c_ast.Struct
+ else:
+ return c_ast.Union
+
+ ##
+ ## Precedence and associativity of operators
+ ##
+ # If this changes, c_generator.CGenerator.precedence_map needs to change as
+ # well
+ precedence = (
+ ('left', 'LOR'),
+ ('left', 'LAND'),
+ ('left', 'OR'),
+ ('left', 'XOR'),
+ ('left', 'AND'),
+ ('left', 'EQ', 'NE'),
+ ('left', 'GT', 'GE', 'LT', 'LE'),
+ ('left', 'RSHIFT', 'LSHIFT'),
+ ('left', 'PLUS', 'MINUS'),
+ ('left', 'TIMES', 'DIVIDE', 'MOD')
+ )
+
+ ##
+ ## Grammar productions
+ ## Implementation of the BNF defined in K&R2 A.13
+ ##
+
+ # Wrapper around a translation unit, to allow for empty input.
+ # Not strictly part of the C99 Grammar, but useful in practice.
+ def p_translation_unit_or_empty(self, p):
+ """ translation_unit_or_empty : translation_unit
+ | empty
+ """
+ if p[1] is None:
+ p[0] = c_ast.FileAST([])
+ else:
+ p[0] = c_ast.FileAST(p[1])
+
+ def p_translation_unit_1(self, p):
+ """ translation_unit : external_declaration
+ """
+ # Note: external_declaration is already a list
+ p[0] = p[1]
+
+ def p_translation_unit_2(self, p):
+ """ translation_unit : translation_unit external_declaration
+ """
+ p[1].extend(p[2])
+ p[0] = p[1]
+
+ # Declarations always come as lists (because they can be
+ # several in one line), so we wrap the function definition
+ # into a list as well, to make the return value of
+ # external_declaration homogeneous.
+ def p_external_declaration_1(self, p):
+ """ external_declaration : function_definition
+ """
+ p[0] = [p[1]]
+
+ def p_external_declaration_2(self, p):
+ """ external_declaration : declaration
+ """
+ p[0] = p[1]
+
+ def p_external_declaration_3(self, p):
+ """ external_declaration : pp_directive
+ | pppragma_directive
+ """
+ p[0] = [p[1]]
+
+ def p_external_declaration_4(self, p):
+ """ external_declaration : SEMI
+ """
+ p[0] = []
+
+ def p_external_declaration_5(self, p):
+ """ external_declaration : static_assert
+ """
+ p[0] = p[1]
+
+ def p_static_assert_declaration(self, p):
+ """ static_assert : _STATIC_ASSERT LPAREN constant_expression COMMA unified_string_literal RPAREN
+ | _STATIC_ASSERT LPAREN constant_expression RPAREN
+ """
+ if len(p) == 5:
+ p[0] = [c_ast.StaticAssert(p[3], None, self._token_coord(p, 1))]
+ else:
+ p[0] = [c_ast.StaticAssert(p[3], p[5], self._token_coord(p, 1))]
+
+ def p_pp_directive(self, p):
+ """ pp_directive : PPHASH
+ """
+ self._parse_error('Directives not supported yet',
+ self._token_coord(p, 1))
+
+ def p_pppragma_directive(self, p):
+ """ pppragma_directive : PPPRAGMA
+ | PPPRAGMA PPPRAGMASTR
+ """
+ if len(p) == 3:
+ p[0] = c_ast.Pragma(p[2], self._token_coord(p, 2))
+ else:
+ p[0] = c_ast.Pragma("", self._token_coord(p, 1))
+
+ # In function definitions, the declarator can be followed by
+ # a declaration list, for old "K&R style" function definitios.
+ def p_function_definition_1(self, p):
+ """ function_definition : id_declarator declaration_list_opt compound_statement
+ """
+ # no declaration specifiers - 'int' becomes the default type
+ spec = dict(
+ qual=[],
+ alignment=[],
+ storage=[],
+ type=[c_ast.IdentifierType(['int'],
+ coord=self._token_coord(p, 1))],
+ function=[])
+
+ p[0] = self._build_function_definition(
+ spec=spec,
+ decl=p[1],
+ param_decls=p[2],
+ body=p[3])
+
+ def p_function_definition_2(self, p):
+ """ function_definition : declaration_specifiers id_declarator declaration_list_opt compound_statement
+ """
+ spec = p[1]
+
+ p[0] = self._build_function_definition(
+ spec=spec,
+ decl=p[2],
+ param_decls=p[3],
+ body=p[4])
+
+ # Note, according to C18 A.2.2 6.7.10 static_assert-declaration _Static_assert
+ # is a declaration, not a statement. We additionally recognise it as a statement
+ # to fix parsing of _Static_assert inside the functions.
+ #
+ def p_statement(self, p):
+ """ statement : labeled_statement
+ | expression_statement
+ | compound_statement
+ | selection_statement
+ | iteration_statement
+ | jump_statement
+ | pppragma_directive
+ | static_assert
+ """
+ p[0] = p[1]
+
+ # A pragma is generally considered a decorator rather than an actual
+ # statement. Still, for the purposes of analyzing an abstract syntax tree of
+ # C code, pragma's should not be ignored and were previously treated as a
+ # statement. This presents a problem for constructs that take a statement
+ # such as labeled_statements, selection_statements, and
+ # iteration_statements, causing a misleading structure in the AST. For
+ # example, consider the following C code.
+ #
+ # for (int i = 0; i < 3; i++)
+ # #pragma omp critical
+ # sum += 1;
+ #
+ # This code will compile and execute "sum += 1;" as the body of the for
+ # loop. Previous implementations of PyCParser would render the AST for this
+ # block of code as follows:
+ #
+ # For:
+ # DeclList:
+ # Decl: i, [], [], []
+ # TypeDecl: i, []
+ # IdentifierType: ['int']
+ # Constant: int, 0
+ # BinaryOp: <
+ # ID: i
+ # Constant: int, 3
+ # UnaryOp: p++
+ # ID: i
+ # Pragma: omp critical
+ # Assignment: +=
+ # ID: sum
+ # Constant: int, 1
+ #
+ # This AST misleadingly takes the Pragma as the body of the loop and the
+ # assignment then becomes a sibling of the loop.
+ #
+ # To solve edge cases like these, the pragmacomp_or_statement rule groups
+ # a pragma and its following statement (which would otherwise be orphaned)
+ # using a compound block, effectively turning the above code into:
+ #
+ # for (int i = 0; i < 3; i++) {
+ # #pragma omp critical
+ # sum += 1;
+ # }
+ def p_pragmacomp_or_statement(self, p):
+ """ pragmacomp_or_statement : pppragma_directive statement
+ | statement
+ """
+ if isinstance(p[1], c_ast.Pragma) and len(p) == 3:
+ p[0] = c_ast.Compound(
+ block_items=[p[1], p[2]],
+ coord=self._token_coord(p, 1))
+ else:
+ p[0] = p[1]
+
+ # In C, declarations can come several in a line:
+ # int x, *px, romulo = 5;
+ #
+ # However, for the AST, we will split them to separate Decl
+ # nodes.
+ #
+ # This rule splits its declarations and always returns a list
+ # of Decl nodes, even if it's one element long.
+ #
+ def p_decl_body(self, p):
+ """ decl_body : declaration_specifiers init_declarator_list_opt
+ | declaration_specifiers_no_type id_init_declarator_list_opt
+ """
+ spec = p[1]
+
+ # p[2] (init_declarator_list_opt) is either a list or None
+ #
+ if p[2] is None:
+ # By the standard, you must have at least one declarator unless
+ # declaring a structure tag, a union tag, or the members of an
+ # enumeration.
+ #
+ ty = spec['type']
+ s_u_or_e = (c_ast.Struct, c_ast.Union, c_ast.Enum)
+ if len(ty) == 1 and isinstance(ty[0], s_u_or_e):
+ decls = [c_ast.Decl(
+ name=None,
+ quals=spec['qual'],
+ align=spec['alignment'],
+ storage=spec['storage'],
+ funcspec=spec['function'],
+ type=ty[0],
+ init=None,
+ bitsize=None,
+ coord=ty[0].coord)]
+
+ # However, this case can also occur on redeclared identifiers in
+ # an inner scope. The trouble is that the redeclared type's name
+ # gets grouped into declaration_specifiers; _build_declarations
+ # compensates for this.
+ #
+ else:
+ decls = self._build_declarations(
+ spec=spec,
+ decls=[dict(decl=None, init=None)],
+ typedef_namespace=True)
+
+ else:
+ decls = self._build_declarations(
+ spec=spec,
+ decls=p[2],
+ typedef_namespace=True)
+
+ p[0] = decls
+
+ # The declaration has been split to a decl_body sub-rule and
+ # SEMI, because having them in a single rule created a problem
+ # for defining typedefs.
+ #
+ # If a typedef line was directly followed by a line using the
+ # type defined with the typedef, the type would not be
+ # recognized. This is because to reduce the declaration rule,
+ # the parser's lookahead asked for the token after SEMI, which
+ # was the type from the next line, and the lexer had no chance
+ # to see the updated type symbol table.
+ #
+ # Splitting solves this problem, because after seeing SEMI,
+ # the parser reduces decl_body, which actually adds the new
+ # type into the table to be seen by the lexer before the next
+ # line is reached.
+ def p_declaration(self, p):
+ """ declaration : decl_body SEMI
+ """
+ p[0] = p[1]
+
+ # Since each declaration is a list of declarations, this
+ # rule will combine all the declarations and return a single
+ # list
+ #
+ def p_declaration_list(self, p):
+ """ declaration_list : declaration
+ | declaration_list declaration
+ """
+ p[0] = p[1] if len(p) == 2 else p[1] + p[2]
+
+ # To know when declaration-specifiers end and declarators begin,
+ # we require declaration-specifiers to have at least one
+ # type-specifier, and disallow typedef-names after we've seen any
+ # type-specifier. These are both required by the spec.
+ #
+ def p_declaration_specifiers_no_type_1(self, p):
+ """ declaration_specifiers_no_type : type_qualifier declaration_specifiers_no_type_opt
+ """
+ p[0] = self._add_declaration_specifier(p[2], p[1], 'qual')
+
+ def p_declaration_specifiers_no_type_2(self, p):
+ """ declaration_specifiers_no_type : storage_class_specifier declaration_specifiers_no_type_opt
+ """
+ p[0] = self._add_declaration_specifier(p[2], p[1], 'storage')
+
+ def p_declaration_specifiers_no_type_3(self, p):
+ """ declaration_specifiers_no_type : function_specifier declaration_specifiers_no_type_opt
+ """
+ p[0] = self._add_declaration_specifier(p[2], p[1], 'function')
+
+ # Without this, `typedef _Atomic(T) U` will parse incorrectly because the
+ # _Atomic qualifier will match, instead of the specifier.
+ def p_declaration_specifiers_no_type_4(self, p):
+ """ declaration_specifiers_no_type : atomic_specifier declaration_specifiers_no_type_opt
+ """
+ p[0] = self._add_declaration_specifier(p[2], p[1], 'type')
+
+ def p_declaration_specifiers_no_type_5(self, p):
+ """ declaration_specifiers_no_type : alignment_specifier declaration_specifiers_no_type_opt
+ """
+ p[0] = self._add_declaration_specifier(p[2], p[1], 'alignment')
+
+ def p_declaration_specifiers_1(self, p):
+ """ declaration_specifiers : declaration_specifiers type_qualifier
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'qual', append=True)
+
+ def p_declaration_specifiers_2(self, p):
+ """ declaration_specifiers : declaration_specifiers storage_class_specifier
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'storage', append=True)
+
+ def p_declaration_specifiers_3(self, p):
+ """ declaration_specifiers : declaration_specifiers function_specifier
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'function', append=True)
+
+ def p_declaration_specifiers_4(self, p):
+ """ declaration_specifiers : declaration_specifiers type_specifier_no_typeid
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'type', append=True)
+
+ def p_declaration_specifiers_5(self, p):
+ """ declaration_specifiers : type_specifier
+ """
+ p[0] = self._add_declaration_specifier(None, p[1], 'type')
+
+ def p_declaration_specifiers_6(self, p):
+ """ declaration_specifiers : declaration_specifiers_no_type type_specifier
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'type', append=True)
+
+ def p_declaration_specifiers_7(self, p):
+ """ declaration_specifiers : declaration_specifiers alignment_specifier
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'alignment', append=True)
+
+ def p_storage_class_specifier(self, p):
+ """ storage_class_specifier : AUTO
+ | REGISTER
+ | STATIC
+ | EXTERN
+ | TYPEDEF
+ | _THREAD_LOCAL
+ """
+ p[0] = p[1]
+
+ def p_function_specifier(self, p):
+ """ function_specifier : INLINE
+ | _NORETURN
+ """
+ p[0] = p[1]
+
+ def p_type_specifier_no_typeid(self, p):
+ """ type_specifier_no_typeid : VOID
+ | _BOOL
+ | CHAR
+ | SHORT
+ | INT
+ | LONG
+ | FLOAT
+ | DOUBLE
+ | _COMPLEX
+ | SIGNED
+ | UNSIGNED
+ | __INT128
+ """
+ p[0] = c_ast.IdentifierType([p[1]], coord=self._token_coord(p, 1))
+
+ def p_type_specifier(self, p):
+ """ type_specifier : typedef_name
+ | enum_specifier
+ | struct_or_union_specifier
+ | type_specifier_no_typeid
+ | atomic_specifier
+ """
+ p[0] = p[1]
+
+ # See section 6.7.2.4 of the C11 standard.
+ def p_atomic_specifier(self, p):
+ """ atomic_specifier : _ATOMIC LPAREN type_name RPAREN
+ """
+ typ = p[3]
+ typ.quals.append('_Atomic')
+ p[0] = typ
+
+ def p_type_qualifier(self, p):
+ """ type_qualifier : CONST
+ | RESTRICT
+ | VOLATILE
+ | _ATOMIC
+ """
+ p[0] = p[1]
+
+ def p_init_declarator_list(self, p):
+ """ init_declarator_list : init_declarator
+ | init_declarator_list COMMA init_declarator
+ """
+ p[0] = p[1] + [p[3]] if len(p) == 4 else [p[1]]
+
+ # Returns a {decl=<declarator> : init=<initializer>} dictionary
+ # If there's no initializer, uses None
+ #
+ def p_init_declarator(self, p):
+ """ init_declarator : declarator
+ | declarator EQUALS initializer
+ """
+ p[0] = dict(decl=p[1], init=(p[3] if len(p) > 2 else None))
+
+ def p_id_init_declarator_list(self, p):
+ """ id_init_declarator_list : id_init_declarator
+ | id_init_declarator_list COMMA init_declarator
+ """
+ p[0] = p[1] + [p[3]] if len(p) == 4 else [p[1]]
+
+ def p_id_init_declarator(self, p):
+ """ id_init_declarator : id_declarator
+ | id_declarator EQUALS initializer
+ """
+ p[0] = dict(decl=p[1], init=(p[3] if len(p) > 2 else None))
+
+ # Require at least one type specifier in a specifier-qualifier-list
+ #
+ def p_specifier_qualifier_list_1(self, p):
+ """ specifier_qualifier_list : specifier_qualifier_list type_specifier_no_typeid
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'type', append=True)
+
+ def p_specifier_qualifier_list_2(self, p):
+ """ specifier_qualifier_list : specifier_qualifier_list type_qualifier
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'qual', append=True)
+
+ def p_specifier_qualifier_list_3(self, p):
+ """ specifier_qualifier_list : type_specifier
+ """
+ p[0] = self._add_declaration_specifier(None, p[1], 'type')
+
+ def p_specifier_qualifier_list_4(self, p):
+ """ specifier_qualifier_list : type_qualifier_list type_specifier
+ """
+ p[0] = dict(qual=p[1], alignment=[], storage=[], type=[p[2]], function=[])
+
+ def p_specifier_qualifier_list_5(self, p):
+ """ specifier_qualifier_list : alignment_specifier
+ """
+ p[0] = dict(qual=[], alignment=[p[1]], storage=[], type=[], function=[])
+
+ def p_specifier_qualifier_list_6(self, p):
+ """ specifier_qualifier_list : specifier_qualifier_list alignment_specifier
+ """
+ p[0] = self._add_declaration_specifier(p[1], p[2], 'alignment')
+
+ # TYPEID is allowed here (and in other struct/enum related tag names), because
+ # struct/enum tags reside in their own namespace and can be named the same as types
+ #
+ def p_struct_or_union_specifier_1(self, p):
+ """ struct_or_union_specifier : struct_or_union ID
+ | struct_or_union TYPEID
+ """
+ klass = self._select_struct_union_class(p[1])
+ # None means no list of members
+ p[0] = klass(
+ name=p[2],
+ decls=None,
+ coord=self._token_coord(p, 2))
+
+ def p_struct_or_union_specifier_2(self, p):
+ """ struct_or_union_specifier : struct_or_union brace_open struct_declaration_list brace_close
+ | struct_or_union brace_open brace_close
+ """
+ klass = self._select_struct_union_class(p[1])
+ if len(p) == 4:
+ # Empty sequence means an empty list of members
+ p[0] = klass(
+ name=None,
+ decls=[],
+ coord=self._token_coord(p, 2))
+ else:
+ p[0] = klass(
+ name=None,
+ decls=p[3],
+ coord=self._token_coord(p, 2))
+
+
+ def p_struct_or_union_specifier_3(self, p):
+ """ struct_or_union_specifier : struct_or_union ID brace_open struct_declaration_list brace_close
+ | struct_or_union ID brace_open brace_close
+ | struct_or_union TYPEID brace_open struct_declaration_list brace_close
+ | struct_or_union TYPEID brace_open brace_close
+ """
+ klass = self._select_struct_union_class(p[1])
+ if len(p) == 5:
+ # Empty sequence means an empty list of members
+ p[0] = klass(
+ name=p[2],
+ decls=[],
+ coord=self._token_coord(p, 2))
+ else:
+ p[0] = klass(
+ name=p[2],
+ decls=p[4],
+ coord=self._token_coord(p, 2))
+
+ def p_struct_or_union(self, p):
+ """ struct_or_union : STRUCT
+ | UNION
+ """
+ p[0] = p[1]
+
+ # Combine all declarations into a single list
+ #
+ def p_struct_declaration_list(self, p):
+ """ struct_declaration_list : struct_declaration
+ | struct_declaration_list struct_declaration
+ """
+ if len(p) == 2:
+ p[0] = p[1] or []
+ else:
+ p[0] = p[1] + (p[2] or [])
+
+ def p_struct_declaration_1(self, p):
+ """ struct_declaration : specifier_qualifier_list struct_declarator_list_opt SEMI
+ """
+ spec = p[1]
+ assert 'typedef' not in spec['storage']
+
+ if p[2] is not None:
+ decls = self._build_declarations(
+ spec=spec,
+ decls=p[2])
+
+ elif len(spec['type']) == 1:
+ # Anonymous struct/union, gcc extension, C1x feature.
+ # Although the standard only allows structs/unions here, I see no
+ # reason to disallow other types since some compilers have typedefs
+ # here, and pycparser isn't about rejecting all invalid code.
+ #
+ node = spec['type'][0]
+ if isinstance(node, c_ast.Node):
+ decl_type = node
+ else:
+ decl_type = c_ast.IdentifierType(node)
+
+ decls = self._build_declarations(
+ spec=spec,
+ decls=[dict(decl=decl_type)])
+
+ else:
+ # Structure/union members can have the same names as typedefs.
+ # The trouble is that the member's name gets grouped into
+ # specifier_qualifier_list; _build_declarations compensates.
+ #
+ decls = self._build_declarations(
+ spec=spec,
+ decls=[dict(decl=None, init=None)])
+
+ p[0] = decls
+
+ def p_struct_declaration_2(self, p):
+ """ struct_declaration : SEMI
+ """
+ p[0] = None
+
+ def p_struct_declaration_3(self, p):
+ """ struct_declaration : pppragma_directive
+ """
+ p[0] = [p[1]]
+
+ def p_struct_declarator_list(self, p):
+ """ struct_declarator_list : struct_declarator
+ | struct_declarator_list COMMA struct_declarator
+ """
+ p[0] = p[1] + [p[3]] if len(p) == 4 else [p[1]]
+
+ # struct_declarator passes up a dict with the keys: decl (for
+ # the underlying declarator) and bitsize (for the bitsize)
+ #
+ def p_struct_declarator_1(self, p):
+ """ struct_declarator : declarator
+ """
+ p[0] = {'decl': p[1], 'bitsize': None}
+
+ def p_struct_declarator_2(self, p):
+ """ struct_declarator : declarator COLON constant_expression
+ | COLON constant_expression
+ """
+ if len(p) > 3:
+ p[0] = {'decl': p[1], 'bitsize': p[3]}
+ else:
+ p[0] = {'decl': c_ast.TypeDecl(None, None, None, None), 'bitsize': p[2]}
+
+ def p_enum_specifier_1(self, p):
+ """ enum_specifier : ENUM ID
+ | ENUM TYPEID
+ """
+ p[0] = c_ast.Enum(p[2], None, self._token_coord(p, 1))
+
+ def p_enum_specifier_2(self, p):
+ """ enum_specifier : ENUM brace_open enumerator_list brace_close
+ """
+ p[0] = c_ast.Enum(None, p[3], self._token_coord(p, 1))
+
+ def p_enum_specifier_3(self, p):
+ """ enum_specifier : ENUM ID brace_open enumerator_list brace_close
+ | ENUM TYPEID brace_open enumerator_list brace_close
+ """
+ p[0] = c_ast.Enum(p[2], p[4], self._token_coord(p, 1))
+
+ def p_enumerator_list(self, p):
+ """ enumerator_list : enumerator
+ | enumerator_list COMMA
+ | enumerator_list COMMA enumerator
+ """
+ if len(p) == 2:
+ p[0] = c_ast.EnumeratorList([p[1]], p[1].coord)
+ elif len(p) == 3:
+ p[0] = p[1]
+ else:
+ p[1].enumerators.append(p[3])
+ p[0] = p[1]
+
+ def p_alignment_specifier(self, p):
+ """ alignment_specifier : _ALIGNAS LPAREN type_name RPAREN
+ | _ALIGNAS LPAREN constant_expression RPAREN
+ """
+ p[0] = c_ast.Alignas(p[3], self._token_coord(p, 1))
+
+ def p_enumerator(self, p):
+ """ enumerator : ID
+ | ID EQUALS constant_expression
+ """
+ if len(p) == 2:
+ enumerator = c_ast.Enumerator(
+ p[1], None,
+ self._token_coord(p, 1))
+ else:
+ enumerator = c_ast.Enumerator(
+ p[1], p[3],
+ self._token_coord(p, 1))
+ self._add_identifier(enumerator.name, enumerator.coord)
+
+ p[0] = enumerator
+
+ def p_declarator(self, p):
+ """ declarator : id_declarator
+ | typeid_declarator
+ """
+ p[0] = p[1]
+
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID'))
+ def p_xxx_declarator_1(self, p):
+ """ xxx_declarator : direct_xxx_declarator
+ """
+ p[0] = p[1]
+
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID'))
+ def p_xxx_declarator_2(self, p):
+ """ xxx_declarator : pointer direct_xxx_declarator
+ """
+ p[0] = self._type_modify_decl(p[2], p[1])
+
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID'))
+ def p_direct_xxx_declarator_1(self, p):
+ """ direct_xxx_declarator : yyy
+ """
+ p[0] = c_ast.TypeDecl(
+ declname=p[1],
+ type=None,
+ quals=None,
+ align=None,
+ coord=self._token_coord(p, 1))
+
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'))
+ def p_direct_xxx_declarator_2(self, p):
+ """ direct_xxx_declarator : LPAREN xxx_declarator RPAREN
+ """
+ p[0] = p[2]
+
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID'))
+ def p_direct_xxx_declarator_3(self, p):
+ """ direct_xxx_declarator : direct_xxx_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET
+ """
+ quals = (p[3] if len(p) > 5 else []) or []
+ # Accept dimension qualifiers
+ # Per C99 6.7.5.3 p7
+ arr = c_ast.ArrayDecl(
+ type=None,
+ dim=p[4] if len(p) > 5 else p[3],
+ dim_quals=quals,
+ coord=p[1].coord)
+
+ p[0] = self._type_modify_decl(decl=p[1], modifier=arr)
+
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID'))
+ def p_direct_xxx_declarator_4(self, p):
+ """ direct_xxx_declarator : direct_xxx_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET
+ | direct_xxx_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET
+ """
+ # Using slice notation for PLY objects doesn't work in Python 3 for the
+ # version of PLY embedded with pycparser; see PLY Google Code issue 30.
+ # Work around that here by listing the two elements separately.
+ listed_quals = [item if isinstance(item, list) else [item]
+ for item in [p[3],p[4]]]
+ dim_quals = [qual for sublist in listed_quals for qual in sublist
+ if qual is not None]
+ arr = c_ast.ArrayDecl(
+ type=None,
+ dim=p[5],
+ dim_quals=dim_quals,
+ coord=p[1].coord)
+
+ p[0] = self._type_modify_decl(decl=p[1], modifier=arr)
+
+ # Special for VLAs
+ #
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID'))
+ def p_direct_xxx_declarator_5(self, p):
+ """ direct_xxx_declarator : direct_xxx_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET
+ """
+ arr = c_ast.ArrayDecl(
+ type=None,
+ dim=c_ast.ID(p[4], self._token_coord(p, 4)),
+ dim_quals=p[3] if p[3] is not None else [],
+ coord=p[1].coord)
+
+ p[0] = self._type_modify_decl(decl=p[1], modifier=arr)
+
+ @parameterized(('id', 'ID'), ('typeid', 'TYPEID'), ('typeid_noparen', 'TYPEID'))
+ def p_direct_xxx_declarator_6(self, p):
+ """ direct_xxx_declarator : direct_xxx_declarator LPAREN parameter_type_list RPAREN
+ | direct_xxx_declarator LPAREN identifier_list_opt RPAREN
+ """
+ func = c_ast.FuncDecl(
+ args=p[3],
+ type=None,
+ coord=p[1].coord)
+
+ # To see why _get_yacc_lookahead_token is needed, consider:
+ # typedef char TT;
+ # void foo(int TT) { TT = 10; }
+ # Outside the function, TT is a typedef, but inside (starting and
+ # ending with the braces) it's a parameter. The trouble begins with
+ # yacc's lookahead token. We don't know if we're declaring or
+ # defining a function until we see LBRACE, but if we wait for yacc to
+ # trigger a rule on that token, then TT will have already been read
+ # and incorrectly interpreted as TYPEID. We need to add the
+ # parameters to the scope the moment the lexer sees LBRACE.
+ #
+ if self._get_yacc_lookahead_token().type == "LBRACE":
+ if func.args is not None:
+ for param in func.args.params:
+ if isinstance(param, c_ast.EllipsisParam): break
+ self._add_identifier(param.name, param.coord)
+
+ p[0] = self._type_modify_decl(decl=p[1], modifier=func)
+
+ def p_pointer(self, p):
+ """ pointer : TIMES type_qualifier_list_opt
+ | TIMES type_qualifier_list_opt pointer
+ """
+ coord = self._token_coord(p, 1)
+ # Pointer decls nest from inside out. This is important when different
+ # levels have different qualifiers. For example:
+ #
+ # char * const * p;
+ #
+ # Means "pointer to const pointer to char"
+ #
+ # While:
+ #
+ # char ** const p;
+ #
+ # Means "const pointer to pointer to char"
+ #
+ # So when we construct PtrDecl nestings, the leftmost pointer goes in
+ # as the most nested type.
+ nested_type = c_ast.PtrDecl(quals=p[2] or [], type=None, coord=coord)
+ if len(p) > 3:
+ tail_type = p[3]
+ while tail_type.type is not None:
+ tail_type = tail_type.type
+ tail_type.type = nested_type
+ p[0] = p[3]
+ else:
+ p[0] = nested_type
+
+ def p_type_qualifier_list(self, p):
+ """ type_qualifier_list : type_qualifier
+ | type_qualifier_list type_qualifier
+ """
+ p[0] = [p[1]] if len(p) == 2 else p[1] + [p[2]]
+
+ def p_parameter_type_list(self, p):
+ """ parameter_type_list : parameter_list
+ | parameter_list COMMA ELLIPSIS
+ """
+ if len(p) > 2:
+ p[1].params.append(c_ast.EllipsisParam(self._token_coord(p, 3)))
+
+ p[0] = p[1]
+
+ def p_parameter_list(self, p):
+ """ parameter_list : parameter_declaration
+ | parameter_list COMMA parameter_declaration
+ """
+ if len(p) == 2: # single parameter
+ p[0] = c_ast.ParamList([p[1]], p[1].coord)
+ else:
+ p[1].params.append(p[3])
+ p[0] = p[1]
+
+ # From ISO/IEC 9899:TC2, 6.7.5.3.11:
+ # "If, in a parameter declaration, an identifier can be treated either
+ # as a typedef name or as a parameter name, it shall be taken as a
+ # typedef name."
+ #
+ # Inside a parameter declaration, once we've reduced declaration specifiers,
+ # if we shift in an LPAREN and see a TYPEID, it could be either an abstract
+ # declarator or a declarator nested inside parens. This rule tells us to
+ # always treat it as an abstract declarator. Therefore, we only accept
+ # `id_declarator`s and `typeid_noparen_declarator`s.
+ def p_parameter_declaration_1(self, p):
+ """ parameter_declaration : declaration_specifiers id_declarator
+ | declaration_specifiers typeid_noparen_declarator
+ """
+ spec = p[1]
+ if not spec['type']:
+ spec['type'] = [c_ast.IdentifierType(['int'],
+ coord=self._token_coord(p, 1))]
+ p[0] = self._build_declarations(
+ spec=spec,
+ decls=[dict(decl=p[2])])[0]
+
+ def p_parameter_declaration_2(self, p):
+ """ parameter_declaration : declaration_specifiers abstract_declarator_opt
+ """
+ spec = p[1]
+ if not spec['type']:
+ spec['type'] = [c_ast.IdentifierType(['int'],
+ coord=self._token_coord(p, 1))]
+
+ # Parameters can have the same names as typedefs. The trouble is that
+ # the parameter's name gets grouped into declaration_specifiers, making
+ # it look like an old-style declaration; compensate.
+ #
+ if len(spec['type']) > 1 and len(spec['type'][-1].names) == 1 and \
+ self._is_type_in_scope(spec['type'][-1].names[0]):
+ decl = self._build_declarations(
+ spec=spec,
+ decls=[dict(decl=p[2], init=None)])[0]
+
+ # This truly is an old-style parameter declaration
+ #
+ else:
+ decl = c_ast.Typename(
+ name='',
+ quals=spec['qual'],
+ align=None,
+ type=p[2] or c_ast.TypeDecl(None, None, None, None),
+ coord=self._token_coord(p, 2))
+ typename = spec['type']
+ decl = self._fix_decl_name_type(decl, typename)
+
+ p[0] = decl
+
+ def p_identifier_list(self, p):
+ """ identifier_list : identifier
+ | identifier_list COMMA identifier
+ """
+ if len(p) == 2: # single parameter
+ p[0] = c_ast.ParamList([p[1]], p[1].coord)
+ else:
+ p[1].params.append(p[3])
+ p[0] = p[1]
+
+ def p_initializer_1(self, p):
+ """ initializer : assignment_expression
+ """
+ p[0] = p[1]
+
+ def p_initializer_2(self, p):
+ """ initializer : brace_open initializer_list_opt brace_close
+ | brace_open initializer_list COMMA brace_close
+ """
+ if p[2] is None:
+ p[0] = c_ast.InitList([], self._token_coord(p, 1))
+ else:
+ p[0] = p[2]
+
+ def p_initializer_list(self, p):
+ """ initializer_list : designation_opt initializer
+ | initializer_list COMMA designation_opt initializer
+ """
+ if len(p) == 3: # single initializer
+ init = p[2] if p[1] is None else c_ast.NamedInitializer(p[1], p[2])
+ p[0] = c_ast.InitList([init], p[2].coord)
+ else:
+ init = p[4] if p[3] is None else c_ast.NamedInitializer(p[3], p[4])
+ p[1].exprs.append(init)
+ p[0] = p[1]
+
+ def p_designation(self, p):
+ """ designation : designator_list EQUALS
+ """
+ p[0] = p[1]
+
+ # Designators are represented as a list of nodes, in the order in which
+ # they're written in the code.
+ #
+ def p_designator_list(self, p):
+ """ designator_list : designator
+ | designator_list designator
+ """
+ p[0] = [p[1]] if len(p) == 2 else p[1] + [p[2]]
+
+ def p_designator(self, p):
+ """ designator : LBRACKET constant_expression RBRACKET
+ | PERIOD identifier
+ """
+ p[0] = p[2]
+
+ def p_type_name(self, p):
+ """ type_name : specifier_qualifier_list abstract_declarator_opt
+ """
+ typename = c_ast.Typename(
+ name='',
+ quals=p[1]['qual'][:],
+ align=None,
+ type=p[2] or c_ast.TypeDecl(None, None, None, None),
+ coord=self._token_coord(p, 2))
+
+ p[0] = self._fix_decl_name_type(typename, p[1]['type'])
+
+ def p_abstract_declarator_1(self, p):
+ """ abstract_declarator : pointer
+ """
+ dummytype = c_ast.TypeDecl(None, None, None, None)
+ p[0] = self._type_modify_decl(
+ decl=dummytype,
+ modifier=p[1])
+
+ def p_abstract_declarator_2(self, p):
+ """ abstract_declarator : pointer direct_abstract_declarator
+ """
+ p[0] = self._type_modify_decl(p[2], p[1])
+
+ def p_abstract_declarator_3(self, p):
+ """ abstract_declarator : direct_abstract_declarator
+ """
+ p[0] = p[1]
+
+ # Creating and using direct_abstract_declarator_opt here
+ # instead of listing both direct_abstract_declarator and the
+ # lack of it in the beginning of _1 and _2 caused two
+ # shift/reduce errors.
+ #
+ def p_direct_abstract_declarator_1(self, p):
+ """ direct_abstract_declarator : LPAREN abstract_declarator RPAREN """
+ p[0] = p[2]
+
+ def p_direct_abstract_declarator_2(self, p):
+ """ direct_abstract_declarator : direct_abstract_declarator LBRACKET assignment_expression_opt RBRACKET
+ """
+ arr = c_ast.ArrayDecl(
+ type=None,
+ dim=p[3],
+ dim_quals=[],
+ coord=p[1].coord)
+
+ p[0] = self._type_modify_decl(decl=p[1], modifier=arr)
+
+ def p_direct_abstract_declarator_3(self, p):
+ """ direct_abstract_declarator : LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET
+ """
+ quals = (p[2] if len(p) > 4 else []) or []
+ p[0] = c_ast.ArrayDecl(
+ type=c_ast.TypeDecl(None, None, None, None),
+ dim=p[3] if len(p) > 4 else p[2],
+ dim_quals=quals,
+ coord=self._token_coord(p, 1))
+
+ def p_direct_abstract_declarator_4(self, p):
+ """ direct_abstract_declarator : direct_abstract_declarator LBRACKET TIMES RBRACKET
+ """
+ arr = c_ast.ArrayDecl(
+ type=None,
+ dim=c_ast.ID(p[3], self._token_coord(p, 3)),
+ dim_quals=[],
+ coord=p[1].coord)
+
+ p[0] = self._type_modify_decl(decl=p[1], modifier=arr)
+
+ def p_direct_abstract_declarator_5(self, p):
+ """ direct_abstract_declarator : LBRACKET TIMES RBRACKET
+ """
+ p[0] = c_ast.ArrayDecl(
+ type=c_ast.TypeDecl(None, None, None, None),
+ dim=c_ast.ID(p[3], self._token_coord(p, 3)),
+ dim_quals=[],
+ coord=self._token_coord(p, 1))
+
+ def p_direct_abstract_declarator_6(self, p):
+ """ direct_abstract_declarator : direct_abstract_declarator LPAREN parameter_type_list_opt RPAREN
+ """
+ func = c_ast.FuncDecl(
+ args=p[3],
+ type=None,
+ coord=p[1].coord)
+
+ p[0] = self._type_modify_decl(decl=p[1], modifier=func)
+
+ def p_direct_abstract_declarator_7(self, p):
+ """ direct_abstract_declarator : LPAREN parameter_type_list_opt RPAREN
+ """
+ p[0] = c_ast.FuncDecl(
+ args=p[2],
+ type=c_ast.TypeDecl(None, None, None, None),
+ coord=self._token_coord(p, 1))
+
+ # declaration is a list, statement isn't. To make it consistent, block_item
+ # will always be a list
+ #
+ def p_block_item(self, p):
+ """ block_item : declaration
+ | statement
+ """
+ p[0] = p[1] if isinstance(p[1], list) else [p[1]]
+
+ # Since we made block_item a list, this just combines lists
+ #
+ def p_block_item_list(self, p):
+ """ block_item_list : block_item
+ | block_item_list block_item
+ """
+ # Empty block items (plain ';') produce [None], so ignore them
+ p[0] = p[1] if (len(p) == 2 or p[2] == [None]) else p[1] + p[2]
+
+ def p_compound_statement_1(self, p):
+ """ compound_statement : brace_open block_item_list_opt brace_close """
+ p[0] = c_ast.Compound(
+ block_items=p[2],
+ coord=self._token_coord(p, 1))
+
+ def p_labeled_statement_1(self, p):
+ """ labeled_statement : ID COLON pragmacomp_or_statement """
+ p[0] = c_ast.Label(p[1], p[3], self._token_coord(p, 1))
+
+ def p_labeled_statement_2(self, p):
+ """ labeled_statement : CASE constant_expression COLON pragmacomp_or_statement """
+ p[0] = c_ast.Case(p[2], [p[4]], self._token_coord(p, 1))
+
+ def p_labeled_statement_3(self, p):
+ """ labeled_statement : DEFAULT COLON pragmacomp_or_statement """
+ p[0] = c_ast.Default([p[3]], self._token_coord(p, 1))
+
+ def p_selection_statement_1(self, p):
+ """ selection_statement : IF LPAREN expression RPAREN pragmacomp_or_statement """
+ p[0] = c_ast.If(p[3], p[5], None, self._token_coord(p, 1))
+
+ def p_selection_statement_2(self, p):
+ """ selection_statement : IF LPAREN expression RPAREN statement ELSE pragmacomp_or_statement """
+ p[0] = c_ast.If(p[3], p[5], p[7], self._token_coord(p, 1))
+
+ def p_selection_statement_3(self, p):
+ """ selection_statement : SWITCH LPAREN expression RPAREN pragmacomp_or_statement """
+ p[0] = fix_switch_cases(
+ c_ast.Switch(p[3], p[5], self._token_coord(p, 1)))
+
+ def p_iteration_statement_1(self, p):
+ """ iteration_statement : WHILE LPAREN expression RPAREN pragmacomp_or_statement """
+ p[0] = c_ast.While(p[3], p[5], self._token_coord(p, 1))
+
+ def p_iteration_statement_2(self, p):
+ """ iteration_statement : DO pragmacomp_or_statement WHILE LPAREN expression RPAREN SEMI """
+ p[0] = c_ast.DoWhile(p[5], p[2], self._token_coord(p, 1))
+
+ def p_iteration_statement_3(self, p):
+ """ iteration_statement : FOR LPAREN expression_opt SEMI expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement """
+ p[0] = c_ast.For(p[3], p[5], p[7], p[9], self._token_coord(p, 1))
+
+ def p_iteration_statement_4(self, p):
+ """ iteration_statement : FOR LPAREN declaration expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement """
+ p[0] = c_ast.For(c_ast.DeclList(p[3], self._token_coord(p, 1)),
+ p[4], p[6], p[8], self._token_coord(p, 1))
+
+ def p_jump_statement_1(self, p):
+ """ jump_statement : GOTO ID SEMI """
+ p[0] = c_ast.Goto(p[2], self._token_coord(p, 1))
+
+ def p_jump_statement_2(self, p):
+ """ jump_statement : BREAK SEMI """
+ p[0] = c_ast.Break(self._token_coord(p, 1))
+
+ def p_jump_statement_3(self, p):
+ """ jump_statement : CONTINUE SEMI """
+ p[0] = c_ast.Continue(self._token_coord(p, 1))
+
+ def p_jump_statement_4(self, p):
+ """ jump_statement : RETURN expression SEMI
+ | RETURN SEMI
+ """
+ p[0] = c_ast.Return(p[2] if len(p) == 4 else None, self._token_coord(p, 1))
+
+ def p_expression_statement(self, p):
+ """ expression_statement : expression_opt SEMI """
+ if p[1] is None:
+ p[0] = c_ast.EmptyStatement(self._token_coord(p, 2))
+ else:
+ p[0] = p[1]
+
+ def p_expression(self, p):
+ """ expression : assignment_expression
+ | expression COMMA assignment_expression
+ """
+ if len(p) == 2:
+ p[0] = p[1]
+ else:
+ if not isinstance(p[1], c_ast.ExprList):
+ p[1] = c_ast.ExprList([p[1]], p[1].coord)
+
+ p[1].exprs.append(p[3])
+ p[0] = p[1]
+
+ def p_parenthesized_compound_expression(self, p):
+ """ assignment_expression : LPAREN compound_statement RPAREN """
+ p[0] = p[2]
+
+ def p_typedef_name(self, p):
+ """ typedef_name : TYPEID """
+ p[0] = c_ast.IdentifierType([p[1]], coord=self._token_coord(p, 1))
+
+ def p_assignment_expression(self, p):
+ """ assignment_expression : conditional_expression
+ | unary_expression assignment_operator assignment_expression
+ """
+ if len(p) == 2:
+ p[0] = p[1]
+ else:
+ p[0] = c_ast.Assignment(p[2], p[1], p[3], p[1].coord)
+
+ # K&R2 defines these as many separate rules, to encode
+ # precedence and associativity. Why work hard ? I'll just use
+ # the built in precedence/associativity specification feature
+ # of PLY. (see precedence declaration above)
+ #
+ def p_assignment_operator(self, p):
+ """ assignment_operator : EQUALS
+ | XOREQUAL
+ | TIMESEQUAL
+ | DIVEQUAL
+ | MODEQUAL
+ | PLUSEQUAL
+ | MINUSEQUAL
+ | LSHIFTEQUAL
+ | RSHIFTEQUAL
+ | ANDEQUAL
+ | OREQUAL
+ """
+ p[0] = p[1]
+
+ def p_constant_expression(self, p):
+ """ constant_expression : conditional_expression """
+ p[0] = p[1]
+
+ def p_conditional_expression(self, p):
+ """ conditional_expression : binary_expression
+ | binary_expression CONDOP expression COLON conditional_expression
+ """
+ if len(p) == 2:
+ p[0] = p[1]
+ else:
+ p[0] = c_ast.TernaryOp(p[1], p[3], p[5], p[1].coord)
+
+ def p_binary_expression(self, p):
+ """ binary_expression : cast_expression
+ | binary_expression TIMES binary_expression
+ | binary_expression DIVIDE binary_expression
+ | binary_expression MOD binary_expression
+ | binary_expression PLUS binary_expression
+ | binary_expression MINUS binary_expression
+ | binary_expression RSHIFT binary_expression
+ | binary_expression LSHIFT binary_expression
+ | binary_expression LT binary_expression
+ | binary_expression LE binary_expression
+ | binary_expression GE binary_expression
+ | binary_expression GT binary_expression
+ | binary_expression EQ binary_expression
+ | binary_expression NE binary_expression
+ | binary_expression AND binary_expression
+ | binary_expression OR binary_expression
+ | binary_expression XOR binary_expression
+ | binary_expression LAND binary_expression
+ | binary_expression LOR binary_expression
+ """
+ if len(p) == 2:
+ p[0] = p[1]
+ else:
+ p[0] = c_ast.BinaryOp(p[2], p[1], p[3], p[1].coord)
+
+ def p_cast_expression_1(self, p):
+ """ cast_expression : unary_expression """
+ p[0] = p[1]
+
+ def p_cast_expression_2(self, p):
+ """ cast_expression : LPAREN type_name RPAREN cast_expression """
+ p[0] = c_ast.Cast(p[2], p[4], self._token_coord(p, 1))
+
+ def p_unary_expression_1(self, p):
+ """ unary_expression : postfix_expression """
+ p[0] = p[1]
+
+ def p_unary_expression_2(self, p):
+ """ unary_expression : PLUSPLUS unary_expression
+ | MINUSMINUS unary_expression
+ | unary_operator cast_expression
+ """
+ p[0] = c_ast.UnaryOp(p[1], p[2], p[2].coord)
+
+ def p_unary_expression_3(self, p):
+ """ unary_expression : SIZEOF unary_expression
+ | SIZEOF LPAREN type_name RPAREN
+ | _ALIGNOF LPAREN type_name RPAREN
+ """
+ p[0] = c_ast.UnaryOp(
+ p[1],
+ p[2] if len(p) == 3 else p[3],
+ self._token_coord(p, 1))
+
+ def p_unary_operator(self, p):
+ """ unary_operator : AND
+ | TIMES
+ | PLUS
+ | MINUS
+ | NOT
+ | LNOT
+ """
+ p[0] = p[1]
+
+ def p_postfix_expression_1(self, p):
+ """ postfix_expression : primary_expression """
+ p[0] = p[1]
+
+ def p_postfix_expression_2(self, p):
+ """ postfix_expression : postfix_expression LBRACKET expression RBRACKET """
+ p[0] = c_ast.ArrayRef(p[1], p[3], p[1].coord)
+
+ def p_postfix_expression_3(self, p):
+ """ postfix_expression : postfix_expression LPAREN argument_expression_list RPAREN
+ | postfix_expression LPAREN RPAREN
+ """
+ p[0] = c_ast.FuncCall(p[1], p[3] if len(p) == 5 else None, p[1].coord)
+
+ def p_postfix_expression_4(self, p):
+ """ postfix_expression : postfix_expression PERIOD ID
+ | postfix_expression PERIOD TYPEID
+ | postfix_expression ARROW ID
+ | postfix_expression ARROW TYPEID
+ """
+ field = c_ast.ID(p[3], self._token_coord(p, 3))
+ p[0] = c_ast.StructRef(p[1], p[2], field, p[1].coord)
+
+ def p_postfix_expression_5(self, p):
+ """ postfix_expression : postfix_expression PLUSPLUS
+ | postfix_expression MINUSMINUS
+ """
+ p[0] = c_ast.UnaryOp('p' + p[2], p[1], p[1].coord)
+
+ def p_postfix_expression_6(self, p):
+ """ postfix_expression : LPAREN type_name RPAREN brace_open initializer_list brace_close
+ | LPAREN type_name RPAREN brace_open initializer_list COMMA brace_close
+ """
+ p[0] = c_ast.CompoundLiteral(p[2], p[5])
+
+ def p_primary_expression_1(self, p):
+ """ primary_expression : identifier """
+ p[0] = p[1]
+
+ def p_primary_expression_2(self, p):
+ """ primary_expression : constant """
+ p[0] = p[1]
+
+ def p_primary_expression_3(self, p):
+ """ primary_expression : unified_string_literal
+ | unified_wstring_literal
+ """
+ p[0] = p[1]
+
+ def p_primary_expression_4(self, p):
+ """ primary_expression : LPAREN expression RPAREN """
+ p[0] = p[2]
+
+ def p_primary_expression_5(self, p):
+ """ primary_expression : OFFSETOF LPAREN type_name COMMA offsetof_member_designator RPAREN
+ """
+ coord = self._token_coord(p, 1)
+ p[0] = c_ast.FuncCall(c_ast.ID(p[1], coord),
+ c_ast.ExprList([p[3], p[5]], coord),
+ coord)
+
+ def p_offsetof_member_designator(self, p):
+ """ offsetof_member_designator : identifier
+ | offsetof_member_designator PERIOD identifier
+ | offsetof_member_designator LBRACKET expression RBRACKET
+ """
+ if len(p) == 2:
+ p[0] = p[1]
+ elif len(p) == 4:
+ p[0] = c_ast.StructRef(p[1], p[2], p[3], p[1].coord)
+ elif len(p) == 5:
+ p[0] = c_ast.ArrayRef(p[1], p[3], p[1].coord)
+ else:
+ raise NotImplementedError("Unexpected parsing state. len(p): %u" % len(p))
+
+ def p_argument_expression_list(self, p):
+ """ argument_expression_list : assignment_expression
+ | argument_expression_list COMMA assignment_expression
+ """
+ if len(p) == 2: # single expr
+ p[0] = c_ast.ExprList([p[1]], p[1].coord)
+ else:
+ p[1].exprs.append(p[3])
+ p[0] = p[1]
+
+ def p_identifier(self, p):
+ """ identifier : ID """
+ p[0] = c_ast.ID(p[1], self._token_coord(p, 1))
+
+ def p_constant_1(self, p):
+ """ constant : INT_CONST_DEC
+ | INT_CONST_OCT
+ | INT_CONST_HEX
+ | INT_CONST_BIN
+ | INT_CONST_CHAR
+ """
+ uCount = 0
+ lCount = 0
+ for x in p[1][-3:]:
+ if x in ('l', 'L'):
+ lCount += 1
+ elif x in ('u', 'U'):
+ uCount += 1
+ t = ''
+ if uCount > 1:
+ raise ValueError('Constant cannot have more than one u/U suffix.')
+ elif lCount > 2:
+ raise ValueError('Constant cannot have more than two l/L suffix.')
+ prefix = 'unsigned ' * uCount + 'long ' * lCount
+ p[0] = c_ast.Constant(
+ prefix + 'int', p[1], self._token_coord(p, 1))
+
+ def p_constant_2(self, p):
+ """ constant : FLOAT_CONST
+ | HEX_FLOAT_CONST
+ """
+ if 'x' in p[1].lower():
+ t = 'float'
+ else:
+ if p[1][-1] in ('f', 'F'):
+ t = 'float'
+ elif p[1][-1] in ('l', 'L'):
+ t = 'long double'
+ else:
+ t = 'double'
+
+ p[0] = c_ast.Constant(
+ t, p[1], self._token_coord(p, 1))
+
+ def p_constant_3(self, p):
+ """ constant : CHAR_CONST
+ | WCHAR_CONST
+ | U8CHAR_CONST
+ | U16CHAR_CONST
+ | U32CHAR_CONST
+ """
+ p[0] = c_ast.Constant(
+ 'char', p[1], self._token_coord(p, 1))
+
+ # The "unified" string and wstring literal rules are for supporting
+ # concatenation of adjacent string literals.
+ # I.e. "hello " "world" is seen by the C compiler as a single string literal
+ # with the value "hello world"
+ #
+ def p_unified_string_literal(self, p):
+ """ unified_string_literal : STRING_LITERAL
+ | unified_string_literal STRING_LITERAL
+ """
+ if len(p) == 2: # single literal
+ p[0] = c_ast.Constant(
+ 'string', p[1], self._token_coord(p, 1))
+ else:
+ p[1].value = p[1].value[:-1] + p[2][1:]
+ p[0] = p[1]
+
+ def p_unified_wstring_literal(self, p):
+ """ unified_wstring_literal : WSTRING_LITERAL
+ | U8STRING_LITERAL
+ | U16STRING_LITERAL
+ | U32STRING_LITERAL
+ | unified_wstring_literal WSTRING_LITERAL
+ | unified_wstring_literal U8STRING_LITERAL
+ | unified_wstring_literal U16STRING_LITERAL
+ | unified_wstring_literal U32STRING_LITERAL
+ """
+ if len(p) == 2: # single literal
+ p[0] = c_ast.Constant(
+ 'string', p[1], self._token_coord(p, 1))
+ else:
+ p[1].value = p[1].value.rstrip()[:-1] + p[2][2:]
+ p[0] = p[1]
+
+ def p_brace_open(self, p):
+ """ brace_open : LBRACE
+ """
+ p[0] = p[1]
+ p.set_lineno(0, p.lineno(1))
+
+ def p_brace_close(self, p):
+ """ brace_close : RBRACE
+ """
+ p[0] = p[1]
+ p.set_lineno(0, p.lineno(1))
+
+ def p_empty(self, p):
+ 'empty : '
+ p[0] = None
+
+ def p_error(self, p):
+ # If error recovery is added here in the future, make sure
+ # _get_yacc_lookahead_token still works!
+ #
+ if p:
+ self._parse_error(
+ 'before: %s' % p.value,
+ self._coord(lineno=p.lineno,
+ column=self.clex.find_tok_column(p)))
+ else:
+ self._parse_error('At end of input', self.clex.filename)
diff --git a/lib/pycparser/lextab.py b/lib/pycparser/lextab.py
new file mode 100644
index 0000000..444b465
--- /dev/null
+++ b/lib/pycparser/lextab.py
@@ -0,0 +1,10 @@
+# lextab.py. This file automatically created by PLY (version 3.10). Don't edit!
+_tabversion = '3.10'
+_lextokens = set(('INT_CONST_CHAR', 'VOID', 'LBRACKET', 'WCHAR_CONST', 'FLOAT_CONST', 'MINUS', 'RPAREN', 'STRUCT', 'LONG', 'PLUS', 'ELLIPSIS', 'U32STRING_LITERAL', 'GT', 'GOTO', 'ENUM', 'PERIOD', 'GE', 'INT_CONST_DEC', 'ARROW', '_STATIC_ASSERT', '__INT128', 'HEX_FLOAT_CONST', 'DOUBLE', 'MINUSEQUAL', 'INT_CONST_OCT', 'TIMESEQUAL', 'OR', 'SHORT', 'RETURN', 'RSHIFTEQUAL', '_ALIGNAS', 'RESTRICT', 'STATIC', 'SIZEOF', 'UNSIGNED', 'PLUSPLUS', 'COLON', 'WSTRING_LITERAL', 'DIVIDE', 'FOR', 'UNION', 'EQUALS', 'ELSE', 'ANDEQUAL', 'EQ', 'AND', 'TYPEID', 'LBRACE', 'PPHASH', 'INT', 'SIGNED', 'CONTINUE', 'NOT', 'OREQUAL', 'MOD', 'RSHIFT', 'DEFAULT', '_NORETURN', 'CHAR', 'WHILE', 'DIVEQUAL', '_ALIGNOF', 'EXTERN', 'LNOT', 'CASE', 'LAND', 'REGISTER', 'MODEQUAL', 'NE', 'SWITCH', 'INT_CONST_HEX', '_COMPLEX', 'PPPRAGMASTR', 'PLUSEQUAL', 'U32CHAR_CONST', 'CONDOP', 'U8STRING_LITERAL', 'BREAK', 'VOLATILE', 'PPPRAGMA', 'INLINE', 'INT_CONST_BIN', 'DO', 'U8CHAR_CONST', 'CONST', 'U16STRING_LITERAL', 'LOR', 'CHAR_CONST', 'LSHIFT', 'RBRACE', '_BOOL', 'LE', 'SEMI', '_THREAD_LOCAL', 'LT', 'COMMA', 'U16CHAR_CONST', 'OFFSETOF', '_ATOMIC', 'TYPEDEF', 'XOR', 'AUTO', 'TIMES', 'LPAREN', 'MINUSMINUS', 'ID', 'IF', 'STRING_LITERAL', 'FLOAT', 'XOREQUAL', 'LSHIFTEQUAL', 'RBRACKET'))
+_lexreflags = 64
+_lexliterals = ''
+_lexstateinfo = {'ppline': 'exclusive', 'pppragma': 'exclusive', 'INITIAL': 'inclusive'}
+_lexstatere = {'ppline': [('(?P<t_ppline_FILENAME>"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P<t_ppline_LINE_NUMBER>(0(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|([1-9][0-9]*(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?))|(?P<t_ppline_NEWLINE>\\n)|(?P<t_ppline_PPLINE>line)', [None, ('t_ppline_FILENAME', 'FILENAME'), None, None, ('t_ppline_LINE_NUMBER', 'LINE_NUMBER'), None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ('t_ppline_NEWLINE', 'NEWLINE'), ('t_ppline_PPLINE', 'PPLINE')])], 'pppragma': [('(?P<t_pppragma_NEWLINE>\\n)|(?P<t_pppragma_PPPRAGMA>pragma)|(?P<t_pppragma_STR>.+)', [None, ('t_pppragma_NEWLINE', 'NEWLINE'), ('t_pppragma_PPPRAGMA', 'PPPRAGMA'), ('t_pppragma_STR', 'STR')])], 'INITIAL': [('(?P<t_PPHASH>[ \\t]*\\#)|(?P<t_NEWLINE>\\n+)|(?P<t_LBRACE>\\{)|(?P<t_RBRACE>\\})|(?P<t_FLOAT_CONST>((((([0-9]*\\.[0-9]+)|([0-9]+\\.))([eE][-+]?[0-9]+)?)|([0-9]+([eE][-+]?[0-9]+)))[FfLl]?))|(?P<t_HEX_FLOAT_CONST>(0[xX]([0-9a-fA-F]+|((([0-9a-fA-F]+)?\\.[0-9a-fA-F]+)|([0-9a-fA-F]+\\.)))([pP][+-]?[0-9]+)[FfLl]?))|(?P<t_INT_CONST_HEX>0[xX][0-9a-fA-F]+(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|(?P<t_INT_CONST_BIN>0[bB][01]+(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)', [None, ('t_PPHASH', 'PPHASH'), ('t_NEWLINE', 'NEWLINE'), ('t_LBRACE', 'LBRACE'), ('t_RBRACE', 'RBRACE'), ('t_FLOAT_CONST', 'FLOAT_CONST'), None, None, None, None, None, None, None, None, None, ('t_HEX_FLOAT_CONST', 'HEX_FLOAT_CONST'), None, None, None, None, None, None, None, ('t_INT_CONST_HEX', 'INT_CONST_HEX'), None, None, None, None, None, None, None, ('t_INT_CONST_BIN', 'INT_CONST_BIN')]), ('(?P<t_BAD_CONST_OCT>0[0-7]*[89])|(?P<t_INT_CONST_OCT>0[0-7]*(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|(?P<t_INT_CONST_DEC>(0(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?)|([1-9][0-9]*(([uU]ll)|([uU]LL)|(ll[uU]?)|(LL[uU]?)|([uU][lL])|([lL][uU]?)|[uU])?))|(?P<t_INT_CONST_CHAR>\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F])))){2,4}\')|(?P<t_CHAR_CONST>\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?P<t_WCHAR_CONST>L\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?P<t_U8CHAR_CONST>u8\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?P<t_U16CHAR_CONST>u\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')|(?P<t_U32CHAR_CONST>U\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))\')', [None, ('t_BAD_CONST_OCT', 'BAD_CONST_OCT'), ('t_INT_CONST_OCT', 'INT_CONST_OCT'), None, None, None, None, None, None, None, ('t_INT_CONST_DEC', 'INT_CONST_DEC'), None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ('t_INT_CONST_CHAR', 'INT_CONST_CHAR'), None, None, None, None, None, None, ('t_CHAR_CONST', 'CHAR_CONST'), None, None, None, None, None, None, ('t_WCHAR_CONST', 'WCHAR_CONST'), None, None, None, None, None, None, ('t_U8CHAR_CONST', 'U8CHAR_CONST'), None, None, None, None, None, None, ('t_U16CHAR_CONST', 'U16CHAR_CONST'), None, None, None, None, None, None, ('t_U32CHAR_CONST', 'U32CHAR_CONST')]), ('(?P<t_UNMATCHED_QUOTE>(\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))*\\n)|(\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))*$))|(?P<t_BAD_CHAR_CONST>(\'([^\'\\\\\\n]|(\\\\(([a-wyzA-Z._~!=&\\^\\-\\\\?\'"]|x(?![0-9a-fA-F]))|(\\d+)(?!\\d)|(x[0-9a-fA-F]+)(?![0-9a-fA-F]))))[^\'\n]+\')|(\'\')|(\'([\\\\][^a-zA-Z._~^!=&\\^\\-\\\\?\'"x0-9])[^\'\\n]*\'))|(?P<t_WSTRING_LITERAL>L"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P<t_U8STRING_LITERAL>u8"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P<t_U16STRING_LITERAL>u"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P<t_U32STRING_LITERAL>U"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P<t_BAD_STRING_LITERAL>"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*([\\\\][^a-zA-Z._~^!=&\\^\\-\\\\?\'"x0-9])([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P<t_ID>[a-zA-Z_$][0-9a-zA-Z_$]*)|(?P<t_STRING_LITERAL>"([^"\\\\\\n]|(\\\\[0-9a-zA-Z._~!=&\\^\\-\\\\?\'"]))*")|(?P<t_ELLIPSIS>\\.\\.\\.)|(?P<t_PLUSPLUS>\\+\\+)|(?P<t_LOR>\\|\\|)|(?P<t_XOREQUAL>\\^=)|(?P<t_OREQUAL>\\|=)|(?P<t_LSHIFTEQUAL><<=)|(?P<t_RSHIFTEQUAL>>>=)|(?P<t_PLUSEQUAL>\\+=)|(?P<t_TIMESEQUAL>\\*=)', [None, ('t_UNMATCHED_QUOTE', 'UNMATCHED_QUOTE'), None, None, None, None, None, None, None, None, None, None, None, None, None, None, ('t_BAD_CHAR_CONST', 'BAD_CHAR_CONST'), None, None, None, None, None, None, None, None, None, None, ('t_WSTRING_LITERAL', 'WSTRING_LITERAL'), None, None, ('t_U8STRING_LITERAL', 'U8STRING_LITERAL'), None, None, ('t_U16STRING_LITERAL', 'U16STRING_LITERAL'), None, None, ('t_U32STRING_LITERAL', 'U32STRING_LITERAL'), None, None, ('t_BAD_STRING_LITERAL', 'BAD_STRING_LITERAL'), None, None, None, None, None, ('t_ID', 'ID'), (None, 'STRING_LITERAL'), None, None, (None, 'ELLIPSIS'), (None, 'PLUSPLUS'), (None, 'LOR'), (None, 'XOREQUAL'), (None, 'OREQUAL'), (None, 'LSHIFTEQUAL'), (None, 'RSHIFTEQUAL'), (None, 'PLUSEQUAL'), (None, 'TIMESEQUAL')]), ('(?P<t_PLUS>\\+)|(?P<t_MODEQUAL>%=)|(?P<t_DIVEQUAL>/=)|(?P<t_RBRACKET>\\])|(?P<t_CONDOP>\\?)|(?P<t_XOR>\\^)|(?P<t_LSHIFT><<)|(?P<t_LE><=)|(?P<t_LPAREN>\\()|(?P<t_ARROW>->)|(?P<t_EQ>==)|(?P<t_NE>!=)|(?P<t_MINUSMINUS>--)|(?P<t_OR>\\|)|(?P<t_TIMES>\\*)|(?P<t_LBRACKET>\\[)|(?P<t_GE>>=)|(?P<t_RPAREN>\\))|(?P<t_LAND>&&)|(?P<t_RSHIFT>>>)|(?P<t_MINUSEQUAL>-=)|(?P<t_PERIOD>\\.)|(?P<t_ANDEQUAL>&=)|(?P<t_EQUALS>=)|(?P<t_LT><)|(?P<t_COMMA>,)|(?P<t_DIVIDE>/)|(?P<t_AND>&)|(?P<t_MOD>%)|(?P<t_SEMI>;)|(?P<t_MINUS>-)|(?P<t_GT>>)|(?P<t_COLON>:)|(?P<t_NOT>~)|(?P<t_LNOT>!)', [None, (None, 'PLUS'), (None, 'MODEQUAL'), (None, 'DIVEQUAL'), (None, 'RBRACKET'), (None, 'CONDOP'), (None, 'XOR'), (None, 'LSHIFT'), (None, 'LE'), (None, 'LPAREN'), (None, 'ARROW'), (None, 'EQ'), (None, 'NE'), (None, 'MINUSMINUS'), (None, 'OR'), (None, 'TIMES'), (None, 'LBRACKET'), (None, 'GE'), (None, 'RPAREN'), (None, 'LAND'), (None, 'RSHIFT'), (None, 'MINUSEQUAL'), (None, 'PERIOD'), (None, 'ANDEQUAL'), (None, 'EQUALS'), (None, 'LT'), (None, 'COMMA'), (None, 'DIVIDE'), (None, 'AND'), (None, 'MOD'), (None, 'SEMI'), (None, 'MINUS'), (None, 'GT'), (None, 'COLON'), (None, 'NOT'), (None, 'LNOT')])]}
+_lexstateignore = {'ppline': ' \t', 'pppragma': ' \t', 'INITIAL': ' \t'}
+_lexstateerrorf = {'ppline': 't_ppline_error', 'pppragma': 't_pppragma_error', 'INITIAL': 't_error'}
+_lexstateeoff = {}
diff --git a/lib/pycparser/ply/__init__.py b/lib/pycparser/ply/__init__.py
new file mode 100644
index 0000000..6e53cdd
--- /dev/null
+++ b/lib/pycparser/ply/__init__.py
@@ -0,0 +1,5 @@
+# PLY package
+# Author: David Beazley (dave@dabeaz.com)
+
+__version__ = '3.9'
+__all__ = ['lex','yacc']
diff --git a/lib/pycparser/ply/cpp.py b/lib/pycparser/ply/cpp.py
new file mode 100644
index 0000000..86273ea
--- /dev/null
+++ b/lib/pycparser/ply/cpp.py
@@ -0,0 +1,905 @@
+# -----------------------------------------------------------------------------
+# cpp.py
+#
+# Author: David Beazley (http://www.dabeaz.com)
+# Copyright (C) 2017
+# All rights reserved
+#
+# This module implements an ANSI-C style lexical preprocessor for PLY.
+# -----------------------------------------------------------------------------
+import sys
+
+# Some Python 3 compatibility shims
+if sys.version_info.major < 3:
+ STRING_TYPES = (str, unicode)
+else:
+ STRING_TYPES = str
+ xrange = range
+
+# -----------------------------------------------------------------------------
+# Default preprocessor lexer definitions. These tokens are enough to get
+# a basic preprocessor working. Other modules may import these if they want
+# -----------------------------------------------------------------------------
+
+tokens = (
+ 'CPP_ID','CPP_INTEGER', 'CPP_FLOAT', 'CPP_STRING', 'CPP_CHAR', 'CPP_WS', 'CPP_COMMENT1', 'CPP_COMMENT2', 'CPP_POUND','CPP_DPOUND'
+)
+
+literals = "+-*/%|&~^<>=!?()[]{}.,;:\\\'\""
+
+# Whitespace
+def t_CPP_WS(t):
+ r'\s+'
+ t.lexer.lineno += t.value.count("\n")
+ return t
+
+t_CPP_POUND = r'\#'
+t_CPP_DPOUND = r'\#\#'
+
+# Identifier
+t_CPP_ID = r'[A-Za-z_][\w_]*'
+
+# Integer literal
+def CPP_INTEGER(t):
+ r'(((((0x)|(0X))[0-9a-fA-F]+)|(\d+))([uU][lL]|[lL][uU]|[uU]|[lL])?)'
+ return t
+
+t_CPP_INTEGER = CPP_INTEGER
+
+# Floating literal
+t_CPP_FLOAT = r'((\d+)(\.\d+)(e(\+|-)?(\d+))? | (\d+)e(\+|-)?(\d+))([lL]|[fF])?'
+
+# String literal
+def t_CPP_STRING(t):
+ r'\"([^\\\n]|(\\(.|\n)))*?\"'
+ t.lexer.lineno += t.value.count("\n")
+ return t
+
+# Character constant 'c' or L'c'
+def t_CPP_CHAR(t):
+ r'(L)?\'([^\\\n]|(\\(.|\n)))*?\''
+ t.lexer.lineno += t.value.count("\n")
+ return t
+
+# Comment
+def t_CPP_COMMENT1(t):
+ r'(/\*(.|\n)*?\*/)'
+ ncr = t.value.count("\n")
+ t.lexer.lineno += ncr
+ # replace with one space or a number of '\n'
+ t.type = 'CPP_WS'; t.value = '\n' * ncr if ncr else ' '
+ return t
+
+# Line comment
+def t_CPP_COMMENT2(t):
+ r'(//.*?(\n|$))'
+ # replace with '/n'
+ t.type = 'CPP_WS'; t.value = '\n'
+ return t
+
+def t_error(t):
+ t.type = t.value[0]
+ t.value = t.value[0]
+ t.lexer.skip(1)
+ return t
+
+import re
+import copy
+import time
+import os.path
+
+# -----------------------------------------------------------------------------
+# trigraph()
+#
+# Given an input string, this function replaces all trigraph sequences.
+# The following mapping is used:
+#
+# ??= #
+# ??/ \
+# ??' ^
+# ??( [
+# ??) ]
+# ??! |
+# ??< {
+# ??> }
+# ??- ~
+# -----------------------------------------------------------------------------
+
+_trigraph_pat = re.compile(r'''\?\?[=/\'\(\)\!<>\-]''')
+_trigraph_rep = {
+ '=':'#',
+ '/':'\\',
+ "'":'^',
+ '(':'[',
+ ')':']',
+ '!':'|',
+ '<':'{',
+ '>':'}',
+ '-':'~'
+}
+
+def trigraph(input):
+ return _trigraph_pat.sub(lambda g: _trigraph_rep[g.group()[-1]],input)
+
+# ------------------------------------------------------------------
+# Macro object
+#
+# This object holds information about preprocessor macros
+#
+# .name - Macro name (string)
+# .value - Macro value (a list of tokens)
+# .arglist - List of argument names
+# .variadic - Boolean indicating whether or not variadic macro
+# .vararg - Name of the variadic parameter
+#
+# When a macro is created, the macro replacement token sequence is
+# pre-scanned and used to create patch lists that are later used
+# during macro expansion
+# ------------------------------------------------------------------
+
+class Macro(object):
+ def __init__(self,name,value,arglist=None,variadic=False):
+ self.name = name
+ self.value = value
+ self.arglist = arglist
+ self.variadic = variadic
+ if variadic:
+ self.vararg = arglist[-1]
+ self.source = None
+
+# ------------------------------------------------------------------
+# Preprocessor object
+#
+# Object representing a preprocessor. Contains macro definitions,
+# include directories, and other information
+# ------------------------------------------------------------------
+
+class Preprocessor(object):
+ def __init__(self,lexer=None):
+ if lexer is None:
+ lexer = lex.lexer
+ self.lexer = lexer
+ self.macros = { }
+ self.path = []
+ self.temp_path = []
+
+ # Probe the lexer for selected tokens
+ self.lexprobe()
+
+ tm = time.localtime()
+ self.define("__DATE__ \"%s\"" % time.strftime("%b %d %Y",tm))
+ self.define("__TIME__ \"%s\"" % time.strftime("%H:%M:%S",tm))
+ self.parser = None
+
+ # -----------------------------------------------------------------------------
+ # tokenize()
+ #
+ # Utility function. Given a string of text, tokenize into a list of tokens
+ # -----------------------------------------------------------------------------
+
+ def tokenize(self,text):
+ tokens = []
+ self.lexer.input(text)
+ while True:
+ tok = self.lexer.token()
+ if not tok: break
+ tokens.append(tok)
+ return tokens
+
+ # ---------------------------------------------------------------------
+ # error()
+ #
+ # Report a preprocessor error/warning of some kind
+ # ----------------------------------------------------------------------
+
+ def error(self,file,line,msg):
+ print("%s:%d %s" % (file,line,msg))
+
+ # ----------------------------------------------------------------------
+ # lexprobe()
+ #
+ # This method probes the preprocessor lexer object to discover
+ # the token types of symbols that are important to the preprocessor.
+ # If this works right, the preprocessor will simply "work"
+ # with any suitable lexer regardless of how tokens have been named.
+ # ----------------------------------------------------------------------
+
+ def lexprobe(self):
+
+ # Determine the token type for identifiers
+ self.lexer.input("identifier")
+ tok = self.lexer.token()
+ if not tok or tok.value != "identifier":
+ print("Couldn't determine identifier type")
+ else:
+ self.t_ID = tok.type
+
+ # Determine the token type for integers
+ self.lexer.input("12345")
+ tok = self.lexer.token()
+ if not tok or int(tok.value) != 12345:
+ print("Couldn't determine integer type")
+ else:
+ self.t_INTEGER = tok.type
+ self.t_INTEGER_TYPE = type(tok.value)
+
+ # Determine the token type for strings enclosed in double quotes
+ self.lexer.input("\"filename\"")
+ tok = self.lexer.token()
+ if not tok or tok.value != "\"filename\"":
+ print("Couldn't determine string type")
+ else:
+ self.t_STRING = tok.type
+
+ # Determine the token type for whitespace--if any
+ self.lexer.input(" ")
+ tok = self.lexer.token()
+ if not tok or tok.value != " ":
+ self.t_SPACE = None
+ else:
+ self.t_SPACE = tok.type
+
+ # Determine the token type for newlines
+ self.lexer.input("\n")
+ tok = self.lexer.token()
+ if not tok or tok.value != "\n":
+ self.t_NEWLINE = None
+ print("Couldn't determine token for newlines")
+ else:
+ self.t_NEWLINE = tok.type
+
+ self.t_WS = (self.t_SPACE, self.t_NEWLINE)
+
+ # Check for other characters used by the preprocessor
+ chars = [ '<','>','#','##','\\','(',')',',','.']
+ for c in chars:
+ self.lexer.input(c)
+ tok = self.lexer.token()
+ if not tok or tok.value != c:
+ print("Unable to lex '%s' required for preprocessor" % c)
+
+ # ----------------------------------------------------------------------
+ # add_path()
+ #
+ # Adds a search path to the preprocessor.
+ # ----------------------------------------------------------------------
+
+ def add_path(self,path):
+ self.path.append(path)
+
+ # ----------------------------------------------------------------------
+ # group_lines()
+ #
+ # Given an input string, this function splits it into lines. Trailing whitespace
+ # is removed. Any line ending with \ is grouped with the next line. This
+ # function forms the lowest level of the preprocessor---grouping into text into
+ # a line-by-line format.
+ # ----------------------------------------------------------------------
+
+ def group_lines(self,input):
+ lex = self.lexer.clone()
+ lines = [x.rstrip() for x in input.splitlines()]
+ for i in xrange(len(lines)):
+ j = i+1
+ while lines[i].endswith('\\') and (j < len(lines)):
+ lines[i] = lines[i][:-1]+lines[j]
+ lines[j] = ""
+ j += 1
+
+ input = "\n".join(lines)
+ lex.input(input)
+ lex.lineno = 1
+
+ current_line = []
+ while True:
+ tok = lex.token()
+ if not tok:
+ break
+ current_line.append(tok)
+ if tok.type in self.t_WS and '\n' in tok.value:
+ yield current_line
+ current_line = []
+
+ if current_line:
+ yield current_line
+
+ # ----------------------------------------------------------------------
+ # tokenstrip()
+ #
+ # Remove leading/trailing whitespace tokens from a token list
+ # ----------------------------------------------------------------------
+
+ def tokenstrip(self,tokens):
+ i = 0
+ while i < len(tokens) and tokens[i].type in self.t_WS:
+ i += 1
+ del tokens[:i]
+ i = len(tokens)-1
+ while i >= 0 and tokens[i].type in self.t_WS:
+ i -= 1
+ del tokens[i+1:]
+ return tokens
+
+
+ # ----------------------------------------------------------------------
+ # collect_args()
+ #
+ # Collects comma separated arguments from a list of tokens. The arguments
+ # must be enclosed in parenthesis. Returns a tuple (tokencount,args,positions)
+ # where tokencount is the number of tokens consumed, args is a list of arguments,
+ # and positions is a list of integers containing the starting index of each
+ # argument. Each argument is represented by a list of tokens.
+ #
+ # When collecting arguments, leading and trailing whitespace is removed
+ # from each argument.
+ #
+ # This function properly handles nested parenthesis and commas---these do not
+ # define new arguments.
+ # ----------------------------------------------------------------------
+
+ def collect_args(self,tokenlist):
+ args = []
+ positions = []
+ current_arg = []
+ nesting = 1
+ tokenlen = len(tokenlist)
+
+ # Search for the opening '('.
+ i = 0
+ while (i < tokenlen) and (tokenlist[i].type in self.t_WS):
+ i += 1
+
+ if (i < tokenlen) and (tokenlist[i].value == '('):
+ positions.append(i+1)
+ else:
+ self.error(self.source,tokenlist[0].lineno,"Missing '(' in macro arguments")
+ return 0, [], []
+
+ i += 1
+
+ while i < tokenlen:
+ t = tokenlist[i]
+ if t.value == '(':
+ current_arg.append(t)
+ nesting += 1
+ elif t.value == ')':
+ nesting -= 1
+ if nesting == 0:
+ if current_arg:
+ args.append(self.tokenstrip(current_arg))
+ positions.append(i)
+ return i+1,args,positions
+ current_arg.append(t)
+ elif t.value == ',' and nesting == 1:
+ args.append(self.tokenstrip(current_arg))
+ positions.append(i+1)
+ current_arg = []
+ else:
+ current_arg.append(t)
+ i += 1
+
+ # Missing end argument
+ self.error(self.source,tokenlist[-1].lineno,"Missing ')' in macro arguments")
+ return 0, [],[]
+
+ # ----------------------------------------------------------------------
+ # macro_prescan()
+ #
+ # Examine the macro value (token sequence) and identify patch points
+ # This is used to speed up macro expansion later on---we'll know
+ # right away where to apply patches to the value to form the expansion
+ # ----------------------------------------------------------------------
+
+ def macro_prescan(self,macro):
+ macro.patch = [] # Standard macro arguments
+ macro.str_patch = [] # String conversion expansion
+ macro.var_comma_patch = [] # Variadic macro comma patch
+ i = 0
+ while i < len(macro.value):
+ if macro.value[i].type == self.t_ID and macro.value[i].value in macro.arglist:
+ argnum = macro.arglist.index(macro.value[i].value)
+ # Conversion of argument to a string
+ if i > 0 and macro.value[i-1].value == '#':
+ macro.value[i] = copy.copy(macro.value[i])
+ macro.value[i].type = self.t_STRING
+ del macro.value[i-1]
+ macro.str_patch.append((argnum,i-1))
+ continue
+ # Concatenation
+ elif (i > 0 and macro.value[i-1].value == '##'):
+ macro.patch.append(('c',argnum,i-1))
+ del macro.value[i-1]
+ continue
+ elif ((i+1) < len(macro.value) and macro.value[i+1].value == '##'):
+ macro.patch.append(('c',argnum,i))
+ i += 1
+ continue
+ # Standard expansion
+ else:
+ macro.patch.append(('e',argnum,i))
+ elif macro.value[i].value == '##':
+ if macro.variadic and (i > 0) and (macro.value[i-1].value == ',') and \
+ ((i+1) < len(macro.value)) and (macro.value[i+1].type == self.t_ID) and \
+ (macro.value[i+1].value == macro.vararg):
+ macro.var_comma_patch.append(i-1)
+ i += 1
+ macro.patch.sort(key=lambda x: x[2],reverse=True)
+
+ # ----------------------------------------------------------------------
+ # macro_expand_args()
+ #
+ # Given a Macro and list of arguments (each a token list), this method
+ # returns an expanded version of a macro. The return value is a token sequence
+ # representing the replacement macro tokens
+ # ----------------------------------------------------------------------
+
+ def macro_expand_args(self,macro,args):
+ # Make a copy of the macro token sequence
+ rep = [copy.copy(_x) for _x in macro.value]
+
+ # Make string expansion patches. These do not alter the length of the replacement sequence
+
+ str_expansion = {}
+ for argnum, i in macro.str_patch:
+ if argnum not in str_expansion:
+ str_expansion[argnum] = ('"%s"' % "".join([x.value for x in args[argnum]])).replace("\\","\\\\")
+ rep[i] = copy.copy(rep[i])
+ rep[i].value = str_expansion[argnum]
+
+ # Make the variadic macro comma patch. If the variadic macro argument is empty, we get rid
+ comma_patch = False
+ if macro.variadic and not args[-1]:
+ for i in macro.var_comma_patch:
+ rep[i] = None
+ comma_patch = True
+
+ # Make all other patches. The order of these matters. It is assumed that the patch list
+ # has been sorted in reverse order of patch location since replacements will cause the
+ # size of the replacement sequence to expand from the patch point.
+
+ expanded = { }
+ for ptype, argnum, i in macro.patch:
+ # Concatenation. Argument is left unexpanded
+ if ptype == 'c':
+ rep[i:i+1] = args[argnum]
+ # Normal expansion. Argument is macro expanded first
+ elif ptype == 'e':
+ if argnum not in expanded:
+ expanded[argnum] = self.expand_macros(args[argnum])
+ rep[i:i+1] = expanded[argnum]
+
+ # Get rid of removed comma if necessary
+ if comma_patch:
+ rep = [_i for _i in rep if _i]
+
+ return rep
+
+
+ # ----------------------------------------------------------------------
+ # expand_macros()
+ #
+ # Given a list of tokens, this function performs macro expansion.
+ # The expanded argument is a dictionary that contains macros already
+ # expanded. This is used to prevent infinite recursion.
+ # ----------------------------------------------------------------------
+
+ def expand_macros(self,tokens,expanded=None):
+ if expanded is None:
+ expanded = {}
+ i = 0
+ while i < len(tokens):
+ t = tokens[i]
+ if t.type == self.t_ID:
+ if t.value in self.macros and t.value not in expanded:
+ # Yes, we found a macro match
+ expanded[t.value] = True
+
+ m = self.macros[t.value]
+ if not m.arglist:
+ # A simple macro
+ ex = self.expand_macros([copy.copy(_x) for _x in m.value],expanded)
+ for e in ex:
+ e.lineno = t.lineno
+ tokens[i:i+1] = ex
+ i += len(ex)
+ else:
+ # A macro with arguments
+ j = i + 1
+ while j < len(tokens) and tokens[j].type in self.t_WS:
+ j += 1
+ if tokens[j].value == '(':
+ tokcount,args,positions = self.collect_args(tokens[j:])
+ if not m.variadic and len(args) != len(m.arglist):
+ self.error(self.source,t.lineno,"Macro %s requires %d arguments" % (t.value,len(m.arglist)))
+ i = j + tokcount
+ elif m.variadic and len(args) < len(m.arglist)-1:
+ if len(m.arglist) > 2:
+ self.error(self.source,t.lineno,"Macro %s must have at least %d arguments" % (t.value, len(m.arglist)-1))
+ else:
+ self.error(self.source,t.lineno,"Macro %s must have at least %d argument" % (t.value, len(m.arglist)-1))
+ i = j + tokcount
+ else:
+ if m.variadic:
+ if len(args) == len(m.arglist)-1:
+ args.append([])
+ else:
+ args[len(m.arglist)-1] = tokens[j+positions[len(m.arglist)-1]:j+tokcount-1]
+ del args[len(m.arglist):]
+
+ # Get macro replacement text
+ rep = self.macro_expand_args(m,args)
+ rep = self.expand_macros(rep,expanded)
+ for r in rep:
+ r.lineno = t.lineno
+ tokens[i:j+tokcount] = rep
+ i += len(rep)
+ del expanded[t.value]
+ continue
+ elif t.value == '__LINE__':
+ t.type = self.t_INTEGER
+ t.value = self.t_INTEGER_TYPE(t.lineno)
+
+ i += 1
+ return tokens
+
+ # ----------------------------------------------------------------------
+ # evalexpr()
+ #
+ # Evaluate an expression token sequence for the purposes of evaluating
+ # integral expressions.
+ # ----------------------------------------------------------------------
+
+ def evalexpr(self,tokens):
+ # tokens = tokenize(line)
+ # Search for defined macros
+ i = 0
+ while i < len(tokens):
+ if tokens[i].type == self.t_ID and tokens[i].value == 'defined':
+ j = i + 1
+ needparen = False
+ result = "0L"
+ while j < len(tokens):
+ if tokens[j].type in self.t_WS:
+ j += 1
+ continue
+ elif tokens[j].type == self.t_ID:
+ if tokens[j].value in self.macros:
+ result = "1L"
+ else:
+ result = "0L"
+ if not needparen: break
+ elif tokens[j].value == '(':
+ needparen = True
+ elif tokens[j].value == ')':
+ break
+ else:
+ self.error(self.source,tokens[i].lineno,"Malformed defined()")
+ j += 1
+ tokens[i].type = self.t_INTEGER
+ tokens[i].value = self.t_INTEGER_TYPE(result)
+ del tokens[i+1:j+1]
+ i += 1
+ tokens = self.expand_macros(tokens)
+ for i,t in enumerate(tokens):
+ if t.type == self.t_ID:
+ tokens[i] = copy.copy(t)
+ tokens[i].type = self.t_INTEGER
+ tokens[i].value = self.t_INTEGER_TYPE("0L")
+ elif t.type == self.t_INTEGER:
+ tokens[i] = copy.copy(t)
+ # Strip off any trailing suffixes
+ tokens[i].value = str(tokens[i].value)
+ while tokens[i].value[-1] not in "0123456789abcdefABCDEF":
+ tokens[i].value = tokens[i].value[:-1]
+
+ expr = "".join([str(x.value) for x in tokens])
+ expr = expr.replace("&&"," and ")
+ expr = expr.replace("||"," or ")
+ expr = expr.replace("!"," not ")
+ try:
+ result = eval(expr)
+ except Exception:
+ self.error(self.source,tokens[0].lineno,"Couldn't evaluate expression")
+ result = 0
+ return result
+
+ # ----------------------------------------------------------------------
+ # parsegen()
+ #
+ # Parse an input string/
+ # ----------------------------------------------------------------------
+ def parsegen(self,input,source=None):
+
+ # Replace trigraph sequences
+ t = trigraph(input)
+ lines = self.group_lines(t)
+
+ if not source:
+ source = ""
+
+ self.define("__FILE__ \"%s\"" % source)
+
+ self.source = source
+ chunk = []
+ enable = True
+ iftrigger = False
+ ifstack = []
+
+ for x in lines:
+ for i,tok in enumerate(x):
+ if tok.type not in self.t_WS: break
+ if tok.value == '#':
+ # Preprocessor directive
+
+ # insert necessary whitespace instead of eaten tokens
+ for tok in x:
+ if tok.type in self.t_WS and '\n' in tok.value:
+ chunk.append(tok)
+
+ dirtokens = self.tokenstrip(x[i+1:])
+ if dirtokens:
+ name = dirtokens[0].value
+ args = self.tokenstrip(dirtokens[1:])
+ else:
+ name = ""
+ args = []
+
+ if name == 'define':
+ if enable:
+ for tok in self.expand_macros(chunk):
+ yield tok
+ chunk = []
+ self.define(args)
+ elif name == 'include':
+ if enable:
+ for tok in self.expand_macros(chunk):
+ yield tok
+ chunk = []
+ oldfile = self.macros['__FILE__']
+ for tok in self.include(args):
+ yield tok
+ self.macros['__FILE__'] = oldfile
+ self.source = source
+ elif name == 'undef':
+ if enable:
+ for tok in self.expand_macros(chunk):
+ yield tok
+ chunk = []
+ self.undef(args)
+ elif name == 'ifdef':
+ ifstack.append((enable,iftrigger))
+ if enable:
+ if not args[0].value in self.macros:
+ enable = False
+ iftrigger = False
+ else:
+ iftrigger = True
+ elif name == 'ifndef':
+ ifstack.append((enable,iftrigger))
+ if enable:
+ if args[0].value in self.macros:
+ enable = False
+ iftrigger = False
+ else:
+ iftrigger = True
+ elif name == 'if':
+ ifstack.append((enable,iftrigger))
+ if enable:
+ result = self.evalexpr(args)
+ if not result:
+ enable = False
+ iftrigger = False
+ else:
+ iftrigger = True
+ elif name == 'elif':
+ if ifstack:
+ if ifstack[-1][0]: # We only pay attention if outer "if" allows this
+ if enable: # If already true, we flip enable False
+ enable = False
+ elif not iftrigger: # If False, but not triggered yet, we'll check expression
+ result = self.evalexpr(args)
+ if result:
+ enable = True
+ iftrigger = True
+ else:
+ self.error(self.source,dirtokens[0].lineno,"Misplaced #elif")
+
+ elif name == 'else':
+ if ifstack:
+ if ifstack[-1][0]:
+ if enable:
+ enable = False
+ elif not iftrigger:
+ enable = True
+ iftrigger = True
+ else:
+ self.error(self.source,dirtokens[0].lineno,"Misplaced #else")
+
+ elif name == 'endif':
+ if ifstack:
+ enable,iftrigger = ifstack.pop()
+ else:
+ self.error(self.source,dirtokens[0].lineno,"Misplaced #endif")
+ else:
+ # Unknown preprocessor directive
+ pass
+
+ else:
+ # Normal text
+ if enable:
+ chunk.extend(x)
+
+ for tok in self.expand_macros(chunk):
+ yield tok
+ chunk = []
+
+ # ----------------------------------------------------------------------
+ # include()
+ #
+ # Implementation of file-inclusion
+ # ----------------------------------------------------------------------
+
+ def include(self,tokens):
+ # Try to extract the filename and then process an include file
+ if not tokens:
+ return
+ if tokens:
+ if tokens[0].value != '<' and tokens[0].type != self.t_STRING:
+ tokens = self.expand_macros(tokens)
+
+ if tokens[0].value == '<':
+ # Include <...>
+ i = 1
+ while i < len(tokens):
+ if tokens[i].value == '>':
+ break
+ i += 1
+ else:
+ print("Malformed #include <...>")
+ return
+ filename = "".join([x.value for x in tokens[1:i]])
+ path = self.path + [""] + self.temp_path
+ elif tokens[0].type == self.t_STRING:
+ filename = tokens[0].value[1:-1]
+ path = self.temp_path + [""] + self.path
+ else:
+ print("Malformed #include statement")
+ return
+ for p in path:
+ iname = os.path.join(p,filename)
+ try:
+ data = open(iname,"r").read()
+ dname = os.path.dirname(iname)
+ if dname:
+ self.temp_path.insert(0,dname)
+ for tok in self.parsegen(data,filename):
+ yield tok
+ if dname:
+ del self.temp_path[0]
+ break
+ except IOError:
+ pass
+ else:
+ print("Couldn't find '%s'" % filename)
+
+ # ----------------------------------------------------------------------
+ # define()
+ #
+ # Define a new macro
+ # ----------------------------------------------------------------------
+
+ def define(self,tokens):
+ if isinstance(tokens,STRING_TYPES):
+ tokens = self.tokenize(tokens)
+
+ linetok = tokens
+ try:
+ name = linetok[0]
+ if len(linetok) > 1:
+ mtype = linetok[1]
+ else:
+ mtype = None
+ if not mtype:
+ m = Macro(name.value,[])
+ self.macros[name.value] = m
+ elif mtype.type in self.t_WS:
+ # A normal macro
+ m = Macro(name.value,self.tokenstrip(linetok[2:]))
+ self.macros[name.value] = m
+ elif mtype.value == '(':
+ # A macro with arguments
+ tokcount, args, positions = self.collect_args(linetok[1:])
+ variadic = False
+ for a in args:
+ if variadic:
+ print("No more arguments may follow a variadic argument")
+ break
+ astr = "".join([str(_i.value) for _i in a])
+ if astr == "...":
+ variadic = True
+ a[0].type = self.t_ID
+ a[0].value = '__VA_ARGS__'
+ variadic = True
+ del a[1:]
+ continue
+ elif astr[-3:] == "..." and a[0].type == self.t_ID:
+ variadic = True
+ del a[1:]
+ # If, for some reason, "." is part of the identifier, strip off the name for the purposes
+ # of macro expansion
+ if a[0].value[-3:] == '...':
+ a[0].value = a[0].value[:-3]
+ continue
+ if len(a) > 1 or a[0].type != self.t_ID:
+ print("Invalid macro argument")
+ break
+ else:
+ mvalue = self.tokenstrip(linetok[1+tokcount:])
+ i = 0
+ while i < len(mvalue):
+ if i+1 < len(mvalue):
+ if mvalue[i].type in self.t_WS and mvalue[i+1].value == '##':
+ del mvalue[i]
+ continue
+ elif mvalue[i].value == '##' and mvalue[i+1].type in self.t_WS:
+ del mvalue[i+1]
+ i += 1
+ m = Macro(name.value,mvalue,[x[0].value for x in args],variadic)
+ self.macro_prescan(m)
+ self.macros[name.value] = m
+ else:
+ print("Bad macro definition")
+ except LookupError:
+ print("Bad macro definition")
+
+ # ----------------------------------------------------------------------
+ # undef()
+ #
+ # Undefine a macro
+ # ----------------------------------------------------------------------
+
+ def undef(self,tokens):
+ id = tokens[0].value
+ try:
+ del self.macros[id]
+ except LookupError:
+ pass
+
+ # ----------------------------------------------------------------------
+ # parse()
+ #
+ # Parse input text.
+ # ----------------------------------------------------------------------
+ def parse(self,input,source=None,ignore={}):
+ self.ignore = ignore
+ self.parser = self.parsegen(input,source)
+
+ # ----------------------------------------------------------------------
+ # token()
+ #
+ # Method to return individual tokens
+ # ----------------------------------------------------------------------
+ def token(self):
+ try:
+ while True:
+ tok = next(self.parser)
+ if tok.type not in self.ignore: return tok
+ except StopIteration:
+ self.parser = None
+ return None
+
+if __name__ == '__main__':
+ import ply.lex as lex
+ lexer = lex.lex()
+
+ # Run a preprocessor
+ import sys
+ f = open(sys.argv[1])
+ input = f.read()
+
+ p = Preprocessor(lexer)
+ p.parse(input,sys.argv[1])
+ while True:
+ tok = p.token()
+ if not tok: break
+ print(p.source, tok)
diff --git a/lib/pycparser/ply/ctokens.py b/lib/pycparser/ply/ctokens.py
new file mode 100644
index 0000000..f6f6952
--- /dev/null
+++ b/lib/pycparser/ply/ctokens.py
@@ -0,0 +1,133 @@
+# ----------------------------------------------------------------------
+# ctokens.py
+#
+# Token specifications for symbols in ANSI C and C++. This file is
+# meant to be used as a library in other tokenizers.
+# ----------------------------------------------------------------------
+
+# Reserved words
+
+tokens = [
+ # Literals (identifier, integer constant, float constant, string constant, char const)
+ 'ID', 'TYPEID', 'INTEGER', 'FLOAT', 'STRING', 'CHARACTER',
+
+ # Operators (+,-,*,/,%,|,&,~,^,<<,>>, ||, &&, !, <, <=, >, >=, ==, !=)
+ 'PLUS', 'MINUS', 'TIMES', 'DIVIDE', 'MODULO',
+ 'OR', 'AND', 'NOT', 'XOR', 'LSHIFT', 'RSHIFT',
+ 'LOR', 'LAND', 'LNOT',
+ 'LT', 'LE', 'GT', 'GE', 'EQ', 'NE',
+
+ # Assignment (=, *=, /=, %=, +=, -=, <<=, >>=, &=, ^=, |=)
+ 'EQUALS', 'TIMESEQUAL', 'DIVEQUAL', 'MODEQUAL', 'PLUSEQUAL', 'MINUSEQUAL',
+ 'LSHIFTEQUAL','RSHIFTEQUAL', 'ANDEQUAL', 'XOREQUAL', 'OREQUAL',
+
+ # Increment/decrement (++,--)
+ 'INCREMENT', 'DECREMENT',
+
+ # Structure dereference (->)
+ 'ARROW',
+
+ # Ternary operator (?)
+ 'TERNARY',
+
+ # Delimeters ( ) [ ] { } , . ; :
+ 'LPAREN', 'RPAREN',
+ 'LBRACKET', 'RBRACKET',
+ 'LBRACE', 'RBRACE',
+ 'COMMA', 'PERIOD', 'SEMI', 'COLON',
+
+ # Ellipsis (...)
+ 'ELLIPSIS',
+]
+
+# Operators
+t_PLUS = r'\+'
+t_MINUS = r'-'
+t_TIMES = r'\*'
+t_DIVIDE = r'/'
+t_MODULO = r'%'
+t_OR = r'\|'
+t_AND = r'&'
+t_NOT = r'~'
+t_XOR = r'\^'
+t_LSHIFT = r'<<'
+t_RSHIFT = r'>>'
+t_LOR = r'\|\|'
+t_LAND = r'&&'
+t_LNOT = r'!'
+t_LT = r'<'
+t_GT = r'>'
+t_LE = r'<='
+t_GE = r'>='
+t_EQ = r'=='
+t_NE = r'!='
+
+# Assignment operators
+
+t_EQUALS = r'='
+t_TIMESEQUAL = r'\*='
+t_DIVEQUAL = r'/='
+t_MODEQUAL = r'%='
+t_PLUSEQUAL = r'\+='
+t_MINUSEQUAL = r'-='
+t_LSHIFTEQUAL = r'<<='
+t_RSHIFTEQUAL = r'>>='
+t_ANDEQUAL = r'&='
+t_OREQUAL = r'\|='
+t_XOREQUAL = r'\^='
+
+# Increment/decrement
+t_INCREMENT = r'\+\+'
+t_DECREMENT = r'--'
+
+# ->
+t_ARROW = r'->'
+
+# ?
+t_TERNARY = r'\?'
+
+# Delimeters
+t_LPAREN = r'\('
+t_RPAREN = r'\)'
+t_LBRACKET = r'\['
+t_RBRACKET = r'\]'
+t_LBRACE = r'\{'
+t_RBRACE = r'\}'
+t_COMMA = r','
+t_PERIOD = r'\.'
+t_SEMI = r';'
+t_COLON = r':'
+t_ELLIPSIS = r'\.\.\.'
+
+# Identifiers
+t_ID = r'[A-Za-z_][A-Za-z0-9_]*'
+
+# Integer literal
+t_INTEGER = r'\d+([uU]|[lL]|[uU][lL]|[lL][uU])?'
+
+# Floating literal
+t_FLOAT = r'((\d+)(\.\d+)(e(\+|-)?(\d+))? | (\d+)e(\+|-)?(\d+))([lL]|[fF])?'
+
+# String literal
+t_STRING = r'\"([^\\\n]|(\\.))*?\"'
+
+# Character constant 'c' or L'c'
+t_CHARACTER = r'(L)?\'([^\\\n]|(\\.))*?\''
+
+# Comment (C-Style)
+def t_COMMENT(t):
+ r'/\*(.|\n)*?\*/'
+ t.lexer.lineno += t.value.count('\n')
+ return t
+
+# Comment (C++-Style)
+def t_CPPCOMMENT(t):
+ r'//.*\n'
+ t.lexer.lineno += 1
+ return t
+
+
+
+
+
+
diff --git a/lib/pycparser/ply/lex.py b/lib/pycparser/ply/lex.py
new file mode 100644
index 0000000..4bdd76c
--- /dev/null
+++ b/lib/pycparser/ply/lex.py
@@ -0,0 +1,1099 @@
+# -----------------------------------------------------------------------------
+# ply: lex.py
+#
+# Copyright (C) 2001-2017
+# David M. Beazley (Dabeaz LLC)
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+# * Neither the name of the David Beazley or Dabeaz LLC may be used to
+# endorse or promote products derived from this software without
+# specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# -----------------------------------------------------------------------------
+
+__version__ = '3.10'
+__tabversion__ = '3.10'
+
+import re
+import sys
+import types
+import copy
+import os
+import inspect
+
+# This tuple contains known string types
+try:
+ # Python 2.6
+ StringTypes = (types.StringType, types.UnicodeType)
+except AttributeError:
+ # Python 3.0
+ StringTypes = (str, bytes)
+
+# This regular expression is used to match valid token names
+_is_identifier = re.compile(r'^[a-zA-Z0-9_]+$')
+
+# Exception thrown when invalid token encountered and no default error
+# handler is defined.
+class LexError(Exception):
+ def __init__(self, message, s):
+ self.args = (message,)
+ self.text = s
+
+
+# Token class. This class is used to represent the tokens produced.
+class LexToken(object):
+ def __str__(self):
+ return 'LexToken(%s,%r,%d,%d)' % (self.type, self.value, self.lineno, self.lexpos)
+
+ def __repr__(self):
+ return str(self)
+
+
+# This object is a stand-in for a logging object created by the
+# logging module.
+
+class PlyLogger(object):
+ def __init__(self, f):
+ self.f = f
+
+ def critical(self, msg, *args, **kwargs):
+ self.f.write((msg % args) + '\n')
+
+ def warning(self, msg, *args, **kwargs):
+ self.f.write('WARNING: ' + (msg % args) + '\n')
+
+ def error(self, msg, *args, **kwargs):
+ self.f.write('ERROR: ' + (msg % args) + '\n')
+
+ info = critical
+ debug = critical
+
+
+# Null logger is used when no output is generated. Does nothing.
+class NullLogger(object):
+ def __getattribute__(self, name):
+ return self
+
+ def __call__(self, *args, **kwargs):
+ return self
+
+
+# -----------------------------------------------------------------------------
+# === Lexing Engine ===
+#
+# The following Lexer class implements the lexer runtime. There are only
+# a few public methods and attributes:
+#
+# input() - Store a new string in the lexer
+# token() - Get the next token
+# clone() - Clone the lexer
+#
+# lineno - Current line number
+# lexpos - Current position in the input string
+# -----------------------------------------------------------------------------
+
+class Lexer:
+ def __init__(self):
+ self.lexre = None # Master regular expression. This is a list of
+ # tuples (re, findex) where re is a compiled
+ # regular expression and findex is a list
+ # mapping regex group numbers to rules
+ self.lexretext = None # Current regular expression strings
+ self.lexstatere = {} # Dictionary mapping lexer states to master regexs
+ self.lexstateretext = {} # Dictionary mapping lexer states to regex strings
+ self.lexstaterenames = {} # Dictionary mapping lexer states to symbol names
+ self.lexstate = 'INITIAL' # Current lexer state
+ self.lexstatestack = [] # Stack of lexer states
+ self.lexstateinfo = None # State information
+ self.lexstateignore = {} # Dictionary of ignored characters for each state
+ self.lexstateerrorf = {} # Dictionary of error functions for each state
+ self.lexstateeoff = {} # Dictionary of eof functions for each state
+ self.lexreflags = 0 # Optional re compile flags
+ self.lexdata = None # Actual input data (as a string)
+ self.lexpos = 0 # Current position in input text
+ self.lexlen = 0 # Length of the input text
+ self.lexerrorf = None # Error rule (if any)
+ self.lexeoff = None # EOF rule (if any)
+ self.lextokens = None # List of valid tokens
+ self.lexignore = '' # Ignored characters
+ self.lexliterals = '' # Literal characters that can be passed through
+ self.lexmodule = None # Module
+ self.lineno = 1 # Current line number
+ self.lexoptimize = False # Optimized mode
+
+ def clone(self, object=None):
+ c = copy.copy(self)
+
+ # If the object parameter has been supplied, it means we are attaching the
+ # lexer to a new object. In this case, we have to rebind all methods in
+ # the lexstatere and lexstateerrorf tables.
+
+ if object:
+ newtab = {}
+ for key, ritem in self.lexstatere.items():
+ newre = []
+ for cre, findex in ritem:
+ newfindex = []
+ for f in findex:
+ if not f or not f[0]:
+ newfindex.append(f)
+ continue
+ newfindex.append((getattr(object, f[0].__name__), f[1]))
+ newre.append((cre, newfindex))
+ newtab[key] = newre
+ c.lexstatere = newtab
+ c.lexstateerrorf = {}
+ for key, ef in self.lexstateerrorf.items():
+ c.lexstateerrorf[key] = getattr(object, ef.__name__)
+ c.lexmodule = object
+ return c
+
+ # ------------------------------------------------------------
+ # writetab() - Write lexer information to a table file
+ # ------------------------------------------------------------
+ def writetab(self, lextab, outputdir=''):
+ if isinstance(lextab, types.ModuleType):
+ raise IOError("Won't overwrite existing lextab module")
+ basetabmodule = lextab.split('.')[-1]
+ filename = os.path.join(outputdir, basetabmodule) + '.py'
+ with open(filename, 'w') as tf:
+ tf.write('# %s.py. This file automatically created by PLY (version %s). Don\'t edit!\n' % (basetabmodule, __version__))
+ tf.write('_tabversion = %s\n' % repr(__tabversion__))
+ tf.write('_lextokens = set(%s)\n' % repr(tuple(self.lextokens)))
+ tf.write('_lexreflags = %s\n' % repr(self.lexreflags))
+ tf.write('_lexliterals = %s\n' % repr(self.lexliterals))
+ tf.write('_lexstateinfo = %s\n' % repr(self.lexstateinfo))
+
+ # Rewrite the lexstatere table, replacing function objects with function names
+ tabre = {}
+ for statename, lre in self.lexstatere.items():
+ titem = []
+ for (pat, func), retext, renames in zip(lre, self.lexstateretext[statename], self.lexstaterenames[statename]):
+ titem.append((retext, _funcs_to_names(func, renames)))
+ tabre[statename] = titem
+
+ tf.write('_lexstatere = %s\n' % repr(tabre))
+ tf.write('_lexstateignore = %s\n' % repr(self.lexstateignore))
+
+ taberr = {}
+ for statename, ef in self.lexstateerrorf.items():
+ taberr[statename] = ef.__name__ if ef else None
+ tf.write('_lexstateerrorf = %s\n' % repr(taberr))
+
+ tabeof = {}
+ for statename, ef in self.lexstateeoff.items():
+ tabeof[statename] = ef.__name__ if ef else None
+ tf.write('_lexstateeoff = %s\n' % repr(tabeof))
+
+ # ------------------------------------------------------------
+ # readtab() - Read lexer information from a tab file
+ # ------------------------------------------------------------
+ def readtab(self, tabfile, fdict):
+ if isinstance(tabfile, types.ModuleType):
+ lextab = tabfile
+ else:
+ exec('import %s' % tabfile)
+ lextab = sys.modules[tabfile]
+
+ if getattr(lextab, '_tabversion', '0.0') != __tabversion__:
+ raise ImportError('Inconsistent PLY version')
+
+ self.lextokens = lextab._lextokens
+ self.lexreflags = lextab._lexreflags
+ self.lexliterals = lextab._lexliterals
+ self.lextokens_all = self.lextokens | set(self.lexliterals)
+ self.lexstateinfo = lextab._lexstateinfo
+ self.lexstateignore = lextab._lexstateignore
+ self.lexstatere = {}
+ self.lexstateretext = {}
+ for statename, lre in lextab._lexstatere.items():
+ titem = []
+ txtitem = []
+ for pat, func_name in lre:
+ titem.append((re.compile(pat, lextab._lexreflags), _names_to_funcs(func_name, fdict)))
+
+ self.lexstatere[statename] = titem
+ self.lexstateretext[statename] = txtitem
+
+ self.lexstateerrorf = {}
+ for statename, ef in lextab._lexstateerrorf.items():
+ self.lexstateerrorf[statename] = fdict[ef]
+
+ self.lexstateeoff = {}
+ for statename, ef in lextab._lexstateeoff.items():
+ self.lexstateeoff[statename] = fdict[ef]
+
+ self.begin('INITIAL')
+
+ # ------------------------------------------------------------
+ # input() - Push a new string into the lexer
+ # ------------------------------------------------------------
+ def input(self, s):
+ # Pull off the first character to see if s looks like a string
+ c = s[:1]
+ if not isinstance(c, StringTypes):
+ raise ValueError('Expected a string')
+ self.lexdata = s
+ self.lexpos = 0
+ self.lexlen = len(s)
+
+ # ------------------------------------------------------------
+ # begin() - Changes the lexing state
+ # ------------------------------------------------------------
+ def begin(self, state):
+ if state not in self.lexstatere:
+ raise ValueError('Undefined state')
+ self.lexre = self.lexstatere[state]
+ self.lexretext = self.lexstateretext[state]
+ self.lexignore = self.lexstateignore.get(state, '')
+ self.lexerrorf = self.lexstateerrorf.get(state, None)
+ self.lexeoff = self.lexstateeoff.get(state, None)
+ self.lexstate = state
+
+ # ------------------------------------------------------------
+ # push_state() - Changes the lexing state and saves old on stack
+ # ------------------------------------------------------------
+ def push_state(self, state):
+ self.lexstatestack.append(self.lexstate)
+ self.begin(state)
+
+ # ------------------------------------------------------------
+ # pop_state() - Restores the previous state
+ # ------------------------------------------------------------
+ def pop_state(self):
+ self.begin(self.lexstatestack.pop())
+
+ # ------------------------------------------------------------
+ # current_state() - Returns the current lexing state
+ # ------------------------------------------------------------
+ def current_state(self):
+ return self.lexstate
+
+ # ------------------------------------------------------------
+ # skip() - Skip ahead n characters
+ # ------------------------------------------------------------
+ def skip(self, n):
+ self.lexpos += n
+
+ # ------------------------------------------------------------
+ # opttoken() - Return the next token from the Lexer
+ #
+ # Note: This function has been carefully implemented to be as fast
+ # as possible. Don't make changes unless you really know what
+ # you are doing
+ # ------------------------------------------------------------
+ def token(self):
+ # Make local copies of frequently referenced attributes
+ lexpos = self.lexpos
+ lexlen = self.lexlen
+ lexignore = self.lexignore
+ lexdata = self.lexdata
+
+ while lexpos < lexlen:
+ # This code provides some short-circuit code for whitespace, tabs, and other ignored characters
+ if lexdata[lexpos] in lexignore:
+ lexpos += 1
+ continue
+
+ # Look for a regular expression match
+ for lexre, lexindexfunc in self.lexre:
+ m = lexre.match(lexdata, lexpos)
+ if not m:
+ continue
+
+ # Create a token for return
+ tok = LexToken()
+ tok.value = m.group()
+ tok.lineno = self.lineno
+ tok.lexpos = lexpos
+
+ i = m.lastindex
+ func, tok.type = lexindexfunc[i]
+
+ if not func:
+ # If no token type was set, it's an ignored token
+ if tok.type:
+ self.lexpos = m.end()
+ return tok
+ else:
+ lexpos = m.end()
+ break
+
+ lexpos = m.end()
+
+ # If token is processed by a function, call it
+
+ tok.lexer = self # Set additional attributes useful in token rules
+ self.lexmatch = m
+ self.lexpos = lexpos
+
+ newtok = func(tok)
+
+ # Every function must return a token, if nothing, we just move to next token
+ if not newtok:
+ lexpos = self.lexpos # This is here in case user has updated lexpos.
+ lexignore = self.lexignore # This is here in case there was a state change
+ break
+
+ # Verify type of the token. If not in the token map, raise an error
+ if not self.lexoptimize:
+ if newtok.type not in self.lextokens_all:
+ raise LexError("%s:%d: Rule '%s' returned an unknown token type '%s'" % (
+ func.__code__.co_filename, func.__code__.co_firstlineno,
+ func.__name__, newtok.type), lexdata[lexpos:])
+
+ return newtok
+ else:
+ # No match, see if in literals
+ if lexdata[lexpos] in self.lexliterals:
+ tok = LexToken()
+ tok.value = lexdata[lexpos]
+ tok.lineno = self.lineno
+ tok.type = tok.value
+ tok.lexpos = lexpos
+ self.lexpos = lexpos + 1
+ return tok
+
+ # No match. Call t_error() if defined.
+ if self.lexerrorf:
+ tok = LexToken()
+ tok.value = self.lexdata[lexpos:]
+ tok.lineno = self.lineno
+ tok.type = 'error'
+ tok.lexer = self
+ tok.lexpos = lexpos
+ self.lexpos = lexpos
+ newtok = self.lexerrorf(tok)
+ if lexpos == self.lexpos:
+ # Error method didn't change text position at all. This is an error.
+ raise LexError("Scanning error. Illegal character '%s'" % (lexdata[lexpos]), lexdata[lexpos:])
+ lexpos = self.lexpos
+ if not newtok:
+ continue
+ return newtok
+
+ self.lexpos = lexpos
+ raise LexError("Illegal character '%s' at index %d" % (lexdata[lexpos], lexpos), lexdata[lexpos:])
+
+ if self.lexeoff:
+ tok = LexToken()
+ tok.type = 'eof'
+ tok.value = ''
+ tok.lineno = self.lineno
+ tok.lexpos = lexpos
+ tok.lexer = self
+ self.lexpos = lexpos
+ newtok = self.lexeoff(tok)
+ return newtok
+
+ self.lexpos = lexpos + 1
+ if self.lexdata is None:
+ raise RuntimeError('No input string given with input()')
+ return None
+
+ # Iterator interface
+ def __iter__(self):
+ return self
+
+ def next(self):
+ t = self.token()
+ if t is None:
+ raise StopIteration
+ return t
+
+ __next__ = next
+
+# -----------------------------------------------------------------------------
+# ==== Lex Builder ===
+#
+# The functions and classes below are used to collect lexing information
+# and build a Lexer object from it.
+# -----------------------------------------------------------------------------
+
+# -----------------------------------------------------------------------------
+# _get_regex(func)
+#
+# Returns the regular expression assigned to a function either as a doc string
+# or as a .regex attribute attached by the @TOKEN decorator.
+# -----------------------------------------------------------------------------
+def _get_regex(func):
+ return getattr(func, 'regex', func.__doc__)
+
+# -----------------------------------------------------------------------------
+# get_caller_module_dict()
+#
+# This function returns a dictionary containing all of the symbols defined within
+# a caller further down the call stack. This is used to get the environment
+# associated with the yacc() call if none was provided.
+# -----------------------------------------------------------------------------
+def get_caller_module_dict(levels):
+ f = sys._getframe(levels)
+ ldict = f.f_globals.copy()
+ if f.f_globals != f.f_locals:
+ ldict.update(f.f_locals)
+ return ldict
+
+# -----------------------------------------------------------------------------
+# _funcs_to_names()
+#
+# Given a list of regular expression functions, this converts it to a list
+# suitable for output to a table file
+# -----------------------------------------------------------------------------
+def _funcs_to_names(funclist, namelist):
+ result = []
+ for f, name in zip(funclist, namelist):
+ if f and f[0]:
+ result.append((name, f[1]))
+ else:
+ result.append(f)
+ return result
+
+# -----------------------------------------------------------------------------
+# _names_to_funcs()
+#
+# Given a list of regular expression function names, this converts it back to
+# functions.
+# -----------------------------------------------------------------------------
+def _names_to_funcs(namelist, fdict):
+ result = []
+ for n in namelist:
+ if n and n[0]:
+ result.append((fdict[n[0]], n[1]))
+ else:
+ result.append(n)
+ return result
+
+# -----------------------------------------------------------------------------
+# _form_master_re()
+#
+# This function takes a list of all of the regex components and attempts to
+# form the master regular expression. Given limitations in the Python re
+# module, it may be necessary to break the master regex into separate expressions.
+# -----------------------------------------------------------------------------
+def _form_master_re(relist, reflags, ldict, toknames):
+ if not relist:
+ return []
+ regex = '|'.join(relist)
+ try:
+ lexre = re.compile(regex, reflags)
+
+ # Build the index to function map for the matching engine
+ lexindexfunc = [None] * (max(lexre.groupindex.values()) + 1)
+ lexindexnames = lexindexfunc[:]
+
+ for f, i in lexre.groupindex.items():
+ handle = ldict.get(f, None)
+ if type(handle) in (types.FunctionType, types.MethodType):
+ lexindexfunc[i] = (handle, toknames[f])
+ lexindexnames[i] = f
+ elif handle is not None:
+ lexindexnames[i] = f
+ if f.find('ignore_') > 0:
+ lexindexfunc[i] = (None, None)
+ else:
+ lexindexfunc[i] = (None, toknames[f])
+
+ return [(lexre, lexindexfunc)], [regex], [lexindexnames]
+ except Exception:
+ m = int(len(relist)/2)
+ if m == 0:
+ m = 1
+ llist, lre, lnames = _form_master_re(relist[:m], reflags, ldict, toknames)
+ rlist, rre, rnames = _form_master_re(relist[m:], reflags, ldict, toknames)
+ return (llist+rlist), (lre+rre), (lnames+rnames)
+
+# -----------------------------------------------------------------------------
+# def _statetoken(s,names)
+#
+# Given a declaration name s of the form "t_" and a dictionary whose keys are
+# state names, this function returns a tuple (states,tokenname) where states
+# is a tuple of state names and tokenname is the name of the token. For example,
+# calling this with s = "t_foo_bar_SPAM" might return (('foo','bar'),'SPAM')
+# -----------------------------------------------------------------------------
+def _statetoken(s, names):
+ nonstate = 1
+ parts = s.split('_')
+ for i, part in enumerate(parts[1:], 1):
+ if part not in names and part != 'ANY':
+ break
+
+ if i > 1:
+ states = tuple(parts[1:i])
+ else:
+ states = ('INITIAL',)
+
+ if 'ANY' in states:
+ states = tuple(names)
+
+ tokenname = '_'.join(parts[i:])
+ return (states, tokenname)
+
+
+# -----------------------------------------------------------------------------
+# LexerReflect()
+#
+# This class represents information needed to build a lexer as extracted from a
+# user's input file.
+# -----------------------------------------------------------------------------
+class LexerReflect(object):
+ def __init__(self, ldict, log=None, reflags=0):
+ self.ldict = ldict
+ self.error_func = None
+ self.tokens = []
+ self.reflags = reflags
+ self.stateinfo = {'INITIAL': 'inclusive'}
+ self.modules = set()
+ self.error = False
+ self.log = PlyLogger(sys.stderr) if log is None else log
+
+ # Get all of the basic information
+ def get_all(self):
+ self.get_tokens()
+ self.get_literals()
+ self.get_states()
+ self.get_rules()
+
+ # Validate all of the information
+ def validate_all(self):
+ self.validate_tokens()
+ self.validate_literals()
+ self.validate_rules()
+ return self.error
+
+ # Get the tokens map
+ def get_tokens(self):
+ tokens = self.ldict.get('tokens', None)
+ if not tokens:
+ self.log.error('No token list is defined')
+ self.error = True
+ return
+
+ if not isinstance(tokens, (list, tuple)):
+ self.log.error('tokens must be a list or tuple')
+ self.error = True
+ return
+
+ if not tokens:
+ self.log.error('tokens is empty')
+ self.error = True
+ return
+
+ self.tokens = tokens
+
+ # Validate the tokens
+ def validate_tokens(self):
+ terminals = {}
+ for n in self.tokens:
+ if not _is_identifier.match(n):
+ self.log.error("Bad token name '%s'", n)
+ self.error = True
+ if n in terminals:
+ self.log.warning("Token '%s' multiply defined", n)
+ terminals[n] = 1
+
+ # Get the literals specifier
+ def get_literals(self):
+ self.literals = self.ldict.get('literals', '')
+ if not self.literals:
+ self.literals = ''
+
+ # Validate literals
+ def validate_literals(self):
+ try:
+ for c in self.literals:
+ if not isinstance(c, StringTypes) or len(c) > 1:
+ self.log.error('Invalid literal %s. Must be a single character', repr(c))
+ self.error = True
+
+ except TypeError:
+ self.log.error('Invalid literals specification. literals must be a sequence of characters')
+ self.error = True
+
+ def get_states(self):
+ self.states = self.ldict.get('states', None)
+ # Build statemap
+ if self.states:
+ if not isinstance(self.states, (tuple, list)):
+ self.log.error('states must be defined as a tuple or list')
+ self.error = True
+ else:
+ for s in self.states:
+ if not isinstance(s, tuple) or len(s) != 2:
+ self.log.error("Invalid state specifier %s. Must be a tuple (statename,'exclusive|inclusive')", repr(s))
+ self.error = True
+ continue
+ name, statetype = s
+ if not isinstance(name, StringTypes):
+ self.log.error('State name %s must be a string', repr(name))
+ self.error = True
+ continue
+ if not (statetype == 'inclusive' or statetype == 'exclusive'):
+ self.log.error("State type for state %s must be 'inclusive' or 'exclusive'", name)
+ self.error = True
+ continue
+ if name in self.stateinfo:
+ self.log.error("State '%s' already defined", name)
+ self.error = True
+ continue
+ self.stateinfo[name] = statetype
+
+ # Get all of the symbols with a t_ prefix and sort them into various
+ # categories (functions, strings, error functions, and ignore characters)
+
+ def get_rules(self):
+ tsymbols = [f for f in self.ldict if f[:2] == 't_']
+
+ # Now build up a list of functions and a list of strings
+ self.toknames = {} # Mapping of symbols to token names
+ self.funcsym = {} # Symbols defined as functions
+ self.strsym = {} # Symbols defined as strings
+ self.ignore = {} # Ignore strings by state
+ self.errorf = {} # Error functions by state
+ self.eoff = {} # EOF functions by state
+
+ for s in self.stateinfo:
+ self.funcsym[s] = []
+ self.strsym[s] = []
+
+ if len(tsymbols) == 0:
+ self.log.error('No rules of the form t_rulename are defined')
+ self.error = True
+ return
+
+ for f in tsymbols:
+ t = self.ldict[f]
+ states, tokname = _statetoken(f, self.stateinfo)
+ self.toknames[f] = tokname
+
+ if hasattr(t, '__call__'):
+ if tokname == 'error':
+ for s in states:
+ self.errorf[s] = t
+ elif tokname == 'eof':
+ for s in states:
+ self.eoff[s] = t
+ elif tokname == 'ignore':
+ line = t.__code__.co_firstlineno
+ file = t.__code__.co_filename
+ self.log.error("%s:%d: Rule '%s' must be defined as a string", file, line, t.__name__)
+ self.error = True
+ else:
+ for s in states:
+ self.funcsym[s].append((f, t))
+ elif isinstance(t, StringTypes):
+ if tokname == 'ignore':
+ for s in states:
+ self.ignore[s] = t
+ if '\\' in t:
+ self.log.warning("%s contains a literal backslash '\\'", f)
+
+ elif tokname == 'error':
+ self.log.error("Rule '%s' must be defined as a function", f)
+ self.error = True
+ else:
+ for s in states:
+ self.strsym[s].append((f, t))
+ else:
+ self.log.error('%s not defined as a function or string', f)
+ self.error = True
+
+ # Sort the functions by line number
+ for f in self.funcsym.values():
+ f.sort(key=lambda x: x[1].__code__.co_firstlineno)
+
+ # Sort the strings by regular expression length
+ for s in self.strsym.values():
+ s.sort(key=lambda x: len(x[1]), reverse=True)
+
+ # Validate all of the t_rules collected
+ def validate_rules(self):
+ for state in self.stateinfo:
+ # Validate all rules defined by functions
+
+ for fname, f in self.funcsym[state]:
+ line = f.__code__.co_firstlineno
+ file = f.__code__.co_filename
+ module = inspect.getmodule(f)
+ self.modules.add(module)
+
+ tokname = self.toknames[fname]
+ if isinstance(f, types.MethodType):
+ reqargs = 2
+ else:
+ reqargs = 1
+ nargs = f.__code__.co_argcount
+ if nargs > reqargs:
+ self.log.error("%s:%d: Rule '%s' has too many arguments", file, line, f.__name__)
+ self.error = True
+ continue
+
+ if nargs < reqargs:
+ self.log.error("%s:%d: Rule '%s' requires an argument", file, line, f.__name__)
+ self.error = True
+ continue
+
+ if not _get_regex(f):
+ self.log.error("%s:%d: No regular expression defined for rule '%s'", file, line, f.__name__)
+ self.error = True
+ continue
+
+ try:
+ c = re.compile('(?P<%s>%s)' % (fname, _get_regex(f)), self.reflags)
+ if c.match(''):
+ self.log.error("%s:%d: Regular expression for rule '%s' matches empty string", file, line, f.__name__)
+ self.error = True
+ except re.error as e:
+ self.log.error("%s:%d: Invalid regular expression for rule '%s'. %s", file, line, f.__name__, e)
+ if '#' in _get_regex(f):
+ self.log.error("%s:%d. Make sure '#' in rule '%s' is escaped with '\\#'", file, line, f.__name__)
+ self.error = True
+
+ # Validate all rules defined by strings
+ for name, r in self.strsym[state]:
+ tokname = self.toknames[name]
+ if tokname == 'error':
+ self.log.error("Rule '%s' must be defined as a function", name)
+ self.error = True
+ continue
+
+ if tokname not in self.tokens and tokname.find('ignore_') < 0:
+ self.log.error("Rule '%s' defined for an unspecified token %s", name, tokname)
+ self.error = True
+ continue
+
+ try:
+ c = re.compile('(?P<%s>%s)' % (name, r), self.reflags)
+ if (c.match('')):
+ self.log.error("Regular expression for rule '%s' matches empty string", name)
+ self.error = True
+ except re.error as e:
+ self.log.error("Invalid regular expression for rule '%s'. %s", name, e)
+ if '#' in r:
+ self.log.error("Make sure '#' in rule '%s' is escaped with '\\#'", name)
+ self.error = True
+
+ if not self.funcsym[state] and not self.strsym[state]:
+ self.log.error("No rules defined for state '%s'", state)
+ self.error = True
+
+ # Validate the error function
+ efunc = self.errorf.get(state, None)
+ if efunc:
+ f = efunc
+ line = f.__code__.co_firstlineno
+ file = f.__code__.co_filename
+ module = inspect.getmodule(f)
+ self.modules.add(module)
+
+ if isinstance(f, types.MethodType):
+ reqargs = 2
+ else:
+ reqargs = 1
+ nargs = f.__code__.co_argcount
+ if nargs > reqargs:
+ self.log.error("%s:%d: Rule '%s' has too many arguments", file, line, f.__name__)
+ self.error = True
+
+ if nargs < reqargs:
+ self.log.error("%s:%d: Rule '%s' requires an argument", file, line, f.__name__)
+ self.error = True
+
+ for module in self.modules:
+ self.validate_module(module)
+
+ # -----------------------------------------------------------------------------
+ # validate_module()
+ #
+ # This checks to see if there are duplicated t_rulename() functions or strings
+ # in the parser input file. This is done using a simple regular expression
+ # match on each line in the source code of the given module.
+ # -----------------------------------------------------------------------------
+
+ def validate_module(self, module):
+ try:
+ lines, linen = inspect.getsourcelines(module)
+ except IOError:
+ return
+
+ fre = re.compile(r'\s*def\s+(t_[a-zA-Z_0-9]*)\(')
+ sre = re.compile(r'\s*(t_[a-zA-Z_0-9]*)\s*=')
+
+ counthash = {}
+ linen += 1
+ for line in lines:
+ m = fre.match(line)
+ if not m:
+ m = sre.match(line)
+ if m:
+ name = m.group(1)
+ prev = counthash.get(name)
+ if not prev:
+ counthash[name] = linen
+ else:
+ filename = inspect.getsourcefile(module)
+ self.log.error('%s:%d: Rule %s redefined. Previously defined on line %d', filename, linen, name, prev)
+ self.error = True
+ linen += 1
+
+# -----------------------------------------------------------------------------
+# lex(module)
+#
+# Build all of the regular expression rules from definitions in the supplied module
+# -----------------------------------------------------------------------------
+def lex(module=None, object=None, debug=False, optimize=False, lextab='lextab',
+ reflags=int(re.VERBOSE), nowarn=False, outputdir=None, debuglog=None, errorlog=None):
+
+ if lextab is None:
+ lextab = 'lextab'
+
+ global lexer
+
+ ldict = None
+ stateinfo = {'INITIAL': 'inclusive'}
+ lexobj = Lexer()
+ lexobj.lexoptimize = optimize
+ global token, input
+
+ if errorlog is None:
+ errorlog = PlyLogger(sys.stderr)
+
+ if debug:
+ if debuglog is None:
+ debuglog = PlyLogger(sys.stderr)
+
+ # Get the module dictionary used for the lexer
+ if object:
+ module = object
+
+ # Get the module dictionary used for the parser
+ if module:
+ _items = [(k, getattr(module, k)) for k in dir(module)]
+ ldict = dict(_items)
+ # If no __file__ attribute is available, try to obtain it from the __module__ instead
+ if '__file__' not in ldict:
+ ldict['__file__'] = sys.modules[ldict['__module__']].__file__
+ else:
+ ldict = get_caller_module_dict(2)
+
+ # Determine if the module is package of a package or not.
+ # If so, fix the tabmodule setting so that tables load correctly
+ pkg = ldict.get('__package__')
+ if pkg and isinstance(lextab, str):
+ if '.' not in lextab:
+ lextab = pkg + '.' + lextab
+
+ # Collect parser information from the dictionary
+ linfo = LexerReflect(ldict, log=errorlog, reflags=reflags)
+ linfo.get_all()
+ if not optimize:
+ if linfo.validate_all():
+ raise SyntaxError("Can't build lexer")
+
+ if optimize and lextab:
+ try:
+ lexobj.readtab(lextab, ldict)
+ token = lexobj.token
+ input = lexobj.input
+ lexer = lexobj
+ return lexobj
+
+ except ImportError:
+ pass
+
+ # Dump some basic debugging information
+ if debug:
+ debuglog.info('lex: tokens = %r', linfo.tokens)
+ debuglog.info('lex: literals = %r', linfo.literals)
+ debuglog.info('lex: states = %r', linfo.stateinfo)
+
+ # Build a dictionary of valid token names
+ lexobj.lextokens = set()
+ for n in linfo.tokens:
+ lexobj.lextokens.add(n)
+
+ # Get literals specification
+ if isinstance(linfo.literals, (list, tuple)):
+ lexobj.lexliterals = type(linfo.literals[0])().join(linfo.literals)
+ else:
+ lexobj.lexliterals = linfo.literals
+
+ lexobj.lextokens_all = lexobj.lextokens | set(lexobj.lexliterals)
+
+ # Get the stateinfo dictionary
+ stateinfo = linfo.stateinfo
+
+ regexs = {}
+ # Build the master regular expressions
+ for state in stateinfo:
+ regex_list = []
+
+ # Add rules defined by functions first
+ for fname, f in linfo.funcsym[state]:
+ line = f.__code__.co_firstlineno
+ file = f.__code__.co_filename
+ regex_list.append('(?P<%s>%s)' % (fname, _get_regex(f)))
+ if debug:
+ debuglog.info("lex: Adding rule %s -> '%s' (state '%s')", fname, _get_regex(f), state)
+
+ # Now add all of the simple rules
+ for name, r in linfo.strsym[state]:
+ regex_list.append('(?P<%s>%s)' % (name, r))
+ if debug:
+ debuglog.info("lex: Adding rule %s -> '%s' (state '%s')", name, r, state)
+
+ regexs[state] = regex_list
+
+ # Build the master regular expressions
+
+ if debug:
+ debuglog.info('lex: ==== MASTER REGEXS FOLLOW ====')
+
+ for state in regexs:
+ lexre, re_text, re_names = _form_master_re(regexs[state], reflags, ldict, linfo.toknames)
+ lexobj.lexstatere[state] = lexre
+ lexobj.lexstateretext[state] = re_text
+ lexobj.lexstaterenames[state] = re_names
+ if debug:
+ for i, text in enumerate(re_text):
+ debuglog.info("lex: state '%s' : regex[%d] = '%s'", state, i, text)
+
+ # For inclusive states, we need to add the regular expressions from the INITIAL state
+ for state, stype in stateinfo.items():
+ if state != 'INITIAL' and stype == 'inclusive':
+ lexobj.lexstatere[state].extend(lexobj.lexstatere['INITIAL'])
+ lexobj.lexstateretext[state].extend(lexobj.lexstateretext['INITIAL'])
+ lexobj.lexstaterenames[state].extend(lexobj.lexstaterenames['INITIAL'])
+
+ lexobj.lexstateinfo = stateinfo
+ lexobj.lexre = lexobj.lexstatere['INITIAL']
+ lexobj.lexretext = lexobj.lexstateretext['INITIAL']
+ lexobj.lexreflags = reflags
+
+ # Set up ignore variables
+ lexobj.lexstateignore = linfo.ignore
+ lexobj.lexignore = lexobj.lexstateignore.get('INITIAL', '')
+
+ # Set up error functions
+ lexobj.lexstateerrorf = linfo.errorf
+ lexobj.lexerrorf = linfo.errorf.get('INITIAL', None)
+ if not lexobj.lexerrorf:
+ errorlog.warning('No t_error rule is defined')
+
+ # Set up eof functions
+ lexobj.lexstateeoff = linfo.eoff
+ lexobj.lexeoff = linfo.eoff.get('INITIAL', None)
+
+ # Check state information for ignore and error rules
+ for s, stype in stateinfo.items():
+ if stype == 'exclusive':
+ if s not in linfo.errorf:
+ errorlog.warning("No error rule is defined for exclusive state '%s'", s)
+ if s not in linfo.ignore and lexobj.lexignore:
+ errorlog.warning("No ignore rule is defined for exclusive state '%s'", s)
+ elif stype == 'inclusive':
+ if s not in linfo.errorf:
+ linfo.errorf[s] = linfo.errorf.get('INITIAL', None)
+ if s not in linfo.ignore:
+ linfo.ignore[s] = linfo.ignore.get('INITIAL', '')
+
+ # Create global versions of the token() and input() functions
+ token = lexobj.token
+ input = lexobj.input
+ lexer = lexobj
+
+ # If in optimize mode, we write the lextab
+ if lextab and optimize:
+ if outputdir is None:
+ # If no output directory is set, the location of the output files
+ # is determined according to the following rules:
+ # - If lextab specifies a package, files go into that package directory
+ # - Otherwise, files go in the same directory as the specifying module
+ if isinstance(lextab, types.ModuleType):
+ srcfile = lextab.__file__
+ else:
+ if '.' not in lextab:
+ srcfile = ldict['__file__']
+ else:
+ parts = lextab.split('.')
+ pkgname = '.'.join(parts[:-1])
+ exec('import %s' % pkgname)
+ srcfile = getattr(sys.modules[pkgname], '__file__', '')
+ outputdir = os.path.dirname(srcfile)
+ try:
+ lexobj.writetab(lextab, outputdir)
+ except IOError as e:
+ errorlog.warning("Couldn't write lextab module %r. %s" % (lextab, e))
+
+ return lexobj
+
+# -----------------------------------------------------------------------------
+# runmain()
+#
+# This runs the lexer as a main program
+# -----------------------------------------------------------------------------
+
+def runmain(lexer=None, data=None):
+ if not data:
+ try:
+ filename = sys.argv[1]
+ f = open(filename)
+ data = f.read()
+ f.close()
+ except IndexError:
+ sys.stdout.write('Reading from standard input (type EOF to end):\n')
+ data = sys.stdin.read()
+
+ if lexer:
+ _input = lexer.input
+ else:
+ _input = input
+ _input(data)
+ if lexer:
+ _token = lexer.token
+ else:
+ _token = token
+
+ while True:
+ tok = _token()
+ if not tok:
+ break
+ sys.stdout.write('(%s,%r,%d,%d)\n' % (tok.type, tok.value, tok.lineno, tok.lexpos))
+
+# -----------------------------------------------------------------------------
+# @TOKEN(regex)
+#
+# This decorator function can be used to set the regex expression on a function
+# when its docstring might need to be set in an alternative way
+# -----------------------------------------------------------------------------
+
+def TOKEN(r):
+ def set_regex(f):
+ if hasattr(r, '__call__'):
+ f.regex = _get_regex(r)
+ else:
+ f.regex = r
+ return f
+ return set_regex
+
+# Alternative spelling of the TOKEN decorator
+Token = TOKEN
diff --git a/lib/pycparser/ply/yacc.py b/lib/pycparser/ply/yacc.py
new file mode 100644
index 0000000..20b4f28
--- /dev/null
+++ b/lib/pycparser/ply/yacc.py
@@ -0,0 +1,3494 @@
+# -----------------------------------------------------------------------------
+# ply: yacc.py
+#
+# Copyright (C) 2001-2017
+# David M. Beazley (Dabeaz LLC)
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+# * Neither the name of the David Beazley or Dabeaz LLC may be used to
+# endorse or promote products derived from this software without
+# specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# -----------------------------------------------------------------------------
+#
+# This implements an LR parser that is constructed from grammar rules defined
+# as Python functions. The grammer is specified by supplying the BNF inside
+# Python documentation strings. The inspiration for this technique was borrowed
+# from John Aycock's Spark parsing system. PLY might be viewed as cross between
+# Spark and the GNU bison utility.
+#
+# The current implementation is only somewhat object-oriented. The
+# LR parser itself is defined in terms of an object (which allows multiple
+# parsers to co-exist). However, most of the variables used during table
+# construction are defined in terms of global variables. Users shouldn't
+# notice unless they are trying to define multiple parsers at the same
+# time using threads (in which case they should have their head examined).
+#
+# This implementation supports both SLR and LALR(1) parsing. LALR(1)
+# support was originally implemented by Elias Ioup (ezioup@alumni.uchicago.edu),
+# using the algorithm found in Aho, Sethi, and Ullman "Compilers: Principles,
+# Techniques, and Tools" (The Dragon Book). LALR(1) has since been replaced
+# by the more efficient DeRemer and Pennello algorithm.
+#
+# :::::::: WARNING :::::::
+#
+# Construction of LR parsing tables is fairly complicated and expensive.
+# To make this module run fast, a *LOT* of work has been put into
+# optimization---often at the expensive of readability and what might
+# consider to be good Python "coding style." Modify the code at your
+# own risk!
+# ----------------------------------------------------------------------------
+
+import re
+import types
+import sys
+import os.path
+import inspect
+import base64
+import warnings
+
+__version__ = '3.10'
+__tabversion__ = '3.10'
+
+#-----------------------------------------------------------------------------
+# === User configurable parameters ===
+#
+# Change these to modify the default behavior of yacc (if you wish)
+#-----------------------------------------------------------------------------
+
+yaccdebug = True # Debugging mode. If set, yacc generates a
+ # a 'parser.out' file in the current directory
+
+debug_file = 'parser.out' # Default name of the debugging file
+tab_module = 'parsetab' # Default name of the table module
+default_lr = 'LALR' # Default LR table generation method
+
+error_count = 3 # Number of symbols that must be shifted to leave recovery mode
+
+yaccdevel = False # Set to True if developing yacc. This turns off optimized
+ # implementations of certain functions.
+
+resultlimit = 40 # Size limit of results when running in debug mode.
+
+pickle_protocol = 0 # Protocol to use when writing pickle files
+
+# String type-checking compatibility
+if sys.version_info[0] < 3:
+ string_types = basestring
+else:
+ string_types = str
+
+MAXINT = sys.maxsize
+
+# This object is a stand-in for a logging object created by the
+# logging module. PLY will use this by default to create things
+# such as the parser.out file. If a user wants more detailed
+# information, they can create their own logging object and pass
+# it into PLY.
+
+class PlyLogger(object):
+ def __init__(self, f):
+ self.f = f
+
+ def debug(self, msg, *args, **kwargs):
+ self.f.write((msg % args) + '\n')
+
+ info = debug
+
+ def warning(self, msg, *args, **kwargs):
+ self.f.write('WARNING: ' + (msg % args) + '\n')
+
+ def error(self, msg, *args, **kwargs):
+ self.f.write('ERROR: ' + (msg % args) + '\n')
+
+ critical = debug
+
+# Null logger is used when no output is generated. Does nothing.
+class NullLogger(object):
+ def __getattribute__(self, name):
+ return self
+
+ def __call__(self, *args, **kwargs):
+ return self
+
+# Exception raised for yacc-related errors
+class YaccError(Exception):
+ pass
+
+# Format the result message that the parser produces when running in debug mode.
+def format_result(r):
+ repr_str = repr(r)
+ if '\n' in repr_str:
+ repr_str = repr(repr_str)
+ if len(repr_str) > resultlimit:
+ repr_str = repr_str[:resultlimit] + ' ...'
+ result = '<%s @ 0x%x> (%s)' % (type(r).__name__, id(r), repr_str)
+ return result
+
+# Format stack entries when the parser is running in debug mode
+def format_stack_entry(r):
+ repr_str = repr(r)
+ if '\n' in repr_str:
+ repr_str = repr(repr_str)
+ if len(repr_str) < 16:
+ return repr_str
+ else:
+ return '<%s @ 0x%x>' % (type(r).__name__, id(r))
+
+# Panic mode error recovery support. This feature is being reworked--much of the
+# code here is to offer a deprecation/backwards compatible transition
+
+_errok = None
+_token = None
+_restart = None
+_warnmsg = '''PLY: Don't use global functions errok(), token(), and restart() in p_error().
+Instead, invoke the methods on the associated parser instance:
+
+ def p_error(p):
+ ...
+ # Use parser.errok(), parser.token(), parser.restart()
+ ...
+
+ parser = yacc.yacc()
+'''
+
+def errok():
+ warnings.warn(_warnmsg)
+ return _errok()
+
+def restart():
+ warnings.warn(_warnmsg)
+ return _restart()
+
+def token():
+ warnings.warn(_warnmsg)
+ return _token()
+
+# Utility function to call the p_error() function with some deprecation hacks
+def call_errorfunc(errorfunc, token, parser):
+ global _errok, _token, _restart
+ _errok = parser.errok
+ _token = parser.token
+ _restart = parser.restart
+ r = errorfunc(token)
+ try:
+ del _errok, _token, _restart
+ except NameError:
+ pass
+ return r
+
+#-----------------------------------------------------------------------------
+# === LR Parsing Engine ===
+#
+# The following classes are used for the LR parser itself. These are not
+# used during table construction and are independent of the actual LR
+# table generation algorithm
+#-----------------------------------------------------------------------------
+
+# This class is used to hold non-terminal grammar symbols during parsing.
+# It normally has the following attributes set:
+# .type = Grammar symbol type
+# .value = Symbol value
+# .lineno = Starting line number
+# .endlineno = Ending line number (optional, set automatically)
+# .lexpos = Starting lex position
+# .endlexpos = Ending lex position (optional, set automatically)
+
+class YaccSymbol:
+ def __str__(self):
+ return self.type
+
+ def __repr__(self):
+ return str(self)
+
+# This class is a wrapper around the objects actually passed to each
+# grammar rule. Index lookup and assignment actually assign the
+# .value attribute of the underlying YaccSymbol object.
+# The lineno() method returns the line number of a given
+# item (or 0 if not defined). The linespan() method returns
+# a tuple of (startline,endline) representing the range of lines
+# for a symbol. The lexspan() method returns a tuple (lexpos,endlexpos)
+# representing the range of positional information for a symbol.
+
+class YaccProduction:
+ def __init__(self, s, stack=None):
+ self.slice = s
+ self.stack = stack
+ self.lexer = None
+ self.parser = None
+
+ def __getitem__(self, n):
+ if isinstance(n, slice):
+ return [s.value for s in self.slice[n]]
+ elif n >= 0:
+ return self.slice[n].value
+ else:
+ return self.stack[n].value
+
+ def __setitem__(self, n, v):
+ self.slice[n].value = v
+
+ def __getslice__(self, i, j):
+ return [s.value for s in self.slice[i:j]]
+
+ def __len__(self):
+ return len(self.slice)
+
+ def lineno(self, n):
+ return getattr(self.slice[n], 'lineno', 0)
+
+ def set_lineno(self, n, lineno):
+ self.slice[n].lineno = lineno
+
+ def linespan(self, n):
+ startline = getattr(self.slice[n], 'lineno', 0)
+ endline = getattr(self.slice[n], 'endlineno', startline)
+ return startline, endline
+
+ def lexpos(self, n):
+ return getattr(self.slice[n], 'lexpos', 0)
+
+ def lexspan(self, n):
+ startpos = getattr(self.slice[n], 'lexpos', 0)
+ endpos = getattr(self.slice[n], 'endlexpos', startpos)
+ return startpos, endpos
+
+ def error(self):
+ raise SyntaxError
+
+# -----------------------------------------------------------------------------
+# == LRParser ==
+#
+# The LR Parsing engine.
+# -----------------------------------------------------------------------------
+
+class LRParser:
+ def __init__(self, lrtab, errorf):
+ self.productions = lrtab.lr_productions
+ self.action = lrtab.lr_action
+ self.goto = lrtab.lr_goto
+ self.errorfunc = errorf
+ self.set_defaulted_states()
+ self.errorok = True
+
+ def errok(self):
+ self.errorok = True
+
+ def restart(self):
+ del self.statestack[:]
+ del self.symstack[:]
+ sym = YaccSymbol()
+ sym.type = '$end'
+ self.symstack.append(sym)
+ self.statestack.append(0)
+
+ # Defaulted state support.
+ # This method identifies parser states where there is only one possible reduction action.
+ # For such states, the parser can make a choose to make a rule reduction without consuming
+ # the next look-ahead token. This delayed invocation of the tokenizer can be useful in
+ # certain kinds of advanced parsing situations where the lexer and parser interact with
+ # each other or change states (i.e., manipulation of scope, lexer states, etc.).
+ #
+ # See: https://www.gnu.org/software/bison/manual/html_node/Default-Reductions.html#Default-Reductions
+ def set_defaulted_states(self):
+ self.defaulted_states = {}
+ for state, actions in self.action.items():
+ rules = list(actions.values())
+ if len(rules) == 1 and rules[0] < 0:
+ self.defaulted_states[state] = rules[0]
+
+ def disable_defaulted_states(self):
+ self.defaulted_states = {}
+
+ def parse(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None):
+ if debug or yaccdevel:
+ if isinstance(debug, int):
+ debug = PlyLogger(sys.stderr)
+ return self.parsedebug(input, lexer, debug, tracking, tokenfunc)
+ elif tracking:
+ return self.parseopt(input, lexer, debug, tracking, tokenfunc)
+ else:
+ return self.parseopt_notrack(input, lexer, debug, tracking, tokenfunc)
+
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # parsedebug().
+ #
+ # This is the debugging enabled version of parse(). All changes made to the
+ # parsing engine should be made here. Optimized versions of this function
+ # are automatically created by the ply/ygen.py script. This script cuts out
+ # sections enclosed in markers such as this:
+ #
+ # #--! DEBUG
+ # statements
+ # #--! DEBUG
+ #
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ def parsedebug(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None):
+ #--! parsedebug-start
+ lookahead = None # Current lookahead symbol
+ lookaheadstack = [] # Stack of lookahead symbols
+ actions = self.action # Local reference to action table (to avoid lookup on self.)
+ goto = self.goto # Local reference to goto table (to avoid lookup on self.)
+ prod = self.productions # Local reference to production list (to avoid lookup on self.)
+ defaulted_states = self.defaulted_states # Local reference to defaulted states
+ pslice = YaccProduction(None) # Production object passed to grammar rules
+ errorcount = 0 # Used during error recovery
+
+ #--! DEBUG
+ debug.info('PLY: PARSE DEBUG START')
+ #--! DEBUG
+
+ # If no lexer was given, we will try to use the lex module
+ if not lexer:
+ from . import lex
+ lexer = lex.lexer
+
+ # Set up the lexer and parser objects on pslice
+ pslice.lexer = lexer
+ pslice.parser = self
+
+ # If input was supplied, pass to lexer
+ if input is not None:
+ lexer.input(input)
+
+ if tokenfunc is None:
+ # Tokenize function
+ get_token = lexer.token
+ else:
+ get_token = tokenfunc
+
+ # Set the parser() token method (sometimes used in error recovery)
+ self.token = get_token
+
+ # Set up the state and symbol stacks
+
+ statestack = [] # Stack of parsing states
+ self.statestack = statestack
+ symstack = [] # Stack of grammar symbols
+ self.symstack = symstack
+
+ pslice.stack = symstack # Put in the production
+ errtoken = None # Err token
+
+ # The start state is assumed to be (0,$end)
+
+ statestack.append(0)
+ sym = YaccSymbol()
+ sym.type = '$end'
+ symstack.append(sym)
+ state = 0
+ while True:
+ # Get the next symbol on the input. If a lookahead symbol
+ # is already set, we just use that. Otherwise, we'll pull
+ # the next token off of the lookaheadstack or from the lexer
+
+ #--! DEBUG
+ debug.debug('')
+ debug.debug('State : %s', state)
+ #--! DEBUG
+
+ if state not in defaulted_states:
+ if not lookahead:
+ if not lookaheadstack:
+ lookahead = get_token() # Get the next token
+ else:
+ lookahead = lookaheadstack.pop()
+ if not lookahead:
+ lookahead = YaccSymbol()
+ lookahead.type = '$end'
+
+ # Check the action table
+ ltype = lookahead.type
+ t = actions[state].get(ltype)
+ else:
+ t = defaulted_states[state]
+ #--! DEBUG
+ debug.debug('Defaulted state %s: Reduce using %d', state, -t)
+ #--! DEBUG
+
+ #--! DEBUG
+ debug.debug('Stack : %s',
+ ('%s . %s' % (' '.join([xx.type for xx in symstack][1:]), str(lookahead))).lstrip())
+ #--! DEBUG
+
+ if t is not None:
+ if t > 0:
+ # shift a symbol on the stack
+ statestack.append(t)
+ state = t
+
+ #--! DEBUG
+ debug.debug('Action : Shift and goto state %s', t)
+ #--! DEBUG
+
+ symstack.append(lookahead)
+ lookahead = None
+
+ # Decrease error count on successful shift
+ if errorcount:
+ errorcount -= 1
+ continue
+
+ if t < 0:
+ # reduce a symbol on the stack, emit a production
+ p = prod[-t]
+ pname = p.name
+ plen = p.len
+
+ # Get production function
+ sym = YaccSymbol()
+ sym.type = pname # Production name
+ sym.value = None
+
+ #--! DEBUG
+ if plen:
+ debug.info('Action : Reduce rule [%s] with %s and goto state %d', p.str,
+ '['+','.join([format_stack_entry(_v.value) for _v in symstack[-plen:]])+']',
+ goto[statestack[-1-plen]][pname])
+ else:
+ debug.info('Action : Reduce rule [%s] with %s and goto state %d', p.str, [],
+ goto[statestack[-1]][pname])
+
+ #--! DEBUG
+
+ if plen:
+ targ = symstack[-plen-1:]
+ targ[0] = sym
+
+ #--! TRACKING
+ if tracking:
+ t1 = targ[1]
+ sym.lineno = t1.lineno
+ sym.lexpos = t1.lexpos
+ t1 = targ[-1]
+ sym.endlineno = getattr(t1, 'endlineno', t1.lineno)
+ sym.endlexpos = getattr(t1, 'endlexpos', t1.lexpos)
+ #--! TRACKING
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # The code enclosed in this section is duplicated
+ # below as a performance optimization. Make sure
+ # changes get made in both locations.
+
+ pslice.slice = targ
+
+ try:
+ # Call the grammar rule with our special slice object
+ del symstack[-plen:]
+ self.state = state
+ p.callable(pslice)
+ del statestack[-plen:]
+ #--! DEBUG
+ debug.info('Result : %s', format_result(pslice[0]))
+ #--! DEBUG
+ symstack.append(sym)
+ state = goto[statestack[-1]][pname]
+ statestack.append(state)
+ except SyntaxError:
+ # If an error was set. Enter error recovery state
+ lookaheadstack.append(lookahead) # Save the current lookahead token
+ symstack.extend(targ[1:-1]) # Put the production slice back on the stack
+ statestack.pop() # Pop back one state (before the reduce)
+ state = statestack[-1]
+ sym.type = 'error'
+ sym.value = 'error'
+ lookahead = sym
+ errorcount = error_count
+ self.errorok = False
+
+ continue
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ else:
+
+ #--! TRACKING
+ if tracking:
+ sym.lineno = lexer.lineno
+ sym.lexpos = lexer.lexpos
+ #--! TRACKING
+
+ targ = [sym]
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # The code enclosed in this section is duplicated
+ # above as a performance optimization. Make sure
+ # changes get made in both locations.
+
+ pslice.slice = targ
+
+ try:
+ # Call the grammar rule with our special slice object
+ self.state = state
+ p.callable(pslice)
+ #--! DEBUG
+ debug.info('Result : %s', format_result(pslice[0]))
+ #--! DEBUG
+ symstack.append(sym)
+ state = goto[statestack[-1]][pname]
+ statestack.append(state)
+ except SyntaxError:
+ # If an error was set. Enter error recovery state
+ lookaheadstack.append(lookahead) # Save the current lookahead token
+ statestack.pop() # Pop back one state (before the reduce)
+ state = statestack[-1]
+ sym.type = 'error'
+ sym.value = 'error'
+ lookahead = sym
+ errorcount = error_count
+ self.errorok = False
+
+ continue
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ if t == 0:
+ n = symstack[-1]
+ result = getattr(n, 'value', None)
+ #--! DEBUG
+ debug.info('Done : Returning %s', format_result(result))
+ debug.info('PLY: PARSE DEBUG END')
+ #--! DEBUG
+ return result
+
+ if t is None:
+
+ #--! DEBUG
+ debug.error('Error : %s',
+ ('%s . %s' % (' '.join([xx.type for xx in symstack][1:]), str(lookahead))).lstrip())
+ #--! DEBUG
+
+ # We have some kind of parsing error here. To handle
+ # this, we are going to push the current token onto
+ # the tokenstack and replace it with an 'error' token.
+ # If there are any synchronization rules, they may
+ # catch it.
+ #
+ # In addition to pushing the error token, we call call
+ # the user defined p_error() function if this is the
+ # first syntax error. This function is only called if
+ # errorcount == 0.
+ if errorcount == 0 or self.errorok:
+ errorcount = error_count
+ self.errorok = False
+ errtoken = lookahead
+ if errtoken.type == '$end':
+ errtoken = None # End of file!
+ if self.errorfunc:
+ if errtoken and not hasattr(errtoken, 'lexer'):
+ errtoken.lexer = lexer
+ self.state = state
+ tok = call_errorfunc(self.errorfunc, errtoken, self)
+ if self.errorok:
+ # User must have done some kind of panic
+ # mode recovery on their own. The
+ # returned token is the next lookahead
+ lookahead = tok
+ errtoken = None
+ continue
+ else:
+ if errtoken:
+ if hasattr(errtoken, 'lineno'):
+ lineno = lookahead.lineno
+ else:
+ lineno = 0
+ if lineno:
+ sys.stderr.write('yacc: Syntax error at line %d, token=%s\n' % (lineno, errtoken.type))
+ else:
+ sys.stderr.write('yacc: Syntax error, token=%s' % errtoken.type)
+ else:
+ sys.stderr.write('yacc: Parse error in input. EOF\n')
+ return
+
+ else:
+ errorcount = error_count
+
+ # case 1: the statestack only has 1 entry on it. If we're in this state, the
+ # entire parse has been rolled back and we're completely hosed. The token is
+ # discarded and we just keep going.
+
+ if len(statestack) <= 1 and lookahead.type != '$end':
+ lookahead = None
+ errtoken = None
+ state = 0
+ # Nuke the pushback stack
+ del lookaheadstack[:]
+ continue
+
+ # case 2: the statestack has a couple of entries on it, but we're
+ # at the end of the file. nuke the top entry and generate an error token
+
+ # Start nuking entries on the stack
+ if lookahead.type == '$end':
+ # Whoa. We're really hosed here. Bail out
+ return
+
+ if lookahead.type != 'error':
+ sym = symstack[-1]
+ if sym.type == 'error':
+ # Hmmm. Error is on top of stack, we'll just nuke input
+ # symbol and continue
+ #--! TRACKING
+ if tracking:
+ sym.endlineno = getattr(lookahead, 'lineno', sym.lineno)
+ sym.endlexpos = getattr(lookahead, 'lexpos', sym.lexpos)
+ #--! TRACKING
+ lookahead = None
+ continue
+
+ # Create the error symbol for the first time and make it the new lookahead symbol
+ t = YaccSymbol()
+ t.type = 'error'
+
+ if hasattr(lookahead, 'lineno'):
+ t.lineno = t.endlineno = lookahead.lineno
+ if hasattr(lookahead, 'lexpos'):
+ t.lexpos = t.endlexpos = lookahead.lexpos
+ t.value = lookahead
+ lookaheadstack.append(lookahead)
+ lookahead = t
+ else:
+ sym = symstack.pop()
+ #--! TRACKING
+ if tracking:
+ lookahead.lineno = sym.lineno
+ lookahead.lexpos = sym.lexpos
+ #--! TRACKING
+ statestack.pop()
+ state = statestack[-1]
+
+ continue
+
+ # Call an error function here
+ raise RuntimeError('yacc: internal parser error!!!\n')
+
+ #--! parsedebug-end
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # parseopt().
+ #
+ # Optimized version of parse() method. DO NOT EDIT THIS CODE DIRECTLY!
+ # This code is automatically generated by the ply/ygen.py script. Make
+ # changes to the parsedebug() method instead.
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ def parseopt(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None):
+ #--! parseopt-start
+ lookahead = None # Current lookahead symbol
+ lookaheadstack = [] # Stack of lookahead symbols
+ actions = self.action # Local reference to action table (to avoid lookup on self.)
+ goto = self.goto # Local reference to goto table (to avoid lookup on self.)
+ prod = self.productions # Local reference to production list (to avoid lookup on self.)
+ defaulted_states = self.defaulted_states # Local reference to defaulted states
+ pslice = YaccProduction(None) # Production object passed to grammar rules
+ errorcount = 0 # Used during error recovery
+
+
+ # If no lexer was given, we will try to use the lex module
+ if not lexer:
+ from . import lex
+ lexer = lex.lexer
+
+ # Set up the lexer and parser objects on pslice
+ pslice.lexer = lexer
+ pslice.parser = self
+
+ # If input was supplied, pass to lexer
+ if input is not None:
+ lexer.input(input)
+
+ if tokenfunc is None:
+ # Tokenize function
+ get_token = lexer.token
+ else:
+ get_token = tokenfunc
+
+ # Set the parser() token method (sometimes used in error recovery)
+ self.token = get_token
+
+ # Set up the state and symbol stacks
+
+ statestack = [] # Stack of parsing states
+ self.statestack = statestack
+ symstack = [] # Stack of grammar symbols
+ self.symstack = symstack
+
+ pslice.stack = symstack # Put in the production
+ errtoken = None # Err token
+
+ # The start state is assumed to be (0,$end)
+
+ statestack.append(0)
+ sym = YaccSymbol()
+ sym.type = '$end'
+ symstack.append(sym)
+ state = 0
+ while True:
+ # Get the next symbol on the input. If a lookahead symbol
+ # is already set, we just use that. Otherwise, we'll pull
+ # the next token off of the lookaheadstack or from the lexer
+
+
+ if state not in defaulted_states:
+ if not lookahead:
+ if not lookaheadstack:
+ lookahead = get_token() # Get the next token
+ else:
+ lookahead = lookaheadstack.pop()
+ if not lookahead:
+ lookahead = YaccSymbol()
+ lookahead.type = '$end'
+
+ # Check the action table
+ ltype = lookahead.type
+ t = actions[state].get(ltype)
+ else:
+ t = defaulted_states[state]
+
+
+ if t is not None:
+ if t > 0:
+ # shift a symbol on the stack
+ statestack.append(t)
+ state = t
+
+
+ symstack.append(lookahead)
+ lookahead = None
+
+ # Decrease error count on successful shift
+ if errorcount:
+ errorcount -= 1
+ continue
+
+ if t < 0:
+ # reduce a symbol on the stack, emit a production
+ p = prod[-t]
+ pname = p.name
+ plen = p.len
+
+ # Get production function
+ sym = YaccSymbol()
+ sym.type = pname # Production name
+ sym.value = None
+
+
+ if plen:
+ targ = symstack[-plen-1:]
+ targ[0] = sym
+
+ #--! TRACKING
+ if tracking:
+ t1 = targ[1]
+ sym.lineno = t1.lineno
+ sym.lexpos = t1.lexpos
+ t1 = targ[-1]
+ sym.endlineno = getattr(t1, 'endlineno', t1.lineno)
+ sym.endlexpos = getattr(t1, 'endlexpos', t1.lexpos)
+ #--! TRACKING
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # The code enclosed in this section is duplicated
+ # below as a performance optimization. Make sure
+ # changes get made in both locations.
+
+ pslice.slice = targ
+
+ try:
+ # Call the grammar rule with our special slice object
+ del symstack[-plen:]
+ self.state = state
+ p.callable(pslice)
+ del statestack[-plen:]
+ symstack.append(sym)
+ state = goto[statestack[-1]][pname]
+ statestack.append(state)
+ except SyntaxError:
+ # If an error was set. Enter error recovery state
+ lookaheadstack.append(lookahead) # Save the current lookahead token
+ symstack.extend(targ[1:-1]) # Put the production slice back on the stack
+ statestack.pop() # Pop back one state (before the reduce)
+ state = statestack[-1]
+ sym.type = 'error'
+ sym.value = 'error'
+ lookahead = sym
+ errorcount = error_count
+ self.errorok = False
+
+ continue
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ else:
+
+ #--! TRACKING
+ if tracking:
+ sym.lineno = lexer.lineno
+ sym.lexpos = lexer.lexpos
+ #--! TRACKING
+
+ targ = [sym]
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # The code enclosed in this section is duplicated
+ # above as a performance optimization. Make sure
+ # changes get made in both locations.
+
+ pslice.slice = targ
+
+ try:
+ # Call the grammar rule with our special slice object
+ self.state = state
+ p.callable(pslice)
+ symstack.append(sym)
+ state = goto[statestack[-1]][pname]
+ statestack.append(state)
+ except SyntaxError:
+ # If an error was set. Enter error recovery state
+ lookaheadstack.append(lookahead) # Save the current lookahead token
+ statestack.pop() # Pop back one state (before the reduce)
+ state = statestack[-1]
+ sym.type = 'error'
+ sym.value = 'error'
+ lookahead = sym
+ errorcount = error_count
+ self.errorok = False
+
+ continue
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ if t == 0:
+ n = symstack[-1]
+ result = getattr(n, 'value', None)
+ return result
+
+ if t is None:
+
+
+ # We have some kind of parsing error here. To handle
+ # this, we are going to push the current token onto
+ # the tokenstack and replace it with an 'error' token.
+ # If there are any synchronization rules, they may
+ # catch it.
+ #
+ # In addition to pushing the error token, we call call
+ # the user defined p_error() function if this is the
+ # first syntax error. This function is only called if
+ # errorcount == 0.
+ if errorcount == 0 or self.errorok:
+ errorcount = error_count
+ self.errorok = False
+ errtoken = lookahead
+ if errtoken.type == '$end':
+ errtoken = None # End of file!
+ if self.errorfunc:
+ if errtoken and not hasattr(errtoken, 'lexer'):
+ errtoken.lexer = lexer
+ self.state = state
+ tok = call_errorfunc(self.errorfunc, errtoken, self)
+ if self.errorok:
+ # User must have done some kind of panic
+ # mode recovery on their own. The
+ # returned token is the next lookahead
+ lookahead = tok
+ errtoken = None
+ continue
+ else:
+ if errtoken:
+ if hasattr(errtoken, 'lineno'):
+ lineno = lookahead.lineno
+ else:
+ lineno = 0
+ if lineno:
+ sys.stderr.write('yacc: Syntax error at line %d, token=%s\n' % (lineno, errtoken.type))
+ else:
+ sys.stderr.write('yacc: Syntax error, token=%s' % errtoken.type)
+ else:
+ sys.stderr.write('yacc: Parse error in input. EOF\n')
+ return
+
+ else:
+ errorcount = error_count
+
+ # case 1: the statestack only has 1 entry on it. If we're in this state, the
+ # entire parse has been rolled back and we're completely hosed. The token is
+ # discarded and we just keep going.
+
+ if len(statestack) <= 1 and lookahead.type != '$end':
+ lookahead = None
+ errtoken = None
+ state = 0
+ # Nuke the pushback stack
+ del lookaheadstack[:]
+ continue
+
+ # case 2: the statestack has a couple of entries on it, but we're
+ # at the end of the file. nuke the top entry and generate an error token
+
+ # Start nuking entries on the stack
+ if lookahead.type == '$end':
+ # Whoa. We're really hosed here. Bail out
+ return
+
+ if lookahead.type != 'error':
+ sym = symstack[-1]
+ if sym.type == 'error':
+ # Hmmm. Error is on top of stack, we'll just nuke input
+ # symbol and continue
+ #--! TRACKING
+ if tracking:
+ sym.endlineno = getattr(lookahead, 'lineno', sym.lineno)
+ sym.endlexpos = getattr(lookahead, 'lexpos', sym.lexpos)
+ #--! TRACKING
+ lookahead = None
+ continue
+
+ # Create the error symbol for the first time and make it the new lookahead symbol
+ t = YaccSymbol()
+ t.type = 'error'
+
+ if hasattr(lookahead, 'lineno'):
+ t.lineno = t.endlineno = lookahead.lineno
+ if hasattr(lookahead, 'lexpos'):
+ t.lexpos = t.endlexpos = lookahead.lexpos
+ t.value = lookahead
+ lookaheadstack.append(lookahead)
+ lookahead = t
+ else:
+ sym = symstack.pop()
+ #--! TRACKING
+ if tracking:
+ lookahead.lineno = sym.lineno
+ lookahead.lexpos = sym.lexpos
+ #--! TRACKING
+ statestack.pop()
+ state = statestack[-1]
+
+ continue
+
+ # Call an error function here
+ raise RuntimeError('yacc: internal parser error!!!\n')
+
+ #--! parseopt-end
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # parseopt_notrack().
+ #
+ # Optimized version of parseopt() with line number tracking removed.
+ # DO NOT EDIT THIS CODE DIRECTLY. This code is automatically generated
+ # by the ply/ygen.py script. Make changes to the parsedebug() method instead.
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ def parseopt_notrack(self, input=None, lexer=None, debug=False, tracking=False, tokenfunc=None):
+ #--! parseopt-notrack-start
+ lookahead = None # Current lookahead symbol
+ lookaheadstack = [] # Stack of lookahead symbols
+ actions = self.action # Local reference to action table (to avoid lookup on self.)
+ goto = self.goto # Local reference to goto table (to avoid lookup on self.)
+ prod = self.productions # Local reference to production list (to avoid lookup on self.)
+ defaulted_states = self.defaulted_states # Local reference to defaulted states
+ pslice = YaccProduction(None) # Production object passed to grammar rules
+ errorcount = 0 # Used during error recovery
+
+
+ # If no lexer was given, we will try to use the lex module
+ if not lexer:
+ from . import lex
+ lexer = lex.lexer
+
+ # Set up the lexer and parser objects on pslice
+ pslice.lexer = lexer
+ pslice.parser = self
+
+ # If input was supplied, pass to lexer
+ if input is not None:
+ lexer.input(input)
+
+ if tokenfunc is None:
+ # Tokenize function
+ get_token = lexer.token
+ else:
+ get_token = tokenfunc
+
+ # Set the parser() token method (sometimes used in error recovery)
+ self.token = get_token
+
+ # Set up the state and symbol stacks
+
+ statestack = [] # Stack of parsing states
+ self.statestack = statestack
+ symstack = [] # Stack of grammar symbols
+ self.symstack = symstack
+
+ pslice.stack = symstack # Put in the production
+ errtoken = None # Err token
+
+ # The start state is assumed to be (0,$end)
+
+ statestack.append(0)
+ sym = YaccSymbol()
+ sym.type = '$end'
+ symstack.append(sym)
+ state = 0
+ while True:
+ # Get the next symbol on the input. If a lookahead symbol
+ # is already set, we just use that. Otherwise, we'll pull
+ # the next token off of the lookaheadstack or from the lexer
+
+
+ if state not in defaulted_states:
+ if not lookahead:
+ if not lookaheadstack:
+ lookahead = get_token() # Get the next token
+ else:
+ lookahead = lookaheadstack.pop()
+ if not lookahead:
+ lookahead = YaccSymbol()
+ lookahead.type = '$end'
+
+ # Check the action table
+ ltype = lookahead.type
+ t = actions[state].get(ltype)
+ else:
+ t = defaulted_states[state]
+
+
+ if t is not None:
+ if t > 0:
+ # shift a symbol on the stack
+ statestack.append(t)
+ state = t
+
+
+ symstack.append(lookahead)
+ lookahead = None
+
+ # Decrease error count on successful shift
+ if errorcount:
+ errorcount -= 1
+ continue
+
+ if t < 0:
+ # reduce a symbol on the stack, emit a production
+ p = prod[-t]
+ pname = p.name
+ plen = p.len
+
+ # Get production function
+ sym = YaccSymbol()
+ sym.type = pname # Production name
+ sym.value = None
+
+
+ if plen:
+ targ = symstack[-plen-1:]
+ targ[0] = sym
+
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # The code enclosed in this section is duplicated
+ # below as a performance optimization. Make sure
+ # changes get made in both locations.
+
+ pslice.slice = targ
+
+ try:
+ # Call the grammar rule with our special slice object
+ del symstack[-plen:]
+ self.state = state
+ p.callable(pslice)
+ del statestack[-plen:]
+ symstack.append(sym)
+ state = goto[statestack[-1]][pname]
+ statestack.append(state)
+ except SyntaxError:
+ # If an error was set. Enter error recovery state
+ lookaheadstack.append(lookahead) # Save the current lookahead token
+ symstack.extend(targ[1:-1]) # Put the production slice back on the stack
+ statestack.pop() # Pop back one state (before the reduce)
+ state = statestack[-1]
+ sym.type = 'error'
+ sym.value = 'error'
+ lookahead = sym
+ errorcount = error_count
+ self.errorok = False
+
+ continue
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ else:
+
+
+ targ = [sym]
+
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ # The code enclosed in this section is duplicated
+ # above as a performance optimization. Make sure
+ # changes get made in both locations.
+
+ pslice.slice = targ
+
+ try:
+ # Call the grammar rule with our special slice object
+ self.state = state
+ p.callable(pslice)
+ symstack.append(sym)
+ state = goto[statestack[-1]][pname]
+ statestack.append(state)
+ except SyntaxError:
+ # If an error was set. Enter error recovery state
+ lookaheadstack.append(lookahead) # Save the current lookahead token
+ statestack.pop() # Pop back one state (before the reduce)
+ state = statestack[-1]
+ sym.type = 'error'
+ sym.value = 'error'
+ lookahead = sym
+ errorcount = error_count
+ self.errorok = False
+
+ continue
+ # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+ if t == 0:
+ n = symstack[-1]
+ result = getattr(n, 'value', None)
+ return result
+
+ if t is None:
+
+
+ # We have some kind of parsing error here. To handle
+ # this, we are going to push the current token onto
+ # the tokenstack and replace it with an 'error' token.
+ # If there are any synchronization rules, they may
+ # catch it.
+ #
+ # In addition to pushing the error token, we call call
+ # the user defined p_error() function if this is the
+ # first syntax error. This function is only called if
+ # errorcount == 0.
+ if errorcount == 0 or self.errorok:
+ errorcount = error_count
+ self.errorok = False
+ errtoken = lookahead
+ if errtoken.type == '$end':
+ errtoken = None # End of file!
+ if self.errorfunc:
+ if errtoken and not hasattr(errtoken, 'lexer'):
+ errtoken.lexer = lexer
+ self.state = state
+ tok = call_errorfunc(self.errorfunc, errtoken, self)
+ if self.errorok:
+ # User must have done some kind of panic
+ # mode recovery on their own. The
+ # returned token is the next lookahead
+ lookahead = tok
+ errtoken = None
+ continue
+ else:
+ if errtoken:
+ if hasattr(errtoken, 'lineno'):
+ lineno = lookahead.lineno
+ else:
+ lineno = 0
+ if lineno:
+ sys.stderr.write('yacc: Syntax error at line %d, token=%s\n' % (lineno, errtoken.type))
+ else:
+ sys.stderr.write('yacc: Syntax error, token=%s' % errtoken.type)
+ else:
+ sys.stderr.write('yacc: Parse error in input. EOF\n')
+ return
+
+ else:
+ errorcount = error_count
+
+ # case 1: the statestack only has 1 entry on it. If we're in this state, the
+ # entire parse has been rolled back and we're completely hosed. The token is
+ # discarded and we just keep going.
+
+ if len(statestack) <= 1 and lookahead.type != '$end':
+ lookahead = None
+ errtoken = None
+ state = 0
+ # Nuke the pushback stack
+ del lookaheadstack[:]
+ continue
+
+ # case 2: the statestack has a couple of entries on it, but we're
+ # at the end of the file. nuke the top entry and generate an error token
+
+ # Start nuking entries on the stack
+ if lookahead.type == '$end':
+ # Whoa. We're really hosed here. Bail out
+ return
+
+ if lookahead.type != 'error':
+ sym = symstack[-1]
+ if sym.type == 'error':
+ # Hmmm. Error is on top of stack, we'll just nuke input
+ # symbol and continue
+ lookahead = None
+ continue
+
+ # Create the error symbol for the first time and make it the new lookahead symbol
+ t = YaccSymbol()
+ t.type = 'error'
+
+ if hasattr(lookahead, 'lineno'):
+ t.lineno = t.endlineno = lookahead.lineno
+ if hasattr(lookahead, 'lexpos'):
+ t.lexpos = t.endlexpos = lookahead.lexpos
+ t.value = lookahead
+ lookaheadstack.append(lookahead)
+ lookahead = t
+ else:
+ sym = symstack.pop()
+ statestack.pop()
+ state = statestack[-1]
+
+ continue
+
+ # Call an error function here
+ raise RuntimeError('yacc: internal parser error!!!\n')
+
+ #--! parseopt-notrack-end
+
+# -----------------------------------------------------------------------------
+# === Grammar Representation ===
+#
+# The following functions, classes, and variables are used to represent and
+# manipulate the rules that make up a grammar.
+# -----------------------------------------------------------------------------
+
+# regex matching identifiers
+_is_identifier = re.compile(r'^[a-zA-Z0-9_-]+$')
+
+# -----------------------------------------------------------------------------
+# class Production:
+#
+# This class stores the raw information about a single production or grammar rule.
+# A grammar rule refers to a specification such as this:
+#
+# expr : expr PLUS term
+#
+# Here are the basic attributes defined on all productions
+#
+# name - Name of the production. For example 'expr'
+# prod - A list of symbols on the right side ['expr','PLUS','term']
+# prec - Production precedence level
+# number - Production number.
+# func - Function that executes on reduce
+# file - File where production function is defined
+# lineno - Line number where production function is defined
+#
+# The following attributes are defined or optional.
+#
+# len - Length of the production (number of symbols on right hand side)
+# usyms - Set of unique symbols found in the production
+# -----------------------------------------------------------------------------
+
+class Production(object):
+ reduced = 0
+ def __init__(self, number, name, prod, precedence=('right', 0), func=None, file='', line=0):
+ self.name = name
+ self.prod = tuple(prod)
+ self.number = number
+ self.func = func
+ self.callable = None
+ self.file = file
+ self.line = line
+ self.prec = precedence
+
+ # Internal settings used during table construction
+
+ self.len = len(self.prod) # Length of the production
+
+ # Create a list of unique production symbols used in the production
+ self.usyms = []
+ for s in self.prod:
+ if s not in self.usyms:
+ self.usyms.append(s)
+
+ # List of all LR items for the production
+ self.lr_items = []
+ self.lr_next = None
+
+ # Create a string representation
+ if self.prod:
+ self.str = '%s -> %s' % (self.name, ' '.join(self.prod))
+ else:
+ self.str = '%s -> <empty>' % self.name
+
+ def __str__(self):
+ return self.str
+
+ def __repr__(self):
+ return 'Production(' + str(self) + ')'
+
+ def __len__(self):
+ return len(self.prod)
+
+ def __nonzero__(self):
+ return 1
+
+ def __getitem__(self, index):
+ return self.prod[index]
+
+ # Return the nth lr_item from the production (or None if at the end)
+ def lr_item(self, n):
+ if n > len(self.prod):
+ return None
+ p = LRItem(self, n)
+ # Precompute the list of productions immediately following.
+ try:
+ p.lr_after = Prodnames[p.prod[n+1]]
+ except (IndexError, KeyError):
+ p.lr_after = []
+ try:
+ p.lr_before = p.prod[n-1]
+ except IndexError:
+ p.lr_before = None
+ return p
+
+ # Bind the production function name to a callable
+ def bind(self, pdict):
+ if self.func:
+ self.callable = pdict[self.func]
+
+# This class serves as a minimal standin for Production objects when
+# reading table data from files. It only contains information
+# actually used by the LR parsing engine, plus some additional
+# debugging information.
+class MiniProduction(object):
+ def __init__(self, str, name, len, func, file, line):
+ self.name = name
+ self.len = len
+ self.func = func
+ self.callable = None
+ self.file = file
+ self.line = line
+ self.str = str
+
+ def __str__(self):
+ return self.str
+
+ def __repr__(self):
+ return 'MiniProduction(%s)' % self.str
+
+ # Bind the production function name to a callable
+ def bind(self, pdict):
+ if self.func:
+ self.callable = pdict[self.func]
+
+
+# -----------------------------------------------------------------------------
+# class LRItem
+#
+# This class represents a specific stage of parsing a production rule. For
+# example:
+#
+# expr : expr . PLUS term
+#
+# In the above, the "." represents the current location of the parse. Here
+# basic attributes:
+#
+# name - Name of the production. For example 'expr'
+# prod - A list of symbols on the right side ['expr','.', 'PLUS','term']
+# number - Production number.
+#
+# lr_next Next LR item. Example, if we are ' expr -> expr . PLUS term'
+# then lr_next refers to 'expr -> expr PLUS . term'
+# lr_index - LR item index (location of the ".") in the prod list.
+# lookaheads - LALR lookahead symbols for this item
+# len - Length of the production (number of symbols on right hand side)
+# lr_after - List of all productions that immediately follow
+# lr_before - Grammar symbol immediately before
+# -----------------------------------------------------------------------------
+
+class LRItem(object):
+ def __init__(self, p, n):
+ self.name = p.name
+ self.prod = list(p.prod)
+ self.number = p.number
+ self.lr_index = n
+ self.lookaheads = {}
+ self.prod.insert(n, '.')
+ self.prod = tuple(self.prod)
+ self.len = len(self.prod)
+ self.usyms = p.usyms
+
+ def __str__(self):
+ if self.prod:
+ s = '%s -> %s' % (self.name, ' '.join(self.prod))
+ else:
+ s = '%s -> <empty>' % self.name
+ return s
+
+ def __repr__(self):
+ return 'LRItem(' + str(self) + ')'
+
+# -----------------------------------------------------------------------------
+# rightmost_terminal()
+#
+# Return the rightmost terminal from a list of symbols. Used in add_production()
+# -----------------------------------------------------------------------------
+def rightmost_terminal(symbols, terminals):
+ i = len(symbols) - 1
+ while i >= 0:
+ if symbols[i] in terminals:
+ return symbols[i]
+ i -= 1
+ return None
+
+# -----------------------------------------------------------------------------
+# === GRAMMAR CLASS ===
+#
+# The following class represents the contents of the specified grammar along
+# with various computed properties such as first sets, follow sets, LR items, etc.
+# This data is used for critical parts of the table generation process later.
+# -----------------------------------------------------------------------------
+
+class GrammarError(YaccError):
+ pass
+
+class Grammar(object):
+ def __init__(self, terminals):
+ self.Productions = [None] # A list of all of the productions. The first
+ # entry is always reserved for the purpose of
+ # building an augmented grammar
+
+ self.Prodnames = {} # A dictionary mapping the names of nonterminals to a list of all
+ # productions of that nonterminal.
+
+ self.Prodmap = {} # A dictionary that is only used to detect duplicate
+ # productions.
+
+ self.Terminals = {} # A dictionary mapping the names of terminal symbols to a
+ # list of the rules where they are used.
+
+ for term in terminals:
+ self.Terminals[term] = []
+
+ self.Terminals['error'] = []
+
+ self.Nonterminals = {} # A dictionary mapping names of nonterminals to a list
+ # of rule numbers where they are used.
+
+ self.First = {} # A dictionary of precomputed FIRST(x) symbols
+
+ self.Follow = {} # A dictionary of precomputed FOLLOW(x) symbols
+
+ self.Precedence = {} # Precedence rules for each terminal. Contains tuples of the
+ # form ('right',level) or ('nonassoc', level) or ('left',level)
+
+ self.UsedPrecedence = set() # Precedence rules that were actually used by the grammer.
+ # This is only used to provide error checking and to generate
+ # a warning about unused precedence rules.
+
+ self.Start = None # Starting symbol for the grammar
+
+
+ def __len__(self):
+ return len(self.Productions)
+
+ def __getitem__(self, index):
+ return self.Productions[index]
+
+ # -----------------------------------------------------------------------------
+ # set_precedence()
+ #
+ # Sets the precedence for a given terminal. assoc is the associativity such as
+ # 'left','right', or 'nonassoc'. level is a numeric level.
+ #
+ # -----------------------------------------------------------------------------
+
+ def set_precedence(self, term, assoc, level):
+ assert self.Productions == [None], 'Must call set_precedence() before add_production()'
+ if term in self.Precedence:
+ raise GrammarError('Precedence already specified for terminal %r' % term)
+ if assoc not in ['left', 'right', 'nonassoc']:
+ raise GrammarError("Associativity must be one of 'left','right', or 'nonassoc'")
+ self.Precedence[term] = (assoc, level)
+
+ # -----------------------------------------------------------------------------
+ # add_production()
+ #
+ # Given an action function, this function assembles a production rule and
+ # computes its precedence level.
+ #
+ # The production rule is supplied as a list of symbols. For example,
+ # a rule such as 'expr : expr PLUS term' has a production name of 'expr' and
+ # symbols ['expr','PLUS','term'].
+ #
+ # Precedence is determined by the precedence of the right-most non-terminal
+ # or the precedence of a terminal specified by %prec.
+ #
+ # A variety of error checks are performed to make sure production symbols
+ # are valid and that %prec is used correctly.
+ # -----------------------------------------------------------------------------
+
+ def add_production(self, prodname, syms, func=None, file='', line=0):
+
+ if prodname in self.Terminals:
+ raise GrammarError('%s:%d: Illegal rule name %r. Already defined as a token' % (file, line, prodname))
+ if prodname == 'error':
+ raise GrammarError('%s:%d: Illegal rule name %r. error is a reserved word' % (file, line, prodname))
+ if not _is_identifier.match(prodname):
+ raise GrammarError('%s:%d: Illegal rule name %r' % (file, line, prodname))
+
+ # Look for literal tokens
+ for n, s in enumerate(syms):
+ if s[0] in "'\"":
+ try:
+ c = eval(s)
+ if (len(c) > 1):
+ raise GrammarError('%s:%d: Literal token %s in rule %r may only be a single character' %
+ (file, line, s, prodname))
+ if c not in self.Terminals:
+ self.Terminals[c] = []
+ syms[n] = c
+ continue
+ except SyntaxError:
+ pass
+ if not _is_identifier.match(s) and s != '%prec':
+ raise GrammarError('%s:%d: Illegal name %r in rule %r' % (file, line, s, prodname))
+
+ # Determine the precedence level
+ if '%prec' in syms:
+ if syms[-1] == '%prec':
+ raise GrammarError('%s:%d: Syntax error. Nothing follows %%prec' % (file, line))
+ if syms[-2] != '%prec':
+ raise GrammarError('%s:%d: Syntax error. %%prec can only appear at the end of a grammar rule' %
+ (file, line))
+ precname = syms[-1]
+ prodprec = self.Precedence.get(precname)
+ if not prodprec:
+ raise GrammarError('%s:%d: Nothing known about the precedence of %r' % (file, line, precname))
+ else:
+ self.UsedPrecedence.add(precname)
+ del syms[-2:] # Drop %prec from the rule
+ else:
+ # If no %prec, precedence is determined by the rightmost terminal symbol
+ precname = rightmost_terminal(syms, self.Terminals)
+ prodprec = self.Precedence.get(precname, ('right', 0))
+
+ # See if the rule is already in the rulemap
+ map = '%s -> %s' % (prodname, syms)
+ if map in self.Prodmap:
+ m = self.Prodmap[map]
+ raise GrammarError('%s:%d: Duplicate rule %s. ' % (file, line, m) +
+ 'Previous definition at %s:%d' % (m.file, m.line))
+
+ # From this point on, everything is valid. Create a new Production instance
+ pnumber = len(self.Productions)
+ if prodname not in self.Nonterminals:
+ self.Nonterminals[prodname] = []
+
+ # Add the production number to Terminals and Nonterminals
+ for t in syms:
+ if t in self.Terminals:
+ self.Terminals[t].append(pnumber)
+ else:
+ if t not in self.Nonterminals:
+ self.Nonterminals[t] = []
+ self.Nonterminals[t].append(pnumber)
+
+ # Create a production and add it to the list of productions
+ p = Production(pnumber, prodname, syms, prodprec, func, file, line)
+ self.Productions.append(p)
+ self.Prodmap[map] = p
+
+ # Add to the global productions list
+ try:
+ self.Prodnames[prodname].append(p)
+ except KeyError:
+ self.Prodnames[prodname] = [p]
+
+ # -----------------------------------------------------------------------------
+ # set_start()
+ #
+ # Sets the starting symbol and creates the augmented grammar. Production
+ # rule 0 is S' -> start where start is the start symbol.
+ # -----------------------------------------------------------------------------
+
+ def set_start(self, start=None):
+ if not start:
+ start = self.Productions[1].name
+ if start not in self.Nonterminals:
+ raise GrammarError('start symbol %s undefined' % start)
+ self.Productions[0] = Production(0, "S'", [start])
+ self.Nonterminals[start].append(0)
+ self.Start = start
+
+ # -----------------------------------------------------------------------------
+ # find_unreachable()
+ #
+ # Find all of the nonterminal symbols that can't be reached from the starting
+ # symbol. Returns a list of nonterminals that can't be reached.
+ # -----------------------------------------------------------------------------
+
+ def find_unreachable(self):
+
+ # Mark all symbols that are reachable from a symbol s
+ def mark_reachable_from(s):
+ if s in reachable:
+ return
+ reachable.add(s)
+ for p in self.Prodnames.get(s, []):
+ for r in p.prod:
+ mark_reachable_from(r)
+
+ reachable = set()
+ mark_reachable_from(self.Productions[0].prod[0])
+ return [s for s in self.Nonterminals if s not in reachable]
+
+ # -----------------------------------------------------------------------------
+ # infinite_cycles()
+ #
+ # This function looks at the various parsing rules and tries to detect
+ # infinite recursion cycles (grammar rules where there is no possible way
+ # to derive a string of only terminals).
+ # -----------------------------------------------------------------------------
+
+ def infinite_cycles(self):
+ terminates = {}
+
+ # Terminals:
+ for t in self.Terminals:
+ terminates[t] = True
+
+ terminates['$end'] = True
+
+ # Nonterminals:
+
+ # Initialize to false:
+ for n in self.Nonterminals:
+ terminates[n] = False
+
+ # Then propagate termination until no change:
+ while True:
+ some_change = False
+ for (n, pl) in self.Prodnames.items():
+ # Nonterminal n terminates iff any of its productions terminates.
+ for p in pl:
+ # Production p terminates iff all of its rhs symbols terminate.
+ for s in p.prod:
+ if not terminates[s]:
+ # The symbol s does not terminate,
+ # so production p does not terminate.
+ p_terminates = False
+ break
+ else:
+ # didn't break from the loop,
+ # so every symbol s terminates
+ # so production p terminates.
+ p_terminates = True
+
+ if p_terminates:
+ # symbol n terminates!
+ if not terminates[n]:
+ terminates[n] = True
+ some_change = True
+ # Don't need to consider any more productions for this n.
+ break
+
+ if not some_change:
+ break
+
+ infinite = []
+ for (s, term) in terminates.items():
+ if not term:
+ if s not in self.Prodnames and s not in self.Terminals and s != 'error':
+ # s is used-but-not-defined, and we've already warned of that,
+ # so it would be overkill to say that it's also non-terminating.
+ pass
+ else:
+ infinite.append(s)
+
+ return infinite
+
+ # -----------------------------------------------------------------------------
+ # undefined_symbols()
+ #
+ # Find all symbols that were used the grammar, but not defined as tokens or
+ # grammar rules. Returns a list of tuples (sym, prod) where sym in the symbol
+ # and prod is the production where the symbol was used.
+ # -----------------------------------------------------------------------------
+ def undefined_symbols(self):
+ result = []
+ for p in self.Productions:
+ if not p:
+ continue
+
+ for s in p.prod:
+ if s not in self.Prodnames and s not in self.Terminals and s != 'error':
+ result.append((s, p))
+ return result
+
+ # -----------------------------------------------------------------------------
+ # unused_terminals()
+ #
+ # Find all terminals that were defined, but not used by the grammar. Returns
+ # a list of all symbols.
+ # -----------------------------------------------------------------------------
+ def unused_terminals(self):
+ unused_tok = []
+ for s, v in self.Terminals.items():
+ if s != 'error' and not v:
+ unused_tok.append(s)
+
+ return unused_tok
+
+ # ------------------------------------------------------------------------------
+ # unused_rules()
+ #
+ # Find all grammar rules that were defined, but not used (maybe not reachable)
+ # Returns a list of productions.
+ # ------------------------------------------------------------------------------
+
+ def unused_rules(self):
+ unused_prod = []
+ for s, v in self.Nonterminals.items():
+ if not v:
+ p = self.Prodnames[s][0]
+ unused_prod.append(p)
+ return unused_prod
+
+ # -----------------------------------------------------------------------------
+ # unused_precedence()
+ #
+ # Returns a list of tuples (term,precedence) corresponding to precedence
+ # rules that were never used by the grammar. term is the name of the terminal
+ # on which precedence was applied and precedence is a string such as 'left' or
+ # 'right' corresponding to the type of precedence.
+ # -----------------------------------------------------------------------------
+
+ def unused_precedence(self):
+ unused = []
+ for termname in self.Precedence:
+ if not (termname in self.Terminals or termname in self.UsedPrecedence):
+ unused.append((termname, self.Precedence[termname][0]))
+
+ return unused
+
+ # -------------------------------------------------------------------------
+ # _first()
+ #
+ # Compute the value of FIRST1(beta) where beta is a tuple of symbols.
+ #
+ # During execution of compute_first1, the result may be incomplete.
+ # Afterward (e.g., when called from compute_follow()), it will be complete.
+ # -------------------------------------------------------------------------
+ def _first(self, beta):
+
+ # We are computing First(x1,x2,x3,...,xn)
+ result = []
+ for x in beta:
+ x_produces_empty = False
+
+ # Add all the non-<empty> symbols of First[x] to the result.
+ for f in self.First[x]:
+ if f == '<empty>':
+ x_produces_empty = True
+ else:
+ if f not in result:
+ result.append(f)
+
+ if x_produces_empty:
+ # We have to consider the next x in beta,
+ # i.e. stay in the loop.
+ pass
+ else:
+ # We don't have to consider any further symbols in beta.
+ break
+ else:
+ # There was no 'break' from the loop,
+ # so x_produces_empty was true for all x in beta,
+ # so beta produces empty as well.
+ result.append('<empty>')
+
+ return result
+
+ # -------------------------------------------------------------------------
+ # compute_first()
+ #
+ # Compute the value of FIRST1(X) for all symbols
+ # -------------------------------------------------------------------------
+ def compute_first(self):
+ if self.First:
+ return self.First
+
+ # Terminals:
+ for t in self.Terminals:
+ self.First[t] = [t]
+
+ self.First['$end'] = ['$end']
+
+ # Nonterminals:
+
+ # Initialize to the empty set:
+ for n in self.Nonterminals:
+ self.First[n] = []
+
+ # Then propagate symbols until no change:
+ while True:
+ some_change = False
+ for n in self.Nonterminals:
+ for p in self.Prodnames[n]:
+ for f in self._first(p.prod):
+ if f not in self.First[n]:
+ self.First[n].append(f)
+ some_change = True
+ if not some_change:
+ break
+
+ return self.First
+
+ # ---------------------------------------------------------------------
+ # compute_follow()
+ #
+ # Computes all of the follow sets for every non-terminal symbol. The
+ # follow set is the set of all symbols that might follow a given
+ # non-terminal. See the Dragon book, 2nd Ed. p. 189.
+ # ---------------------------------------------------------------------
+ def compute_follow(self, start=None):
+ # If already computed, return the result
+ if self.Follow:
+ return self.Follow
+
+ # If first sets not computed yet, do that first.
+ if not self.First:
+ self.compute_first()
+
+ # Add '$end' to the follow list of the start symbol
+ for k in self.Nonterminals:
+ self.Follow[k] = []
+
+ if not start:
+ start = self.Productions[1].name
+
+ self.Follow[start] = ['$end']
+
+ while True:
+ didadd = False
+ for p in self.Productions[1:]:
+ # Here is the production set
+ for i, B in enumerate(p.prod):
+ if B in self.Nonterminals:
+ # Okay. We got a non-terminal in a production
+ fst = self._first(p.prod[i+1:])
+ hasempty = False
+ for f in fst:
+ if f != '<empty>' and f not in self.Follow[B]:
+ self.Follow[B].append(f)
+ didadd = True
+ if f == '<empty>':
+ hasempty = True
+ if hasempty or i == (len(p.prod)-1):
+ # Add elements of follow(a) to follow(b)
+ for f in self.Follow[p.name]:
+ if f not in self.Follow[B]:
+ self.Follow[B].append(f)
+ didadd = True
+ if not didadd:
+ break
+ return self.Follow
+
+
+ # -----------------------------------------------------------------------------
+ # build_lritems()
+ #
+ # This function walks the list of productions and builds a complete set of the
+ # LR items. The LR items are stored in two ways: First, they are uniquely
+ # numbered and placed in the list _lritems. Second, a linked list of LR items
+ # is built for each production. For example:
+ #
+ # E -> E PLUS E
+ #
+ # Creates the list
+ #
+ # [E -> . E PLUS E, E -> E . PLUS E, E -> E PLUS . E, E -> E PLUS E . ]
+ # -----------------------------------------------------------------------------
+
+ def build_lritems(self):
+ for p in self.Productions:
+ lastlri = p
+ i = 0
+ lr_items = []
+ while True:
+ if i > len(p):
+ lri = None
+ else:
+ lri = LRItem(p, i)
+ # Precompute the list of productions immediately following
+ try:
+ lri.lr_after = self.Prodnames[lri.prod[i+1]]
+ except (IndexError, KeyError):
+ lri.lr_after = []
+ try:
+ lri.lr_before = lri.prod[i-1]
+ except IndexError:
+ lri.lr_before = None
+
+ lastlri.lr_next = lri
+ if not lri:
+ break
+ lr_items.append(lri)
+ lastlri = lri
+ i += 1
+ p.lr_items = lr_items
+
+# -----------------------------------------------------------------------------
+# == Class LRTable ==
+#
+# This basic class represents a basic table of LR parsing information.
+# Methods for generating the tables are not defined here. They are defined
+# in the derived class LRGeneratedTable.
+# -----------------------------------------------------------------------------
+
+class VersionError(YaccError):
+ pass
+
+class LRTable(object):
+ def __init__(self):
+ self.lr_action = None
+ self.lr_goto = None
+ self.lr_productions = None
+ self.lr_method = None
+
+ def read_table(self, module):
+ if isinstance(module, types.ModuleType):
+ parsetab = module
+ else:
+ exec('import %s' % module)
+ parsetab = sys.modules[module]
+
+ if parsetab._tabversion != __tabversion__:
+ raise VersionError('yacc table file version is out of date')
+
+ self.lr_action = parsetab._lr_action
+ self.lr_goto = parsetab._lr_goto
+
+ self.lr_productions = []
+ for p in parsetab._lr_productions:
+ self.lr_productions.append(MiniProduction(*p))
+
+ self.lr_method = parsetab._lr_method
+ return parsetab._lr_signature
+
+ def read_pickle(self, filename):
+ try:
+ import cPickle as pickle
+ except ImportError:
+ import pickle
+
+ if not os.path.exists(filename):
+ raise ImportError
+
+ in_f = open(filename, 'rb')
+
+ tabversion = pickle.load(in_f)
+ if tabversion != __tabversion__:
+ raise VersionError('yacc table file version is out of date')
+ self.lr_method = pickle.load(in_f)
+ signature = pickle.load(in_f)
+ self.lr_action = pickle.load(in_f)
+ self.lr_goto = pickle.load(in_f)
+ productions = pickle.load(in_f)
+
+ self.lr_productions = []
+ for p in productions:
+ self.lr_productions.append(MiniProduction(*p))
+
+ in_f.close()
+ return signature
+
+ # Bind all production function names to callable objects in pdict
+ def bind_callables(self, pdict):
+ for p in self.lr_productions:
+ p.bind(pdict)
+
+
+# -----------------------------------------------------------------------------
+# === LR Generator ===
+#
+# The following classes and functions are used to generate LR parsing tables on
+# a grammar.
+# -----------------------------------------------------------------------------
+
+# -----------------------------------------------------------------------------
+# digraph()
+# traverse()
+#
+# The following two functions are used to compute set valued functions
+# of the form:
+#
+# F(x) = F'(x) U U{F(y) | x R y}
+#
+# This is used to compute the values of Read() sets as well as FOLLOW sets
+# in LALR(1) generation.
+#
+# Inputs: X - An input set
+# R - A relation
+# FP - Set-valued function
+# ------------------------------------------------------------------------------
+
+def digraph(X, R, FP):
+ N = {}
+ for x in X:
+ N[x] = 0
+ stack = []
+ F = {}
+ for x in X:
+ if N[x] == 0:
+ traverse(x, N, stack, F, X, R, FP)
+ return F
+
+def traverse(x, N, stack, F, X, R, FP):
+ stack.append(x)
+ d = len(stack)
+ N[x] = d
+ F[x] = FP(x) # F(X) <- F'(x)
+
+ rel = R(x) # Get y's related to x
+ for y in rel:
+ if N[y] == 0:
+ traverse(y, N, stack, F, X, R, FP)
+ N[x] = min(N[x], N[y])
+ for a in F.get(y, []):
+ if a not in F[x]:
+ F[x].append(a)
+ if N[x] == d:
+ N[stack[-1]] = MAXINT
+ F[stack[-1]] = F[x]
+ element = stack.pop()
+ while element != x:
+ N[stack[-1]] = MAXINT
+ F[stack[-1]] = F[x]
+ element = stack.pop()
+
+class LALRError(YaccError):
+ pass
+
+# -----------------------------------------------------------------------------
+# == LRGeneratedTable ==
+#
+# This class implements the LR table generation algorithm. There are no
+# public methods except for write()
+# -----------------------------------------------------------------------------
+
+class LRGeneratedTable(LRTable):
+ def __init__(self, grammar, method='LALR', log=None):
+ if method not in ['SLR', 'LALR']:
+ raise LALRError('Unsupported method %s' % method)
+
+ self.grammar = grammar
+ self.lr_method = method
+
+ # Set up the logger
+ if not log:
+ log = NullLogger()
+ self.log = log
+
+ # Internal attributes
+ self.lr_action = {} # Action table
+ self.lr_goto = {} # Goto table
+ self.lr_productions = grammar.Productions # Copy of grammar Production array
+ self.lr_goto_cache = {} # Cache of computed gotos
+ self.lr0_cidhash = {} # Cache of closures
+
+ self._add_count = 0 # Internal counter used to detect cycles
+
+ # Diagonistic information filled in by the table generator
+ self.sr_conflict = 0
+ self.rr_conflict = 0
+ self.conflicts = [] # List of conflicts
+
+ self.sr_conflicts = []
+ self.rr_conflicts = []
+
+ # Build the tables
+ self.grammar.build_lritems()
+ self.grammar.compute_first()
+ self.grammar.compute_follow()
+ self.lr_parse_table()
+
+ # Compute the LR(0) closure operation on I, where I is a set of LR(0) items.
+
+ def lr0_closure(self, I):
+ self._add_count += 1
+
+ # Add everything in I to J
+ J = I[:]
+ didadd = True
+ while didadd:
+ didadd = False
+ for j in J:
+ for x in j.lr_after:
+ if getattr(x, 'lr0_added', 0) == self._add_count:
+ continue
+ # Add B --> .G to J
+ J.append(x.lr_next)
+ x.lr0_added = self._add_count
+ didadd = True
+
+ return J
+
+ # Compute the LR(0) goto function goto(I,X) where I is a set
+ # of LR(0) items and X is a grammar symbol. This function is written
+ # in a way that guarantees uniqueness of the generated goto sets
+ # (i.e. the same goto set will never be returned as two different Python
+ # objects). With uniqueness, we can later do fast set comparisons using
+ # id(obj) instead of element-wise comparison.
+
+ def lr0_goto(self, I, x):
+ # First we look for a previously cached entry
+ g = self.lr_goto_cache.get((id(I), x))
+ if g:
+ return g
+
+ # Now we generate the goto set in a way that guarantees uniqueness
+ # of the result
+
+ s = self.lr_goto_cache.get(x)
+ if not s:
+ s = {}
+ self.lr_goto_cache[x] = s
+
+ gs = []
+ for p in I:
+ n = p.lr_next
+ if n and n.lr_before == x:
+ s1 = s.get(id(n))
+ if not s1:
+ s1 = {}
+ s[id(n)] = s1
+ gs.append(n)
+ s = s1
+ g = s.get('$end')
+ if not g:
+ if gs:
+ g = self.lr0_closure(gs)
+ s['$end'] = g
+ else:
+ s['$end'] = gs
+ self.lr_goto_cache[(id(I), x)] = g
+ return g
+
+ # Compute the LR(0) sets of item function
+ def lr0_items(self):
+ C = [self.lr0_closure([self.grammar.Productions[0].lr_next])]
+ i = 0
+ for I in C:
+ self.lr0_cidhash[id(I)] = i
+ i += 1
+
+ # Loop over the items in C and each grammar symbols
+ i = 0
+ while i < len(C):
+ I = C[i]
+ i += 1
+
+ # Collect all of the symbols that could possibly be in the goto(I,X) sets
+ asyms = {}
+ for ii in I:
+ for s in ii.usyms:
+ asyms[s] = None
+
+ for x in asyms:
+ g = self.lr0_goto(I, x)
+ if not g or id(g) in self.lr0_cidhash:
+ continue
+ self.lr0_cidhash[id(g)] = len(C)
+ C.append(g)
+
+ return C
+
+ # -----------------------------------------------------------------------------
+ # ==== LALR(1) Parsing ====
+ #
+ # LALR(1) parsing is almost exactly the same as SLR except that instead of
+ # relying upon Follow() sets when performing reductions, a more selective
+ # lookahead set that incorporates the state of the LR(0) machine is utilized.
+ # Thus, we mainly just have to focus on calculating the lookahead sets.
+ #
+ # The method used here is due to DeRemer and Pennelo (1982).
+ #
+ # DeRemer, F. L., and T. J. Pennelo: "Efficient Computation of LALR(1)
+ # Lookahead Sets", ACM Transactions on Programming Languages and Systems,
+ # Vol. 4, No. 4, Oct. 1982, pp. 615-649
+ #
+ # Further details can also be found in:
+ #
+ # J. Tremblay and P. Sorenson, "The Theory and Practice of Compiler Writing",
+ # McGraw-Hill Book Company, (1985).
+ #
+ # -----------------------------------------------------------------------------
+
+ # -----------------------------------------------------------------------------
+ # compute_nullable_nonterminals()
+ #
+ # Creates a dictionary containing all of the non-terminals that might produce
+ # an empty production.
+ # -----------------------------------------------------------------------------
+
+ def compute_nullable_nonterminals(self):
+ nullable = set()
+ num_nullable = 0
+ while True:
+ for p in self.grammar.Productions[1:]:
+ if p.len == 0:
+ nullable.add(p.name)
+ continue
+ for t in p.prod:
+ if t not in nullable:
+ break
+ else:
+ nullable.add(p.name)
+ if len(nullable) == num_nullable:
+ break
+ num_nullable = len(nullable)
+ return nullable
+
+ # -----------------------------------------------------------------------------
+ # find_nonterminal_trans(C)
+ #
+ # Given a set of LR(0) items, this functions finds all of the non-terminal
+ # transitions. These are transitions in which a dot appears immediately before
+ # a non-terminal. Returns a list of tuples of the form (state,N) where state
+ # is the state number and N is the nonterminal symbol.
+ #
+ # The input C is the set of LR(0) items.
+ # -----------------------------------------------------------------------------
+
+ def find_nonterminal_transitions(self, C):
+ trans = []
+ for stateno, state in enumerate(C):
+ for p in state:
+ if p.lr_index < p.len - 1:
+ t = (stateno, p.prod[p.lr_index+1])
+ if t[1] in self.grammar.Nonterminals:
+ if t not in trans:
+ trans.append(t)
+ return trans
+
+ # -----------------------------------------------------------------------------
+ # dr_relation()
+ #
+ # Computes the DR(p,A) relationships for non-terminal transitions. The input
+ # is a tuple (state,N) where state is a number and N is a nonterminal symbol.
+ #
+ # Returns a list of terminals.
+ # -----------------------------------------------------------------------------
+
+ def dr_relation(self, C, trans, nullable):
+ dr_set = {}
+ state, N = trans
+ terms = []
+
+ g = self.lr0_goto(C[state], N)
+ for p in g:
+ if p.lr_index < p.len - 1:
+ a = p.prod[p.lr_index+1]
+ if a in self.grammar.Terminals:
+ if a not in terms:
+ terms.append(a)
+
+ # This extra bit is to handle the start state
+ if state == 0 and N == self.grammar.Productions[0].prod[0]:
+ terms.append('$end')
+
+ return terms
+
+ # -----------------------------------------------------------------------------
+ # reads_relation()
+ #
+ # Computes the READS() relation (p,A) READS (t,C).
+ # -----------------------------------------------------------------------------
+
+ def reads_relation(self, C, trans, empty):
+ # Look for empty transitions
+ rel = []
+ state, N = trans
+
+ g = self.lr0_goto(C[state], N)
+ j = self.lr0_cidhash.get(id(g), -1)
+ for p in g:
+ if p.lr_index < p.len - 1:
+ a = p.prod[p.lr_index + 1]
+ if a in empty:
+ rel.append((j, a))
+
+ return rel
+
+ # -----------------------------------------------------------------------------
+ # compute_lookback_includes()
+ #
+ # Determines the lookback and includes relations
+ #
+ # LOOKBACK:
+ #
+ # This relation is determined by running the LR(0) state machine forward.
+ # For example, starting with a production "N : . A B C", we run it forward
+ # to obtain "N : A B C ." We then build a relationship between this final
+ # state and the starting state. These relationships are stored in a dictionary
+ # lookdict.
+ #
+ # INCLUDES:
+ #
+ # Computes the INCLUDE() relation (p,A) INCLUDES (p',B).
+ #
+ # This relation is used to determine non-terminal transitions that occur
+ # inside of other non-terminal transition states. (p,A) INCLUDES (p', B)
+ # if the following holds:
+ #
+ # B -> LAT, where T -> epsilon and p' -L-> p
+ #
+ # L is essentially a prefix (which may be empty), T is a suffix that must be
+ # able to derive an empty string. State p' must lead to state p with the string L.
+ #
+ # -----------------------------------------------------------------------------
+
+ def compute_lookback_includes(self, C, trans, nullable):
+ lookdict = {} # Dictionary of lookback relations
+ includedict = {} # Dictionary of include relations
+
+ # Make a dictionary of non-terminal transitions
+ dtrans = {}
+ for t in trans:
+ dtrans[t] = 1
+
+ # Loop over all transitions and compute lookbacks and includes
+ for state, N in trans:
+ lookb = []
+ includes = []
+ for p in C[state]:
+ if p.name != N:
+ continue
+
+ # Okay, we have a name match. We now follow the production all the way
+ # through the state machine until we get the . on the right hand side
+
+ lr_index = p.lr_index
+ j = state
+ while lr_index < p.len - 1:
+ lr_index = lr_index + 1
+ t = p.prod[lr_index]
+
+ # Check to see if this symbol and state are a non-terminal transition
+ if (j, t) in dtrans:
+ # Yes. Okay, there is some chance that this is an includes relation
+ # the only way to know for certain is whether the rest of the
+ # production derives empty
+
+ li = lr_index + 1
+ while li < p.len:
+ if p.prod[li] in self.grammar.Terminals:
+ break # No forget it
+ if p.prod[li] not in nullable:
+ break
+ li = li + 1
+ else:
+ # Appears to be a relation between (j,t) and (state,N)
+ includes.append((j, t))
+
+ g = self.lr0_goto(C[j], t) # Go to next set
+ j = self.lr0_cidhash.get(id(g), -1) # Go to next state
+
+ # When we get here, j is the final state, now we have to locate the production
+ for r in C[j]:
+ if r.name != p.name:
+ continue
+ if r.len != p.len:
+ continue
+ i = 0
+ # This look is comparing a production ". A B C" with "A B C ."
+ while i < r.lr_index:
+ if r.prod[i] != p.prod[i+1]:
+ break
+ i = i + 1
+ else:
+ lookb.append((j, r))
+ for i in includes:
+ if i not in includedict:
+ includedict[i] = []
+ includedict[i].append((state, N))
+ lookdict[(state, N)] = lookb
+
+ return lookdict, includedict
+
+ # -----------------------------------------------------------------------------
+ # compute_read_sets()
+ #
+ # Given a set of LR(0) items, this function computes the read sets.
+ #
+ # Inputs: C = Set of LR(0) items
+ # ntrans = Set of nonterminal transitions
+ # nullable = Set of empty transitions
+ #
+ # Returns a set containing the read sets
+ # -----------------------------------------------------------------------------
+
+ def compute_read_sets(self, C, ntrans, nullable):
+ FP = lambda x: self.dr_relation(C, x, nullable)
+ R = lambda x: self.reads_relation(C, x, nullable)
+ F = digraph(ntrans, R, FP)
+ return F
+
+ # -----------------------------------------------------------------------------
+ # compute_follow_sets()
+ #
+ # Given a set of LR(0) items, a set of non-terminal transitions, a readset,
+ # and an include set, this function computes the follow sets
+ #
+ # Follow(p,A) = Read(p,A) U U {Follow(p',B) | (p,A) INCLUDES (p',B)}
+ #
+ # Inputs:
+ # ntrans = Set of nonterminal transitions
+ # readsets = Readset (previously computed)
+ # inclsets = Include sets (previously computed)
+ #
+ # Returns a set containing the follow sets
+ # -----------------------------------------------------------------------------
+
+ def compute_follow_sets(self, ntrans, readsets, inclsets):
+ FP = lambda x: readsets[x]
+ R = lambda x: inclsets.get(x, [])
+ F = digraph(ntrans, R, FP)
+ return F
+
+ # -----------------------------------------------------------------------------
+ # add_lookaheads()
+ #
+ # Attaches the lookahead symbols to grammar rules.
+ #
+ # Inputs: lookbacks - Set of lookback relations
+ # followset - Computed follow set
+ #
+ # This function directly attaches the lookaheads to productions contained
+ # in the lookbacks set
+ # -----------------------------------------------------------------------------
+
+ def add_lookaheads(self, lookbacks, followset):
+ for trans, lb in lookbacks.items():
+ # Loop over productions in lookback
+ for state, p in lb:
+ if state not in p.lookaheads:
+ p.lookaheads[state] = []
+ f = followset.get(trans, [])
+ for a in f:
+ if a not in p.lookaheads[state]:
+ p.lookaheads[state].append(a)
+
+ # -----------------------------------------------------------------------------
+ # add_lalr_lookaheads()
+ #
+ # This function does all of the work of adding lookahead information for use
+ # with LALR parsing
+ # -----------------------------------------------------------------------------
+
+ def add_lalr_lookaheads(self, C):
+ # Determine all of the nullable nonterminals
+ nullable = self.compute_nullable_nonterminals()
+
+ # Find all non-terminal transitions
+ trans = self.find_nonterminal_transitions(C)
+
+ # Compute read sets
+ readsets = self.compute_read_sets(C, trans, nullable)
+
+ # Compute lookback/includes relations
+ lookd, included = self.compute_lookback_includes(C, trans, nullable)
+
+ # Compute LALR FOLLOW sets
+ followsets = self.compute_follow_sets(trans, readsets, included)
+
+ # Add all of the lookaheads
+ self.add_lookaheads(lookd, followsets)
+
+ # -----------------------------------------------------------------------------
+ # lr_parse_table()
+ #
+ # This function constructs the parse tables for SLR or LALR
+ # -----------------------------------------------------------------------------
+ def lr_parse_table(self):
+ Productions = self.grammar.Productions
+ Precedence = self.grammar.Precedence
+ goto = self.lr_goto # Goto array
+ action = self.lr_action # Action array
+ log = self.log # Logger for output
+
+ actionp = {} # Action production array (temporary)
+
+ log.info('Parsing method: %s', self.lr_method)
+
+ # Step 1: Construct C = { I0, I1, ... IN}, collection of LR(0) items
+ # This determines the number of states
+
+ C = self.lr0_items()
+
+ if self.lr_method == 'LALR':
+ self.add_lalr_lookaheads(C)
+
+ # Build the parser table, state by state
+ st = 0
+ for I in C:
+ # Loop over each production in I
+ actlist = [] # List of actions
+ st_action = {}
+ st_actionp = {}
+ st_goto = {}
+ log.info('')
+ log.info('state %d', st)
+ log.info('')
+ for p in I:
+ log.info(' (%d) %s', p.number, p)
+ log.info('')
+
+ for p in I:
+ if p.len == p.lr_index + 1:
+ if p.name == "S'":
+ # Start symbol. Accept!
+ st_action['$end'] = 0
+ st_actionp['$end'] = p
+ else:
+ # We are at the end of a production. Reduce!
+ if self.lr_method == 'LALR':
+ laheads = p.lookaheads[st]
+ else:
+ laheads = self.grammar.Follow[p.name]
+ for a in laheads:
+ actlist.append((a, p, 'reduce using rule %d (%s)' % (p.number, p)))
+ r = st_action.get(a)
+ if r is not None:
+ # Whoa. Have a shift/reduce or reduce/reduce conflict
+ if r > 0:
+ # Need to decide on shift or reduce here
+ # By default we favor shifting. Need to add
+ # some precedence rules here.
+
+ # Shift precedence comes from the token
+ sprec, slevel = Precedence.get(a, ('right', 0))
+
+ # Reduce precedence comes from rule being reduced (p)
+ rprec, rlevel = Productions[p.number].prec
+
+ if (slevel < rlevel) or ((slevel == rlevel) and (rprec == 'left')):
+ # We really need to reduce here.
+ st_action[a] = -p.number
+ st_actionp[a] = p
+ if not slevel and not rlevel:
+ log.info(' ! shift/reduce conflict for %s resolved as reduce', a)
+ self.sr_conflicts.append((st, a, 'reduce'))
+ Productions[p.number].reduced += 1
+ elif (slevel == rlevel) and (rprec == 'nonassoc'):
+ st_action[a] = None
+ else:
+ # Hmmm. Guess we'll keep the shift
+ if not rlevel:
+ log.info(' ! shift/reduce conflict for %s resolved as shift', a)
+ self.sr_conflicts.append((st, a, 'shift'))
+ elif r < 0:
+ # Reduce/reduce conflict. In this case, we favor the rule
+ # that was defined first in the grammar file
+ oldp = Productions[-r]
+ pp = Productions[p.number]
+ if oldp.line > pp.line:
+ st_action[a] = -p.number
+ st_actionp[a] = p
+ chosenp, rejectp = pp, oldp
+ Productions[p.number].reduced += 1
+ Productions[oldp.number].reduced -= 1
+ else:
+ chosenp, rejectp = oldp, pp
+ self.rr_conflicts.append((st, chosenp, rejectp))
+ log.info(' ! reduce/reduce conflict for %s resolved using rule %d (%s)',
+ a, st_actionp[a].number, st_actionp[a])
+ else:
+ raise LALRError('Unknown conflict in state %d' % st)
+ else:
+ st_action[a] = -p.number
+ st_actionp[a] = p
+ Productions[p.number].reduced += 1
+ else:
+ i = p.lr_index
+ a = p.prod[i+1] # Get symbol right after the "."
+ if a in self.grammar.Terminals:
+ g = self.lr0_goto(I, a)
+ j = self.lr0_cidhash.get(id(g), -1)
+ if j >= 0:
+ # We are in a shift state
+ actlist.append((a, p, 'shift and go to state %d' % j))
+ r = st_action.get(a)
+ if r is not None:
+ # Whoa have a shift/reduce or shift/shift conflict
+ if r > 0:
+ if r != j:
+ raise LALRError('Shift/shift conflict in state %d' % st)
+ elif r < 0:
+ # Do a precedence check.
+ # - if precedence of reduce rule is higher, we reduce.
+ # - if precedence of reduce is same and left assoc, we reduce.
+ # - otherwise we shift
+
+ # Shift precedence comes from the token
+ sprec, slevel = Precedence.get(a, ('right', 0))
+
+ # Reduce precedence comes from the rule that could have been reduced
+ rprec, rlevel = Productions[st_actionp[a].number].prec
+
+ if (slevel > rlevel) or ((slevel == rlevel) and (rprec == 'right')):
+ # We decide to shift here... highest precedence to shift
+ Productions[st_actionp[a].number].reduced -= 1
+ st_action[a] = j
+ st_actionp[a] = p
+ if not rlevel:
+ log.info(' ! shift/reduce conflict for %s resolved as shift', a)
+ self.sr_conflicts.append((st, a, 'shift'))
+ elif (slevel == rlevel) and (rprec == 'nonassoc'):
+ st_action[a] = None
+ else:
+ # Hmmm. Guess we'll keep the reduce
+ if not slevel and not rlevel:
+ log.info(' ! shift/reduce conflict for %s resolved as reduce', a)
+ self.sr_conflicts.append((st, a, 'reduce'))
+
+ else:
+ raise LALRError('Unknown conflict in state %d' % st)
+ else:
+ st_action[a] = j
+ st_actionp[a] = p
+
+ # Print the actions associated with each terminal
+ _actprint = {}
+ for a, p, m in actlist:
+ if a in st_action:
+ if p is st_actionp[a]:
+ log.info(' %-15s %s', a, m)
+ _actprint[(a, m)] = 1
+ log.info('')
+ # Print the actions that were not used. (debugging)
+ not_used = 0
+ for a, p, m in actlist:
+ if a in st_action:
+ if p is not st_actionp[a]:
+ if not (a, m) in _actprint:
+ log.debug(' ! %-15s [ %s ]', a, m)
+ not_used = 1
+ _actprint[(a, m)] = 1
+ if not_used:
+ log.debug('')
+
+ # Construct the goto table for this state
+
+ nkeys = {}
+ for ii in I:
+ for s in ii.usyms:
+ if s in self.grammar.Nonterminals:
+ nkeys[s] = None
+ for n in nkeys:
+ g = self.lr0_goto(I, n)
+ j = self.lr0_cidhash.get(id(g), -1)
+ if j >= 0:
+ st_goto[n] = j
+ log.info(' %-30s shift and go to state %d', n, j)
+
+ action[st] = st_action
+ actionp[st] = st_actionp
+ goto[st] = st_goto
+ st += 1
+
+ # -----------------------------------------------------------------------------
+ # write()
+ #
+ # This function writes the LR parsing tables to a file
+ # -----------------------------------------------------------------------------
+
+ def write_table(self, tabmodule, outputdir='', signature=''):
+ if isinstance(tabmodule, types.ModuleType):
+ raise IOError("Won't overwrite existing tabmodule")
+
+ basemodulename = tabmodule.split('.')[-1]
+ filename = os.path.join(outputdir, basemodulename) + '.py'
+ try:
+ f = open(filename, 'w')
+
+ f.write('''
+# %s
+# This file is automatically generated. Do not edit.
+_tabversion = %r
+
+_lr_method = %r
+
+_lr_signature = %r
+ ''' % (os.path.basename(filename), __tabversion__, self.lr_method, signature))
+
+ # Change smaller to 0 to go back to original tables
+ smaller = 1
+
+ # Factor out names to try and make smaller
+ if smaller:
+ items = {}
+
+ for s, nd in self.lr_action.items():
+ for name, v in nd.items():
+ i = items.get(name)
+ if not i:
+ i = ([], [])
+ items[name] = i
+ i[0].append(s)
+ i[1].append(v)
+
+ f.write('\n_lr_action_items = {')
+ for k, v in items.items():
+ f.write('%r:([' % k)
+ for i in v[0]:
+ f.write('%r,' % i)
+ f.write('],[')
+ for i in v[1]:
+ f.write('%r,' % i)
+
+ f.write(']),')
+ f.write('}\n')
+
+ f.write('''
+_lr_action = {}
+for _k, _v in _lr_action_items.items():
+ for _x,_y in zip(_v[0],_v[1]):
+ if not _x in _lr_action: _lr_action[_x] = {}
+ _lr_action[_x][_k] = _y
+del _lr_action_items
+''')
+
+ else:
+ f.write('\n_lr_action = { ')
+ for k, v in self.lr_action.items():
+ f.write('(%r,%r):%r,' % (k[0], k[1], v))
+ f.write('}\n')
+
+ if smaller:
+ # Factor out names to try and make smaller
+ items = {}
+
+ for s, nd in self.lr_goto.items():
+ for name, v in nd.items():
+ i = items.get(name)
+ if not i:
+ i = ([], [])
+ items[name] = i
+ i[0].append(s)
+ i[1].append(v)
+
+ f.write('\n_lr_goto_items = {')
+ for k, v in items.items():
+ f.write('%r:([' % k)
+ for i in v[0]:
+ f.write('%r,' % i)
+ f.write('],[')
+ for i in v[1]:
+ f.write('%r,' % i)
+
+ f.write(']),')
+ f.write('}\n')
+
+ f.write('''
+_lr_goto = {}
+for _k, _v in _lr_goto_items.items():
+ for _x, _y in zip(_v[0], _v[1]):
+ if not _x in _lr_goto: _lr_goto[_x] = {}
+ _lr_goto[_x][_k] = _y
+del _lr_goto_items
+''')
+ else:
+ f.write('\n_lr_goto = { ')
+ for k, v in self.lr_goto.items():
+ f.write('(%r,%r):%r,' % (k[0], k[1], v))
+ f.write('}\n')
+
+ # Write production table
+ f.write('_lr_productions = [\n')
+ for p in self.lr_productions:
+ if p.func:
+ f.write(' (%r,%r,%d,%r,%r,%d),\n' % (p.str, p.name, p.len,
+ p.func, os.path.basename(p.file), p.line))
+ else:
+ f.write(' (%r,%r,%d,None,None,None),\n' % (str(p), p.name, p.len))
+ f.write(']\n')
+ f.close()
+
+ except IOError as e:
+ raise
+
+
+ # -----------------------------------------------------------------------------
+ # pickle_table()
+ #
+ # This function pickles the LR parsing tables to a supplied file object
+ # -----------------------------------------------------------------------------
+
+ def pickle_table(self, filename, signature=''):
+ try:
+ import cPickle as pickle
+ except ImportError:
+ import pickle
+ with open(filename, 'wb') as outf:
+ pickle.dump(__tabversion__, outf, pickle_protocol)
+ pickle.dump(self.lr_method, outf, pickle_protocol)
+ pickle.dump(signature, outf, pickle_protocol)
+ pickle.dump(self.lr_action, outf, pickle_protocol)
+ pickle.dump(self.lr_goto, outf, pickle_protocol)
+
+ outp = []
+ for p in self.lr_productions:
+ if p.func:
+ outp.append((p.str, p.name, p.len, p.func, os.path.basename(p.file), p.line))
+ else:
+ outp.append((str(p), p.name, p.len, None, None, None))
+ pickle.dump(outp, outf, pickle_protocol)
+
+# -----------------------------------------------------------------------------
+# === INTROSPECTION ===
+#
+# The following functions and classes are used to implement the PLY
+# introspection features followed by the yacc() function itself.
+# -----------------------------------------------------------------------------
+
+# -----------------------------------------------------------------------------
+# get_caller_module_dict()
+#
+# This function returns a dictionary containing all of the symbols defined within
+# a caller further down the call stack. This is used to get the environment
+# associated with the yacc() call if none was provided.
+# -----------------------------------------------------------------------------
+
+def get_caller_module_dict(levels):
+ f = sys._getframe(levels)
+ ldict = f.f_globals.copy()
+ if f.f_globals != f.f_locals:
+ ldict.update(f.f_locals)
+ return ldict
+
+# -----------------------------------------------------------------------------
+# parse_grammar()
+#
+# This takes a raw grammar rule string and parses it into production data
+# -----------------------------------------------------------------------------
+def parse_grammar(doc, file, line):
+ grammar = []
+ # Split the doc string into lines
+ pstrings = doc.splitlines()
+ lastp = None
+ dline = line
+ for ps in pstrings:
+ dline += 1
+ p = ps.split()
+ if not p:
+ continue
+ try:
+ if p[0] == '|':
+ # This is a continuation of a previous rule
+ if not lastp:
+ raise SyntaxError("%s:%d: Misplaced '|'" % (file, dline))
+ prodname = lastp
+ syms = p[1:]
+ else:
+ prodname = p[0]
+ lastp = prodname
+ syms = p[2:]
+ assign = p[1]
+ if assign != ':' and assign != '::=':
+ raise SyntaxError("%s:%d: Syntax error. Expected ':'" % (file, dline))
+
+ grammar.append((file, dline, prodname, syms))
+ except SyntaxError:
+ raise
+ except Exception:
+ raise SyntaxError('%s:%d: Syntax error in rule %r' % (file, dline, ps.strip()))
+
+ return grammar
+
+# -----------------------------------------------------------------------------
+# ParserReflect()
+#
+# This class represents information extracted for building a parser including
+# start symbol, error function, tokens, precedence list, action functions,
+# etc.
+# -----------------------------------------------------------------------------
+class ParserReflect(object):
+ def __init__(self, pdict, log=None):
+ self.pdict = pdict
+ self.start = None
+ self.error_func = None
+ self.tokens = None
+ self.modules = set()
+ self.grammar = []
+ self.error = False
+
+ if log is None:
+ self.log = PlyLogger(sys.stderr)
+ else:
+ self.log = log
+
+ # Get all of the basic information
+ def get_all(self):
+ self.get_start()
+ self.get_error_func()
+ self.get_tokens()
+ self.get_precedence()
+ self.get_pfunctions()
+
+ # Validate all of the information
+ def validate_all(self):
+ self.validate_start()
+ self.validate_error_func()
+ self.validate_tokens()
+ self.validate_precedence()
+ self.validate_pfunctions()
+ self.validate_modules()
+ return self.error
+
+ # Compute a signature over the grammar
+ def signature(self):
+ parts = []
+ try:
+ if self.start:
+ parts.append(self.start)
+ if self.prec:
+ parts.append(''.join([''.join(p) for p in self.prec]))
+ if self.tokens:
+ parts.append(' '.join(self.tokens))
+ for f in self.pfuncs:
+ if f[3]:
+ parts.append(f[3])
+ except (TypeError, ValueError):
+ pass
+ return ''.join(parts)
+
+ # -----------------------------------------------------------------------------
+ # validate_modules()
+ #
+ # This method checks to see if there are duplicated p_rulename() functions
+ # in the parser module file. Without this function, it is really easy for
+ # users to make mistakes by cutting and pasting code fragments (and it's a real
+ # bugger to try and figure out why the resulting parser doesn't work). Therefore,
+ # we just do a little regular expression pattern matching of def statements
+ # to try and detect duplicates.
+ # -----------------------------------------------------------------------------
+
+ def validate_modules(self):
+ # Match def p_funcname(
+ fre = re.compile(r'\s*def\s+(p_[a-zA-Z_0-9]*)\(')
+
+ for module in self.modules:
+ try:
+ lines, linen = inspect.getsourcelines(module)
+ except IOError:
+ continue
+
+ counthash = {}
+ for linen, line in enumerate(lines):
+ linen += 1
+ m = fre.match(line)
+ if m:
+ name = m.group(1)
+ prev = counthash.get(name)
+ if not prev:
+ counthash[name] = linen
+ else:
+ filename = inspect.getsourcefile(module)
+ self.log.warning('%s:%d: Function %s redefined. Previously defined on line %d',
+ filename, linen, name, prev)
+
+ # Get the start symbol
+ def get_start(self):
+ self.start = self.pdict.get('start')
+
+ # Validate the start symbol
+ def validate_start(self):
+ if self.start is not None:
+ if not isinstance(self.start, string_types):
+ self.log.error("'start' must be a string")
+
+ # Look for error handler
+ def get_error_func(self):
+ self.error_func = self.pdict.get('p_error')
+
+ # Validate the error function
+ def validate_error_func(self):
+ if self.error_func:
+ if isinstance(self.error_func, types.FunctionType):
+ ismethod = 0
+ elif isinstance(self.error_func, types.MethodType):
+ ismethod = 1
+ else:
+ self.log.error("'p_error' defined, but is not a function or method")
+ self.error = True
+ return
+
+ eline = self.error_func.__code__.co_firstlineno
+ efile = self.error_func.__code__.co_filename
+ module = inspect.getmodule(self.error_func)
+ self.modules.add(module)
+
+ argcount = self.error_func.__code__.co_argcount - ismethod
+ if argcount != 1:
+ self.log.error('%s:%d: p_error() requires 1 argument', efile, eline)
+ self.error = True
+
+ # Get the tokens map
+ def get_tokens(self):
+ tokens = self.pdict.get('tokens')
+ if not tokens:
+ self.log.error('No token list is defined')
+ self.error = True
+ return
+
+ if not isinstance(tokens, (list, tuple)):
+ self.log.error('tokens must be a list or tuple')
+ self.error = True
+ return
+
+ if not tokens:
+ self.log.error('tokens is empty')
+ self.error = True
+ return
+
+ self.tokens = tokens
+
+ # Validate the tokens
+ def validate_tokens(self):
+ # Validate the tokens.
+ if 'error' in self.tokens:
+ self.log.error("Illegal token name 'error'. Is a reserved word")
+ self.error = True
+ return
+
+ terminals = set()
+ for n in self.tokens:
+ if n in terminals:
+ self.log.warning('Token %r multiply defined', n)
+ terminals.add(n)
+
+ # Get the precedence map (if any)
+ def get_precedence(self):
+ self.prec = self.pdict.get('precedence')
+
+ # Validate and parse the precedence map
+ def validate_precedence(self):
+ preclist = []
+ if self.prec:
+ if not isinstance(self.prec, (list, tuple)):
+ self.log.error('precedence must be a list or tuple')
+ self.error = True
+ return
+ for level, p in enumerate(self.prec):
+ if not isinstance(p, (list, tuple)):
+ self.log.error('Bad precedence table')
+ self.error = True
+ return
+
+ if len(p) < 2:
+ self.log.error('Malformed precedence entry %s. Must be (assoc, term, ..., term)', p)
+ self.error = True
+ return
+ assoc = p[0]
+ if not isinstance(assoc, string_types):
+ self.log.error('precedence associativity must be a string')
+ self.error = True
+ return
+ for term in p[1:]:
+ if not isinstance(term, string_types):
+ self.log.error('precedence items must be strings')
+ self.error = True
+ return
+ preclist.append((term, assoc, level+1))
+ self.preclist = preclist
+
+ # Get all p_functions from the grammar
+ def get_pfunctions(self):
+ p_functions = []
+ for name, item in self.pdict.items():
+ if not name.startswith('p_') or name == 'p_error':
+ continue
+ if isinstance(item, (types.FunctionType, types.MethodType)):
+ line = getattr(item, 'co_firstlineno', item.__code__.co_firstlineno)
+ module = inspect.getmodule(item)
+ p_functions.append((line, module, name, item.__doc__))
+
+ # Sort all of the actions by line number; make sure to stringify
+ # modules to make them sortable, since `line` may not uniquely sort all
+ # p functions
+ p_functions.sort(key=lambda p_function: (
+ p_function[0],
+ str(p_function[1]),
+ p_function[2],
+ p_function[3]))
+ self.pfuncs = p_functions
+
+ # Validate all of the p_functions
+ def validate_pfunctions(self):
+ grammar = []
+ # Check for non-empty symbols
+ if len(self.pfuncs) == 0:
+ self.log.error('no rules of the form p_rulename are defined')
+ self.error = True
+ return
+
+ for line, module, name, doc in self.pfuncs:
+ file = inspect.getsourcefile(module)
+ func = self.pdict[name]
+ if isinstance(func, types.MethodType):
+ reqargs = 2
+ else:
+ reqargs = 1
+ if func.__code__.co_argcount > reqargs:
+ self.log.error('%s:%d: Rule %r has too many arguments', file, line, func.__name__)
+ self.error = True
+ elif func.__code__.co_argcount < reqargs:
+ self.log.error('%s:%d: Rule %r requires an argument', file, line, func.__name__)
+ self.error = True
+ elif not func.__doc__:
+ self.log.warning('%s:%d: No documentation string specified in function %r (ignored)',
+ file, line, func.__name__)
+ else:
+ try:
+ parsed_g = parse_grammar(doc, file, line)
+ for g in parsed_g:
+ grammar.append((name, g))
+ except SyntaxError as e:
+ self.log.error(str(e))
+ self.error = True
+
+ # Looks like a valid grammar rule
+ # Mark the file in which defined.
+ self.modules.add(module)
+
+ # Secondary validation step that looks for p_ definitions that are not functions
+ # or functions that look like they might be grammar rules.
+
+ for n, v in self.pdict.items():
+ if n.startswith('p_') and isinstance(v, (types.FunctionType, types.MethodType)):
+ continue
+ if n.startswith('t_'):
+ continue
+ if n.startswith('p_') and n != 'p_error':
+ self.log.warning('%r not defined as a function', n)
+ if ((isinstance(v, types.FunctionType) and v.__code__.co_argcount == 1) or
+ (isinstance(v, types.MethodType) and v.__func__.__code__.co_argcount == 2)):
+ if v.__doc__:
+ try:
+ doc = v.__doc__.split(' ')
+ if doc[1] == ':':
+ self.log.warning('%s:%d: Possible grammar rule %r defined without p_ prefix',
+ v.__code__.co_filename, v.__code__.co_firstlineno, n)
+ except IndexError:
+ pass
+
+ self.grammar = grammar
+
+# -----------------------------------------------------------------------------
+# yacc(module)
+#
+# Build a parser
+# -----------------------------------------------------------------------------
+
+def yacc(method='LALR', debug=yaccdebug, module=None, tabmodule=tab_module, start=None,
+ check_recursion=True, optimize=False, write_tables=True, debugfile=debug_file,
+ outputdir=None, debuglog=None, errorlog=None, picklefile=None):
+
+ if tabmodule is None:
+ tabmodule = tab_module
+
+ # Reference to the parsing method of the last built parser
+ global parse
+
+ # If pickling is enabled, table files are not created
+ if picklefile:
+ write_tables = 0
+
+ if errorlog is None:
+ errorlog = PlyLogger(sys.stderr)
+
+ # Get the module dictionary used for the parser
+ if module:
+ _items = [(k, getattr(module, k)) for k in dir(module)]
+ pdict = dict(_items)
+ # If no __file__ attribute is available, try to obtain it from the __module__ instead
+ if '__file__' not in pdict:
+ pdict['__file__'] = sys.modules[pdict['__module__']].__file__
+ else:
+ pdict = get_caller_module_dict(2)
+
+ if outputdir is None:
+ # If no output directory is set, the location of the output files
+ # is determined according to the following rules:
+ # - If tabmodule specifies a package, files go into that package directory
+ # - Otherwise, files go in the same directory as the specifying module
+ if isinstance(tabmodule, types.ModuleType):
+ srcfile = tabmodule.__file__
+ else:
+ if '.' not in tabmodule:
+ srcfile = pdict['__file__']
+ else:
+ parts = tabmodule.split('.')
+ pkgname = '.'.join(parts[:-1])
+ exec('import %s' % pkgname)
+ srcfile = getattr(sys.modules[pkgname], '__file__', '')
+ outputdir = os.path.dirname(srcfile)
+
+ # Determine if the module is package of a package or not.
+ # If so, fix the tabmodule setting so that tables load correctly
+ pkg = pdict.get('__package__')
+ if pkg and isinstance(tabmodule, str):
+ if '.' not in tabmodule:
+ tabmodule = pkg + '.' + tabmodule
+
+
+
+ # Set start symbol if it's specified directly using an argument
+ if start is not None:
+ pdict['start'] = start
+
+ # Collect parser information from the dictionary
+ pinfo = ParserReflect(pdict, log=errorlog)
+ pinfo.get_all()
+
+ if pinfo.error:
+ raise YaccError('Unable to build parser')
+
+ # Check signature against table files (if any)
+ signature = pinfo.signature()
+
+ # Read the tables
+ try:
+ lr = LRTable()
+ if picklefile:
+ read_signature = lr.read_pickle(picklefile)
+ else:
+ read_signature = lr.read_table(tabmodule)
+ if optimize or (read_signature == signature):
+ try:
+ lr.bind_callables(pinfo.pdict)
+ parser = LRParser(lr, pinfo.error_func)
+ parse = parser.parse
+ return parser
+ except Exception as e:
+ errorlog.warning('There was a problem loading the table file: %r', e)
+ except VersionError as e:
+ errorlog.warning(str(e))
+ except ImportError:
+ pass
+
+ if debuglog is None:
+ if debug:
+ try:
+ debuglog = PlyLogger(open(os.path.join(outputdir, debugfile), 'w'))
+ except IOError as e:
+ errorlog.warning("Couldn't open %r. %s" % (debugfile, e))
+ debuglog = NullLogger()
+ else:
+ debuglog = NullLogger()
+
+ debuglog.info('Created by PLY version %s (http://www.dabeaz.com/ply)', __version__)
+
+ errors = False
+
+ # Validate the parser information
+ if pinfo.validate_all():
+ raise YaccError('Unable to build parser')
+
+ if not pinfo.error_func:
+ errorlog.warning('no p_error() function is defined')
+
+ # Create a grammar object
+ grammar = Grammar(pinfo.tokens)
+
+ # Set precedence level for terminals
+ for term, assoc, level in pinfo.preclist:
+ try:
+ grammar.set_precedence(term, assoc, level)
+ except GrammarError as e:
+ errorlog.warning('%s', e)
+
+ # Add productions to the grammar
+ for funcname, gram in pinfo.grammar:
+ file, line, prodname, syms = gram
+ try:
+ grammar.add_production(prodname, syms, funcname, file, line)
+ except GrammarError as e:
+ errorlog.error('%s', e)
+ errors = True
+
+ # Set the grammar start symbols
+ try:
+ if start is None:
+ grammar.set_start(pinfo.start)
+ else:
+ grammar.set_start(start)
+ except GrammarError as e:
+ errorlog.error(str(e))
+ errors = True
+
+ if errors:
+ raise YaccError('Unable to build parser')
+
+ # Verify the grammar structure
+ undefined_symbols = grammar.undefined_symbols()
+ for sym, prod in undefined_symbols:
+ errorlog.error('%s:%d: Symbol %r used, but not defined as a token or a rule', prod.file, prod.line, sym)
+ errors = True
+
+ unused_terminals = grammar.unused_terminals()
+ if unused_terminals:
+ debuglog.info('')
+ debuglog.info('Unused terminals:')
+ debuglog.info('')
+ for term in unused_terminals:
+ errorlog.warning('Token %r defined, but not used', term)
+ debuglog.info(' %s', term)
+
+ # Print out all productions to the debug log
+ if debug:
+ debuglog.info('')
+ debuglog.info('Grammar')
+ debuglog.info('')
+ for n, p in enumerate(grammar.Productions):
+ debuglog.info('Rule %-5d %s', n, p)
+
+ # Find unused non-terminals
+ unused_rules = grammar.unused_rules()
+ for prod in unused_rules:
+ errorlog.warning('%s:%d: Rule %r defined, but not used', prod.file, prod.line, prod.name)
+
+ if len(unused_terminals) == 1:
+ errorlog.warning('There is 1 unused token')
+ if len(unused_terminals) > 1:
+ errorlog.warning('There are %d unused tokens', len(unused_terminals))
+
+ if len(unused_rules) == 1:
+ errorlog.warning('There is 1 unused rule')
+ if len(unused_rules) > 1:
+ errorlog.warning('There are %d unused rules', len(unused_rules))
+
+ if debug:
+ debuglog.info('')
+ debuglog.info('Terminals, with rules where they appear')
+ debuglog.info('')
+ terms = list(grammar.Terminals)
+ terms.sort()
+ for term in terms:
+ debuglog.info('%-20s : %s', term, ' '.join([str(s) for s in grammar.Terminals[term]]))
+
+ debuglog.info('')
+ debuglog.info('Nonterminals, with rules where they appear')
+ debuglog.info('')
+ nonterms = list(grammar.Nonterminals)
+ nonterms.sort()
+ for nonterm in nonterms:
+ debuglog.info('%-20s : %s', nonterm, ' '.join([str(s) for s in grammar.Nonterminals[nonterm]]))
+ debuglog.info('')
+
+ if check_recursion:
+ unreachable = grammar.find_unreachable()
+ for u in unreachable:
+ errorlog.warning('Symbol %r is unreachable', u)
+
+ infinite = grammar.infinite_cycles()
+ for inf in infinite:
+ errorlog.error('Infinite recursion detected for symbol %r', inf)
+ errors = True
+
+ unused_prec = grammar.unused_precedence()
+ for term, assoc in unused_prec:
+ errorlog.error('Precedence rule %r defined for unknown symbol %r', assoc, term)
+ errors = True
+
+ if errors:
+ raise YaccError('Unable to build parser')
+
+ # Run the LRGeneratedTable on the grammar
+ if debug:
+ errorlog.debug('Generating %s tables', method)
+
+ lr = LRGeneratedTable(grammar, method, debuglog)
+
+ if debug:
+ num_sr = len(lr.sr_conflicts)
+
+ # Report shift/reduce and reduce/reduce conflicts
+ if num_sr == 1:
+ errorlog.warning('1 shift/reduce conflict')
+ elif num_sr > 1:
+ errorlog.warning('%d shift/reduce conflicts', num_sr)
+
+ num_rr = len(lr.rr_conflicts)
+ if num_rr == 1:
+ errorlog.warning('1 reduce/reduce conflict')
+ elif num_rr > 1:
+ errorlog.warning('%d reduce/reduce conflicts', num_rr)
+
+ # Write out conflicts to the output file
+ if debug and (lr.sr_conflicts or lr.rr_conflicts):
+ debuglog.warning('')
+ debuglog.warning('Conflicts:')
+ debuglog.warning('')
+
+ for state, tok, resolution in lr.sr_conflicts:
+ debuglog.warning('shift/reduce conflict for %s in state %d resolved as %s', tok, state, resolution)
+
+ already_reported = set()
+ for state, rule, rejected in lr.rr_conflicts:
+ if (state, id(rule), id(rejected)) in already_reported:
+ continue
+ debuglog.warning('reduce/reduce conflict in state %d resolved using rule (%s)', state, rule)
+ debuglog.warning('rejected rule (%s) in state %d', rejected, state)
+ errorlog.warning('reduce/reduce conflict in state %d resolved using rule (%s)', state, rule)
+ errorlog.warning('rejected rule (%s) in state %d', rejected, state)
+ already_reported.add((state, id(rule), id(rejected)))
+
+ warned_never = []
+ for state, rule, rejected in lr.rr_conflicts:
+ if not rejected.reduced and (rejected not in warned_never):
+ debuglog.warning('Rule (%s) is never reduced', rejected)
+ errorlog.warning('Rule (%s) is never reduced', rejected)
+ warned_never.append(rejected)
+
+ # Write the table file if requested
+ if write_tables:
+ try:
+ lr.write_table(tabmodule, outputdir, signature)
+ except IOError as e:
+ errorlog.warning("Couldn't create %r. %s" % (tabmodule, e))
+
+ # Write a pickled version of the tables
+ if picklefile:
+ try:
+ lr.pickle_table(picklefile, signature)
+ except IOError as e:
+ errorlog.warning("Couldn't create %r. %s" % (picklefile, e))
+
+ # Build the parser
+ lr.bind_callables(pinfo.pdict)
+ parser = LRParser(lr, pinfo.error_func)
+
+ parse = parser.parse
+ return parser
diff --git a/lib/pycparser/ply/ygen.py b/lib/pycparser/ply/ygen.py
new file mode 100644
index 0000000..acf5ca1
--- /dev/null
+++ b/lib/pycparser/ply/ygen.py
@@ -0,0 +1,74 @@
+# ply: ygen.py
+#
+# This is a support program that auto-generates different versions of the YACC parsing
+# function with different features removed for the purposes of performance.
+#
+# Users should edit the method LParser.parsedebug() in yacc.py. The source code
+# for that method is then used to create the other methods. See the comments in
+# yacc.py for further details.
+
+import os.path
+import shutil
+
+def get_source_range(lines, tag):
+ srclines = enumerate(lines)
+ start_tag = '#--! %s-start' % tag
+ end_tag = '#--! %s-end' % tag
+
+ for start_index, line in srclines:
+ if line.strip().startswith(start_tag):
+ break
+
+ for end_index, line in srclines:
+ if line.strip().endswith(end_tag):
+ break
+
+ return (start_index + 1, end_index)
+
+def filter_section(lines, tag):
+ filtered_lines = []
+ include = True
+ tag_text = '#--! %s' % tag
+ for line in lines:
+ if line.strip().startswith(tag_text):
+ include = not include
+ elif include:
+ filtered_lines.append(line)
+ return filtered_lines
+
+def main():
+ dirname = os.path.dirname(__file__)
+ shutil.copy2(os.path.join(dirname, 'yacc.py'), os.path.join(dirname, 'yacc.py.bak'))
+ with open(os.path.join(dirname, 'yacc.py'), 'r') as f:
+ lines = f.readlines()
+
+ parse_start, parse_end = get_source_range(lines, 'parsedebug')
+ parseopt_start, parseopt_end = get_source_range(lines, 'parseopt')
+ parseopt_notrack_start, parseopt_notrack_end = get_source_range(lines, 'parseopt-notrack')
+
+ # Get the original source
+ orig_lines = lines[parse_start:parse_end]
+
+ # Filter the DEBUG sections out
+ parseopt_lines = filter_section(orig_lines, 'DEBUG')
+
+ # Filter the TRACKING sections out
+ parseopt_notrack_lines = filter_section(parseopt_lines, 'TRACKING')
+
+ # Replace the parser source sections with updated versions
+ lines[parseopt_notrack_start:parseopt_notrack_end] = parseopt_notrack_lines
+ lines[parseopt_start:parseopt_end] = parseopt_lines
+
+ lines = [line.rstrip()+'\n' for line in lines]
+ with open(os.path.join(dirname, 'yacc.py'), 'w') as f:
+ f.writelines(lines)
+
+ print('Updated yacc.py')
+
+if __name__ == '__main__':
+ main()
+
+
+
+
+
diff --git a/lib/pycparser/plyparser.py b/lib/pycparser/plyparser.py
new file mode 100644
index 0000000..b8f4c43
--- /dev/null
+++ b/lib/pycparser/plyparser.py
@@ -0,0 +1,133 @@
+#-----------------------------------------------------------------
+# plyparser.py
+#
+# PLYParser class and other utilities for simplifying programming
+# parsers with PLY
+#
+# Eli Bendersky [https://eli.thegreenplace.net/]
+# License: BSD
+#-----------------------------------------------------------------
+
+import warnings
+
+class Coord(object):
+ """ Coordinates of a syntactic element. Consists of:
+ - File name
+ - Line number
+ - (optional) column number, for the Lexer
+ """
+ __slots__ = ('file', 'line', 'column', '__weakref__')
+ def __init__(self, file, line, column=None):
+ self.file = file
+ self.line = line
+ self.column = column
+
+ def __str__(self):
+ str = "%s:%s" % (self.file, self.line)
+ if self.column: str += ":%s" % self.column
+ return str
+
+
+class ParseError(Exception): pass
+
+
+class PLYParser(object):
+ def _create_opt_rule(self, rulename):
+ """ Given a rule name, creates an optional ply.yacc rule
+ for it. The name of the optional rule is
+ <rulename>_opt
+ """
+ optname = rulename + '_opt'
+
+ def optrule(self, p):
+ p[0] = p[1]
+
+ optrule.__doc__ = '%s : empty\n| %s' % (optname, rulename)
+ optrule.__name__ = 'p_%s' % optname
+ setattr(self.__class__, optrule.__name__, optrule)
+
+ def _coord(self, lineno, column=None):
+ return Coord(
+ file=self.clex.filename,
+ line=lineno,
+ column=column)
+
+ def _token_coord(self, p, token_idx):
+ """ Returns the coordinates for the YaccProduction object 'p' indexed
+ with 'token_idx'. The coordinate includes the 'lineno' and
+ 'column'. Both follow the lex semantic, starting from 1.
+ """
+ last_cr = p.lexer.lexer.lexdata.rfind('\n', 0, p.lexpos(token_idx))
+ if last_cr < 0:
+ last_cr = -1
+ column = (p.lexpos(token_idx) - (last_cr))
+ return self._coord(p.lineno(token_idx), column)
+
+ def _parse_error(self, msg, coord):
+ raise ParseError("%s: %s" % (coord, msg))
+
+
+def parameterized(*params):
+ """ Decorator to create parameterized rules.
+
+ Parameterized rule methods must be named starting with 'p_' and contain
+ 'xxx', and their docstrings may contain 'xxx' and 'yyy'. These will be
+ replaced by the given parameter tuples. For example, ``p_xxx_rule()`` with
+ docstring 'xxx_rule : yyy' when decorated with
+ ``@parameterized(('id', 'ID'))`` produces ``p_id_rule()`` with the docstring
+ 'id_rule : ID'. Using multiple tuples produces multiple rules.
+ """
+ def decorate(rule_func):
+ rule_func._params = params
+ return rule_func
+ return decorate
+
+
+def template(cls):
+ """ Class decorator to generate rules from parameterized rule templates.
+
+ See `parameterized` for more information on parameterized rules.
+ """
+ issued_nodoc_warning = False
+ for attr_name in dir(cls):
+ if attr_name.startswith('p_'):
+ method = getattr(cls, attr_name)
+ if hasattr(method, '_params'):
+ # Remove the template method
+ delattr(cls, attr_name)
+ # Create parameterized rules from this method; only run this if
+ # the method has a docstring. This is to address an issue when
+ # pycparser's users are installed in -OO mode which strips
+ # docstrings away.
+ # See: https://github.com/eliben/pycparser/pull/198/ and
+ # https://github.com/eliben/pycparser/issues/197
+ # for discussion.
+ if method.__doc__ is not None:
+ _create_param_rules(cls, method)
+ elif not issued_nodoc_warning:
+ warnings.warn(
+ 'parsing methods must have __doc__ for pycparser to work properly',
+ RuntimeWarning,
+ stacklevel=2)
+ issued_nodoc_warning = True
+ return cls
+
+
+def _create_param_rules(cls, func):
+ """ Create ply.yacc rules based on a parameterized rule function
+
+ Generates new methods (one per each pair of parameters) based on the
+ template rule function `func`, and attaches them to `cls`. The rule
+ function's parameters must be accessible via its `_params` attribute.
+ """
+ for xxx, yyy in func._params:
+ # Use the template method's body for each new method
+ def param_rule(self, p):
+ func(self, p)
+
+ # Substitute in the params for the grammar rule and function name
+ param_rule.__doc__ = func.__doc__.replace('xxx', xxx).replace('yyy', yyy)
+ param_rule.__name__ = func.__name__.replace('xxx', xxx)
+
+ # Attach the new method to the class
+ setattr(cls, param_rule.__name__, param_rule)
diff --git a/lib/pycparser/yacctab.py b/lib/pycparser/yacctab.py
new file mode 100644
index 0000000..0622c36
--- /dev/null
+++ b/lib/pycparser/yacctab.py
@@ -0,0 +1,366 @@
+
+# yacctab.py
+# This file is automatically generated. Do not edit.
+_tabversion = '3.10'
+
+_lr_method = 'LALR'
+
+_lr_signature = 'translation_unit_or_emptyleftLORleftLANDleftORleftXORleftANDleftEQNEleftGTGELTLEleftRSHIFTLSHIFTleftPLUSMINUSleftTIMESDIVIDEMODAUTO BREAK CASE CHAR CONST CONTINUE DEFAULT DO DOUBLE ELSE ENUM EXTERN FLOAT FOR GOTO IF INLINE INT LONG REGISTER OFFSETOF RESTRICT RETURN SHORT SIGNED SIZEOF STATIC STRUCT SWITCH TYPEDEF UNION UNSIGNED VOID VOLATILE WHILE __INT128 _BOOL _COMPLEX _NORETURN _THREAD_LOCAL _STATIC_ASSERT _ATOMIC _ALIGNOF _ALIGNAS ID TYPEID INT_CONST_DEC INT_CONST_OCT INT_CONST_HEX INT_CONST_BIN INT_CONST_CHAR FLOAT_CONST HEX_FLOAT_CONST CHAR_CONST WCHAR_CONST U8CHAR_CONST U16CHAR_CONST U32CHAR_CONST STRING_LITERAL WSTRING_LITERAL U8STRING_LITERAL U16STRING_LITERAL U32STRING_LITERAL PLUS MINUS TIMES DIVIDE MOD OR AND NOT XOR LSHIFT RSHIFT LOR LAND LNOT LT LE GT GE EQ NE EQUALS TIMESEQUAL DIVEQUAL MODEQUAL PLUSEQUAL MINUSEQUAL LSHIFTEQUAL RSHIFTEQUAL ANDEQUAL XOREQUAL OREQUAL PLUSPLUS MINUSMINUS ARROW CONDOP LPAREN RPAREN LBRACKET RBRACKET LBRACE RBRACE COMMA PERIOD SEMI COLON ELLIPSIS PPHASH PPPRAGMA PPPRAGMASTRabstract_declarator_opt : empty\n| abstract_declaratorassignment_expression_opt : empty\n| assignment_expressionblock_item_list_opt : empty\n| block_item_listdeclaration_list_opt : empty\n| declaration_listdeclaration_specifiers_no_type_opt : empty\n| declaration_specifiers_no_typedesignation_opt : empty\n| designationexpression_opt : empty\n| expressionid_init_declarator_list_opt : empty\n| id_init_declarator_listidentifier_list_opt : empty\n| identifier_listinit_declarator_list_opt : empty\n| init_declarator_listinitializer_list_opt : empty\n| initializer_listparameter_type_list_opt : empty\n| parameter_type_liststruct_declarator_list_opt : empty\n| struct_declarator_listtype_qualifier_list_opt : empty\n| type_qualifier_list direct_id_declarator : ID\n direct_id_declarator : LPAREN id_declarator RPAREN\n direct_id_declarator : direct_id_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_id_declarator : direct_id_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET\n | direct_id_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET\n direct_id_declarator : direct_id_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET\n direct_id_declarator : direct_id_declarator LPAREN parameter_type_list RPAREN\n | direct_id_declarator LPAREN identifier_list_opt RPAREN\n direct_typeid_declarator : TYPEID\n direct_typeid_declarator : LPAREN typeid_declarator RPAREN\n direct_typeid_declarator : direct_typeid_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_typeid_declarator : direct_typeid_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET\n | direct_typeid_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET\n direct_typeid_declarator : direct_typeid_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET\n direct_typeid_declarator : direct_typeid_declarator LPAREN parameter_type_list RPAREN\n | direct_typeid_declarator LPAREN identifier_list_opt RPAREN\n direct_typeid_noparen_declarator : TYPEID\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET\n | direct_typeid_noparen_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET\n direct_typeid_noparen_declarator : direct_typeid_noparen_declarator LPAREN parameter_type_list RPAREN\n | direct_typeid_noparen_declarator LPAREN identifier_list_opt RPAREN\n id_declarator : direct_id_declarator\n id_declarator : pointer direct_id_declarator\n typeid_declarator : direct_typeid_declarator\n typeid_declarator : pointer direct_typeid_declarator\n typeid_noparen_declarator : direct_typeid_noparen_declarator\n typeid_noparen_declarator : pointer direct_typeid_noparen_declarator\n translation_unit_or_empty : translation_unit\n | empty\n translation_unit : external_declaration\n translation_unit : translation_unit external_declaration\n external_declaration : function_definition\n external_declaration : declaration\n external_declaration : pp_directive\n | pppragma_directive\n external_declaration : SEMI\n external_declaration : static_assert\n static_assert : _STATIC_ASSERT LPAREN constant_expression COMMA unified_string_literal RPAREN\n | _STATIC_ASSERT LPAREN constant_expression RPAREN\n pp_directive : PPHASH\n pppragma_directive : PPPRAGMA\n | PPPRAGMA PPPRAGMASTR\n function_definition : id_declarator declaration_list_opt compound_statement\n function_definition : declaration_specifiers id_declarator declaration_list_opt compound_statement\n statement : labeled_statement\n | expression_statement\n | compound_statement\n | selection_statement\n | iteration_statement\n | jump_statement\n | pppragma_directive\n | static_assert\n pragmacomp_or_statement : pppragma_directive statement\n | statement\n decl_body : declaration_specifiers init_declarator_list_opt\n | declaration_specifiers_no_type id_init_declarator_list_opt\n declaration : decl_body SEMI\n declaration_list : declaration\n | declaration_list declaration\n declaration_specifiers_no_type : type_qualifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : storage_class_specifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : function_specifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : atomic_specifier declaration_specifiers_no_type_opt\n declaration_specifiers_no_type : alignment_specifier declaration_specifiers_no_type_opt\n declaration_specifiers : declaration_specifiers type_qualifier\n declaration_specifiers : declaration_specifiers storage_class_specifier\n declaration_specifiers : declaration_specifiers function_specifier\n declaration_specifiers : declaration_specifiers type_specifier_no_typeid\n declaration_specifiers : type_specifier\n declaration_specifiers : declaration_specifiers_no_type type_specifier\n declaration_specifiers : declaration_specifiers alignment_specifier\n storage_class_specifier : AUTO\n | REGISTER\n | STATIC\n | EXTERN\n | TYPEDEF\n | _THREAD_LOCAL\n function_specifier : INLINE\n | _NORETURN\n type_specifier_no_typeid : VOID\n | _BOOL\n | CHAR\n | SHORT\n | INT\n | LONG\n | FLOAT\n | DOUBLE\n | _COMPLEX\n | SIGNED\n | UNSIGNED\n | __INT128\n type_specifier : typedef_name\n | enum_specifier\n | struct_or_union_specifier\n | type_specifier_no_typeid\n | atomic_specifier\n atomic_specifier : _ATOMIC LPAREN type_name RPAREN\n type_qualifier : CONST\n | RESTRICT\n | VOLATILE\n | _ATOMIC\n init_declarator_list : init_declarator\n | init_declarator_list COMMA init_declarator\n init_declarator : declarator\n | declarator EQUALS initializer\n id_init_declarator_list : id_init_declarator\n | id_init_declarator_list COMMA init_declarator\n id_init_declarator : id_declarator\n | id_declarator EQUALS initializer\n specifier_qualifier_list : specifier_qualifier_list type_specifier_no_typeid\n specifier_qualifier_list : specifier_qualifier_list type_qualifier\n specifier_qualifier_list : type_specifier\n specifier_qualifier_list : type_qualifier_list type_specifier\n specifier_qualifier_list : alignment_specifier\n specifier_qualifier_list : specifier_qualifier_list alignment_specifier\n struct_or_union_specifier : struct_or_union ID\n | struct_or_union TYPEID\n struct_or_union_specifier : struct_or_union brace_open struct_declaration_list brace_close\n | struct_or_union brace_open brace_close\n struct_or_union_specifier : struct_or_union ID brace_open struct_declaration_list brace_close\n | struct_or_union ID brace_open brace_close\n | struct_or_union TYPEID brace_open struct_declaration_list brace_close\n | struct_or_union TYPEID brace_open brace_close\n struct_or_union : STRUCT\n | UNION\n struct_declaration_list : struct_declaration\n | struct_declaration_list struct_declaration\n struct_declaration : specifier_qualifier_list struct_declarator_list_opt SEMI\n struct_declaration : SEMI\n struct_declaration : pppragma_directive\n struct_declarator_list : struct_declarator\n | struct_declarator_list COMMA struct_declarator\n struct_declarator : declarator\n struct_declarator : declarator COLON constant_expression\n | COLON constant_expression\n enum_specifier : ENUM ID\n | ENUM TYPEID\n enum_specifier : ENUM brace_open enumerator_list brace_close\n enum_specifier : ENUM ID brace_open enumerator_list brace_close\n | ENUM TYPEID brace_open enumerator_list brace_close\n enumerator_list : enumerator\n | enumerator_list COMMA\n | enumerator_list COMMA enumerator\n alignment_specifier : _ALIGNAS LPAREN type_name RPAREN\n | _ALIGNAS LPAREN constant_expression RPAREN\n enumerator : ID\n | ID EQUALS constant_expression\n declarator : id_declarator\n | typeid_declarator\n pointer : TIMES type_qualifier_list_opt\n | TIMES type_qualifier_list_opt pointer\n type_qualifier_list : type_qualifier\n | type_qualifier_list type_qualifier\n parameter_type_list : parameter_list\n | parameter_list COMMA ELLIPSIS\n parameter_list : parameter_declaration\n | parameter_list COMMA parameter_declaration\n parameter_declaration : declaration_specifiers id_declarator\n | declaration_specifiers typeid_noparen_declarator\n parameter_declaration : declaration_specifiers abstract_declarator_opt\n identifier_list : identifier\n | identifier_list COMMA identifier\n initializer : assignment_expression\n initializer : brace_open initializer_list_opt brace_close\n | brace_open initializer_list COMMA brace_close\n initializer_list : designation_opt initializer\n | initializer_list COMMA designation_opt initializer\n designation : designator_list EQUALS\n designator_list : designator\n | designator_list designator\n designator : LBRACKET constant_expression RBRACKET\n | PERIOD identifier\n type_name : specifier_qualifier_list abstract_declarator_opt\n abstract_declarator : pointer\n abstract_declarator : pointer direct_abstract_declarator\n abstract_declarator : direct_abstract_declarator\n direct_abstract_declarator : LPAREN abstract_declarator RPAREN direct_abstract_declarator : direct_abstract_declarator LBRACKET assignment_expression_opt RBRACKET\n direct_abstract_declarator : LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET\n direct_abstract_declarator : direct_abstract_declarator LBRACKET TIMES RBRACKET\n direct_abstract_declarator : LBRACKET TIMES RBRACKET\n direct_abstract_declarator : direct_abstract_declarator LPAREN parameter_type_list_opt RPAREN\n direct_abstract_declarator : LPAREN parameter_type_list_opt RPAREN\n block_item : declaration\n | statement\n block_item_list : block_item\n | block_item_list block_item\n compound_statement : brace_open block_item_list_opt brace_close labeled_statement : ID COLON pragmacomp_or_statement labeled_statement : CASE constant_expression COLON pragmacomp_or_statement labeled_statement : DEFAULT COLON pragmacomp_or_statement selection_statement : IF LPAREN expression RPAREN pragmacomp_or_statement selection_statement : IF LPAREN expression RPAREN statement ELSE pragmacomp_or_statement selection_statement : SWITCH LPAREN expression RPAREN pragmacomp_or_statement iteration_statement : WHILE LPAREN expression RPAREN pragmacomp_or_statement iteration_statement : DO pragmacomp_or_statement WHILE LPAREN expression RPAREN SEMI iteration_statement : FOR LPAREN expression_opt SEMI expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement iteration_statement : FOR LPAREN declaration expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement jump_statement : GOTO ID SEMI jump_statement : BREAK SEMI jump_statement : CONTINUE SEMI jump_statement : RETURN expression SEMI\n | RETURN SEMI\n expression_statement : expression_opt SEMI expression : assignment_expression\n | expression COMMA assignment_expression\n assignment_expression : LPAREN compound_statement RPAREN typedef_name : TYPEID assignment_expression : conditional_expression\n | unary_expression assignment_operator assignment_expression\n assignment_operator : EQUALS\n | XOREQUAL\n | TIMESEQUAL\n | DIVEQUAL\n | MODEQUAL\n | PLUSEQUAL\n | MINUSEQUAL\n | LSHIFTEQUAL\n | RSHIFTEQUAL\n | ANDEQUAL\n | OREQUAL\n constant_expression : conditional_expression conditional_expression : binary_expression\n | binary_expression CONDOP expression COLON conditional_expression\n binary_expression : cast_expression\n | binary_expression TIMES binary_expression\n | binary_expression DIVIDE binary_expression\n | binary_expression MOD binary_expression\n | binary_expression PLUS binary_expression\n | binary_expression MINUS binary_expression\n | binary_expression RSHIFT binary_expression\n | binary_expression LSHIFT binary_expression\n | binary_expression LT binary_expression\n | binary_expression LE binary_expression\n | binary_expression GE binary_expression\n | binary_expression GT binary_expression\n | binary_expression EQ binary_expression\n | binary_expression NE binary_expression\n | binary_expression AND binary_expression\n | binary_expression OR binary_expression\n | binary_expression XOR binary_expression\n | binary_expression LAND binary_expression\n | binary_expression LOR binary_expression\n cast_expression : unary_expression cast_expression : LPAREN type_name RPAREN cast_expression unary_expression : postfix_expression unary_expression : PLUSPLUS unary_expression\n | MINUSMINUS unary_expression\n | unary_operator cast_expression\n unary_expression : SIZEOF unary_expression\n | SIZEOF LPAREN type_name RPAREN\n | _ALIGNOF LPAREN type_name RPAREN\n unary_operator : AND\n | TIMES\n | PLUS\n | MINUS\n | NOT\n | LNOT\n postfix_expression : primary_expression postfix_expression : postfix_expression LBRACKET expression RBRACKET postfix_expression : postfix_expression LPAREN argument_expression_list RPAREN\n | postfix_expression LPAREN RPAREN\n postfix_expression : postfix_expression PERIOD ID\n | postfix_expression PERIOD TYPEID\n | postfix_expression ARROW ID\n | postfix_expression ARROW TYPEID\n postfix_expression : postfix_expression PLUSPLUS\n | postfix_expression MINUSMINUS\n postfix_expression : LPAREN type_name RPAREN brace_open initializer_list brace_close\n | LPAREN type_name RPAREN brace_open initializer_list COMMA brace_close\n primary_expression : identifier primary_expression : constant primary_expression : unified_string_literal\n | unified_wstring_literal\n primary_expression : LPAREN expression RPAREN primary_expression : OFFSETOF LPAREN type_name COMMA offsetof_member_designator RPAREN\n offsetof_member_designator : identifier\n | offsetof_member_designator PERIOD identifier\n | offsetof_member_designator LBRACKET expression RBRACKET\n argument_expression_list : assignment_expression\n | argument_expression_list COMMA assignment_expression\n identifier : ID constant : INT_CONST_DEC\n | INT_CONST_OCT\n | INT_CONST_HEX\n | INT_CONST_BIN\n | INT_CONST_CHAR\n constant : FLOAT_CONST\n | HEX_FLOAT_CONST\n constant : CHAR_CONST\n | WCHAR_CONST\n | U8CHAR_CONST\n | U16CHAR_CONST\n | U32CHAR_CONST\n unified_string_literal : STRING_LITERAL\n | unified_string_literal STRING_LITERAL\n unified_wstring_literal : WSTRING_LITERAL\n | U8STRING_LITERAL\n | U16STRING_LITERAL\n | U32STRING_LITERAL\n | unified_wstring_literal WSTRING_LITERAL\n | unified_wstring_literal U8STRING_LITERAL\n | unified_wstring_literal U16STRING_LITERAL\n | unified_wstring_literal U32STRING_LITERAL\n brace_open : LBRACE\n brace_close : RBRACE\n empty : '
+
+_lr_action_items = {'INT_CONST_CHAR':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,132,-335,-28,-182,-27,132,-337,-87,-72,-337,132,-286,-285,132,132,-283,-287,-288,132,-284,132,132,132,-336,-183,132,132,-28,-337,132,-28,-337,-337,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,132,-337,-76,-79,-82,-75,132,-77,132,132,-81,-215,-214,-80,-216,132,-78,132,132,-69,-284,132,132,-284,132,132,-244,-247,-245,-241,-242,-246,-248,132,-250,-251,-243,-249,-12,132,132,-11,132,132,132,132,-234,-233,132,-231,132,132,-217,132,-230,132,-84,-218,132,132,132,-337,-337,-198,132,132,132,-337,-284,-229,-232,132,-221,132,-83,-219,-68,132,-28,-337,132,-11,132,132,-220,132,132,132,-284,132,132,132,-337,132,-225,-224,-222,-84,132,132,132,-226,-223,132,-228,-227,]),'VOID':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[6,-337,-113,-128,6,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,6,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,6,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,6,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,6,-131,-95,-101,-97,6,-53,-126,6,-88,6,6,-93,6,-147,-335,-146,6,-167,-166,-182,-100,-126,6,-87,-90,-94,-92,-61,-72,6,-144,-142,6,6,6,-73,6,-89,6,6,6,-149,-159,-160,-156,-336,6,-183,-30,6,6,-74,6,6,6,6,-174,-175,6,-143,-140,6,-141,-145,-76,-79,-82,-75,-77,6,-81,-215,-214,-80,-216,-78,-127,6,-153,6,-151,-148,-157,-168,-69,-36,-35,6,6,6,-234,-233,6,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,6,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LBRACKET':([2,3,5,6,7,10,11,12,13,18,20,22,23,26,27,30,33,34,35,36,39,42,43,44,46,48,49,50,54,56,58,60,62,68,71,73,76,77,80,81,82,86,96,97,98,100,101,103,104,105,106,109,111,127,132,133,134,136,138,139,140,141,142,143,145,147,148,152,153,154,156,160,161,163,164,166,167,168,169,176,177,187,191,198,199,200,211,216,227,230,235,236,237,238,240,241,261,263,269,275,276,278,279,280,283,310,312,314,316,317,328,340,341,342,344,345,347,355,356,371,376,402,403,404,405,407,411,414,442,443,448,449,453,454,457,458,464,465,470,472,474,482,483,488,489,490,492,511,512,518,519,520,526,527,529,530,531,532,544,545,547,550,551,559,560,563,565,570,571,572,],[-113,-128,-124,-110,-106,-104,-107,-125,-105,-99,-109,-120,-115,-102,-126,-108,-238,-111,-337,-122,-129,-29,-121,-116,-112,117,-123,-117,-119,-114,-130,-118,-103,-96,-98,128,-131,-37,-95,-101,-97,117,-147,-335,-146,-167,-166,-28,-180,-182,-27,-100,-126,128,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-314,-142,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,265,-323,-312,282,-149,-336,-183,-181,-30,282,-38,373,-326,-334,-332,-331,-333,-174,-175,-298,-297,-143,-140,282,282,-141,-145,421,-312,-127,-153,-151,-148,-168,-36,-35,282,282,459,-45,-44,-43,-199,373,-296,-295,-294,-293,-292,-305,421,-152,-150,-170,-169,-31,-34,282,459,-39,-42,-202,373,-200,-290,-291,373,-213,-207,-211,-33,-32,-41,-40,-201,549,-307,-209,-208,-210,-212,-51,-50,-306,373,-299,-46,-49,-308,-300,-48,-47,-309,]),'WCHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,133,-335,-28,-182,-27,133,-337,-87,-72,-337,133,-286,-285,133,133,-283,-287,-288,133,-284,133,133,133,-336,-183,133,133,-28,-337,133,-28,-337,-337,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,133,-337,-76,-79,-82,-75,133,-77,133,133,-81,-215,-214,-80,-216,133,-78,133,133,-69,-284,133,133,-284,133,133,-244,-247,-245,-241,-242,-246,-248,133,-250,-251,-243,-249,-12,133,133,-11,133,133,133,133,-234,-233,133,-231,133,133,-217,133,-230,133,-84,-218,133,133,133,-337,-337,-198,133,133,133,-337,-284,-229,-232,133,-221,133,-83,-219,-68,133,-28,-337,133,-11,133,133,-220,133,133,133,-284,133,133,133,-337,133,-225,-224,-222,-84,133,133,133,-226,-223,133,-228,-227,]),'FLOAT_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,134,-335,-28,-182,-27,134,-337,-87,-72,-337,134,-286,-285,134,134,-283,-287,-288,134,-284,134,134,134,-336,-183,134,134,-28,-337,134,-28,-337,-337,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,134,-337,-76,-79,-82,-75,134,-77,134,134,-81,-215,-214,-80,-216,134,-78,134,134,-69,-284,134,134,-284,134,134,-244,-247,-245,-241,-242,-246,-248,134,-250,-251,-243,-249,-12,134,134,-11,134,134,134,134,-234,-233,134,-231,134,134,-217,134,-230,134,-84,-218,134,134,134,-337,-337,-198,134,134,134,-337,-284,-229,-232,134,-221,134,-83,-219,-68,134,-28,-337,134,-11,134,134,-220,134,134,134,-284,134,134,134,-337,134,-225,-224,-222,-84,134,134,134,-226,-223,134,-228,-227,]),'MINUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,144,145,146,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,224,227,229,230,231,232,233,234,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,273,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,135,-335,-28,-182,-27,135,-337,-87,-72,-337,135,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-274,-314,135,-327,135,-283,-287,-325,-304,-322,-302,-255,-315,-289,245,-328,-316,-288,-329,-320,-276,-323,135,-284,135,135,-312,135,-336,-183,135,135,-28,-337,135,-28,-337,-274,-337,135,-326,135,-280,135,-277,-334,-332,-331,-333,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,135,-298,-297,135,135,-279,-278,-337,-76,-79,-82,-75,135,-77,135,135,-81,-215,-214,-80,-216,135,-78,-312,135,135,-69,-284,135,135,-284,135,135,-244,-247,-245,-241,-242,-246,-248,135,-250,-251,-243,-249,-12,135,135,-11,245,245,245,-260,245,245,245,-259,245,245,-257,-256,245,245,245,245,245,-258,-296,-295,-294,-293,-292,-305,135,135,135,135,-234,-233,135,-231,135,135,-217,135,-230,135,-84,-218,135,135,135,-337,-337,-198,135,-281,-282,135,-290,-291,135,-275,-337,-284,-229,-232,135,-221,135,-83,-219,-68,135,-28,-337,135,-11,135,135,-220,135,135,135,-284,135,135,-306,135,-337,-299,135,-225,-224,-222,-84,-300,135,135,135,-226,-223,135,-228,-227,]),'RPAREN':([2,3,5,6,7,10,11,12,13,18,20,22,23,26,27,30,33,34,35,36,39,42,43,44,46,48,49,50,54,56,58,60,62,68,71,73,76,77,80,81,82,86,96,98,100,101,103,104,105,106,107,109,111,118,125,127,129,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,157,158,159,160,161,162,163,164,166,167,168,169,176,177,178,183,187,191,198,199,200,203,207,208,209,210,211,212,213,215,216,221,222,224,225,230,232,234,235,236,237,238,240,241,261,263,266,268,269,270,271,272,273,274,275,276,277,278,279,280,281,283,294,312,314,316,317,328,340,341,342,343,344,345,346,347,348,355,356,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,408,409,411,414,415,416,417,418,422,433,439,442,443,448,449,452,453,454,457,458,460,461,462,463,464,465,468,476,478,480,482,483,486,487,489,490,492,495,501,503,507,511,512,516,517,518,519,524,525,526,527,529,530,531,532,544,545,547,551,553,556,559,560,563,565,566,567,570,571,572,573,],[-113,-128,-124,-110,-106,-104,-107,-125,-105,-99,-109,-120,-115,-102,-126,-108,-238,-111,-337,-122,-129,-29,-121,-116,-112,-52,-123,-117,-119,-114,-130,-118,-103,-96,-98,-54,-131,-37,-95,-101,-97,-53,-147,-146,-167,-166,-28,-180,-182,-27,200,-100,-126,-337,216,-55,-337,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,240,-255,241,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-337,-252,312,-149,-336,-183,-181,-30,332,340,-17,341,-186,-337,-18,-184,-191,-38,355,356,-274,-239,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,407,-279,-143,411,413,-235,-278,-203,-140,-204,-1,-337,-141,-145,-2,-206,-14,-127,-153,-151,-148,-168,-36,-35,-337,-190,-204,-56,-188,-45,-189,-44,-43,476,477,478,479,480,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-310,483,-305,-205,-23,-24,489,490,-337,-13,-218,-152,-150,-170,-169,510,-31,-34,-204,-57,-337,-192,-185,-187,-39,-42,-240,-237,-281,-282,-290,-291,-236,-275,-213,-207,-211,532,535,537,539,-33,-32,544,545,-41,-40,-254,-311,547,-307,-209,-208,-210,-212,-51,-50,-306,-299,-337,568,-46,-49,-308,-300,-337,574,-48,-47,-309,577,]),'STRUCT':([0,1,3,7,10,11,13,14,16,17,19,20,21,25,26,27,29,30,38,39,40,42,45,47,48,52,53,55,58,59,61,62,63,64,65,66,67,75,85,86,87,90,91,93,94,95,97,99,105,118,119,120,121,122,123,124,129,172,174,180,181,182,184,185,186,188,189,190,191,198,200,214,223,229,231,233,239,240,241,267,278,284,285,286,289,291,298,300,301,302,303,305,308,312,313,315,318,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,446,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[24,-337,-128,-106,-104,-107,-105,-64,-60,-67,-66,-109,24,-65,-102,-337,-131,-108,-63,-129,24,-29,-62,-70,-52,-337,-337,-337,-130,24,-71,-103,-337,-9,-131,-91,-10,24,24,-53,-337,-88,24,24,-93,24,-335,24,-182,24,-87,-90,-94,-92,-61,-72,24,24,24,-73,24,-89,24,24,24,-159,-160,-156,-336,-183,-30,24,-74,24,24,24,24,-174,-175,24,24,-76,-79,-82,-75,-77,24,-81,-215,-214,-80,-216,-78,-127,24,24,-157,-69,-36,-35,24,24,24,-234,-233,24,-231,-217,-230,-81,-84,-218,-158,-31,-34,24,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LONG':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[23,-337,-113,-128,23,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,23,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,23,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,23,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,23,-131,-95,-101,-97,23,-53,-126,23,-88,23,23,-93,23,-147,-335,-146,23,-167,-166,-182,-100,-126,23,-87,-90,-94,-92,-61,-72,23,-144,-142,23,23,23,-73,23,-89,23,23,23,-149,-159,-160,-156,-336,23,-183,-30,23,23,-74,23,23,23,23,-174,-175,23,-143,-140,23,-141,-145,-76,-79,-82,-75,-77,23,-81,-215,-214,-80,-216,-78,-127,23,-153,23,-151,-148,-157,-168,-69,-36,-35,23,23,23,-234,-233,23,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,23,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PLUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,144,145,146,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,224,227,229,230,231,232,233,234,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,273,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,137,-335,-28,-182,-27,137,-337,-87,-72,-337,137,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-274,-314,137,-327,137,-283,-287,-325,-304,-322,-302,-255,-315,-289,249,-328,-316,-288,-329,-320,-276,-323,137,-284,137,137,-312,137,-336,-183,137,137,-28,-337,137,-28,-337,-274,-337,137,-326,137,-280,137,-277,-334,-332,-331,-333,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,137,-298,-297,137,137,-279,-278,-337,-76,-79,-82,-75,137,-77,137,137,-81,-215,-214,-80,-216,137,-78,-312,137,137,-69,-284,137,137,-284,137,137,-244,-247,-245,-241,-242,-246,-248,137,-250,-251,-243,-249,-12,137,137,-11,249,249,249,-260,249,249,249,-259,249,249,-257,-256,249,249,249,249,249,-258,-296,-295,-294,-293,-292,-305,137,137,137,137,-234,-233,137,-231,137,137,-217,137,-230,137,-84,-218,137,137,137,-337,-337,-198,137,-281,-282,137,-290,-291,137,-275,-337,-284,-229,-232,137,-221,137,-83,-219,-68,137,-28,-337,137,-11,137,137,-220,137,137,137,-284,137,137,-306,137,-337,-299,137,-225,-224,-222,-84,-300,137,137,137,-226,-223,137,-228,-227,]),'ELLIPSIS':([350,],[462,]),'U32STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,139,-335,-28,-182,-27,139,-337,-87,-72,-337,139,-286,-285,-330,139,-327,139,-283,-287,235,-328,-288,-329,139,-284,139,139,139,-336,-183,139,139,-28,-337,139,-28,-337,-337,139,139,139,-334,-332,-331,-333,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,139,-337,-76,-79,-82,-75,139,-77,139,139,-81,-215,-214,-80,-216,139,-78,139,139,-69,-284,139,139,-284,139,139,-244,-247,-245,-241,-242,-246,-248,139,-250,-251,-243,-249,-12,139,139,-11,139,139,139,139,-234,-233,139,-231,139,139,-217,139,-230,139,-84,-218,139,139,139,-337,-337,-198,139,139,139,-337,-284,-229,-232,139,-221,139,-83,-219,-68,139,-28,-337,139,-11,139,139,-220,139,139,139,-284,139,139,139,-337,139,-225,-224,-222,-84,139,139,139,-226,-223,139,-228,-227,]),'GT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,250,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,250,-262,-260,-264,250,-263,-259,-266,250,-257,-256,-265,250,250,250,250,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'GOTO':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,287,-336,-76,-79,-82,-75,-77,287,-81,-215,-214,-80,-216,287,-78,-69,-234,-233,-231,287,-217,-230,287,-84,-218,287,-229,-232,-221,287,-83,-219,-68,287,-220,287,287,-225,-224,-222,-84,287,287,-226,-223,287,-228,-227,]),'ENUM':([0,1,3,7,10,11,13,14,16,17,19,20,21,25,26,27,29,30,38,39,40,42,45,47,48,52,53,55,58,59,61,62,63,64,65,66,67,75,85,86,87,90,91,93,94,95,97,99,105,118,119,120,121,122,123,124,129,172,174,180,181,182,184,185,186,188,189,190,191,198,200,214,223,229,231,233,239,240,241,267,278,284,285,286,289,291,298,300,301,302,303,305,308,312,313,315,318,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,446,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[32,-337,-128,-106,-104,-107,-105,-64,-60,-67,-66,-109,32,-65,-102,-337,-131,-108,-63,-129,32,-29,-62,-70,-52,-337,-337,-337,-130,32,-71,-103,-337,-9,-131,-91,-10,32,32,-53,-337,-88,32,32,-93,32,-335,32,-182,32,-87,-90,-94,-92,-61,-72,32,32,32,-73,32,-89,32,32,32,-159,-160,-156,-336,-183,-30,32,-74,32,32,32,32,-174,-175,32,32,-76,-79,-82,-75,-77,32,-81,-215,-214,-80,-216,-78,-127,32,32,-157,-69,-36,-35,32,32,32,-234,-233,32,-231,-217,-230,-81,-84,-218,-158,-31,-34,32,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PERIOD':([97,132,133,134,136,138,139,140,141,143,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,227,230,235,236,237,238,261,263,310,371,376,402,403,404,405,407,411,470,472,474,482,483,488,520,526,527,547,550,551,563,565,572,],[-335,-317,-321,-318,-303,-324,-330,-313,-319,-301,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,264,-323,-312,-336,372,-326,-334,-332,-331,-333,-298,-297,-312,-199,372,-296,-295,-294,-293,-292,-305,-202,372,-200,-290,-291,372,-201,548,-307,-306,372,-299,-308,-300,-309,]),'GE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,254,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,254,-262,-260,-264,254,-263,-259,-266,254,-257,-256,-265,254,254,254,254,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'INT_CONST_DEC':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,140,-335,-28,-182,-27,140,-337,-87,-72,-337,140,-286,-285,140,140,-283,-287,-288,140,-284,140,140,140,-336,-183,140,140,-28,-337,140,-28,-337,-337,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,140,-337,-76,-79,-82,-75,140,-77,140,140,-81,-215,-214,-80,-216,140,-78,140,140,-69,-284,140,140,-284,140,140,-244,-247,-245,-241,-242,-246,-248,140,-250,-251,-243,-249,-12,140,140,-11,140,140,140,140,-234,-233,140,-231,140,140,-217,140,-230,140,-84,-218,140,140,140,-337,-337,-198,140,140,140,-337,-284,-229,-232,140,-221,140,-83,-219,-68,140,-28,-337,140,-11,140,140,-220,140,140,140,-284,140,140,140,-337,140,-225,-224,-222,-84,140,140,140,-226,-223,140,-228,-227,]),'ARROW':([132,133,134,136,138,139,140,141,143,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,230,235,236,237,238,261,263,310,402,403,404,405,407,411,482,483,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,262,-323,-312,-336,-326,-334,-332,-331,-333,-298,-297,-312,-296,-295,-294,-293,-292,-305,-290,-291,-306,-299,-300,]),'_STATIC_ASSERT':([0,14,16,17,19,25,38,45,47,59,61,97,119,123,124,180,181,191,223,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[41,-64,-60,-67,-66,-65,-63,-62,-70,41,-71,-335,-87,-61,-72,-73,41,-336,-74,-76,-79,-82,-75,-77,41,-81,-215,-214,-80,-216,41,-78,-69,-234,-233,-231,41,-217,-230,41,-84,-218,41,-229,-232,-221,41,-83,-219,-68,41,-220,41,41,-225,-224,-222,-84,41,41,-226,-223,41,-228,-227,]),'CHAR':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[46,-337,-113,-128,46,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,46,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,46,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,46,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,46,-131,-95,-101,-97,46,-53,-126,46,-88,46,46,-93,46,-147,-335,-146,46,-167,-166,-182,-100,-126,46,-87,-90,-94,-92,-61,-72,46,-144,-142,46,46,46,-73,46,-89,46,46,46,-149,-159,-160,-156,-336,46,-183,-30,46,46,-74,46,46,46,46,-174,-175,46,-143,-140,46,-141,-145,-76,-79,-82,-75,-77,46,-81,-215,-214,-80,-216,-78,-127,46,-153,46,-151,-148,-157,-168,-69,-36,-35,46,46,46,-234,-233,46,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,46,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'HEX_FLOAT_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,141,-335,-28,-182,-27,141,-337,-87,-72,-337,141,-286,-285,141,141,-283,-287,-288,141,-284,141,141,141,-336,-183,141,141,-28,-337,141,-28,-337,-337,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,141,-337,-76,-79,-82,-75,141,-77,141,141,-81,-215,-214,-80,-216,141,-78,141,141,-69,-284,141,141,-284,141,141,-244,-247,-245,-241,-242,-246,-248,141,-250,-251,-243,-249,-12,141,141,-11,141,141,141,141,-234,-233,141,-231,141,141,-217,141,-230,141,-84,-218,141,141,141,-337,-337,-198,141,141,141,-337,-284,-229,-232,141,-221,141,-83,-219,-68,141,-28,-337,141,-11,141,141,-220,141,141,141,-284,141,141,141,-337,141,-225,-224,-222,-84,141,141,141,-226,-223,141,-228,-227,]),'DOUBLE':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[50,-337,-113,-128,50,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,50,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,50,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,50,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,50,-131,-95,-101,-97,50,-53,-126,50,-88,50,50,-93,50,-147,-335,-146,50,-167,-166,-182,-100,-126,50,-87,-90,-94,-92,-61,-72,50,-144,-142,50,50,50,-73,50,-89,50,50,50,-149,-159,-160,-156,-336,50,-183,-30,50,50,-74,50,50,50,50,-174,-175,50,-143,-140,50,-141,-145,-76,-79,-82,-75,-77,50,-81,-215,-214,-80,-216,-78,-127,50,-153,50,-151,-148,-157,-168,-69,-36,-35,50,50,50,-234,-233,50,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,50,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'MINUSEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,358,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'INT_CONST_OCT':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,145,-335,-28,-182,-27,145,-337,-87,-72,-337,145,-286,-285,145,145,-283,-287,-288,145,-284,145,145,145,-336,-183,145,145,-28,-337,145,-28,-337,-337,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,145,-337,-76,-79,-82,-75,145,-77,145,145,-81,-215,-214,-80,-216,145,-78,145,145,-69,-284,145,145,-284,145,145,-244,-247,-245,-241,-242,-246,-248,145,-250,-251,-243,-249,-12,145,145,-11,145,145,145,145,-234,-233,145,-231,145,145,-217,145,-230,145,-84,-218,145,145,145,-337,-337,-198,145,145,145,-337,-284,-229,-232,145,-221,145,-83,-219,-68,145,-28,-337,145,-11,145,145,-220,145,145,145,-284,145,145,145,-337,145,-225,-224,-222,-84,145,145,145,-226,-223,145,-228,-227,]),'TIMESEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,367,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'OR':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,259,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,259,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,259,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'SHORT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[2,-337,-113,-128,2,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,2,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,2,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,2,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,2,-131,-95,-101,-97,2,-53,-126,2,-88,2,2,-93,2,-147,-335,-146,2,-167,-166,-182,-100,-126,2,-87,-90,-94,-92,-61,-72,2,-144,-142,2,2,2,-73,2,-89,2,2,2,-149,-159,-160,-156,-336,2,-183,-30,2,2,-74,2,2,2,2,-174,-175,2,-143,-140,2,-141,-145,-76,-79,-82,-75,-77,2,-81,-215,-214,-80,-216,-78,-127,2,-153,2,-151,-148,-157,-168,-69,-36,-35,2,2,2,-234,-233,2,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,2,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'RETURN':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,290,-336,-76,-79,-82,-75,-77,290,-81,-215,-214,-80,-216,290,-78,-69,-234,-233,-231,290,-217,-230,290,-84,-218,290,-229,-232,-221,290,-83,-219,-68,290,-220,290,290,-225,-224,-222,-84,290,290,-226,-223,290,-228,-227,]),'RSHIFTEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,368,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'_ALIGNAS':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,109,111,118,119,123,124,129,142,147,174,177,180,181,182,184,185,186,187,188,189,190,191,192,200,211,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[8,8,-113,-128,8,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,8,-120,-115,-65,-102,8,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,8,8,-119,8,-114,-130,8,-118,-71,-103,8,-131,-96,-98,8,-131,-95,-101,-97,8,-53,8,8,-88,8,8,-147,-335,-146,8,-167,-166,-100,-126,8,-87,-61,-72,8,-144,-142,8,8,-73,8,-89,8,8,8,-149,-159,-160,-156,-336,8,-30,8,-74,8,8,8,8,-174,-175,8,-143,-140,8,-141,-145,-76,-79,-82,-75,-77,8,-81,-215,-214,-80,-216,-78,-127,8,-153,8,-151,-148,-157,-168,-69,-36,-35,8,8,8,-234,-233,8,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,8,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'RESTRICT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,103,105,109,111,117,118,119,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[39,39,-113,-128,39,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,39,-120,-115,-65,-102,39,-131,-108,-238,-111,39,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,39,39,-119,39,-114,-130,39,-118,-71,-103,39,-131,-96,-98,39,-131,-95,-101,-97,39,-53,39,39,-88,39,39,-147,-335,-146,39,-167,-166,39,-182,-100,-126,39,39,-87,-61,-72,39,39,-144,-142,39,39,39,-73,39,-89,39,39,39,-149,-159,-160,-156,-336,39,-183,-30,39,39,39,39,39,-74,39,39,39,39,-174,-175,39,-143,-140,39,-141,-145,39,-76,-79,-82,-75,-77,39,-81,-215,-214,-80,-216,-78,-127,39,-153,39,-151,-148,-157,-168,-69,-36,-35,39,39,39,-234,-233,39,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,39,39,-229,-232,-221,-83,-219,-68,-33,-32,39,39,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'STATIC':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,105,109,111,117,118,119,123,124,128,129,180,181,182,187,191,198,200,205,211,219,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,536,554,555,557,558,575,576,578,579,],[10,10,-113,-128,10,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,10,-120,-115,-65,-102,10,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,10,10,-119,10,-114,-130,10,-118,-71,-103,10,-131,-96,-98,10,-131,-95,-101,-97,-53,10,10,-88,10,-147,-335,-146,-167,-166,-182,-100,-126,206,10,-87,-61,-72,220,10,-73,10,-89,-149,-336,-183,-30,338,10,353,-74,-174,-175,10,-76,-79,-82,-75,-77,10,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,10,10,10,-234,-233,10,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,515,10,-229,-232,-221,-83,-219,-68,-33,-32,542,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'SIZEOF':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,146,-335,-28,-182,-27,146,-337,-87,-72,-337,146,-286,-285,146,146,-283,-287,-288,146,-284,146,146,146,-336,-183,146,146,-28,-337,146,-28,-337,-337,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,146,-337,-76,-79,-82,-75,146,-77,146,146,-81,-215,-214,-80,-216,146,-78,146,146,-69,-284,146,146,-284,146,146,-244,-247,-245,-241,-242,-246,-248,146,-250,-251,-243,-249,-12,146,146,-11,146,146,146,146,-234,-233,146,-231,146,146,-217,146,-230,146,-84,-218,146,146,146,-337,-337,-198,146,146,146,-337,-284,-229,-232,146,-221,146,-83,-219,-68,146,-28,-337,146,-11,146,146,-220,146,146,146,-284,146,146,146,-337,146,-225,-224,-222,-84,146,146,146,-226,-223,146,-228,-227,]),'UNSIGNED':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[22,-337,-113,-128,22,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,22,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,22,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,22,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,22,-131,-95,-101,-97,22,-53,-126,22,-88,22,22,-93,22,-147,-335,-146,22,-167,-166,-182,-100,-126,22,-87,-90,-94,-92,-61,-72,22,-144,-142,22,22,22,-73,22,-89,22,22,22,-149,-159,-160,-156,-336,22,-183,-30,22,22,-74,22,22,22,22,-174,-175,22,-143,-140,22,-141,-145,-76,-79,-82,-75,-77,22,-81,-215,-214,-80,-216,-78,-127,22,-153,22,-151,-148,-157,-168,-69,-36,-35,22,22,22,-234,-233,22,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,22,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'UNION':([0,1,3,7,10,11,13,14,16,17,19,20,21,25,26,27,29,30,38,39,40,42,45,47,48,52,53,55,58,59,61,62,63,64,65,66,67,75,85,86,87,90,91,93,94,95,97,99,105,118,119,120,121,122,123,124,129,172,174,180,181,182,184,185,186,188,189,190,191,198,200,214,223,229,231,233,239,240,241,267,278,284,285,286,289,291,298,300,301,302,303,305,308,312,313,315,318,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,446,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[28,-337,-128,-106,-104,-107,-105,-64,-60,-67,-66,-109,28,-65,-102,-337,-131,-108,-63,-129,28,-29,-62,-70,-52,-337,-337,-337,-130,28,-71,-103,-337,-9,-131,-91,-10,28,28,-53,-337,-88,28,28,-93,28,-335,28,-182,28,-87,-90,-94,-92,-61,-72,28,28,28,-73,28,-89,28,28,28,-159,-160,-156,-336,-183,-30,28,-74,28,28,28,28,-174,-175,28,28,-76,-79,-82,-75,-77,28,-81,-215,-214,-80,-216,-78,-127,28,28,-157,-69,-36,-35,28,28,28,-234,-233,28,-231,-217,-230,-81,-84,-218,-158,-31,-34,28,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'COLON':([2,3,5,6,12,22,23,33,34,36,39,42,43,44,46,48,49,50,54,56,58,60,73,74,76,77,86,96,98,100,101,111,127,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,179,187,191,192,200,216,224,225,230,232,234,235,236,237,238,240,241,261,263,268,269,272,273,275,279,280,295,310,312,314,316,317,324,328,340,341,355,356,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,407,411,431,442,443,445,448,449,453,454,464,465,468,476,478,480,482,483,486,487,511,512,518,519,524,547,551,565,],[-113,-128,-124,-110,-125,-120,-115,-238,-111,-122,-129,-29,-121,-116,-112,-52,-123,-117,-119,-114,-130,-118,-54,-179,-131,-37,-53,-147,-146,-167,-166,-126,-55,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-178,-149,-336,319,-30,-38,-274,-239,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,-279,-143,-235,-278,-140,-141,-145,429,440,-127,-153,-151,-148,447,-168,-36,-35,-44,-43,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,481,-270,-258,-296,-295,-294,-293,-292,-305,502,-152,-150,319,-170,-169,-31,-34,-39,-42,-240,-237,-281,-282,-290,-291,-236,-275,-33,-32,-41,-40,-254,-306,-299,-300,]),'$end':([0,9,14,16,17,19,25,38,45,47,57,59,61,119,123,124,180,191,223,332,439,510,],[-337,0,-64,-60,-67,-66,-65,-63,-62,-70,-59,-58,-71,-87,-61,-72,-73,-336,-74,-69,-218,-68,]),'WSTRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,148,-335,-28,-182,-27,148,-337,-87,-72,-337,148,-286,-285,-330,148,-327,148,-283,-287,237,-328,-288,-329,148,-284,148,148,148,-336,-183,148,148,-28,-337,148,-28,-337,-337,148,148,148,-334,-332,-331,-333,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,148,-337,-76,-79,-82,-75,148,-77,148,148,-81,-215,-214,-80,-216,148,-78,148,148,-69,-284,148,148,-284,148,148,-244,-247,-245,-241,-242,-246,-248,148,-250,-251,-243,-249,-12,148,148,-11,148,148,148,148,-234,-233,148,-231,148,148,-217,148,-230,148,-84,-218,148,148,148,-337,-337,-198,148,148,148,-337,-284,-229,-232,148,-221,148,-83,-219,-68,148,-28,-337,148,-11,148,148,-220,148,148,148,-284,148,148,148,-337,148,-225,-224,-222,-84,148,148,148,-226,-223,148,-228,-227,]),'DIVIDE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,252,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,252,252,252,252,252,252,252,252,252,252,-257,-256,252,252,252,252,252,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'FOR':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,292,-336,-76,-79,-82,-75,-77,292,-81,-215,-214,-80,-216,292,-78,-69,-234,-233,-231,292,-217,-230,292,-84,-218,292,-229,-232,-221,292,-83,-219,-68,292,-220,292,292,-225,-224,-222,-84,292,292,-226,-223,292,-228,-227,]),'PLUSPLUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,145,146,148,149,150,151,152,153,154,156,160,161,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,227,229,230,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,482,483,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,149,-335,-28,-182,-27,149,-337,-87,-72,-337,149,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-314,149,-327,149,-283,-287,-325,-304,-322,-302,-315,-289,-328,-316,-288,-329,-320,263,-323,149,-284,149,149,-312,149,-336,-183,149,149,-28,-337,149,-28,-337,-337,149,-326,149,149,-334,-332,-331,-333,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,149,-298,-297,149,149,-337,-76,-79,-82,-75,149,-77,149,149,-81,-215,-214,-80,-216,149,-78,-312,149,149,-69,-284,149,149,-284,149,149,-244,-247,-245,-241,-242,-246,-248,149,-250,-251,-243,-249,-12,149,149,-11,-296,-295,-294,-293,-292,-305,149,149,149,149,-234,-233,149,-231,149,149,-217,149,-230,149,-84,-218,149,149,149,-337,-337,-198,149,149,-290,-291,149,-337,-284,-229,-232,149,-221,149,-83,-219,-68,149,-28,-337,149,-11,149,149,-220,149,149,149,-284,149,149,-306,149,-337,-299,149,-225,-224,-222,-84,-300,149,149,149,-226,-223,149,-228,-227,]),'EQUALS':([42,48,73,74,75,77,78,86,110,127,132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,179,191,197,200,216,224,230,232,234,235,236,237,238,261,263,268,273,310,340,341,355,356,371,376,402,403,404,405,407,411,453,454,464,465,470,474,478,480,482,483,487,511,512,518,519,520,547,551,565,],[-29,-52,-54,-179,-178,-37,131,-53,201,-55,-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-178,-336,329,-30,-38,360,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-36,-35,-44,-43,-199,475,-296,-295,-294,-293,-292,-305,-31,-34,-39,-42,-202,-200,-281,-282,-290,-291,-275,-33,-32,-41,-40,-201,-306,-299,-300,]),'ELSE':([61,124,191,284,285,286,289,291,300,303,308,332,424,425,428,435,437,438,439,496,497,500,505,506,510,536,554,555,557,558,575,576,578,579,],[-71,-72,-336,-76,-79,-82,-75,-77,-81,-80,-78,-69,-234,-233,-231,-230,-81,-84,-218,-229,-232,-221,-83,-219,-68,-220,-225,-224,-222,569,-226,-223,-228,-227,]),'ANDEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,365,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'EQ':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,256,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,256,-262,-260,-264,-268,-263,-259,-266,256,-257,-256,-265,256,-267,256,256,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'AND':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,144,145,146,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,224,227,229,230,231,232,233,234,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,273,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,150,-335,-28,-182,-27,150,-337,-87,-72,-337,150,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-274,-314,150,-327,150,-283,-287,-325,-304,-322,-302,-255,-315,-289,257,-328,-316,-288,-329,-320,-276,-323,150,-284,150,150,-312,150,-336,-183,150,150,-28,-337,150,-28,-337,-274,-337,150,-326,150,-280,150,-277,-334,-332,-331,-333,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,150,-298,-297,150,150,-279,-278,-337,-76,-79,-82,-75,150,-77,150,150,-81,-215,-214,-80,-216,150,-78,-312,150,150,-69,-284,150,150,-284,150,150,-244,-247,-245,-241,-242,-246,-248,150,-250,-251,-243,-249,-12,150,150,-11,-261,257,-262,-260,-264,-268,-263,-259,-266,257,-257,-256,-265,257,-267,-269,257,-258,-296,-295,-294,-293,-292,-305,150,150,150,150,-234,-233,150,-231,150,150,-217,150,-230,150,-84,-218,150,150,150,-337,-337,-198,150,-281,-282,150,-290,-291,150,-275,-337,-284,-229,-232,150,-221,150,-83,-219,-68,150,-28,-337,150,-11,150,150,-220,150,150,150,-284,150,150,-306,150,-337,-299,150,-225,-224,-222,-84,-300,150,150,150,-226,-223,150,-228,-227,]),'TYPEID':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,72,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,103,104,105,106,109,111,118,119,120,121,122,123,124,126,129,142,147,172,174,180,181,182,184,185,186,187,188,189,190,191,192,198,199,200,202,211,214,223,229,231,233,239,240,241,262,264,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,344,350,422,424,425,427,428,432,435,437,438,439,442,443,445,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[33,-337,-113,-128,77,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,33,-120,-115,-154,-65,-102,-126,-155,-131,-108,96,100,-238,-111,-337,-122,-63,-129,33,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,33,-118,-71,-103,-337,-9,-131,-91,-10,-96,77,-98,77,33,-131,-95,-101,-97,33,-53,-126,77,-88,33,33,-93,33,-147,-335,-146,33,-167,-166,-28,-180,-182,-27,-100,-126,33,-87,-90,-94,-92,-61,-72,77,33,-144,-142,33,33,-73,33,-89,33,33,33,-149,-159,-160,-156,-336,77,-183,-181,-30,77,347,33,-74,33,33,33,33,-174,-175,402,404,33,-143,-140,33,-141,-145,-76,-79,-82,-75,-77,33,-81,-215,-214,-80,-216,-78,-127,33,-153,33,-151,-148,-157,-168,-69,-36,-35,33,347,33,33,-234,-233,33,-231,-217,-230,-81,-84,-218,-152,-150,77,-158,-170,-169,-31,-34,33,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LBRACE':([21,24,28,31,32,42,48,61,75,86,88,90,92,93,96,97,98,100,101,119,124,130,131,181,182,191,200,201,227,229,284,285,286,289,291,298,300,301,302,303,305,307,308,332,340,341,369,375,377,413,424,425,428,429,432,435,437,438,439,440,453,454,472,475,477,478,479,488,496,497,500,502,505,506,510,511,512,521,522,535,536,537,539,550,554,555,557,558,569,574,575,576,577,578,579,],[-337,-154,-155,97,97,-29,-52,-71,-337,-53,-7,-88,97,-8,97,-335,97,97,97,-87,-72,97,97,97,-89,-336,-30,97,-337,97,-76,-79,-82,-75,-77,97,-81,-215,-214,-80,-216,97,-78,-69,-36,-35,-12,97,-11,97,-234,-233,-231,97,-217,-230,97,-84,-218,97,-31,-34,-337,-198,97,97,97,-337,-229,-232,-221,97,-83,-219,-68,-33,-32,97,-11,97,-220,97,97,-337,-225,-224,-222,-84,97,97,-226,-223,97,-228,-227,]),'PPHASH':([0,14,16,17,19,25,38,45,47,59,61,119,123,124,180,191,223,332,439,510,],[47,-64,-60,-67,-66,-65,-63,-62,-70,47,-71,-87,-61,-72,-73,-336,-74,-69,-218,-68,]),'INT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[56,-337,-113,-128,56,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,56,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,56,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,56,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,56,-131,-95,-101,-97,56,-53,-126,56,-88,56,56,-93,56,-147,-335,-146,56,-167,-166,-182,-100,-126,56,-87,-90,-94,-92,-61,-72,56,-144,-142,56,56,56,-73,56,-89,56,56,56,-149,-159,-160,-156,-336,56,-183,-30,56,56,-74,56,56,56,56,-174,-175,56,-143,-140,56,-141,-145,-76,-79,-82,-75,-77,56,-81,-215,-214,-80,-216,-78,-127,56,-153,56,-151,-148,-157,-168,-69,-36,-35,56,56,56,-234,-233,56,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,56,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'SIGNED':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[54,-337,-113,-128,54,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,54,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,54,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,54,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,54,-131,-95,-101,-97,54,-53,-126,54,-88,54,54,-93,54,-147,-335,-146,54,-167,-166,-182,-100,-126,54,-87,-90,-94,-92,-61,-72,54,-144,-142,54,54,54,-73,54,-89,54,54,54,-149,-159,-160,-156,-336,54,-183,-30,54,54,-74,54,54,54,54,-174,-175,54,-143,-140,54,-141,-145,-76,-79,-82,-75,-77,54,-81,-215,-214,-80,-216,-78,-127,54,-153,54,-151,-148,-157,-168,-69,-36,-35,54,54,54,-234,-233,54,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,54,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'CONTINUE':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,293,-336,-76,-79,-82,-75,-77,293,-81,-215,-214,-80,-216,293,-78,-69,-234,-233,-231,293,-217,-230,293,-84,-218,293,-229,-232,-221,293,-83,-219,-68,293,-220,293,293,-225,-224,-222,-84,293,293,-226,-223,293,-228,-227,]),'NOT':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,151,-335,-28,-182,-27,151,-337,-87,-72,-337,151,-286,-285,151,151,-283,-287,-288,151,-284,151,151,151,-336,-183,151,151,-28,-337,151,-28,-337,-337,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,151,-337,-76,-79,-82,-75,151,-77,151,151,-81,-215,-214,-80,-216,151,-78,151,151,-69,-284,151,151,-284,151,151,-244,-247,-245,-241,-242,-246,-248,151,-250,-251,-243,-249,-12,151,151,-11,151,151,151,151,-234,-233,151,-231,151,151,-217,151,-230,151,-84,-218,151,151,151,-337,-337,-198,151,151,151,-337,-284,-229,-232,151,-221,151,-83,-219,-68,151,-28,-337,151,-11,151,151,-220,151,151,151,-284,151,151,151,-337,151,-225,-224,-222,-84,151,151,151,-226,-223,151,-228,-227,]),'OREQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,366,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'MOD':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,260,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,260,260,260,260,260,260,260,260,260,260,-257,-256,260,260,260,260,260,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'RSHIFT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,242,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,242,-262,-260,242,242,242,-259,242,242,-257,-256,242,242,242,242,242,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'DEFAULT':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,295,-336,-76,-79,-82,-75,-77,295,-81,-215,-214,-80,-216,295,-78,-69,-234,-233,-231,295,-217,-230,295,-84,-218,295,-229,-232,-221,295,-83,-219,-68,295,-220,295,295,-225,-224,-222,-84,295,295,-226,-223,295,-228,-227,]),'_NORETURN':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[20,20,-113,-128,20,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,20,-120,-115,-65,-102,20,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,20,20,-119,20,-114,-130,20,-118,-71,-103,20,-131,-96,-98,20,-131,-95,-101,-97,-53,20,20,-88,20,-147,-335,-146,-167,-166,-100,-126,20,-87,-61,-72,20,-73,20,-89,-149,-336,-30,20,-74,-174,-175,20,-76,-79,-82,-75,-77,20,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,20,20,20,-234,-233,20,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,20,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'__INT128':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[43,-337,-113,-128,43,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,43,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,43,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,43,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,43,-131,-95,-101,-97,43,-53,-126,43,-88,43,43,-93,43,-147,-335,-146,43,-167,-166,-182,-100,-126,43,-87,-90,-94,-92,-61,-72,43,-144,-142,43,43,43,-73,43,-89,43,43,43,-149,-159,-160,-156,-336,43,-183,-30,43,43,-74,43,43,43,43,-174,-175,43,-143,-140,43,-141,-145,-76,-79,-82,-75,-77,43,-81,-215,-214,-80,-216,-78,-127,43,-153,43,-151,-148,-157,-168,-69,-36,-35,43,43,43,-234,-233,43,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,43,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'WHILE':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,436,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,296,-336,-76,-79,-82,-75,-77,296,-81,-215,-214,-80,-216,296,-78,-69,-234,-233,-231,296,-217,-230,504,296,-84,-218,296,-229,-232,-221,296,-83,-219,-68,296,-220,296,296,-225,-224,-222,-84,296,296,-226,-223,296,-228,-227,]),'U8CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,154,-335,-28,-182,-27,154,-337,-87,-72,-337,154,-286,-285,154,154,-283,-287,-288,154,-284,154,154,154,-336,-183,154,154,-28,-337,154,-28,-337,-337,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,154,-337,-76,-79,-82,-75,154,-77,154,154,-81,-215,-214,-80,-216,154,-78,154,154,-69,-284,154,154,-284,154,154,-244,-247,-245,-241,-242,-246,-248,154,-250,-251,-243,-249,-12,154,154,-11,154,154,154,154,-234,-233,154,-231,154,154,-217,154,-230,154,-84,-218,154,154,154,-337,-337,-198,154,154,154,-337,-284,-229,-232,154,-221,154,-83,-219,-68,154,-28,-337,154,-11,154,154,-220,154,154,154,-284,154,154,154,-337,154,-225,-224,-222,-84,154,154,154,-226,-223,154,-228,-227,]),'_ALIGNOF':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,155,-335,-28,-182,-27,155,-337,-87,-72,-337,155,-286,-285,155,155,-283,-287,-288,155,-284,155,155,155,-336,-183,155,155,-28,-337,155,-28,-337,-337,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,155,-337,-76,-79,-82,-75,155,-77,155,155,-81,-215,-214,-80,-216,155,-78,155,155,-69,-284,155,155,-284,155,155,-244,-247,-245,-241,-242,-246,-248,155,-250,-251,-243,-249,-12,155,155,-11,155,155,155,155,-234,-233,155,-231,155,155,-217,155,-230,155,-84,-218,155,155,155,-337,-337,-198,155,155,155,-337,-284,-229,-232,155,-221,155,-83,-219,-68,155,-28,-337,155,-11,155,155,-220,155,155,155,-284,155,155,155,-337,155,-225,-224,-222,-84,155,155,155,-226,-223,155,-228,-227,]),'EXTERN':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[13,13,-113,-128,13,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,13,-120,-115,-65,-102,13,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,13,13,-119,13,-114,-130,13,-118,-71,-103,13,-131,-96,-98,13,-131,-95,-101,-97,-53,13,13,-88,13,-147,-335,-146,-167,-166,-100,-126,13,-87,-61,-72,13,-73,13,-89,-149,-336,-30,13,-74,-174,-175,13,-76,-79,-82,-75,-77,13,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,13,13,13,-234,-233,13,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,13,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'CASE':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,297,-336,-76,-79,-82,-75,-77,297,-81,-215,-214,-80,-216,297,-78,-69,-234,-233,-231,297,-217,-230,297,-84,-218,297,-229,-232,-221,297,-83,-219,-68,297,-220,297,297,-225,-224,-222,-84,297,297,-226,-223,297,-228,-227,]),'LAND':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,255,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,255,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'REGISTER':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[62,62,-113,-128,62,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,62,-120,-115,-65,-102,62,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,62,62,-119,62,-114,-130,62,-118,-71,-103,62,-131,-96,-98,62,-131,-95,-101,-97,-53,62,62,-88,62,-147,-335,-146,-167,-166,-100,-126,62,-87,-61,-72,62,-73,62,-89,-149,-336,-30,62,-74,-174,-175,62,-76,-79,-82,-75,-77,62,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,62,62,62,-234,-233,62,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,62,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'MODEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,359,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'NE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,247,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,247,-262,-260,-264,-268,-263,-259,-266,247,-257,-256,-265,247,-267,247,247,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'SWITCH':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,299,-336,-76,-79,-82,-75,-77,299,-81,-215,-214,-80,-216,299,-78,-69,-234,-233,-231,299,-217,-230,299,-84,-218,299,-229,-232,-221,299,-83,-219,-68,299,-220,299,299,-225,-224,-222,-84,299,299,-226,-223,299,-228,-227,]),'INT_CONST_HEX':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,160,-335,-28,-182,-27,160,-337,-87,-72,-337,160,-286,-285,160,160,-283,-287,-288,160,-284,160,160,160,-336,-183,160,160,-28,-337,160,-28,-337,-337,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,160,-337,-76,-79,-82,-75,160,-77,160,160,-81,-215,-214,-80,-216,160,-78,160,160,-69,-284,160,160,-284,160,160,-244,-247,-245,-241,-242,-246,-248,160,-250,-251,-243,-249,-12,160,160,-11,160,160,160,160,-234,-233,160,-231,160,160,-217,160,-230,160,-84,-218,160,160,160,-337,-337,-198,160,160,160,-337,-284,-229,-232,160,-221,160,-83,-219,-68,160,-28,-337,160,-11,160,160,-220,160,160,160,-284,160,160,160,-337,160,-225,-224,-222,-84,160,160,160,-226,-223,160,-228,-227,]),'_COMPLEX':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[60,-337,-113,-128,60,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,60,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,60,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,60,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,60,-131,-95,-101,-97,60,-53,-126,60,-88,60,60,-93,60,-147,-335,-146,60,-167,-166,-182,-100,-126,60,-87,-90,-94,-92,-61,-72,60,-144,-142,60,60,60,-73,60,-89,60,60,60,-149,-159,-160,-156,-336,60,-183,-30,60,60,-74,60,60,60,60,-174,-175,60,-143,-140,60,-141,-145,-76,-79,-82,-75,-77,60,-81,-215,-214,-80,-216,-78,-127,60,-153,60,-151,-148,-157,-168,-69,-36,-35,60,60,60,-234,-233,60,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,60,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PPPRAGMASTR':([61,],[124,]),'PLUSEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,362,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'U32CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,138,-335,-28,-182,-27,138,-337,-87,-72,-337,138,-286,-285,138,138,-283,-287,-288,138,-284,138,138,138,-336,-183,138,138,-28,-337,138,-28,-337,-337,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,138,-337,-76,-79,-82,-75,138,-77,138,138,-81,-215,-214,-80,-216,138,-78,138,138,-69,-284,138,138,-284,138,138,-244,-247,-245,-241,-242,-246,-248,138,-250,-251,-243,-249,-12,138,138,-11,138,138,138,138,-234,-233,138,-231,138,138,-217,138,-230,138,-84,-218,138,138,138,-337,-337,-198,138,138,138,-337,-284,-229,-232,138,-221,138,-83,-219,-68,138,-28,-337,138,-11,138,138,-220,138,138,138,-284,138,138,138,-337,138,-225,-224,-222,-84,138,138,138,-226,-223,138,-228,-227,]),'CONDOP':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,258,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'U8STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,163,-335,-28,-182,-27,163,-337,-87,-72,-337,163,-286,-285,-330,163,-327,163,-283,-287,236,-328,-288,-329,163,-284,163,163,163,-336,-183,163,163,-28,-337,163,-28,-337,-337,163,163,163,-334,-332,-331,-333,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,163,-337,-76,-79,-82,-75,163,-77,163,163,-81,-215,-214,-80,-216,163,-78,163,163,-69,-284,163,163,-284,163,163,-244,-247,-245,-241,-242,-246,-248,163,-250,-251,-243,-249,-12,163,163,-11,163,163,163,163,-234,-233,163,-231,163,163,-217,163,-230,163,-84,-218,163,163,163,-337,-337,-198,163,163,163,-337,-284,-229,-232,163,-221,163,-83,-219,-68,163,-28,-337,163,-11,163,163,-220,163,163,163,-284,163,163,163,-337,163,-225,-224,-222,-84,163,163,163,-226,-223,163,-228,-227,]),'BREAK':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,304,-336,-76,-79,-82,-75,-77,304,-81,-215,-214,-80,-216,304,-78,-69,-234,-233,-231,304,-217,-230,304,-84,-218,304,-229,-232,-221,304,-83,-219,-68,304,-220,304,304,-225,-224,-222,-84,304,304,-226,-223,304,-228,-227,]),'VOLATILE':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,103,105,109,111,117,118,119,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[58,58,-113,-128,58,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,58,-120,-115,-65,-102,58,-131,-108,-238,-111,58,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,58,58,-119,58,-114,-130,58,-118,-71,-103,58,-131,-96,-98,58,-131,-95,-101,-97,58,-53,58,58,-88,58,58,-147,-335,-146,58,-167,-166,58,-182,-100,-126,58,58,-87,-61,-72,58,58,-144,-142,58,58,58,-73,58,-89,58,58,58,-149,-159,-160,-156,-336,58,-183,-30,58,58,58,58,58,-74,58,58,58,58,-174,-175,58,-143,-140,58,-141,-145,58,-76,-79,-82,-75,-77,58,-81,-215,-214,-80,-216,-78,-127,58,-153,58,-151,-148,-157,-168,-69,-36,-35,58,58,58,-234,-233,58,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,58,58,-229,-232,-221,-83,-219,-68,-33,-32,58,58,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'PPPRAGMA':([0,14,16,17,19,25,38,45,47,59,61,97,99,119,123,124,180,181,184,185,186,188,189,190,191,223,284,285,286,289,291,298,300,301,302,303,305,307,308,313,315,318,332,424,425,428,429,432,435,437,438,439,440,446,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[61,-64,-60,-67,-66,-65,-63,-62,-70,61,-71,-335,61,-87,-61,-72,-73,61,61,61,61,-159,-160,-156,-336,-74,-76,-79,-82,-75,-77,61,-81,-215,-214,-80,-216,61,-78,61,61,-157,-69,-234,-233,-231,61,-217,-230,61,-84,-218,61,-158,-229,-232,-221,61,-83,-219,-68,61,-220,61,61,-225,-224,-222,-84,61,61,-226,-223,61,-228,-227,]),'INLINE':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[30,30,-113,-128,30,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,30,-120,-115,-65,-102,30,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,30,30,-119,30,-114,-130,30,-118,-71,-103,30,-131,-96,-98,30,-131,-95,-101,-97,-53,30,30,-88,30,-147,-335,-146,-167,-166,-100,-126,30,-87,-61,-72,30,-73,30,-89,-149,-336,-30,30,-74,-174,-175,30,-76,-79,-82,-75,-77,30,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,30,30,30,-234,-233,30,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,30,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'INT_CONST_BIN':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,164,-335,-28,-182,-27,164,-337,-87,-72,-337,164,-286,-285,164,164,-283,-287,-288,164,-284,164,164,164,-336,-183,164,164,-28,-337,164,-28,-337,-337,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,164,-337,-76,-79,-82,-75,164,-77,164,164,-81,-215,-214,-80,-216,164,-78,164,164,-69,-284,164,164,-284,164,164,-244,-247,-245,-241,-242,-246,-248,164,-250,-251,-243,-249,-12,164,164,-11,164,164,164,164,-234,-233,164,-231,164,164,-217,164,-230,164,-84,-218,164,164,164,-337,-337,-198,164,164,164,-337,-284,-229,-232,164,-221,164,-83,-219,-68,164,-28,-337,164,-11,164,164,-220,164,164,164,-284,164,164,164,-337,164,-225,-224,-222,-84,164,164,164,-226,-223,164,-228,-227,]),'DO':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,307,-336,-76,-79,-82,-75,-77,307,-81,-215,-214,-80,-216,307,-78,-69,-234,-233,-231,307,-217,-230,307,-84,-218,307,-229,-232,-221,307,-83,-219,-68,307,-220,307,307,-225,-224,-222,-84,307,307,-226,-223,307,-228,-227,]),'LNOT':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,165,-335,-28,-182,-27,165,-337,-87,-72,-337,165,-286,-285,165,165,-283,-287,-288,165,-284,165,165,165,-336,-183,165,165,-28,-337,165,-28,-337,-337,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,165,-337,-76,-79,-82,-75,165,-77,165,165,-81,-215,-214,-80,-216,165,-78,165,165,-69,-284,165,165,-284,165,165,-244,-247,-245,-241,-242,-246,-248,165,-250,-251,-243,-249,-12,165,165,-11,165,165,165,165,-234,-233,165,-231,165,165,-217,165,-230,165,-84,-218,165,165,165,-337,-337,-198,165,165,165,-337,-284,-229,-232,165,-221,165,-83,-219,-68,165,-28,-337,165,-11,165,165,-220,165,165,165,-284,165,165,165,-337,165,-225,-224,-222,-84,165,165,165,-226,-223,165,-228,-227,]),'CONST':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,85,86,87,89,90,93,95,96,97,98,99,100,101,103,105,109,111,117,118,119,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[3,3,-113,-128,3,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,3,-120,-115,-65,-102,3,-131,-108,-238,-111,3,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,3,3,-119,3,-114,-130,3,-118,-71,-103,3,-131,-96,-98,3,-131,-95,-101,-97,3,-53,3,3,-88,3,3,-147,-335,-146,3,-167,-166,3,-182,-100,-126,3,3,-87,-61,-72,3,3,-144,-142,3,3,3,-73,3,-89,3,3,3,-149,-159,-160,-156,-336,3,-183,-30,3,3,3,3,3,-74,3,3,3,3,-174,-175,3,-143,-140,3,-141,-145,3,-76,-79,-82,-75,-77,3,-81,-215,-214,-80,-216,-78,-127,3,-153,3,-151,-148,-157,-168,-69,-36,-35,3,3,3,-234,-233,3,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,3,3,-229,-232,-221,-83,-219,-68,-33,-32,3,3,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LSHIFT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,244,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,244,-262,-260,244,244,244,-259,244,244,-257,-256,244,244,244,244,244,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'LOR':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,243,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,167,-335,-28,-182,-27,167,-337,-87,-72,-337,167,-286,-285,167,167,-283,-287,-288,167,-284,167,167,167,-336,-183,167,167,-28,-337,167,-28,-337,-337,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,167,-337,-76,-79,-82,-75,167,-77,167,167,-81,-215,-214,-80,-216,167,-78,167,167,-69,-284,167,167,-284,167,167,-244,-247,-245,-241,-242,-246,-248,167,-250,-251,-243,-249,-12,167,167,-11,167,167,167,167,-234,-233,167,-231,167,167,-217,167,-230,167,-84,-218,167,167,167,-337,-337,-198,167,167,167,-337,-284,-229,-232,167,-221,167,-83,-219,-68,167,-28,-337,167,-11,167,167,-220,167,167,167,-284,167,167,167,-337,167,-225,-224,-222,-84,167,167,167,-226,-223,167,-228,-227,]),'U16STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,139,146,148,149,150,151,153,163,165,166,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,166,-335,-28,-182,-27,166,-337,-87,-72,-337,166,-286,-285,-330,166,-327,166,-283,-287,238,-328,-288,-329,166,-284,166,166,166,-336,-183,166,166,-28,-337,166,-28,-337,-337,166,166,166,-334,-332,-331,-333,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,166,-337,-76,-79,-82,-75,166,-77,166,166,-81,-215,-214,-80,-216,166,-78,166,166,-69,-284,166,166,-284,166,166,-244,-247,-245,-241,-242,-246,-248,166,-250,-251,-243,-249,-12,166,166,-11,166,166,166,166,-234,-233,166,-231,166,166,-217,166,-230,166,-84,-218,166,166,166,-337,-337,-198,166,166,166,-337,-284,-229,-232,166,-221,166,-83,-219,-68,166,-28,-337,166,-11,166,166,-220,166,166,166,-284,166,166,166,-337,166,-225,-224,-222,-84,166,166,166,-226,-223,166,-228,-227,]),'RBRACE':([61,97,99,119,124,132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,181,184,185,186,188,189,190,191,195,196,197,224,225,227,228,230,232,234,235,236,237,238,261,263,268,273,284,285,286,289,291,298,300,301,302,303,305,306,308,309,313,315,318,325,326,327,332,370,374,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,424,425,428,432,435,437,438,439,446,450,451,468,469,472,473,476,478,480,482,483,487,496,497,500,505,506,510,523,524,528,536,546,547,550,551,554,555,557,558,565,575,576,578,579,],[-71,-335,191,-87,-72,-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-337,191,191,191,-159,-160,-156,-336,-171,191,-176,-274,-239,-337,-193,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-76,-79,-82,-75,-77,-6,-81,-215,-214,-80,-216,-5,-78,191,191,191,-157,191,191,-172,-69,191,-22,-21,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,-234,-233,-231,-217,-230,-81,-84,-218,-158,-173,-177,-240,-194,191,-196,-237,-281,-282,-290,-291,-275,-229,-232,-221,-83,-219,-68,-195,-254,191,-220,-197,-306,191,-299,-225,-224,-222,-84,-300,-226,-223,-228,-227,]),'_BOOL':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[34,-337,-113,-128,34,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,34,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,34,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,34,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,34,-131,-95,-101,-97,34,-53,-126,34,-88,34,34,-93,34,-147,-335,-146,34,-167,-166,-182,-100,-126,34,-87,-90,-94,-92,-61,-72,34,-144,-142,34,34,34,-73,34,-89,34,34,34,-149,-159,-160,-156,-336,34,-183,-30,34,34,-74,34,34,34,34,-174,-175,34,-143,-140,34,-141,-145,-76,-79,-82,-75,-77,34,-81,-215,-214,-80,-216,-78,-127,34,-153,34,-151,-148,-157,-168,-69,-36,-35,34,34,34,-234,-233,34,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,34,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LE':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,246,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,246,-262,-260,-264,246,-263,-259,-266,246,-257,-256,-265,246,246,246,246,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'SEMI':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,70,71,73,74,75,76,77,78,79,80,81,82,83,84,86,87,89,91,94,96,97,98,99,100,101,108,109,110,111,113,114,115,119,120,121,122,123,124,127,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,179,180,181,184,185,186,187,188,189,190,191,192,200,216,217,223,224,225,226,228,230,232,234,235,236,237,238,240,241,261,263,268,269,272,273,275,279,280,284,285,286,288,289,290,291,293,294,298,300,301,302,303,304,305,306,307,308,310,312,313,314,315,316,317,318,320,321,322,323,324,328,330,331,332,340,341,355,356,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,423,424,425,426,427,428,429,432,433,435,437,438,439,440,442,443,444,446,448,449,453,454,464,465,468,469,476,478,480,482,483,486,487,496,497,498,499,500,502,505,506,508,509,510,511,512,518,519,523,524,533,534,535,536,537,539,547,551,552,554,555,557,558,565,568,569,574,575,576,577,578,579,],[19,-337,-113,-128,-337,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,-337,-29,-121,-116,-62,-112,-70,-52,-123,-117,119,-337,-337,-119,-337,-114,-130,19,-118,-71,-103,-337,-9,-131,-91,-10,-96,-20,-98,-54,-179,-178,-131,-37,-134,-85,-95,-101,-97,-19,-132,-53,-126,-337,-337,-93,-147,-335,-146,188,-167,-166,-136,-100,-138,-126,-16,-86,-15,-87,-90,-94,-92,-61,-72,-55,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-178,-73,-337,188,188,188,-149,-159,-160,-156,-336,-337,-30,-38,-133,-74,-274,-239,-135,-193,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,-279,-143,-235,-278,-140,-141,-145,-76,-79,-82,424,-75,425,-77,428,-14,-337,-81,-215,-214,-80,435,-216,-13,-337,-78,-312,-127,188,-153,188,-151,-148,-157,-26,-25,446,-161,-163,-168,-139,-137,-69,-36,-35,-44,-43,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,-292,-305,496,-234,-233,497,-337,-231,-337,-217,-13,-230,-81,-84,-218,-337,-152,-150,-165,-158,-170,-169,-31,-34,-39,-42,-240,-194,-237,-281,-282,-290,-291,-236,-275,-229,-232,533,-337,-221,-337,-83,-219,-162,-164,-68,-33,-32,-41,-40,-195,-254,-337,553,-337,-220,-337,-337,-306,-299,566,-225,-224,-222,-84,-300,575,-337,-337,-226,-223,-337,-228,-227,]),'_THREAD_LOCAL':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[11,11,-113,-128,11,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,11,-120,-115,-65,-102,11,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,11,11,-119,11,-114,-130,11,-118,-71,-103,11,-131,-96,-98,11,-131,-95,-101,-97,-53,11,11,-88,11,-147,-335,-146,-167,-166,-100,-126,11,-87,-61,-72,11,-73,11,-89,-149,-336,-30,11,-74,-174,-175,11,-76,-79,-82,-75,-77,11,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,11,11,11,-234,-233,11,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,11,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'LT':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,248,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,248,-262,-260,-264,248,-263,-259,-266,248,-257,-256,-265,248,248,248,248,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'COMMA':([2,3,5,6,7,10,11,12,13,18,20,22,23,26,27,30,33,34,35,36,39,42,43,44,46,48,49,50,54,56,58,60,62,68,70,71,73,74,75,76,77,78,80,81,82,84,86,96,98,100,101,103,104,105,106,108,109,110,111,113,127,132,133,134,136,138,139,140,141,142,143,144,145,147,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,177,178,179,187,191,195,196,197,198,199,200,203,210,211,212,213,215,216,217,224,225,226,228,230,232,234,235,236,237,238,240,241,261,263,268,269,270,272,273,274,275,276,277,279,280,281,283,294,310,312,314,316,317,320,323,324,325,326,327,328,330,331,340,341,343,344,345,346,347,348,355,356,374,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,414,426,442,443,444,448,449,450,451,453,454,458,461,463,464,465,468,469,473,476,478,480,482,483,486,487,489,490,492,501,503,507,508,509,511,512,518,519,523,524,525,528,529,530,531,532,544,545,546,547,551,556,559,560,564,565,570,571,],[-113,-128,-124,-110,-106,-104,-107,-125,-105,-99,-109,-120,-115,-102,-126,-108,-238,-111,-337,-122,-129,-29,-121,-116,-112,-52,-123,-117,-119,-114,-130,-118,-103,-96,126,-98,-54,-179,-178,-131,-37,-134,-95,-101,-97,-132,-53,-147,-146,-167,-166,-28,-180,-182,-27,-136,-100,-138,-126,202,-55,-317,-321,-318,-303,-324,-330,-313,-319,-144,-301,-274,-314,-142,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-337,-252,-178,-149,-336,-171,327,-176,-183,-181,-30,333,-186,-337,349,350,-191,-38,-133,-274,-239,-135,-193,-326,-280,-277,-334,-332,-331,-333,-174,-175,-298,-297,-279,-143,412,-235,-278,-203,-140,-204,-1,-141,-145,-2,-206,412,-312,-127,-153,-151,-148,445,-161,-163,327,327,-172,-168,-139,-137,-36,-35,-190,-204,-56,-188,-45,-189,-44,-43,472,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,412,-270,-258,-296,-295,-294,-293,412,-292,-310,484,485,-305,-205,412,-152,-150,-165,-170,-169,-173,-177,-31,-34,-57,-192,-187,-39,-42,-240,-194,-196,-237,-281,-282,-290,-291,-236,-275,-213,-207,-211,412,412,412,-162,-164,-33,-32,-41,-40,-195,-254,-311,550,-209,-208,-210,-212,-51,-50,-197,-306,-299,412,-46,-49,412,-300,-48,-47,]),'U16CHAR_CONST':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,169,-335,-28,-182,-27,169,-337,-87,-72,-337,169,-286,-285,169,169,-283,-287,-288,169,-284,169,169,169,-336,-183,169,169,-28,-337,169,-28,-337,-337,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,169,-337,-76,-79,-82,-75,169,-77,169,169,-81,-215,-214,-80,-216,169,-78,169,169,-69,-284,169,169,-284,169,169,-244,-247,-245,-241,-242,-246,-248,169,-250,-251,-243,-249,-12,169,169,-11,169,169,169,169,-234,-233,169,-231,169,169,-217,169,-230,169,-84,-218,169,169,169,-337,-337,-198,169,169,169,-337,-284,-229,-232,169,-221,169,-83,-219,-68,169,-28,-337,169,-11,169,169,-220,169,169,169,-284,169,169,169,-337,169,-225,-224,-222,-84,169,169,169,-226,-223,169,-228,-227,]),'OFFSETOF':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,137,146,149,150,151,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,170,-335,-28,-182,-27,170,-337,-87,-72,-337,170,-286,-285,170,170,-283,-287,-288,170,-284,170,170,170,-336,-183,170,170,-28,-337,170,-28,-337,-337,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,170,-337,-76,-79,-82,-75,170,-77,170,170,-81,-215,-214,-80,-216,170,-78,170,170,-69,-284,170,170,-284,170,170,-244,-247,-245,-241,-242,-246,-248,170,-250,-251,-243,-249,-12,170,170,-11,170,170,170,170,-234,-233,170,-231,170,170,-217,170,-230,170,-84,-218,170,170,170,-337,-337,-198,170,170,170,-337,-284,-229,-232,170,-221,170,-83,-219,-68,170,-28,-337,170,-11,170,170,-220,170,170,170,-284,170,170,170,-337,170,-225,-224,-222,-84,170,170,170,-226,-223,170,-228,-227,]),'_ATOMIC':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,35,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,103,105,109,111,117,118,119,120,121,122,123,124,128,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,205,206,211,214,219,220,223,229,231,233,239,240,241,267,269,275,278,279,280,282,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,459,460,496,497,500,505,506,510,511,512,514,515,536,554,555,557,558,575,576,578,579,],[29,65,-113,-128,76,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,65,-120,-115,-65,-102,65,-131,-108,-238,-111,76,-122,-63,-129,112,-29,-121,-116,-62,-112,-70,-52,-123,-117,65,65,-119,65,-114,-130,29,-118,-71,-103,65,-9,-131,-91,-10,-96,-98,65,-131,-95,-101,-97,29,-53,65,76,-88,112,65,-93,29,-147,-335,-146,29,-167,-166,76,-182,-100,-126,76,29,-87,-90,-94,-92,-61,-72,76,29,-144,-142,65,29,76,-73,65,-89,29,29,29,-149,-159,-160,-156,-336,76,-183,-30,76,76,76,112,76,76,-74,29,29,29,29,-174,-175,29,-143,-140,29,-141,-145,76,-76,-79,-82,-75,-77,65,-81,-215,-214,-80,-216,-78,-127,29,-153,29,-151,-148,-157,-168,-69,-36,-35,29,29,29,-234,-233,65,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,76,29,-229,-232,-221,-83,-219,-68,-33,-32,76,76,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'TYPEDEF':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[7,7,-113,-128,7,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,7,-120,-115,-65,-102,7,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,7,7,-119,7,-114,-130,7,-118,-71,-103,7,-131,-96,-98,7,-131,-95,-101,-97,-53,7,7,-88,7,-147,-335,-146,-167,-166,-100,-126,7,-87,-61,-72,7,-73,7,-89,-149,-336,-30,7,-74,-174,-175,7,-76,-79,-82,-75,-77,7,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,7,7,7,-234,-233,7,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,7,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'XOR':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,251,-328,-316,-329,-320,-276,-323,-312,-336,-274,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-261,251,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,251,-267,-269,251,-258,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'AUTO':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,65,68,71,75,76,80,81,82,86,87,89,90,93,96,97,98,100,101,109,111,118,119,123,124,129,180,181,182,187,191,200,211,223,240,241,278,284,285,286,289,291,298,300,301,302,303,305,308,312,314,316,317,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[26,26,-113,-128,26,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,26,-120,-115,-65,-102,26,-131,-108,-238,-111,-122,-63,-129,-29,-121,-116,-62,-112,-70,-52,-123,-117,26,26,-119,26,-114,-130,26,-118,-71,-103,26,-131,-96,-98,26,-131,-95,-101,-97,-53,26,26,-88,26,-147,-335,-146,-167,-166,-100,-126,26,-87,-61,-72,26,-73,26,-89,-149,-336,-30,26,-74,-174,-175,26,-76,-79,-82,-75,-77,26,-81,-215,-214,-80,-216,-78,-127,-153,-151,-148,-168,-69,-36,-35,26,26,26,-234,-233,26,-231,-217,-230,-81,-84,-218,-152,-150,-170,-169,-31,-34,26,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'DIVEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,357,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'TIMES':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,22,23,25,26,27,29,30,33,34,35,36,37,38,39,40,43,44,45,46,47,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,76,80,81,82,85,87,89,91,94,96,97,98,100,101,103,104,105,106,109,111,116,117,119,120,121,122,123,124,126,128,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,156,158,160,161,162,163,164,165,166,167,168,169,171,173,174,175,176,177,180,181,187,191,192,198,201,202,204,205,206,211,218,219,220,223,224,227,229,230,231,232,233,234,235,236,237,238,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,268,269,273,275,278,279,280,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,312,314,316,317,319,328,329,332,336,338,339,342,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,442,443,445,447,448,449,459,472,475,477,478,480,481,482,483,484,487,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[35,-337,-113,-128,35,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,-120,-115,-65,-102,-126,-131,-108,-238,-111,-337,-122,35,-63,-129,35,-121,-116,-62,-112,-70,-123,-117,-337,-337,-119,-337,-114,-130,35,-118,-71,-103,-337,-9,-131,-91,-10,-96,35,-98,-131,-95,-101,-97,173,-126,35,35,-93,-147,-335,-146,-167,-166,-28,35,-182,-27,-100,-126,173,-337,-87,-90,-94,-92,-61,-72,35,-337,173,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-144,-301,-274,-314,173,-142,-327,173,-283,-287,-325,-304,-322,-302,-255,-315,-289,253,-328,-316,-288,-329,-320,-276,-323,173,-284,173,173,-312,35,-73,173,-149,-336,35,-183,173,35,336,-28,-337,35,352,-28,-337,-74,-274,-337,173,-326,173,-280,173,-277,-334,-332,-331,-333,-174,-175,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,173,-298,-297,173,173,-279,-143,-278,-140,35,-141,-145,420,-76,-79,-82,-75,173,-77,173,173,-81,-215,-214,-80,-216,173,-78,-312,-127,-153,-151,-148,173,-168,173,-69,-284,173,173,35,-284,173,173,-244,-247,-245,-241,-242,-246,-248,173,-250,-251,-243,-249,-12,173,173,-11,253,253,253,253,253,253,253,253,253,253,-257,-256,253,253,253,253,253,-258,-296,-295,-294,-293,-292,-305,173,173,173,494,-234,-233,173,-231,173,173,-217,173,-230,173,-84,-218,173,173,-152,-150,35,173,-170,-169,-337,-337,-198,173,-281,-282,173,-290,-291,173,-275,-337,-284,-229,-232,173,-221,173,-83,-219,-68,541,-28,-337,173,-11,173,173,-220,173,173,173,-284,173,173,-306,173,-337,-299,173,-225,-224,-222,-84,-300,173,173,173,-226,-223,173,-228,-227,]),'LPAREN':([0,1,2,3,4,5,6,7,8,10,11,12,13,14,15,16,17,18,19,20,22,23,25,26,27,29,30,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,72,73,76,77,80,81,82,85,86,87,89,91,94,96,97,98,100,101,103,104,105,106,109,111,112,116,117,119,120,121,122,123,124,126,127,128,131,132,133,134,135,136,137,138,139,140,141,142,143,145,146,147,148,149,150,151,152,153,154,155,156,160,161,163,164,165,166,167,168,169,170,171,173,174,175,176,177,180,181,187,191,192,198,199,200,201,202,204,205,206,211,216,218,219,220,223,227,229,230,231,233,235,236,237,238,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,269,275,276,278,279,280,282,283,284,285,286,289,290,291,292,296,297,298,299,300,301,302,303,305,307,308,310,311,312,314,316,317,319,328,329,332,336,338,339,340,341,342,344,345,347,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,402,403,404,405,407,411,412,413,414,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,442,443,445,447,448,449,453,454,457,458,459,464,465,472,475,477,481,482,483,484,488,489,490,492,494,496,497,499,500,502,504,505,506,510,511,512,513,514,515,518,519,521,522,529,530,531,532,533,535,536,537,538,539,541,542,543,544,545,547,549,550,551,553,554,555,557,558,559,560,565,566,569,570,571,574,575,576,577,578,579,],[37,-337,-113,-128,69,-124,-110,-106,85,-104,-107,-125,-105,-64,37,-60,-67,-99,-66,-109,-120,-115,-65,-102,-126,95,-108,-238,-111,-337,-122,37,-63,-129,37,116,-29,-121,-116,-62,-112,-70,118,-123,-117,-337,-337,-119,-337,-114,-130,37,-118,-71,-103,-337,-9,95,-91,-10,-96,69,-98,69,129,-131,-37,-95,-101,-97,174,118,-126,69,37,-93,-147,-335,-146,-167,-166,-28,-180,-182,-27,-100,-126,95,174,-337,-87,-90,-94,-92,-61,-72,69,129,-337,229,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-144,-301,-314,231,-142,-327,233,-283,-287,-325,-304,-322,239,-302,-315,-289,-328,-316,-288,-329,-320,266,-323,267,174,-284,229,233,-312,278,-73,229,-149,-336,69,-183,-181,-30,229,69,229,-28,-337,342,-38,229,-28,-337,-74,-337,229,-326,229,229,-334,-332,-331,-333,-174,-175,174,174,174,174,174,174,174,174,174,174,174,174,174,174,174,174,229,174,174,-298,-297,229,229,-143,-140,278,278,-141,-145,-337,422,-76,-79,-82,-75,229,-77,427,430,174,229,434,-81,-215,-214,-80,-216,229,-78,-312,441,-127,-153,-151,-148,174,-168,174,-69,-284,229,229,-36,-35,342,342,460,-45,-284,229,229,-44,-43,-244,-247,-245,-241,-242,-246,-248,229,-250,-251,-243,-249,-12,174,229,-11,-296,-295,-294,-293,-292,-305,229,174,422,229,229,-234,-233,229,-231,229,229,-217,229,-230,229,-84,-218,229,229,-152,-150,69,174,-170,-169,-31,-34,342,460,-337,-39,-42,-337,-198,174,174,-290,-291,229,-337,-213,-207,-211,-284,-229,-232,229,-221,229,538,-83,-219,-68,-33,-32,229,-28,-337,-41,-40,229,-11,-209,-208,-210,-212,229,229,-220,229,229,229,-284,229,229,-51,-50,-306,229,-337,-299,229,-225,-224,-222,-84,-46,-49,-300,229,229,-48,-47,229,-226,-223,229,-228,-227,]),'MINUSMINUS':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,132,133,134,135,136,137,138,139,140,141,143,145,146,148,149,150,151,152,153,154,156,160,161,163,164,165,166,167,168,169,171,173,174,175,176,181,191,198,201,204,205,206,218,219,220,227,229,230,231,233,235,236,237,238,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,263,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,310,319,329,332,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,402,403,404,405,407,411,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,459,472,475,477,481,482,483,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,547,549,550,551,553,554,555,557,558,565,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,175,-335,-28,-182,-27,175,-337,-87,-72,-337,175,-317,-321,-318,-286,-303,-285,-324,-330,-313,-319,-301,-314,175,-327,175,-283,-287,-325,-304,-322,-302,-315,-289,-328,-316,-288,-329,-320,261,-323,175,-284,175,175,-312,175,-336,-183,175,175,-28,-337,175,-28,-337,-337,175,-326,175,175,-334,-332,-331,-333,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,175,-298,-297,175,175,-337,-76,-79,-82,-75,175,-77,175,175,-81,-215,-214,-80,-216,175,-78,-312,175,175,-69,-284,175,175,-284,175,175,-244,-247,-245,-241,-242,-246,-248,175,-250,-251,-243,-249,-12,175,175,-11,-296,-295,-294,-293,-292,-305,175,175,175,175,-234,-233,175,-231,175,175,-217,175,-230,175,-84,-218,175,175,175,-337,-337,-198,175,175,-290,-291,175,-337,-284,-229,-232,175,-221,175,-83,-219,-68,175,-28,-337,175,-11,175,175,-220,175,175,175,-284,175,175,-306,175,-337,-299,175,-225,-224,-222,-84,-300,175,175,175,-226,-223,175,-228,-227,]),'ID':([0,1,2,3,4,5,6,7,10,11,12,13,14,15,16,17,18,19,20,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,43,44,45,46,47,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,69,71,72,76,80,81,82,85,87,89,91,94,96,97,98,100,101,102,103,104,105,106,109,111,116,117,118,119,120,121,122,123,124,126,128,129,131,135,137,142,146,147,149,150,151,165,171,173,174,175,180,181,187,191,192,193,194,198,199,201,202,204,205,206,211,218,219,220,223,227,229,231,233,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,262,264,265,266,269,275,279,280,282,284,285,286,287,289,290,291,297,298,300,301,302,303,305,307,308,312,314,316,317,319,327,328,329,332,336,338,339,342,344,349,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,372,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,442,443,445,447,448,449,457,459,460,472,475,477,481,484,485,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,548,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[42,-337,-113,-128,42,-124,-110,-106,-104,-107,-125,-105,-64,42,-60,-67,-99,-66,-109,-120,-115,-154,-65,-102,-126,-155,-131,-108,98,101,-238,-111,-337,-122,42,-63,-129,42,-121,-116,-62,-112,-70,-123,-117,-337,-337,-119,-337,-114,-130,42,-118,-71,-103,-337,-9,-131,-91,-10,-96,42,-98,42,-131,-95,-101,-97,176,-126,42,42,-93,-147,-335,-146,-167,-166,197,-28,-180,-182,-27,-100,-126,176,-337,176,-87,-90,-94,-92,-61,-72,42,-337,176,176,-286,-285,-144,176,-142,176,-283,-287,-288,176,-284,176,176,-73,310,-149,-336,42,197,197,-183,-181,176,42,176,-28,-337,42,176,-28,-337,-74,-337,176,176,176,-174,-175,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,403,405,176,176,-143,-140,-141,-145,-337,-76,-79,-82,423,-75,176,-77,176,310,-81,-215,-214,-80,-216,310,-78,-127,-153,-151,-148,176,197,-168,176,-69,-284,176,176,42,42,176,-284,176,176,-244,-247,-245,-241,-242,-246,-248,176,-250,-251,-243,-249,-12,176,176,176,-11,176,176,176,176,-234,-233,176,-231,310,176,-217,176,-230,310,-84,-218,310,176,-152,-150,42,176,-170,-169,42,-337,176,-337,-198,176,176,176,176,-337,-284,-229,-232,176,-221,310,-83,-219,-68,176,-28,-337,176,-11,176,310,-220,310,176,310,-284,176,176,176,176,-337,176,-225,-224,-222,-84,176,310,310,-226,-223,310,-228,-227,]),'IF':([61,97,119,124,181,191,284,285,286,289,291,298,300,301,302,303,305,307,308,332,424,425,428,429,432,435,437,438,439,440,496,497,500,502,505,506,510,535,536,537,539,554,555,557,558,569,574,575,576,577,578,579,],[-71,-335,-87,-72,311,-336,-76,-79,-82,-75,-77,311,-81,-215,-214,-80,-216,311,-78,-69,-234,-233,-231,311,-217,-230,311,-84,-218,311,-229,-232,-221,311,-83,-219,-68,311,-220,311,311,-225,-224,-222,-84,311,311,-226,-223,311,-228,-227,]),'STRING_LITERAL':([3,39,58,61,76,85,97,103,105,106,116,117,119,124,128,131,135,136,137,146,149,150,151,152,165,171,173,174,175,181,191,198,201,204,205,206,218,219,220,227,229,230,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,282,284,285,286,289,290,291,297,298,300,301,302,303,305,307,308,319,329,332,333,336,338,339,352,353,354,357,358,359,360,361,362,363,364,365,366,367,368,369,373,375,377,412,413,419,421,424,425,427,428,429,430,432,434,435,437,438,439,440,441,447,452,459,472,475,477,481,484,488,494,496,497,499,500,502,505,506,510,513,514,515,521,522,533,535,536,537,538,539,541,542,543,549,550,553,554,555,557,558,566,569,574,575,576,577,578,579,],[-128,-129,-130,-71,-131,152,-335,-28,-182,-27,152,-337,-87,-72,-337,152,-286,230,-285,152,152,-283,-287,-325,-288,152,-284,152,152,152,-336,-183,152,152,-28,-337,152,-28,-337,-337,152,-326,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,152,-337,-76,-79,-82,-75,152,-77,152,152,-81,-215,-214,-80,-216,152,-78,152,152,-69,152,-284,152,152,-284,152,152,-244,-247,-245,-241,-242,-246,-248,152,-250,-251,-243,-249,-12,152,152,-11,152,152,152,152,-234,-233,152,-231,152,152,-217,152,-230,152,-84,-218,152,152,152,230,-337,-337,-198,152,152,152,-337,-284,-229,-232,152,-221,152,-83,-219,-68,152,-28,-337,152,-11,152,152,-220,152,152,152,-284,152,152,152,-337,152,-225,-224,-222,-84,152,152,152,-226,-223,152,-228,-227,]),'FLOAT':([0,1,2,3,4,5,6,7,10,11,12,13,14,16,17,18,19,20,21,22,23,25,26,27,29,30,33,34,36,38,39,40,42,43,44,45,46,47,48,49,50,52,53,54,55,56,58,59,60,61,62,63,64,65,66,67,68,71,75,76,80,81,82,85,86,87,89,90,91,93,94,95,96,97,98,99,100,101,105,109,111,118,119,120,121,122,123,124,129,142,147,172,174,177,180,181,182,184,185,186,187,188,189,190,191,192,198,200,211,214,223,229,231,233,239,240,241,267,269,275,278,279,280,284,285,286,289,291,298,300,301,302,303,305,308,312,313,314,315,316,317,318,328,332,340,341,342,350,422,424,425,427,428,432,435,437,438,439,442,443,446,448,449,453,454,460,496,497,500,505,506,510,511,512,536,554,555,557,558,575,576,578,579,],[44,-337,-113,-128,44,-124,-110,-106,-104,-107,-125,-105,-64,-60,-67,-99,-66,-109,44,-120,-115,-65,-102,-126,-131,-108,-238,-111,-122,-63,-129,44,-29,-121,-116,-62,-112,-70,-52,-123,-117,-337,-337,-119,-337,-114,-130,44,-118,-71,-103,-337,-9,-131,-91,-10,-96,-98,44,-131,-95,-101,-97,44,-53,-126,44,-88,44,44,-93,44,-147,-335,-146,44,-167,-166,-182,-100,-126,44,-87,-90,-94,-92,-61,-72,44,-144,-142,44,44,44,-73,44,-89,44,44,44,-149,-159,-160,-156,-336,44,-183,-30,44,44,-74,44,44,44,44,-174,-175,44,-143,-140,44,-141,-145,-76,-79,-82,-75,-77,44,-81,-215,-214,-80,-216,-78,-127,44,-153,44,-151,-148,-157,-168,-69,-36,-35,44,44,44,-234,-233,44,-231,-217,-230,-81,-84,-218,-152,-150,-158,-170,-169,-31,-34,44,-229,-232,-221,-83,-219,-68,-33,-32,-220,-225,-224,-222,-84,-226,-223,-228,-227,]),'XOREQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,361,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'LSHIFTEQUAL':([132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,160,161,163,164,166,167,168,169,176,191,224,230,232,234,235,236,237,238,261,263,268,273,310,402,403,404,405,407,411,478,480,482,483,487,547,551,565,],[-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-315,-289,-328,-316,-329,-320,-276,-323,-312,-336,363,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-278,-312,-296,-295,-294,-293,-292,-305,-281,-282,-290,-291,-275,-306,-299,-300,]),'RBRACKET':([3,39,58,76,103,105,106,117,128,132,133,134,136,138,139,140,141,143,144,145,148,152,153,154,156,158,160,161,162,163,164,166,167,168,169,176,178,191,198,204,205,218,219,224,225,230,232,234,235,236,237,238,261,263,268,272,273,282,334,335,336,337,351,352,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,400,401,402,403,404,405,406,407,411,419,420,421,455,456,459,466,467,468,471,476,478,480,482,483,486,487,491,493,494,513,514,524,540,541,547,551,561,562,564,565,],[-128,-129,-130,-131,-28,-182,-27,-337,-337,-317,-321,-318,-303,-324,-330,-313,-319,-301,-274,-314,-327,-325,-304,-322,-302,-255,-315,-289,-253,-328,-316,-329,-320,-276,-323,-312,-252,-336,-183,-337,-28,-337,-28,-274,-239,-326,-280,-277,-334,-332,-331,-333,-298,-297,-279,-235,-278,-337,453,-4,454,-3,464,465,-261,-273,-262,-260,-264,-268,-263,-259,-266,-271,-257,-256,-265,-272,-267,-269,-270,-258,-296,-295,-294,-293,482,-292,-305,-337,492,-337,511,512,-337,518,519,-240,520,-237,-281,-282,-290,-291,-236,-275,529,530,531,-337,-28,-254,559,560,-306,-299,570,571,572,-300,]),}
+
+_lr_action = {}
+for _k, _v in _lr_action_items.items():
+ for _x,_y in zip(_v[0],_v[1]):
+ if not _x in _lr_action: _lr_action[_x] = {}
+ _lr_action[_x][_k] = _y
+del _lr_action_items
+
+_lr_goto_items = {'expression_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[284,284,284,284,284,284,284,284,284,284,284,284,284,]),'struct_or_union_specifier':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,]),'init_declarator_list':([4,89,],[70,70,]),'init_declarator_list_opt':([4,89,],[79,79,]),'iteration_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[285,285,285,285,285,285,285,285,285,285,285,285,285,]),'static_assert':([0,59,181,298,307,429,437,440,502,535,537,539,569,574,577,],[17,17,286,286,286,286,286,286,286,286,286,286,286,286,286,]),'unified_string_literal':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,333,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,452,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,136,]),'assignment_expression_opt':([204,218,419,421,513,],[334,351,491,493,540,]),'brace_open':([31,32,92,96,98,100,101,130,131,181,201,229,298,307,375,413,429,437,440,477,478,479,502,521,535,537,539,569,574,577,],[99,102,181,184,185,193,194,181,227,181,227,181,181,181,227,488,181,181,181,488,488,488,181,227,181,181,181,181,181,181,]),'enumerator':([102,193,194,327,],[195,195,195,450,]),'typeid_noparen_declarator':([211,],[348,]),'type_qualifier_list_opt':([35,117,128,206,220,282,459,515,],[104,204,218,339,354,419,513,543,]),'declaration_specifiers_no_type_opt':([1,27,52,53,55,63,87,],[66,94,120,121,122,94,94,]),'expression_opt':([181,298,307,427,429,437,440,499,502,533,535,537,539,553,566,569,574,577,],[288,288,288,498,288,288,288,534,288,552,288,288,288,567,573,288,288,288,]),'designation':([227,472,488,550,],[369,369,369,369,]),'parameter_list':([118,129,278,342,422,460,],[213,213,213,213,213,213,]),'alignment_specifier':([0,1,4,21,27,52,53,55,59,63,75,85,87,89,93,95,99,118,129,174,177,181,184,185,186,192,211,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[53,53,81,53,53,53,53,53,53,53,53,142,53,81,53,142,142,53,53,142,280,53,142,142,142,280,81,142,142,142,142,142,53,53,142,142,53,53,53,53,53,]),'labeled_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[289,289,289,289,289,289,289,289,289,289,289,289,289,]),'abstract_declarator':([177,211,278,342,],[281,281,418,418,]),'translation_unit':([0,],[59,]),'init_declarator':([4,89,126,202,],[84,84,217,331,]),'direct_abstract_declarator':([177,211,276,278,342,344,457,],[283,283,414,283,283,414,414,]),'designator_list':([227,472,488,550,],[376,376,376,376,]),'identifier':([85,116,118,129,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,349,353,354,364,372,373,375,412,413,419,421,427,429,430,434,437,440,441,447,460,477,481,484,485,499,502,513,521,533,535,537,538,539,542,543,548,549,553,566,569,574,577,],[143,143,215,215,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,143,461,143,143,143,470,143,143,143,143,143,143,143,143,143,143,143,143,143,143,215,143,143,143,527,143,143,143,143,143,143,143,143,143,143,143,563,143,143,143,143,143,143,]),'offsetof_member_designator':([485,],[526,]),'unary_expression':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[144,144,224,232,234,144,224,273,224,224,224,224,224,224,224,144,144,144,144,144,144,144,144,144,144,144,144,144,144,144,144,224,144,144,224,224,224,144,224,224,144,144,224,224,224,224,224,144,224,224,144,224,224,224,224,224,224,224,224,224,144,144,144,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,224,]),'abstract_declarator_opt':([177,211,],[274,343,]),'initializer':([131,201,375,521,],[226,330,473,546,]),'direct_id_declarator':([0,4,15,37,40,59,69,72,89,91,126,192,202,211,342,344,445,457,],[48,48,86,48,48,48,48,86,48,48,48,48,48,48,48,86,48,86,]),'struct_declaration_list':([99,184,185,],[186,313,315,]),'pp_directive':([0,59,],[14,14,]),'declaration_list':([21,75,],[93,93,]),'id_init_declarator':([40,91,],[108,108,]),'type_specifier':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[18,18,109,18,18,147,109,18,147,147,18,18,269,147,18,147,147,147,109,147,147,147,147,147,18,18,147,147,18,18,18,18,18,]),'compound_statement':([92,130,181,229,298,307,429,437,440,502,535,537,539,569,574,577,],[180,223,291,378,291,291,291,291,291,291,291,291,291,291,291,291,]),'pointer':([0,4,37,40,59,69,89,91,104,126,177,192,202,211,278,342,445,],[15,72,15,15,15,72,72,15,199,72,276,72,72,344,276,457,72,]),'typeid_declarator':([4,69,89,126,192,202,445,],[74,125,74,74,74,74,74,]),'id_init_declarator_list':([40,91,],[113,113,]),'declarator':([4,89,126,192,202,445,],[78,78,78,324,78,324,]),'argument_expression_list':([266,],[409,]),'struct_declarator_list_opt':([192,],[322,]),'block_item_list':([181,],[298,]),'parameter_type_list_opt':([278,342,422,],[417,417,495,]),'struct_declarator':([192,445,],[323,508,]),'type_qualifier':([0,1,4,21,27,35,52,53,55,59,63,75,85,87,89,93,95,99,103,117,118,128,129,172,174,177,181,184,185,186,192,205,206,211,219,220,229,231,233,239,267,278,282,298,313,315,342,350,422,427,459,460,514,515,],[52,52,80,52,52,105,52,52,52,52,52,52,105,52,80,52,105,105,198,105,52,105,52,198,105,279,52,105,105,105,279,198,105,80,198,105,105,105,105,105,105,52,105,52,105,105,52,52,52,52,105,52,198,105,]),'assignment_operator':([224,],[364,]),'expression':([174,181,229,231,233,258,265,290,298,307,427,429,430,434,437,440,441,499,502,533,535,537,538,539,549,553,566,569,574,577,],[270,294,270,270,270,399,406,426,294,294,294,294,501,503,294,294,507,294,294,294,294,294,556,294,564,294,294,294,294,294,]),'storage_class_specifier':([0,1,4,21,27,52,53,55,59,63,75,87,89,93,118,129,181,211,278,298,342,350,422,427,460,],[1,1,68,1,1,1,1,1,1,1,1,1,68,1,1,1,1,68,1,1,1,1,1,1,1,]),'unified_wstring_literal':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,153,]),'translation_unit_or_empty':([0,],[9,]),'initializer_list_opt':([227,],[370,]),'brace_close':([99,184,185,186,196,309,313,315,325,326,370,472,528,550,],[187,314,316,317,328,439,442,443,448,449,469,523,551,565,]),'direct_typeid_declarator':([4,69,72,89,126,192,202,445,],[73,73,127,73,73,73,73,73,]),'external_declaration':([0,59,],[16,123,]),'pragmacomp_or_statement':([307,429,440,502,535,537,539,569,574,577,],[436,500,506,536,554,555,557,576,578,579,]),'type_name':([85,95,174,229,231,233,239,267,],[157,183,271,379,380,381,382,410,]),'typedef_name':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,36,]),'pppragma_directive':([0,59,99,181,184,185,186,298,307,313,315,429,437,440,502,535,537,539,569,574,577,],[25,25,189,300,189,189,189,300,437,189,189,437,300,437,437,437,437,437,437,437,437,]),'statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[301,301,438,438,505,438,438,438,438,558,438,438,438,]),'cast_expression':([85,116,131,171,174,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[158,158,158,268,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,487,158,158,158,158,158,158,158,158,158,158,487,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,158,]),'atomic_specifier':([0,1,21,27,40,52,53,55,59,63,75,85,87,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[27,63,87,63,111,63,63,63,27,63,87,111,63,111,87,111,111,27,27,111,111,87,111,111,111,111,111,111,111,111,111,27,87,111,111,27,27,27,87,27,]),'struct_declarator_list':([192,],[320,]),'empty':([0,1,4,21,27,35,40,52,53,55,63,75,87,89,91,117,118,128,129,177,181,192,204,206,211,218,220,227,278,282,298,307,342,419,421,422,427,429,437,440,459,460,472,488,499,502,513,515,533,535,537,539,550,553,566,569,574,577,],[57,64,83,88,64,106,115,64,64,64,64,88,64,83,115,106,208,106,208,277,306,321,337,106,277,337,106,377,415,106,433,433,415,337,337,415,433,433,433,433,106,208,522,522,433,433,337,106,433,433,433,433,522,433,433,433,433,433,]),'parameter_declaration':([118,129,278,342,350,422,460,],[210,210,210,210,463,210,210,]),'primary_expression':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,161,]),'declaration':([0,21,59,75,93,181,298,427,],[38,90,38,90,182,302,302,499,]),'declaration_specifiers_no_type':([0,1,21,27,52,53,55,59,63,75,87,93,118,129,181,278,298,342,350,422,427,460,],[40,67,91,67,67,67,67,40,67,91,67,91,214,214,91,214,91,214,214,214,91,214,]),'jump_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[303,303,303,303,303,303,303,303,303,303,303,303,303,]),'enumerator_list':([102,193,194,],[196,325,326,]),'block_item':([181,298,],[305,432,]),'constant_expression':([85,116,297,319,329,373,447,],[159,203,431,444,451,471,509,]),'identifier_list_opt':([118,129,460,],[207,221,516,]),'constant':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,156,]),'type_specifier_no_typeid':([0,4,21,40,59,75,85,89,91,93,95,99,118,129,172,174,177,181,184,185,186,192,211,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[12,71,12,12,12,12,12,71,12,12,12,12,12,12,12,12,275,12,12,12,12,275,71,12,12,12,12,12,12,12,12,12,12,12,12,12,12,12,]),'struct_declaration':([99,184,185,186,313,315,],[190,190,190,318,318,318,]),'direct_typeid_noparen_declarator':([211,344,],[345,458,]),'id_declarator':([0,4,37,40,59,69,89,91,126,192,202,211,342,445,],[21,75,107,110,21,107,179,110,179,179,179,346,107,179,]),'selection_statement':([181,298,307,429,437,440,502,535,537,539,569,574,577,],[308,308,308,308,308,308,308,308,308,308,308,308,308,]),'postfix_expression':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,]),'initializer_list':([227,488,],[374,528,]),'unary_operator':([85,116,131,146,149,171,174,175,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,413,419,421,427,429,430,434,437,440,441,447,477,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,171,]),'struct_or_union':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,31,]),'block_item_list_opt':([181,],[309,]),'assignment_expression':([131,174,181,201,204,218,229,231,233,258,265,266,290,298,307,338,339,353,354,364,375,412,419,421,427,429,430,434,437,440,441,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[228,272,272,228,335,335,272,272,272,272,272,408,272,272,272,455,456,466,467,468,228,486,335,335,272,272,272,272,272,272,272,525,272,272,335,228,272,272,272,272,272,561,562,272,272,272,272,272,272,]),'designation_opt':([227,472,488,550,],[375,521,375,521,]),'parameter_type_list':([118,129,278,342,422,460,],[209,222,416,416,416,517,]),'type_qualifier_list':([35,85,95,99,117,128,174,184,185,186,206,220,229,231,233,239,267,282,313,315,459,515,],[103,172,172,172,205,219,172,172,172,172,103,103,172,172,172,172,172,103,172,172,514,103,]),'designator':([227,376,472,488,550,],[371,474,371,371,371,]),'id_init_declarator_list_opt':([40,91,],[114,114,]),'declaration_specifiers':([0,21,59,75,93,118,129,181,278,298,342,350,422,427,460,],[4,89,4,89,89,211,211,89,211,89,211,211,211,89,211,]),'identifier_list':([118,129,460,],[212,212,212,]),'declaration_list_opt':([21,75,],[92,130,]),'function_definition':([0,59,],[45,45,]),'binary_expression':([85,116,131,174,181,201,204,218,229,231,233,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,419,421,427,429,430,434,437,440,441,447,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[162,162,162,162,162,162,162,162,162,162,162,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,162,400,401,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,162,]),'enum_specifier':([0,21,40,59,75,85,91,93,95,99,118,129,172,174,181,184,185,186,214,229,231,233,239,267,278,298,313,315,342,350,422,427,460,],[49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,49,]),'decl_body':([0,21,59,75,93,181,298,427,],[51,51,51,51,51,51,51,51,]),'function_specifier':([0,1,4,21,27,52,53,55,59,63,75,87,89,93,118,129,181,211,278,298,342,350,422,427,460,],[55,55,82,55,55,55,55,55,55,55,55,55,82,55,55,55,55,82,55,55,55,55,55,55,55,]),'specifier_qualifier_list':([85,95,99,174,184,185,186,229,231,233,239,267,313,315,],[177,177,192,177,192,192,192,177,177,177,177,177,192,192,]),'conditional_expression':([85,116,131,174,181,201,204,218,229,231,233,258,265,266,290,297,298,307,319,329,338,339,353,354,364,373,375,412,419,421,427,429,430,434,437,440,441,447,481,484,499,502,513,521,533,535,537,538,539,542,543,549,553,566,569,574,577,],[178,178,225,225,225,225,225,225,225,225,225,225,225,225,225,178,225,225,178,178,225,225,225,225,225,178,225,225,225,225,225,225,225,225,225,225,225,178,524,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,225,]),}
+
+_lr_goto = {}
+for _k, _v in _lr_goto_items.items():
+ for _x, _y in zip(_v[0], _v[1]):
+ if not _x in _lr_goto: _lr_goto[_x] = {}
+ _lr_goto[_x][_k] = _y
+del _lr_goto_items
+_lr_productions = [
+ ("S' -> translation_unit_or_empty","S'",1,None,None,None),
+ ('abstract_declarator_opt -> empty','abstract_declarator_opt',1,'p_abstract_declarator_opt','plyparser.py',43),
+ ('abstract_declarator_opt -> abstract_declarator','abstract_declarator_opt',1,'p_abstract_declarator_opt','plyparser.py',44),
+ ('assignment_expression_opt -> empty','assignment_expression_opt',1,'p_assignment_expression_opt','plyparser.py',43),
+ ('assignment_expression_opt -> assignment_expression','assignment_expression_opt',1,'p_assignment_expression_opt','plyparser.py',44),
+ ('block_item_list_opt -> empty','block_item_list_opt',1,'p_block_item_list_opt','plyparser.py',43),
+ ('block_item_list_opt -> block_item_list','block_item_list_opt',1,'p_block_item_list_opt','plyparser.py',44),
+ ('declaration_list_opt -> empty','declaration_list_opt',1,'p_declaration_list_opt','plyparser.py',43),
+ ('declaration_list_opt -> declaration_list','declaration_list_opt',1,'p_declaration_list_opt','plyparser.py',44),
+ ('declaration_specifiers_no_type_opt -> empty','declaration_specifiers_no_type_opt',1,'p_declaration_specifiers_no_type_opt','plyparser.py',43),
+ ('declaration_specifiers_no_type_opt -> declaration_specifiers_no_type','declaration_specifiers_no_type_opt',1,'p_declaration_specifiers_no_type_opt','plyparser.py',44),
+ ('designation_opt -> empty','designation_opt',1,'p_designation_opt','plyparser.py',43),
+ ('designation_opt -> designation','designation_opt',1,'p_designation_opt','plyparser.py',44),
+ ('expression_opt -> empty','expression_opt',1,'p_expression_opt','plyparser.py',43),
+ ('expression_opt -> expression','expression_opt',1,'p_expression_opt','plyparser.py',44),
+ ('id_init_declarator_list_opt -> empty','id_init_declarator_list_opt',1,'p_id_init_declarator_list_opt','plyparser.py',43),
+ ('id_init_declarator_list_opt -> id_init_declarator_list','id_init_declarator_list_opt',1,'p_id_init_declarator_list_opt','plyparser.py',44),
+ ('identifier_list_opt -> empty','identifier_list_opt',1,'p_identifier_list_opt','plyparser.py',43),
+ ('identifier_list_opt -> identifier_list','identifier_list_opt',1,'p_identifier_list_opt','plyparser.py',44),
+ ('init_declarator_list_opt -> empty','init_declarator_list_opt',1,'p_init_declarator_list_opt','plyparser.py',43),
+ ('init_declarator_list_opt -> init_declarator_list','init_declarator_list_opt',1,'p_init_declarator_list_opt','plyparser.py',44),
+ ('initializer_list_opt -> empty','initializer_list_opt',1,'p_initializer_list_opt','plyparser.py',43),
+ ('initializer_list_opt -> initializer_list','initializer_list_opt',1,'p_initializer_list_opt','plyparser.py',44),
+ ('parameter_type_list_opt -> empty','parameter_type_list_opt',1,'p_parameter_type_list_opt','plyparser.py',43),
+ ('parameter_type_list_opt -> parameter_type_list','parameter_type_list_opt',1,'p_parameter_type_list_opt','plyparser.py',44),
+ ('struct_declarator_list_opt -> empty','struct_declarator_list_opt',1,'p_struct_declarator_list_opt','plyparser.py',43),
+ ('struct_declarator_list_opt -> struct_declarator_list','struct_declarator_list_opt',1,'p_struct_declarator_list_opt','plyparser.py',44),
+ ('type_qualifier_list_opt -> empty','type_qualifier_list_opt',1,'p_type_qualifier_list_opt','plyparser.py',43),
+ ('type_qualifier_list_opt -> type_qualifier_list','type_qualifier_list_opt',1,'p_type_qualifier_list_opt','plyparser.py',44),
+ ('direct_id_declarator -> ID','direct_id_declarator',1,'p_direct_id_declarator_1','plyparser.py',126),
+ ('direct_id_declarator -> LPAREN id_declarator RPAREN','direct_id_declarator',3,'p_direct_id_declarator_2','plyparser.py',126),
+ ('direct_id_declarator -> direct_id_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_id_declarator',5,'p_direct_id_declarator_3','plyparser.py',126),
+ ('direct_id_declarator -> direct_id_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET','direct_id_declarator',6,'p_direct_id_declarator_4','plyparser.py',126),
+ ('direct_id_declarator -> direct_id_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET','direct_id_declarator',6,'p_direct_id_declarator_4','plyparser.py',127),
+ ('direct_id_declarator -> direct_id_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET','direct_id_declarator',5,'p_direct_id_declarator_5','plyparser.py',126),
+ ('direct_id_declarator -> direct_id_declarator LPAREN parameter_type_list RPAREN','direct_id_declarator',4,'p_direct_id_declarator_6','plyparser.py',126),
+ ('direct_id_declarator -> direct_id_declarator LPAREN identifier_list_opt RPAREN','direct_id_declarator',4,'p_direct_id_declarator_6','plyparser.py',127),
+ ('direct_typeid_declarator -> TYPEID','direct_typeid_declarator',1,'p_direct_typeid_declarator_1','plyparser.py',126),
+ ('direct_typeid_declarator -> LPAREN typeid_declarator RPAREN','direct_typeid_declarator',3,'p_direct_typeid_declarator_2','plyparser.py',126),
+ ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_typeid_declarator',5,'p_direct_typeid_declarator_3','plyparser.py',126),
+ ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET','direct_typeid_declarator',6,'p_direct_typeid_declarator_4','plyparser.py',126),
+ ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET','direct_typeid_declarator',6,'p_direct_typeid_declarator_4','plyparser.py',127),
+ ('direct_typeid_declarator -> direct_typeid_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET','direct_typeid_declarator',5,'p_direct_typeid_declarator_5','plyparser.py',126),
+ ('direct_typeid_declarator -> direct_typeid_declarator LPAREN parameter_type_list RPAREN','direct_typeid_declarator',4,'p_direct_typeid_declarator_6','plyparser.py',126),
+ ('direct_typeid_declarator -> direct_typeid_declarator LPAREN identifier_list_opt RPAREN','direct_typeid_declarator',4,'p_direct_typeid_declarator_6','plyparser.py',127),
+ ('direct_typeid_noparen_declarator -> TYPEID','direct_typeid_noparen_declarator',1,'p_direct_typeid_noparen_declarator_1','plyparser.py',126),
+ ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_typeid_noparen_declarator',5,'p_direct_typeid_noparen_declarator_3','plyparser.py',126),
+ ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET STATIC type_qualifier_list_opt assignment_expression RBRACKET','direct_typeid_noparen_declarator',6,'p_direct_typeid_noparen_declarator_4','plyparser.py',126),
+ ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET type_qualifier_list STATIC assignment_expression RBRACKET','direct_typeid_noparen_declarator',6,'p_direct_typeid_noparen_declarator_4','plyparser.py',127),
+ ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LBRACKET type_qualifier_list_opt TIMES RBRACKET','direct_typeid_noparen_declarator',5,'p_direct_typeid_noparen_declarator_5','plyparser.py',126),
+ ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LPAREN parameter_type_list RPAREN','direct_typeid_noparen_declarator',4,'p_direct_typeid_noparen_declarator_6','plyparser.py',126),
+ ('direct_typeid_noparen_declarator -> direct_typeid_noparen_declarator LPAREN identifier_list_opt RPAREN','direct_typeid_noparen_declarator',4,'p_direct_typeid_noparen_declarator_6','plyparser.py',127),
+ ('id_declarator -> direct_id_declarator','id_declarator',1,'p_id_declarator_1','plyparser.py',126),
+ ('id_declarator -> pointer direct_id_declarator','id_declarator',2,'p_id_declarator_2','plyparser.py',126),
+ ('typeid_declarator -> direct_typeid_declarator','typeid_declarator',1,'p_typeid_declarator_1','plyparser.py',126),
+ ('typeid_declarator -> pointer direct_typeid_declarator','typeid_declarator',2,'p_typeid_declarator_2','plyparser.py',126),
+ ('typeid_noparen_declarator -> direct_typeid_noparen_declarator','typeid_noparen_declarator',1,'p_typeid_noparen_declarator_1','plyparser.py',126),
+ ('typeid_noparen_declarator -> pointer direct_typeid_noparen_declarator','typeid_noparen_declarator',2,'p_typeid_noparen_declarator_2','plyparser.py',126),
+ ('translation_unit_or_empty -> translation_unit','translation_unit_or_empty',1,'p_translation_unit_or_empty','c_parser.py',509),
+ ('translation_unit_or_empty -> empty','translation_unit_or_empty',1,'p_translation_unit_or_empty','c_parser.py',510),
+ ('translation_unit -> external_declaration','translation_unit',1,'p_translation_unit_1','c_parser.py',518),
+ ('translation_unit -> translation_unit external_declaration','translation_unit',2,'p_translation_unit_2','c_parser.py',524),
+ ('external_declaration -> function_definition','external_declaration',1,'p_external_declaration_1','c_parser.py',534),
+ ('external_declaration -> declaration','external_declaration',1,'p_external_declaration_2','c_parser.py',539),
+ ('external_declaration -> pp_directive','external_declaration',1,'p_external_declaration_3','c_parser.py',544),
+ ('external_declaration -> pppragma_directive','external_declaration',1,'p_external_declaration_3','c_parser.py',545),
+ ('external_declaration -> SEMI','external_declaration',1,'p_external_declaration_4','c_parser.py',550),
+ ('external_declaration -> static_assert','external_declaration',1,'p_external_declaration_5','c_parser.py',555),
+ ('static_assert -> _STATIC_ASSERT LPAREN constant_expression COMMA unified_string_literal RPAREN','static_assert',6,'p_static_assert_declaration','c_parser.py',560),
+ ('static_assert -> _STATIC_ASSERT LPAREN constant_expression RPAREN','static_assert',4,'p_static_assert_declaration','c_parser.py',561),
+ ('pp_directive -> PPHASH','pp_directive',1,'p_pp_directive','c_parser.py',569),
+ ('pppragma_directive -> PPPRAGMA','pppragma_directive',1,'p_pppragma_directive','c_parser.py',575),
+ ('pppragma_directive -> PPPRAGMA PPPRAGMASTR','pppragma_directive',2,'p_pppragma_directive','c_parser.py',576),
+ ('function_definition -> id_declarator declaration_list_opt compound_statement','function_definition',3,'p_function_definition_1','c_parser.py',586),
+ ('function_definition -> declaration_specifiers id_declarator declaration_list_opt compound_statement','function_definition',4,'p_function_definition_2','c_parser.py',604),
+ ('statement -> labeled_statement','statement',1,'p_statement','c_parser.py',619),
+ ('statement -> expression_statement','statement',1,'p_statement','c_parser.py',620),
+ ('statement -> compound_statement','statement',1,'p_statement','c_parser.py',621),
+ ('statement -> selection_statement','statement',1,'p_statement','c_parser.py',622),
+ ('statement -> iteration_statement','statement',1,'p_statement','c_parser.py',623),
+ ('statement -> jump_statement','statement',1,'p_statement','c_parser.py',624),
+ ('statement -> pppragma_directive','statement',1,'p_statement','c_parser.py',625),
+ ('statement -> static_assert','statement',1,'p_statement','c_parser.py',626),
+ ('pragmacomp_or_statement -> pppragma_directive statement','pragmacomp_or_statement',2,'p_pragmacomp_or_statement','c_parser.py',674),
+ ('pragmacomp_or_statement -> statement','pragmacomp_or_statement',1,'p_pragmacomp_or_statement','c_parser.py',675),
+ ('decl_body -> declaration_specifiers init_declarator_list_opt','decl_body',2,'p_decl_body','c_parser.py',694),
+ ('decl_body -> declaration_specifiers_no_type id_init_declarator_list_opt','decl_body',2,'p_decl_body','c_parser.py',695),
+ ('declaration -> decl_body SEMI','declaration',2,'p_declaration','c_parser.py',755),
+ ('declaration_list -> declaration','declaration_list',1,'p_declaration_list','c_parser.py',764),
+ ('declaration_list -> declaration_list declaration','declaration_list',2,'p_declaration_list','c_parser.py',765),
+ ('declaration_specifiers_no_type -> type_qualifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_1','c_parser.py',775),
+ ('declaration_specifiers_no_type -> storage_class_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_2','c_parser.py',780),
+ ('declaration_specifiers_no_type -> function_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_3','c_parser.py',785),
+ ('declaration_specifiers_no_type -> atomic_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_4','c_parser.py',792),
+ ('declaration_specifiers_no_type -> alignment_specifier declaration_specifiers_no_type_opt','declaration_specifiers_no_type',2,'p_declaration_specifiers_no_type_5','c_parser.py',797),
+ ('declaration_specifiers -> declaration_specifiers type_qualifier','declaration_specifiers',2,'p_declaration_specifiers_1','c_parser.py',802),
+ ('declaration_specifiers -> declaration_specifiers storage_class_specifier','declaration_specifiers',2,'p_declaration_specifiers_2','c_parser.py',807),
+ ('declaration_specifiers -> declaration_specifiers function_specifier','declaration_specifiers',2,'p_declaration_specifiers_3','c_parser.py',812),
+ ('declaration_specifiers -> declaration_specifiers type_specifier_no_typeid','declaration_specifiers',2,'p_declaration_specifiers_4','c_parser.py',817),
+ ('declaration_specifiers -> type_specifier','declaration_specifiers',1,'p_declaration_specifiers_5','c_parser.py',822),
+ ('declaration_specifiers -> declaration_specifiers_no_type type_specifier','declaration_specifiers',2,'p_declaration_specifiers_6','c_parser.py',827),
+ ('declaration_specifiers -> declaration_specifiers alignment_specifier','declaration_specifiers',2,'p_declaration_specifiers_7','c_parser.py',832),
+ ('storage_class_specifier -> AUTO','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',837),
+ ('storage_class_specifier -> REGISTER','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',838),
+ ('storage_class_specifier -> STATIC','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',839),
+ ('storage_class_specifier -> EXTERN','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',840),
+ ('storage_class_specifier -> TYPEDEF','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',841),
+ ('storage_class_specifier -> _THREAD_LOCAL','storage_class_specifier',1,'p_storage_class_specifier','c_parser.py',842),
+ ('function_specifier -> INLINE','function_specifier',1,'p_function_specifier','c_parser.py',847),
+ ('function_specifier -> _NORETURN','function_specifier',1,'p_function_specifier','c_parser.py',848),
+ ('type_specifier_no_typeid -> VOID','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',853),
+ ('type_specifier_no_typeid -> _BOOL','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',854),
+ ('type_specifier_no_typeid -> CHAR','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',855),
+ ('type_specifier_no_typeid -> SHORT','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',856),
+ ('type_specifier_no_typeid -> INT','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',857),
+ ('type_specifier_no_typeid -> LONG','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',858),
+ ('type_specifier_no_typeid -> FLOAT','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',859),
+ ('type_specifier_no_typeid -> DOUBLE','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',860),
+ ('type_specifier_no_typeid -> _COMPLEX','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',861),
+ ('type_specifier_no_typeid -> SIGNED','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',862),
+ ('type_specifier_no_typeid -> UNSIGNED','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',863),
+ ('type_specifier_no_typeid -> __INT128','type_specifier_no_typeid',1,'p_type_specifier_no_typeid','c_parser.py',864),
+ ('type_specifier -> typedef_name','type_specifier',1,'p_type_specifier','c_parser.py',869),
+ ('type_specifier -> enum_specifier','type_specifier',1,'p_type_specifier','c_parser.py',870),
+ ('type_specifier -> struct_or_union_specifier','type_specifier',1,'p_type_specifier','c_parser.py',871),
+ ('type_specifier -> type_specifier_no_typeid','type_specifier',1,'p_type_specifier','c_parser.py',872),
+ ('type_specifier -> atomic_specifier','type_specifier',1,'p_type_specifier','c_parser.py',873),
+ ('atomic_specifier -> _ATOMIC LPAREN type_name RPAREN','atomic_specifier',4,'p_atomic_specifier','c_parser.py',879),
+ ('type_qualifier -> CONST','type_qualifier',1,'p_type_qualifier','c_parser.py',886),
+ ('type_qualifier -> RESTRICT','type_qualifier',1,'p_type_qualifier','c_parser.py',887),
+ ('type_qualifier -> VOLATILE','type_qualifier',1,'p_type_qualifier','c_parser.py',888),
+ ('type_qualifier -> _ATOMIC','type_qualifier',1,'p_type_qualifier','c_parser.py',889),
+ ('init_declarator_list -> init_declarator','init_declarator_list',1,'p_init_declarator_list','c_parser.py',894),
+ ('init_declarator_list -> init_declarator_list COMMA init_declarator','init_declarator_list',3,'p_init_declarator_list','c_parser.py',895),
+ ('init_declarator -> declarator','init_declarator',1,'p_init_declarator','c_parser.py',903),
+ ('init_declarator -> declarator EQUALS initializer','init_declarator',3,'p_init_declarator','c_parser.py',904),
+ ('id_init_declarator_list -> id_init_declarator','id_init_declarator_list',1,'p_id_init_declarator_list','c_parser.py',909),
+ ('id_init_declarator_list -> id_init_declarator_list COMMA init_declarator','id_init_declarator_list',3,'p_id_init_declarator_list','c_parser.py',910),
+ ('id_init_declarator -> id_declarator','id_init_declarator',1,'p_id_init_declarator','c_parser.py',915),
+ ('id_init_declarator -> id_declarator EQUALS initializer','id_init_declarator',3,'p_id_init_declarator','c_parser.py',916),
+ ('specifier_qualifier_list -> specifier_qualifier_list type_specifier_no_typeid','specifier_qualifier_list',2,'p_specifier_qualifier_list_1','c_parser.py',923),
+ ('specifier_qualifier_list -> specifier_qualifier_list type_qualifier','specifier_qualifier_list',2,'p_specifier_qualifier_list_2','c_parser.py',928),
+ ('specifier_qualifier_list -> type_specifier','specifier_qualifier_list',1,'p_specifier_qualifier_list_3','c_parser.py',933),
+ ('specifier_qualifier_list -> type_qualifier_list type_specifier','specifier_qualifier_list',2,'p_specifier_qualifier_list_4','c_parser.py',938),
+ ('specifier_qualifier_list -> alignment_specifier','specifier_qualifier_list',1,'p_specifier_qualifier_list_5','c_parser.py',943),
+ ('specifier_qualifier_list -> specifier_qualifier_list alignment_specifier','specifier_qualifier_list',2,'p_specifier_qualifier_list_6','c_parser.py',948),
+ ('struct_or_union_specifier -> struct_or_union ID','struct_or_union_specifier',2,'p_struct_or_union_specifier_1','c_parser.py',956),
+ ('struct_or_union_specifier -> struct_or_union TYPEID','struct_or_union_specifier',2,'p_struct_or_union_specifier_1','c_parser.py',957),
+ ('struct_or_union_specifier -> struct_or_union brace_open struct_declaration_list brace_close','struct_or_union_specifier',4,'p_struct_or_union_specifier_2','c_parser.py',967),
+ ('struct_or_union_specifier -> struct_or_union brace_open brace_close','struct_or_union_specifier',3,'p_struct_or_union_specifier_2','c_parser.py',968),
+ ('struct_or_union_specifier -> struct_or_union ID brace_open struct_declaration_list brace_close','struct_or_union_specifier',5,'p_struct_or_union_specifier_3','c_parser.py',985),
+ ('struct_or_union_specifier -> struct_or_union ID brace_open brace_close','struct_or_union_specifier',4,'p_struct_or_union_specifier_3','c_parser.py',986),
+ ('struct_or_union_specifier -> struct_or_union TYPEID brace_open struct_declaration_list brace_close','struct_or_union_specifier',5,'p_struct_or_union_specifier_3','c_parser.py',987),
+ ('struct_or_union_specifier -> struct_or_union TYPEID brace_open brace_close','struct_or_union_specifier',4,'p_struct_or_union_specifier_3','c_parser.py',988),
+ ('struct_or_union -> STRUCT','struct_or_union',1,'p_struct_or_union','c_parser.py',1004),
+ ('struct_or_union -> UNION','struct_or_union',1,'p_struct_or_union','c_parser.py',1005),
+ ('struct_declaration_list -> struct_declaration','struct_declaration_list',1,'p_struct_declaration_list','c_parser.py',1012),
+ ('struct_declaration_list -> struct_declaration_list struct_declaration','struct_declaration_list',2,'p_struct_declaration_list','c_parser.py',1013),
+ ('struct_declaration -> specifier_qualifier_list struct_declarator_list_opt SEMI','struct_declaration',3,'p_struct_declaration_1','c_parser.py',1021),
+ ('struct_declaration -> SEMI','struct_declaration',1,'p_struct_declaration_2','c_parser.py',1059),
+ ('struct_declaration -> pppragma_directive','struct_declaration',1,'p_struct_declaration_3','c_parser.py',1064),
+ ('struct_declarator_list -> struct_declarator','struct_declarator_list',1,'p_struct_declarator_list','c_parser.py',1069),
+ ('struct_declarator_list -> struct_declarator_list COMMA struct_declarator','struct_declarator_list',3,'p_struct_declarator_list','c_parser.py',1070),
+ ('struct_declarator -> declarator','struct_declarator',1,'p_struct_declarator_1','c_parser.py',1078),
+ ('struct_declarator -> declarator COLON constant_expression','struct_declarator',3,'p_struct_declarator_2','c_parser.py',1083),
+ ('struct_declarator -> COLON constant_expression','struct_declarator',2,'p_struct_declarator_2','c_parser.py',1084),
+ ('enum_specifier -> ENUM ID','enum_specifier',2,'p_enum_specifier_1','c_parser.py',1092),
+ ('enum_specifier -> ENUM TYPEID','enum_specifier',2,'p_enum_specifier_1','c_parser.py',1093),
+ ('enum_specifier -> ENUM brace_open enumerator_list brace_close','enum_specifier',4,'p_enum_specifier_2','c_parser.py',1098),
+ ('enum_specifier -> ENUM ID brace_open enumerator_list brace_close','enum_specifier',5,'p_enum_specifier_3','c_parser.py',1103),
+ ('enum_specifier -> ENUM TYPEID brace_open enumerator_list brace_close','enum_specifier',5,'p_enum_specifier_3','c_parser.py',1104),
+ ('enumerator_list -> enumerator','enumerator_list',1,'p_enumerator_list','c_parser.py',1109),
+ ('enumerator_list -> enumerator_list COMMA','enumerator_list',2,'p_enumerator_list','c_parser.py',1110),
+ ('enumerator_list -> enumerator_list COMMA enumerator','enumerator_list',3,'p_enumerator_list','c_parser.py',1111),
+ ('alignment_specifier -> _ALIGNAS LPAREN type_name RPAREN','alignment_specifier',4,'p_alignment_specifier','c_parser.py',1122),
+ ('alignment_specifier -> _ALIGNAS LPAREN constant_expression RPAREN','alignment_specifier',4,'p_alignment_specifier','c_parser.py',1123),
+ ('enumerator -> ID','enumerator',1,'p_enumerator','c_parser.py',1128),
+ ('enumerator -> ID EQUALS constant_expression','enumerator',3,'p_enumerator','c_parser.py',1129),
+ ('declarator -> id_declarator','declarator',1,'p_declarator','c_parser.py',1144),
+ ('declarator -> typeid_declarator','declarator',1,'p_declarator','c_parser.py',1145),
+ ('pointer -> TIMES type_qualifier_list_opt','pointer',2,'p_pointer','c_parser.py',1257),
+ ('pointer -> TIMES type_qualifier_list_opt pointer','pointer',3,'p_pointer','c_parser.py',1258),
+ ('type_qualifier_list -> type_qualifier','type_qualifier_list',1,'p_type_qualifier_list','c_parser.py',1287),
+ ('type_qualifier_list -> type_qualifier_list type_qualifier','type_qualifier_list',2,'p_type_qualifier_list','c_parser.py',1288),
+ ('parameter_type_list -> parameter_list','parameter_type_list',1,'p_parameter_type_list','c_parser.py',1293),
+ ('parameter_type_list -> parameter_list COMMA ELLIPSIS','parameter_type_list',3,'p_parameter_type_list','c_parser.py',1294),
+ ('parameter_list -> parameter_declaration','parameter_list',1,'p_parameter_list','c_parser.py',1302),
+ ('parameter_list -> parameter_list COMMA parameter_declaration','parameter_list',3,'p_parameter_list','c_parser.py',1303),
+ ('parameter_declaration -> declaration_specifiers id_declarator','parameter_declaration',2,'p_parameter_declaration_1','c_parser.py',1322),
+ ('parameter_declaration -> declaration_specifiers typeid_noparen_declarator','parameter_declaration',2,'p_parameter_declaration_1','c_parser.py',1323),
+ ('parameter_declaration -> declaration_specifiers abstract_declarator_opt','parameter_declaration',2,'p_parameter_declaration_2','c_parser.py',1334),
+ ('identifier_list -> identifier','identifier_list',1,'p_identifier_list','c_parser.py',1366),
+ ('identifier_list -> identifier_list COMMA identifier','identifier_list',3,'p_identifier_list','c_parser.py',1367),
+ ('initializer -> assignment_expression','initializer',1,'p_initializer_1','c_parser.py',1376),
+ ('initializer -> brace_open initializer_list_opt brace_close','initializer',3,'p_initializer_2','c_parser.py',1381),
+ ('initializer -> brace_open initializer_list COMMA brace_close','initializer',4,'p_initializer_2','c_parser.py',1382),
+ ('initializer_list -> designation_opt initializer','initializer_list',2,'p_initializer_list','c_parser.py',1390),
+ ('initializer_list -> initializer_list COMMA designation_opt initializer','initializer_list',4,'p_initializer_list','c_parser.py',1391),
+ ('designation -> designator_list EQUALS','designation',2,'p_designation','c_parser.py',1402),
+ ('designator_list -> designator','designator_list',1,'p_designator_list','c_parser.py',1410),
+ ('designator_list -> designator_list designator','designator_list',2,'p_designator_list','c_parser.py',1411),
+ ('designator -> LBRACKET constant_expression RBRACKET','designator',3,'p_designator','c_parser.py',1416),
+ ('designator -> PERIOD identifier','designator',2,'p_designator','c_parser.py',1417),
+ ('type_name -> specifier_qualifier_list abstract_declarator_opt','type_name',2,'p_type_name','c_parser.py',1422),
+ ('abstract_declarator -> pointer','abstract_declarator',1,'p_abstract_declarator_1','c_parser.py',1434),
+ ('abstract_declarator -> pointer direct_abstract_declarator','abstract_declarator',2,'p_abstract_declarator_2','c_parser.py',1442),
+ ('abstract_declarator -> direct_abstract_declarator','abstract_declarator',1,'p_abstract_declarator_3','c_parser.py',1447),
+ ('direct_abstract_declarator -> LPAREN abstract_declarator RPAREN','direct_abstract_declarator',3,'p_direct_abstract_declarator_1','c_parser.py',1457),
+ ('direct_abstract_declarator -> direct_abstract_declarator LBRACKET assignment_expression_opt RBRACKET','direct_abstract_declarator',4,'p_direct_abstract_declarator_2','c_parser.py',1461),
+ ('direct_abstract_declarator -> LBRACKET type_qualifier_list_opt assignment_expression_opt RBRACKET','direct_abstract_declarator',4,'p_direct_abstract_declarator_3','c_parser.py',1472),
+ ('direct_abstract_declarator -> direct_abstract_declarator LBRACKET TIMES RBRACKET','direct_abstract_declarator',4,'p_direct_abstract_declarator_4','c_parser.py',1482),
+ ('direct_abstract_declarator -> LBRACKET TIMES RBRACKET','direct_abstract_declarator',3,'p_direct_abstract_declarator_5','c_parser.py',1493),
+ ('direct_abstract_declarator -> direct_abstract_declarator LPAREN parameter_type_list_opt RPAREN','direct_abstract_declarator',4,'p_direct_abstract_declarator_6','c_parser.py',1502),
+ ('direct_abstract_declarator -> LPAREN parameter_type_list_opt RPAREN','direct_abstract_declarator',3,'p_direct_abstract_declarator_7','c_parser.py',1512),
+ ('block_item -> declaration','block_item',1,'p_block_item','c_parser.py',1523),
+ ('block_item -> statement','block_item',1,'p_block_item','c_parser.py',1524),
+ ('block_item_list -> block_item','block_item_list',1,'p_block_item_list','c_parser.py',1531),
+ ('block_item_list -> block_item_list block_item','block_item_list',2,'p_block_item_list','c_parser.py',1532),
+ ('compound_statement -> brace_open block_item_list_opt brace_close','compound_statement',3,'p_compound_statement_1','c_parser.py',1538),
+ ('labeled_statement -> ID COLON pragmacomp_or_statement','labeled_statement',3,'p_labeled_statement_1','c_parser.py',1544),
+ ('labeled_statement -> CASE constant_expression COLON pragmacomp_or_statement','labeled_statement',4,'p_labeled_statement_2','c_parser.py',1548),
+ ('labeled_statement -> DEFAULT COLON pragmacomp_or_statement','labeled_statement',3,'p_labeled_statement_3','c_parser.py',1552),
+ ('selection_statement -> IF LPAREN expression RPAREN pragmacomp_or_statement','selection_statement',5,'p_selection_statement_1','c_parser.py',1556),
+ ('selection_statement -> IF LPAREN expression RPAREN statement ELSE pragmacomp_or_statement','selection_statement',7,'p_selection_statement_2','c_parser.py',1560),
+ ('selection_statement -> SWITCH LPAREN expression RPAREN pragmacomp_or_statement','selection_statement',5,'p_selection_statement_3','c_parser.py',1564),
+ ('iteration_statement -> WHILE LPAREN expression RPAREN pragmacomp_or_statement','iteration_statement',5,'p_iteration_statement_1','c_parser.py',1569),
+ ('iteration_statement -> DO pragmacomp_or_statement WHILE LPAREN expression RPAREN SEMI','iteration_statement',7,'p_iteration_statement_2','c_parser.py',1573),
+ ('iteration_statement -> FOR LPAREN expression_opt SEMI expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement','iteration_statement',9,'p_iteration_statement_3','c_parser.py',1577),
+ ('iteration_statement -> FOR LPAREN declaration expression_opt SEMI expression_opt RPAREN pragmacomp_or_statement','iteration_statement',8,'p_iteration_statement_4','c_parser.py',1581),
+ ('jump_statement -> GOTO ID SEMI','jump_statement',3,'p_jump_statement_1','c_parser.py',1586),
+ ('jump_statement -> BREAK SEMI','jump_statement',2,'p_jump_statement_2','c_parser.py',1590),
+ ('jump_statement -> CONTINUE SEMI','jump_statement',2,'p_jump_statement_3','c_parser.py',1594),
+ ('jump_statement -> RETURN expression SEMI','jump_statement',3,'p_jump_statement_4','c_parser.py',1598),
+ ('jump_statement -> RETURN SEMI','jump_statement',2,'p_jump_statement_4','c_parser.py',1599),
+ ('expression_statement -> expression_opt SEMI','expression_statement',2,'p_expression_statement','c_parser.py',1604),
+ ('expression -> assignment_expression','expression',1,'p_expression','c_parser.py',1611),
+ ('expression -> expression COMMA assignment_expression','expression',3,'p_expression','c_parser.py',1612),
+ ('assignment_expression -> LPAREN compound_statement RPAREN','assignment_expression',3,'p_parenthesized_compound_expression','c_parser.py',1624),
+ ('typedef_name -> TYPEID','typedef_name',1,'p_typedef_name','c_parser.py',1628),
+ ('assignment_expression -> conditional_expression','assignment_expression',1,'p_assignment_expression','c_parser.py',1632),
+ ('assignment_expression -> unary_expression assignment_operator assignment_expression','assignment_expression',3,'p_assignment_expression','c_parser.py',1633),
+ ('assignment_operator -> EQUALS','assignment_operator',1,'p_assignment_operator','c_parser.py',1646),
+ ('assignment_operator -> XOREQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1647),
+ ('assignment_operator -> TIMESEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1648),
+ ('assignment_operator -> DIVEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1649),
+ ('assignment_operator -> MODEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1650),
+ ('assignment_operator -> PLUSEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1651),
+ ('assignment_operator -> MINUSEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1652),
+ ('assignment_operator -> LSHIFTEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1653),
+ ('assignment_operator -> RSHIFTEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1654),
+ ('assignment_operator -> ANDEQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1655),
+ ('assignment_operator -> OREQUAL','assignment_operator',1,'p_assignment_operator','c_parser.py',1656),
+ ('constant_expression -> conditional_expression','constant_expression',1,'p_constant_expression','c_parser.py',1661),
+ ('conditional_expression -> binary_expression','conditional_expression',1,'p_conditional_expression','c_parser.py',1665),
+ ('conditional_expression -> binary_expression CONDOP expression COLON conditional_expression','conditional_expression',5,'p_conditional_expression','c_parser.py',1666),
+ ('binary_expression -> cast_expression','binary_expression',1,'p_binary_expression','c_parser.py',1674),
+ ('binary_expression -> binary_expression TIMES binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1675),
+ ('binary_expression -> binary_expression DIVIDE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1676),
+ ('binary_expression -> binary_expression MOD binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1677),
+ ('binary_expression -> binary_expression PLUS binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1678),
+ ('binary_expression -> binary_expression MINUS binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1679),
+ ('binary_expression -> binary_expression RSHIFT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1680),
+ ('binary_expression -> binary_expression LSHIFT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1681),
+ ('binary_expression -> binary_expression LT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1682),
+ ('binary_expression -> binary_expression LE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1683),
+ ('binary_expression -> binary_expression GE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1684),
+ ('binary_expression -> binary_expression GT binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1685),
+ ('binary_expression -> binary_expression EQ binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1686),
+ ('binary_expression -> binary_expression NE binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1687),
+ ('binary_expression -> binary_expression AND binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1688),
+ ('binary_expression -> binary_expression OR binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1689),
+ ('binary_expression -> binary_expression XOR binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1690),
+ ('binary_expression -> binary_expression LAND binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1691),
+ ('binary_expression -> binary_expression LOR binary_expression','binary_expression',3,'p_binary_expression','c_parser.py',1692),
+ ('cast_expression -> unary_expression','cast_expression',1,'p_cast_expression_1','c_parser.py',1700),
+ ('cast_expression -> LPAREN type_name RPAREN cast_expression','cast_expression',4,'p_cast_expression_2','c_parser.py',1704),
+ ('unary_expression -> postfix_expression','unary_expression',1,'p_unary_expression_1','c_parser.py',1708),
+ ('unary_expression -> PLUSPLUS unary_expression','unary_expression',2,'p_unary_expression_2','c_parser.py',1712),
+ ('unary_expression -> MINUSMINUS unary_expression','unary_expression',2,'p_unary_expression_2','c_parser.py',1713),
+ ('unary_expression -> unary_operator cast_expression','unary_expression',2,'p_unary_expression_2','c_parser.py',1714),
+ ('unary_expression -> SIZEOF unary_expression','unary_expression',2,'p_unary_expression_3','c_parser.py',1719),
+ ('unary_expression -> SIZEOF LPAREN type_name RPAREN','unary_expression',4,'p_unary_expression_3','c_parser.py',1720),
+ ('unary_expression -> _ALIGNOF LPAREN type_name RPAREN','unary_expression',4,'p_unary_expression_3','c_parser.py',1721),
+ ('unary_operator -> AND','unary_operator',1,'p_unary_operator','c_parser.py',1729),
+ ('unary_operator -> TIMES','unary_operator',1,'p_unary_operator','c_parser.py',1730),
+ ('unary_operator -> PLUS','unary_operator',1,'p_unary_operator','c_parser.py',1731),
+ ('unary_operator -> MINUS','unary_operator',1,'p_unary_operator','c_parser.py',1732),
+ ('unary_operator -> NOT','unary_operator',1,'p_unary_operator','c_parser.py',1733),
+ ('unary_operator -> LNOT','unary_operator',1,'p_unary_operator','c_parser.py',1734),
+ ('postfix_expression -> primary_expression','postfix_expression',1,'p_postfix_expression_1','c_parser.py',1739),
+ ('postfix_expression -> postfix_expression LBRACKET expression RBRACKET','postfix_expression',4,'p_postfix_expression_2','c_parser.py',1743),
+ ('postfix_expression -> postfix_expression LPAREN argument_expression_list RPAREN','postfix_expression',4,'p_postfix_expression_3','c_parser.py',1747),
+ ('postfix_expression -> postfix_expression LPAREN RPAREN','postfix_expression',3,'p_postfix_expression_3','c_parser.py',1748),
+ ('postfix_expression -> postfix_expression PERIOD ID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1753),
+ ('postfix_expression -> postfix_expression PERIOD TYPEID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1754),
+ ('postfix_expression -> postfix_expression ARROW ID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1755),
+ ('postfix_expression -> postfix_expression ARROW TYPEID','postfix_expression',3,'p_postfix_expression_4','c_parser.py',1756),
+ ('postfix_expression -> postfix_expression PLUSPLUS','postfix_expression',2,'p_postfix_expression_5','c_parser.py',1762),
+ ('postfix_expression -> postfix_expression MINUSMINUS','postfix_expression',2,'p_postfix_expression_5','c_parser.py',1763),
+ ('postfix_expression -> LPAREN type_name RPAREN brace_open initializer_list brace_close','postfix_expression',6,'p_postfix_expression_6','c_parser.py',1768),
+ ('postfix_expression -> LPAREN type_name RPAREN brace_open initializer_list COMMA brace_close','postfix_expression',7,'p_postfix_expression_6','c_parser.py',1769),
+ ('primary_expression -> identifier','primary_expression',1,'p_primary_expression_1','c_parser.py',1774),
+ ('primary_expression -> constant','primary_expression',1,'p_primary_expression_2','c_parser.py',1778),
+ ('primary_expression -> unified_string_literal','primary_expression',1,'p_primary_expression_3','c_parser.py',1782),
+ ('primary_expression -> unified_wstring_literal','primary_expression',1,'p_primary_expression_3','c_parser.py',1783),
+ ('primary_expression -> LPAREN expression RPAREN','primary_expression',3,'p_primary_expression_4','c_parser.py',1788),
+ ('primary_expression -> OFFSETOF LPAREN type_name COMMA offsetof_member_designator RPAREN','primary_expression',6,'p_primary_expression_5','c_parser.py',1792),
+ ('offsetof_member_designator -> identifier','offsetof_member_designator',1,'p_offsetof_member_designator','c_parser.py',1800),
+ ('offsetof_member_designator -> offsetof_member_designator PERIOD identifier','offsetof_member_designator',3,'p_offsetof_member_designator','c_parser.py',1801),
+ ('offsetof_member_designator -> offsetof_member_designator LBRACKET expression RBRACKET','offsetof_member_designator',4,'p_offsetof_member_designator','c_parser.py',1802),
+ ('argument_expression_list -> assignment_expression','argument_expression_list',1,'p_argument_expression_list','c_parser.py',1814),
+ ('argument_expression_list -> argument_expression_list COMMA assignment_expression','argument_expression_list',3,'p_argument_expression_list','c_parser.py',1815),
+ ('identifier -> ID','identifier',1,'p_identifier','c_parser.py',1824),
+ ('constant -> INT_CONST_DEC','constant',1,'p_constant_1','c_parser.py',1828),
+ ('constant -> INT_CONST_OCT','constant',1,'p_constant_1','c_parser.py',1829),
+ ('constant -> INT_CONST_HEX','constant',1,'p_constant_1','c_parser.py',1830),
+ ('constant -> INT_CONST_BIN','constant',1,'p_constant_1','c_parser.py',1831),
+ ('constant -> INT_CONST_CHAR','constant',1,'p_constant_1','c_parser.py',1832),
+ ('constant -> FLOAT_CONST','constant',1,'p_constant_2','c_parser.py',1851),
+ ('constant -> HEX_FLOAT_CONST','constant',1,'p_constant_2','c_parser.py',1852),
+ ('constant -> CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1868),
+ ('constant -> WCHAR_CONST','constant',1,'p_constant_3','c_parser.py',1869),
+ ('constant -> U8CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1870),
+ ('constant -> U16CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1871),
+ ('constant -> U32CHAR_CONST','constant',1,'p_constant_3','c_parser.py',1872),
+ ('unified_string_literal -> STRING_LITERAL','unified_string_literal',1,'p_unified_string_literal','c_parser.py',1883),
+ ('unified_string_literal -> unified_string_literal STRING_LITERAL','unified_string_literal',2,'p_unified_string_literal','c_parser.py',1884),
+ ('unified_wstring_literal -> WSTRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1894),
+ ('unified_wstring_literal -> U8STRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1895),
+ ('unified_wstring_literal -> U16STRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1896),
+ ('unified_wstring_literal -> U32STRING_LITERAL','unified_wstring_literal',1,'p_unified_wstring_literal','c_parser.py',1897),
+ ('unified_wstring_literal -> unified_wstring_literal WSTRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1898),
+ ('unified_wstring_literal -> unified_wstring_literal U8STRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1899),
+ ('unified_wstring_literal -> unified_wstring_literal U16STRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1900),
+ ('unified_wstring_literal -> unified_wstring_literal U32STRING_LITERAL','unified_wstring_literal',2,'p_unified_wstring_literal','c_parser.py',1901),
+ ('brace_open -> LBRACE','brace_open',1,'p_brace_open','c_parser.py',1911),
+ ('brace_close -> RBRACE','brace_close',1,'p_brace_close','c_parser.py',1917),
+ ('empty -> <empty>','empty',0,'p_empty','c_parser.py',1923),
+]
diff --git a/lib/pycryptodome-3.15.0.dist-info/AUTHORS.rst b/lib/pycryptodome-3.15.0.dist-info/AUTHORS.rst
new file mode 100644
index 0000000..79adf3c
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/AUTHORS.rst
@@ -0,0 +1,50 @@
+Simon Arneaud
+Nevins Bartolomeo
+Thorsten E. Behrens
+Tim Berners-Lee
+Frédéric Bertolus
+Ian Bicking
+Joris Bontje
+Antoon Bosselaers
+Andrea Bottoni
+Jean-Paul Calderone
+Sergey Chernov
+Geremy Condra
+Jan Dittberner
+Andrew Eland
+Philippe Frycia
+Peter Gutmann
+Hirendra Hindocha
+Nikhil Jhingan
+Sebastian Kayser
+Ryan Kelly
+Andrew M. Kuchling
+Piers Lauder
+Legrandin
+M.-A. Lemburg
+Wim Lewis
+Darsey C. Litzenberger
+Richard Mitchell
+Mark Moraes
+Lim Chee Siang
+Bryan Olson
+Wallace Owen
+Colin Plumb
+Robey Pointer
+Lorenz Quack
+Sebastian Ramacher
+Jeethu Rao
+James P. Rutledge
+Matt Schreiner
+Peter Simmons
+Janne Snabb
+Tom St. Denis
+Anders Sundman
+Paul Swartz
+Fabrizio Tarizzo
+Kevin M. Turner
+Barry A. Warsaw
+Eric Young
+Hannes van Niekerk
+Stefan Seering
+Koki Takahashi
diff --git a/lib/pycryptodome-3.15.0.dist-info/INSTALLER b/lib/pycryptodome-3.15.0.dist-info/INSTALLER
new file mode 100644
index 0000000..a1b589e
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/lib/pycryptodome-3.15.0.dist-info/LICENSE.rst b/lib/pycryptodome-3.15.0.dist-info/LICENSE.rst
new file mode 100644
index 0000000..3008ff7
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/LICENSE.rst
@@ -0,0 +1,61 @@
+The source code in PyCryptodome is partially in the public domain
+and partially released under the BSD 2-Clause license.
+
+In either case, there are minimal if no restrictions on the redistribution,
+modification and usage of the software.
+
+Public domain
+=============
+
+All code originating from PyCrypto is free and unencumbered software
+released into the public domain.
+
+Anyone is free to copy, modify, publish, use, compile, sell, or
+distribute this software, either in source code form or as a compiled
+binary, for any purpose, commercial or non-commercial, and by any
+means.
+
+In jurisdictions that recognize copyright laws, the author or authors
+of this software dedicate any and all copyright interest in the
+software to the public domain. We make this dedication for the benefit
+of the public at large and to the detriment of our heirs and
+successors. We intend this dedication to be an overt act of
+relinquishment in perpetuity of all present and future rights to this
+software under copyright law.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
+OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+OTHER DEALINGS IN THE SOFTWARE.
+
+For more information, please refer to <http://unlicense.org>
+
+BSD license
+===========
+
+All direct contributions to PyCryptodome are released under the following
+license. The copyright of each piece belongs to the respective author.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/lib/pycryptodome-3.15.0.dist-info/METADATA b/lib/pycryptodome-3.15.0.dist-info/METADATA
new file mode 100644
index 0000000..6c8eba6
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/METADATA
@@ -0,0 +1,84 @@
+Metadata-Version: 2.1
+Name: pycryptodome
+Version: 3.15.0
+Summary: Cryptographic library for Python
+Home-page: https://www.pycryptodome.org
+Author: Helder Eijs
+Author-email: helderijs@gmail.com
+License: BSD, Public Domain
+Project-URL: Source, https://github.com/Legrandin/pycryptodome/
+Platform: Posix; MacOS X; Windows
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: License :: OSI Approved :: BSD License
+Classifier: License :: OSI Approved :: Apache Software License
+Classifier: License :: Public Domain
+Classifier: Intended Audience :: Developers
+Classifier: Operating System :: Unix
+Classifier: Operating System :: Microsoft :: Windows
+Classifier: Operating System :: MacOS :: MacOS X
+Classifier: Topic :: Security :: Cryptography
+Classifier: Programming Language :: Python :: 2
+Classifier: Programming Language :: Python :: 2.7
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.5
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Requires-Python: >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*
+
+
+PyCryptodome
+============
+
+PyCryptodome is a self-contained Python package of low-level
+cryptographic primitives.
+
+It supports Python 2.7, Python 3.5 and newer, and PyPy.
+
+You can install it with::
+
+ pip install pycryptodome
+
+All modules are installed under the ``Crypto`` package.
+
+Check the pycryptodomex_ project for the equivalent library that
+works under the ``Cryptodome`` package.
+
+PyCryptodome is a fork of PyCrypto. It brings several enhancements
+with respect to the last official version of PyCrypto (2.6.1),
+for instance:
+
+* Authenticated encryption modes (GCM, CCM, EAX, SIV, OCB)
+* Accelerated AES on Intel platforms via AES-NI
+* First class support for PyPy
+* Elliptic curves cryptography (NIST P-curves; Ed25519, Ed448)
+* Better and more compact API (`nonce` and `iv` attributes for ciphers,
+ automatic generation of random nonces and IVs, simplified CTR cipher mode,
+ and more)
+* SHA-3 (including SHAKE XOFs) and BLAKE2 hash algorithms
+* Salsa20 and ChaCha20 stream ciphers
+* scrypt and HKDF
+* Deterministic (EC)DSA and EdDSA
+* Password-protected PKCS#8 key containers
+* Shamir's Secret Sharing scheme
+* Random numbers get sourced directly from the OS (and not from a CSPRNG in userspace)
+* Simplified install process, including better support for Windows
+* Cleaner RSA and DSA key generation (largely based on FIPS 186-4)
+* Major clean ups and simplification of the code base
+
+PyCryptodome is not a wrapper to a separate C library like *OpenSSL*.
+To the largest possible extent, algorithms are implemented in pure Python.
+Only the pieces that are extremely critical to performance (e.g. block ciphers)
+are implemented as C extensions.
+
+For more information, see the `homepage`_.
+
+All the code can be downloaded from `GitHub`_.
+
+.. _pycryptodomex: https://pypi.python.org/pypi/pycryptodomex
+.. _`homepage`: http://www.pycryptodome.org
+.. _GitHub: https://github.com/Legrandin/pycryptodome
+
+
diff --git a/lib/pycryptodome-3.15.0.dist-info/RECORD b/lib/pycryptodome-3.15.0.dist-info/RECORD
new file mode 100644
index 0000000..3920f9c
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/RECORD
@@ -0,0 +1,513 @@
+Crypto/Cipher/AES.py,sha256=xTYZap5JY4PMfurJr9--yZhbvYogFSY-7XIzOiDTg-c,9525
+Crypto/Cipher/AES.pyi,sha256=KFRI5Bc7OMN7APFJ48wrcH0sQISsFfiY91HtLNpJqgA,1343
+Crypto/Cipher/ARC2.py,sha256=L2-Nj1XT8zBjZUTykGq8N0QONXfuiNLJNQJZzSztI_M,7010
+Crypto/Cipher/ARC2.pyi,sha256=FrGIcMtmPnJ08OoLnYhCJCYx0RxYvSgHimNKeTkn0qQ,954
+Crypto/Cipher/ARC4.py,sha256=MjMVoqqU232XJsU_sZlPrCAjUMpGGAT3fDS8PpC-XZ8,5152
+Crypto/Cipher/ARC4.pyi,sha256=sMw73yZHeonmGx9BhiyA7__4PQJocU04SMRcDjnyJ2Y,431
+Crypto/Cipher/Blowfish.py,sha256=UbpLnRlNWaclr9cjZuvgJoPXrTI9fYJThqG2X5NIFH4,5964
+Crypto/Cipher/Blowfish.pyi,sha256=C7lgc7tn4IqV_7jwP-QGUTTgGsik6ZPY4cBH8q7LSII,990
+Crypto/Cipher/CAST.py,sha256=DIoa9CsmGqI-h1h4U3tBF3PzYFVdWbfHqTi_TeWEVck,6071
+Crypto/Cipher/CAST.pyi,sha256=S_Z4-goG9AWDj6QwE3pjPLJKpHbQ2ssq98vIAWTOqVU,955
+Crypto/Cipher/ChaCha20.py,sha256=OyovzJUl-VhzeXOsDc4ZP3lH7AaRT7agdDXDRVXK0zo,10760
+Crypto/Cipher/ChaCha20.pyi,sha256=_l1xhtOyBmYEHP7Ftmk8EQZpKegX9p3N5tckC_PPve0,762
+Crypto/Cipher/ChaCha20_Poly1305.py,sha256=kxGCaUw_3Q-OpyZ-zuEwmf_nvj5uz6Oaed3NiFMMyM4,11529
+Crypto/Cipher/ChaCha20_Poly1305.pyi,sha256=h1U5ixODzM9NwLpX9oaIJdeQ0ubYeDeY9m6ur05dKCc,1068
+Crypto/Cipher/DES.py,sha256=IAX3Aeuy5jPNMJQ1nTMTwndzdYHZoXhprogxYwBDhu8,5947
+Crypto/Cipher/DES.pyi,sha256=7M1zxQlI3aeyymRzJIRayALnv-m9WWgzQ4bgtmuGmFY,935
+Crypto/Cipher/DES3.py,sha256=HEGzkcXrlpVg6ouCjRSYpLr7gsNTQnOCnwaxveotHqA,6925
+Crypto/Cipher/DES3.pyi,sha256=KLzWGTtEC15IRMroCk4dwNNRJnr99if-U30nWXQEQnw,1005
+Crypto/Cipher/PKCS1_OAEP.py,sha256=Wq8QJLp8EW5owNdAm6uLkJDGtFLRNDz_PVnzaMJUT2g,8827
+Crypto/Cipher/PKCS1_OAEP.pyi,sha256=J2ZwIrEKfFQ4Geilq0v5aGtbO74lIh_dqNUECPpN8So,1179
+Crypto/Cipher/PKCS1_v1_5.py,sha256=LXW0FRnU82b-ge4GkgApNg0-BRG3Cf-ABjf6WCnKsIo,8141
+Crypto/Cipher/PKCS1_v1_5.pyi,sha256=KTIQpRCb5mkRLvI9EdPEUeBxLozRS-TOmdeDaJP8Ad0,686
+Crypto/Cipher/Salsa20.py,sha256=sTjms0rve3YTSEql53KmTeAKRvzvPwQ-CVMem06JMzs,6349
+Crypto/Cipher/Salsa20.pyi,sha256=4vjq_HN8NK7U9VdaaHIgs17-fyW8SRPDZaHy3jKVkto,744
+Crypto/Cipher/_ARC4.abi3.so,sha256=9tij7SXY2szh3jxEUesoveaWK69ZzdpHyuXBjoNOAiI,13768
+Crypto/Cipher/_EKSBlowfish.py,sha256=lI0pa8XOG3GTRfXTQtpW8BzTs6AHw4YSXSw6A2zKvJ8,5205
+Crypto/Cipher/_EKSBlowfish.pyi,sha256=L-GHYtoL61P9rdPcZAKmFJZxNMU9N4sfXtQNM0u8nIc,266
+Crypto/Cipher/_Salsa20.abi3.so,sha256=rO__EPhnlaYPjmqcDsbwhdlPAE4bsUHEV8Ya7RniIzk,26784
+Crypto/Cipher/__init__.py,sha256=j2fT-hweEnhbklqWv85cWwPinlM30uepKarrpZSUTSc,2844
+Crypto/Cipher/__init__.pyi,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+Crypto/Cipher/__pycache__/AES.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/ARC2.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/ARC4.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/Blowfish.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/CAST.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/ChaCha20.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/ChaCha20_Poly1305.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/DES.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/DES3.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/PKCS1_OAEP.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/PKCS1_v1_5.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/Salsa20.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_EKSBlowfish.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_cbc.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_ccm.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_cfb.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_ctr.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_eax.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_ecb.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_gcm.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_ocb.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_ofb.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_openpgp.cpython-39.pyc,,
+Crypto/Cipher/__pycache__/_mode_siv.cpython-39.pyc,,
+Crypto/Cipher/_chacha20.abi3.so,sha256=YyOToJ7m_aReh5TmGbs1yyvkzzgpxKbe0g-PI91Oepg,28224
+Crypto/Cipher/_mode_cbc.py,sha256=H7jeSQnH8dbmPqkAekW9I2XoRW1WbTIm71VUVOQp2tQ,10951
+Crypto/Cipher/_mode_cbc.pyi,sha256=8T9G8dP1R7oUDlsNEdX3ywl6aZ0W_3v87iL8HjqeeVs,687
+Crypto/Cipher/_mode_ccm.py,sha256=ppEHk8CI5eDyB_FQePcdyXuEAeMIdBkN5lfnKf-K7zY,24444
+Crypto/Cipher/_mode_ccm.pyi,sha256=ZSs4SOlivIG_JUxknDjQKs4ZYGmkwAO7K2DKcbz_14M,1600
+Crypto/Cipher/_mode_cfb.py,sha256=wv4H-0_8eBHE6lOuLcvl97q4bfybqzhlnjycduqHRrU,10801
+Crypto/Cipher/_mode_cfb.pyi,sha256=lQ2LvikXb0CrHqR72-j7Q8ygbMyJhb0OwqnhVEmkyR8,727
+Crypto/Cipher/_mode_ctr.py,sha256=UZnAZHw_hdhY9TEkmUMglXMx71NkQR9NLukt_G-Kvlc,15830
+Crypto/Cipher/_mode_ctr.pyi,sha256=JEZHTO88y3L1uDLyySl8ARPSvoaalB6_D2YQ78tD_k4,800
+Crypto/Cipher/_mode_eax.py,sha256=vsQKIaoyny8GHHVP7rPt4GMucgZTEr03r13s2JK7s9E,14511
+Crypto/Cipher/_mode_eax.pyi,sha256=VHPtTdA-2btCvRE-4npRtGCrApg7rBNWpHSZV1po8J0,1545
+Crypto/Cipher/_mode_ecb.py,sha256=xy3XV-5T5JWQFT80P0z2W7DOzvwUgAot6jcV4nKintw,8309
+Crypto/Cipher/_mode_ecb.pyi,sha256=xe_OlSwhFAJwcI5JizfR_zkBtm983n5UV6Zg2JGcKcA,592
+Crypto/Cipher/_mode_gcm.py,sha256=oZwOEUTf4N9fVaHzjI2KWXXoDoezblYG_w3Fv0S8Vzk,21358
+Crypto/Cipher/_mode_gcm.pyi,sha256=5t72QHQS0gDq6wtzYfaVqTxmjBzpUvsQvDaP2DqNvLE,1541
+Crypto/Cipher/_mode_ocb.py,sha256=BpJNunBkTpSLQccSTsslxJlrjxmosqk-gxepaDOFpAM,19794
+Crypto/Cipher/_mode_ocb.pyi,sha256=SXMUa1s1dY-272lktxSOtyOoqLdtPvfNkRXqmXjBE4o,1231
+Crypto/Cipher/_mode_ofb.py,sha256=JjLNn-8eDN6js8am8HBmCUrNj3hgUIrJync-y_-Xl8I,10281
+Crypto/Cipher/_mode_ofb.pyi,sha256=T-SVUS0N52GpvFu1tAsWNX1mNMXIq4N1vo0wQV2uV8I,691
+Crypto/Cipher/_mode_openpgp.py,sha256=KFmMsSXSQHa_i8iaK06DJWjAZ4C954EB-91kFfqZ_uw,7053
+Crypto/Cipher/_mode_openpgp.pyi,sha256=FoLrFqnvxJf0F_npHOgPURfUyGSt6DxyIp2ikoXi-CI,556
+Crypto/Cipher/_mode_siv.py,sha256=VXEUCbGEpM92GGDhiTdS2ntPmRSZGOPKFyHLwNNWqes,14062
+Crypto/Cipher/_mode_siv.pyi,sha256=syb3kXnyuhoQV6FXvozIjudWCQBCadOb1I2BuV-6Ai0,1261
+Crypto/Cipher/_pkcs1_decode.abi3.so,sha256=Sj2GwQlSOqApxtTnOnLzmkgEnevQMi0C0SQ6X4J87LI,28096
+Crypto/Cipher/_raw_aes.abi3.so,sha256=-0h1F5j7S2o5Fse9E8vJsEJdyssYyXbOw7xFqUkOFXE,66256
+Crypto/Cipher/_raw_aesni.abi3.so,sha256=ETlT13NLNrK1rJqzWd7OSiqxc5l6kleKwHr-5Jz0kI8,101136
+Crypto/Cipher/_raw_arc2.abi3.so,sha256=2bCrwjUZRpxPBqXkGaJGoG1EJeRK8mfifG7Fu4M-zjo,43776
+Crypto/Cipher/_raw_blowfish.abi3.so,sha256=Oz2tFZGrXe6v7fzoglsjPDQcaGWt6Fqmda3TESGvz-k,69976
+Crypto/Cipher/_raw_cast.abi3.so,sha256=2R0xZQKE4JBMpUSGMGzI3YPiahnAV9jWBVsmr7zW7gI,42976
+Crypto/Cipher/_raw_cbc.abi3.so,sha256=m7nSugq5uISQgK7YfE3RWcfOKnOXgzkgZACCfvWg5Ro,20736
+Crypto/Cipher/_raw_cfb.abi3.so,sha256=BEWBBHHHE76yWWvvXj3CrXPnEYRyZbMOX-IvtIVGTWQ,25440
+Crypto/Cipher/_raw_ctr.abi3.so,sha256=rgXToNC7BZQVA5b32sLXnkfC87kiWeaN-_bUtxtN-8s,28600
+Crypto/Cipher/_raw_des.abi3.so,sha256=zJa6DcUSLh9r8-LnFZEqnkC2YiXjaLKuYslFJo7KWQI,75672
+Crypto/Cipher/_raw_des3.abi3.so,sha256=89CvgYf8OmRRRcLonn3X2SEZmdBThuld_6vAvnBWYFE,76480
+Crypto/Cipher/_raw_ecb.abi3.so,sha256=qDblihBcqT8qVF4o85cqM8WFiYq8vZyld4mWT2Eps9E,12440
+Crypto/Cipher/_raw_eksblowfish.abi3.so,sha256=a6zP_4DvjNpkk_tsa7LqqHN5t0HSlhC5jJ6O16SGanA,166264
+Crypto/Cipher/_raw_ocb.abi3.so,sha256=nwcAK1LXd-owImaz1ktOozZ0QCg2pniWEKvkOvN_M9I,37344
+Crypto/Cipher/_raw_ofb.abi3.so,sha256=1QHGLlfLcqQ-YYQUtcHqQZtJu2KR_cU3vOsY9H1F_v8,15368
+Crypto/Hash/BLAKE2b.py,sha256=PHzRVThR_iAlAWz3Upe-nZTOlbsbVt39yZcDFFWhqFQ,9424
+Crypto/Hash/BLAKE2b.pyi,sha256=U4K3mapdYeHVHrlIEgffKV9IfALVbqkOrVbJRujns10,906
+Crypto/Hash/BLAKE2s.py,sha256=qHUh6FVqW-A_JLxramDoeJXYJZg3kQGzwi_zinnn3-8,9430
+Crypto/Hash/BLAKE2s.pyi,sha256=9jsL4jLQq5_Mb8WM99LPurH1D-FL-gLAeZyBf8QiWt0,739
+Crypto/Hash/CMAC.py,sha256=KQzO2VD4n-2QPqIB4yVDJOAgajP5bVuNqGFSS56aavM,10351
+Crypto/Hash/CMAC.pyi,sha256=kZXAeUzxQ38nY-aYbIPrZZmROxgja2HnvUz7xuAXuoE,822
+Crypto/Hash/HMAC.py,sha256=EGzjRKpSZnYSjowXylTIp1xaeM3mXIBSeUwYKLtU6MQ,7024
+Crypto/Hash/HMAC.pyi,sha256=fAyHBEf5Ee6LoiYYTZ9PZpmIRvitU6OriKGfjFUM_4c,624
+Crypto/Hash/KMAC128.py,sha256=RBZgbWsD7j7Nb06r8-JuZee_wuF2biZ6ChKvxLpFERI,5949
+Crypto/Hash/KMAC128.pyi,sha256=CHcjiaNKjvWQgLXsawb3Vxttxmt_hVK-Dv-5RVs6oOE,903
+Crypto/Hash/KMAC256.py,sha256=X4wqWcF7_JHwhnSvtN-tF4Wr3HdDnkRXNCb1O9Ar_VA,2906
+Crypto/Hash/KMAC256.pyi,sha256=oAeKgyta2iqjV9Yv818xoW5eJ2ixeLP4joUP8XUi2e0,226
+Crypto/Hash/KangarooTwelve.py,sha256=ayqFWsE-xL7gzAfJrd0gKbRxlZk9I59UPPLsjzWj9Xw,9029
+Crypto/Hash/KangarooTwelve.pyi,sha256=shf_g18EQxoJ8O9Kzuah17jw6J-vzmMsuqz1mAUY5WE,572
+Crypto/Hash/MD2.py,sha256=3Bbq-tKUklD80oy_02XEcpLqzxPkWeuE4tK4Y7SkLTk,6111
+Crypto/Hash/MD2.pyi,sha256=wa7GSYUpzL27so4YgiJEZo0dO201KyHK87Q20Pvm-bM,492
+Crypto/Hash/MD4.py,sha256=2MxXHclhh1xJeA94hVJLDKCqDWEhhjpVWrzxNfDCYnI,6582
+Crypto/Hash/MD4.pyi,sha256=7ZtZQEgJCwIswneb0NBov_uL0_Toglh9EPMnLVFGqwo,532
+Crypto/Hash/MD5.py,sha256=iC2xwz0OhQvDHn7eZUrttErd5sNv_ZXlmu32iLFwpm4,6618
+Crypto/Hash/MD5.pyi,sha256=c4MCJHvYTi2YL4hmqEu9ivbSvkBJdR-S2ldUqEpzK8s,492
+Crypto/Hash/Poly1305.py,sha256=t4AxiGiYAszL40GplTxERY4h-aa_U_4Rt-2MfD95ii8,8074
+Crypto/Hash/Poly1305.pyi,sha256=TSGottirLPIRyivSjZucQB7aPaYfhrUkn8oot6OrmmU,665
+Crypto/Hash/RIPEMD.py,sha256=KlCkJgE97zU4c32OZkAzmBcOSwOYQUkuL4zGZ4fS8ZI,1199
+Crypto/Hash/RIPEMD.pyi,sha256=TEOz1O-5v3DudzYkKTLQDjS6s0c_MIsXnqVoW-y9350,94
+Crypto/Hash/RIPEMD160.py,sha256=tlIzXUa6GZ-PDKi3mriguNxWVshwFUrx4Nd5WSbGEvo,6398
+Crypto/Hash/RIPEMD160.pyi,sha256=RQ9yXxjH1BSaU3mwhsCn9-67C0a_Bcv3MDdafQCiuPs,516
+Crypto/Hash/SHA.py,sha256=nSCZz4cjbDBJwoRt0m_a-3lj2xWjxsy7gkcctJTgf6E,1148
+Crypto/Hash/SHA.pyi,sha256=LhpqURwJrNcxMtrfpSulU16nMqarhc8iEoZpurbtXIk,161
+Crypto/Hash/SHA1.py,sha256=rUvRvZ6cLdaNa9vROZn4Ju45tyTqADH470XGLm9isM8,6690
+Crypto/Hash/SHA1.pyi,sha256=vNtB_b4MytJq8Io1xufdOO6VL-nMBcCnDPIgJQuNPCM,536
+Crypto/Hash/SHA224.py,sha256=WB6hinPQ5pSBN5vsBypa0eFTonua31jMAhj5OyWBzSw,6901
+Crypto/Hash/SHA224.pyi,sha256=8RsbyIwIfO8Fc_fpWw1MnFw04Z4n-qL0G01qCQZwvx8,544
+Crypto/Hash/SHA256.py,sha256=6Hngh_GGgVP2MiLam-EMXgAuF2-1Upd1JjgPpo1db18,6897
+Crypto/Hash/SHA256.pyi,sha256=zndNEjv6DZOWaOpuoUKsA2hTi2J7-oJFgOQ10sSRnXE,612
+Crypto/Hash/SHA384.py,sha256=cPy9NZfORg5I_Zeey4kryYgP-SW9RADZEn-vJhzcb4o,6899
+Crypto/Hash/SHA384.pyi,sha256=KIWbD-lBbd7lvWgFquIqUAMaisovey0HV0Nmmq-pvOY,544
+Crypto/Hash/SHA3_224.py,sha256=DYyrrCo-weZxOw7TIupdCPVjwOv_x7r0ltxm5fUcNS0,6179
+Crypto/Hash/SHA3_224.pyi,sha256=YNvN-GxVpPK6_-ee0_n-7wgAhq7JzBaBaeGiNdVoQdk,605
+Crypto/Hash/SHA3_256.py,sha256=ajt5tbMSjsWYWAehZwpNy7fB4SzzkNxnBAvRrRZvess,6179
+Crypto/Hash/SHA3_256.pyi,sha256=JlPOiVtEVNJerGWRuBDunXBT19WX_6ObpUuMaX7QdEs,605
+Crypto/Hash/SHA3_384.py,sha256=ue24yxpMjmUCU2OktIbZDN_3AssUEhQVy6X_ajEcgJc,6274
+Crypto/Hash/SHA3_384.pyi,sha256=c3wb0c6RjlcMcK22mV4RZWsDmUA2NszaghLAXIrN8T8,605
+Crypto/Hash/SHA3_512.py,sha256=L0GMJfARBSg9CfIA5t4YAawyeA0-N4cEzpzfJdSvX04,6131
+Crypto/Hash/SHA3_512.pyi,sha256=bZ0WozTD_mQ_5t_z4SWCpCn61YhCVamF501jsQdUjps,605
+Crypto/Hash/SHA512.py,sha256=uc0Ss7qRcUf0cbQvXDZOocgZKArnO441LV19t0C--38,7720
+Crypto/Hash/SHA512.pyi,sha256=VfMzHx-0U4efCyZCrgs_aOz17W8t0ZHL_3uR8zaYzCU,622
+Crypto/Hash/SHAKE128.py,sha256=O3OptmB3Zz-o6ZabxdGyGrsPMblMFbd-S5_wlARyOW8,4761
+Crypto/Hash/SHAKE128.pyi,sha256=wLhV8lh8YYWzi7PkhAB3_JQn_hOZNvkiZYg-JjiPpfs,437
+Crypto/Hash/SHAKE256.py,sha256=0UD_mU8mEEZ2f-q1BL4GJXBzpP--4mB3an93mrFZFss,4762
+Crypto/Hash/SHAKE256.pyi,sha256=9Uq_FaeYwDx_6dLv331Wv1snnGxA2UhFcUdELHkwU9U,437
+Crypto/Hash/TupleHash128.py,sha256=vl5pac01g93xOI0ma_nL9EgOsvXkWZgwLjUIUPo9b6U,4720
+Crypto/Hash/TupleHash128.pyi,sha256=fXkKNiV-HubXeCBcBTlAtHlYXwWeQ3iDcOWCs11iIfk,652
+Crypto/Hash/TupleHash256.py,sha256=mV6ygJKOY6FUUnYVUFQr3qzxlVT2GC0Wu_4zQjReXwg,2910
+Crypto/Hash/TupleHash256.pyi,sha256=esuouWh2HmCu3M4kLjCgu5jrQ87NrBQU5h9o5x21kl0,144
+Crypto/Hash/_BLAKE2b.abi3.so,sha256=oDj6Vazfth92Y3ObyffmOxYrcF-xxdq3zpSHW8yNf8g,21888
+Crypto/Hash/_BLAKE2s.abi3.so,sha256=yCsBrlATlGDEDeGS3RskkrW30d-AC7flIVhBtDdzOBU,21712
+Crypto/Hash/_MD2.abi3.so,sha256=CKT2_OL7CFttdSoz_1xZuTgxXPwwHx9aK9YNj884hA8,20128
+Crypto/Hash/_MD4.abi3.so,sha256=0bj8Jcde8wwcjgtpyiAgnjGhdlgb_Gt8mJMpzG_bLKo,25576
+Crypto/Hash/_MD5.abi3.so,sha256=l35uHqYH6RGKjAQfT7VtcGnW1fTgeYOZTGT9wuQdktw,31704
+Crypto/Hash/_RIPEMD160.abi3.so,sha256=rEIK5YUk7ggE_SmopBQO25AdxV8347TNMa9cbcRja9M,55608
+Crypto/Hash/_SHA1.abi3.so,sha256=kZwBL7s7caKFBb2cjNBY2l6w1AJ3t8UoTBRYewzTWpY,74416
+Crypto/Hash/_SHA224.abi3.so,sha256=ynkUo7XzbPY8UY92aDFp_Cj7RNRA9ADqyGaPef_ooSA,43792
+Crypto/Hash/_SHA256.abi3.so,sha256=XGKbLOdmeb71ddO1ECbZOGCwLhUFYeOxpJnFUQQPydE,43872
+Crypto/Hash/_SHA384.abi3.so,sha256=ckyhd4K5wPj7CDNBt0iSvYlLGMSu5MCAQrSmf3wdwgA,50520
+Crypto/Hash/_SHA512.abi3.so,sha256=ZAT7ECwNhSXS4H0Ido671z0sn3aeDfnjeOh9QOy5mbE,50624
+Crypto/Hash/__init__.py,sha256=vaCd9NlWJbYWuDUu8QZAt_6DdW0mk1OAmVDvyG7-Izg,1239
+Crypto/Hash/__init__.pyi,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+Crypto/Hash/__pycache__/BLAKE2b.cpython-39.pyc,,
+Crypto/Hash/__pycache__/BLAKE2s.cpython-39.pyc,,
+Crypto/Hash/__pycache__/CMAC.cpython-39.pyc,,
+Crypto/Hash/__pycache__/HMAC.cpython-39.pyc,,
+Crypto/Hash/__pycache__/KMAC128.cpython-39.pyc,,
+Crypto/Hash/__pycache__/KMAC256.cpython-39.pyc,,
+Crypto/Hash/__pycache__/KangarooTwelve.cpython-39.pyc,,
+Crypto/Hash/__pycache__/MD2.cpython-39.pyc,,
+Crypto/Hash/__pycache__/MD4.cpython-39.pyc,,
+Crypto/Hash/__pycache__/MD5.cpython-39.pyc,,
+Crypto/Hash/__pycache__/Poly1305.cpython-39.pyc,,
+Crypto/Hash/__pycache__/RIPEMD.cpython-39.pyc,,
+Crypto/Hash/__pycache__/RIPEMD160.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA1.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA224.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA256.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA384.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA3_224.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA3_256.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA3_384.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA3_512.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHA512.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHAKE128.cpython-39.pyc,,
+Crypto/Hash/__pycache__/SHAKE256.cpython-39.pyc,,
+Crypto/Hash/__pycache__/TupleHash128.cpython-39.pyc,,
+Crypto/Hash/__pycache__/TupleHash256.cpython-39.pyc,,
+Crypto/Hash/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Hash/__pycache__/cSHAKE128.cpython-39.pyc,,
+Crypto/Hash/__pycache__/cSHAKE256.cpython-39.pyc,,
+Crypto/Hash/__pycache__/keccak.cpython-39.pyc,,
+Crypto/Hash/_ghash_clmul.abi3.so,sha256=WfUR7T-C4cNXdyE_NssjdArjLQ4N6zJO2pCVdtUE7fQ,50160
+Crypto/Hash/_ghash_portable.abi3.so,sha256=z0OI9LNT5qr9mh8wSOWDVBCDqvS_jTAISu2BZG3dhJQ,17432
+Crypto/Hash/_keccak.abi3.so,sha256=sQe8cRuL_dcphf4TSxDio8DNc_1rdDbi_kdUc3VMJm4,35064
+Crypto/Hash/_poly1305.abi3.so,sha256=fq4M2r4BoJm0xEiUk506jaVMKoiFLv1dqq7GO3VOqMQ,33360
+Crypto/Hash/cSHAKE128.py,sha256=fgY6LLEiYb31LCPt900KTSj5HkJDny0-HoqswVMIUHs,6317
+Crypto/Hash/cSHAKE128.pyi,sha256=ILenFDiznj1cjIXbjIBvzEqfFOjObKXWyMsOr63awUo,499
+Crypto/Hash/cSHAKE256.py,sha256=tQlRaYEzNlhWMLjNAgGygmcIk1aRCndVGykC51Nbbw8,2202
+Crypto/Hash/cSHAKE256.pyi,sha256=d_bNhBzSdYSBJ5_6-PqIcJ6ZDKDbXbVqAjJrDTqDMao,231
+Crypto/Hash/keccak.py,sha256=95jyfmSqW9Pr2dvqKTy6STDePCbWV5oOgFuMKrrGokQ,7543
+Crypto/Hash/keccak.pyi,sha256=pXAZaNfayZCXMxB7IDFr2F8Hi06_hwFB3GXjNzY7sBM,741
+Crypto/IO/PEM.py,sha256=z0WG8YHl_5JgkTv3NiqsaF0ri2IhLrbJfBrBVA_H6H4,6948
+Crypto/IO/PEM.pyi,sha256=a1G07RQtZvEtXHlybxdDcoTPM3nqMbdONNjzcz5HGtE,303
+Crypto/IO/PKCS8.py,sha256=tM4XbY4DIesa42RBfavYO0XCKmJdXIIoup9FisoD9Us,9070
+Crypto/IO/PKCS8.pyi,sha256=dZ2LEDFXAhPZbF8ZsXMxXlluV_XPlVsQLC8HPYF7Mro,480
+Crypto/IO/_PBES.py,sha256=grUpKRRsDlXPuW69z82edameYoYecNyAoZbPlbHbj98,16324
+Crypto/IO/_PBES.pyi,sha256=QWJLbYh7ywy2wlRWnbUQG_hqlv6zfobF5o6FKh7reWA,489
+Crypto/IO/__init__.py,sha256=QUvnoDWlmuOGEjxXh_uXHMoSmoPi_nSeh-Et7MSofeg,1540
+Crypto/IO/__pycache__/PEM.cpython-39.pyc,,
+Crypto/IO/__pycache__/PKCS8.cpython-39.pyc,,
+Crypto/IO/__pycache__/_PBES.cpython-39.pyc,,
+Crypto/IO/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Math/Numbers.py,sha256=RUPDs7rIPkBwGoH617_95IKI3FTuVwN_MFsnTGEoCyI,2022
+Crypto/Math/Numbers.pyi,sha256=8BNTgE22Kr-ZKca290R_dMtK0fGmdwbWSt6r02ggUXY,84
+Crypto/Math/Primality.py,sha256=LUw55aVWupHDHZvNrXFGAyigP_ZN6HXBV8jwbExUuok,11371
+Crypto/Math/Primality.pyi,sha256=iXAY0gUmciIS_FvH5VJwhQfK-0tDmaH2vcDLHHFyxIE,823
+Crypto/Math/_IntegerBase.py,sha256=iiL9pnDVju7U_eytYEoDkYNuizsVTLe-1bY9U9PX3qc,10508
+Crypto/Math/_IntegerBase.pyi,sha256=Ovo8qweHd6TR7h5506CiwaN9v0L-isgOfUu9YJb6Fr4,3470
+Crypto/Math/_IntegerCustom.py,sha256=skFIo_AB8GmoTgk5mI2STJB2UYYyZFtR_aCJdScA0_A,4250
+Crypto/Math/_IntegerCustom.pyi,sha256=s9UZigBEgUvHS4IOdt8jXhsZ33O9j19p7lieob1R-EY,135
+Crypto/Math/_IntegerGMP.py,sha256=uk0uu7P_7YjETw1Ga_b2PFLUQ_spnXV--Bml9y79EDw,26743
+Crypto/Math/_IntegerGMP.pyi,sha256=UcJOGMYT1d-G0PjbC5ByShFl5oyorFR8h38fFt0uY9s,78
+Crypto/Math/_IntegerNative.py,sha256=5adc6zDveHN0JF9No0zOCqekm4YRdDiWdJ5baSiq4_A,11591
+Crypto/Math/_IntegerNative.pyi,sha256=pZaN1xXnB8u7VfrMgp6jqi_jCaJ4x4t0Ecs7qZ_2x-4,81
+Crypto/Math/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+Crypto/Math/__pycache__/Numbers.cpython-39.pyc,,
+Crypto/Math/__pycache__/Primality.cpython-39.pyc,,
+Crypto/Math/__pycache__/_IntegerBase.cpython-39.pyc,,
+Crypto/Math/__pycache__/_IntegerCustom.cpython-39.pyc,,
+Crypto/Math/__pycache__/_IntegerGMP.cpython-39.pyc,,
+Crypto/Math/__pycache__/_IntegerNative.cpython-39.pyc,,
+Crypto/Math/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Math/_modexp.abi3.so,sha256=s43NhDm4d5yz5DMku63cGyis7sPbzPOmPHUy-kAW3-U,294464
+Crypto/Protocol/KDF.py,sha256=4Yj5T2oUx2d7IQiZp9pVNDDl2drlK6RjgIQCe84v4RM,19829
+Crypto/Protocol/KDF.pyi,sha256=OfuAajDDJIDIny-zMuGsfhqCLZr4x8bZnV5Tonbg00E,1383
+Crypto/Protocol/SecretSharing.py,sha256=wpZ1bvl7LGyISM7kq1XyDQuy7B8IR2AuRH2spBmnxus,8778
+Crypto/Protocol/SecretSharing.pyi,sha256=-lErV2RvaNPuOA0z4c44WmNSu9irCw_DDb7wPgCS2BY,798
+Crypto/Protocol/__init__.py,sha256=eXlh5nJVd6NoXfUjJ-mNGgm5oE8r6MYDBOIHXWdzTPw,1548
+Crypto/Protocol/__init__.pyi,sha256=RNdrwMgjt9b9LmckdRkaYYC4PCzNV-1Hi2T3B2MHgds,43
+Crypto/Protocol/__pycache__/KDF.cpython-39.pyc,,
+Crypto/Protocol/__pycache__/SecretSharing.cpython-39.pyc,,
+Crypto/Protocol/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Protocol/_scrypt.abi3.so,sha256=93tpUVttH0xGBgOn645S-CdDeMrZEdS5DFPaB1vuieU,25024
+Crypto/PublicKey/DSA.py,sha256=dLECSm9lYt_qoilF3AXizkJQsR6nd_QcaeKZ0soGU8o,22378
+Crypto/PublicKey/DSA.pyi,sha256=t6y3t_w_odo5exLTS4K3_d76ObdqfN6R1QHKOff_LA4,1381
+Crypto/PublicKey/ECC.py,sha256=8VFzOPnhnYeRMB--B6ZYNpMlRdrwzTEM7meJeusiGv4,64596
+Crypto/PublicKey/ECC.pyi,sha256=if_pcA6iKA_YL5ei1iVdoSRwPn8zMdh0KJ6eMwaoDqY,2563
+Crypto/PublicKey/ElGamal.py,sha256=zWz52BILZv9YUzBnNJbj43G_Q5NtUV2oxZK6sjfQeJ4,8615
+Crypto/PublicKey/ElGamal.pyi,sha256=-s3ty0v_o-8Rq8_nrYh32Vo6ihr8OaSWdc_H7_CVGCo,674
+Crypto/PublicKey/RSA.py,sha256=MNykS-cPAdIw6nx5oqt6pFt6UntUl4zxPP0lPMgJ6O8,29038
+Crypto/PublicKey/RSA.pyi,sha256=UpGjHMfe8UvuQPVX6NaMMezyAQw8TmltVxrLeUzr2tc,1865
+Crypto/PublicKey/__init__.py,sha256=Cu1oi4-0Qe8QXpsmacruHj54iixROrCPwWuc7jm01ho,3142
+Crypto/PublicKey/__init__.pyi,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+Crypto/PublicKey/__pycache__/DSA.cpython-39.pyc,,
+Crypto/PublicKey/__pycache__/ECC.cpython-39.pyc,,
+Crypto/PublicKey/__pycache__/ElGamal.cpython-39.pyc,,
+Crypto/PublicKey/__pycache__/RSA.cpython-39.pyc,,
+Crypto/PublicKey/__pycache__/__init__.cpython-39.pyc,,
+Crypto/PublicKey/__pycache__/_openssh.cpython-39.pyc,,
+Crypto/PublicKey/_ec_ws.abi3.so,sha256=zv8VYVbmb8qYr6618D8HJ7plUQRbqP3LTV_SYbpIAPE,1068008
+Crypto/PublicKey/_ed25519.abi3.so,sha256=QvZZUTB65m7xVhUx3IRUlaLphMNMNSypiJMcHk7c-mw,578280
+Crypto/PublicKey/_ed448.abi3.so,sha256=QyaQBiuku52ceIN-JJpYV5JyVuRswMgWejfW8qQjVjg,329424
+Crypto/PublicKey/_openssh.py,sha256=1v8v-0XbVt6TuCLky4GLZGDq45JZtnX9ExM6CGi_K_c,5126
+Crypto/PublicKey/_openssh.pyi,sha256=ywCy9UDu2_AQI60ChWxGxyqHiZoYwMKC3TVXJn_ZVIM,324
+Crypto/PublicKey/_x25519.abi3.so,sha256=fxNz69J4rA1_RC5CzJhi7Yrluat2fTbFvTOBLLaO3IY,124632
+Crypto/Random/__init__.py,sha256=Dm-guhBXdf4LEr9zqk-IPnwwplHsQFpVrJlr11-kBGA,1809
+Crypto/Random/__init__.pyi,sha256=ieifhoMB2veKusRRBZWQp6igPri5027VrqfddO5b-WU,367
+Crypto/Random/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Random/__pycache__/random.cpython-39.pyc,,
+Crypto/Random/random.py,sha256=HCDPWzIZ0_HOsLfi7iK5SWTGgHzBezcltkoEV6m-0G4,5234
+Crypto/Random/random.pyi,sha256=Lgo1h6wtyUDhEuroDRyt-eYvPFEgQOo0fxfAE68S2cM,807
+Crypto/SelfTest/Cipher/__init__.py,sha256=8T6OmvduOpqSmp1rA7ZPfIH63NorYynUoLpH4JkC2jU,3620
+Crypto/SelfTest/Cipher/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/common.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_AES.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_ARC2.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_ARC4.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_Blowfish.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_CAST.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_CBC.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_CCM.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_CFB.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_CTR.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_ChaCha20.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_ChaCha20_Poly1305.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_DES.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_DES3.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_EAX.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_GCM.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_OCB.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_OFB.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_OpenPGP.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_SIV.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_Salsa20.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_pkcs1_15.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/__pycache__/test_pkcs1_oaep.cpython-39.pyc,,
+Crypto/SelfTest/Cipher/common.py,sha256=nvzRPuvwN7GFt5m8wQzSJd1mOVda_N0GcMzC7wHF-Mg,17316
+Crypto/SelfTest/Cipher/test_AES.py,sha256=bgLc-5lq7BuEmkRxyW7e7AfXM_Q9OaY9HYV3BrDmiSo,71731
+Crypto/SelfTest/Cipher/test_ARC2.py,sha256=kaY-lc899PWYT4FOKQNu-xDoGrlZYCH226HN-mtmyF8,6454
+Crypto/SelfTest/Cipher/test_ARC4.py,sha256=ZrRn7HJh8_MELniyyFP0xVOxYSfMhngMsugXb-RrAmc,24730
+Crypto/SelfTest/Cipher/test_Blowfish.py,sha256=-WiY2HDEqaeACbivv4hYHyBNjksv2r7tf22-skgY44A,7230
+Crypto/SelfTest/Cipher/test_CAST.py,sha256=u98fHuuzzXE_eZsEmcoWPF-nnfOTWR-RXJsZuqmiGWE,3279
+Crypto/SelfTest/Cipher/test_CBC.py,sha256=dXDU3Y57bOzfEz62ogWpD7yfnkrYWJulkXR0yO9RB9g,20202
+Crypto/SelfTest/Cipher/test_CCM.py,sha256=8pTtOASf0nqb6-ZhAIDh62gLgwbAiAQZxP6Sgue6xM0,37304
+Crypto/SelfTest/Cipher/test_CFB.py,sha256=BoX_CuHGJLmBk-inWiOIf5JnYDFqQvLfcEcEFTLwnmg,16061
+Crypto/SelfTest/Cipher/test_CTR.py,sha256=f9PKGo5FgQXP880Fn0rxr1USoZqoJf7S61-DApSUTK4,21314
+Crypto/SelfTest/Cipher/test_ChaCha20.py,sha256=PDV0Xx5OSprm0Wkp4jJPIdDM8nZlz5ByX9FIttsLrTE,20316
+Crypto/SelfTest/Cipher/test_ChaCha20_Poly1305.py,sha256=KMhJf9pV6VAQdCLPotm4n4PkbscM0edkY0l6EzeivE0,30499
+Crypto/SelfTest/Cipher/test_DES.py,sha256=4Jz2XNFXzYPzoqVR6xzIy-_DxOS8RRQDPAFzhhFndDk,15943
+Crypto/SelfTest/Cipher/test_DES3.py,sha256=Kpq-_7cpG8CVsDQv_B5GVbLC3Z-5wkbsO_JWzXGAZ7w,6561
+Crypto/SelfTest/Cipher/test_EAX.py,sha256=4v3XRexRoLUo0DWOk0ZHtc9TO359unRL-xYwrIKaagA,28819
+Crypto/SelfTest/Cipher/test_GCM.py,sha256=pMlNYw9RjjikWnMHXUYLZugJeX2JPaa9Lp1dtjurZmg,37276
+Crypto/SelfTest/Cipher/test_OCB.py,sha256=AfQ734lS0VcBMF8snwFzQ1cDWfWgBWFRJLL7tz4S0h0,28053
+Crypto/SelfTest/Cipher/test_OFB.py,sha256=WPPo-ByC93O2Jej_UBQ3SHRKmFflO0zo94C0u8-rVHY,9367
+Crypto/SelfTest/Cipher/test_OpenPGP.py,sha256=ZQsjls14yvQsP9c3Ifc4F8c4xYKpkZ6Cynr46bHef9E,8477
+Crypto/SelfTest/Cipher/test_SIV.py,sha256=0I7_sVuqtaLjCjiEPkA1l1Se7w5DE_ZI2gQ5aXbdDxA,19939
+Crypto/SelfTest/Cipher/test_Salsa20.py,sha256=rAt0Gy3DUo9VsbzL3_nfbNu9sMrsh3mAj1Xk0CfI5tQ,16591
+Crypto/SelfTest/Cipher/test_pkcs1_15.py,sha256=n3rik7vt_a8yvYksm4Fnmh0P6CpIBiTXeCfZj1YCyQY,10944
+Crypto/SelfTest/Cipher/test_pkcs1_oaep.py,sha256=MPznWwzqvbw4FIkLcylC5Ur6syMwqudohBEsrsWuS9M,22290
+Crypto/SelfTest/Hash/__init__.py,sha256=-TeN101NniS-6DJGKfcgqgFmZA6rx_kJnyUXmpFaR2k,3713
+Crypto/SelfTest/Hash/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/common.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_BLAKE2.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_CMAC.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_HMAC.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_KMAC.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_KangarooTwelve.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_MD2.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_MD4.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_MD5.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_Poly1305.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_RIPEMD160.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA1.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA224.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA256.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA384.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA3_224.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA3_256.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA3_384.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA3_512.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHA512.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_SHAKE.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_TupleHash.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_cSHAKE.cpython-39.pyc,,
+Crypto/SelfTest/Hash/__pycache__/test_keccak.cpython-39.pyc,,
+Crypto/SelfTest/Hash/common.py,sha256=rREos5jC82-QtGOqvx59rLYne2scZFMjRD3DVBNv_EI,9878
+Crypto/SelfTest/Hash/test_BLAKE2.py,sha256=Xcn8el7VEoPTRFGglLOxr-Ha85vZB8S8SmVp2HEX2B4,16314
+Crypto/SelfTest/Hash/test_CMAC.py,sha256=-YAfT4wPkUpdGoCtvfROSctmj56zBS5uD9mhVtLdG4Y,13360
+Crypto/SelfTest/Hash/test_HMAC.py,sha256=F7YPjYL0FkaBT2lcOAFvORabgyv50PRtJuIcjLXwLL4,19941
+Crypto/SelfTest/Hash/test_KMAC.py,sha256=6dDbkV9wCQldaQNe1xHzd9ZLTakIroZaaPZItDaWPuU,11704
+Crypto/SelfTest/Hash/test_KangarooTwelve.py,sha256=GLgXZXT4pvuLyUCwmmHL2uDR_BUJ-4H2r7rj63pXgtM,10455
+Crypto/SelfTest/Hash/test_MD2.py,sha256=DmKwE5sOOiZzTaG8mZsr4GYsC-lHQJhdDPwDwKayowk,2324
+Crypto/SelfTest/Hash/test_MD4.py,sha256=7MLR7Y4n4sR-n9YoNRxLREg7tA0RD6oCHtusaUv9ihc,2347
+Crypto/SelfTest/Hash/test_MD5.py,sha256=d_vO_rWOvFVbOU2KGEIuzVjUe44d3AEoLGY4GvmvE0U,3284
+Crypto/SelfTest/Hash/test_Poly1305.py,sha256=HUrIeWXrtUo5pAAcJnGacNJ6z2O0PbxWtO9bTF65cIo,18297
+Crypto/SelfTest/Hash/test_RIPEMD160.py,sha256=xQDplk8VlmbMXEmA1380gG78aA5l3z8gS3CYSJC4Yyo,2663
+Crypto/SelfTest/Hash/test_SHA1.py,sha256=5JNKqw9ZBLFsIOqKLgVGJTaMog-1hy_Tn7Bs8moNOJQ,2926
+Crypto/SelfTest/Hash/test_SHA224.py,sha256=e7Hv0MgLyvUZ7k6Sm-QikeNX5MT4zb92FBDQPdl3Pn4,2533
+Crypto/SelfTest/Hash/test_SHA256.py,sha256=Ja952iqnKsoE8aLO3HxBnMZPkZ1-xcAAfwa81kpViaw,3617
+Crypto/SelfTest/Hash/test_SHA384.py,sha256=OQQWXtHz2k0giN-mo00KaOssHoEV6sabUgi9SSUBPBU,2714
+Crypto/SelfTest/Hash/test_SHA3_224.py,sha256=InARVpF5ijFHA0caOxGa_2UMH_FcIY21lQH45ZfU6eE,2830
+Crypto/SelfTest/Hash/test_SHA3_256.py,sha256=CjGeeQDZIbFEjZrC1rcVMEm5J3m-yL7QZCDV-HysQe4,2831
+Crypto/SelfTest/Hash/test_SHA3_384.py,sha256=CMPpLzqpjv90Euraxwnm1Use2kGziN9jn4WgYGkK6U0,2830
+Crypto/SelfTest/Hash/test_SHA3_512.py,sha256=Xivd0fO6eJ5Dc-2cYPy6HbouWEplfFRPuVaMet-ATwo,2831
+Crypto/SelfTest/Hash/test_SHA512.py,sha256=tglTx0VokAkiNY2y-7jMlzgckj4h1KY67jkNmATiqPQ,5198
+Crypto/SelfTest/Hash/test_SHAKE.py,sha256=RpIAaMTUZCToxuJO33--jmAC-Iebh2qLWC2nFVul05w,4715
+Crypto/SelfTest/Hash/test_TupleHash.py,sha256=BOHlqoTQ26GJEfaBKGpA_j99CAB3r-5fQLHsQvnvWxI,8135
+Crypto/SelfTest/Hash/test_cSHAKE.py,sha256=biDAriI4PDfNZ5ElySNxCAP8f1PH0BhmBLMXupSbibU,6792
+Crypto/SelfTest/Hash/test_keccak.py,sha256=cfZQDBcEjNN8-EsfxcBsT4dGYFR0KT9ICZ1KCPDDVLM,8889
+Crypto/SelfTest/IO/__init__.py,sha256=_4mxkjHrp7TRKg4Q-95fEQtfERUFm8QpvTQL_sKNaj0,1994
+Crypto/SelfTest/IO/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/IO/__pycache__/test_PBES.cpython-39.pyc,,
+Crypto/SelfTest/IO/__pycache__/test_PKCS8.cpython-39.pyc,,
+Crypto/SelfTest/IO/test_PBES.py,sha256=c8Xxg_Nz1Ir1v5nKgahOElsZMjxuEHnIHGJOAivK838,3453
+Crypto/SelfTest/IO/test_PKCS8.py,sha256=ZgkXWzhey9uUoHaGalWZMbI1wSTcSC-tOTVzpoiVe-w,17601
+Crypto/SelfTest/Math/__init__.py,sha256=_fYQ6JUgK4cuv62SBwItAdhyUF2P7l7s4LOmOVyGvyw,2101
+Crypto/SelfTest/Math/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/Math/__pycache__/test_Numbers.cpython-39.pyc,,
+Crypto/SelfTest/Math/__pycache__/test_Primality.cpython-39.pyc,,
+Crypto/SelfTest/Math/__pycache__/test_modexp.cpython-39.pyc,,
+Crypto/SelfTest/Math/test_Numbers.py,sha256=SpdC-vRjVWRGFX5ZaV-GJh318Fs5xh8ImT8yqNJlB5M,30782
+Crypto/SelfTest/Math/test_Primality.py,sha256=fuBicKSRd5bmCh6jyw_ySCkK9Dgifo2A9QIzpMhfm_Y,4881
+Crypto/SelfTest/Math/test_modexp.py,sha256=oXkGt_dWwpYy3rWJAkIAEAQfft1q5ubZCdIw5kXc-ps,8103
+Crypto/SelfTest/Protocol/__init__.py,sha256=RK1layikOWyUcwhwmDEnfZGLVWAzSIDgUkBCCWD_2ag,1743
+Crypto/SelfTest/Protocol/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/Protocol/__pycache__/test_KDF.cpython-39.pyc,,
+Crypto/SelfTest/Protocol/__pycache__/test_SecretSharing.cpython-39.pyc,,
+Crypto/SelfTest/Protocol/__pycache__/test_rfc1751.cpython-39.pyc,,
+Crypto/SelfTest/Protocol/test_KDF.py,sha256=rgaRPHG2lJt16acdLXWZOv--LuXBYTF9k53n6O9NhHc,34013
+Crypto/SelfTest/Protocol/test_SecretSharing.py,sha256=xTY_D3HsP86OSt0I0vZv6Q9XzQXjCFtVFf0-Dno4xBY,9685
+Crypto/SelfTest/Protocol/test_rfc1751.py,sha256=LR3M9XLk_sxOyapPq32PEf93SUMwErFwwzlHNKhUazg,2208
+Crypto/SelfTest/PublicKey/__init__.py,sha256=rowj1oLXbL_oOVbs4Egdf5zb89ZRVhIdQHbatmjdNCM,2118
+Crypto/SelfTest/PublicKey/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_DSA.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_ECC_25519.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_ECC_448.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_ECC_NIST.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_ElGamal.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_RSA.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_import_DSA.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_import_ECC.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/__pycache__/test_import_RSA.cpython-39.pyc,,
+Crypto/SelfTest/PublicKey/test_DSA.py,sha256=M-3F4O801yYWHDqTLNmBmm5k-OMfkOHBi3eBNg4U6zg,9600
+Crypto/SelfTest/PublicKey/test_ECC_25519.py,sha256=Exn10DIMzvrdN4die5iqneztpqAJDkFZWmoXS42Y_DY,13440
+Crypto/SelfTest/PublicKey/test_ECC_448.py,sha256=yubrMZK0hyLXuBuh2F7jd-BZnQ5hJJs8LEsupQV_yrk,14669
+Crypto/SelfTest/PublicKey/test_ECC_NIST.py,sha256=CVo_E4Izo4_GBoHa9muoOQ38LLbGfCt8cBE4zR0TMnY,49972
+Crypto/SelfTest/PublicKey/test_ElGamal.py,sha256=5Ys8fz4gb3wAUC5Aqa3pVfO31nBVVE-jBGnI5CllhUc,8648
+Crypto/SelfTest/PublicKey/test_RSA.py,sha256=4X2ZYNWWNWe4oxclrvnqaTW0mKZAHglMDEHo98FDJEo,12261
+Crypto/SelfTest/PublicKey/test_import_DSA.py,sha256=2aVhMJe2f4pNItCDZsp63IBdwfxqjzBCQA4axvqWheg,25509
+Crypto/SelfTest/PublicKey/test_import_ECC.py,sha256=iLYJHwH1CdQ8KfRr-HDEe3iAQhgmB_XTdxkKtqfqr-g,102822
+Crypto/SelfTest/PublicKey/test_import_RSA.py,sha256=xcb-7D8ERYIoocJCF_hnU-Eex9Amp7vmhttDv1tjprk,25095
+Crypto/SelfTest/Random/__init__.py,sha256=EE2uqy3wAuwiGtGTuBsRC4vK4-Yo0Tdl8_O94p9ZNlw,1542
+Crypto/SelfTest/Random/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/Random/__pycache__/test_random.cpython-39.pyc,,
+Crypto/SelfTest/Random/test_random.py,sha256=yif-POW3svFZ_GWOlCwuRUU-leCJw3cYCawXfkbGWCA,6990
+Crypto/SelfTest/Signature/__init__.py,sha256=4BxIOB9IMey1d1Xixyl7mRYJYnuzNpXk83KlzGBiEEM,1558
+Crypto/SelfTest/Signature/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/Signature/__pycache__/test_dss.cpython-39.pyc,,
+Crypto/SelfTest/Signature/__pycache__/test_eddsa.cpython-39.pyc,,
+Crypto/SelfTest/Signature/__pycache__/test_pkcs1_15.cpython-39.pyc,,
+Crypto/SelfTest/Signature/__pycache__/test_pss.cpython-39.pyc,,
+Crypto/SelfTest/Signature/test_dss.py,sha256=lKR1iRX2CPHqmt-yeFPCwErlnKJ-dSS_kajzlPAe8Fs,57090
+Crypto/SelfTest/Signature/test_eddsa.py,sha256=25t9_GD8_kswEptcc2NidQ3ZuvdMW8p-JrDDGVCzq_Y,24130
+Crypto/SelfTest/Signature/test_pkcs1_15.py,sha256=vbeQ-UbGVVBhaI4JUi2Nz2m913dHFNnSoPRDSpcYkaw,13541
+Crypto/SelfTest/Signature/test_pss.py,sha256=AS1CE55W9b63VBo7MjCPeQ80sj2iSMYBsjFZYx4kkks,15811
+Crypto/SelfTest/Util/__init__.py,sha256=3jO6ijPtoTAxotErcTp0Tulcc8SC4shYLgbtp5kFKd4,1997
+Crypto/SelfTest/Util/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/Util/__pycache__/test_Counter.cpython-39.pyc,,
+Crypto/SelfTest/Util/__pycache__/test_Padding.cpython-39.pyc,,
+Crypto/SelfTest/Util/__pycache__/test_asn1.cpython-39.pyc,,
+Crypto/SelfTest/Util/__pycache__/test_number.cpython-39.pyc,,
+Crypto/SelfTest/Util/__pycache__/test_rfc1751.cpython-39.pyc,,
+Crypto/SelfTest/Util/__pycache__/test_strxor.cpython-39.pyc,,
+Crypto/SelfTest/Util/test_Counter.py,sha256=PndKG-bx13FaDtvZpvaaCIa0T_sLsAE9bI0AuhHNN7g,2272
+Crypto/SelfTest/Util/test_Padding.py,sha256=1h2KKM1Zf7lB8da5lzcoz5T-CJJ6A_Rcwm9OcIiUqSY,5814
+Crypto/SelfTest/Util/test_asn1.py,sha256=jlXbKnTZJxj1eYTIY0c4yU9UN4iRoA3baEbIBYAAzKI,29327
+Crypto/SelfTest/Util/test_number.py,sha256=DqmhOnSmc417Qp28l8SQDJGS5XqcXDRW1vV9Z6rboO0,8518
+Crypto/SelfTest/Util/test_rfc1751.py,sha256=jkKFEA6oEHOB79Q6vnbdS5wZAetp5Lyi7bzyCsLXrQo,1113
+Crypto/SelfTest/Util/test_strxor.py,sha256=GZLG2BIY4Db5xxj2HXW3iyS7_C-KuX9CwsfjZ3WB2sY,10215
+Crypto/SelfTest/__init__.py,sha256=TkQaS45sI54v4yWSjKs9VlTUFuEb06-YbmQVwp1zI1I,3640
+Crypto/SelfTest/__main__.py,sha256=Bta_95qOi_iqjdqkYJO_z5pTD0fxTfaFrQXe8RewIjk,1502
+Crypto/SelfTest/__pycache__/__init__.cpython-39.pyc,,
+Crypto/SelfTest/__pycache__/__main__.cpython-39.pyc,,
+Crypto/SelfTest/__pycache__/loader.cpython-39.pyc,,
+Crypto/SelfTest/__pycache__/st_common.cpython-39.pyc,,
+Crypto/SelfTest/loader.py,sha256=q0uN-aywiWKL16--tl32je91pfPn3RT-qD0mKNbGK3A,6739
+Crypto/SelfTest/st_common.py,sha256=0u4r-Ue-O1u7m9CFKfHAtRfPamSftS_EoEOfUrIvMQE,1945
+Crypto/Signature/DSS.py,sha256=lLW3A5lp1fD8ziXiYDF2nxl-0Gk9rdAiltRQ8aczsbY,15300
+Crypto/Signature/DSS.pyi,sha256=Kf-grv7Krf1ZfTTCLwTB21XOvhBRlV54TTv3YAkVKkc,1094
+Crypto/Signature/PKCS1_PSS.py,sha256=RDd18tjg0DmzZgQDBgUeMO7xGmffs2srM8BgPLgAKiw,2099
+Crypto/Signature/PKCS1_PSS.pyi,sha256=F83y0Q-fB12GiXlzV5g64zCkdFIGubDokyLVQ_DXYeQ,272
+Crypto/Signature/PKCS1_v1_5.py,sha256=qEtvcBicSdmoL0gb5B2H_o-9XD1s-sOS65VxRkJHqk8,1989
+Crypto/Signature/PKCS1_v1_5.pyi,sha256=ycl2Lk1x4f1qDuTueSn4KDq1C3vN5a3nhDUzzn7S2Wo,149
+Crypto/Signature/__init__.py,sha256=nkUODHAHwqmFvemdRLKTFXCY0lh6WdxmFlJuLD0IBfw,1695
+Crypto/Signature/__pycache__/DSS.cpython-39.pyc,,
+Crypto/Signature/__pycache__/PKCS1_PSS.cpython-39.pyc,,
+Crypto/Signature/__pycache__/PKCS1_v1_5.cpython-39.pyc,,
+Crypto/Signature/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Signature/__pycache__/eddsa.cpython-39.pyc,,
+Crypto/Signature/__pycache__/pkcs1_15.cpython-39.pyc,,
+Crypto/Signature/__pycache__/pss.cpython-39.pyc,,
+Crypto/Signature/eddsa.py,sha256=wlE1-RWYGAOhbAAm4syGn0eNgz_b7hZnZhHYqMatmvo,12314
+Crypto/Signature/eddsa.pyi,sha256=FflMZC-uUM0QmbpZ_i_whnNflNsIIHU_QWbxEapQwbw,728
+Crypto/Signature/pkcs1_15.py,sha256=bbDDecPn4VtrCZCABS4QHhpV_xgzXg8xPX__-NLjmOU,8714
+Crypto/Signature/pkcs1_15.pyi,sha256=FjQTiNADwk8fCCmbGmxW912KCyhR9UINtSy-9JLjWgE,564
+Crypto/Signature/pss.py,sha256=-LZT7_0ciG8vvECN9JS8pWH5kjwRBDMw81XvxsGODr0,13434
+Crypto/Signature/pss.pyi,sha256=7KedQdcpuE-iBxRDmB9qsdI7STbZmiIo9I-4ZR6wiag,1040
+Crypto/Util/Counter.py,sha256=8qQaeEZGHtAhwGj0A9gJQ9e0ZihtjWpeCSeuS1h9KF8,3110
+Crypto/Util/Counter.pyi,sha256=2JrTHJYq263XosQSC_NIP0TufUsTlG7WUr-lRqjJCuA,290
+Crypto/Util/Padding.py,sha256=70zHE3iL6jsAIJDWtp3c79C6G-QN8_H0Qk_XSjB8Hpw,4313
+Crypto/Util/Padding.pyi,sha256=47R3H2kE66PtKO82eT_Vc5eCSgNe4qOFgqOIPRdlp9c,238
+Crypto/Util/RFC1751.py,sha256=vXf_GR5ezf6QYIedqNOJIiiAUygkjUfNNtGp1zdatVI,21192
+Crypto/Util/RFC1751.pyi,sha256=B42LvsE6G786rNEsrhta_BANazgrpb0WoSBPqKyjt5g,159
+Crypto/Util/__init__.py,sha256=PkdJwchrSxd4ty71sd0edXst9DSinI5gBgzeRfq6Nko,1927
+Crypto/Util/__pycache__/Counter.cpython-39.pyc,,
+Crypto/Util/__pycache__/Padding.cpython-39.pyc,,
+Crypto/Util/__pycache__/RFC1751.cpython-39.pyc,,
+Crypto/Util/__pycache__/__init__.cpython-39.pyc,,
+Crypto/Util/__pycache__/_cpu_features.cpython-39.pyc,,
+Crypto/Util/__pycache__/_file_system.cpython-39.pyc,,
+Crypto/Util/__pycache__/_raw_api.cpython-39.pyc,,
+Crypto/Util/__pycache__/asn1.cpython-39.pyc,,
+Crypto/Util/__pycache__/number.cpython-39.pyc,,
+Crypto/Util/__pycache__/py3compat.cpython-39.pyc,,
+Crypto/Util/__pycache__/strxor.cpython-39.pyc,,
+Crypto/Util/_cpu_features.py,sha256=hdmpjcnxJjkz2dyfsPBNnyaykqa3zXGj9gUGuWk1A2g,1989
+Crypto/Util/_cpu_features.pyi,sha256=3wKXZ0Z8llc2uxADvbhz3dHV6YLyRrDujOsabXlffCQ,59
+Crypto/Util/_cpuid_c.abi3.so,sha256=EdPZEMHSW3XmxJ0hnMPKGTLoxtD_0h5IH2G6g_QeWXo,12776
+Crypto/Util/_file_system.py,sha256=ORfvmlwKJtlGA5PCNr646U8huzTabs2b43d_FDauG-I,2171
+Crypto/Util/_file_system.pyi,sha256=5QruEWPE4urPtlCT5Eg8tBQyhV9ffBfZIAjmMo727dM,100
+Crypto/Util/_raw_api.py,sha256=h0NsvGrLM4gGjbe5p21TzKDqz3TXpDZPYa8j2SGgch8,10226
+Crypto/Util/_raw_api.pyi,sha256=Ohc2rr6RS-nhs6T5AL1YyQtaqsx6BVrJa092CiwAvNM,906
+Crypto/Util/_strxor.abi3.so,sha256=b_g2QAepTu-tquCPD4uE6Ocp_eex6_y0KIsNmUW_bjw,14960
+Crypto/Util/asn1.py,sha256=X1VsGWGCYn6iGs4i5Gildy4qQOSpohKc2TZW7PZEz0o,31657
+Crypto/Util/asn1.pyi,sha256=xR4oQKBf4SXiz0IQ_K0lw427jvvgX9SiEXejIu9fdV8,3579
+Crypto/Util/number.py,sha256=-Ey8sJKYzwIdK_ua3-tg8WAcE9XWAgK-iiqs2BweN5I,95709
+Crypto/Util/number.pyi,sha256=ixX1BS8EvvuPXN1_8aosdYHKmtXGB9NlRNVI9T9MAA8,975
+Crypto/Util/py3compat.py,sha256=XvM77kr9oBlD2WHBcgGPiETx3ekH-Yu_b4e3n2X6MhE,5526
+Crypto/Util/py3compat.pyi,sha256=lcLAXVV6t4d_y_EsUZOYEYgrOUczczMl_3IawItxYpw,837
+Crypto/Util/strxor.py,sha256=IdkE9I5J5v9ScJmhjbVfrGAKIUnukBLtdYIPzh1eR5Y,5441
+Crypto/Util/strxor.pyi,sha256=OuBvuuK_ezq3eaHY10J89xpER9IQ9wcYzFI7j1tpll0,243
+Crypto/__init__.py,sha256=t-JUcMSSnZhIFP8wNWKMkQah8dXuyhttfHa0Qda5GQo,185
+Crypto/__init__.pyi,sha256=e5Ea45Jy2RdOr6bmLF9jiS2Bw65WnYTD1NMLJlbGAaw,99
+Crypto/__pycache__/__init__.cpython-39.pyc,,
+Crypto/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+pycryptodome-3.15.0.dist-info/AUTHORS.rst,sha256=TL3mA8NQgoAcSU73HDEKSMnuix-f4GiSTxtFMy_CS9M,750
+pycryptodome-3.15.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+pycryptodome-3.15.0.dist-info/LICENSE.rst,sha256=TgRmDXfBxk6J15U3kZ-4JA-iFISn49sp81iyx_hOoHM,2926
+pycryptodome-3.15.0.dist-info/METADATA,sha256=BEW1YR7gWjSmLOwPs56jegpFahCcNAfO8GBg7lAqdoA,3184
+pycryptodome-3.15.0.dist-info/RECORD,,
+pycryptodome-3.15.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+pycryptodome-3.15.0.dist-info/WHEEL,sha256=NHVYCelbjQYaRHcLsSKkvd1e6nvZj0EwYHq59C6NexI,111
+pycryptodome-3.15.0.dist-info/top_level.txt,sha256=-W2wTtkxc1QnPUPRqBZ0bMwrhD8xRD13HIobFX-wDOs,7
diff --git a/lib/pycryptodome-3.15.0.dist-info/REQUESTED b/lib/pycryptodome-3.15.0.dist-info/REQUESTED
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/REQUESTED
diff --git a/lib/pycryptodome-3.15.0.dist-info/WHEEL b/lib/pycryptodome-3.15.0.dist-info/WHEEL
new file mode 100644
index 0000000..59c3897
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/WHEEL
@@ -0,0 +1,5 @@
+Wheel-Version: 1.0
+Generator: bdist_wheel (0.36.2)
+Root-Is-Purelib: false
+Tag: cp35-abi3-manylinux2010_x86_64
+
diff --git a/lib/pycryptodome-3.15.0.dist-info/top_level.txt b/lib/pycryptodome-3.15.0.dist-info/top_level.txt
new file mode 100644
index 0000000..e6645e7
--- /dev/null
+++ b/lib/pycryptodome-3.15.0.dist-info/top_level.txt
@@ -0,0 +1 @@
+Crypto
diff --git a/lib/snack.py b/lib/snack.py
new file mode 100644
index 0000000..a9d2b44
--- /dev/null
+++ b/lib/snack.py
@@ -0,0 +1,998 @@
+# snack.py: maps C extension module _snack to proper python types in module
+# snack.
+# The first section is a very literal mapping.
+# The second section contains convenience classes that amalgamate
+# the literal classes and make them more object-oriented.
+
+"""
+This module provides the NEWT Windowing toolkit API for Python
+This is a lightweight text-mode windowing library, based on slang.
+
+Classes:
+
+ - Widget
+ - Button
+ - CompactButton
+ - Checkbox
+ - SingleRadioButton
+ - Listbox
+ - Textbox
+ - TextboxReflowed
+ - Label
+ - Scale
+ - Entry
+ - Form
+ - Grid
+ - SnackScreen
+ - RadioGroup
+ - RadioBar
+ - ButtonBar
+ - GridFormHelp
+ - GridForm
+ - CheckboxTree
+ - Clistbox
+
+Functions:
+
+ - ListboxChoiceWindow
+ - ButtonChoiceWindow
+ - EntryWindow
+"""
+
+
+from __future__ import absolute_import, print_function, unicode_literals
+import _snack
+import string
+import sys
+
+from _snack import FLAG_DISABLED, FLAGS_SET, FLAGS_RESET, FLAGS_TOGGLE, FD_READ, FD_WRITE, FD_EXCEPT
+
+LEFT = (-1, 0)
+DOWN = (-1, -1)
+CENTER = (0, 0)
+UP = (1, 1)
+RIGHT = (1, 0)
+
+snackArgs = {"append":-1}
+
+class Widget:
+ """Base class for NEWT toolkit - Do not use directly
+
+ methods:
+
+ - Widget(self)
+ - setCallback(self, obj, data = None) :
+ The callback for when object activated.
+ data is passed to obj.
+ """
+ def setCallback(self, obj, data = None):
+ if data:
+ self.w.setCallback(obj, data)
+ else:
+ self.w.setCallback(obj)
+
+ def __init__(self):
+ raise NotImplementedError
+
+class Button(Widget):
+ """Basic button class, takes button text as parameter
+
+ method:
+
+ - Button(self, text): returns a button
+ """
+ def __init__(self, text):
+ self.w = _snack.button(text)
+
+class CompactButton(Widget):
+ """Compact Button class (less frilly button decoration).
+
+ methods:
+
+ - CompactButton(self,text) : create button, with text.
+ """
+ def __init__(self, text):
+ self.w = _snack.compactbutton(text)
+
+class Checkbox(Widget):
+ """A checkbox.
+
+ methods:
+
+ - Checkbox(self, text, isOn = 0) : text, and boolean as to default value
+ - setValue(self) : set value
+ - value(self, value) : return checkbox value
+ - selected(self) : returns boolean
+ - setFlags(self, flag, sense) : set flags
+
+ flags: FLAG_DISABLED, FLAGS_SET, FLAGS_RESET
+ """
+ def value(self):
+ return self.w.checkboxValue
+
+ def selected(self):
+ return self.w.checkboxValue != 0
+
+ def setFlags (self, flag, sense):
+
+ return self.w.checkboxSetFlags(flag, sense)
+
+ def setValue (self, value):
+ return self.w.checkboxSetValue(value)
+
+ def __init__(self, text, isOn = 0):
+ self.w = _snack.checkbox(text, isOn)
+
+class SingleRadioButton(Widget):
+ """Single Radio Button.
+
+ methods:
+
+ - SingleRadioButton(text, group, isOn = 0) : create button
+ - selected(self) : returns bool, whether or not is selected.
+ """
+
+ def selected(self):
+ return self.w.key == self.w.radioValue;
+
+ def __init__(self, text, group, isOn = 0):
+ if group:
+ self.w = _snack.radiobutton(text, group.w, isOn)
+ else:
+ self.w = _snack.radiobutton(text, None, isOn)
+
+class Listbox(Widget):
+ """Listbox class.
+
+ methods:
+
+ - Listbox(self, height, scroll = 0, returnExit = 0, width = 0, showCursor = 0, multiple = 0, border = 0)
+ - insert(self, text, item, before) : insert element; before = key to item to insert before, or None.
+ - delete(self, item) : delete item from list.
+ - replace(self, text,item) : Replace a given item's text
+ - current(self) : returns currently selected item
+ - getSelection(self) : returns a list of selected items
+ - setCurrent(self,i tem) : select current.
+ - clear(self) : clear listbox
+ """
+
+ def append(self, text, item):
+ key = self.w.listboxAddItem(text)
+ self.key2item[key] = item
+ self.item2key[item] = key
+
+ def insert(self, text, item, before):
+ if (not before):
+ key = self.w.listboxInsertItem(text, 0)
+ else:
+ key = self.w.listboxInsertItem(text, self.item2key[before])
+ self.key2item[key] = item
+ self.item2key[item] = key
+
+ def delete(self, item):
+ self.w.listboxDeleteItem(self.item2key[item])
+ del self.key2item[self.item2key[item]]
+ del self.item2key[item]
+
+ def replace(self, text, item):
+ key = self.w.listboxInsertItem(text, self.item2key[item])
+ self.w.listboxDeleteItem(self.item2key[item])
+ del self.key2item[self.item2key[item]]
+ self.item2key[item] = key
+ self.key2item[key] = item
+
+ def current(self):
+ return self.key2item[self.w.listboxGetCurrent()]
+
+ def getSelection(self):
+ selection = []
+ list = self.w.listboxGetSelection()
+ for key in list:
+ selection.append(self.key2item[key])
+ return selection
+
+ def setCurrent(self, item):
+ self.w.listboxSetCurrent(self.item2key[item])
+
+ def clear(self):
+ self.key2item = {}
+ self.item2key = {}
+ self.w.listboxClear()
+
+ def __init__(self, height, scroll = 0, returnExit = 0, width = 0, showCursor = 0, multiple = 0, border = 0):
+ self.w = _snack.listbox(height, scroll, returnExit, showCursor, multiple, border)
+ self.key2item = {}
+ self.item2key = {}
+ if (width):
+ self.w.listboxSetWidth(width)
+
+class Textbox(Widget):
+ """Textbox, container for text.
+
+ methods:
+
+ - Textbox(self, width, height, scroll = 0, wrap = 0): scroll, wrap are flags
+ include scroll bars, or text wrap.
+ - setText(text) : set text.
+ - setHeight(height): set height.
+ """
+
+ def setText(self, text):
+ self.w.textboxText(text)
+
+ def setHeight(self, height):
+ self.w.textboxHeight(height)
+
+ def __init__(self, width, height, text, scroll = 0, wrap = 0):
+ self.w = _snack.textbox(width, height, text, scroll, wrap)
+
+class TextboxReflowed(Textbox):
+
+ def __init__(self, width, text, flexDown = 5, flexUp = 10, maxHeight = -1):
+ (newtext, width, height) = reflow(text, width, flexDown, flexUp)
+ if maxHeight != -1 and height > maxHeight:
+ Textbox.__init__(self, width, maxHeight, newtext, 1)
+ else:
+ Textbox.__init__(self, width, height, newtext, 0)
+
+class Label(Widget):
+ """A Label (simple text).
+
+ methods:
+
+ - Label(self,text) : create label
+ - setText(self,text) : change text.
+ - setColors(self, colorset) : change individual colors
+ """
+ def setText(self, text):
+ self.w.labelText(text)
+
+ def __init__(self, text):
+ self.w = _snack.label(text)
+
+ def setColors(self, colorset):
+ self.w.labelSetColors(colorset)
+
+class Scale(Widget):
+ """A Scale (progress bar).
+
+ methods:
+
+ - Scale(self,width, total) : create scale; width: size on screen, fullamount: integer.
+ - set(self,amount) : set amount to integer.
+ """
+ def set(self, amount):
+ self.w.scaleSet(amount)
+
+ def __init__(self, width, total):
+ self.w = _snack.scale(width, total)
+
+class Entry(Widget):
+ """Entry widget.
+
+ methods:
+
+ - Entry(self, width, text = "", hidden = 0, password = 0, scroll = 1, returnExit = 0)
+ constructor. hidden doesn't show text, password stars it out,
+ scroll includes scroll bars;
+ if returnExit is set, return from Form when exiting this element, else
+ proceed to next entry widget.
+ - value(self): return value.
+ - set(text, cursorAtEnd = 1) : set the text
+ - setFlags (flag, sense) : flags can be FLAG_DISABLED, FLAGS_SET, FLAGS_RESET, FLAGS_TOGGLE
+ """
+ def value(self):
+ return self.w.entryValue
+
+ def set(self, text, cursorAtEnd = 1):
+ return self.w.entrySetValue(text, cursorAtEnd)
+
+ def setFlags (self, flag, sense):
+ return self.w.entrySetFlags(flag, sense)
+
+ def __init__(self, width, text = "", hidden = 0, password = 0, scroll = 1,
+ returnExit = 0):
+ self.w = _snack.entry(width, text, hidden, password, scroll, returnExit)
+
+
+# Form uses hotkeys
+hotkeys = { "F1" : _snack.KEY_F1, "F2" : _snack.KEY_F2, "F3" : _snack.KEY_F3,
+ "F4" : _snack.KEY_F4, "F5" : _snack.KEY_F5, "F6" : _snack.KEY_F6,
+ "F7" : _snack.KEY_F7, "F8" : _snack.KEY_F8, "F9" : _snack.KEY_F9,
+ "F10" : _snack.KEY_F10, "F11" : _snack.KEY_F11,
+ "F12" : _snack.KEY_F12, "ESC" : _snack.KEY_ESC,
+ "ENTER": _snack.KEY_ENTER, "SUSPEND" : _snack.KEY_SUSPEND,
+ "BACKSPACE": _snack.KEY_BACKSPACE, "DELETE": _snack.KEY_DELETE,
+ "INSERT": _snack.KEY_INSERT, "RESIZE": _snack.KEY_RESIZE,
+ " " : ord(" ") }
+
+for n in list(hotkeys.keys()):
+ hotkeys[hotkeys[n]] = n
+for o,c in [ (ord(c),c) for c in string.ascii_letters+string.digits ]:
+ hotkeys[c] = o
+ hotkeys[o] = c
+
+class Form:
+ """ Base Form class, from which Grid, etc. inherit
+
+ methods:
+
+ - Form(self, helpArg = None) : constructor.
+ - addHotKey(self, keyname) : keynames of form "F1" through "F12", "ESC"
+ - add(self, widget) : Add a widget
+ - run(self): run a form, expecting input
+ - draw(self): draw form.
+ - setTimer(self, timer) : add a timer
+ - watchFile(self, file, flags) : watch a named file
+ - setCurrent (self, co): Set a given widget as the current focus
+ """
+ def addHotKey(self, keyname):
+ self.w.addhotkey(hotkeys[keyname])
+
+ def add(self, widget):
+ if 'hotkeys' in widget.__dict__:
+ for key in widget.hotkeys.keys():
+ self.addHotKey(key)
+
+ if 'gridmembers' in widget.__dict__:
+ for w in widget.gridmembers:
+ self.add(w)
+ elif 'w' in widget.__dict__:
+ self.trans[widget.w.key] = widget
+ return self.w.add(widget.w)
+ return None
+
+ def run(self):
+ (what, which) = self.w.run()
+ if (what == _snack.FORM_EXIT_WIDGET):
+ return self.trans[which]
+ elif (what == _snack.FORM_EXIT_TIMER):
+ return "TIMER"
+ elif (what == _snack.FORM_EXIT_FDREADY):
+ return self.filemap[which]
+ elif (what == _snack.FORM_EXIT_HOTKEY):
+ return hotkeys[which]
+ raise RuntimeError("EOF or IO error")
+
+ def draw(self):
+ self.w.draw()
+ return None
+
+ def __init__(self, helpArg = None):
+ self.trans = {}
+ self.filemap = {}
+ self.w = _snack.form(helpArg)
+ # we do the reference count for the helpArg in python! gross
+ self.helpArg = helpArg
+
+ def setCurrent (self, co):
+ self.w.setcurrent (co.w)
+
+ def setTimer (self, timer):
+ self.w.settimer (timer)
+
+ def watchFile (self, file, flags):
+ self.filemap[file.fileno()] = file
+ self.w.watchfd (file.fileno(), flags)
+
+class Grid:
+ """Grid class.
+
+ methods:
+
+ - place(self,x,y): Return what is placed at (x,y)
+ - setField(self, what, col, row, padding = (0, 0, 0, 0),
+ anchorLeft = 0, anchorTop = 0, anchorRight = 0,
+ anchorBottom = 0, growx = 0, growy = 0):
+ used to add widget 'what' to grid.
+ - Grid(self, *args): eg. g = Grid(2,3) for 2x3 grid
+ """
+ def place(self, x, y):
+ return self.g.place(x, y)
+
+ def setField(self, what, col, row, padding = (0, 0, 0, 0),
+ anchorLeft = 0, anchorTop = 0, anchorRight = 0,
+ anchorBottom = 0, growx = 0, growy = 0):
+ self.gridmembers.append(what)
+ anchorFlags = 0
+ if (anchorLeft):
+ anchorFlags = _snack.ANCHOR_LEFT
+ elif (anchorRight):
+ anchorFlags = _snack.ANCHOR_RIGHT
+
+ if (anchorTop):
+ anchorFlags = anchorFlags | _snack.ANCHOR_TOP
+ elif (anchorBottom):
+ anchorFlags = anchorFlags | _snack.ANCHOR_BOTTOM
+
+ gridFlags = 0
+ if (growx):
+ gridFlags = _snack.GRID_GROWX
+ if (growy):
+ gridFlags = gridFlags | _snack.GRID_GROWY
+
+ if 'g' in what.__dict__:
+ return self.g.setfield(col, row, what.g, padding, anchorFlags,
+ gridFlags)
+ else:
+ return self.g.setfield(col, row, what.w, padding, anchorFlags)
+
+ def __init__(self, *args):
+ self.g = _snack.grid(*args)
+ self.gridmembers = []
+
+colorsets = { "ROOT" : _snack.COLORSET_ROOT,
+ "BORDER" : _snack.COLORSET_BORDER,
+ "WINDOW" : _snack.COLORSET_WINDOW,
+ "SHADOW" : _snack.COLORSET_SHADOW,
+ "TITLE" : _snack.COLORSET_TITLE,
+ "BUTTON" : _snack.COLORSET_BUTTON,
+ "ACTBUTTON" : _snack.COLORSET_ACTBUTTON,
+ "CHECKBOX" : _snack.COLORSET_CHECKBOX,
+ "ACTCHECKBOX" : _snack.COLORSET_ACTCHECKBOX,
+ "ENTRY" : _snack.COLORSET_ENTRY,
+ "LABEL" : _snack.COLORSET_LABEL,
+ "LISTBOX" : _snack.COLORSET_LISTBOX,
+ "ACTLISTBOX" : _snack.COLORSET_ACTLISTBOX,
+ "TEXTBOX" : _snack.COLORSET_TEXTBOX,
+ "ACTTEXTBOX" : _snack.COLORSET_ACTTEXTBOX,
+ "HELPLINE" : _snack.COLORSET_HELPLINE,
+ "ROOTTEXT" : _snack.COLORSET_ROOTTEXT,
+ "EMPTYSCALE" : _snack.COLORSET_EMPTYSCALE,
+ "FULLSCALE" : _snack.COLORSET_FULLSCALE,
+ "DISENTRY" : _snack.COLORSET_DISENTRY,
+ "COMPACTBUTTON" : _snack.COLORSET_COMPACTBUTTON,
+ "ACTSELLISTBOX" : _snack.COLORSET_ACTSELLISTBOX,
+ "SELLISTBOX" : _snack.COLORSET_SELLISTBOX }
+
+class SnackScreen:
+ """A Screen;
+
+ methods:
+
+ - Screen(self) : constructor
+ - finish(self)
+ - resume(self)
+ - suspend(self)
+ - doHelpCallback(self,arg) call callback with arg
+ - helpCallback(self,cb): Set help callback
+ - suspendcallback(self,cb, data=None) : set callback. data=data to pass to cb.
+ - openWindow(self,left, top, width, height, title): Open a window.
+ - pushHelpLine(self,text): put help line on screen. Returns current help line if text=None
+ - setColor(self, colorset, fg, bg): Set foreground and background colors;
+ colorset = key from snack.colorsets,
+ fg & bg = english color names defined by S-Lang
+ (ref: S-Lang Library C Programmer's Guide section:
+ 8.4.4. Setting Character Attributes)
+ """
+ def __init__(self):
+ _snack.init()
+ (self.width, self.height) = _snack.size()
+ self.pushHelpLine(None)
+
+ def finish(self):
+ return _snack.finish()
+
+ def resume(self):
+ _snack.resume()
+
+ def suspend(self):
+ _snack.suspend()
+
+ def doHelpCallback(self, arg):
+ self.helpCb(self, arg)
+
+ def helpCallback(self, cb):
+ self.helpCb = cb
+ return _snack.helpcallback(self.doHelpCallback)
+
+ def suspendCallback(self, cb, data = None):
+ if data:
+ return _snack.suspendcallback(cb, data)
+ return _snack.suspendcallback(cb)
+
+ def openWindow(self, left, top, width, height, title):
+ return _snack.openwindow(left, top, width, height, title)
+
+ def pushHelpLine(self, text):
+ if (not text):
+ return _snack.pushhelpline("*default*")
+ else:
+ return _snack.pushhelpline(text)
+
+ def popHelpLine(self):
+ return _snack.pophelpline()
+
+ def drawRootText(self, left, top, text):
+ return _snack.drawroottext(left, top, text)
+
+ def centeredWindow(self, width, height, title):
+ return _snack.centeredwindow(width, height, title)
+
+ def gridWrappedWindow(self, grid, title, x = None, y = None):
+ if x and y:
+ return _snack.gridwrappedwindow(grid.g, title, x, y)
+
+ return _snack.gridwrappedwindow(grid.g, title)
+
+ def popWindow(self, refresh = True):
+ if refresh:
+ return _snack.popwindow()
+ return _snack.popwindownorefresh()
+
+ def refresh(self):
+ return _snack.refresh()
+
+ def setColor(self, colorset, fg, bg):
+ if colorset in colorsets:
+ return _snack.setcolor(colorsets[colorset], fg, bg)
+ else:
+ # assume colorset is an integer for the custom color set
+ return _snack.setcolor(colorset, fg, bg)
+
+def reflow(text, width, flexDown = 5, flexUp = 5):
+ """ returns a tuple of the wrapped text, the actual width, and the actual height
+ """
+ return _snack.reflow(text, width, flexDown, flexUp)
+
+# combo widgets
+
+class RadioGroup(Widget):
+ """ Combo widget: Group of Radio buttons
+
+ methods:
+
+ - RadioGroup(self): constructor.
+ - add(self,title, value, default = None): add a button. Returns button.
+ - getSelection(self) : returns value of selected button | None
+ """
+ def __init__(self):
+ self.prev = None
+ self.buttonlist = []
+
+ def add(self, title, value, default = None):
+ if not self.prev and default == None:
+ # If the first element is not explicitly set to
+ # not be the default, make it be the default
+ default = 1
+ b = SingleRadioButton(title, self.prev, default)
+ self.prev = b
+ self.buttonlist.append((b, value))
+ return b
+
+ def getSelection(self):
+ for (b, value) in self.buttonlist:
+ if b.selected(): return value
+ return None
+
+
+class RadioBar(Grid):
+ """ Bar of Radio buttons, based on Grid.
+
+ methods:
+
+ - RadioBar(self, screen, buttonlist) : constructor.
+ - getSelection(self): return value of selected button
+ """
+
+ def __init__(self, screen, buttonlist):
+ self.list = []
+ self.item = 0
+ self.group = RadioGroup()
+ Grid.__init__(self, 1, len(buttonlist))
+ for (title, value, default) in buttonlist:
+ b = self.group.add(title, value, default)
+ self.list.append((b, value))
+ self.setField(b, 0, self.item, anchorLeft = 1)
+ self.item = self.item + 1
+
+ def getSelection(self):
+ return self.group.getSelection()
+
+
+# you normally want to pack a ButtonBar with growx = 1
+
+class ButtonBar(Grid):
+ """ Bar of buttons, based on grid.
+
+ methods:
+
+ - ButtonBar(screen, buttonlist,buttonlist, compact = 0):
+ - buttonPressed(self, result): Takes the widget returned by Form.run and looks to see
+ if it was one of the widgets in the ButtonBar.
+ """
+ def __init__(self, screen, buttonlist, compact = 0):
+ self.list = []
+ self.hotkeys = {}
+ self.item = 0
+ Grid.__init__(self, len(buttonlist), 1)
+ for blist in buttonlist:
+ if isinstance(blist, str if sys.version >= '3' else basestring):
+ title = blist
+ value = blist.lower()
+ elif len(blist) == 2:
+ (title, value) = blist
+ else:
+ (title, value, hotkey) = blist
+ self.hotkeys[hotkey] = value
+
+ if compact:
+ b = CompactButton(title)
+ else:
+ b = Button(title)
+ self.list.append((b, value))
+ self.setField(b, self.item, 0, (1, 0, 1, 0))
+ self.item = self.item + 1
+
+ def buttonPressed(self, result):
+ if result in self.hotkeys:
+ return self.hotkeys[result]
+
+ for (button, value) in self.list:
+ if result == button:
+ return value
+ return None
+
+
+class GridFormHelp(Grid):
+ """ Subclass of Grid, for the help form text.
+
+ methods:
+
+ - GridFormHelp(self, screen, title, help, *args) :
+ - add (self, widget, col, row, padding = (0, 0, 0, 0),
+ anchorLeft = 0, anchorTop = 0, anchorRight = 0,
+ anchorBottom = 0, growx = 0, growy = 0):
+ - runOnce(self, x = None, y = None): pop up the help window
+ - addHotKey(self, keyname):
+ - setTimer(self, keyname):
+ - create(self, x = None, y = None):
+ - run(self, x = None, y = None):
+ - draw(self):
+ - runPopup(self):
+ - setCurrent (self, co):
+ """
+ def __init__(self, screen, title, help, *args):
+ self.screen = screen
+ self.title = title
+ self.form = Form(help)
+ self.childList = []
+ self.form_created = 0
+ args = list(args)
+ args[:0] = [self]
+ Grid.__init__(*tuple(args))
+
+ def add(self, widget, col, row, padding = (0, 0, 0, 0),
+ anchorLeft = 0, anchorTop = 0, anchorRight = 0,
+ anchorBottom = 0, growx = 0, growy = 0):
+ self.setField(widget, col, row, padding, anchorLeft,
+ anchorTop, anchorRight, anchorBottom,
+ growx, growy);
+ self.childList.append(widget)
+
+ def runOnce(self, x = None, y = None):
+ result = self.run(x, y)
+ self.screen.popWindow()
+ return result
+
+ def addHotKey(self, keyname):
+ self.form.addHotKey(keyname)
+
+ def setTimer(self, keyname):
+ self.form.setTimer(keyname)
+
+ def create(self, x = None, y = None):
+ if not self.form_created:
+ self.place(1,1)
+ for child in self.childList:
+ self.form.add(child)
+ self.screen.gridWrappedWindow(self, self.title, x, y)
+ self.form_created = 1
+
+ def run(self, x = None, y = None):
+ self.create(x, y)
+ return self.form.run()
+
+ def draw(self):
+ self.create()
+ return self.form.draw()
+
+ def runPopup(self):
+ self.create()
+ self.screen.gridWrappedWindow(self, self.title)
+ result = self.form.run()
+ self.screen.popWindow()
+ return result
+
+ def setCurrent (self, co):
+ self.form.setCurrent (co)
+
+class GridForm(GridFormHelp):
+ """ GridForm class (extends GridFormHelp):
+
+ methods:
+
+ - GridForm(self, screen, title, *args):
+ """
+ def __init__(self, screen, title, *args):
+ myargs = (self, screen, title, None) + args
+ GridFormHelp.__init__(*myargs)
+
+class CheckboxTree(Widget):
+ """ CheckboxTree combo widget,
+
+ methods:
+
+ - CheckboxTree(self, height, scroll = 0, width = None, hide_checkbox = 0, unselectable = 0)
+ constructor.
+ - append(self, text, item = None, selected = 0):
+ - addItem(self, text, path, item = None, selected = 0):
+ - getCurrent(self):
+ - getSelection(self):
+ - setEntry(self, item, text):
+ - setCurrent(self, item):
+ - setEntryValue(self, item, selected = 1):
+ - getEntryValue(self, item):
+ """
+ def append(self, text, item = None, selected = 0):
+ self.addItem(text, (snackArgs['append'], ), item, selected)
+
+ def addItem(self, text, path, item = None, selected = 0):
+ if item is None:
+ item = text
+ key = self.w.checkboxtreeAddItem(text, path, selected)
+ self.key2item[key] = item
+ self.item2key[item] = key
+
+ def getCurrent(self):
+ curr = self.w.checkboxtreeGetCurrent()
+ return self.key2item[curr]
+
+ def __init__(self, height, scroll = 0, width = None, hide_checkbox = 0, unselectable = 0):
+ self.w = _snack.checkboxtree(height, scroll, hide_checkbox, unselectable)
+ self.key2item = {}
+ self.item2key = {}
+ if (width):
+ self.w.checkboxtreeSetWidth(width)
+
+ def getSelection(self):
+ selection = []
+ list = self.w.checkboxtreeGetSelection()
+ for key in list:
+ selection.append(self.key2item[key])
+ return selection
+
+ def setEntry(self, item, text):
+ self.w.checkboxtreeSetEntry(self.item2key[item], text)
+
+ def setCurrent(self, item):
+ self.w.checkboxtreeSetCurrent(self.item2key[item])
+
+ def setEntryValue(self, item, selected = 1):
+ self.w.checkboxtreeSetEntryValue(self.item2key[item], selected)
+
+ def getEntryValue(self, item):
+ return self.w.checkboxtreeGetEntryValue(self.item2key[item])
+
+def ListboxChoiceWindow(screen, title, text, items,
+ buttons = ('Ok', 'Cancel'),
+ width = 40, scroll = 0, height = -1, default = None,
+ help = None):
+ """
+ - ListboxChoiceWindow(screen, title, text, items,
+ buttons = ('Ok', 'Cancel'),
+ width = 40, scroll = 0, height = -1, default = None,
+ help = None):
+ """
+ if (height == -1): height = len(items)
+
+ bb = ButtonBar(screen, buttons)
+ t = TextboxReflowed(width, text)
+ l = Listbox(height, scroll = scroll, returnExit = 1)
+ count = 0
+ for item in items:
+ if type(item) == tuple:
+ (text, key) = item
+ else:
+ text = item
+ key = count
+
+ if (default == count):
+ default = key
+ elif (default == item):
+ default = key
+
+ l.append(text, key)
+ count = count + 1
+
+ if (default != None):
+ l.setCurrent (default)
+
+ g = GridFormHelp(screen, title, help, 1, 3)
+ g.add(t, 0, 0)
+ g.add(l, 0, 1, padding = (0, 1, 0, 1))
+ g.add(bb, 0, 2, growx = 1)
+
+ rc = g.runOnce()
+
+ return (bb.buttonPressed(rc), l.current())
+
+def ButtonChoiceWindow(screen, title, text,
+ buttons = [ 'Ok', 'Cancel' ],
+ width = 40, x = None, y = None, help = None):
+ """
+ - ButtonChoiceWindow(screen, title, text,
+ buttons = [ 'Ok', 'Cancel' ],
+ width = 40, x = None, y = None, help = None):
+ """
+ bb = ButtonBar(screen, buttons)
+ t = TextboxReflowed(width, text, maxHeight = screen.height - 12)
+
+ g = GridFormHelp(screen, title, help, 1, 2)
+ g.add(t, 0, 0, padding = (0, 0, 0, 1))
+ g.add(bb, 0, 1, growx = 1)
+ return bb.buttonPressed(g.runOnce(x, y))
+
+def EntryWindow(screen, title, text, prompts, allowCancel = 1, width = 40,
+ entryWidth = 20, buttons = [ 'Ok', 'Cancel' ], help = None):
+ """
+ EntryWindow(screen, title, text, prompts, allowCancel = 1, width = 40,
+ entryWidth = 20, buttons = [ 'Ok', 'Cancel' ], help = None):
+ """
+ bb = ButtonBar(screen, buttons);
+ t = TextboxReflowed(width, text)
+
+ count = 0
+ for n in prompts:
+ count = count + 1
+
+ sg = Grid(2, count)
+
+ count = 0
+ entryList = []
+ for n in prompts:
+ if type(n) == tuple:
+ (n, e) = n
+ if isinstance(e, str if sys.version >= '3' else basestring):
+ e = Entry(entryWidth, e)
+ else:
+ e = Entry(entryWidth)
+
+ sg.setField(Label(n), 0, count, padding = (0, 0, 1, 0), anchorLeft = 1)
+ sg.setField(e, 1, count, anchorLeft = 1)
+ count = count + 1
+ entryList.append(e)
+
+ g = GridFormHelp(screen, title, help, 1, 3)
+
+ g.add(t, 0, 0, padding = (0, 0, 0, 1))
+ g.add(sg, 0, 1, padding = (0, 0, 0, 1))
+ g.add(bb, 0, 2, growx = 1)
+
+ result = g.runOnce()
+
+ entryValues = []
+ count = 0
+ for n in prompts:
+ entryValues.append(entryList[count].value())
+ count = count + 1
+
+ return (bb.buttonPressed(result), tuple(entryValues))
+
+class CListbox(Grid):
+ """Clistbox convenience class.
+
+ methods:
+
+ - Clistbox(self, height, cols, cols_widths, scroll = 0) : constructor
+ - colFormText(self, col_text, align = None, adjust_width = 0) : column text.
+ - append(self, col_text, item, col_text_align = None) :
+ - insert(self, col_text, item, before, col_text_align = None)
+ - delete(self, item)
+ - replace(self, col_text, item, col_text_align = None)
+ - current(self) : returns current item
+ - setCurrent(self, item): sets an item as current
+ - clear(self): clear the listbox
+
+ Alignments may be LEFT, RIGHT, CENTER, None
+ """
+ def __init__(self, height, cols, col_widths, scroll = 0,
+ returnExit = 0, width = 0, col_pad = 1,
+ col_text_align = None, col_labels = None,
+ col_label_align = None, adjust_width=0):
+
+ self.cols = cols
+ self.col_widths = col_widths[:]
+ self.col_pad = col_pad
+ self.col_text_align = col_text_align
+
+ if col_labels != None:
+ Grid.__init__(self, 1, 2)
+ box_y = 1
+
+ lstr = self.colFormText(col_labels, col_label_align,
+ adjust_width=adjust_width)
+ self.label = Label(lstr)
+ self.setField(self.label, 0, 0, anchorLeft=1)
+
+ else:
+ Grid.__init__(self, 1, 1)
+ box_y = 0
+
+
+ self.listbox = Listbox(height, scroll, returnExit, width)
+ self.setField(self.listbox, 0, box_y, anchorRight=1)
+
+ def colFormText(self, col_text, align = None, adjust_width=0):
+ i = 0
+ str = ""
+ c_len = len(col_text)
+ while (i < self.cols) and (i < c_len):
+
+ cstr = col_text[i]
+ cstrlen = _snack.wstrlen(cstr)
+ if self.col_widths[i] < cstrlen:
+ if adjust_width:
+ self.col_widths[i] = cstrlen
+ else:
+ cstr = cstr[:self.col_widths[i]]
+
+ delta = self.col_widths[i] - _snack.wstrlen(cstr)
+
+ if delta > 0:
+ if align == None:
+ a = LEFT
+ else:
+ a = align[i]
+
+ if a == LEFT:
+ cstr = cstr + (" " * delta)
+ if a == CENTER:
+ cstr = (" " * (delta / 2)) + cstr + \
+ (" " * ((delta + 1) / 2))
+ if a == RIGHT:
+ cstr = (" " * delta) + cstr
+
+ if i != c_len - 1:
+ pstr = (" " * self.col_pad)
+ else:
+ pstr = ""
+
+ str = str + cstr + pstr
+
+ i = i + 1
+
+ return str
+
+ def append(self, col_text, item, col_text_align = None):
+ if col_text_align == None:
+ col_text_align = self.col_text_align
+ text = self.colFormText(col_text, col_text_align)
+ self.listbox.append(text, item)
+
+ def insert(self, col_text, item, before, col_text_align = None):
+ if col_text_align == None:
+ col_text_align = self.col_text_align
+ text = self.colFormText(col_text, col_text_align)
+ self.listbox.insert(text, item, before)
+
+ def delete(self, item):
+ self.listbox.delete(item)
+
+ def replace(self, col_text, item, col_text_align = None):
+ if col_text_align == None:
+ col_text_align = self.col_text_align
+ text = self.colFormText(col_text, col_text_align)
+ self.listbox.replace(text, item)
+
+ def current(self):
+ return self.listbox.current()
+
+ def setCurrent(self, item):
+ self.listbox.setCurrent(item)
+
+ def clear(self):
+ self.listbox.clear()
+
+def customColorset(x):
+ return 30 + x
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
new file mode 100644
index 0000000..3cae9f5
--- /dev/null
+++ b/lib/sqlalchemy/__init__.py
@@ -0,0 +1,158 @@
+# sqlalchemy/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import util as _util
+from .engine import create_engine
+from .engine import create_mock_engine
+from .engine import engine_from_config
+from .inspection import inspect
+from .schema import BLANK_SCHEMA
+from .schema import CheckConstraint
+from .schema import Column
+from .schema import ColumnDefault
+from .schema import Computed
+from .schema import Constraint
+from .schema import DDL
+from .schema import DefaultClause
+from .schema import FetchedValue
+from .schema import ForeignKey
+from .schema import ForeignKeyConstraint
+from .schema import Identity
+from .schema import Index
+from .schema import MetaData
+from .schema import PrimaryKeyConstraint
+from .schema import Sequence
+from .schema import Table
+from .schema import ThreadLocalMetaData
+from .schema import UniqueConstraint
+from .sql import alias
+from .sql import all_
+from .sql import and_
+from .sql import any_
+from .sql import asc
+from .sql import between
+from .sql import bindparam
+from .sql import case
+from .sql import cast
+from .sql import collate
+from .sql import column
+from .sql import delete
+from .sql import desc
+from .sql import distinct
+from .sql import except_
+from .sql import except_all
+from .sql import exists
+from .sql import extract
+from .sql import false
+from .sql import func
+from .sql import funcfilter
+from .sql import insert
+from .sql import intersect
+from .sql import intersect_all
+from .sql import join
+from .sql import LABEL_STYLE_DEFAULT
+from .sql import LABEL_STYLE_DISAMBIGUATE_ONLY
+from .sql import LABEL_STYLE_NONE
+from .sql import LABEL_STYLE_TABLENAME_PLUS_COL
+from .sql import lambda_stmt
+from .sql import lateral
+from .sql import literal
+from .sql import literal_column
+from .sql import modifier
+from .sql import not_
+from .sql import null
+from .sql import nulls_first
+from .sql import nulls_last
+from .sql import nullsfirst
+from .sql import nullslast
+from .sql import or_
+from .sql import outerjoin
+from .sql import outparam
+from .sql import over
+from .sql import select
+from .sql import subquery
+from .sql import table
+from .sql import tablesample
+from .sql import text
+from .sql import true
+from .sql import tuple_
+from .sql import type_coerce
+from .sql import union
+from .sql import union_all
+from .sql import update
+from .sql import values
+from .sql import within_group
+from .types import ARRAY
+from .types import BIGINT
+from .types import BigInteger
+from .types import BINARY
+from .types import BLOB
+from .types import BOOLEAN
+from .types import Boolean
+from .types import CHAR
+from .types import CLOB
+from .types import DATE
+from .types import Date
+from .types import DATETIME
+from .types import DateTime
+from .types import DECIMAL
+from .types import Enum
+from .types import FLOAT
+from .types import Float
+from .types import INT
+from .types import INTEGER
+from .types import Integer
+from .types import Interval
+from .types import JSON
+from .types import LargeBinary
+from .types import NCHAR
+from .types import NUMERIC
+from .types import Numeric
+from .types import NVARCHAR
+from .types import PickleType
+from .types import REAL
+from .types import SMALLINT
+from .types import SmallInteger
+from .types import String
+from .types import TEXT
+from .types import Text
+from .types import TIME
+from .types import Time
+from .types import TIMESTAMP
+from .types import TupleType
+from .types import TypeDecorator
+from .types import Unicode
+from .types import UnicodeText
+from .types import VARBINARY
+from .types import VARCHAR
+
+
+__version__ = "1.4.40"
+
+
+def __go(lcls):
+ global __all__
+
+ from . import events
+ from . import util as _sa_util
+
+ import inspect as _inspect
+
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
+
+ _sa_util.preloaded.import_prefix("sqlalchemy")
+
+ from . import exc
+
+ exc._version_token = "".join(__version__.split(".")[0:2])
+
+
+__go(locals())
diff --git a/lib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.so b/lib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..58e90e7
--- /dev/null
+++ b/lib/sqlalchemy/cimmutabledict.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/sqlalchemy/connectors/__init__.py b/lib/sqlalchemy/connectors/__init__.py
new file mode 100644
index 0000000..e738086
--- /dev/null
+++ b/lib/sqlalchemy/connectors/__init__.py
@@ -0,0 +1,10 @@
+# connectors/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+class Connector(object):
+ pass
diff --git a/lib/sqlalchemy/connectors/mxodbc.py b/lib/sqlalchemy/connectors/mxodbc.py
new file mode 100644
index 0000000..89b3484
--- /dev/null
+++ b/lib/sqlalchemy/connectors/mxodbc.py
@@ -0,0 +1,166 @@
+# connectors/mxodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+Provide a SQLALchemy connector for the eGenix mxODBC commercial
+Python adapter for ODBC. This is not a free product, but eGenix
+provides SQLAlchemy with a license for use in continuous integration
+testing.
+
+This has been tested for use with mxODBC 3.1.2 on SQL Server 2005
+and 2008, using the SQL Server Native driver. However, it is
+possible for this to be used on other database platforms.
+
+For more info on mxODBC, see https://www.egenix.com/
+
+.. deprecated:: 1.4 The mxODBC DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to mssql.
+
+"""
+
+import re
+import sys
+import warnings
+
+from . import Connector
+from ..util import warn_deprecated
+
+
+class MxODBCConnector(Connector):
+ driver = "mxodbc"
+
+ supports_sane_multi_rowcount = False
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ supports_native_decimal = True
+
+ @classmethod
+ def dbapi(cls):
+ # this classmethod will normally be replaced by an instance
+ # attribute of the same name, so this is normally only called once.
+ cls._load_mx_exceptions()
+ platform = sys.platform
+ if platform == "win32":
+ from mx.ODBC import Windows as Module
+ # this can be the string "linux2", and possibly others
+ elif "linux" in platform:
+ from mx.ODBC import unixODBC as Module
+ elif platform == "darwin":
+ from mx.ODBC import iODBC as Module
+ else:
+ raise ImportError("Unrecognized platform for mxODBC import")
+
+ warn_deprecated(
+ "The mxODBC DBAPI is deprecated and will be removed"
+ "in a future version. Please use one of the supported DBAPIs to"
+ "connect to mssql.",
+ version="1.4",
+ )
+ return Module
+
+ @classmethod
+ def _load_mx_exceptions(cls):
+ """Import mxODBC exception classes into the module namespace,
+ as if they had been imported normally. This is done here
+ to avoid requiring all SQLAlchemy users to install mxODBC.
+ """
+ global InterfaceError, ProgrammingError
+ from mx.ODBC import InterfaceError
+ from mx.ODBC import ProgrammingError
+
+ def on_connect(self):
+ def connect(conn):
+ conn.stringformat = self.dbapi.MIXED_STRINGFORMAT
+ conn.datetimeformat = self.dbapi.PYDATETIME_DATETIMEFORMAT
+ conn.decimalformat = self.dbapi.DECIMAL_DECIMALFORMAT
+ conn.errorhandler = self._error_handler()
+
+ return connect
+
+ def _error_handler(self):
+ """Return a handler that adjusts mxODBC's raised Warnings to
+ emit Python standard warnings.
+ """
+ from mx.ODBC.Error import Warning as MxOdbcWarning
+
+ def error_handler(connection, cursor, errorclass, errorvalue):
+ if issubclass(errorclass, MxOdbcWarning):
+ errorclass.__bases__ = (Warning,)
+ warnings.warn(
+ message=str(errorvalue), category=errorclass, stacklevel=2
+ )
+ else:
+ raise errorclass(errorvalue)
+
+ return error_handler
+
+ def create_connect_args(self, url):
+ r"""Return a tuple of \*args, \**kwargs for creating a connection.
+
+ The mxODBC 3.x connection constructor looks like this:
+
+ connect(dsn, user='', password='',
+ clear_auto_commit=1, errorhandler=None)
+
+ This method translates the values in the provided URI
+ into args and kwargs needed to instantiate an mxODBC Connection.
+
+ The arg 'errorhandler' is not used by SQLAlchemy and will
+ not be populated.
+
+ """
+ opts = url.translate_connect_args(username="user")
+ opts.update(url.query)
+ args = opts.pop("host")
+ opts.pop("port", None)
+ opts.pop("database", None)
+ return (args,), opts
+
+ def is_disconnect(self, e, connection, cursor):
+ # TODO: eGenix recommends checking connection.closed here
+ # Does that detect dropped connections ?
+ if isinstance(e, self.dbapi.ProgrammingError):
+ return "connection already closed" in str(e)
+ elif isinstance(e, self.dbapi.Error):
+ return "[08S01]" in str(e)
+ else:
+ return False
+
+ def _get_server_version_info(self, connection):
+ # eGenix suggests using conn.dbms_version instead
+ # of what we're doing here
+ dbapi_con = connection.connection
+ version = []
+ r = re.compile(r"[.\-]")
+ # 18 == pyodbc.SQL_DBMS_VER
+ for n in r.split(dbapi_con.getinfo(18)[1]):
+ try:
+ version.append(int(n))
+ except ValueError:
+ version.append(n)
+ return tuple(version)
+
+ def _get_direct(self, context):
+ if context:
+ native_odbc_execute = context.execution_options.get(
+ "native_odbc_execute", "auto"
+ )
+ # default to direct=True in all cases, is more generally
+ # compatible especially with SQL Server
+ return False if native_odbc_execute is True else True
+ else:
+ return True
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ cursor.executemany(
+ statement, parameters, direct=self._get_direct(context)
+ )
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ cursor.execute(statement, parameters, direct=self._get_direct(context))
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
new file mode 100644
index 0000000..9bb67b5
--- /dev/null
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -0,0 +1,193 @@
+# connectors/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from . import Connector
+from .. import util
+
+
+class PyODBCConnector(Connector):
+ driver = "pyodbc"
+
+ # this is no longer False for pyodbc in general
+ supports_sane_rowcount_returning = True
+ supports_sane_multi_rowcount = False
+
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ supports_native_decimal = True
+ default_paramstyle = "named"
+
+ use_setinputsizes = False
+
+ # for non-DSN connections, this *may* be used to
+ # hold the desired driver name
+ pyodbc_driver_name = None
+
+ def __init__(
+ self, supports_unicode_binds=None, use_setinputsizes=False, **kw
+ ):
+ super(PyODBCConnector, self).__init__(**kw)
+ if supports_unicode_binds is not None:
+ self.supports_unicode_binds = supports_unicode_binds
+ self.use_setinputsizes = use_setinputsizes
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("pyodbc")
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ opts.update(url.query)
+
+ keys = opts
+
+ query = url.query
+
+ connect_args = {}
+ for param in ("ansi", "unicode_results", "autocommit"):
+ if param in keys:
+ connect_args[param] = util.asbool(keys.pop(param))
+
+ if "odbc_connect" in keys:
+ connectors = [util.unquote_plus(keys.pop("odbc_connect"))]
+ else:
+
+ def check_quote(token):
+ if ";" in str(token) or str(token).startswith("{"):
+ token = "{%s}" % token.replace("}", "}}")
+ return token
+
+ keys = dict((k, check_quote(v)) for k, v in keys.items())
+
+ dsn_connection = "dsn" in keys or (
+ "host" in keys and "database" not in keys
+ )
+ if dsn_connection:
+ connectors = [
+ "dsn=%s" % (keys.pop("host", "") or keys.pop("dsn", ""))
+ ]
+ else:
+ port = ""
+ if "port" in keys and "port" not in query:
+ port = ",%d" % int(keys.pop("port"))
+
+ connectors = []
+ driver = keys.pop("driver", self.pyodbc_driver_name)
+ if driver is None and keys:
+ # note if keys is empty, this is a totally blank URL
+ util.warn(
+ "No driver name specified; "
+ "this is expected by PyODBC when using "
+ "DSN-less connections"
+ )
+ else:
+ connectors.append("DRIVER={%s}" % driver)
+
+ connectors.extend(
+ [
+ "Server=%s%s" % (keys.pop("host", ""), port),
+ "Database=%s" % keys.pop("database", ""),
+ ]
+ )
+
+ user = keys.pop("user", None)
+ if user:
+ connectors.append("UID=%s" % user)
+ pwd = keys.pop("password", "")
+ if pwd:
+ connectors.append("PWD=%s" % pwd)
+ else:
+ authentication = keys.pop("authentication", None)
+ if authentication:
+ connectors.append("Authentication=%s" % authentication)
+ else:
+ connectors.append("Trusted_Connection=Yes")
+
+ # if set to 'Yes', the ODBC layer will try to automagically
+ # convert textual data from your database encoding to your
+ # client encoding. This should obviously be set to 'No' if
+ # you query a cp1253 encoded database from a latin1 client...
+ if "odbc_autotranslate" in keys:
+ connectors.append(
+ "AutoTranslate=%s" % keys.pop("odbc_autotranslate")
+ )
+
+ connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()])
+
+ return [[";".join(connectors)], connect_args]
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.ProgrammingError):
+ return "The cursor's connection has been closed." in str(
+ e
+ ) or "Attempt to use a closed connection." in str(e)
+ else:
+ return False
+
+ def _dbapi_version(self):
+ if not self.dbapi:
+ return ()
+ return self._parse_dbapi_version(self.dbapi.version)
+
+ def _parse_dbapi_version(self, vers):
+ m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers)
+ if not m:
+ return ()
+ vers = tuple([int(x) for x in m.group(1).split(".")])
+ if m.group(2):
+ vers += (m.group(2),)
+ return vers
+
+ def _get_server_version_info(self, connection, allow_chars=True):
+ # NOTE: this function is not reliable, particularly when
+ # freetds is in use. Implement database-specific server version
+ # queries.
+ dbapi_con = connection.connection
+ version = []
+ r = re.compile(r"[.\-]")
+ for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
+ try:
+ version.append(int(n))
+ except ValueError:
+ if allow_chars:
+ version.append(n)
+ return tuple(version)
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ # the rules for these types seems a little strange, as you can pass
+ # non-tuples as well as tuples, however it seems to assume "0"
+ # for the subsequent values if you don't pass a tuple which fails
+ # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype
+ # that ticket #5649 is targeting.
+
+ # NOTE: as of #6058, this won't be called if the use_setinputsizes flag
+ # is False, or if no types were specified in list_of_tuples
+
+ cursor.setinputsizes(
+ [
+ (dbtype, None, None)
+ if not isinstance(dbtype, tuple)
+ else dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ # adjust for ConnectionFairy being present
+ # allows attribute set e.g. "connection.autocommit = True"
+ # to work properly
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+ super(PyODBCConnector, self).set_isolation_level(connection, level)
diff --git a/lib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.so b/lib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..f2b7b00
--- /dev/null
+++ b/lib/sqlalchemy/cprocessors.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.so b/lib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.so
new file mode 100755
index 0000000..0d851bd
--- /dev/null
+++ b/lib/sqlalchemy/cresultproxy.cpython-39-x86_64-linux-gnu.so
Binary files differ
diff --git a/lib/sqlalchemy/databases/__init__.py b/lib/sqlalchemy/databases/__init__.py
new file mode 100644
index 0000000..fa83229
--- /dev/null
+++ b/lib/sqlalchemy/databases/__init__.py
@@ -0,0 +1,38 @@
+# databases/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Include imports from the sqlalchemy.dialects package for backwards
+compatibility with pre 0.6 versions.
+
+"""
+from ..dialects.firebird import base as firebird
+from ..dialects.mssql import base as mssql
+from ..dialects.mysql import base as mysql
+from ..dialects.oracle import base as oracle
+from ..dialects.postgresql import base as postgresql
+from ..dialects.sqlite import base as sqlite
+from ..dialects.sybase import base as sybase
+from ..util import warn_deprecated_20
+
+postgres = postgresql
+
+
+__all__ = (
+ "firebird",
+ "mssql",
+ "mysql",
+ "postgresql",
+ "sqlite",
+ "oracle",
+ "sybase",
+)
+
+
+warn_deprecated_20(
+ "The `database` package is deprecated and will be removed in v2.0 "
+ "of sqlalchemy. Use the `dialects` package instead."
+)
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py
new file mode 100644
index 0000000..84a9ad8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/__init__.py
@@ -0,0 +1,72 @@
+# dialects/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+__all__ = (
+ "firebird",
+ "mssql",
+ "mysql",
+ "oracle",
+ "postgresql",
+ "sqlite",
+ "sybase",
+)
+
+
+from .. import util
+
+
+def _auto_fn(name):
+ """default dialect importer.
+
+ plugs into the :class:`.PluginLoader`
+ as a first-hit system.
+
+ """
+ if "." in name:
+ dialect, driver = name.split(".")
+ else:
+ dialect = name
+ driver = "base"
+
+ try:
+ if dialect == "firebird":
+ try:
+ module = __import__("sqlalchemy_firebird")
+ except ImportError:
+ module = __import__("sqlalchemy.dialects.firebird").dialects
+ module = getattr(module, dialect)
+ elif dialect == "sybase":
+ try:
+ module = __import__("sqlalchemy_sybase")
+ except ImportError:
+ module = __import__("sqlalchemy.dialects.sybase").dialects
+ module = getattr(module, dialect)
+ elif dialect == "mariadb":
+ # it's "OK" for us to hardcode here since _auto_fn is already
+ # hardcoded. if mysql / mariadb etc were third party dialects
+ # they would just publish all the entrypoints, which would actually
+ # look much nicer.
+ module = __import__(
+ "sqlalchemy.dialects.mysql.mariadb"
+ ).dialects.mysql.mariadb
+ return module.loader(driver)
+ else:
+ module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects
+ module = getattr(module, dialect)
+ except ImportError:
+ return None
+
+ if hasattr(module, driver):
+ module = getattr(module, driver)
+ return lambda: module.dialect
+ else:
+ return None
+
+
+registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
+
+plugins = util.PluginLoader("sqlalchemy.plugins")
diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py
new file mode 100644
index 0000000..a34eecf
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/__init__.py
@@ -0,0 +1,41 @@
+# firebird/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from sqlalchemy.dialects.firebird.base import BIGINT
+from sqlalchemy.dialects.firebird.base import BLOB
+from sqlalchemy.dialects.firebird.base import CHAR
+from sqlalchemy.dialects.firebird.base import DATE
+from sqlalchemy.dialects.firebird.base import FLOAT
+from sqlalchemy.dialects.firebird.base import NUMERIC
+from sqlalchemy.dialects.firebird.base import SMALLINT
+from sqlalchemy.dialects.firebird.base import TEXT
+from sqlalchemy.dialects.firebird.base import TIME
+from sqlalchemy.dialects.firebird.base import TIMESTAMP
+from sqlalchemy.dialects.firebird.base import VARCHAR
+from . import base # noqa
+from . import fdb # noqa
+from . import kinterbasdb # noqa
+
+
+base.dialect = dialect = fdb.dialect
+
+__all__ = (
+ "SMALLINT",
+ "BIGINT",
+ "FLOAT",
+ "FLOAT",
+ "DATE",
+ "TIME",
+ "TEXT",
+ "NUMERIC",
+ "FLOAT",
+ "TIMESTAMP",
+ "VARCHAR",
+ "CHAR",
+ "BLOB",
+ "dialect",
+)
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
new file mode 100644
index 0000000..e2698b1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -0,0 +1,989 @@
+# firebird/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: firebird
+ :name: Firebird
+
+.. note::
+
+ The Firebird dialect within SQLAlchemy **is not currently supported**.
+ It is not tested within continuous integration and is likely to have
+ many issues and caveats not currently handled. Consider using the
+ `external dialect <https://github.com/pauldex/sqlalchemy-firebird>`_
+ instead.
+
+.. deprecated:: 1.4 The internal Firebird dialect is deprecated and will be
+ removed in a future version. Use the external dialect.
+
+Firebird Dialects
+-----------------
+
+Firebird offers two distinct dialects_ (not to be confused with a
+SQLAlchemy ``Dialect``):
+
+dialect 1
+ This is the old syntax and behaviour, inherited from Interbase pre-6.0.
+
+dialect 3
+ This is the newer and supported syntax, introduced in Interbase 6.0.
+
+The SQLAlchemy Firebird dialect detects these versions and
+adjusts its representation of SQL accordingly. However,
+support for dialect 1 is not well tested and probably has
+incompatibilities.
+
+Locking Behavior
+----------------
+
+Firebird locks tables aggressively. For this reason, a DROP TABLE may
+hang until other transactions are released. SQLAlchemy does its best
+to release transactions as quickly as possible. The most common cause
+of hanging transactions is a non-fully consumed result set, i.e.::
+
+ result = engine.execute(text("select * from table"))
+ row = result.fetchone()
+ return
+
+Where above, the ``CursorResult`` has not been fully consumed. The
+connection will be returned to the pool and the transactional state
+rolled back once the Python garbage collector reclaims the objects
+which hold onto the connection, which often occurs asynchronously.
+The above use case can be alleviated by calling ``first()`` on the
+``CursorResult`` which will fetch the first row and immediately close
+all remaining cursor/connection resources.
+
+RETURNING support
+-----------------
+
+Firebird 2.0 supports returning a result set from inserts, and 2.1
+extends that to deletes and updates. This is generically exposed by
+the SQLAlchemy ``returning()`` method, such as::
+
+ # INSERT..RETURNING
+ result = table.insert().returning(table.c.col1, table.c.col2).\
+ values(name='foo')
+ print(result.fetchall())
+
+ # UPDATE..RETURNING
+ raises = empl.update().returning(empl.c.id, empl.c.salary).\
+ where(empl.c.sales>100).\
+ values(dict(salary=empl.c.salary * 1.1))
+ print(raises.fetchall())
+
+
+.. _dialects: https://mc-computing.com/Databases/Firebird/SQL_Dialect.html
+"""
+
+import datetime
+
+from sqlalchemy import exc
+from sqlalchemy import sql
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.engine import default
+from sqlalchemy.engine import reflection
+from sqlalchemy.sql import compiler
+from sqlalchemy.sql import expression
+from sqlalchemy.types import BIGINT
+from sqlalchemy.types import BLOB
+from sqlalchemy.types import DATE
+from sqlalchemy.types import FLOAT
+from sqlalchemy.types import INTEGER
+from sqlalchemy.types import Integer
+from sqlalchemy.types import NUMERIC
+from sqlalchemy.types import SMALLINT
+from sqlalchemy.types import TEXT
+from sqlalchemy.types import TIME
+from sqlalchemy.types import TIMESTAMP
+
+
+RESERVED_WORDS = set(
+ [
+ "active",
+ "add",
+ "admin",
+ "after",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "ascending",
+ "at",
+ "auto",
+ "avg",
+ "before",
+ "begin",
+ "between",
+ "bigint",
+ "bit_length",
+ "blob",
+ "both",
+ "by",
+ "case",
+ "cast",
+ "char",
+ "character",
+ "character_length",
+ "char_length",
+ "check",
+ "close",
+ "collate",
+ "column",
+ "commit",
+ "committed",
+ "computed",
+ "conditional",
+ "connect",
+ "constraint",
+ "containing",
+ "count",
+ "create",
+ "cross",
+ "cstring",
+ "current",
+ "current_connection",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_transaction",
+ "current_user",
+ "cursor",
+ "database",
+ "date",
+ "day",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delete",
+ "desc",
+ "descending",
+ "disconnect",
+ "distinct",
+ "do",
+ "domain",
+ "double",
+ "drop",
+ "else",
+ "end",
+ "entry_point",
+ "escape",
+ "exception",
+ "execute",
+ "exists",
+ "exit",
+ "external",
+ "extract",
+ "fetch",
+ "file",
+ "filter",
+ "float",
+ "for",
+ "foreign",
+ "from",
+ "full",
+ "function",
+ "gdscode",
+ "generator",
+ "gen_id",
+ "global",
+ "grant",
+ "group",
+ "having",
+ "hour",
+ "if",
+ "in",
+ "inactive",
+ "index",
+ "inner",
+ "input_type",
+ "insensitive",
+ "insert",
+ "int",
+ "integer",
+ "into",
+ "is",
+ "isolation",
+ "join",
+ "key",
+ "leading",
+ "left",
+ "length",
+ "level",
+ "like",
+ "long",
+ "lower",
+ "manual",
+ "max",
+ "maximum_segment",
+ "merge",
+ "min",
+ "minute",
+ "module_name",
+ "month",
+ "names",
+ "national",
+ "natural",
+ "nchar",
+ "no",
+ "not",
+ "null",
+ "numeric",
+ "octet_length",
+ "of",
+ "on",
+ "only",
+ "open",
+ "option",
+ "or",
+ "order",
+ "outer",
+ "output_type",
+ "overflow",
+ "page",
+ "pages",
+ "page_size",
+ "parameter",
+ "password",
+ "plan",
+ "position",
+ "post_event",
+ "precision",
+ "primary",
+ "privileges",
+ "procedure",
+ "protected",
+ "rdb$db_key",
+ "read",
+ "real",
+ "record_version",
+ "recreate",
+ "recursive",
+ "references",
+ "release",
+ "reserv",
+ "reserving",
+ "retain",
+ "returning_values",
+ "returns",
+ "revoke",
+ "right",
+ "rollback",
+ "rows",
+ "row_count",
+ "savepoint",
+ "schema",
+ "second",
+ "segment",
+ "select",
+ "sensitive",
+ "set",
+ "shadow",
+ "shared",
+ "singular",
+ "size",
+ "smallint",
+ "snapshot",
+ "some",
+ "sort",
+ "sqlcode",
+ "stability",
+ "start",
+ "starting",
+ "starts",
+ "statistics",
+ "sub_type",
+ "sum",
+ "suspend",
+ "table",
+ "then",
+ "time",
+ "timestamp",
+ "to",
+ "trailing",
+ "transaction",
+ "trigger",
+ "trim",
+ "uncommitted",
+ "union",
+ "unique",
+ "update",
+ "upper",
+ "user",
+ "using",
+ "value",
+ "values",
+ "varchar",
+ "variable",
+ "varying",
+ "view",
+ "wait",
+ "when",
+ "where",
+ "while",
+ "with",
+ "work",
+ "write",
+ "year",
+ ]
+)
+
+
+class _StringType(sqltypes.String):
+ """Base for Firebird string types."""
+
+ def __init__(self, charset=None, **kw):
+ self.charset = charset
+ super(_StringType, self).__init__(**kw)
+
+
+class VARCHAR(_StringType, sqltypes.VARCHAR):
+ """Firebird VARCHAR type"""
+
+ __visit_name__ = "VARCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ super(VARCHAR, self).__init__(length=length, **kwargs)
+
+
+class CHAR(_StringType, sqltypes.CHAR):
+ """Firebird CHAR type"""
+
+ __visit_name__ = "CHAR"
+
+ def __init__(self, length=None, **kwargs):
+ super(CHAR, self).__init__(length=length, **kwargs)
+
+
+class _FBDateTime(sqltypes.DateTime):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+
+ return process
+
+
+colspecs = {sqltypes.DateTime: _FBDateTime}
+
+ischema_names = {
+ "SHORT": SMALLINT,
+ "LONG": INTEGER,
+ "QUAD": FLOAT,
+ "FLOAT": FLOAT,
+ "DATE": DATE,
+ "TIME": TIME,
+ "TEXT": TEXT,
+ "INT64": BIGINT,
+ "DOUBLE": FLOAT,
+ "TIMESTAMP": TIMESTAMP,
+ "VARYING": VARCHAR,
+ "CSTRING": CHAR,
+ "BLOB": BLOB,
+}
+
+
+# TODO: date conversion types (should be implemented as _FBDateTime,
+# _FBDate, etc. as bind/result functionality is required)
+
+
+class FBTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_boolean(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_TIMESTAMP(type_, **kw)
+
+ def visit_TEXT(self, type_, **kw):
+ return "BLOB SUB_TYPE 1"
+
+ def visit_BLOB(self, type_, **kw):
+ return "BLOB SUB_TYPE 0"
+
+ def _extend_string(self, type_, basic):
+ charset = getattr(type_, "charset", None)
+ if charset is None:
+ return basic
+ else:
+ return "%s CHARACTER SET %s" % (basic, charset)
+
+ def visit_CHAR(self, type_, **kw):
+ basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
+ return self._extend_string(type_, basic)
+
+ def visit_VARCHAR(self, type_, **kw):
+ if not type_.length:
+ raise exc.CompileError(
+ "VARCHAR requires a length on dialect %s" % self.dialect.name
+ )
+ basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
+ return self._extend_string(type_, basic)
+
+
+class FBCompiler(sql.compiler.SQLCompiler):
+ """Firebird specific idiosyncrasies"""
+
+ ansi_bind_rules = True
+
+ # def visit_contains_op_binary(self, binary, operator, **kw):
+ # cant use CONTAINING b.c. it's case insensitive.
+
+ # def visit_not_contains_op_binary(self, binary, operator, **kw):
+ # cant use NOT CONTAINING b.c. it's case insensitive.
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_startswith_op_binary(self, binary, operator, **kw):
+ return "%s STARTING WITH %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ )
+
+ def visit_not_startswith_op_binary(self, binary, operator, **kw):
+ return "%s NOT STARTING WITH %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ )
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ return "mod(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_alias(self, alias, asfrom=False, **kwargs):
+ if self.dialect._version_two:
+ return super(FBCompiler, self).visit_alias(
+ alias, asfrom=asfrom, **kwargs
+ )
+ else:
+ # Override to not use the AS keyword which FB 1.5 does not like
+ if asfrom:
+ alias_name = (
+ isinstance(alias.name, expression._truncated_label)
+ and self._truncated_identifier("alias", alias.name)
+ or alias.name
+ )
+
+ return (
+ self.process(alias.element, asfrom=asfrom, **kwargs)
+ + " "
+ + self.preparer.format_alias(alias, alias_name)
+ )
+ else:
+ return self.process(alias.element, **kwargs)
+
+ def visit_substring_func(self, func, **kw):
+ s = self.process(func.clauses.clauses[0])
+ start = self.process(func.clauses.clauses[1])
+ if len(func.clauses.clauses) > 2:
+ length = self.process(func.clauses.clauses[2])
+ return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+ else:
+ return "SUBSTRING(%s FROM %s)" % (s, start)
+
+ def visit_length_func(self, function, **kw):
+ if self.dialect._version_two:
+ return "char_length" + self.function_argspec(function)
+ else:
+ return "strlen" + self.function_argspec(function)
+
+ visit_char_length_func = visit_length_func
+
+ def function_argspec(self, func, **kw):
+ # TODO: this probably will need to be
+ # narrowed to a fixed list, some no-arg functions
+ # may require parens - see similar example in the oracle
+ # dialect
+ if func.clauses is not None and len(func.clauses):
+ return self.process(func.clause_expr, **kw)
+ else:
+ return ""
+
+ def default_from(self):
+ return " FROM rdb$database"
+
+ def visit_sequence(self, seq, **kw):
+ return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
+
+ def get_select_precolumns(self, select, **kw):
+ """Called when building a ``SELECT`` statement, position is just
+ before column list Firebird puts the limit and offset right
+ after the ``SELECT``...
+ """
+
+ result = ""
+ if select._limit_clause is not None:
+ result += "FIRST %s " % self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ result += "SKIP %s " % self.process(select._offset_clause, **kw)
+ result += super(FBCompiler, self).get_select_precolumns(select, **kw)
+ return result
+
+ def limit_clause(self, select, **kw):
+ """Already taken care of in the `get_select_precolumns` method."""
+
+ return ""
+
+ def returning_clause(self, stmt, returning_cols):
+ columns = [
+ self._label_returning_column(stmt, c)
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return "RETURNING " + ", ".join(columns)
+
+
+class FBDDLCompiler(sql.compiler.DDLCompiler):
+ """Firebird syntactic idiosyncrasies"""
+
+ def visit_create_sequence(self, create):
+ """Generate a ``CREATE GENERATOR`` statement for the sequence."""
+
+ # no syntax for these
+ # https://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
+ if create.element.start is not None:
+ raise NotImplementedError(
+ "Firebird SEQUENCE doesn't support START WITH"
+ )
+ if create.element.increment is not None:
+ raise NotImplementedError(
+ "Firebird SEQUENCE doesn't support INCREMENT BY"
+ )
+
+ if self.dialect._version_two:
+ return "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
+ else:
+ return "CREATE GENERATOR %s" % self.preparer.format_sequence(
+ create.element
+ )
+
+ def visit_drop_sequence(self, drop):
+ """Generate a ``DROP GENERATOR`` statement for the sequence."""
+
+ if self.dialect._version_two:
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(
+ drop.element
+ )
+ else:
+ return "DROP GENERATOR %s" % self.preparer.format_sequence(
+ drop.element
+ )
+
+ def visit_computed_column(self, generated):
+ if generated.persisted is not None:
+ raise exc.CompileError(
+ "Firebird computed columns do not support a persistence "
+ "method setting; set the 'persisted' flag to None for "
+ "Firebird support."
+ )
+ return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+
+
+class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
+ """Install Firebird specific reserved words."""
+
+ reserved_words = RESERVED_WORDS
+ illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(
+ ["_"]
+ )
+
+ def __init__(self, dialect):
+ super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
+
+
+class FBExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq, type_):
+ """Get the next value from the sequence using ``gen_id()``."""
+
+ return self._execute_scalar(
+ "SELECT gen_id(%s, 1) FROM rdb$database"
+ % self.identifier_preparer.format_sequence(seq),
+ type_,
+ )
+
+
+class FBDialect(default.DefaultDialect):
+ """Firebird dialect"""
+
+ name = "firebird"
+ supports_statement_cache = True
+
+ max_identifier_length = 31
+
+ supports_sequences = True
+ sequences_optional = False
+ supports_default_values = True
+ postfetch_lastrowid = False
+
+ supports_native_boolean = False
+
+ requires_name_normalize = True
+ supports_empty_insert = False
+
+ statement_compiler = FBCompiler
+ ddl_compiler = FBDDLCompiler
+ preparer = FBIdentifierPreparer
+ type_compiler = FBTypeCompiler
+ execution_ctx_cls = FBExecutionContext
+
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ construct_arguments = []
+
+ # defaults to dialect ver. 3,
+ # will be autodetected off upon
+ # first connect
+ _version_two = True
+
+ def __init__(self, *args, **kwargs):
+ util.warn_deprecated(
+ "The firebird dialect is deprecated and will be removed "
+ "in a future version. This dialect is superseded by the external "
+ "dialect https://github.com/pauldex/sqlalchemy-firebird.",
+ version="1.4",
+ )
+ super(FBDialect, self).__init__(*args, **kwargs)
+
+ def initialize(self, connection):
+ super(FBDialect, self).initialize(connection)
+ self._version_two = (
+ "firebird" in self.server_version_info
+ and self.server_version_info >= (2,)
+ ) or (
+ "interbase" in self.server_version_info
+ and self.server_version_info >= (6,)
+ )
+
+ if not self._version_two:
+ # TODO: whatever other pre < 2.0 stuff goes here
+ self.ischema_names = ischema_names.copy()
+ self.ischema_names["TIMESTAMP"] = sqltypes.DATE
+ self.colspecs = {sqltypes.DateTime: sqltypes.DATE}
+
+ self.implicit_returning = self._version_two and self.__dict__.get(
+ "implicit_returning", True
+ )
+
+ def has_table(self, connection, table_name, schema=None):
+ """Return ``True`` if the given table exists, ignoring
+ the `schema`."""
+ self._ensure_has_table_connection(connection)
+
+ tblqry = """
+ SELECT 1 AS has_table FROM rdb$database
+ WHERE EXISTS (SELECT rdb$relation_name
+ FROM rdb$relations
+ WHERE rdb$relation_name=?)
+ """
+ c = connection.exec_driver_sql(
+ tblqry, [self.denormalize_name(table_name)]
+ )
+ return c.first() is not None
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ """Return ``True`` if the given sequence (generator) exists."""
+
+ genqry = """
+ SELECT 1 AS has_sequence FROM rdb$database
+ WHERE EXISTS (SELECT rdb$generator_name
+ FROM rdb$generators
+ WHERE rdb$generator_name=?)
+ """
+ c = connection.exec_driver_sql(
+ genqry, [self.denormalize_name(sequence_name)]
+ )
+ return c.first() is not None
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ # there are two queries commonly mentioned for this.
+ # this one, using view_blr, is at the Firebird FAQ among other places:
+ # https://www.firebirdfaq.org/faq174/
+ s = """
+ select rdb$relation_name
+ from rdb$relations
+ where rdb$view_blr is null
+ and (rdb$system_flag is null or rdb$system_flag = 0);
+ """
+
+ # the other query is this one. It's not clear if there's really
+ # any difference between these two. This link:
+ # https://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
+ # states them as interchangeable. Some discussion at [ticket:2898]
+ # SELECT DISTINCT rdb$relation_name
+ # FROM rdb$relation_fields
+ # WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
+
+ return [
+ self.normalize_name(row[0])
+ for row in connection.exec_driver_sql(s)
+ ]
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ # see https://www.firebirdfaq.org/faq174/
+ s = """
+ select rdb$relation_name
+ from rdb$relations
+ where rdb$view_blr is not null
+ and (rdb$system_flag is null or rdb$system_flag = 0);
+ """
+ return [
+ self.normalize_name(row[0])
+ for row in connection.exec_driver_sql(s)
+ ]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ qry = """
+ SELECT rdb$view_source AS view_source
+ FROM rdb$relations
+ WHERE rdb$relation_name=?
+ """
+ rp = connection.exec_driver_sql(
+ qry, [self.denormalize_name(view_name)]
+ )
+ row = rp.first()
+ if row:
+ return row["view_source"]
+ else:
+ return None
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ # Query to extract the PK/FK constrained fields of the given table
+ keyqry = """
+ SELECT se.rdb$field_name AS fname
+ FROM rdb$relation_constraints rc
+ JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
+ WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+ """
+ tablename = self.denormalize_name(table_name)
+ # get primary key fields
+ c = connection.exec_driver_sql(keyqry, ["PRIMARY KEY", tablename])
+ pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()]
+ return {"constrained_columns": pkfields, "name": None}
+
+ @reflection.cache
+ def get_column_sequence(
+ self, connection, table_name, column_name, schema=None, **kw
+ ):
+ tablename = self.denormalize_name(table_name)
+ colname = self.denormalize_name(column_name)
+ # Heuristic-query to determine the generator associated to a PK field
+ genqry = """
+ SELECT trigdep.rdb$depended_on_name AS fgenerator
+ FROM rdb$dependencies tabdep
+ JOIN rdb$dependencies trigdep
+ ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
+ AND trigdep.rdb$depended_on_type=14
+ AND trigdep.rdb$dependent_type=2
+ JOIN rdb$triggers trig ON
+ trig.rdb$trigger_name=tabdep.rdb$dependent_name
+ WHERE tabdep.rdb$depended_on_name=?
+ AND tabdep.rdb$depended_on_type=0
+ AND trig.rdb$trigger_type=1
+ AND tabdep.rdb$field_name=?
+ AND (SELECT count(*)
+ FROM rdb$dependencies trigdep2
+ WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
+ """
+ genr = connection.exec_driver_sql(genqry, [tablename, colname]).first()
+ if genr is not None:
+ return dict(name=self.normalize_name(genr["fgenerator"]))
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ # Query to extract the details of all the fields of the given table
+ tblqry = """
+ SELECT r.rdb$field_name AS fname,
+ r.rdb$null_flag AS null_flag,
+ t.rdb$type_name AS ftype,
+ f.rdb$field_sub_type AS stype,
+ f.rdb$field_length/
+ COALESCE(cs.rdb$bytes_per_character,1) AS flen,
+ f.rdb$field_precision AS fprec,
+ f.rdb$field_scale AS fscale,
+ COALESCE(r.rdb$default_source,
+ f.rdb$default_source) AS fdefault
+ FROM rdb$relation_fields r
+ JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
+ JOIN rdb$types t
+ ON t.rdb$type=f.rdb$field_type AND
+ t.rdb$field_name='RDB$FIELD_TYPE'
+ LEFT JOIN rdb$character_sets cs ON
+ f.rdb$character_set_id=cs.rdb$character_set_id
+ WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
+ ORDER BY r.rdb$field_position
+ """
+ # get the PK, used to determine the eventual associated sequence
+ pk_constraint = self.get_pk_constraint(connection, table_name)
+ pkey_cols = pk_constraint["constrained_columns"]
+
+ tablename = self.denormalize_name(table_name)
+ # get all of the fields for this table
+ c = connection.exec_driver_sql(tblqry, [tablename])
+ cols = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ name = self.normalize_name(row["fname"])
+ orig_colname = row["fname"]
+
+ # get the data type
+ colspec = row["ftype"].rstrip()
+ coltype = self.ischema_names.get(colspec)
+ if coltype is None:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (colspec, name)
+ )
+ coltype = sqltypes.NULLTYPE
+ elif issubclass(coltype, Integer) and row["fprec"] != 0:
+ coltype = NUMERIC(
+ precision=row["fprec"], scale=row["fscale"] * -1
+ )
+ elif colspec in ("VARYING", "CSTRING"):
+ coltype = coltype(row["flen"])
+ elif colspec == "TEXT":
+ coltype = TEXT(row["flen"])
+ elif colspec == "BLOB":
+ if row["stype"] == 1:
+ coltype = TEXT()
+ else:
+ coltype = BLOB()
+ else:
+ coltype = coltype()
+
+ # does it have a default value?
+ defvalue = None
+ if row["fdefault"] is not None:
+ # the value comes down as "DEFAULT 'value'": there may be
+ # more than one whitespace around the "DEFAULT" keyword
+ # and it may also be lower case
+ # (see also https://tracker.firebirdsql.org/browse/CORE-356)
+ defexpr = row["fdefault"].lstrip()
+ assert defexpr[:8].rstrip().upper() == "DEFAULT", (
+ "Unrecognized default value: %s" % defexpr
+ )
+ defvalue = defexpr[8:].strip()
+ if defvalue == "NULL":
+ # Redundant
+ defvalue = None
+ col_d = {
+ "name": name,
+ "type": coltype,
+ "nullable": not bool(row["null_flag"]),
+ "default": defvalue,
+ "autoincrement": "auto",
+ }
+
+ if orig_colname.lower() == orig_colname:
+ col_d["quote"] = True
+
+ # if the PK is a single field, try to see if its linked to
+ # a sequence thru a trigger
+ if len(pkey_cols) == 1 and name == pkey_cols[0]:
+ seq_d = self.get_column_sequence(connection, tablename, name)
+ if seq_d is not None:
+ col_d["sequence"] = seq_d
+
+ cols.append(col_d)
+ return cols
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # Query to extract the details of each UK/FK of the given table
+ fkqry = """
+ SELECT rc.rdb$constraint_name AS cname,
+ cse.rdb$field_name AS fname,
+ ix2.rdb$relation_name AS targetrname,
+ se.rdb$field_name AS targetfname
+ FROM rdb$relation_constraints rc
+ JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
+ JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
+ JOIN rdb$index_segments cse ON
+ cse.rdb$index_name=ix1.rdb$index_name
+ JOIN rdb$index_segments se
+ ON se.rdb$index_name=ix2.rdb$index_name
+ AND se.rdb$field_position=cse.rdb$field_position
+ WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
+ ORDER BY se.rdb$index_name, se.rdb$field_position
+ """
+ tablename = self.denormalize_name(table_name)
+
+ c = connection.exec_driver_sql(fkqry, ["FOREIGN KEY", tablename])
+ fks = util.defaultdict(
+ lambda: {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ }
+ )
+
+ for row in c:
+ cname = self.normalize_name(row["cname"])
+ fk = fks[cname]
+ if not fk["name"]:
+ fk["name"] = cname
+ fk["referred_table"] = self.normalize_name(row["targetrname"])
+ fk["constrained_columns"].append(self.normalize_name(row["fname"]))
+ fk["referred_columns"].append(
+ self.normalize_name(row["targetfname"])
+ )
+ return list(fks.values())
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ qry = """
+ SELECT ix.rdb$index_name AS index_name,
+ ix.rdb$unique_flag AS unique_flag,
+ ic.rdb$field_name AS field_name
+ FROM rdb$indices ix
+ JOIN rdb$index_segments ic
+ ON ix.rdb$index_name=ic.rdb$index_name
+ LEFT OUTER JOIN rdb$relation_constraints
+ ON rdb$relation_constraints.rdb$index_name =
+ ic.rdb$index_name
+ WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
+ AND rdb$relation_constraints.rdb$constraint_type IS NULL
+ ORDER BY index_name, ic.rdb$field_position
+ """
+ c = connection.exec_driver_sql(
+ qry, [self.denormalize_name(table_name)]
+ )
+
+ indexes = util.defaultdict(dict)
+ for row in c:
+ indexrec = indexes[row["index_name"]]
+ if "name" not in indexrec:
+ indexrec["name"] = self.normalize_name(row["index_name"])
+ indexrec["column_names"] = []
+ indexrec["unique"] = bool(row["unique_flag"])
+
+ indexrec["column_names"].append(
+ self.normalize_name(row["field_name"])
+ )
+
+ return list(indexes.values())
diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py
new file mode 100644
index 0000000..38f4432
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/fdb.py
@@ -0,0 +1,112 @@
+# firebird/fdb.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: firebird+fdb
+ :name: fdb
+ :dbapi: pyodbc
+ :connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...]
+ :url: https://pypi.org/project/fdb/
+
+ fdb is a kinterbasdb compatible DBAPI for Firebird.
+
+ .. versionchanged:: 0.9 - The fdb dialect is now the default dialect
+ under the ``firebird://`` URL space, as ``fdb`` is now the official
+ Python driver for Firebird.
+
+Arguments
+----------
+
+The ``fdb`` dialect is based on the
+:mod:`sqlalchemy.dialects.firebird.kinterbasdb` dialect, however does not
+accept every argument that Kinterbasdb does.
+
+* ``enable_rowcount`` - True by default, setting this to False disables
+ the usage of "cursor.rowcount" with the
+ Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
+ after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
+ CursorResult will return -1 for result.rowcount. The rationale here is
+ that Kinterbasdb requires a second round trip to the database when
+ .rowcount is called - since SQLA's resultproxy automatically closes
+ the cursor after a non-result-returning statement, rowcount must be
+ called, if at all, before the result object is returned. Additionally,
+ cursor.rowcount may not return correct results with older versions
+ of Firebird, and setting this flag to False will also cause the
+ SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
+ per-execution basis using the ``enable_rowcount`` option with
+ :meth:`_engine.Connection.execution_options`::
+
+ conn = engine.connect().execution_options(enable_rowcount=True)
+ r = conn.execute(stmt)
+ print(r.rowcount)
+
+* ``retaining`` - False by default. Setting this to True will pass the
+ ``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
+ methods of the DBAPI connection, which can improve performance in some
+ situations, but apparently with significant caveats.
+ Please read the fdb and/or kinterbasdb DBAPI documentation in order to
+ understand the implications of this flag.
+
+ .. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
+ In 0.8 it defaulted to ``True``.
+
+ .. seealso::
+
+ https://pythonhosted.org/fdb/usage-guide.html#retaining-transactions
+ - information on the "retaining" flag.
+
+""" # noqa
+
+from .kinterbasdb import FBDialect_kinterbasdb
+from ... import util
+
+
+class FBDialect_fdb(FBDialect_kinterbasdb):
+ supports_statement_cache = True
+
+ def __init__(self, enable_rowcount=True, retaining=False, **kwargs):
+ super(FBDialect_fdb, self).__init__(
+ enable_rowcount=enable_rowcount, retaining=retaining, **kwargs
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("fdb")
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if opts.get("port"):
+ opts["host"] = "%s/%s" % (opts["host"], opts["port"])
+ del opts["port"]
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "type_conv", int)
+
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ """Get the version of the Firebird server used by a connection.
+
+ Returns a tuple of (`major`, `minor`, `build`), three integers
+ representing the version of the attached server.
+ """
+
+ # This is the simpler approach (the other uses the services api),
+ # that for backward compatibility reasons returns a string like
+ # LI-V6.3.3.12981 Firebird 2.0
+ # where the first version is a fake one resembling the old
+ # Interbase signature.
+
+ isc_info_firebird_version = 103
+ fbconn = connection.connection
+
+ version = fbconn.db_info(isc_info_firebird_version)
+
+ return self._parse_version_info(version)
+
+
+dialect = FBDialect_fdb
diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
new file mode 100644
index 0000000..b999404
--- /dev/null
+++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py
@@ -0,0 +1,202 @@
+# firebird/kinterbasdb.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: firebird+kinterbasdb
+ :name: kinterbasdb
+ :dbapi: kinterbasdb
+ :connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...]
+ :url: https://firebirdsql.org/index.php?op=devel&sub=python
+
+Arguments
+----------
+
+The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
+arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect.
+In addition, it also accepts the following:
+
+* ``type_conv`` - select the kind of mapping done on the types: by default
+ SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
+ the linked documents below for further information.
+
+* ``concurrency_level`` - set the backend policy with regards to threading
+ issues: by default SQLAlchemy uses policy 1. See the linked documents
+ below for further information.
+
+.. seealso::
+
+ https://sourceforge.net/projects/kinterbasdb
+
+ https://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
+
+ https://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
+
+""" # noqa
+
+import decimal
+from re import match
+
+from .base import FBDialect
+from .base import FBExecutionContext
+from ... import types as sqltypes
+from ... import util
+
+
+class _kinterbasdb_numeric(object):
+ def bind_processor(self, dialect):
+ def process(value):
+ if isinstance(value, decimal.Decimal):
+ return str(value)
+ else:
+ return value
+
+ return process
+
+
+class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
+ pass
+
+
+class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
+ pass
+
+
+class FBExecutionContext_kinterbasdb(FBExecutionContext):
+ @property
+ def rowcount(self):
+ if self.execution_options.get(
+ "enable_rowcount", self.dialect.enable_rowcount
+ ):
+ return self.cursor.rowcount
+ else:
+ return -1
+
+
+class FBDialect_kinterbasdb(FBDialect):
+ driver = "kinterbasdb"
+ supports_statement_cache = True
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = FBExecutionContext_kinterbasdb
+
+ supports_native_decimal = True
+
+ colspecs = util.update_copy(
+ FBDialect.colspecs,
+ {
+ sqltypes.Numeric: _FBNumeric_kinterbasdb,
+ sqltypes.Float: _FBFloat_kinterbasdb,
+ },
+ )
+
+ def __init__(
+ self,
+ type_conv=200,
+ concurrency_level=1,
+ enable_rowcount=True,
+ retaining=False,
+ **kwargs
+ ):
+ super(FBDialect_kinterbasdb, self).__init__(**kwargs)
+ self.enable_rowcount = enable_rowcount
+ self.type_conv = type_conv
+ self.concurrency_level = concurrency_level
+ self.retaining = retaining
+ if enable_rowcount:
+ self.supports_sane_rowcount = True
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("kinterbasdb")
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ # kinterbase does not accept a None, but wants an empty list
+ # when there are no arguments.
+ cursor.execute(statement, parameters or [])
+
+ def do_rollback(self, dbapi_connection):
+ dbapi_connection.rollback(self.retaining)
+
+ def do_commit(self, dbapi_connection):
+ dbapi_connection.commit(self.retaining)
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if opts.get("port"):
+ opts["host"] = "%s/%s" % (opts["host"], opts["port"])
+ del opts["port"]
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "type_conv", int)
+
+ type_conv = opts.pop("type_conv", self.type_conv)
+ concurrency_level = opts.pop(
+ "concurrency_level", self.concurrency_level
+ )
+
+ if self.dbapi is not None:
+ initialized = getattr(self.dbapi, "initialized", None)
+ if initialized is None:
+ # CVS rev 1.96 changed the name of the attribute:
+ # https://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
+ # Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
+ initialized = getattr(self.dbapi, "_initialized", False)
+ if not initialized:
+ self.dbapi.init(
+ type_conv=type_conv, concurrency_level=concurrency_level
+ )
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ """Get the version of the Firebird server used by a connection.
+
+ Returns a tuple of (`major`, `minor`, `build`), three integers
+ representing the version of the attached server.
+ """
+
+ # This is the simpler approach (the other uses the services api),
+ # that for backward compatibility reasons returns a string like
+ # LI-V6.3.3.12981 Firebird 2.0
+ # where the first version is a fake one resembling the old
+ # Interbase signature.
+
+ fbconn = connection.connection
+ version = fbconn.server_version
+
+ return self._parse_version_info(version)
+
+ def _parse_version_info(self, version):
+ m = match(
+ r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version
+ )
+ if not m:
+ raise AssertionError(
+ "Could not determine version from string '%s'" % version
+ )
+
+ if m.group(5) != None:
+ return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"])
+ else:
+ return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"])
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
+ msg = str(e)
+ return (
+ "Error writing data to the connection" in msg
+ or "Unable to complete network request to host" in msg
+ or "Invalid connection state" in msg
+ or "Invalid cursor state" in msg
+ or "connection shutdown" in msg
+ )
+ else:
+ return False
+
+
+dialect = FBDialect_kinterbasdb
diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py
new file mode 100644
index 0000000..cae0168
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/__init__.py
@@ -0,0 +1,85 @@
+# mssql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import mxodbc # noqa
+from . import pymssql # noqa
+from . import pyodbc # noqa
+from .base import BIGINT
+from .base import BINARY
+from .base import BIT
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import DATETIME2
+from .base import DATETIMEOFFSET
+from .base import DECIMAL
+from .base import FLOAT
+from .base import IMAGE
+from .base import INTEGER
+from .base import JSON
+from .base import MONEY
+from .base import NCHAR
+from .base import NTEXT
+from .base import NUMERIC
+from .base import NVARCHAR
+from .base import REAL
+from .base import ROWVERSION
+from .base import SMALLDATETIME
+from .base import SMALLINT
+from .base import SMALLMONEY
+from .base import SQL_VARIANT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import TINYINT
+from .base import try_cast
+from .base import UNIQUEIDENTIFIER
+from .base import VARBINARY
+from .base import VARCHAR
+from .base import XML
+
+
+base.dialect = dialect = pyodbc.dialect
+
+
+__all__ = (
+ "JSON",
+ "INTEGER",
+ "BIGINT",
+ "SMALLINT",
+ "TINYINT",
+ "VARCHAR",
+ "NVARCHAR",
+ "CHAR",
+ "NCHAR",
+ "TEXT",
+ "NTEXT",
+ "DECIMAL",
+ "NUMERIC",
+ "FLOAT",
+ "DATETIME",
+ "DATETIME2",
+ "DATETIMEOFFSET",
+ "DATE",
+ "TIME",
+ "SMALLDATETIME",
+ "BINARY",
+ "VARBINARY",
+ "BIT",
+ "REAL",
+ "IMAGE",
+ "TIMESTAMP",
+ "ROWVERSION",
+ "MONEY",
+ "SMALLMONEY",
+ "UNIQUEIDENTIFIER",
+ "SQL_VARIANT",
+ "XML",
+ "dialect",
+ "try_cast",
+)
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
new file mode 100644
index 0000000..ee6ce87
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -0,0 +1,3545 @@
+# mssql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+.. dialect:: mssql
+ :name: Microsoft SQL Server
+ :full_support: 2017
+ :normal_support: 2012+
+ :best_effort: 2005+
+
+.. _mssql_external_dialects:
+
+External Dialects
+-----------------
+
+In addition to the above DBAPI layers with native SQLAlchemy support, there
+are third-party dialects for other DBAPI layers that are compatible
+with SQL Server. See the "External Dialects" list on the
+:ref:`dialect_toplevel` page.
+
+.. _mssql_identity:
+
+Auto Increment Behavior / IDENTITY Columns
+------------------------------------------
+
+SQL Server provides so-called "auto incrementing" behavior using the
+``IDENTITY`` construct, which can be placed on any single integer column in a
+table. SQLAlchemy considers ``IDENTITY`` within its default "autoincrement"
+behavior for an integer primary key column, described at
+:paramref:`_schema.Column.autoincrement`. This means that by default,
+the first integer primary key column in a :class:`_schema.Table` will be
+considered to be the identity column - unless it is associated with a
+:class:`.Sequence` - and will generate DDL as such::
+
+ from sqlalchemy import Table, MetaData, Column, Integer
+
+ m = MetaData()
+ t = Table('t', m,
+ Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ m.create_all(engine)
+
+The above example will generate DDL as:
+
+.. sourcecode:: sql
+
+ CREATE TABLE t (
+ id INTEGER NOT NULL IDENTITY,
+ x INTEGER NULL,
+ PRIMARY KEY (id)
+ )
+
+For the case where this default generation of ``IDENTITY`` is not desired,
+specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag,
+on the first integer primary key column::
+
+ m = MetaData()
+ t = Table('t', m,
+ Column('id', Integer, primary_key=True, autoincrement=False),
+ Column('x', Integer))
+ m.create_all(engine)
+
+To add the ``IDENTITY`` keyword to a non-primary key column, specify
+``True`` for the :paramref:`_schema.Column.autoincrement` flag on the desired
+:class:`_schema.Column` object, and ensure that
+:paramref:`_schema.Column.autoincrement`
+is set to ``False`` on any integer primary key column::
+
+ m = MetaData()
+ t = Table('t', m,
+ Column('id', Integer, primary_key=True, autoincrement=False),
+ Column('x', Integer, autoincrement=True))
+ m.create_all(engine)
+
+.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
+ in a :class:`_schema.Column` to specify the start and increment
+ parameters of an IDENTITY. These replace
+ the use of the :class:`.Sequence` object in order to specify these values.
+
+.. deprecated:: 1.4
+
+ The ``mssql_identity_start`` and ``mssql_identity_increment`` parameters
+ to :class:`_schema.Column` are deprecated and should we replaced by
+ an :class:`_schema.Identity` object. Specifying both ways of configuring
+ an IDENTITY will result in a compile error.
+ These options are also no longer returned as part of the
+ ``dialect_options`` key in :meth:`_reflection.Inspector.get_columns`.
+ Use the information in the ``identity`` key instead.
+
+.. deprecated:: 1.3
+
+ The use of :class:`.Sequence` to specify IDENTITY characteristics is
+ deprecated and will be removed in a future release. Please use
+ the :class:`_schema.Identity` object parameters
+ :paramref:`_schema.Identity.start` and
+ :paramref:`_schema.Identity.increment`.
+
+.. versionchanged:: 1.4 Removed the ability to use a :class:`.Sequence`
+ object to modify IDENTITY characteristics. :class:`.Sequence` objects
+ now only manipulate true T-SQL SEQUENCE types.
+
+.. note::
+
+ There can only be one IDENTITY column on the table. When using
+ ``autoincrement=True`` to enable the IDENTITY keyword, SQLAlchemy does not
+ guard against multiple columns specifying the option simultaneously. The
+ SQL Server database will instead reject the ``CREATE TABLE`` statement.
+
+.. note::
+
+ An INSERT statement which attempts to provide a value for a column that is
+ marked with IDENTITY will be rejected by SQL Server. In order for the
+ value to be accepted, a session-level option "SET IDENTITY_INSERT" must be
+ enabled. The SQLAlchemy SQL Server dialect will perform this operation
+ automatically when using a core :class:`_expression.Insert`
+ construct; if the
+ execution specifies a value for the IDENTITY column, the "IDENTITY_INSERT"
+ option will be enabled for the span of that statement's invocation.However,
+ this scenario is not high performing and should not be relied upon for
+ normal use. If a table doesn't actually require IDENTITY behavior in its
+ integer primary key column, the keyword should be disabled when creating
+ the table by ensuring that ``autoincrement=False`` is set.
+
+Controlling "Start" and "Increment"
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Specific control over the "start" and "increment" values for
+the ``IDENTITY`` generator are provided using the
+:paramref:`_schema.Identity.start` and :paramref:`_schema.Identity.increment`
+parameters passed to the :class:`_schema.Identity` object::
+
+ from sqlalchemy import Table, Integer, Column, Identity
+
+ test = Table(
+ 'test', metadata,
+ Column(
+ 'id',
+ Integer,
+ primary_key=True,
+ Identity(start=100, increment=10)
+ ),
+ Column('name', String(20))
+ )
+
+The CREATE TABLE for the above :class:`_schema.Table` object would be:
+
+.. sourcecode:: sql
+
+ CREATE TABLE test (
+ id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
+ name VARCHAR(20) NULL,
+ )
+
+.. note::
+
+ The :class:`_schema.Identity` object supports many other parameter in
+ addition to ``start`` and ``increment``. These are not supported by
+ SQL Server and will be ignored when generating the CREATE TABLE ddl.
+
+.. versionchanged:: 1.3.19 The :class:`_schema.Identity` object is
+ now used to affect the
+ ``IDENTITY`` generator for a :class:`_schema.Column` under SQL Server.
+ Previously, the :class:`.Sequence` object was used. As SQL Server now
+ supports real sequences as a separate construct, :class:`.Sequence` will be
+ functional in the normal way starting from SQLAlchemy version 1.4.
+
+
+Using IDENTITY with Non-Integer numeric types
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+SQL Server also allows ``IDENTITY`` to be used with ``NUMERIC`` columns. To
+implement this pattern smoothly in SQLAlchemy, the primary datatype of the
+column should remain as ``Integer``, however the underlying implementation
+type deployed to the SQL Server database can be specified as ``Numeric`` using
+:meth:`.TypeEngine.with_variant`::
+
+ from sqlalchemy import Column
+ from sqlalchemy import Integer
+ from sqlalchemy import Numeric
+ from sqlalchemy import String
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class TestTable(Base):
+ __tablename__ = "test"
+ id = Column(
+ Integer().with_variant(Numeric(10, 0), "mssql"),
+ primary_key=True,
+ autoincrement=True,
+ )
+ name = Column(String)
+
+In the above example, ``Integer().with_variant()`` provides clear usage
+information that accurately describes the intent of the code. The general
+restriction that ``autoincrement`` only applies to ``Integer`` is established
+at the metadata level and not at the per-dialect level.
+
+When using the above pattern, the primary key identifier that comes back from
+the insertion of a row, which is also the value that would be assigned to an
+ORM object such as ``TestTable`` above, will be an instance of ``Decimal()``
+and not ``int`` when using SQL Server. The numeric return type of the
+:class:`_types.Numeric` type can be changed to return floats by passing False
+to :paramref:`_types.Numeric.asdecimal`. To normalize the return type of the
+above ``Numeric(10, 0)`` to return Python ints (which also support "long"
+integer values in Python 3), use :class:`_types.TypeDecorator` as follows::
+
+ from sqlalchemy import TypeDecorator
+
+ class NumericAsInteger(TypeDecorator):
+ '''normalize floating point return values into ints'''
+
+ impl = Numeric(10, 0, asdecimal=False)
+ cache_ok = True
+
+ def process_result_value(self, value, dialect):
+ if value is not None:
+ value = int(value)
+ return value
+
+ class TestTable(Base):
+ __tablename__ = "test"
+ id = Column(
+ Integer().with_variant(NumericAsInteger, "mssql"),
+ primary_key=True,
+ autoincrement=True,
+ )
+ name = Column(String)
+
+
+INSERT behavior
+^^^^^^^^^^^^^^^^
+
+Handling of the ``IDENTITY`` column at INSERT time involves two key
+techniques. The most common is being able to fetch the "last inserted value"
+for a given ``IDENTITY`` column, a process which SQLAlchemy performs
+implicitly in many cases, most importantly within the ORM.
+
+The process for fetching this value has several variants:
+
+* In the vast majority of cases, RETURNING is used in conjunction with INSERT
+ statements on SQL Server in order to get newly generated primary key values:
+
+ .. sourcecode:: sql
+
+ INSERT INTO t (x) OUTPUT inserted.id VALUES (?)
+
+* When RETURNING is not available or has been disabled via
+ ``implicit_returning=False``, either the ``scope_identity()`` function or
+ the ``@@identity`` variable is used; behavior varies by backend:
+
+ * when using PyODBC, the phrase ``; select scope_identity()`` will be
+ appended to the end of the INSERT statement; a second result set will be
+ fetched in order to receive the value. Given a table as::
+
+ t = Table('t', m, Column('id', Integer, primary_key=True),
+ Column('x', Integer),
+ implicit_returning=False)
+
+ an INSERT will look like:
+
+ .. sourcecode:: sql
+
+ INSERT INTO t (x) VALUES (?); select scope_identity()
+
+ * Other dialects such as pymssql will call upon
+ ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT
+ statement. If the flag ``use_scope_identity=False`` is passed to
+ :func:`_sa.create_engine`,
+ the statement ``SELECT @@identity AS lastrowid``
+ is used instead.
+
+A table that contains an ``IDENTITY`` column will prohibit an INSERT statement
+that refers to the identity column explicitly. The SQLAlchemy dialect will
+detect when an INSERT construct, created using a core
+:func:`_expression.insert`
+construct (not a plain string SQL), refers to the identity column, and
+in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert
+statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the
+execution. Given this example::
+
+ m = MetaData()
+ t = Table('t', m, Column('id', Integer, primary_key=True),
+ Column('x', Integer))
+ m.create_all(engine)
+
+ with engine.begin() as conn:
+ conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2})
+
+The above column will be created with IDENTITY, however the INSERT statement
+we emit is specifying explicit values. In the echo output we can see
+how SQLAlchemy handles this:
+
+.. sourcecode:: sql
+
+ CREATE TABLE t (
+ id INTEGER NOT NULL IDENTITY(1,1),
+ x INTEGER NULL,
+ PRIMARY KEY (id)
+ )
+
+ COMMIT
+ SET IDENTITY_INSERT t ON
+ INSERT INTO t (id, x) VALUES (?, ?)
+ ((1, 1), (2, 2))
+ SET IDENTITY_INSERT t OFF
+ COMMIT
+
+
+
+This is an auxiliary use case suitable for testing and bulk insert scenarios.
+
+SEQUENCE support
+----------------
+
+The :class:`.Sequence` object now creates "real" sequences, i.e.,
+``CREATE SEQUENCE``. To provide compatibility with other dialects,
+:class:`.Sequence` defaults to a start value of 1, even though the
+T-SQL defaults is -9223372036854775808.
+
+.. versionadded:: 1.4.0
+
+MAX on VARCHAR / NVARCHAR
+-------------------------
+
+SQL Server supports the special string "MAX" within the
+:class:`_types.VARCHAR` and :class:`_types.NVARCHAR` datatypes,
+to indicate "maximum length possible". The dialect currently handles this as
+a length of "None" in the base type, rather than supplying a
+dialect-specific version of these types, so that a base type
+specified such as ``VARCHAR(None)`` can assume "unlengthed" behavior on
+more than one backend without using dialect-specific types.
+
+To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None::
+
+ my_table = Table(
+ 'my_table', metadata,
+ Column('my_data', VARCHAR(None)),
+ Column('my_n_data', NVARCHAR(None))
+ )
+
+
+Collation Support
+-----------------
+
+Character collations are supported by the base string types,
+specified by the string argument "collation"::
+
+ from sqlalchemy import VARCHAR
+ Column('login', VARCHAR(32, collation='Latin1_General_CI_AS'))
+
+When such a column is associated with a :class:`_schema.Table`, the
+CREATE TABLE statement for this column will yield::
+
+ login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL
+
+LIMIT/OFFSET Support
+--------------------
+
+MSSQL has added support for LIMIT / OFFSET as of SQL Server 2012, via the
+"OFFSET n ROWS" and "FETCH NEXT n ROWS" clauses. SQLAlchemy supports these
+syntaxes automatically if SQL Server 2012 or greater is detected.
+
+.. versionchanged:: 1.4 support added for SQL Server "OFFSET n ROWS" and
+ "FETCH NEXT n ROWS" syntax.
+
+For statements that specify only LIMIT and no OFFSET, all versions of SQL
+Server support the TOP keyword. This syntax is used for all SQL Server
+versions when no OFFSET clause is present. A statement such as::
+
+ select(some_table).limit(5)
+
+will render similarly to::
+
+ SELECT TOP 5 col1, col2.. FROM table
+
+For versions of SQL Server prior to SQL Server 2012, a statement that uses
+LIMIT and OFFSET, or just OFFSET alone, will be rendered using the
+``ROW_NUMBER()`` window function. A statement such as::
+
+ select(some_table).order_by(some_table.c.col3).limit(5).offset(10)
+
+will render similarly to::
+
+ SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2,
+ ROW_NUMBER() OVER (ORDER BY col3) AS
+ mssql_rn FROM table WHERE t.x = :x_1) AS
+ anon_1 WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1
+
+Note that when using LIMIT and/or OFFSET, whether using the older
+or newer SQL Server syntaxes, the statement must have an ORDER BY as well,
+else a :class:`.CompileError` is raised.
+
+.. _mssql_isolation_level:
+
+Transaction Isolation Level
+---------------------------
+
+All SQL Server dialects support setting of transaction isolation level
+both via a dialect-specific parameter
+:paramref:`_sa.create_engine.isolation_level`
+accepted by :func:`_sa.create_engine`,
+as well as the :paramref:`.Connection.execution_options.isolation_level`
+argument as passed to
+:meth:`_engine.Connection.execution_options`.
+This feature works by issuing the
+command ``SET TRANSACTION ISOLATION LEVEL <level>`` for
+each new connection.
+
+To set isolation level using :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "mssql+pyodbc://scott:tiger@ms_2008",
+ isolation_level="REPEATABLE READ"
+ )
+
+To set using per-connection execution options::
+
+ connection = engine.connect()
+ connection = connection.execution_options(
+ isolation_level="READ COMMITTED"
+ )
+
+Valid values for ``isolation_level`` include:
+
+* ``AUTOCOMMIT`` - pyodbc / pymssql-specific
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``SNAPSHOT`` - specific to SQL Server
+
+There are also more options for isolation level configurations, such as
+"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply
+different isolation level settings. See the discussion at
+:ref:`dbapi_autocommit` for background.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+Nullability
+-----------
+MSSQL has support for three levels of column nullability. The default
+nullability allows nulls and is explicit in the CREATE TABLE
+construct::
+
+ name VARCHAR(20) NULL
+
+If ``nullable=None`` is specified then no specification is made. In
+other words the database's configured default is used. This will
+render::
+
+ name VARCHAR(20)
+
+If ``nullable`` is ``True`` or ``False`` then the column will be
+``NULL`` or ``NOT NULL`` respectively.
+
+Date / Time Handling
+--------------------
+DATE and TIME are supported. Bind parameters are converted
+to datetime.datetime() objects as required by most MSSQL drivers,
+and results are processed from strings if needed.
+The DATE and TIME types are not available for MSSQL 2005 and
+previous - if a server version below 2008 is detected, DDL
+for these types will be issued as DATETIME.
+
+.. _mssql_large_type_deprecation:
+
+Large Text/Binary Type Deprecation
+----------------------------------
+
+Per
+`SQL Server 2012/2014 Documentation <https://technet.microsoft.com/en-us/library/ms187993.aspx>`_,
+the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL
+Server in a future release. SQLAlchemy normally relates these types to the
+:class:`.UnicodeText`, :class:`_expression.TextClause` and
+:class:`.LargeBinary` datatypes.
+
+In order to accommodate this change, a new flag ``deprecate_large_types``
+is added to the dialect, which will be automatically set based on detection
+of the server version in use, if not otherwise set by the user. The
+behavior of this flag is as follows:
+
+* When this flag is ``True``, the :class:`.UnicodeText`,
+ :class:`_expression.TextClause` and
+ :class:`.LargeBinary` datatypes, when used to render DDL, will render the
+ types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``,
+ respectively. This is a new behavior as of the addition of this flag.
+
+* When this flag is ``False``, the :class:`.UnicodeText`,
+ :class:`_expression.TextClause` and
+ :class:`.LargeBinary` datatypes, when used to render DDL, will render the
+ types ``NTEXT``, ``TEXT``, and ``IMAGE``,
+ respectively. This is the long-standing behavior of these types.
+
+* The flag begins with the value ``None``, before a database connection is
+ established. If the dialect is used to render DDL without the flag being
+ set, it is interpreted the same as ``False``.
+
+* On first connection, the dialect detects if SQL Server version 2012 or
+ greater is in use; if the flag is still at ``None``, it sets it to ``True``
+ or ``False`` based on whether 2012 or greater is detected.
+
+* The flag can be set to either ``True`` or ``False`` when the dialect
+ is created, typically via :func:`_sa.create_engine`::
+
+ eng = create_engine("mssql+pymssql://user:pass@host/db",
+ deprecate_large_types=True)
+
+* Complete control over whether the "old" or "new" types are rendered is
+ available in all SQLAlchemy versions by using the UPPERCASE type objects
+ instead: :class:`_types.NVARCHAR`, :class:`_types.VARCHAR`,
+ :class:`_types.VARBINARY`, :class:`_types.TEXT`, :class:`_mssql.NTEXT`,
+ :class:`_mssql.IMAGE`
+ will always remain fixed and always output exactly that
+ type.
+
+.. versionadded:: 1.0.0
+
+.. _multipart_schema_names:
+
+Multipart Schema Names
+----------------------
+
+SQL Server schemas sometimes require multiple parts to their "schema"
+qualifier, that is, including the database name and owner name as separate
+tokens, such as ``mydatabase.dbo.some_table``. These multipart names can be set
+at once using the :paramref:`_schema.Table.schema` argument of
+:class:`_schema.Table`::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="mydatabase.dbo"
+ )
+
+When performing operations such as table or component reflection, a schema
+argument that contains a dot will be split into separate
+"database" and "owner" components in order to correctly query the SQL
+Server information schema tables, as these two values are stored separately.
+Additionally, when rendering the schema name for DDL or SQL, the two
+components will be quoted separately for case sensitive names and other
+special characters. Given an argument as below::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="MyDataBase.dbo"
+ )
+
+The above schema would be rendered as ``[MyDataBase].dbo``, and also in
+reflection, would be reflected using "dbo" as the owner and "MyDataBase"
+as the database name.
+
+To control how the schema name is broken into database / owner,
+specify brackets (which in SQL Server are quoting characters) in the name.
+Below, the "owner" will be considered as ``MyDataBase.dbo`` and the
+"database" will be None::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="[MyDataBase.dbo]"
+ )
+
+To individually specify both database and owner name with special characters
+or embedded dots, use two sets of brackets::
+
+ Table(
+ "some_table", metadata,
+ Column("q", String(50)),
+ schema="[MyDataBase.Period].[MyOwner.Dot]"
+ )
+
+
+.. versionchanged:: 1.2 the SQL Server dialect now treats brackets as
+ identifier delimiters splitting the schema into separate database
+ and owner tokens, to allow dots within either name itself.
+
+.. _legacy_schema_rendering:
+
+Legacy Schema Mode
+------------------
+
+Very old versions of the MSSQL dialect introduced the behavior such that a
+schema-qualified table would be auto-aliased when used in a
+SELECT statement; given a table::
+
+ account_table = Table(
+ 'account', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('info', String(100)),
+ schema="customer_schema"
+ )
+
+this legacy mode of rendering would assume that "customer_schema.account"
+would not be accepted by all parts of the SQL statement, as illustrated
+below::
+
+ >>> eng = create_engine("mssql+pymssql://mydsn", legacy_schema_aliasing=True)
+ >>> print(account_table.select().compile(eng))
+ SELECT account_1.id, account_1.info
+ FROM customer_schema.account AS account_1
+
+This mode of behavior is now off by default, as it appears to have served
+no purpose; however in the case that legacy applications rely upon it,
+it is available using the ``legacy_schema_aliasing`` argument to
+:func:`_sa.create_engine` as illustrated above.
+
+.. versionchanged:: 1.1 the ``legacy_schema_aliasing`` flag introduced
+ in version 1.0.5 to allow disabling of legacy mode for schemas now
+ defaults to False.
+
+.. deprecated:: 1.4
+
+ The ``legacy_schema_aliasing`` flag is now
+ deprecated and will be removed in a future release.
+
+.. _mssql_indexes:
+
+Clustered Index Support
+-----------------------
+
+The MSSQL dialect supports clustered indexes (and primary keys) via the
+``mssql_clustered`` option. This option is available to :class:`.Index`,
+:class:`.UniqueConstraint`. and :class:`.PrimaryKeyConstraint`.
+
+To generate a clustered index::
+
+ Index("my_index", table.c.x, mssql_clustered=True)
+
+which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``.
+
+To generate a clustered primary key use::
+
+ Table('my_table', metadata,
+ Column('x', ...),
+ Column('y', ...),
+ PrimaryKeyConstraint("x", "y", mssql_clustered=True))
+
+which will render the table, for example, as::
+
+ CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
+ PRIMARY KEY CLUSTERED (x, y))
+
+Similarly, we can generate a clustered unique constraint using::
+
+ Table('my_table', metadata,
+ Column('x', ...),
+ Column('y', ...),
+ PrimaryKeyConstraint("x"),
+ UniqueConstraint("y", mssql_clustered=True),
+ )
+
+To explicitly request a non-clustered primary key (for example, when
+a separate clustered index is desired), use::
+
+ Table('my_table', metadata,
+ Column('x', ...),
+ Column('y', ...),
+ PrimaryKeyConstraint("x", "y", mssql_clustered=False))
+
+which will render the table, for example, as::
+
+ CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL,
+ PRIMARY KEY NONCLUSTERED (x, y))
+
+.. versionchanged:: 1.1 the ``mssql_clustered`` option now defaults
+ to None, rather than False. ``mssql_clustered=False`` now explicitly
+ renders the NONCLUSTERED clause, whereas None omits the CLUSTERED
+ clause entirely, allowing SQL Server defaults to take effect.
+
+
+MSSQL-Specific Index Options
+-----------------------------
+
+In addition to clustering, the MSSQL dialect supports other special options
+for :class:`.Index`.
+
+INCLUDE
+^^^^^^^
+
+The ``mssql_include`` option renders INCLUDE(colname) for the given string
+names::
+
+ Index("my_index", table.c.x, mssql_include=['y'])
+
+would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
+
+.. _mssql_index_where:
+
+Filtered Indexes
+^^^^^^^^^^^^^^^^
+
+The ``mssql_where`` option renders WHERE(condition) for the given string
+names::
+
+ Index("my_index", table.c.x, mssql_where=table.c.x > 10)
+
+would render the index as ``CREATE INDEX my_index ON table (x) WHERE x > 10``.
+
+.. versionadded:: 1.3.4
+
+Index ordering
+^^^^^^^^^^^^^^
+
+Index ordering is available via functional expressions, such as::
+
+ Index("my_index", table.c.x.desc())
+
+would render the index as ``CREATE INDEX my_index ON table (x DESC)``
+
+.. seealso::
+
+ :ref:`schema_indexes_functional`
+
+Compatibility Levels
+--------------------
+MSSQL supports the notion of setting compatibility levels at the
+database level. This allows, for instance, to run a database that
+is compatible with SQL2000 while running on a SQL2005 database
+server. ``server_version_info`` will always return the database
+server version information (in this case SQL2005) and not the
+compatibility level information. Because of this, if running under
+a backwards compatibility mode SQLAlchemy may attempt to use T-SQL
+statements that are unable to be parsed by the database server.
+
+Triggers
+--------
+
+SQLAlchemy by default uses OUTPUT INSERTED to get at newly
+generated primary key values via IDENTITY columns or other
+server side defaults. MS-SQL does not
+allow the usage of OUTPUT INSERTED on tables that have triggers.
+To disable the usage of OUTPUT INSERTED on a per-table basis,
+specify ``implicit_returning=False`` for each :class:`_schema.Table`
+which has triggers::
+
+ Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ # ...,
+ implicit_returning=False
+ )
+
+Declarative form::
+
+ class MyClass(Base):
+ # ...
+ __table_args__ = {'implicit_returning':False}
+
+
+This option can also be specified engine-wide using the
+``implicit_returning=False`` argument on :func:`_sa.create_engine`.
+
+.. _mssql_rowcount_versioning:
+
+Rowcount Support / ORM Versioning
+---------------------------------
+
+The SQL Server drivers may have limited ability to return the number
+of rows updated from an UPDATE or DELETE statement.
+
+As of this writing, the PyODBC driver is not able to return a rowcount when
+OUTPUT INSERTED is used. This impacts the SQLAlchemy ORM's versioning feature
+in many cases where server-side value generators are in use in that while the
+versioning operations can succeed, the ORM cannot always check that an UPDATE
+or DELETE statement matched the number of rows expected, which is how it
+verifies that the version identifier matched. When this condition occurs, a
+warning will be emitted but the operation will proceed.
+
+The use of OUTPUT INSERTED can be disabled by setting the
+:paramref:`_schema.Table.implicit_returning` flag to ``False`` on a particular
+:class:`_schema.Table`, which in declarative looks like::
+
+ class MyTable(Base):
+ __tablename__ = 'mytable'
+ id = Column(Integer, primary_key=True)
+ stuff = Column(String(10))
+ timestamp = Column(TIMESTAMP(), default=text('DEFAULT'))
+ __mapper_args__ = {
+ 'version_id_col': timestamp,
+ 'version_id_generator': False,
+ }
+ __table_args__ = {
+ 'implicit_returning': False
+ }
+
+Enabling Snapshot Isolation
+---------------------------
+
+SQL Server has a default transaction
+isolation mode that locks entire tables, and causes even mildly concurrent
+applications to have long held locks and frequent deadlocks.
+Enabling snapshot isolation for the database as a whole is recommended
+for modern levels of concurrency support. This is accomplished via the
+following ALTER DATABASE commands executed at the SQL prompt::
+
+ ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON
+
+ ALTER DATABASE MyDatabase SET READ_COMMITTED_SNAPSHOT ON
+
+Background on SQL Server snapshot isolation is available at
+https://msdn.microsoft.com/en-us/library/ms175095.aspx.
+
+""" # noqa
+
+import codecs
+import datetime
+import operator
+import re
+
+from . import information_schema as ischema
+from .json import JSON
+from .json import JSONIndexType
+from .json import JSONPathType
+from ... import exc
+from ... import Identity
+from ... import schema as sa_schema
+from ... import Sequence
+from ... import sql
+from ... import text
+from ... import types as sqltypes
+from ... import util
+from ...engine import cursor as _cursor
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import compiler
+from ...sql import elements
+from ...sql import expression
+from ...sql import func
+from ...sql import quoted_name
+from ...sql import roles
+from ...sql import util as sql_util
+from ...types import BIGINT
+from ...types import BINARY
+from ...types import CHAR
+from ...types import DATE
+from ...types import DATETIME
+from ...types import DECIMAL
+from ...types import FLOAT
+from ...types import INTEGER
+from ...types import NCHAR
+from ...types import NUMERIC
+from ...types import NVARCHAR
+from ...types import SMALLINT
+from ...types import TEXT
+from ...types import VARCHAR
+from ...util import compat
+from ...util import update_wrapper
+from ...util.langhelpers import public_factory
+
+
+# https://sqlserverbuilds.blogspot.com/
+MS_2017_VERSION = (14,)
+MS_2016_VERSION = (13,)
+MS_2014_VERSION = (12,)
+MS_2012_VERSION = (11,)
+MS_2008_VERSION = (10,)
+MS_2005_VERSION = (9,)
+MS_2000_VERSION = (8,)
+
+RESERVED_WORDS = set(
+ [
+ "add",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "authorization",
+ "backup",
+ "begin",
+ "between",
+ "break",
+ "browse",
+ "bulk",
+ "by",
+ "cascade",
+ "case",
+ "check",
+ "checkpoint",
+ "close",
+ "clustered",
+ "coalesce",
+ "collate",
+ "column",
+ "commit",
+ "compute",
+ "constraint",
+ "contains",
+ "containstable",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "dbcc",
+ "deallocate",
+ "declare",
+ "default",
+ "delete",
+ "deny",
+ "desc",
+ "disk",
+ "distinct",
+ "distributed",
+ "double",
+ "drop",
+ "dump",
+ "else",
+ "end",
+ "errlvl",
+ "escape",
+ "except",
+ "exec",
+ "execute",
+ "exists",
+ "exit",
+ "external",
+ "fetch",
+ "file",
+ "fillfactor",
+ "for",
+ "foreign",
+ "freetext",
+ "freetexttable",
+ "from",
+ "full",
+ "function",
+ "goto",
+ "grant",
+ "group",
+ "having",
+ "holdlock",
+ "identity",
+ "identity_insert",
+ "identitycol",
+ "if",
+ "in",
+ "index",
+ "inner",
+ "insert",
+ "intersect",
+ "into",
+ "is",
+ "join",
+ "key",
+ "kill",
+ "left",
+ "like",
+ "lineno",
+ "load",
+ "merge",
+ "national",
+ "nocheck",
+ "nonclustered",
+ "not",
+ "null",
+ "nullif",
+ "of",
+ "off",
+ "offsets",
+ "on",
+ "open",
+ "opendatasource",
+ "openquery",
+ "openrowset",
+ "openxml",
+ "option",
+ "or",
+ "order",
+ "outer",
+ "over",
+ "percent",
+ "pivot",
+ "plan",
+ "precision",
+ "primary",
+ "print",
+ "proc",
+ "procedure",
+ "public",
+ "raiserror",
+ "read",
+ "readtext",
+ "reconfigure",
+ "references",
+ "replication",
+ "restore",
+ "restrict",
+ "return",
+ "revert",
+ "revoke",
+ "right",
+ "rollback",
+ "rowcount",
+ "rowguidcol",
+ "rule",
+ "save",
+ "schema",
+ "securityaudit",
+ "select",
+ "session_user",
+ "set",
+ "setuser",
+ "shutdown",
+ "some",
+ "statistics",
+ "system_user",
+ "table",
+ "tablesample",
+ "textsize",
+ "then",
+ "to",
+ "top",
+ "tran",
+ "transaction",
+ "trigger",
+ "truncate",
+ "tsequal",
+ "union",
+ "unique",
+ "unpivot",
+ "update",
+ "updatetext",
+ "use",
+ "user",
+ "values",
+ "varying",
+ "view",
+ "waitfor",
+ "when",
+ "where",
+ "while",
+ "with",
+ "writetext",
+ ]
+)
+
+
+class REAL(sqltypes.REAL):
+ __visit_name__ = "REAL"
+
+ def __init__(self, **kw):
+ # REAL is a synonym for FLOAT(24) on SQL server.
+ # it is only accepted as the word "REAL" in DDL, the numeric
+ # precision value is not allowed to be present
+ kw.setdefault("precision", 24)
+ super(REAL, self).__init__(**kw)
+
+
+class TINYINT(sqltypes.Integer):
+ __visit_name__ = "TINYINT"
+
+
+# MSSQL DATE/TIME types have varied behavior, sometimes returning
+# strings. MSDate/TIME check for everything, and always
+# filter bind parameters into datetime objects (required by pyodbc,
+# not sure about other dialects).
+
+
+class _MSDate(sqltypes.Date):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+
+ return process
+
+ _reg = re.compile(r"(\d+)-(\d+)-(\d+)")
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.date()
+ elif isinstance(value, util.string_types):
+ m = self._reg.match(value)
+ if not m:
+ raise ValueError(
+ "could not parse %r as a date value" % (value,)
+ )
+ return datetime.date(*[int(x or 0) for x in m.groups()])
+ else:
+ return value
+
+ return process
+
+
+class TIME(sqltypes.TIME):
+ def __init__(self, precision=None, **kwargs):
+ self.precision = precision
+ super(TIME, self).__init__()
+
+ __zero_date = datetime.date(1900, 1, 1)
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ value = datetime.datetime.combine(
+ self.__zero_date, value.time()
+ )
+ elif isinstance(value, datetime.time):
+ """issue #5339
+ per: https://github.com/mkleehammer/pyodbc/wiki/Tips-and-Tricks-by-Database-Platform#time-columns
+ pass TIME value as string
+ """ # noqa
+ value = str(value)
+ return value
+
+ return process
+
+ _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?")
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if isinstance(value, datetime.datetime):
+ return value.time()
+ elif isinstance(value, util.string_types):
+ m = self._reg.match(value)
+ if not m:
+ raise ValueError(
+ "could not parse %r as a time value" % (value,)
+ )
+ return datetime.time(*[int(x or 0) for x in m.groups()])
+ else:
+ return value
+
+ return process
+
+
+_MSTime = TIME
+
+
+class _BASETIMEIMPL(TIME):
+ __visit_name__ = "_BASETIMEIMPL"
+
+
+class _DateTimeBase(object):
+ def bind_processor(self, dialect):
+ def process(value):
+ if type(value) == datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ else:
+ return value
+
+ return process
+
+
+class _MSDateTime(_DateTimeBase, sqltypes.DateTime):
+ pass
+
+
+class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = "SMALLDATETIME"
+
+
+class DATETIME2(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = "DATETIME2"
+
+ def __init__(self, precision=None, **kw):
+ super(DATETIME2, self).__init__(**kw)
+ self.precision = precision
+
+
+class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime):
+ __visit_name__ = "DATETIMEOFFSET"
+
+ def __init__(self, precision=None, **kw):
+ super(DATETIMEOFFSET, self).__init__(**kw)
+ self.precision = precision
+
+
+class _UnicodeLiteral(object):
+ def literal_processor(self, dialect):
+ def process(value):
+
+ value = value.replace("'", "''")
+
+ if dialect.identifier_preparer._double_percents:
+ value = value.replace("%", "%%")
+
+ return "N'%s'" % value
+
+ return process
+
+
+class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode):
+ pass
+
+
+class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText):
+ pass
+
+
+class TIMESTAMP(sqltypes._Binary):
+ """Implement the SQL Server TIMESTAMP type.
+
+ Note this is **completely different** than the SQL Standard
+ TIMESTAMP type, which is not supported by SQL Server. It
+ is a read-only datatype that does not support INSERT of values.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :class:`_mssql.ROWVERSION`
+
+ """
+
+ __visit_name__ = "TIMESTAMP"
+
+ # expected by _Binary to be present
+ length = None
+
+ def __init__(self, convert_int=False):
+ """Construct a TIMESTAMP or ROWVERSION type.
+
+ :param convert_int: if True, binary integer values will
+ be converted to integers on read.
+
+ .. versionadded:: 1.2
+
+ """
+ self.convert_int = convert_int
+
+ def result_processor(self, dialect, coltype):
+ super_ = super(TIMESTAMP, self).result_processor(dialect, coltype)
+ if self.convert_int:
+
+ def process(value):
+ value = super_(value)
+ if value is not None:
+ # https://stackoverflow.com/a/30403242/34549
+ value = int(codecs.encode(value, "hex"), 16)
+ return value
+
+ return process
+ else:
+ return super_
+
+
+class ROWVERSION(TIMESTAMP):
+ """Implement the SQL Server ROWVERSION type.
+
+ The ROWVERSION datatype is a SQL Server synonym for the TIMESTAMP
+ datatype, however current SQL Server documentation suggests using
+ ROWVERSION for new datatypes going forward.
+
+ The ROWVERSION datatype does **not** reflect (e.g. introspect) from the
+ database as itself; the returned datatype will be
+ :class:`_mssql.TIMESTAMP`.
+
+ This is a read-only datatype that does not support INSERT of values.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :class:`_mssql.TIMESTAMP`
+
+ """
+
+ __visit_name__ = "ROWVERSION"
+
+
+class NTEXT(sqltypes.UnicodeText):
+
+ """MSSQL NTEXT type, for variable-length unicode text up to 2^30
+ characters."""
+
+ __visit_name__ = "NTEXT"
+
+
+class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary):
+ """The MSSQL VARBINARY type.
+
+ This type adds additional features to the core :class:`_types.VARBINARY`
+ type, including "deprecate_large_types" mode where
+ either ``VARBINARY(max)`` or IMAGE is rendered, as well as the SQL
+ Server ``FILESTREAM`` option.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :ref:`mssql_large_type_deprecation`
+
+ """
+
+ __visit_name__ = "VARBINARY"
+
+ def __init__(self, length=None, filestream=False):
+ """
+ Construct a VARBINARY type.
+
+ :param length: optional, a length for the column for use in
+ DDL statements, for those binary types that accept a length,
+ such as the MySQL BLOB type.
+
+ :param filestream=False: if True, renders the ``FILESTREAM`` keyword
+ in the table definition. In this case ``length`` must be ``None``
+ or ``'max'``.
+
+ .. versionadded:: 1.4.31
+
+ """
+
+ self.filestream = filestream
+ if self.filestream and length not in (None, "max"):
+ raise ValueError(
+ "length must be None or 'max' when setting filestream"
+ )
+ super(VARBINARY, self).__init__(length=length)
+
+
+class IMAGE(sqltypes.LargeBinary):
+ __visit_name__ = "IMAGE"
+
+
+class XML(sqltypes.Text):
+ """MSSQL XML type.
+
+ This is a placeholder type for reflection purposes that does not include
+ any Python-side datatype support. It also does not currently support
+ additional arguments, such as "CONTENT", "DOCUMENT",
+ "xml_schema_collection".
+
+ .. versionadded:: 1.1.11
+
+ """
+
+ __visit_name__ = "XML"
+
+
+class BIT(sqltypes.Boolean):
+ """MSSQL BIT type.
+
+ Both pyodbc and pymssql return values from BIT columns as
+ Python <class 'bool'> so just subclass Boolean.
+
+ """
+
+ __visit_name__ = "BIT"
+
+
+class MONEY(sqltypes.TypeEngine):
+ __visit_name__ = "MONEY"
+
+
+class SMALLMONEY(sqltypes.TypeEngine):
+ __visit_name__ = "SMALLMONEY"
+
+
+class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
+
+
+class SQL_VARIANT(sqltypes.TypeEngine):
+ __visit_name__ = "SQL_VARIANT"
+
+
+class TryCast(sql.elements.Cast):
+ """Represent a SQL Server TRY_CAST expression."""
+
+ __visit_name__ = "try_cast"
+
+ stringify_dialect = "mssql"
+ inherit_cache = True
+
+ def __init__(self, *arg, **kw):
+ """Create a TRY_CAST expression.
+
+ :class:`.TryCast` is a subclass of SQLAlchemy's :class:`.Cast`
+ construct, and works in the same way, except that the SQL expression
+ rendered is "TRY_CAST" rather than "CAST"::
+
+ from sqlalchemy import select
+ from sqlalchemy import Numeric
+ from sqlalchemy.dialects.mssql import try_cast
+
+ stmt = select(
+ try_cast(product_table.c.unit_price, Numeric(10, 4))
+ )
+
+ The above would render::
+
+ SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4))
+ FROM product_table
+
+ .. versionadded:: 1.3.7
+
+ """
+ super(TryCast, self).__init__(*arg, **kw)
+
+
+try_cast = public_factory(TryCast, ".dialects.mssql.try_cast")
+
+# old names.
+MSDateTime = _MSDateTime
+MSDate = _MSDate
+MSReal = REAL
+MSTinyInteger = TINYINT
+MSTime = TIME
+MSSmallDateTime = SMALLDATETIME
+MSDateTime2 = DATETIME2
+MSDateTimeOffset = DATETIMEOFFSET
+MSText = TEXT
+MSNText = NTEXT
+MSString = VARCHAR
+MSNVarchar = NVARCHAR
+MSChar = CHAR
+MSNChar = NCHAR
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSImage = IMAGE
+MSBit = BIT
+MSMoney = MONEY
+MSSmallMoney = SMALLMONEY
+MSUniqueIdentifier = UNIQUEIDENTIFIER
+MSVariant = SQL_VARIANT
+
+ischema_names = {
+ "int": INTEGER,
+ "bigint": BIGINT,
+ "smallint": SMALLINT,
+ "tinyint": TINYINT,
+ "varchar": VARCHAR,
+ "nvarchar": NVARCHAR,
+ "char": CHAR,
+ "nchar": NCHAR,
+ "text": TEXT,
+ "ntext": NTEXT,
+ "decimal": DECIMAL,
+ "numeric": NUMERIC,
+ "float": FLOAT,
+ "datetime": DATETIME,
+ "datetime2": DATETIME2,
+ "datetimeoffset": DATETIMEOFFSET,
+ "date": DATE,
+ "time": TIME,
+ "smalldatetime": SMALLDATETIME,
+ "binary": BINARY,
+ "varbinary": VARBINARY,
+ "bit": BIT,
+ "real": REAL,
+ "image": IMAGE,
+ "xml": XML,
+ "timestamp": TIMESTAMP,
+ "money": MONEY,
+ "smallmoney": SMALLMONEY,
+ "uniqueidentifier": UNIQUEIDENTIFIER,
+ "sql_variant": SQL_VARIANT,
+}
+
+
+class MSTypeCompiler(compiler.GenericTypeCompiler):
+ def _extend(self, spec, type_, length=None):
+ """Extend a string-type declaration with standard SQL
+ COLLATE annotations.
+
+ """
+
+ if getattr(type_, "collation", None):
+ collation = "COLLATE %s" % type_.collation
+ else:
+ collation = None
+
+ if not length:
+ length = type_.length
+
+ if length:
+ spec = spec + "(%s)" % length
+
+ return " ".join([c for c in (spec, collation) if c is not None])
+
+ def visit_FLOAT(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is None:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {"precision": precision}
+
+ def visit_TINYINT(self, type_, **kw):
+ return "TINYINT"
+
+ def visit_TIME(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is not None:
+ return "TIME(%s)" % precision
+ else:
+ return "TIME"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ return "TIMESTAMP"
+
+ def visit_ROWVERSION(self, type_, **kw):
+ return "ROWVERSION"
+
+ def visit_datetime(self, type_, **kw):
+ if type_.timezone:
+ return self.visit_DATETIMEOFFSET(type_, **kw)
+ else:
+ return self.visit_DATETIME(type_, **kw)
+
+ def visit_DATETIMEOFFSET(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is not None:
+ return "DATETIMEOFFSET(%s)" % type_.precision
+ else:
+ return "DATETIMEOFFSET"
+
+ def visit_DATETIME2(self, type_, **kw):
+ precision = getattr(type_, "precision", None)
+ if precision is not None:
+ return "DATETIME2(%s)" % precision
+ else:
+ return "DATETIME2"
+
+ def visit_SMALLDATETIME(self, type_, **kw):
+ return "SMALLDATETIME"
+
+ def visit_unicode(self, type_, **kw):
+ return self.visit_NVARCHAR(type_, **kw)
+
+ def visit_text(self, type_, **kw):
+ if self.dialect.deprecate_large_types:
+ return self.visit_VARCHAR(type_, **kw)
+ else:
+ return self.visit_TEXT(type_, **kw)
+
+ def visit_unicode_text(self, type_, **kw):
+ if self.dialect.deprecate_large_types:
+ return self.visit_NVARCHAR(type_, **kw)
+ else:
+ return self.visit_NTEXT(type_, **kw)
+
+ def visit_NTEXT(self, type_, **kw):
+ return self._extend("NTEXT", type_)
+
+ def visit_TEXT(self, type_, **kw):
+ return self._extend("TEXT", type_)
+
+ def visit_VARCHAR(self, type_, **kw):
+ return self._extend("VARCHAR", type_, length=type_.length or "max")
+
+ def visit_CHAR(self, type_, **kw):
+ return self._extend("CHAR", type_)
+
+ def visit_NCHAR(self, type_, **kw):
+ return self._extend("NCHAR", type_)
+
+ def visit_NVARCHAR(self, type_, **kw):
+ return self._extend("NVARCHAR", type_, length=type_.length or "max")
+
+ def visit_date(self, type_, **kw):
+ if self.dialect.server_version_info < MS_2008_VERSION:
+ return self.visit_DATETIME(type_, **kw)
+ else:
+ return self.visit_DATE(type_, **kw)
+
+ def visit__BASETIMEIMPL(self, type_, **kw):
+ return self.visit_time(type_, **kw)
+
+ def visit_time(self, type_, **kw):
+ if self.dialect.server_version_info < MS_2008_VERSION:
+ return self.visit_DATETIME(type_, **kw)
+ else:
+ return self.visit_TIME(type_, **kw)
+
+ def visit_large_binary(self, type_, **kw):
+ if self.dialect.deprecate_large_types:
+ return self.visit_VARBINARY(type_, **kw)
+ else:
+ return self.visit_IMAGE(type_, **kw)
+
+ def visit_IMAGE(self, type_, **kw):
+ return "IMAGE"
+
+ def visit_XML(self, type_, **kw):
+ return "XML"
+
+ def visit_VARBINARY(self, type_, **kw):
+ text = self._extend("VARBINARY", type_, length=type_.length or "max")
+ if getattr(type_, "filestream", False):
+ text += " FILESTREAM"
+ return text
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BIT(type_)
+
+ def visit_BIT(self, type_, **kw):
+ return "BIT"
+
+ def visit_JSON(self, type_, **kw):
+ # this is a bit of a break with SQLAlchemy's convention of
+ # "UPPERCASE name goes to UPPERCASE type name with no modification"
+ return self._extend("NVARCHAR", type_, length="max")
+
+ def visit_MONEY(self, type_, **kw):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_, **kw):
+ return "SMALLMONEY"
+
+ def visit_UNIQUEIDENTIFIER(self, type_, **kw):
+ return "UNIQUEIDENTIFIER"
+
+ def visit_SQL_VARIANT(self, type_, **kw):
+ return "SQL_VARIANT"
+
+
+class MSExecutionContext(default.DefaultExecutionContext):
+ _enable_identity_insert = False
+ _select_lastrowid = False
+ _lastrowid = None
+ _rowcount = None
+
+ def _opt_encode(self, statement):
+
+ if not self.dialect.supports_unicode_statements:
+ encoded = self.dialect._encoder(statement)[0]
+ else:
+ encoded = statement
+
+ if self.compiled and self.compiled.schema_translate_map:
+
+ rst = self.compiled.preparer._render_schema_translates
+ encoded = rst(encoded, self.compiled.schema_translate_map)
+
+ return encoded
+
+ def pre_exec(self):
+ """Activate IDENTITY_INSERT if needed."""
+
+ if self.isinsert:
+ tbl = self.compiled.compile_state.dml_table
+ id_column = tbl._autoincrement_column
+ insert_has_identity = (id_column is not None) and (
+ not isinstance(id_column.default, Sequence)
+ )
+
+ if insert_has_identity:
+ compile_state = self.compiled.dml_compile_state
+ self._enable_identity_insert = (
+ id_column.key in self.compiled_parameters[0]
+ ) or (
+ compile_state._dict_parameters
+ and (id_column.key in compile_state._insert_col_keys)
+ )
+
+ else:
+ self._enable_identity_insert = False
+
+ self._select_lastrowid = (
+ not self.compiled.inline
+ and insert_has_identity
+ and not self.compiled.returning
+ and not self._enable_identity_insert
+ and not self.executemany
+ )
+
+ if self._enable_identity_insert:
+ self.root_connection._cursor_execute(
+ self.cursor,
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s ON"
+ % self.identifier_preparer.format_table(tbl)
+ ),
+ (),
+ self,
+ )
+
+ def post_exec(self):
+ """Disable IDENTITY_INSERT if enabled."""
+
+ conn = self.root_connection
+
+ if self.isinsert or self.isupdate or self.isdelete:
+ self._rowcount = self.cursor.rowcount
+
+ if self._select_lastrowid:
+ if self.dialect.use_scope_identity:
+ conn._cursor_execute(
+ self.cursor,
+ "SELECT scope_identity() AS lastrowid",
+ (),
+ self,
+ )
+ else:
+ conn._cursor_execute(
+ self.cursor, "SELECT @@identity AS lastrowid", (), self
+ )
+ # fetchall() ensures the cursor is consumed without closing it
+ row = self.cursor.fetchall()[0]
+ self._lastrowid = int(row[0])
+
+ elif (
+ self.isinsert or self.isupdate or self.isdelete
+ ) and self.compiled.returning:
+ self.cursor_fetch_strategy = (
+ _cursor.FullyBufferedCursorFetchStrategy(
+ self.cursor,
+ self.cursor.description,
+ self.cursor.fetchall(),
+ )
+ )
+
+ if self._enable_identity_insert:
+ conn._cursor_execute(
+ self.cursor,
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s OFF"
+ % self.identifier_preparer.format_table(
+ self.compiled.compile_state.dml_table
+ )
+ ),
+ (),
+ self,
+ )
+
+ def get_lastrowid(self):
+ return self._lastrowid
+
+ @property
+ def rowcount(self):
+ if self._rowcount is not None:
+ return self._rowcount
+ else:
+ return self.cursor.rowcount
+
+ def handle_dbapi_exception(self, e):
+ if self._enable_identity_insert:
+ try:
+ self.cursor.execute(
+ self._opt_encode(
+ "SET IDENTITY_INSERT %s OFF"
+ % self.identifier_preparer.format_table(
+ self.compiled.compile_state.dml_table
+ )
+ )
+ )
+ except Exception:
+ pass
+
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ (
+ "SELECT NEXT VALUE FOR %s"
+ % self.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
+
+ def get_insert_default(self, column):
+ if (
+ isinstance(column, sa_schema.Column)
+ and column is column.table._autoincrement_column
+ and isinstance(column.default, sa_schema.Sequence)
+ and column.default.optional
+ ):
+ return None
+ return super(MSExecutionContext, self).get_insert_default(column)
+
+
+class MSSQLCompiler(compiler.SQLCompiler):
+ returning_precedes_values = True
+
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {
+ "doy": "dayofyear",
+ "dow": "weekday",
+ "milliseconds": "millisecond",
+ "microseconds": "microsecond",
+ },
+ )
+
+ def __init__(self, *args, **kwargs):
+ self.tablealiases = {}
+ super(MSSQLCompiler, self).__init__(*args, **kwargs)
+
+ def _with_legacy_schema_aliasing(fn):
+ def decorate(self, *arg, **kw):
+ if self.dialect.legacy_schema_aliasing:
+ return fn(self, *arg, **kw)
+ else:
+ super_ = getattr(super(MSSQLCompiler, self), fn.__name__)
+ return super_(*arg, **kw)
+
+ return decorate
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_current_date_func(self, fn, **kw):
+ return "GETDATE()"
+
+ def visit_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_char_length_func(self, fn, **kw):
+ return "LEN%s" % self.function_argspec(fn, **kw)
+
+ def visit_concat_op_binary(self, binary, operator, **kw):
+ return "%s + %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_true(self, expr, **kw):
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ return "0"
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ return "CONTAINS (%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def get_select_precolumns(self, select, **kw):
+ """MS-SQL puts TOP, it's version of LIMIT here"""
+
+ s = super(MSSQLCompiler, self).get_select_precolumns(select, **kw)
+
+ if select._has_row_limiting_clause and self._use_top(select):
+ # ODBC drivers and possibly others
+ # don't support bind params in the SELECT clause on SQL Server.
+ # so have to use literal here.
+ kw["literal_execute"] = True
+ s += "TOP %s " % self.process(
+ self._get_limit_or_fetch(select), **kw
+ )
+ if select._fetch_clause is not None:
+ if select._fetch_clause_options["percent"]:
+ s += "PERCENT "
+ if select._fetch_clause_options["with_ties"]:
+ s += "WITH TIES "
+
+ return s
+
+ def get_from_hint_text(self, table, text):
+ return text
+
+ def get_crud_hint_text(self, table, text):
+ return text
+
+ def _get_limit_or_fetch(self, select):
+ if select._fetch_clause is None:
+ return select._limit_clause
+ else:
+ return select._fetch_clause
+
+ def _use_top(self, select):
+ return (select._offset_clause is None) and (
+ select._simple_int_clause(select._limit_clause)
+ or (
+ # limit can use TOP with is by itself. fetch only uses TOP
+ # when it needs to because of PERCENT and/or WITH TIES
+ select._simple_int_clause(select._fetch_clause)
+ and (
+ select._fetch_clause_options["percent"]
+ or select._fetch_clause_options["with_ties"]
+ )
+ )
+ )
+
+ def fetch_clause(self, cs, **kwargs):
+ return ""
+
+ def limit_clause(self, cs, **kwargs):
+ return ""
+
+ def _check_can_use_fetch_limit(self, select):
+ # to use ROW_NUMBER(), an ORDER BY is required.
+ # OFFSET are FETCH are options of the ORDER BY clause
+ if not select._order_by_clause.clauses:
+ raise exc.CompileError(
+ "MSSQL requires an order_by when "
+ "using an OFFSET or a non-simple "
+ "LIMIT clause"
+ )
+
+ if select._fetch_clause_options is not None and (
+ select._fetch_clause_options["percent"]
+ or select._fetch_clause_options["with_ties"]
+ ):
+ raise exc.CompileError(
+ "MSSQL needs TOP to use PERCENT and/or WITH TIES. "
+ "Only simple fetch without offset can be used."
+ )
+
+ def _row_limit_clause(self, select, **kw):
+ """MSSQL 2012 supports OFFSET/FETCH operators
+ Use it instead subquery with row_number
+
+ """
+
+ if self.dialect._supports_offset_fetch and not self._use_top(select):
+ self._check_can_use_fetch_limit(select)
+
+ text = ""
+
+ if select._offset_clause is not None:
+ offset_str = self.process(select._offset_clause, **kw)
+ else:
+ offset_str = "0"
+ text += "\n OFFSET %s ROWS" % offset_str
+
+ limit = self._get_limit_or_fetch(select)
+
+ if limit is not None:
+ text += "\n FETCH FIRST %s ROWS ONLY" % self.process(
+ limit, **kw
+ )
+ return text
+ else:
+ return ""
+
+ def visit_try_cast(self, element, **kw):
+ return "TRY_CAST (%s AS %s)" % (
+ self.process(element.clause, **kw),
+ self.process(element.typeclause, **kw),
+ )
+
+ def translate_select_structure(self, select_stmt, **kwargs):
+ """Look for ``LIMIT`` and OFFSET in a select statement, and if
+ so tries to wrap it in a subquery with ``row_number()`` criterion.
+ MSSQL 2012 and above are excluded
+
+ """
+ select = select_stmt
+
+ if (
+ select._has_row_limiting_clause
+ and not self.dialect._supports_offset_fetch
+ and not self._use_top(select)
+ and not getattr(select, "_mssql_visit", None)
+ ):
+ self._check_can_use_fetch_limit(select)
+
+ _order_by_clauses = [
+ sql_util.unwrap_label_reference(elem)
+ for elem in select._order_by_clause.clauses
+ ]
+
+ limit_clause = self._get_limit_or_fetch(select)
+ offset_clause = select._offset_clause
+
+ select = select._generate()
+ select._mssql_visit = True
+ select = (
+ select.add_columns(
+ sql.func.ROW_NUMBER()
+ .over(order_by=_order_by_clauses)
+ .label("mssql_rn")
+ )
+ .order_by(None)
+ .alias()
+ )
+
+ mssql_rn = sql.column("mssql_rn")
+ limitselect = sql.select(
+ *[c for c in select.c if c.key != "mssql_rn"]
+ )
+ if offset_clause is not None:
+ limitselect = limitselect.where(mssql_rn > offset_clause)
+ if limit_clause is not None:
+ limitselect = limitselect.where(
+ mssql_rn <= (limit_clause + offset_clause)
+ )
+ else:
+ limitselect = limitselect.where(mssql_rn <= (limit_clause))
+ return limitselect
+ else:
+ return select
+
+ @_with_legacy_schema_aliasing
+ def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs):
+ if mssql_aliased is table or iscrud:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ # alias schema-qualified tables
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=table, **kwargs)
+ else:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ @_with_legacy_schema_aliasing
+ def visit_alias(self, alias, **kw):
+ # translate for schema-qualified table aliases
+ kw["mssql_aliased"] = alias.element
+ return super(MSSQLCompiler, self).visit_alias(alias, **kw)
+
+ @_with_legacy_schema_aliasing
+ def visit_column(self, column, add_to_result_map=None, **kw):
+ if (
+ column.table is not None
+ and (not self.isupdate and not self.isdelete)
+ or self.is_subquery()
+ ):
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ converted = elements._corresponding_column_or_error(t, column)
+ if add_to_result_map is not None:
+ add_to_result_map(
+ column.name,
+ column.name,
+ (column, column.name, column.key),
+ column.type,
+ )
+
+ return super(MSSQLCompiler, self).visit_column(converted, **kw)
+
+ return super(MSSQLCompiler, self).visit_column(
+ column, add_to_result_map=add_to_result_map, **kw
+ )
+
+ def _schema_aliased_table(self, table):
+ if getattr(table, "schema", None) is not None:
+ if table not in self.tablealiases:
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_extract(self, extract, **kw):
+ field = self.extract_map.get(extract.field, extract.field)
+ return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw))
+
+ def visit_savepoint(self, savepoint_stmt):
+ return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+ def visit_binary(self, binary, **kwargs):
+ """Move bind parameters to the right-hand side of an operator, where
+ possible.
+
+ """
+ if (
+ isinstance(binary.left, expression.BindParameter)
+ and binary.operator == operator.eq
+ and not isinstance(binary.right, expression.BindParameter)
+ ):
+ return self.process(
+ expression.BinaryExpression(
+ binary.right, binary.left, binary.operator
+ ),
+ **kwargs
+ )
+ return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
+
+ def returning_clause(self, stmt, returning_cols):
+ # SQL server returning clause requires that the columns refer to
+ # the virtual table names "inserted" or "deleted". Here, we make
+ # a simple alias of our table with that name, and then adapt the
+ # columns we have from the list of RETURNING columns to that new name
+ # so that they render as "inserted.<colname>" / "deleted.<colname>".
+
+ if self.isinsert or self.isupdate:
+ target = stmt.table.alias("inserted")
+ else:
+ target = stmt.table.alias("deleted")
+
+ adapter = sql_util.ClauseAdapter(target)
+
+ # adapter.traverse() takes a column from our target table and returns
+ # the one that is linked to the "inserted" / "deleted" tables. So in
+ # order to retrieve these values back from the result (e.g. like
+ # row[column]), tell the compiler to also add the original unadapted
+ # column to the result map. Before #4877, these were (unknowingly)
+ # falling back using string name matching in the result set which
+ # necessarily used an expensive KeyError in order to match.
+
+ columns = [
+ self._label_returning_column(
+ stmt,
+ adapter.traverse(c),
+ {"result_map_targets": (c,)},
+ )
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return "OUTPUT " + ", ".join(columns)
+
+ def get_cte_preamble(self, recursive):
+ # SQL Server finds it too inconvenient to accept
+ # an entirely optional, SQL standard specified,
+ # "RECURSIVE" word with their "WITH",
+ # so here we go
+ return "WITH"
+
+ def label_select_column(self, select, column, asfrom):
+ if isinstance(column, expression.Function):
+ return column.label(None)
+ else:
+ return super(MSSQLCompiler, self).label_select_column(
+ select, column, asfrom
+ )
+
+ def for_update_clause(self, select, **kw):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which
+ # SQLAlchemy doesn't use
+ return ""
+
+ def order_by_clause(self, select, **kw):
+ # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if (
+ self.is_subquery()
+ and not select._limit
+ and (
+ select._offset is None
+ or not self.dialect._supports_offset_fetch
+ )
+ ):
+ # avoid processing the order by clause if we won't end up
+ # using it, because we don't want all the bind params tacked
+ # onto the positional list if that is what the dbapi requires
+ return ""
+
+ order_by = self.process(select._order_by_clause, **kw)
+
+ if order_by:
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the UPDATE..FROM clause specific to MSSQL.
+
+ In MSSQL, if the UPDATE statement involves an alias of the table to
+ be updated, then the table itself must be added to the FROM list as
+ well. Otherwise, it is optional. Here, we add it regardless.
+
+ """
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ """If we have extra froms make sure we render any alias as hint."""
+ ashint = False
+ if extra_froms:
+ ashint = True
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, ashint=ashint
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. FROM clause specific to MSSQL.
+
+ Yes, it has the FROM keyword twice.
+
+ """
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+ def visit_empty_set_expr(self, type_):
+ return "SELECT 1 WHERE 1!=1"
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "NOT EXISTS (SELECT %s INTERSECT SELECT %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "EXISTS (SELECT %s INTERSECT SELECT %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def _render_json_extract_from_binary(self, binary, operator, **kw):
+ # note we are intentionally calling upon the process() calls in the
+ # order in which they appear in the SQL String as this is used
+ # by positional parameter rendering
+
+ if binary.type._type_affinity is sqltypes.JSON:
+ return "JSON_QUERY(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ # as with other dialects, start with an explicit test for NULL
+ case_expression = "CASE JSON_VALUE(%s, %s) WHEN NULL THEN NULL" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ if binary.type._type_affinity is sqltypes.Integer:
+ type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS INTEGER)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ elif binary.type._type_affinity is sqltypes.Numeric:
+ type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ "FLOAT"
+ if isinstance(binary.type, sqltypes.Float)
+ else "NUMERIC(%s, %s)"
+ % (binary.type.precision, binary.type.scale),
+ )
+ elif binary.type._type_affinity is sqltypes.Boolean:
+ # the NULL handling is particularly weird with boolean, so
+ # explicitly return numeric (BIT) constants
+ type_expression = (
+ "WHEN 'true' THEN 1 WHEN 'false' THEN 0 ELSE NULL"
+ )
+ elif binary.type._type_affinity is sqltypes.String:
+ # TODO: does this comment (from mysql) apply to here, too?
+ # this fails with a JSON value that's a four byte unicode
+ # string. SQLite has the same problem at the moment
+ type_expression = "ELSE JSON_VALUE(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ else:
+ # other affinity....this is not expected right now
+ type_expression = "ELSE JSON_QUERY(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ return case_expression + " " + type_expression + " END"
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_sequence(self, seq, **kw):
+ return "NEXT VALUE FOR %s" % self.preparer.format_sequence(seq)
+
+
+class MSSQLStrictCompiler(MSSQLCompiler):
+
+ """A subclass of MSSQLCompiler which disables the usage of bind
+ parameters where not allowed natively by MS-SQL.
+
+ A dialect may use this compiler on a platform where native
+ binds are used.
+
+ """
+
+ ansi_bind_rules = True
+
+ def visit_in_op_binary(self, binary, operator, **kw):
+ kw["literal_execute"] = True
+ return "%s IN %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_not_in_op_binary(self, binary, operator, **kw):
+ kw["literal_execute"] = True
+ return "%s NOT IN %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def render_literal_value(self, value, type_):
+ """
+ For date and datetime values, convert to a string
+ format acceptable to MSSQL. That seems to be the
+ so-called ODBC canonical date format which looks
+ like this:
+
+ yyyy-mm-dd hh:mi:ss.mmm(24h)
+
+ For other data types, call the base class implementation.
+ """
+ # datetime and date are both subclasses of datetime.date
+ if issubclass(type(value), datetime.date):
+ # SQL Server wants single quotes around the date string.
+ return "'" + str(value) + "'"
+ else:
+ return super(MSSQLStrictCompiler, self).render_literal_value(
+ value, type_
+ )
+
+
+class MSDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column)
+
+ # type is not accepted in a computed column
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+ else:
+ colspec += " " + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+
+ if column.nullable is not None:
+ if (
+ not column.nullable
+ or column.primary_key
+ or isinstance(column.default, sa_schema.Sequence)
+ or column.autoincrement is True
+ or column.identity
+ ):
+ colspec += " NOT NULL"
+ elif column.computed is None:
+ # don't specify "NULL" for computed columns
+ colspec += " NULL"
+
+ if column.table is None:
+ raise exc.CompileError(
+ "mssql requires Table-bound columns "
+ "in order to generate DDL"
+ )
+
+ d_opt = column.dialect_options["mssql"]
+ start = d_opt["identity_start"]
+ increment = d_opt["identity_increment"]
+ if start is not None or increment is not None:
+ if column.identity:
+ raise exc.CompileError(
+ "Cannot specify options 'mssql_identity_start' and/or "
+ "'mssql_identity_increment' while also using the "
+ "'Identity' construct."
+ )
+ util.warn_deprecated(
+ "The dialect options 'mssql_identity_start' and "
+ "'mssql_identity_increment' are deprecated. "
+ "Use the 'Identity' object instead.",
+ "1.4",
+ )
+
+ if column.identity:
+ colspec += self.process(column.identity, **kwargs)
+ elif (
+ column is column.table._autoincrement_column
+ or column.autoincrement is True
+ ) and (
+ not isinstance(column.default, Sequence) or column.default.optional
+ ):
+ colspec += self.process(Identity(start=start, increment=increment))
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ return colspec
+
+ def visit_create_index(self, create, include_schema=False):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+
+ # handle clustering option
+ clustered = index.dialect_options["mssql"]["clustered"]
+ if clustered is not None:
+ if clustered:
+ text += "CLUSTERED "
+ else:
+ text += "NONCLUSTERED "
+
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(index.table),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+
+ # handle other included columns
+ if index.dialect_options["mssql"]["include"]:
+ inclusions = [
+ index.table.c[col]
+ if isinstance(col, util.string_types)
+ else col
+ for col in index.dialect_options["mssql"]["include"]
+ ]
+
+ text += " INCLUDE (%s)" % ", ".join(
+ [preparer.quote(c.name) for c in inclusions]
+ )
+
+ whereclause = index.dialect_options["mssql"]["where"]
+
+ if whereclause is not None:
+ whereclause = coercions.expect(
+ roles.DDLExpressionRole, whereclause
+ )
+
+ where_compiled = self.sql_compiler.process(
+ whereclause, include_table=False, literal_binds=True
+ )
+ text += " WHERE " + where_compiled
+
+ return text
+
+ def visit_drop_index(self, drop):
+ return "\nDROP INDEX %s ON %s" % (
+ self._prepared_index_name(drop.element, include_schema=False),
+ self.preparer.format_table(drop.element.table),
+ )
+
+ def visit_primary_key_constraint(self, constraint):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
+ text += "PRIMARY KEY "
+
+ clustered = constraint.dialect_options["mssql"]["clustered"]
+ if clustered is not None:
+ if clustered:
+ text += "CLUSTERED "
+ else:
+ text += "NONCLUSTERED "
+
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_unique_constraint(self, constraint):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "UNIQUE "
+
+ clustered = constraint.dialect_options["mssql"]["clustered"]
+ if clustered is not None:
+ if clustered:
+ text += "CLUSTERED "
+ else:
+ text += "NONCLUSTERED "
+
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name) for c in constraint
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_computed_column(self, generated):
+ text = "AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+ # explicitly check for True|False since None means server default
+ if generated.persisted is True:
+ text += " PERSISTED"
+ return text
+
+ def visit_create_sequence(self, create, **kw):
+ prefix = None
+ if create.element.data_type is not None:
+ data_type = create.element.data_type
+ prefix = " AS %s" % self.type_compiler.process(data_type)
+ return super(MSDDLCompiler, self).visit_create_sequence(
+ create, prefix=prefix, **kw
+ )
+
+ def visit_identity_column(self, identity, **kw):
+ text = " IDENTITY"
+ if identity.start is not None or identity.increment is not None:
+ start = 1 if identity.start is None else identity.start
+ increment = 1 if identity.increment is None else identity.increment
+ text += "(%s,%s)" % (start, increment)
+ return text
+
+
+class MSIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS
+
+ def __init__(self, dialect):
+ super(MSIdentifierPreparer, self).__init__(
+ dialect,
+ initial_quote="[",
+ final_quote="]",
+ quote_case_sensitive_collations=False,
+ )
+
+ def _escape_identifier(self, value):
+ return value.replace("]", "]]")
+
+ def _unescape_identifier(self, value):
+ return value.replace("]]", "]")
+
+ def quote_schema(self, schema, force=None):
+ """Prepare a quoted table and schema name."""
+
+ # need to re-implement the deprecation warning entirely
+ if force is not None:
+ # not using the util.deprecated_params() decorator in this
+ # case because of the additional function call overhead on this
+ # very performance-critical spot.
+ util.warn_deprecated(
+ "The IdentifierPreparer.quote_schema.force parameter is "
+ "deprecated and will be removed in a future release. This "
+ "flag has no effect on the behavior of the "
+ "IdentifierPreparer.quote method; please refer to "
+ "quoted_name().",
+ version="1.3",
+ )
+
+ dbname, owner = _schema_elements(schema)
+ if dbname:
+ result = "%s.%s" % (self.quote(dbname), self.quote(owner))
+ elif owner:
+ result = self.quote(owner)
+ else:
+ result = ""
+ return result
+
+
+def _db_plus_owner_listing(fn):
+ def wrap(dialect, connection, schema=None, **kw):
+ dbname, owner = _owner_plus_db(dialect, schema)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
+ return update_wrapper(wrap, fn)
+
+
+def _db_plus_owner(fn):
+ def wrap(dialect, connection, tablename, schema=None, **kw):
+ dbname, owner = _owner_plus_db(dialect, schema)
+ return _switch_db(
+ dbname,
+ connection,
+ fn,
+ dialect,
+ connection,
+ tablename,
+ dbname,
+ owner,
+ schema,
+ **kw
+ )
+
+ return update_wrapper(wrap, fn)
+
+
+def _switch_db(dbname, connection, fn, *arg, **kw):
+ if dbname:
+ current_db = connection.exec_driver_sql("select db_name()").scalar()
+ if current_db != dbname:
+ connection.exec_driver_sql(
+ "use %s" % connection.dialect.identifier_preparer.quote(dbname)
+ )
+ try:
+ return fn(*arg, **kw)
+ finally:
+ if dbname and current_db != dbname:
+ connection.exec_driver_sql(
+ "use %s"
+ % connection.dialect.identifier_preparer.quote(current_db)
+ )
+
+
+def _owner_plus_db(dialect, schema):
+ if not schema:
+ return None, dialect.default_schema_name
+ elif "." in schema:
+ return _schema_elements(schema)
+ else:
+ return None, schema
+
+
+_memoized_schema = util.LRUCache()
+
+
+def _schema_elements(schema):
+ if isinstance(schema, quoted_name) and schema.quote:
+ return None, schema
+
+ if schema in _memoized_schema:
+ return _memoized_schema[schema]
+
+ # tests for this function are in:
+ # test/dialect/mssql/test_reflection.py ->
+ # OwnerPlusDBTest.test_owner_database_pairs
+ # test/dialect/mssql/test_compiler.py -> test_force_schema_*
+ # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_*
+ #
+
+ if schema.startswith("__[SCHEMA_"):
+ return None, schema
+
+ push = []
+ symbol = ""
+ bracket = False
+ has_brackets = False
+ for token in re.split(r"(\[|\]|\.)", schema):
+ if not token:
+ continue
+ if token == "[":
+ bracket = True
+ has_brackets = True
+ elif token == "]":
+ bracket = False
+ elif not bracket and token == ".":
+ if has_brackets:
+ push.append("[%s]" % symbol)
+ else:
+ push.append(symbol)
+ symbol = ""
+ has_brackets = False
+ else:
+ symbol += token
+ if symbol:
+ push.append(symbol)
+ if len(push) > 1:
+ dbname, owner = ".".join(push[0:-1]), push[-1]
+
+ # test for internal brackets
+ if re.match(r".*\].*\[.*", dbname[1:-1]):
+ dbname = quoted_name(dbname, quote=False)
+ else:
+ dbname = dbname.lstrip("[").rstrip("]")
+
+ elif len(push):
+ dbname, owner = None, push[0]
+ else:
+ dbname, owner = None, None
+
+ _memoized_schema[schema] = dbname, owner
+ return dbname, owner
+
+
+class MSDialect(default.DefaultDialect):
+ # will assume it's at least mssql2005
+ name = "mssql"
+ supports_statement_cache = True
+ supports_default_values = True
+ supports_empty_insert = False
+ execution_ctx_cls = MSExecutionContext
+ use_scope_identity = True
+ max_identifier_length = 128
+ schema_name = "dbo"
+
+ implicit_returning = True
+ full_returning = True
+
+ colspecs = {
+ sqltypes.DateTime: _MSDateTime,
+ sqltypes.Date: _MSDate,
+ sqltypes.JSON: JSON,
+ sqltypes.JSON.JSONIndexType: JSONIndexType,
+ sqltypes.JSON.JSONPathType: JSONPathType,
+ sqltypes.Time: _BASETIMEIMPL,
+ sqltypes.Unicode: _MSUnicode,
+ sqltypes.UnicodeText: _MSUnicodeText,
+ DATETIMEOFFSET: DATETIMEOFFSET,
+ DATETIME2: DATETIME2,
+ SMALLDATETIME: SMALLDATETIME,
+ DATETIME: DATETIME,
+ }
+
+ engine_config_types = default.DefaultDialect.engine_config_types.union(
+ {"legacy_schema_aliasing": util.asbool}
+ )
+
+ ischema_names = ischema_names
+
+ supports_sequences = True
+ sequences_optional = True
+ # T-SQL's actual default is -9223372036854775808
+ default_sequence_base = 1
+
+ supports_native_boolean = False
+ non_native_boolean_check_constraint = False
+ supports_unicode_binds = True
+ postfetch_lastrowid = True
+ _supports_offset_fetch = False
+ _supports_nvarchar_max = False
+
+ legacy_schema_aliasing = False
+
+ server_version_info = ()
+
+ statement_compiler = MSSQLCompiler
+ ddl_compiler = MSDDLCompiler
+ type_compiler = MSTypeCompiler
+ preparer = MSIdentifierPreparer
+
+ construct_arguments = [
+ (sa_schema.PrimaryKeyConstraint, {"clustered": None}),
+ (sa_schema.UniqueConstraint, {"clustered": None}),
+ (sa_schema.Index, {"clustered": None, "include": None, "where": None}),
+ (
+ sa_schema.Column,
+ {"identity_start": None, "identity_increment": None},
+ ),
+ ]
+
+ def __init__(
+ self,
+ query_timeout=None,
+ use_scope_identity=True,
+ schema_name="dbo",
+ isolation_level=None,
+ deprecate_large_types=None,
+ json_serializer=None,
+ json_deserializer=None,
+ legacy_schema_aliasing=None,
+ ignore_no_transaction_on_rollback=False,
+ **opts
+ ):
+ self.query_timeout = int(query_timeout or 0)
+ self.schema_name = schema_name
+
+ self.use_scope_identity = use_scope_identity
+ self.deprecate_large_types = deprecate_large_types
+ self.ignore_no_transaction_on_rollback = (
+ ignore_no_transaction_on_rollback
+ )
+
+ if legacy_schema_aliasing is not None:
+ util.warn_deprecated(
+ "The legacy_schema_aliasing parameter is "
+ "deprecated and will be removed in a future release.",
+ "1.4",
+ )
+ self.legacy_schema_aliasing = legacy_schema_aliasing
+
+ super(MSDialect, self).__init__(**opts)
+
+ self.isolation_level = isolation_level
+ self._json_serializer = json_serializer
+ self._json_deserializer = json_deserializer
+
+ def do_savepoint(self, connection, name):
+ # give the DBAPI a push
+ connection.exec_driver_sql("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
+ super(MSDialect, self).do_savepoint(connection, name)
+
+ def do_release_savepoint(self, connection, name):
+ # SQL Server does not support RELEASE SAVEPOINT
+ pass
+
+ def do_rollback(self, dbapi_connection):
+ try:
+ super(MSDialect, self).do_rollback(dbapi_connection)
+ except self.dbapi.ProgrammingError as e:
+ if self.ignore_no_transaction_on_rollback and re.match(
+ r".*\b111214\b", str(e)
+ ):
+ util.warn(
+ "ProgrammingError 111214 "
+ "'No corresponding transaction found.' "
+ "has been suppressed via "
+ "ignore_no_transaction_on_rollback=True"
+ )
+ else:
+ raise
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "SNAPSHOT",
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+ if level not in self._isolation_lookup:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+ cursor = connection.cursor()
+ cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level)
+ cursor.close()
+ if level == "SNAPSHOT":
+ connection.commit()
+
+ def get_isolation_level(self, dbapi_connection):
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute(
+ "SELECT name FROM sys.system_views WHERE name IN "
+ "('dm_exec_sessions', 'dm_pdw_nodes_exec_sessions')"
+ )
+ row = cursor.fetchone()
+ if not row:
+ raise NotImplementedError(
+ "Can't fetch isolation level on this particular "
+ "SQL Server version."
+ )
+
+ view_name = "sys.{}".format(row[0])
+ cursor.execute(
+ """
+ SELECT CASE transaction_isolation_level
+ WHEN 0 THEN NULL
+ WHEN 1 THEN 'READ UNCOMMITTED'
+ WHEN 2 THEN 'READ COMMITTED'
+ WHEN 3 THEN 'REPEATABLE READ'
+ WHEN 4 THEN 'SERIALIZABLE'
+ WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL
+ FROM {}
+ where session_id = @@SPID
+ """.format(
+ view_name
+ )
+ )
+ row = cursor.fetchone()
+ assert row is not None
+ val = row[0]
+ finally:
+ cursor.close()
+ return val.upper()
+
+ def initialize(self, connection):
+ super(MSDialect, self).initialize(connection)
+ self._setup_version_attributes()
+ self._setup_supports_nvarchar_max(connection)
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ def _setup_version_attributes(self):
+ if self.server_version_info[0] not in list(range(8, 17)):
+ util.warn(
+ "Unrecognized server version info '%s'. Some SQL Server "
+ "features may not function properly."
+ % ".".join(str(x) for x in self.server_version_info)
+ )
+
+ if self.server_version_info >= MS_2008_VERSION:
+ self.supports_multivalues_insert = True
+ if self.deprecate_large_types is None:
+ self.deprecate_large_types = (
+ self.server_version_info >= MS_2012_VERSION
+ )
+
+ self._supports_offset_fetch = (
+ self.server_version_info and self.server_version_info[0] >= 11
+ )
+
+ def _setup_supports_nvarchar_max(self, connection):
+ try:
+ connection.scalar(
+ sql.text("SELECT CAST('test max support' AS NVARCHAR(max))")
+ )
+ except exc.DBAPIError:
+ self._supports_nvarchar_max = False
+ else:
+ self._supports_nvarchar_max = True
+
+ def _get_default_schema_name(self, connection):
+ query = sql.text("SELECT schema_name()")
+ default_schema_name = connection.scalar(query)
+ if default_schema_name is not None:
+ # guard against the case where the default_schema_name is being
+ # fed back into a table reflection function.
+ return quoted_name(default_schema_name, quote=True)
+ else:
+ return self.schema_name
+
+ @_db_plus_owner
+ def has_table(self, connection, tablename, dbname, owner, schema):
+ self._ensure_has_table_connection(connection)
+ if tablename.startswith("#"): # temporary table
+ tables = ischema.mssql_temp_table_columns
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
+ )
+ )
+
+ # #7168: fetch all (not just first match) in case some other #temp
+ # table with the same name happens to appear first
+ table_names = connection.execute(s).scalars().fetchall()
+ # #6910: verify it's not a temp table from another session
+ for table_name in table_names:
+ if bool(
+ connection.scalar(
+ text("SELECT object_id(:table_name)"),
+ {"table_name": "tempdb.dbo.[{}]".format(table_name)},
+ )
+ ):
+ return True
+ else:
+ return False
+ else:
+ tables = ischema.tables
+
+ s = sql.select(tables.c.table_name).where(
+ sql.and_(
+ tables.c.table_type == "BASE TABLE",
+ tables.c.table_name == tablename,
+ )
+ )
+
+ if owner:
+ s = s.where(tables.c.table_schema == owner)
+
+ c = connection.execute(s)
+
+ return c.first() is not None
+
+ @_db_plus_owner
+ def has_sequence(self, connection, sequencename, dbname, owner, schema):
+ sequences = ischema.sequences
+
+ s = sql.select(sequences.c.sequence_name).where(
+ sequences.c.sequence_name == sequencename
+ )
+
+ if owner:
+ s = s.where(sequences.c.sequence_schema == owner)
+
+ c = connection.execute(s)
+
+ return c.first() is not None
+
+ @reflection.cache
+ @_db_plus_owner_listing
+ def get_sequence_names(self, connection, dbname, owner, schema, **kw):
+ sequences = ischema.sequences
+
+ s = sql.select(sequences.c.sequence_name)
+ if owner:
+ s = s.where(sequences.c.sequence_schema == owner)
+
+ c = connection.execute(s)
+
+ return [row[0] for row in c]
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = sql.select(ischema.schemata.c.schema_name).order_by(
+ ischema.schemata.c.schema_name
+ )
+ schema_names = [r[0] for r in connection.execute(s)]
+ return schema_names
+
+ @reflection.cache
+ @_db_plus_owner_listing
+ def get_table_names(self, connection, dbname, owner, schema, **kw):
+ tables = ischema.tables
+ s = (
+ sql.select(tables.c.table_name)
+ .where(
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == "BASE TABLE",
+ )
+ )
+ .order_by(tables.c.table_name)
+ )
+ table_names = [r[0] for r in connection.execute(s)]
+ return table_names
+
+ @reflection.cache
+ @_db_plus_owner_listing
+ def get_view_names(self, connection, dbname, owner, schema, **kw):
+ tables = ischema.tables
+ s = (
+ sql.select(tables.c.table_name)
+ .where(
+ sql.and_(
+ tables.c.table_schema == owner,
+ tables.c.table_type == "VIEW",
+ )
+ )
+ .order_by(tables.c.table_name)
+ )
+ view_names = [r[0] for r in connection.execute(s)]
+ return view_names
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_indexes(self, connection, tablename, dbname, owner, schema, **kw):
+ filter_definition = (
+ "ind.filter_definition"
+ if self.server_version_info >= MS_2008_VERSION
+ else "NULL as filter_definition"
+ )
+ rp = connection.execution_options(future_result=True).execute(
+ sql.text(
+ "select ind.index_id, ind.is_unique, ind.name, "
+ "%s "
+ "from sys.indexes as ind join sys.tables as tab on "
+ "ind.object_id=tab.object_id "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name = :tabname "
+ "and sch.name=:schname "
+ "and ind.is_primary_key=0 and ind.type != 0"
+ % filter_definition
+ )
+ .bindparams(
+ sql.bindparam("tabname", tablename, ischema.CoerceUnicode()),
+ sql.bindparam("schname", owner, ischema.CoerceUnicode()),
+ )
+ .columns(name=sqltypes.Unicode())
+ )
+ indexes = {}
+ for row in rp.mappings():
+ indexes[row["index_id"]] = {
+ "name": row["name"],
+ "unique": row["is_unique"] == 1,
+ "column_names": [],
+ "include_columns": [],
+ }
+
+ if row["filter_definition"] is not None:
+ indexes[row["index_id"]].setdefault("dialect_options", {})[
+ "mssql_where"
+ ] = row["filter_definition"]
+
+ rp = connection.execution_options(future_result=True).execute(
+ sql.text(
+ "select ind_col.index_id, ind_col.object_id, col.name, "
+ "ind_col.is_included_column "
+ "from sys.columns as col "
+ "join sys.tables as tab on tab.object_id=col.object_id "
+ "join sys.index_columns as ind_col on "
+ "(ind_col.column_id=col.column_id and "
+ "ind_col.object_id=tab.object_id) "
+ "join sys.schemas as sch on sch.schema_id=tab.schema_id "
+ "where tab.name=:tabname "
+ "and sch.name=:schname"
+ )
+ .bindparams(
+ sql.bindparam("tabname", tablename, ischema.CoerceUnicode()),
+ sql.bindparam("schname", owner, ischema.CoerceUnicode()),
+ )
+ .columns(name=sqltypes.Unicode())
+ )
+ for row in rp.mappings():
+ if row["index_id"] in indexes:
+ if row["is_included_column"]:
+ indexes[row["index_id"]]["include_columns"].append(
+ row["name"]
+ )
+ else:
+ indexes[row["index_id"]]["column_names"].append(
+ row["name"]
+ )
+ for index_info in indexes.values():
+ # NOTE: "root level" include_columns is legacy, now part of
+ # dialect_options (issue #7382)
+ index_info.setdefault("dialect_options", {})[
+ "mssql_include"
+ ] = index_info["include_columns"]
+
+ return list(indexes.values())
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_view_definition(
+ self, connection, viewname, dbname, owner, schema, **kw
+ ):
+ rp = connection.execute(
+ sql.text(
+ "select definition from sys.sql_modules as mod, "
+ "sys.views as views, "
+ "sys.schemas as sch"
+ " where "
+ "mod.object_id=views.object_id and "
+ "views.schema_id=sch.schema_id and "
+ "views.name=:viewname and sch.name=:schname"
+ ).bindparams(
+ sql.bindparam("viewname", viewname, ischema.CoerceUnicode()),
+ sql.bindparam("schname", owner, ischema.CoerceUnicode()),
+ )
+ )
+
+ if rp:
+ view_def = rp.scalar()
+ return view_def
+
+ def _temp_table_name_like_pattern(self, tablename):
+ # LIKE uses '%' to match zero or more characters and '_' to match any
+ # single character. We want to match literal underscores, so T-SQL
+ # requires that we enclose them in square brackets.
+ return tablename + (
+ ("[_][_][_]%") if not tablename.startswith("##") else ""
+ )
+
+ def _get_internal_temp_table_name(self, connection, tablename):
+ # it's likely that schema is always "dbo", but since we can
+ # get it here, let's get it.
+ # see https://stackoverflow.com/questions/8311959/
+ # specifying-schema-for-temporary-tables
+
+ try:
+ return connection.execute(
+ sql.text(
+ "select table_schema, table_name "
+ "from tempdb.information_schema.tables "
+ "where table_name like :p1"
+ ),
+ {"p1": self._temp_table_name_like_pattern(tablename)},
+ ).one()
+ except exc.MultipleResultsFound as me:
+ util.raise_(
+ exc.UnreflectableTableError(
+ "Found more than one temporary table named '%s' in tempdb "
+ "at this time. Cannot reliably resolve that name to its "
+ "internal table name." % tablename
+ ),
+ replace_context=me,
+ )
+ except exc.NoResultFound as ne:
+ util.raise_(
+ exc.NoSuchTableError(
+ "Unable to find a temporary table named '%s' in tempdb."
+ % tablename
+ ),
+ replace_context=ne,
+ )
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
+ is_temp_table = tablename.startswith("#")
+ if is_temp_table:
+ owner, tablename = self._get_internal_temp_table_name(
+ connection, tablename
+ )
+
+ columns = ischema.mssql_temp_table_columns
+ else:
+ columns = ischema.columns
+
+ computed_cols = ischema.computed_columns
+ identity_cols = ischema.identity_columns
+ if owner:
+ whereclause = sql.and_(
+ columns.c.table_name == tablename,
+ columns.c.table_schema == owner,
+ )
+ full_name = columns.c.table_schema + "." + columns.c.table_name
+ else:
+ whereclause = columns.c.table_name == tablename
+ full_name = columns.c.table_name
+
+ join = columns.join(
+ computed_cols,
+ onclause=sql.and_(
+ computed_cols.c.object_id == func.object_id(full_name),
+ computed_cols.c.name
+ == columns.c.column_name.collate("DATABASE_DEFAULT"),
+ ),
+ isouter=True,
+ ).join(
+ identity_cols,
+ onclause=sql.and_(
+ identity_cols.c.object_id == func.object_id(full_name),
+ identity_cols.c.name
+ == columns.c.column_name.collate("DATABASE_DEFAULT"),
+ ),
+ isouter=True,
+ )
+
+ if self._supports_nvarchar_max:
+ computed_definition = computed_cols.c.definition
+ else:
+ # tds_version 4.2 does not support NVARCHAR(MAX)
+ computed_definition = sql.cast(
+ computed_cols.c.definition, NVARCHAR(4000)
+ )
+
+ s = (
+ sql.select(
+ columns,
+ computed_definition,
+ computed_cols.c.is_persisted,
+ identity_cols.c.is_identity,
+ identity_cols.c.seed_value,
+ identity_cols.c.increment_value,
+ )
+ .where(whereclause)
+ .select_from(join)
+ .order_by(columns.c.ordinal_position)
+ )
+
+ c = connection.execution_options(future_result=True).execute(s)
+
+ cols = []
+ for row in c.mappings():
+ name = row[columns.c.column_name]
+ type_ = row[columns.c.data_type]
+ nullable = row[columns.c.is_nullable] == "YES"
+ charlen = row[columns.c.character_maximum_length]
+ numericprec = row[columns.c.numeric_precision]
+ numericscale = row[columns.c.numeric_scale]
+ default = row[columns.c.column_default]
+ collation = row[columns.c.collation_name]
+ definition = row[computed_definition]
+ is_persisted = row[computed_cols.c.is_persisted]
+ is_identity = row[identity_cols.c.is_identity]
+ identity_start = row[identity_cols.c.seed_value]
+ identity_increment = row[identity_cols.c.increment_value]
+
+ coltype = self.ischema_names.get(type_, None)
+
+ kwargs = {}
+ if coltype in (
+ MSString,
+ MSChar,
+ MSNVarchar,
+ MSNChar,
+ MSText,
+ MSNText,
+ MSBinary,
+ MSVarBinary,
+ sqltypes.LargeBinary,
+ ):
+ if charlen == -1:
+ charlen = None
+ kwargs["length"] = charlen
+ if collation:
+ kwargs["collation"] = collation
+
+ if coltype is None:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (type_, name)
+ )
+ coltype = sqltypes.NULLTYPE
+ else:
+ if issubclass(coltype, sqltypes.Numeric):
+ kwargs["precision"] = numericprec
+
+ if not issubclass(coltype, sqltypes.Float):
+ kwargs["scale"] = numericscale
+
+ coltype = coltype(**kwargs)
+ cdict = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": is_identity is not None,
+ }
+
+ if definition is not None and is_persisted is not None:
+ cdict["computed"] = {
+ "sqltext": definition,
+ "persisted": is_persisted,
+ }
+
+ if is_identity is not None:
+ # identity_start and identity_increment are Decimal or None
+ if identity_start is None or identity_increment is None:
+ cdict["identity"] = {}
+ else:
+ if isinstance(coltype, sqltypes.BigInteger):
+ start = compat.long_type(identity_start)
+ increment = compat.long_type(identity_increment)
+ elif isinstance(coltype, sqltypes.Integer):
+ start = int(identity_start)
+ increment = int(identity_increment)
+ else:
+ start = identity_start
+ increment = identity_increment
+
+ cdict["identity"] = {
+ "start": start,
+ "increment": increment,
+ }
+
+ cols.append(cdict)
+
+ return cols
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_pk_constraint(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
+ pkeys = []
+ TC = ischema.constraints
+ C = ischema.key_constraints.alias("C")
+
+ # Primary key constraints
+ s = (
+ sql.select(
+ C.c.column_name, TC.c.constraint_type, C.c.constraint_name
+ )
+ .where(
+ sql.and_(
+ TC.c.constraint_name == C.c.constraint_name,
+ TC.c.table_schema == C.c.table_schema,
+ C.c.table_name == tablename,
+ C.c.table_schema == owner,
+ ),
+ )
+ .order_by(TC.c.constraint_name, C.c.ordinal_position)
+ )
+ c = connection.execution_options(future_result=True).execute(s)
+ constraint_name = None
+ for row in c.mappings():
+ if "PRIMARY" in row[TC.c.constraint_type.name]:
+ pkeys.append(row["COLUMN_NAME"])
+ if constraint_name is None:
+ constraint_name = row[C.c.constraint_name.name]
+ return {"constrained_columns": pkeys, "name": constraint_name}
+
+ @reflection.cache
+ @_db_plus_owner
+ def get_foreign_keys(
+ self, connection, tablename, dbname, owner, schema, **kw
+ ):
+ # Foreign key constraints
+ s = (
+ text(
+ """\
+WITH fk_info AS (
+ SELECT
+ ischema_ref_con.constraint_schema,
+ ischema_ref_con.constraint_name,
+ ischema_key_col.ordinal_position,
+ ischema_key_col.table_schema,
+ ischema_key_col.table_name,
+ ischema_ref_con.unique_constraint_schema,
+ ischema_ref_con.unique_constraint_name,
+ ischema_ref_con.match_option,
+ ischema_ref_con.update_rule,
+ ischema_ref_con.delete_rule,
+ ischema_key_col.column_name AS constrained_column
+ FROM
+ INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS ischema_ref_con
+ INNER JOIN
+ INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col ON
+ ischema_key_col.table_schema = ischema_ref_con.constraint_schema
+ AND ischema_key_col.constraint_name =
+ ischema_ref_con.constraint_name
+ WHERE ischema_key_col.table_name = :tablename
+ AND ischema_key_col.table_schema = :owner
+),
+constraint_info AS (
+ SELECT
+ ischema_key_col.constraint_schema,
+ ischema_key_col.constraint_name,
+ ischema_key_col.ordinal_position,
+ ischema_key_col.table_schema,
+ ischema_key_col.table_name,
+ ischema_key_col.column_name
+ FROM
+ INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col
+),
+index_info AS (
+ SELECT
+ sys.schemas.name AS index_schema,
+ sys.indexes.name AS index_name,
+ sys.index_columns.key_ordinal AS ordinal_position,
+ sys.schemas.name AS table_schema,
+ sys.objects.name AS table_name,
+ sys.columns.name AS column_name
+ FROM
+ sys.indexes
+ INNER JOIN
+ sys.objects ON
+ sys.objects.object_id = sys.indexes.object_id
+ INNER JOIN
+ sys.schemas ON
+ sys.schemas.schema_id = sys.objects.schema_id
+ INNER JOIN
+ sys.index_columns ON
+ sys.index_columns.object_id = sys.objects.object_id
+ AND sys.index_columns.index_id = sys.indexes.index_id
+ INNER JOIN
+ sys.columns ON
+ sys.columns.object_id = sys.indexes.object_id
+ AND sys.columns.column_id = sys.index_columns.column_id
+)
+ SELECT
+ fk_info.constraint_schema,
+ fk_info.constraint_name,
+ fk_info.ordinal_position,
+ fk_info.constrained_column,
+ constraint_info.table_schema AS referred_table_schema,
+ constraint_info.table_name AS referred_table_name,
+ constraint_info.column_name AS referred_column,
+ fk_info.match_option,
+ fk_info.update_rule,
+ fk_info.delete_rule
+ FROM
+ fk_info INNER JOIN constraint_info ON
+ constraint_info.constraint_schema =
+ fk_info.unique_constraint_schema
+ AND constraint_info.constraint_name =
+ fk_info.unique_constraint_name
+ AND constraint_info.ordinal_position = fk_info.ordinal_position
+ UNION
+ SELECT
+ fk_info.constraint_schema,
+ fk_info.constraint_name,
+ fk_info.ordinal_position,
+ fk_info.constrained_column,
+ index_info.table_schema AS referred_table_schema,
+ index_info.table_name AS referred_table_name,
+ index_info.column_name AS referred_column,
+ fk_info.match_option,
+ fk_info.update_rule,
+ fk_info.delete_rule
+ FROM
+ fk_info INNER JOIN index_info ON
+ index_info.index_schema = fk_info.unique_constraint_schema
+ AND index_info.index_name = fk_info.unique_constraint_name
+ AND index_info.ordinal_position = fk_info.ordinal_position
+
+ ORDER BY fk_info.constraint_schema, fk_info.constraint_name,
+ fk_info.ordinal_position
+"""
+ )
+ .bindparams(
+ sql.bindparam("tablename", tablename, ischema.CoerceUnicode()),
+ sql.bindparam("owner", owner, ischema.CoerceUnicode()),
+ )
+ .columns(
+ constraint_schema=sqltypes.Unicode(),
+ constraint_name=sqltypes.Unicode(),
+ table_schema=sqltypes.Unicode(),
+ table_name=sqltypes.Unicode(),
+ constrained_column=sqltypes.Unicode(),
+ referred_table_schema=sqltypes.Unicode(),
+ referred_table_name=sqltypes.Unicode(),
+ referred_column=sqltypes.Unicode(),
+ )
+ )
+
+ # group rows by constraint ID, to handle multi-column FKs
+ fkeys = []
+
+ def fkey_rec():
+ return {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ "options": {},
+ }
+
+ fkeys = util.defaultdict(fkey_rec)
+
+ for r in connection.execute(s).fetchall():
+ (
+ _, # constraint schema
+ rfknm,
+ _, # ordinal position
+ scol,
+ rschema,
+ rtbl,
+ rcol,
+ # TODO: we support match=<keyword> for foreign keys so
+ # we can support this also, PG has match=FULL for example
+ # but this seems to not be a valid value for SQL Server
+ _, # match rule
+ fkuprule,
+ fkdelrule,
+ ) = r
+
+ rec = fkeys[rfknm]
+ rec["name"] = rfknm
+
+ if fkuprule != "NO ACTION":
+ rec["options"]["onupdate"] = fkuprule
+
+ if fkdelrule != "NO ACTION":
+ rec["options"]["ondelete"] = fkdelrule
+
+ if not rec["referred_table"]:
+ rec["referred_table"] = rtbl
+ if schema is not None or owner != rschema:
+ if dbname:
+ rschema = dbname + "." + rschema
+ rec["referred_schema"] = rschema
+
+ local_cols, remote_cols = (
+ rec["constrained_columns"],
+ rec["referred_columns"],
+ )
+
+ local_cols.append(scol)
+ remote_cols.append(rcol)
+
+ return list(fkeys.values())
diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py
new file mode 100644
index 0000000..df91493
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/information_schema.py
@@ -0,0 +1,232 @@
+# mssql/information_schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import cast
+from ... import Column
+from ... import MetaData
+from ... import Table
+from ... import util
+from ...ext.compiler import compiles
+from ...sql import expression
+from ...types import Boolean
+from ...types import Integer
+from ...types import Numeric
+from ...types import String
+from ...types import TypeDecorator
+from ...types import Unicode
+
+
+ischema = MetaData()
+
+
+class CoerceUnicode(TypeDecorator):
+ impl = Unicode
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect):
+ if util.py2k and isinstance(value, util.binary_type):
+ value = value.decode(dialect.encoding)
+ return value
+
+ def bind_expression(self, bindvalue):
+ return _cast_on_2005(bindvalue)
+
+
+class _cast_on_2005(expression.ColumnElement):
+ def __init__(self, bindvalue):
+ self.bindvalue = bindvalue
+
+
+@compiles(_cast_on_2005)
+def _compile(element, compiler, **kw):
+ from . import base
+
+ if (
+ compiler.dialect.server_version_info is None
+ or compiler.dialect.server_version_info < base.MS_2005_VERSION
+ ):
+ return compiler.process(element.bindvalue, **kw)
+ else:
+ return compiler.process(cast(element.bindvalue, Unicode), **kw)
+
+
+schemata = Table(
+ "SCHEMATA",
+ ischema,
+ Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
+ Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
+ Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
+ schema="INFORMATION_SCHEMA",
+)
+
+tables = Table(
+ "TABLES",
+ ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("TABLE_TYPE", CoerceUnicode, key="table_type"),
+ schema="INFORMATION_SCHEMA",
+)
+
+columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+mssql_temp_table_columns = Table(
+ "COLUMNS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("IS_NULLABLE", Integer, key="is_nullable"),
+ Column("DATA_TYPE", String, key="data_type"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ Column(
+ "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+ ),
+ Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+ Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+ Column("COLUMN_DEFAULT", Integer, key="column_default"),
+ Column("COLLATION_NAME", String, key="collation_name"),
+ schema="tempdb.INFORMATION_SCHEMA",
+)
+
+constraints = Table(
+ "TABLE_CONSTRAINTS",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"),
+ schema="INFORMATION_SCHEMA",
+)
+
+column_constraints = Table(
+ "CONSTRAINT_COLUMN_USAGE",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+key_constraints = Table(
+ "KEY_COLUMN_USAGE",
+ ischema,
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+ Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+ schema="INFORMATION_SCHEMA",
+)
+
+ref_constraints = Table(
+ "REFERENTIAL_CONSTRAINTS",
+ ischema,
+ Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
+ Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
+ Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
+ # TODO: is CATLOG misspelled ?
+ Column(
+ "UNIQUE_CONSTRAINT_CATLOG",
+ CoerceUnicode,
+ key="unique_constraint_catalog",
+ ),
+ Column(
+ "UNIQUE_CONSTRAINT_SCHEMA",
+ CoerceUnicode,
+ key="unique_constraint_schema",
+ ),
+ Column(
+ "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"
+ ),
+ Column("MATCH_OPTION", String, key="match_option"),
+ Column("UPDATE_RULE", String, key="update_rule"),
+ Column("DELETE_RULE", String, key="delete_rule"),
+ schema="INFORMATION_SCHEMA",
+)
+
+views = Table(
+ "VIEWS",
+ ischema,
+ Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
+ Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+ Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+ Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
+ Column("CHECK_OPTION", String, key="check_option"),
+ Column("IS_UPDATABLE", String, key="is_updatable"),
+ schema="INFORMATION_SCHEMA",
+)
+
+computed_columns = Table(
+ "computed_columns",
+ ischema,
+ Column("object_id", Integer),
+ Column("name", CoerceUnicode),
+ Column("is_computed", Boolean),
+ Column("is_persisted", Boolean),
+ Column("definition", CoerceUnicode),
+ schema="sys",
+)
+
+sequences = Table(
+ "SEQUENCES",
+ ischema,
+ Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"),
+ Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"),
+ Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"),
+ schema="INFORMATION_SCHEMA",
+)
+
+
+class IdentitySqlVariant(TypeDecorator):
+ r"""This type casts sql_variant columns in the identity_columns view
+ to numeric. This is required because:
+
+ * pyodbc does not support sql_variant
+ * pymssql under python 2 return the byte representation of the number,
+ int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the
+ correct value as string.
+ """
+ impl = Unicode
+ cache_ok = True
+
+ def column_expression(self, colexpr):
+ return cast(colexpr, Numeric)
+
+
+identity_columns = Table(
+ "identity_columns",
+ ischema,
+ Column("object_id", Integer),
+ Column("name", CoerceUnicode),
+ Column("is_identity", Boolean),
+ Column("seed_value", IdentitySqlVariant),
+ Column("increment_value", IdentitySqlVariant),
+ Column("last_value", IdentitySqlVariant),
+ Column("is_not_for_replication", Boolean),
+ schema="sys",
+)
diff --git a/lib/sqlalchemy/dialects/mssql/json.py b/lib/sqlalchemy/dialects/mssql/json.py
new file mode 100644
index 0000000..d515731
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/json.py
@@ -0,0 +1,125 @@
+from ... import types as sqltypes
+
+# technically, all the dialect-specific datatypes that don't have any special
+# behaviors would be private with names like _MSJson. However, we haven't been
+# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType
+# / JSONPathType in their json.py files, so keep consistent with that
+# sub-convention for now. A future change can update them all to be
+# package-private at once.
+
+
+class JSON(sqltypes.JSON):
+ """MSSQL JSON type.
+
+ MSSQL supports JSON-formatted data as of SQL Server 2016.
+
+ The :class:`_mssql.JSON` datatype at the DDL level will represent the
+ datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison
+ functions as well as Python coercion behavior.
+
+ :class:`_mssql.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a SQL Server backend.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The :class:`_mssql.JSON` type supports persistence of JSON values
+ as well as the core index operations provided by :class:`_types.JSON`
+ datatype, by adapting the operations to render the ``JSON_VALUE``
+ or ``JSON_QUERY`` functions at the database level.
+
+ The SQL Server :class:`_mssql.JSON` type necessarily makes use of the
+ ``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements
+ of a JSON object. These two functions have a major restriction in that
+ they are **mutually exclusive** based on the type of object to be returned.
+ The ``JSON_QUERY`` function **only** returns a JSON dictionary or list,
+ but not an individual string, numeric, or boolean element; the
+ ``JSON_VALUE`` function **only** returns an individual string, numeric,
+ or boolean element. **both functions either return NULL or raise
+ an error if they are not used against the correct expected value**.
+
+ To handle this awkward requirement, indexed access rules are as follows:
+
+ 1. When extracting a sub element from a JSON that is itself a JSON
+ dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor
+ should be used::
+
+ stmt = select(
+ data_table.c.data["some key"].as_json()
+ ).where(
+ data_table.c.data["some key"].as_json() == {"sub": "structure"}
+ )
+
+ 2. When extracting a sub element from a JSON that is a plain boolean,
+ string, integer, or float, use the appropriate method among
+ :meth:`_types.JSON.Comparator.as_boolean`,
+ :meth:`_types.JSON.Comparator.as_string`,
+ :meth:`_types.JSON.Comparator.as_integer`,
+ :meth:`_types.JSON.Comparator.as_float`::
+
+ stmt = select(
+ data_table.c.data["some key"].as_string()
+ ).where(
+ data_table.c.data["some key"].as_string() == "some string"
+ )
+
+ .. versionadded:: 1.4
+
+
+ """
+
+ # note there was a result processor here that was looking for "number",
+ # but none of the tests seem to exercise it.
+
+
+# Note: these objects currently match exactly those of MySQL, however since
+# these are not generalizable to all JSON implementations, remain separately
+# implemented for each dialect.
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py
new file mode 100644
index 0000000..95c32d4
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py
@@ -0,0 +1,150 @@
+# mssql/mxodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: mssql+mxodbc
+ :name: mxODBC
+ :dbapi: mxodbc
+ :connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
+ :url: https://www.egenix.com/
+
+.. deprecated:: 1.4 The mxODBC DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to mssql.
+
+Execution Modes
+---------------
+
+mxODBC features two styles of statement execution, using the
+``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
+an extension to the DBAPI specification). The former makes use of a particular
+API call specific to the SQL Server Native Client ODBC driver known
+SQLDescribeParam, while the latter does not.
+
+mxODBC apparently only makes repeated use of a single prepared statement
+when SQLDescribeParam is used. The advantage to prepared statement reuse is
+one of performance. The disadvantage is that SQLDescribeParam has a limited
+set of scenarios in which bind parameters are understood, including that they
+cannot be placed within the argument lists of function calls, anywhere outside
+the FROM, or even within subqueries within the FROM clause - making the usage
+of bind parameters within SELECT statements impossible for all but the most
+simplistic statements.
+
+For this reason, the mxODBC dialect uses the "native" mode by default only for
+INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
+all other statements.
+
+This behavior can be controlled via
+:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
+``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
+value of ``True`` will unconditionally use native bind parameters and a value
+of ``False`` will unconditionally use string-escaped parameters.
+
+"""
+
+
+from .base import _MSDate
+from .base import _MSDateTime
+from .base import _MSTime
+from .base import MSDialect
+from .base import VARBINARY
+from .pyodbc import _MSNumeric_pyodbc
+from .pyodbc import MSExecutionContext_pyodbc
+from ... import types as sqltypes
+from ...connectors.mxodbc import MxODBCConnector
+
+
+class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
+ """Include pyodbc's numeric processor."""
+
+
+class _MSDate_mxodbc(_MSDate):
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is not None:
+ return "%s-%s-%s" % (value.year, value.month, value.day)
+ else:
+ return None
+
+ return process
+
+
+class _MSTime_mxodbc(_MSTime):
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is not None:
+ return "%s:%s:%s" % (value.hour, value.minute, value.second)
+ else:
+ return None
+
+ return process
+
+
+class _VARBINARY_mxodbc(VARBINARY):
+
+ """
+ mxODBC Support for VARBINARY column types.
+
+ This handles the special case for null VARBINARY values,
+ which maps None values to the mx.ODBC.Manager.BinaryNull symbol.
+ """
+
+ def bind_processor(self, dialect):
+ if dialect.dbapi is None:
+ return None
+
+ DBAPIBinary = dialect.dbapi.Binary
+
+ def process(value):
+ if value is not None:
+ return DBAPIBinary(value)
+ else:
+ # should pull from mx.ODBC.Manager.BinaryNull
+ return dialect.dbapi.BinaryNull
+
+ return process
+
+
+class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
+ """
+ The pyodbc execution context is useful for enabling
+ SELECT SCOPE_IDENTITY in cases where OUTPUT clause
+ does not work (tables with insert triggers).
+ """
+
+ # todo - investigate whether the pyodbc execution context
+ # is really only being used in cases where OUTPUT
+ # won't work.
+
+
+class MSDialect_mxodbc(MxODBCConnector, MSDialect):
+
+ # this is only needed if "native ODBC" mode is used,
+ # which is now disabled by default.
+ # statement_compiler = MSSQLStrictCompiler
+ supports_statement_cache = True
+
+ execution_ctx_cls = MSExecutionContext_mxodbc
+
+ # flag used by _MSNumeric_mxodbc
+ _need_decimal_fix = True
+
+ colspecs = {
+ sqltypes.Numeric: _MSNumeric_mxodbc,
+ sqltypes.DateTime: _MSDateTime,
+ sqltypes.Date: _MSDate_mxodbc,
+ sqltypes.Time: _MSTime_mxodbc,
+ VARBINARY: _VARBINARY_mxodbc,
+ sqltypes.LargeBinary: _VARBINARY_mxodbc,
+ }
+
+ def __init__(self, description_encoding=None, **params):
+ super(MSDialect_mxodbc, self).__init__(**params)
+ self.description_encoding = description_encoding
+
+
+dialect = MSDialect_mxodbc
diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py
new file mode 100644
index 0000000..56f3305
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/provision.py
@@ -0,0 +1,116 @@
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from ... import create_engine
+from ... import exc
+from ...schema import Column
+from ...schema import DropConstraint
+from ...schema import ForeignKeyConstraint
+from ...schema import MetaData
+from ...schema import Table
+from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_pre_tables
+from ...testing.provision import drop_db
+from ...testing.provision import get_temp_table_name
+from ...testing.provision import log
+from ...testing.provision import run_reap_dbs
+from ...testing.provision import temp_table_keyword_args
+
+
+@create_db.for_db("mssql")
+def _mssql_create_db(cfg, eng, ident):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ conn.exec_driver_sql("create database %s" % ident)
+ conn.exec_driver_sql(
+ "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
+ )
+ conn.exec_driver_sql(
+ "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
+ )
+ conn.exec_driver_sql("use %s" % ident)
+ conn.exec_driver_sql("create schema test_schema")
+ conn.exec_driver_sql("create schema test_schema_2")
+
+
+@drop_db.for_db("mssql")
+def _mssql_drop_db(cfg, eng, ident):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ _mssql_drop_ignore(conn, ident)
+
+
+def _mssql_drop_ignore(conn, ident):
+ try:
+ # typically when this happens, we can't KILL the session anyway,
+ # so let the cleanup process drop the DBs
+ # for row in conn.exec_driver_sql(
+ # "select session_id from sys.dm_exec_sessions "
+ # "where database_id=db_id('%s')" % ident):
+ # log.info("killing SQL server session %s", row['session_id'])
+ # conn.exec_driver_sql("kill %s" % row['session_id'])
+ conn.exec_driver_sql("drop database %s" % ident)
+ log.info("Reaped db: %s", ident)
+ return True
+ except exc.DatabaseError as err:
+ log.warning("couldn't drop db: %s", err)
+ return False
+
+
+@run_reap_dbs.for_db("mssql")
+def _reap_mssql_dbs(url, idents):
+ log.info("db reaper connecting to %r", url)
+ eng = create_engine(url)
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+
+ log.info("identifiers in file: %s", ", ".join(idents))
+
+ to_reap = conn.exec_driver_sql(
+ "select d.name from sys.databases as d where name "
+ "like 'TEST_%' and not exists (select session_id "
+ "from sys.dm_exec_sessions "
+ "where database_id=d.database_id)"
+ )
+ all_names = {dbname.lower() for (dbname,) in to_reap}
+ to_drop = set()
+ for name in all_names:
+ if name in idents:
+ to_drop.add(name)
+
+ dropped = total = 0
+ for total, dbname in enumerate(to_drop, 1):
+ if _mssql_drop_ignore(conn, dbname):
+ dropped += 1
+ log.info(
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
+
+
+@temp_table_keyword_args.for_db("mssql")
+def _mssql_temp_table_keyword_args(cfg, eng):
+ return {}
+
+
+@get_temp_table_name.for_db("mssql")
+def _mssql_get_temp_table_name(cfg, eng, base_name):
+ return "##" + base_name
+
+
+@drop_all_schema_objects_pre_tables.for_db("mssql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ inspector = inspect(conn)
+ for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
+ for tname in inspector.get_table_names(schema=schema):
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint(
+ [tb.c.x], [tb.c.y], name=fk["name"]
+ )
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
new file mode 100644
index 0000000..84c5fed
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/pymssql.py
@@ -0,0 +1,138 @@
+# mssql/pymssql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: mssql+pymssql
+ :name: pymssql
+ :dbapi: pymssql
+ :connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?charset=utf8
+
+pymssql is a Python module that provides a Python DBAPI interface around
+`FreeTDS <https://www.freetds.org/>`_.
+
+.. note::
+
+ pymssql is currently not included in SQLAlchemy's continuous integration
+ (CI) testing.
+
+Modern versions of this driver worked very well with SQL Server and FreeTDS
+from Linux and were highly recommended. However, pymssql is currently
+unmaintained and has fallen behind the progress of the Microsoft ODBC driver in
+its support for newer features of SQL Server. The latest official release of
+pymssql at the time of this document is version 2.1.4 (August, 2018) and it
+lacks support for:
+
+1. table-valued parameters (TVPs),
+2. ``datetimeoffset`` columns using timezone-aware ``datetime`` objects
+ (values are sent and retrieved as strings), and
+3. encrypted connections (e.g., to Azure SQL), when pymssql is installed from
+ the pre-built wheels. Support for encrypted connections requires building
+ pymssql from source, which can be a nuisance, especially under Windows.
+
+The above features are all supported by mssql+pyodbc when using Microsoft's
+ODBC Driver for SQL Server (msodbcsql), which is now available for Windows,
+(several flavors of) Linux, and macOS.
+
+
+""" # noqa
+import re
+
+from .base import MSDialect
+from .base import MSIdentifierPreparer
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+
+
+class _MSNumeric_pymssql(sqltypes.Numeric):
+ def result_processor(self, dialect, type_):
+ if not self.asdecimal:
+ return processors.to_float
+ else:
+ return sqltypes.Numeric.result_processor(self, dialect, type_)
+
+
+class MSIdentifierPreparer_pymssql(MSIdentifierPreparer):
+ def __init__(self, dialect):
+ super(MSIdentifierPreparer_pymssql, self).__init__(dialect)
+ # pymssql has the very unusual behavior that it uses pyformat
+ # yet does not require that percent signs be doubled
+ self._double_percents = False
+
+
+class MSDialect_pymssql(MSDialect):
+ supports_statement_cache = True
+ supports_native_decimal = True
+ driver = "pymssql"
+
+ preparer = MSIdentifierPreparer_pymssql
+
+ colspecs = util.update_copy(
+ MSDialect.colspecs,
+ {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float},
+ )
+
+ @classmethod
+ def dbapi(cls):
+ module = __import__("pymssql")
+ # pymmsql < 2.1.1 doesn't have a Binary method. we use string
+ client_ver = tuple(int(x) for x in module.__version__.split("."))
+ if client_ver < (2, 1, 1):
+ # TODO: monkeypatching here is less than ideal
+ module.Binary = lambda x: x if hasattr(x, "decode") else str(x)
+
+ if client_ver < (1,):
+ util.warn(
+ "The pymssql dialect expects at least "
+ "the 1.0 series of the pymssql DBAPI."
+ )
+ return module
+
+ def _get_server_version_info(self, connection):
+ vers = connection.exec_driver_sql("select @@version").scalar()
+ m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3, 4))
+ else:
+ return None
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ opts.update(url.query)
+ port = opts.pop("port", None)
+ if port and "host" in opts:
+ opts["host"] = "%s:%s" % (opts["host"], port)
+ return [[], opts]
+
+ def is_disconnect(self, e, connection, cursor):
+ for msg in (
+ "Adaptive Server connection timed out",
+ "Net-Lib error during Connection reset by peer",
+ "message 20003", # connection timeout
+ "Error 10054",
+ "Not connected to any MS SQL server",
+ "Connection is closed",
+ "message 20006", # Write to the server failed
+ "message 20017", # Unexpected EOF from the server
+ "message 20047", # DBPROCESS is dead or not enabled
+ ):
+ if msg in str(e):
+ return True
+ else:
+ return False
+
+ def set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit(True)
+ else:
+ connection.autocommit(False)
+ super(MSDialect_pymssql, self).set_isolation_level(
+ connection, level
+ )
+
+
+dialect = MSDialect_pymssql
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
new file mode 100644
index 0000000..edb76f2
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py
@@ -0,0 +1,673 @@
+# mssql/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mssql+pyodbc
+ :name: PyODBC
+ :dbapi: pyodbc
+ :connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
+ :url: https://pypi.org/project/pyodbc/
+
+Connecting to PyODBC
+--------------------
+
+The URL here is to be translated to PyODBC connection strings, as
+detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_.
+
+DSN Connections
+^^^^^^^^^^^^^^^
+
+A DSN connection in ODBC means that a pre-existing ODBC datasource is
+configured on the client machine. The application then specifies the name
+of this datasource, which encompasses details such as the specific ODBC driver
+in use as well as the network address of the database. Assuming a datasource
+is configured on the client, a basic DSN-based connection looks like::
+
+ engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
+
+Which above, will pass the following connection string to PyODBC::
+
+ DSN=some_dsn;UID=scott;PWD=tiger
+
+If the username and password are omitted, the DSN form will also add
+the ``Trusted_Connection=yes`` directive to the ODBC string.
+
+Hostname Connections
+^^^^^^^^^^^^^^^^^^^^
+
+Hostname-based connections are also supported by pyodbc. These are often
+easier to use than a DSN and have the additional advantage that the specific
+database name to connect towards may be specified locally in the URL, rather
+than it being fixed as part of a datasource configuration.
+
+When using a hostname connection, the driver name must also be specified in the
+query parameters of the URL. As these names usually have spaces in them, the
+name must be URL encoded which means using plus signs for spaces::
+
+ engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server")
+
+Other keywords interpreted by the Pyodbc dialect to be passed to
+``pyodbc.connect()`` in both the DSN and hostname cases include:
+``odbc_autotranslate``, ``ansi``, ``unicode_results``, ``autocommit``,
+``authentication``.
+Note that in order for the dialect to recognize these keywords
+(including the ``driver`` keyword above) they must be all lowercase.
+Multiple additional keyword arguments must be separated by an
+ampersand (``&``), not a semicolon::
+
+ engine = create_engine(
+ "mssql+pyodbc://scott:tiger@myhost:49242/databasename"
+ "?driver=ODBC+Driver+17+for+SQL+Server"
+ "&authentication=ActiveDirectoryIntegrated"
+ )
+
+The equivalent URL can be constructed using :class:`_sa.engine.URL`::
+
+ from sqlalchemy.engine import URL
+ connection_url = URL.create(
+ "mssql+pyodbc",
+ username="scott",
+ password="tiger",
+ host="myhost",
+ port=49242,
+ database="databasename",
+ query={
+ "driver": "ODBC Driver 17 for SQL Server",
+ "authentication": "ActiveDirectoryIntegrated",
+ },
+ )
+
+
+Pass through exact Pyodbc string
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+A PyODBC connection string can also be sent in pyodbc's format directly, as
+specified in `the PyODBC documentation
+<https://github.com/mkleehammer/pyodbc/wiki/Connecting-to-databases>`_,
+using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object
+can help make this easier::
+
+ from sqlalchemy.engine import URL
+ connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password"
+ connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string})
+
+ engine = create_engine(connection_url)
+
+.. _mssql_pyodbc_access_tokens:
+
+Connecting to databases with access tokens
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Some database servers are set up to only accept access tokens for login. For
+example, SQL Server allows the use of Azure Active Directory tokens to connect
+to databases. This requires creating a credential object using the
+``azure-identity`` library. More information about the authentication step can be
+found in `Microsoft's documentation
+<https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=bash>`_.
+
+After getting an engine, the credentials need to be sent to ``pyodbc.connect``
+each time a connection is requested. One way to do this is to set up an event
+listener on the engine that adds the credential token to the dialect's connect
+call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For
+SQL Server in particular, this is passed as an ODBC connection attribute with
+a data structure `described by Microsoft
+<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_.
+
+The following code snippet will create an engine that connects to an Azure SQL
+database using Azure credentials::
+
+ import struct
+ from sqlalchemy import create_engine, event
+ from sqlalchemy.engine.url import URL
+ from azure import identity
+
+ SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
+ TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
+
+ connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
+
+ engine = create_engine(connection_string)
+
+ azure_credentials = identity.DefaultAzureCredential()
+
+ @event.listens_for(engine, "do_connect")
+ def provide_token(dialect, conn_rec, cargs, cparams):
+ # remove the "Trusted_Connection" parameter that SQLAlchemy adds
+ cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
+
+ # create token credential
+ raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
+ token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
+
+ # apply it to keyword arguments
+ cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
+
+.. tip::
+
+ The ``Trusted_Connection`` token is currently added by the SQLAlchemy
+ pyodbc dialect when no username or password is present. This needs
+ to be removed per Microsoft's
+ `documentation for Azure access tokens
+ <https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_,
+ stating that a connection string when using an access token must not contain
+ ``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters.
+
+.. _azure_synapse_ignore_no_transaction_on_rollback:
+
+Avoiding transaction-related exceptions on Azure Synapse Analytics
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Azure Synapse Analytics has a significant difference in its transaction
+handling compared to plain SQL Server; in some cases an error within a Synapse
+transaction can cause it to be arbitrarily terminated on the server side, which
+then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to
+fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()``
+to pass silently if no transaction is present as the driver does not expect
+this condition. The symptom of this failure is an exception with a message
+resembling 'No corresponding transaction found. (111214)' when attempting to
+emit a ``.rollback()`` after an operation had a failure of some kind.
+
+This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to
+the SQL Server dialect via the :func:`_sa.create_engine` function as follows::
+
+ engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True)
+
+Using the above parameter, the dialect will catch ``ProgrammingError``
+exceptions raised during ``connection.rollback()`` and emit a warning
+if the error message contains code ``111214``, however will not raise
+an exception.
+
+.. versionadded:: 1.4.40 Added the
+ ``ignore_no_transaction_on_rollback=True`` parameter.
+
+Enable autocommit for Azure SQL Data Warehouse (DW) connections
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Azure SQL Data Warehouse does not support transactions,
+and that can cause problems with SQLAlchemy's "autobegin" (and implicit
+commit/rollback) behavior. We can avoid these problems by enabling autocommit
+at both the pyodbc and engine levels::
+
+ connection_url = sa.engine.URL.create(
+ "mssql+pyodbc",
+ username="scott",
+ password="tiger",
+ host="dw.azure.example.com",
+ database="mydb",
+ query={
+ "driver": "ODBC Driver 17 for SQL Server",
+ "autocommit": "True",
+ },
+ )
+
+ engine = create_engine(connection_url).execution_options(
+ isolation_level="AUTOCOMMIT"
+ )
+
+Avoiding sending large string parameters as TEXT/NTEXT
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+By default, for historical reasons, Microsoft's ODBC drivers for SQL Server
+send long string parameters (greater than 4000 SBCS characters or 2000 Unicode
+characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many
+years and are starting to cause compatibility issues with newer versions of
+SQL_Server/Azure. For example, see `this
+issue <https://github.com/mkleehammer/pyodbc/issues/835>`_.
+
+Starting with ODBC Driver 18 for SQL Server we can override the legacy
+behavior and pass long strings as varchar(max)/nvarchar(max) using the
+``LongAsMax=Yes`` connection string parameter::
+
+ connection_url = sa.engine.URL.create(
+ "mssql+pyodbc",
+ username="scott",
+ password="tiger",
+ host="mssqlserver.example.com",
+ database="mydb",
+ query={
+ "driver": "ODBC Driver 18 for SQL Server",
+ "LongAsMax": "Yes",
+ },
+ )
+
+
+Pyodbc Pooling / connection close behavior
+------------------------------------------
+
+PyODBC uses internal `pooling
+<https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ by
+default, which means connections will be longer lived than they are within
+SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often
+preferable to disable this behavior. This behavior can only be disabled
+globally at the PyODBC module level, **before** any connections are made::
+
+ import pyodbc
+
+ pyodbc.pooling = False
+
+ # don't use the engine before pooling is set to False
+ engine = create_engine("mssql+pyodbc://user:pass@dsn")
+
+If this variable is left at its default value of ``True``, **the application
+will continue to maintain active database connections**, even when the
+SQLAlchemy engine itself fully discards a connection or if the engine is
+disposed.
+
+.. seealso::
+
+ `pooling <https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ -
+ in the PyODBC documentation.
+
+Driver / Unicode Support
+-------------------------
+
+PyODBC works best with Microsoft ODBC drivers, particularly in the area
+of Unicode support on both Python 2 and Python 3.
+
+Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not**
+recommended; there have been historically many Unicode-related issues
+in this area, including before Microsoft offered ODBC drivers for Linux
+and OSX. Now that Microsoft offers drivers for all platforms, for
+PyODBC support these are recommended. FreeTDS remains relevant for
+non-ODBC drivers such as pymssql where it works very well.
+
+
+Rowcount Support
+----------------
+
+Pyodbc only has partial support for rowcount. See the notes at
+:ref:`mssql_rowcount_versioning` for important notes when using ORM
+versioning.
+
+.. _mssql_pyodbc_fastexecutemany:
+
+Fast Executemany Mode
+---------------------
+
+The Pyodbc driver has added support for a "fast executemany" mode of execution
+which greatly reduces round trips for a DBAPI ``executemany()`` call when using
+Microsoft ODBC drivers, for **limited size batches that fit in memory**. The
+feature is enabled by setting the flag ``.fast_executemany`` on the DBAPI
+cursor when an executemany call is to be used. The SQLAlchemy pyodbc SQL
+Server dialect supports setting this flag automatically when the
+``.fast_executemany`` flag is passed to
+:func:`_sa.create_engine` ; note that the ODBC driver must be the Microsoft
+driver in order to use this flag::
+
+ engine = create_engine(
+ "mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server",
+ fast_executemany=True)
+
+.. warning:: The pyodbc fast_executemany mode **buffers all rows in memory** and is
+ not compatible with very large batches of data. A future version of SQLAlchemy
+ may support this flag as a per-execution option instead.
+
+.. versionadded:: 1.3
+
+.. seealso::
+
+ `fast executemany <https://github.com/mkleehammer/pyodbc/wiki/Features-beyond-the-DB-API#fast_executemany>`_
+ - on github
+
+.. _mssql_pyodbc_setinputsizes:
+
+Setinputsizes Support
+-----------------------
+
+The pyodbc ``cursor.setinputsizes()`` method can be used if necessary. To
+enable this hook, pass ``use_setinputsizes=True`` to :func:`_sa.create_engine`::
+
+ engine = create_engine("mssql+pyodbc://...", use_setinputsizes=True)
+
+The behavior of the hook can then be customized, as may be necessary
+particularly if fast_executemany is in use, via the
+:meth:`.DialectEvents.do_setinputsizes` hook. See that method for usage
+examples.
+
+.. versionchanged:: 1.4.1 The pyodbc dialects will not use setinputsizes
+ unless ``use_setinputsizes=True`` is passed.
+
+""" # noqa
+
+
+import datetime
+import decimal
+import re
+import struct
+
+from .base import BINARY
+from .base import DATETIMEOFFSET
+from .base import MSDialect
+from .base import MSExecutionContext
+from .base import VARBINARY
+from ... import exc
+from ... import types as sqltypes
+from ... import util
+from ...connectors.pyodbc import PyODBCConnector
+
+
+class _ms_numeric_pyodbc(object):
+
+ """Turns Decimals with adjusted() < 0 or > 7 into strings.
+
+ The routines here are needed for older pyodbc versions
+ as well as current mxODBC versions.
+
+ """
+
+ def bind_processor(self, dialect):
+
+ super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect)
+
+ if not dialect._need_decimal_fix:
+ return super_process
+
+ def process(value):
+ if self.asdecimal and isinstance(value, decimal.Decimal):
+ adjusted = value.adjusted()
+ if adjusted < 0:
+ return self._small_dec_to_string(value)
+ elif adjusted > 7:
+ return self._large_dec_to_string(value)
+
+ if super_process:
+ return super_process(value)
+ else:
+ return value
+
+ return process
+
+ # these routines needed for older versions of pyodbc.
+ # as of 2.1.8 this logic is integrated.
+
+ def _small_dec_to_string(self, value):
+ return "%s0.%s%s" % (
+ (value < 0 and "-" or ""),
+ "0" * (abs(value.adjusted()) - 1),
+ "".join([str(nint) for nint in value.as_tuple()[1]]),
+ )
+
+ def _large_dec_to_string(self, value):
+ _int = value.as_tuple()[1]
+ if "E" in str(value):
+ result = "%s%s%s" % (
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int]),
+ "0" * (value.adjusted() - (len(_int) - 1)),
+ )
+ else:
+ if (len(_int) - 1) > value.adjusted():
+ result = "%s%s.%s" % (
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int][0 : value.adjusted() + 1]),
+ "".join([str(s) for s in _int][value.adjusted() + 1 :]),
+ )
+ else:
+ result = "%s%s" % (
+ (value < 0 and "-" or ""),
+ "".join([str(s) for s in _int][0 : value.adjusted() + 1]),
+ )
+ return result
+
+
+class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
+ pass
+
+
+class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
+ pass
+
+
+class _ms_binary_pyodbc(object):
+ """Wraps binary values in dialect-specific Binary wrapper.
+ If the value is null, return a pyodbc-specific BinaryNull
+ object to prevent pyODBC [and FreeTDS] from defaulting binary
+ NULL types to SQLWCHAR and causing implicit conversion errors.
+ """
+
+ def bind_processor(self, dialect):
+ if dialect.dbapi is None:
+ return None
+
+ DBAPIBinary = dialect.dbapi.Binary
+
+ def process(value):
+ if value is not None:
+ return DBAPIBinary(value)
+ else:
+ # pyodbc-specific
+ return dialect.dbapi.BinaryNull
+
+ return process
+
+
+class _ODBCDateTimeBindProcessor(object):
+ """Add bind processors to handle datetimeoffset behaviors"""
+
+ has_tz = False
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, util.string_types):
+ # if a string was passed directly, allow it through
+ return value
+ elif not value.tzinfo or (not self.timezone and not self.has_tz):
+ # for DateTime(timezone=False)
+ return value
+ else:
+ # for DATETIMEOFFSET or DateTime(timezone=True)
+ #
+ # Convert to string format required by T-SQL
+ dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z")
+ # offset needs a colon, e.g., -0700 -> -07:00
+ # "UTC offset in the form (+-)HHMM[SS[.ffffff]]"
+ # backend currently rejects seconds / fractional seconds
+ dto_string = re.sub(
+ r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string
+ )
+ return dto_string
+
+ return process
+
+
+class _ODBCDateTime(_ODBCDateTimeBindProcessor, sqltypes.DateTime):
+ pass
+
+
+class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET):
+ has_tz = True
+
+
+class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY):
+ pass
+
+
+class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY):
+ pass
+
+
+class MSExecutionContext_pyodbc(MSExecutionContext):
+ _embedded_scope_identity = False
+
+ def pre_exec(self):
+ """where appropriate, issue "select scope_identity()" in the same
+ statement.
+
+ Background on why "scope_identity()" is preferable to "@@identity":
+ https://msdn.microsoft.com/en-us/library/ms190315.aspx
+
+ Background on why we attempt to embed "scope_identity()" into the same
+ statement as the INSERT:
+ https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
+
+ """
+
+ super(MSExecutionContext_pyodbc, self).pre_exec()
+
+ # don't embed the scope_identity select into an
+ # "INSERT .. DEFAULT VALUES"
+ if (
+ self._select_lastrowid
+ and self.dialect.use_scope_identity
+ and len(self.parameters[0])
+ ):
+ self._embedded_scope_identity = True
+
+ self.statement += "; select scope_identity()"
+
+ def post_exec(self):
+ if self._embedded_scope_identity:
+ # Fetch the last inserted id from the manipulated statement
+ # We may have to skip over a number of result sets with
+ # no data (due to triggers, etc.)
+ while True:
+ try:
+ # fetchall() ensures the cursor is consumed
+ # without closing it (FreeTDS particularly)
+ row = self.cursor.fetchall()[0]
+ break
+ except self.dialect.dbapi.Error:
+ # no way around this - nextset() consumes the previous set
+ # so we need to just keep flipping
+ self.cursor.nextset()
+
+ self._lastrowid = int(row[0])
+ else:
+ super(MSExecutionContext_pyodbc, self).post_exec()
+
+
+class MSDialect_pyodbc(PyODBCConnector, MSDialect):
+ supports_statement_cache = True
+
+ # mssql still has problems with this on Linux
+ supports_sane_rowcount_returning = False
+
+ execution_ctx_cls = MSExecutionContext_pyodbc
+
+ colspecs = util.update_copy(
+ MSDialect.colspecs,
+ {
+ sqltypes.Numeric: _MSNumeric_pyodbc,
+ sqltypes.Float: _MSFloat_pyodbc,
+ BINARY: _BINARY_pyodbc,
+ # support DateTime(timezone=True)
+ sqltypes.DateTime: _ODBCDateTime,
+ DATETIMEOFFSET: _ODBCDATETIMEOFFSET,
+ # SQL Server dialect has a VARBINARY that is just to support
+ # "deprecate_large_types" w/ VARBINARY(max), but also we must
+ # handle the usual SQL standard VARBINARY
+ VARBINARY: _VARBINARY_pyodbc,
+ sqltypes.VARBINARY: _VARBINARY_pyodbc,
+ sqltypes.LargeBinary: _VARBINARY_pyodbc,
+ },
+ )
+
+ def __init__(
+ self, description_encoding=None, fast_executemany=False, **params
+ ):
+ if "description_encoding" in params:
+ self.description_encoding = params.pop("description_encoding")
+ super(MSDialect_pyodbc, self).__init__(**params)
+ self.use_scope_identity = (
+ self.use_scope_identity
+ and self.dbapi
+ and hasattr(self.dbapi.Cursor, "nextset")
+ )
+ self._need_decimal_fix = self.dbapi and self._dbapi_version() < (
+ 2,
+ 1,
+ 8,
+ )
+ self.fast_executemany = fast_executemany
+
+ def _get_server_version_info(self, connection):
+ try:
+ # "Version of the instance of SQL Server, in the form
+ # of 'major.minor.build.revision'"
+ raw = connection.exec_driver_sql(
+ "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)"
+ ).scalar()
+ except exc.DBAPIError:
+ # SQL Server docs indicate this function isn't present prior to
+ # 2008. Before we had the VARCHAR cast above, pyodbc would also
+ # fail on this query.
+ return super(MSDialect_pyodbc, self)._get_server_version_info(
+ connection, allow_chars=False
+ )
+ else:
+ version = []
+ r = re.compile(r"[.\-]")
+ for n in r.split(raw):
+ try:
+ version.append(int(n))
+ except ValueError:
+ pass
+ return tuple(version)
+
+ def on_connect(self):
+ super_ = super(MSDialect_pyodbc, self).on_connect()
+
+ def on_connect(conn):
+ if super_ is not None:
+ super_(conn)
+
+ self._setup_timestampoffset_type(conn)
+
+ return on_connect
+
+ def _setup_timestampoffset_type(self, connection):
+ # output converter function for datetimeoffset
+ def _handle_datetimeoffset(dto_value):
+ tup = struct.unpack("<6hI2h", dto_value)
+ return datetime.datetime(
+ tup[0],
+ tup[1],
+ tup[2],
+ tup[3],
+ tup[4],
+ tup[5],
+ tup[6] // 1000,
+ util.timezone(
+ datetime.timedelta(hours=tup[7], minutes=tup[8])
+ ),
+ )
+
+ odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h
+ connection.add_output_converter(
+ odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset
+ )
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ if self.fast_executemany:
+ cursor.fast_executemany = True
+ super(MSDialect_pyodbc, self).do_executemany(
+ cursor, statement, parameters, context=context
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ code = e.args[0]
+ if code in {
+ "08S01",
+ "01000",
+ "01002",
+ "08003",
+ "08007",
+ "08S02",
+ "08001",
+ "HYT00",
+ "HY010",
+ "10054",
+ }:
+ return True
+ return super(MSDialect_pyodbc, self).is_disconnect(
+ e, connection, cursor
+ )
+
+
+dialect = MSDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py
new file mode 100644
index 0000000..04c83d1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/__init__.py
@@ -0,0 +1,103 @@
+# mysql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import cymysql # noqa
+from . import mariadbconnector # noqa
+from . import mysqlconnector # noqa
+from . import mysqldb # noqa
+from . import oursql # noqa
+from . import pymysql # noqa
+from . import pyodbc # noqa
+from .base import BIGINT
+from .base import BINARY
+from .base import BIT
+from .base import BLOB
+from .base import BOOLEAN
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import DECIMAL
+from .base import DOUBLE
+from .base import ENUM
+from .base import FLOAT
+from .base import INTEGER
+from .base import JSON
+from .base import LONGBLOB
+from .base import LONGTEXT
+from .base import MEDIUMBLOB
+from .base import MEDIUMINT
+from .base import MEDIUMTEXT
+from .base import NCHAR
+from .base import NUMERIC
+from .base import NVARCHAR
+from .base import REAL
+from .base import SET
+from .base import SMALLINT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import TINYBLOB
+from .base import TINYINT
+from .base import TINYTEXT
+from .base import VARBINARY
+from .base import VARCHAR
+from .base import YEAR
+from .dml import Insert
+from .dml import insert
+from .expression import match
+from ...util import compat
+
+if compat.py3k:
+ from . import aiomysql # noqa
+ from . import asyncmy # noqa
+
+# default dialect
+base.dialect = dialect = mysqldb.dialect
+
+__all__ = (
+ "BIGINT",
+ "BINARY",
+ "BIT",
+ "BLOB",
+ "BOOLEAN",
+ "CHAR",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "DOUBLE",
+ "ENUM",
+ "DECIMAL",
+ "FLOAT",
+ "INTEGER",
+ "INTEGER",
+ "JSON",
+ "LONGBLOB",
+ "LONGTEXT",
+ "MEDIUMBLOB",
+ "MEDIUMINT",
+ "MEDIUMTEXT",
+ "NCHAR",
+ "NVARCHAR",
+ "NUMERIC",
+ "SET",
+ "SMALLINT",
+ "REAL",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "TINYBLOB",
+ "TINYINT",
+ "TINYTEXT",
+ "VARBINARY",
+ "VARCHAR",
+ "YEAR",
+ "dialect",
+ "insert",
+ "Insert",
+ "match",
+)
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py
new file mode 100644
index 0000000..975467c
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py
@@ -0,0 +1,317 @@
+# mysql/aiomysql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+aiomysql
+ :name: aiomysql
+ :dbapi: aiomysql
+ :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://github.com/aio-libs/aiomysql
+
+.. warning:: The aiomysql dialect is not currently tested as part of
+ SQLAlchemy’s continuous integration. As of September, 2021 the driver
+ appears to be unmaintained and no longer functions for Python version 3.10,
+ and additionally depends on a significantly outdated version of PyMySQL.
+ Please refer to the :ref:`asyncmy` dialect for current MySQL/MariaDB asyncio
+ functionality.
+
+The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
+
+Using a special asyncio mediation layer, the aiomysql dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
+
+
+""" # noqa
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import asyncio
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiomysql_cursor:
+ server_side = False
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "await_",
+ "_cursor",
+ "_rows",
+ )
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor()
+
+ # see https://github.com/aio-libs/aiomysql/issues/543
+ self._cursor = self.await_(cursor.__aenter__())
+ self._rows = []
+
+ @property
+ def description(self):
+ return self._cursor.description
+
+ @property
+ def rowcount(self):
+ return self._cursor.rowcount
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ @property
+ def lastrowid(self):
+ return self._cursor.lastrowid
+
+ def close(self):
+ # note we aren't actually closing the cursor here,
+ # we are just letting GC do it. to allow this to be async
+ # we would need the Result to change how it does "Safe close cursor".
+ # MySQL "cursors" don't actually have state to be "closed" besides
+ # exhausting rows, which we already have done for sync cursor.
+ # another option would be to emulate aiosqlite dialect and assign
+ # cursor only if we are doing server side cursor operation.
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ return self.await_(self._execute_async(operation, parameters))
+
+ def executemany(self, operation, seq_of_parameters):
+ return self.await_(
+ self._executemany_async(operation, seq_of_parameters)
+ )
+
+ async def _execute_async(self, operation, parameters):
+ async with self._adapt_connection._execute_mutex:
+ if parameters is None:
+ result = await self._cursor.execute(operation)
+ else:
+ result = await self._cursor.execute(operation, parameters)
+
+ if not self.server_side:
+ # aiomysql has a "fake" async result, so we have to pull it out
+ # of that here since our default result is not async.
+ # we could just as easily grab "_rows" here and be done with it
+ # but this is safer.
+ self._rows = list(await self._cursor.fetchall())
+ return result
+
+ async def _executemany_async(self, operation, seq_of_parameters):
+ async with self._adapt_connection._execute_mutex:
+ return await self._cursor.executemany(operation, seq_of_parameters)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
+ __slots__ = ()
+ server_side = True
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor(
+ adapt_connection.dbapi.aiomysql.SSCursor
+ )
+
+ self._cursor = self.await_(cursor.__aenter__())
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiomysql_connection(AdaptedConnection):
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection", "_execute_mutex")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+ self._execute_mutex = asyncio.Lock()
+
+ def ping(self, reconnect):
+ return self.await_(self._connection.ping(reconnect))
+
+ def character_set_name(self):
+ return self._connection.character_set_name()
+
+ def autocommit(self, value):
+ self.await_(self._connection.autocommit(value))
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_aiomysql_ss_cursor(self)
+ else:
+ return AsyncAdapt_aiomysql_cursor(self)
+
+ def rollback(self):
+ self.await_(self._connection.rollback())
+
+ def commit(self):
+ self.await_(self._connection.commit())
+
+ def close(self):
+ # it's not awaitable.
+ self._connection.close()
+
+
+class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiomysql_dbapi:
+ def __init__(self, aiomysql, pymysql):
+ self.aiomysql = aiomysql
+ self.pymysql = pymysql
+ self.paramstyle = "format"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DataError",
+ "DatabaseError",
+ "OperationalError",
+ "InterfaceError",
+ "IntegrityError",
+ "ProgrammingError",
+ "InternalError",
+ "NotSupportedError",
+ ):
+ setattr(self, name, getattr(self.aiomysql, name))
+
+ for name in (
+ "NUMBER",
+ "STRING",
+ "DATETIME",
+ "BINARY",
+ "TIMESTAMP",
+ "Binary",
+ ):
+ setattr(self, name, getattr(self.pymysql, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_aiomysql_connection(
+ self,
+ await_fallback(self.aiomysql.connect(*arg, **kw)),
+ )
+ else:
+ return AsyncAdapt_aiomysql_connection(
+ self,
+ await_only(self.aiomysql.connect(*arg, **kw)),
+ )
+
+
+class MySQLDialect_aiomysql(MySQLDialect_pymysql):
+ driver = "aiomysql"
+ supports_statement_cache = True
+
+ supports_server_side_cursors = True
+ _sscursor = AsyncAdapt_aiomysql_ss_cursor
+
+ is_async = True
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_aiomysql_dbapi(
+ __import__("aiomysql"), __import__("pymysql")
+ )
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def create_connect_args(self, url):
+ return super(MySQLDialect_aiomysql, self).create_connect_args(
+ url, _translate_args=dict(username="user", database="db")
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_aiomysql, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ else:
+ str_e = str(e).lower()
+ return "not connected" in str_e
+
+ def _found_rows_client_flag(self):
+ from pymysql.constants import CLIENT
+
+ return CLIENT.FOUND_ROWS
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = MySQLDialect_aiomysql
diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py
new file mode 100644
index 0000000..521918a
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py
@@ -0,0 +1,328 @@
+# mysql/asyncmy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+asyncmy
+ :name: asyncmy
+ :dbapi: asyncmy
+ :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://github.com/long2ice/asyncmy
+
+.. note:: The asyncmy dialect as of September, 2021 was added to provide
+ MySQL/MariaDB asyncio compatibility given that the :ref:`aiomysql` database
+ driver has become unmaintained, however asyncmy is itself very new.
+
+Using a special asyncio mediation layer, the asyncmy dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
+
+
+""" # noqa
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import asynccontextmanager
+from ...util.concurrency import asyncio
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_asyncmy_cursor:
+ server_side = False
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "await_",
+ "_cursor",
+ "_rows",
+ )
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor()
+
+ self._cursor = self.await_(cursor.__aenter__())
+ self._rows = []
+
+ @property
+ def description(self):
+ return self._cursor.description
+
+ @property
+ def rowcount(self):
+ return self._cursor.rowcount
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ @property
+ def lastrowid(self):
+ return self._cursor.lastrowid
+
+ def close(self):
+ # note we aren't actually closing the cursor here,
+ # we are just letting GC do it. to allow this to be async
+ # we would need the Result to change how it does "Safe close cursor".
+ # MySQL "cursors" don't actually have state to be "closed" besides
+ # exhausting rows, which we already have done for sync cursor.
+ # another option would be to emulate aiosqlite dialect and assign
+ # cursor only if we are doing server side cursor operation.
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ return self.await_(self._execute_async(operation, parameters))
+
+ def executemany(self, operation, seq_of_parameters):
+ return self.await_(
+ self._executemany_async(operation, seq_of_parameters)
+ )
+
+ async def _execute_async(self, operation, parameters):
+ async with self._adapt_connection._mutex_and_adapt_errors():
+ if parameters is None:
+ result = await self._cursor.execute(operation)
+ else:
+ result = await self._cursor.execute(operation, parameters)
+
+ if not self.server_side:
+ # asyncmy has a "fake" async result, so we have to pull it out
+ # of that here since our default result is not async.
+ # we could just as easily grab "_rows" here and be done with it
+ # but this is safer.
+ self._rows = list(await self._cursor.fetchall())
+ return result
+
+ async def _executemany_async(self, operation, seq_of_parameters):
+ async with self._adapt_connection._mutex_and_adapt_errors():
+ return await self._cursor.executemany(operation, seq_of_parameters)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
+ __slots__ = ()
+ server_side = True
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor(
+ adapt_connection.dbapi.asyncmy.cursors.SSCursor
+ )
+
+ self._cursor = self.await_(cursor.__aenter__())
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_asyncmy_connection(AdaptedConnection):
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection", "_execute_mutex")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+ self._execute_mutex = asyncio.Lock()
+
+ @asynccontextmanager
+ async def _mutex_and_adapt_errors(self):
+ async with self._execute_mutex:
+ try:
+ yield
+ except AttributeError:
+ raise self.dbapi.InternalError(
+ "network operation failed due to asyncmy attribute error"
+ )
+
+ def ping(self, reconnect):
+ assert not reconnect
+ return self.await_(self._do_ping())
+
+ async def _do_ping(self):
+ async with self._mutex_and_adapt_errors():
+ return await self._connection.ping(False)
+
+ def character_set_name(self):
+ return self._connection.character_set_name()
+
+ def autocommit(self, value):
+ self.await_(self._connection.autocommit(value))
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_asyncmy_ss_cursor(self)
+ else:
+ return AsyncAdapt_asyncmy_cursor(self)
+
+ def rollback(self):
+ self.await_(self._connection.rollback())
+
+ def commit(self):
+ self.await_(self._connection.commit())
+
+ def close(self):
+ # it's not awaitable.
+ self._connection.close()
+
+
+class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+def _Binary(x):
+ """Return x as a binary type."""
+ return bytes(x)
+
+
+class AsyncAdapt_asyncmy_dbapi:
+ def __init__(self, asyncmy):
+ self.asyncmy = asyncmy
+ self.paramstyle = "format"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DataError",
+ "DatabaseError",
+ "OperationalError",
+ "InterfaceError",
+ "IntegrityError",
+ "ProgrammingError",
+ "InternalError",
+ "NotSupportedError",
+ ):
+ setattr(self, name, getattr(self.asyncmy.errors, name))
+
+ STRING = util.symbol("STRING")
+ NUMBER = util.symbol("NUMBER")
+ BINARY = util.symbol("BINARY")
+ DATETIME = util.symbol("DATETIME")
+ TIMESTAMP = util.symbol("TIMESTAMP")
+ Binary = staticmethod(_Binary)
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_asyncmy_connection(
+ self,
+ await_fallback(self.asyncmy.connect(*arg, **kw)),
+ )
+ else:
+ return AsyncAdapt_asyncmy_connection(
+ self,
+ await_only(self.asyncmy.connect(*arg, **kw)),
+ )
+
+
+class MySQLDialect_asyncmy(MySQLDialect_pymysql):
+ driver = "asyncmy"
+ supports_statement_cache = True
+
+ supports_server_side_cursors = True
+ _sscursor = AsyncAdapt_asyncmy_ss_cursor
+
+ is_async = True
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def create_connect_args(self, url):
+ return super(MySQLDialect_asyncmy, self).create_connect_args(
+ url, _translate_args=dict(username="user", database="db")
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_asyncmy, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ else:
+ str_e = str(e).lower()
+ return (
+ "not connected" in str_e or "network operation failed" in str_e
+ )
+
+ def _found_rows_client_flag(self):
+ from asyncmy.constants import CLIENT
+
+ return CLIENT.FOUND_ROWS
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = MySQLDialect_asyncmy
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
new file mode 100644
index 0000000..111c63b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -0,0 +1,3306 @@
+# mysql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: mysql
+ :name: MySQL / MariaDB
+ :full_support: 5.6, 5.7, 8.0 / 10.4, 10.5
+ :normal_support: 5.6+ / 10+
+ :best_effort: 5.0.2+ / 5.0.2+
+
+Supported Versions and Features
+-------------------------------
+
+SQLAlchemy supports MySQL starting with version 5.0.2 through modern releases,
+as well as all modern versions of MariaDB. See the official MySQL
+documentation for detailed information about features supported in any given
+server release.
+
+.. versionchanged:: 1.4 minimum MySQL version supported is now 5.0.2.
+
+MariaDB Support
+~~~~~~~~~~~~~~~
+
+The MariaDB variant of MySQL retains fundamental compatibility with MySQL's
+protocols however the development of these two products continues to diverge.
+Within the realm of SQLAlchemy, the two databases have a small number of
+syntactical and behavioral differences that SQLAlchemy accommodates automatically.
+To connect to a MariaDB database, no changes to the database URL are required::
+
+
+ engine = create_engine("mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4")
+
+Upon first connect, the SQLAlchemy dialect employs a
+server version detection scheme that determines if the
+backing database reports as MariaDB. Based on this flag, the dialect
+can make different choices in those of areas where its behavior
+must be different.
+
+.. _mysql_mariadb_only_mode:
+
+MariaDB-Only Mode
+~~~~~~~~~~~~~~~~~
+
+The dialect also supports an **optional** "MariaDB-only" mode of connection, which may be
+useful for the case where an application makes use of MariaDB-specific features
+and is not compatible with a MySQL database. To use this mode of operation,
+replace the "mysql" token in the above URL with "mariadb"::
+
+ engine = create_engine("mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4")
+
+The above engine, upon first connect, will raise an error if the server version
+detection detects that the backing database is not MariaDB.
+
+When using an engine with ``"mariadb"`` as the dialect name, **all mysql-specific options
+that include the name "mysql" in them are now named with "mariadb"**. This means
+options like ``mysql_engine`` should be named ``mariadb_engine``, etc. Both
+"mysql" and "mariadb" options can be used simultaneously for applications that
+use URLs with both "mysql" and "mariadb" dialects::
+
+ my_table = Table(
+ "mytable",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("textdata", String(50)),
+ mariadb_engine="InnoDB",
+ mysql_engine="InnoDB",
+ )
+
+ Index(
+ "textdata_ix",
+ my_table.c.textdata,
+ mysql_prefix="FULLTEXT",
+ mariadb_prefix="FULLTEXT",
+ )
+
+Similar behavior will occur when the above structures are reflected, i.e. the
+"mariadb" prefix will be present in the option names when the database URL
+is based on the "mariadb" name.
+
+.. versionadded:: 1.4 Added "mariadb" dialect name supporting "MariaDB-only mode"
+ for the MySQL dialect.
+
+.. _mysql_connection_timeouts:
+
+Connection Timeouts and Disconnects
+-----------------------------------
+
+MySQL / MariaDB feature an automatic connection close behavior, for connections that
+have been idle for a fixed period of time, defaulting to eight hours.
+To circumvent having this issue, use
+the :paramref:`_sa.create_engine.pool_recycle` option which ensures that
+a connection will be discarded and replaced with a new one if it has been
+present in the pool for a fixed number of seconds::
+
+ engine = create_engine('mysql+mysqldb://...', pool_recycle=3600)
+
+For more comprehensive disconnect detection of pooled connections, including
+accommodation of server restarts and network issues, a pre-ping approach may
+be employed. See :ref:`pool_disconnects` for current approaches.
+
+.. seealso::
+
+ :ref:`pool_disconnects` - Background on several techniques for dealing
+ with timed out connections as well as database restarts.
+
+.. _mysql_storage_engines:
+
+CREATE TABLE arguments including Storage Engines
+------------------------------------------------
+
+Both MySQL's and MariaDB's CREATE TABLE syntax includes a wide array of special options,
+including ``ENGINE``, ``CHARSET``, ``MAX_ROWS``, ``ROW_FORMAT``,
+``INSERT_METHOD``, and many more.
+To accommodate the rendering of these arguments, specify the form
+``mysql_argument_name="value"``. For example, to specify a table with
+``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8mb4``, and ``KEY_BLOCK_SIZE``
+of ``1024``::
+
+ Table('mytable', metadata,
+ Column('data', String(32)),
+ mysql_engine='InnoDB',
+ mysql_charset='utf8mb4',
+ mysql_key_block_size="1024"
+ )
+
+When supporting :ref:`mysql_mariadb_only_mode` mode, similar keys against
+the "mariadb" prefix must be included as well. The values can of course
+vary independently so that different settings on MySQL vs. MariaDB may
+be maintained::
+
+ # support both "mysql" and "mariadb-only" engine URLs
+
+ Table('mytable', metadata,
+ Column('data', String(32)),
+
+ mysql_engine='InnoDB',
+ mariadb_engine='InnoDB',
+
+ mysql_charset='utf8mb4',
+ mariadb_charset='utf8',
+
+ mysql_key_block_size="1024"
+ mariadb_key_block_size="1024"
+
+ )
+
+The MySQL / MariaDB dialects will normally transfer any keyword specified as
+``mysql_keyword_name`` to be rendered as ``KEYWORD_NAME`` in the
+``CREATE TABLE`` statement. A handful of these names will render with a space
+instead of an underscore; to support this, the MySQL dialect has awareness of
+these particular names, which include ``DATA DIRECTORY``
+(e.g. ``mysql_data_directory``), ``CHARACTER SET`` (e.g.
+``mysql_character_set``) and ``INDEX DIRECTORY`` (e.g.
+``mysql_index_directory``).
+
+The most common argument is ``mysql_engine``, which refers to the storage
+engine for the table. Historically, MySQL server installations would default
+to ``MyISAM`` for this value, although newer versions may be defaulting
+to ``InnoDB``. The ``InnoDB`` engine is typically preferred for its support
+of transactions and foreign keys.
+
+A :class:`_schema.Table`
+that is created in a MySQL / MariaDB database with a storage engine
+of ``MyISAM`` will be essentially non-transactional, meaning any
+INSERT/UPDATE/DELETE statement referring to this table will be invoked as
+autocommit. It also will have no support for foreign key constraints; while
+the ``CREATE TABLE`` statement accepts foreign key options, when using the
+``MyISAM`` storage engine these arguments are discarded. Reflecting such a
+table will also produce no foreign key constraint information.
+
+For fully atomic transactions as well as support for foreign key
+constraints, all participating ``CREATE TABLE`` statements must specify a
+transactional engine, which in the vast majority of cases is ``InnoDB``.
+
+
+Case Sensitivity and Table Reflection
+-------------------------------------
+
+Both MySQL and MariaDB have inconsistent support for case-sensitive identifier
+names, basing support on specific details of the underlying
+operating system. However, it has been observed that no matter
+what case sensitivity behavior is present, the names of tables in
+foreign key declarations are *always* received from the database
+as all-lower case, making it impossible to accurately reflect a
+schema where inter-related tables use mixed-case identifier names.
+
+Therefore it is strongly advised that table names be declared as
+all lower case both within SQLAlchemy as well as on the MySQL / MariaDB
+database itself, especially if database reflection features are
+to be used.
+
+.. _mysql_isolation_level:
+
+Transaction Isolation Level
+---------------------------
+
+All MySQL / MariaDB dialects support setting of transaction isolation level both via a
+dialect-specific parameter :paramref:`_sa.create_engine.isolation_level`
+accepted
+by :func:`_sa.create_engine`, as well as the
+:paramref:`.Connection.execution_options.isolation_level` argument as passed to
+:meth:`_engine.Connection.execution_options`.
+This feature works by issuing the
+command ``SET SESSION TRANSACTION ISOLATION LEVEL <level>`` for each new
+connection. For the special AUTOCOMMIT isolation level, DBAPI-specific
+techniques are used.
+
+To set isolation level using :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "mysql://scott:tiger@localhost/test",
+ isolation_level="READ UNCOMMITTED"
+ )
+
+To set using per-connection execution options::
+
+ connection = engine.connect()
+ connection = connection.execution_options(
+ isolation_level="READ COMMITTED"
+ )
+
+Valid values for ``isolation_level`` include:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+The special ``AUTOCOMMIT`` value makes use of the various "autocommit"
+attributes provided by specific DBAPIs, and is currently supported by
+MySQLdb, MySQL-Client, MySQL-Connector Python, and PyMySQL. Using it,
+the database connection will return true for the value of
+``SELECT @@autocommit;``.
+
+There are also more options for isolation level configurations, such as
+"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply
+different isolation level settings. See the discussion at
+:ref:`dbapi_autocommit` for background.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+AUTO_INCREMENT Behavior
+-----------------------
+
+When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on
+the first :class:`.Integer` primary key column which is not marked as a
+foreign key::
+
+ >>> t = Table('mytable', metadata,
+ ... Column('mytable_id', Integer, primary_key=True)
+ ... )
+ >>> t.create()
+ CREATE TABLE mytable (
+ id INTEGER NOT NULL AUTO_INCREMENT,
+ PRIMARY KEY (id)
+ )
+
+You can disable this behavior by passing ``False`` to the
+:paramref:`_schema.Column.autoincrement` argument of :class:`_schema.Column`.
+This flag
+can also be used to enable auto-increment on a secondary column in a
+multi-column key for some storage engines::
+
+ Table('mytable', metadata,
+ Column('gid', Integer, primary_key=True, autoincrement=False),
+ Column('id', Integer, primary_key=True)
+ )
+
+.. _mysql_ss_cursors:
+
+Server Side Cursors
+-------------------
+
+Server-side cursor support is available for the mysqlclient, PyMySQL,
+mariadbconnector dialects and may also be available in others. This makes use
+of either the "buffered=True/False" flag if available or by using a class such
+as ``MySQLdb.cursors.SSCursor`` or ``pymysql.cursors.SSCursor`` internally.
+
+
+Server side cursors are enabled on a per-statement basis by using the
+:paramref:`.Connection.execution_options.stream_results` connection execution
+option::
+
+ with engine.connect() as conn:
+ result = conn.execution_options(stream_results=True).execute(text("select * from table"))
+
+Note that some kinds of SQL statements may not be supported with
+server side cursors; generally, only SQL statements that return rows should be
+used with this option.
+
+.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated
+ and will be removed in a future release. Please use the
+ :paramref:`_engine.Connection.stream_results` execution option for
+ unbuffered cursor support.
+
+.. seealso::
+
+ :ref:`engine_stream_results`
+
+.. _mysql_unicode:
+
+Unicode
+-------
+
+Charset Selection
+~~~~~~~~~~~~~~~~~
+
+Most MySQL / MariaDB DBAPIs offer the option to set the client character set for
+a connection. This is typically delivered using the ``charset`` parameter
+in the URL, such as::
+
+ e = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4")
+
+This charset is the **client character set** for the connection. Some
+MySQL DBAPIs will default this to a value such as ``latin1``, and some
+will make use of the ``default-character-set`` setting in the ``my.cnf``
+file as well. Documentation for the DBAPI in use should be consulted
+for specific behavior.
+
+The encoding used for Unicode has traditionally been ``'utf8'``. However, for
+MySQL versions 5.5.3 and MariaDB 5.5 on forward, a new MySQL-specific encoding
+``'utf8mb4'`` has been introduced, and as of MySQL 8.0 a warning is emitted by
+the server if plain ``utf8`` is specified within any server-side directives,
+replaced with ``utf8mb3``. The rationale for this new encoding is due to the
+fact that MySQL's legacy utf-8 encoding only supports codepoints up to three
+bytes instead of four. Therefore, when communicating with a MySQL or MariaDB
+database that includes codepoints more than three bytes in size, this new
+charset is preferred, if supported by both the database as well as the client
+DBAPI, as in::
+
+ e = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4")
+
+All modern DBAPIs should support the ``utf8mb4`` charset.
+
+In order to use ``utf8mb4`` encoding for a schema that was created with legacy
+``utf8``, changes to the MySQL/MariaDB schema and/or server configuration may be
+required.
+
+.. seealso::
+
+ `The utf8mb4 Character Set \
+ <https://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html>`_ - \
+ in the MySQL documentation
+
+.. _mysql_binary_introducer:
+
+Dealing with Binary Data Warnings and Unicode
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+MySQL versions 5.6, 5.7 and later (not MariaDB at the time of this writing) now
+emit a warning when attempting to pass binary data to the database, while a
+character set encoding is also in place, when the binary data itself is not
+valid for that encoding::
+
+ default.py:509: Warning: (1300, "Invalid utf8mb4 character string:
+ 'F9876A'")
+ cursor.execute(statement, parameters)
+
+This warning is due to the fact that the MySQL client library is attempting to
+interpret the binary string as a unicode object even if a datatype such
+as :class:`.LargeBinary` is in use. To resolve this, the SQL statement requires
+a binary "character set introducer" be present before any non-NULL value
+that renders like this::
+
+ INSERT INTO table (data) VALUES (_binary %s)
+
+These character set introducers are provided by the DBAPI driver, assuming the
+use of mysqlclient or PyMySQL (both of which are recommended). Add the query
+string parameter ``binary_prefix=true`` to the URL to repair this warning::
+
+ # mysqlclient
+ engine = create_engine(
+ "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true")
+
+ # PyMySQL
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true")
+
+
+The ``binary_prefix`` flag may or may not be supported by other MySQL drivers.
+
+SQLAlchemy itself cannot render this ``_binary`` prefix reliably, as it does
+not work with the NULL value, which is valid to be sent as a bound parameter.
+As the MySQL driver renders parameters directly into the SQL string, it's the
+most efficient place for this additional keyword to be passed.
+
+.. seealso::
+
+ `Character set introducers <https://dev.mysql.com/doc/refman/5.7/en/charset-introducer.html>`_ - on the MySQL website
+
+
+ANSI Quoting Style
+------------------
+
+MySQL / MariaDB feature two varieties of identifier "quoting style", one using
+backticks and the other using quotes, e.g. ```some_identifier``` vs.
+``"some_identifier"``. All MySQL dialects detect which version
+is in use by checking the value of :ref:`sql_mode<mysql_sql_mode>` when a connection is first
+established with a particular :class:`_engine.Engine`.
+This quoting style comes
+into play when rendering table and column names as well as when reflecting
+existing database structures. The detection is entirely automatic and
+no special configuration is needed to use either quoting style.
+
+
+.. _mysql_sql_mode:
+
+Changing the sql_mode
+---------------------
+
+MySQL supports operating in multiple
+`Server SQL Modes <https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html>`_ for
+both Servers and Clients. To change the ``sql_mode`` for a given application, a
+developer can leverage SQLAlchemy's Events system.
+
+In the following example, the event system is used to set the ``sql_mode`` on
+the ``first_connect`` and ``connect`` events::
+
+ from sqlalchemy import create_engine, event
+
+ eng = create_engine("mysql://scott:tiger@localhost/test", echo='debug')
+
+ # `insert=True` will ensure this is the very first listener to run
+ @event.listens_for(eng, "connect", insert=True)
+ def connect(dbapi_connection, connection_record):
+ cursor = dbapi_connection.cursor()
+ cursor.execute("SET sql_mode = 'STRICT_ALL_TABLES'")
+
+ conn = eng.connect()
+
+In the example illustrated above, the "connect" event will invoke the "SET"
+statement on the connection at the moment a particular DBAPI connection is
+first created for a given Pool, before the connection is made available to the
+connection pool. Additionally, because the function was registered with
+``insert=True``, it will be prepended to the internal list of registered
+functions.
+
+
+MySQL / MariaDB SQL Extensions
+------------------------------
+
+Many of the MySQL / MariaDB SQL extensions are handled through SQLAlchemy's generic
+function and operator support::
+
+ table.select(table.c.password==func.md5('plaintext'))
+ table.select(table.c.username.op('regexp')('^[a-d]'))
+
+And of course any valid SQL statement can be executed as a string as well.
+
+Some limited direct support for MySQL / MariaDB extensions to SQL is currently
+available.
+
+* INSERT..ON DUPLICATE KEY UPDATE: See
+ :ref:`mysql_insert_on_duplicate_key_update`
+
+* SELECT pragma, use :meth:`_expression.Select.prefix_with` and
+ :meth:`_query.Query.prefix_with`::
+
+ select(...).prefix_with(['HIGH_PRIORITY', 'SQL_SMALL_RESULT'])
+
+* UPDATE with LIMIT::
+
+ update(..., mysql_limit=10, mariadb_limit=10)
+
+* optimizer hints, use :meth:`_expression.Select.prefix_with` and
+ :meth:`_query.Query.prefix_with`::
+
+ select(...).prefix_with("/*+ NO_RANGE_OPTIMIZATION(t4 PRIMARY) */")
+
+* index hints, use :meth:`_expression.Select.with_hint` and
+ :meth:`_query.Query.with_hint`::
+
+ select(...).with_hint(some_table, "USE INDEX xyz")
+
+* MATCH operator support::
+
+ from sqlalchemy.dialects.mysql import match
+ select(...).where(match(col1, col2, against="some expr").in_boolean_mode())
+
+ .. seealso::
+
+ :class:`_mysql.match`
+
+.. _mysql_insert_on_duplicate_key_update:
+
+INSERT...ON DUPLICATE KEY UPDATE (Upsert)
+------------------------------------------
+
+MySQL / MariaDB allow "upserts" (update or insert)
+of rows into a table via the ``ON DUPLICATE KEY UPDATE`` clause of the
+``INSERT`` statement. A candidate row will only be inserted if that row does
+not match an existing primary or unique key in the table; otherwise, an UPDATE
+will be performed. The statement allows for separate specification of the
+values to INSERT versus the values for UPDATE.
+
+SQLAlchemy provides ``ON DUPLICATE KEY UPDATE`` support via the MySQL-specific
+:func:`.mysql.insert()` function, which provides
+the generative method :meth:`~.mysql.Insert.on_duplicate_key_update`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.dialects.mysql import insert
+
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... data=insert_stmt.inserted.data,
+ ... status='U'
+ ... )
+ >>> print(on_duplicate_key_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s)
+ ON DUPLICATE KEY UPDATE data = VALUES(data), status = %s
+
+
+Unlike PostgreSQL's "ON CONFLICT" phrase, the "ON DUPLICATE KEY UPDATE"
+phrase will always match on any primary key or unique key, and will always
+perform an UPDATE if there's a match; there are no options for it to raise
+an error or to skip performing an UPDATE.
+
+``ON DUPLICATE KEY UPDATE`` is used to perform an update of the already
+existing row, using any combination of new values as well as values
+from the proposed insertion. These values are normally specified using
+keyword arguments passed to the
+:meth:`_mysql.Insert.on_duplicate_key_update`
+given column key values (usually the name of the column, unless it
+specifies :paramref:`_schema.Column.key`
+) as keys and literal or SQL expressions
+as values:
+
+.. sourcecode:: pycon+sql
+
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... data="some data",
+ ... updated_at=func.current_timestamp(),
+ ... )
+
+ >>> print(on_duplicate_key_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s)
+ ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP
+
+In a manner similar to that of :meth:`.UpdateBase.values`, other parameter
+forms are accepted, including a single dictionary:
+
+.. sourcecode:: pycon+sql
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... {"data": "some data", "updated_at": func.current_timestamp()},
+ ... )
+
+as well as a list of 2-tuples, which will automatically provide
+a parameter-ordered UPDATE statement in a manner similar to that described
+at :ref:`tutorial_parameter_ordered_updates`. Unlike the :class:`_expression.Update`
+object,
+no special flag is needed to specify the intent since the argument form is
+this context is unambiguous:
+
+.. sourcecode:: pycon+sql
+
+ >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update(
+ ... [
+ ... ("data", "some data"),
+ ... ("updated_at", func.current_timestamp()),
+ ... ]
+ ... )
+
+ >>> print(on_duplicate_key_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s)
+ ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP
+
+.. versionchanged:: 1.3 support for parameter-ordered UPDATE clause within
+ MySQL ON DUPLICATE KEY UPDATE
+
+.. warning::
+
+ The :meth:`_mysql.Insert.on_duplicate_key_update`
+ method does **not** take into
+ account Python-side default UPDATE values or generation functions, e.g.
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON DUPLICATE KEY style of UPDATE,
+ unless they are manually specified explicitly in the parameters.
+
+
+
+In order to refer to the proposed insertion row, the special alias
+:attr:`_mysql.Insert.inserted` is available as an attribute on
+the :class:`_mysql.Insert` object; this object is a
+:class:`_expression.ColumnCollection` which contains all columns of the target
+table:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh')
+
+ >>> do_update_stmt = stmt.on_duplicate_key_update(
+ ... data="updated value",
+ ... author=stmt.inserted.author
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author) VALUES (%s, %s, %s)
+ ON DUPLICATE KEY UPDATE data = %s, author = VALUES(author)
+
+When rendered, the "inserted" namespace will produce the expression
+``VALUES(<columnname>)``.
+
+.. versionadded:: 1.2 Added support for MySQL ON DUPLICATE KEY UPDATE clause
+
+
+
+rowcount Support
+----------------
+
+SQLAlchemy standardizes the DBAPI ``cursor.rowcount`` attribute to be the
+usual definition of "number of rows matched by an UPDATE or DELETE" statement.
+This is in contradiction to the default setting on most MySQL DBAPI drivers,
+which is "number of rows actually modified/deleted". For this reason, the
+SQLAlchemy MySQL dialects always add the ``constants.CLIENT.FOUND_ROWS``
+flag, or whatever is equivalent for the target dialect, upon connection.
+This setting is currently hardcoded.
+
+.. seealso::
+
+ :attr:`_engine.CursorResult.rowcount`
+
+
+.. _mysql_indexes:
+
+MySQL / MariaDB- Specific Index Options
+-----------------------------------------
+
+MySQL and MariaDB-specific extensions to the :class:`.Index` construct are available.
+
+Index Length
+~~~~~~~~~~~~~
+
+MySQL and MariaDB both provide an option to create index entries with a certain length, where
+"length" refers to the number of characters or bytes in each value which will
+become part of the index. SQLAlchemy provides this feature via the
+``mysql_length`` and/or ``mariadb_length`` parameters::
+
+ Index('my_index', my_table.c.data, mysql_length=10, mariadb_length=10)
+
+ Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4,
+ 'b': 9})
+
+ Index('a_b_idx', my_table.c.a, my_table.c.b, mariadb_length={'a': 4,
+ 'b': 9})
+
+Prefix lengths are given in characters for nonbinary string types and in bytes
+for binary string types. The value passed to the keyword argument *must* be
+either an integer (and, thus, specify the same prefix length value for all
+columns of the index) or a dict in which keys are column names and values are
+prefix length values for corresponding columns. MySQL and MariaDB only allow a
+length for a column of an index if it is for a CHAR, VARCHAR, TEXT, BINARY,
+VARBINARY and BLOB.
+
+Index Prefixes
+~~~~~~~~~~~~~~
+
+MySQL storage engines permit you to specify an index prefix when creating
+an index. SQLAlchemy provides this feature via the
+``mysql_prefix`` parameter on :class:`.Index`::
+
+ Index('my_index', my_table.c.data, mysql_prefix='FULLTEXT')
+
+The value passed to the keyword argument will be simply passed through to the
+underlying CREATE INDEX, so it *must* be a valid index prefix for your MySQL
+storage engine.
+
+.. versionadded:: 1.1.5
+
+.. seealso::
+
+ `CREATE INDEX <https://dev.mysql.com/doc/refman/5.0/en/create-index.html>`_ - MySQL documentation
+
+Index Types
+~~~~~~~~~~~~~
+
+Some MySQL storage engines permit you to specify an index type when creating
+an index or primary key constraint. SQLAlchemy provides this feature via the
+``mysql_using`` parameter on :class:`.Index`::
+
+ Index('my_index', my_table.c.data, mysql_using='hash', mariadb_using='hash')
+
+As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`::
+
+ PrimaryKeyConstraint("data", mysql_using='hash', mariadb_using='hash')
+
+The value passed to the keyword argument will be simply passed through to the
+underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index
+type for your MySQL storage engine.
+
+More information can be found at:
+
+https://dev.mysql.com/doc/refman/5.0/en/create-index.html
+
+https://dev.mysql.com/doc/refman/5.0/en/create-table.html
+
+Index Parsers
+~~~~~~~~~~~~~
+
+CREATE FULLTEXT INDEX in MySQL also supports a "WITH PARSER" option. This
+is available using the keyword argument ``mysql_with_parser``::
+
+ Index(
+ 'my_index', my_table.c.data,
+ mysql_prefix='FULLTEXT', mysql_with_parser="ngram",
+ mariadb_prefix='FULLTEXT', mariadb_with_parser="ngram",
+ )
+
+.. versionadded:: 1.3
+
+
+.. _mysql_foreign_keys:
+
+MySQL / MariaDB Foreign Keys
+-----------------------------
+
+MySQL and MariaDB's behavior regarding foreign keys has some important caveats.
+
+Foreign Key Arguments to Avoid
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Neither MySQL nor MariaDB support the foreign key arguments "DEFERRABLE", "INITIALLY",
+or "MATCH". Using the ``deferrable`` or ``initially`` keyword argument with
+:class:`_schema.ForeignKeyConstraint` or :class:`_schema.ForeignKey`
+will have the effect of
+these keywords being rendered in a DDL expression, which will then raise an
+error on MySQL or MariaDB. In order to use these keywords on a foreign key while having
+them ignored on a MySQL / MariaDB backend, use a custom compile rule::
+
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.schema import ForeignKeyConstraint
+
+ @compiles(ForeignKeyConstraint, "mysql", "mariadb")
+ def process(element, compiler, **kw):
+ element.deferrable = element.initially = None
+ return compiler.visit_foreign_key_constraint(element, **kw)
+
+The "MATCH" keyword is in fact more insidious, and is explicitly disallowed
+by SQLAlchemy in conjunction with the MySQL or MariaDB backends. This argument is
+silently ignored by MySQL / MariaDB, but in addition has the effect of ON UPDATE and ON
+DELETE options also being ignored by the backend. Therefore MATCH should
+never be used with the MySQL / MariaDB backends; as is the case with DEFERRABLE and
+INITIALLY, custom compilation rules can be used to correct a
+ForeignKeyConstraint at DDL definition time.
+
+Reflection of Foreign Key Constraints
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Not all MySQL / MariaDB storage engines support foreign keys. When using the
+very common ``MyISAM`` MySQL storage engine, the information loaded by table
+reflection will not include foreign keys. For these tables, you may supply a
+:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time::
+
+ Table('mytable', metadata,
+ ForeignKeyConstraint(['other_id'], ['othertable.other_id']),
+ autoload_with=engine
+ )
+
+.. seealso::
+
+ :ref:`mysql_storage_engines`
+
+.. _mysql_unique_constraints:
+
+MySQL / MariaDB Unique Constraints and Reflection
+----------------------------------------------------
+
+SQLAlchemy supports both the :class:`.Index` construct with the
+flag ``unique=True``, indicating a UNIQUE index, as well as the
+:class:`.UniqueConstraint` construct, representing a UNIQUE constraint.
+Both objects/syntaxes are supported by MySQL / MariaDB when emitting DDL to create
+these constraints. However, MySQL / MariaDB does not have a unique constraint
+construct that is separate from a unique index; that is, the "UNIQUE"
+constraint on MySQL / MariaDB is equivalent to creating a "UNIQUE INDEX".
+
+When reflecting these constructs, the
+:meth:`_reflection.Inspector.get_indexes`
+and the :meth:`_reflection.Inspector.get_unique_constraints`
+methods will **both**
+return an entry for a UNIQUE index in MySQL / MariaDB. However, when performing
+full table reflection using ``Table(..., autoload_with=engine)``,
+the :class:`.UniqueConstraint` construct is
+**not** part of the fully reflected :class:`_schema.Table` construct under any
+circumstances; this construct is always represented by a :class:`.Index`
+with the ``unique=True`` setting present in the :attr:`_schema.Table.indexes`
+collection.
+
+
+TIMESTAMP / DATETIME issues
+---------------------------
+
+.. _mysql_timestamp_onupdate:
+
+Rendering ON UPDATE CURRENT TIMESTAMP for MySQL / MariaDB's explicit_defaults_for_timestamp
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+MySQL / MariaDB have historically expanded the DDL for the :class:`_types.TIMESTAMP`
+datatype into the phrase "TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE
+CURRENT_TIMESTAMP", which includes non-standard SQL that automatically updates
+the column with the current timestamp when an UPDATE occurs, eliminating the
+usual need to use a trigger in such a case where server-side update changes are
+desired.
+
+MySQL 5.6 introduced a new flag `explicit_defaults_for_timestamp
+<https://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html
+#sysvar_explicit_defaults_for_timestamp>`_ which disables the above behavior,
+and in MySQL 8 this flag defaults to true, meaning in order to get a MySQL
+"on update timestamp" without changing this flag, the above DDL must be
+rendered explicitly. Additionally, the same DDL is valid for use of the
+``DATETIME`` datatype as well.
+
+SQLAlchemy's MySQL dialect does not yet have an option to generate
+MySQL's "ON UPDATE CURRENT_TIMESTAMP" clause, noting that this is not a general
+purpose "ON UPDATE" as there is no such syntax in standard SQL. SQLAlchemy's
+:paramref:`_schema.Column.server_onupdate` parameter is currently not related
+to this special MySQL behavior.
+
+To generate this DDL, make use of the :paramref:`_schema.Column.server_default`
+parameter and pass a textual clause that also includes the ON UPDATE clause::
+
+ from sqlalchemy import Table, MetaData, Column, Integer, String, TIMESTAMP
+ from sqlalchemy import text
+
+ metadata = MetaData()
+
+ mytable = Table(
+ "mytable",
+ metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column(
+ 'last_updated',
+ TIMESTAMP,
+ server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
+ )
+ )
+
+The same instructions apply to use of the :class:`_types.DateTime` and
+:class:`_types.DATETIME` datatypes::
+
+ from sqlalchemy import DateTime
+
+ mytable = Table(
+ "mytable",
+ metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column(
+ 'last_updated',
+ DateTime,
+ server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
+ )
+ )
+
+
+Even though the :paramref:`_schema.Column.server_onupdate` feature does not
+generate this DDL, it still may be desirable to signal to the ORM that this
+updated value should be fetched. This syntax looks like the following::
+
+ from sqlalchemy.schema import FetchedValue
+
+ class MyClass(Base):
+ __tablename__ = 'mytable'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(String(50))
+ last_updated = Column(
+ TIMESTAMP,
+ server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
+ server_onupdate=FetchedValue()
+ )
+
+
+.. _mysql_timestamp_null:
+
+TIMESTAMP Columns and NULL
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+MySQL historically enforces that a column which specifies the
+TIMESTAMP datatype implicitly includes a default value of
+CURRENT_TIMESTAMP, even though this is not stated, and additionally
+sets the column as NOT NULL, the opposite behavior vs. that of all
+other datatypes::
+
+ mysql> CREATE TABLE ts_test (
+ -> a INTEGER,
+ -> b INTEGER NOT NULL,
+ -> c TIMESTAMP,
+ -> d TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ -> e TIMESTAMP NULL);
+ Query OK, 0 rows affected (0.03 sec)
+
+ mysql> SHOW CREATE TABLE ts_test;
+ +---------+-----------------------------------------------------
+ | Table | Create Table
+ +---------+-----------------------------------------------------
+ | ts_test | CREATE TABLE `ts_test` (
+ `a` int(11) DEFAULT NULL,
+ `b` int(11) NOT NULL,
+ `c` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ `d` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ `e` timestamp NULL DEFAULT NULL
+ ) ENGINE=MyISAM DEFAULT CHARSET=latin1
+
+Above, we see that an INTEGER column defaults to NULL, unless it is specified
+with NOT NULL. But when the column is of type TIMESTAMP, an implicit
+default of CURRENT_TIMESTAMP is generated which also coerces the column
+to be a NOT NULL, even though we did not specify it as such.
+
+This behavior of MySQL can be changed on the MySQL side using the
+`explicit_defaults_for_timestamp
+<https://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html
+#sysvar_explicit_defaults_for_timestamp>`_ configuration flag introduced in
+MySQL 5.6. With this server setting enabled, TIMESTAMP columns behave like
+any other datatype on the MySQL side with regards to defaults and nullability.
+
+However, to accommodate the vast majority of MySQL databases that do not
+specify this new flag, SQLAlchemy emits the "NULL" specifier explicitly with
+any TIMESTAMP column that does not specify ``nullable=False``. In order to
+accommodate newer databases that specify ``explicit_defaults_for_timestamp``,
+SQLAlchemy also emits NOT NULL for TIMESTAMP columns that do specify
+``nullable=False``. The following example illustrates::
+
+ from sqlalchemy import MetaData, Integer, Table, Column, text
+ from sqlalchemy.dialects.mysql import TIMESTAMP
+
+ m = MetaData()
+ t = Table('ts_test', m,
+ Column('a', Integer),
+ Column('b', Integer, nullable=False),
+ Column('c', TIMESTAMP),
+ Column('d', TIMESTAMP, nullable=False)
+ )
+
+
+ from sqlalchemy import create_engine
+ e = create_engine("mysql://scott:tiger@localhost/test", echo=True)
+ m.create_all(e)
+
+output::
+
+ CREATE TABLE ts_test (
+ a INTEGER,
+ b INTEGER NOT NULL,
+ c TIMESTAMP NULL,
+ d TIMESTAMP NOT NULL
+ )
+
+.. versionchanged:: 1.0.0 - SQLAlchemy now renders NULL or NOT NULL in all
+ cases for TIMESTAMP columns, to accommodate
+ ``explicit_defaults_for_timestamp``. Prior to this version, it will
+ not render "NOT NULL" for a TIMESTAMP column that is ``nullable=False``.
+
+""" # noqa
+
+from array import array as _array
+from collections import defaultdict
+from itertools import compress
+import re
+
+from sqlalchemy import literal_column
+from sqlalchemy import text
+from sqlalchemy.sql import visitors
+from . import reflection as _reflection
+from .enumerated import ENUM
+from .enumerated import SET
+from .json import JSON
+from .json import JSONIndexType
+from .json import JSONPathType
+from .reserved_words import RESERVED_WORDS_MARIADB
+from .reserved_words import RESERVED_WORDS_MYSQL
+from .types import _FloatType
+from .types import _IntegerType
+from .types import _MatchType
+from .types import _NumericType
+from .types import _StringType
+from .types import BIGINT
+from .types import BIT
+from .types import CHAR
+from .types import DATETIME
+from .types import DECIMAL
+from .types import DOUBLE
+from .types import FLOAT
+from .types import INTEGER
+from .types import LONGBLOB
+from .types import LONGTEXT
+from .types import MEDIUMBLOB
+from .types import MEDIUMINT
+from .types import MEDIUMTEXT
+from .types import NCHAR
+from .types import NUMERIC
+from .types import NVARCHAR
+from .types import REAL
+from .types import SMALLINT
+from .types import TEXT
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TINYBLOB
+from .types import TINYINT
+from .types import TINYTEXT
+from .types import VARCHAR
+from .types import YEAR
+from ... import exc
+from ... import log
+from ... import schema as sa_schema
+from ... import sql
+from ... import types as sqltypes
+from ... import util
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import compiler
+from ...sql import elements
+from ...sql import functions
+from ...sql import operators
+from ...sql import roles
+from ...sql import util as sql_util
+from ...sql.sqltypes import Unicode
+from ...types import BINARY
+from ...types import BLOB
+from ...types import BOOLEAN
+from ...types import DATE
+from ...types import VARBINARY
+from ...util import topological
+
+AUTOCOMMIT_RE = re.compile(
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)",
+ re.I | re.UNICODE,
+)
+SET_RE = re.compile(
+ r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
+)
+
+
+# old names
+MSTime = TIME
+MSSet = SET
+MSEnum = ENUM
+MSLongBlob = LONGBLOB
+MSMediumBlob = MEDIUMBLOB
+MSTinyBlob = TINYBLOB
+MSBlob = BLOB
+MSBinary = BINARY
+MSVarBinary = VARBINARY
+MSNChar = NCHAR
+MSNVarChar = NVARCHAR
+MSChar = CHAR
+MSString = VARCHAR
+MSLongText = LONGTEXT
+MSMediumText = MEDIUMTEXT
+MSTinyText = TINYTEXT
+MSText = TEXT
+MSYear = YEAR
+MSTimeStamp = TIMESTAMP
+MSBit = BIT
+MSSmallInteger = SMALLINT
+MSTinyInteger = TINYINT
+MSMediumInteger = MEDIUMINT
+MSBigInteger = BIGINT
+MSNumeric = NUMERIC
+MSDecimal = DECIMAL
+MSDouble = DOUBLE
+MSReal = REAL
+MSFloat = FLOAT
+MSInteger = INTEGER
+
+colspecs = {
+ _IntegerType: _IntegerType,
+ _NumericType: _NumericType,
+ _FloatType: _FloatType,
+ sqltypes.Numeric: NUMERIC,
+ sqltypes.Float: FLOAT,
+ sqltypes.Time: TIME,
+ sqltypes.Enum: ENUM,
+ sqltypes.MatchType: _MatchType,
+ sqltypes.JSON: JSON,
+ sqltypes.JSON.JSONIndexType: JSONIndexType,
+ sqltypes.JSON.JSONPathType: JSONPathType,
+}
+
+# Everything 3.23 through 5.1 excepting OpenGIS types.
+ischema_names = {
+ "bigint": BIGINT,
+ "binary": BINARY,
+ "bit": BIT,
+ "blob": BLOB,
+ "boolean": BOOLEAN,
+ "char": CHAR,
+ "date": DATE,
+ "datetime": DATETIME,
+ "decimal": DECIMAL,
+ "double": DOUBLE,
+ "enum": ENUM,
+ "fixed": DECIMAL,
+ "float": FLOAT,
+ "int": INTEGER,
+ "integer": INTEGER,
+ "json": JSON,
+ "longblob": LONGBLOB,
+ "longtext": LONGTEXT,
+ "mediumblob": MEDIUMBLOB,
+ "mediumint": MEDIUMINT,
+ "mediumtext": MEDIUMTEXT,
+ "nchar": NCHAR,
+ "nvarchar": NVARCHAR,
+ "numeric": NUMERIC,
+ "set": SET,
+ "smallint": SMALLINT,
+ "text": TEXT,
+ "time": TIME,
+ "timestamp": TIMESTAMP,
+ "tinyblob": TINYBLOB,
+ "tinyint": TINYINT,
+ "tinytext": TINYTEXT,
+ "varbinary": VARBINARY,
+ "varchar": VARCHAR,
+ "year": YEAR,
+}
+
+
+class MySQLExecutionContext(default.DefaultExecutionContext):
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_RE.match(statement)
+
+ def create_server_side_cursor(self):
+ if self.dialect.supports_server_side_cursors:
+ return self._dbapi_connection.cursor(self.dialect._sscursor)
+ else:
+ raise NotImplementedError()
+
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ (
+ "select nextval(%s)"
+ % self.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
+
+
+class MySQLCompiler(compiler.SQLCompiler):
+
+ render_table_with_column_in_update_from = True
+ """Overridden from base SQLCompiler value"""
+
+ extract_map = compiler.SQLCompiler.extract_map.copy()
+ extract_map.update({"milliseconds": "millisecond"})
+
+ def default_from(self):
+ """Called when a ``SELECT`` statement has no froms,
+ and no ``FROM`` clause is to be appended.
+
+ """
+ if self.stack:
+ stmt = self.stack[-1]["selectable"]
+ if stmt._where_criteria:
+ return " FROM DUAL"
+
+ return ""
+
+ def visit_random_func(self, fn, **kw):
+ return "rand%s" % self.function_argspec(fn)
+
+ def visit_sequence(self, seq, **kw):
+ return "nextval(%s)" % self.preparer.format_sequence(seq)
+
+ def visit_sysdate_func(self, fn, **kw):
+ return "SYSDATE()"
+
+ def _render_json_extract_from_binary(self, binary, operator, **kw):
+ # note we are intentionally calling upon the process() calls in the
+ # order in which they appear in the SQL String as this is used
+ # by positional parameter rendering
+
+ if binary.type._type_affinity is sqltypes.JSON:
+ return "JSON_EXTRACT(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ # for non-JSON, MySQL doesn't handle JSON null at all so it has to
+ # be explicit
+ case_expression = "CASE JSON_EXTRACT(%s, %s) WHEN 'null' THEN NULL" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ if binary.type._type_affinity is sqltypes.Integer:
+ type_expression = (
+ "ELSE CAST(JSON_EXTRACT(%s, %s) AS SIGNED INTEGER)"
+ % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ )
+ elif binary.type._type_affinity is sqltypes.Numeric:
+ if (
+ binary.type.scale is not None
+ and binary.type.precision is not None
+ ):
+ # using DECIMAL here because MySQL does not recognize NUMERIC
+ type_expression = (
+ "ELSE CAST(JSON_EXTRACT(%s, %s) AS DECIMAL(%s, %s))"
+ % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ binary.type.precision,
+ binary.type.scale,
+ )
+ )
+ else:
+ # FLOAT / REAL not added in MySQL til 8.0.17
+ type_expression = (
+ "ELSE JSON_EXTRACT(%s, %s)+0.0000000000000000000000"
+ % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ )
+ elif binary.type._type_affinity is sqltypes.Boolean:
+ # the NULL handling is particularly weird with boolean, so
+ # explicitly return true/false constants
+ type_expression = "WHEN true THEN true ELSE false"
+ elif binary.type._type_affinity is sqltypes.String:
+ # (gord): this fails with a JSON value that's a four byte unicode
+ # string. SQLite has the same problem at the moment
+ # (zzzeek): I'm not really sure. let's take a look at a test case
+ # that hits each backend and maybe make a requires rule for it?
+ type_expression = "ELSE JSON_UNQUOTE(JSON_EXTRACT(%s, %s))" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+ else:
+ # other affinity....this is not expected right now
+ type_expression = "ELSE JSON_EXTRACT(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ return case_expression + " " + type_expression + " END"
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ return self._render_json_extract_from_binary(binary, operator, **kw)
+
+ def visit_on_duplicate_key_update(self, on_duplicate, **kw):
+ statement = self.current_executable
+
+ if on_duplicate._parameter_ordering:
+ parameter_ordering = [
+ coercions.expect(roles.DMLColumnRole, key)
+ for key in on_duplicate._parameter_ordering
+ ]
+ ordered_keys = set(parameter_ordering)
+ cols = [
+ statement.table.c[key]
+ for key in parameter_ordering
+ if key in statement.table.c
+ ] + [c for c in statement.table.c if c.key not in ordered_keys]
+ else:
+ cols = statement.table.c
+
+ clauses = []
+ # traverses through all table columns to preserve table column order
+ for column in (col for col in cols if col.key in on_duplicate.update):
+
+ val = on_duplicate.update[column.key]
+
+ if coercions._is_literal(val):
+ val = elements.BindParameter(None, val, type_=column.type)
+ value_text = self.process(val.self_group(), use_schema=False)
+ else:
+
+ def replace(obj):
+ if (
+ isinstance(obj, elements.BindParameter)
+ and obj.type._isnull
+ ):
+ obj = obj._clone()
+ obj.type = column.type
+ return obj
+ elif (
+ isinstance(obj, elements.ColumnClause)
+ and obj.table is on_duplicate.inserted_alias
+ ):
+ obj = literal_column(
+ "VALUES(" + self.preparer.quote(obj.name) + ")"
+ )
+ return obj
+ else:
+ # element is not replaced
+ return None
+
+ val = visitors.replacement_traverse(val, {}, replace)
+ value_text = self.process(val.self_group(), use_schema=False)
+
+ name_text = self.preparer.quote(column.name)
+ clauses.append("%s = %s" % (name_text, value_text))
+
+ non_matching = set(on_duplicate.update) - set(c.key for c in cols)
+ if non_matching:
+ util.warn(
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
+ self.statement.table.name,
+ (", ".join("'%s'" % c for c in non_matching)),
+ )
+ )
+
+ return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses)
+
+ def visit_concat_op_binary(self, binary, operator, **kw):
+ return "concat(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ _match_valid_flag_combinations = frozenset(
+ (
+ # (boolean_mode, natural_language, query_expansion)
+ (False, False, False),
+ (True, False, False),
+ (False, True, False),
+ (False, False, True),
+ (False, True, True),
+ )
+ )
+
+ _match_flag_expressions = (
+ "IN BOOLEAN MODE",
+ "IN NATURAL LANGUAGE MODE",
+ "WITH QUERY EXPANSION",
+ )
+
+ def visit_mysql_match(self, element, **kw):
+ return self.visit_match_op_binary(element, element.operator, **kw)
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ """
+ Note that `mysql_boolean_mode` is enabled by default because of
+ backward compatibility
+ """
+
+ modifiers = binary.modifiers
+
+ boolean_mode = modifiers.get("mysql_boolean_mode", True)
+ natural_language = modifiers.get("mysql_natural_language", False)
+ query_expansion = modifiers.get("mysql_query_expansion", False)
+
+ flag_combination = (boolean_mode, natural_language, query_expansion)
+
+ if flag_combination not in self._match_valid_flag_combinations:
+ flags = (
+ "in_boolean_mode=%s" % boolean_mode,
+ "in_natural_language_mode=%s" % natural_language,
+ "with_query_expansion=%s" % query_expansion,
+ )
+
+ flags = ", ".join(flags)
+
+ raise exc.CompileError("Invalid MySQL match flags: %s" % flags)
+
+ match_clause = binary.left
+ match_clause = self.process(match_clause, **kw)
+ against_clause = self.process(binary.right, **kw)
+
+ if any(flag_combination):
+ flag_expressions = compress(
+ self._match_flag_expressions,
+ flag_combination,
+ )
+
+ against_clause = [against_clause]
+ against_clause.extend(flag_expressions)
+
+ against_clause = " ".join(against_clause)
+
+ return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause)
+
+ def get_from_hint_text(self, table, text):
+ return text
+
+ def visit_typeclause(self, typeclause, type_=None, **kw):
+ if type_ is None:
+ type_ = typeclause.type.dialect_impl(self.dialect)
+ if isinstance(type_, sqltypes.TypeDecorator):
+ return self.visit_typeclause(typeclause, type_.impl, **kw)
+ elif isinstance(type_, sqltypes.Integer):
+ if getattr(type_, "unsigned", False):
+ return "UNSIGNED INTEGER"
+ else:
+ return "SIGNED INTEGER"
+ elif isinstance(type_, sqltypes.TIMESTAMP):
+ return "DATETIME"
+ elif isinstance(
+ type_,
+ (
+ sqltypes.DECIMAL,
+ sqltypes.DateTime,
+ sqltypes.Date,
+ sqltypes.Time,
+ ),
+ ):
+ return self.dialect.type_compiler.process(type_)
+ elif isinstance(type_, sqltypes.String) and not isinstance(
+ type_, (ENUM, SET)
+ ):
+ adapted = CHAR._adapt_string_for_cast(type_)
+ return self.dialect.type_compiler.process(adapted)
+ elif isinstance(type_, sqltypes._Binary):
+ return "BINARY"
+ elif isinstance(type_, sqltypes.JSON):
+ return "JSON"
+ elif isinstance(type_, sqltypes.NUMERIC):
+ return self.dialect.type_compiler.process(type_).replace(
+ "NUMERIC", "DECIMAL"
+ )
+ elif (
+ isinstance(type_, sqltypes.Float)
+ and self.dialect._support_float_cast
+ ):
+ return self.dialect.type_compiler.process(type_)
+ else:
+ return None
+
+ def visit_cast(self, cast, **kw):
+ type_ = self.process(cast.typeclause)
+ if type_ is None:
+ util.warn(
+ "Datatype %s does not support CAST on MySQL/MariaDb; "
+ "the CAST will be skipped."
+ % self.dialect.type_compiler.process(cast.typeclause.type)
+ )
+ return self.process(cast.clause.self_group(), **kw)
+
+ return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_)
+
+ def render_literal_value(self, value, type_):
+ value = super(MySQLCompiler, self).render_literal_value(value, type_)
+ if self.dialect._backslash_escapes:
+ value = value.replace("\\", "\\\\")
+ return value
+
+ # override native_boolean=False behavior here, as
+ # MySQL still supports native boolean
+ def visit_true(self, element, **kw):
+ return "true"
+
+ def visit_false(self, element, **kw):
+ return "false"
+
+ def get_select_precolumns(self, select, **kw):
+ """Add special MySQL keywords in place of DISTINCT.
+
+ .. deprecated 1.4:: this usage is deprecated.
+ :meth:`_expression.Select.prefix_with` should be used for special
+ keywords at the start of a SELECT.
+
+ """
+ if isinstance(select._distinct, util.string_types):
+ util.warn_deprecated(
+ "Sending string values for 'distinct' is deprecated in the "
+ "MySQL dialect and will be removed in a future release. "
+ "Please use :meth:`.Select.prefix_with` for special keywords "
+ "at the start of a SELECT statement",
+ version="1.4",
+ )
+ return select._distinct.upper() + " "
+
+ return super(MySQLCompiler, self).get_select_precolumns(select, **kw)
+
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
+ if join.full:
+ join_type = " FULL OUTER JOIN "
+ elif join.isouter:
+ join_type = " LEFT OUTER JOIN "
+ else:
+ join_type = " INNER JOIN "
+
+ return "".join(
+ (
+ self.process(
+ join.left, asfrom=True, from_linter=from_linter, **kwargs
+ ),
+ join_type,
+ self.process(
+ join.right, asfrom=True, from_linter=from_linter, **kwargs
+ ),
+ " ON ",
+ self.process(join.onclause, from_linter=from_linter, **kwargs),
+ )
+ )
+
+ def for_update_clause(self, select, **kw):
+ if select._for_update_arg.read:
+ tmp = " LOCK IN SHARE MODE"
+ else:
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of and self.dialect.supports_for_update_of:
+
+ tables = util.OrderedSet()
+ for c in select._for_update_arg.of:
+ tables.update(sql_util.surface_selectables_only(c))
+
+ tmp += " OF " + ", ".join(
+ self.process(table, ashint=True, use_schema=False, **kw)
+ for table in tables
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+
+ if select._for_update_arg.skip_locked:
+ tmp += " SKIP LOCKED"
+
+ return tmp
+
+ def limit_clause(self, select, **kw):
+ # MySQL supports:
+ # LIMIT <limit>
+ # LIMIT <offset>, <limit>
+ # and in server versions > 3.3:
+ # LIMIT <limit> OFFSET <offset>
+ # The latter is more readable for offsets but we're stuck with the
+ # former until we can refine dialects by server revision.
+
+ limit_clause, offset_clause = (
+ select._limit_clause,
+ select._offset_clause,
+ )
+
+ if limit_clause is None and offset_clause is None:
+ return ""
+ elif offset_clause is not None:
+ # As suggested by the MySQL docs, need to apply an
+ # artificial limit if one wasn't provided
+ # https://dev.mysql.com/doc/refman/5.0/en/select.html
+ if limit_clause is None:
+ # hardwire the upper limit. Currently
+ # needed by OurSQL with Python 3
+ # (https://bugs.launchpad.net/oursql/+bug/686232),
+ # but also is consistent with the usage of the upper
+ # bound as part of MySQL's "syntax" for OFFSET with
+ # no LIMIT
+ return " \n LIMIT %s, %s" % (
+ self.process(offset_clause, **kw),
+ "18446744073709551615",
+ )
+ else:
+ return " \n LIMIT %s, %s" % (
+ self.process(offset_clause, **kw),
+ self.process(limit_clause, **kw),
+ )
+ else:
+ # No offset provided, so just use the limit
+ return " \n LIMIT %s" % (self.process(limit_clause, **kw),)
+
+ def update_limit_clause(self, update_stmt):
+ limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None)
+ if limit:
+ return "LIMIT %s" % limit
+ else:
+ return None
+
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ kw["asfrom"] = True
+ return ", ".join(
+ t._compiler_dispatch(self, **kw)
+ for t in [from_table] + list(extra_froms)
+ )
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return None
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ """If we have extra froms make sure we render any alias as hint."""
+ ashint = False
+ if extra_froms:
+ ashint = True
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, ashint=ashint
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. USING clause specific to MySQL."""
+ kw["asfrom"] = True
+ return "USING " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+ def visit_empty_set_expr(self, element_types):
+ return (
+ "SELECT %(outer)s FROM (SELECT %(inner)s) "
+ "as _empty_set WHERE 1!=1"
+ % {
+ "inner": ", ".join(
+ "1 AS _in_%s" % idx
+ for idx, type_ in enumerate(element_types)
+ ),
+ "outer": ", ".join(
+ "_in_%s" % idx for idx, type_ in enumerate(element_types)
+ ),
+ }
+ )
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "NOT (%s <=> %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "%s <=> %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def _mariadb_regexp_flags(self, flags, pattern, **kw):
+ return "CONCAT('(?', %s, ')', %s)" % (
+ self.process(flags, **kw),
+ self.process(pattern, **kw),
+ )
+
+ def _regexp_match(self, op_string, binary, operator, **kw):
+ flags = binary.modifiers["flags"]
+ if flags is None:
+ return self._generate_generic_binary(binary, op_string, **kw)
+ elif self.dialect.is_mariadb:
+ return "%s%s%s" % (
+ self.process(binary.left, **kw),
+ op_string,
+ self._mariadb_regexp_flags(flags, binary.right),
+ )
+ else:
+ text = "REGEXP_LIKE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ self.process(flags, **kw),
+ )
+ if op_string == " NOT REGEXP ":
+ return "NOT %s" % text
+ else:
+ return text
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match(" REGEXP ", binary, operator, **kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match(" NOT REGEXP ", binary, operator, **kw)
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ flags = binary.modifiers["flags"]
+ replacement = binary.modifiers["replacement"]
+ if flags is None:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ self.process(replacement, **kw),
+ )
+ elif self.dialect.is_mariadb:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self._mariadb_regexp_flags(flags, binary.right),
+ self.process(replacement, **kw),
+ )
+ else:
+ return "REGEXP_REPLACE(%s, %s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ self.process(replacement, **kw),
+ self.process(flags, **kw),
+ )
+
+
+class MySQLDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kw):
+ """Builds column DDL."""
+
+ colspec = [
+ self.preparer.format_column(column),
+ self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ ),
+ ]
+
+ if column.computed is not None:
+ colspec.append(self.process(column.computed))
+
+ is_timestamp = isinstance(
+ column.type._unwrapped_dialect_impl(self.dialect),
+ sqltypes.TIMESTAMP,
+ )
+
+ if not column.nullable:
+ colspec.append("NOT NULL")
+
+ # see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa
+ elif column.nullable and is_timestamp:
+ colspec.append("NULL")
+
+ comment = column.comment
+ if comment is not None:
+ literal = self.sql_compiler.render_literal_value(
+ comment, sqltypes.String()
+ )
+ colspec.append("COMMENT " + literal)
+
+ if (
+ column.table is not None
+ and column is column.table._autoincrement_column
+ and (
+ column.server_default is None
+ or isinstance(column.server_default, sa_schema.Identity)
+ )
+ and not (
+ self.dialect.supports_sequences
+ and isinstance(column.default, sa_schema.Sequence)
+ and not column.default.optional
+ )
+ ):
+ colspec.append("AUTO_INCREMENT")
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec.append("DEFAULT " + default)
+ return " ".join(colspec)
+
+ def post_create_table(self, table):
+ """Build table-level CREATE options like ENGINE and COLLATE."""
+
+ table_opts = []
+
+ opts = dict(
+ (k[len(self.dialect.name) + 1 :].upper(), v)
+ for k, v in table.kwargs.items()
+ if k.startswith("%s_" % self.dialect.name)
+ )
+
+ if table.comment is not None:
+ opts["COMMENT"] = table.comment
+
+ partition_options = [
+ "PARTITION_BY",
+ "PARTITIONS",
+ "SUBPARTITIONS",
+ "SUBPARTITION_BY",
+ ]
+
+ nonpart_options = set(opts).difference(partition_options)
+ part_options = set(opts).intersection(partition_options)
+
+ for opt in topological.sort(
+ [
+ ("DEFAULT_CHARSET", "COLLATE"),
+ ("DEFAULT_CHARACTER_SET", "COLLATE"),
+ ("CHARSET", "COLLATE"),
+ ("CHARACTER_SET", "COLLATE"),
+ ],
+ nonpart_options,
+ ):
+ arg = opts[opt]
+ if opt in _reflection._options_of_type_string:
+
+ arg = self.sql_compiler.render_literal_value(
+ arg, sqltypes.String()
+ )
+
+ if opt in (
+ "DATA_DIRECTORY",
+ "INDEX_DIRECTORY",
+ "DEFAULT_CHARACTER_SET",
+ "CHARACTER_SET",
+ "DEFAULT_CHARSET",
+ "DEFAULT_COLLATE",
+ ):
+ opt = opt.replace("_", " ")
+
+ joiner = "="
+ if opt in (
+ "TABLESPACE",
+ "DEFAULT CHARACTER SET",
+ "CHARACTER SET",
+ "COLLATE",
+ ):
+ joiner = " "
+
+ table_opts.append(joiner.join((opt, arg)))
+
+ for opt in topological.sort(
+ [
+ ("PARTITION_BY", "PARTITIONS"),
+ ("PARTITION_BY", "SUBPARTITION_BY"),
+ ("PARTITION_BY", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITIONS"),
+ ("PARTITIONS", "SUBPARTITION_BY"),
+ ("SUBPARTITION_BY", "SUBPARTITIONS"),
+ ],
+ part_options,
+ ):
+ arg = opts[opt]
+ if opt in _reflection._options_of_type_string:
+ arg = self.sql_compiler.render_literal_value(
+ arg, sqltypes.String()
+ )
+
+ opt = opt.replace("_", " ")
+ joiner = " "
+
+ table_opts.append(joiner.join((opt, arg)))
+
+ return " ".join(table_opts)
+
+ def visit_create_index(self, create, **kw):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ table = preparer.format_table(index.table)
+
+ columns = [
+ self.sql_compiler.process(
+ elements.Grouping(expr)
+ if (
+ isinstance(expr, elements.BinaryExpression)
+ or (
+ isinstance(expr, elements.UnaryExpression)
+ and expr.modifier
+ not in (operators.desc_op, operators.asc_op)
+ )
+ or isinstance(expr, functions.FunctionElement)
+ )
+ else expr,
+ include_table=False,
+ literal_binds=True,
+ )
+ for expr in index.expressions
+ ]
+
+ name = self._prepared_index_name(index)
+
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+
+ index_prefix = index.kwargs.get("%s_prefix" % self.dialect.name, None)
+ if index_prefix:
+ text += index_prefix + " "
+
+ text += "INDEX "
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+ text += "%s ON %s " % (name, table)
+
+ length = index.dialect_options[self.dialect.name]["length"]
+ if length is not None:
+
+ if isinstance(length, dict):
+ # length value can be a (column_name --> integer value)
+ # mapping specifying the prefix length for each column of the
+ # index
+ columns = ", ".join(
+ "%s(%d)" % (expr, length[col.name])
+ if col.name in length
+ else (
+ "%s(%d)" % (expr, length[expr])
+ if expr in length
+ else "%s" % expr
+ )
+ for col, expr in zip(index.expressions, columns)
+ )
+ else:
+ # or can be an integer value specifying the same
+ # prefix length for all columns of the index
+ columns = ", ".join(
+ "%s(%d)" % (col, length) for col in columns
+ )
+ else:
+ columns = ", ".join(columns)
+ text += "(%s)" % columns
+
+ parser = index.dialect_options["mysql"]["with_parser"]
+ if parser is not None:
+ text += " WITH PARSER %s" % (parser,)
+
+ using = index.dialect_options["mysql"]["using"]
+ if using is not None:
+ text += " USING %s" % (preparer.quote(using))
+
+ return text
+
+ def visit_primary_key_constraint(self, constraint):
+ text = super(MySQLDDLCompiler, self).visit_primary_key_constraint(
+ constraint
+ )
+ using = constraint.dialect_options["mysql"]["using"]
+ if using:
+ text += " USING %s" % (self.preparer.quote(using))
+ return text
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+ text = "\nDROP INDEX "
+ if drop.if_exists:
+ text += "IF EXISTS "
+
+ return text + "%s ON %s" % (
+ self._prepared_index_name(index, include_schema=False),
+ self.preparer.format_table(index.table),
+ )
+
+ def visit_drop_constraint(self, drop):
+ constraint = drop.element
+ if isinstance(constraint, sa_schema.ForeignKeyConstraint):
+ qual = "FOREIGN KEY "
+ const = self.preparer.format_constraint(constraint)
+ elif isinstance(constraint, sa_schema.PrimaryKeyConstraint):
+ qual = "PRIMARY KEY "
+ const = ""
+ elif isinstance(constraint, sa_schema.UniqueConstraint):
+ qual = "INDEX "
+ const = self.preparer.format_constraint(constraint)
+ elif isinstance(constraint, sa_schema.CheckConstraint):
+ if self.dialect.is_mariadb:
+ qual = "CONSTRAINT "
+ else:
+ qual = "CHECK "
+ const = self.preparer.format_constraint(constraint)
+ else:
+ qual = ""
+ const = self.preparer.format_constraint(constraint)
+ return "ALTER TABLE %s DROP %s%s" % (
+ self.preparer.format_table(constraint.table),
+ qual,
+ const,
+ )
+
+ def define_constraint_match(self, constraint):
+ if constraint.match is not None:
+ raise exc.CompileError(
+ "MySQL ignores the 'MATCH' keyword while at the same time "
+ "causes ON UPDATE/ON DELETE clauses to be ignored."
+ )
+ return ""
+
+ def visit_set_table_comment(self, create):
+ return "ALTER TABLE %s COMMENT %s" % (
+ self.preparer.format_table(create.element),
+ self.sql_compiler.render_literal_value(
+ create.element.comment, sqltypes.String()
+ ),
+ )
+
+ def visit_drop_table_comment(self, create):
+ return "ALTER TABLE %s COMMENT ''" % (
+ self.preparer.format_table(create.element)
+ )
+
+ def visit_set_column_comment(self, create):
+ return "ALTER TABLE %s CHANGE %s %s" % (
+ self.preparer.format_table(create.element.table),
+ self.preparer.format_column(create.element),
+ self.get_column_specification(create.element),
+ )
+
+
+class MySQLTypeCompiler(compiler.GenericTypeCompiler):
+ def _extend_numeric(self, type_, spec):
+ "Extend a numeric-type declaration with MySQL specific extensions."
+
+ if not self._mysql_type(type_):
+ return spec
+
+ if type_.unsigned:
+ spec += " UNSIGNED"
+ if type_.zerofill:
+ spec += " ZEROFILL"
+ return spec
+
+ def _extend_string(self, type_, defaults, spec):
+ """Extend a string-type declaration with standard SQL CHARACTER SET /
+ COLLATE annotations and MySQL specific extensions.
+
+ """
+
+ def attr(name):
+ return getattr(type_, name, defaults.get(name))
+
+ if attr("charset"):
+ charset = "CHARACTER SET %s" % attr("charset")
+ elif attr("ascii"):
+ charset = "ASCII"
+ elif attr("unicode"):
+ charset = "UNICODE"
+ else:
+ charset = None
+
+ if attr("collation"):
+ collation = "COLLATE %s" % type_.collation
+ elif attr("binary"):
+ collation = "BINARY"
+ else:
+ collation = None
+
+ if attr("national"):
+ # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets.
+ return " ".join(
+ [c for c in ("NATIONAL", spec, collation) if c is not None]
+ )
+ return " ".join(
+ [c for c in (spec, charset, collation) if c is not None]
+ )
+
+ def _mysql_type(self, type_):
+ return isinstance(type_, (_StringType, _NumericType))
+
+ def visit_NUMERIC(self, type_, **kw):
+ if type_.precision is None:
+ return self._extend_numeric(type_, "NUMERIC")
+ elif type_.scale is None:
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s)" % {"precision": type_.precision},
+ )
+ else:
+ return self._extend_numeric(
+ type_,
+ "NUMERIC(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+
+ def visit_DECIMAL(self, type_, **kw):
+ if type_.precision is None:
+ return self._extend_numeric(type_, "DECIMAL")
+ elif type_.scale is None:
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s)" % {"precision": type_.precision},
+ )
+ else:
+ return self._extend_numeric(
+ type_,
+ "DECIMAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+
+ def visit_DOUBLE(self, type_, **kw):
+ if type_.precision is not None and type_.scale is not None:
+ return self._extend_numeric(
+ type_,
+ "DOUBLE(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+ else:
+ return self._extend_numeric(type_, "DOUBLE")
+
+ def visit_REAL(self, type_, **kw):
+ if type_.precision is not None and type_.scale is not None:
+ return self._extend_numeric(
+ type_,
+ "REAL(%(precision)s, %(scale)s)"
+ % {"precision": type_.precision, "scale": type_.scale},
+ )
+ else:
+ return self._extend_numeric(type_, "REAL")
+
+ def visit_FLOAT(self, type_, **kw):
+ if (
+ self._mysql_type(type_)
+ and type_.scale is not None
+ and type_.precision is not None
+ ):
+ return self._extend_numeric(
+ type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)
+ )
+ elif type_.precision is not None:
+ return self._extend_numeric(
+ type_, "FLOAT(%s)" % (type_.precision,)
+ )
+ else:
+ return self._extend_numeric(type_, "FLOAT")
+
+ def visit_INTEGER(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "INTEGER(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "INTEGER")
+
+ def visit_BIGINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "BIGINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "BIGINT")
+
+ def visit_MEDIUMINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "MEDIUMINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "MEDIUMINT")
+
+ def visit_TINYINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_, "TINYINT(%s)" % type_.display_width
+ )
+ else:
+ return self._extend_numeric(type_, "TINYINT")
+
+ def visit_SMALLINT(self, type_, **kw):
+ if self._mysql_type(type_) and type_.display_width is not None:
+ return self._extend_numeric(
+ type_,
+ "SMALLINT(%(display_width)s)"
+ % {"display_width": type_.display_width},
+ )
+ else:
+ return self._extend_numeric(type_, "SMALLINT")
+
+ def visit_BIT(self, type_, **kw):
+ if type_.length is not None:
+ return "BIT(%s)" % type_.length
+ else:
+ return "BIT"
+
+ def visit_DATETIME(self, type_, **kw):
+ if getattr(type_, "fsp", None):
+ return "DATETIME(%d)" % type_.fsp
+ else:
+ return "DATETIME"
+
+ def visit_DATE(self, type_, **kw):
+ return "DATE"
+
+ def visit_TIME(self, type_, **kw):
+ if getattr(type_, "fsp", None):
+ return "TIME(%d)" % type_.fsp
+ else:
+ return "TIME"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ if getattr(type_, "fsp", None):
+ return "TIMESTAMP(%d)" % type_.fsp
+ else:
+ return "TIMESTAMP"
+
+ def visit_YEAR(self, type_, **kw):
+ if type_.display_width is None:
+ return "YEAR"
+ else:
+ return "YEAR(%s)" % type_.display_width
+
+ def visit_TEXT(self, type_, **kw):
+ if type_.length:
+ return self._extend_string(type_, {}, "TEXT(%d)" % type_.length)
+ else:
+ return self._extend_string(type_, {}, "TEXT")
+
+ def visit_TINYTEXT(self, type_, **kw):
+ return self._extend_string(type_, {}, "TINYTEXT")
+
+ def visit_MEDIUMTEXT(self, type_, **kw):
+ return self._extend_string(type_, {}, "MEDIUMTEXT")
+
+ def visit_LONGTEXT(self, type_, **kw):
+ return self._extend_string(type_, {}, "LONGTEXT")
+
+ def visit_VARCHAR(self, type_, **kw):
+ if type_.length:
+ return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length)
+ else:
+ raise exc.CompileError(
+ "VARCHAR requires a length on dialect %s" % self.dialect.name
+ )
+
+ def visit_CHAR(self, type_, **kw):
+ if type_.length:
+ return self._extend_string(
+ type_, {}, "CHAR(%(length)s)" % {"length": type_.length}
+ )
+ else:
+ return self._extend_string(type_, {}, "CHAR")
+
+ def visit_NVARCHAR(self, type_, **kw):
+ # We'll actually generate the equiv. "NATIONAL VARCHAR" instead
+ # of "NVARCHAR".
+ if type_.length:
+ return self._extend_string(
+ type_,
+ {"national": True},
+ "VARCHAR(%(length)s)" % {"length": type_.length},
+ )
+ else:
+ raise exc.CompileError(
+ "NVARCHAR requires a length on dialect %s" % self.dialect.name
+ )
+
+ def visit_NCHAR(self, type_, **kw):
+ # We'll actually generate the equiv.
+ # "NATIONAL CHAR" instead of "NCHAR".
+ if type_.length:
+ return self._extend_string(
+ type_,
+ {"national": True},
+ "CHAR(%(length)s)" % {"length": type_.length},
+ )
+ else:
+ return self._extend_string(type_, {"national": True}, "CHAR")
+
+ def visit_VARBINARY(self, type_, **kw):
+ return "VARBINARY(%d)" % type_.length
+
+ def visit_JSON(self, type_, **kw):
+ return "JSON"
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_)
+
+ def visit_enum(self, type_, **kw):
+ if not type_.native_enum:
+ return super(MySQLTypeCompiler, self).visit_enum(type_)
+ else:
+ return self._visit_enumerated_values("ENUM", type_, type_.enums)
+
+ def visit_BLOB(self, type_, **kw):
+ if type_.length:
+ return "BLOB(%d)" % type_.length
+ else:
+ return "BLOB"
+
+ def visit_TINYBLOB(self, type_, **kw):
+ return "TINYBLOB"
+
+ def visit_MEDIUMBLOB(self, type_, **kw):
+ return "MEDIUMBLOB"
+
+ def visit_LONGBLOB(self, type_, **kw):
+ return "LONGBLOB"
+
+ def _visit_enumerated_values(self, name, type_, enumerated_values):
+ quoted_enums = []
+ for e in enumerated_values:
+ quoted_enums.append("'%s'" % e.replace("'", "''"))
+ return self._extend_string(
+ type_, {}, "%s(%s)" % (name, ",".join(quoted_enums))
+ )
+
+ def visit_ENUM(self, type_, **kw):
+ return self._visit_enumerated_values("ENUM", type_, type_.enums)
+
+ def visit_SET(self, type_, **kw):
+ return self._visit_enumerated_values("SET", type_, type_.values)
+
+ def visit_BOOLEAN(self, type_, **kw):
+ return "BOOL"
+
+
+class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS_MYSQL
+
+ def __init__(self, dialect, server_ansiquotes=False, **kw):
+ if not server_ansiquotes:
+ quote = "`"
+ else:
+ quote = '"'
+
+ super(MySQLIdentifierPreparer, self).__init__(
+ dialect, initial_quote=quote, escape_quote=quote
+ )
+
+ def _quote_free_identifiers(self, *ids):
+ """Unilaterally identifier-quote any number of strings."""
+
+ return tuple([self.quote_identifier(i) for i in ids if i is not None])
+
+
+class MariaDBIdentifierPreparer(MySQLIdentifierPreparer):
+ reserved_words = RESERVED_WORDS_MARIADB
+
+
+@log.class_logger
+class MySQLDialect(default.DefaultDialect):
+ """Details of the MySQL dialect.
+ Not used directly in application code.
+ """
+
+ name = "mysql"
+ supports_statement_cache = True
+
+ supports_alter = True
+
+ # MySQL has no true "boolean" type; we
+ # allow for the "true" and "false" keywords, however
+ supports_native_boolean = False
+
+ # identifiers are 64, however aliases can be 255...
+ max_identifier_length = 255
+ max_index_name_length = 64
+ max_constraint_name_length = 64
+
+ supports_native_enum = True
+
+ supports_sequences = False # default for MySQL ...
+ # ... may be updated to True for MariaDB 10.3+ in initialize()
+
+ sequences_optional = False
+
+ supports_for_update_of = False # default for MySQL ...
+ # ... may be updated to True for MySQL 8+ in initialize()
+
+ # MySQL doesn't support "DEFAULT VALUES" but *does* support
+ # "VALUES (DEFAULT)"
+ supports_default_values = False
+ supports_default_metavalue = True
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+ supports_multivalues_insert = True
+
+ supports_comments = True
+ inline_comments = True
+ default_paramstyle = "format"
+ colspecs = colspecs
+
+ cte_follows_insert = True
+
+ statement_compiler = MySQLCompiler
+ ddl_compiler = MySQLDDLCompiler
+ type_compiler = MySQLTypeCompiler
+ ischema_names = ischema_names
+ preparer = MySQLIdentifierPreparer
+
+ is_mariadb = False
+ _mariadb_normalized_version_info = None
+
+ # default SQL compilation settings -
+ # these are modified upon initialize(),
+ # i.e. first connect
+ _backslash_escapes = True
+ _server_ansiquotes = False
+
+ construct_arguments = [
+ (sa_schema.Table, {"*": None}),
+ (sql.Update, {"limit": None}),
+ (sa_schema.PrimaryKeyConstraint, {"using": None}),
+ (
+ sa_schema.Index,
+ {
+ "using": None,
+ "length": None,
+ "prefix": None,
+ "with_parser": None,
+ },
+ ),
+ ]
+
+ def __init__(
+ self,
+ isolation_level=None,
+ json_serializer=None,
+ json_deserializer=None,
+ is_mariadb=None,
+ **kwargs
+ ):
+ kwargs.pop("use_ansiquotes", None) # legacy
+ default.DefaultDialect.__init__(self, **kwargs)
+ self.isolation_level = isolation_level
+ self._json_serializer = json_serializer
+ self._json_deserializer = json_deserializer
+ self._set_mariadb(is_mariadb, None)
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+
+ # adjust for ConnectionFairy being present
+ # allows attribute set e.g. "connection.autocommit = True"
+ # to work properly
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ self._set_isolation_level(connection, level)
+
+ def _set_isolation_level(self, connection, level):
+ if level not in self._isolation_lookup:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+ cursor = connection.cursor()
+ cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level)
+ cursor.execute("COMMIT")
+ cursor.close()
+
+ def get_isolation_level(self, connection):
+ cursor = connection.cursor()
+ if self._is_mysql and self.server_version_info >= (5, 7, 20):
+ cursor.execute("SELECT @@transaction_isolation")
+ else:
+ cursor.execute("SELECT @@tx_isolation")
+ row = cursor.fetchone()
+ if row is None:
+ util.warn(
+ "Could not retrieve transaction isolation level for MySQL "
+ "connection."
+ )
+ raise NotImplementedError()
+ val = row[0]
+ cursor.close()
+ if util.py3k and isinstance(val, bytes):
+ val = val.decode()
+ return val.upper().replace("-", " ")
+
+ @classmethod
+ def _is_mariadb_from_url(cls, url):
+ dbapi = cls.dbapi()
+ dialect = cls(dbapi=dbapi)
+
+ cargs, cparams = dialect.create_connect_args(url)
+ conn = dialect.connect(*cargs, **cparams)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
+ val = cursor.fetchone()[0]
+ except:
+ raise
+ else:
+ return bool(val)
+ finally:
+ conn.close()
+
+ def _get_server_version_info(self, connection):
+ # get database server version info explicitly over the wire
+ # to avoid proxy servers like MaxScale getting in the
+ # way with their own values, see #4205
+ dbapi_con = connection.connection
+ cursor = dbapi_con.cursor()
+ cursor.execute("SELECT VERSION()")
+ val = cursor.fetchone()[0]
+ cursor.close()
+ if util.py3k and isinstance(val, bytes):
+ val = val.decode()
+
+ return self._parse_server_version(val)
+
+ def _parse_server_version(self, val):
+ version = []
+ is_mariadb = False
+
+ r = re.compile(r"[.\-+]")
+ tokens = r.split(val)
+ for token in tokens:
+ parsed_token = re.match(
+ r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token
+ )
+ if not parsed_token:
+ continue
+ elif parsed_token.group(2):
+ self._mariadb_normalized_version_info = tuple(version[-3:])
+ is_mariadb = True
+ else:
+ digit = int(parsed_token.group(1))
+ version.append(digit)
+
+ server_version_info = tuple(version)
+
+ self._set_mariadb(server_version_info and is_mariadb, val)
+
+ if not is_mariadb:
+ self._mariadb_normalized_version_info = server_version_info
+
+ if server_version_info < (5, 0, 2):
+ raise NotImplementedError(
+ "the MySQL/MariaDB dialect supports server "
+ "version info 5.0.2 and above."
+ )
+
+ # setting it here to help w the test suite
+ self.server_version_info = server_version_info
+ return server_version_info
+
+ def _set_mariadb(self, is_mariadb, server_version_info):
+ if is_mariadb is None:
+ return
+
+ if not is_mariadb and self.is_mariadb:
+ raise exc.InvalidRequestError(
+ "MySQL version %s is not a MariaDB variant."
+ % (server_version_info,)
+ )
+ if is_mariadb:
+ self.preparer = MariaDBIdentifierPreparer
+ # this would have been set by the default dialect already,
+ # so set it again
+ self.identifier_preparer = self.preparer(self)
+ self.is_mariadb = is_mariadb
+
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(sql.text("XA END :xid"), dict(xid=xid))
+ connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid))
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ connection.execute(sql.text("XA END :xid"), dict(xid=xid))
+ connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid))
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid))
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.exec_driver_sql("XA RECOVER")
+ return [row["data"][0 : row["gtrid_length"]] for row in resultset]
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e,
+ (
+ self.dbapi.OperationalError,
+ self.dbapi.ProgrammingError,
+ self.dbapi.InterfaceError,
+ ),
+ ) and self._extract_error_code(e) in (
+ 1927,
+ 2006,
+ 2013,
+ 2014,
+ 2045,
+ 2055,
+ 4031,
+ ):
+ return True
+ elif isinstance(
+ e, (self.dbapi.InterfaceError, self.dbapi.InternalError)
+ ):
+ # if underlying connection is closed,
+ # this is the error you get
+ return "(0, '')" in str(e)
+ else:
+ return False
+
+ def _compat_fetchall(self, rp, charset=None):
+ """Proxy result rows to smooth over MySQL-Python driver
+ inconsistencies."""
+
+ return [_DecodingRow(row, charset) for row in rp.fetchall()]
+
+ def _compat_fetchone(self, rp, charset=None):
+ """Proxy a result row to smooth over MySQL-Python driver
+ inconsistencies."""
+
+ row = rp.fetchone()
+ if row:
+ return _DecodingRow(row, charset)
+ else:
+ return None
+
+ def _compat_first(self, rp, charset=None):
+ """Proxy a result row to smooth over MySQL-Python driver
+ inconsistencies."""
+
+ row = rp.first()
+ if row:
+ return _DecodingRow(row, charset)
+ else:
+ return None
+
+ def _extract_error_code(self, exception):
+ raise NotImplementedError()
+
+ def _get_default_schema_name(self, connection):
+ return connection.exec_driver_sql("SELECT DATABASE()").scalar()
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ if schema is None:
+ schema = self.default_schema_name
+
+ rs = connection.execute(
+ text(
+ "SELECT COUNT(*) FROM information_schema.tables WHERE "
+ "table_schema = :table_schema AND "
+ "table_name = :table_name"
+ ).bindparams(
+ sql.bindparam("table_schema", type_=Unicode),
+ sql.bindparam("table_name", type_=Unicode),
+ ),
+ {
+ "table_schema": util.text_type(schema),
+ "table_name": util.text_type(table_name),
+ },
+ )
+ return bool(rs.scalar())
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ if not self.supports_sequences:
+ self._sequences_not_supported()
+ if not schema:
+ schema = self.default_schema_name
+ # MariaDB implements sequences as a special type of table
+ #
+ cursor = connection.execute(
+ sql.text(
+ "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
+ "WHERE TABLE_TYPE='SEQUENCE' and TABLE_NAME=:name AND "
+ "TABLE_SCHEMA=:schema_name"
+ ),
+ dict(
+ name=util.text_type(sequence_name),
+ schema_name=util.text_type(schema),
+ ),
+ )
+ return cursor.first() is not None
+
+ def _sequences_not_supported(self):
+ raise NotImplementedError(
+ "Sequences are supported only by the "
+ "MariaDB series 10.3 or greater"
+ )
+
+ @reflection.cache
+ def get_sequence_names(self, connection, schema=None, **kw):
+ if not self.supports_sequences:
+ self._sequences_not_supported()
+ if not schema:
+ schema = self.default_schema_name
+ # MariaDB implements sequences as a special type of table
+ cursor = connection.execute(
+ sql.text(
+ "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
+ "WHERE TABLE_TYPE='SEQUENCE' and TABLE_SCHEMA=:schema_name"
+ ),
+ dict(schema_name=schema),
+ )
+ return [
+ row[0]
+ for row in self._compat_fetchall(
+ cursor, charset=self._connection_charset
+ )
+ ]
+
+ def initialize(self, connection):
+ # this is driver-based, does not need server version info
+ # and is fairly critical for even basic SQL operations
+ self._connection_charset = self._detect_charset(connection)
+
+ # call super().initialize() because we need to have
+ # server_version_info set up. in 1.4 under python 2 only this does the
+ # "check unicode returns" thing, which is the one area that some
+ # SQL gets compiled within initialize() currently
+ default.DefaultDialect.initialize(self, connection)
+
+ self._detect_sql_mode(connection)
+ self._detect_ansiquotes(connection) # depends on sql mode
+ self._detect_casing(connection)
+ if self._server_ansiquotes:
+ # if ansiquotes == True, build a new IdentifierPreparer
+ # with the new setting
+ self.identifier_preparer = self.preparer(
+ self, server_ansiquotes=self._server_ansiquotes
+ )
+
+ self.supports_sequences = (
+ self.is_mariadb and self.server_version_info >= (10, 3)
+ )
+
+ self.supports_for_update_of = (
+ self._is_mysql and self.server_version_info >= (8,)
+ )
+
+ self._needs_correct_for_88718_96365 = (
+ not self.is_mariadb and self.server_version_info >= (8,)
+ )
+
+ self._warn_for_known_db_issues()
+
+ def _warn_for_known_db_issues(self):
+ if self.is_mariadb:
+ mdb_version = self._mariadb_normalized_version_info
+ if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
+ util.warn(
+ "MariaDB %r before 10.2.9 has known issues regarding "
+ "CHECK constraints, which impact handling of NULL values "
+ "with SQLAlchemy's boolean datatype (MDEV-13596). An "
+ "additional issue prevents proper migrations of columns "
+ "with CHECK constraints (MDEV-11114). Please upgrade to "
+ "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
+ "series, to avoid these issues." % (mdb_version,)
+ )
+
+ @property
+ def _support_float_cast(self):
+ if not self.server_version_info:
+ return False
+ elif self.is_mariadb:
+ # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+ return self.server_version_info >= (10, 4, 5)
+ else:
+ # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa
+ return self.server_version_info >= (8, 0, 17)
+
+ @property
+ def _is_mariadb(self):
+ return self.is_mariadb
+
+ @property
+ def _is_mysql(self):
+ return not self.is_mariadb
+
+ @property
+ def _is_mariadb_102(self):
+ return self.is_mariadb and self._mariadb_normalized_version_info > (
+ 10,
+ 2,
+ )
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ rp = connection.exec_driver_sql("SHOW schemas")
+ return [r[0] for r in rp]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ """Return a Unicode SHOW TABLES from a given schema."""
+ if schema is not None:
+ current_schema = schema
+ else:
+ current_schema = self.default_schema_name
+
+ charset = self._connection_charset
+
+ rp = connection.exec_driver_sql(
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(current_schema)
+ )
+
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] == "BASE TABLE"
+ ]
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+ charset = self._connection_charset
+ rp = connection.exec_driver_sql(
+ "SHOW FULL TABLES FROM %s"
+ % self.identifier_preparer.quote_identifier(schema)
+ )
+ return [
+ row[0]
+ for row in self._compat_fetchall(rp, charset=charset)
+ if row[1] in ("VIEW", "SYSTEM VIEW")
+ ]
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ return parsed_state.table_options
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ return parsed_state.columns
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ for key in parsed_state.keys:
+ if key["type"] == "PRIMARY":
+ # There can be only one.
+ cols = [s[0] for s in key["columns"]]
+ return {"constrained_columns": cols, "name": None}
+ return {"constrained_columns": [], "name": None}
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ default_schema = None
+
+ fkeys = []
+
+ for spec in parsed_state.fk_constraints:
+ ref_name = spec["table"][-1]
+ ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema
+
+ if not ref_schema:
+ if default_schema is None:
+ default_schema = connection.dialect.default_schema_name
+ if schema == default_schema:
+ ref_schema = schema
+
+ loc_names = spec["local"]
+ ref_names = spec["foreign"]
+
+ con_kw = {}
+ for opt in ("onupdate", "ondelete"):
+ if spec.get(opt, False) not in ("NO ACTION", None):
+ con_kw[opt] = spec[opt]
+
+ fkey_d = {
+ "name": spec["name"],
+ "constrained_columns": loc_names,
+ "referred_schema": ref_schema,
+ "referred_table": ref_name,
+ "referred_columns": ref_names,
+ "options": con_kw,
+ }
+ fkeys.append(fkey_d)
+
+ if self._needs_correct_for_88718_96365:
+ self._correct_for_mysql_bugs_88718_96365(fkeys, connection)
+
+ return fkeys
+
+ def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection):
+ # Foreign key is always in lower case (MySQL 8.0)
+ # https://bugs.mysql.com/bug.php?id=88718
+ # issue #4344 for SQLAlchemy
+
+ # table name also for MySQL 8.0
+ # https://bugs.mysql.com/bug.php?id=96365
+ # issue #4751 for SQLAlchemy
+
+ # for lower_case_table_names=2, information_schema.columns
+ # preserves the original table/schema casing, but SHOW CREATE
+ # TABLE does not. this problem is not in lower_case_table_names=1,
+ # but use case-insensitive matching for these two modes in any case.
+
+ if self._casing in (1, 2):
+
+ def lower(s):
+ return s.lower()
+
+ else:
+ # if on case sensitive, there can be two tables referenced
+ # with the same name different casing, so we need to use
+ # case-sensitive matching.
+ def lower(s):
+ return s
+
+ default_schema_name = connection.dialect.default_schema_name
+ col_tuples = [
+ (
+ lower(rec["referred_schema"] or default_schema_name),
+ lower(rec["referred_table"]),
+ col_name,
+ )
+ for rec in fkeys
+ for col_name in rec["referred_columns"]
+ ]
+
+ if col_tuples:
+
+ correct_for_wrong_fk_case = connection.execute(
+ sql.text(
+ """
+ select table_schema, table_name, column_name
+ from information_schema.columns
+ where (table_schema, table_name, lower(column_name)) in
+ :table_data;
+ """
+ ).bindparams(sql.bindparam("table_data", expanding=True)),
+ dict(table_data=col_tuples),
+ )
+
+ # in casing=0, table name and schema name come back in their
+ # exact case.
+ # in casing=1, table name and schema name come back in lower
+ # case.
+ # in casing=2, table name and schema name come back from the
+ # information_schema.columns view in the case
+ # that was used in CREATE DATABASE and CREATE TABLE, but
+ # SHOW CREATE TABLE converts them to *lower case*, therefore
+ # not matching. So for this case, case-insensitive lookup
+ # is necessary
+ d = defaultdict(dict)
+ for schema, tname, cname in correct_for_wrong_fk_case:
+ d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema
+ d[(lower(schema), lower(tname))]["TABLENAME"] = tname
+ d[(lower(schema), lower(tname))][cname.lower()] = cname
+
+ for fkey in fkeys:
+ rec = d[
+ (
+ lower(fkey["referred_schema"] or default_schema_name),
+ lower(fkey["referred_table"]),
+ )
+ ]
+
+ fkey["referred_table"] = rec["TABLENAME"]
+ if fkey["referred_schema"] is not None:
+ fkey["referred_schema"] = rec["SCHEMANAME"]
+
+ fkey["referred_columns"] = [
+ rec[col.lower()] for col in fkey["referred_columns"]
+ ]
+
+ @reflection.cache
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+
+ return [
+ {"name": spec["name"], "sqltext": spec["sqltext"]}
+ for spec in parsed_state.ck_constraints
+ ]
+
+ @reflection.cache
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+ return {
+ "text": parsed_state.table_options.get(
+ "%s_comment" % self.name, None
+ )
+ }
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+
+ indexes = []
+
+ for spec in parsed_state.keys:
+ dialect_options = {}
+ unique = False
+ flavor = spec["type"]
+ if flavor == "PRIMARY":
+ continue
+ if flavor == "UNIQUE":
+ unique = True
+ elif flavor in ("FULLTEXT", "SPATIAL"):
+ dialect_options["%s_prefix" % self.name] = flavor
+ elif flavor is None:
+ pass
+ else:
+ self.logger.info(
+ "Converting unknown KEY type %s to a plain KEY", flavor
+ )
+ pass
+
+ if spec["parser"]:
+ dialect_options["%s_with_parser" % (self.name)] = spec[
+ "parser"
+ ]
+
+ index_d = {}
+ if dialect_options:
+ index_d["dialect_options"] = dialect_options
+
+ index_d["name"] = spec["name"]
+ index_d["column_names"] = [s[0] for s in spec["columns"]]
+ index_d["unique"] = unique
+ if flavor:
+ index_d["type"] = flavor
+ indexes.append(index_d)
+ return indexes
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ parsed_state = self._parsed_state_or_create(
+ connection, table_name, schema, **kw
+ )
+
+ return [
+ {
+ "name": key["name"],
+ "column_names": [col[0] for col in key["columns"]],
+ "duplicates_index": key["name"],
+ }
+ for key in parsed_state.keys
+ if key["type"] == "UNIQUE"
+ ]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+
+ charset = self._connection_charset
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(schema, view_name)
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
+ return sql
+
+ def _parsed_state_or_create(
+ self, connection, table_name, schema=None, **kw
+ ):
+ return self._setup_parser(
+ connection,
+ table_name,
+ schema,
+ info_cache=kw.get("info_cache", None),
+ )
+
+ @util.memoized_property
+ def _tabledef_parser(self):
+ """return the MySQLTableDefinitionParser, generate if needed.
+
+ The deferred creation ensures that the dialect has
+ retrieved server version information first.
+
+ """
+ preparer = self.identifier_preparer
+ return _reflection.MySQLTableDefinitionParser(self, preparer)
+
+ @reflection.cache
+ def _setup_parser(self, connection, table_name, schema=None, **kw):
+ charset = self._connection_charset
+ parser = self._tabledef_parser
+ full_name = ".".join(
+ self.identifier_preparer._quote_free_identifiers(
+ schema, table_name
+ )
+ )
+ sql = self._show_create_table(
+ connection, None, charset, full_name=full_name
+ )
+ if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql):
+ # Adapt views to something table-like.
+ columns = self._describe_table(
+ connection, None, charset, full_name=full_name
+ )
+ sql = parser._describe_to_create(table_name, columns)
+ return parser.parse(sql, charset)
+
+ def _fetch_setting(self, connection, setting_name):
+ charset = self._connection_charset
+
+ if self.server_version_info and self.server_version_info < (5, 6):
+ sql = "SHOW VARIABLES LIKE '%s'" % setting_name
+ fetch_col = 1
+ else:
+ sql = "SELECT @@%s" % setting_name
+ fetch_col = 0
+
+ show_var = connection.exec_driver_sql(sql)
+ row = self._compat_first(show_var, charset=charset)
+ if not row:
+ return None
+ else:
+ return row[fetch_col]
+
+ def _detect_charset(self, connection):
+ raise NotImplementedError()
+
+ def _detect_casing(self, connection):
+ """Sniff out identifier case sensitivity.
+
+ Cached per-connection. This value can not change without a server
+ restart.
+
+ """
+ # https://dev.mysql.com/doc/refman/en/identifier-case-sensitivity.html
+
+ setting = self._fetch_setting(connection, "lower_case_table_names")
+ if setting is None:
+ cs = 0
+ else:
+ # 4.0.15 returns OFF or ON according to [ticket:489]
+ # 3.23 doesn't, 4.0.27 doesn't..
+ if setting == "OFF":
+ cs = 0
+ elif setting == "ON":
+ cs = 1
+ else:
+ cs = int(setting)
+ self._casing = cs
+ return cs
+
+ def _detect_collations(self, connection):
+ """Pull the active COLLATIONS list from the server.
+
+ Cached per-connection.
+ """
+
+ collations = {}
+ charset = self._connection_charset
+ rs = connection.exec_driver_sql("SHOW COLLATION")
+ for row in self._compat_fetchall(rs, charset):
+ collations[row[0]] = row[1]
+ return collations
+
+ def _detect_sql_mode(self, connection):
+ setting = self._fetch_setting(connection, "sql_mode")
+
+ if setting is None:
+ util.warn(
+ "Could not retrieve SQL_MODE; please ensure the "
+ "MySQL user has permissions to SHOW VARIABLES"
+ )
+ self._sql_mode = ""
+ else:
+ self._sql_mode = setting or ""
+
+ def _detect_ansiquotes(self, connection):
+ """Detect and adjust for the ANSI_QUOTES sql mode."""
+
+ mode = self._sql_mode
+ if not mode:
+ mode = ""
+ elif mode.isdigit():
+ mode_no = int(mode)
+ mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or ""
+
+ self._server_ansiquotes = "ANSI_QUOTES" in mode
+
+ # as of MySQL 5.0.1
+ self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode
+
+ def _show_create_table(
+ self, connection, table, charset=None, full_name=None
+ ):
+ """Run SHOW CREATE TABLE for a ``Table``."""
+
+ if full_name is None:
+ full_name = self.identifier_preparer.format_table(table)
+ st = "SHOW CREATE TABLE %s" % full_name
+
+ rp = None
+ try:
+ rp = connection.execution_options(
+ skip_user_error_events=True
+ ).exec_driver_sql(st)
+ except exc.DBAPIError as e:
+ if self._extract_error_code(e.orig) == 1146:
+ util.raise_(exc.NoSuchTableError(full_name), replace_context=e)
+ else:
+ raise
+ row = self._compat_first(rp, charset=charset)
+ if not row:
+ raise exc.NoSuchTableError(full_name)
+ return row[1].strip()
+
+ def _describe_table(self, connection, table, charset=None, full_name=None):
+ """Run DESCRIBE for a ``Table`` and return processed rows."""
+
+ if full_name is None:
+ full_name = self.identifier_preparer.format_table(table)
+ st = "DESCRIBE %s" % full_name
+
+ rp, rows = None, None
+ try:
+ try:
+ rp = connection.execution_options(
+ skip_user_error_events=True
+ ).exec_driver_sql(st)
+ except exc.DBAPIError as e:
+ code = self._extract_error_code(e.orig)
+ if code == 1146:
+ util.raise_(
+ exc.NoSuchTableError(full_name), replace_context=e
+ )
+ elif code == 1356:
+ util.raise_(
+ exc.UnreflectableTableError(
+ "Table or view named %s could not be "
+ "reflected: %s" % (full_name, e)
+ ),
+ replace_context=e,
+ )
+ else:
+ raise
+ rows = self._compat_fetchall(rp, charset=charset)
+ finally:
+ if rp:
+ rp.close()
+ return rows
+
+
+class _DecodingRow(object):
+ """Return unicode-decoded values based on type inspection.
+
+ Smooth over data type issues (esp. with alpha driver versions) and
+ normalize strings as Unicode regardless of user-configured driver
+ encoding settings.
+
+ """
+
+ # Some MySQL-python versions can return some columns as
+ # sets.Set(['value']) (seriously) but thankfully that doesn't
+ # seem to come up in DDL queries.
+
+ _encoding_compat = {
+ "koi8r": "koi8_r",
+ "koi8u": "koi8_u",
+ "utf16": "utf-16-be", # MySQL's uft16 is always bigendian
+ "utf8mb4": "utf8", # real utf8
+ "utf8mb3": "utf8", # real utf8; saw this happen on CI but I cannot
+ # reproduce, possibly mariadb10.6 related
+ "eucjpms": "ujis",
+ }
+
+ def __init__(self, rowproxy, charset):
+ self.rowproxy = rowproxy
+ self.charset = self._encoding_compat.get(charset, charset)
+
+ def __getitem__(self, index):
+ item = self.rowproxy[index]
+ if isinstance(item, _array):
+ item = item.tostring()
+
+ if self.charset and isinstance(item, util.binary_type):
+ return item.decode(self.charset)
+ else:
+ return item
+
+ def __getattr__(self, attr):
+ item = getattr(self.rowproxy, attr)
+ if isinstance(item, _array):
+ item = item.tostring()
+ if self.charset and isinstance(item, util.binary_type):
+ return item.decode(self.charset)
+ else:
+ return item
diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py
new file mode 100644
index 0000000..a67a194
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/cymysql.py
@@ -0,0 +1,82 @@
+# mysql/cymysql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+
+.. dialect:: mysql+cymysql
+ :name: CyMySQL
+ :dbapi: cymysql
+ :connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
+ :url: https://github.com/nakagami/CyMySQL
+
+.. note::
+
+ The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous
+ integration** and may have unresolved issues. The recommended MySQL
+ dialects are mysqlclient and PyMySQL.
+
+""" # noqa
+
+from .base import BIT
+from .base import MySQLDialect
+from .mysqldb import MySQLDialect_mysqldb
+from ... import util
+
+
+class _cymysqlBIT(BIT):
+ def result_processor(self, dialect, coltype):
+ """Convert MySQL's 64 bit, variable length binary string to a long."""
+
+ def process(value):
+ if value is not None:
+ v = 0
+ for i in util.iterbytes(value):
+ v = v << 8 | i
+ return v
+ return value
+
+ return process
+
+
+class MySQLDialect_cymysql(MySQLDialect_mysqldb):
+ driver = "cymysql"
+ supports_statement_cache = True
+
+ description_encoding = None
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+ supports_unicode_statements = True
+
+ colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("cymysql")
+
+ def _detect_charset(self, connection):
+ return connection.connection.charset
+
+ def _extract_error_code(self, exception):
+ return exception.errno
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.OperationalError):
+ return self._extract_error_code(e) in (
+ 2006,
+ 2013,
+ 2014,
+ 2045,
+ 2055,
+ )
+ elif isinstance(e, self.dbapi.InterfaceError):
+ # if underlying connection is closed,
+ # this is the error you get
+ return True
+ else:
+ return False
+
+
+dialect = MySQLDialect_cymysql
diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py
new file mode 100644
index 0000000..0c8791a
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/dml.py
@@ -0,0 +1,175 @@
+from ... import exc
+from ... import util
+from ...sql.base import _exclusive_against
+from ...sql.base import _generative
+from ...sql.base import ColumnCollection
+from ...sql.dml import Insert as StandardInsert
+from ...sql.elements import ClauseElement
+from ...sql.expression import alias
+from ...util.langhelpers import public_factory
+
+
+__all__ = ("Insert", "insert")
+
+
+class Insert(StandardInsert):
+ """MySQL-specific implementation of INSERT.
+
+ Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE.
+
+ The :class:`~.mysql.Insert` object is created using the
+ :func:`sqlalchemy.dialects.mysql.insert` function.
+
+ .. versionadded:: 1.2
+
+ """
+
+ stringify_dialect = "mysql"
+ inherit_cache = False
+
+ @property
+ def inserted(self):
+ """Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE
+ statement
+
+ MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row
+ that would be inserted, via a special function called ``VALUES()``.
+ This attribute provides all columns in this row to be referenceable
+ such that they will render within a ``VALUES()`` function inside the
+ ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted``
+ so as not to conflict with the existing
+ :meth:`_expression.Insert.values` method.
+
+ .. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance
+ of :class:`_expression.ColumnCollection`, which provides an
+ interface the same as that of the :attr:`_schema.Table.c`
+ collection described at :ref:`metadata_tables_and_columns`.
+ With this collection, ordinary names are accessible like attributes
+ (e.g. ``stmt.inserted.some_column``), but special names and
+ dictionary method names should be accessed using indexed access,
+ such as ``stmt.inserted["column name"]`` or
+ ``stmt.inserted["values"]``. See the docstring for
+ :class:`_expression.ColumnCollection` for further examples.
+
+ .. seealso::
+
+ :ref:`mysql_insert_on_duplicate_key_update` - example of how
+ to use :attr:`_expression.Insert.inserted`
+
+ """
+ return self.inserted_alias.columns
+
+ @util.memoized_property
+ def inserted_alias(self):
+ return alias(self.table, name="inserted")
+
+ @_generative
+ @_exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already "
+ "has an ON DUPLICATE KEY clause present"
+ },
+ )
+ def on_duplicate_key_update(self, *args, **kw):
+ r"""
+ Specifies the ON DUPLICATE KEY UPDATE clause.
+
+ :param \**kw: Column keys linked to UPDATE values. The
+ values may be any SQL expression or supported literal Python
+ values.
+
+ .. warning:: This dictionary does **not** take into account
+ Python-specified default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON DUPLICATE KEY UPDATE
+ style of UPDATE, unless values are manually specified here.
+
+ :param \*args: As an alternative to passing key/value parameters,
+ a dictionary or list of 2-tuples can be passed as a single positional
+ argument.
+
+ Passing a single dictionary is equivalent to the keyword argument
+ form::
+
+ insert().on_duplicate_key_update({"name": "some name"})
+
+ Passing a list of 2-tuples indicates that the parameter assignments
+ in the UPDATE clause should be ordered as sent, in a manner similar
+ to that described for the :class:`_expression.Update`
+ construct overall
+ in :ref:`tutorial_parameter_ordered_updates`::
+
+ insert().on_duplicate_key_update(
+ [("name", "some name"), ("value", "some value")])
+
+ .. versionchanged:: 1.3 parameters can be specified as a dictionary
+ or list of 2-tuples; the latter form provides for parameter
+ ordering.
+
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`mysql_insert_on_duplicate_key_update`
+
+ """
+ if args and kw:
+ raise exc.ArgumentError(
+ "Can't pass kwargs and positional arguments simultaneously"
+ )
+
+ if args:
+ if len(args) > 1:
+ raise exc.ArgumentError(
+ "Only a single dictionary or list of tuples "
+ "is accepted positionally."
+ )
+ values = args[0]
+ else:
+ values = kw
+
+ inserted_alias = getattr(self, "inserted_alias", None)
+ self._post_values_clause = OnDuplicateClause(inserted_alias, values)
+
+
+insert = public_factory(
+ Insert, ".dialects.mysql.insert", ".dialects.mysql.Insert"
+)
+
+
+class OnDuplicateClause(ClauseElement):
+ __visit_name__ = "on_duplicate_key_update"
+
+ _parameter_ordering = None
+
+ stringify_dialect = "mysql"
+
+ def __init__(self, inserted_alias, update):
+ self.inserted_alias = inserted_alias
+
+ # auto-detect that parameters should be ordered. This is copied from
+ # Update._proces_colparams(), however we don't look for a special flag
+ # in this case since we are not disambiguating from other use cases as
+ # we are in Update.values().
+ if isinstance(update, list) and (
+ update and isinstance(update[0], tuple)
+ ):
+ self._parameter_ordering = [key for key, value in update]
+ update = dict(update)
+
+ if isinstance(update, dict):
+ if not update:
+ raise ValueError(
+ "update parameter dictionary must not be empty"
+ )
+ elif isinstance(update, ColumnCollection):
+ update = dict(update)
+ else:
+ raise ValueError(
+ "update parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
+ self.update = update
diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py
new file mode 100644
index 0000000..6c9ef28
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/enumerated.py
@@ -0,0 +1,263 @@
+# mysql/enumerated.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from .types import _StringType
+from ... import exc
+from ... import sql
+from ... import util
+from ...sql import sqltypes
+from ...sql.base import NO_ARG
+
+
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
+ """MySQL ENUM type."""
+
+ __visit_name__ = "ENUM"
+
+ native_enum = True
+
+ def __init__(self, *enums, **kw):
+ """Construct an ENUM.
+
+ E.g.::
+
+ Column('myenum', ENUM("foo", "bar", "baz"))
+
+ :param enums: The range of valid values for this ENUM. Values in
+ enums are not quoted, they will be escaped and surrounded by single
+ quotes when generating the schema. This object may also be a
+ PEP-435-compliant enumerated type.
+
+ .. versionadded: 1.1 added support for PEP-435-compliant enumerated
+ types.
+
+ :param strict: This flag has no effect.
+
+ .. versionchanged:: The MySQL ENUM type as well as the base Enum
+ type now validates all Python data values.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ :param quoting: Not used. A warning will be raised if provided.
+
+ """
+ if kw.pop("quoting", NO_ARG) is not NO_ARG:
+ util.warn_deprecated_20(
+ "The 'quoting' parameter to :class:`.mysql.ENUM` is deprecated"
+ " and will be removed in a future release. "
+ "This parameter now has no effect."
+ )
+ kw.pop("strict", None)
+ self._enum_init(enums, kw)
+ _StringType.__init__(self, length=self.length, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a MySQL native :class:`.mysql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ kw.setdefault("values_callable", impl.values_callable)
+ kw.setdefault("omit_aliases", impl._omit_aliases)
+ return cls(**kw)
+
+ def _object_value_for_elem(self, elem):
+ # mysql sends back a blank string for any value that
+ # was persisted that was not in the enums; that is, it does no
+ # validation on the incoming data, it "truncates" it to be
+ # the blank string. Return it straight.
+ if elem == "":
+ return elem
+ else:
+ return super(ENUM, self)._object_value_for_elem(elem)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
+ )
+
+
+class SET(_StringType):
+ """MySQL SET type."""
+
+ __visit_name__ = "SET"
+
+ def __init__(self, *values, **kw):
+ """Construct a SET.
+
+ E.g.::
+
+ Column('myset', SET("foo", "bar", "baz"))
+
+
+ The list of potential values is required in the case that this
+ set will be used to generate DDL for a table, or if the
+ :paramref:`.SET.retrieve_as_bitwise` flag is set to True.
+
+ :param values: The range of valid values for this SET. The values
+ are not quoted, they will be escaped and surrounded by single
+ quotes when generating the schema.
+
+ :param convert_unicode: Same flag as that of
+ :paramref:`.String.convert_unicode`.
+
+ :param collation: same as that of :paramref:`.String.collation`
+
+ :param charset: same as that of :paramref:`.VARCHAR.charset`.
+
+ :param ascii: same as that of :paramref:`.VARCHAR.ascii`.
+
+ :param unicode: same as that of :paramref:`.VARCHAR.unicode`.
+
+ :param binary: same as that of :paramref:`.VARCHAR.binary`.
+
+ :param retrieve_as_bitwise: if True, the data for the set type will be
+ persisted and selected using an integer value, where a set is coerced
+ into a bitwise mask for persistence. MySQL allows this mode which
+ has the advantage of being able to store values unambiguously,
+ such as the blank string ``''``. The datatype will appear
+ as the expression ``col + 0`` in a SELECT statement, so that the
+ value is coerced into an integer value in result sets.
+ This flag is required if one wishes
+ to persist a set that can store the blank string ``''`` as a value.
+
+ .. warning::
+
+ When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
+ essential that the list of set values is expressed in the
+ **exact same order** as exists on the MySQL database.
+
+ .. versionadded:: 1.0.0
+
+ :param quoting: Not used. A warning will be raised if passed.
+
+ """
+ if kw.pop("quoting", NO_ARG) is not NO_ARG:
+ util.warn_deprecated_20(
+ "The 'quoting' parameter to :class:`.mysql.SET` is deprecated"
+ " and will be removed in a future release. "
+ "This parameter now has no effect."
+ )
+ self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False)
+ self.values = tuple(values)
+ if not self.retrieve_as_bitwise and "" in values:
+ raise exc.ArgumentError(
+ "Can't use the blank value '' in a SET without "
+ "setting retrieve_as_bitwise=True"
+ )
+ if self.retrieve_as_bitwise:
+ self._bitmap = dict(
+ (value, 2 ** idx) for idx, value in enumerate(self.values)
+ )
+ self._bitmap.update(
+ (2 ** idx, value) for idx, value in enumerate(self.values)
+ )
+ length = max([len(v) for v in values] + [0])
+ kw.setdefault("length", length)
+ super(SET, self).__init__(**kw)
+
+ def column_expression(self, colexpr):
+ if self.retrieve_as_bitwise:
+ return sql.type_coerce(
+ sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
+ )
+ else:
+ return colexpr
+
+ def result_processor(self, dialect, coltype):
+ if self.retrieve_as_bitwise:
+
+ def process(value):
+ if value is not None:
+ value = int(value)
+
+ return set(util.map_bits(self._bitmap.__getitem__, value))
+ else:
+ return None
+
+ else:
+ super_convert = super(SET, self).result_processor(dialect, coltype)
+
+ def process(value):
+ if isinstance(value, util.string_types):
+ # MySQLdb returns a string, let's parse
+ if super_convert:
+ value = super_convert(value)
+ return set(re.findall(r"[^,]+", value))
+ else:
+ # mysql-connector-python does a naive
+ # split(",") which throws in an empty string
+ if value is not None:
+ value.discard("")
+ return value
+
+ return process
+
+ def bind_processor(self, dialect):
+ super_convert = super(SET, self).bind_processor(dialect)
+ if self.retrieve_as_bitwise:
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, util.int_types + util.string_types):
+ if super_convert:
+ return super_convert(value)
+ else:
+ return value
+ else:
+ int_value = 0
+ for v in value:
+ int_value |= self._bitmap[v]
+ return int_value
+
+ else:
+
+ def process(value):
+ # accept strings and int (actually bitflag) values directly
+ if value is not None and not isinstance(
+ value, util.int_types + util.string_types
+ ):
+ value = ",".join(value)
+
+ if super_convert:
+ return super_convert(value)
+ else:
+ return value
+
+ return process
+
+ def adapt(self, impltype, **kw):
+ kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
+ return util.constructor_copy(self, impltype, *self.values, **kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self,
+ to_inspect=[SET, _StringType],
+ additional_kw=[
+ ("retrieve_as_bitwise", False),
+ ],
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py
new file mode 100644
index 0000000..7a66e9b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/expression.py
@@ -0,0 +1,130 @@
+from ... import exc
+from ... import util
+from ...sql import coercions
+from ...sql import elements
+from ...sql import operators
+from ...sql import roles
+from ...sql.base import _generative
+from ...sql.base import Generative
+
+
+class match(Generative, elements.BinaryExpression):
+ """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
+
+ E.g.::
+
+ from sqlalchemy import desc
+ from sqlalchemy.dialects.mysql import match
+
+ match_expr = match(
+ users_table.c.firstname,
+ users_table.c.lastname,
+ against="Firstname Lastname",
+ )
+
+ stmt = (
+ select(users_table)
+ .where(match_expr.in_boolean_mode())
+ .order_by(desc(match_expr))
+ )
+
+ Would produce SQL resembling::
+
+ SELECT id, firstname, lastname
+ FROM user
+ WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE)
+ ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC
+
+ The :func:`_mysql.match` function is a standalone version of the
+ :meth:`_sql.ColumnElement.match` method available on all
+ SQL expressions, as when :meth:`_expression.ColumnElement.match` is
+ used, but allows to pass multiple columns
+
+ :param cols: column expressions to match against
+
+ :param against: expression to be compared towards
+
+ :param in_boolean_mode: boolean, set "boolean mode" to true
+
+ :param in_natural_language_mode: boolean , set "natural language" to true
+
+ :param with_query_expansion: boolean, set "query expansion" to true
+
+ .. versionadded:: 1.4.19
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.match`
+
+ """
+
+ __visit_name__ = "mysql_match"
+
+ inherit_cache = True
+
+ def __init__(self, *cols, **kw):
+ if not cols:
+ raise exc.ArgumentError("columns are required")
+
+ against = kw.pop("against", None)
+
+ if against is None:
+ raise exc.ArgumentError("against is required")
+ against = coercions.expect(
+ roles.ExpressionElementRole,
+ against,
+ )
+
+ left = elements.BooleanClauseList._construct_raw(
+ operators.comma_op,
+ clauses=cols,
+ )
+ left.group = False
+
+ flags = util.immutabledict(
+ {
+ "mysql_boolean_mode": kw.pop("in_boolean_mode", False),
+ "mysql_natural_language": kw.pop(
+ "in_natural_language_mode", False
+ ),
+ "mysql_query_expansion": kw.pop("with_query_expansion", False),
+ }
+ )
+
+ if kw:
+ raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw)))
+
+ super(match, self).__init__(
+ left, against, operators.match_op, modifiers=flags
+ )
+
+ @_generative
+ def in_boolean_mode(self):
+ """Apply the "IN BOOLEAN MODE" modifier to the MATCH expression.
+
+ :return: a new :class:`_mysql.match` instance with modifications
+ applied.
+ """
+
+ self.modifiers = self.modifiers.union({"mysql_boolean_mode": True})
+
+ @_generative
+ def in_natural_language_mode(self):
+ """Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH
+ expression.
+
+ :return: a new :class:`_mysql.match` instance with modifications
+ applied.
+ """
+
+ self.modifiers = self.modifiers.union({"mysql_natural_language": True})
+
+ @_generative
+ def with_query_expansion(self):
+ """Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression.
+
+ :return: a new :class:`_mysql.match` instance with modifications
+ applied.
+ """
+
+ self.modifiers = self.modifiers.union({"mysql_query_expansion": True})
diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py
new file mode 100644
index 0000000..857fcce
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/json.py
@@ -0,0 +1,84 @@
+# mysql/json.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+from ... import types as sqltypes
+
+
+class JSON(sqltypes.JSON):
+ """MySQL JSON type.
+
+ MySQL supports JSON as of version 5.7.
+ MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2.
+
+ :class:`_mysql.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a MySQL or MariaDB backend.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The :class:`.mysql.JSON` type supports persistence of JSON values
+ as well as the core index operations provided by :class:`_types.JSON`
+ datatype, by adapting the operations to render the ``JSON_EXTRACT``
+ function at the database level.
+
+ .. versionadded:: 1.1
+
+ """
+
+ pass
+
+
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/mariadb.py b/lib/sqlalchemy/dialects/mysql/mariadb.py
new file mode 100644
index 0000000..568c3f0
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mariadb.py
@@ -0,0 +1,25 @@
+from .base import MariaDBIdentifierPreparer
+from .base import MySQLDialect
+
+
+class MariaDBDialect(MySQLDialect):
+ is_mariadb = True
+ supports_statement_cache = True
+ name = "mariadb"
+ preparer = MariaDBIdentifierPreparer
+
+
+def loader(driver):
+ driver_mod = __import__(
+ "sqlalchemy.dialects.mysql.%s" % driver
+ ).dialects.mysql
+ driver_cls = getattr(driver_mod, driver).dialect
+
+ return type(
+ "MariaDBDialect_%s" % driver,
+ (
+ MariaDBDialect,
+ driver_cls,
+ ),
+ {"supports_statement_cache": True},
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py
new file mode 100644
index 0000000..c8b2ead
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py
@@ -0,0 +1,240 @@
+# mysql/mariadbconnector.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: mysql+mariadbconnector
+ :name: MariaDB Connector/Python
+ :dbapi: mariadb
+ :connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://pypi.org/project/mariadb/
+
+Driver Status
+-------------
+
+MariaDB Connector/Python enables Python programs to access MariaDB and MySQL
+databases using an API which is compliant with the Python DB API 2.0 (PEP-249).
+It is written in C and uses MariaDB Connector/C client library for client server
+communication.
+
+Note that the default driver for a ``mariadb://`` connection URI continues to
+be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
+
+.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
+
+""" # noqa
+import re
+
+from .base import MySQLCompiler
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from ... import sql
+from ... import util
+
+mariadb_cpy_minimum_version = (1, 0, 1)
+
+
+class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
+ _lastrowid = None
+
+ def create_server_side_cursor(self):
+ return self._dbapi_connection.cursor(buffered=False)
+
+ def create_default_cursor(self):
+ return self._dbapi_connection.cursor(buffered=True)
+
+ def post_exec(self):
+ if self.isinsert and self.compiled.postfetch_lastrowid:
+ self._lastrowid = self.cursor.lastrowid
+
+ def get_lastrowid(self):
+ return self._lastrowid
+
+
+class MySQLCompiler_mariadbconnector(MySQLCompiler):
+ pass
+
+
+class MySQLDialect_mariadbconnector(MySQLDialect):
+ driver = "mariadbconnector"
+ supports_statement_cache = True
+
+ # set this to True at the module level to prevent the driver from running
+ # against a backend that server detects as MySQL. currently this appears to
+ # be unnecessary as MariaDB client libraries have always worked against
+ # MySQL databases. However, if this changes at some point, this can be
+ # adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if
+ # this change is made at some point to ensure the correct exception
+ # is raised at the correct point when running the driver against
+ # a MySQL backend.
+ # is_mariadb = True
+
+ supports_unicode_statements = True
+ encoding = "utf8mb4"
+ convert_unicode = True
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+ supports_native_decimal = True
+ default_paramstyle = "qmark"
+ execution_ctx_cls = MySQLExecutionContext_mariadbconnector
+ statement_compiler = MySQLCompiler_mariadbconnector
+
+ supports_server_side_cursors = True
+
+ @util.memoized_property
+ def _dbapi_version(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ return tuple(
+ [
+ int(x)
+ for x in re.findall(
+ r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+ )
+ ]
+ )
+ else:
+ return (99, 99, 99)
+
+ def __init__(self, **kwargs):
+ super(MySQLDialect_mariadbconnector, self).__init__(**kwargs)
+ self.paramstyle = "qmark"
+ if self.dbapi is not None:
+ if self._dbapi_version < mariadb_cpy_minimum_version:
+ raise NotImplementedError(
+ "The minimum required version for MariaDB "
+ "Connector/Python is %s"
+ % ".".join(str(x) for x in mariadb_cpy_minimum_version)
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("mariadb")
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_mariadbconnector, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ elif isinstance(e, self.dbapi.Error):
+ str_e = str(e).lower()
+ return "not connected" in str_e or "isn't valid" in str_e
+ else:
+ return False
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args()
+
+ int_params = [
+ "connect_timeout",
+ "read_timeout",
+ "write_timeout",
+ "client_flag",
+ "port",
+ "pool_size",
+ ]
+ bool_params = [
+ "local_infile",
+ "ssl_verify_cert",
+ "ssl",
+ "pool_reset_connection",
+ ]
+
+ for key in int_params:
+ util.coerce_kw_type(opts, key, int)
+ for key in bool_params:
+ util.coerce_kw_type(opts, key, bool)
+
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+ # supports_sane_rowcount.
+ client_flag = opts.get("client_flag", 0)
+ if self.dbapi is not None:
+ try:
+ CLIENT_FLAGS = __import__(
+ self.dbapi.__name__ + ".constants.CLIENT"
+ ).constants.CLIENT
+ client_flag |= CLIENT_FLAGS.FOUND_ROWS
+ except (AttributeError, ImportError):
+ self.supports_sane_rowcount = False
+ opts["client_flag"] = client_flag
+ return [[], opts]
+
+ def _extract_error_code(self, exception):
+ try:
+ rc = exception.errno
+ except:
+ rc = -1
+ return rc
+
+ def _detect_charset(self, connection):
+ return "utf8mb4"
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
+
+ def _set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+ super(MySQLDialect_mariadbconnector, self)._set_isolation_level(
+ connection, level
+ )
+
+ def do_begin_twophase(self, connection, xid):
+ connection.execute(
+ sql.text("XA BEGIN :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.execute(
+ sql.text("XA END :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+ connection.execute(
+ sql.text("XA PREPARE :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ connection.execute(
+ sql.text("XA END :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+ connection.execute(
+ sql.text("XA ROLLBACK :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ connection.execute(
+ sql.text("XA COMMIT :xid").bindparams(
+ sql.bindparam("xid", xid, literal_execute=True)
+ )
+ )
+
+
+dialect = MySQLDialect_mariadbconnector
diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
new file mode 100644
index 0000000..356babe
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py
@@ -0,0 +1,240 @@
+# mysql/mysqlconnector.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: mysql+mysqlconnector
+ :name: MySQL Connector/Python
+ :dbapi: myconnpy
+ :connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://pypi.org/project/mysql-connector-python/
+
+.. note::
+
+ The MySQL Connector/Python DBAPI has had many issues since its release,
+ some of which may remain unresolved, and the mysqlconnector dialect is
+ **not tested as part of SQLAlchemy's continuous integration**.
+ The recommended MySQL dialects are mysqlclient and PyMySQL.
+
+""" # noqa
+
+import re
+
+from .base import BIT
+from .base import MySQLCompiler
+from .base import MySQLDialect
+from .base import MySQLIdentifierPreparer
+from ... import processors
+from ... import util
+
+
+class MySQLCompiler_mysqlconnector(MySQLCompiler):
+ def visit_mod_binary(self, binary, operator, **kw):
+ if self.dialect._mysqlconnector_double_percents:
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+ else:
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
+
+ def post_process_text(self, text):
+ if self.dialect._mysqlconnector_double_percents:
+ return text.replace("%", "%%")
+ else:
+ return text
+
+ def escape_literal_column(self, text):
+ if self.dialect._mysqlconnector_double_percents:
+ return text.replace("%", "%%")
+ else:
+ return text
+
+
+class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
+ @property
+ def _double_percents(self):
+ return self.dialect._mysqlconnector_double_percents
+
+ @_double_percents.setter
+ def _double_percents(self, value):
+ pass
+
+ def _escape_identifier(self, value):
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ if self.dialect._mysqlconnector_double_percents:
+ return value.replace("%", "%%")
+ else:
+ return value
+
+
+class _myconnpyBIT(BIT):
+ def result_processor(self, dialect, coltype):
+ """MySQL-connector already converts mysql bits, so."""
+
+ return None
+
+
+class MySQLDialect_mysqlconnector(MySQLDialect):
+ driver = "mysqlconnector"
+ supports_statement_cache = True
+
+ supports_unicode_binds = True
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+
+ supports_native_decimal = True
+
+ default_paramstyle = "format"
+ statement_compiler = MySQLCompiler_mysqlconnector
+
+ preparer = MySQLIdentifierPreparer_mysqlconnector
+
+ colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
+
+ def __init__(self, *arg, **kw):
+ super(MySQLDialect_mysqlconnector, self).__init__(*arg, **kw)
+
+ # hack description encoding since mysqlconnector randomly
+ # returns bytes or not
+ self._description_decoder = (
+ processors.to_conditional_unicode_processor_factory
+ )(self.description_encoding)
+
+ def _check_unicode_description(self, connection):
+ # hack description encoding since mysqlconnector randomly
+ # returns bytes or not
+ return False
+
+ @property
+ def description_encoding(self):
+ # total guess
+ return "latin-1"
+
+ @util.memoized_property
+ def supports_unicode_statements(self):
+ return util.py3k or self._mysqlconnector_version_info > (2, 0)
+
+ @classmethod
+ def dbapi(cls):
+ from mysql import connector
+
+ return connector
+
+ def do_ping(self, dbapi_connection):
+ try:
+ dbapi_connection.ping(False)
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, None):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "allow_local_infile", bool)
+ util.coerce_kw_type(opts, "autocommit", bool)
+ util.coerce_kw_type(opts, "buffered", bool)
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "connection_timeout", int)
+ util.coerce_kw_type(opts, "connect_timeout", int)
+ util.coerce_kw_type(opts, "consume_results", bool)
+ util.coerce_kw_type(opts, "force_ipv6", bool)
+ util.coerce_kw_type(opts, "get_warnings", bool)
+ util.coerce_kw_type(opts, "pool_reset_session", bool)
+ util.coerce_kw_type(opts, "pool_size", int)
+ util.coerce_kw_type(opts, "raise_on_warnings", bool)
+ util.coerce_kw_type(opts, "raw", bool)
+ util.coerce_kw_type(opts, "ssl_verify_cert", bool)
+ util.coerce_kw_type(opts, "use_pure", bool)
+ util.coerce_kw_type(opts, "use_unicode", bool)
+
+ # unfortunately, MySQL/connector python refuses to release a
+ # cursor without reading fully, so non-buffered isn't an option
+ opts.setdefault("buffered", True)
+
+ # FOUND_ROWS must be set in ClientFlag to enable
+ # supports_sane_rowcount.
+ if self.dbapi is not None:
+ try:
+ from mysql.connector.constants import ClientFlag
+
+ client_flags = opts.get(
+ "client_flags", ClientFlag.get_default()
+ )
+ client_flags |= ClientFlag.FOUND_ROWS
+ opts["client_flags"] = client_flags
+ except Exception:
+ pass
+ return [[], opts]
+
+ @util.memoized_property
+ def _mysqlconnector_version_info(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+
+ @util.memoized_property
+ def _mysqlconnector_double_percents(self):
+ return not util.py3k and self._mysqlconnector_version_info < (2, 0)
+
+ def _detect_charset(self, connection):
+ return connection.connection.charset
+
+ def _extract_error_code(self, exception):
+ return exception.errno
+
+ def is_disconnect(self, e, connection, cursor):
+ errnos = (2006, 2013, 2014, 2045, 2055, 2048)
+ exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
+ if isinstance(e, exceptions):
+ return (
+ e.errno in errnos
+ or "MySQL Connection not available." in str(e)
+ or "Connection to MySQL is not available" in str(e)
+ )
+ else:
+ return False
+
+ def _compat_fetchall(self, rp, charset=None):
+ return rp.fetchall()
+
+ def _compat_fetchone(self, rp, charset=None):
+ return rp.fetchone()
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
+
+ def _set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+ super(MySQLDialect_mysqlconnector, self)._set_isolation_level(
+ connection, level
+ )
+
+
+dialect = MySQLDialect_mysqlconnector
diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py
new file mode 100644
index 0000000..7a721e8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py
@@ -0,0 +1,331 @@
+# mysql/mysqldb.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: mysql+mysqldb
+ :name: mysqlclient (maintained fork of MySQL-Python)
+ :dbapi: mysqldb
+ :connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://pypi.org/project/mysqlclient/
+
+Driver Status
+-------------
+
+The mysqlclient DBAPI is a maintained fork of the
+`MySQL-Python <https://sourceforge.net/projects/mysql-python>`_ DBAPI
+that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3
+and is very stable.
+
+.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
+
+.. _mysqldb_unicode:
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+.. _mysqldb_ssl:
+
+SSL Connections
+----------------
+
+The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the
+key "ssl", which may be specified using the
+:paramref:`_sa.create_engine.connect_args` dictionary::
+
+ engine = create_engine(
+ "mysql+mysqldb://scott:tiger@192.168.0.134/test",
+ connect_args={
+ "ssl": {
+ "ssl_ca": "/home/gord/client-ssl/ca.pem",
+ "ssl_cert": "/home/gord/client-ssl/client-cert.pem",
+ "ssl_key": "/home/gord/client-ssl/client-key.pem"
+ }
+ }
+ )
+
+For convenience, the following keys may also be specified inline within the URL
+where they will be interpreted into the "ssl" dictionary automatically:
+"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher",
+"ssl_check_hostname". An example is as follows::
+
+ connection_uri = (
+ "mysql+mysqldb://scott:tiger@192.168.0.134/test"
+ "?ssl_ca=/home/gord/client-ssl/ca.pem"
+ "&ssl_cert=/home/gord/client-ssl/client-cert.pem"
+ "&ssl_key=/home/gord/client-ssl/client-key.pem"
+ )
+
+If the server uses an automatically-generated certificate that is self-signed
+or does not match the host name (as seen from the client), it may also be
+necessary to indicate ``ssl_check_hostname=false``::
+
+ connection_uri = (
+ "mysql+pymysql://scott:tiger@192.168.0.134/test"
+ "?ssl_ca=/home/gord/client-ssl/ca.pem"
+ "&ssl_cert=/home/gord/client-ssl/client-cert.pem"
+ "&ssl_key=/home/gord/client-ssl/client-key.pem"
+ "&ssl_check_hostname=false"
+ )
+
+
+.. seealso::
+
+ :ref:`pymysql_ssl` in the PyMySQL dialect
+
+
+Using MySQLdb with Google Cloud SQL
+-----------------------------------
+
+Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
+using a URL like the following::
+
+ mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
+
+Server Side Cursors
+-------------------
+
+The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
+
+"""
+
+import re
+
+from .base import MySQLCompiler
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from .base import MySQLIdentifierPreparer
+from .base import TEXT
+from ... import sql
+from ... import util
+
+
+class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
+ @property
+ def rowcount(self):
+ if hasattr(self, "_rowcount"):
+ return self._rowcount
+ else:
+ return self.cursor.rowcount
+
+
+class MySQLCompiler_mysqldb(MySQLCompiler):
+ pass
+
+
+class MySQLDialect_mysqldb(MySQLDialect):
+ driver = "mysqldb"
+ supports_statement_cache = True
+ supports_unicode_statements = True
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+
+ supports_native_decimal = True
+
+ default_paramstyle = "format"
+ execution_ctx_cls = MySQLExecutionContext_mysqldb
+ statement_compiler = MySQLCompiler_mysqldb
+ preparer = MySQLIdentifierPreparer
+
+ def __init__(self, **kwargs):
+ super(MySQLDialect_mysqldb, self).__init__(**kwargs)
+ self._mysql_dbapi_version = (
+ self._parse_dbapi_version(self.dbapi.__version__)
+ if self.dbapi is not None and hasattr(self.dbapi, "__version__")
+ else (0, 0, 0)
+ )
+
+ def _parse_dbapi_version(self, version):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+ else:
+ return (0, 0, 0)
+
+ @util.langhelpers.memoized_property
+ def supports_server_side_cursors(self):
+ try:
+ cursors = __import__("MySQLdb.cursors").cursors
+ self._sscursor = cursors.SSCursor
+ return True
+ except (ImportError, AttributeError):
+ return False
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("MySQLdb")
+
+ def on_connect(self):
+ super_ = super(MySQLDialect_mysqldb, self).on_connect()
+
+ def on_connect(conn):
+ if super_ is not None:
+ super_(conn)
+
+ charset_name = conn.character_set_name()
+
+ if charset_name is not None:
+ cursor = conn.cursor()
+ cursor.execute("SET NAMES %s" % charset_name)
+ cursor.close()
+
+ return on_connect
+
+ def do_ping(self, dbapi_connection):
+ try:
+ dbapi_connection.ping(False)
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, None):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ rowcount = cursor.executemany(statement, parameters)
+ if context is not None:
+ context._rowcount = rowcount
+
+ def _check_unicode_returns(self, connection):
+ # work around issue fixed in
+ # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
+ # specific issue w/ the utf8mb4_bin collation and unicode returns
+
+ collation = connection.exec_driver_sql(
+ "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
+ % (
+ self.identifier_preparer.quote("Charset"),
+ self.identifier_preparer.quote("Collation"),
+ )
+ ).scalar()
+ has_utf8mb4_bin = self.server_version_info > (5,) and collation
+ if has_utf8mb4_bin:
+ additional_tests = [
+ sql.collate(
+ sql.cast(
+ sql.literal_column("'test collated returns'"),
+ TEXT(charset="utf8mb4"),
+ ),
+ "utf8mb4_bin",
+ )
+ ]
+ else:
+ additional_tests = []
+ return super(MySQLDialect_mysqldb, self)._check_unicode_returns(
+ connection, additional_tests
+ )
+
+ def create_connect_args(self, url, _translate_args=None):
+ if _translate_args is None:
+ _translate_args = dict(
+ database="db", username="user", password="passwd"
+ )
+
+ opts = url.translate_connect_args(**_translate_args)
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "connect_timeout", int)
+ util.coerce_kw_type(opts, "read_timeout", int)
+ util.coerce_kw_type(opts, "write_timeout", int)
+ util.coerce_kw_type(opts, "client_flag", int)
+ util.coerce_kw_type(opts, "local_infile", int)
+ # Note: using either of the below will cause all strings to be
+ # returned as Unicode, both in raw SQL operations and with column
+ # types like String and MSString.
+ util.coerce_kw_type(opts, "use_unicode", bool)
+ util.coerce_kw_type(opts, "charset", str)
+
+ # Rich values 'cursorclass' and 'conv' are not supported via
+ # query string.
+
+ ssl = {}
+ keys = [
+ ("ssl_ca", str),
+ ("ssl_key", str),
+ ("ssl_cert", str),
+ ("ssl_capath", str),
+ ("ssl_cipher", str),
+ ("ssl_check_hostname", bool),
+ ]
+ for key, kw_type in keys:
+ if key in opts:
+ ssl[key[4:]] = opts[key]
+ util.coerce_kw_type(ssl, key[4:], kw_type)
+ del opts[key]
+ if ssl:
+ opts["ssl"] = ssl
+
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+ # supports_sane_rowcount.
+ client_flag = opts.get("client_flag", 0)
+
+ client_flag_found_rows = self._found_rows_client_flag()
+ if client_flag_found_rows is not None:
+ client_flag |= client_flag_found_rows
+ opts["client_flag"] = client_flag
+ return [[], opts]
+
+ def _found_rows_client_flag(self):
+ if self.dbapi is not None:
+ try:
+ CLIENT_FLAGS = __import__(
+ self.dbapi.__name__ + ".constants.CLIENT"
+ ).constants.CLIENT
+ except (AttributeError, ImportError):
+ return None
+ else:
+ return CLIENT_FLAGS.FOUND_ROWS
+ else:
+ return None
+
+ def _extract_error_code(self, exception):
+ return exception.args[0]
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ try:
+ # note: the SQL here would be
+ # "SHOW VARIABLES LIKE 'character_set%%'"
+ cset_name = connection.connection.character_set_name
+ except AttributeError:
+ util.warn(
+ "No 'character_set_name' can be detected with "
+ "this MySQL-Python version; "
+ "please upgrade to a recent version of MySQL-Python. "
+ "Assuming latin1."
+ )
+ return "latin1"
+ else:
+ return cset_name()
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ "AUTOCOMMIT",
+ ]
+ )
+
+ def _set_isolation_level(self, connection, level):
+ if level == "AUTOCOMMIT":
+ connection.autocommit(True)
+ else:
+ connection.autocommit(False)
+ super(MySQLDialect_mysqldb, self)._set_isolation_level(
+ connection, level
+ )
+
+
+dialect = MySQLDialect_mysqldb
diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py
new file mode 100644
index 0000000..f6287dc
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/oursql.py
@@ -0,0 +1,273 @@
+# mysql/oursql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: mysql+oursql
+ :name: OurSQL
+ :dbapi: oursql
+ :connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
+ :url: https://packages.python.org/oursql/
+
+.. note::
+
+ The OurSQL MySQL dialect is legacy and is no longer supported upstream,
+ and is **not tested as part of SQLAlchemy's continuous integration**.
+ The recommended MySQL dialects are mysqlclient and PyMySQL.
+
+.. deprecated:: 1.4 The OurSQL DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to mysql.
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+
+"""
+
+
+from .base import BIT
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from ... import types as sqltypes
+from ... import util
+
+
+class _oursqlBIT(BIT):
+ def result_processor(self, dialect, coltype):
+ """oursql already converts mysql bits, so."""
+
+ return None
+
+
+class MySQLExecutionContext_oursql(MySQLExecutionContext):
+ @property
+ def plain_query(self):
+ return self.execution_options.get("_oursql_plain_query", False)
+
+
+class MySQLDialect_oursql(MySQLDialect):
+ driver = "oursql"
+ supports_statement_cache = True
+
+ if util.py2k:
+ supports_unicode_binds = True
+ supports_unicode_statements = True
+
+ supports_native_decimal = True
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+ execution_ctx_cls = MySQLExecutionContext_oursql
+
+ colspecs = util.update_copy(
+ MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _oursqlBIT}
+ )
+
+ @classmethod
+ def dbapi(cls):
+ util.warn_deprecated(
+ "The OurSQL DBAPI is deprecated and will be removed "
+ "in a future version. Please use one of the supported DBAPIs to "
+ "connect to mysql.",
+ version="1.4",
+ )
+ return __import__("oursql")
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ """Provide an implementation of
+ *cursor.execute(statement, parameters)*."""
+
+ if context and context.plain_query:
+ cursor.execute(statement, plain_query=True)
+ else:
+ cursor.execute(statement, parameters)
+
+ def do_begin(self, connection):
+ connection.cursor().execute("BEGIN", plain_query=True)
+
+ def _xa_query(self, connection, query, xid):
+ if util.py2k:
+ arg = connection.connection._escape_string(xid)
+ else:
+ charset = self._connection_charset
+ arg = connection.connection._escape_string(
+ xid.encode(charset)
+ ).decode(charset)
+ arg = "'%s'" % arg
+ connection.execution_options(_oursql_plain_query=True).exec_driver_sql(
+ query % arg
+ )
+
+ # Because mysql is bad, these methods have to be
+ # reimplemented to use _PlainQuery. Basically, some queries
+ # refuse to return any data if they're run through
+ # the parameterized query API, or refuse to be parameterized
+ # in the first place.
+ def do_begin_twophase(self, connection, xid):
+ self._xa_query(connection, "XA BEGIN %s", xid)
+
+ def do_prepare_twophase(self, connection, xid):
+ self._xa_query(connection, "XA END %s", xid)
+ self._xa_query(connection, "XA PREPARE %s", xid)
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self._xa_query(connection, "XA END %s", xid)
+ self._xa_query(connection, "XA ROLLBACK %s", xid)
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if not is_prepared:
+ self.do_prepare_twophase(connection, xid)
+ self._xa_query(connection, "XA COMMIT %s", xid)
+
+ # Q: why didn't we need all these "plain_query" overrides earlier ?
+ # am i on a newer/older version of OurSQL ?
+ def has_table(self, connection, table_name, schema=None):
+ return MySQLDialect.has_table(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ table_name,
+ schema,
+ )
+
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+ return MySQLDialect.get_table_options(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ table_name,
+ schema=schema,
+ **kw
+ )
+
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ return MySQLDialect.get_columns(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ table_name,
+ schema=schema,
+ **kw
+ )
+
+ def get_view_names(self, connection, schema=None, **kw):
+ return MySQLDialect.get_view_names(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ schema=schema,
+ **kw
+ )
+
+ def get_table_names(self, connection, schema=None, **kw):
+ return MySQLDialect.get_table_names(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ schema,
+ )
+
+ def get_schema_names(self, connection, **kw):
+ return MySQLDialect.get_schema_names(
+ self,
+ connection.connect().execution_options(_oursql_plain_query=True),
+ **kw
+ )
+
+ def initialize(self, connection):
+ return MySQLDialect.initialize(
+ self, connection.execution_options(_oursql_plain_query=True)
+ )
+
+ def _show_create_table(
+ self, connection, table, charset=None, full_name=None
+ ):
+ return MySQLDialect._show_create_table(
+ self,
+ connection.connect(close_with_result=True).execution_options(
+ _oursql_plain_query=True
+ ),
+ table,
+ charset,
+ full_name,
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.ProgrammingError):
+ return (
+ e.errno is None
+ and "cursor" not in e.args[1]
+ and e.args[1].endswith("closed")
+ )
+ else:
+ return e.errno in (2006, 2013, 2014, 2045, 2055)
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(
+ database="db", username="user", password="passwd"
+ )
+ opts.update(url.query)
+
+ util.coerce_kw_type(opts, "port", int)
+ util.coerce_kw_type(opts, "compress", bool)
+ util.coerce_kw_type(opts, "autoping", bool)
+ util.coerce_kw_type(opts, "raise_on_warnings", bool)
+
+ util.coerce_kw_type(opts, "default_charset", bool)
+ if opts.pop("default_charset", False):
+ opts["charset"] = None
+ else:
+ util.coerce_kw_type(opts, "charset", str)
+ opts["use_unicode"] = opts.get("use_unicode", True)
+ util.coerce_kw_type(opts, "use_unicode", bool)
+
+ # FOUND_ROWS must be set in CLIENT_FLAGS to enable
+ # supports_sane_rowcount.
+ opts.setdefault("found_rows", True)
+
+ ssl = {}
+ for key in [
+ "ssl_ca",
+ "ssl_key",
+ "ssl_cert",
+ "ssl_capath",
+ "ssl_cipher",
+ ]:
+ if key in opts:
+ ssl[key[4:]] = opts[key]
+ util.coerce_kw_type(ssl, key[4:], str)
+ del opts[key]
+ if ssl:
+ opts["ssl"] = ssl
+
+ return [[], opts]
+
+ def _extract_error_code(self, exception):
+ return exception.errno
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ return connection.connection.charset
+
+ def _compat_fetchall(self, rp, charset=None):
+ """oursql isn't super-broken like MySQLdb, yaaay."""
+ return rp.fetchall()
+
+ def _compat_fetchone(self, rp, charset=None):
+ """oursql isn't super-broken like MySQLdb, yaaay."""
+ return rp.fetchone()
+
+ def _compat_first(self, rp, charset=None):
+ return rp.first()
+
+
+dialect = MySQLDialect_oursql
diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py
new file mode 100644
index 0000000..86aaa94
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/provision.py
@@ -0,0 +1,78 @@
+from ... import exc
+from ...testing.provision import configure_follower
+from ...testing.provision import create_db
+from ...testing.provision import drop_db
+from ...testing.provision import generate_driver_url
+from ...testing.provision import temp_table_keyword_args
+
+
+@generate_driver_url.for_db("mysql", "mariadb")
+def generate_driver_url(url, driver, query_str):
+ backend = url.get_backend_name()
+
+ # NOTE: at the moment, tests are running mariadbconnector
+ # against both mariadb and mysql backends. if we want this to be
+ # limited, do the decision making here to reject a "mysql+mariadbconnector"
+ # URL. Optionally also re-enable the module level
+ # MySQLDialect_mariadbconnector.is_mysql flag as well, which must include
+ # a unit and/or functional test.
+
+ # all the Jenkins tests have been running mysqlclient Python library
+ # built against mariadb client drivers for years against all MySQL /
+ # MariaDB versions going back to MySQL 5.6, currently they can talk
+ # to MySQL databases without problems.
+
+ if backend == "mysql":
+ dialect_cls = url.get_dialect()
+ if dialect_cls._is_mariadb_from_url(url):
+ backend = "mariadb"
+
+ new_url = url.set(
+ drivername="%s+%s" % (backend, driver)
+ ).update_query_string(query_str)
+
+ try:
+ new_url.get_dialect()
+ except exc.NoSuchModuleError:
+ return None
+ else:
+ return new_url
+
+
+@create_db.for_db("mysql", "mariadb")
+def _mysql_create_db(cfg, eng, ident):
+ with eng.begin() as conn:
+ try:
+ _mysql_drop_db(cfg, conn, ident)
+ except Exception:
+ pass
+
+ with eng.begin() as conn:
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s CHARACTER SET utf8mb4" % ident
+ )
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident
+ )
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident
+ )
+
+
+@configure_follower.for_db("mysql", "mariadb")
+def _mysql_configure_follower(config, ident):
+ config.test_schema = "%s_test_schema" % ident
+ config.test_schema_2 = "%s_test_schema_2" % ident
+
+
+@drop_db.for_db("mysql", "mariadb")
+def _mysql_drop_db(cfg, eng, ident):
+ with eng.begin() as conn:
+ conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident)
+ conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident)
+ conn.exec_driver_sql("DROP DATABASE %s" % ident)
+
+
+@temp_table_keyword_args.for_db("mysql", "mariadb")
+def _mysql_temp_table_keyword_args(cfg, eng):
+ return {"prefixes": ["TEMPORARY"]}
diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py
new file mode 100644
index 0000000..f620133
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/pymysql.py
@@ -0,0 +1,98 @@
+# mysql/pymysql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: mysql+pymysql
+ :name: PyMySQL
+ :dbapi: pymysql
+ :connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
+ :url: https://pymysql.readthedocs.io/
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+.. _pymysql_ssl:
+
+SSL Connections
+------------------
+
+The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb,
+described at :ref:`mysqldb_ssl`. See that section for examples.
+
+
+MySQL-Python Compatibility
+--------------------------
+
+The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
+and targets 100% compatibility. Most behavioral notes for MySQL-python apply
+to the pymysql driver as well.
+
+""" # noqa
+
+from .mysqldb import MySQLDialect_mysqldb
+from ...util import langhelpers
+from ...util import py3k
+
+
+class MySQLDialect_pymysql(MySQLDialect_mysqldb):
+ driver = "pymysql"
+ supports_statement_cache = True
+
+ description_encoding = None
+
+ # generally, these two values should be both True
+ # or both False. PyMySQL unicode tests pass all the way back
+ # to 0.4 either way. See [ticket:3337]
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ @langhelpers.memoized_property
+ def supports_server_side_cursors(self):
+ try:
+ cursors = __import__("pymysql.cursors").cursors
+ self._sscursor = cursors.SSCursor
+ return True
+ except (ImportError, AttributeError):
+ return False
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("pymysql")
+
+ def create_connect_args(self, url, _translate_args=None):
+ if _translate_args is None:
+ _translate_args = dict(username="user")
+ return super(MySQLDialect_pymysql, self).create_connect_args(
+ url, _translate_args=_translate_args
+ )
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_pymysql, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ elif isinstance(e, self.dbapi.Error):
+ str_e = str(e).lower()
+ return (
+ "already closed" in str_e or "connection was killed" in str_e
+ )
+ else:
+ return False
+
+ if py3k:
+
+ def _extract_error_code(self, exception):
+ if isinstance(exception.args[0], Exception):
+ exception = exception.args[0]
+ return exception.args[0]
+
+
+dialect = MySQLDialect_pymysql
diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py
new file mode 100644
index 0000000..bfa61f6
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py
@@ -0,0 +1,136 @@
+# mysql/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+
+.. dialect:: mysql+pyodbc
+ :name: PyODBC
+ :dbapi: pyodbc
+ :connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
+ :url: https://pypi.org/project/pyodbc/
+
+.. note::
+
+ The PyODBC for MySQL dialect is **not tested as part of
+ SQLAlchemy's continuous integration**.
+ The recommended MySQL dialects are mysqlclient and PyMySQL.
+ However, if you want to use the mysql+pyodbc dialect and require
+ full support for ``utf8mb4`` characters (including supplementary
+ characters like emoji) be sure to use a current release of
+ MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode")
+ version of the driver in your DSN or connection string.
+
+Pass through exact pyodbc connection string::
+
+ import urllib
+ connection_string = (
+ 'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
+ 'SERVER=localhost;'
+ 'PORT=3307;'
+ 'DATABASE=mydb;'
+ 'UID=root;'
+ 'PWD=(whatever);'
+ 'charset=utf8mb4;'
+ )
+ params = urllib.parse.quote_plus(connection_string)
+ connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
+
+""" # noqa
+
+import re
+
+from .base import MySQLDialect
+from .base import MySQLExecutionContext
+from .types import TIME
+from ... import exc
+from ... import util
+from ...connectors.pyodbc import PyODBCConnector
+from ...sql.sqltypes import Time
+
+
+class _pyodbcTIME(TIME):
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ # pyodbc returns a datetime.time object; no need to convert
+ return value
+
+ return process
+
+
+class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
+ def get_lastrowid(self):
+ cursor = self.create_cursor()
+ cursor.execute("SELECT LAST_INSERT_ID()")
+ lastrowid = cursor.fetchone()[0]
+ cursor.close()
+ return lastrowid
+
+
+class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
+ supports_statement_cache = True
+ colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME})
+ supports_unicode_statements = True
+ execution_ctx_cls = MySQLExecutionContext_pyodbc
+
+ pyodbc_driver_name = "MySQL"
+
+ def _detect_charset(self, connection):
+ """Sniff out the character set in use for connection results."""
+
+ # Prefer 'character_set_results' for the current connection over the
+ # value in the driver. SET NAMES or individual variable SETs will
+ # change the charset without updating the driver's view of the world.
+ #
+ # If it's decided that issuing that sort of SQL leaves you SOL, then
+ # this can prefer the driver value.
+
+ # set this to None as _fetch_setting attempts to use it (None is OK)
+ self._connection_charset = None
+ try:
+ value = self._fetch_setting(connection, "character_set_client")
+ if value:
+ return value
+ except exc.DBAPIError:
+ pass
+
+ util.warn(
+ "Could not detect the connection character set. "
+ "Assuming latin1."
+ )
+ return "latin1"
+
+ def _get_server_version_info(self, connection):
+ return MySQLDialect._get_server_version_info(self, connection)
+
+ def _extract_error_code(self, exception):
+ m = re.compile(r"\((\d+)\)").search(str(exception.args))
+ c = m.group(1)
+ if c:
+ return int(c)
+ else:
+ return None
+
+ def on_connect(self):
+ super_ = super(MySQLDialect_pyodbc, self).on_connect()
+
+ def on_connect(conn):
+ if super_ is not None:
+ super_(conn)
+
+ # declare Unicode encoding for pyodbc as per
+ # https://github.com/mkleehammer/pyodbc/wiki/Unicode
+ pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR
+ pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR
+ conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8")
+ conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8")
+ conn.setencoding(encoding="utf-8")
+
+ return on_connect
+
+
+dialect = MySQLDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py
new file mode 100644
index 0000000..27394bb
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/reflection.py
@@ -0,0 +1,558 @@
+# mysql/reflection.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from .enumerated import ENUM
+from .enumerated import SET
+from .types import DATETIME
+from .types import TIME
+from .types import TIMESTAMP
+from ... import log
+from ... import types as sqltypes
+from ... import util
+
+
+class ReflectedState(object):
+ """Stores raw information about a SHOW CREATE TABLE statement."""
+
+ def __init__(self):
+ self.columns = []
+ self.table_options = {}
+ self.table_name = None
+ self.keys = []
+ self.fk_constraints = []
+ self.ck_constraints = []
+
+
+@log.class_logger
+class MySQLTableDefinitionParser(object):
+ """Parses the results of a SHOW CREATE TABLE statement."""
+
+ def __init__(self, dialect, preparer):
+ self.dialect = dialect
+ self.preparer = preparer
+ self._prep_regexes()
+
+ def parse(self, show_create, charset):
+ state = ReflectedState()
+ state.charset = charset
+ for line in re.split(r"\r?\n", show_create):
+ if line.startswith(" " + self.preparer.initial_quote):
+ self._parse_column(line, state)
+ # a regular table options line
+ elif line.startswith(") "):
+ self._parse_table_options(line, state)
+ # an ANSI-mode table options line
+ elif line == ")":
+ pass
+ elif line.startswith("CREATE "):
+ self._parse_table_name(line, state)
+ # Not present in real reflection, but may be if
+ # loading from a file.
+ elif not line:
+ pass
+ else:
+ type_, spec = self._parse_constraints(line)
+ if type_ is None:
+ util.warn("Unknown schema content: %r" % line)
+ elif type_ == "key":
+ state.keys.append(spec)
+ elif type_ == "fk_constraint":
+ state.fk_constraints.append(spec)
+ elif type_ == "ck_constraint":
+ state.ck_constraints.append(spec)
+ else:
+ pass
+ return state
+
+ def _parse_constraints(self, line):
+ """Parse a KEY or CONSTRAINT line.
+
+ :param line: A line of SHOW CREATE TABLE output
+ """
+
+ # KEY
+ m = self._re_key.match(line)
+ if m:
+ spec = m.groupdict()
+ # convert columns into name, length pairs
+ # NOTE: we may want to consider SHOW INDEX as the
+ # format of indexes in MySQL becomes more complex
+ spec["columns"] = self._parse_keyexprs(spec["columns"])
+ if spec["version_sql"]:
+ m2 = self._re_key_version_sql.match(spec["version_sql"])
+ if m2 and m2.groupdict()["parser"]:
+ spec["parser"] = m2.groupdict()["parser"]
+ if spec["parser"]:
+ spec["parser"] = self.preparer.unformat_identifiers(
+ spec["parser"]
+ )[0]
+ return "key", spec
+
+ # FOREIGN KEY CONSTRAINT
+ m = self._re_fk_constraint.match(line)
+ if m:
+ spec = m.groupdict()
+ spec["table"] = self.preparer.unformat_identifiers(spec["table"])
+ spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])]
+ spec["foreign"] = [
+ c[0] for c in self._parse_keyexprs(spec["foreign"])
+ ]
+ return "fk_constraint", spec
+
+ # CHECK constraint
+ m = self._re_ck_constraint.match(line)
+ if m:
+ spec = m.groupdict()
+ return "ck_constraint", spec
+
+ # PARTITION and SUBPARTITION
+ m = self._re_partition.match(line)
+ if m:
+ # Punt!
+ return "partition", line
+
+ # No match.
+ return (None, line)
+
+ def _parse_table_name(self, line, state):
+ """Extract the table name.
+
+ :param line: The first line of SHOW CREATE TABLE
+ """
+
+ regex, cleanup = self._pr_name
+ m = regex.match(line)
+ if m:
+ state.table_name = cleanup(m.group("name"))
+
+ def _parse_table_options(self, line, state):
+ """Build a dictionary of all reflected table-level options.
+
+ :param line: The final line of SHOW CREATE TABLE output.
+ """
+
+ options = {}
+
+ if not line or line == ")":
+ pass
+
+ else:
+ rest_of_line = line[:]
+ for regex, cleanup in self._pr_options:
+ m = regex.search(rest_of_line)
+ if not m:
+ continue
+ directive, value = m.group("directive"), m.group("val")
+ if cleanup:
+ value = cleanup(value)
+ options[directive.lower()] = value
+ rest_of_line = regex.sub("", rest_of_line)
+
+ for nope in ("auto_increment", "data directory", "index directory"):
+ options.pop(nope, None)
+
+ for opt, val in options.items():
+ state.table_options["%s_%s" % (self.dialect.name, opt)] = val
+
+ def _parse_column(self, line, state):
+ """Extract column details.
+
+ Falls back to a 'minimal support' variant if full parse fails.
+
+ :param line: Any column-bearing line from SHOW CREATE TABLE
+ """
+
+ spec = None
+ m = self._re_column.match(line)
+ if m:
+ spec = m.groupdict()
+ spec["full"] = True
+ else:
+ m = self._re_column_loose.match(line)
+ if m:
+ spec = m.groupdict()
+ spec["full"] = False
+ if not spec:
+ util.warn("Unknown column definition %r" % line)
+ return
+ if not spec["full"]:
+ util.warn("Incomplete reflection of column definition %r" % line)
+
+ name, type_, args = spec["name"], spec["coltype"], spec["arg"]
+
+ try:
+ col_type = self.dialect.ischema_names[type_]
+ except KeyError:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (type_, name)
+ )
+ col_type = sqltypes.NullType
+
+ # Column type positional arguments eg. varchar(32)
+ if args is None or args == "":
+ type_args = []
+ elif args[0] == "'" and args[-1] == "'":
+ type_args = self._re_csv_str.findall(args)
+ else:
+ type_args = [int(v) for v in self._re_csv_int.findall(args)]
+
+ # Column type keyword options
+ type_kw = {}
+
+ if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
+ if type_args:
+ type_kw["fsp"] = type_args.pop(0)
+
+ for kw in ("unsigned", "zerofill"):
+ if spec.get(kw, False):
+ type_kw[kw] = True
+ for kw in ("charset", "collate"):
+ if spec.get(kw, False):
+ type_kw[kw] = spec[kw]
+ if issubclass(col_type, (ENUM, SET)):
+ type_args = _strip_values(type_args)
+
+ if issubclass(col_type, SET) and "" in type_args:
+ type_kw["retrieve_as_bitwise"] = True
+
+ type_instance = col_type(*type_args, **type_kw)
+
+ col_kw = {}
+
+ # NOT NULL
+ col_kw["nullable"] = True
+ # this can be "NULL" in the case of TIMESTAMP
+ if spec.get("notnull", False) == "NOT NULL":
+ col_kw["nullable"] = False
+
+ # AUTO_INCREMENT
+ if spec.get("autoincr", False):
+ col_kw["autoincrement"] = True
+ elif issubclass(col_type, sqltypes.Integer):
+ col_kw["autoincrement"] = False
+
+ # DEFAULT
+ default = spec.get("default", None)
+
+ if default == "NULL":
+ # eliminates the need to deal with this later.
+ default = None
+
+ comment = spec.get("comment", None)
+
+ if comment is not None:
+ comment = comment.replace("\\\\", "\\").replace("''", "'")
+
+ sqltext = spec.get("generated")
+ if sqltext is not None:
+ computed = dict(sqltext=sqltext)
+ persisted = spec.get("persistence")
+ if persisted is not None:
+ computed["persisted"] = persisted == "STORED"
+ col_kw["computed"] = computed
+
+ col_d = dict(
+ name=name, type=type_instance, default=default, comment=comment
+ )
+ col_d.update(col_kw)
+ state.columns.append(col_d)
+
+ def _describe_to_create(self, table_name, columns):
+ """Re-format DESCRIBE output as a SHOW CREATE TABLE string.
+
+ DESCRIBE is a much simpler reflection and is sufficient for
+ reflecting views for runtime use. This method formats DDL
+ for columns only- keys are omitted.
+
+ :param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
+ SHOW FULL COLUMNS FROM rows must be rearranged for use with
+ this function.
+ """
+
+ buffer = []
+ for row in columns:
+ (name, col_type, nullable, default, extra) = [
+ row[i] for i in (0, 1, 2, 4, 5)
+ ]
+
+ line = [" "]
+ line.append(self.preparer.quote_identifier(name))
+ line.append(col_type)
+ if not nullable:
+ line.append("NOT NULL")
+ if default:
+ if "auto_increment" in default:
+ pass
+ elif col_type.startswith("timestamp") and default.startswith(
+ "C"
+ ):
+ line.append("DEFAULT")
+ line.append(default)
+ elif default == "NULL":
+ line.append("DEFAULT")
+ line.append(default)
+ else:
+ line.append("DEFAULT")
+ line.append("'%s'" % default.replace("'", "''"))
+ if extra:
+ line.append(extra)
+
+ buffer.append(" ".join(line))
+
+ return "".join(
+ [
+ (
+ "CREATE TABLE %s (\n"
+ % self.preparer.quote_identifier(table_name)
+ ),
+ ",\n".join(buffer),
+ "\n) ",
+ ]
+ )
+
+ def _parse_keyexprs(self, identifiers):
+ """Unpack '"col"(2),"col" ASC'-ish strings into components."""
+
+ return self._re_keyexprs.findall(identifiers)
+
+ def _prep_regexes(self):
+ """Pre-compile regular expressions."""
+
+ self._re_columns = []
+ self._pr_options = []
+
+ _final = self.preparer.final_quote
+
+ quotes = dict(
+ zip(
+ ("iq", "fq", "esc_fq"),
+ [
+ re.escape(s)
+ for s in (
+ self.preparer.initial_quote,
+ _final,
+ self.preparer._escape_identifier(_final),
+ )
+ ],
+ )
+ )
+
+ self._pr_name = _pr_compile(
+ r"^CREATE (?:\w+ +)?TABLE +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes,
+ self.preparer._unescape_identifier,
+ )
+
+ # `col`,`col2`(32),`col3`(15) DESC
+ #
+ self._re_keyexprs = _re_compile(
+ r"(?:"
+ r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)"
+ r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes
+ )
+
+ # 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
+ self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27")
+
+ # 123 or 123,456
+ self._re_csv_int = _re_compile(r"\d+")
+
+ # `colname` <type> [type opts]
+ # (NOT NULL | NULL)
+ # DEFAULT ('value' | CURRENT_TIMESTAMP...)
+ # COMMENT 'comment'
+ # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT)
+ # STORAGE (DISK|MEMORY)
+ self._re_column = _re_compile(
+ r" "
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"(?P<coltype>\w+)"
+ r"(?:\((?P<arg>(?:\d+|\d+,\d+|"
+ r"(?:'(?:''|[^'])*',?)+))\))?"
+ r"(?: +(?P<unsigned>UNSIGNED))?"
+ r"(?: +(?P<zerofill>ZEROFILL))?"
+ r"(?: +CHARACTER SET +(?P<charset>[\w_]+))?"
+ r"(?: +COLLATE +(?P<collate>[\w_]+))?"
+ r"(?: +(?P<notnull>(?:NOT )?NULL))?"
+ r"(?: +DEFAULT +(?P<default>"
+ r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
+ r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
+ r"))?"
+ r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
+ r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?"
+ r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
+ r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
+ r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
+ r"(?: +STORAGE +(?P<storage>\w+))?"
+ r"(?: +(?P<extra>.*))?"
+ r",?$" % quotes
+ )
+
+ # Fallback, try to parse as little as possible
+ self._re_column_loose = _re_compile(
+ r" "
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"(?P<coltype>\w+)"
+ r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?"
+ r".*?(?P<notnull>(?:NOT )NULL)?" % quotes
+ )
+
+ # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
+ # (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
+ # KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */
+ self._re_key = _re_compile(
+ r" "
+ r"(?:(?P<type>\S+) )?KEY"
+ r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?"
+ r"(?: +USING +(?P<using_pre>\S+))?"
+ r" +\((?P<columns>.+?)\)"
+ r"(?: +USING +(?P<using_post>\S+))?"
+ r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?"
+ r"(?: +WITH PARSER +(?P<parser>\S+))?"
+ r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
+ r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
+ r",?$" % quotes
+ )
+
+ # https://forums.mysql.com/read.php?20,567102,567111#msg-567111
+ # It means if the MySQL version >= \d+, execute what's in the comment
+ self._re_key_version_sql = _re_compile(
+ r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?"
+ )
+
+ # CONSTRAINT `name` FOREIGN KEY (`local_col`)
+ # REFERENCES `remote` (`remote_col`)
+ # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE
+ # ON DELETE CASCADE ON UPDATE RESTRICT
+ #
+ # unique constraints come back as KEYs
+ kw = quotes.copy()
+ kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
+ self._re_fk_constraint = _re_compile(
+ r" "
+ r"CONSTRAINT +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"FOREIGN KEY +"
+ r"\((?P<local>[^\)]+?)\) REFERENCES +"
+ r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s"
+ r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +"
+ r"\((?P<foreign>[^\)]+?)\)"
+ r"(?: +(?P<match>MATCH \w+))?"
+ r"(?: +ON DELETE (?P<ondelete>%(on)s))?"
+ r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw
+ )
+
+ # CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)'
+ # testing on MariaDB 10.2 shows that the CHECK constraint
+ # is returned on a line by itself, so to match without worrying
+ # about parenthesis in the expression we go to the end of the line
+ self._re_ck_constraint = _re_compile(
+ r" "
+ r"CONSTRAINT +"
+ r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
+ r"CHECK +"
+ r"\((?P<sqltext>.+)\),?" % kw
+ )
+
+ # PARTITION
+ #
+ # punt!
+ self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)")
+
+ # Table-level options (COLLATE, ENGINE, etc.)
+ # Do the string options first, since they have quoted
+ # strings we need to get rid of.
+ for option in _options_of_type_string:
+ self._add_option_string(option)
+
+ for option in (
+ "ENGINE",
+ "TYPE",
+ "AUTO_INCREMENT",
+ "AVG_ROW_LENGTH",
+ "CHARACTER SET",
+ "DEFAULT CHARSET",
+ "CHECKSUM",
+ "COLLATE",
+ "DELAY_KEY_WRITE",
+ "INSERT_METHOD",
+ "MAX_ROWS",
+ "MIN_ROWS",
+ "PACK_KEYS",
+ "ROW_FORMAT",
+ "KEY_BLOCK_SIZE",
+ ):
+ self._add_option_word(option)
+
+ self._add_option_regex("UNION", r"\([^\)]+\)")
+ self._add_option_regex("TABLESPACE", r".*? STORAGE DISK")
+ self._add_option_regex(
+ "RAID_TYPE",
+ r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+",
+ )
+
+ _optional_equals = r"(?:\s*(?:=\s*)|\s+)"
+
+ def _add_option_string(self, directive):
+ regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
+ re.escape(directive),
+ self._optional_equals,
+ )
+ self._pr_options.append(
+ _pr_compile(
+ regex, lambda v: v.replace("\\\\", "\\").replace("''", "'")
+ )
+ )
+
+ def _add_option_word(self, directive):
+ regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
+ re.escape(directive),
+ self._optional_equals,
+ )
+ self._pr_options.append(_pr_compile(regex))
+
+ def _add_option_regex(self, directive, regex):
+ regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
+ re.escape(directive),
+ self._optional_equals,
+ regex,
+ )
+ self._pr_options.append(_pr_compile(regex))
+
+
+_options_of_type_string = (
+ "COMMENT",
+ "DATA DIRECTORY",
+ "INDEX DIRECTORY",
+ "PASSWORD",
+ "CONNECTION",
+)
+
+
+def _pr_compile(regex, cleanup=None):
+ """Prepare a 2-tuple of compiled regex and callable."""
+
+ return (_re_compile(regex), cleanup)
+
+
+def _re_compile(regex):
+ """Compile a string to regex, I and UNICODE."""
+
+ return re.compile(regex, re.I | re.UNICODE)
+
+
+def _strip_values(values):
+ "Strip reflected values quotes"
+ strip_values = []
+ for a in values:
+ if a[0:1] == '"' or a[0:1] == "'":
+ # strip enclosing quotes and unquote interior
+ a = a[1:-1].replace(a[0] * 2, a[0])
+ strip_values.append(a)
+ return strip_values
diff --git a/lib/sqlalchemy/dialects/mysql/reserved_words.py b/lib/sqlalchemy/dialects/mysql/reserved_words.py
new file mode 100644
index 0000000..995168b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/reserved_words.py
@@ -0,0 +1,564 @@
+# mysql/reserved_words.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+# generated using:
+# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9
+#
+# https://mariadb.com/kb/en/reserved-words/
+# includes: Reserved Words, Oracle Mode (separate set unioned)
+# excludes: Exceptions, Function Names
+RESERVED_WORDS_MARIADB = {
+ "accessible",
+ "add",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "as",
+ "asc",
+ "asensitive",
+ "before",
+ "between",
+ "bigint",
+ "binary",
+ "blob",
+ "both",
+ "by",
+ "call",
+ "cascade",
+ "case",
+ "change",
+ "char",
+ "character",
+ "check",
+ "collate",
+ "column",
+ "condition",
+ "constraint",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "databases",
+ "day_hour",
+ "day_microsecond",
+ "day_minute",
+ "day_second",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delayed",
+ "delete",
+ "desc",
+ "describe",
+ "deterministic",
+ "distinct",
+ "distinctrow",
+ "div",
+ "do_domain_ids",
+ "double",
+ "drop",
+ "dual",
+ "each",
+ "else",
+ "elseif",
+ "enclosed",
+ "escaped",
+ "except",
+ "exists",
+ "exit",
+ "explain",
+ "false",
+ "fetch",
+ "float",
+ "float4",
+ "float8",
+ "for",
+ "force",
+ "foreign",
+ "from",
+ "fulltext",
+ "general",
+ "grant",
+ "group",
+ "having",
+ "high_priority",
+ "hour_microsecond",
+ "hour_minute",
+ "hour_second",
+ "if",
+ "ignore",
+ "ignore_domain_ids",
+ "ignore_server_ids",
+ "in",
+ "index",
+ "infile",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "int",
+ "int1",
+ "int2",
+ "int3",
+ "int4",
+ "int8",
+ "integer",
+ "intersect",
+ "interval",
+ "into",
+ "is",
+ "iterate",
+ "join",
+ "key",
+ "keys",
+ "kill",
+ "leading",
+ "leave",
+ "left",
+ "like",
+ "limit",
+ "linear",
+ "lines",
+ "load",
+ "localtime",
+ "localtimestamp",
+ "lock",
+ "long",
+ "longblob",
+ "longtext",
+ "loop",
+ "low_priority",
+ "master_heartbeat_period",
+ "master_ssl_verify_server_cert",
+ "match",
+ "maxvalue",
+ "mediumblob",
+ "mediumint",
+ "mediumtext",
+ "middleint",
+ "minute_microsecond",
+ "minute_second",
+ "mod",
+ "modifies",
+ "natural",
+ "no_write_to_binlog",
+ "not",
+ "null",
+ "numeric",
+ "offset",
+ "on",
+ "optimize",
+ "option",
+ "optionally",
+ "or",
+ "order",
+ "out",
+ "outer",
+ "outfile",
+ "over",
+ "page_checksum",
+ "parse_vcol_expr",
+ "partition",
+ "position",
+ "precision",
+ "primary",
+ "procedure",
+ "purge",
+ "range",
+ "read",
+ "read_write",
+ "reads",
+ "real",
+ "recursive",
+ "ref_system_id",
+ "references",
+ "regexp",
+ "release",
+ "rename",
+ "repeat",
+ "replace",
+ "require",
+ "resignal",
+ "restrict",
+ "return",
+ "returning",
+ "revoke",
+ "right",
+ "rlike",
+ "rows",
+ "schema",
+ "schemas",
+ "second_microsecond",
+ "select",
+ "sensitive",
+ "separator",
+ "set",
+ "show",
+ "signal",
+ "slow",
+ "smallint",
+ "spatial",
+ "specific",
+ "sql",
+ "sql_big_result",
+ "sql_calc_found_rows",
+ "sql_small_result",
+ "sqlexception",
+ "sqlstate",
+ "sqlwarning",
+ "ssl",
+ "starting",
+ "stats_auto_recalc",
+ "stats_persistent",
+ "stats_sample_pages",
+ "straight_join",
+ "table",
+ "terminated",
+ "then",
+ "tinyblob",
+ "tinyint",
+ "tinytext",
+ "to",
+ "trailing",
+ "trigger",
+ "true",
+ "undo",
+ "union",
+ "unique",
+ "unlock",
+ "unsigned",
+ "update",
+ "usage",
+ "use",
+ "using",
+ "utc_date",
+ "utc_time",
+ "utc_timestamp",
+ "values",
+ "varbinary",
+ "varchar",
+ "varcharacter",
+ "varying",
+ "when",
+ "where",
+ "while",
+ "window",
+ "with",
+ "write",
+ "xor",
+ "year_month",
+ "zerofill",
+}.union(
+ {
+ "body",
+ "elsif",
+ "goto",
+ "history",
+ "others",
+ "package",
+ "period",
+ "raise",
+ "rowtype",
+ "system",
+ "system_time",
+ "versioning",
+ "without",
+ }
+)
+
+# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
+# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
+# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
+# includes: MySQL x.0 Keywords and Reserved Words
+# excludes: MySQL x.0 New Keywords and Reserved Words,
+# MySQL x.0 Removed Keywords and Reserved Words
+RESERVED_WORDS_MYSQL = {
+ "accessible",
+ "add",
+ "admin",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "array",
+ "as",
+ "asc",
+ "asensitive",
+ "before",
+ "between",
+ "bigint",
+ "binary",
+ "blob",
+ "both",
+ "by",
+ "call",
+ "cascade",
+ "case",
+ "change",
+ "char",
+ "character",
+ "check",
+ "collate",
+ "column",
+ "condition",
+ "constraint",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "cube",
+ "cume_dist",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "database",
+ "databases",
+ "day_hour",
+ "day_microsecond",
+ "day_minute",
+ "day_second",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delayed",
+ "delete",
+ "dense_rank",
+ "desc",
+ "describe",
+ "deterministic",
+ "distinct",
+ "distinctrow",
+ "div",
+ "double",
+ "drop",
+ "dual",
+ "each",
+ "else",
+ "elseif",
+ "empty",
+ "enclosed",
+ "escaped",
+ "except",
+ "exists",
+ "exit",
+ "explain",
+ "false",
+ "fetch",
+ "first_value",
+ "float",
+ "float4",
+ "float8",
+ "for",
+ "force",
+ "foreign",
+ "from",
+ "fulltext",
+ "function",
+ "general",
+ "generated",
+ "get",
+ "get_master_public_key",
+ "grant",
+ "group",
+ "grouping",
+ "groups",
+ "having",
+ "high_priority",
+ "hour_microsecond",
+ "hour_minute",
+ "hour_second",
+ "if",
+ "ignore",
+ "ignore_server_ids",
+ "in",
+ "index",
+ "infile",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "int",
+ "int1",
+ "int2",
+ "int3",
+ "int4",
+ "int8",
+ "integer",
+ "interval",
+ "into",
+ "io_after_gtids",
+ "io_before_gtids",
+ "is",
+ "iterate",
+ "join",
+ "json_table",
+ "key",
+ "keys",
+ "kill",
+ "lag",
+ "last_value",
+ "lateral",
+ "lead",
+ "leading",
+ "leave",
+ "left",
+ "like",
+ "limit",
+ "linear",
+ "lines",
+ "load",
+ "localtime",
+ "localtimestamp",
+ "lock",
+ "long",
+ "longblob",
+ "longtext",
+ "loop",
+ "low_priority",
+ "master_bind",
+ "master_heartbeat_period",
+ "master_ssl_verify_server_cert",
+ "match",
+ "maxvalue",
+ "mediumblob",
+ "mediumint",
+ "mediumtext",
+ "member",
+ "middleint",
+ "minute_microsecond",
+ "minute_second",
+ "mod",
+ "modifies",
+ "natural",
+ "no_write_to_binlog",
+ "not",
+ "nth_value",
+ "ntile",
+ "null",
+ "numeric",
+ "of",
+ "on",
+ "optimize",
+ "optimizer_costs",
+ "option",
+ "optionally",
+ "or",
+ "order",
+ "out",
+ "outer",
+ "outfile",
+ "over",
+ "parse_gcol_expr",
+ "partition",
+ "percent_rank",
+ "persist",
+ "persist_only",
+ "precision",
+ "primary",
+ "procedure",
+ "purge",
+ "range",
+ "rank",
+ "read",
+ "read_write",
+ "reads",
+ "real",
+ "recursive",
+ "references",
+ "regexp",
+ "release",
+ "rename",
+ "repeat",
+ "replace",
+ "require",
+ "resignal",
+ "restrict",
+ "return",
+ "revoke",
+ "right",
+ "rlike",
+ "role",
+ "row",
+ "row_number",
+ "rows",
+ "schema",
+ "schemas",
+ "second_microsecond",
+ "select",
+ "sensitive",
+ "separator",
+ "set",
+ "show",
+ "signal",
+ "slow",
+ "smallint",
+ "spatial",
+ "specific",
+ "sql",
+ "sql_after_gtids",
+ "sql_before_gtids",
+ "sql_big_result",
+ "sql_calc_found_rows",
+ "sql_small_result",
+ "sqlexception",
+ "sqlstate",
+ "sqlwarning",
+ "ssl",
+ "starting",
+ "stored",
+ "straight_join",
+ "system",
+ "table",
+ "terminated",
+ "then",
+ "tinyblob",
+ "tinyint",
+ "tinytext",
+ "to",
+ "trailing",
+ "trigger",
+ "true",
+ "undo",
+ "union",
+ "unique",
+ "unlock",
+ "unsigned",
+ "update",
+ "usage",
+ "use",
+ "using",
+ "utc_date",
+ "utc_time",
+ "utc_timestamp",
+ "values",
+ "varbinary",
+ "varchar",
+ "varcharacter",
+ "varying",
+ "virtual",
+ "when",
+ "where",
+ "while",
+ "window",
+ "with",
+ "write",
+ "xor",
+ "year_month",
+ "zerofill",
+}
diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py
new file mode 100644
index 0000000..b81ee95
--- /dev/null
+++ b/lib/sqlalchemy/dialects/mysql/types.py
@@ -0,0 +1,773 @@
+# mysql/types.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import datetime
+
+from ... import exc
+from ... import types as sqltypes
+from ... import util
+
+
+class _NumericType(object):
+ """Base for MySQL numeric types.
+
+ This is the base both for NUMERIC as well as INTEGER, hence
+ it's a mixin.
+
+ """
+
+ def __init__(self, unsigned=False, zerofill=False, **kw):
+ self.unsigned = unsigned
+ self.zerofill = zerofill
+ super(_NumericType, self).__init__(**kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_NumericType, sqltypes.Numeric]
+ )
+
+
+class _FloatType(_NumericType, sqltypes.Float):
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ if isinstance(self, (REAL, DOUBLE)) and (
+ (precision is None and scale is not None)
+ or (precision is not None and scale is None)
+ ):
+ raise exc.ArgumentError(
+ "You must specify both precision and scale or omit "
+ "both altogether."
+ )
+ super(_FloatType, self).__init__(
+ precision=precision, asdecimal=asdecimal, **kw
+ )
+ self.scale = scale
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
+ )
+
+
+class _IntegerType(_NumericType, sqltypes.Integer):
+ def __init__(self, display_width=None, **kw):
+ self.display_width = display_width
+ super(_IntegerType, self).__init__(**kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
+ )
+
+
+class _StringType(sqltypes.String):
+ """Base for MySQL string types."""
+
+ def __init__(
+ self,
+ charset=None,
+ collation=None,
+ ascii=False, # noqa
+ binary=False,
+ unicode=False,
+ national=False,
+ **kw
+ ):
+ self.charset = charset
+
+ # allow collate= or collation=
+ kw.setdefault("collation", kw.pop("collate", collation))
+
+ self.ascii = ascii
+ self.unicode = unicode
+ self.binary = binary
+ self.national = national
+ super(_StringType, self).__init__(**kw)
+
+ def __repr__(self):
+ return util.generic_repr(
+ self, to_inspect=[_StringType, sqltypes.String]
+ )
+
+
+class _MatchType(sqltypes.Float, sqltypes.MatchType):
+ def __init__(self, **kw):
+ # TODO: float arguments?
+ sqltypes.Float.__init__(self)
+ sqltypes.MatchType.__init__(self)
+
+
+class NUMERIC(_NumericType, sqltypes.NUMERIC):
+ """MySQL NUMERIC type."""
+
+ __visit_name__ = "NUMERIC"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a NUMERIC.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(NUMERIC, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class DECIMAL(_NumericType, sqltypes.DECIMAL):
+ """MySQL DECIMAL type."""
+
+ __visit_name__ = "DECIMAL"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a DECIMAL.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(DECIMAL, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class DOUBLE(_FloatType):
+ """MySQL DOUBLE type."""
+
+ __visit_name__ = "DOUBLE"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a DOUBLE.
+
+ .. note::
+
+ The :class:`.DOUBLE` type by default converts from float
+ to Decimal, using a truncation that defaults to 10 digits.
+ Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
+ to change this scale, or ``asdecimal=False`` to return values
+ directly as Python floating points.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(DOUBLE, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class REAL(_FloatType, sqltypes.REAL):
+ """MySQL REAL type."""
+
+ __visit_name__ = "REAL"
+
+ def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
+ """Construct a REAL.
+
+ .. note::
+
+ The :class:`.REAL` type by default converts from float
+ to Decimal, using a truncation that defaults to 10 digits.
+ Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
+ to change this scale, or ``asdecimal=False`` to return values
+ directly as Python floating points.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(REAL, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+
+class FLOAT(_FloatType, sqltypes.FLOAT):
+ """MySQL FLOAT type."""
+
+ __visit_name__ = "FLOAT"
+
+ def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
+ """Construct a FLOAT.
+
+ :param precision: Total digits in this number. If scale and precision
+ are both None, values are stored to limits allowed by the server.
+
+ :param scale: The number of digits after the decimal point.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(FLOAT, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal, **kw
+ )
+
+ def bind_processor(self, dialect):
+ return None
+
+
+class INTEGER(_IntegerType, sqltypes.INTEGER):
+ """MySQL INTEGER type."""
+
+ __visit_name__ = "INTEGER"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct an INTEGER.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(INTEGER, self).__init__(display_width=display_width, **kw)
+
+
+class BIGINT(_IntegerType, sqltypes.BIGINT):
+ """MySQL BIGINTEGER type."""
+
+ __visit_name__ = "BIGINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a BIGINTEGER.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(BIGINT, self).__init__(display_width=display_width, **kw)
+
+
+class MEDIUMINT(_IntegerType):
+ """MySQL MEDIUMINTEGER type."""
+
+ __visit_name__ = "MEDIUMINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a MEDIUMINTEGER
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(MEDIUMINT, self).__init__(display_width=display_width, **kw)
+
+
+class TINYINT(_IntegerType):
+ """MySQL TINYINT type."""
+
+ __visit_name__ = "TINYINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a TINYINT.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(TINYINT, self).__init__(display_width=display_width, **kw)
+
+
+class SMALLINT(_IntegerType, sqltypes.SMALLINT):
+ """MySQL SMALLINTEGER type."""
+
+ __visit_name__ = "SMALLINT"
+
+ def __init__(self, display_width=None, **kw):
+ """Construct a SMALLINTEGER.
+
+ :param display_width: Optional, maximum display width for this number.
+
+ :param unsigned: a boolean, optional.
+
+ :param zerofill: Optional. If true, values will be stored as strings
+ left-padded with zeros. Note that this does not effect the values
+ returned by the underlying database API, which continue to be
+ numeric.
+
+ """
+ super(SMALLINT, self).__init__(display_width=display_width, **kw)
+
+
+class BIT(sqltypes.TypeEngine):
+ """MySQL BIT type.
+
+ This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
+ for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a
+ MSTinyInteger() type.
+
+ """
+
+ __visit_name__ = "BIT"
+
+ def __init__(self, length=None):
+ """Construct a BIT.
+
+ :param length: Optional, number of bits.
+
+ """
+ self.length = length
+
+ def result_processor(self, dialect, coltype):
+ """Convert a MySQL's 64 bit, variable length binary string to a long.
+
+ TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
+ already do this, so this logic should be moved to those dialects.
+
+ """
+
+ def process(value):
+ if value is not None:
+ v = 0
+ for i in value:
+ if not isinstance(i, int):
+ i = ord(i) # convert byte to int on Python 2
+ v = v << 8 | i
+ return v
+ return value
+
+ return process
+
+
+class TIME(sqltypes.TIME):
+ """MySQL TIME type."""
+
+ __visit_name__ = "TIME"
+
+ def __init__(self, timezone=False, fsp=None):
+ """Construct a MySQL TIME type.
+
+ :param timezone: not used by the MySQL dialect.
+ :param fsp: fractional seconds precision value.
+ MySQL 5.6 supports storage of fractional seconds;
+ this parameter will be used when emitting DDL
+ for the TIME type.
+
+ .. note::
+
+ DBAPI driver support for fractional seconds may
+ be limited; current support includes
+ MySQL Connector/Python.
+
+ """
+ super(TIME, self).__init__(timezone=timezone)
+ self.fsp = fsp
+
+ def result_processor(self, dialect, coltype):
+ time = datetime.time
+
+ def process(value):
+ # convert from a timedelta value
+ if value is not None:
+ microseconds = value.microseconds
+ seconds = value.seconds
+ minutes = seconds // 60
+ return time(
+ minutes // 60,
+ minutes % 60,
+ seconds - minutes * 60,
+ microsecond=microseconds,
+ )
+ else:
+ return None
+
+ return process
+
+
+class TIMESTAMP(sqltypes.TIMESTAMP):
+ """MySQL TIMESTAMP type."""
+
+ __visit_name__ = "TIMESTAMP"
+
+ def __init__(self, timezone=False, fsp=None):
+ """Construct a MySQL TIMESTAMP type.
+
+ :param timezone: not used by the MySQL dialect.
+ :param fsp: fractional seconds precision value.
+ MySQL 5.6.4 supports storage of fractional seconds;
+ this parameter will be used when emitting DDL
+ for the TIMESTAMP type.
+
+ .. note::
+
+ DBAPI driver support for fractional seconds may
+ be limited; current support includes
+ MySQL Connector/Python.
+
+ """
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.fsp = fsp
+
+
+class DATETIME(sqltypes.DATETIME):
+ """MySQL DATETIME type."""
+
+ __visit_name__ = "DATETIME"
+
+ def __init__(self, timezone=False, fsp=None):
+ """Construct a MySQL DATETIME type.
+
+ :param timezone: not used by the MySQL dialect.
+ :param fsp: fractional seconds precision value.
+ MySQL 5.6.4 supports storage of fractional seconds;
+ this parameter will be used when emitting DDL
+ for the DATETIME type.
+
+ .. note::
+
+ DBAPI driver support for fractional seconds may
+ be limited; current support includes
+ MySQL Connector/Python.
+
+ """
+ super(DATETIME, self).__init__(timezone=timezone)
+ self.fsp = fsp
+
+
+class YEAR(sqltypes.TypeEngine):
+ """MySQL YEAR type, for single byte storage of years 1901-2155."""
+
+ __visit_name__ = "YEAR"
+
+ def __init__(self, display_width=None):
+ self.display_width = display_width
+
+
+class TEXT(_StringType, sqltypes.TEXT):
+ """MySQL TEXT type, for text up to 2^16 characters."""
+
+ __visit_name__ = "TEXT"
+
+ def __init__(self, length=None, **kw):
+ """Construct a TEXT.
+
+ :param length: Optional, if provided the server may optimize storage
+ by substituting the smallest TEXT type sufficient to store
+ ``length`` characters.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(TEXT, self).__init__(length=length, **kw)
+
+
+class TINYTEXT(_StringType):
+ """MySQL TINYTEXT type, for text up to 2^8 characters."""
+
+ __visit_name__ = "TINYTEXT"
+
+ def __init__(self, **kwargs):
+ """Construct a TINYTEXT.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(TINYTEXT, self).__init__(**kwargs)
+
+
+class MEDIUMTEXT(_StringType):
+ """MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
+
+ __visit_name__ = "MEDIUMTEXT"
+
+ def __init__(self, **kwargs):
+ """Construct a MEDIUMTEXT.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(MEDIUMTEXT, self).__init__(**kwargs)
+
+
+class LONGTEXT(_StringType):
+ """MySQL LONGTEXT type, for text up to 2^32 characters."""
+
+ __visit_name__ = "LONGTEXT"
+
+ def __init__(self, **kwargs):
+ """Construct a LONGTEXT.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(LONGTEXT, self).__init__(**kwargs)
+
+
+class VARCHAR(_StringType, sqltypes.VARCHAR):
+ """MySQL VARCHAR type, for variable-length character data."""
+
+ __visit_name__ = "VARCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct a VARCHAR.
+
+ :param charset: Optional, a column-level character set for this string
+ value. Takes precedence to 'ascii' or 'unicode' short-hand.
+
+ :param collation: Optional, a column-level collation for this string
+ value. Takes precedence to 'binary' short-hand.
+
+ :param ascii: Defaults to False: short-hand for the ``latin1``
+ character set, generates ASCII in schema.
+
+ :param unicode: Defaults to False: short-hand for the ``ucs2``
+ character set, generates UNICODE in schema.
+
+ :param national: Optional. If true, use the server's configured
+ national character set.
+
+ :param binary: Defaults to False: short-hand, pick the binary
+ collation type that matches the column's character set. Generates
+ BINARY in schema. This does not affect the type of data stored,
+ only the collation of character data.
+
+ """
+ super(VARCHAR, self).__init__(length=length, **kwargs)
+
+
+class CHAR(_StringType, sqltypes.CHAR):
+ """MySQL CHAR type, for fixed-length character data."""
+
+ __visit_name__ = "CHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct a CHAR.
+
+ :param length: Maximum data length, in characters.
+
+ :param binary: Optional, use the default binary collation for the
+ national character set. This does not affect the type of data
+ stored, use a BINARY type for binary data.
+
+ :param collation: Optional, request a particular collation. Must be
+ compatible with the national character set.
+
+ """
+ super(CHAR, self).__init__(length=length, **kwargs)
+
+ @classmethod
+ def _adapt_string_for_cast(self, type_):
+ # copy the given string type into a CHAR
+ # for the purposes of rendering a CAST expression
+ type_ = sqltypes.to_instance(type_)
+ if isinstance(type_, sqltypes.CHAR):
+ return type_
+ elif isinstance(type_, _StringType):
+ return CHAR(
+ length=type_.length,
+ charset=type_.charset,
+ collation=type_.collation,
+ ascii=type_.ascii,
+ binary=type_.binary,
+ unicode=type_.unicode,
+ national=False, # not supported in CAST
+ )
+ else:
+ return CHAR(length=type_.length)
+
+
+class NVARCHAR(_StringType, sqltypes.NVARCHAR):
+ """MySQL NVARCHAR type.
+
+ For variable-length character data in the server's configured national
+ character set.
+ """
+
+ __visit_name__ = "NVARCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct an NVARCHAR.
+
+ :param length: Maximum data length, in characters.
+
+ :param binary: Optional, use the default binary collation for the
+ national character set. This does not affect the type of data
+ stored, use a BINARY type for binary data.
+
+ :param collation: Optional, request a particular collation. Must be
+ compatible with the national character set.
+
+ """
+ kwargs["national"] = True
+ super(NVARCHAR, self).__init__(length=length, **kwargs)
+
+
+class NCHAR(_StringType, sqltypes.NCHAR):
+ """MySQL NCHAR type.
+
+ For fixed-length character data in the server's configured national
+ character set.
+ """
+
+ __visit_name__ = "NCHAR"
+
+ def __init__(self, length=None, **kwargs):
+ """Construct an NCHAR.
+
+ :param length: Maximum data length, in characters.
+
+ :param binary: Optional, use the default binary collation for the
+ national character set. This does not affect the type of data
+ stored, use a BINARY type for binary data.
+
+ :param collation: Optional, request a particular collation. Must be
+ compatible with the national character set.
+
+ """
+ kwargs["national"] = True
+ super(NCHAR, self).__init__(length=length, **kwargs)
+
+
+class TINYBLOB(sqltypes._Binary):
+ """MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
+
+ __visit_name__ = "TINYBLOB"
+
+
+class MEDIUMBLOB(sqltypes._Binary):
+ """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
+
+ __visit_name__ = "MEDIUMBLOB"
+
+
+class LONGBLOB(sqltypes._Binary):
+ """MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
+
+ __visit_name__ = "LONGBLOB"
diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py
new file mode 100644
index 0000000..c83e057
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/__init__.py
@@ -0,0 +1,58 @@
+# oracle/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import cx_oracle # noqa
+from .base import BFILE
+from .base import BINARY_DOUBLE
+from .base import BINARY_FLOAT
+from .base import BLOB
+from .base import CHAR
+from .base import CLOB
+from .base import DATE
+from .base import DOUBLE_PRECISION
+from .base import FLOAT
+from .base import INTERVAL
+from .base import LONG
+from .base import NCHAR
+from .base import NCLOB
+from .base import NUMBER
+from .base import NVARCHAR
+from .base import NVARCHAR2
+from .base import RAW
+from .base import ROWID
+from .base import TIMESTAMP
+from .base import VARCHAR
+from .base import VARCHAR2
+
+
+base.dialect = dialect = cx_oracle.dialect
+
+__all__ = (
+ "VARCHAR",
+ "NVARCHAR",
+ "CHAR",
+ "NCHAR",
+ "DATE",
+ "NUMBER",
+ "BLOB",
+ "BFILE",
+ "CLOB",
+ "NCLOB",
+ "TIMESTAMP",
+ "RAW",
+ "FLOAT",
+ "DOUBLE_PRECISION",
+ "BINARY_DOUBLE",
+ "BINARY_FLOAT",
+ "LONG",
+ "dialect",
+ "INTERVAL",
+ "VARCHAR2",
+ "NVARCHAR2",
+ "ROWID",
+)
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
new file mode 100644
index 0000000..77f0dbd
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -0,0 +1,2522 @@
+# oracle/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: oracle
+ :name: Oracle
+ :full_support: 11.2, 18c
+ :normal_support: 11+
+ :best_effort: 8+
+
+
+Auto Increment Behavior
+-----------------------
+
+SQLAlchemy Table objects which include integer primary keys are usually
+assumed to have "autoincrementing" behavior, meaning they can generate their
+own primary key values upon INSERT. For use within Oracle, two options are
+available, which are the use of IDENTITY columns (Oracle 12 and above only)
+or the association of a SEQUENCE with the column.
+
+Specifying GENERATED AS IDENTITY (Oracle 12 and above)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Starting from version 12 Oracle can make use of identity columns using
+the :class:`_sql.Identity` to specify the autoincrementing behavior::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Identity(start=3), primary_key=True),
+ Column(...), ...
+ )
+
+The CREATE TABLE for the above :class:`_schema.Table` object would be:
+
+.. sourcecode:: sql
+
+ CREATE TABLE mytable (
+ id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 3),
+ ...,
+ PRIMARY KEY (id)
+ )
+
+The :class:`_schema.Identity` object support many options to control the
+"autoincrementing" behavior of the column, like the starting value, the
+incrementing value, etc.
+In addition to the standard options, Oracle supports setting
+:paramref:`_schema.Identity.always` to ``None`` to use the default
+generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports
+setting :paramref:`_schema.Identity.on_null` to ``True`` to specify ON NULL
+in conjunction with a 'BY DEFAULT' identity column.
+
+Using a SEQUENCE (all Oracle versions)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Older version of Oracle had no "autoincrement"
+feature, SQLAlchemy relies upon sequences to produce these values. With the
+older Oracle versions, *a sequence must always be explicitly specified to
+enable autoincrement*. This is divergent with the majority of documentation
+examples which assume the usage of an autoincrement-capable database. To
+specify sequences, use the sqlalchemy.schema.Sequence object which is passed
+to a Column construct::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Sequence('id_seq'), primary_key=True),
+ Column(...), ...
+ )
+
+This step is also required when using table reflection, i.e. autoload_with=engine::
+
+ t = Table('mytable', metadata,
+ Column('id', Integer, Sequence('id_seq'), primary_key=True),
+ autoload_with=engine
+ )
+
+.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
+ in a :class:`_schema.Column` to specify the option of an autoincrementing
+ column.
+
+.. _oracle_isolation_level:
+
+Transaction Isolation Level / Autocommit
+----------------------------------------
+
+The Oracle database supports "READ COMMITTED" and "SERIALIZABLE" modes of
+isolation. The AUTOCOMMIT isolation level is also supported by the cx_Oracle
+dialect.
+
+To set using per-connection execution options::
+
+ connection = engine.connect()
+ connection = connection.execution_options(
+ isolation_level="AUTOCOMMIT"
+ )
+
+For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle dialect sets the
+level at the session level using ``ALTER SESSION``, which is reverted back
+to its default setting when the connection is returned to the connection
+pool.
+
+Valid values for ``isolation_level`` include:
+
+* ``READ COMMITTED``
+* ``AUTOCOMMIT``
+* ``SERIALIZABLE``
+
+.. note:: The implementation for the
+ :meth:`_engine.Connection.get_isolation_level` method as implemented by the
+ Oracle dialect necessarily forces the start of a transaction using the
+ Oracle LOCAL_TRANSACTION_ID function; otherwise no level is normally
+ readable.
+
+ Additionally, the :meth:`_engine.Connection.get_isolation_level` method will
+ raise an exception if the ``v$transaction`` view is not available due to
+ permissions or other reasons, which is a common occurrence in Oracle
+ installations.
+
+ The cx_Oracle dialect attempts to call the
+ :meth:`_engine.Connection.get_isolation_level` method when the dialect makes
+ its first connection to the database in order to acquire the
+ "default"isolation level. This default level is necessary so that the level
+ can be reset on a connection after it has been temporarily modified using
+ :meth:`_engine.Connection.execution_options` method. In the common event
+ that the :meth:`_engine.Connection.get_isolation_level` method raises an
+ exception due to ``v$transaction`` not being readable as well as any other
+ database-related failure, the level is assumed to be "READ COMMITTED". No
+ warning is emitted for this initial first-connect condition as it is
+ expected to be a common restriction on Oracle databases.
+
+.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_oracle dialect
+ as well as the notion of a default isolation level
+
+.. versionadded:: 1.3.21 Added support for SERIALIZABLE as well as live
+ reading of the isolation level.
+
+.. versionchanged:: 1.3.22 In the event that the default isolation
+ level cannot be read due to permissions on the v$transaction view as
+ is common in Oracle installations, the default isolation level is hardcoded
+ to "READ COMMITTED" which was the behavior prior to 1.3.21.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+Identifier Casing
+-----------------
+
+In Oracle, the data dictionary represents all case insensitive identifier
+names using UPPERCASE text. SQLAlchemy on the other hand considers an
+all-lower case identifier name to be case insensitive. The Oracle dialect
+converts all case insensitive identifiers to and from those two formats during
+schema level communication, such as reflection of tables and indexes. Using
+an UPPERCASE name on the SQLAlchemy side indicates a case sensitive
+identifier, and SQLAlchemy will quote the name - this will cause mismatches
+against data dictionary data received from Oracle, so unless identifier names
+have been truly created as case sensitive (i.e. using quoted names), all
+lowercase names should be used on the SQLAlchemy side.
+
+.. _oracle_max_identifier_lengths:
+
+Max Identifier Lengths
+----------------------
+
+Oracle has changed the default max identifier length as of Oracle Server
+version 12.2. Prior to this version, the length was 30, and for 12.2 and
+greater it is now 128. This change impacts SQLAlchemy in the area of
+generated SQL label names as well as the generation of constraint names,
+particularly in the case where the constraint naming convention feature
+described at :ref:`constraint_naming_conventions` is being used.
+
+To assist with this change and others, Oracle includes the concept of a
+"compatibility" version, which is a version number that is independent of the
+actual server version in order to assist with migration of Oracle databases,
+and may be configured within the Oracle server itself. This compatibility
+version is retrieved using the query ``SELECT value FROM v$parameter WHERE
+name = 'compatible';``. The SQLAlchemy Oracle dialect, when tasked with
+determining the default max identifier length, will attempt to use this query
+upon first connect in order to determine the effective compatibility version of
+the server, which determines what the maximum allowed identifier length is for
+the server. If the table is not available, the server version information is
+used instead.
+
+As of SQLAlchemy 1.4, the default max identifier length for the Oracle dialect
+is 128 characters. Upon first connect, the compatibility version is detected
+and if it is less than Oracle version 12.2, the max identifier length is
+changed to be 30 characters. In all cases, setting the
+:paramref:`_sa.create_engine.max_identifier_length` parameter will bypass this
+change and the value given will be used as is::
+
+ engine = create_engine(
+ "oracle+cx_oracle://scott:tiger@oracle122",
+ max_identifier_length=30)
+
+The maximum identifier length comes into play both when generating anonymized
+SQL labels in SELECT statements, but more crucially when generating constraint
+names from a naming convention. It is this area that has created the need for
+SQLAlchemy to change this default conservatively. For example, the following
+naming convention produces two very different constraint names based on the
+identifier length::
+
+ from sqlalchemy import Column
+ from sqlalchemy import Index
+ from sqlalchemy import Integer
+ from sqlalchemy import MetaData
+ from sqlalchemy import Table
+ from sqlalchemy.dialects import oracle
+ from sqlalchemy.schema import CreateIndex
+
+ m = MetaData(naming_convention={"ix": "ix_%(column_0N_name)s"})
+
+ t = Table(
+ "t",
+ m,
+ Column("some_column_name_1", Integer),
+ Column("some_column_name_2", Integer),
+ Column("some_column_name_3", Integer),
+ )
+
+ ix = Index(
+ None,
+ t.c.some_column_name_1,
+ t.c.some_column_name_2,
+ t.c.some_column_name_3,
+ )
+
+ oracle_dialect = oracle.dialect(max_identifier_length=30)
+ print(CreateIndex(ix).compile(dialect=oracle_dialect))
+
+With an identifier length of 30, the above CREATE INDEX looks like::
+
+ CREATE INDEX ix_some_column_name_1s_70cd ON t
+ (some_column_name_1, some_column_name_2, some_column_name_3)
+
+However with length=128, it becomes::
+
+ CREATE INDEX ix_some_column_name_1some_column_name_2some_column_name_3 ON t
+ (some_column_name_1, some_column_name_2, some_column_name_3)
+
+Applications which have run versions of SQLAlchemy prior to 1.4 on an Oracle
+server version 12.2 or greater are therefore subject to the scenario of a
+database migration that wishes to "DROP CONSTRAINT" on a name that was
+previously generated with the shorter length. This migration will fail when
+the identifier length is changed without the name of the index or constraint
+first being adjusted. Such applications are strongly advised to make use of
+:paramref:`_sa.create_engine.max_identifier_length`
+in order to maintain control
+of the generation of truncated names, and to fully review and test all database
+migrations in a staging environment when changing this value to ensure that the
+impact of this change has been mitigated.
+
+.. versionchanged:: 1.4 the default max_identifier_length for Oracle is 128
+ characters, which is adjusted down to 30 upon first connect if an older
+ version of Oracle server (compatibility version < 12.2) is detected.
+
+
+LIMIT/OFFSET/FETCH Support
+--------------------------
+
+Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` currently
+use an emulated approach for LIMIT / OFFSET based on window functions, which
+involves creation of a subquery using ``ROW_NUMBER`` that is prone to
+performance issues as well as SQL construction issues for complex statements.
+However, this approach is supported by all Oracle versions. See notes below.
+
+When using Oracle 12c and above, use the :meth:`_sql.Select.fetch` method
+instead; this will render the more modern
+``FETCH FIRST N ROW / OFFSET N ROWS`` syntax.
+
+Notes on LIMIT / OFFSET emulation (when fetch() method cannot be used)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+If using :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset`,
+or with the ORM the :meth:`_orm.Query.limit` and :meth:`_orm.Query.offset` methods,
+and the :meth:`_sql.Select.fetch` method **cannot** be used instead, the following
+notes apply:
+
+* SQLAlchemy currently makes use of ROWNUM to achieve
+ LIMIT/OFFSET; the exact methodology is taken from
+ https://blogs.oracle.com/oraclemagazine/on-rownum-and-limiting-results .
+
+* the "FIRST_ROWS()" optimization keyword is not used by default. To enable
+ the usage of this optimization directive, specify ``optimize_limits=True``
+ to :func:`_sa.create_engine`.
+
+ .. versionchanged:: 1.4
+ The Oracle dialect renders limit/offset integer values using a "post
+ compile" scheme which renders the integer directly before passing the
+ statement to the cursor for execution. The ``use_binds_for_limits`` flag
+ no longer has an effect.
+
+ .. seealso::
+
+ :ref:`change_4808`.
+
+* A future release may use ``FETCH FIRST N ROW / OFFSET N ROWS`` automatically
+ when :meth:`_sql.Select.limit`, :meth:`_sql.Select.offset`, :meth:`_orm.Query.limit`,
+ :meth:`_orm.Query.offset` are used.
+
+.. _oracle_returning:
+
+RETURNING Support
+-----------------
+
+The Oracle database supports a limited form of RETURNING, in order to retrieve
+result sets of matched rows from INSERT, UPDATE and DELETE statements.
+Oracle's RETURNING..INTO syntax only supports one row being returned, as it
+relies upon OUT parameters in order to function. In addition, supported
+DBAPIs have further limitations (see :ref:`cx_oracle_returning`).
+
+SQLAlchemy's "implicit returning" feature, which employs RETURNING within an
+INSERT and sometimes an UPDATE statement in order to fetch newly generated
+primary key values and other SQL defaults and expressions, is normally enabled
+on the Oracle backend. By default, "implicit returning" typically only
+fetches the value of a single ``nextval(some_seq)`` expression embedded into
+an INSERT in order to increment a sequence within an INSERT statement and get
+the value back at the same time. To disable this feature across the board,
+specify ``implicit_returning=False`` to :func:`_sa.create_engine`::
+
+ engine = create_engine("oracle://scott:tiger@dsn",
+ implicit_returning=False)
+
+Implicit returning can also be disabled on a table-by-table basis as a table
+option::
+
+ # Core Table
+ my_table = Table("my_table", metadata, ..., implicit_returning=False)
+
+
+ # declarative
+ class MyClass(Base):
+ __tablename__ = 'my_table'
+ __table_args__ = {"implicit_returning": False}
+
+.. seealso::
+
+ :ref:`cx_oracle_returning` - additional cx_oracle-specific restrictions on
+ implicit returning.
+
+ON UPDATE CASCADE
+-----------------
+
+Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based
+solution is available at
+https://asktom.oracle.com/tkyte/update_cascade/index.html .
+
+When using the SQLAlchemy ORM, the ORM has limited ability to manually issue
+cascading updates - specify ForeignKey objects using the
+"deferrable=True, initially='deferred'" keyword arguments,
+and specify "passive_updates=False" on each relationship().
+
+Oracle 8 Compatibility
+----------------------
+
+When Oracle 8 is detected, the dialect internally configures itself to the
+following behaviors:
+
+* the use_ansi flag is set to False. This has the effect of converting all
+ JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN
+ makes use of Oracle's (+) operator.
+
+* the NVARCHAR2 and NCLOB datatypes are no longer generated as DDL when
+ the :class:`~sqlalchemy.types.Unicode` is used - VARCHAR2 and CLOB are
+ issued instead. This because these types don't seem to work correctly on
+ Oracle 8 even though they are available. The
+ :class:`~sqlalchemy.types.NVARCHAR` and
+ :class:`~sqlalchemy.dialects.oracle.NCLOB` types will always generate
+ NVARCHAR2 and NCLOB.
+
+* the "native unicode" mode is disabled when using cx_oracle, i.e. SQLAlchemy
+ encodes all Python unicode objects to "string" before passing in as bind
+ parameters.
+
+Synonym/DBLINK Reflection
+-------------------------
+
+When using reflection with Table objects, the dialect can optionally search
+for tables indicated by synonyms, either in local or remote schemas or
+accessed over DBLINK, by passing the flag ``oracle_resolve_synonyms=True`` as
+a keyword argument to the :class:`_schema.Table` construct::
+
+ some_table = Table('some_table', autoload_with=some_engine,
+ oracle_resolve_synonyms=True)
+
+When this flag is set, the given name (such as ``some_table`` above) will
+be searched not just in the ``ALL_TABLES`` view, but also within the
+``ALL_SYNONYMS`` view to see if this name is actually a synonym to another
+name. If the synonym is located and refers to a DBLINK, the oracle dialect
+knows how to locate the table's information using DBLINK syntax(e.g.
+``@dblink``).
+
+``oracle_resolve_synonyms`` is accepted wherever reflection arguments are
+accepted, including methods such as :meth:`_schema.MetaData.reflect` and
+:meth:`_reflection.Inspector.get_columns`.
+
+If synonyms are not in use, this flag should be left disabled.
+
+.. _oracle_constraint_reflection:
+
+Constraint Reflection
+---------------------
+
+The Oracle dialect can return information about foreign key, unique, and
+CHECK constraints, as well as indexes on tables.
+
+Raw information regarding these constraints can be acquired using
+:meth:`_reflection.Inspector.get_foreign_keys`,
+:meth:`_reflection.Inspector.get_unique_constraints`,
+:meth:`_reflection.Inspector.get_check_constraints`, and
+:meth:`_reflection.Inspector.get_indexes`.
+
+.. versionchanged:: 1.2 The Oracle dialect can now reflect UNIQUE and
+ CHECK constraints.
+
+When using reflection at the :class:`_schema.Table` level, the
+:class:`_schema.Table`
+will also include these constraints.
+
+Note the following caveats:
+
+* When using the :meth:`_reflection.Inspector.get_check_constraints` method,
+ Oracle
+ builds a special "IS NOT NULL" constraint for columns that specify
+ "NOT NULL". This constraint is **not** returned by default; to include
+ the "IS NOT NULL" constraints, pass the flag ``include_all=True``::
+
+ from sqlalchemy import create_engine, inspect
+
+ engine = create_engine("oracle+cx_oracle://s:t@dsn")
+ inspector = inspect(engine)
+ all_check_constraints = inspector.get_check_constraints(
+ "some_table", include_all=True)
+
+* in most cases, when reflecting a :class:`_schema.Table`,
+ a UNIQUE constraint will
+ **not** be available as a :class:`.UniqueConstraint` object, as Oracle
+ mirrors unique constraints with a UNIQUE index in most cases (the exception
+ seems to be when two or more unique constraints represent the same columns);
+ the :class:`_schema.Table` will instead represent these using
+ :class:`.Index`
+ with the ``unique=True`` flag set.
+
+* Oracle creates an implicit index for the primary key of a table; this index
+ is **excluded** from all index results.
+
+* the list of columns reflected for an index will not include column names
+ that start with SYS_NC.
+
+Table names with SYSTEM/SYSAUX tablespaces
+-------------------------------------------
+
+The :meth:`_reflection.Inspector.get_table_names` and
+:meth:`_reflection.Inspector.get_temp_table_names`
+methods each return a list of table names for the current engine. These methods
+are also part of the reflection which occurs within an operation such as
+:meth:`_schema.MetaData.reflect`. By default,
+these operations exclude the ``SYSTEM``
+and ``SYSAUX`` tablespaces from the operation. In order to change this, the
+default list of tablespaces excluded can be changed at the engine level using
+the ``exclude_tablespaces`` parameter::
+
+ # exclude SYSAUX and SOME_TABLESPACE, but not SYSTEM
+ e = create_engine(
+ "oracle://scott:tiger@xe",
+ exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"])
+
+.. versionadded:: 1.1
+
+DateTime Compatibility
+----------------------
+
+Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``,
+which can actually store a date and time value. For this reason, the Oracle
+dialect provides a type :class:`_oracle.DATE` which is a subclass of
+:class:`.DateTime`. This type has no special behavior, and is only
+present as a "marker" for this type; additionally, when a database column
+is reflected and the type is reported as ``DATE``, the time-supporting
+:class:`_oracle.DATE` type is used.
+
+.. versionchanged:: 0.9.4 Added :class:`_oracle.DATE` to subclass
+ :class:`.DateTime`. This is a change as previous versions
+ would reflect a ``DATE`` column as :class:`_types.DATE`, which subclasses
+ :class:`.Date`. The only significance here is for schemes that are
+ examining the type of column for use in special Python translations or
+ for migrating schemas to other database backends.
+
+.. _oracle_table_options:
+
+Oracle Table Options
+-------------------------
+
+The CREATE TABLE phrase supports the following options with Oracle
+in conjunction with the :class:`_schema.Table` construct:
+
+
+* ``ON COMMIT``::
+
+ Table(
+ "some_table", metadata, ...,
+ prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS')
+
+.. versionadded:: 1.0.0
+
+* ``COMPRESS``::
+
+ Table('mytable', metadata, Column('data', String(32)),
+ oracle_compress=True)
+
+ Table('mytable', metadata, Column('data', String(32)),
+ oracle_compress=6)
+
+ The ``oracle_compress`` parameter accepts either an integer compression
+ level, or ``True`` to use the default compression level.
+
+.. versionadded:: 1.0.0
+
+.. _oracle_index_options:
+
+Oracle Specific Index Options
+-----------------------------
+
+Bitmap Indexes
+~~~~~~~~~~~~~~
+
+You can specify the ``oracle_bitmap`` parameter to create a bitmap index
+instead of a B-tree index::
+
+ Index('my_index', my_table.c.data, oracle_bitmap=True)
+
+Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not
+check for such limitations, only the database will.
+
+.. versionadded:: 1.0.0
+
+Index compression
+~~~~~~~~~~~~~~~~~
+
+Oracle has a more efficient storage mode for indexes containing lots of
+repeated values. Use the ``oracle_compress`` parameter to turn on key
+compression::
+
+ Index('my_index', my_table.c.data, oracle_compress=True)
+
+ Index('my_index', my_table.c.data1, my_table.c.data2, unique=True,
+ oracle_compress=1)
+
+The ``oracle_compress`` parameter accepts either an integer specifying the
+number of prefix columns to compress, or ``True`` to use the default (all
+columns for non-unique indexes, all but the last column for unique indexes).
+
+.. versionadded:: 1.0.0
+
+""" # noqa
+
+from itertools import groupby
+import re
+
+from ... import Computed
+from ... import exc
+from ... import schema as sa_schema
+from ... import sql
+from ... import util
+from ...engine import default
+from ...engine import reflection
+from ...sql import compiler
+from ...sql import expression
+from ...sql import sqltypes
+from ...sql import util as sql_util
+from ...sql import visitors
+from ...types import BLOB
+from ...types import CHAR
+from ...types import CLOB
+from ...types import FLOAT
+from ...types import INTEGER
+from ...types import NCHAR
+from ...types import NVARCHAR
+from ...types import TIMESTAMP
+from ...types import VARCHAR
+from ...util import compat
+
+RESERVED_WORDS = set(
+ "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN "
+ "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED "
+ "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE "
+ "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE "
+ "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES "
+ "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS "
+ "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER "
+ "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR "
+ "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split()
+)
+
+NO_ARG_FNS = set(
+ "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split()
+)
+
+
+class RAW(sqltypes._Binary):
+ __visit_name__ = "RAW"
+
+
+OracleRaw = RAW
+
+
+class NCLOB(sqltypes.Text):
+ __visit_name__ = "NCLOB"
+
+
+class VARCHAR2(VARCHAR):
+ __visit_name__ = "VARCHAR2"
+
+
+NVARCHAR2 = NVARCHAR
+
+
+class NUMBER(sqltypes.Numeric, sqltypes.Integer):
+ __visit_name__ = "NUMBER"
+
+ def __init__(self, precision=None, scale=None, asdecimal=None):
+ if asdecimal is None:
+ asdecimal = bool(scale and scale > 0)
+
+ super(NUMBER, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal
+ )
+
+ def adapt(self, impltype):
+ ret = super(NUMBER, self).adapt(impltype)
+ # leave a hint for the DBAPI handler
+ ret._is_oracle_number = True
+ return ret
+
+ @property
+ def _type_affinity(self):
+ if bool(self.scale and self.scale > 0):
+ return sqltypes.Numeric
+ else:
+ return sqltypes.Integer
+
+
+class DOUBLE_PRECISION(sqltypes.Float):
+ __visit_name__ = "DOUBLE_PRECISION"
+
+
+class BINARY_DOUBLE(sqltypes.Float):
+ __visit_name__ = "BINARY_DOUBLE"
+
+
+class BINARY_FLOAT(sqltypes.Float):
+ __visit_name__ = "BINARY_FLOAT"
+
+
+class BFILE(sqltypes.LargeBinary):
+ __visit_name__ = "BFILE"
+
+
+class LONG(sqltypes.Text):
+ __visit_name__ = "LONG"
+
+
+class DATE(sqltypes.DateTime):
+ """Provide the oracle DATE type.
+
+ This type has no special Python behavior, except that it subclasses
+ :class:`_types.DateTime`; this is to suit the fact that the Oracle
+ ``DATE`` type supports a time value.
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ __visit_name__ = "DATE"
+
+ def _compare_type_affinity(self, other):
+ return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+ __visit_name__ = "INTERVAL"
+
+ def __init__(self, day_precision=None, second_precision=None):
+ """Construct an INTERVAL.
+
+ Note that only DAY TO SECOND intervals are currently supported.
+ This is due to a lack of support for YEAR TO MONTH intervals
+ within available DBAPIs.
+
+ :param day_precision: the day precision value. this is the number of
+ digits to store for the day field. Defaults to "2"
+ :param second_precision: the second precision value. this is the
+ number of digits to store for the fractional seconds field.
+ Defaults to "6".
+
+ """
+ self.day_precision = day_precision
+ self.second_precision = second_precision
+
+ @classmethod
+ def _adapt_from_generic_interval(cls, interval):
+ return INTERVAL(
+ day_precision=interval.day_precision,
+ second_precision=interval.second_precision,
+ )
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(
+ native=True,
+ second_precision=self.second_precision,
+ day_precision=self.day_precision,
+ )
+
+ def coerce_compared_value(self, op, value):
+ return self
+
+
+class ROWID(sqltypes.TypeEngine):
+ """Oracle ROWID type.
+
+ When used in a cast() or similar, generates ROWID.
+
+ """
+
+ __visit_name__ = "ROWID"
+
+
+class _OracleBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+
+colspecs = {
+ sqltypes.Boolean: _OracleBoolean,
+ sqltypes.Interval: INTERVAL,
+ sqltypes.DateTime: DATE,
+}
+
+ischema_names = {
+ "VARCHAR2": VARCHAR,
+ "NVARCHAR2": NVARCHAR,
+ "CHAR": CHAR,
+ "NCHAR": NCHAR,
+ "DATE": DATE,
+ "NUMBER": NUMBER,
+ "BLOB": BLOB,
+ "BFILE": BFILE,
+ "CLOB": CLOB,
+ "NCLOB": NCLOB,
+ "TIMESTAMP": TIMESTAMP,
+ "TIMESTAMP WITH TIME ZONE": TIMESTAMP,
+ "INTERVAL DAY TO SECOND": INTERVAL,
+ "RAW": RAW,
+ "FLOAT": FLOAT,
+ "DOUBLE PRECISION": DOUBLE_PRECISION,
+ "LONG": LONG,
+ "BINARY_DOUBLE": BINARY_DOUBLE,
+ "BINARY_FLOAT": BINARY_FLOAT,
+}
+
+
+class OracleTypeCompiler(compiler.GenericTypeCompiler):
+ # Note:
+ # Oracle DATE == DATETIME
+ # Oracle does not allow milliseconds in DATE
+ # Oracle does not support TIME columns
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_DATE(type_, **kw)
+
+ def visit_float(self, type_, **kw):
+ return self.visit_FLOAT(type_, **kw)
+
+ def visit_unicode(self, type_, **kw):
+ if self.dialect._use_nchar_for_unicode:
+ return self.visit_NVARCHAR2(type_, **kw)
+ else:
+ return self.visit_VARCHAR2(type_, **kw)
+
+ def visit_INTERVAL(self, type_, **kw):
+ return "INTERVAL DAY%s TO SECOND%s" % (
+ type_.day_precision is not None
+ and "(%d)" % type_.day_precision
+ or "",
+ type_.second_precision is not None
+ and "(%d)" % type_.second_precision
+ or "",
+ )
+
+ def visit_LONG(self, type_, **kw):
+ return "LONG"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ if type_.timezone:
+ return "TIMESTAMP WITH TIME ZONE"
+ else:
+ return "TIMESTAMP"
+
+ def visit_DOUBLE_PRECISION(self, type_, **kw):
+ return self._generate_numeric(type_, "DOUBLE PRECISION", **kw)
+
+ def visit_BINARY_DOUBLE(self, type_, **kw):
+ return self._generate_numeric(type_, "BINARY_DOUBLE", **kw)
+
+ def visit_BINARY_FLOAT(self, type_, **kw):
+ return self._generate_numeric(type_, "BINARY_FLOAT", **kw)
+
+ def visit_FLOAT(self, type_, **kw):
+ # don't support conversion between decimal/binary
+ # precision yet
+ kw["no_precision"] = True
+ return self._generate_numeric(type_, "FLOAT", **kw)
+
+ def visit_NUMBER(self, type_, **kw):
+ return self._generate_numeric(type_, "NUMBER", **kw)
+
+ def _generate_numeric(
+ self, type_, name, precision=None, scale=None, no_precision=False, **kw
+ ):
+ if precision is None:
+ precision = type_.precision
+
+ if scale is None:
+ scale = getattr(type_, "scale", None)
+
+ if no_precision or precision is None:
+ return name
+ elif scale is None:
+ n = "%(name)s(%(precision)s)"
+ return n % {"name": name, "precision": precision}
+ else:
+ n = "%(name)s(%(precision)s, %(scale)s)"
+ return n % {"name": name, "precision": precision, "scale": scale}
+
+ def visit_string(self, type_, **kw):
+ return self.visit_VARCHAR2(type_, **kw)
+
+ def visit_VARCHAR2(self, type_, **kw):
+ return self._visit_varchar(type_, "", "2")
+
+ def visit_NVARCHAR2(self, type_, **kw):
+ return self._visit_varchar(type_, "N", "2")
+
+ visit_NVARCHAR = visit_NVARCHAR2
+
+ def visit_VARCHAR(self, type_, **kw):
+ return self._visit_varchar(type_, "", "")
+
+ def _visit_varchar(self, type_, n, num):
+ if not type_.length:
+ return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n}
+ elif not n and self.dialect._supports_char_length:
+ varchar = "VARCHAR%(two)s(%(length)s CHAR)"
+ return varchar % {"length": type_.length, "two": num}
+ else:
+ varchar = "%(n)sVARCHAR%(two)s(%(length)s)"
+ return varchar % {"length": type_.length, "two": num, "n": n}
+
+ def visit_text(self, type_, **kw):
+ return self.visit_CLOB(type_, **kw)
+
+ def visit_unicode_text(self, type_, **kw):
+ if self.dialect._use_nchar_for_unicode:
+ return self.visit_NCLOB(type_, **kw)
+ else:
+ return self.visit_CLOB(type_, **kw)
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_, **kw)
+
+ def visit_big_integer(self, type_, **kw):
+ return self.visit_NUMBER(type_, precision=19, **kw)
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
+
+ def visit_RAW(self, type_, **kw):
+ if type_.length:
+ return "RAW(%(length)s)" % {"length": type_.length}
+ else:
+ return "RAW"
+
+ def visit_ROWID(self, type_, **kw):
+ return "ROWID"
+
+
+class OracleCompiler(compiler.SQLCompiler):
+ """Oracle compiler modifies the lexical structure of Select
+ statements to work under non-ANSI configured Oracle databases, if
+ the use_ansi flag is False.
+ """
+
+ compound_keywords = util.update_copy(
+ compiler.SQLCompiler.compound_keywords,
+ {expression.CompoundSelect.EXCEPT: "MINUS"},
+ )
+
+ def __init__(self, *args, **kwargs):
+ self.__wheres = {}
+ super(OracleCompiler, self).__init__(*args, **kwargs)
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ return "mod(%s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_char_length_func(self, fn, **kw):
+ return "LENGTH" + self.function_argspec(fn, **kw)
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ return "CONTAINS (%s, %s)" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_true(self, expr, **kw):
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ return "0"
+
+ def get_cte_preamble(self, recursive):
+ return "WITH"
+
+ def get_select_hint_text(self, byfroms):
+ return " ".join("/*+ %s */" % text for table, text in byfroms.items())
+
+ def function_argspec(self, fn, **kw):
+ if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS:
+ return compiler.SQLCompiler.function_argspec(self, fn, **kw)
+ else:
+ return ""
+
+ def visit_function(self, func, **kw):
+ text = super(OracleCompiler, self).visit_function(func, **kw)
+ if kw.get("asfrom", False):
+ text = "TABLE (%s)" % func
+ return text
+
+ def visit_table_valued_column(self, element, **kw):
+ text = super(OracleCompiler, self).visit_table_valued_column(
+ element, **kw
+ )
+ text = "COLUMN_VALUE " + text
+ return text
+
+ def default_from(self):
+ """Called when a ``SELECT`` statement has no froms,
+ and no ``FROM`` clause is to be appended.
+
+ The Oracle compiler tacks a "FROM DUAL" to the statement.
+ """
+
+ return " FROM DUAL"
+
+ def visit_join(self, join, from_linter=None, **kwargs):
+ if self.dialect.use_ansi:
+ return compiler.SQLCompiler.visit_join(
+ self, join, from_linter=from_linter, **kwargs
+ )
+ else:
+ if from_linter:
+ from_linter.edges.add((join.left, join.right))
+
+ kwargs["asfrom"] = True
+ if isinstance(join.right, expression.FromGrouping):
+ right = join.right.element
+ else:
+ right = join.right
+ return (
+ self.process(join.left, from_linter=from_linter, **kwargs)
+ + ", "
+ + self.process(right, from_linter=from_linter, **kwargs)
+ )
+
+ def _get_nonansi_join_whereclause(self, froms):
+ clauses = []
+
+ def visit_join(join):
+ if join.isouter:
+ # https://docs.oracle.com/database/121/SQLRF/queries006.htm#SQLRF52354
+ # "apply the outer join operator (+) to all columns of B in
+ # the join condition in the WHERE clause" - that is,
+ # unconditionally regardless of operator or the other side
+ def visit_binary(binary):
+ if isinstance(
+ binary.left, expression.ColumnClause
+ ) and join.right.is_derived_from(binary.left.table):
+ binary.left = _OuterJoinColumn(binary.left)
+ elif isinstance(
+ binary.right, expression.ColumnClause
+ ) and join.right.is_derived_from(binary.right.table):
+ binary.right = _OuterJoinColumn(binary.right)
+
+ clauses.append(
+ visitors.cloned_traverse(
+ join.onclause, {}, {"binary": visit_binary}
+ )
+ )
+ else:
+ clauses.append(join.onclause)
+
+ for j in join.left, join.right:
+ if isinstance(j, expression.Join):
+ visit_join(j)
+ elif isinstance(j, expression.FromGrouping):
+ visit_join(j.element)
+
+ for f in froms:
+ if isinstance(f, expression.Join):
+ visit_join(f)
+
+ if not clauses:
+ return None
+ else:
+ return sql.and_(*clauses)
+
+ def visit_outer_join_column(self, vc, **kw):
+ return self.process(vc.column, **kw) + "(+)"
+
+ def visit_sequence(self, seq, **kw):
+ return self.preparer.format_sequence(seq) + ".nextval"
+
+ def get_render_as_alias_suffix(self, alias_name_text):
+ """Oracle doesn't like ``FROM table AS alias``"""
+
+ return " " + alias_name_text
+
+ def returning_clause(self, stmt, returning_cols):
+ columns = []
+ binds = []
+
+ for i, column in enumerate(
+ expression._select_iterables(returning_cols)
+ ):
+ if (
+ self.isupdate
+ and isinstance(column, sa_schema.Column)
+ and isinstance(column.server_default, Computed)
+ and not self.dialect._supports_update_returning_computed_cols
+ ):
+ util.warn(
+ "Computed columns don't work with Oracle UPDATE "
+ "statements that use RETURNING; the value of the column "
+ "*before* the UPDATE takes place is returned. It is "
+ "advised to not use RETURNING with an Oracle computed "
+ "column. Consider setting implicit_returning to False on "
+ "the Table object in order to avoid implicit RETURNING "
+ "clauses from being generated for this Table."
+ )
+ if column.type._has_column_expression:
+ col_expr = column.type.column_expression(column)
+ else:
+ col_expr = column
+
+ outparam = sql.outparam("ret_%d" % i, type_=column.type)
+ self.binds[outparam.key] = outparam
+ binds.append(
+ self.bindparam_string(self._truncate_bindparam(outparam))
+ )
+
+ # ensure the ExecutionContext.get_out_parameters() method is
+ # *not* called; the cx_Oracle dialect wants to handle these
+ # parameters separately
+ self.has_out_parameters = False
+
+ columns.append(self.process(col_expr, within_columns_clause=False))
+
+ self._add_to_result_map(
+ getattr(col_expr, "name", col_expr._anon_name_label),
+ getattr(col_expr, "name", col_expr._anon_name_label),
+ (
+ column,
+ getattr(column, "name", None),
+ getattr(column, "key", None),
+ ),
+ column.type,
+ )
+
+ return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds)
+
+ def translate_select_structure(self, select_stmt, **kwargs):
+ select = select_stmt
+
+ if not getattr(select, "_oracle_visit", None):
+ if not self.dialect.use_ansi:
+ froms = self._display_froms_for_select(
+ select, kwargs.get("asfrom", False)
+ )
+ whereclause = self._get_nonansi_join_whereclause(froms)
+ if whereclause is not None:
+ select = select.where(whereclause)
+ select._oracle_visit = True
+
+ # if fetch is used this is not needed
+ if (
+ select._has_row_limiting_clause
+ and select._fetch_clause is None
+ ):
+ limit_clause = select._limit_clause
+ offset_clause = select._offset_clause
+
+ if select._simple_int_clause(limit_clause):
+ limit_clause = limit_clause.render_literal_execute()
+
+ if select._simple_int_clause(offset_clause):
+ offset_clause = offset_clause.render_literal_execute()
+
+ # currently using form at:
+ # https://blogs.oracle.com/oraclemagazine/\
+ # on-rownum-and-limiting-results
+
+ orig_select = select
+ select = select._generate()
+ select._oracle_visit = True
+
+ # add expressions to accommodate FOR UPDATE OF
+ for_update = select._for_update_arg
+ if for_update is not None and for_update.of:
+ for_update = for_update._clone()
+ for_update._copy_internals()
+
+ for elem in for_update.of:
+ if not select.selected_columns.contains_column(elem):
+ select = select.add_columns(elem)
+
+ # Wrap the middle select and add the hint
+ inner_subquery = select.alias()
+ limitselect = sql.select(
+ *[
+ c
+ for c in inner_subquery.c
+ if orig_select.selected_columns.corresponding_column(c)
+ is not None
+ ]
+ )
+
+ if (
+ limit_clause is not None
+ and self.dialect.optimize_limits
+ and select._simple_int_clause(limit_clause)
+ ):
+ limitselect = limitselect.prefix_with(
+ expression.text(
+ "/*+ FIRST_ROWS(%s) */"
+ % self.process(limit_clause, **kwargs)
+ )
+ )
+
+ limitselect._oracle_visit = True
+ limitselect._is_wrapper = True
+
+ # add expressions to accommodate FOR UPDATE OF
+ if for_update is not None and for_update.of:
+
+ adapter = sql_util.ClauseAdapter(inner_subquery)
+ for_update.of = [
+ adapter.traverse(elem) for elem in for_update.of
+ ]
+
+ # If needed, add the limiting clause
+ if limit_clause is not None:
+ if select._simple_int_clause(limit_clause) and (
+ offset_clause is None
+ or select._simple_int_clause(offset_clause)
+ ):
+ max_row = limit_clause
+
+ if offset_clause is not None:
+ max_row = max_row + offset_clause
+
+ else:
+ max_row = limit_clause
+
+ if offset_clause is not None:
+ max_row = max_row + offset_clause
+ limitselect = limitselect.where(
+ sql.literal_column("ROWNUM") <= max_row
+ )
+
+ # If needed, add the ora_rn, and wrap again with offset.
+ if offset_clause is None:
+ limitselect._for_update_arg = for_update
+ select = limitselect
+ else:
+ limitselect = limitselect.add_columns(
+ sql.literal_column("ROWNUM").label("ora_rn")
+ )
+ limitselect._oracle_visit = True
+ limitselect._is_wrapper = True
+
+ if for_update is not None and for_update.of:
+ limitselect_cols = limitselect.selected_columns
+ for elem in for_update.of:
+ if (
+ limitselect_cols.corresponding_column(elem)
+ is None
+ ):
+ limitselect = limitselect.add_columns(elem)
+
+ limit_subquery = limitselect.alias()
+ origselect_cols = orig_select.selected_columns
+ offsetselect = sql.select(
+ *[
+ c
+ for c in limit_subquery.c
+ if origselect_cols.corresponding_column(c)
+ is not None
+ ]
+ )
+
+ offsetselect._oracle_visit = True
+ offsetselect._is_wrapper = True
+
+ if for_update is not None and for_update.of:
+ adapter = sql_util.ClauseAdapter(limit_subquery)
+ for_update.of = [
+ adapter.traverse(elem) for elem in for_update.of
+ ]
+
+ offsetselect = offsetselect.where(
+ sql.literal_column("ora_rn") > offset_clause
+ )
+
+ offsetselect._for_update_arg = for_update
+ select = offsetselect
+
+ return select
+
+ def limit_clause(self, select, **kw):
+ return ""
+
+ def visit_empty_set_expr(self, type_):
+ return "SELECT 1 FROM DUAL WHERE 1!=1"
+
+ def for_update_clause(self, select, **kw):
+ if self.is_subquery():
+ return ""
+
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of:
+ tmp += " OF " + ", ".join(
+ self.process(elem, **kw) for elem in select._for_update_arg.of
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+ if select._for_update_arg.skip_locked:
+ tmp += " SKIP LOCKED"
+
+ return tmp
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "DECODE(%s, %s, 0, 1) = 1" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "DECODE(%s, %s, 0, 1) = 0" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def _get_regexp_args(self, binary, kw):
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
+ flags = binary.modifiers["flags"]
+ if flags is not None:
+ flags = self.process(flags, **kw)
+ return string, pattern, flags
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ string, pattern, flags = self._get_regexp_args(binary, kw)
+ if flags is None:
+ return "REGEXP_LIKE(%s, %s)" % (string, pattern)
+ else:
+ return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return "NOT %s" % self.visit_regexp_match_op_binary(
+ binary, operator, **kw
+ )
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ string, pattern, flags = self._get_regexp_args(binary, kw)
+ replacement = self.process(binary.modifiers["replacement"], **kw)
+ if flags is None:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ )
+ else:
+ return "REGEXP_REPLACE(%s, %s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ flags,
+ )
+
+
+class OracleDDLCompiler(compiler.DDLCompiler):
+ def define_constraint_cascades(self, constraint):
+ text = ""
+ if constraint.ondelete is not None:
+ text += " ON DELETE %s" % constraint.ondelete
+
+ # oracle has no ON UPDATE CASCADE -
+ # its only available via triggers
+ # https://asktom.oracle.com/tkyte/update_cascade/index.html
+ if constraint.onupdate is not None:
+ util.warn(
+ "Oracle does not contain native UPDATE CASCADE "
+ "functionality - onupdates will not be rendered for foreign "
+ "keys. Consider using deferrable=True, initially='deferred' "
+ "or triggers."
+ )
+
+ return text
+
+ def visit_drop_table_comment(self, drop):
+ return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table(
+ drop.element
+ )
+
+ def visit_create_index(self, create):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ if index.dialect_options["oracle"]["bitmap"]:
+ text += "BITMAP "
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=True),
+ preparer.format_table(index.table, use_schema=True),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+ if index.dialect_options["oracle"]["compress"] is not False:
+ if index.dialect_options["oracle"]["compress"] is True:
+ text += " COMPRESS"
+ else:
+ text += " COMPRESS %d" % (
+ index.dialect_options["oracle"]["compress"]
+ )
+ return text
+
+ def post_create_table(self, table):
+ table_opts = []
+ opts = table.dialect_options["oracle"]
+
+ if opts["on_commit"]:
+ on_commit_options = opts["on_commit"].replace("_", " ").upper()
+ table_opts.append("\n ON COMMIT %s" % on_commit_options)
+
+ if opts["compress"]:
+ if opts["compress"] is True:
+ table_opts.append("\n COMPRESS")
+ else:
+ table_opts.append("\n COMPRESS FOR %s" % (opts["compress"]))
+
+ return "".join(table_opts)
+
+ def get_identity_options(self, identity_options):
+ text = super(OracleDDLCompiler, self).get_identity_options(
+ identity_options
+ )
+ text = text.replace("NO MINVALUE", "NOMINVALUE")
+ text = text.replace("NO MAXVALUE", "NOMAXVALUE")
+ text = text.replace("NO CYCLE", "NOCYCLE")
+ text = text.replace("NO ORDER", "NOORDER")
+ return text
+
+ def visit_computed_column(self, generated):
+ text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+ if generated.persisted is True:
+ raise exc.CompileError(
+ "Oracle computed columns do not support 'stored' persistence; "
+ "set the 'persisted' flag to None or False for Oracle support."
+ )
+ elif generated.persisted is False:
+ text += " VIRTUAL"
+ return text
+
+ def visit_identity_column(self, identity, **kw):
+ if identity.always is None:
+ kind = ""
+ else:
+ kind = "ALWAYS" if identity.always else "BY DEFAULT"
+ text = "GENERATED %s" % kind
+ if identity.on_null:
+ text += " ON NULL"
+ text += " AS IDENTITY"
+ options = self.get_identity_options(identity)
+ if options:
+ text += " (%s)" % options
+ return text
+
+
+class OracleIdentifierPreparer(compiler.IdentifierPreparer):
+
+ reserved_words = {x.lower() for x in RESERVED_WORDS}
+ illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union(
+ ["_", "$"]
+ )
+
+ def _bindparam_requires_quotes(self, value):
+ """Return True if the given identifier requires quoting."""
+ lc_value = value.lower()
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ )
+
+ def format_savepoint(self, savepoint):
+ name = savepoint.ident.lstrip("_")
+ return super(OracleIdentifierPreparer, self).format_savepoint(
+ savepoint, name
+ )
+
+
+class OracleExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ "SELECT "
+ + self.identifier_preparer.format_sequence(seq)
+ + ".nextval FROM DUAL",
+ type_,
+ )
+
+
+class OracleDialect(default.DefaultDialect):
+ name = "oracle"
+ supports_statement_cache = True
+ supports_alter = True
+ supports_unicode_statements = False
+ supports_unicode_binds = False
+ max_identifier_length = 128
+
+ supports_simple_order_by_label = False
+ cte_follows_insert = True
+
+ supports_sequences = True
+ sequences_optional = False
+ postfetch_lastrowid = False
+
+ default_paramstyle = "named"
+ colspecs = colspecs
+ ischema_names = ischema_names
+ requires_name_normalize = True
+
+ supports_comments = True
+
+ supports_default_values = False
+ supports_default_metavalue = True
+ supports_empty_insert = False
+ supports_identity_columns = True
+
+ statement_compiler = OracleCompiler
+ ddl_compiler = OracleDDLCompiler
+ type_compiler = OracleTypeCompiler
+ preparer = OracleIdentifierPreparer
+ execution_ctx_cls = OracleExecutionContext
+
+ reflection_options = ("oracle_resolve_synonyms",)
+
+ _use_nchar_for_unicode = False
+
+ construct_arguments = [
+ (
+ sa_schema.Table,
+ {"resolve_synonyms": False, "on_commit": None, "compress": False},
+ ),
+ (sa_schema.Index, {"bitmap": False, "compress": False}),
+ ]
+
+ @util.deprecated_params(
+ use_binds_for_limits=(
+ "1.4",
+ "The ``use_binds_for_limits`` Oracle dialect parameter is "
+ "deprecated. The dialect now renders LIMIT /OFFSET integers "
+ "inline in all cases using a post-compilation hook, so that the "
+ "value is still represented by a 'bound parameter' on the Core "
+ "Expression side.",
+ )
+ )
+ def __init__(
+ self,
+ use_ansi=True,
+ optimize_limits=False,
+ use_binds_for_limits=None,
+ use_nchar_for_unicode=False,
+ exclude_tablespaces=("SYSTEM", "SYSAUX"),
+ **kwargs
+ ):
+ default.DefaultDialect.__init__(self, **kwargs)
+ self._use_nchar_for_unicode = use_nchar_for_unicode
+ self.use_ansi = use_ansi
+ self.optimize_limits = optimize_limits
+ self.exclude_tablespaces = exclude_tablespaces
+
+ def initialize(self, connection):
+ super(OracleDialect, self).initialize(connection)
+
+ self.implicit_returning = self.__dict__.get(
+ "implicit_returning", self.server_version_info > (10,)
+ )
+
+ if self._is_oracle_8:
+ self.colspecs = self.colspecs.copy()
+ self.colspecs.pop(sqltypes.Interval)
+ self.use_ansi = False
+
+ self.supports_identity_columns = self.server_version_info >= (12,)
+
+ def _get_effective_compat_server_version_info(self, connection):
+ # dialect does not need compat levels below 12.2, so don't query
+ # in those cases
+
+ if self.server_version_info < (12, 2):
+ return self.server_version_info
+ try:
+ compat = connection.exec_driver_sql(
+ "SELECT value FROM v$parameter WHERE name = 'compatible'"
+ ).scalar()
+ except exc.DBAPIError:
+ compat = None
+
+ if compat:
+ try:
+ return tuple(int(x) for x in compat.split("."))
+ except:
+ return self.server_version_info
+ else:
+ return self.server_version_info
+
+ @property
+ def _is_oracle_8(self):
+ return self.server_version_info and self.server_version_info < (9,)
+
+ @property
+ def _supports_table_compression(self):
+ return self.server_version_info and self.server_version_info >= (10, 1)
+
+ @property
+ def _supports_table_compress_for(self):
+ return self.server_version_info and self.server_version_info >= (11,)
+
+ @property
+ def _supports_char_length(self):
+ return not self._is_oracle_8
+
+ @property
+ def _supports_update_returning_computed_cols(self):
+ # on version 18 this error is no longet present while it happens on 11
+ # it may work also on versions before the 18
+ return self.server_version_info and self.server_version_info >= (18,)
+
+ def do_release_savepoint(self, connection, name):
+ # Oracle does not support RELEASE SAVEPOINT
+ pass
+
+ def _check_max_identifier_length(self, connection):
+ if self._get_effective_compat_server_version_info(connection) < (
+ 12,
+ 2,
+ ):
+ return 30
+ else:
+ # use the default
+ return None
+
+ def _check_unicode_returns(self, connection):
+ additional_tests = [
+ expression.cast(
+ expression.literal_column("'test nvarchar2 returns'"),
+ sqltypes.NVARCHAR(60),
+ )
+ ]
+ return super(OracleDialect, self)._check_unicode_returns(
+ connection, additional_tests
+ )
+
+ _isolation_lookup = ["READ COMMITTED", "SERIALIZABLE"]
+
+ def get_isolation_level(self, connection):
+ raise NotImplementedError("implemented by cx_Oracle dialect")
+
+ def get_default_isolation_level(self, dbapi_conn):
+ try:
+ return self.get_isolation_level(dbapi_conn)
+ except NotImplementedError:
+ raise
+ except:
+ return "READ COMMITTED"
+
+ def set_isolation_level(self, connection, level):
+ raise NotImplementedError("implemented by cx_Oracle dialect")
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ if not schema:
+ schema = self.default_schema_name
+
+ cursor = connection.execute(
+ sql.text(
+ "SELECT table_name FROM all_tables "
+ "WHERE table_name = CAST(:name AS VARCHAR2(128)) "
+ "AND owner = CAST(:schema_name AS VARCHAR2(128))"
+ ),
+ dict(
+ name=self.denormalize_name(table_name),
+ schema_name=self.denormalize_name(schema),
+ ),
+ )
+ return cursor.first() is not None
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ if not schema:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT sequence_name FROM all_sequences "
+ "WHERE sequence_name = :name AND "
+ "sequence_owner = :schema_name"
+ ),
+ dict(
+ name=self.denormalize_name(sequence_name),
+ schema_name=self.denormalize_name(schema),
+ ),
+ )
+ return cursor.first() is not None
+
+ def _get_default_schema_name(self, connection):
+ return self.normalize_name(
+ connection.exec_driver_sql(
+ "select sys_context( 'userenv', 'current_schema' ) from dual"
+ ).scalar()
+ )
+
+ def _resolve_synonym(
+ self,
+ connection,
+ desired_owner=None,
+ desired_synonym=None,
+ desired_table=None,
+ ):
+ """search for a local synonym matching the given desired owner/name.
+
+ if desired_owner is None, attempts to locate a distinct owner.
+
+ returns the actual name, owner, dblink name, and synonym name if
+ found.
+ """
+
+ q = (
+ "SELECT owner, table_owner, table_name, db_link, "
+ "synonym_name FROM all_synonyms WHERE "
+ )
+ clauses = []
+ params = {}
+ if desired_synonym:
+ clauses.append(
+ "synonym_name = CAST(:synonym_name AS VARCHAR2(128))"
+ )
+ params["synonym_name"] = desired_synonym
+ if desired_owner:
+ clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))")
+ params["desired_owner"] = desired_owner
+ if desired_table:
+ clauses.append("table_name = CAST(:tname AS VARCHAR2(128))")
+ params["tname"] = desired_table
+
+ q += " AND ".join(clauses)
+
+ result = connection.execution_options(future_result=True).execute(
+ sql.text(q), params
+ )
+ if desired_owner:
+ row = result.mappings().first()
+ if row:
+ return (
+ row["table_name"],
+ row["table_owner"],
+ row["db_link"],
+ row["synonym_name"],
+ )
+ else:
+ return None, None, None, None
+ else:
+ rows = result.mappings().all()
+ if len(rows) > 1:
+ raise AssertionError(
+ "There are multiple tables visible to the schema, you "
+ "must specify owner"
+ )
+ elif len(rows) == 1:
+ row = rows[0]
+ return (
+ row["table_name"],
+ row["table_owner"],
+ row["db_link"],
+ row["synonym_name"],
+ )
+ else:
+ return None, None, None, None
+
+ @reflection.cache
+ def _prepare_reflection_args(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ if resolve_synonyms:
+ actual_name, owner, dblink, synonym = self._resolve_synonym(
+ connection,
+ desired_owner=self.denormalize_name(schema),
+ desired_synonym=self.denormalize_name(table_name),
+ )
+ else:
+ actual_name, owner, dblink, synonym = None, None, None, None
+ if not actual_name:
+ actual_name = self.denormalize_name(table_name)
+
+ if dblink:
+ # using user_db_links here since all_db_links appears
+ # to have more restricted permissions.
+ # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
+ # will need to hear from more users if we are doing
+ # the right thing here. See [ticket:2619]
+ owner = connection.scalar(
+ sql.text(
+ "SELECT username FROM user_db_links " "WHERE db_link=:link"
+ ),
+ dict(link=dblink),
+ )
+ dblink = "@" + dblink
+ elif not owner:
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ return (actual_name, owner, dblink or "", synonym)
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = "SELECT username FROM all_users ORDER BY username"
+ cursor = connection.exec_driver_sql(s)
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ schema = self.denormalize_name(schema or self.default_schema_name)
+
+ # note that table_names() isn't loading DBLINKed or synonym'ed tables
+ if schema is None:
+ schema = self.default_schema_name
+
+ sql_str = "SELECT table_name FROM all_tables WHERE "
+ if self.exclude_tablespaces:
+ sql_str += (
+ "nvl(tablespace_name, 'no tablespace') "
+ "NOT IN (%s) AND "
+ % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ )
+ sql_str += (
+ "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL"
+ )
+
+ cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ schema = self.denormalize_name(self.default_schema_name)
+
+ sql_str = "SELECT table_name FROM all_tables WHERE "
+ if self.exclude_tablespaces:
+ sql_str += (
+ "nvl(tablespace_name, 'no tablespace') "
+ "NOT IN (%s) AND "
+ % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ )
+ sql_str += (
+ "OWNER = :owner "
+ "AND IOT_NAME IS NULL "
+ "AND DURATION IS NOT NULL"
+ )
+
+ cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ schema = self.denormalize_name(schema or self.default_schema_name)
+ s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
+ cursor = connection.execute(
+ s, dict(owner=self.denormalize_name(schema))
+ )
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_sequence_names(self, connection, schema=None, **kw):
+ if not schema:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT sequence_name FROM all_sequences "
+ "WHERE sequence_owner = :schema_name"
+ ),
+ dict(schema_name=self.denormalize_name(schema)),
+ )
+ return [self.normalize_name(row[0]) for row in cursor]
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+ options = {}
+
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ params = {"table_name": table_name}
+
+ columns = ["table_name"]
+ if self._supports_table_compression:
+ columns.append("compression")
+ if self._supports_table_compress_for:
+ columns.append("compress_for")
+
+ text = (
+ "SELECT %(columns)s "
+ "FROM ALL_TABLES%(dblink)s "
+ "WHERE table_name = CAST(:table_name AS VARCHAR(128))"
+ )
+
+ if schema is not None:
+ params["owner"] = schema
+ text += " AND owner = CAST(:owner AS VARCHAR(128)) "
+ text = text % {"dblink": dblink, "columns": ", ".join(columns)}
+
+ result = connection.execute(sql.text(text), params)
+
+ enabled = dict(DISABLED=False, ENABLED=True)
+
+ row = result.first()
+ if row:
+ if "compression" in row._fields and enabled.get(
+ row.compression, False
+ ):
+ if "compress_for" in row._fields:
+ options["oracle_compress"] = row.compress_for
+ else:
+ options["oracle_compress"] = True
+
+ return options
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ """
+
+ kw arguments can be:
+
+ oracle_resolve_synonyms
+
+ dblink
+
+ """
+
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+ columns = []
+ if self._supports_char_length:
+ char_length_col = "char_length"
+ else:
+ char_length_col = "data_length"
+
+ if self.server_version_info >= (12,):
+ identity_cols = """\
+ col.default_on_null,
+ (
+ SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS
+ FROM ALL_TAB_IDENTITY_COLS%(dblink)s id
+ WHERE col.table_name = id.table_name
+ AND col.column_name = id.column_name
+ AND col.owner = id.owner
+ ) AS identity_options""" % {
+ "dblink": dblink
+ }
+ else:
+ identity_cols = "NULL as default_on_null, NULL as identity_options"
+
+ params = {"table_name": table_name}
+
+ text = """
+ SELECT
+ col.column_name,
+ col.data_type,
+ col.%(char_length_col)s,
+ col.data_precision,
+ col.data_scale,
+ col.nullable,
+ col.data_default,
+ com.comments,
+ col.virtual_column,
+ %(identity_cols)s
+ FROM all_tab_cols%(dblink)s col
+ LEFT JOIN all_col_comments%(dblink)s com
+ ON col.table_name = com.table_name
+ AND col.column_name = com.column_name
+ AND col.owner = com.owner
+ WHERE col.table_name = CAST(:table_name AS VARCHAR2(128))
+ AND col.hidden_column = 'NO'
+ """
+ if schema is not None:
+ params["owner"] = schema
+ text += " AND col.owner = :owner "
+ text += " ORDER BY col.column_id"
+ text = text % {
+ "dblink": dblink,
+ "char_length_col": char_length_col,
+ "identity_cols": identity_cols,
+ }
+
+ c = connection.execute(sql.text(text), params)
+
+ for row in c:
+ colname = self.normalize_name(row[0])
+ orig_colname = row[0]
+ coltype = row[1]
+ length = row[2]
+ precision = row[3]
+ scale = row[4]
+ nullable = row[5] == "Y"
+ default = row[6]
+ comment = row[7]
+ generated = row[8]
+ default_on_nul = row[9]
+ identity_options = row[10]
+
+ if coltype == "NUMBER":
+ if precision is None and scale == 0:
+ coltype = INTEGER()
+ else:
+ coltype = NUMBER(precision, scale)
+ elif coltype == "FLOAT":
+ # TODO: support "precision" here as "binary_precision"
+ coltype = FLOAT()
+ elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"):
+ coltype = self.ischema_names.get(coltype)(length)
+ elif "WITH TIME ZONE" in coltype:
+ coltype = TIMESTAMP(timezone=True)
+ else:
+ coltype = re.sub(r"\(\d+\)", "", coltype)
+ try:
+ coltype = self.ischema_names[coltype]
+ except KeyError:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (coltype, colname)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ if generated == "YES":
+ computed = dict(sqltext=default)
+ default = None
+ else:
+ computed = None
+
+ if identity_options is not None:
+ identity = self._parse_identity_options(
+ identity_options, default_on_nul
+ )
+ default = None
+ else:
+ identity = None
+
+ cdict = {
+ "name": colname,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": "auto",
+ "comment": comment,
+ }
+ if orig_colname.lower() == orig_colname:
+ cdict["quote"] = True
+ if computed is not None:
+ cdict["computed"] = computed
+ if identity is not None:
+ cdict["identity"] = identity
+
+ columns.append(cdict)
+ return columns
+
+ def _parse_identity_options(self, identity_options, default_on_nul):
+ # identity_options is a string that starts with 'ALWAYS,' or
+ # 'BY DEFAULT,' and continues with
+ # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1,
+ # CYCLE_FLAG: N, CACHE_SIZE: 1, ORDER_FLAG: N, SCALE_FLAG: N,
+ # EXTEND_FLAG: N, SESSION_FLAG: N, KEEP_VALUE: N
+ parts = [p.strip() for p in identity_options.split(",")]
+ identity = {
+ "always": parts[0] == "ALWAYS",
+ "on_null": default_on_nul == "YES",
+ }
+
+ for part in parts[1:]:
+ option, value = part.split(":")
+ value = value.strip()
+
+ if "START WITH" in option:
+ identity["start"] = compat.long_type(value)
+ elif "INCREMENT BY" in option:
+ identity["increment"] = compat.long_type(value)
+ elif "MAX_VALUE" in option:
+ identity["maxvalue"] = compat.long_type(value)
+ elif "MIN_VALUE" in option:
+ identity["minvalue"] = compat.long_type(value)
+ elif "CYCLE_FLAG" in option:
+ identity["cycle"] = value == "Y"
+ elif "CACHE_SIZE" in option:
+ identity["cache"] = compat.long_type(value)
+ elif "ORDER_FLAG" in option:
+ identity["order"] = value == "Y"
+ return identity
+
+ @reflection.cache
+ def get_table_comment(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ info_cache = kw.get("info_cache")
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ if not schema:
+ schema = self.default_schema_name
+
+ COMMENT_SQL = """
+ SELECT comments
+ FROM all_tab_comments
+ WHERE table_name = CAST(:table_name AS VARCHAR(128))
+ AND owner = CAST(:schema_name AS VARCHAR(128))
+ """
+
+ c = connection.execute(
+ sql.text(COMMENT_SQL),
+ dict(table_name=table_name, schema_name=schema),
+ )
+ return {"text": c.scalar()}
+
+ @reflection.cache
+ def get_indexes(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+
+ info_cache = kw.get("info_cache")
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+ indexes = []
+
+ params = {"table_name": table_name}
+ text = (
+ "SELECT a.index_name, a.column_name, "
+ "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "
+ "\nFROM ALL_IND_COLUMNS%(dblink)s a, "
+ "\nALL_INDEXES%(dblink)s b "
+ "\nWHERE "
+ "\na.index_name = b.index_name "
+ "\nAND a.table_owner = b.table_owner "
+ "\nAND a.table_name = b.table_name "
+ "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))"
+ )
+
+ if schema is not None:
+ params["schema"] = schema
+ text += "AND a.table_owner = :schema "
+
+ text += "ORDER BY a.index_name, a.column_position"
+
+ text = text % {"dblink": dblink}
+
+ q = sql.text(text)
+ rp = connection.execute(q, params)
+ indexes = []
+ last_index_name = None
+ pk_constraint = self.get_pk_constraint(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms=resolve_synonyms,
+ dblink=dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
+ enabled = dict(DISABLED=False, ENABLED=True)
+
+ oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
+
+ index = None
+ for rset in rp:
+ index_name_normalized = self.normalize_name(rset.index_name)
+
+ # skip primary key index. This is refined as of
+ # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y"
+ # if the name of this index was generated by Oracle, however
+ # if a named primary key constraint was created then this flag
+ # is false.
+ if (
+ pk_constraint
+ and index_name_normalized == pk_constraint["name"]
+ ):
+ continue
+
+ if rset.index_name != last_index_name:
+ index = dict(
+ name=index_name_normalized,
+ column_names=[],
+ dialect_options={},
+ )
+ indexes.append(index)
+ index["unique"] = uniqueness.get(rset.uniqueness, False)
+
+ if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"):
+ index["dialect_options"]["oracle_bitmap"] = True
+ if enabled.get(rset.compression, False):
+ index["dialect_options"][
+ "oracle_compress"
+ ] = rset.prefix_length
+
+ # filter out Oracle SYS_NC names. could also do an outer join
+ # to the all_tab_columns table and check for real col names there.
+ if not oracle_sys_col.match(rset.column_name):
+ index["column_names"].append(
+ self.normalize_name(rset.column_name)
+ )
+ last_index_name = rset.index_name
+
+ return indexes
+
+ @reflection.cache
+ def _get_constraint_data(
+ self, connection, table_name, schema=None, dblink="", **kw
+ ):
+
+ params = {"table_name": table_name}
+
+ text = (
+ "SELECT"
+ "\nac.constraint_name," # 0
+ "\nac.constraint_type," # 1
+ "\nloc.column_name AS local_column," # 2
+ "\nrem.table_name AS remote_table," # 3
+ "\nrem.column_name AS remote_column," # 4
+ "\nrem.owner AS remote_owner," # 5
+ "\nloc.position as loc_pos," # 6
+ "\nrem.position as rem_pos," # 7
+ "\nac.search_condition," # 8
+ "\nac.delete_rule" # 9
+ "\nFROM all_constraints%(dblink)s ac,"
+ "\nall_cons_columns%(dblink)s loc,"
+ "\nall_cons_columns%(dblink)s rem"
+ "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))"
+ "\nAND ac.constraint_type IN ('R','P', 'U', 'C')"
+ )
+
+ if schema is not None:
+ params["owner"] = schema
+ text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))"
+
+ text += (
+ "\nAND ac.owner = loc.owner"
+ "\nAND ac.constraint_name = loc.constraint_name"
+ "\nAND ac.r_owner = rem.owner(+)"
+ "\nAND ac.r_constraint_name = rem.constraint_name(+)"
+ "\nAND (rem.position IS NULL or loc.position=rem.position)"
+ "\nORDER BY ac.constraint_name, loc.position"
+ )
+
+ text = text % {"dblink": dblink}
+ rp = connection.execute(sql.text(text), params)
+ constraint_data = rp.fetchall()
+ return constraint_data
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+ pkeys = []
+ constraint_name = None
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ for row in constraint_data:
+ (
+ cons_name,
+ cons_type,
+ local_column,
+ remote_table,
+ remote_column,
+ remote_owner,
+ ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+ if cons_type == "P":
+ if constraint_name is None:
+ constraint_name = self.normalize_name(cons_name)
+ pkeys.append(local_column)
+ return {"constrained_columns": pkeys, "name": constraint_name}
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ """
+
+ kw arguments can be:
+
+ oracle_resolve_synonyms
+
+ dblink
+
+ """
+ requested_schema = schema # to check later on
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ def fkey_rec():
+ return {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": None,
+ "referred_columns": [],
+ "options": {},
+ }
+
+ fkeys = util.defaultdict(fkey_rec)
+
+ for row in constraint_data:
+ (
+ cons_name,
+ cons_type,
+ local_column,
+ remote_table,
+ remote_column,
+ remote_owner,
+ ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
+
+ cons_name = self.normalize_name(cons_name)
+
+ if cons_type == "R":
+ if remote_table is None:
+ # ticket 363
+ util.warn(
+ (
+ "Got 'None' querying 'table_name' from "
+ "all_cons_columns%(dblink)s - does the user have "
+ "proper rights to the table?"
+ )
+ % {"dblink": dblink}
+ )
+ continue
+
+ rec = fkeys[cons_name]
+ rec["name"] = cons_name
+ local_cols, remote_cols = (
+ rec["constrained_columns"],
+ rec["referred_columns"],
+ )
+
+ if not rec["referred_table"]:
+ if resolve_synonyms:
+ (
+ ref_remote_name,
+ ref_remote_owner,
+ ref_dblink,
+ ref_synonym,
+ ) = self._resolve_synonym(
+ connection,
+ desired_owner=self.denormalize_name(remote_owner),
+ desired_table=self.denormalize_name(remote_table),
+ )
+ if ref_synonym:
+ remote_table = self.normalize_name(ref_synonym)
+ remote_owner = self.normalize_name(
+ ref_remote_owner
+ )
+
+ rec["referred_table"] = remote_table
+
+ if (
+ requested_schema is not None
+ or self.denormalize_name(remote_owner) != schema
+ ):
+ rec["referred_schema"] = remote_owner
+
+ if row[9] != "NO ACTION":
+ rec["options"]["ondelete"] = row[9]
+
+ local_cols.append(local_column)
+ remote_cols.append(remote_column)
+
+ return list(fkeys.values())
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ unique_keys = filter(lambda x: x[1] == "U", constraint_data)
+ uniques_group = groupby(unique_keys, lambda x: x[0])
+
+ index_names = {
+ ix["name"]
+ for ix in self.get_indexes(connection, table_name, schema=schema)
+ }
+ return [
+ {
+ "name": name,
+ "column_names": cols,
+ "duplicates_index": name if name in index_names else None,
+ }
+ for name, cols in [
+ [
+ self.normalize_name(i[0]),
+ [self.normalize_name(x[2]) for x in i[1]],
+ ]
+ for i in uniques_group
+ ]
+ ]
+
+ @reflection.cache
+ def get_view_definition(
+ self,
+ connection,
+ view_name,
+ schema=None,
+ resolve_synonyms=False,
+ dblink="",
+ **kw
+ ):
+ info_cache = kw.get("info_cache")
+ (view_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ view_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ params = {"view_name": view_name}
+ text = "SELECT text FROM all_views WHERE view_name=:view_name"
+
+ if schema is not None:
+ text += " AND owner = :schema"
+ params["schema"] = schema
+
+ rp = connection.execute(sql.text(text), params).scalar()
+ if rp:
+ if util.py2k:
+ rp = rp.decode(self.encoding)
+ return rp
+ else:
+ return None
+
+ @reflection.cache
+ def get_check_constraints(
+ self, connection, table_name, schema=None, include_all=False, **kw
+ ):
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
+ dblink = kw.get("dblink", "")
+ info_cache = kw.get("info_cache")
+
+ (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ connection,
+ table_name,
+ schema,
+ resolve_synonyms,
+ dblink,
+ info_cache=info_cache,
+ )
+
+ constraint_data = self._get_constraint_data(
+ connection,
+ table_name,
+ schema,
+ dblink,
+ info_cache=kw.get("info_cache"),
+ )
+
+ check_constraints = filter(lambda x: x[1] == "C", constraint_data)
+
+ return [
+ {"name": self.normalize_name(cons[0]), "sqltext": cons[8]}
+ for cons in check_constraints
+ if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8])
+ ]
+
+
+class _OuterJoinColumn(sql.ClauseElement):
+ __visit_name__ = "outer_join_column"
+
+ def __init__(self, column):
+ self.column = column
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
new file mode 100644
index 0000000..64029a4
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -0,0 +1,1424 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: oracle+cx_oracle
+ :name: cx-Oracle
+ :dbapi: cx_oracle
+ :connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
+ :url: https://oracle.github.io/python-cx_Oracle/
+
+DSN vs. Hostname connections
+-----------------------------
+
+cx_Oracle provides several methods of indicating the target database. The
+dialect translates from a series of different URL forms.
+
+Hostname Connections with Easy Connect Syntax
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Given a hostname, port and service name of the target Oracle Database, for
+example from Oracle's `Easy Connect syntax
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#easy-connect-syntax-for-connection-strings>`_,
+then connect in SQLAlchemy using the ``service_name`` query string parameter::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8")
+
+The `full Easy Connect syntax
+<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B0437826-43C1-49EC-A94D-B650B6A4A6EE>`_
+is not supported. Instead, use a ``tnsnames.ora`` file and connect using a
+DSN.
+
+Connections with tnsnames.ora or Oracle Cloud
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Alternatively, if no port, database name, or ``service_name`` is provided, the
+dialect will use an Oracle DSN "connection string". This takes the "hostname"
+portion of the URL as the data source name. For example, if the
+``tnsnames.ora`` file contains a `Net Service Name
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#net-service-names-for-connection-strings>`_
+of ``myalias`` as below::
+
+ myalias =
+ (DESCRIPTION =
+ (ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521))
+ (CONNECT_DATA =
+ (SERVER = DEDICATED)
+ (SERVICE_NAME = orclpdb1)
+ )
+ )
+
+The cx_Oracle dialect connects to this database service when ``myalias`` is the
+hostname portion of the URL, without specifying a port, database name or
+``service_name``::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8")
+
+Users of Oracle Cloud should use this syntax and also configure the cloud
+wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-autononmous-databases>`_.
+
+SID Connections
+^^^^^^^^^^^^^^^
+
+To use Oracle's obsolete SID connection syntax, the SID can be passed in a
+"database name" portion of the URL as below::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8")
+
+Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as
+follows::
+
+ >>> import cx_Oracle
+ >>> cx_Oracle.makedsn("hostname", 1521, sid="dbname")
+ '(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))'
+
+Passing cx_Oracle connect arguments
+-----------------------------------
+
+Additional connection arguments can usually be passed via the URL
+query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted
+and converted to the correct symbol::
+
+ e = create_engine(
+ "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true")
+
+.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names
+ within the URL string itself, to be passed to the cx_Oracle DBAPI. As
+ was the case earlier but not correctly documented, the
+ :paramref:`_sa.create_engine.connect_args` parameter also accepts all
+ cx_Oracle DBAPI connect arguments.
+
+To pass arguments directly to ``.connect()`` without using the query
+string, use the :paramref:`_sa.create_engine.connect_args` dictionary.
+Any cx_Oracle parameter value and/or constant may be passed, such as::
+
+ import cx_Oracle
+ e = create_engine(
+ "oracle+cx_oracle://user:pass@dsn",
+ connect_args={
+ "encoding": "UTF-8",
+ "nencoding": "UTF-8",
+ "mode": cx_Oracle.SYSDBA,
+ "events": True
+ }
+ )
+
+Note that the default value for ``encoding`` and ``nencoding`` was changed to
+"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that
+version, or later.
+
+Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver
+--------------------------------------------------------------------------
+
+There are also options that are consumed by the SQLAlchemy cx_oracle dialect
+itself. These options are always passed directly to :func:`_sa.create_engine`
+, such as::
+
+ e = create_engine(
+ "oracle+cx_oracle://user:pass@dsn", coerce_to_unicode=False)
+
+The parameters accepted by the cx_oracle dialect are as follows:
+
+* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted
+ to 50. This setting is significant with cx_Oracle as the contents of LOB
+ objects are only readable within a "live" row (e.g. within a batch of
+ 50 rows).
+
+* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`.
+
+* ``coerce_to_unicode`` - see :ref:`cx_oracle_unicode` for detail.
+
+* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail.
+
+* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail.
+
+.. _cx_oracle_sessionpool:
+
+Using cx_Oracle SessionPool
+---------------------------
+
+The cx_Oracle library provides its own connection pool implementation that may
+be used in place of SQLAlchemy's pooling functionality. This can be achieved
+by using the :paramref:`_sa.create_engine.creator` parameter to provide a
+function that returns a new connection, along with setting
+:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable
+SQLAlchemy's pooling::
+
+ import cx_Oracle
+ from sqlalchemy import create_engine
+ from sqlalchemy.pool import NullPool
+
+ pool = cx_Oracle.SessionPool(
+ user="scott", password="tiger", dsn="orclpdb",
+ min=2, max=5, increment=1, threaded=True,
+ encoding="UTF-8", nencoding="UTF-8"
+ )
+
+ engine = create_engine("oracle://", creator=pool.acquire, poolclass=NullPool)
+
+The above engine may then be used normally where cx_Oracle's pool handles
+connection pooling::
+
+ with engine.connect() as conn:
+ print(conn.scalar("select 1 FROM dual"))
+
+
+As well as providing a scalable solution for multi-user applications, the
+cx_Oracle session pool supports some Oracle features such as DRCP and
+`Application Continuity
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/ha.html#application-continuity-ac>`_.
+
+Using Oracle Database Resident Connection Pooling (DRCP)
+--------------------------------------------------------
+
+When using Oracle's `DRCP
+<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-015CA8C1-2386-4626-855D-CC546DDC1086>`_,
+the best practice is to pass a connection class and "purity" when acquiring a
+connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation
+<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_.
+
+This can be achieved by wrapping ``pool.acquire()``::
+
+ import cx_Oracle
+ from sqlalchemy import create_engine
+ from sqlalchemy.pool import NullPool
+
+ pool = cx_Oracle.SessionPool(
+ user="scott", password="tiger", dsn="orclpdb",
+ min=2, max=5, increment=1, threaded=True,
+ encoding="UTF-8", nencoding="UTF-8"
+ )
+
+ def creator():
+ return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF)
+
+ engine = create_engine("oracle://", creator=creator, poolclass=NullPool)
+
+The above engine may then be used normally where cx_Oracle handles session
+pooling and Oracle Database additionally uses DRCP::
+
+ with engine.connect() as conn:
+ print(conn.scalar("select 1 FROM dual"))
+
+.. _cx_oracle_unicode:
+
+Unicode
+-------
+
+As is the case for all DBAPIs under Python 3, all strings are inherently
+Unicode strings. Under Python 2, cx_Oracle also supports Python Unicode
+objects directly. In all cases however, the driver requires an explicit
+encoding configuration.
+
+Ensuring the Correct Client Encoding
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The long accepted standard for establishing client encoding for nearly all
+Oracle related software is via the `NLS_LANG <https://www.oracle.com/database/technologies/faq-nls-lang.html>`_
+environment variable. cx_Oracle like most other Oracle drivers will use
+this environment variable as the source of its encoding configuration. The
+format of this variable is idiosyncratic; a typical value would be
+``AMERICAN_AMERICA.AL32UTF8``.
+
+The cx_Oracle driver also supports a programmatic alternative which is to
+pass the ``encoding`` and ``nencoding`` parameters directly to its
+``.connect()`` function. These can be present in the URL as follows::
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8")
+
+For the meaning of the ``encoding`` and ``nencoding`` parameters, please
+consult
+`Characters Sets and National Language Support (NLS) <https://cx-oracle.readthedocs.io/en/latest/user_guide/globalization.html#globalization>`_.
+
+.. seealso::
+
+ `Characters Sets and National Language Support (NLS) <https://cx-oracle.readthedocs.io/en/latest/user_guide/globalization.html#globalization>`_
+ - in the cx_Oracle documentation.
+
+
+Unicode-specific Column datatypes
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The Core expression language handles unicode data by use of the :class:`.Unicode`
+and :class:`.UnicodeText`
+datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by
+default. When using these datatypes with Unicode data, it is expected that
+the Oracle database is configured with a Unicode-aware character set, as well
+as that the ``NLS_LANG`` environment variable is set appropriately, so that
+the VARCHAR2 and CLOB datatypes can accommodate the data.
+
+In the case that the Oracle database is not configured with a Unicode character
+set, the two options are to use the :class:`_types.NCHAR` and
+:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag
+``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`,
+which will cause the
+SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` /
+:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB.
+
+.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText`
+ datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes
+ unless the ``use_nchar_for_unicode=True`` is passed to the dialect
+ when :func:`_sa.create_engine` is called.
+
+Unicode Coercion of result rows under Python 2
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When result sets are fetched that include strings, under Python 3 the cx_Oracle
+DBAPI returns all strings as Python Unicode objects, since Python 3 only has a
+Unicode string type. This occurs for data fetched from datatypes such as
+VARCHAR2, CHAR, CLOB, NCHAR, NCLOB, etc. In order to provide cross-
+compatibility under Python 2, the SQLAlchemy cx_Oracle dialect will add
+Unicode-conversion to string data under Python 2 as well. Historically, this
+made use of converters that were supplied by cx_Oracle but were found to be
+non-performant; SQLAlchemy's own converters are used for the string to Unicode
+conversion under Python 2. To disable the Python 2 Unicode conversion for
+VARCHAR2, CHAR, and CLOB, the flag ``coerce_to_unicode=False`` can be passed to
+:func:`_sa.create_engine`.
+
+.. versionchanged:: 1.3 Unicode conversion is applied to all string values
+ by default under python 2. The ``coerce_to_unicode`` now defaults to True
+ and can be set to False to disable the Unicode coercion of strings that are
+ delivered as VARCHAR2/CHAR/CLOB data.
+
+.. _cx_oracle_unicode_encoding_errors:
+
+Encoding Errors
+^^^^^^^^^^^^^^^
+
+For the unusual case that data in the Oracle database is present with a broken
+encoding, the dialect accepts a parameter ``encoding_errors`` which will be
+passed to Unicode decoding functions in order to affect how decoding errors are
+handled. The value is ultimately consumed by the Python `decode
+<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`_ function, and
+is passed both via cx_Oracle's ``encodingErrors`` parameter consumed by
+``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the
+cx_Oracle dialect makes use of both under different circumstances.
+
+.. versionadded:: 1.3.11
+
+
+.. _cx_oracle_setinputsizes:
+
+Fine grained control over cx_Oracle data binding performance with setinputsizes
+-------------------------------------------------------------------------------
+
+The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the
+DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the
+datatypes that are bound to a SQL statement for Python values being passed as
+parameters. While virtually no other DBAPI assigns any use to the
+``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its
+interactions with the Oracle client interface, and in some scenarios it is not
+possible for SQLAlchemy to know exactly how data should be bound, as some
+settings can cause profoundly different performance characteristics, while
+altering the type coercion behavior at the same time.
+
+Users of the cx_Oracle dialect are **strongly encouraged** to read through
+cx_Oracle's list of built-in datatype symbols at
+https://cx-oracle.readthedocs.io/en/latest/api_manual/module.html#database-types.
+Note that in some cases, significant performance degradation can occur when
+using these types vs. not, in particular when specifying ``cx_Oracle.CLOB``.
+
+On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can
+be used both for runtime visibility (e.g. logging) of the setinputsizes step as
+well as to fully control how ``setinputsizes()`` is used on a per-statement
+basis.
+
+.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes`
+
+
+Example 1 - logging all setinputsizes calls
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The following example illustrates how to log the intermediary values from a
+SQLAlchemy perspective before they are converted to the raw ``setinputsizes()``
+parameter dictionary. The keys of the dictionary are :class:`.BindParameter`
+objects which have a ``.key`` and a ``.type`` attribute::
+
+ from sqlalchemy import create_engine, event
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
+
+ @event.listens_for(engine, "do_setinputsizes")
+ def _log_setinputsizes(inputsizes, cursor, statement, parameters, context):
+ for bindparam, dbapitype in inputsizes.items():
+ log.info(
+ "Bound parameter name: %s SQLAlchemy type: %r "
+ "DBAPI object: %s",
+ bindparam.key, bindparam.type, dbapitype)
+
+Example 2 - remove all bindings to CLOB
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The ``CLOB`` datatype in cx_Oracle incurs a significant performance overhead,
+however is set by default for the ``Text`` type within the SQLAlchemy 1.2
+series. This setting can be modified as follows::
+
+ from sqlalchemy import create_engine, event
+ from cx_Oracle import CLOB
+
+ engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe")
+
+ @event.listens_for(engine, "do_setinputsizes")
+ def _remove_clob(inputsizes, cursor, statement, parameters, context):
+ for bindparam, dbapitype in list(inputsizes.items()):
+ if dbapitype is CLOB:
+ del inputsizes[bindparam]
+
+.. _cx_oracle_returning:
+
+RETURNING Support
+-----------------
+
+The cx_Oracle dialect implements RETURNING using OUT parameters.
+The dialect supports RETURNING fully, however cx_Oracle 6 is recommended
+for complete support.
+
+.. _cx_oracle_lob:
+
+LOB Objects
+-----------
+
+cx_oracle returns oracle LOBs using the cx_oracle.LOB object. SQLAlchemy
+converts these to strings so that the interface of the Binary type is
+consistent with that of other backends, which takes place within a cx_Oracle
+outputtypehandler.
+
+cx_Oracle prior to version 6 would require that LOB objects be read before
+a new batch of rows would be read, as determined by the ``cursor.arraysize``.
+As of the 6 series, this limitation has been lifted. Nevertheless, because
+SQLAlchemy pre-reads these LOBs up front, this issue is avoided in any case.
+
+To disable the auto "read()" feature of the dialect, the flag
+``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`. Under
+the cx_Oracle 5 series, having this flag turned off means there is the chance
+of reading from a stale LOB object if not read as it is fetched. With
+cx_Oracle 6, this issue is resolved.
+
+.. versionchanged:: 1.2 the LOB handling system has been greatly simplified
+ internally to make use of outputtypehandlers, and no longer makes use
+ of alternate "buffered" result set objects.
+
+Two Phase Transactions Not Supported
+-------------------------------------
+
+Two phase transactions are **not supported** under cx_Oracle due to poor
+driver support. As of cx_Oracle 6.0b1, the interface for
+two phase transactions has been changed to be more of a direct pass-through
+to the underlying OCI layer with less automation. The additional logic
+to support this system is not implemented in SQLAlchemy.
+
+.. _cx_oracle_numeric:
+
+Precision Numerics
+------------------
+
+SQLAlchemy's numeric types can handle receiving and returning values as Python
+``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a
+subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in
+use, the :paramref:`.Numeric.asdecimal` flag determines if values should be
+coerced to ``Decimal`` upon return, or returned as float objects. To make
+matters more complicated under Oracle, Oracle's ``NUMBER`` type can also
+represent integer values if the "scale" is zero, so the Oracle-specific
+:class:`_oracle.NUMBER` type takes this into account as well.
+
+The cx_Oracle dialect makes extensive use of connection- and cursor-level
+"outputtypehandler" callables in order to coerce numeric values as requested.
+These callables are specific to the specific flavor of :class:`.Numeric` in
+use, as well as if no SQLAlchemy typing objects are present. There are
+observed scenarios where Oracle may sends incomplete or ambiguous information
+about the numeric types being returned, such as a query where the numeric types
+are buried under multiple levels of subquery. The type handlers do their best
+to make the right decision in all cases, deferring to the underlying cx_Oracle
+DBAPI for all those cases where the driver can make the best decision.
+
+When no typing objects are present, as when executing plain SQL strings, a
+default "outputtypehandler" is present which will generally return numeric
+values which specify precision and scale as Python ``Decimal`` objects. To
+disable this coercion to decimal for performance reasons, pass the flag
+``coerce_to_decimal=False`` to :func:`_sa.create_engine`::
+
+ engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False)
+
+The ``coerce_to_decimal`` flag only impacts the results of plain string
+SQL statements that are not otherwise associated with a :class:`.Numeric`
+SQLAlchemy type (or a subclass of such).
+
+.. versionchanged:: 1.2 The numeric handling system for cx_Oracle has been
+ reworked to take advantage of newer cx_Oracle features as well
+ as better integration of outputtypehandlers.
+
+""" # noqa
+
+from __future__ import absolute_import
+
+import decimal
+import random
+import re
+
+from . import base as oracle
+from .base import OracleCompiler
+from .base import OracleDialect
+from .base import OracleExecutionContext
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+from ...engine import cursor as _cursor
+from ...util import compat
+
+
+class _OracleInteger(sqltypes.Integer):
+ def get_dbapi_type(self, dbapi):
+ # see https://github.com/oracle/python-cx_Oracle/issues/
+ # 208#issuecomment-409715955
+ return int
+
+ def _cx_oracle_var(self, dialect, cursor):
+ cx_Oracle = dialect.dbapi
+ return cursor.var(
+ cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int
+ )
+
+ def _cx_oracle_outputtypehandler(self, dialect):
+ def handler(cursor, name, default_type, size, precision, scale):
+ return self._cx_oracle_var(dialect, cursor)
+
+ return handler
+
+
+class _OracleNumeric(sqltypes.Numeric):
+ is_number = False
+
+ def bind_processor(self, dialect):
+ if self.scale == 0:
+ return None
+ elif self.asdecimal:
+ processor = processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+
+ def process(value):
+ if isinstance(value, (int, float)):
+ return processor(value)
+ elif value is not None and value.is_infinite():
+ return float(value)
+ else:
+ return value
+
+ return process
+ else:
+ return processors.to_float
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+ def _cx_oracle_outputtypehandler(self, dialect):
+ cx_Oracle = dialect.dbapi
+
+ is_cx_oracle_6 = dialect._is_cx_oracle_6
+
+ def handler(cursor, name, default_type, size, precision, scale):
+ outconverter = None
+
+ if precision:
+ if self.asdecimal:
+ if default_type == cx_Oracle.NATIVE_FLOAT:
+ # receiving float and doing Decimal after the fact
+ # allows for float("inf") to be handled
+ type_ = default_type
+ outconverter = decimal.Decimal
+ elif is_cx_oracle_6:
+ type_ = decimal.Decimal
+ else:
+ type_ = cx_Oracle.STRING
+ outconverter = dialect._to_decimal
+ else:
+ if self.is_number and scale == 0:
+ # integer. cx_Oracle is observed to handle the widest
+ # variety of ints when no directives are passed,
+ # from 5.2 to 7.0. See [ticket:4457]
+ return None
+ else:
+ type_ = cx_Oracle.NATIVE_FLOAT
+
+ else:
+ if self.asdecimal:
+ if default_type == cx_Oracle.NATIVE_FLOAT:
+ type_ = default_type
+ outconverter = decimal.Decimal
+ elif is_cx_oracle_6:
+ type_ = decimal.Decimal
+ else:
+ type_ = cx_Oracle.STRING
+ outconverter = dialect._to_decimal
+ else:
+ if self.is_number and scale == 0:
+ # integer. cx_Oracle is observed to handle the widest
+ # variety of ints when no directives are passed,
+ # from 5.2 to 7.0. See [ticket:4457]
+ return None
+ else:
+ type_ = cx_Oracle.NATIVE_FLOAT
+
+ return cursor.var(
+ type_,
+ 255,
+ arraysize=cursor.arraysize,
+ outconverter=outconverter,
+ )
+
+ return handler
+
+
+class _OracleBinaryFloat(_OracleNumeric):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NATIVE_FLOAT
+
+
+class _OracleBINARY_FLOAT(_OracleBinaryFloat, oracle.BINARY_FLOAT):
+ pass
+
+
+class _OracleBINARY_DOUBLE(_OracleBinaryFloat, oracle.BINARY_DOUBLE):
+ pass
+
+
+class _OracleNUMBER(_OracleNumeric):
+ is_number = True
+
+
+class _OracleDate(sqltypes.Date):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is not None:
+ return value.date()
+ else:
+ return value
+
+ return process
+
+
+# TODO: the names used across CHAR / VARCHAR / NCHAR / NVARCHAR
+# here are inconsistent and not very good
+class _OracleChar(sqltypes.CHAR):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.FIXED_CHAR
+
+
+class _OracleNChar(sqltypes.NCHAR):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.FIXED_NCHAR
+
+
+class _OracleUnicodeStringNCHAR(oracle.NVARCHAR2):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NCHAR
+
+
+class _OracleUnicodeStringCHAR(sqltypes.Unicode):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.LONG_STRING
+
+
+class _OracleUnicodeTextNCLOB(oracle.NCLOB):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NCLOB
+
+
+class _OracleUnicodeTextCLOB(sqltypes.UnicodeText):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.CLOB
+
+
+class _OracleText(sqltypes.Text):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.CLOB
+
+
+class _OracleLong(oracle.LONG):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.LONG_STRING
+
+
+class _OracleString(sqltypes.String):
+ pass
+
+
+class _OracleEnum(sqltypes.Enum):
+ def bind_processor(self, dialect):
+ enum_proc = sqltypes.Enum.bind_processor(self, dialect)
+
+ def process(value):
+ raw_str = enum_proc(value)
+ return raw_str
+
+ return process
+
+
+class _OracleBinary(sqltypes.LargeBinary):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BLOB
+
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.auto_convert_lobs:
+ return None
+ else:
+ return super(_OracleBinary, self).result_processor(
+ dialect, coltype
+ )
+
+
+class _OracleInterval(oracle.INTERVAL):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTERVAL
+
+
+class _OracleRaw(oracle.RAW):
+ pass
+
+
+class _OracleRowid(oracle.ROWID):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.ROWID
+
+
+class OracleCompiler_cx_oracle(OracleCompiler):
+ _oracle_cx_sql_compiler = True
+
+ def bindparam_string(self, name, **kw):
+ quote = getattr(name, "quote", None)
+ if (
+ quote is True
+ or quote is not False
+ and self.preparer._bindparam_requires_quotes(name)
+ and not kw.get("post_compile", False)
+ ):
+ # interesting to note about expanding parameters - since the
+ # new parameters take the form <paramname>_<int>, at least if
+ # they are originally formed from reserved words, they no longer
+ # need quoting :). names that include illegal characters
+ # won't work however.
+ quoted_name = '"%s"' % name
+ kw["escaped_from"] = name
+ name = quoted_name
+
+ return OracleCompiler.bindparam_string(self, name, **kw)
+
+
+class OracleExecutionContext_cx_oracle(OracleExecutionContext):
+ out_parameters = None
+
+ def _generate_out_parameter_vars(self):
+ # check for has_out_parameters or RETURNING, create cx_Oracle.var
+ # objects if so
+ if self.compiled.returning or self.compiled.has_out_parameters:
+ quoted_bind_names = self.compiled.escaped_bind_names
+ for bindparam in self.compiled.binds.values():
+ if bindparam.isoutparam:
+ name = self.compiled.bind_names[bindparam]
+ type_impl = bindparam.type.dialect_impl(self.dialect)
+
+ if hasattr(type_impl, "_cx_oracle_var"):
+ self.out_parameters[name] = type_impl._cx_oracle_var(
+ self.dialect, self.cursor
+ )
+ else:
+ dbtype = type_impl.get_dbapi_type(self.dialect.dbapi)
+
+ cx_Oracle = self.dialect.dbapi
+
+ if dbtype is None:
+ raise exc.InvalidRequestError(
+ "Cannot create out parameter for "
+ "parameter "
+ "%r - its type %r is not supported by"
+ " cx_oracle" % (bindparam.key, bindparam.type)
+ )
+
+ if compat.py2k and dbtype in (
+ cx_Oracle.CLOB,
+ cx_Oracle.NCLOB,
+ ):
+ outconverter = (
+ processors.to_unicode_processor_factory(
+ self.dialect.encoding,
+ errors=self.dialect.encoding_errors,
+ )
+ )
+ self.out_parameters[name] = self.cursor.var(
+ dbtype,
+ outconverter=lambda value: outconverter(
+ value.read()
+ ),
+ )
+
+ elif dbtype in (
+ cx_Oracle.BLOB,
+ cx_Oracle.CLOB,
+ cx_Oracle.NCLOB,
+ ):
+ self.out_parameters[name] = self.cursor.var(
+ dbtype, outconverter=lambda value: value.read()
+ )
+ elif compat.py2k and isinstance(
+ type_impl, sqltypes.Unicode
+ ):
+ outconverter = (
+ processors.to_unicode_processor_factory(
+ self.dialect.encoding,
+ errors=self.dialect.encoding_errors,
+ )
+ )
+ self.out_parameters[name] = self.cursor.var(
+ dbtype, outconverter=outconverter
+ )
+ else:
+ self.out_parameters[name] = self.cursor.var(dbtype)
+ self.parameters[0][
+ quoted_bind_names.get(name, name)
+ ] = self.out_parameters[name]
+
+ def _generate_cursor_outputtype_handler(self):
+ output_handlers = {}
+
+ for (keyname, name, objects, type_) in self.compiled._result_columns:
+ handler = type_._cached_custom_processor(
+ self.dialect,
+ "cx_oracle_outputtypehandler",
+ self._get_cx_oracle_type_handler,
+ )
+
+ if handler:
+ denormalized_name = self.dialect.denormalize_name(keyname)
+ output_handlers[denormalized_name] = handler
+
+ if output_handlers:
+ default_handler = self._dbapi_connection.outputtypehandler
+
+ def output_type_handler(
+ cursor, name, default_type, size, precision, scale
+ ):
+ if name in output_handlers:
+ return output_handlers[name](
+ cursor, name, default_type, size, precision, scale
+ )
+ else:
+ return default_handler(
+ cursor, name, default_type, size, precision, scale
+ )
+
+ self.cursor.outputtypehandler = output_type_handler
+
+ def _get_cx_oracle_type_handler(self, impl):
+ if hasattr(impl, "_cx_oracle_outputtypehandler"):
+ return impl._cx_oracle_outputtypehandler(self.dialect)
+ else:
+ return None
+
+ def pre_exec(self):
+ if not getattr(self.compiled, "_oracle_cx_sql_compiler", False):
+ return
+
+ self.out_parameters = {}
+
+ self._generate_out_parameter_vars()
+
+ self._generate_cursor_outputtype_handler()
+
+ self.include_set_input_sizes = self.dialect._include_setinputsizes
+
+ def post_exec(self):
+ if self.compiled and self.out_parameters and self.compiled.returning:
+ # create a fake cursor result from the out parameters. unlike
+ # get_out_parameter_values(), the result-row handlers here will be
+ # applied at the Result level
+ returning_params = [
+ self.dialect._returningval(self.out_parameters["ret_%d" % i])
+ for i in range(len(self.out_parameters))
+ ]
+
+ fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy(
+ self.cursor,
+ [
+ (getattr(col, "name", col._anon_name_label), None)
+ for col in self.compiled.returning
+ ],
+ initial_buffer=[tuple(returning_params)],
+ )
+
+ self.cursor_fetch_strategy = fetch_strategy
+
+ def create_cursor(self):
+ c = self._dbapi_connection.cursor()
+ if self.dialect.arraysize:
+ c.arraysize = self.dialect.arraysize
+
+ return c
+
+ def get_out_parameter_values(self, out_param_names):
+ # this method should not be called when the compiler has
+ # RETURNING as we've turned the has_out_parameters flag set to
+ # False.
+ assert not self.compiled.returning
+
+ return [
+ self.dialect._paramval(self.out_parameters[name])
+ for name in out_param_names
+ ]
+
+
+class OracleDialect_cx_oracle(OracleDialect):
+ supports_statement_cache = True
+ execution_ctx_cls = OracleExecutionContext_cx_oracle
+ statement_compiler = OracleCompiler_cx_oracle
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ use_setinputsizes = True
+
+ driver = "cx_oracle"
+
+ colspecs = {
+ sqltypes.Numeric: _OracleNumeric,
+ sqltypes.Float: _OracleNumeric,
+ oracle.BINARY_FLOAT: _OracleBINARY_FLOAT,
+ oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE,
+ sqltypes.Integer: _OracleInteger,
+ oracle.NUMBER: _OracleNUMBER,
+ sqltypes.Date: _OracleDate,
+ sqltypes.LargeBinary: _OracleBinary,
+ sqltypes.Boolean: oracle._OracleBoolean,
+ sqltypes.Interval: _OracleInterval,
+ oracle.INTERVAL: _OracleInterval,
+ sqltypes.Text: _OracleText,
+ sqltypes.String: _OracleString,
+ sqltypes.UnicodeText: _OracleUnicodeTextCLOB,
+ sqltypes.CHAR: _OracleChar,
+ sqltypes.NCHAR: _OracleNChar,
+ sqltypes.Enum: _OracleEnum,
+ oracle.LONG: _OracleLong,
+ oracle.RAW: _OracleRaw,
+ sqltypes.Unicode: _OracleUnicodeStringCHAR,
+ sqltypes.NVARCHAR: _OracleUnicodeStringNCHAR,
+ oracle.NCLOB: _OracleUnicodeTextNCLOB,
+ oracle.ROWID: _OracleRowid,
+ }
+
+ execute_sequence_format = list
+
+ _cx_oracle_threaded = None
+
+ @util.deprecated_params(
+ threaded=(
+ "1.3",
+ "The 'threaded' parameter to the cx_oracle dialect "
+ "is deprecated as a dialect-level argument, and will be removed "
+ "in a future release. As of version 1.3, it defaults to False "
+ "rather than True. The 'threaded' option can be passed to "
+ "cx_Oracle directly in the URL query string passed to "
+ ":func:`_sa.create_engine`.",
+ )
+ )
+ def __init__(
+ self,
+ auto_convert_lobs=True,
+ coerce_to_unicode=True,
+ coerce_to_decimal=True,
+ arraysize=50,
+ encoding_errors=None,
+ threaded=None,
+ **kwargs
+ ):
+
+ OracleDialect.__init__(self, **kwargs)
+ self.arraysize = arraysize
+ self.encoding_errors = encoding_errors
+ if threaded is not None:
+ self._cx_oracle_threaded = threaded
+ self.auto_convert_lobs = auto_convert_lobs
+ self.coerce_to_unicode = coerce_to_unicode
+ self.coerce_to_decimal = coerce_to_decimal
+ if self._use_nchar_for_unicode:
+ self.colspecs = self.colspecs.copy()
+ self.colspecs[sqltypes.Unicode] = _OracleUnicodeStringNCHAR
+ self.colspecs[sqltypes.UnicodeText] = _OracleUnicodeTextNCLOB
+
+ cx_Oracle = self.dbapi
+
+ if cx_Oracle is None:
+ self._include_setinputsizes = {}
+ self.cx_oracle_ver = (0, 0, 0)
+ else:
+ self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version)
+ if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0):
+ raise exc.InvalidRequestError(
+ "cx_Oracle version 5.2 and above are supported"
+ )
+
+ self._include_setinputsizes = {
+ cx_Oracle.DATETIME,
+ cx_Oracle.NCLOB,
+ cx_Oracle.CLOB,
+ cx_Oracle.LOB,
+ cx_Oracle.NCHAR,
+ cx_Oracle.FIXED_NCHAR,
+ cx_Oracle.BLOB,
+ cx_Oracle.FIXED_CHAR,
+ cx_Oracle.TIMESTAMP,
+ _OracleInteger,
+ _OracleBINARY_FLOAT,
+ _OracleBINARY_DOUBLE,
+ }
+
+ self._paramval = lambda value: value.getvalue()
+
+ # https://github.com/oracle/python-cx_Oracle/issues/176#issuecomment-386821291
+ # https://github.com/oracle/python-cx_Oracle/issues/224
+ self._values_are_lists = self.cx_oracle_ver >= (6, 3)
+ if self._values_are_lists:
+ cx_Oracle.__future__.dml_ret_array_val = True
+
+ def _returningval(value):
+ try:
+ return value.values[0][0]
+ except IndexError:
+ return None
+
+ self._returningval = _returningval
+ else:
+ self._returningval = self._paramval
+
+ self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,)
+
+ @property
+ def _cursor_var_unicode_kwargs(self):
+ if self.encoding_errors:
+ if self.cx_oracle_ver >= (6, 4):
+ return {"encodingErrors": self.encoding_errors}
+ else:
+ util.warn(
+ "cx_oracle version %r does not support encodingErrors"
+ % (self.cx_oracle_ver,)
+ )
+
+ return {}
+
+ def _parse_cx_oracle_ver(self, version):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
+ if m:
+ return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
+ else:
+ return (0, 0, 0)
+
+ @classmethod
+ def dbapi(cls):
+ import cx_Oracle
+
+ return cx_Oracle
+
+ def initialize(self, connection):
+ super(OracleDialect_cx_oracle, self).initialize(connection)
+ if self._is_oracle_8:
+ self.supports_unicode_binds = False
+
+ self._detect_decimal_char(connection)
+
+ def get_isolation_level(self, connection):
+ # sources:
+
+ # general idea of transaction id, have to start one, etc.
+ # https://stackoverflow.com/questions/10711204/how-to-check-isoloation-level
+
+ # how to decode xid cols from v$transaction to match
+ # https://asktom.oracle.com/pls/apex/f?p=100:11:0::::P11_QUESTION_ID:9532779900346079444
+
+ # Oracle tuple comparison without using IN:
+ # https://www.sql-workbench.eu/comparison/tuple_comparison.html
+
+ with connection.cursor() as cursor:
+ # this is the only way to ensure a transaction is started without
+ # actually running DML. There's no way to see the configured
+ # isolation level without getting it from v$transaction which
+ # means transaction has to be started.
+ outval = cursor.var(str)
+ cursor.execute(
+ """
+ begin
+ :trans_id := dbms_transaction.local_transaction_id( TRUE );
+ end;
+ """,
+ {"trans_id": outval},
+ )
+ trans_id = outval.getvalue()
+ xidusn, xidslot, xidsqn = trans_id.split(".", 2)
+
+ cursor.execute(
+ "SELECT CASE BITAND(t.flag, POWER(2, 28)) "
+ "WHEN 0 THEN 'READ COMMITTED' "
+ "ELSE 'SERIALIZABLE' END AS isolation_level "
+ "FROM v$transaction t WHERE "
+ "(t.xidusn, t.xidslot, t.xidsqn) = "
+ "((:xidusn, :xidslot, :xidsqn))",
+ {"xidusn": xidusn, "xidslot": xidslot, "xidsqn": xidsqn},
+ )
+ row = cursor.fetchone()
+ if row is None:
+ raise exc.InvalidRequestError(
+ "could not retrieve isolation level"
+ )
+ result = row[0]
+
+ return result
+
+ def set_isolation_level(self, connection, level):
+ if hasattr(connection, "dbapi_connection"):
+ dbapi_connection = connection.dbapi_connection
+ else:
+ dbapi_connection = connection
+ if level == "AUTOCOMMIT":
+ dbapi_connection.autocommit = True
+ else:
+ dbapi_connection.autocommit = False
+ connection.rollback()
+ with connection.cursor() as cursor:
+ cursor.execute("ALTER SESSION SET ISOLATION_LEVEL=%s" % level)
+
+ def _detect_decimal_char(self, connection):
+ # we have the option to change this setting upon connect,
+ # or just look at what it is upon connect and convert.
+ # to minimize the chance of interference with changes to
+ # NLS_TERRITORY or formatting behavior of the DB, we opt
+ # to just look at it
+
+ self._decimal_char = connection.exec_driver_sql(
+ "select value from nls_session_parameters "
+ "where parameter = 'NLS_NUMERIC_CHARACTERS'"
+ ).scalar()[0]
+ if self._decimal_char != ".":
+ _detect_decimal = self._detect_decimal
+ _to_decimal = self._to_decimal
+
+ self._detect_decimal = lambda value: _detect_decimal(
+ value.replace(self._decimal_char, ".")
+ )
+ self._to_decimal = lambda value: _to_decimal(
+ value.replace(self._decimal_char, ".")
+ )
+
+ def _detect_decimal(self, value):
+ if "." in value:
+ return self._to_decimal(value)
+ else:
+ return int(value)
+
+ _to_decimal = decimal.Decimal
+
+ def _generate_connection_outputtype_handler(self):
+ """establish the default outputtypehandler established at the
+ connection level.
+
+ """
+
+ dialect = self
+ cx_Oracle = dialect.dbapi
+
+ number_handler = _OracleNUMBER(
+ asdecimal=True
+ )._cx_oracle_outputtypehandler(dialect)
+ float_handler = _OracleNUMBER(
+ asdecimal=False
+ )._cx_oracle_outputtypehandler(dialect)
+
+ def output_type_handler(
+ cursor, name, default_type, size, precision, scale
+ ):
+
+ if (
+ default_type == cx_Oracle.NUMBER
+ and default_type is not cx_Oracle.NATIVE_FLOAT
+ ):
+ if not dialect.coerce_to_decimal:
+ return None
+ elif precision == 0 and scale in (0, -127):
+ # ambiguous type, this occurs when selecting
+ # numbers from deep subqueries
+ return cursor.var(
+ cx_Oracle.STRING,
+ 255,
+ outconverter=dialect._detect_decimal,
+ arraysize=cursor.arraysize,
+ )
+ elif precision and scale > 0:
+ return number_handler(
+ cursor, name, default_type, size, precision, scale
+ )
+ else:
+ return float_handler(
+ cursor, name, default_type, size, precision, scale
+ )
+
+ # allow all strings to come back natively as Unicode
+ elif (
+ dialect.coerce_to_unicode
+ and default_type
+ in (
+ cx_Oracle.STRING,
+ cx_Oracle.FIXED_CHAR,
+ )
+ and default_type is not cx_Oracle.CLOB
+ and default_type is not cx_Oracle.NCLOB
+ ):
+ if compat.py2k:
+ outconverter = processors.to_unicode_processor_factory(
+ dialect.encoding, errors=dialect.encoding_errors
+ )
+ return cursor.var(
+ cx_Oracle.STRING,
+ size,
+ cursor.arraysize,
+ outconverter=outconverter,
+ )
+ else:
+ return cursor.var(
+ util.text_type,
+ size,
+ cursor.arraysize,
+ **dialect._cursor_var_unicode_kwargs
+ )
+
+ elif dialect.auto_convert_lobs and default_type in (
+ cx_Oracle.CLOB,
+ cx_Oracle.NCLOB,
+ ):
+ if compat.py2k:
+ outconverter = processors.to_unicode_processor_factory(
+ dialect.encoding, errors=dialect.encoding_errors
+ )
+ return cursor.var(
+ cx_Oracle.LONG_STRING,
+ size,
+ cursor.arraysize,
+ outconverter=outconverter,
+ )
+ else:
+ return cursor.var(
+ cx_Oracle.LONG_STRING,
+ size,
+ cursor.arraysize,
+ **dialect._cursor_var_unicode_kwargs
+ )
+
+ elif dialect.auto_convert_lobs and default_type in (
+ cx_Oracle.BLOB,
+ ):
+ return cursor.var(
+ cx_Oracle.LONG_BINARY,
+ size,
+ cursor.arraysize,
+ )
+
+ return output_type_handler
+
+ def on_connect(self):
+
+ output_type_handler = self._generate_connection_outputtype_handler()
+
+ def on_connect(conn):
+ conn.outputtypehandler = output_type_handler
+
+ return on_connect
+
+ def create_connect_args(self, url):
+ opts = dict(url.query)
+
+ for opt in ("use_ansi", "auto_convert_lobs"):
+ if opt in opts:
+ util.warn_deprecated(
+ "cx_oracle dialect option %r should only be passed to "
+ "create_engine directly, not within the URL string" % opt,
+ version="1.3",
+ )
+ util.coerce_kw_type(opts, opt, bool)
+ setattr(self, opt, opts.pop(opt))
+
+ database = url.database
+ service_name = opts.pop("service_name", None)
+ if database or service_name:
+ # if we have a database, then we have a remote host
+ port = url.port
+ if port:
+ port = int(port)
+ else:
+ port = 1521
+
+ if database and service_name:
+ raise exc.InvalidRequestError(
+ '"service_name" option shouldn\'t '
+ 'be used with a "database" part of the url'
+ )
+ if database:
+ makedsn_kwargs = {"sid": database}
+ if service_name:
+ makedsn_kwargs = {"service_name": service_name}
+
+ dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs)
+ else:
+ # we have a local tnsname
+ dsn = url.host
+
+ if dsn is not None:
+ opts["dsn"] = dsn
+ if url.password is not None:
+ opts["password"] = url.password
+ if url.username is not None:
+ opts["user"] = url.username
+
+ if self._cx_oracle_threaded is not None:
+ opts.setdefault("threaded", self._cx_oracle_threaded)
+
+ def convert_cx_oracle_constant(value):
+ if isinstance(value, util.string_types):
+ try:
+ int_val = int(value)
+ except ValueError:
+ value = value.upper()
+ return getattr(self.dbapi, value)
+ else:
+ return int_val
+ else:
+ return value
+
+ util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant)
+ util.coerce_kw_type(opts, "threaded", bool)
+ util.coerce_kw_type(opts, "events", bool)
+ util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant)
+ return ([], opts)
+
+ def _get_server_version_info(self, connection):
+ return tuple(int(x) for x in connection.connection.version.split("."))
+
+ def is_disconnect(self, e, connection, cursor):
+ (error,) = e.args
+ if isinstance(
+ e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError)
+ ) and "not connected" in str(e):
+ return True
+
+ if hasattr(error, "code") and error.code in {
+ 28,
+ 3114,
+ 3113,
+ 3135,
+ 1033,
+ 2396,
+ }:
+ # ORA-00028: your session has been killed
+ # ORA-03114: not connected to ORACLE
+ # ORA-03113: end-of-file on communication channel
+ # ORA-03135: connection lost contact
+ # ORA-01033: ORACLE initialization or shutdown in progress
+ # ORA-02396: exceeded maximum idle time, please connect again
+ # TODO: Others ?
+ return True
+
+ if re.match(r"^(?:DPI-1010|DPI-1080|DPY-1001|DPY-4011)", str(e)):
+ # DPI-1010: not connected
+ # DPI-1080: connection was closed by ORA-3113
+ # python-oracledb's DPY-1001: not connected to database
+ # python-oracledb's DPY-4011: the database or network closed the
+ # connection
+ # TODO: others?
+ return True
+
+ return False
+
+ def create_xid(self):
+ """create a two-phase transaction ID.
+
+ this id will be passed to do_begin_twophase(), do_rollback_twophase(),
+ do_commit_twophase(). its format is unspecified.
+
+ """
+
+ id_ = random.randint(0, 2 ** 128)
+ return (0x1234, "%032x" % id_, "%032x" % 9)
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ if isinstance(parameters, tuple):
+ parameters = list(parameters)
+ cursor.executemany(statement, parameters)
+
+ def do_begin_twophase(self, connection, xid):
+ connection.connection.begin(*xid)
+ connection.connection.info["cx_oracle_xid"] = xid
+
+ def do_prepare_twophase(self, connection, xid):
+ result = connection.connection.prepare()
+ connection.info["cx_oracle_prepared"] = result
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ self.do_rollback(connection.connection)
+ # TODO: need to end XA state here
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+
+ if not is_prepared:
+ self.do_commit(connection.connection)
+ else:
+ if recover:
+ raise NotImplementedError(
+ "2pc recovery not implemented for cx_Oracle"
+ )
+ oci_prepared = connection.info["cx_oracle_prepared"]
+ if oci_prepared:
+ self.do_commit(connection.connection)
+ # TODO: need to end XA state here
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ # not usually used, here to support if someone is modifying
+ # the dialect to use positional style
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ collection = (
+ (key, dbtype)
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ )
+
+ if not self.supports_unicode_binds:
+ # oracle 8 only
+ collection = (
+ (self.dialect._encoder(key)[0], dbtype)
+ for key, dbtype in collection
+ )
+
+ cursor.setinputsizes(**{key: dbtype for key, dbtype in collection})
+
+ def do_recover_twophase(self, connection):
+ raise NotImplementedError(
+ "recover two phase query for cx_Oracle not implemented"
+ )
+
+
+dialect = OracleDialect_cx_oracle
diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py
new file mode 100644
index 0000000..74ad1f2
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/provision.py
@@ -0,0 +1,160 @@
+from ... import create_engine
+from ... import exc
+from ...engine import url as sa_url
+from ...testing.provision import configure_follower
+from ...testing.provision import create_db
+from ...testing.provision import drop_db
+from ...testing.provision import follower_url_from_main
+from ...testing.provision import log
+from ...testing.provision import post_configure_engine
+from ...testing.provision import run_reap_dbs
+from ...testing.provision import set_default_schema_on_connection
+from ...testing.provision import stop_test_class_outside_fixtures
+from ...testing.provision import temp_table_keyword_args
+
+
+@create_db.for_db("oracle")
+def _oracle_create_db(cfg, eng, ident):
+ # NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
+ # similar, so that the default tablespace is not "system"; reflection will
+ # fail otherwise
+ with eng.begin() as conn:
+ conn.exec_driver_sql("create user %s identified by xe" % ident)
+ conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident)
+ conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident)
+ conn.exec_driver_sql("grant dba to %s" % (ident,))
+ conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
+ conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
+ conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
+
+
+@configure_follower.for_db("oracle")
+def _oracle_configure_follower(config, ident):
+ config.test_schema = "%s_ts1" % ident
+ config.test_schema_2 = "%s_ts2" % ident
+
+
+def _ora_drop_ignore(conn, dbname):
+ try:
+ conn.exec_driver_sql("drop user %s cascade" % dbname)
+ log.info("Reaped db: %s", dbname)
+ return True
+ except exc.DatabaseError as err:
+ log.warning("couldn't drop db: %s", err)
+ return False
+
+
+@drop_db.for_db("oracle")
+def _oracle_drop_db(cfg, eng, ident):
+ with eng.begin() as conn:
+ # cx_Oracle seems to occasionally leak open connections when a large
+ # suite it run, even if we confirm we have zero references to
+ # connection objects.
+ # while there is a "kill session" command in Oracle,
+ # it unfortunately does not release the connection sufficiently.
+ _ora_drop_ignore(conn, ident)
+ _ora_drop_ignore(conn, "%s_ts1" % ident)
+ _ora_drop_ignore(conn, "%s_ts2" % ident)
+
+
+@stop_test_class_outside_fixtures.for_db("oracle")
+def stop_test_class_outside_fixtures(config, db, cls):
+
+ try:
+ with db.begin() as conn:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
+ conn.exec_driver_sql("purge recyclebin")
+ except exc.DatabaseError as err:
+ log.warning("purge recyclebin command failed: %s", err)
+
+ # clear statement cache on all connections that were used
+ # https://github.com/oracle/python-cx_Oracle/issues/519
+
+ for cx_oracle_conn in _all_conns:
+ try:
+ sc = cx_oracle_conn.stmtcachesize
+ except db.dialect.dbapi.InterfaceError:
+ # connection closed
+ pass
+ else:
+ cx_oracle_conn.stmtcachesize = 0
+ cx_oracle_conn.stmtcachesize = sc
+ _all_conns.clear()
+
+
+_all_conns = set()
+
+
+@post_configure_engine.for_db("oracle")
+def _oracle_post_configure_engine(url, engine, follower_ident):
+ from sqlalchemy import event
+
+ @event.listens_for(engine, "checkout")
+ def checkout(dbapi_con, con_record, con_proxy):
+ _all_conns.add(dbapi_con)
+
+ @event.listens_for(engine, "checkin")
+ def checkin(dbapi_connection, connection_record):
+ # work around cx_Oracle issue:
+ # https://github.com/oracle/python-cx_Oracle/issues/530
+ # invalidate oracle connections that had 2pc set up
+ if "cx_oracle_xid" in connection_record.info:
+ connection_record.invalidate()
+
+
+@run_reap_dbs.for_db("oracle")
+def _reap_oracle_dbs(url, idents):
+ log.info("db reaper connecting to %r", url)
+ eng = create_engine(url)
+ with eng.begin() as conn:
+
+ log.info("identifiers in file: %s", ", ".join(idents))
+
+ to_reap = conn.exec_driver_sql(
+ "select u.username from all_users u where username "
+ "like 'TEST_%' and not exists (select username "
+ "from v$session where username=u.username)"
+ )
+ all_names = {username.lower() for (username,) in to_reap}
+ to_drop = set()
+ for name in all_names:
+ if name.endswith("_ts1") or name.endswith("_ts2"):
+ continue
+ elif name in idents:
+ to_drop.add(name)
+ if "%s_ts1" % name in all_names:
+ to_drop.add("%s_ts1" % name)
+ if "%s_ts2" % name in all_names:
+ to_drop.add("%s_ts2" % name)
+
+ dropped = total = 0
+ for total, username in enumerate(to_drop, 1):
+ if _ora_drop_ignore(conn, username):
+ dropped += 1
+ log.info(
+ "Dropped %d out of %d stale databases detected", dropped, total
+ )
+
+
+@follower_url_from_main.for_db("oracle")
+def _oracle_follower_url_from_main(url, ident):
+ url = sa_url.make_url(url)
+ return url.set(username=ident, password="xe")
+
+
+@temp_table_keyword_args.for_db("oracle")
+def _oracle_temp_table_keyword_args(cfg, eng):
+ return {
+ "prefixes": ["GLOBAL TEMPORARY"],
+ "oracle_on_commit": "PRESERVE ROWS",
+ }
+
+
+@set_default_schema_on_connection.for_db("oracle")
+def _oracle_set_default_schema_on_connection(
+ cfg, dbapi_connection, schema_name
+):
+ cursor = dbapi_connection.cursor()
+ cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name)
+ cursor.close()
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
new file mode 100644
index 0000000..12d9e94
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -0,0 +1,117 @@
+# postgresql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from . import base
+from . import pg8000 # noqa
+from . import psycopg2 # noqa
+from . import psycopg2cffi # noqa
+from . import pygresql # noqa
+from . import pypostgresql # noqa
+from .array import All
+from .array import Any
+from .array import ARRAY
+from .array import array
+from .base import BIGINT
+from .base import BIT
+from .base import BOOLEAN
+from .base import BYTEA
+from .base import CHAR
+from .base import CIDR
+from .base import CreateEnumType
+from .base import DATE
+from .base import DOUBLE_PRECISION
+from .base import DropEnumType
+from .base import ENUM
+from .base import FLOAT
+from .base import INET
+from .base import INTEGER
+from .base import INTERVAL
+from .base import MACADDR
+from .base import MONEY
+from .base import NUMERIC
+from .base import OID
+from .base import REAL
+from .base import REGCLASS
+from .base import SMALLINT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import TSVECTOR
+from .base import UUID
+from .base import VARCHAR
+from .dml import Insert
+from .dml import insert
+from .ext import aggregate_order_by
+from .ext import array_agg
+from .ext import ExcludeConstraint
+from .hstore import HSTORE
+from .hstore import hstore
+from .json import JSON
+from .json import JSONB
+from .ranges import DATERANGE
+from .ranges import INT4RANGE
+from .ranges import INT8RANGE
+from .ranges import NUMRANGE
+from .ranges import TSRANGE
+from .ranges import TSTZRANGE
+from ...util import compat
+
+if compat.py3k:
+ from . import asyncpg # noqa
+
+base.dialect = dialect = psycopg2.dialect
+
+
+__all__ = (
+ "INTEGER",
+ "BIGINT",
+ "SMALLINT",
+ "VARCHAR",
+ "CHAR",
+ "TEXT",
+ "NUMERIC",
+ "FLOAT",
+ "REAL",
+ "INET",
+ "CIDR",
+ "UUID",
+ "BIT",
+ "MACADDR",
+ "MONEY",
+ "OID",
+ "REGCLASS",
+ "DOUBLE_PRECISION",
+ "TIMESTAMP",
+ "TIME",
+ "DATE",
+ "BYTEA",
+ "BOOLEAN",
+ "INTERVAL",
+ "ARRAY",
+ "ENUM",
+ "dialect",
+ "array",
+ "HSTORE",
+ "hstore",
+ "INT4RANGE",
+ "INT8RANGE",
+ "NUMRANGE",
+ "DATERANGE",
+ "TSVECTOR",
+ "TSRANGE",
+ "TSTZRANGE",
+ "JSON",
+ "JSONB",
+ "Any",
+ "All",
+ "DropEnumType",
+ "CreateEnumType",
+ "ExcludeConstraint",
+ "aggregate_order_by",
+ "array_agg",
+ "insert",
+ "Insert",
+)
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
new file mode 100644
index 0000000..daf7c5d
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -0,0 +1,413 @@
+# postgresql/array.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from ... import types as sqltypes
+from ... import util
+from ...sql import coercions
+from ...sql import expression
+from ...sql import operators
+from ...sql import roles
+
+
+def Any(other, arrexpr, operator=operators.eq):
+ """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
+ See that method for details.
+
+ """
+
+ return arrexpr.any(other, operator)
+
+
+def All(other, arrexpr, operator=operators.eq):
+ """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
+ See that method for details.
+
+ """
+
+ return arrexpr.all(other, operator)
+
+
+class array(expression.ClauseList, expression.ColumnElement):
+
+ """A PostgreSQL ARRAY literal.
+
+ This is used to produce ARRAY literals in SQL expressions, e.g.::
+
+ from sqlalchemy.dialects.postgresql import array
+ from sqlalchemy.dialects import postgresql
+ from sqlalchemy import select, func
+
+ stmt = select(array([1,2]) + array([3,4,5]))
+
+ print(stmt.compile(dialect=postgresql.dialect()))
+
+ Produces the SQL::
+
+ SELECT ARRAY[%(param_1)s, %(param_2)s] ||
+ ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
+
+ An instance of :class:`.array` will always have the datatype
+ :class:`_types.ARRAY`. The "inner" type of the array is inferred from
+ the values present, unless the ``type_`` keyword argument is passed::
+
+ array(['foo', 'bar'], type_=CHAR)
+
+ Multidimensional arrays are produced by nesting :class:`.array` constructs.
+ The dimensionality of the final :class:`_types.ARRAY`
+ type is calculated by
+ recursively adding the dimensions of the inner :class:`_types.ARRAY`
+ type::
+
+ stmt = select(
+ array([
+ array([1, 2]), array([3, 4]), array([column('q'), column('x')])
+ ])
+ )
+ print(stmt.compile(dialect=postgresql.dialect()))
+
+ Produces::
+
+ SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
+ ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
+
+ .. versionadded:: 1.3.6 added support for multidimensional array literals
+
+ .. seealso::
+
+ :class:`_postgresql.ARRAY`
+
+ """
+
+ __visit_name__ = "array"
+
+ stringify_dialect = "postgresql"
+ inherit_cache = True
+
+ def __init__(self, clauses, **kw):
+ clauses = [
+ coercions.expect(roles.ExpressionElementRole, c) for c in clauses
+ ]
+
+ super(array, self).__init__(*clauses, **kw)
+
+ self._type_tuple = [arg.type for arg in clauses]
+ main_type = kw.pop(
+ "type_",
+ self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE,
+ )
+
+ if isinstance(main_type, ARRAY):
+ self.type = ARRAY(
+ main_type.item_type,
+ dimensions=main_type.dimensions + 1
+ if main_type.dimensions is not None
+ else 2,
+ )
+ else:
+ self.type = ARRAY(main_type)
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
+ if _assume_scalar or operator is operators.getitem:
+ return expression.BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ )
+
+ else:
+ return array(
+ [
+ self._bind_param(
+ operator, o, _assume_scalar=True, type_=type_
+ )
+ for o in obj
+ ]
+ )
+
+ def self_group(self, against=None):
+ if against in (operators.any_op, operators.all_op, operators.getitem):
+ return expression.Grouping(self)
+ else:
+ return self
+
+
+CONTAINS = operators.custom_op("@>", precedence=5, is_comparison=True)
+
+CONTAINED_BY = operators.custom_op("<@", precedence=5, is_comparison=True)
+
+OVERLAP = operators.custom_op("&&", precedence=5, is_comparison=True)
+
+
+class ARRAY(sqltypes.ARRAY):
+
+ """PostgreSQL ARRAY type.
+
+ .. versionchanged:: 1.1 The :class:`_postgresql.ARRAY` type is now
+ a subclass of the core :class:`_types.ARRAY` type.
+
+ The :class:`_postgresql.ARRAY` type is constructed in the same way
+ as the core :class:`_types.ARRAY` type; a member type is required, and a
+ number of dimensions is recommended if the type is to be used for more
+ than one dimension::
+
+ from sqlalchemy.dialects import postgresql
+
+ mytable = Table("mytable", metadata,
+ Column("data", postgresql.ARRAY(Integer, dimensions=2))
+ )
+
+ The :class:`_postgresql.ARRAY` type provides all operations defined on the
+ core :class:`_types.ARRAY` type, including support for "dimensions",
+ indexed access, and simple matching such as
+ :meth:`.types.ARRAY.Comparator.any` and
+ :meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY`
+ class also
+ provides PostgreSQL-specific methods for containment operations, including
+ :meth:`.postgresql.ARRAY.Comparator.contains`
+ :meth:`.postgresql.ARRAY.Comparator.contained_by`, and
+ :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.::
+
+ mytable.c.data.contains([1, 2])
+
+ The :class:`_postgresql.ARRAY` type may not be supported on all
+ PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
+
+ Additionally, the :class:`_postgresql.ARRAY`
+ type does not work directly in
+ conjunction with the :class:`.ENUM` type. For a workaround, see the
+ special type at :ref:`postgresql_array_of_enum`.
+
+ .. seealso::
+
+ :class:`_types.ARRAY` - base array type
+
+ :class:`_postgresql.array` - produces a literal array value.
+
+ """
+
+ class Comparator(sqltypes.ARRAY.Comparator):
+
+ """Define comparison operations for :class:`_types.ARRAY`.
+
+ Note that these operations are in addition to those provided
+ by the base :class:`.types.ARRAY.Comparator` class, including
+ :meth:`.types.ARRAY.Comparator.any` and
+ :meth:`.types.ARRAY.Comparator.all`.
+
+ """
+
+ def contains(self, other, **kwargs):
+ """Boolean expression. Test if elements are a superset of the
+ elements of the argument array expression.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
+
+ def contained_by(self, other):
+ """Boolean expression. Test if elements are a proper subset of the
+ elements of the argument array expression.
+ """
+ return self.operate(
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
+
+ def overlap(self, other):
+ """Boolean expression. Test if array has elements in common with
+ an argument array expression.
+ """
+ return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
+
+ comparator_factory = Comparator
+
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
+ """Construct an ARRAY.
+
+ E.g.::
+
+ Column('myarray', ARRAY(Integer))
+
+ Arguments are:
+
+ :param item_type: The data type of items of this array. Note that
+ dimensionality is irrelevant here, so multi-dimensional arrays like
+ ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
+ ``ARRAY(ARRAY(Integer))`` or such.
+
+ :param as_tuple=False: Specify whether return results
+ should be converted to tuples from lists. DBAPIs such
+ as psycopg2 return lists by default. When tuples are
+ returned, the results are hashable.
+
+ :param dimensions: if non-None, the ARRAY will assume a fixed
+ number of dimensions. This will cause the DDL emitted for this
+ ARRAY to include the exact number of bracket clauses ``[]``,
+ and will also optimize the performance of the type overall.
+ Note that PG arrays are always implicitly "non-dimensioned",
+ meaning they can store any number of dimensions no matter how
+ they were declared.
+
+ :param zero_indexes=False: when True, index values will be converted
+ between Python zero-based and PostgreSQL one-based indexes, e.g.
+ a value of one will be added to all index values before passing
+ to the database.
+
+ .. versionadded:: 0.9.5
+
+
+ """
+ if isinstance(item_type, ARRAY):
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+ self.as_tuple = as_tuple
+ self.dimensions = dimensions
+ self.zero_indexes = zero_indexes
+
+ @property
+ def hashable(self):
+ return self.as_tuple
+
+ @property
+ def python_type(self):
+ return list
+
+ def compare_values(self, x, y):
+ return x == y
+
+ def _proc_array(self, arr, itemproc, dim, collection):
+ if dim is None:
+ arr = list(arr)
+ if (
+ dim == 1
+ or dim is None
+ and (
+ # this has to be (list, tuple), or at least
+ # not hasattr('__iter__'), since Py3K strings
+ # etc. have __iter__
+ not arr
+ or not isinstance(arr[0], (list, tuple))
+ )
+ ):
+ if itemproc:
+ return collection(itemproc(x) for x in arr)
+ else:
+ return collection(arr)
+ else:
+ return collection(
+ self._proc_array(
+ x,
+ itemproc,
+ dim - 1 if dim is not None else None,
+ collection,
+ )
+ for x in arr
+ )
+
+ @util.memoized_property
+ def _against_native_enum(self):
+ return (
+ isinstance(self.item_type, sqltypes.Enum)
+ and self.item_type.native_enum
+ )
+
+ def bind_expression(self, bindvalue):
+ return bindvalue
+
+ def bind_processor(self, dialect):
+ item_proc = self.item_type.dialect_impl(dialect).bind_processor(
+ dialect
+ )
+
+ def process(value):
+ if value is None:
+ return value
+ else:
+ return self._proc_array(
+ value, item_proc, self.dimensions, list
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ item_proc = self.item_type.dialect_impl(dialect).result_processor(
+ dialect, coltype
+ )
+
+ def process(value):
+ if value is None:
+ return value
+ else:
+ return self._proc_array(
+ value,
+ item_proc,
+ self.dimensions,
+ tuple if self.as_tuple else list,
+ )
+
+ if self._against_native_enum:
+ super_rp = process
+ pattern = re.compile(r"^{(.*)}$")
+
+ def handle_raw_string(value):
+ inner = pattern.match(value).group(1)
+ return _split_enum_values(inner)
+
+ def process(value):
+ if value is None:
+ return value
+ # isinstance(value, util.string_types) is required to handle
+ # the case where a TypeDecorator for and Array of Enum is
+ # used like was required in sa < 1.3.17
+ return super_rp(
+ handle_raw_string(value)
+ if isinstance(value, util.string_types)
+ else value
+ )
+
+ return process
+
+
+def _split_enum_values(array_string):
+
+ if '"' not in array_string:
+ # no escape char is present so it can just split on the comma
+ return array_string.split(",") if array_string else []
+
+ # handles quoted strings from:
+ # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr'
+ # returns
+ # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr']
+ text = array_string.replace(r"\"", "_$ESC_QUOTE$_")
+ text = text.replace(r"\\", "\\")
+ result = []
+ on_quotes = re.split(r'(")', text)
+ in_quotes = False
+ for tok in on_quotes:
+ if tok == '"':
+ in_quotes = not in_quotes
+ elif in_quotes:
+ result.append(tok.replace("_$ESC_QUOTE$_", '"'))
+ else:
+ result.extend(re.findall(r"([^\s,]+),?", tok))
+ return result
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
new file mode 100644
index 0000000..305ad46
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -0,0 +1,1112 @@
+# postgresql/asyncpg.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+asyncpg
+ :name: asyncpg
+ :dbapi: asyncpg
+ :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://magicstack.github.io/asyncpg/
+
+The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
+
+Using a special asyncio mediation layer, the asyncpg dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
+
+The dialect can also be run as a "synchronous" dialect within the
+:func:`_sa.create_engine` function, which will pass "await" calls into
+an ad-hoc event loop. This mode of operation is of **limited use**
+and is for special testing scenarios only. The mode can be enabled by
+adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
+in conjunction with :func:`_sa.create_engine`::
+
+ # for testing purposes only; do not use in production!
+ engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
+
+
+.. versionadded:: 1.4
+
+.. note::
+
+ By default asyncpg does not decode the ``json`` and ``jsonb`` types and
+ returns them as strings. SQLAlchemy sets default type decoder for ``json``
+ and ``jsonb`` types using the python builtin ``json.loads`` function.
+ The json implementation used can be changed by setting the attribute
+ ``json_deserializer`` when creating the engine with
+ :func:`create_engine` or :func:`create_async_engine`.
+
+
+.. _asyncpg_prepared_statement_cache:
+
+Prepared Statement Cache
+--------------------------
+
+The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
+for all statements. The prepared statement objects are cached after
+construction which appears to grant a 10% or more performance improvement for
+statement invocation. The cache is on a per-DBAPI connection basis, which
+means that the primary storage for prepared statements is within DBAPI
+connections pooled within the connection pool. The size of this cache
+defaults to 100 statements per DBAPI connection and may be adjusted using the
+``prepared_statement_cache_size`` DBAPI argument (note that while this argument
+is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
+asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
+argument)::
+
+
+ engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
+
+To disable the prepared statement cache, use a value of zero::
+
+ engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
+
+.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
+
+
+.. warning:: The ``asyncpg`` database driver necessarily uses caches for
+ PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
+ such as ``ENUM`` objects are changed via DDL operations. Additionally,
+ prepared statements themselves which are optionally cached by SQLAlchemy's
+ driver as described above may also become "stale" when DDL has been emitted
+ to the PostgreSQL database which modifies the tables or other objects
+ involved in a particular prepared statement.
+
+ The SQLAlchemy asyncpg dialect will invalidate these caches within its local
+ process when statements that represent DDL are emitted on a local
+ connection, but this is only controllable within a single Python process /
+ database engine. If DDL changes are made from other database engines
+ and/or processes, a running application may encounter asyncpg exceptions
+ ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
+ failed for type <oid>")`` if it refers to pooled database connections which
+ operated upon the previous structures. The SQLAlchemy asyncpg dialect will
+ recover from these error cases when the driver raises these exceptions by
+ clearing its internal caches as well as those of the asyncpg driver in
+ response to them, but cannot prevent them from being raised in the first
+ place if the cached prepared statement or asyncpg type caches have gone
+ stale, nor can it retry the statement as the PostgreSQL transaction is
+ invalidated when these errors occur.
+
+Disabling the PostgreSQL JIT to improve ENUM datatype handling
+---------------------------------------------------------------
+
+Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
+using PostgreSQL ENUM datatypes, where upon the creation of new database
+connections, an expensive query may be emitted in order to retrieve metadata
+regarding custom types which has been shown to negatively affect performance.
+To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
+client using this setting passed to :func:`_asyncio.create_async_engine`::
+
+ engine = create_async_engine(
+ "postgresql+asyncpg://user:password@localhost/tmp",
+ connect_args={"server_settings": {"jit": "off"}},
+ )
+
+.. seealso::
+
+ https://github.com/MagicStack/asyncpg/issues/727
+
+""" # noqa
+
+import collections
+import decimal
+import json as _py_json
+import re
+import time
+
+from . import json
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import ENUM
+from .base import INTERVAL
+from .base import OID
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import PGIdentifierPreparer
+from .base import REGCLASS
+from .base import UUID
+from ... import exc
+from ... import pool
+from ... import processors
+from ... import util
+from ...engine import AdaptedConnection
+from ...sql import sqltypes
+from ...util.concurrency import asyncio
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+try:
+ from uuid import UUID as _python_UUID # noqa
+except ImportError:
+ _python_UUID = None
+
+
+class AsyncpgTime(sqltypes.Time):
+ def get_dbapi_type(self, dbapi):
+ if self.timezone:
+ return dbapi.TIME_W_TZ
+ else:
+ return dbapi.TIME
+
+
+class AsyncpgDate(sqltypes.Date):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATE
+
+
+class AsyncpgDateTime(sqltypes.DateTime):
+ def get_dbapi_type(self, dbapi):
+ if self.timezone:
+ return dbapi.TIMESTAMP_W_TZ
+ else:
+ return dbapi.TIMESTAMP
+
+
+class AsyncpgBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BOOLEAN
+
+
+class AsyncPgInterval(INTERVAL):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTERVAL
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+
+ return AsyncPgInterval(precision=interval.second_precision)
+
+
+class AsyncPgEnum(ENUM):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.ENUM
+
+
+class AsyncpgInteger(sqltypes.Integer):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class AsyncpgBigInteger(sqltypes.BigInteger):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BIGINTEGER
+
+
+class AsyncpgJSON(json.JSON):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSON
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class AsyncpgJSONB(json.JSONB):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSONB
+
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
+ def get_dbapi_type(self, dbapi):
+ raise NotImplementedError("should not be here")
+
+
+class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+
+class AsyncpgJSONPathType(json.JSONPathType):
+ def bind_processor(self, dialect):
+ def process(value):
+ assert isinstance(value, util.collections_abc.Sequence)
+ tokens = [util.text_type(elem) for elem in value]
+ return tokens
+
+ return process
+
+
+class AsyncpgUUID(UUID):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.UUID
+
+ def bind_processor(self, dialect):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = str(value)
+ return value
+
+ return process
+
+
+class AsyncpgNumeric(sqltypes.Numeric):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # pg8000 returns Decimal natively for 1700
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # pg8000 returns float natively for 701
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class AsyncpgFloat(AsyncpgNumeric):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.FLOAT
+
+
+class AsyncpgREGCLASS(REGCLASS):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+
+class AsyncpgOID(OID):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class PGExecutionContext_asyncpg(PGExecutionContext):
+ def handle_dbapi_exception(self, e):
+ if isinstance(
+ e,
+ (
+ self.dialect.dbapi.InvalidCachedStatementError,
+ self.dialect.dbapi.InternalServerError,
+ ),
+ ):
+ self.dialect._invalidate_schema_cache()
+
+ def pre_exec(self):
+ if self.isddl:
+ self.dialect._invalidate_schema_cache()
+
+ self.cursor._invalidate_schema_cache_asof = (
+ self.dialect._invalidate_schema_cache_asof
+ )
+
+ if not self.compiled:
+ return
+
+ # we have to exclude ENUM because "enum" not really a "type"
+ # we can cast to, it has to be the name of the type itself.
+ # for now we just omit it from casting
+ self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
+
+ def create_server_side_cursor(self):
+ return self._dbapi_connection.cursor(server_side=True)
+
+
+class PGCompiler_asyncpg(PGCompiler):
+ pass
+
+
+class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
+ pass
+
+
+class AsyncAdapt_asyncpg_cursor:
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "_rows",
+ "description",
+ "arraysize",
+ "rowcount",
+ "_inputsizes",
+ "_cursor",
+ "_invalidate_schema_cache_asof",
+ )
+
+ server_side = False
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self._rows = []
+ self._cursor = None
+ self.description = None
+ self.arraysize = 1
+ self.rowcount = -1
+ self._inputsizes = None
+ self._invalidate_schema_cache_asof = 0
+
+ def close(self):
+ self._rows[:] = []
+
+ def _handle_exception(self, error):
+ self._adapt_connection._handle_exception(error)
+
+ def _parameter_placeholders(self, params):
+ if not self._inputsizes:
+ return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
+ else:
+ return tuple(
+ "$%d::%s" % (idx, typ) if typ else "$%d" % idx
+ for idx, typ in enumerate(
+ (_pg_types.get(typ) for typ in self._inputsizes), 1
+ )
+ )
+
+ async def _prepare_and_execute(self, operation, parameters):
+ adapt_connection = self._adapt_connection
+
+ async with adapt_connection._execute_mutex:
+
+ if not adapt_connection._started:
+ await adapt_connection._start_transaction()
+
+ if parameters is not None:
+ operation = operation % self._parameter_placeholders(
+ parameters
+ )
+ else:
+ parameters = ()
+
+ try:
+ prepared_stmt, attributes = await adapt_connection._prepare(
+ operation, self._invalidate_schema_cache_asof
+ )
+
+ if attributes:
+ self.description = [
+ (
+ attr.name,
+ attr.type.oid,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+ for attr in attributes
+ ]
+ else:
+ self.description = None
+
+ if self.server_side:
+ self._cursor = await prepared_stmt.cursor(*parameters)
+ self.rowcount = -1
+ else:
+ self._rows = await prepared_stmt.fetch(*parameters)
+ status = prepared_stmt.get_statusmsg()
+
+ reg = re.match(
+ r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status
+ )
+ if reg:
+ self.rowcount = int(reg.group(1))
+ else:
+ self.rowcount = -1
+
+ except Exception as error:
+ self._handle_exception(error)
+
+ async def _executemany(self, operation, seq_of_parameters):
+ adapt_connection = self._adapt_connection
+
+ async with adapt_connection._execute_mutex:
+ await adapt_connection._check_type_cache_invalidation(
+ self._invalidate_schema_cache_asof
+ )
+
+ if not adapt_connection._started:
+ await adapt_connection._start_transaction()
+
+ operation = operation % self._parameter_placeholders(
+ seq_of_parameters[0]
+ )
+
+ try:
+ return await self._connection.executemany(
+ operation, seq_of_parameters
+ )
+ except Exception as error:
+ self._handle_exception(error)
+
+ def execute(self, operation, parameters=None):
+ self._adapt_connection.await_(
+ self._prepare_and_execute(operation, parameters)
+ )
+
+ def executemany(self, operation, seq_of_parameters):
+ return self._adapt_connection.await_(
+ self._executemany(operation, seq_of_parameters)
+ )
+
+ def setinputsizes(self, *inputsizes):
+ self._inputsizes = inputsizes
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
+
+ server_side = True
+ __slots__ = ("_rowbuffer",)
+
+ def __init__(self, adapt_connection):
+ super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection)
+ self._rowbuffer = None
+
+ def close(self):
+ self._cursor = None
+ self._rowbuffer = None
+
+ def _buffer_rows(self):
+ new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
+ self._rowbuffer = collections.deque(new_rows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ if not self._rowbuffer:
+ self._buffer_rows()
+
+ while True:
+ while self._rowbuffer:
+ yield self._rowbuffer.popleft()
+
+ self._buffer_rows()
+ if not self._rowbuffer:
+ break
+
+ def fetchone(self):
+ if not self._rowbuffer:
+ self._buffer_rows()
+ if not self._rowbuffer:
+ return None
+ return self._rowbuffer.popleft()
+
+ def fetchmany(self, size=None):
+ if size is None:
+ return self.fetchall()
+
+ if not self._rowbuffer:
+ self._buffer_rows()
+
+ buf = list(self._rowbuffer)
+ lb = len(buf)
+ if size > lb:
+ buf.extend(
+ self._adapt_connection.await_(self._cursor.fetch(size - lb))
+ )
+
+ result = buf[0:size]
+ self._rowbuffer = collections.deque(buf[size:])
+ return result
+
+ def fetchall(self):
+ ret = list(self._rowbuffer) + list(
+ self._adapt_connection.await_(self._all())
+ )
+ self._rowbuffer.clear()
+ return ret
+
+ async def _all(self):
+ rows = []
+
+ # TODO: looks like we have to hand-roll some kind of batching here.
+ # hardcoding for the moment but this should be improved.
+ while True:
+ batch = await self._cursor.fetch(1000)
+ if batch:
+ rows.extend(batch)
+ continue
+ else:
+ break
+ return rows
+
+ def executemany(self, operation, seq_of_parameters):
+ raise NotImplementedError(
+ "server side cursor doesn't support executemany yet"
+ )
+
+
+class AsyncAdapt_asyncpg_connection(AdaptedConnection):
+ __slots__ = (
+ "dbapi",
+ "_connection",
+ "isolation_level",
+ "_isolation_setting",
+ "readonly",
+ "deferrable",
+ "_transaction",
+ "_started",
+ "_prepared_statement_cache",
+ "_invalidate_schema_cache_asof",
+ "_execute_mutex",
+ )
+
+ await_ = staticmethod(await_only)
+
+ def __init__(self, dbapi, connection, prepared_statement_cache_size=100):
+ self.dbapi = dbapi
+ self._connection = connection
+ self.isolation_level = self._isolation_setting = "read_committed"
+ self.readonly = False
+ self.deferrable = False
+ self._transaction = None
+ self._started = False
+ self._invalidate_schema_cache_asof = time.time()
+ self._execute_mutex = asyncio.Lock()
+
+ if prepared_statement_cache_size:
+ self._prepared_statement_cache = util.LRUCache(
+ prepared_statement_cache_size
+ )
+ else:
+ self._prepared_statement_cache = None
+
+ async def _check_type_cache_invalidation(self, invalidate_timestamp):
+ if invalidate_timestamp > self._invalidate_schema_cache_asof:
+ await self._connection.reload_schema_state()
+ self._invalidate_schema_cache_asof = invalidate_timestamp
+
+ async def _prepare(self, operation, invalidate_timestamp):
+ await self._check_type_cache_invalidation(invalidate_timestamp)
+
+ cache = self._prepared_statement_cache
+ if cache is None:
+ prepared_stmt = await self._connection.prepare(operation)
+ attributes = prepared_stmt.get_attributes()
+ return prepared_stmt, attributes
+
+ # asyncpg uses a type cache for the "attributes" which seems to go
+ # stale independently of the PreparedStatement itself, so place that
+ # collection in the cache as well.
+ if operation in cache:
+ prepared_stmt, attributes, cached_timestamp = cache[operation]
+
+ # preparedstatements themselves also go stale for certain DDL
+ # changes such as size of a VARCHAR changing, so there is also
+ # a cross-connection invalidation timestamp
+ if cached_timestamp > invalidate_timestamp:
+ return prepared_stmt, attributes
+
+ prepared_stmt = await self._connection.prepare(operation)
+ attributes = prepared_stmt.get_attributes()
+ cache[operation] = (prepared_stmt, attributes, time.time())
+
+ return prepared_stmt, attributes
+
+ def _handle_exception(self, error):
+ if self._connection.is_closed():
+ self._transaction = None
+ self._started = False
+
+ if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
+ exception_mapping = self.dbapi._asyncpg_error_translate
+
+ for super_ in type(error).__mro__:
+ if super_ in exception_mapping:
+ translated_error = exception_mapping[super_](
+ "%s: %s" % (type(error), error)
+ )
+ translated_error.pgcode = (
+ translated_error.sqlstate
+ ) = getattr(error, "sqlstate", None)
+ raise translated_error from error
+ else:
+ raise error
+ else:
+ raise error
+
+ @property
+ def autocommit(self):
+ return self.isolation_level == "autocommit"
+
+ @autocommit.setter
+ def autocommit(self, value):
+ if value:
+ self.isolation_level = "autocommit"
+ else:
+ self.isolation_level = self._isolation_setting
+
+ def set_isolation_level(self, level):
+ if self._started:
+ self.rollback()
+ self.isolation_level = self._isolation_setting = level
+
+ async def _start_transaction(self):
+ if self.isolation_level == "autocommit":
+ return
+
+ try:
+ self._transaction = self._connection.transaction(
+ isolation=self.isolation_level,
+ readonly=self.readonly,
+ deferrable=self.deferrable,
+ )
+ await self._transaction.start()
+ except Exception as error:
+ self._handle_exception(error)
+ else:
+ self._started = True
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_asyncpg_ss_cursor(self)
+ else:
+ return AsyncAdapt_asyncpg_cursor(self)
+
+ def rollback(self):
+ if self._started:
+ try:
+ self.await_(self._transaction.rollback())
+ except Exception as error:
+ self._handle_exception(error)
+ finally:
+ self._transaction = None
+ self._started = False
+
+ def commit(self):
+ if self._started:
+ try:
+ self.await_(self._transaction.commit())
+ except Exception as error:
+ self._handle_exception(error)
+ finally:
+ self._transaction = None
+ self._started = False
+
+ def close(self):
+ self.rollback()
+
+ self.await_(self._connection.close())
+
+
+class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_asyncpg_dbapi:
+ def __init__(self, asyncpg):
+ self.asyncpg = asyncpg
+ self.paramstyle = "format"
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+ prepared_statement_cache_size = kw.pop(
+ "prepared_statement_cache_size", 100
+ )
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_asyncpg_connection(
+ self,
+ await_fallback(self.asyncpg.connect(*arg, **kw)),
+ prepared_statement_cache_size=prepared_statement_cache_size,
+ )
+ else:
+ return AsyncAdapt_asyncpg_connection(
+ self,
+ await_only(self.asyncpg.connect(*arg, **kw)),
+ prepared_statement_cache_size=prepared_statement_cache_size,
+ )
+
+ class Error(Exception):
+ pass
+
+ class Warning(Exception): # noqa
+ pass
+
+ class InterfaceError(Error):
+ pass
+
+ class DatabaseError(Error):
+ pass
+
+ class InternalError(DatabaseError):
+ pass
+
+ class OperationalError(DatabaseError):
+ pass
+
+ class ProgrammingError(DatabaseError):
+ pass
+
+ class IntegrityError(DatabaseError):
+ pass
+
+ class DataError(DatabaseError):
+ pass
+
+ class NotSupportedError(DatabaseError):
+ pass
+
+ class InternalServerError(InternalError):
+ pass
+
+ class InvalidCachedStatementError(NotSupportedError):
+ def __init__(self, message):
+ super(
+ AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self
+ ).__init__(
+ message + " (SQLAlchemy asyncpg dialect will now invalidate "
+ "all prepared caches in response to this exception)",
+ )
+
+ @util.memoized_property
+ def _asyncpg_error_translate(self):
+ import asyncpg
+
+ return {
+ asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
+ asyncpg.exceptions.PostgresError: self.Error,
+ asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
+ asyncpg.exceptions.InterfaceError: self.InterfaceError,
+ asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
+ asyncpg.exceptions.InternalServerError: self.InternalServerError,
+ }
+
+ def Binary(self, value):
+ return value
+
+ STRING = util.symbol("STRING")
+ TIMESTAMP = util.symbol("TIMESTAMP")
+ TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
+ TIME = util.symbol("TIME")
+ TIME_W_TZ = util.symbol("TIME_W_TZ")
+ DATE = util.symbol("DATE")
+ INTERVAL = util.symbol("INTERVAL")
+ NUMBER = util.symbol("NUMBER")
+ FLOAT = util.symbol("FLOAT")
+ BOOLEAN = util.symbol("BOOLEAN")
+ INTEGER = util.symbol("INTEGER")
+ BIGINTEGER = util.symbol("BIGINTEGER")
+ BYTES = util.symbol("BYTES")
+ DECIMAL = util.symbol("DECIMAL")
+ JSON = util.symbol("JSON")
+ JSONB = util.symbol("JSONB")
+ ENUM = util.symbol("ENUM")
+ UUID = util.symbol("UUID")
+ BYTEA = util.symbol("BYTEA")
+
+ DATETIME = TIMESTAMP
+ BINARY = BYTEA
+
+
+_pg_types = {
+ AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
+ AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
+ AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
+ AsyncAdapt_asyncpg_dbapi.DATE: "date",
+ AsyncAdapt_asyncpg_dbapi.TIME: "time",
+ AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone",
+ AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
+ AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
+ AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
+ AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
+ AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
+ AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
+ AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
+ AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
+ AsyncAdapt_asyncpg_dbapi.JSON: "json",
+ AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
+ AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
+ AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
+ AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
+}
+
+
+class PGDialect_asyncpg(PGDialect):
+ driver = "asyncpg"
+ supports_statement_cache = True
+
+ supports_unicode_statements = True
+ supports_server_side_cursors = True
+
+ supports_unicode_binds = True
+
+ default_paramstyle = "format"
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = PGExecutionContext_asyncpg
+ statement_compiler = PGCompiler_asyncpg
+ preparer = PGIdentifierPreparer_asyncpg
+
+ use_setinputsizes = True
+
+ use_native_uuid = True
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Time: AsyncpgTime,
+ sqltypes.Date: AsyncpgDate,
+ sqltypes.DateTime: AsyncpgDateTime,
+ sqltypes.Interval: AsyncPgInterval,
+ INTERVAL: AsyncPgInterval,
+ UUID: AsyncpgUUID,
+ sqltypes.Boolean: AsyncpgBoolean,
+ sqltypes.Integer: AsyncpgInteger,
+ sqltypes.BigInteger: AsyncpgBigInteger,
+ sqltypes.Numeric: AsyncpgNumeric,
+ sqltypes.Float: AsyncpgFloat,
+ sqltypes.JSON: AsyncpgJSON,
+ json.JSONB: AsyncpgJSONB,
+ sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
+ sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
+ sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
+ sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
+ sqltypes.Enum: AsyncPgEnum,
+ OID: AsyncpgOID,
+ REGCLASS: AsyncpgREGCLASS,
+ },
+ )
+ is_async = True
+ _invalidate_schema_cache_asof = 0
+
+ def _invalidate_schema_cache(self):
+ self._invalidate_schema_cache_asof = time.time()
+
+ @util.memoized_property
+ def _dbapi_version(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ return tuple(
+ [
+ int(x)
+ for x in re.findall(
+ r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+ )
+ ]
+ )
+ else:
+ return (99, 99, 99)
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
+
+ @util.memoized_property
+ def _isolation_lookup(self):
+ return {
+ "AUTOCOMMIT": "autocommit",
+ "READ COMMITTED": "read_committed",
+ "REPEATABLE READ": "repeatable_read",
+ "SERIALIZABLE": "serializable",
+ }
+
+ def set_isolation_level(self, connection, level):
+ try:
+ level = self._isolation_lookup[level.replace("_", " ")]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ ),
+ replace_context=err,
+ )
+
+ connection.set_isolation_level(level)
+
+ def set_readonly(self, connection, value):
+ connection.readonly = value
+
+ def get_readonly(self, connection):
+ return connection.readonly
+
+ def set_deferrable(self, connection, value):
+ connection.deferrable = value
+
+ def get_deferrable(self, connection):
+ return connection.deferrable
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+
+ opts.update(url.query)
+ util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
+ util.coerce_kw_type(opts, "port", int)
+ return ([], opts)
+
+ @classmethod
+ def get_pool_class(cls, url):
+
+ async_fallback = url.query.get("async_fallback", False)
+
+ if util.asbool(async_fallback):
+ return pool.FallbackAsyncAdaptedQueuePool
+ else:
+ return pool.AsyncAdaptedQueuePool
+
+ def is_disconnect(self, e, connection, cursor):
+ if connection:
+ return connection._connection.is_closed()
+ else:
+ return isinstance(
+ e, self.dbapi.InterfaceError
+ ) and "connection is closed" in str(e)
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
+ async def setup_asyncpg_json_codec(self, conn):
+ """set up JSON codec for asyncpg.
+
+ This occurs for all new connections and
+ can be overridden by third party dialects.
+
+ .. versionadded:: 1.4.27
+
+ """
+
+ asyncpg_connection = conn._connection
+ deserializer = self._json_deserializer or _py_json.loads
+
+ def _json_decoder(bin_value):
+ return deserializer(bin_value.decode())
+
+ await asyncpg_connection.set_type_codec(
+ "json",
+ encoder=str.encode,
+ decoder=_json_decoder,
+ schema="pg_catalog",
+ format="binary",
+ )
+
+ async def setup_asyncpg_jsonb_codec(self, conn):
+ """set up JSONB codec for asyncpg.
+
+ This occurs for all new connections and
+ can be overridden by third party dialects.
+
+ .. versionadded:: 1.4.27
+
+ """
+
+ asyncpg_connection = conn._connection
+ deserializer = self._json_deserializer or _py_json.loads
+
+ def _jsonb_encoder(str_value):
+ # \x01 is the prefix for jsonb used by PostgreSQL.
+ # asyncpg requires it when format='binary'
+ return b"\x01" + str_value.encode()
+
+ deserializer = self._json_deserializer or _py_json.loads
+
+ def _jsonb_decoder(bin_value):
+ # the byte is the \x01 prefix for jsonb used by PostgreSQL.
+ # asyncpg returns it when format='binary'
+ return deserializer(bin_value[1:].decode())
+
+ await asyncpg_connection.set_type_codec(
+ "jsonb",
+ encoder=_jsonb_encoder,
+ decoder=_jsonb_decoder,
+ schema="pg_catalog",
+ format="binary",
+ )
+
+ def on_connect(self):
+ """on_connect for asyncpg
+
+ A major component of this for asyncpg is to set up type decoders at the
+ asyncpg level.
+
+ See https://github.com/MagicStack/asyncpg/issues/623 for
+ notes on JSON/JSONB implementation.
+
+ """
+
+ super_connect = super(PGDialect_asyncpg, self).on_connect()
+
+ def connect(conn):
+ conn.await_(self.setup_asyncpg_json_codec(conn))
+ conn.await_(self.setup_asyncpg_jsonb_codec(conn))
+ if super_connect is not None:
+ super_connect(conn)
+
+ return connect
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = PGDialect_asyncpg
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
new file mode 100644
index 0000000..eb84170
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -0,0 +1,4651 @@
+# postgresql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: postgresql
+ :name: PostgreSQL
+ :full_support: 9.6, 10, 11, 12, 13, 14
+ :normal_support: 9.6+
+ :best_effort: 8+
+
+.. _postgresql_sequences:
+
+Sequences/SERIAL/IDENTITY
+-------------------------
+
+PostgreSQL supports sequences, and SQLAlchemy uses these as the default means
+of creating new primary key values for integer-based primary key columns. When
+creating tables, SQLAlchemy will issue the ``SERIAL`` datatype for
+integer-based primary key columns, which generates a sequence and server side
+default corresponding to the column.
+
+To specify a specific named sequence to be used for primary key generation,
+use the :func:`~sqlalchemy.schema.Sequence` construct::
+
+ Table('sometable', metadata,
+ Column('id', Integer, Sequence('some_id_seq'), primary_key=True)
+ )
+
+When SQLAlchemy issues a single INSERT statement, to fulfill the contract of
+having the "last insert identifier" available, a RETURNING clause is added to
+the INSERT statement which specifies the primary key columns should be
+returned after the statement completes. The RETURNING functionality only takes
+place if PostgreSQL 8.2 or later is in use. As a fallback approach, the
+sequence, whether specified explicitly or implicitly via ``SERIAL``, is
+executed independently beforehand, the returned value to be used in the
+subsequent insert. Note that when an
+:func:`~sqlalchemy.sql.expression.insert()` construct is executed using
+"executemany" semantics, the "last inserted identifier" functionality does not
+apply; no RETURNING clause is emitted nor is the sequence pre-executed in this
+case.
+
+To force the usage of RETURNING by default off, specify the flag
+``implicit_returning=False`` to :func:`_sa.create_engine`.
+
+PostgreSQL 10 and above IDENTITY columns
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+PostgreSQL 10 and above have a new IDENTITY feature that supersedes the use
+of SERIAL. The :class:`_schema.Identity` construct in a
+:class:`_schema.Column` can be used to control its behavior::
+
+ from sqlalchemy import Table, Column, MetaData, Integer, Computed
+
+ metadata = MetaData()
+
+ data = Table(
+ "data",
+ metadata,
+ Column(
+ 'id', Integer, Identity(start=42, cycle=True), primary_key=True
+ ),
+ Column('data', String)
+ )
+
+The CREATE TABLE for the above :class:`_schema.Table` object would be:
+
+.. sourcecode:: sql
+
+ CREATE TABLE data (
+ id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 42 CYCLE),
+ data VARCHAR,
+ PRIMARY KEY (id)
+ )
+
+.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct
+ in a :class:`_schema.Column` to specify the option of an autoincrementing
+ column.
+
+.. note::
+
+ Previous versions of SQLAlchemy did not have built-in support for rendering
+ of IDENTITY, and could use the following compilation hook to replace
+ occurrences of SERIAL with IDENTITY::
+
+ from sqlalchemy.schema import CreateColumn
+ from sqlalchemy.ext.compiler import compiles
+
+
+ @compiles(CreateColumn, 'postgresql')
+ def use_identity(element, compiler, **kw):
+ text = compiler.visit_create_column(element, **kw)
+ text = text.replace(
+ "SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY"
+ )
+ return text
+
+ Using the above, a table such as::
+
+ t = Table(
+ 't', m,
+ Column('id', Integer, primary_key=True),
+ Column('data', String)
+ )
+
+ Will generate on the backing database as::
+
+ CREATE TABLE t (
+ id INT GENERATED BY DEFAULT AS IDENTITY,
+ data VARCHAR,
+ PRIMARY KEY (id)
+ )
+
+.. _postgresql_ss_cursors:
+
+Server Side Cursors
+-------------------
+
+Server-side cursor support is available for the psycopg2, asyncpg
+dialects and may also be available in others.
+
+Server side cursors are enabled on a per-statement basis by using the
+:paramref:`.Connection.execution_options.stream_results` connection execution
+option::
+
+ with engine.connect() as conn:
+ result = conn.execution_options(stream_results=True).execute(text("select * from table"))
+
+Note that some kinds of SQL statements may not be supported with
+server side cursors; generally, only SQL statements that return rows should be
+used with this option.
+
+.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated
+ and will be removed in a future release. Please use the
+ :paramref:`_engine.Connection.stream_results` execution option for
+ unbuffered cursor support.
+
+.. seealso::
+
+ :ref:`engine_stream_results`
+
+.. _postgresql_isolation_level:
+
+Transaction Isolation Level
+---------------------------
+
+Most SQLAlchemy dialects support setting of transaction isolation level
+using the :paramref:`_sa.create_engine.isolation_level` parameter
+at the :func:`_sa.create_engine` level, and at the :class:`_engine.Connection`
+level via the :paramref:`.Connection.execution_options.isolation_level`
+parameter.
+
+For PostgreSQL dialects, this feature works either by making use of the
+DBAPI-specific features, such as psycopg2's isolation level flags which will
+embed the isolation level setting inline with the ``"BEGIN"`` statement, or for
+DBAPIs with no direct support by emitting ``SET SESSION CHARACTERISTICS AS
+TRANSACTION ISOLATION LEVEL <level>`` ahead of the ``"BEGIN"`` statement
+emitted by the DBAPI. For the special AUTOCOMMIT isolation level,
+DBAPI-specific techniques are used which is typically an ``.autocommit``
+flag on the DBAPI connection object.
+
+To set isolation level using :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "postgresql+pg8000://scott:tiger@localhost/test",
+ isolation_level = "REPEATABLE READ"
+ )
+
+To set using per-connection execution options::
+
+ with engine.connect() as conn:
+ conn = conn.execution_options(
+ isolation_level="REPEATABLE READ"
+ )
+ with conn.begin():
+ # ... work with transaction
+
+There are also more options for isolation level configurations, such as
+"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply
+different isolation level settings. See the discussion at
+:ref:`dbapi_autocommit` for background.
+
+Valid values for ``isolation_level`` on most PostgreSQL dialects include:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+ :ref:`postgresql_readonly_deferrable`
+
+ :ref:`psycopg2_isolation_level`
+
+ :ref:`pg8000_isolation_level`
+
+.. _postgresql_readonly_deferrable:
+
+Setting READ ONLY / DEFERRABLE
+------------------------------
+
+Most PostgreSQL dialects support setting the "READ ONLY" and "DEFERRABLE"
+characteristics of the transaction, which is in addition to the isolation level
+setting. These two attributes can be established either in conjunction with or
+independently of the isolation level by passing the ``postgresql_readonly`` and
+``postgresql_deferrable`` flags with
+:meth:`_engine.Connection.execution_options`. The example below illustrates
+passing the ``"SERIALIZABLE"`` isolation level at the same time as setting
+"READ ONLY" and "DEFERRABLE"::
+
+ with engine.connect() as conn:
+ conn = conn.execution_options(
+ isolation_level="SERIALIZABLE",
+ postgresql_readonly=True,
+ postgresql_deferrable=True
+ )
+ with conn.begin():
+ # ... work with transaction
+
+Note that some DBAPIs such as asyncpg only support "readonly" with
+SERIALIZABLE isolation.
+
+.. versionadded:: 1.4 added support for the ``postgresql_readonly``
+ and ``postgresql_deferrable`` execution options.
+
+.. _postgresql_alternate_search_path:
+
+Setting Alternate Search Paths on Connect
+------------------------------------------
+
+The PostgreSQL ``search_path`` variable refers to the list of schema names
+that will be implicitly referred towards when a particular table or other
+object is referenced in a SQL statement. As detailed in the next section
+:ref:`postgresql_schema_reflection`, SQLAlchemy is generally organized around
+the concept of keeping this variable at its default value of ``public``,
+however, in order to have it set to any arbitrary name or names when connections
+are used automatically, the "SET SESSION search_path" command may be invoked
+for all connections in a pool using the following event handler, as discussed
+at :ref:`schema_set_default_connections`::
+
+ from sqlalchemy import event
+ from sqlalchemy import create_engine
+
+ engine = create_engine("postgresql+psycopg2://scott:tiger@host/dbname")
+
+ @event.listens_for(engine, "connect", insert=True)
+ def set_search_path(dbapi_connection, connection_record):
+ existing_autocommit = dbapi_connection.autocommit
+ dbapi_connection.autocommit = True
+ cursor = dbapi_connection.cursor()
+ cursor.execute("SET SESSION search_path='%s'" % schema_name)
+ cursor.close()
+ dbapi_connection.autocommit = existing_autocommit
+
+The reason the recipe is complicated by use of the ``.autocommit`` DBAPI
+attribute is so that when the ``SET SESSION search_path`` directive is invoked,
+it is invoked outside of the scope of any transaction and therefore will not
+be reverted when the DBAPI connection has a rollback.
+
+.. seealso::
+
+ :ref:`schema_set_default_connections` - in the :ref:`metadata_toplevel` documentation
+
+
+
+
+.. _postgresql_schema_reflection:
+
+Remote-Schema Table Introspection and PostgreSQL search_path
+------------------------------------------------------------
+
+.. admonition:: Section Best Practices Summarized
+
+ keep the ``search_path`` variable set to its default of ``public``, without
+ any other schema names. For other schema names, name these explicitly
+ within :class:`_schema.Table` definitions. Alternatively, the
+ ``postgresql_ignore_search_path`` option will cause all reflected
+ :class:`_schema.Table` objects to have a :attr:`_schema.Table.schema`
+ attribute set up.
+
+The PostgreSQL dialect can reflect tables from any schema, as outlined in
+:ref:`metadata_reflection_schemas`.
+
+With regards to tables which these :class:`_schema.Table`
+objects refer to via foreign key constraint, a decision must be made as to how
+the ``.schema`` is represented in those remote tables, in the case where that
+remote schema name is also a member of the current
+`PostgreSQL search path
+<https://www.postgresql.org/docs/current/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_.
+
+By default, the PostgreSQL dialect mimics the behavior encouraged by
+PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function
+returns a sample definition for a particular foreign key constraint,
+omitting the referenced schema name from that definition when the name is
+also in the PostgreSQL schema search path. The interaction below
+illustrates this behavior::
+
+ test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY);
+ CREATE TABLE
+ test=> CREATE TABLE referring(
+ test(> id INTEGER PRIMARY KEY,
+ test(> referred_id INTEGER REFERENCES test_schema.referred(id));
+ CREATE TABLE
+ test=> SET search_path TO public, test_schema;
+ test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM
+ test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n
+ test-> ON n.oid = c.relnamespace
+ test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid
+ test-> WHERE c.relname='referring' AND r.contype = 'f'
+ test-> ;
+ pg_get_constraintdef
+ ---------------------------------------------------
+ FOREIGN KEY (referred_id) REFERENCES referred(id)
+ (1 row)
+
+Above, we created a table ``referred`` as a member of the remote schema
+``test_schema``, however when we added ``test_schema`` to the
+PG ``search_path`` and then asked ``pg_get_constraintdef()`` for the
+``FOREIGN KEY`` syntax, ``test_schema`` was not included in the output of
+the function.
+
+On the other hand, if we set the search path back to the typical default
+of ``public``::
+
+ test=> SET search_path TO public;
+ SET
+
+The same query against ``pg_get_constraintdef()`` now returns the fully
+schema-qualified name for us::
+
+ test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM
+ test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n
+ test-> ON n.oid = c.relnamespace
+ test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid
+ test-> WHERE c.relname='referring' AND r.contype = 'f';
+ pg_get_constraintdef
+ ---------------------------------------------------------------
+ FOREIGN KEY (referred_id) REFERENCES test_schema.referred(id)
+ (1 row)
+
+SQLAlchemy will by default use the return value of ``pg_get_constraintdef()``
+in order to determine the remote schema name. That is, if our ``search_path``
+were set to include ``test_schema``, and we invoked a table
+reflection process as follows::
+
+ >>> from sqlalchemy import Table, MetaData, create_engine, text
+ >>> engine = create_engine("postgresql://scott:tiger@localhost/test")
+ >>> with engine.connect() as conn:
+ ... conn.execute(text("SET search_path TO test_schema, public"))
+ ... metadata_obj = MetaData()
+ ... referring = Table('referring', metadata_obj,
+ ... autoload_with=conn)
+ ...
+ <sqlalchemy.engine.result.CursorResult object at 0x101612ed0>
+
+The above process would deliver to the :attr:`_schema.MetaData.tables`
+collection
+``referred`` table named **without** the schema::
+
+ >>> metadata_obj.tables['referred'].schema is None
+ True
+
+To alter the behavior of reflection such that the referred schema is
+maintained regardless of the ``search_path`` setting, use the
+``postgresql_ignore_search_path`` option, which can be specified as a
+dialect-specific argument to both :class:`_schema.Table` as well as
+:meth:`_schema.MetaData.reflect`::
+
+ >>> with engine.connect() as conn:
+ ... conn.execute(text("SET search_path TO test_schema, public"))
+ ... metadata_obj = MetaData()
+ ... referring = Table('referring', metadata_obj,
+ ... autoload_with=conn,
+ ... postgresql_ignore_search_path=True)
+ ...
+ <sqlalchemy.engine.result.CursorResult object at 0x1016126d0>
+
+We will now have ``test_schema.referred`` stored as schema-qualified::
+
+ >>> metadata_obj.tables['test_schema.referred'].schema
+ 'test_schema'
+
+.. sidebar:: Best Practices for PostgreSQL Schema reflection
+
+ The description of PostgreSQL schema reflection behavior is complex, and
+ is the product of many years of dealing with widely varied use cases and
+ user preferences. But in fact, there's no need to understand any of it if
+ you just stick to the simplest use pattern: leave the ``search_path`` set
+ to its default of ``public`` only, never refer to the name ``public`` as
+ an explicit schema name otherwise, and refer to all other schema names
+ explicitly when building up a :class:`_schema.Table` object. The options
+ described here are only for those users who can't, or prefer not to, stay
+ within these guidelines.
+
+Note that **in all cases**, the "default" schema is always reflected as
+``None``. The "default" schema on PostgreSQL is that which is returned by the
+PostgreSQL ``current_schema()`` function. On a typical PostgreSQL
+installation, this is the name ``public``. So a table that refers to another
+which is in the ``public`` (i.e. default) schema will always have the
+``.schema`` attribute set to ``None``.
+
+.. seealso::
+
+ :ref:`reflection_schema_qualified_interaction` - discussion of the issue
+ from a backend-agnostic perspective
+
+ `The Schema Search Path
+ <https://www.postgresql.org/docs/current/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_
+ - on the PostgreSQL website.
+
+INSERT/UPDATE...RETURNING
+-------------------------
+
+The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and
+``DELETE..RETURNING`` syntaxes. ``INSERT..RETURNING`` is used by default
+for single-row INSERT statements in order to fetch newly generated
+primary key identifiers. To specify an explicit ``RETURNING`` clause,
+use the :meth:`._UpdateBase.returning` method on a per-statement basis::
+
+ # INSERT..RETURNING
+ result = table.insert().returning(table.c.col1, table.c.col2).\
+ values(name='foo')
+ print(result.fetchall())
+
+ # UPDATE..RETURNING
+ result = table.update().returning(table.c.col1, table.c.col2).\
+ where(table.c.name=='foo').values(name='bar')
+ print(result.fetchall())
+
+ # DELETE..RETURNING
+ result = table.delete().returning(table.c.col1, table.c.col2).\
+ where(table.c.name=='foo')
+ print(result.fetchall())
+
+.. _postgresql_insert_on_conflict:
+
+INSERT...ON CONFLICT (Upsert)
+------------------------------
+
+Starting with version 9.5, PostgreSQL allows "upserts" (update or insert) of
+rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` statement. A
+candidate row will only be inserted if that row does not violate any unique
+constraints. In the case of a unique constraint violation, a secondary action
+can occur which can be either "DO UPDATE", indicating that the data in the
+target row should be updated, or "DO NOTHING", which indicates to silently skip
+this row.
+
+Conflicts are determined using existing unique constraints and indexes. These
+constraints may be identified either using their name as stated in DDL,
+or they may be inferred by stating the columns and conditions that comprise
+the indexes.
+
+SQLAlchemy provides ``ON CONFLICT`` support via the PostgreSQL-specific
+:func:`_postgresql.insert()` function, which provides
+the generative methods :meth:`_postgresql.Insert.on_conflict_do_update`
+and :meth:`~.postgresql.Insert.on_conflict_do_nothing`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.dialects.postgresql import insert
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+ >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(
+ ... index_elements=['id']
+ ... )
+ >>> print(do_nothing_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO NOTHING
+ {stop}
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint='pk_my_table',
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT ON CONSTRAINT pk_my_table DO UPDATE SET data = %(param_1)s
+
+.. versionadded:: 1.1
+
+.. seealso::
+
+ `INSERT .. ON CONFLICT
+ <https://www.postgresql.org/docs/current/static/sql-insert.html#SQL-ON-CONFLICT>`_
+ - in the PostgreSQL documentation.
+
+Specifying the Target
+^^^^^^^^^^^^^^^^^^^^^
+
+Both methods supply the "target" of the conflict using either the
+named constraint or by column inference:
+
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` argument
+ specifies a sequence containing string column names, :class:`_schema.Column`
+ objects, and/or SQL expression elements, which would identify a unique
+ index:
+
+ .. sourcecode:: pycon+sql
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+ {stop}
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... index_elements=[my_table.c.id],
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+
+* When using :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` to
+ infer an index, a partial index can be inferred by also specifying the
+ use the :paramref:`_postgresql.Insert.on_conflict_do_update.index_where` parameter:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data')
+ >>> stmt = stmt.on_conflict_do_update(
+ ... index_elements=[my_table.c.user_email],
+ ... index_where=my_table.c.user_email.like('%@gmail.com'),
+ ... set_=dict(data=stmt.excluded.data)
+ ... )
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (data, user_email)
+ VALUES (%(data)s, %(user_email)s) ON CONFLICT (user_email)
+ WHERE user_email LIKE %(user_email_1)s DO UPDATE SET data = excluded.data
+
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument is
+ used to specify an index directly rather than inferring it. This can be
+ the name of a UNIQUE constraint, a PRIMARY KEY constraint, or an INDEX:
+
+ .. sourcecode:: pycon+sql
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint='my_table_idx_1',
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT ON CONSTRAINT my_table_idx_1 DO UPDATE SET data = %(param_1)s
+ {stop}
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint='my_table_pk',
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT ON CONSTRAINT my_table_pk DO UPDATE SET data = %(param_1)s
+ {stop}
+
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument may
+ also refer to a SQLAlchemy construct representing a constraint,
+ e.g. :class:`.UniqueConstraint`, :class:`.PrimaryKeyConstraint`,
+ :class:`.Index`, or :class:`.ExcludeConstraint`. In this use,
+ if the constraint has a name, it is used directly. Otherwise, if the
+ constraint is unnamed, then inference will be used, where the expressions
+ and optional WHERE clause of the constraint will be spelled out in the
+ construct. This use is especially convenient
+ to refer to the named or unnamed primary key of a :class:`_schema.Table`
+ using the
+ :attr:`_schema.Table.primary_key` attribute:
+
+ .. sourcecode:: pycon+sql
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... constraint=my_table.primary_key,
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+
+The SET Clause
+^^^^^^^^^^^^^^^
+
+``ON CONFLICT...DO UPDATE`` is used to perform an update of the already
+existing row, using any combination of new values as well as values
+from the proposed insertion. These values are specified using the
+:paramref:`_postgresql.Insert.on_conflict_do_update.set_` parameter. This
+parameter accepts a dictionary which consists of direct values
+for UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
+
+.. warning::
+
+ The :meth:`_expression.Insert.on_conflict_do_update`
+ method does **not** take into
+ account Python-side default UPDATE values or generation functions, e.g.
+ those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON CONFLICT style of UPDATE,
+ unless they are manually specified in the
+ :paramref:`_postgresql.Insert.on_conflict_do_update.set_` dictionary.
+
+Updating using the Excluded INSERT Values
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to refer to the proposed insertion row, the special alias
+:attr:`~.postgresql.Insert.excluded` is available as an attribute on
+the :class:`_postgresql.Insert` object; this object is a
+:class:`_expression.ColumnCollection`
+which alias contains all columns of the target
+table:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author)
+ ... )
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author)
+ VALUES (%(id)s, %(data)s, %(author)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author
+
+Additional WHERE Criteria
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The :meth:`_expression.Insert.on_conflict_do_update` method also accepts
+a WHERE clause using the :paramref:`_postgresql.Insert.on_conflict_do_update.where`
+parameter, which will limit those rows which receive an UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+ >>> on_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author),
+ ... where=(my_table.c.status == 2)
+ ... )
+ >>> print(on_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author)
+ VALUES (%(id)s, %(data)s, %(author)s)
+ ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author
+ WHERE my_table.status = %(status_1)s
+
+Skipping Rows with DO NOTHING
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``ON CONFLICT`` may be used to skip inserting a row entirely
+if any conflict with a unique or exclusion constraint occurs; below
+this is illustrated using the
+:meth:`~.postgresql.Insert.on_conflict_do_nothing` method:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id'])
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT (id) DO NOTHING
+
+If ``DO NOTHING`` is used without specifying any columns or constraint,
+it has the effect of skipping the INSERT for any unique or exclusion
+constraint violation which occurs:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing()
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
+ ON CONFLICT DO NOTHING
+
+.. _postgresql_match:
+
+Full Text Search
+----------------
+
+SQLAlchemy makes available the PostgreSQL ``@@`` operator via the
+:meth:`_expression.ColumnElement.match` method on any textual column expression.
+
+On the PostgreSQL dialect, an expression like the following::
+
+ select(sometable.c.text.match("search string"))
+
+will emit to the database::
+
+ SELECT text @@ to_tsquery('search string') FROM table
+
+Various other PostgreSQL text search functions such as ``to_tsquery()``,
+``to_tsvector()``, and ``plainto_tsquery()`` are available by explicitly using
+the standard SQLAlchemy :data:`.func` construct.
+
+For example::
+
+ select(func.to_tsvector('fat cats ate rats').match('cat & rat'))
+
+Emits the equivalent of::
+
+ SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat')
+
+The :class:`_postgresql.TSVECTOR` type can provide for explicit CAST::
+
+ from sqlalchemy.dialects.postgresql import TSVECTOR
+ from sqlalchemy import select, cast
+ select(cast("some text", TSVECTOR))
+
+produces a statement equivalent to::
+
+ SELECT CAST('some text' AS TSVECTOR) AS anon_1
+
+.. tip::
+
+ It's important to remember that text searching in PostgreSQL is powerful but complicated,
+ and SQLAlchemy users are advised to reference the PostgreSQL documentation
+ regarding
+ `Full Text Search <https://www.postgresql.org/docs/current/textsearch-controls.html>`_.
+
+ There are important differences between ``to_tsquery`` and
+ ``plainto_tsquery``, the most significant of which is that ``to_tsquery``
+ expects specially formatted "querytext" that is written to PostgreSQL's own
+ specification, while ``plainto_tsquery`` expects unformatted text that is
+ transformed into ``to_tsquery`` compatible querytext. This means the input to
+ ``.match()`` under PostgreSQL may be incompatible with the input to
+ ``.match()`` under another database backend. SQLAlchemy users who support
+ multiple backends are advised to carefully implement their usage of
+ ``.match()`` to work around these constraints.
+
+Full Text Searches in PostgreSQL are influenced by a combination of: the
+PostgreSQL setting of ``default_text_search_config``, the ``regconfig`` used
+to build the GIN/GiST indexes, and the ``regconfig`` optionally passed in
+during a query.
+
+When performing a Full Text Search against a column that has a GIN or
+GiST index that is already pre-computed (which is common on full text
+searches) one may need to explicitly pass in a particular PostgreSQL
+``regconfig`` value to ensure the query-planner utilizes the index and does
+not re-compute the column on demand.
+
+In order to provide for this explicit query planning, or to use different
+search strategies, the ``match`` method accepts a ``postgresql_regconfig``
+keyword argument::
+
+ select(mytable.c.id).where(
+ mytable.c.title.match('somestring', postgresql_regconfig='english')
+ )
+
+Emits the equivalent of::
+
+ SELECT mytable.id FROM mytable
+ WHERE mytable.title @@ to_tsquery('english', 'somestring')
+
+One can also specifically pass in a `'regconfig'` value to the
+``to_tsvector()`` command as the initial argument::
+
+ select(mytable.c.id).where(
+ func.to_tsvector('english', mytable.c.title )\
+ .match('somestring', postgresql_regconfig='english')
+ )
+
+produces a statement equivalent to::
+
+ SELECT mytable.id FROM mytable
+ WHERE to_tsvector('english', mytable.title) @@
+ to_tsquery('english', 'somestring')
+
+It is recommended that you use the ``EXPLAIN ANALYZE...`` tool from
+PostgreSQL to ensure that you are generating queries with SQLAlchemy that
+take full advantage of any indexes you may have created for full text search.
+
+.. seealso::
+
+ `Full Text Search <https://www.postgresql.org/docs/current/textsearch-controls.html>`_ - in the PostgreSQL documentation
+
+
+FROM ONLY ...
+-------------
+
+The dialect supports PostgreSQL's ONLY keyword for targeting only a particular
+table in an inheritance hierarchy. This can be used to produce the
+``SELECT ... FROM ONLY``, ``UPDATE ONLY ...``, and ``DELETE FROM ONLY ...``
+syntaxes. It uses SQLAlchemy's hints mechanism::
+
+ # SELECT ... FROM ONLY ...
+ result = table.select().with_hint(table, 'ONLY', 'postgresql')
+ print(result.fetchall())
+
+ # UPDATE ONLY ...
+ table.update(values=dict(foo='bar')).with_hint('ONLY',
+ dialect_name='postgresql')
+
+ # DELETE FROM ONLY ...
+ table.delete().with_hint('ONLY', dialect_name='postgresql')
+
+
+.. _postgresql_indexes:
+
+PostgreSQL-Specific Index Options
+---------------------------------
+
+Several extensions to the :class:`.Index` construct are available, specific
+to the PostgreSQL dialect.
+
+Covering Indexes
+^^^^^^^^^^^^^^^^
+
+The ``postgresql_include`` option renders INCLUDE(colname) for the given
+string names::
+
+ Index("my_index", table.c.x, postgresql_include=['y'])
+
+would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
+
+Note that this feature requires PostgreSQL 11 or later.
+
+.. versionadded:: 1.4
+
+.. _postgresql_partial_indexes:
+
+Partial Indexes
+^^^^^^^^^^^^^^^
+
+Partial indexes add criterion to the index definition so that the index is
+applied to a subset of rows. These can be specified on :class:`.Index`
+using the ``postgresql_where`` keyword argument::
+
+ Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10)
+
+.. _postgresql_operator_classes:
+
+Operator Classes
+^^^^^^^^^^^^^^^^
+
+PostgreSQL allows the specification of an *operator class* for each column of
+an index (see
+https://www.postgresql.org/docs/current/interactive/indexes-opclass.html).
+The :class:`.Index` construct allows these to be specified via the
+``postgresql_ops`` keyword argument::
+
+ Index(
+ 'my_index', my_table.c.id, my_table.c.data,
+ postgresql_ops={
+ 'data': 'text_pattern_ops',
+ 'id': 'int4_ops'
+ })
+
+Note that the keys in the ``postgresql_ops`` dictionaries are the
+"key" name of the :class:`_schema.Column`, i.e. the name used to access it from
+the ``.c`` collection of :class:`_schema.Table`, which can be configured to be
+different than the actual name of the column as expressed in the database.
+
+If ``postgresql_ops`` is to be used against a complex SQL expression such
+as a function call, then to apply to the column it must be given a label
+that is identified in the dictionary by name, e.g.::
+
+ Index(
+ 'my_index', my_table.c.id,
+ func.lower(my_table.c.data).label('data_lower'),
+ postgresql_ops={
+ 'data_lower': 'text_pattern_ops',
+ 'id': 'int4_ops'
+ })
+
+Operator classes are also supported by the
+:class:`_postgresql.ExcludeConstraint` construct using the
+:paramref:`_postgresql.ExcludeConstraint.ops` parameter. See that parameter for
+details.
+
+.. versionadded:: 1.3.21 added support for operator classes with
+ :class:`_postgresql.ExcludeConstraint`.
+
+
+Index Types
+^^^^^^^^^^^
+
+PostgreSQL provides several index types: B-Tree, Hash, GiST, and GIN, as well
+as the ability for users to create their own (see
+https://www.postgresql.org/docs/current/static/indexes-types.html). These can be
+specified on :class:`.Index` using the ``postgresql_using`` keyword argument::
+
+ Index('my_index', my_table.c.data, postgresql_using='gin')
+
+The value passed to the keyword argument will be simply passed through to the
+underlying CREATE INDEX command, so it *must* be a valid index type for your
+version of PostgreSQL.
+
+.. _postgresql_index_storage:
+
+Index Storage Parameters
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+PostgreSQL allows storage parameters to be set on indexes. The storage
+parameters available depend on the index method used by the index. Storage
+parameters can be specified on :class:`.Index` using the ``postgresql_with``
+keyword argument::
+
+ Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50})
+
+.. versionadded:: 1.0.6
+
+PostgreSQL allows to define the tablespace in which to create the index.
+The tablespace can be specified on :class:`.Index` using the
+``postgresql_tablespace`` keyword argument::
+
+ Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace')
+
+.. versionadded:: 1.1
+
+Note that the same option is available on :class:`_schema.Table` as well.
+
+.. _postgresql_index_concurrently:
+
+Indexes with CONCURRENTLY
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The PostgreSQL index option CONCURRENTLY is supported by passing the
+flag ``postgresql_concurrently`` to the :class:`.Index` construct::
+
+ tbl = Table('testtbl', m, Column('data', Integer))
+
+ idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True)
+
+The above index construct will render DDL for CREATE INDEX, assuming
+PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as::
+
+ CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data)
+
+For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for
+a connection-less dialect, it will emit::
+
+ DROP INDEX CONCURRENTLY test_idx1
+
+.. versionadded:: 1.1 support for CONCURRENTLY on DROP INDEX. The
+ CONCURRENTLY keyword is now only emitted if a high enough version
+ of PostgreSQL is detected on the connection (or for a connection-less
+ dialect).
+
+When using CONCURRENTLY, the PostgreSQL database requires that the statement
+be invoked outside of a transaction block. The Python DBAPI enforces that
+even for a single statement, a transaction is present, so to use this
+construct, the DBAPI's "autocommit" mode must be used::
+
+ metadata = MetaData()
+ table = Table(
+ "foo", metadata,
+ Column("id", String))
+ index = Index(
+ "foo_idx", table.c.id, postgresql_concurrently=True)
+
+ with engine.connect() as conn:
+ with conn.execution_options(isolation_level='AUTOCOMMIT'):
+ table.create(conn)
+
+.. seealso::
+
+ :ref:`postgresql_isolation_level`
+
+.. _postgresql_index_reflection:
+
+PostgreSQL Index Reflection
+---------------------------
+
+The PostgreSQL database creates a UNIQUE INDEX implicitly whenever the
+UNIQUE CONSTRAINT construct is used. When inspecting a table using
+:class:`_reflection.Inspector`, the :meth:`_reflection.Inspector.get_indexes`
+and the :meth:`_reflection.Inspector.get_unique_constraints`
+will report on these
+two constructs distinctly; in the case of the index, the key
+``duplicates_constraint`` will be present in the index entry if it is
+detected as mirroring a constraint. When performing reflection using
+``Table(..., autoload_with=engine)``, the UNIQUE INDEX is **not** returned
+in :attr:`_schema.Table.indexes` when it is detected as mirroring a
+:class:`.UniqueConstraint` in the :attr:`_schema.Table.constraints` collection
+.
+
+.. versionchanged:: 1.0.0 - :class:`_schema.Table` reflection now includes
+ :class:`.UniqueConstraint` objects present in the
+ :attr:`_schema.Table.constraints`
+ collection; the PostgreSQL backend will no longer include a "mirrored"
+ :class:`.Index` construct in :attr:`_schema.Table.indexes`
+ if it is detected
+ as corresponding to a unique constraint.
+
+Special Reflection Options
+--------------------------
+
+The :class:`_reflection.Inspector`
+used for the PostgreSQL backend is an instance
+of :class:`.PGInspector`, which offers additional methods::
+
+ from sqlalchemy import create_engine, inspect
+
+ engine = create_engine("postgresql+psycopg2://localhost/test")
+ insp = inspect(engine) # will be a PGInspector
+
+ print(insp.get_enums())
+
+.. autoclass:: PGInspector
+ :members:
+
+.. _postgresql_table_options:
+
+PostgreSQL Table Options
+------------------------
+
+Several options for CREATE TABLE are supported directly by the PostgreSQL
+dialect in conjunction with the :class:`_schema.Table` construct:
+
+* ``TABLESPACE``::
+
+ Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace')
+
+ The above option is also available on the :class:`.Index` construct.
+
+* ``ON COMMIT``::
+
+ Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS')
+
+* ``WITH OIDS``::
+
+ Table("some_table", metadata, ..., postgresql_with_oids=True)
+
+* ``WITHOUT OIDS``::
+
+ Table("some_table", metadata, ..., postgresql_with_oids=False)
+
+* ``INHERITS``::
+
+ Table("some_table", metadata, ..., postgresql_inherits="some_supertable")
+
+ Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...))
+
+ .. versionadded:: 1.0.0
+
+* ``PARTITION BY``::
+
+ Table("some_table", metadata, ...,
+ postgresql_partition_by='LIST (part_column)')
+
+ .. versionadded:: 1.2.6
+
+.. seealso::
+
+ `PostgreSQL CREATE TABLE options
+ <https://www.postgresql.org/docs/current/static/sql-createtable.html>`_ -
+ in the PostgreSQL documentation.
+
+.. _postgresql_constraint_options:
+
+PostgreSQL Constraint Options
+-----------------------------
+
+The following option(s) are supported by the PostgreSQL dialect in conjunction
+with selected constraint constructs:
+
+* ``NOT VALID``: This option applies towards CHECK and FOREIGN KEY constraints
+ when the constraint is being added to an existing table via ALTER TABLE,
+ and has the effect that existing rows are not scanned during the ALTER
+ operation against the constraint being added.
+
+ When using a SQL migration tool such as `Alembic <https://alembic.sqlalchemy.org>`_
+ that renders ALTER TABLE constructs, the ``postgresql_not_valid`` argument
+ may be specified as an additional keyword argument within the operation
+ that creates the constraint, as in the following Alembic example::
+
+ def update():
+ op.create_foreign_key(
+ "fk_user_address",
+ "address",
+ "user",
+ ["user_id"],
+ ["id"],
+ postgresql_not_valid=True
+ )
+
+ The keyword is ultimately accepted directly by the
+ :class:`_schema.CheckConstraint`, :class:`_schema.ForeignKeyConstraint`
+ and :class:`_schema.ForeignKey` constructs; when using a tool like
+ Alembic, dialect-specific keyword arguments are passed through to
+ these constructs from the migration operation directives::
+
+ CheckConstraint("some_field IS NOT NULL", postgresql_not_valid=True)
+
+ ForeignKeyConstraint(["some_id"], ["some_table.some_id"], postgresql_not_valid=True)
+
+ .. versionadded:: 1.4.32
+
+ .. seealso::
+
+ `PostgreSQL ALTER TABLE options
+ <https://www.postgresql.org/docs/current/static/sql-altertable.html>`_ -
+ in the PostgreSQL documentation.
+
+.. _postgresql_table_valued_overview:
+
+Table values, Table and Column valued functions, Row and Tuple objects
+-----------------------------------------------------------------------
+
+PostgreSQL makes great use of modern SQL forms such as table-valued functions,
+tables and rows as values. These constructs are commonly used as part
+of PostgreSQL's support for complex datatypes such as JSON, ARRAY, and other
+datatypes. SQLAlchemy's SQL expression language has native support for
+most table-valued and row-valued forms.
+
+.. _postgresql_table_valued:
+
+Table-Valued Functions
+^^^^^^^^^^^^^^^^^^^^^^^
+
+Many PostgreSQL built-in functions are intended to be used in the FROM clause
+of a SELECT statement, and are capable of returning table rows or sets of table
+rows. A large portion of PostgreSQL's JSON functions for example such as
+``json_array_elements()``, ``json_object_keys()``, ``json_each_text()``,
+``json_each()``, ``json_to_record()``, ``json_populate_recordset()`` use such
+forms. These classes of SQL function calling forms in SQLAlchemy are available
+using the :meth:`_functions.FunctionElement.table_valued` method in conjunction
+with :class:`_functions.Function` objects generated from the :data:`_sql.func`
+namespace.
+
+Examples from PostgreSQL's reference documentation follow below:
+
+* ``json_each()``::
+
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value"))
+ >>> print(stmt)
+ SELECT anon_1.key, anon_1.value
+ FROM json_each(:json_each_1) AS anon_1
+
+* ``json_populate_record()``::
+
+ >>> from sqlalchemy import select, func, literal_column
+ >>> stmt = select(
+ ... func.json_populate_record(
+ ... literal_column("null::myrowtype"),
+ ... '{"a":1,"b":2}'
+ ... ).table_valued("a", "b", name="x")
+ ... )
+ >>> print(stmt)
+ SELECT x.a, x.b
+ FROM json_populate_record(null::myrowtype, :json_populate_record_1) AS x
+
+* ``json_to_record()`` - this form uses a PostgreSQL specific form of derived
+ columns in the alias, where we may make use of :func:`_sql.column` elements with
+ types to produce them. The :meth:`_functions.FunctionElement.table_valued`
+ method produces a :class:`_sql.TableValuedAlias` construct, and the method
+ :meth:`_sql.TableValuedAlias.render_derived` method sets up the derived
+ columns specification::
+
+ >>> from sqlalchemy import select, func, column, Integer, Text
+ >>> stmt = select(
+ ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}').table_valued(
+ ... column("a", Integer), column("b", Text), column("d", Text),
+ ... ).render_derived(name="x", with_types=True)
+ ... )
+ >>> print(stmt)
+ SELECT x.a, x.b, x.d
+ FROM json_to_record(:json_to_record_1) AS x(a INTEGER, b TEXT, d TEXT)
+
+* ``WITH ORDINALITY`` - part of the SQL standard, ``WITH ORDINALITY`` adds an
+ ordinal counter to the output of a function and is accepted by a limited set
+ of PostgreSQL functions including ``unnest()`` and ``generate_series()``. The
+ :meth:`_functions.FunctionElement.table_valued` method accepts a keyword
+ parameter ``with_ordinality`` for this purpose, which accepts the string name
+ that will be applied to the "ordinality" column::
+
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(
+ ... func.generate_series(4, 1, -1).
+ ... table_valued("value", with_ordinality="ordinality").
+ ... render_derived()
+ ... )
+ >>> print(stmt)
+ SELECT anon_1.value, anon_1.ordinality
+ FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3)
+ WITH ORDINALITY AS anon_1(value, ordinality)
+
+.. versionadded:: 1.4.0b2
+
+.. seealso::
+
+ :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial`
+
+.. _postgresql_column_valued:
+
+Column Valued Functions
+^^^^^^^^^^^^^^^^^^^^^^^
+
+Similar to the table valued function, a column valued function is present
+in the FROM clause, but delivers itself to the columns clause as a single
+scalar value. PostgreSQL functions such as ``json_array_elements()``,
+``unnest()`` and ``generate_series()`` may use this form. Column valued functions are available using the
+:meth:`_functions.FunctionElement.column_valued` method of :class:`_functions.FunctionElement`:
+
+* ``json_array_elements()``::
+
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(func.json_array_elements('["one", "two"]').column_valued("x"))
+ >>> print(stmt)
+ SELECT x
+ FROM json_array_elements(:json_array_elements_1) AS x
+
+* ``unnest()`` - in order to generate a PostgreSQL ARRAY literal, the
+ :func:`_postgresql.array` construct may be used::
+
+
+ >>> from sqlalchemy.dialects.postgresql import array
+ >>> from sqlalchemy import select, func
+ >>> stmt = select(func.unnest(array([1, 2])).column_valued())
+ >>> print(stmt)
+ SELECT anon_1
+ FROM unnest(ARRAY[%(param_1)s, %(param_2)s]) AS anon_1
+
+ The function can of course be used against an existing table-bound column
+ that's of type :class:`_types.ARRAY`::
+
+ >>> from sqlalchemy import table, column, ARRAY, Integer
+ >>> from sqlalchemy import select, func
+ >>> t = table("t", column('value', ARRAY(Integer)))
+ >>> stmt = select(func.unnest(t.c.value).column_valued("unnested_value"))
+ >>> print(stmt)
+ SELECT unnested_value
+ FROM unnest(t.value) AS unnested_value
+
+.. seealso::
+
+ :ref:`tutorial_functions_column_valued` - in the :ref:`unified_tutorial`
+
+
+Row Types
+^^^^^^^^^
+
+Built-in support for rendering a ``ROW`` may be approximated using
+``func.ROW`` with the :attr:`_sa.func` namespace, or by using the
+:func:`_sql.tuple_` construct::
+
+ >>> from sqlalchemy import table, column, func, tuple_
+ >>> t = table("t", column("id"), column("fk"))
+ >>> stmt = t.select().where(
+ ... tuple_(t.c.id, t.c.fk) > (1,2)
+ ... ).where(
+ ... func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7)
+ ... )
+ >>> print(stmt)
+ SELECT t.id, t.fk
+ FROM t
+ WHERE (t.id, t.fk) > (:param_1, :param_2) AND ROW(t.id, t.fk) < ROW(:ROW_1, :ROW_2)
+
+.. seealso::
+
+ `PostgreSQL Row Constructors
+ <https://www.postgresql.org/docs/current/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS>`_
+
+ `PostgreSQL Row Constructor Comparison
+ <https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON>`_
+
+Table Types passed to Functions
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+PostgreSQL supports passing a table as an argument to a function, which it
+refers towards as a "record" type. SQLAlchemy :class:`_sql.FromClause` objects
+such as :class:`_schema.Table` support this special form using the
+:meth:`_sql.FromClause.table_valued` method, which is comparable to the
+:meth:`_functions.FunctionElement.table_valued` method except that the collection
+of columns is already established by that of the :class:`_sql.FromClause`
+itself::
+
+
+ >>> from sqlalchemy import table, column, func, select
+ >>> a = table( "a", column("id"), column("x"), column("y"))
+ >>> stmt = select(func.row_to_json(a.table_valued()))
+ >>> print(stmt)
+ SELECT row_to_json(a) AS row_to_json_1
+ FROM a
+
+.. versionadded:: 1.4.0b2
+
+
+ARRAY Types
+-----------
+
+The PostgreSQL dialect supports arrays, both as multidimensional column types
+as well as array literals:
+
+* :class:`_postgresql.ARRAY` - ARRAY datatype
+
+* :class:`_postgresql.array` - array literal
+
+* :func:`_postgresql.array_agg` - ARRAY_AGG SQL function
+
+* :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate
+ function syntax.
+
+JSON Types
+----------
+
+The PostgreSQL dialect supports both JSON and JSONB datatypes, including
+psycopg2's native support and support for all of PostgreSQL's special
+operators:
+
+* :class:`_postgresql.JSON`
+
+* :class:`_postgresql.JSONB`
+
+HSTORE Type
+-----------
+
+The PostgreSQL HSTORE type as well as hstore literals are supported:
+
+* :class:`_postgresql.HSTORE` - HSTORE datatype
+
+* :class:`_postgresql.hstore` - hstore literal
+
+ENUM Types
+----------
+
+PostgreSQL has an independently creatable TYPE structure which is used
+to implement an enumerated type. This approach introduces significant
+complexity on the SQLAlchemy side in terms of when this type should be
+CREATED and DROPPED. The type object is also an independently reflectable
+entity. The following sections should be consulted:
+
+* :class:`_postgresql.ENUM` - DDL and typing support for ENUM.
+
+* :meth:`.PGInspector.get_enums` - retrieve a listing of current ENUM types
+
+* :meth:`.postgresql.ENUM.create` , :meth:`.postgresql.ENUM.drop` - individual
+ CREATE and DROP commands for ENUM.
+
+.. _postgresql_array_of_enum:
+
+Using ENUM with ARRAY
+^^^^^^^^^^^^^^^^^^^^^
+
+The combination of ENUM and ARRAY is not directly supported by backend
+DBAPIs at this time. Prior to SQLAlchemy 1.3.17, a special workaround
+was needed in order to allow this combination to work, described below.
+
+.. versionchanged:: 1.3.17 The combination of ENUM and ARRAY is now directly
+ handled by SQLAlchemy's implementation without any workarounds needed.
+
+.. sourcecode:: python
+
+ from sqlalchemy import TypeDecorator
+ from sqlalchemy.dialects.postgresql import ARRAY
+
+ class ArrayOfEnum(TypeDecorator):
+ impl = ARRAY
+
+ def bind_expression(self, bindvalue):
+ return sa.cast(bindvalue, self)
+
+ def result_processor(self, dialect, coltype):
+ super_rp = super(ArrayOfEnum, self).result_processor(
+ dialect, coltype)
+
+ def handle_raw_string(value):
+ inner = re.match(r"^{(.*)}$", value).group(1)
+ return inner.split(",") if inner else []
+
+ def process(value):
+ if value is None:
+ return None
+ return super_rp(handle_raw_string(value))
+ return process
+
+E.g.::
+
+ Table(
+ 'mydata', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', ArrayOfEnum(ENUM('a', 'b, 'c', name='myenum')))
+
+ )
+
+This type is not included as a built-in type as it would be incompatible
+with a DBAPI that suddenly decides to support ARRAY of ENUM directly in
+a new version.
+
+.. _postgresql_array_of_json:
+
+Using JSON/JSONB with ARRAY
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Similar to using ENUM, prior to SQLAlchemy 1.3.17, for an ARRAY of JSON/JSONB
+we need to render the appropriate CAST. Current psycopg2 drivers accommodate
+the result set correctly without any special steps.
+
+.. versionchanged:: 1.3.17 The combination of JSON/JSONB and ARRAY is now
+ directly handled by SQLAlchemy's implementation without any workarounds
+ needed.
+
+.. sourcecode:: python
+
+ class CastingArray(ARRAY):
+ def bind_expression(self, bindvalue):
+ return sa.cast(bindvalue, self)
+
+E.g.::
+
+ Table(
+ 'mydata', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', CastingArray(JSONB))
+ )
+
+
+""" # noqa: E501
+
+from collections import defaultdict
+import datetime as dt
+import re
+from uuid import UUID as _python_UUID
+
+from . import array as _array
+from . import dml
+from . import hstore as _hstore
+from . import json as _json
+from . import ranges as _ranges
+from ... import exc
+from ... import schema
+from ... import sql
+from ... import util
+from ...engine import characteristics
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import compiler
+from ...sql import elements
+from ...sql import expression
+from ...sql import roles
+from ...sql import sqltypes
+from ...sql import util as sql_util
+from ...sql.ddl import DDLBase
+from ...types import BIGINT
+from ...types import BOOLEAN
+from ...types import CHAR
+from ...types import DATE
+from ...types import FLOAT
+from ...types import INTEGER
+from ...types import NUMERIC
+from ...types import REAL
+from ...types import SMALLINT
+from ...types import TEXT
+from ...types import VARCHAR
+
+IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I)
+
+AUTOCOMMIT_REGEXP = re.compile(
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|"
+ "IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)",
+ re.I | re.UNICODE,
+)
+
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "current_catalog",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "fetch",
+ "for",
+ "foreign",
+ "from",
+ "grant",
+ "group",
+ "having",
+ "in",
+ "initially",
+ "intersect",
+ "into",
+ "leading",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "new",
+ "not",
+ "null",
+ "of",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "placing",
+ "primary",
+ "references",
+ "returning",
+ "select",
+ "session_user",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "variadic",
+ "when",
+ "where",
+ "window",
+ "with",
+ "authorization",
+ "between",
+ "binary",
+ "cross",
+ "current_schema",
+ "freeze",
+ "full",
+ "ilike",
+ "inner",
+ "is",
+ "isnull",
+ "join",
+ "left",
+ "like",
+ "natural",
+ "notnull",
+ "outer",
+ "over",
+ "overlaps",
+ "right",
+ "similar",
+ "verbose",
+ ]
+)
+
+_DECIMAL_TYPES = (1231, 1700)
+_FLOAT_TYPES = (700, 701, 1021, 1022)
+_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
+
+
+class BYTEA(sqltypes.LargeBinary):
+ __visit_name__ = "BYTEA"
+
+
+class DOUBLE_PRECISION(sqltypes.Float):
+ __visit_name__ = "DOUBLE_PRECISION"
+
+
+class INET(sqltypes.TypeEngine):
+ __visit_name__ = "INET"
+
+
+PGInet = INET
+
+
+class CIDR(sqltypes.TypeEngine):
+ __visit_name__ = "CIDR"
+
+
+PGCidr = CIDR
+
+
+class MACADDR(sqltypes.TypeEngine):
+ __visit_name__ = "MACADDR"
+
+
+PGMacAddr = MACADDR
+
+
+class MONEY(sqltypes.TypeEngine):
+
+ r"""Provide the PostgreSQL MONEY type.
+
+ Depending on driver, result rows using this type may return a
+ string value which includes currency symbols.
+
+ For this reason, it may be preferable to provide conversion to a
+ numerically-based currency datatype using :class:`_types.TypeDecorator`::
+
+ import re
+ import decimal
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def process_result_value(self, value: Any, dialect: Any) -> None:
+ if value is not None:
+ # adjust this for the currency and numeric
+ m = re.match(r"\$([\d.]+)", value)
+ if m:
+ value = decimal.Decimal(m.group(1))
+ return value
+
+ Alternatively, the conversion may be applied as a CAST using
+ the :meth:`_types.TypeDecorator.column_expression` method as follows::
+
+ import decimal
+ from sqlalchemy import cast
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def column_expression(self, column: Any):
+ return cast(column, Numeric())
+
+ .. versionadded:: 1.2
+
+ """
+
+ __visit_name__ = "MONEY"
+
+
+class OID(sqltypes.TypeEngine):
+
+ """Provide the PostgreSQL OID type.
+
+ .. versionadded:: 0.9.5
+
+ """
+
+ __visit_name__ = "OID"
+
+
+class REGCLASS(sqltypes.TypeEngine):
+
+ """Provide the PostgreSQL REGCLASS type.
+
+ .. versionadded:: 1.2.7
+
+ """
+
+ __visit_name__ = "REGCLASS"
+
+
+class TIMESTAMP(sqltypes.TIMESTAMP):
+
+ """Provide the PostgreSQL TIMESTAMP type."""
+
+ __visit_name__ = "TIMESTAMP"
+
+ def __init__(self, timezone=False, precision=None):
+ """Construct a TIMESTAMP.
+
+ :param timezone: boolean value if timezone present, default False
+ :param precision: optional integer precision value
+
+ .. versionadded:: 1.4
+
+ """
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class TIME(sqltypes.TIME):
+
+ """PostgreSQL TIME type."""
+
+ __visit_name__ = "TIME"
+
+ def __init__(self, timezone=False, precision=None):
+ """Construct a TIME.
+
+ :param timezone: boolean value if timezone present, default False
+ :param precision: optional integer precision value
+
+ .. versionadded:: 1.4
+
+ """
+ super(TIME, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+
+ """PostgreSQL INTERVAL type."""
+
+ __visit_name__ = "INTERVAL"
+ native = True
+
+ def __init__(self, precision=None, fields=None):
+ """Construct an INTERVAL.
+
+ :param precision: optional integer precision value
+ :param fields: string fields specifier. allows storage of fields
+ to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
+ etc.
+
+ .. versionadded:: 1.2
+
+ """
+ self.precision = precision
+ self.fields = fields
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+ return INTERVAL(precision=interval.second_precision)
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(native=True, second_precision=self.precision)
+
+ @property
+ def python_type(self):
+ return dt.timedelta
+
+ def coerce_compared_value(self, op, value):
+ return self
+
+
+PGInterval = INTERVAL
+
+
+class BIT(sqltypes.TypeEngine):
+ __visit_name__ = "BIT"
+
+ def __init__(self, length=None, varying=False):
+ if not varying:
+ # BIT without VARYING defaults to length 1
+ self.length = length or 1
+ else:
+ # but BIT VARYING can be unlimited-length, so no default
+ self.length = length
+ self.varying = varying
+
+
+PGBit = BIT
+
+
+class UUID(sqltypes.TypeEngine):
+
+ """PostgreSQL UUID type.
+
+ Represents the UUID column type, interpreting
+ data either as natively returned by the DBAPI
+ or as Python uuid objects.
+
+ The UUID type is currently known to work within the prominent DBAPI
+ drivers supported by SQLAlchemy including psycopg2, pg8000 and
+ asyncpg. Support for other DBAPI drivers may be incomplete or non-present.
+
+ """
+
+ __visit_name__ = "UUID"
+
+ def __init__(self, as_uuid=False):
+ """Construct a UUID type.
+
+
+ :param as_uuid=False: if True, values will be interpreted
+ as Python uuid objects, converting to/from string via the
+ DBAPI.
+
+ """
+ self.as_uuid = as_uuid
+
+ def coerce_compared_value(self, op, value):
+ """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+
+ if isinstance(value, util.string_types):
+ return self
+ else:
+ return super(UUID, self).coerce_compared_value(op, value)
+
+ def bind_processor(self, dialect):
+ if self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = util.text_type(value)
+ return value
+
+ return process
+ else:
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+ else:
+ return None
+
+ def literal_processor(self, dialect):
+ if self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = "'%s'::UUID" % value
+ return value
+
+ return process
+ else:
+
+ def process(value):
+ if value is not None:
+ value = "'%s'" % value
+ return value
+
+ return process
+
+ @property
+ def python_type(self):
+ return _python_UUID if self.as_uuid else str
+
+
+PGUuid = UUID
+
+
+class TSVECTOR(sqltypes.TypeEngine):
+
+ """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
+ text search type TSVECTOR.
+
+ It can be used to do full text queries on natural language
+ documents.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`postgresql_match`
+
+ """
+
+ __visit_name__ = "TSVECTOR"
+
+
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
+
+ """PostgreSQL ENUM type.
+
+ This is a subclass of :class:`_types.Enum` which includes
+ support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
+
+ When the builtin type :class:`_types.Enum` is used and the
+ :paramref:`.Enum.native_enum` flag is left at its default of
+ True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
+ type as the implementation, so the special create/drop rules
+ will be used.
+
+ The create/drop behavior of ENUM is necessarily intricate, due to the
+ awkward relationship the ENUM type has in relationship to the
+ parent table, in that it may be "owned" by just a single table, or
+ may be shared among many tables.
+
+ When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
+ in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
+ corresponding to when the :meth:`_schema.Table.create` and
+ :meth:`_schema.Table.drop`
+ methods are called::
+
+ table = Table('sometable', metadata,
+ Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
+ )
+
+ table.create(engine) # will emit CREATE ENUM and CREATE TABLE
+ table.drop(engine) # will emit DROP TABLE and DROP ENUM
+
+ To use a common enumerated type between multiple tables, the best
+ practice is to declare the :class:`_types.Enum` or
+ :class:`_postgresql.ENUM` independently, and associate it with the
+ :class:`_schema.MetaData` object itself::
+
+ my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
+
+ t1 = Table('sometable_one', metadata,
+ Column('some_enum', myenum)
+ )
+
+ t2 = Table('sometable_two', metadata,
+ Column('some_enum', myenum)
+ )
+
+ When this pattern is used, care must still be taken at the level
+ of individual table creates. Emitting CREATE TABLE without also
+ specifying ``checkfirst=True`` will still cause issues::
+
+ t1.create(engine) # will fail: no such type 'myenum'
+
+ If we specify ``checkfirst=True``, the individual table-level create
+ operation will check for the ``ENUM`` and create if not exists::
+
+ # will check if enum exists, and emit CREATE TYPE if not
+ t1.create(engine, checkfirst=True)
+
+ When using a metadata-level ENUM type, the type will always be created
+ and dropped if either the metadata-wide create/drop is called::
+
+ metadata.create_all(engine) # will emit CREATE TYPE
+ metadata.drop_all(engine) # will emit DROP TYPE
+
+ The type can also be created and dropped directly::
+
+ my_enum.create(engine)
+ my_enum.drop(engine)
+
+ .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
+ now behaves more strictly with regards to CREATE/DROP. A metadata-level
+ ENUM type will only be created and dropped at the metadata level,
+ not the table level, with the exception of
+ ``table.create(checkfirst=True)``.
+ The ``table.drop()`` call will now emit a DROP TYPE for a table-level
+ enumerated type.
+
+ """
+
+ native_enum = True
+
+ def __init__(self, *enums, **kw):
+ """Construct an :class:`_postgresql.ENUM`.
+
+ Arguments are the same as that of
+ :class:`_types.Enum`, but also including
+ the following parameters.
+
+ :param create_type: Defaults to True.
+ Indicates that ``CREATE TYPE`` should be
+ emitted, after optionally checking for the
+ presence of the type, when the parent
+ table is being created; and additionally
+ that ``DROP TYPE`` is called when the table
+ is dropped. When ``False``, no check
+ will be performed and no ``CREATE TYPE``
+ or ``DROP TYPE`` is emitted, unless
+ :meth:`~.postgresql.ENUM.create`
+ or :meth:`~.postgresql.ENUM.drop`
+ are called directly.
+ Setting to ``False`` is helpful
+ when invoking a creation scheme to a SQL file
+ without access to the actual database -
+ the :meth:`~.postgresql.ENUM.create` and
+ :meth:`~.postgresql.ENUM.drop` methods can
+ be used to emit SQL to a target bind.
+
+ """
+ native_enum = kw.pop("native_enum", None)
+ if native_enum is False:
+ util.warn(
+ "the native_enum flag does not apply to the "
+ "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
+ "always refers to ENUM. Use sqlalchemy.types.Enum for "
+ "non-native enum."
+ )
+ self.create_type = kw.pop("create_type", True)
+ super(ENUM, self).__init__(*enums, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ kw.setdefault("name", impl.name)
+ kw.setdefault("schema", impl.schema)
+ kw.setdefault("inherit_schema", impl.inherit_schema)
+ kw.setdefault("metadata", impl.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("values_callable", impl.values_callable)
+ kw.setdefault("omit_aliases", impl._omit_aliases)
+ return cls(**kw)
+
+ def create(self, bind=None, checkfirst=True):
+ """Emit ``CREATE TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL CREATE TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type does not exist already before
+ creating.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=True):
+ """Emit ``DROP TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL DROP TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type actually exists before dropping.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+ class EnumGenerator(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_create_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return not self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_create_enum(enum):
+ return
+
+ self.connection.execute(CreateEnumType(enum))
+
+ class EnumDropper(DDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_drop_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_drop_enum(enum):
+ return
+
+ self.connection.execute(DropEnumType(enum))
+
+ def _check_for_name_in_memos(self, checkfirst, kw):
+ """Look in the 'ddl runner' for 'memos', then
+ note our name in that collection.
+
+ This to ensure a particular named enum is operated
+ upon only once within any kind of create/drop
+ sequence without relying upon "checkfirst".
+
+ """
+ if not self.create_type:
+ return True
+ if "_ddl_runner" in kw:
+ ddl_runner = kw["_ddl_runner"]
+ if "_pg_enums" in ddl_runner.memo:
+ pg_enums = ddl_runner.memo["_pg_enums"]
+ else:
+ pg_enums = ddl_runner.memo["_pg_enums"] = set()
+ present = (self.schema, self.name) in pg_enums
+ pg_enums.add((self.schema, self.name))
+ return present
+ else:
+ return False
+
+ def _on_table_create(self, target, bind, checkfirst=False, **kw):
+ if (
+ checkfirst
+ or (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ )
+ ) and not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+ if (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ and not self._check_for_name_in_memos(checkfirst, kw)
+ ):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+
+class _ColonCast(elements.Cast):
+ __visit_name__ = "colon_cast"
+
+ def __init__(self, expression, type_):
+ self.type = type_
+ self.clause = expression
+ self.typeclause = elements.TypeClause(type_)
+
+
+colspecs = {
+ sqltypes.ARRAY: _array.ARRAY,
+ sqltypes.Interval: INTERVAL,
+ sqltypes.Enum: ENUM,
+ sqltypes.JSON.JSONPathType: _json.JSONPathType,
+ sqltypes.JSON: _json.JSON,
+}
+
+ischema_names = {
+ "_array": _array.ARRAY,
+ "hstore": _hstore.HSTORE,
+ "json": _json.JSON,
+ "jsonb": _json.JSONB,
+ "int4range": _ranges.INT4RANGE,
+ "int8range": _ranges.INT8RANGE,
+ "numrange": _ranges.NUMRANGE,
+ "daterange": _ranges.DATERANGE,
+ "tsrange": _ranges.TSRANGE,
+ "tstzrange": _ranges.TSTZRANGE,
+ "integer": INTEGER,
+ "bigint": BIGINT,
+ "smallint": SMALLINT,
+ "character varying": VARCHAR,
+ "character": CHAR,
+ '"char"': sqltypes.String,
+ "name": sqltypes.String,
+ "text": TEXT,
+ "numeric": NUMERIC,
+ "float": FLOAT,
+ "real": REAL,
+ "inet": INET,
+ "cidr": CIDR,
+ "uuid": UUID,
+ "bit": BIT,
+ "bit varying": BIT,
+ "macaddr": MACADDR,
+ "money": MONEY,
+ "oid": OID,
+ "regclass": REGCLASS,
+ "double precision": DOUBLE_PRECISION,
+ "timestamp": TIMESTAMP,
+ "timestamp with time zone": TIMESTAMP,
+ "timestamp without time zone": TIMESTAMP,
+ "time with time zone": TIME,
+ "time without time zone": TIME,
+ "date": DATE,
+ "time": TIME,
+ "bytea": BYTEA,
+ "boolean": BOOLEAN,
+ "interval": INTERVAL,
+ "tsvector": TSVECTOR,
+}
+
+
+class PGCompiler(compiler.SQLCompiler):
+ def visit_colon_cast(self, element, **kw):
+ return "%s::%s" % (
+ element.clause._compiler_dispatch(self, **kw),
+ element.typeclause._compiler_dispatch(self, **kw),
+ )
+
+ def visit_array(self, element, **kw):
+ return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
+
+ def visit_slice(self, element, **kw):
+ return "%s:%s" % (
+ self.process(element.start, **kw),
+ self.process(element.stop, **kw),
+ )
+
+ def visit_json_getitem_op_binary(
+ self, binary, operator, _cast_applied=False, **kw
+ ):
+ if (
+ not _cast_applied
+ and binary.type._type_affinity is not sqltypes.JSON
+ ):
+ kw["_cast_applied"] = True
+ return self.process(sql.cast(binary, binary.type), **kw)
+
+ kw["eager_grouping"] = True
+
+ return self._generate_generic_binary(
+ binary, " -> " if not _cast_applied else " ->> ", **kw
+ )
+
+ def visit_json_path_getitem_op_binary(
+ self, binary, operator, _cast_applied=False, **kw
+ ):
+ if (
+ not _cast_applied
+ and binary.type._type_affinity is not sqltypes.JSON
+ ):
+ kw["_cast_applied"] = True
+ return self.process(sql.cast(binary, binary.type), **kw)
+
+ kw["eager_grouping"] = True
+ return self._generate_generic_binary(
+ binary, " #> " if not _cast_applied else " #>> ", **kw
+ )
+
+ def visit_getitem_binary(self, binary, operator, **kw):
+ return "%s[%s]" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_aggregate_order_by(self, element, **kw):
+ return "%s ORDER BY %s" % (
+ self.process(element.target, **kw),
+ self.process(element.order_by, **kw),
+ )
+
+ def visit_match_op_binary(self, binary, operator, **kw):
+ if "postgresql_regconfig" in binary.modifiers:
+ regconfig = self.render_literal_value(
+ binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE
+ )
+ if regconfig:
+ return "%s @@ to_tsquery(%s, %s)" % (
+ self.process(binary.left, **kw),
+ regconfig,
+ self.process(binary.right, **kw),
+ )
+ return "%s @@ to_tsquery(%s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+
+ return "%s ILIKE %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_not_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "%s NOT ILIKE %s" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def _regexp_match(self, base_op, binary, operator, kw):
+ flags = binary.modifiers["flags"]
+ if flags is None:
+ return self._generate_generic_binary(
+ binary, " %s " % base_op, **kw
+ )
+ if isinstance(flags, elements.BindParameter) and flags.value == "i":
+ return self._generate_generic_binary(
+ binary, " %s* " % base_op, **kw
+ )
+ flags = self.process(flags, **kw)
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
+ return "%s %s CONCAT('(?', %s, ')', %s)" % (
+ string,
+ base_op,
+ flags,
+ pattern,
+ )
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match("~", binary, operator, kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._regexp_match("!~", binary, operator, kw)
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
+ flags = binary.modifiers["flags"]
+ if flags is not None:
+ flags = self.process(flags, **kw)
+ replacement = self.process(binary.modifiers["replacement"], **kw)
+ if flags is None:
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ )
+ else:
+ return "REGEXP_REPLACE(%s, %s, %s, %s)" % (
+ string,
+ pattern,
+ replacement,
+ flags,
+ )
+
+ def visit_empty_set_expr(self, element_types):
+ # cast the empty set to the type we are comparing against. if
+ # we are comparing against the null type, pick an arbitrary
+ # datatype for the empty set
+ return "SELECT %s WHERE 1!=1" % (
+ ", ".join(
+ "CAST(NULL AS %s)"
+ % self.dialect.type_compiler.process(
+ INTEGER() if type_._isnull else type_
+ )
+ for type_ in element_types or [INTEGER()]
+ ),
+ )
+
+ def render_literal_value(self, value, type_):
+ value = super(PGCompiler, self).render_literal_value(value, type_)
+
+ if self.dialect._backslash_escapes:
+ value = value.replace("\\", "\\\\")
+ return value
+
+ def visit_sequence(self, seq, **kw):
+ return "nextval('%s')" % self.preparer.format_sequence(seq)
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += " \n LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += "\n LIMIT ALL"
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ return text
+
+ def format_from_hint_text(self, sqltext, table, hint, iscrud):
+ if hint.upper() != "ONLY":
+ raise exc.CompileError("Unrecognized hint: %r" % hint)
+ return "ONLY " + sqltext
+
+ def get_select_precolumns(self, select, **kw):
+ # Do not call super().get_select_precolumns because
+ # it will warn/raise when distinct on is present
+ if select._distinct or select._distinct_on:
+ if select._distinct_on:
+ return (
+ "DISTINCT ON ("
+ + ", ".join(
+ [
+ self.process(col, **kw)
+ for col in select._distinct_on
+ ]
+ )
+ + ") "
+ )
+ else:
+ return "DISTINCT "
+ else:
+ return ""
+
+ def for_update_clause(self, select, **kw):
+
+ if select._for_update_arg.read:
+ if select._for_update_arg.key_share:
+ tmp = " FOR KEY SHARE"
+ else:
+ tmp = " FOR SHARE"
+ elif select._for_update_arg.key_share:
+ tmp = " FOR NO KEY UPDATE"
+ else:
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of:
+
+ tables = util.OrderedSet()
+ for c in select._for_update_arg.of:
+ tables.update(sql_util.surface_selectables_only(c))
+
+ tmp += " OF " + ", ".join(
+ self.process(table, ashint=True, use_schema=False, **kw)
+ for table in tables
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+ if select._for_update_arg.skip_locked:
+ tmp += " SKIP LOCKED"
+
+ return tmp
+
+ def returning_clause(self, stmt, returning_cols):
+
+ columns = [
+ self._label_returning_column(stmt, c)
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return "RETURNING " + ", ".join(columns)
+
+ def visit_substring_func(self, func, **kw):
+ s = self.process(func.clauses.clauses[0], **kw)
+ start = self.process(func.clauses.clauses[1], **kw)
+ if len(func.clauses.clauses) > 2:
+ length = self.process(func.clauses.clauses[2], **kw)
+ return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
+ else:
+ return "SUBSTRING(%s FROM %s)" % (s, start)
+
+ def _on_conflict_target(self, clause, **kw):
+
+ if clause.constraint_target is not None:
+ # target may be a name of an Index, UniqueConstraint or
+ # ExcludeConstraint. While there is a separate
+ # "max_identifier_length" for indexes, PostgreSQL uses the same
+ # length for all objects so we can use
+ # truncate_and_render_constraint_name
+ target_text = (
+ "ON CONSTRAINT %s"
+ % self.preparer.truncate_and_render_constraint_name(
+ clause.constraint_target
+ )
+ )
+ elif clause.inferred_target_elements is not None:
+ target_text = "(%s)" % ", ".join(
+ (
+ self.preparer.quote(c)
+ if isinstance(c, util.string_types)
+ else self.process(c, include_table=False, use_schema=False)
+ )
+ for c in clause.inferred_target_elements
+ )
+ if clause.inferred_target_whereclause is not None:
+ target_text += " WHERE %s" % self.process(
+ clause.inferred_target_whereclause,
+ include_table=False,
+ use_schema=False,
+ )
+ else:
+ target_text = ""
+
+ return target_text
+
+ @util.memoized_property
+ def _is_safe_for_fast_insert_values_helper(self):
+ # don't allow fast executemany if _post_values_clause is
+ # present and is not an OnConflictDoNothing. what this means
+ # concretely is that the
+ # "fast insert executemany helper" won't be used, in other
+ # words we won't convert "executemany()" of many parameter
+ # sets into a single INSERT with many elements in VALUES.
+ # We can't apply that optimization safely if for example the
+ # statement includes a clause like "ON CONFLICT DO UPDATE"
+
+ return self.insert_single_values_expr is not None and (
+ self.statement._post_values_clause is None
+ or isinstance(
+ self.statement._post_values_clause, dml.OnConflictDoNothing
+ )
+ )
+
+ def visit_on_conflict_do_nothing(self, on_conflict, **kw):
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ if target_text:
+ return "ON CONFLICT %s DO NOTHING" % target_text
+ else:
+ return "ON CONFLICT DO NOTHING"
+
+ def visit_on_conflict_do_update(self, on_conflict, **kw):
+
+ clause = on_conflict
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ action_set_ops = []
+
+ set_parameters = dict(clause.update_values_to_set)
+ # create a list of column assignment clauses as tuples
+
+ insert_statement = self.stack[-1]["selectable"]
+ cols = insert_statement.table.c
+ for c in cols:
+ col_key = c.key
+
+ if col_key in set_parameters:
+ value = set_parameters.pop(col_key)
+ elif c in set_parameters:
+ value = set_parameters.pop(c)
+ else:
+ continue
+
+ if coercions._is_literal(value):
+ value = elements.BindParameter(None, value, type_=c.type)
+
+ else:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = self.preparer.quote(c.name)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ # check for names that don't match columns
+ if set_parameters:
+ util.warn(
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
+ self.current_executable.table.name,
+ (", ".join("'%s'" % c for c in set_parameters)),
+ )
+ )
+ for k, v in set_parameters.items():
+ key_text = (
+ self.preparer.quote(k)
+ if isinstance(k, util.string_types)
+ else self.process(k, use_schema=False)
+ )
+ value_text = self.process(
+ coercions.expect(roles.ExpressionElementRole, v),
+ use_schema=False,
+ )
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ action_text = ", ".join(action_set_ops)
+ if clause.update_whereclause is not None:
+ action_text += " WHERE %s" % self.process(
+ clause.update_whereclause, include_table=True, use_schema=False
+ )
+
+ return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text)
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ kw["asfrom"] = True
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. USING clause specific to PostgreSQL."""
+ kw["asfrom"] = True
+ return "USING " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def fetch_clause(self, select, **kw):
+ # pg requires parens for non literal clauses. It's also required for
+ # bind parameters if a ::type casts is used by the driver (asyncpg),
+ # so it's easiest to just always add it
+ text = ""
+ if select._offset_clause is not None:
+ text += "\n OFFSET (%s) ROWS" % self.process(
+ select._offset_clause, **kw
+ )
+ if select._fetch_clause is not None:
+ text += "\n FETCH FIRST (%s)%s ROWS %s" % (
+ self.process(select._fetch_clause, **kw),
+ " PERCENT" if select._fetch_clause_options["percent"] else "",
+ "WITH TIES"
+ if select._fetch_clause_options["with_ties"]
+ else "ONLY",
+ )
+ return text
+
+
+class PGDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+
+ colspec = self.preparer.format_column(column)
+ impl_type = column.type.dialect_impl(self.dialect)
+ if isinstance(impl_type, sqltypes.TypeDecorator):
+ impl_type = impl_type.impl
+
+ has_identity = (
+ column.identity is not None
+ and self.dialect.supports_identity_columns
+ )
+
+ if (
+ column.primary_key
+ and column is column.table._autoincrement_column
+ and (
+ self.dialect.supports_smallserial
+ or not isinstance(impl_type, sqltypes.SmallInteger)
+ )
+ and not has_identity
+ and (
+ column.default is None
+ or (
+ isinstance(column.default, schema.Sequence)
+ and column.default.optional
+ )
+ )
+ ):
+ if isinstance(impl_type, sqltypes.BigInteger):
+ colspec += " BIGSERIAL"
+ elif isinstance(impl_type, sqltypes.SmallInteger):
+ colspec += " SMALLSERIAL"
+ else:
+ colspec += " SERIAL"
+ else:
+ colspec += " " + self.dialect.type_compiler.process(
+ column.type,
+ type_expression=column,
+ identifier_preparer=self.preparer,
+ )
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+ if has_identity:
+ colspec += " " + self.process(column.identity)
+
+ if not column.nullable and not has_identity:
+ colspec += " NOT NULL"
+ elif column.nullable and has_identity:
+ colspec += " NULL"
+ return colspec
+
+ def _define_constraint_validity(self, constraint):
+ not_valid = constraint.dialect_options["postgresql"]["not_valid"]
+ return " NOT VALID" if not_valid else ""
+
+ def visit_check_constraint(self, constraint):
+ if constraint._type_bound:
+ typ = list(constraint.columns)[0].type
+ if (
+ isinstance(typ, sqltypes.ARRAY)
+ and isinstance(typ.item_type, sqltypes.Enum)
+ and not typ.item_type.native_enum
+ ):
+ raise exc.CompileError(
+ "PostgreSQL dialect cannot produce the CHECK constraint "
+ "for ARRAY of non-native ENUM; please specify "
+ "create_constraint=False on this Enum datatype."
+ )
+
+ text = super(PGDDLCompiler, self).visit_check_constraint(constraint)
+ text += self._define_constraint_validity(constraint)
+ return text
+
+ def visit_foreign_key_constraint(self, constraint):
+ text = super(PGDDLCompiler, self).visit_foreign_key_constraint(
+ constraint
+ )
+ text += self._define_constraint_validity(constraint)
+ return text
+
+ def visit_drop_table_comment(self, drop):
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
+
+ def visit_create_enum_type(self, create):
+ type_ = create.element
+
+ return "CREATE TYPE %s AS ENUM (%s)" % (
+ self.preparer.format_type(type_),
+ ", ".join(
+ self.sql_compiler.process(sql.literal(e), literal_binds=True)
+ for e in type_.enums
+ ),
+ )
+
+ def visit_drop_enum_type(self, drop):
+ type_ = drop.element
+
+ return "DROP TYPE %s" % (self.preparer.format_type(type_))
+
+ def visit_create_index(self, create):
+ preparer = self.preparer
+ index = create.element
+ self._verify_index_table(index)
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ text += "INDEX "
+
+ if self.dialect._supports_create_index_concurrently:
+ concurrently = index.dialect_options["postgresql"]["concurrently"]
+ if concurrently:
+ text += "CONCURRENTLY "
+
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += "%s ON %s " % (
+ self._prepared_index_name(index, include_schema=False),
+ preparer.format_table(index.table),
+ )
+
+ using = index.dialect_options["postgresql"]["using"]
+ if using:
+ text += (
+ "USING %s "
+ % self.preparer.validate_sql_phrase(using, IDX_USING).lower()
+ )
+
+ ops = index.dialect_options["postgresql"]["ops"]
+ text += "(%s)" % (
+ ", ".join(
+ [
+ self.sql_compiler.process(
+ expr.self_group()
+ if not isinstance(expr, expression.ColumnClause)
+ else expr,
+ include_table=False,
+ literal_binds=True,
+ )
+ + (
+ (" " + ops[expr.key])
+ if hasattr(expr, "key") and expr.key in ops
+ else ""
+ )
+ for expr in index.expressions
+ ]
+ )
+ )
+
+ includeclause = index.dialect_options["postgresql"]["include"]
+ if includeclause:
+ inclusions = [
+ index.table.c[col]
+ if isinstance(col, util.string_types)
+ else col
+ for col in includeclause
+ ]
+ text += " INCLUDE (%s)" % ", ".join(
+ [preparer.quote(c.name) for c in inclusions]
+ )
+
+ withclause = index.dialect_options["postgresql"]["with"]
+ if withclause:
+ text += " WITH (%s)" % (
+ ", ".join(
+ [
+ "%s = %s" % storage_parameter
+ for storage_parameter in withclause.items()
+ ]
+ )
+ )
+
+ tablespace_name = index.dialect_options["postgresql"]["tablespace"]
+ if tablespace_name:
+ text += " TABLESPACE %s" % preparer.quote(tablespace_name)
+
+ whereclause = index.dialect_options["postgresql"]["where"]
+ if whereclause is not None:
+ whereclause = coercions.expect(
+ roles.DDLExpressionRole, whereclause
+ )
+
+ where_compiled = self.sql_compiler.process(
+ whereclause, include_table=False, literal_binds=True
+ )
+ text += " WHERE " + where_compiled
+
+ return text
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+
+ text = "\nDROP INDEX "
+
+ if self.dialect._supports_drop_index_concurrently:
+ concurrently = index.dialect_options["postgresql"]["concurrently"]
+ if concurrently:
+ text += "CONCURRENTLY "
+
+ if drop.if_exists:
+ text += "IF EXISTS "
+
+ text += self._prepared_index_name(index, include_schema=True)
+ return text
+
+ def visit_exclude_constraint(self, constraint, **kw):
+ text = ""
+ if constraint.name is not None:
+ text += "CONSTRAINT %s " % self.preparer.format_constraint(
+ constraint
+ )
+ elements = []
+ for expr, name, op in constraint._render_exprs:
+ kw["include_table"] = False
+ exclude_element = self.sql_compiler.process(expr, **kw) + (
+ (" " + constraint.ops[expr.key])
+ if hasattr(expr, "key") and expr.key in constraint.ops
+ else ""
+ )
+
+ elements.append("%s WITH %s" % (exclude_element, op))
+ text += "EXCLUDE USING %s (%s)" % (
+ self.preparer.validate_sql_phrase(
+ constraint.using, IDX_USING
+ ).lower(),
+ ", ".join(elements),
+ )
+ if constraint.where is not None:
+ text += " WHERE (%s)" % self.sql_compiler.process(
+ constraint.where, literal_binds=True
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def post_create_table(self, table):
+ table_opts = []
+ pg_opts = table.dialect_options["postgresql"]
+
+ inherits = pg_opts.get("inherits")
+ if inherits is not None:
+ if not isinstance(inherits, (list, tuple)):
+ inherits = (inherits,)
+ table_opts.append(
+ "\n INHERITS ( "
+ + ", ".join(self.preparer.quote(name) for name in inherits)
+ + " )"
+ )
+
+ if pg_opts["partition_by"]:
+ table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"])
+
+ if pg_opts["with_oids"] is True:
+ table_opts.append("\n WITH OIDS")
+ elif pg_opts["with_oids"] is False:
+ table_opts.append("\n WITHOUT OIDS")
+
+ if pg_opts["on_commit"]:
+ on_commit_options = pg_opts["on_commit"].replace("_", " ").upper()
+ table_opts.append("\n ON COMMIT %s" % on_commit_options)
+
+ if pg_opts["tablespace"]:
+ tablespace_name = pg_opts["tablespace"]
+ table_opts.append(
+ "\n TABLESPACE %s" % self.preparer.quote(tablespace_name)
+ )
+
+ return "".join(table_opts)
+
+ def visit_computed_column(self, generated):
+ if generated.persisted is False:
+ raise exc.CompileError(
+ "PostrgreSQL computed columns do not support 'virtual' "
+ "persistence; set the 'persisted' flag to None or True for "
+ "PostgreSQL support."
+ )
+
+ return "GENERATED ALWAYS AS (%s) STORED" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+
+ def visit_create_sequence(self, create, **kw):
+ prefix = None
+ if create.element.data_type is not None:
+ prefix = " AS %s" % self.type_compiler.process(
+ create.element.data_type
+ )
+
+ return super(PGDDLCompiler, self).visit_create_sequence(
+ create, prefix=prefix, **kw
+ )
+
+
+class PGTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_TSVECTOR(self, type_, **kw):
+ return "TSVECTOR"
+
+ def visit_INET(self, type_, **kw):
+ return "INET"
+
+ def visit_CIDR(self, type_, **kw):
+ return "CIDR"
+
+ def visit_MACADDR(self, type_, **kw):
+ return "MACADDR"
+
+ def visit_MONEY(self, type_, **kw):
+ return "MONEY"
+
+ def visit_OID(self, type_, **kw):
+ return "OID"
+
+ def visit_REGCLASS(self, type_, **kw):
+ return "REGCLASS"
+
+ def visit_FLOAT(self, type_, **kw):
+ if not type_.precision:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {"precision": type_.precision}
+
+ def visit_DOUBLE_PRECISION(self, type_, **kw):
+ return "DOUBLE PRECISION"
+
+ def visit_BIGINT(self, type_, **kw):
+ return "BIGINT"
+
+ def visit_HSTORE(self, type_, **kw):
+ return "HSTORE"
+
+ def visit_JSON(self, type_, **kw):
+ return "JSON"
+
+ def visit_JSONB(self, type_, **kw):
+ return "JSONB"
+
+ def visit_INT4RANGE(self, type_, **kw):
+ return "INT4RANGE"
+
+ def visit_INT8RANGE(self, type_, **kw):
+ return "INT8RANGE"
+
+ def visit_NUMRANGE(self, type_, **kw):
+ return "NUMRANGE"
+
+ def visit_DATERANGE(self, type_, **kw):
+ return "DATERANGE"
+
+ def visit_TSRANGE(self, type_, **kw):
+ return "TSRANGE"
+
+ def visit_TSTZRANGE(self, type_, **kw):
+ return "TSTZRANGE"
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_TIMESTAMP(type_, **kw)
+
+ def visit_enum(self, type_, **kw):
+ if not type_.native_enum or not self.dialect.supports_native_enum:
+ return super(PGTypeCompiler, self).visit_enum(type_, **kw)
+ else:
+ return self.visit_ENUM(type_, **kw)
+
+ def visit_ENUM(self, type_, identifier_preparer=None, **kw):
+ if identifier_preparer is None:
+ identifier_preparer = self.dialect.identifier_preparer
+ return identifier_preparer.format_type(type_)
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ return "TIMESTAMP%s %s" % (
+ "(%d)" % type_.precision
+ if getattr(type_, "precision", None) is not None
+ else "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
+ )
+
+ def visit_TIME(self, type_, **kw):
+ return "TIME%s %s" % (
+ "(%d)" % type_.precision
+ if getattr(type_, "precision", None) is not None
+ else "",
+ (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",
+ )
+
+ def visit_INTERVAL(self, type_, **kw):
+ text = "INTERVAL"
+ if type_.fields is not None:
+ text += " " + type_.fields
+ if type_.precision is not None:
+ text += " (%d)" % type_.precision
+ return text
+
+ def visit_BIT(self, type_, **kw):
+ if type_.varying:
+ compiled = "BIT VARYING"
+ if type_.length is not None:
+ compiled += "(%d)" % type_.length
+ else:
+ compiled = "BIT(%d)" % type_.length
+ return compiled
+
+ def visit_UUID(self, type_, **kw):
+ return "UUID"
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BYTEA(type_, **kw)
+
+ def visit_BYTEA(self, type_, **kw):
+ return "BYTEA"
+
+ def visit_ARRAY(self, type_, **kw):
+
+ inner = self.process(type_.item_type, **kw)
+ return re.sub(
+ r"((?: COLLATE.*)?)$",
+ (
+ r"%s\1"
+ % (
+ "[]"
+ * (type_.dimensions if type_.dimensions is not None else 1)
+ )
+ ),
+ inner,
+ count=1,
+ )
+
+
+class PGIdentifierPreparer(compiler.IdentifierPreparer):
+
+ reserved_words = RESERVED_WORDS
+
+ def _unquote_identifier(self, value):
+ if value[0] == self.initial_quote:
+ value = value[1:-1].replace(
+ self.escape_to_quote, self.escape_quote
+ )
+ return value
+
+ def format_type(self, type_, use_schema=True):
+ if not type_.name:
+ raise exc.CompileError("PostgreSQL ENUM type requires a name.")
+
+ name = self.quote(type_.name)
+ effective_schema = self.schema_for_object(type_)
+
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
+ name = self.quote_schema(effective_schema) + "." + name
+ return name
+
+
+class PGInspector(reflection.Inspector):
+ def get_table_oid(self, table_name, schema=None):
+ """Return the OID for the given table name."""
+
+ with self._operation_context() as conn:
+ return self.dialect.get_table_oid(
+ conn, table_name, schema, info_cache=self.info_cache
+ )
+
+ def get_enums(self, schema=None):
+ """Return a list of ENUM objects.
+
+ Each member is a dictionary containing these fields:
+
+ * name - name of the enum
+ * schema - the schema name for the enum.
+ * visible - boolean, whether or not this enum is visible
+ in the default search path.
+ * labels - a list of string labels that apply to the enum.
+
+ :param schema: schema name. If None, the default schema
+ (typically 'public') is used. May also be set to '*' to
+ indicate load enums for all schemas.
+
+ .. versionadded:: 1.0.0
+
+ """
+ schema = schema or self.default_schema_name
+ with self._operation_context() as conn:
+ return self.dialect._load_enums(conn, schema)
+
+ def get_foreign_table_names(self, schema=None):
+ """Return a list of FOREIGN TABLE names.
+
+ Behavior is similar to that of
+ :meth:`_reflection.Inspector.get_table_names`,
+ except that the list is limited to those tables that report a
+ ``relkind`` value of ``f``.
+
+ .. versionadded:: 1.0.0
+
+ """
+ schema = schema or self.default_schema_name
+ with self._operation_context() as conn:
+ return self.dialect._get_foreign_table_names(conn, schema)
+
+ def get_view_names(self, schema=None, include=("plain", "materialized")):
+ """Return all view names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ :param include: specify which types of views to return. Passed
+ as a string value (for a single type) or a tuple (for any number
+ of types). Defaults to ``('plain', 'materialized')``.
+
+ .. versionadded:: 1.1
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_view_names(
+ conn, schema, info_cache=self.info_cache, include=include
+ )
+
+
+class CreateEnumType(schema._CreateDropBase):
+ __visit_name__ = "create_enum_type"
+
+
+class DropEnumType(schema._CreateDropBase):
+ __visit_name__ = "drop_enum_type"
+
+
+class PGExecutionContext(default.DefaultExecutionContext):
+ def fire_sequence(self, seq, type_):
+ return self._execute_scalar(
+ (
+ "select nextval('%s')"
+ % self.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
+
+ def get_insert_default(self, column):
+ if column.primary_key and column is column.table._autoincrement_column:
+ if column.server_default and column.server_default.has_argument:
+
+ # pre-execute passive defaults on primary key columns
+ return self._execute_scalar(
+ "select %s" % column.server_default.arg, column.type
+ )
+
+ elif column.default is None or (
+ column.default.is_sequence and column.default.optional
+ ):
+ # execute the sequence associated with a SERIAL primary
+ # key column. for non-primary-key SERIAL, the ID just
+ # generates server side.
+
+ try:
+ seq_name = column._postgresql_seq_name
+ except AttributeError:
+ tab = column.table.name
+ col = column.name
+ tab = tab[0 : 29 + max(0, (29 - len(col)))]
+ col = col[0 : 29 + max(0, (29 - len(tab)))]
+ name = "%s_%s_seq" % (tab, col)
+ column._postgresql_seq_name = seq_name = name
+
+ if column.table is not None:
+ effective_schema = self.connection.schema_for_object(
+ column.table
+ )
+ else:
+ effective_schema = None
+
+ if effective_schema is not None:
+ exc = 'select nextval(\'"%s"."%s"\')' % (
+ effective_schema,
+ seq_name,
+ )
+ else:
+ exc = "select nextval('\"%s\"')" % (seq_name,)
+
+ return self._execute_scalar(exc, column.type)
+
+ return super(PGExecutionContext, self).get_insert_default(column)
+
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_REGEXP.match(statement)
+
+
+class PGReadOnlyConnectionCharacteristic(
+ characteristics.ConnectionCharacteristic
+):
+ transactional = True
+
+ def reset_characteristic(self, dialect, dbapi_conn):
+ dialect.set_readonly(dbapi_conn, False)
+
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ dialect.set_readonly(dbapi_conn, value)
+
+ def get_characteristic(self, dialect, dbapi_conn):
+ return dialect.get_readonly(dbapi_conn)
+
+
+class PGDeferrableConnectionCharacteristic(
+ characteristics.ConnectionCharacteristic
+):
+ transactional = True
+
+ def reset_characteristic(self, dialect, dbapi_conn):
+ dialect.set_deferrable(dbapi_conn, False)
+
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ dialect.set_deferrable(dbapi_conn, value)
+
+ def get_characteristic(self, dialect, dbapi_conn):
+ return dialect.get_deferrable(dbapi_conn)
+
+
+class PGDialect(default.DefaultDialect):
+ name = "postgresql"
+ supports_statement_cache = True
+ supports_alter = True
+ max_identifier_length = 63
+ supports_sane_rowcount = True
+
+ supports_native_enum = True
+ supports_native_boolean = True
+ supports_smallserial = True
+
+ supports_sequences = True
+ sequences_optional = True
+ preexecute_autoincrement_sequences = True
+ postfetch_lastrowid = False
+
+ supports_comments = True
+ supports_default_values = True
+
+ supports_default_metavalue = True
+
+ supports_empty_insert = False
+ supports_multivalues_insert = True
+ supports_identity_columns = True
+
+ default_paramstyle = "pyformat"
+ ischema_names = ischema_names
+ colspecs = colspecs
+
+ statement_compiler = PGCompiler
+ ddl_compiler = PGDDLCompiler
+ type_compiler = PGTypeCompiler
+ preparer = PGIdentifierPreparer
+ execution_ctx_cls = PGExecutionContext
+ inspector = PGInspector
+ isolation_level = None
+
+ implicit_returning = True
+ full_returning = True
+
+ connection_characteristics = (
+ default.DefaultDialect.connection_characteristics
+ )
+ connection_characteristics = connection_characteristics.union(
+ {
+ "postgresql_readonly": PGReadOnlyConnectionCharacteristic(),
+ "postgresql_deferrable": PGDeferrableConnectionCharacteristic(),
+ }
+ )
+
+ construct_arguments = [
+ (
+ schema.Index,
+ {
+ "using": False,
+ "include": None,
+ "where": None,
+ "ops": {},
+ "concurrently": False,
+ "with": {},
+ "tablespace": None,
+ },
+ ),
+ (
+ schema.Table,
+ {
+ "ignore_search_path": False,
+ "tablespace": None,
+ "partition_by": None,
+ "with_oids": None,
+ "on_commit": None,
+ "inherits": None,
+ },
+ ),
+ (
+ schema.CheckConstraint,
+ {
+ "not_valid": False,
+ },
+ ),
+ (
+ schema.ForeignKeyConstraint,
+ {
+ "not_valid": False,
+ },
+ ),
+ ]
+
+ reflection_options = ("postgresql_ignore_search_path",)
+
+ _backslash_escapes = True
+ _supports_create_index_concurrently = True
+ _supports_drop_index_concurrently = True
+
+ def __init__(
+ self,
+ isolation_level=None,
+ json_serializer=None,
+ json_deserializer=None,
+ **kwargs
+ ):
+ default.DefaultDialect.__init__(self, **kwargs)
+
+ # the isolation_level parameter to the PGDialect itself is legacy.
+ # still works however the execution_options method is the one that
+ # is documented.
+ self.isolation_level = isolation_level
+ self._json_deserializer = json_deserializer
+ self._json_serializer = json_serializer
+
+ def initialize(self, connection):
+ super(PGDialect, self).initialize(connection)
+
+ if self.server_version_info <= (8, 2):
+ self.full_returning = self.implicit_returning = False
+
+ self.supports_native_enum = self.server_version_info >= (8, 3)
+ if not self.supports_native_enum:
+ self.colspecs = self.colspecs.copy()
+ # pop base Enum type
+ self.colspecs.pop(sqltypes.Enum, None)
+ # psycopg2, others may have placed ENUM here as well
+ self.colspecs.pop(ENUM, None)
+
+ # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
+ self.supports_smallserial = self.server_version_info >= (9, 2)
+
+ if self.server_version_info < (8, 2):
+ self._backslash_escapes = False
+ else:
+ # ensure this query is not emitted on server version < 8.2
+ # as it will fail
+ std_string = connection.exec_driver_sql(
+ "show standard_conforming_strings"
+ ).scalar()
+ self._backslash_escapes = std_string == "off"
+
+ self._supports_create_index_concurrently = (
+ self.server_version_info >= (8, 2)
+ )
+ self._supports_drop_index_concurrently = self.server_version_info >= (
+ 9,
+ 2,
+ )
+ self.supports_identity_columns = self.server_version_info >= (10,)
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ _isolation_lookup = set(
+ [
+ "SERIALIZABLE",
+ "READ UNCOMMITTED",
+ "READ COMMITTED",
+ "REPEATABLE READ",
+ ]
+ )
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+ if level not in self._isolation_lookup:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+ cursor = connection.cursor()
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION "
+ "ISOLATION LEVEL %s" % level
+ )
+ cursor.execute("COMMIT")
+ cursor.close()
+
+ def get_isolation_level(self, connection):
+ cursor = connection.cursor()
+ cursor.execute("show transaction isolation level")
+ val = cursor.fetchone()[0]
+ cursor.close()
+ return val.upper()
+
+ def set_readonly(self, connection, value):
+ raise NotImplementedError()
+
+ def get_readonly(self, connection):
+ raise NotImplementedError()
+
+ def set_deferrable(self, connection, value):
+ raise NotImplementedError()
+
+ def get_deferrable(self, connection):
+ raise NotImplementedError()
+
+ def do_begin_twophase(self, connection, xid):
+ self.do_begin(connection.connection)
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.exec_driver_sql("PREPARE TRANSACTION '%s'" % xid)
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if is_prepared:
+ if recover:
+ # FIXME: ugly hack to get out of transaction
+ # context when committing recoverable transactions
+ # Must find out a way how to make the dbapi not
+ # open a transaction.
+ connection.exec_driver_sql("ROLLBACK")
+ connection.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid)
+ connection.exec_driver_sql("BEGIN")
+ self.do_rollback(connection.connection)
+ else:
+ self.do_rollback(connection.connection)
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ if is_prepared:
+ if recover:
+ connection.exec_driver_sql("ROLLBACK")
+ connection.exec_driver_sql("COMMIT PREPARED '%s'" % xid)
+ connection.exec_driver_sql("BEGIN")
+ self.do_rollback(connection.connection)
+ else:
+ self.do_commit(connection.connection)
+
+ def do_recover_twophase(self, connection):
+ resultset = connection.execute(
+ sql.text("SELECT gid FROM pg_prepared_xacts")
+ )
+ return [row[0] for row in resultset]
+
+ def _get_default_schema_name(self, connection):
+ return connection.exec_driver_sql("select current_schema()").scalar()
+
+ def has_schema(self, connection, schema):
+ query = (
+ "select nspname from pg_namespace " "where lower(nspname)=:schema"
+ )
+ cursor = connection.execute(
+ sql.text(query).bindparams(
+ sql.bindparam(
+ "schema",
+ util.text_type(schema.lower()),
+ type_=sqltypes.Unicode,
+ )
+ )
+ )
+
+ return bool(cursor.first())
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+ # seems like case gets folded in pg_class...
+ if schema is None:
+ cursor = connection.execute(
+ sql.text(
+ "select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where "
+ "pg_catalog.pg_table_is_visible(c.oid) "
+ "and relname=:name"
+ ).bindparams(
+ sql.bindparam(
+ "name",
+ util.text_type(table_name),
+ type_=sqltypes.Unicode,
+ )
+ )
+ )
+ else:
+ cursor = connection.execute(
+ sql.text(
+ "select relname from pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where n.nspname=:schema and "
+ "relname=:name"
+ ).bindparams(
+ sql.bindparam(
+ "name",
+ util.text_type(table_name),
+ type_=sqltypes.Unicode,
+ ),
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ )
+ )
+ return bool(cursor.first())
+
+ def has_sequence(self, connection, sequence_name, schema=None):
+ if schema is None:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT relname FROM pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where relkind='S' and "
+ "n.nspname=:schema and relname=:name"
+ ).bindparams(
+ sql.bindparam(
+ "name",
+ util.text_type(sequence_name),
+ type_=sqltypes.Unicode,
+ ),
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ )
+ )
+
+ return bool(cursor.first())
+
+ def has_type(self, connection, type_name, schema=None):
+ if schema is not None:
+ query = """
+ SELECT EXISTS (
+ SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
+ WHERE t.typnamespace = n.oid
+ AND t.typname = :typname
+ AND n.nspname = :nspname
+ )
+ """
+ query = sql.text(query)
+ else:
+ query = """
+ SELECT EXISTS (
+ SELECT * FROM pg_catalog.pg_type t
+ WHERE t.typname = :typname
+ AND pg_type_is_visible(t.oid)
+ )
+ """
+ query = sql.text(query)
+ query = query.bindparams(
+ sql.bindparam(
+ "typname", util.text_type(type_name), type_=sqltypes.Unicode
+ )
+ )
+ if schema is not None:
+ query = query.bindparams(
+ sql.bindparam(
+ "nspname", util.text_type(schema), type_=sqltypes.Unicode
+ )
+ )
+ cursor = connection.execute(query)
+ return bool(cursor.scalar())
+
+ def _get_server_version_info(self, connection):
+ v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
+ m = re.match(
+ r".*(?:PostgreSQL|EnterpriseDB) "
+ r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?",
+ v,
+ )
+ if not m:
+ raise AssertionError(
+ "Could not determine version from string '%s'" % v
+ )
+ return tuple([int(x) for x in m.group(1, 2, 3) if x is not None])
+
+ @reflection.cache
+ def get_table_oid(self, connection, table_name, schema=None, **kw):
+ """Fetch the oid for schema.table_name.
+
+ Several reflection methods require the table oid. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+ table_oid = None
+ if schema is not None:
+ schema_where_clause = "n.nspname = :schema"
+ else:
+ schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
+ query = (
+ """
+ SELECT c.oid
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+ WHERE (%s)
+ AND c.relname = :table_name AND c.relkind in
+ ('r', 'v', 'm', 'f', 'p')
+ """
+ % schema_where_clause
+ )
+ # Since we're binding to unicode, table_name and schema_name must be
+ # unicode.
+ table_name = util.text_type(table_name)
+ if schema is not None:
+ schema = util.text_type(schema)
+ s = sql.text(query).bindparams(table_name=sqltypes.Unicode)
+ s = s.columns(oid=sqltypes.Integer)
+ if schema:
+ s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode))
+ c = connection.execute(s, dict(table_name=table_name, schema=schema))
+ table_oid = c.scalar()
+ if table_oid is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_oid
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ result = connection.execute(
+ sql.text(
+ "SELECT nspname FROM pg_namespace "
+ "WHERE nspname NOT LIKE 'pg_%' "
+ "ORDER BY nspname"
+ ).columns(nspname=sqltypes.Unicode)
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ result = connection.execute(
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
+ ).columns(relname=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def _get_foreign_table_names(self, connection, schema=None, **kw):
+ result = connection.execute(
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind = 'f'"
+ ).columns(relname=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def get_view_names(
+ self, connection, schema=None, include=("plain", "materialized"), **kw
+ ):
+
+ include_kind = {"plain": "v", "materialized": "m"}
+ try:
+ kinds = [include_kind[i] for i in util.to_list(include)]
+ except KeyError:
+ raise ValueError(
+ "include %r unknown, needs to be a sequence containing "
+ "one or both of 'plain' and 'materialized'" % (include,)
+ )
+ if not kinds:
+ raise ValueError(
+ "empty include, needs to be a sequence containing "
+ "one or both of 'plain' and 'materialized'"
+ )
+
+ result = connection.execute(
+ sql.text(
+ "SELECT c.relname FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relkind IN (%s)"
+ % (", ".join("'%s'" % elem for elem in kinds))
+ ).columns(relname=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name
+ ),
+ )
+ return [name for name, in result]
+
+ @reflection.cache
+ def get_sequence_names(self, connection, schema=None, **kw):
+ if not schema:
+ schema = self.default_schema_name
+ cursor = connection.execute(
+ sql.text(
+ "SELECT relname FROM pg_class c join pg_namespace n on "
+ "n.oid=c.relnamespace where relkind='S' and "
+ "n.nspname=:schema"
+ ).bindparams(
+ sql.bindparam(
+ "schema",
+ util.text_type(schema),
+ type_=sqltypes.Unicode,
+ ),
+ )
+ )
+ return [row[0] for row in cursor]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ view_def = connection.scalar(
+ sql.text(
+ "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
+ "JOIN pg_namespace n ON n.oid = c.relnamespace "
+ "WHERE n.nspname = :schema AND c.relname = :view_name "
+ "AND c.relkind IN ('v', 'm')"
+ ).columns(view_def=sqltypes.Unicode),
+ dict(
+ schema=schema
+ if schema is not None
+ else self.default_schema_name,
+ view_name=view_name,
+ ),
+ )
+ return view_def
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ generated = (
+ "a.attgenerated as generated"
+ if self.server_version_info >= (12,)
+ else "NULL as generated"
+ )
+ if self.server_version_info >= (10,):
+ # a.attidentity != '' is required or it will reflect also
+ # serial columns as identity.
+ identity = """\
+ (SELECT json_build_object(
+ 'always', a.attidentity = 'a',
+ 'start', s.seqstart,
+ 'increment', s.seqincrement,
+ 'minvalue', s.seqmin,
+ 'maxvalue', s.seqmax,
+ 'cache', s.seqcache,
+ 'cycle', s.seqcycle)
+ FROM pg_catalog.pg_sequence s
+ JOIN pg_catalog.pg_class c on s.seqrelid = c."oid"
+ WHERE c.relkind = 'S'
+ AND a.attidentity != ''
+ AND s.seqrelid = pg_catalog.pg_get_serial_sequence(
+ a.attrelid::regclass::text, a.attname
+ )::regclass::oid
+ ) as identity_options\
+ """
+ else:
+ identity = "NULL as identity_options"
+
+ SQL_COLS = """
+ SELECT a.attname,
+ pg_catalog.format_type(a.atttypid, a.atttypmod),
+ (
+ SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid)
+ FROM pg_catalog.pg_attrdef d
+ WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum
+ AND a.atthasdef
+ ) AS DEFAULT,
+ a.attnotnull,
+ a.attrelid as table_oid,
+ pgd.description as comment,
+ %s,
+ %s
+ FROM pg_catalog.pg_attribute a
+ LEFT JOIN pg_catalog.pg_description pgd ON (
+ pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum)
+ WHERE a.attrelid = :table_oid
+ AND a.attnum > 0 AND NOT a.attisdropped
+ ORDER BY a.attnum
+ """ % (
+ generated,
+ identity,
+ )
+ s = (
+ sql.text(SQL_COLS)
+ .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
+ .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
+ )
+ c = connection.execute(s, dict(table_oid=table_oid))
+ rows = c.fetchall()
+
+ # dictionary with (name, ) if default search path or (schema, name)
+ # as keys
+ domains = self._load_domains(connection)
+
+ # dictionary with (name, ) if default search path or (schema, name)
+ # as keys
+ enums = dict(
+ ((rec["name"],), rec)
+ if rec["visible"]
+ else ((rec["schema"], rec["name"]), rec)
+ for rec in self._load_enums(connection, schema="*")
+ )
+
+ # format columns
+ columns = []
+
+ for (
+ name,
+ format_type,
+ default_,
+ notnull,
+ table_oid,
+ comment,
+ generated,
+ identity,
+ ) in rows:
+ column_info = self._get_column_info(
+ name,
+ format_type,
+ default_,
+ notnull,
+ domains,
+ enums,
+ schema,
+ comment,
+ generated,
+ identity,
+ )
+ columns.append(column_info)
+ return columns
+
+ def _get_column_info(
+ self,
+ name,
+ format_type,
+ default,
+ notnull,
+ domains,
+ enums,
+ schema,
+ comment,
+ generated,
+ identity,
+ ):
+ def _handle_array_type(attype):
+ return (
+ # strip '[]' from integer[], etc.
+ re.sub(r"\[\]$", "", attype),
+ attype.endswith("[]"),
+ )
+
+ # strip (*) from character varying(5), timestamp(5)
+ # with time zone, geometry(POLYGON), etc.
+ attype = re.sub(r"\(.*\)", "", format_type)
+
+ # strip '[]' from integer[], etc. and check if an array
+ attype, is_array = _handle_array_type(attype)
+
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+ nullable = not notnull
+
+ charlen = re.search(r"\(([\d,]+)\)", format_type)
+ if charlen:
+ charlen = charlen.group(1)
+ args = re.search(r"\((.*)\)", format_type)
+ if args and args.group(1):
+ args = tuple(re.split(r"\s*,\s*", args.group(1)))
+ else:
+ args = ()
+ kwargs = {}
+
+ if attype == "numeric":
+ if charlen:
+ prec, scale = charlen.split(",")
+ args = (int(prec), int(scale))
+ else:
+ args = ()
+ elif attype == "double precision":
+ args = (53,)
+ elif attype == "integer":
+ args = ()
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype == "bit varying":
+ kwargs["varying"] = True
+ if charlen:
+ args = (int(charlen),)
+ else:
+ args = ()
+ elif attype.startswith("interval"):
+ field_match = re.match(r"interval (.+)", attype, re.I)
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ if field_match:
+ kwargs["fields"] = field_match.group(1)
+ attype = "interval"
+ args = ()
+ elif charlen:
+ args = (int(charlen),)
+
+ while True:
+ # looping here to suit nested domains
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
+ break
+ elif enum_or_domain_key in enums:
+ enum = enums[enum_or_domain_key]
+ coltype = ENUM
+ kwargs["name"] = enum["name"]
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
+ break
+ elif enum_or_domain_key in domains:
+ domain = domains[enum_or_domain_key]
+ attype = domain["attype"]
+ attype, is_array = _handle_array_type(attype)
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+ # A table can't override a not null on the domain,
+ # but can override nullable
+ nullable = nullable and domain["nullable"]
+ if domain["default"] and not default:
+ # It can, however, override the default
+ # value, but can't set it to null.
+ default = domain["default"]
+ continue
+ else:
+ coltype = None
+ break
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = self.ischema_names["_array"](coltype)
+ else:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (attype, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ # If a zero byte or blank string depending on driver (is also absent
+ # for older PG versions), then not a generated column. Otherwise, s =
+ # stored. (Other values might be added in the future.)
+ if generated not in (None, "", b"\x00"):
+ computed = dict(
+ sqltext=default, persisted=generated in ("s", b"s")
+ )
+ default = None
+ else:
+ computed = None
+
+ # adjust the default value
+ autoincrement = False
+ if default is not None:
+ match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+ if match is not None:
+ if issubclass(coltype._type_affinity, sqltypes.Integer):
+ autoincrement = True
+ # the default is related to a Sequence
+ sch = schema
+ if "." not in match.group(2) and sch is not None:
+ # unconditionally quote the schema name. this could
+ # later be enhanced to obey quoting rules /
+ # "quote schema"
+ default = (
+ match.group(1)
+ + ('"%s"' % sch)
+ + "."
+ + match.group(2)
+ + match.group(3)
+ )
+
+ column_info = dict(
+ name=name,
+ type=coltype,
+ nullable=nullable,
+ default=default,
+ autoincrement=autoincrement or identity is not None,
+ comment=comment,
+ )
+ if computed is not None:
+ column_info["computed"] = computed
+ if identity is not None:
+ column_info["identity"] = identity
+ return column_info
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ if self.server_version_info < (8, 4):
+ PK_SQL = """
+ SELECT a.attname
+ FROM
+ pg_class t
+ join pg_index ix on t.oid = ix.indrelid
+ join pg_attribute a
+ on t.oid=a.attrelid AND %s
+ WHERE
+ t.oid = :table_oid and ix.indisprimary = 't'
+ ORDER BY a.attnum
+ """ % self._pg_index_any(
+ "a.attnum", "ix.indkey"
+ )
+
+ else:
+ # unnest() and generate_subscripts() both introduced in
+ # version 8.4
+ PK_SQL = """
+ SELECT a.attname
+ FROM pg_attribute a JOIN (
+ SELECT unnest(ix.indkey) attnum,
+ generate_subscripts(ix.indkey, 1) ord
+ FROM pg_index ix
+ WHERE ix.indrelid = :table_oid AND ix.indisprimary
+ ) k ON a.attnum=k.attnum
+ WHERE a.attrelid = :table_oid
+ ORDER BY k.ord
+ """
+ t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
+ c = connection.execute(t, dict(table_oid=table_oid))
+ cols = [r[0] for r in c.fetchall()]
+
+ PK_CONS_SQL = """
+ SELECT conname
+ FROM pg_catalog.pg_constraint r
+ WHERE r.conrelid = :table_oid AND r.contype = 'p'
+ ORDER BY 1
+ """
+ t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
+ c = connection.execute(t, dict(table_oid=table_oid))
+ name = c.scalar()
+
+ return {"constrained_columns": cols, "name": name}
+
+ @reflection.cache
+ def get_foreign_keys(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ postgresql_ignore_search_path=False,
+ **kw
+ ):
+ preparer = self.identifier_preparer
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ FK_SQL = """
+ SELECT r.conname,
+ pg_catalog.pg_get_constraintdef(r.oid, true) as condef,
+ n.nspname as conschema
+ FROM pg_catalog.pg_constraint r,
+ pg_namespace n,
+ pg_class c
+
+ WHERE r.conrelid = :table AND
+ r.contype = 'f' AND
+ c.oid = confrelid AND
+ n.oid = c.relnamespace
+ ORDER BY 1
+ """
+ # https://www.postgresql.org/docs/9.0/static/sql-createtable.html
+ FK_REGEX = re.compile(
+ r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
+ r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
+ r"[\s]?(ON UPDATE "
+ r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
+ r"[\s]?(ON DELETE "
+ r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
+ r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?"
+ r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
+ )
+
+ t = sql.text(FK_SQL).columns(
+ conname=sqltypes.Unicode, condef=sqltypes.Unicode
+ )
+ c = connection.execute(t, dict(table=table_oid))
+ fkeys = []
+ for conname, condef, conschema in c.fetchall():
+ m = re.search(FK_REGEX, condef).groups()
+
+ (
+ constrained_columns,
+ referred_schema,
+ referred_table,
+ referred_columns,
+ _,
+ match,
+ _,
+ onupdate,
+ _,
+ ondelete,
+ deferrable,
+ _,
+ initially,
+ ) = m
+
+ if deferrable is not None:
+ deferrable = True if deferrable == "DEFERRABLE" else False
+ constrained_columns = [
+ preparer._unquote_identifier(x)
+ for x in re.split(r"\s*,\s*", constrained_columns)
+ ]
+
+ if postgresql_ignore_search_path:
+ # when ignoring search path, we use the actual schema
+ # provided it isn't the "default" schema
+ if conschema != self.default_schema_name:
+ referred_schema = conschema
+ else:
+ referred_schema = schema
+ elif referred_schema:
+ # referred_schema is the schema that we regexp'ed from
+ # pg_get_constraintdef(). If the schema is in the search
+ # path, pg_get_constraintdef() will give us None.
+ referred_schema = preparer._unquote_identifier(referred_schema)
+ elif schema is not None and schema == conschema:
+ # If the actual schema matches the schema of the table
+ # we're reflecting, then we will use that.
+ referred_schema = schema
+
+ referred_table = preparer._unquote_identifier(referred_table)
+ referred_columns = [
+ preparer._unquote_identifier(x)
+ for x in re.split(r"\s*,\s", referred_columns)
+ ]
+ options = {
+ k: v
+ for k, v in [
+ ("onupdate", onupdate),
+ ("ondelete", ondelete),
+ ("initially", initially),
+ ("deferrable", deferrable),
+ ("match", match),
+ ]
+ if v is not None and v != "NO ACTION"
+ }
+ fkey_d = {
+ "name": conname,
+ "constrained_columns": constrained_columns,
+ "referred_schema": referred_schema,
+ "referred_table": referred_table,
+ "referred_columns": referred_columns,
+ "options": options,
+ }
+ fkeys.append(fkey_d)
+ return fkeys
+
+ def _pg_index_any(self, col, compare_to):
+ if self.server_version_info < (8, 1):
+ # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us
+ # "In CVS tip you could replace this with "attnum = ANY (indkey)".
+ # Unfortunately, most array support doesn't work on int2vector in
+ # pre-8.1 releases, so I think you're kinda stuck with the above
+ # for now.
+ # regards, tom lane"
+ return "(%s)" % " OR ".join(
+ "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10)
+ )
+ else:
+ return "%s = ANY(%s)" % (col, compare_to)
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ # cast indkey as varchar since it's an int2vector,
+ # returned as a list by some drivers such as pypostgresql
+
+ if self.server_version_info < (8, 5):
+ IDX_SQL = """
+ SELECT
+ i.relname as relname,
+ ix.indisunique, ix.indexprs, ix.indpred,
+ a.attname, a.attnum, NULL, ix.indkey%s,
+ %s, %s, am.amname,
+ NULL as indnkeyatts
+ FROM
+ pg_class t
+ join pg_index ix on t.oid = ix.indrelid
+ join pg_class i on i.oid = ix.indexrelid
+ left outer join
+ pg_attribute a
+ on t.oid = a.attrelid and %s
+ left outer join
+ pg_am am
+ on i.relam = am.oid
+ WHERE
+ t.relkind IN ('r', 'v', 'f', 'm')
+ and t.oid = :table_oid
+ and ix.indisprimary = 'f'
+ ORDER BY
+ t.relname,
+ i.relname
+ """ % (
+ # version 8.3 here was based on observing the
+ # cast does not work in PG 8.2.4, does work in 8.3.0.
+ # nothing in PG changelogs regarding this.
+ "::varchar" if self.server_version_info >= (8, 3) else "",
+ "ix.indoption::varchar"
+ if self.server_version_info >= (8, 3)
+ else "NULL",
+ "i.reloptions"
+ if self.server_version_info >= (8, 2)
+ else "NULL",
+ self._pg_index_any("a.attnum", "ix.indkey"),
+ )
+ else:
+ IDX_SQL = """
+ SELECT
+ i.relname as relname,
+ ix.indisunique, ix.indexprs,
+ a.attname, a.attnum, c.conrelid, ix.indkey::varchar,
+ ix.indoption::varchar, i.reloptions, am.amname,
+ pg_get_expr(ix.indpred, ix.indrelid),
+ %s as indnkeyatts
+ FROM
+ pg_class t
+ join pg_index ix on t.oid = ix.indrelid
+ join pg_class i on i.oid = ix.indexrelid
+ left outer join
+ pg_attribute a
+ on t.oid = a.attrelid and a.attnum = ANY(ix.indkey)
+ left outer join
+ pg_constraint c
+ on (ix.indrelid = c.conrelid and
+ ix.indexrelid = c.conindid and
+ c.contype in ('p', 'u', 'x'))
+ left outer join
+ pg_am am
+ on i.relam = am.oid
+ WHERE
+ t.relkind IN ('r', 'v', 'f', 'm', 'p')
+ and t.oid = :table_oid
+ and ix.indisprimary = 'f'
+ ORDER BY
+ t.relname,
+ i.relname
+ """ % (
+ "ix.indnkeyatts"
+ if self.server_version_info >= (11, 0)
+ else "NULL",
+ )
+
+ t = sql.text(IDX_SQL).columns(
+ relname=sqltypes.Unicode, attname=sqltypes.Unicode
+ )
+ c = connection.execute(t, dict(table_oid=table_oid))
+
+ indexes = defaultdict(lambda: defaultdict(dict))
+
+ sv_idx_name = None
+ for row in c.fetchall():
+ (
+ idx_name,
+ unique,
+ expr,
+ col,
+ col_num,
+ conrelid,
+ idx_key,
+ idx_option,
+ options,
+ amname,
+ filter_definition,
+ indnkeyatts,
+ ) = row
+
+ if expr:
+ if idx_name != sv_idx_name:
+ util.warn(
+ "Skipped unsupported reflection of "
+ "expression-based index %s" % idx_name
+ )
+ sv_idx_name = idx_name
+ continue
+
+ has_idx = idx_name in indexes
+ index = indexes[idx_name]
+ if col is not None:
+ index["cols"][col_num] = col
+ if not has_idx:
+ idx_keys = idx_key.split()
+ # "The number of key columns in the index, not counting any
+ # included columns, which are merely stored and do not
+ # participate in the index semantics"
+ if indnkeyatts and idx_keys[indnkeyatts:]:
+ # this is a "covering index" which has INCLUDE columns
+ # as well as regular index columns
+ inc_keys = idx_keys[indnkeyatts:]
+ idx_keys = idx_keys[:indnkeyatts]
+ else:
+ inc_keys = []
+
+ index["key"] = [int(k.strip()) for k in idx_keys]
+ index["inc"] = [int(k.strip()) for k in inc_keys]
+
+ # (new in pg 8.3)
+ # "pg_index.indoption" is list of ints, one per column/expr.
+ # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST
+ sorting = {}
+ for col_idx, col_flags in enumerate(
+ (idx_option or "").split()
+ ):
+ col_flags = int(col_flags.strip())
+ col_sorting = ()
+ # try to set flags only if they differ from PG defaults...
+ if col_flags & 0x01:
+ col_sorting += ("desc",)
+ if not (col_flags & 0x02):
+ col_sorting += ("nulls_last",)
+ else:
+ if col_flags & 0x02:
+ col_sorting += ("nulls_first",)
+ if col_sorting:
+ sorting[col_idx] = col_sorting
+ if sorting:
+ index["sorting"] = sorting
+
+ index["unique"] = unique
+ if conrelid is not None:
+ index["duplicates_constraint"] = idx_name
+ if options:
+ index["options"] = dict(
+ [option.split("=") for option in options]
+ )
+
+ # it *might* be nice to include that this is 'btree' in the
+ # reflection info. But we don't want an Index object
+ # to have a ``postgresql_using`` in it that is just the
+ # default, so for the moment leaving this out.
+ if amname and amname != "btree":
+ index["amname"] = amname
+
+ if filter_definition:
+ index["postgresql_where"] = filter_definition
+
+ result = []
+ for name, idx in indexes.items():
+ entry = {
+ "name": name,
+ "unique": idx["unique"],
+ "column_names": [idx["cols"][i] for i in idx["key"]],
+ }
+ if self.server_version_info >= (11, 0):
+ # NOTE: this is legacy, this is part of dialect_options now
+ # as of #7382
+ entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]]
+ if "duplicates_constraint" in idx:
+ entry["duplicates_constraint"] = idx["duplicates_constraint"]
+ if "sorting" in idx:
+ entry["column_sorting"] = dict(
+ (idx["cols"][idx["key"][i]], value)
+ for i, value in idx["sorting"].items()
+ )
+ if "include_columns" in entry:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_include"
+ ] = entry["include_columns"]
+ if "options" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_with"
+ ] = idx["options"]
+ if "amname" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_using"
+ ] = idx["amname"]
+ if "postgresql_where" in idx:
+ entry.setdefault("dialect_options", {})[
+ "postgresql_where"
+ ] = idx["postgresql_where"]
+ result.append(entry)
+ return result
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ UNIQUE_SQL = """
+ SELECT
+ cons.conname as name,
+ cons.conkey as key,
+ a.attnum as col_num,
+ a.attname as col_name
+ FROM
+ pg_catalog.pg_constraint cons
+ join pg_attribute a
+ on cons.conrelid = a.attrelid AND
+ a.attnum = ANY(cons.conkey)
+ WHERE
+ cons.conrelid = :table_oid AND
+ cons.contype = 'u'
+ """
+
+ t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
+ c = connection.execute(t, dict(table_oid=table_oid))
+
+ uniques = defaultdict(lambda: defaultdict(dict))
+ for row in c.fetchall():
+ uc = uniques[row.name]
+ uc["key"] = row.key
+ uc["cols"][row.col_num] = row.col_name
+
+ return [
+ {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]}
+ for name, uc in uniques.items()
+ ]
+
+ @reflection.cache
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ COMMENT_SQL = """
+ SELECT
+ pgd.description as table_comment
+ FROM
+ pg_catalog.pg_description pgd
+ WHERE
+ pgd.objsubid = 0 AND
+ pgd.objoid = :table_oid
+ """
+
+ c = connection.execute(
+ sql.text(COMMENT_SQL), dict(table_oid=table_oid)
+ )
+ return {"text": c.scalar()}
+
+ @reflection.cache
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ table_oid = self.get_table_oid(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ CHECK_SQL = """
+ SELECT
+ cons.conname as name,
+ pg_get_constraintdef(cons.oid) as src
+ FROM
+ pg_catalog.pg_constraint cons
+ WHERE
+ cons.conrelid = :table_oid AND
+ cons.contype = 'c'
+ """
+
+ c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
+
+ ret = []
+ for name, src in c:
+ # samples:
+ # "CHECK (((a > 1) AND (a < 5)))"
+ # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))"
+ # "CHECK (((a > 1) AND (a < 5))) NOT VALID"
+ # "CHECK (some_boolean_function(a))"
+ # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)"
+
+ m = re.match(
+ r"^CHECK *\((.+)\)( NOT VALID)?$", src, flags=re.DOTALL
+ )
+ if not m:
+ util.warn("Could not parse CHECK constraint text: %r" % src)
+ sqltext = ""
+ else:
+ sqltext = re.compile(
+ r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL
+ ).sub(r"\1", m.group(1))
+ entry = {"name": name, "sqltext": sqltext}
+ if m and m.group(2):
+ entry["dialect_options"] = {"not_valid": True}
+
+ ret.append(entry)
+ return ret
+
+ def _load_enums(self, connection, schema=None):
+ schema = schema or self.default_schema_name
+ if not self.supports_native_enum:
+ return {}
+
+ # Load data types for enums:
+ SQL_ENUMS = """
+ SELECT t.typname as "name",
+ -- no enum defaults in 8.4 at least
+ -- t.typdefault as "default",
+ pg_catalog.pg_type_is_visible(t.oid) as "visible",
+ n.nspname as "schema",
+ e.enumlabel as "label"
+ FROM pg_catalog.pg_type t
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid
+ WHERE t.typtype = 'e'
+ """
+
+ if schema != "*":
+ SQL_ENUMS += "AND n.nspname = :schema "
+
+ # e.oid gives us label order within an enum
+ SQL_ENUMS += 'ORDER BY "schema", "name", e.oid'
+
+ s = sql.text(SQL_ENUMS).columns(
+ attname=sqltypes.Unicode, label=sqltypes.Unicode
+ )
+
+ if schema != "*":
+ s = s.bindparams(schema=schema)
+
+ c = connection.execute(s)
+
+ enums = []
+ enum_by_name = {}
+ for enum in c.fetchall():
+ key = (enum.schema, enum.name)
+ if key in enum_by_name:
+ enum_by_name[key]["labels"].append(enum.label)
+ else:
+ enum_by_name[key] = enum_rec = {
+ "name": enum.name,
+ "schema": enum.schema,
+ "visible": enum.visible,
+ "labels": [],
+ }
+ if enum.label is not None:
+ enum_rec["labels"].append(enum.label)
+ enums.append(enum_rec)
+ return enums
+
+ def _load_domains(self, connection):
+ # Load data types for domains:
+ SQL_DOMAINS = """
+ SELECT t.typname as "name",
+ pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
+ not t.typnotnull as "nullable",
+ t.typdefault as "default",
+ pg_catalog.pg_type_is_visible(t.oid) as "visible",
+ n.nspname as "schema"
+ FROM pg_catalog.pg_type t
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ WHERE t.typtype = 'd'
+ """
+
+ s = sql.text(SQL_DOMAINS)
+ c = connection.execution_options(future_result=True).execute(s)
+
+ domains = {}
+ for domain in c.mappings():
+ domain = domain
+ # strip (30) from character varying(30)
+ attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
+ # 'visible' just means whether or not the domain is in a
+ # schema that's on the search path -- or not overridden by
+ # a schema with higher precedence. If it's not visible,
+ # it will be prefixed with the schema-name when it's used.
+ if domain["visible"]:
+ key = (domain["name"],)
+ else:
+ key = (domain["schema"], domain["name"])
+
+ domains[key] = {
+ "attype": attype,
+ "nullable": domain["nullable"],
+ "default": domain["default"],
+ }
+
+ return domains
diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py
new file mode 100644
index 0000000..b483774
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/dml.py
@@ -0,0 +1,274 @@
+# postgresql/on_conflict.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import ext
+from ... import util
+from ...sql import coercions
+from ...sql import roles
+from ...sql import schema
+from ...sql.base import _exclusive_against
+from ...sql.base import _generative
+from ...sql.base import ColumnCollection
+from ...sql.dml import Insert as StandardInsert
+from ...sql.elements import ClauseElement
+from ...sql.expression import alias
+from ...util.langhelpers import public_factory
+
+
+__all__ = ("Insert", "insert")
+
+
+class Insert(StandardInsert):
+ """PostgreSQL-specific implementation of INSERT.
+
+ Adds methods for PG-specific syntaxes such as ON CONFLICT.
+
+ The :class:`_postgresql.Insert` object is created using the
+ :func:`sqlalchemy.dialects.postgresql.insert` function.
+
+ .. versionadded:: 1.1
+
+ """
+
+ stringify_dialect = "postgresql"
+ inherit_cache = False
+
+ @util.memoized_property
+ def excluded(self):
+ """Provide the ``excluded`` namespace for an ON CONFLICT statement
+
+ PG's ON CONFLICT clause allows reference to the row that would
+ be inserted, known as ``excluded``. This attribute provides
+ all columns in this row to be referenceable.
+
+ .. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an
+ instance of :class:`_expression.ColumnCollection`, which provides
+ an interface the same as that of the :attr:`_schema.Table.c`
+ collection described at :ref:`metadata_tables_and_columns`.
+ With this collection, ordinary names are accessible like attributes
+ (e.g. ``stmt.excluded.some_column``), but special names and
+ dictionary method names should be accessed using indexed access,
+ such as ``stmt.excluded["column name"]`` or
+ ``stmt.excluded["values"]``. See the docstring for
+ :class:`_expression.ColumnCollection` for further examples.
+
+ .. seealso::
+
+ :ref:`postgresql_insert_on_conflict` - example of how
+ to use :attr:`_expression.Insert.excluded`
+
+ """
+ return alias(self.table, name="excluded").columns
+
+ _on_conflict_exclusive = _exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already has "
+ "an ON CONFLICT clause established"
+ },
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_update(
+ self,
+ constraint=None,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ r"""
+ Specifies a DO UPDATE SET action for ON CONFLICT clause.
+
+ Either the ``constraint`` or ``index_elements`` argument is
+ required, but only one of these can be specified.
+
+ :param constraint:
+ The name of a unique or exclusion constraint on the table,
+ or the constraint object itself if it has a .name attribute.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ :param set\_:
+ A dictionary or other mapping object
+ where the keys are either names of columns in the target table,
+ or :class:`_schema.Column` objects or other ORM-mapped columns
+ matching that of the target table, and expressions or literals
+ as values, specifying the ``SET`` actions to take.
+
+ .. versionadded:: 1.4 The
+ :paramref:`_postgresql.Insert.on_conflict_do_update.set_`
+ parameter supports :class:`_schema.Column` objects from the target
+ :class:`_schema.Table` as keys.
+
+ .. warning:: This dictionary does **not** take into account
+ Python-specified default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON CONFLICT style of
+ UPDATE, unless they are manually specified in the
+ :paramref:`.Insert.on_conflict_do_update.set_` dictionary.
+
+ :param where:
+ Optional argument. If present, can be a literal SQL
+ string or an acceptable expression for a ``WHERE`` clause
+ that restricts the rows affected by ``DO UPDATE SET``. Rows
+ not meeting the ``WHERE`` condition will not be updated
+ (effectively a ``DO NOTHING`` for those rows).
+
+ .. versionadded:: 1.1
+
+
+ .. seealso::
+
+ :ref:`postgresql_insert_on_conflict`
+
+ """
+ self._post_values_clause = OnConflictDoUpdate(
+ constraint, index_elements, index_where, set_, where
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_nothing(
+ self, constraint=None, index_elements=None, index_where=None
+ ):
+ """
+ Specifies a DO NOTHING action for ON CONFLICT clause.
+
+ The ``constraint`` and ``index_elements`` arguments
+ are optional, but only one of these can be specified.
+
+ :param constraint:
+ The name of a unique or exclusion constraint on the table,
+ or the constraint object itself if it has a .name attribute.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`postgresql_insert_on_conflict`
+
+ """
+ self._post_values_clause = OnConflictDoNothing(
+ constraint, index_elements, index_where
+ )
+
+
+insert = public_factory(
+ Insert, ".dialects.postgresql.insert", ".dialects.postgresql.Insert"
+)
+
+
+class OnConflictClause(ClauseElement):
+ stringify_dialect = "postgresql"
+
+ def __init__(self, constraint=None, index_elements=None, index_where=None):
+
+ if constraint is not None:
+ if not isinstance(constraint, util.string_types) and isinstance(
+ constraint,
+ (schema.Index, schema.Constraint, ext.ExcludeConstraint),
+ ):
+ constraint = getattr(constraint, "name") or constraint
+
+ if constraint is not None:
+ if index_elements is not None:
+ raise ValueError(
+ "'constraint' and 'index_elements' are mutually exclusive"
+ )
+
+ if isinstance(constraint, util.string_types):
+ self.constraint_target = constraint
+ self.inferred_target_elements = None
+ self.inferred_target_whereclause = None
+ elif isinstance(constraint, schema.Index):
+ index_elements = constraint.expressions
+ index_where = constraint.dialect_options["postgresql"].get(
+ "where"
+ )
+ elif isinstance(constraint, ext.ExcludeConstraint):
+ index_elements = constraint.columns
+ index_where = constraint.where
+ else:
+ index_elements = constraint.columns
+ index_where = constraint.dialect_options["postgresql"].get(
+ "where"
+ )
+
+ if index_elements is not None:
+ self.constraint_target = None
+ self.inferred_target_elements = index_elements
+ self.inferred_target_whereclause = index_where
+ elif constraint is None:
+ self.constraint_target = (
+ self.inferred_target_elements
+ ) = self.inferred_target_whereclause = None
+
+
+class OnConflictDoNothing(OnConflictClause):
+ __visit_name__ = "on_conflict_do_nothing"
+
+
+class OnConflictDoUpdate(OnConflictClause):
+ __visit_name__ = "on_conflict_do_update"
+
+ def __init__(
+ self,
+ constraint=None,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ super(OnConflictDoUpdate, self).__init__(
+ constraint=constraint,
+ index_elements=index_elements,
+ index_where=index_where,
+ )
+
+ if (
+ self.inferred_target_elements is None
+ and self.constraint_target is None
+ ):
+ raise ValueError(
+ "Either constraint or index_elements, "
+ "but not both, must be specified unless DO NOTHING"
+ )
+
+ if isinstance(set_, dict):
+ if not set_:
+ raise ValueError("set parameter dictionary must not be empty")
+ elif isinstance(set_, ColumnCollection):
+ set_ = dict(set_)
+ else:
+ raise ValueError(
+ "set parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
+ self.update_values_to_set = [
+ (coercions.expect(roles.DMLColumnRole, key), value)
+ for key, value in set_.items()
+ ]
+ self.update_whereclause = where
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
new file mode 100644
index 0000000..9e52ee1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -0,0 +1,277 @@
+# postgresql/ext.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .array import ARRAY
+from ... import util
+from ...sql import coercions
+from ...sql import elements
+from ...sql import expression
+from ...sql import functions
+from ...sql import roles
+from ...sql import schema
+from ...sql.schema import ColumnCollectionConstraint
+
+
+class aggregate_order_by(expression.ColumnElement):
+ """Represent a PostgreSQL aggregate order by expression.
+
+ E.g.::
+
+ from sqlalchemy.dialects.postgresql import aggregate_order_by
+ expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
+ stmt = select(expr)
+
+ would represent the expression::
+
+ SELECT array_agg(a ORDER BY b DESC) FROM table;
+
+ Similarly::
+
+ expr = func.string_agg(
+ table.c.a,
+ aggregate_order_by(literal_column("','"), table.c.a)
+ )
+ stmt = select(expr)
+
+ Would represent::
+
+ SELECT string_agg(a, ',' ORDER BY a) FROM table;
+
+ .. versionadded:: 1.1
+
+ .. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms
+
+ .. seealso::
+
+ :class:`_functions.array_agg`
+
+ """
+
+ __visit_name__ = "aggregate_order_by"
+
+ stringify_dialect = "postgresql"
+ inherit_cache = False
+
+ def __init__(self, target, *order_by):
+ self.target = coercions.expect(roles.ExpressionElementRole, target)
+ self.type = self.target.type
+
+ _lob = len(order_by)
+ if _lob == 0:
+ raise TypeError("at least one ORDER BY element is required")
+ elif _lob == 1:
+ self.order_by = coercions.expect(
+ roles.ExpressionElementRole, order_by[0]
+ )
+ else:
+ self.order_by = elements.ClauseList(
+ *order_by, _literal_as_text_role=roles.ExpressionElementRole
+ )
+
+ def self_group(self, against=None):
+ return self
+
+ def get_children(self, **kwargs):
+ return self.target, self.order_by
+
+ def _copy_internals(self, clone=elements._clone, **kw):
+ self.target = clone(self.target, **kw)
+ self.order_by = clone(self.order_by, **kw)
+
+ @property
+ def _from_objects(self):
+ return self.target._from_objects + self.order_by._from_objects
+
+
+class ExcludeConstraint(ColumnCollectionConstraint):
+ """A table-level EXCLUDE constraint.
+
+ Defines an EXCLUDE constraint as described in the `PostgreSQL
+ documentation`__.
+
+ __ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
+
+ """ # noqa
+
+ __visit_name__ = "exclude_constraint"
+
+ where = None
+ inherit_cache = False
+
+ create_drop_stringify_dialect = "postgresql"
+
+ @elements._document_text_coercion(
+ "where",
+ ":class:`.ExcludeConstraint`",
+ ":paramref:`.ExcludeConstraint.where`",
+ )
+ def __init__(self, *elements, **kw):
+ r"""
+ Create an :class:`.ExcludeConstraint` object.
+
+ E.g.::
+
+ const = ExcludeConstraint(
+ (Column('period'), '&&'),
+ (Column('group'), '='),
+ where=(Column('group') != 'some group'),
+ ops={'group': 'my_operator_class'}
+ )
+
+ The constraint is normally embedded into the :class:`_schema.Table`
+ construct
+ directly, or added later using :meth:`.append_constraint`::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('period', TSRANGE()),
+ Column('group', String)
+ )
+
+ some_table.append_constraint(
+ ExcludeConstraint(
+ (some_table.c.period, '&&'),
+ (some_table.c.group, '='),
+ where=some_table.c.group != 'some group',
+ name='some_table_excl_const',
+ ops={'group': 'my_operator_class'}
+ )
+ )
+
+ :param \*elements:
+
+ A sequence of two tuples of the form ``(column, operator)`` where
+ "column" is a SQL expression element or a raw SQL string, most
+ typically a :class:`_schema.Column` object,
+ and "operator" is a string
+ containing the operator to use. In order to specify a column name
+ when a :class:`_schema.Column` object is not available,
+ while ensuring
+ that any necessary quoting rules take effect, an ad-hoc
+ :class:`_schema.Column` or :func:`_expression.column`
+ object should be
+ used.
+
+ :param name:
+ Optional, the in-database name of this constraint.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param using:
+ Optional string. If set, emit USING <index_method> when issuing DDL
+ for this constraint. Defaults to 'gist'.
+
+ :param where:
+ Optional SQL expression construct or literal SQL string.
+ If set, emit WHERE <predicate> when issuing DDL
+ for this constraint.
+
+ :param ops:
+ Optional dictionary. Used to define operator classes for the
+ elements; works the same way as that of the
+ :ref:`postgresql_ops <postgresql_operator_classes>`
+ parameter specified to the :class:`_schema.Index` construct.
+
+ .. versionadded:: 1.3.21
+
+ .. seealso::
+
+ :ref:`postgresql_operator_classes` - general description of how
+ PostgreSQL operator classes are specified.
+
+ """
+ columns = []
+ render_exprs = []
+ self.operators = {}
+
+ expressions, operators = zip(*elements)
+
+ for (expr, column, strname, add_element), operator in zip(
+ coercions.expect_col_expression_collection(
+ roles.DDLConstraintColumnRole, expressions
+ ),
+ operators,
+ ):
+ if add_element is not None:
+ columns.append(add_element)
+
+ name = column.name if column is not None else strname
+
+ if name is not None:
+ # backwards compat
+ self.operators[name] = operator
+
+ render_exprs.append((expr, name, operator))
+
+ self._render_exprs = render_exprs
+
+ ColumnCollectionConstraint.__init__(
+ self,
+ *columns,
+ name=kw.get("name"),
+ deferrable=kw.get("deferrable"),
+ initially=kw.get("initially")
+ )
+ self.using = kw.get("using", "gist")
+ where = kw.get("where")
+ if where is not None:
+ self.where = coercions.expect(roles.StatementOptionRole, where)
+
+ self.ops = kw.get("ops", {})
+
+ def _set_parent(self, table, **kw):
+ super(ExcludeConstraint, self)._set_parent(table)
+
+ self._render_exprs = [
+ (
+ expr if isinstance(expr, elements.ClauseElement) else colexpr,
+ name,
+ operator,
+ )
+ for (expr, name, operator), colexpr in util.zip_longest(
+ self._render_exprs, self.columns
+ )
+ ]
+
+ def _copy(self, target_table=None, **kw):
+ elements = [
+ (
+ schema._copy_expression(expr, self.parent, target_table),
+ self.operators[expr.name],
+ )
+ for expr in self.columns
+ ]
+ c = self.__class__(
+ *elements,
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ where=self.where,
+ using=self.using
+ )
+ c.dispatch._update(self.dispatch)
+ return c
+
+
+def array_agg(*arg, **kw):
+ """PostgreSQL-specific form of :class:`_functions.array_agg`, ensures
+ return type is :class:`_postgresql.ARRAY` and not
+ the plain :class:`_types.ARRAY`, unless an explicit ``type_``
+ is passed.
+
+ .. versionadded:: 1.1
+
+ """
+ kw["_default_array_type"] = ARRAY
+ return functions.func.array_agg(*arg, **kw)
diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py
new file mode 100644
index 0000000..29800d2
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/hstore.py
@@ -0,0 +1,455 @@
+# postgresql/hstore.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import re
+
+from .array import ARRAY
+from ... import types as sqltypes
+from ... import util
+from ...sql import functions as sqlfunc
+from ...sql import operators
+
+
+__all__ = ("HSTORE", "hstore")
+
+idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
+
+GETITEM = operators.custom_op(
+ "->",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_KEY = operators.custom_op(
+ "?",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ALL = operators.custom_op(
+ "?&",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ANY = operators.custom_op(
+ "?|",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINS = operators.custom_op(
+ "@>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINED_BY = operators.custom_op(
+ "<@",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+
+class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
+ """Represent the PostgreSQL HSTORE type.
+
+ The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
+
+ data_table = Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', HSTORE)
+ )
+
+ with engine.connect() as conn:
+ conn.execute(
+ data_table.insert(),
+ data = {"key1": "value1", "key2": "value2"}
+ )
+
+ :class:`.HSTORE` provides for a wide range of operations, including:
+
+ * Index operations::
+
+ data_table.c.data['some key'] == 'some value'
+
+ * Containment operations::
+
+ data_table.c.data.has_key('some key')
+
+ data_table.c.data.has_all(['one', 'two', 'three'])
+
+ * Concatenation::
+
+ data_table.c.data + {"k1": "v1"}
+
+ For a full list of special methods see
+ :class:`.HSTORE.comparator_factory`.
+
+ For usage with the SQLAlchemy ORM, it may be desirable to combine
+ the usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary
+ now part of the :mod:`sqlalchemy.ext.mutable`
+ extension. This extension will allow "in-place" changes to the
+ dictionary, e.g. addition of new keys or replacement/removal of existing
+ keys to/from the current dictionary, to produce events which will be
+ detected by the unit of work::
+
+ from sqlalchemy.ext.mutable import MutableDict
+
+ class MyClass(Base):
+ __tablename__ = 'data_table'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(HSTORE))
+
+ my_object = session.query(MyClass).one()
+
+ # in-place mutation, requires Mutable extension
+ # in order for the ORM to detect
+ my_object.data['some_key'] = 'some value'
+
+ session.commit()
+
+ When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM
+ will not be alerted to any changes to the contents of an existing
+ dictionary, unless that dictionary value is re-assigned to the
+ HSTORE-attribute itself, thus generating a change event.
+
+ .. seealso::
+
+ :class:`.hstore` - render the PostgreSQL ``hstore()`` function.
+
+
+ """
+
+ __visit_name__ = "HSTORE"
+ hashable = False
+ text_type = sqltypes.Text()
+
+ def __init__(self, text_type=None):
+ """Construct a new :class:`.HSTORE`.
+
+ :param text_type: the type that should be used for indexed values.
+ Defaults to :class:`_types.Text`.
+
+ .. versionadded:: 1.1.0
+
+ """
+ if text_type is not None:
+ self.text_type = text_type
+
+ class Comparator(
+ sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
+ ):
+ """Define comparison operations for :class:`.HSTORE`."""
+
+ def has_key(self, other):
+ """Boolean expression. Test for presence of a key. Note that the
+ key may be a SQLA expression.
+ """
+ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
+
+ def has_all(self, other):
+ """Boolean expression. Test for presence of all keys in jsonb"""
+ return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
+
+ def has_any(self, other):
+ """Boolean expression. Test for presence of any key in jsonb"""
+ return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
+
+ def contains(self, other, **kwargs):
+ """Boolean expression. Test if keys (or array) are a superset
+ of/contained the keys of the argument jsonb expression.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
+
+ def contained_by(self, other):
+ """Boolean expression. Test if keys are a proper subset of the
+ keys of the argument jsonb expression.
+ """
+ return self.operate(
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
+
+ def _setup_getitem(self, index):
+ return GETITEM, index, self.type.text_type
+
+ def defined(self, key):
+ """Boolean expression. Test for presence of a non-NULL value for
+ the key. Note that the key may be a SQLA expression.
+ """
+ return _HStoreDefinedFunction(self.expr, key)
+
+ def delete(self, key):
+ """HStore expression. Returns the contents of this hstore with the
+ given key deleted. Note that the key may be a SQLA expression.
+ """
+ if isinstance(key, dict):
+ key = _serialize_hstore(key)
+ return _HStoreDeleteFunction(self.expr, key)
+
+ def slice(self, array):
+ """HStore expression. Returns a subset of an hstore defined by
+ array of keys.
+ """
+ return _HStoreSliceFunction(self.expr, array)
+
+ def keys(self):
+ """Text array expression. Returns array of keys."""
+ return _HStoreKeysFunction(self.expr)
+
+ def vals(self):
+ """Text array expression. Returns array of values."""
+ return _HStoreValsFunction(self.expr)
+
+ def array(self):
+ """Text array expression. Returns array of alternating keys and
+ values.
+ """
+ return _HStoreArrayFunction(self.expr)
+
+ def matrix(self):
+ """Text array expression. Returns array of [key, value] pairs."""
+ return _HStoreMatrixFunction(self.expr)
+
+ comparator_factory = Comparator
+
+ def bind_processor(self, dialect):
+ if util.py2k:
+ encoding = dialect.encoding
+
+ def process(value):
+ if isinstance(value, dict):
+ return _serialize_hstore(value).encode(encoding)
+ else:
+ return value
+
+ else:
+
+ def process(value):
+ if isinstance(value, dict):
+ return _serialize_hstore(value)
+ else:
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if util.py2k:
+ encoding = dialect.encoding
+
+ def process(value):
+ if value is not None:
+ return _parse_hstore(value.decode(encoding))
+ else:
+ return value
+
+ else:
+
+ def process(value):
+ if value is not None:
+ return _parse_hstore(value)
+ else:
+ return value
+
+ return process
+
+
+class hstore(sqlfunc.GenericFunction):
+ """Construct an hstore value within a SQL expression using the
+ PostgreSQL ``hstore()`` function.
+
+ The :class:`.hstore` function accepts one or two arguments as described
+ in the PostgreSQL documentation.
+
+ E.g.::
+
+ from sqlalchemy.dialects.postgresql import array, hstore
+
+ select(hstore('key1', 'value1'))
+
+ select(
+ hstore(
+ array(['key1', 'key2', 'key3']),
+ array(['value1', 'value2', 'value3'])
+ )
+ )
+
+ .. seealso::
+
+ :class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
+
+ """
+
+ type = HSTORE
+ name = "hstore"
+ inherit_cache = True
+
+
+class _HStoreDefinedFunction(sqlfunc.GenericFunction):
+ type = sqltypes.Boolean
+ name = "defined"
+ inherit_cache = True
+
+
+class _HStoreDeleteFunction(sqlfunc.GenericFunction):
+ type = HSTORE
+ name = "delete"
+ inherit_cache = True
+
+
+class _HStoreSliceFunction(sqlfunc.GenericFunction):
+ type = HSTORE
+ name = "slice"
+ inherit_cache = True
+
+
+class _HStoreKeysFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "akeys"
+ inherit_cache = True
+
+
+class _HStoreValsFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "avals"
+ inherit_cache = True
+
+
+class _HStoreArrayFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "hstore_to_array"
+ inherit_cache = True
+
+
+class _HStoreMatrixFunction(sqlfunc.GenericFunction):
+ type = ARRAY(sqltypes.Text)
+ name = "hstore_to_matrix"
+ inherit_cache = True
+
+
+#
+# parsing. note that none of this is used with the psycopg2 backend,
+# which provides its own native extensions.
+#
+
+# My best guess at the parsing rules of hstore literals, since no formal
+# grammar is given. This is mostly reverse engineered from PG's input parser
+# behavior.
+HSTORE_PAIR_RE = re.compile(
+ r"""
+(
+ "(?P<key> (\\ . | [^"])* )" # Quoted key
+)
+[ ]* => [ ]* # Pair operator, optional adjoining whitespace
+(
+ (?P<value_null> NULL ) # NULL value
+ | "(?P<value> (\\ . | [^"])* )" # Quoted value
+)
+""",
+ re.VERBOSE,
+)
+
+HSTORE_DELIMITER_RE = re.compile(
+ r"""
+[ ]* , [ ]*
+""",
+ re.VERBOSE,
+)
+
+
+def _parse_error(hstore_str, pos):
+ """format an unmarshalling error."""
+
+ ctx = 20
+ hslen = len(hstore_str)
+
+ parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)]
+ residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)]
+
+ if len(parsed_tail) > ctx:
+ parsed_tail = "[...]" + parsed_tail[1:]
+ if len(residual) > ctx:
+ residual = residual[:-1] + "[...]"
+
+ return "After %r, could not parse residual at position %d: %r" % (
+ parsed_tail,
+ pos,
+ residual,
+ )
+
+
+def _parse_hstore(hstore_str):
+ """Parse an hstore from its literal string representation.
+
+ Attempts to approximate PG's hstore input parsing rules as closely as
+ possible. Although currently this is not strictly necessary, since the
+ current implementation of hstore's output syntax is stricter than what it
+ accepts as input, the documentation makes no guarantees that will always
+ be the case.
+
+
+
+ """
+ result = {}
+ pos = 0
+ pair_match = HSTORE_PAIR_RE.match(hstore_str)
+
+ while pair_match is not None:
+ key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\")
+ if pair_match.group("value_null"):
+ value = None
+ else:
+ value = (
+ pair_match.group("value")
+ .replace(r"\"", '"')
+ .replace("\\\\", "\\")
+ )
+ result[key] = value
+
+ pos += pair_match.end()
+
+ delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
+ if delim_match is not None:
+ pos += delim_match.end()
+
+ pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
+
+ if pos != len(hstore_str):
+ raise ValueError(_parse_error(hstore_str, pos))
+
+ return result
+
+
+def _serialize_hstore(val):
+ """Serialize a dictionary into an hstore literal. Keys and values must
+ both be strings (except None for values).
+
+ """
+
+ def esc(s, position):
+ if position == "value" and s is None:
+ return "NULL"
+ elif isinstance(s, util.string_types):
+ return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"")
+ else:
+ raise ValueError(
+ "%r in %s position is not a string." % (s, position)
+ )
+
+ return ", ".join(
+ "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items()
+ )
diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py
new file mode 100644
index 0000000..daaaeac
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/json.py
@@ -0,0 +1,327 @@
+# postgresql/json.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import absolute_import
+
+from ... import types as sqltypes
+from ... import util
+from ...sql import operators
+
+
+__all__ = ("JSON", "JSONB")
+
+idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
+
+ASTEXT = operators.custom_op(
+ "->>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+JSONPATH_ASTEXT = operators.custom_op(
+ "#>>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+
+HAS_KEY = operators.custom_op(
+ "?",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ALL = operators.custom_op(
+ "?&",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+HAS_ANY = operators.custom_op(
+ "?|",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINS = operators.custom_op(
+ "@>",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+CONTAINED_BY = operators.custom_op(
+ "<@",
+ precedence=idx_precedence,
+ natural_self_precedent=True,
+ eager_grouping=True,
+)
+
+
+class JSONPathType(sqltypes.JSON.JSONPathType):
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ assert isinstance(value, util.collections_abc.Sequence)
+ tokens = [util.text_type(elem) for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ assert isinstance(value, util.collections_abc.Sequence)
+ tokens = [util.text_type(elem) for elem in value]
+ value = "{%s}" % (", ".join(tokens))
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSON(sqltypes.JSON):
+ """Represent the PostgreSQL JSON type.
+
+ :class:`_postgresql.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a PostgreSQL backend,
+ however base :class:`_types.JSON` datatype does not provide Python
+ accessors for PostgreSQL-specific comparison methods such as
+ :meth:`_postgresql.JSON.Comparator.astext`; additionally, to use
+ PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should
+ be used explicitly.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The operators provided by the PostgreSQL version of :class:`_types.JSON`
+ include:
+
+ * Index operations (the ``->`` operator)::
+
+ data_table.c.data['some key']
+
+ data_table.c.data[5]
+
+
+ * Index operations returning text (the ``->>`` operator)::
+
+ data_table.c.data['some key'].astext == 'some value'
+
+ Note that equivalent functionality is available via the
+ :attr:`.JSON.Comparator.as_string` accessor.
+
+ * Index operations with CAST
+ (equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
+
+ data_table.c.data['some key'].astext.cast(Integer) == 5
+
+ Note that equivalent functionality is available via the
+ :attr:`.JSON.Comparator.as_integer` and similar accessors.
+
+ * Path index operations (the ``#>`` operator)::
+
+ data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
+
+ * Path index operations returning text (the ``#>>`` operator)::
+
+ data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
+
+ .. versionchanged:: 1.1 The :meth:`_expression.ColumnElement.cast`
+ operator on
+ JSON objects now requires that the :attr:`.JSON.Comparator.astext`
+ modifier be called explicitly, if the cast works only from a textual
+ string.
+
+ Index operations return an expression object whose type defaults to
+ :class:`_types.JSON` by default,
+ so that further JSON-oriented instructions
+ may be called upon the result type.
+
+ Custom serializers and deserializers are specified at the dialect level,
+ that is using :func:`_sa.create_engine`. The reason for this is that when
+ using psycopg2, the DBAPI only allows serializers at the per-cursor
+ or per-connection level. E.g.::
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test",
+ json_serializer=my_serialize_fn,
+ json_deserializer=my_deserialize_fn
+ )
+
+ When using the psycopg2 dialect, the json_deserializer is registered
+ against the database using ``psycopg2.extras.register_default_json``.
+
+ .. seealso::
+
+ :class:`_types.JSON` - Core level JSON type
+
+ :class:`_postgresql.JSONB`
+
+ .. versionchanged:: 1.1 :class:`_postgresql.JSON` is now a PostgreSQL-
+ specific specialization of the new :class:`_types.JSON` type.
+
+ """ # noqa
+
+ astext_type = sqltypes.Text()
+
+ def __init__(self, none_as_null=False, astext_type=None):
+ """Construct a :class:`_types.JSON` type.
+
+ :param none_as_null: if True, persist the value ``None`` as a
+ SQL NULL value, not the JSON encoding of ``null``. Note that
+ when this flag is False, the :func:`.null` construct can still
+ be used to persist a NULL value::
+
+ from sqlalchemy import null
+ conn.execute(table.insert(), data=null())
+
+ .. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null`
+ is now supported in order to persist a NULL value.
+
+ .. seealso::
+
+ :attr:`_types.JSON.NULL`
+
+ :param astext_type: the type to use for the
+ :attr:`.JSON.Comparator.astext`
+ accessor on indexed attributes. Defaults to :class:`_types.Text`.
+
+ .. versionadded:: 1.1
+
+ """
+ super(JSON, self).__init__(none_as_null=none_as_null)
+ if astext_type is not None:
+ self.astext_type = astext_type
+
+ class Comparator(sqltypes.JSON.Comparator):
+ """Define comparison operations for :class:`_types.JSON`."""
+
+ @property
+ def astext(self):
+ """On an indexed expression, use the "astext" (e.g. "->>")
+ conversion when rendered in SQL.
+
+ E.g.::
+
+ select(data_table.c.data['some key'].astext)
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.cast`
+
+ """
+ if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
+ return self.expr.left.operate(
+ JSONPATH_ASTEXT,
+ self.expr.right,
+ result_type=self.type.astext_type,
+ )
+ else:
+ return self.expr.left.operate(
+ ASTEXT, self.expr.right, result_type=self.type.astext_type
+ )
+
+ comparator_factory = Comparator
+
+
+class JSONB(JSON):
+ """Represent the PostgreSQL JSONB type.
+
+ The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
+ e.g.::
+
+ data_table = Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', JSONB)
+ )
+
+ with engine.connect() as conn:
+ conn.execute(
+ data_table.insert(),
+ data = {"key1": "value1", "key2": "value2"}
+ )
+
+ The :class:`_postgresql.JSONB` type includes all operations provided by
+ :class:`_types.JSON`, including the same behaviors for indexing
+ operations.
+ It also adds additional operators specific to JSONB, including
+ :meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`,
+ :meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`,
+ and :meth:`.JSONB.Comparator.contained_by`.
+
+ Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB`
+ type does not detect
+ in-place changes when used with the ORM, unless the
+ :mod:`sqlalchemy.ext.mutable` extension is used.
+
+ Custom serializers and deserializers
+ are shared with the :class:`_types.JSON` class,
+ using the ``json_serializer``
+ and ``json_deserializer`` keyword arguments. These must be specified
+ at the dialect level using :func:`_sa.create_engine`. When using
+ psycopg2, the serializers are associated with the jsonb type using
+ ``psycopg2.extras.register_default_jsonb`` on a per-connection basis,
+ in the same way that ``psycopg2.extras.register_default_json`` is used
+ to register these handlers with the json type.
+
+ .. versionadded:: 0.9.7
+
+ .. seealso::
+
+ :class:`_types.JSON`
+
+ """
+
+ __visit_name__ = "JSONB"
+
+ class Comparator(JSON.Comparator):
+ """Define comparison operations for :class:`_types.JSON`."""
+
+ def has_key(self, other):
+ """Boolean expression. Test for presence of a key. Note that the
+ key may be a SQLA expression.
+ """
+ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
+
+ def has_all(self, other):
+ """Boolean expression. Test for presence of all keys in jsonb"""
+ return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
+
+ def has_any(self, other):
+ """Boolean expression. Test for presence of any key in jsonb"""
+ return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
+
+ def contains(self, other, **kwargs):
+ """Boolean expression. Test if keys (or array) are a superset
+ of/contained the keys of the argument jsonb expression.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
+
+ def contained_by(self, other):
+ """Boolean expression. Test if keys are a proper subset of the
+ keys of the argument jsonb expression.
+ """
+ return self.operate(
+ CONTAINED_BY, other, result_type=sqltypes.Boolean
+ )
+
+ comparator_factory = Comparator
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
new file mode 100644
index 0000000..98561a9
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -0,0 +1,594 @@
+# postgresql/pg8000.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+pg8000
+ :name: pg8000
+ :dbapi: pg8000
+ :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://pypi.org/project/pg8000/
+
+.. versionchanged:: 1.4 The pg8000 dialect has been updated for version
+ 1.16.6 and higher, and is again part of SQLAlchemy's continuous integration
+ with full feature support.
+
+.. _pg8000_unicode:
+
+Unicode
+-------
+
+pg8000 will encode / decode string values between it and the server using the
+PostgreSQL ``client_encoding`` parameter; by default this is the value in
+the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
+Typically, this can be changed to ``utf-8``, as a more useful default::
+
+ #client_encoding = sql_ascii # actually, defaults to database
+ # encoding
+ client_encoding = utf8
+
+The ``client_encoding`` can be overridden for a session by executing the SQL:
+
+SET CLIENT_ENCODING TO 'utf8';
+
+SQLAlchemy will execute this SQL on all new connections based on the value
+passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
+
+ engine = create_engine(
+ "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
+
+.. _pg8000_ssl:
+
+SSL Connections
+---------------
+
+pg8000 accepts a Python ``SSLContext`` object which may be specified using the
+:paramref:`_sa.create_engine.connect_args` dictionary::
+
+ import ssl
+ ssl_context = ssl.create_default_context()
+ engine = sa.create_engine(
+ "postgresql+pg8000://scott:tiger@192.168.0.199/test",
+ connect_args={"ssl_context": ssl_context},
+ )
+
+If the server uses an automatically-generated certificate that is self-signed
+or does not match the host name (as seen from the client), it may also be
+necessary to disable hostname checking::
+
+ import ssl
+ ssl_context = ssl.create_default_context()
+ ssl_context.check_hostname = False
+ ssl_context.verify_mode = ssl.CERT_NONE
+ engine = sa.create_engine(
+ "postgresql+pg8000://scott:tiger@192.168.0.199/test",
+ connect_args={"ssl_context": ssl_context},
+ )
+
+.. _pg8000_isolation_level:
+
+pg8000 Transaction Isolation Level
+-------------------------------------
+
+The pg8000 dialect offers the same isolation level settings as that
+of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+.. seealso::
+
+ :ref:`postgresql_isolation_level`
+
+ :ref:`psycopg2_isolation_level`
+
+
+""" # noqa
+import decimal
+import re
+from uuid import UUID as _python_UUID
+
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import ENUM
+from .base import INTERVAL
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .json import JSON
+from .json import JSONB
+from .json import JSONPathType
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+from ...sql.elements import quoted_name
+
+
+class _PGNumeric(sqltypes.Numeric):
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # pg8000 returns Decimal natively for 1700
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # pg8000 returns float natively for 701
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class _PGNumericNoBind(_PGNumeric):
+ def bind_processor(self, dialect):
+ return None
+
+
+class _PGJSON(JSON):
+ def result_processor(self, dialect, coltype):
+ return None
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSON
+
+
+class _PGJSONB(JSONB):
+ def result_processor(self, dialect, coltype):
+ return None
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.JSONB
+
+
+class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
+ def get_dbapi_type(self, dbapi):
+ raise NotImplementedError("should not be here")
+
+
+class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+
+class _PGJSONPathType(JSONPathType):
+ def get_dbapi_type(self, dbapi):
+ return 1009
+
+
+class _PGUUID(UUID):
+ def bind_processor(self, dialect):
+ if not self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ value = str(value)
+ return value
+
+ return process
+
+
+class _PGEnum(ENUM):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.UNKNOWN
+
+
+class _PGInterval(INTERVAL):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTERVAL
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+ return _PGInterval(precision=interval.second_precision)
+
+
+class _PGTimeStamp(sqltypes.DateTime):
+ def get_dbapi_type(self, dbapi):
+ if self.timezone:
+ # TIMESTAMPTZOID
+ return 1184
+ else:
+ # TIMESTAMPOID
+ return 1114
+
+
+class _PGTime(sqltypes.Time):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.TIME
+
+
+class _PGInteger(sqltypes.Integer):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class _PGSmallInteger(sqltypes.SmallInteger):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.INTEGER
+
+
+class _PGNullType(sqltypes.NullType):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NULLTYPE
+
+
+class _PGBigInteger(sqltypes.BigInteger):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BIGINTEGER
+
+
+class _PGBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BOOLEAN
+
+
+class _PGARRAY(PGARRAY):
+ def bind_expression(self, bindvalue):
+ return _ColonCast(bindvalue, self)
+
+
+_server_side_id = util.counter()
+
+
+class PGExecutionContext_pg8000(PGExecutionContext):
+ def create_server_side_cursor(self):
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
+ return ServerSideCursor(self._dbapi_connection.cursor(), ident)
+
+ def pre_exec(self):
+ if not self.compiled:
+ return
+
+
+class ServerSideCursor:
+ server_side = True
+
+ def __init__(self, cursor, ident):
+ self.ident = ident
+ self.cursor = cursor
+
+ @property
+ def connection(self):
+ return self.cursor.connection
+
+ @property
+ def rowcount(self):
+ return self.cursor.rowcount
+
+ @property
+ def description(self):
+ return self.cursor.description
+
+ def execute(self, operation, args=(), stream=None):
+ op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation
+ self.cursor.execute(op, args, stream=stream)
+ return self
+
+ def executemany(self, operation, param_sets):
+ self.cursor.executemany(operation, param_sets)
+ return self
+
+ def fetchone(self):
+ self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident)
+ return self.cursor.fetchone()
+
+ def fetchmany(self, num=None):
+ if num is None:
+ return self.fetchall()
+ else:
+ self.cursor.execute(
+ "FETCH FORWARD " + str(int(num)) + " FROM " + self.ident
+ )
+ return self.cursor.fetchall()
+
+ def fetchall(self):
+ self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident)
+ return self.cursor.fetchall()
+
+ def close(self):
+ self.cursor.execute("CLOSE " + self.ident)
+ self.cursor.close()
+
+ def setinputsizes(self, *sizes):
+ self.cursor.setinputsizes(*sizes)
+
+ def setoutputsize(self, size, column=None):
+ pass
+
+
+class PGCompiler_pg8000(PGCompiler):
+ def visit_mod_binary(self, binary, operator, **kw):
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+
+
+class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
+ def __init__(self, *args, **kwargs):
+ PGIdentifierPreparer.__init__(self, *args, **kwargs)
+ self._double_percents = False
+
+
+class PGDialect_pg8000(PGDialect):
+ driver = "pg8000"
+ supports_statement_cache = True
+
+ supports_unicode_statements = True
+
+ supports_unicode_binds = True
+
+ default_paramstyle = "format"
+ supports_sane_multi_rowcount = True
+ execution_ctx_cls = PGExecutionContext_pg8000
+ statement_compiler = PGCompiler_pg8000
+ preparer = PGIdentifierPreparer_pg8000
+ supports_server_side_cursors = True
+
+ use_setinputsizes = True
+
+ # reversed as of pg8000 1.16.6. 1.16.5 and lower
+ # are no longer compatible
+ description_encoding = None
+ # description_encoding = "use_encoding"
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric: _PGNumericNoBind,
+ sqltypes.Float: _PGNumeric,
+ sqltypes.JSON: _PGJSON,
+ sqltypes.Boolean: _PGBoolean,
+ sqltypes.NullType: _PGNullType,
+ JSONB: _PGJSONB,
+ sqltypes.JSON.JSONPathType: _PGJSONPathType,
+ sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
+ sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
+ sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
+ UUID: _PGUUID,
+ sqltypes.Interval: _PGInterval,
+ INTERVAL: _PGInterval,
+ sqltypes.DateTime: _PGTimeStamp,
+ sqltypes.Time: _PGTime,
+ sqltypes.Integer: _PGInteger,
+ sqltypes.SmallInteger: _PGSmallInteger,
+ sqltypes.BigInteger: _PGBigInteger,
+ sqltypes.Enum: _PGEnum,
+ sqltypes.ARRAY: _PGARRAY,
+ },
+ )
+
+ def __init__(self, client_encoding=None, **kwargs):
+ PGDialect.__init__(self, **kwargs)
+ self.client_encoding = client_encoding
+
+ if self._dbapi_version < (1, 16, 6):
+ raise NotImplementedError("pg8000 1.16.6 or greater is required")
+
+ @util.memoized_property
+ def _dbapi_version(self):
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ return tuple(
+ [
+ int(x)
+ for x in re.findall(
+ r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
+ )
+ ]
+ )
+ else:
+ return (99, 99, 99)
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("pg8000")
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.InterfaceError) and "network error" in str(
+ e
+ ):
+ # new as of pg8000 1.19.0 for broken connections
+ return True
+
+ # connection was closed normally
+ return "connection is closed" in str(e)
+
+ def set_isolation_level(self, connection, level):
+ level = level.replace("_", " ")
+
+ # adjust for ConnectionFairy possibly being present
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ if level == "AUTOCOMMIT":
+ connection.autocommit = True
+ elif level in self._isolation_lookup:
+ connection.autocommit = False
+ cursor = connection.cursor()
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION "
+ "ISOLATION LEVEL %s" % level
+ )
+ cursor.execute("COMMIT")
+ cursor.close()
+ else:
+ raise exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s or AUTOCOMMIT"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ )
+
+ def set_readonly(self, connection, value):
+ cursor = connection.cursor()
+ try:
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
+ % ("READ ONLY" if value else "READ WRITE")
+ )
+ cursor.execute("COMMIT")
+ finally:
+ cursor.close()
+
+ def get_readonly(self, connection):
+ cursor = connection.cursor()
+ try:
+ cursor.execute("show transaction_read_only")
+ val = cursor.fetchone()[0]
+ finally:
+ cursor.close()
+
+ return val == "on"
+
+ def set_deferrable(self, connection, value):
+ cursor = connection.cursor()
+ try:
+ cursor.execute(
+ "SET SESSION CHARACTERISTICS AS TRANSACTION %s"
+ % ("DEFERRABLE" if value else "NOT DEFERRABLE")
+ )
+ cursor.execute("COMMIT")
+ finally:
+ cursor.close()
+
+ def get_deferrable(self, connection):
+ cursor = connection.cursor()
+ try:
+ cursor.execute("show transaction_deferrable")
+ val = cursor.fetchone()[0]
+ finally:
+ cursor.close()
+
+ return val == "on"
+
+ def set_client_encoding(self, connection, client_encoding):
+ # adjust for ConnectionFairy possibly being present
+ if hasattr(connection, "dbapi_connection"):
+ connection = connection.dbapi_connection
+
+ cursor = connection.cursor()
+ cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
+ cursor.execute("COMMIT")
+ cursor.close()
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
+ def do_begin_twophase(self, connection, xid):
+ connection.connection.tpc_begin((0, xid, ""))
+
+ def do_prepare_twophase(self, connection, xid):
+ connection.connection.tpc_prepare()
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ connection.connection.tpc_rollback((0, xid, ""))
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ connection.connection.tpc_commit((0, xid, ""))
+
+ def do_recover_twophase(self, connection):
+ return [row[1] for row in connection.connection.tpc_recover()]
+
+ def on_connect(self):
+ fns = []
+
+ def on_connect(conn):
+ conn.py_types[quoted_name] = conn.py_types[util.text_type]
+
+ fns.append(on_connect)
+
+ if self.client_encoding is not None:
+
+ def on_connect(conn):
+ self.set_client_encoding(conn, self.client_encoding)
+
+ fns.append(on_connect)
+
+ if self.isolation_level is not None:
+
+ def on_connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ fns.append(on_connect)
+
+ if self._json_deserializer:
+
+ def on_connect(conn):
+ # json
+ conn.register_in_adapter(114, self._json_deserializer)
+
+ # jsonb
+ conn.register_in_adapter(3802, self._json_deserializer)
+
+ fns.append(on_connect)
+
+ if len(fns) > 0:
+
+ def on_connect(conn):
+ for fn in fns:
+ fn(conn)
+
+ return on_connect
+ else:
+ return None
+
+
+dialect = PGDialect_pg8000
diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py
new file mode 100644
index 0000000..98470f3
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/provision.py
@@ -0,0 +1,124 @@
+import time
+
+from ... import exc
+from ... import inspect
+from ... import text
+from ...testing import warn_test_suite
+from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_post_tables
+from ...testing.provision import drop_all_schema_objects_pre_tables
+from ...testing.provision import drop_db
+from ...testing.provision import log
+from ...testing.provision import prepare_for_drop_tables
+from ...testing.provision import set_default_schema_on_connection
+from ...testing.provision import temp_table_keyword_args
+
+
+@create_db.for_db("postgresql")
+def _pg_create_db(cfg, eng, ident):
+ template_db = cfg.options.postgresql_templatedb
+
+ with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
+
+ if not template_db:
+ template_db = conn.exec_driver_sql(
+ "select current_database()"
+ ).scalar()
+
+ attempt = 0
+ while True:
+ try:
+ conn.exec_driver_sql(
+ "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
+ )
+ except exc.OperationalError as err:
+ attempt += 1
+ if attempt >= 3:
+ raise
+ if "accessed by other users" in str(err):
+ log.info(
+ "Waiting to create %s, URI %r, "
+ "template DB %s is in use sleeping for .5",
+ ident,
+ eng.url,
+ template_db,
+ )
+ time.sleep(0.5)
+ except:
+ raise
+ else:
+ break
+
+
+@drop_db.for_db("postgresql")
+def _pg_drop_db(cfg, eng, ident):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ with conn.begin():
+ conn.execute(
+ text(
+ "select pg_terminate_backend(pid) from pg_stat_activity "
+ "where usename=current_user and pid != pg_backend_pid() "
+ "and datname=:dname"
+ ),
+ dict(dname=ident),
+ )
+ conn.exec_driver_sql("DROP DATABASE %s" % ident)
+
+
+@temp_table_keyword_args.for_db("postgresql")
+def _postgresql_temp_table_keyword_args(cfg, eng):
+ return {"prefixes": ["TEMPORARY"]}
+
+
+@set_default_schema_on_connection.for_db("postgresql")
+def _postgresql_set_default_schema_on_connection(
+ cfg, dbapi_connection, schema_name
+):
+ existing_autocommit = dbapi_connection.autocommit
+ dbapi_connection.autocommit = True
+ cursor = dbapi_connection.cursor()
+ cursor.execute("SET SESSION search_path='%s'" % schema_name)
+ cursor.close()
+ dbapi_connection.autocommit = existing_autocommit
+
+
+@drop_all_schema_objects_pre_tables.for_db("postgresql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+ for xid in conn.execute("select gid from pg_prepared_xacts").scalars():
+ conn.execute("ROLLBACK PREPARED '%s'" % xid)
+
+
+@drop_all_schema_objects_post_tables.for_db("postgresql")
+def drop_all_schema_objects_post_tables(cfg, eng):
+ from sqlalchemy.dialects import postgresql
+
+ inspector = inspect(eng)
+ with eng.begin() as conn:
+ for enum in inspector.get_enums("*"):
+ conn.execute(
+ postgresql.DropEnumType(
+ postgresql.ENUM(name=enum["name"], schema=enum["schema"])
+ )
+ )
+
+
+@prepare_for_drop_tables.for_db("postgresql")
+def prepare_for_drop_tables(config, connection):
+ """Ensure there are no locks on the current username/database."""
+
+ result = connection.exec_driver_sql(
+ "select pid, state, wait_event_type, query "
+ # "select pg_terminate_backend(pid), state, wait_event_type "
+ "from pg_stat_activity where "
+ "usename=current_user "
+ "and datname=current_database() and state='idle in transaction' "
+ "and pid != pg_backend_pid()"
+ )
+ rows = result.all() # noqa
+ if rows:
+ warn_test_suite(
+ "PostgreSQL may not be able to DROP tables due to "
+ "idle in transaction: %s"
+ % ("; ".join(row._mapping["query"] for row in rows))
+ )
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
new file mode 100644
index 0000000..6747427
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -0,0 +1,1088 @@
+# postgresql/psycopg2.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+psycopg2
+ :name: psycopg2
+ :dbapi: psycopg2
+ :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://pypi.org/project/psycopg2/
+
+psycopg2 Connect Arguments
+--------------------------
+
+Keyword arguments that are specific to the SQLAlchemy psycopg2 dialect
+may be passed to :func:`_sa.create_engine()`, and include the following:
+
+
+* ``isolation_level``: This option, available for all PostgreSQL dialects,
+ includes the ``AUTOCOMMIT`` isolation level when using the psycopg2
+ dialect. This option sets the **default** isolation level for the
+ connection that is set immediately upon connection to the database before
+ the connection is pooled. This option is generally superseded by the more
+ modern :paramref:`_engine.Connection.execution_options.isolation_level`
+ execution option, detailed at :ref:`dbapi_autocommit`.
+
+ .. seealso::
+
+ :ref:`psycopg2_isolation_level`
+
+ :ref:`dbapi_autocommit`
+
+
+* ``client_encoding``: sets the client encoding in a libpq-agnostic way,
+ using psycopg2's ``set_client_encoding()`` method.
+
+ .. seealso::
+
+ :ref:`psycopg2_unicode`
+
+* ``use_native_unicode``: Under Python 2 only, this can be set to False to
+ disable the use of psycopg2's native Unicode support.
+
+ .. seealso::
+
+ :ref:`psycopg2_disable_native_unicode`
+
+
+* ``executemany_mode``, ``executemany_batch_page_size``,
+ ``executemany_values_page_size``: Allows use of psycopg2
+ extensions for optimizing "executemany"-style queries. See the referenced
+ section below for details.
+
+ .. seealso::
+
+ :ref:`psycopg2_executemany_mode`
+
+.. tip::
+
+ The above keyword arguments are **dialect** keyword arguments, meaning
+ that they are passed as explicit keyword arguments to :func:`_sa.create_engine()`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://scott:tiger@localhost/test",
+ isolation_level="SERIALIZABLE",
+ )
+
+ These should not be confused with **DBAPI** connect arguments, which
+ are passed as part of the :paramref:`_sa.create_engine.connect_args`
+ dictionary and/or are passed in the URL query string, as detailed in
+ the section :ref:`custom_dbapi_args`.
+
+.. _psycopg2_ssl:
+
+SSL Connections
+---------------
+
+The psycopg2 module has a connection argument named ``sslmode`` for
+controlling its behavior regarding secure (SSL) connections. The default is
+``sslmode=prefer``; it will attempt an SSL connection and if that fails it
+will fall back to an unencrypted connection. ``sslmode=require`` may be used
+to ensure that only secure connections are established. Consult the
+psycopg2 / libpq documentation for further options that are available.
+
+Note that ``sslmode`` is specific to psycopg2 so it is included in the
+connection URI::
+
+ engine = sa.create_engine(
+ "postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require"
+ )
+
+Unix Domain Connections
+------------------------
+
+psycopg2 supports connecting via Unix domain connections. When the ``host``
+portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
+which specifies Unix-domain communication rather than TCP/IP communication::
+
+ create_engine("postgresql+psycopg2://user:password@/dbname")
+
+By default, the socket file used is to connect to a Unix-domain socket
+in ``/tmp``, or whatever socket directory was specified when PostgreSQL
+was built. This value can be overridden by passing a pathname to psycopg2,
+using ``host`` as an additional keyword argument::
+
+ create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
+
+.. seealso::
+
+ `PQconnectdbParams \
+ <https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_
+
+.. _psycopg2_multi_host:
+
+Specifying multiple fallback hosts
+-----------------------------------
+
+psycopg2 supports multiple connection points in the connection string.
+When the ``host`` parameter is used multiple times in the query section of
+the URL, SQLAlchemy will create a single string of the host and port
+information provided to make the connections. Tokens may consist of
+``host::port`` or just ``host``; in the latter case, the default port
+is selected by libpq. In the example below, three host connections
+are specified, for ``HostA::PortA``, ``HostB`` connecting to the default port,
+and ``HostC::PortC``::
+
+ create_engine(
+ "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC"
+ )
+
+As an alternative, libpq query string format also may be used; this specifies
+``host`` and ``port`` as single query string arguments with comma-separated
+lists - the default port can be chosen by indicating an empty value
+in the comma separated list::
+
+ create_engine(
+ "postgresql+psycopg2://user:password@/dbname?host=HostA,HostB,HostC&port=PortA,,PortC"
+ )
+
+With either URL style, connections to each host is attempted based on a
+configurable strategy, which may be configured using the libpq
+``target_session_attrs`` parameter. Per libpq this defaults to ``any``
+which indicates a connection to each host is then attempted until a connection is successful.
+Other strategies include ``primary``, ``prefer-standby``, etc. The complete
+list is documented by PostgreSQL at
+`libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_.
+
+For example, to indicate two hosts using the ``primary`` strategy::
+
+ create_engine(
+ "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC&target_session_attrs=primary"
+ )
+
+.. versionchanged:: 1.4.40 Port specification in psycopg2 multiple host format
+ is repaired, previously ports were not correctly interpreted in this context.
+ libpq comma-separated format is also now supported.
+
+.. versionadded:: 1.3.20 Support for multiple hosts in PostgreSQL connection
+ string.
+
+.. seealso::
+
+ `libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_ - please refer
+ to this section in the libpq documentation for complete background on multiple host support.
+
+
+Empty DSN Connections / Environment Variable Connections
+---------------------------------------------------------
+
+The psycopg2 DBAPI can connect to PostgreSQL by passing an empty DSN to the
+libpq client library, which by default indicates to connect to a localhost
+PostgreSQL database that is open for "trust" connections. This behavior can be
+further tailored using a particular set of environment variables which are
+prefixed with ``PG_...``, which are consumed by ``libpq`` to take the place of
+any or all elements of the connection string.
+
+For this form, the URL can be passed without any elements other than the
+initial scheme::
+
+ engine = create_engine('postgresql+psycopg2://')
+
+In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()``
+function which in turn represents an empty DSN passed to libpq.
+
+.. versionadded:: 1.3.2 support for parameter-less connections with psycopg2.
+
+.. seealso::
+
+ `Environment Variables\
+ <https://www.postgresql.org/docs/current/libpq-envars.html>`_ -
+ PostgreSQL documentation on how to use ``PG_...``
+ environment variables for connections.
+
+.. _psycopg2_execution_options:
+
+Per-Statement/Connection Execution Options
+-------------------------------------------
+
+The following DBAPI-specific options are respected when used with
+:meth:`_engine.Connection.execution_options`,
+:meth:`.Executable.execution_options`,
+:meth:`_query.Query.execution_options`,
+in addition to those not specific to DBAPIs:
+
+* ``isolation_level`` - Set the transaction isolation level for the lifespan
+ of a :class:`_engine.Connection` (can only be set on a connection,
+ not a statement
+ or query). See :ref:`psycopg2_isolation_level`.
+
+* ``stream_results`` - Enable or disable usage of psycopg2 server side
+ cursors - this feature makes use of "named" cursors in combination with
+ special result handling methods so that result rows are not fully buffered.
+ Defaults to False, meaning cursors are buffered by default.
+
+* ``max_row_buffer`` - when using ``stream_results``, an integer value that
+ specifies the maximum number of rows to buffer at a time. This is
+ interpreted by the :class:`.BufferedRowCursorResult`, and if omitted the
+ buffer will grow to ultimately store 1000 rows at a time.
+
+ .. versionchanged:: 1.4 The ``max_row_buffer`` size can now be greater than
+ 1000, and the buffer will grow to that size.
+
+.. _psycopg2_batch_mode:
+
+.. _psycopg2_executemany_mode:
+
+Psycopg2 Fast Execution Helpers
+-------------------------------
+
+Modern versions of psycopg2 include a feature known as
+`Fast Execution Helpers \
+<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which
+have been shown in benchmarking to improve psycopg2's executemany()
+performance, primarily with INSERT statements, by multiple orders of magnitude.
+SQLAlchemy internally makes use of these extensions for ``executemany()`` style
+calls, which correspond to lists of parameters being passed to
+:meth:`_engine.Connection.execute` as detailed in :ref:`multiple parameter
+sets <tutorial_multiple_parameters>`. The ORM also uses this mode internally whenever
+possible.
+
+The two available extensions on the psycopg2 side are the ``execute_values()``
+and ``execute_batch()`` functions. The psycopg2 dialect defaults to using the
+``execute_values()`` extension for all qualifying INSERT statements.
+
+.. versionchanged:: 1.4 The psycopg2 dialect now defaults to a new mode
+ ``"values_only"`` for ``executemany_mode``, which allows an order of
+ magnitude performance improvement for INSERT statements, but does not
+ include "batch" mode for UPDATE and DELETE statements which removes the
+ ability of ``cursor.rowcount`` to function correctly.
+
+The use of these extensions is controlled by the ``executemany_mode`` flag
+which may be passed to :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://scott:tiger@host/dbname",
+ executemany_mode='values_plus_batch')
+
+
+Possible options for ``executemany_mode`` include:
+
+* ``values_only`` - this is the default value. the psycopg2 execute_values()
+ extension is used for qualifying INSERT statements, which rewrites the INSERT
+ to include multiple VALUES clauses so that many parameter sets can be
+ inserted with one statement.
+
+ .. versionadded:: 1.4 Added ``"values_only"`` setting for ``executemany_mode``
+ which is also now the default.
+
+* ``None`` - No psycopg2 extensions are not used, and the usual
+ ``cursor.executemany()`` method is used when invoking statements with
+ multiple parameter sets.
+
+* ``'batch'`` - Uses ``psycopg2.extras.execute_batch`` for all qualifying
+ INSERT, UPDATE and DELETE statements, so that multiple copies
+ of a SQL query, each one corresponding to a parameter set passed to
+ ``executemany()``, are joined into a single SQL string separated by a
+ semicolon. When using this mode, the :attr:`_engine.CursorResult.rowcount`
+ attribute will not contain a value for executemany-style executions.
+
+* ``'values_plus_batch'``- ``execute_values`` is used for qualifying INSERT
+ statements, ``execute_batch`` is used for UPDATE and DELETE.
+ When using this mode, the :attr:`_engine.CursorResult.rowcount`
+ attribute will not contain a value for executemany-style executions against
+ UPDATE and DELETE statements.
+
+By "qualifying statements", we mean that the statement being executed
+must be a Core :func:`_expression.insert`, :func:`_expression.update`
+or :func:`_expression.delete` construct, and not a plain textual SQL
+string or one constructed using :func:`_expression.text`. When using the
+ORM, all insert/update/delete statements used by the ORM flush process
+are qualifying.
+
+The "page size" for the "values" and "batch" strategies can be affected
+by using the ``executemany_batch_page_size`` and
+``executemany_values_page_size`` engine parameters. These
+control how many parameter sets
+should be represented in each execution. The "values" page size defaults
+to 1000, which is different that psycopg2's default. The "batch" page
+size defaults to 100. These can be affected by passing new values to
+:func:`_engine.create_engine`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://scott:tiger@host/dbname",
+ executemany_mode='values',
+ executemany_values_page_size=10000, executemany_batch_page_size=500)
+
+.. versionchanged:: 1.4
+
+ The default for ``executemany_values_page_size`` is now 1000, up from
+ 100.
+
+.. seealso::
+
+ :ref:`tutorial_multiple_parameters` - General information on using the
+ :class:`_engine.Connection`
+ object to execute statements in such a way as to make
+ use of the DBAPI ``.executemany()`` method.
+
+
+.. _psycopg2_unicode:
+
+Unicode with Psycopg2
+----------------------
+
+The psycopg2 DBAPI driver supports Unicode data transparently. Under Python 2
+only, the SQLAlchemy psycopg2 dialect will enable the
+``psycopg2.extensions.UNICODE`` extension by default to ensure Unicode is
+handled properly; under Python 3, this is psycopg2's default behavior.
+
+The client character encoding can be controlled for the psycopg2 dialect
+in the following ways:
+
+* For PostgreSQL 9.1 and above, the ``client_encoding`` parameter may be
+ passed in the database URL; this parameter is consumed by the underlying
+ ``libpq`` PostgreSQL client library::
+
+ engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8")
+
+ Alternatively, the above ``client_encoding`` value may be passed using
+ :paramref:`_sa.create_engine.connect_args` for programmatic establishment with
+ ``libpq``::
+
+ engine = create_engine(
+ "postgresql+psycopg2://user:pass@host/dbname",
+ connect_args={'client_encoding': 'utf8'}
+ )
+
+* For all PostgreSQL versions, psycopg2 supports a client-side encoding
+ value that will be passed to database connections when they are first
+ established. The SQLAlchemy psycopg2 dialect supports this using the
+ ``client_encoding`` parameter passed to :func:`_sa.create_engine`::
+
+ engine = create_engine(
+ "postgresql+psycopg2://user:pass@host/dbname",
+ client_encoding="utf8"
+ )
+
+ .. tip:: The above ``client_encoding`` parameter admittedly is very similar
+ in appearance to usage of the parameter within the
+ :paramref:`_sa.create_engine.connect_args` dictionary; the difference
+ above is that the parameter is consumed by psycopg2 and is
+ passed to the database connection using ``SET client_encoding TO
+ 'utf8'``; in the previously mentioned style, the parameter is instead
+ passed through psycopg2 and consumed by the ``libpq`` library.
+
+* A common way to set up client encoding with PostgreSQL databases is to
+ ensure it is configured within the server-side postgresql.conf file;
+ this is the recommended way to set encoding for a server that is
+ consistently of one encoding in all databases::
+
+ # postgresql.conf file
+
+ # client_encoding = sql_ascii # actually, defaults to database
+ # encoding
+ client_encoding = utf8
+
+.. _psycopg2_disable_native_unicode:
+
+Disabling Native Unicode
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+Under Python 2 only, SQLAlchemy can also be instructed to skip the usage of the
+psycopg2 ``UNICODE`` extension and to instead utilize its own unicode
+encode/decode services, which are normally reserved only for those DBAPIs that
+don't fully support unicode directly. Passing ``use_native_unicode=False`` to
+:func:`_sa.create_engine` will disable usage of ``psycopg2.extensions.
+UNICODE``. SQLAlchemy will instead encode data itself into Python bytestrings
+on the way in and coerce from bytes on the way back, using the value of the
+:func:`_sa.create_engine` ``encoding`` parameter, which defaults to ``utf-8``.
+SQLAlchemy's own unicode encode/decode functionality is steadily becoming
+obsolete as most DBAPIs now support unicode fully.
+
+
+Transactions
+------------
+
+The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
+
+.. _psycopg2_isolation_level:
+
+Psycopg2 Transaction Isolation Level
+-------------------------------------
+
+As discussed in :ref:`postgresql_isolation_level`,
+all PostgreSQL dialects support setting of transaction isolation level
+both via the ``isolation_level`` parameter passed to :func:`_sa.create_engine`
+,
+as well as the ``isolation_level`` argument used by
+:meth:`_engine.Connection.execution_options`. When using the psycopg2 dialect
+, these
+options make use of psycopg2's ``set_isolation_level()`` connection method,
+rather than emitting a PostgreSQL directive; this is because psycopg2's
+API-level setting is always emitted at the start of each transaction in any
+case.
+
+The psycopg2 dialect supports these constants for isolation level:
+
+* ``READ COMMITTED``
+* ``READ UNCOMMITTED``
+* ``REPEATABLE READ``
+* ``SERIALIZABLE``
+* ``AUTOCOMMIT``
+
+.. seealso::
+
+ :ref:`postgresql_isolation_level`
+
+ :ref:`pg8000_isolation_level`
+
+
+NOTICE logging
+---------------
+
+The psycopg2 dialect will log PostgreSQL NOTICE messages
+via the ``sqlalchemy.dialects.postgresql`` logger. When this logger
+is set to the ``logging.INFO`` level, notice messages will be logged::
+
+ import logging
+
+ logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
+
+Above, it is assumed that logging is configured externally. If this is not
+the case, configuration such as ``logging.basicConfig()`` must be utilized::
+
+ import logging
+
+ logging.basicConfig() # log messages to stdout
+ logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
+
+.. seealso::
+
+ `Logging HOWTO <https://docs.python.org/3/howto/logging.html>`_ - on the python.org website
+
+.. _psycopg2_hstore:
+
+HSTORE type
+------------
+
+The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of
+the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension
+by default when psycopg2 version 2.4 or greater is used, and
+it is detected that the target database has the HSTORE type set up for use.
+In other words, when the dialect makes the first
+connection, a sequence like the following is performed:
+
+1. Request the available HSTORE oids using
+ ``psycopg2.extras.HstoreAdapter.get_oids()``.
+ If this function returns a list of HSTORE identifiers, we then determine
+ that the ``HSTORE`` extension is present.
+ This function is **skipped** if the version of psycopg2 installed is
+ less than version 2.4.
+
+2. If the ``use_native_hstore`` flag is at its default of ``True``, and
+ we've detected that ``HSTORE`` oids are available, the
+ ``psycopg2.extensions.register_hstore()`` extension is invoked for all
+ connections.
+
+The ``register_hstore()`` extension has the effect of **all Python
+dictionaries being accepted as parameters regardless of the type of target
+column in SQL**. The dictionaries are converted by this extension into a
+textual HSTORE expression. If this behavior is not desired, disable the
+use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
+follows::
+
+ engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
+ use_native_hstore=False)
+
+The ``HSTORE`` type is **still supported** when the
+``psycopg2.extensions.register_hstore()`` extension is not used. It merely
+means that the coercion between Python dictionaries and the HSTORE
+string format, on both the parameter side and the result side, will take
+place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2``
+which may be more performant.
+
+""" # noqa
+from __future__ import absolute_import
+
+import decimal
+import logging
+import re
+from uuid import UUID as _python_UUID
+
+from .array import ARRAY as PGARRAY
+from .base import _ColonCast
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import ENUM
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGExecutionContext
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .hstore import HSTORE
+from .json import JSON
+from .json import JSONB
+from ... import exc
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+from ...engine import cursor as _cursor
+from ...util import collections_abc
+
+
+logger = logging.getLogger("sqlalchemy.dialects.postgresql")
+
+
+class _PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # pg8000 returns Decimal natively for 1700
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # pg8000 returns float natively for 701
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class _PGEnum(ENUM):
+ def result_processor(self, dialect, coltype):
+ if util.py2k and self._expect_unicode is True:
+ # for py2k, if the enum type needs unicode data (which is set up as
+ # part of the Enum() constructor based on values passed as py2k
+ # unicode objects) we have to use our own converters since
+ # psycopg2's don't work, a rare exception to the "modern DBAPIs
+ # support unicode everywhere" theme of deprecating
+ # convert_unicode=True. Use the special "force_nocheck" directive
+ # which forces unicode conversion to happen on the Python side
+ # without an isinstance() check. in py3k psycopg2 does the right
+ # thing automatically.
+ self._expect_unicode = "force_nocheck"
+ return super(_PGEnum, self).result_processor(dialect, coltype)
+
+
+class _PGHStore(HSTORE):
+ def bind_processor(self, dialect):
+ if dialect._has_native_hstore:
+ return None
+ else:
+ return super(_PGHStore, self).bind_processor(dialect)
+
+ def result_processor(self, dialect, coltype):
+ if dialect._has_native_hstore:
+ return None
+ else:
+ return super(_PGHStore, self).result_processor(dialect, coltype)
+
+
+class _PGARRAY(PGARRAY):
+ def bind_expression(self, bindvalue):
+ return _ColonCast(bindvalue, self)
+
+
+class _PGJSON(JSON):
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class _PGJSONB(JSONB):
+ def result_processor(self, dialect, coltype):
+ return None
+
+
+class _PGUUID(UUID):
+ def bind_processor(self, dialect):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = _python_UUID(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not self.as_uuid and dialect.use_native_uuid:
+
+ def process(value):
+ if value is not None:
+ value = str(value)
+ return value
+
+ return process
+
+
+_server_side_id = util.counter()
+
+
+class PGExecutionContext_psycopg2(PGExecutionContext):
+ _psycopg2_fetched_rows = None
+
+ def create_server_side_cursor(self):
+ # use server-side cursors:
+ # https://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+ ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
+ return self._dbapi_connection.cursor(ident)
+
+ def post_exec(self):
+ if (
+ self._psycopg2_fetched_rows
+ and self.compiled
+ and self.compiled.returning
+ ):
+ # psycopg2 execute_values will provide for a real cursor where
+ # cursor.description works correctly. however, it executes the
+ # INSERT statement multiple times for multiple pages of rows, so
+ # while this cursor also supports calling .fetchall() directly, in
+ # order to get the list of all rows inserted across multiple pages,
+ # we have to retrieve the aggregated list from the execute_values()
+ # function directly.
+ strat_cls = _cursor.FullyBufferedCursorFetchStrategy
+ self.cursor_fetch_strategy = strat_cls(
+ self.cursor, initial_buffer=self._psycopg2_fetched_rows
+ )
+ self._log_notices(self.cursor)
+
+ def _log_notices(self, cursor):
+ # check also that notices is an iterable, after it's already
+ # established that we will be iterating through it. This is to get
+ # around test suites such as SQLAlchemy's using a Mock object for
+ # cursor
+ if not cursor.connection.notices or not isinstance(
+ cursor.connection.notices, collections_abc.Iterable
+ ):
+ return
+
+ for notice in cursor.connection.notices:
+ # NOTICE messages have a
+ # newline character at the end
+ logger.info(notice.rstrip())
+
+ cursor.connection.notices[:] = []
+
+
+class PGCompiler_psycopg2(PGCompiler):
+ pass
+
+
+class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
+ pass
+
+
+EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0)
+EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1)
+EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2)
+EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol(
+ "executemany_values_plus_batch",
+ canonical=EXECUTEMANY_BATCH | EXECUTEMANY_VALUES,
+)
+
+
+class PGDialect_psycopg2(PGDialect):
+ driver = "psycopg2"
+
+ supports_statement_cache = True
+
+ if util.py2k:
+ # turn off supports_unicode_statements for Python 2. psycopg2 supports
+ # unicode statements in Py2K. But! it does not support unicode *bound
+ # parameter names* because it uses the Python "%" operator to
+ # interpolate these into the string, and this fails. So for Py2K, we
+ # have to use full-on encoding for statements and parameters before
+ # passing to cursor.execute().
+ supports_unicode_statements = False
+
+ supports_server_side_cursors = True
+
+ default_paramstyle = "pyformat"
+ # set to true based on psycopg2 version
+ supports_sane_multi_rowcount = False
+ execution_ctx_cls = PGExecutionContext_psycopg2
+ statement_compiler = PGCompiler_psycopg2
+ preparer = PGIdentifierPreparer_psycopg2
+ psycopg2_version = (0, 0)
+
+ _has_native_hstore = True
+
+ engine_config_types = PGDialect.engine_config_types.union(
+ {"use_native_unicode": util.asbool}
+ )
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric: _PGNumeric,
+ ENUM: _PGEnum, # needs force_unicode
+ sqltypes.Enum: _PGEnum, # needs force_unicode
+ HSTORE: _PGHStore,
+ JSON: _PGJSON,
+ sqltypes.JSON: _PGJSON,
+ JSONB: _PGJSONB,
+ UUID: _PGUUID,
+ sqltypes.ARRAY: _PGARRAY,
+ },
+ )
+
+ def __init__(
+ self,
+ use_native_unicode=True,
+ client_encoding=None,
+ use_native_hstore=True,
+ use_native_uuid=True,
+ executemany_mode="values_only",
+ executemany_batch_page_size=100,
+ executemany_values_page_size=1000,
+ **kwargs
+ ):
+ PGDialect.__init__(self, **kwargs)
+ self.use_native_unicode = use_native_unicode
+ if not use_native_unicode and not util.py2k:
+ raise exc.ArgumentError(
+ "psycopg2 native_unicode mode is required under Python 3"
+ )
+ if not use_native_hstore:
+ self._has_native_hstore = False
+ self.use_native_hstore = use_native_hstore
+ self.use_native_uuid = use_native_uuid
+ self.supports_unicode_binds = use_native_unicode
+ self.client_encoding = client_encoding
+
+ # Parse executemany_mode argument, allowing it to be only one of the
+ # symbol names
+ self.executemany_mode = util.symbol.parse_user_argument(
+ executemany_mode,
+ {
+ EXECUTEMANY_PLAIN: [None],
+ EXECUTEMANY_BATCH: ["batch"],
+ EXECUTEMANY_VALUES: ["values_only"],
+ EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch", "values"],
+ },
+ "executemany_mode",
+ )
+
+ if self.executemany_mode & EXECUTEMANY_VALUES:
+ self.insert_executemany_returning = True
+
+ self.executemany_batch_page_size = executemany_batch_page_size
+ self.executemany_values_page_size = executemany_values_page_size
+
+ if self.dbapi and hasattr(self.dbapi, "__version__"):
+ m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
+ if m:
+ self.psycopg2_version = tuple(
+ int(x) for x in m.group(1, 2, 3) if x is not None
+ )
+
+ if self.psycopg2_version < (2, 7):
+ raise ImportError(
+ "psycopg2 version 2.7 or higher is required."
+ )
+
+ def initialize(self, connection):
+ super(PGDialect_psycopg2, self).initialize(connection)
+ self._has_native_hstore = (
+ self.use_native_hstore
+ and self._hstore_oids(connection.connection) is not None
+ )
+
+ # PGDialect.initialize() checks server version for <= 8.2 and sets
+ # this flag to False if so
+ if not self.full_returning:
+ self.insert_executemany_returning = False
+ self.executemany_mode = EXECUTEMANY_PLAIN
+
+ self.supports_sane_multi_rowcount = not (
+ self.executemany_mode & EXECUTEMANY_BATCH
+ )
+
+ @classmethod
+ def dbapi(cls):
+ import psycopg2
+
+ return psycopg2
+
+ @classmethod
+ def _psycopg2_extensions(cls):
+ from psycopg2 import extensions
+
+ return extensions
+
+ @classmethod
+ def _psycopg2_extras(cls):
+ from psycopg2 import extras
+
+ return extras
+
+ @util.memoized_property
+ def _isolation_lookup(self):
+ extensions = self._psycopg2_extensions()
+ return {
+ "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT,
+ "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED,
+ "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
+ "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ,
+ "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE,
+ }
+
+ def set_isolation_level(self, connection, level):
+ try:
+ level = self._isolation_lookup[level.replace("_", " ")]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (level, self.name, ", ".join(self._isolation_lookup))
+ ),
+ replace_context=err,
+ )
+
+ connection.set_isolation_level(level)
+
+ def set_readonly(self, connection, value):
+ connection.readonly = value
+
+ def get_readonly(self, connection):
+ return connection.readonly
+
+ def set_deferrable(self, connection, value):
+ connection.deferrable = value
+
+ def get_deferrable(self, connection):
+ return connection.deferrable
+
+ def do_ping(self, dbapi_connection):
+ cursor = None
+ before_autocommit = dbapi_connection.autocommit
+ try:
+ if not before_autocommit:
+ dbapi_connection.autocommit = True
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute(self._dialect_specific_select_one)
+ finally:
+ cursor.close()
+ if not before_autocommit and not dbapi_connection.closed:
+ dbapi_connection.autocommit = before_autocommit
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, cursor):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def on_connect(self):
+ extras = self._psycopg2_extras()
+ extensions = self._psycopg2_extensions()
+
+ fns = []
+ if self.client_encoding is not None:
+
+ def on_connect(conn):
+ conn.set_client_encoding(self.client_encoding)
+
+ fns.append(on_connect)
+
+ if self.isolation_level is not None:
+
+ def on_connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ fns.append(on_connect)
+
+ if self.dbapi and self.use_native_uuid:
+
+ def on_connect(conn):
+ extras.register_uuid(None, conn)
+
+ fns.append(on_connect)
+
+ if util.py2k and self.dbapi and self.use_native_unicode:
+
+ def on_connect(conn):
+ extensions.register_type(extensions.UNICODE, conn)
+ extensions.register_type(extensions.UNICODEARRAY, conn)
+
+ fns.append(on_connect)
+
+ if self.dbapi and self.use_native_hstore:
+
+ def on_connect(conn):
+ hstore_oids = self._hstore_oids(conn)
+ if hstore_oids is not None:
+ oid, array_oid = hstore_oids
+ kw = {"oid": oid}
+ if util.py2k:
+ kw["unicode"] = True
+ kw["array_oid"] = array_oid
+ extras.register_hstore(conn, **kw)
+
+ fns.append(on_connect)
+
+ if self.dbapi and self._json_deserializer:
+
+ def on_connect(conn):
+ extras.register_default_json(
+ conn, loads=self._json_deserializer
+ )
+ extras.register_default_jsonb(
+ conn, loads=self._json_deserializer
+ )
+
+ fns.append(on_connect)
+
+ if fns:
+
+ def on_connect(conn):
+ for fn in fns:
+ fn(conn)
+
+ return on_connect
+ else:
+ return None
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ if (
+ self.executemany_mode & EXECUTEMANY_VALUES
+ and context
+ and context.isinsert
+ and context.compiled._is_safe_for_fast_insert_values_helper
+ ):
+ executemany_values = (
+ "(%s)" % context.compiled.insert_single_values_expr
+ )
+ if not self.supports_unicode_statements:
+ executemany_values = executemany_values.encode(self.encoding)
+
+ # guard for statement that was altered via event hook or similar
+ if executemany_values not in statement:
+ executemany_values = None
+ else:
+ executemany_values = None
+
+ if executemany_values:
+ statement = statement.replace(executemany_values, "%s")
+ if self.executemany_values_page_size:
+ kwargs = {"page_size": self.executemany_values_page_size}
+ else:
+ kwargs = {}
+ xtras = self._psycopg2_extras()
+ context._psycopg2_fetched_rows = xtras.execute_values(
+ cursor,
+ statement,
+ parameters,
+ template=executemany_values,
+ fetch=bool(context.compiled.returning),
+ **kwargs
+ )
+
+ elif self.executemany_mode & EXECUTEMANY_BATCH:
+ if self.executemany_batch_page_size:
+ kwargs = {"page_size": self.executemany_batch_page_size}
+ else:
+ kwargs = {}
+ self._psycopg2_extras().execute_batch(
+ cursor, statement, parameters, **kwargs
+ )
+ else:
+ cursor.executemany(statement, parameters)
+
+ @util.memoized_instancemethod
+ def _hstore_oids(self, conn):
+ extras = self._psycopg2_extras()
+ if hasattr(conn, "dbapi_connection"):
+ conn = conn.dbapi_connection
+ oids = extras.HstoreAdapter.get_oids(conn)
+ if oids is not None and oids[0]:
+ return oids[0:2]
+ else:
+ return None
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+
+ is_multihost = False
+ if "host" in url.query:
+ is_multihost = isinstance(url.query["host"], (list, tuple))
+
+ if opts or url.query:
+ if not opts:
+ opts = {}
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
+ opts.update(url.query)
+ if is_multihost:
+ hosts, ports = zip(
+ *[
+ token.split(":") if ":" in token else (token, "")
+ for token in url.query["host"]
+ ]
+ )
+ opts["host"] = ",".join(hosts)
+ if "port" in opts:
+ raise exc.ArgumentError(
+ "Can't mix 'multihost' formats together; use "
+ '"host=h1,h2,h3&port=p1,p2,p3" or '
+ '"host=h1:p1&host=h2:p2&host=h3:p3" separately'
+ )
+ opts["port"] = ",".join(ports)
+ return ([], opts)
+ else:
+ # no connection arguments whatsoever; psycopg2.connect()
+ # requires that "dsn" be present as a blank string.
+ return ([""], opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ # check the "closed" flag. this might not be
+ # present on old psycopg2 versions. Also,
+ # this flag doesn't actually help in a lot of disconnect
+ # situations, so don't rely on it.
+ if getattr(connection, "closed", False):
+ return True
+
+ # checks based on strings. in the case that .closed
+ # didn't cut it, fall back onto these.
+ str_e = str(e).partition("\n")[0]
+ for msg in [
+ # these error messages from libpq: interfaces/libpq/fe-misc.c
+ # and interfaces/libpq/fe-secure.c.
+ "terminating connection",
+ "closed the connection",
+ "connection not open",
+ "could not receive data from server",
+ "could not send data to server",
+ # psycopg2 client errors, psycopg2/connection.h,
+ # psycopg2/cursor.h
+ "connection already closed",
+ "cursor already closed",
+ # not sure where this path is originally from, it may
+ # be obsolete. It really says "losed", not "closed".
+ "losed the connection unexpectedly",
+ # these can occur in newer SSL
+ "connection has been closed unexpectedly",
+ "SSL error: decryption failed or bad record mac",
+ "SSL SYSCALL error: Bad file descriptor",
+ "SSL SYSCALL error: EOF detected",
+ "SSL SYSCALL error: Operation timed out",
+ "SSL SYSCALL error: Bad address",
+ ]:
+ idx = str_e.find(msg)
+ if idx >= 0 and '"' not in str_e[:idx]:
+ return True
+ return False
+
+
+dialect = PGDialect_psycopg2
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
new file mode 100644
index 0000000..10d1aae
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py
@@ -0,0 +1,60 @@
+# testing/engines.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: postgresql+psycopg2cffi
+ :name: psycopg2cffi
+ :dbapi: psycopg2cffi
+ :connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://pypi.org/project/psycopg2cffi/
+
+``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C
+layer. This makes it suitable for use in e.g. PyPy. Documentation
+is as per ``psycopg2``.
+
+.. versionadded:: 1.0.0
+
+.. seealso::
+
+ :mod:`sqlalchemy.dialects.postgresql.psycopg2`
+
+""" # noqa
+from .psycopg2 import PGDialect_psycopg2
+
+
+class PGDialect_psycopg2cffi(PGDialect_psycopg2):
+ driver = "psycopg2cffi"
+ supports_unicode_statements = True
+ supports_statement_cache = True
+
+ # psycopg2cffi's first release is 2.5.0, but reports
+ # __version__ as 2.4.4. Subsequent releases seem to have
+ # fixed this.
+
+ FEATURE_VERSION_MAP = dict(
+ native_json=(2, 4, 4),
+ native_jsonb=(2, 7, 1),
+ sane_multi_rowcount=(2, 4, 4),
+ array_oid=(2, 4, 4),
+ hstore_adapter=(2, 4, 4),
+ )
+
+ @classmethod
+ def dbapi(cls):
+ return __import__("psycopg2cffi")
+
+ @classmethod
+ def _psycopg2_extensions(cls):
+ root = __import__("psycopg2cffi", fromlist=["extensions"])
+ return root.extensions
+
+ @classmethod
+ def _psycopg2_extras(cls):
+ root = __import__("psycopg2cffi", fromlist=["extras"])
+ return root.extras
+
+
+dialect = PGDialect_psycopg2cffi
diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py
new file mode 100644
index 0000000..d273b8c
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py
@@ -0,0 +1,278 @@
+# postgresql/pygresql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+.. dialect:: postgresql+pygresql
+ :name: pygresql
+ :dbapi: pgdb
+ :connectstring: postgresql+pygresql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://www.pygresql.org/
+
+.. note::
+
+ The pygresql dialect is **not tested as part of SQLAlchemy's continuous
+ integration** and may have unresolved issues. The recommended PostgreSQL
+ dialect is psycopg2.
+
+.. deprecated:: 1.4 The pygresql DBAPI is deprecated and will be removed
+ in a future version. Please use one of the supported DBAPIs to
+ connect to PostgreSQL.
+
+""" # noqa
+
+import decimal
+import re
+
+from .base import _DECIMAL_TYPES
+from .base import _FLOAT_TYPES
+from .base import _INT_TYPES
+from .base import PGCompiler
+from .base import PGDialect
+from .base import PGIdentifierPreparer
+from .base import UUID
+from .hstore import HSTORE
+from .json import JSON
+from .json import JSONB
+from ... import exc
+from ... import processors
+from ... import util
+from ...sql.elements import Null
+from ...types import JSON as Json
+from ...types import Numeric
+
+
+class _PGNumeric(Numeric):
+ def bind_processor(self, dialect):
+ return None
+
+ def result_processor(self, dialect, coltype):
+ if not isinstance(coltype, int):
+ coltype = coltype.oid
+ if self.asdecimal:
+ if coltype in _FLOAT_TYPES:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ # PyGreSQL returns Decimal natively for 1700 (numeric)
+ return None
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+ else:
+ if coltype in _FLOAT_TYPES:
+ # PyGreSQL returns float natively for 701 (float8)
+ return None
+ elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
+ return processors.to_float
+ else:
+ raise exc.InvalidRequestError(
+ "Unknown PG numeric type: %d" % coltype
+ )
+
+
+class _PGHStore(HSTORE):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_hstore:
+ return super(_PGHStore, self).bind_processor(dialect)
+ hstore = dialect.dbapi.Hstore
+
+ def process(value):
+ if isinstance(value, dict):
+ return hstore(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_hstore:
+ return super(_PGHStore, self).result_processor(dialect, coltype)
+
+
+class _PGJSON(JSON):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_json:
+ return super(_PGJSON, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_json:
+ return super(_PGJSON, self).result_processor(dialect, coltype)
+
+
+class _PGJSONB(JSONB):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_json:
+ return super(_PGJSONB, self).bind_processor(dialect)
+ json = dialect.dbapi.Json
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
+ if value is None or isinstance(value, (dict, list)):
+ return json(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_json:
+ return super(_PGJSONB, self).result_processor(dialect, coltype)
+
+
+class _PGUUID(UUID):
+ def bind_processor(self, dialect):
+ if not dialect.has_native_uuid:
+ return super(_PGUUID, self).bind_processor(dialect)
+ uuid = dialect.dbapi.Uuid
+
+ def process(value):
+ if value is None:
+ return None
+ if isinstance(value, (str, bytes)):
+ if len(value) == 16:
+ return uuid(bytes=value)
+ return uuid(value)
+ if isinstance(value, int):
+ return uuid(int=value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if not dialect.has_native_uuid:
+ return super(_PGUUID, self).result_processor(dialect, coltype)
+ if not self.as_uuid:
+
+ def process(value):
+ if value is not None:
+ return str(value)
+
+ return process
+
+
+class _PGCompiler(PGCompiler):
+ def visit_mod_binary(self, binary, operator, **kw):
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+
+ def post_process_text(self, text):
+ return text.replace("%", "%%")
+
+
+class _PGIdentifierPreparer(PGIdentifierPreparer):
+ def _escape_identifier(self, value):
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ return value.replace("%", "%%")
+
+
+class PGDialect_pygresql(PGDialect):
+
+ driver = "pygresql"
+ supports_statement_cache = True
+
+ statement_compiler = _PGCompiler
+ preparer = _PGIdentifierPreparer
+
+ @classmethod
+ def dbapi(cls):
+ import pgdb
+
+ util.warn_deprecated(
+ "The pygresql DBAPI is deprecated and will be removed "
+ "in a future version. Please use one of the supported DBAPIs to "
+ "connect to PostgreSQL.",
+ version="1.4",
+ )
+
+ return pgdb
+
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ Numeric: _PGNumeric,
+ HSTORE: _PGHStore,
+ Json: _PGJSON,
+ JSON: _PGJSON,
+ JSONB: _PGJSONB,
+ UUID: _PGUUID,
+ },
+ )
+
+ def __init__(self, **kwargs):
+ super(PGDialect_pygresql, self).__init__(**kwargs)
+ try:
+ version = self.dbapi.version
+ m = re.match(r"(\d+)\.(\d+)", version)
+ version = (int(m.group(1)), int(m.group(2)))
+ except (AttributeError, ValueError, TypeError):
+ version = (0, 0)
+ self.dbapi_version = version
+ if version < (5, 0):
+ has_native_hstore = has_native_json = has_native_uuid = False
+ if version != (0, 0):
+ util.warn(
+ "PyGreSQL is only fully supported by SQLAlchemy"
+ " since version 5.0."
+ )
+ else:
+ self.supports_unicode_statements = True
+ self.supports_unicode_binds = True
+ has_native_hstore = has_native_json = has_native_uuid = True
+ self.has_native_hstore = has_native_hstore
+ self.has_native_json = has_native_json
+ self.has_native_uuid = has_native_uuid
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["host"] = "%s:%s" % (
+ opts.get("host", "").rsplit(":", 1)[0],
+ opts.pop("port"),
+ )
+ opts.update(url.query)
+ return [], opts
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(e, self.dbapi.Error):
+ if not connection:
+ return False
+ try:
+ connection = connection.connection
+ except AttributeError:
+ pass
+ else:
+ if not connection:
+ return False
+ try:
+ return connection.closed
+ except AttributeError: # PyGreSQL < 5.0
+ return connection._cnx is None
+ return False
+
+
+dialect = PGDialect_pygresql
diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
new file mode 100644
index 0000000..886e368
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py
@@ -0,0 +1,126 @@
+# postgresql/pypostgresql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+.. dialect:: postgresql+pypostgresql
+ :name: py-postgresql
+ :dbapi: pypostgresql
+ :connectstring: postgresql+pypostgresql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://python.projects.pgfoundry.org/
+
+.. note::
+
+ The pypostgresql dialect is **not tested as part of SQLAlchemy's continuous
+ integration** and may have unresolved issues. The recommended PostgreSQL
+ driver is psycopg2.
+
+.. deprecated:: 1.4 The py-postgresql DBAPI is deprecated and will be removed
+ in a future version. This DBAPI is superseded by the external
+ version available at external-dialect_. Please use the external version or
+ one of the supported DBAPIs to connect to PostgreSQL.
+
+.. TODO update link
+.. _external-dialect: https://github.com/PyGreSQL
+
+""" # noqa
+
+from .base import PGDialect
+from .base import PGExecutionContext
+from ... import processors
+from ... import types as sqltypes
+from ... import util
+
+
+class PGNumeric(sqltypes.Numeric):
+ def bind_processor(self, dialect):
+ return processors.to_str
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ return None
+ else:
+ return processors.to_float
+
+
+class PGExecutionContext_pypostgresql(PGExecutionContext):
+ pass
+
+
+class PGDialect_pypostgresql(PGDialect):
+ driver = "pypostgresql"
+
+ supports_statement_cache = True
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+ description_encoding = None
+ default_paramstyle = "pyformat"
+
+ # requires trunk version to support sane rowcounts
+ # TODO: use dbapi version information to set this flag appropriately
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = False
+
+ execution_ctx_cls = PGExecutionContext_pypostgresql
+ colspecs = util.update_copy(
+ PGDialect.colspecs,
+ {
+ sqltypes.Numeric: PGNumeric,
+ # prevents PGNumeric from being used
+ sqltypes.Float: sqltypes.Float,
+ },
+ )
+
+ @classmethod
+ def dbapi(cls):
+ from postgresql.driver import dbapi20
+
+ # TODO update link
+ util.warn_deprecated(
+ "The py-postgresql DBAPI is deprecated and will be removed "
+ "in a future version. This DBAPI is superseded by the external"
+ "version available at https://github.com/PyGreSQL. Please "
+ "use one of the supported DBAPIs to connect to PostgreSQL.",
+ version="1.4",
+ )
+
+ return dbapi20
+
+ _DBAPI_ERROR_NAMES = [
+ "Error",
+ "InterfaceError",
+ "DatabaseError",
+ "DataError",
+ "OperationalError",
+ "IntegrityError",
+ "InternalError",
+ "ProgrammingError",
+ "NotSupportedError",
+ ]
+
+ @util.memoized_property
+ def dbapi_exception_translation_map(self):
+ if self.dbapi is None:
+ return {}
+
+ return dict(
+ (getattr(self.dbapi, name).__name__, name)
+ for name in self._DBAPI_ERROR_NAMES
+ )
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user")
+ if "port" in opts:
+ opts["port"] = int(opts["port"])
+ else:
+ opts["port"] = 5432
+ opts.update(url.query)
+ return ([], opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ return "connection is closed" in str(e)
+
+
+dialect = PGDialect_pypostgresql
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
new file mode 100644
index 0000000..51f3b04
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/ranges.py
@@ -0,0 +1,138 @@
+# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import types as sqltypes
+
+
+__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE")
+
+
+class RangeOperators(object):
+ """
+ This mixin provides functionality for the Range Operators
+ listed in the Range Operators table of the `PostgreSQL documentation`__
+ for Range Functions and Operators. It is used by all the range types
+ provided in the ``postgres`` dialect and can likely be used for
+ any range types you create yourself.
+
+ __ https://www.postgresql.org/docs/current/static/functions-range.html
+
+ No extra support is provided for the Range Functions listed in the Range
+ Functions table of the PostgreSQL documentation. For these, the normal
+ :func:`~sqlalchemy.sql.expression.func` object should be used.
+
+ """
+
+ class comparator_factory(sqltypes.Concatenable.Comparator):
+ """Define comparison operations for range types."""
+
+ def __ne__(self, other):
+ "Boolean expression. Returns true if two ranges are not equal"
+ if other is None:
+ return super(RangeOperators.comparator_factory, self).__ne__(
+ other
+ )
+ else:
+ return self.expr.op("<>", is_comparison=True)(other)
+
+ def contains(self, other, **kw):
+ """Boolean expression. Returns true if the right hand operand,
+ which can be an element or a range, is contained within the
+ column.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ return self.expr.op("@>", is_comparison=True)(other)
+
+ def contained_by(self, other):
+ """Boolean expression. Returns true if the column is contained
+ within the right hand operand.
+ """
+ return self.expr.op("<@", is_comparison=True)(other)
+
+ def overlaps(self, other):
+ """Boolean expression. Returns true if the column overlaps
+ (has points in common with) the right hand operand.
+ """
+ return self.expr.op("&&", is_comparison=True)(other)
+
+ def strictly_left_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ left of the right hand operand.
+ """
+ return self.expr.op("<<", is_comparison=True)(other)
+
+ __lshift__ = strictly_left_of
+
+ def strictly_right_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ right of the right hand operand.
+ """
+ return self.expr.op(">>", is_comparison=True)(other)
+
+ __rshift__ = strictly_right_of
+
+ def not_extend_right_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend right of the range in the operand.
+ """
+ return self.expr.op("&<", is_comparison=True)(other)
+
+ def not_extend_left_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend left of the range in the operand.
+ """
+ return self.expr.op("&>", is_comparison=True)(other)
+
+ def adjacent_to(self, other):
+ """Boolean expression. Returns true if the range in the column
+ is adjacent to the range in the operand.
+ """
+ return self.expr.op("-|-", is_comparison=True)(other)
+
+ def __add__(self, other):
+ """Range expression. Returns the union of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ return self.expr.op("+")(other)
+
+
+class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL INT4RANGE type."""
+
+ __visit_name__ = "INT4RANGE"
+
+
+class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL INT8RANGE type."""
+
+ __visit_name__ = "INT8RANGE"
+
+
+class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL NUMRANGE type."""
+
+ __visit_name__ = "NUMRANGE"
+
+
+class DATERANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL DATERANGE type."""
+
+ __visit_name__ = "DATERANGE"
+
+
+class TSRANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL TSRANGE type."""
+
+ __visit_name__ = "TSRANGE"
+
+
+class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
+ """Represent the PostgreSQL TSTZRANGE type."""
+
+ __visit_name__ = "TSTZRANGE"
diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py
new file mode 100644
index 0000000..8d8d933
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/__init__.py
@@ -0,0 +1,58 @@
+# sqlite/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import pysqlcipher # noqa
+from . import pysqlite # noqa
+from .base import BLOB
+from .base import BOOLEAN
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import DECIMAL
+from .base import FLOAT
+from .base import INTEGER
+from .base import JSON
+from .base import NUMERIC
+from .base import REAL
+from .base import SMALLINT
+from .base import TEXT
+from .base import TIME
+from .base import TIMESTAMP
+from .base import VARCHAR
+from .dml import Insert
+from .dml import insert
+from ...util import compat
+
+if compat.py3k:
+ from . import aiosqlite # noqa
+
+# default dialect
+base.dialect = dialect = pysqlite.dialect
+
+
+__all__ = (
+ "BLOB",
+ "BOOLEAN",
+ "CHAR",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "FLOAT",
+ "INTEGER",
+ "JSON",
+ "NUMERIC",
+ "SMALLINT",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "VARCHAR",
+ "REAL",
+ "Insert",
+ "insert",
+ "dialect",
+)
diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
new file mode 100644
index 0000000..9fc6d35
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py
@@ -0,0 +1,335 @@
+# sqlite/aiosqlite.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+
+.. dialect:: sqlite+aiosqlite
+ :name: aiosqlite
+ :dbapi: aiosqlite
+ :connectstring: sqlite+aiosqlite:///file_path
+ :url: https://pypi.org/project/aiosqlite/
+
+The aiosqlite dialect provides support for the SQLAlchemy asyncio interface
+running on top of pysqlite.
+
+aiosqlite is a wrapper around pysqlite that uses a background thread for
+each connection. It does not actually use non-blocking IO, as SQLite
+databases are not socket-based. However it does provide a working asyncio
+interface that's useful for testing and prototyping purposes.
+
+Using a special asyncio mediation layer, the aiosqlite dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("sqlite+aiosqlite:///filename")
+
+The URL passes through all arguments to the ``pysqlite`` driver, so all
+connection arguments are the same as they are for that of :ref:`pysqlite`.
+
+
+""" # noqa
+
+from .base import SQLiteExecutionContext
+from .pysqlite import SQLiteDialect_pysqlite
+from ... import pool
+from ... import util
+from ...engine import AdaptedConnection
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiosqlite_cursor:
+ __slots__ = (
+ "_adapt_connection",
+ "_connection",
+ "description",
+ "await_",
+ "_rows",
+ "arraysize",
+ "rowcount",
+ "lastrowid",
+ )
+
+ server_side = False
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+ self.arraysize = 1
+ self.rowcount = -1
+ self.description = None
+ self._rows = []
+
+ def close(self):
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ try:
+ _cursor = self.await_(self._connection.cursor())
+
+ if parameters is None:
+ self.await_(_cursor.execute(operation))
+ else:
+ self.await_(_cursor.execute(operation, parameters))
+
+ if _cursor.description:
+ self.description = _cursor.description
+ self.lastrowid = self.rowcount = -1
+
+ if not self.server_side:
+ self._rows = self.await_(_cursor.fetchall())
+ else:
+ self.description = None
+ self.lastrowid = _cursor.lastrowid
+ self.rowcount = _cursor.rowcount
+
+ if not self.server_side:
+ self.await_(_cursor.close())
+ else:
+ self._cursor = _cursor
+ except Exception as error:
+ self._adapt_connection._handle_exception(error)
+
+ def executemany(self, operation, seq_of_parameters):
+ try:
+ _cursor = self.await_(self._connection.cursor())
+ self.await_(_cursor.executemany(operation, seq_of_parameters))
+ self.description = None
+ self.lastrowid = _cursor.lastrowid
+ self.rowcount = _cursor.rowcount
+ self.await_(_cursor.close())
+ except Exception as error:
+ self._adapt_connection._handle_exception(error)
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
+ __slots__ = "_cursor"
+
+ server_side = True
+
+ def __init__(self, *arg, **kw):
+ super().__init__(*arg, **kw)
+ self._cursor = None
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
+ await_ = staticmethod(await_only)
+ __slots__ = ("dbapi", "_connection")
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+
+ @property
+ def isolation_level(self):
+ return self._connection.isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value):
+ try:
+ self._connection.isolation_level = value
+ except Exception as error:
+ self._handle_exception(error)
+
+ def create_function(self, *args, **kw):
+ try:
+ self.await_(self._connection.create_function(*args, **kw))
+ except Exception as error:
+ self._handle_exception(error)
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_aiosqlite_ss_cursor(self)
+ else:
+ return AsyncAdapt_aiosqlite_cursor(self)
+
+ def execute(self, *args, **kw):
+ return self.await_(self._connection.execute(*args, **kw))
+
+ def rollback(self):
+ try:
+ self.await_(self._connection.rollback())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def commit(self):
+ try:
+ self.await_(self._connection.commit())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def close(self):
+ try:
+ self.await_(self._connection.close())
+ except Exception as error:
+ self._handle_exception(error)
+
+ def _handle_exception(self, error):
+ if (
+ isinstance(error, ValueError)
+ and error.args[0] == "no active connection"
+ ):
+ util.raise_(
+ self.dbapi.sqlite.OperationalError("no active connection"),
+ from_=error,
+ )
+ else:
+ raise error
+
+
+class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiosqlite_dbapi:
+ def __init__(self, aiosqlite, sqlite):
+ self.aiosqlite = aiosqlite
+ self.sqlite = sqlite
+ self.paramstyle = "qmark"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "DatabaseError",
+ "Error",
+ "IntegrityError",
+ "NotSupportedError",
+ "OperationalError",
+ "ProgrammingError",
+ "sqlite_version",
+ "sqlite_version_info",
+ ):
+ setattr(self, name, getattr(self.aiosqlite, name))
+
+ for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"):
+ setattr(self, name, getattr(self.sqlite, name))
+
+ for name in ("Binary",):
+ setattr(self, name, getattr(self.sqlite, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ # Q. WHY do we need this?
+ # A. Because there is no way to set connection.isolation_level
+ # otherwise
+ # Q. BUT HOW do you know it is SAFE ?????
+ # A. The only operation that isn't safe is the isolation level set
+ # operation which aiosqlite appears to have let slip through even
+ # though pysqlite appears to do check_same_thread for this.
+ # All execute operations etc. should be safe because they all
+ # go through the single executor thread.
+
+ kw["check_same_thread"] = False
+
+ connection = self.aiosqlite.connect(*arg, **kw)
+
+ # it's a Thread. you'll thank us later
+ connection.daemon = True
+
+ if util.asbool(async_fallback):
+ return AsyncAdaptFallback_aiosqlite_connection(
+ self,
+ await_fallback(connection),
+ )
+ else:
+ return AsyncAdapt_aiosqlite_connection(
+ self,
+ await_only(connection),
+ )
+
+
+class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
+ def create_server_side_cursor(self):
+ return self._dbapi_connection.cursor(server_side=True)
+
+
+class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
+ driver = "aiosqlite"
+ supports_statement_cache = True
+
+ is_async = True
+
+ supports_server_side_cursors = True
+
+ execution_ctx_cls = SQLiteExecutionContext_aiosqlite
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_aiosqlite_dbapi(
+ __import__("aiosqlite"), __import__("sqlite3")
+ )
+
+ @classmethod
+ def get_pool_class(cls, url):
+ if cls._is_url_file_db(url):
+ return pool.NullPool
+ else:
+ return pool.StaticPool
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e, self.dbapi.OperationalError
+ ) and "no active connection" in str(e):
+ return True
+
+ return super().is_disconnect(e, connection, cursor)
+
+ def get_driver_connection(self, connection):
+ return connection._connection
+
+
+dialect = SQLiteDialect_aiosqlite
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
new file mode 100644
index 0000000..0959d04
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -0,0 +1,2556 @@
+# sqlite/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: sqlite
+ :name: SQLite
+ :full_support: 3.21, 3.28+
+ :normal_support: 3.12+
+ :best_effort: 3.7.16+
+
+.. _sqlite_datetime:
+
+Date and Time Types
+-------------------
+
+SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does
+not provide out of the box functionality for translating values between Python
+`datetime` objects and a SQLite-supported format. SQLAlchemy's own
+:class:`~sqlalchemy.types.DateTime` and related types provide date formatting
+and parsing functionality when SQLite is used. The implementation classes are
+:class:`_sqlite.DATETIME`, :class:`_sqlite.DATE` and :class:`_sqlite.TIME`.
+These types represent dates and times as ISO formatted strings, which also
+nicely support ordering. There's no reliance on typical "libc" internals for
+these functions so historical dates are fully supported.
+
+Ensuring Text affinity
+^^^^^^^^^^^^^^^^^^^^^^
+
+The DDL rendered for these types is the standard ``DATE``, ``TIME``
+and ``DATETIME`` indicators. However, custom storage formats can also be
+applied to these types. When the
+storage format is detected as containing no alpha characters, the DDL for
+these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``,
+so that the column continues to have textual affinity.
+
+.. seealso::
+
+ `Type Affinity <https://www.sqlite.org/datatype3.html#affinity>`_ -
+ in the SQLite documentation
+
+.. _sqlite_autoincrement:
+
+SQLite Auto Incrementing Behavior
+----------------------------------
+
+Background on SQLite's autoincrement is at: https://sqlite.org/autoinc.html
+
+Key concepts:
+
+* SQLite has an implicit "auto increment" feature that takes place for any
+ non-composite primary-key column that is specifically created using
+ "INTEGER PRIMARY KEY" for the type + primary key.
+
+* SQLite also has an explicit "AUTOINCREMENT" keyword, that is **not**
+ equivalent to the implicit autoincrement feature; this keyword is not
+ recommended for general use. SQLAlchemy does not render this keyword
+ unless a special SQLite-specific directive is used (see below). However,
+ it still requires that the column's type is named "INTEGER".
+
+Using the AUTOINCREMENT Keyword
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To specifically render the AUTOINCREMENT keyword on the primary key column
+when rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table
+construct::
+
+ Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ sqlite_autoincrement=True)
+
+Allowing autoincrement behavior SQLAlchemy types other than Integer/INTEGER
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+SQLite's typing model is based on naming conventions. Among other things, this
+means that any type name which contains the substring ``"INT"`` will be
+determined to be of "integer affinity". A type named ``"BIGINT"``,
+``"SPECIAL_INT"`` or even ``"XYZINTQPR"``, will be considered by SQLite to be
+of "integer" affinity. However, **the SQLite autoincrement feature, whether
+implicitly or explicitly enabled, requires that the name of the column's type
+is exactly the string "INTEGER"**. Therefore, if an application uses a type
+like :class:`.BigInteger` for a primary key, on SQLite this type will need to
+be rendered as the name ``"INTEGER"`` when emitting the initial ``CREATE
+TABLE`` statement in order for the autoincrement behavior to be available.
+
+One approach to achieve this is to use :class:`.Integer` on SQLite
+only using :meth:`.TypeEngine.with_variant`::
+
+ table = Table(
+ "my_table", metadata,
+ Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True)
+ )
+
+Another is to use a subclass of :class:`.BigInteger` that overrides its DDL
+name to be ``INTEGER`` when compiled against SQLite::
+
+ from sqlalchemy import BigInteger
+ from sqlalchemy.ext.compiler import compiles
+
+ class SLBigInteger(BigInteger):
+ pass
+
+ @compiles(SLBigInteger, 'sqlite')
+ def bi_c(element, compiler, **kw):
+ return "INTEGER"
+
+ @compiles(SLBigInteger)
+ def bi_c(element, compiler, **kw):
+ return compiler.visit_BIGINT(element, **kw)
+
+
+ table = Table(
+ "my_table", metadata,
+ Column("id", SLBigInteger(), primary_key=True)
+ )
+
+.. seealso::
+
+ :meth:`.TypeEngine.with_variant`
+
+ :ref:`sqlalchemy.ext.compiler_toplevel`
+
+ `Datatypes In SQLite Version 3 <https://sqlite.org/datatype3.html>`_
+
+.. _sqlite_concurrency:
+
+Database Locking Behavior / Concurrency
+---------------------------------------
+
+SQLite is not designed for a high level of write concurrency. The database
+itself, being a file, is locked completely during write operations within
+transactions, meaning exactly one "connection" (in reality a file handle)
+has exclusive access to the database during this period - all other
+"connections" will be blocked during this time.
+
+The Python DBAPI specification also calls for a connection model that is
+always in a transaction; there is no ``connection.begin()`` method,
+only ``connection.commit()`` and ``connection.rollback()``, upon which a
+new transaction is to be begun immediately. This may seem to imply
+that the SQLite driver would in theory allow only a single filehandle on a
+particular database file at any time; however, there are several
+factors both within SQLite itself as well as within the pysqlite driver
+which loosen this restriction significantly.
+
+However, no matter what locking modes are used, SQLite will still always
+lock the database file once a transaction is started and DML (e.g. INSERT,
+UPDATE, DELETE) has at least been emitted, and this will block
+other transactions at least at the point that they also attempt to emit DML.
+By default, the length of time on this block is very short before it times out
+with an error.
+
+This behavior becomes more critical when used in conjunction with the
+SQLAlchemy ORM. SQLAlchemy's :class:`.Session` object by default runs
+within a transaction, and with its autoflush model, may emit DML preceding
+any SELECT statement. This may lead to a SQLite database that locks
+more quickly than is expected. The locking mode of SQLite and the pysqlite
+driver can be manipulated to some degree, however it should be noted that
+achieving a high degree of write-concurrency with SQLite is a losing battle.
+
+For more information on SQLite's lack of write concurrency by design, please
+see
+`Situations Where Another RDBMS May Work Better - High Concurrency
+<https://www.sqlite.org/whentouse.html>`_ near the bottom of the page.
+
+The following subsections introduce areas that are impacted by SQLite's
+file-based architecture and additionally will usually require workarounds to
+work when using the pysqlite driver.
+
+.. _sqlite_isolation_level:
+
+Transaction Isolation Level / Autocommit
+----------------------------------------
+
+SQLite supports "transaction isolation" in a non-standard way, along two
+axes. One is that of the
+`PRAGMA read_uncommitted <https://www.sqlite.org/pragma.html#pragma_read_uncommitted>`_
+instruction. This setting can essentially switch SQLite between its
+default mode of ``SERIALIZABLE`` isolation, and a "dirty read" isolation
+mode normally referred to as ``READ UNCOMMITTED``.
+
+SQLAlchemy ties into this PRAGMA statement using the
+:paramref:`_sa.create_engine.isolation_level` parameter of
+:func:`_sa.create_engine`.
+Valid values for this parameter when used with SQLite are ``"SERIALIZABLE"``
+and ``"READ UNCOMMITTED"`` corresponding to a value of 0 and 1, respectively.
+SQLite defaults to ``SERIALIZABLE``, however its behavior is impacted by
+the pysqlite driver's default behavior.
+
+When using the pysqlite driver, the ``"AUTOCOMMIT"`` isolation level is also
+available, which will alter the pysqlite connection using the ``.isolation_level``
+attribute on the DBAPI connection and set it to None for the duration
+of the setting.
+
+.. versionadded:: 1.3.16 added support for SQLite AUTOCOMMIT isolation level
+ when using the pysqlite / sqlite3 SQLite driver.
+
+
+The other axis along which SQLite's transactional locking is impacted is
+via the nature of the ``BEGIN`` statement used. The three varieties
+are "deferred", "immediate", and "exclusive", as described at
+`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_. A straight
+``BEGIN`` statement uses the "deferred" mode, where the database file is
+not locked until the first read or write operation, and read access remains
+open to other transactions until the first write operation. But again,
+it is critical to note that the pysqlite driver interferes with this behavior
+by *not even emitting BEGIN* until the first write operation.
+
+.. warning::
+
+ SQLite's transactional scope is impacted by unresolved
+ issues in the pysqlite driver, which defers BEGIN statements to a greater
+ degree than is often feasible. See the section :ref:`pysqlite_serializable`
+ for techniques to work around this behavior.
+
+.. seealso::
+
+ :ref:`dbapi_autocommit`
+
+SAVEPOINT Support
+----------------------------
+
+SQLite supports SAVEPOINTs, which only function once a transaction is
+begun. SQLAlchemy's SAVEPOINT support is available using the
+:meth:`_engine.Connection.begin_nested` method at the Core level, and
+:meth:`.Session.begin_nested` at the ORM level. However, SAVEPOINTs
+won't work at all with pysqlite unless workarounds are taken.
+
+.. warning::
+
+ SQLite's SAVEPOINT feature is impacted by unresolved
+ issues in the pysqlite driver, which defers BEGIN statements to a greater
+ degree than is often feasible. See the section :ref:`pysqlite_serializable`
+ for techniques to work around this behavior.
+
+Transactional DDL
+----------------------------
+
+The SQLite database supports transactional :term:`DDL` as well.
+In this case, the pysqlite driver is not only failing to start transactions,
+it also is ending any existing transaction when DDL is detected, so again,
+workarounds are required.
+
+.. warning::
+
+ SQLite's transactional DDL is impacted by unresolved issues
+ in the pysqlite driver, which fails to emit BEGIN and additionally
+ forces a COMMIT to cancel any transaction when DDL is encountered.
+ See the section :ref:`pysqlite_serializable`
+ for techniques to work around this behavior.
+
+.. _sqlite_foreign_keys:
+
+Foreign Key Support
+-------------------
+
+SQLite supports FOREIGN KEY syntax when emitting CREATE statements for tables,
+however by default these constraints have no effect on the operation of the
+table.
+
+Constraint checking on SQLite has three prerequisites:
+
+* At least version 3.6.19 of SQLite must be in use
+* The SQLite library must be compiled *without* the SQLITE_OMIT_FOREIGN_KEY
+ or SQLITE_OMIT_TRIGGER symbols enabled.
+* The ``PRAGMA foreign_keys = ON`` statement must be emitted on all
+ connections before use -- including the initial call to
+ :meth:`sqlalchemy.schema.MetaData.create_all`.
+
+SQLAlchemy allows for the ``PRAGMA`` statement to be emitted automatically for
+new connections through the usage of events::
+
+ from sqlalchemy.engine import Engine
+ from sqlalchemy import event
+
+ @event.listens_for(Engine, "connect")
+ def set_sqlite_pragma(dbapi_connection, connection_record):
+ cursor = dbapi_connection.cursor()
+ cursor.execute("PRAGMA foreign_keys=ON")
+ cursor.close()
+
+.. warning::
+
+ When SQLite foreign keys are enabled, it is **not possible**
+ to emit CREATE or DROP statements for tables that contain
+ mutually-dependent foreign key constraints;
+ to emit the DDL for these tables requires that ALTER TABLE be used to
+ create or drop these constraints separately, for which SQLite has
+ no support.
+
+.. seealso::
+
+ `SQLite Foreign Key Support <https://www.sqlite.org/foreignkeys.html>`_
+ - on the SQLite web site.
+
+ :ref:`event_toplevel` - SQLAlchemy event API.
+
+ :ref:`use_alter` - more information on SQLAlchemy's facilities for handling
+ mutually-dependent foreign key constraints.
+
+.. _sqlite_on_conflict_ddl:
+
+ON CONFLICT support for constraints
+-----------------------------------
+
+.. seealso:: This section describes the :term:`DDL` version of "ON CONFLICT" for
+ SQLite, which occurs within a CREATE TABLE statement. For "ON CONFLICT" as
+ applied to an INSERT statement, see :ref:`sqlite_on_conflict_insert`.
+
+SQLite supports a non-standard DDL clause known as ON CONFLICT which can be applied
+to primary key, unique, check, and not null constraints. In DDL, it is
+rendered either within the "CONSTRAINT" clause or within the column definition
+itself depending on the location of the target constraint. To render this
+clause within DDL, the extension parameter ``sqlite_on_conflict`` can be
+specified with a string conflict resolution algorithm within the
+:class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`,
+:class:`.CheckConstraint` objects. Within the :class:`_schema.Column` object,
+there
+are individual parameters ``sqlite_on_conflict_not_null``,
+``sqlite_on_conflict_primary_key``, ``sqlite_on_conflict_unique`` which each
+correspond to the three types of relevant constraint types that can be
+indicated from a :class:`_schema.Column` object.
+
+.. seealso::
+
+ `ON CONFLICT <https://www.sqlite.org/lang_conflict.html>`_ - in the SQLite
+ documentation
+
+.. versionadded:: 1.3
+
+
+The ``sqlite_on_conflict`` parameters accept a string argument which is just
+the resolution name to be chosen, which on SQLite can be one of ROLLBACK,
+ABORT, FAIL, IGNORE, and REPLACE. For example, to add a UNIQUE constraint
+that specifies the IGNORE algorithm::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', Integer),
+ UniqueConstraint('id', 'data', sqlite_on_conflict='IGNORE')
+ )
+
+The above renders CREATE TABLE DDL as::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ data INTEGER,
+ PRIMARY KEY (id),
+ UNIQUE (id, data) ON CONFLICT IGNORE
+ )
+
+
+When using the :paramref:`_schema.Column.unique`
+flag to add a UNIQUE constraint
+to a single column, the ``sqlite_on_conflict_unique`` parameter can
+be added to the :class:`_schema.Column` as well, which will be added to the
+UNIQUE constraint in the DDL::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', Integer, unique=True,
+ sqlite_on_conflict_unique='IGNORE')
+ )
+
+rendering::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ data INTEGER,
+ PRIMARY KEY (id),
+ UNIQUE (data) ON CONFLICT IGNORE
+ )
+
+To apply the FAIL algorithm for a NOT NULL constraint,
+``sqlite_on_conflict_not_null`` is used::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', Integer, nullable=False,
+ sqlite_on_conflict_not_null='FAIL')
+ )
+
+this renders the column inline ON CONFLICT phrase::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ data INTEGER NOT NULL ON CONFLICT FAIL,
+ PRIMARY KEY (id)
+ )
+
+
+Similarly, for an inline primary key, use ``sqlite_on_conflict_primary_key``::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, primary_key=True,
+ sqlite_on_conflict_primary_key='FAIL')
+ )
+
+SQLAlchemy renders the PRIMARY KEY constraint separately, so the conflict
+resolution algorithm is applied to the constraint itself::
+
+ CREATE TABLE some_table (
+ id INTEGER NOT NULL,
+ PRIMARY KEY (id) ON CONFLICT FAIL
+ )
+
+.. _sqlite_on_conflict_insert:
+
+INSERT...ON CONFLICT (Upsert)
+-----------------------------------
+
+.. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for
+ SQLite, which occurs within an INSERT statement. For "ON CONFLICT" as
+ applied to a CREATE TABLE statement, see :ref:`sqlite_on_conflict_ddl`.
+
+From version 3.24.0 onwards, SQLite supports "upserts" (update or insert)
+of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT``
+statement. A candidate row will only be inserted if that row does not violate
+any unique or primary key constraints. In the case of a unique constraint violation, a
+secondary action can occur which can be either "DO UPDATE", indicating that
+the data in the target row should be updated, or "DO NOTHING", which indicates
+to silently skip this row.
+
+Conflicts are determined using columns that are part of existing unique
+constraints and indexes. These constraints are identified by stating the
+columns and conditions that comprise the indexes.
+
+SQLAlchemy provides ``ON CONFLICT`` support via the SQLite-specific
+:func:`_sqlite.insert()` function, which provides
+the generative methods :meth:`_sqlite.Insert.on_conflict_do_update`
+and :meth:`_sqlite.Insert.on_conflict_do_nothing`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.dialects.sqlite import insert
+
+ >>> insert_stmt = insert(my_table).values(
+ ... id='some_existing_id',
+ ... data='inserted value')
+
+ >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?{stop}
+
+ >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(
+ ... index_elements=['id']
+ ... )
+
+ >>> print(do_nothing_stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+ ON CONFLICT (id) DO NOTHING
+
+.. versionadded:: 1.4
+
+.. seealso::
+
+ `Upsert
+ <https://sqlite.org/lang_UPSERT.html>`_
+ - in the SQLite documentation.
+
+
+Specifying the Target
+^^^^^^^^^^^^^^^^^^^^^
+
+Both methods supply the "target" of the conflict using column inference:
+
+* The :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` argument
+ specifies a sequence containing string column names, :class:`_schema.Column`
+ objects, and/or SQL expression elements, which would identify a unique index
+ or unique constraint.
+
+* When using :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements`
+ to infer an index, a partial index can be inferred by also specifying the
+ :paramref:`_sqlite.Insert.on_conflict_do_update.index_where` parameter:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data')
+
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=[my_table.c.user_email],
+ ... index_where=my_table.c.user_email.like('%@gmail.com'),
+ ... set_=dict(data=stmt.excluded.data)
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (data, user_email) VALUES (?, ?)
+ ON CONFLICT (user_email)
+ WHERE user_email LIKE '%@gmail.com'
+ DO UPDATE SET data = excluded.data
+ >>>
+
+The SET Clause
+^^^^^^^^^^^^^^^
+
+``ON CONFLICT...DO UPDATE`` is used to perform an update of the already
+existing row, using any combination of new values as well as values
+from the proposed insertion. These values are specified using the
+:paramref:`_sqlite.Insert.on_conflict_do_update.set_` parameter. This
+parameter accepts a dictionary which consists of direct values
+for UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value')
+ ... )
+
+ >>> print(do_update_stmt)
+
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?
+
+.. warning::
+
+ The :meth:`_sqlite.Insert.on_conflict_do_update` method does **not** take
+ into account Python-side default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`. These
+ values will not be exercised for an ON CONFLICT style of UPDATE, unless
+ they are manually specified in the
+ :paramref:`_sqlite.Insert.on_conflict_do_update.set_` dictionary.
+
+Updating using the Excluded INSERT Values
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to refer to the proposed insertion row, the special alias
+:attr:`~.sqlite.Insert.excluded` is available as an attribute on
+the :class:`_sqlite.Insert` object; this object creates an "excluded." prefix
+on a column, that informs the DO UPDATE to update the row with the value that
+would have been inserted had the constraint not failed:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+
+ >>> do_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author)
+ ... )
+
+ >>> print(do_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author
+
+Additional WHERE Criteria
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The :meth:`_sqlite.Insert.on_conflict_do_update` method also accepts
+a WHERE clause using the :paramref:`_sqlite.Insert.on_conflict_do_update.where`
+parameter, which will limit those rows which receive an UPDATE:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(
+ ... id='some_id',
+ ... data='inserted value',
+ ... author='jlh'
+ ... )
+
+ >>> on_update_stmt = stmt.on_conflict_do_update(
+ ... index_elements=['id'],
+ ... set_=dict(data='updated value', author=stmt.excluded.author),
+ ... where=(my_table.c.status == 2)
+ ... )
+ >>> print(on_update_stmt)
+ {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?)
+ ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author
+ WHERE my_table.status = ?
+
+
+Skipping Rows with DO NOTHING
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``ON CONFLICT`` may be used to skip inserting a row entirely
+if any conflict with a unique constraint occurs; below this is illustrated
+using the :meth:`_sqlite.Insert.on_conflict_do_nothing` method:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id'])
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING
+
+
+If ``DO NOTHING`` is used without specifying any columns or constraint,
+it has the effect of skipping the INSERT for any unique violation which
+occurs:
+
+.. sourcecode:: pycon+sql
+
+ >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+ >>> stmt = stmt.on_conflict_do_nothing()
+ >>> print(stmt)
+ {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING
+
+.. _sqlite_type_reflection:
+
+Type Reflection
+---------------
+
+SQLite types are unlike those of most other database backends, in that
+the string name of the type usually does not correspond to a "type" in a
+one-to-one fashion. Instead, SQLite links per-column typing behavior
+to one of five so-called "type affinities" based on a string matching
+pattern for the type.
+
+SQLAlchemy's reflection process, when inspecting types, uses a simple
+lookup table to link the keywords returned to provided SQLAlchemy types.
+This lookup table is present within the SQLite dialect as it is for all
+other dialects. However, the SQLite dialect has a different "fallback"
+routine for when a particular type name is not located in the lookup map;
+it instead implements the SQLite "type affinity" scheme located at
+https://www.sqlite.org/datatype3.html section 2.1.
+
+The provided typemap will make direct associations from an exact string
+name match for the following types:
+
+:class:`_types.BIGINT`, :class:`_types.BLOB`,
+:class:`_types.BOOLEAN`, :class:`_types.BOOLEAN`,
+:class:`_types.CHAR`, :class:`_types.DATE`,
+:class:`_types.DATETIME`, :class:`_types.FLOAT`,
+:class:`_types.DECIMAL`, :class:`_types.FLOAT`,
+:class:`_types.INTEGER`, :class:`_types.INTEGER`,
+:class:`_types.NUMERIC`, :class:`_types.REAL`,
+:class:`_types.SMALLINT`, :class:`_types.TEXT`,
+:class:`_types.TIME`, :class:`_types.TIMESTAMP`,
+:class:`_types.VARCHAR`, :class:`_types.NVARCHAR`,
+:class:`_types.NCHAR`
+
+When a type name does not match one of the above types, the "type affinity"
+lookup is used instead:
+
+* :class:`_types.INTEGER` is returned if the type name includes the
+ string ``INT``
+* :class:`_types.TEXT` is returned if the type name includes the
+ string ``CHAR``, ``CLOB`` or ``TEXT``
+* :class:`_types.NullType` is returned if the type name includes the
+ string ``BLOB``
+* :class:`_types.REAL` is returned if the type name includes the string
+ ``REAL``, ``FLOA`` or ``DOUB``.
+* Otherwise, the :class:`_types.NUMERIC` type is used.
+
+.. versionadded:: 0.9.3 Support for SQLite type affinity rules when reflecting
+ columns.
+
+
+.. _sqlite_partial_index:
+
+Partial Indexes
+---------------
+
+A partial index, e.g. one which uses a WHERE clause, can be specified
+with the DDL system using the argument ``sqlite_where``::
+
+ tbl = Table('testtbl', m, Column('data', Integer))
+ idx = Index('test_idx1', tbl.c.data,
+ sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10))
+
+The index will be rendered at create time as::
+
+ CREATE INDEX test_idx1 ON testtbl (data)
+ WHERE data > 5 AND data < 10
+
+.. versionadded:: 0.9.9
+
+.. _sqlite_dotted_column_names:
+
+Dotted Column Names
+-------------------
+
+Using table or column names that explicitly have periods in them is
+**not recommended**. While this is generally a bad idea for relational
+databases in general, as the dot is a syntactically significant character,
+the SQLite driver up until version **3.10.0** of SQLite has a bug which
+requires that SQLAlchemy filter out these dots in result sets.
+
+.. versionchanged:: 1.1
+
+ The following SQLite issue has been resolved as of version 3.10.0
+ of SQLite. SQLAlchemy as of **1.1** automatically disables its internal
+ workarounds based on detection of this version.
+
+The bug, entirely outside of SQLAlchemy, can be illustrated thusly::
+
+ import sqlite3
+
+ assert sqlite3.sqlite_version_info < (3, 10, 0), "bug is fixed in this version"
+
+ conn = sqlite3.connect(":memory:")
+ cursor = conn.cursor()
+
+ cursor.execute("create table x (a integer, b integer)")
+ cursor.execute("insert into x (a, b) values (1, 1)")
+ cursor.execute("insert into x (a, b) values (2, 2)")
+
+ cursor.execute("select x.a, x.b from x")
+ assert [c[0] for c in cursor.description] == ['a', 'b']
+
+ cursor.execute('''
+ select x.a, x.b from x where a=1
+ union
+ select x.a, x.b from x where a=2
+ ''')
+ assert [c[0] for c in cursor.description] == ['a', 'b'], \
+ [c[0] for c in cursor.description]
+
+The second assertion fails::
+
+ Traceback (most recent call last):
+ File "test.py", line 19, in <module>
+ [c[0] for c in cursor.description]
+ AssertionError: ['x.a', 'x.b']
+
+Where above, the driver incorrectly reports the names of the columns
+including the name of the table, which is entirely inconsistent vs.
+when the UNION is not present.
+
+SQLAlchemy relies upon column names being predictable in how they match
+to the original statement, so the SQLAlchemy dialect has no choice but
+to filter these out::
+
+
+ from sqlalchemy import create_engine
+
+ eng = create_engine("sqlite://")
+ conn = eng.connect()
+
+ conn.exec_driver_sql("create table x (a integer, b integer)")
+ conn.exec_driver_sql("insert into x (a, b) values (1, 1)")
+ conn.exec_driver_sql("insert into x (a, b) values (2, 2)")
+
+ result = conn.exec_driver_sql("select x.a, x.b from x")
+ assert result.keys() == ["a", "b"]
+
+ result = conn.exec_driver_sql('''
+ select x.a, x.b from x where a=1
+ union
+ select x.a, x.b from x where a=2
+ ''')
+ assert result.keys() == ["a", "b"]
+
+Note that above, even though SQLAlchemy filters out the dots, *both
+names are still addressable*::
+
+ >>> row = result.first()
+ >>> row["a"]
+ 1
+ >>> row["x.a"]
+ 1
+ >>> row["b"]
+ 1
+ >>> row["x.b"]
+ 1
+
+Therefore, the workaround applied by SQLAlchemy only impacts
+:meth:`_engine.CursorResult.keys` and :meth:`.Row.keys()` in the public API. In
+the very specific case where an application is forced to use column names that
+contain dots, and the functionality of :meth:`_engine.CursorResult.keys` and
+:meth:`.Row.keys()` is required to return these dotted names unmodified,
+the ``sqlite_raw_colnames`` execution option may be provided, either on a
+per-:class:`_engine.Connection` basis::
+
+ result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql('''
+ select x.a, x.b from x where a=1
+ union
+ select x.a, x.b from x where a=2
+ ''')
+ assert result.keys() == ["x.a", "x.b"]
+
+or on a per-:class:`_engine.Engine` basis::
+
+ engine = create_engine("sqlite://", execution_options={"sqlite_raw_colnames": True})
+
+When using the per-:class:`_engine.Engine` execution option, note that
+**Core and ORM queries that use UNION may not function properly**.
+
+SQLite-specific table options
+-----------------------------
+
+One option for CREATE TABLE is supported directly by the SQLite
+dialect in conjunction with the :class:`_schema.Table` construct:
+
+* ``WITHOUT ROWID``::
+
+ Table("some_table", metadata, ..., sqlite_with_rowid=False)
+
+.. seealso::
+
+ `SQLite CREATE TABLE options
+ <https://www.sqlite.org/lang_createtable.html>`_
+
+""" # noqa
+
+import datetime
+import numbers
+import re
+
+from .json import JSON
+from .json import JSONIndexType
+from .json import JSONPathType
+from ... import exc
+from ... import processors
+from ... import schema as sa_schema
+from ... import sql
+from ... import types as sqltypes
+from ... import util
+from ...engine import default
+from ...engine import reflection
+from ...sql import coercions
+from ...sql import ColumnElement
+from ...sql import compiler
+from ...sql import elements
+from ...sql import roles
+from ...sql import schema
+from ...types import BLOB # noqa
+from ...types import BOOLEAN # noqa
+from ...types import CHAR # noqa
+from ...types import DECIMAL # noqa
+from ...types import FLOAT # noqa
+from ...types import INTEGER # noqa
+from ...types import NUMERIC # noqa
+from ...types import REAL # noqa
+from ...types import SMALLINT # noqa
+from ...types import TEXT # noqa
+from ...types import TIMESTAMP # noqa
+from ...types import VARCHAR # noqa
+
+
+class _SQliteJson(JSON):
+ def result_processor(self, dialect, coltype):
+ default_processor = super(_SQliteJson, self).result_processor(
+ dialect, coltype
+ )
+
+ def process(value):
+ try:
+ return default_processor(value)
+ except TypeError:
+ if isinstance(value, numbers.Number):
+ return value
+ else:
+ raise
+
+ return process
+
+
+class _DateTimeMixin(object):
+ _reg = None
+ _storage_format = None
+
+ def __init__(self, storage_format=None, regexp=None, **kw):
+ super(_DateTimeMixin, self).__init__(**kw)
+ if regexp is not None:
+ self._reg = re.compile(regexp)
+ if storage_format is not None:
+ self._storage_format = storage_format
+
+ @property
+ def format_is_text_affinity(self):
+ """return True if the storage format will automatically imply
+ a TEXT affinity.
+
+ If the storage format contains no non-numeric characters,
+ it will imply a NUMERIC storage format on SQLite; in this case,
+ the type will generate its DDL as DATE_CHAR, DATETIME_CHAR,
+ TIME_CHAR.
+
+ .. versionadded:: 1.0.0
+
+ """
+ spec = self._storage_format % {
+ "year": 0,
+ "month": 0,
+ "day": 0,
+ "hour": 0,
+ "minute": 0,
+ "second": 0,
+ "microsecond": 0,
+ }
+ return bool(re.search(r"[^0-9]", spec))
+
+ def adapt(self, cls, **kw):
+ if issubclass(cls, _DateTimeMixin):
+ if self._storage_format:
+ kw["storage_format"] = self._storage_format
+ if self._reg:
+ kw["regexp"] = self._reg
+ return super(_DateTimeMixin, self).adapt(cls, **kw)
+
+ def literal_processor(self, dialect):
+ bp = self.bind_processor(dialect)
+
+ def process(value):
+ return "'%s'" % bp(value)
+
+ return process
+
+
+class DATETIME(_DateTimeMixin, sqltypes.DateTime):
+ r"""Represent a Python datetime object in SQLite using a string.
+
+ The default string storage format is::
+
+ "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+
+ e.g.::
+
+ 2021-03-15 12:05:57.105542
+
+ The storage format can be customized to some degree using the
+ ``storage_format`` and ``regexp`` parameters, such as::
+
+ import re
+ from sqlalchemy.dialects.sqlite import DATETIME
+
+ dt = DATETIME(storage_format="%(year)04d/%(month)02d/%(day)02d "
+ "%(hour)02d:%(minute)02d:%(second)02d",
+ regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)"
+ )
+
+ :param storage_format: format string which will be applied to the dict
+ with keys year, month, day, hour, minute, second, and microsecond.
+
+ :param regexp: regular expression which will be applied to incoming result
+ rows. If the regexp contains named groups, the resulting match dict is
+ applied to the Python datetime() constructor as keyword arguments.
+ Otherwise, if positional groups are used, the datetime() constructor
+ is called with positional arguments via
+ ``*map(int, match_obj.groups(0))``.
+
+ """ # noqa
+
+ _storage_format = (
+ "%(year)04d-%(month)02d-%(day)02d "
+ "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+ )
+
+ def __init__(self, *args, **kwargs):
+ truncate_microseconds = kwargs.pop("truncate_microseconds", False)
+ super(DATETIME, self).__init__(*args, **kwargs)
+ if truncate_microseconds:
+ assert "storage_format" not in kwargs, (
+ "You can specify only "
+ "one of truncate_microseconds or storage_format."
+ )
+ assert "regexp" not in kwargs, (
+ "You can specify only one of "
+ "truncate_microseconds or regexp."
+ )
+ self._storage_format = (
+ "%(year)04d-%(month)02d-%(day)02d "
+ "%(hour)02d:%(minute)02d:%(second)02d"
+ )
+
+ def bind_processor(self, dialect):
+ datetime_datetime = datetime.datetime
+ datetime_date = datetime.date
+ format_ = self._storage_format
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, datetime_datetime):
+ return format_ % {
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ "hour": value.hour,
+ "minute": value.minute,
+ "second": value.second,
+ "microsecond": value.microsecond,
+ }
+ elif isinstance(value, datetime_date):
+ return format_ % {
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ "hour": 0,
+ "minute": 0,
+ "second": 0,
+ "microsecond": 0,
+ }
+ else:
+ raise TypeError(
+ "SQLite DateTime type only accepts Python "
+ "datetime and date objects as input."
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if self._reg:
+ return processors.str_to_datetime_processor_factory(
+ self._reg, datetime.datetime
+ )
+ else:
+ return processors.str_to_datetime
+
+
+class DATE(_DateTimeMixin, sqltypes.Date):
+ r"""Represent a Python date object in SQLite using a string.
+
+ The default string storage format is::
+
+ "%(year)04d-%(month)02d-%(day)02d"
+
+ e.g.::
+
+ 2011-03-15
+
+ The storage format can be customized to some degree using the
+ ``storage_format`` and ``regexp`` parameters, such as::
+
+ import re
+ from sqlalchemy.dialects.sqlite import DATE
+
+ d = DATE(
+ storage_format="%(month)02d/%(day)02d/%(year)04d",
+ regexp=re.compile("(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+)")
+ )
+
+ :param storage_format: format string which will be applied to the
+ dict with keys year, month, and day.
+
+ :param regexp: regular expression which will be applied to
+ incoming result rows. If the regexp contains named groups, the
+ resulting match dict is applied to the Python date() constructor
+ as keyword arguments. Otherwise, if positional groups are used, the
+ date() constructor is called with positional arguments via
+ ``*map(int, match_obj.groups(0))``.
+ """
+
+ _storage_format = "%(year)04d-%(month)02d-%(day)02d"
+
+ def bind_processor(self, dialect):
+ datetime_date = datetime.date
+ format_ = self._storage_format
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, datetime_date):
+ return format_ % {
+ "year": value.year,
+ "month": value.month,
+ "day": value.day,
+ }
+ else:
+ raise TypeError(
+ "SQLite Date type only accepts Python "
+ "date objects as input."
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if self._reg:
+ return processors.str_to_datetime_processor_factory(
+ self._reg, datetime.date
+ )
+ else:
+ return processors.str_to_date
+
+
+class TIME(_DateTimeMixin, sqltypes.Time):
+ r"""Represent a Python time object in SQLite using a string.
+
+ The default string storage format is::
+
+ "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+
+ e.g.::
+
+ 12:05:57.10558
+
+ The storage format can be customized to some degree using the
+ ``storage_format`` and ``regexp`` parameters, such as::
+
+ import re
+ from sqlalchemy.dialects.sqlite import TIME
+
+ t = TIME(storage_format="%(hour)02d-%(minute)02d-"
+ "%(second)02d-%(microsecond)06d",
+ regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?")
+ )
+
+ :param storage_format: format string which will be applied to the dict
+ with keys hour, minute, second, and microsecond.
+
+ :param regexp: regular expression which will be applied to incoming result
+ rows. If the regexp contains named groups, the resulting match dict is
+ applied to the Python time() constructor as keyword arguments. Otherwise,
+ if positional groups are used, the time() constructor is called with
+ positional arguments via ``*map(int, match_obj.groups(0))``.
+ """
+
+ _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d"
+
+ def __init__(self, *args, **kwargs):
+ truncate_microseconds = kwargs.pop("truncate_microseconds", False)
+ super(TIME, self).__init__(*args, **kwargs)
+ if truncate_microseconds:
+ assert "storage_format" not in kwargs, (
+ "You can specify only "
+ "one of truncate_microseconds or storage_format."
+ )
+ assert "regexp" not in kwargs, (
+ "You can specify only one of "
+ "truncate_microseconds or regexp."
+ )
+ self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d"
+
+ def bind_processor(self, dialect):
+ datetime_time = datetime.time
+ format_ = self._storage_format
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, datetime_time):
+ return format_ % {
+ "hour": value.hour,
+ "minute": value.minute,
+ "second": value.second,
+ "microsecond": value.microsecond,
+ }
+ else:
+ raise TypeError(
+ "SQLite Time type only accepts Python "
+ "time objects as input."
+ )
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if self._reg:
+ return processors.str_to_datetime_processor_factory(
+ self._reg, datetime.time
+ )
+ else:
+ return processors.str_to_time
+
+
+colspecs = {
+ sqltypes.Date: DATE,
+ sqltypes.DateTime: DATETIME,
+ sqltypes.JSON: _SQliteJson,
+ sqltypes.JSON.JSONIndexType: JSONIndexType,
+ sqltypes.JSON.JSONPathType: JSONPathType,
+ sqltypes.Time: TIME,
+}
+
+ischema_names = {
+ "BIGINT": sqltypes.BIGINT,
+ "BLOB": sqltypes.BLOB,
+ "BOOL": sqltypes.BOOLEAN,
+ "BOOLEAN": sqltypes.BOOLEAN,
+ "CHAR": sqltypes.CHAR,
+ "DATE": sqltypes.DATE,
+ "DATE_CHAR": sqltypes.DATE,
+ "DATETIME": sqltypes.DATETIME,
+ "DATETIME_CHAR": sqltypes.DATETIME,
+ "DOUBLE": sqltypes.FLOAT,
+ "DECIMAL": sqltypes.DECIMAL,
+ "FLOAT": sqltypes.FLOAT,
+ "INT": sqltypes.INTEGER,
+ "INTEGER": sqltypes.INTEGER,
+ "JSON": JSON,
+ "NUMERIC": sqltypes.NUMERIC,
+ "REAL": sqltypes.REAL,
+ "SMALLINT": sqltypes.SMALLINT,
+ "TEXT": sqltypes.TEXT,
+ "TIME": sqltypes.TIME,
+ "TIME_CHAR": sqltypes.TIME,
+ "TIMESTAMP": sqltypes.TIMESTAMP,
+ "VARCHAR": sqltypes.VARCHAR,
+ "NVARCHAR": sqltypes.NVARCHAR,
+ "NCHAR": sqltypes.NCHAR,
+}
+
+
+class SQLiteCompiler(compiler.SQLCompiler):
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {
+ "month": "%m",
+ "day": "%d",
+ "year": "%Y",
+ "second": "%S",
+ "hour": "%H",
+ "doy": "%j",
+ "minute": "%M",
+ "epoch": "%s",
+ "dow": "%w",
+ "week": "%W",
+ },
+ )
+
+ def visit_now_func(self, fn, **kw):
+ return "CURRENT_TIMESTAMP"
+
+ def visit_localtimestamp_func(self, func, **kw):
+ return 'DATETIME(CURRENT_TIMESTAMP, "localtime")'
+
+ def visit_true(self, expr, **kw):
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ return "0"
+
+ def visit_char_length_func(self, fn, **kw):
+ return "length%s" % self.function_argspec(fn)
+
+ def visit_cast(self, cast, **kwargs):
+ if self.dialect.supports_cast:
+ return super(SQLiteCompiler, self).visit_cast(cast, **kwargs)
+ else:
+ return self.process(cast.clause, **kwargs)
+
+ def visit_extract(self, extract, **kw):
+ try:
+ return "CAST(STRFTIME('%s', %s) AS INTEGER)" % (
+ self.extract_map[extract.field],
+ self.process(extract.expr, **kw),
+ )
+ except KeyError as err:
+ util.raise_(
+ exc.CompileError(
+ "%s is not a valid extract argument." % extract.field
+ ),
+ replace_context=err,
+ )
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += "\n LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += "\n LIMIT " + self.process(sql.literal(-1))
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ else:
+ text += " OFFSET " + self.process(sql.literal(0), **kw)
+ return text
+
+ def for_update_clause(self, select, **kw):
+ # sqlite has no "FOR UPDATE" AFAICT
+ return ""
+
+ def visit_is_distinct_from_binary(self, binary, operator, **kw):
+ return "%s IS NOT %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_is_not_distinct_from_binary(self, binary, operator, **kw):
+ return "%s IS %s" % (
+ self.process(binary.left),
+ self.process(binary.right),
+ )
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ if binary.type._type_affinity is sqltypes.JSON:
+ expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
+ else:
+ expr = "JSON_EXTRACT(%s, %s)"
+
+ return expr % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ if binary.type._type_affinity is sqltypes.JSON:
+ expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))"
+ else:
+ expr = "JSON_EXTRACT(%s, %s)"
+
+ return expr % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_empty_set_op_expr(self, type_, expand_op):
+ # slightly old SQLite versions don't seem to be able to handle
+ # the empty set impl
+ return self.visit_empty_set_expr(type_)
+
+ def visit_empty_set_expr(self, element_types):
+ return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % (
+ ", ".join("1" for type_ in element_types or [INTEGER()]),
+ ", ".join("1" for type_ in element_types or [INTEGER()]),
+ )
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " REGEXP ", **kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " NOT REGEXP ", **kw)
+
+ def _on_conflict_target(self, clause, **kw):
+ if clause.constraint_target is not None:
+ target_text = "(%s)" % clause.constraint_target
+ elif clause.inferred_target_elements is not None:
+ target_text = "(%s)" % ", ".join(
+ (
+ self.preparer.quote(c)
+ if isinstance(c, util.string_types)
+ else self.process(c, include_table=False, use_schema=False)
+ )
+ for c in clause.inferred_target_elements
+ )
+ if clause.inferred_target_whereclause is not None:
+ target_text += " WHERE %s" % self.process(
+ clause.inferred_target_whereclause,
+ include_table=False,
+ use_schema=False,
+ literal_binds=True,
+ )
+
+ else:
+ target_text = ""
+
+ return target_text
+
+ def visit_on_conflict_do_nothing(self, on_conflict, **kw):
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ if target_text:
+ return "ON CONFLICT %s DO NOTHING" % target_text
+ else:
+ return "ON CONFLICT DO NOTHING"
+
+ def visit_on_conflict_do_update(self, on_conflict, **kw):
+ clause = on_conflict
+
+ target_text = self._on_conflict_target(on_conflict, **kw)
+
+ action_set_ops = []
+
+ set_parameters = dict(clause.update_values_to_set)
+ # create a list of column assignment clauses as tuples
+
+ insert_statement = self.stack[-1]["selectable"]
+ cols = insert_statement.table.c
+ for c in cols:
+ col_key = c.key
+
+ if col_key in set_parameters:
+ value = set_parameters.pop(col_key)
+ elif c in set_parameters:
+ value = set_parameters.pop(c)
+ else:
+ continue
+
+ if coercions._is_literal(value):
+ value = elements.BindParameter(None, value, type_=c.type)
+
+ else:
+ if (
+ isinstance(value, elements.BindParameter)
+ and value.type._isnull
+ ):
+ value = value._clone()
+ value.type = c.type
+ value_text = self.process(value.self_group(), use_schema=False)
+
+ key_text = self.preparer.quote(c.name)
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ # check for names that don't match columns
+ if set_parameters:
+ util.warn(
+ "Additional column names not matching "
+ "any column keys in table '%s': %s"
+ % (
+ self.current_executable.table.name,
+ (", ".join("'%s'" % c for c in set_parameters)),
+ )
+ )
+ for k, v in set_parameters.items():
+ key_text = (
+ self.preparer.quote(k)
+ if isinstance(k, util.string_types)
+ else self.process(k, use_schema=False)
+ )
+ value_text = self.process(
+ coercions.expect(roles.ExpressionElementRole, v),
+ use_schema=False,
+ )
+ action_set_ops.append("%s = %s" % (key_text, value_text))
+
+ action_text = ", ".join(action_set_ops)
+ if clause.update_whereclause is not None:
+ action_text += " WHERE %s" % self.process(
+ clause.update_whereclause, include_table=True, use_schema=False
+ )
+
+ return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text)
+
+
+class SQLiteDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+
+ coltype = self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ colspec = self.preparer.format_column(column) + " " + coltype
+ default = self.get_column_default_string(column)
+ if default is not None:
+ if isinstance(column.server_default.arg, ColumnElement):
+ default = "(" + default + ")"
+ colspec += " DEFAULT " + default
+
+ if not column.nullable:
+ colspec += " NOT NULL"
+
+ on_conflict_clause = column.dialect_options["sqlite"][
+ "on_conflict_not_null"
+ ]
+ if on_conflict_clause is not None:
+ colspec += " ON CONFLICT " + on_conflict_clause
+
+ if column.primary_key:
+ if (
+ column.autoincrement is True
+ and len(column.table.primary_key.columns) != 1
+ ):
+ raise exc.CompileError(
+ "SQLite does not support autoincrement for "
+ "composite primary keys"
+ )
+
+ if (
+ column.table.dialect_options["sqlite"]["autoincrement"]
+ and len(column.table.primary_key.columns) == 1
+ and issubclass(column.type._type_affinity, sqltypes.Integer)
+ and not column.foreign_keys
+ ):
+ colspec += " PRIMARY KEY"
+
+ on_conflict_clause = column.dialect_options["sqlite"][
+ "on_conflict_primary_key"
+ ]
+ if on_conflict_clause is not None:
+ colspec += " ON CONFLICT " + on_conflict_clause
+
+ colspec += " AUTOINCREMENT"
+
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+
+ return colspec
+
+ def visit_primary_key_constraint(self, constraint):
+ # for columns with sqlite_autoincrement=True,
+ # the PRIMARY KEY constraint can only be inline
+ # with the column itself.
+ if len(constraint.columns) == 1:
+ c = list(constraint)[0]
+ if (
+ c.primary_key
+ and c.table.dialect_options["sqlite"]["autoincrement"]
+ and issubclass(c.type._type_affinity, sqltypes.Integer)
+ and not c.foreign_keys
+ ):
+ return None
+
+ text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint(
+ constraint
+ )
+
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
+ if on_conflict_clause is None and len(constraint.columns) == 1:
+ on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][
+ "on_conflict_primary_key"
+ ]
+
+ if on_conflict_clause is not None:
+ text += " ON CONFLICT " + on_conflict_clause
+
+ return text
+
+ def visit_unique_constraint(self, constraint):
+ text = super(SQLiteDDLCompiler, self).visit_unique_constraint(
+ constraint
+ )
+
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
+ if on_conflict_clause is None and len(constraint.columns) == 1:
+ col1 = list(constraint)[0]
+ if isinstance(col1, schema.SchemaItem):
+ on_conflict_clause = list(constraint)[0].dialect_options[
+ "sqlite"
+ ]["on_conflict_unique"]
+
+ if on_conflict_clause is not None:
+ text += " ON CONFLICT " + on_conflict_clause
+
+ return text
+
+ def visit_check_constraint(self, constraint):
+ text = super(SQLiteDDLCompiler, self).visit_check_constraint(
+ constraint
+ )
+
+ on_conflict_clause = constraint.dialect_options["sqlite"][
+ "on_conflict"
+ ]
+
+ if on_conflict_clause is not None:
+ text += " ON CONFLICT " + on_conflict_clause
+
+ return text
+
+ def visit_column_check_constraint(self, constraint):
+ text = super(SQLiteDDLCompiler, self).visit_column_check_constraint(
+ constraint
+ )
+
+ if constraint.dialect_options["sqlite"]["on_conflict"] is not None:
+ raise exc.CompileError(
+ "SQLite does not support on conflict clause for "
+ "column check constraint"
+ )
+
+ return text
+
+ def visit_foreign_key_constraint(self, constraint):
+
+ local_table = constraint.elements[0].parent.table
+ remote_table = constraint.elements[0].column.table
+
+ if local_table.schema != remote_table.schema:
+ return None
+ else:
+ return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint(
+ constraint
+ )
+
+ def define_constraint_remote_table(self, constraint, table, preparer):
+ """Format the remote table clause of a CREATE CONSTRAINT clause."""
+
+ return preparer.format_table(table, use_schema=False)
+
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True
+ ):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+
+ text += "INDEX "
+
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += "%s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=True),
+ preparer.format_table(index.table, use_schema=False),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+
+ whereclause = index.dialect_options["sqlite"]["where"]
+ if whereclause is not None:
+ where_compiled = self.sql_compiler.process(
+ whereclause, include_table=False, literal_binds=True
+ )
+ text += " WHERE " + where_compiled
+
+ return text
+
+ def post_create_table(self, table):
+ if table.dialect_options["sqlite"]["with_rowid"] is False:
+ return "\n WITHOUT ROWID"
+ return ""
+
+
+class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_)
+
+ def visit_DATETIME(self, type_, **kw):
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
+ return super(SQLiteTypeCompiler, self).visit_DATETIME(type_)
+ else:
+ return "DATETIME_CHAR"
+
+ def visit_DATE(self, type_, **kw):
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
+ return super(SQLiteTypeCompiler, self).visit_DATE(type_)
+ else:
+ return "DATE_CHAR"
+
+ def visit_TIME(self, type_, **kw):
+ if (
+ not isinstance(type_, _DateTimeMixin)
+ or type_.format_is_text_affinity
+ ):
+ return super(SQLiteTypeCompiler, self).visit_TIME(type_)
+ else:
+ return "TIME_CHAR"
+
+ def visit_JSON(self, type_, **kw):
+ # note this name provides NUMERIC affinity, not TEXT.
+ # should not be an issue unless the JSON value consists of a single
+ # numeric value. JSONTEXT can be used if this case is required.
+ return "JSON"
+
+
+class SQLiteIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = set(
+ [
+ "add",
+ "after",
+ "all",
+ "alter",
+ "analyze",
+ "and",
+ "as",
+ "asc",
+ "attach",
+ "autoincrement",
+ "before",
+ "begin",
+ "between",
+ "by",
+ "cascade",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "commit",
+ "conflict",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_time",
+ "current_timestamp",
+ "database",
+ "default",
+ "deferrable",
+ "deferred",
+ "delete",
+ "desc",
+ "detach",
+ "distinct",
+ "drop",
+ "each",
+ "else",
+ "end",
+ "escape",
+ "except",
+ "exclusive",
+ "exists",
+ "explain",
+ "false",
+ "fail",
+ "for",
+ "foreign",
+ "from",
+ "full",
+ "glob",
+ "group",
+ "having",
+ "if",
+ "ignore",
+ "immediate",
+ "in",
+ "index",
+ "indexed",
+ "initially",
+ "inner",
+ "insert",
+ "instead",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "key",
+ "left",
+ "like",
+ "limit",
+ "match",
+ "natural",
+ "not",
+ "notnull",
+ "null",
+ "of",
+ "offset",
+ "on",
+ "or",
+ "order",
+ "outer",
+ "plan",
+ "pragma",
+ "primary",
+ "query",
+ "raise",
+ "references",
+ "reindex",
+ "rename",
+ "replace",
+ "restrict",
+ "right",
+ "rollback",
+ "row",
+ "select",
+ "set",
+ "table",
+ "temp",
+ "temporary",
+ "then",
+ "to",
+ "transaction",
+ "trigger",
+ "true",
+ "union",
+ "unique",
+ "update",
+ "using",
+ "vacuum",
+ "values",
+ "view",
+ "virtual",
+ "when",
+ "where",
+ ]
+ )
+
+
+class SQLiteExecutionContext(default.DefaultExecutionContext):
+ @util.memoized_property
+ def _preserve_raw_colnames(self):
+ return (
+ not self.dialect._broken_dotted_colnames
+ or self.execution_options.get("sqlite_raw_colnames", False)
+ )
+
+ def _translate_colname(self, colname):
+ # TODO: detect SQLite version 3.10.0 or greater;
+ # see [ticket:3633]
+
+ # adjust for dotted column names. SQLite
+ # in the case of UNION may store col names as
+ # "tablename.colname", or if using an attached database,
+ # "database.tablename.colname", in cursor.description
+ if not self._preserve_raw_colnames and "." in colname:
+ return colname.split(".")[-1], colname
+ else:
+ return colname, None
+
+
+class SQLiteDialect(default.DefaultDialect):
+ name = "sqlite"
+ supports_alter = False
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+
+ # SQlite supports "DEFAULT VALUES" but *does not* support
+ # "VALUES (DEFAULT)"
+ supports_default_values = True
+ supports_default_metavalue = False
+
+ supports_empty_insert = False
+ supports_cast = True
+ supports_multivalues_insert = True
+ tuple_in_values = True
+ supports_statement_cache = True
+
+ default_paramstyle = "qmark"
+ execution_ctx_cls = SQLiteExecutionContext
+ statement_compiler = SQLiteCompiler
+ ddl_compiler = SQLiteDDLCompiler
+ type_compiler = SQLiteTypeCompiler
+ preparer = SQLiteIdentifierPreparer
+ ischema_names = ischema_names
+ colspecs = colspecs
+ isolation_level = None
+
+ construct_arguments = [
+ (
+ sa_schema.Table,
+ {
+ "autoincrement": False,
+ "with_rowid": True,
+ },
+ ),
+ (sa_schema.Index, {"where": None}),
+ (
+ sa_schema.Column,
+ {
+ "on_conflict_primary_key": None,
+ "on_conflict_not_null": None,
+ "on_conflict_unique": None,
+ },
+ ),
+ (sa_schema.Constraint, {"on_conflict": None}),
+ ]
+
+ _broken_fk_pragma_quotes = False
+ _broken_dotted_colnames = False
+
+ @util.deprecated_params(
+ _json_serializer=(
+ "1.3.7",
+ "The _json_serializer argument to the SQLite dialect has "
+ "been renamed to the correct name of json_serializer. The old "
+ "argument name will be removed in a future release.",
+ ),
+ _json_deserializer=(
+ "1.3.7",
+ "The _json_deserializer argument to the SQLite dialect has "
+ "been renamed to the correct name of json_deserializer. The old "
+ "argument name will be removed in a future release.",
+ ),
+ )
+ def __init__(
+ self,
+ isolation_level=None,
+ native_datetime=False,
+ json_serializer=None,
+ json_deserializer=None,
+ _json_serializer=None,
+ _json_deserializer=None,
+ **kwargs
+ ):
+ default.DefaultDialect.__init__(self, **kwargs)
+ self.isolation_level = isolation_level
+
+ if _json_serializer:
+ json_serializer = _json_serializer
+ if _json_deserializer:
+ json_deserializer = _json_deserializer
+ self._json_serializer = json_serializer
+ self._json_deserializer = json_deserializer
+
+ # this flag used by pysqlite dialect, and perhaps others in the
+ # future, to indicate the driver is handling date/timestamp
+ # conversions (and perhaps datetime/time as well on some hypothetical
+ # driver ?)
+ self.native_datetime = native_datetime
+
+ if self.dbapi is not None:
+ if self.dbapi.sqlite_version_info < (3, 7, 16):
+ util.warn(
+ "SQLite version %s is older than 3.7.16, and will not "
+ "support right nested joins, as are sometimes used in "
+ "more complex ORM scenarios. SQLAlchemy 1.4 and above "
+ "no longer tries to rewrite these joins."
+ % (self.dbapi.sqlite_version_info,)
+ )
+
+ self._broken_dotted_colnames = self.dbapi.sqlite_version_info < (
+ 3,
+ 10,
+ 0,
+ )
+ self.supports_default_values = self.dbapi.sqlite_version_info >= (
+ 3,
+ 3,
+ 8,
+ )
+ self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3)
+ self.supports_multivalues_insert = (
+ # https://www.sqlite.org/releaselog/3_7_11.html
+ self.dbapi.sqlite_version_info
+ >= (3, 7, 11)
+ )
+ # see https://www.sqlalchemy.org/trac/ticket/2568
+ # as well as https://www.sqlite.org/src/info/600482d161
+ self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < (
+ 3,
+ 6,
+ 14,
+ )
+
+ _isolation_lookup = util.immutabledict(
+ {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0}
+ )
+
+ def set_isolation_level(self, connection, level):
+ try:
+ isolation_level = self._isolation_lookup[level.replace("_", " ")]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Invalid value '%s' for isolation_level. "
+ "Valid isolation levels for %s are %s"
+ % (
+ level,
+ self.name,
+ ", ".join(self._isolation_lookup),
+ )
+ ),
+ replace_context=err,
+ )
+ cursor = connection.cursor()
+ cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level)
+ cursor.close()
+
+ def get_isolation_level(self, connection):
+ cursor = connection.cursor()
+ cursor.execute("PRAGMA read_uncommitted")
+ res = cursor.fetchone()
+ if res:
+ value = res[0]
+ else:
+ # https://www.sqlite.org/changes.html#version_3_3_3
+ # "Optional READ UNCOMMITTED isolation (instead of the
+ # default isolation level of SERIALIZABLE) and
+ # table level locking when database connections
+ # share a common cache.""
+ # pre-SQLite 3.3.0 default to 0
+ value = 0
+ cursor.close()
+ if value == 0:
+ return "SERIALIZABLE"
+ elif value == 1:
+ return "READ UNCOMMITTED"
+ else:
+ assert False, "Unknown isolation level %s" % value
+
+ def on_connect(self):
+ if self.isolation_level is not None:
+
+ def connect(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ return connect
+ else:
+ return None
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+ s = "PRAGMA database_list"
+ dl = connection.exec_driver_sql(s)
+
+ return [db[1] for db in dl if db[1] != "temp"]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = "%s.sqlite_master" % qschema
+ else:
+ master = "sqlite_master"
+ s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % (
+ master,
+ )
+ rs = connection.exec_driver_sql(s)
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ s = (
+ "SELECT name FROM sqlite_temp_master "
+ "WHERE type='table' ORDER BY name "
+ )
+ rs = connection.exec_driver_sql(s)
+
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_temp_view_names(self, connection, **kw):
+ s = (
+ "SELECT name FROM sqlite_temp_master "
+ "WHERE type='view' ORDER BY name "
+ )
+ rs = connection.exec_driver_sql(s)
+
+ return [row[0] for row in rs]
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ info = self._get_table_pragma(
+ connection, "table_info", table_name, schema=schema
+ )
+ return bool(info)
+
+ def _get_default_schema_name(self, connection):
+ return "main"
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = "%s.sqlite_master" % qschema
+ else:
+ master = "sqlite_master"
+ s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % (
+ master,
+ )
+ rs = connection.exec_driver_sql(s)
+
+ return [row[0] for row in rs]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is not None:
+ qschema = self.identifier_preparer.quote_identifier(schema)
+ master = "%s.sqlite_master" % qschema
+ s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % (
+ master,
+ )
+ rs = connection.exec_driver_sql(s, (view_name,))
+ else:
+ try:
+ s = (
+ "SELECT sql FROM "
+ " (SELECT * FROM sqlite_master UNION ALL "
+ " SELECT * FROM sqlite_temp_master) "
+ "WHERE name = ? "
+ "AND type='view'"
+ )
+ rs = connection.exec_driver_sql(s, (view_name,))
+ except exc.DBAPIError:
+ s = (
+ "SELECT sql FROM sqlite_master WHERE name = ? "
+ "AND type='view'"
+ )
+ rs = connection.exec_driver_sql(s, (view_name,))
+
+ result = rs.fetchall()
+ if result:
+ return result[0].sql
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ pragma = "table_info"
+ # computed columns are threaded as hidden, they require table_xinfo
+ if self.server_version_info >= (3, 31):
+ pragma = "table_xinfo"
+ info = self._get_table_pragma(
+ connection, pragma, table_name, schema=schema
+ )
+ columns = []
+ tablesql = None
+ for row in info:
+ name = row[1]
+ type_ = row[2].upper()
+ nullable = not row[3]
+ default = row[4]
+ primary_key = row[5]
+ hidden = row[6] if pragma == "table_xinfo" else 0
+
+ # hidden has value 0 for normal columns, 1 for hidden columns,
+ # 2 for computed virtual columns and 3 for computed stored columns
+ # https://www.sqlite.org/src/info/069351b85f9a706f60d3e98fbc8aaf40c374356b967c0464aede30ead3d9d18b
+ if hidden == 1:
+ continue
+
+ generated = bool(hidden)
+ persisted = hidden == 3
+
+ if tablesql is None and generated:
+ tablesql = self._get_table_sql(
+ connection, table_name, schema, **kw
+ )
+
+ columns.append(
+ self._get_column_info(
+ name,
+ type_,
+ nullable,
+ default,
+ primary_key,
+ generated,
+ persisted,
+ tablesql,
+ )
+ )
+ return columns
+
+ def _get_column_info(
+ self,
+ name,
+ type_,
+ nullable,
+ default,
+ primary_key,
+ generated,
+ persisted,
+ tablesql,
+ ):
+
+ if generated:
+ # the type of a column "cc INTEGER GENERATED ALWAYS AS (1 + 42)"
+ # somehow is "INTEGER GENERATED ALWAYS"
+ type_ = re.sub("generated", "", type_, flags=re.IGNORECASE)
+ type_ = re.sub("always", "", type_, flags=re.IGNORECASE).strip()
+
+ coltype = self._resolve_type_affinity(type_)
+
+ if default is not None:
+ default = util.text_type(default)
+
+ colspec = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": "auto",
+ "primary_key": primary_key,
+ }
+ if generated:
+ sqltext = ""
+ if tablesql:
+ pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?"
+ match = re.search(
+ re.escape(name) + pattern, tablesql, re.IGNORECASE
+ )
+ if match:
+ sqltext = match.group(1)
+ colspec["computed"] = {"sqltext": sqltext, "persisted": persisted}
+ return colspec
+
+ def _resolve_type_affinity(self, type_):
+ """Return a data type from a reflected column, using affinity rules.
+
+ SQLite's goal for universal compatibility introduces some complexity
+ during reflection, as a column's defined type might not actually be a
+ type that SQLite understands - or indeed, my not be defined *at all*.
+ Internally, SQLite handles this with a 'data type affinity' for each
+ column definition, mapping to one of 'TEXT', 'NUMERIC', 'INTEGER',
+ 'REAL', or 'NONE' (raw bits). The algorithm that determines this is
+ listed in https://www.sqlite.org/datatype3.html section 2.1.
+
+ This method allows SQLAlchemy to support that algorithm, while still
+ providing access to smarter reflection utilities by recognizing
+ column definitions that SQLite only supports through affinity (like
+ DATE and DOUBLE).
+
+ """
+ match = re.match(r"([\w ]+)(\(.*?\))?", type_)
+ if match:
+ coltype = match.group(1)
+ args = match.group(2)
+ else:
+ coltype = ""
+ args = ""
+
+ if coltype in self.ischema_names:
+ coltype = self.ischema_names[coltype]
+ elif "INT" in coltype:
+ coltype = sqltypes.INTEGER
+ elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype:
+ coltype = sqltypes.TEXT
+ elif "BLOB" in coltype or not coltype:
+ coltype = sqltypes.NullType
+ elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype:
+ coltype = sqltypes.REAL
+ else:
+ coltype = sqltypes.NUMERIC
+
+ if args is not None:
+ args = re.findall(r"(\d+)", args)
+ try:
+ coltype = coltype(*[int(a) for a in args])
+ except TypeError:
+ util.warn(
+ "Could not instantiate type %s with "
+ "reflected arguments %s; using no arguments."
+ % (coltype, args)
+ )
+ coltype = coltype()
+ else:
+ coltype = coltype()
+
+ return coltype
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ constraint_name = None
+ table_data = self._get_table_sql(connection, table_name, schema=schema)
+ if table_data:
+ PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY"
+ result = re.search(PK_PATTERN, table_data, re.I)
+ constraint_name = result.group(1) if result else None
+
+ cols = self.get_columns(connection, table_name, schema, **kw)
+ cols.sort(key=lambda col: col.get("primary_key"))
+ pkeys = []
+ for col in cols:
+ if col["primary_key"]:
+ pkeys.append(col["name"])
+
+ return {"constrained_columns": pkeys, "name": constraint_name}
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # sqlite makes this *extremely difficult*.
+ # First, use the pragma to get the actual FKs.
+ pragma_fks = self._get_table_pragma(
+ connection, "foreign_key_list", table_name, schema=schema
+ )
+
+ fks = {}
+
+ for row in pragma_fks:
+ (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4])
+
+ if not rcol:
+ # no referred column, which means it was not named in the
+ # original DDL. The referred columns of the foreign key
+ # constraint are therefore the primary key of the referred
+ # table.
+ referred_pk = self.get_pk_constraint(
+ connection, rtbl, schema=schema, **kw
+ )
+ # note that if table doesn't exist, we still get back a record,
+ # just it has no columns in it
+ referred_columns = referred_pk["constrained_columns"]
+ else:
+ # note we use this list only if this is the first column
+ # in the constraint. for subsequent columns we ignore the
+ # list and append "rcol" if present.
+ referred_columns = []
+
+ if self._broken_fk_pragma_quotes:
+ rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl)
+
+ if numerical_id in fks:
+ fk = fks[numerical_id]
+ else:
+ fk = fks[numerical_id] = {
+ "name": None,
+ "constrained_columns": [],
+ "referred_schema": schema,
+ "referred_table": rtbl,
+ "referred_columns": referred_columns,
+ "options": {},
+ }
+ fks[numerical_id] = fk
+
+ fk["constrained_columns"].append(lcol)
+
+ if rcol:
+ fk["referred_columns"].append(rcol)
+
+ def fk_sig(constrained_columns, referred_table, referred_columns):
+ return (
+ tuple(constrained_columns)
+ + (referred_table,)
+ + tuple(referred_columns)
+ )
+
+ # then, parse the actual SQL and attempt to find DDL that matches
+ # the names as well. SQLite saves the DDL in whatever format
+ # it was typed in as, so need to be liberal here.
+
+ keys_by_signature = dict(
+ (
+ fk_sig(
+ fk["constrained_columns"],
+ fk["referred_table"],
+ fk["referred_columns"],
+ ),
+ fk,
+ )
+ for fk in fks.values()
+ )
+
+ table_data = self._get_table_sql(connection, table_name, schema=schema)
+ if table_data is None:
+ # system tables, etc.
+ return []
+
+ def parse_fks():
+ FK_PATTERN = (
+ r"(?:CONSTRAINT (\w+) +)?"
+ r"FOREIGN KEY *\( *(.+?) *\) +"
+ r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *'
+ r"((?:ON (?:DELETE|UPDATE) "
+ r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)"
+ )
+ for match in re.finditer(FK_PATTERN, table_data, re.I):
+ (
+ constraint_name,
+ constrained_columns,
+ referred_quoted_name,
+ referred_name,
+ referred_columns,
+ onupdatedelete,
+ ) = match.group(1, 2, 3, 4, 5, 6)
+ constrained_columns = list(
+ self._find_cols_in_sig(constrained_columns)
+ )
+ if not referred_columns:
+ referred_columns = constrained_columns
+ else:
+ referred_columns = list(
+ self._find_cols_in_sig(referred_columns)
+ )
+ referred_name = referred_quoted_name or referred_name
+ options = {}
+
+ for token in re.split(r" *\bON\b *", onupdatedelete.upper()):
+ if token.startswith("DELETE"):
+ ondelete = token[6:].strip()
+ if ondelete and ondelete != "NO ACTION":
+ options["ondelete"] = ondelete
+ elif token.startswith("UPDATE"):
+ onupdate = token[6:].strip()
+ if onupdate and onupdate != "NO ACTION":
+ options["onupdate"] = onupdate
+ yield (
+ constraint_name,
+ constrained_columns,
+ referred_name,
+ referred_columns,
+ options,
+ )
+
+ fkeys = []
+
+ for (
+ constraint_name,
+ constrained_columns,
+ referred_name,
+ referred_columns,
+ options,
+ ) in parse_fks():
+ sig = fk_sig(constrained_columns, referred_name, referred_columns)
+ if sig not in keys_by_signature:
+ util.warn(
+ "WARNING: SQL-parsed foreign key constraint "
+ "'%s' could not be located in PRAGMA "
+ "foreign_keys for table %s" % (sig, table_name)
+ )
+ continue
+ key = keys_by_signature.pop(sig)
+ key["name"] = constraint_name
+ key["options"] = options
+ fkeys.append(key)
+ # assume the remainders are the unnamed, inline constraints, just
+ # use them as is as it's extremely difficult to parse inline
+ # constraints
+ fkeys.extend(keys_by_signature.values())
+ return fkeys
+
+ def _find_cols_in_sig(self, sig):
+ for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I):
+ yield match.group(1) or match.group(2)
+
+ @reflection.cache
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+
+ auto_index_by_sig = {}
+ for idx in self.get_indexes(
+ connection,
+ table_name,
+ schema=schema,
+ include_auto_indexes=True,
+ **kw
+ ):
+ if not idx["name"].startswith("sqlite_autoindex"):
+ continue
+ sig = tuple(idx["column_names"])
+ auto_index_by_sig[sig] = idx
+
+ table_data = self._get_table_sql(
+ connection, table_name, schema=schema, **kw
+ )
+ if not table_data:
+ return []
+
+ unique_constraints = []
+
+ def parse_uqs():
+ UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
+ INLINE_UNIQUE_PATTERN = (
+ r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) '
+ r"+[a-z0-9_ ]+? +UNIQUE"
+ )
+
+ for match in re.finditer(UNIQUE_PATTERN, table_data, re.I):
+ name, cols = match.group(1, 2)
+ yield name, list(self._find_cols_in_sig(cols))
+
+ # we need to match inlines as well, as we seek to differentiate
+ # a UNIQUE constraint from a UNIQUE INDEX, even though these
+ # are kind of the same thing :)
+ for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I):
+ cols = list(
+ self._find_cols_in_sig(match.group(1) or match.group(2))
+ )
+ yield None, cols
+
+ for name, cols in parse_uqs():
+ sig = tuple(cols)
+ if sig in auto_index_by_sig:
+ auto_index_by_sig.pop(sig)
+ parsed_constraint = {"name": name, "column_names": cols}
+ unique_constraints.append(parsed_constraint)
+ # NOTE: auto_index_by_sig might not be empty here,
+ # the PRIMARY KEY may have an entry.
+ return unique_constraints
+
+ @reflection.cache
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ table_data = self._get_table_sql(
+ connection, table_name, schema=schema, **kw
+ )
+ if not table_data:
+ return []
+
+ CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *"
+ check_constraints = []
+ # NOTE: we aren't using re.S here because we actually are
+ # taking advantage of each CHECK constraint being all on one
+ # line in the table definition in order to delineate. This
+ # necessarily makes assumptions as to how the CREATE TABLE
+ # was emitted.
+
+ for match in re.finditer(CHECK_PATTERN, table_data, re.I):
+ name = match.group(1)
+
+ if name:
+ name = re.sub(r'^"|"$', "", name)
+
+ check_constraints.append({"sqltext": match.group(2), "name": name})
+
+ return check_constraints
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ pragma_indexes = self._get_table_pragma(
+ connection, "index_list", table_name, schema=schema
+ )
+ indexes = []
+
+ include_auto_indexes = kw.pop("include_auto_indexes", False)
+ for row in pragma_indexes:
+ # ignore implicit primary key index.
+ # https://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html
+ if not include_auto_indexes and row[1].startswith(
+ "sqlite_autoindex"
+ ):
+ continue
+ indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
+
+ # loop thru unique indexes to get the column names.
+ for idx in list(indexes):
+ pragma_index = self._get_table_pragma(
+ connection, "index_info", idx["name"]
+ )
+
+ for row in pragma_index:
+ if row[2] is None:
+ util.warn(
+ "Skipped unsupported reflection of "
+ "expression-based index %s" % idx["name"]
+ )
+ indexes.remove(idx)
+ break
+ else:
+ idx["column_names"].append(row[2])
+ return indexes
+
+ @reflection.cache
+ def _get_table_sql(self, connection, table_name, schema=None, **kw):
+ if schema:
+ schema_expr = "%s." % (
+ self.identifier_preparer.quote_identifier(schema)
+ )
+ else:
+ schema_expr = ""
+ try:
+ s = (
+ "SELECT sql FROM "
+ " (SELECT * FROM %(schema)ssqlite_master UNION ALL "
+ " SELECT * FROM %(schema)ssqlite_temp_master) "
+ "WHERE name = ? "
+ "AND type = 'table'" % {"schema": schema_expr}
+ )
+ rs = connection.exec_driver_sql(s, (table_name,))
+ except exc.DBAPIError:
+ s = (
+ "SELECT sql FROM %(schema)ssqlite_master "
+ "WHERE name = ? "
+ "AND type = 'table'" % {"schema": schema_expr}
+ )
+ rs = connection.exec_driver_sql(s, (table_name,))
+ return rs.scalar()
+
+ def _get_table_pragma(self, connection, pragma, table_name, schema=None):
+ quote = self.identifier_preparer.quote_identifier
+ if schema is not None:
+ statements = ["PRAGMA %s." % quote(schema)]
+ else:
+ # because PRAGMA looks in all attached databases if no schema
+ # given, need to specify "main" schema, however since we want
+ # 'temp' tables in the same namespace as 'main', need to run
+ # the PRAGMA twice
+ statements = ["PRAGMA main.", "PRAGMA temp."]
+
+ qtable = quote(table_name)
+ for statement in statements:
+ statement = "%s%s(%s)" % (statement, pragma, qtable)
+ cursor = connection.exec_driver_sql(statement)
+ if not cursor._soft_closed:
+ # work around SQLite issue whereby cursor.description
+ # is blank when PRAGMA returns no rows:
+ # https://www.sqlite.org/cvstrac/tktview?tn=1884
+ result = cursor.fetchall()
+ else:
+ result = []
+ if result:
+ return result
+ else:
+ return []
diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py
new file mode 100644
index 0000000..b04a5e6
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/dml.py
@@ -0,0 +1,200 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import util
+from ...sql import coercions
+from ...sql import roles
+from ...sql.base import _exclusive_against
+from ...sql.base import _generative
+from ...sql.base import ColumnCollection
+from ...sql.dml import Insert as StandardInsert
+from ...sql.elements import ClauseElement
+from ...sql.expression import alias
+from ...util.langhelpers import public_factory
+
+
+__all__ = ("Insert", "insert")
+
+
+class Insert(StandardInsert):
+ """SQLite-specific implementation of INSERT.
+
+ Adds methods for SQLite-specific syntaxes such as ON CONFLICT.
+
+ The :class:`_sqlite.Insert` object is created using the
+ :func:`sqlalchemy.dialects.sqlite.insert` function.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`sqlite_on_conflict_insert`
+
+ """
+
+ stringify_dialect = "sqlite"
+ inherit_cache = False
+
+ @util.memoized_property
+ def excluded(self):
+ """Provide the ``excluded`` namespace for an ON CONFLICT statement
+
+ SQLite's ON CONFLICT clause allows reference to the row that would
+ be inserted, known as ``excluded``. This attribute provides
+ all columns in this row to be referenceable.
+
+ .. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance
+ of :class:`_expression.ColumnCollection`, which provides an
+ interface the same as that of the :attr:`_schema.Table.c`
+ collection described at :ref:`metadata_tables_and_columns`.
+ With this collection, ordinary names are accessible like attributes
+ (e.g. ``stmt.excluded.some_column``), but special names and
+ dictionary method names should be accessed using indexed access,
+ such as ``stmt.excluded["column name"]`` or
+ ``stmt.excluded["values"]``. See the docstring for
+ :class:`_expression.ColumnCollection` for further examples.
+
+ """
+ return alias(self.table, name="excluded").columns
+
+ _on_conflict_exclusive = _exclusive_against(
+ "_post_values_clause",
+ msgs={
+ "_post_values_clause": "This Insert construct already has "
+ "an ON CONFLICT clause established"
+ },
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_update(
+ self,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ r"""
+ Specifies a DO UPDATE SET action for ON CONFLICT clause.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index or unique constraint.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ :param set\_:
+ A dictionary or other mapping object
+ where the keys are either names of columns in the target table,
+ or :class:`_schema.Column` objects or other ORM-mapped columns
+ matching that of the target table, and expressions or literals
+ as values, specifying the ``SET`` actions to take.
+
+ .. versionadded:: 1.4 The
+ :paramref:`_sqlite.Insert.on_conflict_do_update.set_`
+ parameter supports :class:`_schema.Column` objects from the target
+ :class:`_schema.Table` as keys.
+
+ .. warning:: This dictionary does **not** take into account
+ Python-specified default UPDATE values or generation functions,
+ e.g. those specified using :paramref:`_schema.Column.onupdate`.
+ These values will not be exercised for an ON CONFLICT style of
+ UPDATE, unless they are manually specified in the
+ :paramref:`.Insert.on_conflict_do_update.set_` dictionary.
+
+ :param where:
+ Optional argument. If present, can be a literal SQL
+ string or an acceptable expression for a ``WHERE`` clause
+ that restricts the rows affected by ``DO UPDATE SET``. Rows
+ not meeting the ``WHERE`` condition will not be updated
+ (effectively a ``DO NOTHING`` for those rows).
+
+ """
+
+ self._post_values_clause = OnConflictDoUpdate(
+ index_elements, index_where, set_, where
+ )
+
+ @_generative
+ @_on_conflict_exclusive
+ def on_conflict_do_nothing(self, index_elements=None, index_where=None):
+ """
+ Specifies a DO NOTHING action for ON CONFLICT clause.
+
+ :param index_elements:
+ A sequence consisting of string column names, :class:`_schema.Column`
+ objects, or other column expression objects that will be used
+ to infer a target index or unique constraint.
+
+ :param index_where:
+ Additional WHERE criterion that can be used to infer a
+ conditional target index.
+
+ """
+
+ self._post_values_clause = OnConflictDoNothing(
+ index_elements, index_where
+ )
+
+
+insert = public_factory(
+ Insert, ".dialects.sqlite.insert", ".dialects.sqlite.Insert"
+)
+
+
+class OnConflictClause(ClauseElement):
+ stringify_dialect = "sqlite"
+
+ def __init__(self, index_elements=None, index_where=None):
+
+ if index_elements is not None:
+ self.constraint_target = None
+ self.inferred_target_elements = index_elements
+ self.inferred_target_whereclause = index_where
+ else:
+ self.constraint_target = (
+ self.inferred_target_elements
+ ) = self.inferred_target_whereclause = None
+
+
+class OnConflictDoNothing(OnConflictClause):
+ __visit_name__ = "on_conflict_do_nothing"
+
+
+class OnConflictDoUpdate(OnConflictClause):
+ __visit_name__ = "on_conflict_do_update"
+
+ def __init__(
+ self,
+ index_elements=None,
+ index_where=None,
+ set_=None,
+ where=None,
+ ):
+ super(OnConflictDoUpdate, self).__init__(
+ index_elements=index_elements,
+ index_where=index_where,
+ )
+
+ if isinstance(set_, dict):
+ if not set_:
+ raise ValueError("set parameter dictionary must not be empty")
+ elif isinstance(set_, ColumnCollection):
+ set_ = dict(set_)
+ else:
+ raise ValueError(
+ "set parameter must be a non-empty dictionary "
+ "or a ColumnCollection such as the `.c.` collection "
+ "of a Table object"
+ )
+ self.update_values_to_set = [
+ (coercions.expect(roles.DMLColumnRole, key), value)
+ for key, value in set_.items()
+ ]
+ self.update_whereclause = where
diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py
new file mode 100644
index 0000000..614f954
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/json.py
@@ -0,0 +1,84 @@
+from ... import types as sqltypes
+
+
+class JSON(sqltypes.JSON):
+ """SQLite JSON type.
+
+ SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note
+ that JSON1_ is a
+ `loadable extension <https://www.sqlite.org/loadext.html>`_ and as such
+ may not be available, or may require run-time loading.
+
+ :class:`_sqlite.JSON` is used automatically whenever the base
+ :class:`_types.JSON` datatype is used against a SQLite backend.
+
+ .. seealso::
+
+ :class:`_types.JSON` - main documentation for the generic
+ cross-platform JSON datatype.
+
+ The :class:`_sqlite.JSON` type supports persistence of JSON values
+ as well as the core index operations provided by :class:`_types.JSON`
+ datatype, by adapting the operations to render the ``JSON_EXTRACT``
+ function wrapped in the ``JSON_QUOTE`` function at the database level.
+ Extracted values are quoted in order to ensure that the results are
+ always JSON string values.
+
+
+ .. versionadded:: 1.3
+
+
+ .. _JSON1: https://www.sqlite.org/json1.html
+
+ """
+
+
+# Note: these objects currently match exactly those of MySQL, however since
+# these are not generalizable to all JSON implementations, remain separately
+# implemented for each dialect.
+class _FormatTypeMixin(object):
+ def _format_value(self, value):
+ raise NotImplementedError()
+
+ def bind_processor(self, dialect):
+ super_proc = self.string_bind_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ super_proc = self.string_literal_processor(dialect)
+
+ def process(value):
+ value = self._format_value(value)
+ if super_proc:
+ value = super_proc(value)
+ return value
+
+ return process
+
+
+class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
+ def _format_value(self, value):
+ if isinstance(value, int):
+ value = "$[%s]" % value
+ else:
+ value = '$."%s"' % value
+ return value
+
+
+class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
+ def _format_value(self, value):
+ return "$%s" % (
+ "".join(
+ [
+ "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
+ for elem in value
+ ]
+ )
+ )
diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py
new file mode 100644
index 0000000..e5b17e8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/provision.py
@@ -0,0 +1,142 @@
+import os
+import re
+
+from ... import exc
+from ...engine import url as sa_url
+from ...testing.provision import create_db
+from ...testing.provision import drop_db
+from ...testing.provision import follower_url_from_main
+from ...testing.provision import generate_driver_url
+from ...testing.provision import log
+from ...testing.provision import post_configure_engine
+from ...testing.provision import run_reap_dbs
+from ...testing.provision import stop_test_class_outside_fixtures
+from ...testing.provision import temp_table_keyword_args
+
+
+# TODO: I can't get this to build dynamically with pytest-xdist procs
+_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"}
+
+
+@generate_driver_url.for_db("sqlite")
+def generate_driver_url(url, driver, query_str):
+ if driver == "pysqlcipher" and url.get_driver_name() != "pysqlcipher":
+ if url.database:
+ url = url.set(database=url.database + ".enc")
+ url = url.set(password="test")
+ url = url.set(drivername="sqlite+%s" % (driver,))
+ try:
+ url.get_dialect()
+ except exc.NoSuchModuleError:
+ return None
+ else:
+ return url
+
+
+@follower_url_from_main.for_db("sqlite")
+def _sqlite_follower_url_from_main(url, ident):
+ url = sa_url.make_url(url)
+
+ if not url.database or url.database == ":memory:":
+ return url
+ else:
+
+ m = re.match(r"(.+?)\.(.+)$", url.database)
+ name, ext = m.group(1, 2)
+ drivername = url.get_driver_name()
+ return sa_url.make_url(
+ "sqlite+%s:///%s_%s.%s" % (drivername, drivername, ident, ext)
+ )
+
+
+@post_configure_engine.for_db("sqlite")
+def _sqlite_post_configure_engine(url, engine, follower_ident):
+ from sqlalchemy import event
+
+ @event.listens_for(engine, "connect")
+ def connect(dbapi_connection, connection_record):
+ # use file DBs in all cases, memory acts kind of strangely
+ # as an attached
+ if not follower_ident:
+ # note this test_schema.db gets created for all test runs.
+ # there's not any dedicated cleanup step for it. it in some
+ # ways corresponds to the "test.test_schema" schema that's
+ # expected to be already present, so for now it just stays
+ # in a given checkout directory.
+ dbapi_connection.execute(
+ 'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
+ % (engine.driver,)
+ )
+ else:
+ dbapi_connection.execute(
+ 'ATTACH DATABASE "%s_%s_test_schema.db" AS test_schema'
+ % (follower_ident, engine.driver)
+ )
+
+
+@create_db.for_db("sqlite")
+def _sqlite_create_db(cfg, eng, ident):
+ pass
+
+
+@drop_db.for_db("sqlite")
+def _sqlite_drop_db(cfg, eng, ident):
+ for path in [
+ "%s.db" % ident,
+ "%s_%s_test_schema.db" % (ident, eng.driver),
+ ]:
+ if os.path.exists(path):
+ log.info("deleting SQLite database file: %s" % path)
+ os.remove(path)
+
+
+@stop_test_class_outside_fixtures.for_db("sqlite")
+def stop_test_class_outside_fixtures(config, db, cls):
+ with db.connect() as conn:
+ files = [
+ row.file
+ for row in conn.exec_driver_sql("PRAGMA database_list")
+ if row.file
+ ]
+
+ if files:
+ db.dispose()
+ # some sqlite file tests are not cleaning up well yet, so do this
+ # just to make things simple for now
+ for file_ in files:
+ if file_ and os.path.exists(file_):
+ os.remove(file_)
+
+
+@temp_table_keyword_args.for_db("sqlite")
+def _sqlite_temp_table_keyword_args(cfg, eng):
+ return {"prefixes": ["TEMPORARY"]}
+
+
+@run_reap_dbs.for_db("sqlite")
+def _reap_sqlite_dbs(url, idents):
+ log.info("db reaper connecting to %r", url)
+
+ log.info("identifiers in file: %s", ", ".join(idents))
+ for ident in idents:
+ # we don't have a config so we can't call _sqlite_drop_db due to the
+ # decorator
+ for ext in ("db", "db.enc"):
+ for path in (
+ ["%s.%s" % (ident, ext)]
+ + [
+ "%s_%s.%s" % (drivername, ident, ext)
+ for drivername in _drivernames
+ ]
+ + [
+ "%s_test_schema.%s" % (drivername, ext)
+ for drivername in _drivernames
+ ]
+ + [
+ "%s_%s_test_schema.%s" % (ident, drivername, ext)
+ for drivername in _drivernames
+ ]
+ ):
+ if os.path.exists(path):
+ log.info("deleting SQLite database file: %s" % path)
+ os.remove(path)
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
new file mode 100644
index 0000000..65f94c8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py
@@ -0,0 +1,164 @@
+# sqlite/pysqlcipher.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: sqlite+pysqlcipher
+ :name: pysqlcipher
+ :dbapi: sqlcipher 3 or pysqlcipher
+ :connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=<iter>]
+
+ Dialect for support of DBAPIs that make use of the
+ `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
+
+
+Driver
+------
+
+Current dialect selection logic is:
+
+* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module,
+ that module is used.
+* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/
+* If not available, fall back to https://pypi.org/project/pysqlcipher3/
+* For Python 2, https://pypi.org/project/pysqlcipher/ is used.
+
+.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no
+ longer maintained; the ``sqlcipher3`` driver as of this writing appears
+ to be current. For future compatibility, any pysqlcipher-compatible DBAPI
+ may be used as follows::
+
+ import sqlcipher_compatible_driver
+
+ from sqlalchemy import create_engine
+
+ e = create_engine(
+ "sqlite+pysqlcipher://:password@/dbname.db",
+ module=sqlcipher_compatible_driver
+ )
+
+These drivers make use of the SQLCipher engine. This system essentially
+introduces new PRAGMA commands to SQLite which allows the setting of a
+passphrase and other encryption parameters, allowing the database file to be
+encrypted.
+
+
+Connect Strings
+---------------
+
+The format of the connect string is in every way the same as that
+of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
+"password" field is now accepted, which should contain a passphrase::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
+
+For an absolute file path, two leading slashes should be used for the
+database name::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
+
+A selection of additional encryption-related pragmas supported by SQLCipher
+as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
+in the query string, and will result in that PRAGMA being called for each
+new connection. Currently, ``cipher``, ``kdf_iter``
+``cipher_page_size`` and ``cipher_use_hmac`` are supported::
+
+ e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
+
+.. warning:: Previous versions of sqlalchemy did not take into consideration
+ the encryption-related pragmas passed in the url string, that were silently
+ ignored. This may cause errors when opening files saved by a
+ previous sqlalchemy version if the encryption options do not match.
+
+
+Pooling Behavior
+----------------
+
+The driver makes a change to the default pool behavior of pysqlite
+as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
+has been observed to be significantly slower on connection than the
+pysqlite driver, most likely due to the encryption overhead, so the
+dialect here defaults to using the :class:`.SingletonThreadPool`
+implementation,
+instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
+implementation is entirely configurable using the
+:paramref:`_sa.create_engine.poolclass` parameter; the :class:`.
+StaticPool` may
+be more feasible for single-threaded use, or :class:`.NullPool` may be used
+to prevent unencrypted connections from being held open for long periods of
+time, at the expense of slower startup time for new connections.
+
+
+""" # noqa
+
+from __future__ import absolute_import
+
+from .pysqlite import SQLiteDialect_pysqlite
+from ... import pool
+from ... import util
+
+
+class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
+ driver = "pysqlcipher"
+ supports_statement_cache = True
+
+ pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac")
+
+ @classmethod
+ def dbapi(cls):
+ if util.py3k:
+ try:
+ import sqlcipher3 as sqlcipher
+ except ImportError:
+ pass
+ else:
+ return sqlcipher
+
+ from pysqlcipher3 import dbapi2 as sqlcipher
+
+ else:
+ from pysqlcipher import dbapi2 as sqlcipher
+
+ return sqlcipher
+
+ @classmethod
+ def get_pool_class(cls, url):
+ return pool.SingletonThreadPool
+
+ def on_connect_url(self, url):
+ super_on_connect = super(
+ SQLiteDialect_pysqlcipher, self
+ ).on_connect_url(url)
+
+ # pull the info we need from the URL early. Even though URL
+ # is immutable, we don't want any in-place changes to the URL
+ # to affect things
+ passphrase = url.password or ""
+ url_query = dict(url.query)
+
+ def on_connect(conn):
+ cursor = conn.cursor()
+ cursor.execute('pragma key="%s"' % passphrase)
+ for prag in self.pragmas:
+ value = url_query.get(prag, None)
+ if value is not None:
+ cursor.execute('pragma %s="%s"' % (prag, value))
+ cursor.close()
+
+ if super_on_connect:
+ super_on_connect(conn)
+
+ return on_connect
+
+ def create_connect_args(self, url):
+ plain_url = url._replace(password=None)
+ plain_url = plain_url.difference_update_query(self.pragmas)
+ return super(SQLiteDialect_pysqlcipher, self).create_connect_args(
+ plain_url
+ )
+
+
+dialect = SQLiteDialect_pysqlcipher
diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
new file mode 100644
index 0000000..1aae561
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py
@@ -0,0 +1,613 @@
+# sqlite/pysqlite.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""
+.. dialect:: sqlite+pysqlite
+ :name: pysqlite
+ :dbapi: sqlite3
+ :connectstring: sqlite+pysqlite:///file_path
+ :url: https://docs.python.org/library/sqlite3.html
+
+ Note that ``pysqlite`` is the same driver as the ``sqlite3``
+ module included with the Python distribution.
+
+Driver
+------
+
+The ``sqlite3`` Python DBAPI is standard on all modern Python versions;
+for cPython and Pypy, no additional installation is necessary.
+
+
+Connect Strings
+---------------
+
+The file specification for the SQLite database is taken as the "database"
+portion of the URL. Note that the format of a SQLAlchemy url is::
+
+ driver://user:pass@host/database
+
+This means that the actual filename to be used starts with the characters to
+the **right** of the third slash. So connecting to a relative filepath
+looks like::
+
+ # relative path
+ e = create_engine('sqlite:///path/to/database.db')
+
+An absolute path, which is denoted by starting with a slash, means you
+need **four** slashes::
+
+ # absolute path
+ e = create_engine('sqlite:////path/to/database.db')
+
+To use a Windows path, regular drive specifications and backslashes can be
+used. Double backslashes are probably needed::
+
+ # absolute path on Windows
+ e = create_engine('sqlite:///C:\\path\\to\\database.db')
+
+The sqlite ``:memory:`` identifier is the default if no filepath is
+present. Specify ``sqlite://`` and nothing else::
+
+ # in-memory database
+ e = create_engine('sqlite://')
+
+.. _pysqlite_uri_connections:
+
+URI Connections
+^^^^^^^^^^^^^^^
+
+Modern versions of SQLite support an alternative system of connecting using a
+`driver level URI <https://www.sqlite.org/uri.html>`_, which has the advantage
+that additional driver-level arguments can be passed including options such as
+"read only". The Python sqlite3 driver supports this mode under modern Python
+3 versions. The SQLAlchemy pysqlite driver supports this mode of use by
+specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept
+as the "database" portion of the SQLAlchemy url (that is, following a slash)::
+
+ e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true")
+
+.. note:: The "uri=true" parameter must appear in the **query string**
+ of the URL. It will not currently work as expected if it is only
+ present in the :paramref:`_sa.create_engine.connect_args`
+ parameter dictionary.
+
+The logic reconciles the simultaneous presence of SQLAlchemy's query string and
+SQLite's query string by separating out the parameters that belong to the
+Python sqlite3 driver vs. those that belong to the SQLite URI. This is
+achieved through the use of a fixed list of parameters known to be accepted by
+the Python side of the driver. For example, to include a URL that indicates
+the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the
+SQLite "mode" and "nolock" parameters, they can all be passed together on the
+query string::
+
+ e = create_engine(
+ "sqlite:///file:path/to/database?"
+ "check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true"
+ )
+
+Above, the pysqlite / sqlite3 DBAPI would be passed arguments as::
+
+ sqlite3.connect(
+ "file:path/to/database?mode=ro&nolock=1",
+ check_same_thread=True, timeout=10, uri=True
+ )
+
+Regarding future parameters added to either the Python or native drivers. new
+parameter names added to the SQLite URI scheme should be automatically
+accommodated by this scheme. New parameter names added to the Python driver
+side can be accommodated by specifying them in the
+:paramref:`_sa.create_engine.connect_args` dictionary,
+until dialect support is
+added by SQLAlchemy. For the less likely case that the native SQLite driver
+adds a new parameter name that overlaps with one of the existing, known Python
+driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would
+require adjustment for the URL scheme to continue to support this.
+
+As is always the case for all SQLAlchemy dialects, the entire "URL" process
+can be bypassed in :func:`_sa.create_engine` through the use of the
+:paramref:`_sa.create_engine.creator`
+parameter which allows for a custom callable
+that creates a Python sqlite3 driver level connection directly.
+
+.. versionadded:: 1.3.9
+
+.. seealso::
+
+ `Uniform Resource Identifiers <https://www.sqlite.org/uri.html>`_ - in
+ the SQLite documentation
+
+.. _pysqlite_regexp:
+
+Regular Expression Support
+---------------------------
+
+.. versionadded:: 1.4
+
+Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided
+using Python's re.search_ function. SQLite itself does not include a working
+regular expression operator; instead, it includes a non-implemented placeholder
+operator ``REGEXP`` that calls a user-defined function that must be provided.
+
+SQLAlchemy's implementation makes use of the pysqlite create_function_ hook
+as follows::
+
+
+ def regexp(a, b):
+ return re.search(a, b) is not None
+
+ sqlite_connection.create_function(
+ "regexp", 2, regexp,
+ )
+
+There is currently no support for regular expression flags as a separate
+argument, as these are not supported by SQLite's REGEXP operator, however these
+may be included inline within the regular expression string. See `Python regular expressions`_ for
+details.
+
+.. seealso::
+
+ `Python regular expressions`_: Documentation for Python's regular expression syntax.
+
+.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
+
+.. _re.search: https://docs.python.org/3/library/re.html#re.search
+
+.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search
+
+
+
+Compatibility with sqlite3 "native" date and datetime types
+-----------------------------------------------------------
+
+The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
+sqlite3.PARSE_COLNAMES options, which have the effect of any column
+or expression explicitly cast as "date" or "timestamp" will be converted
+to a Python date or datetime object. The date and datetime types provided
+with the pysqlite dialect are not currently compatible with these options,
+since they render the ISO date/datetime including microseconds, which
+pysqlite's driver does not. Additionally, SQLAlchemy does not at
+this time automatically render the "cast" syntax required for the
+freestanding functions "current_timestamp" and "current_date" to return
+datetime/date types natively. Unfortunately, pysqlite
+does not provide the standard DBAPI types in ``cursor.description``,
+leaving SQLAlchemy with no way to detect these types on the fly
+without expensive per-row type checks.
+
+Keeping in mind that pysqlite's parsing option is not recommended,
+nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
+can be forced if one configures "native_datetime=True" on create_engine()::
+
+ engine = create_engine('sqlite://',
+ connect_args={'detect_types':
+ sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
+ native_datetime=True
+ )
+
+With this flag enabled, the DATE and TIMESTAMP types (but note - not the
+DATETIME or TIME types...confused yet ?) will not perform any bind parameter
+or result processing. Execution of "func.current_date()" will return a string.
+"func.current_timestamp()" is registered as returning a DATETIME type in
+SQLAlchemy, so this function still receives SQLAlchemy-level result
+processing.
+
+.. _pysqlite_threading_pooling:
+
+Threading/Pooling Behavior
+---------------------------
+
+Pysqlite's default behavior is to prohibit the usage of a single connection
+in more than one thread. This is originally intended to work with older
+versions of SQLite that did not support multithreaded operation under
+various circumstances. In particular, older SQLite versions
+did not allow a ``:memory:`` database to be used in multiple threads
+under any circumstances.
+
+Pysqlite does include a now-undocumented flag known as
+``check_same_thread`` which will disable this check, however note that
+pysqlite connections are still not safe to use in concurrently in multiple
+threads. In particular, any statement execution calls would need to be
+externally mutexed, as Pysqlite does not provide for thread-safe propagation
+of error messages among other things. So while even ``:memory:`` databases
+can be shared among threads in modern SQLite, Pysqlite doesn't provide enough
+thread-safety to make this usage worth it.
+
+SQLAlchemy sets up pooling to work with Pysqlite's default behavior:
+
+* When a ``:memory:`` SQLite database is specified, the dialect by default
+ will use :class:`.SingletonThreadPool`. This pool maintains a single
+ connection per thread, so that all access to the engine within the current
+ thread use the same ``:memory:`` database - other threads would access a
+ different ``:memory:`` database.
+* When a file-based database is specified, the dialect will use
+ :class:`.NullPool` as the source of connections. This pool closes and
+ discards connections which are returned to the pool immediately. SQLite
+ file-based connections have extremely low overhead, so pooling is not
+ necessary. The scheme also prevents a connection from being used again in
+ a different thread and works best with SQLite's coarse-grained file locking.
+
+Using a Memory Database in Multiple Threads
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To use a ``:memory:`` database in a multithreaded scenario, the same
+connection object must be shared among threads, since the database exists
+only within the scope of that connection. The
+:class:`.StaticPool` implementation will maintain a single connection
+globally, and the ``check_same_thread`` flag can be passed to Pysqlite
+as ``False``::
+
+ from sqlalchemy.pool import StaticPool
+ engine = create_engine('sqlite://',
+ connect_args={'check_same_thread':False},
+ poolclass=StaticPool)
+
+Note that using a ``:memory:`` database in multiple threads requires a recent
+version of SQLite.
+
+Using Temporary Tables with SQLite
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Due to the way SQLite deals with temporary tables, if you wish to use a
+temporary table in a file-based SQLite database across multiple checkouts
+from the connection pool, such as when using an ORM :class:`.Session` where
+the temporary table should continue to remain after :meth:`.Session.commit` or
+:meth:`.Session.rollback` is called, a pool which maintains a single
+connection must be used. Use :class:`.SingletonThreadPool` if the scope is
+only needed within the current thread, or :class:`.StaticPool` is scope is
+needed within multiple threads for this case::
+
+ # maintain the same connection per thread
+ from sqlalchemy.pool import SingletonThreadPool
+ engine = create_engine('sqlite:///mydb.db',
+ poolclass=SingletonThreadPool)
+
+
+ # maintain the same connection across all threads
+ from sqlalchemy.pool import StaticPool
+ engine = create_engine('sqlite:///mydb.db',
+ poolclass=StaticPool)
+
+Note that :class:`.SingletonThreadPool` should be configured for the number
+of threads that are to be used; beyond that number, connections will be
+closed out in a non deterministic way.
+
+Unicode
+-------
+
+The pysqlite driver only returns Python ``unicode`` objects in result sets,
+never plain strings, and accommodates ``unicode`` objects within bound
+parameter values in all cases. Regardless of the SQLAlchemy string type in
+use, string-based result values will by Python ``unicode`` in Python 2.
+The :class:`.Unicode` type should still be used to indicate those columns that
+require unicode, however, so that non-``unicode`` values passed inadvertently
+will emit a warning. Pysqlite will emit an error if a non-``unicode`` string
+is passed containing non-ASCII characters.
+
+Dealing with Mixed String / Binary Columns in Python 3
+------------------------------------------------------
+
+The SQLite database is weakly typed, and as such it is possible when using
+binary values, which in Python 3 are represented as ``b'some string'``, that a
+particular SQLite database can have data values within different rows where
+some of them will be returned as a ``b''`` value by the Pysqlite driver, and
+others will be returned as Python strings, e.g. ``''`` values. This situation
+is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used
+consistently, however if a particular SQLite database has data that was
+inserted using the Pysqlite driver directly, or when using the SQLAlchemy
+:class:`.String` type which was later changed to :class:`.LargeBinary`, the
+table will not be consistently readable because SQLAlchemy's
+:class:`.LargeBinary` datatype does not handle strings so it has no way of
+"encoding" a value that is in string format.
+
+To deal with a SQLite table that has mixed string / binary data in the
+same column, use a custom type that will check each row individually::
+
+ # note this is Python 3 only
+
+ from sqlalchemy import String
+ from sqlalchemy import TypeDecorator
+
+ class MixedBinary(TypeDecorator):
+ impl = String
+ cache_ok = True
+
+ def process_result_value(self, value, dialect):
+ if isinstance(value, str):
+ value = bytes(value, 'utf-8')
+ elif value is not None:
+ value = bytes(value)
+
+ return value
+
+Then use the above ``MixedBinary`` datatype in the place where
+:class:`.LargeBinary` would normally be used.
+
+.. _pysqlite_serializable:
+
+Serializable isolation / Savepoints / Transactional DDL
+-------------------------------------------------------
+
+In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
+driver's assortment of issues that prevent several features of SQLite
+from working correctly. The pysqlite DBAPI driver has several
+long-standing bugs which impact the correctness of its transactional
+behavior. In its default mode of operation, SQLite features such as
+SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
+non-functional, and in order to use these features, workarounds must
+be taken.
+
+The issue is essentially that the driver attempts to second-guess the user's
+intent, failing to start transactions and sometimes ending them prematurely, in
+an effort to minimize the SQLite databases's file locking behavior, even
+though SQLite itself uses "shared" locks for read-only activities.
+
+SQLAlchemy chooses to not alter this behavior by default, as it is the
+long-expected behavior of the pysqlite driver; if and when the pysqlite
+driver attempts to repair these issues, that will be more of a driver towards
+defaults for SQLAlchemy.
+
+The good news is that with a few events, we can implement transactional
+support fully, by disabling pysqlite's feature entirely and emitting BEGIN
+ourselves. This is achieved using two event listeners::
+
+ from sqlalchemy import create_engine, event
+
+ engine = create_engine("sqlite:///myfile.db")
+
+ @event.listens_for(engine, "connect")
+ def do_connect(dbapi_connection, connection_record):
+ # disable pysqlite's emitting of the BEGIN statement entirely.
+ # also stops it from emitting COMMIT before any DDL.
+ dbapi_connection.isolation_level = None
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ # emit our own BEGIN
+ conn.exec_driver_sql("BEGIN")
+
+.. warning:: When using the above recipe, it is advised to not use the
+ :paramref:`.Connection.execution_options.isolation_level` setting on
+ :class:`_engine.Connection` and :func:`_sa.create_engine`
+ with the SQLite driver,
+ as this function necessarily will also alter the ".isolation_level" setting.
+
+
+Above, we intercept a new pysqlite connection and disable any transactional
+integration. Then, at the point at which SQLAlchemy knows that transaction
+scope is to begin, we emit ``"BEGIN"`` ourselves.
+
+When we take control of ``"BEGIN"``, we can also control directly SQLite's
+locking modes, introduced at
+`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_,
+by adding the desired locking mode to our ``"BEGIN"``::
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ conn.exec_driver_sql("BEGIN EXCLUSIVE")
+
+.. seealso::
+
+ `BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ -
+ on the SQLite site
+
+ `sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ -
+ on the Python bug tracker
+
+ `sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ -
+ on the Python bug tracker
+
+
+""" # noqa
+
+import os
+import re
+
+from .base import DATE
+from .base import DATETIME
+from .base import SQLiteDialect
+from ... import exc
+from ... import pool
+from ... import types as sqltypes
+from ... import util
+
+
+class _SQLite_pysqliteTimeStamp(DATETIME):
+ def bind_processor(self, dialect):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATETIME.bind_processor(self, dialect)
+
+ def result_processor(self, dialect, coltype):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATETIME.result_processor(self, dialect, coltype)
+
+
+class _SQLite_pysqliteDate(DATE):
+ def bind_processor(self, dialect):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATE.bind_processor(self, dialect)
+
+ def result_processor(self, dialect, coltype):
+ if dialect.native_datetime:
+ return None
+ else:
+ return DATE.result_processor(self, dialect, coltype)
+
+
+class SQLiteDialect_pysqlite(SQLiteDialect):
+ default_paramstyle = "qmark"
+ supports_statement_cache = True
+
+ colspecs = util.update_copy(
+ SQLiteDialect.colspecs,
+ {
+ sqltypes.Date: _SQLite_pysqliteDate,
+ sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
+ },
+ )
+
+ if not util.py2k:
+ description_encoding = None
+
+ driver = "pysqlite"
+
+ @classmethod
+ def dbapi(cls):
+ if util.py2k:
+ try:
+ from pysqlite2 import dbapi2 as sqlite
+ except ImportError:
+ try:
+ from sqlite3 import dbapi2 as sqlite
+ except ImportError as e:
+ raise e
+ else:
+ from sqlite3 import dbapi2 as sqlite
+ return sqlite
+
+ @classmethod
+ def _is_url_file_db(cls, url):
+ if (url.database and url.database != ":memory:") and (
+ url.query.get("mode", None) != "memory"
+ ):
+ return True
+ else:
+ return False
+
+ @classmethod
+ def get_pool_class(cls, url):
+ if cls._is_url_file_db(url):
+ return pool.NullPool
+ else:
+ return pool.SingletonThreadPool
+
+ def _get_server_version_info(self, connection):
+ return self.dbapi.sqlite_version_info
+
+ _isolation_lookup = SQLiteDialect._isolation_lookup.union(
+ {
+ "AUTOCOMMIT": None,
+ }
+ )
+
+ def set_isolation_level(self, connection, level):
+ if hasattr(connection, "dbapi_connection"):
+ dbapi_connection = connection.dbapi_connection
+ else:
+ dbapi_connection = connection
+
+ if level == "AUTOCOMMIT":
+ dbapi_connection.isolation_level = None
+ else:
+ dbapi_connection.isolation_level = ""
+ return super(SQLiteDialect_pysqlite, self).set_isolation_level(
+ connection, level
+ )
+
+ def on_connect(self):
+ connect = super(SQLiteDialect_pysqlite, self).on_connect()
+
+ def regexp(a, b):
+ if b is None:
+ return None
+ return re.search(a, b) is not None
+
+ def set_regexp(connection):
+ if hasattr(connection, "dbapi_connection"):
+ dbapi_connection = connection.dbapi_connection
+ else:
+ dbapi_connection = connection
+ dbapi_connection.create_function(
+ "regexp",
+ 2,
+ regexp,
+ )
+
+ fns = [set_regexp]
+
+ if self.isolation_level is not None:
+
+ def iso_level(conn):
+ self.set_isolation_level(conn, self.isolation_level)
+
+ fns.append(iso_level)
+
+ def connect(conn):
+ for fn in fns:
+ fn(conn)
+
+ return connect
+
+ def create_connect_args(self, url):
+ if url.username or url.password or url.host or url.port:
+ raise exc.ArgumentError(
+ "Invalid SQLite URL: %s\n"
+ "Valid SQLite URL forms are:\n"
+ " sqlite:///:memory: (or, sqlite://)\n"
+ " sqlite:///relative/path/to/file.db\n"
+ " sqlite:////absolute/path/to/file.db" % (url,)
+ )
+
+ # theoretically, this list can be augmented, at least as far as
+ # parameter names accepted by sqlite3/pysqlite, using
+ # inspect.getfullargspec(). for the moment this seems like overkill
+ # as these parameters don't change very often, and as always,
+ # parameters passed to connect_args will always go to the
+ # sqlite3/pysqlite driver.
+ pysqlite_args = [
+ ("uri", bool),
+ ("timeout", float),
+ ("isolation_level", str),
+ ("detect_types", int),
+ ("check_same_thread", bool),
+ ("cached_statements", int),
+ ]
+ opts = url.query
+ pysqlite_opts = {}
+ for key, type_ in pysqlite_args:
+ util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
+
+ if pysqlite_opts.get("uri", False):
+ uri_opts = dict(opts)
+ # here, we are actually separating the parameters that go to
+ # sqlite3/pysqlite vs. those that go the SQLite URI. What if
+ # two names conflict? again, this seems to be not the case right
+ # now, and in the case that new names are added to
+ # either side which overlap, again the sqlite3/pysqlite parameters
+ # can be passed through connect_args instead of in the URL.
+ # If SQLite native URIs add a parameter like "timeout" that
+ # we already have listed here for the python driver, then we need
+ # to adjust for that here.
+ for key, type_ in pysqlite_args:
+ uri_opts.pop(key, None)
+ filename = url.database
+ if uri_opts:
+ # sorting of keys is for unit test support
+ filename += "?" + (
+ "&".join(
+ "%s=%s" % (key, uri_opts[key])
+ for key in sorted(uri_opts)
+ )
+ )
+ else:
+ filename = url.database or ":memory:"
+ if filename != ":memory:":
+ filename = os.path.abspath(filename)
+
+ return ([filename], pysqlite_opts)
+
+ def is_disconnect(self, e, connection, cursor):
+ return isinstance(
+ e, self.dbapi.ProgrammingError
+ ) and "Cannot operate on a closed database." in str(e)
+
+
+dialect = SQLiteDialect_pysqlite
diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py
new file mode 100644
index 0000000..c7755c8
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/__init__.py
@@ -0,0 +1,67 @@
+# sybase/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import base # noqa
+from . import pyodbc # noqa
+from . import pysybase # noqa
+from .base import BIGINT
+from .base import BINARY
+from .base import BIT
+from .base import CHAR
+from .base import DATE
+from .base import DATETIME
+from .base import FLOAT
+from .base import IMAGE
+from .base import INT
+from .base import INTEGER
+from .base import MONEY
+from .base import NCHAR
+from .base import NUMERIC
+from .base import NVARCHAR
+from .base import SMALLINT
+from .base import SMALLMONEY
+from .base import TEXT
+from .base import TIME
+from .base import TINYINT
+from .base import UNICHAR
+from .base import UNITEXT
+from .base import UNIVARCHAR
+from .base import VARBINARY
+from .base import VARCHAR
+
+
+# default dialect
+base.dialect = dialect = pyodbc.dialect
+
+
+__all__ = (
+ "CHAR",
+ "VARCHAR",
+ "TIME",
+ "NCHAR",
+ "NVARCHAR",
+ "TEXT",
+ "DATE",
+ "DATETIME",
+ "FLOAT",
+ "NUMERIC",
+ "BIGINT",
+ "INT",
+ "INTEGER",
+ "SMALLINT",
+ "BINARY",
+ "VARBINARY",
+ "UNITEXT",
+ "UNICHAR",
+ "UNIVARCHAR",
+ "IMAGE",
+ "BIT",
+ "MONEY",
+ "SMALLMONEY",
+ "TINYINT",
+ "dialect",
+)
diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py
new file mode 100644
index 0000000..83248d1
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/base.py
@@ -0,0 +1,1100 @@
+# sybase/base.py
+# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+# get_select_precolumns(), limit_clause() implementation
+# copyright (C) 2007 Fisch Asset Management
+# AG https://www.fam.ch, with coding by Alexander Houben
+# alexander.houben@thor-solutions.ch
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+.. dialect:: sybase
+ :name: Sybase
+
+.. note::
+
+ The Sybase dialect within SQLAlchemy **is not currently supported**.
+ It is not tested within continuous integration and is likely to have
+ many issues and caveats not currently handled. Consider using the
+ `external dialect <https://github.com/gordthompson/sqlalchemy-sybase>`_
+ instead.
+
+.. deprecated:: 1.4 The internal Sybase dialect is deprecated and will be
+ removed in a future version. Use the external dialect.
+
+"""
+
+import re
+
+from sqlalchemy import exc
+from sqlalchemy import schema as sa_schema
+from sqlalchemy import types as sqltypes
+from sqlalchemy import util
+from sqlalchemy.engine import default
+from sqlalchemy.engine import reflection
+from sqlalchemy.sql import compiler
+from sqlalchemy.sql import text
+from sqlalchemy.types import BIGINT
+from sqlalchemy.types import BINARY
+from sqlalchemy.types import CHAR
+from sqlalchemy.types import DATE
+from sqlalchemy.types import DATETIME
+from sqlalchemy.types import DECIMAL
+from sqlalchemy.types import FLOAT
+from sqlalchemy.types import INT # noqa
+from sqlalchemy.types import INTEGER
+from sqlalchemy.types import NCHAR
+from sqlalchemy.types import NUMERIC
+from sqlalchemy.types import NVARCHAR
+from sqlalchemy.types import REAL
+from sqlalchemy.types import SMALLINT
+from sqlalchemy.types import TEXT
+from sqlalchemy.types import TIME
+from sqlalchemy.types import TIMESTAMP
+from sqlalchemy.types import Unicode
+from sqlalchemy.types import VARBINARY
+from sqlalchemy.types import VARCHAR
+
+
+RESERVED_WORDS = set(
+ [
+ "add",
+ "all",
+ "alter",
+ "and",
+ "any",
+ "as",
+ "asc",
+ "backup",
+ "begin",
+ "between",
+ "bigint",
+ "binary",
+ "bit",
+ "bottom",
+ "break",
+ "by",
+ "call",
+ "capability",
+ "cascade",
+ "case",
+ "cast",
+ "char",
+ "char_convert",
+ "character",
+ "check",
+ "checkpoint",
+ "close",
+ "comment",
+ "commit",
+ "connect",
+ "constraint",
+ "contains",
+ "continue",
+ "convert",
+ "create",
+ "cross",
+ "cube",
+ "current",
+ "current_timestamp",
+ "current_user",
+ "cursor",
+ "date",
+ "dbspace",
+ "deallocate",
+ "dec",
+ "decimal",
+ "declare",
+ "default",
+ "delete",
+ "deleting",
+ "desc",
+ "distinct",
+ "do",
+ "double",
+ "drop",
+ "dynamic",
+ "else",
+ "elseif",
+ "encrypted",
+ "end",
+ "endif",
+ "escape",
+ "except",
+ "exception",
+ "exec",
+ "execute",
+ "existing",
+ "exists",
+ "externlogin",
+ "fetch",
+ "first",
+ "float",
+ "for",
+ "force",
+ "foreign",
+ "forward",
+ "from",
+ "full",
+ "goto",
+ "grant",
+ "group",
+ "having",
+ "holdlock",
+ "identified",
+ "if",
+ "in",
+ "index",
+ "index_lparen",
+ "inner",
+ "inout",
+ "insensitive",
+ "insert",
+ "inserting",
+ "install",
+ "instead",
+ "int",
+ "integer",
+ "integrated",
+ "intersect",
+ "into",
+ "iq",
+ "is",
+ "isolation",
+ "join",
+ "key",
+ "lateral",
+ "left",
+ "like",
+ "lock",
+ "login",
+ "long",
+ "match",
+ "membership",
+ "message",
+ "mode",
+ "modify",
+ "natural",
+ "new",
+ "no",
+ "noholdlock",
+ "not",
+ "notify",
+ "null",
+ "numeric",
+ "of",
+ "off",
+ "on",
+ "open",
+ "option",
+ "options",
+ "or",
+ "order",
+ "others",
+ "out",
+ "outer",
+ "over",
+ "passthrough",
+ "precision",
+ "prepare",
+ "primary",
+ "print",
+ "privileges",
+ "proc",
+ "procedure",
+ "publication",
+ "raiserror",
+ "readtext",
+ "real",
+ "reference",
+ "references",
+ "release",
+ "remote",
+ "remove",
+ "rename",
+ "reorganize",
+ "resource",
+ "restore",
+ "restrict",
+ "return",
+ "revoke",
+ "right",
+ "rollback",
+ "rollup",
+ "save",
+ "savepoint",
+ "scroll",
+ "select",
+ "sensitive",
+ "session",
+ "set",
+ "setuser",
+ "share",
+ "smallint",
+ "some",
+ "sqlcode",
+ "sqlstate",
+ "start",
+ "stop",
+ "subtrans",
+ "subtransaction",
+ "synchronize",
+ "syntax_error",
+ "table",
+ "temporary",
+ "then",
+ "time",
+ "timestamp",
+ "tinyint",
+ "to",
+ "top",
+ "tran",
+ "trigger",
+ "truncate",
+ "tsequal",
+ "unbounded",
+ "union",
+ "unique",
+ "unknown",
+ "unsigned",
+ "update",
+ "updating",
+ "user",
+ "using",
+ "validate",
+ "values",
+ "varbinary",
+ "varchar",
+ "variable",
+ "varying",
+ "view",
+ "wait",
+ "waitfor",
+ "when",
+ "where",
+ "while",
+ "window",
+ "with",
+ "with_cube",
+ "with_lparen",
+ "with_rollup",
+ "within",
+ "work",
+ "writetext",
+ ]
+)
+
+
+class _SybaseUnitypeMixin(object):
+ """these types appear to return a buffer object."""
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is not None:
+ return str(value) # decode("ucs-2")
+ else:
+ return None
+
+ return process
+
+
+class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
+ __visit_name__ = "UNICHAR"
+
+
+class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode):
+ __visit_name__ = "UNIVARCHAR"
+
+
+class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText):
+ __visit_name__ = "UNITEXT"
+
+
+class TINYINT(sqltypes.Integer):
+ __visit_name__ = "TINYINT"
+
+
+class BIT(sqltypes.TypeEngine):
+ __visit_name__ = "BIT"
+
+
+class MONEY(sqltypes.TypeEngine):
+ __visit_name__ = "MONEY"
+
+
+class SMALLMONEY(sqltypes.TypeEngine):
+ __visit_name__ = "SMALLMONEY"
+
+
+class UNIQUEIDENTIFIER(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
+
+
+class IMAGE(sqltypes.LargeBinary):
+ __visit_name__ = "IMAGE"
+
+
+class SybaseTypeCompiler(compiler.GenericTypeCompiler):
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_IMAGE(type_)
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BIT(type_)
+
+ def visit_unicode(self, type_, **kw):
+ return self.visit_NVARCHAR(type_)
+
+ def visit_UNICHAR(self, type_, **kw):
+ return "UNICHAR(%d)" % type_.length
+
+ def visit_UNIVARCHAR(self, type_, **kw):
+ return "UNIVARCHAR(%d)" % type_.length
+
+ def visit_UNITEXT(self, type_, **kw):
+ return "UNITEXT"
+
+ def visit_TINYINT(self, type_, **kw):
+ return "TINYINT"
+
+ def visit_IMAGE(self, type_, **kw):
+ return "IMAGE"
+
+ def visit_BIT(self, type_, **kw):
+ return "BIT"
+
+ def visit_MONEY(self, type_, **kw):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_, **kw):
+ return "SMALLMONEY"
+
+ def visit_UNIQUEIDENTIFIER(self, type_, **kw):
+ return "UNIQUEIDENTIFIER"
+
+
+ischema_names = {
+ "bigint": BIGINT,
+ "int": INTEGER,
+ "integer": INTEGER,
+ "smallint": SMALLINT,
+ "tinyint": TINYINT,
+ "unsigned bigint": BIGINT, # TODO: unsigned flags
+ "unsigned int": INTEGER, # TODO: unsigned flags
+ "unsigned smallint": SMALLINT, # TODO: unsigned flags
+ "numeric": NUMERIC,
+ "decimal": DECIMAL,
+ "dec": DECIMAL,
+ "float": FLOAT,
+ "double": NUMERIC, # TODO
+ "double precision": NUMERIC, # TODO
+ "real": REAL,
+ "smallmoney": SMALLMONEY,
+ "money": MONEY,
+ "smalldatetime": DATETIME,
+ "datetime": DATETIME,
+ "date": DATE,
+ "time": TIME,
+ "char": CHAR,
+ "character": CHAR,
+ "varchar": VARCHAR,
+ "character varying": VARCHAR,
+ "char varying": VARCHAR,
+ "unichar": UNICHAR,
+ "unicode character": UNIVARCHAR,
+ "nchar": NCHAR,
+ "national char": NCHAR,
+ "national character": NCHAR,
+ "nvarchar": NVARCHAR,
+ "nchar varying": NVARCHAR,
+ "national char varying": NVARCHAR,
+ "national character varying": NVARCHAR,
+ "text": TEXT,
+ "unitext": UNITEXT,
+ "binary": BINARY,
+ "varbinary": VARBINARY,
+ "image": IMAGE,
+ "bit": BIT,
+ # not in documentation for ASE 15.7
+ "long varchar": TEXT, # TODO
+ "timestamp": TIMESTAMP,
+ "uniqueidentifier": UNIQUEIDENTIFIER,
+}
+
+
+class SybaseInspector(reflection.Inspector):
+ def __init__(self, conn):
+ reflection.Inspector.__init__(self, conn)
+
+ def get_table_id(self, table_name, schema=None):
+ """Return the table id from `table_name` and `schema`."""
+
+ return self.dialect.get_table_id(
+ self.bind, table_name, schema, info_cache=self.info_cache
+ )
+
+
+class SybaseExecutionContext(default.DefaultExecutionContext):
+ _enable_identity_insert = False
+
+ def set_ddl_autocommit(self, connection, value):
+ """Must be implemented by subclasses to accommodate DDL executions.
+
+ "connection" is the raw unwrapped DBAPI connection. "value"
+ is True or False. when True, the connection should be configured
+ such that a DDL can take place subsequently. when False,
+ a DDL has taken place and the connection should be resumed
+ into non-autocommit mode.
+
+ """
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ if self.isinsert:
+ tbl = self.compiled.statement.table
+ seq_column = tbl._autoincrement_column
+ insert_has_sequence = seq_column is not None
+
+ if insert_has_sequence:
+ self._enable_identity_insert = (
+ seq_column.key in self.compiled_parameters[0]
+ )
+ else:
+ self._enable_identity_insert = False
+
+ if self._enable_identity_insert:
+ self.cursor.execute(
+ "SET IDENTITY_INSERT %s ON"
+ % self.dialect.identifier_preparer.format_table(tbl)
+ )
+
+ if self.isddl:
+ # TODO: to enhance this, we can detect "ddl in tran" on the
+ # database settings. this error message should be improved to
+ # include a note about that.
+ if not self.should_autocommit:
+ raise exc.InvalidRequestError(
+ "The Sybase dialect only supports "
+ "DDL in 'autocommit' mode at this time."
+ )
+
+ self.root_connection.engine.logger.info(
+ "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')"
+ )
+
+ self.set_ddl_autocommit(
+ self.root_connection.connection.connection, True
+ )
+
+ def post_exec(self):
+ if self.isddl:
+ self.set_ddl_autocommit(self.root_connection, False)
+
+ if self._enable_identity_insert:
+ self.cursor.execute(
+ "SET IDENTITY_INSERT %s OFF"
+ % self.dialect.identifier_preparer.format_table(
+ self.compiled.statement.table
+ )
+ )
+
+ def get_lastrowid(self):
+ cursor = self.create_cursor()
+ cursor.execute("SELECT @@identity AS lastrowid")
+ lastrowid = cursor.fetchone()[0]
+ cursor.close()
+ return lastrowid
+
+
+class SybaseSQLCompiler(compiler.SQLCompiler):
+ ansi_bind_rules = True
+
+ extract_map = util.update_copy(
+ compiler.SQLCompiler.extract_map,
+ {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"},
+ )
+
+ def get_from_hint_text(self, table, text):
+ return text
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += " ROWS LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += " ROWS"
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ return text
+
+ def visit_extract(self, extract, **kw):
+ field = self.extract_map.get(extract.field, extract.field)
+ return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw))
+
+ def visit_now_func(self, fn, **kw):
+ return "GETDATE()"
+
+ def for_update_clause(self, select):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR"
+ # which SQLAlchemy doesn't use
+ return ""
+
+ def order_by_clause(self, select, **kw):
+ kw["literal_binds"] = True
+ order_by = self.process(select._order_by_clause, **kw)
+
+ # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if order_by and (not self.is_subquery() or select._limit):
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ """If we have extra froms make sure we render any alias as hint."""
+ ashint = False
+ if extra_froms:
+ ashint = True
+ return from_table._compiler_dispatch(
+ self, asfrom=True, iscrud=True, ashint=ashint
+ )
+
+ def delete_extra_from_clause(
+ self, delete_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Render the DELETE .. FROM clause specific to Sybase."""
+ kw["asfrom"] = True
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in [from_table] + extra_froms
+ )
+
+
+class SybaseDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
+
+ if column.table is None:
+ raise exc.CompileError(
+ "The Sybase dialect requires Table-bound "
+ "columns in order to generate DDL"
+ )
+ seq_col = column.table._autoincrement_column
+
+ # install a IDENTITY Sequence if we have an implicit IDENTITY column
+ if seq_col is column:
+ sequence = (
+ isinstance(column.default, sa_schema.Sequence)
+ and column.default
+ )
+ if sequence:
+ start, increment = sequence.start or 1, sequence.increment or 1
+ else:
+ start, increment = 1, 1
+ if (start, increment) == (1, 1):
+ colspec += " IDENTITY"
+ else:
+ # TODO: need correct syntax for this
+ colspec += " IDENTITY(%s,%s)" % (start, increment)
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if column.nullable is not None:
+ if not column.nullable or column.primary_key:
+ colspec += " NOT NULL"
+ else:
+ colspec += " NULL"
+
+ return colspec
+
+ def visit_drop_index(self, drop):
+ index = drop.element
+ return "\nDROP INDEX %s.%s" % (
+ self.preparer.quote_identifier(index.table.name),
+ self._prepared_index_name(drop.element, include_schema=False),
+ )
+
+
+class SybaseIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = RESERVED_WORDS
+
+
+class SybaseDialect(default.DefaultDialect):
+ name = "sybase"
+ supports_unicode_statements = False
+ supports_sane_rowcount = False
+ supports_sane_multi_rowcount = False
+ supports_statement_cache = True
+
+ supports_native_boolean = False
+ supports_unicode_binds = False
+ postfetch_lastrowid = True
+
+ colspecs = {}
+ ischema_names = ischema_names
+
+ type_compiler = SybaseTypeCompiler
+ statement_compiler = SybaseSQLCompiler
+ ddl_compiler = SybaseDDLCompiler
+ preparer = SybaseIdentifierPreparer
+ inspector = SybaseInspector
+
+ construct_arguments = []
+
+ def __init__(self, *args, **kwargs):
+ util.warn_deprecated(
+ "The Sybase dialect is deprecated and will be removed "
+ "in a future version. This dialect is superseded by the external "
+ "dialect https://github.com/gordthompson/sqlalchemy-sybase.",
+ version="1.4",
+ )
+ super(SybaseDialect, self).__init__(*args, **kwargs)
+
+ def _get_default_schema_name(self, connection):
+ return connection.scalar(
+ text("SELECT user_name() as user_name").columns(username=Unicode)
+ )
+
+ def initialize(self, connection):
+ super(SybaseDialect, self).initialize(connection)
+ if (
+ self.server_version_info is not None
+ and self.server_version_info < (15,)
+ ):
+ self.max_identifier_length = 30
+ else:
+ self.max_identifier_length = 255
+
+ def get_table_id(self, connection, table_name, schema=None, **kw):
+ """Fetch the id for schema.table_name.
+
+ Several reflection methods require the table id. The idea for using
+ this method is that it can be fetched one time and cached for
+ subsequent calls.
+
+ """
+
+ table_id = None
+ if schema is None:
+ schema = self.default_schema_name
+
+ TABLEID_SQL = text(
+ """
+ SELECT o.id AS id
+ FROM sysobjects o JOIN sysusers u ON o.uid=u.uid
+ WHERE u.name = :schema_name
+ AND o.name = :table_name
+ AND o.type in ('U', 'V')
+ """
+ )
+
+ if util.py2k:
+ if isinstance(schema, unicode): # noqa
+ schema = schema.encode("ascii")
+ if isinstance(table_name, unicode): # noqa
+ table_name = table_name.encode("ascii")
+ result = connection.execute(
+ TABLEID_SQL, schema_name=schema, table_name=table_name
+ )
+ table_id = result.scalar()
+ if table_id is None:
+ raise exc.NoSuchTableError(table_name)
+ return table_id
+
+ @reflection.cache
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ COLUMN_SQL = text(
+ """
+ SELECT col.name AS name,
+ t.name AS type,
+ (col.status & 8) AS nullable,
+ (col.status & 128) AS autoincrement,
+ com.text AS 'default',
+ col.prec AS precision,
+ col.scale AS scale,
+ col.length AS length
+ FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON
+ col.cdefault = com.id
+ WHERE col.usertype = t.usertype
+ AND col.id = :table_id
+ ORDER BY col.colid
+ """
+ )
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+
+ columns = []
+ for (
+ name,
+ type_,
+ nullable,
+ autoincrement,
+ default_,
+ precision,
+ scale,
+ length,
+ ) in results:
+ col_info = self._get_column_info(
+ name,
+ type_,
+ bool(nullable),
+ bool(autoincrement),
+ default_,
+ precision,
+ scale,
+ length,
+ )
+ columns.append(col_info)
+
+ return columns
+
+ def _get_column_info(
+ self,
+ name,
+ type_,
+ nullable,
+ autoincrement,
+ default,
+ precision,
+ scale,
+ length,
+ ):
+
+ coltype = self.ischema_names.get(type_, None)
+
+ kwargs = {}
+
+ if coltype in (NUMERIC, DECIMAL):
+ args = (precision, scale)
+ elif coltype == FLOAT:
+ args = (precision,)
+ elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR):
+ args = (length,)
+ else:
+ args = ()
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ # is this necessary
+ # if is_array:
+ # coltype = ARRAY(coltype)
+ else:
+ util.warn(
+ "Did not recognize type '%s' of column '%s'" % (type_, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ if default:
+ default = default.replace("DEFAULT", "").strip()
+ default = re.sub("^'(.*)'$", lambda m: m.group(1), default)
+ else:
+ default = None
+
+ column_info = dict(
+ name=name,
+ type=coltype,
+ nullable=nullable,
+ default=default,
+ autoincrement=autoincrement,
+ )
+ return column_info
+
+ @reflection.cache
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ table_cache = {}
+ column_cache = {}
+ foreign_keys = []
+
+ table_cache[table_id] = {"name": table_name, "schema": schema}
+
+ COLUMN_SQL = text(
+ """
+ SELECT c.colid AS id, c.name AS name
+ FROM syscolumns c
+ WHERE c.id = :table_id
+ """
+ )
+
+ results = connection.execute(COLUMN_SQL, table_id=table_id)
+ columns = {}
+ for col in results:
+ columns[col["id"]] = col["name"]
+ column_cache[table_id] = columns
+
+ REFCONSTRAINT_SQL = text(
+ """
+ SELECT o.name AS name, r.reftabid AS reftable_id,
+ r.keycnt AS 'count',
+ r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3,
+ r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6,
+ r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9,
+ r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12,
+ r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15,
+ r.fokey16 AS fokey16,
+ r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3,
+ r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6,
+ r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9,
+ r.refkey10 AS refkey10, r.refkey11 AS refkey11,
+ r.refkey12 AS refkey12, r.refkey13 AS refkey13,
+ r.refkey14 AS refkey14, r.refkey15 AS refkey15,
+ r.refkey16 AS refkey16
+ FROM sysreferences r JOIN sysobjects o on r.tableid = o.id
+ WHERE r.tableid = :table_id
+ """
+ )
+ referential_constraints = connection.execute(
+ REFCONSTRAINT_SQL, table_id=table_id
+ ).fetchall()
+
+ REFTABLE_SQL = text(
+ """
+ SELECT o.name AS name, u.name AS 'schema'
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE o.id = :table_id
+ """
+ )
+
+ for r in referential_constraints:
+ reftable_id = r["reftable_id"]
+
+ if reftable_id not in table_cache:
+ c = connection.execute(REFTABLE_SQL, table_id=reftable_id)
+ reftable = c.fetchone()
+ c.close()
+ table_info = {"name": reftable["name"], "schema": None}
+ if (
+ schema is not None
+ or reftable["schema"] != self.default_schema_name
+ ):
+ table_info["schema"] = reftable["schema"]
+
+ table_cache[reftable_id] = table_info
+ results = connection.execute(COLUMN_SQL, table_id=reftable_id)
+ reftable_columns = {}
+ for col in results:
+ reftable_columns[col["id"]] = col["name"]
+ column_cache[reftable_id] = reftable_columns
+
+ reftable = table_cache[reftable_id]
+ reftable_columns = column_cache[reftable_id]
+
+ constrained_columns = []
+ referred_columns = []
+ for i in range(1, r["count"] + 1):
+ constrained_columns.append(columns[r["fokey%i" % i]])
+ referred_columns.append(reftable_columns[r["refkey%i" % i]])
+
+ fk_info = {
+ "constrained_columns": constrained_columns,
+ "referred_schema": reftable["schema"],
+ "referred_table": reftable["name"],
+ "referred_columns": referred_columns,
+ "name": r["name"],
+ }
+
+ foreign_keys.append(fk_info)
+
+ return foreign_keys
+
+ @reflection.cache
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ INDEX_SQL = text(
+ """
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ (i.status & 0x2) AS 'unique',
+ index_col(object_name(i.id), i.indid, 1) AS col_1,
+ index_col(object_name(i.id), i.indid, 2) AS col_2,
+ index_col(object_name(i.id), i.indid, 3) AS col_3,
+ index_col(object_name(i.id), i.indid, 4) AS col_4,
+ index_col(object_name(i.id), i.indid, 5) AS col_5,
+ index_col(object_name(i.id), i.indid, 6) AS col_6,
+ index_col(object_name(i.id), i.indid, 7) AS col_7,
+ index_col(object_name(i.id), i.indid, 8) AS col_8,
+ index_col(object_name(i.id), i.indid, 9) AS col_9,
+ index_col(object_name(i.id), i.indid, 10) AS col_10,
+ index_col(object_name(i.id), i.indid, 11) AS col_11,
+ index_col(object_name(i.id), i.indid, 12) AS col_12,
+ index_col(object_name(i.id), i.indid, 13) AS col_13,
+ index_col(object_name(i.id), i.indid, 14) AS col_14,
+ index_col(object_name(i.id), i.indid, 15) AS col_15,
+ index_col(object_name(i.id), i.indid, 16) AS col_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 0
+ AND i.indid BETWEEN 1 AND 254
+ """
+ )
+
+ results = connection.execute(INDEX_SQL, table_id=table_id)
+ indexes = []
+ for r in results:
+ column_names = []
+ for i in range(1, r["count"]):
+ column_names.append(r["col_%i" % (i,)])
+ index_info = {
+ "name": r["name"],
+ "unique": bool(r["unique"]),
+ "column_names": column_names,
+ }
+ indexes.append(index_info)
+
+ return indexes
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ table_id = self.get_table_id(
+ connection, table_name, schema, info_cache=kw.get("info_cache")
+ )
+
+ PK_SQL = text(
+ """
+ SELECT object_name(i.id) AS table_name,
+ i.keycnt AS 'count',
+ i.name AS name,
+ index_col(object_name(i.id), i.indid, 1) AS pk_1,
+ index_col(object_name(i.id), i.indid, 2) AS pk_2,
+ index_col(object_name(i.id), i.indid, 3) AS pk_3,
+ index_col(object_name(i.id), i.indid, 4) AS pk_4,
+ index_col(object_name(i.id), i.indid, 5) AS pk_5,
+ index_col(object_name(i.id), i.indid, 6) AS pk_6,
+ index_col(object_name(i.id), i.indid, 7) AS pk_7,
+ index_col(object_name(i.id), i.indid, 8) AS pk_8,
+ index_col(object_name(i.id), i.indid, 9) AS pk_9,
+ index_col(object_name(i.id), i.indid, 10) AS pk_10,
+ index_col(object_name(i.id), i.indid, 11) AS pk_11,
+ index_col(object_name(i.id), i.indid, 12) AS pk_12,
+ index_col(object_name(i.id), i.indid, 13) AS pk_13,
+ index_col(object_name(i.id), i.indid, 14) AS pk_14,
+ index_col(object_name(i.id), i.indid, 15) AS pk_15,
+ index_col(object_name(i.id), i.indid, 16) AS pk_16
+ FROM sysindexes i, sysobjects o
+ WHERE o.id = i.id
+ AND o.id = :table_id
+ AND (i.status & 2048) = 2048
+ AND i.indid BETWEEN 1 AND 254
+ """
+ )
+
+ results = connection.execute(PK_SQL, table_id=table_id)
+ pks = results.fetchone()
+ results.close()
+
+ constrained_columns = []
+ if pks:
+ for i in range(1, pks["count"] + 1):
+ constrained_columns.append(pks["pk_%i" % (i,)])
+ return {
+ "constrained_columns": constrained_columns,
+ "name": pks["name"],
+ }
+ else:
+ return {"constrained_columns": [], "name": None}
+
+ @reflection.cache
+ def get_schema_names(self, connection, **kw):
+
+ SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u")
+
+ schemas = connection.execute(SCHEMA_SQL)
+
+ return [s["name"] for s in schemas]
+
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ TABLE_SQL = text(
+ """
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'U'
+ """
+ )
+
+ if util.py2k:
+ if isinstance(schema, unicode): # noqa
+ schema = schema.encode("ascii")
+
+ tables = connection.execute(TABLE_SQL, schema_name=schema)
+
+ return [t["name"] for t in tables]
+
+ @reflection.cache
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ VIEW_DEF_SQL = text(
+ """
+ SELECT c.text
+ FROM syscomments c JOIN sysobjects o ON c.id = o.id
+ WHERE o.name = :view_name
+ AND o.type = 'V'
+ """
+ )
+
+ if util.py2k:
+ if isinstance(view_name, unicode): # noqa
+ view_name = view_name.encode("ascii")
+
+ view = connection.execute(VIEW_DEF_SQL, view_name=view_name)
+
+ return view.scalar()
+
+ @reflection.cache
+ def get_view_names(self, connection, schema=None, **kw):
+ if schema is None:
+ schema = self.default_schema_name
+
+ VIEW_SQL = text(
+ """
+ SELECT o.name AS name
+ FROM sysobjects o JOIN sysusers u ON o.uid = u.uid
+ WHERE u.name = :schema_name
+ AND o.type = 'V'
+ """
+ )
+
+ if util.py2k:
+ if isinstance(schema, unicode): # noqa
+ schema = schema.encode("ascii")
+ views = connection.execute(VIEW_SQL, schema_name=schema)
+
+ return [v["name"] for v in views]
+
+ def has_table(self, connection, table_name, schema=None):
+ self._ensure_has_table_connection(connection)
+
+ try:
+ self.get_table_id(connection, table_name, schema)
+ except exc.NoSuchTableError:
+ return False
+ else:
+ return True
diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py
new file mode 100644
index 0000000..fe5a614
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py
@@ -0,0 +1,34 @@
+# sybase/mxodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+
+.. dialect:: sybase+mxodbc
+ :name: mxODBC
+ :dbapi: mxodbc
+ :connectstring: sybase+mxodbc://<username>:<password>@<dsnname>
+ :url: https://www.egenix.com/
+
+.. note::
+
+ This dialect is a stub only and is likely non functional at this time.
+
+"""
+from sqlalchemy.connectors.mxodbc import MxODBCConnector
+from sqlalchemy.dialects.sybase.base import SybaseDialect
+from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
+
+
+class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
+ pass
+
+
+class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_mxodbc
+ supports_statement_cache = True
+
+
+dialect = SybaseDialect_mxodbc
diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py
new file mode 100644
index 0000000..f408e8f
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py
@@ -0,0 +1,89 @@
+# sybase/pyodbc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: sybase+pyodbc
+ :name: PyODBC
+ :dbapi: pyodbc
+ :connectstring: sybase+pyodbc://<username>:<password>@<dsnname>[/<database>]
+ :url: https://pypi.org/project/pyodbc/
+
+Unicode Support
+---------------
+
+The pyodbc driver currently supports usage of these Sybase types with
+Unicode or multibyte strings::
+
+ CHAR
+ NCHAR
+ NVARCHAR
+ TEXT
+ VARCHAR
+
+Currently *not* supported are::
+
+ UNICHAR
+ UNITEXT
+ UNIVARCHAR
+
+""" # noqa
+
+import decimal
+
+from sqlalchemy import processors
+from sqlalchemy import types as sqltypes
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy.dialects.sybase.base import SybaseDialect
+from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
+
+
+class _SybNumeric_pyodbc(sqltypes.Numeric):
+ """Turns Decimals with adjusted() < -6 into floats.
+
+ It's not yet known how to get decimals with many
+ significant digits or very large adjusted() into Sybase
+ via pyodbc.
+
+ """
+
+ def bind_processor(self, dialect):
+ super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
+
+ def process(value):
+ if self.asdecimal and isinstance(value, decimal.Decimal):
+
+ if value.adjusted() < -6:
+ return processors.to_float(value)
+
+ if super_process:
+ return super_process(value)
+ else:
+ return value
+
+ return process
+
+
+class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
+ def set_ddl_autocommit(self, connection, value):
+ if value:
+ connection.autocommit = True
+ else:
+ connection.autocommit = False
+
+
+class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
+ execution_ctx_cls = SybaseExecutionContext_pyodbc
+ supports_statement_cache = True
+
+ colspecs = {sqltypes.Numeric: _SybNumeric_pyodbc}
+
+ @classmethod
+ def dbapi(cls):
+ return PyODBCConnector.dbapi()
+
+
+dialect = SybaseDialect_pyodbc
diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py
new file mode 100644
index 0000000..4c96aac
--- /dev/null
+++ b/lib/sqlalchemy/dialects/sybase/pysybase.py
@@ -0,0 +1,106 @@
+# sybase/pysybase.py
+# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+.. dialect:: sybase+pysybase
+ :name: Python-Sybase
+ :dbapi: Sybase
+ :connectstring: sybase+pysybase://<username>:<password>@<dsn>/[database name]
+ :url: https://python-sybase.sourceforge.net/
+
+Unicode Support
+---------------
+
+The python-sybase driver does not appear to support non-ASCII strings of any
+kind at this time.
+
+""" # noqa
+
+from sqlalchemy import processors
+from sqlalchemy import types as sqltypes
+from sqlalchemy.dialects.sybase.base import SybaseDialect
+from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
+from sqlalchemy.dialects.sybase.base import SybaseSQLCompiler
+
+
+class _SybNumeric(sqltypes.Numeric):
+ def result_processor(self, dialect, type_):
+ if not self.asdecimal:
+ return processors.to_float
+ else:
+ return sqltypes.Numeric.result_processor(self, dialect, type_)
+
+
+class SybaseExecutionContext_pysybase(SybaseExecutionContext):
+ def set_ddl_autocommit(self, dbapi_connection, value):
+ if value:
+ # call commit() on the Sybase connection directly,
+ # to avoid any side effects of calling a Connection
+ # transactional method inside of pre_exec()
+ dbapi_connection.commit()
+
+ def pre_exec(self):
+ SybaseExecutionContext.pre_exec(self)
+
+ for param in self.parameters:
+ for key in list(param):
+ param["@" + key] = param[key]
+ del param[key]
+
+
+class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
+ def bindparam_string(self, name, **kw):
+ return "@" + name
+
+
+class SybaseDialect_pysybase(SybaseDialect):
+ driver = "pysybase"
+ execution_ctx_cls = SybaseExecutionContext_pysybase
+ statement_compiler = SybaseSQLCompiler_pysybase
+
+ supports_statement_cache = True
+
+ colspecs = {sqltypes.Numeric: _SybNumeric, sqltypes.Float: sqltypes.Float}
+
+ @classmethod
+ def dbapi(cls):
+ import Sybase
+
+ return Sybase
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args(username="user", password="passwd")
+
+ return ([opts.pop("host")], opts)
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ # calling python-sybase executemany yields:
+ # TypeError: string too long for buffer
+ for param in parameters:
+ cursor.execute(statement, param)
+
+ def _get_server_version_info(self, connection):
+ vers = connection.exec_driver_sql("select @@version_number").scalar()
+ # i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0),
+ # (12, 5, 0, 0)
+ return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
+
+ def is_disconnect(self, e, connection, cursor):
+ if isinstance(
+ e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
+ ):
+ msg = str(e)
+ return (
+ "Unable to complete network request to host" in msg
+ or "Invalid connection state" in msg
+ or "Invalid cursor state" in msg
+ )
+ else:
+ return False
+
+
+dialect = SybaseDialect_pysybase
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
new file mode 100644
index 0000000..2437e17
--- /dev/null
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -0,0 +1,62 @@
+# engine/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQL connections, SQL execution and high-level DB-API interface.
+
+The engine package defines the basic components used to interface
+DB-API modules with higher-level statement construction,
+connection-management, execution and result contexts. The primary
+"entry point" class into this package is the Engine and its public
+constructor ``create_engine()``.
+
+"""
+
+from . import events
+from . import util
+from .base import Connection
+from .base import Engine
+from .base import NestedTransaction
+from .base import RootTransaction
+from .base import Transaction
+from .base import TwoPhaseTransaction
+from .create import create_engine
+from .create import engine_from_config
+from .cursor import BaseCursorResult
+from .cursor import BufferedColumnResultProxy
+from .cursor import BufferedColumnRow
+from .cursor import BufferedRowResultProxy
+from .cursor import CursorResult
+from .cursor import FullyBufferedResultProxy
+from .cursor import LegacyCursorResult
+from .cursor import ResultProxy
+from .interfaces import AdaptedConnection
+from .interfaces import Compiled
+from .interfaces import Connectable
+from .interfaces import CreateEnginePlugin
+from .interfaces import Dialect
+from .interfaces import ExceptionContext
+from .interfaces import ExecutionContext
+from .interfaces import TypeCompiler
+from .mock import create_mock_engine
+from .reflection import Inspector
+from .result import ChunkedIteratorResult
+from .result import FilterResult
+from .result import FrozenResult
+from .result import IteratorResult
+from .result import MappingResult
+from .result import MergedResult
+from .result import Result
+from .result import result_tuple
+from .result import ScalarResult
+from .row import BaseRow
+from .row import LegacyRow
+from .row import Row
+from .row import RowMapping
+from .url import make_url
+from .url import URL
+from .util import connection_memoize
+from ..sql import ddl
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
new file mode 100644
index 0000000..f126eb0
--- /dev/null
+++ b/lib/sqlalchemy/engine/base.py
@@ -0,0 +1,3450 @@
+# engine/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from __future__ import with_statement
+
+import contextlib
+import sys
+
+from .interfaces import Connectable
+from .interfaces import ExceptionContext
+from .util import _distill_params
+from .util import _distill_params_20
+from .util import TransactionalContext
+from .. import exc
+from .. import inspection
+from .. import log
+from .. import util
+from ..sql import compiler
+from ..sql import util as sql_util
+
+
+"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
+
+"""
+
+_EMPTY_EXECUTION_OPTS = util.immutabledict()
+
+
+class Connection(Connectable):
+ """Provides high-level functionality for a wrapped DB-API connection.
+
+ **This is the SQLAlchemy 1.x.x version** of the :class:`_engine.Connection`
+ class. For the :term:`2.0 style` version, which features some API
+ differences, see :class:`_future.Connection`.
+
+ The :class:`_engine.Connection` object is procured by calling
+ the :meth:`_engine.Engine.connect` method of the :class:`_engine.Engine`
+ object, and provides services for execution of SQL statements as well
+ as transaction control.
+
+ The Connection object is **not** thread-safe. While a Connection can be
+ shared among threads using properly synchronized access, it is still
+ possible that the underlying DBAPI connection may not support shared
+ access between threads. Check the DBAPI documentation for details.
+
+ The Connection object represents a single DBAPI connection checked out
+ from the connection pool. In this state, the connection pool has no affect
+ upon the connection, including its expiration or timeout state. For the
+ connection pool to properly manage connections, connections should be
+ returned to the connection pool (i.e. ``connection.close()``) whenever the
+ connection is not in use.
+
+ .. index::
+ single: thread safety; Connection
+
+ """
+
+ _is_future = False
+ _sqla_logger_namespace = "sqlalchemy.engine.Connection"
+
+ # used by sqlalchemy.engine.util.TransactionalContext
+ _trans_context_manager = None
+
+ def __init__(
+ self,
+ engine,
+ connection=None,
+ close_with_result=False,
+ _branch_from=None,
+ _execution_options=None,
+ _dispatch=None,
+ _has_events=None,
+ _allow_revalidate=True,
+ ):
+ """Construct a new Connection."""
+ self.engine = engine
+ self.dialect = engine.dialect
+ self.__branch_from = _branch_from
+
+ if _branch_from:
+ # branching is always "from" the root connection
+ assert _branch_from.__branch_from is None
+ self._dbapi_connection = connection
+ self._execution_options = _execution_options
+ self._echo = _branch_from._echo
+ self.should_close_with_result = False
+ self.dispatch = _dispatch
+ self._has_events = _branch_from._has_events
+ else:
+ self._dbapi_connection = (
+ connection
+ if connection is not None
+ else engine.raw_connection()
+ )
+
+ self._transaction = self._nested_transaction = None
+ self.__savepoint_seq = 0
+ self.__in_begin = False
+ self.should_close_with_result = close_with_result
+
+ self.__can_reconnect = _allow_revalidate
+ self._echo = self.engine._should_log_info()
+
+ if _has_events is None:
+ # if _has_events is sent explicitly as False,
+ # then don't join the dispatch of the engine; we don't
+ # want to handle any of the engine's events in that case.
+ self.dispatch = self.dispatch._join(engine.dispatch)
+ self._has_events = _has_events or (
+ _has_events is None and engine._has_events
+ )
+
+ assert not _execution_options
+ self._execution_options = engine._execution_options
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.engine_connect(self, _branch_from is not None)
+
+ @util.memoized_property
+ def _message_formatter(self):
+ if "logging_token" in self._execution_options:
+ token = self._execution_options["logging_token"]
+ return lambda msg: "[%s] %s" % (token, msg)
+ else:
+ return None
+
+ def _log_info(self, message, *arg, **kw):
+ fmt = self._message_formatter
+
+ if fmt:
+ message = fmt(message)
+
+ if log.STACKLEVEL:
+ kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET
+
+ self.engine.logger.info(message, *arg, **kw)
+
+ def _log_debug(self, message, *arg, **kw):
+ fmt = self._message_formatter
+
+ if fmt:
+ message = fmt(message)
+
+ if log.STACKLEVEL:
+ kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET
+
+ self.engine.logger.debug(message, *arg, **kw)
+
+ @property
+ def _schema_translate_map(self):
+ return self._execution_options.get("schema_translate_map", None)
+
+ def schema_for_object(self, obj):
+ """Return the schema name for the given schema item taking into
+ account current schema translate map.
+
+ """
+
+ name = obj.schema
+ schema_translate_map = self._execution_options.get(
+ "schema_translate_map", None
+ )
+
+ if (
+ schema_translate_map
+ and name in schema_translate_map
+ and obj._use_schema_map
+ ):
+ return schema_translate_map[name]
+ else:
+ return name
+
+ def _branch(self):
+ """Return a new Connection which references this Connection's
+ engine and connection; but does not have close_with_result enabled,
+ and also whose close() method does nothing.
+
+ .. deprecated:: 1.4 the "branching" concept will be removed in
+ SQLAlchemy 2.0 as well as the "Connection.connect()" method which
+ is the only consumer for this.
+
+ The Core uses this very sparingly, only in the case of
+ custom SQL default functions that are to be INSERTed as the
+ primary key of a row where we need to get the value back, so we have
+ to invoke it distinctly - this is a very uncommon case.
+
+ Userland code accesses _branch() when the connect()
+ method is called. The branched connection
+ acts as much as possible like the parent, except that it stays
+ connected when a close() event occurs.
+
+ """
+ return self.engine._connection_cls(
+ self.engine,
+ self._dbapi_connection,
+ _branch_from=self.__branch_from if self.__branch_from else self,
+ _execution_options=self._execution_options,
+ _has_events=self._has_events,
+ _dispatch=self.dispatch,
+ )
+
+ def _generate_for_options(self):
+ """define connection method chaining behavior for execution_options"""
+
+ if self._is_future:
+ return self
+ else:
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = self.__dict__.copy()
+ return c
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ self.close()
+
+ def execution_options(self, **opt):
+ r""" Set non-SQL options for the connection which take effect
+ during execution.
+
+ For a "future" style connection, this method returns this same
+ :class:`_future.Connection` object with the new options added.
+
+ For a legacy connection, this method returns a copy of this
+ :class:`_engine.Connection` which references the same underlying DBAPI
+ connection, but also defines the given execution options which will
+ take effect for a call to
+ :meth:`execute`. As the new :class:`_engine.Connection` references the
+ same underlying resource, it's usually a good idea to ensure that
+ the copies will be discarded immediately, which is implicit if used
+ as in::
+
+ result = connection.execution_options(stream_results=True).\
+ execute(stmt)
+
+ Note that any key/value can be passed to
+ :meth:`_engine.Connection.execution_options`,
+ and it will be stored in the
+ ``_execution_options`` dictionary of the :class:`_engine.Connection`.
+ It
+ is suitable for usage by end-user schemes to communicate with
+ event listeners, for example.
+
+ The keywords that are currently recognized by SQLAlchemy itself
+ include all those listed under :meth:`.Executable.execution_options`,
+ as well as others that are specific to :class:`_engine.Connection`.
+
+ :param autocommit: Available on: Connection, statement.
+ When True, a COMMIT will be invoked after execution
+ when executed in 'autocommit' mode, i.e. when an explicit
+ transaction is not begun on the connection. Note that this
+ is **library level, not DBAPI level autocommit**. The DBAPI
+ connection will remain in a real transaction unless the
+ "AUTOCOMMIT" isolation level is used.
+
+ .. deprecated:: 1.4 The "autocommit" execution option is deprecated
+ and will be removed in SQLAlchemy 2.0. See
+ :ref:`migration_20_autocommit` for discussion.
+
+ :param compiled_cache: Available on: Connection.
+ A dictionary where :class:`.Compiled` objects
+ will be cached when the :class:`_engine.Connection`
+ compiles a clause
+ expression into a :class:`.Compiled` object. This dictionary will
+ supersede the statement cache that may be configured on the
+ :class:`_engine.Engine` itself. If set to None, caching
+ is disabled, even if the engine has a configured cache size.
+
+ Note that the ORM makes use of its own "compiled" caches for
+ some operations, including flush operations. The caching
+ used by the ORM internally supersedes a cache dictionary
+ specified here.
+
+ :param logging_token: Available on: :class:`_engine.Connection`,
+ :class:`_engine.Engine`.
+
+ Adds the specified string token surrounded by brackets in log
+ messages logged by the connection, i.e. the logging that's enabled
+ either via the :paramref:`_sa.create_engine.echo` flag or via the
+ ``logging.getLogger("sqlalchemy.engine")`` logger. This allows a
+ per-connection or per-sub-engine token to be available which is
+ useful for debugging concurrent connection scenarios.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`dbengine_logging_tokens` - usage example
+
+ :paramref:`_sa.create_engine.logging_name` - adds a name to the
+ name used by the Python logger object itself.
+
+ :param isolation_level: Available on: :class:`_engine.Connection`.
+
+ Set the transaction isolation level for the lifespan of this
+ :class:`_engine.Connection` object.
+ Valid values include those string
+ values accepted by the :paramref:`_sa.create_engine.isolation_level`
+ parameter passed to :func:`_sa.create_engine`. These levels are
+ semi-database specific; see individual dialect documentation for
+ valid levels.
+
+ The isolation level option applies the isolation level by emitting
+ statements on the DBAPI connection, and **necessarily affects the
+ original Connection object overall**, not just the copy that is
+ returned by the call to :meth:`_engine.Connection.execution_options`
+ method. The isolation level will remain at the given setting until
+ the DBAPI connection itself is returned to the connection pool, i.e.
+ the :meth:`_engine.Connection.close` method on the original
+ :class:`_engine.Connection` is called,
+ where an event handler will emit
+ additional statements on the DBAPI connection in order to revert the
+ isolation level change.
+
+ .. warning:: The ``isolation_level`` execution option should
+ **not** be used when a transaction is already established, that
+ is, the :meth:`_engine.Connection.begin`
+ method or similar has been
+ called. A database cannot change the isolation level on a
+ transaction in progress, and different DBAPIs and/or
+ SQLAlchemy dialects may implicitly roll back or commit
+ the transaction, or not affect the connection at all.
+
+ .. note:: The ``isolation_level`` execution option is implicitly
+ reset if the :class:`_engine.Connection` is invalidated, e.g. via
+ the :meth:`_engine.Connection.invalidate` method, or if a
+ disconnection error occurs. The new connection produced after
+ the invalidation will not have the isolation level re-applied
+ to it automatically.
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.isolation_level`
+ - set per :class:`_engine.Engine` isolation level
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :ref:`SQLite Transaction Isolation <sqlite_isolation_level>`
+
+ :ref:`PostgreSQL Transaction Isolation <postgresql_isolation_level>`
+
+ :ref:`MySQL Transaction Isolation <mysql_isolation_level>`
+
+ :ref:`SQL Server Transaction Isolation <mssql_isolation_level>`
+
+ :ref:`session_transaction_isolation` - for the ORM
+
+ :param no_parameters: When ``True``, if the final parameter
+ list or dictionary is totally empty, will invoke the
+ statement on the cursor as ``cursor.execute(statement)``,
+ not passing the parameter collection at all.
+ Some DBAPIs such as psycopg2 and mysql-python consider
+ percent signs as significant only when parameters are
+ present; this option allows code to generate SQL
+ containing percent signs (and possibly other characters)
+ that is neutral regarding whether it's executed by the DBAPI
+ or piped into a script that's later invoked by
+ command line tools.
+
+ :param stream_results: Available on: Connection, statement.
+ Indicate to the dialect that results should be
+ "streamed" and not pre-buffered, if possible. For backends
+ such as PostgreSQL, MySQL and MariaDB, this indicates the use of
+ a "server side cursor" as opposed to a client side cursor.
+ Other backends such as that of Oracle may already use server
+ side cursors by default.
+
+ The usage of
+ :paramref:`_engine.Connection.execution_options.stream_results` is
+ usually combined with setting a fixed number of rows to to be fetched
+ in batches, to allow for efficient iteration of database rows while
+ at the same time not loading all result rows into memory at once;
+ this can be configured on a :class:`_engine.Result` object using the
+ :meth:`_engine.Result.yield_per` method, after execution has
+ returned a new :class:`_engine.Result`. If
+ :meth:`_engine.Result.yield_per` is not used,
+ the :paramref:`_engine.Connection.execution_options.stream_results`
+ mode of operation will instead use a dynamically sized buffer
+ which buffers sets of rows at a time, growing on each batch
+ based on a fixed growth size up until a limit which may
+ be configured using the
+ :paramref:`_engine.Connection.execution_options.max_row_buffer`
+ parameter.
+
+ When using the ORM to fetch ORM mapped objects from a result,
+ :meth:`_engine.Result.yield_per` should always be used with
+ :paramref:`_engine.Connection.execution_options.stream_results`,
+ so that the ORM does not fetch all rows into new ORM objects at once.
+
+ For typical use, the
+ :paramref:`_engine.Connection.execution_options.yield_per` execution
+ option should be preferred, which sets up both
+ :paramref:`_engine.Connection.execution_options.stream_results` and
+ :meth:`_engine.Result.yield_per` at once. This option is supported
+ both at a core level by :class:`_engine.Connection` as well as by the
+ ORM :class:`_engine.Session`; the latter is described at
+ :ref:`orm_queryguide_yield_per`.
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - background on
+ :paramref:`_engine.Connection.execution_options.stream_results`
+
+ :paramref:`_engine.Connection.execution_options.max_row_buffer`
+
+ :paramref:`_engine.Connection.execution_options.yield_per`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+ describing the ORM version of ``yield_per``
+
+ :param max_row_buffer: Available on: :class:`_engine.Connection`,
+ :class:`_sql.Executable`. Sets a maximum
+ buffer size to use when the
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ execution option is used on a backend that supports server side
+ cursors. The default value if not specified is 1000.
+
+ .. seealso::
+
+ :paramref:`_engine.Connection.execution_options.stream_results`
+
+ :ref:`engine_stream_results`
+
+
+ :param yield_per: Available on: :class:`_engine.Connection`,
+ :class:`_sql.Executable`. Integer value applied which will
+ set the :paramref:`_engine.Connection.execution_options.stream_results`
+ execution option and invoke :meth:`_engine.Result.yield_per`
+ automatically at once. Allows equivalent functionality as
+ is present when using this parameter with the ORM.
+
+ .. versionadded:: 1.4.40
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - background and examples
+ on using server side cursors with Core.
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+ describing the ORM version of ``yield_per``
+
+ :param schema_translate_map: Available on: :class:`_engine.Connection`,
+ :class:`_engine.Engine`, :class:`_sql.Executable`.
+
+ :param schema_translate_map: Available on: Connection, Engine.
+ A dictionary mapping schema names to schema names, that will be
+ applied to the :paramref:`_schema.Table.schema` element of each
+ :class:`_schema.Table`
+ encountered when SQL or DDL expression elements
+ are compiled into strings; the resulting schema name will be
+ converted based on presence in the map of the original name.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`schema_translating`
+
+ .. seealso::
+
+ :meth:`_engine.Engine.execution_options`
+
+ :meth:`.Executable.execution_options`
+
+ :meth:`_engine.Connection.get_execution_options`
+
+
+ """ # noqa
+ c = self._generate_for_options()
+ c._execution_options = c._execution_options.union(opt)
+ if self._has_events or self.engine._has_events:
+ self.dispatch.set_connection_execution_options(c, opt)
+ self.dialect.set_connection_execution_options(c, opt)
+ return c
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+ """
+ return self._execution_options
+
+ @property
+ def closed(self):
+ """Return True if this connection is closed."""
+
+ # note this is independent for a "branched" connection vs.
+ # the base
+
+ return self._dbapi_connection is None and not self.__can_reconnect
+
+ @property
+ def invalidated(self):
+ """Return True if this connection was invalidated."""
+
+ # prior to 1.4, "invalid" was stored as a state independent of
+ # "closed", meaning an invalidated connection could be "closed",
+ # the _dbapi_connection would be None and closed=True, yet the
+ # "invalid" flag would stay True. This meant that there were
+ # three separate states (open/valid, closed/valid, closed/invalid)
+ # when there is really no reason for that; a connection that's
+ # "closed" does not need to be "invalid". So the state is now
+ # represented by the two facts alone.
+
+ if self.__branch_from:
+ return self.__branch_from.invalidated
+
+ return self._dbapi_connection is None and not self.closed
+
+ @property
+ def connection(self):
+ """The underlying DB-API connection managed by this Connection.
+
+ This is a SQLAlchemy connection-pool proxied connection
+ which then has the attribute
+ :attr:`_pool._ConnectionFairy.dbapi_connection` that refers to the
+ actual driver connection.
+
+ .. seealso::
+
+
+ :ref:`dbapi_connections`
+
+ """
+
+ if self._dbapi_connection is None:
+ try:
+ return self._revalidate_connection()
+ except (exc.PendingRollbackError, exc.ResourceClosedError):
+ raise
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+ else:
+ return self._dbapi_connection
+
+ def get_isolation_level(self):
+ """Return the current isolation level assigned to this
+ :class:`_engine.Connection`.
+
+ This will typically be the default isolation level as determined
+ by the dialect, unless if the
+ :paramref:`.Connection.execution_options.isolation_level`
+ feature has been used to alter the isolation level on a
+ per-:class:`_engine.Connection` basis.
+
+ This attribute will typically perform a live SQL operation in order
+ to procure the current isolation level, so the value returned is the
+ actual level on the underlying DBAPI connection regardless of how
+ this state was set. Compare to the
+ :attr:`_engine.Connection.default_isolation_level` accessor
+ which returns the dialect-level setting without performing a SQL
+ query.
+
+ .. versionadded:: 0.9.9
+
+ .. seealso::
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`_sa.create_engine.isolation_level`
+ - set per :class:`_engine.Engine` isolation level
+
+ :paramref:`.Connection.execution_options.isolation_level`
+ - set per :class:`_engine.Connection` isolation level
+
+ """
+ try:
+ return self.dialect.get_isolation_level(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ @property
+ def default_isolation_level(self):
+ """The default isolation level assigned to this
+ :class:`_engine.Connection`.
+
+ This is the isolation level setting that the
+ :class:`_engine.Connection`
+ has when first procured via the :meth:`_engine.Engine.connect` method.
+ This level stays in place until the
+ :paramref:`.Connection.execution_options.isolation_level` is used
+ to change the setting on a per-:class:`_engine.Connection` basis.
+
+ Unlike :meth:`_engine.Connection.get_isolation_level`,
+ this attribute is set
+ ahead of time from the first connection procured by the dialect,
+ so SQL query is not invoked when this accessor is called.
+
+ .. versionadded:: 0.9.9
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :paramref:`_sa.create_engine.isolation_level`
+ - set per :class:`_engine.Engine` isolation level
+
+ :paramref:`.Connection.execution_options.isolation_level`
+ - set per :class:`_engine.Connection` isolation level
+
+ """
+ return self.dialect.default_isolation_level
+
+ def _invalid_transaction(self):
+ if self.invalidated:
+ raise exc.PendingRollbackError(
+ "Can't reconnect until invalid %stransaction is rolled "
+ "back."
+ % (
+ "savepoint "
+ if self._nested_transaction is not None
+ else ""
+ ),
+ code="8s2b",
+ )
+ else:
+ assert not self._is_future
+ raise exc.PendingRollbackError(
+ "This connection is on an inactive %stransaction. "
+ "Please rollback() fully before proceeding."
+ % (
+ "savepoint "
+ if self._nested_transaction is not None
+ else ""
+ ),
+ code="8s2a",
+ )
+
+ def _revalidate_connection(self):
+ if self.__branch_from:
+ return self.__branch_from._revalidate_connection()
+ if self.__can_reconnect and self.invalidated:
+ if self._transaction is not None:
+ self._invalid_transaction()
+ self._dbapi_connection = self.engine.raw_connection(
+ _connection=self
+ )
+ return self._dbapi_connection
+ raise exc.ResourceClosedError("This Connection is closed")
+
+ @property
+ def _still_open_and_dbapi_connection_is_valid(self):
+ return self._dbapi_connection is not None and getattr(
+ self._dbapi_connection, "is_valid", False
+ )
+
+ @property
+ def info(self):
+ """Info dictionary associated with the underlying DBAPI connection
+ referred to by this :class:`_engine.Connection`, allowing user-defined
+ data to be associated with the connection.
+
+ The data here will follow along with the DBAPI connection including
+ after it is returned to the connection pool and used again
+ in subsequent instances of :class:`_engine.Connection`.
+
+ """
+
+ return self.connection.info
+
+ @util.deprecated_20(":meth:`.Connection.connect`")
+ def connect(self, close_with_result=False):
+ """Returns a branched version of this :class:`_engine.Connection`.
+
+ The :meth:`_engine.Connection.close` method on the returned
+ :class:`_engine.Connection` can be called and this
+ :class:`_engine.Connection` will remain open.
+
+ This method provides usage symmetry with
+ :meth:`_engine.Engine.connect`, including for usage
+ with context managers.
+
+ """
+
+ return self._branch()
+
+ def invalidate(self, exception=None):
+ """Invalidate the underlying DBAPI connection associated with
+ this :class:`_engine.Connection`.
+
+ An attempt will be made to close the underlying DBAPI connection
+ immediately; however if this operation fails, the error is logged
+ but not raised. The connection is then discarded whether or not
+ close() succeeded.
+
+ Upon the next use (where "use" typically means using the
+ :meth:`_engine.Connection.execute` method or similar),
+ this :class:`_engine.Connection` will attempt to
+ procure a new DBAPI connection using the services of the
+ :class:`_pool.Pool` as a source of connectivity (e.g.
+ a "reconnection").
+
+ If a transaction was in progress (e.g. the
+ :meth:`_engine.Connection.begin` method has been called) when
+ :meth:`_engine.Connection.invalidate` method is called, at the DBAPI
+ level all state associated with this transaction is lost, as
+ the DBAPI connection is closed. The :class:`_engine.Connection`
+ will not allow a reconnection to proceed until the
+ :class:`.Transaction` object is ended, by calling the
+ :meth:`.Transaction.rollback` method; until that point, any attempt at
+ continuing to use the :class:`_engine.Connection` will raise an
+ :class:`~sqlalchemy.exc.InvalidRequestError`.
+ This is to prevent applications from accidentally
+ continuing an ongoing transactional operations despite the
+ fact that the transaction has been lost due to an
+ invalidation.
+
+ The :meth:`_engine.Connection.invalidate` method,
+ just like auto-invalidation,
+ will at the connection pool level invoke the
+ :meth:`_events.PoolEvents.invalidate` event.
+
+ :param exception: an optional ``Exception`` instance that's the
+ reason for the invalidation. is passed along to event handlers
+ and logging functions.
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+
+ if self.__branch_from:
+ return self.__branch_from.invalidate(exception=exception)
+
+ if self.invalidated:
+ return
+
+ if self.closed:
+ raise exc.ResourceClosedError("This Connection is closed")
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self._dbapi_connection.invalidate(exception)
+ self._dbapi_connection = None
+
+ def detach(self):
+ """Detach the underlying DB-API connection from its connection pool.
+
+ E.g.::
+
+ with engine.connect() as conn:
+ conn.detach()
+ conn.execute(text("SET search_path TO schema1, schema2"))
+
+ # work with connection
+
+ # connection is fully closed (since we used "with:", can
+ # also call .close())
+
+ This :class:`_engine.Connection` instance will remain usable.
+ When closed
+ (or exited from a context manager context as above),
+ the DB-API connection will be literally closed and not
+ returned to its originating pool.
+
+ This method can be used to insulate the rest of an application
+ from a modified state on a connection (such as a transaction
+ isolation level or similar).
+
+ """
+
+ self._dbapi_connection.detach()
+
+ def _autobegin(self):
+ self.begin()
+
+ def begin(self):
+ """Begin a transaction and return a transaction handle.
+
+ The returned object is an instance of :class:`.Transaction`.
+ This object represents the "scope" of the transaction,
+ which completes when either the :meth:`.Transaction.rollback`
+ or :meth:`.Transaction.commit` method is called.
+
+ .. tip::
+
+ The :meth:`_engine.Connection.begin` method is invoked when using
+ the :meth:`_engine.Engine.begin` context manager method as well.
+ All documentation that refers to behaviors specific to the
+ :meth:`_engine.Connection.begin` method also apply to use of the
+ :meth:`_engine.Engine.begin` method.
+
+ Legacy use: nested calls to :meth:`.begin` on the same
+ :class:`_engine.Connection` will return new :class:`.Transaction`
+ objects that represent an emulated transaction within the scope of the
+ enclosing transaction, that is::
+
+ trans = conn.begin() # outermost transaction
+ trans2 = conn.begin() # "nested"
+ trans2.commit() # does nothing
+ trans.commit() # actually commits
+
+ Calls to :meth:`.Transaction.commit` only have an effect
+ when invoked via the outermost :class:`.Transaction` object, though the
+ :meth:`.Transaction.rollback` method of any of the
+ :class:`.Transaction` objects will roll back the
+ transaction.
+
+ .. tip::
+
+ The above "nesting" behavior is a legacy behavior specific to
+ :term:`1.x style` use and will be removed in SQLAlchemy 2.0. For
+ notes on :term:`2.0 style` use, see
+ :meth:`_future.Connection.begin`.
+
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin_nested` - use a SAVEPOINT
+
+ :meth:`_engine.Connection.begin_twophase` -
+ use a two phase /XID transaction
+
+ :meth:`_engine.Engine.begin` - context manager available from
+ :class:`_engine.Engine`
+
+ """
+ if self._is_future:
+ assert not self.__branch_from
+ elif self.__branch_from:
+ return self.__branch_from.begin()
+
+ if self.__in_begin:
+ # for dialects that emit SQL within the process of
+ # dialect.do_begin() or dialect.do_begin_twophase(), this
+ # flag prevents "autobegin" from being emitted within that
+ # process, while allowing self._transaction to remain at None
+ # until it's complete.
+ return
+ elif self._transaction is None:
+ self._transaction = RootTransaction(self)
+ return self._transaction
+ else:
+ if self._is_future:
+ raise exc.InvalidRequestError(
+ "This connection has already initialized a SQLAlchemy "
+ "Transaction() object via begin() or autobegin; can't "
+ "call begin() here unless rollback() or commit() "
+ "is called first."
+ )
+ else:
+ return MarkerTransaction(self)
+
+ def begin_nested(self):
+ """Begin a nested transaction (i.e. SAVEPOINT) and return a
+ transaction handle, assuming an outer transaction is already
+ established.
+
+ Nested transactions require SAVEPOINT support in the
+ underlying database. Any transaction in the hierarchy may
+ ``commit`` and ``rollback``, however the outermost transaction
+ still controls the overall ``commit`` or ``rollback`` of the
+ transaction of a whole.
+
+ The legacy form of :meth:`_engine.Connection.begin_nested` method has
+ alternate behaviors based on whether or not the
+ :meth:`_engine.Connection.begin` method was called previously. If
+ :meth:`_engine.Connection.begin` was not called, then this method will
+ behave the same as the :meth:`_engine.Connection.begin` method and
+ return a :class:`.RootTransaction` object that begins and commits a
+ real transaction - **no savepoint is invoked**. If
+ :meth:`_engine.Connection.begin` **has** been called, and a
+ :class:`.RootTransaction` is already established, then this method
+ returns an instance of :class:`.NestedTransaction` which will invoke
+ and manage the scope of a SAVEPOINT.
+
+ .. tip::
+
+ The above mentioned behavior of
+ :meth:`_engine.Connection.begin_nested` is a legacy behavior
+ specific to :term:`1.x style` use. In :term:`2.0 style` use, the
+ :meth:`_future.Connection.begin_nested` method instead autobegins
+ the outer transaction that can be committed using
+ "commit-as-you-go" style; see
+ :meth:`_future.Connection.begin_nested` for migration details.
+
+ .. versionchanged:: 1.4.13 The behavior of
+ :meth:`_engine.Connection.begin_nested`
+ as returning a :class:`.RootTransaction` if
+ :meth:`_engine.Connection.begin` were not called has been restored
+ as was the case in 1.3.x versions; in previous 1.4.x versions, an
+ outer transaction would be "autobegun" but would not be committed.
+
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin`
+
+ :ref:`session_begin_nested` - ORM support for SAVEPOINT
+
+ """
+ if self._is_future:
+ assert not self.__branch_from
+ elif self.__branch_from:
+ return self.__branch_from.begin_nested()
+
+ if self._transaction is None:
+ if not self._is_future:
+ util.warn_deprecated_20(
+ "Calling Connection.begin_nested() in 2.0 style use will "
+ "return a NestedTransaction (SAVEPOINT) in all cases, "
+ "that will not commit the outer transaction. For code "
+ "that is cross-compatible between 1.x and 2.0 style use, "
+ "ensure Connection.begin() is called before calling "
+ "Connection.begin_nested()."
+ )
+ return self.begin()
+ else:
+ self._autobegin()
+
+ return NestedTransaction(self)
+
+ def begin_twophase(self, xid=None):
+ """Begin a two-phase or XA transaction and return a transaction
+ handle.
+
+ The returned object is an instance of :class:`.TwoPhaseTransaction`,
+ which in addition to the methods provided by
+ :class:`.Transaction`, also provides a
+ :meth:`~.TwoPhaseTransaction.prepare` method.
+
+ :param xid: the two phase transaction id. If not supplied, a
+ random id will be generated.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin`
+
+ :meth:`_engine.Connection.begin_twophase`
+
+ """
+
+ if self.__branch_from:
+ return self.__branch_from.begin_twophase(xid=xid)
+
+ if self._transaction is not None:
+ raise exc.InvalidRequestError(
+ "Cannot start a two phase transaction when a transaction "
+ "is already in progress."
+ )
+ if xid is None:
+ xid = self.engine.dialect.create_xid()
+ return TwoPhaseTransaction(self, xid)
+
+ def recover_twophase(self):
+ return self.engine.dialect.do_recover_twophase(self)
+
+ def rollback_prepared(self, xid, recover=False):
+ self.engine.dialect.do_rollback_twophase(self, xid, recover=recover)
+
+ def commit_prepared(self, xid, recover=False):
+ self.engine.dialect.do_commit_twophase(self, xid, recover=recover)
+
+ def in_transaction(self):
+ """Return True if a transaction is in progress."""
+ if self.__branch_from is not None:
+ return self.__branch_from.in_transaction()
+
+ return self._transaction is not None and self._transaction.is_active
+
+ def in_nested_transaction(self):
+ """Return True if a transaction is in progress."""
+ if self.__branch_from is not None:
+ return self.__branch_from.in_nested_transaction()
+
+ return (
+ self._nested_transaction is not None
+ and self._nested_transaction.is_active
+ )
+
+ def _is_autocommit_isolation(self):
+ opt_iso = self._execution_options.get("isolation_level", None)
+ return bool(
+ opt_iso == "AUTOCOMMIT"
+ or (
+ opt_iso is None
+ and getattr(self.engine.dialect, "isolation_level", None)
+ == "AUTOCOMMIT"
+ )
+ )
+
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+
+ if self.__branch_from is not None:
+ return self.__branch_from.get_transaction()
+
+ return self._transaction
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+ if self.__branch_from is not None:
+
+ return self.__branch_from.get_nested_transaction()
+
+ return self._nested_transaction
+
+ def _begin_impl(self, transaction):
+ assert not self.__branch_from
+
+ if self._echo:
+ if self._is_autocommit_isolation():
+ self._log_info(
+ "BEGIN (implicit; DBAPI should not BEGIN due to "
+ "autocommit mode)"
+ )
+ else:
+ self._log_info("BEGIN (implicit)")
+
+ self.__in_begin = True
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.begin(self)
+
+ try:
+ self.engine.dialect.do_begin(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+ finally:
+ self.__in_begin = False
+
+ def _rollback_impl(self):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.rollback(self)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ if self._echo:
+ if self._is_autocommit_isolation():
+ self._log_info(
+ "ROLLBACK using DBAPI connection.rollback(), "
+ "DBAPI should ignore due to autocommit mode"
+ )
+ else:
+ self._log_info("ROLLBACK")
+ try:
+ self.engine.dialect.do_rollback(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _commit_impl(self, autocommit=False):
+ assert not self.__branch_from
+
+ # AUTOCOMMIT isolation-level is a dialect-specific concept, however
+ # if a connection has this set as the isolation level, we can skip
+ # the "autocommit" warning as the operation will do "autocommit"
+ # in any case
+ if autocommit and not self._is_autocommit_isolation():
+ util.warn_deprecated_20(
+ "The current statement is being autocommitted using "
+ "implicit autocommit, which will be removed in "
+ "SQLAlchemy 2.0. "
+ "Use the .begin() method of Engine or Connection in order to "
+ "use an explicit transaction for DML and DDL statements."
+ )
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.commit(self)
+
+ if self._echo:
+ if self._is_autocommit_isolation():
+ self._log_info(
+ "COMMIT using DBAPI connection.commit(), "
+ "DBAPI should ignore due to autocommit mode"
+ )
+ else:
+ self._log_info("COMMIT")
+ try:
+ self.engine.dialect.do_commit(self.connection)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _savepoint_impl(self, name=None):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.savepoint(self, name)
+
+ if name is None:
+ self.__savepoint_seq += 1
+ name = "sa_savepoint_%s" % self.__savepoint_seq
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.engine.dialect.do_savepoint(self, name)
+ return name
+
+ def _rollback_to_savepoint_impl(self, name):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.rollback_savepoint(self, name, None)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.engine.dialect.do_rollback_to_savepoint(self, name)
+
+ def _release_savepoint_impl(self, name):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.release_savepoint(self, name, None)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.engine.dialect.do_release_savepoint(self, name)
+
+ def _begin_twophase_impl(self, transaction):
+ assert not self.__branch_from
+
+ if self._echo:
+ self._log_info("BEGIN TWOPHASE (implicit)")
+ if self._has_events or self.engine._has_events:
+ self.dispatch.begin_twophase(self, transaction.xid)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ self.__in_begin = True
+ try:
+ self.engine.dialect.do_begin_twophase(self, transaction.xid)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+ finally:
+ self.__in_begin = False
+
+ def _prepare_twophase_impl(self, xid):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.prepare_twophase(self, xid)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_prepare_twophase(self, xid)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _rollback_twophase_impl(self, xid, is_prepared):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.rollback_twophase(self, xid, is_prepared)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_rollback_twophase(
+ self, xid, is_prepared
+ )
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _commit_twophase_impl(self, xid, is_prepared):
+ assert not self.__branch_from
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.commit_twophase(self, xid, is_prepared)
+
+ if self._still_open_and_dbapi_connection_is_valid:
+ assert isinstance(self._transaction, TwoPhaseTransaction)
+ try:
+ self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ def _autorollback(self):
+ if self.__branch_from:
+ self.__branch_from._autorollback()
+
+ if not self.in_transaction():
+ self._rollback_impl()
+
+ def _warn_for_legacy_exec_format(self):
+ util.warn_deprecated_20(
+ "The connection.execute() method in "
+ "SQLAlchemy 2.0 will accept parameters as a single "
+ "dictionary or a "
+ "single sequence of dictionaries only. "
+ "Parameters passed as keyword arguments, tuples or positionally "
+ "oriented dictionaries and/or tuples "
+ "will no longer be accepted."
+ )
+
+ def close(self):
+ """Close this :class:`_engine.Connection`.
+
+ This results in a release of the underlying database
+ resources, that is, the DBAPI connection referenced
+ internally. The DBAPI connection is typically restored
+ back to the connection-holding :class:`_pool.Pool` referenced
+ by the :class:`_engine.Engine` that produced this
+ :class:`_engine.Connection`. Any transactional state present on
+ the DBAPI connection is also unconditionally released via
+ the DBAPI connection's ``rollback()`` method, regardless
+ of any :class:`.Transaction` object that may be
+ outstanding with regards to this :class:`_engine.Connection`.
+
+ After :meth:`_engine.Connection.close` is called, the
+ :class:`_engine.Connection` is permanently in a closed state,
+ and will allow no further operations.
+
+ """
+
+ if self.__branch_from:
+ assert not self._is_future
+ util.warn_deprecated_20(
+ "The .close() method on a so-called 'branched' connection is "
+ "deprecated as of 1.4, as are 'branched' connections overall, "
+ "and will be removed in a future release. If this is a "
+ "default-handling function, don't close the connection."
+ )
+ self._dbapi_connection = None
+ self.__can_reconnect = False
+ return
+
+ if self._transaction:
+ self._transaction.close()
+ skip_reset = True
+ else:
+ skip_reset = False
+
+ if self._dbapi_connection is not None:
+ conn = self._dbapi_connection
+
+ # as we just closed the transaction, close the connection
+ # pool connection without doing an additional reset
+ if skip_reset:
+ conn._close_no_reset()
+ else:
+ conn.close()
+
+ # There is a slight chance that conn.close() may have
+ # triggered an invalidation here in which case
+ # _dbapi_connection would already be None, however usually
+ # it will be non-None here and in a "closed" state.
+ self._dbapi_connection = None
+ self.__can_reconnect = False
+
+ def scalar(self, object_, *multiparams, **params):
+ """Executes and returns the first column of the first row.
+
+ The underlying result/cursor is closed after execution.
+
+ """
+
+ return self.execute(object_, *multiparams, **params).scalar()
+
+ def scalars(self, object_, *multiparams, **params):
+ """Executes and returns a scalar result set, which yields scalar values
+ from the first column of each row.
+
+ This method is equivalent to calling :meth:`_engine.Connection.execute`
+ to receive a :class:`_result.Result` object, then invoking the
+ :meth:`_result.Result.scalars` method to produce a
+ :class:`_result.ScalarResult` instance.
+
+ :return: a :class:`_result.ScalarResult`
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ return self.execute(object_, *multiparams, **params).scalars()
+
+ def execute(self, statement, *multiparams, **params):
+ r"""Executes a SQL statement construct and returns a
+ :class:`_engine.CursorResult`.
+
+ :param statement: The statement to be executed. May be
+ one of:
+
+ * a plain string (deprecated)
+ * any :class:`_expression.ClauseElement` construct that is also
+ a subclass of :class:`.Executable`, such as a
+ :func:`_expression.select` construct
+ * a :class:`.FunctionElement`, such as that generated
+ by :data:`.func`, will be automatically wrapped in
+ a SELECT statement, which is then executed.
+ * a :class:`.DDLElement` object
+ * a :class:`.DefaultGenerator` object
+ * a :class:`.Compiled` object
+
+ .. deprecated:: 2.0 passing a string to
+ :meth:`_engine.Connection.execute` is
+ deprecated and will be removed in version 2.0. Use the
+ :func:`_expression.text` construct with
+ :meth:`_engine.Connection.execute`, or the
+ :meth:`_engine.Connection.exec_driver_sql`
+ method to invoke a driver-level
+ SQL string.
+
+ :param \*multiparams/\**params: represent bound parameter
+ values to be used in the execution. Typically,
+ the format is either a collection of one or more
+ dictionaries passed to \*multiparams::
+
+ conn.execute(
+ table.insert(),
+ {"id":1, "value":"v1"},
+ {"id":2, "value":"v2"}
+ )
+
+ ...or individual key/values interpreted by \**params::
+
+ conn.execute(
+ table.insert(), id=1, value="v1"
+ )
+
+ In the case that a plain SQL string is passed, and the underlying
+ DBAPI accepts positional bind parameters, a collection of tuples
+ or individual values in \*multiparams may be passed::
+
+ conn.execute(
+ "INSERT INTO table (id, value) VALUES (?, ?)",
+ (1, "v1"), (2, "v2")
+ )
+
+ conn.execute(
+ "INSERT INTO table (id, value) VALUES (?, ?)",
+ 1, "v1"
+ )
+
+ Note above, the usage of a question mark "?" or other
+ symbol is contingent upon the "paramstyle" accepted by the DBAPI
+ in use, which may be any of "qmark", "named", "pyformat", "format",
+ "numeric". See `pep-249
+ <https://www.python.org/dev/peps/pep-0249/>`_ for details on
+ paramstyle.
+
+ To execute a textual SQL statement which uses bound parameters in a
+ DBAPI-agnostic way, use the :func:`_expression.text` construct.
+
+ .. deprecated:: 2.0 use of tuple or scalar positional parameters
+ is deprecated. All params should be dicts or sequences of dicts.
+ Use :meth:`.exec_driver_sql` to execute a plain string with
+ tuple or scalar positional parameters.
+
+ """
+
+ if isinstance(statement, util.string_types):
+ util.warn_deprecated_20(
+ "Passing a string to Connection.execute() is "
+ "deprecated and will be removed in version 2.0. Use the "
+ "text() construct, "
+ "or the Connection.exec_driver_sql() method to invoke a "
+ "driver-level SQL string."
+ )
+
+ return self._exec_driver_sql(
+ statement,
+ multiparams,
+ params,
+ _EMPTY_EXECUTION_OPTS,
+ future=False,
+ )
+
+ try:
+ meth = statement._execute_on_connection
+ except AttributeError as err:
+ util.raise_(
+ exc.ObjectNotExecutableError(statement), replace_context=err
+ )
+ else:
+ return meth(self, multiparams, params, _EMPTY_EXECUTION_OPTS)
+
+ def _execute_function(self, func, multiparams, params, execution_options):
+ """Execute a sql.FunctionElement object."""
+
+ return self._execute_clauseelement(
+ func.select(), multiparams, params, execution_options
+ )
+
+ def _execute_default(
+ self,
+ default,
+ multiparams,
+ params,
+ # migrate is calling this directly :(
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ """Execute a schema.ColumnDefault object."""
+
+ execution_options = self._execution_options.merge_with(
+ execution_options
+ )
+
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if self._has_events or self.engine._has_events:
+ (
+ default,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ default, distilled_parameters, execution_options
+ )
+
+ try:
+ conn = self._dbapi_connection
+ if conn is None:
+ conn = self._revalidate_connection()
+
+ dialect = self.dialect
+ ctx = dialect.execution_ctx_cls._init_default(
+ dialect, self, conn, execution_options
+ )
+ except (exc.PendingRollbackError, exc.ResourceClosedError):
+ raise
+ except BaseException as e:
+ self._handle_dbapi_exception(e, None, None, None, None)
+
+ ret = ctx._exec_default(None, default, None)
+ if self.should_close_with_result:
+ self.close()
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ default,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+
+ return ret
+
+ def _execute_ddl(self, ddl, multiparams, params, execution_options):
+ """Execute a schema.DDL object."""
+
+ execution_options = ddl._execution_options.merge_with(
+ self._execution_options, execution_options
+ )
+
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if self._has_events or self.engine._has_events:
+ (
+ ddl,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ ddl, distilled_parameters, execution_options
+ )
+
+ exec_opts = self._execution_options.merge_with(execution_options)
+ schema_translate_map = exec_opts.get("schema_translate_map", None)
+
+ dialect = self.dialect
+
+ compiled = ddl.compile(
+ dialect=dialect, schema_translate_map=schema_translate_map
+ )
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_ddl,
+ compiled,
+ None,
+ execution_options,
+ compiled,
+ )
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ ddl,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _invoke_before_exec_event(
+ self, elem, distilled_params, execution_options
+ ):
+
+ if len(distilled_params) == 1:
+ event_multiparams, event_params = [], distilled_params[0]
+ else:
+ event_multiparams, event_params = distilled_params, {}
+
+ for fn in self.dispatch.before_execute:
+ elem, event_multiparams, event_params = fn(
+ self,
+ elem,
+ event_multiparams,
+ event_params,
+ execution_options,
+ )
+
+ if event_multiparams:
+ distilled_params = list(event_multiparams)
+ if event_params:
+ raise exc.InvalidRequestError(
+ "Event handler can't return non-empty multiparams "
+ "and params at the same time"
+ )
+ elif event_params:
+ distilled_params = [event_params]
+ else:
+ distilled_params = []
+
+ return elem, distilled_params, event_multiparams, event_params
+
+ def _execute_clauseelement(
+ self, elem, multiparams, params, execution_options
+ ):
+ """Execute a sql.ClauseElement object."""
+
+ execution_options = elem._execution_options.merge_with(
+ self._execution_options, execution_options
+ )
+
+ distilled_params = _distill_params(self, multiparams, params)
+
+ has_events = self._has_events or self.engine._has_events
+ if has_events:
+ (
+ elem,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ elem, distilled_params, execution_options
+ )
+
+ if distilled_params:
+ # ensure we don't retain a link to the view object for keys()
+ # which links to the values, which we don't want to cache
+ keys = sorted(distilled_params[0])
+ for_executemany = len(distilled_params) > 1
+ else:
+ keys = []
+ for_executemany = False
+
+ dialect = self.dialect
+
+ schema_translate_map = execution_options.get(
+ "schema_translate_map", None
+ )
+
+ compiled_cache = execution_options.get(
+ "compiled_cache", self.engine._compiled_cache
+ )
+
+ compiled_sql, extracted_params, cache_hit = elem._compile_w_cache(
+ dialect=dialect,
+ compiled_cache=compiled_cache,
+ column_keys=keys,
+ for_executemany=for_executemany,
+ schema_translate_map=schema_translate_map,
+ linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+ )
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_compiled,
+ compiled_sql,
+ distilled_params,
+ execution_options,
+ compiled_sql,
+ distilled_params,
+ elem,
+ extracted_params,
+ cache_hit=cache_hit,
+ )
+ if has_events:
+ self.dispatch.after_execute(
+ self,
+ elem,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _execute_compiled(
+ self,
+ compiled,
+ multiparams,
+ params,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ """Execute a sql.Compiled object.
+
+ TODO: why do we have this? likely deprecate or remove
+
+ """
+
+ execution_options = compiled.execution_options.merge_with(
+ self._execution_options, execution_options
+ )
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if self._has_events or self.engine._has_events:
+ (
+ compiled,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ compiled, distilled_parameters, execution_options
+ )
+
+ dialect = self.dialect
+
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_compiled,
+ compiled,
+ distilled_parameters,
+ execution_options,
+ compiled,
+ distilled_parameters,
+ None,
+ None,
+ )
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ compiled,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _exec_driver_sql(
+ self, statement, multiparams, params, execution_options, future
+ ):
+
+ execution_options = self._execution_options.merge_with(
+ execution_options
+ )
+
+ distilled_parameters = _distill_params(self, multiparams, params)
+
+ if not future:
+ if self._has_events or self.engine._has_events:
+ (
+ statement,
+ distilled_params,
+ event_multiparams,
+ event_params,
+ ) = self._invoke_before_exec_event(
+ statement, distilled_parameters, execution_options
+ )
+
+ dialect = self.dialect
+ ret = self._execute_context(
+ dialect,
+ dialect.execution_ctx_cls._init_statement,
+ statement,
+ distilled_parameters,
+ execution_options,
+ statement,
+ distilled_parameters,
+ )
+
+ if not future:
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_execute(
+ self,
+ statement,
+ event_multiparams,
+ event_params,
+ execution_options,
+ ret,
+ )
+ return ret
+
+ def _execute_20(
+ self,
+ statement,
+ parameters=None,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ args_10style, kwargs_10style = _distill_params_20(parameters)
+ try:
+ meth = statement._execute_on_connection
+ except AttributeError as err:
+ util.raise_(
+ exc.ObjectNotExecutableError(statement), replace_context=err
+ )
+ else:
+ return meth(self, args_10style, kwargs_10style, execution_options)
+
+ def exec_driver_sql(
+ self, statement, parameters=None, execution_options=None
+ ):
+ r"""Executes a SQL statement construct and returns a
+ :class:`_engine.CursorResult`.
+
+ :param statement: The statement str to be executed. Bound parameters
+ must use the underlying DBAPI's paramstyle, such as "qmark",
+ "pyformat", "format", etc.
+
+ :param parameters: represent bound parameter values to be used in the
+ execution. The format is one of: a dictionary of named parameters,
+ a tuple of positional parameters, or a list containing either
+ dictionaries or tuples for multiple-execute support.
+
+ E.g. multiple dictionaries::
+
+
+ conn.exec_driver_sql(
+ "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)",
+ [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}]
+ )
+
+ Single dictionary::
+
+ conn.exec_driver_sql(
+ "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)",
+ dict(id=1, value="v1")
+ )
+
+ Single tuple::
+
+ conn.exec_driver_sql(
+ "INSERT INTO table (id, value) VALUES (?, ?)",
+ (1, 'v1')
+ )
+
+ .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does
+ not participate in the
+ :meth:`_events.ConnectionEvents.before_execute` and
+ :meth:`_events.ConnectionEvents.after_execute` events. To
+ intercept calls to :meth:`_engine.Connection.exec_driver_sql`, use
+ :meth:`_events.ConnectionEvents.before_cursor_execute` and
+ :meth:`_events.ConnectionEvents.after_cursor_execute`.
+
+ .. seealso::
+
+ :pep:`249`
+
+ """
+
+ args_10style, kwargs_10style = _distill_params_20(parameters)
+
+ return self._exec_driver_sql(
+ statement,
+ args_10style,
+ kwargs_10style,
+ execution_options,
+ future=True,
+ )
+
+ def _execute_context(
+ self,
+ dialect,
+ constructor,
+ statement,
+ parameters,
+ execution_options,
+ *args,
+ **kw
+ ):
+ """Create an :class:`.ExecutionContext` and execute, returning
+ a :class:`_engine.CursorResult`."""
+
+ branched = self
+ if self.__branch_from:
+ # if this is a "branched" connection, do everything in terms
+ # of the "root" connection, *except* for .close(), which is
+ # the only feature that branching provides
+ self = self.__branch_from
+
+ if execution_options:
+ yp = execution_options.get("yield_per", None)
+ if yp:
+ execution_options = execution_options.union(
+ {"stream_results": True, "max_row_buffer": yp}
+ )
+
+ try:
+ conn = self._dbapi_connection
+ if conn is None:
+ conn = self._revalidate_connection()
+
+ context = constructor(
+ dialect, self, conn, execution_options, *args, **kw
+ )
+ except (exc.PendingRollbackError, exc.ResourceClosedError):
+ raise
+ except BaseException as e:
+ self._handle_dbapi_exception(
+ e, util.text_type(statement), parameters, None, None
+ )
+
+ if (
+ self._transaction
+ and not self._transaction.is_active
+ or (
+ self._nested_transaction
+ and not self._nested_transaction.is_active
+ )
+ ):
+ self._invalid_transaction()
+
+ elif self._trans_context_manager:
+ TransactionalContext._trans_ctx_check(self)
+
+ if self._is_future and self._transaction is None:
+ self._autobegin()
+
+ context.pre_exec()
+
+ if dialect.use_setinputsizes:
+ context._set_input_sizes()
+
+ cursor, statement, parameters = (
+ context.cursor,
+ context.statement,
+ context.parameters,
+ )
+
+ if not context.executemany:
+ parameters = parameters[0]
+
+ if self._has_events or self.engine._has_events:
+ for fn in self.dispatch.before_cursor_execute:
+ statement, parameters = fn(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
+
+ if self._echo:
+
+ self._log_info(statement)
+
+ stats = context._get_cache_stats()
+
+ if not self.engine.hide_parameters:
+ self._log_info(
+ "[%s] %r",
+ stats,
+ sql_util._repr_params(
+ parameters, batches=10, ismulti=context.executemany
+ ),
+ )
+ else:
+ self._log_info(
+ "[%s] [SQL parameters hidden due to hide_parameters=True]"
+ % (stats,)
+ )
+
+ evt_handled = False
+ try:
+ if context.executemany:
+ if self.dialect._has_events:
+ for fn in self.dialect.dispatch.do_executemany:
+ if fn(cursor, statement, parameters, context):
+ evt_handled = True
+ break
+ if not evt_handled:
+ self.dialect.do_executemany(
+ cursor, statement, parameters, context
+ )
+ elif not parameters and context.no_parameters:
+ if self.dialect._has_events:
+ for fn in self.dialect.dispatch.do_execute_no_params:
+ if fn(cursor, statement, context):
+ evt_handled = True
+ break
+ if not evt_handled:
+ self.dialect.do_execute_no_params(
+ cursor, statement, context
+ )
+ else:
+ if self.dialect._has_events:
+ for fn in self.dialect.dispatch.do_execute:
+ if fn(cursor, statement, parameters, context):
+ evt_handled = True
+ break
+ if not evt_handled:
+ self.dialect.do_execute(
+ cursor, statement, parameters, context
+ )
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_cursor_execute(
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ context.executemany,
+ )
+
+ context.post_exec()
+
+ result = context._setup_result_proxy()
+
+ if not self._is_future:
+ should_close_with_result = branched.should_close_with_result
+
+ if not result._soft_closed and should_close_with_result:
+ result._autoclose_connection = True
+
+ if (
+ # usually we're in a transaction so avoid relatively
+ # expensive / legacy should_autocommit call
+ self._transaction is None
+ and context.should_autocommit
+ ):
+ self._commit_impl(autocommit=True)
+
+ # for "connectionless" execution, we have to close this
+ # Connection after the statement is complete.
+ # legacy stuff.
+ if should_close_with_result and context._soft_closed:
+ assert not self._is_future
+
+ # CursorResult already exhausted rows / has no rows.
+ # close us now
+ branched.close()
+
+ except BaseException as e:
+ self._handle_dbapi_exception(
+ e, statement, parameters, cursor, context
+ )
+
+ return result
+
+ def _cursor_execute(self, cursor, statement, parameters, context=None):
+ """Execute a statement + params on the given cursor.
+
+ Adds appropriate logging and exception handling.
+
+ This method is used by DefaultDialect for special-case
+ executions, such as for sequences and column defaults.
+ The path of statement execution in the majority of cases
+ terminates at _execute_context().
+
+ """
+ if self._has_events or self.engine._has_events:
+ for fn in self.dispatch.before_cursor_execute:
+ statement, parameters = fn(
+ self, cursor, statement, parameters, context, False
+ )
+
+ if self._echo:
+ self._log_info(statement)
+ self._log_info("[raw sql] %r", parameters)
+ try:
+ for fn in (
+ ()
+ if not self.dialect._has_events
+ else self.dialect.dispatch.do_execute
+ ):
+ if fn(cursor, statement, parameters, context):
+ break
+ else:
+ self.dialect.do_execute(cursor, statement, parameters, context)
+ except BaseException as e:
+ self._handle_dbapi_exception(
+ e, statement, parameters, cursor, context
+ )
+
+ if self._has_events or self.engine._has_events:
+ self.dispatch.after_cursor_execute(
+ self, cursor, statement, parameters, context, False
+ )
+
+ def _safe_close_cursor(self, cursor):
+ """Close the given cursor, catching exceptions
+ and turning into log warnings.
+
+ """
+ try:
+ cursor.close()
+ except Exception:
+ # log the error through the connection pool's logger.
+ self.engine.pool.logger.error(
+ "Error closing cursor", exc_info=True
+ )
+
+ _reentrant_error = False
+ _is_disconnect = False
+
+ def _handle_dbapi_exception(
+ self, e, statement, parameters, cursor, context
+ ):
+ exc_info = sys.exc_info()
+
+ is_exit_exception = util.is_exit_exception(e)
+
+ if not self._is_disconnect:
+ self._is_disconnect = (
+ isinstance(e, self.dialect.dbapi.Error)
+ and not self.closed
+ and self.dialect.is_disconnect(
+ e,
+ self._dbapi_connection if not self.invalidated else None,
+ cursor,
+ )
+ ) or (is_exit_exception and not self.closed)
+
+ invalidate_pool_on_disconnect = not is_exit_exception
+
+ if self._reentrant_error:
+ util.raise_(
+ exc.DBAPIError.instance(
+ statement,
+ parameters,
+ e,
+ self.dialect.dbapi.Error,
+ hide_parameters=self.engine.hide_parameters,
+ dialect=self.dialect,
+ ismulti=context.executemany
+ if context is not None
+ else None,
+ ),
+ with_traceback=exc_info[2],
+ from_=e,
+ )
+ self._reentrant_error = True
+ try:
+ # non-DBAPI error - if we already got a context,
+ # or there's no string statement, don't wrap it
+ should_wrap = isinstance(e, self.dialect.dbapi.Error) or (
+ statement is not None
+ and context is None
+ and not is_exit_exception
+ )
+
+ if should_wrap:
+ sqlalchemy_exception = exc.DBAPIError.instance(
+ statement,
+ parameters,
+ e,
+ self.dialect.dbapi.Error,
+ hide_parameters=self.engine.hide_parameters,
+ connection_invalidated=self._is_disconnect,
+ dialect=self.dialect,
+ ismulti=context.executemany
+ if context is not None
+ else None,
+ )
+ else:
+ sqlalchemy_exception = None
+
+ newraise = None
+
+ if (
+ self._has_events or self.engine._has_events
+ ) and not self._execution_options.get(
+ "skip_user_error_events", False
+ ):
+ ctx = ExceptionContextImpl(
+ e,
+ sqlalchemy_exception,
+ self.engine,
+ self,
+ cursor,
+ statement,
+ parameters,
+ context,
+ self._is_disconnect,
+ invalidate_pool_on_disconnect,
+ )
+
+ for fn in self.dispatch.handle_error:
+ try:
+ # handler returns an exception;
+ # call next handler in a chain
+ per_fn = fn(ctx)
+ if per_fn is not None:
+ ctx.chained_exception = newraise = per_fn
+ except Exception as _raised:
+ # handler raises an exception - stop processing
+ newraise = _raised
+ break
+
+ if self._is_disconnect != ctx.is_disconnect:
+ self._is_disconnect = ctx.is_disconnect
+ if sqlalchemy_exception:
+ sqlalchemy_exception.connection_invalidated = (
+ ctx.is_disconnect
+ )
+
+ # set up potentially user-defined value for
+ # invalidate pool.
+ invalidate_pool_on_disconnect = (
+ ctx.invalidate_pool_on_disconnect
+ )
+
+ if should_wrap and context:
+ context.handle_dbapi_exception(e)
+
+ if not self._is_disconnect:
+ if cursor:
+ self._safe_close_cursor(cursor)
+ with util.safe_reraise(warn_only=True):
+ self._autorollback()
+
+ if newraise:
+ util.raise_(newraise, with_traceback=exc_info[2], from_=e)
+ elif should_wrap:
+ util.raise_(
+ sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+ )
+ else:
+ util.raise_(exc_info[1], with_traceback=exc_info[2])
+
+ finally:
+ del self._reentrant_error
+ if self._is_disconnect:
+ del self._is_disconnect
+ if not self.invalidated:
+ dbapi_conn_wrapper = self._dbapi_connection
+ if invalidate_pool_on_disconnect:
+ self.engine.pool._invalidate(dbapi_conn_wrapper, e)
+ self.invalidate(e)
+ if self.should_close_with_result:
+ assert not self._is_future
+ self.close()
+
+ @classmethod
+ def _handle_dbapi_exception_noconnection(cls, e, dialect, engine):
+ exc_info = sys.exc_info()
+
+ is_disconnect = dialect.is_disconnect(e, None, None)
+
+ should_wrap = isinstance(e, dialect.dbapi.Error)
+
+ if should_wrap:
+ sqlalchemy_exception = exc.DBAPIError.instance(
+ None,
+ None,
+ e,
+ dialect.dbapi.Error,
+ hide_parameters=engine.hide_parameters,
+ connection_invalidated=is_disconnect,
+ )
+ else:
+ sqlalchemy_exception = None
+
+ newraise = None
+
+ if engine._has_events:
+ ctx = ExceptionContextImpl(
+ e,
+ sqlalchemy_exception,
+ engine,
+ None,
+ None,
+ None,
+ None,
+ None,
+ is_disconnect,
+ True,
+ )
+ for fn in engine.dispatch.handle_error:
+ try:
+ # handler returns an exception;
+ # call next handler in a chain
+ per_fn = fn(ctx)
+ if per_fn is not None:
+ ctx.chained_exception = newraise = per_fn
+ except Exception as _raised:
+ # handler raises an exception - stop processing
+ newraise = _raised
+ break
+
+ if sqlalchemy_exception and is_disconnect != ctx.is_disconnect:
+ sqlalchemy_exception.connection_invalidated = (
+ is_disconnect
+ ) = ctx.is_disconnect
+
+ if newraise:
+ util.raise_(newraise, with_traceback=exc_info[2], from_=e)
+ elif should_wrap:
+ util.raise_(
+ sqlalchemy_exception, with_traceback=exc_info[2], from_=e
+ )
+ else:
+ util.raise_(exc_info[1], with_traceback=exc_info[2])
+
+ def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ """run a DDL visitor.
+
+ This method is only here so that the MockConnection can change the
+ options given to the visitor so that "checkfirst" is skipped.
+
+ """
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Connection.transaction` "
+ "method is deprecated and will be "
+ "removed in a future release. Use the :meth:`_engine.Engine.begin` "
+ "context manager instead.",
+ )
+ def transaction(self, callable_, *args, **kwargs):
+ r"""Execute the given function within a transaction boundary.
+
+ The function is passed this :class:`_engine.Connection`
+ as the first argument, followed by the given \*args and \**kwargs,
+ e.g.::
+
+ def do_something(conn, x, y):
+ conn.execute(text("some statement"), {'x':x, 'y':y})
+
+ conn.transaction(do_something, 5, 10)
+
+ The operations inside the function are all invoked within the
+ context of a single :class:`.Transaction`.
+ Upon success, the transaction is committed. If an
+ exception is raised, the transaction is rolled back
+ before propagating the exception.
+
+ .. note::
+
+ The :meth:`.transaction` method is superseded by
+ the usage of the Python ``with:`` statement, which can
+ be used with :meth:`_engine.Connection.begin`::
+
+ with conn.begin():
+ conn.execute(text("some statement"), {'x':5, 'y':10})
+
+ As well as with :meth:`_engine.Engine.begin`::
+
+ with engine.begin() as conn:
+ conn.execute(text("some statement"), {'x':5, 'y':10})
+
+ .. seealso::
+
+ :meth:`_engine.Engine.begin` - engine-level transactional
+ context
+
+ :meth:`_engine.Engine.transaction` - engine-level version of
+ :meth:`_engine.Connection.transaction`
+
+ """
+
+ kwargs["_sa_skip_warning"] = True
+ trans = self.begin()
+ try:
+ ret = self.run_callable(callable_, *args, **kwargs)
+ trans.commit()
+ return ret
+ except:
+ with util.safe_reraise():
+ trans.rollback()
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Connection.run_callable` "
+ "method is deprecated and will "
+ "be removed in a future release. Invoke the callable function "
+ "directly, passing the Connection.",
+ )
+ def run_callable(self, callable_, *args, **kwargs):
+ r"""Given a callable object or function, execute it, passing
+ a :class:`_engine.Connection` as the first argument.
+
+ The given \*args and \**kwargs are passed subsequent
+ to the :class:`_engine.Connection` argument.
+
+ This function, along with :meth:`_engine.Engine.run_callable`,
+ allows a function to be run with a :class:`_engine.Connection`
+ or :class:`_engine.Engine` object without the need to know
+ which one is being dealt with.
+
+ """
+ return callable_(self, *args, **kwargs)
+
+
+class ExceptionContextImpl(ExceptionContext):
+ """Implement the :class:`.ExceptionContext` interface."""
+
+ def __init__(
+ self,
+ exception,
+ sqlalchemy_exception,
+ engine,
+ connection,
+ cursor,
+ statement,
+ parameters,
+ context,
+ is_disconnect,
+ invalidate_pool_on_disconnect,
+ ):
+ self.engine = engine
+ self.connection = connection
+ self.sqlalchemy_exception = sqlalchemy_exception
+ self.original_exception = exception
+ self.execution_context = context
+ self.statement = statement
+ self.parameters = parameters
+ self.is_disconnect = is_disconnect
+ self.invalidate_pool_on_disconnect = invalidate_pool_on_disconnect
+
+
+class Transaction(TransactionalContext):
+ """Represent a database transaction in progress.
+
+ The :class:`.Transaction` object is procured by
+ calling the :meth:`_engine.Connection.begin` method of
+ :class:`_engine.Connection`::
+
+ from sqlalchemy import create_engine
+ engine = create_engine("postgresql://scott:tiger@localhost/test")
+ connection = engine.connect()
+ trans = connection.begin()
+ connection.execute(text("insert into x (a, b) values (1, 2)"))
+ trans.commit()
+
+ The object provides :meth:`.rollback` and :meth:`.commit`
+ methods in order to control transaction boundaries. It
+ also implements a context manager interface so that
+ the Python ``with`` statement can be used with the
+ :meth:`_engine.Connection.begin` method::
+
+ with connection.begin():
+ connection.execute(text("insert into x (a, b) values (1, 2)"))
+
+ The Transaction object is **not** threadsafe.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.begin`
+
+ :meth:`_engine.Connection.begin_twophase`
+
+ :meth:`_engine.Connection.begin_nested`
+
+ .. index::
+ single: thread safety; Transaction
+ """
+
+ __slots__ = ()
+
+ _is_root = False
+
+ def __init__(self, connection):
+ raise NotImplementedError()
+
+ def _do_deactivate(self):
+ """do whatever steps are necessary to set this transaction as
+ "deactive", however leave this transaction object in place as far
+ as the connection's state.
+
+ for a "real" transaction this should roll back the transaction
+ and ensure this transaction is no longer a reset agent.
+
+ this is used for nesting of marker transactions where the marker
+ can set the "real" transaction as rolled back, however it stays
+ in place.
+
+ for 2.0 we hope to remove this nesting feature.
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def _deactivated_from_connection(self):
+ """True if this transaction is totally deactivated from the connection
+ and therefore can no longer affect its state.
+
+ """
+ raise NotImplementedError()
+
+ def _do_close(self):
+ raise NotImplementedError()
+
+ def _do_rollback(self):
+ raise NotImplementedError()
+
+ def _do_commit(self):
+ raise NotImplementedError()
+
+ @property
+ def is_valid(self):
+ return self.is_active and not self.connection.invalidated
+
+ def close(self):
+ """Close this :class:`.Transaction`.
+
+ If this transaction is the base transaction in a begin/commit
+ nesting, the transaction will rollback(). Otherwise, the
+ method returns.
+
+ This is used to cancel a Transaction without affecting the scope of
+ an enclosing transaction.
+
+ """
+ try:
+ self._do_close()
+ finally:
+ assert not self.is_active
+
+ def rollback(self):
+ """Roll back this :class:`.Transaction`.
+
+ The implementation of this may vary based on the type of transaction in
+ use:
+
+ * For a simple database transaction (e.g. :class:`.RootTransaction`),
+ it corresponds to a ROLLBACK.
+
+ * For a :class:`.NestedTransaction`, it corresponds to a
+ "ROLLBACK TO SAVEPOINT" operation.
+
+ * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two
+ phase transactions may be used.
+
+
+ """
+ try:
+ self._do_rollback()
+ finally:
+ assert not self.is_active
+
+ def commit(self):
+ """Commit this :class:`.Transaction`.
+
+ The implementation of this may vary based on the type of transaction in
+ use:
+
+ * For a simple database transaction (e.g. :class:`.RootTransaction`),
+ it corresponds to a COMMIT.
+
+ * For a :class:`.NestedTransaction`, it corresponds to a
+ "RELEASE SAVEPOINT" operation.
+
+ * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two
+ phase transactions may be used.
+
+ """
+ try:
+ self._do_commit()
+ finally:
+ assert not self.is_active
+
+ def _get_subject(self):
+ return self.connection
+
+ def _transaction_is_active(self):
+ return self.is_active
+
+ def _transaction_is_closed(self):
+ return not self._deactivated_from_connection
+
+ def _rollback_can_be_called(self):
+ # for RootTransaction / NestedTransaction, it's safe to call
+ # rollback() even if the transaction is deactive and no warnings
+ # will be emitted. tested in
+ # test_transaction.py -> test_no_rollback_in_deactive(?:_savepoint)?
+ return True
+
+
+class MarkerTransaction(Transaction):
+ """A 'marker' transaction that is used for nested begin() calls.
+
+ .. deprecated:: 1.4 future connection for 2.0 won't support this pattern.
+
+ """
+
+ __slots__ = ("connection", "_is_active", "_transaction")
+
+ def __init__(self, connection):
+ assert connection._transaction is not None
+ if not connection._transaction.is_active:
+ raise exc.InvalidRequestError(
+ "the current transaction on this connection is inactive. "
+ "Please issue a rollback first."
+ )
+
+ assert not connection._is_future
+ util.warn_deprecated_20(
+ "Calling .begin() when a transaction is already begun, creating "
+ "a 'sub' transaction, is deprecated "
+ "and will be removed in 2.0. See the documentation section "
+ "'Migrating from the nesting pattern' for background on how "
+ "to migrate from this pattern."
+ )
+
+ self.connection = connection
+
+ if connection._trans_context_manager:
+ TransactionalContext._trans_ctx_check(connection)
+
+ if connection._nested_transaction is not None:
+ self._transaction = connection._nested_transaction
+ else:
+ self._transaction = connection._transaction
+ self._is_active = True
+
+ @property
+ def _deactivated_from_connection(self):
+ return not self.is_active
+
+ @property
+ def is_active(self):
+ return self._is_active and self._transaction.is_active
+
+ def _deactivate(self):
+ self._is_active = False
+
+ def _do_close(self):
+ # does not actually roll back the root
+ self._deactivate()
+
+ def _do_rollback(self):
+ # does roll back the root
+ if self._is_active:
+ try:
+ self._transaction._do_deactivate()
+ finally:
+ self._deactivate()
+
+ def _do_commit(self):
+ self._deactivate()
+
+
+class RootTransaction(Transaction):
+ """Represent the "root" transaction on a :class:`_engine.Connection`.
+
+ This corresponds to the current "BEGIN/COMMIT/ROLLBACK" that's occurring
+ for the :class:`_engine.Connection`. The :class:`_engine.RootTransaction`
+ is created by calling upon the :meth:`_engine.Connection.begin` method, and
+ remains associated with the :class:`_engine.Connection` throughout its
+ active span. The current :class:`_engine.RootTransaction` in use is
+ accessible via the :attr:`_engine.Connection.get_transaction` method of
+ :class:`_engine.Connection`.
+
+ In :term:`2.0 style` use, the :class:`_future.Connection` also employs
+ "autobegin" behavior that will create a new
+ :class:`_engine.RootTransaction` whenever a connection in a
+ non-transactional state is used to emit commands on the DBAPI connection.
+ The scope of the :class:`_engine.RootTransaction` in 2.0 style
+ use can be controlled using the :meth:`_future.Connection.commit` and
+ :meth:`_future.Connection.rollback` methods.
+
+
+ """
+
+ _is_root = True
+
+ __slots__ = ("connection", "is_active")
+
+ def __init__(self, connection):
+ assert connection._transaction is None
+ if connection._trans_context_manager:
+ TransactionalContext._trans_ctx_check(connection)
+ self.connection = connection
+ self._connection_begin_impl()
+ connection._transaction = self
+
+ self.is_active = True
+
+ def _deactivate_from_connection(self):
+ if self.is_active:
+ assert self.connection._transaction is self
+ self.is_active = False
+
+ elif self.connection._transaction is not self:
+ util.warn("transaction already deassociated from connection")
+
+ @property
+ def _deactivated_from_connection(self):
+ return self.connection._transaction is not self
+
+ def _do_deactivate(self):
+ # called from a MarkerTransaction to cancel this root transaction.
+ # the transaction stays in place as connection._transaction, but
+ # is no longer active and is no longer the reset agent for the
+ # pooled connection. the connection won't support a new begin()
+ # until this transaction is explicitly closed, rolled back,
+ # or committed.
+
+ assert self.connection._transaction is self
+
+ if self.is_active:
+ self._connection_rollback_impl()
+
+ # handle case where a savepoint was created inside of a marker
+ # transaction that refers to a root. nested has to be cancelled
+ # also.
+ if self.connection._nested_transaction:
+ self.connection._nested_transaction._cancel()
+
+ self._deactivate_from_connection()
+
+ def _connection_begin_impl(self):
+ self.connection._begin_impl(self)
+
+ def _connection_rollback_impl(self):
+ self.connection._rollback_impl()
+
+ def _connection_commit_impl(self):
+ self.connection._commit_impl()
+
+ def _close_impl(self, try_deactivate=False):
+ try:
+ if self.is_active:
+ self._connection_rollback_impl()
+
+ if self.connection._nested_transaction:
+ self.connection._nested_transaction._cancel()
+ finally:
+ if self.is_active or try_deactivate:
+ self._deactivate_from_connection()
+ if self.connection._transaction is self:
+ self.connection._transaction = None
+
+ assert not self.is_active
+ assert self.connection._transaction is not self
+
+ def _do_close(self):
+ self._close_impl()
+
+ def _do_rollback(self):
+ self._close_impl(try_deactivate=True)
+
+ def _do_commit(self):
+ if self.is_active:
+ assert self.connection._transaction is self
+
+ try:
+ self._connection_commit_impl()
+ finally:
+ # whether or not commit succeeds, cancel any
+ # nested transactions, make this transaction "inactive"
+ # and remove it as a reset agent
+ if self.connection._nested_transaction:
+ self.connection._nested_transaction._cancel()
+
+ self._deactivate_from_connection()
+
+ # ...however only remove as the connection's current transaction
+ # if commit succeeded. otherwise it stays on so that a rollback
+ # needs to occur.
+ self.connection._transaction = None
+ else:
+ if self.connection._transaction is self:
+ self.connection._invalid_transaction()
+ else:
+ raise exc.InvalidRequestError("This transaction is inactive")
+
+ assert not self.is_active
+ assert self.connection._transaction is not self
+
+
+class NestedTransaction(Transaction):
+ """Represent a 'nested', or SAVEPOINT transaction.
+
+ The :class:`.NestedTransaction` object is created by calling the
+ :meth:`_engine.Connection.begin_nested` method of
+ :class:`_engine.Connection`.
+
+ When using :class:`.NestedTransaction`, the semantics of "begin" /
+ "commit" / "rollback" are as follows:
+
+ * the "begin" operation corresponds to the "BEGIN SAVEPOINT" command, where
+ the savepoint is given an explicit name that is part of the state
+ of this object.
+
+ * The :meth:`.NestedTransaction.commit` method corresponds to a
+ "RELEASE SAVEPOINT" operation, using the savepoint identifier associated
+ with this :class:`.NestedTransaction`.
+
+ * The :meth:`.NestedTransaction.rollback` method corresponds to a
+ "ROLLBACK TO SAVEPOINT" operation, using the savepoint identifier
+ associated with this :class:`.NestedTransaction`.
+
+ The rationale for mimicking the semantics of an outer transaction in
+ terms of savepoints so that code may deal with a "savepoint" transaction
+ and an "outer" transaction in an agnostic way.
+
+ .. seealso::
+
+ :ref:`session_begin_nested` - ORM version of the SAVEPOINT API.
+
+ """
+
+ __slots__ = ("connection", "is_active", "_savepoint", "_previous_nested")
+
+ def __init__(self, connection):
+ assert connection._transaction is not None
+ if connection._trans_context_manager:
+ TransactionalContext._trans_ctx_check(connection)
+ self.connection = connection
+ self._savepoint = self.connection._savepoint_impl()
+ self.is_active = True
+ self._previous_nested = connection._nested_transaction
+ connection._nested_transaction = self
+
+ def _deactivate_from_connection(self, warn=True):
+ if self.connection._nested_transaction is self:
+ self.connection._nested_transaction = self._previous_nested
+ elif warn:
+ util.warn(
+ "nested transaction already deassociated from connection"
+ )
+
+ @property
+ def _deactivated_from_connection(self):
+ return self.connection._nested_transaction is not self
+
+ def _cancel(self):
+ # called by RootTransaction when the outer transaction is
+ # committed, rolled back, or closed to cancel all savepoints
+ # without any action being taken
+ self.is_active = False
+ self._deactivate_from_connection()
+ if self._previous_nested:
+ self._previous_nested._cancel()
+
+ def _close_impl(self, deactivate_from_connection, warn_already_deactive):
+ try:
+ if self.is_active and self.connection._transaction.is_active:
+ self.connection._rollback_to_savepoint_impl(self._savepoint)
+ finally:
+ self.is_active = False
+
+ if deactivate_from_connection:
+ self._deactivate_from_connection(warn=warn_already_deactive)
+
+ assert not self.is_active
+ if deactivate_from_connection:
+ assert self.connection._nested_transaction is not self
+
+ def _do_deactivate(self):
+ self._close_impl(False, False)
+
+ def _do_close(self):
+ self._close_impl(True, False)
+
+ def _do_rollback(self):
+ self._close_impl(True, True)
+
+ def _do_commit(self):
+ if self.is_active:
+ try:
+ self.connection._release_savepoint_impl(self._savepoint)
+ finally:
+ # nested trans becomes inactive on failed release
+ # unconditionally. this prevents it from trying to
+ # emit SQL when it rolls back.
+ self.is_active = False
+
+ # but only de-associate from connection if it succeeded
+ self._deactivate_from_connection()
+ else:
+ if self.connection._nested_transaction is self:
+ self.connection._invalid_transaction()
+ else:
+ raise exc.InvalidRequestError(
+ "This nested transaction is inactive"
+ )
+
+
+class TwoPhaseTransaction(RootTransaction):
+ """Represent a two-phase transaction.
+
+ A new :class:`.TwoPhaseTransaction` object may be procured
+ using the :meth:`_engine.Connection.begin_twophase` method.
+
+ The interface is the same as that of :class:`.Transaction`
+ with the addition of the :meth:`prepare` method.
+
+ """
+
+ __slots__ = ("connection", "is_active", "xid", "_is_prepared")
+
+ def __init__(self, connection, xid):
+ self._is_prepared = False
+ self.xid = xid
+ super(TwoPhaseTransaction, self).__init__(connection)
+
+ def prepare(self):
+ """Prepare this :class:`.TwoPhaseTransaction`.
+
+ After a PREPARE, the transaction can be committed.
+
+ """
+ if not self.is_active:
+ raise exc.InvalidRequestError("This transaction is inactive")
+ self.connection._prepare_twophase_impl(self.xid)
+ self._is_prepared = True
+
+ def _connection_begin_impl(self):
+ self.connection._begin_twophase_impl(self)
+
+ def _connection_rollback_impl(self):
+ self.connection._rollback_twophase_impl(self.xid, self._is_prepared)
+
+ def _connection_commit_impl(self):
+ self.connection._commit_twophase_impl(self.xid, self._is_prepared)
+
+
+class Engine(Connectable, log.Identified):
+ """
+ Connects a :class:`~sqlalchemy.pool.Pool` and
+ :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a
+ source of database connectivity and behavior.
+
+ This is the **SQLAlchemy 1.x version** of :class:`_engine.Engine`. For
+ the :term:`2.0 style` version, which includes some API differences,
+ see :class:`_future.Engine`.
+
+ An :class:`_engine.Engine` object is instantiated publicly using the
+ :func:`~sqlalchemy.create_engine` function.
+
+ .. seealso::
+
+ :doc:`/core/engines`
+
+ :ref:`connections_toplevel`
+
+ """
+
+ _execution_options = _EMPTY_EXECUTION_OPTS
+ _has_events = False
+ _connection_cls = Connection
+ _sqla_logger_namespace = "sqlalchemy.engine.Engine"
+ _is_future = False
+
+ _schema_translate_map = None
+
+ def __init__(
+ self,
+ pool,
+ dialect,
+ url,
+ logging_name=None,
+ echo=None,
+ query_cache_size=500,
+ execution_options=None,
+ hide_parameters=False,
+ ):
+ self.pool = pool
+ self.url = url
+ self.dialect = dialect
+ if logging_name:
+ self.logging_name = logging_name
+ self.echo = echo
+ self.hide_parameters = hide_parameters
+ if query_cache_size != 0:
+ self._compiled_cache = util.LRUCache(
+ query_cache_size, size_alert=self._lru_size_alert
+ )
+ else:
+ self._compiled_cache = None
+ log.instance_logger(self, echoflag=echo)
+ if execution_options:
+ self.update_execution_options(**execution_options)
+
+ def _lru_size_alert(self, cache):
+ if self._should_log_info:
+ self.logger.info(
+ "Compiled cache size pruning from %d items to %d. "
+ "Increase cache size to reduce the frequency of pruning.",
+ len(cache),
+ cache.capacity,
+ )
+
+ @property
+ def engine(self):
+ return self
+
+ def clear_compiled_cache(self):
+ """Clear the compiled cache associated with the dialect.
+
+ This applies **only** to the built-in cache that is established
+ via the :paramref:`_engine.create_engine.query_cache_size` parameter.
+ It will not impact any dictionary caches that were passed via the
+ :paramref:`.Connection.execution_options.query_cache` parameter.
+
+ .. versionadded:: 1.4
+
+ """
+ if self._compiled_cache:
+ self._compiled_cache.clear()
+
+ def update_execution_options(self, **opt):
+ r"""Update the default execution_options dictionary
+ of this :class:`_engine.Engine`.
+
+ The given keys/values in \**opt are added to the
+ default execution options that will be used for
+ all connections. The initial contents of this dictionary
+ can be sent via the ``execution_options`` parameter
+ to :func:`_sa.create_engine`.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+
+ :meth:`_engine.Engine.execution_options`
+
+ """
+ self._execution_options = self._execution_options.union(opt)
+ self.dispatch.set_engine_execution_options(self, opt)
+ self.dialect.set_engine_execution_options(self, opt)
+
+ def execution_options(self, **opt):
+ """Return a new :class:`_engine.Engine` that will provide
+ :class:`_engine.Connection` objects with the given execution options.
+
+ The returned :class:`_engine.Engine` remains related to the original
+ :class:`_engine.Engine` in that it shares the same connection pool and
+ other state:
+
+ * The :class:`_pool.Pool` used by the new :class:`_engine.Engine`
+ is the
+ same instance. The :meth:`_engine.Engine.dispose`
+ method will replace
+ the connection pool instance for the parent engine as well
+ as this one.
+ * Event listeners are "cascaded" - meaning, the new
+ :class:`_engine.Engine`
+ inherits the events of the parent, and new events can be associated
+ with the new :class:`_engine.Engine` individually.
+ * The logging configuration and logging_name is copied from the parent
+ :class:`_engine.Engine`.
+
+ The intent of the :meth:`_engine.Engine.execution_options` method is
+ to implement "sharding" schemes where multiple :class:`_engine.Engine`
+ objects refer to the same connection pool, but are differentiated
+ by options that would be consumed by a custom event::
+
+ primary_engine = create_engine("mysql://")
+ shard1 = primary_engine.execution_options(shard_id="shard1")
+ shard2 = primary_engine.execution_options(shard_id="shard2")
+
+ Above, the ``shard1`` engine serves as a factory for
+ :class:`_engine.Connection`
+ objects that will contain the execution option
+ ``shard_id=shard1``, and ``shard2`` will produce
+ :class:`_engine.Connection`
+ objects that contain the execution option ``shard_id=shard2``.
+
+ An event handler can consume the above execution option to perform
+ a schema switch or other operation, given a connection. Below
+ we emit a MySQL ``use`` statement to switch databases, at the same
+ time keeping track of which database we've established using the
+ :attr:`_engine.Connection.info` dictionary,
+ which gives us a persistent
+ storage space that follows the DBAPI connection::
+
+ from sqlalchemy import event
+ from sqlalchemy.engine import Engine
+
+ shards = {"default": "base", shard_1: "db1", "shard_2": "db2"}
+
+ @event.listens_for(Engine, "before_cursor_execute")
+ def _switch_shard(conn, cursor, stmt,
+ params, context, executemany):
+ shard_id = conn._execution_options.get('shard_id', "default")
+ current_shard = conn.info.get("current_shard", None)
+
+ if current_shard != shard_id:
+ cursor.execute("use %s" % shards[shard_id])
+ conn.info["current_shard"] = shard_id
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+ - update execution options
+ on a :class:`_engine.Connection` object.
+
+ :meth:`_engine.Engine.update_execution_options`
+ - update the execution
+ options for a given :class:`_engine.Engine` in place.
+
+ :meth:`_engine.Engine.get_execution_options`
+
+
+ """
+ return self._option_cls(self, opt)
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded: 1.3
+
+ .. seealso::
+
+ :meth:`_engine.Engine.execution_options`
+ """
+ return self._execution_options
+
+ @property
+ def name(self):
+ """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
+ in use by this :class:`Engine`."""
+
+ return self.dialect.name
+
+ @property
+ def driver(self):
+ """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
+ in use by this :class:`Engine`."""
+
+ return self.dialect.driver
+
+ echo = log.echo_property()
+
+ def __repr__(self):
+ return "Engine(%r)" % (self.url,)
+
+ def dispose(self, close=True):
+ """Dispose of the connection pool used by this
+ :class:`_engine.Engine`.
+
+ A new connection pool is created immediately after the old one has been
+ disposed. The previous connection pool is disposed either actively, by
+ closing out all currently checked-in connections in that pool, or
+ passively, by losing references to it but otherwise not closing any
+ connections. The latter strategy is more appropriate for an initializer
+ in a forked Python process.
+
+ :param close: if left at its default of ``True``, has the
+ effect of fully closing all **currently checked in**
+ database connections. Connections that are still checked out
+ will **not** be closed, however they will no longer be associated
+ with this :class:`_engine.Engine`,
+ so when they are closed individually, eventually the
+ :class:`_pool.Pool` which they are associated with will
+ be garbage collected and they will be closed out fully, if
+ not already closed on checkin.
+
+ If set to ``False``, the previous connection pool is de-referenced,
+ and otherwise not touched in any way.
+
+ .. versionadded:: 1.4.33 Added the :paramref:`.Engine.dispose.close`
+ parameter to allow the replacement of a connection pool in a child
+ process without interfering with the connections used by the parent
+ process.
+
+
+ .. seealso::
+
+ :ref:`engine_disposal`
+
+ :ref:`pooling_multiprocessing`
+
+ """
+ if close:
+ self.pool.dispose()
+ self.pool = self.pool.recreate()
+ self.dispatch.engine_disposed(self)
+
+ def _execute_default(
+ self, default, multiparams=(), params=util.EMPTY_DICT
+ ):
+ with self.connect() as conn:
+ return conn._execute_default(default, multiparams, params)
+
+ @contextlib.contextmanager
+ def _optional_conn_ctx_manager(self, connection=None):
+ if connection is None:
+ with self.connect() as conn:
+ yield conn
+ else:
+ yield connection
+
+ class _trans_ctx(object):
+ def __init__(self, conn, transaction, close_with_result):
+ self.conn = conn
+ self.transaction = transaction
+ self.close_with_result = close_with_result
+
+ def __enter__(self):
+ self.transaction.__enter__()
+ return self.conn
+
+ def __exit__(self, type_, value, traceback):
+ try:
+ self.transaction.__exit__(type_, value, traceback)
+ finally:
+ if not self.close_with_result:
+ self.conn.close()
+
+ def begin(self, close_with_result=False):
+ """Return a context manager delivering a :class:`_engine.Connection`
+ with a :class:`.Transaction` established.
+
+ E.g.::
+
+ with engine.begin() as conn:
+ conn.execute(
+ text("insert into table (x, y, z) values (1, 2, 3)")
+ )
+ conn.execute(text("my_special_procedure(5)"))
+
+ Upon successful operation, the :class:`.Transaction`
+ is committed. If an error is raised, the :class:`.Transaction`
+ is rolled back.
+
+ Legacy use only: the ``close_with_result`` flag is normally ``False``,
+ and indicates that the :class:`_engine.Connection` will be closed when
+ the operation is complete. When set to ``True``, it indicates the
+ :class:`_engine.Connection` is in "single use" mode, where the
+ :class:`_engine.CursorResult` returned by the first call to
+ :meth:`_engine.Connection.execute` will close the
+ :class:`_engine.Connection` when that :class:`_engine.CursorResult` has
+ exhausted all result rows.
+
+ .. seealso::
+
+ :meth:`_engine.Engine.connect` - procure a
+ :class:`_engine.Connection` from
+ an :class:`_engine.Engine`.
+
+ :meth:`_engine.Connection.begin` - start a :class:`.Transaction`
+ for a particular :class:`_engine.Connection`.
+
+ """
+ if self._connection_cls._is_future:
+ conn = self.connect()
+ else:
+ conn = self.connect(close_with_result=close_with_result)
+ try:
+ trans = conn.begin()
+ except:
+ with util.safe_reraise():
+ conn.close()
+ return Engine._trans_ctx(conn, trans, close_with_result)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.transaction` "
+ "method is deprecated and will be "
+ "removed in a future release. Use the :meth:`_engine.Engine.begin` "
+ "context "
+ "manager instead.",
+ )
+ def transaction(self, callable_, *args, **kwargs):
+ r"""Execute the given function within a transaction boundary.
+
+ The function is passed a :class:`_engine.Connection` newly procured
+ from :meth:`_engine.Engine.connect` as the first argument,
+ followed by the given \*args and \**kwargs.
+
+ e.g.::
+
+ def do_something(conn, x, y):
+ conn.execute(text("some statement"), {'x':x, 'y':y})
+
+ engine.transaction(do_something, 5, 10)
+
+ The operations inside the function are all invoked within the
+ context of a single :class:`.Transaction`.
+ Upon success, the transaction is committed. If an
+ exception is raised, the transaction is rolled back
+ before propagating the exception.
+
+ .. note::
+
+ The :meth:`.transaction` method is superseded by
+ the usage of the Python ``with:`` statement, which can
+ be used with :meth:`_engine.Engine.begin`::
+
+ with engine.begin() as conn:
+ conn.execute(text("some statement"), {'x':5, 'y':10})
+
+ .. seealso::
+
+ :meth:`_engine.Engine.begin` - engine-level transactional
+ context
+
+ :meth:`_engine.Connection.transaction`
+ - connection-level version of
+ :meth:`_engine.Engine.transaction`
+
+ """
+ kwargs["_sa_skip_warning"] = True
+ with self.connect() as conn:
+ return conn.transaction(callable_, *args, **kwargs)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.run_callable` "
+ "method is deprecated and will be "
+ "removed in a future release. Use the :meth:`_engine.Engine.begin` "
+ "context manager instead.",
+ )
+ def run_callable(self, callable_, *args, **kwargs):
+ r"""Given a callable object or function, execute it, passing
+ a :class:`_engine.Connection` as the first argument.
+
+ The given \*args and \**kwargs are passed subsequent
+ to the :class:`_engine.Connection` argument.
+
+ This function, along with :meth:`_engine.Connection.run_callable`,
+ allows a function to be run with a :class:`_engine.Connection`
+ or :class:`_engine.Engine` object without the need to know
+ which one is being dealt with.
+
+ """
+ kwargs["_sa_skip_warning"] = True
+ with self.connect() as conn:
+ return conn.run_callable(callable_, *args, **kwargs)
+
+ def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ with self.begin() as conn:
+ conn._run_ddl_visitor(visitorcallable, element, **kwargs)
+
+ @util.deprecated_20(
+ ":meth:`_engine.Engine.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, statement, *multiparams, **params):
+ """Executes the given construct and returns a
+ :class:`_engine.CursorResult`.
+
+ The arguments are the same as those used by
+ :meth:`_engine.Connection.execute`.
+
+ Here, a :class:`_engine.Connection` is acquired using the
+ :meth:`_engine.Engine.connect` method, and the statement executed
+ with that connection. The returned :class:`_engine.CursorResult`
+ is flagged
+ such that when the :class:`_engine.CursorResult` is exhausted and its
+ underlying cursor is closed, the :class:`_engine.Connection`
+ created here
+ will also be closed, which allows its associated DBAPI connection
+ resource to be returned to the connection pool.
+
+ """
+ connection = self.connect(close_with_result=True)
+ return connection.execute(statement, *multiparams, **params)
+
+ @util.deprecated_20(
+ ":meth:`_engine.Engine.scalar`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`; the :meth:`_future.Result.scalar` "
+ "method can then be "
+ "used to return a scalar result.",
+ )
+ def scalar(self, statement, *multiparams, **params):
+ """Executes and returns the first column of the first row.
+
+ The underlying result/cursor is closed after execution.
+ """
+ return self.execute(statement, *multiparams, **params).scalar()
+
+ def _execute_clauseelement(
+ self,
+ elem,
+ multiparams=None,
+ params=None,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ connection = self.connect(close_with_result=True)
+ return connection._execute_clauseelement(
+ elem, multiparams, params, execution_options
+ )
+
+ def _execute_compiled(
+ self,
+ compiled,
+ multiparams,
+ params,
+ execution_options=_EMPTY_EXECUTION_OPTS,
+ ):
+ connection = self.connect(close_with_result=True)
+ return connection._execute_compiled(
+ compiled, multiparams, params, execution_options
+ )
+
+ def connect(self, close_with_result=False):
+ """Return a new :class:`_engine.Connection` object.
+
+ The :class:`_engine.Connection` object is a facade that uses a DBAPI
+ connection internally in order to communicate with the database. This
+ connection is procured from the connection-holding :class:`_pool.Pool`
+ referenced by this :class:`_engine.Engine`. When the
+ :meth:`_engine.Connection.close` method of the
+ :class:`_engine.Connection` object
+ is called, the underlying DBAPI connection is then returned to the
+ connection pool, where it may be used again in a subsequent call to
+ :meth:`_engine.Engine.connect`.
+
+ """
+
+ return self._connection_cls(self, close_with_result=close_with_result)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.table_names` "
+ "method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_reflection.Inspector.get_table_names`.",
+ )
+ def table_names(self, schema=None, connection=None):
+ """Return a list of all table names available in the database.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+
+ :param connection: Optional, use a specified connection.
+ """
+ with self._optional_conn_ctx_manager(connection) as conn:
+ insp = inspection.inspect(conn)
+ return insp.get_table_names(schema)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.Engine.has_table` "
+ "method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_reflection.Inspector.has_table`.",
+ )
+ def has_table(self, table_name, schema=None):
+ """Return True if the given backend has a table of the given name.
+
+ .. seealso::
+
+ :ref:`metadata_reflection_inspector` - detailed schema inspection
+ using the :class:`_reflection.Inspector` interface.
+
+ :class:`.quoted_name` - used to pass quoting information along
+ with a schema identifier.
+
+ """
+ with self._optional_conn_ctx_manager(None) as conn:
+ insp = inspection.inspect(conn)
+ return insp.has_table(table_name, schema=schema)
+
+ def _wrap_pool_connect(self, fn, connection):
+ dialect = self.dialect
+ try:
+ return fn()
+ except dialect.dbapi.Error as e:
+ if connection is None:
+ Connection._handle_dbapi_exception_noconnection(
+ e, dialect, self
+ )
+ else:
+ util.raise_(
+ sys.exc_info()[1], with_traceback=sys.exc_info()[2]
+ )
+
+ def raw_connection(self, _connection=None):
+ """Return a "raw" DBAPI connection from the connection pool.
+
+ The returned object is a proxied version of the DBAPI
+ connection object used by the underlying driver in use.
+ The object will have all the same behavior as the real DBAPI
+ connection, except that its ``close()`` method will result in the
+ connection being returned to the pool, rather than being closed
+ for real.
+
+ This method provides direct DBAPI connection access for
+ special situations when the API provided by
+ :class:`_engine.Connection`
+ is not needed. When a :class:`_engine.Connection` object is already
+ present, the DBAPI connection is available using
+ the :attr:`_engine.Connection.connection` accessor.
+
+ .. seealso::
+
+ :ref:`dbapi_connections`
+
+ """
+ return self._wrap_pool_connect(self.pool.connect, _connection)
+
+
+class OptionEngineMixin(object):
+ _sa_propagate_class_events = False
+
+ def __init__(self, proxied, execution_options):
+ self._proxied = proxied
+ self.url = proxied.url
+ self.dialect = proxied.dialect
+ self.logging_name = proxied.logging_name
+ self.echo = proxied.echo
+ self._compiled_cache = proxied._compiled_cache
+ self.hide_parameters = proxied.hide_parameters
+ log.instance_logger(self, echoflag=self.echo)
+
+ # note: this will propagate events that are assigned to the parent
+ # engine after this OptionEngine is created. Since we share
+ # the events of the parent we also disallow class-level events
+ # to apply to the OptionEngine class directly.
+ #
+ # the other way this can work would be to transfer existing
+ # events only, using:
+ # self.dispatch._update(proxied.dispatch)
+ #
+ # that might be more appropriate however it would be a behavioral
+ # change for logic that assigns events to the parent engine and
+ # would like it to take effect for the already-created sub-engine.
+ self.dispatch = self.dispatch._join(proxied.dispatch)
+
+ self._execution_options = proxied._execution_options
+ self.update_execution_options(**execution_options)
+
+ def _get_pool(self):
+ return self._proxied.pool
+
+ def _set_pool(self, pool):
+ self._proxied.pool = pool
+
+ pool = property(_get_pool, _set_pool)
+
+ def _get_has_events(self):
+ return self._proxied._has_events or self.__dict__.get(
+ "_has_events", False
+ )
+
+ def _set_has_events(self, value):
+ self.__dict__["_has_events"] = value
+
+ _has_events = property(_get_has_events, _set_has_events)
+
+
+class OptionEngine(OptionEngineMixin, Engine):
+ pass
+
+
+Engine._option_cls = OptionEngine
diff --git a/lib/sqlalchemy/engine/characteristics.py b/lib/sqlalchemy/engine/characteristics.py
new file mode 100644
index 0000000..c00bff4
--- /dev/null
+++ b/lib/sqlalchemy/engine/characteristics.py
@@ -0,0 +1,56 @@
+import abc
+
+from ..util import ABC
+
+
+class ConnectionCharacteristic(ABC):
+ """An abstract base for an object that can set, get and reset a
+ per-connection characteristic, typically one that gets reset when the
+ connection is returned to the connection pool.
+
+ transaction isolation is the canonical example, and the
+ ``IsolationLevelCharacteristic`` implementation provides this for the
+ ``DefaultDialect``.
+
+ The ``ConnectionCharacteristic`` class should call upon the ``Dialect`` for
+ the implementation of each method. The object exists strictly to serve as
+ a dialect visitor that can be placed into the
+ ``DefaultDialect.connection_characteristics`` dictionary where it will take
+ effect for calls to :meth:`_engine.Connection.execution_options` and
+ related APIs.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ()
+
+ transactional = False
+
+ @abc.abstractmethod
+ def reset_characteristic(self, dialect, dbapi_conn):
+ """Reset the characteristic on the connection to its default value."""
+
+ @abc.abstractmethod
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ """set characteristic on the connection to a given value."""
+
+ @abc.abstractmethod
+ def get_characteristic(self, dialect, dbapi_conn):
+ """Given a DBAPI connection, get the current value of the
+ characteristic.
+
+ """
+
+
+class IsolationLevelCharacteristic(ConnectionCharacteristic):
+ transactional = True
+
+ def reset_characteristic(self, dialect, dbapi_conn):
+ dialect.reset_isolation_level(dbapi_conn)
+
+ def set_characteristic(self, dialect, dbapi_conn, value):
+ dialect.set_isolation_level(dbapi_conn, value)
+
+ def get_characteristic(self, dialect, dbapi_conn):
+ return dialect.get_isolation_level(dbapi_conn)
diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py
new file mode 100644
index 0000000..b9886b7
--- /dev/null
+++ b/lib/sqlalchemy/engine/create.py
@@ -0,0 +1,743 @@
+# engine/create.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from . import base
+from . import url as _url
+from .mock import create_mock_engine
+from .. import event
+from .. import exc
+from .. import pool as poollib
+from .. import util
+from ..sql import compiler
+
+
+@util.deprecated_params(
+ strategy=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.strategy` keyword is deprecated, "
+ "and the only argument accepted is 'mock'; please use "
+ ":func:`.create_mock_engine` going forward. For general "
+ "customization of create_engine which may have been accomplished "
+ "using strategies, see :class:`.CreateEnginePlugin`.",
+ ),
+ empty_in_strategy=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is "
+ "deprecated, and no longer has any effect. All IN expressions "
+ "are now rendered using "
+ 'the "expanding parameter" strategy which renders a set of bound'
+ 'expressions, or an "empty set" SELECT, at statement execution'
+ "time.",
+ ),
+ case_sensitive=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.case_sensitive` parameter "
+ "is deprecated and will be removed in a future release. "
+ "Applications should work with result column names in a case "
+ "sensitive fashion.",
+ ),
+)
+def create_engine(url, **kwargs):
+ """Create a new :class:`_engine.Engine` instance.
+
+ The standard calling form is to send the :ref:`URL <database_urls>` as the
+ first positional argument, usually a string
+ that indicates database dialect and connection arguments::
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test")
+
+ .. note::
+
+ Please review :ref:`database_urls` for general guidelines in composing
+ URL strings. In particular, special characters, such as those often
+ part of passwords, must be URL encoded to be properly parsed.
+
+ Additional keyword arguments may then follow it which
+ establish various options on the resulting :class:`_engine.Engine`
+ and its underlying :class:`.Dialect` and :class:`_pool.Pool`
+ constructs::
+
+ engine = create_engine("mysql://scott:tiger@hostname/dbname",
+ encoding='latin1', echo=True)
+
+ The string form of the URL is
+ ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where
+ ``dialect`` is a database name such as ``mysql``, ``oracle``,
+ ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as
+ ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively,
+ the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`.
+
+ ``**kwargs`` takes a wide variety of options which are routed
+ towards their appropriate components. Arguments may be specific to
+ the :class:`_engine.Engine`, the underlying :class:`.Dialect`,
+ as well as the
+ :class:`_pool.Pool`. Specific dialects also accept keyword arguments that
+ are unique to that dialect. Here, we describe the parameters
+ that are common to most :func:`_sa.create_engine()` usage.
+
+ Once established, the newly resulting :class:`_engine.Engine` will
+ request a connection from the underlying :class:`_pool.Pool` once
+ :meth:`_engine.Engine.connect` is called, or a method which depends on it
+ such as :meth:`_engine.Engine.execute` is invoked. The
+ :class:`_pool.Pool` in turn
+ will establish the first actual DBAPI connection when this request
+ is received. The :func:`_sa.create_engine` call itself does **not**
+ establish any actual DBAPI connections directly.
+
+ .. seealso::
+
+ :doc:`/core/engines`
+
+ :doc:`/dialects/index`
+
+ :ref:`connections_toplevel`
+
+ :param case_sensitive: if False, result column names
+ will match in a case-insensitive fashion, that is,
+ ``row['SomeColumn']``.
+
+ :param connect_args: a dictionary of options which will be
+ passed directly to the DBAPI's ``connect()`` method as
+ additional keyword arguments. See the example
+ at :ref:`custom_dbapi_args`.
+
+ :param convert_unicode=False: if set to True, causes
+ all :class:`.String` datatypes to act as though the
+ :paramref:`.String.convert_unicode` flag has been set to ``True``,
+ regardless of a setting of ``False`` on an individual :class:`.String`
+ type. This has the effect of causing all :class:`.String` -based
+ columns to accommodate Python Unicode objects directly as though the
+ datatype were the :class:`.Unicode` type.
+
+ .. deprecated:: 1.3
+
+ The :paramref:`_sa.create_engine.convert_unicode` parameter
+ is deprecated and will be removed in a future release.
+ All modern DBAPIs now support Python Unicode directly and this
+ parameter is unnecessary.
+
+ :param creator: a callable which returns a DBAPI connection.
+ This creation function will be passed to the underlying
+ connection pool and will be used to create all new database
+ connections. Usage of this function causes connection
+ parameters specified in the URL argument to be bypassed.
+
+ This hook is not as flexible as the newer
+ :meth:`_events.DialectEvents.do_connect` hook which allows complete
+ control over how a connection is made to the database, given the full
+ set of URL arguments and state beforehand.
+
+ .. seealso::
+
+ :meth:`_events.DialectEvents.do_connect` - event hook that allows
+ full control over DBAPI connection mechanics.
+
+ :ref:`custom_dbapi_args`
+
+ :param echo=False: if True, the Engine will log all statements
+ as well as a ``repr()`` of their parameter lists to the default log
+ handler, which defaults to ``sys.stdout`` for output. If set to the
+ string ``"debug"``, result rows will be printed to the standard output
+ as well. The ``echo`` attribute of ``Engine`` can be modified at any
+ time to turn logging on and off; direct control of logging is also
+ available using the standard Python ``logging`` module.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+
+ :param echo_pool=False: if True, the connection pool will log
+ informational output such as when connections are invalidated
+ as well as when connections are recycled to the default log handler,
+ which defaults to ``sys.stdout`` for output. If set to the string
+ ``"debug"``, the logging will include pool checkouts and checkins.
+ Direct control of logging is also available using the standard Python
+ ``logging`` module.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+
+ :param empty_in_strategy: No longer used; SQLAlchemy now uses
+ "empty set" behavior for IN in all cases.
+
+ :param enable_from_linting: defaults to True. Will emit a warning
+ if a given SELECT statement is found to have un-linked FROM elements
+ which would cause a cartesian product.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`change_4737`
+
+ :param encoding: **legacy Python 2 value only, where it only applies to
+ specific DBAPIs, not used in Python 3 for any modern DBAPI driver.
+ Please refer to individual dialect documentation for client encoding
+ behaviors.** Defaults to the string value ``utf-8``. This value
+ refers **only** to the character encoding that is used when SQLAlchemy
+ sends or receives data from a :term:`DBAPI` that does not support
+ Python Unicode and **is only used under Python 2**, only for certain
+ DBAPI drivers, and only in certain circumstances. **Python 3 users
+ please DISREGARD this parameter and refer to the documentation for the
+ specific dialect in use in order to configure character encoding
+ behavior.**
+
+ .. note:: The ``encoding`` parameter deals only with in-Python
+ encoding issues that were prevalent with **some DBAPIS only**
+ under **Python 2 only**. Under Python 3 it is not used by
+ any modern dialect. For DBAPIs that require
+ client encoding configurations, which are most of those outside
+ of SQLite, please consult specific :ref:`dialect documentation
+ <dialect_toplevel>` for details.
+
+ All modern DBAPIs that work in Python 3 necessarily feature direct
+ support for Python unicode strings. Under Python 2, this was not
+ always the case. For those scenarios where the DBAPI is detected as
+ not supporting a Python ``unicode`` object under Python 2, this
+ encoding is used to determine the source/destination encoding. It is
+ **not used** for those cases where the DBAPI handles unicode directly.
+
+ To properly configure a system to accommodate Python ``unicode``
+ objects, the DBAPI should be configured to handle unicode to the
+ greatest degree as is appropriate - see the notes on unicode pertaining
+ to the specific target database in use at :ref:`dialect_toplevel`.
+
+ Areas where string encoding may need to be accommodated
+ outside of the DBAPI, nearly always under **Python 2 only**,
+ include zero or more of:
+
+ * the values passed to bound parameters, corresponding to
+ the :class:`.Unicode` type or the :class:`.String` type
+ when ``convert_unicode`` is ``True``;
+ * the values returned in result set columns corresponding
+ to the :class:`.Unicode` type or the :class:`.String`
+ type when ``convert_unicode`` is ``True``;
+ * the string SQL statement passed to the DBAPI's
+ ``cursor.execute()`` method;
+ * the string names of the keys in the bound parameter
+ dictionary passed to the DBAPI's ``cursor.execute()``
+ as well as ``cursor.setinputsizes()`` methods;
+ * the string column names retrieved from the DBAPI's
+ ``cursor.description`` attribute.
+
+ When using Python 3, the DBAPI is required to support all of the above
+ values as Python ``unicode`` objects, which in Python 3 are just known
+ as ``str``. In Python 2, the DBAPI does not specify unicode behavior
+ at all, so SQLAlchemy must make decisions for each of the above values
+ on a per-DBAPI basis - implementations are completely inconsistent in
+ their behavior.
+
+ :param execution_options: Dictionary execution options which will
+ be applied to all connections. See
+ :meth:`~sqlalchemy.engine.Connection.execution_options`
+
+ :param future: Use the 2.0 style :class:`_future.Engine` and
+ :class:`_future.Connection` API.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`migration_20_toplevel`
+
+ :param hide_parameters: Boolean, when set to True, SQL statement parameters
+ will not be displayed in INFO logging nor will they be formatted into
+ the string representation of :class:`.StatementError` objects.
+
+ .. versionadded:: 1.3.8
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+ :param implicit_returning=True: Legacy flag that when set to ``False``
+ will disable the use of ``RETURNING`` on supporting backends where it
+ would normally be used to fetch newly generated primary key values for
+ single-row INSERT statements that do not otherwise specify a RETURNING
+ clause. This behavior applies primarily to the PostgreSQL, Oracle,
+ SQL Server backends.
+
+ .. warning:: this flag originally allowed the "implicit returning"
+ feature to be *enabled* back when it was very new and there was not
+ well-established database support. In modern SQLAlchemy, this flag
+ should **always be set to True**. Some SQLAlchemy features will
+ fail to function properly if this flag is set to ``False``.
+
+ :param isolation_level: this string parameter is interpreted by various
+ dialects in order to affect the transaction isolation level of the
+ database connection. The parameter essentially accepts some subset of
+ these string arguments: ``"SERIALIZABLE"``, ``"REPEATABLE READ"``,
+ ``"READ COMMITTED"``, ``"READ UNCOMMITTED"`` and ``"AUTOCOMMIT"``.
+ Behavior here varies per backend, and
+ individual dialects should be consulted directly.
+
+ Note that the isolation level can also be set on a
+ per-:class:`_engine.Connection` basis as well, using the
+ :paramref:`.Connection.execution_options.isolation_level`
+ feature.
+
+ .. seealso::
+
+ :ref:`dbapi_autocommit`
+
+ :param json_deserializer: for dialects that support the
+ :class:`_types.JSON`
+ datatype, this is a Python callable that will convert a JSON string
+ to a Python object. By default, the Python ``json.loads`` function is
+ used.
+
+ .. versionchanged:: 1.3.7 The SQLite dialect renamed this from
+ ``_json_deserializer``.
+
+ :param json_serializer: for dialects that support the :class:`_types.JSON`
+ datatype, this is a Python callable that will render a given object
+ as JSON. By default, the Python ``json.dumps`` function is used.
+
+ .. versionchanged:: 1.3.7 The SQLite dialect renamed this from
+ ``_json_serializer``.
+
+
+ :param label_length=None: optional integer value which limits
+ the size of dynamically generated column labels to that many
+ characters. If less than 6, labels are generated as
+ "_(counter)". If ``None``, the value of
+ ``dialect.max_identifier_length``, which may be affected via the
+ :paramref:`_sa.create_engine.max_identifier_length` parameter,
+ is used instead. The value of
+ :paramref:`_sa.create_engine.label_length`
+ may not be larger than that of
+ :paramref:`_sa.create_engine.max_identfier_length`.
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.max_identifier_length`
+
+ :param listeners: A list of one or more
+ :class:`~sqlalchemy.interfaces.PoolListener` objects which will
+ receive connection pool events.
+
+ :param logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.engine" logger. Defaults to a hexstring of the
+ object's id.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+ :paramref:`_engine.Connection.execution_options.logging_token`
+
+
+
+ :param max_identifier_length: integer; override the max_identifier_length
+ determined by the dialect. if ``None`` or zero, has no effect. This
+ is the database's configured maximum number of characters that may be
+ used in a SQL identifier such as a table name, column name, or label
+ name. All dialects determine this value automatically, however in the
+ case of a new database version for which this value has changed but
+ SQLAlchemy's dialect has not been adjusted, the value may be passed
+ here.
+
+ .. versionadded:: 1.3.9
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.label_length`
+
+ :param max_overflow=10: the number of connections to allow in
+ connection pool "overflow", that is connections that can be
+ opened above and beyond the pool_size setting, which defaults
+ to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`.
+
+ :param module=None: reference to a Python module object (the module
+ itself, not its string name). Specifies an alternate DBAPI module to
+ be used by the engine's dialect. Each sub-dialect references a
+ specific DBAPI which will be imported before first connect. This
+ parameter causes the import to be bypassed, and the given module to
+ be used instead. Can be used for testing of DBAPIs as well as to
+ inject "mock" DBAPI implementations into the :class:`_engine.Engine`.
+
+ :param paramstyle=None: The `paramstyle <https://legacy.python.org/dev/peps/pep-0249/#paramstyle>`_
+ to use when rendering bound parameters. This style defaults to the
+ one recommended by the DBAPI itself, which is retrieved from the
+ ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept
+ more than one paramstyle, and in particular it may be desirable
+ to change a "named" paramstyle into a "positional" one, or vice versa.
+ When this attribute is passed, it should be one of the values
+ ``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or
+ ``"pyformat"``, and should correspond to a parameter style known
+ to be supported by the DBAPI in use.
+
+ :param pool=None: an already-constructed instance of
+ :class:`~sqlalchemy.pool.Pool`, such as a
+ :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this
+ pool will be used directly as the underlying connection pool
+ for the engine, bypassing whatever connection parameters are
+ present in the URL argument. For information on constructing
+ connection pools manually, see :ref:`pooling_toplevel`.
+
+ :param poolclass=None: a :class:`~sqlalchemy.pool.Pool`
+ subclass, which will be used to create a connection pool
+ instance using the connection parameters given in the URL. Note
+ this differs from ``pool`` in that you don't actually
+ instantiate the pool in this case, you just indicate what type
+ of pool to be used.
+
+ :param pool_logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.pool" logger. Defaults to a hexstring of the object's
+ id.
+
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+
+ :param pool_pre_ping: boolean, if True will enable the connection pool
+ "pre-ping" feature that tests connections for liveness upon
+ each checkout.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`pool_disconnects_pessimistic`
+
+ :param pool_size=5: the number of connections to keep open
+ inside the connection pool. This used with
+ :class:`~sqlalchemy.pool.QueuePool` as
+ well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With
+ :class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting
+ of 0 indicates no limit; to disable pooling, set ``poolclass`` to
+ :class:`~sqlalchemy.pool.NullPool` instead.
+
+ :param pool_recycle=-1: this setting causes the pool to recycle
+ connections after the given number of seconds has passed. It
+ defaults to -1, or no timeout. For example, setting to 3600
+ means connections will be recycled after one hour. Note that
+ MySQL in particular will disconnect automatically if no
+ activity is detected on a connection for eight hours (although
+ this is configurable with the MySQLDB connection itself and the
+ server configuration as well).
+
+ .. seealso::
+
+ :ref:`pool_setting_recycle`
+
+ :param pool_reset_on_return='rollback': set the
+ :paramref:`_pool.Pool.reset_on_return` parameter of the underlying
+ :class:`_pool.Pool` object, which can be set to the values
+ ``"rollback"``, ``"commit"``, or ``None``.
+
+ .. seealso::
+
+ :paramref:`_pool.Pool.reset_on_return`
+
+ :param pool_timeout=30: number of seconds to wait before giving
+ up on getting a connection from the pool. This is only used
+ with :class:`~sqlalchemy.pool.QueuePool`. This can be a float but is
+ subject to the limitations of Python time functions which may not be
+ reliable in the tens of milliseconds.
+
+ .. note: don't use 30.0 above, it seems to break with the :param tag
+
+ :param pool_use_lifo=False: use LIFO (last-in-first-out) when retrieving
+ connections from :class:`.QueuePool` instead of FIFO
+ (first-in-first-out). Using LIFO, a server-side timeout scheme can
+ reduce the number of connections used during non- peak periods of
+ use. When planning for server-side timeouts, ensure that a recycle or
+ pre-ping strategy is in use to gracefully handle stale connections.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`pool_use_lifo`
+
+ :ref:`pool_disconnects`
+
+ :param plugins: string list of plugin names to load. See
+ :class:`.CreateEnginePlugin` for background.
+
+ .. versionadded:: 1.2.3
+
+ :param query_cache_size: size of the cache used to cache the SQL string
+ form of queries. Set to zero to disable caching.
+
+ The cache is pruned of its least recently used items when its size reaches
+ N * 1.5. Defaults to 500, meaning the cache will always store at least
+ 500 SQL statements when filled, and will grow up to 750 items at which
+ point it is pruned back down to 500 by removing the 250 least recently
+ used items.
+
+ Caching is accomplished on a per-statement basis by generating a
+ cache key that represents the statement's structure, then generating
+ string SQL for the current dialect only if that key is not present
+ in the cache. All statements support caching, however some features
+ such as an INSERT with a large set of parameters will intentionally
+ bypass the cache. SQL logging will indicate statistics for each
+ statement whether or not it were pull from the cache.
+
+ .. note:: some ORM functions related to unit-of-work persistence as well
+ as some attribute loading strategies will make use of individual
+ per-mapper caches outside of the main cache.
+
+
+ .. seealso::
+
+ :ref:`sql_caching`
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ if "strategy" in kwargs:
+ strat = kwargs.pop("strategy")
+ if strat == "mock":
+ return create_mock_engine(url, **kwargs)
+ else:
+ raise exc.ArgumentError("unknown strategy: %r" % strat)
+
+ kwargs.pop("empty_in_strategy", None)
+
+ # create url.URL object
+ u = _url.make_url(url)
+
+ u, plugins, kwargs = u._instantiate_plugins(kwargs)
+
+ entrypoint = u._get_entrypoint()
+ dialect_cls = entrypoint.get_dialect_cls(u)
+
+ if kwargs.pop("_coerce_config", False):
+
+ def pop_kwarg(key, default=None):
+ value = kwargs.pop(key, default)
+ if key in dialect_cls.engine_config_types:
+ value = dialect_cls.engine_config_types[key](value)
+ return value
+
+ else:
+ pop_kwarg = kwargs.pop
+
+ dialect_args = {}
+ # consume dialect arguments from kwargs
+ for k in util.get_cls_kwargs(dialect_cls):
+ if k in kwargs:
+ dialect_args[k] = pop_kwarg(k)
+
+ dbapi = kwargs.pop("module", None)
+ if dbapi is None:
+ dbapi_args = {}
+ for k in util.get_func_kwargs(dialect_cls.dbapi):
+ if k in kwargs:
+ dbapi_args[k] = pop_kwarg(k)
+ dbapi = dialect_cls.dbapi(**dbapi_args)
+
+ dialect_args["dbapi"] = dbapi
+
+ dialect_args.setdefault("compiler_linting", compiler.NO_LINTING)
+ enable_from_linting = kwargs.pop("enable_from_linting", True)
+ if enable_from_linting:
+ dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS
+
+ for plugin in plugins:
+ plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
+
+ # create dialect
+ dialect = dialect_cls(**dialect_args)
+
+ # assemble connection arguments
+ (cargs, cparams) = dialect.create_connect_args(u)
+ cparams.update(pop_kwarg("connect_args", {}))
+ cargs = list(cargs) # allow mutability
+
+ # look for existing pool or create
+ pool = pop_kwarg("pool", None)
+ if pool is None:
+
+ def connect(connection_record=None):
+ if dialect._has_events:
+ for fn in dialect.dispatch.do_connect:
+ connection = fn(dialect, connection_record, cargs, cparams)
+ if connection is not None:
+ return connection
+ return dialect.connect(*cargs, **cparams)
+
+ creator = pop_kwarg("creator", connect)
+
+ poolclass = pop_kwarg("poolclass", None)
+ if poolclass is None:
+ poolclass = dialect.get_dialect_pool_class(u)
+ pool_args = {"dialect": dialect}
+
+ # consume pool arguments from kwargs, translating a few of
+ # the arguments
+ translate = {
+ "logging_name": "pool_logging_name",
+ "echo": "echo_pool",
+ "timeout": "pool_timeout",
+ "recycle": "pool_recycle",
+ "events": "pool_events",
+ "reset_on_return": "pool_reset_on_return",
+ "pre_ping": "pool_pre_ping",
+ "use_lifo": "pool_use_lifo",
+ }
+ for k in util.get_cls_kwargs(poolclass):
+ tk = translate.get(k, k)
+ if tk in kwargs:
+ pool_args[k] = pop_kwarg(tk)
+
+ for plugin in plugins:
+ plugin.handle_pool_kwargs(poolclass, pool_args)
+
+ pool = poolclass(creator, **pool_args)
+ else:
+ if isinstance(pool, poollib.dbapi_proxy._DBProxy):
+ pool = pool.get_pool(*cargs, **cparams)
+
+ pool._dialect = dialect
+
+ # create engine.
+ if pop_kwarg("future", False):
+ from sqlalchemy import future
+
+ default_engine_class = future.Engine
+ else:
+ default_engine_class = base.Engine
+
+ engineclass = kwargs.pop("_future_engine_class", default_engine_class)
+
+ engine_args = {}
+ for k in util.get_cls_kwargs(engineclass):
+ if k in kwargs:
+ engine_args[k] = pop_kwarg(k)
+
+ # internal flags used by the test suite for instrumenting / proxying
+ # engines with mocks etc.
+ _initialize = kwargs.pop("_initialize", True)
+ _wrap_do_on_connect = kwargs.pop("_wrap_do_on_connect", None)
+
+ # all kwargs should be consumed
+ if kwargs:
+ raise TypeError(
+ "Invalid argument(s) %s sent to create_engine(), "
+ "using configuration %s/%s/%s. Please check that the "
+ "keyword arguments are appropriate for this combination "
+ "of components."
+ % (
+ ",".join("'%s'" % k for k in kwargs),
+ dialect.__class__.__name__,
+ pool.__class__.__name__,
+ engineclass.__name__,
+ )
+ )
+
+ engine = engineclass(pool, dialect, u, **engine_args)
+
+ if _initialize:
+ do_on_connect = dialect.on_connect_url(u)
+ if do_on_connect:
+ if _wrap_do_on_connect:
+ do_on_connect = _wrap_do_on_connect(do_on_connect)
+
+ def on_connect(dbapi_connection, connection_record):
+ do_on_connect(dbapi_connection)
+
+ event.listen(pool, "connect", on_connect)
+
+ def first_connect(dbapi_connection, connection_record):
+ c = base.Connection(
+ engine,
+ connection=dbapi_connection,
+ _has_events=False,
+ # reconnecting will be a reentrant condition, so if the
+ # connection goes away, Connection is then closed
+ _allow_revalidate=False,
+ )
+ c._execution_options = util.EMPTY_DICT
+
+ try:
+ dialect.initialize(c)
+ finally:
+ # note that "invalidated" and "closed" are mutually
+ # exclusive in 1.4 Connection.
+ if not c.invalidated and not c.closed:
+ # transaction is rolled back otherwise, tested by
+ # test/dialect/postgresql/test_dialect.py
+ # ::MiscBackendTest::test_initial_transaction_state
+ dialect.do_rollback(c.connection)
+
+ # previously, the "first_connect" event was used here, which was then
+ # scaled back if the "on_connect" handler were present. now,
+ # since "on_connect" is virtually always present, just use
+ # "connect" event with once_unless_exception in all cases so that
+ # the connection event flow is consistent in all cases.
+ event.listen(
+ pool, "connect", first_connect, _once_unless_exception=True
+ )
+
+ dialect_cls.engine_created(engine)
+ if entrypoint is not dialect_cls:
+ entrypoint.engine_created(engine)
+
+ for plugin in plugins:
+ plugin.engine_created(engine)
+
+ return engine
+
+
+def engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+ """Create a new Engine instance using a configuration dictionary.
+
+ The dictionary is typically produced from a config file.
+
+ The keys of interest to ``engine_from_config()`` should be prefixed, e.g.
+ ``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument
+ indicates the prefix to be searched for. Each matching key (after the
+ prefix is stripped) is treated as though it were the corresponding keyword
+ argument to a :func:`_sa.create_engine` call.
+
+ The only required key is (assuming the default prefix) ``sqlalchemy.url``,
+ which provides the :ref:`database URL <database_urls>`.
+
+ A select set of keyword arguments will be "coerced" to their
+ expected type based on string values. The set of arguments
+ is extensible per-dialect using the ``engine_config_types`` accessor.
+
+ :param configuration: A dictionary (typically produced from a config file,
+ but this is not a requirement). Items whose keys start with the value
+ of 'prefix' will have that prefix stripped, and will then be passed to
+ :func:`_sa.create_engine`.
+
+ :param prefix: Prefix to match and then strip from keys
+ in 'configuration'.
+
+ :param kwargs: Each keyword argument to ``engine_from_config()`` itself
+ overrides the corresponding item taken from the 'configuration'
+ dictionary. Keyword arguments should *not* be prefixed.
+
+ """
+
+ options = dict(
+ (key[len(prefix) :], configuration[key])
+ for key in configuration
+ if key.startswith(prefix)
+ )
+ options["_coerce_config"] = True
+ options.update(kwargs)
+ url = options.pop("url")
+ return create_engine(url, **options)
diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py
new file mode 100644
index 0000000..774916d
--- /dev/null
+++ b/lib/sqlalchemy/engine/cursor.py
@@ -0,0 +1,1942 @@
+# engine/cursor.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define cursor-specific result set constructs including
+:class:`.BaseCursorResult`, :class:`.CursorResult`."""
+
+
+import collections
+import functools
+
+from .result import Result
+from .result import ResultMetaData
+from .result import SimpleResultMetaData
+from .result import tuplegetter
+from .row import LegacyRow
+from .. import exc
+from .. import util
+from ..sql import expression
+from ..sql import sqltypes
+from ..sql import util as sql_util
+from ..sql.base import _generative
+from ..sql.compiler import RM_NAME
+from ..sql.compiler import RM_OBJECTS
+from ..sql.compiler import RM_RENDERED_NAME
+from ..sql.compiler import RM_TYPE
+
+_UNPICKLED = util.symbol("unpickled")
+
+
+# metadata entry tuple indexes.
+# using raw tuple is faster than namedtuple.
+MD_INDEX = 0 # integer index in cursor.description
+MD_RESULT_MAP_INDEX = 1 # integer index in compiled._result_columns
+MD_OBJECTS = 2 # other string keys and ColumnElement obj that can match
+MD_LOOKUP_KEY = 3 # string key we usually expect for key-based lookup
+MD_RENDERED_NAME = 4 # name that is usually in cursor.description
+MD_PROCESSOR = 5 # callable to process a result value into a row
+MD_UNTRANSLATED = 6 # raw name from cursor.description
+
+
+class CursorResultMetaData(ResultMetaData):
+ """Result metadata for DBAPI cursors."""
+
+ __slots__ = (
+ "_keymap",
+ "case_sensitive",
+ "_processors",
+ "_keys",
+ "_keymap_by_result_column_idx",
+ "_tuplefilter",
+ "_translated_indexes",
+ "_safe_for_cache"
+ # don't need _unique_filters support here for now. Can be added
+ # if a need arises.
+ )
+
+ returns_rows = True
+
+ def _has_key(self, key):
+ return key in self._keymap
+
+ def _for_freeze(self):
+ return SimpleResultMetaData(
+ self._keys,
+ extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
+ )
+
+ def _reduce(self, keys):
+ recs = list(self._metadata_for_keys(keys))
+
+ indexes = [rec[MD_INDEX] for rec in recs]
+ new_keys = [rec[MD_LOOKUP_KEY] for rec in recs]
+
+ if self._translated_indexes:
+ indexes = [self._translated_indexes[idx] for idx in indexes]
+
+ tup = tuplegetter(*indexes)
+
+ new_metadata = self.__class__.__new__(self.__class__)
+ new_metadata.case_sensitive = self.case_sensitive
+ new_metadata._processors = self._processors
+ new_metadata._keys = new_keys
+ new_metadata._tuplefilter = tup
+ new_metadata._translated_indexes = indexes
+
+ new_recs = [
+ (index,) + rec[1:]
+ for index, rec in enumerate(self._metadata_for_keys(keys))
+ ]
+ new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
+
+ # TODO: need unit test for:
+ # result = connection.execute("raw sql, no columns").scalars()
+ # without the "or ()" it's failing because MD_OBJECTS is None
+ new_metadata._keymap.update(
+ {
+ e: new_rec
+ for new_rec in new_recs
+ for e in new_rec[MD_OBJECTS] or ()
+ }
+ )
+
+ return new_metadata
+
+ def _adapt_to_context(self, context):
+ """When using a cached Compiled construct that has a _result_map,
+ for a new statement that used the cached Compiled, we need to ensure
+ the keymap has the Column objects from our new statement as keys.
+ So here we rewrite keymap with new entries for the new columns
+ as matched to those of the cached statement.
+
+ """
+
+ if not context.compiled._result_columns:
+ return self
+
+ compiled_statement = context.compiled.statement
+ invoked_statement = context.invoked_statement
+
+ if compiled_statement is invoked_statement:
+ return self
+
+ # make a copy and add the columns from the invoked statement
+ # to the result map.
+ md = self.__class__.__new__(self.__class__)
+
+ md._keymap = dict(self._keymap)
+
+ keymap_by_position = self._keymap_by_result_column_idx
+
+ for idx, new in enumerate(invoked_statement._all_selected_columns):
+ try:
+ rec = keymap_by_position[idx]
+ except KeyError:
+ # this can happen when there are bogus column entries
+ # in a TextualSelect
+ pass
+ else:
+ md._keymap[new] = rec
+
+ md.case_sensitive = self.case_sensitive
+ md._processors = self._processors
+ assert not self._tuplefilter
+ md._tuplefilter = None
+ md._translated_indexes = None
+ md._keys = self._keys
+ md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
+ md._safe_for_cache = self._safe_for_cache
+ return md
+
+ def __init__(self, parent, cursor_description):
+ context = parent.context
+ dialect = context.dialect
+ self._tuplefilter = None
+ self._translated_indexes = None
+ self.case_sensitive = dialect.case_sensitive
+ self._safe_for_cache = False
+
+ if context.result_column_struct:
+ (
+ result_columns,
+ cols_are_ordered,
+ textual_ordered,
+ loose_column_name_matching,
+ ) = context.result_column_struct
+ num_ctx_cols = len(result_columns)
+ else:
+ result_columns = (
+ cols_are_ordered
+ ) = (
+ num_ctx_cols
+ ) = loose_column_name_matching = textual_ordered = False
+
+ # merge cursor.description with the column info
+ # present in the compiled structure, if any
+ raw = self._merge_cursor_description(
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ loose_column_name_matching,
+ )
+
+ self._keymap = {}
+
+ # processors in key order for certain per-row
+ # views like __iter__ and slices
+ self._processors = [
+ metadata_entry[MD_PROCESSOR] for metadata_entry in raw
+ ]
+
+ if context.compiled:
+ self._keymap_by_result_column_idx = {
+ metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
+ for metadata_entry in raw
+ }
+
+ # keymap by primary string...
+ by_key = dict(
+ [
+ (metadata_entry[MD_LOOKUP_KEY], metadata_entry)
+ for metadata_entry in raw
+ ]
+ )
+
+ # for compiled SQL constructs, copy additional lookup keys into
+ # the key lookup map, such as Column objects, labels,
+ # column keys and other names
+ if num_ctx_cols:
+
+ # if by-primary-string dictionary smaller (or bigger?!) than
+ # number of columns, assume we have dupes, rewrite
+ # dupe records with "None" for index which results in
+ # ambiguous column exception when accessed.
+ if len(by_key) != num_ctx_cols:
+ # new in 1.4: get the complete set of all possible keys,
+ # strings, objects, whatever, that are dupes across two
+ # different records, first.
+ index_by_key = {}
+ dupes = set()
+ for metadata_entry in raw:
+ for key in (metadata_entry[MD_RENDERED_NAME],) + (
+ metadata_entry[MD_OBJECTS] or ()
+ ):
+ if not self.case_sensitive and isinstance(
+ key, util.string_types
+ ):
+ key = key.lower()
+ idx = metadata_entry[MD_INDEX]
+ # if this key has been associated with more than one
+ # positional index, it's a dupe
+ if index_by_key.setdefault(key, idx) != idx:
+ dupes.add(key)
+
+ # then put everything we have into the keymap excluding only
+ # those keys that are dupes.
+ self._keymap.update(
+ [
+ (obj_elem, metadata_entry)
+ for metadata_entry in raw
+ if metadata_entry[MD_OBJECTS]
+ for obj_elem in metadata_entry[MD_OBJECTS]
+ if obj_elem not in dupes
+ ]
+ )
+
+ # then for the dupe keys, put the "ambiguous column"
+ # record into by_key.
+ by_key.update({key: (None, None, (), key) for key in dupes})
+
+ else:
+ # no dupes - copy secondary elements from compiled
+ # columns into self._keymap
+ self._keymap.update(
+ [
+ (obj_elem, metadata_entry)
+ for metadata_entry in raw
+ if metadata_entry[MD_OBJECTS]
+ for obj_elem in metadata_entry[MD_OBJECTS]
+ ]
+ )
+
+ # update keymap with primary string names taking
+ # precedence
+ self._keymap.update(by_key)
+
+ # update keymap with "translated" names (sqlite-only thing)
+ if not num_ctx_cols and context._translate_colname:
+ self._keymap.update(
+ [
+ (
+ metadata_entry[MD_UNTRANSLATED],
+ self._keymap[metadata_entry[MD_LOOKUP_KEY]],
+ )
+ for metadata_entry in raw
+ if metadata_entry[MD_UNTRANSLATED]
+ ]
+ )
+
+ def _merge_cursor_description(
+ self,
+ context,
+ cursor_description,
+ result_columns,
+ num_ctx_cols,
+ cols_are_ordered,
+ textual_ordered,
+ loose_column_name_matching,
+ ):
+ """Merge a cursor.description with compiled result column information.
+
+ There are at least four separate strategies used here, selected
+ depending on the type of SQL construct used to start with.
+
+ The most common case is that of the compiled SQL expression construct,
+ which generated the column names present in the raw SQL string and
+ which has the identical number of columns as were reported by
+ cursor.description. In this case, we assume a 1-1 positional mapping
+ between the entries in cursor.description and the compiled object.
+ This is also the most performant case as we disregard extracting /
+ decoding the column names present in cursor.description since we
+ already have the desired name we generated in the compiled SQL
+ construct.
+
+ The next common case is that of the completely raw string SQL,
+ such as passed to connection.execute(). In this case we have no
+ compiled construct to work with, so we extract and decode the
+ names from cursor.description and index those as the primary
+ result row target keys.
+
+ The remaining fairly common case is that of the textual SQL
+ that includes at least partial column information; this is when
+ we use a :class:`_expression.TextualSelect` construct.
+ This construct may have
+ unordered or ordered column information. In the ordered case, we
+ merge the cursor.description and the compiled construct's information
+ positionally, and warn if there are additional description names
+ present, however we still decode the names in cursor.description
+ as we don't have a guarantee that the names in the columns match
+ on these. In the unordered case, we match names in cursor.description
+ to that of the compiled construct based on name matching.
+ In both of these cases, the cursor.description names and the column
+ expression objects and names are indexed as result row target keys.
+
+ The final case is much less common, where we have a compiled
+ non-textual SQL expression construct, but the number of columns
+ in cursor.description doesn't match what's in the compiled
+ construct. We make the guess here that there might be textual
+ column expressions in the compiled construct that themselves include
+ a comma in them causing them to split. We do the same name-matching
+ as with textual non-ordered columns.
+
+ The name-matched system of merging is the same as that used by
+ SQLAlchemy for all cases up through te 0.9 series. Positional
+ matching for compiled SQL expressions was introduced in 1.0 as a
+ major performance feature, and positional matching for textual
+ :class:`_expression.TextualSelect` objects in 1.1.
+ As name matching is no longer
+ a common case, it was acceptable to factor it into smaller generator-
+ oriented methods that are easier to understand, but incur slightly
+ more performance overhead.
+
+ """
+
+ case_sensitive = context.dialect.case_sensitive
+
+ if (
+ num_ctx_cols
+ and cols_are_ordered
+ and not textual_ordered
+ and num_ctx_cols == len(cursor_description)
+ ):
+ self._keys = [elem[0] for elem in result_columns]
+ # pure positional 1-1 case; doesn't need to read
+ # the names from cursor.description
+
+ # this metadata is safe to cache because we are guaranteed
+ # to have the columns in the same order for new executions
+ self._safe_for_cache = True
+ return [
+ (
+ idx,
+ idx,
+ rmap_entry[RM_OBJECTS],
+ rmap_entry[RM_NAME].lower()
+ if not case_sensitive
+ else rmap_entry[RM_NAME],
+ rmap_entry[RM_RENDERED_NAME],
+ context.get_result_processor(
+ rmap_entry[RM_TYPE],
+ rmap_entry[RM_RENDERED_NAME],
+ cursor_description[idx][1],
+ ),
+ None,
+ )
+ for idx, rmap_entry in enumerate(result_columns)
+ ]
+ else:
+
+ # name-based or text-positional cases, where we need
+ # to read cursor.description names
+
+ if textual_ordered:
+ self._safe_for_cache = True
+ # textual positional case
+ raw_iterator = self._merge_textual_cols_by_position(
+ context, cursor_description, result_columns
+ )
+ elif num_ctx_cols:
+ # compiled SQL with a mismatch of description cols
+ # vs. compiled cols, or textual w/ unordered columns
+ # the order of columns can change if the query is
+ # against a "select *", so not safe to cache
+ self._safe_for_cache = False
+ raw_iterator = self._merge_cols_by_name(
+ context,
+ cursor_description,
+ result_columns,
+ loose_column_name_matching,
+ )
+ else:
+ # no compiled SQL, just a raw string, order of columns
+ # can change for "select *"
+ self._safe_for_cache = False
+ raw_iterator = self._merge_cols_by_none(
+ context, cursor_description
+ )
+
+ return [
+ (
+ idx,
+ ridx,
+ obj,
+ cursor_colname,
+ cursor_colname,
+ context.get_result_processor(
+ mapped_type, cursor_colname, coltype
+ ),
+ untranslated,
+ )
+ for (
+ idx,
+ ridx,
+ cursor_colname,
+ mapped_type,
+ coltype,
+ obj,
+ untranslated,
+ ) in raw_iterator
+ ]
+
+ def _colnames_from_description(self, context, cursor_description):
+ """Extract column names and data types from a cursor.description.
+
+ Applies unicode decoding, column translation, "normalization",
+ and case sensitivity rules to the names based on the dialect.
+
+ """
+
+ dialect = context.dialect
+ case_sensitive = dialect.case_sensitive
+ translate_colname = context._translate_colname
+ description_decoder = (
+ dialect._description_decoder
+ if dialect.description_encoding
+ else None
+ )
+ normalize_name = (
+ dialect.normalize_name if dialect.requires_name_normalize else None
+ )
+ untranslated = None
+
+ self._keys = []
+
+ for idx, rec in enumerate(cursor_description):
+ colname = rec[0]
+ coltype = rec[1]
+
+ if description_decoder:
+ colname = description_decoder(colname)
+
+ if translate_colname:
+ colname, untranslated = translate_colname(colname)
+
+ if normalize_name:
+ colname = normalize_name(colname)
+
+ self._keys.append(colname)
+ if not case_sensitive:
+ colname = colname.lower()
+
+ yield idx, colname, untranslated, coltype
+
+ def _merge_textual_cols_by_position(
+ self, context, cursor_description, result_columns
+ ):
+ num_ctx_cols = len(result_columns) if result_columns else None
+
+ if num_ctx_cols > len(cursor_description):
+ util.warn(
+ "Number of columns in textual SQL (%d) is "
+ "smaller than number of columns requested (%d)"
+ % (num_ctx_cols, len(cursor_description))
+ )
+ seen = set()
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
+ if idx < num_ctx_cols:
+ ctx_rec = result_columns[idx]
+ obj = ctx_rec[RM_OBJECTS]
+ ridx = idx
+ mapped_type = ctx_rec[RM_TYPE]
+ if obj[0] in seen:
+ raise exc.InvalidRequestError(
+ "Duplicate column expression requested "
+ "in textual SQL: %r" % obj[0]
+ )
+ seen.add(obj[0])
+ else:
+ mapped_type = sqltypes.NULLTYPE
+ obj = None
+ ridx = None
+ yield idx, ridx, colname, mapped_type, coltype, obj, untranslated
+
+ def _merge_cols_by_name(
+ self,
+ context,
+ cursor_description,
+ result_columns,
+ loose_column_name_matching,
+ ):
+ dialect = context.dialect
+ case_sensitive = dialect.case_sensitive
+ match_map = self._create_description_match_map(
+ result_columns, case_sensitive, loose_column_name_matching
+ )
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
+ try:
+ ctx_rec = match_map[colname]
+ except KeyError:
+ mapped_type = sqltypes.NULLTYPE
+ obj = None
+ result_columns_idx = None
+ else:
+ obj = ctx_rec[1]
+ mapped_type = ctx_rec[2]
+ result_columns_idx = ctx_rec[3]
+ yield (
+ idx,
+ result_columns_idx,
+ colname,
+ mapped_type,
+ coltype,
+ obj,
+ untranslated,
+ )
+
+ @classmethod
+ def _create_description_match_map(
+ cls,
+ result_columns,
+ case_sensitive=True,
+ loose_column_name_matching=False,
+ ):
+ """when matching cursor.description to a set of names that are present
+ in a Compiled object, as is the case with TextualSelect, get all the
+ names we expect might match those in cursor.description.
+ """
+
+ d = {}
+ for ridx, elem in enumerate(result_columns):
+ key = elem[RM_RENDERED_NAME]
+
+ if not case_sensitive:
+ key = key.lower()
+ if key in d:
+ # conflicting keyname - just add the column-linked objects
+ # to the existing record. if there is a duplicate column
+ # name in the cursor description, this will allow all of those
+ # objects to raise an ambiguous column error
+ e_name, e_obj, e_type, e_ridx = d[key]
+ d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type, ridx
+ else:
+ d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx)
+
+ if loose_column_name_matching:
+ # when using a textual statement with an unordered set
+ # of columns that line up, we are expecting the user
+ # to be using label names in the SQL that match to the column
+ # expressions. Enable more liberal matching for this case;
+ # duplicate keys that are ambiguous will be fixed later.
+ for r_key in elem[RM_OBJECTS]:
+ d.setdefault(
+ r_key,
+ (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx),
+ )
+
+ return d
+
+ def _merge_cols_by_none(self, context, cursor_description):
+ for (
+ idx,
+ colname,
+ untranslated,
+ coltype,
+ ) in self._colnames_from_description(context, cursor_description):
+ yield (
+ idx,
+ None,
+ colname,
+ sqltypes.NULLTYPE,
+ coltype,
+ None,
+ untranslated,
+ )
+
+ def _key_fallback(self, key, err, raiseerr=True):
+ if raiseerr:
+ util.raise_(
+ exc.NoSuchColumnError(
+ "Could not locate column in row for column '%s'"
+ % util.string_or_unprintable(key)
+ ),
+ replace_context=err,
+ )
+ else:
+ return None
+
+ def _raise_for_ambiguous_column_name(self, rec):
+ raise exc.InvalidRequestError(
+ "Ambiguous column name '%s' in "
+ "result set column descriptions" % rec[MD_LOOKUP_KEY]
+ )
+
+ def _index_for_key(self, key, raiseerr=True):
+ # TODO: can consider pre-loading ints and negative ints
+ # into _keymap - also no coverage here
+ if isinstance(key, int):
+ key = self._keys[key]
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._key_fallback(key, ke, raiseerr)
+ if rec is None:
+ return None
+
+ index = rec[0]
+
+ if index is None:
+ self._raise_for_ambiguous_column_name(rec)
+ return index
+
+ def _indexes_for_keys(self, keys):
+
+ try:
+ return [self._keymap[key][0] for key in keys]
+ except KeyError as ke:
+ # ensure it raises
+ CursorResultMetaData._key_fallback(self, ke.args[0], ke)
+
+ def _metadata_for_keys(self, keys):
+ for key in keys:
+ if int in key.__class__.__mro__:
+ key = self._keys[key]
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ # ensure it raises
+ CursorResultMetaData._key_fallback(self, ke.args[0], ke)
+
+ index = rec[0]
+
+ if index is None:
+ self._raise_for_ambiguous_column_name(rec)
+
+ yield rec
+
+ def __getstate__(self):
+ return {
+ "_keymap": {
+ key: (rec[MD_INDEX], rec[MD_RESULT_MAP_INDEX], _UNPICKLED, key)
+ for key, rec in self._keymap.items()
+ if isinstance(key, util.string_types + util.int_types)
+ },
+ "_keys": self._keys,
+ "case_sensitive": self.case_sensitive,
+ "_translated_indexes": self._translated_indexes,
+ "_tuplefilter": self._tuplefilter,
+ }
+
+ def __setstate__(self, state):
+ self._processors = [None for _ in range(len(state["_keys"]))]
+ self._keymap = state["_keymap"]
+
+ self._keymap_by_result_column_idx = {
+ rec[MD_RESULT_MAP_INDEX]: rec for rec in self._keymap.values()
+ }
+ self._keys = state["_keys"]
+ self.case_sensitive = state["case_sensitive"]
+
+ if state["_translated_indexes"]:
+ self._translated_indexes = state["_translated_indexes"]
+ self._tuplefilter = tuplegetter(*self._translated_indexes)
+ else:
+ self._translated_indexes = self._tuplefilter = None
+
+
+class LegacyCursorResultMetaData(CursorResultMetaData):
+ __slots__ = ()
+
+ def _contains(self, value, row):
+ key = value
+ if key in self._keymap:
+ util.warn_deprecated_20(
+ "Using the 'in' operator to test for string or column "
+ "keys, or integer indexes, in a :class:`.Row` object is "
+ "deprecated and will "
+ "be removed in a future release. "
+ "Use the `Row._fields` or `Row._mapping` attribute, i.e. "
+ "'key in row._fields'",
+ )
+ return True
+ else:
+ return self._key_fallback(key, None, False) is not None
+
+ def _key_fallback(self, key, err, raiseerr=True):
+ map_ = self._keymap
+ result = None
+
+ if isinstance(key, util.string_types):
+ result = map_.get(key if self.case_sensitive else key.lower())
+ elif isinstance(key, expression.ColumnElement):
+ if (
+ key._tq_label
+ and (
+ key._tq_label
+ if self.case_sensitive
+ else key._tq_label.lower()
+ )
+ in map_
+ ):
+ result = map_[
+ key._tq_label
+ if self.case_sensitive
+ else key._tq_label.lower()
+ ]
+ elif (
+ hasattr(key, "name")
+ and (key.name if self.case_sensitive else key.name.lower())
+ in map_
+ ):
+ # match is only on name.
+ result = map_[
+ key.name if self.case_sensitive else key.name.lower()
+ ]
+
+ # search extra hard to make sure this
+ # isn't a column/label name overlap.
+ # this check isn't currently available if the row
+ # was unpickled.
+ if result is not None and result[MD_OBJECTS] not in (
+ None,
+ _UNPICKLED,
+ ):
+ for obj in result[MD_OBJECTS]:
+ if key._compare_name_for_result(obj):
+ break
+ else:
+ result = None
+ if result is not None:
+ if result[MD_OBJECTS] is _UNPICKLED:
+ util.warn_deprecated(
+ "Retrieving row values using Column objects from a "
+ "row that was unpickled is deprecated; adequate "
+ "state cannot be pickled for this to be efficient. "
+ "This usage will raise KeyError in a future release.",
+ version="1.4",
+ )
+ else:
+ util.warn_deprecated(
+ "Retrieving row values using Column objects with only "
+ "matching names as keys is deprecated, and will raise "
+ "KeyError in a future release; only Column "
+ "objects that are explicitly part of the statement "
+ "object should be used.",
+ version="1.4",
+ )
+ if result is None:
+ if raiseerr:
+ util.raise_(
+ exc.NoSuchColumnError(
+ "Could not locate column in row for column '%s'"
+ % util.string_or_unprintable(key)
+ ),
+ replace_context=err,
+ )
+ else:
+ return None
+ else:
+ map_[key] = result
+ return result
+
+ def _warn_for_nonint(self, key):
+ util.warn_deprecated_20(
+ "Using non-integer/slice indices on Row is deprecated and will "
+ "be removed in version 2.0; please use row._mapping[<key>], or "
+ "the mappings() accessor on the Result object.",
+ stacklevel=4,
+ )
+
+ def _has_key(self, key):
+ if key in self._keymap:
+ return True
+ else:
+ return self._key_fallback(key, None, False) is not None
+
+
+class ResultFetchStrategy(object):
+ """Define a fetching strategy for a result object.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ()
+
+ alternate_cursor_description = None
+
+ def soft_close(self, result, dbapi_cursor):
+ raise NotImplementedError()
+
+ def hard_close(self, result, dbapi_cursor):
+ raise NotImplementedError()
+
+ def yield_per(self, result, dbapi_cursor, num):
+ return
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ raise NotImplementedError()
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ raise NotImplementedError()
+
+ def fetchall(self, result):
+ raise NotImplementedError()
+
+ def handle_exception(self, result, dbapi_cursor, err):
+ raise err
+
+
+class NoCursorFetchStrategy(ResultFetchStrategy):
+ """Cursor strategy for a result that has no open cursor.
+
+ There are two varieties of this strategy, one for DQL and one for
+ DML (and also DDL), each of which represent a result that had a cursor
+ but no longer has one.
+
+ """
+
+ __slots__ = ()
+
+ def soft_close(self, result, dbapi_cursor):
+ pass
+
+ def hard_close(self, result, dbapi_cursor):
+ pass
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ return self._non_result(result, None)
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ return self._non_result(result, [])
+
+ def fetchall(self, result, dbapi_cursor):
+ return self._non_result(result, [])
+
+ def _non_result(self, result, default, err=None):
+ raise NotImplementedError()
+
+
+class NoCursorDQLFetchStrategy(NoCursorFetchStrategy):
+ """Cursor strategy for a DQL result that has no open cursor.
+
+ This is a result set that can return rows, i.e. for a SELECT, or for an
+ INSERT, UPDATE, DELETE that includes RETURNING. However it is in the state
+ where the cursor is closed and no rows remain available. The owning result
+ object may or may not be "hard closed", which determines if the fetch
+ methods send empty results or raise for closed result.
+
+ """
+
+ __slots__ = ()
+
+ def _non_result(self, result, default, err=None):
+ if result.closed:
+ util.raise_(
+ exc.ResourceClosedError("This result object is closed."),
+ replace_context=err,
+ )
+ else:
+ return default
+
+
+_NO_CURSOR_DQL = NoCursorDQLFetchStrategy()
+
+
+class NoCursorDMLFetchStrategy(NoCursorFetchStrategy):
+ """Cursor strategy for a DML result that has no open cursor.
+
+ This is a result set that does not return rows, i.e. for an INSERT,
+ UPDATE, DELETE that does not include RETURNING.
+
+ """
+
+ __slots__ = ()
+
+ def _non_result(self, result, default, err=None):
+ # we only expect to have a _NoResultMetaData() here right now.
+ assert not result._metadata.returns_rows
+ result._metadata._we_dont_return_rows(err)
+
+
+_NO_CURSOR_DML = NoCursorDMLFetchStrategy()
+
+
+class CursorFetchStrategy(ResultFetchStrategy):
+ """Call fetch methods from a DBAPI cursor.
+
+ Alternate versions of this class may instead buffer the rows from
+ cursors or not use cursors at all.
+
+ """
+
+ __slots__ = ()
+
+ def soft_close(self, result, dbapi_cursor):
+ result.cursor_strategy = _NO_CURSOR_DQL
+
+ def hard_close(self, result, dbapi_cursor):
+ result.cursor_strategy = _NO_CURSOR_DQL
+
+ def handle_exception(self, result, dbapi_cursor, err):
+ result.connection._handle_dbapi_exception(
+ err, None, None, dbapi_cursor, result.context
+ )
+
+ def yield_per(self, result, dbapi_cursor, num):
+ result.cursor_strategy = BufferedRowCursorFetchStrategy(
+ dbapi_cursor,
+ {"max_row_buffer": num},
+ initial_buffer=collections.deque(),
+ growth_factor=0,
+ )
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ try:
+ row = dbapi_cursor.fetchone()
+ if row is None:
+ result._soft_close(hard=hard_close)
+ return row
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ try:
+ if size is None:
+ l = dbapi_cursor.fetchmany()
+ else:
+ l = dbapi_cursor.fetchmany(size)
+
+ if not l:
+ result._soft_close()
+ return l
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+ def fetchall(self, result, dbapi_cursor):
+ try:
+ rows = dbapi_cursor.fetchall()
+ result._soft_close()
+ return rows
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+
+_DEFAULT_FETCH = CursorFetchStrategy()
+
+
+class BufferedRowCursorFetchStrategy(CursorFetchStrategy):
+ """A cursor fetch strategy with row buffering behavior.
+
+ This strategy buffers the contents of a selection of rows
+ before ``fetchone()`` is called. This is to allow the results of
+ ``cursor.description`` to be available immediately, when
+ interfacing with a DB-API that requires rows to be consumed before
+ this information is available (currently psycopg2, when used with
+ server-side cursors).
+
+ The pre-fetching behavior fetches only one row initially, and then
+ grows its buffer size by a fixed amount with each successive need
+ for additional rows up the ``max_row_buffer`` size, which defaults
+ to 1000::
+
+ with psycopg2_engine.connect() as conn:
+
+ result = conn.execution_options(
+ stream_results=True, max_row_buffer=50
+ ).execute(text("select * from table"))
+
+ .. versionadded:: 1.4 ``max_row_buffer`` may now exceed 1000 rows.
+
+ .. seealso::
+
+ :ref:`psycopg2_execution_options`
+ """
+
+ __slots__ = ("_max_row_buffer", "_rowbuffer", "_bufsize", "_growth_factor")
+
+ def __init__(
+ self,
+ dbapi_cursor,
+ execution_options,
+ growth_factor=5,
+ initial_buffer=None,
+ ):
+ self._max_row_buffer = execution_options.get("max_row_buffer", 1000)
+
+ if initial_buffer is not None:
+ self._rowbuffer = initial_buffer
+ else:
+ self._rowbuffer = collections.deque(dbapi_cursor.fetchmany(1))
+ self._growth_factor = growth_factor
+
+ if growth_factor:
+ self._bufsize = min(self._max_row_buffer, self._growth_factor)
+ else:
+ self._bufsize = self._max_row_buffer
+
+ @classmethod
+ def create(cls, result):
+ return BufferedRowCursorFetchStrategy(
+ result.cursor,
+ result.context.execution_options,
+ )
+
+ def _buffer_rows(self, result, dbapi_cursor):
+ """this is currently used only by fetchone()."""
+
+ size = self._bufsize
+ try:
+ if size < 1:
+ new_rows = dbapi_cursor.fetchall()
+ else:
+ new_rows = dbapi_cursor.fetchmany(size)
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+ if not new_rows:
+ return
+ self._rowbuffer = collections.deque(new_rows)
+ if self._growth_factor and size < self._max_row_buffer:
+ self._bufsize = min(
+ self._max_row_buffer, size * self._growth_factor
+ )
+
+ def yield_per(self, result, dbapi_cursor, num):
+ self._growth_factor = 0
+ self._max_row_buffer = self._bufsize = num
+
+ def soft_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(BufferedRowCursorFetchStrategy, self).soft_close(
+ result, dbapi_cursor
+ )
+
+ def hard_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(BufferedRowCursorFetchStrategy, self).hard_close(
+ result, dbapi_cursor
+ )
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ if not self._rowbuffer:
+ self._buffer_rows(result, dbapi_cursor)
+ if not self._rowbuffer:
+ try:
+ result._soft_close(hard=hard_close)
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+ return None
+ return self._rowbuffer.popleft()
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ if size is None:
+ return self.fetchall(result, dbapi_cursor)
+
+ buf = list(self._rowbuffer)
+ lb = len(buf)
+ if size > lb:
+ try:
+ new = dbapi_cursor.fetchmany(size - lb)
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+ else:
+ if not new:
+ result._soft_close()
+ else:
+ buf.extend(new)
+
+ result = buf[0:size]
+ self._rowbuffer = collections.deque(buf[size:])
+ return result
+
+ def fetchall(self, result, dbapi_cursor):
+ try:
+ ret = list(self._rowbuffer) + list(dbapi_cursor.fetchall())
+ self._rowbuffer.clear()
+ result._soft_close()
+ return ret
+ except BaseException as e:
+ self.handle_exception(result, dbapi_cursor, e)
+
+
+class FullyBufferedCursorFetchStrategy(CursorFetchStrategy):
+ """A cursor strategy that buffers rows fully upon creation.
+
+ Used for operations where a result is to be delivered
+ after the database conversation can not be continued,
+ such as MSSQL INSERT...OUTPUT after an autocommit.
+
+ """
+
+ __slots__ = ("_rowbuffer", "alternate_cursor_description")
+
+ def __init__(
+ self, dbapi_cursor, alternate_description=None, initial_buffer=None
+ ):
+ self.alternate_cursor_description = alternate_description
+ if initial_buffer is not None:
+ self._rowbuffer = collections.deque(initial_buffer)
+ else:
+ self._rowbuffer = collections.deque(dbapi_cursor.fetchall())
+
+ def yield_per(self, result, dbapi_cursor, num):
+ pass
+
+ def soft_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(FullyBufferedCursorFetchStrategy, self).soft_close(
+ result, dbapi_cursor
+ )
+
+ def hard_close(self, result, dbapi_cursor):
+ self._rowbuffer.clear()
+ super(FullyBufferedCursorFetchStrategy, self).hard_close(
+ result, dbapi_cursor
+ )
+
+ def fetchone(self, result, dbapi_cursor, hard_close=False):
+ if self._rowbuffer:
+ return self._rowbuffer.popleft()
+ else:
+ result._soft_close(hard=hard_close)
+ return None
+
+ def fetchmany(self, result, dbapi_cursor, size=None):
+ if size is None:
+ return self.fetchall(result, dbapi_cursor)
+
+ buf = list(self._rowbuffer)
+ rows = buf[0:size]
+ self._rowbuffer = collections.deque(buf[size:])
+ if not rows:
+ result._soft_close()
+ return rows
+
+ def fetchall(self, result, dbapi_cursor):
+ ret = self._rowbuffer
+ self._rowbuffer = collections.deque()
+ result._soft_close()
+ return ret
+
+
+class _NoResultMetaData(ResultMetaData):
+ __slots__ = ()
+
+ returns_rows = False
+
+ def _we_dont_return_rows(self, err=None):
+ util.raise_(
+ exc.ResourceClosedError(
+ "This result object does not return rows. "
+ "It has been closed automatically."
+ ),
+ replace_context=err,
+ )
+
+ def _index_for_key(self, keys, raiseerr):
+ self._we_dont_return_rows()
+
+ def _metadata_for_keys(self, key):
+ self._we_dont_return_rows()
+
+ def _reduce(self, keys):
+ self._we_dont_return_rows()
+
+ @property
+ def _keymap(self):
+ self._we_dont_return_rows()
+
+ @property
+ def keys(self):
+ self._we_dont_return_rows()
+
+
+class _LegacyNoResultMetaData(_NoResultMetaData):
+ @property
+ def keys(self):
+ util.warn_deprecated_20(
+ "Calling the .keys() method on a result set that does not return "
+ "rows is deprecated and will raise ResourceClosedError in "
+ "SQLAlchemy 2.0.",
+ )
+ return []
+
+
+_NO_RESULT_METADATA = _NoResultMetaData()
+_LEGACY_NO_RESULT_METADATA = _LegacyNoResultMetaData()
+
+
+class BaseCursorResult(object):
+ """Base class for database result objects."""
+
+ out_parameters = None
+ _metadata = None
+ _soft_closed = False
+ closed = False
+
+ def __init__(self, context, cursor_strategy, cursor_description):
+ self.context = context
+ self.dialect = context.dialect
+ self.cursor = context.cursor
+ self.cursor_strategy = cursor_strategy
+ self.connection = context.root_connection
+ self._echo = echo = (
+ self.connection._echo and context.engine._should_log_debug()
+ )
+
+ if cursor_description is not None:
+ # inline of Result._row_getter(), set up an initial row
+ # getter assuming no transformations will be called as this
+ # is the most common case
+
+ if echo:
+ log = self.context.connection._log_debug
+
+ def log_row(row):
+ log("Row %r", sql_util._repr_row(row))
+ return row
+
+ self._row_logging_fn = log_row
+ else:
+ log_row = None
+
+ metadata = self._init_metadata(context, cursor_description)
+
+ keymap = metadata._keymap
+ processors = metadata._processors
+ process_row = self._process_row
+ key_style = process_row._default_key_style
+ _make_row = functools.partial(
+ process_row, metadata, processors, keymap, key_style
+ )
+ if log_row:
+
+ def make_row(row):
+ made_row = _make_row(row)
+ log_row(made_row)
+ return made_row
+
+ else:
+ make_row = _make_row
+ self._set_memoized_attribute("_row_getter", make_row)
+
+ else:
+ self._metadata = self._no_result_metadata
+
+ def _init_metadata(self, context, cursor_description):
+
+ if context.compiled:
+ if context.compiled._cached_metadata:
+ metadata = self.context.compiled._cached_metadata
+ else:
+ metadata = self._cursor_metadata(self, cursor_description)
+ if metadata._safe_for_cache:
+ context.compiled._cached_metadata = metadata
+
+ # result rewrite/ adapt step. this is to suit the case
+ # when we are invoked against a cached Compiled object, we want
+ # to rewrite the ResultMetaData to reflect the Column objects
+ # that are in our current SQL statement object, not the one
+ # that is associated with the cached Compiled object.
+ # the Compiled object may also tell us to not
+ # actually do this step; this is to support the ORM where
+ # it is to produce a new Result object in any case, and will
+ # be using the cached Column objects against this database result
+ # so we don't want to rewrite them.
+ #
+ # Basically this step suits the use case where the end user
+ # is using Core SQL expressions and is accessing columns in the
+ # result row using row._mapping[table.c.column].
+ compiled = context.compiled
+ if (
+ compiled
+ and compiled._result_columns
+ and context.cache_hit is context.dialect.CACHE_HIT
+ and not context.execution_options.get(
+ "_result_disable_adapt_to_context", False
+ )
+ and compiled.statement is not context.invoked_statement
+ ):
+ metadata = metadata._adapt_to_context(context)
+
+ self._metadata = metadata
+
+ else:
+ self._metadata = metadata = self._cursor_metadata(
+ self, cursor_description
+ )
+ if self._echo:
+ context.connection._log_debug(
+ "Col %r", tuple(x[0] for x in cursor_description)
+ )
+ return metadata
+
+ def _soft_close(self, hard=False):
+ """Soft close this :class:`_engine.CursorResult`.
+
+ This releases all DBAPI cursor resources, but leaves the
+ CursorResult "open" from a semantic perspective, meaning the
+ fetchXXX() methods will continue to return empty results.
+
+ This method is called automatically when:
+
+ * all result rows are exhausted using the fetchXXX() methods.
+ * cursor.description is None.
+
+ This method is **not public**, but is documented in order to clarify
+ the "autoclose" process used.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`_engine.CursorResult.close`
+
+
+ """
+ if (not hard and self._soft_closed) or (hard and self.closed):
+ return
+
+ if hard:
+ self.closed = True
+ self.cursor_strategy.hard_close(self, self.cursor)
+ else:
+ self.cursor_strategy.soft_close(self, self.cursor)
+
+ if not self._soft_closed:
+ cursor = self.cursor
+ self.cursor = None
+ self.connection._safe_close_cursor(cursor)
+ self._soft_closed = True
+
+ @property
+ def inserted_primary_key_rows(self):
+ """Return the value of
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ as a row contained within a list; some dialects may support a
+ multiple row form as well.
+
+ .. note:: As indicated below, in current SQLAlchemy versions this
+ accessor is only useful beyond what's already supplied by
+ :attr:`_engine.CursorResult.inserted_primary_key` when using the
+ :ref:`postgresql_psycopg2` dialect. Future versions hope to
+ generalize this feature to more dialects.
+
+ This accessor is added to support dialects that offer the feature
+ that is currently implemented by the :ref:`psycopg2_executemany_mode`
+ feature, currently **only the psycopg2 dialect**, which provides
+ for many rows to be INSERTed at once while still retaining the
+ behavior of being able to return server-generated primary key values.
+
+ * **When using the psycopg2 dialect, or other dialects that may support
+ "fast executemany" style inserts in upcoming releases** : When
+ invoking an INSERT statement while passing a list of rows as the
+ second argument to :meth:`_engine.Connection.execute`, this accessor
+ will then provide a list of rows, where each row contains the primary
+ key value for each row that was INSERTed.
+
+ * **When using all other dialects / backends that don't yet support
+ this feature**: This accessor is only useful for **single row INSERT
+ statements**, and returns the same information as that of the
+ :attr:`_engine.CursorResult.inserted_primary_key` within a
+ single-element list. When an INSERT statement is executed in
+ conjunction with a list of rows to be INSERTed, the list will contain
+ one row per row inserted in the statement, however it will contain
+ ``None`` for any server-generated values.
+
+ Future releases of SQLAlchemy will further generalize the
+ "fast execution helper" feature of psycopg2 to suit other dialects,
+ thus allowing this accessor to be of more general use.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.CursorResult.inserted_primary_key`
+
+ """
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() " "expression construct."
+ )
+ elif self.context._is_explicit_returning:
+ raise exc.InvalidRequestError(
+ "Can't call inserted_primary_key "
+ "when returning() "
+ "is used."
+ )
+ return self.context.inserted_primary_key_rows
+
+ @property
+ def inserted_primary_key(self):
+ """Return the primary key for the row just inserted.
+
+ The return value is a :class:`_result.Row` object representing
+ a named tuple of primary key values in the order in which the
+ primary key columns are configured in the source
+ :class:`_schema.Table`.
+
+ .. versionchanged:: 1.4.8 - the
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ value is now a named tuple via the :class:`_result.Row` class,
+ rather than a plain tuple.
+
+ This accessor only applies to single row :func:`_expression.insert`
+ constructs which did not explicitly specify
+ :meth:`_expression.Insert.returning`. Support for multirow inserts,
+ while not yet available for most backends, would be accessed using
+ the :attr:`_engine.CursorResult.inserted_primary_key_rows` accessor.
+
+ Note that primary key columns which specify a server_default clause, or
+ otherwise do not qualify as "autoincrement" columns (see the notes at
+ :class:`_schema.Column`), and were generated using the database-side
+ default, will appear in this list as ``None`` unless the backend
+ supports "returning" and the insert statement executed with the
+ "implicit returning" enabled.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() construct.
+
+ """
+
+ if self.context.executemany:
+ raise exc.InvalidRequestError(
+ "This statement was an executemany call; if primary key "
+ "returning is supported, please "
+ "use .inserted_primary_key_rows."
+ )
+
+ ikp = self.inserted_primary_key_rows
+ if ikp:
+ return ikp[0]
+ else:
+ return None
+
+ def last_updated_params(self):
+ """Return the collection of updated parameters from this
+ execution.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an update() construct.
+
+ """
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isupdate:
+ raise exc.InvalidRequestError(
+ "Statement is not an update() " "expression construct."
+ )
+ elif self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
+
+ def last_inserted_params(self):
+ """Return the collection of inserted parameters from this
+ execution.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() construct.
+
+ """
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() " "expression construct."
+ )
+ elif self.context.executemany:
+ return self.context.compiled_parameters
+ else:
+ return self.context.compiled_parameters[0]
+
+ @property
+ def returned_defaults_rows(self):
+ """Return a list of rows each containing the values of default
+ columns that were fetched using
+ the :meth:`.ValuesBase.return_defaults` feature.
+
+ The return value is a list of :class:`.Row` objects.
+
+ .. versionadded:: 1.4
+
+ """
+ return self.context.returned_default_rows
+
+ @property
+ def returned_defaults(self):
+ """Return the values of default columns that were fetched using
+ the :meth:`.ValuesBase.return_defaults` feature.
+
+ The value is an instance of :class:`.Row`, or ``None``
+ if :meth:`.ValuesBase.return_defaults` was not used or if the
+ backend does not support RETURNING.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`.ValuesBase.return_defaults`
+
+ """
+
+ if self.context.executemany:
+ raise exc.InvalidRequestError(
+ "This statement was an executemany call; if return defaults "
+ "is supported, please use .returned_defaults_rows."
+ )
+
+ rows = self.context.returned_default_rows
+ if rows:
+ return rows[0]
+ else:
+ return None
+
+ def lastrow_has_defaults(self):
+ """Return ``lastrow_has_defaults()`` from the underlying
+ :class:`.ExecutionContext`.
+
+ See :class:`.ExecutionContext` for details.
+
+ """
+
+ return self.context.lastrow_has_defaults()
+
+ def postfetch_cols(self):
+ """Return ``postfetch_cols()`` from the underlying
+ :class:`.ExecutionContext`.
+
+ See :class:`.ExecutionContext` for details.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() or update() construct.
+
+ """
+
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert and not self.context.isupdate:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() or update() "
+ "expression construct."
+ )
+ return self.context.postfetch_cols
+
+ def prefetch_cols(self):
+ """Return ``prefetch_cols()`` from the underlying
+ :class:`.ExecutionContext`.
+
+ See :class:`.ExecutionContext` for details.
+
+ Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed
+ statement is not a compiled expression construct
+ or is not an insert() or update() construct.
+
+ """
+
+ if not self.context.compiled:
+ raise exc.InvalidRequestError(
+ "Statement is not a compiled " "expression construct."
+ )
+ elif not self.context.isinsert and not self.context.isupdate:
+ raise exc.InvalidRequestError(
+ "Statement is not an insert() or update() "
+ "expression construct."
+ )
+ return self.context.prefetch_cols
+
+ def supports_sane_rowcount(self):
+ """Return ``supports_sane_rowcount`` from the dialect.
+
+ See :attr:`_engine.CursorResult.rowcount` for background.
+
+ """
+
+ return self.dialect.supports_sane_rowcount
+
+ def supports_sane_multi_rowcount(self):
+ """Return ``supports_sane_multi_rowcount`` from the dialect.
+
+ See :attr:`_engine.CursorResult.rowcount` for background.
+
+ """
+
+ return self.dialect.supports_sane_multi_rowcount
+
+ @util.memoized_property
+ def rowcount(self):
+ """Return the 'rowcount' for this result.
+
+ The 'rowcount' reports the number of rows *matched*
+ by the WHERE criterion of an UPDATE or DELETE statement.
+
+ .. note::
+
+ Notes regarding :attr:`_engine.CursorResult.rowcount`:
+
+
+ * This attribute returns the number of rows *matched*,
+ which is not necessarily the same as the number of rows
+ that were actually *modified* - an UPDATE statement, for example,
+ may have no net change on a given row if the SET values
+ given are the same as those present in the row already.
+ Such a row would be matched but not modified.
+ On backends that feature both styles, such as MySQL,
+ rowcount is configured by default to return the match
+ count in all cases.
+
+ * :attr:`_engine.CursorResult.rowcount`
+ is *only* useful in conjunction
+ with an UPDATE or DELETE statement. Contrary to what the Python
+ DBAPI says, it does *not* return the
+ number of rows available from the results of a SELECT statement
+ as DBAPIs cannot support this functionality when rows are
+ unbuffered.
+
+ * :attr:`_engine.CursorResult.rowcount`
+ may not be fully implemented by
+ all dialects. In particular, most DBAPIs do not support an
+ aggregate rowcount result from an executemany call.
+ The :meth:`_engine.CursorResult.supports_sane_rowcount` and
+ :meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods
+ will report from the dialect if each usage is known to be
+ supported.
+
+ * Statements that use RETURNING may not return a correct
+ rowcount.
+
+ .. seealso::
+
+ :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial`
+
+ """ # noqa: E501
+
+ try:
+ return self.context.rowcount
+ except BaseException as e:
+ self.cursor_strategy.handle_exception(self, self.cursor, e)
+
+ @property
+ def lastrowid(self):
+ """Return the 'lastrowid' accessor on the DBAPI cursor.
+
+ This is a DBAPI specific method and is only functional
+ for those backends which support it, for statements
+ where it is appropriate. It's behavior is not
+ consistent across backends.
+
+ Usage of this method is normally unnecessary when
+ using insert() expression constructs; the
+ :attr:`~CursorResult.inserted_primary_key` attribute provides a
+ tuple of primary key values for a newly inserted row,
+ regardless of database backend.
+
+ """
+ try:
+ return self.context.get_lastrowid()
+ except BaseException as e:
+ self.cursor_strategy.handle_exception(self, self.cursor, e)
+
+ @property
+ def returns_rows(self):
+ """True if this :class:`_engine.CursorResult` returns zero or more
+ rows.
+
+ I.e. if it is legal to call the methods
+ :meth:`_engine.CursorResult.fetchone`,
+ :meth:`_engine.CursorResult.fetchmany`
+ :meth:`_engine.CursorResult.fetchall`.
+
+ Overall, the value of :attr:`_engine.CursorResult.returns_rows` should
+ always be synonymous with whether or not the DBAPI cursor had a
+ ``.description`` attribute, indicating the presence of result columns,
+ noting that a cursor that returns zero rows still has a
+ ``.description`` if a row-returning statement was emitted.
+
+ This attribute should be True for all results that are against
+ SELECT statements, as well as for DML statements INSERT/UPDATE/DELETE
+ that use RETURNING. For INSERT/UPDATE/DELETE statements that were
+ not using RETURNING, the value will usually be False, however
+ there are some dialect-specific exceptions to this, such as when
+ using the MSSQL / pyodbc dialect a SELECT is emitted inline in
+ order to retrieve an inserted primary key value.
+
+
+ """
+ return self._metadata.returns_rows
+
+ @property
+ def is_insert(self):
+ """True if this :class:`_engine.CursorResult` is the result
+ of a executing an expression language compiled
+ :func:`_expression.insert` construct.
+
+ When True, this implies that the
+ :attr:`inserted_primary_key` attribute is accessible,
+ assuming the statement did not include
+ a user defined "returning" construct.
+
+ """
+ return self.context.isinsert
+
+
+class CursorResult(BaseCursorResult, Result):
+ """A Result that is representing state from a DBAPI cursor.
+
+ .. versionchanged:: 1.4 The :class:`.CursorResult` and
+ :class:`.LegacyCursorResult`
+ classes replace the previous :class:`.ResultProxy` interface.
+ These classes are based on the :class:`.Result` calling API
+ which provides an updated usage model and calling facade for
+ SQLAlchemy Core and SQLAlchemy ORM.
+
+ Returns database rows via the :class:`.Row` class, which provides
+ additional API features and behaviors on top of the raw data returned by
+ the DBAPI. Through the use of filters such as the :meth:`.Result.scalars`
+ method, other kinds of objects may also be returned.
+
+ Within the scope of the 1.x series of SQLAlchemy, Core SQL results in
+ version 1.4 return an instance of :class:`._engine.LegacyCursorResult`
+ which takes the place of the ``CursorResult`` class used for the 1.3 series
+ and previously. This object returns rows as :class:`.LegacyRow` objects,
+ which maintains Python mapping (i.e. dictionary) like behaviors upon the
+ object itself. Going forward, the :attr:`.Row._mapping` attribute should
+ be used for dictionary behaviors.
+
+ .. seealso::
+
+ :ref:`coretutorial_selecting` - introductory material for accessing
+ :class:`_engine.CursorResult` and :class:`.Row` objects.
+
+ """
+
+ _cursor_metadata = CursorResultMetaData
+ _cursor_strategy_cls = CursorFetchStrategy
+ _no_result_metadata = _NO_RESULT_METADATA
+ _is_cursor = True
+
+ def _fetchiter_impl(self):
+ fetchone = self.cursor_strategy.fetchone
+
+ while True:
+ row = fetchone(self, self.cursor)
+ if row is None:
+ break
+ yield row
+
+ def _fetchone_impl(self, hard_close=False):
+ return self.cursor_strategy.fetchone(self, self.cursor, hard_close)
+
+ def _fetchall_impl(self):
+ return self.cursor_strategy.fetchall(self, self.cursor)
+
+ def _fetchmany_impl(self, size=None):
+ return self.cursor_strategy.fetchmany(self, self.cursor, size)
+
+ def _raw_row_iterator(self):
+ return self._fetchiter_impl()
+
+ def merge(self, *others):
+ merged_result = super(CursorResult, self).merge(*others)
+ setup_rowcounts = not self._metadata.returns_rows
+ if setup_rowcounts:
+ merged_result.rowcount = sum(
+ result.rowcount for result in (self,) + others
+ )
+ return merged_result
+
+ def close(self):
+ """Close this :class:`_engine.CursorResult`.
+
+ This closes out the underlying DBAPI cursor corresponding to the
+ statement execution, if one is still present. Note that the DBAPI
+ cursor is automatically released when the :class:`_engine.CursorResult`
+ exhausts all available rows. :meth:`_engine.CursorResult.close` is
+ generally an optional method except in the case when discarding a
+ :class:`_engine.CursorResult` that still has additional rows pending
+ for fetch.
+
+ After this method is called, it is no longer valid to call upon
+ the fetch methods, which will raise a :class:`.ResourceClosedError`
+ on subsequent use.
+
+ .. seealso::
+
+ :ref:`connections_toplevel`
+
+ """
+ self._soft_close(hard=True)
+
+ @_generative
+ def yield_per(self, num):
+ self._yield_per = num
+ self.cursor_strategy.yield_per(self, self.cursor, num)
+
+
+class LegacyCursorResult(CursorResult):
+ """Legacy version of :class:`.CursorResult`.
+
+ This class includes connection "connection autoclose" behavior for use with
+ "connectionless" execution, as well as delivers rows using the
+ :class:`.LegacyRow` row implementation.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _autoclose_connection = False
+ _process_row = LegacyRow
+ _cursor_metadata = LegacyCursorResultMetaData
+ _cursor_strategy_cls = CursorFetchStrategy
+
+ _no_result_metadata = _LEGACY_NO_RESULT_METADATA
+
+ def close(self):
+ """Close this :class:`_engine.LegacyCursorResult`.
+
+ This method has the same behavior as that of
+ :meth:`._engine.CursorResult`, but it also may close
+ the underlying :class:`.Connection` for the case of "connectionless"
+ execution.
+
+ .. deprecated:: 2.0 "connectionless" execution is deprecated and will
+ be removed in version 2.0. Version 2.0 will feature the
+ :class:`_future.Result`
+ object that will no longer affect the status
+ of the originating connection in any case.
+
+ After this method is called, it is no longer valid to call upon
+ the fetch methods, which will raise a :class:`.ResourceClosedError`
+ on subsequent use.
+
+ .. seealso::
+
+ :ref:`connections_toplevel`
+
+ :ref:`dbengine_implicit`
+ """
+ self._soft_close(hard=True)
+
+ def _soft_close(self, hard=False):
+ soft_closed = self._soft_closed
+ super(LegacyCursorResult, self)._soft_close(hard=hard)
+ if (
+ not soft_closed
+ and self._soft_closed
+ and self._autoclose_connection
+ ):
+ self.connection.close()
+
+
+ResultProxy = LegacyCursorResult
+
+
+class BufferedRowResultProxy(ResultProxy):
+ """A ResultProxy with row buffering behavior.
+
+ .. deprecated:: 1.4 this class is now supplied using a strategy object.
+ See :class:`.BufferedRowCursorFetchStrategy`.
+
+ """
+
+ _cursor_strategy_cls = BufferedRowCursorFetchStrategy
+
+
+class FullyBufferedResultProxy(ResultProxy):
+ """A result proxy that buffers rows fully upon creation.
+
+ .. deprecated:: 1.4 this class is now supplied using a strategy object.
+ See :class:`.FullyBufferedCursorFetchStrategy`.
+
+ """
+
+ _cursor_strategy_cls = FullyBufferedCursorFetchStrategy
+
+
+class BufferedColumnRow(LegacyRow):
+ """Row is now BufferedColumn in all cases"""
+
+
+class BufferedColumnResultProxy(ResultProxy):
+ """A ResultProxy with column buffering behavior.
+
+ .. versionchanged:: 1.4 This is now the default behavior of the Row
+ and this class does not change behavior in any way.
+
+ """
+
+ _process_row = BufferedColumnRow
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
new file mode 100644
index 0000000..268a2d6
--- /dev/null
+++ b/lib/sqlalchemy/engine/default.py
@@ -0,0 +1,1936 @@
+# engine/default.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Default implementations of per-dialect sqlalchemy.engine classes.
+
+These are semi-private implementation classes which are only of importance
+to database dialect authors; dialects will usually use the classes here
+as the base class for their own corresponding classes.
+
+"""
+
+import codecs
+import functools
+import random
+import re
+import weakref
+
+from . import characteristics
+from . import cursor as _cursor
+from . import interfaces
+from .base import Connection
+from .. import event
+from .. import exc
+from .. import pool
+from .. import processors
+from .. import types as sqltypes
+from .. import util
+from ..sql import compiler
+from ..sql import expression
+from ..sql.elements import quoted_name
+
+AUTOCOMMIT_REGEXP = re.compile(
+ r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE
+)
+
+# When we're handed literal SQL, ensure it's a SELECT query
+SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE)
+
+
+CACHE_HIT = util.symbol("CACHE_HIT")
+CACHE_MISS = util.symbol("CACHE_MISS")
+CACHING_DISABLED = util.symbol("CACHING_DISABLED")
+NO_CACHE_KEY = util.symbol("NO_CACHE_KEY")
+NO_DIALECT_SUPPORT = util.symbol("NO_DIALECT_SUPPORT")
+
+
+class DefaultDialect(interfaces.Dialect):
+ """Default implementation of Dialect"""
+
+ statement_compiler = compiler.SQLCompiler
+ ddl_compiler = compiler.DDLCompiler
+ type_compiler = compiler.GenericTypeCompiler
+ preparer = compiler.IdentifierPreparer
+ supports_alter = True
+ supports_comments = False
+ inline_comments = False
+ use_setinputsizes = False
+ supports_statement_cache = True
+
+ # the first value we'd get for an autoincrement
+ # column.
+ default_sequence_base = 1
+
+ # most DBAPIs happy with this for execute().
+ # not cx_oracle.
+ execute_sequence_format = tuple
+
+ supports_schemas = True
+ supports_views = True
+ supports_sequences = False
+ sequences_optional = False
+ preexecute_autoincrement_sequences = False
+ supports_identity_columns = False
+ postfetch_lastrowid = True
+ implicit_returning = False
+ full_returning = False
+ insert_executemany_returning = False
+
+ cte_follows_insert = False
+
+ supports_native_enum = False
+ supports_native_boolean = False
+ non_native_boolean_check_constraint = True
+
+ supports_simple_order_by_label = True
+
+ tuple_in_values = False
+
+ connection_characteristics = util.immutabledict(
+ {"isolation_level": characteristics.IsolationLevelCharacteristic()}
+ )
+
+ engine_config_types = util.immutabledict(
+ [
+ ("convert_unicode", util.bool_or_str("force")),
+ ("pool_timeout", util.asint),
+ ("echo", util.bool_or_str("debug")),
+ ("echo_pool", util.bool_or_str("debug")),
+ ("pool_recycle", util.asint),
+ ("pool_size", util.asint),
+ ("max_overflow", util.asint),
+ ("future", util.asbool),
+ ]
+ )
+
+ # if the NUMERIC type
+ # returns decimal.Decimal.
+ # *not* the FLOAT type however.
+ supports_native_decimal = False
+
+ if util.py3k:
+ supports_unicode_statements = True
+ supports_unicode_binds = True
+ returns_unicode_strings = sqltypes.String.RETURNS_UNICODE
+ description_encoding = None
+ else:
+ supports_unicode_statements = False
+ supports_unicode_binds = False
+ returns_unicode_strings = sqltypes.String.RETURNS_UNKNOWN
+ description_encoding = "use_encoding"
+
+ name = "default"
+
+ # length at which to truncate
+ # any identifier.
+ max_identifier_length = 9999
+ _user_defined_max_identifier_length = None
+
+ isolation_level = None
+
+ # sub-categories of max_identifier_length.
+ # currently these accommodate for MySQL which allows alias names
+ # of 255 but DDL names only of 64.
+ max_index_name_length = None
+ max_constraint_name_length = None
+
+ supports_sane_rowcount = True
+ supports_sane_multi_rowcount = True
+ colspecs = {}
+ default_paramstyle = "named"
+
+ supports_default_values = False
+ """dialect supports INSERT... DEFAULT VALUES syntax"""
+
+ supports_default_metavalue = False
+ """dialect supports INSERT... VALUES (DEFAULT) syntax"""
+
+ # not sure if this is a real thing but the compiler will deliver it
+ # if this is the only flag enabled.
+ supports_empty_insert = True
+ """dialect supports INSERT () VALUES ()"""
+
+ supports_multivalues_insert = False
+
+ supports_is_distinct_from = True
+
+ supports_server_side_cursors = False
+
+ server_side_cursors = False
+
+ # extra record-level locking features (#4860)
+ supports_for_update_of = False
+
+ server_version_info = None
+
+ default_schema_name = None
+
+ construct_arguments = None
+ """Optional set of argument specifiers for various SQLAlchemy
+ constructs, typically schema items.
+
+ To implement, establish as a series of tuples, as in::
+
+ construct_arguments = [
+ (schema.Index, {
+ "using": False,
+ "where": None,
+ "ops": None
+ })
+ ]
+
+ If the above construct is established on the PostgreSQL dialect,
+ the :class:`.Index` construct will now accept the keyword arguments
+ ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``.
+ Any other argument specified to the constructor of :class:`.Index`
+ which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`.
+
+ A dialect which does not include a ``construct_arguments`` member will
+ not participate in the argument validation system. For such a dialect,
+ any argument name is accepted by all participating constructs, within
+ the namespace of arguments prefixed with that dialect name. The rationale
+ here is so that third-party dialects that haven't yet implemented this
+ feature continue to function in the old way.
+
+ .. versionadded:: 0.9.2
+
+ .. seealso::
+
+ :class:`.DialectKWArgs` - implementing base class which consumes
+ :attr:`.DefaultDialect.construct_arguments`
+
+
+ """
+
+ # indicates symbol names are
+ # UPPERCASEd if they are case insensitive
+ # within the database.
+ # if this is True, the methods normalize_name()
+ # and denormalize_name() must be provided.
+ requires_name_normalize = False
+
+ reflection_options = ()
+
+ dbapi_exception_translation_map = util.immutabledict()
+ """mapping used in the extremely unusual case that a DBAPI's
+ published exceptions don't actually have the __name__ that they
+ are linked towards.
+
+ .. versionadded:: 1.0.5
+
+ """
+
+ is_async = False
+
+ CACHE_HIT = CACHE_HIT
+ CACHE_MISS = CACHE_MISS
+ CACHING_DISABLED = CACHING_DISABLED
+ NO_CACHE_KEY = NO_CACHE_KEY
+ NO_DIALECT_SUPPORT = NO_DIALECT_SUPPORT
+
+ @util.deprecated_params(
+ convert_unicode=(
+ "1.3",
+ "The :paramref:`_sa.create_engine.convert_unicode` parameter "
+ "and corresponding dialect-level parameters are deprecated, "
+ "and will be removed in a future release. Modern DBAPIs support "
+ "Python Unicode natively and this parameter is unnecessary.",
+ ),
+ empty_in_strategy=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is "
+ "deprecated, and no longer has any effect. All IN expressions "
+ "are now rendered using "
+ 'the "expanding parameter" strategy which renders a set of bound'
+ 'expressions, or an "empty set" SELECT, at statement execution'
+ "time.",
+ ),
+ case_sensitive=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.case_sensitive` parameter "
+ "is deprecated and will be removed in a future release. "
+ "Applications should work with result column names in a case "
+ "sensitive fashion.",
+ ),
+ server_side_cursors=(
+ "1.4",
+ "The :paramref:`_sa.create_engine.server_side_cursors` parameter "
+ "is deprecated and will be removed in a future release. Please "
+ "use the "
+ ":paramref:`_engine.Connection.execution_options.stream_results` "
+ "parameter.",
+ ),
+ )
+ def __init__(
+ self,
+ convert_unicode=False,
+ encoding="utf-8",
+ paramstyle=None,
+ dbapi=None,
+ implicit_returning=None,
+ case_sensitive=True,
+ supports_native_boolean=None,
+ max_identifier_length=None,
+ label_length=None,
+ # int() is because the @deprecated_params decorator cannot accommodate
+ # the direct reference to the "NO_LINTING" object
+ compiler_linting=int(compiler.NO_LINTING),
+ server_side_cursors=False,
+ **kwargs
+ ):
+
+ if not getattr(self, "ported_sqla_06", True):
+ util.warn(
+ "The %s dialect is not yet ported to the 0.6 format"
+ % self.name
+ )
+
+ if server_side_cursors:
+ if not self.supports_server_side_cursors:
+ raise exc.ArgumentError(
+ "Dialect %s does not support server side cursors" % self
+ )
+ else:
+ self.server_side_cursors = True
+
+ self.convert_unicode = convert_unicode
+ self.encoding = encoding
+ self.positional = False
+ self._ischema = None
+ self.dbapi = dbapi
+ if paramstyle is not None:
+ self.paramstyle = paramstyle
+ elif self.dbapi is not None:
+ self.paramstyle = self.dbapi.paramstyle
+ else:
+ self.paramstyle = self.default_paramstyle
+ if implicit_returning is not None:
+ self.implicit_returning = implicit_returning
+ self.positional = self.paramstyle in ("qmark", "format", "numeric")
+ self.identifier_preparer = self.preparer(self)
+ self.type_compiler = self.type_compiler(self)
+ if supports_native_boolean is not None:
+ self.supports_native_boolean = supports_native_boolean
+ self.case_sensitive = case_sensitive
+
+ self._user_defined_max_identifier_length = max_identifier_length
+ if self._user_defined_max_identifier_length:
+ self.max_identifier_length = (
+ self._user_defined_max_identifier_length
+ )
+ self.label_length = label_length
+ self.compiler_linting = compiler_linting
+ if self.description_encoding == "use_encoding":
+ self._description_decoder = (
+ processors.to_unicode_processor_factory
+ )(encoding)
+ elif self.description_encoding is not None:
+ self._description_decoder = (
+ processors.to_unicode_processor_factory
+ )(self.description_encoding)
+ self._encoder = codecs.getencoder(self.encoding)
+ self._decoder = processors.to_unicode_processor_factory(self.encoding)
+
+ def _ensure_has_table_connection(self, arg):
+
+ if not isinstance(arg, Connection):
+ raise exc.ArgumentError(
+ "The argument passed to Dialect.has_table() should be a "
+ "%s, got %s. "
+ "Additionally, the Dialect.has_table() method is for "
+ "internal dialect "
+ "use only; please use "
+ "``inspect(some_engine).has_table(<tablename>>)`` "
+ "for public API use." % (Connection, type(arg))
+ )
+
+ @util.memoized_property
+ def _supports_statement_cache(self):
+ ssc = self.__class__.__dict__.get("supports_statement_cache", None)
+ if ssc is None:
+ util.warn(
+ "Dialect %s:%s will not make use of SQL compilation caching "
+ "as it does not set the 'supports_statement_cache' attribute "
+ "to ``True``. This can have "
+ "significant performance implications including some "
+ "performance degradations in comparison to prior SQLAlchemy "
+ "versions. Dialect maintainers should seek to set this "
+ "attribute to True after appropriate development and testing "
+ "for SQLAlchemy 1.4 caching support. Alternatively, this "
+ "attribute may be set to False which will disable this "
+ "warning." % (self.name, self.driver),
+ code="cprf",
+ )
+
+ return bool(ssc)
+
+ @util.memoized_property
+ def _type_memos(self):
+ return weakref.WeakKeyDictionary()
+
+ @property
+ def dialect_description(self):
+ return self.name + "+" + self.driver
+
+ @property
+ def supports_sane_rowcount_returning(self):
+ """True if this dialect supports sane rowcount even if RETURNING is
+ in use.
+
+ For dialects that don't support RETURNING, this is synonymous with
+ ``supports_sane_rowcount``.
+
+ """
+ return self.supports_sane_rowcount
+
+ @classmethod
+ def get_pool_class(cls, url):
+ return getattr(cls, "poolclass", pool.QueuePool)
+
+ def get_dialect_pool_class(self, url):
+ return self.get_pool_class(url)
+
+ @classmethod
+ def load_provisioning(cls):
+ package = ".".join(cls.__module__.split(".")[0:-1])
+ try:
+ __import__(package + ".provision")
+ except ImportError:
+ pass
+
+ def initialize(self, connection):
+ try:
+ self.server_version_info = self._get_server_version_info(
+ connection
+ )
+ except NotImplementedError:
+ self.server_version_info = None
+ try:
+ self.default_schema_name = self._get_default_schema_name(
+ connection
+ )
+ except NotImplementedError:
+ self.default_schema_name = None
+
+ try:
+ self.default_isolation_level = self.get_default_isolation_level(
+ connection.connection
+ )
+ except NotImplementedError:
+ self.default_isolation_level = None
+
+ if self.returns_unicode_strings is sqltypes.String.RETURNS_UNKNOWN:
+ if util.py3k:
+ raise exc.InvalidRequestError(
+ "RETURNS_UNKNOWN is unsupported in Python 3"
+ )
+ self.returns_unicode_strings = self._check_unicode_returns(
+ connection
+ )
+
+ if (
+ self.description_encoding is not None
+ and self._check_unicode_description(connection)
+ ):
+ self._description_decoder = self.description_encoding = None
+
+ if not self._user_defined_max_identifier_length:
+ max_ident_length = self._check_max_identifier_length(connection)
+ if max_ident_length:
+ self.max_identifier_length = max_ident_length
+
+ if (
+ self.label_length
+ and self.label_length > self.max_identifier_length
+ ):
+ raise exc.ArgumentError(
+ "Label length of %d is greater than this dialect's"
+ " maximum identifier length of %d"
+ % (self.label_length, self.max_identifier_length)
+ )
+
+ def on_connect(self):
+ # inherits the docstring from interfaces.Dialect.on_connect
+ return None
+
+ def _check_max_identifier_length(self, connection):
+ """Perform a connection / server version specific check to determine
+ the max_identifier_length.
+
+ If the dialect's class level max_identifier_length should be used,
+ can return None.
+
+ .. versionadded:: 1.3.9
+
+ """
+ return None
+
+ def get_default_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, return its isolation level, or
+ a default isolation level if one cannot be retrieved.
+
+ May be overridden by subclasses in order to provide a
+ "fallback" isolation level for databases that cannot reliably
+ retrieve the actual isolation level.
+
+ By default, calls the :meth:`_engine.Interfaces.get_isolation_level`
+ method, propagating any exceptions raised.
+
+ .. versionadded:: 1.3.22
+
+ """
+ return self.get_isolation_level(dbapi_conn)
+
+ def _check_unicode_returns(self, connection, additional_tests=None):
+ # this now runs in py2k only and will be removed in 2.0; disabled for
+ # Python 3 in all cases under #5315
+ if util.py2k and not self.supports_unicode_statements:
+ cast_to = util.binary_type
+ else:
+ cast_to = util.text_type
+
+ if self.positional:
+ parameters = self.execute_sequence_format()
+ else:
+ parameters = {}
+
+ def check_unicode(test):
+ statement = cast_to(expression.select(test).compile(dialect=self))
+ try:
+ cursor = connection.connection.cursor()
+ connection._cursor_execute(cursor, statement, parameters)
+ row = cursor.fetchone()
+ cursor.close()
+ except exc.DBAPIError as de:
+ # note that _cursor_execute() will have closed the cursor
+ # if an exception is thrown.
+ util.warn(
+ "Exception attempting to "
+ "detect unicode returns: %r" % de
+ )
+ return False
+ else:
+ return isinstance(row[0], util.text_type)
+
+ tests = [
+ # detect plain VARCHAR
+ expression.cast(
+ expression.literal_column("'test plain returns'"),
+ sqltypes.VARCHAR(60),
+ ),
+ # detect if there's an NVARCHAR type with different behavior
+ # available
+ expression.cast(
+ expression.literal_column("'test unicode returns'"),
+ sqltypes.Unicode(60),
+ ),
+ ]
+
+ if additional_tests:
+ tests += additional_tests
+
+ results = {check_unicode(test) for test in tests}
+
+ if results.issuperset([True, False]):
+ return sqltypes.String.RETURNS_CONDITIONAL
+ else:
+ return (
+ sqltypes.String.RETURNS_UNICODE
+ if results == {True}
+ else sqltypes.String.RETURNS_BYTES
+ )
+
+ def _check_unicode_description(self, connection):
+ # all DBAPIs on Py2K return cursor.description as encoded
+
+ if util.py2k and not self.supports_unicode_statements:
+ cast_to = util.binary_type
+ else:
+ cast_to = util.text_type
+
+ cursor = connection.connection.cursor()
+ try:
+ cursor.execute(
+ cast_to(
+ expression.select(
+ expression.literal_column("'x'").label("some_label")
+ ).compile(dialect=self)
+ )
+ )
+ return isinstance(cursor.description[0][0], util.text_type)
+ finally:
+ cursor.close()
+
+ def type_descriptor(self, typeobj):
+ """Provide a database-specific :class:`.TypeEngine` object, given
+ the generic object which comes from the types module.
+
+ This method looks for a dictionary called
+ ``colspecs`` as a class or instance-level variable,
+ and passes on to :func:`_types.adapt_type`.
+
+ """
+ return sqltypes.adapt_type(typeobj, self.colspecs)
+
+ def has_index(self, connection, table_name, index_name, schema=None):
+ if not self.has_table(connection, table_name, schema=schema):
+ return False
+ for idx in self.get_indexes(connection, table_name, schema=schema):
+ if idx["name"] == index_name:
+ return True
+ else:
+ return False
+
+ def validate_identifier(self, ident):
+ if len(ident) > self.max_identifier_length:
+ raise exc.IdentifierError(
+ "Identifier '%s' exceeds maximum length of %d characters"
+ % (ident, self.max_identifier_length)
+ )
+
+ def connect(self, *cargs, **cparams):
+ # inherits the docstring from interfaces.Dialect.connect
+ return self.dbapi.connect(*cargs, **cparams)
+
+ def create_connect_args(self, url):
+ # inherits the docstring from interfaces.Dialect.create_connect_args
+ opts = url.translate_connect_args()
+ opts.update(url.query)
+ return [[], opts]
+
+ def set_engine_execution_options(self, engine, opts):
+ supported_names = set(self.connection_characteristics).intersection(
+ opts
+ )
+ if supported_names:
+ characteristics = util.immutabledict(
+ (name, opts[name]) for name in supported_names
+ )
+
+ @event.listens_for(engine, "engine_connect")
+ def set_connection_characteristics(connection, branch):
+ if not branch:
+ self._set_connection_characteristics(
+ connection, characteristics
+ )
+
+ def set_connection_execution_options(self, connection, opts):
+ supported_names = set(self.connection_characteristics).intersection(
+ opts
+ )
+ if supported_names:
+ characteristics = util.immutabledict(
+ (name, opts[name]) for name in supported_names
+ )
+ self._set_connection_characteristics(connection, characteristics)
+
+ def _set_connection_characteristics(self, connection, characteristics):
+
+ characteristic_values = [
+ (name, self.connection_characteristics[name], value)
+ for name, value in characteristics.items()
+ ]
+
+ if connection.in_transaction():
+ trans_objs = [
+ (name, obj)
+ for name, obj, value in characteristic_values
+ if obj.transactional
+ ]
+ if trans_objs:
+ if connection._is_future:
+ raise exc.InvalidRequestError(
+ "This connection has already initialized a SQLAlchemy "
+ "Transaction() object via begin() or autobegin; "
+ "%s may not be altered unless rollback() or commit() "
+ "is called first."
+ % (", ".join(name for name, obj in trans_objs))
+ )
+ else:
+ util.warn(
+ "Connection is already established with a "
+ "Transaction; "
+ "setting %s may implicitly rollback or "
+ "commit "
+ "the existing transaction, or have no effect until "
+ "next transaction"
+ % (", ".join(name for name, obj in trans_objs))
+ )
+
+ dbapi_connection = connection.connection.dbapi_connection
+ for name, characteristic, value in characteristic_values:
+ characteristic.set_characteristic(self, dbapi_connection, value)
+ connection.connection._connection_record.finalize_callback.append(
+ functools.partial(self._reset_characteristics, characteristics)
+ )
+
+ def _reset_characteristics(self, characteristics, dbapi_connection):
+ for characteristic_name in characteristics:
+ characteristic = self.connection_characteristics[
+ characteristic_name
+ ]
+ characteristic.reset_characteristic(self, dbapi_connection)
+
+ def do_begin(self, dbapi_connection):
+ pass
+
+ def do_rollback(self, dbapi_connection):
+ dbapi_connection.rollback()
+
+ def do_commit(self, dbapi_connection):
+ dbapi_connection.commit()
+
+ def do_close(self, dbapi_connection):
+ dbapi_connection.close()
+
+ @util.memoized_property
+ def _dialect_specific_select_one(self):
+ return str(expression.select(1).compile(dialect=self))
+
+ def do_ping(self, dbapi_connection):
+ cursor = None
+ try:
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute(self._dialect_specific_select_one)
+ finally:
+ cursor.close()
+ except self.dbapi.Error as err:
+ if self.is_disconnect(err, dbapi_connection, cursor):
+ return False
+ else:
+ raise
+ else:
+ return True
+
+ def create_xid(self):
+ """Create a random two-phase transaction ID.
+
+ This id will be passed to do_begin_twophase(), do_rollback_twophase(),
+ do_commit_twophase(). Its format is unspecified.
+ """
+
+ return "_sa_%032x" % random.randint(0, 2 ** 128)
+
+ def do_savepoint(self, connection, name):
+ connection.execute(expression.SavepointClause(name))
+
+ def do_rollback_to_savepoint(self, connection, name):
+ connection.execute(expression.RollbackToSavepointClause(name))
+
+ def do_release_savepoint(self, connection, name):
+ connection.execute(expression.ReleaseSavepointClause(name))
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ cursor.executemany(statement, parameters)
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ cursor.execute(statement, parameters)
+
+ def do_execute_no_params(self, cursor, statement, context=None):
+ cursor.execute(statement)
+
+ def is_disconnect(self, e, connection, cursor):
+ return False
+
+ def reset_isolation_level(self, dbapi_conn):
+ # default_isolation_level is read from the first connection
+ # after the initial set of 'isolation_level', if any, so is
+ # the configured default of this dialect.
+ self.set_isolation_level(dbapi_conn, self.default_isolation_level)
+
+ def normalize_name(self, name):
+ if name is None:
+ return None
+ if util.py2k:
+ if isinstance(name, str):
+ name = name.decode(self.encoding)
+
+ name_lower = name.lower()
+ name_upper = name.upper()
+
+ if name_upper == name_lower:
+ # name has no upper/lower conversion, e.g. non-european characters.
+ # return unchanged
+ return name
+ elif name_upper == name and not (
+ self.identifier_preparer._requires_quotes
+ )(name_lower):
+ # name is all uppercase and doesn't require quoting; normalize
+ # to all lower case
+ return name_lower
+ elif name_lower == name:
+ # name is all lower case, which if denormalized means we need to
+ # force quoting on it
+ return quoted_name(name, quote=True)
+ else:
+ # name is mixed case, means it will be quoted in SQL when used
+ # later, no normalizes
+ return name
+
+ def denormalize_name(self, name):
+ if name is None:
+ return None
+
+ name_lower = name.lower()
+ name_upper = name.upper()
+
+ if name_upper == name_lower:
+ # name has no upper/lower conversion, e.g. non-european characters.
+ # return unchanged
+ return name
+ elif name_lower == name and not (
+ self.identifier_preparer._requires_quotes
+ )(name_lower):
+ name = name_upper
+ if util.py2k:
+ if not self.supports_unicode_binds:
+ name = name.encode(self.encoding)
+ else:
+ name = unicode(name) # noqa
+ return name
+
+ def get_driver_connection(self, connection):
+ return connection
+
+
+class _RendersLiteral(object):
+ def literal_processor(self, dialect):
+ def process(value):
+ return "'%s'" % value
+
+ return process
+
+
+class _StrDateTime(_RendersLiteral, sqltypes.DateTime):
+ pass
+
+
+class _StrDate(_RendersLiteral, sqltypes.Date):
+ pass
+
+
+class _StrTime(_RendersLiteral, sqltypes.Time):
+ pass
+
+
+class StrCompileDialect(DefaultDialect):
+
+ statement_compiler = compiler.StrSQLCompiler
+ ddl_compiler = compiler.DDLCompiler
+ type_compiler = compiler.StrSQLTypeCompiler
+ preparer = compiler.IdentifierPreparer
+
+ supports_statement_cache = True
+
+ supports_identity_columns = True
+
+ supports_sequences = True
+ sequences_optional = True
+ preexecute_autoincrement_sequences = False
+ implicit_returning = False
+
+ supports_native_boolean = True
+
+ supports_multivalues_insert = True
+ supports_simple_order_by_label = True
+
+ colspecs = {
+ sqltypes.DateTime: _StrDateTime,
+ sqltypes.Date: _StrDate,
+ sqltypes.Time: _StrTime,
+ }
+
+
+class DefaultExecutionContext(interfaces.ExecutionContext):
+ isinsert = False
+ isupdate = False
+ isdelete = False
+ is_crud = False
+ is_text = False
+ isddl = False
+ executemany = False
+ compiled = None
+ statement = None
+ result_column_struct = None
+ returned_default_rows = None
+ execution_options = util.immutabledict()
+
+ include_set_input_sizes = None
+ exclude_set_input_sizes = None
+
+ cursor_fetch_strategy = _cursor._DEFAULT_FETCH
+
+ cache_stats = None
+ invoked_statement = None
+
+ _is_implicit_returning = False
+ _is_explicit_returning = False
+ _is_future_result = False
+ _is_server_side = False
+
+ _soft_closed = False
+
+ # a hook for SQLite's translation of
+ # result column names
+ # NOTE: pyhive is using this hook, can't remove it :(
+ _translate_colname = None
+
+ _expanded_parameters = util.immutabledict()
+
+ cache_hit = NO_CACHE_KEY
+
+ @classmethod
+ def _init_ddl(
+ cls,
+ dialect,
+ connection,
+ dbapi_connection,
+ execution_options,
+ compiled_ddl,
+ ):
+ """Initialize execution context for a DDLElement construct."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+
+ self.compiled = compiled = compiled_ddl
+ self.isddl = True
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ self.unicode_statement = util.text_type(compiled)
+ if compiled.schema_translate_map:
+ schema_translate_map = self.execution_options.get(
+ "schema_translate_map", {}
+ )
+
+ rst = compiled.preparer._render_schema_translates
+ self.unicode_statement = rst(
+ self.unicode_statement, schema_translate_map
+ )
+
+ if not dialect.supports_unicode_statements:
+ self.statement = dialect._encoder(self.unicode_statement)[0]
+ else:
+ self.statement = self.unicode_statement
+
+ self.cursor = self.create_cursor()
+ self.compiled_parameters = []
+
+ if dialect.positional:
+ self.parameters = [dialect.execute_sequence_format()]
+ else:
+ self.parameters = [{}]
+
+ return self
+
+ @classmethod
+ def _init_compiled(
+ cls,
+ dialect,
+ connection,
+ dbapi_connection,
+ execution_options,
+ compiled,
+ parameters,
+ invoked_statement,
+ extracted_parameters,
+ cache_hit=CACHING_DISABLED,
+ ):
+ """Initialize execution context for a Compiled construct."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+ self.extracted_parameters = extracted_parameters
+ self.invoked_statement = invoked_statement
+ self.compiled = compiled
+ self.cache_hit = cache_hit
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ self.result_column_struct = (
+ compiled._result_columns,
+ compiled._ordered_columns,
+ compiled._textual_ordered_columns,
+ compiled._loose_column_name_matching,
+ )
+ self.isinsert = compiled.isinsert
+ self.isupdate = compiled.isupdate
+ self.isdelete = compiled.isdelete
+ self.is_text = compiled.isplaintext
+
+ if self.isinsert or self.isupdate or self.isdelete:
+ self.is_crud = True
+ self._is_explicit_returning = bool(compiled.statement._returning)
+ self._is_implicit_returning = bool(
+ compiled.returning and not compiled.statement._returning
+ )
+
+ if not parameters:
+ self.compiled_parameters = [
+ compiled.construct_params(
+ extracted_parameters=extracted_parameters,
+ escape_names=False,
+ )
+ ]
+ else:
+ self.compiled_parameters = [
+ compiled.construct_params(
+ m,
+ escape_names=False,
+ _group_number=grp,
+ extracted_parameters=extracted_parameters,
+ )
+ for grp, m in enumerate(parameters)
+ ]
+
+ self.executemany = len(parameters) > 1
+
+ # this must occur before create_cursor() since the statement
+ # has to be regexed in some cases for server side cursor
+ if util.py2k:
+ self.unicode_statement = util.text_type(compiled.string)
+ else:
+ self.unicode_statement = compiled.string
+
+ self.cursor = self.create_cursor()
+
+ if self.compiled.insert_prefetch or self.compiled.update_prefetch:
+ if self.executemany:
+ self._process_executemany_defaults()
+ else:
+ self._process_executesingle_defaults()
+
+ processors = compiled._bind_processors
+
+ if compiled.literal_execute_params or compiled.post_compile_params:
+ if self.executemany:
+ raise exc.InvalidRequestError(
+ "'literal_execute' or 'expanding' parameters can't be "
+ "used with executemany()"
+ )
+
+ expanded_state = compiled._process_parameters_for_postcompile(
+ self.compiled_parameters[0]
+ )
+
+ # re-assign self.unicode_statement
+ self.unicode_statement = expanded_state.statement
+
+ # used by set_input_sizes() which is needed for Oracle
+ self._expanded_parameters = expanded_state.parameter_expansion
+
+ processors = dict(processors)
+ processors.update(expanded_state.processors)
+ positiontup = expanded_state.positiontup
+ elif compiled.positional:
+ positiontup = self.compiled.positiontup
+
+ if compiled.schema_translate_map:
+ schema_translate_map = self.execution_options.get(
+ "schema_translate_map", {}
+ )
+ rst = compiled.preparer._render_schema_translates
+ self.unicode_statement = rst(
+ self.unicode_statement, schema_translate_map
+ )
+
+ # final self.unicode_statement is now assigned, encode if needed
+ # by dialect
+ if not dialect.supports_unicode_statements:
+ self.statement = self.unicode_statement.encode(
+ self.dialect.encoding
+ )
+ else:
+ self.statement = self.unicode_statement
+
+ # Convert the dictionary of bind parameter values
+ # into a dict or list to be sent to the DBAPI's
+ # execute() or executemany() method.
+ parameters = []
+ if compiled.positional:
+ for compiled_params in self.compiled_parameters:
+ param = [
+ processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in positiontup
+ ]
+ parameters.append(dialect.execute_sequence_format(param))
+ else:
+ encode = not dialect.supports_unicode_statements
+ if encode:
+ encoder = dialect._encoder
+ for compiled_params in self.compiled_parameters:
+ escaped_bind_names = compiled.escaped_bind_names
+
+ if encode:
+ if escaped_bind_names:
+ param = {
+ encoder(escaped_bind_names.get(key, key))[
+ 0
+ ]: processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+ else:
+ param = {
+ encoder(key)[0]: processors[key](
+ compiled_params[key]
+ )
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+ else:
+ if escaped_bind_names:
+ param = {
+ escaped_bind_names.get(key, key): processors[key](
+ compiled_params[key]
+ )
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+ else:
+ param = {
+ key: processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in compiled_params
+ }
+
+ parameters.append(param)
+
+ self.parameters = dialect.execute_sequence_format(parameters)
+
+ return self
+
+ @classmethod
+ def _init_statement(
+ cls,
+ dialect,
+ connection,
+ dbapi_connection,
+ execution_options,
+ statement,
+ parameters,
+ ):
+ """Initialize execution context for a string SQL statement."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+ self.is_text = True
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ if not parameters:
+ if self.dialect.positional:
+ self.parameters = [dialect.execute_sequence_format()]
+ else:
+ self.parameters = [{}]
+ elif isinstance(parameters[0], dialect.execute_sequence_format):
+ self.parameters = parameters
+ elif isinstance(parameters[0], dict):
+ if dialect.supports_unicode_statements:
+ self.parameters = parameters
+ else:
+ self.parameters = [
+ {dialect._encoder(k)[0]: d[k] for k in d}
+ for d in parameters
+ ] or [{}]
+ else:
+ self.parameters = [
+ dialect.execute_sequence_format(p) for p in parameters
+ ]
+
+ self.executemany = len(parameters) > 1
+
+ if not dialect.supports_unicode_statements and isinstance(
+ statement, util.text_type
+ ):
+ self.unicode_statement = statement
+ self.statement = dialect._encoder(statement)[0]
+ else:
+ self.statement = self.unicode_statement = statement
+
+ self.cursor = self.create_cursor()
+ return self
+
+ @classmethod
+ def _init_default(
+ cls, dialect, connection, dbapi_connection, execution_options
+ ):
+ """Initialize execution context for a ColumnDefault construct."""
+
+ self = cls.__new__(cls)
+ self.root_connection = connection
+ self._dbapi_connection = dbapi_connection
+ self.dialect = connection.dialect
+
+ self.execution_options = execution_options
+
+ self._is_future_result = (
+ connection._is_future
+ or self.execution_options.get("future_result", False)
+ )
+
+ self.cursor = self.create_cursor()
+ return self
+
+ def _get_cache_stats(self):
+ if self.compiled is None:
+ return "raw sql"
+
+ now = util.perf_counter()
+
+ ch = self.cache_hit
+
+ if ch is NO_CACHE_KEY:
+ return "no key %.5fs" % (now - self.compiled._gen_time,)
+ elif ch is CACHE_HIT:
+ return "cached since %.4gs ago" % (now - self.compiled._gen_time,)
+ elif ch is CACHE_MISS:
+ return "generated in %.5fs" % (now - self.compiled._gen_time,)
+ elif ch is CACHING_DISABLED:
+ return "caching disabled %.5fs" % (now - self.compiled._gen_time,)
+ elif ch is NO_DIALECT_SUPPORT:
+ return "dialect %s+%s does not support caching %.5fs" % (
+ self.dialect.name,
+ self.dialect.driver,
+ now - self.compiled._gen_time,
+ )
+ else:
+ return "unknown"
+
+ @util.memoized_property
+ def identifier_preparer(self):
+ if self.compiled:
+ return self.compiled.preparer
+ elif "schema_translate_map" in self.execution_options:
+ return self.dialect.identifier_preparer._with_schema_translate(
+ self.execution_options["schema_translate_map"]
+ )
+ else:
+ return self.dialect.identifier_preparer
+
+ @util.memoized_property
+ def engine(self):
+ return self.root_connection.engine
+
+ @util.memoized_property
+ def postfetch_cols(self):
+ return self.compiled.postfetch
+
+ @util.memoized_property
+ def prefetch_cols(self):
+ if self.isinsert:
+ return self.compiled.insert_prefetch
+ elif self.isupdate:
+ return self.compiled.update_prefetch
+ else:
+ return ()
+
+ @util.memoized_property
+ def returning_cols(self):
+ self.compiled.returning
+
+ @util.memoized_property
+ def no_parameters(self):
+ return self.execution_options.get("no_parameters", False)
+
+ @util.memoized_property
+ def should_autocommit(self):
+ autocommit = self.execution_options.get(
+ "autocommit",
+ not self.compiled
+ and self.statement
+ and expression.PARSE_AUTOCOMMIT
+ or False,
+ )
+
+ if autocommit is expression.PARSE_AUTOCOMMIT:
+ return self.should_autocommit_text(self.unicode_statement)
+ else:
+ return autocommit
+
+ def _execute_scalar(self, stmt, type_, parameters=None):
+ """Execute a string statement on the current cursor, returning a
+ scalar result.
+
+ Used to fire off sequences, default phrases, and "select lastrowid"
+ types of statements individually or in the context of a parent INSERT
+ or UPDATE statement.
+
+ """
+
+ conn = self.root_connection
+ if (
+ isinstance(stmt, util.text_type)
+ and not self.dialect.supports_unicode_statements
+ ):
+ stmt = self.dialect._encoder(stmt)[0]
+
+ if "schema_translate_map" in self.execution_options:
+ schema_translate_map = self.execution_options.get(
+ "schema_translate_map", {}
+ )
+
+ rst = self.identifier_preparer._render_schema_translates
+ stmt = rst(stmt, schema_translate_map)
+
+ if not parameters:
+ if self.dialect.positional:
+ parameters = self.dialect.execute_sequence_format()
+ else:
+ parameters = {}
+
+ conn._cursor_execute(self.cursor, stmt, parameters, context=self)
+ r = self.cursor.fetchone()[0]
+ if type_ is not None:
+ # apply type post processors to the result
+ proc = type_._cached_result_processor(
+ self.dialect, self.cursor.description[0][1]
+ )
+ if proc:
+ return proc(r)
+ return r
+
+ @property
+ def connection(self):
+ conn = self.root_connection
+ if conn._is_future:
+ return conn
+ else:
+ return conn._branch()
+
+ def should_autocommit_text(self, statement):
+ return AUTOCOMMIT_REGEXP.match(statement)
+
+ def _use_server_side_cursor(self):
+ if not self.dialect.supports_server_side_cursors:
+ return False
+
+ if self.dialect.server_side_cursors:
+ # this is deprecated
+ use_server_side = self.execution_options.get(
+ "stream_results", True
+ ) and (
+ (
+ self.compiled
+ and isinstance(
+ self.compiled.statement, expression.Selectable
+ )
+ or (
+ (
+ not self.compiled
+ or isinstance(
+ self.compiled.statement, expression.TextClause
+ )
+ )
+ and self.unicode_statement
+ and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement)
+ )
+ )
+ )
+ else:
+ use_server_side = self.execution_options.get(
+ "stream_results", False
+ )
+
+ return use_server_side
+
+ def create_cursor(self):
+ if (
+ # inlining initial preference checks for SS cursors
+ self.dialect.supports_server_side_cursors
+ and (
+ self.execution_options.get("stream_results", False)
+ or (
+ self.dialect.server_side_cursors
+ and self._use_server_side_cursor()
+ )
+ )
+ ):
+ self._is_server_side = True
+ return self.create_server_side_cursor()
+ else:
+ self._is_server_side = False
+ return self.create_default_cursor()
+
+ def create_default_cursor(self):
+ return self._dbapi_connection.cursor()
+
+ def create_server_side_cursor(self):
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ pass
+
+ def get_out_parameter_values(self, names):
+ raise NotImplementedError(
+ "This dialect does not support OUT parameters"
+ )
+
+ def post_exec(self):
+ pass
+
+ def get_result_processor(self, type_, colname, coltype):
+ """Return a 'result processor' for a given type as present in
+ cursor.description.
+
+ This has a default implementation that dialects can override
+ for context-sensitive result type handling.
+
+ """
+ return type_._cached_result_processor(self.dialect, coltype)
+
+ def get_lastrowid(self):
+ """return self.cursor.lastrowid, or equivalent, after an INSERT.
+
+ This may involve calling special cursor functions, issuing a new SELECT
+ on the cursor (or a new one), or returning a stored value that was
+ calculated within post_exec().
+
+ This function will only be called for dialects which support "implicit"
+ primary key generation, keep preexecute_autoincrement_sequences set to
+ False, and when no explicit id value was bound to the statement.
+
+ The function is called once for an INSERT statement that would need to
+ return the last inserted primary key for those dialects that make use
+ of the lastrowid concept. In these cases, it is called directly after
+ :meth:`.ExecutionContext.post_exec`.
+
+ """
+ return self.cursor.lastrowid
+
+ def handle_dbapi_exception(self, e):
+ pass
+
+ @property
+ def rowcount(self):
+ return self.cursor.rowcount
+
+ def supports_sane_rowcount(self):
+ return self.dialect.supports_sane_rowcount
+
+ def supports_sane_multi_rowcount(self):
+ return self.dialect.supports_sane_multi_rowcount
+
+ def _setup_result_proxy(self):
+ exec_opt = self.execution_options
+
+ if self.is_crud or self.is_text:
+ result = self._setup_dml_or_text_result()
+ yp = sr = False
+ else:
+ yp = exec_opt.get("yield_per", None)
+ sr = self._is_server_side or exec_opt.get("stream_results", False)
+ strategy = self.cursor_fetch_strategy
+ if sr and strategy is _cursor._DEFAULT_FETCH:
+ strategy = _cursor.BufferedRowCursorFetchStrategy(
+ self.cursor, self.execution_options
+ )
+ cursor_description = (
+ strategy.alternate_cursor_description
+ or self.cursor.description
+ )
+ if cursor_description is None:
+ strategy = _cursor._NO_CURSOR_DQL
+
+ if self._is_future_result:
+ if self.root_connection.should_close_with_result:
+ raise exc.InvalidRequestError(
+ "can't use future_result=True with close_with_result"
+ )
+ result = _cursor.CursorResult(
+ self, strategy, cursor_description
+ )
+ else:
+ result = _cursor.LegacyCursorResult(
+ self, strategy, cursor_description
+ )
+
+ if (
+ self.compiled
+ and not self.isddl
+ and self.compiled.has_out_parameters
+ ):
+ self._setup_out_parameters(result)
+
+ self._soft_closed = result._soft_closed
+
+ if yp:
+ result = result.yield_per(yp)
+
+ return result
+
+ def _setup_out_parameters(self, result):
+
+ out_bindparams = [
+ (param, name)
+ for param, name in self.compiled.bind_names.items()
+ if param.isoutparam
+ ]
+ out_parameters = {}
+
+ for bindparam, raw_value in zip(
+ [param for param, name in out_bindparams],
+ self.get_out_parameter_values(
+ [name for param, name in out_bindparams]
+ ),
+ ):
+
+ type_ = bindparam.type
+ impl_type = type_.dialect_impl(self.dialect)
+ dbapi_type = impl_type.get_dbapi_type(self.dialect.dbapi)
+ result_processor = impl_type.result_processor(
+ self.dialect, dbapi_type
+ )
+ if result_processor is not None:
+ raw_value = result_processor(raw_value)
+ out_parameters[bindparam.key] = raw_value
+
+ result.out_parameters = out_parameters
+
+ def _setup_dml_or_text_result(self):
+ if self.isinsert:
+ if self.compiled.postfetch_lastrowid:
+ self.inserted_primary_key_rows = (
+ self._setup_ins_pk_from_lastrowid()
+ )
+ # else if not self._is_implicit_returning,
+ # the default inserted_primary_key_rows accessor will
+ # return an "empty" primary key collection when accessed.
+
+ strategy = self.cursor_fetch_strategy
+ if self._is_server_side and strategy is _cursor._DEFAULT_FETCH:
+ strategy = _cursor.BufferedRowCursorFetchStrategy(
+ self.cursor, self.execution_options
+ )
+ cursor_description = (
+ strategy.alternate_cursor_description or self.cursor.description
+ )
+ if cursor_description is None:
+ strategy = _cursor._NO_CURSOR_DML
+
+ if self._is_future_result:
+ result = _cursor.CursorResult(self, strategy, cursor_description)
+ else:
+ result = _cursor.LegacyCursorResult(
+ self, strategy, cursor_description
+ )
+
+ if self.isinsert:
+ if self._is_implicit_returning:
+ rows = result.all()
+
+ self.returned_default_rows = rows
+
+ self.inserted_primary_key_rows = (
+ self._setup_ins_pk_from_implicit_returning(result, rows)
+ )
+
+ # test that it has a cursor metadata that is accurate. the
+ # first row will have been fetched and current assumptions
+ # are that the result has only one row, until executemany()
+ # support is added here.
+ assert result._metadata.returns_rows
+ result._soft_close()
+ elif not self._is_explicit_returning:
+ result._soft_close()
+
+ # we assume here the result does not return any rows.
+ # *usually*, this will be true. However, some dialects
+ # such as that of MSSQL/pyodbc need to SELECT a post fetch
+ # function so this is not necessarily true.
+ # assert not result.returns_rows
+
+ elif self.isupdate and self._is_implicit_returning:
+ row = result.fetchone()
+ self.returned_default_rows = [row]
+ result._soft_close()
+
+ # test that it has a cursor metadata that is accurate.
+ # the rows have all been fetched however.
+ assert result._metadata.returns_rows
+
+ elif not result._metadata.returns_rows:
+ # no results, get rowcount
+ # (which requires open cursor on some drivers
+ # such as kintersbasdb, mxodbc)
+ result.rowcount
+ result._soft_close()
+ return result
+
+ @util.memoized_property
+ def inserted_primary_key_rows(self):
+ # if no specific "get primary key" strategy was set up
+ # during execution, return a "default" primary key based
+ # on what's in the compiled_parameters and nothing else.
+ return self._setup_ins_pk_from_empty()
+
+ def _setup_ins_pk_from_lastrowid(self):
+ getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+
+ lastrowid = self.get_lastrowid()
+ return [getter(lastrowid, self.compiled_parameters[0])]
+
+ def _setup_ins_pk_from_empty(self):
+ getter = self.compiled._inserted_primary_key_from_lastrowid_getter
+ return [getter(None, param) for param in self.compiled_parameters]
+
+ def _setup_ins_pk_from_implicit_returning(self, result, rows):
+
+ if not rows:
+ return []
+
+ getter = self.compiled._inserted_primary_key_from_returning_getter
+ compiled_params = self.compiled_parameters
+
+ return [
+ getter(row, param) for row, param in zip(rows, compiled_params)
+ ]
+
+ def lastrow_has_defaults(self):
+ return (self.isinsert or self.isupdate) and bool(
+ self.compiled.postfetch
+ )
+
+ def _set_input_sizes(self):
+ """Given a cursor and ClauseParameters, call the appropriate
+ style of ``setinputsizes()`` on the cursor, using DB-API types
+ from the bind parameter's ``TypeEngine`` objects.
+
+ This method only called by those dialects which require it,
+ currently cx_oracle, asyncpg and pg8000.
+
+ """
+ if self.isddl or self.is_text:
+ return
+
+ inputsizes = self.compiled._get_set_input_sizes_lookup(
+ include_types=self.include_set_input_sizes,
+ exclude_types=self.exclude_set_input_sizes,
+ )
+
+ if inputsizes is None:
+ return
+
+ if self.dialect._has_events:
+ inputsizes = dict(inputsizes)
+ self.dialect.dispatch.do_setinputsizes(
+ inputsizes, self.cursor, self.statement, self.parameters, self
+ )
+
+ has_escaped_names = bool(self.compiled.escaped_bind_names)
+ if has_escaped_names:
+ escaped_bind_names = self.compiled.escaped_bind_names
+
+ if self.dialect.positional:
+ items = [
+ (key, self.compiled.binds[key])
+ for key in self.compiled.positiontup
+ ]
+ else:
+ items = [
+ (key, bindparam)
+ for bindparam, key in self.compiled.bind_names.items()
+ ]
+
+ generic_inputsizes = []
+ for key, bindparam in items:
+ if bindparam in self.compiled.literal_execute_params:
+ continue
+
+ if key in self._expanded_parameters:
+ if bindparam.type._is_tuple_type:
+ num = len(bindparam.type.types)
+ dbtypes = inputsizes[bindparam]
+ generic_inputsizes.extend(
+ (
+ (
+ escaped_bind_names.get(paramname, paramname)
+ if has_escaped_names
+ else paramname
+ ),
+ dbtypes[idx % num],
+ bindparam.type.types[idx % num],
+ )
+ for idx, paramname in enumerate(
+ self._expanded_parameters[key]
+ )
+ )
+ else:
+ dbtype = inputsizes.get(bindparam, None)
+ generic_inputsizes.extend(
+ (
+ (
+ escaped_bind_names.get(paramname, paramname)
+ if has_escaped_names
+ else paramname
+ ),
+ dbtype,
+ bindparam.type,
+ )
+ for paramname in self._expanded_parameters[key]
+ )
+ else:
+ dbtype = inputsizes.get(bindparam, None)
+
+ escaped_name = (
+ escaped_bind_names.get(key, key)
+ if has_escaped_names
+ else key
+ )
+
+ generic_inputsizes.append(
+ (escaped_name, dbtype, bindparam.type)
+ )
+ try:
+ self.dialect.do_set_input_sizes(
+ self.cursor, generic_inputsizes, self
+ )
+ except BaseException as e:
+ self.root_connection._handle_dbapi_exception(
+ e, None, None, None, self
+ )
+
+ def _exec_default(self, column, default, type_):
+ if default.is_sequence:
+ return self.fire_sequence(default, type_)
+ elif default.is_callable:
+ self.current_column = column
+ return default.arg(self)
+ elif default.is_clause_element:
+ return self._exec_default_clause_element(column, default, type_)
+ else:
+ return default.arg
+
+ def _exec_default_clause_element(self, column, default, type_):
+ # execute a default that's a complete clause element. Here, we have
+ # to re-implement a miniature version of the compile->parameters->
+ # cursor.execute() sequence, since we don't want to modify the state
+ # of the connection / result in progress or create new connection/
+ # result objects etc.
+ # .. versionchanged:: 1.4
+
+ if not default._arg_is_typed:
+ default_arg = expression.type_coerce(default.arg, type_)
+ else:
+ default_arg = default.arg
+ compiled = expression.select(default_arg).compile(dialect=self.dialect)
+ compiled_params = compiled.construct_params()
+ processors = compiled._bind_processors
+ if compiled.positional:
+ positiontup = compiled.positiontup
+ parameters = self.dialect.execute_sequence_format(
+ [
+ processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key]
+ for key in positiontup
+ ]
+ )
+ else:
+ parameters = dict(
+ (
+ key,
+ processors[key](compiled_params[key])
+ if key in processors
+ else compiled_params[key],
+ )
+ for key in compiled_params
+ )
+ return self._execute_scalar(
+ util.text_type(compiled), type_, parameters=parameters
+ )
+
+ current_parameters = None
+ """A dictionary of parameters applied to the current row.
+
+ This attribute is only available in the context of a user-defined default
+ generation function, e.g. as described at :ref:`context_default_functions`.
+ It consists of a dictionary which includes entries for each column/value
+ pair that is to be part of the INSERT or UPDATE statement. The keys of the
+ dictionary will be the key value of each :class:`_schema.Column`,
+ which is usually
+ synonymous with the name.
+
+ Note that the :attr:`.DefaultExecutionContext.current_parameters` attribute
+ does not accommodate for the "multi-values" feature of the
+ :meth:`_expression.Insert.values` method. The
+ :meth:`.DefaultExecutionContext.get_current_parameters` method should be
+ preferred.
+
+ .. seealso::
+
+ :meth:`.DefaultExecutionContext.get_current_parameters`
+
+ :ref:`context_default_functions`
+
+ """
+
+ def get_current_parameters(self, isolate_multiinsert_groups=True):
+ """Return a dictionary of parameters applied to the current row.
+
+ This method can only be used in the context of a user-defined default
+ generation function, e.g. as described at
+ :ref:`context_default_functions`. When invoked, a dictionary is
+ returned which includes entries for each column/value pair that is part
+ of the INSERT or UPDATE statement. The keys of the dictionary will be
+ the key value of each :class:`_schema.Column`,
+ which is usually synonymous
+ with the name.
+
+ :param isolate_multiinsert_groups=True: indicates that multi-valued
+ INSERT constructs created using :meth:`_expression.Insert.values`
+ should be
+ handled by returning only the subset of parameters that are local
+ to the current column default invocation. When ``False``, the
+ raw parameters of the statement are returned including the
+ naming convention used in the case of multi-valued INSERT.
+
+ .. versionadded:: 1.2 added
+ :meth:`.DefaultExecutionContext.get_current_parameters`
+ which provides more functionality over the existing
+ :attr:`.DefaultExecutionContext.current_parameters`
+ attribute.
+
+ .. seealso::
+
+ :attr:`.DefaultExecutionContext.current_parameters`
+
+ :ref:`context_default_functions`
+
+ """
+ try:
+ parameters = self.current_parameters
+ column = self.current_column
+ except AttributeError:
+ raise exc.InvalidRequestError(
+ "get_current_parameters() can only be invoked in the "
+ "context of a Python side column default function"
+ )
+
+ compile_state = self.compiled.compile_state
+ if (
+ isolate_multiinsert_groups
+ and self.isinsert
+ and compile_state._has_multi_parameters
+ ):
+ if column._is_multiparam_column:
+ index = column.index + 1
+ d = {column.original.key: parameters[column.key]}
+ else:
+ d = {column.key: parameters[column.key]}
+ index = 0
+ keys = compile_state._dict_parameters.keys()
+ d.update(
+ (key, parameters["%s_m%d" % (key, index)]) for key in keys
+ )
+ return d
+ else:
+ return parameters
+
+ def get_insert_default(self, column):
+ if column.default is None:
+ return None
+ else:
+ return self._exec_default(column, column.default, column.type)
+
+ def get_update_default(self, column):
+ if column.onupdate is None:
+ return None
+ else:
+ return self._exec_default(column, column.onupdate, column.type)
+
+ def _process_executemany_defaults(self):
+ key_getter = self.compiled._within_exec_param_key_getter
+
+ scalar_defaults = {}
+
+ insert_prefetch = self.compiled.insert_prefetch
+ update_prefetch = self.compiled.update_prefetch
+
+ # pre-determine scalar Python-side defaults
+ # to avoid many calls of get_insert_default()/
+ # get_update_default()
+ for c in insert_prefetch:
+ if c.default and not c.default.is_sequence and c.default.is_scalar:
+ scalar_defaults[c] = c.default.arg
+
+ for c in update_prefetch:
+ if c.onupdate and c.onupdate.is_scalar:
+ scalar_defaults[c] = c.onupdate.arg
+
+ for param in self.compiled_parameters:
+ self.current_parameters = param
+ for c in insert_prefetch:
+ if c in scalar_defaults:
+ val = scalar_defaults[c]
+ else:
+ val = self.get_insert_default(c)
+ if val is not None:
+ param[key_getter(c)] = val
+ for c in update_prefetch:
+ if c in scalar_defaults:
+ val = scalar_defaults[c]
+ else:
+ val = self.get_update_default(c)
+ if val is not None:
+ param[key_getter(c)] = val
+
+ del self.current_parameters
+
+ def _process_executesingle_defaults(self):
+ key_getter = self.compiled._within_exec_param_key_getter
+ self.current_parameters = (
+ compiled_parameters
+ ) = self.compiled_parameters[0]
+
+ for c in self.compiled.insert_prefetch:
+ if c.default and not c.default.is_sequence and c.default.is_scalar:
+ val = c.default.arg
+ else:
+ val = self.get_insert_default(c)
+
+ if val is not None:
+ compiled_parameters[key_getter(c)] = val
+
+ for c in self.compiled.update_prefetch:
+ val = self.get_update_default(c)
+
+ if val is not None:
+ compiled_parameters[key_getter(c)] = val
+ del self.current_parameters
+
+
+DefaultDialect.execution_ctx_cls = DefaultExecutionContext
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
new file mode 100644
index 0000000..286c4d4
--- /dev/null
+++ b/lib/sqlalchemy/engine/events.py
@@ -0,0 +1,835 @@
+# sqlalchemy/engine/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from .base import Engine
+from .interfaces import Connectable
+from .interfaces import Dialect
+from .. import event
+from .. import exc
+
+
+class ConnectionEvents(event.Events):
+ """Available events for :class:`.Connectable`, which includes
+ :class:`_engine.Connection` and :class:`_engine.Engine`.
+
+ The methods here define the name of an event as well as the names of
+ members that are passed to listener functions.
+
+ An event listener can be associated with any :class:`.Connectable`
+ class or instance, such as an :class:`_engine.Engine`, e.g.::
+
+ from sqlalchemy import event, create_engine
+
+ def before_cursor_execute(conn, cursor, statement, parameters, context,
+ executemany):
+ log.info("Received statement: %s", statement)
+
+ engine = create_engine('postgresql://scott:tiger@localhost/test')
+ event.listen(engine, "before_cursor_execute", before_cursor_execute)
+
+ or with a specific :class:`_engine.Connection`::
+
+ with engine.begin() as conn:
+ @event.listens_for(conn, 'before_cursor_execute')
+ def before_cursor_execute(conn, cursor, statement, parameters,
+ context, executemany):
+ log.info("Received statement: %s", statement)
+
+ When the methods are called with a `statement` parameter, such as in
+ :meth:`.after_cursor_execute` or :meth:`.before_cursor_execute`,
+ the statement is the exact SQL string that was prepared for transmission
+ to the DBAPI ``cursor`` in the connection's :class:`.Dialect`.
+
+ The :meth:`.before_execute` and :meth:`.before_cursor_execute`
+ events can also be established with the ``retval=True`` flag, which
+ allows modification of the statement and parameters to be sent
+ to the database. The :meth:`.before_cursor_execute` event is
+ particularly useful here to add ad-hoc string transformations, such
+ as comments, to all executions::
+
+ from sqlalchemy.engine import Engine
+ from sqlalchemy import event
+
+ @event.listens_for(Engine, "before_cursor_execute", retval=True)
+ def comment_sql_calls(conn, cursor, statement, parameters,
+ context, executemany):
+ statement = statement + " -- some comment"
+ return statement, parameters
+
+ .. note:: :class:`_events.ConnectionEvents` can be established on any
+ combination of :class:`_engine.Engine`, :class:`_engine.Connection`,
+ as well
+ as instances of each of those classes. Events across all
+ four scopes will fire off for a given instance of
+ :class:`_engine.Connection`. However, for performance reasons, the
+ :class:`_engine.Connection` object determines at instantiation time
+ whether or not its parent :class:`_engine.Engine` has event listeners
+ established. Event listeners added to the :class:`_engine.Engine`
+ class or to an instance of :class:`_engine.Engine`
+ *after* the instantiation
+ of a dependent :class:`_engine.Connection` instance will usually
+ *not* be available on that :class:`_engine.Connection` instance.
+ The newly
+ added listeners will instead take effect for
+ :class:`_engine.Connection`
+ instances created subsequent to those event listeners being
+ established on the parent :class:`_engine.Engine` class or instance.
+
+ :param retval=False: Applies to the :meth:`.before_execute` and
+ :meth:`.before_cursor_execute` events only. When True, the
+ user-defined event function must have a return value, which
+ is a tuple of parameters that replace the given statement
+ and parameters. See those methods for a description of
+ specific return arguments.
+
+ """
+
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = Connectable
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
+
+ target._has_events = True
+
+ if not retval:
+ if identifier == "before_execute":
+ orig_fn = fn
+
+ def wrap_before_execute(
+ conn, clauseelement, multiparams, params, execution_options
+ ):
+ orig_fn(
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ execution_options,
+ )
+ return clauseelement, multiparams, params
+
+ fn = wrap_before_execute
+ elif identifier == "before_cursor_execute":
+ orig_fn = fn
+
+ def wrap_before_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ orig_fn(
+ conn,
+ cursor,
+ statement,
+ parameters,
+ context,
+ executemany,
+ )
+ return statement, parameters
+
+ fn = wrap_before_cursor_execute
+ elif retval and identifier not in (
+ "before_execute",
+ "before_cursor_execute",
+ "handle_error",
+ ):
+ raise exc.ArgumentError(
+ "Only the 'before_execute', "
+ "'before_cursor_execute' and 'handle_error' engine "
+ "event listeners accept the 'retval=True' "
+ "argument."
+ )
+ event_key.with_wrapper(fn).base_listen()
+
+ @event._legacy_signature(
+ "1.4",
+ ["conn", "clauseelement", "multiparams", "params"],
+ lambda conn, clauseelement, multiparams, params, execution_options: (
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ ),
+ )
+ def before_execute(
+ self, conn, clauseelement, multiparams, params, execution_options
+ ):
+ """Intercept high level execute() events, receiving uncompiled
+ SQL constructs and other objects prior to rendering into SQL.
+
+ This event is good for debugging SQL compilation issues as well
+ as early manipulation of the parameters being sent to the database,
+ as the parameter lists will be in a consistent format here.
+
+ This event can be optionally established with the ``retval=True``
+ flag. The ``clauseelement``, ``multiparams``, and ``params``
+ arguments should be returned as a three-tuple in this case::
+
+ @event.listens_for(Engine, "before_execute", retval=True)
+ def before_execute(conn, clauseelement, multiparams, params):
+ # do something with clauseelement, multiparams, params
+ return clauseelement, multiparams, params
+
+ :param conn: :class:`_engine.Connection` object
+ :param clauseelement: SQL expression construct, :class:`.Compiled`
+ instance, or string statement passed to
+ :meth:`_engine.Connection.execute`.
+ :param multiparams: Multiple parameter sets, a list of dictionaries.
+ :param params: Single parameter set, a single dictionary.
+ :param execution_options: dictionary of execution
+ options passed along with the statement, if any. This is a merge
+ of all options that will be used, including those of the statement,
+ the connection, and those passed in to the method itself for
+ the 2.0 style of execution.
+
+ .. versionadded: 1.4
+
+ .. seealso::
+
+ :meth:`.before_cursor_execute`
+
+ """
+
+ @event._legacy_signature(
+ "1.4",
+ ["conn", "clauseelement", "multiparams", "params", "result"],
+ lambda conn, clauseelement, multiparams, params, execution_options, result: ( # noqa
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ result,
+ ),
+ )
+ def after_execute(
+ self,
+ conn,
+ clauseelement,
+ multiparams,
+ params,
+ execution_options,
+ result,
+ ):
+ """Intercept high level execute() events after execute.
+
+
+ :param conn: :class:`_engine.Connection` object
+ :param clauseelement: SQL expression construct, :class:`.Compiled`
+ instance, or string statement passed to
+ :meth:`_engine.Connection.execute`.
+ :param multiparams: Multiple parameter sets, a list of dictionaries.
+ :param params: Single parameter set, a single dictionary.
+ :param execution_options: dictionary of execution
+ options passed along with the statement, if any. This is a merge
+ of all options that will be used, including those of the statement,
+ the connection, and those passed in to the method itself for
+ the 2.0 style of execution.
+
+ .. versionadded: 1.4
+
+ :param result: :class:`_engine.CursorResult` generated by the
+ execution.
+
+ """
+
+ def before_cursor_execute(
+ self, conn, cursor, statement, parameters, context, executemany
+ ):
+ """Intercept low-level cursor execute() events before execution,
+ receiving the string SQL statement and DBAPI-specific parameter list to
+ be invoked against a cursor.
+
+ This event is a good choice for logging as well as late modifications
+ to the SQL string. It's less ideal for parameter modifications except
+ for those which are specific to a target backend.
+
+ This event can be optionally established with the ``retval=True``
+ flag. The ``statement`` and ``parameters`` arguments should be
+ returned as a two-tuple in this case::
+
+ @event.listens_for(Engine, "before_cursor_execute", retval=True)
+ def before_cursor_execute(conn, cursor, statement,
+ parameters, context, executemany):
+ # do something with statement, parameters
+ return statement, parameters
+
+ See the example at :class:`_events.ConnectionEvents`.
+
+ :param conn: :class:`_engine.Connection` object
+ :param cursor: DBAPI cursor object
+ :param statement: string SQL statement, as to be passed to the DBAPI
+ :param parameters: Dictionary, tuple, or list of parameters being
+ passed to the ``execute()`` or ``executemany()`` method of the
+ DBAPI ``cursor``. In some cases may be ``None``.
+ :param context: :class:`.ExecutionContext` object in use. May
+ be ``None``.
+ :param executemany: boolean, if ``True``, this is an ``executemany()``
+ call, if ``False``, this is an ``execute()`` call.
+
+ .. seealso::
+
+ :meth:`.before_execute`
+
+ :meth:`.after_cursor_execute`
+
+ """
+
+ def after_cursor_execute(
+ self, conn, cursor, statement, parameters, context, executemany
+ ):
+ """Intercept low-level cursor execute() events after execution.
+
+ :param conn: :class:`_engine.Connection` object
+ :param cursor: DBAPI cursor object. Will have results pending
+ if the statement was a SELECT, but these should not be consumed
+ as they will be needed by the :class:`_engine.CursorResult`.
+ :param statement: string SQL statement, as passed to the DBAPI
+ :param parameters: Dictionary, tuple, or list of parameters being
+ passed to the ``execute()`` or ``executemany()`` method of the
+ DBAPI ``cursor``. In some cases may be ``None``.
+ :param context: :class:`.ExecutionContext` object in use. May
+ be ``None``.
+ :param executemany: boolean, if ``True``, this is an ``executemany()``
+ call, if ``False``, this is an ``execute()`` call.
+
+ """
+
+ def handle_error(self, exception_context):
+ r"""Intercept all exceptions processed by the
+ :class:`_engine.Connection`.
+
+ This includes all exceptions emitted by the DBAPI as well as
+ within SQLAlchemy's statement invocation process, including
+ encoding errors and other statement validation errors. Other areas
+ in which the event is invoked include transaction begin and end,
+ result row fetching, cursor creation.
+
+ Note that :meth:`.handle_error` may support new kinds of exceptions
+ and new calling scenarios at *any time*. Code which uses this
+ event must expect new calling patterns to be present in minor
+ releases.
+
+ To support the wide variety of members that correspond to an exception,
+ as well as to allow extensibility of the event without backwards
+ incompatibility, the sole argument received is an instance of
+ :class:`.ExceptionContext`. This object contains data members
+ representing detail about the exception.
+
+ Use cases supported by this hook include:
+
+ * read-only, low-level exception handling for logging and
+ debugging purposes
+ * exception re-writing
+ * Establishing or disabling whether a connection or the owning
+ connection pool is invalidated or expired in response to a
+ specific exception [1]_.
+
+ The hook is called while the cursor from the failed operation
+ (if any) is still open and accessible. Special cleanup operations
+ can be called on this cursor; SQLAlchemy will attempt to close
+ this cursor subsequent to this hook being invoked. If the connection
+ is in "autocommit" mode, the transaction also remains open within
+ the scope of this hook; the rollback of the per-statement transaction
+ also occurs after the hook is called.
+
+ .. note::
+
+ .. [1] The pool "pre_ping" handler enabled using the
+ :paramref:`_sa.create_engine.pool_pre_ping` parameter does
+ **not** consult this event before deciding if the "ping"
+ returned false, as opposed to receiving an unhandled error.
+ For this use case, the :ref:`legacy recipe based on
+ engine_connect() may be used
+ <pool_disconnects_pessimistic_custom>`. A future API allow
+ more comprehensive customization of the "disconnect"
+ detection mechanism across all functions.
+
+ A handler function has two options for replacing
+ the SQLAlchemy-constructed exception into one that is user
+ defined. It can either raise this new exception directly, in
+ which case all further event listeners are bypassed and the
+ exception will be raised, after appropriate cleanup as taken
+ place::
+
+ @event.listens_for(Engine, "handle_error")
+ def handle_exception(context):
+ if isinstance(context.original_exception,
+ psycopg2.OperationalError) and \
+ "failed" in str(context.original_exception):
+ raise MySpecialException("failed operation")
+
+ .. warning:: Because the
+ :meth:`_events.ConnectionEvents.handle_error`
+ event specifically provides for exceptions to be re-thrown as
+ the ultimate exception raised by the failed statement,
+ **stack traces will be misleading** if the user-defined event
+ handler itself fails and throws an unexpected exception;
+ the stack trace may not illustrate the actual code line that
+ failed! It is advised to code carefully here and use
+ logging and/or inline debugging if unexpected exceptions are
+ occurring.
+
+ Alternatively, a "chained" style of event handling can be
+ used, by configuring the handler with the ``retval=True``
+ modifier and returning the new exception instance from the
+ function. In this case, event handling will continue onto the
+ next handler. The "chained" exception is available using
+ :attr:`.ExceptionContext.chained_exception`::
+
+ @event.listens_for(Engine, "handle_error", retval=True)
+ def handle_exception(context):
+ if context.chained_exception is not None and \
+ "special" in context.chained_exception.message:
+ return MySpecialException("failed",
+ cause=context.chained_exception)
+
+ Handlers that return ``None`` may be used within the chain; when
+ a handler returns ``None``, the previous exception instance,
+ if any, is maintained as the current exception that is passed onto the
+ next handler.
+
+ When a custom exception is raised or returned, SQLAlchemy raises
+ this new exception as-is, it is not wrapped by any SQLAlchemy
+ object. If the exception is not a subclass of
+ :class:`sqlalchemy.exc.StatementError`,
+ certain features may not be available; currently this includes
+ the ORM's feature of adding a detail hint about "autoflush" to
+ exceptions raised within the autoflush process.
+
+ :param context: an :class:`.ExceptionContext` object. See this
+ class for details on all available members.
+
+ .. versionadded:: 0.9.7 Added the
+ :meth:`_events.ConnectionEvents.handle_error` hook.
+
+ .. versionchanged:: 1.1 The :meth:`.handle_error` event will now
+ receive all exceptions that inherit from ``BaseException``,
+ including ``SystemExit`` and ``KeyboardInterrupt``. The setting for
+ :attr:`.ExceptionContext.is_disconnect` is ``True`` in this case and
+ the default for
+ :attr:`.ExceptionContext.invalidate_pool_on_disconnect` is
+ ``False``.
+
+ .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is now
+ invoked when an :class:`_engine.Engine` fails during the initial
+ call to :meth:`_engine.Engine.connect`, as well as when a
+ :class:`_engine.Connection` object encounters an error during a
+ reconnect operation.
+
+ .. versionchanged:: 1.0.0 The :meth:`.handle_error` event is
+ not fired off when a dialect makes use of the
+ ``skip_user_error_events`` execution option. This is used
+ by dialects which intend to catch SQLAlchemy-specific exceptions
+ within specific operations, such as when the MySQL dialect detects
+ a table not present within the ``has_table()`` dialect method.
+ Prior to 1.0.0, code which implements :meth:`.handle_error` needs
+ to ensure that exceptions thrown in these scenarios are re-raised
+ without modification.
+
+ """
+
+ def engine_connect(self, conn, branch):
+ """Intercept the creation of a new :class:`_engine.Connection`.
+
+ This event is called typically as the direct result of calling
+ the :meth:`_engine.Engine.connect` method.
+
+ It differs from the :meth:`_events.PoolEvents.connect` method, which
+ refers to the actual connection to a database at the DBAPI level;
+ a DBAPI connection may be pooled and reused for many operations.
+ In contrast, this event refers only to the production of a higher level
+ :class:`_engine.Connection` wrapper around such a DBAPI connection.
+
+ It also differs from the :meth:`_events.PoolEvents.checkout` event
+ in that it is specific to the :class:`_engine.Connection` object,
+ not the
+ DBAPI connection that :meth:`_events.PoolEvents.checkout` deals with,
+ although
+ this DBAPI connection is available here via the
+ :attr:`_engine.Connection.connection` attribute.
+ But note there can in fact
+ be multiple :meth:`_events.PoolEvents.checkout`
+ events within the lifespan
+ of a single :class:`_engine.Connection` object, if that
+ :class:`_engine.Connection`
+ is invalidated and re-established. There can also be multiple
+ :class:`_engine.Connection`
+ objects generated for the same already-checked-out
+ DBAPI connection, in the case that a "branch" of a
+ :class:`_engine.Connection`
+ is produced.
+
+ :param conn: :class:`_engine.Connection` object.
+ :param branch: if True, this is a "branch" of an existing
+ :class:`_engine.Connection`. A branch is generated within the course
+ of a statement execution to invoke supplemental statements, most
+ typically to pre-execute a SELECT of a default value for the purposes
+ of an INSERT statement.
+
+ .. seealso::
+
+ :meth:`_events.PoolEvents.checkout`
+ the lower-level pool checkout event
+ for an individual DBAPI connection
+
+ """
+
+ def set_connection_execution_options(self, conn, opts):
+ """Intercept when the :meth:`_engine.Connection.execution_options`
+ method is called.
+
+ This method is called after the new :class:`_engine.Connection`
+ has been
+ produced, with the newly updated execution options collection, but
+ before the :class:`.Dialect` has acted upon any of those new options.
+
+ Note that this method is not called when a new
+ :class:`_engine.Connection`
+ is produced which is inheriting execution options from its parent
+ :class:`_engine.Engine`; to intercept this condition, use the
+ :meth:`_events.ConnectionEvents.engine_connect` event.
+
+ :param conn: The newly copied :class:`_engine.Connection` object
+
+ :param opts: dictionary of options that were passed to the
+ :meth:`_engine.Connection.execution_options` method.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.set_engine_execution_options`
+ - event
+ which is called when :meth:`_engine.Engine.execution_options`
+ is called.
+
+
+ """
+
+ def set_engine_execution_options(self, engine, opts):
+ """Intercept when the :meth:`_engine.Engine.execution_options`
+ method is called.
+
+ The :meth:`_engine.Engine.execution_options` method produces a shallow
+ copy of the :class:`_engine.Engine` which stores the new options.
+ That new
+ :class:`_engine.Engine` is passed here.
+ A particular application of this
+ method is to add a :meth:`_events.ConnectionEvents.engine_connect`
+ event
+ handler to the given :class:`_engine.Engine`
+ which will perform some per-
+ :class:`_engine.Connection` task specific to these execution options.
+
+ :param conn: The newly copied :class:`_engine.Engine` object
+
+ :param opts: dictionary of options that were passed to the
+ :meth:`_engine.Connection.execution_options` method.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.set_connection_execution_options`
+ - event
+ which is called when :meth:`_engine.Connection.execution_options`
+ is
+ called.
+
+ """
+
+ def engine_disposed(self, engine):
+ """Intercept when the :meth:`_engine.Engine.dispose` method is called.
+
+ The :meth:`_engine.Engine.dispose` method instructs the engine to
+ "dispose" of it's connection pool (e.g. :class:`_pool.Pool`), and
+ replaces it with a new one. Disposing of the old pool has the
+ effect that existing checked-in connections are closed. The new
+ pool does not establish any new connections until it is first used.
+
+ This event can be used to indicate that resources related to the
+ :class:`_engine.Engine` should also be cleaned up,
+ keeping in mind that the
+ :class:`_engine.Engine`
+ can still be used for new requests in which case
+ it re-acquires connection resources.
+
+ .. versionadded:: 1.0.5
+
+ """
+
+ def begin(self, conn):
+ """Intercept begin() events.
+
+ :param conn: :class:`_engine.Connection` object
+
+ """
+
+ def rollback(self, conn):
+ """Intercept rollback() events, as initiated by a
+ :class:`.Transaction`.
+
+ Note that the :class:`_pool.Pool` also "auto-rolls back"
+ a DBAPI connection upon checkin, if the ``reset_on_return``
+ flag is set to its default value of ``'rollback'``.
+ To intercept this
+ rollback, use the :meth:`_events.PoolEvents.reset` hook.
+
+ :param conn: :class:`_engine.Connection` object
+
+ .. seealso::
+
+ :meth:`_events.PoolEvents.reset`
+
+ """
+
+ def commit(self, conn):
+ """Intercept commit() events, as initiated by a
+ :class:`.Transaction`.
+
+ Note that the :class:`_pool.Pool` may also "auto-commit"
+ a DBAPI connection upon checkin, if the ``reset_on_return``
+ flag is set to the value ``'commit'``. To intercept this
+ commit, use the :meth:`_events.PoolEvents.reset` hook.
+
+ :param conn: :class:`_engine.Connection` object
+ """
+
+ def savepoint(self, conn, name):
+ """Intercept savepoint() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param name: specified name used for the savepoint.
+
+ """
+
+ def rollback_savepoint(self, conn, name, context):
+ """Intercept rollback_savepoint() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param name: specified name used for the savepoint.
+ :param context: not used
+
+ """
+ # TODO: deprecate "context"
+
+ def release_savepoint(self, conn, name, context):
+ """Intercept release_savepoint() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param name: specified name used for the savepoint.
+ :param context: not used
+
+ """
+ # TODO: deprecate "context"
+
+ def begin_twophase(self, conn, xid):
+ """Intercept begin_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+
+ """
+
+ def prepare_twophase(self, conn, xid):
+ """Intercept prepare_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+ """
+
+ def rollback_twophase(self, conn, xid, is_prepared):
+ """Intercept rollback_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+ :param is_prepared: boolean, indicates if
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+
+ """
+
+ def commit_twophase(self, conn, xid, is_prepared):
+ """Intercept commit_twophase() events.
+
+ :param conn: :class:`_engine.Connection` object
+ :param xid: two-phase XID identifier
+ :param is_prepared: boolean, indicates if
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+
+ """
+
+
+class DialectEvents(event.Events):
+ """event interface for execution-replacement functions.
+
+ These events allow direct instrumentation and replacement
+ of key dialect functions which interact with the DBAPI.
+
+ .. note::
+
+ :class:`.DialectEvents` hooks should be considered **semi-public**
+ and experimental.
+ These hooks are not for general use and are only for those situations
+ where intricate re-statement of DBAPI mechanics must be injected onto
+ an existing dialect. For general-use statement-interception events,
+ please use the :class:`_events.ConnectionEvents` interface.
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.before_cursor_execute`
+
+ :meth:`_events.ConnectionEvents.before_execute`
+
+ :meth:`_events.ConnectionEvents.after_cursor_execute`
+
+ :meth:`_events.ConnectionEvents.after_execute`
+
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = Dialect
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ target = event_key.dispatch_target
+
+ target._has_events = True
+ event_key.base_listen()
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, type):
+ if issubclass(target, Engine):
+ return Dialect
+ elif issubclass(target, Dialect):
+ return target
+ elif isinstance(target, Engine):
+ return target.dialect
+ elif isinstance(target, Dialect):
+ return target
+ elif hasattr(target, "dispatch") and hasattr(
+ target.dispatch._events, "_no_async_engine_events"
+ ):
+ target.dispatch._events._no_async_engine_events()
+ else:
+ return None
+
+ def do_connect(self, dialect, conn_rec, cargs, cparams):
+ """Receive connection arguments before a connection is made.
+
+ This event is useful in that it allows the handler to manipulate the
+ cargs and/or cparams collections that control how the DBAPI
+ ``connect()`` function will be called. ``cargs`` will always be a
+ Python list that can be mutated in-place, and ``cparams`` a Python
+ dictionary that may also be mutated::
+
+ e = create_engine("postgresql+psycopg2://user@host/dbname")
+
+ @event.listens_for(e, 'do_connect')
+ def receive_do_connect(dialect, conn_rec, cargs, cparams):
+ cparams["password"] = "some_password"
+
+ The event hook may also be used to override the call to ``connect()``
+ entirely, by returning a non-``None`` DBAPI connection object::
+
+ e = create_engine("postgresql+psycopg2://user@host/dbname")
+
+ @event.listens_for(e, 'do_connect')
+ def receive_do_connect(dialect, conn_rec, cargs, cparams):
+ return psycopg2.connect(*cargs, **cparams)
+
+
+ .. versionadded:: 1.0.3
+
+ .. seealso::
+
+ :ref:`custom_dbapi_args`
+
+ """
+
+ def do_executemany(self, cursor, statement, parameters, context):
+ """Receive a cursor to have executemany() called.
+
+ Return the value True to halt further events from invoking,
+ and to indicate that the cursor execution has already taken
+ place within the event handler.
+
+ """
+
+ def do_execute_no_params(self, cursor, statement, context):
+ """Receive a cursor to have execute() with no parameters called.
+
+ Return the value True to halt further events from invoking,
+ and to indicate that the cursor execution has already taken
+ place within the event handler.
+
+ """
+
+ def do_execute(self, cursor, statement, parameters, context):
+ """Receive a cursor to have execute() called.
+
+ Return the value True to halt further events from invoking,
+ and to indicate that the cursor execution has already taken
+ place within the event handler.
+
+ """
+
+ def do_setinputsizes(
+ self, inputsizes, cursor, statement, parameters, context
+ ):
+ """Receive the setinputsizes dictionary for possible modification.
+
+ This event is emitted in the case where the dialect makes use of the
+ DBAPI ``cursor.setinputsizes()`` method which passes information about
+ parameter binding for a particular statement. The given
+ ``inputsizes`` dictionary will contain :class:`.BindParameter` objects
+ as keys, linked to DBAPI-specific type objects as values; for
+ parameters that are not bound, they are added to the dictionary with
+ ``None`` as the value, which means the parameter will not be included
+ in the ultimate setinputsizes call. The event may be used to inspect
+ and/or log the datatypes that are being bound, as well as to modify the
+ dictionary in place. Parameters can be added, modified, or removed
+ from this dictionary. Callers will typically want to inspect the
+ :attr:`.BindParameter.type` attribute of the given bind objects in
+ order to make decisions about the DBAPI object.
+
+ After the event, the ``inputsizes`` dictionary is converted into
+ an appropriate datastructure to be passed to ``cursor.setinputsizes``;
+ either a list for a positional bound parameter execution style,
+ or a dictionary of string parameter keys to DBAPI type objects for
+ a named bound parameter execution style.
+
+ The setinputsizes hook overall is only used for dialects which include
+ the flag ``use_setinputsizes=True``. Dialects which use this
+ include cx_Oracle, pg8000, asyncpg, and pyodbc dialects.
+
+ .. note::
+
+ For use with pyodbc, the ``use_setinputsizes`` flag
+ must be passed to the dialect, e.g.::
+
+ create_engine("mssql+pyodbc://...", use_setinputsizes=True)
+
+ .. seealso::
+
+ :ref:`mssql_pyodbc_setinputsizes`
+
+ .. versionadded:: 1.2.9
+
+ .. seealso::
+
+ :ref:`cx_oracle_setinputsizes`
+
+ """
+ pass
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
new file mode 100644
index 0000000..4f2524a
--- /dev/null
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -0,0 +1,1719 @@
+# engine/interfaces.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define core interfaces used by the engine system."""
+
+from .. import util
+from ..sql.compiler import Compiled # noqa
+from ..sql.compiler import TypeCompiler # noqa
+from ..util.concurrency import await_only
+
+
+class Dialect(object):
+ """Define the behavior of a specific database and DB-API combination.
+
+ Any aspect of metadata definition, SQL query generation,
+ execution, result-set handling, or anything else which varies
+ between databases is defined under the general category of the
+ Dialect. The Dialect acts as a factory for other
+ database-specific object implementations including
+ ExecutionContext, Compiled, DefaultGenerator, and TypeEngine.
+
+ .. note:: Third party dialects should not subclass :class:`.Dialect`
+ directly. Instead, subclass :class:`.default.DefaultDialect` or
+ descendant class.
+
+ All dialects include the following attributes. There are many other
+ attributes that may be supported as well:
+
+ ``name``
+ identifying name for the dialect from a DBAPI-neutral point of view
+ (i.e. 'sqlite')
+
+ ``driver``
+ identifying name for the dialect's DBAPI
+
+ ``positional``
+ True if the paramstyle for this Dialect is positional.
+
+ ``paramstyle``
+ the paramstyle to be used (some DB-APIs support multiple
+ paramstyles).
+
+ ``encoding``
+ type of encoding to use for unicode, usually defaults to
+ 'utf-8'.
+
+ ``statement_compiler``
+ a :class:`.Compiled` class used to compile SQL statements
+
+ ``ddl_compiler``
+ a :class:`.Compiled` class used to compile DDL statements
+
+ ``server_version_info``
+ a tuple containing a version number for the DB backend in use.
+ This value is only available for supporting dialects, and is
+ typically populated during the initial connection to the database.
+
+ ``default_schema_name``
+ the name of the default schema. This value is only available for
+ supporting dialects, and is typically populated during the
+ initial connection to the database.
+
+ ``execution_ctx_cls``
+ a :class:`.ExecutionContext` class used to handle statement execution
+
+ ``execute_sequence_format``
+ either the 'tuple' or 'list' type, depending on what cursor.execute()
+ accepts for the second argument (they vary).
+
+ ``preparer``
+ a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to
+ quote identifiers.
+
+ ``supports_alter``
+ ``True`` if the database supports ``ALTER TABLE`` - used only for
+ generating foreign key constraints in certain circumstances
+
+ ``max_identifier_length``
+ The maximum length of identifier names.
+
+ ``supports_sane_rowcount``
+ Indicate whether the dialect properly implements rowcount for
+ ``UPDATE`` and ``DELETE`` statements.
+
+ ``supports_sane_multi_rowcount``
+ Indicate whether the dialect properly implements rowcount for
+ ``UPDATE`` and ``DELETE`` statements when executed via
+ executemany.
+
+ ``preexecute_autoincrement_sequences``
+ True if 'implicit' primary key functions must be executed separately
+ in order to get their value. This is currently oriented towards
+ PostgreSQL.
+
+ ``implicit_returning``
+ use RETURNING or equivalent during INSERT execution in order to load
+ newly generated primary keys and other column defaults in one execution,
+ which are then available via inserted_primary_key.
+ If an insert statement has returning() specified explicitly,
+ the "implicit" functionality is not used and inserted_primary_key
+ will not be available.
+
+ ``colspecs``
+ A dictionary of TypeEngine classes from sqlalchemy.types mapped
+ to subclasses that are specific to the dialect class. This
+ dictionary is class-level only and is not accessed from the
+ dialect instance itself.
+
+ ``supports_default_values``
+ Indicates if the construct ``INSERT INTO tablename DEFAULT
+ VALUES`` is supported
+
+ ``supports_sequences``
+ Indicates if the dialect supports CREATE SEQUENCE or similar.
+
+ ``sequences_optional``
+ If True, indicates if the "optional" flag on the Sequence() construct
+ should signal to not generate a CREATE SEQUENCE. Applies only to
+ dialects that support sequences. Currently used only to allow PostgreSQL
+ SERIAL to be used on a column that specifies Sequence() for usage on
+ other backends.
+
+ ``supports_native_enum``
+ Indicates if the dialect supports a native ENUM construct.
+ This will prevent types.Enum from generating a CHECK
+ constraint when that type is used.
+
+ ``supports_native_boolean``
+ Indicates if the dialect supports a native boolean construct.
+ This will prevent types.Boolean from generating a CHECK
+ constraint when that type is used.
+
+ ``dbapi_exception_translation_map``
+ A dictionary of names that will contain as values the names of
+ pep-249 exceptions ("IntegrityError", "OperationalError", etc)
+ keyed to alternate class names, to support the case where a
+ DBAPI has exception classes that aren't named as they are
+ referred to (e.g. IntegrityError = MyException). In the vast
+ majority of cases this dictionary is empty.
+
+ .. versionadded:: 1.0.5
+
+ """
+
+ _has_events = False
+
+ supports_statement_cache = True
+ """indicates if this dialect supports caching.
+
+ All dialects that are compatible with statement caching should set this
+ flag to True directly on each dialect class and subclass that supports
+ it. SQLAlchemy tests that this flag is locally present on each dialect
+ subclass before it will use statement caching. This is to provide
+ safety for legacy or new dialects that are not yet fully tested to be
+ compliant with SQL statement caching.
+
+ .. versionadded:: 1.4.5
+
+ .. seealso::
+
+ :ref:`engine_thirdparty_caching`
+
+ """
+
+ def create_connect_args(self, url):
+ """Build DB-API compatible connection arguments.
+
+ Given a :class:`.URL` object, returns a tuple
+ consisting of a ``(*args, **kwargs)`` suitable to send directly
+ to the dbapi's connect function. The arguments are sent to the
+ :meth:`.Dialect.connect` method which then runs the DBAPI-level
+ ``connect()`` function.
+
+ The method typically makes use of the
+ :meth:`.URL.translate_connect_args`
+ method in order to generate a dictionary of options.
+
+ The default implementation is::
+
+ def create_connect_args(self, url):
+ opts = url.translate_connect_args()
+ opts.update(url.query)
+ return [[], opts]
+
+ :param url: a :class:`.URL` object
+
+ :return: a tuple of ``(*args, **kwargs)`` which will be passed to the
+ :meth:`.Dialect.connect` method.
+
+ .. seealso::
+
+ :meth:`.URL.translate_connect_args`
+
+ """
+
+ raise NotImplementedError()
+
+ @classmethod
+ def type_descriptor(cls, typeobj):
+ """Transform a generic type to a dialect-specific type.
+
+ Dialect classes will usually use the
+ :func:`_types.adapt_type` function in the types module to
+ accomplish this.
+
+ The returned result is cached *per dialect class* so can
+ contain no dialect-instance state.
+
+ """
+
+ raise NotImplementedError()
+
+ def initialize(self, connection):
+ """Called during strategized creation of the dialect with a
+ connection.
+
+ Allows dialects to configure options based on server version info or
+ other properties.
+
+ The connection passed here is a SQLAlchemy Connection object,
+ with full capabilities.
+
+ The initialize() method of the base dialect should be called via
+ super().
+
+ .. note:: as of SQLAlchemy 1.4, this method is called **before**
+ any :meth:`_engine.Dialect.on_connect` hooks are called.
+
+ """
+
+ pass
+
+ def get_columns(self, connection, table_name, schema=None, **kw):
+ """Return information about columns in `table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return column
+ information as a list of dictionaries with these keys:
+
+ name
+ the column's name
+
+ type
+ [sqlalchemy.types#TypeEngine]
+
+ nullable
+ boolean
+
+ default
+ the column's default value
+
+ autoincrement
+ boolean
+
+ sequence
+ a dictionary of the form
+ {'name' : str, 'start' :int, 'increment': int, 'minvalue': int,
+ 'maxvalue': int, 'nominvalue': bool, 'nomaxvalue': bool,
+ 'cycle': bool, 'cache': int, 'order': bool}
+
+ Additional column attributes may be present.
+ """
+
+ raise NotImplementedError()
+
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ """Return information about the primary key constraint on
+ table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return primary
+ key information as a dictionary with these keys:
+
+ constrained_columns
+ a list of column names that make up the primary key
+
+ name
+ optional name of the primary key constraint.
+
+ """
+ raise NotImplementedError()
+
+ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ """Return information about foreign_keys in `table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name`, and an optional string `schema`, return foreign
+ key information as a list of dicts with these keys:
+
+ name
+ the constraint's name
+
+ constrained_columns
+ a list of column names that make up the foreign key
+
+ referred_schema
+ the name of the referred schema
+
+ referred_table
+ the name of the referred table
+
+ referred_columns
+ a list of column names in the referred table that correspond to
+ constrained_columns
+ """
+
+ raise NotImplementedError()
+
+ def get_table_names(self, connection, schema=None, **kw):
+ """Return a list of table names for `schema`."""
+
+ raise NotImplementedError()
+
+ def get_temp_table_names(self, connection, schema=None, **kw):
+ """Return a list of temporary table names on the given connection,
+ if supported by the underlying backend.
+
+ """
+
+ raise NotImplementedError()
+
+ def get_view_names(self, connection, schema=None, **kw):
+ """Return a list of all view names available in the database.
+
+ :param schema: schema name to query, if not the default schema.
+ """
+
+ raise NotImplementedError()
+
+ def get_sequence_names(self, connection, schema=None, **kw):
+ """Return a list of all sequence names available in the database.
+
+ :param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 1.4
+ """
+
+ raise NotImplementedError()
+
+ def get_temp_view_names(self, connection, schema=None, **kw):
+ """Return a list of temporary view names on the given connection,
+ if supported by the underlying backend.
+
+ """
+
+ raise NotImplementedError()
+
+ def get_view_definition(self, connection, view_name, schema=None, **kw):
+ """Return view definition.
+
+ Given a :class:`_engine.Connection`, a string
+ `view_name`, and an optional string `schema`, return the view
+ definition.
+ """
+
+ raise NotImplementedError()
+
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ """Return information about indexes in `table_name`.
+
+ Given a :class:`_engine.Connection`, a string
+ `table_name` and an optional string `schema`, return index
+ information as a list of dictionaries with these keys:
+
+ name
+ the index's name
+
+ column_names
+ list of column names in order
+
+ unique
+ boolean
+ """
+
+ raise NotImplementedError()
+
+ def get_unique_constraints(
+ self, connection, table_name, schema=None, **kw
+ ):
+ r"""Return information about unique constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ unique constraint information as a list of dicts with these keys:
+
+ name
+ the unique constraint's name
+
+ column_names
+ list of column names in order
+
+ \**kw
+ other options passed to the dialect's get_unique_constraints()
+ method.
+
+ .. versionadded:: 0.9.0
+
+ """
+
+ raise NotImplementedError()
+
+ def get_check_constraints(self, connection, table_name, schema=None, **kw):
+ r"""Return information about check constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ check constraint information as a list of dicts with these keys:
+
+ * ``name`` -
+ the check constraint's name
+
+ * ``sqltext`` -
+ the check constraint's SQL expression
+
+ * ``**kw`` -
+ other options passed to the dialect's get_check_constraints()
+ method.
+
+ .. versionadded:: 1.1.0
+
+ """
+
+ raise NotImplementedError()
+
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ r"""Return the "comment" for the table identified by `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ table comment information as a dictionary with this key:
+
+ text
+ text of the comment
+
+ Raises ``NotImplementedError`` for dialects that don't support
+ comments.
+
+ .. versionadded:: 1.2
+
+ """
+
+ raise NotImplementedError()
+
+ def normalize_name(self, name):
+ """convert the given name to lowercase if it is detected as
+ case insensitive.
+
+ This method is only used if the dialect defines
+ requires_name_normalize=True.
+
+ """
+ raise NotImplementedError()
+
+ def denormalize_name(self, name):
+ """convert the given name to a case insensitive identifier
+ for the backend if it is an all-lowercase name.
+
+ This method is only used if the dialect defines
+ requires_name_normalize=True.
+
+ """
+ raise NotImplementedError()
+
+ def has_table(self, connection, table_name, schema=None, **kw):
+ """For internal dialect use, check the existence of a particular table
+ in the database.
+
+ Given a :class:`_engine.Connection` object, a string table_name and
+ optional schema name, return True if the given table exists in the
+ database, False otherwise.
+
+ This method serves as the underlying implementation of the
+ public facing :meth:`.Inspector.has_table` method, and is also used
+ internally to implement the "checkfirst" behavior for methods like
+ :meth:`_schema.Table.create` and :meth:`_schema.MetaData.create_all`.
+
+ .. note:: This method is used internally by SQLAlchemy, and is
+ published so that third-party dialects may provide an
+ implementation. It is **not** the public API for checking for table
+ presence. Please use the :meth:`.Inspector.has_table` method.
+ Alternatively, for legacy cross-compatibility, the
+ :meth:`_engine.Engine.has_table` method may be used.
+
+ """
+
+ raise NotImplementedError()
+
+ def has_index(self, connection, table_name, index_name, schema=None):
+ """Check the existence of a particular index name in the database.
+
+ Given a :class:`_engine.Connection` object, a string
+ `table_name` and string index name, return True if an index of the
+ given name on the given table exists, false otherwise.
+
+ The :class:`.DefaultDialect` implements this in terms of the
+ :meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods,
+ however dialects can implement a more performant version.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ raise NotImplementedError()
+
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
+ """Check the existence of a particular sequence in the database.
+
+ Given a :class:`_engine.Connection` object and a string
+ `sequence_name`, return True if the given sequence exists in
+ the database, False otherwise.
+ """
+
+ raise NotImplementedError()
+
+ def _get_server_version_info(self, connection):
+ """Retrieve the server version info from the given connection.
+
+ This is used by the default implementation to populate the
+ "server_version_info" attribute and is called exactly
+ once upon first connect.
+
+ """
+
+ raise NotImplementedError()
+
+ def _get_default_schema_name(self, connection):
+ """Return the string name of the currently selected schema from
+ the given connection.
+
+ This is used by the default implementation to populate the
+ "default_schema_name" attribute and is called exactly
+ once upon first connect.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_begin(self, dbapi_connection):
+ """Provide an implementation of ``connection.begin()``, given a
+ DB-API connection.
+
+ The DBAPI has no dedicated "begin" method and it is expected
+ that transactions are implicit. This hook is provided for those
+ DBAPIs that might need additional help in this area.
+
+ Note that :meth:`.Dialect.do_begin` is not called unless a
+ :class:`.Transaction` object is in use. The
+ :meth:`.Dialect.do_autocommit`
+ hook is provided for DBAPIs that need some extra commands emitted
+ after a commit in order to enter the next transaction, when the
+ SQLAlchemy :class:`_engine.Connection`
+ is used in its default "autocommit"
+ mode.
+
+ :param dbapi_connection: a DBAPI connection, typically
+ proxied within a :class:`.ConnectionFairy`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_rollback(self, dbapi_connection):
+ """Provide an implementation of ``connection.rollback()``, given
+ a DB-API connection.
+
+ :param dbapi_connection: a DBAPI connection, typically
+ proxied within a :class:`.ConnectionFairy`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_commit(self, dbapi_connection):
+ """Provide an implementation of ``connection.commit()``, given a
+ DB-API connection.
+
+ :param dbapi_connection: a DBAPI connection, typically
+ proxied within a :class:`.ConnectionFairy`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_close(self, dbapi_connection):
+ """Provide an implementation of ``connection.close()``, given a DBAPI
+ connection.
+
+ This hook is called by the :class:`_pool.Pool`
+ when a connection has been
+ detached from the pool, or is being returned beyond the normal
+ capacity of the pool.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ """invoke the cursor.setinputsizes() method with appropriate arguments
+
+ This hook is called if the dialect.use_inputsizes flag is set to True.
+ Parameter data is passed in a list of tuples (paramname, dbtype,
+ sqltype), where ``paramname`` is the key of the parameter in the
+ statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the
+ SQLAlchemy type. The order of tuples is in the correct parameter order.
+
+ .. versionadded:: 1.4
+
+
+ """
+ raise NotImplementedError()
+
+ def create_xid(self):
+ """Create a two-phase transaction ID.
+
+ This id will be passed to do_begin_twophase(),
+ do_rollback_twophase(), do_commit_twophase(). Its format is
+ unspecified.
+ """
+
+ raise NotImplementedError()
+
+ def do_savepoint(self, connection, name):
+ """Create a savepoint with the given name.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param name: savepoint name.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_rollback_to_savepoint(self, connection, name):
+ """Rollback a connection to the named savepoint.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param name: savepoint name.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_release_savepoint(self, connection, name):
+ """Release the named savepoint on a connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param name: savepoint name.
+ """
+
+ raise NotImplementedError()
+
+ def do_begin_twophase(self, connection, xid):
+ """Begin a two phase transaction on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+
+ """
+
+ raise NotImplementedError()
+
+ def do_prepare_twophase(self, connection, xid):
+ """Prepare a two phase transaction on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+
+ """
+
+ raise NotImplementedError()
+
+ def do_rollback_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ """Rollback a two phase transaction on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+ :param is_prepared: whether or not
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+ :param recover: if the recover flag was passed.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_commit_twophase(
+ self, connection, xid, is_prepared=True, recover=False
+ ):
+ """Commit a two phase transaction on the given connection.
+
+
+ :param connection: a :class:`_engine.Connection`.
+ :param xid: xid
+ :param is_prepared: whether or not
+ :meth:`.TwoPhaseTransaction.prepare` was called.
+ :param recover: if the recover flag was passed.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_recover_twophase(self, connection):
+ """Recover list of uncommitted prepared two phase transaction
+ identifiers on the given connection.
+
+ :param connection: a :class:`_engine.Connection`.
+
+ """
+
+ raise NotImplementedError()
+
+ def do_executemany(self, cursor, statement, parameters, context=None):
+ """Provide an implementation of ``cursor.executemany(statement,
+ parameters)``."""
+
+ raise NotImplementedError()
+
+ def do_execute(self, cursor, statement, parameters, context=None):
+ """Provide an implementation of ``cursor.execute(statement,
+ parameters)``."""
+
+ raise NotImplementedError()
+
+ def do_execute_no_params(
+ self, cursor, statement, parameters, context=None
+ ):
+ """Provide an implementation of ``cursor.execute(statement)``.
+
+ The parameter collection should not be sent.
+
+ """
+
+ raise NotImplementedError()
+
+ def is_disconnect(self, e, connection, cursor):
+ """Return True if the given DB-API error indicates an invalid
+ connection"""
+
+ raise NotImplementedError()
+
+ def connect(self, *cargs, **cparams):
+ r"""Establish a connection using this dialect's DBAPI.
+
+ The default implementation of this method is::
+
+ def connect(self, *cargs, **cparams):
+ return self.dbapi.connect(*cargs, **cparams)
+
+ The ``*cargs, **cparams`` parameters are generated directly
+ from this dialect's :meth:`.Dialect.create_connect_args` method.
+
+ This method may be used for dialects that need to perform programmatic
+ per-connection steps when a new connection is procured from the
+ DBAPI.
+
+
+ :param \*cargs: positional parameters returned from the
+ :meth:`.Dialect.create_connect_args` method
+
+ :param \*\*cparams: keyword parameters returned from the
+ :meth:`.Dialect.create_connect_args` method.
+
+ :return: a DBAPI connection, typically from the :pep:`249` module
+ level ``.connect()`` function.
+
+ .. seealso::
+
+ :meth:`.Dialect.create_connect_args`
+
+ :meth:`.Dialect.on_connect`
+
+ """
+
+ def on_connect_url(self, url):
+ """return a callable which sets up a newly created DBAPI connection.
+
+ This method is a new hook that supersedes the
+ :meth:`_engine.Dialect.on_connect` method when implemented by a
+ dialect. When not implemented by a dialect, it invokes the
+ :meth:`_engine.Dialect.on_connect` method directly to maintain
+ compatibility with existing dialects. There is no deprecation
+ for :meth:`_engine.Dialect.on_connect` expected.
+
+ The callable should accept a single argument "conn" which is the
+ DBAPI connection itself. The inner callable has no
+ return value.
+
+ E.g.::
+
+ class MyDialect(default.DefaultDialect):
+ # ...
+
+ def on_connect_url(self, url):
+ def do_on_connect(connection):
+ connection.execute("SET SPECIAL FLAGS etc")
+
+ return do_on_connect
+
+ This is used to set dialect-wide per-connection options such as
+ isolation modes, Unicode modes, etc.
+
+ This method differs from :meth:`_engine.Dialect.on_connect` in that
+ it is passed the :class:`_engine.URL` object that's relevant to the
+ connect args. Normally the only way to get this is from the
+ :meth:`_engine.Dialect.on_connect` hook is to look on the
+ :class:`_engine.Engine` itself, however this URL object may have been
+ replaced by plugins.
+
+ .. note::
+
+ The default implementation of
+ :meth:`_engine.Dialect.on_connect_url` is to invoke the
+ :meth:`_engine.Dialect.on_connect` method. Therefore if a dialect
+ implements this method, the :meth:`_engine.Dialect.on_connect`
+ method **will not be called** unless the overriding dialect calls
+ it directly from here.
+
+ .. versionadded:: 1.4.3 added :meth:`_engine.Dialect.on_connect_url`
+ which normally calls into :meth:`_engine.Dialect.on_connect`.
+
+ :param url: a :class:`_engine.URL` object representing the
+ :class:`_engine.URL` that was passed to the
+ :meth:`_engine.Dialect.create_connect_args` method.
+
+ :return: a callable that accepts a single DBAPI connection as an
+ argument, or None.
+
+ .. seealso::
+
+ :meth:`_engine.Dialect.on_connect`
+
+ """
+ return self.on_connect()
+
+ def on_connect(self):
+ """return a callable which sets up a newly created DBAPI connection.
+
+ The callable should accept a single argument "conn" which is the
+ DBAPI connection itself. The inner callable has no
+ return value.
+
+ E.g.::
+
+ class MyDialect(default.DefaultDialect):
+ # ...
+
+ def on_connect(self):
+ def do_on_connect(connection):
+ connection.execute("SET SPECIAL FLAGS etc")
+
+ return do_on_connect
+
+ This is used to set dialect-wide per-connection options such as
+ isolation modes, Unicode modes, etc.
+
+ The "do_on_connect" callable is invoked by using the
+ :meth:`_events.PoolEvents.connect` event
+ hook, then unwrapping the DBAPI connection and passing it into the
+ callable.
+
+ .. versionchanged:: 1.4 the on_connect hook is no longer called twice
+ for the first connection of a dialect. The on_connect hook is still
+ called before the :meth:`_engine.Dialect.initialize` method however.
+
+ .. versionchanged:: 1.4.3 the on_connect hook is invoked from a new
+ method on_connect_url that passes the URL that was used to create
+ the connect args. Dialects can implement on_connect_url instead
+ of on_connect if they need the URL object that was used for the
+ connection in order to get additional context.
+
+ If None is returned, no event listener is generated.
+
+ :return: a callable that accepts a single DBAPI connection as an
+ argument, or None.
+
+ .. seealso::
+
+ :meth:`.Dialect.connect` - allows the DBAPI ``connect()`` sequence
+ itself to be controlled.
+
+ :meth:`.Dialect.on_connect_url` - supersedes
+ :meth:`.Dialect.on_connect` to also receive the
+ :class:`_engine.URL` object in context.
+
+ """
+ return None
+
+ def reset_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, revert its isolation to the default.
+
+ Note that this is a dialect-level method which is used as part
+ of the implementation of the :class:`_engine.Connection` and
+ :class:`_engine.Engine`
+ isolation level facilities; these APIs should be preferred for
+ most typical use cases.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`.Connection.execution_options.isolation_level` -
+ set per :class:`_engine.Connection` isolation level
+
+ :paramref:`_sa.create_engine.isolation_level` -
+ set per :class:`_engine.Engine` isolation level
+
+ """
+
+ raise NotImplementedError()
+
+ def set_isolation_level(self, dbapi_conn, level):
+ """Given a DBAPI connection, set its isolation level.
+
+ Note that this is a dialect-level method which is used as part
+ of the implementation of the :class:`_engine.Connection` and
+ :class:`_engine.Engine`
+ isolation level facilities; these APIs should be preferred for
+ most typical use cases.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`.Connection.execution_options.isolation_level` -
+ set per :class:`_engine.Connection` isolation level
+
+ :paramref:`_sa.create_engine.isolation_level` -
+ set per :class:`_engine.Engine` isolation level
+
+ """
+
+ raise NotImplementedError()
+
+ def get_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, return its isolation level.
+
+ When working with a :class:`_engine.Connection` object,
+ the corresponding
+ DBAPI connection may be procured using the
+ :attr:`_engine.Connection.connection` accessor.
+
+ Note that this is a dialect-level method which is used as part
+ of the implementation of the :class:`_engine.Connection` and
+ :class:`_engine.Engine` isolation level facilities;
+ these APIs should be preferred for most typical use cases.
+
+
+ .. seealso::
+
+ :meth:`_engine.Connection.get_isolation_level`
+ - view current level
+
+ :attr:`_engine.Connection.default_isolation_level`
+ - view default level
+
+ :paramref:`.Connection.execution_options.isolation_level` -
+ set per :class:`_engine.Connection` isolation level
+
+ :paramref:`_sa.create_engine.isolation_level` -
+ set per :class:`_engine.Engine` isolation level
+
+
+ """
+
+ raise NotImplementedError()
+
+ def get_default_isolation_level(self, dbapi_conn):
+ """Given a DBAPI connection, return its isolation level, or
+ a default isolation level if one cannot be retrieved.
+
+ This method may only raise NotImplementedError and
+ **must not raise any other exception**, as it is used implicitly upon
+ first connect.
+
+ The method **must return a value** for a dialect that supports
+ isolation level settings, as this level is what will be reverted
+ towards when a per-connection isolation level change is made.
+
+ The method defaults to using the :meth:`.Dialect.get_isolation_level`
+ method unless overridden by a dialect.
+
+ .. versionadded:: 1.3.22
+
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def get_dialect_cls(cls, url):
+ """Given a URL, return the :class:`.Dialect` that will be used.
+
+ This is a hook that allows an external plugin to provide functionality
+ around an existing dialect, by allowing the plugin to be loaded
+ from the url based on an entrypoint, and then the plugin returns
+ the actual dialect to be used.
+
+ By default this just returns the cls.
+
+ .. versionadded:: 1.0.3
+
+ """
+ return cls
+
+ @classmethod
+ def load_provisioning(cls):
+ """set up the provision.py module for this dialect.
+
+ For dialects that include a provision.py module that sets up
+ provisioning followers, this method should initiate that process.
+
+ A typical implementation would be::
+
+ @classmethod
+ def load_provisioning(cls):
+ __import__("mydialect.provision")
+
+ The default method assumes a module named ``provision.py`` inside
+ the owning package of the current dialect, based on the ``__module__``
+ attribute::
+
+ @classmethod
+ def load_provisioning(cls):
+ package = ".".join(cls.__module__.split(".")[0:-1])
+ try:
+ __import__(package + ".provision")
+ except ImportError:
+ pass
+
+ .. versionadded:: 1.3.14
+
+ """
+
+ @classmethod
+ def engine_created(cls, engine):
+ """A convenience hook called before returning the final
+ :class:`_engine.Engine`.
+
+ If the dialect returned a different class from the
+ :meth:`.get_dialect_cls`
+ method, then the hook is called on both classes, first on
+ the dialect class returned by the :meth:`.get_dialect_cls` method and
+ then on the class on which the method was called.
+
+ The hook should be used by dialects and/or wrappers to apply special
+ events to the engine or its components. In particular, it allows
+ a dialect-wrapping class to apply dialect-level events.
+
+ .. versionadded:: 1.0.3
+
+ """
+
+ def get_driver_connection(self, connection):
+ """Returns the connection object as returned by the external driver
+ package.
+
+ For normal dialects that use a DBAPI compliant driver this call
+ will just return the ``connection`` passed as argument.
+ For dialects that instead adapt a non DBAPI compliant driver, like
+ when adapting an asyncio driver, this call will return the
+ connection-like object as returned by the driver.
+
+ .. versionadded:: 1.4.24
+
+ """
+ raise NotImplementedError()
+
+
+class CreateEnginePlugin(object):
+ """A set of hooks intended to augment the construction of an
+ :class:`_engine.Engine` object based on entrypoint names in a URL.
+
+ The purpose of :class:`_engine.CreateEnginePlugin` is to allow third-party
+ systems to apply engine, pool and dialect level event listeners without
+ the need for the target application to be modified; instead, the plugin
+ names can be added to the database URL. Target applications for
+ :class:`_engine.CreateEnginePlugin` include:
+
+ * connection and SQL performance tools, e.g. which use events to track
+ number of checkouts and/or time spent with statements
+
+ * connectivity plugins such as proxies
+
+ A rudimentary :class:`_engine.CreateEnginePlugin` that attaches a logger
+ to an :class:`_engine.Engine` object might look like::
+
+
+ import logging
+
+ from sqlalchemy.engine import CreateEnginePlugin
+ from sqlalchemy import event
+
+ class LogCursorEventsPlugin(CreateEnginePlugin):
+ def __init__(self, url, kwargs):
+ # consume the parameter "log_cursor_logging_name" from the
+ # URL query
+ logging_name = url.query.get("log_cursor_logging_name", "log_cursor")
+
+ self.log = logging.getLogger(logging_name)
+
+ def update_url(self, url):
+ "update the URL to one that no longer includes our parameters"
+ return url.difference_update_query(["log_cursor_logging_name"])
+
+ def engine_created(self, engine):
+ "attach an event listener after the new Engine is constructed"
+ event.listen(engine, "before_cursor_execute", self._log_event)
+
+
+ def _log_event(
+ self,
+ conn,
+ cursor,
+ statement,
+ parameters,
+ context,
+ executemany):
+
+ self.log.info("Plugin logged cursor event: %s", statement)
+
+
+
+ Plugins are registered using entry points in a similar way as that
+ of dialects::
+
+ entry_points={
+ 'sqlalchemy.plugins': [
+ 'log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin'
+ ]
+
+ A plugin that uses the above names would be invoked from a database
+ URL as in::
+
+ from sqlalchemy import create_engine
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?"
+ "plugin=log_cursor_plugin&log_cursor_logging_name=mylogger"
+ )
+
+ The ``plugin`` URL parameter supports multiple instances, so that a URL
+ may specify multiple plugins; they are loaded in the order stated
+ in the URL::
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?"
+ "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three")
+
+ The plugin names may also be passed directly to :func:`_sa.create_engine`
+ using the :paramref:`_sa.create_engine.plugins` argument::
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test",
+ plugins=["myplugin"])
+
+ .. versionadded:: 1.2.3 plugin names can also be specified
+ to :func:`_sa.create_engine` as a list
+
+ A plugin may consume plugin-specific arguments from the
+ :class:`_engine.URL` object as well as the ``kwargs`` dictionary, which is
+ the dictionary of arguments passed to the :func:`_sa.create_engine`
+ call. "Consuming" these arguments includes that they must be removed
+ when the plugin initializes, so that the arguments are not passed along
+ to the :class:`_engine.Dialect` constructor, where they will raise an
+ :class:`_exc.ArgumentError` because they are not known by the dialect.
+
+ As of version 1.4 of SQLAlchemy, arguments should continue to be consumed
+ from the ``kwargs`` dictionary directly, by removing the values with a
+ method such as ``dict.pop``. Arguments from the :class:`_engine.URL` object
+ should be consumed by implementing the
+ :meth:`_engine.CreateEnginePlugin.update_url` method, returning a new copy
+ of the :class:`_engine.URL` with plugin-specific parameters removed::
+
+ class MyPlugin(CreateEnginePlugin):
+ def __init__(self, url, kwargs):
+ self.my_argument_one = url.query['my_argument_one']
+ self.my_argument_two = url.query['my_argument_two']
+ self.my_argument_three = kwargs.pop('my_argument_three', None)
+
+ def update_url(self, url):
+ return url.difference_update_query(
+ ["my_argument_one", "my_argument_two"]
+ )
+
+ Arguments like those illustrated above would be consumed from a
+ :func:`_sa.create_engine` call such as::
+
+ from sqlalchemy import create_engine
+
+ engine = create_engine(
+ "mysql+pymysql://scott:tiger@localhost/test?"
+ "plugin=myplugin&my_argument_one=foo&my_argument_two=bar",
+ my_argument_three='bat'
+ )
+
+ .. versionchanged:: 1.4
+
+ The :class:`_engine.URL` object is now immutable; a
+ :class:`_engine.CreateEnginePlugin` that needs to alter the
+ :class:`_engine.URL` should implement the newly added
+ :meth:`_engine.CreateEnginePlugin.update_url` method, which
+ is invoked after the plugin is constructed.
+
+ For migration, construct the plugin in the following way, checking
+ for the existence of the :meth:`_engine.CreateEnginePlugin.update_url`
+ method to detect which version is running::
+
+ class MyPlugin(CreateEnginePlugin):
+ def __init__(self, url, kwargs):
+ if hasattr(CreateEnginePlugin, "update_url"):
+ # detect the 1.4 API
+ self.my_argument_one = url.query['my_argument_one']
+ self.my_argument_two = url.query['my_argument_two']
+ else:
+ # detect the 1.3 and earlier API - mutate the
+ # URL directly
+ self.my_argument_one = url.query.pop('my_argument_one')
+ self.my_argument_two = url.query.pop('my_argument_two')
+
+ self.my_argument_three = kwargs.pop('my_argument_three', None)
+
+ def update_url(self, url):
+ # this method is only called in the 1.4 version
+ return url.difference_update_query(
+ ["my_argument_one", "my_argument_two"]
+ )
+
+ .. seealso::
+
+ :ref:`change_5526` - overview of the :class:`_engine.URL` change which
+ also includes notes regarding :class:`_engine.CreateEnginePlugin`.
+
+
+ When the engine creation process completes and produces the
+ :class:`_engine.Engine` object, it is again passed to the plugin via the
+ :meth:`_engine.CreateEnginePlugin.engine_created` hook. In this hook, additional
+ changes can be made to the engine, most typically involving setup of
+ events (e.g. those defined in :ref:`core_event_toplevel`).
+
+ .. versionadded:: 1.1
+
+ """ # noqa: E501
+
+ def __init__(self, url, kwargs):
+ """Construct a new :class:`.CreateEnginePlugin`.
+
+ The plugin object is instantiated individually for each call
+ to :func:`_sa.create_engine`. A single :class:`_engine.
+ Engine` will be
+ passed to the :meth:`.CreateEnginePlugin.engine_created` method
+ corresponding to this URL.
+
+ :param url: the :class:`_engine.URL` object. The plugin may inspect
+ the :class:`_engine.URL` for arguments. Arguments used by the
+ plugin should be removed, by returning an updated :class:`_engine.URL`
+ from the :meth:`_engine.CreateEnginePlugin.update_url` method.
+
+ .. versionchanged:: 1.4
+
+ The :class:`_engine.URL` object is now immutable, so a
+ :class:`_engine.CreateEnginePlugin` that needs to alter the
+ :class:`_engine.URL` object should implement the
+ :meth:`_engine.CreateEnginePlugin.update_url` method.
+
+ :param kwargs: The keyword arguments passed to
+ :func:`_sa.create_engine`.
+
+ """
+ self.url = url
+
+ def update_url(self, url):
+ """Update the :class:`_engine.URL`.
+
+ A new :class:`_engine.URL` should be returned. This method is
+ typically used to consume configuration arguments from the
+ :class:`_engine.URL` which must be removed, as they will not be
+ recognized by the dialect. The
+ :meth:`_engine.URL.difference_update_query` method is available
+ to remove these arguments. See the docstring at
+ :class:`_engine.CreateEnginePlugin` for an example.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ def handle_dialect_kwargs(self, dialect_cls, dialect_args):
+ """parse and modify dialect kwargs"""
+
+ def handle_pool_kwargs(self, pool_cls, pool_args):
+ """parse and modify pool kwargs"""
+
+ def engine_created(self, engine):
+ """Receive the :class:`_engine.Engine`
+ object when it is fully constructed.
+
+ The plugin may make additional changes to the engine, such as
+ registering engine or connection pool events.
+
+ """
+
+
+class ExecutionContext(object):
+ """A messenger object for a Dialect that corresponds to a single
+ execution.
+
+ ExecutionContext should have these data members:
+
+ connection
+ Connection object which can be freely used by default value
+ generators to execute SQL. This Connection should reference the
+ same underlying connection/transactional resources of
+ root_connection.
+
+ root_connection
+ Connection object which is the source of this ExecutionContext. This
+ Connection may have close_with_result=True set, in which case it can
+ only be used once.
+
+ dialect
+ dialect which created this ExecutionContext.
+
+ cursor
+ DB-API cursor procured from the connection,
+
+ compiled
+ if passed to constructor, sqlalchemy.engine.base.Compiled object
+ being executed,
+
+ statement
+ string version of the statement to be executed. Is either
+ passed to the constructor, or must be created from the
+ sql.Compiled object by the time pre_exec() has completed.
+
+ parameters
+ bind parameters passed to the execute() method. For compiled
+ statements, this is a dictionary or list of dictionaries. For
+ textual statements, it should be in a format suitable for the
+ dialect's paramstyle (i.e. dict or list of dicts for non
+ positional, list or list of lists/tuples for positional).
+
+ isinsert
+ True if the statement is an INSERT.
+
+ isupdate
+ True if the statement is an UPDATE.
+
+ should_autocommit
+ True if the statement is a "committable" statement.
+
+ prefetch_cols
+ a list of Column objects for which a client-side default
+ was fired off. Applies to inserts and updates.
+
+ postfetch_cols
+ a list of Column objects for which a server-side default or
+ inline SQL expression value was fired off. Applies to inserts
+ and updates.
+ """
+
+ def create_cursor(self):
+ """Return a new cursor generated from this ExecutionContext's
+ connection.
+
+ Some dialects may wish to change the behavior of
+ connection.cursor(), such as postgresql which may return a PG
+ "server side" cursor.
+ """
+
+ raise NotImplementedError()
+
+ def pre_exec(self):
+ """Called before an execution of a compiled statement.
+
+ If a compiled statement was passed to this ExecutionContext,
+ the `statement` and `parameters` datamembers must be
+ initialized after this statement is complete.
+ """
+
+ raise NotImplementedError()
+
+ def get_out_parameter_values(self, out_param_names):
+ """Return a sequence of OUT parameter values from a cursor.
+
+ For dialects that support OUT parameters, this method will be called
+ when there is a :class:`.SQLCompiler` object which has the
+ :attr:`.SQLCompiler.has_out_parameters` flag set. This flag in turn
+ will be set to True if the statement itself has :class:`.BindParameter`
+ objects that have the ``.isoutparam`` flag set which are consumed by
+ the :meth:`.SQLCompiler.visit_bindparam` method. If the dialect
+ compiler produces :class:`.BindParameter` objects with ``.isoutparam``
+ set which are not handled by :meth:`.SQLCompiler.visit_bindparam`, it
+ should set this flag explicitly.
+
+ The list of names that were rendered for each bound parameter
+ is passed to the method. The method should then return a sequence of
+ values corresponding to the list of parameter objects. Unlike in
+ previous SQLAlchemy versions, the values can be the **raw values** from
+ the DBAPI; the execution context will apply the appropriate type
+ handler based on what's present in self.compiled.binds and update the
+ values. The processed dictionary will then be made available via the
+ ``.out_parameters`` collection on the result object. Note that
+ SQLAlchemy 1.4 has multiple kinds of result object as part of the 2.0
+ transition.
+
+ .. versionadded:: 1.4 - added
+ :meth:`.ExecutionContext.get_out_parameter_values`, which is invoked
+ automatically by the :class:`.DefaultExecutionContext` when there
+ are :class:`.BindParameter` objects with the ``.isoutparam`` flag
+ set. This replaces the practice of setting out parameters within
+ the now-removed ``get_result_proxy()`` method.
+
+ """
+ raise NotImplementedError()
+
+ def post_exec(self):
+ """Called after the execution of a compiled statement.
+
+ If a compiled statement was passed to this ExecutionContext,
+ the `last_insert_ids`, `last_inserted_params`, etc.
+ datamembers should be available after this method completes.
+ """
+
+ raise NotImplementedError()
+
+ def handle_dbapi_exception(self, e):
+ """Receive a DBAPI exception which occurred upon execute, result
+ fetch, etc."""
+
+ raise NotImplementedError()
+
+ def should_autocommit_text(self, statement):
+ """Parse the given textual statement and return True if it refers to
+ a "committable" statement"""
+
+ raise NotImplementedError()
+
+ def lastrow_has_defaults(self):
+ """Return True if the last INSERT or UPDATE row contained
+ inlined or database-side defaults.
+ """
+
+ raise NotImplementedError()
+
+ def get_rowcount(self):
+ """Return the DBAPI ``cursor.rowcount`` value, or in some
+ cases an interpreted value.
+
+ See :attr:`_engine.CursorResult.rowcount` for details on this.
+
+ """
+
+ raise NotImplementedError()
+
+
+@util.deprecated_20_cls(
+ ":class:`.Connectable`",
+ alternative=(
+ "The :class:`_engine.Engine` will be the only Core "
+ "object that features a .connect() method, and the "
+ ":class:`_engine.Connection` will be the only object that features "
+ "an .execute() method."
+ ),
+ constructor=None,
+)
+class Connectable(object):
+ """Interface for an object which supports execution of SQL constructs.
+
+ The two implementations of :class:`.Connectable` are
+ :class:`_engine.Connection` and :class:`_engine.Engine`.
+
+ Connectable must also implement the 'dialect' member which references a
+ :class:`.Dialect` instance.
+
+ """
+
+ def connect(self, **kwargs):
+ """Return a :class:`_engine.Connection` object.
+
+ Depending on context, this may be ``self`` if this object
+ is already an instance of :class:`_engine.Connection`, or a newly
+ procured :class:`_engine.Connection` if this object is an instance
+ of :class:`_engine.Engine`.
+
+ """
+
+ engine = None
+ """The :class:`_engine.Engine` instance referred to by this
+ :class:`.Connectable`.
+
+ May be ``self`` if this is already an :class:`_engine.Engine`.
+
+ """
+
+ def execute(self, object_, *multiparams, **params):
+ """Executes the given construct and returns a
+ :class:`_engine.CursorResult`.
+ """
+ raise NotImplementedError()
+
+ def scalar(self, object_, *multiparams, **params):
+ """Executes and returns the first column of the first row.
+
+ The underlying cursor is closed after execution.
+ """
+ raise NotImplementedError()
+
+ def _run_visitor(self, visitorcallable, element, **kwargs):
+ raise NotImplementedError()
+
+ def _execute_clauseelement(self, elem, multiparams=None, params=None):
+ raise NotImplementedError()
+
+
+class ExceptionContext(object):
+ """Encapsulate information about an error condition in progress.
+
+ This object exists solely to be passed to the
+ :meth:`_events.ConnectionEvents.handle_error` event,
+ supporting an interface that
+ can be extended without backwards-incompatibility.
+
+ .. versionadded:: 0.9.7
+
+ """
+
+ connection = None
+ """The :class:`_engine.Connection` in use during the exception.
+
+ This member is present, except in the case of a failure when
+ first connecting.
+
+ .. seealso::
+
+ :attr:`.ExceptionContext.engine`
+
+
+ """
+
+ engine = None
+ """The :class:`_engine.Engine` in use during the exception.
+
+ This member should always be present, even in the case of a failure
+ when first connecting.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+ cursor = None
+ """The DBAPI cursor object.
+
+ May be None.
+
+ """
+
+ statement = None
+ """String SQL statement that was emitted directly to the DBAPI.
+
+ May be None.
+
+ """
+
+ parameters = None
+ """Parameter collection that was emitted directly to the DBAPI.
+
+ May be None.
+
+ """
+
+ original_exception = None
+ """The exception object which was caught.
+
+ This member is always present.
+
+ """
+
+ sqlalchemy_exception = None
+ """The :class:`sqlalchemy.exc.StatementError` which wraps the original,
+ and will be raised if exception handling is not circumvented by the event.
+
+ May be None, as not all exception types are wrapped by SQLAlchemy.
+ For DBAPI-level exceptions that subclass the dbapi's Error class, this
+ field will always be present.
+
+ """
+
+ chained_exception = None
+ """The exception that was returned by the previous handler in the
+ exception chain, if any.
+
+ If present, this exception will be the one ultimately raised by
+ SQLAlchemy unless a subsequent handler replaces it.
+
+ May be None.
+
+ """
+
+ execution_context = None
+ """The :class:`.ExecutionContext` corresponding to the execution
+ operation in progress.
+
+ This is present for statement execution operations, but not for
+ operations such as transaction begin/end. It also is not present when
+ the exception was raised before the :class:`.ExecutionContext`
+ could be constructed.
+
+ Note that the :attr:`.ExceptionContext.statement` and
+ :attr:`.ExceptionContext.parameters` members may represent a
+ different value than that of the :class:`.ExecutionContext`,
+ potentially in the case where a
+ :meth:`_events.ConnectionEvents.before_cursor_execute` event or similar
+ modified the statement/parameters to be sent.
+
+ May be None.
+
+ """
+
+ is_disconnect = None
+ """Represent whether the exception as occurred represents a "disconnect"
+ condition.
+
+ This flag will always be True or False within the scope of the
+ :meth:`_events.ConnectionEvents.handle_error` handler.
+
+ SQLAlchemy will defer to this flag in order to determine whether or not
+ the connection should be invalidated subsequently. That is, by
+ assigning to this flag, a "disconnect" event which then results in
+ a connection and pool invalidation can be invoked or prevented by
+ changing this flag.
+
+
+ .. note:: The pool "pre_ping" handler enabled using the
+ :paramref:`_sa.create_engine.pool_pre_ping` parameter does **not**
+ consult this event before deciding if the "ping" returned false,
+ as opposed to receiving an unhandled error. For this use case, the
+ :ref:`legacy recipe based on engine_connect() may be used
+ <pool_disconnects_pessimistic_custom>`. A future API allow more
+ comprehensive customization of the "disconnect" detection mechanism
+ across all functions.
+
+ """
+
+ invalidate_pool_on_disconnect = True
+ """Represent whether all connections in the pool should be invalidated
+ when a "disconnect" condition is in effect.
+
+ Setting this flag to False within the scope of the
+ :meth:`_events.ConnectionEvents.handle_error`
+ event will have the effect such
+ that the full collection of connections in the pool will not be
+ invalidated during a disconnect; only the current connection that is the
+ subject of the error will actually be invalidated.
+
+ The purpose of this flag is for custom disconnect-handling schemes where
+ the invalidation of other connections in the pool is to be performed
+ based on other conditions, or even on a per-connection basis.
+
+ .. versionadded:: 1.0.3
+
+ """
+
+
+class AdaptedConnection(object):
+ """Interface of an adapted connection object to support the DBAPI protocol.
+
+ Used by asyncio dialects to provide a sync-style pep-249 facade on top
+ of the asyncio connection/cursor API provided by the driver.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ __slots__ = ("_connection",)
+
+ @property
+ def driver_connection(self):
+ """The connection object as returned by the driver after a connect."""
+ return self._connection
+
+ def run_async(self, fn):
+ """Run the awaitable returned by the given function, which is passed
+ the raw asyncio driver connection.
+
+ This is used to invoke awaitable-only methods on the driver connection
+ within the context of a "synchronous" method, like a connection
+ pool event handler.
+
+ E.g.::
+
+ engine = create_async_engine(...)
+
+ @event.listens_for(engine.sync_engine, "connect")
+ def register_custom_types(dbapi_connection, ...):
+ dbapi_connection.run_async(
+ lambda connection: connection.set_type_codec(
+ 'MyCustomType', encoder, decoder, ...
+ )
+ )
+
+ .. versionadded:: 1.4.30
+
+ .. seealso::
+
+ :ref:`asyncio_events_run_async`
+
+ """
+ return await_only(fn(self._connection))
+
+ def __repr__(self):
+ return "<AdaptedConnection %s>" % self._connection
diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py
new file mode 100644
index 0000000..6fcb09f
--- /dev/null
+++ b/lib/sqlalchemy/engine/mock.py
@@ -0,0 +1,118 @@
+# engine/mock.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from operator import attrgetter
+
+from . import base
+from . import url as _url
+from .. import util
+from ..sql import ddl
+
+
+class MockConnection(base.Connectable):
+ def __init__(self, dialect, execute):
+ self._dialect = dialect
+ self.execute = execute
+
+ engine = property(lambda s: s)
+ dialect = property(attrgetter("_dialect"))
+ name = property(lambda s: s._dialect.name)
+
+ def schema_for_object(self, obj):
+ return obj.schema
+
+ def connect(self, **kwargs):
+ return self
+
+ def execution_options(self, **kw):
+ return self
+
+ def compiler(self, statement, parameters, **kwargs):
+ return self._dialect.compiler(
+ statement, parameters, engine=self, **kwargs
+ )
+
+ def create(self, entity, **kwargs):
+ kwargs["checkfirst"] = False
+
+ ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single(
+ entity
+ )
+
+ def drop(self, entity, **kwargs):
+ kwargs["checkfirst"] = False
+
+ ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single(entity)
+
+ def _run_ddl_visitor(
+ self, visitorcallable, element, connection=None, **kwargs
+ ):
+ kwargs["checkfirst"] = False
+ visitorcallable(self.dialect, self, **kwargs).traverse_single(element)
+
+ def execute(self, object_, *multiparams, **params):
+ raise NotImplementedError()
+
+
+def create_mock_engine(url, executor, **kw):
+ """Create a "mock" engine used for echoing DDL.
+
+ This is a utility function used for debugging or storing the output of DDL
+ sequences as generated by :meth:`_schema.MetaData.create_all`
+ and related methods.
+
+ The function accepts a URL which is used only to determine the kind of
+ dialect to be used, as well as an "executor" callable function which
+ will receive a SQL expression object and parameters, which can then be
+ echoed or otherwise printed. The executor's return value is not handled,
+ nor does the engine allow regular string statements to be invoked, and
+ is therefore only useful for DDL that is sent to the database without
+ receiving any results.
+
+ E.g.::
+
+ from sqlalchemy import create_mock_engine
+
+ def dump(sql, *multiparams, **params):
+ print(sql.compile(dialect=engine.dialect))
+
+ engine = create_mock_engine('postgresql://', dump)
+ metadata.create_all(engine, checkfirst=False)
+
+ :param url: A string URL which typically needs to contain only the
+ database backend name.
+
+ :param executor: a callable which receives the arguments ``sql``,
+ ``*multiparams`` and ``**params``. The ``sql`` parameter is typically
+ an instance of :class:`.DDLElement`, which can then be compiled into a
+ string using :meth:`.DDLElement.compile`.
+
+ .. versionadded:: 1.4 - the :func:`.create_mock_engine` function replaces
+ the previous "mock" engine strategy used with
+ :func:`_sa.create_engine`.
+
+ .. seealso::
+
+ :ref:`faq_ddl_as_string`
+
+ """
+
+ # create url.URL object
+ u = _url.make_url(url)
+
+ dialect_cls = u.get_dialect()
+
+ dialect_args = {}
+ # consume dialect arguments from kwargs
+ for k in util.get_cls_kwargs(dialect_cls):
+ if k in kw:
+ dialect_args[k] = kw.pop(k)
+
+ # create dialect
+ dialect = dialect_cls(**dialect_args)
+
+ return MockConnection(dialect, executor)
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
new file mode 100644
index 0000000..b475228
--- /dev/null
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -0,0 +1,1160 @@
+# engine/reflection.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Provides an abstraction for obtaining database schema information.
+
+Usage Notes:
+
+Here are some general conventions when accessing the low level inspector
+methods such as get_table_names, get_columns, etc.
+
+1. Inspector methods return lists of dicts in most cases for the following
+ reasons:
+
+ * They're both standard types that can be serialized.
+ * Using a dict instead of a tuple allows easy expansion of attributes.
+ * Using a list for the outer structure maintains order and is easy to work
+ with (e.g. list comprehension [d['name'] for d in cols]).
+
+2. Records that contain a name, such as the column name in a column record
+ use the key 'name'. So for most return values, each record will have a
+ 'name' attribute..
+"""
+
+import contextlib
+
+from .base import Connectable
+from .base import Connection
+from .base import Engine
+from .. import exc
+from .. import inspection
+from .. import sql
+from .. import util
+from ..sql import operators
+from ..sql import schema as sa_schema
+from ..sql.type_api import TypeEngine
+from ..util import topological
+
+
+@util.decorator
+def cache(fn, self, con, *args, **kw):
+ info_cache = kw.get("info_cache", None)
+ if info_cache is None:
+ return fn(self, con, *args, **kw)
+ key = (
+ fn.__name__,
+ tuple(a for a in args if isinstance(a, util.string_types)),
+ tuple((k, v) for k, v in kw.items() if k != "info_cache"),
+ )
+ ret = info_cache.get(key)
+ if ret is None:
+ ret = fn(self, con, *args, **kw)
+ info_cache[key] = ret
+ return ret
+
+
+@inspection._self_inspects
+class Inspector(object):
+ """Performs database schema inspection.
+
+ The Inspector acts as a proxy to the reflection methods of the
+ :class:`~sqlalchemy.engine.interfaces.Dialect`, providing a
+ consistent interface as well as caching support for previously
+ fetched metadata.
+
+ A :class:`_reflection.Inspector` object is usually created via the
+ :func:`_sa.inspect` function, which may be passed an
+ :class:`_engine.Engine`
+ or a :class:`_engine.Connection`::
+
+ from sqlalchemy import inspect, create_engine
+ engine = create_engine('...')
+ insp = inspect(engine)
+
+ Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` associated
+ with the engine may opt to return an :class:`_reflection.Inspector`
+ subclass that
+ provides additional methods specific to the dialect's target database.
+
+ """
+
+ @util.deprecated(
+ "1.4",
+ "The __init__() method on :class:`_reflection.Inspector` "
+ "is deprecated and "
+ "will be removed in a future release. Please use the "
+ ":func:`.sqlalchemy.inspect` "
+ "function on an :class:`_engine.Engine` or "
+ ":class:`_engine.Connection` "
+ "in order to "
+ "acquire an :class:`_reflection.Inspector`.",
+ )
+ def __init__(self, bind):
+ """Initialize a new :class:`_reflection.Inspector`.
+
+ :param bind: a :class:`~sqlalchemy.engine.Connectable`,
+ which is typically an instance of
+ :class:`~sqlalchemy.engine.Engine` or
+ :class:`~sqlalchemy.engine.Connection`.
+
+ For a dialect-specific instance of :class:`_reflection.Inspector`, see
+ :meth:`_reflection.Inspector.from_engine`
+
+ """
+ return self._init_legacy(bind)
+
+ @classmethod
+ def _construct(cls, init, bind):
+
+ if hasattr(bind.dialect, "inspector"):
+ cls = bind.dialect.inspector
+
+ self = cls.__new__(cls)
+ init(self, bind)
+ return self
+
+ def _init_legacy(self, bind):
+ if hasattr(bind, "exec_driver_sql"):
+ self._init_connection(bind)
+ else:
+ self._init_engine(bind)
+
+ def _init_engine(self, engine):
+ self.bind = self.engine = engine
+ engine.connect().close()
+ self._op_context_requires_connect = True
+ self.dialect = self.engine.dialect
+ self.info_cache = {}
+
+ def _init_connection(self, connection):
+ self.bind = connection
+ self.engine = connection.engine
+ self._op_context_requires_connect = False
+ self.dialect = self.engine.dialect
+ self.info_cache = {}
+
+ @classmethod
+ @util.deprecated(
+ "1.4",
+ "The from_engine() method on :class:`_reflection.Inspector` "
+ "is deprecated and "
+ "will be removed in a future release. Please use the "
+ ":func:`.sqlalchemy.inspect` "
+ "function on an :class:`_engine.Engine` or "
+ ":class:`_engine.Connection` "
+ "in order to "
+ "acquire an :class:`_reflection.Inspector`.",
+ )
+ def from_engine(cls, bind):
+ """Construct a new dialect-specific Inspector object from the given
+ engine or connection.
+
+ :param bind: a :class:`~sqlalchemy.engine.Connectable`,
+ which is typically an instance of
+ :class:`~sqlalchemy.engine.Engine` or
+ :class:`~sqlalchemy.engine.Connection`.
+
+ This method differs from direct a direct constructor call of
+ :class:`_reflection.Inspector` in that the
+ :class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to
+ provide a dialect-specific :class:`_reflection.Inspector` instance,
+ which may
+ provide additional methods.
+
+ See the example at :class:`_reflection.Inspector`.
+
+ """
+ return cls._construct(cls._init_legacy, bind)
+
+ @inspection._inspects(Connectable)
+ def _connectable_insp(bind):
+ # this method should not be used unless some unusual case
+ # has subclassed "Connectable"
+
+ return Inspector._construct(Inspector._init_legacy, bind)
+
+ @inspection._inspects(Engine)
+ def _engine_insp(bind):
+ return Inspector._construct(Inspector._init_engine, bind)
+
+ @inspection._inspects(Connection)
+ def _connection_insp(bind):
+ return Inspector._construct(Inspector._init_connection, bind)
+
+ @contextlib.contextmanager
+ def _operation_context(self):
+ """Return a context that optimizes for multiple operations on a single
+ transaction.
+
+ This essentially allows connect()/close() to be called if we detected
+ that we're against an :class:`_engine.Engine` and not a
+ :class:`_engine.Connection`.
+
+ """
+ if self._op_context_requires_connect:
+ conn = self.bind.connect()
+ else:
+ conn = self.bind
+ try:
+ yield conn
+ finally:
+ if self._op_context_requires_connect:
+ conn.close()
+
+ @contextlib.contextmanager
+ def _inspection_context(self):
+ """Return an :class:`_reflection.Inspector`
+ from this one that will run all
+ operations on a single connection.
+
+ """
+
+ with self._operation_context() as conn:
+ sub_insp = self._construct(self.__class__._init_connection, conn)
+ sub_insp.info_cache = self.info_cache
+ yield sub_insp
+
+ @property
+ def default_schema_name(self):
+ """Return the default schema name presented by the dialect
+ for the current engine's database user.
+
+ E.g. this is typically ``public`` for PostgreSQL and ``dbo``
+ for SQL Server.
+
+ """
+ return self.dialect.default_schema_name
+
+ def get_schema_names(self):
+ """Return all schema names."""
+
+ if hasattr(self.dialect, "get_schema_names"):
+ with self._operation_context() as conn:
+ return self.dialect.get_schema_names(
+ conn, info_cache=self.info_cache
+ )
+ return []
+
+ def get_table_names(self, schema=None):
+ """Return all table names in referred to within a particular schema.
+
+ The names are expected to be real tables only, not views.
+ Views are instead returned using the
+ :meth:`_reflection.Inspector.get_view_names`
+ method.
+
+
+ :param schema: Schema name. If ``schema`` is left at ``None``, the
+ database's default schema is
+ used, else the named schema is searched. If the database does not
+ support named schemas, behavior is undefined if ``schema`` is not
+ passed as ``None``. For special quoting, use :class:`.quoted_name`.
+
+ .. seealso::
+
+ :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names`
+
+ :attr:`_schema.MetaData.sorted_tables`
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ def has_table(self, table_name, schema=None):
+ """Return True if the backend has a table of the given name.
+
+
+ :param table_name: name of the table to check
+ :param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method
+ replaces the :meth:`_engine.Engine.has_table` method.
+
+ """
+ # TODO: info_cache?
+ with self._operation_context() as conn:
+ return self.dialect.has_table(conn, table_name, schema)
+
+ def has_sequence(self, sequence_name, schema=None):
+ """Return True if the backend has a table of the given name.
+
+ :param sequence_name: name of the table to check
+ :param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 1.4
+
+ """
+ # TODO: info_cache?
+ with self._operation_context() as conn:
+ return self.dialect.has_sequence(conn, sequence_name, schema)
+
+ def get_sorted_table_and_fkc_names(self, schema=None):
+ """Return dependency-sorted table and foreign key constraint names in
+ referred to within a particular schema.
+
+ This will yield 2-tuples of
+ ``(tablename, [(tname, fkname), (tname, fkname), ...])``
+ consisting of table names in CREATE order grouped with the foreign key
+ constraint names that are not detected as belonging to a cycle.
+ The final element
+ will be ``(None, [(tname, fkname), (tname, fkname), ..])``
+ which will consist of remaining
+ foreign key constraint names that would require a separate CREATE
+ step after-the-fact, based on dependencies between tables.
+
+ .. versionadded:: 1.0.-
+
+ .. seealso::
+
+ :meth:`_reflection.Inspector.get_table_names`
+
+ :func:`.sort_tables_and_constraints` - similar method which works
+ with an already-given :class:`_schema.MetaData`.
+
+ """
+
+ with self._operation_context() as conn:
+ tnames = self.dialect.get_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ tuples = set()
+ remaining_fkcs = set()
+
+ fknames_for_table = {}
+ for tname in tnames:
+ fkeys = self.get_foreign_keys(tname, schema)
+ fknames_for_table[tname] = set([fk["name"] for fk in fkeys])
+ for fkey in fkeys:
+ if tname != fkey["referred_table"]:
+ tuples.add((fkey["referred_table"], tname))
+ try:
+ candidate_sort = list(topological.sort(tuples, tnames))
+ except exc.CircularDependencyError as err:
+ for edge in err.edges:
+ tuples.remove(edge)
+ remaining_fkcs.update(
+ (edge[1], fkc) for fkc in fknames_for_table[edge[1]]
+ )
+
+ candidate_sort = list(topological.sort(tuples, tnames))
+ return [
+ (tname, fknames_for_table[tname].difference(remaining_fkcs))
+ for tname in candidate_sort
+ ] + [(None, list(remaining_fkcs))]
+
+ def get_temp_table_names(self):
+ """Return a list of temporary table names for the current bind.
+
+ This method is unsupported by most dialects; currently
+ only SQLite implements it.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_temp_table_names(
+ conn, info_cache=self.info_cache
+ )
+
+ def get_temp_view_names(self):
+ """Return a list of temporary view names for the current bind.
+
+ This method is unsupported by most dialects; currently
+ only SQLite implements it.
+
+ .. versionadded:: 1.0.0
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.get_temp_view_names(
+ conn, info_cache=self.info_cache
+ )
+
+ def get_table_options(self, table_name, schema=None, **kw):
+ """Return a dictionary of options specified when the table of the
+ given name was created.
+
+ This currently includes some options that apply to MySQL tables.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+ if hasattr(self.dialect, "get_table_options"):
+ with self._operation_context() as conn:
+ return self.dialect.get_table_options(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+ return {}
+
+ def get_view_names(self, schema=None):
+ """Return all view names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_view_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ def get_sequence_names(self, schema=None):
+ """Return all sequence names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_sequence_names(
+ conn, schema, info_cache=self.info_cache
+ )
+
+ def get_view_definition(self, view_name, schema=None):
+ """Return definition for `view_name`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_view_definition(
+ conn, view_name, schema, info_cache=self.info_cache
+ )
+
+ def get_columns(self, table_name, schema=None, **kw):
+ """Return information about columns in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ column information as a list of dicts with these keys:
+
+ * ``name`` - the column's name
+
+ * ``type`` - the type of this column; an instance of
+ :class:`~sqlalchemy.types.TypeEngine`
+
+ * ``nullable`` - boolean flag if the column is NULL or NOT NULL
+
+ * ``default`` - the column's server default value - this is returned
+ as a string SQL expression.
+
+ * ``autoincrement`` - indicates that the column is auto incremented -
+ this is returned as a boolean or 'auto'
+
+ * ``comment`` - (optional) the comment on the column. Only some
+ dialects return this key
+
+ * ``computed`` - (optional) when present it indicates that this column
+ is computed by the database. Only some dialects return this key.
+ Returned as a dict with the keys:
+
+ * ``sqltext`` - the expression used to generate this column returned
+ as a string SQL expression
+
+ * ``persisted`` - (optional) boolean that indicates if the column is
+ stored in the table
+
+ .. versionadded:: 1.3.16 - added support for computed reflection.
+
+ * ``identity`` - (optional) when present it indicates that this column
+ is a generated always column. Only some dialects return this key.
+ For a list of keywords on this dict see :class:`_schema.Identity`.
+
+ .. versionadded:: 1.4 - added support for identity column reflection.
+
+ * ``dialect_options`` - (optional) a dict with dialect specific options
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :return: list of dictionaries, each representing the definition of
+ a database column.
+
+ """
+
+ with self._operation_context() as conn:
+ col_defs = self.dialect.get_columns(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+ for col_def in col_defs:
+ # make this easy and only return instances for coltype
+ coltype = col_def["type"]
+ if not isinstance(coltype, TypeEngine):
+ col_def["type"] = coltype()
+ return col_defs
+
+ def get_pk_constraint(self, table_name, schema=None, **kw):
+ """Return information about primary key constraint on `table_name`.
+
+ Given a string `table_name`, and an optional string `schema`, return
+ primary key information as a dictionary with these keys:
+
+ * ``constrained_columns`` -
+ a list of column names that make up the primary key
+
+ * ``name`` -
+ optional name of the primary key constraint.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.get_pk_constraint(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_foreign_keys(self, table_name, schema=None, **kw):
+ """Return information about foreign_keys in `table_name`.
+
+ Given a string `table_name`, and an optional string `schema`, return
+ foreign key information as a list of dicts with these keys:
+
+ * ``constrained_columns`` -
+ a list of column names that make up the foreign key
+
+ * ``referred_schema`` -
+ the name of the referred schema
+
+ * ``referred_table`` -
+ the name of the referred table
+
+ * ``referred_columns`` -
+ a list of column names in the referred table that correspond to
+ constrained_columns
+
+ * ``name`` -
+ optional name of the foreign key constraint.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_foreign_keys(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_indexes(self, table_name, schema=None, **kw):
+ """Return information about indexes in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ index information as a list of dicts with these keys:
+
+ * ``name`` -
+ the index's name
+
+ * ``column_names`` -
+ list of column names in order
+
+ * ``unique`` -
+ boolean
+
+ * ``column_sorting`` -
+ optional dict mapping column names to tuple of sort keywords,
+ which may include ``asc``, ``desc``, ``nulls_first``, ``nulls_last``.
+
+ .. versionadded:: 1.3.5
+
+ * ``dialect_options`` -
+ dict of dialect-specific index options. May not be present
+ for all dialects.
+
+ .. versionadded:: 1.0.0
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_indexes(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_unique_constraints(self, table_name, schema=None, **kw):
+ """Return information about unique constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ unique constraint information as a list of dicts with these keys:
+
+ * ``name`` -
+ the unique constraint's name
+
+ * ``column_names`` -
+ list of column names in order
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_unique_constraints(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_table_comment(self, table_name, schema=None, **kw):
+ """Return information about the table comment for ``table_name``.
+
+ Given a string ``table_name`` and an optional string ``schema``,
+ return table comment information as a dictionary with these keys:
+
+ * ``text`` -
+ text of the comment.
+
+ Raises ``NotImplementedError`` for a dialect that does not support
+ comments.
+
+ .. versionadded:: 1.2
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_table_comment(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_check_constraints(self, table_name, schema=None, **kw):
+ """Return information about check constraints in `table_name`.
+
+ Given a string `table_name` and an optional string `schema`, return
+ check constraint information as a list of dicts with these keys:
+
+ * ``name`` -
+ the check constraint's name
+
+ * ``sqltext`` -
+ the check constraint's SQL expression
+
+ * ``dialect_options`` -
+ may or may not be present; a dictionary with additional
+ dialect-specific options for this CHECK constraint
+
+ .. versionadded:: 1.3.8
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ .. versionadded:: 1.1.0
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_check_constraints(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ @util.deprecated_20(
+ ":meth:`_reflection.Inspector.reflecttable`",
+ "The :meth:`_reflection.Inspector.reflecttable` "
+ "method was renamed to "
+ ":meth:`_reflection.Inspector.reflect_table`. This deprecated alias "
+ "will be removed in a future release.",
+ )
+ def reflecttable(self, *args, **kwargs):
+ "See reflect_table. This method name is deprecated"
+ return self.reflect_table(*args, **kwargs)
+
+ def reflect_table(
+ self,
+ table,
+ include_columns,
+ exclude_columns=(),
+ resolve_fks=True,
+ _extend_on=None,
+ ):
+ """Given a :class:`_schema.Table` object, load its internal
+ constructs based on introspection.
+
+ This is the underlying method used by most dialects to produce
+ table reflection. Direct usage is like::
+
+ from sqlalchemy import create_engine, MetaData, Table
+ from sqlalchemy import inspect
+
+ engine = create_engine('...')
+ meta = MetaData()
+ user_table = Table('user', meta)
+ insp = inspect(engine)
+ insp.reflect_table(user_table, None)
+
+ .. versionchanged:: 1.4 Renamed from ``reflecttable`` to
+ ``reflect_table``
+
+ :param table: a :class:`~sqlalchemy.schema.Table` instance.
+ :param include_columns: a list of string column names to include
+ in the reflection process. If ``None``, all columns are reflected.
+
+ """
+
+ if _extend_on is not None:
+ if table in _extend_on:
+ return
+ else:
+ _extend_on.add(table)
+
+ dialect = self.bind.dialect
+
+ with self._operation_context() as conn:
+ schema = conn.schema_for_object(table)
+
+ table_name = table.name
+
+ # get table-level arguments that are specifically
+ # intended for reflection, e.g. oracle_resolve_synonyms.
+ # these are unconditionally passed to related Table
+ # objects
+ reflection_options = dict(
+ (k, table.dialect_kwargs.get(k))
+ for k in dialect.reflection_options
+ if k in table.dialect_kwargs
+ )
+
+ # reflect table options, like mysql_engine
+ tbl_opts = self.get_table_options(
+ table_name, schema, **table.dialect_kwargs
+ )
+ if tbl_opts:
+ # add additional kwargs to the Table if the dialect
+ # returned them
+ table._validate_dialect_kwargs(tbl_opts)
+
+ if util.py2k:
+ if isinstance(schema, str):
+ schema = schema.decode(dialect.encoding)
+ if isinstance(table_name, str):
+ table_name = table_name.decode(dialect.encoding)
+
+ found_table = False
+ cols_by_orig_name = {}
+
+ for col_d in self.get_columns(
+ table_name, schema, **table.dialect_kwargs
+ ):
+ found_table = True
+
+ self._reflect_column(
+ table,
+ col_d,
+ include_columns,
+ exclude_columns,
+ cols_by_orig_name,
+ )
+
+ # NOTE: support tables/views with no columns
+ if not found_table and not self.has_table(table_name, schema):
+ raise exc.NoSuchTableError(table_name)
+
+ self._reflect_pk(
+ table_name, schema, table, cols_by_orig_name, exclude_columns
+ )
+
+ self._reflect_fk(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on,
+ reflection_options,
+ )
+
+ self._reflect_indexes(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
+
+ self._reflect_unique_constraints(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
+
+ self._reflect_check_constraints(
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ )
+
+ self._reflect_table_comment(
+ table_name, schema, table, reflection_options
+ )
+
+ def _reflect_column(
+ self, table, col_d, include_columns, exclude_columns, cols_by_orig_name
+ ):
+
+ orig_name = col_d["name"]
+
+ table.metadata.dispatch.column_reflect(self, table, col_d)
+ table.dispatch.column_reflect(self, table, col_d)
+
+ # fetch name again as column_reflect is allowed to
+ # change it
+ name = col_d["name"]
+ if (include_columns and name not in include_columns) or (
+ exclude_columns and name in exclude_columns
+ ):
+ return
+
+ coltype = col_d["type"]
+
+ col_kw = dict(
+ (k, col_d[k])
+ for k in [
+ "nullable",
+ "autoincrement",
+ "quote",
+ "info",
+ "key",
+ "comment",
+ ]
+ if k in col_d
+ )
+
+ if "dialect_options" in col_d:
+ col_kw.update(col_d["dialect_options"])
+
+ colargs = []
+ if col_d.get("default") is not None:
+ default = col_d["default"]
+ if isinstance(default, sql.elements.TextClause):
+ default = sa_schema.DefaultClause(default, _reflected=True)
+ elif not isinstance(default, sa_schema.FetchedValue):
+ default = sa_schema.DefaultClause(
+ sql.text(col_d["default"]), _reflected=True
+ )
+
+ colargs.append(default)
+
+ if "computed" in col_d:
+ computed = sa_schema.Computed(**col_d["computed"])
+ colargs.append(computed)
+
+ if "identity" in col_d:
+ computed = sa_schema.Identity(**col_d["identity"])
+ colargs.append(computed)
+
+ if "sequence" in col_d:
+ self._reflect_col_sequence(col_d, colargs)
+
+ cols_by_orig_name[orig_name] = col = sa_schema.Column(
+ name, coltype, *colargs, **col_kw
+ )
+
+ if col.key in table.primary_key:
+ col.primary_key = True
+ table.append_column(col, replace_existing=True)
+
+ def _reflect_col_sequence(self, col_d, colargs):
+ if "sequence" in col_d:
+ # TODO: mssql and sybase are using this.
+ seq = col_d["sequence"]
+ sequence = sa_schema.Sequence(seq["name"], 1, 1)
+ if "start" in seq:
+ sequence.start = seq["start"]
+ if "increment" in seq:
+ sequence.increment = seq["increment"]
+ colargs.append(sequence)
+
+ def _reflect_pk(
+ self, table_name, schema, table, cols_by_orig_name, exclude_columns
+ ):
+ pk_cons = self.get_pk_constraint(
+ table_name, schema, **table.dialect_kwargs
+ )
+ if pk_cons:
+ pk_cols = [
+ cols_by_orig_name[pk]
+ for pk in pk_cons["constrained_columns"]
+ if pk in cols_by_orig_name and pk not in exclude_columns
+ ]
+
+ # update pk constraint name
+ table.primary_key.name = pk_cons.get("name")
+
+ # tell the PKConstraint to re-initialize
+ # its column collection
+ table.primary_key._reload(pk_cols)
+
+ def _reflect_fk(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on,
+ reflection_options,
+ ):
+ fkeys = self.get_foreign_keys(
+ table_name, schema, **table.dialect_kwargs
+ )
+ for fkey_d in fkeys:
+ conname = fkey_d["name"]
+ # look for columns by orig name in cols_by_orig_name,
+ # but support columns that are in-Python only as fallback
+ constrained_columns = [
+ cols_by_orig_name[c].key if c in cols_by_orig_name else c
+ for c in fkey_d["constrained_columns"]
+ ]
+
+ if (
+ exclude_columns
+ and set(constrained_columns).intersection(exclude_columns)
+ or (
+ include_columns
+ and set(constrained_columns).difference(include_columns)
+ )
+ ):
+ continue
+
+ referred_schema = fkey_d["referred_schema"]
+ referred_table = fkey_d["referred_table"]
+ referred_columns = fkey_d["referred_columns"]
+ refspec = []
+ if referred_schema is not None:
+ if resolve_fks:
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ schema=referred_schema,
+ autoload_with=self.bind,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
+ for column in referred_columns:
+ refspec.append(
+ ".".join([referred_schema, referred_table, column])
+ )
+ else:
+ if resolve_fks:
+ sa_schema.Table(
+ referred_table,
+ table.metadata,
+ autoload_with=self.bind,
+ schema=sa_schema.BLANK_SCHEMA,
+ _extend_on=_extend_on,
+ **reflection_options
+ )
+ for column in referred_columns:
+ refspec.append(".".join([referred_table, column]))
+ if "options" in fkey_d:
+ options = fkey_d["options"]
+ else:
+ options = {}
+
+ table.append_constraint(
+ sa_schema.ForeignKeyConstraint(
+ constrained_columns,
+ refspec,
+ conname,
+ link_to_name=True,
+ **options
+ )
+ )
+
+ _index_sort_exprs = [
+ ("asc", operators.asc_op),
+ ("desc", operators.desc_op),
+ ("nulls_first", operators.nulls_first_op),
+ ("nulls_last", operators.nulls_last_op),
+ ]
+
+ def _reflect_indexes(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
+ # Indexes
+ indexes = self.get_indexes(table_name, schema)
+ for index_d in indexes:
+ name = index_d["name"]
+ columns = index_d["column_names"]
+ column_sorting = index_d.get("column_sorting", {})
+ unique = index_d["unique"]
+ flavor = index_d.get("type", "index")
+ dialect_options = index_d.get("dialect_options", {})
+
+ duplicates = index_d.get("duplicates_constraint")
+ if include_columns and not set(columns).issubset(include_columns):
+ util.warn(
+ "Omitting %s key for (%s), key covers omitted columns."
+ % (flavor, ", ".join(columns))
+ )
+ continue
+ if duplicates:
+ continue
+ # look for columns by orig name in cols_by_orig_name,
+ # but support columns that are in-Python only as fallback
+ idx_cols = []
+ for c in columns:
+ try:
+ idx_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
+ except KeyError:
+ util.warn(
+ "%s key '%s' was not located in "
+ "columns for table '%s'" % (flavor, c, table_name)
+ )
+ continue
+ c_sorting = column_sorting.get(c, ())
+ for k, op in self._index_sort_exprs:
+ if k in c_sorting:
+ idx_col = op(idx_col)
+ idx_cols.append(idx_col)
+
+ sa_schema.Index(
+ name,
+ *idx_cols,
+ _table=table,
+ **dict(list(dialect_options.items()) + [("unique", unique)])
+ )
+
+ def _reflect_unique_constraints(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
+
+ # Unique Constraints
+ try:
+ constraints = self.get_unique_constraints(table_name, schema)
+ except NotImplementedError:
+ # optional dialect feature
+ return
+
+ for const_d in constraints:
+ conname = const_d["name"]
+ columns = const_d["column_names"]
+ duplicates = const_d.get("duplicates_index")
+ if include_columns and not set(columns).issubset(include_columns):
+ util.warn(
+ "Omitting unique constraint key for (%s), "
+ "key covers omitted columns." % ", ".join(columns)
+ )
+ continue
+ if duplicates:
+ continue
+ # look for columns by orig name in cols_by_orig_name,
+ # but support columns that are in-Python only as fallback
+ constrained_cols = []
+ for c in columns:
+ try:
+ constrained_col = (
+ cols_by_orig_name[c]
+ if c in cols_by_orig_name
+ else table.c[c]
+ )
+ except KeyError:
+ util.warn(
+ "unique constraint key '%s' was not located in "
+ "columns for table '%s'" % (c, table_name)
+ )
+ else:
+ constrained_cols.append(constrained_col)
+ table.append_constraint(
+ sa_schema.UniqueConstraint(*constrained_cols, name=conname)
+ )
+
+ def _reflect_check_constraints(
+ self,
+ table_name,
+ schema,
+ table,
+ cols_by_orig_name,
+ include_columns,
+ exclude_columns,
+ reflection_options,
+ ):
+ try:
+ constraints = self.get_check_constraints(table_name, schema)
+ except NotImplementedError:
+ # optional dialect feature
+ return
+
+ for const_d in constraints:
+ table.append_constraint(sa_schema.CheckConstraint(**const_d))
+
+ def _reflect_table_comment(
+ self, table_name, schema, table, reflection_options
+ ):
+ try:
+ comment_dict = self.get_table_comment(table_name, schema)
+ except NotImplementedError:
+ return
+ else:
+ table.comment = comment_dict.get("text", None)
diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py
new file mode 100644
index 0000000..1fd4e1c
--- /dev/null
+++ b/lib/sqlalchemy/engine/result.py
@@ -0,0 +1,1857 @@
+# engine/result.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define generic result set constructs."""
+
+
+import functools
+import itertools
+import operator
+
+from .row import _baserow_usecext
+from .row import Row
+from .. import exc
+from .. import util
+from ..sql.base import _generative
+from ..sql.base import HasMemoized
+from ..sql.base import InPlaceGenerative
+from ..util import collections_abc
+from ..util import py2k
+
+
+if _baserow_usecext:
+ from sqlalchemy.cresultproxy import tuplegetter
+
+ _row_as_tuple = tuplegetter
+else:
+
+ def tuplegetter(*indexes):
+ it = operator.itemgetter(*indexes)
+
+ if len(indexes) > 1:
+ return it
+ else:
+ return lambda row: (it(row),)
+
+ def _row_as_tuple(*indexes):
+ # circumvent LegacyRow.__getitem__ pointing to
+ # _get_by_key_impl_mapping for now. otherwise we could
+ # use itemgetter
+ getters = [
+ operator.methodcaller("_get_by_int_impl", index)
+ for index in indexes
+ ]
+ return lambda rec: tuple([getter(rec) for getter in getters])
+
+
+class ResultMetaData(object):
+ """Base for metadata about result rows."""
+
+ __slots__ = ()
+
+ _tuplefilter = None
+ _translated_indexes = None
+ _unique_filters = None
+
+ @property
+ def keys(self):
+ return RMKeyView(self)
+
+ def _has_key(self, key):
+ raise NotImplementedError()
+
+ def _for_freeze(self):
+ raise NotImplementedError()
+
+ def _key_fallback(self, key, err, raiseerr=True):
+ assert raiseerr
+ util.raise_(KeyError(key), replace_context=err)
+
+ def _warn_for_nonint(self, key):
+ util.warn_deprecated_20(
+ "Retrieving row members using strings or other non-integers is "
+ "deprecated; use row._mapping for a dictionary interface "
+ "to the row"
+ )
+
+ def _raise_for_nonint(self, key):
+ raise TypeError(
+ "TypeError: tuple indices must be integers or slices, not %s"
+ % type(key).__name__
+ )
+
+ def _index_for_key(self, keys, raiseerr):
+ raise NotImplementedError()
+
+ def _metadata_for_keys(self, key):
+ raise NotImplementedError()
+
+ def _reduce(self, keys):
+ raise NotImplementedError()
+
+ def _getter(self, key, raiseerr=True):
+
+ index = self._index_for_key(key, raiseerr)
+
+ if index is not None:
+ return operator.itemgetter(index)
+ else:
+ return None
+
+ def _row_as_tuple_getter(self, keys):
+ indexes = self._indexes_for_keys(keys)
+ return _row_as_tuple(*indexes)
+
+
+class RMKeyView(collections_abc.KeysView):
+ __slots__ = ("_parent", "_keys")
+
+ def __init__(self, parent):
+ self._parent = parent
+ self._keys = [k for k in parent._keys if k is not None]
+
+ def __len__(self):
+ return len(self._keys)
+
+ def __repr__(self):
+ return "{0.__class__.__name__}({0._keys!r})".format(self)
+
+ def __iter__(self):
+ return iter(self._keys)
+
+ def __contains__(self, item):
+ if not _baserow_usecext and isinstance(item, int):
+ return False
+
+ # note this also includes special key fallback behaviors
+ # which also don't seem to be tested in test_resultset right now
+ return self._parent._has_key(item)
+
+ def __eq__(self, other):
+ return list(other) == list(self)
+
+ def __ne__(self, other):
+ return list(other) != list(self)
+
+
+class SimpleResultMetaData(ResultMetaData):
+ """result metadata for in-memory collections."""
+
+ __slots__ = (
+ "_keys",
+ "_keymap",
+ "_processors",
+ "_tuplefilter",
+ "_translated_indexes",
+ "_unique_filters",
+ )
+
+ def __init__(
+ self,
+ keys,
+ extra=None,
+ _processors=None,
+ _tuplefilter=None,
+ _translated_indexes=None,
+ _unique_filters=None,
+ ):
+ self._keys = list(keys)
+ self._tuplefilter = _tuplefilter
+ self._translated_indexes = _translated_indexes
+ self._unique_filters = _unique_filters
+
+ if extra:
+ recs_names = [
+ (
+ (name,) + extras,
+ (index, name, extras),
+ )
+ for index, (name, extras) in enumerate(zip(self._keys, extra))
+ ]
+ else:
+ recs_names = [
+ ((name,), (index, name, ()))
+ for index, name in enumerate(self._keys)
+ ]
+
+ self._keymap = {key: rec for keys, rec in recs_names for key in keys}
+
+ self._processors = _processors
+
+ def _has_key(self, key):
+ return key in self._keymap
+
+ def _for_freeze(self):
+ unique_filters = self._unique_filters
+ if unique_filters and self._tuplefilter:
+ unique_filters = self._tuplefilter(unique_filters)
+
+ # TODO: are we freezing the result with or without uniqueness
+ # applied?
+ return SimpleResultMetaData(
+ self._keys,
+ extra=[self._keymap[key][2] for key in self._keys],
+ _unique_filters=unique_filters,
+ )
+
+ def __getstate__(self):
+ return {
+ "_keys": self._keys,
+ "_translated_indexes": self._translated_indexes,
+ }
+
+ def __setstate__(self, state):
+ if state["_translated_indexes"]:
+ _translated_indexes = state["_translated_indexes"]
+ _tuplefilter = tuplegetter(*_translated_indexes)
+ else:
+ _translated_indexes = _tuplefilter = None
+ self.__init__(
+ state["_keys"],
+ _translated_indexes=_translated_indexes,
+ _tuplefilter=_tuplefilter,
+ )
+
+ def _contains(self, value, row):
+ return value in row._data
+
+ def _index_for_key(self, key, raiseerr=True):
+ if int in key.__class__.__mro__:
+ key = self._keys[key]
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._key_fallback(key, ke, raiseerr)
+
+ return rec[0]
+
+ def _indexes_for_keys(self, keys):
+ return [self._keymap[key][0] for key in keys]
+
+ def _metadata_for_keys(self, keys):
+ for key in keys:
+ if int in key.__class__.__mro__:
+ key = self._keys[key]
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._key_fallback(key, ke, True)
+
+ yield rec
+
+ def _reduce(self, keys):
+ try:
+ metadata_for_keys = [
+ self._keymap[
+ self._keys[key] if int in key.__class__.__mro__ else key
+ ]
+ for key in keys
+ ]
+ except KeyError as ke:
+ self._key_fallback(ke.args[0], ke, True)
+
+ indexes, new_keys, extra = zip(*metadata_for_keys)
+
+ if self._translated_indexes:
+ indexes = [self._translated_indexes[idx] for idx in indexes]
+
+ tup = tuplegetter(*indexes)
+
+ new_metadata = SimpleResultMetaData(
+ new_keys,
+ extra=extra,
+ _tuplefilter=tup,
+ _translated_indexes=indexes,
+ _processors=self._processors,
+ _unique_filters=self._unique_filters,
+ )
+
+ return new_metadata
+
+
+def result_tuple(fields, extra=None):
+ parent = SimpleResultMetaData(fields, extra)
+ return functools.partial(
+ Row, parent, parent._processors, parent._keymap, Row._default_key_style
+ )
+
+
+# a symbol that indicates to internal Result methods that
+# "no row is returned". We can't use None for those cases where a scalar
+# filter is applied to rows.
+_NO_ROW = util.symbol("NO_ROW")
+
+
+class ResultInternal(InPlaceGenerative):
+ _real_result = None
+ _generate_rows = True
+ _unique_filter_state = None
+ _post_creational_filter = None
+ _is_cursor = False
+
+ @HasMemoized.memoized_attribute
+ def _row_getter(self):
+ real_result = self._real_result if self._real_result else self
+
+ if real_result._source_supports_scalars:
+ if not self._generate_rows:
+ return None
+ else:
+ _proc = real_result._process_row
+
+ def process_row(
+ metadata, processors, keymap, key_style, scalar_obj
+ ):
+ return _proc(
+ metadata, processors, keymap, key_style, (scalar_obj,)
+ )
+
+ else:
+ process_row = real_result._process_row
+
+ key_style = real_result._process_row._default_key_style
+ metadata = self._metadata
+
+ keymap = metadata._keymap
+ processors = metadata._processors
+ tf = metadata._tuplefilter
+
+ if tf and not real_result._source_supports_scalars:
+ if processors:
+ processors = tf(processors)
+
+ _make_row_orig = functools.partial(
+ process_row, metadata, processors, keymap, key_style
+ )
+
+ def make_row(row):
+ return _make_row_orig(tf(row))
+
+ else:
+ make_row = functools.partial(
+ process_row, metadata, processors, keymap, key_style
+ )
+
+ fns = ()
+
+ if real_result._row_logging_fn:
+ fns = (real_result._row_logging_fn,)
+ else:
+ fns = ()
+
+ if fns:
+ _make_row = make_row
+
+ def make_row(row):
+ row = _make_row(row)
+ for fn in fns:
+ row = fn(row)
+ return row
+
+ return make_row
+
+ @HasMemoized.memoized_attribute
+ def _iterator_getter(self):
+
+ make_row = self._row_getter
+
+ post_creational_filter = self._post_creational_filter
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ def iterrows(self):
+ for row in self._fetchiter_impl():
+ obj = make_row(row) if make_row else row
+ hashed = strategy(obj) if strategy else obj
+ if hashed in uniques:
+ continue
+ uniques.add(hashed)
+ if post_creational_filter:
+ obj = post_creational_filter(obj)
+ yield obj
+
+ else:
+
+ def iterrows(self):
+ for row in self._fetchiter_impl():
+ row = make_row(row) if make_row else row
+ if post_creational_filter:
+ row = post_creational_filter(row)
+ yield row
+
+ return iterrows
+
+ def _raw_all_rows(self):
+ make_row = self._row_getter
+ rows = self._fetchall_impl()
+ return [make_row(row) for row in rows]
+
+ def _allrows(self):
+
+ post_creational_filter = self._post_creational_filter
+
+ make_row = self._row_getter
+
+ rows = self._fetchall_impl()
+ if make_row:
+ made_rows = [make_row(row) for row in rows]
+ else:
+ made_rows = rows
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ rows = [
+ made_row
+ for made_row, sig_row in [
+ (
+ made_row,
+ strategy(made_row) if strategy else made_row,
+ )
+ for made_row in made_rows
+ ]
+ if sig_row not in uniques and not uniques.add(sig_row)
+ ]
+ else:
+ rows = made_rows
+
+ if post_creational_filter:
+ rows = [post_creational_filter(row) for row in rows]
+ return rows
+
+ @HasMemoized.memoized_attribute
+ def _onerow_getter(self):
+ make_row = self._row_getter
+
+ post_creational_filter = self._post_creational_filter
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ def onerow(self):
+ _onerow = self._fetchone_impl
+ while True:
+ row = _onerow()
+ if row is None:
+ return _NO_ROW
+ else:
+ obj = make_row(row) if make_row else row
+ hashed = strategy(obj) if strategy else obj
+ if hashed in uniques:
+ continue
+ else:
+ uniques.add(hashed)
+ if post_creational_filter:
+ obj = post_creational_filter(obj)
+ return obj
+
+ else:
+
+ def onerow(self):
+ row = self._fetchone_impl()
+ if row is None:
+ return _NO_ROW
+ else:
+ row = make_row(row) if make_row else row
+ if post_creational_filter:
+ row = post_creational_filter(row)
+ return row
+
+ return onerow
+
+ @HasMemoized.memoized_attribute
+ def _manyrow_getter(self):
+ make_row = self._row_getter
+
+ post_creational_filter = self._post_creational_filter
+
+ if self._unique_filter_state:
+ uniques, strategy = self._unique_strategy
+
+ def filterrows(make_row, rows, strategy, uniques):
+ if make_row:
+ rows = [make_row(row) for row in rows]
+
+ if strategy:
+ made_rows = (
+ (made_row, strategy(made_row)) for made_row in rows
+ )
+ else:
+ made_rows = ((made_row, made_row) for made_row in rows)
+ return [
+ made_row
+ for made_row, sig_row in made_rows
+ if sig_row not in uniques and not uniques.add(sig_row)
+ ]
+
+ def manyrows(self, num):
+ collect = []
+
+ _manyrows = self._fetchmany_impl
+
+ if num is None:
+ # if None is passed, we don't know the default
+ # manyrows number, DBAPI has this as cursor.arraysize
+ # different DBAPIs / fetch strategies may be different.
+ # do a fetch to find what the number is. if there are
+ # only fewer rows left, then it doesn't matter.
+ real_result = (
+ self._real_result if self._real_result else self
+ )
+ if real_result._yield_per:
+ num_required = num = real_result._yield_per
+ else:
+ rows = _manyrows(num)
+ num = len(rows)
+ collect.extend(
+ filterrows(make_row, rows, strategy, uniques)
+ )
+ num_required = num - len(collect)
+ else:
+ num_required = num
+
+ while num_required:
+ rows = _manyrows(num_required)
+ if not rows:
+ break
+
+ collect.extend(
+ filterrows(make_row, rows, strategy, uniques)
+ )
+ num_required = num - len(collect)
+
+ if post_creational_filter:
+ collect = [post_creational_filter(row) for row in collect]
+ return collect
+
+ else:
+
+ def manyrows(self, num):
+ if num is None:
+ real_result = (
+ self._real_result if self._real_result else self
+ )
+ num = real_result._yield_per
+
+ rows = self._fetchmany_impl(num)
+ if make_row:
+ rows = [make_row(row) for row in rows]
+ if post_creational_filter:
+ rows = [post_creational_filter(row) for row in rows]
+ return rows
+
+ return manyrows
+
+ def _only_one_row(
+ self,
+ raise_for_second_row,
+ raise_for_none,
+ scalar,
+ ):
+ onerow = self._fetchone_impl
+
+ row = onerow(hard_close=True)
+ if row is None:
+ if raise_for_none:
+ raise exc.NoResultFound(
+ "No row was found when one was required"
+ )
+ else:
+ return None
+
+ if scalar and self._source_supports_scalars:
+ self._generate_rows = False
+ make_row = None
+ else:
+ make_row = self._row_getter
+
+ try:
+ row = make_row(row) if make_row else row
+ except:
+ self._soft_close(hard=True)
+ raise
+
+ if raise_for_second_row:
+ if self._unique_filter_state:
+ # for no second row but uniqueness, need to essentially
+ # consume the entire result :(
+ uniques, strategy = self._unique_strategy
+
+ existing_row_hash = strategy(row) if strategy else row
+
+ while True:
+ next_row = onerow(hard_close=True)
+ if next_row is None:
+ next_row = _NO_ROW
+ break
+
+ try:
+ next_row = make_row(next_row) if make_row else next_row
+
+ if strategy:
+ if existing_row_hash == strategy(next_row):
+ continue
+ elif row == next_row:
+ continue
+ # here, we have a row and it's different
+ break
+ except:
+ self._soft_close(hard=True)
+ raise
+ else:
+ next_row = onerow(hard_close=True)
+ if next_row is None:
+ next_row = _NO_ROW
+
+ if next_row is not _NO_ROW:
+ self._soft_close(hard=True)
+ raise exc.MultipleResultsFound(
+ "Multiple rows were found when exactly one was required"
+ if raise_for_none
+ else "Multiple rows were found when one or none "
+ "was required"
+ )
+ else:
+ next_row = _NO_ROW
+ # if we checked for second row then that would have
+ # closed us :)
+ self._soft_close(hard=True)
+
+ if not scalar:
+ post_creational_filter = self._post_creational_filter
+ if post_creational_filter:
+ row = post_creational_filter(row)
+
+ if scalar and make_row:
+ return row[0]
+ else:
+ return row
+
+ def _iter_impl(self):
+ return self._iterator_getter(self)
+
+ def _next_impl(self):
+ row = self._onerow_getter(self)
+ if row is _NO_ROW:
+ raise StopIteration()
+ else:
+ return row
+
+ @_generative
+ def _column_slices(self, indexes):
+ real_result = self._real_result if self._real_result else self
+
+ if real_result._source_supports_scalars and len(indexes) == 1:
+ util.warn_deprecated(
+ "The Result.columns() method has a bug in SQLAlchemy 1.4 that "
+ "is causing it to yield scalar values, rather than Row "
+ "objects, in the case where a single index is passed and the "
+ "result is against ORM mapped objects. In SQLAlchemy 2.0, "
+ "Result will continue yield Row objects in this scenario. "
+ "Use the Result.scalars() method to yield scalar values.",
+ "2.0",
+ )
+ self._generate_rows = False
+ else:
+ self._generate_rows = True
+ self._metadata = self._metadata._reduce(indexes)
+
+ @HasMemoized.memoized_attribute
+ def _unique_strategy(self):
+ uniques, strategy = self._unique_filter_state
+
+ real_result = (
+ self._real_result if self._real_result is not None else self
+ )
+
+ if not strategy and self._metadata._unique_filters:
+ if (
+ real_result._source_supports_scalars
+ and not self._generate_rows
+ ):
+ strategy = self._metadata._unique_filters[0]
+ else:
+ filters = self._metadata._unique_filters
+ if self._metadata._tuplefilter:
+ filters = self._metadata._tuplefilter(filters)
+
+ strategy = operator.methodcaller("_filter_on_values", filters)
+ return uniques, strategy
+
+
+class _WithKeys(object):
+ # used mainly to share documentation on the keys method.
+ # py2k does not allow overriding the __doc__ attribute.
+ def keys(self):
+ """Return an iterable view which yields the string keys that would
+ be represented by each :class:`.Row`.
+
+ The keys can represent the labels of the columns returned by a core
+ statement or the names of the orm classes returned by an orm
+ execution.
+
+ The view also can be tested for key containment using the Python
+ ``in`` operator, which will test both for the string keys represented
+ in the view, as well as for alternate keys such as column objects.
+
+ .. versionchanged:: 1.4 a key view object is returned rather than a
+ plain list.
+
+
+ """
+ return self._metadata.keys
+
+
+class Result(_WithKeys, ResultInternal):
+ """Represent a set of database results.
+
+ .. versionadded:: 1.4 The :class:`.Result` object provides a completely
+ updated usage model and calling facade for SQLAlchemy Core and
+ SQLAlchemy ORM. In Core, it forms the basis of the
+ :class:`.CursorResult` object which replaces the previous
+ :class:`.ResultProxy` interface. When using the ORM, a higher level
+ object called :class:`.ChunkedIteratorResult` is normally used.
+
+ .. note:: In SQLAlchemy 1.4 and above, this object is
+ used for ORM results returned by :meth:`_orm.Session.execute`, which can
+ yield instances of ORM mapped objects either individually or within
+ tuple-like rows. Note that the :class:`_result.Result` object does not
+ deduplicate instances or rows automatically as is the case with the
+ legacy :class:`_orm.Query` object. For in-Python de-duplication of
+ instances or rows, use the :meth:`_result.Result.unique` modifier
+ method.
+
+ .. seealso::
+
+ :ref:`tutorial_fetching_rows` - in the :doc:`/tutorial/index`
+
+ """
+
+ _process_row = Row
+
+ _row_logging_fn = None
+
+ _source_supports_scalars = False
+
+ _yield_per = None
+
+ _attributes = util.immutabledict()
+
+ def __init__(self, cursor_metadata):
+ self._metadata = cursor_metadata
+
+ def _soft_close(self, hard=False):
+ raise NotImplementedError()
+
+ def close(self):
+ """close this :class:`_result.Result`.
+
+ The behavior of this method is implementation specific, and is
+ not implemented by default. The method should generally end
+ the resources in use by the result object and also cause any
+ subsequent iteration or row fetching to raise
+ :class:`.ResourceClosedError`.
+
+ .. versionadded:: 1.4.27 - ``.close()`` was previously not generally
+ available for all :class:`_result.Result` classes, instead only
+ being available on the :class:`_engine.CursorResult` returned for
+ Core statement executions. As most other result objects, namely the
+ ones used by the ORM, are proxying a :class:`_engine.CursorResult`
+ in any case, this allows the underlying cursor result to be closed
+ from the outside facade for the case when the ORM query is using
+ the ``yield_per`` execution option where it does not immediately
+ exhaust and autoclose the database cursor.
+
+ """
+ self._soft_close(hard=True)
+
+ @_generative
+ def yield_per(self, num):
+ """Configure the row-fetching strategy to fetch ``num`` rows at a time.
+
+ This impacts the underlying behavior of the result when iterating over
+ the result object, or otherwise making use of methods such as
+ :meth:`_engine.Result.fetchone` that return one row at a time. Data
+ from the underlying cursor or other data source will be buffered up to
+ this many rows in memory, and the buffered collection will then be
+ yielded out one row at at time or as many rows are requested. Each time
+ the buffer clears, it will be refreshed to this many rows or as many
+ rows remain if fewer remain.
+
+ The :meth:`_engine.Result.yield_per` method is generally used in
+ conjunction with the
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ execution option, which will allow the database dialect in use to make
+ use of a server side cursor, if the DBAPI supports a specific "server
+ side cursor" mode separate from its default mode of operation.
+
+ .. tip::
+
+ Consider using the
+ :paramref:`_engine.Connection.execution_options.yield_per`
+ execution option, which will simultaneously set
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ to ensure the use of server side cursors, as well as automatically
+ invoke the :meth:`_engine.Result.yield_per` method to establish
+ a fixed row buffer size at once.
+
+ The :paramref:`_engine.Connection.execution_options.yield_per`
+ execution option is available for ORM operations, with
+ :class:`_orm.Session`-oriented use described at
+ :ref:`orm_queryguide_yield_per`. The Core-only version which works
+ with :class:`_engine.Connection` is new as of SQLAlchemy 1.4.40.
+
+ .. versionadded:: 1.4
+
+ :param num: number of rows to fetch each time the buffer is refilled.
+ If set to a value below 1, fetches all rows for the next buffer.
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - describes Core behavior for
+ :meth:`_engine.Result.yield_per`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+
+ """
+ self._yield_per = num
+
+ @_generative
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_engine.Result`.
+
+ When this filter is applied with no arguments, the rows or objects
+ returned will filtered such that each row is returned uniquely. The
+ algorithm used to determine this uniqueness is by default the Python
+ hashing identity of the whole tuple. In some cases a specialized
+ per-entity hashing scheme may be used, such as when using the ORM, a
+ scheme is applied which works against the primary key identity of
+ returned objects.
+
+ The unique filter is applied **after all other filters**, which means
+ if the columns returned have been refined using a method such as the
+ :meth:`_engine.Result.columns` or :meth:`_engine.Result.scalars`
+ method, the uniquing is applied to **only the column or columns
+ returned**. This occurs regardless of the order in which these
+ methods have been called upon the :class:`_engine.Result` object.
+
+ The unique filter also changes the calculus used for methods like
+ :meth:`_engine.Result.fetchmany` and :meth:`_engine.Result.partitions`.
+ When using :meth:`_engine.Result.unique`, these methods will continue
+ to yield the number of rows or objects requested, after uniquing
+ has been applied. However, this necessarily impacts the buffering
+ behavior of the underlying cursor or datasource, such that multiple
+ underlying calls to ``cursor.fetchmany()`` may be necessary in order
+ to accumulate enough objects in order to provide a unique collection
+ of the requested size.
+
+ :param strategy: a callable that will be applied to rows or objects
+ being iterated, which should return an object that represents the
+ unique value of the row. A Python ``set()`` is used to store
+ these identities. If not passed, a default uniqueness strategy
+ is used which may have been assembled by the source of this
+ :class:`_engine.Result` object.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row.
+
+ This method may be used to limit the columns returned as well
+ as to reorder them. The given list of expressions are normally
+ a series of integers or string key names. They may also be
+ appropriate :class:`.ColumnElement` objects which correspond to
+ a given statement construct.
+
+ E.g.::
+
+ statement = select(table.c.x, table.c.y, table.c.z)
+ result = connection.execute(statement)
+
+ for z, y in result.columns('z', 'y'):
+ # ...
+
+
+ Example of using the column objects from the statement itself::
+
+ for z, y in result.columns(
+ statement.selected_columns.c.z,
+ statement.selected_columns.c.y
+ ):
+ # ...
+
+ .. versionadded:: 1.4
+
+ :param \*col_expressions: indicates columns to be returned. Elements
+ may be integer row indexes, string column names, or appropriate
+ :class:`.ColumnElement` objects corresponding to a select construct.
+
+ :return: this :class:`_engine.Result` object with the modifications
+ given.
+
+ """
+ return self._column_slices(col_expressions)
+
+ def scalars(self, index=0):
+ """Return a :class:`_result.ScalarResult` filtering object which
+ will return single elements rather than :class:`_row.Row` objects.
+
+ E.g.::
+
+ >>> result = conn.execute(text("select int_id from table"))
+ >>> result.scalars().all()
+ [1, 2, 3]
+
+ When results are fetched from the :class:`_result.ScalarResult`
+ filtering object, the single column-row that would be returned by the
+ :class:`_result.Result` is instead returned as the column's value.
+
+ .. versionadded:: 1.4
+
+ :param index: integer or row key indicating the column to be fetched
+ from each row, defaults to ``0`` indicating the first column.
+
+ :return: a new :class:`_result.ScalarResult` filtering object referring
+ to this :class:`_result.Result` object.
+
+ """
+ return ScalarResult(self, index)
+
+ def _getter(self, key, raiseerr=True):
+ """return a callable that will retrieve the given key from a
+ :class:`.Row`.
+
+ """
+ if self._source_supports_scalars:
+ raise NotImplementedError(
+ "can't use this function in 'only scalars' mode"
+ )
+ return self._metadata._getter(key, raiseerr)
+
+ def _tuple_getter(self, keys):
+ """return a callable that will retrieve the given keys from a
+ :class:`.Row`.
+
+ """
+ if self._source_supports_scalars:
+ raise NotImplementedError(
+ "can't use this function in 'only scalars' mode"
+ )
+ return self._metadata._row_as_tuple_getter(keys)
+
+ def mappings(self):
+ """Apply a mappings filter to returned rows, returning an instance of
+ :class:`_result.MappingResult`.
+
+ When this filter is applied, fetching rows will return
+ :class:`.RowMapping` objects instead of :class:`.Row` objects.
+
+ .. versionadded:: 1.4
+
+ :return: a new :class:`_result.MappingResult` filtering object
+ referring to this :class:`_result.Result` object.
+
+ """
+
+ return MappingResult(self)
+
+ def _raw_row_iterator(self):
+ """Return a safe iterator that yields raw row data.
+
+ This is used by the :meth:`._engine.Result.merge` method
+ to merge multiple compatible results together.
+
+ """
+ raise NotImplementedError()
+
+ def _fetchiter_impl(self):
+ raise NotImplementedError()
+
+ def _fetchone_impl(self, hard_close=False):
+ raise NotImplementedError()
+
+ def _fetchall_impl(self):
+ raise NotImplementedError()
+
+ def _fetchmany_impl(self, size=None):
+ raise NotImplementedError()
+
+ def __iter__(self):
+ return self._iter_impl()
+
+ def __next__(self):
+ return self._next_impl()
+
+ if py2k:
+
+ def next(self): # noqa
+ return self._next_impl()
+
+ def partitions(self, size=None):
+ """Iterate through sub-lists of rows of the size given.
+
+ Each list will be of the size given, excluding the last list to
+ be yielded, which may have a small number of rows. No empty
+ lists will be yielded.
+
+ The result object is automatically closed when the iterator
+ is fully consumed.
+
+ Note that the backend driver will usually buffer the entire result
+ ahead of time unless the
+ :paramref:`.Connection.execution_options.stream_results` execution
+ option is used indicating that the driver should not pre-buffer
+ results, if possible. Not all drivers support this option and
+ the option is silently ignored for those who do not.
+
+ When using the ORM, the :meth:`_engine.Result.partitions` method
+ is typically more effective from a memory perspective when it is
+ combined with use of the
+ :ref:`yield_per execution option <orm_queryguide_yield_per>`,
+ which instructs both the DBAPI driver to use server side cursors,
+ if available, as well as instructs the ORM loading internals to only
+ build a certain amount of ORM objects from a result at a time before
+ yielding them out.
+
+ .. versionadded:: 1.4
+
+ :param size: indicate the maximum number of rows to be present
+ in each list yielded. If None, makes use of the value set by
+ the :meth:`_engine.Result.yield_per`, method, if it were called,
+ or the :paramref:`_engine.Connection.execution_options.yield_per`
+ execution option, which is equivalent in this regard. If
+ yield_per weren't set, it makes use of the
+ :meth:`_engine.Result.fetchmany` default, which may be backend
+ specific and not well defined.
+
+ :return: iterator of lists
+
+ .. seealso::
+
+ :ref:`engine_stream_results`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = getter(self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ def fetchall(self):
+ """A synonym for the :meth:`_engine.Result.all` method."""
+
+ return self._allrows()
+
+ def fetchone(self):
+ """Fetch one row.
+
+ When all rows are exhausted, returns None.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch the first row of a result only, use the
+ :meth:`_engine.Result.first` method. To iterate through all
+ rows, iterate the :class:`_engine.Result` object directly.
+
+ :return: a :class:`.Row` object if no filters are applied, or None
+ if no rows remain.
+
+ """
+ row = self._onerow_getter(self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ def fetchmany(self, size=None):
+ """Fetch many rows.
+
+ When all rows are exhausted, returns an empty list.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch rows in groups, use the :meth:`._result.Result.partitions`
+ method.
+
+ :return: a list of :class:`.Row` objects.
+
+ """
+
+ return self._manyrow_getter(self, size)
+
+ def all(self):
+ """Return all rows in a list.
+
+ Closes the result set after invocation. Subsequent invocations
+ will return an empty list.
+
+ .. versionadded:: 1.4
+
+ :return: a list of :class:`.Row` objects.
+
+ """
+
+ return self._allrows()
+
+ def first(self):
+ """Fetch the first row or None if no row is present.
+
+ Closes the result set and discards remaining rows.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.first`.
+
+ Additionally, in contrast to the behavior of the legacy ORM
+ :meth:`_orm.Query.first` method, **no limit is applied** to the
+ SQL query which was invoked to produce this :class:`_engine.Result`;
+ for a DBAPI driver that buffers results in memory before yielding
+ rows, all rows will be sent to the Python process and all but
+ the first row will be discarded.
+
+ .. seealso::
+
+ :ref:`migration_20_unify_select`
+
+ :return: a :class:`.Row` object, or None
+ if no rows remain.
+
+ .. seealso::
+
+ :meth:`_result.Result.scalar`
+
+ :meth:`_result.Result.one`
+
+ """
+
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=False
+ )
+
+ def one_or_none(self):
+ """Return at most one result or raise an exception.
+
+ Returns ``None`` if the result has no rows.
+ Raises :class:`.MultipleResultsFound`
+ if multiple rows are returned.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row` or None if no row is available.
+
+ :raises: :class:`.MultipleResultsFound`
+
+ .. seealso::
+
+ :meth:`_result.Result.first`
+
+ :meth:`_result.Result.one`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=False
+ )
+
+ def scalar_one(self):
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=True
+ )
+
+ def scalar_one_or_none(self):
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=True
+ )
+
+ def one(self):
+ """Return exactly one row or raise an exception.
+
+ Raises :class:`.NoResultFound` if the result returns no
+ rows, or :class:`.MultipleResultsFound` if multiple rows
+ would be returned.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar_one` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.one`.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row`.
+
+ :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
+
+ .. seealso::
+
+ :meth:`_result.Result.first`
+
+ :meth:`_result.Result.one_or_none`
+
+ :meth:`_result.Result.scalar_one`
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=False
+ )
+
+ def scalar(self):
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=True
+ )
+
+ def freeze(self):
+ """Return a callable object that will produce copies of this
+ :class:`.Result` when invoked.
+
+ The callable object returned is an instance of
+ :class:`_engine.FrozenResult`.
+
+ This is used for result set caching. The method must be called
+ on the result when it has been unconsumed, and calling the method
+ will consume the result fully. When the :class:`_engine.FrozenResult`
+ is retrieved from a cache, it can be called any number of times where
+ it will produce a new :class:`_engine.Result` object each time
+ against its stored set of rows.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - example usage within the
+ ORM to implement a result-set cache.
+
+ """
+
+ return FrozenResult(self)
+
+ def merge(self, *others):
+ """Merge this :class:`.Result` with other compatible result
+ objects.
+
+ The object returned is an instance of :class:`_engine.MergedResult`,
+ which will be composed of iterators from the given result
+ objects.
+
+ The new result will use the metadata from this result object.
+ The subsequent result objects must be against an identical
+ set of result / cursor metadata, otherwise the behavior is
+ undefined.
+
+ """
+ return MergedResult(self._metadata, (self,) + others)
+
+
+class FilterResult(ResultInternal):
+ """A wrapper for a :class:`_engine.Result` that returns objects other than
+ :class:`_result.Row` objects, such as dictionaries or scalar objects.
+
+ :class:`.FilterResult` is the common base for additional result
+ APIs including :class:`.MappingResult`, :class:`.ScalarResult`
+ and :class:`.AsyncResult`.
+
+ """
+
+ _post_creational_filter = None
+
+ @_generative
+ def yield_per(self, num):
+ """Configure the row-fetching strategy to fetch ``num`` rows at a time.
+
+ The :meth:`_engine.FilterResult.yield_per` method is a pass through
+ to the :meth:`_engine.Result.yield_per` method. See that method's
+ documentation for usage notes.
+
+ .. versionadded:: 1.4.40 - added :meth:`_engine.FilterResult.yield_per`
+ so that the method is available on all result set implementations
+
+ .. seealso::
+
+ :ref:`engine_stream_results` - describes Core behavior for
+ :meth:`_engine.Result.yield_per`
+
+ :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel`
+
+ """
+ self._real_result = self._real_result.yield_per(num)
+
+ def _soft_close(self, hard=False):
+ self._real_result._soft_close(hard=hard)
+
+ @property
+ def _attributes(self):
+ return self._real_result._attributes
+
+ def _fetchiter_impl(self):
+ return self._real_result._fetchiter_impl()
+
+ def _fetchone_impl(self, hard_close=False):
+ return self._real_result._fetchone_impl(hard_close=hard_close)
+
+ def _fetchall_impl(self):
+ return self._real_result._fetchall_impl()
+
+ def _fetchmany_impl(self, size=None):
+ return self._real_result._fetchmany_impl(size=size)
+
+
+class ScalarResult(FilterResult):
+ """A wrapper for a :class:`_result.Result` that returns scalar values
+ rather than :class:`_row.Row` values.
+
+ The :class:`_result.ScalarResult` object is acquired by calling the
+ :meth:`_result.Result.scalars` method.
+
+ A special limitation of :class:`_result.ScalarResult` is that it has
+ no ``fetchone()`` method; since the semantics of ``fetchone()`` are that
+ the ``None`` value indicates no more results, this is not compatible
+ with :class:`_result.ScalarResult` since there is no way to distinguish
+ between ``None`` as a row value versus ``None`` as an indicator. Use
+ ``next(result)`` to receive values individually.
+
+ """
+
+ _generate_rows = False
+
+ def __init__(self, real_result, index):
+ self._real_result = real_result
+
+ if real_result._source_supports_scalars:
+ self._metadata = real_result._metadata
+ self._post_creational_filter = None
+ else:
+ self._metadata = real_result._metadata._reduce([index])
+ self._post_creational_filter = operator.itemgetter(0)
+
+ self._unique_filter_state = real_result._unique_filter_state
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_engine.ScalarResult`.
+
+ See :meth:`_engine.Result.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = getter(self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ def fetchall(self):
+ """A synonym for the :meth:`_engine.ScalarResult.all` method."""
+
+ return self._allrows()
+
+ def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._manyrow_getter(self, size)
+
+ def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._allrows()
+
+ def __iter__(self):
+ return self._iter_impl()
+
+ def __next__(self):
+ return self._next_impl()
+
+ if py2k:
+
+ def next(self): # noqa
+ return self._next_impl()
+
+ def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=False
+ )
+
+ def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=False
+ )
+
+ def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=False
+ )
+
+
+class MappingResult(_WithKeys, FilterResult):
+ """A wrapper for a :class:`_engine.Result` that returns dictionary values
+ rather than :class:`_engine.Row` values.
+
+ The :class:`_engine.MappingResult` object is acquired by calling the
+ :meth:`_engine.Result.mappings` method.
+
+ """
+
+ _generate_rows = True
+
+ _post_creational_filter = operator.attrgetter("_mapping")
+
+ def __init__(self, result):
+ self._real_result = result
+ self._unique_filter_state = result._unique_filter_state
+ self._metadata = result._metadata
+ if result._source_supports_scalars:
+ self._metadata = self._metadata._reduce([0])
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_engine.MappingResult`.
+
+ See :meth:`_engine.Result.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row."""
+ return self._column_slices(col_expressions)
+
+ def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_result.Result.partitions` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = getter(self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ def fetchall(self):
+ """A synonym for the :meth:`_engine.MappingResult.all` method."""
+
+ return self._allrows()
+
+ def fetchone(self):
+ """Fetch one object.
+
+ Equivalent to :meth:`_result.Result.fetchone` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ row = self._onerow_getter(self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_result.Result.fetchmany` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return self._manyrow_getter(self, size)
+
+ def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_result.Result.all` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return self._allrows()
+
+ def __iter__(self):
+ return self._iter_impl()
+
+ def __next__(self):
+ return self._next_impl()
+
+ if py2k:
+
+ def next(self): # noqa
+ return self._next_impl()
+
+ def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_result.Result.first` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=False, raise_for_none=False, scalar=False
+ )
+
+ def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one_or_none` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=False, scalar=False
+ )
+
+ def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_result.Result.one` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return self._only_one_row(
+ raise_for_second_row=True, raise_for_none=True, scalar=False
+ )
+
+
+class FrozenResult(object):
+ """Represents a :class:`.Result` object in a "frozen" state suitable
+ for caching.
+
+ The :class:`_engine.FrozenResult` object is returned from the
+ :meth:`_engine.Result.freeze` method of any :class:`_engine.Result`
+ object.
+
+ A new iterable :class:`.Result` object is generated from a fixed
+ set of data each time the :class:`.FrozenResult` is invoked as
+ a callable::
+
+
+ result = connection.execute(query)
+
+ frozen = result.freeze()
+
+ unfrozen_result_one = frozen()
+
+ for row in unfrozen_result_one:
+ print(row)
+
+ unfrozen_result_two = frozen()
+ rows = unfrozen_result_two.all()
+
+ # ... etc
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - example usage within the
+ ORM to implement a result-set cache.
+
+ :func:`_orm.loading.merge_frozen_result` - ORM function to merge
+ a frozen result back into a :class:`_orm.Session`.
+
+ """
+
+ def __init__(self, result):
+ self.metadata = result._metadata._for_freeze()
+ self._source_supports_scalars = result._source_supports_scalars
+ self._attributes = result._attributes
+
+ if self._source_supports_scalars:
+ self.data = list(result._raw_row_iterator())
+ else:
+ self.data = result.fetchall()
+
+ def rewrite_rows(self):
+ if self._source_supports_scalars:
+ return [[elem] for elem in self.data]
+ else:
+ return [list(row) for row in self.data]
+
+ def with_new_rows(self, tuple_data):
+ fr = FrozenResult.__new__(FrozenResult)
+ fr.metadata = self.metadata
+ fr._attributes = self._attributes
+ fr._source_supports_scalars = self._source_supports_scalars
+
+ if self._source_supports_scalars:
+ fr.data = [d[0] for d in tuple_data]
+ else:
+ fr.data = tuple_data
+ return fr
+
+ def __call__(self):
+ result = IteratorResult(self.metadata, iter(self.data))
+ result._attributes = self._attributes
+ result._source_supports_scalars = self._source_supports_scalars
+ return result
+
+
+class IteratorResult(Result):
+ """A :class:`.Result` that gets data from a Python iterator of
+ :class:`.Row` objects.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _hard_closed = False
+
+ def __init__(
+ self,
+ cursor_metadata,
+ iterator,
+ raw=None,
+ _source_supports_scalars=False,
+ ):
+ self._metadata = cursor_metadata
+ self.iterator = iterator
+ self.raw = raw
+ self._source_supports_scalars = _source_supports_scalars
+
+ def _soft_close(self, hard=False, **kw):
+ if hard:
+ self._hard_closed = True
+ if self.raw is not None:
+ self.raw._soft_close(hard=hard, **kw)
+ self.iterator = iter([])
+ self._reset_memoizations()
+
+ def _raise_hard_closed(self):
+ raise exc.ResourceClosedError("This result object is closed.")
+
+ def _raw_row_iterator(self):
+ return self.iterator
+
+ def _fetchiter_impl(self):
+ if self._hard_closed:
+ self._raise_hard_closed()
+ return self.iterator
+
+ def _fetchone_impl(self, hard_close=False):
+ if self._hard_closed:
+ self._raise_hard_closed()
+
+ row = next(self.iterator, _NO_ROW)
+ if row is _NO_ROW:
+ self._soft_close(hard=hard_close)
+ return None
+ else:
+ return row
+
+ def _fetchall_impl(self):
+ if self._hard_closed:
+ self._raise_hard_closed()
+
+ try:
+ return list(self.iterator)
+ finally:
+ self._soft_close()
+
+ def _fetchmany_impl(self, size=None):
+ if self._hard_closed:
+ self._raise_hard_closed()
+
+ return list(itertools.islice(self.iterator, 0, size))
+
+
+def null_result():
+ return IteratorResult(SimpleResultMetaData([]), iter([]))
+
+
+class ChunkedIteratorResult(IteratorResult):
+ """An :class:`.IteratorResult` that works from an iterator-producing
+ callable.
+
+ The given ``chunks`` argument is a function that is given a number of rows
+ to return in each chunk, or ``None`` for all rows. The function should
+ then return an un-consumed iterator of lists, each list of the requested
+ size.
+
+ The function can be called at any time again, in which case it should
+ continue from the same result set but adjust the chunk size as given.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def __init__(
+ self,
+ cursor_metadata,
+ chunks,
+ source_supports_scalars=False,
+ raw=None,
+ dynamic_yield_per=False,
+ ):
+ self._metadata = cursor_metadata
+ self.chunks = chunks
+ self._source_supports_scalars = source_supports_scalars
+ self.raw = raw
+ self.iterator = itertools.chain.from_iterable(self.chunks(None))
+ self.dynamic_yield_per = dynamic_yield_per
+
+ @_generative
+ def yield_per(self, num):
+ # TODO: this throws away the iterator which may be holding
+ # onto a chunk. the yield_per cannot be changed once any
+ # rows have been fetched. either find a way to enforce this,
+ # or we can't use itertools.chain and will instead have to
+ # keep track.
+
+ self._yield_per = num
+ self.iterator = itertools.chain.from_iterable(self.chunks(num))
+
+ def _soft_close(self, **kw):
+ super(ChunkedIteratorResult, self)._soft_close(**kw)
+ self.chunks = lambda size: []
+
+ def _fetchmany_impl(self, size=None):
+ if self.dynamic_yield_per:
+ self.iterator = itertools.chain.from_iterable(self.chunks(size))
+ return super(ChunkedIteratorResult, self)._fetchmany_impl(size=size)
+
+
+class MergedResult(IteratorResult):
+ """A :class:`_engine.Result` that is merged from any number of
+ :class:`_engine.Result` objects.
+
+ Returned by the :meth:`_engine.Result.merge` method.
+
+ .. versionadded:: 1.4
+
+ """
+
+ closed = False
+
+ def __init__(self, cursor_metadata, results):
+ self._results = results
+ super(MergedResult, self).__init__(
+ cursor_metadata,
+ itertools.chain.from_iterable(
+ r._raw_row_iterator() for r in results
+ ),
+ )
+
+ self._unique_filter_state = results[0]._unique_filter_state
+ self._yield_per = results[0]._yield_per
+
+ # going to try something w/ this in next rev
+ self._source_supports_scalars = results[0]._source_supports_scalars
+
+ self._attributes = self._attributes.merge_with(
+ *[r._attributes for r in results]
+ )
+
+ def _soft_close(self, hard=False, **kw):
+ for r in self._results:
+ r._soft_close(hard=hard, **kw)
+ if hard:
+ self.closed = True
diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py
new file mode 100644
index 0000000..e80e8c6
--- /dev/null
+++ b/lib/sqlalchemy/engine/row.py
@@ -0,0 +1,621 @@
+# engine/row.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define row constructs including :class:`.Row`."""
+
+
+import operator
+
+from .. import util
+from ..sql import util as sql_util
+from ..util.compat import collections_abc
+
+MD_INDEX = 0 # integer index in cursor.description
+
+# This reconstructor is necessary so that pickles with the C extension or
+# without use the same Binary format.
+try:
+ # We need a different reconstructor on the C extension so that we can
+ # add extra checks that fields have correctly been initialized by
+ # __setstate__.
+ from sqlalchemy.cresultproxy import safe_rowproxy_reconstructor
+
+ # The extra function embedding is needed so that the
+ # reconstructor function has the same signature whether or not
+ # the extension is present.
+ def rowproxy_reconstructor(cls, state):
+ return safe_rowproxy_reconstructor(cls, state)
+
+
+except ImportError:
+
+ def rowproxy_reconstructor(cls, state):
+ obj = cls.__new__(cls)
+ obj.__setstate__(state)
+ return obj
+
+
+KEY_INTEGER_ONLY = 0
+"""__getitem__ only allows integer values, raises TypeError otherwise"""
+
+KEY_OBJECTS_ONLY = 1
+"""__getitem__ only allows string/object values, raises TypeError otherwise"""
+
+KEY_OBJECTS_BUT_WARN = 2
+"""__getitem__ allows integer or string/object values, but emits a 2.0
+deprecation warning if string/object is passed"""
+
+KEY_OBJECTS_NO_WARN = 3
+"""__getitem__ allows integer or string/object values with no warnings
+or errors."""
+
+try:
+ from sqlalchemy.cresultproxy import BaseRow
+
+ _baserow_usecext = True
+except ImportError:
+ _baserow_usecext = False
+
+ class BaseRow(object):
+ __slots__ = ("_parent", "_data", "_keymap", "_key_style")
+
+ def __init__(self, parent, processors, keymap, key_style, data):
+ """Row objects are constructed by CursorResult objects."""
+
+ object.__setattr__(self, "_parent", parent)
+
+ if processors:
+ object.__setattr__(
+ self,
+ "_data",
+ tuple(
+ [
+ proc(value) if proc else value
+ for proc, value in zip(processors, data)
+ ]
+ ),
+ )
+ else:
+ object.__setattr__(self, "_data", tuple(data))
+
+ object.__setattr__(self, "_keymap", keymap)
+
+ object.__setattr__(self, "_key_style", key_style)
+
+ def __reduce__(self):
+ return (
+ rowproxy_reconstructor,
+ (self.__class__, self.__getstate__()),
+ )
+
+ def _filter_on_values(self, filters):
+ return Row(
+ self._parent,
+ filters,
+ self._keymap,
+ self._key_style,
+ self._data,
+ )
+
+ def _values_impl(self):
+ return list(self)
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __len__(self):
+ return len(self._data)
+
+ def __hash__(self):
+ return hash(self._data)
+
+ def _get_by_int_impl(self, key):
+ return self._data[key]
+
+ def _get_by_key_impl(self, key):
+ if int in key.__class__.__mro__:
+ return self._data[key]
+
+ if self._key_style == KEY_INTEGER_ONLY:
+ self._parent._raise_for_nonint(key)
+
+ # the following is all LegacyRow support. none of this
+ # should be called if not LegacyRow
+ # assert isinstance(self, LegacyRow)
+
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._parent._key_fallback(key, ke)
+ except TypeError:
+ if isinstance(key, slice):
+ return tuple(self._data[key])
+ else:
+ raise
+
+ mdindex = rec[MD_INDEX]
+ if mdindex is None:
+ self._parent._raise_for_ambiguous_column_name(rec)
+
+ elif self._key_style == KEY_OBJECTS_BUT_WARN and mdindex != key:
+ self._parent._warn_for_nonint(key)
+
+ return self._data[mdindex]
+
+ # The original 1.4 plan was that Row would not allow row["str"]
+ # access, however as the C extensions were inadvertently allowing
+ # this coupled with the fact that orm Session sets future=True,
+ # this allows a softer upgrade path. see #6218
+ __getitem__ = _get_by_key_impl
+
+ def _get_by_key_impl_mapping(self, key):
+ try:
+ rec = self._keymap[key]
+ except KeyError as ke:
+ rec = self._parent._key_fallback(key, ke)
+
+ mdindex = rec[MD_INDEX]
+ if mdindex is None:
+ self._parent._raise_for_ambiguous_column_name(rec)
+ elif (
+ self._key_style == KEY_OBJECTS_ONLY
+ and int in key.__class__.__mro__
+ ):
+ raise KeyError(key)
+
+ return self._data[mdindex]
+
+ def __getattr__(self, name):
+ try:
+ return self._get_by_key_impl_mapping(name)
+ except KeyError as e:
+ util.raise_(AttributeError(e.args[0]), replace_context=e)
+
+
+class Row(BaseRow, collections_abc.Sequence):
+ """Represent a single result row.
+
+ The :class:`.Row` object represents a row of a database result. It is
+ typically associated in the 1.x series of SQLAlchemy with the
+ :class:`_engine.CursorResult` object, however is also used by the ORM for
+ tuple-like results as of SQLAlchemy 1.4.
+
+ The :class:`.Row` object seeks to act as much like a Python named
+ tuple as possible. For mapping (i.e. dictionary) behavior on a row,
+ such as testing for containment of keys, refer to the :attr:`.Row._mapping`
+ attribute.
+
+ .. seealso::
+
+ :ref:`tutorial_selecting_data` - includes examples of selecting
+ rows from SELECT statements.
+
+ :class:`.LegacyRow` - Compatibility interface introduced in SQLAlchemy
+ 1.4.
+
+ .. versionchanged:: 1.4
+
+ Renamed ``RowProxy`` to :class:`.Row`. :class:`.Row` is no longer a
+ "proxy" object in that it contains the final form of data within it,
+ and now acts mostly like a named tuple. Mapping-like functionality is
+ moved to the :attr:`.Row._mapping` attribute, but will remain available
+ in SQLAlchemy 1.x series via the :class:`.LegacyRow` class that is used
+ by :class:`_engine.LegacyCursorResult`.
+ See :ref:`change_4710_core` for background
+ on this change.
+
+ """
+
+ __slots__ = ()
+
+ # in 2.0, this should be KEY_INTEGER_ONLY
+ _default_key_style = KEY_OBJECTS_BUT_WARN
+
+ def __setattr__(self, name, value):
+ raise AttributeError("can't set attribute")
+
+ def __delattr__(self, name):
+ raise AttributeError("can't delete attribute")
+
+ @property
+ def _mapping(self):
+ """Return a :class:`.RowMapping` for this :class:`.Row`.
+
+ This object provides a consistent Python mapping (i.e. dictionary)
+ interface for the data contained within the row. The :class:`.Row`
+ by itself behaves like a named tuple, however in the 1.4 series of
+ SQLAlchemy, the :class:`.LegacyRow` class is still used by Core which
+ continues to have mapping-like behaviors against the row object
+ itself.
+
+ .. seealso::
+
+ :attr:`.Row._fields`
+
+ .. versionadded:: 1.4
+
+ """
+ return RowMapping(
+ self._parent,
+ None,
+ self._keymap,
+ RowMapping._default_key_style,
+ self._data,
+ )
+
+ def _special_name_accessor(name):
+ """Handle ambiguous names such as "count" and "index" """
+
+ @property
+ def go(self):
+ if self._parent._has_key(name):
+ return self.__getattr__(name)
+ else:
+
+ def meth(*arg, **kw):
+ return getattr(collections_abc.Sequence, name)(
+ self, *arg, **kw
+ )
+
+ return meth
+
+ return go
+
+ count = _special_name_accessor("count")
+ index = _special_name_accessor("index")
+
+ def __contains__(self, key):
+ return key in self._data
+
+ def __getstate__(self):
+ return {
+ "_parent": self._parent,
+ "_data": self._data,
+ "_key_style": self._key_style,
+ }
+
+ def __setstate__(self, state):
+ parent = state["_parent"]
+ object.__setattr__(self, "_parent", parent)
+ object.__setattr__(self, "_data", state["_data"])
+ object.__setattr__(self, "_keymap", parent._keymap)
+ object.__setattr__(self, "_key_style", state["_key_style"])
+
+ def _op(self, other, op):
+ return (
+ op(tuple(self), tuple(other))
+ if isinstance(other, Row)
+ else op(tuple(self), other)
+ )
+
+ __hash__ = BaseRow.__hash__
+
+ def __lt__(self, other):
+ return self._op(other, operator.lt)
+
+ def __le__(self, other):
+ return self._op(other, operator.le)
+
+ def __ge__(self, other):
+ return self._op(other, operator.ge)
+
+ def __gt__(self, other):
+ return self._op(other, operator.gt)
+
+ def __eq__(self, other):
+ return self._op(other, operator.eq)
+
+ def __ne__(self, other):
+ return self._op(other, operator.ne)
+
+ def __repr__(self):
+ return repr(sql_util._repr_row(self))
+
+ @util.deprecated_20(
+ ":meth:`.Row.keys`",
+ alternative="Use the namedtuple standard accessor "
+ ":attr:`.Row._fields`, or for full mapping behavior use "
+ "row._mapping.keys() ",
+ )
+ def keys(self):
+ """Return the list of keys as strings represented by this
+ :class:`.Row`.
+
+ The keys can represent the labels of the columns returned by a core
+ statement or the names of the orm classes returned by an orm
+ execution.
+
+ This method is analogous to the Python dictionary ``.keys()`` method,
+ except that it returns a list, not an iterator.
+
+ .. seealso::
+
+ :attr:`.Row._fields`
+
+ :attr:`.Row._mapping`
+
+ """
+ return self._parent.keys
+
+ @property
+ def _fields(self):
+ """Return a tuple of string keys as represented by this
+ :class:`.Row`.
+
+ The keys can represent the labels of the columns returned by a core
+ statement or the names of the orm classes returned by an orm
+ execution.
+
+ This attribute is analogous to the Python named tuple ``._fields``
+ attribute.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`.Row._mapping`
+
+ """
+ return tuple([k for k in self._parent.keys if k is not None])
+
+ def _asdict(self):
+ """Return a new dict which maps field names to their corresponding
+ values.
+
+ This method is analogous to the Python named tuple ``._asdict()``
+ method, and works by applying the ``dict()`` constructor to the
+ :attr:`.Row._mapping` attribute.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`.Row._mapping`
+
+ """
+ return dict(self._mapping)
+
+ def _replace(self):
+ raise NotImplementedError()
+
+ @property
+ def _field_defaults(self):
+ raise NotImplementedError()
+
+
+class LegacyRow(Row):
+ """A subclass of :class:`.Row` that delivers 1.x SQLAlchemy behaviors
+ for Core.
+
+ The :class:`.LegacyRow` class is where most of the Python mapping
+ (i.e. dictionary-like)
+ behaviors are implemented for the row object. The mapping behavior
+ of :class:`.Row` going forward is accessible via the :class:`.Row._mapping`
+ attribute.
+
+ .. versionadded:: 1.4 - added :class:`.LegacyRow` which encapsulates most
+ of the deprecated behaviors of :class:`.Row`.
+
+ """
+
+ __slots__ = ()
+
+ if util.SQLALCHEMY_WARN_20:
+ _default_key_style = KEY_OBJECTS_BUT_WARN
+ else:
+ _default_key_style = KEY_OBJECTS_NO_WARN
+
+ def __contains__(self, key):
+ return self._parent._contains(key, self)
+
+ # prior to #6218, LegacyRow would redirect the behavior of __getitem__
+ # for the non C version of BaseRow. This is now set up by Python BaseRow
+ # in all cases
+ # if not _baserow_usecext:
+ # __getitem__ = BaseRow._get_by_key_impl
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.has_key` method is deprecated and will be "
+ "removed in a future release. To test for key membership, use "
+ "the :attr:`Row._mapping` attribute, i.e. 'key in row._mapping`.",
+ )
+ def has_key(self, key):
+ """Return True if this :class:`.LegacyRow` contains the given key.
+
+ Through the SQLAlchemy 1.x series, the ``__contains__()`` method of
+ :class:`.Row` (or :class:`.LegacyRow` as of SQLAlchemy 1.4) also links
+ to :meth:`.Row.has_key`, in that an expression such as ::
+
+ "some_col" in row
+
+ Will return True if the row contains a column named ``"some_col"``,
+ in the way that a Python mapping works.
+
+ However, it is planned that the 2.0 series of SQLAlchemy will reverse
+ this behavior so that ``__contains__()`` will refer to a value being
+ present in the row, in the way that a Python tuple works.
+
+ .. seealso::
+
+ :ref:`change_4710_core`
+
+ """
+
+ return self._parent._has_key(key)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.items` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.items()'.",
+ )
+ def items(self):
+ """Return a list of tuples, each tuple containing a key/value pair.
+
+ This method is analogous to the Python dictionary ``.items()`` method,
+ except that it returns a list, not an iterator.
+
+ """
+
+ return [(key, self[key]) for key in self.keys()]
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.iterkeys` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.keys()'.",
+ )
+ def iterkeys(self):
+ """Return a an iterator against the :meth:`.Row.keys` method.
+
+ This method is analogous to the Python-2-only dictionary
+ ``.iterkeys()`` method.
+
+ """
+ return iter(self._parent.keys)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.itervalues` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.values()'.",
+ )
+ def itervalues(self):
+ """Return a an iterator against the :meth:`.Row.values` method.
+
+ This method is analogous to the Python-2-only dictionary
+ ``.itervalues()`` method.
+
+ """
+ return iter(self)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.LegacyRow.values` method is deprecated and will be "
+ "removed in a future release. Use the :attr:`Row._mapping` "
+ "attribute, i.e., 'row._mapping.values()'.",
+ )
+ def values(self):
+ """Return the values represented by this :class:`.Row` as a list.
+
+ This method is analogous to the Python dictionary ``.values()`` method,
+ except that it returns a list, not an iterator.
+
+ """
+
+ return self._values_impl()
+
+
+BaseRowProxy = BaseRow
+RowProxy = Row
+
+
+class ROMappingView(
+ collections_abc.KeysView,
+ collections_abc.ValuesView,
+ collections_abc.ItemsView,
+):
+ __slots__ = (
+ "_mapping",
+ "_items",
+ )
+
+ def __init__(self, mapping, items):
+ self._mapping = mapping
+ self._items = items
+
+ def __len__(self):
+ return len(self._items)
+
+ def __repr__(self):
+ return "{0.__class__.__name__}({0._mapping!r})".format(self)
+
+ def __iter__(self):
+ return iter(self._items)
+
+ def __contains__(self, item):
+ return item in self._items
+
+ def __eq__(self, other):
+ return list(other) == list(self)
+
+ def __ne__(self, other):
+ return list(other) != list(self)
+
+
+class RowMapping(BaseRow, collections_abc.Mapping):
+ """A ``Mapping`` that maps column names and objects to :class:`.Row`
+ values.
+
+ The :class:`.RowMapping` is available from a :class:`.Row` via the
+ :attr:`.Row._mapping` attribute, as well as from the iterable interface
+ provided by the :class:`.MappingResult` object returned by the
+ :meth:`_engine.Result.mappings` method.
+
+ :class:`.RowMapping` supplies Python mapping (i.e. dictionary) access to
+ the contents of the row. This includes support for testing of
+ containment of specific keys (string column names or objects), as well
+ as iteration of keys, values, and items::
+
+ for row in result:
+ if 'a' in row._mapping:
+ print("Column 'a': %s" % row._mapping['a'])
+
+ print("Column b: %s" % row._mapping[table.c.b])
+
+
+ .. versionadded:: 1.4 The :class:`.RowMapping` object replaces the
+ mapping-like access previously provided by a database result row,
+ which now seeks to behave mostly like a named tuple.
+
+ """
+
+ __slots__ = ()
+
+ _default_key_style = KEY_OBJECTS_ONLY
+
+ if not _baserow_usecext:
+
+ __getitem__ = BaseRow._get_by_key_impl_mapping
+
+ def _values_impl(self):
+ return list(self._data)
+
+ def __iter__(self):
+ return (k for k in self._parent.keys if k is not None)
+
+ def __len__(self):
+ return len(self._data)
+
+ def __contains__(self, key):
+ return self._parent._has_key(key)
+
+ def __repr__(self):
+ return repr(dict(self))
+
+ def items(self):
+ """Return a view of key/value tuples for the elements in the
+ underlying :class:`.Row`.
+
+ """
+ return ROMappingView(self, [(key, self[key]) for key in self.keys()])
+
+ def keys(self):
+ """Return a view of 'keys' for string column names represented
+ by the underlying :class:`.Row`.
+
+ """
+
+ return self._parent.keys
+
+ def values(self):
+ """Return a view of values for the values represented in the
+ underlying :class:`.Row`.
+
+ """
+ return ROMappingView(self, self._values_impl())
diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py
new file mode 100644
index 0000000..54a5e51
--- /dev/null
+++ b/lib/sqlalchemy/engine/strategies.py
@@ -0,0 +1,17 @@
+# engine/strategies.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Deprecated mock engine strategy used by Alembic.
+
+
+"""
+
+from .mock import MockConnection # noqa
+
+
+class MockEngineStrategy(object):
+ MockConnection = MockConnection
diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py
new file mode 100644
index 0000000..db971c2
--- /dev/null
+++ b/lib/sqlalchemy/engine/url.py
@@ -0,0 +1,806 @@
+# engine/url.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates
+information about a database connection specification.
+
+The URL object is created automatically when
+:func:`~sqlalchemy.engine.create_engine` is called with a string
+argument; alternatively, the URL is a public-facing construct which can
+be used directly and is also accepted directly by ``create_engine()``.
+"""
+
+import re
+
+from .interfaces import Dialect
+from .. import exc
+from .. import util
+from ..dialects import plugins
+from ..dialects import registry
+from ..util import collections_abc
+from ..util import compat
+
+
+class URL(
+ util.namedtuple(
+ "URL",
+ [
+ "drivername",
+ "username",
+ "password",
+ "host",
+ "port",
+ "database",
+ "query",
+ ],
+ )
+):
+ """
+ Represent the components of a URL used to connect to a database.
+
+ This object is suitable to be passed directly to a
+ :func:`_sa.create_engine` call. The fields of the URL are parsed
+ from a string by the :func:`.make_url` function. The string
+ format of the URL is an RFC-1738-style string.
+
+ To create a new :class:`_engine.URL` object, use the
+ :func:`_engine.url.make_url` function. To construct a :class:`_engine.URL`
+ programmatically, use the :meth:`_engine.URL.create` constructor.
+
+ .. versionchanged:: 1.4
+
+ The :class:`_engine.URL` object is now an immutable object. To
+ create a URL, use the :func:`_engine.make_url` or
+ :meth:`_engine.URL.create` function / method. To modify
+ a :class:`_engine.URL`, use methods like
+ :meth:`_engine.URL.set` and
+ :meth:`_engine.URL.update_query_dict` to return a new
+ :class:`_engine.URL` object with modifications. See notes for this
+ change at :ref:`change_5526`.
+
+ :class:`_engine.URL` contains the following attributes:
+
+ * :attr:`_engine.URL.drivername`: database backend and driver name, such as
+ ``postgresql+psycopg2``
+ * :attr:`_engine.URL.username`: username string
+ * :attr:`_engine.URL.password`: password string
+ * :attr:`_engine.URL.host`: string hostname
+ * :attr:`_engine.URL.port`: integer port number
+ * :attr:`_engine.URL.database`: string database name
+ * :attr:`_engine.URL.query`: an immutable mapping representing the query
+ string. contains strings for keys and either strings or tuples of
+ strings for values.
+
+
+ """
+
+ def __new__(self, *arg, **kw):
+ if kw.pop("_new_ok", False):
+ return super(URL, self).__new__(self, *arg, **kw)
+ else:
+ util.warn_deprecated(
+ "Calling URL() directly is deprecated and will be disabled "
+ "in a future release. The public constructor for URL is "
+ "now the URL.create() method.",
+ "1.4",
+ )
+ return URL.create(*arg, **kw)
+
+ @classmethod
+ def create(
+ cls,
+ drivername,
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ database=None,
+ query=util.EMPTY_DICT,
+ ):
+ """Create a new :class:`_engine.URL` object.
+
+ :param drivername: the name of the database backend. This name will
+ correspond to a module in sqlalchemy/databases or a third party
+ plug-in.
+ :param username: The user name.
+ :param password: database password. Is typically a string, but may
+ also be an object that can be stringified with ``str()``.
+
+ .. note:: A password-producing object will be stringified only
+ **once** per :class:`_engine.Engine` object. For dynamic password
+ generation per connect, see :ref:`engines_dynamic_tokens`.
+
+ :param host: The name of the host.
+ :param port: The port number.
+ :param database: The database name.
+ :param query: A dictionary of string keys to string values to be passed
+ to the dialect and/or the DBAPI upon connect. To specify non-string
+ parameters to a Python DBAPI directly, use the
+ :paramref:`_sa.create_engine.connect_args` parameter to
+ :func:`_sa.create_engine`. See also
+ :attr:`_engine.URL.normalized_query` for a dictionary that is
+ consistently string->list of string.
+ :return: new :class:`_engine.URL` object.
+
+ .. versionadded:: 1.4
+
+ The :class:`_engine.URL` object is now an **immutable named
+ tuple**. In addition, the ``query`` dictionary is also immutable.
+ To create a URL, use the :func:`_engine.url.make_url` or
+ :meth:`_engine.URL.create` function/ method. To modify a
+ :class:`_engine.URL`, use the :meth:`_engine.URL.set` and
+ :meth:`_engine.URL.update_query` methods.
+
+ """
+
+ return cls(
+ cls._assert_str(drivername, "drivername"),
+ cls._assert_none_str(username, "username"),
+ password,
+ cls._assert_none_str(host, "host"),
+ cls._assert_port(port),
+ cls._assert_none_str(database, "database"),
+ cls._str_dict(query),
+ _new_ok=True,
+ )
+
+ @classmethod
+ def _assert_port(cls, port):
+ if port is None:
+ return None
+ try:
+ return int(port)
+ except TypeError:
+ raise TypeError("Port argument must be an integer or None")
+
+ @classmethod
+ def _assert_str(cls, v, paramname):
+ if not isinstance(v, compat.string_types):
+ raise TypeError("%s must be a string" % paramname)
+ return v
+
+ @classmethod
+ def _assert_none_str(cls, v, paramname):
+ if v is None:
+ return v
+
+ return cls._assert_str(v, paramname)
+
+ @classmethod
+ def _str_dict(cls, dict_):
+ if dict_ is None:
+ return util.EMPTY_DICT
+
+ def _assert_value(val):
+ if isinstance(val, compat.string_types):
+ return val
+ elif isinstance(val, collections_abc.Sequence):
+ return tuple(_assert_value(elem) for elem in val)
+ else:
+ raise TypeError(
+ "Query dictionary values must be strings or "
+ "sequences of strings"
+ )
+
+ def _assert_str(v):
+ if not isinstance(v, compat.string_types):
+ raise TypeError("Query dictionary keys must be strings")
+ return v
+
+ if isinstance(dict_, collections_abc.Sequence):
+ dict_items = dict_
+ else:
+ dict_items = dict_.items()
+
+ return util.immutabledict(
+ {
+ _assert_str(key): _assert_value(
+ value,
+ )
+ for key, value in dict_items
+ }
+ )
+
+ def set(
+ self,
+ drivername=None,
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ database=None,
+ query=None,
+ ):
+ """return a new :class:`_engine.URL` object with modifications.
+
+ Values are used if they are non-None. To set a value to ``None``
+ explicitly, use the :meth:`_engine.URL._replace` method adapted
+ from ``namedtuple``.
+
+ :param drivername: new drivername
+ :param username: new username
+ :param password: new password
+ :param host: new hostname
+ :param port: new port
+ :param query: new query parameters, passed a dict of string keys
+ referring to string or sequence of string values. Fully
+ replaces the previous list of arguments.
+
+ :return: new :class:`_engine.URL` object.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :meth:`_engine.URL.update_query_dict`
+
+ """
+
+ kw = {}
+ if drivername is not None:
+ kw["drivername"] = drivername
+ if username is not None:
+ kw["username"] = username
+ if password is not None:
+ kw["password"] = password
+ if host is not None:
+ kw["host"] = host
+ if port is not None:
+ kw["port"] = port
+ if database is not None:
+ kw["database"] = database
+ if query is not None:
+ kw["query"] = query
+
+ return self._replace(**kw)
+
+ def _replace(self, **kw):
+ """Override ``namedtuple._replace()`` to provide argument checking."""
+
+ if "drivername" in kw:
+ self._assert_str(kw["drivername"], "drivername")
+ for name in "username", "host", "database":
+ if name in kw:
+ self._assert_none_str(kw[name], name)
+ if "port" in kw:
+ self._assert_port(kw["port"])
+ if "query" in kw:
+ kw["query"] = self._str_dict(kw["query"])
+
+ return super(URL, self)._replace(**kw)
+
+ def update_query_string(self, query_string, append=False):
+ """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query`
+ parameter dictionary updated by the given query string.
+
+ E.g.::
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname")
+ >>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
+ >>> str(url)
+ 'postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
+
+ :param query_string: a URL escaped query string, not including the
+ question mark.
+
+ :param append: if True, parameters in the existing query string will
+ not be removed; new parameters will be in addition to those present.
+ If left at its default of False, keys present in the given query
+ parameters will replace those of the existing query string.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.update_query_dict`
+
+ """ # noqa: E501
+ return self.update_query_pairs(
+ util.parse_qsl(query_string), append=append
+ )
+
+ def update_query_pairs(self, key_value_pairs, append=False):
+ """Return a new :class:`_engine.URL` object with the
+ :attr:`_engine.URL.query`
+ parameter dictionary updated by the given sequence of key/value pairs
+
+ E.g.::
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname")
+ >>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")])
+ >>> str(url)
+ 'postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
+
+ :param key_value_pairs: A sequence of tuples containing two strings
+ each.
+
+ :param append: if True, parameters in the existing query string will
+ not be removed; new parameters will be in addition to those present.
+ If left at its default of False, keys present in the given query
+ parameters will replace those of the existing query string.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.difference_update_query`
+
+ :meth:`_engine.URL.set`
+
+ """ # noqa: E501
+
+ existing_query = self.query
+ new_keys = {}
+
+ for key, value in key_value_pairs:
+ if key in new_keys:
+ new_keys[key] = util.to_list(new_keys[key])
+ new_keys[key].append(value)
+ else:
+ new_keys[key] = value
+
+ if append:
+ new_query = {}
+
+ for k in new_keys:
+ if k in existing_query:
+ new_query[k] = util.to_list(
+ existing_query[k]
+ ) + util.to_list(new_keys[k])
+ else:
+ new_query[k] = new_keys[k]
+
+ new_query.update(
+ {
+ k: existing_query[k]
+ for k in set(existing_query).difference(new_keys)
+ }
+ )
+ else:
+ new_query = self.query.union(new_keys)
+ return self.set(query=new_query)
+
+ def update_query_dict(self, query_parameters, append=False):
+ """Return a new :class:`_engine.URL` object with the
+ :attr:`_engine.URL.query` parameter dictionary updated by the given
+ dictionary.
+
+ The dictionary typically contains string keys and string values.
+ In order to represent a query parameter that is expressed multiple
+ times, pass a sequence of string values.
+
+ E.g.::
+
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname")
+ >>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"})
+ >>> str(url)
+ 'postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt'
+
+
+ :param query_parameters: A dictionary with string keys and values
+ that are either strings, or sequences of strings.
+
+ :param append: if True, parameters in the existing query string will
+ not be removed; new parameters will be in addition to those present.
+ If left at its default of False, keys present in the given query
+ parameters will replace those of the existing query string.
+
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.update_query_string`
+
+ :meth:`_engine.URL.update_query_pairs`
+
+ :meth:`_engine.URL.difference_update_query`
+
+ :meth:`_engine.URL.set`
+
+ """ # noqa: E501
+ return self.update_query_pairs(query_parameters.items(), append=append)
+
+ def difference_update_query(self, names):
+ """
+ Remove the given names from the :attr:`_engine.URL.query` dictionary,
+ returning the new :class:`_engine.URL`.
+
+ E.g.::
+
+ url = url.difference_update_query(['foo', 'bar'])
+
+ Equivalent to using :meth:`_engine.URL.set` as follows::
+
+ url = url.set(
+ query={
+ key: url.query[key]
+ for key in set(url.query).difference(['foo', 'bar'])
+ }
+ )
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_engine.URL.query`
+
+ :meth:`_engine.URL.update_query_dict`
+
+ :meth:`_engine.URL.set`
+
+ """
+
+ if not set(names).intersection(self.query):
+ return self
+
+ return URL(
+ self.drivername,
+ self.username,
+ self.password,
+ self.host,
+ self.port,
+ self.database,
+ util.immutabledict(
+ {
+ key: self.query[key]
+ for key in set(self.query).difference(names)
+ }
+ ),
+ _new_ok=True,
+ )
+
+ @util.memoized_property
+ def normalized_query(self):
+ """Return the :attr:`_engine.URL.query` dictionary with values normalized
+ into sequences.
+
+ As the :attr:`_engine.URL.query` dictionary may contain either
+ string values or sequences of string values to differentiate between
+ parameters that are specified multiple times in the query string,
+ code that needs to handle multiple parameters generically will wish
+ to use this attribute so that all parameters present are presented
+ as sequences. Inspiration is from Python's ``urllib.parse.parse_qs``
+ function. E.g.::
+
+
+ >>> from sqlalchemy.engine import make_url
+ >>> url = make_url("postgresql://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt")
+ >>> url.query
+ immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'})
+ >>> url.normalized_query
+ immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': ('/path/to/crt',)})
+
+ """ # noqa: E501
+
+ return util.immutabledict(
+ {
+ k: (v,) if not isinstance(v, tuple) else v
+ for k, v in self.query.items()
+ }
+ )
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_engine.URL.__to_string__ method is deprecated and will "
+ "be removed in a future release. Please use the "
+ ":meth:`_engine.URL.render_as_string` method.",
+ )
+ def __to_string__(self, hide_password=True):
+ """Render this :class:`_engine.URL` object as a string.
+
+ :param hide_password: Defaults to True. The password is not shown
+ in the string unless this is set to False.
+
+ """
+ return self.render_as_string(hide_password=hide_password)
+
+ def render_as_string(self, hide_password=True):
+ """Render this :class:`_engine.URL` object as a string.
+
+ This method is used when the ``__str__()`` or ``__repr__()``
+ methods are used. The method directly includes additional options.
+
+ :param hide_password: Defaults to True. The password is not shown
+ in the string unless this is set to False.
+
+ """
+ s = self.drivername + "://"
+ if self.username is not None:
+ s += _rfc_1738_quote(self.username)
+ if self.password is not None:
+ s += ":" + (
+ "***"
+ if hide_password
+ else _rfc_1738_quote(str(self.password))
+ )
+ s += "@"
+ if self.host is not None:
+ if ":" in self.host:
+ s += "[%s]" % self.host
+ else:
+ s += self.host
+ if self.port is not None:
+ s += ":" + str(self.port)
+ if self.database is not None:
+ s += "/" + self.database
+ if self.query:
+ keys = list(self.query)
+ keys.sort()
+ s += "?" + "&".join(
+ "%s=%s" % (util.quote_plus(k), util.quote_plus(element))
+ for k in keys
+ for element in util.to_list(self.query[k])
+ )
+ return s
+
+ def __str__(self):
+ return self.render_as_string(hide_password=False)
+
+ def __repr__(self):
+ return self.render_as_string()
+
+ def __copy__(self):
+ return self.__class__.create(
+ self.drivername,
+ self.username,
+ self.password,
+ self.host,
+ self.port,
+ self.database,
+ # note this is an immutabledict of str-> str / tuple of str,
+ # also fully immutable. does not require deepcopy
+ self.query,
+ )
+
+ def __deepcopy__(self, memo):
+ return self.__copy__()
+
+ def __hash__(self):
+ return hash(str(self))
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, URL)
+ and self.drivername == other.drivername
+ and self.username == other.username
+ and self.password == other.password
+ and self.host == other.host
+ and self.database == other.database
+ and self.query == other.query
+ and self.port == other.port
+ )
+
+ def __ne__(self, other):
+ return not self == other
+
+ def get_backend_name(self):
+ """Return the backend name.
+
+ This is the name that corresponds to the database backend in
+ use, and is the portion of the :attr:`_engine.URL.drivername`
+ that is to the left of the plus sign.
+
+ """
+ if "+" not in self.drivername:
+ return self.drivername
+ else:
+ return self.drivername.split("+")[0]
+
+ def get_driver_name(self):
+ """Return the backend name.
+
+ This is the name that corresponds to the DBAPI driver in
+ use, and is the portion of the :attr:`_engine.URL.drivername`
+ that is to the right of the plus sign.
+
+ If the :attr:`_engine.URL.drivername` does not include a plus sign,
+ then the default :class:`_engine.Dialect` for this :class:`_engine.URL`
+ is imported in order to get the driver name.
+
+ """
+
+ if "+" not in self.drivername:
+ return self.get_dialect().driver
+ else:
+ return self.drivername.split("+")[1]
+
+ def _instantiate_plugins(self, kwargs):
+ plugin_names = util.to_list(self.query.get("plugin", ()))
+ plugin_names += kwargs.get("plugins", [])
+
+ kwargs = dict(kwargs)
+
+ loaded_plugins = [
+ plugins.load(plugin_name)(self, kwargs)
+ for plugin_name in plugin_names
+ ]
+
+ u = self.difference_update_query(["plugin", "plugins"])
+
+ for plugin in loaded_plugins:
+ new_u = plugin.update_url(u)
+ if new_u is not None:
+ u = new_u
+
+ kwargs.pop("plugins", None)
+
+ return u, loaded_plugins, kwargs
+
+ def _get_entrypoint(self):
+ """Return the "entry point" dialect class.
+
+ This is normally the dialect itself except in the case when the
+ returned class implements the get_dialect_cls() method.
+
+ """
+ if "+" not in self.drivername:
+ name = self.drivername
+ else:
+ name = self.drivername.replace("+", ".")
+ cls = registry.load(name)
+ # check for legacy dialects that
+ # would return a module with 'dialect' as the
+ # actual class
+ if (
+ hasattr(cls, "dialect")
+ and isinstance(cls.dialect, type)
+ and issubclass(cls.dialect, Dialect)
+ ):
+ return cls.dialect
+ else:
+ return cls
+
+ def get_dialect(self):
+ """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding
+ to this URL's driver name.
+
+ """
+ entrypoint = self._get_entrypoint()
+ dialect_cls = entrypoint.get_dialect_cls(self)
+ return dialect_cls
+
+ def translate_connect_args(self, names=None, **kw):
+ r"""Translate url attributes into a dictionary of connection arguments.
+
+ Returns attributes of this url (`host`, `database`, `username`,
+ `password`, `port`) as a plain dictionary. The attribute names are
+ used as the keys by default. Unset or false attributes are omitted
+ from the final dictionary.
+
+ :param \**kw: Optional, alternate key names for url attributes.
+
+ :param names: Deprecated. Same purpose as the keyword-based alternate
+ names, but correlates the name to the original positionally.
+ """
+
+ if names is not None:
+ util.warn_deprecated(
+ "The `URL.translate_connect_args.name`s parameter is "
+ "deprecated. Please pass the "
+ "alternate names as kw arguments.",
+ "1.4",
+ )
+
+ translated = {}
+ attribute_names = ["host", "database", "username", "password", "port"]
+ for sname in attribute_names:
+ if names:
+ name = names.pop(0)
+ elif sname in kw:
+ name = kw[sname]
+ else:
+ name = sname
+ if name is not None and getattr(self, sname, False):
+ if sname == "password":
+ translated[name] = str(getattr(self, sname))
+ else:
+ translated[name] = getattr(self, sname)
+
+ return translated
+
+
+def make_url(name_or_url):
+ """Given a string or unicode instance, produce a new URL instance.
+
+ The given string is parsed according to the RFC 1738 spec. If an
+ existing URL object is passed, just returns the object.
+ """
+
+ if isinstance(name_or_url, util.string_types):
+ return _parse_rfc1738_args(name_or_url)
+ else:
+ return name_or_url
+
+
+def _parse_rfc1738_args(name):
+ pattern = re.compile(
+ r"""
+ (?P<name>[\w\+]+)://
+ (?:
+ (?P<username>[^:/]*)
+ (?::(?P<password>[^@]*))?
+ @)?
+ (?:
+ (?:
+ \[(?P<ipv6host>[^/\?]+)\] |
+ (?P<ipv4host>[^/:\?]+)
+ )?
+ (?::(?P<port>[^/\?]*))?
+ )?
+ (?:/(?P<database>[^\?]*))?
+ (?:\?(?P<query>.*))?
+ """,
+ re.X,
+ )
+
+ m = pattern.match(name)
+ if m is not None:
+ components = m.groupdict()
+ if components["query"] is not None:
+ query = {}
+
+ for key, value in util.parse_qsl(components["query"]):
+ if util.py2k:
+ key = key.encode("ascii")
+ if key in query:
+ query[key] = util.to_list(query[key])
+ query[key].append(value)
+ else:
+ query[key] = value
+ else:
+ query = None
+ components["query"] = query
+
+ if components["username"] is not None:
+ components["username"] = _rfc_1738_unquote(components["username"])
+
+ if components["password"] is not None:
+ components["password"] = _rfc_1738_unquote(components["password"])
+
+ ipv4host = components.pop("ipv4host")
+ ipv6host = components.pop("ipv6host")
+ components["host"] = ipv4host or ipv6host
+ name = components.pop("name")
+
+ if components["port"]:
+ components["port"] = int(components["port"])
+
+ return URL.create(name, **components)
+
+ else:
+ raise exc.ArgumentError(
+ "Could not parse rfc1738 URL from string '%s'" % name
+ )
+
+
+def _rfc_1738_quote(text):
+ return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text)
+
+
+def _rfc_1738_unquote(text):
+ return util.unquote(text)
+
+
+def _parse_keyvalue_args(name):
+ m = re.match(r"(\w+)://(.*)", name)
+ if m is not None:
+ (name, args) = m.group(1, 2)
+ opts = dict(util.parse_qsl(args))
+ return URL(name, *opts)
+ else:
+ return None
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py
new file mode 100644
index 0000000..1b03ebb
--- /dev/null
+++ b/lib/sqlalchemy/engine/util.py
@@ -0,0 +1,253 @@
+# engine/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .. import exc
+from .. import util
+from ..util import collections_abc
+from ..util import immutabledict
+
+
+def connection_memoize(key):
+ """Decorator, memoize a function in a connection.info stash.
+
+ Only applicable to functions which take no arguments other than a
+ connection. The memo will be stored in ``connection.info[key]``.
+ """
+
+ @util.decorator
+ def decorated(fn, self, connection):
+ connection = connection.connect()
+ try:
+ return connection.info[key]
+ except KeyError:
+ connection.info[key] = val = fn(self, connection)
+ return val
+
+ return decorated
+
+
+_no_tuple = ()
+_no_kw = util.immutabledict()
+
+
+def _distill_params(connection, multiparams, params):
+ r"""Given arguments from the calling form \*multiparams, \**params,
+ return a list of bind parameter structures, usually a list of
+ dictionaries.
+
+ In the case of 'raw' execution which accepts positional parameters,
+ it may be a list of tuples or lists.
+
+ """
+
+ if not multiparams:
+ if params:
+ connection._warn_for_legacy_exec_format()
+ return [params]
+ else:
+ return []
+ elif len(multiparams) == 1:
+ zero = multiparams[0]
+ if isinstance(zero, (list, tuple)):
+ if (
+ not zero
+ or hasattr(zero[0], "__iter__")
+ and not hasattr(zero[0], "strip")
+ ):
+ # execute(stmt, [{}, {}, {}, ...])
+ # execute(stmt, [(), (), (), ...])
+ return zero
+ else:
+ # this is used by exec_driver_sql only, so a deprecation
+ # warning would already be coming from passing a plain
+ # textual statement with positional parameters to
+ # execute().
+ # execute(stmt, ("value", "value"))
+ return [zero]
+ elif hasattr(zero, "keys"):
+ # execute(stmt, {"key":"value"})
+ return [zero]
+ else:
+ connection._warn_for_legacy_exec_format()
+ # execute(stmt, "value")
+ return [[zero]]
+ else:
+ connection._warn_for_legacy_exec_format()
+ if hasattr(multiparams[0], "__iter__") and not hasattr(
+ multiparams[0], "strip"
+ ):
+ return multiparams
+ else:
+ return [multiparams]
+
+
+def _distill_cursor_params(connection, multiparams, params):
+ """_distill_params without any warnings. more appropriate for
+ "cursor" params that can include tuple arguments, lists of tuples,
+ etc.
+
+ """
+
+ if not multiparams:
+ if params:
+ return [params]
+ else:
+ return []
+ elif len(multiparams) == 1:
+ zero = multiparams[0]
+ if isinstance(zero, (list, tuple)):
+ if (
+ not zero
+ or hasattr(zero[0], "__iter__")
+ and not hasattr(zero[0], "strip")
+ ):
+ # execute(stmt, [{}, {}, {}, ...])
+ # execute(stmt, [(), (), (), ...])
+ return zero
+ else:
+ # this is used by exec_driver_sql only, so a deprecation
+ # warning would already be coming from passing a plain
+ # textual statement with positional parameters to
+ # execute().
+ # execute(stmt, ("value", "value"))
+
+ return [zero]
+ elif hasattr(zero, "keys"):
+ # execute(stmt, {"key":"value"})
+ return [zero]
+ else:
+ # execute(stmt, "value")
+ return [[zero]]
+ else:
+ if hasattr(multiparams[0], "__iter__") and not hasattr(
+ multiparams[0], "strip"
+ ):
+ return multiparams
+ else:
+ return [multiparams]
+
+
+def _distill_params_20(params):
+ if params is None:
+ return _no_tuple, _no_kw
+ elif isinstance(params, list):
+ # collections_abc.MutableSequence): # avoid abc.__instancecheck__
+ if params and not isinstance(
+ params[0], (collections_abc.Mapping, tuple)
+ ):
+ raise exc.ArgumentError(
+ "List argument must consist only of tuples or dictionaries"
+ )
+
+ return (params,), _no_kw
+ elif isinstance(
+ params,
+ (tuple, dict, immutabledict),
+ # only do abc.__instancecheck__ for Mapping after we've checked
+ # for plain dictionaries and would otherwise raise
+ ) or isinstance(params, collections_abc.Mapping):
+ return (params,), _no_kw
+ else:
+ raise exc.ArgumentError("mapping or sequence expected for parameters")
+
+
+class TransactionalContext(object):
+ """Apply Python context manager behavior to transaction objects.
+
+ Performs validation to ensure the subject of the transaction is not
+ used if the transaction were ended prematurely.
+
+ """
+
+ _trans_subject = None
+
+ def _transaction_is_active(self):
+ raise NotImplementedError()
+
+ def _transaction_is_closed(self):
+ raise NotImplementedError()
+
+ def _rollback_can_be_called(self):
+ """indicates the object is in a state that is known to be acceptable
+ for rollback() to be called.
+
+ This does not necessarily mean rollback() will succeed or not raise
+ an error, just that there is currently no state detected that indicates
+ rollback() would fail or emit warnings.
+
+ It also does not mean that there's a transaction in progress, as
+ it is usually safe to call rollback() even if no transaction is
+ present.
+
+ .. versionadded:: 1.4.28
+
+ """
+ raise NotImplementedError()
+
+ def _get_subject(self):
+ raise NotImplementedError()
+
+ @classmethod
+ def _trans_ctx_check(cls, subject):
+ trans_context = subject._trans_context_manager
+ if trans_context:
+ if not trans_context._transaction_is_active():
+ raise exc.InvalidRequestError(
+ "Can't operate on closed transaction inside context "
+ "manager. Please complete the context manager "
+ "before emitting further commands."
+ )
+
+ def __enter__(self):
+ subject = self._get_subject()
+
+ # none for outer transaction, may be non-None for nested
+ # savepoint, legacy nesting cases
+ trans_context = subject._trans_context_manager
+ self._outer_trans_ctx = trans_context
+
+ self._trans_subject = subject
+ subject._trans_context_manager = self
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ subject = self._trans_subject
+
+ # simplistically we could assume that
+ # "subject._trans_context_manager is self". However, any calling
+ # code that is manipulating __exit__ directly would break this
+ # assumption. alembic context manager
+ # is an example of partial use that just calls __exit__ and
+ # not __enter__ at the moment. it's safe to assume this is being done
+ # in the wild also
+ out_of_band_exit = (
+ subject is None or subject._trans_context_manager is not self
+ )
+
+ if type_ is None and self._transaction_is_active():
+ try:
+ self.commit()
+ except:
+ with util.safe_reraise():
+ if self._rollback_can_be_called():
+ self.rollback()
+ finally:
+ if not out_of_band_exit:
+ subject._trans_context_manager = self._outer_trans_ctx
+ self._trans_subject = self._outer_trans_ctx = None
+ else:
+ try:
+ if not self._transaction_is_active():
+ if not self._transaction_is_closed():
+ self.close()
+ else:
+ if self._rollback_can_be_called():
+ self.rollback()
+ finally:
+ if not out_of_band_exit:
+ subject._trans_context_manager = self._outer_trans_ctx
+ self._trans_subject = self._outer_trans_ctx = None
diff --git a/lib/sqlalchemy/event/__init__.py b/lib/sqlalchemy/event/__init__.py
new file mode 100644
index 0000000..a89bea8
--- /dev/null
+++ b/lib/sqlalchemy/event/__init__.py
@@ -0,0 +1,17 @@
+# event/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .api import CANCEL
+from .api import contains
+from .api import listen
+from .api import listens_for
+from .api import NO_RETVAL
+from .api import remove
+from .attr import RefCollection
+from .base import dispatcher
+from .base import Events
+from .legacy import _legacy_signature
diff --git a/lib/sqlalchemy/event/api.py b/lib/sqlalchemy/event/api.py
new file mode 100644
index 0000000..ce44f57
--- /dev/null
+++ b/lib/sqlalchemy/event/api.py
@@ -0,0 +1,219 @@
+# event/api.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Public API functions for the event system.
+
+"""
+from __future__ import absolute_import
+
+from .base import _registrars
+from .registry import _EventKey
+from .. import exc
+from .. import util
+
+
+CANCEL = util.symbol("CANCEL")
+NO_RETVAL = util.symbol("NO_RETVAL")
+
+
+def _event_key(target, identifier, fn):
+ for evt_cls in _registrars[identifier]:
+ tgt = evt_cls._accept_with(target)
+ if tgt is not None:
+ return _EventKey(target, identifier, fn, tgt)
+ else:
+ raise exc.InvalidRequestError(
+ "No such event '%s' for target '%s'" % (identifier, target)
+ )
+
+
+def listen(target, identifier, fn, *args, **kw):
+ """Register a listener function for the given target.
+
+ The :func:`.listen` function is part of the primary interface for the
+ SQLAlchemy event system, documented at :ref:`event_toplevel`.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.schema import UniqueConstraint
+
+ def unique_constraint_name(const, table):
+ const.name = "uq_%s_%s" % (
+ table.name,
+ list(const.columns)[0].name
+ )
+ event.listen(
+ UniqueConstraint,
+ "after_parent_attach",
+ unique_constraint_name)
+
+ :param bool insert: The default behavior for event handlers is to append
+ the decorated user defined function to an internal list of registered
+ event listeners upon discovery. If a user registers a function with
+ ``insert=True``, SQLAlchemy will insert (prepend) the function to the
+ internal list upon discovery. This feature is not typically used or
+ recommended by the SQLAlchemy maintainers, but is provided to ensure
+ certain user defined functions can run before others, such as when
+ :ref:`Changing the sql_mode in MySQL <mysql_sql_mode>`.
+
+ :param bool named: When using named argument passing, the names listed in
+ the function argument specification will be used as keys in the
+ dictionary.
+ See :ref:`event_named_argument_styles`.
+
+ :param bool once: Private/Internal API usage. Deprecated. This parameter
+ would provide that an event function would run only once per given
+ target. It does not however imply automatic de-registration of the
+ listener function; associating an arbitrarily high number of listeners
+ without explicitly removing them will cause memory to grow unbounded even
+ if ``once=True`` is specified.
+
+ :param bool propagate: The ``propagate`` kwarg is available when working
+ with ORM instrumentation and mapping events.
+ See :class:`_ormevent.MapperEvents` and
+ :meth:`_ormevent.MapperEvents.before_mapper_configured` for examples.
+
+ :param bool retval: This flag applies only to specific event listeners,
+ each of which includes documentation explaining when it should be used.
+ By default, no listener ever requires a return value.
+ However, some listeners do support special behaviors for return values,
+ and include in their documentation that the ``retval=True`` flag is
+ necessary for a return value to be processed.
+
+ Event listener suites that make use of :paramref:`_event.listen.retval`
+ include :class:`_events.ConnectionEvents` and
+ :class:`_ormevent.AttributeEvents`.
+
+ .. note::
+
+ The :func:`.listen` function cannot be called at the same time
+ that the target event is being run. This has implications
+ for thread safety, and also means an event cannot be added
+ from inside the listener function for itself. The list of
+ events to be run are present inside of a mutable collection
+ that can't be changed during iteration.
+
+ Event registration and removal is not intended to be a "high
+ velocity" operation; it is a configurational operation. For
+ systems that need to quickly associate and deassociate with
+ events at high scale, use a mutable structure that is handled
+ from inside of a single listener.
+
+ .. seealso::
+
+ :func:`.listens_for`
+
+ :func:`.remove`
+
+ """
+
+ _event_key(target, identifier, fn).listen(*args, **kw)
+
+
+def listens_for(target, identifier, *args, **kw):
+ """Decorate a function as a listener for the given target + identifier.
+
+ The :func:`.listens_for` decorator is part of the primary interface for the
+ SQLAlchemy event system, documented at :ref:`event_toplevel`.
+
+ This function generally shares the same kwargs as :func:`.listens`.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.schema import UniqueConstraint
+
+ @event.listens_for(UniqueConstraint, "after_parent_attach")
+ def unique_constraint_name(const, table):
+ const.name = "uq_%s_%s" % (
+ table.name,
+ list(const.columns)[0].name
+ )
+
+ A given function can also be invoked for only the first invocation
+ of the event using the ``once`` argument::
+
+ @event.listens_for(Mapper, "before_configure", once=True)
+ def on_config():
+ do_config()
+
+
+ .. warning:: The ``once`` argument does not imply automatic de-registration
+ of the listener function after it has been invoked a first time; a
+ listener entry will remain associated with the target object.
+ Associating an arbitrarily high number of listeners without explicitly
+ removing them will cause memory to grow unbounded even if ``once=True``
+ is specified.
+
+ .. seealso::
+
+ :func:`.listen` - general description of event listening
+
+ """
+
+ def decorate(fn):
+ listen(target, identifier, fn, *args, **kw)
+ return fn
+
+ return decorate
+
+
+def remove(target, identifier, fn):
+ """Remove an event listener.
+
+ The arguments here should match exactly those which were sent to
+ :func:`.listen`; all the event registration which proceeded as a result
+ of this call will be reverted by calling :func:`.remove` with the same
+ arguments.
+
+ e.g.::
+
+ # if a function was registered like this...
+ @event.listens_for(SomeMappedClass, "before_insert", propagate=True)
+ def my_listener_function(*arg):
+ pass
+
+ # ... it's removed like this
+ event.remove(SomeMappedClass, "before_insert", my_listener_function)
+
+ Above, the listener function associated with ``SomeMappedClass`` was also
+ propagated to subclasses of ``SomeMappedClass``; the :func:`.remove`
+ function will revert all of these operations.
+
+ .. note::
+
+ The :func:`.remove` function cannot be called at the same time
+ that the target event is being run. This has implications
+ for thread safety, and also means an event cannot be removed
+ from inside the listener function for itself. The list of
+ events to be run are present inside of a mutable collection
+ that can't be changed during iteration.
+
+ Event registration and removal is not intended to be a "high
+ velocity" operation; it is a configurational operation. For
+ systems that need to quickly associate and deassociate with
+ events at high scale, use a mutable structure that is handled
+ from inside of a single listener.
+
+ .. versionchanged:: 1.0.0 - a ``collections.deque()`` object is now
+ used as the container for the list of events, which explicitly
+ disallows collection mutation while the collection is being
+ iterated.
+
+ .. seealso::
+
+ :func:`.listen`
+
+ """
+ _event_key(target, identifier, fn).remove()
+
+
+def contains(target, identifier, fn):
+ """Return True if the given target/ident/fn is set up to listen."""
+
+ return _event_key(target, identifier, fn).contains()
diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py
new file mode 100644
index 0000000..0d16165
--- /dev/null
+++ b/lib/sqlalchemy/event/attr.py
@@ -0,0 +1,468 @@
+# event/attr.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Attribute implementation for _Dispatch classes.
+
+The various listener targets for a particular event class are represented
+as attributes, which refer to collections of listeners to be fired off.
+These collections can exist at the class level as well as at the instance
+level. An event is fired off using code like this::
+
+ some_object.dispatch.first_connect(arg1, arg2)
+
+Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and
+``first_connect`` is typically an instance of ``_ListenerCollection``
+if event listeners are present, or ``_EmptyListener`` if none are present.
+
+The attribute mechanics here spend effort trying to ensure listener functions
+are available with a minimum of function call overhead, that unnecessary
+objects aren't created (i.e. many empty per-instance listener collections),
+as well as that everything is garbage collectable when owning references are
+lost. Other features such as "propagation" of listener functions across
+many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances,
+as well as support for subclass propagation (e.g. events assigned to
+``Pool`` vs. ``QueuePool``) are all implemented here.
+
+"""
+
+from __future__ import absolute_import
+from __future__ import with_statement
+
+import collections
+from itertools import chain
+import weakref
+
+from . import legacy
+from . import registry
+from .. import exc
+from .. import util
+from ..util import threading
+from ..util.concurrency import AsyncAdaptedLock
+
+
+class RefCollection(util.MemoizedSlots):
+ __slots__ = ("ref",)
+
+ def _memoized_attr_ref(self):
+ return weakref.ref(self, registry._collection_gced)
+
+
+class _empty_collection(object):
+ def append(self, element):
+ pass
+
+ def extend(self, other):
+ pass
+
+ def remove(self, element):
+ pass
+
+ def __iter__(self):
+ return iter([])
+
+ def clear(self):
+ pass
+
+
+class _ClsLevelDispatch(RefCollection):
+ """Class-level events on :class:`._Dispatch` classes."""
+
+ __slots__ = (
+ "clsname",
+ "name",
+ "arg_names",
+ "has_kw",
+ "legacy_signatures",
+ "_clslevel",
+ "__weakref__",
+ )
+
+ def __init__(self, parent_dispatch_cls, fn):
+ self.name = fn.__name__
+ self.clsname = parent_dispatch_cls.__name__
+ argspec = util.inspect_getfullargspec(fn)
+ self.arg_names = argspec.args[1:]
+ self.has_kw = bool(argspec.varkw)
+ self.legacy_signatures = list(
+ reversed(
+ sorted(
+ getattr(fn, "_legacy_signatures", []), key=lambda s: s[0]
+ )
+ )
+ )
+ fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn)
+
+ self._clslevel = weakref.WeakKeyDictionary()
+
+ def _adjust_fn_spec(self, fn, named):
+ if named:
+ fn = self._wrap_fn_for_kw(fn)
+ if self.legacy_signatures:
+ try:
+ argspec = util.get_callable_argspec(fn, no_self=True)
+ except TypeError:
+ pass
+ else:
+ fn = legacy._wrap_fn_for_legacy(self, fn, argspec)
+ return fn
+
+ def _wrap_fn_for_kw(self, fn):
+ def wrap_kw(*args, **kw):
+ argdict = dict(zip(self.arg_names, args))
+ argdict.update(kw)
+ return fn(**argdict)
+
+ return wrap_kw
+
+ def insert(self, event_key, propagate):
+ target = event_key.dispatch_target
+ assert isinstance(
+ target, type
+ ), "Class-level Event targets must be classes."
+ if not getattr(target, "_sa_propagate_class_events", True):
+ raise exc.InvalidRequestError(
+ "Can't assign an event directly to the %s class" % target
+ )
+
+ for cls in util.walk_subclasses(target):
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ if cls not in self._clslevel:
+ self._assign_cls_collection(cls)
+ self._clslevel[cls].appendleft(event_key._listen_fn)
+ registry._stored_in_collection(event_key, self)
+
+ def append(self, event_key, propagate):
+ target = event_key.dispatch_target
+ assert isinstance(
+ target, type
+ ), "Class-level Event targets must be classes."
+ if not getattr(target, "_sa_propagate_class_events", True):
+ raise exc.InvalidRequestError(
+ "Can't assign an event directly to the %s class" % target
+ )
+ for cls in util.walk_subclasses(target):
+ if cls is not target and cls not in self._clslevel:
+ self.update_subclass(cls)
+ else:
+ if cls not in self._clslevel:
+ self._assign_cls_collection(cls)
+ self._clslevel[cls].append(event_key._listen_fn)
+ registry._stored_in_collection(event_key, self)
+
+ def _assign_cls_collection(self, target):
+ if getattr(target, "_sa_propagate_class_events", True):
+ self._clslevel[target] = collections.deque()
+ else:
+ self._clslevel[target] = _empty_collection()
+
+ def update_subclass(self, target):
+ if target not in self._clslevel:
+ self._assign_cls_collection(target)
+ clslevel = self._clslevel[target]
+ for cls in target.__mro__[1:]:
+ if cls in self._clslevel:
+ clslevel.extend(
+ [fn for fn in self._clslevel[cls] if fn not in clslevel]
+ )
+
+ def remove(self, event_key):
+ target = event_key.dispatch_target
+ for cls in util.walk_subclasses(target):
+ if cls in self._clslevel:
+ self._clslevel[cls].remove(event_key._listen_fn)
+ registry._removed_from_collection(event_key, self)
+
+ def clear(self):
+ """Clear all class level listeners"""
+
+ to_clear = set()
+ for dispatcher in self._clslevel.values():
+ to_clear.update(dispatcher)
+ dispatcher.clear()
+ registry._clear(self, to_clear)
+
+ def for_modify(self, obj):
+ """Return an event collection which can be modified.
+
+ For _ClsLevelDispatch at the class level of
+ a dispatcher, this returns self.
+
+ """
+ return self
+
+
+class _InstanceLevelDispatch(RefCollection):
+ __slots__ = ()
+
+ def _adjust_fn_spec(self, fn, named):
+ return self.parent._adjust_fn_spec(fn, named)
+
+
+class _EmptyListener(_InstanceLevelDispatch):
+ """Serves as a proxy interface to the events
+ served by a _ClsLevelDispatch, when there are no
+ instance-level events present.
+
+ Is replaced by _ListenerCollection when instance-level
+ events are added.
+
+ """
+
+ propagate = frozenset()
+ listeners = ()
+
+ __slots__ = "parent", "parent_listeners", "name"
+
+ def __init__(self, parent, target_cls):
+ if target_cls not in parent._clslevel:
+ parent.update_subclass(target_cls)
+ self.parent = parent # _ClsLevelDispatch
+ self.parent_listeners = parent._clslevel[target_cls]
+ self.name = parent.name
+
+ def for_modify(self, obj):
+ """Return an event collection which can be modified.
+
+ For _EmptyListener at the instance level of
+ a dispatcher, this generates a new
+ _ListenerCollection, applies it to the instance,
+ and returns it.
+
+ """
+ result = _ListenerCollection(self.parent, obj._instance_cls)
+ if getattr(obj, self.name) is self:
+ setattr(obj, self.name, result)
+ else:
+ assert isinstance(getattr(obj, self.name), _JoinedListener)
+ return result
+
+ def _needs_modify(self, *args, **kw):
+ raise NotImplementedError("need to call for_modify()")
+
+ exec_once = (
+ exec_once_unless_exception
+ ) = insert = append = remove = clear = _needs_modify
+
+ def __call__(self, *args, **kw):
+ """Execute this event."""
+
+ for fn in self.parent_listeners:
+ fn(*args, **kw)
+
+ def __len__(self):
+ return len(self.parent_listeners)
+
+ def __iter__(self):
+ return iter(self.parent_listeners)
+
+ def __bool__(self):
+ return bool(self.parent_listeners)
+
+ __nonzero__ = __bool__
+
+
+class _CompoundListener(_InstanceLevelDispatch):
+ __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once"
+
+ def _set_asyncio(self):
+ self._exec_once_mutex = AsyncAdaptedLock()
+
+ def _memoized_attr__exec_once_mutex(self):
+ return threading.Lock()
+
+ def _exec_once_impl(self, retry_on_exception, *args, **kw):
+ with self._exec_once_mutex:
+ if not self._exec_once:
+ try:
+ self(*args, **kw)
+ exception = False
+ except:
+ exception = True
+ raise
+ finally:
+ if not exception or not retry_on_exception:
+ self._exec_once = True
+
+ def exec_once(self, *args, **kw):
+ """Execute this event, but only if it has not been
+ executed already for this collection."""
+
+ if not self._exec_once:
+ self._exec_once_impl(False, *args, **kw)
+
+ def exec_once_unless_exception(self, *args, **kw):
+ """Execute this event, but only if it has not been
+ executed already for this collection, or was called
+ by a previous exec_once_unless_exception call and
+ raised an exception.
+
+ If exec_once was already called, then this method will never run
+ the callable regardless of whether it raised or not.
+
+ .. versionadded:: 1.3.8
+
+ """
+ if not self._exec_once:
+ self._exec_once_impl(True, *args, **kw)
+
+ def _exec_w_sync_on_first_run(self, *args, **kw):
+ """Execute this event, and use a mutex if it has not been
+ executed already for this collection, or was called
+ by a previous _exec_w_sync_on_first_run call and
+ raised an exception.
+
+ If _exec_w_sync_on_first_run was already called and didn't raise an
+ exception, then a mutex is not used.
+
+ .. versionadded:: 1.4.11
+
+ """
+ if not self._exec_w_sync_once:
+ with self._exec_once_mutex:
+ try:
+ self(*args, **kw)
+ except:
+ raise
+ else:
+ self._exec_w_sync_once = True
+ else:
+ self(*args, **kw)
+
+ def __call__(self, *args, **kw):
+ """Execute this event."""
+
+ for fn in self.parent_listeners:
+ fn(*args, **kw)
+ for fn in self.listeners:
+ fn(*args, **kw)
+
+ def __len__(self):
+ return len(self.parent_listeners) + len(self.listeners)
+
+ def __iter__(self):
+ return chain(self.parent_listeners, self.listeners)
+
+ def __bool__(self):
+ return bool(self.listeners or self.parent_listeners)
+
+ __nonzero__ = __bool__
+
+
+class _ListenerCollection(_CompoundListener):
+ """Instance-level attributes on instances of :class:`._Dispatch`.
+
+ Represents a collection of listeners.
+
+ As of 0.7.9, _ListenerCollection is only first
+ created via the _EmptyListener.for_modify() method.
+
+ """
+
+ __slots__ = (
+ "parent_listeners",
+ "parent",
+ "name",
+ "listeners",
+ "propagate",
+ "__weakref__",
+ )
+
+ def __init__(self, parent, target_cls):
+ if target_cls not in parent._clslevel:
+ parent.update_subclass(target_cls)
+ self._exec_once = False
+ self._exec_w_sync_once = False
+ self.parent_listeners = parent._clslevel[target_cls]
+ self.parent = parent
+ self.name = parent.name
+ self.listeners = collections.deque()
+ self.propagate = set()
+
+ def for_modify(self, obj):
+ """Return an event collection which can be modified.
+
+ For _ListenerCollection at the instance level of
+ a dispatcher, this returns self.
+
+ """
+ return self
+
+ def _update(self, other, only_propagate=True):
+ """Populate from the listeners in another :class:`_Dispatch`
+ object."""
+
+ existing_listeners = self.listeners
+ existing_listener_set = set(existing_listeners)
+ self.propagate.update(other.propagate)
+ other_listeners = [
+ l
+ for l in other.listeners
+ if l not in existing_listener_set
+ and not only_propagate
+ or l in self.propagate
+ ]
+
+ existing_listeners.extend(other_listeners)
+
+ to_associate = other.propagate.union(other_listeners)
+ registry._stored_in_collection_multi(self, other, to_associate)
+
+ def insert(self, event_key, propagate):
+ if event_key.prepend_to_list(self, self.listeners):
+ if propagate:
+ self.propagate.add(event_key._listen_fn)
+
+ def append(self, event_key, propagate):
+ if event_key.append_to_list(self, self.listeners):
+ if propagate:
+ self.propagate.add(event_key._listen_fn)
+
+ def remove(self, event_key):
+ self.listeners.remove(event_key._listen_fn)
+ self.propagate.discard(event_key._listen_fn)
+ registry._removed_from_collection(event_key, self)
+
+ def clear(self):
+ registry._clear(self, self.listeners)
+ self.propagate.clear()
+ self.listeners.clear()
+
+
+class _JoinedListener(_CompoundListener):
+ __slots__ = "parent", "name", "local", "parent_listeners"
+
+ def __init__(self, parent, name, local):
+ self._exec_once = False
+ self.parent = parent
+ self.name = name
+ self.local = local
+ self.parent_listeners = self.local
+
+ @property
+ def listeners(self):
+ return getattr(self.parent, self.name)
+
+ def _adjust_fn_spec(self, fn, named):
+ return self.local._adjust_fn_spec(fn, named)
+
+ def for_modify(self, obj):
+ self.local = self.parent_listeners = self.local.for_modify(obj)
+ return self
+
+ def insert(self, event_key, propagate):
+ self.local.insert(event_key, propagate)
+
+ def append(self, event_key, propagate):
+ self.local.append(event_key, propagate)
+
+ def remove(self, event_key):
+ self.local.remove(event_key)
+
+ def clear(self):
+ raise NotImplementedError()
diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py
new file mode 100644
index 0000000..510e16b
--- /dev/null
+++ b/lib/sqlalchemy/event/base.py
@@ -0,0 +1,345 @@
+# event/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Base implementation classes.
+
+The public-facing ``Events`` serves as the base class for an event interface;
+its public attributes represent different kinds of events. These attributes
+are mirrored onto a ``_Dispatch`` class, which serves as a container for
+collections of listener functions. These collections are represented both
+at the class level of a particular ``_Dispatch`` class as well as within
+instances of ``_Dispatch``.
+
+"""
+from __future__ import absolute_import
+
+import weakref
+
+from .attr import _ClsLevelDispatch
+from .attr import _EmptyListener
+from .attr import _JoinedListener
+from .. import util
+
+
+_registrars = util.defaultdict(list)
+
+
+def _is_event_name(name):
+ # _sa_event prefix is special to support internal-only event names.
+ # most event names are just plain method names that aren't
+ # underscored.
+
+ return (
+ not name.startswith("_") and name != "dispatch"
+ ) or name.startswith("_sa_event")
+
+
+class _UnpickleDispatch(object):
+ """Serializable callable that re-generates an instance of
+ :class:`_Dispatch` given a particular :class:`.Events` subclass.
+
+ """
+
+ def __call__(self, _instance_cls):
+ for cls in _instance_cls.__mro__:
+ if "dispatch" in cls.__dict__:
+ return cls.__dict__["dispatch"].dispatch._for_class(
+ _instance_cls
+ )
+ else:
+ raise AttributeError("No class with a 'dispatch' member present.")
+
+
+class _Dispatch(object):
+ """Mirror the event listening definitions of an Events class with
+ listener collections.
+
+ Classes which define a "dispatch" member will return a
+ non-instantiated :class:`._Dispatch` subclass when the member
+ is accessed at the class level. When the "dispatch" member is
+ accessed at the instance level of its owner, an instance
+ of the :class:`._Dispatch` class is returned.
+
+ A :class:`._Dispatch` class is generated for each :class:`.Events`
+ class defined, by the :func:`._create_dispatcher_class` function.
+ The original :class:`.Events` classes remain untouched.
+ This decouples the construction of :class:`.Events` subclasses from
+ the implementation used by the event internals, and allows
+ inspecting tools like Sphinx to work in an unsurprising
+ way against the public API.
+
+ """
+
+ # In one ORM edge case, an attribute is added to _Dispatch,
+ # so __dict__ is used in just that case and potentially others.
+ __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners"
+
+ _empty_listener_reg = weakref.WeakKeyDictionary()
+
+ def __init__(self, parent, instance_cls=None):
+ self._parent = parent
+ self._instance_cls = instance_cls
+
+ if instance_cls:
+ try:
+ self._empty_listeners = self._empty_listener_reg[instance_cls]
+ except KeyError:
+ self._empty_listeners = self._empty_listener_reg[
+ instance_cls
+ ] = {
+ ls.name: _EmptyListener(ls, instance_cls)
+ for ls in parent._event_descriptors
+ }
+ else:
+ self._empty_listeners = {}
+
+ def __getattr__(self, name):
+ # Assign EmptyListeners as attributes on demand
+ # to reduce startup time for new dispatch objects.
+ try:
+ ls = self._empty_listeners[name]
+ except KeyError:
+ raise AttributeError(name)
+ else:
+ setattr(self, ls.name, ls)
+ return ls
+
+ @property
+ def _event_descriptors(self):
+ for k in self._event_names:
+ # Yield _ClsLevelDispatch related
+ # to relevant event name.
+ yield getattr(self, k)
+
+ @property
+ def _listen(self):
+ return self._events._listen
+
+ def _for_class(self, instance_cls):
+ return self.__class__(self, instance_cls)
+
+ def _for_instance(self, instance):
+ instance_cls = instance.__class__
+ return self._for_class(instance_cls)
+
+ def _join(self, other):
+ """Create a 'join' of this :class:`._Dispatch` and another.
+
+ This new dispatcher will dispatch events to both
+ :class:`._Dispatch` objects.
+
+ """
+ if "_joined_dispatch_cls" not in self.__class__.__dict__:
+ cls = type(
+ "Joined%s" % self.__class__.__name__,
+ (_JoinedDispatcher,),
+ {"__slots__": self._event_names},
+ )
+
+ self.__class__._joined_dispatch_cls = cls
+ return self._joined_dispatch_cls(self, other)
+
+ def __reduce__(self):
+ return _UnpickleDispatch(), (self._instance_cls,)
+
+ def _update(self, other, only_propagate=True):
+ """Populate from the listeners in another :class:`_Dispatch`
+ object."""
+ for ls in other._event_descriptors:
+ if isinstance(ls, _EmptyListener):
+ continue
+ getattr(self, ls.name).for_modify(self)._update(
+ ls, only_propagate=only_propagate
+ )
+
+ def _clear(self):
+ for ls in self._event_descriptors:
+ ls.for_modify(self).clear()
+
+
+class _EventMeta(type):
+ """Intercept new Event subclasses and create
+ associated _Dispatch classes."""
+
+ def __init__(cls, classname, bases, dict_):
+ _create_dispatcher_class(cls, classname, bases, dict_)
+ type.__init__(cls, classname, bases, dict_)
+
+
+def _create_dispatcher_class(cls, classname, bases, dict_):
+ """Create a :class:`._Dispatch` class corresponding to an
+ :class:`.Events` class."""
+
+ # there's all kinds of ways to do this,
+ # i.e. make a Dispatch class that shares the '_listen' method
+ # of the Event class, this is the straight monkeypatch.
+ if hasattr(cls, "dispatch"):
+ dispatch_base = cls.dispatch.__class__
+ else:
+ dispatch_base = _Dispatch
+
+ event_names = [k for k in dict_ if _is_event_name(k)]
+ dispatch_cls = type(
+ "%sDispatch" % classname, (dispatch_base,), {"__slots__": event_names}
+ )
+
+ dispatch_cls._event_names = event_names
+
+ dispatch_inst = cls._set_dispatch(cls, dispatch_cls)
+ for k in dispatch_cls._event_names:
+ setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k]))
+ _registrars[k].append(cls)
+
+ for super_ in dispatch_cls.__bases__:
+ if issubclass(super_, _Dispatch) and super_ is not _Dispatch:
+ for ls in super_._events.dispatch._event_descriptors:
+ setattr(dispatch_inst, ls.name, ls)
+ dispatch_cls._event_names.append(ls.name)
+
+ if getattr(cls, "_dispatch_target", None):
+ the_cls = cls._dispatch_target
+ if (
+ hasattr(the_cls, "__slots__")
+ and "_slots_dispatch" in the_cls.__slots__
+ ):
+ cls._dispatch_target.dispatch = slots_dispatcher(cls)
+ else:
+ cls._dispatch_target.dispatch = dispatcher(cls)
+
+
+def _remove_dispatcher(cls):
+ for k in cls.dispatch._event_names:
+ _registrars[k].remove(cls)
+ if not _registrars[k]:
+ del _registrars[k]
+
+
+class Events(util.with_metaclass(_EventMeta, object)):
+ """Define event listening functions for a particular target type."""
+
+ @staticmethod
+ def _set_dispatch(cls, dispatch_cls):
+ # This allows an Events subclass to define additional utility
+ # methods made available to the target via
+ # "self.dispatch._events.<utilitymethod>"
+ # @staticmethod to allow easy "super" calls while in a metaclass
+ # constructor.
+ cls.dispatch = dispatch_cls(None)
+ dispatch_cls._events = cls
+ return cls.dispatch
+
+ @classmethod
+ def _accept_with(cls, target):
+ def dispatch_is(*types):
+ return all(isinstance(target.dispatch, t) for t in types)
+
+ def dispatch_parent_is(t):
+ return isinstance(target.dispatch.parent, t)
+
+ # Mapper, ClassManager, Session override this to
+ # also accept classes, scoped_sessions, sessionmakers, etc.
+ if hasattr(target, "dispatch"):
+ if (
+ dispatch_is(cls.dispatch.__class__)
+ or dispatch_is(type, cls.dispatch.__class__)
+ or (
+ dispatch_is(_JoinedDispatcher)
+ and dispatch_parent_is(cls.dispatch.__class__)
+ )
+ ):
+ return target
+
+ @classmethod
+ def _listen(
+ cls,
+ event_key,
+ propagate=False,
+ insert=False,
+ named=False,
+ asyncio=False,
+ ):
+ event_key.base_listen(
+ propagate=propagate, insert=insert, named=named, asyncio=asyncio
+ )
+
+ @classmethod
+ def _remove(cls, event_key):
+ event_key.remove()
+
+ @classmethod
+ def _clear(cls):
+ cls.dispatch._clear()
+
+
+class _JoinedDispatcher(object):
+ """Represent a connection between two _Dispatch objects."""
+
+ __slots__ = "local", "parent", "_instance_cls"
+
+ def __init__(self, local, parent):
+ self.local = local
+ self.parent = parent
+ self._instance_cls = self.local._instance_cls
+
+ def __getattr__(self, name):
+ # Assign _JoinedListeners as attributes on demand
+ # to reduce startup time for new dispatch objects.
+ ls = getattr(self.local, name)
+ jl = _JoinedListener(self.parent, ls.name, ls)
+ setattr(self, ls.name, jl)
+ return jl
+
+ @property
+ def _listen(self):
+ return self.parent._listen
+
+ @property
+ def _events(self):
+ return self.parent._events
+
+
+class dispatcher(object):
+ """Descriptor used by target classes to
+ deliver the _Dispatch class at the class level
+ and produce new _Dispatch instances for target
+ instances.
+
+ """
+
+ def __init__(self, events):
+ self.dispatch = events.dispatch
+ self.events = events
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self.dispatch
+
+ disp = self.dispatch._for_instance(obj)
+ try:
+ obj.__dict__["dispatch"] = disp
+ except AttributeError as ae:
+ util.raise_(
+ TypeError(
+ "target %r doesn't have __dict__, should it be "
+ "defining _slots_dispatch?" % (obj,)
+ ),
+ replace_context=ae,
+ )
+ return disp
+
+
+class slots_dispatcher(dispatcher):
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self.dispatch
+
+ if hasattr(obj, "_slots_dispatch"):
+ return obj._slots_dispatch
+
+ disp = self.dispatch._for_instance(obj)
+ obj._slots_dispatch = disp
+ return disp
diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py
new file mode 100644
index 0000000..d9f6ce5
--- /dev/null
+++ b/lib/sqlalchemy/event/legacy.py
@@ -0,0 +1,185 @@
+# event/legacy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Routines to handle adaption of legacy call signatures,
+generation of deprecation notes and docstrings.
+
+"""
+
+from .. import util
+
+
+def _legacy_signature(since, argnames, converter=None):
+ def leg(fn):
+ if not hasattr(fn, "_legacy_signatures"):
+ fn._legacy_signatures = []
+ fn._legacy_signatures.append((since, argnames, converter))
+ return fn
+
+ return leg
+
+
+def _wrap_fn_for_legacy(dispatch_collection, fn, argspec):
+ for since, argnames, conv in dispatch_collection.legacy_signatures:
+ if argnames[-1] == "**kw":
+ has_kw = True
+ argnames = argnames[0:-1]
+ else:
+ has_kw = False
+
+ if len(argnames) == len(argspec.args) and has_kw is bool(
+ argspec.varkw
+ ):
+
+ formatted_def = "def %s(%s%s)" % (
+ dispatch_collection.name,
+ ", ".join(dispatch_collection.arg_names),
+ ", **kw" if has_kw else "",
+ )
+ warning_txt = (
+ 'The argument signature for the "%s.%s" event listener '
+ "has changed as of version %s, and conversion for "
+ "the old argument signature will be removed in a "
+ 'future release. The new signature is "%s"'
+ % (
+ dispatch_collection.clsname,
+ dispatch_collection.name,
+ since,
+ formatted_def,
+ )
+ )
+
+ if conv:
+ assert not has_kw
+
+ def wrap_leg(*args):
+ util.warn_deprecated(warning_txt, version=since)
+ return fn(*conv(*args))
+
+ else:
+
+ def wrap_leg(*args, **kw):
+ util.warn_deprecated(warning_txt, version=since)
+ argdict = dict(zip(dispatch_collection.arg_names, args))
+ args = [argdict[name] for name in argnames]
+ if has_kw:
+ return fn(*args, **kw)
+ else:
+ return fn(*args)
+
+ return wrap_leg
+ else:
+ return fn
+
+
+def _indent(text, indent):
+ return "\n".join(indent + line for line in text.split("\n"))
+
+
+def _standard_listen_example(dispatch_collection, sample_target, fn):
+ example_kw_arg = _indent(
+ "\n".join(
+ "%(arg)s = kw['%(arg)s']" % {"arg": arg}
+ for arg in dispatch_collection.arg_names[0:2]
+ ),
+ " ",
+ )
+ if dispatch_collection.legacy_signatures:
+ current_since = max(
+ since
+ for since, args, conv in dispatch_collection.legacy_signatures
+ )
+ else:
+ current_since = None
+ text = (
+ "from sqlalchemy import event\n\n\n"
+ "@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
+ "def receive_%(event_name)s("
+ "%(named_event_arguments)s%(has_kw_arguments)s):\n"
+ " \"listen for the '%(event_name)s' event\"\n"
+ "\n # ... (event handling logic) ...\n"
+ )
+
+ text %= {
+ "current_since": " (arguments as of %s)" % current_since
+ if current_since
+ else "",
+ "event_name": fn.__name__,
+ "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
+ "named_event_arguments": ", ".join(dispatch_collection.arg_names),
+ "example_kw_arg": example_kw_arg,
+ "sample_target": sample_target,
+ }
+ return text
+
+
+def _legacy_listen_examples(dispatch_collection, sample_target, fn):
+ text = ""
+ for since, args, conv in dispatch_collection.legacy_signatures:
+ text += (
+ "\n# DEPRECATED calling style (pre-%(since)s, "
+ "will be removed in a future release)\n"
+ "@event.listens_for(%(sample_target)s, '%(event_name)s')\n"
+ "def receive_%(event_name)s("
+ "%(named_event_arguments)s%(has_kw_arguments)s):\n"
+ " \"listen for the '%(event_name)s' event\"\n"
+ "\n # ... (event handling logic) ...\n"
+ % {
+ "since": since,
+ "event_name": fn.__name__,
+ "has_kw_arguments": " **kw"
+ if dispatch_collection.has_kw
+ else "",
+ "named_event_arguments": ", ".join(args),
+ "sample_target": sample_target,
+ }
+ )
+ return text
+
+
+def _version_signature_changes(parent_dispatch_cls, dispatch_collection):
+ since, args, conv = dispatch_collection.legacy_signatures[0]
+ return (
+ "\n.. deprecated:: %(since)s\n"
+ " The :class:`.%(clsname)s.%(event_name)s` event now accepts the \n"
+ " arguments ``%(named_event_arguments)s%(has_kw_arguments)s``.\n"
+ " Support for listener functions which accept the previous \n"
+ ' argument signature(s) listed above as "deprecated" will be \n'
+ " removed in a future release."
+ % {
+ "since": since,
+ "clsname": parent_dispatch_cls.__name__,
+ "event_name": dispatch_collection.name,
+ "named_event_arguments": ", ".join(dispatch_collection.arg_names),
+ "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "",
+ }
+ )
+
+
+def _augment_fn_docs(dispatch_collection, parent_dispatch_cls, fn):
+ header = (
+ ".. container:: event_signatures\n\n"
+ " Example argument forms::\n"
+ "\n"
+ )
+
+ sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj")
+ text = header + _indent(
+ _standard_listen_example(dispatch_collection, sample_target, fn),
+ " " * 8,
+ )
+ if dispatch_collection.legacy_signatures:
+ text += _indent(
+ _legacy_listen_examples(dispatch_collection, sample_target, fn),
+ " " * 8,
+ )
+
+ text += _version_signature_changes(
+ parent_dispatch_cls, dispatch_collection
+ )
+
+ return util.inject_docstring_text(fn.__doc__, text, 1)
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py
new file mode 100644
index 0000000..ac143c4
--- /dev/null
+++ b/lib/sqlalchemy/event/registry.py
@@ -0,0 +1,297 @@
+# event/registry.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Provides managed registration services on behalf of :func:`.listen`
+arguments.
+
+By "managed registration", we mean that event listening functions and
+other objects can be added to various collections in such a way that their
+membership in all those collections can be revoked at once, based on
+an equivalent :class:`._EventKey`.
+
+"""
+
+from __future__ import absolute_import
+
+import collections
+import types
+import weakref
+
+from .. import exc
+from .. import util
+
+
+_key_to_collection = collections.defaultdict(dict)
+"""
+Given an original listen() argument, can locate all
+listener collections and the listener fn contained
+
+(target, identifier, fn) -> {
+ ref(listenercollection) -> ref(listener_fn)
+ ref(listenercollection) -> ref(listener_fn)
+ ref(listenercollection) -> ref(listener_fn)
+ }
+"""
+
+_collection_to_key = collections.defaultdict(dict)
+"""
+Given a _ListenerCollection or _ClsLevelListener, can locate
+all the original listen() arguments and the listener fn contained
+
+ref(listenercollection) -> {
+ ref(listener_fn) -> (target, identifier, fn),
+ ref(listener_fn) -> (target, identifier, fn),
+ ref(listener_fn) -> (target, identifier, fn),
+ }
+"""
+
+
+def _collection_gced(ref):
+ # defaultdict, so can't get a KeyError
+ if not _collection_to_key or ref not in _collection_to_key:
+ return
+ listener_to_key = _collection_to_key.pop(ref)
+ for key in listener_to_key.values():
+ if key in _key_to_collection:
+ # defaultdict, so can't get a KeyError
+ dispatch_reg = _key_to_collection[key]
+ dispatch_reg.pop(ref)
+ if not dispatch_reg:
+ _key_to_collection.pop(key)
+
+
+def _stored_in_collection(event_key, owner):
+ key = event_key._key
+
+ dispatch_reg = _key_to_collection[key]
+
+ owner_ref = owner.ref
+ listen_ref = weakref.ref(event_key._listen_fn)
+
+ if owner_ref in dispatch_reg:
+ return False
+
+ dispatch_reg[owner_ref] = listen_ref
+
+ listener_to_key = _collection_to_key[owner_ref]
+ listener_to_key[listen_ref] = key
+
+ return True
+
+
+def _removed_from_collection(event_key, owner):
+ key = event_key._key
+
+ dispatch_reg = _key_to_collection[key]
+
+ listen_ref = weakref.ref(event_key._listen_fn)
+
+ owner_ref = owner.ref
+ dispatch_reg.pop(owner_ref, None)
+ if not dispatch_reg:
+ del _key_to_collection[key]
+
+ if owner_ref in _collection_to_key:
+ listener_to_key = _collection_to_key[owner_ref]
+ listener_to_key.pop(listen_ref)
+
+
+def _stored_in_collection_multi(newowner, oldowner, elements):
+ if not elements:
+ return
+
+ oldowner = oldowner.ref
+ newowner = newowner.ref
+
+ old_listener_to_key = _collection_to_key[oldowner]
+ new_listener_to_key = _collection_to_key[newowner]
+
+ for listen_fn in elements:
+ listen_ref = weakref.ref(listen_fn)
+ try:
+ key = old_listener_to_key[listen_ref]
+ except KeyError:
+ # can occur during interpreter shutdown.
+ # see #6740
+ continue
+
+ try:
+ dispatch_reg = _key_to_collection[key]
+ except KeyError:
+ continue
+
+ if newowner in dispatch_reg:
+ assert dispatch_reg[newowner] == listen_ref
+ else:
+ dispatch_reg[newowner] = listen_ref
+
+ new_listener_to_key[listen_ref] = key
+
+
+def _clear(owner, elements):
+ if not elements:
+ return
+
+ owner = owner.ref
+ listener_to_key = _collection_to_key[owner]
+ for listen_fn in elements:
+ listen_ref = weakref.ref(listen_fn)
+ key = listener_to_key[listen_ref]
+ dispatch_reg = _key_to_collection[key]
+ dispatch_reg.pop(owner, None)
+
+ if not dispatch_reg:
+ del _key_to_collection[key]
+
+
+class _EventKey(object):
+ """Represent :func:`.listen` arguments."""
+
+ __slots__ = (
+ "target",
+ "identifier",
+ "fn",
+ "fn_key",
+ "fn_wrap",
+ "dispatch_target",
+ )
+
+ def __init__(self, target, identifier, fn, dispatch_target, _fn_wrap=None):
+ self.target = target
+ self.identifier = identifier
+ self.fn = fn
+ if isinstance(fn, types.MethodType):
+ self.fn_key = id(fn.__func__), id(fn.__self__)
+ else:
+ self.fn_key = id(fn)
+ self.fn_wrap = _fn_wrap
+ self.dispatch_target = dispatch_target
+
+ @property
+ def _key(self):
+ return (id(self.target), self.identifier, self.fn_key)
+
+ def with_wrapper(self, fn_wrap):
+ if fn_wrap is self._listen_fn:
+ return self
+ else:
+ return _EventKey(
+ self.target,
+ self.identifier,
+ self.fn,
+ self.dispatch_target,
+ _fn_wrap=fn_wrap,
+ )
+
+ def with_dispatch_target(self, dispatch_target):
+ if dispatch_target is self.dispatch_target:
+ return self
+ else:
+ return _EventKey(
+ self.target,
+ self.identifier,
+ self.fn,
+ dispatch_target,
+ _fn_wrap=self.fn_wrap,
+ )
+
+ def listen(self, *args, **kw):
+ once = kw.pop("once", False)
+ once_unless_exception = kw.pop("_once_unless_exception", False)
+ named = kw.pop("named", False)
+
+ target, identifier, fn = (
+ self.dispatch_target,
+ self.identifier,
+ self._listen_fn,
+ )
+
+ dispatch_collection = getattr(target.dispatch, identifier)
+
+ adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named)
+
+ self = self.with_wrapper(adjusted_fn)
+
+ stub_function = getattr(
+ self.dispatch_target.dispatch._events, self.identifier
+ )
+ if hasattr(stub_function, "_sa_warn"):
+ stub_function._sa_warn()
+
+ if once or once_unless_exception:
+ self.with_wrapper(
+ util.only_once(
+ self._listen_fn, retry_on_exception=once_unless_exception
+ )
+ ).listen(*args, **kw)
+ else:
+ self.dispatch_target.dispatch._listen(self, *args, **kw)
+
+ def remove(self):
+ key = self._key
+
+ if key not in _key_to_collection:
+ raise exc.InvalidRequestError(
+ "No listeners found for event %s / %r / %s "
+ % (self.target, self.identifier, self.fn)
+ )
+
+ dispatch_reg = _key_to_collection.pop(key)
+
+ for collection_ref, listener_ref in dispatch_reg.items():
+ collection = collection_ref()
+ listener_fn = listener_ref()
+ if collection is not None and listener_fn is not None:
+ collection.remove(self.with_wrapper(listener_fn))
+
+ def contains(self):
+ """Return True if this event key is registered to listen."""
+ return self._key in _key_to_collection
+
+ def base_listen(
+ self,
+ propagate=False,
+ insert=False,
+ named=False,
+ retval=None,
+ asyncio=False,
+ ):
+
+ target, identifier = self.dispatch_target, self.identifier
+
+ dispatch_collection = getattr(target.dispatch, identifier)
+
+ for_modify = dispatch_collection.for_modify(target.dispatch)
+ if asyncio:
+ for_modify._set_asyncio()
+
+ if insert:
+ for_modify.insert(self, propagate)
+ else:
+ for_modify.append(self, propagate)
+
+ @property
+ def _listen_fn(self):
+ return self.fn_wrap or self.fn
+
+ def append_to_list(self, owner, list_):
+ if _stored_in_collection(self, owner):
+ list_.append(self._listen_fn)
+ return True
+ else:
+ return False
+
+ def remove_from_list(self, owner, list_):
+ _removed_from_collection(self, owner)
+ list_.remove(self._listen_fn)
+
+ def prepend_to_list(self, owner, list_):
+ if _stored_in_collection(self, owner):
+ list_.appendleft(self._listen_fn)
+ return True
+ else:
+ return False
diff --git a/lib/sqlalchemy/events.py b/lib/sqlalchemy/events.py
new file mode 100644
index 0000000..d17b0b1
--- /dev/null
+++ b/lib/sqlalchemy/events.py
@@ -0,0 +1,14 @@
+# sqlalchemy/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Core event interfaces."""
+
+from .engine.events import ConnectionEvents
+from .engine.events import DialectEvents
+from .pool.events import PoolEvents
+from .sql.base import SchemaEventTarget
+from .sql.events import DDLEvents
diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py
new file mode 100644
index 0000000..78bcef3
--- /dev/null
+++ b/lib/sqlalchemy/exc.py
@@ -0,0 +1,733 @@
+# sqlalchemy/exc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Exceptions used with SQLAlchemy.
+
+The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are
+raised as a result of DBAPI exceptions are all subclasses of
+:exc:`.DBAPIError`.
+
+"""
+
+from .util import _preloaded
+from .util import compat
+
+_version_token = None
+
+
+class HasDescriptionCode(object):
+ """helper which adds 'code' as an attribute and '_code_str' as a method"""
+
+ code = None
+
+ def __init__(self, *arg, **kw):
+ code = kw.pop("code", None)
+ if code is not None:
+ self.code = code
+ super(HasDescriptionCode, self).__init__(*arg, **kw)
+
+ def _code_str(self):
+ if not self.code:
+ return ""
+ else:
+ return (
+ "(Background on this error at: "
+ "https://sqlalche.me/e/%s/%s)"
+ % (
+ _version_token,
+ self.code,
+ )
+ )
+
+ def __str__(self):
+ message = super(HasDescriptionCode, self).__str__()
+ if self.code:
+ message = "%s %s" % (message, self._code_str())
+ return message
+
+
+class SQLAlchemyError(HasDescriptionCode, Exception):
+ """Generic error class."""
+
+ def _message(self, as_unicode=compat.py3k):
+ # rules:
+ #
+ # 1. under py2k, for __str__ return single string arg as it was
+ # given without converting to unicode. for __unicode__
+ # do a conversion but check that it's not unicode already just in
+ # case
+ #
+ # 2. under py3k, single arg string will usually be a unicode
+ # object, but since __str__() must return unicode, check for
+ # bytestring just in case
+ #
+ # 3. for multiple self.args, this is not a case in current
+ # SQLAlchemy though this is happening in at least one known external
+ # library, call str() which does a repr().
+ #
+ if len(self.args) == 1:
+ text = self.args[0]
+
+ if as_unicode and isinstance(text, compat.binary_types):
+ text = compat.decode_backslashreplace(text, "utf-8")
+ # This is for when the argument is not a string of any sort.
+ # Otherwise, converting this exception to string would fail for
+ # non-string arguments.
+ elif compat.py3k or not as_unicode:
+ text = str(text)
+ else:
+ text = compat.text_type(text)
+
+ return text
+ else:
+ # this is not a normal case within SQLAlchemy but is here for
+ # compatibility with Exception.args - the str() comes out as
+ # a repr() of the tuple
+ return str(self.args)
+
+ def _sql_message(self, as_unicode):
+ message = self._message(as_unicode)
+
+ if self.code:
+ message = "%s %s" % (message, self._code_str())
+
+ return message
+
+ def __str__(self):
+ return self._sql_message(compat.py3k)
+
+ def __unicode__(self):
+ return self._sql_message(as_unicode=True)
+
+
+class ArgumentError(SQLAlchemyError):
+ """Raised when an invalid or conflicting function argument is supplied.
+
+ This error generally corresponds to construction time state errors.
+
+ """
+
+
+class ObjectNotExecutableError(ArgumentError):
+ """Raised when an object is passed to .execute() that can't be
+ executed as SQL.
+
+ .. versionadded:: 1.1
+
+ """
+
+ def __init__(self, target):
+ super(ObjectNotExecutableError, self).__init__(
+ "Not an executable object: %r" % target
+ )
+ self.target = target
+
+ def __reduce__(self):
+ return self.__class__, (self.target,)
+
+
+class NoSuchModuleError(ArgumentError):
+ """Raised when a dynamically-loaded module (usually a database dialect)
+ of a particular name cannot be located."""
+
+
+class NoForeignKeysError(ArgumentError):
+ """Raised when no foreign keys can be located between two selectables
+ during a join."""
+
+
+class AmbiguousForeignKeysError(ArgumentError):
+ """Raised when more than one foreign key matching can be located
+ between two selectables during a join."""
+
+
+class CircularDependencyError(SQLAlchemyError):
+ """Raised by topological sorts when a circular dependency is detected.
+
+ There are two scenarios where this error occurs:
+
+ * In a Session flush operation, if two objects are mutually dependent
+ on each other, they can not be inserted or deleted via INSERT or
+ DELETE statements alone; an UPDATE will be needed to post-associate
+ or pre-deassociate one of the foreign key constrained values.
+ The ``post_update`` flag described at :ref:`post_update` can resolve
+ this cycle.
+ * In a :attr:`_schema.MetaData.sorted_tables` operation, two
+ :class:`_schema.ForeignKey`
+ or :class:`_schema.ForeignKeyConstraint` objects mutually refer to each
+ other. Apply the ``use_alter=True`` flag to one or both,
+ see :ref:`use_alter`.
+
+ """
+
+ def __init__(self, message, cycles, edges, msg=None, code=None):
+ if msg is None:
+ message += " (%s)" % ", ".join(repr(s) for s in cycles)
+ else:
+ message = msg
+ SQLAlchemyError.__init__(self, message, code=code)
+ self.cycles = cycles
+ self.edges = edges
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (None, self.cycles, self.edges, self.args[0]),
+ {"code": self.code} if self.code is not None else {},
+ )
+
+
+class CompileError(SQLAlchemyError):
+ """Raised when an error occurs during SQL compilation"""
+
+
+class UnsupportedCompilationError(CompileError):
+ """Raised when an operation is not supported by the given compiler.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string`
+
+ :ref:`error_l7de`
+ """
+
+ code = "l7de"
+
+ def __init__(self, compiler, element_type, message=None):
+ super(UnsupportedCompilationError, self).__init__(
+ "Compiler %r can't render element of type %s%s"
+ % (compiler, element_type, ": %s" % message if message else "")
+ )
+ self.compiler = compiler
+ self.element_type = element_type
+ self.message = message
+
+ def __reduce__(self):
+ return self.__class__, (self.compiler, self.element_type, self.message)
+
+
+class IdentifierError(SQLAlchemyError):
+ """Raised when a schema name is beyond the max character limit"""
+
+
+class DisconnectionError(SQLAlchemyError):
+ """A disconnect is detected on a raw DB-API connection.
+
+ This error is raised and consumed internally by a connection pool. It can
+ be raised by the :meth:`_events.PoolEvents.checkout`
+ event so that the host pool
+ forces a retry; the exception will be caught three times in a row before
+ the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError`
+ regarding the connection attempt.
+
+ """
+
+ invalidate_pool = False
+
+
+class InvalidatePoolError(DisconnectionError):
+ """Raised when the connection pool should invalidate all stale connections.
+
+ A subclass of :class:`_exc.DisconnectionError` that indicates that the
+ disconnect situation encountered on the connection probably means the
+ entire pool should be invalidated, as the database has been restarted.
+
+ This exception will be handled otherwise the same way as
+ :class:`_exc.DisconnectionError`, allowing three attempts to reconnect
+ before giving up.
+
+ .. versionadded:: 1.2
+
+ """
+
+ invalidate_pool = True
+
+
+class TimeoutError(SQLAlchemyError): # noqa
+ """Raised when a connection pool times out on getting a connection."""
+
+
+class InvalidRequestError(SQLAlchemyError):
+ """SQLAlchemy was asked to do something it can't do.
+
+ This error generally corresponds to runtime state errors.
+
+ """
+
+
+class NoInspectionAvailable(InvalidRequestError):
+ """A subject passed to :func:`sqlalchemy.inspection.inspect` produced
+ no context for inspection."""
+
+
+class PendingRollbackError(InvalidRequestError):
+ """A transaction has failed and needs to be rolled back before
+ continuing.
+
+ .. versionadded:: 1.4
+
+ """
+
+
+class ResourceClosedError(InvalidRequestError):
+ """An operation was requested from a connection, cursor, or other
+ object that's in a closed state."""
+
+
+class NoSuchColumnError(InvalidRequestError, KeyError):
+ """A nonexistent column is requested from a ``Row``."""
+
+
+class NoResultFound(InvalidRequestError):
+ """A database result was required but none was found.
+
+
+ .. versionchanged:: 1.4 This exception is now part of the
+ ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol
+ remains importable from ``sqlalchemy.orm.exc``.
+
+
+ """
+
+
+class MultipleResultsFound(InvalidRequestError):
+ """A single database result was required but more than one were found.
+
+ .. versionchanged:: 1.4 This exception is now part of the
+ ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol
+ remains importable from ``sqlalchemy.orm.exc``.
+
+
+ """
+
+
+class NoReferenceError(InvalidRequestError):
+ """Raised by ``ForeignKey`` to indicate a reference cannot be resolved."""
+
+
+class AwaitRequired(InvalidRequestError):
+ """Error raised by the async greenlet spawn if no async operation
+ was awaited when it required one.
+
+ """
+
+ code = "xd1r"
+
+
+class MissingGreenlet(InvalidRequestError):
+ r"""Error raised by the async greenlet await\_ if called while not inside
+ the greenlet spawn context.
+
+ """
+
+ code = "xd2s"
+
+
+class NoReferencedTableError(NoReferenceError):
+ """Raised by ``ForeignKey`` when the referred ``Table`` cannot be
+ located.
+
+ """
+
+ def __init__(self, message, tname):
+ NoReferenceError.__init__(self, message)
+ self.table_name = tname
+
+ def __reduce__(self):
+ return self.__class__, (self.args[0], self.table_name)
+
+
+class NoReferencedColumnError(NoReferenceError):
+ """Raised by ``ForeignKey`` when the referred ``Column`` cannot be
+ located.
+
+ """
+
+ def __init__(self, message, tname, cname):
+ NoReferenceError.__init__(self, message)
+ self.table_name = tname
+ self.column_name = cname
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (self.args[0], self.table_name, self.column_name),
+ )
+
+
+class NoSuchTableError(InvalidRequestError):
+ """Table does not exist or is not visible to a connection."""
+
+
+class UnreflectableTableError(InvalidRequestError):
+ """Table exists but can't be reflected for some reason.
+
+ .. versionadded:: 1.2
+
+ """
+
+
+class UnboundExecutionError(InvalidRequestError):
+ """SQL was attempted without a database connection to execute it on."""
+
+
+class DontWrapMixin(object):
+ """A mixin class which, when applied to a user-defined Exception class,
+ will not be wrapped inside of :exc:`.StatementError` if the error is
+ emitted within the process of executing a statement.
+
+ E.g.::
+
+ from sqlalchemy.exc import DontWrapMixin
+
+ class MyCustomException(Exception, DontWrapMixin):
+ pass
+
+ class MySpecialType(TypeDecorator):
+ impl = String
+
+ def process_bind_param(self, value, dialect):
+ if value == 'invalid':
+ raise MyCustomException("invalid!")
+
+ """
+
+
+class StatementError(SQLAlchemyError):
+ """An error occurred during execution of a SQL statement.
+
+ :class:`StatementError` wraps the exception raised
+ during execution, and features :attr:`.statement`
+ and :attr:`.params` attributes which supply context regarding
+ the specifics of the statement which had an issue.
+
+ The wrapped exception object is available in
+ the :attr:`.orig` attribute.
+
+ """
+
+ statement = None
+ """The string SQL statement being invoked when this exception occurred."""
+
+ params = None
+ """The parameter list being used when this exception occurred."""
+
+ orig = None
+ """The DBAPI exception object."""
+
+ ismulti = None
+
+ def __init__(
+ self,
+ message,
+ statement,
+ params,
+ orig,
+ hide_parameters=False,
+ code=None,
+ ismulti=None,
+ ):
+ SQLAlchemyError.__init__(self, message, code=code)
+ self.statement = statement
+ self.params = params
+ self.orig = orig
+ self.ismulti = ismulti
+ self.hide_parameters = hide_parameters
+ self.detail = []
+
+ def add_detail(self, msg):
+ self.detail.append(msg)
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (
+ self.args[0],
+ self.statement,
+ self.params,
+ self.orig,
+ self.hide_parameters,
+ self.__dict__.get("code"),
+ self.ismulti,
+ ),
+ {"detail": self.detail},
+ )
+
+ @_preloaded.preload_module("sqlalchemy.sql.util")
+ def _sql_message(self, as_unicode):
+ util = _preloaded.preloaded.sql_util
+
+ details = [self._message(as_unicode=as_unicode)]
+ if self.statement:
+ if not as_unicode and not compat.py3k:
+ stmt_detail = "[SQL: %s]" % compat.safe_bytestring(
+ self.statement
+ )
+ else:
+ stmt_detail = "[SQL: %s]" % self.statement
+ details.append(stmt_detail)
+ if self.params:
+ if self.hide_parameters:
+ details.append(
+ "[SQL parameters hidden due to hide_parameters=True]"
+ )
+ else:
+ params_repr = util._repr_params(
+ self.params, 10, ismulti=self.ismulti
+ )
+ details.append("[parameters: %r]" % params_repr)
+ code_str = self._code_str()
+ if code_str:
+ details.append(code_str)
+ return "\n".join(["(%s)" % det for det in self.detail] + details)
+
+
+class DBAPIError(StatementError):
+ """Raised when the execution of a database operation fails.
+
+ Wraps exceptions raised by the DB-API underlying the
+ database operation. Driver-specific implementations of the standard
+ DB-API exception types are wrapped by matching sub-types of SQLAlchemy's
+ :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to
+ :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note
+ that there is no guarantee that different DB-API implementations will
+ raise the same exception type for any given error condition.
+
+ :class:`DBAPIError` features :attr:`~.StatementError.statement`
+ and :attr:`~.StatementError.params` attributes which supply context
+ regarding the specifics of the statement which had an issue, for the
+ typical case when the error was raised within the context of
+ emitting a SQL statement.
+
+ The wrapped exception object is available in the
+ :attr:`~.StatementError.orig` attribute. Its type and properties are
+ DB-API implementation specific.
+
+ """
+
+ code = "dbapi"
+
+ @classmethod
+ def instance(
+ cls,
+ statement,
+ params,
+ orig,
+ dbapi_base_err,
+ hide_parameters=False,
+ connection_invalidated=False,
+ dialect=None,
+ ismulti=None,
+ ):
+ # Don't ever wrap these, just return them directly as if
+ # DBAPIError didn't exist.
+ if (
+ isinstance(orig, BaseException) and not isinstance(orig, Exception)
+ ) or isinstance(orig, DontWrapMixin):
+ return orig
+
+ if orig is not None:
+ # not a DBAPI error, statement is present.
+ # raise a StatementError
+ if isinstance(orig, SQLAlchemyError) and statement:
+ return StatementError(
+ "(%s.%s) %s"
+ % (
+ orig.__class__.__module__,
+ orig.__class__.__name__,
+ orig.args[0],
+ ),
+ statement,
+ params,
+ orig,
+ hide_parameters=hide_parameters,
+ code=orig.code,
+ ismulti=ismulti,
+ )
+ elif not isinstance(orig, dbapi_base_err) and statement:
+ return StatementError(
+ "(%s.%s) %s"
+ % (
+ orig.__class__.__module__,
+ orig.__class__.__name__,
+ orig,
+ ),
+ statement,
+ params,
+ orig,
+ hide_parameters=hide_parameters,
+ ismulti=ismulti,
+ )
+
+ glob = globals()
+ for super_ in orig.__class__.__mro__:
+ name = super_.__name__
+ if dialect:
+ name = dialect.dbapi_exception_translation_map.get(
+ name, name
+ )
+ if name in glob and issubclass(glob[name], DBAPIError):
+ cls = glob[name]
+ break
+
+ return cls(
+ statement,
+ params,
+ orig,
+ connection_invalidated=connection_invalidated,
+ hide_parameters=hide_parameters,
+ code=cls.code,
+ ismulti=ismulti,
+ )
+
+ def __reduce__(self):
+ return (
+ self.__class__,
+ (
+ self.statement,
+ self.params,
+ self.orig,
+ self.hide_parameters,
+ self.connection_invalidated,
+ self.__dict__.get("code"),
+ self.ismulti,
+ ),
+ {"detail": self.detail},
+ )
+
+ def __init__(
+ self,
+ statement,
+ params,
+ orig,
+ hide_parameters=False,
+ connection_invalidated=False,
+ code=None,
+ ismulti=None,
+ ):
+ try:
+ text = str(orig)
+ except Exception as e:
+ text = "Error in str() of DB-API-generated exception: " + str(e)
+ StatementError.__init__(
+ self,
+ "(%s.%s) %s"
+ % (orig.__class__.__module__, orig.__class__.__name__, text),
+ statement,
+ params,
+ orig,
+ hide_parameters,
+ code=code,
+ ismulti=ismulti,
+ )
+ self.connection_invalidated = connection_invalidated
+
+
+class InterfaceError(DBAPIError):
+ """Wraps a DB-API InterfaceError."""
+
+ code = "rvf5"
+
+
+class DatabaseError(DBAPIError):
+ """Wraps a DB-API DatabaseError."""
+
+ code = "4xp6"
+
+
+class DataError(DatabaseError):
+ """Wraps a DB-API DataError."""
+
+ code = "9h9h"
+
+
+class OperationalError(DatabaseError):
+ """Wraps a DB-API OperationalError."""
+
+ code = "e3q8"
+
+
+class IntegrityError(DatabaseError):
+ """Wraps a DB-API IntegrityError."""
+
+ code = "gkpj"
+
+
+class InternalError(DatabaseError):
+ """Wraps a DB-API InternalError."""
+
+ code = "2j85"
+
+
+class ProgrammingError(DatabaseError):
+ """Wraps a DB-API ProgrammingError."""
+
+ code = "f405"
+
+
+class NotSupportedError(DatabaseError):
+ """Wraps a DB-API NotSupportedError."""
+
+ code = "tw8g"
+
+
+# Warnings
+
+
+class SADeprecationWarning(HasDescriptionCode, DeprecationWarning):
+ """Issued for usage of deprecated APIs."""
+
+ deprecated_since = None
+ "Indicates the version that started raising this deprecation warning"
+
+
+class Base20DeprecationWarning(SADeprecationWarning):
+ """Issued for usage of APIs specifically deprecated or legacy in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :ref:`error_b8d9`.
+
+ :ref:`deprecation_20_mode`
+
+ """
+
+ deprecated_since = "1.4"
+ "Indicates the version that started raising this deprecation warning"
+
+ def __str__(self):
+ return (
+ super(Base20DeprecationWarning, self).__str__()
+ + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)"
+ )
+
+
+class LegacyAPIWarning(Base20DeprecationWarning):
+ """indicates an API that is in 'legacy' status, a long term deprecation."""
+
+
+class RemovedIn20Warning(Base20DeprecationWarning):
+ """indicates an API that will be fully removed in SQLAlchemy 2.0."""
+
+
+class MovedIn20Warning(RemovedIn20Warning):
+ """Subtype of RemovedIn20Warning to indicate an API that moved only."""
+
+
+class SAPendingDeprecationWarning(PendingDeprecationWarning):
+ """A similar warning as :class:`_exc.SADeprecationWarning`, this warning
+ is not used in modern versions of SQLAlchemy.
+
+ """
+
+ deprecated_since = None
+ "Indicates the version that started raising this deprecation warning"
+
+
+class SAWarning(HasDescriptionCode, RuntimeWarning):
+ """Issued at runtime."""
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py
new file mode 100644
index 0000000..62bbbf3
--- /dev/null
+++ b/lib/sqlalchemy/ext/__init__.py
@@ -0,0 +1,11 @@
+# ext/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .. import util as _sa_util
+
+
+_sa_util.preloaded.import_prefix("sqlalchemy.ext")
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
new file mode 100644
index 0000000..fbf377a
--- /dev/null
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -0,0 +1,1627 @@
+# ext/associationproxy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Contain the ``AssociationProxy`` class.
+
+The ``AssociationProxy`` is a Python property object which provides
+transparent proxied access to the endpoint of an association object.
+
+See the example ``examples/association/proxied_association.py``.
+
+"""
+import operator
+
+from .. import exc
+from .. import inspect
+from .. import orm
+from .. import util
+from ..orm import collections
+from ..orm import interfaces
+from ..sql import or_
+from ..sql.operators import ColumnOperators
+
+
+def association_proxy(target_collection, attr, **kw):
+ r"""Return a Python property implementing a view of a target
+ attribute which references an attribute on members of the
+ target.
+
+ The returned value is an instance of :class:`.AssociationProxy`.
+
+ Implements a Python property representing a relationship as a collection
+ of simpler values, or a scalar value. The proxied property will mimic
+ the collection type of the target (list, dict or set), or, in the case of
+ a one to one relationship, a simple scalar value.
+
+ :param target_collection: Name of the attribute we'll proxy to.
+ This attribute is typically mapped by
+ :func:`~sqlalchemy.orm.relationship` to link to a target collection, but
+ can also be a many-to-one or non-scalar relationship.
+
+ :param attr: Attribute on the associated instance or instances we'll
+ proxy for.
+
+ For example, given a target collection of [obj1, obj2], a list created
+ by this proxy property would look like [getattr(obj1, *attr*),
+ getattr(obj2, *attr*)]
+
+ If the relationship is one-to-one or otherwise uselist=False, then
+ simply: getattr(obj, *attr*)
+
+ :param creator: optional.
+
+ When new items are added to this proxied collection, new instances of
+ the class collected by the target collection will be created. For list
+ and set collections, the target class constructor will be called with
+ the 'value' for the new instance. For dict types, two arguments are
+ passed: key and value.
+
+ If you want to construct instances differently, supply a *creator*
+ function that takes arguments as above and returns instances.
+
+ For scalar relationships, creator() will be called if the target is None.
+ If the target is present, set operations are proxied to setattr() on the
+ associated object.
+
+ If you have an associated object with multiple attributes, you may set
+ up multiple association proxies mapping to different attributes. See
+ the unit tests for examples, and for examples of how creator() functions
+ can be used to construct the scalar relationship on-demand in this
+ situation.
+
+ :param \*\*kw: Passes along any other keyword arguments to
+ :class:`.AssociationProxy`.
+
+ """
+ return AssociationProxy(target_collection, attr, **kw)
+
+
+ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
+"""Symbol indicating an :class:`.InspectionAttr` that's
+ of type :class:`.AssociationProxy`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+"""
+
+
+class AssociationProxy(interfaces.InspectionAttrInfo):
+ """A descriptor that presents a read/write view of an object attribute."""
+
+ is_attribute = True
+ extension_type = ASSOCIATION_PROXY
+
+ def __init__(
+ self,
+ target_collection,
+ attr,
+ creator=None,
+ getset_factory=None,
+ proxy_factory=None,
+ proxy_bulk_set=None,
+ info=None,
+ cascade_scalar_deletes=False,
+ ):
+ """Construct a new :class:`.AssociationProxy`.
+
+ The :func:`.association_proxy` function is provided as the usual
+ entrypoint here, though :class:`.AssociationProxy` can be instantiated
+ and/or subclassed directly.
+
+ :param target_collection: Name of the collection we'll proxy to,
+ usually created with :func:`_orm.relationship`.
+
+ :param attr: Attribute on the collected instances we'll proxy
+ for. For example, given a target collection of [obj1, obj2], a
+ list created by this proxy property would look like
+ [getattr(obj1, attr), getattr(obj2, attr)]
+
+ :param creator: Optional. When new items are added to this proxied
+ collection, new instances of the class collected by the target
+ collection will be created. For list and set collections, the
+ target class constructor will be called with the 'value' for the
+ new instance. For dict types, two arguments are passed:
+ key and value.
+
+ If you want to construct instances differently, supply a 'creator'
+ function that takes arguments as above and returns instances.
+
+ :param cascade_scalar_deletes: when True, indicates that setting
+ the proxied value to ``None``, or deleting it via ``del``, should
+ also remove the source object. Only applies to scalar attributes.
+ Normally, removing the proxied target will not remove the proxy
+ source, as this object may have other state that is still to be
+ kept.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`cascade_scalar_deletes` - complete usage example
+
+ :param getset_factory: Optional. Proxied attribute access is
+ automatically handled by routines that get and set values based on
+ the `attr` argument for this proxy.
+
+ If you would like to customize this behavior, you may supply a
+ `getset_factory` callable that produces a tuple of `getter` and
+ `setter` functions. The factory is called with two arguments, the
+ abstract type of the underlying collection and this proxy instance.
+
+ :param proxy_factory: Optional. The type of collection to emulate is
+ determined by sniffing the target collection. If your collection
+ type can't be determined by duck typing or you'd like to use a
+ different collection implementation, you may supply a factory
+ function to produce those collections. Only applicable to
+ non-scalar relationships.
+
+ :param proxy_bulk_set: Optional, use with proxy_factory. See
+ the _set() method for details.
+
+ :param info: optional, will be assigned to
+ :attr:`.AssociationProxy.info` if present.
+
+ .. versionadded:: 1.0.9
+
+ """
+ self.target_collection = target_collection
+ self.value_attr = attr
+ self.creator = creator
+ self.getset_factory = getset_factory
+ self.proxy_factory = proxy_factory
+ self.proxy_bulk_set = proxy_bulk_set
+ self.cascade_scalar_deletes = cascade_scalar_deletes
+
+ self.key = "_%s_%s_%s" % (
+ type(self).__name__,
+ target_collection,
+ id(self),
+ )
+ if info:
+ self.info = info
+
+ def __get__(self, obj, class_):
+ if class_ is None:
+ return self
+ inst = self._as_instance(class_, obj)
+ if inst:
+ return inst.get(obj)
+
+ # obj has to be None here
+ # assert obj is None
+
+ return self
+
+ def __set__(self, obj, values):
+ class_ = type(obj)
+ return self._as_instance(class_, obj).set(obj, values)
+
+ def __delete__(self, obj):
+ class_ = type(obj)
+ return self._as_instance(class_, obj).delete(obj)
+
+ def for_class(self, class_, obj=None):
+ r"""Return the internal state local to a specific mapped class.
+
+ E.g., given a class ``User``::
+
+ class User(Base):
+ # ...
+
+ keywords = association_proxy('kws', 'keyword')
+
+ If we access this :class:`.AssociationProxy` from
+ :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the
+ target class for this proxy as mapped by ``User``::
+
+ inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class
+
+ This returns an instance of :class:`.AssociationProxyInstance` that
+ is specific to the ``User`` class. The :class:`.AssociationProxy`
+ object remains agnostic of its parent class.
+
+ :param class\_: the class that we are returning state for.
+
+ :param obj: optional, an instance of the class that is required
+ if the attribute refers to a polymorphic target, e.g. where we have
+ to look at the type of the actual destination object to get the
+ complete path.
+
+ .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores
+ any state specific to a particular parent class; the state is now
+ stored in per-class :class:`.AssociationProxyInstance` objects.
+
+
+ """
+ return self._as_instance(class_, obj)
+
+ def _as_instance(self, class_, obj):
+ try:
+ inst = class_.__dict__[self.key + "_inst"]
+ except KeyError:
+ inst = None
+
+ # avoid exception context
+ if inst is None:
+ owner = self._calc_owner(class_)
+ if owner is not None:
+ inst = AssociationProxyInstance.for_proxy(self, owner, obj)
+ setattr(class_, self.key + "_inst", inst)
+ else:
+ inst = None
+
+ if inst is not None and not inst._is_canonical:
+ # the AssociationProxyInstance can't be generalized
+ # since the proxied attribute is not on the targeted
+ # class, only on subclasses of it, which might be
+ # different. only return for the specific
+ # object's current value
+ return inst._non_canonical_get_for_object(obj)
+ else:
+ return inst
+
+ def _calc_owner(self, target_cls):
+ # we might be getting invoked for a subclass
+ # that is not mapped yet, in some declarative situations.
+ # save until we are mapped
+ try:
+ insp = inspect(target_cls)
+ except exc.NoInspectionAvailable:
+ # can't find a mapper, don't set owner. if we are a not-yet-mapped
+ # subclass, we can also scan through __mro__ to find a mapped
+ # class, but instead just wait for us to be called again against a
+ # mapped class normally.
+ return None
+ else:
+ return insp.mapper.class_manager.class_
+
+ def _default_getset(self, collection_class):
+ attr = self.value_attr
+ _getter = operator.attrgetter(attr)
+
+ def getter(target):
+ return _getter(target) if target is not None else None
+
+ if collection_class is dict:
+
+ def setter(o, k, v):
+ setattr(o, attr, v)
+
+ else:
+
+ def setter(o, v):
+ setattr(o, attr, v)
+
+ return getter, setter
+
+ def __repr__(self):
+ return "AssociationProxy(%r, %r)" % (
+ self.target_collection,
+ self.value_attr,
+ )
+
+
+class AssociationProxyInstance(object):
+ """A per-class object that serves class- and object-specific results.
+
+ This is used by :class:`.AssociationProxy` when it is invoked
+ in terms of a specific class or instance of a class, i.e. when it is
+ used as a regular Python descriptor.
+
+ When referring to the :class:`.AssociationProxy` as a normal Python
+ descriptor, the :class:`.AssociationProxyInstance` is the object that
+ actually serves the information. Under normal circumstances, its presence
+ is transparent::
+
+ >>> User.keywords.scalar
+ False
+
+ In the special case that the :class:`.AssociationProxy` object is being
+ accessed directly, in order to get an explicit handle to the
+ :class:`.AssociationProxyInstance`, use the
+ :meth:`.AssociationProxy.for_class` method::
+
+ proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User)
+
+ # view if proxy object is scalar or not
+ >>> proxy_state.scalar
+ False
+
+ .. versionadded:: 1.3
+
+ """ # noqa
+
+ def __init__(self, parent, owning_class, target_class, value_attr):
+ self.parent = parent
+ self.key = parent.key
+ self.owning_class = owning_class
+ self.target_collection = parent.target_collection
+ self.collection_class = None
+ self.target_class = target_class
+ self.value_attr = value_attr
+
+ target_class = None
+ """The intermediary class handled by this
+ :class:`.AssociationProxyInstance`.
+
+ Intercepted append/set/assignment events will result
+ in the generation of new instances of this class.
+
+ """
+
+ @classmethod
+ def for_proxy(cls, parent, owning_class, parent_instance):
+ target_collection = parent.target_collection
+ value_attr = parent.value_attr
+ prop = orm.class_mapper(owning_class).get_property(target_collection)
+
+ # this was never asserted before but this should be made clear.
+ if not isinstance(prop, orm.RelationshipProperty):
+ util.raise_(
+ NotImplementedError(
+ "association proxy to a non-relationship "
+ "intermediary is not supported"
+ ),
+ replace_context=None,
+ )
+
+ target_class = prop.mapper.class_
+
+ try:
+ target_assoc = cls._cls_unwrap_target_assoc_proxy(
+ target_class, value_attr
+ )
+ except AttributeError:
+ # the proxied attribute doesn't exist on the target class;
+ # return an "ambiguous" instance that will work on a per-object
+ # basis
+ return AmbiguousAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ except Exception as err:
+ util.raise_(
+ exc.InvalidRequestError(
+ "Association proxy received an unexpected error when "
+ "trying to retreive attribute "
+ '"%s.%s" from '
+ 'class "%s": %s'
+ % (
+ target_class.__name__,
+ parent.value_attr,
+ target_class.__name__,
+ err,
+ )
+ ),
+ from_=err,
+ )
+ else:
+ return cls._construct_for_assoc(
+ target_assoc, parent, owning_class, target_class, value_attr
+ )
+
+ @classmethod
+ def _construct_for_assoc(
+ cls, target_assoc, parent, owning_class, target_class, value_attr
+ ):
+ if target_assoc is not None:
+ return ObjectAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+
+ attr = getattr(target_class, value_attr)
+ if not hasattr(attr, "_is_internal_proxy"):
+ return AmbiguousAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ is_object = attr._impl_uses_objects
+ if is_object:
+ return ObjectAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ else:
+ return ColumnAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+
+ def _get_property(self):
+ return orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
+
+ @property
+ def _comparator(self):
+ return self._get_property().comparator
+
+ def __clause_element__(self):
+ raise NotImplementedError(
+ "The association proxy can't be used as a plain column "
+ "expression; it only works inside of a comparison expression"
+ )
+
+ @classmethod
+ def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr):
+ attr = getattr(target_class, value_attr)
+ if isinstance(attr, (AssociationProxy, AssociationProxyInstance)):
+ return attr
+ return None
+
+ @util.memoized_property
+ def _unwrap_target_assoc_proxy(self):
+ return self._cls_unwrap_target_assoc_proxy(
+ self.target_class, self.value_attr
+ )
+
+ @property
+ def remote_attr(self):
+ """The 'remote' class attribute referenced by this
+ :class:`.AssociationProxyInstance`.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.attr`
+
+ :attr:`.AssociationProxyInstance.local_attr`
+
+ """
+ return getattr(self.target_class, self.value_attr)
+
+ @property
+ def local_attr(self):
+ """The 'local' class attribute referenced by this
+ :class:`.AssociationProxyInstance`.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.attr`
+
+ :attr:`.AssociationProxyInstance.remote_attr`
+
+ """
+ return getattr(self.owning_class, self.target_collection)
+
+ @property
+ def attr(self):
+ """Return a tuple of ``(local_attr, remote_attr)``.
+
+ This attribute was originally intended to facilitate using the
+ :meth:`_query.Query.join` method to join across the two relationships
+ at once, however this makes use of a deprecated calling style.
+
+ To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with
+ an association proxy, the current method is to make use of the
+ :attr:`.AssociationProxyInstance.local_attr` and
+ :attr:`.AssociationProxyInstance.remote_attr` attributes separately::
+
+ stmt = (
+ select(Parent).
+ join(Parent.proxied.local_attr).
+ join(Parent.proxied.remote_attr)
+ )
+
+ A future release may seek to provide a more succinct join pattern
+ for association proxy attributes.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.local_attr`
+
+ :attr:`.AssociationProxyInstance.remote_attr`
+
+ """
+ return (self.local_attr, self.remote_attr)
+
+ @util.memoized_property
+ def scalar(self):
+ """Return ``True`` if this :class:`.AssociationProxyInstance`
+ proxies a scalar relationship on the local side."""
+
+ scalar = not self._get_property().uselist
+ if scalar:
+ self._initialize_scalar_accessors()
+ return scalar
+
+ @util.memoized_property
+ def _value_is_scalar(self):
+ return (
+ not self._get_property()
+ .mapper.get_property(self.value_attr)
+ .uselist
+ )
+
+ @property
+ def _target_is_object(self):
+ raise NotImplementedError()
+
+ def _initialize_scalar_accessors(self):
+ if self.parent.getset_factory:
+ get, set_ = self.parent.getset_factory(None, self)
+ else:
+ get, set_ = self.parent._default_getset(None)
+ self._scalar_get, self._scalar_set = get, set_
+
+ def _default_getset(self, collection_class):
+ attr = self.value_attr
+ _getter = operator.attrgetter(attr)
+
+ def getter(target):
+ return _getter(target) if target is not None else None
+
+ if collection_class is dict:
+
+ def setter(o, k, v):
+ return setattr(o, attr, v)
+
+ else:
+
+ def setter(o, v):
+ return setattr(o, attr, v)
+
+ return getter, setter
+
+ @property
+ def info(self):
+ return self.parent.info
+
+ def get(self, obj):
+ if obj is None:
+ return self
+
+ if self.scalar:
+ target = getattr(obj, self.target_collection)
+ return self._scalar_get(target)
+ else:
+ try:
+ # If the owning instance is reborn (orm session resurrect,
+ # etc.), refresh the proxy cache.
+ creator_id, self_id, proxy = getattr(obj, self.key)
+ except AttributeError:
+ pass
+ else:
+ if id(obj) == creator_id and id(self) == self_id:
+ assert self.collection_class is not None
+ return proxy
+
+ self.collection_class, proxy = self._new(
+ _lazy_collection(obj, self.target_collection)
+ )
+ setattr(obj, self.key, (id(obj), id(self), proxy))
+ return proxy
+
+ def set(self, obj, values):
+ if self.scalar:
+ creator = (
+ self.parent.creator
+ if self.parent.creator
+ else self.target_class
+ )
+ target = getattr(obj, self.target_collection)
+ if target is None:
+ if values is None:
+ return
+ setattr(obj, self.target_collection, creator(values))
+ else:
+ self._scalar_set(target, values)
+ if values is None and self.parent.cascade_scalar_deletes:
+ setattr(obj, self.target_collection, None)
+ else:
+ proxy = self.get(obj)
+ assert self.collection_class is not None
+ if proxy is not values:
+ proxy._bulk_replace(self, values)
+
+ def delete(self, obj):
+ if self.owning_class is None:
+ self._calc_owner(obj, None)
+
+ if self.scalar:
+ target = getattr(obj, self.target_collection)
+ if target is not None:
+ delattr(target, self.value_attr)
+ delattr(obj, self.target_collection)
+
+ def _new(self, lazy_collection):
+ creator = (
+ self.parent.creator if self.parent.creator else self.target_class
+ )
+ collection_class = util.duck_type_collection(lazy_collection())
+
+ if self.parent.proxy_factory:
+ return (
+ collection_class,
+ self.parent.proxy_factory(
+ lazy_collection, creator, self.value_attr, self
+ ),
+ )
+
+ if self.parent.getset_factory:
+ getter, setter = self.parent.getset_factory(collection_class, self)
+ else:
+ getter, setter = self.parent._default_getset(collection_class)
+
+ if collection_class is list:
+ return (
+ collection_class,
+ _AssociationList(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ elif collection_class is dict:
+ return (
+ collection_class,
+ _AssociationDict(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ elif collection_class is set:
+ return (
+ collection_class,
+ _AssociationSet(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ else:
+ raise exc.ArgumentError(
+ "could not guess which interface to use for "
+ 'collection_class "%s" backing "%s"; specify a '
+ "proxy_factory and proxy_bulk_set manually"
+ % (self.collection_class.__name__, self.target_collection)
+ )
+
+ def _set(self, proxy, values):
+ if self.parent.proxy_bulk_set:
+ self.parent.proxy_bulk_set(proxy, values)
+ elif self.collection_class is list:
+ proxy.extend(values)
+ elif self.collection_class is dict:
+ proxy.update(values)
+ elif self.collection_class is set:
+ proxy.update(values)
+ else:
+ raise exc.ArgumentError(
+ "no proxy_bulk_set supplied for custom "
+ "collection_class implementation"
+ )
+
+ def _inflate(self, proxy):
+ creator = (
+ self.parent.creator and self.parent.creator or self.target_class
+ )
+
+ if self.parent.getset_factory:
+ getter, setter = self.parent.getset_factory(
+ self.collection_class, self
+ )
+ else:
+ getter, setter = self.parent._default_getset(self.collection_class)
+
+ proxy.creator = creator
+ proxy.getter = getter
+ proxy.setter = setter
+
+ def _criterion_exists(self, criterion=None, **kwargs):
+ is_has = kwargs.pop("is_has", None)
+
+ target_assoc = self._unwrap_target_assoc_proxy
+ if target_assoc is not None:
+ inner = target_assoc._criterion_exists(
+ criterion=criterion, **kwargs
+ )
+ return self._comparator._criterion_exists(inner)
+
+ if self._target_is_object:
+ prop = getattr(self.target_class, self.value_attr)
+ value_expr = prop._criterion_exists(criterion, **kwargs)
+ else:
+ if kwargs:
+ raise exc.ArgumentError(
+ "Can't apply keyword arguments to column-targeted "
+ "association proxy; use =="
+ )
+ elif is_has and criterion is not None:
+ raise exc.ArgumentError(
+ "Non-empty has() not allowed for "
+ "column-targeted association proxy; use =="
+ )
+
+ value_expr = criterion
+
+ return self._comparator._criterion_exists(value_expr)
+
+ def any(self, criterion=None, **kwargs):
+ """Produce a proxied 'any' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`
+ and/or :meth:`.RelationshipProperty.Comparator.has`
+ operators of the underlying proxied attributes.
+
+ """
+ if self._unwrap_target_assoc_proxy is None and (
+ self.scalar
+ and (not self._target_is_object or self._value_is_scalar)
+ ):
+ raise exc.InvalidRequestError(
+ "'any()' not implemented for scalar " "attributes. Use has()."
+ )
+ return self._criterion_exists(
+ criterion=criterion, is_has=False, **kwargs
+ )
+
+ def has(self, criterion=None, **kwargs):
+ """Produce a proxied 'has' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`
+ and/or :meth:`.RelationshipProperty.Comparator.has`
+ operators of the underlying proxied attributes.
+
+ """
+ if self._unwrap_target_assoc_proxy is None and (
+ not self.scalar
+ or (self._target_is_object and not self._value_is_scalar)
+ ):
+ raise exc.InvalidRequestError(
+ "'has()' not implemented for collections. " "Use any()."
+ )
+ return self._criterion_exists(
+ criterion=criterion, is_has=True, **kwargs
+ )
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.parent)
+
+
+class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
+ """an :class:`.AssociationProxyInstance` where we cannot determine
+ the type of target object.
+ """
+
+ _is_canonical = False
+
+ def _ambiguous(self):
+ raise AttributeError(
+ "Association proxy %s.%s refers to an attribute '%s' that is not "
+ "directly mapped on class %s; therefore this operation cannot "
+ "proceed since we don't know what type of object is referred "
+ "towards"
+ % (
+ self.owning_class.__name__,
+ self.target_collection,
+ self.value_attr,
+ self.target_class,
+ )
+ )
+
+ def get(self, obj):
+ if obj is None:
+ return self
+ else:
+ return super(AmbiguousAssociationProxyInstance, self).get(obj)
+
+ def __eq__(self, obj):
+ self._ambiguous()
+
+ def __ne__(self, obj):
+ self._ambiguous()
+
+ def any(self, criterion=None, **kwargs):
+ self._ambiguous()
+
+ def has(self, criterion=None, **kwargs):
+ self._ambiguous()
+
+ @util.memoized_property
+ def _lookup_cache(self):
+ # mapping of <subclass>->AssociationProxyInstance.
+ # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist;
+ # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2
+ return {}
+
+ def _non_canonical_get_for_object(self, parent_instance):
+ if parent_instance is not None:
+ actual_obj = getattr(parent_instance, self.target_collection)
+ if actual_obj is not None:
+ try:
+ insp = inspect(actual_obj)
+ except exc.NoInspectionAvailable:
+ pass
+ else:
+ mapper = insp.mapper
+ instance_class = mapper.class_
+ if instance_class not in self._lookup_cache:
+ self._populate_cache(instance_class, mapper)
+
+ try:
+ return self._lookup_cache[instance_class]
+ except KeyError:
+ pass
+
+ # no object or ambiguous object given, so return "self", which
+ # is a proxy with generally only instance-level functionality
+ return self
+
+ def _populate_cache(self, instance_class, mapper):
+ prop = orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
+
+ if mapper.isa(prop.mapper):
+ target_class = instance_class
+ try:
+ target_assoc = self._cls_unwrap_target_assoc_proxy(
+ target_class, self.value_attr
+ )
+ except AttributeError:
+ pass
+ else:
+ self._lookup_cache[instance_class] = self._construct_for_assoc(
+ target_assoc,
+ self.parent,
+ self.owning_class,
+ target_class,
+ self.value_attr,
+ )
+
+
+class ObjectAssociationProxyInstance(AssociationProxyInstance):
+ """an :class:`.AssociationProxyInstance` that has an object as a target."""
+
+ _target_is_object = True
+ _is_canonical = True
+
+ def contains(self, obj):
+ """Produce a proxied 'contains' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`,
+ :meth:`.RelationshipProperty.Comparator.has`,
+ and/or :meth:`.RelationshipProperty.Comparator.contains`
+ operators of the underlying proxied attributes.
+ """
+
+ target_assoc = self._unwrap_target_assoc_proxy
+ if target_assoc is not None:
+ return self._comparator._criterion_exists(
+ target_assoc.contains(obj)
+ if not target_assoc.scalar
+ else target_assoc == obj
+ )
+ elif (
+ self._target_is_object
+ and self.scalar
+ and not self._value_is_scalar
+ ):
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr).contains(obj)
+ )
+ elif self._target_is_object and self.scalar and self._value_is_scalar:
+ raise exc.InvalidRequestError(
+ "contains() doesn't apply to a scalar object endpoint; use =="
+ )
+ else:
+
+ return self._comparator._criterion_exists(**{self.value_attr: obj})
+
+ def __eq__(self, obj):
+ # note the has() here will fail for collections; eq_()
+ # is only allowed with a scalar.
+ if obj is None:
+ return or_(
+ self._comparator.has(**{self.value_attr: obj}),
+ self._comparator == None,
+ )
+ else:
+ return self._comparator.has(**{self.value_attr: obj})
+
+ def __ne__(self, obj):
+ # note the has() here will fail for collections; eq_()
+ # is only allowed with a scalar.
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr) != obj
+ )
+
+
+class ColumnAssociationProxyInstance(
+ ColumnOperators, AssociationProxyInstance
+):
+ """an :class:`.AssociationProxyInstance` that has a database column as a
+ target.
+ """
+
+ _target_is_object = False
+ _is_canonical = True
+
+ def __eq__(self, other):
+ # special case "is None" to check for no related row as well
+ expr = self._criterion_exists(
+ self.remote_attr.operate(operator.eq, other)
+ )
+ if other is None:
+ return or_(expr, self._comparator == None)
+ else:
+ return expr
+
+ def operate(self, op, *other, **kwargs):
+ return self._criterion_exists(
+ self.remote_attr.operate(op, *other, **kwargs)
+ )
+
+
+class _lazy_collection(object):
+ def __init__(self, obj, target):
+ self.parent = obj
+ self.target = target
+
+ def __call__(self):
+ return getattr(self.parent, self.target)
+
+ def __getstate__(self):
+ return {"obj": self.parent, "target": self.target}
+
+ def __setstate__(self, state):
+ self.parent = state["obj"]
+ self.target = state["target"]
+
+
+class _AssociationCollection(object):
+ def __init__(self, lazy_collection, creator, getter, setter, parent):
+ """Constructs an _AssociationCollection.
+
+ This will always be a subclass of either _AssociationList,
+ _AssociationSet, or _AssociationDict.
+
+ lazy_collection
+ A callable returning a list-based collection of entities (usually an
+ object attribute managed by a SQLAlchemy relationship())
+
+ creator
+ A function that creates new target entities. Given one parameter:
+ value. This assertion is assumed::
+
+ obj = creator(somevalue)
+ assert getter(obj) == somevalue
+
+ getter
+ A function. Given an associated object, return the 'value'.
+
+ setter
+ A function. Given an associated object and a value, store that
+ value on the object.
+
+ """
+ self.lazy_collection = lazy_collection
+ self.creator = creator
+ self.getter = getter
+ self.setter = setter
+ self.parent = parent
+
+ col = property(lambda self: self.lazy_collection())
+
+ def __len__(self):
+ return len(self.col)
+
+ def __bool__(self):
+ return bool(self.col)
+
+ __nonzero__ = __bool__
+
+ def __getstate__(self):
+ return {"parent": self.parent, "lazy_collection": self.lazy_collection}
+
+ def __setstate__(self, state):
+ self.parent = state["parent"]
+ self.lazy_collection = state["lazy_collection"]
+ self.parent._inflate(self)
+
+ def _bulk_replace(self, assoc_proxy, values):
+ self.clear()
+ assoc_proxy._set(self, values)
+
+
+class _AssociationList(_AssociationCollection):
+ """Generic, converting, list-to-list proxy."""
+
+ def _create(self, value):
+ return self.creator(value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def _set(self, object_, value):
+ return self.setter(object_, value)
+
+ def __getitem__(self, index):
+ if not isinstance(index, slice):
+ return self._get(self.col[index])
+ else:
+ return [self._get(member) for member in self.col[index]]
+
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ self._set(self.col[index], value)
+ else:
+ if index.stop is None:
+ stop = len(self)
+ elif index.stop < 0:
+ stop = len(self) + index.stop
+ else:
+ stop = index.stop
+ step = index.step or 1
+
+ start = index.start or 0
+ rng = list(range(index.start or 0, stop, step))
+ if step == 1:
+ for i in rng:
+ del self[start]
+ i = start
+ for item in value:
+ self.insert(i, item)
+ i += 1
+ else:
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value), len(rng))
+ )
+ for i, item in zip(rng, value):
+ self._set(self.col[i], item)
+
+ def __delitem__(self, index):
+ del self.col[index]
+
+ def __contains__(self, value):
+ for member in self.col:
+ # testlib.pragma exempt:__eq__
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __getslice__(self, start, end):
+ return [self._get(member) for member in self.col[start:end]]
+
+ def __setslice__(self, start, end, values):
+ members = [self._create(v) for v in values]
+ self.col[start:end] = members
+
+ def __delslice__(self, start, end):
+ del self.col[start:end]
+
+ def __iter__(self):
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or
+ just use the underlying collection directly from its property
+ on the parent.
+ """
+
+ for member in self.col:
+ yield self._get(member)
+ return
+
+ def append(self, value):
+ col = self.col
+ item = self._create(value)
+ col.append(item)
+
+ def count(self, value):
+ return sum(
+ [
+ 1
+ for _ in util.itertools_filter(
+ lambda v: v == value, iter(self)
+ )
+ ]
+ )
+
+ def extend(self, values):
+ for v in values:
+ self.append(v)
+
+ def insert(self, index, value):
+ self.col[index:index] = [self._create(value)]
+
+ def pop(self, index=-1):
+ return self.getter(self.col.pop(index))
+
+ def remove(self, value):
+ for i, val in enumerate(self):
+ if val == value:
+ del self.col[i]
+ return
+ raise ValueError("value not in list")
+
+ def reverse(self):
+ """Not supported, use reversed(mylist)"""
+
+ raise NotImplementedError
+
+ def sort(self):
+ """Not supported, use sorted(mylist)"""
+
+ raise NotImplementedError
+
+ def clear(self):
+ del self.col[0 : len(self.col)]
+
+ def __eq__(self, other):
+ return list(self) == other
+
+ def __ne__(self, other):
+ return list(self) != other
+
+ def __lt__(self, other):
+ return list(self) < other
+
+ def __le__(self, other):
+ return list(self) <= other
+
+ def __gt__(self, other):
+ return list(self) > other
+
+ def __ge__(self, other):
+ return list(self) >= other
+
+ def __cmp__(self, other):
+ return util.cmp(list(self), other)
+
+ def __add__(self, iterable):
+ try:
+ other = list(iterable)
+ except TypeError:
+ return NotImplemented
+ return list(self) + other
+
+ def __radd__(self, iterable):
+ try:
+ other = list(iterable)
+ except TypeError:
+ return NotImplemented
+ return other + list(self)
+
+ def __mul__(self, n):
+ if not isinstance(n, int):
+ return NotImplemented
+ return list(self) * n
+
+ __rmul__ = __mul__
+
+ def __iadd__(self, iterable):
+ self.extend(iterable)
+ return self
+
+ def __imul__(self, n):
+ # unlike a regular list *=, proxied __imul__ will generate unique
+ # backing objects for each copy. *= on proxied lists is a bit of
+ # a stretch anyhow, and this interpretation of the __imul__ contract
+ # is more plausibly useful than copying the backing objects.
+ if not isinstance(n, int):
+ return NotImplemented
+ if n == 0:
+ self.clear()
+ elif n > 1:
+ self.extend(list(self) * (n - 1))
+ return self
+
+ def index(self, item, *args):
+ return list(self).index(item, *args)
+
+ def copy(self):
+ return list(self)
+
+ def __repr__(self):
+ return repr(list(self))
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
+ func.__doc__ = getattr(list, func_name).__doc__
+ del func_name, func
+
+
+_NotProvided = util.symbol("_NotProvided")
+
+
+class _AssociationDict(_AssociationCollection):
+ """Generic, converting, dict-to-dict proxy."""
+
+ def _create(self, key, value):
+ return self.creator(key, value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def _set(self, object_, key, value):
+ return self.setter(object_, key, value)
+
+ def __getitem__(self, key):
+ return self._get(self.col[key])
+
+ def __setitem__(self, key, value):
+ if key in self.col:
+ self._set(self.col[key], key, value)
+ else:
+ self.col[key] = self._create(key, value)
+
+ def __delitem__(self, key):
+ del self.col[key]
+
+ def __contains__(self, key):
+ # testlib.pragma exempt:__hash__
+ return key in self.col
+
+ def has_key(self, key):
+ # testlib.pragma exempt:__hash__
+ return key in self.col
+
+ def __iter__(self):
+ return iter(self.col.keys())
+
+ def clear(self):
+ self.col.clear()
+
+ def __eq__(self, other):
+ return dict(self) == other
+
+ def __ne__(self, other):
+ return dict(self) != other
+
+ def __lt__(self, other):
+ return dict(self) < other
+
+ def __le__(self, other):
+ return dict(self) <= other
+
+ def __gt__(self, other):
+ return dict(self) > other
+
+ def __ge__(self, other):
+ return dict(self) >= other
+
+ def __cmp__(self, other):
+ return util.cmp(dict(self), other)
+
+ def __repr__(self):
+ return repr(dict(self.items()))
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def setdefault(self, key, default=None):
+ if key not in self.col:
+ self.col[key] = self._create(key, default)
+ return default
+ else:
+ return self[key]
+
+ def keys(self):
+ return self.col.keys()
+
+ if util.py2k:
+
+ def iteritems(self):
+ return ((key, self._get(self.col[key])) for key in self.col)
+
+ def itervalues(self):
+ return (self._get(self.col[key]) for key in self.col)
+
+ def iterkeys(self):
+ return self.col.iterkeys()
+
+ def values(self):
+ return [self._get(member) for member in self.col.values()]
+
+ def items(self):
+ return [(k, self._get(self.col[k])) for k in self]
+
+ else:
+
+ def items(self):
+ return ((key, self._get(self.col[key])) for key in self.col)
+
+ def values(self):
+ return (self._get(self.col[key]) for key in self.col)
+
+ def pop(self, key, default=_NotProvided):
+ if default is _NotProvided:
+ member = self.col.pop(key)
+ else:
+ member = self.col.pop(key, default)
+ return self._get(member)
+
+ def popitem(self):
+ item = self.col.popitem()
+ return (item[0], self._get(item[1]))
+
+ def update(self, *a, **kw):
+ if len(a) > 1:
+ raise TypeError(
+ "update expected at most 1 arguments, got %i" % len(a)
+ )
+ elif len(a) == 1:
+ seq_or_map = a[0]
+ # discern dict from sequence - took the advice from
+ # https://www.voidspace.org.uk/python/articles/duck_typing.shtml
+ # still not perfect :(
+ if hasattr(seq_or_map, "keys"):
+ for item in seq_or_map:
+ self[item] = seq_or_map[item]
+ else:
+ try:
+ for k, v in seq_or_map:
+ self[k] = v
+ except ValueError as err:
+ util.raise_(
+ ValueError(
+ "dictionary update sequence "
+ "requires 2-element tuples"
+ ),
+ replace_context=err,
+ )
+
+ for key, value in kw:
+ self[key] = value
+
+ def _bulk_replace(self, assoc_proxy, values):
+ existing = set(self)
+ constants = existing.intersection(values or ())
+ additions = set(values or ()).difference(constants)
+ removals = existing.difference(constants)
+
+ for key, member in values.items() or ():
+ if key in additions:
+ self[key] = member
+ elif key in constants:
+ self[key] = member
+
+ for key in removals:
+ del self[key]
+
+ def copy(self):
+ return dict(self.items())
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(dict, func_name)
+ ):
+ func.__doc__ = getattr(dict, func_name).__doc__
+ del func_name, func
+
+
+class _AssociationSet(_AssociationCollection):
+ """Generic, converting, set-to-set proxy."""
+
+ def _create(self, value):
+ return self.creator(value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def __len__(self):
+ return len(self.col)
+
+ def __bool__(self):
+ if self.col:
+ return True
+ else:
+ return False
+
+ __nonzero__ = __bool__
+
+ def __contains__(self, value):
+ for member in self.col:
+ # testlib.pragma exempt:__eq__
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __iter__(self):
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or just use
+ the underlying collection directly from its property on the parent.
+
+ """
+ for member in self.col:
+ yield self._get(member)
+ return
+
+ def add(self, value):
+ if value not in self:
+ self.col.add(self._create(value))
+
+ # for discard and remove, choosing a more expensive check strategy rather
+ # than call self.creator()
+ def discard(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ break
+
+ def remove(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ return
+ raise KeyError(value)
+
+ def pop(self):
+ if not self.col:
+ raise KeyError("pop from an empty set")
+ member = self.col.pop()
+ return self._get(member)
+
+ def update(self, other):
+ for value in other:
+ self.add(value)
+
+ def _bulk_replace(self, assoc_proxy, values):
+ existing = set(self)
+ constants = existing.intersection(values or ())
+ additions = set(values or ()).difference(constants)
+ removals = existing.difference(constants)
+
+ appender = self.add
+ remover = self.remove
+
+ for member in values or ():
+ if member in additions:
+ appender(member)
+ elif member in constants:
+ appender(member)
+
+ for member in removals:
+ remover(member)
+
+ def __ior__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ for value in other:
+ self.add(value)
+ return self
+
+ def _set(self):
+ return set(iter(self))
+
+ def union(self, other):
+ return set(self).union(other)
+
+ __or__ = union
+
+ def difference(self, other):
+ return set(self).difference(other)
+
+ __sub__ = difference
+
+ def difference_update(self, other):
+ for value in other:
+ self.discard(value)
+
+ def __isub__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ for value in other:
+ self.discard(value)
+ return self
+
+ def intersection(self, other):
+ return set(self).intersection(other)
+
+ __and__ = intersection
+
+ def intersection_update(self, other):
+ want, have = self.intersection(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ def __iand__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.intersection(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
+
+ def symmetric_difference(self, other):
+ return set(self).symmetric_difference(other)
+
+ __xor__ = symmetric_difference
+
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ def __ixor__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.symmetric_difference(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
+
+ def issubset(self, other):
+ return set(self).issubset(other)
+
+ def issuperset(self, other):
+ return set(self).issuperset(other)
+
+ def clear(self):
+ self.col.clear()
+
+ def copy(self):
+ return set(self)
+
+ def __eq__(self, other):
+ return set(self) == other
+
+ def __ne__(self, other):
+ return set(self) != other
+
+ def __lt__(self, other):
+ return set(self) < other
+
+ def __le__(self, other):
+ return set(self) <= other
+
+ def __gt__(self, other):
+ return set(self) > other
+
+ def __ge__(self, other):
+ return set(self) >= other
+
+ def __repr__(self):
+ return repr(set(self))
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(set, func_name)
+ ):
+ func.__doc__ = getattr(set, func_name).__doc__
+ del func_name, func
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
new file mode 100644
index 0000000..15b2cb0
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/__init__.py
@@ -0,0 +1,22 @@
+# ext/asyncio/__init__.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .engine import async_engine_from_config
+from .engine import AsyncConnection
+from .engine import AsyncEngine
+from .engine import AsyncTransaction
+from .engine import create_async_engine
+from .events import AsyncConnectionEvents
+from .events import AsyncSessionEvents
+from .result import AsyncMappingResult
+from .result import AsyncResult
+from .result import AsyncScalarResult
+from .scoping import async_scoped_session
+from .session import async_object_session
+from .session import async_session
+from .session import AsyncSession
+from .session import AsyncSessionTransaction
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
new file mode 100644
index 0000000..3f77f55
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -0,0 +1,89 @@
+import abc
+import functools
+import weakref
+
+from . import exc as async_exc
+
+
+class ReversibleProxy:
+ # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
+ _proxy_objects = {}
+ __slots__ = ("__weakref__",)
+
+ def _assign_proxied(self, target):
+ if target is not None:
+ target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ proxy_ref = weakref.ref(
+ self,
+ functools.partial(ReversibleProxy._target_gced, target_ref),
+ )
+ ReversibleProxy._proxy_objects[target_ref] = proxy_ref
+
+ return target
+
+ @classmethod
+ def _target_gced(cls, ref, proxy_ref=None):
+ cls._proxy_objects.pop(ref, None)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ raise NotImplementedError()
+
+ @classmethod
+ def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ try:
+ proxy_ref = cls._proxy_objects[weakref.ref(target)]
+ except KeyError:
+ pass
+ else:
+ proxy = proxy_ref()
+ if proxy is not None:
+ return proxy
+
+ if regenerate:
+ return cls._regenerate_proxy_for_target(target)
+ else:
+ return None
+
+
+class StartableContext(abc.ABC):
+ __slots__ = ()
+
+ @abc.abstractmethod
+ async def start(self, is_ctxmanager=False):
+ pass
+
+ def __await__(self):
+ return self.start().__await__()
+
+ async def __aenter__(self):
+ return await self.start(is_ctxmanager=True)
+
+ @abc.abstractmethod
+ async def __aexit__(self, type_, value, traceback):
+ pass
+
+ def _raise_for_not_started(self):
+ raise async_exc.AsyncContextNotStarted(
+ "%s context has not been started and object has not been awaited."
+ % (self.__class__.__name__)
+ )
+
+
+class ProxyComparable(ReversibleProxy):
+ __slots__ = ()
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, self.__class__)
+ and self._proxied == other._proxied
+ )
+
+ def __ne__(self, other):
+ return (
+ not isinstance(other, self.__class__)
+ or self._proxied != other._proxied
+ )
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
new file mode 100644
index 0000000..4fbe4f7
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -0,0 +1,828 @@
+# ext/asyncio/engine.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+import asyncio
+
+from . import exc as async_exc
+from .base import ProxyComparable
+from .base import StartableContext
+from .result import _ensure_sync_result
+from .result import AsyncResult
+from ... import exc
+from ... import inspection
+from ... import util
+from ...engine import create_engine as _create_engine
+from ...engine.base import NestedTransaction
+from ...future import Connection
+from ...future import Engine
+from ...util.concurrency import greenlet_spawn
+
+
+def create_async_engine(*arg, **kw):
+ """Create a new async engine instance.
+
+ Arguments passed to :func:`_asyncio.create_async_engine` are mostly
+ identical to those passed to the :func:`_sa.create_engine` function.
+ The specified dialect must be an asyncio-compatible dialect
+ such as :ref:`dialect-postgresql-asyncpg`.
+
+ .. versionadded:: 1.4
+
+ """
+
+ if kw.get("server_side_cursors", False):
+ raise async_exc.AsyncMethodRequired(
+ "Can't set server_side_cursors for async engine globally; "
+ "use the connection.stream() method for an async "
+ "streaming result set"
+ )
+ kw["future"] = True
+ sync_engine = _create_engine(*arg, **kw)
+ return AsyncEngine(sync_engine)
+
+
+def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+ """Create a new AsyncEngine instance using a configuration dictionary.
+
+ This function is analogous to the :func:`_sa.engine_from_config` function
+ in SQLAlchemy Core, except that the requested dialect must be an
+ asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`.
+ The argument signature of the function is identical to that
+ of :func:`_sa.engine_from_config`.
+
+ .. versionadded:: 1.4.29
+
+ """
+ options = {
+ key[len(prefix) :]: value
+ for key, value in configuration.items()
+ if key.startswith(prefix)
+ }
+ options["_coerce_config"] = True
+ options.update(kwargs)
+ url = options.pop("url")
+ return create_async_engine(url, **options)
+
+
+class AsyncConnectable:
+ __slots__ = "_slots_dispatch", "__weakref__"
+
+
+@util.create_proxy_methods(
+ Connection,
+ ":class:`_future.Connection`",
+ ":class:`_asyncio.AsyncConnection`",
+ classmethods=[],
+ methods=[],
+ attributes=[
+ "closed",
+ "invalidated",
+ "dialect",
+ "default_isolation_level",
+ ],
+)
+class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
+ """An asyncio proxy for a :class:`_engine.Connection`.
+
+ :class:`_asyncio.AsyncConnection` is acquired using the
+ :meth:`_asyncio.AsyncEngine.connect`
+ method of :class:`_asyncio.AsyncEngine`::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ async with engine.connect() as conn:
+ result = await conn.execute(select(table))
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ # AsyncConnection is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncConnection that matches this one given only the
+ # "sync" elements.
+ __slots__ = (
+ "engine",
+ "sync_engine",
+ "sync_connection",
+ )
+
+ def __init__(self, async_engine, sync_connection=None):
+ self.engine = async_engine
+ self.sync_engine = async_engine.sync_engine
+ self.sync_connection = self._assign_proxied(sync_connection)
+
+ sync_connection: Connection
+ """Reference to the sync-style :class:`_engine.Connection` this
+ :class:`_asyncio.AsyncConnection` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncConnection` is associated with via its underlying
+ :class:`_engine.Connection`.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncConnection(
+ AsyncEngine._retrieve_proxy_for_target(target.engine), target
+ )
+
+ async def start(self, is_ctxmanager=False):
+ """Start this :class:`_asyncio.AsyncConnection` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+ if self.sync_connection:
+ raise exc.InvalidRequestError("connection is already started")
+ self.sync_connection = self._assign_proxied(
+ await (greenlet_spawn(self.sync_engine.connect))
+ )
+ return self
+
+ @property
+ def connection(self):
+ """Not implemented for async; call
+ :meth:`_asyncio.AsyncConnection.get_raw_connection`.
+ """
+ raise exc.InvalidRequestError(
+ "AsyncConnection.connection accessor is not implemented as the "
+ "attribute may need to reconnect on an invalidated connection. "
+ "Use the get_raw_connection() method."
+ )
+
+ async def get_raw_connection(self):
+ """Return the pooled DBAPI-level connection in use by this
+ :class:`_asyncio.AsyncConnection`.
+
+ This is a SQLAlchemy connection-pool proxied connection
+ which then has the attribute
+ :attr:`_pool._ConnectionFairy.driver_connection` that refers to the
+ actual driver connection. Its
+ :attr:`_pool._ConnectionFairy.dbapi_connection` refers instead
+ to an :class:`_engine.AdaptedConnection` instance that
+ adapts the driver connection to the DBAPI protocol.
+
+ """
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(getattr, conn, "connection")
+
+ @property
+ def _proxied(self):
+ return self.sync_connection
+
+ @property
+ def info(self):
+ """Return the :attr:`_engine.Connection.info` dictionary of the
+ underlying :class:`_engine.Connection`.
+
+ This dictionary is freely writable for user-defined state to be
+ associated with the database connection.
+
+ This attribute is only available if the :class:`.AsyncConnection` is
+ currently connected. If the :attr:`.AsyncConnection.closed` attribute
+ is ``True``, then accessing this attribute will raise
+ :class:`.ResourceClosedError`.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ return self.sync_connection.info
+
+ def _sync_connection(self):
+ if not self.sync_connection:
+ self._raise_for_not_started()
+ return self.sync_connection
+
+ def begin(self):
+ """Begin a transaction prior to autobegin occurring."""
+ self._sync_connection()
+ return AsyncTransaction(self)
+
+ def begin_nested(self):
+ """Begin a nested transaction and return a transaction handle."""
+ self._sync_connection()
+ return AsyncTransaction(self, nested=True)
+
+ async def invalidate(self, exception=None):
+ """Invalidate the underlying DBAPI connection associated with
+ this :class:`_engine.Connection`.
+
+ See the method :meth:`_engine.Connection.invalidate` for full
+ detail on this method.
+
+ """
+
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.invalidate, exception=exception)
+
+ async def get_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ async def set_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ def in_transaction(self):
+ """Return True if a transaction is in progress.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+
+ conn = self._sync_connection()
+
+ return conn.in_transaction()
+
+ def in_nested_transaction(self):
+ """Return True if a transaction is in progress.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ return conn.in_nested_transaction()
+
+ def get_transaction(self):
+ """Return an :class:`.AsyncTransaction` representing the current
+ transaction, if any.
+
+ This makes use of the underlying synchronous connection's
+ :meth:`_engine.Connection.get_transaction` method to get the current
+ :class:`_engine.Transaction`, which is then proxied in a new
+ :class:`.AsyncTransaction` object.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ trans = conn.get_transaction()
+ if trans is not None:
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return an :class:`.AsyncTransaction` representing the current
+ nested (savepoint) transaction, if any.
+
+ This makes use of the underlying synchronous connection's
+ :meth:`_engine.Connection.get_nested_transaction` method to get the
+ current :class:`_engine.Transaction`, which is then proxied in a new
+ :class:`.AsyncTransaction` object.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ trans = conn.get_nested_transaction()
+ if trans is not None:
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ async def execution_options(self, **opt):
+ r"""Set non-SQL options for the connection which take effect
+ during execution.
+
+ This returns this :class:`_asyncio.AsyncConnection` object with
+ the new options added.
+
+ See :meth:`_future.Connection.execution_options` for full details
+ on this method.
+
+ """
+
+ conn = self._sync_connection()
+ c2 = await greenlet_spawn(conn.execution_options, **opt)
+ assert c2 is conn
+ return self
+
+ async def commit(self):
+ """Commit the transaction that is currently in progress.
+
+ This method commits the current transaction if one has been started.
+ If no transaction was started, the method has no effect, assuming
+ the connection is in a non-invalidated state.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.commit)
+
+ async def rollback(self):
+ """Roll back the transaction that is currently in progress.
+
+ This method rolls back the current transaction if one has been started.
+ If no transaction was started, the method has no effect. If a
+ transaction was started and the connection is in an invalidated state,
+ the transaction is cleared using this method.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.rollback)
+
+ async def close(self):
+ """Close this :class:`_asyncio.AsyncConnection`.
+
+ This has the effect of also rolling back the transaction if one
+ is in place.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.close)
+
+ async def exec_driver_sql(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a driver-level SQL string and return buffered
+ :class:`_engine.Result`.
+
+ """
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn.exec_driver_sql,
+ statement,
+ parameters,
+ execution_options,
+ _require_await=True,
+ )
+
+ return await _ensure_sync_result(result, self.exec_driver_sql)
+
+ async def stream(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ """Execute a statement and return a streaming
+ :class:`_asyncio.AsyncResult` object."""
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20,
+ statement,
+ parameters,
+ util.EMPTY_DICT.merge_with(
+ execution_options, {"stream_results": True}
+ ),
+ _require_await=True,
+ )
+ if not result.context._is_server_side:
+ # TODO: real exception here
+ assert False, "server side result expected"
+ return AsyncResult(result)
+
+ async def execute(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and return a buffered
+ :class:`_engine.Result`.
+
+ :param object: The statement to be executed. This is always
+ an object that is in both the :class:`_expression.ClauseElement` and
+ :class:`_expression.Executable` hierarchies, including:
+
+ * :class:`_expression.Select`
+ * :class:`_expression.Insert`, :class:`_expression.Update`,
+ :class:`_expression.Delete`
+ * :class:`_expression.TextClause` and
+ :class:`_expression.TextualSelect`
+ * :class:`_schema.DDL` and objects which inherit from
+ :class:`_schema.DDLElement`
+
+ :param parameters: parameters which will be bound into the statement.
+ This may be either a dictionary of parameter names to values,
+ or a mutable sequence (e.g. a list) of dictionaries. When a
+ list of dictionaries is passed, the underlying statement execution
+ will make use of the DBAPI ``cursor.executemany()`` method.
+ When a single dictionary is passed, the DBAPI ``cursor.execute()``
+ method will be used.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_future.Connection.execution_options`.
+
+ :return: a :class:`_engine.Result` object.
+
+ """
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20,
+ statement,
+ parameters,
+ execution_options,
+ _require_await=True,
+ )
+ return await _ensure_sync_result(result, self.execute)
+
+ async def scalar(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and returns a scalar object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalar` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a scalar Python value representing the first column of the
+ first row returned.
+
+ """
+ result = await self.execute(statement, parameters, execution_options)
+ return result.scalar()
+
+ async def scalars(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and returns a scalar objects.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalars` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a :class:`_engine.ScalarResult` object.
+
+ .. versionadded:: 1.4.24
+
+ """
+ result = await self.execute(statement, parameters, execution_options)
+ return result.scalars()
+
+ async def stream_scalars(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement and returns a streaming scalar result
+ object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.AsyncResult.scalars` method after invoking the
+ :meth:`_future.Connection.stream` method. Parameters are equivalent.
+
+ :return: an :class:`_asyncio.AsyncScalarResult` object.
+
+ .. versionadded:: 1.4.24
+
+ """
+ result = await self.stream(statement, parameters, execution_options)
+ return result.scalars()
+
+ async def run_sync(self, fn, *arg, **kw):
+ """Invoke the given sync callable passing self as the first argument.
+
+ This method maintains the asyncio event loop all the way through
+ to the database connection by running the given callable in a
+ specially instrumented greenlet.
+
+ E.g.::
+
+ with async_engine.begin() as conn:
+ await conn.run_sync(metadata.create_all)
+
+ .. note::
+
+ The provided callable is invoked inline within the asyncio event
+ loop, and will block on traditional IO calls. IO within this
+ callable should only call into SQLAlchemy's asyncio database
+ APIs which will be properly adapted to the greenlet context.
+
+ .. seealso::
+
+ :ref:`session_run_sync`
+ """
+
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(fn, conn, *arg, **kw)
+
+ def __await__(self):
+ return self.start().__await__()
+
+ async def __aexit__(self, type_, value, traceback):
+ await asyncio.shield(self.close())
+
+
+@util.create_proxy_methods(
+ Engine,
+ ":class:`_future.Engine`",
+ ":class:`_asyncio.AsyncEngine`",
+ classmethods=[],
+ methods=[
+ "clear_compiled_cache",
+ "update_execution_options",
+ "get_execution_options",
+ ],
+ attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
+)
+class AsyncEngine(ProxyComparable, AsyncConnectable):
+ """An asyncio proxy for a :class:`_engine.Engine`.
+
+ :class:`_asyncio.AsyncEngine` is acquired using the
+ :func:`_asyncio.create_async_engine` function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ # AsyncEngine is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncEngine that matches this one given only the
+ # "sync" elements.
+ __slots__ = ("sync_engine", "_proxied")
+
+ _connection_cls = AsyncConnection
+
+ _option_cls: type
+
+ class _trans_ctx(StartableContext):
+ def __init__(self, conn):
+ self.conn = conn
+
+ async def start(self, is_ctxmanager=False):
+ await self.conn.start(is_ctxmanager=is_ctxmanager)
+ self.transaction = self.conn.begin()
+ await self.transaction.__aenter__()
+
+ return self.conn
+
+ async def __aexit__(self, type_, value, traceback):
+ async def go():
+ await self.transaction.__aexit__(type_, value, traceback)
+ await self.conn.close()
+
+ await asyncio.shield(go())
+
+ def __init__(self, sync_engine):
+ if not sync_engine.dialect.is_async:
+ raise exc.InvalidRequestError(
+ "The asyncio extension requires an async driver to be used. "
+ f"The loaded {sync_engine.dialect.driver!r} is not async."
+ )
+ self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
+
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncEngine` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncEngine(target)
+
+ def begin(self):
+ """Return a context manager which when entered will deliver an
+ :class:`_asyncio.AsyncConnection` with an
+ :class:`_asyncio.AsyncTransaction` established.
+
+ E.g.::
+
+ async with async_engine.begin() as conn:
+ await conn.execute(
+ text("insert into table (x, y, z) values (1, 2, 3)")
+ )
+ await conn.execute(text("my_special_procedure(5)"))
+
+
+ """
+ conn = self.connect()
+ return self._trans_ctx(conn)
+
+ def connect(self):
+ """Return an :class:`_asyncio.AsyncConnection` object.
+
+ The :class:`_asyncio.AsyncConnection` will procure a database
+ connection from the underlying connection pool when it is entered
+ as an async context manager::
+
+ async with async_engine.connect() as conn:
+ result = await conn.execute(select(user_table))
+
+ The :class:`_asyncio.AsyncConnection` may also be started outside of a
+ context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
+ method.
+
+ """
+
+ return self._connection_cls(self)
+
+ async def raw_connection(self):
+ """Return a "raw" DBAPI connection from the connection pool.
+
+ .. seealso::
+
+ :ref:`dbapi_connections`
+
+ """
+ return await greenlet_spawn(self.sync_engine.raw_connection)
+
+ def execution_options(self, **opt):
+ """Return a new :class:`_asyncio.AsyncEngine` that will provide
+ :class:`_asyncio.AsyncConnection` objects with the given execution
+ options.
+
+ Proxied from :meth:`_future.Engine.execution_options`. See that
+ method for details.
+
+ """
+
+ return AsyncEngine(self.sync_engine.execution_options(**opt))
+
+ async def dispose(self):
+ """Dispose of the connection pool used by this
+ :class:`_asyncio.AsyncEngine`.
+
+ This will close all connection pool connections that are
+ **currently checked in**. See the documentation for the underlying
+ :meth:`_future.Engine.dispose` method for further notes.
+
+ .. seealso::
+
+ :meth:`_future.Engine.dispose`
+
+ """
+
+ await greenlet_spawn(self.sync_engine.dispose)
+
+
+class AsyncTransaction(ProxyComparable, StartableContext):
+ """An asyncio proxy for a :class:`_engine.Transaction`."""
+
+ __slots__ = ("connection", "sync_transaction", "nested")
+
+ def __init__(self, connection, nested=False):
+ self.connection = connection # AsyncConnection
+ self.sync_transaction = None # sqlalchemy.engine.Transaction
+ self.nested = nested
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ sync_connection = target.connection
+ sync_transaction = target
+ nested = isinstance(target, NestedTransaction)
+
+ async_connection = AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+ assert async_connection is not None
+
+ obj = cls.__new__(cls)
+ obj.connection = async_connection
+ obj.sync_transaction = obj._assign_proxied(sync_transaction)
+ obj.nested = nested
+ return obj
+
+ def _sync_transaction(self):
+ if not self.sync_transaction:
+ self._raise_for_not_started()
+ return self.sync_transaction
+
+ @property
+ def _proxied(self):
+ return self.sync_transaction
+
+ @property
+ def is_valid(self):
+ return self._sync_transaction().is_valid
+
+ @property
+ def is_active(self):
+ return self._sync_transaction().is_active
+
+ async def close(self):
+ """Close this :class:`.Transaction`.
+
+ If this transaction is the base transaction in a begin/commit
+ nesting, the transaction will rollback(). Otherwise, the
+ method returns.
+
+ This is used to cancel a Transaction without affecting the scope of
+ an enclosing transaction.
+
+ """
+ await greenlet_spawn(self._sync_transaction().close)
+
+ async def rollback(self):
+ """Roll back this :class:`.Transaction`."""
+ await greenlet_spawn(self._sync_transaction().rollback)
+
+ async def commit(self):
+ """Commit this :class:`.Transaction`."""
+
+ await greenlet_spawn(self._sync_transaction().commit)
+
+ async def start(self, is_ctxmanager=False):
+ """Start this :class:`_asyncio.AsyncTransaction` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.connection._sync_connection().begin_nested
+ if self.nested
+ else self.connection._sync_connection().begin
+ )
+ )
+ if is_ctxmanager:
+ self.sync_transaction.__enter__()
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await greenlet_spawn(
+ self._sync_transaction().__exit__, type_, value, traceback
+ )
+
+
+def _get_sync_engine_or_connection(async_engine):
+ if isinstance(async_engine, AsyncConnection):
+ return async_engine.sync_connection
+
+ try:
+ return async_engine.sync_engine
+ except AttributeError as e:
+ raise exc.ArgumentError(
+ "AsyncEngine expected, got %r" % async_engine
+ ) from e
+
+
+@inspection._inspects(AsyncConnection)
+def _no_insp_for_async_conn_yet(subject):
+ raise exc.NoInspectionAvailable(
+ "Inspection on an AsyncConnection is currently not supported. "
+ "Please use ``run_sync`` to pass a callable where it's possible "
+ "to call ``inspect`` on the passed connection.",
+ code="xd3s",
+ )
+
+
+@inspection._inspects(AsyncEngine)
+def _no_insp_for_async_engine_xyet(subject):
+ raise exc.NoInspectionAvailable(
+ "Inspection on an AsyncEngine is currently not supported. "
+ "Please obtain a connection then use ``conn.run_sync`` to pass a "
+ "callable where it's possible to call ``inspect`` on the "
+ "passed connection.",
+ code="xd3s",
+ )
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
new file mode 100644
index 0000000..c5d5e01
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/events.py
@@ -0,0 +1,44 @@
+# ext/asyncio/events.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .engine import AsyncConnectable
+from .session import AsyncSession
+from ...engine import events as engine_event
+from ...orm import events as orm_event
+
+
+class AsyncConnectionEvents(engine_event.ConnectionEvents):
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = AsyncConnectable
+
+ @classmethod
+ def _no_async_engine_events(cls):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes."
+ )
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ cls._no_async_engine_events()
+
+
+class AsyncSessionEvents(orm_event.SessionEvents):
+ _target_class_doc = "SomeSession"
+ _dispatch_target = AsyncSession
+
+ @classmethod
+ def _no_async_engine_events(cls):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ cls._no_async_engine_events()
diff --git a/lib/sqlalchemy/ext/asyncio/exc.py b/lib/sqlalchemy/ext/asyncio/exc.py
new file mode 100644
index 0000000..cf0d9a8
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/exc.py
@@ -0,0 +1,21 @@
+# ext/asyncio/exc.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import exc
+
+
+class AsyncMethodRequired(exc.InvalidRequestError):
+ """an API can't be used because its result would not be
+ compatible with async"""
+
+
+class AsyncContextNotStarted(exc.InvalidRequestError):
+ """a startable context manager has not been started."""
+
+
+class AsyncContextAlreadyStarted(exc.InvalidRequestError):
+ """a startable context manager is already started."""
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
new file mode 100644
index 0000000..a77b6a8
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -0,0 +1,671 @@
+# ext/asyncio/result.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import operator
+
+from . import exc as async_exc
+from ...engine.result import _NO_ROW
+from ...engine.result import FilterResult
+from ...engine.result import FrozenResult
+from ...engine.result import MergedResult
+from ...sql.base import _generative
+from ...util.concurrency import greenlet_spawn
+
+
+class AsyncCommon(FilterResult):
+ async def close(self):
+ """Close this result."""
+
+ await greenlet_spawn(self._real_result.close)
+
+
+class AsyncResult(AsyncCommon):
+ """An asyncio wrapper around a :class:`_result.Result` object.
+
+ The :class:`_asyncio.AsyncResult` only applies to statement executions that
+ use a server-side cursor. It is returned only from the
+ :meth:`_asyncio.AsyncConnection.stream` and
+ :meth:`_asyncio.AsyncSession.stream` methods.
+
+ .. note:: As is the case with :class:`_engine.Result`, this object is
+ used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`,
+ which can yield instances of ORM mapped objects either individually or
+ within tuple-like rows. Note that these result objects do not
+ deduplicate instances or rows automatically as is the case with the
+ legacy :class:`_orm.Query` object. For in-Python de-duplication of
+ instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier
+ method.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def __init__(self, real_result):
+ self._real_result = real_result
+
+ self._metadata = real_result._metadata
+ self._unique_filter_state = real_result._unique_filter_state
+
+ # BaseCursorResult pre-generates the "_row_getter". Use that
+ # if available rather than building a second one
+ if "_row_getter" in real_result.__dict__:
+ self._set_memoized_attribute(
+ "_row_getter", real_result.__dict__["_row_getter"]
+ )
+
+ def keys(self):
+ """Return the :meth:`_engine.Result.keys` collection from the
+ underlying :class:`_engine.Result`.
+
+ """
+ return self._metadata.keys
+
+ @_generative
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncResult`.
+
+ Refer to :meth:`_engine.Result.unique` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+
+ """
+ self._unique_filter_state = (set(), strategy)
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row.
+
+ Refer to :meth:`_engine.Result.columns` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+
+ """
+ return self._column_slices(col_expressions)
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of rows of the size given.
+
+ An async iterator is returned::
+
+ async def scroll_results(connection):
+ result = await connection.stream(select(users_table))
+
+ async for partition in result.partitions(100):
+ print("list of rows: %s" % partition)
+
+ .. seealso::
+
+ :meth:`_engine.Result.partitions`
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchone(self):
+ """Fetch one row.
+
+ When all rows are exhausted, returns None.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch the first row of a result only, use the
+ :meth:`_engine.Result.first` method. To iterate through all
+ rows, iterate the :class:`_engine.Result` object directly.
+
+ :return: a :class:`.Row` object if no filters are applied, or None
+ if no rows remain.
+
+ """
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ async def fetchmany(self, size=None):
+ """Fetch many rows.
+
+ When all rows are exhausted, returns an empty list.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch rows in groups, use the
+ :meth:`._asyncio.AsyncResult.partitions` method.
+
+ :return: a list of :class:`.Row` objects.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.partitions`
+
+ """
+
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all rows in a list.
+
+ Closes the result set after invocation. Subsequent invocations
+ will return an empty list.
+
+ :return: a list of :class:`.Row` objects.
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first row or None if no row is present.
+
+ Closes the result set and discards remaining rows.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default. To
+ return exactly one single scalar value, that is, the first column of
+ the first row, use the :meth:`_asyncio.AsyncResult.scalar` method,
+ or combine :meth:`_asyncio.AsyncResult.scalars` and
+ :meth:`_asyncio.AsyncResult.first`.
+
+ :return: a :class:`.Row` object, or None
+ if no rows remain.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.scalar`
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one result or raise an exception.
+
+ Returns ``None`` if the result has no rows.
+ Raises :class:`.MultipleResultsFound`
+ if multiple rows are returned.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row` or None if no row is available.
+
+ :raises: :class:`.MultipleResultsFound`
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.first`
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def scalar_one(self):
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+ then :meth:`_asyncio.AsyncResult.one`.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ :meth:`_asyncio.AsyncResult.scalars`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, True)
+
+ async def scalar_one_or_none(self):
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+ then :meth:`_asyncio.AsyncResult.one_or_none`.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.one_or_none`
+
+ :meth:`_asyncio.AsyncResult.scalars`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, True)
+
+ async def one(self):
+ """Return exactly one row or raise an exception.
+
+ Raises :class:`.NoResultFound` if the result returns no
+ rows, or :class:`.MultipleResultsFound` if multiple rows
+ would be returned.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the
+ :meth:`_asyncio.AsyncResult.scalar_one` method, or combine
+ :meth:`_asyncio.AsyncResult.scalars` and
+ :meth:`_asyncio.AsyncResult.one`.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row`.
+
+ :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.first`
+
+ :meth:`_asyncio.AsyncResult.one_or_none`
+
+ :meth:`_asyncio.AsyncResult.scalar_one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+ async def scalar(self):
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, True)
+
+ async def freeze(self):
+ """Return a callable object that will produce copies of this
+ :class:`_asyncio.AsyncResult` when invoked.
+
+ The callable object returned is an instance of
+ :class:`_engine.FrozenResult`.
+
+ This is used for result set caching. The method must be called
+ on the result when it has been unconsumed, and calling the method
+ will consume the result fully. When the :class:`_engine.FrozenResult`
+ is retrieved from a cache, it can be called any number of times where
+ it will produce a new :class:`_engine.Result` object each time
+ against its stored set of rows.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - example usage within the
+ ORM to implement a result-set cache.
+
+ """
+
+ return await greenlet_spawn(FrozenResult, self)
+
+ def merge(self, *others):
+ """Merge this :class:`_asyncio.AsyncResult` with other compatible
+ result objects.
+
+ The object returned is an instance of :class:`_engine.MergedResult`,
+ which will be composed of iterators from the given result
+ objects.
+
+ The new result will use the metadata from this result object.
+ The subsequent result objects must be against an identical
+ set of result / cursor metadata, otherwise the behavior is
+ undefined.
+
+ """
+ return MergedResult(self._metadata, (self,) + others)
+
+ def scalars(self, index=0):
+ """Return an :class:`_asyncio.AsyncScalarResult` filtering object which
+ will return single elements rather than :class:`_row.Row` objects.
+
+ Refer to :meth:`_result.Result.scalars` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ :param index: integer or row key indicating the column to be fetched
+ from each row, defaults to ``0`` indicating the first column.
+
+ :return: a new :class:`_asyncio.AsyncScalarResult` filtering object
+ referring to this :class:`_asyncio.AsyncResult` object.
+
+ """
+ return AsyncScalarResult(self._real_result, index)
+
+ def mappings(self):
+ """Apply a mappings filter to returned rows, returning an instance of
+ :class:`_asyncio.AsyncMappingResult`.
+
+ When this filter is applied, fetching rows will return
+ :class:`.RowMapping` objects instead of :class:`.Row` objects.
+
+ Refer to :meth:`_result.Result.mappings` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ :return: a new :class:`_asyncio.AsyncMappingResult` filtering object
+ referring to the underlying :class:`_result.Result` object.
+
+ """
+
+ return AsyncMappingResult(self._real_result)
+
+
+class AsyncScalarResult(AsyncCommon):
+ """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
+ rather than :class:`_row.Row` values.
+
+ The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
+ :meth:`_asyncio.AsyncResult.scalars` method.
+
+ Refer to the :class:`_result.ScalarResult` object in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _generate_rows = False
+
+ def __init__(self, real_result, index):
+ self._real_result = real_result
+
+ if real_result._source_supports_scalars:
+ self._metadata = real_result._metadata
+ self._post_creational_filter = None
+ else:
+ self._metadata = real_result._metadata._reduce([index])
+ self._post_creational_filter = operator.itemgetter(0)
+
+ self._unique_filter_state = real_result._unique_filter_state
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncScalarResult`.
+
+ See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchall(self):
+ """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+class AsyncMappingResult(AsyncCommon):
+ """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary
+ values rather than :class:`_engine.Row` values.
+
+ The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
+ :meth:`_asyncio.AsyncResult.mappings` method.
+
+ Refer to the :class:`_result.MappingResult` object in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _generate_rows = True
+
+ _post_creational_filter = operator.attrgetter("_mapping")
+
+ def __init__(self, result):
+ self._real_result = result
+ self._unique_filter_state = result._unique_filter_state
+ self._metadata = result._metadata
+ if result._source_supports_scalars:
+ self._metadata = self._metadata._reduce([0])
+
+ def keys(self):
+ """Return an iterable view which yields the string keys that would
+ be represented by each :class:`.Row`.
+
+ The view also can be tested for key containment using the Python
+ ``in`` operator, which will test both for the string keys represented
+ in the view, as well as for alternate keys such as column objects.
+
+ .. versionchanged:: 1.4 a key view object is returned rather than a
+ plain list.
+
+
+ """
+ return self._metadata.keys
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncMappingResult`.
+
+ See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row."""
+ return self._column_slices(col_expressions)
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchall(self):
+ """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchone(self):
+ """Fetch one object.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ async def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+async def _ensure_sync_result(result, calling_method):
+ if not result._is_cursor:
+ cursor_result = getattr(result, "raw", None)
+ else:
+ cursor_result = result
+ if cursor_result and cursor_result.context._is_server_side:
+ await greenlet_spawn(cursor_result.close)
+ raise async_exc.AsyncMethodRequired(
+ "Can't use the %s.%s() method with a "
+ "server-side cursor. "
+ "Use the %s.stream() method for an async "
+ "streaming result set."
+ % (
+ calling_method.__self__.__class__.__name__,
+ calling_method.__name__,
+ calling_method.__self__.__class__.__name__,
+ )
+ )
+ return result
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
new file mode 100644
index 0000000..8eca8c5
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/scoping.py
@@ -0,0 +1,107 @@
+# ext/asyncio/scoping.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .session import AsyncSession
+from ...orm.scoping import ScopedSessionMixin
+from ...util import create_proxy_methods
+from ...util import ScopedRegistry
+
+
+@create_proxy_methods(
+ AsyncSession,
+ ":class:`_asyncio.AsyncSession`",
+ ":class:`_asyncio.scoping.async_scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get",
+ "get_bind",
+ "is_modified",
+ "invalidate",
+ "merge",
+ "refresh",
+ "rollback",
+ "scalar",
+ "scalars",
+ "stream",
+ "stream_scalars",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
+class async_scoped_session(ScopedSessionMixin):
+ """Provides scoped management of :class:`.AsyncSession` objects.
+
+ See the section :ref:`asyncio_scoped_session` for usage details.
+
+ .. versionadded:: 1.4.19
+
+
+ """
+
+ _support_async = True
+
+ def __init__(self, session_factory, scopefunc):
+ """Construct a new :class:`_asyncio.async_scoped_session`.
+
+ :param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
+ instances. This is usually, but not necessarily, an instance
+ of :class:`_orm.sessionmaker` which itself was passed the
+ :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
+ parameter::
+
+ async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
+ AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+
+ :param scopefunc: function which defines
+ the current scope. A function such as ``asyncio.current_task``
+ may be useful here.
+
+ """ # noqa: E501
+
+ self.session_factory = session_factory
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+
+ @property
+ def _proxied(self):
+ return self.registry()
+
+ async def remove(self):
+ """Dispose of the current :class:`.AsyncSession`, if present.
+
+ Different from scoped_session's remove method, this method would use
+ await to wait for the close method of AsyncSession.
+
+ """
+
+ if self.registry.has():
+ await self.registry().close()
+ self.registry.clear()
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
new file mode 100644
index 0000000..378cbcb
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -0,0 +1,759 @@
+# ext/asyncio/session.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import asyncio
+
+from . import engine
+from . import result as _result
+from .base import ReversibleProxy
+from .base import StartableContext
+from .result import _ensure_sync_result
+from ... import util
+from ...orm import object_session
+from ...orm import Session
+from ...orm import state as _instance_state
+from ...util.concurrency import greenlet_spawn
+
+_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
+_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
+
+
+@util.create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_asyncio.AsyncSession`",
+ classmethods=["object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "is_modified",
+ "in_transaction",
+ "in_nested_transaction",
+ ],
+ attributes=[
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
+class AsyncSession(ReversibleProxy):
+ """Asyncio version of :class:`_orm.Session`.
+
+ The :class:`_asyncio.AsyncSession` is a proxy for a traditional
+ :class:`_orm.Session` instance.
+
+ .. versionadded:: 1.4
+
+ To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
+ implementations, see the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
+
+
+ """
+
+ _is_asyncio = True
+
+ dispatch = None
+
+ def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+ r"""Construct a new :class:`_asyncio.AsyncSession`.
+
+ All parameters other than ``sync_session_class`` are passed to the
+ ``sync_session_class`` callable directly to instantiate a new
+ :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
+ parameter documentation.
+
+ :param sync_session_class:
+ A :class:`_orm.Session` subclass or other callable which will be used
+ to construct the :class:`_orm.Session` which will be proxied. This
+ parameter may be used to provide custom :class:`_orm.Session`
+ subclasses. Defaults to the
+ :attr:`_asyncio.AsyncSession.sync_session_class` class-level
+ attribute.
+
+ .. versionadded:: 1.4.24
+
+ """
+ kw["future"] = True
+ if bind:
+ self.bind = bind
+ bind = engine._get_sync_engine_or_connection(bind)
+
+ if binds:
+ self.binds = binds
+ binds = {
+ key: engine._get_sync_engine_or_connection(b)
+ for key, b in binds.items()
+ }
+
+ if sync_session_class:
+ self.sync_session_class = sync_session_class
+
+ self.sync_session = self._proxied = self._assign_proxied(
+ self.sync_session_class(bind=bind, binds=binds, **kw)
+ )
+
+ sync_session_class = Session
+ """The class or callable that provides the
+ underlying :class:`_orm.Session` instance for a particular
+ :class:`_asyncio.AsyncSession`.
+
+ At the class level, this attribute is the default value for the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
+ subclasses of :class:`_asyncio.AsyncSession` can override this.
+
+ At the instance level, this attribute indicates the current class or
+ callable that was used to provide the :class:`_orm.Session` instance for
+ this :class:`_asyncio.AsyncSession` instance.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ sync_session: Session
+ """Reference to the underlying :class:`_orm.Session` this
+ :class:`_asyncio.AsyncSession` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+
+ """
+
+ async def refresh(
+ self, instance, attribute_names=None, with_for_update=None
+ ):
+ """Expire and refresh the attributes on the given instance.
+
+ A query will be issued to the database and all attributes will be
+ refreshed with their current database value.
+
+ This is the async version of the :meth:`_orm.Session.refresh` method.
+ See that method for a complete description of all options.
+
+ .. seealso::
+
+ :meth:`_orm.Session.refresh` - main documentation for refresh
+
+ """
+
+ return await greenlet_spawn(
+ self.sync_session.refresh,
+ instance,
+ attribute_names=attribute_names,
+ with_for_update=with_for_update,
+ )
+
+ async def run_sync(self, fn, *arg, **kw):
+ """Invoke the given sync callable passing sync self as the first
+ argument.
+
+ This method maintains the asyncio event loop all the way through
+ to the database connection by running the given callable in a
+ specially instrumented greenlet.
+
+ E.g.::
+
+ with AsyncSession(async_engine) as session:
+ await session.run_sync(some_business_method)
+
+ .. note::
+
+ The provided callable is invoked inline within the asyncio event
+ loop, and will block on traditional IO calls. IO within this
+ callable should only call into SQLAlchemy's asyncio database
+ APIs which will be properly adapted to the greenlet context.
+
+ .. seealso::
+
+ :ref:`session_run_sync`
+ """
+
+ return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+
+ async def execute(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a buffered
+ :class:`_engine.Result` object.
+
+ .. seealso::
+
+ :meth:`_orm.Session.execute` - main documentation for execute
+
+ """
+
+ if execution_options:
+ execution_options = util.immutabledict(execution_options).union(
+ _EXECUTE_OPTIONS
+ )
+ else:
+ execution_options = _EXECUTE_OPTIONS
+
+ result = await greenlet_spawn(
+ self.sync_session.execute,
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return await _ensure_sync_result(result, self.execute)
+
+ async def scalar(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a scalar result.
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalar` - main documentation for scalar
+
+ """
+
+ result = await self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalar()
+
+ async def scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return scalar results.
+
+ :return: a :class:`_result.ScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalars` - main documentation for scalars
+
+ :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version
+
+ """
+
+ result = await self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalars()
+
+ async def get(
+ self,
+ entity,
+ ident,
+ options=None,
+ populate_existing=False,
+ with_for_update=None,
+ identity_token=None,
+ ):
+ """Return an instance based on the given primary key identifier,
+ or ``None`` if not found.
+
+ .. seealso::
+
+ :meth:`_orm.Session.get` - main documentation for get
+
+
+ """
+ return await greenlet_spawn(
+ self.sync_session.get,
+ entity,
+ ident,
+ options=options,
+ populate_existing=populate_existing,
+ with_for_update=with_for_update,
+ identity_token=identity_token,
+ )
+
+ async def stream(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a streaming
+ :class:`_asyncio.AsyncResult` object.
+
+ """
+
+ if execution_options:
+ execution_options = util.immutabledict(execution_options).union(
+ _STREAM_OPTIONS
+ )
+ else:
+ execution_options = _STREAM_OPTIONS
+
+ result = await greenlet_spawn(
+ self.sync_session.execute,
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return _result.AsyncResult(result)
+
+ async def stream_scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a stream of scalar results.
+
+ :return: an :class:`_asyncio.AsyncScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalars` - main documentation for scalars
+
+ :meth:`_asyncio.AsyncSession.scalars` - non streaming version
+
+ """
+
+ result = await self.stream(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalars()
+
+ async def delete(self, instance):
+ """Mark an instance as deleted.
+
+ The database delete operation occurs upon ``flush()``.
+
+ As this operation may need to cascade along unloaded relationships,
+ it is awaitable to allow for those queries to take place.
+
+ .. seealso::
+
+ :meth:`_orm.Session.delete` - main documentation for delete
+
+ """
+ return await greenlet_spawn(self.sync_session.delete, instance)
+
+ async def merge(self, instance, load=True, options=None):
+ """Copy the state of a given instance into a corresponding instance
+ within this :class:`_asyncio.AsyncSession`.
+
+ .. seealso::
+
+ :meth:`_orm.Session.merge` - main documentation for merge
+
+ """
+ return await greenlet_spawn(
+ self.sync_session.merge, instance, load=load, options=options
+ )
+
+ async def flush(self, objects=None):
+ """Flush all the object changes to the database.
+
+ .. seealso::
+
+ :meth:`_orm.Session.flush` - main documentation for flush
+
+ """
+ await greenlet_spawn(self.sync_session.flush, objects=objects)
+
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ trans = self.sync_session.get_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ trans = self.sync_session.get_nested_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ """Return a "bind" to which the synchronous proxied :class:`_orm.Session`
+ is bound.
+
+ Unlike the :meth:`_orm.Session.get_bind` method, this method is
+ currently **not** used by this :class:`.AsyncSession` in any way
+ in order to resolve engines for requests.
+
+ .. note::
+
+ This method proxies directly to the :meth:`_orm.Session.get_bind`
+ method, however is currently **not** useful as an override target,
+ in contrast to that of the :meth:`_orm.Session.get_bind` method.
+ The example below illustrates how to implement custom
+ :meth:`_orm.Session.get_bind` schemes that work with
+ :class:`.AsyncSession` and :class:`.AsyncEngine`.
+
+ The pattern introduced at :ref:`session_custom_partitioning`
+ illustrates how to apply a custom bind-lookup scheme to a
+ :class:`_orm.Session` given a set of :class:`_engine.Engine` objects.
+ To apply a corresponding :meth:`_orm.Session.get_bind` implementation
+ for use with a :class:`.AsyncSession` and :class:`.AsyncEngine`
+ objects, continue to subclass :class:`_orm.Session` and apply it to
+ :class:`.AsyncSession` using
+ :paramref:`.AsyncSession.sync_session_class`. The inner method must
+ continue to return :class:`_engine.Engine` instances, which can be
+ acquired from a :class:`_asyncio.AsyncEngine` using the
+ :attr:`_asyncio.AsyncEngine.sync_engine` attribute::
+
+ # using example from "Custom Vertical Partitioning"
+
+
+ import random
+
+ from sqlalchemy.ext.asyncio import AsyncSession
+ from sqlalchemy.ext.asyncio import create_async_engine
+ from sqlalchemy.orm import Session, sessionmaker
+
+ # construct async engines w/ async drivers
+ engines = {
+ 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
+ 'other':create_async_engine("sqlite+aiosqlite:///other.db"),
+ 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
+ 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
+ }
+
+ class RoutingSession(Session):
+ def get_bind(self, mapper=None, clause=None, **kw):
+ # within get_bind(), return sync engines
+ if mapper and issubclass(mapper.class_, MyOtherClass):
+ return engines['other'].sync_engine
+ elif self._flushing or isinstance(clause, (Update, Delete)):
+ return engines['leader'].sync_engine
+ else:
+ return engines[
+ random.choice(['follower1','follower2'])
+ ].sync_engine
+
+ # apply to AsyncSession using sync_session_class
+ AsyncSessionMaker = sessionmaker(
+ class_=AsyncSession,
+ sync_session_class=RoutingSession
+ )
+
+ The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
+ implicitly non-blocking context in the same manner as ORM event hooks
+ and functions that are invoked via :meth:`.AsyncSession.run_sync`, so
+ routines that wish to run SQL commands inside of
+ :meth:`_orm.Session.get_bind` can continue to do so using
+ blocking-style code, which will be translated to implicitly async calls
+ at the point of invoking IO on the database drivers.
+
+ """ # noqa: E501
+
+ return self.sync_session.get_bind(
+ mapper=mapper, clause=clause, bind=bind, **kw
+ )
+
+ async def connection(self, **kw):
+ r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
+ this :class:`.Session` object's transactional state.
+
+ This method may also be used to establish execution options for the
+ database connection used by the current transaction.
+
+ .. versionadded:: 1.4.24 Added \**kw arguments which are passed
+ through to the underlying :meth:`_orm.Session.connection` method.
+
+ .. seealso::
+
+ :meth:`_orm.Session.connection` - main documentation for
+ "connection"
+
+ """
+
+ sync_connection = await greenlet_spawn(
+ self.sync_session.connection, **kw
+ )
+ return engine.AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+
+ def begin(self, **kw):
+ """Return an :class:`_asyncio.AsyncSessionTransaction` object.
+
+ The underlying :class:`_orm.Session` will perform the
+ "begin" action when the :class:`_asyncio.AsyncSessionTransaction`
+ object is entered::
+
+ async with async_session.begin():
+ # .. ORM transaction is begun
+
+ Note that database IO will not normally occur when the session-level
+ transaction is begun, as database transactions begin on an
+ on-demand basis. However, the begin block is async to accommodate
+ for a :meth:`_orm.SessionEvents.after_transaction_create`
+ event hook that may perform IO.
+
+ For a general description of ORM begin, see
+ :meth:`_orm.Session.begin`.
+
+ """
+
+ return AsyncSessionTransaction(self)
+
+ def begin_nested(self, **kw):
+ """Return an :class:`_asyncio.AsyncSessionTransaction` object
+ which will begin a "nested" transaction, e.g. SAVEPOINT.
+
+ Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`.
+
+ For a general description of ORM begin nested, see
+ :meth:`_orm.Session.begin_nested`.
+
+ """
+
+ return AsyncSessionTransaction(self, nested=True)
+
+ async def rollback(self):
+ """Rollback the current transaction in progress."""
+ return await greenlet_spawn(self.sync_session.rollback)
+
+ async def commit(self):
+ """Commit the current transaction in progress."""
+ return await greenlet_spawn(self.sync_session.commit)
+
+ async def close(self):
+ """Close out the transactional resources and ORM objects used by this
+ :class:`_asyncio.AsyncSession`.
+
+ This expunges all ORM objects associated with this
+ :class:`_asyncio.AsyncSession`, ends any transaction in progress and
+ :term:`releases` any :class:`_asyncio.AsyncConnection` objects which
+ this :class:`_asyncio.AsyncSession` itself has checked out from
+ associated :class:`_asyncio.AsyncEngine` objects. The operation then
+ leaves the :class:`_asyncio.AsyncSession` in a state which it may be
+ used again.
+
+ .. tip::
+
+ The :meth:`_asyncio.AsyncSession.close` method **does not prevent
+ the Session from being used again**. The
+ :class:`_asyncio.AsyncSession` itself does not actually have a
+ distinct "closed" state; it merely means the
+ :class:`_asyncio.AsyncSession` will release all database
+ connections and ORM objects.
+
+
+ .. seealso::
+
+ :ref:`session_closing` - detail on the semantics of
+ :meth:`_asyncio.AsyncSession.close`
+
+ """
+ await greenlet_spawn(self.sync_session.close)
+
+ async def invalidate(self):
+ """Close this Session, using connection invalidation.
+
+ For a complete description, see :meth:`_orm.Session.invalidate`.
+ """
+ return await greenlet_spawn(self.sync_session.invalidate)
+
+ @classmethod
+ async def close_all(self):
+ """Close all :class:`_asyncio.AsyncSession` sessions."""
+ return await greenlet_spawn(self.sync_session.close_all)
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await asyncio.shield(self.close())
+
+ def _maker_context_manager(self):
+ # no @contextlib.asynccontextmanager until python3.7, gr
+ return _AsyncSessionContextManager(self)
+
+
+class _AsyncSessionContextManager:
+ def __init__(self, async_session):
+ self.async_session = async_session
+
+ async def __aenter__(self):
+ self.trans = self.async_session.begin()
+ await self.trans.__aenter__()
+ return self.async_session
+
+ async def __aexit__(self, type_, value, traceback):
+ async def go():
+ await self.trans.__aexit__(type_, value, traceback)
+ await self.async_session.__aexit__(type_, value, traceback)
+
+ await asyncio.shield(go())
+
+
+class AsyncSessionTransaction(ReversibleProxy, StartableContext):
+ """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
+
+ This object is provided so that a transaction-holding object
+ for the :meth:`_asyncio.AsyncSession.begin` may be returned.
+
+ The object supports both explicit calls to
+ :meth:`_asyncio.AsyncSessionTransaction.commit` and
+ :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an
+ async context manager.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ("session", "sync_transaction", "nested")
+
+ def __init__(self, session, nested=False):
+ self.session = session
+ self.nested = nested
+ self.sync_transaction = None
+
+ @property
+ def is_active(self):
+ return (
+ self._sync_transaction() is not None
+ and self._sync_transaction().is_active
+ )
+
+ def _sync_transaction(self):
+ if not self.sync_transaction:
+ self._raise_for_not_started()
+ return self.sync_transaction
+
+ async def rollback(self):
+ """Roll back this :class:`_asyncio.AsyncTransaction`."""
+ await greenlet_spawn(self._sync_transaction().rollback)
+
+ async def commit(self):
+ """Commit this :class:`_asyncio.AsyncTransaction`."""
+
+ await greenlet_spawn(self._sync_transaction().commit)
+
+ async def start(self, is_ctxmanager=False):
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.session.sync_session.begin_nested
+ if self.nested
+ else self.session.sync_session.begin
+ )
+ )
+ if is_ctxmanager:
+ self.sync_transaction.__enter__()
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await greenlet_spawn(
+ self._sync_transaction().__exit__, type_, value, traceback
+ )
+
+
+def async_object_session(instance):
+ """Return the :class:`_asyncio.AsyncSession` to which the given instance
+ belongs.
+
+ This function makes use of the sync-API function
+ :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
+ refers to the given instance, and from there links it to the original
+ :class:`_asyncio.AsyncSession`.
+
+ If the :class:`_asyncio.AsyncSession` has been garbage collected, the
+ return value is ``None``.
+
+ This functionality is also available from the
+ :attr:`_orm.InstanceState.async_session` accessor.
+
+ :param instance: an ORM mapped instance
+ :return: an :class:`_asyncio.AsyncSession` object, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ session = object_session(instance)
+ if session is not None:
+ return async_session(session)
+ else:
+ return None
+
+
+def async_session(session):
+ """Return the :class:`_asyncio.AsyncSession` which is proxying the given
+ :class:`_orm.Session` object, if any.
+
+ :param session: a :class:`_orm.Session` instance.
+ :return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
+
+
+_instance_state._async_provider = async_session
diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py
new file mode 100644
index 0000000..a5d7267
--- /dev/null
+++ b/lib/sqlalchemy/ext/automap.py
@@ -0,0 +1,1234 @@
+# ext/automap.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system
+which automatically generates mapped classes and relationships from a database
+schema, typically though not necessarily one which is reflected.
+
+It is hoped that the :class:`.AutomapBase` system provides a quick
+and modernized solution to the problem that the very famous
+`SQLSoup <https://sqlsoup.readthedocs.io/en/latest/>`_
+also tries to solve, that of generating a quick and rudimentary object
+model from an existing database on the fly. By addressing the issue strictly
+at the mapper configuration level, and integrating fully with existing
+Declarative class techniques, :class:`.AutomapBase` seeks to provide
+a well-integrated approach to the issue of expediently auto-generating ad-hoc
+mappings.
+
+.. tip:: The :ref:`automap_toplevel` extension is geared towards a
+ "zero declaration" approach, where a complete ORM model including classes
+ and pre-named relationships can be generated on the fly from a database
+ schema. For applications that still want to use explicit class declarations
+ including explicit relationship definitions in conjunction with reflection
+ of tables, the :class:`.DeferredReflection` class, described at
+ :ref:`orm_declarative_reflected_deferred_reflection`, is a better choice.
+
+
+
+Basic Use
+=========
+
+The simplest usage is to reflect an existing database into a new model.
+We create a new :class:`.AutomapBase` class in a similar manner as to how
+we create a declarative base class, using :func:`.automap_base`.
+We then call :meth:`.AutomapBase.prepare` on the resulting base class,
+asking it to reflect the schema and produce mappings::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy.orm import Session
+ from sqlalchemy import create_engine
+
+ Base = automap_base()
+
+ # engine, suppose it has two tables 'user' and 'address' set up
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ # reflect the tables
+ Base.prepare(autoload_with=engine)
+
+ # mapped classes are now created with names by default
+ # matching that of the table name.
+ User = Base.classes.user
+ Address = Base.classes.address
+
+ session = Session(engine)
+
+ # rudimentary relationships are produced
+ session.add(Address(email_address="foo@bar.com", user=User(name="foo")))
+ session.commit()
+
+ # collection-based relationships are by default named
+ # "<classname>_collection"
+ print (u1.address_collection)
+
+Above, calling :meth:`.AutomapBase.prepare` while passing along the
+:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the
+:meth:`_schema.MetaData.reflect`
+method will be called on this declarative base
+classes' :class:`_schema.MetaData` collection; then, each **viable**
+:class:`_schema.Table` within the :class:`_schema.MetaData`
+will get a new mapped class
+generated automatically. The :class:`_schema.ForeignKeyConstraint`
+objects which
+link the various tables together will be used to produce new, bidirectional
+:func:`_orm.relationship` objects between classes.
+The classes and relationships
+follow along a default naming scheme that we can customize. At this point,
+our basic mapping consisting of related ``User`` and ``Address`` classes is
+ready to use in the traditional way.
+
+.. note:: By **viable**, we mean that for a table to be mapped, it must
+ specify a primary key. Additionally, if the table is detected as being
+ a pure association table between two other tables, it will not be directly
+ mapped and will instead be configured as a many-to-many table between
+ the mappings for the two referring tables.
+
+Generating Mappings from an Existing MetaData
+=============================================
+
+We can pass a pre-declared :class:`_schema.MetaData` object to
+:func:`.automap_base`.
+This object can be constructed in any way, including programmatically, from
+a serialized file, or from itself being reflected using
+:meth:`_schema.MetaData.reflect`.
+Below we illustrate a combination of reflection and
+explicit table declaration::
+
+ from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey
+ from sqlalchemy.ext.automap import automap_base
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ # produce our own MetaData object
+ metadata = MetaData()
+
+ # we can reflect it ourselves from a database, using options
+ # such as 'only' to limit what tables we look at...
+ metadata.reflect(engine, only=['user', 'address'])
+
+ # ... or just define our own Table objects with it (or combine both)
+ Table('user_order', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('user_id', ForeignKey('user.id'))
+ )
+
+ # we can then produce a set of mappings from this MetaData.
+ Base = automap_base(metadata=metadata)
+
+ # calling prepare() just sets up mapped classes and relationships.
+ Base.prepare()
+
+ # mapped classes are ready
+ User, Address, Order = Base.classes.user, Base.classes.address,\
+ Base.classes.user_order
+
+Specifying Classes Explicitly
+=============================
+
+.. tip:: If explicit classes are expected to be prominent in an application,
+ consider using :class:`.DeferredReflection` instead.
+
+The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined
+explicitly, in a way similar to that of the :class:`.DeferredReflection` class.
+Classes that extend from :class:`.AutomapBase` act like regular declarative
+classes, but are not immediately mapped after their construction, and are
+instead mapped when we call :meth:`.AutomapBase.prepare`. The
+:meth:`.AutomapBase.prepare` method will make use of the classes we've
+established based on the table name we use. If our schema contains tables
+``user`` and ``address``, we can define one or both of the classes to be used::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import create_engine
+
+ # automap base
+ Base = automap_base()
+
+ # pre-declare User for the 'user' table
+ class User(Base):
+ __tablename__ = 'user'
+
+ # override schema elements like Columns
+ user_name = Column('name', String)
+
+ # override relationships too, if desired.
+ # we must use the same name that automap would use for the
+ # relationship, and also must refer to the class name that automap will
+ # generate for "address"
+ address_collection = relationship("address", collection_class=set)
+
+ # reflect
+ engine = create_engine("sqlite:///mydatabase.db")
+ Base.prepare(autoload_with=engine)
+
+ # we still have Address generated from the tablename "address",
+ # but User is the same as Base.classes.User now
+
+ Address = Base.classes.address
+
+ u1 = session.query(User).first()
+ print (u1.address_collection)
+
+ # the backref is still there:
+ a1 = session.query(Address).first()
+ print (a1.user)
+
+Above, one of the more intricate details is that we illustrated overriding
+one of the :func:`_orm.relationship` objects that automap would have created.
+To do this, we needed to make sure the names match up with what automap
+would normally generate, in that the relationship name would be
+``User.address_collection`` and the name of the class referred to, from
+automap's perspective, is called ``address``, even though we are referring to
+it as ``Address`` within our usage of this class.
+
+Overriding Naming Schemes
+=========================
+
+:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and
+relationship names based on a schema, which means it has decision points in how
+these names are determined. These three decision points are provided using
+functions which can be passed to the :meth:`.AutomapBase.prepare` method, and
+are known as :func:`.classname_for_table`,
+:func:`.name_for_scalar_relationship`,
+and :func:`.name_for_collection_relationship`. Any or all of these
+functions are provided as in the example below, where we use a "camel case"
+scheme for class names and a "pluralizer" for collection names using the
+`Inflect <https://pypi.org/project/inflect>`_ package::
+
+ import re
+ import inflect
+
+ def camelize_classname(base, tablename, table):
+ "Produce a 'camelized' class name, e.g. "
+ "'words_and_underscores' -> 'WordsAndUnderscores'"
+
+ return str(tablename[0].upper() + \
+ re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:]))
+
+ _pluralizer = inflect.engine()
+ def pluralize_collection(base, local_cls, referred_cls, constraint):
+ "Produce an 'uncamelized', 'pluralized' class name, e.g. "
+ "'SomeTerm' -> 'some_terms'"
+
+ referred_name = referred_cls.__name__
+ uncamelized = re.sub(r'[A-Z]',
+ lambda m: "_%s" % m.group(0).lower(),
+ referred_name)[1:]
+ pluralized = _pluralizer.plural(uncamelized)
+ return pluralized
+
+ from sqlalchemy.ext.automap import automap_base
+
+ Base = automap_base()
+
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ Base.prepare(autoload_with=engine,
+ classname_for_table=camelize_classname,
+ name_for_collection_relationship=pluralize_collection
+ )
+
+From the above mapping, we would now have classes ``User`` and ``Address``,
+where the collection from ``User`` to ``Address`` is called
+``User.addresses``::
+
+ User, Address = Base.classes.User, Base.classes.Address
+
+ u1 = User(addresses=[Address(email="foo@bar.com")])
+
+Relationship Detection
+======================
+
+The vast majority of what automap accomplishes is the generation of
+:func:`_orm.relationship` structures based on foreign keys. The mechanism
+by which this works for many-to-one and one-to-many relationships is as
+follows:
+
+1. A given :class:`_schema.Table`, known to be mapped to a particular class,
+ is examined for :class:`_schema.ForeignKeyConstraint` objects.
+
+2. From each :class:`_schema.ForeignKeyConstraint`, the remote
+ :class:`_schema.Table`
+ object present is matched up to the class to which it is to be mapped,
+ if any, else it is skipped.
+
+3. As the :class:`_schema.ForeignKeyConstraint`
+ we are examining corresponds to a
+ reference from the immediate mapped class, the relationship will be set up
+ as a many-to-one referring to the referred class; a corresponding
+ one-to-many backref will be created on the referred class referring
+ to this class.
+
+4. If any of the columns that are part of the
+ :class:`_schema.ForeignKeyConstraint`
+ are not nullable (e.g. ``nullable=False``), a
+ :paramref:`_orm.relationship.cascade` keyword argument
+ of ``all, delete-orphan`` will be added to the keyword arguments to
+ be passed to the relationship or backref. If the
+ :class:`_schema.ForeignKeyConstraint` reports that
+ :paramref:`_schema.ForeignKeyConstraint.ondelete`
+ is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable
+ set of columns, the option :paramref:`_orm.relationship.passive_deletes`
+ flag is set to ``True`` in the set of relationship keyword arguments.
+ Note that not all backends support reflection of ON DELETE.
+
+ .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key
+ constraints when producing a one-to-many relationship and establish
+ a default cascade of ``all, delete-orphan`` if so; additionally,
+ if the constraint specifies
+ :paramref:`_schema.ForeignKeyConstraint.ondelete`
+ of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns,
+ the ``passive_deletes=True`` option is also added.
+
+5. The names of the relationships are determined using the
+ :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and
+ :paramref:`.AutomapBase.prepare.name_for_collection_relationship`
+ callable functions. It is important to note that the default relationship
+ naming derives the name from the **the actual class name**. If you've
+ given a particular class an explicit name by declaring it, or specified an
+ alternate class naming scheme, that's the name from which the relationship
+ name will be derived.
+
+6. The classes are inspected for an existing mapped property matching these
+ names. If one is detected on one side, but none on the other side,
+ :class:`.AutomapBase` attempts to create a relationship on the missing side,
+ then uses the :paramref:`_orm.relationship.back_populates`
+ parameter in order to
+ point the new relationship to the other side.
+
+7. In the usual case where no relationship is on either side,
+ :meth:`.AutomapBase.prepare` produces a :func:`_orm.relationship` on the
+ "many-to-one" side and matches it to the other using the
+ :paramref:`_orm.relationship.backref` parameter.
+
+8. Production of the :func:`_orm.relationship` and optionally the
+ :func:`.backref`
+ is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship`
+ function, which can be supplied by the end-user in order to augment
+ the arguments passed to :func:`_orm.relationship` or :func:`.backref` or to
+ make use of custom implementations of these functions.
+
+Custom Relationship Arguments
+-----------------------------
+
+The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used
+to add parameters to relationships. For most cases, we can make use of the
+existing :func:`.automap.generate_relationship` function to return
+the object, after augmenting the given keyword dictionary with our own
+arguments.
+
+Below is an illustration of how to send
+:paramref:`_orm.relationship.cascade` and
+:paramref:`_orm.relationship.passive_deletes`
+options along to all one-to-many relationships::
+
+ from sqlalchemy.ext.automap import generate_relationship
+
+ def _gen_relationship(base, direction, return_fn,
+ attrname, local_cls, referred_cls, **kw):
+ if direction is interfaces.ONETOMANY:
+ kw['cascade'] = 'all, delete-orphan'
+ kw['passive_deletes'] = True
+ # make use of the built-in function to actually return
+ # the result.
+ return generate_relationship(base, direction, return_fn,
+ attrname, local_cls, referred_cls, **kw)
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import create_engine
+
+ # automap base
+ Base = automap_base()
+
+ engine = create_engine("sqlite:///mydatabase.db")
+ Base.prepare(autoload_with=engine,
+ generate_relationship=_gen_relationship)
+
+Many-to-Many relationships
+--------------------------
+
+:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g.
+those which contain a ``secondary`` argument. The process for producing these
+is as follows:
+
+1. A given :class:`_schema.Table` is examined for
+ :class:`_schema.ForeignKeyConstraint`
+ objects, before any mapped class has been assigned to it.
+
+2. If the table contains two and exactly two
+ :class:`_schema.ForeignKeyConstraint`
+ objects, and all columns within this table are members of these two
+ :class:`_schema.ForeignKeyConstraint` objects, the table is assumed to be a
+ "secondary" table, and will **not be mapped directly**.
+
+3. The two (or one, for self-referential) external tables to which the
+ :class:`_schema.Table`
+ refers to are matched to the classes to which they will be
+ mapped, if any.
+
+4. If mapped classes for both sides are located, a many-to-many bi-directional
+ :func:`_orm.relationship` / :func:`.backref`
+ pair is created between the two
+ classes.
+
+5. The override logic for many-to-many works the same as that of one-to-many/
+ many-to-one; the :func:`.generate_relationship` function is called upon
+ to generate the structures and existing attributes will be maintained.
+
+Relationships with Inheritance
+------------------------------
+
+:mod:`.sqlalchemy.ext.automap` will not generate any relationships between
+two classes that are in an inheritance relationship. That is, with two
+classes given as follows::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee', 'polymorphic_on': type
+ }
+
+ class Engineer(Employee):
+ __tablename__ = 'engineer'
+ id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
+ __mapper_args__ = {
+ 'polymorphic_identity':'engineer',
+ }
+
+The foreign key from ``Engineer`` to ``Employee`` is used not for a
+relationship, but to establish joined inheritance between the two classes.
+
+Note that this means automap will not generate *any* relationships
+for foreign keys that link from a subclass to a superclass. If a mapping
+has actual relationships from subclass to superclass as well, those
+need to be explicit. Below, as we have two separate foreign keys
+from ``Engineer`` to ``Employee``, we need to set up both the relationship
+we want as well as the ``inherit_condition``, as these are not things
+SQLAlchemy can guess::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee', 'polymorphic_on':type
+ }
+
+ class Engineer(Employee):
+ __tablename__ = 'engineer'
+ id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
+ favorite_employee_id = Column(Integer, ForeignKey('employee.id'))
+
+ favorite_employee = relationship(Employee,
+ foreign_keys=favorite_employee_id)
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'engineer',
+ 'inherit_condition': id == Employee.id
+ }
+
+Handling Simple Naming Conflicts
+--------------------------------
+
+In the case of naming conflicts during mapping, override any of
+:func:`.classname_for_table`, :func:`.name_for_scalar_relationship`,
+and :func:`.name_for_collection_relationship` as needed. For example, if
+automap is attempting to name a many-to-one relationship the same as an
+existing column, an alternate convention can be conditionally selected. Given
+a schema:
+
+.. sourcecode:: sql
+
+ CREATE TABLE table_a (
+ id INTEGER PRIMARY KEY
+ );
+
+ CREATE TABLE table_b (
+ id INTEGER PRIMARY KEY,
+ table_a INTEGER,
+ FOREIGN KEY(table_a) REFERENCES table_a(id)
+ );
+
+The above schema will first automap the ``table_a`` table as a class named
+``table_a``; it will then automap a relationship onto the class for ``table_b``
+with the same name as this related class, e.g. ``table_a``. This
+relationship name conflicts with the mapping column ``table_b.table_a``,
+and will emit an error on mapping.
+
+We can resolve this conflict by using an underscore as follows::
+
+ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
+ name = referred_cls.__name__.lower()
+ local_table = local_cls.__table__
+ if name in local_table.columns:
+ newname = name + "_"
+ warnings.warn(
+ "Already detected name %s present. using %s" %
+ (name, newname))
+ return newname
+ return name
+
+
+ Base.prepare(autoload_with=engine,
+ name_for_scalar_relationship=name_for_scalar_relationship)
+
+Alternatively, we can change the name on the column side. The columns
+that are mapped can be modified using the technique described at
+:ref:`mapper_column_distinct_names`, by assigning the column explicitly
+to a new name::
+
+ Base = automap_base()
+
+ class TableB(Base):
+ __tablename__ = 'table_b'
+ _table_a = Column('table_a', ForeignKey('table_a.id'))
+
+ Base.prepare(autoload_with=engine)
+
+
+Using Automap with Explicit Declarations
+========================================
+
+As noted previously, automap has no dependency on reflection, and can make
+use of any collection of :class:`_schema.Table` objects within a
+:class:`_schema.MetaData`
+collection. From this, it follows that automap can also be used
+generate missing relationships given an otherwise complete model that fully
+defines table metadata::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import Column, Integer, String, ForeignKey
+
+ Base = automap_base()
+
+ class User(Base):
+ __tablename__ = 'user'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ class Address(Base):
+ __tablename__ = 'address'
+
+ id = Column(Integer, primary_key=True)
+ email = Column(String)
+ user_id = Column(ForeignKey('user.id'))
+
+ # produce relationships
+ Base.prepare()
+
+ # mapping is complete, with "address_collection" and
+ # "user" relationships
+ a1 = Address(email='u1')
+ a2 = Address(email='u2')
+ u1 = User(address_collection=[a1, a2])
+ assert a1.user is u1
+
+Above, given mostly complete ``User`` and ``Address`` mappings, the
+:class:`_schema.ForeignKey` which we defined on ``Address.user_id`` allowed a
+bidirectional relationship pair ``Address.user`` and
+``User.address_collection`` to be generated on the mapped classes.
+
+Note that when subclassing :class:`.AutomapBase`,
+the :meth:`.AutomapBase.prepare` method is required; if not called, the classes
+we've declared are in an un-mapped state.
+
+
+.. _automap_intercepting_columns:
+
+Intercepting Column Definitions
+===============================
+
+The :class:`_schema.MetaData` and :class:`_schema.Table` objects support an
+event hook :meth:`_events.DDLEvents.column_reflect` that may be used to intercept
+the information reflected about a database column before the :class:`_schema.Column`
+object is constructed. For example if we wanted to map columns using a
+naming convention such as ``"attr_<columnname>"``, the event could
+be applied as::
+
+ @event.listens_for(Base.metadata, "column_reflect")
+ def column_reflect(inspector, table, column_info):
+ # set column.key = "attr_<lower_case_name>"
+ column_info['key'] = "attr_%s" % column_info['name'].lower()
+
+ # run reflection
+ Base.prepare(autoload_with=engine)
+
+.. versionadded:: 1.4.0b2 the :meth:`_events.DDLEvents.column_reflect` event
+ may be applied to a :class:`_schema.MetaData` object.
+
+.. seealso::
+
+ :meth:`_events.DDLEvents.column_reflect`
+
+ :ref:`mapper_automated_reflection_schemes` - in the ORM mapping documentation
+
+
+""" # noqa
+from .. import util
+from ..orm import backref
+from ..orm import declarative_base as _declarative_base
+from ..orm import exc as orm_exc
+from ..orm import interfaces
+from ..orm import relationship
+from ..orm.decl_base import _DeferredMapperConfig
+from ..orm.mapper import _CONFIGURE_MUTEX
+from ..schema import ForeignKeyConstraint
+from ..sql import and_
+
+
+def classname_for_table(base, tablename, table):
+ """Return the class name that should be used, given the name
+ of a table.
+
+ The default implementation is::
+
+ return str(tablename)
+
+ Alternate implementations can be specified using the
+ :paramref:`.AutomapBase.prepare.classname_for_table`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param tablename: string name of the :class:`_schema.Table`.
+
+ :param table: the :class:`_schema.Table` object itself.
+
+ :return: a string class name.
+
+ .. note::
+
+ In Python 2, the string used for the class name **must** be a
+ non-Unicode object, e.g. a ``str()`` object. The ``.name`` attribute
+ of :class:`_schema.Table` is typically a Python unicode subclass,
+ so the
+ ``str()`` function should be applied to this name, after accounting for
+ any non-ASCII characters.
+
+ """
+ return str(tablename)
+
+
+def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
+ """Return the attribute name that should be used to refer from one
+ class to another, for a scalar object reference.
+
+ The default implementation is::
+
+ return referred_cls.__name__.lower()
+
+ Alternate implementations can be specified using the
+ :paramref:`.AutomapBase.prepare.name_for_scalar_relationship`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param local_cls: the class to be mapped on the local side.
+
+ :param referred_cls: the class to be mapped on the referring side.
+
+ :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being
+ inspected to produce this relationship.
+
+ """
+ return referred_cls.__name__.lower()
+
+
+def name_for_collection_relationship(
+ base, local_cls, referred_cls, constraint
+):
+ """Return the attribute name that should be used to refer from one
+ class to another, for a collection reference.
+
+ The default implementation is::
+
+ return referred_cls.__name__.lower() + "_collection"
+
+ Alternate implementations
+ can be specified using the
+ :paramref:`.AutomapBase.prepare.name_for_collection_relationship`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param local_cls: the class to be mapped on the local side.
+
+ :param referred_cls: the class to be mapped on the referring side.
+
+ :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being
+ inspected to produce this relationship.
+
+ """
+ return referred_cls.__name__.lower() + "_collection"
+
+
+def generate_relationship(
+ base, direction, return_fn, attrname, local_cls, referred_cls, **kw
+):
+ r"""Generate a :func:`_orm.relationship` or :func:`.backref`
+ on behalf of two
+ mapped classes.
+
+ An alternate implementation of this function can be specified using the
+ :paramref:`.AutomapBase.prepare.generate_relationship` parameter.
+
+ The default implementation of this function is as follows::
+
+ if return_fn is backref:
+ return return_fn(attrname, **kw)
+ elif return_fn is relationship:
+ return return_fn(referred_cls, **kw)
+ else:
+ raise TypeError("Unknown relationship function: %s" % return_fn)
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param direction: indicate the "direction" of the relationship; this will
+ be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOMANY`.
+
+ :param return_fn: the function that is used by default to create the
+ relationship. This will be either :func:`_orm.relationship` or
+ :func:`.backref`. The :func:`.backref` function's result will be used to
+ produce a new :func:`_orm.relationship` in a second step,
+ so it is critical
+ that user-defined implementations correctly differentiate between the two
+ functions, if a custom relationship function is being used.
+
+ :param attrname: the attribute name to which this relationship is being
+ assigned. If the value of :paramref:`.generate_relationship.return_fn` is
+ the :func:`.backref` function, then this name is the name that is being
+ assigned to the backref.
+
+ :param local_cls: the "local" class to which this relationship or backref
+ will be locally present.
+
+ :param referred_cls: the "referred" class to which the relationship or
+ backref refers to.
+
+ :param \**kw: all additional keyword arguments are passed along to the
+ function.
+
+ :return: a :func:`_orm.relationship` or :func:`.backref` construct,
+ as dictated
+ by the :paramref:`.generate_relationship.return_fn` parameter.
+
+ """
+ if return_fn is backref:
+ return return_fn(attrname, **kw)
+ elif return_fn is relationship:
+ return return_fn(referred_cls, **kw)
+ else:
+ raise TypeError("Unknown relationship function: %s" % return_fn)
+
+
+class AutomapBase(object):
+ """Base class for an "automap" schema.
+
+ The :class:`.AutomapBase` class can be compared to the "declarative base"
+ class that is produced by the :func:`.declarative.declarative_base`
+ function. In practice, the :class:`.AutomapBase` class is always used
+ as a mixin along with an actual declarative base.
+
+ A new subclassable :class:`.AutomapBase` is typically instantiated
+ using the :func:`.automap_base` function.
+
+ .. seealso::
+
+ :ref:`automap_toplevel`
+
+ """
+
+ __abstract__ = True
+
+ classes = None
+ """An instance of :class:`.util.Properties` containing classes.
+
+ This object behaves much like the ``.c`` collection on a table. Classes
+ are present under the name they were given, e.g.::
+
+ Base = automap_base()
+ Base.prepare(autoload_with=some_engine)
+
+ User, Address = Base.classes.User, Base.classes.Address
+
+ """
+
+ @classmethod
+ @util.deprecated_params(
+ engine=(
+ "2.0",
+ "The :paramref:`_automap.AutomapBase.prepare.engine` parameter "
+ "is deprecated and will be removed in a future release. "
+ "Please use the "
+ ":paramref:`_automap.AutomapBase.prepare.autoload_with` "
+ "parameter.",
+ ),
+ reflect=(
+ "2.0",
+ "The :paramref:`_automap.AutomapBase.prepare.reflect` "
+ "parameter is deprecated and will be removed in a future "
+ "release. Reflection is enabled when "
+ ":paramref:`_automap.AutomapBase.prepare.autoload_with` "
+ "is passed.",
+ ),
+ )
+ def prepare(
+ cls,
+ autoload_with=None,
+ engine=None,
+ reflect=False,
+ schema=None,
+ classname_for_table=None,
+ collection_class=None,
+ name_for_scalar_relationship=None,
+ name_for_collection_relationship=None,
+ generate_relationship=None,
+ reflection_options=util.EMPTY_DICT,
+ ):
+ """Extract mapped classes and relationships from the
+ :class:`_schema.MetaData` and
+ perform mappings.
+
+ :param engine: an :class:`_engine.Engine` or
+ :class:`_engine.Connection` with which
+ to perform schema reflection, if specified.
+ If the :paramref:`.AutomapBase.prepare.reflect` argument is False,
+ this object is not used.
+
+ :param reflect: if True, the :meth:`_schema.MetaData.reflect`
+ method is called
+ on the :class:`_schema.MetaData` associated with this
+ :class:`.AutomapBase`.
+ The :class:`_engine.Engine` passed via
+ :paramref:`.AutomapBase.prepare.engine` will be used to perform the
+ reflection if present; else, the :class:`_schema.MetaData`
+ should already be
+ bound to some engine else the operation will fail.
+
+ :param classname_for_table: callable function which will be used to
+ produce new class names, given a table name. Defaults to
+ :func:`.classname_for_table`.
+
+ :param name_for_scalar_relationship: callable function which will be
+ used to produce relationship names for scalar relationships. Defaults
+ to :func:`.name_for_scalar_relationship`.
+
+ :param name_for_collection_relationship: callable function which will
+ be used to produce relationship names for collection-oriented
+ relationships. Defaults to :func:`.name_for_collection_relationship`.
+
+ :param generate_relationship: callable function which will be used to
+ actually generate :func:`_orm.relationship` and :func:`.backref`
+ constructs. Defaults to :func:`.generate_relationship`.
+
+ :param collection_class: the Python collection class that will be used
+ when a new :func:`_orm.relationship`
+ object is created that represents a
+ collection. Defaults to ``list``.
+
+ :param schema: When present in conjunction with the
+ :paramref:`.AutomapBase.prepare.reflect` flag, is passed to
+ :meth:`_schema.MetaData.reflect`
+ to indicate the primary schema where tables
+ should be reflected from. When omitted, the default schema in use
+ by the database connection is used.
+
+ .. versionadded:: 1.1
+
+ :param reflection_options: When present, this dictionary of options
+ will be passed to :meth:`_schema.MetaData.reflect`
+ to supply general reflection-specific options like ``only`` and/or
+ dialect-specific options like ``oracle_resolve_synonyms``.
+
+ .. versionadded:: 1.4
+
+ """
+ glbls = globals()
+ if classname_for_table is None:
+ classname_for_table = glbls["classname_for_table"]
+ if name_for_scalar_relationship is None:
+ name_for_scalar_relationship = glbls[
+ "name_for_scalar_relationship"
+ ]
+ if name_for_collection_relationship is None:
+ name_for_collection_relationship = glbls[
+ "name_for_collection_relationship"
+ ]
+ if generate_relationship is None:
+ generate_relationship = glbls["generate_relationship"]
+ if collection_class is None:
+ collection_class = list
+
+ if autoload_with:
+ reflect = True
+
+ if engine:
+ autoload_with = engine
+
+ if reflect:
+ opts = dict(
+ schema=schema,
+ extend_existing=True,
+ autoload_replace=False,
+ )
+ if reflection_options:
+ opts.update(reflection_options)
+ cls.metadata.reflect(autoload_with, **opts)
+
+ with _CONFIGURE_MUTEX:
+ table_to_map_config = dict(
+ (m.local_table, m)
+ for m in _DeferredMapperConfig.classes_for_base(
+ cls, sort=False
+ )
+ )
+
+ many_to_many = []
+
+ for table in cls.metadata.tables.values():
+ lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table)
+ if lcl_m2m is not None:
+ many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table))
+ elif not table.primary_key:
+ continue
+ elif table not in table_to_map_config:
+ mapped_cls = type(
+ classname_for_table(cls, table.name, table),
+ (cls,),
+ {"__table__": table},
+ )
+ map_config = _DeferredMapperConfig.config_for_cls(
+ mapped_cls
+ )
+ cls.classes[map_config.cls.__name__] = mapped_cls
+ table_to_map_config[table] = map_config
+
+ for map_config in table_to_map_config.values():
+ _relationships_for_fks(
+ cls,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
+
+ for lcl_m2m, rem_m2m, m2m_const, table in many_to_many:
+ _m2m_relationship(
+ cls,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
+
+ for map_config in _DeferredMapperConfig.classes_for_base(cls):
+ map_config.map()
+
+ _sa_decl_prepare = True
+ """Indicate that the mapping of classes should be deferred.
+
+ The presence of this attribute name indicates to declarative
+ that the call to mapper() should not occur immediately; instead,
+ information about the table and attributes to be mapped are gathered
+ into an internal structure called _DeferredMapperConfig. These
+ objects can be collected later using classes_for_base(), additional
+ mapping decisions can be made, and then the map() method will actually
+ apply the mapping.
+
+ The only real reason this deferral of the whole
+ thing is needed is to support primary key columns that aren't reflected
+ yet when the class is declared; everything else can theoretically be
+ added to the mapper later. However, the _DeferredMapperConfig is a
+ nice interface in any case which exists at that not usually exposed point
+ at which declarative has the class and the Table but hasn't called
+ mapper() yet.
+
+ """
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of AutomapBase. "
+ "Mappings are not produced until the .prepare() "
+ "method is called on the class hierarchy."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+
+def automap_base(declarative_base=None, **kw):
+ r"""Produce a declarative automap base.
+
+ This function produces a new base class that is a product of the
+ :class:`.AutomapBase` class as well a declarative base produced by
+ :func:`.declarative.declarative_base`.
+
+ All parameters other than ``declarative_base`` are keyword arguments
+ that are passed directly to the :func:`.declarative.declarative_base`
+ function.
+
+ :param declarative_base: an existing class produced by
+ :func:`.declarative.declarative_base`. When this is passed, the function
+ no longer invokes :func:`.declarative.declarative_base` itself, and all
+ other keyword arguments are ignored.
+
+ :param \**kw: keyword arguments are passed along to
+ :func:`.declarative.declarative_base`.
+
+ """
+ if declarative_base is None:
+ Base = _declarative_base(**kw)
+ else:
+ Base = declarative_base
+
+ return type(
+ Base.__name__,
+ (AutomapBase, Base),
+ {"__abstract__": True, "classes": util.Properties({})},
+ )
+
+
+def _is_many_to_many(automap_base, table):
+ fk_constraints = [
+ const
+ for const in table.constraints
+ if isinstance(const, ForeignKeyConstraint)
+ ]
+ if len(fk_constraints) != 2:
+ return None, None, None
+
+ cols = sum(
+ [
+ [fk.parent for fk in fk_constraint.elements]
+ for fk_constraint in fk_constraints
+ ],
+ [],
+ )
+
+ if set(cols) != set(table.c):
+ return None, None, None
+
+ return (
+ fk_constraints[0].elements[0].column.table,
+ fk_constraints[1].elements[0].column.table,
+ fk_constraints,
+ )
+
+
+def _relationships_for_fks(
+ automap_base,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
+ local_table = map_config.local_table
+ local_cls = map_config.cls # derived from a weakref, may be None
+
+ if local_table is None or local_cls is None:
+ return
+ for constraint in local_table.constraints:
+ if isinstance(constraint, ForeignKeyConstraint):
+ fks = constraint.elements
+ referred_table = fks[0].column.table
+ referred_cfg = table_to_map_config.get(referred_table, None)
+ if referred_cfg is None:
+ continue
+ referred_cls = referred_cfg.cls
+
+ if local_cls is not referred_cls and issubclass(
+ local_cls, referred_cls
+ ):
+ continue
+
+ relationship_name = name_for_scalar_relationship(
+ automap_base, local_cls, referred_cls, constraint
+ )
+ backref_name = name_for_collection_relationship(
+ automap_base, referred_cls, local_cls, constraint
+ )
+
+ o2m_kws = {}
+ nullable = False not in {fk.parent.nullable for fk in fks}
+ if not nullable:
+ o2m_kws["cascade"] = "all, delete-orphan"
+
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "cascade"
+ ):
+ o2m_kws["passive_deletes"] = True
+ else:
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "set null"
+ ):
+ o2m_kws["passive_deletes"] = True
+
+ create_backref = backref_name not in referred_cfg.properties
+
+ if relationship_name not in map_config.properties:
+ if create_backref:
+ backref_obj = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
+ **o2m_kws
+ )
+ else:
+ backref_obj = None
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOONE,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ backref=backref_obj,
+ remote_side=[fk.column for fk in constraint.elements],
+ )
+ if rel is not None:
+ map_config.properties[relationship_name] = rel
+ if not create_backref:
+ referred_cfg.properties[
+ backref_name
+ ].back_populates = relationship_name
+ elif create_backref:
+ rel = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ **o2m_kws
+ )
+ if rel is not None:
+ referred_cfg.properties[backref_name] = rel
+ map_config.properties[
+ relationship_name
+ ].back_populates = backref_name
+
+
+def _m2m_relationship(
+ automap_base,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
+
+ map_config = table_to_map_config.get(lcl_m2m, None)
+ referred_cfg = table_to_map_config.get(rem_m2m, None)
+ if map_config is None or referred_cfg is None:
+ return
+
+ local_cls = map_config.cls
+ referred_cls = referred_cfg.cls
+
+ relationship_name = name_for_collection_relationship(
+ automap_base, local_cls, referred_cls, m2m_const[0]
+ )
+ backref_name = name_for_collection_relationship(
+ automap_base, referred_cls, local_cls, m2m_const[1]
+ )
+
+ create_backref = backref_name not in referred_cfg.properties
+
+ if table in table_to_map_config:
+ overlaps = "__*"
+ else:
+ overlaps = None
+
+ if relationship_name not in map_config.properties:
+ if create_backref:
+ backref_obj = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
+ overlaps=overlaps,
+ )
+ else:
+ backref_obj = None
+
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ overlaps=overlaps,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ backref=backref_obj,
+ collection_class=collection_class,
+ )
+ if rel is not None:
+ map_config.properties[relationship_name] = rel
+
+ if not create_backref:
+ referred_cfg.properties[
+ backref_name
+ ].back_populates = relationship_name
+ elif create_backref:
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ overlaps=overlaps,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ )
+ if rel is not None:
+ referred_cfg.properties[backref_name] = rel
+ map_config.properties[
+ relationship_name
+ ].back_populates = backref_name
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
new file mode 100644
index 0000000..109e0c0
--- /dev/null
+++ b/lib/sqlalchemy/ext/baked.py
@@ -0,0 +1,648 @@
+# sqlalchemy/ext/baked.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Baked query extension.
+
+Provides a creational pattern for the :class:`.query.Query` object which
+allows the fully constructed object, Core select statement, and string
+compiled result to be fully cached.
+
+
+"""
+
+import logging
+
+from .. import exc as sa_exc
+from .. import util
+from ..orm import exc as orm_exc
+from ..orm import strategy_options
+from ..orm.query import Query
+from ..orm.session import Session
+from ..sql import func
+from ..sql import literal_column
+from ..sql import util as sql_util
+from ..util import collections_abc
+
+
+log = logging.getLogger(__name__)
+
+
+class Bakery(object):
+ """Callable which returns a :class:`.BakedQuery`.
+
+ This object is returned by the class method
+ :meth:`.BakedQuery.bakery`. It exists as an object
+ so that the "cache" can be easily inspected.
+
+ .. versionadded:: 1.2
+
+
+ """
+
+ __slots__ = "cls", "cache"
+
+ def __init__(self, cls_, cache):
+ self.cls = cls_
+ self.cache = cache
+
+ def __call__(self, initial_fn, *args):
+ return self.cls(self.cache, initial_fn, args)
+
+
+class BakedQuery(object):
+ """A builder object for :class:`.query.Query` objects."""
+
+ __slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
+
+ def __init__(self, bakery, initial_fn, args=()):
+ self._cache_key = ()
+ self._update_cache_key(initial_fn, args)
+ self.steps = [initial_fn]
+ self._spoiled = False
+ self._bakery = bakery
+
+ @classmethod
+ def bakery(cls, size=200, _size_alert=None):
+ """Construct a new bakery.
+
+ :return: an instance of :class:`.Bakery`
+
+ """
+
+ return Bakery(cls, util.LRUCache(size, size_alert=_size_alert))
+
+ def _clone(self):
+ b1 = BakedQuery.__new__(BakedQuery)
+ b1._cache_key = self._cache_key
+ b1.steps = list(self.steps)
+ b1._bakery = self._bakery
+ b1._spoiled = self._spoiled
+ return b1
+
+ def _update_cache_key(self, fn, args=()):
+ self._cache_key += (fn.__code__,) + args
+
+ def __iadd__(self, other):
+ if isinstance(other, tuple):
+ self.add_criteria(*other)
+ else:
+ self.add_criteria(other)
+ return self
+
+ def __add__(self, other):
+ if isinstance(other, tuple):
+ return self.with_criteria(*other)
+ else:
+ return self.with_criteria(other)
+
+ def add_criteria(self, fn, *args):
+ """Add a criteria function to this :class:`.BakedQuery`.
+
+ This is equivalent to using the ``+=`` operator to
+ modify a :class:`.BakedQuery` in-place.
+
+ """
+ self._update_cache_key(fn, args)
+ self.steps.append(fn)
+ return self
+
+ def with_criteria(self, fn, *args):
+ """Add a criteria function to a :class:`.BakedQuery` cloned from this
+ one.
+
+ This is equivalent to using the ``+`` operator to
+ produce a new :class:`.BakedQuery` with modifications.
+
+ """
+ return self._clone().add_criteria(fn, *args)
+
+ def for_session(self, session):
+ """Return a :class:`_baked.Result` object for this
+ :class:`.BakedQuery`.
+
+ This is equivalent to calling the :class:`.BakedQuery` as a
+ Python callable, e.g. ``result = my_baked_query(session)``.
+
+ """
+ return Result(self, session)
+
+ def __call__(self, session):
+ return self.for_session(session)
+
+ def spoil(self, full=False):
+ """Cancel any query caching that will occur on this BakedQuery object.
+
+ The BakedQuery can continue to be used normally, however additional
+ creational functions will not be cached; they will be called
+ on every invocation.
+
+ This is to support the case where a particular step in constructing
+ a baked query disqualifies the query from being cacheable, such
+ as a variant that relies upon some uncacheable value.
+
+ :param full: if False, only functions added to this
+ :class:`.BakedQuery` object subsequent to the spoil step will be
+ non-cached; the state of the :class:`.BakedQuery` up until
+ this point will be pulled from the cache. If True, then the
+ entire :class:`_query.Query` object is built from scratch each
+ time, with all creational functions being called on each
+ invocation.
+
+ """
+ if not full and not self._spoiled:
+ _spoil_point = self._clone()
+ _spoil_point._cache_key += ("_query_only",)
+ self.steps = [_spoil_point._retrieve_baked_query]
+ self._spoiled = True
+ return self
+
+ def _effective_key(self, session):
+ """Return the key that actually goes into the cache dictionary for
+ this :class:`.BakedQuery`, taking into account the given
+ :class:`.Session`.
+
+ This basically means we also will include the session's query_class,
+ as the actual :class:`_query.Query` object is part of what's cached
+ and needs to match the type of :class:`_query.Query` that a later
+ session will want to use.
+
+ """
+ return self._cache_key + (session._query_cls,)
+
+ def _with_lazyload_options(self, options, effective_path, cache_path=None):
+ """Cloning version of _add_lazyload_options."""
+ q = self._clone()
+ q._add_lazyload_options(options, effective_path, cache_path=cache_path)
+ return q
+
+ def _add_lazyload_options(self, options, effective_path, cache_path=None):
+ """Used by per-state lazy loaders to add options to the
+ "lazy load" query from a parent query.
+
+ Creates a cache key based on given load path and query options;
+ if a repeatable cache key cannot be generated, the query is
+ "spoiled" so that it won't use caching.
+
+ """
+
+ key = ()
+
+ if not cache_path:
+ cache_path = effective_path
+
+ for opt in options:
+ if opt._is_legacy_option or opt._is_compile_state:
+ ck = opt._generate_cache_key()
+ if ck is None:
+ self.spoil(full=True)
+ else:
+ assert not ck[1], (
+ "loader options with variable bound parameters "
+ "not supported with baked queries. Please "
+ "use new-style select() statements for cached "
+ "ORM queries."
+ )
+ key += ck[0]
+
+ self.add_criteria(
+ lambda q: q._with_current_path(effective_path).options(*options),
+ cache_path.path,
+ key,
+ )
+
+ def _retrieve_baked_query(self, session):
+ query = self._bakery.get(self._effective_key(session), None)
+ if query is None:
+ query = self._as_query(session)
+ self._bakery[self._effective_key(session)] = query.with_session(
+ None
+ )
+ return query.with_session(session)
+
+ def _bake(self, session):
+ query = self._as_query(session)
+ query.session = None
+
+ # in 1.4, this is where before_compile() event is
+ # invoked
+ statement = query._statement_20()
+
+ # if the query is not safe to cache, we still do everything as though
+ # we did cache it, since the receiver of _bake() assumes subqueryload
+ # context was set up, etc.
+ #
+ # note also we want to cache the statement itself because this
+ # allows the statement itself to hold onto its cache key that is
+ # used by the Connection, which in itself is more expensive to
+ # generate than what BakedQuery was able to provide in 1.3 and prior
+
+ if statement._compile_options._bake_ok:
+ self._bakery[self._effective_key(session)] = (
+ query,
+ statement,
+ )
+
+ return query, statement
+
+ def to_query(self, query_or_session):
+ """Return the :class:`_query.Query` object for use as a subquery.
+
+ This method should be used within the lambda callable being used
+ to generate a step of an enclosing :class:`.BakedQuery`. The
+ parameter should normally be the :class:`_query.Query` object that
+ is passed to the lambda::
+
+ sub_bq = self.bakery(lambda s: s.query(User.name))
+ sub_bq += lambda q: q.filter(
+ User.id == Address.user_id).correlate(Address)
+
+ main_bq = self.bakery(lambda s: s.query(Address))
+ main_bq += lambda q: q.filter(
+ sub_bq.to_query(q).exists())
+
+ In the case where the subquery is used in the first callable against
+ a :class:`.Session`, the :class:`.Session` is also accepted::
+
+ sub_bq = self.bakery(lambda s: s.query(User.name))
+ sub_bq += lambda q: q.filter(
+ User.id == Address.user_id).correlate(Address)
+
+ main_bq = self.bakery(
+ lambda s: s.query(
+ Address.id, sub_bq.to_query(q).scalar_subquery())
+ )
+
+ :param query_or_session: a :class:`_query.Query` object or a class
+ :class:`.Session` object, that is assumed to be within the context
+ of an enclosing :class:`.BakedQuery` callable.
+
+
+ .. versionadded:: 1.3
+
+
+ """
+
+ if isinstance(query_or_session, Session):
+ session = query_or_session
+ elif isinstance(query_or_session, Query):
+ session = query_or_session.session
+ if session is None:
+ raise sa_exc.ArgumentError(
+ "Given Query needs to be associated with a Session"
+ )
+ else:
+ raise TypeError(
+ "Query or Session object expected, got %r."
+ % type(query_or_session)
+ )
+ return self._as_query(session)
+
+ def _as_query(self, session):
+ query = self.steps[0](session)
+
+ for step in self.steps[1:]:
+ query = step(query)
+
+ return query
+
+
+class Result(object):
+ """Invokes a :class:`.BakedQuery` against a :class:`.Session`.
+
+ The :class:`_baked.Result` object is where the actual :class:`.query.Query`
+ object gets created, or retrieved from the cache,
+ against a target :class:`.Session`, and is then invoked for results.
+
+ """
+
+ __slots__ = "bq", "session", "_params", "_post_criteria"
+
+ def __init__(self, bq, session):
+ self.bq = bq
+ self.session = session
+ self._params = {}
+ self._post_criteria = []
+
+ def params(self, *args, **kw):
+ """Specify parameters to be replaced into the string SQL statement."""
+
+ if len(args) == 1:
+ kw.update(args[0])
+ elif len(args) > 0:
+ raise sa_exc.ArgumentError(
+ "params() takes zero or one positional argument, "
+ "which is a dictionary."
+ )
+ self._params.update(kw)
+ return self
+
+ def _using_post_criteria(self, fns):
+ if fns:
+ self._post_criteria.extend(fns)
+ return self
+
+ def with_post_criteria(self, fn):
+ """Add a criteria function that will be applied post-cache.
+
+ This adds a function that will be run against the
+ :class:`_query.Query` object after it is retrieved from the
+ cache. This currently includes **only** the
+ :meth:`_query.Query.params` and :meth:`_query.Query.execution_options`
+ methods.
+
+ .. warning:: :meth:`_baked.Result.with_post_criteria`
+ functions are applied
+ to the :class:`_query.Query`
+ object **after** the query's SQL statement
+ object has been retrieved from the cache. Only
+ :meth:`_query.Query.params` and
+ :meth:`_query.Query.execution_options`
+ methods should be used.
+
+
+ .. versionadded:: 1.2
+
+
+ """
+ return self._using_post_criteria([fn])
+
+ def _as_query(self):
+ q = self.bq._as_query(self.session).params(self._params)
+ for fn in self._post_criteria:
+ q = fn(q)
+ return q
+
+ def __str__(self):
+ return str(self._as_query())
+
+ def __iter__(self):
+ return self._iter().__iter__()
+
+ def _iter(self):
+ bq = self.bq
+
+ if not self.session.enable_baked_queries or bq._spoiled:
+ return self._as_query()._iter()
+
+ query, statement = bq._bakery.get(
+ bq._effective_key(self.session), (None, None)
+ )
+ if query is None:
+ query, statement = bq._bake(self.session)
+
+ if self._params:
+ q = query.params(self._params)
+ else:
+ q = query
+ for fn in self._post_criteria:
+ q = fn(q)
+
+ params = q._params
+ execution_options = dict(q._execution_options)
+ execution_options.update(
+ {
+ "_sa_orm_load_options": q.load_options,
+ "compiled_cache": bq._bakery,
+ }
+ )
+
+ result = self.session.execute(
+ statement, params, execution_options=execution_options
+ )
+ if result._attributes.get("is_single_entity", False):
+ result = result.scalars()
+
+ if result._attributes.get("filtered", False):
+ result = result.unique()
+
+ return result
+
+ def count(self):
+ """return the 'count'.
+
+ Equivalent to :meth:`_query.Query.count`.
+
+ Note this uses a subquery to ensure an accurate count regardless
+ of the structure of the original statement.
+
+ .. versionadded:: 1.1.6
+
+ """
+
+ col = func.count(literal_column("*"))
+ bq = self.bq.with_criteria(lambda q: q._from_self(col))
+ return bq.for_session(self.session).params(self._params).scalar()
+
+ def scalar(self):
+ """Return the first element of the first result or None
+ if no rows present. If multiple rows are returned,
+ raises MultipleResultsFound.
+
+ Equivalent to :meth:`_query.Query.scalar`.
+
+ .. versionadded:: 1.1.6
+
+ """
+ try:
+ ret = self.one()
+ if not isinstance(ret, collections_abc.Sequence):
+ return ret
+ return ret[0]
+ except orm_exc.NoResultFound:
+ return None
+
+ def first(self):
+ """Return the first row.
+
+ Equivalent to :meth:`_query.Query.first`.
+
+ """
+
+ bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
+ return (
+ bq.for_session(self.session)
+ .params(self._params)
+ ._using_post_criteria(self._post_criteria)
+ ._iter()
+ .first()
+ )
+
+ def one(self):
+ """Return exactly one result or raise an exception.
+
+ Equivalent to :meth:`_query.Query.one`.
+
+ """
+ return self._iter().one()
+
+ def one_or_none(self):
+ """Return one or zero results, or raise an exception for multiple
+ rows.
+
+ Equivalent to :meth:`_query.Query.one_or_none`.
+
+ .. versionadded:: 1.0.9
+
+ """
+ return self._iter().one_or_none()
+
+ def all(self):
+ """Return all rows.
+
+ Equivalent to :meth:`_query.Query.all`.
+
+ """
+ return self._iter().all()
+
+ def get(self, ident):
+ """Retrieve an object based on identity.
+
+ Equivalent to :meth:`_query.Query.get`.
+
+ """
+
+ query = self.bq.steps[0](self.session)
+ return query._get_impl(ident, self._load_on_pk_identity)
+
+ def _load_on_pk_identity(self, session, query, primary_key_identity, **kw):
+ """Load the given primary key identity from the database."""
+
+ mapper = query._raw_columns[0]._annotations["parententity"]
+
+ _get_clause, _get_params = mapper._get_clause
+
+ def setup(query):
+ _lcl_get_clause = _get_clause
+ q = query._clone()
+ q._get_condition()
+ q._order_by = None
+
+ # None present in ident - turn those comparisons
+ # into "IS NULL"
+ if None in primary_key_identity:
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
+ _lcl_get_clause = sql_util.adapt_criterion_to_null(
+ _lcl_get_clause, nones
+ )
+
+ # TODO: can mapper._get_clause be pre-adapted?
+ q._where_criteria = (
+ sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}),
+ )
+
+ for fn in self._post_criteria:
+ q = fn(q)
+ return q
+
+ # cache the query against a key that includes
+ # which positions in the primary key are NULL
+ # (remember, we can map to an OUTER JOIN)
+ bq = self.bq
+
+ # add the clause we got from mapper._get_clause to the cache
+ # key so that if a race causes multiple calls to _get_clause,
+ # we've cached on ours
+ bq = bq._clone()
+ bq._cache_key += (_get_clause,)
+
+ bq = bq.with_criteria(
+ setup, tuple(elem is None for elem in primary_key_identity)
+ )
+
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
+
+ result = list(bq.for_session(self.session).params(**params))
+ l = len(result)
+ if l > 1:
+ raise orm_exc.MultipleResultsFound()
+ elif l:
+ return result[0]
+ else:
+ return None
+
+
+@util.deprecated(
+ "1.2", "Baked lazy loading is now the default implementation."
+)
+def bake_lazy_loaders():
+ """Enable the use of baked queries for all lazyloaders systemwide.
+
+ The "baked" implementation of lazy loading is now the sole implementation
+ for the base lazy loader; this method has no effect except for a warning.
+
+ """
+ pass
+
+
+@util.deprecated(
+ "1.2", "Baked lazy loading is now the default implementation."
+)
+def unbake_lazy_loaders():
+ """Disable the use of baked queries for all lazyloaders systemwide.
+
+ This method now raises NotImplementedError() as the "baked" implementation
+ is the only lazy load implementation. The
+ :paramref:`_orm.relationship.bake_queries` flag may be used to disable
+ the caching of queries on a per-relationship basis.
+
+ """
+ raise NotImplementedError(
+ "Baked lazy loading is now the default implementation"
+ )
+
+
+@strategy_options.loader_option()
+def baked_lazyload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using "lazy"
+ loading with a "baked" query used in the load.
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"})
+
+
+@baked_lazyload._add_unbound_fn
+@util.deprecated(
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
+def baked_lazyload(*keys):
+ return strategy_options._UnboundLoad._from_keys(
+ strategy_options._UnboundLoad.baked_lazyload, keys, False, {}
+ )
+
+
+@baked_lazyload._add_unbound_all_fn
+@util.deprecated(
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
+def baked_lazyload_all(*keys):
+ return strategy_options._UnboundLoad._from_keys(
+ strategy_options._UnboundLoad.baked_lazyload, keys, True, {}
+ )
+
+
+baked_lazyload = baked_lazyload._unbound_fn
+baked_lazyload_all = baked_lazyload_all._unbound_all_fn
+
+bakery = BakedQuery.bakery
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
new file mode 100644
index 0000000..76b59ea
--- /dev/null
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -0,0 +1,613 @@
+# ext/compiler.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Provides an API for creation of custom ClauseElements and compilers.
+
+Synopsis
+========
+
+Usage involves the creation of one or more
+:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
+more callables defining its compilation::
+
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.sql.expression import ColumnClause
+
+ class MyColumn(ColumnClause):
+ inherit_cache = True
+
+ @compiles(MyColumn)
+ def compile_mycolumn(element, compiler, **kw):
+ return "[%s]" % element.name
+
+Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
+the base expression element for named column objects. The ``compiles``
+decorator registers itself with the ``MyColumn`` class so that it is invoked
+when the object is compiled to a string::
+
+ from sqlalchemy import select
+
+ s = select(MyColumn('x'), MyColumn('y'))
+ print(str(s))
+
+Produces::
+
+ SELECT [x], [y]
+
+Dialect-specific compilation rules
+==================================
+
+Compilers can also be made dialect-specific. The appropriate compiler will be
+invoked for the dialect in use::
+
+ from sqlalchemy.schema import DDLElement
+
+ class AlterColumn(DDLElement):
+ inherit_cache = False
+
+ def __init__(self, column, cmd):
+ self.column = column
+ self.cmd = cmd
+
+ @compiles(AlterColumn)
+ def visit_alter_column(element, compiler, **kw):
+ return "ALTER COLUMN %s ..." % element.column.name
+
+ @compiles(AlterColumn, 'postgresql')
+ def visit_alter_column(element, compiler, **kw):
+ return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
+ element.column.name)
+
+The second ``visit_alter_table`` will be invoked when any ``postgresql``
+dialect is used.
+
+.. _compilerext_compiling_subelements:
+
+Compiling sub-elements of a custom expression construct
+=======================================================
+
+The ``compiler`` argument is the
+:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
+can be inspected for any information about the in-progress compilation,
+including ``compiler.dialect``, ``compiler.statement`` etc. The
+:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
+:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
+method which can be used for compilation of embedded attributes::
+
+ from sqlalchemy.sql.expression import Executable, ClauseElement
+
+ class InsertFromSelect(Executable, ClauseElement):
+ inherit_cache = False
+
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+ @compiles(InsertFromSelect)
+ def visit_insert_from_select(element, compiler, **kw):
+ return "INSERT INTO %s (%s)" % (
+ compiler.process(element.table, asfrom=True, **kw),
+ compiler.process(element.select, **kw)
+ )
+
+ insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5))
+ print(insert)
+
+Produces::
+
+ "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
+ FROM mytable WHERE mytable.x > :x_1)"
+
+.. note::
+
+ The above ``InsertFromSelect`` construct is only an example, this actual
+ functionality is already available using the
+ :meth:`_expression.Insert.from_select` method.
+
+.. note::
+
+ The above ``InsertFromSelect`` construct probably wants to have "autocommit"
+ enabled. See :ref:`enabling_compiled_autocommit` for this step.
+
+Cross Compiling between SQL and DDL compilers
+---------------------------------------------
+
+SQL and DDL constructs are each compiled using different base compilers -
+``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
+compilation rules of SQL expressions from within a DDL expression. The
+``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
+below where we generate a CHECK constraint that embeds a SQL expression::
+
+ @compiles(MyConstraint)
+ def compile_my_constraint(constraint, ddlcompiler, **kw):
+ kw['literal_binds'] = True
+ return "CONSTRAINT %s CHECK (%s)" % (
+ constraint.name,
+ ddlcompiler.sql_compiler.process(
+ constraint.expression, **kw)
+ )
+
+Above, we add an additional flag to the process step as called by
+:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
+indicates that any SQL expression which refers to a :class:`.BindParameter`
+object or other "literal" object such as those which refer to strings or
+integers should be rendered **in-place**, rather than being referred to as
+a bound parameter; when emitting DDL, bound parameters are typically not
+supported.
+
+
+.. _enabling_compiled_autocommit:
+
+Enabling Autocommit on a Construct
+==================================
+
+Recall from the section :ref:`autocommit` that the :class:`_engine.Engine`,
+when
+asked to execute a construct in the absence of a user-defined transaction,
+detects if the given construct represents DML or DDL, that is, a data
+modification or data definition statement, which requires (or may require,
+in the case of DDL) that the transaction generated by the DBAPI be committed
+(recall that DBAPI always has a transaction going on regardless of what
+SQLAlchemy does). Checking for this is actually accomplished by checking for
+the "autocommit" execution option on the construct. When building a
+construct like an INSERT derivation, a new DDL type, or perhaps a stored
+procedure that alters data, the "autocommit" option needs to be set in order
+for the statement to function with "connectionless" execution
+(as described in :ref:`dbengine_implicit`).
+
+Currently a quick way to do this is to subclass :class:`.Executable`, then
+add the "autocommit" flag to the ``_execution_options`` dictionary (note this
+is a "frozen" dictionary which supplies a generative ``union()`` method)::
+
+ from sqlalchemy.sql.expression import Executable, ClauseElement
+
+ class MyInsertThing(Executable, ClauseElement):
+ _execution_options = \
+ Executable._execution_options.union({'autocommit': True})
+
+More succinctly, if the construct is truly similar to an INSERT, UPDATE, or
+DELETE, :class:`.UpdateBase` can be used, which already is a subclass
+of :class:`.Executable`, :class:`_expression.ClauseElement` and includes the
+``autocommit`` flag::
+
+ from sqlalchemy.sql.expression import UpdateBase
+
+ class MyInsertThing(UpdateBase):
+ def __init__(self, ...):
+ ...
+
+
+
+
+DDL elements that subclass :class:`.DDLElement` already have the
+"autocommit" flag turned on.
+
+
+
+
+Changing the default compilation of existing constructs
+=======================================================
+
+The compiler extension applies just as well to the existing constructs. When
+overriding the compilation of a built in SQL construct, the @compiles
+decorator is invoked upon the appropriate class (be sure to use the class,
+i.e. ``Insert`` or ``Select``, instead of the creation function such
+as ``insert()`` or ``select()``).
+
+Within the new compilation function, to get at the "original" compilation
+routine, use the appropriate visit_XXX method - this
+because compiler.process() will call upon the overriding routine and cause
+an endless loop. Such as, to add "prefix" to all insert statements::
+
+ from sqlalchemy.sql.expression import Insert
+
+ @compiles(Insert)
+ def prefix_inserts(insert, compiler, **kw):
+ return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
+
+The above compiler will prefix all INSERT statements with "some prefix" when
+compiled.
+
+.. _type_compilation_extension:
+
+Changing Compilation of Types
+=============================
+
+``compiler`` works for types, too, such as below where we implement the
+MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
+
+ @compiles(String, 'mssql')
+ @compiles(VARCHAR, 'mssql')
+ def compile_varchar(element, compiler, **kw):
+ if element.length == 'max':
+ return "VARCHAR('max')"
+ else:
+ return compiler.visit_VARCHAR(element, **kw)
+
+ foo = Table('foo', metadata,
+ Column('data', VARCHAR('max'))
+ )
+
+Subclassing Guidelines
+======================
+
+A big part of using the compiler extension is subclassing SQLAlchemy
+expression constructs. To make this easier, the expression and
+schema packages feature a set of "bases" intended for common tasks.
+A synopsis is as follows:
+
+* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
+ expression class. Any SQL expression can be derived from this base, and is
+ probably the best choice for longer constructs such as specialized INSERT
+ statements.
+
+* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
+ "column-like" elements. Anything that you'd place in the "columns" clause of
+ a SELECT statement (as well as order by and group by) can derive from this -
+ the object will automatically have Python "comparison" behavior.
+
+ :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
+ ``type`` member which is expression's return type. This can be established
+ at the instance level in the constructor, or at the class level if its
+ generally constant::
+
+ class timestamp(ColumnElement):
+ type = TIMESTAMP()
+ inherit_cache = True
+
+* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
+ ``ColumnElement`` and a "from clause" like object, and represents a SQL
+ function or stored procedure type of call. Since most databases support
+ statements along the line of "SELECT FROM <some function>"
+ ``FunctionElement`` adds in the ability to be used in the FROM clause of a
+ ``select()`` construct::
+
+ from sqlalchemy.sql.expression import FunctionElement
+
+ class coalesce(FunctionElement):
+ name = 'coalesce'
+ inherit_cache = True
+
+ @compiles(coalesce)
+ def compile(element, compiler, **kw):
+ return "coalesce(%s)" % compiler.process(element.clauses, **kw)
+
+ @compiles(coalesce, 'oracle')
+ def compile(element, compiler, **kw):
+ if len(element.clauses) > 2:
+ raise TypeError("coalesce only supports two arguments on Oracle")
+ return "nvl(%s)" % compiler.process(element.clauses, **kw)
+
+* :class:`.DDLElement` - The root of all DDL expressions,
+ like CREATE TABLE, ALTER TABLE, etc. Compilation of :class:`.DDLElement`
+ subclasses is issued by a :class:`.DDLCompiler` instead of a
+ :class:`.SQLCompiler`. :class:`.DDLElement` can also be used as an event hook
+ in conjunction with event hooks like :meth:`.DDLEvents.before_create` and
+ :meth:`.DDLEvents.after_create`, allowing the construct to be invoked
+ automatically during CREATE TABLE and DROP TABLE sequences.
+
+ .. seealso::
+
+ :ref:`metadata_ddl_toplevel` - contains examples of associating
+ :class:`.DDL` objects (which are themselves :class:`.DDLElement`
+ instances) with :class:`.DDLEvents` event hooks.
+
+* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
+ should be used with any expression class that represents a "standalone"
+ SQL statement that can be passed directly to an ``execute()`` method. It
+ is already implicit within ``DDLElement`` and ``FunctionElement``.
+
+Most of the above constructs also respond to SQL statement caching. A
+subclassed construct will want to define the caching behavior for the object,
+which usually means setting the flag ``inherit_cache`` to the value of
+``False`` or ``True``. See the next section :ref:`compilerext_caching`
+for background.
+
+
+.. _compilerext_caching:
+
+Enabling Caching Support for Custom Constructs
+==============================================
+
+SQLAlchemy as of version 1.4 includes a
+:ref:`SQL compilation caching facility <sql_caching>` which will allow
+equivalent SQL constructs to cache their stringified form, along with other
+structural information used to fetch results from the statement.
+
+For reasons discussed at :ref:`caching_caveats`, the implementation of this
+caching system takes a conservative approach towards including custom SQL
+constructs and/or subclasses within the caching system. This includes that
+any user-defined SQL constructs, including all the examples for this
+extension, will not participate in caching by default unless they positively
+assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache`
+attribute when set to ``True`` at the class level of a specific subclass
+will indicate that instances of this class may be safely cached, using the
+cache key generation scheme of the immediate superclass. This applies
+for example to the "synopsis" example indicated previously::
+
+ class MyColumn(ColumnClause):
+ inherit_cache = True
+
+ @compiles(MyColumn)
+ def compile_mycolumn(element, compiler, **kw):
+ return "[%s]" % element.name
+
+Above, the ``MyColumn`` class does not include any new state that
+affects its SQL compilation; the cache key of ``MyColumn`` instances will
+make use of that of the ``ColumnClause`` superclass, meaning it will take
+into account the class of the object (``MyColumn``), the string name and
+datatype of the object::
+
+ >>> MyColumn("some_name", String())._generate_cache_key()
+ CacheKey(
+ key=('0', <class '__main__.MyColumn'>,
+ 'name', 'some_name',
+ 'type', (<class 'sqlalchemy.sql.sqltypes.String'>,
+ ('length', None), ('collation', None))
+ ), bindparams=[])
+
+For objects that are likely to be **used liberally as components within many
+larger statements**, such as :class:`_schema.Column` subclasses and custom SQL
+datatypes, it's important that **caching be enabled as much as possible**, as
+this may otherwise negatively affect performance.
+
+An example of an object that **does** contain state which affects its SQL
+compilation is the one illustrated at :ref:`compilerext_compiling_subelements`;
+this is an "INSERT FROM SELECT" construct that combines together a
+:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of
+which independently affect the SQL string generation of the construct. For
+this class, the example illustrates that it simply does not participate in
+caching::
+
+ class InsertFromSelect(Executable, ClauseElement):
+ inherit_cache = False
+
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+ @compiles(InsertFromSelect)
+ def visit_insert_from_select(element, compiler, **kw):
+ return "INSERT INTO %s (%s)" % (
+ compiler.process(element.table, asfrom=True, **kw),
+ compiler.process(element.select, **kw)
+ )
+
+While it is also possible that the above ``InsertFromSelect`` could be made to
+produce a cache key that is composed of that of the :class:`_schema.Table` and
+:class:`_sql.Select` components together, the API for this is not at the moment
+fully public. However, for an "INSERT FROM SELECT" construct, which is only
+used by itself for specific operations, caching is not as critical as in the
+previous example.
+
+For objects that are **used in relative isolation and are generally
+standalone**, such as custom :term:`DML` constructs like an "INSERT FROM
+SELECT", **caching is generally less critical** as the lack of caching for such
+a construct will have only localized implications for that specific operation.
+
+
+Further Examples
+================
+
+"UTC timestamp" function
+-------------------------
+
+A function that works like "CURRENT_TIMESTAMP" except applies the
+appropriate conversions so that the time is in UTC time. Timestamps are best
+stored in relational databases as UTC, without time zones. UTC so that your
+database doesn't think time has gone backwards in the hour when daylight
+savings ends, without timezones because timezones are like character
+encodings - they're best applied only at the endpoints of an application
+(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
+
+For PostgreSQL and Microsoft SQL Server::
+
+ from sqlalchemy.sql import expression
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.types import DateTime
+
+ class utcnow(expression.FunctionElement):
+ type = DateTime()
+ inherit_cache = True
+
+ @compiles(utcnow, 'postgresql')
+ def pg_utcnow(element, compiler, **kw):
+ return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
+
+ @compiles(utcnow, 'mssql')
+ def ms_utcnow(element, compiler, **kw):
+ return "GETUTCDATE()"
+
+Example usage::
+
+ from sqlalchemy import (
+ Table, Column, Integer, String, DateTime, MetaData
+ )
+ metadata = MetaData()
+ event = Table("event", metadata,
+ Column("id", Integer, primary_key=True),
+ Column("description", String(50), nullable=False),
+ Column("timestamp", DateTime, server_default=utcnow())
+ )
+
+"GREATEST" function
+-------------------
+
+The "GREATEST" function is given any number of arguments and returns the one
+that is of the highest value - its equivalent to Python's ``max``
+function. A SQL standard version versus a CASE based version which only
+accommodates two arguments::
+
+ from sqlalchemy.sql import expression, case
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.types import Numeric
+
+ class greatest(expression.FunctionElement):
+ type = Numeric()
+ name = 'greatest'
+ inherit_cache = True
+
+ @compiles(greatest)
+ def default_greatest(element, compiler, **kw):
+ return compiler.visit_function(element)
+
+ @compiles(greatest, 'sqlite')
+ @compiles(greatest, 'mssql')
+ @compiles(greatest, 'oracle')
+ def case_greatest(element, compiler, **kw):
+ arg1, arg2 = list(element.clauses)
+ return compiler.process(case([(arg1 > arg2, arg1)], else_=arg2), **kw)
+
+Example usage::
+
+ Session.query(Account).\
+ filter(
+ greatest(
+ Account.checking_balance,
+ Account.savings_balance) > 10000
+ )
+
+"false" expression
+------------------
+
+Render a "false" constant expression, rendering as "0" on platforms that
+don't have a "false" constant::
+
+ from sqlalchemy.sql import expression
+ from sqlalchemy.ext.compiler import compiles
+
+ class sql_false(expression.ColumnElement):
+ inherit_cache = True
+
+ @compiles(sql_false)
+ def default_false(element, compiler, **kw):
+ return "false"
+
+ @compiles(sql_false, 'mssql')
+ @compiles(sql_false, 'mysql')
+ @compiles(sql_false, 'oracle')
+ def int_false(element, compiler, **kw):
+ return "0"
+
+Example usage::
+
+ from sqlalchemy import select, union_all
+
+ exp = union_all(
+ select(users.c.name, sql_false().label("enrolled")),
+ select(customers.c.name, customers.c.enrolled)
+ )
+
+"""
+from .. import exc
+from .. import util
+from ..sql import sqltypes
+
+
+def compiles(class_, *specs):
+ """Register a function as a compiler for a
+ given :class:`_expression.ClauseElement` type."""
+
+ def decorate(fn):
+ # get an existing @compiles handler
+ existing = class_.__dict__.get("_compiler_dispatcher", None)
+
+ # get the original handler. All ClauseElement classes have one
+ # of these, but some TypeEngine classes will not.
+ existing_dispatch = getattr(class_, "_compiler_dispatch", None)
+
+ if not existing:
+ existing = _dispatcher()
+
+ if existing_dispatch:
+
+ def _wrap_existing_dispatch(element, compiler, **kw):
+ try:
+ return existing_dispatch(element, compiler, **kw)
+ except exc.UnsupportedCompilationError as uce:
+ util.raise_(
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
+ ),
+ from_=uce,
+ )
+
+ existing.specs["default"] = _wrap_existing_dispatch
+
+ # TODO: why is the lambda needed ?
+ setattr(
+ class_,
+ "_compiler_dispatch",
+ lambda *arg, **kw: existing(*arg, **kw),
+ )
+ setattr(class_, "_compiler_dispatcher", existing)
+
+ if specs:
+ for s in specs:
+ existing.specs[s] = fn
+
+ else:
+ existing.specs["default"] = fn
+ return fn
+
+ return decorate
+
+
+def deregister(class_):
+ """Remove all custom compilers associated with a given
+ :class:`_expression.ClauseElement` type.
+
+ """
+
+ if hasattr(class_, "_compiler_dispatcher"):
+ class_._compiler_dispatch = class_._original_compiler_dispatch
+ del class_._compiler_dispatcher
+
+
+class _dispatcher(object):
+ def __init__(self):
+ self.specs = {}
+
+ def __call__(self, element, compiler, **kw):
+ # TODO: yes, this could also switch off of DBAPI in use.
+ fn = self.specs.get(compiler.dialect.name, None)
+ if not fn:
+ try:
+ fn = self.specs["default"]
+ except KeyError as ke:
+ util.raise_(
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
+ ),
+ replace_context=ke,
+ )
+
+ # if compilation includes add_to_result_map, collect add_to_result_map
+ # arguments from the user-defined callable, which are probably none
+ # because this is not public API. if it wasn't called, then call it
+ # ourselves.
+ arm = kw.get("add_to_result_map", None)
+ if arm:
+ arm_collection = []
+ kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
+
+ expr = fn(element, compiler, **kw)
+
+ if arm:
+ if not arm_collection:
+ arm_collection.append(
+ (None, None, (element,), sqltypes.NULLTYPE)
+ )
+ for tup in arm_collection:
+ arm(*tup)
+ return expr
diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py
new file mode 100644
index 0000000..6215e35
--- /dev/null
+++ b/lib/sqlalchemy/ext/declarative/__init__.py
@@ -0,0 +1,64 @@
+# ext/declarative/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .extensions import AbstractConcreteBase
+from .extensions import ConcreteBase
+from .extensions import DeferredReflection
+from .extensions import instrument_declarative
+from ... import util
+from ...orm.decl_api import as_declarative as _as_declarative
+from ...orm.decl_api import declarative_base as _declarative_base
+from ...orm.decl_api import DeclarativeMeta
+from ...orm.decl_api import declared_attr
+from ...orm.decl_api import has_inherited_table as _has_inherited_table
+from ...orm.decl_api import synonym_for as _synonym_for
+
+
+@util.moved_20(
+ "The ``declarative_base()`` function is now available as "
+ ":func:`sqlalchemy.orm.declarative_base`."
+)
+def declarative_base(*arg, **kw):
+ return _declarative_base(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``as_declarative()`` function is now available as "
+ ":func:`sqlalchemy.orm.as_declarative`"
+)
+def as_declarative(*arg, **kw):
+ return _as_declarative(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``has_inherited_table()`` function is now available as "
+ ":func:`sqlalchemy.orm.has_inherited_table`."
+)
+def has_inherited_table(*arg, **kw):
+ return _has_inherited_table(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``synonym_for()`` function is now available as "
+ ":func:`sqlalchemy.orm.synonym_for`"
+)
+def synonym_for(*arg, **kw):
+ return _synonym_for(*arg, **kw)
+
+
+__all__ = [
+ "declarative_base",
+ "synonym_for",
+ "has_inherited_table",
+ "instrument_declarative",
+ "declared_attr",
+ "as_declarative",
+ "ConcreteBase",
+ "AbstractConcreteBase",
+ "DeclarativeMeta",
+ "DeferredReflection",
+]
diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py
new file mode 100644
index 0000000..7818841
--- /dev/null
+++ b/lib/sqlalchemy/ext/declarative/extensions.py
@@ -0,0 +1,463 @@
+# ext/declarative/extensions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Public API functions and helpers for declarative."""
+
+
+from ... import inspection
+from ... import util
+from ...orm import exc as orm_exc
+from ...orm import registry
+from ...orm import relationships
+from ...orm.base import _mapper_or_none
+from ...orm.clsregistry import _resolver
+from ...orm.decl_base import _DeferredMapperConfig
+from ...orm.util import polymorphic_union
+from ...schema import Table
+from ...util import OrderedDict
+
+
+@util.deprecated(
+ "2.0",
+ "the instrument_declarative function is deprecated "
+ "and will be removed in SQLAlhcemy 2.0. Please use "
+ ":meth:`_orm.registry.map_declaratively",
+)
+def instrument_declarative(cls, cls_registry, metadata):
+ """Given a class, configure the class declaratively,
+ using the given registry, which can be any dictionary, and
+ MetaData object.
+
+ """
+ registry(metadata=metadata, class_registry=cls_registry).map_declaratively(
+ cls
+ )
+
+
+class ConcreteBase(object):
+ """A helper class for 'concrete' declarative mappings.
+
+ :class:`.ConcreteBase` will use the :func:`.polymorphic_union`
+ function automatically, against all tables mapped as a subclass
+ to this class. The function is called via the
+ ``__declare_last__()`` function, which is essentially
+ a hook for the :meth:`.after_configured` event.
+
+ :class:`.ConcreteBase` produces a mapped
+ table for the class itself. Compare to :class:`.AbstractConcreteBase`,
+ which does not.
+
+ Example::
+
+ from sqlalchemy.ext.declarative import ConcreteBase
+
+ class Employee(ConcreteBase, Base):
+ __tablename__ = 'employee'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee',
+ 'concrete':True}
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ manager_data = Column(String(40))
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+
+ The name of the discriminator column used by :func:`.polymorphic_union`
+ defaults to the name ``type``. To suit the use case of a mapping where an
+ actual column in a mapped table is already named ``type``, the
+ discriminator name can be configured by setting the
+ ``_concrete_discriminator_name`` attribute::
+
+ class Employee(ConcreteBase, Base):
+ _concrete_discriminator_name = '_concrete_discriminator'
+
+ .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
+ attribute to :class:`_declarative.ConcreteBase` so that the
+ virtual discriminator column name can be customized.
+
+ .. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute
+ need only be placed on the basemost class to take correct effect for
+ all subclasses. An explicit error message is now raised if the
+ mapped column names conflict with the discriminator name, whereas
+ in the 1.3.x series there would be some warnings and then a non-useful
+ query would be generated.
+
+ .. seealso::
+
+ :class:`.AbstractConcreteBase`
+
+ :ref:`concrete_inheritance`
+
+
+ """
+
+ @classmethod
+ def _create_polymorphic_union(cls, mappers, discriminator_name):
+ return polymorphic_union(
+ OrderedDict(
+ (mp.polymorphic_identity, mp.local_table) for mp in mappers
+ ),
+ discriminator_name,
+ "pjoin",
+ )
+
+ @classmethod
+ def __declare_first__(cls):
+ m = cls.__mapper__
+ if m.with_polymorphic:
+ return
+
+ discriminator_name = (
+ getattr(cls, "_concrete_discriminator_name", None) or "type"
+ )
+
+ mappers = list(m.self_and_descendants)
+ pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
+ m._set_with_polymorphic(("*", pjoin))
+ m._set_polymorphic_on(pjoin.c[discriminator_name])
+
+
+class AbstractConcreteBase(ConcreteBase):
+ """A helper class for 'concrete' declarative mappings.
+
+ :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
+ function automatically, against all tables mapped as a subclass
+ to this class. The function is called via the
+ ``__declare_last__()`` function, which is essentially
+ a hook for the :meth:`.after_configured` event.
+
+ :class:`.AbstractConcreteBase` does produce a mapped class
+ for the base class, however it is not persisted to any table; it
+ is instead mapped directly to the "polymorphic" selectable directly
+ and is only used for selecting. Compare to :class:`.ConcreteBase`,
+ which does create a persisted table for the base class.
+
+ .. note::
+
+ The :class:`.AbstractConcreteBase` class does not intend to set up the
+ mapping for the base class until all the subclasses have been defined,
+ as it needs to create a mapping against a selectable that will include
+ all subclass tables. In order to achieve this, it waits for the
+ **mapper configuration event** to occur, at which point it scans
+ through all the configured subclasses and sets up a mapping that will
+ query against all subclasses at once.
+
+ While this event is normally invoked automatically, in the case of
+ :class:`.AbstractConcreteBase`, it may be necessary to invoke it
+ explicitly after **all** subclass mappings are defined, if the first
+ operation is to be a query against this base class. To do so, invoke
+ :func:`.configure_mappers` once all the desired classes have been
+ configured::
+
+ from sqlalchemy.orm import configure_mappers
+
+ configure_mappers()
+
+ .. seealso::
+
+ :func:`_orm.configure_mappers`
+
+
+ Example::
+
+ from sqlalchemy.ext.declarative import AbstractConcreteBase
+
+ class Employee(AbstractConcreteBase, Base):
+ pass
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ manager_data = Column(String(40))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ configure_mappers()
+
+ The abstract base class is handled by declarative in a special way;
+ at class configuration time, it behaves like a declarative mixin
+ or an ``__abstract__`` base class. Once classes are configured
+ and mappings are produced, it then gets mapped itself, but
+ after all of its descendants. This is a very unique system of mapping
+ not found in any other SQLAlchemy system.
+
+ Using this approach, we can specify columns and properties
+ that will take place on mapped subclasses, in the way that
+ we normally do as in :ref:`declarative_mixins`::
+
+ class Company(Base):
+ __tablename__ = 'company'
+ id = Column(Integer, primary_key=True)
+
+ class Employee(AbstractConcreteBase, Base):
+ employee_id = Column(Integer, primary_key=True)
+
+ @declared_attr
+ def company_id(cls):
+ return Column(ForeignKey('company.id'))
+
+ @declared_attr
+ def company(cls):
+ return relationship("Company")
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+
+ name = Column(String(50))
+ manager_data = Column(String(40))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ configure_mappers()
+
+ When we make use of our mappings however, both ``Manager`` and
+ ``Employee`` will have an independently usable ``.company`` attribute::
+
+ session.query(Employee).filter(Employee.company.has(id=5))
+
+ .. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
+ have been reworked to support relationships established directly
+ on the abstract base, without any special configurational steps.
+
+ .. seealso::
+
+ :class:`.ConcreteBase`
+
+ :ref:`concrete_inheritance`
+
+ """
+
+ __no_table__ = True
+
+ @classmethod
+ def __declare_first__(cls):
+ cls._sa_decl_prepare_nocascade()
+
+ @classmethod
+ def _sa_decl_prepare_nocascade(cls):
+ if getattr(cls, "__mapper__", None):
+ return
+
+ to_map = _DeferredMapperConfig.config_for_cls(cls)
+
+ # can't rely on 'self_and_descendants' here
+ # since technically an immediate subclass
+ # might not be mapped, but a subclass
+ # may be.
+ mappers = []
+ stack = list(cls.__subclasses__())
+ while stack:
+ klass = stack.pop()
+ stack.extend(klass.__subclasses__())
+ mn = _mapper_or_none(klass)
+ if mn is not None:
+ mappers.append(mn)
+
+ discriminator_name = (
+ getattr(cls, "_concrete_discriminator_name", None) or "type"
+ )
+ pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
+
+ # For columns that were declared on the class, these
+ # are normally ignored with the "__no_table__" mapping,
+ # unless they have a different attribute key vs. col name
+ # and are in the properties argument.
+ # In that case, ensure we update the properties entry
+ # to the correct column from the pjoin target table.
+ declared_cols = set(to_map.declared_columns)
+ for k, v in list(to_map.properties.items()):
+ if v in declared_cols:
+ to_map.properties[k] = pjoin.c[v.key]
+
+ to_map.local_table = pjoin
+
+ m_args = to_map.mapper_args_fn or dict
+
+ def mapper_args():
+ args = m_args()
+ args["polymorphic_on"] = pjoin.c[discriminator_name]
+ return args
+
+ to_map.mapper_args_fn = mapper_args
+
+ m = to_map.map()
+
+ for scls in cls.__subclasses__():
+ sm = _mapper_or_none(scls)
+ if sm and sm.concrete and cls in scls.__bases__:
+ sm._set_concrete_base(m)
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of AbstractConcreteBase and "
+ "has a mapping pending until all subclasses are defined. "
+ "Call the sqlalchemy.orm.configure_mappers() function after "
+ "all subclasses have been defined to "
+ "complete the mapping of this class."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+
+class DeferredReflection(object):
+ """A helper class for construction of mappings based on
+ a deferred reflection step.
+
+ Normally, declarative can be used with reflection by
+ setting a :class:`_schema.Table` object using autoload_with=engine
+ as the ``__table__`` attribute on a declarative class.
+ The caveat is that the :class:`_schema.Table` must be fully
+ reflected, or at the very least have a primary key column,
+ at the point at which a normal declarative mapping is
+ constructed, meaning the :class:`_engine.Engine` must be available
+ at class declaration time.
+
+ The :class:`.DeferredReflection` mixin moves the construction
+ of mappers to be at a later point, after a specific
+ method is called which first reflects all :class:`_schema.Table`
+ objects created so far. Classes can define it as such::
+
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.declarative import DeferredReflection
+ Base = declarative_base()
+
+ class MyClass(DeferredReflection, Base):
+ __tablename__ = 'mytable'
+
+ Above, ``MyClass`` is not yet mapped. After a series of
+ classes have been defined in the above fashion, all tables
+ can be reflected and mappings created using
+ :meth:`.prepare`::
+
+ engine = create_engine("someengine://...")
+ DeferredReflection.prepare(engine)
+
+ The :class:`.DeferredReflection` mixin can be applied to individual
+ classes, used as the base for the declarative base itself,
+ or used in a custom abstract class. Using an abstract base
+ allows that only a subset of classes to be prepared for a
+ particular prepare step, which is necessary for applications
+ that use more than one engine. For example, if an application
+ has two engines, you might use two bases, and prepare each
+ separately, e.g.::
+
+ class ReflectedOne(DeferredReflection, Base):
+ __abstract__ = True
+
+ class ReflectedTwo(DeferredReflection, Base):
+ __abstract__ = True
+
+ class MyClass(ReflectedOne):
+ __tablename__ = 'mytable'
+
+ class MyOtherClass(ReflectedOne):
+ __tablename__ = 'myothertable'
+
+ class YetAnotherClass(ReflectedTwo):
+ __tablename__ = 'yetanothertable'
+
+ # ... etc.
+
+ Above, the class hierarchies for ``ReflectedOne`` and
+ ``ReflectedTwo`` can be configured separately::
+
+ ReflectedOne.prepare(engine_one)
+ ReflectedTwo.prepare(engine_two)
+
+ .. seealso::
+
+ :ref:`orm_declarative_reflected_deferred_reflection` - in the
+ :ref:`orm_declarative_table_config_toplevel` section.
+
+ """
+
+ @classmethod
+ def prepare(cls, engine):
+ """Reflect all :class:`_schema.Table` objects for all current
+ :class:`.DeferredReflection` subclasses"""
+
+ to_map = _DeferredMapperConfig.classes_for_base(cls)
+
+ with inspection.inspect(engine)._inspection_context() as insp:
+ for thingy in to_map:
+ cls._sa_decl_prepare(thingy.local_table, insp)
+ thingy.map()
+ mapper = thingy.cls.__mapper__
+ metadata = mapper.class_.metadata
+ for rel in mapper._props.values():
+ if (
+ isinstance(rel, relationships.RelationshipProperty)
+ and rel.secondary is not None
+ ):
+ if isinstance(rel.secondary, Table):
+ cls._reflect_table(rel.secondary, insp)
+ elif isinstance(rel.secondary, str):
+
+ _, resolve_arg = _resolver(rel.parent.class_, rel)
+
+ rel.secondary = resolve_arg(rel.secondary)
+ rel.secondary._resolvers += (
+ cls._sa_deferred_table_resolver(
+ insp, metadata
+ ),
+ )
+
+ # controversy! do we resolve it here? or leave
+ # it deferred? I think doing it here is necessary
+ # so the connection does not leak.
+ rel.secondary = rel.secondary()
+
+ @classmethod
+ def _sa_deferred_table_resolver(cls, inspector, metadata):
+ def _resolve(key):
+ t1 = Table(key, metadata)
+ cls._reflect_table(t1, inspector)
+ return t1
+
+ return _resolve
+
+ @classmethod
+ def _sa_decl_prepare(cls, local_table, inspector):
+ # autoload Table, which is already
+ # present in the metadata. This
+ # will fill in db-loaded columns
+ # into the existing Table object.
+ if local_table is not None:
+ cls._reflect_table(local_table, inspector)
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of DeferredReflection. "
+ "Mappings are not produced until the .prepare() "
+ "method is called on the class hierarchy."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+ @classmethod
+ def _reflect_table(cls, table, inspector):
+ Table(
+ table.name,
+ table.metadata,
+ extend_existing=True,
+ autoload_replace=False,
+ autoload_with=inspector,
+ schema=table.schema,
+ )
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
new file mode 100644
index 0000000..bad076e
--- /dev/null
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -0,0 +1,256 @@
+# ext/horizontal_shard.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Horizontal sharding support.
+
+Defines a rudimental 'horizontal sharding' system which allows a Session to
+distribute queries and persistence operations across multiple databases.
+
+For a usage example, see the :ref:`examples_sharding` example included in
+the source distribution.
+
+"""
+
+from .. import event
+from .. import exc
+from .. import inspect
+from .. import util
+from ..orm.query import Query
+from ..orm.session import Session
+
+__all__ = ["ShardedSession", "ShardedQuery"]
+
+
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self.execute_chooser = self.session.execute_chooser
+ self._shard_id = None
+
+ def set_shard(self, shard_id):
+ """Return a new query, limited to a single shard ID.
+
+ All subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+
+ The shard_id can be passed for a 2.0 style execution to the
+ bind_arguments dictionary of :meth:`.Session.execute`::
+
+ results = session.execute(
+ stmt,
+ bind_arguments={"shard_id": "my_shard"}
+ )
+
+ """
+ return self.execution_options(_sa_shard_id=shard_id)
+
+
+class ShardedSession(Session):
+ def __init__(
+ self,
+ shard_chooser,
+ id_chooser,
+ execute_chooser=None,
+ shards=None,
+ query_cls=ShardedQuery,
+ **kwargs
+ ):
+ """Construct a ShardedSession.
+
+ :param shard_chooser: A callable which, passed a Mapper, a mapped
+ instance, and possibly a SQL clause, returns a shard ID. This id
+ may be based off of the attributes present within the object, or on
+ some round-robin scheme. If the scheme is based on a selection, it
+ should set whatever state on the instance to mark it in the future as
+ participating in that shard.
+
+ :param id_chooser: A callable, passed a query and a tuple of identity
+ values, which should return a list of shard ids where the ID might
+ reside. The databases will be queried in the order of this listing.
+
+ :param execute_chooser: For a given :class:`.ORMExecuteState`,
+ returns the list of shard_ids
+ where the query should be issued. Results from all shards returned
+ will be combined together into a single listing.
+
+ .. versionchanged:: 1.4 The ``execute_chooser`` parameter
+ supersedes the ``query_chooser`` parameter.
+
+ :param shards: A dictionary of string shard names
+ to :class:`~sqlalchemy.engine.Engine` objects.
+
+ """
+ query_chooser = kwargs.pop("query_chooser", None)
+ super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
+
+ event.listen(
+ self, "do_orm_execute", execute_and_instances, retval=True
+ )
+ self.shard_chooser = shard_chooser
+ self.id_chooser = id_chooser
+
+ if query_chooser:
+ util.warn_deprecated(
+ "The ``query_choser`` parameter is deprecated; "
+ "please use ``execute_chooser``.",
+ "1.4",
+ )
+ if execute_chooser:
+ raise exc.ArgumentError(
+ "Can't pass query_chooser and execute_chooser "
+ "at the same time."
+ )
+
+ def execute_chooser(orm_context):
+ return query_chooser(orm_context.statement)
+
+ self.execute_chooser = execute_chooser
+ else:
+ self.execute_chooser = execute_chooser
+ self.query_chooser = query_chooser
+ self.__binds = {}
+ if shards is not None:
+ for k in shards:
+ self.bind_shard(k, shards[k])
+
+ def _identity_lookup(
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ lazy_loaded_from=None,
+ **kw
+ ):
+ """override the default :meth:`.Session._identity_lookup` method so
+ that we search for a given non-token primary key identity across all
+ possible identity tokens (e.g. shard ids).
+
+ .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
+ the :class:`_query.Query` object to the :class:`.Session`.
+
+ """
+
+ if identity_token is not None:
+ return super(ShardedSession, self)._identity_lookup(
+ mapper,
+ primary_key_identity,
+ identity_token=identity_token,
+ **kw
+ )
+ else:
+ q = self.query(mapper)
+ if lazy_loaded_from:
+ q = q._set_lazyload_from(lazy_loaded_from)
+ for shard_id in self.id_chooser(q, primary_key_identity):
+ obj = super(ShardedSession, self)._identity_lookup(
+ mapper,
+ primary_key_identity,
+ identity_token=shard_id,
+ lazy_loaded_from=lazy_loaded_from,
+ **kw
+ )
+ if obj is not None:
+ return obj
+
+ return None
+
+ def _choose_shard_and_assign(self, mapper, instance, **kw):
+ if instance is not None:
+ state = inspect(instance)
+ if state.key:
+ token = state.key[2]
+ assert token is not None
+ return token
+ elif state.identity_token:
+ return state.identity_token
+
+ shard_id = self.shard_chooser(mapper, instance, **kw)
+ if instance is not None:
+ state.identity_token = shard_id
+ return shard_id
+
+ def connection_callable(
+ self, mapper=None, instance=None, shard_id=None, **kwargs
+ ):
+ """Provide a :class:`_engine.Connection` to use in the unit of work
+ flush process.
+
+ """
+
+ if shard_id is None:
+ shard_id = self._choose_shard_and_assign(mapper, instance)
+
+ if self.in_transaction():
+ return self.get_transaction().connection(mapper, shard_id=shard_id)
+ else:
+ return self.get_bind(
+ mapper, shard_id=shard_id, instance=instance
+ ).connect(**kwargs)
+
+ def get_bind(
+ self, mapper=None, shard_id=None, instance=None, clause=None, **kw
+ ):
+ if shard_id is None:
+ shard_id = self._choose_shard_and_assign(
+ mapper, instance, clause=clause
+ )
+ return self.__binds[shard_id]
+
+ def bind_shard(self, shard_id, bind):
+ self.__binds[shard_id] = bind
+
+
+def execute_and_instances(orm_context):
+ if orm_context.is_select:
+ load_options = active_options = orm_context.load_options
+ update_options = None
+
+ elif orm_context.is_update or orm_context.is_delete:
+ load_options = None
+ update_options = active_options = orm_context.update_delete_options
+ else:
+ load_options = update_options = active_options = None
+
+ session = orm_context.session
+
+ def iter_for_shard(shard_id, load_options, update_options):
+ execution_options = dict(orm_context.local_execution_options)
+
+ bind_arguments = dict(orm_context.bind_arguments)
+ bind_arguments["shard_id"] = shard_id
+
+ if orm_context.is_select:
+ load_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_load_options"] = load_options
+ elif orm_context.is_update or orm_context.is_delete:
+ update_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_update_options"] = update_options
+
+ return orm_context.invoke_statement(
+ bind_arguments=bind_arguments, execution_options=execution_options
+ )
+
+ if active_options and active_options._refresh_identity_token is not None:
+ shard_id = active_options._refresh_identity_token
+ elif "_sa_shard_id" in orm_context.execution_options:
+ shard_id = orm_context.execution_options["_sa_shard_id"]
+ elif "shard_id" in orm_context.bind_arguments:
+ shard_id = orm_context.bind_arguments["shard_id"]
+ else:
+ shard_id = None
+
+ if shard_id is not None:
+ return iter_for_shard(shard_id, load_options, update_options)
+ else:
+ partial = []
+ for shard_id in session.execute_chooser(orm_context):
+ result_ = iter_for_shard(shard_id, load_options, update_options)
+ partial.append(result_)
+
+ return partial[0].merge(*partial[1:])
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
new file mode 100644
index 0000000..cc0aca6
--- /dev/null
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -0,0 +1,1206 @@
+# ext/hybrid.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Define attributes on ORM-mapped classes that have "hybrid" behavior.
+
+"hybrid" means the attribute has distinct behaviors defined at the
+class level and at the instance level.
+
+The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of
+method decorator, is around 50 lines of code and has almost no
+dependencies on the rest of SQLAlchemy. It can, in theory, work with
+any descriptor-based expression system.
+
+Consider a mapping ``Interval``, representing integer ``start`` and ``end``
+values. We can define higher level functions on mapped classes that produce SQL
+expressions at the class level, and Python expression evaluation at the
+instance level. Below, each function decorated with :class:`.hybrid_method` or
+:class:`.hybrid_property` may receive ``self`` as an instance of the class, or
+as the class itself::
+
+ from sqlalchemy import Column, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.orm import Session, aliased
+ from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
+
+ Base = declarative_base()
+
+ class Interval(Base):
+ __tablename__ = 'interval'
+
+ id = Column(Integer, primary_key=True)
+ start = Column(Integer, nullable=False)
+ end = Column(Integer, nullable=False)
+
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @hybrid_method
+ def contains(self, point):
+ return (self.start <= point) & (point <= self.end)
+
+ @hybrid_method
+ def intersects(self, other):
+ return self.contains(other.start) | self.contains(other.end)
+
+Above, the ``length`` property returns the difference between the
+``end`` and ``start`` attributes. With an instance of ``Interval``,
+this subtraction occurs in Python, using normal Python descriptor
+mechanics::
+
+ >>> i1 = Interval(5, 10)
+ >>> i1.length
+ 5
+
+When dealing with the ``Interval`` class itself, the :class:`.hybrid_property`
+descriptor evaluates the function body given the ``Interval`` class as
+the argument, which when evaluated with SQLAlchemy expression mechanics
+(here using the :attr:`.QueryableAttribute.expression` accessor)
+returns a new SQL expression::
+
+ >>> print(Interval.length.expression)
+ interval."end" - interval.start
+
+ >>> print(Session().query(Interval).filter(Interval.length > 10))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval."end" - interval.start > :param_1
+
+ORM methods such as :meth:`_query.Query.filter_by`
+generally use ``getattr()`` to
+locate attributes, so can also be used with hybrid attributes::
+
+ >>> print(Session().query(Interval).filter_by(length=5))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval."end" - interval.start = :param_1
+
+The ``Interval`` class example also illustrates two methods,
+``contains()`` and ``intersects()``, decorated with
+:class:`.hybrid_method`. This decorator applies the same idea to
+methods that :class:`.hybrid_property` applies to attributes. The
+methods return boolean values, and take advantage of the Python ``|``
+and ``&`` bitwise operators to produce equivalent instance-level and
+SQL expression-level boolean behavior::
+
+ >>> i1.contains(6)
+ True
+ >>> i1.contains(15)
+ False
+ >>> i1.intersects(Interval(7, 18))
+ True
+ >>> i1.intersects(Interval(25, 29))
+ False
+
+ >>> print(Session().query(Interval).filter(Interval.contains(15)))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval.start <= :start_1 AND interval."end" > :end_1
+
+ >>> ia = aliased(Interval)
+ >>> print(Session().query(Interval, ia).filter(Interval.intersects(ia)))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end, interval_1.id AS interval_1_id,
+ interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end
+ FROM interval, interval AS interval_1
+ WHERE interval.start <= interval_1.start
+ AND interval."end" > interval_1.start
+ OR interval.start <= interval_1."end"
+ AND interval."end" > interval_1."end"
+
+.. _hybrid_distinct_expression:
+
+Defining Expression Behavior Distinct from Attribute Behavior
+--------------------------------------------------------------
+
+Our usage of the ``&`` and ``|`` bitwise operators above was
+fortunate, considering our functions operated on two boolean values to
+return a new one. In many cases, the construction of an in-Python
+function and a SQLAlchemy SQL expression have enough differences that
+two separate Python expressions should be defined. The
+:mod:`~sqlalchemy.ext.hybrid` decorators define the
+:meth:`.hybrid_property.expression` modifier for this purpose. As an
+example we'll define the radius of the interval, which requires the
+usage of the absolute value function::
+
+ from sqlalchemy import func
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def radius(self):
+ return abs(self.length) / 2
+
+ @radius.expression
+ def radius(cls):
+ return func.abs(cls.length) / 2
+
+Above the Python function ``abs()`` is used for instance-level
+operations, the SQL function ``ABS()`` is used via the :data:`.func`
+object for class-level expressions::
+
+ >>> i1.radius
+ 2
+
+ >>> print(Session().query(Interval).filter(Interval.radius > 5))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1
+
+.. note:: When defining an expression for a hybrid property or method, the
+ expression method **must** retain the name of the original hybrid, else
+ the new hybrid with the additional state will be attached to the class
+ with the non-matching name. To use the example above::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def radius(self):
+ return abs(self.length) / 2
+
+ # WRONG - the non-matching name will cause this function to be
+ # ignored
+ @radius.expression
+ def radius_expression(cls):
+ return func.abs(cls.length) / 2
+
+ This is also true for other mutator methods, such as
+ :meth:`.hybrid_property.update_expression`. This is the same behavior
+ as that of the ``@property`` construct that is part of standard Python.
+
+Defining Setters
+----------------
+
+Hybrid properties can also define setter methods. If we wanted
+``length`` above, when set, to modify the endpoint value::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @length.setter
+ def length(self, value):
+ self.end = self.start + value
+
+The ``length(self, value)`` method is now called upon set::
+
+ >>> i1 = Interval(5, 10)
+ >>> i1.length
+ 5
+ >>> i1.length = 12
+ >>> i1.end
+ 17
+
+.. _hybrid_bulk_update:
+
+Allowing Bulk ORM Update
+------------------------
+
+A hybrid can define a custom "UPDATE" handler for when using the
+:meth:`_query.Query.update` method, allowing the hybrid to be used in the
+SET clause of the update.
+
+Normally, when using a hybrid with :meth:`_query.Query.update`, the SQL
+expression is used as the column that's the target of the SET. If our
+``Interval`` class had a hybrid ``start_point`` that linked to
+``Interval.start``, this could be substituted directly::
+
+ session.query(Interval).update({Interval.start_point: 10})
+
+However, when using a composite hybrid like ``Interval.length``, this
+hybrid represents more than one column. We can set up a handler that will
+accommodate a value passed to :meth:`_query.Query.update` which can affect
+this, using the :meth:`.hybrid_property.update_expression` decorator.
+A handler that works similarly to our setter would be::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @length.setter
+ def length(self, value):
+ self.end = self.start + value
+
+ @length.update_expression
+ def length(cls, value):
+ return [
+ (cls.end, cls.start + value)
+ ]
+
+Above, if we use ``Interval.length`` in an UPDATE expression as::
+
+ session.query(Interval).update(
+ {Interval.length: 25}, synchronize_session='fetch')
+
+We'll get an UPDATE statement along the lines of::
+
+ UPDATE interval SET end=start + :value
+
+In some cases, the default "evaluate" strategy can't perform the SET
+expression in Python; while the addition operator we're using above
+is supported, for more complex SET expressions it will usually be necessary
+to use either the "fetch" or False synchronization strategy as illustrated
+above.
+
+.. note:: For ORM bulk updates to work with hybrids, the function name
+ of the hybrid must match that of how it is accessed. Something
+ like this wouldn't work::
+
+ class Interval(object):
+ # ...
+
+ def _get(self):
+ return self.end - self.start
+
+ def _set(self, value):
+ self.end = self.start + value
+
+ def _update_expr(cls, value):
+ return [
+ (cls.end, cls.start + value)
+ ]
+
+ length = hybrid_property(
+ fget=_get, fset=_set, update_expr=_update_expr
+ )
+
+ The Python descriptor protocol does not provide any reliable way for
+ a descriptor to know what attribute name it was accessed as, and
+ the UPDATE scheme currently relies upon being able to access the
+ attribute from an instance by name in order to perform the instance
+ synchronization step.
+
+.. versionadded:: 1.2 added support for bulk updates to hybrid properties.
+
+Working with Relationships
+--------------------------
+
+There's no essential difference when creating hybrids that work with
+related objects as opposed to column-based data. The need for distinct
+expressions tends to be greater. The two variants we'll illustrate
+are the "join-dependent" hybrid, and the "correlated subquery" hybrid.
+
+Join-Dependent Relationship Hybrid
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Consider the following declarative
+mapping which relates a ``User`` to a ``SavingsAccount``::
+
+ from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ Base = declarative_base()
+
+ class SavingsAccount(Base):
+ __tablename__ = 'account'
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
+ balance = Column(Numeric(15, 5))
+
+ class User(Base):
+ __tablename__ = 'user'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+
+ accounts = relationship("SavingsAccount", backref="owner")
+
+ @hybrid_property
+ def balance(self):
+ if self.accounts:
+ return self.accounts[0].balance
+ else:
+ return None
+
+ @balance.setter
+ def balance(self, value):
+ if not self.accounts:
+ account = Account(owner=self)
+ else:
+ account = self.accounts[0]
+ account.balance = value
+
+ @balance.expression
+ def balance(cls):
+ return SavingsAccount.balance
+
+The above hybrid property ``balance`` works with the first
+``SavingsAccount`` entry in the list of accounts for this user. The
+in-Python getter/setter methods can treat ``accounts`` as a Python
+list available on ``self``.
+
+However, at the expression level, it's expected that the ``User`` class will
+be used in an appropriate context such that an appropriate join to
+``SavingsAccount`` will be present::
+
+ >>> print(Session().query(User, User.balance).
+ ... join(User.accounts).filter(User.balance > 5000))
+ SELECT "user".id AS user_id, "user".name AS user_name,
+ account.balance AS account_balance
+ FROM "user" JOIN account ON "user".id = account.user_id
+ WHERE account.balance > :balance_1
+
+Note however, that while the instance level accessors need to worry
+about whether ``self.accounts`` is even present, this issue expresses
+itself differently at the SQL expression level, where we basically
+would use an outer join::
+
+ >>> from sqlalchemy import or_
+ >>> print (Session().query(User, User.balance).outerjoin(User.accounts).
+ ... filter(or_(User.balance < 5000, User.balance == None)))
+ SELECT "user".id AS user_id, "user".name AS user_name,
+ account.balance AS account_balance
+ FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id
+ WHERE account.balance < :balance_1 OR account.balance IS NULL
+
+Correlated Subquery Relationship Hybrid
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+We can, of course, forego being dependent on the enclosing query's usage
+of joins in favor of the correlated subquery, which can portably be packed
+into a single column expression. A correlated subquery is more portable, but
+often performs more poorly at the SQL level. Using the same technique
+illustrated at :ref:`mapper_column_property_sql_expressions`,
+we can adjust our ``SavingsAccount`` example to aggregate the balances for
+*all* accounts, and use a correlated subquery for the column expression::
+
+ from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.hybrid import hybrid_property
+ from sqlalchemy import select, func
+
+ Base = declarative_base()
+
+ class SavingsAccount(Base):
+ __tablename__ = 'account'
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
+ balance = Column(Numeric(15, 5))
+
+ class User(Base):
+ __tablename__ = 'user'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+
+ accounts = relationship("SavingsAccount", backref="owner")
+
+ @hybrid_property
+ def balance(self):
+ return sum(acc.balance for acc in self.accounts)
+
+ @balance.expression
+ def balance(cls):
+ return select(func.sum(SavingsAccount.balance)).\
+ where(SavingsAccount.user_id==cls.id).\
+ label('total_balance')
+
+The above recipe will give us the ``balance`` column which renders
+a correlated SELECT::
+
+ >>> print(s.query(User).filter(User.balance > 400))
+ SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE (SELECT sum(account.balance) AS sum_1
+ FROM account
+ WHERE account.user_id = "user".id) > :param_1
+
+.. _hybrid_custom_comparators:
+
+Building Custom Comparators
+---------------------------
+
+The hybrid property also includes a helper that allows construction of
+custom comparators. A comparator object allows one to customize the
+behavior of each SQLAlchemy expression operator individually. They
+are useful when creating custom types that have some highly
+idiosyncratic behavior on the SQL side.
+
+.. note:: The :meth:`.hybrid_property.comparator` decorator introduced
+ in this section **replaces** the use of the
+ :meth:`.hybrid_property.expression` decorator.
+ They cannot be used together.
+
+The example class below allows case-insensitive comparisons on the attribute
+named ``word_insensitive``::
+
+ from sqlalchemy.ext.hybrid import Comparator, hybrid_property
+ from sqlalchemy import func, Column, Integer, String
+ from sqlalchemy.orm import Session
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class CaseInsensitiveComparator(Comparator):
+ def __eq__(self, other):
+ return func.lower(self.__clause_element__()) == func.lower(other)
+
+ class SearchWord(Base):
+ __tablename__ = 'searchword'
+ id = Column(Integer, primary_key=True)
+ word = Column(String(255), nullable=False)
+
+ @hybrid_property
+ def word_insensitive(self):
+ return self.word.lower()
+
+ @word_insensitive.comparator
+ def word_insensitive(cls):
+ return CaseInsensitiveComparator(cls.word)
+
+Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()``
+SQL function to both sides::
+
+ >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
+ SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
+ FROM searchword
+ WHERE lower(searchword.word) = lower(:lower_1)
+
+The ``CaseInsensitiveComparator`` above implements part of the
+:class:`.ColumnOperators` interface. A "coercion" operation like
+lowercasing can be applied to all comparison operations (i.e. ``eq``,
+``lt``, ``gt``, etc.) using :meth:`.Operators.operate`::
+
+ class CaseInsensitiveComparator(Comparator):
+ def operate(self, op, other, **kwargs):
+ return op(
+ func.lower(self.__clause_element__()),
+ func.lower(other),
+ **kwargs,
+ )
+
+.. _hybrid_reuse_subclass:
+
+Reusing Hybrid Properties across Subclasses
+-------------------------------------------
+
+A hybrid can be referred to from a superclass, to allow modifying
+methods like :meth:`.hybrid_property.getter`, :meth:`.hybrid_property.setter`
+to be used to redefine those methods on a subclass. This is similar to
+how the standard Python ``@property`` object works::
+
+ class FirstNameOnly(Base):
+ # ...
+
+ first_name = Column(String)
+
+ @hybrid_property
+ def name(self):
+ return self.first_name
+
+ @name.setter
+ def name(self, value):
+ self.first_name = value
+
+ class FirstNameLastName(FirstNameOnly):
+ # ...
+
+ last_name = Column(String)
+
+ @FirstNameOnly.name.getter
+ def name(self):
+ return self.first_name + ' ' + self.last_name
+
+ @name.setter
+ def name(self, value):
+ self.first_name, self.last_name = value.split(' ', 1)
+
+Above, the ``FirstNameLastName`` class refers to the hybrid from
+``FirstNameOnly.name`` to repurpose its getter and setter for the subclass.
+
+When overriding :meth:`.hybrid_property.expression` and
+:meth:`.hybrid_property.comparator` alone as the first reference to the
+superclass, these names conflict with the same-named accessors on the class-
+level :class:`.QueryableAttribute` object returned at the class level. To
+override these methods when referring directly to the parent class descriptor,
+add the special qualifier :attr:`.hybrid_property.overrides`, which will de-
+reference the instrumented attribute back to the hybrid object::
+
+ class FirstNameLastName(FirstNameOnly):
+ # ...
+
+ last_name = Column(String)
+
+ @FirstNameOnly.name.overrides.expression
+ def name(cls):
+ return func.concat(cls.first_name, ' ', cls.last_name)
+
+.. versionadded:: 1.2 Added :meth:`.hybrid_property.getter` as well as the
+ ability to redefine accessors per-subclass.
+
+
+Hybrid Value Objects
+--------------------
+
+Note in our previous example, if we were to compare the ``word_insensitive``
+attribute of a ``SearchWord`` instance to a plain Python string, the plain
+Python string would not be coerced to lower case - the
+``CaseInsensitiveComparator`` we built, being returned by
+``@word_insensitive.comparator``, only applies to the SQL side.
+
+A more comprehensive form of the custom comparator is to construct a *Hybrid
+Value Object*. This technique applies the target value or expression to a value
+object which is then returned by the accessor in all cases. The value object
+allows control of all operations upon the value as well as how compared values
+are treated, both on the SQL expression side as well as the Python value side.
+Replacing the previous ``CaseInsensitiveComparator`` class with a new
+``CaseInsensitiveWord`` class::
+
+ class CaseInsensitiveWord(Comparator):
+ "Hybrid value representing a lower case representation of a word."
+
+ def __init__(self, word):
+ if isinstance(word, basestring):
+ self.word = word.lower()
+ elif isinstance(word, CaseInsensitiveWord):
+ self.word = word.word
+ else:
+ self.word = func.lower(word)
+
+ def operate(self, op, other, **kwargs):
+ if not isinstance(other, CaseInsensitiveWord):
+ other = CaseInsensitiveWord(other)
+ return op(self.word, other.word, **kwargs)
+
+ def __clause_element__(self):
+ return self.word
+
+ def __str__(self):
+ return self.word
+
+ key = 'word'
+ "Label to apply to Query tuple results"
+
+Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may
+be a SQL function, or may be a Python native. By overriding ``operate()`` and
+``__clause_element__()`` to work in terms of ``self.word``, all comparison
+operations will work against the "converted" form of ``word``, whether it be
+SQL side or Python side. Our ``SearchWord`` class can now deliver the
+``CaseInsensitiveWord`` object unconditionally from a single hybrid call::
+
+ class SearchWord(Base):
+ __tablename__ = 'searchword'
+ id = Column(Integer, primary_key=True)
+ word = Column(String(255), nullable=False)
+
+ @hybrid_property
+ def word_insensitive(self):
+ return CaseInsensitiveWord(self.word)
+
+The ``word_insensitive`` attribute now has case-insensitive comparison behavior
+universally, including SQL expression vs. Python expression (note the Python
+value is converted to lower case on the Python side here)::
+
+ >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
+ SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
+ FROM searchword
+ WHERE lower(searchword.word) = :lower_1
+
+SQL expression versus SQL expression::
+
+ >>> sw1 = aliased(SearchWord)
+ >>> sw2 = aliased(SearchWord)
+ >>> print(Session().query(
+ ... sw1.word_insensitive,
+ ... sw2.word_insensitive).\
+ ... filter(
+ ... sw1.word_insensitive > sw2.word_insensitive
+ ... ))
+ SELECT lower(searchword_1.word) AS lower_1,
+ lower(searchword_2.word) AS lower_2
+ FROM searchword AS searchword_1, searchword AS searchword_2
+ WHERE lower(searchword_1.word) > lower(searchword_2.word)
+
+Python only expression::
+
+ >>> ws1 = SearchWord(word="SomeWord")
+ >>> ws1.word_insensitive == "sOmEwOrD"
+ True
+ >>> ws1.word_insensitive == "XOmEwOrX"
+ False
+ >>> print(ws1.word_insensitive)
+ someword
+
+The Hybrid Value pattern is very useful for any kind of value that may have
+multiple representations, such as timestamps, time deltas, units of
+measurement, currencies and encrypted passwords.
+
+.. seealso::
+
+ `Hybrids and Value Agnostic Types
+ <https://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/>`_
+ - on the techspot.zzzeek.org blog
+
+ `Value Agnostic Types, Part II
+ <https://techspot.zzzeek.org/2011/10/29/value-agnostic-types-part-ii/>`_ -
+ on the techspot.zzzeek.org blog
+
+.. _hybrid_transformers:
+
+Building Transformers
+----------------------
+
+A *transformer* is an object which can receive a :class:`_query.Query`
+object and
+return a new one. The :class:`_query.Query` object includes a method
+:meth:`.with_transformation` that returns a new :class:`_query.Query`
+transformed by
+the given function.
+
+We can combine this with the :class:`.Comparator` class to produce one type
+of recipe which can both set up the FROM clause of a query as well as assign
+filtering criterion.
+
+Consider a mapped class ``Node``, which assembles using adjacency list into a
+hierarchical tree pattern::
+
+ from sqlalchemy import Column, Integer, ForeignKey
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class Node(Base):
+ __tablename__ = 'node'
+ id = Column(Integer, primary_key=True)
+ parent_id = Column(Integer, ForeignKey('node.id'))
+ parent = relationship("Node", remote_side=id)
+
+Suppose we wanted to add an accessor ``grandparent``. This would return the
+``parent`` of ``Node.parent``. When we have an instance of ``Node``, this is
+simple::
+
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ class Node(Base):
+ # ...
+
+ @hybrid_property
+ def grandparent(self):
+ return self.parent.parent
+
+For the expression, things are not so clear. We'd need to construct a
+:class:`_query.Query` where we :meth:`_query.Query.join` twice along
+``Node.parent`` to get to the ``grandparent``. We can instead return a
+transforming callable that we'll combine with the :class:`.Comparator` class to
+receive any :class:`_query.Query` object, and return a new one that's joined to
+the ``Node.parent`` attribute and filtered based on the given criterion::
+
+ from sqlalchemy.ext.hybrid import Comparator
+
+ class GrandparentTransformer(Comparator):
+ def operate(self, op, other, **kwargs):
+ def transform(q):
+ cls = self.__clause_element__()
+ parent_alias = aliased(cls)
+ return q.join(parent_alias, cls.parent).filter(
+ op(parent_alias.parent, other, **kwargs)
+ )
+
+ return transform
+
+ Base = declarative_base()
+
+ class Node(Base):
+ __tablename__ = 'node'
+ id = Column(Integer, primary_key=True)
+ parent_id = Column(Integer, ForeignKey('node.id'))
+ parent = relationship("Node", remote_side=id)
+
+ @hybrid_property
+ def grandparent(self):
+ return self.parent.parent
+
+ @grandparent.comparator
+ def grandparent(cls):
+ return GrandparentTransformer(cls)
+
+The ``GrandparentTransformer`` overrides the core :meth:`.Operators.operate`
+method at the base of the :class:`.Comparator` hierarchy to return a query-
+transforming callable, which then runs the given comparison operation in a
+particular context. Such as, in the example above, the ``operate`` method is
+called, given the :attr:`.Operators.eq` callable as well as the right side of
+the comparison ``Node(id=5)``. A function ``transform`` is then returned which
+will transform a :class:`_query.Query` first to join to ``Node.parent``,
+then to
+compare ``parent_alias`` using :attr:`.Operators.eq` against the left and right
+sides, passing into :meth:`_query.Query.filter`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.orm import Session
+ >>> session = Session()
+ {sql}>>> session.query(Node).\
+ ... with_transformation(Node.grandparent==Node(id=5)).\
+ ... all()
+ SELECT node.id AS node_id, node.parent_id AS node_parent_id
+ FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
+ WHERE :param_1 = node_1.parent_id
+ {stop}
+
+We can modify the pattern to be more verbose but flexible by separating the
+"join" step from the "filter" step. The tricky part here is ensuring that
+successive instances of ``GrandparentTransformer`` use the same
+:class:`.AliasedClass` object against ``Node``. Below we use a simple
+memoizing approach that associates a ``GrandparentTransformer`` with each
+class::
+
+ class Node(Base):
+
+ # ...
+
+ @grandparent.comparator
+ def grandparent(cls):
+ # memoize a GrandparentTransformer
+ # per class
+ if '_gp' not in cls.__dict__:
+ cls._gp = GrandparentTransformer(cls)
+ return cls._gp
+
+ class GrandparentTransformer(Comparator):
+
+ def __init__(self, cls):
+ self.parent_alias = aliased(cls)
+
+ @property
+ def join(self):
+ def go(q):
+ return q.join(self.parent_alias, Node.parent)
+ return go
+
+ def operate(self, op, other, **kwargs):
+ return op(self.parent_alias.parent, other, **kwargs)
+
+.. sourcecode:: pycon+sql
+
+ {sql}>>> session.query(Node).\
+ ... with_transformation(Node.grandparent.join).\
+ ... filter(Node.grandparent==Node(id=5))
+ SELECT node.id AS node_id, node.parent_id AS node_parent_id
+ FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
+ WHERE :param_1 = node_1.parent_id
+ {stop}
+
+The "transformer" pattern is an experimental pattern that starts to make usage
+of some functional programming paradigms. While it's only recommended for
+advanced and/or patient developers, there's probably a whole lot of amazing
+things it can be used for.
+
+""" # noqa
+from .. import util
+from ..orm import attributes
+from ..orm import interfaces
+
+HYBRID_METHOD = util.symbol("HYBRID_METHOD")
+"""Symbol indicating an :class:`InspectionAttr` that's
+ of type :class:`.hybrid_method`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_attributes`
+
+"""
+
+HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY")
+"""Symbol indicating an :class:`InspectionAttr` that's
+ of type :class:`.hybrid_method`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_attributes`
+
+"""
+
+
+class hybrid_method(interfaces.InspectionAttrInfo):
+ """A decorator which allows definition of a Python object method with both
+ instance-level and class-level behavior.
+
+ """
+
+ is_attribute = True
+ extension_type = HYBRID_METHOD
+
+ def __init__(self, func, expr=None):
+ """Create a new :class:`.hybrid_method`.
+
+ Usage is typically via decorator::
+
+ from sqlalchemy.ext.hybrid import hybrid_method
+
+ class SomeClass(object):
+ @hybrid_method
+ def value(self, x, y):
+ return self._value + x + y
+
+ @value.expression
+ def value(self, x, y):
+ return func.some_function(self._value, x, y)
+
+ """
+ self.func = func
+ self.expression(expr or func)
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.expr.__get__(owner, owner.__class__)
+ else:
+ return self.func.__get__(instance, owner)
+
+ def expression(self, expr):
+ """Provide a modifying decorator that defines a
+ SQL-expression producing method."""
+
+ self.expr = expr
+ if not self.expr.__doc__:
+ self.expr.__doc__ = self.func.__doc__
+ return self
+
+
+class hybrid_property(interfaces.InspectionAttrInfo):
+ """A decorator which allows definition of a Python descriptor with both
+ instance-level and class-level behavior.
+
+ """
+
+ is_attribute = True
+ extension_type = HYBRID_PROPERTY
+
+ def __init__(
+ self,
+ fget,
+ fset=None,
+ fdel=None,
+ expr=None,
+ custom_comparator=None,
+ update_expr=None,
+ ):
+ """Create a new :class:`.hybrid_property`.
+
+ Usage is typically via decorator::
+
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ class SomeClass(object):
+ @hybrid_property
+ def value(self):
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ self._value = value
+
+ """
+ self.fget = fget
+ self.fset = fset
+ self.fdel = fdel
+ self.expr = expr
+ self.custom_comparator = custom_comparator
+ self.update_expr = update_expr
+ util.update_wrapper(self, fget)
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self._expr_comparator(owner)
+ else:
+ return self.fget(instance)
+
+ def __set__(self, instance, value):
+ if self.fset is None:
+ raise AttributeError("can't set attribute")
+ self.fset(instance, value)
+
+ def __delete__(self, instance):
+ if self.fdel is None:
+ raise AttributeError("can't delete attribute")
+ self.fdel(instance)
+
+ def _copy(self, **kw):
+ defaults = {
+ key: value
+ for key, value in self.__dict__.items()
+ if not key.startswith("_")
+ }
+ defaults.update(**kw)
+ return type(self)(**defaults)
+
+ @property
+ def overrides(self):
+ """Prefix for a method that is overriding an existing attribute.
+
+ The :attr:`.hybrid_property.overrides` accessor just returns
+ this hybrid object, which when called at the class level from
+ a parent class, will de-reference the "instrumented attribute"
+ normally returned at this level, and allow modifying decorators
+ like :meth:`.hybrid_property.expression` and
+ :meth:`.hybrid_property.comparator`
+ to be used without conflicting with the same-named attributes
+ normally present on the :class:`.QueryableAttribute`::
+
+ class SuperClass(object):
+ # ...
+
+ @hybrid_property
+ def foobar(self):
+ return self._foobar
+
+ class SubClass(SuperClass):
+ # ...
+
+ @SuperClass.foobar.overrides.expression
+ def foobar(cls):
+ return func.subfoobar(self._foobar)
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`hybrid_reuse_subclass`
+
+ """
+ return self
+
+ def getter(self, fget):
+ """Provide a modifying decorator that defines a getter method.
+
+ .. versionadded:: 1.2
+
+ """
+
+ return self._copy(fget=fget)
+
+ def setter(self, fset):
+ """Provide a modifying decorator that defines a setter method."""
+
+ return self._copy(fset=fset)
+
+ def deleter(self, fdel):
+ """Provide a modifying decorator that defines a deletion method."""
+
+ return self._copy(fdel=fdel)
+
+ def expression(self, expr):
+ """Provide a modifying decorator that defines a SQL-expression
+ producing method.
+
+ When a hybrid is invoked at the class level, the SQL expression given
+ here is wrapped inside of a specialized :class:`.QueryableAttribute`,
+ which is the same kind of object used by the ORM to represent other
+ mapped attributes. The reason for this is so that other class-level
+ attributes such as docstrings and a reference to the hybrid itself may
+ be maintained within the structure that's returned, without any
+ modifications to the original SQL expression passed in.
+
+ .. note::
+
+ When referring to a hybrid property from an owning class (e.g.
+ ``SomeClass.some_hybrid``), an instance of
+ :class:`.QueryableAttribute` is returned, representing the
+ expression or comparator object as well as this hybrid object.
+ However, that object itself has accessors called ``expression`` and
+ ``comparator``; so when attempting to override these decorators on a
+ subclass, it may be necessary to qualify it using the
+ :attr:`.hybrid_property.overrides` modifier first. See that
+ modifier for details.
+
+ .. seealso::
+
+ :ref:`hybrid_distinct_expression`
+
+ """
+
+ return self._copy(expr=expr)
+
+ def comparator(self, comparator):
+ """Provide a modifying decorator that defines a custom
+ comparator producing method.
+
+ The return value of the decorated method should be an instance of
+ :class:`~.hybrid.Comparator`.
+
+ .. note:: The :meth:`.hybrid_property.comparator` decorator
+ **replaces** the use of the :meth:`.hybrid_property.expression`
+ decorator. They cannot be used together.
+
+ When a hybrid is invoked at the class level, the
+ :class:`~.hybrid.Comparator` object given here is wrapped inside of a
+ specialized :class:`.QueryableAttribute`, which is the same kind of
+ object used by the ORM to represent other mapped attributes. The
+ reason for this is so that other class-level attributes such as
+ docstrings and a reference to the hybrid itself may be maintained
+ within the structure that's returned, without any modifications to the
+ original comparator object passed in.
+
+ .. note::
+
+ When referring to a hybrid property from an owning class (e.g.
+ ``SomeClass.some_hybrid``), an instance of
+ :class:`.QueryableAttribute` is returned, representing the
+ expression or comparator object as this hybrid object. However,
+ that object itself has accessors called ``expression`` and
+ ``comparator``; so when attempting to override these decorators on a
+ subclass, it may be necessary to qualify it using the
+ :attr:`.hybrid_property.overrides` modifier first. See that
+ modifier for details.
+
+ """
+ return self._copy(custom_comparator=comparator)
+
+ def update_expression(self, meth):
+ """Provide a modifying decorator that defines an UPDATE tuple
+ producing method.
+
+ The method accepts a single value, which is the value to be
+ rendered into the SET clause of an UPDATE statement. The method
+ should then process this value into individual column expressions
+ that fit into the ultimate SET clause, and return them as a
+ sequence of 2-tuples. Each tuple
+ contains a column expression as the key and a value to be rendered.
+
+ E.g.::
+
+ class Person(Base):
+ # ...
+
+ first_name = Column(String)
+ last_name = Column(String)
+
+ @hybrid_property
+ def fullname(self):
+ return first_name + " " + last_name
+
+ @fullname.update_expression
+ def fullname(cls, value):
+ fname, lname = value.split(" ", 1)
+ return [
+ (cls.first_name, fname),
+ (cls.last_name, lname)
+ ]
+
+ .. versionadded:: 1.2
+
+ """
+ return self._copy(update_expr=meth)
+
+ @util.memoized_property
+ def _expr_comparator(self):
+ if self.custom_comparator is not None:
+ return self._get_comparator(self.custom_comparator)
+ elif self.expr is not None:
+ return self._get_expr(self.expr)
+ else:
+ return self._get_expr(self.fget)
+
+ def _get_expr(self, expr):
+ def _expr(cls):
+ return ExprComparator(cls, expr(cls), self)
+
+ util.update_wrapper(_expr, expr)
+
+ return self._get_comparator(_expr)
+
+ def _get_comparator(self, comparator):
+
+ proxy_attr = attributes.create_proxied_attribute(self)
+
+ def expr_comparator(owner):
+ # because this is the descriptor protocol, we don't really know
+ # what our attribute name is. so search for it through the
+ # MRO.
+ for lookup in owner.__mro__:
+ if self.__name__ in lookup.__dict__:
+ if lookup.__dict__[self.__name__] is self:
+ name = self.__name__
+ break
+ else:
+ name = attributes.NO_KEY
+
+ return proxy_attr(
+ owner,
+ name,
+ self,
+ comparator(owner),
+ doc=comparator.__doc__ or self.__doc__,
+ )
+
+ return expr_comparator
+
+
+class Comparator(interfaces.PropComparator):
+ """A helper class that allows easy construction of custom
+ :class:`~.orm.interfaces.PropComparator`
+ classes for usage with hybrids."""
+
+ property = None
+
+ def __init__(self, expression):
+ self.expression = expression
+
+ def __clause_element__(self):
+ expr = self.expression
+ if hasattr(expr, "__clause_element__"):
+ expr = expr.__clause_element__()
+ return expr
+
+ def adapt_to_entity(self, adapt_to_entity):
+ # interesting....
+ return self
+
+
+class ExprComparator(Comparator):
+ def __init__(self, cls, expression, hybrid):
+ self.cls = cls
+ self.expression = expression
+ self.hybrid = hybrid
+
+ def __getattr__(self, key):
+ return getattr(self.expression, key)
+
+ @property
+ def info(self):
+ return self.hybrid.info
+
+ def _bulk_update_tuples(self, value):
+ if isinstance(self.expression, attributes.QueryableAttribute):
+ return self.expression._bulk_update_tuples(value)
+ elif self.hybrid.update_expr is not None:
+ return self.hybrid.update_expr(self.cls, value)
+ else:
+ return [(self.expression, value)]
+
+ @property
+ def property(self):
+ return self.expression.property
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.expression, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.expression, **kwargs)
diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py
new file mode 100644
index 0000000..7cbac54
--- /dev/null
+++ b/lib/sqlalchemy/ext/indexable.py
@@ -0,0 +1,352 @@
+# ext/index.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define attributes on ORM-mapped classes that have "index" attributes for
+columns with :class:`_types.Indexable` types.
+
+"index" means the attribute is associated with an element of an
+:class:`_types.Indexable` column with the predefined index to access it.
+The :class:`_types.Indexable` types include types such as
+:class:`_types.ARRAY`, :class:`_types.JSON` and
+:class:`_postgresql.HSTORE`.
+
+
+
+The :mod:`~sqlalchemy.ext.indexable` extension provides
+:class:`_schema.Column`-like interface for any element of an
+:class:`_types.Indexable` typed column. In simple cases, it can be
+treated as a :class:`_schema.Column` - mapped attribute.
+
+
+.. versionadded:: 1.1
+
+Synopsis
+========
+
+Given ``Person`` as a model with a primary key and JSON data field.
+While this field may have any number of elements encoded within it,
+we would like to refer to the element called ``name`` individually
+as a dedicated attribute which behaves like a standalone column::
+
+ from sqlalchemy import Column, JSON, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.indexable import index_property
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ name = index_property('data', 'name')
+
+
+Above, the ``name`` attribute now behaves like a mapped column. We
+can compose a new ``Person`` and set the value of ``name``::
+
+ >>> person = Person(name='Alchemist')
+
+The value is now accessible::
+
+ >>> person.name
+ 'Alchemist'
+
+Behind the scenes, the JSON field was initialized to a new blank dictionary
+and the field was set::
+
+ >>> person.data
+ {"name": "Alchemist'}
+
+The field is mutable in place::
+
+ >>> person.name = 'Renamed'
+ >>> person.name
+ 'Renamed'
+ >>> person.data
+ {'name': 'Renamed'}
+
+When using :class:`.index_property`, the change that we make to the indexable
+structure is also automatically tracked as history; we no longer need
+to use :class:`~.mutable.MutableDict` in order to track this change
+for the unit of work.
+
+Deletions work normally as well::
+
+ >>> del person.name
+ >>> person.data
+ {}
+
+Above, deletion of ``person.name`` deletes the value from the dictionary,
+but not the dictionary itself.
+
+A missing key will produce ``AttributeError``::
+
+ >>> person = Person()
+ >>> person.name
+ ...
+ AttributeError: 'name'
+
+Unless you set a default value::
+
+ >>> class Person(Base):
+ >>> __tablename__ = 'person'
+ >>>
+ >>> id = Column(Integer, primary_key=True)
+ >>> data = Column(JSON)
+ >>>
+ >>> name = index_property('data', 'name', default=None) # See default
+
+ >>> person = Person()
+ >>> print(person.name)
+ None
+
+
+The attributes are also accessible at the class level.
+Below, we illustrate ``Person.name`` used to generate
+an indexed SQL criteria::
+
+ >>> from sqlalchemy.orm import Session
+ >>> session = Session()
+ >>> query = session.query(Person).filter(Person.name == 'Alchemist')
+
+The above query is equivalent to::
+
+ >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
+
+Multiple :class:`.index_property` objects can be chained to produce
+multiple levels of indexing::
+
+ from sqlalchemy import Column, JSON, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.indexable import index_property
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ birthday = index_property('data', 'birthday')
+ year = index_property('birthday', 'year')
+ month = index_property('birthday', 'month')
+ day = index_property('birthday', 'day')
+
+Above, a query such as::
+
+ q = session.query(Person).filter(Person.year == '1980')
+
+On a PostgreSQL backend, the above query will render as::
+
+ SELECT person.id, person.data
+ FROM person
+ WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
+
+Default Values
+==============
+
+:class:`.index_property` includes special behaviors for when the indexed
+data structure does not exist, and a set operation is called:
+
+* For an :class:`.index_property` that is given an integer index value,
+ the default data structure will be a Python list of ``None`` values,
+ at least as long as the index value; the value is then set at its
+ place in the list. This means for an index value of zero, the list
+ will be initialized to ``[None]`` before setting the given value,
+ and for an index value of five, the list will be initialized to
+ ``[None, None, None, None, None]`` before setting the fifth element
+ to the given value. Note that an existing list is **not** extended
+ in place to receive a value.
+
+* for an :class:`.index_property` that is given any other kind of index
+ value (e.g. strings usually), a Python dictionary is used as the
+ default data structure.
+
+* The default data structure can be set to any Python callable using the
+ :paramref:`.index_property.datatype` parameter, overriding the previous
+ rules.
+
+
+Subclassing
+===========
+
+:class:`.index_property` can be subclassed, in particular for the common
+use case of providing coercion of values or SQL expressions as they are
+accessed. Below is a common recipe for use with a PostgreSQL JSON type,
+where we want to also include automatic casting plus ``astext()``::
+
+ class pg_json_property(index_property):
+ def __init__(self, attr_name, index, cast_type):
+ super(pg_json_property, self).__init__(attr_name, index)
+ self.cast_type = cast_type
+
+ def expr(self, model):
+ expr = super(pg_json_property, self).expr(model)
+ return expr.astext.cast(self.cast_type)
+
+The above subclass can be used with the PostgreSQL-specific
+version of :class:`_postgresql.JSON`::
+
+ from sqlalchemy import Column, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.dialects.postgresql import JSON
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ age = pg_json_property('data', 'age', Integer)
+
+The ``age`` attribute at the instance level works as before; however
+when rendering SQL, PostgreSQL's ``->>`` operator will be used
+for indexed access, instead of the usual index operator of ``->``::
+
+ >>> query = session.query(Person).filter(Person.age < 20)
+
+The above query will render::
+
+ SELECT person.id, person.data
+ FROM person
+ WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
+
+""" # noqa
+from __future__ import absolute_import
+
+from .. import inspect
+from .. import util
+from ..ext.hybrid import hybrid_property
+from ..orm.attributes import flag_modified
+
+
+__all__ = ["index_property"]
+
+
+class index_property(hybrid_property): # noqa
+ """A property generator. The generated property describes an object
+ attribute that corresponds to an :class:`_types.Indexable`
+ column.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :mod:`sqlalchemy.ext.indexable`
+
+ """
+
+ _NO_DEFAULT_ARGUMENT = object()
+
+ def __init__(
+ self,
+ attr_name,
+ index,
+ default=_NO_DEFAULT_ARGUMENT,
+ datatype=None,
+ mutable=True,
+ onebased=True,
+ ):
+ """Create a new :class:`.index_property`.
+
+ :param attr_name:
+ An attribute name of an `Indexable` typed column, or other
+ attribute that returns an indexable structure.
+ :param index:
+ The index to be used for getting and setting this value. This
+ should be the Python-side index value for integers.
+ :param default:
+ A value which will be returned instead of `AttributeError`
+ when there is not a value at given index.
+ :param datatype: default datatype to use when the field is empty.
+ By default, this is derived from the type of index used; a
+ Python list for an integer index, or a Python dictionary for
+ any other style of index. For a list, the list will be
+ initialized to a list of None values that is at least
+ ``index`` elements long.
+ :param mutable: if False, writes and deletes to the attribute will
+ be disallowed.
+ :param onebased: assume the SQL representation of this value is
+ one-based; that is, the first index in SQL is 1, not zero.
+ """
+
+ if mutable:
+ super(index_property, self).__init__(
+ self.fget, self.fset, self.fdel, self.expr
+ )
+ else:
+ super(index_property, self).__init__(
+ self.fget, None, None, self.expr
+ )
+ self.attr_name = attr_name
+ self.index = index
+ self.default = default
+ is_numeric = isinstance(index, int)
+ onebased = is_numeric and onebased
+
+ if datatype is not None:
+ self.datatype = datatype
+ else:
+ if is_numeric:
+ self.datatype = lambda: [None for x in range(index + 1)]
+ else:
+ self.datatype = dict
+ self.onebased = onebased
+
+ def _fget_default(self, err=None):
+ if self.default == self._NO_DEFAULT_ARGUMENT:
+ util.raise_(AttributeError(self.attr_name), replace_context=err)
+ else:
+ return self.default
+
+ def fget(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ return self._fget_default()
+ try:
+ value = column_value[self.index]
+ except (KeyError, IndexError) as err:
+ return self._fget_default(err)
+ else:
+ return value
+
+ def fset(self, instance, value):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name, None)
+ if column_value is None:
+ column_value = self.datatype()
+ setattr(instance, attr_name, column_value)
+ column_value[self.index] = value
+ setattr(instance, attr_name, column_value)
+ if attr_name in inspect(instance).mapper.attrs:
+ flag_modified(instance, attr_name)
+
+ def fdel(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ raise AttributeError(self.attr_name)
+ try:
+ del column_value[self.index]
+ except KeyError as err:
+ util.raise_(AttributeError(self.attr_name), replace_context=err)
+ else:
+ setattr(instance, attr_name, column_value)
+ flag_modified(instance, attr_name)
+
+ def expr(self, model):
+ column = getattr(model, self.attr_name)
+ index = self.index
+ if self.onebased:
+ index += 1
+ return column[index]
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
new file mode 100644
index 0000000..54f3e64
--- /dev/null
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -0,0 +1,416 @@
+"""Extensible class instrumentation.
+
+The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
+systems of class instrumentation within the ORM. Class instrumentation
+refers to how the ORM places attributes on the class which maintain
+data and track changes to that data, as well as event hooks installed
+on the class.
+
+.. note::
+ The extension package is provided for the benefit of integration
+ with other object management packages, which already perform
+ their own instrumentation. It is not intended for general use.
+
+For examples of how the instrumentation extension is used,
+see the example :ref:`examples_instrumentation`.
+
+"""
+import weakref
+
+from .. import util
+from ..orm import attributes
+from ..orm import base as orm_base
+from ..orm import collections
+from ..orm import exc as orm_exc
+from ..orm import instrumentation as orm_instrumentation
+from ..orm.instrumentation import _default_dict_getter
+from ..orm.instrumentation import _default_manager_getter
+from ..orm.instrumentation import _default_state_getter
+from ..orm.instrumentation import ClassManager
+from ..orm.instrumentation import InstrumentationFactory
+
+
+INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
+"""Attribute, elects custom instrumentation when present on a mapped class.
+
+Allows a class to specify a slightly or wildly different technique for
+tracking changes made to mapped attributes and collections.
+
+Only one instrumentation implementation is allowed in a given object
+inheritance hierarchy.
+
+The value of this attribute must be a callable and will be passed a class
+object. The callable must return one of:
+
+ - An instance of an :class:`.InstrumentationManager` or subclass
+ - An object implementing all or some of InstrumentationManager (TODO)
+ - A dictionary of callables, implementing all or some of the above (TODO)
+ - An instance of a :class:`.ClassManager` or subclass
+
+This attribute is consulted by SQLAlchemy instrumentation
+resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
+has been imported. If custom finders are installed in the global
+instrumentation_finders list, they may or may not choose to honor this
+attribute.
+
+"""
+
+
+def find_native_user_instrumentation_hook(cls):
+ """Find user-specified instrumentation management for a class."""
+ return getattr(cls, INSTRUMENTATION_MANAGER, None)
+
+
+instrumentation_finders = [find_native_user_instrumentation_hook]
+"""An extensible sequence of callables which return instrumentation
+implementations
+
+When a class is registered, each callable will be passed a class object.
+If None is returned, the
+next finder in the sequence is consulted. Otherwise the return must be an
+instrumentation factory that follows the same guidelines as
+sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
+
+By default, the only finder is find_native_user_instrumentation_hook, which
+searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
+ClassManager instrumentation is used.
+
+"""
+
+
+class ExtendedInstrumentationRegistry(InstrumentationFactory):
+ """Extends :class:`.InstrumentationFactory` with additional
+ bookkeeping, to accommodate multiple types of
+ class managers.
+
+ """
+
+ _manager_finders = weakref.WeakKeyDictionary()
+ _state_finders = weakref.WeakKeyDictionary()
+ _dict_finders = weakref.WeakKeyDictionary()
+ _extended = False
+
+ def _locate_extended_factory(self, class_):
+ for finder in instrumentation_finders:
+ factory = finder(class_)
+ if factory is not None:
+ manager = self._extended_class_manager(class_, factory)
+ return manager, factory
+ else:
+ return None, None
+
+ def _check_conflicts(self, class_, factory):
+ existing_factories = self._collect_management_factories_for(
+ class_
+ ).difference([factory])
+ if existing_factories:
+ raise TypeError(
+ "multiple instrumentation implementations specified "
+ "in %s inheritance hierarchy: %r"
+ % (class_.__name__, list(existing_factories))
+ )
+
+ def _extended_class_manager(self, class_, factory):
+ manager = factory(class_)
+ if not isinstance(manager, ClassManager):
+ manager = _ClassInstrumentationAdapter(class_, manager)
+
+ if factory != ClassManager and not self._extended:
+ # somebody invoked a custom ClassManager.
+ # reinstall global "getter" functions with the more
+ # expensive ones.
+ self._extended = True
+ _install_instrumented_lookups()
+
+ self._manager_finders[class_] = manager.manager_getter()
+ self._state_finders[class_] = manager.state_getter()
+ self._dict_finders[class_] = manager.dict_getter()
+ return manager
+
+ def _collect_management_factories_for(self, cls):
+ """Return a collection of factories in play or specified for a
+ hierarchy.
+
+ Traverses the entire inheritance graph of a cls and returns a
+ collection of instrumentation factories for those classes. Factories
+ are extracted from active ClassManagers, if available, otherwise
+ instrumentation_finders is consulted.
+
+ """
+ hierarchy = util.class_hierarchy(cls)
+ factories = set()
+ for member in hierarchy:
+ manager = self.manager_of_class(member)
+ if manager is not None:
+ factories.add(manager.factory)
+ else:
+ for finder in instrumentation_finders:
+ factory = finder(member)
+ if factory is not None:
+ break
+ else:
+ factory = None
+ factories.add(factory)
+ factories.discard(None)
+ return factories
+
+ def unregister(self, class_):
+ super(ExtendedInstrumentationRegistry, self).unregister(class_)
+ if class_ in self._manager_finders:
+ del self._manager_finders[class_]
+ del self._state_finders[class_]
+ del self._dict_finders[class_]
+
+ def manager_of_class(self, cls):
+ if cls is None:
+ return None
+ try:
+ finder = self._manager_finders.get(cls, _default_manager_getter)
+ except TypeError:
+ # due to weakref lookup on invalid object
+ return None
+ else:
+ return finder(cls)
+
+ def state_of(self, instance):
+ if instance is None:
+ raise AttributeError("None has no persistent state.")
+ return self._state_finders.get(
+ instance.__class__, _default_state_getter
+ )(instance)
+
+ def dict_of(self, instance):
+ if instance is None:
+ raise AttributeError("None has no persistent state.")
+ return self._dict_finders.get(
+ instance.__class__, _default_dict_getter
+ )(instance)
+
+
+orm_instrumentation._instrumentation_factory = (
+ _instrumentation_factory
+) = ExtendedInstrumentationRegistry()
+orm_instrumentation.instrumentation_finders = instrumentation_finders
+
+
+class InstrumentationManager(object):
+ """User-defined class instrumentation extension.
+
+ :class:`.InstrumentationManager` can be subclassed in order
+ to change
+ how class instrumentation proceeds. This class exists for
+ the purposes of integration with other object management
+ frameworks which would like to entirely modify the
+ instrumentation methodology of the ORM, and is not intended
+ for regular usage. For interception of class instrumentation
+ events, see :class:`.InstrumentationEvents`.
+
+ The API for this class should be considered as semi-stable,
+ and may change slightly with new releases.
+
+ """
+
+ # r4361 added a mandatory (cls) constructor to this interface.
+ # given that, perhaps class_ should be dropped from all of these
+ # signatures.
+
+ def __init__(self, class_):
+ pass
+
+ def manage(self, class_, manager):
+ setattr(class_, "_default_class_manager", manager)
+
+ def unregister(self, class_, manager):
+ delattr(class_, "_default_class_manager")
+
+ def manager_getter(self, class_):
+ def get(cls):
+ return cls._default_class_manager
+
+ return get
+
+ def instrument_attribute(self, class_, key, inst):
+ pass
+
+ def post_configure_attribute(self, class_, key, inst):
+ pass
+
+ def install_descriptor(self, class_, key, inst):
+ setattr(class_, key, inst)
+
+ def uninstall_descriptor(self, class_, key):
+ delattr(class_, key)
+
+ def install_member(self, class_, key, implementation):
+ setattr(class_, key, implementation)
+
+ def uninstall_member(self, class_, key):
+ delattr(class_, key)
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ return collections.prepare_instrumentation(collection_class)
+
+ def get_instance_dict(self, class_, instance):
+ return instance.__dict__
+
+ def initialize_instance_dict(self, class_, instance):
+ pass
+
+ def install_state(self, class_, instance, state):
+ setattr(instance, "_default_state", state)
+
+ def remove_state(self, class_, instance):
+ delattr(instance, "_default_state")
+
+ def state_getter(self, class_):
+ return lambda instance: getattr(instance, "_default_state")
+
+ def dict_getter(self, class_):
+ return lambda inst: self.get_instance_dict(class_, inst)
+
+
+class _ClassInstrumentationAdapter(ClassManager):
+ """Adapts a user-defined InstrumentationManager to a ClassManager."""
+
+ def __init__(self, class_, override):
+ self._adapted = override
+ self._get_state = self._adapted.state_getter(class_)
+ self._get_dict = self._adapted.dict_getter(class_)
+
+ ClassManager.__init__(self, class_)
+
+ def manage(self):
+ self._adapted.manage(self.class_, self)
+
+ def unregister(self):
+ self._adapted.unregister(self.class_, self)
+
+ def manager_getter(self):
+ return self._adapted.manager_getter(self.class_)
+
+ def instrument_attribute(self, key, inst, propagated=False):
+ ClassManager.instrument_attribute(self, key, inst, propagated)
+ if not propagated:
+ self._adapted.instrument_attribute(self.class_, key, inst)
+
+ def post_configure_attribute(self, key):
+ super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
+ self._adapted.post_configure_attribute(self.class_, key, self[key])
+
+ def install_descriptor(self, key, inst):
+ self._adapted.install_descriptor(self.class_, key, inst)
+
+ def uninstall_descriptor(self, key):
+ self._adapted.uninstall_descriptor(self.class_, key)
+
+ def install_member(self, key, implementation):
+ self._adapted.install_member(self.class_, key, implementation)
+
+ def uninstall_member(self, key):
+ self._adapted.uninstall_member(self.class_, key)
+
+ def instrument_collection_class(self, key, collection_class):
+ return self._adapted.instrument_collection_class(
+ self.class_, key, collection_class
+ )
+
+ def initialize_collection(self, key, state, factory):
+ delegate = getattr(self._adapted, "initialize_collection", None)
+ if delegate:
+ return delegate(key, state, factory)
+ else:
+ return ClassManager.initialize_collection(
+ self, key, state, factory
+ )
+
+ def new_instance(self, state=None):
+ instance = self.class_.__new__(self.class_)
+ self.setup_instance(instance, state)
+ return instance
+
+ def _new_state_if_none(self, instance):
+ """Install a default InstanceState if none is present.
+
+ A private convenience method used by the __init__ decorator.
+ """
+ if self.has_state(instance):
+ return False
+ else:
+ return self.setup_instance(instance)
+
+ def setup_instance(self, instance, state=None):
+ self._adapted.initialize_instance_dict(self.class_, instance)
+
+ if state is None:
+ state = self._state_constructor(instance, self)
+
+ # the given instance is assumed to have no state
+ self._adapted.install_state(self.class_, instance, state)
+ return state
+
+ def teardown_instance(self, instance):
+ self._adapted.remove_state(self.class_, instance)
+
+ def has_state(self, instance):
+ try:
+ self._get_state(instance)
+ except orm_exc.NO_STATE:
+ return False
+ else:
+ return True
+
+ def state_getter(self):
+ return self._get_state
+
+ def dict_getter(self):
+ return self._get_dict
+
+
+def _install_instrumented_lookups():
+ """Replace global class/object management functions
+ with ExtendedInstrumentationRegistry implementations, which
+ allow multiple types of class managers to be present,
+ at the cost of performance.
+
+ This function is called only by ExtendedInstrumentationRegistry
+ and unit tests specific to this behavior.
+
+ The _reinstall_default_lookups() function can be called
+ after this one to re-establish the default functions.
+
+ """
+ _install_lookups(
+ dict(
+ instance_state=_instrumentation_factory.state_of,
+ instance_dict=_instrumentation_factory.dict_of,
+ manager_of_class=_instrumentation_factory.manager_of_class,
+ )
+ )
+
+
+def _reinstall_default_lookups():
+ """Restore simplified lookups."""
+ _install_lookups(
+ dict(
+ instance_state=_default_state_getter,
+ instance_dict=_default_dict_getter,
+ manager_of_class=_default_manager_getter,
+ )
+ )
+ _instrumentation_factory._extended = False
+
+
+def _install_lookups(lookups):
+ global instance_state, instance_dict, manager_of_class
+ instance_state = lookups["instance_state"]
+ instance_dict = lookups["instance_dict"]
+ manager_of_class = lookups["manager_of_class"]
+ orm_base.instance_state = (
+ attributes.instance_state
+ ) = orm_instrumentation.instance_state = instance_state
+ orm_base.instance_dict = (
+ attributes.instance_dict
+ ) = orm_instrumentation.instance_dict = instance_dict
+ orm_base.manager_of_class = (
+ attributes.manager_of_class
+ ) = orm_instrumentation.manager_of_class = manager_of_class
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py
new file mode 100644
index 0000000..cbec06a
--- /dev/null
+++ b/lib/sqlalchemy/ext/mutable.py
@@ -0,0 +1,958 @@
+# ext/mutable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Provide support for tracking of in-place changes to scalar values,
+which are propagated into ORM change events on owning parent objects.
+
+.. _mutable_scalars:
+
+Establishing Mutability on Scalar Column Values
+===============================================
+
+A typical example of a "mutable" structure is a Python dictionary.
+Following the example introduced in :ref:`types_toplevel`, we
+begin with a custom type that marshals Python dictionaries into
+JSON strings before being persisted::
+
+ from sqlalchemy.types import TypeDecorator, VARCHAR
+ import json
+
+ class JSONEncodedDict(TypeDecorator):
+ "Represents an immutable structure as a json-encoded string."
+
+ impl = VARCHAR
+
+ def process_bind_param(self, value, dialect):
+ if value is not None:
+ value = json.dumps(value)
+ return value
+
+ def process_result_value(self, value, dialect):
+ if value is not None:
+ value = json.loads(value)
+ return value
+
+The usage of ``json`` is only for the purposes of example. The
+:mod:`sqlalchemy.ext.mutable` extension can be used
+with any type whose target Python type may be mutable, including
+:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc.
+
+When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
+tracks all parents which reference it. Below, we illustrate a simple
+version of the :class:`.MutableDict` dictionary object, which applies
+the :class:`.Mutable` mixin to a plain Python dictionary::
+
+ from sqlalchemy.ext.mutable import Mutable
+
+ class MutableDict(Mutable, dict):
+ @classmethod
+ def coerce(cls, key, value):
+ "Convert plain dictionaries to MutableDict."
+
+ if not isinstance(value, MutableDict):
+ if isinstance(value, dict):
+ return MutableDict(value)
+
+ # this call will raise ValueError
+ return Mutable.coerce(key, value)
+ else:
+ return value
+
+ def __setitem__(self, key, value):
+ "Detect dictionary set events and emit change events."
+
+ dict.__setitem__(self, key, value)
+ self.changed()
+
+ def __delitem__(self, key):
+ "Detect dictionary del events and emit change events."
+
+ dict.__delitem__(self, key)
+ self.changed()
+
+The above dictionary class takes the approach of subclassing the Python
+built-in ``dict`` to produce a dict
+subclass which routes all mutation events through ``__setitem__``. There are
+variants on this approach, such as subclassing ``UserDict.UserDict`` or
+``collections.MutableMapping``; the part that's important to this example is
+that the :meth:`.Mutable.changed` method is called whenever an in-place
+change to the datastructure takes place.
+
+We also redefine the :meth:`.Mutable.coerce` method which will be used to
+convert any values that are not instances of ``MutableDict``, such
+as the plain dictionaries returned by the ``json`` module, into the
+appropriate type. Defining this method is optional; we could just as well
+created our ``JSONEncodedDict`` such that it always returns an instance
+of ``MutableDict``, and additionally ensured that all calling code
+uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not
+overridden, any values applied to a parent object which are not instances
+of the mutable type will raise a ``ValueError``.
+
+Our new ``MutableDict`` type offers a class method
+:meth:`~.Mutable.as_mutable` which we can use within column metadata
+to associate with types. This method grabs the given type object or
+class and associates a listener that will detect all future mappings
+of this type, applying event listening instrumentation to the mapped
+attribute. Such as, with classical table metadata::
+
+ from sqlalchemy import Table, Column, Integer
+
+ my_data = Table('my_data', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', MutableDict.as_mutable(JSONEncodedDict))
+ )
+
+Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
+(if the type object was not an instance already), which will intercept any
+attributes which are mapped against this type. Below we establish a simple
+mapping against the ``my_data`` table::
+
+ from sqlalchemy import mapper
+
+ class MyDataClass(object):
+ pass
+
+ # associates mutation listeners with MyDataClass.data
+ mapper(MyDataClass, my_data)
+
+The ``MyDataClass.data`` member will now be notified of in place changes
+to its value.
+
+There's no difference in usage when using declarative::
+
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(JSONEncodedDict))
+
+Any in-place changes to the ``MyDataClass.data`` member
+will flag the attribute as "dirty" on the parent object::
+
+ >>> from sqlalchemy.orm import Session
+
+ >>> sess = Session()
+ >>> m1 = MyDataClass(data={'value1':'foo'})
+ >>> sess.add(m1)
+ >>> sess.commit()
+
+ >>> m1.data['value1'] = 'bar'
+ >>> assert m1 in sess.dirty
+ True
+
+The ``MutableDict`` can be associated with all future instances
+of ``JSONEncodedDict`` in one step, using
+:meth:`~.Mutable.associate_with`. This is similar to
+:meth:`~.Mutable.as_mutable` except it will intercept all occurrences
+of ``MutableDict`` in all mappings unconditionally, without
+the need to declare it individually::
+
+ MutableDict.associate_with(JSONEncodedDict)
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(JSONEncodedDict)
+
+
+Supporting Pickling
+--------------------
+
+The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the
+placement of a ``weakref.WeakKeyDictionary`` upon the value object, which
+stores a mapping of parent mapped objects keyed to the attribute name under
+which they are associated with this value. ``WeakKeyDictionary`` objects are
+not picklable, due to the fact that they contain weakrefs and function
+callbacks. In our case, this is a good thing, since if this dictionary were
+picklable, it could lead to an excessively large pickle size for our value
+objects that are pickled by themselves outside of the context of the parent.
+The developer responsibility here is only to provide a ``__getstate__`` method
+that excludes the :meth:`~MutableBase._parents` collection from the pickle
+stream::
+
+ class MyMutableType(Mutable):
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d.pop('_parents', None)
+ return d
+
+With our dictionary example, we need to return the contents of the dict itself
+(and also restore them on __setstate__)::
+
+ class MutableDict(Mutable, dict):
+ # ....
+
+ def __getstate__(self):
+ return dict(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+In the case that our mutable value object is pickled as it is attached to one
+or more parent objects that are also part of the pickle, the :class:`.Mutable`
+mixin will re-establish the :attr:`.Mutable._parents` collection on each value
+object as the owning parents themselves are unpickled.
+
+Receiving Events
+----------------
+
+The :meth:`.AttributeEvents.modified` event handler may be used to receive
+an event when a mutable scalar emits a change event. This event handler
+is called when the :func:`.attributes.flag_modified` function is called
+from within the mutable extension::
+
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy import event
+
+ Base = declarative_base()
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(JSONEncodedDict))
+
+ @event.listens_for(MyDataClass.data, "modified")
+ def modified_json(instance):
+ print("json value modified:", instance.data)
+
+.. _mutable_composites:
+
+Establishing Mutability on Composites
+=====================================
+
+Composites are a special ORM feature which allow a single scalar attribute to
+be assigned an object value which represents information "composed" from one
+or more columns from the underlying mapped table. The usual example is that of
+a geometric "point", and is introduced in :ref:`mapper_composite`.
+
+As is the case with :class:`.Mutable`, the user-defined composite class
+subclasses :class:`.MutableComposite` as a mixin, and detects and delivers
+change events to its parents via the :meth:`.MutableComposite.changed` method.
+In the case of a composite class, the detection is usually via the usage of
+Python descriptors (i.e. ``@property``), or alternatively via the special
+Python method ``__setattr__()``. Below we expand upon the ``Point`` class
+introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite`
+and to also route attribute set events via ``__setattr__`` to the
+:meth:`.MutableComposite.changed` method::
+
+ from sqlalchemy.ext.mutable import MutableComposite
+
+ class Point(MutableComposite):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __setattr__(self, key, value):
+ "Intercept set events"
+
+ # set the attribute
+ object.__setattr__(self, key, value)
+
+ # alert all parents to the change
+ self.changed()
+
+ def __composite_values__(self):
+ return self.x, self.y
+
+ def __eq__(self, other):
+ return isinstance(other, Point) and \
+ other.x == self.x and \
+ other.y == self.y
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+The :class:`.MutableComposite` class uses a Python metaclass to automatically
+establish listeners for any usage of :func:`_orm.composite` that specifies our
+``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
+listeners are established which will route change events from ``Point``
+objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
+
+ from sqlalchemy.orm import composite, mapper
+ from sqlalchemy import Table, Column
+
+ vertices = Table('vertices', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('x1', Integer),
+ Column('y1', Integer),
+ Column('x2', Integer),
+ Column('y2', Integer),
+ )
+
+ class Vertex(object):
+ pass
+
+ mapper(Vertex, vertices, properties={
+ 'start': composite(Point, vertices.c.x1, vertices.c.y1),
+ 'end': composite(Point, vertices.c.x2, vertices.c.y2)
+ })
+
+Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members
+will flag the attribute as "dirty" on the parent object::
+
+ >>> from sqlalchemy.orm import Session
+
+ >>> sess = Session()
+ >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15))
+ >>> sess.add(v1)
+ >>> sess.commit()
+
+ >>> v1.end.x = 8
+ >>> assert v1 in sess.dirty
+ True
+
+Coercing Mutable Composites
+---------------------------
+
+The :meth:`.MutableBase.coerce` method is also supported on composite types.
+In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce`
+method is only called for attribute set operations, not load operations.
+Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent
+to using a :func:`.validates` validation routine for all attributes which
+make use of the custom composite type::
+
+ class Point(MutableComposite):
+ # other Point methods
+ # ...
+
+ def coerce(cls, key, value):
+ if isinstance(value, tuple):
+ value = Point(*value)
+ elif not isinstance(value, Point):
+ raise ValueError("tuple or Point expected")
+ return value
+
+Supporting Pickling
+--------------------
+
+As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper
+class uses a ``weakref.WeakKeyDictionary`` available via the
+:meth:`MutableBase._parents` attribute which isn't picklable. If we need to
+pickle instances of ``Point`` or its owning class ``Vertex``, we at least need
+to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary.
+Below we define both a ``__getstate__`` and a ``__setstate__`` that package up
+the minimal form of our ``Point`` class::
+
+ class Point(MutableComposite):
+ # ...
+
+ def __getstate__(self):
+ return self.x, self.y
+
+ def __setstate__(self, state):
+ self.x, self.y = state
+
+As with :class:`.Mutable`, the :class:`.MutableComposite` augments the
+pickling process of the parent's object-relational state so that the
+:meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
+
+"""
+from collections import defaultdict
+import weakref
+
+from .. import event
+from .. import inspect
+from .. import types
+from ..orm import Mapper
+from ..orm import mapper
+from ..orm.attributes import flag_modified
+from ..sql.base import SchemaEventTarget
+from ..util import memoized_property
+
+
+class MutableBase(object):
+ """Common base class to :class:`.Mutable`
+ and :class:`.MutableComposite`.
+
+ """
+
+ @memoized_property
+ def _parents(self):
+ """Dictionary of parent object's :class:`.InstanceState`->attribute
+ name on the parent.
+
+ This attribute is a so-called "memoized" property. It initializes
+ itself with a new ``weakref.WeakKeyDictionary`` the first time
+ it is accessed, returning the same object upon subsequent access.
+
+ .. versionchanged:: 1.4 the :class:`.InstanceState` is now used
+ as the key in the weak dictionary rather than the instance
+ itself.
+
+ """
+
+ return weakref.WeakKeyDictionary()
+
+ @classmethod
+ def coerce(cls, key, value):
+ """Given a value, coerce it into the target type.
+
+ Can be overridden by custom subclasses to coerce incoming
+ data into a particular type.
+
+ By default, raises ``ValueError``.
+
+ This method is called in different scenarios depending on if
+ the parent class is of type :class:`.Mutable` or of type
+ :class:`.MutableComposite`. In the case of the former, it is called
+ for both attribute-set operations as well as during ORM loading
+ operations. For the latter, it is only called during attribute-set
+ operations; the mechanics of the :func:`.composite` construct
+ handle coercion during load operations.
+
+
+ :param key: string name of the ORM-mapped attribute being set.
+ :param value: the incoming value.
+ :return: the method should return the coerced value, or raise
+ ``ValueError`` if the coercion cannot be completed.
+
+ """
+ if value is None:
+ return None
+ msg = "Attribute '%s' does not accept objects of type %s"
+ raise ValueError(msg % (key, type(value)))
+
+ @classmethod
+ def _get_listen_keys(cls, attribute):
+ """Given a descriptor attribute, return a ``set()`` of the attribute
+ keys which indicate a change in the state of this attribute.
+
+ This is normally just ``set([attribute.key])``, but can be overridden
+ to provide for additional keys. E.g. a :class:`.MutableComposite`
+ augments this set with the attribute keys associated with the columns
+ that comprise the composite value.
+
+ This collection is consulted in the case of intercepting the
+ :meth:`.InstanceEvents.refresh` and
+ :meth:`.InstanceEvents.refresh_flush` events, which pass along a list
+ of attribute names that have been refreshed; the list is compared
+ against this set to determine if action needs to be taken.
+
+ .. versionadded:: 1.0.5
+
+ """
+ return {attribute.key}
+
+ @classmethod
+ def _listen_on_attribute(cls, attribute, coerce, parent_cls):
+ """Establish this type as a mutation listener for the given
+ mapped descriptor.
+
+ """
+ key = attribute.key
+ if parent_cls is not attribute.class_:
+ return
+
+ # rely on "propagate" here
+ parent_cls = attribute.class_
+
+ listen_keys = cls._get_listen_keys(attribute)
+
+ def load(state, *args):
+ """Listen for objects loaded or refreshed.
+
+ Wrap the target data member's value with
+ ``Mutable``.
+
+ """
+ val = state.dict.get(key, None)
+ if val is not None:
+ if coerce:
+ val = cls.coerce(key, val)
+ state.dict[key] = val
+ val._parents[state] = key
+
+ def load_attrs(state, ctx, attrs):
+ if not attrs or listen_keys.intersection(attrs):
+ load(state)
+
+ def set_(target, value, oldvalue, initiator):
+ """Listen for set/replace events on the target
+ data member.
+
+ Establish a weak reference to the parent object
+ on the incoming value, remove it for the one
+ outgoing.
+
+ """
+ if value is oldvalue:
+ return value
+
+ if not isinstance(value, cls):
+ value = cls.coerce(key, value)
+ if value is not None:
+ value._parents[target] = key
+ if isinstance(oldvalue, cls):
+ oldvalue._parents.pop(inspect(target), None)
+ return value
+
+ def pickle(state, state_dict):
+ val = state.dict.get(key, None)
+ if val is not None:
+ if "ext.mutable.values" not in state_dict:
+ state_dict["ext.mutable.values"] = defaultdict(list)
+ state_dict["ext.mutable.values"][key].append(val)
+
+ def unpickle(state, state_dict):
+ if "ext.mutable.values" in state_dict:
+ collection = state_dict["ext.mutable.values"]
+ if isinstance(collection, list):
+ # legacy format
+ for val in collection:
+ val._parents[state] = key
+ else:
+ for val in state_dict["ext.mutable.values"][key]:
+ val._parents[state] = key
+
+ event.listen(parent_cls, "load", load, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "refresh", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ attribute, "set", set_, raw=True, retval=True, propagate=True
+ )
+ event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "unpickle", unpickle, raw=True, propagate=True
+ )
+
+
+class Mutable(MutableBase):
+ """Mixin that defines transparent propagation of change
+ events to a parent object.
+
+ See the example in :ref:`mutable_scalars` for usage information.
+
+ """
+
+ def changed(self):
+ """Subclasses should call this method whenever change events occur."""
+
+ for parent, key in self._parents.items():
+ flag_modified(parent.obj(), key)
+
+ @classmethod
+ def associate_with_attribute(cls, attribute):
+ """Establish this type as a mutation listener for the given
+ mapped descriptor.
+
+ """
+ cls._listen_on_attribute(attribute, True, attribute.class_)
+
+ @classmethod
+ def associate_with(cls, sqltype):
+ """Associate this wrapper with all future mapped columns
+ of the given type.
+
+ This is a convenience method that calls
+ ``associate_with_attribute`` automatically.
+
+ .. warning::
+
+ The listeners established by this method are *global*
+ to all mappers, and are *not* garbage collected. Only use
+ :meth:`.associate_with` for types that are permanent to an
+ application, not with ad-hoc types else this will cause unbounded
+ growth in memory usage.
+
+ """
+
+ def listen_for_type(mapper, class_):
+ if mapper.non_primary:
+ return
+ for prop in mapper.column_attrs:
+ if isinstance(prop.columns[0].type, sqltype):
+ cls.associate_with_attribute(getattr(class_, prop.key))
+
+ event.listen(mapper, "mapper_configured", listen_for_type)
+
+ @classmethod
+ def as_mutable(cls, sqltype):
+ """Associate a SQL type with this mutable Python type.
+
+ This establishes listeners that will detect ORM mappings against
+ the given type, adding mutation event trackers to those mappings.
+
+ The type is returned, unconditionally as an instance, so that
+ :meth:`.as_mutable` can be used inline::
+
+ Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', MyMutableType.as_mutable(PickleType))
+ )
+
+ Note that the returned type is always an instance, even if a class
+ is given, and that only columns which are declared specifically with
+ that type instance receive additional instrumentation.
+
+ To associate a particular mutable type with all occurrences of a
+ particular type, use the :meth:`.Mutable.associate_with` classmethod
+ of the particular :class:`.Mutable` subclass to establish a global
+ association.
+
+ .. warning::
+
+ The listeners established by this method are *global*
+ to all mappers, and are *not* garbage collected. Only use
+ :meth:`.as_mutable` for types that are permanent to an application,
+ not with ad-hoc types else this will cause unbounded growth
+ in memory usage.
+
+ """
+ sqltype = types.to_instance(sqltype)
+
+ # a SchemaType will be copied when the Column is copied,
+ # and we'll lose our ability to link that type back to the original.
+ # so track our original type w/ columns
+ if isinstance(sqltype, SchemaEventTarget):
+
+ @event.listens_for(sqltype, "before_parent_attach")
+ def _add_column_memo(sqltyp, parent):
+ parent.info["_ext_mutable_orig_type"] = sqltyp
+
+ schema_event_check = True
+ else:
+ schema_event_check = False
+
+ def listen_for_type(mapper, class_):
+ if mapper.non_primary:
+ return
+ for prop in mapper.column_attrs:
+ if (
+ schema_event_check
+ and hasattr(prop.expression, "info")
+ and prop.expression.info.get("_ext_mutable_orig_type")
+ is sqltype
+ ) or (prop.columns[0].type is sqltype):
+ cls.associate_with_attribute(getattr(class_, prop.key))
+
+ event.listen(mapper, "mapper_configured", listen_for_type)
+
+ return sqltype
+
+
+class MutableComposite(MutableBase):
+ """Mixin that defines transparent propagation of change
+ events on a SQLAlchemy "composite" object to its
+ owning parent or parents.
+
+ See the example in :ref:`mutable_composites` for usage information.
+
+ """
+
+ @classmethod
+ def _get_listen_keys(cls, attribute):
+ return {attribute.key}.union(attribute.property._attribute_keys)
+
+ def changed(self):
+ """Subclasses should call this method whenever change events occur."""
+
+ for parent, key in self._parents.items():
+
+ prop = parent.mapper.get_property(key)
+ for value, attr_name in zip(
+ self.__composite_values__(), prop._attribute_keys
+ ):
+ setattr(parent.obj(), attr_name, value)
+
+
+def _setup_composite_listener():
+ def _listen_for_type(mapper, class_):
+ for prop in mapper.iterate_properties:
+ if (
+ hasattr(prop, "composite_class")
+ and isinstance(prop.composite_class, type)
+ and issubclass(prop.composite_class, MutableComposite)
+ ):
+ prop.composite_class._listen_on_attribute(
+ getattr(class_, prop.key), False, class_
+ )
+
+ if not event.contains(Mapper, "mapper_configured", _listen_for_type):
+ event.listen(Mapper, "mapper_configured", _listen_for_type)
+
+
+_setup_composite_listener()
+
+
+class MutableDict(Mutable, dict):
+ """A dictionary type that implements :class:`.Mutable`.
+
+ The :class:`.MutableDict` object implements a dictionary that will
+ emit change events to the underlying mapping when the contents of
+ the dictionary are altered, including when values are added or removed.
+
+ Note that :class:`.MutableDict` does **not** apply mutable tracking to the
+ *values themselves* inside the dictionary. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ dictionary structure, such as a JSON structure. To support this use case,
+ build a subclass of :class:`.MutableDict` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. seealso::
+
+ :class:`.MutableList`
+
+ :class:`.MutableSet`
+
+ """
+
+ def __setitem__(self, key, value):
+ """Detect dictionary set events and emit change events."""
+ dict.__setitem__(self, key, value)
+ self.changed()
+
+ def setdefault(self, key, value):
+ result = dict.setdefault(self, key, value)
+ self.changed()
+ return result
+
+ def __delitem__(self, key):
+ """Detect dictionary del events and emit change events."""
+ dict.__delitem__(self, key)
+ self.changed()
+
+ def update(self, *a, **kw):
+ dict.update(self, *a, **kw)
+ self.changed()
+
+ def pop(self, *arg):
+ result = dict.pop(self, *arg)
+ self.changed()
+ return result
+
+ def popitem(self):
+ result = dict.popitem(self)
+ self.changed()
+ return result
+
+ def clear(self):
+ dict.clear(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, key, value):
+ """Convert plain dictionary to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, dict):
+ return cls(value)
+ return Mutable.coerce(key, value)
+ else:
+ return value
+
+ def __getstate__(self):
+ return dict(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+
+class MutableList(Mutable, list):
+ """A list type that implements :class:`.Mutable`.
+
+ The :class:`.MutableList` object implements a list that will
+ emit change events to the underlying mapping when the contents of
+ the list are altered, including when values are added or removed.
+
+ Note that :class:`.MutableList` does **not** apply mutable tracking to the
+ *values themselves* inside the list. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ mutable structure, such as a JSON structure. To support this use case,
+ build a subclass of :class:`.MutableList` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :class:`.MutableDict`
+
+ :class:`.MutableSet`
+
+ """
+
+ def __reduce_ex__(self, proto):
+ return (self.__class__, (list(self),))
+
+ # needed for backwards compatibility with
+ # older pickles
+ def __setstate__(self, state):
+ self[:] = state
+
+ def __setitem__(self, index, value):
+ """Detect list set events and emit change events."""
+ list.__setitem__(self, index, value)
+ self.changed()
+
+ def __setslice__(self, start, end, value):
+ """Detect list set events and emit change events."""
+ list.__setslice__(self, start, end, value)
+ self.changed()
+
+ def __delitem__(self, index):
+ """Detect list del events and emit change events."""
+ list.__delitem__(self, index)
+ self.changed()
+
+ def __delslice__(self, start, end):
+ """Detect list del events and emit change events."""
+ list.__delslice__(self, start, end)
+ self.changed()
+
+ def pop(self, *arg):
+ result = list.pop(self, *arg)
+ self.changed()
+ return result
+
+ def append(self, x):
+ list.append(self, x)
+ self.changed()
+
+ def extend(self, x):
+ list.extend(self, x)
+ self.changed()
+
+ def __iadd__(self, x):
+ self.extend(x)
+ return self
+
+ def insert(self, i, x):
+ list.insert(self, i, x)
+ self.changed()
+
+ def remove(self, i):
+ list.remove(self, i)
+ self.changed()
+
+ def clear(self):
+ list.clear(self)
+ self.changed()
+
+ def sort(self, **kw):
+ list.sort(self, **kw)
+ self.changed()
+
+ def reverse(self):
+ list.reverse(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, index, value):
+ """Convert plain list to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, list):
+ return cls(value)
+ return Mutable.coerce(index, value)
+ else:
+ return value
+
+
+class MutableSet(Mutable, set):
+ """A set type that implements :class:`.Mutable`.
+
+ The :class:`.MutableSet` object implements a set that will
+ emit change events to the underlying mapping when the contents of
+ the set are altered, including when values are added or removed.
+
+ Note that :class:`.MutableSet` does **not** apply mutable tracking to the
+ *values themselves* inside the set. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ mutable structure. To support this use case,
+ build a subclass of :class:`.MutableSet` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :class:`.MutableDict`
+
+ :class:`.MutableList`
+
+
+ """
+
+ def update(self, *arg):
+ set.update(self, *arg)
+ self.changed()
+
+ def intersection_update(self, *arg):
+ set.intersection_update(self, *arg)
+ self.changed()
+
+ def difference_update(self, *arg):
+ set.difference_update(self, *arg)
+ self.changed()
+
+ def symmetric_difference_update(self, *arg):
+ set.symmetric_difference_update(self, *arg)
+ self.changed()
+
+ def __ior__(self, other):
+ self.update(other)
+ return self
+
+ def __iand__(self, other):
+ self.intersection_update(other)
+ return self
+
+ def __ixor__(self, other):
+ self.symmetric_difference_update(other)
+ return self
+
+ def __isub__(self, other):
+ self.difference_update(other)
+ return self
+
+ def add(self, elem):
+ set.add(self, elem)
+ self.changed()
+
+ def remove(self, elem):
+ set.remove(self, elem)
+ self.changed()
+
+ def discard(self, elem):
+ set.discard(self, elem)
+ self.changed()
+
+ def pop(self, *arg):
+ result = set.pop(self, *arg)
+ self.changed()
+ return result
+
+ def clear(self):
+ set.clear(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, index, value):
+ """Convert plain set to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, set):
+ return cls(value)
+ return Mutable.coerce(index, value)
+ else:
+ return value
+
+ def __getstate__(self):
+ return set(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+ def __reduce_ex__(self, proto):
+ return (self.__class__, (list(self),))
diff --git a/lib/sqlalchemy/ext/mypy/__init__.py b/lib/sqlalchemy/ext/mypy/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/__init__.py
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
new file mode 100644
index 0000000..99be194
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -0,0 +1,299 @@
+# ext/mypy/apply.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mypy.nodes import ARG_NAMED_OPT
+from mypy.nodes import Argument
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import MDEF
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import add_method_to_class
+from mypy.types import AnyType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneTyp
+from mypy.types import ProperType
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import infer
+from . import util
+from .names import NAMED_TYPE_SQLA_MAPPED
+
+
+def apply_mypy_mapped_attr(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ item: Union[NameExpr, StrExpr],
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ if isinstance(item, NameExpr):
+ name = item.name
+ elif isinstance(item, StrExpr):
+ name = item.value
+ else:
+ return None
+
+ for stmt in cls.defs.body:
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name == name
+ ):
+ break
+ else:
+ util.fail(api, "Can't find mapped attribute {}".format(name), cls)
+ return None
+
+ if stmt.type is None:
+ util.fail(
+ api,
+ "Statement linked from _mypy_mapped_attrs has no "
+ "typing information",
+ stmt,
+ )
+ return None
+
+ left_hand_explicit_type = get_proper_type(stmt.type)
+ assert isinstance(
+ left_hand_explicit_type, (Instance, UnionType, UnboundType)
+ )
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=item.line,
+ column=item.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+
+ apply_type_to_mapped_statement(
+ api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
+ )
+
+
+def re_apply_declarative_assignments(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """For multiple class passes, re-apply our left-hand side types as mypy
+ seems to reset them in place.
+
+ """
+ mapped_attr_lookup = {attr.name: attr for attr in attributes}
+ update_cls_metadata = False
+
+ for stmt in cls.defs.body:
+ # for a re-apply, all of our statements are AssignmentStmt;
+ # @declared_attr calls will have been converted and this
+ # currently seems to be preserved by mypy (but who knows if this
+ # will change).
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name in mapped_attr_lookup
+ and isinstance(stmt.lvalues[0].node, Var)
+ ):
+
+ left_node = stmt.lvalues[0].node
+ python_type_for_type = mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type
+
+ left_node_proper_type = get_proper_type(left_node.type)
+
+ # if we have scanned an UnboundType and now there's a more
+ # specific type than UnboundType, call the re-scan so we
+ # can get that set up correctly
+ if (
+ isinstance(python_type_for_type, UnboundType)
+ and not isinstance(left_node_proper_type, UnboundType)
+ and (
+ isinstance(stmt.rvalue, CallExpr)
+ and isinstance(stmt.rvalue.callee, MemberExpr)
+ and isinstance(stmt.rvalue.callee.expr, NameExpr)
+ and stmt.rvalue.callee.expr.node is not None
+ and stmt.rvalue.callee.expr.node.fullname
+ == NAMED_TYPE_SQLA_MAPPED
+ and stmt.rvalue.callee.name == "_empty_constructor"
+ and isinstance(stmt.rvalue.args[0], CallExpr)
+ and isinstance(stmt.rvalue.args[0].callee, RefExpr)
+ )
+ ):
+
+ python_type_for_type = (
+ infer.infer_type_from_right_hand_nameexpr(
+ api,
+ stmt,
+ left_node,
+ left_node_proper_type,
+ stmt.rvalue.args[0].callee,
+ )
+ )
+
+ if python_type_for_type is None or isinstance(
+ python_type_for_type, UnboundType
+ ):
+ continue
+
+ # update the SQLAlchemyAttribute with the better information
+ mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type = python_type_for_type
+
+ update_cls_metadata = True
+
+ if python_type_for_type is not None:
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
+ )
+
+ if update_cls_metadata:
+ util.set_mapped_attributes(cls.info, attributes)
+
+
+def apply_type_to_mapped_statement(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ lvalue: NameExpr,
+ left_hand_explicit_type: Optional[ProperType],
+ python_type_for_type: Optional[ProperType],
+) -> None:
+ """Apply the Mapped[<type>] annotation and right hand object to a
+ declarative assignment statement.
+
+ This converts a Python declarative class statement such as::
+
+ class User(Base):
+ # ...
+
+ attrname = Column(Integer)
+
+ To one that describes the final Python behavior to Mypy::
+
+ class User(Base):
+ # ...
+
+ attrname : Mapped[Optional[int]] = <meaningless temp node>
+
+ """
+ left_node = lvalue.node
+ assert isinstance(left_node, Var)
+
+ if left_hand_explicit_type is not None:
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
+ )
+ else:
+ lvalue.is_inferred_def = False
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED,
+ [] if python_type_for_type is None else [python_type_for_type],
+ )
+
+ # so to have it skip the right side totally, we can do this:
+ # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # however, if we instead manufacture a new node that uses the old
+ # one, then we can still get type checking for the call itself,
+ # e.g. the Column, relationship() call, etc.
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
+ # the original right-hand side is maintained so it gets type checked
+ # internally
+ stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
+
+
+def add_additional_orm_attributes(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Apply __init__, __table__ and other attributes to the mapped class."""
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ return
+
+ is_base = util.get_is_base(info)
+
+ if "__init__" not in info.names and not is_base:
+ mapped_attr_names = {attr.name: attr.type for attr in attributes}
+
+ for base in info.mro[1:-1]:
+ if "sqlalchemy" not in info.metadata:
+ continue
+
+ base_cls_attributes = util.get_mapped_attributes(base, api)
+ if base_cls_attributes is None:
+ continue
+
+ for attr in base_cls_attributes:
+ mapped_attr_names.setdefault(attr.name, attr.type)
+
+ arguments = []
+ for name, typ in mapped_attr_names.items():
+ if typ is None:
+ typ = AnyType(TypeOfAny.special_form)
+ arguments.append(
+ Argument(
+ variable=Var(name, typ),
+ type_annotation=typ,
+ initializer=TempNode(typ),
+ kind=ARG_NAMED_OPT,
+ )
+ )
+
+ add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
+
+ if "__table__" not in info.names and util.get_has_table(info):
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.sql.schema.Table", "__table__"
+ )
+ if not is_base:
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
+ )
+
+
+def _apply_placeholder_attr_to_class(
+ api: SemanticAnalyzerPluginInterface,
+ cls: ClassDef,
+ qualified_name: str,
+ attrname: str,
+) -> None:
+ sym = api.lookup_fully_qualified_or_none(qualified_name)
+ if sym:
+ assert isinstance(sym.node, TypeInfo)
+ type_: ProperType = Instance(sym.node, [])
+ else:
+ type_ = AnyType(TypeOfAny.special_form)
+ var = Var(attrname)
+ var._fullname = cls.fullname + "." + attrname
+ var.info = cls.info
+ var.type = type_
+ cls.info.names[attrname] = SymbolTableNode(MDEF, var)
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
new file mode 100644
index 0000000..c33c30e
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -0,0 +1,516 @@
+# ext/mypy/decl_class.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import Decorator
+from mypy.nodes import LambdaExpr
+from mypy.nodes import ListExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import PlaceholderNode
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import apply
+from . import infer
+from . import names
+from . import util
+
+
+def scan_declarative_assignments_and_apply_types(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ is_mixin_scan: bool = False,
+) -> Optional[List[util.SQLAlchemyAttribute]]:
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ # this can occur during cached passes
+ return None
+ elif cls.fullname.startswith("builtins"):
+ return None
+
+ mapped_attributes: Optional[
+ List[util.SQLAlchemyAttribute]
+ ] = util.get_mapped_attributes(info, api)
+
+ # used by assign.add_additional_orm_attributes among others
+ util.establish_as_sqlalchemy(info)
+
+ if mapped_attributes is not None:
+ # ensure that a class that's mapped is always picked up by
+ # its mapped() decorator or declarative metaclass before
+ # it would be detected as an unmapped mixin class
+
+ if not is_mixin_scan:
+ # mypy can call us more than once. it then *may* have reset the
+ # left hand side of everything, but not the right that we removed,
+ # removing our ability to re-scan. but we have the types
+ # here, so lets re-apply them, or if we have an UnboundType,
+ # we can re-scan
+
+ apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
+
+ return mapped_attributes
+
+ mapped_attributes = []
+
+ if not cls.defs.body:
+ # when we get a mixin class from another file, the body is
+ # empty (!) but the names are in the symbol table. so use that.
+
+ for sym_name, sym in info.names.items():
+ _scan_symbol_table_entry(
+ cls, api, sym_name, sym, mapped_attributes
+ )
+ else:
+ for stmt in util.flatten_typechecking(cls.defs.body):
+ if isinstance(stmt, AssignmentStmt):
+ _scan_declarative_assignment_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ elif isinstance(stmt, Decorator):
+ _scan_declarative_decorator_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ _scan_for_mapped_bases(cls, api)
+
+ if not is_mixin_scan:
+ apply.add_additional_orm_attributes(cls, api, mapped_attributes)
+
+ util.set_mapped_attributes(info, mapped_attributes)
+
+ return mapped_attributes
+
+
+def _scan_symbol_table_entry(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ name: str,
+ value: SymbolTableNode,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from a SymbolTableNode that's in the
+ type.names dictionary.
+
+ """
+ value_type = get_proper_type(value.type)
+ if not isinstance(value_type, Instance):
+ return
+
+ left_hand_explicit_type = None
+ type_id = names.type_id_for_named_node(value_type.type)
+ # type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
+
+ err = False
+
+ # TODO: this is nearly the same logic as that of
+ # _scan_declarative_decorator_stmt, likely can be merged
+ if type_id in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }:
+ if value_type.args:
+ left_hand_explicit_type = get_proper_type(value_type.args[0])
+ else:
+ err = True
+ elif type_id is names.COLUMN:
+ if not value_type.args:
+ err = True
+ else:
+ typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
+ value_type.args[0]
+ )
+ if isinstance(typeengine_arg, Instance):
+ typeengine_arg = typeengine_arg.type
+
+ if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
+ sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
+
+ left_hand_explicit_type = UnionType(
+ [
+ infer.extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ value_type,
+ )
+
+ if err:
+ msg = (
+ "Can't infer type from attribute {} on class {}. "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(name, cls.name), cls)
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ if left_hand_explicit_type is not None:
+ assert value.node is not None
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=value.node.line,
+ column=value.node.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+
+
+def _scan_declarative_decorator_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: Decorator,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from a @declared_attr in a declarative
+ class.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ @declared_attr
+ def updated_at(cls) -> Column[DateTime]:
+ return Column(DateTime)
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ updated_at: Mapped[Optional[datetime.datetime]]
+
+ """
+ for dec in stmt.decorators:
+ if (
+ isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
+ and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
+ ):
+ break
+ else:
+ return
+
+ dec_index = cls.defs.body.index(stmt)
+
+ left_hand_explicit_type: Optional[ProperType] = None
+
+ if util.name_is_dunder(stmt.name):
+ # for dunder names like __table_args__, __tablename__,
+ # __mapper_args__ etc., rewrite these as simple assignment
+ # statements; otherwise mypy doesn't like if the decorated
+ # function has an annotation like ``cls: Type[Foo]`` because
+ # it isn't @classmethod
+ any_ = AnyType(TypeOfAny.special_form)
+ left_node = NameExpr(stmt.var.name)
+ left_node.node = stmt.var
+ new_stmt = AssignmentStmt([left_node], TempNode(any_))
+ new_stmt.type = left_node.node.type
+ cls.defs.body[dec_index] = new_stmt
+ return
+ elif isinstance(stmt.func.type, CallableType):
+ func_type = stmt.func.type.ret_type
+ if isinstance(func_type, UnboundType):
+ type_id = names.type_id_for_unbound_type(func_type, cls, api)
+ else:
+ # this does not seem to occur unless the type argument is
+ # incorrect
+ return
+
+ if (
+ type_id
+ in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }
+ and func_type.args
+ ):
+ left_hand_explicit_type = get_proper_type(func_type.args[0])
+ elif type_id is names.COLUMN and func_type.args:
+ typeengine_arg = func_type.args[0]
+ if isinstance(typeengine_arg, UnboundType):
+ sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
+ left_hand_explicit_type = UnionType(
+ [
+ infer.extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ func_type,
+ )
+
+ if left_hand_explicit_type is None:
+ # no type on the decorated function. our option here is to
+ # dig into the function body and get the return type, but they
+ # should just have an annotation.
+ msg = (
+ "Can't infer type from @declared_attr on function '{}'; "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(stmt.var.name), stmt)
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ left_node = NameExpr(stmt.var.name)
+ left_node.node = stmt.var
+
+ # totally feeling around in the dark here as I don't totally understand
+ # the significance of UnboundType. It seems to be something that is
+ # not going to do what's expected when it is applied as the type of
+ # an AssignmentStatement. So do a feeling-around-in-the-dark version
+ # of converting it to the regular Instance/TypeInfo/UnionType structures
+ # we see everywhere else.
+ if isinstance(left_hand_explicit_type, UnboundType):
+ left_hand_explicit_type = get_proper_type(
+ util.unbound_to_instance(api, left_hand_explicit_type)
+ )
+
+ left_node.node.type = api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
+ )
+
+ # this will ignore the rvalue entirely
+ # rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(lambda: <function body>)
+ # the function body is maintained so it gets type checked internally
+ rvalue = util.expr_to_mapped_constructor(
+ LambdaExpr(stmt.func.arguments, stmt.func.body)
+ )
+
+ new_stmt = AssignmentStmt([left_node], rvalue)
+ new_stmt.type = left_node.node.type
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=left_node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+ cls.defs.body[dec_index] = new_stmt
+
+
+def _scan_declarative_assignment_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from an assignment statement in a
+ declarative class.
+
+ """
+ lvalue = stmt.lvalues[0]
+ if not isinstance(lvalue, NameExpr):
+ return
+
+ sym = cls.info.names.get(lvalue.name)
+
+ # this establishes that semantic analysis has taken place, which
+ # means the nodes are populated and we are called from an appropriate
+ # hook.
+ assert sym is not None
+ node = sym.node
+
+ if isinstance(node, PlaceholderNode):
+ return
+
+ assert node is lvalue.node
+ assert isinstance(node, Var)
+
+ if node.name == "__abstract__":
+ if api.parse_bool(stmt.rvalue) is True:
+ util.set_is_base(cls.info)
+ return
+ elif node.name == "__tablename__":
+ util.set_has_table(cls.info)
+ elif node.name.startswith("__"):
+ return
+ elif node.name == "_mypy_mapped_attrs":
+ if not isinstance(stmt.rvalue, ListExpr):
+ util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
+ else:
+ for item in stmt.rvalue.items:
+ if isinstance(item, (NameExpr, StrExpr)):
+ apply.apply_mypy_mapped_attr(cls, api, item, attributes)
+
+ left_hand_mapped_type: Optional[Type] = None
+ left_hand_explicit_type: Optional[ProperType] = None
+
+ if node.is_inferred or node.type is None:
+ if isinstance(stmt.type, UnboundType):
+ # look for an explicit Mapped[] type annotation on the left
+ # side with nothing on the right
+
+ # print(stmt.type)
+ # Mapped?[Optional?[A?]]
+
+ left_hand_explicit_type = stmt.type
+
+ if stmt.type.name == "Mapped":
+ mapped_sym = api.lookup_qualified("Mapped", cls)
+ if (
+ mapped_sym is not None
+ and mapped_sym.node is not None
+ and names.type_id_for_named_node(mapped_sym.node)
+ is names.MAPPED
+ ):
+ left_hand_explicit_type = get_proper_type(
+ stmt.type.args[0]
+ )
+ left_hand_mapped_type = stmt.type
+
+ # TODO: do we need to convert from unbound for this case?
+ # left_hand_explicit_type = util._unbound_to_instance(
+ # api, left_hand_explicit_type
+ # )
+ else:
+ node_type = get_proper_type(node.type)
+ if (
+ isinstance(node_type, Instance)
+ and names.type_id_for_named_node(node_type.type) is names.MAPPED
+ ):
+ # print(node.type)
+ # sqlalchemy.orm.attributes.Mapped[<python type>]
+ left_hand_explicit_type = get_proper_type(node_type.args[0])
+ left_hand_mapped_type = node_type
+ else:
+ # print(node.type)
+ # <python type>
+ left_hand_explicit_type = node_type
+ left_hand_mapped_type = None
+
+ if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
+ # annotation without assignment and Mapped is present
+ # as type annotation
+ # equivalent to using _infer_type_from_left_hand_type_only.
+
+ python_type_for_type = left_hand_explicit_type
+ elif isinstance(stmt.rvalue, CallExpr) and isinstance(
+ stmt.rvalue.callee, RefExpr
+ ):
+
+ python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
+ api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
+ )
+
+ if python_type_for_type is None:
+ return
+
+ else:
+ return
+
+ assert python_type_for_type is not None
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=python_type_for_type,
+ info=cls.info,
+ )
+ )
+
+ apply.apply_type_to_mapped_statement(
+ api,
+ stmt,
+ lvalue,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
+
+
+def _scan_for_mapped_bases(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+) -> None:
+ """Given a class, iterate through its superclass hierarchy to find
+ all other classes that are considered as ORM-significant.
+
+ Locates non-mapped mixins and scans them for mapped attributes to be
+ applied to subclasses.
+
+ """
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ return
+
+ for base_info in info.mro[1:-1]:
+ if base_info.fullname.startswith("builtins"):
+ continue
+
+ # scan each base for mapped attributes. if they are not already
+ # scanned (but have all their type info), that means they are unmapped
+ # mixins
+ scan_declarative_assignments_and_apply_types(
+ base_info.defn, api, is_mixin_scan=True
+ )
diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py
new file mode 100644
index 0000000..f88a960
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/infer.py
@@ -0,0 +1,556 @@
+# ext/mypy/infer.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import Optional
+from typing import Sequence
+
+from mypy.maptype import map_instance_to_supertype
+from mypy.messages import format_type
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import LambdaExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.subtypes import is_subtype
+from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import TypeOfAny
+from mypy.types import UnionType
+
+from . import names
+from . import util
+
+
+def infer_type_from_right_hand_nameexpr(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ infer_from_right_side: RefExpr,
+) -> Optional[ProperType]:
+
+ type_id = names.type_id_for_callee(infer_from_right_side)
+
+ if type_id is None:
+ return None
+ elif type_id is names.COLUMN:
+ python_type_for_type = _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.RELATIONSHIP:
+ python_type_for_type = _infer_type_from_relationship(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.COLUMN_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_column_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.SYNONYM_PROPERTY:
+ python_type_for_type = infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif type_id is names.COMPOSITE_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_composite_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ else:
+ return None
+
+ return python_type_for_type
+
+
+def _infer_type_from_relationship(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a relationship.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses = relationship(Address, uselist=True)
+
+ order: Mapped["Order"] = relationship("Order")
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses: Mapped[List[Address]]
+
+ order: Mapped["Order"]
+
+ """
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type: Optional[ProperType] = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ # type
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+
+ # other cases not covered - an error message directs the user
+ # to set an explicit type annotation
+ #
+ # node.type == str, it's a string
+ # if isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, Var
+ # )
+ # points to a type
+ # isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, TypeAlias
+ # )
+ # string expression
+ # isinstance(target_cls_arg, StrExpr)
+
+ uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
+ collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
+ stmt.rvalue, "collection_class"
+ )
+ type_is_a_collection = False
+
+ # this can be used to determine Optional for a many-to-one
+ # in the same way nullable=False could be used, if we start supporting
+ # that.
+ # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+
+ if (
+ uselist_arg is not None
+ and api.parse_bool(uselist_arg) is True
+ and collection_cls_arg is None
+ ):
+ type_is_a_collection = True
+ if python_type_for_type is not None:
+ python_type_for_type = api.named_type(
+ names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
+ )
+ elif (
+ uselist_arg is None or api.parse_bool(uselist_arg) is True
+ ) and collection_cls_arg is not None:
+ type_is_a_collection = True
+ if isinstance(collection_cls_arg, CallExpr):
+ collection_cls_arg = collection_cls_arg.callee
+
+ if isinstance(collection_cls_arg, NameExpr) and isinstance(
+ collection_cls_arg.node, TypeInfo
+ ):
+ if python_type_for_type is not None:
+ # this can still be overridden by the left hand side
+ # within _infer_Type_from_left_and_inferred_right
+ python_type_for_type = Instance(
+ collection_cls_arg.node, [python_type_for_type]
+ )
+ elif (
+ isinstance(collection_cls_arg, NameExpr)
+ and isinstance(collection_cls_arg.node, FuncDef)
+ and collection_cls_arg.node.type is not None
+ ):
+ if python_type_for_type is not None:
+ # this can still be overridden by the left hand side
+ # within _infer_Type_from_left_and_inferred_right
+
+ # TODO: handle mypy.types.Overloaded
+ if isinstance(collection_cls_arg.node.type, CallableType):
+ rt = get_proper_type(collection_cls_arg.node.type.ret_type)
+
+ if isinstance(rt, CallableType):
+ callable_ret_type = get_proper_type(rt.ret_type)
+ if isinstance(callable_ret_type, Instance):
+ python_type_for_type = Instance(
+ callable_ret_type.type,
+ [python_type_for_type],
+ )
+ else:
+ util.fail(
+ api,
+ "Expected Python collection type for "
+ "collection_class parameter",
+ stmt.rvalue,
+ )
+ python_type_for_type = None
+ elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
+ if collection_cls_arg is not None:
+ util.fail(
+ api,
+ "Sending uselist=False and collection_class at the same time "
+ "does not make sense",
+ stmt.rvalue,
+ )
+ if python_type_for_type is not None:
+ python_type_for_type = UnionType(
+ [python_type_for_type, NoneType()]
+ )
+
+ else:
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer scalar or collection for ORM mapped expression "
+ "assigned to attribute '{}' if both 'uselist' and "
+ "'collection_class' arguments are absent from the "
+ "relationship(); please specify a "
+ "type annotation on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ if python_type_for_type is None:
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ if type_is_a_collection:
+ assert isinstance(left_hand_explicit_type, Instance)
+ assert isinstance(python_type_for_type, Instance)
+ return _infer_collection_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_composite_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a CompositeProperty."""
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+ else:
+ python_type_for_type = None
+
+ if python_type_for_type is None:
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_column_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a ColumnProperty.
+
+ This includes mappings against ``column_property()`` as well as the
+ ``deferred()`` function.
+
+ """
+ assert isinstance(stmt.rvalue, CallExpr)
+
+ if stmt.rvalue.args:
+ first_prop_arg = stmt.rvalue.args[0]
+
+ if isinstance(first_prop_arg, CallExpr):
+ type_id = names.type_id_for_callee(first_prop_arg.callee)
+
+ # look for column_property() / deferred() etc with Column as first
+ # argument
+ if type_id is names.COLUMN:
+ return _infer_type_from_decl_column(
+ api,
+ stmt,
+ node,
+ left_hand_explicit_type,
+ right_hand_expression=first_prop_arg,
+ )
+
+ if isinstance(stmt.rvalue, CallExpr):
+ type_id = names.type_id_for_callee(stmt.rvalue.callee)
+ # this is probably not strictly necessary as we have to use the left
+ # hand type for query expression in any case. any other no-arg
+ # column prop objects would go here also
+ if type_id is names.QUERY_EXPRESSION:
+ return _infer_type_from_decl_column(
+ api,
+ stmt,
+ node,
+ left_hand_explicit_type,
+ )
+
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_decl_column(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ right_hand_expression: Optional[CallExpr] = None,
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a Column.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a = Column(Integer)
+
+ b = Column("b", String)
+
+ c: Mapped[int] = Column(Integer)
+
+ d: bool = Column(Boolean)
+
+ Will resolve in MyPy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a : Mapped[int]
+
+ b : Mapped[str]
+
+ c: Mapped[int]
+
+ d: Mapped[bool]
+
+ """
+ assert isinstance(node, Var)
+
+ callee = None
+
+ if right_hand_expression is None:
+ if not isinstance(stmt.rvalue, CallExpr):
+ return None
+
+ right_hand_expression = stmt.rvalue
+
+ for column_arg in right_hand_expression.args[0:2]:
+ if isinstance(column_arg, CallExpr):
+ if isinstance(column_arg.callee, RefExpr):
+ # x = Column(String(50))
+ callee = column_arg.callee
+ type_args: Sequence[Expression] = column_arg.args
+ break
+ elif isinstance(column_arg, (NameExpr, MemberExpr)):
+ if isinstance(column_arg.node, TypeInfo):
+ # x = Column(String)
+ callee = column_arg
+ type_args = ()
+ break
+ else:
+ # x = Column(some_name, String), go to next argument
+ continue
+ elif isinstance(column_arg, (StrExpr,)):
+ # x = Column("name", String), go to next argument
+ continue
+ elif isinstance(column_arg, (LambdaExpr,)):
+ # x = Column("name", String, default=lambda: uuid.uuid4())
+ # go to next argument
+ continue
+ else:
+ assert False
+
+ if callee is None:
+ return None
+
+ if isinstance(callee.node, TypeInfo) and names.mro_has_id(
+ callee.node.mro, names.TYPEENGINE
+ ):
+ python_type_for_type = extract_python_type_from_typeengine(
+ api, callee.node, type_args
+ )
+
+ if left_hand_explicit_type is not None:
+
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+
+ else:
+ return UnionType([python_type_for_type, NoneType()])
+ else:
+ # it's not TypeEngine, it's typically implicitly typed
+ # like ForeignKey. we can't infer from the right side.
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: ProperType,
+ python_type_for_type: ProperType,
+ orig_left_hand_type: Optional[ProperType] = None,
+ orig_python_type_for_type: Optional[ProperType] = None,
+) -> Optional[ProperType]:
+ """Validate type when a left hand annotation is present and we also
+ could infer the right hand side::
+
+ attrname: SomeType = Column(SomeDBType)
+
+ """
+
+ if orig_left_hand_type is None:
+ orig_left_hand_type = left_hand_explicit_type
+ if orig_python_type_for_type is None:
+ orig_python_type_for_type = python_type_for_type
+
+ if not is_subtype(left_hand_explicit_type, python_type_for_type):
+ effective_type = api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
+ )
+
+ msg = (
+ "Left hand assignment '{}: {}' not compatible "
+ "with ORM mapped expression of type {}"
+ )
+ util.fail(
+ api,
+ msg.format(
+ node.name,
+ format_type(orig_left_hand_type),
+ format_type(effective_type),
+ ),
+ node,
+ )
+
+ return orig_left_hand_type
+
+
+def _infer_collection_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Instance,
+ python_type_for_type: Instance,
+) -> Optional[ProperType]:
+ orig_left_hand_type = left_hand_explicit_type
+ orig_python_type_for_type = python_type_for_type
+
+ if left_hand_explicit_type.args:
+ left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
+ python_type_arg = get_proper_type(python_type_for_type.args[0])
+ else:
+ left_hand_arg = left_hand_explicit_type
+ python_type_arg = python_type_for_type
+
+ assert isinstance(left_hand_arg, (Instance, UnionType))
+ assert isinstance(python_type_arg, (Instance, UnionType))
+
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_arg,
+ python_type_arg,
+ orig_left_hand_type=orig_left_hand_type,
+ orig_python_type_for_type=orig_python_type_for_type,
+ )
+
+
+def infer_type_from_left_hand_type_only(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Determine the type based on explicit annotation only.
+
+ if no annotation were present, note that we need one there to know
+ the type.
+
+ """
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer type from ORM mapped expression "
+ "assigned to attribute '{}'; please specify a "
+ "Python type or "
+ "Mapped[<python type>] on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ return api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
+ )
+
+ else:
+ # use type from the left hand side
+ return left_hand_explicit_type
+
+
+def extract_python_type_from_typeengine(
+ api: SemanticAnalyzerPluginInterface,
+ node: TypeInfo,
+ type_args: Sequence[Expression],
+) -> ProperType:
+ if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
+ first_arg = type_args[0]
+ if isinstance(first_arg, RefExpr) and isinstance(
+ first_arg.node, TypeInfo
+ ):
+ for base_ in first_arg.node.mro:
+ if base_.fullname == "enum.Enum":
+ return Instance(first_arg.node, [])
+ # TODO: support other pep-435 types here
+ else:
+ return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
+
+ assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
+ "could not extract Python type from node: %s" % node
+ )
+
+ type_engine_sym = api.lookup_fully_qualified_or_none(
+ "sqlalchemy.sql.type_api.TypeEngine"
+ )
+
+ assert type_engine_sym is not None and isinstance(
+ type_engine_sym.node, TypeInfo
+ )
+ type_engine = map_instance_to_supertype(
+ Instance(node, []),
+ type_engine_sym.node,
+ )
+ return get_proper_type(type_engine.args[-1])
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
new file mode 100644
index 0000000..8ec15a6
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/names.py
@@ -0,0 +1,253 @@
+# ext/mypy/names.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Union
+
+from mypy.nodes import ClassDef
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import TypeAlias
+from mypy.nodes import TypeInfo
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import UnboundType
+
+from ... import util
+
+COLUMN: int = util.symbol("COLUMN") # type: ignore
+RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore
+REGISTRY: int = util.symbol("REGISTRY") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore
+MAPPED: int = util.symbol("MAPPED") # type: ignore
+DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore
+MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore
+COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore
+DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore
+MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore
+AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore
+AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore
+QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore
+
+# names that must succeed with mypy.api.named_type
+NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
+NAMED_TYPE_BUILTINS_STR = "builtins.str"
+NAMED_TYPE_BUILTINS_LIST = "builtins.list"
+NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped"
+
+_lookup: Dict[str, Tuple[int, Set[str]]] = {
+ "Column": (
+ COLUMN,
+ {
+ "sqlalchemy.sql.schema.Column",
+ "sqlalchemy.sql.Column",
+ },
+ ),
+ "RelationshipProperty": (
+ RELATIONSHIP,
+ {
+ "sqlalchemy.orm.relationships.RelationshipProperty",
+ "sqlalchemy.orm.RelationshipProperty",
+ },
+ ),
+ "registry": (
+ REGISTRY,
+ {
+ "sqlalchemy.orm.decl_api.registry",
+ "sqlalchemy.orm.registry",
+ },
+ ),
+ "ColumnProperty": (
+ COLUMN_PROPERTY,
+ {
+ "sqlalchemy.orm.properties.ColumnProperty",
+ "sqlalchemy.orm.ColumnProperty",
+ },
+ ),
+ "SynonymProperty": (
+ SYNONYM_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.SynonymProperty",
+ "sqlalchemy.orm.SynonymProperty",
+ },
+ ),
+ "CompositeProperty": (
+ COMPOSITE_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.CompositeProperty",
+ "sqlalchemy.orm.CompositeProperty",
+ },
+ ),
+ "MapperProperty": (
+ MAPPER_PROPERTY,
+ {
+ "sqlalchemy.orm.interfaces.MapperProperty",
+ "sqlalchemy.orm.MapperProperty",
+ },
+ ),
+ "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
+ "Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}),
+ "declarative_base": (
+ DECLARATIVE_BASE,
+ {
+ "sqlalchemy.ext.declarative.declarative_base",
+ "sqlalchemy.orm.declarative_base",
+ "sqlalchemy.orm.decl_api.declarative_base",
+ },
+ ),
+ "DeclarativeMeta": (
+ DECLARATIVE_META,
+ {
+ "sqlalchemy.ext.declarative.DeclarativeMeta",
+ "sqlalchemy.orm.DeclarativeMeta",
+ "sqlalchemy.orm.decl_api.DeclarativeMeta",
+ },
+ ),
+ "mapped": (
+ MAPPED_DECORATOR,
+ {
+ "sqlalchemy.orm.decl_api.registry.mapped",
+ "sqlalchemy.orm.registry.mapped",
+ },
+ ),
+ "as_declarative": (
+ AS_DECLARATIVE,
+ {
+ "sqlalchemy.ext.declarative.as_declarative",
+ "sqlalchemy.orm.decl_api.as_declarative",
+ "sqlalchemy.orm.as_declarative",
+ },
+ ),
+ "as_declarative_base": (
+ AS_DECLARATIVE_BASE,
+ {
+ "sqlalchemy.orm.decl_api.registry.as_declarative_base",
+ "sqlalchemy.orm.registry.as_declarative_base",
+ },
+ ),
+ "declared_attr": (
+ DECLARED_ATTR,
+ {
+ "sqlalchemy.orm.decl_api.declared_attr",
+ "sqlalchemy.orm.declared_attr",
+ },
+ ),
+ "declarative_mixin": (
+ DECLARATIVE_MIXIN,
+ {
+ "sqlalchemy.orm.decl_api.declarative_mixin",
+ "sqlalchemy.orm.declarative_mixin",
+ },
+ ),
+ "query_expression": (
+ QUERY_EXPRESSION,
+ {"sqlalchemy.orm.query_expression"},
+ ),
+}
+
+
+def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
+ for mr in info.mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ if fullnames is None:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
+ for mr in mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ if fullnames is None:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def type_id_for_unbound_type(
+ type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> Optional[int]:
+ sym = api.lookup_qualified(type_.name, type_)
+ if sym is not None:
+ if isinstance(sym.node, TypeAlias):
+ target_type = get_proper_type(sym.node.target)
+ if isinstance(target_type, Instance):
+ return type_id_for_named_node(target_type.type)
+ elif isinstance(sym.node, TypeInfo):
+ return type_id_for_named_node(sym.node)
+
+ return None
+
+
+def type_id_for_callee(callee: Expression) -> Optional[int]:
+ if isinstance(callee, (MemberExpr, NameExpr)):
+ if isinstance(callee.node, FuncDef):
+ if callee.node.type and isinstance(callee.node.type, CallableType):
+ ret_type = get_proper_type(callee.node.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return type_id_for_fullname(ret_type.type.fullname)
+
+ return None
+ elif isinstance(callee.node, TypeAlias):
+ target_type = get_proper_type(callee.node.target)
+ if isinstance(target_type, Instance):
+ return type_id_for_fullname(target_type.type.fullname)
+ elif isinstance(callee.node, TypeInfo):
+ return type_id_for_named_node(callee)
+ return None
+
+
+def type_id_for_named_node(
+ node: Union[NameExpr, MemberExpr, SymbolNode]
+) -> Optional[int]:
+ type_id, fullnames = _lookup.get(node.name, (None, None))
+
+ if type_id is None or fullnames is None:
+ return None
+ elif node.fullname in fullnames:
+ return type_id
+ else:
+ return None
+
+
+def type_id_for_fullname(fullname: str) -> Optional[int]:
+ tokens = fullname.split(".")
+ immediate = tokens[-1]
+
+ type_id, fullnames = _lookup.get(immediate, (None, None))
+
+ if type_id is None or fullnames is None:
+ return None
+ elif fullname in fullnames:
+ return type_id
+ else:
+ return None
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py
new file mode 100644
index 0000000..8687012
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/plugin.py
@@ -0,0 +1,284 @@
+# ext/mypy/plugin.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+Mypy plugin for SQLAlchemy ORM.
+
+"""
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type as TypingType
+from typing import Union
+
+from mypy import nodes
+from mypy.mro import calculate_mro
+from mypy.mro import MroError
+from mypy.nodes import Block
+from mypy.nodes import ClassDef
+from mypy.nodes import GDEF
+from mypy.nodes import MypyFile
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolTable
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import AttributeContext
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import Plugin
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import Type
+
+from . import decl_class
+from . import names
+from . import util
+
+
+class SQLAlchemyPlugin(Plugin):
+ def get_dynamic_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[DynamicClassDefContext], None]]:
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
+ return _dynamic_class_hook
+ return None
+
+ def get_customize_class_mro_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ return _fill_in_decorators
+
+ def get_class_decorator_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+
+ sym = self.lookup_fully_qualified(fullname)
+
+ if sym is not None and sym.node is not None:
+ type_id = names.type_id_for_named_node(sym.node)
+ if type_id is names.MAPPED_DECORATOR:
+ return _cls_decorator_hook
+ elif type_id in (
+ names.AS_DECLARATIVE,
+ names.AS_DECLARATIVE_BASE,
+ ):
+ return _base_cls_decorator_hook
+ elif type_id is names.DECLARATIVE_MIXIN:
+ return _declarative_mixin_hook
+
+ return None
+
+ def get_metaclass_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
+ # Set any classes that explicitly have metaclass=DeclarativeMeta
+ # as declarative so the check in `get_base_class_hook()` works
+ return _metaclass_cls_hook
+
+ return None
+
+ def get_base_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ sym = self.lookup_fully_qualified(fullname)
+
+ if (
+ sym
+ and isinstance(sym.node, TypeInfo)
+ and util.has_declarative_base(sym.node)
+ ):
+ return _base_cls_hook
+
+ return None
+
+ def get_attribute_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[AttributeContext], Type]]:
+ if fullname.startswith(
+ "sqlalchemy.orm.attributes.QueryableAttribute."
+ ):
+ return _queryable_getattr_hook
+
+ return None
+
+ def get_additional_deps(
+ self, file: MypyFile
+ ) -> List[Tuple[int, str, int]]:
+ return [
+ (10, "sqlalchemy.orm.attributes", -1),
+ (10, "sqlalchemy.orm.decl_api", -1),
+ ]
+
+
+def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
+ return SQLAlchemyPlugin
+
+
+def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
+ """Generate a declarative Base class when the declarative_base() function
+ is encountered."""
+
+ _add_globals(ctx)
+
+ cls = ClassDef(ctx.name, Block([]))
+ cls.fullname = ctx.api.qualified_name(ctx.name)
+
+ info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
+ cls.info = info
+ _set_declarative_metaclass(ctx.api, cls)
+
+ cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
+ if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
+ util.set_is_base(cls_arg.node)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ cls_arg.node.defn, ctx.api, is_mixin_scan=True
+ )
+ info.bases = [Instance(cls_arg.node, [])]
+ else:
+ obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
+
+ info.bases = [obj]
+
+ try:
+ calculate_mro(info)
+ except MroError:
+ util.fail(
+ ctx.api, "Not able to calculate MRO for declarative base", ctx.call
+ )
+ obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
+ info.bases = [obj]
+ info.fallback_to_any = True
+
+ ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
+ util.set_is_base(info)
+
+
+def _fill_in_decorators(ctx: ClassDefContext) -> None:
+ for decorator in ctx.cls.decorators:
+ # set the ".fullname" attribute of a class decorator
+ # that is a MemberExpr. This causes the logic in
+ # semanal.py->apply_class_plugin_hooks to invoke the
+ # get_class_decorator_hook for our "registry.map_class()"
+ # and "registry.as_declarative_base()" methods.
+ # this seems like a bug in mypy that these decorators are otherwise
+ # skipped.
+
+ if (
+ isinstance(decorator, nodes.CallExpr)
+ and isinstance(decorator.callee, nodes.MemberExpr)
+ and decorator.callee.name == "as_declarative_base"
+ ):
+ target = decorator.callee
+ elif (
+ isinstance(decorator, nodes.MemberExpr)
+ and decorator.name == "mapped"
+ ):
+ target = decorator
+ else:
+ continue
+
+ assert isinstance(target.expr, NameExpr)
+ sym = ctx.api.lookup_qualified(
+ target.expr.name, target, suppress_errors=True
+ )
+ if sym and sym.node:
+ sym_type = get_proper_type(sym.type)
+ if isinstance(sym_type, Instance):
+ target.fullname = f"{sym_type.type.fullname}.{target.name}"
+ else:
+ # if the registry is in the same file as where the
+ # decorator is used, it might not have semantic
+ # symbols applied and we can't get a fully qualified
+ # name or an inferred type, so we are actually going to
+ # flag an error in this case that they need to annotate
+ # it. The "registry" is declared just
+ # once (or few times), so they have to just not use
+ # type inference for its assignment in this one case.
+ util.fail(
+ ctx.api,
+ "Class decorator called %s(), but we can't "
+ "tell if it's from an ORM registry. Please "
+ "annotate the registry assignment, e.g. "
+ "my_registry: registry = registry()" % target.name,
+ sym.node,
+ )
+
+
+def _cls_decorator_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ assert isinstance(ctx.reason, nodes.MemberExpr)
+ expr = ctx.reason.expr
+
+ assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
+
+ node_type = get_proper_type(expr.node.type)
+
+ assert (
+ isinstance(node_type, Instance)
+ and names.type_id_for_named_node(node_type.type) is names.REGISTRY
+ )
+
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+
+ cls = ctx.cls
+
+ _set_declarative_metaclass(ctx.api, cls)
+
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ cls, ctx.api, is_mixin_scan=True
+ )
+
+
+def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ ctx.cls, ctx.api, is_mixin_scan=True
+ )
+
+
+def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
+ util.set_is_base(ctx.cls.info)
+
+
+def _base_cls_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
+ # how do I....tell it it has no attribute of a certain name?
+ # can't find any Type that seems to match that
+ return ctx.default_attr_type
+
+
+def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
+ """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
+ for all class defs
+
+ """
+
+ util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
+
+
+def _set_declarative_metaclass(
+ api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
+) -> None:
+ info = target_cls.info
+ sym = api.lookup_fully_qualified_or_none(
+ "sqlalchemy.orm.decl_api.DeclarativeMeta"
+ )
+ assert sym is not None and isinstance(sym.node, TypeInfo)
+ info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
new file mode 100644
index 0000000..16b365e
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -0,0 +1,305 @@
+import re
+from typing import Any
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type as TypingType
+from typing import TypeVar
+from typing import Union
+
+from mypy.nodes import ARG_POS
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import CLASSDEF_NO_INFO
+from mypy.nodes import Context
+from mypy.nodes import Expression
+from mypy.nodes import IfStmt
+from mypy.nodes import JsonDict
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import Statement
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.typeops import map_type_from_supertype
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import Type
+from mypy.types import TypeVarType
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+
+_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
+
+
+class SQLAlchemyAttribute:
+ def __init__(
+ self,
+ name: str,
+ line: int,
+ column: int,
+ typ: Optional[Type],
+ info: TypeInfo,
+ ) -> None:
+ self.name = name
+ self.line = line
+ self.column = column
+ self.type = typ
+ self.info = info
+
+ def serialize(self) -> JsonDict:
+ assert self.type
+ return {
+ "name": self.name,
+ "line": self.line,
+ "column": self.column,
+ "type": self.type.serialize(),
+ }
+
+ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
+ """Expands type vars in the context of a subtype when an attribute is
+ inherited from a generic super type.
+ """
+ if not isinstance(self.type, TypeVarType):
+ return
+
+ self.type = map_type_from_supertype(self.type, sub_type, self.info)
+
+ @classmethod
+ def deserialize(
+ cls,
+ info: TypeInfo,
+ data: JsonDict,
+ api: SemanticAnalyzerPluginInterface,
+ ) -> "SQLAlchemyAttribute":
+ data = data.copy()
+ typ = deserialize_and_fixup_type(data.pop("type"), api)
+ return cls(typ=typ, info=info, **data)
+
+
+def name_is_dunder(name):
+ return bool(re.match(r"^__.+?__$", name))
+
+
+def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
+ info.metadata.setdefault("sqlalchemy", {})[key] = data
+
+
+def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ return info.metadata.get("sqlalchemy", {}).get(key, None)
+
+
+def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ if info.mro:
+ for base in info.mro:
+ metadata = _get_info_metadata(base, key)
+ if metadata is not None:
+ return metadata
+ return None
+
+
+def establish_as_sqlalchemy(info: TypeInfo) -> None:
+ info.metadata.setdefault("sqlalchemy", {})
+
+
+def set_is_base(info: TypeInfo) -> None:
+ _set_info_metadata(info, "is_base", True)
+
+
+def get_is_base(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "is_base")
+ return is_base is True
+
+
+def has_declarative_base(info: TypeInfo) -> bool:
+ is_base = _get_info_mro_metadata(info, "is_base")
+ return is_base is True
+
+
+def set_has_table(info: TypeInfo) -> None:
+ _set_info_metadata(info, "has_table", True)
+
+
+def get_has_table(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "has_table")
+ return is_base is True
+
+
+def get_mapped_attributes(
+ info: TypeInfo, api: SemanticAnalyzerPluginInterface
+) -> Optional[List[SQLAlchemyAttribute]]:
+ mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
+ info, "mapped_attributes"
+ )
+ if mapped_attributes is None:
+ return None
+
+ attributes: List[SQLAlchemyAttribute] = []
+
+ for data in mapped_attributes:
+ attr = SQLAlchemyAttribute.deserialize(info, data, api)
+ attr.expand_typevar_from_subtype(info)
+ attributes.append(attr)
+
+ return attributes
+
+
+def set_mapped_attributes(
+ info: TypeInfo, attributes: List[SQLAlchemyAttribute]
+) -> None:
+ _set_info_metadata(
+ info,
+ "mapped_attributes",
+ [attribute.serialize() for attribute in attributes],
+ )
+
+
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
+ msg = "[SQLAlchemy Mypy plugin] %s" % msg
+ return api.fail(msg, ctx)
+
+
+def add_global(
+ ctx: Union[ClassDefContext, DynamicClassDefContext],
+ module: str,
+ symbol_name: str,
+ asname: str,
+) -> None:
+ module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
+
+ if asname not in module_globals:
+ lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
+ symbol_name
+ ]
+
+ module_globals[asname] = lookup_sym
+
+
+@overload
+def get_callexpr_kwarg(
+ callexpr: CallExpr, name: str, *, expr_types: None = ...
+) -> Optional[Union[CallExpr, NameExpr]]:
+ ...
+
+
+@overload
+def get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Tuple[TypingType[_TArgType], ...]
+) -> Optional[_TArgType]:
+ ...
+
+
+def get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Optional[Tuple[TypingType[Any], ...]] = None
+) -> Optional[Any]:
+ try:
+ arg_idx = callexpr.arg_names.index(name)
+ except ValueError:
+ return None
+
+ kwarg = callexpr.args[arg_idx]
+ if isinstance(
+ kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
+ ):
+ return kwarg
+
+ return None
+
+
+def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
+ for stmt in stmts:
+ if (
+ isinstance(stmt, IfStmt)
+ and isinstance(stmt.expr[0], NameExpr)
+ and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
+ ):
+ for substmt in stmt.body[0].body:
+ yield substmt
+ else:
+ yield stmt
+
+
+def unbound_to_instance(
+ api: SemanticAnalyzerPluginInterface, typ: Type
+) -> Type:
+ """Take the UnboundType that we seem to get as the ret_type from a FuncDef
+ and convert it into an Instance/TypeInfo kind of structure that seems
+ to work as the left-hand type of an AssignmentStatement.
+
+ """
+
+ if not isinstance(typ, UnboundType):
+ return typ
+
+ # TODO: figure out a more robust way to check this. The node is some
+ # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
+ # but I cant figure out how to get them to match up
+ if typ.name == "Optional":
+ # convert from "Optional?" to the more familiar
+ # UnionType[..., NoneType()]
+ return unbound_to_instance(
+ api,
+ UnionType(
+ [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ + [NoneType()]
+ ),
+ )
+
+ node = api.lookup_qualified(typ.name, typ)
+
+ if (
+ node is not None
+ and isinstance(node, SymbolTableNode)
+ and isinstance(node.node, TypeInfo)
+ ):
+ bound_type = node.node
+
+ return Instance(
+ bound_type,
+ [
+ unbound_to_instance(api, arg)
+ if isinstance(arg, UnboundType)
+ else arg
+ for arg in typ.args
+ ],
+ )
+ else:
+ return typ
+
+
+def info_for_cls(
+ cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> Optional[TypeInfo]:
+ if cls.info is CLASSDEF_NO_INFO:
+ sym = api.lookup_qualified(cls.name, cls)
+ if sym is None:
+ return None
+ assert sym and isinstance(sym.node, TypeInfo)
+ return sym.node
+
+ return cls.info
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+ column_descriptor = NameExpr("__sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
+ member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+ return CallExpr(
+ member_expr,
+ [expr],
+ [ARG_POS],
+ ["arg1"],
+ )
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
new file mode 100644
index 0000000..5a327d1
--- /dev/null
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -0,0 +1,388 @@
+# ext/orderinglist.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""A custom list that manages index/position information for contained
+elements.
+
+:author: Jason Kirtland
+
+``orderinglist`` is a helper for mutable ordered relationships. It will
+intercept list operations performed on a :func:`_orm.relationship`-managed
+collection and
+automatically synchronize changes in list position onto a target scalar
+attribute.
+
+Example: A ``slide`` table, where each row refers to zero or more entries
+in a related ``bullet`` table. The bullets within a slide are
+displayed in order based on the value of the ``position`` column in the
+``bullet`` table. As entries are reordered in memory, the value of the
+``position`` attribute should be updated to reflect the new sort order::
+
+
+ Base = declarative_base()
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position")
+
+ class Bullet(Base):
+ __tablename__ = 'bullet'
+ id = Column(Integer, primary_key=True)
+ slide_id = Column(Integer, ForeignKey('slide.id'))
+ position = Column(Integer)
+ text = Column(String)
+
+The standard relationship mapping will produce a list-like attribute on each
+``Slide`` containing all related ``Bullet`` objects,
+but coping with changes in ordering is not handled automatically.
+When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
+attribute will remain unset until manually assigned. When the ``Bullet``
+is inserted into the middle of the list, the following ``Bullet`` objects
+will also need to be renumbered.
+
+The :class:`.OrderingList` object automates this task, managing the
+``position`` attribute on all ``Bullet`` objects in the collection. It is
+constructed using the :func:`.ordering_list` factory::
+
+ from sqlalchemy.ext.orderinglist import ordering_list
+
+ Base = declarative_base()
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position",
+ collection_class=ordering_list('position'))
+
+ class Bullet(Base):
+ __tablename__ = 'bullet'
+ id = Column(Integer, primary_key=True)
+ slide_id = Column(Integer, ForeignKey('slide.id'))
+ position = Column(Integer)
+ text = Column(String)
+
+With the above mapping the ``Bullet.position`` attribute is managed::
+
+ s = Slide()
+ s.bullets.append(Bullet())
+ s.bullets.append(Bullet())
+ s.bullets[1].position
+ >>> 1
+ s.bullets.insert(1, Bullet())
+ s.bullets[2].position
+ >>> 2
+
+The :class:`.OrderingList` construct only works with **changes** to a
+collection, and not the initial load from the database, and requires that the
+list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
+:func:`_orm.relationship` against the target ordering attribute, so that the
+ordering is correct when first loaded.
+
+.. warning::
+
+ :class:`.OrderingList` only provides limited functionality when a primary
+ key column or unique column is the target of the sort. Operations
+ that are unsupported or are problematic include:
+
+ * two entries must trade values. This is not supported directly in the
+ case of a primary key or unique constraint because it means at least
+ one row would need to be temporarily removed first, or changed to
+ a third, neutral value while the switch occurs.
+
+ * an entry must be deleted in order to make room for a new entry.
+ SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
+ single flush. In the case of a primary key, it will trade
+ an INSERT/DELETE of the same primary key for an UPDATE statement in order
+ to lessen the impact of this limitation, however this does not take place
+ for a UNIQUE column.
+ A future feature will allow the "DELETE before INSERT" behavior to be
+ possible, alleviating this limitation, though this feature will require
+ explicit configuration at the mapper level for sets of columns that
+ are to be handled in this way.
+
+:func:`.ordering_list` takes the name of the related object's ordering
+attribute as an argument. By default, the zero-based integer index of the
+object's position in the :func:`.ordering_list` is synchronized with the
+ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
+start numbering at 1 or some other integer, provide ``count_from=1``.
+
+
+"""
+from ..orm.collections import collection
+from ..orm.collections import collection_adapter
+
+
+__all__ = ["ordering_list"]
+
+
+def ordering_list(attr, count_from=None, **kw):
+ """Prepares an :class:`OrderingList` factory for use in mapper definitions.
+
+ Returns an object suitable for use as an argument to a Mapper
+ relationship's ``collection_class`` option. e.g.::
+
+ from sqlalchemy.ext.orderinglist import ordering_list
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position",
+ collection_class=ordering_list('position'))
+
+ :param attr:
+ Name of the mapped attribute to use for storage and retrieval of
+ ordering information
+
+ :param count_from:
+ Set up an integer-based ordering, starting at ``count_from``. For
+ example, ``ordering_list('pos', count_from=1)`` would create a 1-based
+ list in SQL, storing the value in the 'pos' column. Ignored if
+ ``ordering_func`` is supplied.
+
+ Additional arguments are passed to the :class:`.OrderingList` constructor.
+
+ """
+
+ kw = _unsugar_count_from(count_from=count_from, **kw)
+ return lambda: OrderingList(attr, **kw)
+
+
+# Ordering utility functions
+
+
+def count_from_0(index, collection):
+ """Numbering function: consecutive integers starting at 0."""
+
+ return index
+
+
+def count_from_1(index, collection):
+ """Numbering function: consecutive integers starting at 1."""
+
+ return index + 1
+
+
+def count_from_n_factory(start):
+ """Numbering function: consecutive integers starting at arbitrary start."""
+
+ def f(index, collection):
+ return index + start
+
+ try:
+ f.__name__ = "count_from_%i" % start
+ except TypeError:
+ pass
+ return f
+
+
+def _unsugar_count_from(**kw):
+ """Builds counting functions from keyword arguments.
+
+ Keyword argument filter, prepares a simple ``ordering_func`` from a
+ ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
+ """
+
+ count_from = kw.pop("count_from", None)
+ if kw.get("ordering_func", None) is None and count_from is not None:
+ if count_from == 0:
+ kw["ordering_func"] = count_from_0
+ elif count_from == 1:
+ kw["ordering_func"] = count_from_1
+ else:
+ kw["ordering_func"] = count_from_n_factory(count_from)
+ return kw
+
+
+class OrderingList(list):
+ """A custom list that manages position information for its children.
+
+ The :class:`.OrderingList` object is normally set up using the
+ :func:`.ordering_list` factory function, used in conjunction with
+ the :func:`_orm.relationship` function.
+
+ """
+
+ def __init__(
+ self, ordering_attr=None, ordering_func=None, reorder_on_append=False
+ ):
+ """A custom list that manages position information for its children.
+
+ ``OrderingList`` is a ``collection_class`` list implementation that
+ syncs position in a Python list with a position attribute on the
+ mapped objects.
+
+ This implementation relies on the list starting in the proper order,
+ so be **sure** to put an ``order_by`` on your relationship.
+
+ :param ordering_attr:
+ Name of the attribute that stores the object's order in the
+ relationship.
+
+ :param ordering_func: Optional. A function that maps the position in
+ the Python list to a value to store in the
+ ``ordering_attr``. Values returned are usually (but need not be!)
+ integers.
+
+ An ``ordering_func`` is called with two positional parameters: the
+ index of the element in the list, and the list itself.
+
+ If omitted, Python list indexes are used for the attribute values.
+ Two basic pre-built numbering functions are provided in this module:
+ ``count_from_0`` and ``count_from_1``. For more exotic examples
+ like stepped numbering, alphabetical and Fibonacci numbering, see
+ the unit tests.
+
+ :param reorder_on_append:
+ Default False. When appending an object with an existing (non-None)
+ ordering value, that value will be left untouched unless
+ ``reorder_on_append`` is true. This is an optimization to avoid a
+ variety of dangerous unexpected database writes.
+
+ SQLAlchemy will add instances to the list via append() when your
+ object loads. If for some reason the result set from the database
+ skips a step in the ordering (say, row '1' is missing but you get
+ '2', '3', and '4'), reorder_on_append=True would immediately
+ renumber the items to '1', '2', '3'. If you have multiple sessions
+ making changes, any of whom happen to load this collection even in
+ passing, all of the sessions would try to "clean up" the numbering
+ in their commits, possibly causing all but one to fail with a
+ concurrent modification error.
+
+ Recommend leaving this with the default of False, and just call
+ ``reorder()`` if you're doing ``append()`` operations with
+ previously ordered instances or when doing some housekeeping after
+ manual sql operations.
+
+ """
+ self.ordering_attr = ordering_attr
+ if ordering_func is None:
+ ordering_func = count_from_0
+ self.ordering_func = ordering_func
+ self.reorder_on_append = reorder_on_append
+
+ # More complex serialization schemes (multi column, e.g.) are possible by
+ # subclassing and reimplementing these two methods.
+ def _get_order_value(self, entity):
+ return getattr(entity, self.ordering_attr)
+
+ def _set_order_value(self, entity, value):
+ setattr(entity, self.ordering_attr, value)
+
+ def reorder(self):
+ """Synchronize ordering for the entire collection.
+
+ Sweeps through the list and ensures that each object has accurate
+ ordering information set.
+
+ """
+ for index, entity in enumerate(self):
+ self._order_entity(index, entity, True)
+
+ # As of 0.5, _reorder is no longer semi-private
+ _reorder = reorder
+
+ def _order_entity(self, index, entity, reorder=True):
+ have = self._get_order_value(entity)
+
+ # Don't disturb existing ordering if reorder is False
+ if have is not None and not reorder:
+ return
+
+ should_be = self.ordering_func(index, self)
+ if have != should_be:
+ self._set_order_value(entity, should_be)
+
+ def append(self, entity):
+ super(OrderingList, self).append(entity)
+ self._order_entity(len(self) - 1, entity, self.reorder_on_append)
+
+ def _raw_append(self, entity):
+ """Append without any ordering behavior."""
+
+ super(OrderingList, self).append(entity)
+
+ _raw_append = collection.adds(1)(_raw_append)
+
+ def insert(self, index, entity):
+ super(OrderingList, self).insert(index, entity)
+ self._reorder()
+
+ def remove(self, entity):
+ super(OrderingList, self).remove(entity)
+
+ adapter = collection_adapter(self)
+ if adapter and adapter._referenced_by_owner:
+ self._reorder()
+
+ def pop(self, index=-1):
+ entity = super(OrderingList, self).pop(index)
+ self._reorder()
+ return entity
+
+ def __setitem__(self, index, entity):
+ if isinstance(index, slice):
+ step = index.step or 1
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ stop = index.stop or len(self)
+ if stop < 0:
+ stop += len(self)
+
+ for i in range(start, stop, step):
+ self.__setitem__(i, entity[i])
+ else:
+ self._order_entity(index, entity, True)
+ super(OrderingList, self).__setitem__(index, entity)
+
+ def __delitem__(self, index):
+ super(OrderingList, self).__delitem__(index)
+ self._reorder()
+
+ def __setslice__(self, start, end, values):
+ super(OrderingList, self).__setslice__(start, end, values)
+ self._reorder()
+
+ def __delslice__(self, start, end):
+ super(OrderingList, self).__delslice__(start, end)
+ self._reorder()
+
+ def __reduce__(self):
+ return _reconstitute, (self.__class__, self.__dict__, list(self))
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
+ func.__doc__ = getattr(list, func_name).__doc__
+ del func_name, func
+
+
+def _reconstitute(cls, dict_, items):
+ """Reconstitute an :class:`.OrderingList`.
+
+ This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
+ unpickling :class:`.OrderingList` objects.
+
+ """
+ obj = cls.__new__(cls)
+ obj.__dict__.update(dict_)
+ list.extend(obj, items)
+ return obj
diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py
new file mode 100644
index 0000000..094b71b
--- /dev/null
+++ b/lib/sqlalchemy/ext/serializer.py
@@ -0,0 +1,177 @@
+# ext/serializer.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
+allowing "contextual" deserialization.
+
+Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
+or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
+etc. which are referenced by the structure are not persisted in serialized
+form, but are instead re-associated with the query structure
+when it is deserialized.
+
+Usage is nearly the same as that of the standard Python pickle module::
+
+ from sqlalchemy.ext.serializer import loads, dumps
+ metadata = MetaData(bind=some_engine)
+ Session = scoped_session(sessionmaker())
+
+ # ... define mappers
+
+ query = Session.query(MyClass).
+ filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
+
+ # pickle the query
+ serialized = dumps(query)
+
+ # unpickle. Pass in metadata + scoped_session
+ query2 = loads(serialized, metadata, Session)
+
+ print query2.all()
+
+Similar restrictions as when using raw pickle apply; mapped classes must be
+themselves be pickleable, meaning they are importable from a module-level
+namespace.
+
+The serializer module is only appropriate for query structures. It is not
+needed for:
+
+* instances of user-defined classes. These contain no references to engines,
+ sessions or expression constructs in the typical case and can be serialized
+ directly.
+
+* Table metadata that is to be loaded entirely from the serialized structure
+ (i.e. is not already declared in the application). Regular
+ pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
+ typically one which was reflected from an existing database at some previous
+ point in time. The serializer module is specifically for the opposite case,
+ where the Table metadata is already present in memory.
+
+"""
+
+import re
+
+from .. import Column
+from .. import Table
+from ..engine import Engine
+from ..orm import class_mapper
+from ..orm.interfaces import MapperProperty
+from ..orm.mapper import Mapper
+from ..orm.session import Session
+from ..util import b64decode
+from ..util import b64encode
+from ..util import byte_buffer
+from ..util import pickle
+from ..util import text_type
+
+
+__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
+
+
+def Serializer(*args, **kw):
+ pickler = pickle.Pickler(*args, **kw)
+
+ def persistent_id(obj):
+ # print "serializing:", repr(obj)
+ if isinstance(obj, Mapper) and not obj.non_primary:
+ id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
+ elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
+ id_ = (
+ "mapperprop:"
+ + b64encode(pickle.dumps(obj.parent.class_))
+ + ":"
+ + obj.key
+ )
+ elif isinstance(obj, Table):
+ if "parententity" in obj._annotations:
+ id_ = "mapper_selectable:" + b64encode(
+ pickle.dumps(obj._annotations["parententity"].class_)
+ )
+ else:
+ id_ = "table:" + text_type(obj.key)
+ elif isinstance(obj, Column) and isinstance(obj.table, Table):
+ id_ = (
+ "column:" + text_type(obj.table.key) + ":" + text_type(obj.key)
+ )
+ elif isinstance(obj, Session):
+ id_ = "session:"
+ elif isinstance(obj, Engine):
+ id_ = "engine:"
+ else:
+ return None
+ return id_
+
+ pickler.persistent_id = persistent_id
+ return pickler
+
+
+our_ids = re.compile(
+ r"(mapperprop|mapper|mapper_selectable|table|column|"
+ r"session|attribute|engine):(.*)"
+)
+
+
+def Deserializer(file, metadata=None, scoped_session=None, engine=None):
+ unpickler = pickle.Unpickler(file)
+
+ def get_engine():
+ if engine:
+ return engine
+ elif scoped_session and scoped_session().bind:
+ return scoped_session().bind
+ elif metadata and metadata.bind:
+ return metadata.bind
+ else:
+ return None
+
+ def persistent_load(id_):
+ m = our_ids.match(text_type(id_))
+ if not m:
+ return None
+ else:
+ type_, args = m.group(1, 2)
+ if type_ == "attribute":
+ key, clsarg = args.split(":")
+ cls = pickle.loads(b64decode(clsarg))
+ return getattr(cls, key)
+ elif type_ == "mapper":
+ cls = pickle.loads(b64decode(args))
+ return class_mapper(cls)
+ elif type_ == "mapper_selectable":
+ cls = pickle.loads(b64decode(args))
+ return class_mapper(cls).__clause_element__()
+ elif type_ == "mapperprop":
+ mapper, keyname = args.split(":")
+ cls = pickle.loads(b64decode(mapper))
+ return class_mapper(cls).attrs[keyname]
+ elif type_ == "table":
+ return metadata.tables[args]
+ elif type_ == "column":
+ table, colname = args.split(":")
+ return metadata.tables[table].c[colname]
+ elif type_ == "session":
+ return scoped_session()
+ elif type_ == "engine":
+ return get_engine()
+ else:
+ raise Exception("Unknown token: %s" % type_)
+
+ unpickler.persistent_load = persistent_load
+ return unpickler
+
+
+def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
+ buf = byte_buffer()
+ pickler = Serializer(buf, protocol)
+ pickler.dump(obj)
+ return buf.getvalue()
+
+
+def loads(data, metadata=None, scoped_session=None, engine=None):
+ buf = byte_buffer(data)
+ unpickler = Deserializer(buf, metadata, scoped_session, engine)
+ return unpickler.load()
diff --git a/lib/sqlalchemy/future/__init__.py b/lib/sqlalchemy/future/__init__.py
new file mode 100644
index 0000000..a2bed07
--- /dev/null
+++ b/lib/sqlalchemy/future/__init__.py
@@ -0,0 +1,18 @@
+# sql/future/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Future 2.0 API features.
+
+"""
+from .engine import Connection
+from .engine import create_engine
+from .engine import Engine
+from ..sql.selectable import Select
+from ..util.langhelpers import public_factory
+
+
+select = public_factory(Select._create_future_select, ".future.select")
diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py
new file mode 100644
index 0000000..3235529
--- /dev/null
+++ b/lib/sqlalchemy/future/engine.py
@@ -0,0 +1,413 @@
+from .. import util
+from ..engine import Connection as _LegacyConnection
+from ..engine import create_engine as _create_engine
+from ..engine import Engine as _LegacyEngine
+from ..engine.base import OptionEngineMixin
+
+NO_OPTIONS = util.immutabledict()
+
+
+def create_engine(*arg, **kw):
+ """Create a new :class:`_future.Engine` instance.
+
+ Arguments passed to :func:`_future.create_engine` are mostly identical
+ to those passed to the 1.x :func:`_sa.create_engine` function.
+ The difference is that the object returned is the :class:`._future.Engine`
+ which has the 2.0 version of the API.
+
+ """
+
+ kw["_future_engine_class"] = Engine
+ return _create_engine(*arg, **kw)
+
+
+class Connection(_LegacyConnection):
+ """Provides high-level functionality for a wrapped DB-API connection.
+
+ The :class:`_future.Connection` object is procured by calling
+ the :meth:`_future.Engine.connect` method of the :class:`_future.Engine`
+ object, and provides services for execution of SQL statements as well
+ as transaction control.
+
+ **This is the SQLAlchemy 2.0 version** of the :class:`_engine.Connection`
+ class. The API and behavior of this object is largely the same, with the
+ following differences in behavior:
+
+ * The result object returned for results is the
+ :class:`_engine.CursorResult`
+ object, which is a subclass of the :class:`_engine.Result`.
+ This object has a slightly different API and behavior than the
+ :class:`_engine.LegacyCursorResult` returned for 1.x style usage.
+
+ * The object has :meth:`_future.Connection.commit` and
+ :meth:`_future.Connection.rollback` methods which commit or roll back
+ the current transaction in progress, if any.
+
+ * The object features "autobegin" behavior, such that any call to
+ :meth:`_future.Connection.execute` will
+ unconditionally start a
+ transaction which can be controlled using the above mentioned
+ :meth:`_future.Connection.commit` and
+ :meth:`_future.Connection.rollback` methods.
+
+ * The object does not have any "autocommit" functionality. Any SQL
+ statement or DDL statement will not be followed by any COMMIT until
+ the transaction is explicitly committed, either via the
+ :meth:`_future.Connection.commit` method, or if the connection is
+ being used in a context manager that commits such as the one
+ returned by :meth:`_future.Engine.begin`.
+
+ * The SAVEPOINT method :meth:`_future.Connection.begin_nested` returns
+ a :class:`_engine.NestedTransaction` as was always the case, and the
+ savepoint can be controlled by invoking
+ :meth:`_engine.NestedTransaction.commit` or
+ :meth:`_engine.NestedTransaction.rollback` as was the case before.
+ However, this savepoint "transaction" is not associated with the
+ transaction that is controlled by the connection itself; the overall
+ transaction can be committed or rolled back directly which will not emit
+ any special instructions for the SAVEPOINT (this will typically have the
+ effect that one desires).
+
+ * The :class:`_future.Connection` object does not support "branching",
+ which was a pattern by which a sub "connection" would be used that
+ refers to this connection as a parent.
+
+
+
+ """
+
+ _is_future = True
+
+ def _branch(self):
+ raise NotImplementedError(
+ "sqlalchemy.future.Connection does not support "
+ "'branching' of new connections."
+ )
+
+ def begin(self):
+ """Begin a transaction prior to autobegin occurring.
+
+ The returned object is an instance of :class:`_engine.RootTransaction`.
+ This object represents the "scope" of the transaction,
+ which completes when either the :meth:`_engine.Transaction.rollback`
+ or :meth:`_engine.Transaction.commit` method is called.
+
+ The :meth:`_future.Connection.begin` method in SQLAlchemy 2.0 begins a
+ transaction that normally will be begun in any case when the connection
+ is first used to execute a statement. The reason this method might be
+ used would be to invoke the :meth:`_events.ConnectionEvents.begin`
+ event at a specific time, or to organize code within the scope of a
+ connection checkout in terms of context managed blocks, such as::
+
+ with engine.connect() as conn:
+ with conn.begin():
+ conn.execute(...)
+ conn.execute(...)
+
+ with conn.begin():
+ conn.execute(...)
+ conn.execute(...)
+
+ The above code is not fundamentally any different in its behavior than
+ the following code which does not use
+ :meth:`_future.Connection.begin`; the below style is referred towards
+ as "commit as you go" style::
+
+ with engine.connect() as conn:
+ conn.execute(...)
+ conn.execute(...)
+ conn.commit()
+
+ conn.execute(...)
+ conn.execute(...)
+ conn.commit()
+
+ From a database point of view, the :meth:`_future.Connection.begin`
+ method does not emit any SQL or change the state of the underlying
+ DBAPI connection in any way; the Python DBAPI does not have any
+ concept of explicit transaction begin.
+
+ .. seealso::
+
+ :ref:`tutorial_working_with_transactions` - in the
+ :ref:`unified_tutorial`
+
+ :meth:`_future.Connection.begin_nested` - use a SAVEPOINT
+
+ :meth:`_engine.Connection.begin_twophase` -
+ use a two phase /XID transaction
+
+ :meth:`_future.Engine.begin` - context manager available from
+ :class:`_future.Engine`
+
+ """
+ return super(Connection, self).begin()
+
+ def begin_nested(self):
+ """Begin a nested transaction (i.e. SAVEPOINT) and return a transaction
+ handle.
+
+ The returned object is an instance of
+ :class:`_engine.NestedTransaction`.
+
+ Nested transactions require SAVEPOINT support in the
+ underlying database. Any transaction in the hierarchy may
+ ``commit`` and ``rollback``, however the outermost transaction
+ still controls the overall ``commit`` or ``rollback`` of the
+ transaction of a whole.
+
+ If an outer :class:`.RootTransaction` is not present on this
+ :class:`_future.Connection`, a new one is created using "autobegin".
+ This outer transaction may be completed using "commit-as-you-go" style
+ usage, by calling upon :meth:`_future.Connection.commit` or
+ :meth:`_future.Connection.rollback`.
+
+ .. tip::
+
+ The "autobegin" behavior of :meth:`_future.Connection.begin_nested`
+ is specific to :term:`2.0 style` use; for legacy behaviors, see
+ :meth:`_engine.Connection.begin_nested`.
+
+ The :class:`_engine.NestedTransaction` remains independent of the
+ :class:`_future.Connection` object itself. Calling the
+ :meth:`_future.Connection.commit` or
+ :meth:`_future.Connection.rollback` will always affect the actual
+ containing database transaction itself, and not the SAVEPOINT itself.
+ When a database transaction is committed, any SAVEPOINTs that have been
+ established are cleared and the data changes within their scope is also
+ committed.
+
+ .. seealso::
+
+ :meth:`_future.Connection.begin`
+
+
+ """
+ return super(Connection, self).begin_nested()
+
+ def commit(self):
+ """Commit the transaction that is currently in progress.
+
+ This method commits the current transaction if one has been started.
+ If no transaction was started, the method has no effect, assuming
+ the connection is in a non-invalidated state.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ .. note:: The :meth:`_future.Connection.commit` method only acts upon
+ the primary database transaction that is linked to the
+ :class:`_future.Connection` object. It does not operate upon a
+ SAVEPOINT that would have been invoked from the
+ :meth:`_future.Connection.begin_nested` method; for control of a
+ SAVEPOINT, call :meth:`_engine.NestedTransaction.commit` on the
+ :class:`_engine.NestedTransaction` that is returned by the
+ :meth:`_future.Connection.begin_nested` method itself.
+
+
+ """
+ if self._transaction:
+ self._transaction.commit()
+
+ def rollback(self):
+ """Roll back the transaction that is currently in progress.
+
+ This method rolls back the current transaction if one has been started.
+ If no transaction was started, the method has no effect. If a
+ transaction was started and the connection is in an invalidated state,
+ the transaction is cleared using this method.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ .. note:: The :meth:`_future.Connection.rollback` method only acts
+ upon the primary database transaction that is linked to the
+ :class:`_future.Connection` object. It does not operate upon a
+ SAVEPOINT that would have been invoked from the
+ :meth:`_future.Connection.begin_nested` method; for control of a
+ SAVEPOINT, call :meth:`_engine.NestedTransaction.rollback` on the
+ :class:`_engine.NestedTransaction` that is returned by the
+ :meth:`_future.Connection.begin_nested` method itself.
+
+
+ """
+ if self._transaction:
+ self._transaction.rollback()
+
+ def close(self):
+ """Close this :class:`_future.Connection`.
+
+ This has the effect of also calling :meth:`_future.Connection.rollback`
+ if any transaction is in place.
+
+ """
+ super(Connection, self).close()
+
+ def execute(self, statement, parameters=None, execution_options=None):
+ r"""Executes a SQL statement construct and returns a
+ :class:`_engine.Result`.
+
+ :param statement: The statement to be executed. This is always
+ an object that is in both the :class:`_expression.ClauseElement` and
+ :class:`_expression.Executable` hierarchies, including:
+
+ * :class:`_expression.Select`
+ * :class:`_expression.Insert`, :class:`_expression.Update`,
+ :class:`_expression.Delete`
+ * :class:`_expression.TextClause` and
+ :class:`_expression.TextualSelect`
+ * :class:`_schema.DDL` and objects which inherit from
+ :class:`_schema.DDLElement`
+
+ :param parameters: parameters which will be bound into the statement.
+ This may be either a dictionary of parameter names to values,
+ or a mutable sequence (e.g. a list) of dictionaries. When a
+ list of dictionaries is passed, the underlying statement execution
+ will make use of the DBAPI ``cursor.executemany()`` method.
+ When a single dictionary is passed, the DBAPI ``cursor.execute()``
+ method will be used.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_future.Connection.execution_options`.
+
+ :return: a :class:`_engine.Result` object.
+
+ """
+ return self._execute_20(
+ statement, parameters, execution_options or NO_OPTIONS
+ )
+
+ def scalar(self, statement, parameters=None, execution_options=None):
+ r"""Executes a SQL statement construct and returns a scalar object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalar` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a scalar Python value representing the first column of the
+ first row returned.
+
+ """
+ return self.execute(statement, parameters, execution_options).scalar()
+
+
+class Engine(_LegacyEngine):
+ """Connects a :class:`_pool.Pool` and
+ :class:`_engine.Dialect` together to provide a
+ source of database connectivity and behavior.
+
+ **This is the SQLAlchemy 2.0 version** of the :class:`~.engine.Engine`.
+
+ An :class:`.future.Engine` object is instantiated publicly using the
+ :func:`~sqlalchemy.future.create_engine` function.
+
+ .. seealso::
+
+ :doc:`/core/engines`
+
+ :ref:`connections_toplevel`
+
+ """
+
+ _connection_cls = Connection
+ _is_future = True
+
+ def _not_implemented(self, *arg, **kw):
+ raise NotImplementedError(
+ "This method is not implemented for SQLAlchemy 2.0."
+ )
+
+ transaction = (
+ run_callable
+ ) = (
+ execute
+ ) = (
+ scalar
+ ) = (
+ _execute_clauseelement
+ ) = _execute_compiled = table_names = has_table = _not_implemented
+
+ def _run_ddl_visitor(self, visitorcallable, element, **kwargs):
+ # TODO: this is for create_all support etc. not clear if we
+ # want to provide this in 2.0, that is, a way to execute SQL where
+ # they aren't calling "engine.begin()" explicitly, however, DDL
+ # may be a special case for which we want to continue doing it this
+ # way. A big win here is that the full DDL sequence is inside of a
+ # single transaction rather than COMMIT for each statement.
+ with self.begin() as conn:
+ conn._run_ddl_visitor(visitorcallable, element, **kwargs)
+
+ @classmethod
+ def _future_facade(self, legacy_engine):
+ return Engine(
+ legacy_engine.pool,
+ legacy_engine.dialect,
+ legacy_engine.url,
+ logging_name=legacy_engine.logging_name,
+ echo=legacy_engine.echo,
+ hide_parameters=legacy_engine.hide_parameters,
+ execution_options=legacy_engine._execution_options,
+ )
+
+ @util.contextmanager
+ def begin(self):
+ """Return a :class:`_future.Connection` object with a transaction
+ begun.
+
+ Use of this method is similar to that of
+ :meth:`_future.Engine.connect`, typically as a context manager, which
+ will automatically maintain the state of the transaction when the block
+ ends, either by calling :meth:`_future.Connection.commit` when the
+ block succeeds normally, or :meth:`_future.Connection.rollback` when an
+ exception is raised, before propagating the exception outwards::
+
+ with engine.begin() as connection:
+ connection.execute(text("insert into table values ('foo')"))
+
+
+ .. seealso::
+
+ :meth:`_future.Engine.connect`
+
+ :meth:`_future.Connection.begin`
+
+ """
+ with self.connect() as conn:
+ with conn.begin():
+ yield conn
+
+ def connect(self):
+ """Return a new :class:`_future.Connection` object.
+
+ The :class:`_future.Connection` acts as a Python context manager, so
+ the typical use of this method looks like::
+
+ with engine.connect() as connection:
+ connection.execute(text("insert into table values ('foo')"))
+ connection.commit()
+
+ Where above, after the block is completed, the connection is "closed"
+ and its underlying DBAPI resources are returned to the connection pool.
+ This also has the effect of rolling back any transaction that
+ was explicitly begun or was begun via autobegin, and will
+ emit the :meth:`_events.ConnectionEvents.rollback` event if one was
+ started and is still in progress.
+
+ .. seealso::
+
+ :meth:`_future.Engine.begin`
+
+
+ """
+ return super(Engine, self).connect()
+
+
+class OptionEngine(OptionEngineMixin, Engine):
+ pass
+
+
+Engine._option_cls = OptionEngine
diff --git a/lib/sqlalchemy/future/orm/__init__.py b/lib/sqlalchemy/future/orm/__init__.py
new file mode 100644
index 0000000..629631b
--- /dev/null
+++ b/lib/sqlalchemy/future/orm/__init__.py
@@ -0,0 +1,10 @@
+# sql/future/orm/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Future 2.0 API features for Orm.
+
+"""
diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py
new file mode 100644
index 0000000..7f9822d
--- /dev/null
+++ b/lib/sqlalchemy/inspection.py
@@ -0,0 +1,93 @@
+# sqlalchemy/inspect.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The inspection module provides the :func:`_sa.inspect` function,
+which delivers runtime information about a wide variety
+of SQLAlchemy objects, both within the Core as well as the
+ORM.
+
+The :func:`_sa.inspect` function is the entry point to SQLAlchemy's
+public API for viewing the configuration and construction
+of in-memory objects. Depending on the type of object
+passed to :func:`_sa.inspect`, the return value will either be
+a related object which provides a known interface, or in many
+cases it will return the object itself.
+
+The rationale for :func:`_sa.inspect` is twofold. One is that
+it replaces the need to be aware of a large variety of "information
+getting" functions in SQLAlchemy, such as
+:meth:`_reflection.Inspector.from_engine` (deprecated in 1.4),
+:func:`.orm.attributes.instance_state`, :func:`_orm.class_mapper`,
+and others. The other is that the return value of :func:`_sa.inspect`
+is guaranteed to obey a documented API, thus allowing third party
+tools which build on top of SQLAlchemy configurations to be constructed
+in a forwards-compatible way.
+
+"""
+
+from . import exc
+from . import util
+
+
+_registrars = util.defaultdict(list)
+
+
+def inspect(subject, raiseerr=True):
+ """Produce an inspection object for the given target.
+
+ The returned value in some cases may be the
+ same object as the one given, such as if a
+ :class:`_orm.Mapper` object is passed. In other
+ cases, it will be an instance of the registered
+ inspection type for the given object, such as
+ if an :class:`_engine.Engine` is passed, an
+ :class:`_reflection.Inspector` object is returned.
+
+ :param subject: the subject to be inspected.
+ :param raiseerr: When ``True``, if the given subject
+ does not
+ correspond to a known SQLAlchemy inspected type,
+ :class:`sqlalchemy.exc.NoInspectionAvailable`
+ is raised. If ``False``, ``None`` is returned.
+
+ """
+ type_ = type(subject)
+ for cls in type_.__mro__:
+ if cls in _registrars:
+ reg = _registrars[cls]
+ if reg is True:
+ return subject
+ ret = reg(subject)
+ if ret is not None:
+ break
+ else:
+ reg = ret = None
+
+ if raiseerr and (reg is None or ret is None):
+ raise exc.NoInspectionAvailable(
+ "No inspection system is "
+ "available for object of type %s" % type_
+ )
+ return ret
+
+
+def _inspects(*types):
+ def decorate(fn_or_cls):
+ for type_ in types:
+ if type_ in _registrars:
+ raise AssertionError(
+ "Type %s is already " "registered" % type_
+ )
+ _registrars[type_] = fn_or_cls
+ return fn_or_cls
+
+ return decorate
+
+
+def _self_inspects(cls):
+ _inspects(cls)(True)
+ return cls
diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py
new file mode 100644
index 0000000..cc662ec
--- /dev/null
+++ b/lib/sqlalchemy/log.py
@@ -0,0 +1,241 @@
+# sqlalchemy/log.py
+# Copyright (C) 2006-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Logging control and utilities.
+
+Control of logging for SA can be performed from the regular python logging
+module. The regular dotted module namespace is used, starting at
+'sqlalchemy'. For class-level logging, the class name is appended.
+
+The "echo" keyword parameter, available on SQLA :class:`_engine.Engine`
+and :class:`_pool.Pool` objects, corresponds to a logger specific to that
+instance only.
+
+"""
+
+import logging
+import sys
+
+from .util import py311
+from .util import py38
+
+if py38:
+ STACKLEVEL = True
+ # needed as of py3.11.0b1
+ # #8019
+ STACKLEVEL_OFFSET = 2 if py311 else 1
+else:
+ STACKLEVEL = False
+ STACKLEVEL_OFFSET = 0
+
+# set initial level to WARN. This so that
+# log statements don't occur in the absence of explicit
+# logging being enabled for 'sqlalchemy'.
+rootlogger = logging.getLogger("sqlalchemy")
+if rootlogger.level == logging.NOTSET:
+ rootlogger.setLevel(logging.WARN)
+
+
+def _add_default_handler(logger):
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(
+ logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
+ )
+ logger.addHandler(handler)
+
+
+_logged_classes = set()
+
+
+def _qual_logger_name_for_cls(cls):
+ return (
+ getattr(cls, "_sqla_logger_namespace", None)
+ or cls.__module__ + "." + cls.__name__
+ )
+
+
+def class_logger(cls):
+ logger = logging.getLogger(_qual_logger_name_for_cls(cls))
+ cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
+ cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
+ cls.logger = logger
+ _logged_classes.add(cls)
+ return cls
+
+
+class Identified(object):
+ logging_name = None
+
+ def _should_log_debug(self):
+ return self.logger.isEnabledFor(logging.DEBUG)
+
+ def _should_log_info(self):
+ return self.logger.isEnabledFor(logging.INFO)
+
+
+class InstanceLogger(object):
+ """A logger adapter (wrapper) for :class:`.Identified` subclasses.
+
+ This allows multiple instances (e.g. Engine or Pool instances)
+ to share a logger, but have its verbosity controlled on a
+ per-instance basis.
+
+ The basic functionality is to return a logging level
+ which is based on an instance's echo setting.
+
+ Default implementation is:
+
+ 'debug' -> logging.DEBUG
+ True -> logging.INFO
+ False -> Effective level of underlying logger (
+ logging.WARNING by default)
+ None -> same as False
+ """
+
+ # Map echo settings to logger levels
+ _echo_map = {
+ None: logging.NOTSET,
+ False: logging.NOTSET,
+ True: logging.INFO,
+ "debug": logging.DEBUG,
+ }
+
+ def __init__(self, echo, name):
+ self.echo = echo
+ self.logger = logging.getLogger(name)
+
+ # if echo flag is enabled and no handlers,
+ # add a handler to the list
+ if self._echo_map[echo] <= logging.INFO and not self.logger.handlers:
+ _add_default_handler(self.logger)
+
+ #
+ # Boilerplate convenience methods
+ #
+ def debug(self, msg, *args, **kwargs):
+ """Delegate a debug call to the underlying logger."""
+
+ self.log(logging.DEBUG, msg, *args, **kwargs)
+
+ def info(self, msg, *args, **kwargs):
+ """Delegate an info call to the underlying logger."""
+
+ self.log(logging.INFO, msg, *args, **kwargs)
+
+ def warning(self, msg, *args, **kwargs):
+ """Delegate a warning call to the underlying logger."""
+
+ self.log(logging.WARNING, msg, *args, **kwargs)
+
+ warn = warning
+
+ def error(self, msg, *args, **kwargs):
+ """
+ Delegate an error call to the underlying logger.
+ """
+ self.log(logging.ERROR, msg, *args, **kwargs)
+
+ def exception(self, msg, *args, **kwargs):
+ """Delegate an exception call to the underlying logger."""
+
+ kwargs["exc_info"] = 1
+ self.log(logging.ERROR, msg, *args, **kwargs)
+
+ def critical(self, msg, *args, **kwargs):
+ """Delegate a critical call to the underlying logger."""
+
+ self.log(logging.CRITICAL, msg, *args, **kwargs)
+
+ def log(self, level, msg, *args, **kwargs):
+ """Delegate a log call to the underlying logger.
+
+ The level here is determined by the echo
+ flag as well as that of the underlying logger, and
+ logger._log() is called directly.
+
+ """
+
+ # inline the logic from isEnabledFor(),
+ # getEffectiveLevel(), to avoid overhead.
+
+ if self.logger.manager.disable >= level:
+ return
+
+ selected_level = self._echo_map[self.echo]
+ if selected_level == logging.NOTSET:
+ selected_level = self.logger.getEffectiveLevel()
+
+ if level >= selected_level:
+ if STACKLEVEL:
+ kwargs["stacklevel"] = (
+ kwargs.get("stacklevel", 1) + STACKLEVEL_OFFSET
+ )
+
+ self.logger._log(level, msg, args, **kwargs)
+
+ def isEnabledFor(self, level):
+ """Is this logger enabled for level 'level'?"""
+
+ if self.logger.manager.disable >= level:
+ return False
+ return level >= self.getEffectiveLevel()
+
+ def getEffectiveLevel(self):
+ """What's the effective level for this logger?"""
+
+ level = self._echo_map[self.echo]
+ if level == logging.NOTSET:
+ level = self.logger.getEffectiveLevel()
+ return level
+
+
+def instance_logger(instance, echoflag=None):
+ """create a logger for an instance that implements :class:`.Identified`."""
+
+ if instance.logging_name:
+ name = "%s.%s" % (
+ _qual_logger_name_for_cls(instance.__class__),
+ instance.logging_name,
+ )
+ else:
+ name = _qual_logger_name_for_cls(instance.__class__)
+
+ instance._echo = echoflag
+
+ if echoflag in (False, None):
+ # if no echo setting or False, return a Logger directly,
+ # avoiding overhead of filtering
+ logger = logging.getLogger(name)
+ else:
+ # if a specified echo flag, return an EchoLogger,
+ # which checks the flag, overrides normal log
+ # levels by calling logger._log()
+ logger = InstanceLogger(echoflag, name)
+
+ instance.logger = logger
+
+
+class echo_property(object):
+ __doc__ = """\
+ When ``True``, enable log output for this element.
+
+ This has the effect of setting the Python logging level for the namespace
+ of this element's class and object reference. A value of boolean ``True``
+ indicates that the loglevel ``logging.INFO`` will be set for the logger,
+ whereas the string value ``debug`` will set the loglevel to
+ ``logging.DEBUG``.
+ """
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self
+ else:
+ return instance._echo
+
+ def __set__(self, instance, value):
+ instance_logger(instance, echoflag=value)
diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py
new file mode 100644
index 0000000..6e0de05
--- /dev/null
+++ b/lib/sqlalchemy/orm/__init__.py
@@ -0,0 +1,344 @@
+# orm/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+Functional constructs for ORM configuration.
+
+See the SQLAlchemy object relational tutorial and mapper configuration
+documentation for an overview of how this module is used.
+
+"""
+
+from . import exc
+from . import mapper as mapperlib
+from . import strategy_options
+from .attributes import AttributeEvent
+from .attributes import InstrumentedAttribute
+from .attributes import Mapped
+from .attributes import QueryableAttribute
+from .context import QueryContext
+from .decl_api import as_declarative
+from .decl_api import declarative_base
+from .decl_api import declarative_mixin
+from .decl_api import DeclarativeMeta
+from .decl_api import declared_attr
+from .decl_api import has_inherited_table
+from .decl_api import registry
+from .decl_api import synonym_for
+from .descriptor_props import CompositeProperty
+from .descriptor_props import SynonymProperty
+from .identity import IdentityMap
+from .instrumentation import ClassManager
+from .interfaces import EXT_CONTINUE
+from .interfaces import EXT_SKIP
+from .interfaces import EXT_STOP
+from .interfaces import InspectionAttr
+from .interfaces import InspectionAttrInfo
+from .interfaces import MANYTOMANY
+from .interfaces import MANYTOONE
+from .interfaces import MapperProperty
+from .interfaces import NOT_EXTENSION
+from .interfaces import ONETOMANY
+from .interfaces import PropComparator
+from .interfaces import UserDefinedOption
+from .loading import merge_frozen_result
+from .loading import merge_result
+from .mapper import class_mapper
+from .mapper import configure_mappers
+from .mapper import Mapper
+from .mapper import reconstructor
+from .mapper import validates
+from .properties import ColumnProperty
+from .query import AliasOption
+from .query import FromStatement
+from .query import Query
+from .relationships import foreign
+from .relationships import RelationshipProperty
+from .relationships import remote
+from .scoping import scoped_session
+from .session import close_all_sessions
+from .session import make_transient
+from .session import make_transient_to_detached
+from .session import object_session
+from .session import ORMExecuteState
+from .session import Session
+from .session import sessionmaker
+from .session import SessionTransaction
+from .state import AttributeState
+from .state import InstanceState
+from .strategy_options import Load
+from .unitofwork import UOWTransaction
+from .util import aliased
+from .util import Bundle
+from .util import CascadeOptions
+from .util import join
+from .util import LoaderCriteriaOption
+from .util import object_mapper
+from .util import outerjoin
+from .util import polymorphic_union
+from .util import was_deleted
+from .util import with_parent
+from .util import with_polymorphic
+from .. import sql as _sql
+from .. import util as _sa_util
+from ..util.langhelpers import public_factory
+
+
+def create_session(bind=None, **kwargs):
+ r"""Create a new :class:`.Session`
+ with no automation enabled by default.
+
+ This function is used primarily for testing. The usual
+ route to :class:`.Session` creation is via its constructor
+ or the :func:`.sessionmaker` function.
+
+ :param bind: optional, a single Connectable to use for all
+ database access in the created
+ :class:`~sqlalchemy.orm.session.Session`.
+
+ :param \*\*kwargs: optional, passed through to the
+ :class:`.Session` constructor.
+
+ :returns: an :class:`~sqlalchemy.orm.session.Session` instance
+
+ The defaults of create_session() are the opposite of that of
+ :func:`sessionmaker`; ``autoflush`` and ``expire_on_commit`` are
+ False, ``autocommit`` is True. In this sense the session acts
+ more like the "classic" SQLAlchemy 0.3 session with these.
+
+ .. deprecated:: 1.4 The "autocommit" parameter will be removed in
+ SQLAlchemy 2.0. :func:`_orm.create_session` will return a
+ :class:`_orm.Session` that does not include "autocommit' behavior
+ in release 2.0.
+
+ Usage::
+
+ >>> from sqlalchemy.orm import create_session
+ >>> session = create_session()
+
+ It is recommended to use :func:`sessionmaker` instead of
+ create_session().
+
+ """
+
+ if kwargs.get("future", False):
+ kwargs.setdefault("autocommit", False)
+ else:
+ kwargs.setdefault("autocommit", True)
+
+ kwargs.setdefault("autoflush", False)
+ kwargs.setdefault("expire_on_commit", False)
+ return Session(bind=bind, **kwargs)
+
+
+with_loader_criteria = public_factory(LoaderCriteriaOption, ".orm")
+
+relationship = public_factory(RelationshipProperty, ".orm.relationship")
+
+
+@_sa_util.deprecated_20("relation", "Please use :func:`.relationship`.")
+def relation(*arg, **kw):
+ """A synonym for :func:`relationship`."""
+
+ return relationship(*arg, **kw)
+
+
+def dynamic_loader(argument, **kw):
+ """Construct a dynamically-loading mapper property.
+
+ This is essentially the same as
+ using the ``lazy='dynamic'`` argument with :func:`relationship`::
+
+ dynamic_loader(SomeClass)
+
+ # is the same as
+
+ relationship(SomeClass, lazy="dynamic")
+
+ See the section :ref:`dynamic_relationship` for more details
+ on dynamic loading.
+
+ """
+ kw["lazy"] = "dynamic"
+ return relationship(argument, **kw)
+
+
+column_property = public_factory(ColumnProperty, ".orm.column_property")
+composite = public_factory(CompositeProperty, ".orm.composite")
+
+
+def backref(name, **kwargs):
+ """When using the :paramref:`_orm.relationship.backref` parameter,
+ provides specific parameters to be used when the new
+ :func:`_orm.relationship` is generated.
+
+ E.g.::
+
+ 'items':relationship(
+ SomeItem, backref=backref('parent', lazy='subquery'))
+
+ The :paramref:`_orm.relationship.backref` parameter is generally
+ considered to be legacy; for modern applications, using
+ explicit :func:`_orm.relationship` constructs linked together using
+ the :paramref:`_orm.relationship.back_populates` parameter should be
+ preferred.
+
+ .. seealso::
+
+ :ref:`relationships_backref` - background on backrefs
+
+ """
+
+ return (name, kwargs)
+
+
+def deferred(*columns, **kw):
+ r"""Indicate a column-based mapped attribute that by default will
+ not load unless accessed.
+
+ :param \*columns: columns to be mapped. This is typically a single
+ :class:`_schema.Column` object,
+ however a collection is supported in order
+ to support multiple columns mapped under the same attribute.
+
+ :param raiseload: boolean, if True, indicates an exception should be raised
+ if the load operation is to take place.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`deferred_raiseload`
+
+ :param \**kw: additional keyword arguments passed to
+ :class:`.ColumnProperty`.
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ """
+ return ColumnProperty(deferred=True, *columns, **kw)
+
+
+def query_expression(default_expr=_sql.null()):
+ """Indicate an attribute that populates from a query-time SQL expression.
+
+ :param default_expr: Optional SQL expression object that will be used in
+ all cases if not assigned later with :func:`_orm.with_expression`.
+ E.g.::
+
+ from sqlalchemy.sql import literal
+
+ class C(Base):
+ #...
+ my_expr = query_expression(literal(1))
+
+ .. versionadded:: 1.3.18
+
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`mapper_querytime_expression`
+
+ """
+ prop = ColumnProperty(default_expr)
+ prop.strategy_key = (("query_expression", True),)
+ return prop
+
+
+mapper = public_factory(Mapper, ".orm.mapper")
+
+synonym = public_factory(SynonymProperty, ".orm.synonym")
+
+
+def clear_mappers():
+ """Remove all mappers from all classes.
+
+ .. versionchanged:: 1.4 This function now locates all
+ :class:`_orm.registry` objects and calls upon the
+ :meth:`_orm.registry.dispose` method of each.
+
+ This function removes all instrumentation from classes and disposes
+ of their associated mappers. Once called, the classes are unmapped
+ and can be later re-mapped with new mappers.
+
+ :func:`.clear_mappers` is *not* for normal use, as there is literally no
+ valid usage for it outside of very specific testing scenarios. Normally,
+ mappers are permanent structural components of user-defined classes, and
+ are never discarded independently of their class. If a mapped class
+ itself is garbage collected, its mapper is automatically disposed of as
+ well. As such, :func:`.clear_mappers` is only for usage in test suites
+ that re-use the same classes with different mappings, which is itself an
+ extremely rare use case - the only such use case is in fact SQLAlchemy's
+ own test suite, and possibly the test suites of other ORM extension
+ libraries which intend to test various combinations of mapper construction
+ upon a fixed set of classes.
+
+ """
+
+ mapperlib._dispose_registries(mapperlib._all_registries(), False)
+
+
+joinedload = strategy_options.joinedload._unbound_fn
+contains_eager = strategy_options.contains_eager._unbound_fn
+defer = strategy_options.defer._unbound_fn
+undefer = strategy_options.undefer._unbound_fn
+undefer_group = strategy_options.undefer_group._unbound_fn
+with_expression = strategy_options.with_expression._unbound_fn
+load_only = strategy_options.load_only._unbound_fn
+lazyload = strategy_options.lazyload._unbound_fn
+subqueryload = strategy_options.subqueryload._unbound_fn
+selectinload = strategy_options.selectinload._unbound_fn
+immediateload = strategy_options.immediateload._unbound_fn
+noload = strategy_options.noload._unbound_fn
+raiseload = strategy_options.raiseload._unbound_fn
+defaultload = strategy_options.defaultload._unbound_fn
+selectin_polymorphic = strategy_options.selectin_polymorphic._unbound_fn
+
+
+@_sa_util.deprecated_20("eagerload", "Please use :func:`_orm.joinedload`.")
+def eagerload(*args, **kwargs):
+ """A synonym for :func:`joinedload()`."""
+ return joinedload(*args, **kwargs)
+
+
+contains_alias = public_factory(AliasOption, ".orm.contains_alias")
+
+if True:
+ from .events import AttributeEvents
+ from .events import MapperEvents
+ from .events import InstanceEvents
+ from .events import InstrumentationEvents
+ from .events import QueryEvents
+ from .events import SessionEvents
+
+
+def __go(lcls):
+ global __all__
+ global AppenderQuery
+ from .. import util as sa_util
+ from . import dynamic
+ from . import events
+ from . import loading
+ import inspect as _inspect
+
+ from .dynamic import AppenderQuery
+
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
+
+ _sa_util.preloaded.import_prefix("sqlalchemy.orm")
+ _sa_util.preloaded.import_prefix("sqlalchemy.ext")
+
+
+__go(locals())
diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py
new file mode 100644
index 0000000..efa20fb
--- /dev/null
+++ b/lib/sqlalchemy/orm/attributes.py
@@ -0,0 +1,2331 @@
+# orm/attributes.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines instrumentation for class attributes and their interaction
+with instances.
+
+This module is usually not directly visible to user applications, but
+defines a large part of the ORM's interactivity.
+
+
+"""
+
+import operator
+
+from . import collections
+from . import exc as orm_exc
+from . import interfaces
+from .base import ATTR_EMPTY
+from .base import ATTR_WAS_SET
+from .base import CALLABLES_OK
+from .base import DEFERRED_HISTORY_LOAD
+from .base import INIT_OK
+from .base import instance_dict
+from .base import instance_state
+from .base import instance_str
+from .base import LOAD_AGAINST_COMMITTED
+from .base import manager_of_class
+from .base import NEVER_SET # noqa
+from .base import NO_AUTOFLUSH
+from .base import NO_CHANGE # noqa
+from .base import NO_RAISE
+from .base import NO_VALUE
+from .base import NON_PERSISTENT_OK # noqa
+from .base import PASSIVE_CLASS_MISMATCH # noqa
+from .base import PASSIVE_NO_FETCH
+from .base import PASSIVE_NO_FETCH_RELATED # noqa
+from .base import PASSIVE_NO_INITIALIZE
+from .base import PASSIVE_NO_RESULT
+from .base import PASSIVE_OFF
+from .base import PASSIVE_ONLY_PERSISTENT
+from .base import PASSIVE_RETURN_NO_VALUE
+from .base import RELATED_OBJECT_OK # noqa
+from .base import SQL_OK # noqa
+from .base import state_str
+from .. import event
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql import base as sql_base
+from ..sql import roles
+from ..sql import traversals
+from ..sql import visitors
+
+
+class NoKey(str):
+ pass
+
+
+NO_KEY = NoKey("no name")
+
+
+@inspection._self_inspects
+class QueryableAttribute(
+ interfaces._MappedAttribute,
+ interfaces.InspectionAttr,
+ interfaces.PropComparator,
+ traversals.HasCopyInternals,
+ roles.JoinTargetRole,
+ roles.OnClauseRole,
+ sql_base.Immutable,
+ sql_base.MemoizedHasCacheKey,
+):
+ """Base class for :term:`descriptor` objects that intercept
+ attribute events on behalf of a :class:`.MapperProperty`
+ object. The actual :class:`.MapperProperty` is accessible
+ via the :attr:`.QueryableAttribute.property`
+ attribute.
+
+
+ .. seealso::
+
+ :class:`.InstrumentedAttribute`
+
+ :class:`.MapperProperty`
+
+ :attr:`_orm.Mapper.all_orm_descriptors`
+
+ :attr:`_orm.Mapper.attrs`
+ """
+
+ is_attribute = True
+
+ # PropComparator has a __visit_name__ to participate within
+ # traversals. Disambiguate the attribute vs. a comparator.
+ __visit_name__ = "orm_instrumented_attribute"
+
+ def __init__(
+ self,
+ class_,
+ key,
+ parententity,
+ impl=None,
+ comparator=None,
+ of_type=None,
+ extra_criteria=(),
+ ):
+ self.class_ = class_
+ self.key = key
+ self._parententity = parententity
+ self.impl = impl
+ self.comparator = comparator
+ self._of_type = of_type
+ self._extra_criteria = extra_criteria
+
+ manager = manager_of_class(class_)
+ # manager is None in the case of AliasedClass
+ if manager:
+ # propagate existing event listeners from
+ # immediate superclass
+ for base in manager._bases:
+ if key in base:
+ self.dispatch._update(base[key].dispatch)
+ if base[key].dispatch._active_history:
+ self.dispatch._active_history = True
+
+ _cache_key_traversal = [
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+ ]
+
+ def __reduce__(self):
+ # this method is only used in terms of the
+ # sqlalchemy.ext.serializer extension
+ return (
+ _queryable_attribute_unreduce,
+ (
+ self.key,
+ self._parententity.mapper.class_,
+ self._parententity,
+ self._parententity.entity,
+ ),
+ )
+
+ @util.memoized_property
+ def _supports_population(self):
+ return self.impl.supports_population
+
+ @property
+ def _impl_uses_objects(self):
+ return self.impl.uses_objects
+
+ def get_history(self, instance, passive=PASSIVE_OFF):
+ return self.impl.get_history(
+ instance_state(instance), instance_dict(instance), passive
+ )
+
+ @util.memoized_property
+ def info(self):
+ """Return the 'info' dictionary for the underlying SQL element.
+
+ The behavior here is as follows:
+
+ * If the attribute is a column-mapped property, i.e.
+ :class:`.ColumnProperty`, which is mapped directly
+ to a schema-level :class:`_schema.Column` object, this attribute
+ will return the :attr:`.SchemaItem.info` dictionary associated
+ with the core-level :class:`_schema.Column` object.
+
+ * If the attribute is a :class:`.ColumnProperty` but is mapped to
+ any other kind of SQL expression other than a
+ :class:`_schema.Column`,
+ the attribute will refer to the :attr:`.MapperProperty.info`
+ dictionary associated directly with the :class:`.ColumnProperty`,
+ assuming the SQL expression itself does not have its own ``.info``
+ attribute (which should be the case, unless a user-defined SQL
+ construct has defined one).
+
+ * If the attribute refers to any other kind of
+ :class:`.MapperProperty`, including :class:`.RelationshipProperty`,
+ the attribute will refer to the :attr:`.MapperProperty.info`
+ dictionary associated with that :class:`.MapperProperty`.
+
+ * To access the :attr:`.MapperProperty.info` dictionary of the
+ :class:`.MapperProperty` unconditionally, including for a
+ :class:`.ColumnProperty` that's associated directly with a
+ :class:`_schema.Column`, the attribute can be referred to using
+ :attr:`.QueryableAttribute.property` attribute, as
+ ``MyClass.someattribute.property.info``.
+
+ .. seealso::
+
+ :attr:`.SchemaItem.info`
+
+ :attr:`.MapperProperty.info`
+
+ """
+ return self.comparator.info
+
+ @util.memoized_property
+ def parent(self):
+ """Return an inspection instance representing the parent.
+
+ This will be either an instance of :class:`_orm.Mapper`
+ or :class:`.AliasedInsp`, depending upon the nature
+ of the parent entity which this attribute is associated
+ with.
+
+ """
+ return inspection.inspect(self._parententity)
+
+ @util.memoized_property
+ def expression(self):
+ """The SQL expression object represented by this
+ :class:`.QueryableAttribute`.
+
+ This will typically be an instance of a :class:`_sql.ColumnElement`
+ subclass representing a column expression.
+
+ """
+ if self.key is NO_KEY:
+ annotations = {"entity_namespace": self._entity_namespace}
+ else:
+ annotations = {
+ "proxy_key": self.key,
+ "proxy_owner": self._parententity,
+ "entity_namespace": self._entity_namespace,
+ }
+
+ ce = self.comparator.__clause_element__()
+ try:
+ anno = ce._annotate
+ except AttributeError as ae:
+ util.raise_(
+ exc.InvalidRequestError(
+ 'When interpreting attribute "%s" as a SQL expression, '
+ "expected __clause_element__() to return "
+ "a ClauseElement object, got: %r" % (self, ce)
+ ),
+ from_=ae,
+ )
+ else:
+ return anno(annotations)
+
+ @property
+ def _entity_namespace(self):
+ return self._parententity
+
+ @property
+ def _annotations(self):
+ return self.__clause_element__()._annotations
+
+ def __clause_element__(self):
+ return self.expression
+
+ @property
+ def _from_objects(self):
+ return self.expression._from_objects
+
+ def _bulk_update_tuples(self, value):
+ """Return setter tuples for a bulk UPDATE."""
+
+ return self.comparator._bulk_update_tuples(value)
+
+ def adapt_to_entity(self, adapt_to_entity):
+ assert not self._of_type
+ return self.__class__(
+ adapt_to_entity.entity,
+ self.key,
+ impl=self.impl,
+ comparator=self.comparator.adapt_to_entity(adapt_to_entity),
+ parententity=adapt_to_entity,
+ )
+
+ def of_type(self, entity):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.of_type(entity),
+ of_type=inspection.inspect(entity),
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator.and_(*other),
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
+ )
+
+ def _clone(self, **kw):
+ return QueryableAttribute(
+ self.class_,
+ self.key,
+ self._parententity,
+ impl=self.impl,
+ comparator=self.comparator,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria,
+ )
+
+ def label(self, name):
+ return self.__clause_element__().label(name)
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.comparator, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.comparator, **kwargs)
+
+ def hasparent(self, state, optimistic=False):
+ return self.impl.hasparent(state, optimistic=optimistic) is not False
+
+ def __getattr__(self, key):
+ try:
+ return getattr(self.comparator, key)
+ except AttributeError as err:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor %r object associated with %s "
+ "has an attribute %r"
+ % (
+ type(self).__name__,
+ type(self.comparator).__name__,
+ self,
+ key,
+ )
+ ),
+ replace_context=err,
+ )
+
+ def __str__(self):
+ return "%s.%s" % (self.class_.__name__, self.key)
+
+ @util.memoized_property
+ def property(self):
+ """Return the :class:`.MapperProperty` associated with this
+ :class:`.QueryableAttribute`.
+
+
+ Return values here will commonly be instances of
+ :class:`.ColumnProperty` or :class:`.RelationshipProperty`.
+
+
+ """
+ return self.comparator.property
+
+
+def _queryable_attribute_unreduce(key, mapped_class, parententity, entity):
+ # this method is only used in terms of the
+ # sqlalchemy.ext.serializer extension
+ if parententity.is_aliased_class:
+ return entity._get_from_serialized(key, mapped_class, parententity)
+ else:
+ return getattr(entity, key)
+
+
+if util.py3k:
+ from typing import TypeVar, Generic
+
+ _T = TypeVar("_T")
+ _Generic_T = Generic[_T]
+else:
+ _Generic_T = type("_Generic_T", (), {})
+
+
+class Mapped(QueryableAttribute, _Generic_T):
+ """Represent an ORM mapped :term:`descriptor` attribute for typing
+ purposes.
+
+ This class represents the complete descriptor interface for any class
+ attribute that will have been :term:`instrumented` by the ORM
+ :class:`_orm.Mapper` class. When used with typing stubs, it is the final
+ type that would be used by a type checker such as mypy to provide the full
+ behavioral contract for the attribute.
+
+ .. tip::
+
+ The :class:`_orm.Mapped` class represents attributes that are handled
+ directly by the :class:`_orm.Mapper` class. It does not include other
+ Python descriptor classes that are provided as extensions, including
+ :ref:`hybrids_toplevel` and the :ref:`associationproxy_toplevel`.
+ While these systems still make use of ORM-specific superclasses
+ and structures, they are not :term:`instrumented` by the
+ :class:`_orm.Mapper` and instead provide their own functionality
+ when they are accessed on a class.
+
+ When using the :ref:`SQLAlchemy Mypy plugin <mypy_toplevel>`, the
+ :class:`_orm.Mapped` construct is used in typing annotations to indicate to
+ the plugin those attributes that are expected to be mapped; the plugin also
+ applies :class:`_orm.Mapped` as an annotation automatically when it scans
+ through declarative mappings in :ref:`orm_declarative_table` style. For
+ more indirect mapping styles such as
+ :ref:`imperative table <orm_imperative_table_configuration>` it is
+ typically applied explicitly to class level attributes that expect
+ to be mapped based on a given :class:`_schema.Table` configuration.
+
+ :class:`_orm.Mapped` is defined in the
+ `sqlalchemy2-stubs <https://pypi.org/project/sqlalchemy2-stubs>`_ project
+ as a :pep:`484` generic class which may subscribe to any arbitrary Python
+ type, which represents the Python type handled by the attribute::
+
+ class MyMappedClass(Base):
+ __table_ = Table(
+ "some_table", Base.metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("created_at", DateTime)
+ )
+
+ id : Mapped[int]
+ data: Mapped[str]
+ created_at: Mapped[datetime]
+
+ For complete background on how to use :class:`_orm.Mapped` with
+ pep-484 tools like Mypy, see the link below for background on SQLAlchemy's
+ Mypy plugin.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`mypy_toplevel` - complete background on Mypy integration
+
+ """
+
+ def __get__(self, instance, owner):
+ raise NotImplementedError()
+
+ def __set__(self, instance, value):
+ raise NotImplementedError()
+
+ def __delete__(self, instance):
+ raise NotImplementedError()
+
+
+class InstrumentedAttribute(Mapped):
+ """Class bound instrumented attribute which adds basic
+ :term:`descriptor` methods.
+
+ See :class:`.QueryableAttribute` for a description of most features.
+
+
+ """
+
+ inherit_cache = True
+
+ def __set__(self, instance, value):
+ self.impl.set(
+ instance_state(instance), instance_dict(instance), value, None
+ )
+
+ def __delete__(self, instance):
+ self.impl.delete(instance_state(instance), instance_dict(instance))
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self
+
+ dict_ = instance_dict(instance)
+ if self._supports_population and self.key in dict_:
+ return dict_[self.key]
+ else:
+ try:
+ state = instance_state(instance)
+ except AttributeError as err:
+ util.raise_(
+ orm_exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ return self.impl.get(state, dict_)
+
+
+HasEntityNamespace = util.namedtuple(
+ "HasEntityNamespace", ["entity_namespace"]
+)
+HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
+
+
+def create_proxied_attribute(descriptor):
+ """Create an QueryableAttribute / user descriptor hybrid.
+
+ Returns a new QueryableAttribute type that delegates descriptor
+ behavior and getattr() to the given descriptor.
+ """
+
+ # TODO: can move this to descriptor_props if the need for this
+ # function is removed from ext/hybrid.py
+
+ class Proxy(QueryableAttribute):
+ """Presents the :class:`.QueryableAttribute` interface as a
+ proxy on top of a Python descriptor / :class:`.PropComparator`
+ combination.
+
+ """
+
+ _extra_criteria = ()
+
+ def __init__(
+ self,
+ class_,
+ key,
+ descriptor,
+ comparator,
+ adapt_to_entity=None,
+ doc=None,
+ original_property=None,
+ ):
+ self.class_ = class_
+ self.key = key
+ self.descriptor = descriptor
+ self.original_property = original_property
+ self._comparator = comparator
+ self._adapt_to_entity = adapt_to_entity
+ self.__doc__ = doc
+
+ _is_internal_proxy = True
+
+ _cache_key_traversal = [
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
+ ]
+
+ @property
+ def _impl_uses_objects(self):
+ return (
+ self.original_property is not None
+ and getattr(self.class_, self.key).impl.uses_objects
+ )
+
+ @property
+ def _parententity(self):
+ return inspection.inspect(self.class_, raiseerr=False)
+
+ @property
+ def _entity_namespace(self):
+ if hasattr(self._comparator, "_parententity"):
+ return self._comparator._parententity
+ else:
+ # used by hybrid attributes which try to remain
+ # agnostic of any ORM concepts like mappers
+ return HasEntityNamespace(self.class_)
+
+ @property
+ def property(self):
+ return self.comparator.property
+
+ @util.memoized_property
+ def comparator(self):
+ if callable(self._comparator):
+ self._comparator = self._comparator()
+ if self._adapt_to_entity:
+ self._comparator = self._comparator.adapt_to_entity(
+ self._adapt_to_entity
+ )
+ return self._comparator
+
+ def adapt_to_entity(self, adapt_to_entity):
+ return self.__class__(
+ adapt_to_entity.entity,
+ self.key,
+ self.descriptor,
+ self._comparator,
+ adapt_to_entity,
+ )
+
+ def _clone(self, **kw):
+ return self.__class__(
+ self.class_,
+ self.key,
+ self.descriptor,
+ self._comparator,
+ adapt_to_entity=self._adapt_to_entity,
+ original_property=self.original_property,
+ )
+
+ def __get__(self, instance, owner):
+ retval = self.descriptor.__get__(instance, owner)
+ # detect if this is a plain Python @property, which just returns
+ # itself for class level access. If so, then return us.
+ # Otherwise, return the object returned by the descriptor.
+ if retval is self.descriptor and instance is None:
+ return self
+ else:
+ return retval
+
+ def __str__(self):
+ return "%s.%s" % (self.class_.__name__, self.key)
+
+ def __getattr__(self, attribute):
+ """Delegate __getattr__ to the original descriptor and/or
+ comparator."""
+ try:
+ return getattr(descriptor, attribute)
+ except AttributeError as err:
+ if attribute == "comparator":
+ util.raise_(
+ AttributeError("comparator"), replace_context=err
+ )
+ try:
+ # comparator itself might be unreachable
+ comparator = self.comparator
+ except AttributeError as err2:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor unconfigured comparator "
+ "object associated with %s has an attribute %r"
+ % (type(descriptor).__name__, self, attribute)
+ ),
+ replace_context=err2,
+ )
+ else:
+ try:
+ return getattr(comparator, attribute)
+ except AttributeError as err3:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor %r object "
+ "associated with %s has an attribute %r"
+ % (
+ type(descriptor).__name__,
+ type(comparator).__name__,
+ self,
+ attribute,
+ )
+ ),
+ replace_context=err3,
+ )
+
+ Proxy.__name__ = type(descriptor).__name__ + "Proxy"
+
+ util.monkeypatch_proxied_specials(
+ Proxy, type(descriptor), name="descriptor", from_instance=descriptor
+ )
+ return Proxy
+
+
+OP_REMOVE = util.symbol("REMOVE")
+OP_APPEND = util.symbol("APPEND")
+OP_REPLACE = util.symbol("REPLACE")
+OP_BULK_REPLACE = util.symbol("BULK_REPLACE")
+OP_MODIFIED = util.symbol("MODIFIED")
+
+
+class AttributeEvent(object):
+ """A token propagated throughout the course of a chain of attribute
+ events.
+
+ Serves as an indicator of the source of the event and also provides
+ a means of controlling propagation across a chain of attribute
+ operations.
+
+ The :class:`.Event` object is sent as the ``initiator`` argument
+ when dealing with events such as :meth:`.AttributeEvents.append`,
+ :meth:`.AttributeEvents.set`,
+ and :meth:`.AttributeEvents.remove`.
+
+ The :class:`.Event` object is currently interpreted by the backref
+ event handlers, and is used to control the propagation of operations
+ across two mutually-dependent attributes.
+
+ .. versionadded:: 0.9.0
+
+ :attribute impl: The :class:`.AttributeImpl` which is the current event
+ initiator.
+
+ :attribute op: The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE`,
+ :attr:`.OP_REPLACE`, or :attr:`.OP_BULK_REPLACE`, indicating the
+ source operation.
+
+ """
+
+ __slots__ = "impl", "op", "parent_token"
+
+ def __init__(self, attribute_impl, op):
+ self.impl = attribute_impl
+ self.op = op
+ self.parent_token = self.impl.parent_token
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, AttributeEvent)
+ and other.impl is self.impl
+ and other.op == self.op
+ )
+
+ @property
+ def key(self):
+ return self.impl.key
+
+ def hasparent(self, state):
+ return self.impl.hasparent(state)
+
+
+Event = AttributeEvent
+
+
+class AttributeImpl(object):
+ """internal implementation for instrumented attributes."""
+
+ def __init__(
+ self,
+ class_,
+ key,
+ callable_,
+ dispatch,
+ trackparent=False,
+ compare_function=None,
+ active_history=False,
+ parent_token=None,
+ load_on_unexpire=True,
+ send_modified_events=True,
+ accepts_scalar_loader=None,
+ **kwargs
+ ):
+ r"""Construct an AttributeImpl.
+
+ :param \class_: associated class
+
+ :param key: string name of the attribute
+
+ :param \callable_:
+ optional function which generates a callable based on a parent
+ instance, which produces the "default" values for a scalar or
+ collection attribute when it's first accessed, if not present
+ already.
+
+ :param trackparent:
+ if True, attempt to track if an instance has a parent attached
+ to it via this attribute.
+
+ :param compare_function:
+ a function that compares two values which are normally
+ assignable to this attribute.
+
+ :param active_history:
+ indicates that get_history() should always return the "old" value,
+ even if it means executing a lazy callable upon attribute change.
+
+ :param parent_token:
+ Usually references the MapperProperty, used as a key for
+ the hasparent() function to identify an "owning" attribute.
+ Allows multiple AttributeImpls to all match a single
+ owner attribute.
+
+ :param load_on_unexpire:
+ if False, don't include this attribute in a load-on-expired
+ operation, i.e. the "expired_attribute_loader" process.
+ The attribute can still be in the "expired" list and be
+ considered to be "expired". Previously, this flag was called
+ "expire_missing" and is only used by a deferred column
+ attribute.
+
+ :param send_modified_events:
+ if False, the InstanceState._modified_event method will have no
+ effect; this means the attribute will never show up as changed in a
+ history entry.
+
+ """
+ self.class_ = class_
+ self.key = key
+ self.callable_ = callable_
+ self.dispatch = dispatch
+ self.trackparent = trackparent
+ self.parent_token = parent_token or self
+ self.send_modified_events = send_modified_events
+ if compare_function is None:
+ self.is_equal = operator.eq
+ else:
+ self.is_equal = compare_function
+
+ if accepts_scalar_loader is not None:
+ self.accepts_scalar_loader = accepts_scalar_loader
+ else:
+ self.accepts_scalar_loader = self.default_accepts_scalar_loader
+
+ _deferred_history = kwargs.pop("_deferred_history", False)
+ self._deferred_history = _deferred_history
+
+ if active_history:
+ self.dispatch._active_history = True
+
+ self.load_on_unexpire = load_on_unexpire
+ self._modified_token = Event(self, OP_MODIFIED)
+
+ __slots__ = (
+ "class_",
+ "key",
+ "callable_",
+ "dispatch",
+ "trackparent",
+ "parent_token",
+ "send_modified_events",
+ "is_equal",
+ "load_on_unexpire",
+ "_modified_token",
+ "accepts_scalar_loader",
+ "_deferred_history",
+ )
+
+ def __str__(self):
+ return "%s.%s" % (self.class_.__name__, self.key)
+
+ def _get_active_history(self):
+ """Backwards compat for impl.active_history"""
+
+ return self.dispatch._active_history
+
+ def _set_active_history(self, value):
+ self.dispatch._active_history = value
+
+ active_history = property(_get_active_history, _set_active_history)
+
+ def hasparent(self, state, optimistic=False):
+ """Return the boolean value of a `hasparent` flag attached to
+ the given state.
+
+ The `optimistic` flag determines what the default return value
+ should be if no `hasparent` flag can be located.
+
+ As this function is used to determine if an instance is an
+ *orphan*, instances that were loaded from storage should be
+ assumed to not be orphans, until a True/False value for this
+ flag is set.
+
+ An instance attribute that is loaded by a callable function
+ will also not have a `hasparent` flag.
+
+ """
+ msg = "This AttributeImpl is not configured to track parents."
+ assert self.trackparent, msg
+
+ return (
+ state.parents.get(id(self.parent_token), optimistic) is not False
+ )
+
+ def sethasparent(self, state, parent_state, value):
+ """Set a boolean flag on the given item corresponding to
+ whether or not it is attached to a parent object via the
+ attribute represented by this ``InstrumentedAttribute``.
+
+ """
+ msg = "This AttributeImpl is not configured to track parents."
+ assert self.trackparent, msg
+
+ id_ = id(self.parent_token)
+ if value:
+ state.parents[id_] = parent_state
+ else:
+ if id_ in state.parents:
+ last_parent = state.parents[id_]
+
+ if (
+ last_parent is not False
+ and last_parent.key != parent_state.key
+ ):
+
+ if last_parent.obj() is None:
+ raise orm_exc.StaleDataError(
+ "Removing state %s from parent "
+ "state %s along attribute '%s', "
+ "but the parent record "
+ "has gone stale, can't be sure this "
+ "is the most recent parent."
+ % (
+ state_str(state),
+ state_str(parent_state),
+ self.key,
+ )
+ )
+
+ return
+
+ state.parents[id_] = False
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ raise NotImplementedError()
+
+ def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ """Return a list of tuples of (state, obj)
+ for all objects in this attribute's current state
+ + history.
+
+ Only applies to object-based attributes.
+
+ This is an inlining of existing functionality
+ which roughly corresponds to:
+
+ get_state_history(
+ state,
+ key,
+ passive=PASSIVE_NO_INITIALIZE).sum()
+
+ """
+ raise NotImplementedError()
+
+ def _default_value(self, state, dict_):
+ """Produce an empty value for an uninitialized scalar attribute."""
+
+ assert self.key not in dict_, (
+ "_default_value should only be invoked for an "
+ "uninitialized or expired attribute"
+ )
+
+ value = None
+ for fn in self.dispatch.init_scalar:
+ ret = fn(state, value, dict_)
+ if ret is not ATTR_EMPTY:
+ value = ret
+
+ return value
+
+ def get(self, state, dict_, passive=PASSIVE_OFF):
+ """Retrieve a value from the given object.
+ If a callable is assembled on this object's attribute, and
+ passive is False, the callable will be executed and the
+ resulting value will be set as the new value for this attribute.
+ """
+ if self.key in dict_:
+ return dict_[self.key]
+ else:
+ # if history present, don't load
+ key = self.key
+ if (
+ key not in state.committed_state
+ or state.committed_state[key] is NO_VALUE
+ ):
+ if not passive & CALLABLES_OK:
+ return PASSIVE_NO_RESULT
+
+ value = self._fire_loader_callables(state, key, passive)
+
+ if value is PASSIVE_NO_RESULT or value is NO_VALUE:
+ return value
+ elif value is ATTR_WAS_SET:
+ try:
+ return dict_[key]
+ except KeyError as err:
+ # TODO: no test coverage here.
+ util.raise_(
+ KeyError(
+ "Deferred loader for attribute "
+ "%r failed to populate "
+ "correctly" % key
+ ),
+ replace_context=err,
+ )
+ elif value is not ATTR_EMPTY:
+ return self.set_committed_value(state, dict_, value)
+
+ if not passive & INIT_OK:
+ return NO_VALUE
+ else:
+ return self._default_value(state, dict_)
+
+ def _fire_loader_callables(self, state, key, passive):
+ if (
+ self.accepts_scalar_loader
+ and self.load_on_unexpire
+ and key in state.expired_attributes
+ ):
+ return state._load_expired(state, passive)
+ elif key in state.callables:
+ callable_ = state.callables[key]
+ return callable_(state, passive)
+ elif self.callable_:
+ return self.callable_(state, passive)
+ else:
+ return ATTR_EMPTY
+
+ def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(state, dict_, value, initiator, passive=passive)
+
+ def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(
+ state, dict_, None, initiator, passive=passive, check_old=value
+ )
+
+ def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ self.set(
+ state,
+ dict_,
+ None,
+ initiator,
+ passive=passive,
+ check_old=value,
+ pop=True,
+ )
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
+ raise NotImplementedError()
+
+ def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
+ """return the unchanged value of this attribute"""
+
+ if self.key in state.committed_state:
+ value = state.committed_state[self.key]
+ if value is NO_VALUE:
+ return None
+ else:
+ return value
+ else:
+ return self.get(state, dict_, passive=passive)
+
+ def set_committed_value(self, state, dict_, value):
+ """set an attribute value on the given instance and 'commit' it."""
+
+ dict_[self.key] = value
+ state._commit(dict_, [self.key])
+ return value
+
+
+class ScalarAttributeImpl(AttributeImpl):
+ """represents a scalar value-holding InstrumentedAttribute."""
+
+ default_accepts_scalar_loader = True
+ uses_objects = False
+ supports_population = True
+ collection = False
+ dynamic = False
+
+ __slots__ = "_replace_token", "_append_token", "_remove_token"
+
+ def __init__(self, *arg, **kw):
+ super(ScalarAttributeImpl, self).__init__(*arg, **kw)
+ self._replace_token = self._append_token = Event(self, OP_REPLACE)
+ self._remove_token = Event(self, OP_REMOVE)
+
+ def delete(self, state, dict_):
+ if self.dispatch._active_history:
+ old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
+ else:
+ old = dict_.get(self.key, NO_VALUE)
+
+ if self.dispatch.remove:
+ self.fire_remove_event(state, dict_, old, self._remove_token)
+ state._modified_event(dict_, self, old)
+
+ existing = dict_.pop(self.key, NO_VALUE)
+ if (
+ existing is NO_VALUE
+ and old is NO_VALUE
+ and not state.expired
+ and self.key not in state.expired_attributes
+ ):
+ raise AttributeError("%s object does not have a value" % self)
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ if self.key in dict_:
+ return History.from_scalar_attribute(self, state, dict_[self.key])
+ elif self.key in state.committed_state:
+ return History.from_scalar_attribute(self, state, NO_VALUE)
+ else:
+ if passive & INIT_OK:
+ passive ^= INIT_OK
+ current = self.get(state, dict_, passive=passive)
+ if current is PASSIVE_NO_RESULT:
+ return HISTORY_BLANK
+ else:
+ return History.from_scalar_attribute(self, state, current)
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
+ if self.dispatch._active_history:
+ old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE)
+ else:
+ old = dict_.get(self.key, NO_VALUE)
+
+ if self.dispatch.set:
+ value = self.fire_replace_event(
+ state, dict_, value, old, initiator
+ )
+ state._modified_event(dict_, self, old)
+ dict_[self.key] = value
+
+ def fire_replace_event(self, state, dict_, value, previous, initiator):
+ for fn in self.dispatch.set:
+ value = fn(
+ state, value, previous, initiator or self._replace_token
+ )
+ return value
+
+ def fire_remove_event(self, state, dict_, value, initiator):
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ @property
+ def type(self):
+ self.property.columns[0].type
+
+
+class ScalarObjectAttributeImpl(ScalarAttributeImpl):
+ """represents a scalar-holding InstrumentedAttribute,
+ where the target object is also instrumented.
+
+ Adds events to delete/set operations.
+
+ """
+
+ default_accepts_scalar_loader = False
+ uses_objects = True
+ supports_population = True
+ collection = False
+
+ __slots__ = ()
+
+ def delete(self, state, dict_):
+ if self.dispatch._active_history:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED,
+ )
+ else:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_NO_FETCH ^ INIT_OK
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE,
+ )
+
+ self.fire_remove_event(state, dict_, old, self._remove_token)
+
+ existing = dict_.pop(self.key, NO_VALUE)
+
+ # if the attribute is expired, we currently have no way to tell
+ # that an object-attribute was expired vs. not loaded. So
+ # for this test, we look to see if the object has a DB identity.
+ if (
+ existing is NO_VALUE
+ and old is not PASSIVE_NO_RESULT
+ and state.key is None
+ ):
+ raise AttributeError("%s object does not have a value" % self)
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ if self.key in dict_:
+ current = dict_[self.key]
+ else:
+ if passive & INIT_OK:
+ passive ^= INIT_OK
+ current = self.get(state, dict_, passive=passive)
+ if current is PASSIVE_NO_RESULT:
+ return HISTORY_BLANK
+
+ if not self._deferred_history:
+ return History.from_object_attribute(self, state, current)
+ else:
+ original = state.committed_state.get(self.key, _NO_HISTORY)
+ if original is PASSIVE_NO_RESULT:
+
+ loader_passive = passive | (
+ PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE
+ | DEFERRED_HISTORY_LOAD
+ )
+ original = self._fire_loader_callables(
+ state, self.key, loader_passive
+ )
+ return History.from_object_attribute(
+ self, state, current, original=original
+ )
+
+ def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ if self.key in dict_:
+ current = dict_[self.key]
+ elif passive & CALLABLES_OK:
+ current = self.get(state, dict_, passive=passive)
+ else:
+ return []
+
+ # can't use __hash__(), can't use __eq__() here
+ if (
+ current is not None
+ and current is not PASSIVE_NO_RESULT
+ and current is not NO_VALUE
+ ):
+ ret = [(instance_state(current), current)]
+ else:
+ ret = [(None, None)]
+
+ if self.key in state.committed_state:
+ original = state.committed_state[self.key]
+ if (
+ original is not None
+ and original is not PASSIVE_NO_RESULT
+ and original is not NO_VALUE
+ and original is not current
+ ):
+
+ ret.append((instance_state(original), original))
+ return ret
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ ):
+ """Set a value on the given InstanceState."""
+
+ if self.dispatch._active_history:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_ONLY_PERSISTENT
+ | NO_AUTOFLUSH
+ | LOAD_AGAINST_COMMITTED,
+ )
+ else:
+ old = self.get(
+ state,
+ dict_,
+ passive=PASSIVE_NO_FETCH ^ INIT_OK
+ | LOAD_AGAINST_COMMITTED
+ | NO_RAISE,
+ )
+
+ if (
+ check_old is not None
+ and old is not PASSIVE_NO_RESULT
+ and check_old is not old
+ ):
+ if pop:
+ return
+ else:
+ raise ValueError(
+ "Object %s not associated with %s on attribute '%s'"
+ % (instance_str(check_old), state_str(state), self.key)
+ )
+
+ value = self.fire_replace_event(state, dict_, value, old, initiator)
+ dict_[self.key] = value
+
+ def fire_remove_event(self, state, dict_, value, initiator):
+ if self.trackparent and value not in (
+ None,
+ PASSIVE_NO_RESULT,
+ NO_VALUE,
+ ):
+ self.sethasparent(instance_state(value), state, False)
+
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ state._modified_event(dict_, self, value)
+
+ def fire_replace_event(self, state, dict_, value, previous, initiator):
+ if self.trackparent:
+ if previous is not value and previous not in (
+ None,
+ PASSIVE_NO_RESULT,
+ NO_VALUE,
+ ):
+ self.sethasparent(instance_state(previous), state, False)
+
+ for fn in self.dispatch.set:
+ value = fn(
+ state, value, previous, initiator or self._replace_token
+ )
+
+ state._modified_event(dict_, self, previous)
+
+ if self.trackparent:
+ if value is not None:
+ self.sethasparent(instance_state(value), state, True)
+
+ return value
+
+
+class CollectionAttributeImpl(AttributeImpl):
+ """A collection-holding attribute that instruments changes in membership.
+
+ Only handles collections of instrumented objects.
+
+ InstrumentedCollectionAttribute holds an arbitrary, user-specified
+ container object (defaulting to a list) and brokers access to the
+ CollectionAdapter, a "view" onto that object that presents consistent bag
+ semantics to the orm layer independent of the user data implementation.
+
+ """
+
+ default_accepts_scalar_loader = False
+ uses_objects = True
+ supports_population = True
+ collection = True
+ dynamic = False
+
+ __slots__ = (
+ "copy",
+ "collection_factory",
+ "_append_token",
+ "_remove_token",
+ "_bulk_replace_token",
+ "_duck_typed_as",
+ )
+
+ def __init__(
+ self,
+ class_,
+ key,
+ callable_,
+ dispatch,
+ typecallable=None,
+ trackparent=False,
+ copy_function=None,
+ compare_function=None,
+ **kwargs
+ ):
+ super(CollectionAttributeImpl, self).__init__(
+ class_,
+ key,
+ callable_,
+ dispatch,
+ trackparent=trackparent,
+ compare_function=compare_function,
+ **kwargs
+ )
+
+ if copy_function is None:
+ copy_function = self.__copy
+ self.copy = copy_function
+ self.collection_factory = typecallable
+ self._append_token = Event(self, OP_APPEND)
+ self._remove_token = Event(self, OP_REMOVE)
+ self._bulk_replace_token = Event(self, OP_BULK_REPLACE)
+ self._duck_typed_as = util.duck_type_collection(
+ self.collection_factory()
+ )
+
+ if getattr(self.collection_factory, "_sa_linker", None):
+
+ @event.listens_for(self, "init_collection")
+ def link(target, collection, collection_adapter):
+ collection._sa_linker(collection_adapter)
+
+ @event.listens_for(self, "dispose_collection")
+ def unlink(target, collection, collection_adapter):
+ collection._sa_linker(None)
+
+ def __copy(self, item):
+ return [y for y in collections.collection_adapter(item)]
+
+ def get_history(self, state, dict_, passive=PASSIVE_OFF):
+ current = self.get(state, dict_, passive=passive)
+ if current is PASSIVE_NO_RESULT:
+ return HISTORY_BLANK
+ else:
+ return History.from_collection(self, state, current)
+
+ def get_all_pending(self, state, dict_, passive=PASSIVE_NO_INITIALIZE):
+ # NOTE: passive is ignored here at the moment
+
+ if self.key not in dict_:
+ return []
+
+ current = dict_[self.key]
+ current = getattr(current, "_sa_adapter")
+
+ if self.key in state.committed_state:
+ original = state.committed_state[self.key]
+ if original is not NO_VALUE:
+ current_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in current
+ ]
+ original_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in original
+ ]
+
+ current_set = dict(current_states)
+ original_set = dict(original_states)
+
+ return (
+ [
+ (s, o)
+ for s, o in current_states
+ if s not in original_set
+ ]
+ + [(s, o) for s, o in current_states if s in original_set]
+ + [
+ (s, o)
+ for s, o in original_states
+ if s not in current_set
+ ]
+ )
+
+ return [(instance_state(o), o) for o in current]
+
+ def fire_append_event(self, state, dict_, value, initiator):
+ for fn in self.dispatch.append:
+ value = fn(state, value, initiator or self._append_token)
+
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ if self.trackparent and value is not None:
+ self.sethasparent(instance_state(value), state, True)
+
+ return value
+
+ def fire_append_wo_mutation_event(self, state, dict_, value, initiator):
+ for fn in self.dispatch.append_wo_mutation:
+ value = fn(state, value, initiator or self._append_token)
+
+ return value
+
+ def fire_pre_remove_event(self, state, dict_, initiator):
+ """A special event used for pop() operations.
+
+ The "remove" event needs to have the item to be removed passed to
+ it, which in the case of pop from a set, we don't have a way to access
+ the item before the operation. the event is used for all pop()
+ operations (even though set.pop is the one where it is really needed).
+
+ """
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ def fire_remove_event(self, state, dict_, value, initiator):
+ if self.trackparent and value is not None:
+ self.sethasparent(instance_state(value), state, False)
+
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ def delete(self, state, dict_):
+ if self.key not in dict_:
+ return
+
+ state._modified_event(dict_, self, NO_VALUE, True)
+
+ collection = self.get_collection(state, state.dict)
+ collection.clear_with_event()
+
+ # key is always present because we checked above. e.g.
+ # del is a no-op if collection not present.
+ del dict_[self.key]
+
+ def _default_value(self, state, dict_):
+ """Produce an empty collection for an un-initialized attribute"""
+
+ assert self.key not in dict_, (
+ "_default_value should only be invoked for an "
+ "uninitialized or expired attribute"
+ )
+
+ if self.key in state._empty_collections:
+ return state._empty_collections[self.key]
+
+ adapter, user_data = self._initialize_collection(state)
+ adapter._set_empty(user_data)
+ return user_data
+
+ def _initialize_collection(self, state):
+
+ adapter, collection = state.manager.initialize_collection(
+ self.key, state, self.collection_factory
+ )
+
+ self.dispatch.init_collection(state, collection, adapter)
+
+ return adapter, collection
+
+ def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ collection = self.get_collection(state, dict_, passive=passive)
+ if collection is PASSIVE_NO_RESULT:
+ value = self.fire_append_event(state, dict_, value, initiator)
+ assert (
+ self.key not in dict_
+ ), "Collection was loaded during event handling."
+ state._get_pending_mutation(self.key).append(value)
+ else:
+ collection.append_with_event(value, initiator)
+
+ def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ collection = self.get_collection(state, state.dict, passive=passive)
+ if collection is PASSIVE_NO_RESULT:
+ self.fire_remove_event(state, dict_, value, initiator)
+ assert (
+ self.key not in dict_
+ ), "Collection was loaded during event handling."
+ state._get_pending_mutation(self.key).remove(value)
+ else:
+ collection.remove_with_event(value, initiator)
+
+ def pop(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+ try:
+ # TODO: better solution here would be to add
+ # a "popper" role to collections.py to complement
+ # "remover".
+ self.remove(state, dict_, value, initiator, passive=passive)
+ except (ValueError, KeyError, IndexError):
+ pass
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator=None,
+ passive=PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ _adapt=True,
+ ):
+ iterable = orig_iterable = value
+
+ # pulling a new collection first so that an adaptation exception does
+ # not trigger a lazy load of the old collection.
+ new_collection, user_data = self._initialize_collection(state)
+ if _adapt:
+ if new_collection._converter is not None:
+ iterable = new_collection._converter(iterable)
+ else:
+ setting_type = util.duck_type_collection(iterable)
+ receiving_type = self._duck_typed_as
+
+ if setting_type is not receiving_type:
+ given = (
+ iterable is None
+ and "None"
+ or iterable.__class__.__name__
+ )
+ wanted = self._duck_typed_as.__name__
+ raise TypeError(
+ "Incompatible collection type: %s is not %s-like"
+ % (given, wanted)
+ )
+
+ # If the object is an adapted collection, return the (iterable)
+ # adapter.
+ if hasattr(iterable, "_sa_iterator"):
+ iterable = iterable._sa_iterator()
+ elif setting_type is dict:
+ if util.py3k:
+ iterable = iterable.values()
+ else:
+ iterable = getattr(
+ iterable, "itervalues", iterable.values
+ )()
+ else:
+ iterable = iter(iterable)
+ new_values = list(iterable)
+
+ evt = self._bulk_replace_token
+
+ self.dispatch.bulk_replace(state, new_values, evt)
+
+ old = self.get(state, dict_, passive=PASSIVE_ONLY_PERSISTENT)
+ if old is PASSIVE_NO_RESULT:
+ old = self._default_value(state, dict_)
+ elif old is orig_iterable:
+ # ignore re-assignment of the current collection, as happens
+ # implicitly with in-place operators (foo.collection |= other)
+ return
+
+ # place a copy of "old" in state.committed_state
+ state._modified_event(dict_, self, old, True)
+
+ old_collection = old._sa_adapter
+
+ dict_[self.key] = user_data
+
+ collections.bulk_replace(
+ new_values, old_collection, new_collection, initiator=evt
+ )
+
+ self._dispose_previous_collection(state, old, old_collection, True)
+
+ def _dispose_previous_collection(
+ self, state, collection, adapter, fire_event
+ ):
+ del collection._sa_adapter
+
+ # discarding old collection make sure it is not referenced in empty
+ # collections.
+ state._empty_collections.pop(self.key, None)
+ if fire_event:
+ self.dispatch.dispose_collection(state, collection, adapter)
+
+ def _invalidate_collection(self, collection):
+ adapter = getattr(collection, "_sa_adapter")
+ adapter.invalidated = True
+
+ def set_committed_value(self, state, dict_, value):
+ """Set an attribute value on the given instance and 'commit' it."""
+
+ collection, user_data = self._initialize_collection(state)
+
+ if value:
+ collection.append_multiple_without_event(value)
+
+ state.dict[self.key] = user_data
+
+ state._commit(dict_, [self.key])
+
+ if self.key in state._pending_mutations:
+ # pending items exist. issue a modified event,
+ # add/remove new items.
+ state._modified_event(dict_, self, user_data, True)
+
+ pending = state._pending_mutations.pop(self.key)
+ added = pending.added_items
+ removed = pending.deleted_items
+ for item in added:
+ collection.append_without_event(item)
+ for item in removed:
+ collection.remove_without_event(item)
+
+ return user_data
+
+ def get_collection(
+ self, state, dict_, user_data=None, passive=PASSIVE_OFF
+ ):
+ """Retrieve the CollectionAdapter associated with the given state.
+
+ if user_data is None, retrieves it from the state using normal
+ "get()" rules, which will fire lazy callables or return the "empty"
+ collection value.
+
+ """
+ if user_data is None:
+ user_data = self.get(state, dict_, passive=passive)
+ if user_data is PASSIVE_NO_RESULT:
+ return user_data
+
+ return user_data._sa_adapter
+
+
+def backref_listeners(attribute, key, uselist):
+ """Apply listeners to synchronize a two-way relationship."""
+
+ # use easily recognizable names for stack traces.
+
+ # in the sections marked "tokens to test for a recursive loop",
+ # this is somewhat brittle and very performance-sensitive logic
+ # that is specific to how we might arrive at each event. a marker
+ # that can target us directly to arguments being invoked against
+ # the impl might be simpler, but could interfere with other systems.
+
+ parent_token = attribute.impl.parent_token
+ parent_impl = attribute.impl
+
+ def _acceptable_key_err(child_state, initiator, child_impl):
+ raise ValueError(
+ "Bidirectional attribute conflict detected: "
+ 'Passing object %s to attribute "%s" '
+ 'triggers a modify event on attribute "%s" '
+ 'via the backref "%s".'
+ % (
+ state_str(child_state),
+ initiator.parent_token,
+ child_impl.parent_token,
+ attribute.impl.parent_token,
+ )
+ )
+
+ def emit_backref_from_scalar_set_event(state, child, oldchild, initiator):
+ if oldchild is child:
+ return child
+ if (
+ oldchild is not None
+ and oldchild is not PASSIVE_NO_RESULT
+ and oldchild is not NO_VALUE
+ ):
+ # With lazy=None, there's no guarantee that the full collection is
+ # present when updating via a backref.
+ old_state, old_dict = (
+ instance_state(oldchild),
+ instance_dict(oldchild),
+ )
+ impl = old_state.manager[key].impl
+
+ # tokens to test for a recursive loop.
+ if not impl.collection and not impl.dynamic:
+ check_recursive_token = impl._replace_token
+ else:
+ check_recursive_token = impl._remove_token
+
+ if initiator is not check_recursive_token:
+ impl.pop(
+ old_state,
+ old_dict,
+ state.obj(),
+ parent_impl._append_token,
+ passive=PASSIVE_NO_FETCH,
+ )
+
+ if child is not None:
+ child_state, child_dict = (
+ instance_state(child),
+ instance_dict(child),
+ )
+ child_impl = child_state.manager[key].impl
+
+ if (
+ initiator.parent_token is not parent_token
+ and initiator.parent_token is not child_impl.parent_token
+ ):
+ _acceptable_key_err(state, initiator, child_impl)
+
+ # tokens to test for a recursive loop.
+ check_append_token = child_impl._append_token
+ check_bulk_replace_token = (
+ child_impl._bulk_replace_token
+ if child_impl.collection
+ else None
+ )
+
+ if (
+ initiator is not check_append_token
+ and initiator is not check_bulk_replace_token
+ ):
+ child_impl.append(
+ child_state,
+ child_dict,
+ state.obj(),
+ initiator,
+ passive=PASSIVE_NO_FETCH,
+ )
+ return child
+
+ def emit_backref_from_collection_append_event(state, child, initiator):
+ if child is None:
+ return
+
+ child_state, child_dict = instance_state(child), instance_dict(child)
+ child_impl = child_state.manager[key].impl
+
+ if (
+ initiator.parent_token is not parent_token
+ and initiator.parent_token is not child_impl.parent_token
+ ):
+ _acceptable_key_err(state, initiator, child_impl)
+
+ # tokens to test for a recursive loop.
+ check_append_token = child_impl._append_token
+ check_bulk_replace_token = (
+ child_impl._bulk_replace_token if child_impl.collection else None
+ )
+
+ if (
+ initiator is not check_append_token
+ and initiator is not check_bulk_replace_token
+ ):
+ child_impl.append(
+ child_state,
+ child_dict,
+ state.obj(),
+ initiator,
+ passive=PASSIVE_NO_FETCH,
+ )
+ return child
+
+ def emit_backref_from_collection_remove_event(state, child, initiator):
+ if (
+ child is not None
+ and child is not PASSIVE_NO_RESULT
+ and child is not NO_VALUE
+ ):
+ child_state, child_dict = (
+ instance_state(child),
+ instance_dict(child),
+ )
+ child_impl = child_state.manager[key].impl
+
+ # tokens to test for a recursive loop.
+ if not child_impl.collection and not child_impl.dynamic:
+ check_remove_token = child_impl._remove_token
+ check_replace_token = child_impl._replace_token
+ check_for_dupes_on_remove = uselist and not parent_impl.dynamic
+ else:
+ check_remove_token = child_impl._remove_token
+ check_replace_token = (
+ child_impl._bulk_replace_token
+ if child_impl.collection
+ else None
+ )
+ check_for_dupes_on_remove = False
+
+ if (
+ initiator is not check_remove_token
+ and initiator is not check_replace_token
+ ):
+
+ if not check_for_dupes_on_remove or not util.has_dupes(
+ # when this event is called, the item is usually
+ # present in the list, except for a pop() operation.
+ state.dict[parent_impl.key],
+ child,
+ ):
+ child_impl.pop(
+ child_state,
+ child_dict,
+ state.obj(),
+ initiator,
+ passive=PASSIVE_NO_FETCH,
+ )
+
+ if uselist:
+ event.listen(
+ attribute,
+ "append",
+ emit_backref_from_collection_append_event,
+ retval=True,
+ raw=True,
+ )
+ else:
+ event.listen(
+ attribute,
+ "set",
+ emit_backref_from_scalar_set_event,
+ retval=True,
+ raw=True,
+ )
+ # TODO: need coverage in test/orm/ of remove event
+ event.listen(
+ attribute,
+ "remove",
+ emit_backref_from_collection_remove_event,
+ retval=True,
+ raw=True,
+ )
+
+
+_NO_HISTORY = util.symbol("NO_HISTORY")
+_NO_STATE_SYMBOLS = frozenset([id(PASSIVE_NO_RESULT), id(NO_VALUE)])
+
+
+class History(util.namedtuple("History", ["added", "unchanged", "deleted"])):
+ """A 3-tuple of added, unchanged and deleted values,
+ representing the changes which have occurred on an instrumented
+ attribute.
+
+ The easiest way to get a :class:`.History` object for a particular
+ attribute on an object is to use the :func:`_sa.inspect` function::
+
+ from sqlalchemy import inspect
+
+ hist = inspect(myobject).attrs.myattribute.history
+
+ Each tuple member is an iterable sequence:
+
+ * ``added`` - the collection of items added to the attribute (the first
+ tuple element).
+
+ * ``unchanged`` - the collection of items that have not changed on the
+ attribute (the second tuple element).
+
+ * ``deleted`` - the collection of items that have been removed from the
+ attribute (the third tuple element).
+
+ """
+
+ def __bool__(self):
+ return self != HISTORY_BLANK
+
+ __nonzero__ = __bool__
+
+ def empty(self):
+ """Return True if this :class:`.History` has no changes
+ and no existing, unchanged state.
+
+ """
+
+ return not bool((self.added or self.deleted) or self.unchanged)
+
+ def sum(self):
+ """Return a collection of added + unchanged + deleted."""
+
+ return (
+ (self.added or []) + (self.unchanged or []) + (self.deleted or [])
+ )
+
+ def non_deleted(self):
+ """Return a collection of added + unchanged."""
+
+ return (self.added or []) + (self.unchanged or [])
+
+ def non_added(self):
+ """Return a collection of unchanged + deleted."""
+
+ return (self.unchanged or []) + (self.deleted or [])
+
+ def has_changes(self):
+ """Return True if this :class:`.History` has changes."""
+
+ return bool(self.added or self.deleted)
+
+ def as_state(self):
+ return History(
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.added
+ ],
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.unchanged
+ ],
+ [
+ (c is not None) and instance_state(c) or None
+ for c in self.deleted
+ ],
+ )
+
+ @classmethod
+ def from_scalar_attribute(cls, attribute, state, current):
+ original = state.committed_state.get(attribute.key, _NO_HISTORY)
+
+ if original is _NO_HISTORY:
+ if current is NO_VALUE:
+ return cls((), (), ())
+ else:
+ return cls((), [current], ())
+ # don't let ClauseElement expressions here trip things up
+ elif (
+ current is not NO_VALUE
+ and attribute.is_equal(current, original) is True
+ ):
+ return cls((), [current], ())
+ else:
+ # current convention on native scalars is to not
+ # include information
+ # about missing previous value in "deleted", but
+ # we do include None, which helps in some primary
+ # key situations
+ if id(original) in _NO_STATE_SYMBOLS:
+ deleted = ()
+ # indicate a "del" operation occurred when we don't have
+ # the previous value as: ([None], (), ())
+ if id(current) in _NO_STATE_SYMBOLS:
+ current = None
+ else:
+ deleted = [original]
+ if current is NO_VALUE:
+ return cls((), (), deleted)
+ else:
+ return cls([current], (), deleted)
+
+ @classmethod
+ def from_object_attribute(
+ cls, attribute, state, current, original=_NO_HISTORY
+ ):
+ if original is _NO_HISTORY:
+ original = state.committed_state.get(attribute.key, _NO_HISTORY)
+
+ if original is _NO_HISTORY:
+ if current is NO_VALUE:
+ return cls((), (), ())
+ else:
+ return cls((), [current], ())
+ elif current is original and current is not NO_VALUE:
+ return cls((), [current], ())
+ else:
+ # current convention on related objects is to not
+ # include information
+ # about missing previous value in "deleted", and
+ # to also not include None - the dependency.py rules
+ # ignore the None in any case.
+ if id(original) in _NO_STATE_SYMBOLS or original is None:
+ deleted = ()
+ # indicate a "del" operation occurred when we don't have
+ # the previous value as: ([None], (), ())
+ if id(current) in _NO_STATE_SYMBOLS:
+ current = None
+ else:
+ deleted = [original]
+ if current is NO_VALUE:
+ return cls((), (), deleted)
+ else:
+ return cls([current], (), deleted)
+
+ @classmethod
+ def from_collection(cls, attribute, state, current):
+ original = state.committed_state.get(attribute.key, _NO_HISTORY)
+ if current is NO_VALUE:
+ return cls((), (), ())
+
+ current = getattr(current, "_sa_adapter")
+ if original is NO_VALUE:
+ return cls(list(current), (), ())
+ elif original is _NO_HISTORY:
+ return cls((), list(current), ())
+ else:
+
+ current_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in current
+ ]
+ original_states = [
+ ((c is not None) and instance_state(c) or None, c)
+ for c in original
+ ]
+
+ current_set = dict(current_states)
+ original_set = dict(original_states)
+
+ return cls(
+ [o for s, o in current_states if s not in original_set],
+ [o for s, o in current_states if s in original_set],
+ [o for s, o in original_states if s not in current_set],
+ )
+
+
+HISTORY_BLANK = History(None, None, None)
+
+
+def get_history(obj, key, passive=PASSIVE_OFF):
+ """Return a :class:`.History` record for the given object
+ and attribute key.
+
+ This is the **pre-flush** history for a given attribute, which is
+ reset each time the :class:`.Session` flushes changes to the
+ current database transaction.
+
+ .. note::
+
+ Prefer to use the :attr:`.AttributeState.history` and
+ :meth:`.AttributeState.load_history` accessors to retrieve the
+ :class:`.History` for instance attributes.
+
+
+ :param obj: an object whose class is instrumented by the
+ attributes package.
+
+ :param key: string attribute name.
+
+ :param passive: indicates loading behavior for the attribute
+ if the value is not already present. This is a
+ bitflag attribute, which defaults to the symbol
+ :attr:`.PASSIVE_OFF` indicating all necessary SQL
+ should be emitted.
+
+ .. seealso::
+
+ :attr:`.AttributeState.history`
+
+ :meth:`.AttributeState.load_history` - retrieve history
+ using loader callables if the value is not locally present.
+
+ """
+
+ return get_state_history(instance_state(obj), key, passive)
+
+
+def get_state_history(state, key, passive=PASSIVE_OFF):
+ return state.get_history(key, passive)
+
+
+def has_parent(cls, obj, key, optimistic=False):
+ """TODO"""
+ manager = manager_of_class(cls)
+ state = instance_state(obj)
+ return manager.has_parent(state, key, optimistic)
+
+
+def register_attribute(class_, key, **kw):
+ comparator = kw.pop("comparator", None)
+ parententity = kw.pop("parententity", None)
+ doc = kw.pop("doc", None)
+ desc = register_descriptor(class_, key, comparator, parententity, doc=doc)
+ register_attribute_impl(class_, key, **kw)
+ return desc
+
+
+def register_attribute_impl(
+ class_,
+ key,
+ uselist=False,
+ callable_=None,
+ useobject=False,
+ impl_class=None,
+ backref=None,
+ **kw
+):
+
+ manager = manager_of_class(class_)
+ if uselist:
+ factory = kw.pop("typecallable", None)
+ typecallable = manager.instrument_collection_class(
+ key, factory or list
+ )
+ else:
+ typecallable = kw.pop("typecallable", None)
+
+ dispatch = manager[key].dispatch
+
+ if impl_class:
+ impl = impl_class(class_, key, typecallable, dispatch, **kw)
+ elif uselist:
+ impl = CollectionAttributeImpl(
+ class_, key, callable_, dispatch, typecallable=typecallable, **kw
+ )
+ elif useobject:
+ impl = ScalarObjectAttributeImpl(
+ class_, key, callable_, dispatch, **kw
+ )
+ else:
+ impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw)
+
+ manager[key].impl = impl
+
+ if backref:
+ backref_listeners(manager[key], backref, uselist)
+
+ manager.post_configure_attribute(key)
+ return manager[key]
+
+
+def register_descriptor(
+ class_, key, comparator=None, parententity=None, doc=None
+):
+ manager = manager_of_class(class_)
+
+ descriptor = InstrumentedAttribute(
+ class_, key, comparator=comparator, parententity=parententity
+ )
+
+ descriptor.__doc__ = doc
+
+ manager.instrument_attribute(key, descriptor)
+ return descriptor
+
+
+def unregister_attribute(class_, key):
+ manager_of_class(class_).uninstrument_attribute(key)
+
+
+def init_collection(obj, key):
+ """Initialize a collection attribute and return the collection adapter.
+
+ This function is used to provide direct access to collection internals
+ for a previously unloaded attribute. e.g.::
+
+ collection_adapter = init_collection(someobject, 'elements')
+ for elem in values:
+ collection_adapter.append_without_event(elem)
+
+ For an easier way to do the above, see
+ :func:`~sqlalchemy.orm.attributes.set_committed_value`.
+
+ :param obj: a mapped object
+
+ :param key: string attribute name where the collection is located.
+
+ """
+ state = instance_state(obj)
+ dict_ = state.dict
+ return init_state_collection(state, dict_, key)
+
+
+def init_state_collection(state, dict_, key):
+ """Initialize a collection attribute and return the collection adapter.
+
+ Discards any existing collection which may be there.
+
+ """
+ attr = state.manager[key].impl
+
+ old = dict_.pop(key, None) # discard old collection
+ if old is not None:
+ old_collection = old._sa_adapter
+ attr._dispose_previous_collection(state, old, old_collection, False)
+
+ user_data = attr._default_value(state, dict_)
+ adapter = attr.get_collection(state, dict_, user_data)
+ adapter._reset_empty()
+
+ return adapter
+
+
+def set_committed_value(instance, key, value):
+ """Set the value of an attribute with no history events.
+
+ Cancels any previous history present. The value should be
+ a scalar value for scalar-holding attributes, or
+ an iterable for any collection-holding attribute.
+
+ This is the same underlying method used when a lazy loader
+ fires off and loads additional data from the database.
+ In particular, this method can be used by application code
+ which has loaded additional attributes or collections through
+ separate queries, which can then be attached to an instance
+ as though it were part of its original loaded state.
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.manager[key].impl.set_committed_value(state, dict_, value)
+
+
+def set_attribute(instance, key, value, initiator=None):
+ """Set the value of an attribute, firing history events.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+ Custom attribute management schemes will need to make usage
+ of this method to establish attribute state as understood
+ by SQLAlchemy.
+
+ :param instance: the object that will be modified
+
+ :param key: string name of the attribute
+
+ :param value: value to assign
+
+ :param initiator: an instance of :class:`.Event` that would have
+ been propagated from a previous event listener. This argument
+ is used when the :func:`.set_attribute` function is being used within
+ an existing event listening function where an :class:`.Event` object
+ is being supplied; the object may be used to track the origin of the
+ chain of events.
+
+ .. versionadded:: 1.2.3
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.manager[key].impl.set(state, dict_, value, initiator)
+
+
+def get_attribute(instance, key):
+ """Get the value of an attribute, firing any callables required.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+ Custom attribute management schemes will need to make usage
+ of this method to make usage of attribute state as understood
+ by SQLAlchemy.
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ return state.manager[key].impl.get(state, dict_)
+
+
+def del_attribute(instance, key):
+ """Delete the value of an attribute, firing history events.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+ Custom attribute management schemes will need to make usage
+ of this method to establish attribute state as understood
+ by SQLAlchemy.
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state.manager[key].impl.delete(state, dict_)
+
+
+def flag_modified(instance, key):
+ """Mark an attribute on an instance as 'modified'.
+
+ This sets the 'modified' flag on the instance and
+ establishes an unconditional change event for the given attribute.
+ The attribute must have a value present, else an
+ :class:`.InvalidRequestError` is raised.
+
+ To mark an object "dirty" without referring to any specific attribute
+ so that it is considered within a flush, use the
+ :func:`.attributes.flag_dirty` call.
+
+ .. seealso::
+
+ :func:`.attributes.flag_dirty`
+
+ """
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ impl = state.manager[key].impl
+ impl.dispatch.modified(state, impl._modified_token)
+ state._modified_event(dict_, impl, NO_VALUE, is_userland=True)
+
+
+def flag_dirty(instance):
+ """Mark an instance as 'dirty' without any specific attribute mentioned.
+
+ This is a special operation that will allow the object to travel through
+ the flush process for interception by events such as
+ :meth:`.SessionEvents.before_flush`. Note that no SQL will be emitted in
+ the flush process for an object that has no changes, even if marked dirty
+ via this method. However, a :meth:`.SessionEvents.before_flush` handler
+ will be able to see the object in the :attr:`.Session.dirty` collection and
+ may establish changes on it, which will then be included in the SQL
+ emitted.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :func:`.attributes.flag_modified`
+
+ """
+
+ state, dict_ = instance_state(instance), instance_dict(instance)
+ state._modified_event(dict_, None, NO_VALUE, is_userland=True)
diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py
new file mode 100644
index 0000000..8e94d7b
--- /dev/null
+++ b/lib/sqlalchemy/orm/base.py
@@ -0,0 +1,572 @@
+# orm/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Constants and rudimental functions used throughout the ORM.
+
+"""
+
+import operator
+
+from . import exc
+from .. import exc as sa_exc
+from .. import inspection
+from .. import util
+
+
+PASSIVE_NO_RESULT = util.symbol(
+ "PASSIVE_NO_RESULT",
+ """Symbol returned by a loader callable or other attribute/history
+ retrieval operation when a value could not be determined, based
+ on loader callable flags.
+ """,
+)
+
+PASSIVE_CLASS_MISMATCH = util.symbol(
+ "PASSIVE_CLASS_MISMATCH",
+ """Symbol indicating that an object is locally present for a given
+ primary key identity but it is not of the requested class. The
+ return value is therefore None and no SQL should be emitted.""",
+)
+
+ATTR_WAS_SET = util.symbol(
+ "ATTR_WAS_SET",
+ """Symbol returned by a loader callable to indicate the
+ retrieved value, or values, were assigned to their attributes
+ on the target object.
+ """,
+)
+
+ATTR_EMPTY = util.symbol(
+ "ATTR_EMPTY",
+ """Symbol used internally to indicate an attribute had no callable.""",
+)
+
+NO_VALUE = util.symbol(
+ "NO_VALUE",
+ """Symbol which may be placed as the 'previous' value of an attribute,
+ indicating no value was loaded for an attribute when it was modified,
+ and flags indicated we were not to load it.
+ """,
+)
+NEVER_SET = NO_VALUE
+"""
+Synonymous with NO_VALUE
+
+.. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE
+"""
+
+NO_CHANGE = util.symbol(
+ "NO_CHANGE",
+ """No callables or SQL should be emitted on attribute access
+ and no state should change
+ """,
+ canonical=0,
+)
+
+CALLABLES_OK = util.symbol(
+ "CALLABLES_OK",
+ """Loader callables can be fired off if a value
+ is not present.
+ """,
+ canonical=1,
+)
+
+SQL_OK = util.symbol(
+ "SQL_OK",
+ """Loader callables can emit SQL at least on scalar value attributes.""",
+ canonical=2,
+)
+
+RELATED_OBJECT_OK = util.symbol(
+ "RELATED_OBJECT_OK",
+ """Callables can use SQL to load related objects as well
+ as scalar value attributes.
+ """,
+ canonical=4,
+)
+
+INIT_OK = util.symbol(
+ "INIT_OK",
+ """Attributes should be initialized with a blank
+ value (None or an empty collection) upon get, if no other
+ value can be obtained.
+ """,
+ canonical=8,
+)
+
+NON_PERSISTENT_OK = util.symbol(
+ "NON_PERSISTENT_OK",
+ """Callables can be emitted if the parent is not persistent.""",
+ canonical=16,
+)
+
+LOAD_AGAINST_COMMITTED = util.symbol(
+ "LOAD_AGAINST_COMMITTED",
+ """Callables should use committed values as primary/foreign keys during a
+ load.
+ """,
+ canonical=32,
+)
+
+NO_AUTOFLUSH = util.symbol(
+ "NO_AUTOFLUSH",
+ """Loader callables should disable autoflush.""",
+ canonical=64,
+)
+
+NO_RAISE = util.symbol(
+ "NO_RAISE",
+ """Loader callables should not raise any assertions""",
+ canonical=128,
+)
+
+DEFERRED_HISTORY_LOAD = util.symbol(
+ "DEFERRED_HISTORY_LOAD",
+ """indicates special load of the previous value of an attribute""",
+ canonical=256,
+)
+
+# pre-packaged sets of flags used as inputs
+PASSIVE_OFF = util.symbol(
+ "PASSIVE_OFF",
+ "Callables can be emitted in all cases.",
+ canonical=(
+ RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK
+ ),
+)
+PASSIVE_RETURN_NO_VALUE = util.symbol(
+ "PASSIVE_RETURN_NO_VALUE",
+ """PASSIVE_OFF ^ INIT_OK""",
+ canonical=PASSIVE_OFF ^ INIT_OK,
+)
+PASSIVE_NO_INITIALIZE = util.symbol(
+ "PASSIVE_NO_INITIALIZE",
+ "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK",
+ canonical=PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK,
+)
+PASSIVE_NO_FETCH = util.symbol(
+ "PASSIVE_NO_FETCH", "PASSIVE_OFF ^ SQL_OK", canonical=PASSIVE_OFF ^ SQL_OK
+)
+PASSIVE_NO_FETCH_RELATED = util.symbol(
+ "PASSIVE_NO_FETCH_RELATED",
+ "PASSIVE_OFF ^ RELATED_OBJECT_OK",
+ canonical=PASSIVE_OFF ^ RELATED_OBJECT_OK,
+)
+PASSIVE_ONLY_PERSISTENT = util.symbol(
+ "PASSIVE_ONLY_PERSISTENT",
+ "PASSIVE_OFF ^ NON_PERSISTENT_OK",
+ canonical=PASSIVE_OFF ^ NON_PERSISTENT_OK,
+)
+
+DEFAULT_MANAGER_ATTR = "_sa_class_manager"
+DEFAULT_STATE_ATTR = "_sa_instance_state"
+
+EXT_CONTINUE = util.symbol("EXT_CONTINUE")
+EXT_STOP = util.symbol("EXT_STOP")
+EXT_SKIP = util.symbol("EXT_SKIP")
+
+ONETOMANY = util.symbol(
+ "ONETOMANY",
+ """Indicates the one-to-many direction for a :func:`_orm.relationship`.
+
+ This symbol is typically used by the internals but may be exposed within
+ certain API features.
+
+ """,
+)
+
+MANYTOONE = util.symbol(
+ "MANYTOONE",
+ """Indicates the many-to-one direction for a :func:`_orm.relationship`.
+
+ This symbol is typically used by the internals but may be exposed within
+ certain API features.
+
+ """,
+)
+
+MANYTOMANY = util.symbol(
+ "MANYTOMANY",
+ """Indicates the many-to-many direction for a :func:`_orm.relationship`.
+
+ This symbol is typically used by the internals but may be exposed within
+ certain API features.
+
+ """,
+)
+
+NOT_EXTENSION = util.symbol(
+ "NOT_EXTENSION",
+ """Symbol indicating an :class:`InspectionAttr` that's
+ not part of sqlalchemy.ext.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ """,
+)
+
+_never_set = frozenset([NEVER_SET])
+
+_none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT])
+
+_SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED")
+
+_DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
+
+_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
+
+
+def _assertions(*assertions):
+ @util.decorator
+ def generate(fn, *args, **kw):
+ self = args[0]
+ for assertion in assertions:
+ assertion(self, fn.__name__)
+ fn(self, *args[1:], **kw)
+
+ return generate
+
+
+# these can be replaced by sqlalchemy.ext.instrumentation
+# if augmented class instrumentation is enabled.
+def manager_of_class(cls):
+ return cls.__dict__.get(DEFAULT_MANAGER_ATTR, None)
+
+
+instance_state = operator.attrgetter(DEFAULT_STATE_ATTR)
+
+instance_dict = operator.attrgetter("__dict__")
+
+
+def instance_str(instance):
+ """Return a string describing an instance."""
+
+ return state_str(instance_state(instance))
+
+
+def state_str(state):
+ """Return a string describing an instance via its InstanceState."""
+
+ if state is None:
+ return "None"
+ else:
+ return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj()))
+
+
+def state_class_str(state):
+ """Return a string describing an instance's class via its
+ InstanceState.
+ """
+
+ if state is None:
+ return "None"
+ else:
+ return "<%s>" % (state.class_.__name__,)
+
+
+def attribute_str(instance, attribute):
+ return instance_str(instance) + "." + attribute
+
+
+def state_attribute_str(state, attribute):
+ return state_str(state) + "." + attribute
+
+
+def object_mapper(instance):
+ """Given an object, return the primary Mapper associated with the object
+ instance.
+
+ Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
+ if no mapping is configured.
+
+ This function is available via the inspection system as::
+
+ inspect(instance).mapper
+
+ Using the inspection system will raise
+ :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
+ not part of a mapping.
+
+ """
+ return object_state(instance).mapper
+
+
+def object_state(instance):
+ """Given an object, return the :class:`.InstanceState`
+ associated with the object.
+
+ Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError`
+ if no mapping is configured.
+
+ Equivalent functionality is available via the :func:`_sa.inspect`
+ function as::
+
+ inspect(instance)
+
+ Using the inspection system will raise
+ :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is
+ not part of a mapping.
+
+ """
+ state = _inspect_mapped_object(instance)
+ if state is None:
+ raise exc.UnmappedInstanceError(instance)
+ else:
+ return state
+
+
+@inspection._inspects(object)
+def _inspect_mapped_object(instance):
+ try:
+ return instance_state(instance)
+ except (exc.UnmappedClassError,) + exc.NO_STATE:
+ return None
+
+
+def _class_to_mapper(class_or_mapper):
+ insp = inspection.inspect(class_or_mapper, False)
+ if insp is not None:
+ return insp.mapper
+ else:
+ raise exc.UnmappedClassError(class_or_mapper)
+
+
+def _mapper_or_none(entity):
+ """Return the :class:`_orm.Mapper` for the given class or None if the
+ class is not mapped.
+ """
+
+ insp = inspection.inspect(entity, False)
+ if insp is not None:
+ return insp.mapper
+ else:
+ return None
+
+
+def _is_mapped_class(entity):
+ """Return True if the given object is a mapped class,
+ :class:`_orm.Mapper`, or :class:`.AliasedClass`.
+ """
+
+ insp = inspection.inspect(entity, False)
+ return (
+ insp is not None
+ and not insp.is_clause_element
+ and (insp.is_mapper or insp.is_aliased_class)
+ )
+
+
+def _orm_columns(entity):
+ insp = inspection.inspect(entity, False)
+ if hasattr(insp, "selectable") and hasattr(insp.selectable, "c"):
+ return [c for c in insp.selectable.c]
+ else:
+ return [entity]
+
+
+def _is_aliased_class(entity):
+ insp = inspection.inspect(entity, False)
+ return insp is not None and getattr(insp, "is_aliased_class", False)
+
+
+def _entity_descriptor(entity, key):
+ """Return a class attribute given an entity and string name.
+
+ May return :class:`.InstrumentedAttribute` or user-defined
+ attribute.
+
+ """
+ insp = inspection.inspect(entity)
+ if insp.is_selectable:
+ description = entity
+ entity = insp.c
+ elif insp.is_aliased_class:
+ entity = insp.entity
+ description = entity
+ elif hasattr(insp, "mapper"):
+ description = entity = insp.mapper.class_
+ else:
+ description = entity
+
+ try:
+ return getattr(entity, key)
+ except AttributeError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Entity '%s' has no property '%s'" % (description, key)
+ ),
+ replace_context=err,
+ )
+
+
+_state_mapper = util.dottedgetter("manager.mapper")
+
+
+@inspection._inspects(type)
+def _inspect_mapped_class(class_, configure=False):
+ try:
+ class_manager = manager_of_class(class_)
+ if not class_manager.is_mapped:
+ return None
+ mapper = class_manager.mapper
+ except exc.NO_STATE:
+ return None
+ else:
+ if configure:
+ mapper._check_configure()
+ return mapper
+
+
+def class_mapper(class_, configure=True):
+ """Given a class, return the primary :class:`_orm.Mapper` associated
+ with the key.
+
+ Raises :exc:`.UnmappedClassError` if no mapping is configured
+ on the given class, or :exc:`.ArgumentError` if a non-class
+ object is passed.
+
+ Equivalent functionality is available via the :func:`_sa.inspect`
+ function as::
+
+ inspect(some_mapped_class)
+
+ Using the inspection system will raise
+ :class:`sqlalchemy.exc.NoInspectionAvailable` if the class is not mapped.
+
+ """
+ mapper = _inspect_mapped_class(class_, configure=configure)
+ if mapper is None:
+ if not isinstance(class_, type):
+ raise sa_exc.ArgumentError(
+ "Class object expected, got '%r'." % (class_,)
+ )
+ raise exc.UnmappedClassError(class_)
+ else:
+ return mapper
+
+
+class InspectionAttr(object):
+ """A base class applied to all ORM objects that can be returned
+ by the :func:`_sa.inspect` function.
+
+ The attributes defined here allow the usage of simple boolean
+ checks to test basic facts about the object returned.
+
+ While the boolean checks here are basically the same as using
+ the Python isinstance() function, the flags here can be used without
+ the need to import all of these classes, and also such that
+ the SQLAlchemy class system can change while leaving the flags
+ here intact for forwards-compatibility.
+
+ """
+
+ __slots__ = ()
+
+ is_selectable = False
+ """Return True if this object is an instance of
+ :class:`_expression.Selectable`."""
+
+ is_aliased_class = False
+ """True if this object is an instance of :class:`.AliasedClass`."""
+
+ is_instance = False
+ """True if this object is an instance of :class:`.InstanceState`."""
+
+ is_mapper = False
+ """True if this object is an instance of :class:`_orm.Mapper`."""
+
+ is_bundle = False
+ """True if this object is an instance of :class:`.Bundle`."""
+
+ is_property = False
+ """True if this object is an instance of :class:`.MapperProperty`."""
+
+ is_attribute = False
+ """True if this object is a Python :term:`descriptor`.
+
+ This can refer to one of many types. Usually a
+ :class:`.QueryableAttribute` which handles attributes events on behalf
+ of a :class:`.MapperProperty`. But can also be an extension type
+ such as :class:`.AssociationProxy` or :class:`.hybrid_property`.
+ The :attr:`.InspectionAttr.extension_type` will refer to a constant
+ identifying the specific subtype.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_descriptors`
+
+ """
+
+ _is_internal_proxy = False
+ """True if this object is an internal proxy object.
+
+ .. versionadded:: 1.2.12
+
+ """
+
+ is_clause_element = False
+ """True if this object is an instance of
+ :class:`_expression.ClauseElement`."""
+
+ extension_type = NOT_EXTENSION
+ """The extension type, if any.
+ Defaults to :data:`.interfaces.NOT_EXTENSION`
+
+ .. seealso::
+
+ :data:`.HYBRID_METHOD`
+
+ :data:`.HYBRID_PROPERTY`
+
+ :data:`.ASSOCIATION_PROXY`
+
+ """
+
+
+class InspectionAttrInfo(InspectionAttr):
+ """Adds the ``.info`` attribute to :class:`.InspectionAttr`.
+
+ The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo`
+ is that the former is compatible as a mixin for classes that specify
+ ``__slots__``; this is essentially an implementation artifact.
+
+ """
+
+ @util.memoized_property
+ def info(self):
+ """Info dictionary associated with the object, allowing user-defined
+ data to be associated with this :class:`.InspectionAttr`.
+
+ The dictionary is generated when first accessed. Alternatively,
+ it can be specified as a constructor argument to the
+ :func:`.column_property`, :func:`_orm.relationship`, or
+ :func:`.composite`
+ functions.
+
+ .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also
+ available on extension types via the
+ :attr:`.InspectionAttrInfo.info` attribute, so that it can apply
+ to a wider variety of ORM and extension constructs.
+
+ .. seealso::
+
+ :attr:`.QueryableAttribute.info`
+
+ :attr:`.SchemaItem.info`
+
+ """
+ return {}
+
+
+class _MappedAttribute(object):
+ """Mixin for attributes which should be replaced by mapper-assigned
+ attributes.
+
+ """
+
+ __slots__ = ()
diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py
new file mode 100644
index 0000000..2c21498
--- /dev/null
+++ b/lib/sqlalchemy/orm/clsregistry.py
@@ -0,0 +1,441 @@
+# ext/declarative/clsregistry.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Routines to handle the string class registry used by declarative.
+
+This system allows specification of classes and expressions used in
+:func:`_orm.relationship` using strings.
+
+"""
+import weakref
+
+from . import attributes
+from . import interfaces
+from .descriptor_props import SynonymProperty
+from .properties import ColumnProperty
+from .util import class_mapper
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql.schema import _get_table_key
+
+# strong references to registries which we place in
+# the _decl_class_registry, which is usually weak referencing.
+# the internal registries here link to classes with weakrefs and remove
+# themselves when all references to contained classes are removed.
+_registries = set()
+
+
+def add_class(classname, cls, decl_class_registry):
+ """Add a class to the _decl_class_registry associated with the
+ given declarative class.
+
+ """
+ if classname in decl_class_registry:
+ # class already exists.
+ existing = decl_class_registry[classname]
+ if not isinstance(existing, _MultipleClassMarker):
+ existing = decl_class_registry[classname] = _MultipleClassMarker(
+ [cls, existing]
+ )
+ else:
+ decl_class_registry[classname] = cls
+
+ try:
+ root_module = decl_class_registry["_sa_module_registry"]
+ except KeyError:
+ decl_class_registry[
+ "_sa_module_registry"
+ ] = root_module = _ModuleMarker("_sa_module_registry", None)
+
+ tokens = cls.__module__.split(".")
+
+ # build up a tree like this:
+ # modulename: myapp.snacks.nuts
+ #
+ # myapp->snack->nuts->(classes)
+ # snack->nuts->(classes)
+ # nuts->(classes)
+ #
+ # this allows partial token paths to be used.
+ while tokens:
+ token = tokens.pop(0)
+ module = root_module.get_module(token)
+ for token in tokens:
+ module = module.get_module(token)
+ module.add_class(classname, cls)
+
+
+def remove_class(classname, cls, decl_class_registry):
+ if classname in decl_class_registry:
+ existing = decl_class_registry[classname]
+ if isinstance(existing, _MultipleClassMarker):
+ existing.remove_item(cls)
+ else:
+ del decl_class_registry[classname]
+
+ try:
+ root_module = decl_class_registry["_sa_module_registry"]
+ except KeyError:
+ return
+
+ tokens = cls.__module__.split(".")
+
+ while tokens:
+ token = tokens.pop(0)
+ module = root_module.get_module(token)
+ for token in tokens:
+ module = module.get_module(token)
+ module.remove_class(classname, cls)
+
+
+def _key_is_empty(key, decl_class_registry, test):
+ """test if a key is empty of a certain object.
+
+ used for unit tests against the registry to see if garbage collection
+ is working.
+
+ "test" is a callable that will be passed an object should return True
+ if the given object is the one we were looking for.
+
+ We can't pass the actual object itself b.c. this is for testing garbage
+ collection; the caller will have to have removed references to the
+ object itself.
+
+ """
+ if key not in decl_class_registry:
+ return True
+
+ thing = decl_class_registry[key]
+ if isinstance(thing, _MultipleClassMarker):
+ for sub_thing in thing.contents:
+ if test(sub_thing):
+ return False
+ else:
+ return not test(thing)
+
+
+class _MultipleClassMarker(object):
+ """refers to multiple classes of the same name
+ within _decl_class_registry.
+
+ """
+
+ __slots__ = "on_remove", "contents", "__weakref__"
+
+ def __init__(self, classes, on_remove=None):
+ self.on_remove = on_remove
+ self.contents = set(
+ [weakref.ref(item, self._remove_item) for item in classes]
+ )
+ _registries.add(self)
+
+ def remove_item(self, cls):
+ self._remove_item(weakref.ref(cls))
+
+ def __iter__(self):
+ return (ref() for ref in self.contents)
+
+ def attempt_get(self, path, key):
+ if len(self.contents) > 1:
+ raise exc.InvalidRequestError(
+ 'Multiple classes found for path "%s" '
+ "in the registry of this declarative "
+ "base. Please use a fully module-qualified path."
+ % (".".join(path + [key]))
+ )
+ else:
+ ref = list(self.contents)[0]
+ cls = ref()
+ if cls is None:
+ raise NameError(key)
+ return cls
+
+ def _remove_item(self, ref):
+ self.contents.discard(ref)
+ if not self.contents:
+ _registries.discard(self)
+ if self.on_remove:
+ self.on_remove()
+
+ def add_item(self, item):
+ # protect against class registration race condition against
+ # asynchronous garbage collection calling _remove_item,
+ # [ticket:3208]
+ modules = set(
+ [
+ cls.__module__
+ for cls in [ref() for ref in self.contents]
+ if cls is not None
+ ]
+ )
+ if item.__module__ in modules:
+ util.warn(
+ "This declarative base already contains a class with the "
+ "same class name and module name as %s.%s, and will "
+ "be replaced in the string-lookup table."
+ % (item.__module__, item.__name__)
+ )
+ self.contents.add(weakref.ref(item, self._remove_item))
+
+
+class _ModuleMarker(object):
+ """Refers to a module name within
+ _decl_class_registry.
+
+ """
+
+ __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
+
+ def __init__(self, name, parent):
+ self.parent = parent
+ self.name = name
+ self.contents = {}
+ self.mod_ns = _ModNS(self)
+ if self.parent:
+ self.path = self.parent.path + [self.name]
+ else:
+ self.path = []
+ _registries.add(self)
+
+ def __contains__(self, name):
+ return name in self.contents
+
+ def __getitem__(self, name):
+ return self.contents[name]
+
+ def _remove_item(self, name):
+ self.contents.pop(name, None)
+ if not self.contents and self.parent is not None:
+ self.parent._remove_item(self.name)
+ _registries.discard(self)
+
+ def resolve_attr(self, key):
+ return getattr(self.mod_ns, key)
+
+ def get_module(self, name):
+ if name not in self.contents:
+ marker = _ModuleMarker(name, self)
+ self.contents[name] = marker
+ else:
+ marker = self.contents[name]
+ return marker
+
+ def add_class(self, name, cls):
+ if name in self.contents:
+ existing = self.contents[name]
+ existing.add_item(cls)
+ else:
+ existing = self.contents[name] = _MultipleClassMarker(
+ [cls], on_remove=lambda: self._remove_item(name)
+ )
+
+ def remove_class(self, name, cls):
+ if name in self.contents:
+ existing = self.contents[name]
+ existing.remove_item(cls)
+
+
+class _ModNS(object):
+ __slots__ = ("__parent",)
+
+ def __init__(self, parent):
+ self.__parent = parent
+
+ def __getattr__(self, key):
+ try:
+ value = self.__parent.contents[key]
+ except KeyError:
+ pass
+ else:
+ if value is not None:
+ if isinstance(value, _ModuleMarker):
+ return value.mod_ns
+ else:
+ assert isinstance(value, _MultipleClassMarker)
+ return value.attempt_get(self.__parent.path, key)
+ raise NameError(
+ "Module %r has no mapped classes "
+ "registered under the name %r" % (self.__parent.name, key)
+ )
+
+
+class _GetColumns(object):
+ __slots__ = ("cls",)
+
+ def __init__(self, cls):
+ self.cls = cls
+
+ def __getattr__(self, key):
+ mp = class_mapper(self.cls, configure=False)
+ if mp:
+ if key not in mp.all_orm_descriptors:
+ raise AttributeError(
+ "Class %r does not have a mapped column named %r"
+ % (self.cls, key)
+ )
+
+ desc = mp.all_orm_descriptors[key]
+ if desc.extension_type is interfaces.NOT_EXTENSION:
+ prop = desc.property
+ if isinstance(prop, SynonymProperty):
+ key = prop.name
+ elif not isinstance(prop, ColumnProperty):
+ raise exc.InvalidRequestError(
+ "Property %r is not an instance of"
+ " ColumnProperty (i.e. does not correspond"
+ " directly to a Column)." % key
+ )
+ return getattr(self.cls, key)
+
+
+inspection._inspects(_GetColumns)(
+ lambda target: inspection.inspect(target.cls)
+)
+
+
+class _GetTable(object):
+ __slots__ = "key", "metadata"
+
+ def __init__(self, key, metadata):
+ self.key = key
+ self.metadata = metadata
+
+ def __getattr__(self, key):
+ return self.metadata.tables[_get_table_key(key, self.key)]
+
+
+def _determine_container(key, value):
+ if isinstance(value, _MultipleClassMarker):
+ value = value.attempt_get([], key)
+ return _GetColumns(value)
+
+
+class _class_resolver(object):
+ __slots__ = (
+ "cls",
+ "prop",
+ "arg",
+ "fallback",
+ "_dict",
+ "_resolvers",
+ "favor_tables",
+ )
+
+ def __init__(self, cls, prop, fallback, arg, favor_tables=False):
+ self.cls = cls
+ self.prop = prop
+ self.arg = arg
+ self.fallback = fallback
+ self._dict = util.PopulateDict(self._access_cls)
+ self._resolvers = ()
+ self.favor_tables = favor_tables
+
+ def _access_cls(self, key):
+ cls = self.cls
+
+ manager = attributes.manager_of_class(cls)
+ decl_base = manager.registry
+ decl_class_registry = decl_base._class_registry
+ metadata = decl_base.metadata
+
+ if self.favor_tables:
+ if key in metadata.tables:
+ return metadata.tables[key]
+ elif key in metadata._schemas:
+ return _GetTable(key, cls.metadata)
+
+ if key in decl_class_registry:
+ return _determine_container(key, decl_class_registry[key])
+
+ if not self.favor_tables:
+ if key in metadata.tables:
+ return metadata.tables[key]
+ elif key in metadata._schemas:
+ return _GetTable(key, cls.metadata)
+
+ if (
+ "_sa_module_registry" in decl_class_registry
+ and key in decl_class_registry["_sa_module_registry"]
+ ):
+ registry = decl_class_registry["_sa_module_registry"]
+ return registry.resolve_attr(key)
+ elif self._resolvers:
+ for resolv in self._resolvers:
+ value = resolv(key)
+ if value is not None:
+ return value
+
+ return self.fallback[key]
+
+ def _raise_for_name(self, name, err):
+ util.raise_(
+ exc.InvalidRequestError(
+ "When initializing mapper %s, expression %r failed to "
+ "locate a name (%r). If this is a class name, consider "
+ "adding this relationship() to the %r class after "
+ "both dependent classes have been defined."
+ % (self.prop.parent, self.arg, name, self.cls)
+ ),
+ from_=err,
+ )
+
+ def _resolve_name(self):
+ name = self.arg
+ d = self._dict
+ rval = None
+ try:
+ for token in name.split("."):
+ if rval is None:
+ rval = d[token]
+ else:
+ rval = getattr(rval, token)
+ except KeyError as err:
+ self._raise_for_name(name, err)
+ except NameError as n:
+ self._raise_for_name(n.args[0], n)
+ else:
+ if isinstance(rval, _GetColumns):
+ return rval.cls
+ else:
+ return rval
+
+ def __call__(self):
+ try:
+ x = eval(self.arg, globals(), self._dict)
+
+ if isinstance(x, _GetColumns):
+ return x.cls
+ else:
+ return x
+ except NameError as n:
+ self._raise_for_name(n.args[0], n)
+
+
+_fallback_dict = None
+
+
+def _resolver(cls, prop):
+
+ global _fallback_dict
+
+ if _fallback_dict is None:
+ import sqlalchemy
+ from sqlalchemy.orm import foreign, remote
+
+ _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
+ {"foreign": foreign, "remote": remote}
+ )
+
+ def resolve_arg(arg, favor_tables=False):
+ return _class_resolver(
+ cls, prop, _fallback_dict, arg, favor_tables=favor_tables
+ )
+
+ def resolve_name(arg):
+ return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
+
+ return resolve_name, resolve_arg
diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py
new file mode 100644
index 0000000..a189f02
--- /dev/null
+++ b/lib/sqlalchemy/orm/collections.py
@@ -0,0 +1,1706 @@
+# orm/collections.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Support for collections of mapped entities.
+
+The collections package supplies the machinery used to inform the ORM of
+collection membership changes. An instrumentation via decoration approach is
+used, allowing arbitrary types (including built-ins) to be used as entity
+collections without requiring inheritance from a base class.
+
+Instrumentation decoration relays membership change events to the
+:class:`.CollectionAttributeImpl` that is currently managing the collection.
+The decorators observe function call arguments and return values, tracking
+entities entering or leaving the collection. Two decorator approaches are
+provided. One is a bundle of generic decorators that map function arguments
+and return values to events::
+
+ from sqlalchemy.orm.collections import collection
+ class MyClass(object):
+ # ...
+
+ @collection.adds(1)
+ def store(self, item):
+ self.data.append(item)
+
+ @collection.removes_return()
+ def pop(self):
+ return self.data.pop()
+
+
+The second approach is a bundle of targeted decorators that wrap appropriate
+append and remove notifiers around the mutation methods present in the
+standard Python ``list``, ``set`` and ``dict`` interfaces. These could be
+specified in terms of generic decorator recipes, but are instead hand-tooled
+for increased efficiency. The targeted decorators occasionally implement
+adapter-like behavior, such as mapping bulk-set methods (``extend``,
+``update``, ``__setslice__``, etc.) into the series of atomic mutation events
+that the ORM requires.
+
+The targeted decorators are used internally for automatic instrumentation of
+entity collection classes. Every collection class goes through a
+transformation process roughly like so:
+
+1. If the class is a built-in, substitute a trivial sub-class
+2. Is this class already instrumented?
+3. Add in generic decorators
+4. Sniff out the collection interface through duck-typing
+5. Add targeted decoration to any undecorated interface method
+
+This process modifies the class at runtime, decorating methods and adding some
+bookkeeping properties. This isn't possible (or desirable) for built-in
+classes like ``list``, so trivial sub-classes are substituted to hold
+decoration::
+
+ class InstrumentedList(list):
+ pass
+
+Collection classes can be specified in ``relationship(collection_class=)`` as
+types or a function that returns an instance. Collection classes are
+inspected and instrumented during the mapper compilation phase. The
+collection_class callable will be executed once to produce a specimen
+instance, and the type of that specimen will be instrumented. Functions that
+return built-in types like ``lists`` will be adapted to produce instrumented
+instances.
+
+When extending a known type like ``list``, additional decorations are not
+generally not needed. Odds are, the extension method will delegate to a
+method that's already instrumented. For example::
+
+ class QueueIsh(list):
+ def push(self, item):
+ self.append(item)
+ def shift(self):
+ return self.pop(0)
+
+There's no need to decorate these methods. ``append`` and ``pop`` are already
+instrumented as part of the ``list`` interface. Decorating them would fire
+duplicate events, which should be avoided.
+
+The targeted decoration tries not to rely on other methods in the underlying
+collection class, but some are unavoidable. Many depend on 'read' methods
+being present to properly instrument a 'write', for example, ``__setitem__``
+needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also
+reimplemented in terms of atomic appends and removes, so the ``extend``
+decoration will actually perform many ``append`` operations and not call the
+underlying method at all.
+
+Tight control over bulk operation and the firing of events is also possible by
+implementing the instrumentation internally in your methods. The basic
+instrumentation package works under the general assumption that collection
+mutation will not raise unusual exceptions. If you want to closely
+orchestrate append and remove events with exception management, internal
+instrumentation may be the answer. Within your method,
+``collection_adapter(self)`` will retrieve an object that you can use for
+explicit control over triggering append and remove events.
+
+The owning object and :class:`.CollectionAttributeImpl` are also reachable
+through the adapter, allowing for some very sophisticated behavior.
+
+"""
+
+import operator
+import weakref
+
+from sqlalchemy.util.compat import inspect_getfullargspec
+from . import base
+from .. import exc as sa_exc
+from .. import util
+from ..sql import coercions
+from ..sql import expression
+from ..sql import roles
+
+__all__ = [
+ "collection",
+ "collection_adapter",
+ "mapped_collection",
+ "column_mapped_collection",
+ "attribute_mapped_collection",
+]
+
+__instrumentation_mutex = util.threading.Lock()
+
+
+class _PlainColumnGetter(object):
+ """Plain column getter, stores collection of Column objects
+ directly.
+
+ Serializes to a :class:`._SerializableColumnGetterV2`
+ which has more expensive __call__() performance
+ and some rare caveats.
+
+ """
+
+ def __init__(self, cols):
+ self.cols = cols
+ self.composite = len(cols) > 1
+
+ def __reduce__(self):
+ return _SerializableColumnGetterV2._reduce_from_cols(self.cols)
+
+ def _cols(self, mapper):
+ return self.cols
+
+ def __call__(self, value):
+ state = base.instance_state(value)
+ m = base._state_mapper(state)
+
+ key = [
+ m._get_state_attr_by_column(state, state.dict, col)
+ for col in self._cols(m)
+ ]
+
+ if self.composite:
+ return tuple(key)
+ else:
+ return key[0]
+
+
+class _SerializableColumnGetter(object):
+ """Column-based getter used in version 0.7.6 only.
+
+ Remains here for pickle compatibility with 0.7.6.
+
+ """
+
+ def __init__(self, colkeys):
+ self.colkeys = colkeys
+ self.composite = len(colkeys) > 1
+
+ def __reduce__(self):
+ return _SerializableColumnGetter, (self.colkeys,)
+
+ def __call__(self, value):
+ state = base.instance_state(value)
+ m = base._state_mapper(state)
+ key = [
+ m._get_state_attr_by_column(
+ state, state.dict, m.mapped_table.columns[k]
+ )
+ for k in self.colkeys
+ ]
+ if self.composite:
+ return tuple(key)
+ else:
+ return key[0]
+
+
+class _SerializableColumnGetterV2(_PlainColumnGetter):
+ """Updated serializable getter which deals with
+ multi-table mapped classes.
+
+ Two extremely unusual cases are not supported.
+ Mappings which have tables across multiple metadata
+ objects, or which are mapped to non-Table selectables
+ linked across inheriting mappers may fail to function
+ here.
+
+ """
+
+ def __init__(self, colkeys):
+ self.colkeys = colkeys
+ self.composite = len(colkeys) > 1
+
+ def __reduce__(self):
+ return self.__class__, (self.colkeys,)
+
+ @classmethod
+ def _reduce_from_cols(cls, cols):
+ def _table_key(c):
+ if not isinstance(c.table, expression.TableClause):
+ return None
+ else:
+ return c.table.key
+
+ colkeys = [(c.key, _table_key(c)) for c in cols]
+ return _SerializableColumnGetterV2, (colkeys,)
+
+ def _cols(self, mapper):
+ cols = []
+ metadata = getattr(mapper.local_table, "metadata", None)
+ for (ckey, tkey) in self.colkeys:
+ if tkey is None or metadata is None or tkey not in metadata:
+ cols.append(mapper.local_table.c[ckey])
+ else:
+ cols.append(metadata.tables[tkey].c[ckey])
+ return cols
+
+
+def column_mapped_collection(mapping_spec):
+ """A dictionary-based collection type with column-based keying.
+
+ Returns a :class:`.MappedCollection` factory with a keying function
+ generated from mapping_spec, which may be a Column or a sequence
+ of Columns.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+
+ """
+ cols = [
+ coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec")
+ for q in util.to_list(mapping_spec)
+ ]
+ keyfunc = _PlainColumnGetter(cols)
+ return lambda: MappedCollection(keyfunc)
+
+
+class _SerializableAttrGetter(object):
+ def __init__(self, name):
+ self.name = name
+ self.getter = operator.attrgetter(name)
+
+ def __call__(self, target):
+ return self.getter(target)
+
+ def __reduce__(self):
+ return _SerializableAttrGetter, (self.name,)
+
+
+def attribute_mapped_collection(attr_name):
+ """A dictionary-based collection type with attribute-based keying.
+
+ Returns a :class:`.MappedCollection` factory with a keying based on the
+ 'attr_name' attribute of entities in the collection, where ``attr_name``
+ is the string name of the attribute.
+
+ .. warning:: the key value must be assigned to its final value
+ **before** it is accessed by the attribute mapped collection.
+ Additionally, changes to the key attribute are **not tracked**
+ automatically, which means the key in the dictionary is not
+ automatically synchronized with the key value on the target object
+ itself. See the section :ref:`key_collections_mutations`
+ for an example.
+
+ """
+ getter = _SerializableAttrGetter(attr_name)
+ return lambda: MappedCollection(getter)
+
+
+def mapped_collection(keyfunc):
+ """A dictionary-based collection type with arbitrary keying.
+
+ Returns a :class:`.MappedCollection` factory with a keying function
+ generated from keyfunc, a callable that takes an entity and returns a
+ key value.
+
+ The key value must be immutable for the lifetime of the object. You
+ can not, for example, map on foreign key values if those key values will
+ change during the session, i.e. from None to a database-assigned integer
+ after a session flush.
+
+ """
+ return lambda: MappedCollection(keyfunc)
+
+
+class collection(object):
+ """Decorators for entity collection classes.
+
+ The decorators fall into two groups: annotations and interception recipes.
+
+ The annotating decorators (appender, remover, iterator, converter,
+ internally_instrumented) indicate the method's purpose and take no
+ arguments. They are not written with parens::
+
+ @collection.appender
+ def append(self, append): ...
+
+ The recipe decorators all require parens, even those that take no
+ arguments::
+
+ @collection.adds('entity')
+ def insert(self, position, entity): ...
+
+ @collection.removes_return()
+ def popitem(self): ...
+
+ """
+
+ # Bundled as a class solely for ease of use: packaging, doc strings,
+ # importability.
+
+ @staticmethod
+ def appender(fn):
+ """Tag the method as the collection appender.
+
+ The appender method is called with one positional argument: the value
+ to append. The method will be automatically decorated with 'adds(1)'
+ if not already decorated::
+
+ @collection.appender
+ def add(self, append): ...
+
+ # or, equivalently
+ @collection.appender
+ @collection.adds(1)
+ def add(self, append): ...
+
+ # for mapping type, an 'append' may kick out a previous value
+ # that occupies that slot. consider d['a'] = 'foo'- any previous
+ # value in d['a'] is discarded.
+ @collection.appender
+ @collection.replaces(1)
+ def add(self, entity):
+ key = some_key_func(entity)
+ previous = None
+ if key in self:
+ previous = self[key]
+ self[key] = entity
+ return previous
+
+ If the value to append is not allowed in the collection, you may
+ raise an exception. Something to remember is that the appender
+ will be called for each object mapped by a database query. If the
+ database contains rows that violate your collection semantics, you
+ will need to get creative to fix the problem, as access via the
+ collection will not work.
+
+ If the appender method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+
+ """
+ fn._sa_instrument_role = "appender"
+ return fn
+
+ @staticmethod
+ def remover(fn):
+ """Tag the method as the collection remover.
+
+ The remover method is called with one positional argument: the value
+ to remove. The method will be automatically decorated with
+ :meth:`removes_return` if not already decorated::
+
+ @collection.remover
+ def zap(self, entity): ...
+
+ # or, equivalently
+ @collection.remover
+ @collection.removes_return()
+ def zap(self, ): ...
+
+ If the value to remove is not present in the collection, you may
+ raise an exception or return None to ignore the error.
+
+ If the remove method is internally instrumented, you must also
+ receive the keyword argument '_sa_initiator' and ensure its
+ promulgation to collection events.
+
+ """
+ fn._sa_instrument_role = "remover"
+ return fn
+
+ @staticmethod
+ def iterator(fn):
+ """Tag the method as the collection remover.
+
+ The iterator method is called with no arguments. It is expected to
+ return an iterator over all collection members::
+
+ @collection.iterator
+ def __iter__(self): ...
+
+ """
+ fn._sa_instrument_role = "iterator"
+ return fn
+
+ @staticmethod
+ def internally_instrumented(fn):
+ """Tag the method as instrumented.
+
+ This tag will prevent any decoration from being applied to the
+ method. Use this if you are orchestrating your own calls to
+ :func:`.collection_adapter` in one of the basic SQLAlchemy
+ interface methods, or to prevent an automatic ABC method
+ decoration from wrapping your implementation::
+
+ # normally an 'extend' method on a list-like class would be
+ # automatically intercepted and re-implemented in terms of
+ # SQLAlchemy events and append(). your implementation will
+ # never be called, unless:
+ @collection.internally_instrumented
+ def extend(self, items): ...
+
+ """
+ fn._sa_instrumented = True
+ return fn
+
+ @staticmethod
+ @util.deprecated(
+ "1.3",
+ "The :meth:`.collection.converter` handler is deprecated and will "
+ "be removed in a future release. Please refer to the "
+ ":class:`.AttributeEvents.bulk_replace` listener interface in "
+ "conjunction with the :func:`.event.listen` function.",
+ )
+ def converter(fn):
+ """Tag the method as the collection converter.
+
+ This optional method will be called when a collection is being
+ replaced entirely, as in::
+
+ myobj.acollection = [newvalue1, newvalue2]
+
+ The converter method will receive the object being assigned and should
+ return an iterable of values suitable for use by the ``appender``
+ method. A converter must not assign values or mutate the collection,
+ its sole job is to adapt the value the user provides into an iterable
+ of values for the ORM's use.
+
+ The default converter implementation will use duck-typing to do the
+ conversion. A dict-like collection will be convert into an iterable
+ of dictionary values, and other types will simply be iterated::
+
+ @collection.converter
+ def convert(self, other): ...
+
+ If the duck-typing of the object does not match the type of this
+ collection, a TypeError is raised.
+
+ Supply an implementation of this method if you want to expand the
+ range of possible types that can be assigned in bulk or perform
+ validation on the values about to be assigned.
+
+ """
+ fn._sa_instrument_role = "converter"
+ return fn
+
+ @staticmethod
+ def adds(arg):
+ """Mark the method as adding an entity to the collection.
+
+ Adds "add to collection" handling to the method. The decorator
+ argument indicates which method argument holds the SQLAlchemy-relevant
+ value. Arguments can be specified positionally (i.e. integer) or by
+ name::
+
+ @collection.adds(1)
+ def push(self, item): ...
+
+ @collection.adds('entity')
+ def do_stuff(self, thing, entity=None): ...
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_before = ("fire_append_event", arg)
+ return fn
+
+ return decorator
+
+ @staticmethod
+ def replaces(arg):
+ """Mark the method as replacing an entity in the collection.
+
+ Adds "add to collection" and "remove from collection" handling to
+ the method. The decorator argument indicates which method argument
+ holds the SQLAlchemy-relevant value to be added, and return value, if
+ any will be considered the value to remove.
+
+ Arguments can be specified positionally (i.e. integer) or by name::
+
+ @collection.replaces(2)
+ def __setitem__(self, index, item): ...
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_before = ("fire_append_event", arg)
+ fn._sa_instrument_after = "fire_remove_event"
+ return fn
+
+ return decorator
+
+ @staticmethod
+ def removes(arg):
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The decorator
+ argument indicates which method argument holds the SQLAlchemy-relevant
+ value to be removed. Arguments can be specified positionally (i.e.
+ integer) or by name::
+
+ @collection.removes(1)
+ def zap(self, item): ...
+
+ For methods where the value to remove is not known at call-time, use
+ collection.removes_return.
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_before = ("fire_remove_event", arg)
+ return fn
+
+ return decorator
+
+ @staticmethod
+ def removes_return():
+ """Mark the method as removing an entity in the collection.
+
+ Adds "remove from collection" handling to the method. The return
+ value of the method, if any, is considered the value to remove. The
+ method arguments are not inspected::
+
+ @collection.removes_return()
+ def pop(self): ...
+
+ For methods where the value to remove is known at call-time, use
+ collection.remove.
+
+ """
+
+ def decorator(fn):
+ fn._sa_instrument_after = "fire_remove_event"
+ return fn
+
+ return decorator
+
+
+collection_adapter = operator.attrgetter("_sa_adapter")
+"""Fetch the :class:`.CollectionAdapter` for a collection."""
+
+
+class CollectionAdapter(object):
+ """Bridges between the ORM and arbitrary Python collections.
+
+ Proxies base-level collection operations (append, remove, iterate)
+ to the underlying Python collection, and emits add/remove events for
+ entities entering or leaving the collection.
+
+ The ORM uses :class:`.CollectionAdapter` exclusively for interaction with
+ entity collections.
+
+
+ """
+
+ __slots__ = (
+ "attr",
+ "_key",
+ "_data",
+ "owner_state",
+ "_converter",
+ "invalidated",
+ "empty",
+ )
+
+ def __init__(self, attr, owner_state, data):
+ self.attr = attr
+ self._key = attr.key
+ self._data = weakref.ref(data)
+ self.owner_state = owner_state
+ data._sa_adapter = self
+ self._converter = data._sa_converter
+ self.invalidated = False
+ self.empty = False
+
+ def _warn_invalidated(self):
+ util.warn("This collection has been invalidated.")
+
+ @property
+ def data(self):
+ "The entity collection being adapted."
+ return self._data()
+
+ @property
+ def _referenced_by_owner(self):
+ """return True if the owner state still refers to this collection.
+
+ This will return False within a bulk replace operation,
+ where this collection is the one being replaced.
+
+ """
+ return self.owner_state.dict[self._key] is self._data()
+
+ def bulk_appender(self):
+ return self._data()._sa_appender
+
+ def append_with_event(self, item, initiator=None):
+ """Add an entity to the collection, firing mutation events."""
+
+ self._data()._sa_appender(item, _sa_initiator=initiator)
+
+ def _set_empty(self, user_data):
+ assert (
+ not self.empty
+ ), "This collection adapter is already in the 'empty' state"
+ self.empty = True
+ self.owner_state._empty_collections[self._key] = user_data
+
+ def _reset_empty(self):
+ assert (
+ self.empty
+ ), "This collection adapter is not in the 'empty' state"
+ self.empty = False
+ self.owner_state.dict[
+ self._key
+ ] = self.owner_state._empty_collections.pop(self._key)
+
+ def _refuse_empty(self):
+ raise sa_exc.InvalidRequestError(
+ "This is a special 'empty' collection which cannot accommodate "
+ "internal mutation operations"
+ )
+
+ def append_without_event(self, item):
+ """Add or restore an entity to the collection, firing no events."""
+
+ if self.empty:
+ self._refuse_empty()
+ self._data()._sa_appender(item, _sa_initiator=False)
+
+ def append_multiple_without_event(self, items):
+ """Add or restore an entity to the collection, firing no events."""
+ if self.empty:
+ self._refuse_empty()
+ appender = self._data()._sa_appender
+ for item in items:
+ appender(item, _sa_initiator=False)
+
+ def bulk_remover(self):
+ return self._data()._sa_remover
+
+ def remove_with_event(self, item, initiator=None):
+ """Remove an entity from the collection, firing mutation events."""
+ self._data()._sa_remover(item, _sa_initiator=initiator)
+
+ def remove_without_event(self, item):
+ """Remove an entity from the collection, firing no events."""
+ if self.empty:
+ self._refuse_empty()
+ self._data()._sa_remover(item, _sa_initiator=False)
+
+ def clear_with_event(self, initiator=None):
+ """Empty the collection, firing a mutation event for each entity."""
+
+ if self.empty:
+ self._refuse_empty()
+ remover = self._data()._sa_remover
+ for item in list(self):
+ remover(item, _sa_initiator=initiator)
+
+ def clear_without_event(self):
+ """Empty the collection, firing no events."""
+
+ if self.empty:
+ self._refuse_empty()
+ remover = self._data()._sa_remover
+ for item in list(self):
+ remover(item, _sa_initiator=False)
+
+ def __iter__(self):
+ """Iterate over entities in the collection."""
+
+ return iter(self._data()._sa_iterator())
+
+ def __len__(self):
+ """Count entities in the collection."""
+ return len(list(self._data()._sa_iterator()))
+
+ def __bool__(self):
+ return True
+
+ __nonzero__ = __bool__
+
+ def fire_append_wo_mutation_event(self, item, initiator=None):
+ """Notify that a entity is entering the collection but is already
+ present.
+
+
+ Initiator is a token owned by the InstrumentedAttribute that
+ initiated the membership mutation, and should be left as None
+ unless you are passing along an initiator value from a chained
+ operation.
+
+ .. versionadded:: 1.4.15
+
+ """
+ if initiator is not False:
+ if self.invalidated:
+ self._warn_invalidated()
+
+ if self.empty:
+ self._reset_empty()
+
+ return self.attr.fire_append_wo_mutation_event(
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
+ else:
+ return item
+
+ def fire_append_event(self, item, initiator=None):
+ """Notify that a entity has entered the collection.
+
+ Initiator is a token owned by the InstrumentedAttribute that
+ initiated the membership mutation, and should be left as None
+ unless you are passing along an initiator value from a chained
+ operation.
+
+ """
+ if initiator is not False:
+ if self.invalidated:
+ self._warn_invalidated()
+
+ if self.empty:
+ self._reset_empty()
+
+ return self.attr.fire_append_event(
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
+ else:
+ return item
+
+ def fire_remove_event(self, item, initiator=None):
+ """Notify that a entity has been removed from the collection.
+
+ Initiator is the InstrumentedAttribute that initiated the membership
+ mutation, and should be left as None unless you are passing along
+ an initiator value from a chained operation.
+
+ """
+ if initiator is not False:
+ if self.invalidated:
+ self._warn_invalidated()
+
+ if self.empty:
+ self._reset_empty()
+
+ self.attr.fire_remove_event(
+ self.owner_state, self.owner_state.dict, item, initiator
+ )
+
+ def fire_pre_remove_event(self, initiator=None):
+ """Notify that an entity is about to be removed from the collection.
+
+ Only called if the entity cannot be removed after calling
+ fire_remove_event().
+
+ """
+ if self.invalidated:
+ self._warn_invalidated()
+ self.attr.fire_pre_remove_event(
+ self.owner_state, self.owner_state.dict, initiator=initiator
+ )
+
+ def __getstate__(self):
+ return {
+ "key": self._key,
+ "owner_state": self.owner_state,
+ "owner_cls": self.owner_state.class_,
+ "data": self.data,
+ "invalidated": self.invalidated,
+ "empty": self.empty,
+ }
+
+ def __setstate__(self, d):
+ self._key = d["key"]
+ self.owner_state = d["owner_state"]
+ self._data = weakref.ref(d["data"])
+ self._converter = d["data"]._sa_converter
+ d["data"]._sa_adapter = self
+ self.invalidated = d["invalidated"]
+ self.attr = getattr(d["owner_cls"], self._key).impl
+ self.empty = d.get("empty", False)
+
+
+def bulk_replace(values, existing_adapter, new_adapter, initiator=None):
+ """Load a new collection, firing events based on prior like membership.
+
+ Appends instances in ``values`` onto the ``new_adapter``. Events will be
+ fired for any instance not present in the ``existing_adapter``. Any
+ instances in ``existing_adapter`` not present in ``values`` will have
+ remove events fired upon them.
+
+ :param values: An iterable of collection member instances
+
+ :param existing_adapter: A :class:`.CollectionAdapter` of
+ instances to be replaced
+
+ :param new_adapter: An empty :class:`.CollectionAdapter`
+ to load with ``values``
+
+
+ """
+
+ assert isinstance(values, list)
+
+ idset = util.IdentitySet
+ existing_idset = idset(existing_adapter or ())
+ constants = existing_idset.intersection(values or ())
+ additions = idset(values or ()).difference(constants)
+ removals = existing_idset.difference(constants)
+
+ appender = new_adapter.bulk_appender()
+
+ for member in values or ():
+ if member in additions:
+ appender(member, _sa_initiator=initiator)
+ elif member in constants:
+ appender(member, _sa_initiator=False)
+
+ if existing_adapter:
+ for member in removals:
+ existing_adapter.fire_remove_event(member, initiator=initiator)
+
+
+def prepare_instrumentation(factory):
+ """Prepare a callable for future use as a collection class factory.
+
+ Given a collection class factory (either a type or no-arg callable),
+ return another factory that will produce compatible instances when
+ called.
+
+ This function is responsible for converting collection_class=list
+ into the run-time behavior of collection_class=InstrumentedList.
+
+ """
+ # Convert a builtin to 'Instrumented*'
+ if factory in __canned_instrumentation:
+ factory = __canned_instrumentation[factory]
+
+ # Create a specimen
+ cls = type(factory())
+
+ # Did factory callable return a builtin?
+ if cls in __canned_instrumentation:
+ # Wrap it so that it returns our 'Instrumented*'
+ factory = __converting_factory(cls, factory)
+ cls = factory()
+
+ # Instrument the class if needed.
+ if __instrumentation_mutex.acquire():
+ try:
+ if getattr(cls, "_sa_instrumented", None) != id(cls):
+ _instrument_class(cls)
+ finally:
+ __instrumentation_mutex.release()
+
+ return factory
+
+
+def __converting_factory(specimen_cls, original_factory):
+ """Return a wrapper that converts a "canned" collection like
+ set, dict, list into the Instrumented* version.
+
+ """
+
+ instrumented_cls = __canned_instrumentation[specimen_cls]
+
+ def wrapper():
+ collection = original_factory()
+ return instrumented_cls(collection)
+
+ # often flawed but better than nothing
+ wrapper.__name__ = "%sWrapper" % original_factory.__name__
+ wrapper.__doc__ = original_factory.__doc__
+
+ return wrapper
+
+
+def _instrument_class(cls):
+ """Modify methods in a class and install instrumentation."""
+
+ # In the normal call flow, a request for any of the 3 basic collection
+ # types is transformed into one of our trivial subclasses
+ # (e.g. InstrumentedList). Catch anything else that sneaks in here...
+ if cls.__module__ == "__builtin__":
+ raise sa_exc.ArgumentError(
+ "Can not instrument a built-in type. Use a "
+ "subclass, even a trivial one."
+ )
+
+ roles, methods = _locate_roles_and_methods(cls)
+
+ _setup_canned_roles(cls, roles, methods)
+
+ _assert_required_roles(cls, roles, methods)
+
+ _set_collection_attributes(cls, roles, methods)
+
+
+def _locate_roles_and_methods(cls):
+ """search for _sa_instrument_role-decorated methods in
+ method resolution order, assign to roles.
+
+ """
+
+ roles = {}
+ methods = {}
+
+ for supercls in cls.__mro__:
+ for name, method in vars(supercls).items():
+ if not callable(method):
+ continue
+
+ # note role declarations
+ if hasattr(method, "_sa_instrument_role"):
+ role = method._sa_instrument_role
+ assert role in (
+ "appender",
+ "remover",
+ "iterator",
+ "converter",
+ )
+ roles.setdefault(role, name)
+
+ # transfer instrumentation requests from decorated function
+ # to the combined queue
+ before, after = None, None
+ if hasattr(method, "_sa_instrument_before"):
+ op, argument = method._sa_instrument_before
+ assert op in ("fire_append_event", "fire_remove_event")
+ before = op, argument
+ if hasattr(method, "_sa_instrument_after"):
+ op = method._sa_instrument_after
+ assert op in ("fire_append_event", "fire_remove_event")
+ after = op
+ if before:
+ methods[name] = before + (after,)
+ elif after:
+ methods[name] = None, None, after
+ return roles, methods
+
+
+def _setup_canned_roles(cls, roles, methods):
+ """see if this class has "canned" roles based on a known
+ collection type (dict, set, list). Apply those roles
+ as needed to the "roles" dictionary, and also
+ prepare "decorator" methods
+
+ """
+ collection_type = util.duck_type_collection(cls)
+ if collection_type in __interfaces:
+ canned_roles, decorators = __interfaces[collection_type]
+ for role, name in canned_roles.items():
+ roles.setdefault(role, name)
+
+ # apply ABC auto-decoration to methods that need it
+ for method, decorator in decorators.items():
+ fn = getattr(cls, method, None)
+ if (
+ fn
+ and method not in methods
+ and not hasattr(fn, "_sa_instrumented")
+ ):
+ setattr(cls, method, decorator(fn))
+
+
+def _assert_required_roles(cls, roles, methods):
+ """ensure all roles are present, and apply implicit instrumentation if
+ needed
+
+ """
+ if "appender" not in roles or not hasattr(cls, roles["appender"]):
+ raise sa_exc.ArgumentError(
+ "Type %s must elect an appender method to be "
+ "a collection class" % cls.__name__
+ )
+ elif roles["appender"] not in methods and not hasattr(
+ getattr(cls, roles["appender"]), "_sa_instrumented"
+ ):
+ methods[roles["appender"]] = ("fire_append_event", 1, None)
+
+ if "remover" not in roles or not hasattr(cls, roles["remover"]):
+ raise sa_exc.ArgumentError(
+ "Type %s must elect a remover method to be "
+ "a collection class" % cls.__name__
+ )
+ elif roles["remover"] not in methods and not hasattr(
+ getattr(cls, roles["remover"]), "_sa_instrumented"
+ ):
+ methods[roles["remover"]] = ("fire_remove_event", 1, None)
+
+ if "iterator" not in roles or not hasattr(cls, roles["iterator"]):
+ raise sa_exc.ArgumentError(
+ "Type %s must elect an iterator method to be "
+ "a collection class" % cls.__name__
+ )
+
+
+def _set_collection_attributes(cls, roles, methods):
+ """apply ad-hoc instrumentation from decorators, class-level defaults
+ and implicit role declarations
+
+ """
+ for method_name, (before, argument, after) in methods.items():
+ setattr(
+ cls,
+ method_name,
+ _instrument_membership_mutator(
+ getattr(cls, method_name), before, argument, after
+ ),
+ )
+ # intern the role map
+ for role, method_name in roles.items():
+ setattr(cls, "_sa_%s" % role, getattr(cls, method_name))
+
+ cls._sa_adapter = None
+
+ if not hasattr(cls, "_sa_converter"):
+ cls._sa_converter = None
+ cls._sa_instrumented = id(cls)
+
+
+def _instrument_membership_mutator(method, before, argument, after):
+ """Route method args and/or return value through the collection
+ adapter."""
+ # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
+ if before:
+ fn_args = list(
+ util.flatten_iterator(inspect_getfullargspec(method)[0])
+ )
+ if isinstance(argument, int):
+ pos_arg = argument
+ named_arg = len(fn_args) > argument and fn_args[argument] or None
+ else:
+ if argument in fn_args:
+ pos_arg = fn_args.index(argument)
+ else:
+ pos_arg = None
+ named_arg = argument
+ del fn_args
+
+ def wrapper(*args, **kw):
+ if before:
+ if pos_arg is None:
+ if named_arg not in kw:
+ raise sa_exc.ArgumentError(
+ "Missing argument %s" % argument
+ )
+ value = kw[named_arg]
+ else:
+ if len(args) > pos_arg:
+ value = args[pos_arg]
+ elif named_arg in kw:
+ value = kw[named_arg]
+ else:
+ raise sa_exc.ArgumentError(
+ "Missing argument %s" % argument
+ )
+
+ initiator = kw.pop("_sa_initiator", None)
+ if initiator is False:
+ executor = None
+ else:
+ executor = args[0]._sa_adapter
+
+ if before and executor:
+ getattr(executor, before)(value, initiator)
+
+ if not after or not executor:
+ return method(*args, **kw)
+ else:
+ res = method(*args, **kw)
+ if res is not None:
+ getattr(executor, after)(res, initiator)
+ return res
+
+ wrapper._sa_instrumented = True
+ if hasattr(method, "_sa_instrument_role"):
+ wrapper._sa_instrument_role = method._sa_instrument_role
+ wrapper.__name__ = method.__name__
+ wrapper.__doc__ = method.__doc__
+ return wrapper
+
+
+def __set_wo_mutation(collection, item, _sa_initiator=None):
+ """Run set wo mutation events.
+
+ The collection is not mutated.
+
+ """
+ if _sa_initiator is not False:
+ executor = collection._sa_adapter
+ if executor:
+ executor.fire_append_wo_mutation_event(item, _sa_initiator)
+
+
+def __set(collection, item, _sa_initiator=None):
+ """Run set events.
+
+ This event always occurs before the collection is actually mutated.
+
+ """
+
+ if _sa_initiator is not False:
+ executor = collection._sa_adapter
+ if executor:
+ item = executor.fire_append_event(item, _sa_initiator)
+ return item
+
+
+def __del(collection, item, _sa_initiator=None):
+ """Run del events.
+
+ This event occurs before the collection is actually mutated, *except*
+ in the case of a pop operation, in which case it occurs afterwards.
+ For pop operations, the __before_pop hook is called before the
+ operation occurs.
+
+ """
+ if _sa_initiator is not False:
+ executor = collection._sa_adapter
+ if executor:
+ executor.fire_remove_event(item, _sa_initiator)
+
+
+def __before_pop(collection, _sa_initiator=None):
+ """An event which occurs on a before a pop() operation occurs."""
+ executor = collection._sa_adapter
+ if executor:
+ executor.fire_pre_remove_event(_sa_initiator)
+
+
+def _list_decorators():
+ """Tailored instrumentation wrappers for any list-like class."""
+
+ def _tidy(fn):
+ fn._sa_instrumented = True
+ fn.__doc__ = getattr(list, fn.__name__).__doc__
+
+ def append(fn):
+ def append(self, item, _sa_initiator=None):
+ item = __set(self, item, _sa_initiator)
+ fn(self, item)
+
+ _tidy(append)
+ return append
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ __del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__eq__
+ fn(self, value)
+
+ _tidy(remove)
+ return remove
+
+ def insert(fn):
+ def insert(self, index, value):
+ value = __set(self, value)
+ fn(self, index, value)
+
+ _tidy(insert)
+ return insert
+
+ def __setitem__(fn):
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ existing = self[index]
+ if existing is not None:
+ __del(self, existing)
+ value = __set(self, value)
+ fn(self, index, value)
+ else:
+ # slice assignment requires __delitem__, insert, __len__
+ step = index.step or 1
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ if index.stop is not None:
+ stop = index.stop
+ else:
+ stop = len(self)
+ if stop < 0:
+ stop += len(self)
+
+ if step == 1:
+ if value is self:
+ return
+ for i in range(start, stop, step):
+ if len(self) > start:
+ del self[start]
+
+ for i, item in enumerate(value):
+ self.insert(i + start, item)
+ else:
+ rng = list(range(start, stop, step))
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s"
+ % (len(value), len(rng))
+ )
+ for i, item in zip(rng, value):
+ self.__setitem__(i, item)
+
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, index):
+ if not isinstance(index, slice):
+ item = self[index]
+ __del(self, item)
+ fn(self, index)
+ else:
+ # slice deletion requires __getslice__ and a slice-groking
+ # __getitem__ for stepped deletion
+ # note: not breaking this into atomic dels
+ for item in self[index]:
+ __del(self, item)
+ fn(self, index)
+
+ _tidy(__delitem__)
+ return __delitem__
+
+ if util.py2k:
+
+ def __setslice__(fn):
+ def __setslice__(self, start, end, values):
+ for value in self[start:end]:
+ __del(self, value)
+ values = [__set(self, value) for value in values]
+ fn(self, start, end, values)
+
+ _tidy(__setslice__)
+ return __setslice__
+
+ def __delslice__(fn):
+ def __delslice__(self, start, end):
+ for value in self[start:end]:
+ __del(self, value)
+ fn(self, start, end)
+
+ _tidy(__delslice__)
+ return __delslice__
+
+ def extend(fn):
+ def extend(self, iterable):
+ for value in list(iterable):
+ self.append(value)
+
+ _tidy(extend)
+ return extend
+
+ def __iadd__(fn):
+ def __iadd__(self, iterable):
+ # list.__iadd__ takes any iterable and seems to let TypeError
+ # raise as-is instead of returning NotImplemented
+ for value in list(iterable):
+ self.append(value)
+ return self
+
+ _tidy(__iadd__)
+ return __iadd__
+
+ def pop(fn):
+ def pop(self, index=-1):
+ __before_pop(self)
+ item = fn(self, index)
+ __del(self, item)
+ return item
+
+ _tidy(pop)
+ return pop
+
+ if not util.py2k:
+
+ def clear(fn):
+ def clear(self, index=-1):
+ for item in self:
+ __del(self, item)
+ fn(self)
+
+ _tidy(clear)
+ return clear
+
+ # __imul__ : not wrapping this. all members of the collection are already
+ # present, so no need to fire appends... wrapping it with an explicit
+ # decorator is still possible, so events on *= can be had if they're
+ # desired. hard to imagine a use case for __imul__, though.
+
+ l = locals().copy()
+ l.pop("_tidy")
+ return l
+
+
+def _dict_decorators():
+ """Tailored instrumentation wrappers for any dict-like mapping class."""
+
+ def _tidy(fn):
+ fn._sa_instrumented = True
+ fn.__doc__ = getattr(dict, fn.__name__).__doc__
+
+ Unspecified = util.symbol("Unspecified")
+
+ def __setitem__(fn):
+ def __setitem__(self, key, value, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ value = __set(self, value, _sa_initiator)
+ fn(self, key, value)
+
+ _tidy(__setitem__)
+ return __setitem__
+
+ def __delitem__(fn):
+ def __delitem__(self, key, _sa_initiator=None):
+ if key in self:
+ __del(self, self[key], _sa_initiator)
+ fn(self, key)
+
+ _tidy(__delitem__)
+ return __delitem__
+
+ def clear(fn):
+ def clear(self):
+ for key in self:
+ __del(self, self[key])
+ fn(self)
+
+ _tidy(clear)
+ return clear
+
+ def pop(fn):
+ def pop(self, key, default=Unspecified):
+ __before_pop(self)
+ _to_del = key in self
+ if default is Unspecified:
+ item = fn(self, key)
+ else:
+ item = fn(self, key, default)
+ if _to_del:
+ __del(self, item)
+ return item
+
+ _tidy(pop)
+ return pop
+
+ def popitem(fn):
+ def popitem(self):
+ __before_pop(self)
+ item = fn(self)
+ __del(self, item[1])
+ return item
+
+ _tidy(popitem)
+ return popitem
+
+ def setdefault(fn):
+ def setdefault(self, key, default=None):
+ if key not in self:
+ self.__setitem__(key, default)
+ return default
+ else:
+ value = self.__getitem__(key)
+ if value is default:
+ __set_wo_mutation(self, value, None)
+
+ return value
+
+ _tidy(setdefault)
+ return setdefault
+
+ def update(fn):
+ def update(self, __other=Unspecified, **kw):
+ if __other is not Unspecified:
+ if hasattr(__other, "keys"):
+ for key in list(__other):
+ if key not in self or self[key] is not __other[key]:
+ self[key] = __other[key]
+ else:
+ __set_wo_mutation(self, __other[key], None)
+ else:
+ for key, value in __other:
+ if key not in self or self[key] is not value:
+ self[key] = value
+ else:
+ __set_wo_mutation(self, value, None)
+ for key in kw:
+ if key not in self or self[key] is not kw[key]:
+ self[key] = kw[key]
+ else:
+ __set_wo_mutation(self, kw[key], None)
+
+ _tidy(update)
+ return update
+
+ l = locals().copy()
+ l.pop("_tidy")
+ l.pop("Unspecified")
+ return l
+
+
+_set_binop_bases = (set, frozenset)
+
+
+def _set_binops_check_strict(self, obj):
+ """Allow only set, frozenset and self.__class__-derived
+ objects in binops."""
+ return isinstance(obj, _set_binop_bases + (self.__class__,))
+
+
+def _set_binops_check_loose(self, obj):
+ """Allow anything set-like to participate in set binops."""
+ return (
+ isinstance(obj, _set_binop_bases + (self.__class__,))
+ or util.duck_type_collection(obj) == set
+ )
+
+
+def _set_decorators():
+ """Tailored instrumentation wrappers for any set-like class."""
+
+ def _tidy(fn):
+ fn._sa_instrumented = True
+ fn.__doc__ = getattr(set, fn.__name__).__doc__
+
+ Unspecified = util.symbol("Unspecified")
+
+ def add(fn):
+ def add(self, value, _sa_initiator=None):
+ if value not in self:
+ value = __set(self, value, _sa_initiator)
+ else:
+ __set_wo_mutation(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
+ fn(self, value)
+
+ _tidy(add)
+ return add
+
+ def discard(fn):
+ def discard(self, value, _sa_initiator=None):
+ # testlib.pragma exempt:__hash__
+ if value in self:
+ __del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
+ fn(self, value)
+
+ _tidy(discard)
+ return discard
+
+ def remove(fn):
+ def remove(self, value, _sa_initiator=None):
+ # testlib.pragma exempt:__hash__
+ if value in self:
+ __del(self, value, _sa_initiator)
+ # testlib.pragma exempt:__hash__
+ fn(self, value)
+
+ _tidy(remove)
+ return remove
+
+ def pop(fn):
+ def pop(self):
+ __before_pop(self)
+ item = fn(self)
+ # for set in particular, we have no way to access the item
+ # that will be popped before pop is called.
+ __del(self, item)
+ return item
+
+ _tidy(pop)
+ return pop
+
+ def clear(fn):
+ def clear(self):
+ for item in list(self):
+ self.remove(item)
+
+ _tidy(clear)
+ return clear
+
+ def update(fn):
+ def update(self, value):
+ for item in value:
+ self.add(item)
+
+ _tidy(update)
+ return update
+
+ def __ior__(fn):
+ def __ior__(self, value):
+ if not _set_binops_check_strict(self, value):
+ return NotImplemented
+ for item in value:
+ self.add(item)
+ return self
+
+ _tidy(__ior__)
+ return __ior__
+
+ def difference_update(fn):
+ def difference_update(self, value):
+ for item in value:
+ self.discard(item)
+
+ _tidy(difference_update)
+ return difference_update
+
+ def __isub__(fn):
+ def __isub__(self, value):
+ if not _set_binops_check_strict(self, value):
+ return NotImplemented
+ for item in value:
+ self.discard(item)
+ return self
+
+ _tidy(__isub__)
+ return __isub__
+
+ def intersection_update(fn):
+ def intersection_update(self, other):
+ want, have = self.intersection(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+
+ _tidy(intersection_update)
+ return intersection_update
+
+ def __iand__(fn):
+ def __iand__(self, other):
+ if not _set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.intersection(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ return self
+
+ _tidy(__iand__)
+ return __iand__
+
+ def symmetric_difference_update(fn):
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+
+ _tidy(symmetric_difference_update)
+ return symmetric_difference_update
+
+ def __ixor__(fn):
+ def __ixor__(self, other):
+ if not _set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.symmetric_difference(other), set(self)
+ remove, add = have - want, want - have
+
+ for item in remove:
+ self.remove(item)
+ for item in add:
+ self.add(item)
+ return self
+
+ _tidy(__ixor__)
+ return __ixor__
+
+ l = locals().copy()
+ l.pop("_tidy")
+ l.pop("Unspecified")
+ return l
+
+
+class InstrumentedList(list):
+ """An instrumented version of the built-in list."""
+
+
+class InstrumentedSet(set):
+ """An instrumented version of the built-in set."""
+
+
+class InstrumentedDict(dict):
+ """An instrumented version of the built-in dict."""
+
+
+__canned_instrumentation = {
+ list: InstrumentedList,
+ set: InstrumentedSet,
+ dict: InstrumentedDict,
+}
+
+__interfaces = {
+ list: (
+ {"appender": "append", "remover": "remove", "iterator": "__iter__"},
+ _list_decorators(),
+ ),
+ set: (
+ {"appender": "add", "remover": "remove", "iterator": "__iter__"},
+ _set_decorators(),
+ ),
+ # decorators are required for dicts and object collections.
+ dict: ({"iterator": "values"}, _dict_decorators())
+ if util.py3k
+ else ({"iterator": "itervalues"}, _dict_decorators()),
+}
+
+
+class MappedCollection(dict):
+ """A basic dictionary-based collection class.
+
+ Extends dict with the minimal bag semantics that collection
+ classes require. ``set`` and ``remove`` are implemented in terms
+ of a keying function: any callable that takes an object and
+ returns an object for use as a dictionary key.
+
+ """
+
+ def __init__(self, keyfunc):
+ """Create a new collection with keying provided by keyfunc.
+
+ keyfunc may be any callable that takes an object and returns an object
+ for use as a dictionary key.
+
+ The keyfunc will be called every time the ORM needs to add a member by
+ value-only (such as when loading instances from the database) or
+ remove a member. The usual cautions about dictionary keying apply-
+ ``keyfunc(object)`` should return the same output for the life of the
+ collection. Keying based on mutable properties can result in
+ unreachable instances "lost" in the collection.
+
+ """
+ self.keyfunc = keyfunc
+
+ @collection.appender
+ @collection.internally_instrumented
+ def set(self, value, _sa_initiator=None):
+ """Add an item by value, consulting the keyfunc for the key."""
+
+ key = self.keyfunc(value)
+ self.__setitem__(key, value, _sa_initiator)
+
+ @collection.remover
+ @collection.internally_instrumented
+ def remove(self, value, _sa_initiator=None):
+ """Remove an item by value, consulting the keyfunc for the key."""
+
+ key = self.keyfunc(value)
+ # Let self[key] raise if key is not in this collection
+ # testlib.pragma exempt:__ne__
+ if self[key] != value:
+ raise sa_exc.InvalidRequestError(
+ "Can not remove '%s': collection holds '%s' for key '%s'. "
+ "Possible cause: is the MappedCollection key function "
+ "based on mutable properties or properties that only obtain "
+ "values after flush?" % (value, self[key], key)
+ )
+ self.__delitem__(key, _sa_initiator)
+
+
+# ensure instrumentation is associated with
+# these built-in classes; if a user-defined class
+# subclasses these and uses @internally_instrumented,
+# the superclass is otherwise not instrumented.
+# see [ticket:2406].
+_instrument_class(MappedCollection)
+_instrument_class(InstrumentedList)
+_instrument_class(InstrumentedSet)
diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py
new file mode 100644
index 0000000..9d4f652
--- /dev/null
+++ b/lib/sqlalchemy/orm/context.py
@@ -0,0 +1,3136 @@
+# orm/context.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+import itertools
+
+from . import attributes
+from . import interfaces
+from . import loading
+from .base import _is_aliased_class
+from .interfaces import ORMColumnsClauseRole
+from .path_registry import PathRegistry
+from .util import _entity_corresponds_to
+from .util import _ORMJoin
+from .util import aliased
+from .util import Bundle
+from .util import ORMAdapter
+from .. import exc as sa_exc
+from .. import future
+from .. import inspect
+from .. import sql
+from .. import util
+from ..sql import ClauseElement
+from ..sql import coercions
+from ..sql import expression
+from ..sql import roles
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.base import _entity_namespace_key
+from ..sql.base import _select_iterables
+from ..sql.base import CacheableOptions
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
+from ..sql.selectable import LABEL_STYLE_NONE
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectState
+from ..sql.visitors import ExtendedInternalTraversal
+from ..sql.visitors import InternalTraversal
+
+_path_registry = PathRegistry.root
+
+_EMPTY_DICT = util.immutabledict()
+
+
+LABEL_STYLE_LEGACY_ORM = util.symbol("LABEL_STYLE_LEGACY_ORM")
+
+
+class QueryContext(object):
+ __slots__ = (
+ "compile_state",
+ "query",
+ "params",
+ "load_options",
+ "bind_arguments",
+ "execution_options",
+ "session",
+ "autoflush",
+ "populate_existing",
+ "invoke_all_eagers",
+ "version_check",
+ "refresh_state",
+ "create_eager_joins",
+ "propagated_loader_options",
+ "attributes",
+ "runid",
+ "partials",
+ "post_load_paths",
+ "identity_token",
+ "yield_per",
+ "loaders_require_buffering",
+ "loaders_require_uniquing",
+ )
+
+ class default_load_options(Options):
+ _only_return_tuples = False
+ _populate_existing = False
+ _version_check = False
+ _invoke_all_eagers = True
+ _autoflush = True
+ _refresh_identity_token = None
+ _yield_per = None
+ _refresh_state = None
+ _lazy_loaded_from = None
+ _legacy_uniquing = False
+
+ def __init__(
+ self,
+ compile_state,
+ statement,
+ params,
+ session,
+ load_options,
+ execution_options=None,
+ bind_arguments=None,
+ ):
+ self.load_options = load_options
+ self.execution_options = execution_options or _EMPTY_DICT
+ self.bind_arguments = bind_arguments or _EMPTY_DICT
+ self.compile_state = compile_state
+ self.query = statement
+ self.session = session
+ self.loaders_require_buffering = False
+ self.loaders_require_uniquing = False
+ self.params = params
+
+ self.propagated_loader_options = {
+ # issue 7447.
+ # propagated loader options will be present on loaded InstanceState
+ # objects under state.load_options and are typically used by
+ # LazyLoader to apply options to the SELECT statement it emits.
+ # For compile state options (i.e. loader strategy options), these
+ # need to line up with the ".load_path" attribute which in
+ # loader.py is pulled from context.compile_state.current_path.
+ # so, this means these options have to be the ones from the
+ # *cached* statement that's travelling with compile_state, not the
+ # *current* statement which won't match up for an ad-hoc
+ # AliasedClass
+ cached_o
+ for cached_o in compile_state.select_statement._with_options
+ if cached_o.propagate_to_loaders and cached_o._is_compile_state
+ } | {
+ # for user defined loader options that are not "compile state",
+ # those just need to be present as they are
+ uncached_o
+ for uncached_o in statement._with_options
+ if uncached_o.propagate_to_loaders
+ and not uncached_o._is_compile_state
+ }
+
+ self.attributes = dict(compile_state.attributes)
+
+ self.autoflush = load_options._autoflush
+ self.populate_existing = load_options._populate_existing
+ self.invoke_all_eagers = load_options._invoke_all_eagers
+ self.version_check = load_options._version_check
+ self.refresh_state = load_options._refresh_state
+ self.yield_per = load_options._yield_per
+ self.identity_token = load_options._refresh_identity_token
+
+ if self.yield_per and compile_state._no_yield_pers:
+ raise sa_exc.InvalidRequestError(
+ "The yield_per Query option is currently not "
+ "compatible with %s eager loading. Please "
+ "specify lazyload('*') or query.enable_eagerloads(False) in "
+ "order to "
+ "proceed with query.yield_per()."
+ % ", ".join(compile_state._no_yield_pers)
+ )
+
+
+_orm_load_exec_options = util.immutabledict(
+ {"_result_disable_adapt_to_context": True, "future_result": True}
+)
+
+
+class ORMCompileState(CompileState):
+ # note this is a dictionary, but the
+ # default_compile_options._with_polymorphic_adapt_map is a tuple
+ _with_polymorphic_adapt_map = _EMPTY_DICT
+
+ class default_compile_options(CacheableOptions):
+ _cache_key_traversal = [
+ ("_use_legacy_query_style", InternalTraversal.dp_boolean),
+ ("_for_statement", InternalTraversal.dp_boolean),
+ ("_bake_ok", InternalTraversal.dp_boolean),
+ (
+ "_with_polymorphic_adapt_map",
+ ExtendedInternalTraversal.dp_has_cache_key_tuples,
+ ),
+ ("_current_path", InternalTraversal.dp_has_cache_key),
+ ("_enable_single_crit", InternalTraversal.dp_boolean),
+ ("_enable_eagerloads", InternalTraversal.dp_boolean),
+ ("_orm_only_from_obj_alias", InternalTraversal.dp_boolean),
+ ("_only_load_props", InternalTraversal.dp_plain_obj),
+ ("_set_base_alias", InternalTraversal.dp_boolean),
+ ("_for_refresh_state", InternalTraversal.dp_boolean),
+ ("_render_for_subquery", InternalTraversal.dp_boolean),
+ ("_is_star", InternalTraversal.dp_boolean),
+ ]
+
+ # set to True by default from Query._statement_20(), to indicate
+ # the rendered query should look like a legacy ORM query. right
+ # now this basically indicates we should use tablename_columnname
+ # style labels. Generally indicates the statement originated
+ # from a Query object.
+ _use_legacy_query_style = False
+
+ # set *only* when we are coming from the Query.statement
+ # accessor, or a Query-level equivalent such as
+ # query.subquery(). this supersedes "toplevel".
+ _for_statement = False
+
+ _bake_ok = True
+ _with_polymorphic_adapt_map = ()
+ _current_path = _path_registry
+ _enable_single_crit = True
+ _enable_eagerloads = True
+ _orm_only_from_obj_alias = True
+ _only_load_props = None
+ _set_base_alias = False
+ _for_refresh_state = False
+ _render_for_subquery = False
+ _is_star = False
+
+ current_path = _path_registry
+
+ def __init__(self, *arg, **kw):
+ raise NotImplementedError()
+
+ def _append_dedupe_col_collection(self, obj, col_collection):
+ dedupe = self.dedupe_columns
+ if obj not in dedupe:
+ dedupe.add(obj)
+ col_collection.append(obj)
+
+ @classmethod
+ def _column_naming_convention(cls, label_style, legacy):
+
+ if legacy:
+
+ def name(col, col_name=None):
+ if col_name:
+ return col_name
+ else:
+ return getattr(col, "key")
+
+ return name
+ else:
+ return SelectState._column_naming_convention(label_style)
+
+ @classmethod
+ def create_for_statement(cls, statement_container, compiler, **kw):
+ """Create a context for a statement given a :class:`.Compiler`.
+
+ This method is always invoked in the context of SQLCompiler.process().
+
+ For a Select object, this would be invoked from
+ SQLCompiler.visit_select(). For the special FromStatement object used
+ by Query to indicate "Query.from_statement()", this is called by
+ FromStatement._compiler_dispatch() that would be called by
+ SQLCompiler.process().
+
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def get_column_descriptions(cls, statement):
+ return _column_descriptions(statement)
+
+ @classmethod
+ def orm_pre_session_exec(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_reentrant_invoke,
+ ):
+ if is_reentrant_invoke:
+ return statement, execution_options
+
+ (
+ load_options,
+ execution_options,
+ ) = QueryContext.default_load_options.from_execution_options(
+ "_sa_orm_load_options",
+ {"populate_existing", "autoflush", "yield_per"},
+ execution_options,
+ statement._execution_options,
+ )
+
+ # default execution options for ORM results:
+ # 1. _result_disable_adapt_to_context=True
+ # this will disable the ResultSetMetadata._adapt_to_context()
+ # step which we don't need, as we have result processors cached
+ # against the original SELECT statement before caching.
+ # 2. future_result=True. The ORM should **never** resolve columns
+ # in a result set based on names, only on Column objects that
+ # are correctly adapted to the context. W the legacy result
+ # it will still attempt name-based resolution and also emit a
+ # warning.
+ if not execution_options:
+ execution_options = _orm_load_exec_options
+ else:
+ execution_options = execution_options.union(_orm_load_exec_options)
+
+ if load_options._yield_per:
+ execution_options = execution_options.union(
+ {"yield_per": load_options._yield_per}
+ )
+
+ bind_arguments["clause"] = statement
+
+ # new in 1.4 - the coercions system is leveraged to allow the
+ # "subject" mapper of a statement be propagated to the top
+ # as the statement is built. "subject" mapper is the generally
+ # standard object used as an identifier for multi-database schemes.
+
+ # we are here based on the fact that _propagate_attrs contains
+ # "compile_state_plugin": "orm". The "plugin_subject"
+ # needs to be present as well.
+
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
+ else:
+ if plugin_subject:
+ bind_arguments["mapper"] = plugin_subject.mapper
+
+ if load_options._autoflush:
+ session._autoflush()
+
+ return statement, execution_options
+
+ @classmethod
+ def orm_setup_cursor_result(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+ execution_context = result.context
+ compile_state = execution_context.compiled.compile_state
+
+ # cover edge case where ORM entities used in legacy select
+ # were passed to session.execute:
+ # session.execute(legacy_select([User.id, User.name]))
+ # see test_query->test_legacy_tuple_old_select
+
+ load_options = execution_options.get(
+ "_sa_orm_load_options", QueryContext.default_load_options
+ )
+ if compile_state.compile_options._is_star:
+ return result
+
+ querycontext = QueryContext(
+ compile_state,
+ statement,
+ params,
+ session,
+ load_options,
+ execution_options,
+ bind_arguments,
+ )
+ return loading.instances(result, querycontext)
+
+ @property
+ def _lead_mapper_entities(self):
+ """return all _MapperEntity objects in the lead entities collection.
+
+ Does **not** include entities that have been replaced by
+ with_entities(), with_only_columns()
+
+ """
+ return [
+ ent for ent in self._entities if isinstance(ent, _MapperEntity)
+ ]
+
+ def _create_with_polymorphic_adapter(self, ext_info, selectable):
+ if (
+ not ext_info.is_aliased_class
+ and ext_info.mapper.persist_selectable
+ not in self._polymorphic_adapters
+ ):
+ for mp in ext_info.mapper.iterate_to_root():
+ self._mapper_loads_polymorphically_with(
+ mp,
+ sql_util.ColumnAdapter(selectable, mp._equivalent_columns),
+ )
+
+ def _mapper_loads_polymorphically_with(self, mapper, adapter):
+ for m2 in mapper._with_polymorphic_mappers or [mapper]:
+ self._polymorphic_adapters[m2] = adapter
+ for m in m2.iterate_to_root(): # TODO: redundant ?
+ self._polymorphic_adapters[m.local_table] = adapter
+
+ @classmethod
+ def _create_entities_collection(cls, query, legacy):
+ raise NotImplementedError(
+ "this method only works for ORMSelectCompileState"
+ )
+
+
+@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
+class ORMFromStatementCompileState(ORMCompileState):
+ _aliased_generations = util.immutabledict()
+ _from_obj_alias = None
+ _has_mapper_entities = False
+
+ _has_orm_entities = False
+ multi_row_eager_loaders = False
+ compound_eager_adapter = None
+
+ extra_criteria_entities = _EMPTY_DICT
+ eager_joins = _EMPTY_DICT
+
+ @classmethod
+ def create_for_statement(cls, statement_container, compiler, **kw):
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ else:
+ toplevel = True
+
+ self = cls.__new__(cls)
+ self._primary_entity = None
+
+ self.use_legacy_query_style = (
+ statement_container._compile_options._use_legacy_query_style
+ )
+ self.statement_container = self.select_statement = statement_container
+ self.requested_statement = statement = statement_container.element
+
+ if statement.is_dml:
+ self.dml_table = statement.table
+
+ self._entities = []
+ self._polymorphic_adapters = {}
+ self._no_yield_pers = set()
+
+ self.compile_options = statement_container._compile_options
+
+ if (
+ self.use_legacy_query_style
+ and isinstance(statement, expression.SelectBase)
+ and not statement._is_textual
+ and not statement.is_dml
+ and statement._label_style is LABEL_STYLE_NONE
+ ):
+ self.statement = statement.set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ else:
+ self.statement = statement
+
+ self._label_convention = self._column_naming_convention(
+ statement._label_style
+ if not statement._is_textual and not statement.is_dml
+ else LABEL_STYLE_NONE,
+ self.use_legacy_query_style,
+ )
+
+ _QueryEntity.to_compile_state(
+ self,
+ statement_container._raw_columns,
+ self._entities,
+ is_current_entities=True,
+ )
+
+ self.current_path = statement_container._compile_options._current_path
+
+ if toplevel and statement_container._with_options:
+ self.attributes = {"_unbound_load_dedupes": set()}
+ self.global_attributes = compiler._global_attributes
+
+ for opt in statement_container._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state(self)
+
+ else:
+ self.attributes = {}
+ self.global_attributes = compiler._global_attributes
+
+ if statement_container._with_context_options:
+ for fn, key in statement_container._with_context_options:
+ fn(self)
+
+ self.primary_columns = []
+ self.secondary_columns = []
+ self.dedupe_columns = set()
+ self.create_eager_joins = []
+ self._fallback_from_clauses = []
+
+ self.order_by = None
+
+ if isinstance(
+ self.statement, (expression.TextClause, expression.UpdateBase)
+ ):
+
+ self.extra_criteria_entities = {}
+
+ # setup for all entities. Currently, this is not useful
+ # for eager loaders, as the eager loaders that work are able
+ # to do their work entirely in row_processor.
+ for entity in self._entities:
+ entity.setup_compile_state(self)
+
+ # we did the setup just to get primary columns.
+ self.statement = _AdHocColumnsStatement(
+ self.statement, self.primary_columns
+ )
+ else:
+ # allow TextualSelect with implicit columns as well
+ # as select() with ad-hoc columns, see test_query::TextTest
+ self._from_obj_alias = sql.util.ColumnAdapter(
+ self.statement, adapt_on_names=True
+ )
+ # set up for eager loaders, however if we fix subqueryload
+ # it should not need to do this here. the model of eager loaders
+ # that can work entirely in row_processor might be interesting
+ # here though subqueryloader has a lot of upfront work to do
+ # see test/orm/test_query.py -> test_related_eagerload_against_text
+ # for where this part makes a difference. would rather have
+ # subqueryload figure out what it needs more intelligently.
+ # for entity in self._entities:
+ # entity.setup_compile_state(self)
+
+ return self
+
+ def _adapt_col_list(self, cols, current_adapter):
+ return cols
+
+ def _get_current_adapter(self):
+ return None
+
+
+class _AdHocColumnsStatement(ClauseElement):
+ """internal object created to somewhat act like a SELECT when we
+ are selecting columns from a DML RETURNING.
+
+
+ """
+
+ __visit_name__ = None
+
+ def __init__(self, text, columns):
+ self.element = text
+ self.column_args = [
+ coercions.expect(roles.ColumnsClauseRole, c) for c in columns
+ ]
+
+ def _generate_cache_key(self):
+ raise NotImplementedError()
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ raise NotImplementedError()
+
+ def _compiler_dispatch(
+ self, compiler, compound_index=None, asfrom=False, **kw
+ ):
+ """provide a fixed _compiler_dispatch method."""
+
+ toplevel = not compiler.stack
+ entry = (
+ compiler._default_stack_entry if toplevel else compiler.stack[-1]
+ )
+
+ populate_result_map = (
+ toplevel
+ # these two might not be needed
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
+
+ if populate_result_map:
+ compiler._ordered_columns = (
+ compiler._textual_ordered_columns
+ ) = False
+
+ # enable looser result column matching. this is shown to be
+ # needed by test_query.py::TextTest
+ compiler._loose_column_name_matching = True
+
+ for c in self.column_args:
+ compiler.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=compiler._add_to_result_map,
+ )
+ return compiler.process(self.element, **kw)
+
+
+@sql.base.CompileState.plugin_for("orm", "select")
+class ORMSelectCompileState(ORMCompileState, SelectState):
+ _joinpath = _joinpoint = _EMPTY_DICT
+
+ _memoized_entities = _EMPTY_DICT
+
+ _from_obj_alias = None
+ _has_mapper_entities = False
+
+ _has_orm_entities = False
+ multi_row_eager_loaders = False
+ compound_eager_adapter = None
+
+ correlate = None
+ correlate_except = None
+ _where_criteria = ()
+ _having_criteria = ()
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ """compiler hook, we arrive here from compiler.visit_select() only."""
+
+ self = cls.__new__(cls)
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ self.global_attributes = compiler._global_attributes
+ else:
+ toplevel = True
+ self.global_attributes = {}
+
+ select_statement = statement
+
+ # if we are a select() that was never a legacy Query, we won't
+ # have ORM level compile options.
+ statement._compile_options = cls.default_compile_options.safe_merge(
+ statement._compile_options
+ )
+
+ if select_statement._execution_options:
+ # execution options should not impact the compilation of a
+ # query, and at the moment subqueryloader is putting some things
+ # in here that we explicitly don't want stuck in a cache.
+ self.select_statement = select_statement._clone()
+ self.select_statement._execution_options = util.immutabledict()
+ else:
+ self.select_statement = select_statement
+
+ # indicates this select() came from Query.statement
+ self.for_statement = select_statement._compile_options._for_statement
+
+ # generally if we are from Query or directly from a select()
+ self.use_legacy_query_style = (
+ select_statement._compile_options._use_legacy_query_style
+ )
+
+ self._entities = []
+ self._primary_entity = None
+ self._aliased_generations = {}
+ self._polymorphic_adapters = {}
+ self._no_yield_pers = set()
+
+ # legacy: only for query.with_polymorphic()
+ if select_statement._compile_options._with_polymorphic_adapt_map:
+ self._with_polymorphic_adapt_map = dict(
+ select_statement._compile_options._with_polymorphic_adapt_map
+ )
+ self._setup_with_polymorphics()
+
+ self.compile_options = select_statement._compile_options
+
+ if not toplevel:
+ # for subqueries, turn off eagerloads and set
+ # "render_for_subquery".
+ self.compile_options += {
+ "_enable_eagerloads": False,
+ "_render_for_subquery": True,
+ }
+
+ # determine label style. we can make different decisions here.
+ # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY
+ # rather than LABEL_STYLE_NONE, and if we can use disambiguate style
+ # for new style ORM selects too.
+ if (
+ self.use_legacy_query_style
+ and self.select_statement._label_style is LABEL_STYLE_LEGACY_ORM
+ ):
+ if not self.for_statement:
+ self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+ else:
+ self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY
+ else:
+ self.label_style = self.select_statement._label_style
+
+ if select_statement._memoized_select_entities:
+ self._memoized_entities = {
+ memoized_entities: _QueryEntity.to_compile_state(
+ self,
+ memoized_entities._raw_columns,
+ [],
+ is_current_entities=False,
+ )
+ for memoized_entities in (
+ select_statement._memoized_select_entities
+ )
+ }
+
+ # label_convention is stateful and will yield deduping keys if it
+ # sees the same key twice. therefore it's important that it is not
+ # invoked for the above "memoized" entities that aren't actually
+ # in the columns clause
+ self._label_convention = self._column_naming_convention(
+ statement._label_style, self.use_legacy_query_style
+ )
+
+ _QueryEntity.to_compile_state(
+ self,
+ select_statement._raw_columns,
+ self._entities,
+ is_current_entities=True,
+ )
+
+ self.current_path = select_statement._compile_options._current_path
+
+ self.eager_order_by = ()
+
+ if toplevel and (
+ select_statement._with_options
+ or select_statement._memoized_select_entities
+ ):
+ self.attributes = {"_unbound_load_dedupes": set()}
+
+ for (
+ memoized_entities
+ ) in select_statement._memoized_select_entities:
+ for opt in memoized_entities._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state_replaced_entities(
+ self,
+ [
+ ent
+ for ent in self._memoized_entities[
+ memoized_entities
+ ]
+ if isinstance(ent, _MapperEntity)
+ ],
+ )
+
+ for opt in self.select_statement._with_options:
+ if opt._is_compile_state:
+ opt.process_compile_state(self)
+ else:
+ self.attributes = {}
+
+ if select_statement._with_context_options:
+ for fn, key in select_statement._with_context_options:
+ fn(self)
+
+ self.primary_columns = []
+ self.secondary_columns = []
+ self.dedupe_columns = set()
+ self.eager_joins = {}
+ self.extra_criteria_entities = {}
+ self.create_eager_joins = []
+ self._fallback_from_clauses = []
+
+ # normalize the FROM clauses early by themselves, as this makes
+ # it an easier job when we need to assemble a JOIN onto these,
+ # for select.join() as well as joinedload(). As of 1.4 there are now
+ # potentially more complex sets of FROM objects here as the use
+ # of lambda statements for lazyload, load_on_pk etc. uses more
+ # cloning of the select() construct. See #6495
+ self.from_clauses = self._normalize_froms(
+ info.selectable for info in select_statement._from_obj
+ )
+
+ # this is a fairly arbitrary break into a second method,
+ # so it might be nicer to break up create_for_statement()
+ # and _setup_for_generate into three or four logical sections
+ self._setup_for_generate()
+
+ SelectState.__init__(self, self.statement, compiler, **kw)
+
+ return self
+
+ def _setup_for_generate(self):
+ query = self.select_statement
+
+ self.statement = None
+ self._join_entities = ()
+
+ if self.compile_options._set_base_alias:
+ self._set_select_from_alias()
+
+ for memoized_entities in query._memoized_select_entities:
+ if memoized_entities._setup_joins:
+ self._join(
+ memoized_entities._setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+ if memoized_entities._legacy_setup_joins:
+ self._legacy_join(
+ memoized_entities._legacy_setup_joins,
+ self._memoized_entities[memoized_entities],
+ )
+
+ if query._setup_joins:
+ self._join(query._setup_joins, self._entities)
+
+ if query._legacy_setup_joins:
+ self._legacy_join(query._legacy_setup_joins, self._entities)
+
+ current_adapter = self._get_current_adapter()
+
+ if query._where_criteria:
+ self._where_criteria = query._where_criteria
+
+ if current_adapter:
+ self._where_criteria = tuple(
+ current_adapter(crit, True)
+ for crit in self._where_criteria
+ )
+
+ # TODO: some complexity with order_by here was due to mapper.order_by.
+ # now that this is removed we can hopefully make order_by /
+ # group_by act identically to how they are in Core select.
+ self.order_by = (
+ self._adapt_col_list(query._order_by_clauses, current_adapter)
+ if current_adapter and query._order_by_clauses not in (None, False)
+ else query._order_by_clauses
+ )
+
+ if query._having_criteria:
+ self._having_criteria = tuple(
+ current_adapter(crit, True) if current_adapter else crit
+ for crit in query._having_criteria
+ )
+
+ self.group_by = (
+ self._adapt_col_list(
+ util.flatten_iterator(query._group_by_clauses), current_adapter
+ )
+ if current_adapter and query._group_by_clauses not in (None, False)
+ else query._group_by_clauses or None
+ )
+
+ if self.eager_order_by:
+ adapter = self.from_clauses[0]._target_adapter
+ self.eager_order_by = adapter.copy_and_process(self.eager_order_by)
+
+ if query._distinct_on:
+ self.distinct_on = self._adapt_col_list(
+ query._distinct_on, current_adapter
+ )
+ else:
+ self.distinct_on = ()
+
+ self.distinct = query._distinct
+
+ if query._correlate:
+ # ORM mapped entities that are mapped to joins can be passed
+ # to .correlate, so here they are broken into their component
+ # tables.
+ self.correlate = tuple(
+ util.flatten_iterator(
+ sql_util.surface_selectables(s) if s is not None else None
+ for s in query._correlate
+ )
+ )
+ elif query._correlate_except is not None:
+ self.correlate_except = tuple(
+ util.flatten_iterator(
+ sql_util.surface_selectables(s) if s is not None else None
+ for s in query._correlate_except
+ )
+ )
+ elif not query._auto_correlate:
+ self.correlate = (None,)
+
+ # PART II
+
+ self._for_update_arg = query._for_update_arg
+
+ if self.compile_options._is_star and (len(self._entities) != 1):
+ raise sa_exc.CompileError(
+ "Can't generate ORM query that includes multiple expressions "
+ "at the same time as '*'; query for '*' alone if present"
+ )
+ for entity in self._entities:
+ entity.setup_compile_state(self)
+
+ for rec in self.create_eager_joins:
+ strategy = rec[0]
+ strategy(self, *rec[1:])
+
+ # else "load from discrete FROMs" mode,
+ # i.e. when each _MappedEntity has its own FROM
+
+ if self.compile_options._enable_single_crit:
+ self._adjust_for_extra_criteria()
+
+ if not self.primary_columns:
+ if self.compile_options._only_load_props:
+ raise sa_exc.InvalidRequestError(
+ "No column-based properties specified for "
+ "refresh operation. Use session.expire() "
+ "to reload collections and related items."
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Query contains no columns with which to SELECT from."
+ )
+
+ if not self.from_clauses:
+ self.from_clauses = list(self._fallback_from_clauses)
+
+ if self.order_by is False:
+ self.order_by = None
+
+ if self.multi_row_eager_loaders and self._should_nest_selectable:
+ self.statement = self._compound_eager_statement()
+ else:
+ self.statement = self._simple_statement()
+
+ if self.for_statement:
+ ezero = self._mapper_zero()
+ if ezero is not None:
+ # TODO: this goes away once we get rid of the deep entity
+ # thing
+ self.statement = self.statement._annotate(
+ {"deepentity": ezero}
+ )
+
+ @classmethod
+ def _create_entities_collection(cls, query, legacy):
+ """Creates a partial ORMSelectCompileState that includes
+ the full collection of _MapperEntity and other _QueryEntity objects.
+
+ Supports a few remaining use cases that are pre-compilation
+ but still need to gather some of the column / adaption information.
+
+ """
+ self = cls.__new__(cls)
+
+ self._entities = []
+ self._primary_entity = None
+ self._aliased_generations = {}
+ self._polymorphic_adapters = {}
+
+ compile_options = cls.default_compile_options.safe_merge(
+ query._compile_options
+ )
+ # legacy: only for query.with_polymorphic()
+ if compile_options._with_polymorphic_adapt_map:
+ self._with_polymorphic_adapt_map = dict(
+ compile_options._with_polymorphic_adapt_map
+ )
+ self._setup_with_polymorphics()
+
+ self._label_convention = self._column_naming_convention(
+ query._label_style, legacy
+ )
+
+ # entities will also set up polymorphic adapters for mappers
+ # that have with_polymorphic configured
+ _QueryEntity.to_compile_state(
+ self, query._raw_columns, self._entities, is_current_entities=True
+ )
+ return self
+
+ @classmethod
+ def determine_last_joined_entity(cls, statement):
+ setup_joins = statement._setup_joins
+
+ if not setup_joins:
+ return None
+
+ (target, onclause, from_, flags) = setup_joins[-1]
+
+ if isinstance(target, interfaces.PropComparator):
+ return target.entity
+ else:
+ return target
+
+ @classmethod
+ def all_selected_columns(cls, statement):
+ for element in statement._raw_columns:
+ if (
+ element.is_selectable
+ and "entity_namespace" in element._annotations
+ ):
+ ens = element._annotations["entity_namespace"]
+ if not ens.is_mapper and not ens.is_aliased_class:
+ for elem in _select_iterables([element]):
+ yield elem
+ else:
+ for elem in _select_iterables(ens._all_column_expressions):
+ yield elem
+ else:
+ for elem in _select_iterables([element]):
+ yield elem
+
+ @classmethod
+ def get_columns_clause_froms(cls, statement):
+ return cls._normalize_froms(
+ itertools.chain.from_iterable(
+ element._from_objects
+ if "parententity" not in element._annotations
+ else [
+ element._annotations["parententity"].__clause_element__()
+ ]
+ for element in statement._raw_columns
+ )
+ )
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm.query")
+ def from_statement(cls, statement, from_statement):
+ query = util.preloaded.orm_query
+
+ from_statement = coercions.expect(
+ roles.ReturnsRowsRole,
+ from_statement,
+ apply_propagate_attrs=statement,
+ )
+
+ stmt = query.FromStatement(statement._raw_columns, from_statement)
+
+ stmt.__dict__.update(
+ _with_options=statement._with_options,
+ _with_context_options=statement._with_context_options,
+ _execution_options=statement._execution_options,
+ _propagate_attrs=statement._propagate_attrs,
+ )
+ return stmt
+
+ def _setup_with_polymorphics(self):
+ # legacy: only for query.with_polymorphic()
+ for ext_info, wp in self._with_polymorphic_adapt_map.items():
+ self._mapper_loads_polymorphically_with(ext_info, wp._adapter)
+
+ def _set_select_from_alias(self):
+
+ query = self.select_statement # query
+
+ assert self.compile_options._set_base_alias
+ assert len(query._from_obj) == 1
+
+ adapter = self._get_select_from_alias_from_obj(query._from_obj[0])
+ if adapter:
+ self.compile_options += {"_enable_single_crit": False}
+ self._from_obj_alias = adapter
+
+ def _get_select_from_alias_from_obj(self, from_obj):
+ info = from_obj
+
+ if "parententity" in info._annotations:
+ info = info._annotations["parententity"]
+
+ if hasattr(info, "mapper"):
+ if not info.is_aliased_class:
+ raise sa_exc.ArgumentError(
+ "A selectable (FromClause) instance is "
+ "expected when the base alias is being set."
+ )
+ else:
+ return info._adapter
+
+ elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows):
+ equivs = self._all_equivs()
+ return sql_util.ColumnAdapter(info, equivs)
+ else:
+ return None
+
+ def _mapper_zero(self):
+ """return the Mapper associated with the first QueryEntity."""
+ return self._entities[0].mapper
+
+ def _entity_zero(self):
+ """Return the 'entity' (mapper or AliasedClass) associated
+ with the first QueryEntity, or alternatively the 'select from'
+ entity if specified."""
+
+ for ent in self.from_clauses:
+ if "parententity" in ent._annotations:
+ return ent._annotations["parententity"]
+ for qent in self._entities:
+ if qent.entity_zero:
+ return qent.entity_zero
+
+ return None
+
+ def _only_full_mapper_zero(self, methname):
+ if self._entities != [self._primary_entity]:
+ raise sa_exc.InvalidRequestError(
+ "%s() can only be used against "
+ "a single mapped class." % methname
+ )
+ return self._primary_entity.entity_zero
+
+ def _only_entity_zero(self, rationale=None):
+ if len(self._entities) > 1:
+ raise sa_exc.InvalidRequestError(
+ rationale
+ or "This operation requires a Query "
+ "against a single mapper."
+ )
+ return self._entity_zero()
+
+ def _all_equivs(self):
+ equivs = {}
+
+ for memoized_entities in self._memoized_entities.values():
+ for ent in [
+ ent
+ for ent in memoized_entities
+ if isinstance(ent, _MapperEntity)
+ ]:
+ equivs.update(ent.mapper._equivalent_columns)
+
+ for ent in [
+ ent for ent in self._entities if isinstance(ent, _MapperEntity)
+ ]:
+ equivs.update(ent.mapper._equivalent_columns)
+ return equivs
+
+ def _compound_eager_statement(self):
+ # for eager joins present and LIMIT/OFFSET/DISTINCT,
+ # wrap the query inside a select,
+ # then append eager joins onto that
+
+ if self.order_by:
+ # the default coercion for ORDER BY is now the OrderByRole,
+ # which adds an additional post coercion to ByOfRole in that
+ # elements are converted into label references. For the
+ # eager load / subquery wrapping case, we need to un-coerce
+ # the original expressions outside of the label references
+ # in order to have them render.
+ unwrapped_order_by = [
+ elem.element
+ if isinstance(elem, sql.elements._label_reference)
+ else elem
+ for elem in self.order_by
+ ]
+
+ order_by_col_expr = sql_util.expand_column_list_from_order_by(
+ self.primary_columns, unwrapped_order_by
+ )
+ else:
+ order_by_col_expr = []
+ unwrapped_order_by = None
+
+ # put FOR UPDATE on the inner query, where MySQL will honor it,
+ # as well as if it has an OF so PostgreSQL can use it.
+ inner = self._select_statement(
+ self.primary_columns
+ + [c for c in order_by_col_expr if c not in self.dedupe_columns],
+ self.from_clauses,
+ self._where_criteria,
+ self._having_criteria,
+ self.label_style,
+ self.order_by,
+ for_update=self._for_update_arg,
+ hints=self.select_statement._hints,
+ statement_hints=self.select_statement._statement_hints,
+ correlate=self.correlate,
+ correlate_except=self.correlate_except,
+ **self._select_args
+ )
+
+ inner = inner.alias()
+
+ equivs = self._all_equivs()
+
+ self.compound_eager_adapter = sql_util.ColumnAdapter(inner, equivs)
+
+ statement = future.select(
+ *([inner] + self.secondary_columns) # use_labels=self.labels
+ )
+ statement._label_style = self.label_style
+
+ # Oracle however does not allow FOR UPDATE on the subquery,
+ # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL
+ # we expect that all elements of the row are locked, so also put it
+ # on the outside (except in the case of PG when OF is used)
+ if (
+ self._for_update_arg is not None
+ and self._for_update_arg.of is None
+ ):
+ statement._for_update_arg = self._for_update_arg
+
+ from_clause = inner
+ for eager_join in self.eager_joins.values():
+ # EagerLoader places a 'stop_on' attribute on the join,
+ # giving us a marker as to where the "splice point" of
+ # the join should be
+ from_clause = sql_util.splice_joins(
+ from_clause, eager_join, eager_join.stop_on
+ )
+
+ statement.select_from.non_generative(statement, from_clause)
+
+ if unwrapped_order_by:
+ statement.order_by.non_generative(
+ statement,
+ *self.compound_eager_adapter.copy_and_process(
+ unwrapped_order_by
+ )
+ )
+
+ statement.order_by.non_generative(statement, *self.eager_order_by)
+ return statement
+
+ def _simple_statement(self):
+
+ if (
+ self.compile_options._use_legacy_query_style
+ and (self.distinct and not self.distinct_on)
+ and self.order_by
+ ):
+ to_add = sql_util.expand_column_list_from_order_by(
+ self.primary_columns, self.order_by
+ )
+ if to_add:
+ util.warn_deprecated_20(
+ "ORDER BY columns added implicitly due to "
+ "DISTINCT is deprecated and will be removed in "
+ "SQLAlchemy 2.0. SELECT statements with DISTINCT "
+ "should be written to explicitly include the appropriate "
+ "columns in the columns clause"
+ )
+ self.primary_columns += to_add
+
+ statement = self._select_statement(
+ self.primary_columns + self.secondary_columns,
+ tuple(self.from_clauses) + tuple(self.eager_joins.values()),
+ self._where_criteria,
+ self._having_criteria,
+ self.label_style,
+ self.order_by,
+ for_update=self._for_update_arg,
+ hints=self.select_statement._hints,
+ statement_hints=self.select_statement._statement_hints,
+ correlate=self.correlate,
+ correlate_except=self.correlate_except,
+ **self._select_args
+ )
+
+ if self.eager_order_by:
+ statement.order_by.non_generative(statement, *self.eager_order_by)
+ return statement
+
+ def _select_statement(
+ self,
+ raw_columns,
+ from_obj,
+ where_criteria,
+ having_criteria,
+ label_style,
+ order_by,
+ for_update,
+ hints,
+ statement_hints,
+ correlate,
+ correlate_except,
+ limit_clause,
+ offset_clause,
+ fetch_clause,
+ fetch_clause_options,
+ distinct,
+ distinct_on,
+ prefixes,
+ suffixes,
+ group_by,
+ ):
+
+ Select = future.Select
+ statement = Select._create_raw_select(
+ _raw_columns=raw_columns,
+ _from_obj=from_obj,
+ _label_style=label_style,
+ )
+
+ if where_criteria:
+ statement._where_criteria = where_criteria
+ if having_criteria:
+ statement._having_criteria = having_criteria
+
+ if order_by:
+ statement._order_by_clauses += tuple(order_by)
+
+ if distinct_on:
+ statement.distinct.non_generative(statement, *distinct_on)
+ elif distinct:
+ statement.distinct.non_generative(statement)
+
+ if group_by:
+ statement._group_by_clauses += tuple(group_by)
+
+ statement._limit_clause = limit_clause
+ statement._offset_clause = offset_clause
+ statement._fetch_clause = fetch_clause
+ statement._fetch_clause_options = fetch_clause_options
+
+ if prefixes:
+ statement._prefixes = prefixes
+
+ if suffixes:
+ statement._suffixes = suffixes
+
+ statement._for_update_arg = for_update
+
+ if hints:
+ statement._hints = hints
+ if statement_hints:
+ statement._statement_hints = statement_hints
+
+ if correlate:
+ statement.correlate.non_generative(statement, *correlate)
+
+ if correlate_except is not None:
+ statement.correlate_except.non_generative(
+ statement, *correlate_except
+ )
+
+ return statement
+
+ def _adapt_polymorphic_element(self, element):
+ if "parententity" in element._annotations:
+ search = element._annotations["parententity"]
+ alias = self._polymorphic_adapters.get(search, None)
+ if alias:
+ return alias.adapt_clause(element)
+
+ if isinstance(element, expression.FromClause):
+ search = element
+ elif hasattr(element, "table"):
+ search = element.table
+ else:
+ return None
+
+ alias = self._polymorphic_adapters.get(search, None)
+ if alias:
+ return alias.adapt_clause(element)
+
+ def _adapt_aliased_generation(self, element):
+ # this is crazy logic that I look forward to blowing away
+ # when aliased=True is gone :)
+ if "aliased_generation" in element._annotations:
+ for adapter in self._aliased_generations.get(
+ element._annotations["aliased_generation"], ()
+ ):
+ replaced_elem = adapter.replace(element)
+ if replaced_elem is not None:
+ return replaced_elem
+
+ return None
+
+ def _adapt_col_list(self, cols, current_adapter):
+ if current_adapter:
+ return [current_adapter(o, True) for o in cols]
+ else:
+ return cols
+
+ def _get_current_adapter(self):
+
+ adapters = []
+
+ if self._from_obj_alias:
+ # used for legacy going forward for query set_ops, e.g.
+ # union(), union_all(), etc.
+ # 1.4 and previously, also used for from_self(),
+ # select_entity_from()
+ #
+ # for the "from obj" alias, apply extra rule to the
+ # 'ORM only' check, if this query were generated from a
+ # subquery of itself, i.e. _from_selectable(), apply adaption
+ # to all SQL constructs.
+ adapters.append(
+ (
+ False
+ if self.compile_options._orm_only_from_obj_alias
+ else True,
+ self._from_obj_alias.replace,
+ )
+ )
+
+ # vvvvvvvvvvvvvvv legacy vvvvvvvvvvvvvvvvvv
+ # this can totally go away when we remove join(..., aliased=True)
+ if self._aliased_generations:
+ adapters.append((False, self._adapt_aliased_generation))
+ # ^^^^^^^^^^^^^ legacy ^^^^^^^^^^^^^^^^^^^^^
+
+ # this was *hopefully* the only adapter we were going to need
+ # going forward...however, we unfortunately need _from_obj_alias
+ # for query.union(), which we can't drop
+ if self._polymorphic_adapters:
+ adapters.append((False, self._adapt_polymorphic_element))
+
+ if not adapters:
+ return None
+
+ def _adapt_clause(clause, as_filter):
+ # do we adapt all expression elements or only those
+ # tagged as 'ORM' constructs ?
+
+ def replace(elem):
+ is_orm_adapt = (
+ "_orm_adapt" in elem._annotations
+ or "parententity" in elem._annotations
+ )
+ for always_adapt, adapter in adapters:
+ if is_orm_adapt or always_adapt:
+ e = adapter(elem)
+ if e is not None:
+ return e
+
+ return visitors.replacement_traverse(clause, {}, replace)
+
+ return _adapt_clause
+
+ def _join(self, args, entities_collection):
+ for (right, onclause, from_, flags) in args:
+ isouter = flags["isouter"]
+ full = flags["full"]
+ # maybe?
+ self._reset_joinpoint()
+
+ right = inspect(right)
+ if onclause is not None:
+ onclause = inspect(onclause)
+
+ if onclause is None and isinstance(
+ right, interfaces.PropComparator
+ ):
+ # determine onclause/right_entity. still need to think
+ # about how to best organize this since we are getting:
+ #
+ #
+ # q.join(Entity, Parent.property)
+ # q.join(Parent.property)
+ # q.join(Parent.property.of_type(Entity))
+ # q.join(some_table)
+ # q.join(some_table, some_parent.c.id==some_table.c.parent_id)
+ #
+ # is this still too many choices? how do we handle this
+ # when sometimes "right" is implied and sometimes not?
+ #
+ onclause = right
+ right = None
+ elif "parententity" in right._annotations:
+ right = right._annotations["parententity"]
+
+ if onclause is None:
+ if not right.is_selectable and not hasattr(right, "mapper"):
+ raise sa_exc.ArgumentError(
+ "Expected mapped entity or "
+ "selectable/table as join target"
+ )
+
+ of_type = None
+
+ if isinstance(onclause, interfaces.PropComparator):
+ # descriptor/property given (or determined); this tells us
+ # explicitly what the expected "left" side of the join is.
+
+ of_type = getattr(onclause, "_of_type", None)
+
+ if right is None:
+ if of_type:
+ right = of_type
+ else:
+ right = onclause.property
+
+ try:
+ right = right.entity
+ except AttributeError as err:
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Join target %s does not refer to a "
+ "mapped entity" % right
+ ),
+ replace_context=err,
+ )
+
+ left = onclause._parententity
+
+ alias = self._polymorphic_adapters.get(left, None)
+
+ # could be None or could be ColumnAdapter also
+ if isinstance(alias, ORMAdapter) and alias.mapper.isa(left):
+ left = alias.aliased_class
+ onclause = getattr(left, onclause.key)
+
+ prop = onclause.property
+ if not isinstance(onclause, attributes.QueryableAttribute):
+ onclause = prop
+
+ # TODO: this is where "check for path already present"
+ # would occur. see if this still applies?
+
+ if from_ is not None:
+ if (
+ from_ is not left
+ and from_._annotations.get("parententity", None)
+ is not left
+ ):
+ raise sa_exc.InvalidRequestError(
+ "explicit from clause %s does not match left side "
+ "of relationship attribute %s"
+ % (
+ from_._annotations.get("parententity", from_),
+ onclause,
+ )
+ )
+ elif from_ is not None:
+ prop = None
+ left = from_
+ else:
+ # no descriptor/property given; we will need to figure out
+ # what the effective "left" side is
+ prop = left = None
+
+ # figure out the final "left" and "right" sides and create an
+ # ORMJoin to add to our _from_obj tuple
+ self._join_left_to_right(
+ entities_collection,
+ left,
+ right,
+ onclause,
+ prop,
+ False,
+ False,
+ isouter,
+ full,
+ )
+
+ def _legacy_join(self, args, entities_collection):
+ """consumes arguments from join() or outerjoin(), places them into a
+ consistent format with which to form the actual JOIN constructs.
+
+ """
+ for (right, onclause, left, flags) in args:
+
+ outerjoin = flags["isouter"]
+ create_aliases = flags["aliased"]
+ from_joinpoint = flags["from_joinpoint"]
+ full = flags["full"]
+ aliased_generation = flags["aliased_generation"]
+
+ # do a quick inspect to accommodate for a lambda
+ if right is not None and not isinstance(right, util.string_types):
+ right = inspect(right)
+ if onclause is not None and not isinstance(
+ onclause, util.string_types
+ ):
+ onclause = inspect(onclause)
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvv
+ if not from_joinpoint:
+ self._reset_joinpoint()
+ else:
+ prev_aliased_generation = self._joinpoint.get(
+ "aliased_generation", None
+ )
+ if not aliased_generation:
+ aliased_generation = prev_aliased_generation
+ elif prev_aliased_generation:
+ self._aliased_generations[
+ aliased_generation
+ ] = self._aliased_generations.get(
+ prev_aliased_generation, ()
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ if (
+ isinstance(
+ right, (interfaces.PropComparator, util.string_types)
+ )
+ and onclause is None
+ ):
+ onclause = right
+ right = None
+ elif "parententity" in right._annotations:
+ right = right._annotations["parententity"]
+
+ if onclause is None:
+ if not right.is_selectable and not hasattr(right, "mapper"):
+ raise sa_exc.ArgumentError(
+ "Expected mapped entity or "
+ "selectable/table as join target"
+ )
+
+ if isinstance(onclause, interfaces.PropComparator):
+ of_type = getattr(onclause, "_of_type", None)
+ else:
+ of_type = None
+
+ if isinstance(onclause, util.string_types):
+ # string given, e.g. query(Foo).join("bar").
+ # we look to the left entity or what we last joined
+ # towards
+ onclause = _entity_namespace_key(
+ inspect(self._joinpoint_zero()), onclause
+ )
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
+ # check for q.join(Class.propname, from_joinpoint=True)
+ # and Class corresponds at the mapper level to the current
+ # joinpoint. this match intentionally looks for a non-aliased
+ # class-bound descriptor as the onclause and if it matches the
+ # current joinpoint at the mapper level, it's used. This
+ # is a very old use case that is intended to make it easier
+ # to work with the aliased=True flag, which is also something
+ # that probably shouldn't exist on join() due to its high
+ # complexity/usefulness ratio
+ elif from_joinpoint and isinstance(
+ onclause, interfaces.PropComparator
+ ):
+ jp0 = self._joinpoint_zero()
+ info = inspect(jp0)
+
+ if getattr(info, "mapper", None) is onclause._parententity:
+ onclause = _entity_namespace_key(info, onclause.key)
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ if isinstance(onclause, interfaces.PropComparator):
+ # descriptor/property given (or determined); this tells us
+ # explicitly what the expected "left" side of the join is.
+ if right is None:
+ if of_type:
+ right = of_type
+ else:
+ right = onclause.property
+
+ try:
+ right = right.entity
+ except AttributeError as err:
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Join target %s does not refer to a "
+ "mapped entity" % right
+ ),
+ replace_context=err,
+ )
+
+ left = onclause._parententity
+
+ alias = self._polymorphic_adapters.get(left, None)
+
+ # could be None or could be ColumnAdapter also
+ if isinstance(alias, ORMAdapter) and alias.mapper.isa(left):
+ left = alias.aliased_class
+ onclause = getattr(left, onclause.key)
+
+ prop = onclause.property
+ if not isinstance(onclause, attributes.QueryableAttribute):
+ onclause = prop
+
+ if not create_aliases:
+ # check for this path already present.
+ # don't render in that case.
+ edge = (left, right, prop.key)
+ if edge in self._joinpoint:
+ # The child's prev reference might be stale --
+ # it could point to a parent older than the
+ # current joinpoint. If this is the case,
+ # then we need to update it and then fix the
+ # tree's spine with _update_joinpoint. Copy
+ # and then mutate the child, which might be
+ # shared by a different query object.
+ jp = self._joinpoint[edge].copy()
+ jp["prev"] = (edge, self._joinpoint)
+ self._update_joinpoint(jp)
+
+ continue
+
+ else:
+ # no descriptor/property given; we will need to figure out
+ # what the effective "left" side is
+ prop = left = None
+
+ # figure out the final "left" and "right" sides and create an
+ # ORMJoin to add to our _from_obj tuple
+ self._join_left_to_right(
+ entities_collection,
+ left,
+ right,
+ onclause,
+ prop,
+ create_aliases,
+ aliased_generation,
+ outerjoin,
+ full,
+ )
+
+ def _joinpoint_zero(self):
+ return self._joinpoint.get("_joinpoint_entity", self._entity_zero())
+
+ def _join_left_to_right(
+ self,
+ entities_collection,
+ left,
+ right,
+ onclause,
+ prop,
+ create_aliases,
+ aliased_generation,
+ outerjoin,
+ full,
+ ):
+ """given raw "left", "right", "onclause" parameters consumed from
+ a particular key within _join(), add a real ORMJoin object to
+ our _from_obj list (or augment an existing one)
+
+ """
+
+ if left is None:
+ # left not given (e.g. no relationship object/name specified)
+ # figure out the best "left" side based on our existing froms /
+ # entities
+ assert prop is None
+ (
+ left,
+ replace_from_obj_index,
+ use_entity_index,
+ ) = self._join_determine_implicit_left_side(
+ entities_collection, left, right, onclause
+ )
+ else:
+ # left is given via a relationship/name, or as explicit left side.
+ # Determine where in our
+ # "froms" list it should be spliced/appended as well as what
+ # existing entity it corresponds to.
+ (
+ replace_from_obj_index,
+ use_entity_index,
+ ) = self._join_place_explicit_left_side(entities_collection, left)
+
+ if left is right and not create_aliases:
+ raise sa_exc.InvalidRequestError(
+ "Can't construct a join from %s to %s, they "
+ "are the same entity" % (left, right)
+ )
+
+ # the right side as given often needs to be adapted. additionally
+ # a lot of things can be wrong with it. handle all that and
+ # get back the new effective "right" side
+ r_info, right, onclause = self._join_check_and_adapt_right_side(
+ left, right, onclause, prop, create_aliases, aliased_generation
+ )
+
+ if not r_info.is_selectable:
+ extra_criteria = self._get_extra_criteria(r_info)
+ else:
+ extra_criteria = ()
+
+ if replace_from_obj_index is not None:
+ # splice into an existing element in the
+ # self._from_obj list
+ left_clause = self.from_clauses[replace_from_obj_index]
+
+ self.from_clauses = (
+ self.from_clauses[:replace_from_obj_index]
+ + [
+ _ORMJoin(
+ left_clause,
+ right,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ _extra_criteria=extra_criteria,
+ )
+ ]
+ + self.from_clauses[replace_from_obj_index + 1 :]
+ )
+ else:
+ # add a new element to the self._from_obj list
+ if use_entity_index is not None:
+ # make use of _MapperEntity selectable, which is usually
+ # entity_zero.selectable, but if with_polymorphic() were used
+ # might be distinct
+ assert isinstance(
+ entities_collection[use_entity_index], _MapperEntity
+ )
+ left_clause = entities_collection[use_entity_index].selectable
+ else:
+ left_clause = left
+
+ self.from_clauses = self.from_clauses + [
+ _ORMJoin(
+ left_clause,
+ r_info,
+ onclause,
+ isouter=outerjoin,
+ full=full,
+ _extra_criteria=extra_criteria,
+ )
+ ]
+
+ def _join_determine_implicit_left_side(
+ self, entities_collection, left, right, onclause
+ ):
+ """When join conditions don't express the left side explicitly,
+ determine if an existing FROM or entity in this query
+ can serve as the left hand side.
+
+ """
+
+ # when we are here, it means join() was called without an ORM-
+ # specific way of telling us what the "left" side is, e.g.:
+ #
+ # join(RightEntity)
+ #
+ # or
+ #
+ # join(RightEntity, RightEntity.foo == LeftEntity.bar)
+ #
+
+ r_info = inspect(right)
+
+ replace_from_obj_index = use_entity_index = None
+
+ if self.from_clauses:
+ # we have a list of FROMs already. So by definition this
+ # join has to connect to one of those FROMs.
+
+ indexes = sql_util.find_left_clause_to_join_from(
+ self.from_clauses, r_info.selectable, onclause
+ )
+
+ if len(indexes) == 1:
+ replace_from_obj_index = indexes[0]
+ left = self.from_clauses[replace_from_obj_index]
+ elif len(indexes) > 1:
+ raise sa_exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity."
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity." % (right,)
+ )
+
+ elif entities_collection:
+ # we have no explicit FROMs, so the implicit left has to
+ # come from our list of entities.
+
+ potential = {}
+ for entity_index, ent in enumerate(entities_collection):
+ entity = ent.entity_zero_or_selectable
+ if entity is None:
+ continue
+ ent_info = inspect(entity)
+ if ent_info is r_info: # left and right are the same, skip
+ continue
+
+ # by using a dictionary with the selectables as keys this
+ # de-duplicates those selectables as occurs when the query is
+ # against a series of columns from the same selectable
+ if isinstance(ent, _MapperEntity):
+ potential[ent.selectable] = (entity_index, entity)
+ else:
+ potential[ent_info.selectable] = (None, entity)
+
+ all_clauses = list(potential.keys())
+ indexes = sql_util.find_left_clause_to_join_from(
+ all_clauses, r_info.selectable, onclause
+ )
+
+ if len(indexes) == 1:
+ use_entity_index, left = potential[all_clauses[indexes[0]]]
+ elif len(indexes) > 1:
+ raise sa_exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity."
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already "
+ "to help resolve the ambiguity." % (right,)
+ )
+ else:
+ raise sa_exc.InvalidRequestError(
+ "No entities to join from; please use "
+ "select_from() to establish the left "
+ "entity/selectable of this join"
+ )
+
+ return left, replace_from_obj_index, use_entity_index
+
+ def _join_place_explicit_left_side(self, entities_collection, left):
+ """When join conditions express a left side explicitly, determine
+ where in our existing list of FROM clauses we should join towards,
+ or if we need to make a new join, and if so is it from one of our
+ existing entities.
+
+ """
+
+ # when we are here, it means join() was called with an indicator
+ # as to an exact left side, which means a path to a
+ # RelationshipProperty was given, e.g.:
+ #
+ # join(RightEntity, LeftEntity.right)
+ #
+ # or
+ #
+ # join(LeftEntity.right)
+ #
+ # as well as string forms:
+ #
+ # join(RightEntity, "right")
+ #
+ # etc.
+ #
+
+ replace_from_obj_index = use_entity_index = None
+
+ l_info = inspect(left)
+ if self.from_clauses:
+ indexes = sql_util.find_left_clause_that_matches_given(
+ self.from_clauses, l_info.selectable
+ )
+
+ if len(indexes) > 1:
+ raise sa_exc.InvalidRequestError(
+ "Can't identify which entity in which to assign the "
+ "left side of this join. Please use a more specific "
+ "ON clause."
+ )
+
+ # have an index, means the left side is already present in
+ # an existing FROM in the self._from_obj tuple
+ if indexes:
+ replace_from_obj_index = indexes[0]
+
+ # no index, means we need to add a new element to the
+ # self._from_obj tuple
+
+ # no from element present, so we will have to add to the
+ # self._from_obj tuple. Determine if this left side matches up
+ # with existing mapper entities, in which case we want to apply the
+ # aliasing / adaptation rules present on that entity if any
+ if (
+ replace_from_obj_index is None
+ and entities_collection
+ and hasattr(l_info, "mapper")
+ ):
+ for idx, ent in enumerate(entities_collection):
+ # TODO: should we be checking for multiple mapper entities
+ # matching?
+ if isinstance(ent, _MapperEntity) and ent.corresponds_to(left):
+ use_entity_index = idx
+ break
+
+ return replace_from_obj_index, use_entity_index
+
+ def _join_check_and_adapt_right_side(
+ self, left, right, onclause, prop, create_aliases, aliased_generation
+ ):
+ """transform the "right" side of the join as well as the onclause
+ according to polymorphic mapping translations, aliasing on the query
+ or on the join, special cases where the right and left side have
+ overlapping tables.
+
+ """
+
+ l_info = inspect(left)
+ r_info = inspect(right)
+
+ overlap = False
+ if not create_aliases:
+ right_mapper = getattr(r_info, "mapper", None)
+ # if the target is a joined inheritance mapping,
+ # be more liberal about auto-aliasing.
+ if right_mapper and (
+ right_mapper.with_polymorphic
+ or isinstance(right_mapper.persist_selectable, expression.Join)
+ ):
+ for from_obj in self.from_clauses or [l_info.selectable]:
+ if sql_util.selectables_overlap(
+ l_info.selectable, from_obj
+ ) and sql_util.selectables_overlap(
+ from_obj, r_info.selectable
+ ):
+ overlap = True
+ break
+
+ if (
+ overlap or not create_aliases
+ ) and l_info.selectable is r_info.selectable:
+ raise sa_exc.InvalidRequestError(
+ "Can't join table/selectable '%s' to itself"
+ % l_info.selectable
+ )
+
+ right_mapper, right_selectable, right_is_aliased = (
+ getattr(r_info, "mapper", None),
+ r_info.selectable,
+ getattr(r_info, "is_aliased_class", False),
+ )
+
+ if (
+ right_mapper
+ and prop
+ and not right_mapper.common_parent(prop.mapper)
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Join target %s does not correspond to "
+ "the right side of join condition %s" % (right, onclause)
+ )
+
+ # _join_entities is used as a hint for single-table inheritance
+ # purposes at the moment
+ if hasattr(r_info, "mapper"):
+ self._join_entities += (r_info,)
+
+ need_adapter = False
+
+ # test for joining to an unmapped selectable as the target
+ if r_info.is_clause_element:
+
+ if prop:
+ right_mapper = prop.mapper
+
+ if right_selectable._is_lateral:
+ # orm_only is disabled to suit the case where we have to
+ # adapt an explicit correlate(Entity) - the select() loses
+ # the ORM-ness in this case right now, ideally it would not
+ current_adapter = self._get_current_adapter()
+ if current_adapter is not None:
+ # TODO: we had orm_only=False here before, removing
+ # it didn't break things. if we identify the rationale,
+ # may need to apply "_orm_only" annotation here.
+ right = current_adapter(right, True)
+
+ elif prop:
+ # joining to selectable with a mapper property given
+ # as the ON clause
+
+ if not right_selectable.is_derived_from(
+ right_mapper.persist_selectable
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Selectable '%s' is not derived from '%s'"
+ % (
+ right_selectable.description,
+ right_mapper.persist_selectable.description,
+ )
+ )
+
+ # if the destination selectable is a plain select(),
+ # turn it into an alias().
+ if isinstance(right_selectable, expression.SelectBase):
+ right_selectable = coercions.expect(
+ roles.FromClauseRole, right_selectable
+ )
+ need_adapter = True
+
+ # make the right hand side target into an ORM entity
+ right = aliased(right_mapper, right_selectable)
+
+ util.warn_deprecated(
+ "An alias is being generated automatically against "
+ "joined entity %s for raw clauseelement, which is "
+ "deprecated and will be removed in a later release. "
+ "Use the aliased() "
+ "construct explicitly, see the linked example."
+ % right_mapper,
+ "1.4",
+ code="xaj1",
+ )
+
+ elif create_aliases:
+ # it *could* work, but it doesn't right now and I'd rather
+ # get rid of aliased=True completely
+ raise sa_exc.InvalidRequestError(
+ "The aliased=True parameter on query.join() only works "
+ "with an ORM entity, not a plain selectable, as the "
+ "target."
+ )
+
+ # test for overlap:
+ # orm/inheritance/relationships.py
+ # SelfReferentialM2MTest
+ aliased_entity = right_mapper and not right_is_aliased and overlap
+
+ if not need_adapter and (create_aliases or aliased_entity):
+ # there are a few places in the ORM that automatic aliasing
+ # is still desirable, and can't be automatic with a Core
+ # only approach. For illustrations of "overlaps" see
+ # test/orm/inheritance/test_relationships.py. There are also
+ # general overlap cases with many-to-many tables where automatic
+ # aliasing is desirable.
+ right = aliased(right, flat=True)
+ need_adapter = True
+
+ if not create_aliases:
+ util.warn(
+ "An alias is being generated automatically against "
+ "joined entity %s due to overlapping tables. This is a "
+ "legacy pattern which may be "
+ "deprecated in a later release. Use the "
+ "aliased(<entity>, flat=True) "
+ "construct explicitly, see the linked example."
+ % right_mapper,
+ code="xaj2",
+ )
+
+ if need_adapter:
+ assert right_mapper
+
+ adapter = ORMAdapter(
+ right, equivalents=right_mapper._equivalent_columns
+ )
+
+ # if an alias() on the right side was generated,
+ # which is intended to wrap a the right side in a subquery,
+ # ensure that columns retrieved from this target in the result
+ # set are also adapted.
+ if not create_aliases:
+ self._mapper_loads_polymorphically_with(right_mapper, adapter)
+ elif aliased_generation:
+ adapter._debug = True
+ self._aliased_generations[aliased_generation] = (
+ adapter,
+ ) + self._aliased_generations.get(aliased_generation, ())
+ elif (
+ not r_info.is_clause_element
+ and not right_is_aliased
+ and right_mapper.with_polymorphic
+ and isinstance(
+ right_mapper._with_polymorphic_selectable,
+ expression.AliasedReturnsRows,
+ )
+ ):
+ # for the case where the target mapper has a with_polymorphic
+ # set up, ensure an adapter is set up for criteria that works
+ # against this mapper. Previously, this logic used to
+ # use the "create_aliases or aliased_entity" case to generate
+ # an aliased() object, but this creates an alias that isn't
+ # strictly necessary.
+ # see test/orm/test_core_compilation.py
+ # ::RelNaturalAliasedJoinsTest::test_straight
+ # and similar
+ self._mapper_loads_polymorphically_with(
+ right_mapper,
+ sql_util.ColumnAdapter(
+ right_mapper.selectable,
+ right_mapper._equivalent_columns,
+ ),
+ )
+ # if the onclause is a ClauseElement, adapt it with any
+ # adapters that are in place right now
+ if isinstance(onclause, expression.ClauseElement):
+ current_adapter = self._get_current_adapter()
+ if current_adapter:
+ onclause = current_adapter(onclause, True)
+
+ # if joining on a MapperProperty path,
+ # track the path to prevent redundant joins
+ if not create_aliases and prop:
+ self._update_joinpoint(
+ {
+ "_joinpoint_entity": right,
+ "prev": ((left, right, prop.key), self._joinpoint),
+ "aliased_generation": aliased_generation,
+ }
+ )
+ else:
+ self._joinpoint = {
+ "_joinpoint_entity": right,
+ "aliased_generation": aliased_generation,
+ }
+
+ return inspect(right), right, onclause
+
+ def _update_joinpoint(self, jp):
+ self._joinpoint = jp
+ # copy backwards to the root of the _joinpath
+ # dict, so that no existing dict in the path is mutated
+ while "prev" in jp:
+ f, prev = jp["prev"]
+ prev = dict(prev)
+ prev[f] = jp.copy()
+ jp["prev"] = (f, prev)
+ jp = prev
+ self._joinpath = jp
+
+ def _reset_joinpoint(self):
+ self._joinpoint = self._joinpath
+
+ @property
+ def _select_args(self):
+ return {
+ "limit_clause": self.select_statement._limit_clause,
+ "offset_clause": self.select_statement._offset_clause,
+ "distinct": self.distinct,
+ "distinct_on": self.distinct_on,
+ "prefixes": self.select_statement._prefixes,
+ "suffixes": self.select_statement._suffixes,
+ "group_by": self.group_by or None,
+ "fetch_clause": self.select_statement._fetch_clause,
+ "fetch_clause_options": (
+ self.select_statement._fetch_clause_options
+ ),
+ }
+
+ @property
+ def _should_nest_selectable(self):
+ kwargs = self._select_args
+ return (
+ kwargs.get("limit_clause") is not None
+ or kwargs.get("offset_clause") is not None
+ or kwargs.get("distinct", False)
+ or kwargs.get("distinct_on", ())
+ or kwargs.get("group_by", False)
+ )
+
+ def _get_extra_criteria(self, ext_info):
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in self.global_attributes:
+ return tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in self.global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if (ae.include_aliases or ae.entity is ext_info)
+ and ae._should_include(self)
+ )
+ else:
+ return ()
+
+ def _adjust_for_extra_criteria(self):
+ """Apply extra criteria filtering.
+
+ For all distinct single-table-inheritance mappers represented in
+ the columns clause of this query, as well as the "select from entity",
+ add criterion to the WHERE
+ clause of the given QueryContext such that only the appropriate
+ subtypes are selected from the total results.
+
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ associated with the global context.
+
+ """
+
+ for fromclause in self.from_clauses:
+ ext_info = fromclause._annotations.get("parententity", None)
+ if (
+ ext_info
+ and (
+ ext_info.mapper._single_table_criterion is not None
+ or ("additional_entity_criteria", ext_info.mapper)
+ in self.global_attributes
+ )
+ and ext_info not in self.extra_criteria_entities
+ ):
+
+ self.extra_criteria_entities[ext_info] = (
+ ext_info,
+ ext_info._adapter if ext_info.is_aliased_class else None,
+ )
+
+ search = set(self.extra_criteria_entities.values())
+
+ for (ext_info, adapter) in search:
+ if ext_info in self._join_entities:
+ continue
+
+ single_crit = ext_info.mapper._single_table_criterion
+
+ if self.compile_options._for_refresh_state:
+ additional_entity_criteria = []
+ else:
+ additional_entity_criteria = self._get_extra_criteria(ext_info)
+
+ if single_crit is not None:
+ additional_entity_criteria += (single_crit,)
+
+ current_adapter = self._get_current_adapter()
+ for crit in additional_entity_criteria:
+ if adapter:
+ crit = adapter.traverse(crit)
+
+ if current_adapter:
+ crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
+ crit = current_adapter(crit, False)
+ self._where_criteria += (crit,)
+
+
+def _column_descriptions(
+ query_or_select_stmt, compile_state=None, legacy=False
+):
+ if compile_state is None:
+ compile_state = ORMSelectCompileState._create_entities_collection(
+ query_or_select_stmt, legacy=legacy
+ )
+ ctx = compile_state
+ return [
+ {
+ "name": ent._label_name,
+ "type": ent.type,
+ "aliased": getattr(insp_ent, "is_aliased_class", False),
+ "expr": ent.expr,
+ "entity": getattr(insp_ent, "entity", None)
+ if ent.entity_zero is not None and not insp_ent.is_clause_element
+ else None,
+ }
+ for ent, insp_ent in [
+ (
+ _ent,
+ (
+ inspect(_ent.entity_zero)
+ if _ent.entity_zero is not None
+ else None
+ ),
+ )
+ for _ent in ctx._entities
+ ]
+ ]
+
+
+def _legacy_filter_by_entity_zero(query_or_augmented_select):
+ self = query_or_augmented_select
+ if self._legacy_setup_joins:
+ _last_joined_entity = self._last_joined_entity
+ if _last_joined_entity is not None:
+ return _last_joined_entity
+
+ if self._from_obj and "parententity" in self._from_obj[0]._annotations:
+ return self._from_obj[0]._annotations["parententity"]
+
+ return _entity_from_pre_ent_zero(self)
+
+
+def _entity_from_pre_ent_zero(query_or_augmented_select):
+ self = query_or_augmented_select
+ if not self._raw_columns:
+ return None
+
+ ent = self._raw_columns[0]
+
+ if "parententity" in ent._annotations:
+ return ent._annotations["parententity"]
+ elif isinstance(ent, ORMColumnsClauseRole):
+ return ent.entity
+ elif "bundle" in ent._annotations:
+ return ent._annotations["bundle"]
+ else:
+ return ent
+
+
+def _legacy_determine_last_joined_entity(setup_joins, entity_zero):
+ """given the legacy_setup_joins collection at a point in time,
+ figure out what the "filter by entity" would be in terms
+ of those joins.
+
+ in 2.0 this logic should hopefully be much simpler as there will
+ be far fewer ways to specify joins with the ORM
+
+ """
+
+ if not setup_joins:
+ return entity_zero
+
+ # CAN BE REMOVED IN 2.0:
+ # 1. from_joinpoint
+ # 2. aliased_generation
+ # 3. aliased
+ # 4. any treating of prop as str
+ # 5. tuple madness
+ # 6. won't need recursive call anymore without #4
+ # 7. therefore can pass in just the last setup_joins record,
+ # don't need entity_zero
+
+ (right, onclause, left_, flags) = setup_joins[-1]
+
+ from_joinpoint = flags["from_joinpoint"]
+
+ if onclause is None and isinstance(
+ right, (str, interfaces.PropComparator)
+ ):
+ onclause = right
+ right = None
+
+ if right is not None and "parententity" in right._annotations:
+ right = right._annotations["parententity"].entity
+
+ if right is not None:
+ last_entity = right
+ insp = inspect(last_entity)
+ if insp.is_clause_element or insp.is_aliased_class or insp.is_mapper:
+ return insp
+
+ last_entity = onclause
+ if isinstance(last_entity, interfaces.PropComparator):
+ return last_entity.entity
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if isinstance(onclause, str):
+ if from_joinpoint:
+ prev = _legacy_determine_last_joined_entity(
+ setup_joins[0:-1], entity_zero
+ )
+ else:
+ prev = entity_zero
+
+ if prev is None:
+ return None
+
+ prev = inspect(prev)
+ attr = getattr(prev.entity, onclause, None)
+ if attr is not None:
+ return attr.property.entity
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ return None
+
+
+class _QueryEntity(object):
+ """represent an entity column returned within a Query result."""
+
+ __slots__ = ()
+
+ _non_hashable_value = False
+ _null_column_type = False
+ use_id_for_hash = False
+
+ @classmethod
+ def to_compile_state(
+ cls, compile_state, entities, entities_collection, is_current_entities
+ ):
+
+ for idx, entity in enumerate(entities):
+ if entity._is_lambda_element:
+ if entity._is_sequence:
+ cls.to_compile_state(
+ compile_state,
+ entity._resolved,
+ entities_collection,
+ is_current_entities,
+ )
+ continue
+ else:
+ entity = entity._resolved
+
+ if entity.is_clause_element:
+ if entity.is_selectable:
+ if "parententity" in entity._annotations:
+ _MapperEntity(
+ compile_state,
+ entity,
+ entities_collection,
+ is_current_entities,
+ )
+ else:
+ _ColumnEntity._for_columns(
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
+ is_current_entities,
+ )
+ else:
+ if entity._annotations.get("bundle", False):
+ _BundleEntity(
+ compile_state,
+ entity,
+ entities_collection,
+ is_current_entities,
+ )
+ elif entity._is_clause_list:
+ # this is legacy only - test_composites.py
+ # test_query_cols_legacy
+ _ColumnEntity._for_columns(
+ compile_state,
+ entity._select_iterable,
+ entities_collection,
+ idx,
+ is_current_entities,
+ )
+ else:
+ _ColumnEntity._for_columns(
+ compile_state,
+ [entity],
+ entities_collection,
+ idx,
+ is_current_entities,
+ )
+ elif entity.is_bundle:
+ _BundleEntity(compile_state, entity, entities_collection)
+
+ return entities_collection
+
+
+class _MapperEntity(_QueryEntity):
+ """mapper/class/AliasedClass entity"""
+
+ __slots__ = (
+ "expr",
+ "mapper",
+ "entity_zero",
+ "is_aliased_class",
+ "path",
+ "_extra_entities",
+ "_label_name",
+ "_with_polymorphic_mappers",
+ "selectable",
+ "_polymorphic_discriminator",
+ )
+
+ def __init__(
+ self, compile_state, entity, entities_collection, is_current_entities
+ ):
+ entities_collection.append(self)
+ if is_current_entities:
+ if compile_state._primary_entity is None:
+ compile_state._primary_entity = self
+ compile_state._has_mapper_entities = True
+ compile_state._has_orm_entities = True
+
+ entity = entity._annotations["parententity"]
+ entity._post_inspect
+ ext_info = self.entity_zero = entity
+ entity = ext_info.entity
+
+ self.expr = entity
+ self.mapper = mapper = ext_info.mapper
+
+ self._extra_entities = (self.expr,)
+
+ if ext_info.is_aliased_class:
+ self._label_name = ext_info.name
+ else:
+ self._label_name = mapper.class_.__name__
+
+ self.is_aliased_class = ext_info.is_aliased_class
+ self.path = ext_info._path_registry
+
+ if ext_info in compile_state._with_polymorphic_adapt_map:
+ # this codepath occurs only if query.with_polymorphic() were
+ # used
+
+ wp = inspect(compile_state._with_polymorphic_adapt_map[ext_info])
+
+ if self.is_aliased_class:
+ # TODO: invalidrequest ?
+ raise NotImplementedError(
+ "Can't use with_polymorphic() against an Aliased object"
+ )
+
+ mappers, from_obj = mapper._with_polymorphic_args(
+ wp.with_polymorphic_mappers, wp.selectable
+ )
+
+ self._with_polymorphic_mappers = mappers
+ self.selectable = from_obj
+ self._polymorphic_discriminator = wp.polymorphic_on
+
+ else:
+ self.selectable = ext_info.selectable
+ self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers
+ self._polymorphic_discriminator = ext_info.polymorphic_on
+
+ if (
+ mapper.with_polymorphic
+ # controversy - only if inheriting mapper is also
+ # polymorphic?
+ # or (mapper.inherits and mapper.inherits.with_polymorphic)
+ or mapper.inherits
+ or mapper._requires_row_aliasing
+ ):
+ compile_state._create_with_polymorphic_adapter(
+ ext_info, self.selectable
+ )
+
+ supports_single_entity = True
+
+ _non_hashable_value = True
+ use_id_for_hash = True
+
+ @property
+ def type(self):
+ return self.mapper.class_
+
+ @property
+ def entity_zero_or_selectable(self):
+ return self.entity_zero
+
+ def corresponds_to(self, entity):
+ return _entity_corresponds_to(self.entity_zero, entity)
+
+ def _get_entity_clauses(self, compile_state):
+
+ adapter = None
+
+ if not self.is_aliased_class:
+ if compile_state._polymorphic_adapters:
+ adapter = compile_state._polymorphic_adapters.get(
+ self.mapper, None
+ )
+ else:
+ adapter = self.entity_zero._adapter
+
+ if adapter:
+ if compile_state._from_obj_alias:
+ ret = adapter.wrap(compile_state._from_obj_alias)
+ else:
+ ret = adapter
+ else:
+ ret = compile_state._from_obj_alias
+
+ return ret
+
+ def row_processor(self, context, result):
+ compile_state = context.compile_state
+ adapter = self._get_entity_clauses(compile_state)
+
+ if compile_state.compound_eager_adapter and adapter:
+ adapter = adapter.wrap(compile_state.compound_eager_adapter)
+ elif not adapter:
+ adapter = compile_state.compound_eager_adapter
+
+ if compile_state._primary_entity is self:
+ only_load_props = compile_state.compile_options._only_load_props
+ refresh_state = context.refresh_state
+ else:
+ only_load_props = refresh_state = None
+
+ _instance = loading._instance_processor(
+ self,
+ self.mapper,
+ context,
+ result,
+ self.path,
+ adapter,
+ only_load_props=only_load_props,
+ refresh_state=refresh_state,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+
+ return _instance, self._label_name, self._extra_entities
+
+ def setup_compile_state(self, compile_state):
+
+ adapter = self._get_entity_clauses(compile_state)
+
+ single_table_crit = self.mapper._single_table_criterion
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
+ ext_info = self.entity_zero
+ compile_state.extra_criteria_entities[ext_info] = (
+ ext_info,
+ ext_info._adapter if ext_info.is_aliased_class else None,
+ )
+
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ self,
+ self.path,
+ adapter,
+ compile_state.primary_columns,
+ with_polymorphic=self._with_polymorphic_mappers,
+ only_load_props=compile_state.compile_options._only_load_props,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+
+ compile_state._fallback_from_clauses.append(self.selectable)
+
+
+class _BundleEntity(_QueryEntity):
+
+ _extra_entities = ()
+
+ __slots__ = (
+ "bundle",
+ "expr",
+ "type",
+ "_label_name",
+ "_entities",
+ "supports_single_entity",
+ )
+
+ def __init__(
+ self,
+ compile_state,
+ expr,
+ entities_collection,
+ is_current_entities,
+ setup_entities=True,
+ parent_bundle=None,
+ ):
+ compile_state._has_orm_entities = True
+
+ expr = expr._annotations["bundle"]
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ entities_collection.append(self)
+
+ if isinstance(
+ expr, (attributes.QueryableAttribute, interfaces.PropComparator)
+ ):
+ bundle = expr.__clause_element__()
+ else:
+ bundle = expr
+
+ self.bundle = self.expr = bundle
+ self.type = type(bundle)
+ self._label_name = bundle.name
+ self._entities = []
+
+ if setup_entities:
+ for expr in bundle.exprs:
+ if "bundle" in expr._annotations:
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ is_current_entities,
+ parent_bundle=self,
+ )
+ elif isinstance(expr, Bundle):
+ _BundleEntity(
+ compile_state,
+ expr,
+ entities_collection,
+ is_current_entities,
+ parent_bundle=self,
+ )
+ else:
+ _ORMColumnEntity._for_columns(
+ compile_state,
+ [expr],
+ entities_collection,
+ None,
+ is_current_entities,
+ parent_bundle=self,
+ )
+
+ self.supports_single_entity = self.bundle.single_entity
+ if (
+ self.supports_single_entity
+ and not compile_state.compile_options._use_legacy_query_style
+ ):
+ util.warn_deprecated_20(
+ "The Bundle.single_entity flag has no effect when "
+ "using 2.0 style execution."
+ )
+
+ @property
+ def mapper(self):
+ ezero = self.entity_zero
+ if ezero is not None:
+ return ezero.mapper
+ else:
+ return None
+
+ @property
+ def entity_zero(self):
+ for ent in self._entities:
+ ezero = ent.entity_zero
+ if ezero is not None:
+ return ezero
+ else:
+ return None
+
+ def corresponds_to(self, entity):
+ # TODO: we might be able to implement this but for now
+ # we are working around it
+ return False
+
+ @property
+ def entity_zero_or_selectable(self):
+ for ent in self._entities:
+ ezero = ent.entity_zero_or_selectable
+ if ezero is not None:
+ return ezero
+ else:
+ return None
+
+ def setup_compile_state(self, compile_state):
+ for ent in self._entities:
+ ent.setup_compile_state(compile_state)
+
+ def row_processor(self, context, result):
+ procs, labels, extra = zip(
+ *[ent.row_processor(context, result) for ent in self._entities]
+ )
+
+ proc = self.bundle.create_row_processor(context.query, procs, labels)
+
+ return proc, self._label_name, self._extra_entities
+
+
+class _ColumnEntity(_QueryEntity):
+ __slots__ = (
+ "_fetch_column",
+ "_row_processor",
+ "raw_column_index",
+ "translate_raw_column",
+ )
+
+ @classmethod
+ def _for_columns(
+ cls,
+ compile_state,
+ columns,
+ entities_collection,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=None,
+ ):
+ for column in columns:
+ annotations = column._annotations
+ if "parententity" in annotations:
+ _entity = annotations["parententity"]
+ else:
+ _entity = sql_util.extract_first_column_annotation(
+ column, "parententity"
+ )
+
+ if _entity:
+ if "identity_token" in column._annotations:
+ _IdentityTokenEntity(
+ compile_state,
+ column,
+ entities_collection,
+ _entity,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=parent_bundle,
+ )
+ else:
+ _ORMColumnEntity(
+ compile_state,
+ column,
+ entities_collection,
+ _entity,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=parent_bundle,
+ )
+ else:
+ _RawColumnEntity(
+ compile_state,
+ column,
+ entities_collection,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=parent_bundle,
+ )
+
+ @property
+ def type(self):
+ return self.column.type
+
+ @property
+ def _non_hashable_value(self):
+ return not self.column.type.hashable
+
+ @property
+ def _null_column_type(self):
+ return self.column.type._isnull
+
+ def row_processor(self, context, result):
+ compile_state = context.compile_state
+
+ # the resulting callable is entirely cacheable so just return
+ # it if we already made one
+ if self._row_processor is not None:
+ getter, label_name, extra_entities = self._row_processor
+ if self.translate_raw_column:
+ extra_entities += (
+ result.context.invoked_statement._raw_columns[
+ self.raw_column_index
+ ],
+ )
+
+ return getter, label_name, extra_entities
+
+ # retrieve the column that would have been set up in
+ # setup_compile_state, to avoid doing redundant work
+ if self._fetch_column is not None:
+ column = self._fetch_column
+ else:
+ # fetch_column will be None when we are doing a from_statement
+ # and setup_compile_state may not have been called.
+ column = self.column
+
+ # previously, the RawColumnEntity didn't look for from_obj_alias
+ # however I can't think of a case where we would be here and
+ # we'd want to ignore it if this is the from_statement use case.
+ # it's not really a use case to have raw columns + from_statement
+ if compile_state._from_obj_alias:
+ column = compile_state._from_obj_alias.columns[column]
+
+ if column._annotations:
+ # annotated columns perform more slowly in compiler and
+ # result due to the __eq__() method, so use deannotated
+ column = column._deannotate()
+
+ if compile_state.compound_eager_adapter:
+ column = compile_state.compound_eager_adapter.columns[column]
+
+ getter = result._getter(column)
+
+ ret = getter, self._label_name, self._extra_entities
+ self._row_processor = ret
+
+ if self.translate_raw_column:
+ extra_entities = self._extra_entities + (
+ result.context.invoked_statement._raw_columns[
+ self.raw_column_index
+ ],
+ )
+ return getter, self._label_name, extra_entities
+ else:
+ return ret
+
+
+class _RawColumnEntity(_ColumnEntity):
+ entity_zero = None
+ mapper = None
+ supports_single_entity = False
+
+ __slots__ = (
+ "expr",
+ "column",
+ "_label_name",
+ "entity_zero_or_selectable",
+ "_extra_entities",
+ )
+
+ def __init__(
+ self,
+ compile_state,
+ column,
+ entities_collection,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=None,
+ ):
+ self.expr = column
+ self.raw_column_index = raw_column_index
+ self.translate_raw_column = raw_column_index is not None
+
+ if column._is_star:
+ compile_state.compile_options += {"_is_star": True}
+
+ if not is_current_entities or column._is_text_clause:
+ self._label_name = None
+ else:
+ self._label_name = compile_state._label_convention(column)
+
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ entities_collection.append(self)
+
+ self.column = column
+ self.entity_zero_or_selectable = (
+ self.column._from_objects[0] if self.column._from_objects else None
+ )
+ self._extra_entities = (self.expr, self.column)
+ self._fetch_column = self._row_processor = None
+
+ def corresponds_to(self, entity):
+ return False
+
+ def setup_compile_state(self, compile_state):
+ current_adapter = compile_state._get_current_adapter()
+ if current_adapter:
+ column = current_adapter(self.column, False)
+ else:
+ column = self.column
+
+ if column._annotations:
+ # annotated columns perform more slowly in compiler and
+ # result due to the __eq__() method, so use deannotated
+ column = column._deannotate()
+
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+ self._fetch_column = column
+
+
+class _ORMColumnEntity(_ColumnEntity):
+ """Column/expression based entity."""
+
+ supports_single_entity = False
+
+ __slots__ = (
+ "expr",
+ "mapper",
+ "column",
+ "_label_name",
+ "entity_zero_or_selectable",
+ "entity_zero",
+ "_extra_entities",
+ )
+
+ def __init__(
+ self,
+ compile_state,
+ column,
+ entities_collection,
+ parententity,
+ raw_column_index,
+ is_current_entities,
+ parent_bundle=None,
+ ):
+ annotations = column._annotations
+
+ _entity = parententity
+
+ # an AliasedClass won't have proxy_key in the annotations for
+ # a column if it was acquired using the class' adapter directly,
+ # such as using AliasedInsp._adapt_element(). this occurs
+ # within internal loaders.
+
+ orm_key = annotations.get("proxy_key", None)
+ proxy_owner = annotations.get("proxy_owner", _entity)
+ if orm_key:
+ self.expr = getattr(proxy_owner.entity, orm_key)
+ self.translate_raw_column = False
+ else:
+ # if orm_key is not present, that means this is an ad-hoc
+ # SQL ColumnElement, like a CASE() or other expression.
+ # include this column position from the invoked statement
+ # in the ORM-level ResultSetMetaData on each execute, so that
+ # it can be targeted by identity after caching
+ self.expr = column
+ self.translate_raw_column = raw_column_index is not None
+
+ self.raw_column_index = raw_column_index
+
+ if is_current_entities:
+ self._label_name = compile_state._label_convention(
+ column, col_name=orm_key
+ )
+ else:
+ self._label_name = None
+
+ _entity._post_inspect
+ self.entity_zero = self.entity_zero_or_selectable = ezero = _entity
+ self.mapper = mapper = _entity.mapper
+
+ if parent_bundle:
+ parent_bundle._entities.append(self)
+ else:
+ entities_collection.append(self)
+
+ compile_state._has_orm_entities = True
+
+ self.column = column
+
+ self._fetch_column = self._row_processor = None
+
+ self._extra_entities = (self.expr, self.column)
+
+ if (
+ mapper.with_polymorphic
+ or mapper.inherits
+ or mapper._requires_row_aliasing
+ ):
+ compile_state._create_with_polymorphic_adapter(
+ ezero, ezero.selectable
+ )
+
+ def corresponds_to(self, entity):
+ if _is_aliased_class(entity):
+ # TODO: polymorphic subclasses ?
+ return entity is self.entity_zero
+ else:
+ return not _is_aliased_class(
+ self.entity_zero
+ ) and entity.common_parent(self.entity_zero)
+
+ def setup_compile_state(self, compile_state):
+ current_adapter = compile_state._get_current_adapter()
+ if current_adapter:
+ column = current_adapter(self.column, False)
+ else:
+ column = self.column
+
+ ezero = self.entity_zero
+
+ single_table_crit = self.mapper._single_table_criterion
+ if (
+ single_table_crit is not None
+ or ("additional_entity_criteria", self.mapper)
+ in compile_state.global_attributes
+ ):
+
+ compile_state.extra_criteria_entities[ezero] = (
+ ezero,
+ ezero._adapter if ezero.is_aliased_class else None,
+ )
+
+ if column._annotations and not column._expression_label:
+ # annotated columns perform more slowly in compiler and
+ # result due to the __eq__() method, so use deannotated
+ column = column._deannotate()
+
+ # use entity_zero as the from if we have it. this is necessary
+ # for polymorphic scenarios where our FROM is based on ORM entity,
+ # not the FROM of the column. but also, don't use it if our column
+ # doesn't actually have any FROMs that line up, such as when its
+ # a scalar subquery.
+ if set(self.column._from_objects).intersection(
+ ezero.selectable._from_objects
+ ):
+ compile_state._fallback_from_clauses.append(ezero.selectable)
+
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+ self._fetch_column = column
+
+
+class _IdentityTokenEntity(_ORMColumnEntity):
+ translate_raw_column = False
+
+ def setup_compile_state(self, compile_state):
+ pass
+
+ def row_processor(self, context, result):
+ def getter(row):
+ return context.load_options._refresh_identity_token
+
+ return getter, self._label_name, self._extra_entities
diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py
new file mode 100644
index 0000000..16f91c6
--- /dev/null
+++ b/lib/sqlalchemy/orm/decl_api.py
@@ -0,0 +1,1062 @@
+# ext/declarative/api.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Public API functions and helpers for declarative."""
+from __future__ import absolute_import
+
+import itertools
+import re
+import weakref
+
+from . import attributes
+from . import clsregistry
+from . import exc as orm_exc
+from . import instrumentation
+from . import interfaces
+from . import mapper as mapperlib
+from .base import _inspect_mapped_class
+from .decl_base import _add_attribute
+from .decl_base import _as_declarative
+from .decl_base import _declarative_constructor
+from .decl_base import _DeferredMapperConfig
+from .decl_base import _del_attribute
+from .decl_base import _mapper
+from .descriptor_props import SynonymProperty as _orm_synonym
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql.schema import MetaData
+from ..util import hybridmethod
+from ..util import hybridproperty
+
+
+def has_inherited_table(cls):
+ """Given a class, return True if any of the classes it inherits from has a
+ mapped table, otherwise return False.
+
+ This is used in declarative mixins to build attributes that behave
+ differently for the base class vs. a subclass in an inheritance
+ hierarchy.
+
+ .. seealso::
+
+ :ref:`decl_mixin_inheritance`
+
+ """
+ for class_ in cls.__mro__[1:]:
+ if getattr(class_, "__table__", None) is not None:
+ return True
+ return False
+
+
+class DeclarativeMeta(type):
+ def __init__(cls, classname, bases, dict_, **kw):
+ # use cls.__dict__, which can be modified by an
+ # __init_subclass__() method (#7900)
+ dict_ = cls.__dict__
+
+ # early-consume registry from the initial declarative base,
+ # assign privately to not conflict with subclass attributes named
+ # "registry"
+ reg = getattr(cls, "_sa_registry", None)
+ if reg is None:
+ reg = dict_.get("registry", None)
+ if not isinstance(reg, registry):
+ raise exc.InvalidRequestError(
+ "Declarative base class has no 'registry' attribute, "
+ "or registry is not a sqlalchemy.orm.registry() object"
+ )
+ else:
+ cls._sa_registry = reg
+
+ if not cls.__dict__.get("__abstract__", False):
+ _as_declarative(reg, cls, dict_)
+ type.__init__(cls, classname, bases, dict_)
+
+ def __setattr__(cls, key, value):
+ _add_attribute(cls, key, value)
+
+ def __delattr__(cls, key):
+ _del_attribute(cls, key)
+
+
+def synonym_for(name, map_column=False):
+ """Decorator that produces an :func:`_orm.synonym`
+ attribute in conjunction with a Python descriptor.
+
+ The function being decorated is passed to :func:`_orm.synonym` as the
+ :paramref:`.orm.synonym.descriptor` parameter::
+
+ class MyClass(Base):
+ __tablename__ = 'my_table'
+
+ id = Column(Integer, primary_key=True)
+ _job_status = Column("job_status", String(50))
+
+ @synonym_for("job_status")
+ @property
+ def job_status(self):
+ return "Status: %s" % self._job_status
+
+ The :ref:`hybrid properties <mapper_hybrids>` feature of SQLAlchemy
+ is typically preferred instead of synonyms, which is a more legacy
+ feature.
+
+ .. seealso::
+
+ :ref:`synonyms` - Overview of synonyms
+
+ :func:`_orm.synonym` - the mapper-level function
+
+ :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an
+ updated approach to augmenting attribute behavior more flexibly than
+ can be achieved with synonyms.
+
+ """
+
+ def decorate(fn):
+ return _orm_synonym(name, map_column=map_column, descriptor=fn)
+
+ return decorate
+
+
+class declared_attr(interfaces._MappedAttribute, property):
+ """Mark a class-level method as representing the definition of
+ a mapped property or special declarative member name.
+
+ :class:`_orm.declared_attr` is typically applied as a decorator to a class
+ level method, turning the attribute into a scalar-like property that can be
+ invoked from the uninstantiated class. The Declarative mapping process
+ looks for these :class:`_orm.declared_attr` callables as it scans classes,
+ and assumes any attribute marked with :class:`_orm.declared_attr` will be a
+ callable that will produce an object specific to the Declarative mapping or
+ table configuration.
+
+ :class:`_orm.declared_attr` is usually applicable to mixins, to define
+ relationships that are to be applied to different implementors of the
+ class. It is also used to define :class:`_schema.Column` objects that
+ include the :class:`_schema.ForeignKey` construct, as these cannot be
+ easily reused across different mappings. The example below illustrates
+ both::
+
+ class ProvidesUser(object):
+ "A mixin that adds a 'user' relationship to classes."
+
+ @declared_attr
+ def user_id(self):
+ return Column(ForeignKey("user_account.id"))
+
+ @declared_attr
+ def user(self):
+ return relationship("User")
+
+ :class:`_orm.declared_attr` can also be applied to mapped classes, such as
+ to provide a "polymorphic" scheme for inheritance::
+
+ class Employee(Base):
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50), nullable=False)
+
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+
+ @declared_attr
+ def __mapper_args__(cls):
+ if cls.__name__ == 'Employee':
+ return {
+ "polymorphic_on":cls.type,
+ "polymorphic_identity":"Employee"
+ }
+ else:
+ return {"polymorphic_identity":cls.__name__}
+
+ To use :class:`_orm.declared_attr` inside of a Python dataclass
+ as discussed at :ref:`orm_declarative_dataclasses_declarative_table`,
+ it may be placed directly inside the field metadata using a lambda::
+
+ @dataclass
+ class AddressMixin:
+ __sa_dataclass_metadata_key__ = "sa"
+
+ user_id: int = field(
+ init=False, metadata={"sa": declared_attr(lambda: Column(ForeignKey("user.id")))}
+ )
+ user: User = field(
+ init=False, metadata={"sa": declared_attr(lambda: relationship(User))}
+ )
+
+ :class:`_orm.declared_attr` also may be omitted from this form using a
+ lambda directly, as in::
+
+ user: User = field(
+ init=False, metadata={"sa": lambda: relationship(User)}
+ )
+
+ .. seealso::
+
+ :ref:`orm_mixins_toplevel` - illustrates how to use Declarative Mixins
+ which is the primary use case for :class:`_orm.declared_attr`
+
+ :ref:`orm_declarative_dataclasses_mixin` - illustrates special forms
+ for use with Python dataclasses
+
+ """ # noqa: E501
+
+ def __init__(self, fget, cascading=False):
+ super(declared_attr, self).__init__(fget)
+ self.__doc__ = fget.__doc__
+ self._cascading = cascading
+
+ def __get__(desc, self, cls):
+ # the declared_attr needs to make use of a cache that exists
+ # for the span of the declarative scan_attributes() phase.
+ # to achieve this we look at the class manager that's configured.
+ manager = attributes.manager_of_class(cls)
+ if manager is None:
+ if not re.match(r"^__.+__$", desc.fget.__name__):
+ # if there is no manager at all, then this class hasn't been
+ # run through declarative or mapper() at all, emit a warning.
+ util.warn(
+ "Unmanaged access of declarative attribute %s from "
+ "non-mapped class %s" % (desc.fget.__name__, cls.__name__)
+ )
+ return desc.fget(cls)
+ elif manager.is_mapped:
+ # the class is mapped, which means we're outside of the declarative
+ # scan setup, just run the function.
+ return desc.fget(cls)
+
+ # here, we are inside of the declarative scan. use the registry
+ # that is tracking the values of these attributes.
+ declarative_scan = manager.declarative_scan()
+ assert declarative_scan is not None
+ reg = declarative_scan.declared_attr_reg
+
+ if desc in reg:
+ return reg[desc]
+ else:
+ reg[desc] = obj = desc.fget(cls)
+ return obj
+
+ @hybridmethod
+ def _stateful(cls, **kw):
+ return _stateful_declared_attr(**kw)
+
+ @hybridproperty
+ def cascading(cls):
+ """Mark a :class:`.declared_attr` as cascading.
+
+ This is a special-use modifier which indicates that a column
+ or MapperProperty-based declared attribute should be configured
+ distinctly per mapped subclass, within a mapped-inheritance scenario.
+
+ .. warning::
+
+ The :attr:`.declared_attr.cascading` modifier has several
+ limitations:
+
+ * The flag **only** applies to the use of :class:`.declared_attr`
+ on declarative mixin classes and ``__abstract__`` classes; it
+ currently has no effect when used on a mapped class directly.
+
+ * The flag **only** applies to normally-named attributes, e.g.
+ not any special underscore attributes such as ``__tablename__``.
+ On these attributes it has **no** effect.
+
+ * The flag currently **does not allow further overrides** down
+ the class hierarchy; if a subclass tries to override the
+ attribute, a warning is emitted and the overridden attribute
+ is skipped. This is a limitation that it is hoped will be
+ resolved at some point.
+
+ Below, both MyClass as well as MySubClass will have a distinct
+ ``id`` Column object established::
+
+ class HasIdMixin(object):
+ @declared_attr.cascading
+ def id(cls):
+ if has_inherited_table(cls):
+ return Column(
+ ForeignKey('myclass.id'), primary_key=True
+ )
+ else:
+ return Column(Integer, primary_key=True)
+
+ class MyClass(HasIdMixin, Base):
+ __tablename__ = 'myclass'
+ # ...
+
+ class MySubClass(MyClass):
+ ""
+ # ...
+
+ The behavior of the above configuration is that ``MySubClass``
+ will refer to both its own ``id`` column as well as that of
+ ``MyClass`` underneath the attribute named ``some_id``.
+
+ .. seealso::
+
+ :ref:`declarative_inheritance`
+
+ :ref:`mixin_inheritance_columns`
+
+
+ """
+ return cls._stateful(cascading=True)
+
+
+class _stateful_declared_attr(declared_attr):
+ def __init__(self, **kw):
+ self.kw = kw
+
+ def _stateful(self, **kw):
+ new_kw = self.kw.copy()
+ new_kw.update(kw)
+ return _stateful_declared_attr(**new_kw)
+
+ def __call__(self, fn):
+ return declared_attr(fn, **self.kw)
+
+
+def declarative_mixin(cls):
+ """Mark a class as providing the feature of "declarative mixin".
+
+ E.g.::
+
+ from sqlalchemy.orm import declared_attr
+ from sqlalchemy.orm import declarative_mixin
+
+ @declarative_mixin
+ class MyMixin:
+
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+
+ __table_args__ = {'mysql_engine': 'InnoDB'}
+ __mapper_args__= {'always_refresh': True}
+
+ id = Column(Integer, primary_key=True)
+
+ class MyModel(MyMixin, Base):
+ name = Column(String(1000))
+
+ The :func:`_orm.declarative_mixin` decorator currently does not modify
+ the given class in any way; it's current purpose is strictly to assist
+ the :ref:`Mypy plugin <mypy_toplevel>` in being able to identify
+ SQLAlchemy declarative mixin classes when no other context is present.
+
+ .. versionadded:: 1.4.6
+
+ .. seealso::
+
+ :ref:`orm_mixins_toplevel`
+
+ :ref:`mypy_declarative_mixins` - in the
+ :ref:`Mypy plugin documentation <mypy_toplevel>`
+
+ """ # noqa: E501
+
+ return cls
+
+
+def declarative_base(
+ bind=None,
+ metadata=None,
+ mapper=None,
+ cls=object,
+ name="Base",
+ constructor=_declarative_constructor,
+ class_registry=None,
+ metaclass=DeclarativeMeta,
+):
+ r"""Construct a base class for declarative class definitions.
+
+ The new base class will be given a metaclass that produces
+ appropriate :class:`~sqlalchemy.schema.Table` objects and makes
+ the appropriate :func:`~sqlalchemy.orm.mapper` calls based on the
+ information provided declaratively in the class and any subclasses
+ of the class.
+
+ The :func:`_orm.declarative_base` function is a shorthand version
+ of using the :meth:`_orm.registry.generate_base`
+ method. That is, the following::
+
+ from sqlalchemy.orm import declarative_base
+
+ Base = declarative_base()
+
+ Is equivalent to::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+ Base = mapper_registry.generate_base()
+
+ See the docstring for :class:`_orm.registry`
+ and :meth:`_orm.registry.generate_base`
+ for more details.
+
+ .. versionchanged:: 1.4 The :func:`_orm.declarative_base`
+ function is now a specialization of the more generic
+ :class:`_orm.registry` class. The function also moves to the
+ ``sqlalchemy.orm`` package from the ``declarative.ext`` package.
+
+
+ :param bind: An optional
+ :class:`~sqlalchemy.engine.Connectable`, will be assigned
+ the ``bind`` attribute on the :class:`~sqlalchemy.schema.MetaData`
+ instance.
+
+ .. deprecated:: 1.4 The "bind" argument to declarative_base is
+ deprecated and will be removed in SQLAlchemy 2.0.
+
+ :param metadata:
+ An optional :class:`~sqlalchemy.schema.MetaData` instance. All
+ :class:`~sqlalchemy.schema.Table` objects implicitly declared by
+ subclasses of the base will share this MetaData. A MetaData instance
+ will be created if none is provided. The
+ :class:`~sqlalchemy.schema.MetaData` instance will be available via the
+ ``metadata`` attribute of the generated declarative base class.
+
+ :param mapper:
+ An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`. Will
+ be used to map subclasses to their Tables.
+
+ :param cls:
+ Defaults to :class:`object`. A type to use as the base for the generated
+ declarative base class. May be a class or tuple of classes.
+
+ :param name:
+ Defaults to ``Base``. The display name for the generated
+ class. Customizing this is not required, but can improve clarity in
+ tracebacks and debugging.
+
+ :param constructor:
+ Specify the implementation for the ``__init__`` function on a mapped
+ class that has no ``__init__`` of its own. Defaults to an
+ implementation that assigns \**kwargs for declared
+ fields and relationships to an instance. If ``None`` is supplied,
+ no __init__ will be provided and construction will fall back to
+ cls.__init__ by way of the normal Python semantics.
+
+ :param class_registry: optional dictionary that will serve as the
+ registry of class names-> mapped classes when string names
+ are used to identify classes inside of :func:`_orm.relationship`
+ and others. Allows two or more declarative base classes
+ to share the same registry of class names for simplified
+ inter-base relationships.
+
+ :param metaclass:
+ Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__
+ compatible callable to use as the meta type of the generated
+ declarative base class.
+
+ .. seealso::
+
+ :class:`_orm.registry`
+
+ """
+
+ if bind is not None:
+ # util.deprecated_params does not work
+ util.warn_deprecated_20(
+ "The ``bind`` argument to declarative_base is "
+ "deprecated and will be removed in SQLAlchemy 2.0.",
+ )
+
+ return registry(
+ _bind=bind,
+ metadata=metadata,
+ class_registry=class_registry,
+ constructor=constructor,
+ ).generate_base(
+ mapper=mapper,
+ cls=cls,
+ name=name,
+ metaclass=metaclass,
+ )
+
+
+class registry(object):
+ """Generalized registry for mapping classes.
+
+ The :class:`_orm.registry` serves as the basis for maintaining a collection
+ of mappings, and provides configurational hooks used to map classes.
+
+ The three general kinds of mappings supported are Declarative Base,
+ Declarative Decorator, and Imperative Mapping. All of these mapping
+ styles may be used interchangeably:
+
+ * :meth:`_orm.registry.generate_base` returns a new declarative base
+ class, and is the underlying implementation of the
+ :func:`_orm.declarative_base` function.
+
+ * :meth:`_orm.registry.mapped` provides a class decorator that will
+ apply declarative mapping to a class without the use of a declarative
+ base class.
+
+ * :meth:`_orm.registry.map_imperatively` will produce a
+ :class:`_orm.Mapper` for a class without scanning the class for
+ declarative class attributes. This method suits the use case historically
+ provided by the
+ :func:`_orm.mapper` classical mapping function.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`orm_mapping_classes_toplevel` - overview of class mapping
+ styles.
+
+ """
+
+ def __init__(
+ self,
+ metadata=None,
+ class_registry=None,
+ constructor=_declarative_constructor,
+ _bind=None,
+ ):
+ r"""Construct a new :class:`_orm.registry`
+
+ :param metadata:
+ An optional :class:`_schema.MetaData` instance. All
+ :class:`_schema.Table` objects generated using declarative
+ table mapping will make use of this :class:`_schema.MetaData`
+ collection. If this argument is left at its default of ``None``,
+ a blank :class:`_schema.MetaData` collection is created.
+
+ :param constructor:
+ Specify the implementation for the ``__init__`` function on a mapped
+ class that has no ``__init__`` of its own. Defaults to an
+ implementation that assigns \**kwargs for declared
+ fields and relationships to an instance. If ``None`` is supplied,
+ no __init__ will be provided and construction will fall back to
+ cls.__init__ by way of the normal Python semantics.
+
+ :param class_registry: optional dictionary that will serve as the
+ registry of class names-> mapped classes when string names
+ are used to identify classes inside of :func:`_orm.relationship`
+ and others. Allows two or more declarative base classes
+ to share the same registry of class names for simplified
+ inter-base relationships.
+
+ """
+ lcl_metadata = metadata or MetaData()
+ if _bind:
+ lcl_metadata.bind = _bind
+
+ if class_registry is None:
+ class_registry = weakref.WeakValueDictionary()
+
+ self._class_registry = class_registry
+ self._managers = weakref.WeakKeyDictionary()
+ self._non_primary_mappers = weakref.WeakKeyDictionary()
+ self.metadata = lcl_metadata
+ self.constructor = constructor
+
+ self._dependents = set()
+ self._dependencies = set()
+
+ self._new_mappers = False
+
+ with mapperlib._CONFIGURE_MUTEX:
+ mapperlib._mapper_registries[self] = True
+
+ @property
+ def mappers(self):
+ """read only collection of all :class:`_orm.Mapper` objects."""
+
+ return frozenset(manager.mapper for manager in self._managers).union(
+ self._non_primary_mappers
+ )
+
+ def _set_depends_on(self, registry):
+ if registry is self:
+ return
+ registry._dependents.add(self)
+ self._dependencies.add(registry)
+
+ def _flag_new_mapper(self, mapper):
+ mapper._ready_for_configure = True
+ if self._new_mappers:
+ return
+
+ for reg in self._recurse_with_dependents({self}):
+ reg._new_mappers = True
+
+ @classmethod
+ def _recurse_with_dependents(cls, registries):
+ todo = registries
+ done = set()
+ while todo:
+ reg = todo.pop()
+ done.add(reg)
+
+ # if yielding would remove dependents, make sure we have
+ # them before
+ todo.update(reg._dependents.difference(done))
+ yield reg
+
+ # if yielding would add dependents, make sure we have them
+ # after
+ todo.update(reg._dependents.difference(done))
+
+ @classmethod
+ def _recurse_with_dependencies(cls, registries):
+ todo = registries
+ done = set()
+ while todo:
+ reg = todo.pop()
+ done.add(reg)
+
+ # if yielding would remove dependencies, make sure we have
+ # them before
+ todo.update(reg._dependencies.difference(done))
+
+ yield reg
+
+ # if yielding would remove dependencies, make sure we have
+ # them before
+ todo.update(reg._dependencies.difference(done))
+
+ def _mappers_to_configure(self):
+ return itertools.chain(
+ (
+ manager.mapper
+ for manager in list(self._managers)
+ if manager.is_mapped
+ and not manager.mapper.configured
+ and manager.mapper._ready_for_configure
+ ),
+ (
+ npm
+ for npm in list(self._non_primary_mappers)
+ if not npm.configured and npm._ready_for_configure
+ ),
+ )
+
+ def _add_non_primary_mapper(self, np_mapper):
+ self._non_primary_mappers[np_mapper] = True
+
+ def _dispose_cls(self, cls):
+ clsregistry.remove_class(cls.__name__, cls, self._class_registry)
+
+ def _add_manager(self, manager):
+ self._managers[manager] = True
+ if manager.registry is not None and manager.is_mapped:
+ raise exc.ArgumentError(
+ "Class '%s' already has a primary mapper defined. "
+ % manager.class_
+ )
+ manager.registry = self
+
+ def configure(self, cascade=False):
+ """Configure all as-yet unconfigured mappers in this
+ :class:`_orm.registry`.
+
+ The configure step is used to reconcile and initialize the
+ :func:`_orm.relationship` linkages between mapped classes, as well as
+ to invoke configuration events such as the
+ :meth:`_orm.MapperEvents.before_configured` and
+ :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM
+ extensions or user-defined extension hooks.
+
+ If one or more mappers in this registry contain
+ :func:`_orm.relationship` constructs that refer to mapped classes in
+ other registries, this registry is said to be *dependent* on those
+ registries. In order to configure those dependent registries
+ automatically, the :paramref:`_orm.registry.configure.cascade` flag
+ should be set to ``True``. Otherwise, if they are not configured, an
+ exception will be raised. The rationale behind this behavior is to
+ allow an application to programmatically invoke configuration of
+ registries while controlling whether or not the process implicitly
+ reaches other registries.
+
+ As an alternative to invoking :meth:`_orm.registry.configure`, the ORM
+ function :func:`_orm.configure_mappers` function may be used to ensure
+ configuration is complete for all :class:`_orm.registry` objects in
+ memory. This is generally simpler to use and also predates the usage of
+ :class:`_orm.registry` objects overall. However, this function will
+ impact all mappings throughout the running Python process and may be
+ more memory/time consuming for an application that has many registries
+ in use for different purposes that may not be needed immediately.
+
+ .. seealso::
+
+ :func:`_orm.configure_mappers`
+
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ mapperlib._configure_registries({self}, cascade=cascade)
+
+ def dispose(self, cascade=False):
+ """Dispose of all mappers in this :class:`_orm.registry`.
+
+ After invocation, all the classes that were mapped within this registry
+ will no longer have class instrumentation associated with them. This
+ method is the per-:class:`_orm.registry` analogue to the
+ application-wide :func:`_orm.clear_mappers` function.
+
+ If this registry contains mappers that are dependencies of other
+ registries, typically via :func:`_orm.relationship` links, then those
+ registries must be disposed as well. When such registries exist in
+ relation to this one, their :meth:`_orm.registry.dispose` method will
+ also be called, if the :paramref:`_orm.registry.dispose.cascade` flag
+ is set to ``True``; otherwise, an error is raised if those registries
+ were not already disposed.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :func:`_orm.clear_mappers`
+
+ """
+
+ mapperlib._dispose_registries({self}, cascade=cascade)
+
+ def _dispose_manager_and_mapper(self, manager):
+ if "mapper" in manager.__dict__:
+ mapper = manager.mapper
+
+ mapper._set_dispose_flags()
+
+ class_ = manager.class_
+ self._dispose_cls(class_)
+ instrumentation._instrumentation_factory.unregister(class_)
+
+ def generate_base(
+ self,
+ mapper=None,
+ cls=object,
+ name="Base",
+ metaclass=DeclarativeMeta,
+ ):
+ """Generate a declarative base class.
+
+ Classes that inherit from the returned class object will be
+ automatically mapped using declarative mapping.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ Base = mapper_registry.generate_base()
+
+ class MyClass(Base):
+ __tablename__ = "my_table"
+ id = Column(Integer, primary_key=True)
+
+ The above dynamically generated class is equivalent to the
+ non-dynamic example below::
+
+ from sqlalchemy.orm import registry
+ from sqlalchemy.orm.decl_api import DeclarativeMeta
+
+ mapper_registry = registry()
+
+ class Base(metaclass=DeclarativeMeta):
+ __abstract__ = True
+ registry = mapper_registry
+ metadata = mapper_registry.metadata
+
+ __init__ = mapper_registry.constructor
+
+ The :meth:`_orm.registry.generate_base` method provides the
+ implementation for the :func:`_orm.declarative_base` function, which
+ creates the :class:`_orm.registry` and base class all at once.
+
+ See the section :ref:`orm_declarative_mapping` for background and
+ examples.
+
+ :param mapper:
+ An optional callable, defaults to :func:`~sqlalchemy.orm.mapper`.
+ This function is used to generate new :class:`_orm.Mapper` objects.
+
+ :param cls:
+ Defaults to :class:`object`. A type to use as the base for the
+ generated declarative base class. May be a class or tuple of classes.
+
+ :param name:
+ Defaults to ``Base``. The display name for the generated
+ class. Customizing this is not required, but can improve clarity in
+ tracebacks and debugging.
+
+ :param metaclass:
+ Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__
+ compatible callable to use as the meta type of the generated
+ declarative base class.
+
+ .. seealso::
+
+ :ref:`orm_declarative_mapping`
+
+ :func:`_orm.declarative_base`
+
+ """
+ metadata = self.metadata
+
+ bases = not isinstance(cls, tuple) and (cls,) or cls
+
+ class_dict = dict(registry=self, metadata=metadata)
+ if isinstance(cls, type):
+ class_dict["__doc__"] = cls.__doc__
+
+ if self.constructor:
+ class_dict["__init__"] = self.constructor
+
+ class_dict["__abstract__"] = True
+ if mapper:
+ class_dict["__mapper_cls__"] = mapper
+
+ if hasattr(cls, "__class_getitem__"):
+
+ def __class_getitem__(cls, key):
+ # allow generic classes in py3.9+
+ return cls
+
+ class_dict["__class_getitem__"] = __class_getitem__
+
+ return metaclass(name, bases, class_dict)
+
+ def mapped(self, cls):
+ """Class decorator that will apply the Declarative mapping process
+ to a given class.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ @mapper_registry.mapped
+ class Foo:
+ __tablename__ = 'some_table'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ See the section :ref:`orm_declarative_mapping` for complete
+ details and examples.
+
+ :param cls: class to be mapped.
+
+ :return: the class that was passed.
+
+ .. seealso::
+
+ :ref:`orm_declarative_mapping`
+
+ :meth:`_orm.registry.generate_base` - generates a base class
+ that will apply Declarative mapping to subclasses automatically
+ using a Python metaclass.
+
+ """
+ _as_declarative(self, cls, cls.__dict__)
+ return cls
+
+ def as_declarative_base(self, **kw):
+ """
+ Class decorator which will invoke
+ :meth:`_orm.registry.generate_base`
+ for a given base class.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ @mapper_registry.as_declarative_base()
+ class Base(object):
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+ id = Column(Integer, primary_key=True)
+
+ class MyMappedClass(Base):
+ # ...
+
+ All keyword arguments passed to
+ :meth:`_orm.registry.as_declarative_base` are passed
+ along to :meth:`_orm.registry.generate_base`.
+
+ """
+
+ def decorate(cls):
+ kw["cls"] = cls
+ kw["name"] = cls.__name__
+ return self.generate_base(**kw)
+
+ return decorate
+
+ def map_declaratively(self, cls):
+ """Map a class declaratively.
+
+ In this form of mapping, the class is scanned for mapping information,
+ including for columns to be associated with a table, and/or an
+ actual table object.
+
+ Returns the :class:`_orm.Mapper` object.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ class Foo:
+ __tablename__ = 'some_table'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ mapper = mapper_registry.map_declaratively(Foo)
+
+ This function is more conveniently invoked indirectly via either the
+ :meth:`_orm.registry.mapped` class decorator or by subclassing a
+ declarative metaclass generated from
+ :meth:`_orm.registry.generate_base`.
+
+ See the section :ref:`orm_declarative_mapping` for complete
+ details and examples.
+
+ :param cls: class to be mapped.
+
+ :return: a :class:`_orm.Mapper` object.
+
+ .. seealso::
+
+ :ref:`orm_declarative_mapping`
+
+ :meth:`_orm.registry.mapped` - more common decorator interface
+ to this function.
+
+ :meth:`_orm.registry.map_imperatively`
+
+ """
+ return _as_declarative(self, cls, cls.__dict__)
+
+ def map_imperatively(self, class_, local_table=None, **kw):
+ r"""Map a class imperatively.
+
+ In this form of mapping, the class is not scanned for any mapping
+ information. Instead, all mapping constructs are passed as
+ arguments.
+
+ This method is intended to be fully equivalent to the classic
+ SQLAlchemy :func:`_orm.mapper` function, except that it's in terms of
+ a particular registry.
+
+ E.g.::
+
+ from sqlalchemy.orm import registry
+
+ mapper_registry = registry()
+
+ my_table = Table(
+ "my_table",
+ mapper_registry.metadata,
+ Column('id', Integer, primary_key=True)
+ )
+
+ class MyClass:
+ pass
+
+ mapper_registry.map_imperatively(MyClass, my_table)
+
+ See the section :ref:`orm_imperative_mapping` for complete background
+ and usage examples.
+
+ :param class\_: The class to be mapped. Corresponds to the
+ :paramref:`_orm.mapper.class_` parameter.
+
+ :param local_table: the :class:`_schema.Table` or other
+ :class:`_sql.FromClause` object that is the subject of the mapping.
+ Corresponds to the
+ :paramref:`_orm.mapper.local_table` parameter.
+
+ :param \**kw: all other keyword arguments are passed to the
+ :func:`_orm.mapper` function directly.
+
+ .. seealso::
+
+ :ref:`orm_imperative_mapping`
+
+ :ref:`orm_declarative_mapping`
+
+ """
+ return _mapper(self, class_, local_table, kw)
+
+
+mapperlib._legacy_registry = registry()
+
+
+@util.deprecated_params(
+ bind=(
+ "2.0",
+ "The ``bind`` argument to as_declarative is "
+ "deprecated and will be removed in SQLAlchemy 2.0.",
+ )
+)
+def as_declarative(**kw):
+ """
+ Class decorator which will adapt a given class into a
+ :func:`_orm.declarative_base`.
+
+ This function makes use of the :meth:`_orm.registry.as_declarative_base`
+ method, by first creating a :class:`_orm.registry` automatically
+ and then invoking the decorator.
+
+ E.g.::
+
+ from sqlalchemy.orm import as_declarative
+
+ @as_declarative()
+ class Base(object):
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+ id = Column(Integer, primary_key=True)
+
+ class MyMappedClass(Base):
+ # ...
+
+ .. seealso::
+
+ :meth:`_orm.registry.as_declarative_base`
+
+ """
+ bind, metadata, class_registry = (
+ kw.pop("bind", None),
+ kw.pop("metadata", None),
+ kw.pop("class_registry", None),
+ )
+
+ return registry(
+ _bind=bind, metadata=metadata, class_registry=class_registry
+ ).as_declarative_base(**kw)
+
+
+@inspection._inspects(DeclarativeMeta)
+def _inspect_decl_meta(cls):
+ mp = _inspect_mapped_class(cls)
+ if mp is None:
+ if _DeferredMapperConfig.has_cls(cls):
+ _DeferredMapperConfig.raise_unmapped_for_cls(cls)
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s has a deferred mapping on it. It is not yet "
+ "usable as a mapped class." % orm_exc._safe_cls_name(cls),
+ )
+ return mp
diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py
new file mode 100644
index 0000000..6e1c797
--- /dev/null
+++ b/lib/sqlalchemy/orm/decl_base.py
@@ -0,0 +1,1210 @@
+# ext/declarative/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Internal implementation for declarative."""
+from __future__ import absolute_import
+
+import collections
+import weakref
+
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import instrumentation
+from . import clsregistry
+from . import exc as orm_exc
+from . import mapper as mapperlib
+from .attributes import InstrumentedAttribute
+from .attributes import QueryableAttribute
+from .base import _is_mapped_class
+from .base import InspectionAttr
+from .descriptor_props import CompositeProperty
+from .descriptor_props import SynonymProperty
+from .interfaces import MapperProperty
+from .mapper import Mapper as mapper
+from .properties import ColumnProperty
+from .util import class_mapper
+from .. import event
+from .. import exc
+from .. import util
+from ..sql import expression
+from ..sql.schema import Column
+from ..sql.schema import Table
+from ..util import topological
+
+
+def _declared_mapping_info(cls):
+ # deferred mapping
+ if _DeferredMapperConfig.has_cls(cls):
+ return _DeferredMapperConfig.config_for_cls(cls)
+ # regular mapping
+ elif _is_mapped_class(cls):
+ return class_mapper(cls, configure=False)
+ else:
+ return None
+
+
+def _resolve_for_abstract_or_classical(cls):
+ if cls is object:
+ return None
+
+ if cls.__dict__.get("__abstract__", False):
+ for sup in cls.__bases__:
+ sup = _resolve_for_abstract_or_classical(sup)
+ if sup is not None:
+ return sup
+ else:
+ return None
+ else:
+ clsmanager = _dive_for_cls_manager(cls)
+
+ if clsmanager:
+ return clsmanager.class_
+ else:
+ return cls
+
+
+def _get_immediate_cls_attr(cls, attrname, strict=False):
+ """return an attribute of the class that is either present directly
+ on the class, e.g. not on a superclass, or is from a superclass but
+ this superclass is a non-mapped mixin, that is, not a descendant of
+ the declarative base and is also not classically mapped.
+
+ This is used to detect attributes that indicate something about
+ a mapped class independently from any mapped classes that it may
+ inherit from.
+
+ """
+
+ # the rules are different for this name than others,
+ # make sure we've moved it out. transitional
+ assert attrname != "__abstract__"
+
+ if not issubclass(cls, object):
+ return None
+
+ if attrname in cls.__dict__:
+ return getattr(cls, attrname)
+
+ for base in cls.__mro__[1:]:
+ _is_classicial_inherits = _dive_for_cls_manager(base)
+
+ if attrname in base.__dict__ and (
+ base is cls
+ or (
+ (base in cls.__bases__ if strict else True)
+ and not _is_classicial_inherits
+ )
+ ):
+ return getattr(base, attrname)
+ else:
+ return None
+
+
+def _dive_for_cls_manager(cls):
+ # because the class manager registration is pluggable,
+ # we need to do the search for every class in the hierarchy,
+ # rather than just a simple "cls._sa_class_manager"
+
+ # python 2 old style class
+ if not hasattr(cls, "__mro__"):
+ return None
+
+ for base in cls.__mro__:
+ manager = attributes.manager_of_class(base)
+ if manager:
+ return manager
+ return None
+
+
+def _as_declarative(registry, cls, dict_):
+
+ # declarative scans the class for attributes. no table or mapper
+ # args passed separately.
+
+ return _MapperConfig.setup_mapping(registry, cls, dict_, None, {})
+
+
+def _mapper(registry, cls, table, mapper_kw):
+ _ImperativeMapperConfig(registry, cls, table, mapper_kw)
+ return cls.__mapper__
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _is_declarative_props(obj):
+ declared_attr = util.preloaded.orm_decl_api.declared_attr
+
+ return isinstance(obj, (declared_attr, util.classproperty))
+
+
+def _check_declared_props_nocascade(obj, name, cls):
+ if _is_declarative_props(obj):
+ if getattr(obj, "_cascading", False):
+ util.warn(
+ "@declared_attr.cascading is not supported on the %s "
+ "attribute on class %s. This attribute invokes for "
+ "subclasses in any case." % (name, cls)
+ )
+ return True
+ else:
+ return False
+
+
+class _MapperConfig(object):
+ __slots__ = (
+ "cls",
+ "classname",
+ "properties",
+ "declared_attr_reg",
+ "__weakref__",
+ )
+
+ @classmethod
+ def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
+ manager = attributes.manager_of_class(cls)
+ if manager and manager.class_ is cls_:
+ raise exc.InvalidRequestError(
+ "Class %r already has been " "instrumented declaratively" % cls
+ )
+
+ if cls_.__dict__.get("__abstract__", False):
+ return
+
+ defer_map = _get_immediate_cls_attr(
+ cls_, "_sa_decl_prepare_nocascade", strict=True
+ ) or hasattr(cls_, "_sa_decl_prepare")
+
+ if defer_map:
+ cfg_cls = _DeferredMapperConfig
+ else:
+ cfg_cls = _ClassScanMapperConfig
+
+ return cfg_cls(registry, cls_, dict_, table, mapper_kw)
+
+ def __init__(self, registry, cls_, mapper_kw):
+ self.cls = util.assert_arg_type(cls_, type, "cls_")
+ self.classname = cls_.__name__
+ self.properties = util.OrderedDict()
+ self.declared_attr_reg = {}
+
+ if not mapper_kw.get("non_primary", False):
+ instrumentation.register_class(
+ self.cls,
+ finalize=False,
+ registry=registry,
+ declarative_scan=self,
+ init_method=registry.constructor,
+ )
+ else:
+ manager = attributes.manager_of_class(self.cls)
+ if not manager or not manager.is_mapped:
+ raise exc.InvalidRequestError(
+ "Class %s has no primary mapper configured. Configure "
+ "a primary mapper first before setting up a non primary "
+ "Mapper." % self.cls
+ )
+
+ def set_cls_attribute(self, attrname, value):
+
+ manager = instrumentation.manager_of_class(self.cls)
+ manager.install_member(attrname, value)
+ return value
+
+ def _early_mapping(self, mapper_kw):
+ self.map(mapper_kw)
+
+
+class _ImperativeMapperConfig(_MapperConfig):
+ __slots__ = ("dict_", "local_table", "inherits")
+
+ def __init__(
+ self,
+ registry,
+ cls_,
+ table,
+ mapper_kw,
+ ):
+ super(_ImperativeMapperConfig, self).__init__(
+ registry, cls_, mapper_kw
+ )
+
+ self.dict_ = {}
+ self.local_table = self.set_cls_attribute("__table__", table)
+
+ with mapperlib._CONFIGURE_MUTEX:
+ if not mapper_kw.get("non_primary", False):
+ clsregistry.add_class(
+ self.classname, self.cls, registry._class_registry
+ )
+
+ self._setup_inheritance(mapper_kw)
+
+ self._early_mapping(mapper_kw)
+
+ def map(self, mapper_kw=util.EMPTY_DICT):
+ mapper_cls = mapper
+
+ return self.set_cls_attribute(
+ "__mapper__",
+ mapper_cls(self.cls, self.local_table, **mapper_kw),
+ )
+
+ def _setup_inheritance(self, mapper_kw):
+ cls = self.cls
+
+ inherits = mapper_kw.get("inherits", None)
+
+ if inherits is None:
+ # since we search for classical mappings now, search for
+ # multiple mapped bases as well and raise an error.
+ inherits_search = []
+ for c in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(c)
+ if c is None:
+ continue
+ if _declared_mapping_info(
+ c
+ ) is not None and not _get_immediate_cls_attr(
+ c, "_sa_decl_prepare_nocascade", strict=True
+ ):
+ inherits_search.append(c)
+
+ if inherits_search:
+ if len(inherits_search) > 1:
+ raise exc.InvalidRequestError(
+ "Class %s has multiple mapped bases: %r"
+ % (cls, inherits_search)
+ )
+ inherits = inherits_search[0]
+ elif isinstance(inherits, mapper):
+ inherits = inherits.class_
+
+ self.inherits = inherits
+
+
+class _ClassScanMapperConfig(_MapperConfig):
+ __slots__ = (
+ "dict_",
+ "local_table",
+ "persist_selectable",
+ "declared_columns",
+ "column_copies",
+ "table_args",
+ "tablename",
+ "mapper_args",
+ "mapper_args_fn",
+ "inherits",
+ )
+
+ def __init__(
+ self,
+ registry,
+ cls_,
+ dict_,
+ table,
+ mapper_kw,
+ ):
+
+ # grab class dict before the instrumentation manager has been added.
+ # reduces cycles
+ self.dict_ = dict(dict_) if dict_ else {}
+
+ super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw)
+
+ self.persist_selectable = None
+ self.declared_columns = set()
+ self.column_copies = {}
+ self._setup_declared_events()
+
+ self._scan_attributes()
+
+ with mapperlib._CONFIGURE_MUTEX:
+ clsregistry.add_class(
+ self.classname, self.cls, registry._class_registry
+ )
+
+ self._extract_mappable_attributes()
+
+ self._extract_declared_columns()
+
+ self._setup_table(table)
+
+ self._setup_inheritance(mapper_kw)
+
+ self._early_mapping(mapper_kw)
+
+ def _setup_declared_events(self):
+ if _get_immediate_cls_attr(self.cls, "__declare_last__"):
+
+ @event.listens_for(mapper, "after_configured")
+ def after_configured():
+ self.cls.__declare_last__()
+
+ if _get_immediate_cls_attr(self.cls, "__declare_first__"):
+
+ @event.listens_for(mapper, "before_configured")
+ def before_configured():
+ self.cls.__declare_first__()
+
+ def _cls_attr_override_checker(self, cls):
+ """Produce a function that checks if a class has overridden an
+ attribute, taking SQLAlchemy-enabled dataclass fields into account.
+
+ """
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__", None
+ )
+
+ if sa_dataclass_metadata_key is None:
+
+ def attribute_is_overridden(key, obj):
+ return getattr(cls, key) is not obj
+
+ else:
+
+ all_datacls_fields = {
+ f.name: f.metadata[sa_dataclass_metadata_key]
+ for f in util.dataclass_fields(cls)
+ if sa_dataclass_metadata_key in f.metadata
+ }
+ local_datacls_fields = {
+ f.name: f.metadata[sa_dataclass_metadata_key]
+ for f in util.local_dataclass_fields(cls)
+ if sa_dataclass_metadata_key in f.metadata
+ }
+
+ absent = object()
+
+ def attribute_is_overridden(key, obj):
+ if _is_declarative_props(obj):
+ obj = obj.fget
+
+ # this function likely has some failure modes still if
+ # someone is doing a deep mixing of the same attribute
+ # name as plain Python attribute vs. dataclass field.
+
+ ret = local_datacls_fields.get(key, absent)
+ if _is_declarative_props(ret):
+ ret = ret.fget
+
+ if ret is obj:
+ return False
+ elif ret is not absent:
+ return True
+
+ all_field = all_datacls_fields.get(key, absent)
+
+ ret = getattr(cls, key, obj)
+
+ if ret is obj:
+ return False
+
+ # for dataclasses, this could be the
+ # 'default' of the field. so filter more specifically
+ # for an already-mapped InstrumentedAttribute
+ if ret is not absent and isinstance(
+ ret, InstrumentedAttribute
+ ):
+ return True
+
+ if all_field is obj:
+ return False
+ elif all_field is not absent:
+ return True
+
+ # can't find another attribute
+ return False
+
+ return attribute_is_overridden
+
+ def _cls_attr_resolver(self, cls):
+ """produce a function to iterate the "attributes" of a class,
+ adjusting for SQLAlchemy fields embedded in dataclass fields.
+
+ """
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__", None
+ )
+
+ if sa_dataclass_metadata_key is None:
+
+ def local_attributes_for_class():
+ for name, obj in vars(cls).items():
+ yield name, obj, False
+
+ else:
+ field_names = set()
+
+ def local_attributes_for_class():
+ for field in util.local_dataclass_fields(cls):
+ if sa_dataclass_metadata_key in field.metadata:
+ field_names.add(field.name)
+ yield field.name, _as_dc_declaredattr(
+ field.metadata, sa_dataclass_metadata_key
+ ), True
+ for name, obj in vars(cls).items():
+ if name not in field_names:
+ yield name, obj, False
+
+ return local_attributes_for_class
+
+ def _scan_attributes(self):
+ cls = self.cls
+ dict_ = self.dict_
+ column_copies = self.column_copies
+ mapper_args_fn = None
+ table_args = inherited_table_args = None
+ tablename = None
+
+ attribute_is_overridden = self._cls_attr_override_checker(self.cls)
+
+ bases = []
+
+ for base in cls.__mro__:
+ # collect bases and make sure standalone columns are copied
+ # to be the column they will ultimately be on the class,
+ # so that declared_attr functions use the right columns.
+ # need to do this all the way up the hierarchy first
+ # (see #8190)
+
+ class_mapped = (
+ base is not cls
+ and _declared_mapping_info(base) is not None
+ and not _get_immediate_cls_attr(
+ base, "_sa_decl_prepare_nocascade", strict=True
+ )
+ )
+
+ local_attributes_for_class = self._cls_attr_resolver(base)
+
+ if not class_mapped and base is not cls:
+ locally_collected_columns = self._produce_column_copies(
+ local_attributes_for_class,
+ attribute_is_overridden,
+ )
+ else:
+ locally_collected_columns = {}
+
+ bases.append(
+ (
+ base,
+ class_mapped,
+ local_attributes_for_class,
+ locally_collected_columns,
+ )
+ )
+
+ for (
+ base,
+ class_mapped,
+ local_attributes_for_class,
+ locally_collected_columns,
+ ) in bases:
+
+ # this transfer can also take place as we scan each name
+ # for finer-grained control of how collected_attributes is
+ # populated, as this is what impacts column ordering.
+ # however it's simpler to get it out of the way here.
+ dict_.update(locally_collected_columns)
+
+ for name, obj, is_dataclass in local_attributes_for_class():
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (not class_mapped or check_decl):
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ def mapper_args_fn():
+ return dict(cls.__mapper_args__)
+
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
+ tablename = cls.__tablename__
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
+ table_args = cls.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))
+ ):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None"
+ )
+ if base is not cls:
+ inherited_table_args = True
+ elif class_mapped:
+ if _is_declarative_props(obj):
+ util.warn(
+ "Regular (i.e. not __special__) "
+ "attribute '%s.%s' uses @declared_attr, "
+ "but owning class %s is mapped - "
+ "not applying to subclass %s."
+ % (base.__name__, name, base, cls)
+ )
+ continue
+ elif base is not cls:
+ # we're a mixin, abstract base, or something that is
+ # acting like that for now.
+ if isinstance(obj, Column):
+ # already copied columns to the mapped class.
+ continue
+ elif isinstance(obj, MapperProperty):
+ raise exc.InvalidRequestError(
+ "Mapper properties (i.e. deferred,"
+ "column_property(), relationship(), etc.) must "
+ "be declared as @declared_attr callables "
+ "on declarative mixin classes. For dataclass "
+ "field() objects, use a lambda:"
+ )
+ elif _is_declarative_props(obj):
+ if obj._cascading:
+ if name in dict_:
+ # unfortunately, while we can use the user-
+ # defined attribute here to allow a clean
+ # override, if there's another
+ # subclass below then it still tries to use
+ # this. not sure if there is enough
+ # information here to add this as a feature
+ # later on.
+ util.warn(
+ "Attribute '%s' on class %s cannot be "
+ "processed due to "
+ "@declared_attr.cascading; "
+ "skipping" % (name, cls)
+ )
+ dict_[name] = column_copies[
+ obj
+ ] = ret = obj.__get__(obj, cls)
+ setattr(cls, name, ret)
+ else:
+ if is_dataclass:
+ # access attribute using normal class access
+ # first, to see if it's been mapped on a
+ # superclass. note if the dataclasses.field()
+ # has "default", this value can be anything.
+ ret = getattr(cls, name, None)
+
+ # so, if it's anything that's not ORM
+ # mapped, assume we should invoke the
+ # declared_attr
+ if not isinstance(ret, InspectionAttr):
+ ret = obj.fget()
+ else:
+ # access attribute using normal class access.
+ # if the declared attr already took place
+ # on a superclass that is mapped, then
+ # this is no longer a declared_attr, it will
+ # be the InstrumentedAttribute
+ ret = getattr(cls, name)
+
+ # correct for proxies created from hybrid_property
+ # or similar. note there is no known case that
+ # produces nested proxies, so we are only
+ # looking one level deep right now.
+ if (
+ isinstance(ret, InspectionAttr)
+ and ret._is_internal_proxy
+ and not isinstance(
+ ret.original_property, MapperProperty
+ )
+ ):
+ ret = ret.descriptor
+
+ dict_[name] = column_copies[obj] = ret
+ if (
+ isinstance(ret, (Column, MapperProperty))
+ and ret.doc is None
+ ):
+ ret.doc = obj.__doc__
+ # here, the attribute is some other kind of property that
+ # we assume is not part of the declarative mapping.
+ # however, check for some more common mistakes
+ else:
+ self._warn_for_decl_attributes(base, name, obj)
+ elif is_dataclass and (
+ name not in dict_ or dict_[name] is not obj
+ ):
+ # here, we are definitely looking at the target class
+ # and not a superclass. this is currently a
+ # dataclass-only path. if the name is only
+ # a dataclass field and isn't in local cls.__dict__,
+ # put the object there.
+ # assert that the dataclass-enabled resolver agrees
+ # with what we are seeing
+
+ assert not attribute_is_overridden(name, obj)
+
+ if _is_declarative_props(obj):
+ obj = obj.fget()
+
+ dict_[name] = obj
+
+ if inherited_table_args and not tablename:
+ table_args = None
+
+ self.table_args = table_args
+ self.tablename = tablename
+ self.mapper_args_fn = mapper_args_fn
+
+ def _warn_for_decl_attributes(self, cls, key, c):
+ if isinstance(c, expression.ColumnClause):
+ util.warn(
+ "Attribute '%s' on class %s appears to be a non-schema "
+ "'sqlalchemy.sql.column()' "
+ "object; this won't be part of the declarative mapping"
+ % (key, cls)
+ )
+
+ def _produce_column_copies(
+ self, attributes_for_class, attribute_is_overridden
+ ):
+ cls = self.cls
+ dict_ = self.dict_
+ locally_collected_attributes = {}
+ column_copies = self.column_copies
+ # copy mixin columns to the mapped class
+
+ for name, obj, is_dataclass in attributes_for_class():
+ if isinstance(obj, Column):
+ if attribute_is_overridden(name, obj):
+ # if column has been overridden
+ # (like by the InstrumentedAttribute of the
+ # superclass), skip
+ continue
+ elif obj.foreign_keys:
+ raise exc.InvalidRequestError(
+ "Columns with foreign keys to other columns "
+ "must be declared as @declared_attr callables "
+ "on declarative mixin classes. For dataclass "
+ "field() objects, use a lambda:."
+ )
+ elif name not in dict_ and not (
+ "__table__" in dict_
+ and (obj.name or name) in dict_["__table__"].c
+ ):
+ column_copies[obj] = copy_ = obj._copy()
+ copy_._creation_order = obj._creation_order
+ setattr(cls, name, copy_)
+ locally_collected_attributes[name] = copy_
+ return locally_collected_attributes
+
+ def _extract_mappable_attributes(self):
+ cls = self.cls
+ dict_ = self.dict_
+
+ our_stuff = self.properties
+
+ late_mapped = _get_immediate_cls_attr(
+ cls, "_sa_decl_prepare_nocascade", strict=True
+ )
+
+ for k in list(dict_):
+
+ if k in ("__table__", "__tablename__", "__mapper_args__"):
+ continue
+
+ value = dict_[k]
+ if _is_declarative_props(value):
+ if value._cascading:
+ util.warn(
+ "Use of @declared_attr.cascading only applies to "
+ "Declarative 'mixin' and 'abstract' classes. "
+ "Currently, this flag is ignored on mapped class "
+ "%s" % self.cls
+ )
+
+ value = getattr(cls, k)
+
+ elif (
+ isinstance(value, QueryableAttribute)
+ and value.class_ is not cls
+ and value.key != k
+ ):
+ # detect a QueryableAttribute that's already mapped being
+ # assigned elsewhere in userland, turn into a synonym()
+ value = SynonymProperty(value.key)
+ setattr(cls, k, value)
+
+ if (
+ isinstance(value, tuple)
+ and len(value) == 1
+ and isinstance(value[0], (Column, MapperProperty))
+ ):
+ util.warn(
+ "Ignoring declarative-like tuple value of attribute "
+ "'%s': possibly a copy-and-paste error with a comma "
+ "accidentally placed at the end of the line?" % k
+ )
+ continue
+ elif not isinstance(value, (Column, MapperProperty)):
+ # using @declared_attr for some object that
+ # isn't Column/MapperProperty; remove from the dict_
+ # and place the evaluated value onto the class.
+ if not k.startswith("__"):
+ dict_.pop(k)
+ self._warn_for_decl_attributes(cls, k, value)
+ if not late_mapped:
+ setattr(cls, k, value)
+ continue
+ # we expect to see the name 'metadata' in some valid cases;
+ # however at this point we see it's assigned to something trying
+ # to be mapped, so raise for that.
+ elif k == "metadata":
+ raise exc.InvalidRequestError(
+ "Attribute name 'metadata' is reserved "
+ "for the MetaData instance when using a "
+ "declarative base class."
+ )
+ our_stuff[k] = value
+
+ def _extract_declared_columns(self):
+ our_stuff = self.properties
+
+ # set up attributes in the order they were created
+ util.sort_dictionary(
+ our_stuff, key=lambda key: our_stuff[key]._creation_order
+ )
+
+ # extract columns from the class dict
+ declared_columns = self.declared_columns
+ name_to_prop_key = collections.defaultdict(set)
+ for key, c in list(our_stuff.items()):
+ if isinstance(c, (ColumnProperty, CompositeProperty)):
+ for col in c.columns:
+ if isinstance(col, Column) and col.table is None:
+ _undefer_column_name(key, col)
+ if not isinstance(c, CompositeProperty):
+ name_to_prop_key[col.name].add(key)
+ declared_columns.add(col)
+ elif isinstance(c, Column):
+ _undefer_column_name(key, c)
+ name_to_prop_key[c.name].add(key)
+ declared_columns.add(c)
+ # if the column is the same name as the key,
+ # remove it from the explicit properties dict.
+ # the normal rules for assigning column-based properties
+ # will take over, including precedence of columns
+ # in multi-column ColumnProperties.
+ if key == c.key:
+ del our_stuff[key]
+
+ for name, keys in name_to_prop_key.items():
+ if len(keys) > 1:
+ util.warn(
+ "On class %r, Column object %r named "
+ "directly multiple times, "
+ "only one will be used: %s. "
+ "Consider using orm.synonym instead"
+ % (self.classname, name, (", ".join(sorted(keys))))
+ )
+
+ def _setup_table(self, table=None):
+ cls = self.cls
+ tablename = self.tablename
+ table_args = self.table_args
+ dict_ = self.dict_
+ declared_columns = self.declared_columns
+
+ manager = attributes.manager_of_class(cls)
+
+ declared_columns = self.declared_columns = sorted(
+ declared_columns, key=lambda c: c._creation_order
+ )
+
+ if "__table__" not in dict_ and table is None:
+ if hasattr(cls, "__table_cls__"):
+ table_cls = util.unbound_method_to_callable(cls.__table_cls__)
+ else:
+ table_cls = Table
+
+ if tablename is not None:
+
+ args, table_kw = (), {}
+ if table_args:
+ if isinstance(table_args, dict):
+ table_kw = table_args
+ elif isinstance(table_args, tuple):
+ if isinstance(table_args[-1], dict):
+ args, table_kw = table_args[0:-1], table_args[-1]
+ else:
+ args = table_args
+
+ autoload_with = dict_.get("__autoload_with__")
+ if autoload_with:
+ table_kw["autoload_with"] = autoload_with
+
+ autoload = dict_.get("__autoload__")
+ if autoload:
+ table_kw["autoload"] = True
+
+ table = self.set_cls_attribute(
+ "__table__",
+ table_cls(
+ tablename,
+ self._metadata_for_cls(manager),
+ *(tuple(declared_columns) + tuple(args)),
+ **table_kw
+ ),
+ )
+ else:
+ if table is None:
+ table = cls.__table__
+ if declared_columns:
+ for c in declared_columns:
+ if not table.c.contains_column(c):
+ raise exc.ArgumentError(
+ "Can't add additional column %r when "
+ "specifying __table__" % c.key
+ )
+ self.local_table = table
+
+ def _metadata_for_cls(self, manager):
+ if hasattr(self.cls, "metadata"):
+ return self.cls.metadata
+ else:
+ return manager.registry.metadata
+
+ def _setup_inheritance(self, mapper_kw):
+ table = self.local_table
+ cls = self.cls
+ table_args = self.table_args
+ declared_columns = self.declared_columns
+
+ inherits = mapper_kw.get("inherits", None)
+
+ if inherits is None:
+ # since we search for classical mappings now, search for
+ # multiple mapped bases as well and raise an error.
+ inherits_search = []
+ for c in cls.__bases__:
+ c = _resolve_for_abstract_or_classical(c)
+ if c is None:
+ continue
+ if _declared_mapping_info(
+ c
+ ) is not None and not _get_immediate_cls_attr(
+ c, "_sa_decl_prepare_nocascade", strict=True
+ ):
+ if c not in inherits_search:
+ inherits_search.append(c)
+
+ if inherits_search:
+ if len(inherits_search) > 1:
+ raise exc.InvalidRequestError(
+ "Class %s has multiple mapped bases: %r"
+ % (cls, inherits_search)
+ )
+ inherits = inherits_search[0]
+ elif isinstance(inherits, mapper):
+ inherits = inherits.class_
+
+ self.inherits = inherits
+
+ if (
+ table is None
+ and self.inherits is None
+ and not _get_immediate_cls_attr(cls, "__no_table__")
+ ):
+
+ raise exc.InvalidRequestError(
+ "Class %r does not have a __table__ or __tablename__ "
+ "specified and does not inherit from an existing "
+ "table-mapped class." % cls
+ )
+ elif self.inherits:
+ inherited_mapper = _declared_mapping_info(self.inherits)
+ inherited_table = inherited_mapper.local_table
+ inherited_persist_selectable = inherited_mapper.persist_selectable
+
+ if table is None:
+ # single table inheritance.
+ # ensure no table args
+ if table_args:
+ raise exc.ArgumentError(
+ "Can't place __table_args__ on an inherited class "
+ "with no table."
+ )
+ # add any columns declared here to the inherited table.
+ for c in declared_columns:
+ if c.name in inherited_table.c:
+ if inherited_table.c[c.name] is c:
+ continue
+ raise exc.ArgumentError(
+ "Column '%s' on class %s conflicts with "
+ "existing column '%s'"
+ % (c, cls, inherited_table.c[c.name])
+ )
+ if c.primary_key:
+ raise exc.ArgumentError(
+ "Can't place primary key columns on an inherited "
+ "class with no table."
+ )
+ inherited_table.append_column(c)
+ if (
+ inherited_persist_selectable is not None
+ and inherited_persist_selectable is not inherited_table
+ ):
+ inherited_persist_selectable._refresh_for_new_column(c)
+
+ def _prepare_mapper_arguments(self, mapper_kw):
+ properties = self.properties
+
+ if self.mapper_args_fn:
+ mapper_args = self.mapper_args_fn()
+ else:
+ mapper_args = {}
+
+ if mapper_kw:
+ mapper_args.update(mapper_kw)
+
+ if "properties" in mapper_args:
+ properties = dict(properties)
+ properties.update(mapper_args["properties"])
+
+ # make sure that column copies are used rather
+ # than the original columns from any mixins
+ for k in ("version_id_col", "polymorphic_on"):
+ if k in mapper_args:
+ v = mapper_args[k]
+ mapper_args[k] = self.column_copies.get(v, v)
+
+ if "inherits" in mapper_args:
+ inherits_arg = mapper_args["inherits"]
+ if isinstance(inherits_arg, mapper):
+ inherits_arg = inherits_arg.class_
+
+ if inherits_arg is not self.inherits:
+ raise exc.InvalidRequestError(
+ "mapper inherits argument given for non-inheriting "
+ "class %s" % (mapper_args["inherits"])
+ )
+
+ if self.inherits:
+ mapper_args["inherits"] = self.inherits
+
+ if self.inherits and not mapper_args.get("concrete", False):
+ # single or joined inheritance
+ # exclude any cols on the inherited table which are
+ # not mapped on the parent class, to avoid
+ # mapping columns specific to sibling/nephew classes
+ inherited_mapper = _declared_mapping_info(self.inherits)
+ inherited_table = inherited_mapper.local_table
+
+ if "exclude_properties" not in mapper_args:
+ mapper_args["exclude_properties"] = exclude_properties = set(
+ [
+ c.key
+ for c in inherited_table.c
+ if c not in inherited_mapper._columntoproperty
+ ]
+ ).union(inherited_mapper.exclude_properties or ())
+ exclude_properties.difference_update(
+ [c.key for c in self.declared_columns]
+ )
+
+ # look through columns in the current mapper that
+ # are keyed to a propname different than the colname
+ # (if names were the same, we'd have popped it out above,
+ # in which case the mapper makes this combination).
+ # See if the superclass has a similar column property.
+ # If so, join them together.
+ for k, col in list(properties.items()):
+ if not isinstance(col, expression.ColumnElement):
+ continue
+ if k in inherited_mapper._props:
+ p = inherited_mapper._props[k]
+ if isinstance(p, ColumnProperty):
+ # note here we place the subclass column
+ # first. See [ticket:1892] for background.
+ properties[k] = [col] + p.columns
+ result_mapper_args = mapper_args.copy()
+ result_mapper_args["properties"] = properties
+ self.mapper_args = result_mapper_args
+
+ def map(self, mapper_kw=util.EMPTY_DICT):
+ self._prepare_mapper_arguments(mapper_kw)
+ if hasattr(self.cls, "__mapper_cls__"):
+ mapper_cls = util.unbound_method_to_callable(
+ self.cls.__mapper_cls__
+ )
+ else:
+ mapper_cls = mapper
+
+ return self.set_cls_attribute(
+ "__mapper__",
+ mapper_cls(self.cls, self.local_table, **self.mapper_args),
+ )
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
+ # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr.
+ # we can't write it because field.metadata is immutable :( so we have
+ # to go through extra trouble to compare these
+ decl_api = util.preloaded.orm_decl_api
+ obj = field_metadata[sa_dataclass_metadata_key]
+ if callable(obj) and not isinstance(obj, decl_api.declared_attr):
+ return decl_api.declared_attr(obj)
+ else:
+ return obj
+
+
+class _DeferredMapperConfig(_ClassScanMapperConfig):
+ _configs = util.OrderedDict()
+
+ def _early_mapping(self, mapper_kw):
+ pass
+
+ @property
+ def cls(self):
+ return self._cls()
+
+ @cls.setter
+ def cls(self, class_):
+ self._cls = weakref.ref(class_, self._remove_config_cls)
+ self._configs[self._cls] = self
+
+ @classmethod
+ def _remove_config_cls(cls, ref):
+ cls._configs.pop(ref, None)
+
+ @classmethod
+ def has_cls(cls, class_):
+ # 2.6 fails on weakref if class_ is an old style class
+ return isinstance(class_, type) and weakref.ref(class_) in cls._configs
+
+ @classmethod
+ def raise_unmapped_for_cls(cls, class_):
+ if hasattr(class_, "_sa_raise_deferred_config"):
+ class_._sa_raise_deferred_config()
+
+ raise orm_exc.UnmappedClassError(
+ class_,
+ msg="Class %s has a deferred mapping on it. It is not yet "
+ "usable as a mapped class." % orm_exc._safe_cls_name(class_),
+ )
+
+ @classmethod
+ def config_for_cls(cls, class_):
+ return cls._configs[weakref.ref(class_)]
+
+ @classmethod
+ def classes_for_base(cls, base_cls, sort=True):
+ classes_for_base = [
+ m
+ for m, cls_ in [(m, m.cls) for m in cls._configs.values()]
+ if cls_ is not None and issubclass(cls_, base_cls)
+ ]
+
+ if not sort:
+ return classes_for_base
+
+ all_m_by_cls = dict((m.cls, m) for m in classes_for_base)
+
+ tuples = []
+ for m_cls in all_m_by_cls:
+ tuples.extend(
+ (all_m_by_cls[base_cls], all_m_by_cls[m_cls])
+ for base_cls in m_cls.__bases__
+ if base_cls in all_m_by_cls
+ )
+ return list(topological.sort(tuples, classes_for_base))
+
+ def map(self, mapper_kw=util.EMPTY_DICT):
+ self._configs.pop(self._cls, None)
+ return super(_DeferredMapperConfig, self).map(mapper_kw)
+
+
+def _add_attribute(cls, key, value):
+ """add an attribute to an existing declarative class.
+
+ This runs through the logic to determine MapperProperty,
+ adds it to the Mapper, adds a column to the mapped Table, etc.
+
+ """
+
+ if "__mapper__" in cls.__dict__:
+ if isinstance(value, Column):
+ _undefer_column_name(key, value)
+ cls.__table__.append_column(value, replace_existing=True)
+ cls.__mapper__.add_property(key, value)
+ elif isinstance(value, ColumnProperty):
+ for col in value.columns:
+ if isinstance(col, Column) and col.table is None:
+ _undefer_column_name(key, col)
+ cls.__table__.append_column(col, replace_existing=True)
+ cls.__mapper__.add_property(key, value)
+ elif isinstance(value, MapperProperty):
+ cls.__mapper__.add_property(key, value)
+ elif isinstance(value, QueryableAttribute) and value.key != key:
+ # detect a QueryableAttribute that's already mapped being
+ # assigned elsewhere in userland, turn into a synonym()
+ value = SynonymProperty(value.key)
+ cls.__mapper__.add_property(key, value)
+ else:
+ type.__setattr__(cls, key, value)
+ cls.__mapper__._expire_memoizations()
+ else:
+ type.__setattr__(cls, key, value)
+
+
+def _del_attribute(cls, key):
+
+ if (
+ "__mapper__" in cls.__dict__
+ and key in cls.__dict__
+ and not cls.__mapper__._dispose_called
+ ):
+ value = cls.__dict__[key]
+ if isinstance(
+ value, (Column, ColumnProperty, MapperProperty, QueryableAttribute)
+ ):
+ raise NotImplementedError(
+ "Can't un-map individual mapped attributes on a mapped class."
+ )
+ else:
+ type.__delattr__(cls, key)
+ cls.__mapper__._expire_memoizations()
+ else:
+ type.__delattr__(cls, key)
+
+
+def _declarative_constructor(self, **kwargs):
+ """A simple constructor that allows initialization from kwargs.
+
+ Sets attributes on the constructed instance using the names and
+ values in ``kwargs``.
+
+ Only keys that are present as
+ attributes of the instance's class are allowed. These could be,
+ for example, any mapped columns or relationships.
+ """
+ cls_ = type(self)
+ for k in kwargs:
+ if not hasattr(cls_, k):
+ raise TypeError(
+ "%r is an invalid keyword argument for %s" % (k, cls_.__name__)
+ )
+ setattr(self, k, kwargs[k])
+
+
+_declarative_constructor.__name__ = "__init__"
+
+
+def _undefer_column_name(key, column):
+ if column.key is None:
+ column.key = key
+ if column.name is None:
+ column.name = key
diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py
new file mode 100644
index 0000000..1b5be9a
--- /dev/null
+++ b/lib/sqlalchemy/orm/dependency.py
@@ -0,0 +1,1290 @@
+# orm/dependency.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Relationship dependencies.
+
+"""
+
+from . import attributes
+from . import exc
+from . import sync
+from . import unitofwork
+from . import util as mapperutil
+from .interfaces import MANYTOMANY
+from .interfaces import MANYTOONE
+from .interfaces import ONETOMANY
+from .. import exc as sa_exc
+from .. import sql
+from .. import util
+
+
+class DependencyProcessor(object):
+ def __init__(self, prop):
+ self.prop = prop
+ self.cascade = prop.cascade
+ self.mapper = prop.mapper
+ self.parent = prop.parent
+ self.secondary = prop.secondary
+ self.direction = prop.direction
+ self.post_update = prop.post_update
+ self.passive_deletes = prop.passive_deletes
+ self.passive_updates = prop.passive_updates
+ self.enable_typechecks = prop.enable_typechecks
+ if self.passive_deletes:
+ self._passive_delete_flag = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ self._passive_delete_flag = attributes.PASSIVE_OFF
+ if self.passive_updates:
+ self._passive_update_flag = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ self._passive_update_flag = attributes.PASSIVE_OFF
+
+ self.sort_key = "%s_%s" % (self.parent._sort_key, prop.key)
+ self.key = prop.key
+ if not self.prop.synchronize_pairs:
+ raise sa_exc.ArgumentError(
+ "Can't build a DependencyProcessor for relationship %s. "
+ "No target attributes to populate between parent and "
+ "child are present" % self.prop
+ )
+
+ @classmethod
+ def from_relationship(cls, prop):
+ return _direction_to_processor[prop.direction](prop)
+
+ def hasparent(self, state):
+ """return True if the given object instance has a parent,
+ according to the ``InstrumentedAttribute`` handled by this
+ ``DependencyProcessor``.
+
+ """
+ return self.parent.class_manager.get_impl(self.key).hasparent(state)
+
+ def per_property_preprocessors(self, uow):
+ """establish actions and dependencies related to a flush.
+
+ These actions will operate on all relevant states in
+ the aggregate.
+
+ """
+ uow.register_preprocessor(self, True)
+
+ def per_property_flush_actions(self, uow):
+ after_save = unitofwork.ProcessAll(uow, self, False, True)
+ before_delete = unitofwork.ProcessAll(uow, self, True, True)
+
+ parent_saves = unitofwork.SaveUpdateAll(
+ uow, self.parent.primary_base_mapper
+ )
+ child_saves = unitofwork.SaveUpdateAll(
+ uow, self.mapper.primary_base_mapper
+ )
+
+ parent_deletes = unitofwork.DeleteAll(
+ uow, self.parent.primary_base_mapper
+ )
+ child_deletes = unitofwork.DeleteAll(
+ uow, self.mapper.primary_base_mapper
+ )
+
+ self.per_property_dependencies(
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ )
+
+ def per_state_flush_actions(self, uow, states, isdelete):
+ """establish actions and dependencies related to a flush.
+
+ These actions will operate on all relevant states
+ individually. This occurs only if there are cycles
+ in the 'aggregated' version of events.
+
+ """
+
+ child_base_mapper = self.mapper.primary_base_mapper
+ child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper)
+ child_deletes = unitofwork.DeleteAll(uow, child_base_mapper)
+
+ # locate and disable the aggregate processors
+ # for this dependency
+
+ if isdelete:
+ before_delete = unitofwork.ProcessAll(uow, self, True, True)
+ before_delete.disabled = True
+ else:
+ after_save = unitofwork.ProcessAll(uow, self, False, True)
+ after_save.disabled = True
+
+ # check if the "child" side is part of the cycle
+
+ if child_saves not in uow.cycles:
+ # based on the current dependencies we use, the saves/
+ # deletes should always be in the 'cycles' collection
+ # together. if this changes, we will have to break up
+ # this method a bit more.
+ assert child_deletes not in uow.cycles
+
+ # child side is not part of the cycle, so we will link per-state
+ # actions to the aggregate "saves", "deletes" actions
+ child_actions = [(child_saves, False), (child_deletes, True)]
+ child_in_cycles = False
+ else:
+ child_in_cycles = True
+
+ # check if the "parent" side is part of the cycle
+ if not isdelete:
+ parent_saves = unitofwork.SaveUpdateAll(
+ uow, self.parent.base_mapper
+ )
+ parent_deletes = before_delete = None
+ if parent_saves in uow.cycles:
+ parent_in_cycles = True
+ else:
+ parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper)
+ parent_saves = after_save = None
+ if parent_deletes in uow.cycles:
+ parent_in_cycles = True
+
+ # now create actions /dependencies for each state.
+
+ for state in states:
+ # detect if there's anything changed or loaded
+ # by a preprocessor on this state/attribute. In the
+ # case of deletes we may try to load missing items here as well.
+ sum_ = state.manager[self.key].impl.get_all_pending(
+ state,
+ state.dict,
+ self._passive_delete_flag
+ if isdelete
+ else attributes.PASSIVE_NO_INITIALIZE,
+ )
+
+ if not sum_:
+ continue
+
+ if isdelete:
+ before_delete = unitofwork.ProcessState(uow, self, True, state)
+ if parent_in_cycles:
+ parent_deletes = unitofwork.DeleteState(uow, state)
+ else:
+ after_save = unitofwork.ProcessState(uow, self, False, state)
+ if parent_in_cycles:
+ parent_saves = unitofwork.SaveUpdateState(uow, state)
+
+ if child_in_cycles:
+ child_actions = []
+ for child_state, child in sum_:
+ if child_state not in uow.states:
+ child_action = (None, None)
+ else:
+ (deleted, listonly) = uow.states[child_state]
+ if deleted:
+ child_action = (
+ unitofwork.DeleteState(uow, child_state),
+ True,
+ )
+ else:
+ child_action = (
+ unitofwork.SaveUpdateState(uow, child_state),
+ False,
+ )
+ child_actions.append(child_action)
+
+ # establish dependencies between our possibly per-state
+ # parent action and our possibly per-state child action.
+ for child_action, childisdelete in child_actions:
+ self.per_state_dependencies(
+ uow,
+ parent_saves,
+ parent_deletes,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ )
+
+ def presort_deletes(self, uowcommit, states):
+ return False
+
+ def presort_saves(self, uowcommit, states):
+ return False
+
+ def process_deletes(self, uowcommit, states):
+ pass
+
+ def process_saves(self, uowcommit, states):
+ pass
+
+ def prop_has_changes(self, uowcommit, states, isdelete):
+ if not isdelete or self.passive_deletes:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ elif self.direction is MANYTOONE:
+ # here, we were hoping to optimize having to fetch many-to-one
+ # for history and ignore it, if there's no further cascades
+ # to take place. however there are too many less common conditions
+ # that still take place and tests in test_relationships /
+ # test_cascade etc. will still fail.
+ passive = attributes.PASSIVE_NO_FETCH_RELATED
+ else:
+ passive = attributes.PASSIVE_OFF
+
+ for s in states:
+ # TODO: add a high speed method
+ # to InstanceState which returns: attribute
+ # has a non-None value, or had one
+ history = uowcommit.get_attribute_history(s, self.key, passive)
+ if history and not history.empty():
+ return True
+ else:
+ return (
+ states
+ and not self.prop._is_self_referential
+ and self.mapper in uowcommit.mappers
+ )
+
+ def _verify_canload(self, state):
+ if self.prop.uselist and state is None:
+ raise exc.FlushError(
+ "Can't flush None value found in "
+ "collection %s" % (self.prop,)
+ )
+ elif state is not None and not self.mapper._canload(
+ state, allow_subtypes=not self.enable_typechecks
+ ):
+ if self.mapper._canload(state, allow_subtypes=True):
+ raise exc.FlushError(
+ "Attempting to flush an item of type "
+ "%(x)s as a member of collection "
+ '"%(y)s". Expected an object of type '
+ "%(z)s or a polymorphic subclass of "
+ "this type. If %(x)s is a subclass of "
+ '%(z)s, configure mapper "%(zm)s" to '
+ "load this subtype polymorphically, or "
+ "set enable_typechecks=False to allow "
+ "any subtype to be accepted for flush. "
+ % {
+ "x": state.class_,
+ "y": self.prop,
+ "z": self.mapper.class_,
+ "zm": self.mapper,
+ }
+ )
+ else:
+ raise exc.FlushError(
+ "Attempting to flush an item of type "
+ "%(x)s as a member of collection "
+ '"%(y)s". Expected an object of type '
+ "%(z)s or a polymorphic subclass of "
+ "this type."
+ % {
+ "x": state.class_,
+ "y": self.prop,
+ "z": self.mapper.class_,
+ }
+ )
+
+ def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
+ raise NotImplementedError()
+
+ def _get_reversed_processed_set(self, uow):
+ if not self.prop._reverse_property:
+ return None
+
+ process_key = tuple(
+ sorted([self.key] + [p.key for p in self.prop._reverse_property])
+ )
+ return uow.memo(("reverse_key", process_key), set)
+
+ def _post_update(self, state, uowcommit, related, is_m2o_delete=False):
+ for x in related:
+ if not is_m2o_delete or x is not None:
+ uowcommit.register_post_update(
+ state, [r for l, r in self.prop.synchronize_pairs]
+ )
+ break
+
+ def _pks_changed(self, uowcommit, state):
+ raise NotImplementedError()
+
+ def __repr__(self):
+ return "%s(%s)" % (self.__class__.__name__, self.prop)
+
+
+class OneToManyDP(DependencyProcessor):
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
+ if self.post_update:
+ child_post_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, False
+ )
+ child_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (parent_saves, after_save),
+ (after_save, child_post_updates),
+ (before_delete, child_pre_updates),
+ (child_pre_updates, parent_deletes),
+ (child_pre_updates, child_deletes),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (parent_saves, after_save),
+ (after_save, child_saves),
+ (after_save, child_deletes),
+ (child_saves, parent_deletes),
+ (child_deletes, parent_deletes),
+ (before_delete, child_saves),
+ (before_delete, child_deletes),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
+
+ if self.post_update:
+
+ child_post_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, False
+ )
+ child_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.mapper.primary_base_mapper, True
+ )
+
+ # TODO: this whole block is not covered
+ # by any tests
+ if not isdelete:
+ if childisdelete:
+ uow.dependencies.update(
+ [
+ (child_action, after_save),
+ (after_save, child_post_updates),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (child_action, after_save),
+ (after_save, child_post_updates),
+ ]
+ )
+ else:
+ if childisdelete:
+ uow.dependencies.update(
+ [
+ (before_delete, child_pre_updates),
+ (child_pre_updates, delete_parent),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (before_delete, child_pre_updates),
+ (child_pre_updates, delete_parent),
+ ]
+ )
+ elif not isdelete:
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (after_save, child_action),
+ (save_parent, child_action),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [(before_delete, child_action), (child_action, delete_parent)]
+ )
+
+ def presort_deletes(self, uowcommit, states):
+ # head object is being deleted, and we manage its list of
+ # child objects the child objects have to have their
+ # foreign key to the parent set to NULL
+ should_null_fks = (
+ not self.cascade.delete and not self.passive_deletes == "all"
+ )
+
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.deleted:
+ if child is not None and self.hasparent(child) is False:
+ if self.cascade.delete_orphan:
+ uowcommit.register_object(child, isdelete=True)
+ else:
+ uowcommit.register_object(child)
+
+ if should_null_fks:
+ for child in history.unchanged:
+ if child is not None:
+ uowcommit.register_object(
+ child, operation="delete", prop=self.prop
+ )
+
+ def presort_saves(self, uowcommit, states):
+ children_added = uowcommit.memo(("children_added", self), set)
+
+ should_null_fks = (
+ not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ )
+
+ for state in states:
+ pks_changed = self._pks_changed(uowcommit, state)
+
+ if not pks_changed or self.passive_updates:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ passive = attributes.PASSIVE_OFF
+
+ history = uowcommit.get_attribute_history(state, self.key, passive)
+ if history:
+ for child in history.added:
+ if child is not None:
+ uowcommit.register_object(
+ child,
+ cancel_delete=True,
+ operation="add",
+ prop=self.prop,
+ )
+
+ children_added.update(history.added)
+
+ for child in history.deleted:
+ if not self.cascade.delete_orphan:
+ if should_null_fks:
+ uowcommit.register_object(
+ child,
+ isdelete=False,
+ operation="delete",
+ prop=self.prop,
+ )
+ elif self.hasparent(child) is False:
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
+ "delete", child
+ ):
+ uowcommit.register_object(st_, isdelete=True)
+
+ if pks_changed:
+ if history:
+ for child in history.unchanged:
+ if child is not None:
+ uowcommit.register_object(
+ child,
+ False,
+ self.passive_updates,
+ operation="pk change",
+ prop=self.prop,
+ )
+
+ def process_deletes(self, uowcommit, states):
+ # head object is being deleted, and we manage its list of
+ # child objects the child objects have to have their foreign
+ # key to the parent set to NULL this phase can be called
+ # safely for any cascade but is unnecessary if delete cascade
+ # is on.
+
+ if self.post_update or not self.passive_deletes == "all":
+ children_added = uowcommit.memo(("children_added", self), set)
+
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.deleted:
+ if (
+ child is not None
+ and self.hasparent(child) is False
+ ):
+ self._synchronize(
+ state, child, None, True, uowcommit, False
+ )
+ if self.post_update and child:
+ self._post_update(child, uowcommit, [state])
+
+ if self.post_update or not self.cascade.delete:
+ for child in set(history.unchanged).difference(
+ children_added
+ ):
+ if child is not None:
+ self._synchronize(
+ state, child, None, True, uowcommit, False
+ )
+ if self.post_update and child:
+ self._post_update(
+ child, uowcommit, [state]
+ )
+
+ # technically, we can even remove each child from the
+ # collection here too. but this would be a somewhat
+ # inconsistent behavior since it wouldn't happen
+ # if the old parent wasn't deleted but child was moved.
+
+ def process_saves(self, uowcommit, states):
+ should_null_fks = (
+ not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ )
+
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history:
+ for child in history.added:
+ self._synchronize(
+ state, child, None, False, uowcommit, False
+ )
+ if child is not None and self.post_update:
+ self._post_update(child, uowcommit, [state])
+
+ for child in history.deleted:
+ if (
+ should_null_fks
+ and not self.cascade.delete_orphan
+ and not self.hasparent(child)
+ ):
+ self._synchronize(
+ state, child, None, True, uowcommit, False
+ )
+
+ if self._pks_changed(uowcommit, state):
+ for child in history.unchanged:
+ self._synchronize(
+ state, child, None, False, uowcommit, True
+ )
+
+ def _synchronize(
+ self, state, child, associationrow, clearkeys, uowcommit, pks_changed
+ ):
+ source = state
+ dest = child
+ self._verify_canload(child)
+ if dest is None or (
+ not self.post_update and uowcommit.is_deleted(dest)
+ ):
+ return
+ if clearkeys:
+ sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
+ else:
+ sync.populate(
+ source,
+ self.parent,
+ dest,
+ self.mapper,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ self.passive_updates and pks_changed,
+ )
+
+ def _pks_changed(self, uowcommit, state):
+ return sync.source_modified(
+ uowcommit, state, self.parent, self.prop.synchronize_pairs
+ )
+
+
+class ManyToOneDP(DependencyProcessor):
+ def __init__(self, prop):
+ DependencyProcessor.__init__(self, prop)
+ for mapper in self.mapper.self_and_descendants:
+ mapper._dependency_processors.append(DetectKeySwitch(prop))
+
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
+
+ if self.post_update:
+ parent_post_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, False
+ )
+ parent_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (parent_saves, after_save),
+ (after_save, parent_post_updates),
+ (after_save, parent_pre_updates),
+ (before_delete, parent_pre_updates),
+ (parent_pre_updates, child_deletes),
+ (parent_pre_updates, parent_deletes),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (child_saves, after_save),
+ (after_save, parent_saves),
+ (parent_saves, child_deletes),
+ (parent_deletes, child_deletes),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
+
+ if self.post_update:
+
+ if not isdelete:
+ parent_post_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, False
+ )
+ if childisdelete:
+ uow.dependencies.update(
+ [
+ (after_save, parent_post_updates),
+ (parent_post_updates, child_action),
+ ]
+ )
+ else:
+ uow.dependencies.update(
+ [
+ (save_parent, after_save),
+ (child_action, after_save),
+ (after_save, parent_post_updates),
+ ]
+ )
+ else:
+ parent_pre_updates = unitofwork.PostUpdateAll(
+ uow, self.parent.primary_base_mapper, True
+ )
+
+ uow.dependencies.update(
+ [
+ (before_delete, parent_pre_updates),
+ (parent_pre_updates, delete_parent),
+ (parent_pre_updates, child_action),
+ ]
+ )
+
+ elif not isdelete:
+ if not childisdelete:
+ uow.dependencies.update(
+ [(child_action, after_save), (after_save, save_parent)]
+ )
+ else:
+ uow.dependencies.update([(after_save, save_parent)])
+
+ else:
+ if childisdelete:
+ uow.dependencies.update([(delete_parent, child_action)])
+
+ def presort_deletes(self, uowcommit, states):
+ if self.cascade.delete or self.cascade.delete_orphan:
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ if self.cascade.delete_orphan:
+ todelete = history.sum()
+ else:
+ todelete = history.non_deleted()
+ for child in todelete:
+ if child is None:
+ continue
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+ t = self.mapper.cascade_iterator("delete", child)
+ for c, m, st_, dct_ in t:
+ uowcommit.register_object(st_, isdelete=True)
+
+ def presort_saves(self, uowcommit, states):
+ for state in states:
+ uowcommit.register_object(state, operation="add", prop=self.prop)
+ if self.cascade.delete_orphan:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.deleted:
+ if self.hasparent(child) is False:
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+
+ t = self.mapper.cascade_iterator("delete", child)
+ for c, m, st_, dct_ in t:
+ uowcommit.register_object(st_, isdelete=True)
+
+ def process_deletes(self, uowcommit, states):
+ if (
+ self.post_update
+ and not self.cascade.delete_orphan
+ and not self.passive_deletes == "all"
+ ):
+
+ # post_update means we have to update our
+ # row to not reference the child object
+ # before we can DELETE the row
+ for state in states:
+ self._synchronize(state, None, None, True, uowcommit)
+ if state and self.post_update:
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ self._post_update(
+ state, uowcommit, history.sum(), is_m2o_delete=True
+ )
+
+ def process_saves(self, uowcommit, states):
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history:
+ if history.added:
+ for child in history.added:
+ self._synchronize(
+ state, child, None, False, uowcommit, "add"
+ )
+ elif history.deleted:
+ self._synchronize(
+ state, None, None, True, uowcommit, "delete"
+ )
+ if self.post_update:
+ self._post_update(state, uowcommit, history.sum())
+
+ def _synchronize(
+ self,
+ state,
+ child,
+ associationrow,
+ clearkeys,
+ uowcommit,
+ operation=None,
+ ):
+ if state is None or (
+ not self.post_update and uowcommit.is_deleted(state)
+ ):
+ return
+
+ if (
+ operation is not None
+ and child is not None
+ and not uowcommit.session._contains_state(child)
+ ):
+ util.warn(
+ "Object of type %s not in session, %s "
+ "operation along '%s' won't proceed"
+ % (mapperutil.state_class_str(child), operation, self.prop)
+ )
+ return
+
+ if clearkeys or child is None:
+ sync.clear(state, self.parent, self.prop.synchronize_pairs)
+ else:
+ self._verify_canload(child)
+ sync.populate(
+ child,
+ self.mapper,
+ state,
+ self.parent,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ False,
+ )
+
+
+class DetectKeySwitch(DependencyProcessor):
+ """For many-to-one relationships with no one-to-many backref,
+ searches for parents through the unit of work when a primary
+ key has changed and updates them.
+
+ Theoretically, this approach could be expanded to support transparent
+ deletion of objects referenced via many-to-one as well, although
+ the current attribute system doesn't do enough bookkeeping for this
+ to be efficient.
+
+ """
+
+ def per_property_preprocessors(self, uow):
+ if self.prop._reverse_property:
+ if self.passive_updates:
+ return
+ else:
+ if False in (
+ prop.passive_updates
+ for prop in self.prop._reverse_property
+ ):
+ return
+
+ uow.register_preprocessor(self, False)
+
+ def per_property_flush_actions(self, uow):
+ parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper)
+ after_save = unitofwork.ProcessAll(uow, self, False, False)
+ uow.dependencies.update([(parent_saves, after_save)])
+
+ def per_state_flush_actions(self, uow, states, isdelete):
+ pass
+
+ def presort_deletes(self, uowcommit, states):
+ pass
+
+ def presort_saves(self, uow, states):
+ if not self.passive_updates:
+ # for non-passive updates, register in the preprocess stage
+ # so that mapper save_obj() gets a hold of changes
+ self._process_key_switches(states, uow)
+
+ def prop_has_changes(self, uow, states, isdelete):
+ if not isdelete and self.passive_updates:
+ d = self._key_switchers(uow, states)
+ return bool(d)
+
+ return False
+
+ def process_deletes(self, uowcommit, states):
+ assert False
+
+ def process_saves(self, uowcommit, states):
+ # for passive updates, register objects in the process stage
+ # so that we avoid ManyToOneDP's registering the object without
+ # the listonly flag in its own preprocess stage (results in UPDATE)
+ # statements being emitted
+ assert self.passive_updates
+ self._process_key_switches(states, uowcommit)
+
+ def _key_switchers(self, uow, states):
+ switched, notswitched = uow.memo(
+ ("pk_switchers", self), lambda: (set(), set())
+ )
+
+ allstates = switched.union(notswitched)
+ for s in states:
+ if s not in allstates:
+ if self._pks_changed(uow, s):
+ switched.add(s)
+ else:
+ notswitched.add(s)
+ return switched
+
+ def _process_key_switches(self, deplist, uowcommit):
+ switchers = self._key_switchers(uowcommit, deplist)
+ if switchers:
+ # if primary key values have actually changed somewhere, perform
+ # a linear search through the UOW in search of a parent.
+ for state in uowcommit.session.identity_map.all_states():
+ if not issubclass(state.class_, self.parent.class_):
+ continue
+ dict_ = state.dict
+ related = state.get_impl(self.key).get(
+ state, dict_, passive=self._passive_update_flag
+ )
+ if (
+ related is not attributes.PASSIVE_NO_RESULT
+ and related is not None
+ ):
+ if self.prop.uselist:
+ if not related:
+ continue
+ related_obj = related[0]
+ else:
+ related_obj = related
+ related_state = attributes.instance_state(related_obj)
+ if related_state in switchers:
+ uowcommit.register_object(
+ state, False, self.passive_updates
+ )
+ sync.populate(
+ related_state,
+ self.mapper,
+ state,
+ self.parent,
+ self.prop.synchronize_pairs,
+ uowcommit,
+ self.passive_updates,
+ )
+
+ def _pks_changed(self, uowcommit, state):
+ return bool(state.key) and sync.source_modified(
+ uowcommit, state, self.mapper, self.prop.synchronize_pairs
+ )
+
+
+class ManyToManyDP(DependencyProcessor):
+ def per_property_dependencies(
+ self,
+ uow,
+ parent_saves,
+ child_saves,
+ parent_deletes,
+ child_deletes,
+ after_save,
+ before_delete,
+ ):
+
+ uow.dependencies.update(
+ [
+ (parent_saves, after_save),
+ (child_saves, after_save),
+ (after_save, child_deletes),
+ # a rowswitch on the parent from deleted to saved
+ # can make this one occur, as the "save" may remove
+ # an element from the
+ # "deleted" list before we have a chance to
+ # process its child rows
+ (before_delete, parent_saves),
+ (before_delete, parent_deletes),
+ (before_delete, child_deletes),
+ (before_delete, child_saves),
+ ]
+ )
+
+ def per_state_dependencies(
+ self,
+ uow,
+ save_parent,
+ delete_parent,
+ child_action,
+ after_save,
+ before_delete,
+ isdelete,
+ childisdelete,
+ ):
+ if not isdelete:
+ if childisdelete:
+ uow.dependencies.update(
+ [(save_parent, after_save), (after_save, child_action)]
+ )
+ else:
+ uow.dependencies.update(
+ [(save_parent, after_save), (child_action, after_save)]
+ )
+ else:
+ uow.dependencies.update(
+ [(before_delete, child_action), (before_delete, delete_parent)]
+ )
+
+ def presort_deletes(self, uowcommit, states):
+ # TODO: no tests fail if this whole
+ # thing is removed !!!!
+ if not self.passive_deletes:
+ # if no passive deletes, load history on
+ # the collection, so that prop_has_changes()
+ # returns True
+ for state in states:
+ uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+
+ def presort_saves(self, uowcommit, states):
+ if not self.passive_updates:
+ # if no passive updates, load history on
+ # each collection where parent has changed PK,
+ # so that prop_has_changes() returns True
+ for state in states:
+ if self._pks_changed(uowcommit, state):
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_OFF
+ )
+
+ if not self.cascade.delete_orphan:
+ return
+
+ # check for child items removed from the collection
+ # if delete_orphan check is turned on.
+ for state in states:
+ history = uowcommit.get_attribute_history(
+ state, self.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history:
+ for child in history.deleted:
+ if self.hasparent(child) is False:
+ uowcommit.register_object(
+ child,
+ isdelete=True,
+ operation="delete",
+ prop=self.prop,
+ )
+ for c, m, st_, dct_ in self.mapper.cascade_iterator(
+ "delete", child
+ ):
+ uowcommit.register_object(st_, isdelete=True)
+
+ def process_deletes(self, uowcommit, states):
+ secondary_delete = []
+ secondary_insert = []
+ secondary_update = []
+
+ processed = self._get_reversed_processed_set(uowcommit)
+ tmp = set()
+ for state in states:
+ # this history should be cached already, as
+ # we loaded it in preprocess_deletes
+ history = uowcommit.get_attribute_history(
+ state, self.key, self._passive_delete_flag
+ )
+ if history:
+ for child in history.non_added():
+ if child is None or (
+ processed is not None and (state, child) in processed
+ ):
+ continue
+ associationrow = {}
+ if not self._synchronize(
+ state,
+ child,
+ associationrow,
+ False,
+ uowcommit,
+ "delete",
+ ):
+ continue
+ secondary_delete.append(associationrow)
+
+ tmp.update((c, state) for c in history.non_added())
+
+ if processed is not None:
+ processed.update(tmp)
+
+ self._run_crud(
+ uowcommit, secondary_insert, secondary_update, secondary_delete
+ )
+
+ def process_saves(self, uowcommit, states):
+ secondary_delete = []
+ secondary_insert = []
+ secondary_update = []
+
+ processed = self._get_reversed_processed_set(uowcommit)
+ tmp = set()
+
+ for state in states:
+ need_cascade_pks = not self.passive_updates and self._pks_changed(
+ uowcommit, state
+ )
+ if need_cascade_pks:
+ passive = attributes.PASSIVE_OFF
+ else:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ history = uowcommit.get_attribute_history(state, self.key, passive)
+ if history:
+ for child in history.added:
+ if processed is not None and (state, child) in processed:
+ continue
+ associationrow = {}
+ if not self._synchronize(
+ state, child, associationrow, False, uowcommit, "add"
+ ):
+ continue
+ secondary_insert.append(associationrow)
+ for child in history.deleted:
+ if processed is not None and (state, child) in processed:
+ continue
+ associationrow = {}
+ if not self._synchronize(
+ state,
+ child,
+ associationrow,
+ False,
+ uowcommit,
+ "delete",
+ ):
+ continue
+ secondary_delete.append(associationrow)
+
+ tmp.update((c, state) for c in history.added + history.deleted)
+
+ if need_cascade_pks:
+
+ for child in history.unchanged:
+ associationrow = {}
+ sync.update(
+ state,
+ self.parent,
+ associationrow,
+ "old_",
+ self.prop.synchronize_pairs,
+ )
+ sync.update(
+ child,
+ self.mapper,
+ associationrow,
+ "old_",
+ self.prop.secondary_synchronize_pairs,
+ )
+
+ secondary_update.append(associationrow)
+
+ if processed is not None:
+ processed.update(tmp)
+
+ self._run_crud(
+ uowcommit, secondary_insert, secondary_update, secondary_delete
+ )
+
+ def _run_crud(
+ self, uowcommit, secondary_insert, secondary_update, secondary_delete
+ ):
+ connection = uowcommit.transaction.connection(self.mapper)
+
+ if secondary_delete:
+ associationrow = secondary_delete[0]
+ statement = self.secondary.delete().where(
+ sql.and_(
+ *[
+ c == sql.bindparam(c.key, type_=c.type)
+ for c in self.secondary.c
+ if c.key in associationrow
+ ]
+ )
+ )
+ result = connection.execute(statement, secondary_delete)
+
+ if (
+ result.supports_sane_multi_rowcount()
+ ) and result.rowcount != len(secondary_delete):
+ raise exc.StaleDataError(
+ "DELETE statement on table '%s' expected to delete "
+ "%d row(s); Only %d were matched."
+ % (
+ self.secondary.description,
+ len(secondary_delete),
+ result.rowcount,
+ )
+ )
+
+ if secondary_update:
+ associationrow = secondary_update[0]
+ statement = self.secondary.update().where(
+ sql.and_(
+ *[
+ c == sql.bindparam("old_" + c.key, type_=c.type)
+ for c in self.secondary.c
+ if c.key in associationrow
+ ]
+ )
+ )
+ result = connection.execute(statement, secondary_update)
+
+ if (
+ result.supports_sane_multi_rowcount()
+ ) and result.rowcount != len(secondary_update):
+ raise exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to update "
+ "%d row(s); Only %d were matched."
+ % (
+ self.secondary.description,
+ len(secondary_update),
+ result.rowcount,
+ )
+ )
+
+ if secondary_insert:
+ statement = self.secondary.insert()
+ connection.execute(statement, secondary_insert)
+
+ def _synchronize(
+ self, state, child, associationrow, clearkeys, uowcommit, operation
+ ):
+
+ # this checks for None if uselist=True
+ self._verify_canload(child)
+
+ # but if uselist=False we get here. If child is None,
+ # no association row can be generated, so return.
+ if child is None:
+ return False
+
+ if child is not None and not uowcommit.session._contains_state(child):
+ if not child.deleted:
+ util.warn(
+ "Object of type %s not in session, %s "
+ "operation along '%s' won't proceed"
+ % (mapperutil.state_class_str(child), operation, self.prop)
+ )
+ return False
+
+ sync.populate_dict(
+ state, self.parent, associationrow, self.prop.synchronize_pairs
+ )
+ sync.populate_dict(
+ child,
+ self.mapper,
+ associationrow,
+ self.prop.secondary_synchronize_pairs,
+ )
+
+ return True
+
+ def _pks_changed(self, uowcommit, state):
+ return sync.source_modified(
+ uowcommit, state, self.parent, self.prop.synchronize_pairs
+ )
+
+
+_direction_to_processor = {
+ ONETOMANY: OneToManyDP,
+ MANYTOONE: ManyToOneDP,
+ MANYTOMANY: ManyToManyDP,
+}
diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py
new file mode 100644
index 0000000..3d7f23b
--- /dev/null
+++ b/lib/sqlalchemy/orm/descriptor_props.py
@@ -0,0 +1,745 @@
+# orm/descriptor_props.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Descriptor properties are more "auxiliary" properties
+that exist as configurational elements, but don't participate
+as actively in the load/persist ORM loop.
+
+"""
+
+from . import attributes
+from . import util as orm_util
+from .interfaces import MapperProperty
+from .interfaces import PropComparator
+from .util import _none_set
+from .. import event
+from .. import exc as sa_exc
+from .. import schema
+from .. import sql
+from .. import util
+from ..sql import expression
+from ..sql import operators
+
+
+class DescriptorProperty(MapperProperty):
+ """:class:`.MapperProperty` which proxies access to a
+ user-defined descriptor."""
+
+ doc = None
+
+ uses_objects = False
+ _links_to_entity = False
+
+ def instrument_class(self, mapper):
+ prop = self
+
+ class _ProxyImpl(object):
+ accepts_scalar_loader = False
+ load_on_unexpire = True
+ collection = False
+
+ @property
+ def uses_objects(self):
+ return prop.uses_objects
+
+ def __init__(self, key):
+ self.key = key
+
+ if hasattr(prop, "get_history"):
+
+ def get_history(
+ self, state, dict_, passive=attributes.PASSIVE_OFF
+ ):
+ return prop.get_history(state, dict_, passive)
+
+ if self.descriptor is None:
+ desc = getattr(mapper.class_, self.key, None)
+ if mapper._is_userland_descriptor(self.key, desc):
+ self.descriptor = desc
+
+ if self.descriptor is None:
+
+ def fset(obj, value):
+ setattr(obj, self.name, value)
+
+ def fdel(obj):
+ delattr(obj, self.name)
+
+ def fget(obj):
+ return getattr(obj, self.name)
+
+ self.descriptor = property(fget=fget, fset=fset, fdel=fdel)
+
+ proxy_attr = attributes.create_proxied_attribute(self.descriptor)(
+ self.parent.class_,
+ self.key,
+ self.descriptor,
+ lambda: self._comparator_factory(mapper),
+ doc=self.doc,
+ original_property=self,
+ )
+ proxy_attr.impl = _ProxyImpl(self.key)
+ mapper.class_manager.instrument_attribute(self.key, proxy_attr)
+
+
+class CompositeProperty(DescriptorProperty):
+ """Defines a "composite" mapped attribute, representing a collection
+ of columns as one attribute.
+
+ :class:`.CompositeProperty` is constructed using the :func:`.composite`
+ function.
+
+ .. seealso::
+
+ :ref:`mapper_composite`
+
+ """
+
+ def __init__(self, class_, *attrs, **kwargs):
+ r"""Return a composite column-based property for use with a Mapper.
+
+ See the mapping documentation section :ref:`mapper_composite` for a
+ full usage example.
+
+ The :class:`.MapperProperty` returned by :func:`.composite`
+ is the :class:`.CompositeProperty`.
+
+ :param class\_:
+ The "composite type" class, or any classmethod or callable which
+ will produce a new instance of the composite object given the
+ column values in order.
+
+ :param \*cols:
+ List of Column objects to be mapped.
+
+ :param active_history=False:
+ When ``True``, indicates that the "previous" value for a
+ scalar attribute should be loaded when replaced, if not
+ already loaded. See the same flag on :func:`.column_property`.
+
+ :param group:
+ A group name for this property when marked as deferred.
+
+ :param deferred:
+ When True, the column property is "deferred", meaning that it does
+ not load immediately, and is instead loaded when the attribute is
+ first accessed on an instance. See also
+ :func:`~sqlalchemy.orm.deferred`.
+
+ :param comparator_factory: a class which extends
+ :class:`.CompositeProperty.Comparator` which provides custom SQL
+ clause generation for comparison operations.
+
+ :param doc:
+ optional string that will be applied as the doc on the
+ class-bound descriptor.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.MapperProperty.info` attribute of this object.
+
+ """
+ super(CompositeProperty, self).__init__()
+
+ self.attrs = attrs
+ self.composite_class = class_
+ self.active_history = kwargs.get("active_history", False)
+ self.deferred = kwargs.get("deferred", False)
+ self.group = kwargs.get("group", None)
+ self.comparator_factory = kwargs.pop(
+ "comparator_factory", self.__class__.Comparator
+ )
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+
+ util.set_creation_order(self)
+ self._create_descriptor()
+
+ def instrument_class(self, mapper):
+ super(CompositeProperty, self).instrument_class(mapper)
+ self._setup_event_handlers()
+
+ def do_init(self):
+ """Initialization which occurs after the :class:`.CompositeProperty`
+ has been associated with its parent mapper.
+
+ """
+ self._setup_arguments_on_columns()
+
+ _COMPOSITE_FGET = object()
+
+ def _create_descriptor(self):
+ """Create the Python descriptor that will serve as
+ the access point on instances of the mapped class.
+
+ """
+
+ def fget(instance):
+ dict_ = attributes.instance_dict(instance)
+ state = attributes.instance_state(instance)
+
+ if self.key not in dict_:
+ # key not present. Iterate through related
+ # attributes, retrieve their values. This
+ # ensures they all load.
+ values = [
+ getattr(instance, key) for key in self._attribute_keys
+ ]
+
+ # current expected behavior here is that the composite is
+ # created on access if the object is persistent or if
+ # col attributes have non-None. This would be better
+ # if the composite were created unconditionally,
+ # but that would be a behavioral change.
+ if self.key not in dict_ and (
+ state.key is not None or not _none_set.issuperset(values)
+ ):
+ dict_[self.key] = self.composite_class(*values)
+ state.manager.dispatch.refresh(
+ state, self._COMPOSITE_FGET, [self.key]
+ )
+
+ return dict_.get(self.key, None)
+
+ def fset(instance, value):
+ dict_ = attributes.instance_dict(instance)
+ state = attributes.instance_state(instance)
+ attr = state.manager[self.key]
+ previous = dict_.get(self.key, attributes.NO_VALUE)
+ for fn in attr.dispatch.set:
+ value = fn(state, value, previous, attr.impl)
+ dict_[self.key] = value
+ if value is None:
+ for key in self._attribute_keys:
+ setattr(instance, key, None)
+ else:
+ for key, value in zip(
+ self._attribute_keys, value.__composite_values__()
+ ):
+ setattr(instance, key, value)
+
+ def fdel(instance):
+ state = attributes.instance_state(instance)
+ dict_ = attributes.instance_dict(instance)
+ previous = dict_.pop(self.key, attributes.NO_VALUE)
+ attr = state.manager[self.key]
+ attr.dispatch.remove(state, previous, attr.impl)
+ for key in self._attribute_keys:
+ setattr(instance, key, None)
+
+ self.descriptor = property(fget, fset, fdel)
+
+ @util.memoized_property
+ def _comparable_elements(self):
+ return [getattr(self.parent.class_, prop.key) for prop in self.props]
+
+ @util.memoized_property
+ def props(self):
+ props = []
+ for attr in self.attrs:
+ if isinstance(attr, str):
+ prop = self.parent.get_property(attr, _configure_mappers=False)
+ elif isinstance(attr, schema.Column):
+ prop = self.parent._columntoproperty[attr]
+ elif isinstance(attr, attributes.InstrumentedAttribute):
+ prop = attr.property
+ else:
+ raise sa_exc.ArgumentError(
+ "Composite expects Column objects or mapped "
+ "attributes/attribute names as arguments, got: %r"
+ % (attr,)
+ )
+ props.append(prop)
+ return props
+
+ @property
+ def columns(self):
+ return [a for a in self.attrs if isinstance(a, schema.Column)]
+
+ def _setup_arguments_on_columns(self):
+ """Propagate configuration arguments made on this composite
+ to the target columns, for those that apply.
+
+ """
+ for prop in self.props:
+ prop.active_history = self.active_history
+ if self.deferred:
+ prop.deferred = self.deferred
+ prop.strategy_key = (("deferred", True), ("instrument", True))
+ prop.group = self.group
+
+ def _setup_event_handlers(self):
+ """Establish events that populate/expire the composite attribute."""
+
+ def load_handler(state, context):
+ _load_refresh_handler(state, context, None, is_refresh=False)
+
+ def refresh_handler(state, context, to_load):
+ # note this corresponds to sqlalchemy.ext.mutable load_attrs()
+
+ if not to_load or (
+ {self.key}.union(self._attribute_keys)
+ ).intersection(to_load):
+ _load_refresh_handler(state, context, to_load, is_refresh=True)
+
+ def _load_refresh_handler(state, context, to_load, is_refresh):
+ dict_ = state.dict
+
+ # if context indicates we are coming from the
+ # fget() handler, this already set the value; skip the
+ # handler here. (other handlers like mutablecomposite will still
+ # want to catch it)
+ # there's an insufficiency here in that the fget() handler
+ # really should not be using the refresh event and there should
+ # be some other event that mutablecomposite can subscribe
+ # towards for this.
+
+ if (
+ not is_refresh or context is self._COMPOSITE_FGET
+ ) and self.key in dict_:
+ return
+
+ # if column elements aren't loaded, skip.
+ # __get__() will initiate a load for those
+ # columns
+ for k in self._attribute_keys:
+ if k not in dict_:
+ return
+
+ dict_[self.key] = self.composite_class(
+ *[state.dict[key] for key in self._attribute_keys]
+ )
+
+ def expire_handler(state, keys):
+ if keys is None or set(self._attribute_keys).intersection(keys):
+ state.dict.pop(self.key, None)
+
+ def insert_update_handler(mapper, connection, state):
+ """After an insert or update, some columns may be expired due
+ to server side defaults, or re-populated due to client side
+ defaults. Pop out the composite value here so that it
+ recreates.
+
+ """
+
+ state.dict.pop(self.key, None)
+
+ event.listen(
+ self.parent, "after_insert", insert_update_handler, raw=True
+ )
+ event.listen(
+ self.parent, "after_update", insert_update_handler, raw=True
+ )
+ event.listen(
+ self.parent, "load", load_handler, raw=True, propagate=True
+ )
+ event.listen(
+ self.parent, "refresh", refresh_handler, raw=True, propagate=True
+ )
+ event.listen(
+ self.parent, "expire", expire_handler, raw=True, propagate=True
+ )
+
+ # TODO: need a deserialize hook here
+
+ @util.memoized_property
+ def _attribute_keys(self):
+ return [prop.key for prop in self.props]
+
+ def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ """Provided for userland code that uses attributes.get_history()."""
+
+ added = []
+ deleted = []
+
+ has_history = False
+ for prop in self.props:
+ key = prop.key
+ hist = state.manager[key].impl.get_history(state, dict_)
+ if hist.has_changes():
+ has_history = True
+
+ non_deleted = hist.non_deleted()
+ if non_deleted:
+ added.extend(non_deleted)
+ else:
+ added.append(None)
+ if hist.deleted:
+ deleted.extend(hist.deleted)
+ else:
+ deleted.append(None)
+
+ if has_history:
+ return attributes.History(
+ [self.composite_class(*added)],
+ (),
+ [self.composite_class(*deleted)],
+ )
+ else:
+ return attributes.History((), [self.composite_class(*added)], ())
+
+ def _comparator_factory(self, mapper):
+ return self.comparator_factory(self, mapper)
+
+ class CompositeBundle(orm_util.Bundle):
+ def __init__(self, property_, expr):
+ self.property = property_
+ super(CompositeProperty.CompositeBundle, self).__init__(
+ property_.key, *expr
+ )
+
+ def create_row_processor(self, query, procs, labels):
+ def proc(row):
+ return self.property.composite_class(
+ *[proc(row) for proc in procs]
+ )
+
+ return proc
+
+ class Comparator(PropComparator):
+ """Produce boolean, comparison, and other operators for
+ :class:`.CompositeProperty` attributes.
+
+ See the example in :ref:`composite_operations` for an overview
+ of usage , as well as the documentation for :class:`.PropComparator`.
+
+ .. seealso::
+
+ :class:`.PropComparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ __hash__ = None
+
+ @util.memoized_property
+ def clauses(self):
+ return expression.ClauseList(
+ group=False, *self._comparable_elements
+ )
+
+ def __clause_element__(self):
+ return self.expression
+
+ @util.memoized_property
+ def expression(self):
+ clauses = self.clauses._annotate(
+ {
+ "parententity": self._parententity,
+ "parentmapper": self._parententity,
+ "proxy_key": self.prop.key,
+ }
+ )
+ return CompositeProperty.CompositeBundle(self.prop, clauses)
+
+ def _bulk_update_tuples(self, value):
+ if isinstance(value, sql.elements.BindParameter):
+ value = value.value
+
+ if value is None:
+ values = [None for key in self.prop._attribute_keys]
+ elif isinstance(value, self.prop.composite_class):
+ values = value.__composite_values__()
+ else:
+ raise sa_exc.ArgumentError(
+ "Can't UPDATE composite attribute %s to %r"
+ % (self.prop, value)
+ )
+
+ return zip(self._comparable_elements, values)
+
+ @util.memoized_property
+ def _comparable_elements(self):
+ if self._adapt_to_entity:
+ return [
+ getattr(self._adapt_to_entity.entity, prop.key)
+ for prop in self.prop._comparable_elements
+ ]
+ else:
+ return self.prop._comparable_elements
+
+ def __eq__(self, other):
+ if other is None:
+ values = [None] * len(self.prop._comparable_elements)
+ else:
+ values = other.__composite_values__()
+ comparisons = [
+ a == b for a, b in zip(self.prop._comparable_elements, values)
+ ]
+ if self._adapt_to_entity:
+ comparisons = [self.adapter(x) for x in comparisons]
+ return sql.and_(*comparisons)
+
+ def __ne__(self, other):
+ return sql.not_(self.__eq__(other))
+
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
+
+
+class ConcreteInheritedProperty(DescriptorProperty):
+ """A 'do nothing' :class:`.MapperProperty` that disables
+ an attribute on a concrete subclass that is only present
+ on the inherited mapper, not the concrete classes' mapper.
+
+ Cases where this occurs include:
+
+ * When the superclass mapper is mapped against a
+ "polymorphic union", which includes all attributes from
+ all subclasses.
+ * When a relationship() is configured on an inherited mapper,
+ but not on the subclass mapper. Concrete mappers require
+ that relationship() is configured explicitly on each
+ subclass.
+
+ """
+
+ def _comparator_factory(self, mapper):
+ comparator_callable = None
+
+ for m in self.parent.iterate_to_root():
+ p = m._props[self.key]
+ if not isinstance(p, ConcreteInheritedProperty):
+ comparator_callable = p.comparator_factory
+ break
+ return comparator_callable
+
+ def __init__(self):
+ super(ConcreteInheritedProperty, self).__init__()
+
+ def warn():
+ raise AttributeError(
+ "Concrete %s does not implement "
+ "attribute %r at the instance level. Add "
+ "this property explicitly to %s."
+ % (self.parent, self.key, self.parent)
+ )
+
+ class NoninheritedConcreteProp(object):
+ def __set__(s, obj, value):
+ warn()
+
+ def __delete__(s, obj):
+ warn()
+
+ def __get__(s, obj, owner):
+ if obj is None:
+ return self.descriptor
+ warn()
+
+ self.descriptor = NoninheritedConcreteProp()
+
+
+class SynonymProperty(DescriptorProperty):
+ def __init__(
+ self,
+ name,
+ map_column=None,
+ descriptor=None,
+ comparator_factory=None,
+ doc=None,
+ info=None,
+ ):
+ """Denote an attribute name as a synonym to a mapped property,
+ in that the attribute will mirror the value and expression behavior
+ of another attribute.
+
+ e.g.::
+
+ class MyClass(Base):
+ __tablename__ = 'my_table'
+
+ id = Column(Integer, primary_key=True)
+ job_status = Column(String(50))
+
+ status = synonym("job_status")
+
+
+ :param name: the name of the existing mapped property. This
+ can refer to the string name ORM-mapped attribute
+ configured on the class, including column-bound attributes
+ and relationships.
+
+ :param descriptor: a Python :term:`descriptor` that will be used
+ as a getter (and potentially a setter) when this attribute is
+ accessed at the instance level.
+
+ :param map_column: **For classical mappings and mappings against
+ an existing Table object only**. if ``True``, the :func:`.synonym`
+ construct will locate the :class:`_schema.Column`
+ object upon the mapped
+ table that would normally be associated with the attribute name of
+ this synonym, and produce a new :class:`.ColumnProperty` that instead
+ maps this :class:`_schema.Column`
+ to the alternate name given as the "name"
+ argument of the synonym; in this way, the usual step of redefining
+ the mapping of the :class:`_schema.Column`
+ to be under a different name is
+ unnecessary. This is usually intended to be used when a
+ :class:`_schema.Column`
+ is to be replaced with an attribute that also uses a
+ descriptor, that is, in conjunction with the
+ :paramref:`.synonym.descriptor` parameter::
+
+ my_table = Table(
+ "my_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('job_status', String(50))
+ )
+
+ class MyClass(object):
+ @property
+ def _job_status_descriptor(self):
+ return "Status: %s" % self._job_status
+
+
+ mapper(
+ MyClass, my_table, properties={
+ "job_status": synonym(
+ "_job_status", map_column=True,
+ descriptor=MyClass._job_status_descriptor)
+ }
+ )
+
+ Above, the attribute named ``_job_status`` is automatically
+ mapped to the ``job_status`` column::
+
+ >>> j1 = MyClass()
+ >>> j1._job_status = "employed"
+ >>> j1.job_status
+ Status: employed
+
+ When using Declarative, in order to provide a descriptor in
+ conjunction with a synonym, use the
+ :func:`sqlalchemy.ext.declarative.synonym_for` helper. However,
+ note that the :ref:`hybrid properties <mapper_hybrids>` feature
+ should usually be preferred, particularly when redefining attribute
+ behavior.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.InspectionAttr.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param comparator_factory: A subclass of :class:`.PropComparator`
+ that will provide custom comparison behavior at the SQL expression
+ level.
+
+ .. note::
+
+ For the use case of providing an attribute which redefines both
+ Python-level and SQL-expression level behavior of an attribute,
+ please refer to the Hybrid attribute introduced at
+ :ref:`mapper_hybrids` for a more effective technique.
+
+ .. seealso::
+
+ :ref:`synonyms` - Overview of synonyms
+
+ :func:`.synonym_for` - a helper oriented towards Declarative
+
+ :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an
+ updated approach to augmenting attribute behavior more flexibly
+ than can be achieved with synonyms.
+
+ """
+ super(SynonymProperty, self).__init__()
+
+ self.name = name
+ self.map_column = map_column
+ self.descriptor = descriptor
+ self.comparator_factory = comparator_factory
+ self.doc = doc or (descriptor and descriptor.__doc__) or None
+ if info:
+ self.info = info
+
+ util.set_creation_order(self)
+
+ @property
+ def uses_objects(self):
+ return getattr(self.parent.class_, self.name).impl.uses_objects
+
+ # TODO: when initialized, check _proxied_object,
+ # emit a warning if its not a column-based property
+
+ @util.memoized_property
+ def _proxied_object(self):
+ attr = getattr(self.parent.class_, self.name)
+ if not hasattr(attr, "property") or not isinstance(
+ attr.property, MapperProperty
+ ):
+ # attribute is a non-MapperProprerty proxy such as
+ # hybrid or association proxy
+ if isinstance(attr, attributes.QueryableAttribute):
+ return attr.comparator
+ elif isinstance(attr, operators.ColumnOperators):
+ return attr
+
+ raise sa_exc.InvalidRequestError(
+ """synonym() attribute "%s.%s" only supports """
+ """ORM mapped attributes, got %r"""
+ % (self.parent.class_.__name__, self.name, attr)
+ )
+ return attr.property
+
+ def _comparator_factory(self, mapper):
+ prop = self._proxied_object
+
+ if isinstance(prop, MapperProperty):
+ if self.comparator_factory:
+ comp = self.comparator_factory(prop, mapper)
+ else:
+ comp = prop.comparator_factory(prop, mapper)
+ return comp
+ else:
+ return prop
+
+ def get_history(self, *arg, **kw):
+ attr = getattr(self.parent.class_, self.name)
+ return attr.impl.get_history(*arg, **kw)
+
+ @util.preload_module("sqlalchemy.orm.properties")
+ def set_parent(self, parent, init):
+ properties = util.preloaded.orm_properties
+
+ if self.map_column:
+ # implement the 'map_column' option.
+ if self.key not in parent.persist_selectable.c:
+ raise sa_exc.ArgumentError(
+ "Can't compile synonym '%s': no column on table "
+ "'%s' named '%s'"
+ % (
+ self.name,
+ parent.persist_selectable.description,
+ self.key,
+ )
+ )
+ elif (
+ parent.persist_selectable.c[self.key]
+ in parent._columntoproperty
+ and parent._columntoproperty[
+ parent.persist_selectable.c[self.key]
+ ].key
+ == self.name
+ ):
+ raise sa_exc.ArgumentError(
+ "Can't call map_column=True for synonym %r=%r, "
+ "a ColumnProperty already exists keyed to the name "
+ "%r for column %r"
+ % (self.key, self.name, self.name, self.key)
+ )
+ p = properties.ColumnProperty(
+ parent.persist_selectable.c[self.key]
+ )
+ parent._configure_property(self.name, p, init=init, setparent=True)
+ p._mapped_by_synonym = self.key
+
+ self.parent = parent
diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py
new file mode 100644
index 0000000..ec62560
--- /dev/null
+++ b/lib/sqlalchemy/orm/dynamic.py
@@ -0,0 +1,491 @@
+# orm/dynamic.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Dynamic collection API.
+
+Dynamic collections act like Query() objects for read operations and support
+basic add/delete mutation.
+
+"""
+
+from . import attributes
+from . import exc as orm_exc
+from . import interfaces
+from . import object_mapper
+from . import object_session
+from . import relationships
+from . import strategies
+from . import util as orm_util
+from .query import Query
+from .. import exc
+from .. import log
+from .. import util
+from ..engine import result
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="dynamic")
+class DynaLoader(strategies.AbstractRelationshipLoader):
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+ if not self.uselist:
+ raise exc.InvalidRequestError(
+ "On relationship %s, 'dynamic' loaders cannot be used with "
+ "many-to-one/one-to-one relationships and/or "
+ "uselist=False." % self.parent_property
+ )
+ elif self.parent_property.direction not in (
+ interfaces.ONETOMANY,
+ interfaces.MANYTOMANY,
+ ):
+ util.warn(
+ "On relationship %s, 'dynamic' loaders cannot be used with "
+ "many-to-one/one-to-one relationships and/or "
+ "uselist=False. This warning will be an exception in a "
+ "future release." % self.parent_property
+ )
+
+ strategies._register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=True,
+ impl_class=DynamicAttributeImpl,
+ target_mapper=self.parent_property.mapper,
+ order_by=self.parent_property.order_by,
+ query_class=self.parent_property.query_class,
+ )
+
+
+class DynamicAttributeImpl(attributes.AttributeImpl):
+ uses_objects = True
+ default_accepts_scalar_loader = False
+ supports_population = False
+ collection = False
+ dynamic = True
+ order_by = ()
+
+ def __init__(
+ self,
+ class_,
+ key,
+ typecallable,
+ dispatch,
+ target_mapper,
+ order_by,
+ query_class=None,
+ **kw
+ ):
+ super(DynamicAttributeImpl, self).__init__(
+ class_, key, typecallable, dispatch, **kw
+ )
+ self.target_mapper = target_mapper
+ if order_by:
+ self.order_by = tuple(order_by)
+ if not query_class:
+ self.query_class = AppenderQuery
+ elif AppenderMixin in query_class.mro():
+ self.query_class = query_class
+ else:
+ self.query_class = mixin_user_query(query_class)
+
+ def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ if not passive & attributes.SQL_OK:
+ return self._get_collection_history(
+ state, attributes.PASSIVE_NO_INITIALIZE
+ ).added_items
+ else:
+ return self.query_class(self, state)
+
+ def get_collection(
+ self,
+ state,
+ dict_,
+ user_data=None,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ ):
+ if not passive & attributes.SQL_OK:
+ data = self._get_collection_history(state, passive).added_items
+ else:
+ history = self._get_collection_history(state, passive)
+ data = history.added_plus_unchanged
+ return DynamicCollectionAdapter(data)
+
+ @util.memoized_property
+ def _append_token(self):
+ return attributes.Event(self, attributes.OP_APPEND)
+
+ @util.memoized_property
+ def _remove_token(self):
+ return attributes.Event(self, attributes.OP_REMOVE)
+
+ def fire_append_event(
+ self, state, dict_, value, initiator, collection_history=None
+ ):
+ if collection_history is None:
+ collection_history = self._modified_event(state, dict_)
+
+ collection_history.add_added(value)
+
+ for fn in self.dispatch.append:
+ value = fn(state, value, initiator or self._append_token)
+
+ if self.trackparent and value is not None:
+ self.sethasparent(attributes.instance_state(value), state, True)
+
+ def fire_remove_event(
+ self, state, dict_, value, initiator, collection_history=None
+ ):
+ if collection_history is None:
+ collection_history = self._modified_event(state, dict_)
+
+ collection_history.add_removed(value)
+
+ if self.trackparent and value is not None:
+ self.sethasparent(attributes.instance_state(value), state, False)
+
+ for fn in self.dispatch.remove:
+ fn(state, value, initiator or self._remove_token)
+
+ def _modified_event(self, state, dict_):
+
+ if self.key not in state.committed_state:
+ state.committed_state[self.key] = CollectionHistory(self, state)
+
+ state._modified_event(dict_, self, attributes.NEVER_SET)
+
+ # this is a hack to allow the fixtures.ComparableEntity fixture
+ # to work
+ dict_[self.key] = True
+ return state.committed_state[self.key]
+
+ def set(
+ self,
+ state,
+ dict_,
+ value,
+ initiator=None,
+ passive=attributes.PASSIVE_OFF,
+ check_old=None,
+ pop=False,
+ _adapt=True,
+ ):
+ if initiator and initiator.parent_token is self.parent_token:
+ return
+
+ if pop and value is None:
+ return
+
+ iterable = value
+ new_values = list(iterable)
+ if state.has_identity:
+ old_collection = util.IdentitySet(self.get(state, dict_))
+
+ collection_history = self._modified_event(state, dict_)
+ if not state.has_identity:
+ old_collection = collection_history.added_items
+ else:
+ old_collection = old_collection.union(
+ collection_history.added_items
+ )
+
+ idset = util.IdentitySet
+ constants = old_collection.intersection(new_values)
+ additions = idset(new_values).difference(constants)
+ removals = old_collection.difference(constants)
+
+ for member in new_values:
+ if member in additions:
+ self.fire_append_event(
+ state,
+ dict_,
+ member,
+ None,
+ collection_history=collection_history,
+ )
+
+ for member in removals:
+ self.fire_remove_event(
+ state,
+ dict_,
+ member,
+ None,
+ collection_history=collection_history,
+ )
+
+ def delete(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def set_committed_value(self, state, dict_, value):
+ raise NotImplementedError(
+ "Dynamic attributes don't support " "collection population."
+ )
+
+ def get_history(self, state, dict_, passive=attributes.PASSIVE_OFF):
+ c = self._get_collection_history(state, passive)
+ return c.as_history()
+
+ def get_all_pending(
+ self, state, dict_, passive=attributes.PASSIVE_NO_INITIALIZE
+ ):
+ c = self._get_collection_history(state, passive)
+ return [(attributes.instance_state(x), x) for x in c.all_items]
+
+ def _get_collection_history(self, state, passive=attributes.PASSIVE_OFF):
+ if self.key in state.committed_state:
+ c = state.committed_state[self.key]
+ else:
+ c = CollectionHistory(self, state)
+
+ if state.has_identity and (passive & attributes.INIT_OK):
+ return CollectionHistory(self, state, apply_to=c)
+ else:
+ return c
+
+ def append(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
+ if initiator is not self:
+ self.fire_append_event(state, dict_, value, initiator)
+
+ def remove(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
+ if initiator is not self:
+ self.fire_remove_event(state, dict_, value, initiator)
+
+ def pop(
+ self, state, dict_, value, initiator, passive=attributes.PASSIVE_OFF
+ ):
+ self.remove(state, dict_, value, initiator, passive=passive)
+
+
+class DynamicCollectionAdapter(object):
+ """simplified CollectionAdapter for internal API consistency"""
+
+ def __init__(self, data):
+ self.data = data
+
+ def __iter__(self):
+ return iter(self.data)
+
+ def _reset_empty(self):
+ pass
+
+ def __len__(self):
+ return len(self.data)
+
+ def __bool__(self):
+ return True
+
+ __nonzero__ = __bool__
+
+
+class AppenderMixin(object):
+ query_class = None
+
+ def __init__(self, attr, state):
+ super(AppenderMixin, self).__init__(attr.target_mapper, None)
+ self.instance = instance = state.obj()
+ self.attr = attr
+
+ mapper = object_mapper(instance)
+ prop = mapper._props[self.attr.key]
+
+ if prop.secondary is not None:
+ # this is a hack right now. The Query only knows how to
+ # make subsequent joins() without a given left-hand side
+ # from self._from_obj[0]. We need to ensure prop.secondary
+ # is in the FROM. So we purposely put the mapper selectable
+ # in _from_obj[0] to ensure a user-defined join() later on
+ # doesn't fail, and secondary is then in _from_obj[1].
+
+ # note also, we are using the official ORM-annotated selectable
+ # from __clause_element__(), see #7868
+ self._from_obj = (prop.mapper.__clause_element__(), prop.secondary)
+
+ self._where_criteria = (
+ prop._with_parent(instance, alias_secondary=False),
+ )
+
+ if self.attr.order_by:
+ self._order_by_clauses = self.attr.order_by
+
+ def session(self):
+ sess = object_session(self.instance)
+ if (
+ sess is not None
+ and self.autoflush
+ and sess.autoflush
+ and self.instance in sess
+ ):
+ sess.flush()
+ if not orm_util.has_identity(self.instance):
+ return None
+ else:
+ return sess
+
+ session = property(session, lambda s, x: None)
+
+ def _iter(self):
+ sess = self.session
+ if sess is None:
+ state = attributes.instance_state(self.instance)
+ if state.detached:
+ util.warn(
+ "Instance %s is detached, dynamic relationship cannot "
+ "return a correct result. This warning will become "
+ "a DetachedInstanceError in a future release."
+ % (orm_util.state_str(state))
+ )
+
+ return result.IteratorResult(
+ result.SimpleResultMetaData([self.attr.class_.__name__]),
+ self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).added_items,
+ _source_supports_scalars=True,
+ ).scalars()
+ else:
+ return self._generate(sess)._iter()
+
+ def __getitem__(self, index):
+ sess = self.session
+ if sess is None:
+ return self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).indexed(index)
+ else:
+ return self._generate(sess).__getitem__(index)
+
+ def count(self):
+ sess = self.session
+ if sess is None:
+ return len(
+ self.attr._get_collection_history(
+ attributes.instance_state(self.instance),
+ attributes.PASSIVE_NO_INITIALIZE,
+ ).added_items
+ )
+ else:
+ return self._generate(sess).count()
+
+ def _generate(self, sess=None):
+ # note we're returning an entirely new Query class instance
+ # here without any assignment capabilities; the class of this
+ # query is determined by the session.
+ instance = self.instance
+ if sess is None:
+ sess = object_session(instance)
+ if sess is None:
+ raise orm_exc.DetachedInstanceError(
+ "Parent instance %s is not bound to a Session, and no "
+ "contextual session is established; lazy load operation "
+ "of attribute '%s' cannot proceed"
+ % (orm_util.instance_str(instance), self.attr.key)
+ )
+
+ if self.query_class:
+ query = self.query_class(self.attr.target_mapper, session=sess)
+ else:
+ query = sess.query(self.attr.target_mapper)
+
+ query._where_criteria = self._where_criteria
+ query._from_obj = self._from_obj
+ query._order_by_clauses = self._order_by_clauses
+
+ return query
+
+ def extend(self, iterator):
+ for item in iterator:
+ self.attr.append(
+ attributes.instance_state(self.instance),
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
+
+ def append(self, item):
+ self.attr.append(
+ attributes.instance_state(self.instance),
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
+
+ def remove(self, item):
+ self.attr.remove(
+ attributes.instance_state(self.instance),
+ attributes.instance_dict(self.instance),
+ item,
+ None,
+ )
+
+
+class AppenderQuery(AppenderMixin, Query):
+ """A dynamic query that supports basic collection storage operations."""
+
+
+def mixin_user_query(cls):
+ """Return a new class with AppenderQuery functionality layered over."""
+ name = "Appender" + cls.__name__
+ return type(name, (AppenderMixin, cls), {"query_class": cls})
+
+
+class CollectionHistory(object):
+ """Overrides AttributeHistory to receive append/remove events directly."""
+
+ def __init__(self, attr, state, apply_to=None):
+ if apply_to:
+ coll = AppenderQuery(attr, state).autoflush(False)
+ self.unchanged_items = util.OrderedIdentitySet(coll)
+ self.added_items = apply_to.added_items
+ self.deleted_items = apply_to.deleted_items
+ self._reconcile_collection = True
+ else:
+ self.deleted_items = util.OrderedIdentitySet()
+ self.added_items = util.OrderedIdentitySet()
+ self.unchanged_items = util.OrderedIdentitySet()
+ self._reconcile_collection = False
+
+ @property
+ def added_plus_unchanged(self):
+ return list(self.added_items.union(self.unchanged_items))
+
+ @property
+ def all_items(self):
+ return list(
+ self.added_items.union(self.unchanged_items).union(
+ self.deleted_items
+ )
+ )
+
+ def as_history(self):
+ if self._reconcile_collection:
+ added = self.added_items.difference(self.unchanged_items)
+ deleted = self.deleted_items.intersection(self.unchanged_items)
+ unchanged = self.unchanged_items.difference(deleted)
+ else:
+ added, unchanged, deleted = (
+ self.added_items,
+ self.unchanged_items,
+ self.deleted_items,
+ )
+ return attributes.History(list(added), list(unchanged), list(deleted))
+
+ def indexed(self, index):
+ return list(self.added_items)[index]
+
+ def add_added(self, value):
+ self.added_items.add(value)
+
+ def add_removed(self, value):
+ if value in self.added_items:
+ self.added_items.remove(value)
+ else:
+ self.deleted_items.add(value)
diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py
new file mode 100644
index 0000000..dbbfba0
--- /dev/null
+++ b/lib/sqlalchemy/orm/evaluator.py
@@ -0,0 +1,241 @@
+# orm/evaluator.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import operator
+
+from .. import inspect
+from .. import util
+from ..sql import and_
+from ..sql import operators
+
+
+class UnevaluatableError(Exception):
+ pass
+
+
+class _NoObject(operators.ColumnOperators):
+ def operate(self, *arg, **kw):
+ return None
+
+ def reverse_operate(self, *arg, **kw):
+ return None
+
+
+_NO_OBJECT = _NoObject()
+
+_straight_ops = set(
+ getattr(operators, op)
+ for op in (
+ "add",
+ "mul",
+ "sub",
+ "div",
+ "mod",
+ "truediv",
+ "lt",
+ "le",
+ "ne",
+ "gt",
+ "ge",
+ "eq",
+ )
+)
+
+_extended_ops = {
+ operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
+ operators.not_in_op: (
+ lambda a, b: a not in b if a is not _NO_OBJECT else None
+ ),
+}
+
+_notimplemented_ops = set(
+ getattr(operators, op)
+ for op in (
+ "like_op",
+ "not_like_op",
+ "ilike_op",
+ "not_ilike_op",
+ "startswith_op",
+ "between_op",
+ "endswith_op",
+ "concat_op",
+ )
+)
+
+
+class EvaluatorCompiler(object):
+ def __init__(self, target_cls=None):
+ self.target_cls = target_cls
+
+ def process(self, *clauses):
+ if len(clauses) > 1:
+ clause = and_(*clauses)
+ elif clauses:
+ clause = clauses[0]
+
+ meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
+ if not meth:
+ raise UnevaluatableError(
+ "Cannot evaluate %s" % type(clause).__name__
+ )
+ return meth(clause)
+
+ def visit_grouping(self, clause):
+ return self.process(clause.element)
+
+ def visit_null(self, clause):
+ return lambda obj: None
+
+ def visit_false(self, clause):
+ return lambda obj: False
+
+ def visit_true(self, clause):
+ return lambda obj: True
+
+ def visit_column(self, clause):
+ if "parentmapper" in clause._annotations:
+ parentmapper = clause._annotations["parentmapper"]
+ if self.target_cls and not issubclass(
+ self.target_cls, parentmapper.class_
+ ):
+ raise UnevaluatableError(
+ "Can't evaluate criteria against alternate class %s"
+ % parentmapper.class_
+ )
+ key = parentmapper._columntoproperty[clause].key
+ else:
+ key = clause.key
+ if (
+ self.target_cls
+ and key in inspect(self.target_cls).column_attrs
+ ):
+ util.warn(
+ "Evaluating non-mapped column expression '%s' onto "
+ "ORM instances; this is a deprecated use case. Please "
+ "make use of the actual mapped columns in ORM-evaluated "
+ "UPDATE / DELETE expressions." % clause
+ )
+ else:
+ raise UnevaluatableError("Cannot evaluate column: %s" % clause)
+
+ get_corresponding_attr = operator.attrgetter(key)
+ return (
+ lambda obj: get_corresponding_attr(obj)
+ if obj is not None
+ else _NO_OBJECT
+ )
+
+ def visit_tuple(self, clause):
+ return self.visit_clauselist(clause)
+
+ def visit_clauselist(self, clause):
+ evaluators = list(map(self.process, clause.clauses))
+ if clause.operator is operators.or_:
+
+ def evaluate(obj):
+ has_null = False
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if value:
+ return True
+ has_null = has_null or value is None
+ if has_null:
+ return None
+ return False
+
+ elif clause.operator is operators.and_:
+
+ def evaluate(obj):
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if not value:
+ if value is None or value is _NO_OBJECT:
+ return None
+ return False
+ return True
+
+ elif clause.operator is operators.comma_op:
+
+ def evaluate(obj):
+ values = []
+ for sub_evaluate in evaluators:
+ value = sub_evaluate(obj)
+ if value is None or value is _NO_OBJECT:
+ return None
+ values.append(value)
+ return tuple(values)
+
+ else:
+ raise UnevaluatableError(
+ "Cannot evaluate clauselist with operator %s" % clause.operator
+ )
+
+ return evaluate
+
+ def visit_binary(self, clause):
+ eval_left, eval_right = list(
+ map(self.process, [clause.left, clause.right])
+ )
+ operator = clause.operator
+ if operator is operators.is_:
+
+ def evaluate(obj):
+ return eval_left(obj) == eval_right(obj)
+
+ elif operator is operators.is_not:
+
+ def evaluate(obj):
+ return eval_left(obj) != eval_right(obj)
+
+ elif operator in _extended_ops:
+
+ def evaluate(obj):
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is None or right_val is None:
+ return None
+
+ return _extended_ops[operator](left_val, right_val)
+
+ elif operator in _straight_ops:
+
+ def evaluate(obj):
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is None or right_val is None:
+ return None
+ return operator(eval_left(obj), eval_right(obj))
+
+ else:
+ raise UnevaluatableError(
+ "Cannot evaluate %s with operator %s"
+ % (type(clause).__name__, clause.operator)
+ )
+ return evaluate
+
+ def visit_unary(self, clause):
+ eval_inner = self.process(clause.element)
+ if clause.operator is operators.inv:
+
+ def evaluate(obj):
+ value = eval_inner(obj)
+ if value is None:
+ return None
+ return not value
+
+ return evaluate
+ raise UnevaluatableError(
+ "Cannot evaluate %s with operator %s"
+ % (type(clause).__name__, clause.operator)
+ )
+
+ def visit_bindparam(self, clause):
+ if clause.callable:
+ val = clause.callable()
+ else:
+ val = clause.value
+ return lambda obj: val
diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py
new file mode 100644
index 0000000..39659c7
--- /dev/null
+++ b/lib/sqlalchemy/orm/events.py
@@ -0,0 +1,2876 @@
+# orm/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""ORM event interfaces.
+
+"""
+import weakref
+
+from . import instrumentation
+from . import interfaces
+from . import mapperlib
+from .attributes import QueryableAttribute
+from .base import _mapper_or_none
+from .query import Query
+from .scoping import scoped_session
+from .session import Session
+from .session import sessionmaker
+from .. import event
+from .. import exc
+from .. import util
+from ..util.compat import inspect_getfullargspec
+
+
+class InstrumentationEvents(event.Events):
+ """Events related to class instrumentation events.
+
+ The listeners here support being established against
+ any new style class, that is any object that is a subclass
+ of 'type'. Events will then be fired off for events
+ against that class. If the "propagate=True" flag is passed
+ to event.listen(), the event will fire off for subclasses
+ of that class as well.
+
+ The Python ``type`` builtin is also accepted as a target,
+ which when used has the effect of events being emitted
+ for all classes.
+
+ Note the "propagate" flag here is defaulted to ``True``,
+ unlike the other class level events where it defaults
+ to ``False``. This means that new subclasses will also
+ be the subject of these events, when a listener
+ is established on a superclass.
+
+ """
+
+ _target_class_doc = "SomeBaseClass"
+ _dispatch_target = instrumentation.InstrumentationFactory
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, type):
+ return _InstrumentationEventsHold(target)
+ else:
+ return None
+
+ @classmethod
+ def _listen(cls, event_key, propagate=True, **kw):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
+
+ def listen(target_cls, *arg):
+ listen_cls = target()
+
+ # if weakref were collected, however this is not something
+ # that normally happens. it was occurring during test teardown
+ # between mapper/registry/instrumentation_manager, however this
+ # interaction was changed to not rely upon the event system.
+ if listen_cls is None:
+ return None
+
+ if propagate and issubclass(target_cls, listen_cls):
+ return fn(target_cls, *arg)
+ elif not propagate and target_cls is listen_cls:
+ return fn(target_cls, *arg)
+
+ def remove(ref):
+ key = event.registry._EventKey(
+ None,
+ identifier,
+ listen,
+ instrumentation._instrumentation_factory,
+ )
+ getattr(
+ instrumentation._instrumentation_factory.dispatch, identifier
+ ).remove(key)
+
+ target = weakref.ref(target.class_, remove)
+
+ event_key.with_dispatch_target(
+ instrumentation._instrumentation_factory
+ ).with_wrapper(listen).base_listen(**kw)
+
+ @classmethod
+ def _clear(cls):
+ super(InstrumentationEvents, cls)._clear()
+ instrumentation._instrumentation_factory.dispatch._clear()
+
+ def class_instrument(self, cls):
+ """Called after the given class is instrumented.
+
+ To get at the :class:`.ClassManager`, use
+ :func:`.manager_of_class`.
+
+ """
+
+ def class_uninstrument(self, cls):
+ """Called before the given class is uninstrumented.
+
+ To get at the :class:`.ClassManager`, use
+ :func:`.manager_of_class`.
+
+ """
+
+ def attribute_instrument(self, cls, key, inst):
+ """Called when an attribute is instrumented."""
+
+
+class _InstrumentationEventsHold(object):
+ """temporary marker object used to transfer from _accept_with() to
+ _listen() on the InstrumentationEvents class.
+
+ """
+
+ def __init__(self, class_):
+ self.class_ = class_
+
+ dispatch = event.dispatcher(InstrumentationEvents)
+
+
+class InstanceEvents(event.Events):
+ """Define events specific to object lifecycle.
+
+ e.g.::
+
+ from sqlalchemy import event
+
+ def my_load_listener(target, context):
+ print("on load!")
+
+ event.listen(SomeClass, 'load', my_load_listener)
+
+ Available targets include:
+
+ * mapped classes
+ * unmapped superclasses of mapped or to-be-mapped classes
+ (using the ``propagate=True`` flag)
+ * :class:`_orm.Mapper` objects
+ * the :class:`_orm.Mapper` class itself and the :func:`.mapper`
+ function indicate listening for all mappers.
+
+ Instance events are closely related to mapper events, but
+ are more specific to the instance and its instrumentation,
+ rather than its system of persistence.
+
+ When using :class:`.InstanceEvents`, several modifiers are
+ available to the :func:`.event.listen` function.
+
+ :param propagate=False: When True, the event listener should
+ be applied to all inheriting classes as well as the
+ class which is the target of this listener.
+ :param raw=False: When True, the "target" argument passed
+ to applicable event listener functions will be the
+ instance's :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+ :param restore_load_context=False: Applies to the
+ :meth:`.InstanceEvents.load` and :meth:`.InstanceEvents.refresh`
+ events. Restores the loader context of the object when the event
+ hook is complete, so that ongoing eager load operations continue
+ to target the object appropriately. A warning is emitted if the
+ object is moved to a new loader context from within one of these
+ events if this flag is not set.
+
+ .. versionadded:: 1.3.14
+
+
+ """
+
+ _target_class_doc = "SomeClass"
+
+ _dispatch_target = instrumentation.ClassManager
+
+ @classmethod
+ def _new_classmanager_instance(cls, class_, classmanager):
+ _InstanceEventsHold.populate(class_, classmanager)
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm")
+ def _accept_with(cls, target):
+ orm = util.preloaded.orm
+
+ if isinstance(target, instrumentation.ClassManager):
+ return target
+ elif isinstance(target, mapperlib.Mapper):
+ return target.class_manager
+ elif target is orm.mapper:
+ return instrumentation.ClassManager
+ elif isinstance(target, type):
+ if issubclass(target, mapperlib.Mapper):
+ return instrumentation.ClassManager
+ else:
+ manager = instrumentation.manager_of_class(target)
+ if manager:
+ return manager
+ else:
+ return _InstanceEventsHold(target)
+ return None
+
+ @classmethod
+ def _listen(
+ cls,
+ event_key,
+ raw=False,
+ propagate=False,
+ restore_load_context=False,
+ **kw
+ ):
+ target, fn = (event_key.dispatch_target, event_key._listen_fn)
+
+ if not raw or restore_load_context:
+
+ def wrap(state, *arg, **kw):
+ if not raw:
+ target = state.obj()
+ else:
+ target = state
+ if restore_load_context:
+ runid = state.runid
+ try:
+ return fn(target, *arg, **kw)
+ finally:
+ if restore_load_context:
+ state.runid = runid
+
+ event_key = event_key.with_wrapper(wrap)
+
+ event_key.base_listen(propagate=propagate, **kw)
+
+ if propagate:
+ for mgr in target.subclass_managers(True):
+ event_key.with_dispatch_target(mgr).base_listen(propagate=True)
+
+ @classmethod
+ def _clear(cls):
+ super(InstanceEvents, cls)._clear()
+ _InstanceEventsHold._clear()
+
+ def first_init(self, manager, cls):
+ """Called when the first instance of a particular mapping is called.
+
+ This event is called when the ``__init__`` method of a class
+ is called the first time for that particular class. The event
+ invokes before ``__init__`` actually proceeds as well as before
+ the :meth:`.InstanceEvents.init` event is invoked.
+
+ """
+
+ def init(self, target, args, kwargs):
+ """Receive an instance when its constructor is called.
+
+ This method is only called during a userland construction of
+ an object, in conjunction with the object's constructor, e.g.
+ its ``__init__`` method. It is not called when an object is
+ loaded from the database; see the :meth:`.InstanceEvents.load`
+ event in order to intercept a database load.
+
+ The event is called before the actual ``__init__`` constructor
+ of the object is called. The ``kwargs`` dictionary may be
+ modified in-place in order to affect what is passed to
+ ``__init__``.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param args: positional arguments passed to the ``__init__`` method.
+ This is passed as a tuple and is currently immutable.
+ :param kwargs: keyword arguments passed to the ``__init__`` method.
+ This structure *can* be altered in place.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.init_failure`
+
+ :meth:`.InstanceEvents.load`
+
+ """
+
+ def init_failure(self, target, args, kwargs):
+ """Receive an instance when its constructor has been called,
+ and raised an exception.
+
+ This method is only called during a userland construction of
+ an object, in conjunction with the object's constructor, e.g.
+ its ``__init__`` method. It is not called when an object is loaded
+ from the database.
+
+ The event is invoked after an exception raised by the ``__init__``
+ method is caught. After the event
+ is invoked, the original exception is re-raised outwards, so that
+ the construction of the object still raises an exception. The
+ actual exception and stack trace raised should be present in
+ ``sys.exc_info()``.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param args: positional arguments that were passed to the ``__init__``
+ method.
+ :param kwargs: keyword arguments that were passed to the ``__init__``
+ method.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.init`
+
+ :meth:`.InstanceEvents.load`
+
+ """
+
+ def load(self, target, context):
+ """Receive an object instance after it has been created via
+ ``__new__``, and after initial attribute population has
+ occurred.
+
+ This typically occurs when the instance is created based on
+ incoming result rows, and is only called once for that
+ instance's lifetime.
+
+ .. warning::
+
+ During a result-row load, this event is invoked when the
+ first row received for this instance is processed. When using
+ eager loading with collection-oriented attributes, the additional
+ rows that are to be loaded / processed in order to load subsequent
+ collection items have not occurred yet. This has the effect
+ both that collections will not be fully loaded, as well as that
+ if an operation occurs within this event handler that emits
+ another database load operation for the object, the "loading
+ context" for the object can change and interfere with the
+ existing eager loaders still in progress.
+
+ Examples of what can cause the "loading context" to change within
+ the event handler include, but are not necessarily limited to:
+
+ * accessing deferred attributes that weren't part of the row,
+ will trigger an "undefer" operation and refresh the object
+
+ * accessing attributes on a joined-inheritance subclass that
+ weren't part of the row, will trigger a refresh operation.
+
+ As of SQLAlchemy 1.3.14, a warning is emitted when this occurs. The
+ :paramref:`.InstanceEvents.restore_load_context` option may be
+ used on the event to prevent this warning; this will ensure that
+ the existing loading context is maintained for the object after the
+ event is called::
+
+ @event.listens_for(
+ SomeClass, "load", restore_load_context=True)
+ def on_load(instance, context):
+ instance.some_unloaded_attribute
+
+ .. versionchanged:: 1.3.14 Added
+ :paramref:`.InstanceEvents.restore_load_context`
+ and :paramref:`.SessionEvents.restore_load_context` flags which
+ apply to "on load" events, which will ensure that the loading
+ context for an object is restored when the event hook is
+ complete; a warning is emitted if the load context of the object
+ changes without this flag being set.
+
+
+ The :meth:`.InstanceEvents.load` event is also available in a
+ class-method decorator format called :func:`_orm.reconstructor`.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param context: the :class:`.QueryContext` corresponding to the
+ current :class:`_query.Query` in progress. This argument may be
+ ``None`` if the load does not correspond to a :class:`_query.Query`,
+ such as during :meth:`.Session.merge`.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.init`
+
+ :meth:`.InstanceEvents.refresh`
+
+ :meth:`.SessionEvents.loaded_as_persistent`
+
+ :ref:`mapping_constructors`
+
+ """
+
+ def refresh(self, target, context, attrs):
+ """Receive an object instance after one or more attributes have
+ been refreshed from a query.
+
+ Contrast this to the :meth:`.InstanceEvents.load` method, which
+ is invoked when the object is first loaded from a query.
+
+ .. note:: This event is invoked within the loader process before
+ eager loaders may have been completed, and the object's state may
+ not be complete. Additionally, invoking row-level refresh
+ operations on the object will place the object into a new loader
+ context, interfering with the existing load context. See the note
+ on :meth:`.InstanceEvents.load` for background on making use of the
+ :paramref:`.InstanceEvents.restore_load_context` parameter, in
+ order to resolve this scenario.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param context: the :class:`.QueryContext` corresponding to the
+ current :class:`_query.Query` in progress.
+ :param attrs: sequence of attribute names which
+ were populated, or None if all column-mapped, non-deferred
+ attributes were populated.
+
+ .. seealso::
+
+ :meth:`.InstanceEvents.load`
+
+ """
+
+ def refresh_flush(self, target, flush_context, attrs):
+ """Receive an object instance after one or more attributes that
+ contain a column-level default or onupdate handler have been refreshed
+ during persistence of the object's state.
+
+ This event is the same as :meth:`.InstanceEvents.refresh` except
+ it is invoked within the unit of work flush process, and includes
+ only non-primary-key columns that have column level default or
+ onupdate handlers, including Python callables as well as server side
+ defaults and triggers which may be fetched via the RETURNING clause.
+
+ .. note::
+
+ While the :meth:`.InstanceEvents.refresh_flush` event is triggered
+ for an object that was INSERTed as well as for an object that was
+ UPDATEd, the event is geared primarily towards the UPDATE process;
+ it is mostly an internal artifact that INSERT actions can also
+ trigger this event, and note that **primary key columns for an
+ INSERTed row are explicitly omitted** from this event. In order to
+ intercept the newly INSERTed state of an object, the
+ :meth:`.SessionEvents.pending_to_persistent` and
+ :meth:`.MapperEvents.after_insert` are better choices.
+
+ .. versionadded:: 1.0.5
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+ :param attrs: sequence of attribute names which
+ were populated.
+
+ .. seealso::
+
+ :ref:`orm_server_defaults`
+
+ :ref:`metadata_defaults_toplevel`
+
+ """
+
+ def expire(self, target, attrs):
+ """Receive an object instance after its attributes or some subset
+ have been expired.
+
+ 'keys' is a list of attribute names. If None, the entire
+ state was expired.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param attrs: sequence of attribute
+ names which were expired, or None if all attributes were
+ expired.
+
+ """
+
+ def pickle(self, target, state_dict):
+ """Receive an object instance when its associated state is
+ being pickled.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param state_dict: the dictionary returned by
+ :class:`.InstanceState.__getstate__`, containing the state
+ to be pickled.
+
+ """
+
+ def unpickle(self, target, state_dict):
+ """Receive an object instance after its associated state has
+ been unpickled.
+
+ :param target: the mapped instance. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :param state_dict: the dictionary sent to
+ :class:`.InstanceState.__setstate__`, containing the state
+ dictionary which was pickled.
+
+ """
+
+
+class _EventsHold(event.RefCollection):
+ """Hold onto listeners against unmapped, uninstrumented classes.
+
+ Establish _listen() for that class' mapper/instrumentation when
+ those objects are created for that class.
+
+ """
+
+ def __init__(self, class_):
+ self.class_ = class_
+
+ @classmethod
+ def _clear(cls):
+ cls.all_holds.clear()
+
+ class HoldEvents(object):
+ _dispatch_target = None
+
+ @classmethod
+ def _listen(
+ cls, event_key, raw=False, propagate=False, retval=False, **kw
+ ):
+ target = event_key.dispatch_target
+
+ if target.class_ in target.all_holds:
+ collection = target.all_holds[target.class_]
+ else:
+ collection = target.all_holds[target.class_] = {}
+
+ event.registry._stored_in_collection(event_key, target)
+ collection[event_key._key] = (
+ event_key,
+ raw,
+ propagate,
+ retval,
+ kw,
+ )
+
+ if propagate:
+ stack = list(target.class_.__subclasses__())
+ while stack:
+ subclass = stack.pop(0)
+ stack.extend(subclass.__subclasses__())
+ subject = target.resolve(subclass)
+ if subject is not None:
+ # we are already going through __subclasses__()
+ # so leave generic propagate flag False
+ event_key.with_dispatch_target(subject).listen(
+ raw=raw, propagate=False, retval=retval, **kw
+ )
+
+ def remove(self, event_key):
+ target = event_key.dispatch_target
+
+ if isinstance(target, _EventsHold):
+ collection = target.all_holds[target.class_]
+ del collection[event_key._key]
+
+ @classmethod
+ def populate(cls, class_, subject):
+ for subclass in class_.__mro__:
+ if subclass in cls.all_holds:
+ collection = cls.all_holds[subclass]
+ for (
+ event_key,
+ raw,
+ propagate,
+ retval,
+ kw,
+ ) in collection.values():
+ if propagate or subclass is class_:
+ # since we can't be sure in what order different
+ # classes in a hierarchy are triggered with
+ # populate(), we rely upon _EventsHold for all event
+ # assignment, instead of using the generic propagate
+ # flag.
+ event_key.with_dispatch_target(subject).listen(
+ raw=raw, propagate=False, retval=retval, **kw
+ )
+
+
+class _InstanceEventsHold(_EventsHold):
+ all_holds = weakref.WeakKeyDictionary()
+
+ def resolve(self, class_):
+ return instrumentation.manager_of_class(class_)
+
+ class HoldInstanceEvents(_EventsHold.HoldEvents, InstanceEvents):
+ pass
+
+ dispatch = event.dispatcher(HoldInstanceEvents)
+
+
+class MapperEvents(event.Events):
+ """Define events specific to mappings.
+
+ e.g.::
+
+ from sqlalchemy import event
+
+ def my_before_insert_listener(mapper, connection, target):
+ # execute a stored procedure upon INSERT,
+ # apply the value to the row to be inserted
+ target.calculated_value = connection.execute(
+ text("select my_special_function(%d)" % target.special_number)
+ ).scalar()
+
+ # associate the listener function with SomeClass,
+ # to execute during the "before_insert" hook
+ event.listen(
+ SomeClass, 'before_insert', my_before_insert_listener)
+
+ Available targets include:
+
+ * mapped classes
+ * unmapped superclasses of mapped or to-be-mapped classes
+ (using the ``propagate=True`` flag)
+ * :class:`_orm.Mapper` objects
+ * the :class:`_orm.Mapper` class itself and the :func:`.mapper`
+ function indicate listening for all mappers.
+
+ Mapper events provide hooks into critical sections of the
+ mapper, including those related to object instrumentation,
+ object loading, and object persistence. In particular, the
+ persistence methods :meth:`~.MapperEvents.before_insert`,
+ and :meth:`~.MapperEvents.before_update` are popular
+ places to augment the state being persisted - however, these
+ methods operate with several significant restrictions. The
+ user is encouraged to evaluate the
+ :meth:`.SessionEvents.before_flush` and
+ :meth:`.SessionEvents.after_flush` methods as more
+ flexible and user-friendly hooks in which to apply
+ additional database state during a flush.
+
+ When using :class:`.MapperEvents`, several modifiers are
+ available to the :func:`.event.listen` function.
+
+ :param propagate=False: When True, the event listener should
+ be applied to all inheriting mappers and/or the mappers of
+ inheriting classes, as well as any
+ mapper which is the target of this listener.
+ :param raw=False: When True, the "target" argument passed
+ to applicable event listener functions will be the
+ instance's :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+ :param retval=False: when True, the user-defined event function
+ must have a return value, the purpose of which is either to
+ control subsequent event propagation, or to otherwise alter
+ the operation in progress by the mapper. Possible return
+ values are:
+
+ * ``sqlalchemy.orm.interfaces.EXT_CONTINUE`` - continue event
+ processing normally.
+ * ``sqlalchemy.orm.interfaces.EXT_STOP`` - cancel all subsequent
+ event handlers in the chain.
+ * other values - the return value specified by specific listeners.
+
+ """
+
+ _target_class_doc = "SomeClass"
+ _dispatch_target = mapperlib.Mapper
+
+ @classmethod
+ def _new_mapper_instance(cls, class_, mapper):
+ _MapperEventsHold.populate(class_, mapper)
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm")
+ def _accept_with(cls, target):
+ orm = util.preloaded.orm
+
+ if target is orm.mapper:
+ return mapperlib.Mapper
+ elif isinstance(target, type):
+ if issubclass(target, mapperlib.Mapper):
+ return target
+ else:
+ mapper = _mapper_or_none(target)
+ if mapper is not None:
+ return mapper
+ else:
+ return _MapperEventsHold(target)
+ else:
+ return target
+
+ @classmethod
+ def _listen(
+ cls, event_key, raw=False, retval=False, propagate=False, **kw
+ ):
+ target, identifier, fn = (
+ event_key.dispatch_target,
+ event_key.identifier,
+ event_key._listen_fn,
+ )
+
+ if (
+ identifier in ("before_configured", "after_configured")
+ and target is not mapperlib.Mapper
+ ):
+ util.warn(
+ "'before_configured' and 'after_configured' ORM events "
+ "only invoke with the mapper() function or Mapper class "
+ "as the target."
+ )
+
+ if not raw or not retval:
+ if not raw:
+ meth = getattr(cls, identifier)
+ try:
+ target_index = (
+ inspect_getfullargspec(meth)[0].index("target") - 1
+ )
+ except ValueError:
+ target_index = None
+
+ def wrap(*arg, **kw):
+ if not raw and target_index is not None:
+ arg = list(arg)
+ arg[target_index] = arg[target_index].obj()
+ if not retval:
+ fn(*arg, **kw)
+ return interfaces.EXT_CONTINUE
+ else:
+ return fn(*arg, **kw)
+
+ event_key = event_key.with_wrapper(wrap)
+
+ if propagate:
+ for mapper in target.self_and_descendants:
+ event_key.with_dispatch_target(mapper).base_listen(
+ propagate=True, **kw
+ )
+ else:
+ event_key.base_listen(**kw)
+
+ @classmethod
+ def _clear(cls):
+ super(MapperEvents, cls)._clear()
+ _MapperEventsHold._clear()
+
+ def instrument_class(self, mapper, class_):
+ r"""Receive a class when the mapper is first constructed,
+ before instrumentation is applied to the mapped class.
+
+ This event is the earliest phase of mapper construction.
+ Most attributes of the mapper are not yet initialized.
+
+ This listener can either be applied to the :class:`_orm.Mapper`
+ class overall, or to any un-mapped class which serves as a base
+ for classes that will be mapped (using the ``propagate=True`` flag)::
+
+ Base = declarative_base()
+
+ @event.listens_for(Base, "instrument_class", propagate=True)
+ def on_new_class(mapper, cls_):
+ " ... "
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param class\_: the mapped class.
+
+ """
+
+ def before_mapper_configured(self, mapper, class_):
+ """Called right before a specific mapper is to be configured.
+
+ This event is intended to allow a specific mapper to be skipped during
+ the configure step, by returning the :attr:`.orm.interfaces.EXT_SKIP`
+ symbol which indicates to the :func:`.configure_mappers` call that this
+ particular mapper (or hierarchy of mappers, if ``propagate=True`` is
+ used) should be skipped in the current configuration run. When one or
+ more mappers are skipped, the he "new mappers" flag will remain set,
+ meaning the :func:`.configure_mappers` function will continue to be
+ called when mappers are used, to continue to try to configure all
+ available mappers.
+
+ In comparison to the other configure-level events,
+ :meth:`.MapperEvents.before_configured`,
+ :meth:`.MapperEvents.after_configured`, and
+ :meth:`.MapperEvents.mapper_configured`, the
+ :meth;`.MapperEvents.before_mapper_configured` event provides for a
+ meaningful return value when it is registered with the ``retval=True``
+ parameter.
+
+ .. versionadded:: 1.3
+
+ e.g.::
+
+ from sqlalchemy.orm import EXT_SKIP
+
+ Base = declarative_base()
+
+ DontConfigureBase = declarative_base()
+
+ @event.listens_for(
+ DontConfigureBase,
+ "before_mapper_configured", retval=True, propagate=True)
+ def dont_configure(mapper, cls):
+ return EXT_SKIP
+
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_configured`
+
+ :meth:`.MapperEvents.after_configured`
+
+ :meth:`.MapperEvents.mapper_configured`
+
+ """
+
+ def mapper_configured(self, mapper, class_):
+ r"""Called when a specific mapper has completed its own configuration
+ within the scope of the :func:`.configure_mappers` call.
+
+ The :meth:`.MapperEvents.mapper_configured` event is invoked
+ for each mapper that is encountered when the
+ :func:`_orm.configure_mappers` function proceeds through the current
+ list of not-yet-configured mappers.
+ :func:`_orm.configure_mappers` is typically invoked
+ automatically as mappings are first used, as well as each time
+ new mappers have been made available and new mapper use is
+ detected.
+
+ When the event is called, the mapper should be in its final
+ state, but **not including backrefs** that may be invoked from
+ other mappers; they might still be pending within the
+ configuration operation. Bidirectional relationships that
+ are instead configured via the
+ :paramref:`.orm.relationship.back_populates` argument
+ *will* be fully available, since this style of relationship does not
+ rely upon other possibly-not-configured mappers to know that they
+ exist.
+
+ For an event that is guaranteed to have **all** mappers ready
+ to go including backrefs that are defined only on other
+ mappings, use the :meth:`.MapperEvents.after_configured`
+ event; this event invokes only after all known mappings have been
+ fully configured.
+
+ The :meth:`.MapperEvents.mapper_configured` event, unlike
+ :meth:`.MapperEvents.before_configured` or
+ :meth:`.MapperEvents.after_configured`,
+ is called for each mapper/class individually, and the mapper is
+ passed to the event itself. It also is called exactly once for
+ a particular mapper. The event is therefore useful for
+ configurational steps that benefit from being invoked just once
+ on a specific mapper basis, which don't require that "backref"
+ configurations are necessarily ready yet.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param class\_: the mapped class.
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_configured`
+
+ :meth:`.MapperEvents.after_configured`
+
+ :meth:`.MapperEvents.before_mapper_configured`
+
+ """
+ # TODO: need coverage for this event
+
+ def before_configured(self):
+ """Called before a series of mappers have been configured.
+
+ The :meth:`.MapperEvents.before_configured` event is invoked
+ each time the :func:`_orm.configure_mappers` function is
+ invoked, before the function has done any of its work.
+ :func:`_orm.configure_mappers` is typically invoked
+ automatically as mappings are first used, as well as each time
+ new mappers have been made available and new mapper use is
+ detected.
+
+ This event can **only** be applied to the :class:`_orm.Mapper` class
+ or :func:`.mapper` function, and not to individual mappings or
+ mapped classes. It is only invoked for all mappings as a whole::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "before_configured")
+ def go():
+ # ...
+
+ Contrast this event to :meth:`.MapperEvents.after_configured`,
+ which is invoked after the series of mappers has been configured,
+ as well as :meth:`.MapperEvents.before_mapper_configured`
+ and :meth:`.MapperEvents.mapper_configured`, which are both invoked
+ on a per-mapper basis.
+
+ Theoretically this event is called once per
+ application, but is actually called any time new mappers
+ are to be affected by a :func:`_orm.configure_mappers`
+ call. If new mappings are constructed after existing ones have
+ already been used, this event will likely be called again. To ensure
+ that a particular event is only called once and no further, the
+ ``once=True`` argument (new in 0.9.4) can be applied::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "before_configured", once=True)
+ def go():
+ # ...
+
+
+ .. versionadded:: 0.9.3
+
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_mapper_configured`
+
+ :meth:`.MapperEvents.mapper_configured`
+
+ :meth:`.MapperEvents.after_configured`
+
+ """
+
+ def after_configured(self):
+ """Called after a series of mappers have been configured.
+
+ The :meth:`.MapperEvents.after_configured` event is invoked
+ each time the :func:`_orm.configure_mappers` function is
+ invoked, after the function has completed its work.
+ :func:`_orm.configure_mappers` is typically invoked
+ automatically as mappings are first used, as well as each time
+ new mappers have been made available and new mapper use is
+ detected.
+
+ Contrast this event to the :meth:`.MapperEvents.mapper_configured`
+ event, which is called on a per-mapper basis while the configuration
+ operation proceeds; unlike that event, when this event is invoked,
+ all cross-configurations (e.g. backrefs) will also have been made
+ available for any mappers that were pending.
+ Also contrast to :meth:`.MapperEvents.before_configured`,
+ which is invoked before the series of mappers has been configured.
+
+ This event can **only** be applied to the :class:`_orm.Mapper` class
+ or :func:`.mapper` function, and not to individual mappings or
+ mapped classes. It is only invoked for all mappings as a whole::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "after_configured")
+ def go():
+ # ...
+
+ Theoretically this event is called once per
+ application, but is actually called any time new mappers
+ have been affected by a :func:`_orm.configure_mappers`
+ call. If new mappings are constructed after existing ones have
+ already been used, this event will likely be called again. To ensure
+ that a particular event is only called once and no further, the
+ ``once=True`` argument (new in 0.9.4) can be applied::
+
+ from sqlalchemy.orm import mapper
+
+ @event.listens_for(mapper, "after_configured", once=True)
+ def go():
+ # ...
+
+ .. seealso::
+
+ :meth:`.MapperEvents.before_mapper_configured`
+
+ :meth:`.MapperEvents.mapper_configured`
+
+ :meth:`.MapperEvents.before_configured`
+
+ """
+
+ def before_insert(self, mapper, connection, target):
+ """Receive an object instance before an INSERT statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify local, non-object related
+ attributes on the instance before an INSERT occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ The event is often called for a batch of objects of the
+ same class before their INSERT statements are emitted at
+ once in a later step. In the extremely rare case that
+ this is not desirable, the :func:`.mapper` can be
+ configured with ``batch=False``, which will cause
+ batches of instances to be broken up into individual
+ (and more poorly performing) event->persist->event
+ steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit INSERT statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_insert(self, mapper, connection, target):
+ """Receive an object instance after an INSERT statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify in-Python-only
+ state on the instance after an INSERT occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ The event is often called for a batch of objects of the
+ same class after their INSERT statements have been
+ emitted at once in a previous step. In the extremely
+ rare case that this is not desirable, the
+ :func:`.mapper` can be configured with ``batch=False``,
+ which will cause batches of instances to be broken up
+ into individual (and more poorly performing)
+ event->persist->event steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit INSERT statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def before_update(self, mapper, connection, target):
+ """Receive an object instance before an UPDATE statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify local, non-object related
+ attributes on the instance before an UPDATE occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ This method is called for all instances that are
+ marked as "dirty", *even those which have no net changes
+ to their column-based attributes*. An object is marked
+ as dirty when any of its column-based attributes have a
+ "set attribute" operation called or when any of its
+ collections are modified. If, at update time, no
+ column-based attributes have any net changes, no UPDATE
+ statement will be issued. This means that an instance
+ being sent to :meth:`~.MapperEvents.before_update` is
+ *not* a guarantee that an UPDATE statement will be
+ issued, although you can affect the outcome here by
+ modifying attributes so that a net change in value does
+ exist.
+
+ To detect if the column-based attributes on the object have net
+ changes, and will therefore generate an UPDATE statement, use
+ ``object_session(instance).is_modified(instance,
+ include_collections=False)``.
+
+ The event is often called for a batch of objects of the
+ same class before their UPDATE statements are emitted at
+ once in a later step. In the extremely rare case that
+ this is not desirable, the :func:`.mapper` can be
+ configured with ``batch=False``, which will cause
+ batches of instances to be broken up into individual
+ (and more poorly performing) event->persist->event
+ steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit UPDATE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_update(self, mapper, connection, target):
+ """Receive an object instance after an UPDATE statement
+ is emitted corresponding to that instance.
+
+ This event is used to modify in-Python-only
+ state on the instance after an UPDATE occurs, as well
+ as to emit additional SQL statements on the given
+ connection.
+
+ This method is called for all instances that are
+ marked as "dirty", *even those which have no net changes
+ to their column-based attributes*, and for which
+ no UPDATE statement has proceeded. An object is marked
+ as dirty when any of its column-based attributes have a
+ "set attribute" operation called or when any of its
+ collections are modified. If, at update time, no
+ column-based attributes have any net changes, no UPDATE
+ statement will be issued. This means that an instance
+ being sent to :meth:`~.MapperEvents.after_update` is
+ *not* a guarantee that an UPDATE statement has been
+ issued.
+
+ To detect if the column-based attributes on the object have net
+ changes, and therefore resulted in an UPDATE statement, use
+ ``object_session(instance).is_modified(instance,
+ include_collections=False)``.
+
+ The event is often called for a batch of objects of the
+ same class after their UPDATE statements have been emitted at
+ once in a previous step. In the extremely rare case that
+ this is not desirable, the :func:`.mapper` can be
+ configured with ``batch=False``, which will cause
+ batches of instances to be broken up into individual
+ (and more poorly performing) event->persist->event
+ steps.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit UPDATE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being persisted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def before_delete(self, mapper, connection, target):
+ """Receive an object instance before a DELETE statement
+ is emitted corresponding to that instance.
+
+ This event is used to emit additional SQL statements on
+ the given connection as well as to perform application
+ specific bookkeeping related to a deletion event.
+
+ The event is often called for a batch of objects of the
+ same class before their DELETE statements are emitted at
+ once in a later step.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit DELETE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being deleted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_delete(self, mapper, connection, target):
+ """Receive an object instance after a DELETE statement
+ has been emitted corresponding to that instance.
+
+ This event is used to emit additional SQL statements on
+ the given connection as well as to perform application
+ specific bookkeeping related to a deletion event.
+
+ The event is often called for a batch of objects of the
+ same class after their DELETE statements have been emitted at
+ once in a previous step.
+
+ .. warning::
+
+ Mapper-level flush events only allow **very limited operations**,
+ on attributes local to the row being operated upon only,
+ as well as allowing any SQL to be emitted on the given
+ :class:`_engine.Connection`. **Please read fully** the notes
+ at :ref:`session_persistence_mapper` for guidelines on using
+ these methods; generally, the :meth:`.SessionEvents.before_flush`
+ method should be preferred for general on-flush changes.
+
+ :param mapper: the :class:`_orm.Mapper` which is the target
+ of this event.
+ :param connection: the :class:`_engine.Connection` being used to
+ emit DELETE statements for this instance. This
+ provides a handle into the current transaction on the
+ target database specific to this instance.
+ :param target: the mapped instance being deleted. If
+ the event is configured with ``raw=True``, this will
+ instead be the :class:`.InstanceState` state-management
+ object associated with the instance.
+ :return: No return value is supported by this event.
+
+ .. seealso::
+
+ :ref:`session_persistence_events`
+
+ """
+
+
+class _MapperEventsHold(_EventsHold):
+ all_holds = weakref.WeakKeyDictionary()
+
+ def resolve(self, class_):
+ return _mapper_or_none(class_)
+
+ class HoldMapperEvents(_EventsHold.HoldEvents, MapperEvents):
+ pass
+
+ dispatch = event.dispatcher(HoldMapperEvents)
+
+
+_sessionevents_lifecycle_event_names = set()
+
+
+class SessionEvents(event.Events):
+ """Define events specific to :class:`.Session` lifecycle.
+
+ e.g.::
+
+ from sqlalchemy import event
+ from sqlalchemy.orm import sessionmaker
+
+ def my_before_commit(session):
+ print("before commit!")
+
+ Session = sessionmaker()
+
+ event.listen(Session, "before_commit", my_before_commit)
+
+ The :func:`~.event.listen` function will accept
+ :class:`.Session` objects as well as the return result
+ of :class:`~.sessionmaker()` and :class:`~.scoped_session()`.
+
+ Additionally, it accepts the :class:`.Session` class which
+ will apply listeners to all :class:`.Session` instances
+ globally.
+
+ :param raw=False: When True, the "target" argument passed
+ to applicable event listener functions that work on individual
+ objects will be the instance's :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+
+ .. versionadded:: 1.3.14
+
+ :param restore_load_context=False: Applies to the
+ :meth:`.SessionEvents.loaded_as_persistent` event. Restores the loader
+ context of the object when the event hook is complete, so that ongoing
+ eager load operations continue to target the object appropriately. A
+ warning is emitted if the object is moved to a new loader context from
+ within this event if this flag is not set.
+
+ .. versionadded:: 1.3.14
+
+ """
+
+ _target_class_doc = "SomeSessionClassOrObject"
+
+ _dispatch_target = Session
+
+ def _lifecycle_event(fn):
+ _sessionevents_lifecycle_event_names.add(fn.__name__)
+ return fn
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, scoped_session):
+
+ target = target.session_factory
+ if not isinstance(target, sessionmaker) and (
+ not isinstance(target, type) or not issubclass(target, Session)
+ ):
+ raise exc.ArgumentError(
+ "Session event listen on a scoped_session "
+ "requires that its creation callable "
+ "is associated with the Session class."
+ )
+
+ if isinstance(target, sessionmaker):
+ return target.class_
+ elif isinstance(target, type):
+ if issubclass(target, scoped_session):
+ return Session
+ elif issubclass(target, Session):
+ return target
+ elif isinstance(target, Session):
+ return target
+ else:
+ # allows alternate SessionEvents-like-classes to be consulted
+ return event.Events._accept_with(target)
+
+ @classmethod
+ def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
+ is_instance_event = (
+ event_key.identifier in _sessionevents_lifecycle_event_names
+ )
+
+ if is_instance_event:
+ if not raw or restore_load_context:
+
+ fn = event_key._listen_fn
+
+ def wrap(session, state, *arg, **kw):
+ if not raw:
+ target = state.obj()
+ if target is None:
+ # existing behavior is that if the object is
+ # garbage collected, no event is emitted
+ return
+ else:
+ target = state
+ if restore_load_context:
+ runid = state.runid
+ try:
+ return fn(session, target, *arg, **kw)
+ finally:
+ if restore_load_context:
+ state.runid = runid
+
+ event_key = event_key.with_wrapper(wrap)
+
+ event_key.base_listen(**kw)
+
+ def do_orm_execute(self, orm_execute_state):
+ """Intercept statement executions that occur on behalf of an
+ ORM :class:`.Session` object.
+
+ This event is invoked for all top-level SQL statements invoked from the
+ :meth:`_orm.Session.execute` method, as well as related methods such as
+ :meth:`_orm.Session.scalars` and :meth:`_orm.Session.scalar`. As of
+ SQLAlchemy 1.4, all ORM queries emitted on behalf of a
+ :class:`_orm.Session` will flow through this method, so this event hook
+ provides the single point at which ORM queries of all types may be
+ intercepted before they are invoked, and additionally to replace their
+ execution with a different process.
+
+ .. note:: The :meth:`_orm.SessionEvents.do_orm_execute` event hook
+ is triggered **for ORM statement executions only**, meaning those
+ invoked via the :meth:`_orm.Session.execute` and similar methods on
+ the :class:`_orm.Session` object. It does **not** trigger for
+ statements that are invoked by SQLAlchemy Core only, i.e. statements
+ invoked directly using :meth:`_engine.Connection.execute` or
+ otherwise originating from an :class:`_engine.Engine` object without
+ any :class:`_orm.Session` involved. To intercept **all** SQL
+ executions regardless of whether the Core or ORM APIs are in use,
+ see the event hooks at
+ :class:`.ConnectionEvents`, such as
+ :meth:`.ConnectionEvents.before_execute` and
+ :meth:`.ConnectionEvents.before_cursor_execute`.
+
+ This event is a ``do_`` event, meaning it has the capability to replace
+ the operation that the :meth:`_orm.Session.execute` method normally
+ performs. The intended use for this includes sharding and
+ result-caching schemes which may seek to invoke the same statement
+ across multiple database connections, returning a result that is
+ merged from each of them, or which don't invoke the statement at all,
+ instead returning data from a cache.
+
+ The hook intends to replace the use of the
+ ``Query._execute_and_instances`` method that could be subclassed prior
+ to SQLAlchemy 1.4.
+
+ :param orm_execute_state: an instance of :class:`.ORMExecuteState`
+ which contains all information about the current execution, as well
+ as helper functions used to derive other commonly required
+ information. See that object for details.
+
+ .. seealso::
+
+ :ref:`session_execute_events` - top level documentation on how
+ to use :meth:`_orm.SessionEvents.do_orm_execute`
+
+ :class:`.ORMExecuteState` - the object passed to the
+ :meth:`_orm.SessionEvents.do_orm_execute` event which contains
+ all information about the statement to be invoked. It also
+ provides an interface to extend the current statement, options,
+ and parameters as well as an option that allows programmatic
+ invocation of the statement at any point.
+
+ :ref:`examples_session_orm_events` - includes examples of using
+ :meth:`_orm.SessionEvents.do_orm_execute`
+
+ :ref:`examples_caching` - an example of how to integrate
+ Dogpile caching with the ORM :class:`_orm.Session` making use
+ of the :meth:`_orm.SessionEvents.do_orm_execute` event hook.
+
+ :ref:`examples_sharding` - the Horizontal Sharding example /
+ extension relies upon the
+ :meth:`_orm.SessionEvents.do_orm_execute` event hook to invoke a
+ SQL statement on multiple backends and return a merged result.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ def after_transaction_create(self, session, transaction):
+ """Execute when a new :class:`.SessionTransaction` is created.
+
+ This event differs from :meth:`~.SessionEvents.after_begin`
+ in that it occurs for each :class:`.SessionTransaction`
+ overall, as opposed to when transactions are begun
+ on individual database connections. It is also invoked
+ for nested transactions and subtransactions, and is always
+ matched by a corresponding
+ :meth:`~.SessionEvents.after_transaction_end` event
+ (assuming normal operation of the :class:`.Session`).
+
+ :param session: the target :class:`.Session`.
+ :param transaction: the target :class:`.SessionTransaction`.
+
+ To detect if this is the outermost
+ :class:`.SessionTransaction`, as opposed to a "subtransaction" or a
+ SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute
+ is ``None``::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_create(session, transaction):
+ if transaction.parent is None:
+ # work with top-level transaction
+
+ To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the
+ :attr:`.SessionTransaction.nested` attribute::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_create(session, transaction):
+ if transaction.nested:
+ # work with SAVEPOINT transaction
+
+
+ .. seealso::
+
+ :class:`.SessionTransaction`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ def after_transaction_end(self, session, transaction):
+ """Execute when the span of a :class:`.SessionTransaction` ends.
+
+ This event differs from :meth:`~.SessionEvents.after_commit`
+ in that it corresponds to all :class:`.SessionTransaction`
+ objects in use, including those for nested transactions
+ and subtransactions, and is always matched by a corresponding
+ :meth:`~.SessionEvents.after_transaction_create` event.
+
+ :param session: the target :class:`.Session`.
+ :param transaction: the target :class:`.SessionTransaction`.
+
+ To detect if this is the outermost
+ :class:`.SessionTransaction`, as opposed to a "subtransaction" or a
+ SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute
+ is ``None``::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_end(session, transaction):
+ if transaction.parent is None:
+ # work with top-level transaction
+
+ To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the
+ :attr:`.SessionTransaction.nested` attribute::
+
+ @event.listens_for(session, "after_transaction_create")
+ def after_transaction_end(session, transaction):
+ if transaction.nested:
+ # work with SAVEPOINT transaction
+
+
+ .. seealso::
+
+ :class:`.SessionTransaction`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ """
+
+ def before_commit(self, session):
+ """Execute before commit is called.
+
+ .. note::
+
+ The :meth:`~.SessionEvents.before_commit` hook is *not* per-flush,
+ that is, the :class:`.Session` can emit SQL to the database
+ many times within the scope of a transaction.
+ For interception of these events, use the
+ :meth:`~.SessionEvents.before_flush`,
+ :meth:`~.SessionEvents.after_flush`, or
+ :meth:`~.SessionEvents.after_flush_postexec`
+ events.
+
+ :param session: The target :class:`.Session`.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.after_commit`
+
+ :meth:`~.SessionEvents.after_begin`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ def after_commit(self, session):
+ """Execute after a commit has occurred.
+
+ .. note::
+
+ The :meth:`~.SessionEvents.after_commit` hook is *not* per-flush,
+ that is, the :class:`.Session` can emit SQL to the database
+ many times within the scope of a transaction.
+ For interception of these events, use the
+ :meth:`~.SessionEvents.before_flush`,
+ :meth:`~.SessionEvents.after_flush`, or
+ :meth:`~.SessionEvents.after_flush_postexec`
+ events.
+
+ .. note::
+
+ The :class:`.Session` is not in an active transaction
+ when the :meth:`~.SessionEvents.after_commit` event is invoked,
+ and therefore can not emit SQL. To emit SQL corresponding to
+ every transaction, use the :meth:`~.SessionEvents.before_commit`
+ event.
+
+ :param session: The target :class:`.Session`.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_commit`
+
+ :meth:`~.SessionEvents.after_begin`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ def after_rollback(self, session):
+ """Execute after a real DBAPI rollback has occurred.
+
+ Note that this event only fires when the *actual* rollback against
+ the database occurs - it does *not* fire each time the
+ :meth:`.Session.rollback` method is called, if the underlying
+ DBAPI transaction has already been rolled back. In many
+ cases, the :class:`.Session` will not be in
+ an "active" state during this event, as the current
+ transaction is not valid. To acquire a :class:`.Session`
+ which is active after the outermost rollback has proceeded,
+ use the :meth:`.SessionEvents.after_soft_rollback` event, checking the
+ :attr:`.Session.is_active` flag.
+
+ :param session: The target :class:`.Session`.
+
+ """
+
+ def after_soft_rollback(self, session, previous_transaction):
+ """Execute after any rollback has occurred, including "soft"
+ rollbacks that don't actually emit at the DBAPI level.
+
+ This corresponds to both nested and outer rollbacks, i.e.
+ the innermost rollback that calls the DBAPI's
+ rollback() method, as well as the enclosing rollback
+ calls that only pop themselves from the transaction stack.
+
+ The given :class:`.Session` can be used to invoke SQL and
+ :meth:`.Session.query` operations after an outermost rollback
+ by first checking the :attr:`.Session.is_active` flag::
+
+ @event.listens_for(Session, "after_soft_rollback")
+ def do_something(session, previous_transaction):
+ if session.is_active:
+ session.execute("select * from some_table")
+
+ :param session: The target :class:`.Session`.
+ :param previous_transaction: The :class:`.SessionTransaction`
+ transactional marker object which was just closed. The current
+ :class:`.SessionTransaction` for the given :class:`.Session` is
+ available via the :attr:`.Session.transaction` attribute.
+
+ """
+
+ def before_flush(self, session, flush_context, instances):
+ """Execute before flush process has started.
+
+ :param session: The target :class:`.Session`.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+ :param instances: Usually ``None``, this is the collection of
+ objects which can be passed to the :meth:`.Session.flush` method
+ (note this usage is deprecated).
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.after_flush`
+
+ :meth:`~.SessionEvents.after_flush_postexec`
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_flush(self, session, flush_context):
+ """Execute after flush has completed, but before commit has been
+ called.
+
+ Note that the session's state is still in pre-flush, i.e. 'new',
+ 'dirty', and 'deleted' lists still show pre-flush state as well
+ as the history settings on instance attributes.
+
+ .. warning:: This event runs after the :class:`.Session` has emitted
+ SQL to modify the database, but **before** it has altered its
+ internal state to reflect those changes, including that newly
+ inserted objects are placed into the identity map. ORM operations
+ emitted within this event such as loads of related items
+ may produce new identity map entries that will immediately
+ be replaced, sometimes causing confusing results. SQLAlchemy will
+ emit a warning for this condition as of version 1.3.9.
+
+ :param session: The target :class:`.Session`.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_flush`
+
+ :meth:`~.SessionEvents.after_flush_postexec`
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_flush_postexec(self, session, flush_context):
+ """Execute after flush has completed, and after the post-exec
+ state occurs.
+
+ This will be when the 'new', 'dirty', and 'deleted' lists are in
+ their final state. An actual commit() may or may not have
+ occurred, depending on whether or not the flush started its own
+ transaction or participated in a larger transaction.
+
+ :param session: The target :class:`.Session`.
+ :param flush_context: Internal :class:`.UOWTransaction` object
+ which handles the details of the flush.
+
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_flush`
+
+ :meth:`~.SessionEvents.after_flush`
+
+ :ref:`session_persistence_events`
+
+ """
+
+ def after_begin(self, session, transaction, connection):
+ """Execute after a transaction is begun on a connection
+
+ :param session: The target :class:`.Session`.
+ :param transaction: The :class:`.SessionTransaction`.
+ :param connection: The :class:`_engine.Connection` object
+ which will be used for SQL statements.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_commit`
+
+ :meth:`~.SessionEvents.after_commit`
+
+ :meth:`~.SessionEvents.after_transaction_create`
+
+ :meth:`~.SessionEvents.after_transaction_end`
+
+ """
+
+ @_lifecycle_event
+ def before_attach(self, session, instance):
+ """Execute before an instance is attached to a session.
+
+ This is called before an add, delete or merge causes
+ the object to be part of the session.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.after_attach`
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def after_attach(self, session, instance):
+ """Execute after an instance is attached to a session.
+
+ This is called after an add, delete or merge.
+
+ .. note::
+
+ As of 0.8, this event fires off *after* the item
+ has been fully associated with the session, which is
+ different than previous releases. For event
+ handlers that require the object not yet
+ be part of session state (such as handlers which
+ may autoflush while the target object is not
+ yet complete) consider the
+ new :meth:`.before_attach` event.
+
+ .. seealso::
+
+ :meth:`~.SessionEvents.before_attach`
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @event._legacy_signature(
+ "0.9",
+ ["session", "query", "query_context", "result"],
+ lambda update_context: (
+ update_context.session,
+ update_context.query,
+ None,
+ update_context.result,
+ ),
+ )
+ def after_bulk_update(self, update_context):
+ """Execute after an ORM UPDATE against a WHERE expression has been
+ invoked.
+
+ This is called as a result of the :meth:`_query.Query.update` method.
+
+ :param update_context: an "update context" object which contains
+ details about the update, including these attributes:
+
+ * ``session`` - the :class:`.Session` involved
+ * ``query`` -the :class:`_query.Query`
+ object that this update operation
+ was called upon.
+ * ``values`` The "values" dictionary that was passed to
+ :meth:`_query.Query.update`.
+ * ``result`` the :class:`_engine.CursorResult`
+ returned as a result of the
+ bulk UPDATE operation.
+
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile_update`
+
+ :meth:`.SessionEvents.after_bulk_delete`
+
+ """
+
+ @event._legacy_signature(
+ "0.9",
+ ["session", "query", "query_context", "result"],
+ lambda delete_context: (
+ delete_context.session,
+ delete_context.query,
+ None,
+ delete_context.result,
+ ),
+ )
+ def after_bulk_delete(self, delete_context):
+ """Execute after ORM DELETE against a WHERE expression has been
+ invoked.
+
+ This is called as a result of the :meth:`_query.Query.delete` method.
+
+ :param delete_context: a "delete context" object which contains
+ details about the update, including these attributes:
+
+ * ``session`` - the :class:`.Session` involved
+ * ``query`` -the :class:`_query.Query`
+ object that this update operation
+ was called upon.
+ * ``result`` the :class:`_engine.CursorResult`
+ returned as a result of the
+ bulk DELETE operation.
+
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile_delete`
+
+ :meth:`.SessionEvents.after_bulk_update`
+
+ """
+
+ @_lifecycle_event
+ def transient_to_pending(self, session, instance):
+ """Intercept the "transient to pending" transition for a specific
+ object.
+
+ This event is a specialization of the
+ :meth:`.SessionEvents.after_attach` event which is only invoked
+ for this specific transition. It is invoked typically during the
+ :meth:`.Session.add` call.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def pending_to_transient(self, session, instance):
+ """Intercept the "pending to transient" transition for a specific
+ object.
+
+ This less common transition occurs when an pending object that has
+ not been flushed is evicted from the session; this can occur
+ when the :meth:`.Session.rollback` method rolls back the transaction,
+ or when the :meth:`.Session.expunge` method is used.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def persistent_to_transient(self, session, instance):
+ """Intercept the "persistent to transient" transition for a specific
+ object.
+
+ This less common transition occurs when an pending object that has
+ has been flushed is evicted from the session; this can occur
+ when the :meth:`.Session.rollback` method rolls back the transaction.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def pending_to_persistent(self, session, instance):
+ """Intercept the "pending to persistent"" transition for a specific
+ object.
+
+ This event is invoked within the flush process, and is
+ similar to scanning the :attr:`.Session.new` collection within
+ the :meth:`.SessionEvents.after_flush` event. However, in this
+ case the object has already been moved to the persistent state
+ when the event is called.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def detached_to_persistent(self, session, instance):
+ """Intercept the "detached to persistent" transition for a specific
+ object.
+
+ This event is a specialization of the
+ :meth:`.SessionEvents.after_attach` event which is only invoked
+ for this specific transition. It is invoked typically during the
+ :meth:`.Session.add` call, as well as during the
+ :meth:`.Session.delete` call if the object was not previously
+ associated with the
+ :class:`.Session` (note that an object marked as "deleted" remains
+ in the "persistent" state until the flush proceeds).
+
+ .. note::
+
+ If the object becomes persistent as part of a call to
+ :meth:`.Session.delete`, the object is **not** yet marked as
+ deleted when this event is called. To detect deleted objects,
+ check the ``deleted`` flag sent to the
+ :meth:`.SessionEvents.persistent_to_detached` to event after the
+ flush proceeds, or check the :attr:`.Session.deleted` collection
+ within the :meth:`.SessionEvents.before_flush` event if deleted
+ objects need to be intercepted before the flush.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def loaded_as_persistent(self, session, instance):
+ """Intercept the "loaded as persistent" transition for a specific
+ object.
+
+ This event is invoked within the ORM loading process, and is invoked
+ very similarly to the :meth:`.InstanceEvents.load` event. However,
+ the event here is linkable to a :class:`.Session` class or instance,
+ rather than to a mapper or class hierarchy, and integrates
+ with the other session lifecycle events smoothly. The object
+ is guaranteed to be present in the session's identity map when
+ this event is called.
+
+ .. note:: This event is invoked within the loader process before
+ eager loaders may have been completed, and the object's state may
+ not be complete. Additionally, invoking row-level refresh
+ operations on the object will place the object into a new loader
+ context, interfering with the existing load context. See the note
+ on :meth:`.InstanceEvents.load` for background on making use of the
+ :paramref:`.SessionEvents.restore_load_context` parameter, which
+ works in the same manner as that of
+ :paramref:`.InstanceEvents.restore_load_context`, in order to
+ resolve this scenario.
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def persistent_to_deleted(self, session, instance):
+ """Intercept the "persistent to deleted" transition for a specific
+ object.
+
+ This event is invoked when a persistent object's identity
+ is deleted from the database within a flush, however the object
+ still remains associated with the :class:`.Session` until the
+ transaction completes.
+
+ If the transaction is rolled back, the object moves again
+ to the persistent state, and the
+ :meth:`.SessionEvents.deleted_to_persistent` event is called.
+ If the transaction is committed, the object becomes detached,
+ which will emit the :meth:`.SessionEvents.deleted_to_detached`
+ event.
+
+ Note that while the :meth:`.Session.delete` method is the primary
+ public interface to mark an object as deleted, many objects
+ get deleted due to cascade rules, which are not always determined
+ until flush time. Therefore, there's no way to catch
+ every object that will be deleted until the flush has proceeded.
+ the :meth:`.SessionEvents.persistent_to_deleted` event is therefore
+ invoked at the end of a flush.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def deleted_to_persistent(self, session, instance):
+ """Intercept the "deleted to persistent" transition for a specific
+ object.
+
+ This transition occurs only when an object that's been deleted
+ successfully in a flush is restored due to a call to
+ :meth:`.Session.rollback`. The event is not called under
+ any other circumstances.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def deleted_to_detached(self, session, instance):
+ """Intercept the "deleted to detached" transition for a specific
+ object.
+
+ This event is invoked when a deleted object is evicted
+ from the session. The typical case when this occurs is when
+ the transaction for a :class:`.Session` in which the object
+ was deleted is committed; the object moves from the deleted
+ state to the detached state.
+
+ It is also invoked for objects that were deleted in a flush
+ when the :meth:`.Session.expunge_all` or :meth:`.Session.close`
+ events are called, as well as if the object is individually
+ expunged from its deleted state via :meth:`.Session.expunge`.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+ @_lifecycle_event
+ def persistent_to_detached(self, session, instance):
+ """Intercept the "persistent to detached" transition for a specific
+ object.
+
+ This event is invoked when a persistent object is evicted
+ from the session. There are many conditions that cause this
+ to happen, including:
+
+ * using a method such as :meth:`.Session.expunge`
+ or :meth:`.Session.close`
+
+ * Calling the :meth:`.Session.rollback` method, when the object
+ was part of an INSERT statement for that session's transaction
+
+
+ :param session: target :class:`.Session`
+
+ :param instance: the ORM-mapped instance being operated upon.
+
+ :param deleted: boolean. If True, indicates this object moved
+ to the detached state because it was marked as deleted and flushed.
+
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_lifecycle_events`
+
+ """
+
+
+class AttributeEvents(event.Events):
+ r"""Define events for object attributes.
+
+ These are typically defined on the class-bound descriptor for the
+ target class.
+
+ For example, to register a listener that will receive the
+ :meth:`_orm.AttributeEvents.append` event::
+
+ from sqlalchemy import event
+
+ @event.listens_for(MyClass.collection, 'append', propagate=True)
+ def my_append_listener(target, value, initiator):
+ print("received append event for target: %s" % target)
+
+
+ Listeners have the option to return a possibly modified version of the
+ value, when the :paramref:`.AttributeEvents.retval` flag is passed to
+ :func:`.event.listen` or :func:`.event.listens_for`, such as below,
+ illustrated using the :meth:`_orm.AttributeEvents.set` event::
+
+ def validate_phone(target, value, oldvalue, initiator):
+ "Strip non-numeric characters from a phone number"
+
+ return re.sub(r'\D', '', value)
+
+ # setup listener on UserContact.phone attribute, instructing
+ # it to use the return value
+ listen(UserContact.phone, 'set', validate_phone, retval=True)
+
+ A validation function like the above can also raise an exception
+ such as :exc:`ValueError` to halt the operation.
+
+ The :paramref:`.AttributeEvents.propagate` flag is also important when
+ applying listeners to mapped classes that also have mapped subclasses,
+ as when using mapper inheritance patterns::
+
+
+ @event.listens_for(MySuperClass.attr, 'set', propagate=True)
+ def receive_set(target, value, initiator):
+ print("value set: %s" % target)
+
+ The full list of modifiers available to the :func:`.event.listen`
+ and :func:`.event.listens_for` functions are below.
+
+ :param active_history=False: When True, indicates that the
+ "set" event would like to receive the "old" value being
+ replaced unconditionally, even if this requires firing off
+ database loads. Note that ``active_history`` can also be
+ set directly via :func:`.column_property` and
+ :func:`_orm.relationship`.
+
+ :param propagate=False: When True, the listener function will
+ be established not just for the class attribute given, but
+ for attributes of the same name on all current subclasses
+ of that class, as well as all future subclasses of that
+ class, using an additional listener that listens for
+ instrumentation events.
+ :param raw=False: When True, the "target" argument to the
+ event will be the :class:`.InstanceState` management
+ object, rather than the mapped instance itself.
+ :param retval=False: when True, the user-defined event
+ listening must return the "value" argument from the
+ function. This gives the listening function the opportunity
+ to change the value that is ultimately used for a "set"
+ or "append" event.
+
+ """
+
+ _target_class_doc = "SomeClass.some_attribute"
+ _dispatch_target = QueryableAttribute
+
+ @staticmethod
+ def _set_dispatch(cls, dispatch_cls):
+ dispatch = event.Events._set_dispatch(cls, dispatch_cls)
+ dispatch_cls._active_history = False
+ return dispatch
+
+ @classmethod
+ def _accept_with(cls, target):
+ # TODO: coverage
+ if isinstance(target, interfaces.MapperProperty):
+ return getattr(target.parent.class_, target.key)
+ else:
+ return target
+
+ @classmethod
+ def _listen(
+ cls,
+ event_key,
+ active_history=False,
+ raw=False,
+ retval=False,
+ propagate=False,
+ ):
+
+ target, fn = event_key.dispatch_target, event_key._listen_fn
+
+ if active_history:
+ target.dispatch._active_history = True
+
+ if not raw or not retval:
+
+ def wrap(target, *arg):
+ if not raw:
+ target = target.obj()
+ if not retval:
+ if arg:
+ value = arg[0]
+ else:
+ value = None
+ fn(target, *arg)
+ return value
+ else:
+ return fn(target, *arg)
+
+ event_key = event_key.with_wrapper(wrap)
+
+ event_key.base_listen(propagate=propagate)
+
+ if propagate:
+ manager = instrumentation.manager_of_class(target.class_)
+
+ for mgr in manager.subclass_managers(True):
+ event_key.with_dispatch_target(mgr[target.key]).base_listen(
+ propagate=True
+ )
+ if active_history:
+ mgr[target.key].dispatch._active_history = True
+
+ def append(self, target, value, initiator):
+ """Receive a collection append event.
+
+ The append event is invoked for each element as it is appended
+ to the collection. This occurs for single-item appends as well
+ as for a "bulk replace" operation.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being appended. If this listener
+ is registered with ``retval=True``, the listener
+ function must return this value, or a new value which
+ replaces it.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation, as well as be inspected for information
+ about the source of the event.
+ :return: if the event was registered with ``retval=True``,
+ the given value, or a new effective value, should be returned.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ :meth:`.AttributeEvents.bulk_replace`
+
+ """
+
+ def append_wo_mutation(self, target, value, initiator):
+ """Receive a collection append event where the collection was not
+ actually mutated.
+
+ This event differs from :meth:`_orm.AttributeEvents.append` in that
+ it is fired off for de-duplicating collections such as sets and
+ dictionaries, when the object already exists in the target collection.
+ The event does not have a return value and the identity of the
+ given object cannot be changed.
+
+ The event is used for cascading objects into a :class:`_orm.Session`
+ when the collection has already been mutated via a backref event.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value that would be appended if the object did not
+ already exist in the collection.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation, as well as be inspected for information
+ about the source of the event.
+
+ :return: No return value is defined for this event.
+
+ .. versionadded:: 1.4.15
+
+ """
+
+ def bulk_replace(self, target, values, initiator):
+ """Receive a collection 'bulk replace' event.
+
+ This event is invoked for a sequence of values as they are incoming
+ to a bulk collection set operation, which can be
+ modified in place before the values are treated as ORM objects.
+ This is an "early hook" that runs before the bulk replace routine
+ attempts to reconcile which objects are already present in the
+ collection and which are being removed by the net replace operation.
+
+ It is typical that this method be combined with use of the
+ :meth:`.AttributeEvents.append` event. When using both of these
+ events, note that a bulk replace operation will invoke
+ the :meth:`.AttributeEvents.append` event for all new items,
+ even after :meth:`.AttributeEvents.bulk_replace` has been invoked
+ for the collection as a whole. In order to determine if an
+ :meth:`.AttributeEvents.append` event is part of a bulk replace,
+ use the symbol :attr:`~.attributes.OP_BULK_REPLACE` to test the
+ incoming initiator::
+
+ from sqlalchemy.orm.attributes import OP_BULK_REPLACE
+
+ @event.listens_for(SomeObject.collection, "bulk_replace")
+ def process_collection(target, values, initiator):
+ values[:] = [_make_value(value) for value in values]
+
+ @event.listens_for(SomeObject.collection, "append", retval=True)
+ def process_collection(target, value, initiator):
+ # make sure bulk_replace didn't already do it
+ if initiator is None or initiator.op is not OP_BULK_REPLACE:
+ return _make_value(value)
+ else:
+ return value
+
+ .. versionadded:: 1.2
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: a sequence (e.g. a list) of the values being set. The
+ handler can modify this list in place.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+
+ """
+
+ def remove(self, target, value, initiator):
+ """Receive a collection remove event.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being removed.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation.
+
+ .. versionchanged:: 0.9.0 the ``initiator`` argument is now
+ passed as a :class:`.attributes.Event` object, and may be
+ modified by backref handlers within a chain of backref-linked
+ events.
+
+ :return: No return value is defined for this event.
+
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+ def set(self, target, value, oldvalue, initiator):
+ """Receive a scalar set event.
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value being set. If this listener
+ is registered with ``retval=True``, the listener
+ function must return this value, or a new value which
+ replaces it.
+ :param oldvalue: the previous value being replaced. This
+ may also be the symbol ``NEVER_SET`` or ``NO_VALUE``.
+ If the listener is registered with ``active_history=True``,
+ the previous value of the attribute will be loaded from
+ the database if the existing value is currently unloaded
+ or expired.
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event. May be modified
+ from its original value by backref handlers in order to control
+ chained event propagation.
+
+ .. versionchanged:: 0.9.0 the ``initiator`` argument is now
+ passed as a :class:`.attributes.Event` object, and may be
+ modified by backref handlers within a chain of backref-linked
+ events.
+
+ :return: if the event was registered with ``retval=True``,
+ the given value, or a new effective value, should be returned.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+ def init_scalar(self, target, value, dict_):
+ r"""Receive a scalar "init" event.
+
+ This event is invoked when an uninitialized, unpersisted scalar
+ attribute is accessed, e.g. read::
+
+
+ x = my_object.some_attribute
+
+ The ORM's default behavior when this occurs for an un-initialized
+ attribute is to return the value ``None``; note this differs from
+ Python's usual behavior of raising ``AttributeError``. The
+ event here can be used to customize what value is actually returned,
+ with the assumption that the event listener would be mirroring
+ a default generator that is configured on the Core
+ :class:`_schema.Column`
+ object as well.
+
+ Since a default generator on a :class:`_schema.Column`
+ might also produce
+ a changing value such as a timestamp, the
+ :meth:`.AttributeEvents.init_scalar`
+ event handler can also be used to **set** the newly returned value, so
+ that a Core-level default generation function effectively fires off
+ only once, but at the moment the attribute is accessed on the
+ non-persisted object. Normally, no change to the object's state
+ is made when an uninitialized attribute is accessed (much older
+ SQLAlchemy versions did in fact change the object's state).
+
+ If a default generator on a column returned a particular constant,
+ a handler might be used as follows::
+
+ SOME_CONSTANT = 3.1415926
+
+ class MyClass(Base):
+ # ...
+
+ some_attribute = Column(Numeric, default=SOME_CONSTANT)
+
+ @event.listens_for(
+ MyClass.some_attribute, "init_scalar",
+ retval=True, propagate=True)
+ def _init_some_attribute(target, dict_, value):
+ dict_['some_attribute'] = SOME_CONSTANT
+ return SOME_CONSTANT
+
+ Above, we initialize the attribute ``MyClass.some_attribute`` to the
+ value of ``SOME_CONSTANT``. The above code includes the following
+ features:
+
+ * By setting the value ``SOME_CONSTANT`` in the given ``dict_``,
+ we indicate that this value is to be persisted to the database.
+ This supersedes the use of ``SOME_CONSTANT`` in the default generator
+ for the :class:`_schema.Column`. The ``active_column_defaults.py``
+ example given at :ref:`examples_instrumentation` illustrates using
+ the same approach for a changing default, e.g. a timestamp
+ generator. In this particular example, it is not strictly
+ necessary to do this since ``SOME_CONSTANT`` would be part of the
+ INSERT statement in either case.
+
+ * By establishing the ``retval=True`` flag, the value we return
+ from the function will be returned by the attribute getter.
+ Without this flag, the event is assumed to be a passive observer
+ and the return value of our function is ignored.
+
+ * The ``propagate=True`` flag is significant if the mapped class
+ includes inheriting subclasses, which would also make use of this
+ event listener. Without this flag, an inheriting subclass will
+ not use our event handler.
+
+ In the above example, the attribute set event
+ :meth:`.AttributeEvents.set` as well as the related validation feature
+ provided by :obj:`_orm.validates` is **not** invoked when we apply our
+ value to the given ``dict_``. To have these events to invoke in
+ response to our newly generated value, apply the value to the given
+ object as a normal attribute set operation::
+
+ SOME_CONSTANT = 3.1415926
+
+ @event.listens_for(
+ MyClass.some_attribute, "init_scalar",
+ retval=True, propagate=True)
+ def _init_some_attribute(target, dict_, value):
+ # will also fire off attribute set events
+ target.some_attribute = SOME_CONSTANT
+ return SOME_CONSTANT
+
+ When multiple listeners are set up, the generation of the value
+ is "chained" from one listener to the next by passing the value
+ returned by the previous listener that specifies ``retval=True``
+ as the ``value`` argument of the next listener.
+
+ .. versionadded:: 1.1
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param value: the value that is to be returned before this event
+ listener were invoked. This value begins as the value ``None``,
+ however will be the return value of the previous event handler
+ function if multiple listeners are present.
+ :param dict\_: the attribute dictionary of this mapped object.
+ This is normally the ``__dict__`` of the object, but in all cases
+ represents the destination that the attribute system uses to get
+ at the actual value of this attribute. Placing the value in this
+ dictionary has the effect that the value will be used in the
+ INSERT statement generated by the unit of work.
+
+
+ .. seealso::
+
+ :meth:`.AttributeEvents.init_collection` - collection version
+ of this event
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ :ref:`examples_instrumentation` - see the
+ ``active_column_defaults.py`` example.
+
+ """
+
+ def init_collection(self, target, collection, collection_adapter):
+ """Receive a 'collection init' event.
+
+ This event is triggered for a collection-based attribute, when
+ the initial "empty collection" is first generated for a blank
+ attribute, as well as for when the collection is replaced with
+ a new one, such as via a set event.
+
+ E.g., given that ``User.addresses`` is a relationship-based
+ collection, the event is triggered here::
+
+ u1 = User()
+ u1.addresses.append(a1) # <- new collection
+
+ and also during replace operations::
+
+ u1.addresses = [a2, a3] # <- new collection
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+ :param collection: the new collection. This will always be generated
+ from what was specified as
+ :paramref:`_orm.relationship.collection_class`, and will always
+ be empty.
+ :param collection_adapter: the :class:`.CollectionAdapter` that will
+ mediate internal access to the collection.
+
+ .. versionadded:: 1.0.0 :meth:`.AttributeEvents.init_collection`
+ and :meth:`.AttributeEvents.dispose_collection` events.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ :meth:`.AttributeEvents.init_scalar` - "scalar" version of this
+ event.
+
+ """
+
+ def dispose_collection(self, target, collection, collection_adapter):
+ """Receive a 'collection dispose' event.
+
+ This event is triggered for a collection-based attribute when
+ a collection is replaced, that is::
+
+ u1.addresses.append(a1)
+
+ u1.addresses = [a2, a3] # <- old collection is disposed
+
+ The old collection received will contain its previous contents.
+
+ .. versionchanged:: 1.2 The collection passed to
+ :meth:`.AttributeEvents.dispose_collection` will now have its
+ contents before the dispose intact; previously, the collection
+ would be empty.
+
+ .. versionadded:: 1.0.0 the :meth:`.AttributeEvents.init_collection`
+ and :meth:`.AttributeEvents.dispose_collection` events.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+ def modified(self, target, initiator):
+ """Receive a 'modified' event.
+
+ This event is triggered when the :func:`.attributes.flag_modified`
+ function is used to trigger a modify event on an attribute without
+ any specific value being set.
+
+ .. versionadded:: 1.2
+
+ :param target: the object instance receiving the event.
+ If the listener is registered with ``raw=True``, this will
+ be the :class:`.InstanceState` object.
+
+ :param initiator: An instance of :class:`.attributes.Event`
+ representing the initiation of the event.
+
+ .. seealso::
+
+ :class:`.AttributeEvents` - background on listener options such
+ as propagation to subclasses.
+
+ """
+
+
+class QueryEvents(event.Events):
+ """Represent events within the construction of a :class:`_query.Query`
+ object.
+
+ The :class:`_orm.QueryEvents` hooks are now superseded by the
+ :meth:`_orm.SessionEvents.do_orm_execute` event hook.
+
+ """
+
+ _target_class_doc = "SomeQuery"
+ _dispatch_target = Query
+
+ def before_compile(self, query):
+ """Receive the :class:`_query.Query`
+ object before it is composed into a
+ core :class:`_expression.Select` object.
+
+ .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile` event
+ is superseded by the much more capable
+ :meth:`_orm.SessionEvents.do_orm_execute` hook. In version 1.4,
+ the :meth:`_orm.QueryEvents.before_compile` event is **no longer
+ used** for ORM-level attribute loads, such as loads of deferred
+ or expired attributes as well as relationship loaders. See the
+ new examples in :ref:`examples_session_orm_events` which
+ illustrate new ways of intercepting and modifying ORM queries
+ for the most common purpose of adding arbitrary filter criteria.
+
+
+ This event is intended to allow changes to the query given::
+
+ @event.listens_for(Query, "before_compile", retval=True)
+ def no_deleted(query):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+ return query
+
+ The event should normally be listened with the ``retval=True``
+ parameter set, so that the modified query may be returned.
+
+ The :meth:`.QueryEvents.before_compile` event by default
+ will disallow "baked" queries from caching a query, if the event
+ hook returns a new :class:`_query.Query` object.
+ This affects both direct
+ use of the baked query extension as well as its operation within
+ lazy loaders and eager loaders for relationships. In order to
+ re-establish the query being cached, apply the event adding the
+ ``bake_ok`` flag::
+
+ @event.listens_for(
+ Query, "before_compile", retval=True, bake_ok=True)
+ def my_event(query):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+ return query
+
+ When ``bake_ok`` is set to True, the event hook will only be invoked
+ once, and not called for subsequent invocations of a particular query
+ that is being cached.
+
+ .. versionadded:: 1.3.11 - added the "bake_ok" flag to the
+ :meth:`.QueryEvents.before_compile` event and disallowed caching via
+ the "baked" extension from occurring for event handlers that
+ return a new :class:`_query.Query` object if this flag is not set.
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile_update`
+
+ :meth:`.QueryEvents.before_compile_delete`
+
+ :ref:`baked_with_before_compile`
+
+ """
+
+ def before_compile_update(self, query, update_context):
+ """Allow modifications to the :class:`_query.Query` object within
+ :meth:`_query.Query.update`.
+
+ .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_update`
+ event is superseded by the much more capable
+ :meth:`_orm.SessionEvents.do_orm_execute` hook.
+
+ Like the :meth:`.QueryEvents.before_compile` event, if the event
+ is to be used to alter the :class:`_query.Query` object, it should
+ be configured with ``retval=True``, and the modified
+ :class:`_query.Query` object returned, as in ::
+
+ @event.listens_for(Query, "before_compile_update", retval=True)
+ def no_deleted(query, update_context):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+
+ update_context.values['timestamp'] = datetime.utcnow()
+ return query
+
+ The ``.values`` dictionary of the "update context" object can also
+ be modified in place as illustrated above.
+
+ :param query: a :class:`_query.Query` instance; this is also
+ the ``.query`` attribute of the given "update context"
+ object.
+
+ :param update_context: an "update context" object which is
+ the same kind of object as described in
+ :paramref:`.QueryEvents.after_bulk_update.update_context`.
+ The object has a ``.values`` attribute in an UPDATE context which is
+ the dictionary of parameters passed to :meth:`_query.Query.update`.
+ This
+ dictionary can be modified to alter the VALUES clause of the
+ resulting UPDATE statement.
+
+ .. versionadded:: 1.2.17
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile`
+
+ :meth:`.QueryEvents.before_compile_delete`
+
+
+ """
+
+ def before_compile_delete(self, query, delete_context):
+ """Allow modifications to the :class:`_query.Query` object within
+ :meth:`_query.Query.delete`.
+
+ .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_delete`
+ event is superseded by the much more capable
+ :meth:`_orm.SessionEvents.do_orm_execute` hook.
+
+ Like the :meth:`.QueryEvents.before_compile` event, this event
+ should be configured with ``retval=True``, and the modified
+ :class:`_query.Query` object returned, as in ::
+
+ @event.listens_for(Query, "before_compile_delete", retval=True)
+ def no_deleted(query, delete_context):
+ for desc in query.column_descriptions:
+ if desc['type'] is User:
+ entity = desc['entity']
+ query = query.filter(entity.deleted == False)
+ return query
+
+ :param query: a :class:`_query.Query` instance; this is also
+ the ``.query`` attribute of the given "delete context"
+ object.
+
+ :param delete_context: a "delete context" object which is
+ the same kind of object as described in
+ :paramref:`.QueryEvents.after_bulk_delete.delete_context`.
+
+ .. versionadded:: 1.2.17
+
+ .. seealso::
+
+ :meth:`.QueryEvents.before_compile`
+
+ :meth:`.QueryEvents.before_compile_update`
+
+
+ """
+
+ @classmethod
+ def _listen(cls, event_key, retval=False, bake_ok=False, **kw):
+ fn = event_key._listen_fn
+
+ if not retval:
+
+ def wrap(*arg, **kw):
+ if not retval:
+ query = arg[0]
+ fn(*arg, **kw)
+ return query
+ else:
+ return fn(*arg, **kw)
+
+ event_key = event_key.with_wrapper(wrap)
+ else:
+ # don't assume we can apply an attribute to the callable
+ def wrap(*arg, **kw):
+ return fn(*arg, **kw)
+
+ event_key = event_key.with_wrapper(wrap)
+
+ wrap._bake_ok = bake_ok
+
+ event_key.base_listen(**kw)
diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py
new file mode 100644
index 0000000..8dd4d90
--- /dev/null
+++ b/lib/sqlalchemy/orm/exc.py
@@ -0,0 +1,204 @@
+# orm/exc.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQLAlchemy ORM exceptions."""
+from .. import exc as sa_exc
+from .. import util
+from ..exc import MultipleResultsFound # noqa
+from ..exc import NoResultFound # noqa
+
+
+NO_STATE = (AttributeError, KeyError)
+"""Exception types that may be raised by instrumentation implementations."""
+
+
+class StaleDataError(sa_exc.SQLAlchemyError):
+ """An operation encountered database state that is unaccounted for.
+
+ Conditions which cause this to happen include:
+
+ * A flush may have attempted to update or delete rows
+ and an unexpected number of rows were matched during
+ the UPDATE or DELETE statement. Note that when
+ version_id_col is used, rows in UPDATE or DELETE statements
+ are also matched against the current known version
+ identifier.
+
+ * A mapped object with version_id_col was refreshed,
+ and the version number coming back from the database does
+ not match that of the object itself.
+
+ * A object is detached from its parent object, however
+ the object was previously attached to a different parent
+ identity which was garbage collected, and a decision
+ cannot be made if the new parent was really the most
+ recent "parent".
+
+ """
+
+
+ConcurrentModificationError = StaleDataError
+
+
+class FlushError(sa_exc.SQLAlchemyError):
+ """A invalid condition was detected during flush()."""
+
+
+class UnmappedError(sa_exc.InvalidRequestError):
+ """Base for exceptions that involve expected mappings not present."""
+
+
+class ObjectDereferencedError(sa_exc.SQLAlchemyError):
+ """An operation cannot complete due to an object being garbage
+ collected.
+
+ """
+
+
+class DetachedInstanceError(sa_exc.SQLAlchemyError):
+ """An attempt to access unloaded attributes on a
+ mapped instance that is detached."""
+
+ code = "bhk3"
+
+
+class UnmappedInstanceError(UnmappedError):
+ """An mapping operation was requested for an unknown instance."""
+
+ @util.preload_module("sqlalchemy.orm.base")
+ def __init__(self, obj, msg=None):
+ base = util.preloaded.orm_base
+
+ if not msg:
+ try:
+ base.class_mapper(type(obj))
+ name = _safe_cls_name(type(obj))
+ msg = (
+ "Class %r is mapped, but this instance lacks "
+ "instrumentation. This occurs when the instance "
+ "is created before sqlalchemy.orm.mapper(%s) "
+ "was called." % (name, name)
+ )
+ except UnmappedClassError:
+ msg = _default_unmapped(type(obj))
+ if isinstance(obj, type):
+ msg += (
+ "; was a class (%s) supplied where an instance was "
+ "required?" % _safe_cls_name(obj)
+ )
+ UnmappedError.__init__(self, msg)
+
+ def __reduce__(self):
+ return self.__class__, (None, self.args[0])
+
+
+class UnmappedClassError(UnmappedError):
+ """An mapping operation was requested for an unknown class."""
+
+ def __init__(self, cls, msg=None):
+ if not msg:
+ msg = _default_unmapped(cls)
+ UnmappedError.__init__(self, msg)
+
+ def __reduce__(self):
+ return self.__class__, (None, self.args[0])
+
+
+class ObjectDeletedError(sa_exc.InvalidRequestError):
+ """A refresh operation failed to retrieve the database
+ row corresponding to an object's known primary key identity.
+
+ A refresh operation proceeds when an expired attribute is
+ accessed on an object, or when :meth:`_query.Query.get` is
+ used to retrieve an object which is, upon retrieval, detected
+ as expired. A SELECT is emitted for the target row
+ based on primary key; if no row is returned, this
+ exception is raised.
+
+ The true meaning of this exception is simply that
+ no row exists for the primary key identifier associated
+ with a persistent object. The row may have been
+ deleted, or in some cases the primary key updated
+ to a new value, outside of the ORM's management of the target
+ object.
+
+ """
+
+ @util.preload_module("sqlalchemy.orm.base")
+ def __init__(self, state, msg=None):
+ base = util.preloaded.orm_base
+
+ if not msg:
+ msg = (
+ "Instance '%s' has been deleted, or its "
+ "row is otherwise not present." % base.state_str(state)
+ )
+
+ sa_exc.InvalidRequestError.__init__(self, msg)
+
+ def __reduce__(self):
+ return self.__class__, (None, self.args[0])
+
+
+class UnmappedColumnError(sa_exc.InvalidRequestError):
+ """Mapping operation was requested on an unknown column."""
+
+
+class LoaderStrategyException(sa_exc.InvalidRequestError):
+ """A loader strategy for an attribute does not exist."""
+
+ def __init__(
+ self,
+ applied_to_property_type,
+ requesting_property,
+ applies_to,
+ actual_strategy_type,
+ strategy_key,
+ ):
+ if actual_strategy_type is None:
+ sa_exc.InvalidRequestError.__init__(
+ self,
+ "Can't find strategy %s for %s"
+ % (strategy_key, requesting_property),
+ )
+ else:
+ sa_exc.InvalidRequestError.__init__(
+ self,
+ 'Can\'t apply "%s" strategy to property "%s", '
+ 'which is a "%s"; this loader strategy is intended '
+ 'to be used with a "%s".'
+ % (
+ util.clsname_as_plain_name(actual_strategy_type),
+ requesting_property,
+ util.clsname_as_plain_name(applied_to_property_type),
+ util.clsname_as_plain_name(applies_to),
+ ),
+ )
+
+
+def _safe_cls_name(cls):
+ try:
+ cls_name = ".".join((cls.__module__, cls.__name__))
+ except AttributeError:
+ cls_name = getattr(cls, "__name__", None)
+ if cls_name is None:
+ cls_name = repr(cls)
+ return cls_name
+
+
+@util.preload_module("sqlalchemy.orm.base")
+def _default_unmapped(cls):
+ base = util.preloaded.orm_base
+
+ try:
+ mappers = base.manager_of_class(cls).mappers
+ except (TypeError,) + NO_STATE:
+ mappers = {}
+ name = _safe_cls_name(cls)
+
+ if not mappers:
+ return "Class '%s' is not mapped" % name
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
new file mode 100644
index 0000000..7de8e2c
--- /dev/null
+++ b/lib/sqlalchemy/orm/identity.py
@@ -0,0 +1,254 @@
+# orm/identity.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import weakref
+
+from . import util as orm_util
+from .. import exc as sa_exc
+from .. import util
+
+
+class IdentityMap(object):
+ def __init__(self):
+ self._dict = {}
+ self._modified = set()
+ self._wr = weakref.ref(self)
+
+ def _kill(self):
+ self._add_unpresent = _killed
+
+ def keys(self):
+ return self._dict.keys()
+
+ def replace(self, state):
+ raise NotImplementedError()
+
+ def add(self, state):
+ raise NotImplementedError()
+
+ def _add_unpresent(self, state, key):
+ """optional inlined form of add() which can assume item isn't present
+ in the map"""
+ self.add(state)
+
+ def update(self, dict_):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def clear(self):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def _manage_incoming_state(self, state):
+ state._instance_dict = self._wr
+
+ if state.modified:
+ self._modified.add(state)
+
+ def _manage_removed_state(self, state):
+ del state._instance_dict
+ if state.modified:
+ self._modified.discard(state)
+
+ def _dirty_states(self):
+ return self._modified
+
+ def check_modified(self):
+ """return True if any InstanceStates present have been marked
+ as 'modified'.
+
+ """
+ return bool(self._modified)
+
+ def has_key(self, key):
+ return key in self
+
+ def popitem(self):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def pop(self, key, *args):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+ def setdefault(self, key, default=None):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def __len__(self):
+ return len(self._dict)
+
+ def copy(self):
+ raise NotImplementedError()
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError("IdentityMap uses add() to insert data")
+
+ def __delitem__(self, key):
+ raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+
+class WeakInstanceDict(IdentityMap):
+ def __getitem__(self, key):
+ state = self._dict[key]
+ o = state.obj()
+ if o is None:
+ raise KeyError(key)
+ return o
+
+ def __contains__(self, key):
+ try:
+ if key in self._dict:
+ state = self._dict[key]
+ o = state.obj()
+ else:
+ return False
+ except KeyError:
+ return False
+ else:
+ return o is not None
+
+ def contains_state(self, state):
+ if state.key in self._dict:
+ try:
+ return self._dict[state.key] is state
+ except KeyError:
+ return False
+ else:
+ return False
+
+ def replace(self, state):
+ if state.key in self._dict:
+ try:
+ existing = self._dict[state.key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if existing is not state:
+ self._manage_removed_state(existing)
+ else:
+ return None
+ else:
+ existing = None
+
+ self._dict[state.key] = state
+ self._manage_incoming_state(state)
+ return existing
+
+ def add(self, state):
+ key = state.key
+ # inline of self.__contains__
+ if key in self._dict:
+ try:
+ existing_state = self._dict[key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if existing_state is not state:
+ o = existing_state.obj()
+ if o is not None:
+ raise sa_exc.InvalidRequestError(
+ "Can't attach instance "
+ "%s; another instance with key %s is already "
+ "present in this session."
+ % (orm_util.state_str(state), state.key)
+ )
+ else:
+ return False
+ self._dict[key] = state
+ self._manage_incoming_state(state)
+ return True
+
+ def _add_unpresent(self, state, key):
+ # inlined form of add() called by loading.py
+ self._dict[key] = state
+ state._instance_dict = self._wr
+
+ def get(self, key, default=None):
+ if key not in self._dict:
+ return default
+ try:
+ state = self._dict[key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ return default
+ else:
+ o = state.obj()
+ if o is None:
+ return default
+ return o
+
+ def items(self):
+ values = self.all_states()
+ result = []
+ for state in values:
+ value = state.obj()
+ if value is not None:
+ result.append((state.key, value))
+ return result
+
+ def values(self):
+ values = self.all_states()
+ result = []
+ for state in values:
+ value = state.obj()
+ if value is not None:
+ result.append(value)
+
+ return result
+
+ def __iter__(self):
+ return iter(self.keys())
+
+ if util.py2k:
+
+ def iteritems(self):
+ return iter(self.items())
+
+ def itervalues(self):
+ return iter(self.values())
+
+ def all_states(self):
+ if util.py2k:
+ return self._dict.values()
+ else:
+ return list(self._dict.values())
+
+ def _fast_discard(self, state):
+ # used by InstanceState for state being
+ # GC'ed, inlines _managed_removed_state
+ try:
+ st = self._dict[state.key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if st is state:
+ self._dict.pop(state.key, None)
+
+ def discard(self, state):
+ self.safe_discard(state)
+
+ def safe_discard(self, state):
+ if state.key in self._dict:
+ try:
+ st = self._dict[state.key]
+ except KeyError:
+ # catch gc removed the key after we just checked for it
+ pass
+ else:
+ if st is state:
+ self._dict.pop(state.key, None)
+ self._manage_removed_state(state)
+
+
+def _killed(state, key):
+ # external function to avoid creating cycles when assigned to
+ # the IdentityMap
+ raise sa_exc.InvalidRequestError(
+ "Object %s cannot be converted to 'persistent' state, as this "
+ "identity map is no longer valid. Has the owning Session "
+ "been closed?" % orm_util.state_str(state),
+ code="lkrp",
+ )
diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py
new file mode 100644
index 0000000..a7023a2
--- /dev/null
+++ b/lib/sqlalchemy/orm/instrumentation.py
@@ -0,0 +1,652 @@
+# orm/instrumentation.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines SQLAlchemy's system of class instrumentation.
+
+This module is usually not directly visible to user applications, but
+defines a large part of the ORM's interactivity.
+
+instrumentation.py deals with registration of end-user classes
+for state tracking. It interacts closely with state.py
+and attributes.py which establish per-instance and per-class-attribute
+instrumentation, respectively.
+
+The class instrumentation system can be customized on a per-class
+or global basis using the :mod:`sqlalchemy.ext.instrumentation`
+module, which provides the means to build and specify
+alternate instrumentation forms.
+
+.. versionchanged: 0.8
+ The instrumentation extension system was moved out of the
+ ORM and into the external :mod:`sqlalchemy.ext.instrumentation`
+ package. When that package is imported, it installs
+ itself within sqlalchemy.orm so that its more comprehensive
+ resolution mechanics take effect.
+
+"""
+
+
+import weakref
+
+from . import base
+from . import collections
+from . import exc
+from . import interfaces
+from . import state
+from .. import util
+from ..util import HasMemoized
+
+
+DEL_ATTR = util.symbol("DEL_ATTR")
+
+
+class ClassManager(HasMemoized, dict):
+ """Tracks state information at the class level."""
+
+ MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR
+ STATE_ATTR = base.DEFAULT_STATE_ATTR
+
+ _state_setter = staticmethod(util.attrsetter(STATE_ATTR))
+
+ expired_attribute_loader = None
+ "previously known as deferred_scalar_loader"
+
+ init_method = None
+
+ factory = None
+ mapper = None
+ declarative_scan = None
+ registry = None
+
+ @property
+ @util.deprecated(
+ "1.4",
+ message="The ClassManager.deferred_scalar_loader attribute is now "
+ "named expired_attribute_loader",
+ )
+ def deferred_scalar_loader(self):
+ return self.expired_attribute_loader
+
+ @deferred_scalar_loader.setter
+ @util.deprecated(
+ "1.4",
+ message="The ClassManager.deferred_scalar_loader attribute is now "
+ "named expired_attribute_loader",
+ )
+ def deferred_scalar_loader(self, obj):
+ self.expired_attribute_loader = obj
+
+ def __init__(self, class_):
+ self.class_ = class_
+ self.info = {}
+ self.new_init = None
+ self.local_attrs = {}
+ self.originals = {}
+ self._finalized = False
+
+ self._bases = [
+ mgr
+ for mgr in [
+ manager_of_class(base)
+ for base in self.class_.__bases__
+ if isinstance(base, type)
+ ]
+ if mgr is not None
+ ]
+
+ for base_ in self._bases:
+ self.update(base_)
+
+ self.dispatch._events._new_classmanager_instance(class_, self)
+
+ for basecls in class_.__mro__:
+ mgr = manager_of_class(basecls)
+ if mgr is not None:
+ self.dispatch._update(mgr.dispatch)
+
+ self.manage()
+
+ if "__del__" in class_.__dict__:
+ util.warn(
+ "__del__() method on class %s will "
+ "cause unreachable cycles and memory leaks, "
+ "as SQLAlchemy instrumentation often creates "
+ "reference cycles. Please remove this method." % class_
+ )
+
+ def _update_state(
+ self,
+ finalize=False,
+ mapper=None,
+ registry=None,
+ declarative_scan=None,
+ expired_attribute_loader=None,
+ init_method=None,
+ ):
+
+ if mapper:
+ self.mapper = mapper
+ if registry:
+ registry._add_manager(self)
+ if declarative_scan:
+ self.declarative_scan = weakref.ref(declarative_scan)
+ if expired_attribute_loader:
+ self.expired_attribute_loader = expired_attribute_loader
+
+ if init_method:
+ assert not self._finalized, (
+ "class is already instrumented, "
+ "init_method %s can't be applied" % init_method
+ )
+ self.init_method = init_method
+
+ if not self._finalized:
+ self.original_init = (
+ self.init_method
+ if self.init_method is not None
+ and self.class_.__init__ is object.__init__
+ else self.class_.__init__
+ )
+
+ if finalize and not self._finalized:
+ self._finalize()
+
+ def _finalize(self):
+ if self._finalized:
+ return
+ self._finalized = True
+
+ self._instrument_init()
+
+ _instrumentation_factory.dispatch.class_instrument(self.class_)
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return other is self
+
+ @property
+ def is_mapped(self):
+ return "mapper" in self.__dict__
+
+ @HasMemoized.memoized_attribute
+ def _all_key_set(self):
+ return frozenset(self)
+
+ @HasMemoized.memoized_attribute
+ def _collection_impl_keys(self):
+ return frozenset(
+ [attr.key for attr in self.values() if attr.impl.collection]
+ )
+
+ @HasMemoized.memoized_attribute
+ def _scalar_loader_impls(self):
+ return frozenset(
+ [
+ attr.impl
+ for attr in self.values()
+ if attr.impl.accepts_scalar_loader
+ ]
+ )
+
+ @HasMemoized.memoized_attribute
+ def _loader_impls(self):
+ return frozenset([attr.impl for attr in self.values()])
+
+ @util.memoized_property
+ def mapper(self):
+ # raises unless self.mapper has been assigned
+ raise exc.UnmappedClassError(self.class_)
+
+ def _all_sqla_attributes(self, exclude=None):
+ """return an iterator of all classbound attributes that are
+ implement :class:`.InspectionAttr`.
+
+ This includes :class:`.QueryableAttribute` as well as extension
+ types such as :class:`.hybrid_property` and
+ :class:`.AssociationProxy`.
+
+ """
+
+ found = {}
+
+ # constraints:
+ # 1. yield keys in cls.__dict__ order
+ # 2. if a subclass has the same key as a superclass, include that
+ # key as part of the ordering of the superclass, because an
+ # overridden key is usually installed by the mapper which is going
+ # on a different ordering
+ # 3. don't use getattr() as this fires off descriptors
+
+ for supercls in self.class_.__mro__[0:-1]:
+ inherits = supercls.__mro__[1]
+ for key in supercls.__dict__:
+ found.setdefault(key, supercls)
+ if key in inherits.__dict__:
+ continue
+ val = found[key].__dict__[key]
+ if (
+ isinstance(val, interfaces.InspectionAttr)
+ and val.is_attribute
+ ):
+ yield key, val
+
+ def _get_class_attr_mro(self, key, default=None):
+ """return an attribute on the class without tripping it."""
+
+ for supercls in self.class_.__mro__:
+ if key in supercls.__dict__:
+ return supercls.__dict__[key]
+ else:
+ return default
+
+ def _attr_has_impl(self, key):
+ """Return True if the given attribute is fully initialized.
+
+ i.e. has an impl.
+ """
+
+ return key in self and self[key].impl is not None
+
+ def _subclass_manager(self, cls):
+ """Create a new ClassManager for a subclass of this ClassManager's
+ class.
+
+ This is called automatically when attributes are instrumented so that
+ the attributes can be propagated to subclasses against their own
+ class-local manager, without the need for mappers etc. to have already
+ pre-configured managers for the full class hierarchy. Mappers
+ can post-configure the auto-generated ClassManager when needed.
+
+ """
+ return register_class(cls, finalize=False)
+
+ def _instrument_init(self):
+ self.new_init = _generate_init(self.class_, self, self.original_init)
+ self.install_member("__init__", self.new_init)
+
+ @util.memoized_property
+ def _state_constructor(self):
+ self.dispatch.first_init(self, self.class_)
+ return state.InstanceState
+
+ def manage(self):
+ """Mark this instance as the manager for its class."""
+
+ setattr(self.class_, self.MANAGER_ATTR, self)
+
+ @util.hybridmethod
+ def manager_getter(self):
+ return _default_manager_getter
+
+ @util.hybridmethod
+ def state_getter(self):
+ """Return a (instance) -> InstanceState callable.
+
+ "state getter" callables should raise either KeyError or
+ AttributeError if no InstanceState could be found for the
+ instance.
+ """
+
+ return _default_state_getter
+
+ @util.hybridmethod
+ def dict_getter(self):
+ return _default_dict_getter
+
+ def instrument_attribute(self, key, inst, propagated=False):
+ if propagated:
+ if key in self.local_attrs:
+ return # don't override local attr with inherited attr
+ else:
+ self.local_attrs[key] = inst
+ self.install_descriptor(key, inst)
+ self._reset_memoizations()
+ self[key] = inst
+
+ for cls in self.class_.__subclasses__():
+ manager = self._subclass_manager(cls)
+ manager.instrument_attribute(key, inst, True)
+
+ def subclass_managers(self, recursive):
+ for cls in self.class_.__subclasses__():
+ mgr = manager_of_class(cls)
+ if mgr is not None and mgr is not self:
+ yield mgr
+ if recursive:
+ for m in mgr.subclass_managers(True):
+ yield m
+
+ def post_configure_attribute(self, key):
+ _instrumentation_factory.dispatch.attribute_instrument(
+ self.class_, key, self[key]
+ )
+
+ def uninstrument_attribute(self, key, propagated=False):
+ if key not in self:
+ return
+ if propagated:
+ if key in self.local_attrs:
+ return # don't get rid of local attr
+ else:
+ del self.local_attrs[key]
+ self.uninstall_descriptor(key)
+ self._reset_memoizations()
+ del self[key]
+ for cls in self.class_.__subclasses__():
+ manager = manager_of_class(cls)
+ if manager:
+ manager.uninstrument_attribute(key, True)
+
+ def unregister(self):
+ """remove all instrumentation established by this ClassManager."""
+
+ for key in list(self.originals):
+ self.uninstall_member(key)
+
+ self.mapper = self.dispatch = self.new_init = None
+ self.info.clear()
+
+ for key in list(self):
+ if key in self.local_attrs:
+ self.uninstrument_attribute(key)
+
+ if self.MANAGER_ATTR in self.class_.__dict__:
+ delattr(self.class_, self.MANAGER_ATTR)
+
+ def install_descriptor(self, key, inst):
+ if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+ raise KeyError(
+ "%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key
+ )
+ setattr(self.class_, key, inst)
+
+ def uninstall_descriptor(self, key):
+ delattr(self.class_, key)
+
+ def install_member(self, key, implementation):
+ if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+ raise KeyError(
+ "%r: requested attribute name conflicts with "
+ "instrumentation attribute of the same name." % key
+ )
+ self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR))
+ setattr(self.class_, key, implementation)
+
+ def uninstall_member(self, key):
+ original = self.originals.pop(key, None)
+ if original is not DEL_ATTR:
+ setattr(self.class_, key, original)
+ else:
+ delattr(self.class_, key)
+
+ def instrument_collection_class(self, key, collection_class):
+ return collections.prepare_instrumentation(collection_class)
+
+ def initialize_collection(self, key, state, factory):
+ user_data = factory()
+ adapter = collections.CollectionAdapter(
+ self.get_impl(key), state, user_data
+ )
+ return adapter, user_data
+
+ def is_instrumented(self, key, search=False):
+ if search:
+ return key in self
+ else:
+ return key in self.local_attrs
+
+ def get_impl(self, key):
+ return self[key].impl
+
+ @property
+ def attributes(self):
+ return iter(self.values())
+
+ # InstanceState management
+
+ def new_instance(self, state=None):
+ instance = self.class_.__new__(self.class_)
+ if state is None:
+ state = self._state_constructor(instance, self)
+ self._state_setter(instance, state)
+ return instance
+
+ def setup_instance(self, instance, state=None):
+ if state is None:
+ state = self._state_constructor(instance, self)
+ self._state_setter(instance, state)
+
+ def teardown_instance(self, instance):
+ delattr(instance, self.STATE_ATTR)
+
+ def _serialize(self, state, state_dict):
+ return _SerializeManager(state, state_dict)
+
+ def _new_state_if_none(self, instance):
+ """Install a default InstanceState if none is present.
+
+ A private convenience method used by the __init__ decorator.
+
+ """
+ if hasattr(instance, self.STATE_ATTR):
+ return False
+ elif self.class_ is not instance.__class__ and self.is_mapped:
+ # this will create a new ClassManager for the
+ # subclass, without a mapper. This is likely a
+ # user error situation but allow the object
+ # to be constructed, so that it is usable
+ # in a non-ORM context at least.
+ return self._subclass_manager(
+ instance.__class__
+ )._new_state_if_none(instance)
+ else:
+ state = self._state_constructor(instance, self)
+ self._state_setter(instance, state)
+ return state
+
+ def has_state(self, instance):
+ return hasattr(instance, self.STATE_ATTR)
+
+ def has_parent(self, state, key, optimistic=False):
+ """TODO"""
+ return self.get_impl(key).hasparent(state, optimistic=optimistic)
+
+ def __bool__(self):
+ """All ClassManagers are non-zero regardless of attribute state."""
+ return True
+
+ __nonzero__ = __bool__
+
+ def __repr__(self):
+ return "<%s of %r at %x>" % (
+ self.__class__.__name__,
+ self.class_,
+ id(self),
+ )
+
+
+class _SerializeManager(object):
+ """Provide serialization of a :class:`.ClassManager`.
+
+ The :class:`.InstanceState` uses ``__init__()`` on serialize
+ and ``__call__()`` on deserialize.
+
+ """
+
+ def __init__(self, state, d):
+ self.class_ = state.class_
+ manager = state.manager
+ manager.dispatch.pickle(state, d)
+
+ def __call__(self, state, inst, state_dict):
+ state.manager = manager = manager_of_class(self.class_)
+ if manager is None:
+ raise exc.UnmappedInstanceError(
+ inst,
+ "Cannot deserialize object of type %r - "
+ "no mapper() has "
+ "been configured for this class within the current "
+ "Python process!" % self.class_,
+ )
+ elif manager.is_mapped and not manager.mapper.configured:
+ manager.mapper._check_configure()
+
+ # setup _sa_instance_state ahead of time so that
+ # unpickle events can access the object normally.
+ # see [ticket:2362]
+ if inst is not None:
+ manager.setup_instance(inst, state)
+ manager.dispatch.unpickle(state, state_dict)
+
+
+class InstrumentationFactory(object):
+ """Factory for new ClassManager instances."""
+
+ def create_manager_for_cls(self, class_):
+ assert class_ is not None
+ assert manager_of_class(class_) is None
+
+ # give a more complicated subclass
+ # a chance to do what it wants here
+ manager, factory = self._locate_extended_factory(class_)
+
+ if factory is None:
+ factory = ClassManager
+ manager = factory(class_)
+
+ self._check_conflicts(class_, factory)
+
+ manager.factory = factory
+
+ return manager
+
+ def _locate_extended_factory(self, class_):
+ """Overridden by a subclass to do an extended lookup."""
+ return None, None
+
+ def _check_conflicts(self, class_, factory):
+ """Overridden by a subclass to test for conflicting factories."""
+ return
+
+ def unregister(self, class_):
+ manager = manager_of_class(class_)
+ manager.unregister()
+ self.dispatch.class_uninstrument(class_)
+
+
+# this attribute is replaced by sqlalchemy.ext.instrumentation
+# when imported.
+_instrumentation_factory = InstrumentationFactory()
+
+# these attributes are replaced by sqlalchemy.ext.instrumentation
+# when a non-standard InstrumentationManager class is first
+# used to instrument a class.
+instance_state = _default_state_getter = base.instance_state
+
+instance_dict = _default_dict_getter = base.instance_dict
+
+manager_of_class = _default_manager_getter = base.manager_of_class
+
+
+def register_class(
+ class_,
+ finalize=True,
+ mapper=None,
+ registry=None,
+ declarative_scan=None,
+ expired_attribute_loader=None,
+ init_method=None,
+):
+ """Register class instrumentation.
+
+ Returns the existing or newly created class manager.
+
+ """
+
+ manager = manager_of_class(class_)
+ if manager is None:
+ manager = _instrumentation_factory.create_manager_for_cls(class_)
+ manager._update_state(
+ mapper=mapper,
+ registry=registry,
+ declarative_scan=declarative_scan,
+ expired_attribute_loader=expired_attribute_loader,
+ init_method=init_method,
+ finalize=finalize,
+ )
+
+ return manager
+
+
+def unregister_class(class_):
+ """Unregister class instrumentation."""
+
+ _instrumentation_factory.unregister(class_)
+
+
+def is_instrumented(instance, key):
+ """Return True if the given attribute on the given instance is
+ instrumented by the attributes package.
+
+ This function may be used regardless of instrumentation
+ applied directly to the class, i.e. no descriptors are required.
+
+ """
+ return manager_of_class(instance.__class__).is_instrumented(
+ key, search=True
+ )
+
+
+def _generate_init(class_, class_manager, original_init):
+ """Build an __init__ decorator that triggers ClassManager events."""
+
+ # TODO: we should use the ClassManager's notion of the
+ # original '__init__' method, once ClassManager is fixed
+ # to always reference that.
+
+ if original_init is None:
+ original_init = class_.__init__
+
+ # Go through some effort here and don't change the user's __init__
+ # calling signature, including the unlikely case that it has
+ # a return value.
+ # FIXME: need to juggle local names to avoid constructor argument
+ # clashes.
+ func_body = """\
+def __init__(%(apply_pos)s):
+ new_state = class_manager._new_state_if_none(%(self_arg)s)
+ if new_state:
+ return new_state._initialize_instance(%(apply_kw)s)
+ else:
+ return original_init(%(apply_kw)s)
+"""
+ func_vars = util.format_argspec_init(original_init, grouped=False)
+ func_text = func_body % func_vars
+
+ if util.py2k:
+ func = getattr(original_init, "im_func", original_init)
+ func_defaults = getattr(func, "func_defaults", None)
+ else:
+ func_defaults = getattr(original_init, "__defaults__", None)
+ func_kw_defaults = getattr(original_init, "__kwdefaults__", None)
+
+ env = locals().copy()
+ env["__name__"] = __name__
+ exec(func_text, env)
+ __init__ = env["__init__"]
+ __init__.__doc__ = original_init.__doc__
+ __init__._sa_original_init = original_init
+
+ if func_defaults:
+ __init__.__defaults__ = func_defaults
+ if not util.py2k and func_kw_defaults:
+ __init__.__kwdefaults__ = func_kw_defaults
+
+ return __init__
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
new file mode 100644
index 0000000..63295d0
--- /dev/null
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -0,0 +1,978 @@
+# orm/interfaces.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+Contains various base classes used throughout the ORM.
+
+Defines some key base classes prominent within the internals.
+
+This module and the classes within are mostly private, though some attributes
+are exposed when inspecting mappings.
+
+"""
+
+from __future__ import absolute_import
+
+import collections
+
+from . import exc as orm_exc
+from . import path_registry
+from .base import _MappedAttribute # noqa
+from .base import EXT_CONTINUE
+from .base import EXT_SKIP
+from .base import EXT_STOP
+from .base import InspectionAttr # noqa
+from .base import InspectionAttrInfo # noqa
+from .base import MANYTOMANY
+from .base import MANYTOONE
+from .base import NOT_EXTENSION
+from .base import ONETOMANY
+from .. import inspect
+from .. import inspection
+from .. import util
+from ..sql import operators
+from ..sql import roles
+from ..sql import visitors
+from ..sql.base import ExecutableOption
+from ..sql.traversals import HasCacheKey
+
+
+__all__ = (
+ "EXT_CONTINUE",
+ "EXT_STOP",
+ "EXT_SKIP",
+ "ONETOMANY",
+ "MANYTOMANY",
+ "MANYTOONE",
+ "NOT_EXTENSION",
+ "LoaderStrategy",
+ "MapperOption",
+ "LoaderOption",
+ "MapperProperty",
+ "PropComparator",
+ "StrategizedProperty",
+)
+
+
+class ORMStatementRole(roles.StatementRole):
+ _role_name = (
+ "Executable SQL or text() construct, including ORM " "aware objects"
+ )
+
+
+class ORMColumnsClauseRole(roles.ColumnsClauseRole):
+ _role_name = "ORM mapped entity, aliased entity, or Column expression"
+
+
+class ORMEntityColumnsClauseRole(ORMColumnsClauseRole):
+ _role_name = "ORM mapped or aliased entity"
+
+
+class ORMFromClauseRole(roles.StrictFromClauseRole):
+ _role_name = "ORM mapped entity, aliased entity, or FROM expression"
+
+
+@inspection._self_inspects
+class MapperProperty(
+ HasCacheKey, _MappedAttribute, InspectionAttr, util.MemoizedSlots
+):
+ """Represent a particular class attribute mapped by :class:`_orm.Mapper`.
+
+ The most common occurrences of :class:`.MapperProperty` are the
+ mapped :class:`_schema.Column`, which is represented in a mapping as
+ an instance of :class:`.ColumnProperty`,
+ and a reference to another class produced by :func:`_orm.relationship`,
+ represented in the mapping as an instance of
+ :class:`.RelationshipProperty`.
+
+ """
+
+ __slots__ = (
+ "_configure_started",
+ "_configure_finished",
+ "parent",
+ "key",
+ "info",
+ )
+
+ _cache_key_traversal = [
+ ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("key", visitors.ExtendedInternalTraversal.dp_string),
+ ]
+
+ cascade = frozenset()
+ """The set of 'cascade' attribute names.
+
+ This collection is checked before the 'cascade_iterator' method is called.
+
+ The collection typically only applies to a RelationshipProperty.
+
+ """
+
+ is_property = True
+ """Part of the InspectionAttr interface; states this object is a
+ mapper property.
+
+ """
+
+ @property
+ def _links_to_entity(self):
+ """True if this MapperProperty refers to a mapped entity.
+
+ Should only be True for RelationshipProperty, False for all others.
+
+ """
+ raise NotImplementedError()
+
+ def _memoized_attr_info(self):
+ """Info dictionary associated with the object, allowing user-defined
+ data to be associated with this :class:`.InspectionAttr`.
+
+ The dictionary is generated when first accessed. Alternatively,
+ it can be specified as a constructor argument to the
+ :func:`.column_property`, :func:`_orm.relationship`, or
+ :func:`.composite`
+ functions.
+
+ .. versionchanged:: 1.0.0 :attr:`.MapperProperty.info` is also
+ available on extension types via the
+ :attr:`.InspectionAttrInfo.info` attribute, so that it can apply
+ to a wider variety of ORM and extension constructs.
+
+ .. seealso::
+
+ :attr:`.QueryableAttribute.info`
+
+ :attr:`.SchemaItem.info`
+
+ """
+ return {}
+
+ def setup(self, context, query_entity, path, adapter, **kwargs):
+ """Called by Query for the purposes of constructing a SQL statement.
+
+ Each MapperProperty associated with the target mapper processes the
+ statement referenced by the query context, adding columns and/or
+ criterion as appropriate.
+
+ """
+
+ def create_row_processor(
+ self, context, query_entity, path, mapper, result, adapter, populators
+ ):
+ """Produce row processing functions and append to the given
+ set of populators lists.
+
+ """
+
+ def cascade_iterator(
+ self, type_, state, dict_, visited_states, halt_on=None
+ ):
+ """Iterate through instances related to the given instance for
+ a particular 'cascade', starting with this MapperProperty.
+
+ Return an iterator3-tuples (instance, mapper, state).
+
+ Note that the 'cascade' collection on this MapperProperty is
+ checked first for the given type before cascade_iterator is called.
+
+ This method typically only applies to RelationshipProperty.
+
+ """
+
+ return iter(())
+
+ def set_parent(self, parent, init):
+ """Set the parent mapper that references this MapperProperty.
+
+ This method is overridden by some subclasses to perform extra
+ setup when the mapper is first known.
+
+ """
+ self.parent = parent
+
+ def instrument_class(self, mapper):
+ """Hook called by the Mapper to the property to initiate
+ instrumentation of the class attribute managed by this
+ MapperProperty.
+
+ The MapperProperty here will typically call out to the
+ attributes module to set up an InstrumentedAttribute.
+
+ This step is the first of two steps to set up an InstrumentedAttribute,
+ and is called early in the mapper setup process.
+
+ The second step is typically the init_class_attribute step,
+ called from StrategizedProperty via the post_instrument_class()
+ hook. This step assigns additional state to the InstrumentedAttribute
+ (specifically the "impl") which has been determined after the
+ MapperProperty has determined what kind of persistence
+ management it needs to do (e.g. scalar, object, collection, etc).
+
+ """
+
+ def __init__(self):
+ self._configure_started = False
+ self._configure_finished = False
+
+ def init(self):
+ """Called after all mappers are created to assemble
+ relationships between mappers and perform other post-mapper-creation
+ initialization steps.
+
+
+ """
+ self._configure_started = True
+ self.do_init()
+ self._configure_finished = True
+
+ @property
+ def class_attribute(self):
+ """Return the class-bound descriptor corresponding to this
+ :class:`.MapperProperty`.
+
+ This is basically a ``getattr()`` call::
+
+ return getattr(self.parent.class_, self.key)
+
+ I.e. if this :class:`.MapperProperty` were named ``addresses``,
+ and the class to which it is mapped is ``User``, this sequence
+ is possible::
+
+ >>> from sqlalchemy import inspect
+ >>> mapper = inspect(User)
+ >>> addresses_property = mapper.attrs.addresses
+ >>> addresses_property.class_attribute is User.addresses
+ True
+ >>> User.addresses.property is addresses_property
+ True
+
+
+ """
+
+ return getattr(self.parent.class_, self.key)
+
+ def do_init(self):
+ """Perform subclass-specific initialization post-mapper-creation
+ steps.
+
+ This is a template method called by the ``MapperProperty``
+ object's init() method.
+
+ """
+
+ def post_instrument_class(self, mapper):
+ """Perform instrumentation adjustments that need to occur
+ after init() has completed.
+
+ The given Mapper is the Mapper invoking the operation, which
+ may not be the same Mapper as self.parent in an inheritance
+ scenario; however, Mapper will always at least be a sub-mapper of
+ self.parent.
+
+ This method is typically used by StrategizedProperty, which delegates
+ it to LoaderStrategy.init_class_attribute() to perform final setup
+ on the class-bound InstrumentedAttribute.
+
+ """
+
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
+ """Merge the attribute represented by this ``MapperProperty``
+ from source to destination object.
+
+ """
+
+ def __repr__(self):
+ return "<%s at 0x%x; %s>" % (
+ self.__class__.__name__,
+ id(self),
+ getattr(self, "key", "no key"),
+ )
+
+
+@inspection._self_inspects
+class PropComparator(operators.ColumnOperators):
+ r"""Defines SQL operators for :class:`.MapperProperty` objects.
+
+ SQLAlchemy allows for operators to
+ be redefined at both the Core and ORM level. :class:`.PropComparator`
+ is the base class of operator redefinition for ORM-level operations,
+ including those of :class:`.ColumnProperty`,
+ :class:`.RelationshipProperty`, and :class:`.CompositeProperty`.
+
+ .. note:: With the advent of Hybrid properties introduced in SQLAlchemy
+ 0.7, as well as Core-level operator redefinition in
+ SQLAlchemy 0.8, the use case for user-defined :class:`.PropComparator`
+ instances is extremely rare. See :ref:`hybrids_toplevel` as well
+ as :ref:`types_operators`.
+
+ User-defined subclasses of :class:`.PropComparator` may be created. The
+ built-in Python comparison and math operator methods, such as
+ :meth:`.operators.ColumnOperators.__eq__`,
+ :meth:`.operators.ColumnOperators.__lt__`, and
+ :meth:`.operators.ColumnOperators.__add__`, can be overridden to provide
+ new operator behavior. The custom :class:`.PropComparator` is passed to
+ the :class:`.MapperProperty` instance via the ``comparator_factory``
+ argument. In each case,
+ the appropriate subclass of :class:`.PropComparator` should be used::
+
+ # definition of custom PropComparator subclasses
+
+ from sqlalchemy.orm.properties import \
+ ColumnProperty,\
+ CompositeProperty,\
+ RelationshipProperty
+
+ class MyColumnComparator(ColumnProperty.Comparator):
+ def __eq__(self, other):
+ return self.__clause_element__() == other
+
+ class MyRelationshipComparator(RelationshipProperty.Comparator):
+ def any(self, expression):
+ "define the 'any' operation"
+ # ...
+
+ class MyCompositeComparator(CompositeProperty.Comparator):
+ def __gt__(self, other):
+ "redefine the 'greater than' operation"
+
+ return sql.and_(*[a>b for a, b in
+ zip(self.__clause_element__().clauses,
+ other.__composite_values__())])
+
+
+ # application of custom PropComparator subclasses
+
+ from sqlalchemy.orm import column_property, relationship, composite
+ from sqlalchemy import Column, String
+
+ class SomeMappedClass(Base):
+ some_column = column_property(Column("some_column", String),
+ comparator_factory=MyColumnComparator)
+
+ some_relationship = relationship(SomeOtherClass,
+ comparator_factory=MyRelationshipComparator)
+
+ some_composite = composite(
+ Column("a", String), Column("b", String),
+ comparator_factory=MyCompositeComparator
+ )
+
+ Note that for column-level operator redefinition, it's usually
+ simpler to define the operators at the Core level, using the
+ :attr:`.TypeEngine.comparator_factory` attribute. See
+ :ref:`types_operators` for more detail.
+
+ .. seealso::
+
+ :class:`.ColumnProperty.Comparator`
+
+ :class:`.RelationshipProperty.Comparator`
+
+ :class:`.CompositeProperty.Comparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ __slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
+
+ __visit_name__ = "orm_prop_comparator"
+
+ def __init__(
+ self,
+ prop,
+ parentmapper,
+ adapt_to_entity=None,
+ ):
+ self.prop = self.property = prop
+ self._parententity = adapt_to_entity or parentmapper
+ self._adapt_to_entity = adapt_to_entity
+
+ def __clause_element__(self):
+ raise NotImplementedError("%r" % self)
+
+ def _bulk_update_tuples(self, value):
+ """Receive a SQL expression that represents a value in the SET
+ clause of an UPDATE statement.
+
+ Return a tuple that can be passed to a :class:`_expression.Update`
+ construct.
+
+ """
+
+ return [(self.__clause_element__(), value)]
+
+ def adapt_to_entity(self, adapt_to_entity):
+ """Return a copy of this PropComparator which will use the given
+ :class:`.AliasedInsp` to produce corresponding expressions.
+ """
+ return self.__class__(self.prop, self._parententity, adapt_to_entity)
+
+ @property
+ def _parentmapper(self):
+ """legacy; this is renamed to _parententity to be
+ compatible with QueryableAttribute."""
+ return inspect(self._parententity).mapper
+
+ @property
+ def _propagate_attrs(self):
+ # this suits the case in coercions where we don't actually
+ # call ``__clause_element__()`` but still need to get
+ # resolved._propagate_attrs. See #6558.
+ return util.immutabledict(
+ {
+ "compile_state_plugin": "orm",
+ "plugin_subject": self._parentmapper,
+ }
+ )
+
+ @property
+ def adapter(self):
+ """Produce a callable that adapts column expressions
+ to suit an aliased version of this comparator.
+
+ """
+ if self._adapt_to_entity is None:
+ return None
+ else:
+ return self._adapt_to_entity._adapt_element
+
+ @property
+ def info(self):
+ return self.property.info
+
+ @staticmethod
+ def any_op(a, b, **kwargs):
+ return a.any(b, **kwargs)
+
+ @staticmethod
+ def has_op(a, b, **kwargs):
+ return a.has(b, **kwargs)
+
+ @staticmethod
+ def of_type_op(a, class_):
+ return a.of_type(class_)
+
+ def of_type(self, class_):
+ r"""Redefine this object in terms of a polymorphic subclass,
+ :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased`
+ construct.
+
+ Returns a new PropComparator from which further criterion can be
+ evaluated.
+
+ e.g.::
+
+ query.join(Company.employees.of_type(Engineer)).\
+ filter(Engineer.name=='foo')
+
+ :param \class_: a class or mapper indicating that criterion will be
+ against this specific subclass.
+
+ .. seealso::
+
+ :ref:`queryguide_join_onclause` - in the :ref:`queryguide_toplevel`
+
+ :ref:`inheritance_of_type`
+
+ """
+
+ return self.operate(PropComparator.of_type_op, class_)
+
+ def and_(self, *criteria):
+ """Add additional criteria to the ON clause that's represented by this
+ relationship attribute.
+
+ E.g.::
+
+
+ stmt = select(User).join(
+ User.addresses.and_(Address.email_address != 'foo')
+ )
+
+ stmt = select(User).options(
+ joinedload(User.addresses.and_(Address.email_address != 'foo'))
+ )
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`orm_queryguide_join_on_augmented`
+
+ :ref:`loader_option_criteria`
+
+ :func:`.with_loader_criteria`
+
+ """
+ return self.operate(operators.and_, *criteria)
+
+ def any(self, criterion=None, **kwargs):
+ r"""Return true if this collection contains any member that meets the
+ given criterion.
+
+ The usual implementation of ``any()`` is
+ :meth:`.RelationshipProperty.Comparator.any`.
+
+ :param criterion: an optional ClauseElement formulated against the
+ member class' table or attributes.
+
+ :param \**kwargs: key/value pairs corresponding to member class
+ attribute names which will be compared via equality to the
+ corresponding values.
+
+ """
+
+ return self.operate(PropComparator.any_op, criterion, **kwargs)
+
+ def has(self, criterion=None, **kwargs):
+ r"""Return true if this element references a member which meets the
+ given criterion.
+
+ The usual implementation of ``has()`` is
+ :meth:`.RelationshipProperty.Comparator.has`.
+
+ :param criterion: an optional ClauseElement formulated against the
+ member class' table or attributes.
+
+ :param \**kwargs: key/value pairs corresponding to member class
+ attribute names which will be compared via equality to the
+ corresponding values.
+
+ """
+
+ return self.operate(PropComparator.has_op, criterion, **kwargs)
+
+
+class StrategizedProperty(MapperProperty):
+ """A MapperProperty which uses selectable strategies to affect
+ loading behavior.
+
+ There is a single strategy selected by default. Alternate
+ strategies can be selected at Query time through the usage of
+ ``StrategizedOption`` objects via the Query.options() method.
+
+ The mechanics of StrategizedProperty are used for every Query
+ invocation for every mapped attribute participating in that Query,
+ to determine first how the attribute will be rendered in SQL
+ and secondly how the attribute will retrieve a value from a result
+ row and apply it to a mapped object. The routines here are very
+ performance-critical.
+
+ """
+
+ __slots__ = (
+ "_strategies",
+ "strategy",
+ "_wildcard_token",
+ "_default_path_loader_key",
+ )
+ inherit_cache = True
+ strategy_wildcard_key = None
+
+ def _memoized_attr__wildcard_token(self):
+ return (
+ "%s:%s"
+ % (self.strategy_wildcard_key, path_registry._WILDCARD_TOKEN),
+ )
+
+ def _memoized_attr__default_path_loader_key(self):
+ return (
+ "loader",
+ (
+ "%s:%s"
+ % (self.strategy_wildcard_key, path_registry._DEFAULT_TOKEN),
+ ),
+ )
+
+ def _get_context_loader(self, context, path):
+ load = None
+
+ search_path = path[self]
+
+ # search among: exact match, "attr.*", "default" strategy
+ # if any.
+ for path_key in (
+ search_path._loader_key,
+ search_path._wildcard_path_loader_key,
+ search_path._default_path_loader_key,
+ ):
+ if path_key in context.attributes:
+ load = context.attributes[path_key]
+ break
+
+ return load
+
+ def _get_strategy(self, key):
+ try:
+ return self._strategies[key]
+ except KeyError:
+ pass
+
+ # run outside to prevent transfer of exception context
+ cls = self._strategy_lookup(self, *key)
+ # this previously was setting self._strategies[cls], that's
+ # a bad idea; should use strategy key at all times because every
+ # strategy has multiple keys at this point
+ self._strategies[key] = strategy = cls(self, key)
+ return strategy
+
+ def setup(self, context, query_entity, path, adapter, **kwargs):
+ loader = self._get_context_loader(context, path)
+ if loader and loader.strategy:
+ strat = self._get_strategy(loader.strategy)
+ else:
+ strat = self.strategy
+ strat.setup_query(
+ context, query_entity, path, loader, adapter, **kwargs
+ )
+
+ def create_row_processor(
+ self, context, query_entity, path, mapper, result, adapter, populators
+ ):
+ loader = self._get_context_loader(context, path)
+ if loader and loader.strategy:
+ strat = self._get_strategy(loader.strategy)
+ else:
+ strat = self.strategy
+ strat.create_row_processor(
+ context,
+ query_entity,
+ path,
+ loader,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+ def do_init(self):
+ self._strategies = {}
+ self.strategy = self._get_strategy(self.strategy_key)
+
+ def post_instrument_class(self, mapper):
+ if (
+ not self.parent.non_primary
+ and not mapper.class_manager._attr_has_impl(self.key)
+ ):
+ self.strategy.init_class_attribute(mapper)
+
+ _all_strategies = collections.defaultdict(dict)
+
+ @classmethod
+ def strategy_for(cls, **kw):
+ def decorate(dec_cls):
+ # ensure each subclass of the strategy has its
+ # own _strategy_keys collection
+ if "_strategy_keys" not in dec_cls.__dict__:
+ dec_cls._strategy_keys = []
+ key = tuple(sorted(kw.items()))
+ cls._all_strategies[cls][key] = dec_cls
+ dec_cls._strategy_keys.append(key)
+ return dec_cls
+
+ return decorate
+
+ @classmethod
+ def _strategy_lookup(cls, requesting_property, *key):
+ requesting_property.parent._with_polymorphic_mappers
+
+ for prop_cls in cls.__mro__:
+ if prop_cls in cls._all_strategies:
+ strategies = cls._all_strategies[prop_cls]
+ try:
+ return strategies[key]
+ except KeyError:
+ pass
+
+ for property_type, strats in cls._all_strategies.items():
+ if key in strats:
+ intended_property_type = property_type
+ actual_strategy = strats[key]
+ break
+ else:
+ intended_property_type = None
+ actual_strategy = None
+
+ raise orm_exc.LoaderStrategyException(
+ cls,
+ requesting_property,
+ intended_property_type,
+ actual_strategy,
+ key,
+ )
+
+
+class ORMOption(ExecutableOption):
+ """Base class for option objects that are passed to ORM queries.
+
+ These options may be consumed by :meth:`.Query.options`,
+ :meth:`.Select.options`, or in a more general sense by any
+ :meth:`.Executable.options` method. They are interpreted at
+ statement compile time or execution time in modern use. The
+ deprecated :class:`.MapperOption` is consumed at ORM query construction
+ time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ()
+
+ _is_legacy_option = False
+
+ propagate_to_loaders = False
+ """if True, indicate this option should be carried along
+ to "secondary" SELECT statements that occur for relationship
+ lazy loaders as well as attribute load / refresh operations.
+
+ """
+
+ _is_compile_state = False
+
+ _is_criteria_option = False
+
+ _is_strategy_option = False
+
+
+class CompileStateOption(HasCacheKey, ORMOption):
+ """base for :class:`.ORMOption` classes that affect the compilation of
+ a SQL query and therefore need to be part of the cache key.
+
+ .. note:: :class:`.CompileStateOption` is generally non-public and
+ should not be used as a base class for user-defined options; instead,
+ use :class:`.UserDefinedOption`, which is easier to use as it does not
+ interact with ORM compilation internals or caching.
+
+ :class:`.CompileStateOption` defines an internal attribute
+ ``_is_compile_state=True`` which has the effect of the ORM compilation
+ routines for SELECT and other statements will call upon these options when
+ a SQL string is being compiled. As such, these classes implement
+ :class:`.HasCacheKey` and need to provide robust ``_cache_key_traversal``
+ structures.
+
+ The :class:`.CompileStateOption` class is used to implement the ORM
+ :class:`.LoaderOption` and :class:`.CriteriaOption` classes.
+
+ .. versionadded:: 1.4.28
+
+
+ """
+
+ _is_compile_state = True
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ """Apply a modification to a given :class:`.CompileState`,
+ given entities that were replaced by with_only_columns() or
+ with_entities().
+
+ .. versionadded:: 1.4.19
+
+ """
+
+
+class LoaderOption(CompileStateOption):
+ """Describe a loader modification to an ORM statement at compilation time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ """Apply a modification to a given :class:`.CompileState`,
+ given entities that were replaced by with_only_columns() or
+ with_entities().
+
+ .. versionadded:: 1.4.19
+
+ """
+ self.process_compile_state(compile_state)
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+
+class CriteriaOption(CompileStateOption):
+ """Describe a WHERE criteria modification to an ORM statement at
+ compilation time.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _is_criteria_option = True
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ def get_global_criteria(self, attributes):
+ """update additional entity criteria options in the given
+ attributes dictionary.
+
+ """
+
+
+class UserDefinedOption(ORMOption):
+ """Base class for a user-defined option that can be consumed from the
+ :meth:`.SessionEvents.do_orm_execute` event hook.
+
+ """
+
+ _is_legacy_option = False
+
+ propagate_to_loaders = False
+ """if True, indicate this option should be carried along
+ to "secondary" Query objects produced during lazy loads
+ or refresh operations.
+
+ """
+
+ def __init__(self, payload=None):
+ self.payload = payload
+
+
+@util.deprecated_cls(
+ "1.4",
+ "The :class:`.MapperOption class is deprecated and will be removed "
+ "in a future release. For "
+ "modifications to queries on a per-execution basis, use the "
+ ":class:`.UserDefinedOption` class to establish state within a "
+ ":class:`.Query` or other Core statement, then use the "
+ ":meth:`.SessionEvents.before_orm_execute` hook to consume them.",
+ constructor=None,
+)
+class MapperOption(ORMOption):
+ """Describe a modification to a Query"""
+
+ _is_legacy_option = True
+
+ propagate_to_loaders = False
+ """if True, indicate this option should be carried along
+ to "secondary" Query objects produced during lazy loads
+ or refresh operations.
+
+ """
+
+ def process_query(self, query):
+ """Apply a modification to the given :class:`_query.Query`."""
+
+ def process_query_conditionally(self, query):
+ """same as process_query(), except that this option may not
+ apply to the given query.
+
+ This is typically applied during a lazy load or scalar refresh
+ operation to propagate options stated in the original Query to the
+ new Query being used for the load. It occurs for those options that
+ specify propagate_to_loaders=True.
+
+ """
+
+ self.process_query(query)
+
+
+class LoaderStrategy(object):
+ """Describe the loading behavior of a StrategizedProperty object.
+
+ The ``LoaderStrategy`` interacts with the querying process in three
+ ways:
+
+ * it controls the configuration of the ``InstrumentedAttribute``
+ placed on a class to handle the behavior of the attribute. this
+ may involve setting up class-level callable functions to fire
+ off a select operation when the attribute is first accessed
+ (i.e. a lazy load)
+
+ * it processes the ``QueryContext`` at statement construction time,
+ where it can modify the SQL statement that is being produced.
+ For example, simple column attributes will add their represented
+ column to the list of selected columns, a joined eager loader
+ may establish join clauses to add to the statement.
+
+ * It produces "row processor" functions at result fetching time.
+ These "row processor" functions populate a particular attribute
+ on a particular mapped instance.
+
+ """
+
+ __slots__ = (
+ "parent_property",
+ "is_class_level",
+ "parent",
+ "key",
+ "strategy_key",
+ "strategy_opts",
+ )
+
+ def __init__(self, parent, strategy_key):
+ self.parent_property = parent
+ self.is_class_level = False
+ self.parent = self.parent_property.parent
+ self.key = self.parent_property.key
+ self.strategy_key = strategy_key
+ self.strategy_opts = dict(strategy_key)
+
+ def init_class_attribute(self, mapper):
+ pass
+
+ def setup_query(
+ self, compile_state, query_entity, path, loadopt, adapter, **kwargs
+ ):
+ """Establish column and other state for a given QueryContext.
+
+ This method fulfills the contract specified by MapperProperty.setup().
+
+ StrategizedProperty delegates its setup() method
+ directly to this method.
+
+ """
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ """Establish row processing functions for a given QueryContext.
+
+ This method fulfills the contract specified by
+ MapperProperty.create_row_processor().
+
+ StrategizedProperty delegates its create_row_processor() method
+ directly to this method.
+
+ """
+
+ def __str__(self):
+ return str(self.parent_property)
diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py
new file mode 100644
index 0000000..b5691c0
--- /dev/null
+++ b/lib/sqlalchemy/orm/loading.py
@@ -0,0 +1,1465 @@
+# orm/loading.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used to convert database
+rows into object instances and associated state.
+
+the functions here are called primarily by Query, Mapper,
+as well as some of the attribute loading strategies.
+
+"""
+from __future__ import absolute_import
+
+from . import attributes
+from . import exc as orm_exc
+from . import path_registry
+from . import strategy_options
+from .base import _DEFER_FOR_STATE
+from .base import _RAISE_FOR_STATE
+from .base import _SET_DEFERRED_EXPIRED
+from .util import _none_set
+from .util import state_str
+from .. import exc as sa_exc
+from .. import future
+from .. import util
+from ..engine import result_tuple
+from ..engine.result import ChunkedIteratorResult
+from ..engine.result import FrozenResult
+from ..engine.result import SimpleResultMetaData
+from ..sql import util as sql_util
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectState
+
+_new_runid = util.counter()
+
+
+def instances(cursor, context):
+ """Return a :class:`.Result` given an ORM query context.
+
+ :param cursor: a :class:`.CursorResult`, generated by a statement
+ which came from :class:`.ORMCompileState`
+
+ :param context: a :class:`.QueryContext` object
+
+ :return: a :class:`.Result` object representing ORM results
+
+ .. versionchanged:: 1.4 The instances() function now uses
+ :class:`.Result` objects and has an all new interface.
+
+ """
+
+ context.runid = _new_runid()
+ context.post_load_paths = {}
+
+ compile_state = context.compile_state
+ filtered = compile_state._has_mapper_entities
+ single_entity = (
+ not context.load_options._only_return_tuples
+ and len(compile_state._entities) == 1
+ and compile_state._entities[0].supports_single_entity
+ )
+
+ try:
+ (process, labels, extra) = list(
+ zip(
+ *[
+ query_entity.row_processor(context, cursor)
+ for query_entity in context.compile_state._entities
+ ]
+ )
+ )
+
+ if context.yield_per and (
+ context.loaders_require_buffering
+ or context.loaders_require_uniquing
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Can't use yield_per with eager loaders that require uniquing "
+ "or row buffering, e.g. joinedload() against collections "
+ "or subqueryload(). Consider the selectinload() strategy "
+ "for better flexibility in loading objects."
+ )
+
+ except Exception:
+ with util.safe_reraise():
+ cursor.close()
+
+ def _no_unique(entry):
+ raise sa_exc.InvalidRequestError(
+ "Can't use the ORM yield_per feature in conjunction with unique()"
+ )
+
+ def _not_hashable(datatype):
+ def go(obj):
+ raise sa_exc.InvalidRequestError(
+ "Can't apply uniqueness to row tuple containing value of "
+ "type %r; this datatype produces non-hashable values"
+ % datatype
+ )
+
+ return go
+
+ if context.load_options._legacy_uniquing:
+ unique_filters = [
+ _no_unique
+ if context.yield_per
+ else id
+ if (
+ ent.use_id_for_hash
+ or ent._non_hashable_value
+ or ent._null_column_type
+ )
+ else None
+ for ent in context.compile_state._entities
+ ]
+ else:
+ unique_filters = [
+ _no_unique
+ if context.yield_per
+ else _not_hashable(ent.column.type)
+ if (not ent.use_id_for_hash and ent._non_hashable_value)
+ else id
+ if ent.use_id_for_hash
+ else None
+ for ent in context.compile_state._entities
+ ]
+
+ row_metadata = SimpleResultMetaData(
+ labels, extra, _unique_filters=unique_filters
+ )
+
+ def chunks(size):
+ while True:
+ yield_per = size
+
+ context.partials = {}
+
+ if yield_per:
+ fetch = cursor.fetchmany(yield_per)
+
+ if not fetch:
+ break
+ else:
+ fetch = cursor._raw_all_rows()
+
+ if single_entity:
+ proc = process[0]
+ rows = [proc(row) for row in fetch]
+ else:
+ rows = [
+ tuple([proc(row) for proc in process]) for row in fetch
+ ]
+
+ for path, post_load in context.post_load_paths.items():
+ post_load.invoke(context, path)
+
+ yield rows
+
+ if not yield_per:
+ break
+
+ if context.execution_options.get("prebuffer_rows", False):
+ # this is a bit of a hack at the moment.
+ # I would rather have some option in the result to pre-buffer
+ # internally.
+ _prebuffered = list(chunks(None))
+
+ def chunks(size):
+ return iter(_prebuffered)
+
+ result = ChunkedIteratorResult(
+ row_metadata,
+ chunks,
+ source_supports_scalars=single_entity,
+ raw=cursor,
+ dynamic_yield_per=cursor.context._is_server_side,
+ )
+
+ # filtered and single_entity are used to indicate to legacy Query that the
+ # query has ORM entities, so legacy deduping and scalars should be called
+ # on the result.
+ result._attributes = result._attributes.union(
+ dict(filtered=filtered, is_single_entity=single_entity)
+ )
+
+ # multi_row_eager_loaders OTOH is specific to joinedload.
+ if context.compile_state.multi_row_eager_loaders:
+
+ def require_unique(obj):
+ raise sa_exc.InvalidRequestError(
+ "The unique() method must be invoked on this Result, "
+ "as it contains results that include joined eager loads "
+ "against collections"
+ )
+
+ result._unique_filter_state = (None, require_unique)
+
+ if context.yield_per:
+ result.yield_per(context.yield_per)
+
+ return result
+
+
+@util.preload_module("sqlalchemy.orm.context")
+def merge_frozen_result(session, statement, frozen_result, load=True):
+ """Merge a :class:`_engine.FrozenResult` back into a :class:`_orm.Session`,
+ returning a new :class:`_engine.Result` object with :term:`persistent`
+ objects.
+
+ See the section :ref:`do_orm_execute_re_executing` for an example.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing`
+
+ :meth:`_engine.Result.freeze`
+
+ :class:`_engine.FrozenResult`
+
+ """
+ querycontext = util.preloaded.orm_context
+
+ if load:
+ # flush current contents if we expect to load data
+ session._autoflush()
+
+ ctx = querycontext.ORMSelectCompileState._create_entities_collection(
+ statement, legacy=False
+ )
+
+ autoflush = session.autoflush
+ try:
+ session.autoflush = False
+ mapped_entities = [
+ i
+ for i, e in enumerate(ctx._entities)
+ if isinstance(e, querycontext._MapperEntity)
+ ]
+ keys = [ent._label_name for ent in ctx._entities]
+
+ keyed_tuple = result_tuple(
+ keys, [ent._extra_entities for ent in ctx._entities]
+ )
+
+ result = []
+ for newrow in frozen_result.rewrite_rows():
+ for i in mapped_entities:
+ if newrow[i] is not None:
+ newrow[i] = session._merge(
+ attributes.instance_state(newrow[i]),
+ attributes.instance_dict(newrow[i]),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+
+ result.append(keyed_tuple(newrow))
+
+ return frozen_result.with_new_rows(result)
+ finally:
+ session.autoflush = autoflush
+
+
+@util.deprecated_20(
+ ":func:`_orm.merge_result`",
+ alternative="The function as well as the method on :class:`_orm.Query` "
+ "is superseded by the :func:`_orm.merge_frozen_result` function.",
+ becomes_legacy=True,
+)
+@util.preload_module("sqlalchemy.orm.context")
+def merge_result(query, iterator, load=True):
+ """Merge a result into the given :class:`.Query` object's Session.
+
+ See :meth:`_orm.Query.merge_result` for top-level documentation on this
+ function.
+
+ """
+
+ querycontext = util.preloaded.orm_context
+
+ session = query.session
+ if load:
+ # flush current contents if we expect to load data
+ session._autoflush()
+
+ # TODO: need test coverage and documentation for the FrozenResult
+ # use case.
+ if isinstance(iterator, FrozenResult):
+ frozen_result = iterator
+ iterator = iter(frozen_result.data)
+ else:
+ frozen_result = None
+
+ ctx = querycontext.ORMSelectCompileState._create_entities_collection(
+ query, legacy=True
+ )
+
+ autoflush = session.autoflush
+ try:
+ session.autoflush = False
+ single_entity = not frozen_result and len(ctx._entities) == 1
+
+ if single_entity:
+ if isinstance(ctx._entities[0], querycontext._MapperEntity):
+ result = [
+ session._merge(
+ attributes.instance_state(instance),
+ attributes.instance_dict(instance),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+ for instance in iterator
+ ]
+ else:
+ result = list(iterator)
+ else:
+ mapped_entities = [
+ i
+ for i, e in enumerate(ctx._entities)
+ if isinstance(e, querycontext._MapperEntity)
+ ]
+ result = []
+ keys = [ent._label_name for ent in ctx._entities]
+
+ keyed_tuple = result_tuple(
+ keys, [ent._extra_entities for ent in ctx._entities]
+ )
+
+ for row in iterator:
+ newrow = list(row)
+ for i in mapped_entities:
+ if newrow[i] is not None:
+ newrow[i] = session._merge(
+ attributes.instance_state(newrow[i]),
+ attributes.instance_dict(newrow[i]),
+ load=load,
+ _recursive={},
+ _resolve_conflict_map={},
+ )
+ result.append(keyed_tuple(newrow))
+
+ if frozen_result:
+ return frozen_result.with_data(result)
+ else:
+ return iter(result)
+ finally:
+ session.autoflush = autoflush
+
+
+def get_from_identity(session, mapper, key, passive):
+ """Look up the given key in the given session's identity map,
+ check the object for expired state if found.
+
+ """
+ instance = session.identity_map.get(key)
+ if instance is not None:
+
+ state = attributes.instance_state(instance)
+
+ if mapper.inherits and not state.mapper.isa(mapper):
+ return attributes.PASSIVE_CLASS_MISMATCH
+
+ # expired - ensure it still exists
+ if state.expired:
+ if not passive & attributes.SQL_OK:
+ # TODO: no coverage here
+ return attributes.PASSIVE_NO_RESULT
+ elif not passive & attributes.RELATED_OBJECT_OK:
+ # this mode is used within a flush and the instance's
+ # expired state will be checked soon enough, if necessary.
+ # also used by immediateloader for a mutually-dependent
+ # o2m->m2m load, :ticket:`6301`
+ return instance
+ try:
+ state._load_expired(state, passive)
+ except orm_exc.ObjectDeletedError:
+ session._remove_newly_deleted([state])
+ return None
+ return instance
+ else:
+ return None
+
+
+def load_on_ident(
+ session,
+ statement,
+ key,
+ load_options=None,
+ refresh_state=None,
+ with_for_update=None,
+ only_load_props=None,
+ no_autoflush=False,
+ bind_arguments=util.EMPTY_DICT,
+ execution_options=util.EMPTY_DICT,
+):
+ """Load the given identity key from the database."""
+ if key is not None:
+ ident = key[1]
+ identity_token = key[2]
+ else:
+ ident = identity_token = None
+
+ return load_on_pk_identity(
+ session,
+ statement,
+ ident,
+ load_options=load_options,
+ refresh_state=refresh_state,
+ with_for_update=with_for_update,
+ only_load_props=only_load_props,
+ identity_token=identity_token,
+ no_autoflush=no_autoflush,
+ bind_arguments=bind_arguments,
+ execution_options=execution_options,
+ )
+
+
+def load_on_pk_identity(
+ session,
+ statement,
+ primary_key_identity,
+ load_options=None,
+ refresh_state=None,
+ with_for_update=None,
+ only_load_props=None,
+ identity_token=None,
+ no_autoflush=False,
+ bind_arguments=util.EMPTY_DICT,
+ execution_options=util.EMPTY_DICT,
+):
+
+ """Load the given primary key identity from the database."""
+
+ query = statement
+ q = query._clone()
+
+ assert not q._is_lambda_element
+
+ # TODO: fix these imports ....
+ from .context import QueryContext, ORMCompileState
+
+ if load_options is None:
+ load_options = QueryContext.default_load_options
+
+ if (
+ statement._compile_options
+ is SelectState.default_select_compile_options
+ ):
+ compile_options = ORMCompileState.default_compile_options
+ else:
+ compile_options = statement._compile_options
+
+ if primary_key_identity is not None:
+ mapper = query._propagate_attrs["plugin_subject"]
+
+ (_get_clause, _get_params) = mapper._get_clause
+
+ # None present in ident - turn those comparisons
+ # into "IS NULL"
+ if None in primary_key_identity:
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
+
+ _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones)
+
+ if len(nones) == len(primary_key_identity):
+ util.warn(
+ "fully NULL primary key identity cannot load any "
+ "object. This condition may raise an error in a future "
+ "release."
+ )
+
+ q._where_criteria = (
+ sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}),
+ )
+
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
+ else:
+ params = None
+
+ if with_for_update is not None:
+ version_check = True
+ q._for_update_arg = with_for_update
+ elif query._for_update_arg is not None:
+ version_check = True
+ q._for_update_arg = query._for_update_arg
+ else:
+ version_check = False
+
+ if refresh_state and refresh_state.load_options:
+ compile_options += {"_current_path": refresh_state.load_path.parent}
+ q = q.options(*refresh_state.load_options)
+
+ new_compile_options, load_options = _set_get_options(
+ compile_options,
+ load_options,
+ version_check=version_check,
+ only_load_props=only_load_props,
+ refresh_state=refresh_state,
+ identity_token=identity_token,
+ )
+ q._compile_options = new_compile_options
+ q._order_by = None
+
+ if no_autoflush:
+ load_options += {"_autoflush": False}
+
+ execution_options = util.EMPTY_DICT.merge_with(
+ execution_options, {"_sa_orm_load_options": load_options}
+ )
+ result = (
+ session.execute(
+ q,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ )
+ .unique()
+ .scalars()
+ )
+
+ try:
+ return result.one()
+ except orm_exc.NoResultFound:
+ return None
+
+
+def _set_get_options(
+ compile_opt,
+ load_opt,
+ populate_existing=None,
+ version_check=None,
+ only_load_props=None,
+ refresh_state=None,
+ identity_token=None,
+):
+
+ compile_options = {}
+ load_options = {}
+ if version_check:
+ load_options["_version_check"] = version_check
+ if populate_existing:
+ load_options["_populate_existing"] = populate_existing
+ if refresh_state:
+ load_options["_refresh_state"] = refresh_state
+ compile_options["_for_refresh_state"] = True
+ if only_load_props:
+ compile_options["_only_load_props"] = frozenset(only_load_props)
+ if identity_token:
+ load_options["_refresh_identity_token"] = identity_token
+
+ if load_options:
+ load_opt += load_options
+ if compile_options:
+ compile_opt += compile_options
+
+ return compile_opt, load_opt
+
+
+def _setup_entity_query(
+ compile_state,
+ mapper,
+ query_entity,
+ path,
+ adapter,
+ column_collection,
+ with_polymorphic=None,
+ only_load_props=None,
+ polymorphic_discriminator=None,
+ **kw
+):
+
+ if with_polymorphic:
+ poly_properties = mapper._iterate_polymorphic_properties(
+ with_polymorphic
+ )
+ else:
+ poly_properties = mapper._polymorphic_properties
+
+ quick_populators = {}
+
+ path.set(compile_state.attributes, "memoized_setups", quick_populators)
+
+ # for the lead entities in the path, e.g. not eager loads, and
+ # assuming a user-passed aliased class, e.g. not a from_self() or any
+ # implicit aliasing, don't add columns to the SELECT that aren't
+ # in the thing that's aliased.
+ check_for_adapt = adapter and len(path) == 1 and path[-1].is_aliased_class
+
+ for value in poly_properties:
+ if only_load_props and value.key not in only_load_props:
+ continue
+
+ value.setup(
+ compile_state,
+ query_entity,
+ path,
+ adapter,
+ only_load_props=only_load_props,
+ column_collection=column_collection,
+ memoized_populators=quick_populators,
+ check_for_adapt=check_for_adapt,
+ **kw
+ )
+
+ if (
+ polymorphic_discriminator is not None
+ and polymorphic_discriminator is not mapper.polymorphic_on
+ ):
+
+ if adapter:
+ pd = adapter.columns[polymorphic_discriminator]
+ else:
+ pd = polymorphic_discriminator
+ column_collection.append(pd)
+
+
+def _warn_for_runid_changed(state):
+ util.warn(
+ "Loading context for %s has changed within a load/refresh "
+ "handler, suggesting a row refresh operation took place. If this "
+ "event handler is expected to be "
+ "emitting row refresh operations within an existing load or refresh "
+ "operation, set restore_load_context=True when establishing the "
+ "listener to ensure the context remains unchanged when the event "
+ "handler completes." % (state_str(state),)
+ )
+
+
+def _instance_processor(
+ query_entity,
+ mapper,
+ context,
+ result,
+ path,
+ adapter,
+ only_load_props=None,
+ refresh_state=None,
+ polymorphic_discriminator=None,
+ _polymorphic_from=None,
+):
+ """Produce a mapper level row processor callable
+ which processes rows into mapped instances."""
+
+ # note that this method, most of which exists in a closure
+ # called _instance(), resists being broken out, as
+ # attempts to do so tend to add significant function
+ # call overhead. _instance() is the most
+ # performance-critical section in the whole ORM.
+
+ identity_class = mapper._identity_class
+ compile_state = context.compile_state
+
+ # look for "row getter" functions that have been assigned along
+ # with the compile state that were cached from a previous load.
+ # these are operator.itemgetter() objects that each will extract a
+ # particular column from each row.
+
+ getter_key = ("getters", mapper)
+ getters = path.get(compile_state.attributes, getter_key, None)
+
+ if getters is None:
+ # no getters, so go through a list of attributes we are loading for,
+ # and the ones that are column based will have already put information
+ # for us in another collection "memoized_setups", which represents the
+ # output of the LoaderStrategy.setup_query() method. We can just as
+ # easily call LoaderStrategy.create_row_processor for each, but by
+ # getting it all at once from setup_query we save another method call
+ # per attribute.
+ props = mapper._prop_set
+ if only_load_props is not None:
+ props = props.intersection(
+ mapper._props[k] for k in only_load_props
+ )
+
+ quick_populators = path.get(
+ context.attributes, "memoized_setups", _none_set
+ )
+
+ todo = []
+ cached_populators = {
+ "new": [],
+ "quick": [],
+ "deferred": [],
+ "expire": [],
+ "delayed": [],
+ "existing": [],
+ "eager": [],
+ }
+
+ if refresh_state is None:
+ # we can also get the "primary key" tuple getter function
+ pk_cols = mapper.primary_key
+
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+ primary_key_getter = result._tuple_getter(pk_cols)
+ else:
+ primary_key_getter = None
+
+ getters = {
+ "cached_populators": cached_populators,
+ "todo": todo,
+ "primary_key_getter": primary_key_getter,
+ }
+ for prop in props:
+ if prop in quick_populators:
+ # this is an inlined path just for column-based attributes.
+ col = quick_populators[prop]
+ if col is _DEFER_FOR_STATE:
+ cached_populators["new"].append(
+ (prop.key, prop._deferred_column_loader)
+ )
+ elif col is _SET_DEFERRED_EXPIRED:
+ # note that in this path, we are no longer
+ # searching in the result to see if the column might
+ # be present in some unexpected way.
+ cached_populators["expire"].append((prop.key, False))
+ elif col is _RAISE_FOR_STATE:
+ cached_populators["new"].append(
+ (prop.key, prop._raise_column_loader)
+ )
+ else:
+ getter = None
+ if adapter:
+ # this logic had been removed for all 1.4 releases
+ # up until 1.4.18; the adapter here is particularly
+ # the compound eager adapter which isn't accommodated
+ # in the quick_populators right now. The "fallback"
+ # logic below instead took over in many more cases
+ # until issue #6596 was identified.
+
+ # note there is still an issue where this codepath
+ # produces no "getter" for cases where a joined-inh
+ # mapping includes a labeled column property, meaning
+ # KeyError is caught internally and we fall back to
+ # _getter(col), which works anyway. The adapter
+ # here for joined inh without any aliasing might not
+ # be useful. Tests which see this include
+ # test.orm.inheritance.test_basic ->
+ # EagerTargetingTest.test_adapt_stringency
+ # OptimizedLoadTest.test_column_expression_joined
+ # PolymorphicOnNotLocalTest.test_polymorphic_on_column_prop # noqa: E501
+ #
+
+ adapted_col = adapter.columns[col]
+ if adapted_col is not None:
+ getter = result._getter(adapted_col, False)
+ if not getter:
+ getter = result._getter(col, False)
+ if getter:
+ cached_populators["quick"].append((prop.key, getter))
+ else:
+ # fall back to the ColumnProperty itself, which
+ # will iterate through all of its columns
+ # to see if one fits
+ prop.create_row_processor(
+ context,
+ query_entity,
+ path,
+ mapper,
+ result,
+ adapter,
+ cached_populators,
+ )
+ else:
+ # loader strategies like subqueryload, selectinload,
+ # joinedload, basically relationships, these need to interact
+ # with the context each time to work correctly.
+ todo.append(prop)
+
+ path.set(compile_state.attributes, getter_key, getters)
+
+ cached_populators = getters["cached_populators"]
+
+ populators = {key: list(value) for key, value in cached_populators.items()}
+ for prop in getters["todo"]:
+ prop.create_row_processor(
+ context, query_entity, path, mapper, result, adapter, populators
+ )
+
+ propagated_loader_options = context.propagated_loader_options
+ load_path = (
+ context.compile_state.current_path + path
+ if context.compile_state.current_path.path
+ else path
+ )
+
+ session_identity_map = context.session.identity_map
+
+ populate_existing = context.populate_existing or mapper.always_refresh
+ load_evt = bool(mapper.class_manager.dispatch.load)
+ refresh_evt = bool(mapper.class_manager.dispatch.refresh)
+ persistent_evt = bool(context.session.dispatch.loaded_as_persistent)
+ if persistent_evt:
+ loaded_as_persistent = context.session.dispatch.loaded_as_persistent
+ instance_state = attributes.instance_state
+ instance_dict = attributes.instance_dict
+ session_id = context.session.hash_key
+ runid = context.runid
+ identity_token = context.identity_token
+
+ version_check = context.version_check
+ if version_check:
+ version_id_col = mapper.version_id_col
+ if version_id_col is not None:
+ if adapter:
+ version_id_col = adapter.columns[version_id_col]
+ version_id_getter = result._getter(version_id_col)
+ else:
+ version_id_getter = None
+
+ if not refresh_state and _polymorphic_from is not None:
+ key = ("loader", path.path)
+ if key in context.attributes and context.attributes[key].strategy == (
+ ("selectinload_polymorphic", True),
+ ):
+ selectin_load_via = mapper._should_selectin_load(
+ context.attributes[key].local_opts["entities"],
+ _polymorphic_from,
+ )
+ else:
+ selectin_load_via = mapper._should_selectin_load(
+ None, _polymorphic_from
+ )
+
+ if selectin_load_via and selectin_load_via is not _polymorphic_from:
+ # only_load_props goes w/ refresh_state only, and in a refresh
+ # we are a single row query for the exact entity; polymorphic
+ # loading does not apply
+ assert only_load_props is None
+
+ callable_ = _load_subclass_via_in(context, path, selectin_load_via)
+
+ PostLoad.callable_for_path(
+ context,
+ load_path,
+ selectin_load_via.mapper,
+ selectin_load_via,
+ callable_,
+ selectin_load_via,
+ )
+
+ post_load = PostLoad.for_context(context, load_path, only_load_props)
+
+ if refresh_state:
+ refresh_identity_key = refresh_state.key
+ if refresh_identity_key is None:
+ # super-rare condition; a refresh is being called
+ # on a non-instance-key instance; this is meant to only
+ # occur within a flush()
+ refresh_identity_key = mapper._identity_key_from_state(
+ refresh_state
+ )
+ else:
+ refresh_identity_key = None
+
+ primary_key_getter = getters["primary_key_getter"]
+
+ if mapper.allow_partial_pks:
+ is_not_primary_key = _none_set.issuperset
+ else:
+ is_not_primary_key = _none_set.intersection
+
+ def _instance(row):
+
+ # determine the state that we'll be populating
+ if refresh_identity_key:
+ # fixed state that we're refreshing
+ state = refresh_state
+ instance = state.obj()
+ dict_ = instance_dict(instance)
+ isnew = state.runid != runid
+ currentload = True
+ loaded_instance = False
+ else:
+ # look at the row, see if that identity is in the
+ # session, or we have to create a new one
+ identitykey = (
+ identity_class,
+ primary_key_getter(row),
+ identity_token,
+ )
+
+ instance = session_identity_map.get(identitykey)
+
+ if instance is not None:
+ # existing instance
+ state = instance_state(instance)
+ dict_ = instance_dict(instance)
+
+ isnew = state.runid != runid
+ currentload = not isnew
+ loaded_instance = False
+
+ if version_check and version_id_getter and not currentload:
+ _validate_version_id(
+ mapper, state, dict_, row, version_id_getter
+ )
+
+ else:
+ # create a new instance
+
+ # check for non-NULL values in the primary key columns,
+ # else no entity is returned for the row
+ if is_not_primary_key(identitykey[1]):
+ return None
+
+ isnew = True
+ currentload = True
+ loaded_instance = True
+
+ instance = mapper.class_manager.new_instance()
+
+ dict_ = instance_dict(instance)
+ state = instance_state(instance)
+ state.key = identitykey
+ state.identity_token = identity_token
+
+ # attach instance to session.
+ state.session_id = session_id
+ session_identity_map._add_unpresent(state, identitykey)
+
+ effective_populate_existing = populate_existing
+ if refresh_state is state:
+ effective_populate_existing = True
+
+ # populate. this looks at whether this state is new
+ # for this load or was existing, and whether or not this
+ # row is the first row with this identity.
+ if currentload or effective_populate_existing:
+ # full population routines. Objects here are either
+ # just created, or we are doing a populate_existing
+
+ # be conservative about setting load_path when populate_existing
+ # is in effect; want to maintain options from the original
+ # load. see test_expire->test_refresh_maintains_deferred_options
+ if isnew and (
+ propagated_loader_options or not effective_populate_existing
+ ):
+ state.load_options = propagated_loader_options
+ state.load_path = load_path
+
+ _populate_full(
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ effective_populate_existing,
+ populators,
+ )
+
+ if isnew:
+ # state.runid should be equal to context.runid / runid
+ # here, however for event checks we are being more conservative
+ # and checking against existing run id
+ # assert state.runid == runid
+
+ existing_runid = state.runid
+
+ if loaded_instance:
+ if load_evt:
+ state.manager.dispatch.load(state, context)
+ if state.runid != existing_runid:
+ _warn_for_runid_changed(state)
+ if persistent_evt:
+ loaded_as_persistent(context.session, state)
+ if state.runid != existing_runid:
+ _warn_for_runid_changed(state)
+ elif refresh_evt:
+ state.manager.dispatch.refresh(
+ state, context, only_load_props
+ )
+ if state.runid != runid:
+ _warn_for_runid_changed(state)
+
+ if effective_populate_existing or state.modified:
+ if refresh_state and only_load_props:
+ state._commit(dict_, only_load_props)
+ else:
+ state._commit_all(dict_, session_identity_map)
+
+ if post_load:
+ post_load.add_state(state, True)
+
+ else:
+ # partial population routines, for objects that were already
+ # in the Session, but a row matches them; apply eager loaders
+ # on existing objects, etc.
+ unloaded = state.unloaded
+ isnew = state not in context.partials
+
+ if not isnew or unloaded or populators["eager"]:
+ # state is having a partial set of its attributes
+ # refreshed. Populate those attributes,
+ # and add to the "context.partials" collection.
+
+ to_load = _populate_partial(
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ unloaded,
+ populators,
+ )
+
+ if isnew:
+ if refresh_evt:
+ existing_runid = state.runid
+ state.manager.dispatch.refresh(state, context, to_load)
+ if state.runid != existing_runid:
+ _warn_for_runid_changed(state)
+
+ state._commit(dict_, to_load)
+
+ if post_load and context.invoke_all_eagers:
+ post_load.add_state(state, False)
+
+ return instance
+
+ if mapper.polymorphic_map and not _polymorphic_from and not refresh_state:
+ # if we are doing polymorphic, dispatch to a different _instance()
+ # method specific to the subclass mapper
+ def ensure_no_pk(row):
+ identitykey = (
+ identity_class,
+ primary_key_getter(row),
+ identity_token,
+ )
+ if not is_not_primary_key(identitykey[1]):
+ return identitykey
+ else:
+ return None
+
+ _instance = _decorate_polymorphic_switch(
+ _instance,
+ context,
+ query_entity,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+ ensure_no_pk,
+ )
+
+ return _instance
+
+
+def _load_subclass_via_in(context, path, entity):
+ mapper = entity.mapper
+
+ zero_idx = len(mapper.base_mapper.primary_key) == 1
+
+ if entity.is_aliased_class:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in(entity)
+ else:
+ q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper
+
+ def do_load(context, path, states, load_only, effective_entity):
+ orig_query = context.query
+
+ options = (enable_opt,) + orig_query._with_options + (disable_opt,)
+ q2 = q.options(*options)
+
+ q2._compile_options = context.compile_state.default_compile_options
+ q2._compile_options += {"_current_path": path.parent}
+
+ if context.populate_existing:
+ q2 = q2.execution_options(populate_existing=True)
+
+ context.session.execute(
+ q2,
+ dict(
+ primary_keys=[
+ state.key[1][0] if zero_idx else state.key[1]
+ for state, load_attrs in states
+ ]
+ ),
+ ).unique().scalars().all()
+
+ return do_load
+
+
+def _populate_full(
+ context,
+ row,
+ state,
+ dict_,
+ isnew,
+ load_path,
+ loaded_instance,
+ populate_existing,
+ populators,
+):
+ if isnew:
+ # first time we are seeing a row with this identity.
+ state.runid = context.runid
+
+ for key, getter in populators["quick"]:
+ dict_[key] = getter(row)
+ if populate_existing:
+ for key, set_callable in populators["expire"]:
+ dict_.pop(key, None)
+ if set_callable:
+ state.expired_attributes.add(key)
+ else:
+ for key, set_callable in populators["expire"]:
+ if set_callable:
+ state.expired_attributes.add(key)
+
+ for key, populator in populators["new"]:
+ populator(state, dict_, row)
+ for key, populator in populators["delayed"]:
+ populator(state, dict_, row)
+ elif load_path != state.load_path:
+ # new load path, e.g. object is present in more than one
+ # column position in a series of rows
+ state.load_path = load_path
+
+ # if we have data, and the data isn't in the dict, OK, let's put
+ # it in.
+ for key, getter in populators["quick"]:
+ if key not in dict_:
+ dict_[key] = getter(row)
+
+ # otherwise treat like an "already seen" row
+ for key, populator in populators["existing"]:
+ populator(state, dict_, row)
+ # TODO: allow "existing" populator to know this is
+ # a new path for the state:
+ # populator(state, dict_, row, new_path=True)
+
+ else:
+ # have already seen rows with this identity in this same path.
+ for key, populator in populators["existing"]:
+ populator(state, dict_, row)
+
+ # TODO: same path
+ # populator(state, dict_, row, new_path=False)
+
+
+def _populate_partial(
+ context, row, state, dict_, isnew, load_path, unloaded, populators
+):
+
+ if not isnew:
+ to_load = context.partials[state]
+ for key, populator in populators["existing"]:
+ if key in to_load:
+ populator(state, dict_, row)
+ else:
+ to_load = unloaded
+ context.partials[state] = to_load
+
+ for key, getter in populators["quick"]:
+ if key in to_load:
+ dict_[key] = getter(row)
+ for key, set_callable in populators["expire"]:
+ if key in to_load:
+ dict_.pop(key, None)
+ if set_callable:
+ state.expired_attributes.add(key)
+ for key, populator in populators["new"]:
+ if key in to_load:
+ populator(state, dict_, row)
+ for key, populator in populators["delayed"]:
+ if key in to_load:
+ populator(state, dict_, row)
+ for key, populator in populators["eager"]:
+ if key not in unloaded:
+ populator(state, dict_, row)
+
+ return to_load
+
+
+def _validate_version_id(mapper, state, dict_, row, getter):
+
+ if mapper._get_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ ) != getter(row):
+ raise orm_exc.StaleDataError(
+ "Instance '%s' has version id '%s' which "
+ "does not match database-loaded version id '%s'."
+ % (
+ state_str(state),
+ mapper._get_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ ),
+ getter(row),
+ )
+ )
+
+
+def _decorate_polymorphic_switch(
+ instance_fn,
+ context,
+ query_entity,
+ mapper,
+ result,
+ path,
+ polymorphic_discriminator,
+ adapter,
+ ensure_no_pk,
+):
+ if polymorphic_discriminator is not None:
+ polymorphic_on = polymorphic_discriminator
+ else:
+ polymorphic_on = mapper.polymorphic_on
+ if polymorphic_on is None:
+ return instance_fn
+
+ if adapter:
+ polymorphic_on = adapter.columns[polymorphic_on]
+
+ def configure_subclass_mapper(discriminator):
+ try:
+ sub_mapper = mapper.polymorphic_map[discriminator]
+ except KeyError:
+ raise AssertionError(
+ "No such polymorphic_identity %r is defined" % discriminator
+ )
+ else:
+ if sub_mapper is mapper:
+ return None
+ elif not sub_mapper.isa(mapper):
+ return False
+
+ return _instance_processor(
+ query_entity,
+ sub_mapper,
+ context,
+ result,
+ path,
+ adapter,
+ _polymorphic_from=mapper,
+ )
+
+ polymorphic_instances = util.PopulateDict(configure_subclass_mapper)
+
+ getter = result._getter(polymorphic_on)
+
+ def polymorphic_instance(row):
+ discriminator = getter(row)
+ if discriminator is not None:
+ _instance = polymorphic_instances[discriminator]
+ if _instance:
+ return _instance(row)
+ elif _instance is False:
+ identitykey = ensure_no_pk(row)
+
+ if identitykey:
+ raise sa_exc.InvalidRequestError(
+ "Row with identity key %s can't be loaded into an "
+ "object; the polymorphic discriminator column '%s' "
+ "refers to %s, which is not a sub-mapper of "
+ "the requested %s"
+ % (
+ identitykey,
+ polymorphic_on,
+ mapper.polymorphic_map[discriminator],
+ mapper,
+ )
+ )
+ else:
+ return None
+ else:
+ return instance_fn(row)
+ else:
+ identitykey = ensure_no_pk(row)
+
+ if identitykey:
+ raise sa_exc.InvalidRequestError(
+ "Row with identity key %s can't be loaded into an "
+ "object; the polymorphic discriminator column '%s' is "
+ "NULL" % (identitykey, polymorphic_on)
+ )
+ else:
+ return None
+
+ return polymorphic_instance
+
+
+class PostLoad(object):
+ """Track loaders and states for "post load" operations."""
+
+ __slots__ = "loaders", "states", "load_keys"
+
+ def __init__(self):
+ self.loaders = {}
+ self.states = util.OrderedDict()
+ self.load_keys = None
+
+ def add_state(self, state, overwrite):
+ # the states for a polymorphic load here are all shared
+ # within a single PostLoad object among multiple subtypes.
+ # Filtering of callables on a per-subclass basis needs to be done at
+ # the invocation level
+ self.states[state] = overwrite
+
+ def invoke(self, context, path):
+ if not self.states:
+ return
+ path = path_registry.PathRegistry.coerce(path)
+ for token, limit_to_mapper, loader, arg, kw in self.loaders.values():
+ states = [
+ (state, overwrite)
+ for state, overwrite in self.states.items()
+ if state.manager.mapper.isa(limit_to_mapper)
+ ]
+ if states:
+ loader(context, path, states, self.load_keys, *arg, **kw)
+ self.states.clear()
+
+ @classmethod
+ def for_context(cls, context, path, only_load_props):
+ pl = context.post_load_paths.get(path.path)
+ if pl is not None and only_load_props:
+ pl.load_keys = only_load_props
+ return pl
+
+ @classmethod
+ def path_exists(self, context, path, key):
+ return (
+ path.path in context.post_load_paths
+ and key in context.post_load_paths[path.path].loaders
+ )
+
+ @classmethod
+ def callable_for_path(
+ cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw
+ ):
+ if path.path in context.post_load_paths:
+ pl = context.post_load_paths[path.path]
+ else:
+ pl = context.post_load_paths[path.path] = PostLoad()
+ pl.loaders[token] = (token, limit_to_mapper, loader_callable, arg, kw)
+
+
+def load_scalar_attributes(mapper, state, attribute_names, passive):
+ """initiate a column-based attribute refresh operation."""
+
+ # assert mapper is _state_mapper(state)
+ session = state.session
+ if not session:
+ raise orm_exc.DetachedInstanceError(
+ "Instance %s is not bound to a Session; "
+ "attribute refresh operation cannot proceed" % (state_str(state))
+ )
+
+ has_key = bool(state.key)
+
+ result = False
+
+ no_autoflush = (
+ bool(passive & attributes.NO_AUTOFLUSH) or state.session.autocommit
+ )
+
+ # in the case of inheritance, particularly concrete and abstract
+ # concrete inheritance, the class manager might have some keys
+ # of attributes on the superclass that we didn't actually map.
+ # These could be mapped as "concrete, don't load" or could be completely
+ # excluded from the mapping and we know nothing about them. Filter them
+ # here to prevent them from coming through.
+ if attribute_names:
+ attribute_names = attribute_names.intersection(mapper.attrs.keys())
+
+ if mapper.inherits and not mapper.concrete:
+ # because we are using Core to produce a select() that we
+ # pass to the Query, we aren't calling setup() for mapped
+ # attributes; in 1.0 this means deferred attrs won't get loaded
+ # by default
+ statement = mapper._optimized_get_statement(state, attribute_names)
+ if statement is not None:
+ # this was previously aliased(mapper, statement), however,
+ # statement is a select() and Query's coercion now raises for this
+ # since you can't "select" from a "SELECT" statement. only
+ # from_statement() allows this.
+ # note: using from_statement() here means there is an adaption
+ # with adapt_on_names set up. the other option is to make the
+ # aliased() against a subquery which affects the SQL.
+
+ from .query import FromStatement
+
+ stmt = FromStatement(mapper, statement).options(
+ strategy_options.Load(mapper).undefer("*")
+ )
+
+ result = load_on_ident(
+ session,
+ stmt,
+ None,
+ only_load_props=attribute_names,
+ refresh_state=state,
+ no_autoflush=no_autoflush,
+ )
+
+ if result is False:
+ if has_key:
+ identity_key = state.key
+ else:
+ # this codepath is rare - only valid when inside a flush, and the
+ # object is becoming persistent but hasn't yet been assigned
+ # an identity_key.
+ # check here to ensure we have the attrs we need.
+ pk_attrs = [
+ mapper._columntoproperty[col].key for col in mapper.primary_key
+ ]
+ if state.expired_attributes.intersection(pk_attrs):
+ raise sa_exc.InvalidRequestError(
+ "Instance %s cannot be refreshed - it's not "
+ " persistent and does not "
+ "contain a full primary key." % state_str(state)
+ )
+ identity_key = mapper._identity_key_from_state(state)
+
+ if (
+ _none_set.issubset(identity_key) and not mapper.allow_partial_pks
+ ) or _none_set.issuperset(identity_key):
+ util.warn_limited(
+ "Instance %s to be refreshed doesn't "
+ "contain a full primary key - can't be refreshed "
+ "(and shouldn't be expired, either).",
+ state_str(state),
+ )
+ return
+
+ result = load_on_ident(
+ session,
+ future.select(mapper).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ ),
+ identity_key,
+ refresh_state=state,
+ only_load_props=attribute_names,
+ no_autoflush=no_autoflush,
+ )
+
+ # if instance is pending, a refresh operation
+ # may not complete (even if PK attributes are assigned)
+ if has_key and result is None:
+ raise orm_exc.ObjectDeletedError(state)
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
new file mode 100644
index 0000000..ed221a9
--- /dev/null
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -0,0 +1,3658 @@
+# orm/mapper.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Logic to map Python classes to and from selectables.
+
+Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central
+configurational unit which associates a class with a database table.
+
+This is a semi-private module; the main configurational API of the ORM is
+available in :class:`~sqlalchemy.orm.`.
+
+"""
+from __future__ import absolute_import
+
+from collections import deque
+from itertools import chain
+import sys
+import weakref
+
+from . import attributes
+from . import exc as orm_exc
+from . import instrumentation
+from . import loading
+from . import properties
+from . import util as orm_util
+from .base import _class_to_mapper
+from .base import _state_mapper
+from .base import class_mapper
+from .base import state_str
+from .interfaces import _MappedAttribute
+from .interfaces import EXT_SKIP
+from .interfaces import InspectionAttr
+from .interfaces import MapperProperty
+from .interfaces import ORMEntityColumnsClauseRole
+from .interfaces import ORMFromClauseRole
+from .interfaces import StrategizedProperty
+from .path_registry import PathRegistry
+from .. import event
+from .. import exc as sa_exc
+from .. import inspection
+from .. import log
+from .. import schema
+from .. import sql
+from .. import util
+from ..sql import base as sql_base
+from ..sql import coercions
+from ..sql import expression
+from ..sql import operators
+from ..sql import roles
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import HasMemoized
+
+_mapper_registries = weakref.WeakKeyDictionary()
+
+_legacy_registry = None
+
+
+def _all_registries():
+ with _CONFIGURE_MUTEX:
+ return set(_mapper_registries)
+
+
+def _unconfigured_mappers():
+ for reg in _all_registries():
+ for mapper in reg._mappers_to_configure():
+ yield mapper
+
+
+_already_compiling = False
+
+
+# a constant returned by _get_attr_by_column to indicate
+# this mapper is not handling an attribute for a particular
+# column
+NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE")
+
+# lock used to synchronize the "mapper configure" step
+_CONFIGURE_MUTEX = util.threading.RLock()
+
+
+@inspection._self_inspects
+@log.class_logger
+class Mapper(
+ ORMFromClauseRole,
+ ORMEntityColumnsClauseRole,
+ sql_base.MemoizedHasCacheKey,
+ InspectionAttr,
+):
+ """Defines an association between a Python class and a database table or
+ other relational structure, so that ORM operations against the class may
+ proceed.
+
+ The :class:`_orm.Mapper` object is instantiated using mapping methods
+ present on the :class:`_orm.registry` object. For information
+ about instantiating new :class:`_orm.Mapper` objects, see
+ :ref:`orm_mapping_classes_toplevel`.
+
+ """
+
+ _dispose_called = False
+ _ready_for_configure = False
+
+ @util.deprecated_params(
+ non_primary=(
+ "1.3",
+ "The :paramref:`.mapper.non_primary` parameter is deprecated, "
+ "and will be removed in a future release. The functionality "
+ "of non primary mappers is now better suited using the "
+ ":class:`.AliasedClass` construct, which can also be used "
+ "as the target of a :func:`_orm.relationship` in 1.3.",
+ ),
+ )
+ def __init__(
+ self,
+ class_,
+ local_table=None,
+ properties=None,
+ primary_key=None,
+ non_primary=False,
+ inherits=None,
+ inherit_condition=None,
+ inherit_foreign_keys=None,
+ always_refresh=False,
+ version_id_col=None,
+ version_id_generator=None,
+ polymorphic_on=None,
+ _polymorphic_map=None,
+ polymorphic_identity=None,
+ concrete=False,
+ with_polymorphic=None,
+ polymorphic_load=None,
+ allow_partial_pks=True,
+ batch=True,
+ column_prefix=None,
+ include_properties=None,
+ exclude_properties=None,
+ passive_updates=True,
+ passive_deletes=False,
+ confirm_deleted_rows=True,
+ eager_defaults=False,
+ legacy_is_orphan=False,
+ _compiled_cache_size=100,
+ ):
+ r"""Direct constructor for a new :class:`_orm.Mapper` object.
+
+ The :func:`_orm.mapper` function is normally invoked through the
+ use of the :class:`_orm.registry` object through either the
+ :ref:`Declarative <orm_declarative_mapping>` or
+ :ref:`Imperative <orm_imperative_mapping>` mapping styles.
+
+ .. versionchanged:: 1.4 The :func:`_orm.mapper` function should not
+ be called directly for classical mapping; for a classical mapping
+ configuration, use the :meth:`_orm.registry.map_imperatively`
+ method. The :func:`_orm.mapper` function may become private in a
+ future release.
+
+ Parameters documented below may be passed to either the
+ :meth:`_orm.registry.map_imperatively` method, or may be passed in the
+ ``__mapper_args__`` declarative class attribute described at
+ :ref:`orm_declarative_mapper_options`.
+
+ :param class\_: The class to be mapped. When using Declarative,
+ this argument is automatically passed as the declared class
+ itself.
+
+ :param local_table: The :class:`_schema.Table` or other selectable
+ to which the class is mapped. May be ``None`` if
+ this mapper inherits from another mapper using single-table
+ inheritance. When using Declarative, this argument is
+ automatically passed by the extension, based on what
+ is configured via the ``__table__`` argument or via the
+ :class:`_schema.Table`
+ produced as a result of the ``__tablename__``
+ and :class:`_schema.Column` arguments present.
+
+ :param always_refresh: If True, all query operations for this mapped
+ class will overwrite all data within object instances that already
+ exist within the session, erasing any in-memory changes with
+ whatever information was loaded from the database. Usage of this
+ flag is highly discouraged; as an alternative, see the method
+ :meth:`_query.Query.populate_existing`.
+
+ :param allow_partial_pks: Defaults to True. Indicates that a
+ composite primary key with some NULL values should be considered as
+ possibly existing within the database. This affects whether a
+ mapper will assign an incoming row to an existing identity, as well
+ as if :meth:`.Session.merge` will check the database first for a
+ particular primary key value. A "partial primary key" can occur if
+ one has mapped to an OUTER JOIN, for example.
+
+ :param batch: Defaults to ``True``, indicating that save operations
+ of multiple entities can be batched together for efficiency.
+ Setting to False indicates
+ that an instance will be fully saved before saving the next
+ instance. This is used in the extremely rare case that a
+ :class:`.MapperEvents` listener requires being called
+ in between individual row persistence operations.
+
+ :param column_prefix: A string which will be prepended
+ to the mapped attribute name when :class:`_schema.Column`
+ objects are automatically assigned as attributes to the
+ mapped class. Does not affect :class:`.Column` objects that
+ are mapped explicitly in the :paramref:`.mapper.properties`
+ dictionary.
+
+ This parameter is typically useful with imperative mappings
+ that keep the :class:`.Table` object separate. Below, assuming
+ the ``user_table`` :class:`.Table` object has columns named
+ ``user_id``, ``user_name``, and ``password``::
+
+ class User(Base):
+ __table__ = user_table
+ __mapper_args__ = {'column_prefix':'_'}
+
+ The above mapping will assign the ``user_id``, ``user_name``, and
+ ``password`` columns to attributes named ``_user_id``,
+ ``_user_name``, and ``_password`` on the mapped ``User`` class.
+
+ The :paramref:`.mapper.column_prefix` parameter is uncommon in
+ modern use. For dealing with reflected tables, a more flexible
+ approach to automating a naming scheme is to intercept the
+ :class:`.Column` objects as they are reflected; see the section
+ :ref:`mapper_automated_reflection_schemes` for notes on this usage
+ pattern.
+
+ :param concrete: If True, indicates this mapper should use concrete
+ table inheritance with its parent mapper.
+
+ See the section :ref:`concrete_inheritance` for an example.
+
+ :param confirm_deleted_rows: defaults to True; when a DELETE occurs
+ of one more rows based on specific primary keys, a warning is
+ emitted when the number of rows matched does not equal the number
+ of rows expected. This parameter may be set to False to handle the
+ case where database ON DELETE CASCADE rules may be deleting some of
+ those rows automatically. The warning may be changed to an
+ exception in a future release.
+
+ .. versionadded:: 0.9.4 - added
+ :paramref:`.mapper.confirm_deleted_rows` as well as conditional
+ matched row checking on delete.
+
+ :param eager_defaults: if True, the ORM will immediately fetch the
+ value of server-generated default values after an INSERT or UPDATE,
+ rather than leaving them as expired to be fetched on next access.
+ This can be used for event schemes where the server-generated values
+ are needed immediately before the flush completes. By default,
+ this scheme will emit an individual ``SELECT`` statement per row
+ inserted or updated, which note can add significant performance
+ overhead. However, if the
+ target database supports :term:`RETURNING`, the default values will
+ be returned inline with the INSERT or UPDATE statement, which can
+ greatly enhance performance for an application that needs frequent
+ access to just-generated server defaults.
+
+ .. seealso::
+
+ :ref:`orm_server_defaults`
+
+ .. versionchanged:: 0.9.0 The ``eager_defaults`` option can now
+ make use of :term:`RETURNING` for backends which support it.
+
+ :param exclude_properties: A list or set of string column names to
+ be excluded from mapping.
+
+ See :ref:`include_exclude_cols` for an example.
+
+ :param include_properties: An inclusive list or set of string column
+ names to map.
+
+ See :ref:`include_exclude_cols` for an example.
+
+ :param inherits: A mapped class or the corresponding
+ :class:`_orm.Mapper`
+ of one indicating a superclass to which this :class:`_orm.Mapper`
+ should *inherit* from. The mapped class here must be a subclass
+ of the other mapper's class. When using Declarative, this argument
+ is passed automatically as a result of the natural class
+ hierarchy of the declared classes.
+
+ .. seealso::
+
+ :ref:`inheritance_toplevel`
+
+ :param inherit_condition: For joined table inheritance, a SQL
+ expression which will
+ define how the two tables are joined; defaults to a natural join
+ between the two tables.
+
+ :param inherit_foreign_keys: When ``inherit_condition`` is used and
+ the columns present are missing a :class:`_schema.ForeignKey`
+ configuration, this parameter can be used to specify which columns
+ are "foreign". In most cases can be left as ``None``.
+
+ :param legacy_is_orphan: Boolean, defaults to ``False``.
+ When ``True``, specifies that "legacy" orphan consideration
+ is to be applied to objects mapped by this mapper, which means
+ that a pending (that is, not persistent) object is auto-expunged
+ from an owning :class:`.Session` only when it is de-associated
+ from *all* parents that specify a ``delete-orphan`` cascade towards
+ this mapper. The new default behavior is that the object is
+ auto-expunged when it is de-associated with *any* of its parents
+ that specify ``delete-orphan`` cascade. This behavior is more
+ consistent with that of a persistent object, and allows behavior to
+ be consistent in more scenarios independently of whether or not an
+ orphan object has been flushed yet or not.
+
+ See the change note and example at :ref:`legacy_is_orphan_addition`
+ for more detail on this change.
+
+ :param non_primary: Specify that this :class:`_orm.Mapper`
+ is in addition
+ to the "primary" mapper, that is, the one used for persistence.
+ The :class:`_orm.Mapper` created here may be used for ad-hoc
+ mapping of the class to an alternate selectable, for loading
+ only.
+
+ .. seealso::
+
+ :ref:`relationship_aliased_class` - the new pattern that removes
+ the need for the :paramref:`_orm.Mapper.non_primary` flag.
+
+ :param passive_deletes: Indicates DELETE behavior of foreign key
+ columns when a joined-table inheritance entity is being deleted.
+ Defaults to ``False`` for a base mapper; for an inheriting mapper,
+ defaults to ``False`` unless the value is set to ``True``
+ on the superclass mapper.
+
+ When ``True``, it is assumed that ON DELETE CASCADE is configured
+ on the foreign key relationships that link this mapper's table
+ to its superclass table, so that when the unit of work attempts
+ to delete the entity, it need only emit a DELETE statement for the
+ superclass table, and not this table.
+
+ When ``False``, a DELETE statement is emitted for this mapper's
+ table individually. If the primary key attributes local to this
+ table are unloaded, then a SELECT must be emitted in order to
+ validate these attributes; note that the primary key columns
+ of a joined-table subclass are not part of the "primary key" of
+ the object as a whole.
+
+ Note that a value of ``True`` is **always** forced onto the
+ subclass mappers; that is, it's not possible for a superclass
+ to specify passive_deletes without this taking effect for
+ all subclass mappers.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`passive_deletes` - description of similar feature as
+ used with :func:`_orm.relationship`
+
+ :paramref:`.mapper.passive_updates` - supporting ON UPDATE
+ CASCADE for joined-table inheritance mappers
+
+ :param passive_updates: Indicates UPDATE behavior of foreign key
+ columns when a primary key column changes on a joined-table
+ inheritance mapping. Defaults to ``True``.
+
+ When True, it is assumed that ON UPDATE CASCADE is configured on
+ the foreign key in the database, and that the database will handle
+ propagation of an UPDATE from a source column to dependent columns
+ on joined-table rows.
+
+ When False, it is assumed that the database does not enforce
+ referential integrity and will not be issuing its own CASCADE
+ operation for an update. The unit of work process will
+ emit an UPDATE statement for the dependent columns during a
+ primary key change.
+
+ .. seealso::
+
+ :ref:`passive_updates` - description of a similar feature as
+ used with :func:`_orm.relationship`
+
+ :paramref:`.mapper.passive_deletes` - supporting ON DELETE
+ CASCADE for joined-table inheritance mappers
+
+ :param polymorphic_load: Specifies "polymorphic loading" behavior
+ for a subclass in an inheritance hierarchy (joined and single
+ table inheritance only). Valid values are:
+
+ * "'inline'" - specifies this class should be part of the
+ "with_polymorphic" mappers, e.g. its columns will be included
+ in a SELECT query against the base.
+
+ * "'selectin'" - specifies that when instances of this class
+ are loaded, an additional SELECT will be emitted to retrieve
+ the columns specific to this subclass. The SELECT uses
+ IN to fetch multiple subclasses at once.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`with_polymorphic_mapper_config`
+
+ :ref:`polymorphic_selectin`
+
+ :param polymorphic_on: Specifies the column, attribute, or
+ SQL expression used to determine the target class for an
+ incoming row, when inheriting classes are present.
+
+ This value is commonly a :class:`_schema.Column` object that's
+ present in the mapped :class:`_schema.Table`::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+
+ id = Column(Integer, primary_key=True)
+ discriminator = Column(String(50))
+
+ __mapper_args__ = {
+ "polymorphic_on":discriminator,
+ "polymorphic_identity":"employee"
+ }
+
+ It may also be specified
+ as a SQL expression, as in this example where we
+ use the :func:`.case` construct to provide a conditional
+ approach::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+
+ id = Column(Integer, primary_key=True)
+ discriminator = Column(String(50))
+
+ __mapper_args__ = {
+ "polymorphic_on":case([
+ (discriminator == "EN", "engineer"),
+ (discriminator == "MA", "manager"),
+ ], else_="employee"),
+ "polymorphic_identity":"employee"
+ }
+
+ It may also refer to any attribute
+ configured with :func:`.column_property`, or to the
+ string name of one::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+
+ id = Column(Integer, primary_key=True)
+ discriminator = Column(String(50))
+ employee_type = column_property(
+ case([
+ (discriminator == "EN", "engineer"),
+ (discriminator == "MA", "manager"),
+ ], else_="employee")
+ )
+
+ __mapper_args__ = {
+ "polymorphic_on":employee_type,
+ "polymorphic_identity":"employee"
+ }
+
+ When setting ``polymorphic_on`` to reference an
+ attribute or expression that's not present in the
+ locally mapped :class:`_schema.Table`, yet the value
+ of the discriminator should be persisted to the database,
+ the value of the
+ discriminator is not automatically set on new
+ instances; this must be handled by the user,
+ either through manual means or via event listeners.
+ A typical approach to establishing such a listener
+ looks like::
+
+ from sqlalchemy import event
+ from sqlalchemy.orm import object_mapper
+
+ @event.listens_for(Employee, "init", propagate=True)
+ def set_identity(instance, *arg, **kw):
+ mapper = object_mapper(instance)
+ instance.discriminator = mapper.polymorphic_identity
+
+ Where above, we assign the value of ``polymorphic_identity``
+ for the mapped class to the ``discriminator`` attribute,
+ thus persisting the value to the ``discriminator`` column
+ in the database.
+
+ .. warning::
+
+ Currently, **only one discriminator column may be set**, typically
+ on the base-most class in the hierarchy. "Cascading" polymorphic
+ columns are not yet supported.
+
+ .. seealso::
+
+ :ref:`inheritance_toplevel`
+
+ :param polymorphic_identity: Specifies the value which
+ identifies this particular class as returned by the
+ column expression referred to by the ``polymorphic_on``
+ setting. As rows are received, the value corresponding
+ to the ``polymorphic_on`` column expression is compared
+ to this value, indicating which subclass should
+ be used for the newly reconstructed object.
+
+ :param properties: A dictionary mapping the string names of object
+ attributes to :class:`.MapperProperty` instances, which define the
+ persistence behavior of that attribute. Note that
+ :class:`_schema.Column`
+ objects present in
+ the mapped :class:`_schema.Table` are automatically placed into
+ ``ColumnProperty`` instances upon mapping, unless overridden.
+ When using Declarative, this argument is passed automatically,
+ based on all those :class:`.MapperProperty` instances declared
+ in the declared class body.
+
+ .. seealso::
+
+ :ref:`orm_mapping_properties` - in the
+ :ref:`orm_mapping_classes_toplevel`
+
+ :param primary_key: A list of :class:`_schema.Column`
+ objects which define
+ the primary key to be used against this mapper's selectable unit.
+ This is normally simply the primary key of the ``local_table``, but
+ can be overridden here.
+
+ .. seealso::
+
+ :ref:`mapper_primary_key` - background and example use
+
+ :param version_id_col: A :class:`_schema.Column`
+ that will be used to keep a running version id of rows
+ in the table. This is used to detect concurrent updates or
+ the presence of stale data in a flush. The methodology is to
+ detect if an UPDATE statement does not match the last known
+ version id, a
+ :class:`~sqlalchemy.orm.exc.StaleDataError` exception is
+ thrown.
+ By default, the column must be of :class:`.Integer` type,
+ unless ``version_id_generator`` specifies an alternative version
+ generator.
+
+ .. seealso::
+
+ :ref:`mapper_version_counter` - discussion of version counting
+ and rationale.
+
+ :param version_id_generator: Define how new version ids should
+ be generated. Defaults to ``None``, which indicates that
+ a simple integer counting scheme be employed. To provide a custom
+ versioning scheme, provide a callable function of the form::
+
+ def generate_version(version):
+ return next_version
+
+ Alternatively, server-side versioning functions such as triggers,
+ or programmatic versioning schemes outside of the version id
+ generator may be used, by specifying the value ``False``.
+ Please see :ref:`server_side_version_counter` for a discussion
+ of important points when using this option.
+
+ .. versionadded:: 0.9.0 ``version_id_generator`` supports
+ server-side version number generation.
+
+ .. seealso::
+
+ :ref:`custom_version_counter`
+
+ :ref:`server_side_version_counter`
+
+
+ :param with_polymorphic: A tuple in the form ``(<classes>,
+ <selectable>)`` indicating the default style of "polymorphic"
+ loading, that is, which tables are queried at once. <classes> is
+ any single or list of mappers and/or classes indicating the
+ inherited classes that should be loaded at once. The special value
+ ``'*'`` may be used to indicate all descending classes should be
+ loaded immediately. The second tuple argument <selectable>
+ indicates a selectable that will be used to query for multiple
+ classes.
+
+ .. seealso::
+
+ :ref:`with_polymorphic` - discussion of polymorphic querying
+ techniques.
+
+ """
+ self.class_ = util.assert_arg_type(class_, type, "class_")
+ self._sort_key = "%s.%s" % (
+ self.class_.__module__,
+ self.class_.__name__,
+ )
+
+ self.class_manager = None
+
+ self._primary_key_argument = util.to_list(primary_key)
+ self.non_primary = non_primary
+
+ self.always_refresh = always_refresh
+
+ if isinstance(version_id_col, MapperProperty):
+ self.version_id_prop = version_id_col
+ self.version_id_col = None
+ else:
+ self.version_id_col = version_id_col
+ if version_id_generator is False:
+ self.version_id_generator = False
+ elif version_id_generator is None:
+ self.version_id_generator = lambda x: (x or 0) + 1
+ else:
+ self.version_id_generator = version_id_generator
+
+ self.concrete = concrete
+ self.single = False
+ self.inherits = inherits
+ if local_table is not None:
+ self.local_table = coercions.expect(
+ roles.StrictFromClauseRole, local_table
+ )
+ else:
+ self.local_table = None
+
+ self.inherit_condition = inherit_condition
+ self.inherit_foreign_keys = inherit_foreign_keys
+ self._init_properties = properties or {}
+ self._delete_orphans = []
+ self.batch = batch
+ self.eager_defaults = eager_defaults
+ self.column_prefix = column_prefix
+ self.polymorphic_on = (
+ coercions.expect(
+ roles.ColumnArgumentOrKeyRole,
+ polymorphic_on,
+ argname="polymorphic_on",
+ )
+ if polymorphic_on is not None
+ else None
+ )
+ self._dependency_processors = []
+ self.validators = util.EMPTY_DICT
+ self.passive_updates = passive_updates
+ self.passive_deletes = passive_deletes
+ self.legacy_is_orphan = legacy_is_orphan
+ self._clause_adapter = None
+ self._requires_row_aliasing = False
+ self._inherits_equated_pairs = None
+ self._memoized_values = {}
+ self._compiled_cache_size = _compiled_cache_size
+ self._reconstructor = None
+ self.allow_partial_pks = allow_partial_pks
+
+ if self.inherits and not self.concrete:
+ self.confirm_deleted_rows = False
+ else:
+ self.confirm_deleted_rows = confirm_deleted_rows
+
+ self._set_with_polymorphic(with_polymorphic)
+ self.polymorphic_load = polymorphic_load
+
+ # our 'polymorphic identity', a string name that when located in a
+ # result set row indicates this Mapper should be used to construct
+ # the object instance for that row.
+ self.polymorphic_identity = polymorphic_identity
+
+ # a dictionary of 'polymorphic identity' names, associating those
+ # names with Mappers that will be used to construct object instances
+ # upon a select operation.
+ if _polymorphic_map is None:
+ self.polymorphic_map = {}
+ else:
+ self.polymorphic_map = _polymorphic_map
+
+ if include_properties is not None:
+ self.include_properties = util.to_set(include_properties)
+ else:
+ self.include_properties = None
+ if exclude_properties:
+ self.exclude_properties = util.to_set(exclude_properties)
+ else:
+ self.exclude_properties = None
+
+ # prevent this mapper from being constructed
+ # while a configure_mappers() is occurring (and defer a
+ # configure_mappers() until construction succeeds)
+ with _CONFIGURE_MUTEX:
+ self.dispatch._events._new_mapper_instance(class_, self)
+ self._configure_inheritance()
+ self._configure_class_instrumentation()
+ self._configure_properties()
+ self._configure_polymorphic_setter()
+ self._configure_pks()
+ self.registry._flag_new_mapper(self)
+ self._log("constructed")
+ self._expire_memoizations()
+
+ # major attributes initialized at the classlevel so that
+ # they can be Sphinx-documented.
+
+ is_mapper = True
+ """Part of the inspection API."""
+
+ represents_outer_join = False
+
+ @property
+ def mapper(self):
+ """Part of the inspection API.
+
+ Returns self.
+
+ """
+ return self
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self,)
+
+ @property
+ def entity(self):
+ r"""Part of the inspection API.
+
+ Returns self.class\_.
+
+ """
+ return self.class_
+
+ local_table = None
+ """The :class:`_expression.Selectable` which this :class:`_orm.Mapper`
+ manages.
+
+ Typically is an instance of :class:`_schema.Table` or
+ :class:`_expression.Alias`.
+ May also be ``None``.
+
+ The "local" table is the
+ selectable that the :class:`_orm.Mapper` is directly responsible for
+ managing from an attribute access and flush perspective. For
+ non-inheriting mappers, the local table is the same as the
+ "mapped" table. For joined-table inheritance mappers, local_table
+ will be the particular sub-table of the overall "join" which
+ this :class:`_orm.Mapper` represents. If this mapper is a
+ single-table inheriting mapper, local_table will be ``None``.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.persist_selectable`.
+
+ """
+
+ persist_selectable = None
+ """The :class:`_expression.Selectable` to which this :class:`_orm.Mapper`
+ is mapped.
+
+ Typically an instance of :class:`_schema.Table`,
+ :class:`_expression.Join`, or :class:`_expression.Alias`.
+
+ The :attr:`_orm.Mapper.persist_selectable` is separate from
+ :attr:`_orm.Mapper.selectable` in that the former represents columns
+ that are mapped on this class or its superclasses, whereas the
+ latter may be a "polymorphic" selectable that contains additional columns
+ which are in fact mapped on subclasses only.
+
+ "persist selectable" is the "thing the mapper writes to" and
+ "selectable" is the "thing the mapper selects from".
+
+ :attr:`_orm.Mapper.persist_selectable` is also separate from
+ :attr:`_orm.Mapper.local_table`, which represents the set of columns that
+ are locally mapped on this class directly.
+
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.selectable`.
+
+ :attr:`_orm.Mapper.local_table`.
+
+ """
+
+ inherits = None
+ """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper`
+ inherits from, if any.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ configured = False
+ """Represent ``True`` if this :class:`_orm.Mapper` has been configured.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ .. seealso::
+
+ :func:`.configure_mappers`.
+
+ """
+
+ concrete = None
+ """Represent ``True`` if this :class:`_orm.Mapper` is a concrete
+ inheritance mapper.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ tables = None
+ """An iterable containing the collection of :class:`_schema.Table` objects
+ which this :class:`_orm.Mapper` is aware of.
+
+ If the mapper is mapped to a :class:`_expression.Join`, or an
+ :class:`_expression.Alias`
+ representing a :class:`_expression.Select`, the individual
+ :class:`_schema.Table`
+ objects that comprise the full construct will be represented here.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ primary_key = None
+ """An iterable containing the collection of :class:`_schema.Column`
+ objects
+ which comprise the 'primary key' of the mapped table, from the
+ perspective of this :class:`_orm.Mapper`.
+
+ This list is against the selectable in
+ :attr:`_orm.Mapper.persist_selectable`.
+ In the case of inheriting mappers, some columns may be managed by a
+ superclass mapper. For example, in the case of a
+ :class:`_expression.Join`, the
+ primary key is determined by all of the primary key columns across all
+ tables referenced by the :class:`_expression.Join`.
+
+ The list is also not necessarily the same as the primary key column
+ collection associated with the underlying tables; the :class:`_orm.Mapper`
+ features a ``primary_key`` argument that can override what the
+ :class:`_orm.Mapper` considers as primary key columns.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ class_ = None
+ """The Python class which this :class:`_orm.Mapper` maps.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ class_manager = None
+ """The :class:`.ClassManager` which maintains event listeners
+ and class-bound descriptors for this :class:`_orm.Mapper`.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ single = None
+ """Represent ``True`` if this :class:`_orm.Mapper` is a single table
+ inheritance mapper.
+
+ :attr:`_orm.Mapper.local_table` will be ``None`` if this flag is set.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ non_primary = None
+ """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary"
+ mapper, e.g. a mapper that is used only to select rows but not for
+ persistence management.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ polymorphic_on = None
+ """The :class:`_schema.Column` or SQL expression specified as the
+ ``polymorphic_on`` argument
+ for this :class:`_orm.Mapper`, within an inheritance scenario.
+
+ This attribute is normally a :class:`_schema.Column` instance but
+ may also be an expression, such as one derived from
+ :func:`.cast`.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ polymorphic_map = None
+ """A mapping of "polymorphic identity" identifiers mapped to
+ :class:`_orm.Mapper` instances, within an inheritance scenario.
+
+ The identifiers can be of any type which is comparable to the
+ type of column represented by :attr:`_orm.Mapper.polymorphic_on`.
+
+ An inheritance chain of mappers will all reference the same
+ polymorphic map object. The object is used to correlate incoming
+ result rows to target mappers.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ polymorphic_identity = None
+ """Represent an identifier which is matched against the
+ :attr:`_orm.Mapper.polymorphic_on` column during result row loading.
+
+ Used only with inheritance, this object can be of any type which is
+ comparable to the type of column represented by
+ :attr:`_orm.Mapper.polymorphic_on`.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ base_mapper = None
+ """The base-most :class:`_orm.Mapper` in an inheritance chain.
+
+ In a non-inheriting scenario, this attribute will always be this
+ :class:`_orm.Mapper`. In an inheritance scenario, it references
+ the :class:`_orm.Mapper` which is parent to all other :class:`_orm.Mapper`
+ objects in the inheritance chain.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ columns = None
+ """A collection of :class:`_schema.Column` or other scalar expression
+ objects maintained by this :class:`_orm.Mapper`.
+
+ The collection behaves the same as that of the ``c`` attribute on
+ any :class:`_schema.Table` object,
+ except that only those columns included in
+ this mapping are present, and are keyed based on the attribute name
+ defined in the mapping, not necessarily the ``key`` attribute of the
+ :class:`_schema.Column` itself. Additionally, scalar expressions mapped
+ by :func:`.column_property` are also present here.
+
+ This is a *read only* attribute determined during mapper construction.
+ Behavior is undefined if directly modified.
+
+ """
+
+ validators = None
+ """An immutable dictionary of attributes which have been decorated
+ using the :func:`_orm.validates` decorator.
+
+ The dictionary contains string attribute names as keys
+ mapped to the actual validation method.
+
+ """
+
+ c = None
+ """A synonym for :attr:`_orm.Mapper.columns`."""
+
+ @property
+ @util.deprecated("1.3", "Use .persist_selectable")
+ def mapped_table(self):
+ return self.persist_selectable
+
+ @util.memoized_property
+ def _path_registry(self):
+ return PathRegistry.per_mapper(self)
+
+ def _configure_inheritance(self):
+ """Configure settings related to inheriting and/or inherited mappers
+ being present."""
+
+ # a set of all mappers which inherit from this one.
+ self._inheriting_mappers = util.WeakSequence()
+
+ if self.inherits:
+ if isinstance(self.inherits, type):
+ self.inherits = class_mapper(self.inherits, configure=False)
+ if not issubclass(self.class_, self.inherits.class_):
+ raise sa_exc.ArgumentError(
+ "Class '%s' does not inherit from '%s'"
+ % (self.class_.__name__, self.inherits.class_.__name__)
+ )
+
+ self.dispatch._update(self.inherits.dispatch)
+
+ if self.non_primary != self.inherits.non_primary:
+ np = not self.non_primary and "primary" or "non-primary"
+ raise sa_exc.ArgumentError(
+ "Inheritance of %s mapper for class '%s' is "
+ "only allowed from a %s mapper"
+ % (np, self.class_.__name__, np)
+ )
+ # inherit_condition is optional.
+ if self.local_table is None:
+ self.local_table = self.inherits.local_table
+ self.persist_selectable = self.inherits.persist_selectable
+ self.single = True
+ elif self.local_table is not self.inherits.local_table:
+ if self.concrete:
+ self.persist_selectable = self.local_table
+ for mapper in self.iterate_to_root():
+ if mapper.polymorphic_on is not None:
+ mapper._requires_row_aliasing = True
+ else:
+ if self.inherit_condition is None:
+ # figure out inherit condition from our table to the
+ # immediate table of the inherited mapper, not its
+ # full table which could pull in other stuff we don't
+ # want (allows test/inheritance.InheritTest4 to pass)
+ try:
+ self.inherit_condition = sql_util.join_condition(
+ self.inherits.local_table, self.local_table
+ )
+ except sa_exc.NoForeignKeysError as nfe:
+ assert self.inherits.local_table is not None
+ assert self.local_table is not None
+ util.raise_(
+ sa_exc.NoForeignKeysError(
+ "Can't determine the inherit condition "
+ "between inherited table '%s' and "
+ "inheriting "
+ "table '%s'; tables have no "
+ "foreign key relationships established. "
+ "Please ensure the inheriting table has "
+ "a foreign key relationship to the "
+ "inherited "
+ "table, or provide an "
+ "'on clause' using "
+ "the 'inherit_condition' mapper argument."
+ % (
+ self.inherits.local_table.description,
+ self.local_table.description,
+ )
+ ),
+ replace_context=nfe,
+ )
+ except sa_exc.AmbiguousForeignKeysError as afe:
+ assert self.inherits.local_table is not None
+ assert self.local_table is not None
+ util.raise_(
+ sa_exc.AmbiguousForeignKeysError(
+ "Can't determine the inherit condition "
+ "between inherited table '%s' and "
+ "inheriting "
+ "table '%s'; tables have more than one "
+ "foreign key relationship established. "
+ "Please specify the 'on clause' using "
+ "the 'inherit_condition' mapper argument."
+ % (
+ self.inherits.local_table.description,
+ self.local_table.description,
+ )
+ ),
+ replace_context=afe,
+ )
+ self.persist_selectable = sql.join(
+ self.inherits.persist_selectable,
+ self.local_table,
+ self.inherit_condition,
+ )
+
+ fks = util.to_set(self.inherit_foreign_keys)
+ self._inherits_equated_pairs = sql_util.criterion_as_pairs(
+ self.persist_selectable.onclause,
+ consider_as_foreign_keys=fks,
+ )
+ else:
+ self.persist_selectable = self.local_table
+
+ if self.polymorphic_identity is not None and not self.concrete:
+ self._identity_class = self.inherits._identity_class
+ else:
+ self._identity_class = self.class_
+
+ if self.version_id_col is None:
+ self.version_id_col = self.inherits.version_id_col
+ self.version_id_generator = self.inherits.version_id_generator
+ elif (
+ self.inherits.version_id_col is not None
+ and self.version_id_col is not self.inherits.version_id_col
+ ):
+ util.warn(
+ "Inheriting version_id_col '%s' does not match inherited "
+ "version_id_col '%s' and will not automatically populate "
+ "the inherited versioning column. "
+ "version_id_col should only be specified on "
+ "the base-most mapper that includes versioning."
+ % (
+ self.version_id_col.description,
+ self.inherits.version_id_col.description,
+ )
+ )
+
+ self.polymorphic_map = self.inherits.polymorphic_map
+ self.batch = self.inherits.batch
+ self.inherits._inheriting_mappers.append(self)
+ self.base_mapper = self.inherits.base_mapper
+ self.passive_updates = self.inherits.passive_updates
+ self.passive_deletes = (
+ self.inherits.passive_deletes or self.passive_deletes
+ )
+ self._all_tables = self.inherits._all_tables
+
+ if self.polymorphic_identity is not None:
+ if self.polymorphic_identity in self.polymorphic_map:
+ util.warn(
+ "Reassigning polymorphic association for identity %r "
+ "from %r to %r: Check for duplicate use of %r as "
+ "value for polymorphic_identity."
+ % (
+ self.polymorphic_identity,
+ self.polymorphic_map[self.polymorphic_identity],
+ self,
+ self.polymorphic_identity,
+ )
+ )
+ self.polymorphic_map[self.polymorphic_identity] = self
+
+ if self.polymorphic_load and self.concrete:
+ raise sa_exc.ArgumentError(
+ "polymorphic_load is not currently supported "
+ "with concrete table inheritance"
+ )
+ if self.polymorphic_load == "inline":
+ self.inherits._add_with_polymorphic_subclass(self)
+ elif self.polymorphic_load == "selectin":
+ pass
+ elif self.polymorphic_load is not None:
+ raise sa_exc.ArgumentError(
+ "unknown argument for polymorphic_load: %r"
+ % self.polymorphic_load
+ )
+
+ else:
+ self._all_tables = set()
+ self.base_mapper = self
+ self.persist_selectable = self.local_table
+ if self.polymorphic_identity is not None:
+ self.polymorphic_map[self.polymorphic_identity] = self
+ self._identity_class = self.class_
+
+ if self.persist_selectable is None:
+ raise sa_exc.ArgumentError(
+ "Mapper '%s' does not have a persist_selectable specified."
+ % self
+ )
+
+ def _set_with_polymorphic(self, with_polymorphic):
+ if with_polymorphic == "*":
+ self.with_polymorphic = ("*", None)
+ elif isinstance(with_polymorphic, (tuple, list)):
+ if isinstance(
+ with_polymorphic[0], util.string_types + (tuple, list)
+ ):
+ self.with_polymorphic = with_polymorphic
+ else:
+ self.with_polymorphic = (with_polymorphic, None)
+ elif with_polymorphic is not None:
+ raise sa_exc.ArgumentError("Invalid setting for with_polymorphic")
+ else:
+ self.with_polymorphic = None
+
+ if self.with_polymorphic and self.with_polymorphic[1] is not None:
+ self.with_polymorphic = (
+ self.with_polymorphic[0],
+ coercions.expect(
+ roles.StrictFromClauseRole,
+ self.with_polymorphic[1],
+ allow_select=True,
+ ),
+ )
+
+ if self.configured:
+ self._expire_memoizations()
+
+ def _add_with_polymorphic_subclass(self, mapper):
+ subcl = mapper.class_
+ if self.with_polymorphic is None:
+ self._set_with_polymorphic((subcl,))
+ elif self.with_polymorphic[0] != "*":
+ self._set_with_polymorphic(
+ (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1])
+ )
+
+ def _set_concrete_base(self, mapper):
+ """Set the given :class:`_orm.Mapper` as the 'inherits' for this
+ :class:`_orm.Mapper`, assuming this :class:`_orm.Mapper` is concrete
+ and does not already have an inherits."""
+
+ assert self.concrete
+ assert not self.inherits
+ assert isinstance(mapper, Mapper)
+ self.inherits = mapper
+ self.inherits.polymorphic_map.update(self.polymorphic_map)
+ self.polymorphic_map = self.inherits.polymorphic_map
+ for mapper in self.iterate_to_root():
+ if mapper.polymorphic_on is not None:
+ mapper._requires_row_aliasing = True
+ self.batch = self.inherits.batch
+ for mp in self.self_and_descendants:
+ mp.base_mapper = self.inherits.base_mapper
+ self.inherits._inheriting_mappers.append(self)
+ self.passive_updates = self.inherits.passive_updates
+ self._all_tables = self.inherits._all_tables
+
+ for key, prop in mapper._props.items():
+ if key not in self._props and not self._should_exclude(
+ key, key, local=False, column=None
+ ):
+ self._adapt_inherited_property(key, prop, False)
+
+ def _set_polymorphic_on(self, polymorphic_on):
+ self.polymorphic_on = polymorphic_on
+ self._configure_polymorphic_setter(True)
+
+ def _configure_class_instrumentation(self):
+ """If this mapper is to be a primary mapper (i.e. the
+ non_primary flag is not set), associate this Mapper with the
+ given class and entity name.
+
+ Subsequent calls to ``class_mapper()`` for the ``class_`` / ``entity``
+ name combination will return this mapper. Also decorate the
+ `__init__` method on the mapped class to include optional
+ auto-session attachment logic.
+
+ """
+
+ # we expect that declarative has applied the class manager
+ # already and set up a registry. if this is None,
+ # we will emit a deprecation warning below when we also see that
+ # it has no registry.
+ manager = attributes.manager_of_class(self.class_)
+
+ if self.non_primary:
+ if not manager or not manager.is_mapped:
+ raise sa_exc.InvalidRequestError(
+ "Class %s has no primary mapper configured. Configure "
+ "a primary mapper first before setting up a non primary "
+ "Mapper." % self.class_
+ )
+ self.class_manager = manager
+ self.registry = manager.registry
+ self._identity_class = manager.mapper._identity_class
+ manager.registry._add_non_primary_mapper(self)
+ return
+
+ if manager is not None:
+ assert manager.class_ is self.class_
+ if manager.is_mapped:
+ # changed in #7579:
+ # this message is defined in two places as of this change,
+ # also in decl_api -> _add_manager(). in 2.0, this codepath
+ # is removed as any calls to mapper() / Mapper without
+ # the registry setting up first will be rejected.
+ raise sa_exc.ArgumentError(
+ "Class '%s' already has a primary mapper defined. "
+ % self.class_
+ )
+ # else:
+ # a ClassManager may already exist as
+ # ClassManager.instrument_attribute() creates
+ # new managers for each subclass if they don't yet exist.
+
+ self.dispatch.instrument_class(self, self.class_)
+
+ # this invokes the class_instrument event and sets up
+ # the __init__ method. documented behavior is that this must
+ # occur after the instrument_class event above.
+ # yes two events with the same two words reversed and different APIs.
+ # :(
+
+ manager = instrumentation.register_class(
+ self.class_,
+ mapper=self,
+ expired_attribute_loader=util.partial(
+ loading.load_scalar_attributes, self
+ ),
+ # finalize flag means instrument the __init__ method
+ # and call the class_instrument event
+ finalize=True,
+ )
+
+ if not manager.registry:
+ util.warn_deprecated_20(
+ "Calling the mapper() function directly outside of a "
+ "declarative registry is deprecated."
+ " Please use the sqlalchemy.orm.registry.map_imperatively() "
+ "function for a classical mapping."
+ )
+ assert _legacy_registry is not None
+ _legacy_registry._add_manager(manager)
+
+ self.class_manager = manager
+ self.registry = manager.registry
+
+ # The remaining members can be added by any mapper,
+ # e_name None or not.
+ if manager.mapper is None:
+ return
+
+ event.listen(manager, "init", _event_on_init, raw=True)
+
+ for key, method in util.iterate_attributes(self.class_):
+ if key == "__init__" and hasattr(method, "_sa_original_init"):
+ method = method._sa_original_init
+ if hasattr(method, "__func__"):
+ method = method.__func__
+ if callable(method):
+ if hasattr(method, "__sa_reconstructor__"):
+ self._reconstructor = method
+ event.listen(manager, "load", _event_on_load, raw=True)
+ elif hasattr(method, "__sa_validators__"):
+ validation_opts = method.__sa_validation_opts__
+ for name in method.__sa_validators__:
+ if name in self.validators:
+ raise sa_exc.InvalidRequestError(
+ "A validation function for mapped "
+ "attribute %r on mapper %s already exists."
+ % (name, self)
+ )
+ self.validators = self.validators.union(
+ {name: (method, validation_opts)}
+ )
+
+ def _set_dispose_flags(self):
+ self.configured = True
+ self._ready_for_configure = True
+ self._dispose_called = True
+
+ self.__dict__.pop("_configure_failed", None)
+
+ def _configure_pks(self):
+ self.tables = sql_util.find_tables(self.persist_selectable)
+
+ self._pks_by_table = {}
+ self._cols_by_table = {}
+
+ all_cols = util.column_set(
+ chain(*[col.proxy_set for col in self._columntoproperty])
+ )
+
+ pk_cols = util.column_set(c for c in all_cols if c.primary_key)
+
+ # identify primary key columns which are also mapped by this mapper.
+ tables = set(self.tables + [self.persist_selectable])
+ self._all_tables.update(tables)
+ for t in tables:
+ if t.primary_key and pk_cols.issuperset(t.primary_key):
+ # ordering is important since it determines the ordering of
+ # mapper.primary_key (and therefore query.get())
+ self._pks_by_table[t] = util.ordered_column_set(
+ t.primary_key
+ ).intersection(pk_cols)
+ self._cols_by_table[t] = util.ordered_column_set(t.c).intersection(
+ all_cols
+ )
+
+ # if explicit PK argument sent, add those columns to the
+ # primary key mappings
+ if self._primary_key_argument:
+ for k in self._primary_key_argument:
+ if k.table not in self._pks_by_table:
+ self._pks_by_table[k.table] = util.OrderedSet()
+ self._pks_by_table[k.table].add(k)
+
+ # otherwise, see that we got a full PK for the mapped table
+ elif (
+ self.persist_selectable not in self._pks_by_table
+ or len(self._pks_by_table[self.persist_selectable]) == 0
+ ):
+ raise sa_exc.ArgumentError(
+ "Mapper %s could not assemble any primary "
+ "key columns for mapped table '%s'"
+ % (self, self.persist_selectable.description)
+ )
+ elif self.local_table not in self._pks_by_table and isinstance(
+ self.local_table, schema.Table
+ ):
+ util.warn(
+ "Could not assemble any primary "
+ "keys for locally mapped table '%s' - "
+ "no rows will be persisted in this Table."
+ % self.local_table.description
+ )
+
+ if (
+ self.inherits
+ and not self.concrete
+ and not self._primary_key_argument
+ ):
+ # if inheriting, the "primary key" for this mapper is
+ # that of the inheriting (unless concrete or explicit)
+ self.primary_key = self.inherits.primary_key
+ else:
+ # determine primary key from argument or persist_selectable pks
+ if self._primary_key_argument:
+ primary_key = [
+ self.persist_selectable.corresponding_column(c)
+ for c in self._primary_key_argument
+ ]
+ else:
+ # if heuristically determined PKs, reduce to the minimal set
+ # of columns by eliminating FK->PK pairs for a multi-table
+ # expression. May over-reduce for some kinds of UNIONs
+ # / CTEs; use explicit PK argument for these special cases
+ primary_key = sql_util.reduce_columns(
+ self._pks_by_table[self.persist_selectable],
+ ignore_nonexistent_tables=True,
+ )
+
+ if len(primary_key) == 0:
+ raise sa_exc.ArgumentError(
+ "Mapper %s could not assemble any primary "
+ "key columns for mapped table '%s'"
+ % (self, self.persist_selectable.description)
+ )
+
+ self.primary_key = tuple(primary_key)
+ self._log("Identified primary key columns: %s", primary_key)
+
+ # determine cols that aren't expressed within our tables; mark these
+ # as "read only" properties which are refreshed upon INSERT/UPDATE
+ self._readonly_props = set(
+ self._columntoproperty[col]
+ for col in self._columntoproperty
+ if self._columntoproperty[col] not in self._identity_key_props
+ and (
+ not hasattr(col, "table")
+ or col.table not in self._cols_by_table
+ )
+ )
+
+ def _configure_properties(self):
+
+ # TODO: consider using DedupeColumnCollection
+ self.columns = self.c = sql_base.ColumnCollection()
+
+ # object attribute names mapped to MapperProperty objects
+ self._props = util.OrderedDict()
+
+ # table columns mapped to MapperProperty
+ self._columntoproperty = _ColumnMapping(self)
+
+ # load custom properties
+ if self._init_properties:
+ for key, prop in self._init_properties.items():
+ self._configure_property(key, prop, False)
+
+ # pull properties from the inherited mapper if any.
+ if self.inherits:
+ for key, prop in self.inherits._props.items():
+ if key not in self._props and not self._should_exclude(
+ key, key, local=False, column=None
+ ):
+ self._adapt_inherited_property(key, prop, False)
+
+ # create properties for each column in the mapped table,
+ # for those columns which don't already map to a property
+ for column in self.persist_selectable.columns:
+ if column in self._columntoproperty:
+ continue
+
+ column_key = (self.column_prefix or "") + column.key
+
+ if self._should_exclude(
+ column.key,
+ column_key,
+ local=self.local_table.c.contains_column(column),
+ column=column,
+ ):
+ continue
+
+ # adjust the "key" used for this column to that
+ # of the inheriting mapper
+ for mapper in self.iterate_to_root():
+ if column in mapper._columntoproperty:
+ column_key = mapper._columntoproperty[column].key
+
+ self._configure_property(
+ column_key, column, init=False, setparent=True
+ )
+
+ def _configure_polymorphic_setter(self, init=False):
+ """Configure an attribute on the mapper representing the
+ 'polymorphic_on' column, if applicable, and not
+ already generated by _configure_properties (which is typical).
+
+ Also create a setter function which will assign this
+ attribute to the value of the 'polymorphic_identity'
+ upon instance construction, also if applicable. This
+ routine will run when an instance is created.
+
+ """
+ setter = False
+
+ if self.polymorphic_on is not None:
+ setter = True
+
+ if isinstance(self.polymorphic_on, util.string_types):
+ # polymorphic_on specified as a string - link
+ # it to mapped ColumnProperty
+ try:
+ self.polymorphic_on = self._props[self.polymorphic_on]
+ except KeyError as err:
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Can't determine polymorphic_on "
+ "value '%s' - no attribute is "
+ "mapped to this name." % self.polymorphic_on
+ ),
+ replace_context=err,
+ )
+
+ if self.polymorphic_on in self._columntoproperty:
+ # polymorphic_on is a column that is already mapped
+ # to a ColumnProperty
+ prop = self._columntoproperty[self.polymorphic_on]
+ elif isinstance(self.polymorphic_on, MapperProperty):
+ # polymorphic_on is directly a MapperProperty,
+ # ensure it's a ColumnProperty
+ if not isinstance(
+ self.polymorphic_on, properties.ColumnProperty
+ ):
+ raise sa_exc.ArgumentError(
+ "Only direct column-mapped "
+ "property or SQL expression "
+ "can be passed for polymorphic_on"
+ )
+ prop = self.polymorphic_on
+ else:
+ # polymorphic_on is a Column or SQL expression and
+ # doesn't appear to be mapped. this means it can be 1.
+ # only present in the with_polymorphic selectable or
+ # 2. a totally standalone SQL expression which we'd
+ # hope is compatible with this mapper's persist_selectable
+ col = self.persist_selectable.corresponding_column(
+ self.polymorphic_on
+ )
+ if col is None:
+ # polymorphic_on doesn't derive from any
+ # column/expression isn't present in the mapped
+ # table. we will make a "hidden" ColumnProperty
+ # for it. Just check that if it's directly a
+ # schema.Column and we have with_polymorphic, it's
+ # likely a user error if the schema.Column isn't
+ # represented somehow in either persist_selectable or
+ # with_polymorphic. Otherwise as of 0.7.4 we
+ # just go with it and assume the user wants it
+ # that way (i.e. a CASE statement)
+ setter = False
+ instrument = False
+ col = self.polymorphic_on
+ if isinstance(col, schema.Column) and (
+ self.with_polymorphic is None
+ or self.with_polymorphic[1].corresponding_column(col)
+ is None
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Could not map polymorphic_on column "
+ "'%s' to the mapped table - polymorphic "
+ "loads will not function properly"
+ % col.description
+ )
+ else:
+ # column/expression that polymorphic_on derives from
+ # is present in our mapped table
+ # and is probably mapped, but polymorphic_on itself
+ # is not. This happens when
+ # the polymorphic_on is only directly present in the
+ # with_polymorphic selectable, as when use
+ # polymorphic_union.
+ # we'll make a separate ColumnProperty for it.
+ instrument = True
+ key = getattr(col, "key", None)
+ if key:
+ if self._should_exclude(col.key, col.key, False, col):
+ raise sa_exc.InvalidRequestError(
+ "Cannot exclude or override the "
+ "discriminator column %r" % col.key
+ )
+ else:
+ self.polymorphic_on = col = col.label("_sa_polymorphic_on")
+ key = col.key
+
+ prop = properties.ColumnProperty(col, _instrument=instrument)
+ self._configure_property(key, prop, init=init, setparent=True)
+
+ # the actual polymorphic_on should be the first public-facing
+ # column in the property
+ self.polymorphic_on = prop.columns[0]
+ polymorphic_key = prop.key
+
+ else:
+ # no polymorphic_on was set.
+ # check inheriting mappers for one.
+ for mapper in self.iterate_to_root():
+ # determine if polymorphic_on of the parent
+ # should be propagated here. If the col
+ # is present in our mapped table, or if our mapped
+ # table is the same as the parent (i.e. single table
+ # inheritance), we can use it
+ if mapper.polymorphic_on is not None:
+ if self.persist_selectable is mapper.persist_selectable:
+ self.polymorphic_on = mapper.polymorphic_on
+ else:
+ self.polymorphic_on = (
+ self.persist_selectable
+ ).corresponding_column(mapper.polymorphic_on)
+ # we can use the parent mapper's _set_polymorphic_identity
+ # directly; it ensures the polymorphic_identity of the
+ # instance's mapper is used so is portable to subclasses.
+ if self.polymorphic_on is not None:
+ self._set_polymorphic_identity = (
+ mapper._set_polymorphic_identity
+ )
+ self._validate_polymorphic_identity = (
+ mapper._validate_polymorphic_identity
+ )
+ else:
+ self._set_polymorphic_identity = None
+ return
+
+ if setter:
+
+ def _set_polymorphic_identity(state):
+ dict_ = state.dict
+ state.get_impl(polymorphic_key).set(
+ state,
+ dict_,
+ state.manager.mapper.polymorphic_identity,
+ None,
+ )
+
+ def _validate_polymorphic_identity(mapper, state, dict_):
+ if (
+ polymorphic_key in dict_
+ and dict_[polymorphic_key]
+ not in mapper._acceptable_polymorphic_identities
+ ):
+ util.warn_limited(
+ "Flushing object %s with "
+ "incompatible polymorphic identity %r; the "
+ "object may not refresh and/or load correctly",
+ (state_str(state), dict_[polymorphic_key]),
+ )
+
+ self._set_polymorphic_identity = _set_polymorphic_identity
+ self._validate_polymorphic_identity = (
+ _validate_polymorphic_identity
+ )
+ else:
+ self._set_polymorphic_identity = None
+
+ _validate_polymorphic_identity = None
+
+ @HasMemoized.memoized_attribute
+ def _version_id_prop(self):
+ if self.version_id_col is not None:
+ return self._columntoproperty[self.version_id_col]
+ else:
+ return None
+
+ @HasMemoized.memoized_attribute
+ def _acceptable_polymorphic_identities(self):
+ identities = set()
+
+ stack = deque([self])
+ while stack:
+ item = stack.popleft()
+ if item.persist_selectable is self.persist_selectable:
+ identities.add(item.polymorphic_identity)
+ stack.extend(item._inheriting_mappers)
+
+ return identities
+
+ @HasMemoized.memoized_attribute
+ def _prop_set(self):
+ return frozenset(self._props.values())
+
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def _adapt_inherited_property(self, key, prop, init):
+ descriptor_props = util.preloaded.orm_descriptor_props
+
+ if not self.concrete:
+ self._configure_property(key, prop, init=False, setparent=False)
+ elif key not in self._props:
+ # determine if the class implements this attribute; if not,
+ # or if it is implemented by the attribute that is handling the
+ # given superclass-mapped property, then we need to report that we
+ # can't use this at the instance level since we are a concrete
+ # mapper and we don't map this. don't trip user-defined
+ # descriptors that might have side effects when invoked.
+ implementing_attribute = self.class_manager._get_class_attr_mro(
+ key, prop
+ )
+ if implementing_attribute is prop or (
+ isinstance(
+ implementing_attribute, attributes.InstrumentedAttribute
+ )
+ and implementing_attribute._parententity is prop.parent
+ ):
+ self._configure_property(
+ key,
+ descriptor_props.ConcreteInheritedProperty(),
+ init=init,
+ setparent=True,
+ )
+
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def _configure_property(self, key, prop, init=True, setparent=True):
+ descriptor_props = util.preloaded.orm_descriptor_props
+ self._log("_configure_property(%s, %s)", key, prop.__class__.__name__)
+
+ if not isinstance(prop, MapperProperty):
+ prop = self._property_from_column(key, prop)
+
+ if isinstance(prop, properties.ColumnProperty):
+ col = self.persist_selectable.corresponding_column(prop.columns[0])
+
+ # if the column is not present in the mapped table,
+ # test if a column has been added after the fact to the
+ # parent table (or their parent, etc.) [ticket:1570]
+ if col is None and self.inherits:
+ path = [self]
+ for m in self.inherits.iterate_to_root():
+ col = m.local_table.corresponding_column(prop.columns[0])
+ if col is not None:
+ for m2 in path:
+ m2.persist_selectable._refresh_for_new_column(col)
+ col = self.persist_selectable.corresponding_column(
+ prop.columns[0]
+ )
+ break
+ path.append(m)
+
+ # subquery expression, column not present in the mapped
+ # selectable.
+ if col is None:
+ col = prop.columns[0]
+
+ # column is coming in after _readonly_props was
+ # initialized; check for 'readonly'
+ if hasattr(self, "_readonly_props") and (
+ not hasattr(col, "table")
+ or col.table not in self._cols_by_table
+ ):
+ self._readonly_props.add(prop)
+
+ else:
+ # if column is coming in after _cols_by_table was
+ # initialized, ensure the col is in the right set
+ if (
+ hasattr(self, "_cols_by_table")
+ and col.table in self._cols_by_table
+ and col not in self._cols_by_table[col.table]
+ ):
+ self._cols_by_table[col.table].add(col)
+
+ # if this properties.ColumnProperty represents the "polymorphic
+ # discriminator" column, mark it. We'll need this when rendering
+ # columns in SELECT statements.
+ if not hasattr(prop, "_is_polymorphic_discriminator"):
+ prop._is_polymorphic_discriminator = (
+ col is self.polymorphic_on
+ or prop.columns[0] is self.polymorphic_on
+ )
+
+ if isinstance(col, expression.Label):
+ # new in 1.4, get column property against expressions
+ # to be addressable in subqueries
+ col.key = col._tq_key_label = key
+
+ self.columns.add(col, key)
+ for col in prop.columns:
+ for col in col.proxy_set:
+ self._columntoproperty[col] = prop
+
+ prop.key = key
+
+ if setparent:
+ prop.set_parent(self, init)
+
+ if key in self._props and getattr(
+ self._props[key], "_mapped_by_synonym", False
+ ):
+ syn = self._props[key]._mapped_by_synonym
+ raise sa_exc.ArgumentError(
+ "Can't call map_column=True for synonym %r=%r, "
+ "a ColumnProperty already exists keyed to the name "
+ "%r for column %r" % (syn, key, key, syn)
+ )
+
+ if (
+ key in self._props
+ and not isinstance(prop, properties.ColumnProperty)
+ and not isinstance(
+ self._props[key],
+ (
+ properties.ColumnProperty,
+ descriptor_props.ConcreteInheritedProperty,
+ ),
+ )
+ ):
+ util.warn(
+ "Property %s on %s being replaced with new "
+ "property %s; the old property will be discarded"
+ % (self._props[key], self, prop)
+ )
+ oldprop = self._props[key]
+ self._path_registry.pop(oldprop, None)
+
+ self._props[key] = prop
+
+ if not self.non_primary:
+ prop.instrument_class(self)
+
+ for mapper in self._inheriting_mappers:
+ mapper._adapt_inherited_property(key, prop, init)
+
+ if init:
+ prop.init()
+ prop.post_instrument_class(self)
+
+ if self.configured:
+ self._expire_memoizations()
+
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def _property_from_column(self, key, prop):
+ """generate/update a :class:`.ColumnProperty` given a
+ :class:`_schema.Column` object."""
+ descriptor_props = util.preloaded.orm_descriptor_props
+ # we were passed a Column or a list of Columns;
+ # generate a properties.ColumnProperty
+ columns = util.to_list(prop)
+ column = columns[0]
+ assert isinstance(column, expression.ColumnElement)
+
+ prop = self._props.get(key, None)
+
+ if isinstance(prop, properties.ColumnProperty):
+ if (
+ (
+ not self._inherits_equated_pairs
+ or (prop.columns[0], column)
+ not in self._inherits_equated_pairs
+ )
+ and not prop.columns[0].shares_lineage(column)
+ and prop.columns[0] is not self.version_id_col
+ and column is not self.version_id_col
+ ):
+ warn_only = prop.parent is not self
+ msg = (
+ "Implicitly combining column %s with column "
+ "%s under attribute '%s'. Please configure one "
+ "or more attributes for these same-named columns "
+ "explicitly." % (prop.columns[-1], column, key)
+ )
+ if warn_only:
+ util.warn(msg)
+ else:
+ raise sa_exc.InvalidRequestError(msg)
+
+ # existing properties.ColumnProperty from an inheriting
+ # mapper. make a copy and append our column to it
+ prop = prop.copy()
+ prop.columns.insert(0, column)
+ self._log(
+ "inserting column to existing list "
+ "in properties.ColumnProperty %s" % (key)
+ )
+ return prop
+ elif prop is None or isinstance(
+ prop, descriptor_props.ConcreteInheritedProperty
+ ):
+ mapped_column = []
+ for c in columns:
+ mc = self.persist_selectable.corresponding_column(c)
+ if mc is None:
+ mc = self.local_table.corresponding_column(c)
+ if mc is not None:
+ # if the column is in the local table but not the
+ # mapped table, this corresponds to adding a
+ # column after the fact to the local table.
+ # [ticket:1523]
+ self.persist_selectable._refresh_for_new_column(mc)
+ mc = self.persist_selectable.corresponding_column(c)
+ if mc is None:
+ raise sa_exc.ArgumentError(
+ "When configuring property '%s' on %s, "
+ "column '%s' is not represented in the mapper's "
+ "table. Use the `column_property()` function to "
+ "force this column to be mapped as a read-only "
+ "attribute." % (key, self, c)
+ )
+ mapped_column.append(mc)
+ return properties.ColumnProperty(*mapped_column)
+ else:
+ raise sa_exc.ArgumentError(
+ "WARNING: when configuring property '%s' on %s, "
+ "column '%s' conflicts with property '%r'. "
+ "To resolve this, map the column to the class under a "
+ "different name in the 'properties' dictionary. Or, "
+ "to remove all awareness of the column entirely "
+ "(including its availability as a foreign key), "
+ "use the 'include_properties' or 'exclude_properties' "
+ "mapper arguments to control specifically which table "
+ "columns get mapped." % (key, self, column.key, prop)
+ )
+
+ def _check_configure(self):
+ if self.registry._new_mappers:
+ _configure_registries({self.registry}, cascade=True)
+
+ def _post_configure_properties(self):
+ """Call the ``init()`` method on all ``MapperProperties``
+ attached to this mapper.
+
+ This is a deferred configuration step which is intended
+ to execute once all mappers have been constructed.
+
+ """
+
+ self._log("_post_configure_properties() started")
+ l = [(key, prop) for key, prop in self._props.items()]
+ for key, prop in l:
+ self._log("initialize prop %s", key)
+
+ if prop.parent is self and not prop._configure_started:
+ prop.init()
+
+ if prop._configure_finished:
+ prop.post_instrument_class(self)
+
+ self._log("_post_configure_properties() complete")
+ self.configured = True
+
+ def add_properties(self, dict_of_properties):
+ """Add the given dictionary of properties to this mapper,
+ using `add_property`.
+
+ """
+ for key, value in dict_of_properties.items():
+ self.add_property(key, value)
+
+ def add_property(self, key, prop):
+ """Add an individual MapperProperty to this mapper.
+
+ If the mapper has not been configured yet, just adds the
+ property to the initial properties dictionary sent to the
+ constructor. If this Mapper has already been configured, then
+ the given MapperProperty is configured immediately.
+
+ """
+ self._init_properties[key] = prop
+ self._configure_property(key, prop, init=self.configured)
+
+ def _expire_memoizations(self):
+ for mapper in self.iterate_to_root():
+ mapper._reset_memoizations()
+
+ @property
+ def _log_desc(self):
+ return (
+ "("
+ + self.class_.__name__
+ + "|"
+ + (
+ self.local_table is not None
+ and self.local_table.description
+ or str(self.local_table)
+ )
+ + (self.non_primary and "|non-primary" or "")
+ + ")"
+ )
+
+ def _log(self, msg, *args):
+ self.logger.info("%s " + msg, *((self._log_desc,) + args))
+
+ def _log_debug(self, msg, *args):
+ self.logger.debug("%s " + msg, *((self._log_desc,) + args))
+
+ def __repr__(self):
+ return "<Mapper at 0x%x; %s>" % (id(self), self.class_.__name__)
+
+ def __str__(self):
+ return "mapped class %s%s->%s" % (
+ self.class_.__name__,
+ self.non_primary and " (non-primary)" or "",
+ self.local_table.description
+ if self.local_table is not None
+ else self.persist_selectable.description,
+ )
+
+ def _is_orphan(self, state):
+ orphan_possible = False
+ for mapper in self.iterate_to_root():
+ for (key, cls) in mapper._delete_orphans:
+ orphan_possible = True
+
+ has_parent = attributes.manager_of_class(cls).has_parent(
+ state, key, optimistic=state.has_identity
+ )
+
+ if self.legacy_is_orphan and has_parent:
+ return False
+ elif not self.legacy_is_orphan and not has_parent:
+ return True
+
+ if self.legacy_is_orphan:
+ return orphan_possible
+ else:
+ return False
+
+ def has_property(self, key):
+ return key in self._props
+
+ def get_property(self, key, _configure_mappers=True):
+ """return a MapperProperty associated with the given key."""
+
+ if _configure_mappers:
+ self._check_configure()
+
+ try:
+ return self._props[key]
+ except KeyError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Mapper '%s' has no property '%s'" % (self, key)
+ ),
+ replace_context=err,
+ )
+
+ def get_property_by_column(self, column):
+ """Given a :class:`_schema.Column` object, return the
+ :class:`.MapperProperty` which maps this column."""
+
+ return self._columntoproperty[column]
+
+ @property
+ def iterate_properties(self):
+ """return an iterator of all MapperProperty objects."""
+
+ self._check_configure()
+ return iter(self._props.values())
+
+ def _mappers_from_spec(self, spec, selectable):
+ """given a with_polymorphic() argument, return the set of mappers it
+ represents.
+
+ Trims the list of mappers to just those represented within the given
+ selectable, if present. This helps some more legacy-ish mappings.
+
+ """
+ if spec == "*":
+ mappers = list(self.self_and_descendants)
+ elif spec:
+ mappers = set()
+ for m in util.to_list(spec):
+ m = _class_to_mapper(m)
+ if not m.isa(self):
+ raise sa_exc.InvalidRequestError(
+ "%r does not inherit from %r" % (m, self)
+ )
+
+ if selectable is None:
+ mappers.update(m.iterate_to_root())
+ else:
+ mappers.add(m)
+ mappers = [m for m in self.self_and_descendants if m in mappers]
+ else:
+ mappers = []
+
+ if selectable is not None:
+ tables = set(
+ sql_util.find_tables(selectable, include_aliases=True)
+ )
+ mappers = [m for m in mappers if m.local_table in tables]
+ return mappers
+
+ def _selectable_from_mappers(self, mappers, innerjoin):
+ """given a list of mappers (assumed to be within this mapper's
+ inheritance hierarchy), construct an outerjoin amongst those mapper's
+ mapped tables.
+
+ """
+ from_obj = self.persist_selectable
+ for m in mappers:
+ if m is self:
+ continue
+ if m.concrete:
+ raise sa_exc.InvalidRequestError(
+ "'with_polymorphic()' requires 'selectable' argument "
+ "when concrete-inheriting mappers are used."
+ )
+ elif not m.single:
+ if innerjoin:
+ from_obj = from_obj.join(
+ m.local_table, m.inherit_condition
+ )
+ else:
+ from_obj = from_obj.outerjoin(
+ m.local_table, m.inherit_condition
+ )
+
+ return from_obj
+
+ @HasMemoized.memoized_attribute
+ def _single_table_criterion(self):
+ if self.single and self.inherits and self.polymorphic_on is not None:
+ return self.polymorphic_on._annotate(
+ {"parententity": self, "parentmapper": self}
+ ).in_(m.polymorphic_identity for m in self.self_and_descendants)
+ else:
+ return None
+
+ @HasMemoized.memoized_attribute
+ def _with_polymorphic_mappers(self):
+ self._check_configure()
+
+ if not self.with_polymorphic:
+ return []
+ return self._mappers_from_spec(*self.with_polymorphic)
+
+ @HasMemoized.memoized_attribute
+ def _post_inspect(self):
+ """This hook is invoked by attribute inspection.
+
+ E.g. when Query calls:
+
+ coercions.expect(roles.ColumnsClauseRole, ent, keep_inspect=True)
+
+ This allows the inspection process run a configure mappers hook.
+
+ """
+ self._check_configure()
+
+ @HasMemoized.memoized_attribute
+ def _with_polymorphic_selectable(self):
+ if not self.with_polymorphic:
+ return self.persist_selectable
+
+ spec, selectable = self.with_polymorphic
+ if selectable is not None:
+ return selectable
+ else:
+ return self._selectable_from_mappers(
+ self._mappers_from_spec(spec, selectable), False
+ )
+
+ with_polymorphic_mappers = _with_polymorphic_mappers
+ """The list of :class:`_orm.Mapper` objects included in the
+ default "polymorphic" query.
+
+ """
+
+ @HasMemoized.memoized_attribute
+ def _insert_cols_evaluating_none(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ col for col in columns if col.type.should_evaluate_none
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _insert_cols_as_none(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ col.key
+ for col in columns
+ if not col.primary_key
+ and not col.server_default
+ and not col.default
+ and not col.type.should_evaluate_none
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _propkey_to_col(self):
+ return dict(
+ (
+ table,
+ dict(
+ (self._columntoproperty[col].key, col) for col in columns
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _pk_keys_by_table(self):
+ return dict(
+ (table, frozenset([col.key for col in pks]))
+ for table, pks in self._pks_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _pk_attr_keys_by_table(self):
+ return dict(
+ (
+ table,
+ frozenset([self._columntoproperty[col].key for col in pks]),
+ )
+ for table, pks in self._pks_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _server_default_cols(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ [
+ col.key
+ for col in columns
+ if col.server_default is not None
+ ]
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_attribute
+ def _server_default_plus_onupdate_propkeys(self):
+ result = set()
+
+ for table, columns in self._cols_by_table.items():
+ for col in columns:
+ if (
+ col.server_default is not None
+ or col.server_onupdate is not None
+ ) and col in self._columntoproperty:
+ result.add(self._columntoproperty[col].key)
+
+ return result
+
+ @HasMemoized.memoized_attribute
+ def _server_onupdate_default_cols(self):
+ return dict(
+ (
+ table,
+ frozenset(
+ [
+ col.key
+ for col in columns
+ if col.server_onupdate is not None
+ ]
+ ),
+ )
+ for table, columns in self._cols_by_table.items()
+ )
+
+ @HasMemoized.memoized_instancemethod
+ def __clause_element__(self):
+
+ annotations = {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ }
+ if self.persist_selectable is not self.local_table:
+ # joined table inheritance, with polymorphic selectable,
+ # etc.
+ annotations["dml_table"] = self.local_table._annotate(
+ {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ }
+ )._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ return self.selectable._annotate(annotations)._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ @util.memoized_property
+ def select_identity_token(self):
+ return (
+ expression.null()
+ ._annotate(
+ {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ "identity_token": True,
+ }
+ )
+ ._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+ )
+
+ @property
+ def selectable(self):
+ """The :class:`_schema.FromClause` construct this
+ :class:`_orm.Mapper` selects from by default.
+
+ Normally, this is equivalent to :attr:`.persist_selectable`, unless
+ the ``with_polymorphic`` feature is in use, in which case the
+ full "polymorphic" selectable is returned.
+
+ """
+ return self._with_polymorphic_selectable
+
+ def _with_polymorphic_args(
+ self, spec=None, selectable=False, innerjoin=False
+ ):
+ if selectable not in (None, False):
+ selectable = coercions.expect(
+ roles.StrictFromClauseRole, selectable, allow_select=True
+ )
+
+ if self.with_polymorphic:
+ if not spec:
+ spec = self.with_polymorphic[0]
+ if selectable is False:
+ selectable = self.with_polymorphic[1]
+ elif selectable is False:
+ selectable = None
+ mappers = self._mappers_from_spec(spec, selectable)
+ if selectable is not None:
+ return mappers, selectable
+ else:
+ return mappers, self._selectable_from_mappers(mappers, innerjoin)
+
+ @HasMemoized.memoized_attribute
+ def _polymorphic_properties(self):
+ return list(
+ self._iterate_polymorphic_properties(
+ self._with_polymorphic_mappers
+ )
+ )
+
+ @property
+ def _all_column_expressions(self):
+ poly_properties = self._polymorphic_properties
+ adapter = self._polymorphic_adapter
+
+ return [
+ adapter.columns[prop.columns[0]] if adapter else prop.columns[0]
+ for prop in poly_properties
+ if isinstance(prop, properties.ColumnProperty)
+ and prop._renders_in_subqueries
+ ]
+
+ def _columns_plus_keys(self, polymorphic_mappers=()):
+ if polymorphic_mappers:
+ poly_properties = self._iterate_polymorphic_properties(
+ polymorphic_mappers
+ )
+ else:
+ poly_properties = self._polymorphic_properties
+
+ return [
+ (prop.key, prop.columns[0])
+ for prop in poly_properties
+ if isinstance(prop, properties.ColumnProperty)
+ ]
+
+ @HasMemoized.memoized_attribute
+ def _polymorphic_adapter(self):
+ if self.with_polymorphic:
+ return sql_util.ColumnAdapter(
+ self.selectable, equivalents=self._equivalent_columns
+ )
+ else:
+ return None
+
+ def _iterate_polymorphic_properties(self, mappers=None):
+ """Return an iterator of MapperProperty objects which will render into
+ a SELECT."""
+ if mappers is None:
+ mappers = self._with_polymorphic_mappers
+
+ if not mappers:
+ for c in self.iterate_properties:
+ yield c
+ else:
+ # in the polymorphic case, filter out discriminator columns
+ # from other mappers, as these are sometimes dependent on that
+ # mapper's polymorphic selectable (which we don't want rendered)
+ for c in util.unique_list(
+ chain(
+ *[
+ list(mapper.iterate_properties)
+ for mapper in [self] + mappers
+ ]
+ )
+ ):
+ if getattr(c, "_is_polymorphic_discriminator", False) and (
+ self.polymorphic_on is None
+ or c.columns[0] is not self.polymorphic_on
+ ):
+ continue
+ yield c
+
+ @HasMemoized.memoized_attribute
+ def attrs(self):
+ """A namespace of all :class:`.MapperProperty` objects
+ associated this mapper.
+
+ This is an object that provides each property based on
+ its key name. For instance, the mapper for a
+ ``User`` class which has ``User.name`` attribute would
+ provide ``mapper.attrs.name``, which would be the
+ :class:`.ColumnProperty` representing the ``name``
+ column. The namespace object can also be iterated,
+ which would yield each :class:`.MapperProperty`.
+
+ :class:`_orm.Mapper` has several pre-filtered views
+ of this attribute which limit the types of properties
+ returned, including :attr:`.synonyms`, :attr:`.column_attrs`,
+ :attr:`.relationships`, and :attr:`.composites`.
+
+ .. warning::
+
+ The :attr:`_orm.Mapper.attrs` accessor namespace is an
+ instance of :class:`.OrderedProperties`. This is
+ a dictionary-like object which includes a small number of
+ named methods such as :meth:`.OrderedProperties.items`
+ and :meth:`.OrderedProperties.values`. When
+ accessing attributes dynamically, favor using the dict-access
+ scheme, e.g. ``mapper.attrs[somename]`` over
+ ``getattr(mapper.attrs, somename)`` to avoid name collisions.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_descriptors`
+
+ """
+
+ self._check_configure()
+ return util.ImmutableProperties(self._props)
+
+ @HasMemoized.memoized_attribute
+ def all_orm_descriptors(self):
+ """A namespace of all :class:`.InspectionAttr` attributes associated
+ with the mapped class.
+
+ These attributes are in all cases Python :term:`descriptors`
+ associated with the mapped class or its superclasses.
+
+ This namespace includes attributes that are mapped to the class
+ as well as attributes declared by extension modules.
+ It includes any Python descriptor type that inherits from
+ :class:`.InspectionAttr`. This includes
+ :class:`.QueryableAttribute`, as well as extension types such as
+ :class:`.hybrid_property`, :class:`.hybrid_method` and
+ :class:`.AssociationProxy`.
+
+ To distinguish between mapped attributes and extension attributes,
+ the attribute :attr:`.InspectionAttr.extension_type` will refer
+ to a constant that distinguishes between different extension types.
+
+ The sorting of the attributes is based on the following rules:
+
+ 1. Iterate through the class and its superclasses in order from
+ subclass to superclass (i.e. iterate through ``cls.__mro__``)
+
+ 2. For each class, yield the attributes in the order in which they
+ appear in ``__dict__``, with the exception of those in step
+ 3 below. In Python 3.6 and above this ordering will be the
+ same as that of the class' construction, with the exception
+ of attributes that were added after the fact by the application
+ or the mapper.
+
+ 3. If a certain attribute key is also in the superclass ``__dict__``,
+ then it's included in the iteration for that class, and not the
+ class in which it first appeared.
+
+ The above process produces an ordering that is deterministic in terms
+ of the order in which attributes were assigned to the class.
+
+ .. versionchanged:: 1.3.19 ensured deterministic ordering for
+ :meth:`_orm.Mapper.all_orm_descriptors`.
+
+ When dealing with a :class:`.QueryableAttribute`, the
+ :attr:`.QueryableAttribute.property` attribute refers to the
+ :class:`.MapperProperty` property, which is what you get when
+ referring to the collection of mapped properties via
+ :attr:`_orm.Mapper.attrs`.
+
+ .. warning::
+
+ The :attr:`_orm.Mapper.all_orm_descriptors`
+ accessor namespace is an
+ instance of :class:`.OrderedProperties`. This is
+ a dictionary-like object which includes a small number of
+ named methods such as :meth:`.OrderedProperties.items`
+ and :meth:`.OrderedProperties.values`. When
+ accessing attributes dynamically, favor using the dict-access
+ scheme, e.g. ``mapper.all_orm_descriptors[somename]`` over
+ ``getattr(mapper.all_orm_descriptors, somename)`` to avoid name
+ collisions.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs`
+
+ """
+ return util.ImmutableProperties(
+ dict(self.class_manager._all_sqla_attributes())
+ )
+
+ @HasMemoized.memoized_attribute
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def synonyms(self):
+ """Return a namespace of all :class:`.SynonymProperty`
+ properties maintained by this :class:`_orm.Mapper`.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ descriptor_props = util.preloaded.orm_descriptor_props
+
+ return self._filter_properties(descriptor_props.SynonymProperty)
+
+ @property
+ def entity_namespace(self):
+ return self.class_
+
+ @HasMemoized.memoized_attribute
+ def column_attrs(self):
+ """Return a namespace of all :class:`.ColumnProperty`
+ properties maintained by this :class:`_orm.Mapper`.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ return self._filter_properties(properties.ColumnProperty)
+
+ @util.preload_module("sqlalchemy.orm.relationships")
+ @HasMemoized.memoized_attribute
+ def relationships(self):
+ """A namespace of all :class:`.RelationshipProperty` properties
+ maintained by this :class:`_orm.Mapper`.
+
+ .. warning::
+
+ the :attr:`_orm.Mapper.relationships` accessor namespace is an
+ instance of :class:`.OrderedProperties`. This is
+ a dictionary-like object which includes a small number of
+ named methods such as :meth:`.OrderedProperties.items`
+ and :meth:`.OrderedProperties.values`. When
+ accessing attributes dynamically, favor using the dict-access
+ scheme, e.g. ``mapper.relationships[somename]`` over
+ ``getattr(mapper.relationships, somename)`` to avoid name
+ collisions.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ return self._filter_properties(
+ util.preloaded.orm_relationships.RelationshipProperty
+ )
+
+ @HasMemoized.memoized_attribute
+ @util.preload_module("sqlalchemy.orm.descriptor_props")
+ def composites(self):
+ """Return a namespace of all :class:`.CompositeProperty`
+ properties maintained by this :class:`_orm.Mapper`.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.attrs` - namespace of all
+ :class:`.MapperProperty`
+ objects.
+
+ """
+ return self._filter_properties(
+ util.preloaded.orm_descriptor_props.CompositeProperty
+ )
+
+ def _filter_properties(self, type_):
+ self._check_configure()
+ return util.ImmutableProperties(
+ util.OrderedDict(
+ (k, v) for k, v in self._props.items() if isinstance(v, type_)
+ )
+ )
+
+ @HasMemoized.memoized_attribute
+ def _get_clause(self):
+ """create a "get clause" based on the primary key. this is used
+ by query.get() and many-to-one lazyloads to load this item
+ by primary key.
+
+ """
+ params = [
+ (
+ primary_key,
+ sql.bindparam("pk_%d" % idx, type_=primary_key.type),
+ )
+ for idx, primary_key in enumerate(self.primary_key, 1)
+ ]
+ return (
+ sql.and_(*[k == v for (k, v) in params]),
+ util.column_dict(params),
+ )
+
+ @HasMemoized.memoized_attribute
+ def _equivalent_columns(self):
+ """Create a map of all equivalent columns, based on
+ the determination of column pairs that are equated to
+ one another based on inherit condition. This is designed
+ to work with the queries that util.polymorphic_union
+ comes up with, which often don't include the columns from
+ the base table directly (including the subclass table columns
+ only).
+
+ The resulting structure is a dictionary of columns mapped
+ to lists of equivalent columns, e.g.::
+
+ {
+ tablea.col1:
+ {tableb.col1, tablec.col1},
+ tablea.col2:
+ {tabled.col2}
+ }
+
+ """
+ result = util.column_dict()
+
+ def visit_binary(binary):
+ if binary.operator == operators.eq:
+ if binary.left in result:
+ result[binary.left].add(binary.right)
+ else:
+ result[binary.left] = util.column_set((binary.right,))
+ if binary.right in result:
+ result[binary.right].add(binary.left)
+ else:
+ result[binary.right] = util.column_set((binary.left,))
+
+ for mapper in self.base_mapper.self_and_descendants:
+ if mapper.inherit_condition is not None:
+ visitors.traverse(
+ mapper.inherit_condition, {}, {"binary": visit_binary}
+ )
+
+ return result
+
+ def _is_userland_descriptor(self, assigned_name, obj):
+ if isinstance(
+ obj,
+ (
+ _MappedAttribute,
+ instrumentation.ClassManager,
+ expression.ColumnElement,
+ ),
+ ):
+ return False
+ else:
+ return assigned_name not in self._dataclass_fields
+
+ @HasMemoized.memoized_attribute
+ def _dataclass_fields(self):
+ return [f.name for f in util.dataclass_fields(self.class_)]
+
+ def _should_exclude(self, name, assigned_name, local, column):
+ """determine whether a particular property should be implicitly
+ present on the class.
+
+ This occurs when properties are propagated from an inherited class, or
+ are applied from the columns present in the mapped table.
+
+ """
+
+ # check for class-bound attributes and/or descriptors,
+ # either local or from an inherited class
+ # ignore dataclass field default values
+ if local:
+ if self.class_.__dict__.get(
+ assigned_name, None
+ ) is not None and self._is_userland_descriptor(
+ assigned_name, self.class_.__dict__[assigned_name]
+ ):
+ return True
+ else:
+ attr = self.class_manager._get_class_attr_mro(assigned_name, None)
+ if attr is not None and self._is_userland_descriptor(
+ assigned_name, attr
+ ):
+ return True
+
+ if (
+ self.include_properties is not None
+ and name not in self.include_properties
+ and (column is None or column not in self.include_properties)
+ ):
+ self._log("not including property %s" % (name))
+ return True
+
+ if self.exclude_properties is not None and (
+ name in self.exclude_properties
+ or (column is not None and column in self.exclude_properties)
+ ):
+ self._log("excluding property %s" % (name))
+ return True
+
+ return False
+
+ def common_parent(self, other):
+ """Return true if the given mapper shares a
+ common inherited parent as this mapper."""
+
+ return self.base_mapper is other.base_mapper
+
+ def is_sibling(self, other):
+ """return true if the other mapper is an inheriting sibling to this
+ one. common parent but different branch
+
+ """
+ return (
+ self.base_mapper is other.base_mapper
+ and not self.isa(other)
+ and not other.isa(self)
+ )
+
+ def _canload(self, state, allow_subtypes):
+ s = self.primary_mapper()
+ if self.polymorphic_on is not None or allow_subtypes:
+ return _state_mapper(state).isa(s)
+ else:
+ return _state_mapper(state) is s
+
+ def isa(self, other):
+ """Return True if the this mapper inherits from the given mapper."""
+
+ m = self
+ while m and m is not other:
+ m = m.inherits
+ return bool(m)
+
+ def iterate_to_root(self):
+ m = self
+ while m:
+ yield m
+ m = m.inherits
+
+ @HasMemoized.memoized_attribute
+ def self_and_descendants(self):
+ """The collection including this mapper and all descendant mappers.
+
+ This includes not just the immediately inheriting mappers but
+ all their inheriting mappers as well.
+
+ """
+ descendants = []
+ stack = deque([self])
+ while stack:
+ item = stack.popleft()
+ descendants.append(item)
+ stack.extend(item._inheriting_mappers)
+ return util.WeakSequence(descendants)
+
+ def polymorphic_iterator(self):
+ """Iterate through the collection including this mapper and
+ all descendant mappers.
+
+ This includes not just the immediately inheriting mappers but
+ all their inheriting mappers as well.
+
+ To iterate through an entire hierarchy, use
+ ``mapper.base_mapper.polymorphic_iterator()``.
+
+ """
+ return iter(self.self_and_descendants)
+
+ def primary_mapper(self):
+ """Return the primary mapper corresponding to this mapper's class key
+ (class)."""
+
+ return self.class_manager.mapper
+
+ @property
+ def primary_base_mapper(self):
+ return self.class_manager.mapper.base_mapper
+
+ def _result_has_identity_key(self, result, adapter=None):
+ pk_cols = self.primary_key
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+ rk = result.keys()
+ for col in pk_cols:
+ if col not in rk:
+ return False
+ else:
+ return True
+
+ def identity_key_from_row(self, row, identity_token=None, adapter=None):
+ """Return an identity-map key for use in storing/retrieving an
+ item from the identity map.
+
+ :param row: A :class:`.Row` instance. The columns which are
+ mapped by this :class:`_orm.Mapper` should be locatable in the row,
+ preferably via the :class:`_schema.Column`
+ object directly (as is the case
+ when a :func:`_expression.select` construct is executed), or
+ via string names of the form ``<tablename>_<colname>``.
+
+ """
+ pk_cols = self.primary_key
+ if adapter:
+ pk_cols = [adapter.columns[c] for c in pk_cols]
+
+ return (
+ self._identity_class,
+ tuple(row[column] for column in pk_cols),
+ identity_token,
+ )
+
+ def identity_key_from_primary_key(self, primary_key, identity_token=None):
+ """Return an identity-map key for use in storing/retrieving an
+ item from an identity map.
+
+ :param primary_key: A list of values indicating the identifier.
+
+ """
+ return self._identity_class, tuple(primary_key), identity_token
+
+ def identity_key_from_instance(self, instance):
+ """Return the identity key for the given instance, based on
+ its primary key attributes.
+
+ If the instance's state is expired, calling this method
+ will result in a database check to see if the object has been deleted.
+ If the row no longer exists,
+ :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ This value is typically also found on the instance state under the
+ attribute name `key`.
+
+ """
+ state = attributes.instance_state(instance)
+ return self._identity_key_from_state(state, attributes.PASSIVE_OFF)
+
+ def _identity_key_from_state(
+ self, state, passive=attributes.PASSIVE_RETURN_NO_VALUE
+ ):
+ dict_ = state.dict
+ manager = state.manager
+ return (
+ self._identity_class,
+ tuple(
+ [
+ manager[prop.key].impl.get(state, dict_, passive)
+ for prop in self._identity_key_props
+ ]
+ ),
+ state.identity_token,
+ )
+
+ def primary_key_from_instance(self, instance):
+ """Return the list of primary key values for the given
+ instance.
+
+ If the instance's state is expired, calling this method
+ will result in a database check to see if the object has been deleted.
+ If the row no longer exists,
+ :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ """
+ state = attributes.instance_state(instance)
+ identity_key = self._identity_key_from_state(
+ state, attributes.PASSIVE_OFF
+ )
+ return identity_key[1]
+
+ @HasMemoized.memoized_attribute
+ def _persistent_sortkey_fn(self):
+ key_fns = [col.type.sort_key_function for col in self.primary_key]
+
+ if set(key_fns).difference([None]):
+
+ def key(state):
+ return tuple(
+ key_fn(val) if key_fn is not None else val
+ for key_fn, val in zip(key_fns, state.key[1])
+ )
+
+ else:
+
+ def key(state):
+ return state.key[1]
+
+ return key
+
+ @HasMemoized.memoized_attribute
+ def _identity_key_props(self):
+ return [self._columntoproperty[col] for col in self.primary_key]
+
+ @HasMemoized.memoized_attribute
+ def _all_pk_cols(self):
+ collection = set()
+ for table in self.tables:
+ collection.update(self._pks_by_table[table])
+ return collection
+
+ @HasMemoized.memoized_attribute
+ def _should_undefer_in_wildcard(self):
+ cols = set(self.primary_key)
+ if self.polymorphic_on is not None:
+ cols.add(self.polymorphic_on)
+ return cols
+
+ @HasMemoized.memoized_attribute
+ def _primary_key_propkeys(self):
+ return {self._columntoproperty[col].key for col in self._all_pk_cols}
+
+ def _get_state_attr_by_column(
+ self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
+ ):
+ prop = self._columntoproperty[column]
+ return state.manager[prop.key].impl.get(state, dict_, passive=passive)
+
+ def _set_committed_state_attr_by_column(self, state, dict_, column, value):
+ prop = self._columntoproperty[column]
+ state.manager[prop.key].impl.set_committed_value(state, dict_, value)
+
+ def _set_state_attr_by_column(self, state, dict_, column, value):
+ prop = self._columntoproperty[column]
+ state.manager[prop.key].impl.set(state, dict_, value, None)
+
+ def _get_committed_attr_by_column(self, obj, column):
+ state = attributes.instance_state(obj)
+ dict_ = attributes.instance_dict(obj)
+ return self._get_committed_state_attr_by_column(
+ state, dict_, column, passive=attributes.PASSIVE_OFF
+ )
+
+ def _get_committed_state_attr_by_column(
+ self, state, dict_, column, passive=attributes.PASSIVE_RETURN_NO_VALUE
+ ):
+
+ prop = self._columntoproperty[column]
+ return state.manager[prop.key].impl.get_committed_value(
+ state, dict_, passive=passive
+ )
+
+ def _optimized_get_statement(self, state, attribute_names):
+ """assemble a WHERE clause which retrieves a given state by primary
+ key, using a minimized set of tables.
+
+ Applies to a joined-table inheritance mapper where the
+ requested attribute names are only present on joined tables,
+ not the base table. The WHERE clause attempts to include
+ only those tables to minimize joins.
+
+ """
+ props = self._props
+
+ col_attribute_names = set(attribute_names).intersection(
+ state.mapper.column_attrs.keys()
+ )
+ tables = set(
+ chain(
+ *[
+ sql_util.find_tables(c, check_columns=True)
+ for key in col_attribute_names
+ for c in props[key].columns
+ ]
+ )
+ )
+
+ if self.base_mapper.local_table in tables:
+ return None
+
+ def visit_binary(binary):
+ leftcol = binary.left
+ rightcol = binary.right
+ if leftcol is None or rightcol is None:
+ return
+
+ if leftcol.table not in tables:
+ leftval = self._get_committed_state_attr_by_column(
+ state,
+ state.dict,
+ leftcol,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+ if leftval in orm_util._none_set:
+ raise _OptGetColumnsNotAvailable()
+ binary.left = sql.bindparam(
+ None, leftval, type_=binary.right.type
+ )
+ elif rightcol.table not in tables:
+ rightval = self._get_committed_state_attr_by_column(
+ state,
+ state.dict,
+ rightcol,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+ if rightval in orm_util._none_set:
+ raise _OptGetColumnsNotAvailable()
+ binary.right = sql.bindparam(
+ None, rightval, type_=binary.right.type
+ )
+
+ allconds = []
+
+ start = False
+
+ # as of #7507, from the lowest base table on upwards,
+ # we include all intermediary tables.
+
+ for mapper in reversed(list(self.iterate_to_root())):
+ if mapper.local_table in tables:
+ start = True
+ elif not isinstance(mapper.local_table, expression.TableClause):
+ return None
+ if start and not mapper.single:
+ allconds.append(mapper.inherit_condition)
+ tables.add(mapper.local_table)
+
+ # only the bottom table needs its criteria to be altered to fit
+ # the primary key ident - the rest of the tables upwards to the
+ # descendant-most class should all be present and joined to each
+ # other.
+ try:
+ allconds[0] = visitors.cloned_traverse(
+ allconds[0], {}, {"binary": visit_binary}
+ )
+ except _OptGetColumnsNotAvailable:
+ return None
+
+ cond = sql.and_(*allconds)
+
+ cols = []
+ for key in col_attribute_names:
+ cols.extend(props[key].columns)
+ return (
+ sql.select(*cols)
+ .where(cond)
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ )
+
+ def _iterate_to_target_viawpoly(self, mapper):
+ if self.isa(mapper):
+ prev = self
+ for m in self.iterate_to_root():
+ yield m
+
+ if m is not prev and prev not in m._with_polymorphic_mappers:
+ break
+
+ prev = m
+ if m is mapper:
+ break
+
+ def _should_selectin_load(self, enabled_via_opt, polymorphic_from):
+ if not enabled_via_opt:
+ # common case, takes place for all polymorphic loads
+ mapper = polymorphic_from
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if m.polymorphic_load == "selectin":
+ return m
+ else:
+ # uncommon case, selectin load options were used
+ enabled_via_opt = set(enabled_via_opt)
+ enabled_via_opt_mappers = {e.mapper: e for e in enabled_via_opt}
+ for entity in enabled_via_opt.union([polymorphic_from]):
+ mapper = entity.mapper
+ for m in self._iterate_to_target_viawpoly(mapper):
+ if (
+ m.polymorphic_load == "selectin"
+ or m in enabled_via_opt_mappers
+ ):
+ return enabled_via_opt_mappers.get(m, m)
+
+ return None
+
+ @util.preload_module("sqlalchemy.orm.strategy_options")
+ def _subclass_load_via_in(self, entity):
+ """Assemble a that can load the columns local to
+ this subclass as a SELECT with IN.
+
+ """
+ strategy_options = util.preloaded.orm_strategy_options
+
+ assert self.inherits
+
+ if self.polymorphic_on is not None:
+ polymorphic_prop = self._columntoproperty[self.polymorphic_on]
+ keep_props = set([polymorphic_prop] + self._identity_key_props)
+ else:
+ keep_props = set(self._identity_key_props)
+
+ disable_opt = strategy_options.Load(entity)
+ enable_opt = strategy_options.Load(entity)
+
+ for prop in self.attrs:
+ if prop.parent is self or prop in keep_props:
+ # "enable" options, to turn on the properties that we want to
+ # load by default (subject to options from the query)
+ if not isinstance(prop, StrategizedProperty):
+ continue
+
+ enable_opt.set_generic_strategy(
+ # convert string name to an attribute before passing
+ # to loader strategy
+ (getattr(entity.entity_namespace, prop.key),),
+ dict(prop.strategy_key),
+ )
+ else:
+ # "disable" options, to turn off the properties from the
+ # superclass that we *don't* want to load, applied after
+ # the options from the query to override them
+ disable_opt.set_generic_strategy(
+ # convert string name to an attribute before passing
+ # to loader strategy
+ (getattr(entity.entity_namespace, prop.key),),
+ {"do_nothing": True},
+ )
+
+ primary_key = [
+ sql_util._deep_annotate(pk, {"_orm_adapt": True})
+ for pk in self.primary_key
+ ]
+
+ if len(primary_key) > 1:
+ in_expr = sql.tuple_(*primary_key)
+ else:
+ in_expr = primary_key[0]
+
+ if entity.is_aliased_class:
+ assert entity.mapper is self
+
+ q = sql.select(entity).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+
+ in_expr = entity._adapter.traverse(in_expr)
+ primary_key = [entity._adapter.traverse(k) for k in primary_key]
+ q = q.where(
+ in_expr.in_(sql.bindparam("primary_keys", expanding=True))
+ ).order_by(*primary_key)
+ else:
+
+ q = sql.select(self).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ q = q.where(
+ in_expr.in_(sql.bindparam("primary_keys", expanding=True))
+ ).order_by(*primary_key)
+
+ return q, enable_opt, disable_opt
+
+ @HasMemoized.memoized_attribute
+ def _subclass_load_via_in_mapper(self):
+ return self._subclass_load_via_in(self)
+
+ def cascade_iterator(self, type_, state, halt_on=None):
+ r"""Iterate each element and its mapper in an object graph,
+ for all relationships that meet the given cascade rule.
+
+ :param type\_:
+ The name of the cascade rule (i.e. ``"save-update"``, ``"delete"``,
+ etc.).
+
+ .. note:: the ``"all"`` cascade is not accepted here. For a generic
+ object traversal function, see :ref:`faq_walk_objects`.
+
+ :param state:
+ The lead InstanceState. child items will be processed per
+ the relationships defined for this object's mapper.
+
+ :return: the method yields individual object instances.
+
+ .. seealso::
+
+ :ref:`unitofwork_cascades`
+
+ :ref:`faq_walk_objects` - illustrates a generic function to
+ traverse all objects without relying on cascades.
+
+ """
+ visited_states = set()
+ prp, mpp = object(), object()
+
+ assert state.mapper.isa(self)
+
+ visitables = deque(
+ [(deque(state.mapper._props.values()), prp, state, state.dict)]
+ )
+
+ while visitables:
+ iterator, item_type, parent_state, parent_dict = visitables[-1]
+ if not iterator:
+ visitables.pop()
+ continue
+
+ if item_type is prp:
+ prop = iterator.popleft()
+ if type_ not in prop.cascade:
+ continue
+ queue = deque(
+ prop.cascade_iterator(
+ type_,
+ parent_state,
+ parent_dict,
+ visited_states,
+ halt_on,
+ )
+ )
+ if queue:
+ visitables.append((queue, mpp, None, None))
+ elif item_type is mpp:
+ (
+ instance,
+ instance_mapper,
+ corresponding_state,
+ corresponding_dict,
+ ) = iterator.popleft()
+ yield (
+ instance,
+ instance_mapper,
+ corresponding_state,
+ corresponding_dict,
+ )
+ visitables.append(
+ (
+ deque(instance_mapper._props.values()),
+ prp,
+ corresponding_state,
+ corresponding_dict,
+ )
+ )
+
+ @HasMemoized.memoized_attribute
+ def _compiled_cache(self):
+ return util.LRUCache(self._compiled_cache_size)
+
+ @HasMemoized.memoized_attribute
+ def _sorted_tables(self):
+ table_to_mapper = {}
+
+ for mapper in self.base_mapper.self_and_descendants:
+ for t in mapper.tables:
+ table_to_mapper.setdefault(t, mapper)
+
+ extra_dependencies = []
+ for table, mapper in table_to_mapper.items():
+ super_ = mapper.inherits
+ if super_:
+ extra_dependencies.extend(
+ [(super_table, table) for super_table in super_.tables]
+ )
+
+ def skip(fk):
+ # attempt to skip dependencies that are not
+ # significant to the inheritance chain
+ # for two tables that are related by inheritance.
+ # while that dependency may be important, it's technically
+ # not what we mean to sort on here.
+ parent = table_to_mapper.get(fk.parent.table)
+ dep = table_to_mapper.get(fk.column.table)
+ if (
+ parent is not None
+ and dep is not None
+ and dep is not parent
+ and dep.inherit_condition is not None
+ ):
+ cols = set(sql_util._find_columns(dep.inherit_condition))
+ if parent.inherit_condition is not None:
+ cols = cols.union(
+ sql_util._find_columns(parent.inherit_condition)
+ )
+ return fk.parent not in cols and fk.column not in cols
+ else:
+ return fk.parent not in cols
+ return False
+
+ sorted_ = sql_util.sort_tables(
+ table_to_mapper,
+ skip_fn=skip,
+ extra_dependencies=extra_dependencies,
+ )
+
+ ret = util.OrderedDict()
+ for t in sorted_:
+ ret[t] = table_to_mapper[t]
+ return ret
+
+ def _memo(self, key, callable_):
+ if key in self._memoized_values:
+ return self._memoized_values[key]
+ else:
+ self._memoized_values[key] = value = callable_()
+ return value
+
+ @util.memoized_property
+ def _table_to_equated(self):
+ """memoized map of tables to collections of columns to be
+ synchronized upwards to the base mapper."""
+
+ result = util.defaultdict(list)
+
+ for table in self._sorted_tables:
+ cols = set(table.c)
+ for m in self.iterate_to_root():
+ if m._inherits_equated_pairs and cols.intersection(
+ util.reduce(
+ set.union,
+ [l.proxy_set for l, r in m._inherits_equated_pairs],
+ )
+ ):
+ result[table].append((m, m._inherits_equated_pairs))
+
+ return result
+
+
+class _OptGetColumnsNotAvailable(Exception):
+ pass
+
+
+def configure_mappers():
+ """Initialize the inter-mapper relationships of all mappers that
+ have been constructed thus far across all :class:`_orm.registry`
+ collections.
+
+ The configure step is used to reconcile and initialize the
+ :func:`_orm.relationship` linkages between mapped classes, as well as to
+ invoke configuration events such as the
+ :meth:`_orm.MapperEvents.before_configured` and
+ :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM
+ extensions or user-defined extension hooks.
+
+ Mapper configuration is normally invoked automatically, the first time
+ mappings from a particular :class:`_orm.registry` are used, as well as
+ whenever mappings are used and additional not-yet-configured mappers have
+ been constructed. The automatic configuration process however is local only
+ to the :class:`_orm.registry` involving the target mapper and any related
+ :class:`_orm.registry` objects which it may depend on; this is
+ equivalent to invoking the :meth:`_orm.registry.configure` method
+ on a particular :class:`_orm.registry`.
+
+ By contrast, the :func:`_orm.configure_mappers` function will invoke the
+ configuration process on all :class:`_orm.registry` objects that
+ exist in memory, and may be useful for scenarios where many individual
+ :class:`_orm.registry` objects that are nonetheless interrelated are
+ in use.
+
+ .. versionchanged:: 1.4
+
+ As of SQLAlchemy 1.4.0b2, this function works on a
+ per-:class:`_orm.registry` basis, locating all :class:`_orm.registry`
+ objects present and invoking the :meth:`_orm.registry.configure` method
+ on each. The :meth:`_orm.registry.configure` method may be preferred to
+ limit the configuration of mappers to those local to a particular
+ :class:`_orm.registry` and/or declarative base class.
+
+ Points at which automatic configuration is invoked include when a mapped
+ class is instantiated into an instance, as well as when ORM queries
+ are emitted using :meth:`.Session.query` or :meth:`_orm.Session.execute`
+ with an ORM-enabled statement.
+
+ The mapper configure process, whether invoked by
+ :func:`_orm.configure_mappers` or from :meth:`_orm.registry.configure`,
+ provides several event hooks that can be used to augment the mapper
+ configuration step. These hooks include:
+
+ * :meth:`.MapperEvents.before_configured` - called once before
+ :func:`.configure_mappers` or :meth:`_orm.registry.configure` does any
+ work; this can be used to establish additional options, properties, or
+ related mappings before the operation proceeds.
+
+ * :meth:`.MapperEvents.mapper_configured` - called as each individual
+ :class:`_orm.Mapper` is configured within the process; will include all
+ mapper state except for backrefs set up by other mappers that are still
+ to be configured.
+
+ * :meth:`.MapperEvents.after_configured` - called once after
+ :func:`.configure_mappers` or :meth:`_orm.registry.configure` is
+ complete; at this stage, all :class:`_orm.Mapper` objects that fall
+ within the scope of the configuration operation will be fully configured.
+ Note that the calling application may still have other mappings that
+ haven't been produced yet, such as if they are in modules as yet
+ unimported, and may also have mappings that are still to be configured,
+ if they are in other :class:`_orm.registry` collections not part of the
+ current scope of configuration.
+
+ """
+
+ _configure_registries(_all_registries(), cascade=True)
+
+
+def _configure_registries(registries, cascade):
+ for reg in registries:
+ if reg._new_mappers:
+ break
+ else:
+ return
+
+ with _CONFIGURE_MUTEX:
+ global _already_compiling
+ if _already_compiling:
+ return
+ _already_compiling = True
+ try:
+
+ # double-check inside mutex
+ for reg in registries:
+ if reg._new_mappers:
+ break
+ else:
+ return
+
+ Mapper.dispatch._for_class(Mapper).before_configured()
+ # initialize properties on all mappers
+ # note that _mapper_registry is unordered, which
+ # may randomly conceal/reveal issues related to
+ # the order of mapper compilation
+
+ _do_configure_registries(registries, cascade)
+ finally:
+ _already_compiling = False
+ Mapper.dispatch._for_class(Mapper).after_configured()
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _do_configure_registries(registries, cascade):
+
+ registry = util.preloaded.orm_decl_api.registry
+
+ orig = set(registries)
+
+ for reg in registry._recurse_with_dependencies(registries):
+ has_skip = False
+
+ for mapper in reg._mappers_to_configure():
+ run_configure = None
+ for fn in mapper.dispatch.before_mapper_configured:
+ run_configure = fn(mapper, mapper.class_)
+ if run_configure is EXT_SKIP:
+ has_skip = True
+ break
+ if run_configure is EXT_SKIP:
+ continue
+
+ if getattr(mapper, "_configure_failed", False):
+ e = sa_exc.InvalidRequestError(
+ "One or more mappers failed to initialize - "
+ "can't proceed with initialization of other "
+ "mappers. Triggering mapper: '%s'. "
+ "Original exception was: %s"
+ % (mapper, mapper._configure_failed)
+ )
+ e._configure_failed = mapper._configure_failed
+ raise e
+
+ if not mapper.configured:
+ try:
+ mapper._post_configure_properties()
+ mapper._expire_memoizations()
+ mapper.dispatch.mapper_configured(mapper, mapper.class_)
+ except Exception:
+ exc = sys.exc_info()[1]
+ if not hasattr(exc, "_configure_failed"):
+ mapper._configure_failed = exc
+ raise
+ if not has_skip:
+ reg._new_mappers = False
+
+ if not cascade and reg._dependencies.difference(orig):
+ raise sa_exc.InvalidRequestError(
+ "configure was called with cascade=False but "
+ "additional registries remain"
+ )
+
+
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _dispose_registries(registries, cascade):
+
+ registry = util.preloaded.orm_decl_api.registry
+
+ orig = set(registries)
+
+ for reg in registry._recurse_with_dependents(registries):
+ if not cascade and reg._dependents.difference(orig):
+ raise sa_exc.InvalidRequestError(
+ "Registry has dependent registries that are not disposed; "
+ "pass cascade=True to clear these also"
+ )
+
+ while reg._managers:
+ try:
+ manager, _ = reg._managers.popitem()
+ except KeyError:
+ # guard against race between while and popitem
+ pass
+ else:
+ reg._dispose_manager_and_mapper(manager)
+
+ reg._non_primary_mappers.clear()
+ reg._dependents.clear()
+ for dep in reg._dependencies:
+ dep._dependents.discard(reg)
+ reg._dependencies.clear()
+ # this wasn't done in the 1.3 clear_mappers() and in fact it
+ # was a bug, as it could cause configure_mappers() to invoke
+ # the "before_configured" event even though mappers had all been
+ # disposed.
+ reg._new_mappers = False
+
+
+def reconstructor(fn):
+ """Decorate a method as the 'reconstructor' hook.
+
+ Designates a single method as the "reconstructor", an ``__init__``-like
+ method that will be called by the ORM after the instance has been
+ loaded from the database or otherwise reconstituted.
+
+ The reconstructor will be invoked with no arguments. Scalar
+ (non-collection) database-mapped attributes of the instance will
+ be available for use within the function. Eagerly-loaded
+ collections are generally not yet available and will usually only
+ contain the first element. ORM state changes made to objects at
+ this stage will not be recorded for the next flush() operation, so
+ the activity within a reconstructor should be conservative.
+
+ .. seealso::
+
+ :ref:`mapping_constructors`
+
+ :meth:`.InstanceEvents.load`
+
+ """
+ fn.__sa_reconstructor__ = True
+ return fn
+
+
+def validates(*names, **kw):
+ r"""Decorate a method as a 'validator' for one or more named properties.
+
+ Designates a method as a validator, a method which receives the
+ name of the attribute as well as a value to be assigned, or in the
+ case of a collection, the value to be added to the collection.
+ The function can then raise validation exceptions to halt the
+ process from continuing (where Python's built-in ``ValueError``
+ and ``AssertionError`` exceptions are reasonable choices), or can
+ modify or replace the value before proceeding. The function should
+ otherwise return the given value.
+
+ Note that a validator for a collection **cannot** issue a load of that
+ collection within the validation routine - this usage raises
+ an assertion to avoid recursion overflows. This is a reentrant
+ condition which is not supported.
+
+ :param \*names: list of attribute names to be validated.
+ :param include_removes: if True, "remove" events will be
+ sent as well - the validation function must accept an additional
+ argument "is_remove" which will be a boolean.
+
+ :param include_backrefs: defaults to ``True``; if ``False``, the
+ validation function will not emit if the originator is an attribute
+ event related via a backref. This can be used for bi-directional
+ :func:`.validates` usage where only one validator should emit per
+ attribute operation.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`simple_validators` - usage examples for :func:`.validates`
+
+ """
+ include_removes = kw.pop("include_removes", False)
+ include_backrefs = kw.pop("include_backrefs", True)
+
+ def wrap(fn):
+ fn.__sa_validators__ = names
+ fn.__sa_validation_opts__ = {
+ "include_removes": include_removes,
+ "include_backrefs": include_backrefs,
+ }
+ return fn
+
+ return wrap
+
+
+def _event_on_load(state, ctx):
+ instrumenting_mapper = state.manager.mapper
+
+ if instrumenting_mapper._reconstructor:
+ instrumenting_mapper._reconstructor(state.obj())
+
+
+def _event_on_init(state, args, kwargs):
+ """Run init_instance hooks.
+
+ This also includes mapper compilation, normally not needed
+ here but helps with some piecemeal configuration
+ scenarios (such as in the ORM tutorial).
+
+ """
+
+ instrumenting_mapper = state.manager.mapper
+ if instrumenting_mapper:
+ instrumenting_mapper._check_configure()
+ if instrumenting_mapper._set_polymorphic_identity:
+ instrumenting_mapper._set_polymorphic_identity(state)
+
+
+class _ColumnMapping(dict):
+ """Error reporting helper for mapper._columntoproperty."""
+
+ __slots__ = ("mapper",)
+
+ def __init__(self, mapper):
+ # TODO: weakref would be a good idea here
+ self.mapper = mapper
+
+ def __missing__(self, column):
+ prop = self.mapper._props.get(column)
+ if prop:
+ raise orm_exc.UnmappedColumnError(
+ "Column '%s.%s' is not available, due to "
+ "conflicting property '%s':%r"
+ % (column.table.name, column.name, column.key, prop)
+ )
+ raise orm_exc.UnmappedColumnError(
+ "No column %s is configured on mapper %s..."
+ % (column, self.mapper)
+ )
diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py
new file mode 100644
index 0000000..331ddd7
--- /dev/null
+++ b/lib/sqlalchemy/orm/path_registry.py
@@ -0,0 +1,519 @@
+# orm/path_registry.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Path tracking utilities, representing mapper graph traversals.
+
+"""
+
+from itertools import chain
+import logging
+
+from . import base as orm_base
+from .. import exc
+from .. import inspection
+from .. import util
+from ..sql import visitors
+from ..sql.traversals import HasCacheKey
+
+log = logging.getLogger(__name__)
+
+
+def _unreduce_path(path):
+ return PathRegistry.deserialize(path)
+
+
+_WILDCARD_TOKEN = "*"
+_DEFAULT_TOKEN = "_sa_default"
+
+
+class PathRegistry(HasCacheKey):
+ """Represent query load paths and registry functions.
+
+ Basically represents structures like:
+
+ (<User mapper>, "orders", <Order mapper>, "items", <Item mapper>)
+
+ These structures are generated by things like
+ query options (joinedload(), subqueryload(), etc.) and are
+ used to compose keys stored in the query._attributes dictionary
+ for various options.
+
+ They are then re-composed at query compile/result row time as
+ the query is formed and as rows are fetched, where they again
+ serve to compose keys to look up options in the context.attributes
+ dictionary, which is copied from query._attributes.
+
+ The path structure has a limited amount of caching, where each
+ "root" ultimately pulls from a fixed registry associated with
+ the first mapper, that also contains elements for each of its
+ property keys. However paths longer than two elements, which
+ are the exception rather than the rule, are generated on an
+ as-needed basis.
+
+ """
+
+ __slots__ = ()
+
+ is_token = False
+ is_root = False
+
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list)
+ ]
+
+ def __eq__(self, other):
+ try:
+ return other is not None and self.path == other._path_for_compare
+ except AttributeError:
+ util.warn(
+ "Comparison of PathRegistry to %r is not supported"
+ % (type(other))
+ )
+ return False
+
+ def __ne__(self, other):
+ try:
+ return other is None or self.path != other._path_for_compare
+ except AttributeError:
+ util.warn(
+ "Comparison of PathRegistry to %r is not supported"
+ % (type(other))
+ )
+ return True
+
+ @property
+ def _path_for_compare(self):
+ return self.path
+
+ def set(self, attributes, key, value):
+ log.debug("set '%s' on path '%s' to '%s'", key, self, value)
+ attributes[(key, self.natural_path)] = value
+
+ def setdefault(self, attributes, key, value):
+ log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value)
+ attributes.setdefault((key, self.natural_path), value)
+
+ def get(self, attributes, key, value=None):
+ key = (key, self.natural_path)
+ if key in attributes:
+ return attributes[key]
+ else:
+ return value
+
+ def __len__(self):
+ return len(self.path)
+
+ def __hash__(self):
+ return id(self)
+
+ @property
+ def length(self):
+ return len(self.path)
+
+ def pairs(self):
+ path = self.path
+ for i in range(0, len(path), 2):
+ yield path[i], path[i + 1]
+
+ def contains_mapper(self, mapper):
+ for path_mapper in [self.path[i] for i in range(0, len(self.path), 2)]:
+ if path_mapper.is_mapper and path_mapper.isa(mapper):
+ return True
+ else:
+ return False
+
+ def contains(self, attributes, key):
+ return (key, self.path) in attributes
+
+ def __reduce__(self):
+ return _unreduce_path, (self.serialize(),)
+
+ @classmethod
+ def _serialize_path(cls, path):
+ return list(
+ zip(
+ [
+ m.class_ if (m.is_mapper or m.is_aliased_class) else str(m)
+ for m in [path[i] for i in range(0, len(path), 2)]
+ ],
+ [
+ path[i].key if (path[i].is_property) else str(path[i])
+ for i in range(1, len(path), 2)
+ ]
+ + [None],
+ )
+ )
+
+ @classmethod
+ def _deserialize_path(cls, path):
+ def _deserialize_mapper_token(mcls):
+ return (
+ # note: we likely dont want configure=True here however
+ # this is maintained at the moment for backwards compatibility
+ orm_base._inspect_mapped_class(mcls, configure=True)
+ if mcls not in PathToken._intern
+ else PathToken._intern[mcls]
+ )
+
+ def _deserialize_key_token(mcls, key):
+ if key is None:
+ return None
+ elif key in PathToken._intern:
+ return PathToken._intern[key]
+ else:
+ return orm_base._inspect_mapped_class(
+ mcls, configure=True
+ ).attrs[key]
+
+ p = tuple(
+ chain(
+ *[
+ (
+ _deserialize_mapper_token(mcls),
+ _deserialize_key_token(mcls, key),
+ )
+ for mcls, key in path
+ ]
+ )
+ )
+ if p and p[-1] is None:
+ p = p[0:-1]
+ return p
+
+ @classmethod
+ def serialize_context_dict(cls, dict_, tokens):
+ return [
+ ((key, cls._serialize_path(path)), value)
+ for (key, path), value in [
+ (k, v)
+ for k, v in dict_.items()
+ if isinstance(k, tuple) and k[0] in tokens
+ ]
+ ]
+
+ @classmethod
+ def deserialize_context_dict(cls, serialized):
+ return util.OrderedDict(
+ ((key, tuple(cls._deserialize_path(path))), value)
+ for (key, path), value in serialized
+ )
+
+ def serialize(self):
+ path = self.path
+ return self._serialize_path(path)
+
+ @classmethod
+ def deserialize(cls, path):
+ if path is None:
+ return None
+ p = cls._deserialize_path(path)
+ return cls.coerce(p)
+
+ @classmethod
+ def per_mapper(cls, mapper):
+ if mapper.is_mapper:
+ return CachingEntityRegistry(cls.root, mapper)
+ else:
+ return SlotsEntityRegistry(cls.root, mapper)
+
+ @classmethod
+ def coerce(cls, raw):
+ return util.reduce(lambda prev, next: prev[next], raw, cls.root)
+
+ def token(self, token):
+ if token.endswith(":" + _WILDCARD_TOKEN):
+ return TokenRegistry(self, token)
+ elif token.endswith(":" + _DEFAULT_TOKEN):
+ return TokenRegistry(self.root, token)
+ else:
+ raise exc.ArgumentError("invalid token: %s" % token)
+
+ def __add__(self, other):
+ return util.reduce(lambda prev, next: prev[next], other.path, self)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.path)
+
+
+class RootRegistry(PathRegistry):
+ """Root registry, defers to mappers so that
+ paths are maintained per-root-mapper.
+
+ """
+
+ inherit_cache = True
+
+ path = natural_path = ()
+ has_entity = False
+ is_aliased_class = False
+ is_root = True
+
+ def __getitem__(self, entity):
+ if entity in PathToken._intern:
+ return PathToken._intern[entity]
+ else:
+ return entity._path_registry
+
+
+PathRegistry.root = RootRegistry()
+
+
+class PathToken(orm_base.InspectionAttr, HasCacheKey, str):
+ """cacheable string token"""
+
+ _intern = {}
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (str(self),)
+
+ @property
+ def _path_for_compare(self):
+ return None
+
+ @classmethod
+ def intern(cls, strvalue):
+ if strvalue in cls._intern:
+ return cls._intern[strvalue]
+ else:
+ cls._intern[strvalue] = result = PathToken(strvalue)
+ return result
+
+
+class TokenRegistry(PathRegistry):
+ __slots__ = ("token", "parent", "path", "natural_path")
+
+ inherit_cache = True
+
+ def __init__(self, parent, token):
+ token = PathToken.intern(token)
+
+ self.token = token
+ self.parent = parent
+ self.path = parent.path + (token,)
+ self.natural_path = parent.natural_path + (token,)
+
+ has_entity = False
+
+ is_token = True
+
+ def generate_for_superclasses(self):
+ if not self.parent.is_aliased_class and not self.parent.is_root:
+ for ent in self.parent.mapper.iterate_to_root():
+ yield TokenRegistry(self.parent.parent[ent], self.token)
+ elif (
+ self.parent.is_aliased_class
+ and self.parent.entity._is_with_polymorphic
+ ):
+ yield self
+ for ent in self.parent.entity._with_polymorphic_entities:
+ yield TokenRegistry(self.parent.parent[ent], self.token)
+ else:
+ yield self
+
+ def __getitem__(self, entity):
+ raise NotImplementedError()
+
+
+class PropRegistry(PathRegistry):
+ is_unnatural = False
+ inherit_cache = True
+
+ def __init__(self, parent, prop):
+ # restate this path in terms of the
+ # given MapperProperty's parent.
+ insp = inspection.inspect(parent[-1])
+ natural_parent = parent
+
+ if not insp.is_aliased_class or insp._use_mapper_path:
+ parent = natural_parent = parent.parent[prop.parent]
+ elif (
+ insp.is_aliased_class
+ and insp.with_polymorphic_mappers
+ and prop.parent in insp.with_polymorphic_mappers
+ ):
+ subclass_entity = parent[-1]._entity_for_mapper(prop.parent)
+ parent = parent.parent[subclass_entity]
+
+ # when building a path where with_polymorphic() is in use,
+ # special logic to determine the "natural path" when subclass
+ # entities are used.
+ #
+ # here we are trying to distinguish between a path that starts
+ # on a the with_polymorhpic entity vs. one that starts on a
+ # normal entity that introduces a with_polymorphic() in the
+ # middle using of_type():
+ #
+ # # as in test_polymorphic_rel->
+ # # test_subqueryload_on_subclass_uses_path_correctly
+ # wp = with_polymorphic(RegularEntity, "*")
+ # sess.query(wp).options(someload(wp.SomeSubEntity.foos))
+ #
+ # vs
+ #
+ # # as in test_relationship->JoinedloadWPolyOfTypeContinued
+ # wp = with_polymorphic(SomeFoo, "*")
+ # sess.query(RegularEntity).options(
+ # someload(RegularEntity.foos.of_type(wp))
+ # .someload(wp.SubFoo.bar)
+ # )
+ #
+ # in the former case, the Query as it generates a path that we
+ # want to match will be in terms of the with_polymorphic at the
+ # beginning. in the latter case, Query will generate simple
+ # paths that don't know about this with_polymorphic, so we must
+ # use a separate natural path.
+ #
+ #
+ if parent.parent:
+ natural_parent = parent.parent[subclass_entity.mapper]
+ self.is_unnatural = True
+ else:
+ natural_parent = parent
+ elif (
+ natural_parent.parent
+ and insp.is_aliased_class
+ and prop.parent # this should always be the case here
+ is not insp.mapper
+ and insp.mapper.isa(prop.parent)
+ ):
+ natural_parent = parent.parent[prop.parent]
+
+ self.prop = prop
+ self.parent = parent
+ self.path = parent.path + (prop,)
+ self.natural_path = natural_parent.natural_path + (prop,)
+
+ self._wildcard_path_loader_key = (
+ "loader",
+ parent.path + self.prop._wildcard_token,
+ )
+ self._default_path_loader_key = self.prop._default_path_loader_key
+ self._loader_key = ("loader", self.natural_path)
+
+ def __str__(self):
+ return " -> ".join(str(elem) for elem in self.path)
+
+ @util.memoized_property
+ def has_entity(self):
+ return self.prop._links_to_entity
+
+ @util.memoized_property
+ def entity(self):
+ return self.prop.entity
+
+ @property
+ def mapper(self):
+ return self.prop.mapper
+
+ @property
+ def entity_path(self):
+ return self[self.entity]
+
+ def __getitem__(self, entity):
+ if isinstance(entity, (int, slice)):
+ return self.path[entity]
+ else:
+ return SlotsEntityRegistry(self, entity)
+
+
+class AbstractEntityRegistry(PathRegistry):
+ __slots__ = ()
+
+ has_entity = True
+
+ def __init__(self, parent, entity):
+ self.key = entity
+ self.parent = parent
+ self.is_aliased_class = entity.is_aliased_class
+ self.entity = entity
+ self.path = parent.path + (entity,)
+
+ # the "natural path" is the path that we get when Query is traversing
+ # from the lead entities into the various relationships; it corresponds
+ # to the structure of mappers and relationships. when we are given a
+ # path that comes from loader options, as of 1.3 it can have ac-hoc
+ # with_polymorphic() and other AliasedInsp objects inside of it, which
+ # are usually not present in mappings. So here we track both the
+ # "enhanced" path in self.path and the "natural" path that doesn't
+ # include those objects so these two traversals can be matched up.
+
+ # the test here for "(self.is_aliased_class or parent.is_unnatural)"
+ # are to avoid the more expensive conditional logic that follows if we
+ # know we don't have to do it. This conditional can just as well be
+ # "if parent.path:", it just is more function calls.
+ if parent.path and (self.is_aliased_class or parent.is_unnatural):
+ # this is an infrequent code path used only for loader strategies
+ # that also make use of of_type().
+ if entity.mapper.isa(parent.natural_path[-1].entity):
+ self.natural_path = parent.natural_path + (entity.mapper,)
+ else:
+ self.natural_path = parent.natural_path + (
+ parent.natural_path[-1].entity,
+ )
+ # it seems to make sense that since these paths get mixed up
+ # with statements that are cached or not, we should make
+ # sure the natural path is cacheable across different occurrences
+ # of equivalent AliasedClass objects. however, so far this
+ # does not seem to be needed for whatever reason.
+ # elif not parent.path and self.is_aliased_class:
+ # self.natural_path = (self.entity._generate_cache_key()[0], )
+ else:
+ # self.natural_path = parent.natural_path + (entity, )
+ self.natural_path = self.path
+
+ @property
+ def entity_path(self):
+ return self
+
+ @property
+ def mapper(self):
+ return inspection.inspect(self.entity).mapper
+
+ def __bool__(self):
+ return True
+
+ __nonzero__ = __bool__
+
+ def __getitem__(self, entity):
+ if isinstance(entity, (int, slice)):
+ return self.path[entity]
+ elif entity in PathToken._intern:
+ return TokenRegistry(self, PathToken._intern[entity])
+ else:
+ return PropRegistry(self, entity)
+
+
+class SlotsEntityRegistry(AbstractEntityRegistry):
+ # for aliased class, return lightweight, no-cycles created
+ # version
+ inherit_cache = True
+
+ __slots__ = (
+ "key",
+ "parent",
+ "is_aliased_class",
+ "entity",
+ "path",
+ "natural_path",
+ )
+
+
+class CachingEntityRegistry(AbstractEntityRegistry, dict):
+ # for long lived mapper, return dict based caching
+ # version that creates reference cycles
+
+ inherit_cache = True
+
+ def __getitem__(self, entity):
+ if isinstance(entity, (int, slice)):
+ return self.path[entity]
+ else:
+ return dict.__getitem__(self, entity)
+
+ def __missing__(self, key):
+ self[key] = item = PropRegistry(self, key)
+
+ return item
diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
new file mode 100644
index 0000000..a17b24a
--- /dev/null
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -0,0 +1,2517 @@
+# orm/persistence.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used to emit INSERT, UPDATE
+and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
+mappers.
+
+The functions here are called only by the unit of work functions
+in unitofwork.py.
+
+"""
+
+from itertools import chain
+from itertools import groupby
+import operator
+
+from . import attributes
+from . import evaluator
+from . import exc as orm_exc
+from . import loading
+from . import sync
+from .base import NO_VALUE
+from .base import state_str
+from .. import exc as sa_exc
+from .. import future
+from .. import sql
+from .. import util
+from ..engine import result as _result
+from ..sql import coercions
+from ..sql import expression
+from ..sql import operators
+from ..sql import roles
+from ..sql import select
+from ..sql import sqltypes
+from ..sql.base import _entity_namespace_key
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.dml import DeleteDMLState
+from ..sql.dml import InsertDMLState
+from ..sql.dml import UpdateDMLState
+from ..sql.elements import BooleanClauseList
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+
+
+def _bulk_insert(
+ mapper,
+ mappings,
+ session_transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+):
+ base_mapper = mapper.base_mapper
+
+ if session_transaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_insert()"
+ )
+
+ if isstates:
+ if return_defaults:
+ states = [(state, state.dict) for state in mappings]
+ mappings = [dict_ for (state, dict_) in states]
+ else:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ connection = session_transaction.connection(base_mapper)
+ for table, super_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(super_mapper):
+ continue
+
+ records = (
+ (
+ None,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
+ for (
+ state,
+ state_dict,
+ params,
+ mp,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in _collect_insert_commands(
+ table,
+ ((None, mapping, mapper, connection) for mapping in mappings),
+ bulk=True,
+ return_defaults=return_defaults,
+ render_nulls=render_nulls,
+ )
+ )
+ _emit_insert_statements(
+ base_mapper,
+ None,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=return_defaults,
+ )
+
+ if return_defaults and isstates:
+ identity_cls = mapper._identity_class
+ identity_props = [p.key for p in mapper._identity_key_props]
+ for state, dict_ in states:
+ state.key = (
+ identity_cls,
+ tuple([dict_[key] for key in identity_props]),
+ )
+
+
+def _bulk_update(
+ mapper, mappings, session_transaction, isstates, update_changed_only
+):
+ base_mapper = mapper.base_mapper
+
+ search_keys = mapper._primary_key_propkeys
+ if mapper._version_id_prop:
+ search_keys = {mapper._version_id_prop.key}.union(search_keys)
+
+ def _changed_dict(mapper, state):
+ return dict(
+ (k, v)
+ for k, v in state.dict.items()
+ if k in state.committed_state or k in search_keys
+ )
+
+ if isstates:
+ if update_changed_only:
+ mappings = [_changed_dict(mapper, state) for state in mappings]
+ else:
+ mappings = [state.dict for state in mappings]
+ else:
+ mappings = list(mappings)
+
+ if session_transaction.session.connection_callable:
+ raise NotImplementedError(
+ "connection_callable / per-instance sharding "
+ "not supported in bulk_update()"
+ )
+
+ connection = session_transaction.connection(base_mapper)
+
+ for table, super_mapper in base_mapper._sorted_tables.items():
+ if not mapper.isa(super_mapper):
+ continue
+
+ records = _collect_update_commands(
+ None,
+ table,
+ (
+ (
+ None,
+ mapping,
+ mapper,
+ connection,
+ (
+ mapping[mapper._version_id_prop.key]
+ if mapper._version_id_prop
+ else None
+ ),
+ )
+ for mapping in mappings
+ ),
+ bulk=True,
+ )
+
+ _emit_update_statements(
+ base_mapper,
+ None,
+ super_mapper,
+ table,
+ records,
+ bookkeeping=False,
+ )
+
+
+def save_obj(base_mapper, states, uowtransaction, single=False):
+ """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
+ of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation, given a list of states to be flushed. The
+ base mapper in an inheritance hierarchy handles the inserts/
+ updates for all descendant mappers.
+
+ """
+
+ # if batch=false, call _save_obj separately for each object
+ if not single and not base_mapper.batch:
+ for state in _sort_states(base_mapper, states):
+ save_obj(base_mapper, [state], uowtransaction, single=True)
+ return
+
+ states_to_update = []
+ states_to_insert = []
+
+ for (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ ) in _organize_states_for_save(base_mapper, states, uowtransaction):
+ if has_identity or row_switch:
+ states_to_update.append(
+ (state, dict_, mapper, connection, update_version_id)
+ )
+ else:
+ states_to_insert.append((state, dict_, mapper, connection))
+
+ for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
+ insert = _collect_insert_commands(table, states_to_insert)
+
+ update = _collect_update_commands(
+ uowtransaction, table, states_to_update
+ )
+
+ _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ update,
+ )
+
+ _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ insert,
+ )
+
+ _finalize_insert_update_commands(
+ base_mapper,
+ uowtransaction,
+ chain(
+ (
+ (state, state_dict, mapper, connection, False)
+ for (state, state_dict, mapper, connection) in states_to_insert
+ ),
+ (
+ (state, state_dict, mapper, connection, True)
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update
+ ),
+ ),
+ )
+
+
+def post_update(base_mapper, states, uowtransaction, post_update_cols):
+ """Issue UPDATE statements on behalf of a relationship() which
+ specifies post_update.
+
+ """
+
+ states_to_update = list(
+ _organize_states_for_post_update(base_mapper, states, uowtransaction)
+ )
+
+ for table, mapper in base_mapper._sorted_tables.items():
+ if table not in mapper._pks_by_table:
+ continue
+
+ update = (
+ (
+ state,
+ state_dict,
+ sub_mapper,
+ connection,
+ mapper._get_committed_state_attr_by_column(
+ state, state_dict, mapper.version_id_col
+ )
+ if mapper.version_id_col is not None
+ else None,
+ )
+ for state, state_dict, sub_mapper, connection in states_to_update
+ if table in sub_mapper._pks_by_table
+ )
+
+ update = _collect_post_update_commands(
+ base_mapper, uowtransaction, table, update, post_update_cols
+ )
+
+ _emit_post_update_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ update,
+ )
+
+
+def delete_obj(base_mapper, states, uowtransaction):
+ """Issue ``DELETE`` statements for a list of objects.
+
+ This is called within the context of a UOWTransaction during a
+ flush operation.
+
+ """
+
+ states_to_delete = list(
+ _organize_states_for_delete(base_mapper, states, uowtransaction)
+ )
+
+ table_to_mapper = base_mapper._sorted_tables
+
+ for table in reversed(list(table_to_mapper.keys())):
+ mapper = table_to_mapper[table]
+ if table not in mapper._pks_by_table:
+ continue
+ elif mapper.inherits and mapper.passive_deletes:
+ continue
+
+ delete = _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+ )
+
+ _emit_delete_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ delete,
+ )
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
+ mapper.dispatch.after_delete(mapper, connection, state)
+
+
+def _organize_states_for_save(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for INSERT or
+ UPDATE.
+
+ This includes splitting out into distinct lists for
+ each, calling before_insert/before_update, obtaining
+ key information for each state including its dictionary,
+ mapper, the connection to use for the execution per state,
+ and the identity flag.
+
+ """
+
+ for state, dict_, mapper, connection in _connections_for_states(
+ base_mapper, uowtransaction, states
+ ):
+
+ has_identity = bool(state.key)
+
+ instance_key = state.key or mapper._identity_key_from_state(state)
+
+ row_switch = update_version_id = None
+
+ # call before_XXX extensions
+ if not has_identity:
+ mapper.dispatch.before_insert(mapper, connection, state)
+ else:
+ mapper.dispatch.before_update(mapper, connection, state)
+
+ if mapper._validate_polymorphic_identity:
+ mapper._validate_polymorphic_identity(mapper, state, dict_)
+
+ # detect if we have a "pending" instance (i.e. has
+ # no instance_key attached to it), and another instance
+ # with the same identity key already exists as persistent.
+ # convert to an UPDATE if so.
+ if (
+ not has_identity
+ and instance_key in uowtransaction.session.identity_map
+ ):
+ instance = uowtransaction.session.identity_map[instance_key]
+ existing = attributes.instance_state(instance)
+
+ if not uowtransaction.was_already_deleted(existing):
+ if not uowtransaction.is_deleted(existing):
+ util.warn(
+ "New instance %s with identity key %s conflicts "
+ "with persistent instance %s"
+ % (state_str(state), instance_key, state_str(existing))
+ )
+ else:
+ base_mapper._log_debug(
+ "detected row switch for identity %s. "
+ "will update %s, remove %s from "
+ "transaction",
+ instance_key,
+ state_str(state),
+ state_str(existing),
+ )
+
+ # remove the "delete" flag from the existing element
+ uowtransaction.remove_state_actions(existing)
+ row_switch = existing
+
+ if (has_identity or row_switch) and mapper.version_id_col is not None:
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ row_switch if row_switch else state,
+ row_switch.dict if row_switch else dict_,
+ mapper.version_id_col,
+ )
+
+ yield (
+ state,
+ dict_,
+ mapper,
+ connection,
+ has_identity,
+ row_switch,
+ update_version_id,
+ )
+
+
+def _organize_states_for_post_update(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for UPDATE
+ corresponding to post_update.
+
+ This includes obtaining key information for each state
+ including its dictionary, mapper, the connection to use for
+ the execution per state.
+
+ """
+ return _connections_for_states(base_mapper, uowtransaction, states)
+
+
+def _organize_states_for_delete(base_mapper, states, uowtransaction):
+ """Make an initial pass across a set of states for DELETE.
+
+ This includes calling out before_delete and obtaining
+ key information for each state including its dictionary,
+ mapper, the connection to use for the execution per state.
+
+ """
+ for state, dict_, mapper, connection in _connections_for_states(
+ base_mapper, uowtransaction, states
+ ):
+
+ mapper.dispatch.before_delete(mapper, connection, state)
+
+ if mapper.version_id_col is not None:
+ update_version_id = mapper._get_committed_state_attr_by_column(
+ state, dict_, mapper.version_id_col
+ )
+ else:
+ update_version_id = None
+
+ yield (state, dict_, mapper, connection, update_version_id)
+
+
+def _collect_insert_commands(
+ table,
+ states_to_insert,
+ bulk=False,
+ return_defaults=False,
+ render_nulls=False,
+):
+ """Identify sets of values to use in INSERT statements for a
+ list of states.
+
+ """
+ for state, state_dict, mapper, connection in states_to_insert:
+ if table not in mapper._pks_by_table:
+ continue
+
+ params = {}
+ value_params = {}
+
+ propkey_to_col = mapper._propkey_to_col[table]
+
+ eval_none = mapper._insert_cols_evaluating_none[table]
+
+ for propkey in set(propkey_to_col).intersection(state_dict):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+ if value is None and col not in eval_none and not render_nulls:
+ continue
+ elif not bulk and (
+ hasattr(value, "__clause_element__")
+ or isinstance(value, sql.ClauseElement)
+ ):
+ value_params[col] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ else:
+ params[col.key] = value
+
+ if not bulk:
+ # for all the columns that have no default and we don't have
+ # a value and where "None" is not a special value, add
+ # explicit None to the INSERT. This is a legacy behavior
+ # which might be worth removing, as it should not be necessary
+ # and also produces confusion, given that "missing" and None
+ # now have distinct meanings
+ for colkey in (
+ mapper._insert_cols_as_none[table]
+ .difference(params)
+ .difference([c.key for c in value_params])
+ ):
+ params[colkey] = None
+
+ if not bulk or return_defaults:
+ # params are in terms of Column key objects, so
+ # compare to pk_keys_by_table
+ has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
+
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_default_cols[table].issubset(
+ params
+ )
+ else:
+ has_all_defaults = True
+ else:
+ has_all_defaults = has_all_pks = True
+
+ if (
+ mapper.version_id_generator is not False
+ and mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ params[mapper.version_id_col.key] = mapper.version_id_generator(
+ None
+ )
+
+ yield (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ )
+
+
+def _collect_update_commands(
+ uowtransaction, table, states_to_update, bulk=False
+):
+ """Identify sets of values to use in UPDATE statements for a
+ list of states.
+
+ This function works intricately with the history system
+ to determine exactly what values should be updated
+ as well as how the row should be matched within an UPDATE
+ statement. Includes some tricky scenarios where the primary
+ key of an object might have been changed.
+
+ """
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
+
+ if table not in mapper._pks_by_table:
+ continue
+
+ pks = mapper._pks_by_table[table]
+
+ value_params = {}
+
+ propkey_to_col = mapper._propkey_to_col[table]
+
+ if bulk:
+ # keys here are mapped attribute keys, so
+ # look at mapper attribute keys for pk
+ params = dict(
+ (propkey_to_col[propkey].key, state_dict[propkey])
+ for propkey in set(propkey_to_col)
+ .intersection(state_dict)
+ .difference(mapper._pk_attr_keys_by_table[table])
+ )
+ has_all_defaults = True
+ else:
+ params = {}
+ for propkey in set(propkey_to_col).intersection(
+ state.committed_state
+ ):
+ value = state_dict[propkey]
+ col = propkey_to_col[propkey]
+
+ if hasattr(value, "__clause_element__") or isinstance(
+ value, sql.ClauseElement
+ ):
+ value_params[col] = (
+ value.__clause_element__()
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ # guard against values that generate non-__nonzero__
+ # objects for __eq__()
+ elif (
+ state.manager[propkey].impl.is_equal(
+ value, state.committed_state[propkey]
+ )
+ is not True
+ ):
+ params[col.key] = value
+
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = (
+ mapper._server_onupdate_default_cols[table]
+ ).issubset(params)
+ else:
+ has_all_defaults = True
+
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+
+ if not bulk and not (params or value_params):
+ # HACK: check for history in other tables, in case the
+ # history is only in a different table than the one
+ # where the version_id_col is. This logic was lost
+ # from 0.9 -> 1.0.0 and restored in 1.0.6.
+ for prop in mapper._columntoproperty.values():
+ history = state.manager[prop.key].impl.get_history(
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history.added:
+ break
+ else:
+ # no net change, break
+ continue
+
+ col = mapper.version_id_col
+ no_params = not params and not value_params
+ params[col._label] = update_version_id
+
+ if (
+ bulk or col.key not in params
+ ) and mapper.version_id_generator is not False:
+ val = mapper.version_id_generator(update_version_id)
+ params[col.key] = val
+ elif mapper.version_id_generator is False and no_params:
+ # no version id generator, no values set on the table,
+ # and version id wasn't manually incremented.
+ # set version id to itself so we get an UPDATE
+ # statement
+ params[col.key] = update_version_id
+
+ elif not (params or value_params):
+ continue
+
+ has_all_pks = True
+ expect_pk_cascaded = False
+ if bulk:
+ # keys here are mapped attribute keys, so
+ # look at mapper attribute keys for pk
+ pk_params = dict(
+ (propkey_to_col[propkey]._label, state_dict.get(propkey))
+ for propkey in set(propkey_to_col).intersection(
+ mapper._pk_attr_keys_by_table[table]
+ )
+ )
+ else:
+ pk_params = {}
+ for col in pks:
+ propkey = mapper._columntoproperty[col].key
+
+ history = state.manager[propkey].impl.get_history(
+ state, state_dict, attributes.PASSIVE_OFF
+ )
+
+ if history.added:
+ if (
+ not history.deleted
+ or ("pk_cascaded", state, col)
+ in uowtransaction.attributes
+ ):
+ expect_pk_cascaded = True
+ pk_params[col._label] = history.added[0]
+ params.pop(col.key, None)
+ else:
+ # else, use the old value to locate the row
+ pk_params[col._label] = history.deleted[0]
+ if col in value_params:
+ has_all_pks = False
+ else:
+ pk_params[col._label] = history.unchanged[0]
+ if pk_params[col._label] is None:
+ raise orm_exc.FlushError(
+ "Can't update table %s using NULL for primary "
+ "key value on column %s" % (table, col)
+ )
+
+ if params or value_params:
+ params.update(pk_params)
+ yield (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ )
+ elif expect_pk_cascaded:
+ # no UPDATE occurs on this table, but we expect that CASCADE rules
+ # have changed the primary key of the row; propagate this event to
+ # other columns that expect to have been modified. this normally
+ # occurs after the UPDATE is emitted however we invoke it here
+ # explicitly in the absence of our invoking an UPDATE
+ for m, equated_pairs in mapper._table_to_equated[table]:
+ sync.populate(
+ state,
+ m,
+ state,
+ m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates,
+ )
+
+
+def _collect_post_update_commands(
+ base_mapper, uowtransaction, table, states_to_update, post_update_cols
+):
+ """Identify sets of values to use in UPDATE statements for a
+ list of states within a post_update operation.
+
+ """
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_update:
+
+ # assert table in mapper._pks_by_table
+
+ pks = mapper._pks_by_table[table]
+ params = {}
+ hasdata = False
+
+ for col in mapper._cols_by_table[table]:
+ if col in pks:
+ params[col._label] = mapper._get_state_attr_by_column(
+ state, state_dict, col, passive=attributes.PASSIVE_OFF
+ )
+
+ elif col in post_update_cols or col.onupdate is not None:
+ prop = mapper._columntoproperty[col]
+ history = state.manager[prop.key].impl.get_history(
+ state, state_dict, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if history.added:
+ value = history.added[0]
+ params[col.key] = value
+ hasdata = True
+ if hasdata:
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+
+ col = mapper.version_id_col
+ params[col._label] = update_version_id
+
+ if (
+ bool(state.key)
+ and col.key not in params
+ and mapper.version_id_generator is not False
+ ):
+ val = mapper.version_id_generator(update_version_id)
+ params[col.key] = val
+ yield state, state_dict, mapper, connection, params
+
+
+def _collect_delete_commands(
+ base_mapper, uowtransaction, table, states_to_delete
+):
+ """Identify values to use in DELETE statements for a list of
+ states to be deleted."""
+
+ for (
+ state,
+ state_dict,
+ mapper,
+ connection,
+ update_version_id,
+ ) in states_to_delete:
+
+ if table not in mapper._pks_by_table:
+ continue
+
+ params = {}
+ for col in mapper._pks_by_table[table]:
+ params[
+ col.key
+ ] = value = mapper._get_committed_state_attr_by_column(
+ state, state_dict, col
+ )
+ if value is None:
+ raise orm_exc.FlushError(
+ "Can't delete from table %s "
+ "using NULL for primary "
+ "key value on column %s" % (table, col)
+ )
+
+ if (
+ update_version_id is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ params[mapper.version_id_col.key] = update_version_id
+ yield params, connection
+
+
+def _emit_update_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ update,
+ bookkeeping=True,
+):
+ """Emit UPDATE statements corresponding to value lists collected
+ by _collect_update_commands()."""
+
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
+ def update_stmt():
+ clauses = BooleanClauseList._construct_raw(operators.and_)
+
+ for col in mapper._pks_by_table[table]:
+ clauses.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
+
+ if needs_version_id:
+ clauses.clauses.append(
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type,
+ )
+ )
+
+ stmt = table.update().where(clauses)
+ return stmt
+
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
+
+ for (
+ (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
+ records,
+ ) in groupby(
+ update,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # set of parameter keys
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6], # has_all_defaults
+ rec[7], # has all pks
+ ),
+ ):
+ rows = 0
+ records = list(records)
+
+ statement = cached_stmt
+ return_defaults = False
+
+ if not has_all_pks:
+ statement = statement.return_defaults()
+ return_defaults = True
+ elif (
+ bookkeeping
+ and not has_all_defaults
+ and mapper.base_mapper.eager_defaults
+ ):
+ statement = statement.return_defaults()
+ return_defaults = True
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
+ return_defaults = True
+
+ assert_singlerow = (
+ connection.dialect.supports_sane_rowcount
+ if not return_defaults
+ else connection.dialect.supports_sane_rowcount_returning
+ )
+
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
+ allow_multirow = has_all_defaults and not needs_version_id
+
+ if hasvalue:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
+ )
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params,
+ True,
+ c.returned_defaults,
+ )
+ rows += c.rowcount
+ check_rowcount = assert_singlerow
+ else:
+ if not allow_multirow:
+ check_rowcount = assert_singlerow
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+
+ # TODO: why with bookkeeping=False?
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params,
+ True,
+ c.returned_defaults,
+ )
+ rows += c.rowcount
+ else:
+ multiparams = [rec[2] for rec in records]
+
+ check_rowcount = assert_multirow or (
+ assert_singlerow and len(multiparams) == 1
+ )
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ rows += c.rowcount
+
+ for (
+ state,
+ state_dict,
+ params,
+ mapper,
+ connection,
+ value_params,
+ has_all_defaults,
+ has_all_pks,
+ ) in records:
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params,
+ True,
+ c.returned_defaults
+ if not c.context.executemany
+ else None,
+ )
+
+ if check_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
+
+ elif needs_version_id:
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
+
+
+def _emit_insert_statements(
+ base_mapper,
+ uowtransaction,
+ mapper,
+ table,
+ insert,
+ bookkeeping=True,
+):
+ """Emit INSERT statements corresponding to value lists collected
+ by _collect_insert_commands()."""
+
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
+ for (
+ (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ records,
+ ) in groupby(
+ insert,
+ lambda rec: (
+ rec[4], # connection
+ set(rec[2]), # parameter keys
+ bool(rec[5]), # whether we have "value" parameters
+ rec[6],
+ rec[7],
+ ),
+ ):
+
+ statement = cached_stmt
+
+ if (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not connection.dialect.implicit_returning
+ )
+ and has_all_pks
+ and not hasvalue
+ ):
+ # the "we don't need newly generated values back" section.
+ # here we have all the PKs, all the defaults or we don't want
+ # to fetch them, or the dialect doesn't support RETURNING at all
+ # so we have to post-fetch / use lastrowid anyway.
+ records = list(records)
+ multiparams = [rec[2] for rec in records]
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ if bookkeeping:
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ ) in zip(records, c.context.compiled_parameters):
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params,
+ False,
+ c.returned_defaults
+ if not c.context.executemany
+ else None,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
+
+ else:
+ # here, we need defaults and/or pk values back.
+
+ records = list(records)
+ if (
+ not hasvalue
+ and connection.dialect.insert_executemany_returning
+ and len(records) > 1
+ ):
+ do_executemany = True
+ else:
+ do_executemany = False
+
+ if not has_all_defaults and base_mapper.eager_defaults:
+ statement = statement.return_defaults()
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
+ elif do_executemany:
+ statement = statement.return_defaults(*table.primary_key)
+
+ if do_executemany:
+ multiparams = [rec[2] for rec in records]
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ if bookkeeping:
+ for (
+ (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ conn,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ),
+ last_inserted_params,
+ inserted_primary_key,
+ returned_defaults,
+ ) in util.zip_longest(
+ records,
+ c.context.compiled_parameters,
+ c.inserted_primary_key_rows,
+ c.returned_defaults_rows or (),
+ ):
+ if inserted_primary_key is None:
+ # this is a real problem and means that we didn't
+ # get back as many PK rows. we can't continue
+ # since this indicates PK rows were missing, which
+ # means we likely mis-populated records starting
+ # at that point with incorrectly matched PK
+ # values.
+ raise orm_exc.FlushError(
+ "Multi-row INSERT statement for %s did not "
+ "produce "
+ "the correct number of INSERTed rows for "
+ "RETURNING. Ensure there are no triggers or "
+ "special driver issues preventing INSERT from "
+ "functioning properly." % mapper_rec
+ )
+
+ for pk, col in zip(
+ inserted_primary_key,
+ mapper._pks_by_table[table],
+ ):
+ prop = mapper_rec._columntoproperty[col]
+ if state_dict.get(prop.key) is None:
+ state_dict[prop.key] = pk
+
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params,
+ False,
+ returned_defaults,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
+ else:
+ for (
+ state,
+ state_dict,
+ params,
+ mapper_rec,
+ connection,
+ value_params,
+ has_all_pks,
+ has_all_defaults,
+ ) in records:
+ if value_params:
+ result = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
+ )
+ else:
+ result = connection._execute_20(
+ statement,
+ params,
+ execution_options=execution_options,
+ )
+
+ primary_key = result.inserted_primary_key
+ if primary_key is None:
+ raise orm_exc.FlushError(
+ "Single-row INSERT statement for %s "
+ "did not produce a "
+ "new primary key result "
+ "being invoked. Ensure there are no triggers or "
+ "special driver issues preventing INSERT from "
+ "functioning properly." % (mapper_rec,)
+ )
+ for pk, col in zip(
+ primary_key, mapper._pks_by_table[table]
+ ):
+ prop = mapper_rec._columntoproperty[col]
+ if (
+ col in value_params
+ or state_dict.get(prop.key) is None
+ ):
+ state_dict[prop.key] = pk
+ if bookkeeping:
+ if state:
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result,
+ result.context.compiled_parameters[0],
+ value_params,
+ False,
+ result.returned_defaults
+ if not result.context.executemany
+ else None,
+ )
+ else:
+ _postfetch_bulk_save(mapper_rec, state_dict, table)
+
+
+def _emit_post_update_statements(
+ base_mapper, uowtransaction, mapper, table, update
+):
+ """Emit UPDATE statements corresponding to value lists collected
+ by _collect_post_update_commands()."""
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
+ needs_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
+
+ def update_stmt():
+ clauses = BooleanClauseList._construct_raw(operators.and_)
+
+ for col in mapper._pks_by_table[table]:
+ clauses.clauses.append(
+ col == sql.bindparam(col._label, type_=col.type)
+ )
+
+ if needs_version_id:
+ clauses.clauses.append(
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type,
+ )
+ )
+
+ stmt = table.update().where(clauses)
+
+ if mapper.version_id_col is not None:
+ stmt = stmt.return_defaults(mapper.version_id_col)
+
+ return stmt
+
+ statement = base_mapper._memo(("post_update", table), update_stmt)
+
+ # execute each UPDATE in the order according to the original
+ # list of states to guarantee row access order, but
+ # also group them into common (connection, cols) sets
+ # to support executemany().
+ for key, records in groupby(
+ update,
+ lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
+ ):
+ rows = 0
+
+ records = list(records)
+ connection = key[0]
+
+ assert_singlerow = (
+ connection.dialect.supports_sane_rowcount
+ if mapper.version_id_col is None
+ else connection.dialect.supports_sane_rowcount_returning
+ )
+ assert_multirow = (
+ assert_singlerow
+ and connection.dialect.supports_sane_multi_rowcount
+ )
+ allow_multirow = not needs_version_id or assert_multirow
+
+ if not allow_multirow:
+ check_rowcount = assert_singlerow
+ for state, state_dict, mapper_rec, connection, params in records:
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+
+ _postfetch_post_update(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
+ rows += c.rowcount
+ else:
+ multiparams = [
+ params
+ for state, state_dict, mapper_rec, conn, params in records
+ ]
+
+ check_rowcount = assert_multirow or (
+ assert_singlerow and len(multiparams) == 1
+ )
+
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
+ rows += c.rowcount
+ for state, state_dict, mapper_rec, connection, params in records:
+ _postfetch_post_update(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ )
+
+ if check_rowcount:
+ if rows != len(records):
+ raise orm_exc.StaleDataError(
+ "UPDATE statement on table '%s' expected to "
+ "update %d row(s); %d were matched."
+ % (table.description, len(records), rows)
+ )
+
+ elif needs_version_id:
+ util.warn(
+ "Dialect %s does not support updated rowcount "
+ "- versioning cannot be verified."
+ % c.dialect.dialect_description
+ )
+
+
+def _emit_delete_statements(
+ base_mapper, uowtransaction, mapper, table, delete
+):
+ """Emit DELETE statements corresponding to value lists collected
+ by _collect_delete_commands()."""
+
+ need_version_id = (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ )
+
+ def delete_stmt():
+ clauses = BooleanClauseList._construct_raw(operators.and_)
+
+ for col in mapper._pks_by_table[table]:
+ clauses.clauses.append(
+ col == sql.bindparam(col.key, type_=col.type)
+ )
+
+ if need_version_id:
+ clauses.clauses.append(
+ mapper.version_id_col
+ == sql.bindparam(
+ mapper.version_id_col.key, type_=mapper.version_id_col.type
+ )
+ )
+
+ return table.delete().where(clauses)
+
+ statement = base_mapper._memo(("delete", table), delete_stmt)
+ for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
+ del_objects = [params for params, connection in recs]
+
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+ expected = len(del_objects)
+ rows_matched = -1
+ only_warn = False
+
+ if (
+ need_version_id
+ and not connection.dialect.supports_sane_multi_rowcount
+ ):
+ if connection.dialect.supports_sane_rowcount:
+ rows_matched = 0
+ # execute deletes individually so that versioned
+ # rows can be verified
+ for params in del_objects:
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+ rows_matched += c.rowcount
+ else:
+ util.warn(
+ "Dialect %s does not support deleted rowcount "
+ "- versioning cannot be verified."
+ % connection.dialect.dialect_description
+ )
+ connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
+ else:
+ c = connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
+
+ if not need_version_id:
+ only_warn = True
+
+ rows_matched = c.rowcount
+
+ if (
+ base_mapper.confirm_deleted_rows
+ and rows_matched > -1
+ and expected != rows_matched
+ and (
+ connection.dialect.supports_sane_multi_rowcount
+ or len(del_objects) == 1
+ )
+ ):
+ # TODO: why does this "only warn" if versioning is turned off,
+ # whereas the UPDATE raises?
+ if only_warn:
+ util.warn(
+ "DELETE statement on table '%s' expected to "
+ "delete %d row(s); %d were matched. Please set "
+ "confirm_deleted_rows=False within the mapper "
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
+ )
+ else:
+ raise orm_exc.StaleDataError(
+ "DELETE statement on table '%s' expected to "
+ "delete %d row(s); %d were matched. Please set "
+ "confirm_deleted_rows=False within the mapper "
+ "configuration to prevent this warning."
+ % (table.description, expected, rows_matched)
+ )
+
+
+def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
+ """finalize state on states that have been inserted or updated,
+ including calling after_insert/after_update events.
+
+ """
+ for state, state_dict, mapper, connection, has_identity in states:
+
+ if mapper._readonly_props:
+ readonly = state.unmodified_intersection(
+ [
+ p.key
+ for p in mapper._readonly_props
+ if (
+ p.expire_on_flush
+ and (not p.deferred or p.key in state.dict)
+ )
+ or (
+ not p.expire_on_flush
+ and not p.deferred
+ and p.key not in state.dict
+ )
+ ]
+ )
+ if readonly:
+ state._expire_attributes(state.dict, readonly)
+
+ # if eager_defaults option is enabled, load
+ # all expired cols. Else if we have a version_id_col, make sure
+ # it isn't expired.
+ toload_now = []
+
+ if base_mapper.eager_defaults:
+ toload_now.extend(
+ state._unloaded_non_object.intersection(
+ mapper._server_default_plus_onupdate_propkeys
+ )
+ )
+
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_generator is False
+ ):
+ if mapper._version_id_prop.key in state.unloaded:
+ toload_now.extend([mapper._version_id_prop.key])
+
+ if toload_now:
+ state.key = base_mapper._identity_key_from_state(state)
+ stmt = future.select(mapper).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ loading.load_on_ident(
+ uowtransaction.session,
+ stmt,
+ state.key,
+ refresh_state=state,
+ only_load_props=toload_now,
+ )
+
+ # call after_XXX extensions
+ if not has_identity:
+ mapper.dispatch.after_insert(mapper, connection, state)
+ else:
+ mapper.dispatch.after_update(mapper, connection, state)
+
+ if (
+ mapper.version_id_generator is False
+ and mapper.version_id_col is not None
+ ):
+ if state_dict[mapper._version_id_prop.key] is None:
+ raise orm_exc.FlushError(
+ "Instance does not contain a non-NULL version value"
+ )
+
+
+def _postfetch_post_update(
+ mapper, uowtransaction, table, state, dict_, result, params
+):
+ if uowtransaction.is_deleted(state):
+ return
+
+ prefetch_cols = result.context.compiled.prefetch
+ postfetch_cols = result.context.compiled.postfetch
+
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
+ refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
+ if refresh_flush:
+ load_evt_attrs = []
+
+ for c in prefetch_cols:
+ if c.key in params and c in mapper._columntoproperty:
+ dict_[mapper._columntoproperty[c].key] = params[c.key]
+ if refresh_flush:
+ load_evt_attrs.append(mapper._columntoproperty[c].key)
+
+ if refresh_flush and load_evt_attrs:
+ mapper.class_manager.dispatch.refresh_flush(
+ state, uowtransaction, load_evt_attrs
+ )
+
+ if postfetch_cols:
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
+
+
+def _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ dict_,
+ result,
+ params,
+ value_params,
+ isupdate,
+ returned_defaults,
+):
+ """Expire attributes in need of newly persisted database state,
+ after an INSERT or UPDATE statement has proceeded for that
+ state."""
+
+ prefetch_cols = result.context.compiled.prefetch
+ postfetch_cols = result.context.compiled.postfetch
+ returning_cols = result.context.compiled.returning
+
+ if (
+ mapper.version_id_col is not None
+ and mapper.version_id_col in mapper._cols_by_table[table]
+ ):
+ prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
+
+ refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
+ if refresh_flush:
+ load_evt_attrs = []
+
+ if returning_cols:
+ row = returned_defaults
+ if row is not None:
+ for row_value, col in zip(row, returning_cols):
+ # pk cols returned from insert are handled
+ # distinctly, don't step on the values here
+ if col.primary_key and result.context.isinsert:
+ continue
+
+ # note that columns can be in the "return defaults" that are
+ # not mapped to this mapper, typically because they are
+ # "excluded", which can be specified directly or also occurs
+ # when using declarative w/ single table inheritance
+ prop = mapper._columntoproperty.get(col)
+ if prop:
+ dict_[prop.key] = row_value
+ if refresh_flush:
+ load_evt_attrs.append(prop.key)
+
+ for c in prefetch_cols:
+ if c.key in params and c in mapper._columntoproperty:
+ dict_[mapper._columntoproperty[c].key] = params[c.key]
+ if refresh_flush:
+ load_evt_attrs.append(mapper._columntoproperty[c].key)
+
+ if refresh_flush and load_evt_attrs:
+ mapper.class_manager.dispatch.refresh_flush(
+ state, uowtransaction, load_evt_attrs
+ )
+
+ if isupdate and value_params:
+ # explicitly suit the use case specified by
+ # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
+ # database which are set to themselves in order to do a version bump.
+ postfetch_cols.extend(
+ [
+ col
+ for col in value_params
+ if col.primary_key and col not in returning_cols
+ ]
+ )
+
+ if postfetch_cols:
+ state._expire_attributes(
+ state.dict,
+ [
+ mapper._columntoproperty[c].key
+ for c in postfetch_cols
+ if c in mapper._columntoproperty
+ ],
+ )
+
+ # synchronize newly inserted ids from one table to the next
+ # TODO: this still goes a little too often. would be nice to
+ # have definitive list of "columns that changed" here
+ for m, equated_pairs in mapper._table_to_equated[table]:
+ sync.populate(
+ state,
+ m,
+ state,
+ m,
+ equated_pairs,
+ uowtransaction,
+ mapper.passive_updates,
+ )
+
+
+def _postfetch_bulk_save(mapper, dict_, table):
+ for m, equated_pairs in mapper._table_to_equated[table]:
+ sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
+
+
+def _connections_for_states(base_mapper, uowtransaction, states):
+ """Return an iterator of (state, state.dict, mapper, connection).
+
+ The states are sorted according to _sort_states, then paired
+ with the connection they should be using for the given
+ unit of work transaction.
+
+ """
+ # if session has a connection callable,
+ # organize individual states with the connection
+ # to use for update
+ if uowtransaction.session.connection_callable:
+ connection_callable = uowtransaction.session.connection_callable
+ else:
+ connection = uowtransaction.transaction.connection(base_mapper)
+ connection_callable = None
+
+ for state in _sort_states(base_mapper, states):
+ if connection_callable:
+ connection = connection_callable(base_mapper, state.obj())
+
+ mapper = state.manager.mapper
+
+ yield state, state.dict, mapper, connection
+
+
+def _sort_states(mapper, states):
+ pending = set(states)
+ persistent = set(s for s in pending if s.key is not None)
+ pending.difference_update(persistent)
+
+ try:
+ persistent_sorted = sorted(
+ persistent, key=mapper._persistent_sortkey_fn
+ )
+ except TypeError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Could not sort objects by primary key; primary key "
+ "values must be sortable in Python (was: %s)" % err
+ ),
+ replace_context=err,
+ )
+ return (
+ sorted(pending, key=operator.attrgetter("insert_order"))
+ + persistent_sorted
+ )
+
+
+_EMPTY_DICT = util.immutabledict()
+
+
+class BulkUDCompileState(CompileState):
+ class default_update_options(Options):
+ _synchronize_session = "evaluate"
+ _autoflush = True
+ _subject_mapper = None
+ _resolved_values = _EMPTY_DICT
+ _resolved_keys_as_propnames = _EMPTY_DICT
+ _value_evaluators = _EMPTY_DICT
+ _matched_objects = None
+ _matched_rows = None
+ _refresh_identity_token = None
+
+ @classmethod
+ def orm_pre_session_exec(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_reentrant_invoke,
+ ):
+ if is_reentrant_invoke:
+ return statement, execution_options
+
+ (
+ update_options,
+ execution_options,
+ ) = BulkUDCompileState.default_update_options.from_execution_options(
+ "_sa_orm_update_options",
+ {"synchronize_session"},
+ execution_options,
+ statement._execution_options,
+ )
+
+ sync = update_options._synchronize_session
+ if sync is not None:
+ if sync not in ("evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'evaluate', 'fetch', False"
+ )
+
+ bind_arguments["clause"] = statement
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
+ else:
+ bind_arguments["mapper"] = plugin_subject.mapper
+
+ update_options += {"_subject_mapper": plugin_subject.mapper}
+
+ if update_options._autoflush:
+ session._autoflush()
+
+ statement = statement._annotate(
+ {"synchronize_session": update_options._synchronize_session}
+ )
+
+ # this stage of the execution is called before the do_orm_execute event
+ # hook. meaning for an extension like horizontal sharding, this step
+ # happens before the extension splits out into multiple backends and
+ # runs only once. if we do pre_sync_fetch, we execute a SELECT
+ # statement, which the horizontal sharding extension splits amongst the
+ # shards and combines the results together.
+
+ if update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+
+ return (
+ statement,
+ util.immutabledict(execution_options).union(
+ {"_sa_orm_update_options": update_options}
+ ),
+ )
+
+ @classmethod
+ def orm_setup_cursor_result(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+
+ # this stage of the execution is called after the
+ # do_orm_execute event hook. meaning for an extension like
+ # horizontal sharding, this step happens *within* the horizontal
+ # sharding event handler which calls session.execute() re-entrantly
+ # and will occur for each backend individually.
+ # the sharding extension then returns its own merged result from the
+ # individual ones we return here.
+
+ update_options = execution_options["_sa_orm_update_options"]
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(session, result, update_options)
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(session, result, update_options)
+
+ return result
+
+ @classmethod
+ def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
+ """Apply extra criteria filtering.
+
+ For all distinct single-table-inheritance mappers represented in the
+ table being updated or deleted, produce additional WHERE criteria such
+ that only the appropriate subtypes are selected from the total results.
+
+ Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+ collected from the statement.
+
+ """
+
+ return_crit = ()
+
+ adapter = ext_info._adapter if ext_info.is_aliased_class else None
+
+ if (
+ "additional_entity_criteria",
+ ext_info.mapper,
+ ) in global_attributes:
+ return_crit += tuple(
+ ae._resolve_where_criteria(ext_info)
+ for ae in global_attributes[
+ ("additional_entity_criteria", ext_info.mapper)
+ ]
+ if ae.include_aliases or ae.entity is ext_info
+ )
+
+ if ext_info.mapper._single_table_criterion is not None:
+ return_crit += (ext_info.mapper._single_table_criterion,)
+
+ if adapter:
+ return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
+
+ return return_crit
+
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+
+ value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
+
+ try:
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
+ if statement._where_criteria:
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(
+ global_attributes, mapper
+ )
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
+ else:
+
+ def eval_condition(obj):
+ return True
+
+ except evaluator.UnevaluatableError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ 'Could not evaluate current criteria in Python: "%s". '
+ "Specify 'fetch' or False for the "
+ "synchronize_session execution option." % err
+ ),
+ from_=err,
+ )
+
+ if statement.__visit_name__ == "lambda_element":
+ # ._resolved is called on every LambdaElement in order to
+ # generate the cache key, so this access does not add
+ # additional expense
+ effective_statement = statement._resolved
+ else:
+ effective_statement = statement
+
+ if effective_statement.__visit_name__ == "update":
+ resolved_values = cls._get_resolved_values(
+ mapper, effective_statement
+ )
+ value_evaluators = {}
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ # TODO: detect when the where clause is a trivial primary key match.
+ matched_objects = [
+ state.obj()
+ for state in session.identity_map.all_states()
+ if state.mapper.isa(mapper)
+ and not state.expired
+ and eval_condition(state.obj())
+ and (
+ update_options._refresh_identity_token is None
+ # TODO: coverage for the case where horizontal sharding
+ # invokes an update() or delete() given an explicit identity
+ # token up front
+ or state.identity_token
+ == update_options._refresh_identity_token
+ )
+ ]
+ return update_options + {
+ "_matched_objects": matched_objects,
+ "_value_evaluators": value_evaluators,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
+
+ @classmethod
+ def _get_resolved_values(cls, mapper, statement):
+ if statement._multi_values:
+ return []
+ elif statement._ordered_values:
+ return list(statement._ordered_values)
+ elif statement._values:
+ return list(statement._values.items())
+ else:
+ return []
+
+ @classmethod
+ def _resolved_keys_as_propnames(cls, mapper, resolved_values):
+ values = []
+ for k, v in resolved_values:
+ if isinstance(k, attributes.QueryableAttribute):
+ values.append((k.key, v))
+ continue
+ elif hasattr(k, "__clause_element__"):
+ k = k.__clause_element__()
+
+ if mapper and isinstance(k, expression.ColumnElement):
+ try:
+ attr = mapper._columntoproperty[k]
+ except orm_exc.UnmappedColumnError:
+ pass
+ else:
+ values.append((attr.key, v))
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Invalid expression type: %r" % k
+ )
+ return values
+
+ @classmethod
+ def _do_pre_synchronize_fetch(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
+
+ select_stmt = (
+ select(*(mapper.primary_key + (mapper.select_identity_token,)))
+ .select_from(mapper)
+ .options(*statement._with_options)
+ )
+ select_stmt._where_criteria = statement._where_criteria
+
+ def skip_for_full_returning(orm_context):
+ bind = orm_context.session.get_bind(**orm_context.bind_arguments)
+ if bind.dialect.full_returning:
+ return _result.null_result()
+ else:
+ return None
+
+ result = session.execute(
+ select_stmt,
+ params,
+ execution_options,
+ bind_arguments,
+ _add_event=skip_for_full_returning,
+ )
+ matched_rows = result.fetchall()
+
+ value_evaluators = _EMPTY_DICT
+
+ if statement.__visit_name__ == "lambda_element":
+ # ._resolved is called on every LambdaElement in order to
+ # generate the cache key, so this access does not add
+ # additional expense
+ effective_statement = statement._resolved
+ else:
+ effective_statement = statement
+
+ if effective_statement.__visit_name__ == "update":
+ target_cls = mapper.class_
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ resolved_values = cls._get_resolved_values(
+ mapper, effective_statement
+ )
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ value_evaluators = {}
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ else:
+ resolved_keys_as_propnames = _EMPTY_DICT
+
+ return update_options + {
+ "_value_evaluators": value_evaluators,
+ "_matched_rows": matched_rows,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
+
+
+class ORMDMLState:
+ @classmethod
+ def get_entity_description(cls, statement):
+ ext_info = statement.table._annotations["parententity"]
+ mapper = ext_info.mapper
+ if ext_info.is_aliased_class:
+ _label_name = ext_info.name
+ else:
+ _label_name = mapper.class_.__name__
+
+ return {
+ "name": _label_name,
+ "type": mapper.class_,
+ "expr": ext_info.entity,
+ "entity": ext_info.entity,
+ "table": mapper.local_table,
+ }
+
+ @classmethod
+ def get_returning_column_descriptions(cls, statement):
+ def _ent_for_col(c):
+ return c._annotations.get("parententity", None)
+
+ def _attr_for_col(c, ent):
+ if ent is None:
+ return c
+ proxy_key = c._annotations.get("proxy_key", None)
+ if not proxy_key:
+ return c
+ else:
+ return getattr(ent.entity, proxy_key, c)
+
+ return [
+ {
+ "name": c.key,
+ "type": c.type,
+ "expr": _attr_for_col(c, ent),
+ "aliased": ent.is_aliased_class,
+ "entity": ent.entity,
+ }
+ for c, ent in [
+ (c, _ent_for_col(c)) for c in statement._all_selected_columns
+ ]
+ ]
+
+
+@CompileState.plugin_for("orm", "insert")
+class ORMInsert(ORMDMLState, InsertDMLState):
+ @classmethod
+ def orm_pre_session_exec(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ is_reentrant_invoke,
+ ):
+ bind_arguments["clause"] = statement
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
+ else:
+ bind_arguments["mapper"] = plugin_subject.mapper
+
+ return (
+ statement,
+ util.immutabledict(execution_options),
+ )
+
+ @classmethod
+ def orm_setup_cursor_result(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+ return result
+
+
+@CompileState.plugin_for("orm", "update")
+class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+
+ self = cls.__new__(cls)
+
+ ext_info = statement.table._annotations["parententity"]
+
+ self.mapper = mapper = ext_info.mapper
+
+ self.extra_criteria_entities = {}
+
+ self._resolved_values = cls._get_resolved_values(mapper, statement)
+
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
+ if not statement._preserve_parameter_order and statement._values:
+ self._resolved_values = dict(self._resolved_values)
+
+ new_stmt = sql.Update.__new__(sql.Update)
+ new_stmt.__dict__.update(statement.__dict__)
+ new_stmt.table = mapper.local_table
+
+ # note if the statement has _multi_values, these
+ # are passed through to the new statement, which will then raise
+ # InvalidRequestError because UPDATE doesn't support multi_values
+ # right now.
+ if statement._ordered_values:
+ new_stmt._ordered_values = self._resolved_values
+ elif statement._values:
+ new_stmt._values = self._resolved_values
+
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ new_stmt = new_stmt.where(*new_crit)
+
+ # if we are against a lambda statement we might not be the
+ # topmost object that received per-execute annotations
+
+ if (
+ compiler._annotations.get("synchronize_session", None) == "fetch"
+ and compiler.dialect.full_returning
+ ):
+ if new_stmt._returning:
+ raise sa_exc.InvalidRequestError(
+ "Can't use synchronize_session='fetch' "
+ "with explicit returning()"
+ )
+ new_stmt = new_stmt.returning(*mapper.primary_key)
+
+ UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
+
+ return self
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator):
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return core_get_crud_kv_pairs(statement, kv_iterator)
+
+ mapper = plugin_subject.mapper
+
+ values = []
+
+ for k, v in kv_iterator:
+ k = coercions.expect(roles.DMLColumnRole, k)
+
+ if isinstance(k, util.string_types):
+ desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+ if desc is NO_VALUE:
+ values.append(
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+ )
+ else:
+ values.extend(
+ core_get_crud_kv_pairs(
+ statement, desc._bulk_update_tuples(v)
+ )
+ )
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["proxy_key"]
+ )
+ values.extend(
+ core_get_crud_kv_pairs(
+ statement, attr._bulk_update_tuples(v)
+ )
+ )
+ else:
+ values.append(
+ (
+ k,
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+ )
+ return values
+
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, result, update_options):
+
+ states = set()
+ evaluated_keys = list(update_options._value_evaluators.keys())
+ values = update_options._resolved_keys_as_propnames
+ attrib = set(k for k, v in values)
+ for obj in update_options._matched_objects:
+
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+
+ # the evaluated states were gathered across all identity tokens.
+ # however the post_sync events are called per identity token,
+ # so filter.
+ if (
+ update_options._refresh_identity_token is not None
+ and state.identity_token
+ != update_options._refresh_identity_token
+ ):
+ continue
+
+ # only evaluate unmodified attributes
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
+ for key in to_evaluate:
+ if key in dict_:
+ dict_[key] = update_options._value_evaluators[key](obj)
+
+ state.manager.dispatch.refresh(state, None, to_evaluate)
+
+ state._commit(dict_, list(to_evaluate))
+
+ to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ if to_expire:
+ state._expire_attributes(dict_, to_expire)
+
+ states.add(state)
+ session._register_altered(states)
+
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, result, update_options):
+ target_mapper = update_options._subject_mapper
+
+ states = set()
+ evaluated_keys = list(update_options._value_evaluators.keys())
+
+ if result.returns_rows:
+ matched_rows = [
+ tuple(row) + (update_options._refresh_identity_token,)
+ for row in result.all()
+ ]
+ else:
+ matched_rows = update_options._matched_rows
+
+ objs = [
+ session.identity_map[identity_key]
+ for identity_key in [
+ target_mapper.identity_key_from_primary_key(
+ list(primary_key),
+ identity_token=identity_token,
+ )
+ for primary_key, identity_token in [
+ (row[0:-1], row[-1]) for row in matched_rows
+ ]
+ if update_options._refresh_identity_token is None
+ or identity_token == update_options._refresh_identity_token
+ ]
+ if identity_key in session.identity_map
+ ]
+
+ values = update_options._resolved_keys_as_propnames
+ attrib = set(k for k, v in values)
+
+ for obj in objs:
+ state, dict_ = (
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+
+ to_evaluate = state.unmodified.intersection(evaluated_keys)
+ for key in to_evaluate:
+ if key in dict_:
+ dict_[key] = update_options._value_evaluators[key](obj)
+ state.manager.dispatch.refresh(state, None, to_evaluate)
+
+ state._commit(dict_, list(to_evaluate))
+
+ to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ if to_expire:
+ state._expire_attributes(dict_, to_expire)
+
+ states.add(state)
+ session._register_altered(states)
+
+
+@CompileState.plugin_for("orm", "delete")
+class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ self = cls.__new__(cls)
+
+ ext_info = statement.table._annotations["parententity"]
+ self.mapper = mapper = ext_info.mapper
+
+ self.extra_criteria_entities = {}
+
+ extra_criteria_attributes = {}
+
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(extra_criteria_attributes)
+
+ new_crit = cls._adjust_for_extra_criteria(
+ extra_criteria_attributes, mapper
+ )
+ if new_crit:
+ statement = statement.where(*new_crit)
+
+ if (
+ mapper
+ and compiler._annotations.get("synchronize_session", None)
+ == "fetch"
+ and compiler.dialect.full_returning
+ ):
+ statement = statement.returning(*mapper.primary_key)
+
+ DeleteDMLState.__init__(self, statement, compiler, **kw)
+
+ return self
+
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, result, update_options):
+
+ session._remove_newly_deleted(
+ [
+ attributes.instance_state(obj)
+ for obj in update_options._matched_objects
+ ]
+ )
+
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, result, update_options):
+ target_mapper = update_options._subject_mapper
+
+ if result.returns_rows:
+ matched_rows = [
+ tuple(row) + (update_options._refresh_identity_token,)
+ for row in result.all()
+ ]
+ else:
+ matched_rows = update_options._matched_rows
+
+ for row in matched_rows:
+ primary_key = row[0:-1]
+ identity_token = row[-1]
+
+ # TODO: inline this and call remove_newly_deleted
+ # once
+ identity_key = target_mapper.identity_key_from_primary_key(
+ list(primary_key),
+ identity_token=identity_token,
+ )
+ if identity_key in session.identity_map:
+ session._remove_newly_deleted(
+ [
+ attributes.instance_state(
+ session.identity_map[identity_key]
+ )
+ ]
+ )
diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py
new file mode 100644
index 0000000..d32af17
--- /dev/null
+++ b/lib/sqlalchemy/orm/properties.py
@@ -0,0 +1,430 @@
+# orm/properties.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""MapperProperty implementations.
+
+This is a private module which defines the behavior of individual ORM-
+mapped attributes.
+
+"""
+from __future__ import absolute_import
+
+from . import attributes
+from .descriptor_props import CompositeProperty
+from .descriptor_props import ConcreteInheritedProperty
+from .descriptor_props import SynonymProperty
+from .interfaces import PropComparator
+from .interfaces import StrategizedProperty
+from .relationships import RelationshipProperty
+from .. import log
+from .. import util
+from ..sql import coercions
+from ..sql import roles
+
+
+__all__ = [
+ "ColumnProperty",
+ "CompositeProperty",
+ "ConcreteInheritedProperty",
+ "RelationshipProperty",
+ "SynonymProperty",
+]
+
+
+@log.class_logger
+class ColumnProperty(StrategizedProperty):
+ """Describes an object attribute that corresponds to a table column.
+
+ Public constructor is the :func:`_orm.column_property` function.
+
+ """
+
+ strategy_wildcard_key = "column"
+ inherit_cache = True
+ _links_to_entity = False
+
+ __slots__ = (
+ "columns",
+ "group",
+ "deferred",
+ "instrument",
+ "comparator_factory",
+ "descriptor",
+ "active_history",
+ "expire_on_flush",
+ "info",
+ "doc",
+ "strategy_key",
+ "_creation_order",
+ "_is_polymorphic_discriminator",
+ "_mapped_by_synonym",
+ "_deferred_column_loader",
+ "_raise_column_loader",
+ "_renders_in_subqueries",
+ "raiseload",
+ )
+
+ def __init__(self, *columns, **kwargs):
+ r"""Provide a column-level property for use with a mapping.
+
+ Column-based properties can normally be applied to the mapper's
+ ``properties`` dictionary using the :class:`_schema.Column`
+ element directly.
+ Use this function when the given column is not directly present within
+ the mapper's selectable; examples include SQL expressions, functions,
+ and scalar SELECT queries.
+
+ The :func:`_orm.column_property` function returns an instance of
+ :class:`.ColumnProperty`.
+
+ Columns that aren't present in the mapper's selectable won't be
+ persisted by the mapper and are effectively "read-only" attributes.
+
+ :param \*cols:
+ list of Column objects to be mapped.
+
+ :param active_history=False:
+ When ``True``, indicates that the "previous" value for a
+ scalar attribute should be loaded when replaced, if not
+ already loaded. Normally, history tracking logic for
+ simple non-primary-key scalar values only needs to be
+ aware of the "new" value in order to perform a flush. This
+ flag is available for applications that make use of
+ :func:`.attributes.get_history` or :meth:`.Session.is_modified`
+ which also need to know
+ the "previous" value of the attribute.
+
+ :param comparator_factory: a class which extends
+ :class:`.ColumnProperty.Comparator` which provides custom SQL
+ clause generation for comparison operations.
+
+ :param group:
+ a group name for this property when marked as deferred.
+
+ :param deferred:
+ when True, the column property is "deferred", meaning that
+ it does not load immediately, and is instead loaded when the
+ attribute is first accessed on an instance. See also
+ :func:`~sqlalchemy.orm.deferred`.
+
+ :param doc:
+ optional string that will be applied as the doc on the
+ class-bound descriptor.
+
+ :param expire_on_flush=True:
+ Disable expiry on flush. A column_property() which refers
+ to a SQL expression (and not a single table-bound column)
+ is considered to be a "read only" property; populating it
+ has no effect on the state of data, and it can only return
+ database state. For this reason a column_property()'s value
+ is expired whenever the parent object is involved in a
+ flush, that is, has any kind of "dirty" state within a flush.
+ Setting this parameter to ``False`` will have the effect of
+ leaving any existing value present after the flush proceeds.
+ Note however that the :class:`.Session` with default expiration
+ settings still expires
+ all attributes after a :meth:`.Session.commit` call, however.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.MapperProperty.info` attribute of this object.
+
+ :param raiseload: if True, indicates the column should raise an error
+ when undeferred, rather than loading the value. This can be
+ altered at query time by using the :func:`.deferred` option with
+ raiseload=False.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`deferred_raiseload`
+
+ .. seealso::
+
+ :ref:`column_property_options` - to map columns while including
+ mapping options
+
+ :ref:`mapper_column_property_sql_expressions` - to map SQL
+ expressions
+
+ """
+ super(ColumnProperty, self).__init__()
+ self.columns = [
+ coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
+ ]
+ self.group = kwargs.pop("group", None)
+ self.deferred = kwargs.pop("deferred", False)
+ self.raiseload = kwargs.pop("raiseload", False)
+ self.instrument = kwargs.pop("_instrument", True)
+ self.comparator_factory = kwargs.pop(
+ "comparator_factory", self.__class__.Comparator
+ )
+ self.descriptor = kwargs.pop("descriptor", None)
+ self.active_history = kwargs.pop("active_history", False)
+ self.expire_on_flush = kwargs.pop("expire_on_flush", True)
+
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+
+ if "doc" in kwargs:
+ self.doc = kwargs.pop("doc")
+ else:
+ for col in reversed(self.columns):
+ doc = getattr(col, "doc", None)
+ if doc is not None:
+ self.doc = doc
+ break
+ else:
+ self.doc = None
+
+ if kwargs:
+ raise TypeError(
+ "%s received unexpected keyword argument(s): %s"
+ % (self.__class__.__name__, ", ".join(sorted(kwargs.keys())))
+ )
+
+ util.set_creation_order(self)
+
+ self.strategy_key = (
+ ("deferred", self.deferred),
+ ("instrument", self.instrument),
+ )
+ if self.raiseload:
+ self.strategy_key += (("raiseload", True),)
+
+ def _memoized_attr__renders_in_subqueries(self):
+ return ("deferred", True) not in self.strategy_key or (
+ self not in self.parent._readonly_props
+ )
+
+ @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
+ def _memoized_attr__deferred_column_loader(self):
+ state = util.preloaded.orm_state
+ strategies = util.preloaded.orm_strategies
+ return state.InstanceState._instance_level_callable_processor(
+ self.parent.class_manager,
+ strategies.LoadDeferredColumns(self.key),
+ self.key,
+ )
+
+ @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies")
+ def _memoized_attr__raise_column_loader(self):
+ state = util.preloaded.orm_state
+ strategies = util.preloaded.orm_strategies
+ return state.InstanceState._instance_level_callable_processor(
+ self.parent.class_manager,
+ strategies.LoadDeferredColumns(self.key, True),
+ self.key,
+ )
+
+ def __clause_element__(self):
+ """Allow the ColumnProperty to work in expression before it is turned
+ into an instrumented attribute.
+ """
+
+ return self.expression
+
+ @property
+ def expression(self):
+ """Return the primary column or expression for this ColumnProperty.
+
+ E.g.::
+
+
+ class File(Base):
+ # ...
+
+ name = Column(String(64))
+ extension = Column(String(8))
+ filename = column_property(name + '.' + extension)
+ path = column_property('C:/' + filename.expression)
+
+ .. seealso::
+
+ :ref:`mapper_column_property_sql_expressions_composed`
+
+ """
+ return self.columns[0]
+
+ def instrument_class(self, mapper):
+ if not self.instrument:
+ return
+
+ attributes.register_descriptor(
+ mapper.class_,
+ self.key,
+ comparator=self.comparator_factory(self, mapper),
+ parententity=mapper,
+ doc=self.doc,
+ )
+
+ def do_init(self):
+ super(ColumnProperty, self).do_init()
+
+ if len(self.columns) > 1 and set(self.parent.primary_key).issuperset(
+ self.columns
+ ):
+ util.warn(
+ (
+ "On mapper %s, primary key column '%s' is being combined "
+ "with distinct primary key column '%s' in attribute '%s'. "
+ "Use explicit properties to give each column its own "
+ "mapped attribute name."
+ )
+ % (self.parent, self.columns[1], self.columns[0], self.key)
+ )
+
+ def copy(self):
+ return ColumnProperty(
+ deferred=self.deferred,
+ group=self.group,
+ active_history=self.active_history,
+ *self.columns
+ )
+
+ def _getcommitted(
+ self, state, dict_, column, passive=attributes.PASSIVE_OFF
+ ):
+ return state.get_impl(self.key).get_committed_value(
+ state, dict_, passive=passive
+ )
+
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
+ if not self.instrument:
+ return
+ elif self.key in source_dict:
+ value = source_dict[self.key]
+
+ if not load:
+ dest_dict[self.key] = value
+ else:
+ impl = dest_state.get_impl(self.key)
+ impl.set(dest_state, dest_dict, value, None)
+ elif dest_state.has_identity and self.key not in dest_dict:
+ dest_state._expire_attributes(
+ dest_dict, [self.key], no_loader=True
+ )
+
+ class Comparator(util.MemoizedSlots, PropComparator):
+ """Produce boolean, comparison, and other operators for
+ :class:`.ColumnProperty` attributes.
+
+ See the documentation for :class:`.PropComparator` for a brief
+ overview.
+
+ .. seealso::
+
+ :class:`.PropComparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ __slots__ = "__clause_element__", "info", "expressions"
+
+ def _orm_annotate_column(self, column):
+ """annotate and possibly adapt a column to be returned
+ as the mapped-attribute exposed version of the column.
+
+ The column in this context needs to act as much like the
+ column in an ORM mapped context as possible, so includes
+ annotations to give hints to various ORM functions as to
+ the source entity of this column. It also adapts it
+ to the mapper's with_polymorphic selectable if one is
+ present.
+
+ """
+
+ pe = self._parententity
+ annotations = {
+ "entity_namespace": pe,
+ "parententity": pe,
+ "parentmapper": pe,
+ "proxy_key": self.prop.key,
+ }
+
+ col = column
+
+ # for a mapper with polymorphic_on and an adapter, return
+ # the column against the polymorphic selectable.
+ # see also orm.util._orm_downgrade_polymorphic_columns
+ # for the reverse operation.
+ if self._parentmapper._polymorphic_adapter:
+ mapper_local_col = col
+ col = self._parentmapper._polymorphic_adapter.traverse(col)
+
+ # this is a clue to the ORM Query etc. that this column
+ # was adapted to the mapper's polymorphic_adapter. the
+ # ORM uses this hint to know which column its adapting.
+ annotations["adapt_column"] = mapper_local_col
+
+ return col._annotate(annotations)._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": pe}
+ )
+
+ def _memoized_method___clause_element__(self):
+ if self.adapter:
+ return self.adapter(self.prop.columns[0], self.prop.key)
+ else:
+ return self._orm_annotate_column(self.prop.columns[0])
+
+ def _memoized_attr_info(self):
+ """The .info dictionary for this attribute."""
+
+ ce = self.__clause_element__()
+ try:
+ return ce.info
+ except AttributeError:
+ return self.prop.info
+
+ def _memoized_attr_expressions(self):
+ """The full sequence of columns referenced by this
+ attribute, adjusted for any aliasing in progress.
+
+ .. versionadded:: 1.3.17
+
+ """
+ if self.adapter:
+ return [
+ self.adapter(col, self.prop.key)
+ for col in self.prop.columns
+ ]
+ else:
+ return [
+ self._orm_annotate_column(col) for col in self.prop.columns
+ ]
+
+ def _fallback_getattr(self, key):
+ """proxy attribute access down to the mapped column.
+
+ this allows user-defined comparison methods to be accessed.
+ """
+ return getattr(self.__clause_element__(), key)
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.__clause_element__(), *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ col = self.__clause_element__()
+ return op(col._bind_param(op, other), col, **kwargs)
+
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
new file mode 100644
index 0000000..99e4591
--- /dev/null
+++ b/lib/sqlalchemy/orm/query.py
@@ -0,0 +1,3508 @@
+# orm/query.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The Query class and support.
+
+Defines the :class:`_query.Query` class, the central
+construct used by the ORM to construct database queries.
+
+The :class:`_query.Query` class should not be confused with the
+:class:`_expression.Select` class, which defines database
+SELECT operations at the SQL (non-ORM) level. ``Query`` differs from
+``Select`` in that it returns ORM-mapped objects and interacts with an
+ORM session, whereas the ``Select`` construct interacts directly with the
+database to return iterable result sets.
+
+"""
+import itertools
+import operator
+import types
+
+from . import exc as orm_exc
+from . import interfaces
+from . import loading
+from . import util as orm_util
+from .base import _assertions
+from .context import _column_descriptions
+from .context import _legacy_determine_last_joined_entity
+from .context import _legacy_filter_by_entity_zero
+from .context import LABEL_STYLE_LEGACY_ORM
+from .context import ORMCompileState
+from .context import ORMFromStatementCompileState
+from .context import QueryContext
+from .interfaces import ORMColumnsClauseRole
+from .util import aliased
+from .util import AliasedClass
+from .util import object_mapper
+from .util import with_parent
+from .util import with_polymorphic
+from .. import exc as sa_exc
+from .. import inspect
+from .. import inspection
+from .. import log
+from .. import sql
+from .. import util
+from ..sql import coercions
+from ..sql import elements
+from ..sql import expression
+from ..sql import roles
+from ..sql import Select
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.annotation import SupportsCloneAnnotations
+from ..sql.base import _entity_namespace_key
+from ..sql.base import _generative
+from ..sql.base import Executable
+from ..sql.selectable import _MemoizedSelectEntities
+from ..sql.selectable import _SelectFromElements
+from ..sql.selectable import ForUpdateArg
+from ..sql.selectable import GroupedElement
+from ..sql.selectable import HasHints
+from ..sql.selectable import HasPrefixes
+from ..sql.selectable import HasSuffixes
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectBase
+from ..sql.selectable import SelectStatementGrouping
+from ..sql.visitors import InternalTraversal
+from ..util import collections_abc
+
+__all__ = ["Query", "QueryContext", "aliased"]
+
+
+@inspection._self_inspects
+@log.class_logger
+class Query(
+ _SelectFromElements,
+ SupportsCloneAnnotations,
+ HasPrefixes,
+ HasSuffixes,
+ HasHints,
+ Executable,
+):
+
+ """ORM-level SQL construction object.
+
+ :class:`_query.Query`
+ is the source of all SELECT statements generated by the
+ ORM, both those formulated by end-user query operations as well as by
+ high level internal operations such as related collection loading. It
+ features a generative interface whereby successive calls return a new
+ :class:`_query.Query` object, a copy of the former with additional
+ criteria and options associated with it.
+
+ :class:`_query.Query` objects are normally initially generated using the
+ :meth:`~.Session.query` method of :class:`.Session`, and in
+ less common cases by instantiating the :class:`_query.Query` directly and
+ associating with a :class:`.Session` using the
+ :meth:`_query.Query.with_session`
+ method.
+
+ For a full walk through of :class:`_query.Query` usage, see the
+ :ref:`ormtutorial_toplevel`.
+
+ """
+
+ # elements that are in Core and can be cached in the same way
+ _where_criteria = ()
+ _having_criteria = ()
+
+ _order_by_clauses = ()
+ _group_by_clauses = ()
+ _limit_clause = None
+ _offset_clause = None
+
+ _distinct = False
+ _distinct_on = ()
+
+ _for_update_arg = None
+ _correlate = ()
+ _auto_correlate = True
+ _from_obj = ()
+ _setup_joins = ()
+ _legacy_setup_joins = ()
+ _label_style = LABEL_STYLE_LEGACY_ORM
+
+ _memoized_select_entities = ()
+
+ _compile_options = ORMCompileState.default_compile_options
+
+ load_options = QueryContext.default_load_options + {
+ "_legacy_uniquing": True
+ }
+
+ _params = util.EMPTY_DICT
+
+ # local Query builder state, not needed for
+ # compilation or execution
+ _aliased_generation = None
+ _enable_assertions = True
+ _last_joined_entity = None
+ _statement = None
+
+ # mirrors that of ClauseElement, used to propagate the "orm"
+ # plugin as well as the "subject" of the plugin, e.g. the mapper
+ # we are querying against.
+ _propagate_attrs = util.immutabledict()
+
+ def __init__(self, entities, session=None):
+ """Construct a :class:`_query.Query` directly.
+
+ E.g.::
+
+ q = Query([User, Address], session=some_session)
+
+ The above is equivalent to::
+
+ q = some_session.query(User, Address)
+
+ :param entities: a sequence of entities and/or SQL expressions.
+
+ :param session: a :class:`.Session` with which the
+ :class:`_query.Query`
+ will be associated. Optional; a :class:`_query.Query`
+ can be associated
+ with a :class:`.Session` generatively via the
+ :meth:`_query.Query.with_session` method as well.
+
+ .. seealso::
+
+ :meth:`.Session.query`
+
+ :meth:`_query.Query.with_session`
+
+ """
+
+ self.session = session
+ self._set_entities(entities)
+
+ def _set_propagate_attrs(self, values):
+ self._propagate_attrs = util.immutabledict(values)
+ return self
+
+ def _set_entities(self, entities):
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole,
+ ent,
+ apply_propagate_attrs=self,
+ post_inspect=True,
+ )
+ for ent in util.to_list(entities)
+ ]
+
+ def _entity_from_pre_ent_zero(self):
+ if not self._raw_columns:
+ return None
+
+ ent = self._raw_columns[0]
+
+ if "parententity" in ent._annotations:
+ return ent._annotations["parententity"]
+ elif isinstance(ent, ORMColumnsClauseRole):
+ return ent.entity
+ elif "bundle" in ent._annotations:
+ return ent._annotations["bundle"]
+ else:
+ # label, other SQL expression
+ for element in visitors.iterate(ent):
+ if "parententity" in element._annotations:
+ return element._annotations["parententity"]
+ else:
+ return None
+
+ def _only_full_mapper_zero(self, methname):
+ if (
+ len(self._raw_columns) != 1
+ or "parententity" not in self._raw_columns[0]._annotations
+ or not self._raw_columns[0].is_selectable
+ ):
+ raise sa_exc.InvalidRequestError(
+ "%s() can only be used against "
+ "a single mapped class." % methname
+ )
+
+ return self._raw_columns[0]._annotations["parententity"]
+
+ def _set_select_from(self, obj, set_base_alias):
+ fa = [
+ coercions.expect(
+ roles.StrictFromClauseRole,
+ elem,
+ allow_select=True,
+ apply_propagate_attrs=self,
+ )
+ for elem in obj
+ ]
+
+ self._compile_options += {"_set_base_alias": set_base_alias}
+ self._from_obj = tuple(fa)
+
+ @_generative
+ def _set_lazyload_from(self, state):
+ self.load_options += {"_lazy_loaded_from": state}
+
+ def _get_condition(self):
+ return self._no_criterion_condition(
+ "get", order_by=False, distinct=False
+ )
+
+ def _get_existing_condition(self):
+ self._no_criterion_assertion("get", order_by=False, distinct=False)
+
+ def _no_criterion_assertion(self, meth, order_by=True, distinct=True):
+ if not self._enable_assertions:
+ return
+ if (
+ self._where_criteria
+ or self._statement is not None
+ or self._from_obj
+ or self._legacy_setup_joins
+ or self._limit_clause is not None
+ or self._offset_clause is not None
+ or self._group_by_clauses
+ or (order_by and self._order_by_clauses)
+ or (distinct and self._distinct)
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Query.%s() being called on a "
+ "Query with existing criterion. " % meth
+ )
+
+ def _no_criterion_condition(self, meth, order_by=True, distinct=True):
+ self._no_criterion_assertion(meth, order_by, distinct)
+
+ self._from_obj = self._legacy_setup_joins = ()
+ if self._statement is not None:
+ self._compile_options += {"_statement": None}
+ self._where_criteria = ()
+ self._distinct = False
+
+ self._order_by_clauses = self._group_by_clauses = ()
+
+ def _no_clauseelement_condition(self, meth):
+ if not self._enable_assertions:
+ return
+ if self._order_by_clauses:
+ raise sa_exc.InvalidRequestError(
+ "Query.%s() being called on a "
+ "Query with existing criterion. " % meth
+ )
+ self._no_criterion_condition(meth)
+
+ def _no_statement_condition(self, meth):
+ if not self._enable_assertions:
+ return
+ if self._statement is not None:
+ raise sa_exc.InvalidRequestError(
+ (
+ "Query.%s() being called on a Query with an existing full "
+ "statement - can't apply criterion."
+ )
+ % meth
+ )
+
+ def _no_limit_offset(self, meth):
+ if not self._enable_assertions:
+ return
+ if self._limit_clause is not None or self._offset_clause is not None:
+ raise sa_exc.InvalidRequestError(
+ "Query.%s() being called on a Query which already has LIMIT "
+ "or OFFSET applied. Call %s() before limit() or offset() "
+ "are applied." % (meth, meth)
+ )
+
+ @property
+ def _has_row_limiting_clause(self):
+ return (
+ self._limit_clause is not None or self._offset_clause is not None
+ )
+
+ def _get_options(
+ self,
+ populate_existing=None,
+ version_check=None,
+ only_load_props=None,
+ refresh_state=None,
+ identity_token=None,
+ ):
+ load_options = {}
+ compile_options = {}
+
+ if version_check:
+ load_options["_version_check"] = version_check
+ if populate_existing:
+ load_options["_populate_existing"] = populate_existing
+ if refresh_state:
+ load_options["_refresh_state"] = refresh_state
+ compile_options["_for_refresh_state"] = True
+ if only_load_props:
+ compile_options["_only_load_props"] = frozenset(only_load_props)
+ if identity_token:
+ load_options["_refresh_identity_token"] = identity_token
+
+ if load_options:
+ self.load_options += load_options
+ if compile_options:
+ self._compile_options += compile_options
+
+ return self
+
+ def _clone(self):
+ return self._generate()
+
+ @property
+ def statement(self):
+ """The full SELECT statement represented by this Query.
+
+ The statement by default will not have disambiguating labels
+ applied to the construct unless with_labels(True) is called
+ first.
+
+ """
+
+ # .statement can return the direct future.Select() construct here, as
+ # long as we are not using subsequent adaption features that
+ # are made against raw entities, e.g. from_self(), with_polymorphic(),
+ # select_entity_from(). If these features are being used, then
+ # the Select() we return will not have the correct .selected_columns
+ # collection and will not embed in subsequent queries correctly.
+ # We could find a way to make this collection "correct", however
+ # this would not be too different from doing the full compile as
+ # we are doing in any case, the Select() would still not have the
+ # proper state for other attributes like whereclause, order_by,
+ # and these features are all deprecated in any case.
+ #
+ # for these reasons, Query is not a Select, it remains an ORM
+ # object for which __clause_element__() must be called in order for
+ # it to provide a real expression object.
+ #
+ # from there, it starts to look much like Query itself won't be
+ # passed into the execute process and wont generate its own cache
+ # key; this will all occur in terms of the ORM-enabled Select.
+ if (
+ not self._compile_options._set_base_alias
+ and not self._compile_options._with_polymorphic_adapt_map
+ ):
+ # if we don't have legacy top level aliasing features in use
+ # then convert to a future select() directly
+ stmt = self._statement_20(for_statement=True)
+ else:
+ stmt = self._compile_state(for_statement=True).statement
+
+ if self._params:
+ stmt = stmt.params(self._params)
+
+ return stmt
+
+ def _final_statement(self, legacy_query_style=True):
+ """Return the 'final' SELECT statement for this :class:`.Query`.
+
+ This is the Core-only select() that will be rendered by a complete
+ compilation of this query, and is what .statement used to return
+ in 1.3.
+
+ This method creates a complete compile state so is fairly expensive.
+
+ """
+
+ q = self._clone()
+
+ return q._compile_state(
+ use_legacy_query_style=legacy_query_style
+ ).statement
+
+ def _statement_20(self, for_statement=False, use_legacy_query_style=True):
+ # TODO: this event needs to be deprecated, as it currently applies
+ # only to ORM query and occurs at this spot that is now more
+ # or less an artificial spot
+ if self.dispatch.before_compile:
+ for fn in self.dispatch.before_compile:
+ new_query = fn(self)
+ if new_query is not None and new_query is not self:
+ self = new_query
+ if not fn._bake_ok:
+ self._compile_options += {"_bake_ok": False}
+
+ compile_options = self._compile_options
+ compile_options += {
+ "_for_statement": for_statement,
+ "_use_legacy_query_style": use_legacy_query_style,
+ }
+
+ if self._statement is not None:
+ stmt = FromStatement(self._raw_columns, self._statement)
+ stmt.__dict__.update(
+ _with_options=self._with_options,
+ _with_context_options=self._with_context_options,
+ _compile_options=compile_options,
+ _execution_options=self._execution_options,
+ _propagate_attrs=self._propagate_attrs,
+ )
+ else:
+ # Query / select() internal attributes are 99% cross-compatible
+ stmt = Select._create_raw_select(**self.__dict__)
+ stmt.__dict__.update(
+ _label_style=self._label_style,
+ _compile_options=compile_options,
+ _propagate_attrs=self._propagate_attrs,
+ )
+ stmt.__dict__.pop("session", None)
+
+ # ensure the ORM context is used to compile the statement, even
+ # if it has no ORM entities. This is so ORM-only things like
+ # _legacy_joins are picked up that wouldn't be picked up by the
+ # Core statement context
+ if "compile_state_plugin" not in stmt._propagate_attrs:
+ stmt._propagate_attrs = stmt._propagate_attrs.union(
+ {"compile_state_plugin": "orm", "plugin_subject": None}
+ )
+
+ return stmt
+
+ def subquery(
+ self,
+ name=None,
+ with_labels=False,
+ reduce_columns=False,
+ ):
+ """Return the full SELECT statement represented by
+ this :class:`_query.Query`, embedded within an
+ :class:`_expression.Alias`.
+
+ Eager JOIN generation within the query is disabled.
+
+ :param name: string name to be assigned as the alias;
+ this is passed through to :meth:`_expression.FromClause.alias`.
+ If ``None``, a name will be deterministically generated
+ at compile time.
+
+ :param with_labels: if True, :meth:`.with_labels` will be called
+ on the :class:`_query.Query` first to apply table-qualified labels
+ to all columns.
+
+ :param reduce_columns: if True,
+ :meth:`_expression.Select.reduce_columns` will
+ be called on the resulting :func:`_expression.select` construct,
+ to remove same-named columns where one also refers to the other
+ via foreign key or WHERE clause equivalence.
+
+ """
+ q = self.enable_eagerloads(False)
+ if with_labels:
+ q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ q = q.statement
+
+ if reduce_columns:
+ q = q.reduce_columns()
+ return q.alias(name=name)
+
+ def cte(self, name=None, recursive=False, nesting=False):
+ r"""Return the full SELECT statement represented by this
+ :class:`_query.Query` represented as a common table expression (CTE).
+
+ Parameters and usage are the same as those of the
+ :meth:`_expression.SelectBase.cte` method; see that method for
+ further details.
+
+ Here is the `PostgreSQL WITH
+ RECURSIVE example
+ <https://www.postgresql.org/docs/current/static/queries-with.html>`_.
+ Note that, in this example, the ``included_parts`` cte and the
+ ``incl_alias`` alias of it are Core selectables, which
+ means the columns are accessed via the ``.c.`` attribute. The
+ ``parts_alias`` object is an :func:`_orm.aliased` instance of the
+ ``Part`` entity, so column-mapped attributes are available
+ directly::
+
+ from sqlalchemy.orm import aliased
+
+ class Part(Base):
+ __tablename__ = 'part'
+ part = Column(String, primary_key=True)
+ sub_part = Column(String, primary_key=True)
+ quantity = Column(Integer)
+
+ included_parts = session.query(
+ Part.sub_part,
+ Part.part,
+ Part.quantity).\
+ filter(Part.part=="our part").\
+ cte(name="included_parts", recursive=True)
+
+ incl_alias = aliased(included_parts, name="pr")
+ parts_alias = aliased(Part, name="p")
+ included_parts = included_parts.union_all(
+ session.query(
+ parts_alias.sub_part,
+ parts_alias.part,
+ parts_alias.quantity).\
+ filter(parts_alias.part==incl_alias.c.sub_part)
+ )
+
+ q = session.query(
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).
+ label('total_quantity')
+ ).\
+ group_by(included_parts.c.sub_part)
+
+ .. seealso::
+
+ :meth:`_expression.HasCTE.cte`
+
+ """
+ return self.enable_eagerloads(False).statement.cte(
+ name=name, recursive=recursive, nesting=nesting
+ )
+
+ def label(self, name):
+ """Return the full SELECT statement represented by this
+ :class:`_query.Query`, converted
+ to a scalar subquery with a label of the given name.
+
+ Analogous to :meth:`sqlalchemy.sql.expression.SelectBase.label`.
+
+ """
+
+ return self.enable_eagerloads(False).statement.label(name)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_query.Query.as_scalar` method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_query.Query.scalar_subquery`.",
+ )
+ def as_scalar(self):
+ """Return the full SELECT statement represented by this
+ :class:`_query.Query`, converted to a scalar subquery.
+
+ """
+ return self.scalar_subquery()
+
+ def scalar_subquery(self):
+ """Return the full SELECT statement represented by this
+ :class:`_query.Query`, converted to a scalar subquery.
+
+ Analogous to
+ :meth:`sqlalchemy.sql.expression.SelectBase.scalar_subquery`.
+
+ .. versionchanged:: 1.4 The :meth:`_query.Query.scalar_subquery`
+ method replaces the :meth:`_query.Query.as_scalar` method.
+
+ """
+
+ return self.enable_eagerloads(False).statement.scalar_subquery()
+
+ @property
+ def selectable(self):
+ """Return the :class:`_expression.Select` object emitted by this
+ :class:`_query.Query`.
+
+ Used for :func:`_sa.inspect` compatibility, this is equivalent to::
+
+ query.enable_eagerloads(False).with_labels().statement
+
+ """
+ return self.__clause_element__()
+
+ def __clause_element__(self):
+ return (
+ self._with_compile_options(
+ _enable_eagerloads=False, _render_for_subquery=True
+ )
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .statement
+ )
+
+ @_generative
+ def only_return_tuples(self, value):
+ """When set to True, the query results will always be a tuple.
+
+ This is specifically for single element queries. The default is False.
+
+ .. versionadded:: 1.2.5
+
+ .. seealso::
+
+ :meth:`_query.Query.is_single_entity`
+
+ """
+ self.load_options += dict(_only_return_tuples=value)
+
+ @property
+ def is_single_entity(self):
+ """Indicates if this :class:`_query.Query`
+ returns tuples or single entities.
+
+ Returns True if this query returns a single entity for each instance
+ in its result list, and False if this query returns a tuple of entities
+ for each result.
+
+ .. versionadded:: 1.3.11
+
+ .. seealso::
+
+ :meth:`_query.Query.only_return_tuples`
+
+ """
+ return (
+ not self.load_options._only_return_tuples
+ and len(self._raw_columns) == 1
+ and "parententity" in self._raw_columns[0]._annotations
+ and isinstance(
+ self._raw_columns[0]._annotations["parententity"],
+ ORMColumnsClauseRole,
+ )
+ )
+
+ @_generative
+ def enable_eagerloads(self, value):
+ """Control whether or not eager joins and subqueries are
+ rendered.
+
+ When set to False, the returned Query will not render
+ eager joins regardless of :func:`~sqlalchemy.orm.joinedload`,
+ :func:`~sqlalchemy.orm.subqueryload` options
+ or mapper-level ``lazy='joined'``/``lazy='subquery'``
+ configurations.
+
+ This is used primarily when nesting the Query's
+ statement into a subquery or other
+ selectable, or when using :meth:`_query.Query.yield_per`.
+
+ """
+ self._compile_options += {"_enable_eagerloads": value}
+
+ @_generative
+ def _with_compile_options(self, **opt):
+ self._compile_options += opt
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.with_labels` and :meth:`_orm.Query.apply_labels`",
+ alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
+ "instead.",
+ )
+ def with_labels(self):
+ return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ apply_labels = with_labels
+
+ @property
+ def get_label_style(self):
+ """
+ Retrieve the current label style.
+
+ .. versionadded:: 1.4
+
+ """
+ return self._label_style
+
+ def set_label_style(self, style):
+ """Apply column labels to the return value of Query.statement.
+
+ Indicates that this Query's `statement` accessor should return
+ a SELECT statement that applies labels to all columns in the
+ form <tablename>_<columnname>; this is commonly used to
+ disambiguate columns from multiple tables which have the same
+ name.
+
+ When the `Query` actually issues SQL to load rows, it always
+ uses column labeling.
+
+ .. note:: The :meth:`_query.Query.set_label_style` method *only* applies
+ the output of :attr:`_query.Query.statement`, and *not* to any of
+ the result-row invoking systems of :class:`_query.Query` itself,
+ e.g.
+ :meth:`_query.Query.first`, :meth:`_query.Query.all`, etc.
+ To execute
+ a query using :meth:`_query.Query.set_label_style`, invoke the
+ :attr:`_query.Query.statement` using :meth:`.Session.execute`::
+
+ result = session.execute(
+ query
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .statement
+ )
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+ if self._label_style is not style:
+ self = self._generate()
+ self._label_style = style
+ return self
+
+ @_generative
+ def enable_assertions(self, value):
+ """Control whether assertions are generated.
+
+ When set to False, the returned Query will
+ not assert its state before certain operations,
+ including that LIMIT/OFFSET has not been applied
+ when filter() is called, no criterion exists
+ when get() is called, and no "from_statement()"
+ exists when filter()/order_by()/group_by() etc.
+ is called. This more permissive mode is used by
+ custom Query subclasses to specify criterion or
+ other modifiers outside of the usual usage patterns.
+
+ Care should be taken to ensure that the usage
+ pattern is even possible. A statement applied
+ by from_statement() will override any criterion
+ set by filter() or order_by(), for example.
+
+ """
+ self._enable_assertions = value
+
+ @property
+ def whereclause(self):
+ """A readonly attribute which returns the current WHERE criterion for
+ this Query.
+
+ This returned value is a SQL expression construct, or ``None`` if no
+ criterion has been established.
+
+ """
+ return sql.elements.BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
+ @_generative
+ def _with_current_path(self, path):
+ """indicate that this query applies to objects loaded
+ within a certain path.
+
+ Used by deferred loaders (see strategies.py) which transfer
+ query options from an originating query to a newly generated
+ query intended for the deferred load.
+
+ """
+ self._compile_options += {"_current_path": path}
+
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ @util.deprecated_20(
+ ":meth:`_orm.Query.with_polymorphic`",
+ alternative="Use the orm.with_polymorphic() standalone function",
+ )
+ def with_polymorphic(
+ self, cls_or_mappers, selectable=None, polymorphic_on=None
+ ):
+ """Load columns for inheriting classes.
+
+ This is a legacy method which is replaced by the
+ :func:`_orm.with_polymorphic` function.
+
+ .. warning:: The :meth:`_orm.Query.with_polymorphic` method does
+ **not** support 1.4/2.0 style features including
+ :func:`_orm.with_loader_criteria`. Please migrate code
+ to use :func:`_orm.with_polymorphic`.
+
+ :meth:`_query.Query.with_polymorphic` applies transformations
+ to the "main" mapped class represented by this :class:`_query.Query`.
+ The "main" mapped class here means the :class:`_query.Query`
+ object's first argument is a full class, i.e.
+ ``session.query(SomeClass)``. These transformations allow additional
+ tables to be present in the FROM clause so that columns for a
+ joined-inheritance subclass are available in the query, both for the
+ purposes of load-time efficiency as well as the ability to use
+ these columns at query time.
+
+ .. seealso::
+
+ :ref:`with_polymorphic` - illustrates current patterns
+
+ """
+
+ entity = _legacy_filter_by_entity_zero(self)
+
+ wp = with_polymorphic(
+ entity,
+ cls_or_mappers,
+ selectable=selectable,
+ polymorphic_on=polymorphic_on,
+ )
+
+ self._compile_options = self._compile_options.add_to_element(
+ "_with_polymorphic_adapt_map", ((entity, inspect(wp)),)
+ )
+
+ @_generative
+ def yield_per(self, count):
+ r"""Yield only ``count`` rows at a time.
+
+ The purpose of this method is when fetching very large result sets
+ (> 10K rows), to batch results in sub-collections and yield them
+ out partially, so that the Python interpreter doesn't need to declare
+ very large areas of memory which is both time consuming and leads
+ to excessive memory use. The performance from fetching hundreds of
+ thousands of rows can often double when a suitable yield-per setting
+ (e.g. approximately 1000) is used, even with DBAPIs that buffer
+ rows (which are most).
+
+ As of SQLAlchemy 1.4, the :meth:`_orm.Query.yield_per` method is
+ equivalent to using the ``yield_per`` execution option at the ORM
+ level. See the section :ref:`orm_queryguide_yield_per` for further
+ background on this option.
+
+ .. seealso::
+
+ :ref:`orm_queryguide_yield_per`
+
+ """
+ self.load_options += {"_yield_per": count}
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.get`",
+ alternative="The method is now available as :meth:`_orm.Session.get`",
+ becomes_legacy=True,
+ )
+ def get(self, ident):
+ """Return an instance based on the given primary key identifier,
+ or ``None`` if not found.
+
+ E.g.::
+
+ my_user = session.query(User).get(5)
+
+ some_object = session.query(VersionedFoo).get((5, 10))
+
+ some_object = session.query(VersionedFoo).get(
+ {"id": 5, "version_id": 10})
+
+ :meth:`_query.Query.get` is special in that it provides direct
+ access to the identity map of the owning :class:`.Session`.
+ If the given primary key identifier is present
+ in the local identity map, the object is returned
+ directly from this collection and no SQL is emitted,
+ unless the object has been marked fully expired.
+ If not present,
+ a SELECT is performed in order to locate the object.
+
+ :meth:`_query.Query.get` also will perform a check if
+ the object is present in the identity map and
+ marked as expired - a SELECT
+ is emitted to refresh the object as well as to
+ ensure that the row is still present.
+ If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ :meth:`_query.Query.get` is only used to return a single
+ mapped instance, not multiple instances or
+ individual column constructs, and strictly
+ on a single primary key value. The originating
+ :class:`_query.Query` must be constructed in this way,
+ i.e. against a single mapped entity,
+ with no additional filtering criterion. Loading
+ options via :meth:`_query.Query.options` may be applied
+ however, and will be used if the object is not
+ yet locally present.
+
+ :param ident: A scalar, tuple, or dictionary representing the
+ primary key. For a composite (e.g. multiple column) primary key,
+ a tuple or dictionary should be passed.
+
+ For a single-column primary key, the scalar calling form is typically
+ the most expedient. If the primary key of a row is the value "5",
+ the call looks like::
+
+ my_object = query.get(5)
+
+ The tuple form contains primary key values typically in
+ the order in which they correspond to the mapped
+ :class:`_schema.Table`
+ object's primary key columns, or if the
+ :paramref:`_orm.Mapper.primary_key` configuration parameter were
+ used, in
+ the order used for that parameter. For example, if the primary key
+ of a row is represented by the integer
+ digits "5, 10" the call would look like::
+
+ my_object = query.get((5, 10))
+
+ The dictionary form should include as keys the mapped attribute names
+ corresponding to each element of the primary key. If the mapped class
+ has the attributes ``id``, ``version_id`` as the attributes which
+ store the object's primary key value, the call would look like::
+
+ my_object = query.get({"id": 5, "version_id": 10})
+
+ .. versionadded:: 1.3 the :meth:`_query.Query.get`
+ method now optionally
+ accepts a dictionary of attribute names to values in order to
+ indicate a primary key identifier.
+
+
+ :return: The object instance, or ``None``.
+
+ """
+ self._no_criterion_assertion("get", order_by=False, distinct=False)
+
+ # we still implement _get_impl() so that baked query can override
+ # it
+ return self._get_impl(ident, loading.load_on_pk_identity)
+
+ def _get_impl(self, primary_key_identity, db_load_fn, identity_token=None):
+ mapper = self._only_full_mapper_zero("get")
+ return self.session._get_impl(
+ mapper,
+ primary_key_identity,
+ db_load_fn,
+ populate_existing=self.load_options._populate_existing,
+ with_for_update=self._for_update_arg,
+ options=self._with_options,
+ identity_token=identity_token,
+ execution_options=self._execution_options,
+ )
+
+ @property
+ def lazy_loaded_from(self):
+ """An :class:`.InstanceState` that is using this :class:`_query.Query`
+ for a lazy load operation.
+
+ .. deprecated:: 1.4 This attribute should be viewed via the
+ :attr:`.ORMExecuteState.lazy_loaded_from` attribute, within
+ the context of the :meth:`.SessionEvents.do_orm_execute`
+ event.
+
+ .. seealso::
+
+ :attr:`.ORMExecuteState.lazy_loaded_from`
+
+ """
+ return self.load_options._lazy_loaded_from
+
+ @property
+ def _current_path(self):
+ return self._compile_options._current_path
+
+ @_generative
+ def correlate(self, *fromclauses):
+ """Return a :class:`.Query` construct which will correlate the given
+ FROM clauses to that of an enclosing :class:`.Query` or
+ :func:`~.expression.select`.
+
+ The method here accepts mapped classes, :func:`.aliased` constructs,
+ and :func:`.mapper` constructs as arguments, which are resolved into
+ expression constructs, in addition to appropriate expression
+ constructs.
+
+ The correlation arguments are ultimately passed to
+ :meth:`_expression.Select.correlate`
+ after coercion to expression constructs.
+
+ The correlation arguments take effect in such cases
+ as when :meth:`_query.Query.from_self` is used, or when
+ a subquery as returned by :meth:`_query.Query.subquery` is
+ embedded in another :func:`_expression.select` construct.
+
+ """
+
+ self._auto_correlate = False
+ if fromclauses and fromclauses[0] in {None, False}:
+ self._correlate = ()
+ else:
+ self._correlate = set(self._correlate).union(
+ coercions.expect(roles.FromClauseRole, f) for f in fromclauses
+ )
+
+ @_generative
+ def autoflush(self, setting):
+ """Return a Query with a specific 'autoflush' setting.
+
+ As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method
+ is equivalent to using the ``autoflush`` execution option at the
+ ORM level. See the section :ref:`orm_queryguide_autoflush` for
+ further background on this option.
+
+ """
+ self.load_options += {"_autoflush": setting}
+
+ @_generative
+ def populate_existing(self):
+ """Return a :class:`_query.Query`
+ that will expire and refresh all instances
+ as they are loaded, or reused from the current :class:`.Session`.
+
+ As of SQLAlchemy 1.4, the :meth:`_orm.Query.populate_existing` method
+ is equivalent to using the ``populate_existing`` execution option at
+ the ORM level. See the section :ref:`orm_queryguide_populate_existing`
+ for further background on this option.
+
+ """
+ self.load_options += {"_populate_existing": True}
+
+ @_generative
+ def _with_invoke_all_eagers(self, value):
+ """Set the 'invoke all eagers' flag which causes joined- and
+ subquery loaders to traverse into already-loaded related objects
+ and collections.
+
+ Default is that of :attr:`_query.Query._invoke_all_eagers`.
+
+ """
+ self.load_options += {"_invoke_all_eagers": value}
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.with_parent`",
+ alternative="Use the :func:`_orm.with_parent` standalone construct.",
+ becomes_legacy=True,
+ )
+ @util.preload_module("sqlalchemy.orm.relationships")
+ def with_parent(self, instance, property=None, from_entity=None): # noqa
+ """Add filtering criterion that relates the given instance
+ to a child object or collection, using its attribute state
+ as well as an established :func:`_orm.relationship()`
+ configuration.
+
+ The method uses the :func:`.with_parent` function to generate
+ the clause, the result of which is passed to
+ :meth:`_query.Query.filter`.
+
+ Parameters are the same as :func:`.with_parent`, with the exception
+ that the given property can be None, in which case a search is
+ performed against this :class:`_query.Query` object's target mapper.
+
+ :param instance:
+ An instance which has some :func:`_orm.relationship`.
+
+ :param property:
+ String property name, or class-bound attribute, which indicates
+ what relationship from the instance should be used to reconcile the
+ parent/child relationship.
+
+ :param from_entity:
+ Entity in which to consider as the left side. This defaults to the
+ "zero" entity of the :class:`_query.Query` itself.
+
+ """
+ relationships = util.preloaded.orm_relationships
+
+ if from_entity:
+ entity_zero = inspect(from_entity)
+ else:
+ entity_zero = _legacy_filter_by_entity_zero(self)
+ if property is None:
+ # TODO: deprecate, property has to be supplied
+ mapper = object_mapper(instance)
+
+ for prop in mapper.iterate_properties:
+ if (
+ isinstance(prop, relationships.RelationshipProperty)
+ and prop.mapper is entity_zero.mapper
+ ):
+ property = prop # noqa
+ break
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Could not locate a property which relates instances "
+ "of class '%s' to instances of class '%s'"
+ % (
+ entity_zero.mapper.class_.__name__,
+ instance.__class__.__name__,
+ )
+ )
+
+ return self.filter(with_parent(instance, property, entity_zero.entity))
+
+ @_generative
+ def add_entity(self, entity, alias=None):
+ """add a mapped entity to the list of result columns
+ to be returned."""
+
+ if alias is not None:
+ # TODO: deprecate
+ entity = aliased(entity, alias)
+
+ self._raw_columns = list(self._raw_columns)
+
+ self._raw_columns.append(
+ coercions.expect(
+ roles.ColumnsClauseRole, entity, apply_propagate_attrs=self
+ )
+ )
+
+ @_generative
+ def with_session(self, session):
+ """Return a :class:`_query.Query` that will use the given
+ :class:`.Session`.
+
+ While the :class:`_query.Query`
+ object is normally instantiated using the
+ :meth:`.Session.query` method, it is legal to build the
+ :class:`_query.Query`
+ directly without necessarily using a :class:`.Session`. Such a
+ :class:`_query.Query` object, or any :class:`_query.Query`
+ already associated
+ with a different :class:`.Session`, can produce a new
+ :class:`_query.Query`
+ object associated with a target session using this method::
+
+ from sqlalchemy.orm import Query
+
+ query = Query([MyClass]).filter(MyClass.id == 5)
+
+ result = query.with_session(my_session).one()
+
+ """
+
+ self.session = session
+
+ @util.deprecated_20(
+ ":meth:`_query.Query.from_self`",
+ alternative="The new approach is to use the :func:`.orm.aliased` "
+ "construct in conjunction with a subquery. See the section "
+ ":ref:`Selecting from the query itself as a subquery "
+ "<migration_20_query_from_self>` in the 2.0 migration notes for an "
+ "example.",
+ )
+ def from_self(self, *entities):
+ r"""return a Query that selects from this Query's
+ SELECT statement.
+
+ :meth:`_query.Query.from_self` essentially turns the SELECT statement
+ into a SELECT of itself. Given a query such as::
+
+ q = session.query(User).filter(User.name.like('e%'))
+
+ Given the :meth:`_query.Query.from_self` version::
+
+ q = session.query(User).filter(User.name.like('e%')).from_self()
+
+ This query renders as:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1) AS anon_1
+
+ There are lots of cases where :meth:`_query.Query.from_self`
+ may be useful.
+ A simple one is where above, we may want to apply a row LIMIT to
+ the set of user objects we query against, and then apply additional
+ joins against that row-limited set::
+
+ q = session.query(User).filter(User.name.like('e%')).\
+ limit(5).from_self().\
+ join(User.addresses).filter(Address.email.like('q%'))
+
+ The above query joins to the ``Address`` entity but only against the
+ first five results of the ``User`` query:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1
+ LIMIT :param_1) AS anon_1
+ JOIN address ON anon_1.user_id = address.user_id
+ WHERE address.email LIKE :email_1
+
+ **Automatic Aliasing**
+
+ Another key behavior of :meth:`_query.Query.from_self`
+ is that it applies
+ **automatic aliasing** to the entities inside the subquery, when
+ they are referenced on the outside. Above, if we continue to
+ refer to the ``User`` entity without any additional aliasing applied
+ to it, those references will be in terms of the subquery::
+
+ q = session.query(User).filter(User.name.like('e%')).\
+ limit(5).from_self().\
+ join(User.addresses).filter(Address.email.like('q%')).\
+ order_by(User.name)
+
+ The ORDER BY against ``User.name`` is aliased to be in terms of the
+ inner subquery:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1
+ LIMIT :param_1) AS anon_1
+ JOIN address ON anon_1.user_id = address.user_id
+ WHERE address.email LIKE :email_1 ORDER BY anon_1.user_name
+
+ The automatic aliasing feature only works in a **limited** way,
+ for simple filters and orderings. More ambitious constructions
+ such as referring to the entity in joins should prefer to use
+ explicit subquery objects, typically making use of the
+ :meth:`_query.Query.subquery`
+ method to produce an explicit subquery object.
+ Always test the structure of queries by viewing the SQL to ensure
+ a particular structure does what's expected!
+
+ **Changing the Entities**
+
+ :meth:`_query.Query.from_self`
+ also includes the ability to modify what
+ columns are being queried. In our example, we want ``User.id``
+ to be queried by the inner query, so that we can join to the
+ ``Address`` entity on the outside, but we only wanted the outer
+ query to return the ``Address.email`` column::
+
+ q = session.query(User).filter(User.name.like('e%')).\
+ limit(5).from_self(Address.email).\
+ join(User.addresses).filter(Address.email.like('q%'))
+
+ yielding:
+
+ .. sourcecode:: sql
+
+ SELECT address.email AS address_email
+ FROM (SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE "user".name LIKE :name_1
+ LIMIT :param_1) AS anon_1
+ JOIN address ON anon_1.user_id = address.user_id
+ WHERE address.email LIKE :email_1
+
+ **Looking out for Inner / Outer Columns**
+
+ Keep in mind that when referring to columns that originate from
+ inside the subquery, we need to ensure they are present in the
+ columns clause of the subquery itself; this is an ordinary aspect of
+ SQL. For example, if we wanted to load from a joined entity inside
+ the subquery using :func:`.contains_eager`, we need to add those
+ columns. Below illustrates a join of ``Address`` to ``User``,
+ then a subquery, and then we'd like :func:`.contains_eager` to access
+ the ``User`` columns::
+
+ q = session.query(Address).join(Address.user).\
+ filter(User.name.like('e%'))
+
+ q = q.add_entity(User).from_self().\
+ options(contains_eager(Address.user))
+
+ We use :meth:`_query.Query.add_entity` above **before** we call
+ :meth:`_query.Query.from_self`
+ so that the ``User`` columns are present
+ in the inner subquery, so that they are available to the
+ :func:`.contains_eager` modifier we are using on the outside,
+ producing:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.address_id AS anon_1_address_id,
+ anon_1.address_email AS anon_1_address_email,
+ anon_1.address_user_id AS anon_1_address_user_id,
+ anon_1.user_id AS anon_1_user_id,
+ anon_1.user_name AS anon_1_user_name
+ FROM (
+ SELECT address.id AS address_id,
+ address.email AS address_email,
+ address.user_id AS address_user_id,
+ "user".id AS user_id,
+ "user".name AS user_name
+ FROM address JOIN "user" ON "user".id = address.user_id
+ WHERE "user".name LIKE :name_1) AS anon_1
+
+ If we didn't call ``add_entity(User)``, but still asked
+ :func:`.contains_eager` to load the ``User`` entity, it would be
+ forced to add the table on the outside without the correct
+ join criteria - note the ``anon1, "user"`` phrase at
+ the end:
+
+ .. sourcecode:: sql
+
+ -- incorrect query
+ SELECT anon_1.address_id AS anon_1_address_id,
+ anon_1.address_email AS anon_1_address_email,
+ anon_1.address_user_id AS anon_1_address_user_id,
+ "user".id AS user_id,
+ "user".name AS user_name
+ FROM (
+ SELECT address.id AS address_id,
+ address.email AS address_email,
+ address.user_id AS address_user_id
+ FROM address JOIN "user" ON "user".id = address.user_id
+ WHERE "user".name LIKE :name_1) AS anon_1, "user"
+
+ :param \*entities: optional list of entities which will replace
+ those being selected.
+
+ """
+ return self._from_self(*entities)
+
+ def _from_self(self, *entities):
+ fromclause = (
+ self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .correlate(None)
+ .subquery()
+ ._anonymous_fromclause()
+ )
+
+ q = self._from_selectable(fromclause)
+
+ if entities:
+ q._set_entities(entities)
+ return q
+
+ @_generative
+ def _set_enable_single_crit(self, val):
+ self._compile_options += {"_enable_single_crit": val}
+
+ @_generative
+ def _from_selectable(self, fromclause, set_entity_from=True):
+ for attr in (
+ "_where_criteria",
+ "_order_by_clauses",
+ "_group_by_clauses",
+ "_limit_clause",
+ "_offset_clause",
+ "_last_joined_entity",
+ "_legacy_setup_joins",
+ "_memoized_select_entities",
+ "_distinct",
+ "_distinct_on",
+ "_having_criteria",
+ "_prefixes",
+ "_suffixes",
+ ):
+ self.__dict__.pop(attr, None)
+ self._set_select_from([fromclause], set_entity_from)
+ self._compile_options += {
+ "_enable_single_crit": False,
+ }
+
+ # this enables clause adaptation for non-ORM
+ # expressions.
+ # legacy. see test/orm/test_froms.py for various
+ # "oldstyle" tests that rely on this and the corresponding
+ # "newtyle" that do not.
+ self._compile_options += {"_orm_only_from_obj_alias": False}
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_query.Query.values` "
+ "is deprecated and will be removed in a "
+ "future release. Please use :meth:`_query.Query.with_entities`",
+ )
+ def values(self, *columns):
+ """Return an iterator yielding result tuples corresponding
+ to the given list of columns
+
+ """
+
+ if not columns:
+ return iter(())
+ q = self._clone().enable_eagerloads(False)
+ q._set_entities(columns)
+ if not q.load_options._yield_per:
+ q.load_options += {"_yield_per": 10}
+ return iter(q)
+
+ _values = values
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_query.Query.value` "
+ "is deprecated and will be removed in a "
+ "future release. Please use :meth:`_query.Query.with_entities` "
+ "in combination with :meth:`_query.Query.scalar`",
+ )
+ def value(self, column):
+ """Return a scalar result corresponding to the given
+ column expression.
+
+ """
+ try:
+ return next(self.values(column))[0]
+ except StopIteration:
+ return None
+
+ @_generative
+ def with_entities(self, *entities):
+ r"""Return a new :class:`_query.Query`
+ replacing the SELECT list with the
+ given entities.
+
+ e.g.::
+
+ # Users, filtered on some arbitrary criterion
+ # and then ordered by related email address
+ q = session.query(User).\
+ join(User.address).\
+ filter(User.name.like('%ed%')).\
+ order_by(Address.email)
+
+ # given *only* User.id==5, Address.email, and 'q', what
+ # would the *next* User in the result be ?
+ subq = q.with_entities(Address.email).\
+ order_by(None).\
+ filter(User.id==5).\
+ subquery()
+ q = q.join((subq, subq.c.email < Address.email)).\
+ limit(1)
+
+ """
+ _MemoizedSelectEntities._generate_for_statement(self)
+ self._set_entities(entities)
+
+ @_generative
+ def add_columns(self, *column):
+ """Add one or more column expressions to the list
+ of result columns to be returned."""
+
+ self._raw_columns = list(self._raw_columns)
+
+ self._raw_columns.extend(
+ coercions.expect(
+ roles.ColumnsClauseRole,
+ c,
+ apply_propagate_attrs=self,
+ post_inspect=True,
+ )
+ for c in column
+ )
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_query.Query.add_column` "
+ "is deprecated and will be removed in a "
+ "future release. Please use :meth:`_query.Query.add_columns`",
+ )
+ def add_column(self, column):
+ """Add a column expression to the list of result columns to be
+ returned.
+
+ """
+ return self.add_columns(column)
+
+ @_generative
+ def options(self, *args):
+ """Return a new :class:`_query.Query` object,
+ applying the given list of
+ mapper options.
+
+ Most supplied options regard changing how column- and
+ relationship-mapped attributes are loaded.
+
+ .. seealso::
+
+ :ref:`deferred_options`
+
+ :ref:`relationship_loader_options`
+
+ """
+
+ opts = tuple(util.flatten_iterator(args))
+ if self._compile_options._current_path:
+ for opt in opts:
+ if opt._is_legacy_option:
+ opt.process_query_conditionally(self)
+ else:
+ for opt in opts:
+ if opt._is_legacy_option:
+ opt.process_query(self)
+
+ self._with_options += opts
+
+ def with_transformation(self, fn):
+ """Return a new :class:`_query.Query` object transformed by
+ the given function.
+
+ E.g.::
+
+ def filter_something(criterion):
+ def transform(q):
+ return q.filter(criterion)
+ return transform
+
+ q = q.with_transformation(filter_something(x==5))
+
+ This allows ad-hoc recipes to be created for :class:`_query.Query`
+ objects. See the example at :ref:`hybrid_transformers`.
+
+ """
+ return fn(self)
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :meth:`_query.Query.execution_options`
+ """
+ return self._execution_options
+
+ @_generative
+ def execution_options(self, **kwargs):
+ """Set non-SQL options which take effect during execution.
+
+ Options allowed here include all of those accepted by
+ :meth:`_engine.Connection.execution_options`, as well as a series
+ of ORM specific options:
+
+ ``populate_existing=True`` - equivalent to using
+ :meth:`_orm.Query.populate_existing`
+
+ ``autoflush=True|False`` - equivalent to using
+ :meth:`_orm.Query.autoflush`
+
+ ``yield_per=<value>`` - equivalent to using
+ :meth:`_orm.Query.yield_per`
+
+ Note that the ``stream_results`` execution option is enabled
+ automatically if the :meth:`~sqlalchemy.orm.query.Query.yield_per()`
+ method or execution option is used.
+
+ .. versionadded:: 1.4 - added ORM options to
+ :meth:`_orm.Query.execution_options`
+
+ The execution options may also be specified on a per execution basis
+ when using :term:`2.0 style` queries via the
+ :paramref:`_orm.Session.execution_options` parameter.
+
+ .. warning:: The
+ :paramref:`_engine.Connection.execution_options.stream_results`
+ parameter should not be used at the level of individual ORM
+ statement executions, as the :class:`_orm.Session` will not track
+ objects from different schema translate maps within a single
+ session. For multiple schema translate maps within the scope of a
+ single :class:`_orm.Session`, see :ref:`examples_sharding`.
+
+
+ .. seealso::
+
+ :ref:`engine_stream_results`
+
+ :meth:`_query.Query.get_execution_options`
+
+ """
+ self._execution_options = self._execution_options.union(kwargs)
+
+ @_generative
+ def with_for_update(
+ self,
+ read=False,
+ nowait=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
+ """return a new :class:`_query.Query`
+ with the specified options for the
+ ``FOR UPDATE`` clause.
+
+ The behavior of this method is identical to that of
+ :meth:`_expression.GenerativeSelect.with_for_update`.
+ When called with no arguments,
+ the resulting ``SELECT`` statement will have a ``FOR UPDATE`` clause
+ appended. When additional arguments are specified, backend-specific
+ options such as ``FOR UPDATE NOWAIT`` or ``LOCK IN SHARE MODE``
+ can take effect.
+
+ E.g.::
+
+ q = sess.query(User).populate_existing().with_for_update(nowait=True, of=User)
+
+ The above query on a PostgreSQL backend will render like::
+
+ SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT
+
+ .. warning::
+
+ Using ``with_for_update`` in the context of eager loading
+ relationships is not officially supported or recommended by
+ SQLAlchemy and may not work with certain queries on various
+ database backends. When ``with_for_update`` is successfully used
+ with a query that involves :func:`_orm.joinedload`, SQLAlchemy will
+ attempt to emit SQL that locks all involved tables.
+
+ .. note:: It is generally a good idea to combine the use of the
+ :meth:`_orm.Query.populate_existing` method when using the
+ :meth:`_orm.Query.with_for_update` method. The purpose of
+ :meth:`_orm.Query.populate_existing` is to force all the data read
+ from the SELECT to be populated into the ORM objects returned,
+ even if these objects are already in the :term:`identity map`.
+
+ .. seealso::
+
+ :meth:`_expression.GenerativeSelect.with_for_update`
+ - Core level method with
+ full argument and behavioral description.
+
+ :meth:`_orm.Query.populate_existing` - overwrites attributes of
+ objects already loaded in the identity map.
+
+ """ # noqa: E501
+
+ self._for_update_arg = ForUpdateArg(
+ read=read,
+ nowait=nowait,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
+
+ @_generative
+ def params(self, *args, **kwargs):
+ r"""Add values for bind parameters which may have been
+ specified in filter().
+
+ Parameters may be specified using \**kwargs, or optionally a single
+ dictionary as the first positional argument. The reason for both is
+ that \**kwargs is convenient, however some parameter dictionaries
+ contain unicode keys in which case \**kwargs cannot be used.
+
+ """
+ if len(args) == 1:
+ kwargs.update(args[0])
+ elif len(args) > 0:
+ raise sa_exc.ArgumentError(
+ "params() takes zero or one positional argument, "
+ "which is a dictionary."
+ )
+ self._params = self._params.union(kwargs)
+
+ def where(self, *criterion):
+ """A synonym for :meth:`.Query.filter`.
+
+ .. versionadded:: 1.4
+
+ """
+ return self.filter(*criterion)
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def filter(self, *criterion):
+ r"""Apply the given filtering criterion to a copy
+ of this :class:`_query.Query`, using SQL expressions.
+
+ e.g.::
+
+ session.query(MyClass).filter(MyClass.name == 'some name')
+
+ Multiple criteria may be specified as comma separated; the effect
+ is that they will be joined together using the :func:`.and_`
+ function::
+
+ session.query(MyClass).\
+ filter(MyClass.name == 'some name', MyClass.id > 5)
+
+ The criterion is any SQL expression object applicable to the
+ WHERE clause of a select. String expressions are coerced
+ into SQL expression constructs via the :func:`_expression.text`
+ construct.
+
+ .. seealso::
+
+ :meth:`_query.Query.filter_by` - filter on keyword expressions.
+
+ """
+ for criterion in list(criterion):
+ criterion = coercions.expect(
+ roles.WhereHavingRole, criterion, apply_propagate_attrs=self
+ )
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if self._aliased_generation:
+ criterion = sql_util._deep_annotate(
+ criterion, {"aliased_generation": self._aliased_generation}
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ self._where_criteria += (criterion,)
+
+ @util.memoized_property
+ def _last_joined_entity(self):
+ if self._legacy_setup_joins:
+ return _legacy_determine_last_joined_entity(
+ self._legacy_setup_joins, self._entity_from_pre_ent_zero()
+ )
+ else:
+ return None
+
+ def _filter_by_zero(self):
+ """for the filter_by() method, return the target entity for which
+ we will attempt to derive an expression from based on string name.
+
+ """
+
+ if self._legacy_setup_joins:
+ _last_joined_entity = self._last_joined_entity
+ if _last_joined_entity is not None:
+ return _last_joined_entity
+
+ # discussion related to #7239
+ # special check determines if we should try to derive attributes
+ # for filter_by() from the "from object", i.e., if the user
+ # called query.select_from(some selectable).filter_by(some_attr=value).
+ # We don't want to do that in the case that methods like
+ # from_self(), select_entity_from(), or a set op like union() were
+ # called; while these methods also place a
+ # selectable in the _from_obj collection, they also set up
+ # the _set_base_alias boolean which turns on the whole "adapt the
+ # entity to this selectable" thing, meaning the query still continues
+ # to construct itself in terms of the lead entity that was passed
+ # to query(), e.g. query(User).from_self() is still in terms of User,
+ # and not the subquery that from_self() created. This feature of
+ # "implicitly adapt all occurrences of entity X to some arbitrary
+ # subquery" is the main thing I am trying to do away with in 2.0 as
+ # users should now used aliased() for that, but I can't entirely get
+ # rid of it due to query.union() and other set ops relying upon it.
+ #
+ # compare this to the base Select()._filter_by_zero() which can
+ # just return self._from_obj[0] if present, because there is no
+ # "_set_base_alias" feature.
+ #
+ # IOW, this conditional essentially detects if
+ # "select_from(some_selectable)" has been called, as opposed to
+ # "select_entity_from()", "from_self()"
+ # or "union() / some_set_op()".
+ if self._from_obj and not self._compile_options._set_base_alias:
+ return self._from_obj[0]
+
+ return self._raw_columns[0]
+
+ def filter_by(self, **kwargs):
+ r"""Apply the given filtering criterion to a copy
+ of this :class:`_query.Query`, using keyword expressions.
+
+ e.g.::
+
+ session.query(MyClass).filter_by(name = 'some name')
+
+ Multiple criteria may be specified as comma separated; the effect
+ is that they will be joined together using the :func:`.and_`
+ function::
+
+ session.query(MyClass).\
+ filter_by(name = 'some name', id = 5)
+
+ The keyword expressions are extracted from the primary
+ entity of the query, or the last entity that was the
+ target of a call to :meth:`_query.Query.join`.
+
+ .. seealso::
+
+ :meth:`_query.Query.filter` - filter on SQL expressions.
+
+ """
+ from_entity = self._filter_by_zero()
+ if from_entity is None:
+ raise sa_exc.InvalidRequestError(
+ "Can't use filter_by when the first entity '%s' of a query "
+ "is not a mapped class. Please use the filter method instead, "
+ "or change the order of the entities in the query"
+ % self._query_entity_zero()
+ )
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def order_by(self, *clauses):
+ """Apply one or more ORDER BY criteria to the query and return
+ the newly resulting :class:`_query.Query`.
+
+ e.g.::
+
+ q = session.query(Entity).order_by(Entity.id, Entity.name)
+
+ All existing ORDER BY criteria may be cancelled by passing
+ ``None`` by itself. New ORDER BY criteria may then be added by
+ invoking :meth:`_orm.Query.order_by` again, e.g.::
+
+ # will erase all ORDER BY and ORDER BY new_col alone
+ q = q.order_by(None).order_by(new_col)
+
+ .. seealso::
+
+ These sections describe ORDER BY in terms of :term:`2.0 style`
+ invocation but apply to :class:`_orm.Query` as well:
+
+ :ref:`tutorial_order_by` - in the :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ self._order_by_clauses = ()
+ else:
+ criterion = tuple(
+ coercions.expect(roles.OrderByRole, clause)
+ for clause in clauses
+ )
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if self._aliased_generation:
+ criterion = tuple(
+ [
+ sql_util._deep_annotate(
+ o, {"aliased_generation": self._aliased_generation}
+ )
+ for o in criterion
+ ]
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ self._order_by_clauses += criterion
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def group_by(self, *clauses):
+ """Apply one or more GROUP BY criterion to the query and return
+ the newly resulting :class:`_query.Query`.
+
+ All existing GROUP BY settings can be suppressed by
+ passing ``None`` - this will suppress any GROUP BY configured
+ on mappers as well.
+
+ .. seealso::
+
+ These sections describe GROUP BY in terms of :term:`2.0 style`
+ invocation but apply to :class:`_orm.Query` as well:
+
+ :ref:`tutorial_group_by_w_aggregates` - in the
+ :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and (clauses[0] is None or clauses[0] is False):
+ self._group_by_clauses = ()
+ else:
+ criterion = tuple(
+ coercions.expect(roles.GroupByRole, clause)
+ for clause in clauses
+ )
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if self._aliased_generation:
+ criterion = tuple(
+ [
+ sql_util._deep_annotate(
+ o, {"aliased_generation": self._aliased_generation}
+ )
+ for o in criterion
+ ]
+ )
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ self._group_by_clauses += criterion
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def having(self, criterion):
+ r"""Apply a HAVING criterion to the query and return the
+ newly resulting :class:`_query.Query`.
+
+ :meth:`_query.Query.having` is used in conjunction with
+ :meth:`_query.Query.group_by`.
+
+ HAVING criterion makes it possible to use filters on aggregate
+ functions like COUNT, SUM, AVG, MAX, and MIN, eg.::
+
+ q = session.query(User.id).\
+ join(User.addresses).\
+ group_by(User.id).\
+ having(func.count(Address.id) > 2)
+
+ """
+
+ self._having_criteria += (
+ coercions.expect(
+ roles.WhereHavingRole, criterion, apply_propagate_attrs=self
+ ),
+ )
+
+ def _set_op(self, expr_fn, *q):
+ return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
+
+ def union(self, *q):
+ """Produce a UNION of this Query against one or more queries.
+
+ e.g.::
+
+ q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar')
+ q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo')
+
+ q3 = q1.union(q2)
+
+ The method accepts multiple Query objects so as to control
+ the level of nesting. A series of ``union()`` calls such as::
+
+ x.union(y).union(z).all()
+
+ will nest on each ``union()``, and produces::
+
+ SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION
+ SELECT * FROM y) UNION SELECT * FROM Z)
+
+ Whereas::
+
+ x.union(y, z).all()
+
+ produces::
+
+ SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION
+ SELECT * FROM Z)
+
+ Note that many database backends do not allow ORDER BY to
+ be rendered on a query called within UNION, EXCEPT, etc.
+ To disable all ORDER BY clauses including those configured
+ on mappers, issue ``query.order_by(None)`` - the resulting
+ :class:`_query.Query` object will not render ORDER BY within
+ its SELECT statement.
+
+ """
+ return self._set_op(expression.union, *q)
+
+ def union_all(self, *q):
+ """Produce a UNION ALL of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.union_all, *q)
+
+ def intersect(self, *q):
+ """Produce an INTERSECT of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.intersect, *q)
+
+ def intersect_all(self, *q):
+ """Produce an INTERSECT ALL of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.intersect_all, *q)
+
+ def except_(self, *q):
+ """Produce an EXCEPT of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.except_, *q)
+
+ def except_all(self, *q):
+ """Produce an EXCEPT ALL of this Query against one or more queries.
+
+ Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See
+ that method for usage examples.
+
+ """
+ return self._set_op(expression.except_all, *q)
+
+ def _next_aliased_generation(self):
+ if "_aliased_generation_counter" not in self.__dict__:
+ self._aliased_generation_counter = 0
+ self._aliased_generation_counter += 1
+ return self._aliased_generation_counter
+
+ @_generative
+ @_assertions(_no_statement_condition, _no_limit_offset)
+ def join(self, target, *props, **kwargs):
+ r"""Create a SQL JOIN against this :class:`_query.Query`
+ object's criterion
+ and apply generatively, returning the newly resulting
+ :class:`_query.Query`.
+
+ **Simple Relationship Joins**
+
+ Consider a mapping between two classes ``User`` and ``Address``,
+ with a relationship ``User.addresses`` representing a collection
+ of ``Address`` objects associated with each ``User``. The most
+ common usage of :meth:`_query.Query.join`
+ is to create a JOIN along this
+ relationship, using the ``User.addresses`` attribute as an indicator
+ for how this should occur::
+
+ q = session.query(User).join(User.addresses)
+
+ Where above, the call to :meth:`_query.Query.join` along
+ ``User.addresses`` will result in SQL approximately equivalent to::
+
+ SELECT user.id, user.name
+ FROM user JOIN address ON user.id = address.user_id
+
+ In the above example we refer to ``User.addresses`` as passed to
+ :meth:`_query.Query.join` as the "on clause", that is, it indicates
+ how the "ON" portion of the JOIN should be constructed.
+
+ To construct a chain of joins, multiple :meth:`_query.Query.join`
+ calls may be used. The relationship-bound attribute implies both
+ the left and right side of the join at once::
+
+ q = session.query(User).\
+ join(User.orders).\
+ join(Order.items).\
+ join(Item.keywords)
+
+ .. note:: as seen in the above example, **the order in which each
+ call to the join() method occurs is important**. Query would not,
+ for example, know how to join correctly if we were to specify
+ ``User``, then ``Item``, then ``Order``, in our chain of joins; in
+ such a case, depending on the arguments passed, it may raise an
+ error that it doesn't know how to join, or it may produce invalid
+ SQL in which case the database will raise an error. In correct
+ practice, the
+ :meth:`_query.Query.join` method is invoked in such a way that lines
+ up with how we would want the JOIN clauses in SQL to be
+ rendered, and each call should represent a clear link from what
+ precedes it.
+
+ **Joins to a Target Entity or Selectable**
+
+ A second form of :meth:`_query.Query.join` allows any mapped entity or
+ core selectable construct as a target. In this usage,
+ :meth:`_query.Query.join` will attempt to create a JOIN along the
+ natural foreign key relationship between two entities::
+
+ q = session.query(User).join(Address)
+
+ In the above calling form, :meth:`_query.Query.join` is called upon to
+ create the "on clause" automatically for us. This calling form will
+ ultimately raise an error if either there are no foreign keys between
+ the two entities, or if there are multiple foreign key linkages between
+ the target entity and the entity or entities already present on the
+ left side such that creating a join requires more information. Note
+ that when indicating a join to a target without any ON clause, ORM
+ configured relationships are not taken into account.
+
+ **Joins to a Target with an ON Clause**
+
+ The third calling form allows both the target entity as well
+ as the ON clause to be passed explicitly. A example that includes
+ a SQL expression as the ON clause is as follows::
+
+ q = session.query(User).join(Address, User.id==Address.user_id)
+
+ The above form may also use a relationship-bound attribute as the
+ ON clause as well::
+
+ q = session.query(User).join(Address, User.addresses)
+
+ The above syntax can be useful for the case where we wish
+ to join to an alias of a particular target entity. If we wanted
+ to join to ``Address`` twice, it could be achieved using two
+ aliases set up using the :func:`~sqlalchemy.orm.aliased` function::
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+
+ q = session.query(User).\
+ join(a1, User.addresses).\
+ join(a2, User.addresses).\
+ filter(a1.email_address=='ed@foo.com').\
+ filter(a2.email_address=='ed@bar.com')
+
+ The relationship-bound calling form can also specify a target entity
+ using the :meth:`_orm.PropComparator.of_type` method; a query
+ equivalent to the one above would be::
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+
+ q = session.query(User).\
+ join(User.addresses.of_type(a1)).\
+ join(User.addresses.of_type(a2)).\
+ filter(a1.email_address == 'ed@foo.com').\
+ filter(a2.email_address == 'ed@bar.com')
+
+ **Augmenting Built-in ON Clauses**
+
+ As a substitute for providing a full custom ON condition for an
+ existing relationship, the :meth:`_orm.PropComparator.and_` function
+ may be applied to a relationship attribute to augment additional
+ criteria into the ON clause; the additional criteria will be combined
+ with the default criteria using AND::
+
+ q = session.query(User).join(
+ User.addresses.and_(Address.email_address != 'foo@bar.com')
+ )
+
+ .. versionadded:: 1.4
+
+ **Joining to Tables and Subqueries**
+
+
+ The target of a join may also be any table or SELECT statement,
+ which may be related to a target entity or not. Use the
+ appropriate ``.subquery()`` method in order to make a subquery
+ out of a query::
+
+ subq = session.query(Address).\
+ filter(Address.email_address == 'ed@foo.com').\
+ subquery()
+
+
+ q = session.query(User).join(
+ subq, User.id == subq.c.user_id
+ )
+
+ Joining to a subquery in terms of a specific relationship and/or
+ target entity may be achieved by linking the subquery to the
+ entity using :func:`_orm.aliased`::
+
+ subq = session.query(Address).\
+ filter(Address.email_address == 'ed@foo.com').\
+ subquery()
+
+ address_subq = aliased(Address, subq)
+
+ q = session.query(User).join(
+ User.addresses.of_type(address_subq)
+ )
+
+
+ **Controlling what to Join From**
+
+ In cases where the left side of the current state of
+ :class:`_query.Query` is not in line with what we want to join from,
+ the :meth:`_query.Query.select_from` method may be used::
+
+ q = session.query(Address).select_from(User).\
+ join(User.addresses).\
+ filter(User.name == 'ed')
+
+ Which will produce SQL similar to::
+
+ SELECT address.* FROM user
+ JOIN address ON user.id=address.user_id
+ WHERE user.name = :name_1
+
+ **Legacy Features of Query.join()**
+
+ .. deprecated:: 1.4 The following features are deprecated and will
+ be removed in SQLAlchemy 2.0.
+
+ The :meth:`_query.Query.join` method currently supports several
+ usage patterns and arguments that are considered to be legacy
+ as of SQLAlchemy 1.3. A deprecation path will follow
+ in the 1.4 series for the following features:
+
+
+ * Joining on relationship names rather than attributes::
+
+ session.query(User).join("addresses")
+
+ **Why it's legacy**: the string name does not provide enough context
+ for :meth:`_query.Query.join` to always know what is desired,
+ notably in that there is no indication of what the left side
+ of the join should be. This gives rise to flags like
+ ``from_joinpoint`` as well as the ability to place several
+ join clauses in a single :meth:`_query.Query.join` call
+ which don't solve the problem fully while also
+ adding new calling styles that are unnecessary and expensive to
+ accommodate internally.
+
+ **Modern calling pattern**: Use the actual relationship,
+ e.g. ``User.addresses`` in the above case::
+
+ session.query(User).join(User.addresses)
+
+ * Automatic aliasing with the ``aliased=True`` flag::
+
+ session.query(Node).join(Node.children, aliased=True).\
+ filter(Node.name == 'some name')
+
+ **Why it's legacy**: the automatic aliasing feature of
+ :class:`_query.Query` is intensely complicated, both in its internal
+ implementation as well as in its observed behavior, and is almost
+ never used. It is difficult to know upon inspection where and when
+ its aliasing of a target entity, ``Node`` in the above case, will be
+ applied and when it won't, and additionally the feature has to use
+ very elaborate heuristics to achieve this implicit behavior.
+
+ **Modern calling pattern**: Use the :func:`_orm.aliased` construct
+ explicitly::
+
+ from sqlalchemy.orm import aliased
+
+ n1 = aliased(Node)
+
+ session.query(Node).join(Node.children.of_type(n1)).\
+ filter(n1.name == 'some name')
+
+ * Multiple joins in one call::
+
+ session.query(User).join("orders", "items")
+
+ session.query(User).join(User.orders, Order.items)
+
+ session.query(User).join(
+ (Order, User.orders),
+ (Item, Item.order_id == Order.id)
+ )
+
+ session.query(User).join(Order, Item)
+
+ # ... and several more forms actually
+
+ **Why it's legacy**: being able to chain multiple ON clauses in one
+ call to :meth:`_query.Query.join` is yet another attempt to solve
+ the problem of being able to specify what entity to join from,
+ and is the source of a large variety of potential calling patterns
+ that are internally expensive and complicated to parse and
+ accommodate.
+
+ **Modern calling pattern**: Use relationship-bound attributes
+ or SQL-oriented ON clauses within separate calls, so that
+ each call to :meth:`_query.Query.join` knows what the left
+ side should be::
+
+ session.query(User).join(User.orders).join(
+ Item, Item.order_id == Order.id)
+
+
+ :param \*props: Incoming arguments for :meth:`_query.Query.join`,
+ the props collection in modern use should be considered to be a one
+ or two argument form, either as a single "target" entity or ORM
+ attribute-bound relationship, or as a target entity plus an "on
+ clause" which may be a SQL expression or ORM attribute-bound
+ relationship.
+
+ :param isouter=False: If True, the join used will be a left outer join,
+ just as if the :meth:`_query.Query.outerjoin` method were called.
+
+ :param full=False: render FULL OUTER JOIN; implies ``isouter``.
+
+ .. versionadded:: 1.1
+
+ :param from_joinpoint=False: When using ``aliased=True``, a setting
+ of True here will cause the join to be from the most recent
+ joined target, rather than starting back from the original
+ FROM clauses of the query.
+
+ .. note:: This flag is considered legacy.
+
+ :param aliased=False: If True, indicate that the JOIN target should be
+ anonymously aliased. Subsequent calls to :meth:`_query.Query.filter`
+ and similar will adapt the incoming criterion to the target
+ alias, until :meth:`_query.Query.reset_joinpoint` is called.
+
+ .. note:: This flag is considered legacy.
+
+ .. seealso::
+
+ :ref:`ormtutorial_joins` in the ORM tutorial.
+
+ :ref:`inheritance_toplevel` for details on how
+ :meth:`_query.Query.join` is used for inheritance relationships.
+
+ :func:`_orm.join` - a standalone ORM-level join function,
+ used internally by :meth:`_query.Query.join`, which in previous
+ SQLAlchemy versions was the primary ORM-level joining interface.
+
+ """
+
+ aliased, from_joinpoint, isouter, full = (
+ kwargs.pop("aliased", False),
+ kwargs.pop("from_joinpoint", False),
+ kwargs.pop("isouter", False),
+ kwargs.pop("full", False),
+ )
+
+ if aliased or from_joinpoint:
+ util.warn_deprecated_20(
+ "The ``aliased`` and ``from_joinpoint`` keyword arguments "
+ "to Query.join() are deprecated and will be removed "
+ "in SQLAlchemy 2.0."
+ )
+
+ if kwargs:
+ raise TypeError(
+ "unknown arguments: %s" % ", ".join(sorted(kwargs))
+ )
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if not from_joinpoint:
+ self._last_joined_entity = None
+ self._aliased_generation = None
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ if props:
+ onclause, legacy = props[0], props[1:]
+ else:
+ onclause = legacy = None
+
+ if not legacy and onclause is None and not isinstance(target, tuple):
+ # non legacy argument form
+ _props = [(target,)]
+ elif (
+ not legacy
+ and isinstance(
+ target,
+ (
+ expression.Selectable,
+ type,
+ AliasedClass,
+ types.FunctionType,
+ ),
+ )
+ and isinstance(
+ onclause,
+ (
+ elements.ColumnElement,
+ str,
+ interfaces.PropComparator,
+ types.FunctionType,
+ ),
+ )
+ ):
+ # non legacy argument form
+ _props = [(target, onclause)]
+ else:
+ # legacy forms. more time consuming :)
+ _props = []
+ _single = []
+ for prop in (target,) + props:
+ if isinstance(prop, tuple):
+ util.warn_deprecated_20(
+ "Query.join() will no longer accept tuples as "
+ "arguments in SQLAlchemy 2.0."
+ )
+ if _single:
+ _props.extend((_s,) for _s in _single)
+ _single = []
+
+ # this checks for an extremely ancient calling form of
+ # reversed tuples.
+ if isinstance(prop[0], (str, interfaces.PropComparator)):
+ prop = (prop[1], prop[0])
+
+ _props.append(prop)
+ else:
+ _single.append(prop)
+ if _single:
+ _props.extend((_s,) for _s in _single)
+
+ # legacy vvvvvvvvvvvvvvvvvvvvvvvvvvv
+ if aliased:
+ self._aliased_generation = self._next_aliased_generation()
+
+ if self._aliased_generation:
+ _props = [
+ (
+ prop[0],
+ sql_util._deep_annotate(
+ prop[1],
+ {"aliased_generation": self._aliased_generation},
+ )
+ if isinstance(prop[1], expression.ClauseElement)
+ else prop[1],
+ )
+ if len(prop) == 2
+ else prop
+ for prop in _props
+ ]
+
+ # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+ joins_to_add = tuple(
+ (
+ coercions.expect(
+ roles.JoinTargetRole,
+ prop[0],
+ legacy=True,
+ apply_propagate_attrs=self,
+ ),
+ (
+ coercions.expect(roles.OnClauseRole, prop[1], legacy=True)
+ # if not isinstance(prop[1], str)
+ # else prop[1]
+ )
+ if len(prop) == 2
+ else None,
+ None,
+ {
+ "isouter": isouter,
+ "aliased": aliased,
+ "from_joinpoint": True if i > 0 else from_joinpoint,
+ "full": full,
+ "aliased_generation": self._aliased_generation,
+ },
+ )
+ for i, prop in enumerate(_props)
+ )
+
+ if len(joins_to_add) > 1:
+ util.warn_deprecated_20(
+ "Passing a chain of multiple join conditions to Query.join() "
+ "is deprecated and will be removed in SQLAlchemy 2.0. "
+ "Please use individual join() calls per relationship."
+ )
+
+ self._legacy_setup_joins += joins_to_add
+
+ self.__dict__.pop("_last_joined_entity", None)
+
+ def outerjoin(self, target, *props, **kwargs):
+ """Create a left outer join against this ``Query`` object's criterion
+ and apply generatively, returning the newly resulting ``Query``.
+
+ Usage is the same as the ``join()`` method.
+
+ """
+ kwargs["isouter"] = True
+ return self.join(target, *props, **kwargs)
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def reset_joinpoint(self):
+ """Return a new :class:`.Query`, where the "join point" has
+ been reset back to the base FROM entities of the query.
+
+ This method is usually used in conjunction with the
+ ``aliased=True`` feature of the :meth:`~.Query.join`
+ method. See the example in :meth:`~.Query.join` for how
+ this is used.
+
+ """
+ self._last_joined_entity = None
+ self._aliased_generation = None
+
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ def select_from(self, *from_obj):
+ r"""Set the FROM clause of this :class:`.Query` explicitly.
+
+ :meth:`.Query.select_from` is often used in conjunction with
+ :meth:`.Query.join` in order to control which entity is selected
+ from on the "left" side of the join.
+
+ The entity or selectable object here effectively replaces the
+ "left edge" of any calls to :meth:`~.Query.join`, when no
+ joinpoint is otherwise established - usually, the default "join
+ point" is the leftmost entity in the :class:`~.Query` object's
+ list of entities to be selected.
+
+ A typical example::
+
+ q = session.query(Address).select_from(User).\
+ join(User.addresses).\
+ filter(User.name == 'ed')
+
+ Which produces SQL equivalent to::
+
+ SELECT address.* FROM user
+ JOIN address ON user.id=address.user_id
+ WHERE user.name = :name_1
+
+ :param \*from_obj: collection of one or more entities to apply
+ to the FROM clause. Entities can be mapped classes,
+ :class:`.AliasedClass` objects, :class:`.Mapper` objects
+ as well as core :class:`.FromClause` elements like subqueries.
+
+ .. versionchanged:: 0.9
+ This method no longer applies the given FROM object
+ to be the selectable from which matching entities
+ select from; the :meth:`.select_entity_from` method
+ now accomplishes this. See that method for a description
+ of this behavior.
+
+ .. seealso::
+
+ :meth:`~.Query.join`
+
+ :meth:`.Query.select_entity_from`
+
+ """
+
+ self._set_select_from(from_obj, False)
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.select_entity_from`",
+ alternative="Use the :func:`_orm.aliased` construct instead",
+ )
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ def select_entity_from(self, from_obj):
+ r"""Set the FROM clause of this :class:`_query.Query` to a
+ core selectable, applying it as a replacement FROM clause
+ for corresponding mapped entities.
+
+ The :meth:`_query.Query.select_entity_from`
+ method supplies an alternative
+ approach to the use case of applying an :func:`.aliased` construct
+ explicitly throughout a query. Instead of referring to the
+ :func:`.aliased` construct explicitly,
+ :meth:`_query.Query.select_entity_from` automatically *adapts* all
+ occurrences of the entity to the target selectable.
+
+ Given a case for :func:`.aliased` such as selecting ``User``
+ objects from a SELECT statement::
+
+ select_stmt = select(User).where(User.id == 7)
+ user_alias = aliased(User, select_stmt)
+
+ q = session.query(user_alias).\
+ filter(user_alias.name == 'ed')
+
+ Above, we apply the ``user_alias`` object explicitly throughout the
+ query. When it's not feasible for ``user_alias`` to be referenced
+ explicitly in many places, :meth:`_query.Query.select_entity_from`
+ may be
+ used at the start of the query to adapt the existing ``User`` entity::
+
+ q = session.query(User).\
+ select_entity_from(select_stmt.subquery()).\
+ filter(User.name == 'ed')
+
+ Above, the generated SQL will show that the ``User`` entity is
+ adapted to our statement, even in the case of the WHERE clause:
+
+ .. sourcecode:: sql
+
+ SELECT anon_1.id AS anon_1_id, anon_1.name AS anon_1_name
+ FROM (SELECT "user".id AS id, "user".name AS name
+ FROM "user"
+ WHERE "user".id = :id_1) AS anon_1
+ WHERE anon_1.name = :name_1
+
+ The :meth:`_query.Query.select_entity_from` method is similar to the
+ :meth:`_query.Query.select_from` method,
+ in that it sets the FROM clause
+ of the query. The difference is that it additionally applies
+ adaptation to the other parts of the query that refer to the
+ primary entity. If above we had used :meth:`_query.Query.select_from`
+ instead, the SQL generated would have been:
+
+ .. sourcecode:: sql
+
+ -- uses plain select_from(), not select_entity_from()
+ SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user", (SELECT "user".id AS id, "user".name AS name
+ FROM "user"
+ WHERE "user".id = :id_1) AS anon_1
+ WHERE "user".name = :name_1
+
+ To supply textual SQL to the :meth:`_query.Query.select_entity_from`
+ method,
+ we can make use of the :func:`_expression.text` construct. However,
+ the
+ :func:`_expression.text`
+ construct needs to be aligned with the columns of our
+ entity, which is achieved by making use of the
+ :meth:`_expression.TextClause.columns` method::
+
+ text_stmt = text("select id, name from user").columns(
+ User.id, User.name).subquery()
+ q = session.query(User).select_entity_from(text_stmt)
+
+ :meth:`_query.Query.select_entity_from` itself accepts an
+ :func:`.aliased`
+ object, so that the special options of :func:`.aliased` such as
+ :paramref:`.aliased.adapt_on_names` may be used within the
+ scope of the :meth:`_query.Query.select_entity_from`
+ method's adaptation
+ services. Suppose
+ a view ``user_view`` also returns rows from ``user``. If
+ we reflect this view into a :class:`_schema.Table`, this view has no
+ relationship to the :class:`_schema.Table` to which we are mapped,
+ however
+ we can use name matching to select from it::
+
+ user_view = Table('user_view', metadata,
+ autoload_with=engine)
+ user_view_alias = aliased(
+ User, user_view, adapt_on_names=True)
+ q = session.query(User).\
+ select_entity_from(user_view_alias).\
+ order_by(User.name)
+
+ .. versionchanged:: 1.1.7 The :meth:`_query.Query.select_entity_from`
+ method now accepts an :func:`.aliased` object as an alternative
+ to a :class:`_expression.FromClause` object.
+
+ :param from_obj: a :class:`_expression.FromClause`
+ object that will replace
+ the FROM clause of this :class:`_query.Query`.
+ It also may be an instance
+ of :func:`.aliased`.
+
+
+
+ .. seealso::
+
+ :meth:`_query.Query.select_from`
+
+ """
+
+ self._set_select_from([from_obj], True)
+ self._compile_options += {"_enable_single_crit": False}
+
+ def __getitem__(self, item):
+ return orm_util._getitem(
+ self,
+ item,
+ allow_negative=not self.session or not self.session.future,
+ )
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def slice(self, start, stop):
+ """Computes the "slice" of the :class:`_query.Query` represented by
+ the given indices and returns the resulting :class:`_query.Query`.
+
+ The start and stop indices behave like the argument to Python's
+ built-in :func:`range` function. This method provides an
+ alternative to using ``LIMIT``/``OFFSET`` to get a slice of the
+ query.
+
+ For example, ::
+
+ session.query(User).order_by(User.id).slice(1, 3)
+
+ renders as
+
+ .. sourcecode:: sql
+
+ SELECT users.id AS users_id,
+ users.name AS users_name
+ FROM users ORDER BY users.id
+ LIMIT ? OFFSET ?
+ (2, 1)
+
+ .. seealso::
+
+ :meth:`_query.Query.limit`
+
+ :meth:`_query.Query.offset`
+
+ """
+
+ self._limit_clause, self._offset_clause = sql_util._make_slice(
+ self._limit_clause, self._offset_clause, start, stop
+ )
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def limit(self, limit):
+ """Apply a ``LIMIT`` to the query and return the newly resulting
+ ``Query``.
+
+ """
+ self._limit_clause = sql_util._offset_or_limit_clause(limit)
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def offset(self, offset):
+ """Apply an ``OFFSET`` to the query and return the newly resulting
+ ``Query``.
+
+ """
+ self._offset_clause = sql_util._offset_or_limit_clause(offset)
+
+ @_generative
+ @_assertions(_no_statement_condition)
+ def distinct(self, *expr):
+ r"""Apply a ``DISTINCT`` to the query and return the newly resulting
+ ``Query``.
+
+
+ .. note::
+
+ The ORM-level :meth:`.distinct` call includes logic that will
+ automatically add columns from the ORDER BY of the query to the
+ columns clause of the SELECT statement, to satisfy the common need
+ of the database backend that ORDER BY columns be part of the SELECT
+ list when DISTINCT is used. These columns *are not* added to the
+ list of columns actually fetched by the :class:`_query.Query`,
+ however,
+ so would not affect results. The columns are passed through when
+ using the :attr:`_query.Query.statement` accessor, however.
+
+ .. deprecated:: 2.0 This logic is deprecated and will be removed
+ in SQLAlchemy 2.0. See :ref:`migration_20_query_distinct`
+ for a description of this use case in 2.0.
+
+ :param \*expr: optional column expressions. When present,
+ the PostgreSQL dialect will render a ``DISTINCT ON (<expressions>)``
+ construct.
+
+ .. deprecated:: 1.4 Using \*expr in other dialects is deprecated
+ and will raise :class:`_exc.CompileError` in a future version.
+
+ """
+ if expr:
+ self._distinct = True
+ self._distinct_on = self._distinct_on + tuple(
+ coercions.expect(roles.ByOfRole, e) for e in expr
+ )
+ else:
+ self._distinct = True
+
+ def all(self):
+ """Return the results represented by this :class:`_query.Query`
+ as a list.
+
+ This results in an execution of the underlying SQL statement.
+
+ .. warning:: The :class:`_query.Query` object,
+ when asked to return either
+ a sequence or iterator that consists of full ORM-mapped entities,
+ will **deduplicate entries based on primary key**. See the FAQ for
+ more details.
+
+ .. seealso::
+
+ :ref:`faq_query_deduplicating`
+ """
+ return self._iter().all()
+
+ @_generative
+ @_assertions(_no_clauseelement_condition)
+ def from_statement(self, statement):
+ """Execute the given SELECT statement and return results.
+
+ This method bypasses all internal statement compilation, and the
+ statement is executed without modification.
+
+ The statement is typically either a :func:`_expression.text`
+ or :func:`_expression.select` construct, and should return the set
+ of columns
+ appropriate to the entity class represented by this
+ :class:`_query.Query`.
+
+ .. seealso::
+
+ :ref:`orm_tutorial_literal_sql` - usage examples in the
+ ORM tutorial
+
+ """
+ statement = coercions.expect(
+ roles.SelectStatementRole, statement, apply_propagate_attrs=self
+ )
+ self._statement = statement
+
+ def first(self):
+ """Return the first result of this ``Query`` or
+ None if the result doesn't contain any row.
+
+ first() applies a limit of one within the generated SQL, so that
+ only one primary entity row is generated on the server side
+ (note this may consist of multiple result rows if join-loaded
+ collections are present).
+
+ Calling :meth:`_query.Query.first`
+ results in an execution of the underlying
+ query.
+
+ .. seealso::
+
+ :meth:`_query.Query.one`
+
+ :meth:`_query.Query.one_or_none`
+
+ """
+ # replicates limit(1) behavior
+ if self._statement is not None:
+ return self._iter().first()
+ else:
+ return self.limit(1)._iter().first()
+
+ def one_or_none(self):
+ """Return at most one result or raise an exception.
+
+ Returns ``None`` if the query selects
+ no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound``
+ if multiple object identities are returned, or if multiple
+ rows are returned for a query that returns only scalar values
+ as opposed to full identity-mapped entities.
+
+ Calling :meth:`_query.Query.one_or_none`
+ results in an execution of the
+ underlying query.
+
+ .. versionadded:: 1.0.9
+
+ Added :meth:`_query.Query.one_or_none`
+
+ .. seealso::
+
+ :meth:`_query.Query.first`
+
+ :meth:`_query.Query.one`
+
+ """
+ return self._iter().one_or_none()
+
+ def one(self):
+ """Return exactly one result or raise an exception.
+
+ Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects
+ no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound``
+ if multiple object identities are returned, or if multiple
+ rows are returned for a query that returns only scalar values
+ as opposed to full identity-mapped entities.
+
+ Calling :meth:`.one` results in an execution of the underlying query.
+
+ .. seealso::
+
+ :meth:`_query.Query.first`
+
+ :meth:`_query.Query.one_or_none`
+
+ """
+ return self._iter().one()
+
+ def scalar(self):
+ """Return the first element of the first result or None
+ if no rows present. If multiple rows are returned,
+ raises MultipleResultsFound.
+
+ >>> session.query(Item).scalar()
+ <Item>
+ >>> session.query(Item.id).scalar()
+ 1
+ >>> session.query(Item.id).filter(Item.id < 0).scalar()
+ None
+ >>> session.query(Item.id, Item.name).scalar()
+ 1
+ >>> session.query(func.count(Parent.id)).scalar()
+ 20
+
+ This results in an execution of the underlying query.
+
+ """
+ # TODO: not sure why we can't use result.scalar() here
+ try:
+ ret = self.one()
+ if not isinstance(ret, collections_abc.Sequence):
+ return ret
+ return ret[0]
+ except orm_exc.NoResultFound:
+ return None
+
+ def __iter__(self):
+ return self._iter().__iter__()
+
+ def _iter(self):
+ # new style execution.
+ params = self._params
+
+ statement = self._statement_20()
+ result = self.session.execute(
+ statement,
+ params,
+ execution_options={"_sa_orm_load_options": self.load_options},
+ )
+
+ # legacy: automatically set scalars, unique
+ if result._attributes.get("is_single_entity", False):
+ result = result.scalars()
+
+ if (
+ result._attributes.get("filtered", False)
+ and not self.load_options._yield_per
+ ):
+ result = result.unique()
+
+ return result
+
+ def __str__(self):
+ statement = self._statement_20()
+
+ try:
+ bind = (
+ self._get_bind_args(statement, self.session.get_bind)
+ if self.session
+ else None
+ )
+ except sa_exc.UnboundExecutionError:
+ bind = None
+
+ return str(statement.compile(bind))
+
+ def _get_bind_args(self, statement, fn, **kw):
+ return fn(clause=statement, **kw)
+
+ @property
+ def column_descriptions(self):
+ """Return metadata about the columns which would be
+ returned by this :class:`_query.Query`.
+
+ Format is a list of dictionaries::
+
+ user_alias = aliased(User, name='user2')
+ q = sess.query(User, User.id, user_alias)
+
+ # this expression:
+ q.column_descriptions
+
+ # would return:
+ [
+ {
+ 'name':'User',
+ 'type':User,
+ 'aliased':False,
+ 'expr':User,
+ 'entity': User
+ },
+ {
+ 'name':'id',
+ 'type':Integer(),
+ 'aliased':False,
+ 'expr':User.id,
+ 'entity': User
+ },
+ {
+ 'name':'user2',
+ 'type':User,
+ 'aliased':True,
+ 'expr':user_alias,
+ 'entity': user_alias
+ }
+ ]
+
+ .. seealso::
+
+ This API is available using :term:`2.0 style` queries as well,
+ documented at:
+
+ * :ref:`queryguide_inspection`
+
+ * :attr:`.Select.column_descriptions`
+
+ """
+
+ return _column_descriptions(self, legacy=True)
+
+ def instances(self, result_proxy, context=None):
+ """Return an ORM result given a :class:`_engine.CursorResult` and
+ :class:`.QueryContext`.
+
+ """
+ if context is None:
+ util.warn_deprecated(
+ "Using the Query.instances() method without a context "
+ "is deprecated and will be disallowed in a future release. "
+ "Please make use of :meth:`_query.Query.from_statement` "
+ "for linking ORM results to arbitrary select constructs.",
+ version="1.4",
+ )
+ compile_state = self._compile_state(for_statement=False)
+
+ context = QueryContext(
+ compile_state,
+ compile_state.statement,
+ self._params,
+ self.session,
+ self.load_options,
+ )
+
+ result = loading.instances(result_proxy, context)
+
+ # legacy: automatically set scalars, unique
+ if result._attributes.get("is_single_entity", False):
+ result = result.scalars()
+
+ if result._attributes.get("filtered", False):
+ result = result.unique()
+
+ return result
+
+ @util.deprecated_20(
+ ":meth:`_orm.Query.merge_result`",
+ alternative="The method is superseded by the "
+ ":func:`_orm.merge_frozen_result` function.",
+ becomes_legacy=True,
+ enable_warnings=False, # warnings occur via loading.merge_result
+ )
+ def merge_result(self, iterator, load=True):
+ """Merge a result into this :class:`_query.Query` object's Session.
+
+ Given an iterator returned by a :class:`_query.Query`
+ of the same structure
+ as this one, return an identical iterator of results, with all mapped
+ instances merged into the session using :meth:`.Session.merge`. This
+ is an optimized method which will merge all mapped instances,
+ preserving the structure of the result rows and unmapped columns with
+ less method overhead than that of calling :meth:`.Session.merge`
+ explicitly for each value.
+
+ The structure of the results is determined based on the column list of
+ this :class:`_query.Query` - if these do not correspond,
+ unchecked errors
+ will occur.
+
+ The 'load' argument is the same as that of :meth:`.Session.merge`.
+
+ For an example of how :meth:`_query.Query.merge_result` is used, see
+ the source code for the example :ref:`examples_caching`, where
+ :meth:`_query.Query.merge_result` is used to efficiently restore state
+ from a cache back into a target :class:`.Session`.
+
+ """
+
+ return loading.merge_result(self, iterator, load)
+
+ def exists(self):
+ """A convenience method that turns a query into an EXISTS subquery
+ of the form EXISTS (SELECT 1 FROM ... WHERE ...).
+
+ e.g.::
+
+ q = session.query(User).filter(User.name == 'fred')
+ session.query(q.exists())
+
+ Producing SQL similar to::
+
+ SELECT EXISTS (
+ SELECT 1 FROM users WHERE users.name = :name_1
+ ) AS anon_1
+
+ The EXISTS construct is usually used in the WHERE clause::
+
+ session.query(User.id).filter(q.exists()).scalar()
+
+ Note that some databases such as SQL Server don't allow an
+ EXISTS expression to be present in the columns clause of a
+ SELECT. To select a simple boolean value based on the exists
+ as a WHERE, use :func:`.literal`::
+
+ from sqlalchemy import literal
+
+ session.query(literal(True)).filter(q.exists()).scalar()
+
+ """
+
+ # .add_columns() for the case that we are a query().select_from(X),
+ # so that ".statement" can be produced (#2995) but also without
+ # omitting the FROM clause from a query(X) (#2818);
+ # .with_only_columns() after we have a core select() so that
+ # we get just "SELECT 1" without any entities.
+
+ inner = (
+ self.enable_eagerloads(False)
+ .add_columns(sql.literal_column("1"))
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .statement.with_only_columns(1)
+ )
+
+ ezero = self._entity_from_pre_ent_zero()
+ if ezero is not None:
+ inner = inner.select_from(ezero)
+
+ return sql.exists(inner)
+
+ def count(self):
+ r"""Return a count of rows this the SQL formed by this :class:`Query`
+ would return.
+
+ This generates the SQL for this Query as follows::
+
+ SELECT count(1) AS count_1 FROM (
+ SELECT <rest of query follows...>
+ ) AS anon_1
+
+ The above SQL returns a single row, which is the aggregate value
+ of the count function; the :meth:`_query.Query.count`
+ method then returns
+ that single integer value.
+
+ .. warning::
+
+ It is important to note that the value returned by
+ count() is **not the same as the number of ORM objects that this
+ Query would return from a method such as the .all() method**.
+ The :class:`_query.Query` object,
+ when asked to return full entities,
+ will **deduplicate entries based on primary key**, meaning if the
+ same primary key value would appear in the results more than once,
+ only one object of that primary key would be present. This does
+ not apply to a query that is against individual columns.
+
+ .. seealso::
+
+ :ref:`faq_query_deduplicating`
+
+ :ref:`orm_tutorial_query_returning`
+
+ For fine grained control over specific columns to count, to skip the
+ usage of a subquery or otherwise control of the FROM clause, or to use
+ other aggregate functions, use :attr:`~sqlalchemy.sql.expression.func`
+ expressions in conjunction with :meth:`~.Session.query`, i.e.::
+
+ from sqlalchemy import func
+
+ # count User records, without
+ # using a subquery.
+ session.query(func.count(User.id))
+
+ # return count of user "id" grouped
+ # by "name"
+ session.query(func.count(User.id)).\
+ group_by(User.name)
+
+ from sqlalchemy import distinct
+
+ # count distinct "name" values
+ session.query(func.count(distinct(User.name)))
+
+ """
+ col = sql.func.count(sql.literal_column("*"))
+ return self._from_self(col).enable_eagerloads(False).scalar()
+
+ def delete(self, synchronize_session="evaluate"):
+ r"""Perform a DELETE with an arbitrary WHERE clause.
+
+ Deletes rows matched by this query from the database.
+
+ E.g.::
+
+ sess.query(User).filter(User.age == 25).\
+ delete(synchronize_session=False)
+
+ sess.query(User).filter(User.age == 25).\
+ delete(synchronize_session='evaluate')
+
+ .. warning::
+
+ See the section :ref:`orm_expression_update_delete` for important
+ caveats and warnings, including limitations when using bulk UPDATE
+ and DELETE with mapper inheritance configurations.
+
+ :param synchronize_session: chooses the strategy to update the
+ attributes on objects in the session. See the section
+ :ref:`orm_expression_update_delete` for a discussion of these
+ strategies.
+
+ :return: the count of rows matched as returned by the database's
+ "row count" feature.
+
+ .. seealso::
+
+ :ref:`orm_expression_update_delete`
+
+ """
+
+ bulk_del = BulkDelete(self)
+ if self.dispatch.before_compile_delete:
+ for fn in self.dispatch.before_compile_delete:
+ new_query = fn(bulk_del.query, bulk_del)
+ if new_query is not None:
+ bulk_del.query = new_query
+
+ self = bulk_del.query
+
+ delete_ = sql.delete(*self._raw_columns)
+ delete_._where_criteria = self._where_criteria
+ result = self.session.execute(
+ delete_,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ )
+ bulk_del.result = result
+ self.session.dispatch.after_bulk_delete(bulk_del)
+ result.close()
+
+ return result.rowcount
+
+ def update(self, values, synchronize_session="evaluate", update_args=None):
+ r"""Perform an UPDATE with an arbitrary WHERE clause.
+
+ Updates rows matched by this query in the database.
+
+ E.g.::
+
+ sess.query(User).filter(User.age == 25).\
+ update({User.age: User.age - 10}, synchronize_session=False)
+
+ sess.query(User).filter(User.age == 25).\
+ update({"age": User.age - 10}, synchronize_session='evaluate')
+
+ .. warning::
+
+ See the section :ref:`orm_expression_update_delete` for important
+ caveats and warnings, including limitations when using arbitrary
+ UPDATE and DELETE with mapper inheritance configurations.
+
+ :param values: a dictionary with attributes names, or alternatively
+ mapped attributes or SQL expressions, as keys, and literal
+ values or sql expressions as values. If :ref:`parameter-ordered
+ mode <tutorial_parameter_ordered_updates>` is desired, the values can
+ be passed as a list of 2-tuples; this requires that the
+ :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
+ flag is passed to the :paramref:`.Query.update.update_args` dictionary
+ as well.
+
+ :param synchronize_session: chooses the strategy to update the
+ attributes on objects in the session. See the section
+ :ref:`orm_expression_update_delete` for a discussion of these
+ strategies.
+
+ :param update_args: Optional dictionary, if present will be passed
+ to the underlying :func:`_expression.update`
+ construct as the ``**kw`` for
+ the object. May be used to pass dialect-specific arguments such
+ as ``mysql_limit``, as well as other special arguments such as
+ :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`.
+
+ :return: the count of rows matched as returned by the database's
+ "row count" feature.
+
+
+ .. seealso::
+
+ :ref:`orm_expression_update_delete`
+
+
+ """
+
+ update_args = update_args or {}
+
+ bulk_ud = BulkUpdate(self, values, update_args)
+
+ if self.dispatch.before_compile_update:
+ for fn in self.dispatch.before_compile_update:
+ new_query = fn(bulk_ud.query, bulk_ud)
+ if new_query is not None:
+ bulk_ud.query = new_query
+ self = bulk_ud.query
+
+ upd = sql.update(*self._raw_columns)
+
+ ppo = update_args.pop("preserve_parameter_order", False)
+ if ppo:
+ upd = upd.ordered_values(*values)
+ else:
+ upd = upd.values(values)
+ if update_args:
+ upd = upd.with_dialect_options(**update_args)
+
+ upd._where_criteria = self._where_criteria
+ result = self.session.execute(
+ upd,
+ self._params,
+ execution_options={"synchronize_session": synchronize_session},
+ )
+ bulk_ud.result = result
+ self.session.dispatch.after_bulk_update(bulk_ud)
+ result.close()
+ return result.rowcount
+
+ def _compile_state(self, for_statement=False, **kw):
+ """Create an out-of-compiler ORMCompileState object.
+
+ The ORMCompileState object is normally created directly as a result
+ of the SQLCompiler.process() method being handed a Select()
+ or FromStatement() object that uses the "orm" plugin. This method
+ provides a means of creating this ORMCompileState object directly
+ without using the compiler.
+
+ This method is used only for deprecated cases, which include
+ the .from_self() method for a Query that has multiple levels
+ of .from_self() in use, as well as the instances() method. It is
+ also used within the test suite to generate ORMCompileState objects
+ for test purposes.
+
+ """
+
+ stmt = self._statement_20(for_statement=for_statement, **kw)
+ assert for_statement == stmt._compile_options._for_statement
+
+ # this chooses between ORMFromStatementCompileState and
+ # ORMSelectCompileState. We could also base this on
+ # query._statement is not None as we have the ORM Query here
+ # however this is the more general path.
+ compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
+ stmt, "orm"
+ )
+
+ return compile_state_cls.create_for_statement(stmt, None)
+
+ def _compile_context(self, for_statement=False):
+ compile_state = self._compile_state(for_statement=for_statement)
+ context = QueryContext(
+ compile_state,
+ compile_state.statement,
+ self._params,
+ self.session,
+ self.load_options,
+ )
+
+ return context
+
+
+class FromStatement(GroupedElement, SelectBase, Executable):
+ """Core construct that represents a load of ORM objects from a finished
+ select or text construct.
+
+ """
+
+ __visit_name__ = "orm_from_statement"
+
+ _compile_options = ORMFromStatementCompileState.default_compile_options
+
+ _compile_state_factory = ORMFromStatementCompileState.create_for_statement
+
+ _for_update_arg = None
+
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("element", InternalTraversal.dp_clauseelement),
+ ] + Executable._executable_traverse_internals
+
+ _cache_key_traversal = _traverse_internals + [
+ ("_compile_options", InternalTraversal.dp_has_cache_key)
+ ]
+
+ def __init__(self, entities, element):
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole,
+ ent,
+ apply_propagate_attrs=self,
+ post_inspect=True,
+ )
+ for ent in util.to_list(entities)
+ ]
+ self.element = element
+
+ def get_label_style(self):
+ return self._label_style
+
+ def set_label_style(self, label_style):
+ return SelectStatementGrouping(
+ self.element.set_label_style(label_style)
+ )
+
+ @property
+ def _label_style(self):
+ return self.element._label_style
+
+ def _compiler_dispatch(self, compiler, **kw):
+
+ """provide a fixed _compiler_dispatch method.
+
+ This is roughly similar to using the sqlalchemy.ext.compiler
+ ``@compiles`` extension.
+
+ """
+
+ compile_state = self._compile_state_factory(self, compiler, **kw)
+
+ toplevel = not compiler.stack
+
+ if toplevel:
+ compiler.compile_state = compile_state
+
+ return compiler.process(compile_state.statement, **kw)
+
+ def _ensure_disambiguated_names(self):
+ return self
+
+ def get_children(self, **kw):
+ for elem in itertools.chain.from_iterable(
+ element._from_objects for element in self._raw_columns
+ ):
+ yield elem
+ for elem in super(FromStatement, self).get_children(**kw):
+ yield elem
+
+ @property
+ def _returning(self):
+ return self.element._returning if self.element.is_dml else None
+
+ @property
+ def _inline(self):
+ return self.element._inline if self.element.is_dml else None
+
+
+class AliasOption(interfaces.LoaderOption):
+ @util.deprecated(
+ "1.4",
+ "The :class:`.AliasOption` is not necessary "
+ "for entities to be matched up to a query that is established "
+ "via :meth:`.Query.from_statement` and now does nothing.",
+ )
+ def __init__(self, alias):
+ r"""Return a :class:`.MapperOption` that will indicate to the
+ :class:`_query.Query`
+ that the main table has been aliased.
+
+ """
+
+ inherit_cache = False
+
+ def process_compile_state(self, compile_state):
+ pass
+
+
+class BulkUD(object):
+ """State used for the orm.Query version of update() / delete().
+
+ This object is now specific to Query only.
+
+ """
+
+ def __init__(self, query):
+ self.query = query.enable_eagerloads(False)
+ self._validate_query_state()
+ self.mapper = self.query._entity_from_pre_ent_zero()
+
+ def _validate_query_state(self):
+ for attr, methname, notset, op in (
+ ("_limit_clause", "limit()", None, operator.is_),
+ ("_offset_clause", "offset()", None, operator.is_),
+ ("_order_by_clauses", "order_by()", (), operator.eq),
+ ("_group_by_clauses", "group_by()", (), operator.eq),
+ ("_distinct", "distinct()", False, operator.is_),
+ (
+ "_from_obj",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ (
+ "_legacy_setup_joins",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ ):
+ if not op(getattr(self.query, attr), notset):
+ raise sa_exc.InvalidRequestError(
+ "Can't call Query.update() or Query.delete() "
+ "when %s has been called" % (methname,)
+ )
+
+ @property
+ def session(self):
+ return self.query.session
+
+
+class BulkUpdate(BulkUD):
+ """BulkUD which handles UPDATEs."""
+
+ def __init__(self, query, values, update_kwargs):
+ super(BulkUpdate, self).__init__(query)
+ self.values = values
+ self.update_kwargs = update_kwargs
+
+
+class BulkDelete(BulkUD):
+ """BulkUD which handles DELETEs."""
diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py
new file mode 100644
index 0000000..b51ea0e
--- /dev/null
+++ b/lib/sqlalchemy/orm/relationships.py
@@ -0,0 +1,3684 @@
+# orm/relationships.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Heuristics related to join conditions as used in
+:func:`_orm.relationship`.
+
+Provides the :class:`.JoinCondition` object, which encapsulates
+SQL annotation and aliasing behavior focused on the `primaryjoin`
+and `secondaryjoin` aspects of :func:`_orm.relationship`.
+
+"""
+from __future__ import absolute_import
+
+import collections
+import re
+import weakref
+
+from . import attributes
+from .base import _is_mapped_class
+from .base import state_str
+from .interfaces import MANYTOMANY
+from .interfaces import MANYTOONE
+from .interfaces import ONETOMANY
+from .interfaces import PropComparator
+from .interfaces import StrategizedProperty
+from .util import _orm_annotate
+from .util import _orm_deannotate
+from .util import CascadeOptions
+from .. import exc as sa_exc
+from .. import log
+from .. import schema
+from .. import sql
+from .. import util
+from ..inspection import inspect
+from ..sql import coercions
+from ..sql import expression
+from ..sql import operators
+from ..sql import roles
+from ..sql import visitors
+from ..sql.util import _deep_deannotate
+from ..sql.util import _shallow_annotate
+from ..sql.util import adapt_criterion_to_null
+from ..sql.util import ClauseAdapter
+from ..sql.util import join_condition
+from ..sql.util import selectables_overlap
+from ..sql.util import visit_binary_product
+
+
+def remote(expr):
+ """Annotate a portion of a primaryjoin expression
+ with a 'remote' annotation.
+
+ See the section :ref:`relationship_custom_foreign` for a
+ description of use.
+
+ .. seealso::
+
+ :ref:`relationship_custom_foreign`
+
+ :func:`.foreign`
+
+ """
+ return _annotate_columns(
+ coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True}
+ )
+
+
+def foreign(expr):
+ """Annotate a portion of a primaryjoin expression
+ with a 'foreign' annotation.
+
+ See the section :ref:`relationship_custom_foreign` for a
+ description of use.
+
+ .. seealso::
+
+ :ref:`relationship_custom_foreign`
+
+ :func:`.remote`
+
+ """
+
+ return _annotate_columns(
+ coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True}
+ )
+
+
+@log.class_logger
+class RelationshipProperty(StrategizedProperty):
+ """Describes an object property that holds a single item or list
+ of items that correspond to a related database table.
+
+ Public constructor is the :func:`_orm.relationship` function.
+
+ .. seealso::
+
+ :ref:`relationship_config_toplevel`
+
+ """
+
+ strategy_wildcard_key = "relationship"
+ inherit_cache = True
+
+ _links_to_entity = True
+
+ _persistence_only = dict(
+ passive_deletes=False,
+ passive_updates=True,
+ enable_typechecks=True,
+ active_history=False,
+ cascade_backrefs=True,
+ )
+
+ _dependency_processor = None
+
+ def __init__(
+ self,
+ argument,
+ secondary=None,
+ primaryjoin=None,
+ secondaryjoin=None,
+ foreign_keys=None,
+ uselist=None,
+ order_by=False,
+ backref=None,
+ back_populates=None,
+ overlaps=None,
+ post_update=False,
+ cascade=False,
+ viewonly=False,
+ lazy="select",
+ collection_class=None,
+ passive_deletes=_persistence_only["passive_deletes"],
+ passive_updates=_persistence_only["passive_updates"],
+ remote_side=None,
+ enable_typechecks=_persistence_only["enable_typechecks"],
+ join_depth=None,
+ comparator_factory=None,
+ single_parent=False,
+ innerjoin=False,
+ distinct_target_key=None,
+ doc=None,
+ active_history=_persistence_only["active_history"],
+ cascade_backrefs=_persistence_only["cascade_backrefs"],
+ load_on_pending=False,
+ bake_queries=True,
+ _local_remote_pairs=None,
+ query_class=None,
+ info=None,
+ omit_join=None,
+ sync_backref=None,
+ _legacy_inactive_history_style=False,
+ ):
+ """Provide a relationship between two mapped classes.
+
+ This corresponds to a parent-child or associative table relationship.
+ The constructed class is an instance of
+ :class:`.RelationshipProperty`.
+
+ A typical :func:`_orm.relationship`, used in a classical mapping::
+
+ mapper(Parent, properties={
+ 'children': relationship(Child)
+ })
+
+ Some arguments accepted by :func:`_orm.relationship`
+ optionally accept a
+ callable function, which when called produces the desired value.
+ The callable is invoked by the parent :class:`_orm.Mapper` at "mapper
+ initialization" time, which happens only when mappers are first used,
+ and is assumed to be after all mappings have been constructed. This
+ can be used to resolve order-of-declaration and other dependency
+ issues, such as if ``Child`` is declared below ``Parent`` in the same
+ file::
+
+ mapper(Parent, properties={
+ "children":relationship(lambda: Child,
+ order_by=lambda: Child.id)
+ })
+
+ When using the :ref:`declarative_toplevel` extension, the Declarative
+ initializer allows string arguments to be passed to
+ :func:`_orm.relationship`. These string arguments are converted into
+ callables that evaluate the string as Python code, using the
+ Declarative class-registry as a namespace. This allows the lookup of
+ related classes to be automatic via their string name, and removes the
+ need for related classes to be imported into the local module space
+ before the dependent classes have been declared. It is still required
+ that the modules in which these related classes appear are imported
+ anywhere in the application at some point before the related mappings
+ are actually used, else a lookup error will be raised when the
+ :func:`_orm.relationship`
+ attempts to resolve the string reference to the
+ related class. An example of a string- resolved class is as
+ follows::
+
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class Parent(Base):
+ __tablename__ = 'parent'
+ id = Column(Integer, primary_key=True)
+ children = relationship("Child", order_by="Child.id")
+
+ .. seealso::
+
+ :ref:`relationship_config_toplevel` - Full introductory and
+ reference documentation for :func:`_orm.relationship`.
+
+ :ref:`tutorial_orm_related_objects` - ORM tutorial introduction.
+
+ :param argument:
+ A mapped class, or actual :class:`_orm.Mapper` instance,
+ representing
+ the target of the relationship.
+
+ :paramref:`_orm.relationship.argument`
+ may also be passed as a callable
+ function which is evaluated at mapper initialization time, and may
+ be passed as a string name when using Declarative.
+
+ .. warning:: Prior to SQLAlchemy 1.3.16, this value is interpreted
+ using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. versionchanged 1.3.16::
+
+ The string evaluation of the main "argument" no longer accepts an
+ open ended Python expression, instead only accepting a string
+ class name or dotted package-qualified name.
+
+ .. seealso::
+
+ :ref:`declarative_configuring_relationships` - further detail
+ on relationship configuration when using Declarative.
+
+ :param secondary:
+ For a many-to-many relationship, specifies the intermediary
+ table, and is typically an instance of :class:`_schema.Table`.
+ In less common circumstances, the argument may also be specified
+ as an :class:`_expression.Alias` construct, or even a
+ :class:`_expression.Join` construct.
+
+ :paramref:`_orm.relationship.secondary` may
+ also be passed as a callable function which is evaluated at
+ mapper initialization time. When using Declarative, it may also
+ be a string argument noting the name of a :class:`_schema.Table`
+ that is
+ present in the :class:`_schema.MetaData`
+ collection associated with the
+ parent-mapped :class:`_schema.Table`.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ The :paramref:`_orm.relationship.secondary` keyword argument is
+ typically applied in the case where the intermediary
+ :class:`_schema.Table`
+ is not otherwise expressed in any direct class mapping. If the
+ "secondary" table is also explicitly mapped elsewhere (e.g. as in
+ :ref:`association_pattern`), one should consider applying the
+ :paramref:`_orm.relationship.viewonly` flag so that this
+ :func:`_orm.relationship`
+ is not used for persistence operations which
+ may conflict with those of the association object pattern.
+
+ .. seealso::
+
+ :ref:`relationships_many_to_many` - Reference example of "many
+ to many".
+
+ :ref:`self_referential_many_to_many` - Specifics on using
+ many-to-many in a self-referential case.
+
+ :ref:`declarative_many_to_many` - Additional options when using
+ Declarative.
+
+ :ref:`association_pattern` - an alternative to
+ :paramref:`_orm.relationship.secondary`
+ when composing association
+ table relationships, allowing additional attributes to be
+ specified on the association table.
+
+ :ref:`composite_secondary_join` - a lesser-used pattern which
+ in some cases can enable complex :func:`_orm.relationship` SQL
+ conditions to be used.
+
+ .. versionadded:: 0.9.2 :paramref:`_orm.relationship.secondary`
+ works
+ more effectively when referring to a :class:`_expression.Join`
+ instance.
+
+ :param active_history=False:
+ When ``True``, indicates that the "previous" value for a
+ many-to-one reference should be loaded when replaced, if
+ not already loaded. Normally, history tracking logic for
+ simple many-to-ones only needs to be aware of the "new"
+ value in order to perform a flush. This flag is available
+ for applications that make use of
+ :func:`.attributes.get_history` which also need to know
+ the "previous" value of the attribute.
+
+ :param backref:
+ A reference to a string relationship name, or a :func:`_orm.backref`
+ construct, which will be used to automatically generate a new
+ :func:`_orm.relationship` on the related class, which then refers to
+ this one using a bi-directional
+ :paramref:`_orm.relationship.back_populates` configuration.
+
+ In modern Python, explicit use of :func:`_orm.relationship` with
+ :paramref:`_orm.relationship.back_populates` should be preferred, as
+ it is more robust in terms of mapper configuration as well as more
+ conceptually straightforward. It also integrates with new :pep:`484`
+ typing features introduced in SQLAlchemy 2.0 which is not possible
+ with dynamically generated attributes.
+
+ .. seealso::
+
+ :ref:`relationships_backref` - notes on using
+ :paramref:`_orm.relationship.backref`
+
+ :ref:`tutorial_orm_related_objects` - in the
+ :ref:`unified_tutorial`, presents an overview of bi-directional
+ relationship configuration and behaviors using
+ :paramref:`_orm.relationship.back_populates`
+
+ :func:`.backref` - allows control over :func:`_orm.relationship`
+ configuration when using :paramref:`_orm.relationship.backref`.
+
+
+ :param back_populates:
+ Indicates the name of a :func:`_orm.relationship` on the related
+ class that will be synchronized with this one. It is usually
+ expected that the :func:`_orm.relationship` on the related class
+ also refer to this one. This allows objects on both sides of
+ each :func:`_orm.relationship` to synchronize in-Python state
+ changes and also provides directives to the :term:`unit of work`
+ flush process how changes along these relationships should
+ be persisted.
+
+ .. seealso::
+
+ :ref:`tutorial_orm_related_objects` - in the
+ :ref:`unified_tutorial`, presents an overview of bi-directional
+ relationship configuration and behaviors.
+
+ :ref:`relationship_patterns` - includes many examples of
+ :paramref:`_orm.relationship.back_populates`.
+
+ :param overlaps:
+ A string name or comma-delimited set of names of other relationships
+ on either this mapper, a descendant mapper, or a target mapper with
+ which this relationship may write to the same foreign keys upon
+ persistence. The only effect this has is to eliminate the
+ warning that this relationship will conflict with another upon
+ persistence. This is used for such relationships that are truly
+ capable of conflicting with each other on write, but the application
+ will ensure that no such conflicts occur.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`error_qzyx` - usage example
+
+ :param bake_queries=True:
+ Legacy parameter, not used.
+
+ .. versionchanged:: 1.4.23 the "lambda caching" system is no longer
+ used by loader strategies and the ``bake_queries`` parameter
+ has no effect.
+
+ :param cascade:
+ A comma-separated list of cascade rules which determines how
+ Session operations should be "cascaded" from parent to child.
+ This defaults to ``False``, which means the default cascade
+ should be used - this default cascade is ``"save-update, merge"``.
+
+ The available cascades are ``save-update``, ``merge``,
+ ``expunge``, ``delete``, ``delete-orphan``, and ``refresh-expire``.
+ An additional option, ``all`` indicates shorthand for
+ ``"save-update, merge, refresh-expire,
+ expunge, delete"``, and is often used as in ``"all, delete-orphan"``
+ to indicate that related objects should follow along with the
+ parent object in all cases, and be deleted when de-associated.
+
+ .. seealso::
+
+ :ref:`unitofwork_cascades` - Full detail on each of the available
+ cascade options.
+
+ :param cascade_backrefs=True:
+ A boolean value indicating if the ``save-update`` cascade should
+ operate along an assignment event intercepted by a backref.
+ When set to ``False``, the attribute managed by this relationship
+ will not cascade an incoming transient object into the session of a
+ persistent parent, if the event is received via backref.
+
+ .. deprecated:: 1.4 The
+ :paramref:`_orm.relationship.cascade_backrefs`
+ flag will default to False in all cases in SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :ref:`backref_cascade` - Full discussion and examples on how
+ the :paramref:`_orm.relationship.cascade_backrefs` option is used.
+
+ :param collection_class:
+ A class or callable that returns a new list-holding object. will
+ be used in place of a plain list for storing elements.
+
+ .. seealso::
+
+ :ref:`custom_collections` - Introductory documentation and
+ examples.
+
+ :param comparator_factory:
+ A class which extends :class:`.RelationshipProperty.Comparator`
+ which provides custom SQL clause generation for comparison
+ operations.
+
+ .. seealso::
+
+ :class:`.PropComparator` - some detail on redefining comparators
+ at this level.
+
+ :ref:`custom_comparators` - Brief intro to this feature.
+
+
+ :param distinct_target_key=None:
+ Indicate if a "subquery" eager load should apply the DISTINCT
+ keyword to the innermost SELECT statement. When left as ``None``,
+ the DISTINCT keyword will be applied in those cases when the target
+ columns do not comprise the full primary key of the target table.
+ When set to ``True``, the DISTINCT keyword is applied to the
+ innermost SELECT unconditionally.
+
+ It may be desirable to set this flag to False when the DISTINCT is
+ reducing performance of the innermost subquery beyond that of what
+ duplicate innermost rows may be causing.
+
+ .. versionchanged:: 0.9.0 -
+ :paramref:`_orm.relationship.distinct_target_key` now defaults to
+ ``None``, so that the feature enables itself automatically for
+ those cases where the innermost query targets a non-unique
+ key.
+
+ .. seealso::
+
+ :ref:`loading_toplevel` - includes an introduction to subquery
+ eager loading.
+
+ :param doc:
+ Docstring which will be applied to the resulting descriptor.
+
+ :param foreign_keys:
+
+ A list of columns which are to be used as "foreign key"
+ columns, or columns which refer to the value in a remote
+ column, within the context of this :func:`_orm.relationship`
+ object's :paramref:`_orm.relationship.primaryjoin` condition.
+ That is, if the :paramref:`_orm.relationship.primaryjoin`
+ condition of this :func:`_orm.relationship` is ``a.id ==
+ b.a_id``, and the values in ``b.a_id`` are required to be
+ present in ``a.id``, then the "foreign key" column of this
+ :func:`_orm.relationship` is ``b.a_id``.
+
+ In normal cases, the :paramref:`_orm.relationship.foreign_keys`
+ parameter is **not required.** :func:`_orm.relationship` will
+ automatically determine which columns in the
+ :paramref:`_orm.relationship.primaryjoin` condition are to be
+ considered "foreign key" columns based on those
+ :class:`_schema.Column` objects that specify
+ :class:`_schema.ForeignKey`,
+ or are otherwise listed as referencing columns in a
+ :class:`_schema.ForeignKeyConstraint` construct.
+ :paramref:`_orm.relationship.foreign_keys` is only needed when:
+
+ 1. There is more than one way to construct a join from the local
+ table to the remote table, as there are multiple foreign key
+ references present. Setting ``foreign_keys`` will limit the
+ :func:`_orm.relationship`
+ to consider just those columns specified
+ here as "foreign".
+
+ 2. The :class:`_schema.Table` being mapped does not actually have
+ :class:`_schema.ForeignKey` or
+ :class:`_schema.ForeignKeyConstraint`
+ constructs present, often because the table
+ was reflected from a database that does not support foreign key
+ reflection (MySQL MyISAM).
+
+ 3. The :paramref:`_orm.relationship.primaryjoin`
+ argument is used to
+ construct a non-standard join condition, which makes use of
+ columns or expressions that do not normally refer to their
+ "parent" column, such as a join condition expressed by a
+ complex comparison using a SQL function.
+
+ The :func:`_orm.relationship` construct will raise informative
+ error messages that suggest the use of the
+ :paramref:`_orm.relationship.foreign_keys` parameter when
+ presented with an ambiguous condition. In typical cases,
+ if :func:`_orm.relationship` doesn't raise any exceptions, the
+ :paramref:`_orm.relationship.foreign_keys` parameter is usually
+ not needed.
+
+ :paramref:`_orm.relationship.foreign_keys` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`relationship_foreign_keys`
+
+ :ref:`relationship_custom_foreign`
+
+ :func:`.foreign` - allows direct annotation of the "foreign"
+ columns within a :paramref:`_orm.relationship.primaryjoin`
+ condition.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.MapperProperty.info` attribute of this object.
+
+ :param innerjoin=False:
+ When ``True``, joined eager loads will use an inner join to join
+ against related tables instead of an outer join. The purpose
+ of this option is generally one of performance, as inner joins
+ generally perform better than outer joins.
+
+ This flag can be set to ``True`` when the relationship references an
+ object via many-to-one using local foreign keys that are not
+ nullable, or when the reference is one-to-one or a collection that
+ is guaranteed to have one or at least one entry.
+
+ The option supports the same "nested" and "unnested" options as
+ that of :paramref:`_orm.joinedload.innerjoin`. See that flag
+ for details on nested / unnested behaviors.
+
+ .. seealso::
+
+ :paramref:`_orm.joinedload.innerjoin` - the option as specified by
+ loader option, including detail on nesting behavior.
+
+ :ref:`what_kind_of_loading` - Discussion of some details of
+ various loader options.
+
+
+ :param join_depth:
+ When non-``None``, an integer value indicating how many levels
+ deep "eager" loaders should join on a self-referring or cyclical
+ relationship. The number counts how many times the same Mapper
+ shall be present in the loading condition along a particular join
+ branch. When left at its default of ``None``, eager loaders
+ will stop chaining when they encounter a the same target mapper
+ which is already higher up in the chain. This option applies
+ both to joined- and subquery- eager loaders.
+
+ .. seealso::
+
+ :ref:`self_referential_eager_loading` - Introductory documentation
+ and examples.
+
+ :param lazy='select': specifies
+ How the related items should be loaded. Default value is
+ ``select``. Values include:
+
+ * ``select`` - items should be loaded lazily when the property is
+ first accessed, using a separate SELECT statement, or identity map
+ fetch for simple many-to-one references.
+
+ * ``immediate`` - items should be loaded as the parents are loaded,
+ using a separate SELECT statement, or identity map fetch for
+ simple many-to-one references.
+
+ * ``joined`` - items should be loaded "eagerly" in the same query as
+ that of the parent, using a JOIN or LEFT OUTER JOIN. Whether
+ the join is "outer" or not is determined by the
+ :paramref:`_orm.relationship.innerjoin` parameter.
+
+ * ``subquery`` - items should be loaded "eagerly" as the parents are
+ loaded, using one additional SQL statement, which issues a JOIN to
+ a subquery of the original statement, for each collection
+ requested.
+
+ * ``selectin`` - items should be loaded "eagerly" as the parents
+ are loaded, using one or more additional SQL statements, which
+ issues a JOIN to the immediate parent object, specifying primary
+ key identifiers using an IN clause.
+
+ .. versionadded:: 1.2
+
+ * ``noload`` - no loading should occur at any time. This is to
+ support "write-only" attributes, or attributes which are
+ populated in some manner specific to the application.
+
+ * ``raise`` - lazy loading is disallowed; accessing
+ the attribute, if its value were not already loaded via eager
+ loading, will raise an :exc:`~sqlalchemy.exc.InvalidRequestError`.
+ This strategy can be used when objects are to be detached from
+ their attached :class:`.Session` after they are loaded.
+
+ .. versionadded:: 1.1
+
+ * ``raise_on_sql`` - lazy loading that emits SQL is disallowed;
+ accessing the attribute, if its value were not already loaded via
+ eager loading, will raise an
+ :exc:`~sqlalchemy.exc.InvalidRequestError`, **if the lazy load
+ needs to emit SQL**. If the lazy load can pull the related value
+ from the identity map or determine that it should be None, the
+ value is loaded. This strategy can be used when objects will
+ remain associated with the attached :class:`.Session`, however
+ additional SELECT statements should be blocked.
+
+ .. versionadded:: 1.1
+
+ * ``dynamic`` - the attribute will return a pre-configured
+ :class:`_query.Query` object for all read
+ operations, onto which further filtering operations can be
+ applied before iterating the results. See
+ the section :ref:`dynamic_relationship` for more details.
+
+ * True - a synonym for 'select'
+
+ * False - a synonym for 'joined'
+
+ * None - a synonym for 'noload'
+
+ .. seealso::
+
+ :doc:`/orm/loading_relationships` - Full documentation on
+ relationship loader configuration.
+
+ :ref:`dynamic_relationship` - detail on the ``dynamic`` option.
+
+ :ref:`collections_noload_raiseload` - notes on "noload" and "raise"
+
+ :param load_on_pending=False:
+ Indicates loading behavior for transient or pending parent objects.
+
+ When set to ``True``, causes the lazy-loader to
+ issue a query for a parent object that is not persistent, meaning it
+ has never been flushed. This may take effect for a pending object
+ when autoflush is disabled, or for a transient object that has been
+ "attached" to a :class:`.Session` but is not part of its pending
+ collection.
+
+ The :paramref:`_orm.relationship.load_on_pending`
+ flag does not improve
+ behavior when the ORM is used normally - object references should be
+ constructed at the object level, not at the foreign key level, so
+ that they are present in an ordinary way before a flush proceeds.
+ This flag is not not intended for general use.
+
+ .. seealso::
+
+ :meth:`.Session.enable_relationship_loading` - this method
+ establishes "load on pending" behavior for the whole object, and
+ also allows loading on objects that remain transient or
+ detached.
+
+ :param order_by:
+ Indicates the ordering that should be applied when loading these
+ items. :paramref:`_orm.relationship.order_by`
+ is expected to refer to
+ one of the :class:`_schema.Column`
+ objects to which the target class is
+ mapped, or the attribute itself bound to the target class which
+ refers to the column.
+
+ :paramref:`_orm.relationship.order_by`
+ may also be passed as a callable
+ function which is evaluated at mapper initialization time, and may
+ be passed as a Python-evaluable string when using Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ :param passive_deletes=False:
+ Indicates loading behavior during delete operations.
+
+ A value of True indicates that unloaded child items should not
+ be loaded during a delete operation on the parent. Normally,
+ when a parent item is deleted, all child items are loaded so
+ that they can either be marked as deleted, or have their
+ foreign key to the parent set to NULL. Marking this flag as
+ True usually implies an ON DELETE <CASCADE|SET NULL> rule is in
+ place which will handle updating/deleting child rows on the
+ database side.
+
+ Additionally, setting the flag to the string value 'all' will
+ disable the "nulling out" of the child foreign keys, when the parent
+ object is deleted and there is no delete or delete-orphan cascade
+ enabled. This is typically used when a triggering or error raise
+ scenario is in place on the database side. Note that the foreign
+ key attributes on in-session child objects will not be changed after
+ a flush occurs so this is a very special use-case setting.
+ Additionally, the "nulling out" will still occur if the child
+ object is de-associated with the parent.
+
+ .. seealso::
+
+ :ref:`passive_deletes` - Introductory documentation
+ and examples.
+
+ :param passive_updates=True:
+ Indicates the persistence behavior to take when a referenced
+ primary key value changes in place, indicating that the referencing
+ foreign key columns will also need their value changed.
+
+ When True, it is assumed that ``ON UPDATE CASCADE`` is configured on
+ the foreign key in the database, and that the database will
+ handle propagation of an UPDATE from a source column to
+ dependent rows. When False, the SQLAlchemy
+ :func:`_orm.relationship`
+ construct will attempt to emit its own UPDATE statements to
+ modify related targets. However note that SQLAlchemy **cannot**
+ emit an UPDATE for more than one level of cascade. Also,
+ setting this flag to False is not compatible in the case where
+ the database is in fact enforcing referential integrity, unless
+ those constraints are explicitly "deferred", if the target backend
+ supports it.
+
+ It is highly advised that an application which is employing
+ mutable primary keys keeps ``passive_updates`` set to True,
+ and instead uses the referential integrity features of the database
+ itself in order to handle the change efficiently and fully.
+
+ .. seealso::
+
+ :ref:`passive_updates` - Introductory documentation and
+ examples.
+
+ :paramref:`.mapper.passive_updates` - a similar flag which
+ takes effect for joined-table inheritance mappings.
+
+ :param post_update:
+ This indicates that the relationship should be handled by a
+ second UPDATE statement after an INSERT or before a
+ DELETE. Currently, it also will issue an UPDATE after the
+ instance was UPDATEd as well, although this technically should
+ be improved. This flag is used to handle saving bi-directional
+ dependencies between two individual rows (i.e. each row
+ references the other), where it would otherwise be impossible to
+ INSERT or DELETE both rows fully since one row exists before the
+ other. Use this flag when a particular mapping arrangement will
+ incur two rows that are dependent on each other, such as a table
+ that has a one-to-many relationship to a set of child rows, and
+ also has a column that references a single child row within that
+ list (i.e. both tables contain a foreign key to each other). If
+ a flush operation returns an error that a "cyclical
+ dependency" was detected, this is a cue that you might want to
+ use :paramref:`_orm.relationship.post_update` to "break" the cycle.
+
+ .. seealso::
+
+ :ref:`post_update` - Introductory documentation and examples.
+
+ :param primaryjoin:
+ A SQL expression that will be used as the primary
+ join of the child object against the parent object, or in a
+ many-to-many relationship the join of the parent object to the
+ association table. By default, this value is computed based on the
+ foreign key relationships of the parent and child tables (or
+ association table).
+
+ :paramref:`_orm.relationship.primaryjoin` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`relationship_primaryjoin`
+
+ :param remote_side:
+ Used for self-referential relationships, indicates the column or
+ list of columns that form the "remote side" of the relationship.
+
+ :paramref:`_orm.relationship.remote_side` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`self_referential` - in-depth explanation of how
+ :paramref:`_orm.relationship.remote_side`
+ is used to configure self-referential relationships.
+
+ :func:`.remote` - an annotation function that accomplishes the
+ same purpose as :paramref:`_orm.relationship.remote_side`,
+ typically
+ when a custom :paramref:`_orm.relationship.primaryjoin` condition
+ is used.
+
+ :param query_class:
+ A :class:`_query.Query`
+ subclass that will be used internally by the
+ ``AppenderQuery`` returned by a "dynamic" relationship, that
+ is, a relationship that specifies ``lazy="dynamic"`` or was
+ otherwise constructed using the :func:`_orm.dynamic_loader`
+ function.
+
+ .. seealso::
+
+ :ref:`dynamic_relationship` - Introduction to "dynamic"
+ relationship loaders.
+
+ :param secondaryjoin:
+ A SQL expression that will be used as the join of
+ an association table to the child object. By default, this value is
+ computed based on the foreign key relationships of the association
+ and child tables.
+
+ :paramref:`_orm.relationship.secondaryjoin` may also be passed as a
+ callable function which is evaluated at mapper initialization time,
+ and may be passed as a Python-evaluable string when using
+ Declarative.
+
+ .. warning:: When passed as a Python-evaluable string, the
+ argument is interpreted using Python's ``eval()`` function.
+ **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**.
+ See :ref:`declarative_relationship_eval` for details on
+ declarative evaluation of :func:`_orm.relationship` arguments.
+
+ .. seealso::
+
+ :ref:`relationship_primaryjoin`
+
+ :param single_parent:
+ When True, installs a validator which will prevent objects
+ from being associated with more than one parent at a time.
+ This is used for many-to-one or many-to-many relationships that
+ should be treated either as one-to-one or one-to-many. Its usage
+ is optional, except for :func:`_orm.relationship` constructs which
+ are many-to-one or many-to-many and also
+ specify the ``delete-orphan`` cascade option. The
+ :func:`_orm.relationship` construct itself will raise an error
+ instructing when this option is required.
+
+ .. seealso::
+
+ :ref:`unitofwork_cascades` - includes detail on when the
+ :paramref:`_orm.relationship.single_parent`
+ flag may be appropriate.
+
+ :param uselist:
+ A boolean that indicates if this property should be loaded as a
+ list or a scalar. In most cases, this value is determined
+ automatically by :func:`_orm.relationship` at mapper configuration
+ time, based on the type and direction
+ of the relationship - one to many forms a list, many to one
+ forms a scalar, many to many is a list. If a scalar is desired
+ where normally a list would be present, such as a bi-directional
+ one-to-one relationship, set :paramref:`_orm.relationship.uselist`
+ to
+ False.
+
+ The :paramref:`_orm.relationship.uselist`
+ flag is also available on an
+ existing :func:`_orm.relationship`
+ construct as a read-only attribute,
+ which can be used to determine if this :func:`_orm.relationship`
+ deals
+ with collections or scalar attributes::
+
+ >>> User.addresses.property.uselist
+ True
+
+ .. seealso::
+
+ :ref:`relationships_one_to_one` - Introduction to the "one to
+ one" relationship pattern, which is typically when the
+ :paramref:`_orm.relationship.uselist` flag is needed.
+
+ :param viewonly=False:
+ When set to ``True``, the relationship is used only for loading
+ objects, and not for any persistence operation. A
+ :func:`_orm.relationship` which specifies
+ :paramref:`_orm.relationship.viewonly` can work
+ with a wider range of SQL operations within the
+ :paramref:`_orm.relationship.primaryjoin` condition, including
+ operations that feature the use of a variety of comparison operators
+ as well as SQL functions such as :func:`_expression.cast`. The
+ :paramref:`_orm.relationship.viewonly`
+ flag is also of general use when defining any kind of
+ :func:`_orm.relationship` that doesn't represent
+ the full set of related objects, to prevent modifications of the
+ collection from resulting in persistence operations.
+
+ When using the :paramref:`_orm.relationship.viewonly` flag in
+ conjunction with backrefs, the originating relationship for a
+ particular state change will not produce state changes within the
+ viewonly relationship. This is the behavior implied by
+ :paramref:`_orm.relationship.sync_backref` being set to False.
+
+ .. versionchanged:: 1.3.17 - the
+ :paramref:`_orm.relationship.sync_backref` flag is set to False
+ when using viewonly in conjunction with backrefs.
+
+ .. seealso::
+
+ :paramref:`_orm.relationship.sync_backref`
+
+ :param sync_backref:
+ A boolean that enables the events used to synchronize the in-Python
+ attributes when this relationship is target of either
+ :paramref:`_orm.relationship.backref` or
+ :paramref:`_orm.relationship.back_populates`.
+
+ Defaults to ``None``, which indicates that an automatic value should
+ be selected based on the value of the
+ :paramref:`_orm.relationship.viewonly` flag. When left at its
+ default, changes in state will be back-populated only if neither
+ sides of a relationship is viewonly.
+
+ .. versionadded:: 1.3.17
+
+ .. versionchanged:: 1.4 - A relationship that specifies
+ :paramref:`_orm.relationship.viewonly` automatically implies
+ that :paramref:`_orm.relationship.sync_backref` is ``False``.
+
+ .. seealso::
+
+ :paramref:`_orm.relationship.viewonly`
+
+ :param omit_join:
+ Allows manual control over the "selectin" automatic join
+ optimization. Set to ``False`` to disable the "omit join" feature
+ added in SQLAlchemy 1.3; or leave as ``None`` to leave automatic
+ optimization in place.
+
+ .. note:: This flag may only be set to ``False``. It is not
+ necessary to set it to ``True`` as the "omit_join" optimization is
+ automatically detected; if it is not detected, then the
+ optimization is not supported.
+
+ .. versionchanged:: 1.3.11 setting ``omit_join`` to True will now
+ emit a warning as this was not the intended use of this flag.
+
+ .. versionadded:: 1.3
+
+
+ """
+ super(RelationshipProperty, self).__init__()
+
+ self.uselist = uselist
+ self.argument = argument
+ self.secondary = secondary
+ self.primaryjoin = primaryjoin
+ self.secondaryjoin = secondaryjoin
+ self.post_update = post_update
+ self.direction = None
+ self.viewonly = viewonly
+ if viewonly:
+ self._warn_for_persistence_only_flags(
+ passive_deletes=passive_deletes,
+ passive_updates=passive_updates,
+ enable_typechecks=enable_typechecks,
+ active_history=active_history,
+ cascade_backrefs=cascade_backrefs,
+ )
+ if viewonly and sync_backref:
+ raise sa_exc.ArgumentError(
+ "sync_backref and viewonly cannot both be True"
+ )
+ self.sync_backref = sync_backref
+ self.lazy = lazy
+ self.single_parent = single_parent
+ self._user_defined_foreign_keys = foreign_keys
+ self.collection_class = collection_class
+ self.passive_deletes = passive_deletes
+ self.cascade_backrefs = cascade_backrefs
+ self.passive_updates = passive_updates
+ self.remote_side = remote_side
+ self.enable_typechecks = enable_typechecks
+ self.query_class = query_class
+ self.innerjoin = innerjoin
+ self.distinct_target_key = distinct_target_key
+ self.doc = doc
+ self.active_history = active_history
+ self._legacy_inactive_history_style = _legacy_inactive_history_style
+
+ self.join_depth = join_depth
+ if omit_join:
+ util.warn(
+ "setting omit_join to True is not supported; selectin "
+ "loading of this relationship may not work correctly if this "
+ "flag is set explicitly. omit_join optimization is "
+ "automatically detected for conditions under which it is "
+ "supported."
+ )
+
+ self.omit_join = omit_join
+ self.local_remote_pairs = _local_remote_pairs
+ self.bake_queries = bake_queries
+ self.load_on_pending = load_on_pending
+ self.comparator_factory = (
+ comparator_factory or RelationshipProperty.Comparator
+ )
+ self.comparator = self.comparator_factory(self, None)
+ util.set_creation_order(self)
+
+ if info is not None:
+ self.info = info
+
+ self.strategy_key = (("lazy", self.lazy),)
+
+ self._reverse_property = set()
+ if overlaps:
+ self._overlaps = set(re.split(r"\s*,\s*", overlaps))
+ else:
+ self._overlaps = ()
+
+ if cascade is not False:
+ self.cascade = cascade
+ elif self.viewonly:
+ self.cascade = "none"
+ else:
+ self.cascade = "save-update, merge"
+
+ self.order_by = order_by
+
+ self.back_populates = back_populates
+
+ if self.back_populates:
+ if backref:
+ raise sa_exc.ArgumentError(
+ "backref and back_populates keyword arguments "
+ "are mutually exclusive"
+ )
+ self.backref = None
+ else:
+ self.backref = backref
+
+ def _warn_for_persistence_only_flags(self, **kw):
+ for k, v in kw.items():
+ if v != self._persistence_only[k]:
+ # we are warning here rather than warn deprecated as this is a
+ # configuration mistake, and Python shows regular warnings more
+ # aggressively than deprecation warnings by default. Unlike the
+ # case of setting viewonly with cascade, the settings being
+ # warned about here are not actively doing the wrong thing
+ # against viewonly=True, so it is not as urgent to have these
+ # raise an error.
+ util.warn(
+ "Setting %s on relationship() while also "
+ "setting viewonly=True does not make sense, as a "
+ "viewonly=True relationship does not perform persistence "
+ "operations. This configuration may raise an error "
+ "in a future release." % (k,)
+ )
+
+ def instrument_class(self, mapper):
+ attributes.register_descriptor(
+ mapper.class_,
+ self.key,
+ comparator=self.comparator_factory(self, mapper),
+ parententity=mapper,
+ doc=self.doc,
+ )
+
+ class Comparator(PropComparator):
+ """Produce boolean, comparison, and other operators for
+ :class:`.RelationshipProperty` attributes.
+
+ See the documentation for :class:`.PropComparator` for a brief
+ overview of ORM level operator definition.
+
+ .. seealso::
+
+ :class:`.PropComparator`
+
+ :class:`.ColumnProperty.Comparator`
+
+ :class:`.ColumnOperators`
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ """
+
+ _of_type = None
+ _extra_criteria = ()
+
+ def __init__(
+ self,
+ prop,
+ parentmapper,
+ adapt_to_entity=None,
+ of_type=None,
+ extra_criteria=(),
+ ):
+ """Construction of :class:`.RelationshipProperty.Comparator`
+ is internal to the ORM's attribute mechanics.
+
+ """
+ self.prop = prop
+ self._parententity = parentmapper
+ self._adapt_to_entity = adapt_to_entity
+ if of_type:
+ self._of_type = of_type
+ self._extra_criteria = extra_criteria
+
+ def adapt_to_entity(self, adapt_to_entity):
+ return self.__class__(
+ self.property,
+ self._parententity,
+ adapt_to_entity=adapt_to_entity,
+ of_type=self._of_type,
+ )
+
+ @util.memoized_property
+ def entity(self):
+ """The target entity referred to by this
+ :class:`.RelationshipProperty.Comparator`.
+
+ This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp`
+ object.
+
+ This is the "target" or "remote" side of the
+ :func:`_orm.relationship`.
+
+ """
+ # this is a relatively recent change made for
+ # 1.4.27 as part of #7244.
+ # TODO: shouldn't _of_type be inspected up front when received?
+ if self._of_type is not None:
+ return inspect(self._of_type)
+ else:
+ return self.property.entity
+
+ @util.memoized_property
+ def mapper(self):
+ """The target :class:`_orm.Mapper` referred to by this
+ :class:`.RelationshipProperty.Comparator`.
+
+ This is the "target" or "remote" side of the
+ :func:`_orm.relationship`.
+
+ """
+ return self.property.mapper
+
+ @util.memoized_property
+ def _parententity(self):
+ return self.property.parent
+
+ def _source_selectable(self):
+ if self._adapt_to_entity:
+ return self._adapt_to_entity.selectable
+ else:
+ return self.property.parent._with_polymorphic_selectable
+
+ def __clause_element__(self):
+ adapt_from = self._source_selectable()
+ if self._of_type:
+ of_type_entity = inspect(self._of_type)
+ else:
+ of_type_entity = None
+
+ (
+ pj,
+ sj,
+ source,
+ dest,
+ secondary,
+ target_adapter,
+ ) = self.property._create_joins(
+ source_selectable=adapt_from,
+ source_polymorphic=True,
+ of_type_entity=of_type_entity,
+ alias_secondary=True,
+ extra_criteria=self._extra_criteria,
+ )
+ if sj is not None:
+ return pj & sj
+ else:
+ return pj
+
+ def of_type(self, cls):
+ r"""Redefine this object in terms of a polymorphic subclass.
+
+ See :meth:`.PropComparator.of_type` for an example.
+
+
+ """
+ return RelationshipProperty.Comparator(
+ self.property,
+ self._parententity,
+ adapt_to_entity=self._adapt_to_entity,
+ of_type=cls,
+ extra_criteria=self._extra_criteria,
+ )
+
+ def and_(self, *other):
+ """Add AND criteria.
+
+ See :meth:`.PropComparator.and_` for an example.
+
+ .. versionadded:: 1.4
+
+ """
+ return RelationshipProperty.Comparator(
+ self.property,
+ self._parententity,
+ adapt_to_entity=self._adapt_to_entity,
+ of_type=self._of_type,
+ extra_criteria=self._extra_criteria + other,
+ )
+
+ def in_(self, other):
+ """Produce an IN clause - this is not implemented
+ for :func:`_orm.relationship`-based attributes at this time.
+
+ """
+ raise NotImplementedError(
+ "in_() not yet supported for "
+ "relationships. For a simple "
+ "many-to-one, use in_() against "
+ "the set of foreign key values."
+ )
+
+ __hash__ = None
+
+ def __eq__(self, other):
+ """Implement the ``==`` operator.
+
+ In a many-to-one context, such as::
+
+ MyClass.some_prop == <some object>
+
+ this will typically produce a
+ clause such as::
+
+ mytable.related_id == <some id>
+
+ Where ``<some id>`` is the primary key of the given
+ object.
+
+ The ``==`` operator provides partial functionality for non-
+ many-to-one comparisons:
+
+ * Comparisons against collections are not supported.
+ Use :meth:`~.RelationshipProperty.Comparator.contains`.
+ * Compared to a scalar one-to-many, will produce a
+ clause that compares the target columns in the parent to
+ the given target.
+ * Compared to a scalar many-to-many, an alias
+ of the association table will be rendered as
+ well, forming a natural join that is part of the
+ main body of the query. This will not work for
+ queries that go beyond simple AND conjunctions of
+ comparisons, such as those which use OR. Use
+ explicit joins, outerjoins, or
+ :meth:`~.RelationshipProperty.Comparator.has` for
+ more comprehensive non-many-to-one scalar
+ membership tests.
+ * Comparisons against ``None`` given in a one-to-many
+ or many-to-many context produce a NOT EXISTS clause.
+
+ """
+ if isinstance(other, (util.NoneType, expression.Null)):
+ if self.property.direction in [ONETOMANY, MANYTOMANY]:
+ return ~self._criterion_exists()
+ else:
+ return _orm_annotate(
+ self.property._optimized_compare(
+ None, adapt_source=self.adapter
+ )
+ )
+ elif self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "Can't compare a collection to an object or collection; "
+ "use contains() to test for membership."
+ )
+ else:
+ return _orm_annotate(
+ self.property._optimized_compare(
+ other, adapt_source=self.adapter
+ )
+ )
+
+ def _criterion_exists(self, criterion=None, **kwargs):
+ if getattr(self, "_of_type", None):
+ info = inspect(self._of_type)
+ target_mapper, to_selectable, is_aliased_class = (
+ info.mapper,
+ info.selectable,
+ info.is_aliased_class,
+ )
+ if self.property._is_self_referential and not is_aliased_class:
+ to_selectable = to_selectable._anonymous_fromclause()
+
+ single_crit = target_mapper._single_table_criterion
+ if single_crit is not None:
+ if criterion is not None:
+ criterion = single_crit & criterion
+ else:
+ criterion = single_crit
+ else:
+ is_aliased_class = False
+ to_selectable = None
+
+ if self.adapter:
+ source_selectable = self._source_selectable()
+ else:
+ source_selectable = None
+
+ (
+ pj,
+ sj,
+ source,
+ dest,
+ secondary,
+ target_adapter,
+ ) = self.property._create_joins(
+ dest_selectable=to_selectable,
+ source_selectable=source_selectable,
+ )
+
+ for k in kwargs:
+ crit = getattr(self.property.mapper.class_, k) == kwargs[k]
+ if criterion is None:
+ criterion = crit
+ else:
+ criterion = criterion & crit
+
+ # annotate the *local* side of the join condition, in the case
+ # of pj + sj this is the full primaryjoin, in the case of just
+ # pj its the local side of the primaryjoin.
+ if sj is not None:
+ j = _orm_annotate(pj) & sj
+ else:
+ j = _orm_annotate(pj, exclude=self.property.remote_side)
+
+ if (
+ criterion is not None
+ and target_adapter
+ and not is_aliased_class
+ ):
+ # limit this adapter to annotated only?
+ criterion = target_adapter.traverse(criterion)
+
+ # only have the "joined left side" of what we
+ # return be subject to Query adaption. The right
+ # side of it is used for an exists() subquery and
+ # should not correlate or otherwise reach out
+ # to anything in the enclosing query.
+ if criterion is not None:
+ criterion = criterion._annotate(
+ {"no_replacement_traverse": True}
+ )
+
+ crit = j & sql.True_._ifnone(criterion)
+
+ if secondary is not None:
+ ex = (
+ sql.exists(1)
+ .where(crit)
+ .select_from(dest, secondary)
+ .correlate_except(dest, secondary)
+ )
+ else:
+ ex = (
+ sql.exists(1)
+ .where(crit)
+ .select_from(dest)
+ .correlate_except(dest)
+ )
+ return ex
+
+ def any(self, criterion=None, **kwargs):
+ """Produce an expression that tests a collection against
+ particular criterion, using EXISTS.
+
+ An expression like::
+
+ session.query(MyClass).filter(
+ MyClass.somereference.any(SomeRelated.x==2)
+ )
+
+
+ Will produce a query like::
+
+ SELECT * FROM my_table WHERE
+ EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id
+ AND related.x=2)
+
+ Because :meth:`~.RelationshipProperty.Comparator.any` uses
+ a correlated subquery, its performance is not nearly as
+ good when compared against large target tables as that of
+ using a join.
+
+ :meth:`~.RelationshipProperty.Comparator.any` is particularly
+ useful for testing for empty collections::
+
+ session.query(MyClass).filter(
+ ~MyClass.somereference.any()
+ )
+
+ will produce::
+
+ SELECT * FROM my_table WHERE
+ NOT (EXISTS (SELECT 1 FROM related WHERE
+ related.my_id=my_table.id))
+
+ :meth:`~.RelationshipProperty.Comparator.any` is only
+ valid for collections, i.e. a :func:`_orm.relationship`
+ that has ``uselist=True``. For scalar references,
+ use :meth:`~.RelationshipProperty.Comparator.has`.
+
+ """
+ if not self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "'any()' not implemented for scalar "
+ "attributes. Use has()."
+ )
+
+ return self._criterion_exists(criterion, **kwargs)
+
+ def has(self, criterion=None, **kwargs):
+ """Produce an expression that tests a scalar reference against
+ particular criterion, using EXISTS.
+
+ An expression like::
+
+ session.query(MyClass).filter(
+ MyClass.somereference.has(SomeRelated.x==2)
+ )
+
+
+ Will produce a query like::
+
+ SELECT * FROM my_table WHERE
+ EXISTS (SELECT 1 FROM related WHERE
+ related.id==my_table.related_id AND related.x=2)
+
+ Because :meth:`~.RelationshipProperty.Comparator.has` uses
+ a correlated subquery, its performance is not nearly as
+ good when compared against large target tables as that of
+ using a join.
+
+ :meth:`~.RelationshipProperty.Comparator.has` is only
+ valid for scalar references, i.e. a :func:`_orm.relationship`
+ that has ``uselist=False``. For collection references,
+ use :meth:`~.RelationshipProperty.Comparator.any`.
+
+ """
+ if self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "'has()' not implemented for collections. " "Use any()."
+ )
+ return self._criterion_exists(criterion, **kwargs)
+
+ def contains(self, other, **kwargs):
+ """Return a simple expression that tests a collection for
+ containment of a particular item.
+
+ :meth:`~.RelationshipProperty.Comparator.contains` is
+ only valid for a collection, i.e. a
+ :func:`_orm.relationship` that implements
+ one-to-many or many-to-many with ``uselist=True``.
+
+ When used in a simple one-to-many context, an
+ expression like::
+
+ MyClass.contains(other)
+
+ Produces a clause like::
+
+ mytable.id == <some id>
+
+ Where ``<some id>`` is the value of the foreign key
+ attribute on ``other`` which refers to the primary
+ key of its parent object. From this it follows that
+ :meth:`~.RelationshipProperty.Comparator.contains` is
+ very useful when used with simple one-to-many
+ operations.
+
+ For many-to-many operations, the behavior of
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ has more caveats. The association table will be
+ rendered in the statement, producing an "implicit"
+ join, that is, includes multiple tables in the FROM
+ clause which are equated in the WHERE clause::
+
+ query(MyClass).filter(MyClass.contains(other))
+
+ Produces a query like::
+
+ SELECT * FROM my_table, my_association_table AS
+ my_association_table_1 WHERE
+ my_table.id = my_association_table_1.parent_id
+ AND my_association_table_1.child_id = <some id>
+
+ Where ``<some id>`` would be the primary key of
+ ``other``. From the above, it is clear that
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ will **not** work with many-to-many collections when
+ used in queries that move beyond simple AND
+ conjunctions, such as multiple
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ expressions joined by OR. In such cases subqueries or
+ explicit "outer joins" will need to be used instead.
+ See :meth:`~.RelationshipProperty.Comparator.any` for
+ a less-performant alternative using EXISTS, or refer
+ to :meth:`_query.Query.outerjoin`
+ as well as :ref:`orm_queryguide_joins`
+ for more details on constructing outer joins.
+
+ kwargs may be ignored by this operator but are required for API
+ conformance.
+ """
+ if not self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "'contains' not implemented for scalar "
+ "attributes. Use =="
+ )
+ clause = self.property._optimized_compare(
+ other, adapt_source=self.adapter
+ )
+
+ if self.property.secondaryjoin is not None:
+ clause.negation_clause = self.__negated_contains_or_equals(
+ other
+ )
+
+ return clause
+
+ def __negated_contains_or_equals(self, other):
+ if self.property.direction == MANYTOONE:
+ state = attributes.instance_state(other)
+
+ def state_bindparam(local_col, state, remote_col):
+ dict_ = state.dict
+ return sql.bindparam(
+ local_col.key,
+ type_=local_col.type,
+ unique=True,
+ callable_=self.property._get_attr_w_warn_on_none(
+ self.property.mapper, state, dict_, remote_col
+ ),
+ )
+
+ def adapt(col):
+ if self.adapter:
+ return self.adapter(col)
+ else:
+ return col
+
+ if self.property._use_get:
+ return sql.and_(
+ *[
+ sql.or_(
+ adapt(x)
+ != state_bindparam(adapt(x), state, y),
+ adapt(x) == None,
+ )
+ for (x, y) in self.property.local_remote_pairs
+ ]
+ )
+
+ criterion = sql.and_(
+ *[
+ x == y
+ for (x, y) in zip(
+ self.property.mapper.primary_key,
+ self.property.mapper.primary_key_from_instance(other),
+ )
+ ]
+ )
+
+ return ~self._criterion_exists(criterion)
+
+ def __ne__(self, other):
+ """Implement the ``!=`` operator.
+
+ In a many-to-one context, such as::
+
+ MyClass.some_prop != <some object>
+
+ This will typically produce a clause such as::
+
+ mytable.related_id != <some id>
+
+ Where ``<some id>`` is the primary key of the
+ given object.
+
+ The ``!=`` operator provides partial functionality for non-
+ many-to-one comparisons:
+
+ * Comparisons against collections are not supported.
+ Use
+ :meth:`~.RelationshipProperty.Comparator.contains`
+ in conjunction with :func:`_expression.not_`.
+ * Compared to a scalar one-to-many, will produce a
+ clause that compares the target columns in the parent to
+ the given target.
+ * Compared to a scalar many-to-many, an alias
+ of the association table will be rendered as
+ well, forming a natural join that is part of the
+ main body of the query. This will not work for
+ queries that go beyond simple AND conjunctions of
+ comparisons, such as those which use OR. Use
+ explicit joins, outerjoins, or
+ :meth:`~.RelationshipProperty.Comparator.has` in
+ conjunction with :func:`_expression.not_` for
+ more comprehensive non-many-to-one scalar
+ membership tests.
+ * Comparisons against ``None`` given in a one-to-many
+ or many-to-many context produce an EXISTS clause.
+
+ """
+ if isinstance(other, (util.NoneType, expression.Null)):
+ if self.property.direction == MANYTOONE:
+ return _orm_annotate(
+ ~self.property._optimized_compare(
+ None, adapt_source=self.adapter
+ )
+ )
+
+ else:
+ return self._criterion_exists()
+ elif self.property.uselist:
+ raise sa_exc.InvalidRequestError(
+ "Can't compare a collection"
+ " to an object or collection; use "
+ "contains() to test for membership."
+ )
+ else:
+ return _orm_annotate(self.__negated_contains_or_equals(other))
+
+ @util.memoized_property
+ def property(self):
+ self.prop.parent._check_configure()
+ return self.prop
+
+ def _with_parent(self, instance, alias_secondary=True, from_entity=None):
+ assert instance is not None
+ adapt_source = None
+ if from_entity is not None:
+ insp = inspect(from_entity)
+ if insp.is_aliased_class:
+ adapt_source = insp._adapter.adapt_clause
+ return self._optimized_compare(
+ instance,
+ value_is_parent=True,
+ adapt_source=adapt_source,
+ alias_secondary=alias_secondary,
+ )
+
+ def _optimized_compare(
+ self,
+ state,
+ value_is_parent=False,
+ adapt_source=None,
+ alias_secondary=True,
+ ):
+ if state is not None:
+ try:
+ state = inspect(state)
+ except sa_exc.NoInspectionAvailable:
+ state = None
+
+ if state is None or not getattr(state, "is_instance", False):
+ raise sa_exc.ArgumentError(
+ "Mapped instance expected for relationship "
+ "comparison to object. Classes, queries and other "
+ "SQL elements are not accepted in this context; for "
+ "comparison with a subquery, "
+ "use %s.has(**criteria)." % self
+ )
+ reverse_direction = not value_is_parent
+
+ if state is None:
+ return self._lazy_none_clause(
+ reverse_direction, adapt_source=adapt_source
+ )
+
+ if not reverse_direction:
+ criterion, bind_to_col = (
+ self._lazy_strategy._lazywhere,
+ self._lazy_strategy._bind_to_col,
+ )
+ else:
+ criterion, bind_to_col = (
+ self._lazy_strategy._rev_lazywhere,
+ self._lazy_strategy._rev_bind_to_col,
+ )
+
+ if reverse_direction:
+ mapper = self.mapper
+ else:
+ mapper = self.parent
+
+ dict_ = attributes.instance_dict(state.obj())
+
+ def visit_bindparam(bindparam):
+ if bindparam._identifying_key in bind_to_col:
+ bindparam.callable = self._get_attr_w_warn_on_none(
+ mapper,
+ state,
+ dict_,
+ bind_to_col[bindparam._identifying_key],
+ )
+
+ if self.secondary is not None and alias_secondary:
+ criterion = ClauseAdapter(
+ self.secondary._anonymous_fromclause()
+ ).traverse(criterion)
+
+ criterion = visitors.cloned_traverse(
+ criterion, {}, {"bindparam": visit_bindparam}
+ )
+
+ if adapt_source:
+ criterion = adapt_source(criterion)
+ return criterion
+
+ def _get_attr_w_warn_on_none(self, mapper, state, dict_, column):
+ """Create the callable that is used in a many-to-one expression.
+
+ E.g.::
+
+ u1 = s.query(User).get(5)
+
+ expr = Address.user == u1
+
+ Above, the SQL should be "address.user_id = 5". The callable
+ returned by this method produces the value "5" based on the identity
+ of ``u1``.
+
+ """
+
+ # in this callable, we're trying to thread the needle through
+ # a wide variety of scenarios, including:
+ #
+ # * the object hasn't been flushed yet and there's no value for
+ # the attribute as of yet
+ #
+ # * the object hasn't been flushed yet but it has a user-defined
+ # value
+ #
+ # * the object has a value but it's expired and not locally present
+ #
+ # * the object has a value but it's expired and not locally present,
+ # and the object is also detached
+ #
+ # * The object hadn't been flushed yet, there was no value, but
+ # later, the object has been expired and detached, and *now*
+ # they're trying to evaluate it
+ #
+ # * the object had a value, but it was changed to a new value, and
+ # then expired
+ #
+ # * the object had a value, but it was changed to a new value, and
+ # then expired, then the object was detached
+ #
+ # * the object has a user-set value, but it's None and we don't do
+ # the comparison correctly for that so warn
+ #
+
+ prop = mapper.get_property_by_column(column)
+
+ # by invoking this method, InstanceState will track the last known
+ # value for this key each time the attribute is to be expired.
+ # this feature was added explicitly for use in this method.
+ state._track_last_known_value(prop.key)
+
+ def _go():
+ last_known = to_return = state._last_known_values[prop.key]
+ existing_is_available = last_known is not attributes.NO_VALUE
+
+ # we support that the value may have changed. so here we
+ # try to get the most recent value including re-fetching.
+ # only if we can't get a value now due to detachment do we return
+ # the last known value
+ current_value = mapper._get_state_attr_by_column(
+ state,
+ dict_,
+ column,
+ passive=attributes.PASSIVE_OFF
+ if state.persistent
+ else attributes.PASSIVE_NO_FETCH ^ attributes.INIT_OK,
+ )
+
+ if current_value is attributes.NEVER_SET:
+ if not existing_is_available:
+ raise sa_exc.InvalidRequestError(
+ "Can't resolve value for column %s on object "
+ "%s; no value has been set for this column"
+ % (column, state_str(state))
+ )
+ elif current_value is attributes.PASSIVE_NO_RESULT:
+ if not existing_is_available:
+ raise sa_exc.InvalidRequestError(
+ "Can't resolve value for column %s on object "
+ "%s; the object is detached and the value was "
+ "expired" % (column, state_str(state))
+ )
+ else:
+ to_return = current_value
+ if to_return is None:
+ util.warn(
+ "Got None for value of column %s; this is unsupported "
+ "for a relationship comparison and will not "
+ "currently produce an IS comparison "
+ "(but may in a future release)" % column
+ )
+ return to_return
+
+ return _go
+
+ def _lazy_none_clause(self, reverse_direction=False, adapt_source=None):
+ if not reverse_direction:
+ criterion, bind_to_col = (
+ self._lazy_strategy._lazywhere,
+ self._lazy_strategy._bind_to_col,
+ )
+ else:
+ criterion, bind_to_col = (
+ self._lazy_strategy._rev_lazywhere,
+ self._lazy_strategy._rev_bind_to_col,
+ )
+
+ criterion = adapt_criterion_to_null(criterion, bind_to_col)
+
+ if adapt_source:
+ criterion = adapt_source(criterion)
+ return criterion
+
+ def __str__(self):
+ return str(self.parent.class_.__name__) + "." + self.key
+
+ def merge(
+ self,
+ session,
+ source_state,
+ source_dict,
+ dest_state,
+ dest_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ ):
+
+ if load:
+ for r in self._reverse_property:
+ if (source_state, r) in _recursive:
+ return
+
+ if "merge" not in self._cascade:
+ return
+
+ if self.key not in source_dict:
+ return
+
+ if self.uselist:
+ impl = source_state.get_impl(self.key)
+ instances_iterable = impl.get_collection(source_state, source_dict)
+
+ # if this is a CollectionAttributeImpl, then empty should
+ # be False, otherwise "self.key in source_dict" should not be
+ # True
+ assert not instances_iterable.empty if impl.collection else True
+
+ if load:
+ # for a full merge, pre-load the destination collection,
+ # so that individual _merge of each item pulls from identity
+ # map for those already present.
+ # also assumes CollectionAttributeImpl behavior of loading
+ # "old" list in any case
+ dest_state.get_impl(self.key).get(dest_state, dest_dict)
+
+ dest_list = []
+ for current in instances_iterable:
+ current_state = attributes.instance_state(current)
+ current_dict = attributes.instance_dict(current)
+ _recursive[(current_state, self)] = True
+ obj = session._merge(
+ current_state,
+ current_dict,
+ load=load,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
+ if obj is not None:
+ dest_list.append(obj)
+
+ if not load:
+ coll = attributes.init_state_collection(
+ dest_state, dest_dict, self.key
+ )
+ for c in dest_list:
+ coll.append_without_event(c)
+ else:
+ dest_state.get_impl(self.key).set(
+ dest_state, dest_dict, dest_list, _adapt=False
+ )
+ else:
+ current = source_dict[self.key]
+ if current is not None:
+ current_state = attributes.instance_state(current)
+ current_dict = attributes.instance_dict(current)
+ _recursive[(current_state, self)] = True
+ obj = session._merge(
+ current_state,
+ current_dict,
+ load=load,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
+ else:
+ obj = None
+
+ if not load:
+ dest_dict[self.key] = obj
+ else:
+ dest_state.get_impl(self.key).set(
+ dest_state, dest_dict, obj, None
+ )
+
+ def _value_as_iterable(
+ self, state, dict_, key, passive=attributes.PASSIVE_OFF
+ ):
+ """Return a list of tuples (state, obj) for the given
+ key.
+
+ returns an empty list if the value is None/empty/PASSIVE_NO_RESULT
+ """
+
+ impl = state.manager[key].impl
+ x = impl.get(state, dict_, passive=passive)
+ if x is attributes.PASSIVE_NO_RESULT or x is None:
+ return []
+ elif hasattr(impl, "get_collection"):
+ return [
+ (attributes.instance_state(o), o)
+ for o in impl.get_collection(state, dict_, x, passive=passive)
+ ]
+ else:
+ return [(attributes.instance_state(x), x)]
+
+ def cascade_iterator(
+ self, type_, state, dict_, visited_states, halt_on=None
+ ):
+ # assert type_ in self._cascade
+
+ # only actively lazy load on the 'delete' cascade
+ if type_ != "delete" or self.passive_deletes:
+ passive = attributes.PASSIVE_NO_INITIALIZE
+ else:
+ passive = attributes.PASSIVE_OFF
+
+ if type_ == "save-update":
+ tuples = state.manager[self.key].impl.get_all_pending(state, dict_)
+
+ else:
+ tuples = self._value_as_iterable(
+ state, dict_, self.key, passive=passive
+ )
+
+ skip_pending = (
+ type_ == "refresh-expire" and "delete-orphan" not in self._cascade
+ )
+
+ for instance_state, c in tuples:
+ if instance_state in visited_states:
+ continue
+
+ if c is None:
+ # would like to emit a warning here, but
+ # would not be consistent with collection.append(None)
+ # current behavior of silently skipping.
+ # see [ticket:2229]
+ continue
+
+ instance_dict = attributes.instance_dict(c)
+
+ if halt_on and halt_on(instance_state):
+ continue
+
+ if skip_pending and not instance_state.key:
+ continue
+
+ instance_mapper = instance_state.manager.mapper
+
+ if not instance_mapper.isa(self.mapper.class_manager.mapper):
+ raise AssertionError(
+ "Attribute '%s' on class '%s' "
+ "doesn't handle objects "
+ "of type '%s'"
+ % (self.key, self.parent.class_, c.__class__)
+ )
+
+ visited_states.add(instance_state)
+
+ yield c, instance_mapper, instance_state, instance_dict
+
+ @property
+ def _effective_sync_backref(self):
+ if self.viewonly:
+ return False
+ else:
+ return self.sync_backref is not False
+
+ @staticmethod
+ def _check_sync_backref(rel_a, rel_b):
+ if rel_a.viewonly and rel_b.sync_backref:
+ raise sa_exc.InvalidRequestError(
+ "Relationship %s cannot specify sync_backref=True since %s "
+ "includes viewonly=True." % (rel_b, rel_a)
+ )
+ if (
+ rel_a.viewonly
+ and not rel_b.viewonly
+ and rel_b.sync_backref is not False
+ ):
+ rel_b.sync_backref = False
+
+ def _add_reverse_property(self, key):
+ other = self.mapper.get_property(key, _configure_mappers=False)
+ if not isinstance(other, RelationshipProperty):
+ raise sa_exc.InvalidRequestError(
+ "back_populates on relationship '%s' refers to attribute '%s' "
+ "that is not a relationship. The back_populates parameter "
+ "should refer to the name of a relationship on the target "
+ "class." % (self, other)
+ )
+ # viewonly and sync_backref cases
+ # 1. self.viewonly==True and other.sync_backref==True -> error
+ # 2. self.viewonly==True and other.viewonly==False and
+ # other.sync_backref==None -> warn sync_backref=False, set to False
+ self._check_sync_backref(self, other)
+ # 3. other.viewonly==True and self.sync_backref==True -> error
+ # 4. other.viewonly==True and self.viewonly==False and
+ # self.sync_backref==None -> warn sync_backref=False, set to False
+ self._check_sync_backref(other, self)
+
+ self._reverse_property.add(other)
+ other._reverse_property.add(self)
+
+ if not other.mapper.common_parent(self.parent):
+ raise sa_exc.ArgumentError(
+ "reverse_property %r on "
+ "relationship %s references relationship %s, which "
+ "does not reference mapper %s"
+ % (key, self, other, self.parent)
+ )
+
+ if (
+ self.direction in (ONETOMANY, MANYTOONE)
+ and self.direction == other.direction
+ ):
+ raise sa_exc.ArgumentError(
+ "%s and back-reference %s are "
+ "both of the same direction %r. Did you mean to "
+ "set remote_side on the many-to-one side ?"
+ % (other, self, self.direction)
+ )
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.orm.mapper")
+ def entity(self):
+ """Return the target mapped entity, which is an inspect() of the
+ class or aliased class that is referred towards.
+
+ """
+
+ mapperlib = util.preloaded.orm_mapper
+
+ if isinstance(self.argument, util.string_types):
+ argument = self._clsregistry_resolve_name(self.argument)()
+
+ elif callable(self.argument) and not isinstance(
+ self.argument, (type, mapperlib.Mapper)
+ ):
+ argument = self.argument()
+ else:
+ argument = self.argument
+
+ if isinstance(argument, type):
+ return mapperlib.class_mapper(argument, configure=False)
+
+ try:
+ entity = inspect(argument)
+ except sa_exc.NoInspectionAvailable:
+ pass
+ else:
+ if hasattr(entity, "mapper"):
+ return entity
+
+ raise sa_exc.ArgumentError(
+ "relationship '%s' expects "
+ "a class or a mapper argument (received: %s)"
+ % (self.key, type(argument))
+ )
+
+ @util.memoized_property
+ def mapper(self):
+ """Return the targeted :class:`_orm.Mapper` for this
+ :class:`.RelationshipProperty`.
+
+ This is a lazy-initializing static attribute.
+
+ """
+ return self.entity.mapper
+
+ def do_init(self):
+ self._check_conflicts()
+ self._process_dependent_arguments()
+ self._setup_registry_dependencies()
+ self._setup_join_conditions()
+ self._check_cascade_settings(self._cascade)
+ self._post_init()
+ self._generate_backref()
+ self._join_condition._warn_for_conflicting_sync_targets()
+ super(RelationshipProperty, self).do_init()
+ self._lazy_strategy = self._get_strategy((("lazy", "select"),))
+
+ def _setup_registry_dependencies(self):
+ self.parent.mapper.registry._set_depends_on(
+ self.entity.mapper.registry
+ )
+
+ def _process_dependent_arguments(self):
+ """Convert incoming configuration arguments to their
+ proper form.
+
+ Callables are resolved, ORM annotations removed.
+
+ """
+
+ # accept callables for other attributes which may require
+ # deferred initialization. This technique is used
+ # by declarative "string configs" and some recipes.
+ for attr in (
+ "order_by",
+ "primaryjoin",
+ "secondaryjoin",
+ "secondary",
+ "_user_defined_foreign_keys",
+ "remote_side",
+ ):
+ attr_value = getattr(self, attr)
+
+ if isinstance(attr_value, util.string_types):
+ setattr(
+ self,
+ attr,
+ self._clsregistry_resolve_arg(
+ attr_value, favor_tables=attr == "secondary"
+ )(),
+ )
+ elif callable(attr_value) and not _is_mapped_class(attr_value):
+ setattr(self, attr, attr_value())
+
+ # remove "annotations" which are present if mapped class
+ # descriptors are used to create the join expression.
+ for attr in "primaryjoin", "secondaryjoin":
+ val = getattr(self, attr)
+ if val is not None:
+ setattr(
+ self,
+ attr,
+ _orm_deannotate(
+ coercions.expect(
+ roles.ColumnArgumentRole, val, argname=attr
+ )
+ ),
+ )
+
+ if self.secondary is not None and _is_mapped_class(self.secondary):
+ raise sa_exc.ArgumentError(
+ "secondary argument %s passed to to relationship() %s must "
+ "be a Table object or other FROM clause; can't send a mapped "
+ "class directly as rows in 'secondary' are persisted "
+ "independently of a class that is mapped "
+ "to that same table." % (self.secondary, self)
+ )
+
+ # ensure expressions in self.order_by, foreign_keys,
+ # remote_side are all columns, not strings.
+ if self.order_by is not False and self.order_by is not None:
+ self.order_by = tuple(
+ coercions.expect(
+ roles.ColumnArgumentRole, x, argname="order_by"
+ )
+ for x in util.to_list(self.order_by)
+ )
+
+ self._user_defined_foreign_keys = util.column_set(
+ coercions.expect(
+ roles.ColumnArgumentRole, x, argname="foreign_keys"
+ )
+ for x in util.to_column_set(self._user_defined_foreign_keys)
+ )
+
+ self.remote_side = util.column_set(
+ coercions.expect(
+ roles.ColumnArgumentRole, x, argname="remote_side"
+ )
+ for x in util.to_column_set(self.remote_side)
+ )
+
+ self.target = self.entity.persist_selectable
+
+ def _setup_join_conditions(self):
+ self._join_condition = jc = JoinCondition(
+ parent_persist_selectable=self.parent.persist_selectable,
+ child_persist_selectable=self.entity.persist_selectable,
+ parent_local_selectable=self.parent.local_table,
+ child_local_selectable=self.entity.local_table,
+ primaryjoin=self.primaryjoin,
+ secondary=self.secondary,
+ secondaryjoin=self.secondaryjoin,
+ parent_equivalents=self.parent._equivalent_columns,
+ child_equivalents=self.mapper._equivalent_columns,
+ consider_as_foreign_keys=self._user_defined_foreign_keys,
+ local_remote_pairs=self.local_remote_pairs,
+ remote_side=self.remote_side,
+ self_referential=self._is_self_referential,
+ prop=self,
+ support_sync=not self.viewonly,
+ can_be_synced_fn=self._columns_are_mapped,
+ )
+ self.primaryjoin = jc.primaryjoin
+ self.secondaryjoin = jc.secondaryjoin
+ self.direction = jc.direction
+ self.local_remote_pairs = jc.local_remote_pairs
+ self.remote_side = jc.remote_columns
+ self.local_columns = jc.local_columns
+ self.synchronize_pairs = jc.synchronize_pairs
+ self._calculated_foreign_keys = jc.foreign_key_columns
+ self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs
+
+ @property
+ def _clsregistry_resolve_arg(self):
+ return self._clsregistry_resolvers[1]
+
+ @property
+ def _clsregistry_resolve_name(self):
+ return self._clsregistry_resolvers[0]
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.orm.clsregistry")
+ def _clsregistry_resolvers(self):
+ _resolver = util.preloaded.orm_clsregistry._resolver
+
+ return _resolver(self.parent.class_, self)
+
+ @util.preload_module("sqlalchemy.orm.mapper")
+ def _check_conflicts(self):
+ """Test that this relationship is legal, warn about
+ inheritance conflicts."""
+ mapperlib = util.preloaded.orm_mapper
+ if self.parent.non_primary and not mapperlib.class_mapper(
+ self.parent.class_, configure=False
+ ).has_property(self.key):
+ raise sa_exc.ArgumentError(
+ "Attempting to assign a new "
+ "relationship '%s' to a non-primary mapper on "
+ "class '%s'. New relationships can only be added "
+ "to the primary mapper, i.e. the very first mapper "
+ "created for class '%s' "
+ % (
+ self.key,
+ self.parent.class_.__name__,
+ self.parent.class_.__name__,
+ )
+ )
+
+ @property
+ def cascade(self):
+ """Return the current cascade setting for this
+ :class:`.RelationshipProperty`.
+ """
+ return self._cascade
+
+ @cascade.setter
+ def cascade(self, cascade):
+ self._set_cascade(cascade)
+
+ def _set_cascade(self, cascade):
+ cascade = CascadeOptions(cascade)
+
+ if self.viewonly:
+ non_viewonly = set(cascade).difference(
+ CascadeOptions._viewonly_cascades
+ )
+ if non_viewonly:
+ raise sa_exc.ArgumentError(
+ 'Cascade settings "%s" apply to persistence operations '
+ "and should not be combined with a viewonly=True "
+ "relationship." % (", ".join(sorted(non_viewonly)))
+ )
+
+ if "mapper" in self.__dict__:
+ self._check_cascade_settings(cascade)
+ self._cascade = cascade
+
+ if self._dependency_processor:
+ self._dependency_processor.cascade = cascade
+
+ def _check_cascade_settings(self, cascade):
+ if (
+ cascade.delete_orphan
+ and not self.single_parent
+ and (self.direction is MANYTOMANY or self.direction is MANYTOONE)
+ ):
+ raise sa_exc.ArgumentError(
+ "For %(direction)s relationship %(rel)s, delete-orphan "
+ "cascade is normally "
+ 'configured only on the "one" side of a one-to-many '
+ "relationship, "
+ 'and not on the "many" side of a many-to-one or many-to-many '
+ "relationship. "
+ "To force this relationship to allow a particular "
+ '"%(relatedcls)s" object to be referred towards by only '
+ 'a single "%(clsname)s" object at a time via the '
+ "%(rel)s relationship, which "
+ "would allow "
+ "delete-orphan cascade to take place in this direction, set "
+ "the single_parent=True flag."
+ % {
+ "rel": self,
+ "direction": "many-to-one"
+ if self.direction is MANYTOONE
+ else "many-to-many",
+ "clsname": self.parent.class_.__name__,
+ "relatedcls": self.mapper.class_.__name__,
+ },
+ code="bbf0",
+ )
+
+ if self.passive_deletes == "all" and (
+ "delete" in cascade or "delete-orphan" in cascade
+ ):
+ raise sa_exc.ArgumentError(
+ "On %s, can't set passive_deletes='all' in conjunction "
+ "with 'delete' or 'delete-orphan' cascade" % self
+ )
+
+ if cascade.delete_orphan:
+ self.mapper.primary_mapper()._delete_orphans.append(
+ (self.key, self.parent.class_)
+ )
+
+ def _persists_for(self, mapper):
+ """Return True if this property will persist values on behalf
+ of the given mapper.
+
+ """
+
+ return (
+ self.key in mapper.relationships
+ and mapper.relationships[self.key] is self
+ )
+
+ def _columns_are_mapped(self, *cols):
+ """Return True if all columns in the given collection are
+ mapped by the tables referenced by this :class:`.Relationship`.
+
+ """
+ for c in cols:
+ if (
+ self.secondary is not None
+ and self.secondary.c.contains_column(c)
+ ):
+ continue
+ if not self.parent.persist_selectable.c.contains_column(
+ c
+ ) and not self.target.c.contains_column(c):
+ return False
+ return True
+
+ def _generate_backref(self):
+ """Interpret the 'backref' instruction to create a
+ :func:`_orm.relationship` complementary to this one."""
+
+ if self.parent.non_primary:
+ return
+ if self.backref is not None and not self.back_populates:
+ if isinstance(self.backref, util.string_types):
+ backref_key, kwargs = self.backref, {}
+ else:
+ backref_key, kwargs = self.backref
+ mapper = self.mapper.primary_mapper()
+
+ if not mapper.concrete:
+ check = set(mapper.iterate_to_root()).union(
+ mapper.self_and_descendants
+ )
+ for m in check:
+ if m.has_property(backref_key) and not m.concrete:
+ raise sa_exc.ArgumentError(
+ "Error creating backref "
+ "'%s' on relationship '%s': property of that "
+ "name exists on mapper '%s'"
+ % (backref_key, self, m)
+ )
+
+ # determine primaryjoin/secondaryjoin for the
+ # backref. Use the one we had, so that
+ # a custom join doesn't have to be specified in
+ # both directions.
+ if self.secondary is not None:
+ # for many to many, just switch primaryjoin/
+ # secondaryjoin. use the annotated
+ # pj/sj on the _join_condition.
+ pj = kwargs.pop(
+ "primaryjoin",
+ self._join_condition.secondaryjoin_minus_local,
+ )
+ sj = kwargs.pop(
+ "secondaryjoin",
+ self._join_condition.primaryjoin_minus_local,
+ )
+ else:
+ pj = kwargs.pop(
+ "primaryjoin",
+ self._join_condition.primaryjoin_reverse_remote,
+ )
+ sj = kwargs.pop("secondaryjoin", None)
+ if sj:
+ raise sa_exc.InvalidRequestError(
+ "Can't assign 'secondaryjoin' on a backref "
+ "against a non-secondary relationship."
+ )
+
+ foreign_keys = kwargs.pop(
+ "foreign_keys", self._user_defined_foreign_keys
+ )
+ parent = self.parent.primary_mapper()
+ kwargs.setdefault("viewonly", self.viewonly)
+ kwargs.setdefault("post_update", self.post_update)
+ kwargs.setdefault("passive_updates", self.passive_updates)
+ kwargs.setdefault("sync_backref", self.sync_backref)
+ self.back_populates = backref_key
+ relationship = RelationshipProperty(
+ parent,
+ self.secondary,
+ pj,
+ sj,
+ foreign_keys=foreign_keys,
+ back_populates=self.key,
+ **kwargs
+ )
+ mapper._configure_property(backref_key, relationship)
+
+ if self.back_populates:
+ self._add_reverse_property(self.back_populates)
+
+ @util.preload_module("sqlalchemy.orm.dependency")
+ def _post_init(self):
+ dependency = util.preloaded.orm_dependency
+
+ if self.uselist is None:
+ self.uselist = self.direction is not MANYTOONE
+ if not self.viewonly:
+ self._dependency_processor = (
+ dependency.DependencyProcessor.from_relationship
+ )(self)
+
+ @util.memoized_property
+ def _use_get(self):
+ """memoize the 'use_get' attribute of this RelationshipLoader's
+ lazyloader."""
+
+ strategy = self._lazy_strategy
+ return strategy.use_get
+
+ @util.memoized_property
+ def _is_self_referential(self):
+ return self.mapper.common_parent(self.parent)
+
+ def _create_joins(
+ self,
+ source_polymorphic=False,
+ source_selectable=None,
+ dest_selectable=None,
+ of_type_entity=None,
+ alias_secondary=False,
+ extra_criteria=(),
+ ):
+
+ aliased = False
+
+ if alias_secondary and self.secondary is not None:
+ aliased = True
+
+ if source_selectable is None:
+ if source_polymorphic and self.parent.with_polymorphic:
+ source_selectable = self.parent._with_polymorphic_selectable
+
+ if of_type_entity:
+ dest_mapper = of_type_entity.mapper
+ if dest_selectable is None:
+ dest_selectable = of_type_entity.selectable
+ aliased = True
+ else:
+ dest_mapper = self.mapper
+
+ if dest_selectable is None:
+ dest_selectable = self.entity.selectable
+ if self.mapper.with_polymorphic:
+ aliased = True
+
+ if self._is_self_referential and source_selectable is None:
+ dest_selectable = dest_selectable._anonymous_fromclause()
+ aliased = True
+ elif (
+ dest_selectable is not self.mapper._with_polymorphic_selectable
+ or self.mapper.with_polymorphic
+ ):
+ aliased = True
+
+ single_crit = dest_mapper._single_table_criterion
+ aliased = aliased or (
+ source_selectable is not None
+ and (
+ source_selectable
+ is not self.parent._with_polymorphic_selectable
+ or source_selectable._is_subquery
+ )
+ )
+
+ (
+ primaryjoin,
+ secondaryjoin,
+ secondary,
+ target_adapter,
+ dest_selectable,
+ ) = self._join_condition.join_targets(
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit,
+ extra_criteria,
+ )
+ if source_selectable is None:
+ source_selectable = self.parent.local_table
+ if dest_selectable is None:
+ dest_selectable = self.entity.local_table
+ return (
+ primaryjoin,
+ secondaryjoin,
+ source_selectable,
+ dest_selectable,
+ secondary,
+ target_adapter,
+ )
+
+
+def _annotate_columns(element, annotations):
+ def clone(elem):
+ if isinstance(elem, expression.ColumnClause):
+ elem = elem._annotate(annotations.copy())
+ elem._copy_internals(clone=clone)
+ return elem
+
+ if element is not None:
+ element = clone(element)
+ clone = None # remove gc cycles
+ return element
+
+
+class JoinCondition(object):
+ def __init__(
+ self,
+ parent_persist_selectable,
+ child_persist_selectable,
+ parent_local_selectable,
+ child_local_selectable,
+ primaryjoin=None,
+ secondary=None,
+ secondaryjoin=None,
+ parent_equivalents=None,
+ child_equivalents=None,
+ consider_as_foreign_keys=None,
+ local_remote_pairs=None,
+ remote_side=None,
+ self_referential=False,
+ prop=None,
+ support_sync=True,
+ can_be_synced_fn=lambda *c: True,
+ ):
+ self.parent_persist_selectable = parent_persist_selectable
+ self.parent_local_selectable = parent_local_selectable
+ self.child_persist_selectable = child_persist_selectable
+ self.child_local_selectable = child_local_selectable
+ self.parent_equivalents = parent_equivalents
+ self.child_equivalents = child_equivalents
+ self.primaryjoin = primaryjoin
+ self.secondaryjoin = secondaryjoin
+ self.secondary = secondary
+ self.consider_as_foreign_keys = consider_as_foreign_keys
+ self._local_remote_pairs = local_remote_pairs
+ self._remote_side = remote_side
+ self.prop = prop
+ self.self_referential = self_referential
+ self.support_sync = support_sync
+ self.can_be_synced_fn = can_be_synced_fn
+ self._determine_joins()
+ self._sanitize_joins()
+ self._annotate_fks()
+ self._annotate_remote()
+ self._annotate_local()
+ self._annotate_parentmapper()
+ self._setup_pairs()
+ self._check_foreign_cols(self.primaryjoin, True)
+ if self.secondaryjoin is not None:
+ self._check_foreign_cols(self.secondaryjoin, False)
+ self._determine_direction()
+ self._check_remote_side()
+ self._log_joins()
+
+ def _log_joins(self):
+ if self.prop is None:
+ return
+ log = self.prop.logger
+ log.info("%s setup primary join %s", self.prop, self.primaryjoin)
+ log.info("%s setup secondary join %s", self.prop, self.secondaryjoin)
+ log.info(
+ "%s synchronize pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s => %s)" % (l, r) for (l, r) in self.synchronize_pairs
+ ),
+ )
+ log.info(
+ "%s secondary synchronize pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s => %s)" % (l, r)
+ for (l, r) in self.secondary_synchronize_pairs or []
+ ),
+ )
+ log.info(
+ "%s local/remote pairs [%s]",
+ self.prop,
+ ",".join(
+ "(%s / %s)" % (l, r) for (l, r) in self.local_remote_pairs
+ ),
+ )
+ log.info(
+ "%s remote columns [%s]",
+ self.prop,
+ ",".join("%s" % col for col in self.remote_columns),
+ )
+ log.info(
+ "%s local columns [%s]",
+ self.prop,
+ ",".join("%s" % col for col in self.local_columns),
+ )
+ log.info("%s relationship direction %s", self.prop, self.direction)
+
+ def _sanitize_joins(self):
+ """remove the parententity annotation from our join conditions which
+ can leak in here based on some declarative patterns and maybe others.
+
+ We'd want to remove "parentmapper" also, but apparently there's
+ an exotic use case in _join_fixture_inh_selfref_w_entity
+ that relies upon it being present, see :ticket:`3364`.
+
+ """
+
+ self.primaryjoin = _deep_deannotate(
+ self.primaryjoin, values=("parententity", "proxy_key")
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = _deep_deannotate(
+ self.secondaryjoin, values=("parententity", "proxy_key")
+ )
+
+ def _determine_joins(self):
+ """Determine the 'primaryjoin' and 'secondaryjoin' attributes,
+ if not passed to the constructor already.
+
+ This is based on analysis of the foreign key relationships
+ between the parent and target mapped selectables.
+
+ """
+ if self.secondaryjoin is not None and self.secondary is None:
+ raise sa_exc.ArgumentError(
+ "Property %s specified with secondary "
+ "join condition but "
+ "no secondary argument" % self.prop
+ )
+
+ # find a join between the given mapper's mapped table and
+ # the given table. will try the mapper's local table first
+ # for more specificity, then if not found will try the more
+ # general mapped table, which in the case of inheritance is
+ # a join.
+ try:
+ consider_as_foreign_keys = self.consider_as_foreign_keys or None
+ if self.secondary is not None:
+ if self.secondaryjoin is None:
+ self.secondaryjoin = join_condition(
+ self.child_persist_selectable,
+ self.secondary,
+ a_subset=self.child_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+ if self.primaryjoin is None:
+ self.primaryjoin = join_condition(
+ self.parent_persist_selectable,
+ self.secondary,
+ a_subset=self.parent_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+ else:
+ if self.primaryjoin is None:
+ self.primaryjoin = join_condition(
+ self.parent_persist_selectable,
+ self.child_persist_selectable,
+ a_subset=self.parent_local_selectable,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+ except sa_exc.NoForeignKeysError as nfe:
+ if self.secondary is not None:
+ util.raise_(
+ sa_exc.NoForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are no foreign keys "
+ "linking these tables via secondary table '%s'. "
+ "Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or "
+ "specify 'primaryjoin' and 'secondaryjoin' "
+ "expressions." % (self.prop, self.secondary)
+ ),
+ from_=nfe,
+ )
+ else:
+ util.raise_(
+ sa_exc.NoForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are no foreign keys "
+ "linking these tables. "
+ "Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or "
+ "specify a 'primaryjoin' expression." % self.prop
+ ),
+ from_=nfe,
+ )
+ except sa_exc.AmbiguousForeignKeysError as afe:
+ if self.secondary is not None:
+ util.raise_(
+ sa_exc.AmbiguousForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are multiple foreign key "
+ "paths linking the tables via secondary table '%s'. "
+ "Specify the 'foreign_keys' "
+ "argument, providing a list of those columns which "
+ "should be counted as containing a foreign key "
+ "reference from the secondary table to each of the "
+ "parent and child tables."
+ % (self.prop, self.secondary)
+ ),
+ from_=afe,
+ )
+ else:
+ util.raise_(
+ sa_exc.AmbiguousForeignKeysError(
+ "Could not determine join "
+ "condition between parent/child tables on "
+ "relationship %s - there are multiple foreign key "
+ "paths linking the tables. Specify the "
+ "'foreign_keys' argument, providing a list of those "
+ "columns which should be counted as containing a "
+ "foreign key reference to the parent table."
+ % self.prop
+ ),
+ from_=afe,
+ )
+
+ @property
+ def primaryjoin_minus_local(self):
+ return _deep_deannotate(self.primaryjoin, values=("local", "remote"))
+
+ @property
+ def secondaryjoin_minus_local(self):
+ return _deep_deannotate(self.secondaryjoin, values=("local", "remote"))
+
+ @util.memoized_property
+ def primaryjoin_reverse_remote(self):
+ """Return the primaryjoin condition suitable for the
+ "reverse" direction.
+
+ If the primaryjoin was delivered here with pre-existing
+ "remote" annotations, the local/remote annotations
+ are reversed. Otherwise, the local/remote annotations
+ are removed.
+
+ """
+ if self._has_remote_annotations:
+
+ def replace(element):
+ if "remote" in element._annotations:
+ v = dict(element._annotations)
+ del v["remote"]
+ v["local"] = True
+ return element._with_annotations(v)
+ elif "local" in element._annotations:
+ v = dict(element._annotations)
+ del v["local"]
+ v["remote"] = True
+ return element._with_annotations(v)
+
+ return visitors.replacement_traverse(self.primaryjoin, {}, replace)
+ else:
+ if self._has_foreign_annotations:
+ # TODO: coverage
+ return _deep_deannotate(
+ self.primaryjoin, values=("local", "remote")
+ )
+ else:
+ return _deep_deannotate(self.primaryjoin)
+
+ def _has_annotation(self, clause, annotation):
+ for col in visitors.iterate(clause, {}):
+ if annotation in col._annotations:
+ return True
+ else:
+ return False
+
+ @util.memoized_property
+ def _has_foreign_annotations(self):
+ return self._has_annotation(self.primaryjoin, "foreign")
+
+ @util.memoized_property
+ def _has_remote_annotations(self):
+ return self._has_annotation(self.primaryjoin, "remote")
+
+ def _annotate_fks(self):
+ """Annotate the primaryjoin and secondaryjoin
+ structures with 'foreign' annotations marking columns
+ considered as foreign.
+
+ """
+ if self._has_foreign_annotations:
+ return
+
+ if self.consider_as_foreign_keys:
+ self._annotate_from_fk_list()
+ else:
+ self._annotate_present_fks()
+
+ def _annotate_from_fk_list(self):
+ def check_fk(col):
+ if col in self.consider_as_foreign_keys:
+ return col._annotate({"foreign": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, check_fk
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin, {}, check_fk
+ )
+
+ def _annotate_present_fks(self):
+ if self.secondary is not None:
+ secondarycols = util.column_set(self.secondary.c)
+ else:
+ secondarycols = set()
+
+ def is_foreign(a, b):
+ if isinstance(a, schema.Column) and isinstance(b, schema.Column):
+ if a.references(b):
+ return a
+ elif b.references(a):
+ return b
+
+ if secondarycols:
+ if a in secondarycols and b not in secondarycols:
+ return a
+ elif b in secondarycols and a not in secondarycols:
+ return b
+
+ def visit_binary(binary):
+ if not isinstance(
+ binary.left, sql.ColumnElement
+ ) or not isinstance(binary.right, sql.ColumnElement):
+ return
+
+ if (
+ "foreign" not in binary.left._annotations
+ and "foreign" not in binary.right._annotations
+ ):
+ col = is_foreign(binary.left, binary.right)
+ if col is not None:
+ if col.compare(binary.left):
+ binary.left = binary.left._annotate({"foreign": True})
+ elif col.compare(binary.right):
+ binary.right = binary.right._annotate(
+ {"foreign": True}
+ )
+
+ self.primaryjoin = visitors.cloned_traverse(
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
+ if self.secondaryjoin is not None:
+ self.secondaryjoin = visitors.cloned_traverse(
+ self.secondaryjoin, {}, {"binary": visit_binary}
+ )
+
+ def _refers_to_parent_table(self):
+ """Return True if the join condition contains column
+ comparisons where both columns are in both tables.
+
+ """
+ pt = self.parent_persist_selectable
+ mt = self.child_persist_selectable
+ result = [False]
+
+ def visit_binary(binary):
+ c, f = binary.left, binary.right
+ if (
+ isinstance(c, expression.ColumnClause)
+ and isinstance(f, expression.ColumnClause)
+ and pt.is_derived_from(c.table)
+ and pt.is_derived_from(f.table)
+ and mt.is_derived_from(c.table)
+ and mt.is_derived_from(f.table)
+ ):
+ result[0] = True
+
+ visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary})
+ return result[0]
+
+ def _tables_overlap(self):
+ """Return True if parent/child tables have some overlap."""
+
+ return selectables_overlap(
+ self.parent_persist_selectable, self.child_persist_selectable
+ )
+
+ def _annotate_remote(self):
+ """Annotate the primaryjoin and secondaryjoin
+ structures with 'remote' annotations marking columns
+ considered as part of the 'remote' side.
+
+ """
+ if self._has_remote_annotations:
+ return
+
+ if self.secondary is not None:
+ self._annotate_remote_secondary()
+ elif self._local_remote_pairs or self._remote_side:
+ self._annotate_remote_from_args()
+ elif self._refers_to_parent_table():
+ self._annotate_selfref(
+ lambda col: "foreign" in col._annotations, False
+ )
+ elif self._tables_overlap():
+ self._annotate_remote_with_overlap()
+ else:
+ self._annotate_remote_distinct_selectables()
+
+ def _annotate_remote_secondary(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when 'secondary' is present.
+
+ """
+
+ def repl(element):
+ if self.secondary.c.contains_column(element):
+ return element._annotate({"remote": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl
+ )
+ self.secondaryjoin = visitors.replacement_traverse(
+ self.secondaryjoin, {}, repl
+ )
+
+ def _annotate_selfref(self, fn, remote_side_given):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the relationship is detected as self-referential.
+
+ """
+
+ def visit_binary(binary):
+ equated = binary.left.compare(binary.right)
+ if isinstance(binary.left, expression.ColumnClause) and isinstance(
+ binary.right, expression.ColumnClause
+ ):
+ # assume one to many - FKs are "remote"
+ if fn(binary.left):
+ binary.left = binary.left._annotate({"remote": True})
+ if fn(binary.right) and not equated:
+ binary.right = binary.right._annotate({"remote": True})
+ elif not remote_side_given:
+ self._warn_non_column_elements()
+
+ self.primaryjoin = visitors.cloned_traverse(
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
+
+ def _annotate_remote_from_args(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the 'remote_side' or '_local_remote_pairs'
+ arguments are used.
+
+ """
+ if self._local_remote_pairs:
+ if self._remote_side:
+ raise sa_exc.ArgumentError(
+ "remote_side argument is redundant "
+ "against more detailed _local_remote_side "
+ "argument."
+ )
+
+ remote_side = [r for (l, r) in self._local_remote_pairs]
+ else:
+ remote_side = self._remote_side
+
+ if self._refers_to_parent_table():
+ self._annotate_selfref(lambda col: col in remote_side, True)
+ else:
+
+ def repl(element):
+ # use set() to avoid generating ``__eq__()`` expressions
+ # against each element
+ if element in set(remote_side):
+ return element._annotate({"remote": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl
+ )
+
+ def _annotate_remote_with_overlap(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the parent/child tables have some set of
+ tables in common, though is not a fully self-referential
+ relationship.
+
+ """
+
+ def visit_binary(binary):
+ binary.left, binary.right = proc_left_right(
+ binary.left, binary.right
+ )
+ binary.right, binary.left = proc_left_right(
+ binary.right, binary.left
+ )
+
+ check_entities = (
+ self.prop is not None and self.prop.mapper is not self.prop.parent
+ )
+
+ def proc_left_right(left, right):
+ if isinstance(left, expression.ColumnClause) and isinstance(
+ right, expression.ColumnClause
+ ):
+ if self.child_persist_selectable.c.contains_column(
+ right
+ ) and self.parent_persist_selectable.c.contains_column(left):
+ right = right._annotate({"remote": True})
+ elif (
+ check_entities
+ and right._annotations.get("parentmapper") is self.prop.mapper
+ ):
+ right = right._annotate({"remote": True})
+ elif (
+ check_entities
+ and left._annotations.get("parentmapper") is self.prop.mapper
+ ):
+ left = left._annotate({"remote": True})
+ else:
+ self._warn_non_column_elements()
+
+ return left, right
+
+ self.primaryjoin = visitors.cloned_traverse(
+ self.primaryjoin, {}, {"binary": visit_binary}
+ )
+
+ def _annotate_remote_distinct_selectables(self):
+ """annotate 'remote' in primaryjoin, secondaryjoin
+ when the parent/child tables are entirely
+ separate.
+
+ """
+
+ def repl(element):
+ if self.child_persist_selectable.c.contains_column(element) and (
+ not self.parent_local_selectable.c.contains_column(element)
+ or self.child_local_selectable.c.contains_column(element)
+ ):
+ return element._annotate({"remote": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, repl
+ )
+
+ def _warn_non_column_elements(self):
+ util.warn(
+ "Non-simple column elements in primary "
+ "join condition for property %s - consider using "
+ "remote() annotations to mark the remote side." % self.prop
+ )
+
+ def _annotate_local(self):
+ """Annotate the primaryjoin and secondaryjoin
+ structures with 'local' annotations.
+
+ This annotates all column elements found
+ simultaneously in the parent table
+ and the join condition that don't have a
+ 'remote' annotation set up from
+ _annotate_remote() or user-defined.
+
+ """
+ if self._has_annotation(self.primaryjoin, "local"):
+ return
+
+ if self._local_remote_pairs:
+ local_side = util.column_set(
+ [l for (l, r) in self._local_remote_pairs]
+ )
+ else:
+ local_side = util.column_set(self.parent_persist_selectable.c)
+
+ def locals_(elem):
+ if "remote" not in elem._annotations and elem in local_side:
+ return elem._annotate({"local": True})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, locals_
+ )
+
+ def _annotate_parentmapper(self):
+ if self.prop is None:
+ return
+
+ def parentmappers_(elem):
+ if "remote" in elem._annotations:
+ return elem._annotate({"parentmapper": self.prop.mapper})
+ elif "local" in elem._annotations:
+ return elem._annotate({"parentmapper": self.prop.parent})
+
+ self.primaryjoin = visitors.replacement_traverse(
+ self.primaryjoin, {}, parentmappers_
+ )
+
+ def _check_remote_side(self):
+ if not self.local_remote_pairs:
+ raise sa_exc.ArgumentError(
+ "Relationship %s could "
+ "not determine any unambiguous local/remote column "
+ "pairs based on join condition and remote_side "
+ "arguments. "
+ "Consider using the remote() annotation to "
+ "accurately mark those elements of the join "
+ "condition that are on the remote side of "
+ "the relationship." % (self.prop,)
+ )
+
+ def _check_foreign_cols(self, join_condition, primary):
+ """Check the foreign key columns collected and emit error
+ messages."""
+
+ can_sync = False
+
+ foreign_cols = self._gather_columns_with_annotation(
+ join_condition, "foreign"
+ )
+
+ has_foreign = bool(foreign_cols)
+
+ if primary:
+ can_sync = bool(self.synchronize_pairs)
+ else:
+ can_sync = bool(self.secondary_synchronize_pairs)
+
+ if (
+ self.support_sync
+ and can_sync
+ or (not self.support_sync and has_foreign)
+ ):
+ return
+
+ # from here below is just determining the best error message
+ # to report. Check for a join condition using any operator
+ # (not just ==), perhaps they need to turn on "viewonly=True".
+ if self.support_sync and has_foreign and not can_sync:
+ err = (
+ "Could not locate any simple equality expressions "
+ "involving locally mapped foreign key columns for "
+ "%s join condition "
+ "'%s' on relationship %s."
+ % (
+ primary and "primary" or "secondary",
+ join_condition,
+ self.prop,
+ )
+ )
+ err += (
+ " Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or are "
+ "annotated in the join condition with the foreign() "
+ "annotation. To allow comparison operators other than "
+ "'==', the relationship can be marked as viewonly=True."
+ )
+
+ raise sa_exc.ArgumentError(err)
+ else:
+ err = (
+ "Could not locate any relevant foreign key columns "
+ "for %s join condition '%s' on relationship %s."
+ % (
+ primary and "primary" or "secondary",
+ join_condition,
+ self.prop,
+ )
+ )
+ err += (
+ " Ensure that referencing columns are associated "
+ "with a ForeignKey or ForeignKeyConstraint, or are "
+ "annotated in the join condition with the foreign() "
+ "annotation."
+ )
+ raise sa_exc.ArgumentError(err)
+
+ def _determine_direction(self):
+ """Determine if this relationship is one to many, many to one,
+ many to many.
+
+ """
+ if self.secondaryjoin is not None:
+ self.direction = MANYTOMANY
+ else:
+ parentcols = util.column_set(self.parent_persist_selectable.c)
+ targetcols = util.column_set(self.child_persist_selectable.c)
+
+ # fk collection which suggests ONETOMANY.
+ onetomany_fk = targetcols.intersection(self.foreign_key_columns)
+
+ # fk collection which suggests MANYTOONE.
+
+ manytoone_fk = parentcols.intersection(self.foreign_key_columns)
+
+ if onetomany_fk and manytoone_fk:
+ # fks on both sides. test for overlap of local/remote
+ # with foreign key.
+ # we will gather columns directly from their annotations
+ # without deannotating, so that we can distinguish on a column
+ # that refers to itself.
+
+ # 1. columns that are both remote and FK suggest
+ # onetomany.
+ onetomany_local = self._gather_columns_with_annotation(
+ self.primaryjoin, "remote", "foreign"
+ )
+
+ # 2. columns that are FK but are not remote (e.g. local)
+ # suggest manytoone.
+ manytoone_local = set(
+ [
+ c
+ for c in self._gather_columns_with_annotation(
+ self.primaryjoin, "foreign"
+ )
+ if "remote" not in c._annotations
+ ]
+ )
+
+ # 3. if both collections are present, remove columns that
+ # refer to themselves. This is for the case of
+ # and_(Me.id == Me.remote_id, Me.version == Me.version)
+ if onetomany_local and manytoone_local:
+ self_equated = self.remote_columns.intersection(
+ self.local_columns
+ )
+ onetomany_local = onetomany_local.difference(self_equated)
+ manytoone_local = manytoone_local.difference(self_equated)
+
+ # at this point, if only one or the other collection is
+ # present, we know the direction, otherwise it's still
+ # ambiguous.
+
+ if onetomany_local and not manytoone_local:
+ self.direction = ONETOMANY
+ elif manytoone_local and not onetomany_local:
+ self.direction = MANYTOONE
+ else:
+ raise sa_exc.ArgumentError(
+ "Can't determine relationship"
+ " direction for relationship '%s' - foreign "
+ "key columns within the join condition are present "
+ "in both the parent and the child's mapped tables. "
+ "Ensure that only those columns referring "
+ "to a parent column are marked as foreign, "
+ "either via the foreign() annotation or "
+ "via the foreign_keys argument." % self.prop
+ )
+ elif onetomany_fk:
+ self.direction = ONETOMANY
+ elif manytoone_fk:
+ self.direction = MANYTOONE
+ else:
+ raise sa_exc.ArgumentError(
+ "Can't determine relationship "
+ "direction for relationship '%s' - foreign "
+ "key columns are present in neither the parent "
+ "nor the child's mapped tables" % self.prop
+ )
+
+ def _deannotate_pairs(self, collection):
+ """provide deannotation for the various lists of
+ pairs, so that using them in hashes doesn't incur
+ high-overhead __eq__() comparisons against
+ original columns mapped.
+
+ """
+ return [(x._deannotate(), y._deannotate()) for x, y in collection]
+
+ def _setup_pairs(self):
+ sync_pairs = []
+ lrp = util.OrderedSet([])
+ secondary_sync_pairs = []
+
+ def go(joincond, collection):
+ def visit_binary(binary, left, right):
+ if (
+ "remote" in right._annotations
+ and "remote" not in left._annotations
+ and self.can_be_synced_fn(left)
+ ):
+ lrp.add((left, right))
+ elif (
+ "remote" in left._annotations
+ and "remote" not in right._annotations
+ and self.can_be_synced_fn(right)
+ ):
+ lrp.add((right, left))
+ if binary.operator is operators.eq and self.can_be_synced_fn(
+ left, right
+ ):
+ if "foreign" in right._annotations:
+ collection.append((left, right))
+ elif "foreign" in left._annotations:
+ collection.append((right, left))
+
+ visit_binary_product(visit_binary, joincond)
+
+ for joincond, collection in [
+ (self.primaryjoin, sync_pairs),
+ (self.secondaryjoin, secondary_sync_pairs),
+ ]:
+ if joincond is None:
+ continue
+ go(joincond, collection)
+
+ self.local_remote_pairs = self._deannotate_pairs(lrp)
+ self.synchronize_pairs = self._deannotate_pairs(sync_pairs)
+ self.secondary_synchronize_pairs = self._deannotate_pairs(
+ secondary_sync_pairs
+ )
+
+ _track_overlapping_sync_targets = weakref.WeakKeyDictionary()
+
+ def _warn_for_conflicting_sync_targets(self):
+ if not self.support_sync:
+ return
+
+ # we would like to detect if we are synchronizing any column
+ # pairs in conflict with another relationship that wishes to sync
+ # an entirely different column to the same target. This is a
+ # very rare edge case so we will try to minimize the memory/overhead
+ # impact of this check
+ for from_, to_ in [
+ (from_, to_) for (from_, to_) in self.synchronize_pairs
+ ] + [
+ (from_, to_) for (from_, to_) in self.secondary_synchronize_pairs
+ ]:
+ # save ourselves a ton of memory and overhead by only
+ # considering columns that are subject to a overlapping
+ # FK constraints at the core level. This condition can arise
+ # if multiple relationships overlap foreign() directly, but
+ # we're going to assume it's typically a ForeignKeyConstraint-
+ # level configuration that benefits from this warning.
+
+ if to_ not in self._track_overlapping_sync_targets:
+ self._track_overlapping_sync_targets[
+ to_
+ ] = weakref.WeakKeyDictionary({self.prop: from_})
+ else:
+ other_props = []
+ prop_to_from = self._track_overlapping_sync_targets[to_]
+
+ for pr, fr_ in prop_to_from.items():
+ if (
+ not pr.mapper._dispose_called
+ and pr not in self.prop._reverse_property
+ and pr.key not in self.prop._overlaps
+ and self.prop.key not in pr._overlaps
+ # note: the "__*" symbol is used internally by
+ # SQLAlchemy as a general means of suppressing the
+ # overlaps warning for some extension cases, however
+ # this is not currently
+ # a publicly supported symbol and may change at
+ # any time.
+ and "__*" not in self.prop._overlaps
+ and "__*" not in pr._overlaps
+ and not self.prop.parent.is_sibling(pr.parent)
+ and not self.prop.mapper.is_sibling(pr.mapper)
+ and not self.prop.parent.is_sibling(pr.mapper)
+ and not self.prop.mapper.is_sibling(pr.parent)
+ and (
+ self.prop.key != pr.key
+ or not self.prop.parent.common_parent(pr.parent)
+ )
+ ):
+
+ other_props.append((pr, fr_))
+
+ if other_props:
+ util.warn(
+ "relationship '%s' will copy column %s to column %s, "
+ "which conflicts with relationship(s): %s. "
+ "If this is not the intention, consider if these "
+ "relationships should be linked with "
+ "back_populates, or if viewonly=True should be "
+ "applied to one or more if they are read-only. "
+ "For the less common case that foreign key "
+ "constraints are partially overlapping, the "
+ "orm.foreign() "
+ "annotation can be used to isolate the columns that "
+ "should be written towards. To silence this "
+ "warning, add the parameter 'overlaps=\"%s\"' to the "
+ "'%s' relationship."
+ % (
+ self.prop,
+ from_,
+ to_,
+ ", ".join(
+ sorted(
+ "'%s' (copies %s to %s)" % (pr, fr_, to_)
+ for (pr, fr_) in other_props
+ )
+ ),
+ ",".join(sorted(pr.key for pr, fr in other_props)),
+ self.prop,
+ ),
+ code="qzyx",
+ )
+ self._track_overlapping_sync_targets[to_][self.prop] = from_
+
+ @util.memoized_property
+ def remote_columns(self):
+ return self._gather_join_annotations("remote")
+
+ @util.memoized_property
+ def local_columns(self):
+ return self._gather_join_annotations("local")
+
+ @util.memoized_property
+ def foreign_key_columns(self):
+ return self._gather_join_annotations("foreign")
+
+ def _gather_join_annotations(self, annotation):
+ s = set(
+ self._gather_columns_with_annotation(self.primaryjoin, annotation)
+ )
+ if self.secondaryjoin is not None:
+ s.update(
+ self._gather_columns_with_annotation(
+ self.secondaryjoin, annotation
+ )
+ )
+ return {x._deannotate() for x in s}
+
+ def _gather_columns_with_annotation(self, clause, *annotation):
+ annotation = set(annotation)
+ return set(
+ [
+ col
+ for col in visitors.iterate(clause, {})
+ if annotation.issubset(col._annotations)
+ ]
+ )
+
+ def join_targets(
+ self,
+ source_selectable,
+ dest_selectable,
+ aliased,
+ single_crit=None,
+ extra_criteria=(),
+ ):
+ """Given a source and destination selectable, create a
+ join between them.
+
+ This takes into account aliasing the join clause
+ to reference the appropriate corresponding columns
+ in the target objects, as well as the extra child
+ criterion, equivalent column sets, etc.
+
+ """
+ # place a barrier on the destination such that
+ # replacement traversals won't ever dig into it.
+ # its internal structure remains fixed
+ # regardless of context.
+ dest_selectable = _shallow_annotate(
+ dest_selectable, {"no_replacement_traverse": True}
+ )
+
+ primaryjoin, secondaryjoin, secondary = (
+ self.primaryjoin,
+ self.secondaryjoin,
+ self.secondary,
+ )
+
+ # adjust the join condition for single table inheritance,
+ # in the case that the join is to a subclass
+ # this is analogous to the
+ # "_adjust_for_single_table_inheritance()" method in Query.
+
+ if single_crit is not None:
+ if secondaryjoin is not None:
+ secondaryjoin = secondaryjoin & single_crit
+ else:
+ primaryjoin = primaryjoin & single_crit
+
+ if extra_criteria:
+ if secondaryjoin is not None:
+ secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
+ else:
+ primaryjoin = primaryjoin & sql.and_(*extra_criteria)
+
+ if aliased:
+ if secondary is not None:
+ secondary = secondary._anonymous_fromclause(flat=True)
+ primary_aliasizer = ClauseAdapter(
+ secondary, exclude_fn=_ColInAnnotations("local")
+ )
+ secondary_aliasizer = ClauseAdapter(
+ dest_selectable, equivalents=self.child_equivalents
+ ).chain(primary_aliasizer)
+ if source_selectable is not None:
+ primary_aliasizer = ClauseAdapter(
+ secondary, exclude_fn=_ColInAnnotations("local")
+ ).chain(
+ ClauseAdapter(
+ source_selectable,
+ equivalents=self.parent_equivalents,
+ )
+ )
+
+ secondaryjoin = secondary_aliasizer.traverse(secondaryjoin)
+ else:
+ primary_aliasizer = ClauseAdapter(
+ dest_selectable,
+ exclude_fn=_ColInAnnotations("local"),
+ equivalents=self.child_equivalents,
+ )
+ if source_selectable is not None:
+ primary_aliasizer.chain(
+ ClauseAdapter(
+ source_selectable,
+ exclude_fn=_ColInAnnotations("remote"),
+ equivalents=self.parent_equivalents,
+ )
+ )
+ secondary_aliasizer = None
+
+ primaryjoin = primary_aliasizer.traverse(primaryjoin)
+ target_adapter = secondary_aliasizer or primary_aliasizer
+ target_adapter.exclude_fn = None
+ else:
+ target_adapter = None
+ return (
+ primaryjoin,
+ secondaryjoin,
+ secondary,
+ target_adapter,
+ dest_selectable,
+ )
+
+ def create_lazy_clause(self, reverse_direction=False):
+ binds = util.column_dict()
+ equated_columns = util.column_dict()
+
+ has_secondary = self.secondaryjoin is not None
+
+ if has_secondary:
+ lookup = collections.defaultdict(list)
+ for l, r in self.local_remote_pairs:
+ lookup[l].append((l, r))
+ equated_columns[r] = l
+ elif not reverse_direction:
+ for l, r in self.local_remote_pairs:
+ equated_columns[r] = l
+ else:
+ for l, r in self.local_remote_pairs:
+ equated_columns[l] = r
+
+ def col_to_bind(col):
+
+ if (
+ (not reverse_direction and "local" in col._annotations)
+ or reverse_direction
+ and (
+ (has_secondary and col in lookup)
+ or (not has_secondary and "remote" in col._annotations)
+ )
+ ):
+ if col not in binds:
+ binds[col] = sql.bindparam(
+ None, None, type_=col.type, unique=True
+ )
+ return binds[col]
+ return None
+
+ lazywhere = self.primaryjoin
+ if self.secondaryjoin is None or not reverse_direction:
+ lazywhere = visitors.replacement_traverse(
+ lazywhere, {}, col_to_bind
+ )
+
+ if self.secondaryjoin is not None:
+ secondaryjoin = self.secondaryjoin
+ if reverse_direction:
+ secondaryjoin = visitors.replacement_traverse(
+ secondaryjoin, {}, col_to_bind
+ )
+ lazywhere = sql.and_(lazywhere, secondaryjoin)
+
+ bind_to_col = {binds[col].key: col for col in binds}
+
+ return lazywhere, bind_to_col, equated_columns
+
+
+class _ColInAnnotations(object):
+ """Serializable object that tests for a name in c._annotations."""
+
+ __slots__ = ("name",)
+
+ def __init__(self, name):
+ self.name = name
+
+ def __call__(self, c):
+ return self.name in c._annotations
diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py
new file mode 100644
index 0000000..f323233
--- /dev/null
+++ b/lib/sqlalchemy/orm/scoping.py
@@ -0,0 +1,228 @@
+# orm/scoping.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import class_mapper
+from . import exc as orm_exc
+from .session import Session
+from .. import exc as sa_exc
+from ..util import create_proxy_methods
+from ..util import ScopedRegistry
+from ..util import ThreadLocalRegistry
+from ..util import warn
+from ..util import warn_deprecated
+
+__all__ = ["scoped_session", "ScopedSessionMixin"]
+
+
+class ScopedSessionMixin(object):
+ @property
+ def _proxied(self):
+ return self.registry()
+
+ def __call__(self, **kw):
+ r"""Return the current :class:`.Session`, creating it
+ using the :attr:`.scoped_session.session_factory` if not present.
+
+ :param \**kw: Keyword arguments will be passed to the
+ :attr:`.scoped_session.session_factory` callable, if an existing
+ :class:`.Session` is not present. If the :class:`.Session` is present
+ and keyword arguments have been passed,
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if kw:
+ if self.registry.has():
+ raise sa_exc.InvalidRequestError(
+ "Scoped session is already present; "
+ "no new arguments may be specified."
+ )
+ else:
+ sess = self.session_factory(**kw)
+ self.registry.set(sess)
+ else:
+ sess = self.registry()
+ if not self._support_async and sess._is_asyncio:
+ warn_deprecated(
+ "Using `scoped_session` with asyncio is deprecated and "
+ "will raise an error in a future version. "
+ "Please use `async_scoped_session` instead.",
+ "1.4.23",
+ )
+ return sess
+
+ def configure(self, **kwargs):
+ """reconfigure the :class:`.sessionmaker` used by this
+ :class:`.scoped_session`.
+
+ See :meth:`.sessionmaker.configure`.
+
+ """
+
+ if self.registry.has():
+ warn(
+ "At least one scoped session is already present. "
+ " configure() can not affect sessions that have "
+ "already been created."
+ )
+
+ self.session_factory.configure(**kwargs)
+
+
+@create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_orm.scoping.scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get",
+ "get_bind",
+ "is_modified",
+ "bulk_save_objects",
+ "bulk_insert_mappings",
+ "bulk_update_mappings",
+ "merge",
+ "query",
+ "refresh",
+ "rollback",
+ "scalar",
+ "scalars",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ "autocommit",
+ ],
+)
+class scoped_session(ScopedSessionMixin):
+ """Provides scoped management of :class:`.Session` objects.
+
+ See :ref:`unitofwork_contextual` for a tutorial.
+
+ .. note::
+
+ When using :ref:`asyncio_toplevel`, the async-compatible
+ :class:`_asyncio.async_scoped_session` class should be
+ used in place of :class:`.scoped_session`.
+
+ """
+
+ _support_async = False
+
+ session_factory = None
+ """The `session_factory` provided to `__init__` is stored in this
+ attribute and may be accessed at a later time. This can be useful when
+ a new non-scoped :class:`.Session` or :class:`_engine.Connection` to the
+ database is needed."""
+
+ def __init__(self, session_factory, scopefunc=None):
+ """Construct a new :class:`.scoped_session`.
+
+ :param session_factory: a factory to create new :class:`.Session`
+ instances. This is usually, but not necessarily, an instance
+ of :class:`.sessionmaker`.
+ :param scopefunc: optional function which defines
+ the current scope. If not passed, the :class:`.scoped_session`
+ object assumes "thread-local" scope, and will use
+ a Python ``threading.local()`` in order to maintain the current
+ :class:`.Session`. If passed, the function should return
+ a hashable token; this token will be used as the key in a
+ dictionary in order to store and retrieve the current
+ :class:`.Session`.
+
+ """
+ self.session_factory = session_factory
+
+ if scopefunc:
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+ else:
+ self.registry = ThreadLocalRegistry(session_factory)
+
+ def remove(self):
+ """Dispose of the current :class:`.Session`, if present.
+
+ This will first call :meth:`.Session.close` method
+ on the current :class:`.Session`, which releases any existing
+ transactional/connection resources still being held; transactions
+ specifically are rolled back. The :class:`.Session` is then
+ discarded. Upon next usage within the same scope,
+ the :class:`.scoped_session` will produce a new
+ :class:`.Session` object.
+
+ """
+
+ if self.registry.has():
+ self.registry().close()
+ self.registry.clear()
+
+ def query_property(self, query_cls=None):
+ """return a class property which produces a :class:`_query.Query`
+ object
+ against the class and the current :class:`.Session` when called.
+
+ e.g.::
+
+ Session = scoped_session(sessionmaker())
+
+ class MyClass(object):
+ query = Session.query_property()
+
+ # after mappers are defined
+ result = MyClass.query.filter(MyClass.name=='foo').all()
+
+ Produces instances of the session's configured query class by
+ default. To override and use a custom implementation, provide
+ a ``query_cls`` callable. The callable will be invoked with
+ the class's mapper as a positional argument and a session
+ keyword argument.
+
+ There is no limit to the number of query properties placed on
+ a class.
+
+ """
+
+ class query(object):
+ def __get__(s, instance, owner):
+ try:
+ mapper = class_mapper(owner)
+ if mapper:
+ if query_cls:
+ # custom query class
+ return query_cls(mapper, session=self.registry())
+ else:
+ # session's configured query class
+ return self.registry().query(mapper)
+ except orm_exc.UnmappedClassError:
+ return None
+
+ return query()
+
+
+ScopedSession = scoped_session
+"""Old name for backwards compatibility."""
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
new file mode 100644
index 0000000..c6a9169
--- /dev/null
+++ b/lib/sqlalchemy/orm/session.py
@@ -0,0 +1,4386 @@
+# orm/session.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Provides the Session class and related utilities."""
+
+
+import itertools
+import sys
+import weakref
+
+from . import attributes
+from . import context
+from . import exc
+from . import identity
+from . import loading
+from . import persistence
+from . import query
+from . import state as statelib
+from .base import _class_to_mapper
+from .base import _none_set
+from .base import _state_mapper
+from .base import instance_str
+from .base import object_mapper
+from .base import object_state
+from .base import state_str
+from .unitofwork import UOWTransaction
+from .. import engine
+from .. import exc as sa_exc
+from .. import sql
+from .. import util
+from ..engine.util import TransactionalContext
+from ..inspection import inspect
+from ..sql import coercions
+from ..sql import dml
+from ..sql import roles
+from ..sql import visitors
+from ..sql.base import CompileState
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+
+__all__ = [
+ "Session",
+ "SessionTransaction",
+ "sessionmaker",
+ "ORMExecuteState",
+ "close_all_sessions",
+ "make_transient",
+ "make_transient_to_detached",
+ "object_session",
+]
+
+_sessions = weakref.WeakValueDictionary()
+"""Weak-referencing dictionary of :class:`.Session` objects.
+"""
+
+statelib._sessions = _sessions
+
+
+def _state_session(state):
+ """Given an :class:`.InstanceState`, return the :class:`.Session`
+ associated, if any.
+ """
+ return state.session
+
+
+class _SessionClassMethods(object):
+ """Class-level methods for :class:`.Session`, :class:`.sessionmaker`."""
+
+ @classmethod
+ @util.deprecated(
+ "1.3",
+ "The :meth:`.Session.close_all` method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":func:`.session.close_all_sessions`.",
+ )
+ def close_all(cls):
+ """Close *all* sessions in memory."""
+
+ close_all_sessions()
+
+ @classmethod
+ @util.preload_module("sqlalchemy.orm.util")
+ def identity_key(cls, *args, **kwargs):
+ """Return an identity key.
+
+ This is an alias of :func:`.util.identity_key`.
+
+ """
+ return util.preloaded.orm_util.identity_key(*args, **kwargs)
+
+ @classmethod
+ def object_session(cls, instance):
+ """Return the :class:`.Session` to which an object belongs.
+
+ This is an alias of :func:`.object_session`.
+
+ """
+
+ return object_session(instance)
+
+
+ACTIVE = util.symbol("ACTIVE")
+PREPARED = util.symbol("PREPARED")
+COMMITTED = util.symbol("COMMITTED")
+DEACTIVE = util.symbol("DEACTIVE")
+CLOSED = util.symbol("CLOSED")
+
+
+class ORMExecuteState(util.MemoizedSlots):
+ """Represents a call to the :meth:`_orm.Session.execute` method, as passed
+ to the :meth:`.SessionEvents.do_orm_execute` event hook.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`session_execute_events` - top level documentation on how
+ to use :meth:`_orm.SessionEvents.do_orm_execute`
+
+ """
+
+ __slots__ = (
+ "session",
+ "statement",
+ "parameters",
+ "execution_options",
+ "local_execution_options",
+ "bind_arguments",
+ "_compile_state_cls",
+ "_starting_event_idx",
+ "_events_todo",
+ "_update_execution_options",
+ )
+
+ def __init__(
+ self,
+ session,
+ statement,
+ parameters,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
+ events_todo,
+ ):
+ self.session = session
+ self.statement = statement
+ self.parameters = parameters
+ self.local_execution_options = execution_options
+ self.execution_options = statement._execution_options.union(
+ execution_options
+ )
+ self.bind_arguments = bind_arguments
+ self._compile_state_cls = compile_state_cls
+ self._events_todo = list(events_todo)
+
+ def _remaining_events(self):
+ return self._events_todo[self._starting_event_idx + 1 :]
+
+ def invoke_statement(
+ self,
+ statement=None,
+ params=None,
+ execution_options=None,
+ bind_arguments=None,
+ ):
+ """Execute the statement represented by this
+ :class:`.ORMExecuteState`, without re-invoking events that have
+ already proceeded.
+
+ This method essentially performs a re-entrant execution of the current
+ statement for which the :meth:`.SessionEvents.do_orm_execute` event is
+ being currently invoked. The use case for this is for event handlers
+ that want to override how the ultimate
+ :class:`_engine.Result` object is returned, such as for schemes that
+ retrieve results from an offline cache or which concatenate results
+ from multiple executions.
+
+ When the :class:`_engine.Result` object is returned by the actual
+ handler function within :meth:`_orm.SessionEvents.do_orm_execute` and
+ is propagated to the calling
+ :meth:`_orm.Session.execute` method, the remainder of the
+ :meth:`_orm.Session.execute` method is preempted and the
+ :class:`_engine.Result` object is returned to the caller of
+ :meth:`_orm.Session.execute` immediately.
+
+ :param statement: optional statement to be invoked, in place of the
+ statement currently represented by :attr:`.ORMExecuteState.statement`.
+
+ :param params: optional dictionary of parameters which will be merged
+ into the existing :attr:`.ORMExecuteState.parameters` of this
+ :class:`.ORMExecuteState`.
+
+ :param execution_options: optional dictionary of execution options
+ will be merged into the existing
+ :attr:`.ORMExecuteState.execution_options` of this
+ :class:`.ORMExecuteState`.
+
+ :param bind_arguments: optional dictionary of bind_arguments
+ which will be merged amongst the current
+ :attr:`.ORMExecuteState.bind_arguments`
+ of this :class:`.ORMExecuteState`.
+
+ :return: a :class:`_engine.Result` object with ORM-level results.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - background and examples on the
+ appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`.
+
+
+ """
+
+ if statement is None:
+ statement = self.statement
+
+ _bind_arguments = dict(self.bind_arguments)
+ if bind_arguments:
+ _bind_arguments.update(bind_arguments)
+ _bind_arguments["_sa_skip_events"] = True
+
+ if params:
+ _params = dict(self.parameters)
+ _params.update(params)
+ else:
+ _params = self.parameters
+
+ _execution_options = self.local_execution_options
+ if execution_options:
+ _execution_options = _execution_options.union(execution_options)
+
+ return self.session.execute(
+ statement,
+ _params,
+ _execution_options,
+ _bind_arguments,
+ _parent_execute_state=self,
+ )
+
+ @property
+ def bind_mapper(self):
+ """Return the :class:`_orm.Mapper` that is the primary "bind" mapper.
+
+ For an :class:`_orm.ORMExecuteState` object invoking an ORM
+ statement, that is, the :attr:`_orm.ORMExecuteState.is_orm_statement`
+ attribute is ``True``, this attribute will return the
+ :class:`_orm.Mapper` that is considered to be the "primary" mapper
+ of the statement. The term "bind mapper" refers to the fact that
+ a :class:`_orm.Session` object may be "bound" to multiple
+ :class:`_engine.Engine` objects keyed to mapped classes, and the
+ "bind mapper" determines which of those :class:`_engine.Engine` objects
+ would be selected.
+
+ For a statement that is invoked against a single mapped class,
+ :attr:`_orm.ORMExecuteState.bind_mapper` is intended to be a reliable
+ way of getting this mapper.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.all_mappers`
+
+
+ """
+ return self.bind_arguments.get("mapper", None)
+
+ @property
+ def all_mappers(self):
+ """Return a sequence of all :class:`_orm.Mapper` objects that are
+ involved at the top level of this statement.
+
+ By "top level" we mean those :class:`_orm.Mapper` objects that would
+ be represented in the result set rows for a :func:`_sql.select`
+ query, or for a :func:`_dml.update` or :func:`_dml.delete` query,
+ the mapper that is the main subject of the UPDATE or DELETE.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.bind_mapper`
+
+
+
+ """
+ if not self.is_orm_statement:
+ return []
+ elif self.is_select:
+ result = []
+ seen = set()
+ for d in self.statement.column_descriptions:
+ ent = d["entity"]
+ if ent:
+ insp = inspect(ent, raiseerr=False)
+ if insp and insp.mapper and insp.mapper not in seen:
+ seen.add(insp.mapper)
+ result.append(insp.mapper)
+ return result
+ elif self.is_update or self.is_delete:
+ return [self.bind_mapper]
+ else:
+ return []
+
+ @property
+ def is_orm_statement(self):
+ """return True if the operation is an ORM statement.
+
+ This indicates that the select(), update(), or delete() being
+ invoked contains ORM entities as subjects. For a statement
+ that does not have ORM entities and instead refers only to
+ :class:`.Table` metadata, it is invoked as a Core SQL statement
+ and no ORM-level automation takes place.
+
+ """
+ return self._compile_state_cls is not None
+
+ @property
+ def is_select(self):
+ """return True if this is a SELECT operation."""
+ return self.statement.is_select
+
+ @property
+ def is_insert(self):
+ """return True if this is an INSERT operation."""
+ return self.statement.is_dml and self.statement.is_insert
+
+ @property
+ def is_update(self):
+ """return True if this is an UPDATE operation."""
+ return self.statement.is_dml and self.statement.is_update
+
+ @property
+ def is_delete(self):
+ """return True if this is a DELETE operation."""
+ return self.statement.is_dml and self.statement.is_delete
+
+ @property
+ def _is_crud(self):
+ return isinstance(self.statement, (dml.Update, dml.Delete))
+
+ def update_execution_options(self, **opts):
+ # TODO: no coverage
+ self.local_execution_options = self.local_execution_options.union(opts)
+
+ def _orm_compile_options(self):
+ if not self.is_select:
+ return None
+ opts = self.statement._compile_options
+ if opts.isinstance(context.ORMCompileState.default_compile_options):
+ return opts
+ else:
+ return None
+
+ @property
+ def lazy_loaded_from(self):
+ """An :class:`.InstanceState` that is using this statement execution
+ for a lazy load operation.
+
+ The primary rationale for this attribute is to support the horizontal
+ sharding extension, where it is available within specific query
+ execution time hooks created by this extension. To that end, the
+ attribute is only intended to be meaningful at **query execution
+ time**, and importantly not any time prior to that, including query
+ compilation time.
+
+ """
+ return self.load_options._lazy_loaded_from
+
+ @property
+ def loader_strategy_path(self):
+ """Return the :class:`.PathRegistry` for the current load path.
+
+ This object represents the "path" in a query along relationships
+ when a particular object or collection is being loaded.
+
+ """
+ opts = self._orm_compile_options()
+ if opts is not None:
+ return opts._current_path
+ else:
+ return None
+
+ @property
+ def is_column_load(self):
+ """Return True if the operation is refreshing column-oriented
+ attributes on an existing ORM object.
+
+ This occurs during operations such as :meth:`_orm.Session.refresh`,
+ as well as when an attribute deferred by :func:`_orm.defer` is
+ being loaded, or an attribute that was expired either directly
+ by :meth:`_orm.Session.expire` or via a commit operation is being
+ loaded.
+
+ Handlers will very likely not want to add any options to queries
+ when such an operation is occurring as the query should be a straight
+ primary key fetch which should not have any additional WHERE criteria,
+ and loader options travelling with the instance
+ will have already been added to the query.
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.is_relationship_load`
+
+ """
+ opts = self._orm_compile_options()
+ return opts is not None and opts._for_refresh_state
+
+ @property
+ def is_relationship_load(self):
+ """Return True if this load is loading objects on behalf of a
+ relationship.
+
+ This means, the loader in effect is either a LazyLoader,
+ SelectInLoader, SubqueryLoader, or similar, and the entire
+ SELECT statement being emitted is on behalf of a relationship
+ load.
+
+ Handlers will very likely not want to add any options to queries
+ when such an operation is occurring, as loader options are already
+ capable of being propagated to relationship loaders and should
+ be already present.
+
+ .. seealso::
+
+ :attr:`_orm.ORMExecuteState.is_column_load`
+
+ """
+ opts = self._orm_compile_options()
+ if opts is None:
+ return False
+ path = self.loader_strategy_path
+ return path is not None and not path.is_root
+
+ @property
+ def load_options(self):
+ """Return the load_options that will be used for this execution."""
+
+ if not self.is_select:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against a SELECT statement "
+ "so there are no load options."
+ )
+ return self.execution_options.get(
+ "_sa_orm_load_options", context.QueryContext.default_load_options
+ )
+
+ @property
+ def update_delete_options(self):
+ """Return the update_delete_options that will be used for this
+ execution."""
+
+ if not self._is_crud:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against an UPDATE or DELETE "
+ "statement so there are no update options."
+ )
+ return self.execution_options.get(
+ "_sa_orm_update_options",
+ persistence.BulkUDCompileState.default_update_options,
+ )
+
+ @property
+ def user_defined_options(self):
+ """The sequence of :class:`.UserDefinedOptions` that have been
+ associated with the statement being invoked.
+
+ """
+ return [
+ opt
+ for opt in self.statement._with_options
+ if not opt._is_compile_state and not opt._is_legacy_option
+ ]
+
+
+class SessionTransaction(TransactionalContext):
+ """A :class:`.Session`-level transaction.
+
+ :class:`.SessionTransaction` is produced from the
+ :meth:`_orm.Session.begin`
+ and :meth:`_orm.Session.begin_nested` methods. It's largely an internal
+ object that in modern use provides a context manager for session
+ transactions.
+
+ Documentation on interacting with :class:`_orm.SessionTransaction` is
+ at: :ref:`unitofwork_transaction`.
+
+
+ .. versionchanged:: 1.4 The scoping and API methods to work with the
+ :class:`_orm.SessionTransaction` object directly have been simplified.
+
+ .. seealso::
+
+ :ref:`unitofwork_transaction`
+
+ :meth:`.Session.begin`
+
+ :meth:`.Session.begin_nested`
+
+ :meth:`.Session.rollback`
+
+ :meth:`.Session.commit`
+
+ :meth:`.Session.in_transaction`
+
+ :meth:`.Session.in_nested_transaction`
+
+ :meth:`.Session.get_transaction`
+
+ :meth:`.Session.get_nested_transaction`
+
+
+ """
+
+ _rollback_exception = None
+
+ def __init__(
+ self,
+ session,
+ parent=None,
+ nested=False,
+ autobegin=False,
+ ):
+ TransactionalContext._trans_ctx_check(session)
+
+ self.session = session
+ self._connections = {}
+ self._parent = parent
+ self.nested = nested
+ if nested:
+ self._previous_nested_transaction = session._nested_transaction
+ self._state = ACTIVE
+ if not parent and nested:
+ raise sa_exc.InvalidRequestError(
+ "Can't start a SAVEPOINT transaction when no existing "
+ "transaction is in progress"
+ )
+
+ self._take_snapshot(autobegin=autobegin)
+
+ # make sure transaction is assigned before we call the
+ # dispatch
+ self.session._transaction = self
+
+ self.session.dispatch.after_transaction_create(self.session, self)
+
+ @property
+ def parent(self):
+ """The parent :class:`.SessionTransaction` of this
+ :class:`.SessionTransaction`.
+
+ If this attribute is ``None``, indicates this
+ :class:`.SessionTransaction` is at the top of the stack, and
+ corresponds to a real "COMMIT"/"ROLLBACK"
+ block. If non-``None``, then this is either a "subtransaction"
+ or a "nested" / SAVEPOINT transaction. If the
+ :attr:`.SessionTransaction.nested` attribute is ``True``, then
+ this is a SAVEPOINT, and if ``False``, indicates this a subtransaction.
+
+ .. versionadded:: 1.0.16 - use ._parent for previous versions
+
+ """
+ return self._parent
+
+ nested = False
+ """Indicates if this is a nested, or SAVEPOINT, transaction.
+
+ When :attr:`.SessionTransaction.nested` is True, it is expected
+ that :attr:`.SessionTransaction.parent` will be True as well.
+
+ """
+
+ @property
+ def is_active(self):
+ return self.session is not None and self._state is ACTIVE
+
+ def _assert_active(
+ self,
+ prepared_ok=False,
+ rollback_ok=False,
+ deactive_ok=False,
+ closed_msg="This transaction is closed",
+ ):
+ if self._state is COMMITTED:
+ raise sa_exc.InvalidRequestError(
+ "This session is in 'committed' state; no further "
+ "SQL can be emitted within this transaction."
+ )
+ elif self._state is PREPARED:
+ if not prepared_ok:
+ raise sa_exc.InvalidRequestError(
+ "This session is in 'prepared' state; no further "
+ "SQL can be emitted within this transaction."
+ )
+ elif self._state is DEACTIVE:
+ if not deactive_ok and not rollback_ok:
+ if self._rollback_exception:
+ raise sa_exc.PendingRollbackError(
+ "This Session's transaction has been rolled back "
+ "due to a previous exception during flush."
+ " To begin a new transaction with this Session, "
+ "first issue Session.rollback()."
+ " Original exception was: %s"
+ % self._rollback_exception,
+ code="7s2a",
+ )
+ elif not deactive_ok:
+ raise sa_exc.InvalidRequestError(
+ "This session is in 'inactive' state, due to the "
+ "SQL transaction being rolled back; no further "
+ "SQL can be emitted within this transaction."
+ )
+ elif self._state is CLOSED:
+ raise sa_exc.ResourceClosedError(closed_msg)
+
+ @property
+ def _is_transaction_boundary(self):
+ return self.nested or not self._parent
+
+ def connection(self, bindkey, execution_options=None, **kwargs):
+ self._assert_active()
+ bind = self.session.get_bind(bindkey, **kwargs)
+ return self._connection_for_bind(bind, execution_options)
+
+ def _begin(self, nested=False):
+ self._assert_active()
+ return SessionTransaction(self.session, self, nested=nested)
+
+ def _iterate_self_and_parents(self, upto=None):
+
+ current = self
+ result = ()
+ while current:
+ result += (current,)
+ if current._parent is upto:
+ break
+ elif current._parent is None:
+ raise sa_exc.InvalidRequestError(
+ "Transaction %s is not on the active transaction list"
+ % (upto)
+ )
+ else:
+ current = current._parent
+
+ return result
+
+ def _take_snapshot(self, autobegin=False):
+ if not self._is_transaction_boundary:
+ self._new = self._parent._new
+ self._deleted = self._parent._deleted
+ self._dirty = self._parent._dirty
+ self._key_switches = self._parent._key_switches
+ return
+
+ if not autobegin and not self.session._flushing:
+ self.session.flush()
+
+ self._new = weakref.WeakKeyDictionary()
+ self._deleted = weakref.WeakKeyDictionary()
+ self._dirty = weakref.WeakKeyDictionary()
+ self._key_switches = weakref.WeakKeyDictionary()
+
+ def _restore_snapshot(self, dirty_only=False):
+ """Restore the restoration state taken before a transaction began.
+
+ Corresponds to a rollback.
+
+ """
+ assert self._is_transaction_boundary
+
+ to_expunge = set(self._new).union(self.session._new)
+ self.session._expunge_states(to_expunge, to_transient=True)
+
+ for s, (oldkey, newkey) in self._key_switches.items():
+ # we probably can do this conditionally based on
+ # if we expunged or not, but safe_discard does that anyway
+ self.session.identity_map.safe_discard(s)
+
+ # restore the old key
+ s.key = oldkey
+
+ # now restore the object, but only if we didn't expunge
+ if s not in to_expunge:
+ self.session.identity_map.replace(s)
+
+ for s in set(self._deleted).union(self.session._deleted):
+ self.session._update_impl(s, revert_deletion=True)
+
+ assert not self.session._deleted
+
+ for s in self.session.identity_map.all_states():
+ if not dirty_only or s.modified or s in self._dirty:
+ s._expire(s.dict, self.session.identity_map._modified)
+
+ def _remove_snapshot(self):
+ """Remove the restoration state taken before a transaction began.
+
+ Corresponds to a commit.
+
+ """
+ assert self._is_transaction_boundary
+
+ if not self.nested and self.session.expire_on_commit:
+ for s in self.session.identity_map.all_states():
+ s._expire(s.dict, self.session.identity_map._modified)
+
+ statelib.InstanceState._detach_states(
+ list(self._deleted), self.session
+ )
+ self._deleted.clear()
+ elif self.nested:
+ self._parent._new.update(self._new)
+ self._parent._dirty.update(self._dirty)
+ self._parent._deleted.update(self._deleted)
+ self._parent._key_switches.update(self._key_switches)
+
+ def _connection_for_bind(self, bind, execution_options):
+ self._assert_active()
+
+ if bind in self._connections:
+ if execution_options:
+ util.warn(
+ "Connection is already established for the "
+ "given bind; execution_options ignored"
+ )
+ return self._connections[bind][0]
+
+ local_connect = False
+ should_commit = True
+
+ if self._parent:
+ conn = self._parent._connection_for_bind(bind, execution_options)
+ if not self.nested:
+ return conn
+ else:
+ if isinstance(bind, engine.Connection):
+ conn = bind
+ if conn.engine in self._connections:
+ raise sa_exc.InvalidRequestError(
+ "Session already has a Connection associated for the "
+ "given Connection's Engine"
+ )
+ else:
+ conn = bind.connect()
+ local_connect = True
+
+ try:
+ if execution_options:
+ conn = conn.execution_options(**execution_options)
+
+ if self.session.twophase and self._parent is None:
+ transaction = conn.begin_twophase()
+ elif self.nested:
+ transaction = conn.begin_nested()
+ elif conn.in_transaction():
+ # if given a future connection already in a transaction, don't
+ # commit that transaction unless it is a savepoint
+ if conn.in_nested_transaction():
+ transaction = conn.get_nested_transaction()
+ else:
+ transaction = conn.get_transaction()
+ should_commit = False
+ else:
+ transaction = conn.begin()
+ except:
+ # connection will not not be associated with this Session;
+ # close it immediately so that it isn't closed under GC
+ if local_connect:
+ conn.close()
+ raise
+ else:
+ bind_is_connection = isinstance(bind, engine.Connection)
+
+ self._connections[conn] = self._connections[conn.engine] = (
+ conn,
+ transaction,
+ should_commit,
+ not bind_is_connection,
+ )
+ self.session.dispatch.after_begin(self.session, self, conn)
+ return conn
+
+ def prepare(self):
+ if self._parent is not None or not self.session.twophase:
+ raise sa_exc.InvalidRequestError(
+ "'twophase' mode not enabled, or not root transaction; "
+ "can't prepare."
+ )
+ self._prepare_impl()
+
+ def _prepare_impl(self):
+ self._assert_active()
+ if self._parent is None or self.nested:
+ self.session.dispatch.before_commit(self.session)
+
+ stx = self.session._transaction
+ if stx is not self:
+ for subtransaction in stx._iterate_self_and_parents(upto=self):
+ subtransaction.commit()
+
+ if not self.session._flushing:
+ for _flush_guard in range(100):
+ if self.session._is_clean():
+ break
+ self.session.flush()
+ else:
+ raise exc.FlushError(
+ "Over 100 subsequent flushes have occurred within "
+ "session.commit() - is an after_flush() hook "
+ "creating new objects?"
+ )
+
+ if self._parent is None and self.session.twophase:
+ try:
+ for t in set(self._connections.values()):
+ t[1].prepare()
+ except:
+ with util.safe_reraise():
+ self.rollback()
+
+ self._state = PREPARED
+
+ def commit(self, _to_root=False):
+ self._assert_active(prepared_ok=True)
+ if self._state is not PREPARED:
+ self._prepare_impl()
+
+ if self._parent is None or self.nested:
+ for conn, trans, should_commit, autoclose in set(
+ self._connections.values()
+ ):
+ if should_commit:
+ trans.commit()
+
+ self._state = COMMITTED
+ self.session.dispatch.after_commit(self.session)
+
+ self._remove_snapshot()
+
+ self.close()
+
+ if _to_root and self._parent:
+ return self._parent.commit(_to_root=True)
+
+ return self._parent
+
+ def rollback(self, _capture_exception=False, _to_root=False):
+ self._assert_active(prepared_ok=True, rollback_ok=True)
+
+ stx = self.session._transaction
+ if stx is not self:
+ for subtransaction in stx._iterate_self_and_parents(upto=self):
+ subtransaction.close()
+
+ boundary = self
+ rollback_err = None
+ if self._state in (ACTIVE, PREPARED):
+ for transaction in self._iterate_self_and_parents():
+ if transaction._parent is None or transaction.nested:
+ try:
+ for t in set(transaction._connections.values()):
+ t[1].rollback()
+
+ transaction._state = DEACTIVE
+ self.session.dispatch.after_rollback(self.session)
+ except:
+ rollback_err = sys.exc_info()
+ finally:
+ transaction._state = DEACTIVE
+ transaction._restore_snapshot(
+ dirty_only=transaction.nested
+ )
+ boundary = transaction
+ break
+ else:
+ transaction._state = DEACTIVE
+
+ sess = self.session
+
+ if not rollback_err and not sess._is_clean():
+
+ # if items were added, deleted, or mutated
+ # here, we need to re-restore the snapshot
+ util.warn(
+ "Session's state has been changed on "
+ "a non-active transaction - this state "
+ "will be discarded."
+ )
+ boundary._restore_snapshot(dirty_only=boundary.nested)
+
+ self.close()
+
+ if self._parent and _capture_exception:
+ self._parent._rollback_exception = sys.exc_info()[1]
+
+ if rollback_err:
+ util.raise_(rollback_err[1], with_traceback=rollback_err[2])
+
+ sess.dispatch.after_soft_rollback(sess, self)
+
+ if _to_root and self._parent:
+ return self._parent.rollback(_to_root=True)
+ return self._parent
+
+ def close(self, invalidate=False):
+ if self.nested:
+ self.session._nested_transaction = (
+ self._previous_nested_transaction
+ )
+
+ self.session._transaction = self._parent
+
+ if self._parent is None:
+ for connection, transaction, should_commit, autoclose in set(
+ self._connections.values()
+ ):
+ if invalidate:
+ connection.invalidate()
+ if should_commit and transaction.is_active:
+ transaction.close()
+ if autoclose:
+ connection.close()
+
+ self._state = CLOSED
+ self.session.dispatch.after_transaction_end(self.session, self)
+
+ self.session = None
+ self._connections = None
+
+ def _get_subject(self):
+ return self.session
+
+ def _transaction_is_active(self):
+ return self._state is ACTIVE
+
+ def _transaction_is_closed(self):
+ return self._state is CLOSED
+
+ def _rollback_can_be_called(self):
+ return self._state not in (COMMITTED, CLOSED)
+
+
+class Session(_SessionClassMethods):
+ """Manages persistence operations for ORM-mapped objects.
+
+ The Session's usage paradigm is described at :doc:`/orm/session`.
+
+
+ """
+
+ _is_asyncio = False
+
+ @util.deprecated_params(
+ autocommit=(
+ "2.0",
+ "The :paramref:`.Session.autocommit` parameter is deprecated "
+ "and will be removed in SQLAlchemy version 2.0. The "
+ ':class:`_orm.Session` now features "autobegin" behavior '
+ "such that the :meth:`.Session.begin` method may be called "
+ "if a transaction has not yet been started yet. See the section "
+ ":ref:`session_explicit_begin` for background.",
+ ),
+ )
+ def __init__(
+ self,
+ bind=None,
+ autoflush=True,
+ future=False,
+ expire_on_commit=True,
+ autocommit=False,
+ twophase=False,
+ binds=None,
+ enable_baked_queries=True,
+ info=None,
+ query_cls=None,
+ ):
+ r"""Construct a new Session.
+
+ See also the :class:`.sessionmaker` function which is used to
+ generate a :class:`.Session`-producing callable with a given
+ set of arguments.
+
+ :param autocommit:
+ Defaults to ``False``. When ``True``, the
+ :class:`.Session` does not automatically begin transactions for
+ individual statement executions, will acquire connections from the
+ engine on an as-needed basis, releasing to the connection pool
+ after each statement. Flushes will begin and commit (or possibly
+ rollback) their own transaction if no transaction is present.
+ When using this mode, the
+ :meth:`.Session.begin` method may be used to explicitly start
+ transactions, but the usual "autobegin" behavior is not present.
+
+ :param autoflush: When ``True``, all query operations will issue a
+ :meth:`~.Session.flush` call to this ``Session`` before proceeding.
+ This is a convenience feature so that :meth:`~.Session.flush` need
+ not be called repeatedly in order for database queries to retrieve
+ results. It's typical that ``autoflush`` is used in conjunction
+ with ``autocommit=False``. In this scenario, explicit calls to
+ :meth:`~.Session.flush` are rarely needed; you usually only need to
+ call :meth:`~.Session.commit` (which flushes) to finalize changes.
+
+ .. seealso::
+
+ :ref:`session_flushing` - additional background on autoflush
+
+ :param bind: An optional :class:`_engine.Engine` or
+ :class:`_engine.Connection` to
+ which this ``Session`` should be bound. When specified, all SQL
+ operations performed by this session will execute via this
+ connectable.
+
+ :param binds: A dictionary which may specify any number of
+ :class:`_engine.Engine` or :class:`_engine.Connection`
+ objects as the source of
+ connectivity for SQL operations on a per-entity basis. The keys
+ of the dictionary consist of any series of mapped classes,
+ arbitrary Python classes that are bases for mapped classes,
+ :class:`_schema.Table` objects and :class:`_orm.Mapper` objects.
+ The
+ values of the dictionary are then instances of
+ :class:`_engine.Engine`
+ or less commonly :class:`_engine.Connection` objects.
+ Operations which
+ proceed relative to a particular mapped class will consult this
+ dictionary for the closest matching entity in order to determine
+ which :class:`_engine.Engine` should be used for a particular SQL
+ operation. The complete heuristics for resolution are
+ described at :meth:`.Session.get_bind`. Usage looks like::
+
+ Session = sessionmaker(binds={
+ SomeMappedClass: create_engine('postgresql://engine1'),
+ SomeDeclarativeBase: create_engine('postgresql://engine2'),
+ some_mapper: create_engine('postgresql://engine3'),
+ some_table: create_engine('postgresql://engine4'),
+ })
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :meth:`.Session.bind_mapper`
+
+ :meth:`.Session.bind_table`
+
+ :meth:`.Session.get_bind`
+
+
+ :param \class_: Specify an alternate class other than
+ ``sqlalchemy.orm.session.Session`` which should be used by the
+ returned class. This is the only argument that is local to the
+ :class:`.sessionmaker` function, and is not sent directly to the
+ constructor for ``Session``.
+
+ :param enable_baked_queries: defaults to ``True``. A flag consumed
+ by the :mod:`sqlalchemy.ext.baked` extension to determine if
+ "baked queries" should be cached, as is the normal operation
+ of this extension. When set to ``False``, caching as used by
+ this particular extension is disabled.
+
+ .. versionchanged:: 1.4 The ``sqlalchemy.ext.baked`` extension is
+ legacy and is not used by any of SQLAlchemy's internals. This
+ flag therefore only affects applications that are making explicit
+ use of this extension within their own code.
+
+ :param expire_on_commit: Defaults to ``True``. When ``True``, all
+ instances will be fully expired after each :meth:`~.commit`,
+ so that all attribute/object access subsequent to a completed
+ transaction will load from the most recent database state.
+
+ .. seealso::
+
+ :ref:`session_committing`
+
+ :param future: if True, use 2.0 style transactional and engine
+ behavior. Future mode includes the following behaviors:
+
+ * The :class:`_orm.Session` will not use "bound" metadata in order
+ to locate an :class:`_engine.Engine`; the engine or engines in use
+ must be specified to the constructor of :class:`_orm.Session` or
+ otherwise be configured against the :class:`_orm.sessionmaker`
+ in use
+
+ * The "subtransactions" feature of :meth:`_orm.Session.begin` is
+ removed in version 2.0 and is disabled when the future flag is
+ set.
+
+ * The behavior of the :paramref:`_orm.relationship.cascade_backrefs`
+ flag on a :func:`_orm.relationship` will always assume
+ "False" behavior.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`migration_20_toplevel`
+
+ :param info: optional dictionary of arbitrary data to be associated
+ with this :class:`.Session`. Is available via the
+ :attr:`.Session.info` attribute. Note the dictionary is copied at
+ construction time so that modifications to the per-
+ :class:`.Session` dictionary will be local to that
+ :class:`.Session`.
+
+ :param query_cls: Class which should be used to create new Query
+ objects, as returned by the :meth:`~.Session.query` method.
+ Defaults to :class:`_query.Query`.
+
+ :param twophase: When ``True``, all transactions will be started as
+ a "two phase" transaction, i.e. using the "two phase" semantics
+ of the database in use along with an XID. During a
+ :meth:`~.commit`, after :meth:`~.flush` has been issued for all
+ attached databases, the :meth:`~.TwoPhaseTransaction.prepare`
+ method on each database's :class:`.TwoPhaseTransaction` will be
+ called. This allows each database to roll back the entire
+ transaction, before each transaction is committed.
+
+ """
+ self.identity_map = identity.WeakInstanceDict()
+
+ self._new = {} # InstanceState->object, strong refs object
+ self._deleted = {} # same
+ self.bind = bind
+ self.__binds = {}
+ self._flushing = False
+ self._warn_on_events = False
+ self._transaction = None
+ self._nested_transaction = None
+ self.future = future
+ self.hash_key = _new_sessionid()
+ self.autoflush = autoflush
+ self.expire_on_commit = expire_on_commit
+ self.enable_baked_queries = enable_baked_queries
+
+ if autocommit:
+ if future:
+ raise sa_exc.ArgumentError(
+ "Cannot use autocommit mode with future=True."
+ )
+ self.autocommit = True
+ else:
+ self.autocommit = False
+
+ self.twophase = twophase
+ self._query_cls = query_cls if query_cls else query.Query
+ if info:
+ self.info.update(info)
+
+ if binds is not None:
+ for key, bind in binds.items():
+ self._add_bind(key, bind)
+
+ _sessions[self.hash_key] = self
+
+ # used by sqlalchemy.engine.util.TransactionalContext
+ _trans_context_manager = None
+
+ connection_callable = None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ self.close()
+
+ @util.contextmanager
+ def _maker_context_manager(self):
+ with self:
+ with self.begin():
+ yield self
+
+ @property
+ @util.deprecated_20(
+ ":attr:`_orm.Session.transaction`",
+ alternative="For context manager use, use "
+ ":meth:`_orm.Session.begin`. To access "
+ "the current root transaction, use "
+ ":meth:`_orm.Session.get_transaction`.",
+ warn_on_attribute_access=True,
+ )
+ def transaction(self):
+ """The current active or inactive :class:`.SessionTransaction`.
+
+ May be None if no transaction has begun yet.
+
+ .. versionchanged:: 1.4 the :attr:`.Session.transaction` attribute
+ is now a read-only descriptor that also may return None if no
+ transaction has begun yet.
+
+
+ """
+ return self._legacy_transaction()
+
+ def _legacy_transaction(self):
+ if not self.future:
+ self._autobegin()
+ return self._transaction
+
+ def in_transaction(self):
+ """Return True if this :class:`_orm.Session` has begun a transaction.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_orm.Session.is_active`
+
+
+ """
+ return self._transaction is not None
+
+ def in_nested_transaction(self):
+ """Return True if this :class:`_orm.Session` has begun a nested
+ transaction, e.g. SAVEPOINT.
+
+ .. versionadded:: 1.4
+
+ """
+ return self._nested_transaction is not None
+
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+ trans = self._transaction
+ while trans is not None and trans._parent is not None:
+ trans = trans._parent
+ return trans
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ .. versionadded:: 1.4
+
+ """
+
+ return self._nested_transaction
+
+ @util.memoized_property
+ def info(self):
+ """A user-modifiable dictionary.
+
+ The initial value of this dictionary can be populated using the
+ ``info`` argument to the :class:`.Session` constructor or
+ :class:`.sessionmaker` constructor or factory methods. The dictionary
+ here is always local to this :class:`.Session` and can be modified
+ independently of all other :class:`.Session` objects.
+
+ """
+ return {}
+
+ def _autobegin(self):
+ if not self.autocommit and self._transaction is None:
+
+ trans = SessionTransaction(self, autobegin=True)
+ assert self._transaction is trans
+ return True
+
+ return False
+
+ @util.deprecated_params(
+ subtransactions=(
+ "2.0",
+ "The :paramref:`_orm.Session.begin.subtransactions` flag is "
+ "deprecated and "
+ "will be removed in SQLAlchemy version 2.0. See "
+ "the documentation at :ref:`session_subtransactions` for "
+ "background on a compatible alternative pattern.",
+ )
+ )
+ def begin(self, subtransactions=False, nested=False, _subtrans=False):
+ """Begin a transaction, or nested transaction,
+ on this :class:`.Session`, if one is not already begun.
+
+ The :class:`_orm.Session` object features **autobegin** behavior,
+ so that normally it is not necessary to call the
+ :meth:`_orm.Session.begin`
+ method explicitly. However, it may be used in order to control
+ the scope of when the transactional state is begun.
+
+ When used to begin the outermost transaction, an error is raised
+ if this :class:`.Session` is already inside of a transaction.
+
+ :param nested: if True, begins a SAVEPOINT transaction and is
+ equivalent to calling :meth:`~.Session.begin_nested`. For
+ documentation on SAVEPOINT transactions, please see
+ :ref:`session_begin_nested`.
+
+ :param subtransactions: if True, indicates that this
+ :meth:`~.Session.begin` can create a "subtransaction".
+
+ :return: the :class:`.SessionTransaction` object. Note that
+ :class:`.SessionTransaction`
+ acts as a Python context manager, allowing :meth:`.Session.begin`
+ to be used in a "with" block. See :ref:`session_autocommit` for
+ an example.
+
+ .. seealso::
+
+ :ref:`session_autobegin`
+
+ :ref:`unitofwork_transaction`
+
+ :meth:`.Session.begin_nested`
+
+
+ """
+
+ if subtransactions and self.future:
+ raise NotImplementedError(
+ "subtransactions are not implemented in future "
+ "Session objects."
+ )
+
+ if self._autobegin():
+ if not subtransactions and not nested and not _subtrans:
+ return self._transaction
+
+ if self._transaction is not None:
+ if subtransactions or _subtrans or nested:
+ trans = self._transaction._begin(nested=nested)
+ assert self._transaction is trans
+ if nested:
+ self._nested_transaction = trans
+ else:
+ raise sa_exc.InvalidRequestError(
+ "A transaction is already begun on this Session."
+ )
+ elif not self.autocommit:
+ # outermost transaction. must be a not nested and not
+ # a subtransaction
+
+ assert not nested and not _subtrans and not subtransactions
+ trans = SessionTransaction(self)
+ assert self._transaction is trans
+ else:
+ # legacy autocommit mode
+ assert not self.future
+ trans = SessionTransaction(self, nested=nested)
+ assert self._transaction is trans
+
+ return self._transaction # needed for __enter__/__exit__ hook
+
+ def begin_nested(self):
+ """Begin a "nested" transaction on this Session, e.g. SAVEPOINT.
+
+ The target database(s) and associated drivers must support SQL
+ SAVEPOINT for this method to function correctly.
+
+ For documentation on SAVEPOINT
+ transactions, please see :ref:`session_begin_nested`.
+
+ :return: the :class:`.SessionTransaction` object. Note that
+ :class:`.SessionTransaction` acts as a context manager, allowing
+ :meth:`.Session.begin_nested` to be used in a "with" block.
+ See :ref:`session_begin_nested` for a usage example.
+
+ .. seealso::
+
+ :ref:`session_begin_nested`
+
+ :ref:`pysqlite_serializable` - special workarounds required
+ with the SQLite driver in order for SAVEPOINT to work
+ correctly.
+
+ """
+ return self.begin(nested=True)
+
+ def rollback(self):
+ """Rollback the current transaction in progress.
+
+ If no transaction is in progress, this method is a pass-through.
+
+ In :term:`1.x-style` use, this method rolls back the topmost
+ database transaction if no nested transactions are in effect, or
+ to the current nested transaction if one is in effect.
+
+ When
+ :term:`2.0-style` use is in effect via the
+ :paramref:`_orm.Session.future` flag, the method always rolls back
+ the topmost database transaction, discarding any nested
+ transactions that may be in progress.
+
+ .. seealso::
+
+ :ref:`session_rollback`
+
+ :ref:`unitofwork_transaction`
+
+ """
+ if self._transaction is None:
+ pass
+ else:
+ self._transaction.rollback(_to_root=self.future)
+
+ def commit(self):
+ """Flush pending changes and commit the current transaction.
+
+ When the COMMIT operation is complete, all objects are fully
+ :term:`expired`, erasing their internal contents, which will be
+ automatically re-loaded when the objects are next accessed. In the
+ interim, these objects are in an expired state and will not function if
+ they are :term:`detached` from the :class:`.Session`. Additionally,
+ this re-load operation is not supported when using asyncio-oriented
+ APIs. The :paramref:`.Session.expire_on_commit` parameter may be used
+ to disable this behavior.
+
+ When there is no transaction in place for the :class:`.Session`,
+ indicating that no operations were invoked on this :class:`.Session`
+ since the previous call to :meth:`.Session.commit`, the method will
+ begin and commit an internal-only "logical" transaction, that does not
+ normally affect the database unless pending flush changes were
+ detected, but will still invoke event handlers and object expiration
+ rules.
+
+ If :term:`1.x-style` use is in effect and there are currently
+ SAVEPOINTs in progress via :meth:`_orm.Session.begin_nested`,
+ the operation will release the current SAVEPOINT but not commit
+ the outermost database transaction.
+
+ If :term:`2.0-style` use is in effect via the
+ :paramref:`_orm.Session.future` flag, the outermost database
+ transaction is committed unconditionally, automatically releasing any
+ SAVEPOINTs in effect.
+
+ When using legacy "autocommit" mode, this method is only
+ valid to call if a transaction is actually in progress, else
+ an error is raised. Similarly, when using legacy "subtransactions",
+ the method will instead close out the current "subtransaction",
+ rather than the actual database transaction, if a transaction
+ is in progress.
+
+ .. seealso::
+
+ :ref:`session_committing`
+
+ :ref:`unitofwork_transaction`
+
+ :ref:`asyncio_orm_avoid_lazyloads`
+
+ """
+ if self._transaction is None:
+ if not self._autobegin():
+ raise sa_exc.InvalidRequestError("No transaction is begun.")
+
+ self._transaction.commit(_to_root=self.future)
+
+ def prepare(self):
+ """Prepare the current transaction in progress for two phase commit.
+
+ If no transaction is in progress, this method raises an
+ :exc:`~sqlalchemy.exc.InvalidRequestError`.
+
+ Only root transactions of two phase sessions can be prepared. If the
+ current transaction is not such, an
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if self._transaction is None:
+ if not self._autobegin():
+ raise sa_exc.InvalidRequestError("No transaction is begun.")
+
+ self._transaction.prepare()
+
+ def connection(
+ self,
+ bind_arguments=None,
+ close_with_result=False,
+ execution_options=None,
+ **kw
+ ):
+ r"""Return a :class:`_engine.Connection` object corresponding to this
+ :class:`.Session` object's transactional state.
+
+ If this :class:`.Session` is configured with ``autocommit=False``,
+ either the :class:`_engine.Connection` corresponding to the current
+ transaction is returned, or if no transaction is in progress, a new
+ one is begun and the :class:`_engine.Connection`
+ returned (note that no
+ transactional state is established with the DBAPI until the first
+ SQL statement is emitted).
+
+ Alternatively, if this :class:`.Session` is configured with
+ ``autocommit=True``, an ad-hoc :class:`_engine.Connection` is returned
+ using :meth:`_engine.Engine.connect` on the underlying
+ :class:`_engine.Engine`.
+
+ Ambiguity in multi-bind or unbound :class:`.Session` objects can be
+ resolved through any of the optional keyword arguments. This
+ ultimately makes usage of the :meth:`.get_bind` method for resolution.
+
+ :param bind_arguments: dictionary of bind arguments. May include
+ "mapper", "bind", "clause", other custom arguments that are passed
+ to :meth:`.Session.get_bind`.
+
+ :param bind:
+ deprecated; use bind_arguments
+
+ :param mapper:
+ deprecated; use bind_arguments
+
+ :param clause:
+ deprecated; use bind_arguments
+
+ :param close_with_result: Passed to :meth:`_engine.Engine.connect`,
+ indicating the :class:`_engine.Connection` should be considered
+ "single use", automatically closing when the first result set is
+ closed. This flag only has an effect if this :class:`.Session` is
+ configured with ``autocommit=True`` and does not already have a
+ transaction in progress.
+
+ .. deprecated:: 1.4 this parameter is deprecated and will be removed
+ in SQLAlchemy 2.0
+
+ :param execution_options: a dictionary of execution options that will
+ be passed to :meth:`_engine.Connection.execution_options`, **when the
+ connection is first procured only**. If the connection is already
+ present within the :class:`.Session`, a warning is emitted and
+ the arguments are ignored.
+
+ .. seealso::
+
+ :ref:`session_transaction_isolation`
+
+ :param \**kw:
+ deprecated; use bind_arguments
+
+ """
+
+ if not bind_arguments:
+ bind_arguments = kw
+
+ bind = bind_arguments.pop("bind", None)
+ if bind is None:
+ bind = self.get_bind(**bind_arguments)
+
+ return self._connection_for_bind(
+ bind,
+ close_with_result=close_with_result,
+ execution_options=execution_options,
+ )
+
+ def _connection_for_bind(self, engine, execution_options=None, **kw):
+ TransactionalContext._trans_ctx_check(self)
+
+ if self._transaction is not None or self._autobegin():
+ return self._transaction._connection_for_bind(
+ engine, execution_options
+ )
+
+ assert self._transaction is None
+ assert self.autocommit
+ conn = engine.connect(**kw)
+ if execution_options:
+ conn = conn.execution_options(**execution_options)
+ return conn
+
+ def execute(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ _parent_execute_state=None,
+ _add_event=None,
+ **kw
+ ):
+ r"""Execute a SQL expression construct.
+
+ Returns a :class:`_engine.Result` object representing
+ results of the statement execution.
+
+ E.g.::
+
+ from sqlalchemy import select
+ result = session.execute(
+ select(User).where(User.id == 5)
+ )
+
+ The API contract of :meth:`_orm.Session.execute` is similar to that
+ of :meth:`_future.Connection.execute`, the :term:`2.0 style` version
+ of :class:`_future.Connection`.
+
+ .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is
+ now the primary point of ORM statement execution when using
+ :term:`2.0 style` ORM usage.
+
+ :param statement:
+ An executable statement (i.e. an :class:`.Executable` expression
+ such as :func:`_expression.select`).
+
+ :param params:
+ Optional dictionary, or list of dictionaries, containing
+ bound parameter values. If a single dictionary, single-row
+ execution occurs; if a list of dictionaries, an
+ "executemany" will be invoked. The keys in each dictionary
+ must correspond to parameter names present in the statement.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_engine.Connection.execution_options`, and may also
+ provide additional options understood only in an ORM context.
+
+ :param bind_arguments: dictionary of additional arguments to determine
+ the bind. May include "mapper", "bind", or other custom arguments.
+ Contents of this dictionary are passed to the
+ :meth:`.Session.get_bind` method.
+
+ :param mapper:
+ deprecated; use the bind_arguments dictionary
+
+ :param bind:
+ deprecated; use the bind_arguments dictionary
+
+ :param \**kw:
+ deprecated; use the bind_arguments dictionary
+
+ :return: a :class:`_engine.Result` object.
+
+
+ """
+ statement = coercions.expect(roles.StatementRole, statement)
+
+ if kw:
+ util.warn_deprecated_20(
+ "Passing bind arguments to Session.execute() as keyword "
+ "arguments is deprecated and will be removed SQLAlchemy 2.0. "
+ "Please use the bind_arguments parameter."
+ )
+ if not bind_arguments:
+ bind_arguments = kw
+ else:
+ bind_arguments.update(kw)
+ elif not bind_arguments:
+ bind_arguments = {}
+
+ if (
+ statement._propagate_attrs.get("compile_state_plugin", None)
+ == "orm"
+ ):
+ # note that even without "future" mode, we need
+ compile_state_cls = CompileState._get_plugin_class_for_plugin(
+ statement, "orm"
+ )
+ else:
+ compile_state_cls = None
+
+ execution_options = util.coerce_to_immutabledict(execution_options)
+
+ if compile_state_cls is not None:
+ (
+ statement,
+ execution_options,
+ ) = compile_state_cls.orm_pre_session_exec(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ _parent_execute_state is not None,
+ )
+ else:
+ bind_arguments.setdefault("clause", statement)
+ execution_options = execution_options.union(
+ {"future_result": True}
+ )
+
+ if _parent_execute_state:
+ events_todo = _parent_execute_state._remaining_events()
+ else:
+ events_todo = self.dispatch.do_orm_execute
+ if _add_event:
+ events_todo = list(events_todo) + [_add_event]
+
+ if events_todo:
+ orm_exec_state = ORMExecuteState(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
+ events_todo,
+ )
+ for idx, fn in enumerate(events_todo):
+ orm_exec_state._starting_event_idx = idx
+ result = fn(orm_exec_state)
+ if result:
+ return result
+
+ statement = orm_exec_state.statement
+ execution_options = orm_exec_state.local_execution_options
+
+ bind = self.get_bind(**bind_arguments)
+
+ if self.autocommit:
+ # legacy stuff, we can't use future_result w/ autocommit because
+ # we rely upon close_with_result, also legacy. it's all
+ # interrelated
+ conn = self._connection_for_bind(bind, close_with_result=True)
+ execution_options = execution_options.union(
+ dict(future_result=False)
+ )
+ else:
+ conn = self._connection_for_bind(bind)
+ result = conn._execute_20(statement, params or {}, execution_options)
+
+ if compile_state_cls:
+ result = compile_state_cls.orm_setup_cursor_result(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ return result
+
+ def scalar(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a scalar result.
+
+ Usage and parameters are the same as that of
+ :meth:`_orm.Session.execute`; the return result is a scalar Python
+ value.
+
+ """
+
+ return self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ ).scalar()
+
+ def scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return the results as scalars.
+
+ Usage and parameters are the same as that of
+ :meth:`_orm.Session.execute`; the return result is a
+ :class:`_result.ScalarResult` filtering object which
+ will return single elements rather than :class:`_row.Row` objects.
+
+ :return: a :class:`_result.ScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ return self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ ).scalars()
+
+ def close(self):
+ """Close out the transactional resources and ORM objects used by this
+ :class:`_orm.Session`.
+
+ This expunges all ORM objects associated with this
+ :class:`_orm.Session`, ends any transaction in progress and
+ :term:`releases` any :class:`_engine.Connection` objects which this
+ :class:`_orm.Session` itself has checked out from associated
+ :class:`_engine.Engine` objects. The operation then leaves the
+ :class:`_orm.Session` in a state which it may be used again.
+
+ .. tip::
+
+ The :meth:`_orm.Session.close` method **does not prevent the
+ Session from being used again**. The :class:`_orm.Session` itself
+ does not actually have a distinct "closed" state; it merely means
+ the :class:`_orm.Session` will release all database connections
+ and ORM objects.
+
+ .. versionchanged:: 1.4 The :meth:`.Session.close` method does not
+ immediately create a new :class:`.SessionTransaction` object;
+ instead, the new :class:`.SessionTransaction` is created only if
+ the :class:`.Session` is used again for a database operation.
+
+ .. seealso::
+
+ :ref:`session_closing` - detail on the semantics of
+ :meth:`_orm.Session.close`
+
+ """
+ self._close_impl(invalidate=False)
+
+ def invalidate(self):
+ """Close this Session, using connection invalidation.
+
+ This is a variant of :meth:`.Session.close` that will additionally
+ ensure that the :meth:`_engine.Connection.invalidate`
+ method will be called on each :class:`_engine.Connection` object
+ that is currently in use for a transaction (typically there is only
+ one connection unless the :class:`_orm.Session` is used with
+ multiple engines).
+
+ This can be called when the database is known to be in a state where
+ the connections are no longer safe to be used.
+
+ Below illustrates a scenario when using `gevent
+ <https://www.gevent.org/>`_, which can produce ``Timeout`` exceptions
+ that may mean the underlying connection should be discarded::
+
+ import gevent
+
+ try:
+ sess = Session()
+ sess.add(User())
+ sess.commit()
+ except gevent.Timeout:
+ sess.invalidate()
+ raise
+ except:
+ sess.rollback()
+ raise
+
+ The method additionally does everything that :meth:`_orm.Session.close`
+ does, including that all ORM objects are expunged.
+
+ """
+ self._close_impl(invalidate=True)
+
+ def _close_impl(self, invalidate):
+ self.expunge_all()
+ if self._transaction is not None:
+ for transaction in self._transaction._iterate_self_and_parents():
+ transaction.close(invalidate)
+
+ def expunge_all(self):
+ """Remove all object instances from this ``Session``.
+
+ This is equivalent to calling ``expunge(obj)`` on all objects in this
+ ``Session``.
+
+ """
+
+ all_states = self.identity_map.all_states() + list(self._new)
+ self.identity_map._kill()
+ self.identity_map = identity.WeakInstanceDict()
+ self._new = {}
+ self._deleted = {}
+
+ statelib.InstanceState._detach_states(all_states, self)
+
+ def _add_bind(self, key, bind):
+ try:
+ insp = inspect(key)
+ except sa_exc.NoInspectionAvailable as err:
+ if not isinstance(key, type):
+ util.raise_(
+ sa_exc.ArgumentError(
+ "Not an acceptable bind target: %s" % key
+ ),
+ replace_context=err,
+ )
+ else:
+ self.__binds[key] = bind
+ else:
+ if insp.is_selectable:
+ self.__binds[insp] = bind
+ elif insp.is_mapper:
+ self.__binds[insp.class_] = bind
+ for _selectable in insp._all_tables:
+ self.__binds[_selectable] = bind
+ else:
+ raise sa_exc.ArgumentError(
+ "Not an acceptable bind target: %s" % key
+ )
+
+ def bind_mapper(self, mapper, bind):
+ """Associate a :class:`_orm.Mapper` or arbitrary Python class with a
+ "bind", e.g. an :class:`_engine.Engine` or
+ :class:`_engine.Connection`.
+
+ The given entity is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
+
+ :param mapper: a :class:`_orm.Mapper` object,
+ or an instance of a mapped
+ class, or any Python class that is the base of a set of mapped
+ classes.
+
+ :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection`
+ object.
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :paramref:`.Session.binds`
+
+ :meth:`.Session.bind_table`
+
+
+ """
+ self._add_bind(mapper, bind)
+
+ def bind_table(self, table, bind):
+ """Associate a :class:`_schema.Table` with a "bind", e.g. an
+ :class:`_engine.Engine`
+ or :class:`_engine.Connection`.
+
+ The given :class:`_schema.Table` is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
+
+ :param table: a :class:`_schema.Table` object,
+ which is typically the target
+ of an ORM mapping, or is present within a selectable that is
+ mapped.
+
+ :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection`
+ object.
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :paramref:`.Session.binds`
+
+ :meth:`.Session.bind_mapper`
+
+
+ """
+ self._add_bind(table, bind)
+
+ def get_bind(
+ self,
+ mapper=None,
+ clause=None,
+ bind=None,
+ _sa_skip_events=None,
+ _sa_skip_for_implicit_returning=False,
+ ):
+ """Return a "bind" to which this :class:`.Session` is bound.
+
+ The "bind" is usually an instance of :class:`_engine.Engine`,
+ except in the case where the :class:`.Session` has been
+ explicitly bound directly to a :class:`_engine.Connection`.
+
+ For a multiply-bound or unbound :class:`.Session`, the
+ ``mapper`` or ``clause`` arguments are used to determine the
+ appropriate bind to return.
+
+ Note that the "mapper" argument is usually present
+ when :meth:`.Session.get_bind` is called via an ORM
+ operation such as a :meth:`.Session.query`, each
+ individual INSERT/UPDATE/DELETE operation within a
+ :meth:`.Session.flush`, call, etc.
+
+ The order of resolution is:
+
+ 1. if mapper given and :paramref:`.Session.binds` is present,
+ locate a bind based first on the mapper in use, then
+ on the mapped class in use, then on any base classes that are
+ present in the ``__mro__`` of the mapped class, from more specific
+ superclasses to more general.
+ 2. if clause given and ``Session.binds`` is present,
+ locate a bind based on :class:`_schema.Table` objects
+ found in the given clause present in ``Session.binds``.
+ 3. if ``Session.binds`` is present, return that.
+ 4. if clause given, attempt to return a bind
+ linked to the :class:`_schema.MetaData` ultimately
+ associated with the clause.
+ 5. if mapper given, attempt to return a bind
+ linked to the :class:`_schema.MetaData` ultimately
+ associated with the :class:`_schema.Table` or other
+ selectable to which the mapper is mapped.
+ 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError`
+ is raised.
+
+ Note that the :meth:`.Session.get_bind` method can be overridden on
+ a user-defined subclass of :class:`.Session` to provide any kind
+ of bind resolution scheme. See the example at
+ :ref:`session_custom_partitioning`.
+
+ :param mapper:
+ Optional :func:`.mapper` mapped class or instance of
+ :class:`_orm.Mapper`. The bind can be derived from a
+ :class:`_orm.Mapper`
+ first by consulting the "binds" map associated with this
+ :class:`.Session`, and secondly by consulting the
+ :class:`_schema.MetaData`
+ associated with the :class:`_schema.Table` to which the
+ :class:`_orm.Mapper`
+ is mapped for a bind.
+
+ :param clause:
+ A :class:`_expression.ClauseElement` (i.e.
+ :func:`_expression.select`,
+ :func:`_expression.text`,
+ etc.). If the ``mapper`` argument is not present or could not
+ produce a bind, the given expression construct will be searched
+ for a bound element, typically a :class:`_schema.Table`
+ associated with
+ bound :class:`_schema.MetaData`.
+
+ .. seealso::
+
+ :ref:`session_partitioning`
+
+ :paramref:`.Session.binds`
+
+ :meth:`.Session.bind_mapper`
+
+ :meth:`.Session.bind_table`
+
+ """
+
+ # this function is documented as a subclassing hook, so we have
+ # to call this method even if the return is simple
+ if bind:
+ return bind
+ elif not self.__binds and self.bind:
+ # simplest and most common case, we have a bind and no
+ # per-mapper/table binds, we're done
+ return self.bind
+
+ # we don't have self.bind and either have self.__binds
+ # or we don't have self.__binds (which is legacy). Look at the
+ # mapper and the clause
+ if mapper is clause is None:
+ if self.bind:
+ return self.bind
+ else:
+ raise sa_exc.UnboundExecutionError(
+ "This session is not bound to a single Engine or "
+ "Connection, and no context was provided to locate "
+ "a binding."
+ )
+
+ # look more closely at the mapper.
+ if mapper is not None:
+ try:
+ mapper = inspect(mapper)
+ except sa_exc.NoInspectionAvailable as err:
+ if isinstance(mapper, type):
+ util.raise_(
+ exc.UnmappedClassError(mapper),
+ replace_context=err,
+ )
+ else:
+ raise
+
+ # match up the mapper or clause in the __binds
+ if self.__binds:
+ # matching mappers and selectables to entries in the
+ # binds dictionary; supported use case.
+ if mapper:
+ for cls in mapper.class_.__mro__:
+ if cls in self.__binds:
+ return self.__binds[cls]
+ if clause is None:
+ clause = mapper.persist_selectable
+
+ if clause is not None:
+ plugin_subject = clause._propagate_attrs.get(
+ "plugin_subject", None
+ )
+
+ if plugin_subject is not None:
+ for cls in plugin_subject.mapper.class_.__mro__:
+ if cls in self.__binds:
+ return self.__binds[cls]
+
+ for obj in visitors.iterate(clause):
+ if obj in self.__binds:
+ return self.__binds[obj]
+
+ # none of the __binds matched, but we have a fallback bind.
+ # return that
+ if self.bind:
+ return self.bind
+
+ # now we are in legacy territory. looking for "bind" on tables
+ # that are via bound metadata. this goes away in 2.0.
+
+ future_msg = ""
+ future_code = ""
+
+ if mapper and clause is None:
+ clause = mapper.persist_selectable
+
+ if clause is not None:
+ if clause.bind:
+ if self.future:
+ future_msg = (
+ " A bind was located via legacy bound metadata, but "
+ "since future=True is set on this Session, this "
+ "bind is ignored."
+ )
+ else:
+ util.warn_deprecated_20(
+ "This Session located a target engine via bound "
+ "metadata; as this functionality will be removed in "
+ "SQLAlchemy 2.0, an Engine object should be passed "
+ "to the Session() constructor directly."
+ )
+ return clause.bind
+
+ if mapper:
+ if mapper.persist_selectable.bind:
+ if self.future:
+ future_msg = (
+ " A bind was located via legacy bound metadata, but "
+ "since future=True is set on this Session, this "
+ "bind is ignored."
+ )
+ else:
+ util.warn_deprecated_20(
+ "This Session located a target engine via bound "
+ "metadata; as this functionality will be removed in "
+ "SQLAlchemy 2.0, an Engine object should be passed "
+ "to the Session() constructor directly."
+ )
+ return mapper.persist_selectable.bind
+
+ context = []
+ if mapper is not None:
+ context.append("mapper %s" % mapper)
+ if clause is not None:
+ context.append("SQL expression")
+
+ raise sa_exc.UnboundExecutionError(
+ "Could not locate a bind configured on %s or this Session.%s"
+ % (", ".join(context), future_msg),
+ code=future_code,
+ )
+
+ def query(self, *entities, **kwargs):
+ """Return a new :class:`_query.Query` object corresponding to this
+ :class:`_orm.Session`.
+
+ """
+
+ return self._query_cls(entities, self, **kwargs)
+
+ def _identity_lookup(
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ passive=attributes.PASSIVE_OFF,
+ lazy_loaded_from=None,
+ ):
+ """Locate an object in the identity map.
+
+ Given a primary key identity, constructs an identity key and then
+ looks in the session's identity map. If present, the object may
+ be run through unexpiration rules (e.g. load unloaded attributes,
+ check if was deleted).
+
+ e.g.::
+
+ obj = session._identity_lookup(inspect(SomeClass), (1, ))
+
+ :param mapper: mapper in use
+ :param primary_key_identity: the primary key we are searching for, as
+ a tuple.
+ :param identity_token: identity token that should be used to create
+ the identity key. Used as is, however overriding subclasses can
+ repurpose this in order to interpret the value in a special way,
+ such as if None then look among multiple target tokens.
+ :param passive: passive load flag passed to
+ :func:`.loading.get_from_identity`, which impacts the behavior if
+ the object is found; the object may be validated and/or unexpired
+ if the flag allows for SQL to be emitted.
+ :param lazy_loaded_from: an :class:`.InstanceState` that is
+ specifically asking for this identity as a related identity. Used
+ for sharding schemes where there is a correspondence between an object
+ and a related object being lazy-loaded (or otherwise
+ relationship-loaded).
+
+ :return: None if the object is not found in the identity map, *or*
+ if the object was unexpired and found to have been deleted.
+ if passive flags disallow SQL and the object is expired, returns
+ PASSIVE_NO_RESULT. In all other cases the instance is returned.
+
+ .. versionchanged:: 1.4.0 - the :meth:`.Session._identity_lookup`
+ method was moved from :class:`_query.Query` to
+ :class:`.Session`, to avoid having to instantiate the
+ :class:`_query.Query` object.
+
+
+ """
+
+ key = mapper.identity_key_from_primary_key(
+ primary_key_identity, identity_token=identity_token
+ )
+ return loading.get_from_identity(self, mapper, key, passive)
+
+ @property
+ @util.contextmanager
+ def no_autoflush(self):
+ """Return a context manager that disables autoflush.
+
+ e.g.::
+
+ with session.no_autoflush:
+
+ some_object = SomeClass()
+ session.add(some_object)
+ # won't autoflush
+ some_object.related_thing = session.query(SomeRelated).first()
+
+ Operations that proceed within the ``with:`` block
+ will not be subject to flushes occurring upon query
+ access. This is useful when initializing a series
+ of objects which involve existing database queries,
+ where the uncompleted object should not yet be flushed.
+
+ """
+ autoflush = self.autoflush
+ self.autoflush = False
+ try:
+ yield self
+ finally:
+ self.autoflush = autoflush
+
+ def _autoflush(self):
+ if self.autoflush and not self._flushing:
+ try:
+ self.flush()
+ except sa_exc.StatementError as e:
+ # note we are reraising StatementError as opposed to
+ # raising FlushError with "chaining" to remain compatible
+ # with code that catches StatementError, IntegrityError,
+ # etc.
+ e.add_detail(
+ "raised as a result of Query-invoked autoflush; "
+ "consider using a session.no_autoflush block if this "
+ "flush is occurring prematurely"
+ )
+ util.raise_(e, with_traceback=sys.exc_info()[2])
+
+ def refresh(self, instance, attribute_names=None, with_for_update=None):
+ """Expire and refresh attributes on the given instance.
+
+ The selected attributes will first be expired as they would when using
+ :meth:`_orm.Session.expire`; then a SELECT statement will be issued to
+ the database to refresh column-oriented attributes with the current
+ value available in the current transaction.
+
+ :func:`_orm.relationship` oriented attributes will also be immediately
+ loaded if they were already eagerly loaded on the object, using the
+ same eager loading strategy that they were loaded with originally.
+ Unloaded relationship attributes will remain unloaded, as will
+ relationship attributes that were originally lazy loaded.
+
+ .. versionadded:: 1.4 - the :meth:`_orm.Session.refresh` method
+ can also refresh eagerly loaded attributes.
+
+ .. tip::
+
+ While the :meth:`_orm.Session.refresh` method is capable of
+ refreshing both column and relationship oriented attributes, its
+ primary focus is on refreshing of local column-oriented attributes
+ on a single instance. For more open ended "refresh" functionality,
+ including the ability to refresh the attributes on many objects at
+ once while having explicit control over relationship loader
+ strategies, use the
+ :ref:`populate existing <orm_queryguide_populate_existing>` feature
+ instead.
+
+ Note that a highly isolated transaction will return the same values as
+ were previously read in that same transaction, regardless of changes
+ in database state outside of that transaction. Refreshing
+ attributes usually only makes sense at the start of a transaction
+ where database rows have not yet been accessed.
+
+ :param attribute_names: optional. An iterable collection of
+ string attribute names indicating a subset of attributes to
+ be refreshed.
+
+ :param with_for_update: optional boolean ``True`` indicating FOR UPDATE
+ should be used, or may be a dictionary containing flags to
+ indicate a more specific set of FOR UPDATE flags for the SELECT;
+ flags should match the parameters of
+ :meth:`_query.Query.with_for_update`.
+ Supersedes the :paramref:`.Session.refresh.lockmode` parameter.
+
+ .. seealso::
+
+ :ref:`session_expire` - introductory material
+
+ :meth:`.Session.expire`
+
+ :meth:`.Session.expire_all`
+
+ :ref:`orm_queryguide_populate_existing` - allows any ORM query
+ to refresh objects as they would be loaded normally.
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+
+ self._expire_state(state, attribute_names)
+
+ if with_for_update == {}:
+ raise sa_exc.ArgumentError(
+ "with_for_update should be the boolean value "
+ "True, or a dictionary with options. "
+ "A blank dictionary is ambiguous."
+ )
+
+ with_for_update = query.ForUpdateArg._from_argument(with_for_update)
+
+ stmt = sql.select(object_mapper(instance))
+ if (
+ loading.load_on_ident(
+ self,
+ stmt,
+ state.key,
+ refresh_state=state,
+ with_for_update=with_for_update,
+ only_load_props=attribute_names,
+ )
+ is None
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Could not refresh instance '%s'" % instance_str(instance)
+ )
+
+ def expire_all(self):
+ """Expires all persistent instances within this Session.
+
+ When any attributes on a persistent instance is next accessed,
+ a query will be issued using the
+ :class:`.Session` object's current transactional context in order to
+ load all expired attributes for the given instance. Note that
+ a highly isolated transaction will return the same values as were
+ previously read in that same transaction, regardless of changes
+ in database state outside of that transaction.
+
+ To expire individual objects and individual attributes
+ on those objects, use :meth:`Session.expire`.
+
+ The :class:`.Session` object's default behavior is to
+ expire all state whenever the :meth:`Session.rollback`
+ or :meth:`Session.commit` methods are called, so that new
+ state can be loaded for the new transaction. For this reason,
+ calling :meth:`Session.expire_all` should not be needed when
+ autocommit is ``False``, assuming the transaction is isolated.
+
+ .. seealso::
+
+ :ref:`session_expire` - introductory material
+
+ :meth:`.Session.expire`
+
+ :meth:`.Session.refresh`
+
+ :meth:`_orm.Query.populate_existing`
+
+ """
+ for state in self.identity_map.all_states():
+ state._expire(state.dict, self.identity_map._modified)
+
+ def expire(self, instance, attribute_names=None):
+ """Expire the attributes on an instance.
+
+ Marks the attributes of an instance as out of date. When an expired
+ attribute is next accessed, a query will be issued to the
+ :class:`.Session` object's current transactional context in order to
+ load all expired attributes for the given instance. Note that
+ a highly isolated transaction will return the same values as were
+ previously read in that same transaction, regardless of changes
+ in database state outside of that transaction.
+
+ To expire all objects in the :class:`.Session` simultaneously,
+ use :meth:`Session.expire_all`.
+
+ The :class:`.Session` object's default behavior is to
+ expire all state whenever the :meth:`Session.rollback`
+ or :meth:`Session.commit` methods are called, so that new
+ state can be loaded for the new transaction. For this reason,
+ calling :meth:`Session.expire` only makes sense for the specific
+ case that a non-ORM SQL statement was emitted in the current
+ transaction.
+
+ :param instance: The instance to be refreshed.
+ :param attribute_names: optional list of string attribute names
+ indicating a subset of attributes to be expired.
+
+ .. seealso::
+
+ :ref:`session_expire` - introductory material
+
+ :meth:`.Session.expire`
+
+ :meth:`.Session.refresh`
+
+ :meth:`_orm.Query.populate_existing`
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ self._expire_state(state, attribute_names)
+
+ def _expire_state(self, state, attribute_names):
+ self._validate_persistent(state)
+ if attribute_names:
+ state._expire_attributes(state.dict, attribute_names)
+ else:
+ # pre-fetch the full cascade since the expire is going to
+ # remove associations
+ cascaded = list(
+ state.manager.mapper.cascade_iterator("refresh-expire", state)
+ )
+ self._conditional_expire(state)
+ for o, m, st_, dct_ in cascaded:
+ self._conditional_expire(st_)
+
+ def _conditional_expire(self, state, autoflush=None):
+ """Expire a state if persistent, else expunge if pending"""
+
+ if state.key:
+ state._expire(state.dict, self.identity_map._modified)
+ elif state in self._new:
+ self._new.pop(state)
+ state._detach(self)
+
+ def expunge(self, instance):
+ """Remove the `instance` from this ``Session``.
+
+ This will free all internal references to the instance. Cascading
+ will be applied according to the *expunge* cascade rule.
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ if state.session_id is not self.hash_key:
+ raise sa_exc.InvalidRequestError(
+ "Instance %s is not present in this Session" % state_str(state)
+ )
+
+ cascaded = list(
+ state.manager.mapper.cascade_iterator("expunge", state)
+ )
+ self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded])
+
+ def _expunge_states(self, states, to_transient=False):
+ for state in states:
+ if state in self._new:
+ self._new.pop(state)
+ elif self.identity_map.contains_state(state):
+ self.identity_map.safe_discard(state)
+ self._deleted.pop(state, None)
+ elif self._transaction:
+ # state is "detached" from being deleted, but still present
+ # in the transaction snapshot
+ self._transaction._deleted.pop(state, None)
+ statelib.InstanceState._detach_states(
+ states, self, to_transient=to_transient
+ )
+
+ def _register_persistent(self, states):
+ """Register all persistent objects from a flush.
+
+ This is used both for pending objects moving to the persistent
+ state as well as already persistent objects.
+
+ """
+
+ pending_to_persistent = self.dispatch.pending_to_persistent or None
+ for state in states:
+ mapper = _state_mapper(state)
+
+ # prevent against last minute dereferences of the object
+ obj = state.obj()
+ if obj is not None:
+
+ instance_key = mapper._identity_key_from_state(state)
+
+ if (
+ _none_set.intersection(instance_key[1])
+ and not mapper.allow_partial_pks
+ or _none_set.issuperset(instance_key[1])
+ ):
+ raise exc.FlushError(
+ "Instance %s has a NULL identity key. If this is an "
+ "auto-generated value, check that the database table "
+ "allows generation of new primary key values, and "
+ "that the mapped Column object is configured to "
+ "expect these generated values. Ensure also that "
+ "this flush() is not occurring at an inappropriate "
+ "time, such as within a load() event."
+ % state_str(state)
+ )
+
+ if state.key is None:
+ state.key = instance_key
+ elif state.key != instance_key:
+ # primary key switch. use safe_discard() in case another
+ # state has already replaced this one in the identity
+ # map (see test/orm/test_naturalpks.py ReversePKsTest)
+ self.identity_map.safe_discard(state)
+ if state in self._transaction._key_switches:
+ orig_key = self._transaction._key_switches[state][0]
+ else:
+ orig_key = state.key
+ self._transaction._key_switches[state] = (
+ orig_key,
+ instance_key,
+ )
+ state.key = instance_key
+
+ # there can be an existing state in the identity map
+ # that is replaced when the primary keys of two instances
+ # are swapped; see test/orm/test_naturalpks.py -> test_reverse
+ old = self.identity_map.replace(state)
+ if (
+ old is not None
+ and mapper._identity_key_from_state(old) == instance_key
+ and old.obj() is not None
+ ):
+ util.warn(
+ "Identity map already had an identity for %s, "
+ "replacing it with newly flushed object. Are there "
+ "load operations occurring inside of an event handler "
+ "within the flush?" % (instance_key,)
+ )
+ state._orphaned_outside_of_session = False
+
+ statelib.InstanceState._commit_all_states(
+ ((state, state.dict) for state in states), self.identity_map
+ )
+
+ self._register_altered(states)
+
+ if pending_to_persistent is not None:
+ for state in states.intersection(self._new):
+ pending_to_persistent(self, state)
+
+ # remove from new last, might be the last strong ref
+ for state in set(states).intersection(self._new):
+ self._new.pop(state)
+
+ def _register_altered(self, states):
+ if self._transaction:
+ for state in states:
+ if state in self._new:
+ self._transaction._new[state] = True
+ else:
+ self._transaction._dirty[state] = True
+
+ def _remove_newly_deleted(self, states):
+ persistent_to_deleted = self.dispatch.persistent_to_deleted or None
+ for state in states:
+ if self._transaction:
+ self._transaction._deleted[state] = True
+
+ if persistent_to_deleted is not None:
+ # get a strong reference before we pop out of
+ # self._deleted
+ obj = state.obj() # noqa
+
+ self.identity_map.safe_discard(state)
+ self._deleted.pop(state, None)
+ state._deleted = True
+ # can't call state._detach() here, because this state
+ # is still in the transaction snapshot and needs to be
+ # tracked as part of that
+ if persistent_to_deleted is not None:
+ persistent_to_deleted(self, state)
+
+ def add(self, instance, _warn=True):
+ """Place an object in the ``Session``.
+
+ Its state will be persisted to the database on the next flush
+ operation.
+
+ Repeated calls to ``add()`` will be ignored. The opposite of ``add()``
+ is ``expunge()``.
+
+ """
+ if _warn and self._warn_on_events:
+ self._flush_warning("Session.add()")
+
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+
+ self._save_or_update_state(state)
+
+ def add_all(self, instances):
+ """Add the given collection of instances to this ``Session``."""
+
+ if self._warn_on_events:
+ self._flush_warning("Session.add_all()")
+
+ for instance in instances:
+ self.add(instance, _warn=False)
+
+ def _save_or_update_state(self, state):
+ state._orphaned_outside_of_session = False
+ self._save_or_update_impl(state)
+
+ mapper = _state_mapper(state)
+ for o, m, st_, dct_ in mapper.cascade_iterator(
+ "save-update", state, halt_on=self._contains_state
+ ):
+ self._save_or_update_impl(st_)
+
+ def delete(self, instance):
+ """Mark an instance as deleted.
+
+ The database delete operation occurs upon ``flush()``.
+
+ """
+ if self._warn_on_events:
+ self._flush_warning("Session.delete()")
+
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+
+ self._delete_impl(state, instance, head=True)
+
+ def _delete_impl(self, state, obj, head):
+
+ if state.key is None:
+ if head:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is not persisted" % state_str(state)
+ )
+ else:
+ return
+
+ to_attach = self._before_attach(state, obj)
+
+ if state in self._deleted:
+ return
+
+ self.identity_map.add(state)
+
+ if to_attach:
+ self._after_attach(state, obj)
+
+ if head:
+ # grab the cascades before adding the item to the deleted list
+ # so that autoflush does not delete the item
+ # the strong reference to the instance itself is significant here
+ cascade_states = list(
+ state.manager.mapper.cascade_iterator("delete", state)
+ )
+
+ self._deleted[state] = obj
+
+ if head:
+ for o, m, st_, dct_ in cascade_states:
+ self._delete_impl(st_, o, False)
+
+ def get(
+ self,
+ entity,
+ ident,
+ options=None,
+ populate_existing=False,
+ with_for_update=None,
+ identity_token=None,
+ execution_options=None,
+ ):
+ """Return an instance based on the given primary key identifier,
+ or ``None`` if not found.
+
+ E.g.::
+
+ my_user = session.get(User, 5)
+
+ some_object = session.get(VersionedFoo, (5, 10))
+
+ some_object = session.get(
+ VersionedFoo,
+ {"id": 5, "version_id": 10}
+ )
+
+ .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved
+ from the now deprecated :meth:`_orm.Query.get` method.
+
+ :meth:`_orm.Session.get` is special in that it provides direct
+ access to the identity map of the :class:`.Session`.
+ If the given primary key identifier is present
+ in the local identity map, the object is returned
+ directly from this collection and no SQL is emitted,
+ unless the object has been marked fully expired.
+ If not present,
+ a SELECT is performed in order to locate the object.
+
+ :meth:`_orm.Session.get` also will perform a check if
+ the object is present in the identity map and
+ marked as expired - a SELECT
+ is emitted to refresh the object as well as to
+ ensure that the row is still present.
+ If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised.
+
+ :param entity: a mapped class or :class:`.Mapper` indicating the
+ type of entity to be loaded.
+
+ :param ident: A scalar, tuple, or dictionary representing the
+ primary key. For a composite (e.g. multiple column) primary key,
+ a tuple or dictionary should be passed.
+
+ For a single-column primary key, the scalar calling form is typically
+ the most expedient. If the primary key of a row is the value "5",
+ the call looks like::
+
+ my_object = session.get(SomeClass, 5)
+
+ The tuple form contains primary key values typically in
+ the order in which they correspond to the mapped
+ :class:`_schema.Table`
+ object's primary key columns, or if the
+ :paramref:`_orm.Mapper.primary_key` configuration parameter were
+ used, in
+ the order used for that parameter. For example, if the primary key
+ of a row is represented by the integer
+ digits "5, 10" the call would look like::
+
+ my_object = session.get(SomeClass, (5, 10))
+
+ The dictionary form should include as keys the mapped attribute names
+ corresponding to each element of the primary key. If the mapped class
+ has the attributes ``id``, ``version_id`` as the attributes which
+ store the object's primary key value, the call would look like::
+
+ my_object = session.get(SomeClass, {"id": 5, "version_id": 10})
+
+ :param options: optional sequence of loader options which will be
+ applied to the query, if one is emitted.
+
+ :param populate_existing: causes the method to unconditionally emit
+ a SQL query and refresh the object with the newly loaded data,
+ regardless of whether or not the object is already present.
+
+ :param with_for_update: optional boolean ``True`` indicating FOR UPDATE
+ should be used, or may be a dictionary containing flags to
+ indicate a more specific set of FOR UPDATE flags for the SELECT;
+ flags should match the parameters of
+ :meth:`_query.Query.with_for_update`.
+ Supersedes the :paramref:`.Session.refresh.lockmode` parameter.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the query execution if one is emitted.
+ This dictionary can provide a subset of the options that are
+ accepted by :meth:`_engine.Connection.execution_options`, and may
+ also provide additional options understood only in an ORM context.
+
+ .. versionadded:: 1.4.29
+
+ .. seealso::
+
+ :ref:`orm_queryguide_execution_options` - ORM-specific execution
+ options
+
+ :return: The object instance, or ``None``.
+
+ """
+ return self._get_impl(
+ entity,
+ ident,
+ loading.load_on_pk_identity,
+ options,
+ populate_existing=populate_existing,
+ with_for_update=with_for_update,
+ identity_token=identity_token,
+ execution_options=execution_options,
+ )
+
+ def _get_impl(
+ self,
+ entity,
+ primary_key_identity,
+ db_load_fn,
+ options=None,
+ populate_existing=False,
+ with_for_update=None,
+ identity_token=None,
+ execution_options=None,
+ ):
+
+ # convert composite types to individual args
+ if hasattr(primary_key_identity, "__composite_values__"):
+ primary_key_identity = primary_key_identity.__composite_values__()
+
+ mapper = inspect(entity)
+
+ if not mapper or not mapper.is_mapper:
+ raise sa_exc.ArgumentError(
+ "Expected mapped class or mapper, got: %r" % entity
+ )
+
+ is_dict = isinstance(primary_key_identity, dict)
+ if not is_dict:
+ primary_key_identity = util.to_list(
+ primary_key_identity, default=(None,)
+ )
+
+ if len(primary_key_identity) != len(mapper.primary_key):
+ raise sa_exc.InvalidRequestError(
+ "Incorrect number of values in identifier to formulate "
+ "primary key for session.get(); primary key columns "
+ "are %s" % ",".join("'%s'" % c for c in mapper.primary_key)
+ )
+
+ if is_dict:
+ try:
+ primary_key_identity = list(
+ primary_key_identity[prop.key]
+ for prop in mapper._identity_key_props
+ )
+
+ except KeyError as err:
+ util.raise_(
+ sa_exc.InvalidRequestError(
+ "Incorrect names of values in identifier to formulate "
+ "primary key for session.get(); primary key attribute "
+ "names are %s"
+ % ",".join(
+ "'%s'" % prop.key
+ for prop in mapper._identity_key_props
+ )
+ ),
+ replace_context=err,
+ )
+
+ if (
+ not populate_existing
+ and not mapper.always_refresh
+ and with_for_update is None
+ ):
+
+ instance = self._identity_lookup(
+ mapper, primary_key_identity, identity_token=identity_token
+ )
+
+ if instance is not None:
+ # reject calls for id in identity map but class
+ # mismatch.
+ if not issubclass(instance.__class__, mapper.class_):
+ return None
+ return instance
+ elif instance is attributes.PASSIVE_CLASS_MISMATCH:
+ return None
+
+ # set_label_style() not strictly necessary, however this will ensure
+ # that tablename_colname style is used which at the moment is
+ # asserted in a lot of unit tests :)
+
+ load_options = context.QueryContext.default_load_options
+
+ if populate_existing:
+ load_options += {"_populate_existing": populate_existing}
+ statement = sql.select(mapper).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ )
+ if with_for_update is not None:
+ statement._for_update_arg = query.ForUpdateArg._from_argument(
+ with_for_update
+ )
+
+ if options:
+ statement = statement.options(*options)
+ if execution_options:
+ statement = statement.execution_options(**execution_options)
+ return db_load_fn(
+ self,
+ statement,
+ primary_key_identity,
+ load_options=load_options,
+ )
+
+ def merge(self, instance, load=True, options=None):
+ """Copy the state of a given instance into a corresponding instance
+ within this :class:`.Session`.
+
+ :meth:`.Session.merge` examines the primary key attributes of the
+ source instance, and attempts to reconcile it with an instance of the
+ same primary key in the session. If not found locally, it attempts
+ to load the object from the database based on primary key, and if
+ none can be located, creates a new instance. The state of each
+ attribute on the source instance is then copied to the target
+ instance. The resulting target instance is then returned by the
+ method; the original source instance is left unmodified, and
+ un-associated with the :class:`.Session` if not already.
+
+ This operation cascades to associated instances if the association is
+ mapped with ``cascade="merge"``.
+
+ See :ref:`unitofwork_merging` for a detailed discussion of merging.
+
+ .. versionchanged:: 1.1 - :meth:`.Session.merge` will now reconcile
+ pending objects with overlapping primary keys in the same way
+ as persistent. See :ref:`change_3601` for discussion.
+
+ :param instance: Instance to be merged.
+ :param load: Boolean, when False, :meth:`.merge` switches into
+ a "high performance" mode which causes it to forego emitting history
+ events as well as all database access. This flag is used for
+ cases such as transferring graphs of objects into a :class:`.Session`
+ from a second level cache, or to transfer just-loaded objects
+ into the :class:`.Session` owned by a worker thread or process
+ without re-querying the database.
+
+ The ``load=False`` use case adds the caveat that the given
+ object has to be in a "clean" state, that is, has no pending changes
+ to be flushed - even if the incoming object is detached from any
+ :class:`.Session`. This is so that when
+ the merge operation populates local attributes and
+ cascades to related objects and
+ collections, the values can be "stamped" onto the
+ target object as is, without generating any history or attribute
+ events, and without the need to reconcile the incoming data with
+ any existing related objects or collections that might not
+ be loaded. The resulting objects from ``load=False`` are always
+ produced as "clean", so it is only appropriate that the given objects
+ should be "clean" as well, else this suggests a mis-use of the
+ method.
+ :param options: optional sequence of loader options which will be
+ applied to the :meth:`_orm.Session.get` method when the merge
+ operation loads the existing version of the object from the database.
+
+ .. versionadded:: 1.4.24
+
+
+ .. seealso::
+
+ :func:`.make_transient_to_detached` - provides for an alternative
+ means of "merging" a single object into the :class:`.Session`
+
+ """
+
+ if self._warn_on_events:
+ self._flush_warning("Session.merge()")
+
+ _recursive = {}
+ _resolve_conflict_map = {}
+
+ if load:
+ # flush current contents if we expect to load data
+ self._autoflush()
+
+ object_mapper(instance) # verify mapped
+ autoflush = self.autoflush
+ try:
+ self.autoflush = False
+ return self._merge(
+ attributes.instance_state(instance),
+ attributes.instance_dict(instance),
+ load=load,
+ options=options,
+ _recursive=_recursive,
+ _resolve_conflict_map=_resolve_conflict_map,
+ )
+ finally:
+ self.autoflush = autoflush
+
+ def _merge(
+ self,
+ state,
+ state_dict,
+ load=True,
+ options=None,
+ _recursive=None,
+ _resolve_conflict_map=None,
+ ):
+ mapper = _state_mapper(state)
+ if state in _recursive:
+ return _recursive[state]
+
+ new_instance = False
+ key = state.key
+
+ if key is None:
+ if state in self._new:
+ util.warn(
+ "Instance %s is already pending in this Session yet is "
+ "being merged again; this is probably not what you want "
+ "to do" % state_str(state)
+ )
+
+ if not load:
+ raise sa_exc.InvalidRequestError(
+ "merge() with load=False option does not support "
+ "objects transient (i.e. unpersisted) objects. flush() "
+ "all changes on mapped instances before merging with "
+ "load=False."
+ )
+ key = mapper._identity_key_from_state(state)
+ key_is_persistent = attributes.NEVER_SET not in key[1] and (
+ not _none_set.intersection(key[1])
+ or (
+ mapper.allow_partial_pks
+ and not _none_set.issuperset(key[1])
+ )
+ )
+ else:
+ key_is_persistent = True
+
+ if key in self.identity_map:
+ try:
+ merged = self.identity_map[key]
+ except KeyError:
+ # object was GC'ed right as we checked for it
+ merged = None
+ else:
+ merged = None
+
+ if merged is None:
+ if key_is_persistent and key in _resolve_conflict_map:
+ merged = _resolve_conflict_map[key]
+
+ elif not load:
+ if state.modified:
+ raise sa_exc.InvalidRequestError(
+ "merge() with load=False option does not support "
+ "objects marked as 'dirty'. flush() all changes on "
+ "mapped instances before merging with load=False."
+ )
+ merged = mapper.class_manager.new_instance()
+ merged_state = attributes.instance_state(merged)
+ merged_state.key = key
+ self._update_impl(merged_state)
+ new_instance = True
+
+ elif key_is_persistent:
+ merged = self.get(
+ mapper.class_,
+ key[1],
+ identity_token=key[2],
+ options=options,
+ )
+
+ if merged is None:
+ merged = mapper.class_manager.new_instance()
+ merged_state = attributes.instance_state(merged)
+ merged_dict = attributes.instance_dict(merged)
+ new_instance = True
+ self._save_or_update_state(merged_state)
+ else:
+ merged_state = attributes.instance_state(merged)
+ merged_dict = attributes.instance_dict(merged)
+
+ _recursive[state] = merged
+ _resolve_conflict_map[key] = merged
+
+ # check that we didn't just pull the exact same
+ # state out.
+ if state is not merged_state:
+ # version check if applicable
+ if mapper.version_id_col is not None:
+ existing_version = mapper._get_state_attr_by_column(
+ state,
+ state_dict,
+ mapper.version_id_col,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+
+ merged_version = mapper._get_state_attr_by_column(
+ merged_state,
+ merged_dict,
+ mapper.version_id_col,
+ passive=attributes.PASSIVE_NO_INITIALIZE,
+ )
+
+ if (
+ existing_version is not attributes.PASSIVE_NO_RESULT
+ and merged_version is not attributes.PASSIVE_NO_RESULT
+ and existing_version != merged_version
+ ):
+ raise exc.StaleDataError(
+ "Version id '%s' on merged state %s "
+ "does not match existing version '%s'. "
+ "Leave the version attribute unset when "
+ "merging to update the most recent version."
+ % (
+ existing_version,
+ state_str(merged_state),
+ merged_version,
+ )
+ )
+
+ merged_state.load_path = state.load_path
+ merged_state.load_options = state.load_options
+
+ # since we are copying load_options, we need to copy
+ # the callables_ that would have been generated by those
+ # load_options.
+ # assumes that the callables we put in state.callables_
+ # are not instance-specific (which they should not be)
+ merged_state._copy_callables(state)
+
+ for prop in mapper.iterate_properties:
+ prop.merge(
+ self,
+ state,
+ state_dict,
+ merged_state,
+ merged_dict,
+ load,
+ _recursive,
+ _resolve_conflict_map,
+ )
+
+ if not load:
+ # remove any history
+ merged_state._commit_all(merged_dict, self.identity_map)
+
+ if new_instance:
+ merged_state.manager.dispatch.load(merged_state, None)
+ return merged
+
+ def _validate_persistent(self, state):
+ if not self.identity_map.contains_state(state):
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is not persistent within this Session"
+ % state_str(state)
+ )
+
+ def _save_impl(self, state):
+ if state.key is not None:
+ raise sa_exc.InvalidRequestError(
+ "Object '%s' already has an identity - "
+ "it can't be registered as pending" % state_str(state)
+ )
+
+ obj = state.obj()
+ to_attach = self._before_attach(state, obj)
+ if state not in self._new:
+ self._new[state] = obj
+ state.insert_order = len(self._new)
+ if to_attach:
+ self._after_attach(state, obj)
+
+ def _update_impl(self, state, revert_deletion=False):
+ if state.key is None:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' is not persisted" % state_str(state)
+ )
+
+ if state._deleted:
+ if revert_deletion:
+ if not state._attached:
+ return
+ del state._deleted
+ else:
+ raise sa_exc.InvalidRequestError(
+ "Instance '%s' has been deleted. "
+ "Use the make_transient() "
+ "function to send this object back "
+ "to the transient state." % state_str(state)
+ )
+
+ obj = state.obj()
+
+ # check for late gc
+ if obj is None:
+ return
+
+ to_attach = self._before_attach(state, obj)
+
+ self._deleted.pop(state, None)
+ if revert_deletion:
+ self.identity_map.replace(state)
+ else:
+ self.identity_map.add(state)
+
+ if to_attach:
+ self._after_attach(state, obj)
+ elif revert_deletion:
+ self.dispatch.deleted_to_persistent(self, state)
+
+ def _save_or_update_impl(self, state):
+ if state.key is None:
+ self._save_impl(state)
+ else:
+ self._update_impl(state)
+
+ def enable_relationship_loading(self, obj):
+ """Associate an object with this :class:`.Session` for related
+ object loading.
+
+ .. warning::
+
+ :meth:`.enable_relationship_loading` exists to serve special
+ use cases and is not recommended for general use.
+
+ Accesses of attributes mapped with :func:`_orm.relationship`
+ will attempt to load a value from the database using this
+ :class:`.Session` as the source of connectivity. The values
+ will be loaded based on foreign key and primary key values
+ present on this object - if not present, then those relationships
+ will be unavailable.
+
+ The object will be attached to this session, but will
+ **not** participate in any persistence operations; its state
+ for almost all purposes will remain either "transient" or
+ "detached", except for the case of relationship loading.
+
+ Also note that backrefs will often not work as expected.
+ Altering a relationship-bound attribute on the target object
+ may not fire off a backref event, if the effective value
+ is what was already loaded from a foreign-key-holding value.
+
+ The :meth:`.Session.enable_relationship_loading` method is
+ similar to the ``load_on_pending`` flag on :func:`_orm.relationship`.
+ Unlike that flag, :meth:`.Session.enable_relationship_loading` allows
+ an object to remain transient while still being able to load
+ related items.
+
+ To make a transient object associated with a :class:`.Session`
+ via :meth:`.Session.enable_relationship_loading` pending, add
+ it to the :class:`.Session` using :meth:`.Session.add` normally.
+ If the object instead represents an existing identity in the database,
+ it should be merged using :meth:`.Session.merge`.
+
+ :meth:`.Session.enable_relationship_loading` does not improve
+ behavior when the ORM is used normally - object references should be
+ constructed at the object level, not at the foreign key level, so
+ that they are present in an ordinary way before flush()
+ proceeds. This method is not intended for general use.
+
+ .. seealso::
+
+ :paramref:`_orm.relationship.load_on_pending` - this flag
+ allows per-relationship loading of many-to-ones on items that
+ are pending.
+
+ :func:`.make_transient_to_detached` - allows for an object to
+ be added to a :class:`.Session` without SQL emitted, which then
+ will unexpire attributes on access.
+
+ """
+ try:
+ state = attributes.instance_state(obj)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(obj),
+ replace_context=err,
+ )
+
+ to_attach = self._before_attach(state, obj)
+ state._load_pending = True
+ if to_attach:
+ self._after_attach(state, obj)
+
+ def _before_attach(self, state, obj):
+ self._autobegin()
+
+ if state.session_id == self.hash_key:
+ return False
+
+ if state.session_id and state.session_id in _sessions:
+ raise sa_exc.InvalidRequestError(
+ "Object '%s' is already attached to session '%s' "
+ "(this is '%s')"
+ % (state_str(state), state.session_id, self.hash_key)
+ )
+
+ self.dispatch.before_attach(self, state)
+
+ return True
+
+ def _after_attach(self, state, obj):
+ state.session_id = self.hash_key
+ if state.modified and state._strong_obj is None:
+ state._strong_obj = obj
+ self.dispatch.after_attach(self, state)
+
+ if state.key:
+ self.dispatch.detached_to_persistent(self, state)
+ else:
+ self.dispatch.transient_to_pending(self, state)
+
+ def __contains__(self, instance):
+ """Return True if the instance is associated with this session.
+
+ The instance may be pending or persistent within the Session for a
+ result of True.
+
+ """
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ return self._contains_state(state)
+
+ def __iter__(self):
+ """Iterate over all pending or persistent instances within this
+ Session.
+
+ """
+ return iter(
+ list(self._new.values()) + list(self.identity_map.values())
+ )
+
+ def _contains_state(self, state):
+ return state in self._new or self.identity_map.contains_state(state)
+
+ def flush(self, objects=None):
+ """Flush all the object changes to the database.
+
+ Writes out all pending object creations, deletions and modifications
+ to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are
+ automatically ordered by the Session's unit of work dependency
+ solver.
+
+ Database operations will be issued in the current transactional
+ context and do not affect the state of the transaction, unless an
+ error occurs, in which case the entire transaction is rolled back.
+ You may flush() as often as you like within a transaction to move
+ changes from Python to the database's transaction buffer.
+
+ For ``autocommit`` Sessions with no active manual transaction, flush()
+ will create a transaction on the fly that surrounds the entire set of
+ operations into the flush.
+
+ :param objects: Optional; restricts the flush operation to operate
+ only on elements that are in the given collection.
+
+ This feature is for an extremely narrow set of use cases where
+ particular objects may need to be operated upon before the
+ full flush() occurs. It is not intended for general use.
+
+ """
+
+ if self._flushing:
+ raise sa_exc.InvalidRequestError("Session is already flushing")
+
+ if self._is_clean():
+ return
+ try:
+ self._flushing = True
+ self._flush(objects)
+ finally:
+ self._flushing = False
+
+ def _flush_warning(self, method):
+ util.warn(
+ "Usage of the '%s' operation is not currently supported "
+ "within the execution stage of the flush process. "
+ "Results may not be consistent. Consider using alternative "
+ "event listeners or connection-level operations instead." % method
+ )
+
+ def _is_clean(self):
+ return (
+ not self.identity_map.check_modified()
+ and not self._deleted
+ and not self._new
+ )
+
+ def _flush(self, objects=None):
+
+ dirty = self._dirty_states
+ if not dirty and not self._deleted and not self._new:
+ self.identity_map._modified.clear()
+ return
+
+ flush_context = UOWTransaction(self)
+
+ if self.dispatch.before_flush:
+ self.dispatch.before_flush(self, flush_context, objects)
+ # re-establish "dirty states" in case the listeners
+ # added
+ dirty = self._dirty_states
+
+ deleted = set(self._deleted)
+ new = set(self._new)
+
+ dirty = set(dirty).difference(deleted)
+
+ # create the set of all objects we want to operate upon
+ if objects:
+ # specific list passed in
+ objset = set()
+ for o in objects:
+ try:
+ state = attributes.instance_state(o)
+
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(o),
+ replace_context=err,
+ )
+ objset.add(state)
+ else:
+ objset = None
+
+ # store objects whose fate has been decided
+ processed = set()
+
+ # put all saves/updates into the flush context. detect top-level
+ # orphans and throw them into deleted.
+ if objset:
+ proc = new.union(dirty).intersection(objset).difference(deleted)
+ else:
+ proc = new.union(dirty).difference(deleted)
+
+ for state in proc:
+ is_orphan = _state_mapper(state)._is_orphan(state)
+
+ is_persistent_orphan = is_orphan and state.has_identity
+
+ if (
+ is_orphan
+ and not is_persistent_orphan
+ and state._orphaned_outside_of_session
+ ):
+ self._expunge_states([state])
+ else:
+ _reg = flush_context.register_object(
+ state, isdelete=is_persistent_orphan
+ )
+ assert _reg, "Failed to add object to the flush context!"
+ processed.add(state)
+
+ # put all remaining deletes into the flush context.
+ if objset:
+ proc = deleted.intersection(objset).difference(processed)
+ else:
+ proc = deleted.difference(processed)
+ for state in proc:
+ _reg = flush_context.register_object(state, isdelete=True)
+ assert _reg, "Failed to add object to the flush context!"
+
+ if not flush_context.has_work:
+ return
+
+ flush_context.transaction = transaction = self.begin(_subtrans=True)
+ try:
+ self._warn_on_events = True
+ try:
+ flush_context.execute()
+ finally:
+ self._warn_on_events = False
+
+ self.dispatch.after_flush(self, flush_context)
+
+ flush_context.finalize_flush_changes()
+
+ if not objects and self.identity_map._modified:
+ len_ = len(self.identity_map._modified)
+
+ statelib.InstanceState._commit_all_states(
+ [
+ (state, state.dict)
+ for state in self.identity_map._modified
+ ],
+ instance_dict=self.identity_map,
+ )
+ util.warn(
+ "Attribute history events accumulated on %d "
+ "previously clean instances "
+ "within inner-flush event handlers have been "
+ "reset, and will not result in database updates. "
+ "Consider using set_committed_value() within "
+ "inner-flush event handlers to avoid this warning." % len_
+ )
+
+ # useful assertions:
+ # if not objects:
+ # assert not self.identity_map._modified
+ # else:
+ # assert self.identity_map._modified == \
+ # self.identity_map._modified.difference(objects)
+
+ self.dispatch.after_flush_postexec(self, flush_context)
+
+ transaction.commit()
+
+ except:
+ with util.safe_reraise():
+ transaction.rollback(_capture_exception=True)
+
+ def bulk_save_objects(
+ self,
+ objects,
+ return_defaults=False,
+ update_changed_only=True,
+ preserve_order=True,
+ ):
+ """Perform a bulk save of the given list of objects.
+
+ The bulk save feature allows mapped objects to be used as the
+ source of simple INSERT and UPDATE operations which can be more easily
+ grouped together into higher performing "executemany"
+ operations; the extraction of data from the objects is also performed
+ using a lower-latency process that ignores whether or not attributes
+ have actually been modified in the case of UPDATEs, and also ignores
+ SQL expressions.
+
+ The objects as given are not added to the session and no additional
+ state is established on them. If the
+ :paramref:`_orm.Session.bulk_save_objects.return_defaults` flag is set,
+ then server-generated primary key values will be assigned to the
+ returned objects, but **not server side defaults**; this is a
+ limitation in the implementation. If stateful objects are desired,
+ please use the standard :meth:`_orm.Session.add_all` approach or
+ as an alternative newer mass-insert features such as
+ :ref:`orm_dml_returning_objects`.
+
+ .. warning::
+
+ The bulk save feature allows for a lower-latency INSERT/UPDATE
+ of rows at the expense of most other unit-of-work features.
+ Features such as object management, relationship handling,
+ and SQL clause support are **silently omitted** in favor of raw
+ INSERT/UPDATES of records.
+
+ Please note that newer versions of SQLAlchemy are **greatly
+ improving the efficiency** of the standard flush process. It is
+ **strongly recommended** to not use the bulk methods as they
+ represent a forking of SQLAlchemy's functionality and are slowly
+ being moved into legacy status. New features such as
+ :ref:`orm_dml_returning_objects` are both more efficient than
+ the "bulk" methods and provide more predictable functionality.
+
+ **Please read the list of caveats at**
+ :ref:`bulk_operations_caveats` **before using this method, and
+ fully test and confirm the functionality of all code developed
+ using these systems.**
+
+ :param objects: a sequence of mapped object instances. The mapped
+ objects are persisted as is, and are **not** associated with the
+ :class:`.Session` afterwards.
+
+ For each object, whether the object is sent as an INSERT or an
+ UPDATE is dependent on the same rules used by the :class:`.Session`
+ in traditional operation; if the object has the
+ :attr:`.InstanceState.key`
+ attribute set, then the object is assumed to be "detached" and
+ will result in an UPDATE. Otherwise, an INSERT is used.
+
+ In the case of an UPDATE, statements are grouped based on which
+ attributes have changed, and are thus to be the subject of each
+ SET clause. If ``update_changed_only`` is False, then all
+ attributes present within each object are applied to the UPDATE
+ statement, which may help in allowing the statements to be grouped
+ together into a larger executemany(), and will also reduce the
+ overhead of checking history on attributes.
+
+ :param return_defaults: when True, rows that are missing values which
+ generate defaults, namely integer primary key defaults and sequences,
+ will be inserted **one at a time**, so that the primary key value
+ is available. In particular this will allow joined-inheritance
+ and other multi-table mappings to insert correctly without the need
+ to provide primary key values ahead of time; however,
+ :paramref:`.Session.bulk_save_objects.return_defaults` **greatly
+ reduces the performance gains** of the method overall. It is strongly
+ advised to please use the standard :meth:`_orm.Session.add_all`
+ approach.
+
+ :param update_changed_only: when True, UPDATE statements are rendered
+ based on those attributes in each state that have logged changes.
+ When False, all attributes present are rendered into the SET clause
+ with the exception of primary key attributes.
+
+ :param preserve_order: when True, the order of inserts and updates
+ matches exactly the order in which the objects are given. When
+ False, common types of objects are grouped into inserts
+ and updates, to allow for more batching opportunities.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`bulk_operations`
+
+ :meth:`.Session.bulk_insert_mappings`
+
+ :meth:`.Session.bulk_update_mappings`
+
+ """
+
+ obj_states = (attributes.instance_state(obj) for obj in objects)
+
+ if not preserve_order:
+ # the purpose of this sort is just so that common mappers
+ # and persistence states are grouped together, so that groupby
+ # will return a single group for a particular type of mapper.
+ # it's not trying to be deterministic beyond that.
+ obj_states = sorted(
+ obj_states,
+ key=lambda state: (id(state.mapper), state.key is not None),
+ )
+
+ def grouping_key(state):
+ return (state.mapper, state.key is not None)
+
+ for (mapper, isupdate), states in itertools.groupby(
+ obj_states, grouping_key
+ ):
+ self._bulk_save_mappings(
+ mapper,
+ states,
+ isupdate,
+ True,
+ return_defaults,
+ update_changed_only,
+ False,
+ )
+
+ def bulk_insert_mappings(
+ self, mapper, mappings, return_defaults=False, render_nulls=False
+ ):
+ """Perform a bulk insert of the given list of mapping dictionaries.
+
+ The bulk insert feature allows plain Python dictionaries to be used as
+ the source of simple INSERT operations which can be more easily
+ grouped together into higher performing "executemany"
+ operations. Using dictionaries, there is no "history" or session
+ state management features in use, reducing latency when inserting
+ large numbers of simple rows.
+
+ The values within the dictionaries as given are typically passed
+ without modification into Core :meth:`_expression.Insert` constructs,
+ after
+ organizing the values within them across the tables to which
+ the given mapper is mapped.
+
+ .. versionadded:: 1.0.0
+
+ .. warning::
+
+ The bulk insert feature allows for a lower-latency INSERT
+ of rows at the expense of most other unit-of-work features.
+ Features such as object management, relationship handling,
+ and SQL clause support are **silently omitted** in favor of raw
+ INSERT of records.
+
+ Please note that newer versions of SQLAlchemy are **greatly
+ improving the efficiency** of the standard flush process. It is
+ **strongly recommended** to not use the bulk methods as they
+ represent a forking of SQLAlchemy's functionality and are slowly
+ being moved into legacy status. New features such as
+ :ref:`orm_dml_returning_objects` are both more efficient than
+ the "bulk" methods and provide more predictable functionality.
+
+ **Please read the list of caveats at**
+ :ref:`bulk_operations_caveats` **before using this method, and
+ fully test and confirm the functionality of all code developed
+ using these systems.**
+
+ :param mapper: a mapped class, or the actual :class:`_orm.Mapper`
+ object,
+ representing the single kind of object represented within the mapping
+ list.
+
+ :param mappings: a sequence of dictionaries, each one containing the
+ state of the mapped row to be inserted, in terms of the attribute
+ names on the mapped class. If the mapping refers to multiple tables,
+ such as a joined-inheritance mapping, each dictionary must contain all
+ keys to be populated into all tables.
+
+ :param return_defaults: when True, rows that are missing values which
+ generate defaults, namely integer primary key defaults and sequences,
+ will be inserted **one at a time**, so that the primary key value
+ is available. In particular this will allow joined-inheritance
+ and other multi-table mappings to insert correctly without the need
+ to provide primary
+ key values ahead of time; however,
+ :paramref:`.Session.bulk_insert_mappings.return_defaults`
+ **greatly reduces the performance gains** of the method overall.
+ If the rows
+ to be inserted only refer to a single table, then there is no
+ reason this flag should be set as the returned default information
+ is not used.
+
+ :param render_nulls: When True, a value of ``None`` will result
+ in a NULL value being included in the INSERT statement, rather
+ than the column being omitted from the INSERT. This allows all
+ the rows being INSERTed to have the identical set of columns which
+ allows the full set of rows to be batched to the DBAPI. Normally,
+ each column-set that contains a different combination of NULL values
+ than the previous row must omit a different series of columns from
+ the rendered INSERT statement, which means it must be emitted as a
+ separate statement. By passing this flag, the full set of rows
+ are guaranteed to be batchable into one batch; the cost however is
+ that server-side defaults which are invoked by an omitted column will
+ be skipped, so care must be taken to ensure that these are not
+ necessary.
+
+ .. warning::
+
+ When this flag is set, **server side default SQL values will
+ not be invoked** for those columns that are inserted as NULL;
+ the NULL value will be sent explicitly. Care must be taken
+ to ensure that no server-side default functions need to be
+ invoked for the operation as a whole.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`bulk_operations`
+
+ :meth:`.Session.bulk_save_objects`
+
+ :meth:`.Session.bulk_update_mappings`
+
+ """
+ self._bulk_save_mappings(
+ mapper,
+ mappings,
+ False,
+ False,
+ return_defaults,
+ False,
+ render_nulls,
+ )
+
+ def bulk_update_mappings(self, mapper, mappings):
+ """Perform a bulk update of the given list of mapping dictionaries.
+
+ The bulk update feature allows plain Python dictionaries to be used as
+ the source of simple UPDATE operations which can be more easily
+ grouped together into higher performing "executemany"
+ operations. Using dictionaries, there is no "history" or session
+ state management features in use, reducing latency when updating
+ large numbers of simple rows.
+
+ .. versionadded:: 1.0.0
+
+ .. warning::
+
+ The bulk update feature allows for a lower-latency UPDATE
+ of rows at the expense of most other unit-of-work features.
+ Features such as object management, relationship handling,
+ and SQL clause support are **silently omitted** in favor of raw
+ UPDATES of records.
+
+ Please note that newer versions of SQLAlchemy are **greatly
+ improving the efficiency** of the standard flush process. It is
+ **strongly recommended** to not use the bulk methods as they
+ represent a forking of SQLAlchemy's functionality and are slowly
+ being moved into legacy status. New features such as
+ :ref:`orm_dml_returning_objects` are both more efficient than
+ the "bulk" methods and provide more predictable functionality.
+
+ **Please read the list of caveats at**
+ :ref:`bulk_operations_caveats` **before using this method, and
+ fully test and confirm the functionality of all code developed
+ using these systems.**
+
+ :param mapper: a mapped class, or the actual :class:`_orm.Mapper`
+ object,
+ representing the single kind of object represented within the mapping
+ list.
+
+ :param mappings: a sequence of dictionaries, each one containing the
+ state of the mapped row to be updated, in terms of the attribute names
+ on the mapped class. If the mapping refers to multiple tables, such
+ as a joined-inheritance mapping, each dictionary may contain keys
+ corresponding to all tables. All those keys which are present and
+ are not part of the primary key are applied to the SET clause of the
+ UPDATE statement; the primary key values, which are required, are
+ applied to the WHERE clause.
+
+
+ .. seealso::
+
+ :ref:`bulk_operations`
+
+ :meth:`.Session.bulk_insert_mappings`
+
+ :meth:`.Session.bulk_save_objects`
+
+ """
+ self._bulk_save_mappings(
+ mapper, mappings, True, False, False, False, False
+ )
+
+ def _bulk_save_mappings(
+ self,
+ mapper,
+ mappings,
+ isupdate,
+ isstates,
+ return_defaults,
+ update_changed_only,
+ render_nulls,
+ ):
+ mapper = _class_to_mapper(mapper)
+ self._flushing = True
+
+ transaction = self.begin(_subtrans=True)
+ try:
+ if isupdate:
+ persistence._bulk_update(
+ mapper,
+ mappings,
+ transaction,
+ isstates,
+ update_changed_only,
+ )
+ else:
+ persistence._bulk_insert(
+ mapper,
+ mappings,
+ transaction,
+ isstates,
+ return_defaults,
+ render_nulls,
+ )
+ transaction.commit()
+
+ except:
+ with util.safe_reraise():
+ transaction.rollback(_capture_exception=True)
+ finally:
+ self._flushing = False
+
+ def is_modified(self, instance, include_collections=True):
+ r"""Return ``True`` if the given instance has locally
+ modified attributes.
+
+ This method retrieves the history for each instrumented
+ attribute on the instance and performs a comparison of the current
+ value to its previously committed value, if any.
+
+ It is in effect a more expensive and accurate
+ version of checking for the given instance in the
+ :attr:`.Session.dirty` collection; a full test for
+ each attribute's net "dirty" status is performed.
+
+ E.g.::
+
+ return session.is_modified(someobject)
+
+ A few caveats to this method apply:
+
+ * Instances present in the :attr:`.Session.dirty` collection may
+ report ``False`` when tested with this method. This is because
+ the object may have received change events via attribute mutation,
+ thus placing it in :attr:`.Session.dirty`, but ultimately the state
+ is the same as that loaded from the database, resulting in no net
+ change here.
+ * Scalar attributes may not have recorded the previously set
+ value when a new value was applied, if the attribute was not loaded,
+ or was expired, at the time the new value was received - in these
+ cases, the attribute is assumed to have a change, even if there is
+ ultimately no net change against its database value. SQLAlchemy in
+ most cases does not need the "old" value when a set event occurs, so
+ it skips the expense of a SQL call if the old value isn't present,
+ based on the assumption that an UPDATE of the scalar value is
+ usually needed, and in those few cases where it isn't, is less
+ expensive on average than issuing a defensive SELECT.
+
+ The "old" value is fetched unconditionally upon set only if the
+ attribute container has the ``active_history`` flag set to ``True``.
+ This flag is set typically for primary key attributes and scalar
+ object references that are not a simple many-to-one. To set this
+ flag for any arbitrary mapped column, use the ``active_history``
+ argument with :func:`.column_property`.
+
+ :param instance: mapped instance to be tested for pending changes.
+ :param include_collections: Indicates if multivalued collections
+ should be included in the operation. Setting this to ``False`` is a
+ way to detect only local-column based properties (i.e. scalar columns
+ or many-to-one foreign keys) that would result in an UPDATE for this
+ instance upon flush.
+
+ """
+ state = object_state(instance)
+
+ if not state.modified:
+ return False
+
+ dict_ = state.dict
+
+ for attr in state.manager.attributes:
+ if (
+ not include_collections
+ and hasattr(attr.impl, "get_collection")
+ ) or not hasattr(attr.impl, "get_history"):
+ continue
+
+ (added, unchanged, deleted) = attr.impl.get_history(
+ state, dict_, passive=attributes.NO_CHANGE
+ )
+
+ if added or deleted:
+ return True
+ else:
+ return False
+
+ @property
+ def is_active(self):
+ """True if this :class:`.Session` not in "partial rollback" state.
+
+ .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins
+ a new transaction immediately, so this attribute will be False
+ when the :class:`_orm.Session` is first instantiated.
+
+ "partial rollback" state typically indicates that the flush process
+ of the :class:`_orm.Session` has failed, and that the
+ :meth:`_orm.Session.rollback` method must be emitted in order to
+ fully roll back the transaction.
+
+ If this :class:`_orm.Session` is not in a transaction at all, the
+ :class:`_orm.Session` will autobegin when it is first used, so in this
+ case :attr:`_orm.Session.is_active` will return True.
+
+ Otherwise, if this :class:`_orm.Session` is within a transaction,
+ and that transaction has not been rolled back internally, the
+ :attr:`_orm.Session.is_active` will also return True.
+
+ .. seealso::
+
+ :ref:`faq_session_rollback`
+
+ :meth:`_orm.Session.in_transaction`
+
+ """
+ if self.autocommit:
+ return (
+ self._transaction is not None and self._transaction.is_active
+ )
+ else:
+ return self._transaction is None or self._transaction.is_active
+
+ identity_map = None
+ """A mapping of object identities to objects themselves.
+
+ Iterating through ``Session.identity_map.values()`` provides
+ access to the full set of persistent objects (i.e., those
+ that have row identity) currently in the session.
+
+ .. seealso::
+
+ :func:`.identity_key` - helper function to produce the keys used
+ in this dictionary.
+
+ """
+
+ @property
+ def _dirty_states(self):
+ """The set of all persistent states considered dirty.
+
+ This method returns all states that were modified including
+ those that were possibly deleted.
+
+ """
+ return self.identity_map._dirty_states()
+
+ @property
+ def dirty(self):
+ """The set of all persistent instances considered dirty.
+
+ E.g.::
+
+ some_mapped_object in session.dirty
+
+ Instances are considered dirty when they were modified but not
+ deleted.
+
+ Note that this 'dirty' calculation is 'optimistic'; most
+ attribute-setting or collection modification operations will
+ mark an instance as 'dirty' and place it in this set, even if
+ there is no net change to the attribute's value. At flush
+ time, the value of each attribute is compared to its
+ previously saved value, and if there's no net change, no SQL
+ operation will occur (this is a more expensive operation so
+ it's only done at flush time).
+
+ To check if an instance has actionable net changes to its
+ attributes, use the :meth:`.Session.is_modified` method.
+
+ """
+ return util.IdentitySet(
+ [
+ state.obj()
+ for state in self._dirty_states
+ if state not in self._deleted
+ ]
+ )
+
+ @property
+ def deleted(self):
+ "The set of all instances marked as 'deleted' within this ``Session``"
+
+ return util.IdentitySet(list(self._deleted.values()))
+
+ @property
+ def new(self):
+ "The set of all instances marked as 'new' within this ``Session``."
+
+ return util.IdentitySet(list(self._new.values()))
+
+
+class sessionmaker(_SessionClassMethods):
+ """A configurable :class:`.Session` factory.
+
+ The :class:`.sessionmaker` factory generates new
+ :class:`.Session` objects when called, creating them given
+ the configurational arguments established here.
+
+ e.g.::
+
+ from sqlalchemy import create_engine
+ from sqlalchemy.orm import sessionmaker
+
+ # an Engine, which the Session will use for connection
+ # resources
+ engine = create_engine('postgresql://scott:tiger@localhost/')
+
+ Session = sessionmaker(engine)
+
+ with Session() as session:
+ session.add(some_object)
+ session.add(some_other_object)
+ session.commit()
+
+ Context manager use is optional; otherwise, the returned
+ :class:`_orm.Session` object may be closed explicitly via the
+ :meth:`_orm.Session.close` method. Using a
+ ``try:/finally:`` block is optional, however will ensure that the close
+ takes place even if there are database errors::
+
+ session = Session()
+ try:
+ session.add(some_object)
+ session.add(some_other_object)
+ session.commit()
+ finally:
+ session.close()
+
+ :class:`.sessionmaker` acts as a factory for :class:`_orm.Session`
+ objects in the same way as an :class:`_engine.Engine` acts as a factory
+ for :class:`_engine.Connection` objects. In this way it also includes
+ a :meth:`_orm.sessionmaker.begin` method, that provides a context
+ manager which both begins and commits a transaction, as well as closes
+ out the :class:`_orm.Session` when complete, rolling back the transaction
+ if any errors occur::
+
+ Session = sessionmaker(engine)
+
+ with Session.begin() as session:
+ session.add(some_object)
+ session.add(some_other_object)
+ # commits transaction, closes session
+
+ .. versionadded:: 1.4
+
+ When calling upon :class:`_orm.sessionmaker` to construct a
+ :class:`_orm.Session`, keyword arguments may also be passed to the
+ method; these arguments will override that of the globally configured
+ parameters. Below we use a :class:`_orm.sessionmaker` bound to a certain
+ :class:`_engine.Engine` to produce a :class:`_orm.Session` that is instead
+ bound to a specific :class:`_engine.Connection` procured from that engine::
+
+ Session = sessionmaker(engine)
+
+ # bind an individual session to a connection
+
+ with engine.connect() as connection:
+ with Session(bind=connection) as session:
+ # work with session
+
+ The class also includes a method :meth:`_orm.sessionmaker.configure`, which
+ can be used to specify additional keyword arguments to the factory, which
+ will take effect for subsequent :class:`.Session` objects generated. This
+ is usually used to associate one or more :class:`_engine.Engine` objects
+ with an existing
+ :class:`.sessionmaker` factory before it is first used::
+
+ # application starts, sessionmaker does not have
+ # an engine bound yet
+ Session = sessionmaker()
+
+ # ... later, when an engine URL is read from a configuration
+ # file or other events allow the engine to be created
+ engine = create_engine('sqlite:///foo.db')
+ Session.configure(bind=engine)
+
+ sess = Session()
+ # work with session
+
+ .. seealso::
+
+ :ref:`session_getting` - introductory text on creating
+ sessions using :class:`.sessionmaker`.
+
+ """
+
+ def __init__(
+ self,
+ bind=None,
+ class_=Session,
+ autoflush=True,
+ autocommit=False,
+ expire_on_commit=True,
+ info=None,
+ **kw
+ ):
+ r"""Construct a new :class:`.sessionmaker`.
+
+ All arguments here except for ``class_`` correspond to arguments
+ accepted by :class:`.Session` directly. See the
+ :meth:`.Session.__init__` docstring for more details on parameters.
+
+ :param bind: a :class:`_engine.Engine` or other :class:`.Connectable`
+ with
+ which newly created :class:`.Session` objects will be associated.
+ :param class\_: class to use in order to create new :class:`.Session`
+ objects. Defaults to :class:`.Session`.
+ :param autoflush: The autoflush setting to use with newly created
+ :class:`.Session` objects.
+ :param autocommit: The autocommit setting to use with newly created
+ :class:`.Session` objects.
+ :param expire_on_commit=True: the
+ :paramref:`_orm.Session.expire_on_commit` setting to use
+ with newly created :class:`.Session` objects.
+
+ :param info: optional dictionary of information that will be available
+ via :attr:`.Session.info`. Note this dictionary is *updated*, not
+ replaced, when the ``info`` parameter is specified to the specific
+ :class:`.Session` construction operation.
+
+ :param \**kw: all other keyword arguments are passed to the
+ constructor of newly created :class:`.Session` objects.
+
+ """
+ kw["bind"] = bind
+ kw["autoflush"] = autoflush
+ kw["autocommit"] = autocommit
+ kw["expire_on_commit"] = expire_on_commit
+ if info is not None:
+ kw["info"] = info
+ self.kw = kw
+ # make our own subclass of the given class, so that
+ # events can be associated with it specifically.
+ self.class_ = type(class_.__name__, (class_,), {})
+
+ def begin(self):
+ """Produce a context manager that both provides a new
+ :class:`_orm.Session` as well as a transaction that commits.
+
+
+ e.g.::
+
+ Session = sessionmaker(some_engine)
+
+ with Session.begin() as session:
+ session.add(some_object)
+
+ # commits transaction, closes session
+
+ .. versionadded:: 1.4
+
+
+ """
+
+ session = self()
+ return session._maker_context_manager()
+
+ def __call__(self, **local_kw):
+ """Produce a new :class:`.Session` object using the configuration
+ established in this :class:`.sessionmaker`.
+
+ In Python, the ``__call__`` method is invoked on an object when
+ it is "called" in the same way as a function::
+
+ Session = sessionmaker()
+ session = Session() # invokes sessionmaker.__call__()
+
+ """
+ for k, v in self.kw.items():
+ if k == "info" and "info" in local_kw:
+ d = v.copy()
+ d.update(local_kw["info"])
+ local_kw["info"] = d
+ else:
+ local_kw.setdefault(k, v)
+ return self.class_(**local_kw)
+
+ def configure(self, **new_kw):
+ """(Re)configure the arguments for this sessionmaker.
+
+ e.g.::
+
+ Session = sessionmaker()
+
+ Session.configure(bind=create_engine('sqlite://'))
+ """
+ self.kw.update(new_kw)
+
+ def __repr__(self):
+ return "%s(class_=%r, %s)" % (
+ self.__class__.__name__,
+ self.class_.__name__,
+ ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()),
+ )
+
+
+def close_all_sessions():
+ """Close all sessions in memory.
+
+ This function consults a global registry of all :class:`.Session` objects
+ and calls :meth:`.Session.close` on them, which resets them to a clean
+ state.
+
+ This function is not for general use but may be useful for test suites
+ within the teardown scheme.
+
+ .. versionadded:: 1.3
+
+ """
+
+ for sess in _sessions.values():
+ sess.close()
+
+
+def make_transient(instance):
+ """Alter the state of the given instance so that it is :term:`transient`.
+
+ .. note::
+
+ :func:`.make_transient` is a special-case function for
+ advanced use cases only.
+
+ The given mapped instance is assumed to be in the :term:`persistent` or
+ :term:`detached` state. The function will remove its association with any
+ :class:`.Session` as well as its :attr:`.InstanceState.identity`. The
+ effect is that the object will behave as though it were newly constructed,
+ except retaining any attribute / collection values that were loaded at the
+ time of the call. The :attr:`.InstanceState.deleted` flag is also reset
+ if this object had been deleted as a result of using
+ :meth:`.Session.delete`.
+
+ .. warning::
+
+ :func:`.make_transient` does **not** "unexpire" or otherwise eagerly
+ load ORM-mapped attributes that are not currently loaded at the time
+ the function is called. This includes attributes which:
+
+ * were expired via :meth:`.Session.expire`
+
+ * were expired as the natural effect of committing a session
+ transaction, e.g. :meth:`.Session.commit`
+
+ * are normally :term:`lazy loaded` but are not currently loaded
+
+ * are "deferred" via :ref:`deferred` and are not yet loaded
+
+ * were not present in the query which loaded this object, such as that
+ which is common in joined table inheritance and other scenarios.
+
+ After :func:`.make_transient` is called, unloaded attributes such
+ as those above will normally resolve to the value ``None`` when
+ accessed, or an empty collection for a collection-oriented attribute.
+ As the object is transient and un-associated with any database
+ identity, it will no longer retrieve these values.
+
+ .. seealso::
+
+ :func:`.make_transient_to_detached`
+
+ """
+ state = attributes.instance_state(instance)
+ s = _state_session(state)
+ if s:
+ s._expunge_states([state])
+
+ # remove expired state
+ state.expired_attributes.clear()
+
+ # remove deferred callables
+ if state.callables:
+ del state.callables
+
+ if state.key:
+ del state.key
+ if state._deleted:
+ del state._deleted
+
+
+def make_transient_to_detached(instance):
+ """Make the given transient instance :term:`detached`.
+
+ .. note::
+
+ :func:`.make_transient_to_detached` is a special-case function for
+ advanced use cases only.
+
+ All attribute history on the given instance
+ will be reset as though the instance were freshly loaded
+ from a query. Missing attributes will be marked as expired.
+ The primary key attributes of the object, which are required, will be made
+ into the "key" of the instance.
+
+ The object can then be added to a session, or merged
+ possibly with the load=False flag, at which point it will look
+ as if it were loaded that way, without emitting SQL.
+
+ This is a special use case function that differs from a normal
+ call to :meth:`.Session.merge` in that a given persistent state
+ can be manufactured without any SQL calls.
+
+ .. seealso::
+
+ :func:`.make_transient`
+
+ :meth:`.Session.enable_relationship_loading`
+
+ """
+ state = attributes.instance_state(instance)
+ if state.session_id or state.key:
+ raise sa_exc.InvalidRequestError("Given object must be transient")
+ state.key = state.mapper._identity_key_from_state(state)
+ if state._deleted:
+ del state._deleted
+ state._commit_all(state.dict)
+ state._expire_attributes(state.dict, state.unloaded_expirable)
+
+
+def object_session(instance):
+ """Return the :class:`.Session` to which the given instance belongs.
+
+ This is essentially the same as the :attr:`.InstanceState.session`
+ accessor. See that attribute for details.
+
+ """
+
+ try:
+ state = attributes.instance_state(instance)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(instance),
+ replace_context=err,
+ )
+ else:
+ return _state_session(state)
+
+
+_new_sessionid = util.counter()
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
new file mode 100644
index 0000000..9718024
--- /dev/null
+++ b/lib/sqlalchemy/orm/state.py
@@ -0,0 +1,1025 @@
+# orm/state.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines instrumentation of instances.
+
+This module is usually not directly visible to user applications, but
+defines a large part of the ORM's interactivity.
+
+"""
+
+import weakref
+
+from . import base
+from . import exc as orm_exc
+from . import interfaces
+from .base import ATTR_WAS_SET
+from .base import INIT_OK
+from .base import NEVER_SET
+from .base import NO_VALUE
+from .base import PASSIVE_NO_INITIALIZE
+from .base import PASSIVE_NO_RESULT
+from .base import PASSIVE_OFF
+from .base import SQL_OK
+from .path_registry import PathRegistry
+from .. import exc as sa_exc
+from .. import inspection
+from .. import util
+
+
+# late-populated by session.py
+_sessions = None
+
+# optionally late-provided by sqlalchemy.ext.asyncio.session
+_async_provider = None
+
+
+@inspection._self_inspects
+class InstanceState(interfaces.InspectionAttrInfo):
+ """tracks state information at the instance level.
+
+ The :class:`.InstanceState` is a key object used by the
+ SQLAlchemy ORM in order to track the state of an object;
+ it is created the moment an object is instantiated, typically
+ as a result of :term:`instrumentation` which SQLAlchemy applies
+ to the ``__init__()`` method of the class.
+
+ :class:`.InstanceState` is also a semi-public object,
+ available for runtime inspection as to the state of a
+ mapped instance, including information such as its current
+ status within a particular :class:`.Session` and details
+ about data on individual attributes. The public API
+ in order to acquire a :class:`.InstanceState` object
+ is to use the :func:`_sa.inspect` system::
+
+ >>> from sqlalchemy import inspect
+ >>> insp = inspect(some_mapped_object)
+ >>> insp.attrs.nickname.history
+ History(added=['new nickname'], unchanged=(), deleted=['nickname'])
+
+ .. seealso::
+
+ :ref:`orm_mapper_inspection_instancestate`
+
+ """
+
+ session_id = None
+ key = None
+ runid = None
+ load_options = util.EMPTY_SET
+ load_path = PathRegistry.root
+ insert_order = None
+ _strong_obj = None
+ modified = False
+ expired = False
+ _deleted = False
+ _load_pending = False
+ _orphaned_outside_of_session = False
+ is_instance = True
+ identity_token = None
+ _last_known_values = ()
+
+ callables = ()
+ """A namespace where a per-state loader callable can be associated.
+
+ In SQLAlchemy 1.0, this is only used for lazy loaders / deferred
+ loaders that were set up via query option.
+
+ Previously, callables was used also to indicate expired attributes
+ by storing a link to the InstanceState itself in this dictionary.
+ This role is now handled by the expired_attributes set.
+
+ """
+
+ def __init__(self, obj, manager):
+ self.class_ = obj.__class__
+ self.manager = manager
+ self.obj = weakref.ref(obj, self._cleanup)
+ self.committed_state = {}
+ self.expired_attributes = set()
+
+ expired_attributes = None
+ """The set of keys which are 'expired' to be loaded by
+ the manager's deferred scalar loader, assuming no pending
+ changes.
+
+ see also the ``unmodified`` collection which is intersected
+ against this set when a refresh operation occurs."""
+
+ @util.memoized_property
+ def attrs(self):
+ """Return a namespace representing each attribute on
+ the mapped object, including its current value
+ and history.
+
+ The returned object is an instance of :class:`.AttributeState`.
+ This object allows inspection of the current data
+ within an attribute as well as attribute history
+ since the last flush.
+
+ """
+ return util.ImmutableProperties(
+ dict((key, AttributeState(self, key)) for key in self.manager)
+ )
+
+ @property
+ def transient(self):
+ """Return ``True`` if the object is :term:`transient`.
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is None and not self._attached
+
+ @property
+ def pending(self):
+ """Return ``True`` if the object is :term:`pending`.
+
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is None and self._attached
+
+ @property
+ def deleted(self):
+ """Return ``True`` if the object is :term:`deleted`.
+
+ An object that is in the deleted state is guaranteed to
+ not be within the :attr:`.Session.identity_map` of its parent
+ :class:`.Session`; however if the session's transaction is rolled
+ back, the object will be restored to the persistent state and
+ the identity map.
+
+ .. note::
+
+ The :attr:`.InstanceState.deleted` attribute refers to a specific
+ state of the object that occurs between the "persistent" and
+ "detached" states; once the object is :term:`detached`, the
+ :attr:`.InstanceState.deleted` attribute **no longer returns
+ True**; in order to detect that a state was deleted, regardless
+ of whether or not the object is associated with a
+ :class:`.Session`, use the :attr:`.InstanceState.was_deleted`
+ accessor.
+
+ .. versionadded: 1.1
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is not None and self._attached and self._deleted
+
+ @property
+ def was_deleted(self):
+ """Return True if this object is or was previously in the
+ "deleted" state and has not been reverted to persistent.
+
+ This flag returns True once the object was deleted in flush.
+ When the object is expunged from the session either explicitly
+ or via transaction commit and enters the "detached" state,
+ this flag will continue to report True.
+
+ .. versionadded:: 1.1 - added a local method form of
+ :func:`.orm.util.was_deleted`.
+
+ .. seealso::
+
+ :attr:`.InstanceState.deleted` - refers to the "deleted" state
+
+ :func:`.orm.util.was_deleted` - standalone function
+
+ :ref:`session_object_states`
+
+ """
+ return self._deleted
+
+ @property
+ def persistent(self):
+ """Return ``True`` if the object is :term:`persistent`.
+
+ An object that is in the persistent state is guaranteed to
+ be within the :attr:`.Session.identity_map` of its parent
+ :class:`.Session`.
+
+ .. versionchanged:: 1.1 The :attr:`.InstanceState.persistent`
+ accessor no longer returns True for an object that was
+ "deleted" within a flush; use the :attr:`.InstanceState.deleted`
+ accessor to detect this state. This allows the "persistent"
+ state to guarantee membership in the identity map.
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is not None and self._attached and not self._deleted
+
+ @property
+ def detached(self):
+ """Return ``True`` if the object is :term:`detached`.
+
+ .. seealso::
+
+ :ref:`session_object_states`
+
+ """
+ return self.key is not None and not self._attached
+
+ @property
+ @util.preload_module("sqlalchemy.orm.session")
+ def _attached(self):
+ return (
+ self.session_id is not None
+ and self.session_id in util.preloaded.orm_session._sessions
+ )
+
+ def _track_last_known_value(self, key):
+ """Track the last known value of a particular key after expiration
+ operations.
+
+ .. versionadded:: 1.3
+
+ """
+
+ if key not in self._last_known_values:
+ self._last_known_values = dict(self._last_known_values)
+ self._last_known_values[key] = NO_VALUE
+
+ @property
+ def session(self):
+ """Return the owning :class:`.Session` for this instance,
+ or ``None`` if none available.
+
+ Note that the result here can in some cases be *different*
+ from that of ``obj in session``; an object that's been deleted
+ will report as not ``in session``, however if the transaction is
+ still in progress, this attribute will still refer to that session.
+ Only when the transaction is completed does the object become
+ fully detached under normal circumstances.
+
+ .. seealso::
+
+ :attr:`_orm.InstanceState.async_session`
+
+ """
+ if self.session_id:
+ try:
+ return _sessions[self.session_id]
+ except KeyError:
+ pass
+ return None
+
+ @property
+ def async_session(self):
+ """Return the owning :class:`_asyncio.AsyncSession` for this instance,
+ or ``None`` if none available.
+
+ This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio`
+ API is in use for this ORM object. The returned
+ :class:`_asyncio.AsyncSession` object will be a proxy for the
+ :class:`_orm.Session` object that would be returned from the
+ :attr:`_orm.InstanceState.session` attribute for this
+ :class:`_orm.InstanceState`.
+
+ .. versionadded:: 1.4.18
+
+ .. seealso::
+
+ :ref:`asyncio_toplevel`
+
+ """
+ if _async_provider is None:
+ return None
+
+ sess = self.session
+ if sess is not None:
+ return _async_provider(sess)
+ else:
+ return None
+
+ @property
+ def object(self):
+ """Return the mapped object represented by this
+ :class:`.InstanceState`."""
+ return self.obj()
+
+ @property
+ def identity(self):
+ """Return the mapped identity of the mapped object.
+ This is the primary key identity as persisted by the ORM
+ which can always be passed directly to
+ :meth:`_query.Query.get`.
+
+ Returns ``None`` if the object has no primary key identity.
+
+ .. note::
+ An object which is :term:`transient` or :term:`pending`
+ does **not** have a mapped identity until it is flushed,
+ even if its attributes include primary key values.
+
+ """
+ if self.key is None:
+ return None
+ else:
+ return self.key[1]
+
+ @property
+ def identity_key(self):
+ """Return the identity key for the mapped object.
+
+ This is the key used to locate the object within
+ the :attr:`.Session.identity_map` mapping. It contains
+ the identity as returned by :attr:`.identity` within it.
+
+
+ """
+ # TODO: just change .key to .identity_key across
+ # the board ? probably
+ return self.key
+
+ @util.memoized_property
+ def parents(self):
+ return {}
+
+ @util.memoized_property
+ def _pending_mutations(self):
+ return {}
+
+ @util.memoized_property
+ def _empty_collections(self):
+ return {}
+
+ @util.memoized_property
+ def mapper(self):
+ """Return the :class:`_orm.Mapper` used for this mapped object."""
+ return self.manager.mapper
+
+ @property
+ def has_identity(self):
+ """Return ``True`` if this object has an identity key.
+
+ This should always have the same value as the
+ expression ``state.persistent`` or ``state.detached``.
+
+ """
+ return bool(self.key)
+
+ @classmethod
+ def _detach_states(self, states, session, to_transient=False):
+ persistent_to_detached = (
+ session.dispatch.persistent_to_detached or None
+ )
+ deleted_to_detached = session.dispatch.deleted_to_detached or None
+ pending_to_transient = session.dispatch.pending_to_transient or None
+ persistent_to_transient = (
+ session.dispatch.persistent_to_transient or None
+ )
+
+ for state in states:
+ deleted = state._deleted
+ pending = state.key is None
+ persistent = not pending and not deleted
+
+ state.session_id = None
+
+ if to_transient and state.key:
+ del state.key
+ if persistent:
+ if to_transient:
+ if persistent_to_transient is not None:
+ persistent_to_transient(session, state)
+ elif persistent_to_detached is not None:
+ persistent_to_detached(session, state)
+ elif deleted and deleted_to_detached is not None:
+ deleted_to_detached(session, state)
+ elif pending and pending_to_transient is not None:
+ pending_to_transient(session, state)
+
+ state._strong_obj = None
+
+ def _detach(self, session=None):
+ if session:
+ InstanceState._detach_states([self], session)
+ else:
+ self.session_id = self._strong_obj = None
+
+ def _dispose(self):
+ self._detach()
+ del self.obj
+
+ def _cleanup(self, ref):
+ """Weakref callback cleanup.
+
+ This callable cleans out the state when it is being garbage
+ collected.
+
+ this _cleanup **assumes** that there are no strong refs to us!
+ Will not work otherwise!
+
+ """
+
+ # Python builtins become undefined during interpreter shutdown.
+ # Guard against exceptions during this phase, as the method cannot
+ # proceed in any case if builtins have been undefined.
+ if dict is None:
+ return
+
+ instance_dict = self._instance_dict()
+ if instance_dict is not None:
+ instance_dict._fast_discard(self)
+ del self._instance_dict
+
+ # we can't possibly be in instance_dict._modified
+ # b.c. this is weakref cleanup only, that set
+ # is strong referencing!
+ # assert self not in instance_dict._modified
+
+ self.session_id = self._strong_obj = None
+ del self.obj
+
+ def obj(self):
+ return None
+
+ @property
+ def dict(self):
+ """Return the instance dict used by the object.
+
+ Under normal circumstances, this is always synonymous
+ with the ``__dict__`` attribute of the mapped object,
+ unless an alternative instrumentation system has been
+ configured.
+
+ In the case that the actual object has been garbage
+ collected, this accessor returns a blank dictionary.
+
+ """
+ o = self.obj()
+ if o is not None:
+ return base.instance_dict(o)
+ else:
+ return {}
+
+ def _initialize_instance(*mixed, **kwargs):
+ self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa
+ manager = self.manager
+
+ manager.dispatch.init(self, args, kwargs)
+
+ try:
+ return manager.original_init(*mixed[1:], **kwargs)
+ except:
+ with util.safe_reraise():
+ manager.dispatch.init_failure(self, args, kwargs)
+
+ def get_history(self, key, passive):
+ return self.manager[key].impl.get_history(self, self.dict, passive)
+
+ def get_impl(self, key):
+ return self.manager[key].impl
+
+ def _get_pending_mutation(self, key):
+ if key not in self._pending_mutations:
+ self._pending_mutations[key] = PendingCollection()
+ return self._pending_mutations[key]
+
+ def __getstate__(self):
+ state_dict = {"instance": self.obj()}
+ state_dict.update(
+ (k, self.__dict__[k])
+ for k in (
+ "committed_state",
+ "_pending_mutations",
+ "modified",
+ "expired",
+ "callables",
+ "key",
+ "parents",
+ "load_options",
+ "class_",
+ "expired_attributes",
+ "info",
+ )
+ if k in self.__dict__
+ )
+ if self.load_path:
+ state_dict["load_path"] = self.load_path.serialize()
+
+ state_dict["manager"] = self.manager._serialize(self, state_dict)
+
+ return state_dict
+
+ def __setstate__(self, state_dict):
+ inst = state_dict["instance"]
+ if inst is not None:
+ self.obj = weakref.ref(inst, self._cleanup)
+ self.class_ = inst.__class__
+ else:
+ # None being possible here generally new as of 0.7.4
+ # due to storage of state in "parents". "class_"
+ # also new.
+ self.obj = None
+ self.class_ = state_dict["class_"]
+
+ self.committed_state = state_dict.get("committed_state", {})
+ self._pending_mutations = state_dict.get("_pending_mutations", {})
+ self.parents = state_dict.get("parents", {})
+ self.modified = state_dict.get("modified", False)
+ self.expired = state_dict.get("expired", False)
+ if "info" in state_dict:
+ self.info.update(state_dict["info"])
+ if "callables" in state_dict:
+ self.callables = state_dict["callables"]
+
+ try:
+ self.expired_attributes = state_dict["expired_attributes"]
+ except KeyError:
+ self.expired_attributes = set()
+ # 0.9 and earlier compat
+ for k in list(self.callables):
+ if self.callables[k] is self:
+ self.expired_attributes.add(k)
+ del self.callables[k]
+ else:
+ if "expired_attributes" in state_dict:
+ self.expired_attributes = state_dict["expired_attributes"]
+ else:
+ self.expired_attributes = set()
+
+ self.__dict__.update(
+ [
+ (k, state_dict[k])
+ for k in ("key", "load_options")
+ if k in state_dict
+ ]
+ )
+ if self.key:
+ try:
+ self.identity_token = self.key[2]
+ except IndexError:
+ # 1.1 and earlier compat before identity_token
+ assert len(self.key) == 2
+ self.key = self.key + (None,)
+ self.identity_token = None
+
+ if "load_path" in state_dict:
+ self.load_path = PathRegistry.deserialize(state_dict["load_path"])
+
+ state_dict["manager"](self, inst, state_dict)
+
+ def _reset(self, dict_, key):
+ """Remove the given attribute and any
+ callables associated with it."""
+
+ old = dict_.pop(key, None)
+ if old is not None and self.manager[key].impl.collection:
+ self.manager[key].impl._invalidate_collection(old)
+ self.expired_attributes.discard(key)
+ if self.callables:
+ self.callables.pop(key, None)
+
+ def _copy_callables(self, from_):
+ if "callables" in from_.__dict__:
+ self.callables = dict(from_.callables)
+
+ @classmethod
+ def _instance_level_callable_processor(cls, manager, fn, key):
+ impl = manager[key].impl
+ if impl.collection:
+
+ def _set_callable(state, dict_, row):
+ if "callables" not in state.__dict__:
+ state.callables = {}
+ old = dict_.pop(key, None)
+ if old is not None:
+ impl._invalidate_collection(old)
+ state.callables[key] = fn
+
+ else:
+
+ def _set_callable(state, dict_, row):
+ if "callables" not in state.__dict__:
+ state.callables = {}
+ state.callables[key] = fn
+
+ return _set_callable
+
+ def _expire(self, dict_, modified_set):
+ self.expired = True
+ if self.modified:
+ modified_set.discard(self)
+ self.committed_state.clear()
+ self.modified = False
+
+ self._strong_obj = None
+
+ if "_pending_mutations" in self.__dict__:
+ del self.__dict__["_pending_mutations"]
+
+ if "parents" in self.__dict__:
+ del self.__dict__["parents"]
+
+ self.expired_attributes.update(
+ [impl.key for impl in self.manager._loader_impls]
+ )
+
+ if self.callables:
+ # the per state loader callables we can remove here are
+ # LoadDeferredColumns, which undefers a column at the instance
+ # level that is mapped with deferred, and LoadLazyAttribute,
+ # which lazy loads a relationship at the instance level that
+ # is mapped with "noload" or perhaps "immediateload".
+ # Before 1.4, only column-based
+ # attributes could be considered to be "expired", so here they
+ # were the only ones "unexpired", which means to make them deferred
+ # again. For the moment, as of 1.4 we also apply the same
+ # treatment relationships now, that is, an instance level lazy
+ # loader is reset in the same way as a column loader.
+ for k in self.expired_attributes.intersection(self.callables):
+ del self.callables[k]
+
+ for k in self.manager._collection_impl_keys.intersection(dict_):
+ collection = dict_.pop(k)
+ collection._sa_adapter.invalidated = True
+
+ if self._last_known_values:
+ self._last_known_values.update(
+ (k, dict_[k]) for k in self._last_known_values if k in dict_
+ )
+
+ for key in self.manager._all_key_set.intersection(dict_):
+ del dict_[key]
+
+ self.manager.dispatch.expire(self, None)
+
+ def _expire_attributes(self, dict_, attribute_names, no_loader=False):
+ pending = self.__dict__.get("_pending_mutations", None)
+
+ callables = self.callables
+
+ for key in attribute_names:
+ impl = self.manager[key].impl
+ if impl.accepts_scalar_loader:
+ if no_loader and (impl.callable_ or key in callables):
+ continue
+
+ self.expired_attributes.add(key)
+ if callables and key in callables:
+ del callables[key]
+ old = dict_.pop(key, NO_VALUE)
+ if impl.collection and old is not NO_VALUE:
+ impl._invalidate_collection(old)
+
+ if (
+ self._last_known_values
+ and key in self._last_known_values
+ and old is not NO_VALUE
+ ):
+ self._last_known_values[key] = old
+
+ self.committed_state.pop(key, None)
+ if pending:
+ pending.pop(key, None)
+
+ self.manager.dispatch.expire(self, attribute_names)
+
+ def _load_expired(self, state, passive):
+ """__call__ allows the InstanceState to act as a deferred
+ callable for loading expired attributes, which is also
+ serializable (picklable).
+
+ """
+
+ if not passive & SQL_OK:
+ return PASSIVE_NO_RESULT
+
+ toload = self.expired_attributes.intersection(self.unmodified)
+ toload = toload.difference(
+ attr
+ for attr in toload
+ if not self.manager[attr].impl.load_on_unexpire
+ )
+
+ self.manager.expired_attribute_loader(self, toload, passive)
+
+ # if the loader failed, or this
+ # instance state didn't have an identity,
+ # the attributes still might be in the callables
+ # dict. ensure they are removed.
+ self.expired_attributes.clear()
+
+ return ATTR_WAS_SET
+
+ @property
+ def unmodified(self):
+ """Return the set of keys which have no uncommitted changes"""
+
+ return set(self.manager).difference(self.committed_state)
+
+ def unmodified_intersection(self, keys):
+ """Return self.unmodified.intersection(keys)."""
+
+ return (
+ set(keys)
+ .intersection(self.manager)
+ .difference(self.committed_state)
+ )
+
+ @property
+ def unloaded(self):
+ """Return the set of keys which do not have a loaded value.
+
+ This includes expired attributes and any other attribute that
+ was never populated or modified.
+
+ """
+ return (
+ set(self.manager)
+ .difference(self.committed_state)
+ .difference(self.dict)
+ )
+
+ @property
+ def unloaded_expirable(self):
+ """Return the set of keys which do not have a loaded value.
+
+ This includes expired attributes and any other attribute that
+ was never populated or modified.
+
+ """
+ return self.unloaded
+
+ @property
+ def _unloaded_non_object(self):
+ return self.unloaded.intersection(
+ attr
+ for attr in self.manager
+ if self.manager[attr].impl.accepts_scalar_loader
+ )
+
+ def _instance_dict(self):
+ return None
+
+ def _modified_event(
+ self, dict_, attr, previous, collection=False, is_userland=False
+ ):
+ if attr:
+ if not attr.send_modified_events:
+ return
+ if is_userland and attr.key not in dict_:
+ raise sa_exc.InvalidRequestError(
+ "Can't flag attribute '%s' modified; it's not present in "
+ "the object state" % attr.key
+ )
+ if attr.key not in self.committed_state or is_userland:
+ if collection:
+ if previous is NEVER_SET:
+ if attr.key in dict_:
+ previous = dict_[attr.key]
+
+ if previous not in (None, NO_VALUE, NEVER_SET):
+ previous = attr.copy(previous)
+ self.committed_state[attr.key] = previous
+
+ if attr.key in self._last_known_values:
+ self._last_known_values[attr.key] = NO_VALUE
+
+ # assert self._strong_obj is None or self.modified
+
+ if (self.session_id and self._strong_obj is None) or not self.modified:
+ self.modified = True
+ instance_dict = self._instance_dict()
+ if instance_dict:
+ has_modified = bool(instance_dict._modified)
+ instance_dict._modified.add(self)
+ else:
+ has_modified = False
+
+ # only create _strong_obj link if attached
+ # to a session
+
+ inst = self.obj()
+ if self.session_id:
+ self._strong_obj = inst
+
+ # if identity map already had modified objects,
+ # assume autobegin already occurred, else check
+ # for autobegin
+ if not has_modified:
+ # inline of autobegin, to ensure session transaction
+ # snapshot is established
+ try:
+ session = _sessions[self.session_id]
+ except KeyError:
+ pass
+ else:
+ if session._transaction is None:
+ session._autobegin()
+
+ if inst is None and attr:
+ raise orm_exc.ObjectDereferencedError(
+ "Can't emit change event for attribute '%s' - "
+ "parent object of type %s has been garbage "
+ "collected."
+ % (self.manager[attr.key], base.state_class_str(self))
+ )
+
+ def _commit(self, dict_, keys):
+ """Commit attributes.
+
+ This is used by a partial-attribute load operation to mark committed
+ those attributes which were refreshed from the database.
+
+ Attributes marked as "expired" can potentially remain "expired" after
+ this step if a value was not populated in state.dict.
+
+ """
+ for key in keys:
+ self.committed_state.pop(key, None)
+
+ self.expired = False
+
+ self.expired_attributes.difference_update(
+ set(keys).intersection(dict_)
+ )
+
+ # the per-keys commit removes object-level callables,
+ # while that of commit_all does not. it's not clear
+ # if this behavior has a clear rationale, however tests do
+ # ensure this is what it does.
+ if self.callables:
+ for key in (
+ set(self.callables).intersection(keys).intersection(dict_)
+ ):
+ del self.callables[key]
+
+ def _commit_all(self, dict_, instance_dict=None):
+ """commit all attributes unconditionally.
+
+ This is used after a flush() or a full load/refresh
+ to remove all pending state from the instance.
+
+ - all attributes are marked as "committed"
+ - the "strong dirty reference" is removed
+ - the "modified" flag is set to False
+ - any "expired" markers for scalar attributes loaded are removed.
+ - lazy load callables for objects / collections *stay*
+
+ Attributes marked as "expired" can potentially remain
+ "expired" after this step if a value was not populated in state.dict.
+
+ """
+ self._commit_all_states([(self, dict_)], instance_dict)
+
+ @classmethod
+ def _commit_all_states(self, iter_, instance_dict=None):
+ """Mass / highly inlined version of commit_all()."""
+
+ for state, dict_ in iter_:
+ state_dict = state.__dict__
+
+ state.committed_state.clear()
+
+ if "_pending_mutations" in state_dict:
+ del state_dict["_pending_mutations"]
+
+ state.expired_attributes.difference_update(dict_)
+
+ if instance_dict and state.modified:
+ instance_dict._modified.discard(state)
+
+ state.modified = state.expired = False
+ state._strong_obj = None
+
+
+class AttributeState(object):
+ """Provide an inspection interface corresponding
+ to a particular attribute on a particular mapped object.
+
+ The :class:`.AttributeState` object is accessed
+ via the :attr:`.InstanceState.attrs` collection
+ of a particular :class:`.InstanceState`::
+
+ from sqlalchemy import inspect
+
+ insp = inspect(some_mapped_object)
+ attr_state = insp.attrs.some_attribute
+
+ """
+
+ def __init__(self, state, key):
+ self.state = state
+ self.key = key
+
+ @property
+ def loaded_value(self):
+ """The current value of this attribute as loaded from the database.
+
+ If the value has not been loaded, or is otherwise not present
+ in the object's dictionary, returns NO_VALUE.
+
+ """
+ return self.state.dict.get(self.key, NO_VALUE)
+
+ @property
+ def value(self):
+ """Return the value of this attribute.
+
+ This operation is equivalent to accessing the object's
+ attribute directly or via ``getattr()``, and will fire
+ off any pending loader callables if needed.
+
+ """
+ return self.state.manager[self.key].__get__(
+ self.state.obj(), self.state.class_
+ )
+
+ @property
+ def history(self):
+ """Return the current **pre-flush** change history for
+ this attribute, via the :class:`.History` interface.
+
+ This method will **not** emit loader callables if the value of the
+ attribute is unloaded.
+
+ .. note::
+
+ The attribute history system tracks changes on a **per flush
+ basis**. Each time the :class:`.Session` is flushed, the history
+ of each attribute is reset to empty. The :class:`.Session` by
+ default autoflushes each time a :class:`_query.Query` is invoked.
+ For
+ options on how to control this, see :ref:`session_flushing`.
+
+
+ .. seealso::
+
+ :meth:`.AttributeState.load_history` - retrieve history
+ using loader callables if the value is not locally present.
+
+ :func:`.attributes.get_history` - underlying function
+
+ """
+ return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE)
+
+ def load_history(self):
+ """Return the current **pre-flush** change history for
+ this attribute, via the :class:`.History` interface.
+
+ This method **will** emit loader callables if the value of the
+ attribute is unloaded.
+
+ .. note::
+
+ The attribute history system tracks changes on a **per flush
+ basis**. Each time the :class:`.Session` is flushed, the history
+ of each attribute is reset to empty. The :class:`.Session` by
+ default autoflushes each time a :class:`_query.Query` is invoked.
+ For
+ options on how to control this, see :ref:`session_flushing`.
+
+ .. seealso::
+
+ :attr:`.AttributeState.history`
+
+ :func:`.attributes.get_history` - underlying function
+
+ .. versionadded:: 0.9.0
+
+ """
+ return self.state.get_history(self.key, PASSIVE_OFF ^ INIT_OK)
+
+
+class PendingCollection(object):
+ """A writable placeholder for an unloaded collection.
+
+ Stores items appended to and removed from a collection that has not yet
+ been loaded. When the collection is loaded, the changes stored in
+ PendingCollection are applied to it to produce the final result.
+
+ """
+
+ def __init__(self):
+ self.deleted_items = util.IdentitySet()
+ self.added_items = util.OrderedIdentitySet()
+
+ def append(self, value):
+ if value in self.deleted_items:
+ self.deleted_items.remove(value)
+ else:
+ self.added_items.add(value)
+
+ def remove(self, value):
+ if value in self.added_items:
+ self.added_items.remove(value)
+ else:
+ self.deleted_items.add(value)
diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py
new file mode 100644
index 0000000..71aae00
--- /dev/null
+++ b/lib/sqlalchemy/orm/strategies.py
@@ -0,0 +1,3141 @@
+# orm/strategies.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""sqlalchemy.orm.interfaces.LoaderStrategy
+ implementations, and related MapperOptions."""
+from __future__ import absolute_import
+
+import collections
+import itertools
+
+from . import attributes
+from . import exc as orm_exc
+from . import interfaces
+from . import loading
+from . import path_registry
+from . import properties
+from . import query
+from . import relationships
+from . import unitofwork
+from . import util as orm_util
+from .base import _DEFER_FOR_STATE
+from .base import _RAISE_FOR_STATE
+from .base import _SET_DEFERRED_EXPIRED
+from .context import _column_descriptions
+from .context import ORMCompileState
+from .context import ORMSelectCompileState
+from .context import QueryContext
+from .interfaces import LoaderStrategy
+from .interfaces import StrategizedProperty
+from .session import _state_session
+from .state import InstanceState
+from .util import _none_set
+from .util import aliased
+from .. import event
+from .. import exc as sa_exc
+from .. import inspect
+from .. import log
+from .. import sql
+from .. import util
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import Select
+
+
+def _register_attribute(
+ prop,
+ mapper,
+ useobject,
+ compare_function=None,
+ typecallable=None,
+ callable_=None,
+ proxy_property=None,
+ active_history=False,
+ impl_class=None,
+ **kw
+):
+
+ listen_hooks = []
+
+ uselist = useobject and prop.uselist
+
+ if useobject and prop.single_parent:
+ listen_hooks.append(single_parent_validator)
+
+ if prop.key in prop.parent.validators:
+ fn, opts = prop.parent.validators[prop.key]
+ listen_hooks.append(
+ lambda desc, prop: orm_util._validator_events(
+ desc, prop.key, fn, **opts
+ )
+ )
+
+ if useobject:
+ listen_hooks.append(unitofwork.track_cascade_events)
+
+ # need to assemble backref listeners
+ # after the singleparentvalidator, mapper validator
+ if useobject:
+ backref = prop.back_populates
+ if backref and prop._effective_sync_backref:
+ listen_hooks.append(
+ lambda desc, prop: attributes.backref_listeners(
+ desc, backref, uselist
+ )
+ )
+
+ # a single MapperProperty is shared down a class inheritance
+ # hierarchy, so we set up attribute instrumentation and backref event
+ # for each mapper down the hierarchy.
+
+ # typically, "mapper" is the same as prop.parent, due to the way
+ # the configure_mappers() process runs, however this is not strongly
+ # enforced, and in the case of a second configure_mappers() run the
+ # mapper here might not be prop.parent; also, a subclass mapper may
+ # be called here before a superclass mapper. That is, can't depend
+ # on mappers not already being set up so we have to check each one.
+
+ for m in mapper.self_and_descendants:
+ if prop is m._props.get(
+ prop.key
+ ) and not m.class_manager._attr_has_impl(prop.key):
+
+ desc = attributes.register_attribute_impl(
+ m.class_,
+ prop.key,
+ parent_token=prop,
+ uselist=uselist,
+ compare_function=compare_function,
+ useobject=useobject,
+ trackparent=useobject
+ and (
+ prop.single_parent
+ or prop.direction is interfaces.ONETOMANY
+ ),
+ typecallable=typecallable,
+ callable_=callable_,
+ active_history=active_history,
+ impl_class=impl_class,
+ send_modified_events=not useobject or not prop.viewonly,
+ doc=prop.doc,
+ **kw
+ )
+
+ for hook in listen_hooks:
+ hook(desc, prop)
+
+
+@properties.ColumnProperty.strategy_for(instrument=False, deferred=False)
+class UninstrumentedColumnLoader(LoaderStrategy):
+ """Represent a non-instrumented MapperProperty.
+
+ The polymorphic_on argument of mapper() often results in this,
+ if the argument is against the with_polymorphic selectable.
+
+ """
+
+ __slots__ = ("columns",)
+
+ def __init__(self, parent, strategy_key):
+ super(UninstrumentedColumnLoader, self).__init__(parent, strategy_key)
+ self.columns = self.parent_property.columns
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ **kwargs
+ ):
+ for c in self.columns:
+ if adapter:
+ c = adapter.columns[c]
+ compile_state._append_dedupe_col_collection(c, column_collection)
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ pass
+
+
+@log.class_logger
+@properties.ColumnProperty.strategy_for(instrument=True, deferred=False)
+class ColumnLoader(LoaderStrategy):
+ """Provide loading behavior for a :class:`.ColumnProperty`."""
+
+ __slots__ = "columns", "is_composite"
+
+ def __init__(self, parent, strategy_key):
+ super(ColumnLoader, self).__init__(parent, strategy_key)
+ self.columns = self.parent_property.columns
+ self.is_composite = hasattr(self.parent_property, "composite_class")
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ check_for_adapt=False,
+ **kwargs
+ ):
+ for c in self.columns:
+ if adapter:
+ if check_for_adapt:
+ c = adapter.adapt_check_present(c)
+ if c is None:
+ return
+ else:
+ c = adapter.columns[c]
+
+ compile_state._append_dedupe_col_collection(c, column_collection)
+
+ fetch = self.columns[0]
+ if adapter:
+ fetch = adapter.columns[fetch]
+ memoized_populators[self.parent_property] = fetch
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+ coltype = self.columns[0].type
+ # TODO: check all columns ? check for foreign key as well?
+ active_history = (
+ self.parent_property.active_history
+ or self.columns[0].primary_key
+ or (
+ mapper.version_id_col is not None
+ and mapper._columntoproperty.get(mapper.version_id_col, None)
+ is self.parent_property
+ )
+ )
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=False,
+ compare_function=coltype.compare_values,
+ active_history=active_history,
+ )
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ # look through list of columns represented here
+ # to see which, if any, is present in the row.
+ for col in self.columns:
+ if adapter:
+ col = adapter.columns[col]
+ getter = result._getter(col, False)
+ if getter:
+ populators["quick"].append((self.key, getter))
+ break
+ else:
+ populators["expire"].append((self.key, True))
+
+
+@log.class_logger
+@properties.ColumnProperty.strategy_for(query_expression=True)
+class ExpressionColumnLoader(ColumnLoader):
+ def __init__(self, parent, strategy_key):
+ super(ExpressionColumnLoader, self).__init__(parent, strategy_key)
+
+ # compare to the "default" expression that is mapped in
+ # the column. If it's sql.null, we don't need to render
+ # unless an expr is passed in the options.
+ null = sql.null().label(None)
+ self._have_default_expression = any(
+ not c.compare(null) for c in self.parent_property.columns
+ )
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ **kwargs
+ ):
+ columns = None
+ if loadopt and "expression" in loadopt.local_opts:
+ columns = [loadopt.local_opts["expression"]]
+ elif self._have_default_expression:
+ columns = self.parent_property.columns
+
+ if columns is None:
+ return
+
+ for c in columns:
+ if adapter:
+ c = adapter.columns[c]
+ compile_state._append_dedupe_col_collection(c, column_collection)
+
+ fetch = columns[0]
+ if adapter:
+ fetch = adapter.columns[fetch]
+ memoized_populators[self.parent_property] = fetch
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ # look through list of columns represented here
+ # to see which, if any, is present in the row.
+ if loadopt and "expression" in loadopt.local_opts:
+ columns = [loadopt.local_opts["expression"]]
+
+ for col in columns:
+ if adapter:
+ col = adapter.columns[col]
+ getter = result._getter(col, False)
+ if getter:
+ populators["quick"].append((self.key, getter))
+ break
+ else:
+ populators["expire"].append((self.key, True))
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=False,
+ compare_function=self.columns[0].type.compare_values,
+ accepts_scalar_loader=False,
+ )
+
+
+@log.class_logger
+@properties.ColumnProperty.strategy_for(deferred=True, instrument=True)
+@properties.ColumnProperty.strategy_for(
+ deferred=True, instrument=True, raiseload=True
+)
+@properties.ColumnProperty.strategy_for(do_nothing=True)
+class DeferredColumnLoader(LoaderStrategy):
+ """Provide loading behavior for a deferred :class:`.ColumnProperty`."""
+
+ __slots__ = "columns", "group", "raiseload"
+
+ def __init__(self, parent, strategy_key):
+ super(DeferredColumnLoader, self).__init__(parent, strategy_key)
+ if hasattr(self.parent_property, "composite_class"):
+ raise NotImplementedError(
+ "Deferred loading for composite " "types not implemented yet"
+ )
+ self.raiseload = self.strategy_opts.get("raiseload", False)
+ self.columns = self.parent_property.columns
+ self.group = self.parent_property.group
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+
+ # for a DeferredColumnLoader, this method is only used during a
+ # "row processor only" query; see test_deferred.py ->
+ # tests with "rowproc_only" in their name. As of the 1.0 series,
+ # loading._instance_processor doesn't use a "row processing" function
+ # to populate columns, instead it uses data in the "populators"
+ # dictionary. Normally, the DeferredColumnLoader.setup_query()
+ # sets up that data in the "memoized_populators" dictionary
+ # and "create_row_processor()" here is never invoked.
+
+ if (
+ context.refresh_state
+ and context.query._compile_options._only_load_props
+ and self.key in context.query._compile_options._only_load_props
+ ):
+ self.parent_property._get_strategy(
+ (("deferred", False), ("instrument", True))
+ ).create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+ elif not self.is_class_level:
+ if self.raiseload:
+ set_deferred_for_local_state = (
+ self.parent_property._raise_column_loader
+ )
+ else:
+ set_deferred_for_local_state = (
+ self.parent_property._deferred_column_loader
+ )
+ populators["new"].append((self.key, set_deferred_for_local_state))
+ else:
+ populators["expire"].append((self.key, False))
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=False,
+ compare_function=self.columns[0].type.compare_values,
+ callable_=self._load_for_state,
+ load_on_unexpire=False,
+ )
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ only_load_props=None,
+ **kw
+ ):
+
+ if (
+ (
+ compile_state.compile_options._render_for_subquery
+ and self.parent_property._renders_in_subqueries
+ )
+ or (
+ loadopt
+ and "undefer_pks" in loadopt.local_opts
+ and set(self.columns).intersection(
+ self.parent._should_undefer_in_wildcard
+ )
+ )
+ or (
+ loadopt
+ and self.group
+ and loadopt.local_opts.get(
+ "undefer_group_%s" % self.group, False
+ )
+ )
+ or (only_load_props and self.key in only_load_props)
+ ):
+ self.parent_property._get_strategy(
+ (("deferred", False), ("instrument", True))
+ ).setup_query(
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ memoized_populators,
+ **kw
+ )
+ elif self.is_class_level:
+ memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED
+ elif not self.raiseload:
+ memoized_populators[self.parent_property] = _DEFER_FOR_STATE
+ else:
+ memoized_populators[self.parent_property] = _RAISE_FOR_STATE
+
+ def _load_for_state(self, state, passive):
+ if not state.key:
+ return attributes.ATTR_EMPTY
+
+ if not passive & attributes.SQL_OK:
+ return attributes.PASSIVE_NO_RESULT
+
+ localparent = state.manager.mapper
+
+ if self.group:
+ toload = [
+ p.key
+ for p in localparent.iterate_properties
+ if isinstance(p, StrategizedProperty)
+ and isinstance(p.strategy, DeferredColumnLoader)
+ and p.group == self.group
+ ]
+ else:
+ toload = [self.key]
+
+ # narrow the keys down to just those which have no history
+ group = [k for k in toload if k in state.unmodified]
+
+ session = _state_session(state)
+ if session is None:
+ raise orm_exc.DetachedInstanceError(
+ "Parent instance %s is not bound to a Session; "
+ "deferred load operation of attribute '%s' cannot proceed"
+ % (orm_util.state_str(state), self.key)
+ )
+
+ if self.raiseload:
+ self._invoke_raise_load(state, passive, "raise")
+
+ if (
+ loading.load_on_ident(
+ session,
+ sql.select(localparent).set_label_style(
+ LABEL_STYLE_TABLENAME_PLUS_COL
+ ),
+ state.key,
+ only_load_props=group,
+ refresh_state=state,
+ )
+ is None
+ ):
+ raise orm_exc.ObjectDeletedError(state)
+
+ return attributes.ATTR_WAS_SET
+
+ def _invoke_raise_load(self, state, passive, lazy):
+ raise sa_exc.InvalidRequestError(
+ "'%s' is not available due to raiseload=True" % (self,)
+ )
+
+
+class LoadDeferredColumns(object):
+ """serializable loader object used by DeferredColumnLoader"""
+
+ def __init__(self, key, raiseload=False):
+ self.key = key
+ self.raiseload = raiseload
+
+ def __call__(self, state, passive=attributes.PASSIVE_OFF):
+ key = self.key
+
+ localparent = state.manager.mapper
+ prop = localparent._props[key]
+ if self.raiseload:
+ strategy_key = (
+ ("deferred", True),
+ ("instrument", True),
+ ("raiseload", True),
+ )
+ else:
+ strategy_key = (("deferred", True), ("instrument", True))
+ strategy = prop._get_strategy(strategy_key)
+ return strategy._load_for_state(state, passive)
+
+
+class AbstractRelationshipLoader(LoaderStrategy):
+ """LoaderStratgies which deal with related objects."""
+
+ __slots__ = "mapper", "target", "uselist", "entity"
+
+ def __init__(self, parent, strategy_key):
+ super(AbstractRelationshipLoader, self).__init__(parent, strategy_key)
+ self.mapper = self.parent_property.mapper
+ self.entity = self.parent_property.entity
+ self.target = self.parent_property.target
+ self.uselist = self.parent_property.uselist
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(do_nothing=True)
+class DoNothingLoader(LoaderStrategy):
+ """Relationship loader that makes no change to the object's state.
+
+ Compared to NoLoader, this loader does not initialize the
+ collection/attribute to empty/none; the usual default LazyLoader will
+ take effect.
+
+ """
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="noload")
+@relationships.RelationshipProperty.strategy_for(lazy=None)
+class NoLoader(AbstractRelationshipLoader):
+ """Provide loading behavior for a :class:`.RelationshipProperty`
+ with "lazy=None".
+
+ """
+
+ __slots__ = ()
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=True,
+ typecallable=self.parent_property.collection_class,
+ )
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ def invoke_no_load(state, dict_, row):
+ if self.uselist:
+ attributes.init_state_collection(state, dict_, self.key)
+ else:
+ dict_[self.key] = None
+
+ populators["new"].append((self.key, invoke_no_load))
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy=True)
+@relationships.RelationshipProperty.strategy_for(lazy="select")
+@relationships.RelationshipProperty.strategy_for(lazy="raise")
+@relationships.RelationshipProperty.strategy_for(lazy="raise_on_sql")
+@relationships.RelationshipProperty.strategy_for(lazy="baked_select")
+class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
+ """Provide loading behavior for a :class:`.RelationshipProperty`
+ with "lazy=True", that is loads when first accessed.
+
+ """
+
+ __slots__ = (
+ "_lazywhere",
+ "_rev_lazywhere",
+ "_lazyload_reverse_option",
+ "_order_by",
+ "use_get",
+ "is_aliased_class",
+ "_bind_to_col",
+ "_equated_columns",
+ "_rev_bind_to_col",
+ "_rev_equated_columns",
+ "_simple_lazy_clause",
+ "_raise_always",
+ "_raise_on_sql",
+ )
+
+ def __init__(self, parent, strategy_key):
+ super(LazyLoader, self).__init__(parent, strategy_key)
+ self._raise_always = self.strategy_opts["lazy"] == "raise"
+ self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql"
+
+ self.is_aliased_class = inspect(self.entity).is_aliased_class
+
+ join_condition = self.parent_property._join_condition
+ (
+ self._lazywhere,
+ self._bind_to_col,
+ self._equated_columns,
+ ) = join_condition.create_lazy_clause()
+
+ (
+ self._rev_lazywhere,
+ self._rev_bind_to_col,
+ self._rev_equated_columns,
+ ) = join_condition.create_lazy_clause(reverse_direction=True)
+
+ if self.parent_property.order_by:
+ self._order_by = [
+ sql_util._deep_annotate(elem, {"_orm_adapt": True})
+ for elem in util.to_list(self.parent_property.order_by)
+ ]
+ else:
+ self._order_by = None
+
+ self.logger.info("%s lazy loading clause %s", self, self._lazywhere)
+
+ # determine if our "lazywhere" clause is the same as the mapper's
+ # get() clause. then we can just use mapper.get()
+ #
+ # TODO: the "not self.uselist" can be taken out entirely; a m2o
+ # load that populates for a list (very unusual, but is possible with
+ # the API) can still set for "None" and the attribute system will
+ # populate as an empty list.
+ self.use_get = (
+ not self.is_aliased_class
+ and not self.uselist
+ and self.entity._get_clause[0].compare(
+ self._lazywhere,
+ use_proxies=True,
+ compare_keys=False,
+ equivalents=self.mapper._equivalent_columns,
+ )
+ )
+
+ if self.use_get:
+ for col in list(self._equated_columns):
+ if col in self.mapper._equivalent_columns:
+ for c in self.mapper._equivalent_columns[col]:
+ self._equated_columns[c] = self._equated_columns[col]
+
+ self.logger.info(
+ "%s will use Session.get() to " "optimize instance loads", self
+ )
+
+ def init_class_attribute(self, mapper):
+ self.is_class_level = True
+
+ _legacy_inactive_history_style = (
+ self.parent_property._legacy_inactive_history_style
+ )
+
+ if self.parent_property.active_history:
+ active_history = True
+ _deferred_history = False
+
+ elif (
+ self.parent_property.direction is not interfaces.MANYTOONE
+ or not self.use_get
+ ):
+ if _legacy_inactive_history_style:
+ active_history = True
+ _deferred_history = False
+ else:
+ active_history = False
+ _deferred_history = True
+ else:
+ active_history = _deferred_history = False
+
+ _register_attribute(
+ self.parent_property,
+ mapper,
+ useobject=True,
+ callable_=self._load_for_state,
+ typecallable=self.parent_property.collection_class,
+ active_history=active_history,
+ _deferred_history=_deferred_history,
+ )
+
+ def _memoized_attr__simple_lazy_clause(self):
+
+ lazywhere = sql_util._deep_annotate(
+ self._lazywhere, {"_orm_adapt": True}
+ )
+
+ criterion, bind_to_col = (lazywhere, self._bind_to_col)
+
+ params = []
+
+ def visit_bindparam(bindparam):
+ bindparam.unique = False
+
+ visitors.traverse(criterion, {}, {"bindparam": visit_bindparam})
+
+ def visit_bindparam(bindparam):
+ if bindparam._identifying_key in bind_to_col:
+ params.append(
+ (
+ bindparam.key,
+ bind_to_col[bindparam._identifying_key],
+ None,
+ )
+ )
+ elif bindparam.callable is None:
+ params.append((bindparam.key, None, bindparam.value))
+
+ criterion = visitors.cloned_traverse(
+ criterion, {}, {"bindparam": visit_bindparam}
+ )
+
+ return criterion, params
+
+ def _generate_lazy_clause(self, state, passive):
+ criterion, param_keys = self._simple_lazy_clause
+
+ if state is None:
+ return sql_util.adapt_criterion_to_null(
+ criterion, [key for key, ident, value in param_keys]
+ )
+
+ mapper = self.parent_property.parent
+
+ o = state.obj() # strong ref
+ dict_ = attributes.instance_dict(o)
+
+ if passive & attributes.INIT_OK:
+ passive ^= attributes.INIT_OK
+
+ params = {}
+ for key, ident, value in param_keys:
+ if ident is not None:
+ if passive and passive & attributes.LOAD_AGAINST_COMMITTED:
+ value = mapper._get_committed_state_attr_by_column(
+ state, dict_, ident, passive
+ )
+ else:
+ value = mapper._get_state_attr_by_column(
+ state, dict_, ident, passive
+ )
+
+ params[key] = value
+
+ return criterion, params
+
+ def _invoke_raise_load(self, state, passive, lazy):
+ raise sa_exc.InvalidRequestError(
+ "'%s' is not available due to lazy='%s'" % (self, lazy)
+ )
+
+ def _load_for_state(self, state, passive, loadopt=None, extra_criteria=()):
+ if not state.key and (
+ (
+ not self.parent_property.load_on_pending
+ and not state._load_pending
+ )
+ or not state.session_id
+ ):
+ return attributes.ATTR_EMPTY
+
+ pending = not state.key
+ primary_key_identity = None
+
+ use_get = self.use_get and (not loadopt or not loadopt._extra_criteria)
+
+ if (not passive & attributes.SQL_OK and not use_get) or (
+ not passive & attributes.NON_PERSISTENT_OK and pending
+ ):
+ return attributes.PASSIVE_NO_RESULT
+
+ if (
+ # we were given lazy="raise"
+ self._raise_always
+ # the no_raise history-related flag was not passed
+ and not passive & attributes.NO_RAISE
+ and (
+ # if we are use_get and related_object_ok is disabled,
+ # which means we are at most looking in the identity map
+ # for history purposes or otherwise returning
+ # PASSIVE_NO_RESULT, don't raise. This is also a
+ # history-related flag
+ not use_get
+ or passive & attributes.RELATED_OBJECT_OK
+ )
+ ):
+
+ self._invoke_raise_load(state, passive, "raise")
+
+ session = _state_session(state)
+ if not session:
+ if passive & attributes.NO_RAISE:
+ return attributes.PASSIVE_NO_RESULT
+
+ raise orm_exc.DetachedInstanceError(
+ "Parent instance %s is not bound to a Session; "
+ "lazy load operation of attribute '%s' cannot proceed"
+ % (orm_util.state_str(state), self.key)
+ )
+
+ # if we have a simple primary key load, check the
+ # identity map without generating a Query at all
+ if use_get:
+ primary_key_identity = self._get_ident_for_use_get(
+ session, state, passive
+ )
+ if attributes.PASSIVE_NO_RESULT in primary_key_identity:
+ return attributes.PASSIVE_NO_RESULT
+ elif attributes.NEVER_SET in primary_key_identity:
+ return attributes.NEVER_SET
+
+ if _none_set.issuperset(primary_key_identity):
+ return None
+
+ if (
+ self.key in state.dict
+ and not passive & attributes.DEFERRED_HISTORY_LOAD
+ ):
+ return attributes.ATTR_WAS_SET
+
+ # look for this identity in the identity map. Delegate to the
+ # Query class in use, as it may have special rules for how it
+ # does this, including how it decides what the correct
+ # identity_token would be for this identity.
+
+ instance = session._identity_lookup(
+ self.entity,
+ primary_key_identity,
+ passive=passive,
+ lazy_loaded_from=state,
+ )
+
+ if instance is not None:
+ if instance is attributes.PASSIVE_CLASS_MISMATCH:
+ return None
+ else:
+ return instance
+ elif (
+ not passive & attributes.SQL_OK
+ or not passive & attributes.RELATED_OBJECT_OK
+ ):
+ return attributes.PASSIVE_NO_RESULT
+
+ return self._emit_lazyload(
+ session,
+ state,
+ primary_key_identity,
+ passive,
+ loadopt,
+ extra_criteria,
+ )
+
+ def _get_ident_for_use_get(self, session, state, passive):
+ instance_mapper = state.manager.mapper
+
+ if passive & attributes.LOAD_AGAINST_COMMITTED:
+ get_attr = instance_mapper._get_committed_state_attr_by_column
+ else:
+ get_attr = instance_mapper._get_state_attr_by_column
+
+ dict_ = state.dict
+
+ return [
+ get_attr(state, dict_, self._equated_columns[pk], passive=passive)
+ for pk in self.mapper.primary_key
+ ]
+
+ @util.preload_module("sqlalchemy.orm.strategy_options")
+ def _emit_lazyload(
+ self,
+ session,
+ state,
+ primary_key_identity,
+ passive,
+ loadopt,
+ extra_criteria,
+ ):
+ strategy_options = util.preloaded.orm_strategy_options
+
+ clauseelement = self.entity.__clause_element__()
+ stmt = Select._create_raw_select(
+ _raw_columns=[clauseelement],
+ _propagate_attrs=clauseelement._propagate_attrs,
+ _label_style=LABEL_STYLE_TABLENAME_PLUS_COL,
+ _compile_options=ORMCompileState.default_compile_options,
+ )
+ load_options = QueryContext.default_load_options
+
+ load_options += {
+ "_invoke_all_eagers": False,
+ "_lazy_loaded_from": state,
+ }
+
+ if self.parent_property.secondary is not None:
+ stmt = stmt.select_from(
+ self.mapper, self.parent_property.secondary
+ )
+
+ pending = not state.key
+
+ # don't autoflush on pending
+ if pending or passive & attributes.NO_AUTOFLUSH:
+ stmt._execution_options = util.immutabledict({"autoflush": False})
+
+ use_get = self.use_get
+
+ if state.load_options or (loadopt and loadopt._extra_criteria):
+ effective_path = state.load_path[self.parent_property]
+
+ opts = tuple(state.load_options)
+
+ if loadopt and loadopt._extra_criteria:
+ use_get = False
+ opts += (
+ orm_util.LoaderCriteriaOption(self.entity, extra_criteria),
+ )
+
+ stmt._with_options = opts
+ else:
+ # this path is used if there are not already any options
+ # in the query, but an event may want to add them
+ effective_path = state.mapper._path_registry[self.parent_property]
+
+ stmt._compile_options += {"_current_path": effective_path}
+
+ if use_get:
+ if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ self._invoke_raise_load(state, passive, "raise_on_sql")
+
+ return loading.load_on_pk_identity(
+ session, stmt, primary_key_identity, load_options=load_options
+ )
+
+ if self._order_by:
+ stmt._order_by_clauses = self._order_by
+
+ def _lazyload_reverse(compile_context):
+ for rev in self.parent_property._reverse_property:
+ # reverse props that are MANYTOONE are loading *this*
+ # object from get(), so don't need to eager out to those.
+ if (
+ rev.direction is interfaces.MANYTOONE
+ and rev._use_get
+ and not isinstance(rev.strategy, LazyLoader)
+ ):
+ strategy_options.Load.for_existing_path(
+ compile_context.compile_options._current_path[
+ rev.parent
+ ]
+ ).lazyload(rev).process_compile_state(compile_context)
+
+ stmt._with_context_options += (
+ (_lazyload_reverse, self.parent_property),
+ )
+
+ lazy_clause, params = self._generate_lazy_clause(state, passive)
+
+ execution_options = {
+ "_sa_orm_load_options": load_options,
+ }
+
+ if (
+ self.key in state.dict
+ and not passive & attributes.DEFERRED_HISTORY_LOAD
+ ):
+ return attributes.ATTR_WAS_SET
+
+ if pending:
+ if util.has_intersection(orm_util._none_set, params.values()):
+ return None
+
+ elif util.has_intersection(orm_util._never_set, params.values()):
+ return None
+
+ if self._raise_on_sql and not passive & attributes.NO_RAISE:
+ self._invoke_raise_load(state, passive, "raise_on_sql")
+
+ stmt._where_criteria = (lazy_clause,)
+
+ result = session.execute(
+ stmt, params, execution_options=execution_options
+ )
+
+ result = result.unique().scalars().all()
+
+ if self.uselist:
+ return result
+ else:
+ l = len(result)
+ if l:
+ if l > 1:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for lazily-loaded attribute '%s' "
+ % self.parent_property
+ )
+
+ return result[0]
+ else:
+ return None
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ key = self.key
+
+ if not self.is_class_level or (loadopt and loadopt._extra_criteria):
+ # we are not the primary manager for this attribute
+ # on this class - set up a
+ # per-instance lazyloader, which will override the
+ # class-level behavior.
+ # this currently only happens when using a
+ # "lazyload" option on a "no load"
+ # attribute - "eager" attributes always have a
+ # class-level lazyloader installed.
+ set_lazy_callable = (
+ InstanceState._instance_level_callable_processor
+ )(
+ mapper.class_manager,
+ LoadLazyAttribute(
+ key,
+ self,
+ loadopt,
+ loadopt._generate_extra_criteria(context)
+ if loadopt._extra_criteria
+ else None,
+ ),
+ key,
+ )
+
+ populators["new"].append((self.key, set_lazy_callable))
+ elif context.populate_existing or mapper.always_refresh:
+
+ def reset_for_lazy_callable(state, dict_, row):
+ # we are the primary manager for this attribute on
+ # this class - reset its
+ # per-instance attribute state, so that the class-level
+ # lazy loader is
+ # executed when next referenced on this instance.
+ # this is needed in
+ # populate_existing() types of scenarios to reset
+ # any existing state.
+ state._reset(dict_, key)
+
+ populators["new"].append((self.key, reset_for_lazy_callable))
+
+
+class LoadLazyAttribute(object):
+ """semi-serializable loader object used by LazyLoader
+
+ Historically, this object would be carried along with instances that
+ needed to run lazyloaders, so it had to be serializable to support
+ cached instances.
+
+ this is no longer a general requirement, and the case where this object
+ is used is exactly the case where we can't really serialize easily,
+ which is when extra criteria in the loader option is present.
+
+ We can't reliably serialize that as it refers to mapped entities and
+ AliasedClass objects that are local to the current process, which would
+ need to be matched up on deserialize e.g. the sqlalchemy.ext.serializer
+ approach.
+
+ """
+
+ def __init__(self, key, initiating_strategy, loadopt, extra_criteria):
+ self.key = key
+ self.strategy_key = initiating_strategy.strategy_key
+ self.loadopt = loadopt
+ self.extra_criteria = extra_criteria
+
+ def __getstate__(self):
+ if self.extra_criteria is not None:
+ util.warn(
+ "Can't reliably serialize a lazyload() option that "
+ "contains additional criteria; please use eager loading "
+ "for this case"
+ )
+ return {
+ "key": self.key,
+ "strategy_key": self.strategy_key,
+ "loadopt": self.loadopt,
+ "extra_criteria": (),
+ }
+
+ def __call__(self, state, passive=attributes.PASSIVE_OFF):
+ key = self.key
+ instance_mapper = state.manager.mapper
+ prop = instance_mapper._props[key]
+ strategy = prop._strategies[self.strategy_key]
+
+ return strategy._load_for_state(
+ state,
+ passive,
+ loadopt=self.loadopt,
+ extra_criteria=self.extra_criteria,
+ )
+
+
+class PostLoader(AbstractRelationshipLoader):
+ """A relationship loader that emits a second SELECT statement."""
+
+ def _check_recursive_postload(self, context, path, join_depth=None):
+ effective_path = (
+ context.compile_state.current_path or orm_util.PathRegistry.root
+ ) + path
+
+ if loading.PostLoad.path_exists(
+ context, effective_path, self.parent_property
+ ):
+ return True
+
+ path_w_prop = path[self.parent_property]
+ effective_path_w_prop = effective_path[self.parent_property]
+
+ if not path_w_prop.contains(context.attributes, "loader"):
+ if join_depth:
+ if effective_path_w_prop.length / 2 > join_depth:
+ return True
+ elif effective_path_w_prop.contains_mapper(self.mapper):
+ return True
+
+ return False
+
+ def _immediateload_create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ return self.parent_property._get_strategy(
+ (("lazy", "immediate"),)
+ ).create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+
+@relationships.RelationshipProperty.strategy_for(lazy="immediate")
+class ImmediateLoader(PostLoader):
+ __slots__ = ()
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ def load_immediate(state, dict_, row):
+ state.get_impl(self.key).get(state, dict_, flags)
+
+ if self._check_recursive_postload(context, path):
+ # this will not emit SQL and will only emit for a many-to-one
+ # "use get" load. the "_RELATED" part means it may return
+ # instance even if its expired, since this is a mutually-recursive
+ # load operation.
+ flags = attributes.PASSIVE_NO_FETCH_RELATED | attributes.NO_RAISE
+ else:
+ flags = attributes.PASSIVE_OFF | attributes.NO_RAISE
+
+ populators["delayed"].append((self.key, load_immediate))
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="subquery")
+class SubqueryLoader(PostLoader):
+ __slots__ = ("join_depth",)
+
+ def __init__(self, parent, strategy_key):
+ super(SubqueryLoader, self).__init__(parent, strategy_key)
+ self.join_depth = self.parent_property.join_depth
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def _get_leftmost(
+ self,
+ orig_query_entity_index,
+ subq_path,
+ current_compile_state,
+ is_root,
+ ):
+ given_subq_path = subq_path
+ subq_path = subq_path.path
+ subq_mapper = orm_util._class_to_mapper(subq_path[0])
+
+ # determine attributes of the leftmost mapper
+ if (
+ self.parent.isa(subq_mapper)
+ and self.parent_property is subq_path[1]
+ ):
+ leftmost_mapper, leftmost_prop = self.parent, self.parent_property
+ else:
+ leftmost_mapper, leftmost_prop = subq_mapper, subq_path[1]
+
+ if is_root:
+ # the subq_path is also coming from cached state, so when we start
+ # building up this path, it has to also be converted to be in terms
+ # of the current state. this is for the specific case of the entity
+ # is an AliasedClass against a subquery that's not otherwise going
+ # to adapt
+ new_subq_path = current_compile_state._entities[
+ orig_query_entity_index
+ ].entity_zero._path_registry[leftmost_prop]
+ additional = len(subq_path) - len(new_subq_path)
+ if additional:
+ new_subq_path += path_registry.PathRegistry.coerce(
+ subq_path[-additional:]
+ )
+ else:
+ new_subq_path = given_subq_path
+
+ leftmost_cols = leftmost_prop.local_columns
+
+ leftmost_attr = [
+ getattr(
+ new_subq_path.path[0].entity,
+ leftmost_mapper._columntoproperty[c].key,
+ )
+ for c in leftmost_cols
+ ]
+
+ return leftmost_mapper, leftmost_attr, leftmost_prop, new_subq_path
+
+ def _generate_from_original_query(
+ self,
+ orig_compile_state,
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ orig_entity,
+ ):
+ # reformat the original query
+ # to look only for significant columns
+ q = orig_query._clone().correlate(None)
+
+ # LEGACY: make a Query back from the select() !!
+ # This suits at least two legacy cases:
+ # 1. applications which expect before_compile() to be called
+ # below when we run .subquery() on this query (Keystone)
+ # 2. applications which are doing subqueryload with complex
+ # from_self() queries, as query.subquery() / .statement
+ # has to do the full compile context for multiply-nested
+ # from_self() (Neutron) - see test_subqload_from_self
+ # for demo.
+ q2 = query.Query.__new__(query.Query)
+ q2.__dict__.update(q.__dict__)
+ q = q2
+
+ # set the query's "FROM" list explicitly to what the
+ # FROM list would be in any case, as we will be limiting
+ # the columns in the SELECT list which may no longer include
+ # all entities mentioned in things like WHERE, JOIN, etc.
+ if not q._from_obj:
+ q._enable_assertions = False
+ q.select_from.non_generative(
+ q,
+ *{
+ ent["entity"]
+ for ent in _column_descriptions(
+ orig_query, compile_state=orig_compile_state
+ )
+ if ent["entity"] is not None
+ }
+ )
+
+ # select from the identity columns of the outer (specifically, these
+ # are the 'local_cols' of the property). This will remove other
+ # columns from the query that might suggest the right entity which is
+ # why we do set select_from above. The attributes we have are
+ # coerced and adapted using the original query's adapter, which is
+ # needed only for the case of adapting a subclass column to
+ # that of a polymorphic selectable, e.g. we have
+ # Engineer.primary_language and the entity is Person. All other
+ # adaptations, e.g. from_self, select_entity_from(), will occur
+ # within the new query when it compiles, as the compile_state we are
+ # using here is only a partial one. If the subqueryload is from a
+ # with_polymorphic() or other aliased() object, left_attr will already
+ # be the correct attributes so no adaptation is needed.
+ target_cols = orig_compile_state._adapt_col_list(
+ [
+ sql.coercions.expect(sql.roles.ColumnsClauseRole, o)
+ for o in leftmost_attr
+ ],
+ orig_compile_state._get_current_adapter(),
+ )
+ q._raw_columns = target_cols
+
+ distinct_target_key = leftmost_relationship.distinct_target_key
+
+ if distinct_target_key is True:
+ q._distinct = True
+ elif distinct_target_key is None:
+ # if target_cols refer to a non-primary key or only
+ # part of a composite primary key, set the q as distinct
+ for t in set(c.table for c in target_cols):
+ if not set(target_cols).issuperset(t.primary_key):
+ q._distinct = True
+ break
+
+ # don't need ORDER BY if no limit/offset
+ if not q._has_row_limiting_clause:
+ q._order_by_clauses = ()
+
+ if q._distinct is True and q._order_by_clauses:
+ # the logic to automatically add the order by columns to the query
+ # when distinct is True is deprecated in the query
+ to_add = sql_util.expand_column_list_from_order_by(
+ target_cols, q._order_by_clauses
+ )
+ if to_add:
+ q._set_entities(target_cols + to_add)
+
+ # the original query now becomes a subquery
+ # which we'll join onto.
+ # LEGACY: as "q" is a Query, the before_compile() event is invoked
+ # here.
+ embed_q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).subquery()
+ left_alias = orm_util.AliasedClass(
+ leftmost_mapper, embed_q, use_mapper_path=True
+ )
+ return left_alias
+
+ def _prep_for_joins(self, left_alias, subq_path):
+ # figure out what's being joined. a.k.a. the fun part
+ to_join = []
+ pairs = list(subq_path.pairs())
+
+ for i, (mapper, prop) in enumerate(pairs):
+ if i > 0:
+ # look at the previous mapper in the chain -
+ # if it is as or more specific than this prop's
+ # mapper, use that instead.
+ # note we have an assumption here that
+ # the non-first element is always going to be a mapper,
+ # not an AliasedClass
+
+ prev_mapper = pairs[i - 1][1].mapper
+ to_append = prev_mapper if prev_mapper.isa(mapper) else mapper
+ else:
+ to_append = mapper
+
+ to_join.append((to_append, prop.key))
+
+ # determine the immediate parent class we are joining from,
+ # which needs to be aliased.
+
+ if len(to_join) < 2:
+ # in the case of a one level eager load, this is the
+ # leftmost "left_alias".
+ parent_alias = left_alias
+ else:
+ info = inspect(to_join[-1][0])
+ if info.is_aliased_class:
+ parent_alias = info.entity
+ else:
+ # alias a plain mapper as we may be
+ # joining multiple times
+ parent_alias = orm_util.AliasedClass(
+ info.entity, use_mapper_path=True
+ )
+
+ local_cols = self.parent_property.local_columns
+
+ local_attr = [
+ getattr(parent_alias, self.parent._columntoproperty[c].key)
+ for c in local_cols
+ ]
+ return to_join, local_attr, parent_alias
+
+ def _apply_joins(
+ self, q, to_join, left_alias, parent_alias, effective_entity
+ ):
+
+ ltj = len(to_join)
+ if ltj == 1:
+ to_join = [
+ getattr(left_alias, to_join[0][1]).of_type(effective_entity)
+ ]
+ elif ltj == 2:
+ to_join = [
+ getattr(left_alias, to_join[0][1]).of_type(parent_alias),
+ getattr(parent_alias, to_join[-1][1]).of_type(
+ effective_entity
+ ),
+ ]
+ elif ltj > 2:
+ middle = [
+ (
+ orm_util.AliasedClass(item[0])
+ if not inspect(item[0]).is_aliased_class
+ else item[0].entity,
+ item[1],
+ )
+ for item in to_join[1:-1]
+ ]
+ inner = []
+
+ while middle:
+ item = middle.pop(0)
+ attr = getattr(item[0], item[1])
+ if middle:
+ attr = attr.of_type(middle[0][0])
+ else:
+ attr = attr.of_type(parent_alias)
+
+ inner.append(attr)
+
+ to_join = (
+ [getattr(left_alias, to_join[0][1]).of_type(inner[0].parent)]
+ + inner
+ + [
+ getattr(parent_alias, to_join[-1][1]).of_type(
+ effective_entity
+ )
+ ]
+ )
+
+ for attr in to_join:
+ q = q.join(attr)
+
+ return q
+
+ def _setup_options(
+ self,
+ context,
+ q,
+ subq_path,
+ rewritten_path,
+ orig_query,
+ effective_entity,
+ loadopt,
+ ):
+
+ # note that because the subqueryload object
+ # does not re-use the cached query, instead always making
+ # use of the current invoked query, while we have two queries
+ # here (orig and context.query), they are both non-cached
+ # queries and we can transfer the options as is without
+ # adjusting for new criteria. Some work on #6881 / #6889
+ # brought this into question.
+ new_options = orig_query._with_options
+
+ if loadopt and loadopt._extra_criteria:
+
+ new_options += (
+ orm_util.LoaderCriteriaOption(
+ self.entity,
+ loadopt._generate_extra_criteria(context),
+ ),
+ )
+
+ # propagate loader options etc. to the new query.
+ # these will fire relative to subq_path.
+ q = q._with_current_path(rewritten_path)
+ q = q.options(*new_options)
+
+ return q
+
+ def _setup_outermost_orderby(self, q):
+ if self.parent_property.order_by:
+
+ def _setup_outermost_orderby(compile_context):
+ compile_context.eager_order_by += tuple(
+ util.to_list(self.parent_property.order_by)
+ )
+
+ q = q._add_context_option(
+ _setup_outermost_orderby, self.parent_property
+ )
+
+ return q
+
+ class _SubqCollections(object):
+ """Given a :class:`_query.Query` used to emit the "subquery load",
+ provide a load interface that executes the query at the
+ first moment a value is needed.
+
+ """
+
+ __slots__ = (
+ "session",
+ "execution_options",
+ "load_options",
+ "params",
+ "subq",
+ "_data",
+ )
+
+ def __init__(self, context, subq):
+ # avoid creating a cycle by storing context
+ # even though that's preferable
+ self.session = context.session
+ self.execution_options = context.execution_options
+ self.load_options = context.load_options
+ self.params = context.params or {}
+ self.subq = subq
+ self._data = None
+
+ def get(self, key, default):
+ if self._data is None:
+ self._load()
+ return self._data.get(key, default)
+
+ def _load(self):
+ self._data = collections.defaultdict(list)
+
+ q = self.subq
+ assert q.session is None
+
+ q = q.with_session(self.session)
+
+ if self.load_options._populate_existing:
+ q = q.populate_existing()
+ # to work with baked query, the parameters may have been
+ # updated since this query was created, so take these into account
+
+ rows = list(q.params(self.params))
+ for k, v in itertools.groupby(rows, lambda x: x[1:]):
+ self._data[k].extend(vv[0] for vv in v)
+
+ def loader(self, state, dict_, row):
+ if self._data is None:
+ self._load()
+
+ def _setup_query_from_rowproc(
+ self,
+ context,
+ query_entity,
+ path,
+ entity,
+ loadopt,
+ adapter,
+ ):
+ compile_state = context.compile_state
+ if (
+ not compile_state.compile_options._enable_eagerloads
+ or compile_state.compile_options._for_refresh_state
+ ):
+ return
+
+ orig_query_entity_index = compile_state._entities.index(query_entity)
+ context.loaders_require_buffering = True
+
+ path = path[self.parent_property]
+
+ # build up a path indicating the path from the leftmost
+ # entity to the thing we're subquery loading.
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ effective_entity = with_poly_entity
+ else:
+ effective_entity = self.entity
+
+ subq_path, rewritten_path = context.query._execution_options.get(
+ ("subquery_paths", None),
+ (orm_util.PathRegistry.root, orm_util.PathRegistry.root),
+ )
+ is_root = subq_path is orm_util.PathRegistry.root
+ subq_path = subq_path + path
+ rewritten_path = rewritten_path + path
+
+ # if not via query option, check for
+ # a cycle
+ # TODO: why is this here??? this is now handled
+ # by the _check_recursive_postload call
+ if not path.contains(compile_state.attributes, "loader"):
+ if self.join_depth:
+ if (
+ (
+ compile_state.current_path.length
+ if compile_state.current_path
+ else 0
+ )
+ + path.length
+ ) / 2 > self.join_depth:
+ return
+ elif subq_path.contains_mapper(self.mapper):
+ return
+
+ # use the current query being invoked, not the compile state
+ # one. this is so that we get the current parameters. however,
+ # it means we can't use the existing compile state, we have to make
+ # a new one. other approaches include possibly using the
+ # compiled query but swapping the params, seems only marginally
+ # less time spent but more complicated
+ orig_query = context.query._execution_options.get(
+ ("orig_query", SubqueryLoader), context.query
+ )
+
+ # make a new compile_state for the query that's probably cached, but
+ # we're sort of undoing a bit of that caching :(
+ compile_state_cls = ORMCompileState._get_plugin_class_for_plugin(
+ orig_query, "orm"
+ )
+
+ if orig_query._is_lambda_element:
+ if context.load_options._lazy_loaded_from is None:
+ util.warn(
+ 'subqueryloader for "%s" must invoke lambda callable '
+ "at %r in "
+ "order to produce a new query, decreasing the efficiency "
+ "of caching for this statement. Consider using "
+ "selectinload() for more effective full-lambda caching"
+ % (self, orig_query)
+ )
+ orig_query = orig_query._resolved
+
+ # this is the more "quick" version, however it's not clear how
+ # much of this we need. in particular I can't get a test to
+ # fail if the "set_base_alias" is missing and not sure why that is.
+ orig_compile_state = compile_state_cls._create_entities_collection(
+ orig_query, legacy=False
+ )
+
+ (
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ rewritten_path,
+ ) = self._get_leftmost(
+ orig_query_entity_index,
+ rewritten_path,
+ orig_compile_state,
+ is_root,
+ )
+
+ # generate a new Query from the original, then
+ # produce a subquery from it.
+ left_alias = self._generate_from_original_query(
+ orig_compile_state,
+ orig_query,
+ leftmost_mapper,
+ leftmost_attr,
+ leftmost_relationship,
+ entity,
+ )
+
+ # generate another Query that will join the
+ # left alias to the target relationships.
+ # basically doing a longhand
+ # "from_self()". (from_self() itself not quite industrial
+ # strength enough for all contingencies...but very close)
+
+ q = query.Query(effective_entity)
+
+ q._execution_options = q._execution_options.union(
+ {
+ ("orig_query", SubqueryLoader): orig_query,
+ ("subquery_paths", None): (subq_path, rewritten_path),
+ }
+ )
+
+ q = q._set_enable_single_crit(False)
+ to_join, local_attr, parent_alias = self._prep_for_joins(
+ left_alias, subq_path
+ )
+
+ q = q.add_columns(*local_attr)
+ q = self._apply_joins(
+ q, to_join, left_alias, parent_alias, effective_entity
+ )
+
+ q = self._setup_options(
+ context,
+ q,
+ subq_path,
+ rewritten_path,
+ orig_query,
+ effective_entity,
+ loadopt,
+ )
+ q = self._setup_outermost_orderby(q)
+
+ return q
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+
+ if context.refresh_state:
+ return self._immediateload_create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+ # the subqueryloader does a similar check in setup_query() unlike
+ # the other post loaders, however we have this here for consistency
+ elif self._check_recursive_postload(context, path, self.join_depth):
+ return
+ elif not isinstance(context.compile_state, ORMSelectCompileState):
+ # issue 7505 - subqueryload() in 1.3 and previous would silently
+ # degrade for from_statement() without warning. this behavior
+ # is restored here
+ return
+
+ if not self.parent.class_manager[self.key].impl.supports_population:
+ raise sa_exc.InvalidRequestError(
+ "'%s' does not support object "
+ "population - eager loading cannot be applied." % self
+ )
+
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
+
+ subq = self._setup_query_from_rowproc(
+ context,
+ query_entity,
+ path,
+ path[-1],
+ loadopt,
+ adapter,
+ )
+
+ if subq is None:
+ return
+
+ assert subq.session is None
+
+ path = path[self.parent_property]
+
+ local_cols = self.parent_property.local_columns
+
+ # cache the loaded collections in the context
+ # so that inheriting mappers don't re-load when they
+ # call upon create_row_processor again
+ collections = path.get(context.attributes, "collections")
+ if collections is None:
+ collections = self._SubqCollections(context, subq)
+ path.set(context.attributes, "collections", collections)
+
+ if adapter:
+ local_cols = [adapter.columns[c] for c in local_cols]
+
+ if self.uselist:
+ self._create_collection_loader(
+ context, result, collections, local_cols, populators
+ )
+ else:
+ self._create_scalar_loader(
+ context, result, collections, local_cols, populators
+ )
+
+ def _create_collection_loader(
+ self, context, result, collections, local_cols, populators
+ ):
+ tuple_getter = result._tuple_getter(local_cols)
+
+ def load_collection_from_subq(state, dict_, row):
+ collection = collections.get(tuple_getter(row), ())
+ state.get_impl(self.key).set_committed_value(
+ state, dict_, collection
+ )
+
+ def load_collection_from_subq_existing_row(state, dict_, row):
+ if self.key not in dict_:
+ load_collection_from_subq(state, dict_, row)
+
+ populators["new"].append((self.key, load_collection_from_subq))
+ populators["existing"].append(
+ (self.key, load_collection_from_subq_existing_row)
+ )
+
+ if context.invoke_all_eagers:
+ populators["eager"].append((self.key, collections.loader))
+
+ def _create_scalar_loader(
+ self, context, result, collections, local_cols, populators
+ ):
+ tuple_getter = result._tuple_getter(local_cols)
+
+ def load_scalar_from_subq(state, dict_, row):
+ collection = collections.get(tuple_getter(row), (None,))
+ if len(collection) > 1:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for eagerly-loaded attribute '%s' " % self
+ )
+
+ scalar = collection[0]
+ state.get_impl(self.key).set_committed_value(state, dict_, scalar)
+
+ def load_scalar_from_subq_existing_row(state, dict_, row):
+ if self.key not in dict_:
+ load_scalar_from_subq(state, dict_, row)
+
+ populators["new"].append((self.key, load_scalar_from_subq))
+ populators["existing"].append(
+ (self.key, load_scalar_from_subq_existing_row)
+ )
+ if context.invoke_all_eagers:
+ populators["eager"].append((self.key, collections.loader))
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="joined")
+@relationships.RelationshipProperty.strategy_for(lazy=False)
+class JoinedLoader(AbstractRelationshipLoader):
+ """Provide loading behavior for a :class:`.RelationshipProperty`
+ using joined eager loading.
+
+ """
+
+ __slots__ = "join_depth", "_aliased_class_pool"
+
+ def __init__(self, parent, strategy_key):
+ super(JoinedLoader, self).__init__(parent, strategy_key)
+ self.join_depth = self.parent_property.join_depth
+ self._aliased_class_pool = []
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def setup_query(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection=None,
+ parentmapper=None,
+ chained_from_outerjoin=False,
+ **kwargs
+ ):
+ """Add a left outer join to the statement that's being constructed."""
+
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+ elif self.uselist:
+ compile_state.multi_row_eager_loaders = True
+
+ path = path[self.parent_property]
+
+ with_polymorphic = None
+
+ user_defined_adapter = (
+ self._init_user_defined_eager_proc(
+ loadopt, compile_state, compile_state.attributes
+ )
+ if loadopt
+ else False
+ )
+
+ if user_defined_adapter is not False:
+ (
+ clauses,
+ adapter,
+ add_to_collection,
+ ) = self._setup_query_on_user_defined_adapter(
+ compile_state,
+ query_entity,
+ path,
+ adapter,
+ user_defined_adapter,
+ )
+ else:
+ # if not via query option, check for
+ # a cycle
+ if not path.contains(compile_state.attributes, "loader"):
+ if self.join_depth:
+ if path.length / 2 > self.join_depth:
+ return
+ elif path.contains_mapper(self.mapper):
+ return
+
+ (
+ clauses,
+ adapter,
+ add_to_collection,
+ chained_from_outerjoin,
+ ) = self._generate_row_adapter(
+ compile_state,
+ query_entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ parentmapper,
+ chained_from_outerjoin,
+ )
+
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ with_polymorphic = inspect(
+ with_poly_entity
+ ).with_polymorphic_mappers
+ else:
+ with_polymorphic = None
+
+ path = path[self.entity]
+
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ query_entity,
+ path,
+ clauses,
+ add_to_collection,
+ with_polymorphic=with_polymorphic,
+ parentmapper=self.mapper,
+ chained_from_outerjoin=chained_from_outerjoin,
+ )
+
+ if with_poly_entity is not None and None in set(
+ compile_state.secondary_columns
+ ):
+ raise sa_exc.InvalidRequestError(
+ "Detected unaliased columns when generating joined "
+ "load. Make sure to use aliased=True or flat=True "
+ "when using joined loading with with_polymorphic()."
+ )
+
+ def _init_user_defined_eager_proc(
+ self, loadopt, compile_state, target_attributes
+ ):
+
+ # check if the opt applies at all
+ if "eager_from_alias" not in loadopt.local_opts:
+ # nope
+ return False
+
+ path = loadopt.path.parent
+
+ # the option applies. check if the "user_defined_eager_row_processor"
+ # has been built up.
+ adapter = path.get(
+ compile_state.attributes, "user_defined_eager_row_processor", False
+ )
+ if adapter is not False:
+ # just return it
+ return adapter
+
+ # otherwise figure it out.
+ alias = loadopt.local_opts["eager_from_alias"]
+ root_mapper, prop = path[-2:]
+
+ if alias is not None:
+ if isinstance(alias, str):
+ alias = prop.target.alias(alias)
+ adapter = sql_util.ColumnAdapter(
+ alias, equivalents=prop.mapper._equivalent_columns
+ )
+ else:
+ if path.contains(
+ compile_state.attributes, "path_with_polymorphic"
+ ):
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic"
+ )
+ adapter = orm_util.ORMAdapter(
+ with_poly_entity,
+ equivalents=prop.mapper._equivalent_columns,
+ )
+ else:
+ adapter = compile_state._polymorphic_adapters.get(
+ prop.mapper, None
+ )
+ path.set(
+ target_attributes,
+ "user_defined_eager_row_processor",
+ adapter,
+ )
+
+ return adapter
+
+ def _setup_query_on_user_defined_adapter(
+ self, context, entity, path, adapter, user_defined_adapter
+ ):
+
+ # apply some more wrapping to the "user defined adapter"
+ # if we are setting up the query for SQL render.
+ adapter = entity._get_entity_clauses(context)
+
+ if adapter and user_defined_adapter:
+ user_defined_adapter = user_defined_adapter.wrap(adapter)
+ path.set(
+ context.attributes,
+ "user_defined_eager_row_processor",
+ user_defined_adapter,
+ )
+ elif adapter:
+ user_defined_adapter = adapter
+ path.set(
+ context.attributes,
+ "user_defined_eager_row_processor",
+ user_defined_adapter,
+ )
+
+ add_to_collection = context.primary_columns
+ return user_defined_adapter, adapter, add_to_collection
+
+ def _gen_pooled_aliased_class(self, context):
+ # keep a local pool of AliasedClass objects that get re-used.
+ # we need one unique AliasedClass per query per appearance of our
+ # entity in the query.
+
+ if inspect(self.entity).is_aliased_class:
+ alt_selectable = inspect(self.entity).selectable
+ else:
+ alt_selectable = None
+
+ key = ("joinedloader_ac", self)
+ if key not in context.attributes:
+ context.attributes[key] = idx = 0
+ else:
+ context.attributes[key] = idx = context.attributes[key] + 1
+
+ if idx >= len(self._aliased_class_pool):
+ to_adapt = orm_util.AliasedClass(
+ self.mapper,
+ alias=alt_selectable._anonymous_fromclause(flat=True)
+ if alt_selectable is not None
+ else None,
+ flat=True,
+ use_mapper_path=True,
+ )
+
+ # load up the .columns collection on the Alias() before
+ # the object becomes shared among threads. this prevents
+ # races for column identities.
+ inspect(to_adapt).selectable.c
+ self._aliased_class_pool.append(to_adapt)
+
+ return self._aliased_class_pool[idx]
+
+ def _generate_row_adapter(
+ self,
+ compile_state,
+ entity,
+ path,
+ loadopt,
+ adapter,
+ column_collection,
+ parentmapper,
+ chained_from_outerjoin,
+ ):
+ with_poly_entity = path.get(
+ compile_state.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity:
+ to_adapt = with_poly_entity
+ else:
+ to_adapt = self._gen_pooled_aliased_class(compile_state)
+
+ clauses = inspect(to_adapt)._memo(
+ ("joinedloader_ormadapter", self),
+ orm_util.ORMAdapter,
+ to_adapt,
+ equivalents=self.mapper._equivalent_columns,
+ adapt_required=True,
+ allow_label_resolve=False,
+ anonymize_labels=True,
+ )
+
+ assert clauses.aliased_class is not None
+
+ innerjoin = (
+ loadopt.local_opts.get("innerjoin", self.parent_property.innerjoin)
+ if loadopt is not None
+ else self.parent_property.innerjoin
+ )
+
+ if not innerjoin:
+ # if this is an outer join, all non-nested eager joins from
+ # this path must also be outer joins
+ chained_from_outerjoin = True
+
+ compile_state.create_eager_joins.append(
+ (
+ self._create_eager_join,
+ entity,
+ path,
+ adapter,
+ parentmapper,
+ clauses,
+ innerjoin,
+ chained_from_outerjoin,
+ loadopt._extra_criteria if loadopt else (),
+ )
+ )
+
+ add_to_collection = compile_state.secondary_columns
+ path.set(compile_state.attributes, "eager_row_processor", clauses)
+
+ return clauses, adapter, add_to_collection, chained_from_outerjoin
+
+ def _create_eager_join(
+ self,
+ compile_state,
+ query_entity,
+ path,
+ adapter,
+ parentmapper,
+ clauses,
+ innerjoin,
+ chained_from_outerjoin,
+ extra_criteria,
+ ):
+ if parentmapper is None:
+ localparent = query_entity.mapper
+ else:
+ localparent = parentmapper
+
+ # whether or not the Query will wrap the selectable in a subquery,
+ # and then attach eager load joins to that (i.e., in the case of
+ # LIMIT/OFFSET etc.)
+ should_nest_selectable = (
+ compile_state.multi_row_eager_loaders
+ and compile_state._should_nest_selectable
+ )
+
+ query_entity_key = None
+
+ if (
+ query_entity not in compile_state.eager_joins
+ and not should_nest_selectable
+ and compile_state.from_clauses
+ ):
+
+ indexes = sql_util.find_left_clause_that_matches_given(
+ compile_state.from_clauses, query_entity.selectable
+ )
+
+ if len(indexes) > 1:
+ # for the eager load case, I can't reproduce this right
+ # now. For query.join() I can.
+ raise sa_exc.InvalidRequestError(
+ "Can't identify which query entity in which to joined "
+ "eager load from. Please use an exact match when "
+ "specifying the join path."
+ )
+
+ if indexes:
+ clause = compile_state.from_clauses[indexes[0]]
+ # join to an existing FROM clause on the query.
+ # key it to its list index in the eager_joins dict.
+ # Query._compile_context will adapt as needed and
+ # append to the FROM clause of the select().
+ query_entity_key, default_towrap = indexes[0], clause
+
+ if query_entity_key is None:
+ query_entity_key, default_towrap = (
+ query_entity,
+ query_entity.selectable,
+ )
+
+ towrap = compile_state.eager_joins.setdefault(
+ query_entity_key, default_towrap
+ )
+
+ if adapter:
+ if getattr(adapter, "aliased_class", None):
+ # joining from an adapted entity. The adapted entity
+ # might be a "with_polymorphic", so resolve that to our
+ # specific mapper's entity before looking for our attribute
+ # name on it.
+ efm = inspect(adapter.aliased_class)._entity_for_mapper(
+ localparent
+ if localparent.isa(self.parent)
+ else self.parent
+ )
+
+ # look for our attribute on the adapted entity, else fall back
+ # to our straight property
+ onclause = getattr(efm.entity, self.key, self.parent_property)
+ else:
+ onclause = getattr(
+ orm_util.AliasedClass(
+ self.parent, adapter.selectable, use_mapper_path=True
+ ),
+ self.key,
+ self.parent_property,
+ )
+
+ else:
+ onclause = self.parent_property
+
+ assert clauses.aliased_class is not None
+
+ attach_on_outside = (
+ not chained_from_outerjoin
+ or not innerjoin
+ or innerjoin == "unnested"
+ or query_entity.entity_zero.represents_outer_join
+ )
+
+ extra_join_criteria = extra_criteria
+ additional_entity_criteria = compile_state.global_attributes.get(
+ ("additional_entity_criteria", self.mapper), ()
+ )
+ if additional_entity_criteria:
+ extra_join_criteria += tuple(
+ ae._resolve_where_criteria(self.mapper)
+ for ae in additional_entity_criteria
+ if ae.propagate_to_loaders
+ )
+
+ if attach_on_outside:
+ # this is the "classic" eager join case.
+ eagerjoin = orm_util._ORMJoin(
+ towrap,
+ clauses.aliased_class,
+ onclause,
+ isouter=not innerjoin
+ or query_entity.entity_zero.represents_outer_join
+ or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
+ _left_memo=self.parent,
+ _right_memo=self.mapper,
+ _extra_criteria=extra_join_criteria,
+ )
+ else:
+ # all other cases are innerjoin=='nested' approach
+ eagerjoin = self._splice_nested_inner_join(
+ path, towrap, clauses, onclause, extra_join_criteria
+ )
+
+ compile_state.eager_joins[query_entity_key] = eagerjoin
+
+ # send a hint to the Query as to where it may "splice" this join
+ eagerjoin.stop_on = query_entity.selectable
+
+ if not parentmapper:
+ # for parentclause that is the non-eager end of the join,
+ # ensure all the parent cols in the primaryjoin are actually
+ # in the
+ # columns clause (i.e. are not deferred), so that aliasing applied
+ # by the Query propagates those columns outward.
+ # This has the effect
+ # of "undefering" those columns.
+ for col in sql_util._find_columns(
+ self.parent_property.primaryjoin
+ ):
+ if localparent.persist_selectable.c.contains_column(col):
+ if adapter:
+ col = adapter.columns[col]
+ compile_state._append_dedupe_col_collection(
+ col, compile_state.primary_columns
+ )
+
+ if self.parent_property.order_by:
+ compile_state.eager_order_by += tuple(
+ (eagerjoin._target_adapter.copy_and_process)(
+ util.to_list(self.parent_property.order_by)
+ )
+ )
+
+ def _splice_nested_inner_join(
+ self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
+ ):
+
+ if splicing is False:
+ # first call is always handed a join object
+ # from the outside
+ assert isinstance(join_obj, orm_util._ORMJoin)
+ elif isinstance(join_obj, sql.selectable.FromGrouping):
+ return self._splice_nested_inner_join(
+ path,
+ join_obj.element,
+ clauses,
+ onclause,
+ extra_criteria,
+ splicing,
+ )
+ elif not isinstance(join_obj, orm_util._ORMJoin):
+ if path[-2] is splicing:
+ return orm_util._ORMJoin(
+ join_obj,
+ clauses.aliased_class,
+ onclause,
+ isouter=False,
+ _left_memo=splicing,
+ _right_memo=path[-1].mapper,
+ _extra_criteria=extra_criteria,
+ )
+ else:
+ # only here if splicing == True
+ return None
+
+ target_join = self._splice_nested_inner_join(
+ path,
+ join_obj.right,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._right_memo,
+ )
+ if target_join is None:
+ right_splice = False
+ target_join = self._splice_nested_inner_join(
+ path,
+ join_obj.left,
+ clauses,
+ onclause,
+ extra_criteria,
+ join_obj._left_memo,
+ )
+ if target_join is None:
+ # should only return None when recursively called,
+ # e.g. splicing==True
+ assert (
+ splicing is not False
+ ), "assertion failed attempting to produce joined eager loads"
+ return None
+ else:
+ right_splice = True
+
+ if right_splice:
+ # for a right splice, attempt to flatten out
+ # a JOIN b JOIN c JOIN .. to avoid needless
+ # parenthesis nesting
+ if not join_obj.isouter and not target_join.isouter:
+ eagerjoin = join_obj._splice_into_center(target_join)
+ else:
+ eagerjoin = orm_util._ORMJoin(
+ join_obj.left,
+ target_join,
+ join_obj.onclause,
+ isouter=join_obj.isouter,
+ _left_memo=join_obj._left_memo,
+ )
+ else:
+ eagerjoin = orm_util._ORMJoin(
+ target_join,
+ join_obj.right,
+ join_obj.onclause,
+ isouter=join_obj.isouter,
+ _right_memo=join_obj._right_memo,
+ )
+
+ eagerjoin._target_adapter = target_join._target_adapter
+ return eagerjoin
+
+ def _create_eager_adapter(self, context, result, adapter, path, loadopt):
+ compile_state = context.compile_state
+
+ user_defined_adapter = (
+ self._init_user_defined_eager_proc(
+ loadopt, compile_state, context.attributes
+ )
+ if loadopt
+ else False
+ )
+
+ if user_defined_adapter is not False:
+ decorator = user_defined_adapter
+ # user defined eagerloads are part of the "primary"
+ # portion of the load.
+ # the adapters applied to the Query should be honored.
+ if compile_state.compound_eager_adapter and decorator:
+ decorator = decorator.wrap(
+ compile_state.compound_eager_adapter
+ )
+ elif compile_state.compound_eager_adapter:
+ decorator = compile_state.compound_eager_adapter
+ else:
+ decorator = path.get(
+ compile_state.attributes, "eager_row_processor"
+ )
+ if decorator is None:
+ return False
+
+ if self.mapper._result_has_identity_key(result, decorator):
+ return decorator
+ else:
+ # no identity key - don't return a row
+ # processor, will cause a degrade to lazy
+ return False
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+ if not self.parent.class_manager[self.key].impl.supports_population:
+ raise sa_exc.InvalidRequestError(
+ "'%s' does not support object "
+ "population - eager loading cannot be applied." % self
+ )
+
+ if self.uselist:
+ context.loaders_require_uniquing = True
+
+ our_path = path[self.parent_property]
+
+ eager_adapter = self._create_eager_adapter(
+ context, result, adapter, our_path, loadopt
+ )
+
+ if eager_adapter is not False:
+ key = self.key
+
+ _instance = loading._instance_processor(
+ query_entity,
+ self.mapper,
+ context,
+ result,
+ our_path[self.entity],
+ eager_adapter,
+ )
+
+ if not self.uselist:
+ self._create_scalar_loader(context, key, _instance, populators)
+ else:
+ self._create_collection_loader(
+ context, key, _instance, populators
+ )
+ else:
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+
+ def _create_collection_loader(self, context, key, _instance, populators):
+ def load_collection_from_joined_new_row(state, dict_, row):
+ # note this must unconditionally clear out any existing collection.
+ # an existing collection would be present only in the case of
+ # populate_existing().
+ collection = attributes.init_state_collection(state, dict_, key)
+ result_list = util.UniqueAppender(
+ collection, "append_without_event"
+ )
+ context.attributes[(state, key)] = result_list
+ inst = _instance(row)
+ if inst is not None:
+ result_list.append(inst)
+
+ def load_collection_from_joined_existing_row(state, dict_, row):
+ if (state, key) in context.attributes:
+ result_list = context.attributes[(state, key)]
+ else:
+ # appender_key can be absent from context.attributes
+ # with isnew=False when self-referential eager loading
+ # is used; the same instance may be present in two
+ # distinct sets of result columns
+ collection = attributes.init_state_collection(
+ state, dict_, key
+ )
+ result_list = util.UniqueAppender(
+ collection, "append_without_event"
+ )
+ context.attributes[(state, key)] = result_list
+ inst = _instance(row)
+ if inst is not None:
+ result_list.append(inst)
+
+ def load_collection_from_joined_exec(state, dict_, row):
+ _instance(row)
+
+ populators["new"].append(
+ (self.key, load_collection_from_joined_new_row)
+ )
+ populators["existing"].append(
+ (self.key, load_collection_from_joined_existing_row)
+ )
+ if context.invoke_all_eagers:
+ populators["eager"].append(
+ (self.key, load_collection_from_joined_exec)
+ )
+
+ def _create_scalar_loader(self, context, key, _instance, populators):
+ def load_scalar_from_joined_new_row(state, dict_, row):
+ # set a scalar object instance directly on the parent
+ # object, bypassing InstrumentedAttribute event handlers.
+ dict_[key] = _instance(row)
+
+ def load_scalar_from_joined_existing_row(state, dict_, row):
+ # call _instance on the row, even though the object has
+ # been created, so that we further descend into properties
+ existing = _instance(row)
+
+ # conflicting value already loaded, this shouldn't happen
+ if key in dict_:
+ if existing is not dict_[key]:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for eagerly-loaded attribute '%s' "
+ % self
+ )
+ else:
+ # this case is when one row has multiple loads of the
+ # same entity (e.g. via aliasing), one has an attribute
+ # that the other doesn't.
+ dict_[key] = existing
+
+ def load_scalar_from_joined_exec(state, dict_, row):
+ _instance(row)
+
+ populators["new"].append((self.key, load_scalar_from_joined_new_row))
+ populators["existing"].append(
+ (self.key, load_scalar_from_joined_existing_row)
+ )
+ if context.invoke_all_eagers:
+ populators["eager"].append(
+ (self.key, load_scalar_from_joined_exec)
+ )
+
+
+@log.class_logger
+@relationships.RelationshipProperty.strategy_for(lazy="selectin")
+class SelectInLoader(PostLoader, util.MemoizedSlots):
+ __slots__ = (
+ "join_depth",
+ "omit_join",
+ "_parent_alias",
+ "_query_info",
+ "_fallback_query_info",
+ )
+
+ query_info = collections.namedtuple(
+ "queryinfo",
+ [
+ "load_only_child",
+ "load_with_join",
+ "in_expr",
+ "pk_cols",
+ "zero_idx",
+ "child_lookup_cols",
+ ],
+ )
+
+ _chunksize = 500
+
+ def __init__(self, parent, strategy_key):
+ super(SelectInLoader, self).__init__(parent, strategy_key)
+ self.join_depth = self.parent_property.join_depth
+ is_m2o = self.parent_property.direction is interfaces.MANYTOONE
+
+ if self.parent_property.omit_join is not None:
+ self.omit_join = self.parent_property.omit_join
+ else:
+ lazyloader = self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ )
+ if is_m2o:
+ self.omit_join = lazyloader.use_get
+ else:
+ self.omit_join = self.parent._get_clause[0].compare(
+ lazyloader._rev_lazywhere,
+ use_proxies=True,
+ compare_keys=False,
+ equivalents=self.parent._equivalent_columns,
+ )
+
+ if self.omit_join:
+ if is_m2o:
+ self._query_info = self._init_for_omit_join_m2o()
+ self._fallback_query_info = self._init_for_join()
+ else:
+ self._query_info = self._init_for_omit_join()
+ else:
+ self._query_info = self._init_for_join()
+
+ def _init_for_omit_join(self):
+ pk_to_fk = dict(
+ self.parent_property._join_condition.local_remote_pairs
+ )
+ pk_to_fk.update(
+ (equiv, pk_to_fk[k])
+ for k in list(pk_to_fk)
+ for equiv in self.parent._equivalent_columns.get(k, ())
+ )
+
+ pk_cols = fk_cols = [
+ pk_to_fk[col] for col in self.parent.primary_key if col in pk_to_fk
+ ]
+ if len(fk_cols) > 1:
+ in_expr = sql.tuple_(*fk_cols)
+ zero_idx = False
+ else:
+ in_expr = fk_cols[0]
+ zero_idx = True
+
+ return self.query_info(False, False, in_expr, pk_cols, zero_idx, None)
+
+ def _init_for_omit_join_m2o(self):
+ pk_cols = self.mapper.primary_key
+ if len(pk_cols) > 1:
+ in_expr = sql.tuple_(*pk_cols)
+ zero_idx = False
+ else:
+ in_expr = pk_cols[0]
+ zero_idx = True
+
+ lazyloader = self.parent_property._get_strategy((("lazy", "select"),))
+ lookup_cols = [lazyloader._equated_columns[pk] for pk in pk_cols]
+
+ return self.query_info(
+ True, False, in_expr, pk_cols, zero_idx, lookup_cols
+ )
+
+ def _init_for_join(self):
+ self._parent_alias = aliased(self.parent.class_)
+ pa_insp = inspect(self._parent_alias)
+ pk_cols = [
+ pa_insp._adapt_element(col) for col in self.parent.primary_key
+ ]
+ if len(pk_cols) > 1:
+ in_expr = sql.tuple_(*pk_cols)
+ zero_idx = False
+ else:
+ in_expr = pk_cols[0]
+ zero_idx = True
+ return self.query_info(False, True, in_expr, pk_cols, zero_idx, None)
+
+ def init_class_attribute(self, mapper):
+ self.parent_property._get_strategy(
+ (("lazy", "select"),)
+ ).init_class_attribute(mapper)
+
+ def create_row_processor(
+ self,
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ ):
+
+ if context.refresh_state:
+ return self._immediateload_create_row_processor(
+ context,
+ query_entity,
+ path,
+ loadopt,
+ mapper,
+ result,
+ adapter,
+ populators,
+ )
+ elif self._check_recursive_postload(context, path, self.join_depth):
+ return
+
+ if not self.parent.class_manager[self.key].impl.supports_population:
+ raise sa_exc.InvalidRequestError(
+ "'%s' does not support object "
+ "population - eager loading cannot be applied." % self
+ )
+
+ # a little dance here as the "path" is still something that only
+ # semi-tracks the exact series of things we are loading, still not
+ # telling us about with_polymorphic() and stuff like that when it's at
+ # the root.. the initial MapperEntity is more accurate for this case.
+ if len(path) == 1:
+ if not orm_util._entity_isa(query_entity.entity_zero, self.parent):
+ return
+ elif not orm_util._entity_isa(path[-1], self.parent):
+ return
+
+ selectin_path = (
+ context.compile_state.current_path or orm_util.PathRegistry.root
+ ) + path
+
+ path_w_prop = path[self.parent_property]
+
+ # build up a path indicating the path from the leftmost
+ # entity to the thing we're subquery loading.
+ with_poly_entity = path_w_prop.get(
+ context.attributes, "path_with_polymorphic", None
+ )
+ if with_poly_entity is not None:
+ effective_entity = inspect(with_poly_entity)
+ else:
+ effective_entity = self.entity
+
+ loading.PostLoad.callable_for_path(
+ context,
+ selectin_path,
+ self.parent,
+ self.parent_property,
+ self._load_for_path,
+ effective_entity,
+ loadopt,
+ )
+
+ def _load_for_path(
+ self, context, path, states, load_only, effective_entity, loadopt
+ ):
+ if load_only and self.key not in load_only:
+ return
+
+ query_info = self._query_info
+
+ if query_info.load_only_child:
+ our_states = collections.defaultdict(list)
+ none_states = []
+
+ mapper = self.parent
+
+ for state, overwrite in states:
+ state_dict = state.dict
+ related_ident = tuple(
+ mapper._get_state_attr_by_column(
+ state,
+ state_dict,
+ lk,
+ passive=attributes.PASSIVE_NO_FETCH,
+ )
+ for lk in query_info.child_lookup_cols
+ )
+ # if the loaded parent objects do not have the foreign key
+ # to the related item loaded, then degrade into the joined
+ # version of selectinload
+ if attributes.PASSIVE_NO_RESULT in related_ident:
+ query_info = self._fallback_query_info
+ break
+
+ # organize states into lists keyed to particular foreign
+ # key values.
+ if None not in related_ident:
+ our_states[related_ident].append(
+ (state, state_dict, overwrite)
+ )
+ else:
+ # For FK values that have None, add them to a
+ # separate collection that will be populated separately
+ none_states.append((state, state_dict, overwrite))
+
+ # note the above conditional may have changed query_info
+ if not query_info.load_only_child:
+ our_states = [
+ (state.key[1], state, state.dict, overwrite)
+ for state, overwrite in states
+ ]
+
+ pk_cols = query_info.pk_cols
+ in_expr = query_info.in_expr
+
+ if not query_info.load_with_join:
+ # in "omit join" mode, the primary key column and the
+ # "in" expression are in terms of the related entity. So
+ # if the related entity is polymorphic or otherwise aliased,
+ # we need to adapt our "pk_cols" and "in_expr" to that
+ # entity. in non-"omit join" mode, these are against the
+ # parent entity and do not need adaption.
+ if effective_entity.is_aliased_class:
+ pk_cols = [
+ effective_entity._adapt_element(col) for col in pk_cols
+ ]
+ in_expr = effective_entity._adapt_element(in_expr)
+
+ bundle_ent = orm_util.Bundle("pk", *pk_cols)
+ bundle_sql = bundle_ent.__clause_element__()
+
+ entity_sql = effective_entity.__clause_element__()
+ q = Select._create_raw_select(
+ _raw_columns=[bundle_sql, entity_sql],
+ _label_style=LABEL_STYLE_TABLENAME_PLUS_COL,
+ _compile_options=ORMCompileState.default_compile_options,
+ _propagate_attrs={
+ "compile_state_plugin": "orm",
+ "plugin_subject": effective_entity,
+ },
+ )
+
+ if not query_info.load_with_join:
+ # the Bundle we have in the "omit_join" case is against raw, non
+ # annotated columns, so to ensure the Query knows its primary
+ # entity, we add it explicitly. If we made the Bundle against
+ # annotated columns, we hit a performance issue in this specific
+ # case, which is detailed in issue #4347.
+ q = q.select_from(effective_entity)
+ else:
+ # in the non-omit_join case, the Bundle is against the annotated/
+ # mapped column of the parent entity, but the #4347 issue does not
+ # occur in this case.
+ q = q.select_from(self._parent_alias).join(
+ getattr(self._parent_alias, self.parent_property.key).of_type(
+ effective_entity
+ )
+ )
+
+ q = q.filter(in_expr.in_(sql.bindparam("primary_keys")))
+
+ # a test which exercises what these comments talk about is
+ # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic
+ #
+ # effective_entity above is given to us in terms of the cached
+ # statement, namely this one:
+ orig_query = context.compile_state.select_statement
+
+ # the actual statement that was requested is this one:
+ # context_query = context.query
+ #
+ # that's not the cached one, however. So while it is of the identical
+ # structure, if it has entities like AliasedInsp, which we get from
+ # aliased() or with_polymorphic(), the AliasedInsp will likely be a
+ # different object identity each time, and will not match up
+ # hashing-wise to the corresponding AliasedInsp that's in the
+ # cached query, meaning it won't match on paths and loader lookups
+ # and loaders like this one will be skipped if it is used in options.
+ #
+ # Now we want to transfer loader options from the parent query to the
+ # "selectinload" query we're about to run. Which query do we transfer
+ # the options from? We use the cached query, because the options in
+ # that query will be in terms of the effective entity we were just
+ # handed.
+ #
+ # But now the selectinload query we are running is *also*
+ # cached. What if it's cached and running from some previous iteration
+ # of that AliasedInsp? Well in that case it will also use the previous
+ # iteration of the loader options. If the query expires and
+ # gets generated again, it will be handed the current effective_entity
+ # and the current _with_options, again in terms of whatever
+ # compile_state.select_statement happens to be right now, so the
+ # query will still be internally consistent and loader callables
+ # will be correctly invoked.
+
+ effective_path = path[self.parent_property]
+
+ if orig_query is context.query:
+ options = new_options = orig_query._with_options
+ user_defined_options = []
+ else:
+ options = orig_query._with_options
+
+ # propagate compile state options from the original query,
+ # updating their "extra_criteria" as necessary.
+ # note this will create a different cache key than
+ # "orig" options if extra_criteria is present, because the copy
+ # of extra_criteria will have different boundparam than that of
+ # the QueryableAttribute in the path
+
+ new_options = [
+ orig_opt._adjust_for_extra_criteria(context)
+ if orig_opt._is_strategy_option
+ else orig_opt
+ for orig_opt in options
+ if orig_opt._is_compile_state or orig_opt._is_legacy_option
+ ]
+
+ # propagate user defined options from the current query
+ user_defined_options = [
+ opt
+ for opt in context.query._with_options
+ if not opt._is_compile_state and not opt._is_legacy_option
+ ]
+
+ if loadopt and loadopt._extra_criteria:
+ new_options += (
+ orm_util.LoaderCriteriaOption(
+ effective_entity,
+ loadopt._generate_extra_criteria(context),
+ ),
+ )
+
+ q = q.options(*new_options)._update_compile_options(
+ {"_current_path": effective_path}
+ )
+ if user_defined_options:
+ q = q.options(*user_defined_options)
+
+ if context.populate_existing:
+ q = q.execution_options(populate_existing=True)
+
+ if self.parent_property.order_by:
+ if not query_info.load_with_join:
+ eager_order_by = self.parent_property.order_by
+ if effective_entity.is_aliased_class:
+ eager_order_by = [
+ effective_entity._adapt_element(elem)
+ for elem in eager_order_by
+ ]
+ q = q.order_by(*eager_order_by)
+ else:
+
+ def _setup_outermost_orderby(compile_context):
+ compile_context.eager_order_by += tuple(
+ util.to_list(self.parent_property.order_by)
+ )
+
+ q = q._add_context_option(
+ _setup_outermost_orderby, self.parent_property
+ )
+
+ if query_info.load_only_child:
+ self._load_via_child(
+ our_states, none_states, query_info, q, context
+ )
+ else:
+ self._load_via_parent(our_states, query_info, q, context)
+
+ def _load_via_child(self, our_states, none_states, query_info, q, context):
+ uselist = self.uselist
+
+ # this sort is really for the benefit of the unit tests
+ our_keys = sorted(our_states)
+ while our_keys:
+ chunk = our_keys[0 : self._chunksize]
+ our_keys = our_keys[self._chunksize :]
+ data = {
+ k: v
+ for k, v in context.session.execute(
+ q,
+ params={
+ "primary_keys": [
+ key[0] if query_info.zero_idx else key
+ for key in chunk
+ ]
+ },
+ ).unique()
+ }
+
+ for key in chunk:
+ # for a real foreign key and no concurrent changes to the
+ # DB while running this method, "key" is always present in
+ # data. However, for primaryjoins without real foreign keys
+ # a non-None primaryjoin condition may still refer to no
+ # related object.
+ related_obj = data.get(key, None)
+ for state, dict_, overwrite in our_states[key]:
+ if not overwrite and self.key in dict_:
+ continue
+
+ state.get_impl(self.key).set_committed_value(
+ state,
+ dict_,
+ related_obj if not uselist else [related_obj],
+ )
+ # populate none states with empty value / collection
+ for state, dict_, overwrite in none_states:
+ if not overwrite and self.key in dict_:
+ continue
+
+ # note it's OK if this is a uselist=True attribute, the empty
+ # collection will be populated
+ state.get_impl(self.key).set_committed_value(state, dict_, None)
+
+ def _load_via_parent(self, our_states, query_info, q, context):
+ uselist = self.uselist
+ _empty_result = () if uselist else None
+
+ while our_states:
+ chunk = our_states[0 : self._chunksize]
+ our_states = our_states[self._chunksize :]
+
+ primary_keys = [
+ key[0] if query_info.zero_idx else key
+ for key, state, state_dict, overwrite in chunk
+ ]
+
+ data = collections.defaultdict(list)
+ for k, v in itertools.groupby(
+ context.session.execute(
+ q, params={"primary_keys": primary_keys}
+ ).unique(),
+ lambda x: x[0],
+ ):
+ data[k].extend(vv[1] for vv in v)
+
+ for key, state, state_dict, overwrite in chunk:
+
+ if not overwrite and self.key in state_dict:
+ continue
+
+ collection = data.get(key, _empty_result)
+
+ if not uselist and collection:
+ if len(collection) > 1:
+ util.warn(
+ "Multiple rows returned with "
+ "uselist=False for eagerly-loaded "
+ "attribute '%s' " % self
+ )
+ state.get_impl(self.key).set_committed_value(
+ state, state_dict, collection[0]
+ )
+ else:
+ # note that empty tuple set on uselist=False sets the
+ # value to None
+ state.get_impl(self.key).set_committed_value(
+ state, state_dict, collection
+ )
+
+
+def single_parent_validator(desc, prop):
+ def _do_check(state, value, oldvalue, initiator):
+ if value is not None and initiator.key == prop.key:
+ hasparent = initiator.hasparent(attributes.instance_state(value))
+ if hasparent and oldvalue is not value:
+ raise sa_exc.InvalidRequestError(
+ "Instance %s is already associated with an instance "
+ "of %s via its %s attribute, and is only allowed a "
+ "single parent."
+ % (orm_util.instance_str(value), state.class_, prop),
+ code="bbf1",
+ )
+ return value
+
+ def append(state, value, initiator):
+ return _do_check(state, value, None, initiator)
+
+ def set_(state, value, oldvalue, initiator):
+ return _do_check(state, value, oldvalue, initiator)
+
+ event.listen(
+ desc, "append", append, raw=True, retval=True, active_history=True
+ )
+ event.listen(desc, "set", set_, raw=True, retval=True, active_history=True)
diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py
new file mode 100644
index 0000000..c3dd5df
--- /dev/null
+++ b/lib/sqlalchemy/orm/strategy_options.py
@@ -0,0 +1,2008 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+
+"""
+
+from . import util as orm_util
+from .attributes import QueryableAttribute
+from .base import _class_to_mapper
+from .base import _is_aliased_class
+from .base import _is_mapped_class
+from .base import InspectionAttr
+from .interfaces import LoaderOption
+from .interfaces import MapperProperty
+from .interfaces import PropComparator
+from .path_registry import _DEFAULT_TOKEN
+from .path_registry import _WILDCARD_TOKEN
+from .path_registry import PathRegistry
+from .path_registry import TokenRegistry
+from .util import _orm_full_deannotate
+from .. import exc as sa_exc
+from .. import inspect
+from .. import util
+from ..sql import and_
+from ..sql import coercions
+from ..sql import roles
+from ..sql import traversals
+from ..sql import visitors
+from ..sql.base import _generative
+from ..sql.base import Generative
+
+
+class Load(Generative, LoaderOption):
+ """Represents loader options which modify the state of a
+ :class:`_query.Query` in order to affect how various mapped attributes are
+ loaded.
+
+ The :class:`_orm.Load` object is in most cases used implicitly behind the
+ scenes when one makes use of a query option like :func:`_orm.joinedload`,
+ :func:`.defer`, or similar. However, the :class:`_orm.Load` object
+ can also be used directly, and in some cases can be useful.
+
+ To use :class:`_orm.Load` directly, instantiate it with the target mapped
+ class as the argument. This style of usage is
+ useful when dealing with a :class:`_query.Query`
+ that has multiple entities::
+
+ myopt = Load(MyClass).joinedload("widgets")
+
+ The above ``myopt`` can now be used with :meth:`_query.Query.options`,
+ where it
+ will only take effect for the ``MyClass`` entity::
+
+ session.query(MyClass, MyOtherClass).options(myopt)
+
+ One case where :class:`_orm.Load`
+ is useful as public API is when specifying
+ "wildcard" options that only take effect for a certain class::
+
+ session.query(Order).options(Load(Order).lazyload('*'))
+
+ Above, all relationships on ``Order`` will be lazy-loaded, but other
+ attributes on those descendant objects will load using their normal
+ loader strategy.
+
+ .. seealso::
+
+ :ref:`deferred_options`
+
+ :ref:`deferred_loading_w_multiple`
+
+ :ref:`relationship_loader_options`
+
+ """
+
+ _is_strategy_option = True
+
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+ ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
+ (
+ "_context_cache_key",
+ visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
+ ),
+ (
+ "local_opts",
+ visitors.ExtendedInternalTraversal.dp_string_multi_dict,
+ ),
+ ]
+
+ def __init__(self, entity):
+ insp = inspect(entity)
+ insp._post_inspect
+
+ self.path = insp._path_registry
+ # note that this .context is shared among all descendant
+ # Load objects
+ self.context = util.OrderedDict()
+ self.local_opts = {}
+ self.is_class_strategy = False
+
+ @classmethod
+ def for_existing_path(cls, path):
+ load = cls.__new__(cls)
+ load.path = path
+ load.context = {}
+ load.local_opts = {}
+ load._of_type = None
+ load._extra_criteria = ()
+ return load
+
+ def _generate_extra_criteria(self, context):
+ """Apply the current bound parameters in a QueryContext to the
+ immediate "extra_criteria" stored with this Load object.
+
+ Load objects are typically pulled from the cached version of
+ the statement from a QueryContext. The statement currently being
+ executed will have new values (and keys) for bound parameters in the
+ extra criteria which need to be applied by loader strategies when
+ they handle this criteria for a result set.
+
+ """
+
+ assert (
+ self._extra_criteria
+ ), "this should only be called if _extra_criteria is present"
+
+ orig_query = context.compile_state.select_statement
+ current_query = context.query
+
+ # NOTE: while it seems like we should not do the "apply" operation
+ # here if orig_query is current_query, skipping it in the "optimized"
+ # case causes the query to be different from a cache key perspective,
+ # because we are creating a copy of the criteria which is no longer
+ # the same identity of the _extra_criteria in the loader option
+ # itself. cache key logic produces a different key for
+ # (A, copy_of_A) vs. (A, A), because in the latter case it shortens
+ # the second part of the key to just indicate on identity.
+
+ # if orig_query is current_query:
+ # not cached yet. just do the and_()
+ # return and_(*self._extra_criteria)
+
+ k1 = orig_query._generate_cache_key()
+ k2 = current_query._generate_cache_key()
+
+ return k2._apply_params_to_element(k1, and_(*self._extra_criteria))
+
+ def _adjust_for_extra_criteria(self, context):
+ """Apply the current bound parameters in a QueryContext to all
+ occurrences "extra_criteria" stored within al this Load object;
+ copying in place.
+
+ """
+ orig_query = context.compile_state.select_statement
+
+ applied = {}
+
+ ck = [None, None]
+
+ def process(opt):
+ if not opt._extra_criteria:
+ return
+
+ if ck[0] is None:
+ ck[:] = (
+ orig_query._generate_cache_key(),
+ context.query._generate_cache_key(),
+ )
+ k1, k2 = ck
+
+ opt._extra_criteria = tuple(
+ k2._apply_params_to_element(k1, crit)
+ for crit in opt._extra_criteria
+ )
+
+ return self._deep_clone(applied, process)
+
+ def _deep_clone(self, applied, process):
+ if self in applied:
+ return applied[self]
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ if self.context:
+ cloned.context = util.OrderedDict(
+ [
+ (
+ key,
+ value._deep_clone(applied, process)
+ if isinstance(value, Load)
+ else value,
+ )
+ for key, value in self.context.items()
+ ]
+ )
+
+ cloned.local_opts.update(self.local_opts)
+
+ process(cloned)
+
+ return cloned
+
+ @property
+ def _context_cache_key(self):
+ serialized = []
+ if self.context is None:
+ return []
+ for (key, loader_path), obj in self.context.items():
+ if key != "loader":
+ continue
+ serialized.append(loader_path + (obj,))
+ return serialized
+
+ def _generate(self):
+ cloned = super(Load, self)._generate()
+ cloned.local_opts = {}
+ return cloned
+
+ is_opts_only = False
+ is_class_strategy = False
+ strategy = None
+ propagate_to_loaders = False
+ _of_type = None
+ _extra_criteria = ()
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+
+ # process is being run here so that the options given are validated
+ # against what the lead entities were, as well as to accommodate
+ # for the entities having been replaced with equivalents
+ self._process(
+ compile_state,
+ mapper_entities,
+ not bool(compile_state.current_path),
+ )
+
+ def process_compile_state(self, compile_state):
+ if not compile_state.compile_options._enable_eagerloads:
+ return
+
+ self._process(
+ compile_state,
+ compile_state._lead_mapper_entities,
+ not bool(compile_state.current_path)
+ and not compile_state.compile_options._for_refresh_state,
+ )
+
+ def _process(self, compile_state, mapper_entities, raiseerr):
+ is_refresh = compile_state.compile_options._for_refresh_state
+ current_path = compile_state.current_path
+ if current_path:
+ for (token, start_path), loader in self.context.items():
+ if is_refresh and not loader.propagate_to_loaders:
+ continue
+ chopped_start_path = self._chop_path(start_path, current_path)
+ if chopped_start_path is not None:
+ compile_state.attributes[
+ (token, chopped_start_path)
+ ] = loader
+ else:
+ compile_state.attributes.update(self.context)
+
+ def _generate_path(
+ self,
+ path,
+ attr,
+ for_strategy,
+ wildcard_key,
+ raiseerr=True,
+ polymorphic_entity_context=None,
+ ):
+ existing_of_type = self._of_type
+ self._of_type = None
+ if raiseerr and not path.has_entity:
+ if isinstance(path, TokenRegistry):
+ raise sa_exc.ArgumentError(
+ "Wildcard token cannot be followed by another entity"
+ )
+ else:
+ raise sa_exc.ArgumentError(
+ "Mapped attribute '%s' does not "
+ "refer to a mapped entity" % (path.prop,)
+ )
+
+ if isinstance(attr, util.string_types):
+
+ default_token = attr.endswith(_DEFAULT_TOKEN)
+ attr_str_name = attr
+ if attr.endswith(_WILDCARD_TOKEN) or default_token:
+ if default_token:
+ self.propagate_to_loaders = False
+ if wildcard_key:
+ attr = "%s:%s" % (wildcard_key, attr)
+
+ # TODO: AliasedInsp inside the path for of_type is not
+ # working for a with_polymorphic entity because the
+ # relationship loaders don't render the with_poly into the
+ # path. See #4469 which will try to improve this
+ if existing_of_type and not existing_of_type.is_aliased_class:
+ path = path.parent[existing_of_type]
+ path = path.token(attr)
+ self.path = path
+ return path
+
+ if existing_of_type:
+ ent = inspect(existing_of_type)
+ else:
+ ent = path.entity
+
+ util.warn_deprecated_20(
+ "Using strings to indicate column or "
+ "relationship paths in loader options is deprecated "
+ "and will be removed in SQLAlchemy 2.0. Please use "
+ "the class-bound attribute directly.",
+ )
+ try:
+ # use getattr on the class to work around
+ # synonyms, hybrids, etc.
+ attr = getattr(ent.class_, attr)
+ except AttributeError as err:
+ if raiseerr:
+ util.raise_(
+ sa_exc.ArgumentError(
+ 'Can\'t find property named "%s" on '
+ "%s in this Query." % (attr, ent)
+ ),
+ replace_context=err,
+ )
+ else:
+ return None
+ else:
+ try:
+ attr = found_property = attr.property
+ except AttributeError as ae:
+ if not isinstance(attr, MapperProperty):
+ util.raise_(
+ sa_exc.ArgumentError(
+ 'Expected attribute "%s" on %s to be a '
+ "mapped attribute; "
+ "instead got %s object."
+ % (attr_str_name, ent, type(attr))
+ ),
+ replace_context=ae,
+ )
+ else:
+ raise
+
+ path = path[attr]
+ else:
+ insp = inspect(attr)
+
+ if insp.is_mapper or insp.is_aliased_class:
+ # TODO: this does not appear to be a valid codepath. "attr"
+ # would never be a mapper. This block is present in 1.2
+ # as well however does not seem to be accessed in any tests.
+ if not orm_util._entity_corresponds_to_use_path_impl(
+ attr.parent, path[-1]
+ ):
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Attribute '%s' does not "
+ "link from element '%s'" % (attr, path.entity)
+ )
+ else:
+ return None
+ elif insp.is_property:
+ prop = found_property = attr
+ path = path[prop]
+ elif insp.is_attribute:
+ prop = found_property = attr.property
+
+ if not orm_util._entity_corresponds_to_use_path_impl(
+ attr.parent, path[-1]
+ ):
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ 'Attribute "%s" does not '
+ 'link from element "%s".%s'
+ % (
+ attr,
+ path.entity,
+ (
+ " Did you mean to use "
+ "%s.of_type(%s)?"
+ % (path[-2], attr.class_.__name__)
+ if len(path) > 1
+ and path.entity.is_mapper
+ and attr.parent.is_aliased_class
+ else ""
+ ),
+ )
+ )
+ else:
+ return None
+
+ if attr._extra_criteria and not self._extra_criteria:
+ # in most cases, the process that brings us here will have
+ # already established _extra_criteria. however if not,
+ # and it's present on the attribute, then use that.
+ self._extra_criteria = attr._extra_criteria
+
+ if getattr(attr, "_of_type", None):
+ ac = attr._of_type
+ ext_info = of_type_info = inspect(ac)
+
+ if polymorphic_entity_context is None:
+ polymorphic_entity_context = self.context
+
+ existing = path.entity_path[prop].get(
+ polymorphic_entity_context, "path_with_polymorphic"
+ )
+
+ if not ext_info.is_aliased_class:
+ ac = orm_util.with_polymorphic(
+ ext_info.mapper.base_mapper,
+ ext_info.mapper,
+ aliased=True,
+ _use_mapper_path=True,
+ _existing_alias=inspect(existing)
+ if existing is not None
+ else None,
+ )
+
+ ext_info = inspect(ac)
+
+ path.entity_path[prop].set(
+ polymorphic_entity_context, "path_with_polymorphic", ac
+ )
+
+ path = path[prop][ext_info]
+
+ self._of_type = of_type_info
+
+ else:
+ path = path[prop]
+
+ if for_strategy is not None:
+ found_property._get_strategy(for_strategy)
+ if path.has_entity:
+ path = path.entity_path
+ self.path = path
+ return path
+
+ def __str__(self):
+ return "Load(strategy=%r)" % (self.strategy,)
+
+ def _coerce_strat(self, strategy):
+ if strategy is not None:
+ strategy = tuple(sorted(strategy.items()))
+ return strategy
+
+ def _apply_to_parent(self, parent, applied, bound):
+ raise NotImplementedError(
+ "Only 'unbound' loader options may be used with the "
+ "Load.options() method"
+ )
+
+ @_generative
+ def options(self, *opts):
+ r"""Apply a series of options as sub-options to this
+ :class:`_orm.Load`
+ object.
+
+ E.g.::
+
+ query = session.query(Author)
+ query = query.options(
+ joinedload(Author.book).options(
+ load_only(Book.summary, Book.excerpt),
+ joinedload(Book.citations).options(
+ joinedload(Citation.author)
+ )
+ )
+ )
+
+ :param \*opts: A series of loader option objects (ultimately
+ :class:`_orm.Load` objects) which should be applied to the path
+ specified by this :class:`_orm.Load` object.
+
+ .. versionadded:: 1.3.6
+
+ .. seealso::
+
+ :func:`.defaultload`
+
+ :ref:`relationship_loader_options`
+
+ :ref:`deferred_loading_w_multiple`
+
+ """
+ apply_cache = {}
+ bound = not isinstance(self, _UnboundLoad)
+ if bound:
+ raise NotImplementedError(
+ "The options() method is currently only supported "
+ "for 'unbound' loader options"
+ )
+ for opt in opts:
+ opt._apply_to_parent(self, apply_cache, bound)
+
+ @_generative
+ def set_relationship_strategy(
+ self, attr, strategy, propagate_to_loaders=True
+ ):
+ strategy = self._coerce_strat(strategy)
+ self.propagate_to_loaders = propagate_to_loaders
+ cloned = self._clone_for_bind_strategy(attr, strategy, "relationship")
+ self.path = cloned.path
+ self._of_type = cloned._of_type
+ self._extra_criteria = cloned._extra_criteria
+ cloned.is_class_strategy = self.is_class_strategy = False
+ self.propagate_to_loaders = cloned.propagate_to_loaders
+
+ @_generative
+ def set_column_strategy(self, attrs, strategy, opts=None, opts_only=False):
+ strategy = self._coerce_strat(strategy)
+ self.is_class_strategy = False
+ for attr in attrs:
+ cloned = self._clone_for_bind_strategy(
+ attr, strategy, "column", opts_only=opts_only, opts=opts
+ )
+ cloned.propagate_to_loaders = True
+
+ @_generative
+ def set_generic_strategy(self, attrs, strategy):
+ strategy = self._coerce_strat(strategy)
+ for attr in attrs:
+ cloned = self._clone_for_bind_strategy(attr, strategy, None)
+ cloned.propagate_to_loaders = True
+
+ @_generative
+ def set_class_strategy(self, strategy, opts):
+ strategy = self._coerce_strat(strategy)
+ cloned = self._clone_for_bind_strategy(None, strategy, None)
+ cloned.is_class_strategy = True
+ cloned.propagate_to_loaders = True
+ cloned.local_opts.update(opts)
+
+ def _clone_for_bind_strategy(
+ self, attr, strategy, wildcard_key, opts_only=False, opts=None
+ ):
+ """Create an anonymous clone of the Load/_UnboundLoad that is suitable
+ to be placed in the context / _to_bind collection of this Load
+ object. The clone will then lose references to context/_to_bind
+ in order to not create reference cycles.
+
+ """
+ cloned = self._generate()
+ cloned._generate_path(self.path, attr, strategy, wildcard_key)
+ cloned.strategy = strategy
+
+ cloned.local_opts = self.local_opts
+ if opts:
+ cloned.local_opts.update(opts)
+ if opts_only:
+ cloned.is_opts_only = True
+
+ if strategy or cloned.is_opts_only:
+ cloned._set_path_strategy()
+ return cloned
+
+ def _set_for_path(self, context, path, replace=True, merge_opts=False):
+ if merge_opts or not replace:
+ existing = path.get(context, "loader")
+ if existing:
+ if merge_opts:
+ existing.local_opts.update(self.local_opts)
+ existing._extra_criteria += self._extra_criteria
+ else:
+ path.set(context, "loader", self)
+ else:
+ existing = path.get(context, "loader")
+ path.set(context, "loader", self)
+ if existing and existing.is_opts_only:
+ self.local_opts.update(existing.local_opts)
+ existing._extra_criteria += self._extra_criteria
+
+ def _set_path_strategy(self):
+ if not self.is_class_strategy and self.path.has_entity:
+ effective_path = self.path.parent
+ else:
+ effective_path = self.path
+
+ if effective_path.is_token:
+ for path in effective_path.generate_for_superclasses():
+ self._set_for_path(
+ self.context,
+ path,
+ replace=True,
+ merge_opts=self.is_opts_only,
+ )
+ else:
+ self._set_for_path(
+ self.context,
+ effective_path,
+ replace=True,
+ merge_opts=self.is_opts_only,
+ )
+
+ # remove cycles; _set_path_strategy is always invoked on an
+ # anonymous clone of the Load / UnboundLoad object since #5056
+ self.context = None
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+
+ # can't pickle this right now; warning is raised by strategies
+ d["_extra_criteria"] = ()
+
+ if d["context"] is not None:
+ d["context"] = PathRegistry.serialize_context_dict(
+ d["context"], ("loader",)
+ )
+ d["path"] = self.path.serialize()
+ return d
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self.path = PathRegistry.deserialize(self.path)
+ if self.context is not None:
+ self.context = PathRegistry.deserialize_context_dict(self.context)
+
+ def _chop_path(self, to_chop, path):
+ i = -1
+
+ for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)):
+ if isinstance(c_token, util.string_types):
+ # TODO: this is approximated from the _UnboundLoad
+ # version and probably has issues, not fully covered.
+
+ if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN):
+ return to_chop
+ elif (
+ c_token != "relationship:%s" % (_WILDCARD_TOKEN,)
+ and c_token != p_token.key
+ ):
+ return None
+
+ if c_token is p_token:
+ continue
+ elif (
+ isinstance(c_token, InspectionAttr)
+ and c_token.is_mapper
+ and p_token.is_mapper
+ and c_token.isa(p_token)
+ ):
+ continue
+ else:
+ return None
+ return to_chop[i + 1 :]
+
+
+class _UnboundLoad(Load):
+ """Represent a loader option that isn't tied to a root entity.
+
+ The loader option will produce an entity-linked :class:`_orm.Load`
+ object when it is passed :meth:`_query.Query.options`.
+
+ This provides compatibility with the traditional system
+ of freestanding options, e.g. ``joinedload('x.y.z')``.
+
+ """
+
+ def __init__(self):
+ self.path = ()
+ self._to_bind = []
+ self.local_opts = {}
+ self._extra_criteria = ()
+
+ def _gen_cache_key(self, anon_map, bindparams, _unbound_option_seen=None):
+ """Inlined gen_cache_key
+
+ Original traversal is::
+
+
+ _cache_key_traversal = [
+ ("path", visitors.ExtendedInternalTraversal.dp_multi_list),
+ ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ (
+ "_to_bind",
+ visitors.ExtendedInternalTraversal.dp_has_cache_key_list,
+ ),
+ (
+ "_extra_criteria",
+ visitors.InternalTraversal.dp_clauseelement_list),
+ (
+ "local_opts",
+ visitors.ExtendedInternalTraversal.dp_string_multi_dict,
+ ),
+ ]
+
+ The inlining is so that the "_to_bind" list can be flattened to not
+ repeat the same UnboundLoad options over and over again.
+
+ See #6869
+
+ """
+
+ idself = id(self)
+ cls = self.__class__
+
+ if idself in anon_map:
+ return (anon_map[idself], cls)
+ else:
+ id_ = anon_map[idself]
+
+ vis = traversals._cache_key_traversal_visitor
+
+ seen = _unbound_option_seen
+ if seen is None:
+ seen = set()
+
+ return (
+ (id_, cls)
+ + vis.visit_multi_list(
+ "path", self.path, self, anon_map, bindparams
+ )
+ + ("strategy", self.strategy)
+ + (
+ (
+ "_to_bind",
+ tuple(
+ elem._gen_cache_key(
+ anon_map, bindparams, _unbound_option_seen=seen
+ )
+ for elem in self._to_bind
+ if elem not in seen and not seen.add(elem)
+ ),
+ )
+ if self._to_bind
+ else ()
+ )
+ + (
+ (
+ "_extra_criteria",
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in self._extra_criteria
+ ),
+ )
+ if self._extra_criteria
+ else ()
+ )
+ + (
+ vis.visit_string_multi_dict(
+ "local_opts", self.local_opts, self, anon_map, bindparams
+ )
+ if self.local_opts
+ else ()
+ )
+ )
+
+ _is_chain_link = False
+
+ def _set_path_strategy(self):
+ self._to_bind.append(self)
+
+ # remove cycles; _set_path_strategy is always invoked on an
+ # anonymous clone of the Load / UnboundLoad object since #5056
+ self._to_bind = None
+
+ def _deep_clone(self, applied, process):
+ if self in applied:
+ return applied[self]
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ cloned._to_bind = [
+ elem._deep_clone(applied, process) for elem in self._to_bind or ()
+ ]
+
+ cloned.local_opts.update(self.local_opts)
+
+ process(cloned)
+
+ return cloned
+
+ def _apply_to_parent(self, parent, applied, bound, to_bind=None):
+ if self in applied:
+ return applied[self]
+
+ if to_bind is None:
+ to_bind = self._to_bind
+
+ cloned = self._generate()
+
+ applied[self] = cloned
+
+ cloned.strategy = self.strategy
+ if self.path:
+ attr = self.path[-1]
+ if isinstance(attr, util.string_types) and attr.endswith(
+ _DEFAULT_TOKEN
+ ):
+ attr = attr.split(":")[0] + ":" + _WILDCARD_TOKEN
+ cloned._generate_path(
+ parent.path + self.path[0:-1], attr, self.strategy, None
+ )
+
+ # these assertions can go away once the "sub options" API is
+ # mature
+ assert cloned.propagate_to_loaders == self.propagate_to_loaders
+ assert cloned.is_class_strategy == self.is_class_strategy
+ assert cloned.is_opts_only == self.is_opts_only
+
+ uniq = set()
+
+ cloned._to_bind = parent._to_bind
+
+ cloned._to_bind[:] = [
+ elem
+ for elem in cloned._to_bind
+ if elem not in uniq and not uniq.add(elem)
+ ] + [
+ elem._apply_to_parent(parent, applied, bound, to_bind)
+ for elem in to_bind
+ if elem not in uniq and not uniq.add(elem)
+ ]
+
+ cloned.local_opts.update(self.local_opts)
+
+ return cloned
+
+ def _generate_path(self, path, attr, for_strategy, wildcard_key):
+ if (
+ wildcard_key
+ and isinstance(attr, util.string_types)
+ and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN)
+ ):
+ if attr == _DEFAULT_TOKEN:
+ self.propagate_to_loaders = False
+ attr = "%s:%s" % (wildcard_key, attr)
+ if path and _is_mapped_class(path[-1]) and not self.is_class_strategy:
+ path = path[0:-1]
+ if attr:
+ path = path + (attr,)
+ self.path = path
+ self._extra_criteria = getattr(attr, "_extra_criteria", ())
+
+ return path
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+
+ # can't pickle this right now; warning is raised by strategies
+ d["_extra_criteria"] = ()
+
+ d["path"] = self._serialize_path(self.path, filter_aliased_class=True)
+ return d
+
+ def __setstate__(self, state):
+ ret = []
+ for key in state["path"]:
+ if isinstance(key, tuple):
+ if len(key) == 2:
+ # support legacy
+ cls, propkey = key
+ of_type = None
+ else:
+ cls, propkey, of_type = key
+ prop = getattr(cls, propkey)
+ if of_type:
+ prop = prop.of_type(of_type)
+ ret.append(prop)
+ else:
+ ret.append(key)
+ state["path"] = tuple(ret)
+ self.__dict__ = state
+
+ def _process(self, compile_state, mapper_entities, raiseerr):
+ dedupes = compile_state.attributes["_unbound_load_dedupes"]
+ is_refresh = compile_state.compile_options._for_refresh_state
+ for val in self._to_bind:
+ if val not in dedupes:
+ dedupes.add(val)
+ if is_refresh and not val.propagate_to_loaders:
+ continue
+ val._bind_loader(
+ [ent.entity_zero for ent in mapper_entities],
+ compile_state.current_path,
+ compile_state.attributes,
+ raiseerr,
+ )
+
+ @classmethod
+ def _from_keys(cls, meth, keys, chained, kw):
+ opt = _UnboundLoad()
+
+ def _split_key(key):
+ if isinstance(key, util.string_types):
+ # coerce fooload('*') into "default loader strategy"
+ if key == _WILDCARD_TOKEN:
+ return (_DEFAULT_TOKEN,)
+ # coerce fooload(".*") into "wildcard on default entity"
+ elif key.startswith("." + _WILDCARD_TOKEN):
+ util.warn_deprecated(
+ "The undocumented `.{WILDCARD}` format is deprecated "
+ "and will be removed in a future version as it is "
+ "believed to be unused. "
+ "If you have been using this functionality, please "
+ "comment on Issue #4390 on the SQLAlchemy project "
+ "tracker.",
+ version="1.4",
+ )
+ key = key[1:]
+ return key.split(".")
+ else:
+ return (key,)
+
+ all_tokens = [token for key in keys for token in _split_key(key)]
+
+ for token in all_tokens[0:-1]:
+ # set _is_chain_link first so that clones of the
+ # object also inherit this flag
+ opt._is_chain_link = True
+ if chained:
+ opt = meth(opt, token, **kw)
+ else:
+ opt = opt.defaultload(token)
+
+ opt = meth(opt, all_tokens[-1], **kw)
+ opt._is_chain_link = False
+ return opt
+
+ def _chop_path(self, to_chop, path):
+ i = -1
+ for i, (c_token, (p_entity, p_prop)) in enumerate(
+ zip(to_chop, path.pairs())
+ ):
+ if isinstance(c_token, util.string_types):
+ if i == 0 and c_token.endswith(":" + _DEFAULT_TOKEN):
+ return to_chop
+ elif (
+ c_token != "relationship:%s" % (_WILDCARD_TOKEN,)
+ and c_token != p_prop.key
+ ):
+ return None
+ elif isinstance(c_token, PropComparator):
+ if c_token.property is not p_prop or (
+ c_token._parententity is not p_entity
+ and (
+ not c_token._parententity.is_mapper
+ or not c_token._parententity.isa(p_entity)
+ )
+ ):
+ return None
+ else:
+ i += 1
+
+ return to_chop[i:]
+
+ def _serialize_path(self, path, filter_aliased_class=False):
+ ret = []
+ for token in path:
+ if isinstance(token, QueryableAttribute):
+ if (
+ filter_aliased_class
+ and token._of_type
+ and inspect(token._of_type).is_aliased_class
+ ):
+ ret.append((token._parentmapper.class_, token.key, None))
+ else:
+ ret.append(
+ (
+ token._parentmapper.class_,
+ token.key,
+ token._of_type.entity if token._of_type else None,
+ )
+ )
+ elif isinstance(token, PropComparator):
+ ret.append((token._parentmapper.class_, token.key, None))
+ else:
+ ret.append(token)
+ return ret
+
+ def _bind_loader(self, entities, current_path, context, raiseerr):
+ """Convert from an _UnboundLoad() object into a Load() object.
+
+ The _UnboundLoad() uses an informal "path" and does not necessarily
+ refer to a lead entity as it may use string tokens. The Load()
+ OTOH refers to a complete path. This method reconciles from a
+ given Query into a Load.
+
+ Example::
+
+
+ query = session.query(User).options(
+ joinedload("orders").joinedload("items"))
+
+ The above options will be an _UnboundLoad object along the lines
+ of (note this is not the exact API of _UnboundLoad)::
+
+ _UnboundLoad(
+ _to_bind=[
+ _UnboundLoad(["orders"], {"lazy": "joined"}),
+ _UnboundLoad(["orders", "items"], {"lazy": "joined"}),
+ ]
+ )
+
+ After this method, we get something more like this (again this is
+ not exact API)::
+
+ Load(
+ User,
+ (User, User.orders.property))
+ Load(
+ User,
+ (User, User.orders.property, Order, Order.items.property))
+
+ """
+
+ start_path = self.path
+
+ if self.is_class_strategy and current_path:
+ start_path += (entities[0],)
+
+ # _current_path implies we're in a
+ # secondary load with an existing path
+
+ if current_path:
+ start_path = self._chop_path(start_path, current_path)
+
+ if not start_path:
+ return None
+
+ # look at the first token and try to locate within the Query
+ # what entity we are referring towards.
+ token = start_path[0]
+
+ if isinstance(token, util.string_types):
+ entity = self._find_entity_basestring(entities, token, raiseerr)
+ elif isinstance(token, PropComparator):
+ prop = token.property
+ entity = self._find_entity_prop_comparator(
+ entities, prop, token._parententity, raiseerr
+ )
+ elif self.is_class_strategy and _is_mapped_class(token):
+ entity = inspect(token)
+ if entity not in entities:
+ entity = None
+ else:
+ raise sa_exc.ArgumentError(
+ "mapper option expects " "string key or list of attributes"
+ )
+
+ if not entity:
+ return
+
+ path_element = entity
+
+ # transfer our entity-less state into a Load() object
+ # with a real entity path. Start with the lead entity
+ # we just located, then go through the rest of our path
+ # tokens and populate into the Load().
+ loader = Load(path_element)
+
+ if context is None:
+ context = loader.context
+
+ loader.strategy = self.strategy
+ loader.is_opts_only = self.is_opts_only
+ loader.is_class_strategy = self.is_class_strategy
+ loader._extra_criteria = self._extra_criteria
+
+ path = loader.path
+
+ if not loader.is_class_strategy:
+ for idx, token in enumerate(start_path):
+ if not loader._generate_path(
+ loader.path,
+ token,
+ self.strategy if idx == len(start_path) - 1 else None,
+ None,
+ raiseerr,
+ polymorphic_entity_context=context,
+ ):
+ return
+
+ loader.local_opts.update(self.local_opts)
+
+ if not loader.is_class_strategy and loader.path.has_entity:
+ effective_path = loader.path.parent
+ else:
+ effective_path = loader.path
+
+ # prioritize "first class" options over those
+ # that were "links in the chain", e.g. "x" and "y" in
+ # someload("x.y.z") versus someload("x") / someload("x.y")
+
+ if effective_path.is_token:
+ for path in effective_path.generate_for_superclasses():
+ loader._set_for_path(
+ context,
+ path,
+ replace=not self._is_chain_link,
+ merge_opts=self.is_opts_only,
+ )
+ else:
+ loader._set_for_path(
+ context,
+ effective_path,
+ replace=not self._is_chain_link,
+ merge_opts=self.is_opts_only,
+ )
+
+ return loader
+
+ def _find_entity_prop_comparator(self, entities, prop, mapper, raiseerr):
+ if _is_aliased_class(mapper):
+ searchfor = mapper
+ else:
+ searchfor = _class_to_mapper(mapper)
+ for ent in entities:
+ if orm_util._entity_corresponds_to(ent, searchfor):
+ return ent
+ else:
+ if raiseerr:
+ if not list(entities):
+ raise sa_exc.ArgumentError(
+ "Query has only expression-based entities, "
+ 'which do not apply to %s "%s"'
+ % (util.clsname_as_plain_name(type(prop)), prop)
+ )
+ else:
+ raise sa_exc.ArgumentError(
+ 'Mapped attribute "%s" does not apply to any of the '
+ "root entities in this query, e.g. %s. Please "
+ "specify the full path "
+ "from one of the root entities to the target "
+ "attribute. "
+ % (prop, ", ".join(str(x) for x in entities))
+ )
+ else:
+ return None
+
+ def _find_entity_basestring(self, entities, token, raiseerr):
+ if token.endswith(":" + _WILDCARD_TOKEN):
+ if len(list(entities)) != 1:
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Can't apply wildcard ('*') or load_only() "
+ "loader option to multiple entities %s. Specify "
+ "loader options for each entity individually, such "
+ "as %s."
+ % (
+ ", ".join(str(ent) for ent in entities),
+ ", ".join(
+ "Load(%s).some_option('*')" % ent
+ for ent in entities
+ ),
+ )
+ )
+ elif token.endswith(_DEFAULT_TOKEN):
+ raiseerr = False
+
+ for ent in entities:
+ # return only the first _MapperEntity when searching
+ # based on string prop name. Ideally object
+ # attributes are used to specify more exactly.
+ return ent
+ else:
+ if raiseerr:
+ raise sa_exc.ArgumentError(
+ "Query has only expression-based entities - "
+ 'can\'t find property named "%s".' % (token,)
+ )
+ else:
+ return None
+
+
+class loader_option(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, fn):
+ self.name = name = fn.__name__
+ self.fn = fn
+ if hasattr(Load, name):
+ raise TypeError("Load class already has a %s method." % (name))
+ setattr(Load, name, fn)
+
+ return self
+
+ def _add_unbound_fn(self, fn):
+ self._unbound_fn = fn
+ fn_doc = self.fn.__doc__
+ self.fn.__doc__ = """Produce a new :class:`_orm.Load` object with the
+:func:`_orm.%(name)s` option applied.
+
+See :func:`_orm.%(name)s` for usage examples.
+
+""" % {
+ "name": self.name
+ }
+
+ fn.__doc__ = fn_doc
+ return self
+
+ def _add_unbound_all_fn(self, fn):
+ fn.__doc__ = """Produce a standalone "all" option for
+:func:`_orm.%(name)s`.
+
+.. deprecated:: 0.9
+
+ The :func:`_orm.%(name)s_all` function is deprecated, and will be removed
+ in a future release. Please use method chaining with
+ :func:`_orm.%(name)s` instead, as in::
+
+ session.query(MyClass).options(
+ %(name)s("someattribute").%(name)s("anotherattribute")
+ )
+
+""" % {
+ "name": self.name
+ }
+ fn = util.deprecated(
+ # This is used by `baked_lazyload_all` was only deprecated in
+ # version 1.2 so this must stick around until that is removed
+ "0.9",
+ "The :func:`.%(name)s_all` function is deprecated, and will be "
+ "removed in a future release. Please use method chaining with "
+ ":func:`.%(name)s` instead" % {"name": self.name},
+ add_deprecation_to_docstring=False,
+ )(fn)
+
+ self._unbound_all_fn = fn
+ return self
+
+
+@loader_option()
+def contains_eager(loadopt, attr, alias=None):
+ r"""Indicate that the given attribute should be eagerly loaded from
+ columns stated manually in the query.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ The option is used in conjunction with an explicit join that loads
+ the desired rows, i.e.::
+
+ sess.query(Order).\
+ join(Order.user).\
+ options(contains_eager(Order.user))
+
+ The above query would join from the ``Order`` entity to its related
+ ``User`` entity, and the returned ``Order`` objects would have the
+ ``Order.user`` attribute pre-populated.
+
+ It may also be used for customizing the entries in an eagerly loaded
+ collection; queries will normally want to use the
+ :meth:`_query.Query.populate_existing` method assuming the primary
+ collection of parent objects may already have been loaded::
+
+ sess.query(User).\
+ join(User.addresses).\
+ filter(Address.email_address.like('%@aol.com')).\
+ options(contains_eager(User.addresses)).\
+ populate_existing()
+
+ See the section :ref:`contains_eager` for complete usage details.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`contains_eager`
+
+ """
+ if alias is not None:
+ if not isinstance(alias, str):
+ info = inspect(alias)
+ alias = info.selectable
+
+ else:
+ util.warn_deprecated(
+ "Passing a string name for the 'alias' argument to "
+ "'contains_eager()` is deprecated, and will not work in a "
+ "future release. Please use a sqlalchemy.alias() or "
+ "sqlalchemy.orm.aliased() construct.",
+ version="1.4",
+ )
+
+ elif getattr(attr, "_of_type", None):
+ ot = inspect(attr._of_type)
+ alias = ot.selectable
+
+ cloned = loadopt.set_relationship_strategy(
+ attr, {"lazy": "joined"}, propagate_to_loaders=False
+ )
+ cloned.local_opts["eager_from_alias"] = alias
+ return cloned
+
+
+@contains_eager._add_unbound_fn
+def contains_eager(*keys, **kw):
+ return _UnboundLoad()._from_keys(
+ _UnboundLoad.contains_eager, keys, True, kw
+ )
+
+
+@loader_option()
+def load_only(loadopt, *attrs):
+ """Indicate that for a particular entity, only the given list
+ of column-based attribute names should be loaded; all others will be
+ deferred.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ Example - given a class ``User``, load only the ``name`` and ``fullname``
+ attributes::
+
+ session.query(User).options(load_only(User.name, User.fullname))
+
+ Example - given a relationship ``User.addresses -> Address``, specify
+ subquery loading for the ``User.addresses`` collection, but on each
+ ``Address`` object load only the ``email_address`` attribute::
+
+ session.query(User).options(
+ subqueryload(User.addresses).load_only(Address.email_address)
+ )
+
+ For a :class:`_query.Query` that has multiple entities,
+ the lead entity can be
+ specifically referred to using the :class:`_orm.Load` constructor::
+
+ session.query(User, Address).join(User.addresses).options(
+ Load(User).load_only(User.name, User.fullname),
+ Load(Address).load_only(Address.email_address)
+ )
+
+ .. note:: This method will still load a :class:`_schema.Column` even
+ if the column property is defined with ``deferred=True``
+ for the :func:`.column_property` function.
+
+ .. versionadded:: 0.9.0
+
+ """
+ cloned = loadopt.set_column_strategy(
+ attrs, {"deferred": False, "instrument": True}
+ )
+ cloned.set_column_strategy(
+ "*", {"deferred": True, "instrument": True}, {"undefer_pks": True}
+ )
+ return cloned
+
+
+@load_only._add_unbound_fn
+def load_only(*attrs):
+ return _UnboundLoad().load_only(*attrs)
+
+
+@loader_option()
+def joinedload(loadopt, attr, innerjoin=None):
+ """Indicate that the given attribute should be loaded using joined
+ eager loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ examples::
+
+ # joined-load the "orders" collection on "User"
+ query(User).options(joinedload(User.orders))
+
+ # joined-load Order.items and then Item.keywords
+ query(Order).options(
+ joinedload(Order.items).joinedload(Item.keywords))
+
+ # lazily load Order.items, but when Items are loaded,
+ # joined-load the keywords collection
+ query(Order).options(
+ lazyload(Order.items).joinedload(Item.keywords))
+
+ :param innerjoin: if ``True``, indicates that the joined eager load should
+ use an inner join instead of the default of left outer join::
+
+ query(Order).options(joinedload(Order.user, innerjoin=True))
+
+ In order to chain multiple eager joins together where some may be
+ OUTER and others INNER, right-nested joins are used to link them::
+
+ query(A).options(
+ joinedload(A.bs, innerjoin=False).
+ joinedload(B.cs, innerjoin=True)
+ )
+
+ The above query, linking A.bs via "outer" join and B.cs via "inner" join
+ would render the joins as "a LEFT OUTER JOIN (b JOIN c)". When using
+ older versions of SQLite (< 3.7.16), this form of JOIN is translated to
+ use full subqueries as this syntax is otherwise not directly supported.
+
+ The ``innerjoin`` flag can also be stated with the term ``"unnested"``.
+ This indicates that an INNER JOIN should be used, *unless* the join
+ is linked to a LEFT OUTER JOIN to the left, in which case it
+ will render as LEFT OUTER JOIN. For example, supposing ``A.bs``
+ is an outerjoin::
+
+ query(A).options(
+ joinedload(A.bs).
+ joinedload(B.cs, innerjoin="unnested")
+ )
+
+ The above join will render as "a LEFT OUTER JOIN b LEFT OUTER JOIN c",
+ rather than as "a LEFT OUTER JOIN (b JOIN c)".
+
+ .. note:: The "unnested" flag does **not** affect the JOIN rendered
+ from a many-to-many association table, e.g. a table configured
+ as :paramref:`_orm.relationship.secondary`, to the target table; for
+ correctness of results, these joins are always INNER and are
+ therefore right-nested if linked to an OUTER join.
+
+ .. versionchanged:: 1.0.0 ``innerjoin=True`` now implies
+ ``innerjoin="nested"``, whereas in 0.9 it implied
+ ``innerjoin="unnested"``. In order to achieve the pre-1.0 "unnested"
+ inner join behavior, use the value ``innerjoin="unnested"``.
+ See :ref:`migration_3008`.
+
+ .. note::
+
+ The joins produced by :func:`_orm.joinedload` are **anonymously
+ aliased**. The criteria by which the join proceeds cannot be
+ modified, nor can the :class:`_query.Query`
+ refer to these joins in any way,
+ including ordering. See :ref:`zen_of_eager_loading` for further
+ detail.
+
+ To produce a specific SQL JOIN which is explicitly available, use
+ :meth:`_query.Query.join`.
+ To combine explicit JOINs with eager loading
+ of collections, use :func:`_orm.contains_eager`; see
+ :ref:`contains_eager`.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`joined_eager_loading`
+
+ """
+ loader = loadopt.set_relationship_strategy(attr, {"lazy": "joined"})
+ if innerjoin is not None:
+ loader.local_opts["innerjoin"] = innerjoin
+ return loader
+
+
+@joinedload._add_unbound_fn
+def joinedload(*keys, **kw):
+ return _UnboundLoad._from_keys(_UnboundLoad.joinedload, keys, False, kw)
+
+
+@loader_option()
+def subqueryload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using
+ subquery eager loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ examples::
+
+ # subquery-load the "orders" collection on "User"
+ query(User).options(subqueryload(User.orders))
+
+ # subquery-load Order.items and then Item.keywords
+ query(Order).options(
+ subqueryload(Order.items).subqueryload(Item.keywords))
+
+ # lazily load Order.items, but when Items are loaded,
+ # subquery-load the keywords collection
+ query(Order).options(
+ lazyload(Order.items).subqueryload(Item.keywords))
+
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`subquery_eager_loading`
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "subquery"})
+
+
+@subqueryload._add_unbound_fn
+def subqueryload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.subqueryload, keys, False, {})
+
+
+@loader_option()
+def selectinload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using
+ SELECT IN eager loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ examples::
+
+ # selectin-load the "orders" collection on "User"
+ query(User).options(selectinload(User.orders))
+
+ # selectin-load Order.items and then Item.keywords
+ query(Order).options(
+ selectinload(Order.items).selectinload(Item.keywords))
+
+ # lazily load Order.items, but when Items are loaded,
+ # selectin-load the keywords collection
+ query(Order).options(
+ lazyload(Order.items).selectinload(Item.keywords))
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`selectin_eager_loading`
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "selectin"})
+
+
+@selectinload._add_unbound_fn
+def selectinload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.selectinload, keys, False, {})
+
+
+@loader_option()
+def lazyload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using "lazy"
+ loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`lazy_loading`
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "select"})
+
+
+@lazyload._add_unbound_fn
+def lazyload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.lazyload, keys, False, {})
+
+
+@loader_option()
+def immediateload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using
+ an immediate load with a per-attribute SELECT statement.
+
+ The load is achieved using the "lazyloader" strategy and does not
+ fire off any additional eager loaders.
+
+ The :func:`.immediateload` option is superseded in general
+ by the :func:`.selectinload` option, which performs the same task
+ more efficiently by emitting a SELECT for all loaded objects.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`selectin_eager_loading`
+
+ """
+ loader = loadopt.set_relationship_strategy(attr, {"lazy": "immediate"})
+ return loader
+
+
+@immediateload._add_unbound_fn
+def immediateload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.immediateload, keys, False, {})
+
+
+@loader_option()
+def noload(loadopt, attr):
+ """Indicate that the given relationship attribute should remain unloaded.
+
+ The relationship attribute will return ``None`` when accessed without
+ producing any loading effect.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ :func:`_orm.noload` applies to :func:`_orm.relationship` attributes; for
+ column-based attributes, see :func:`_orm.defer`.
+
+ .. note:: Setting this loading strategy as the default strategy
+ for a relationship using the :paramref:`.orm.relationship.lazy`
+ parameter may cause issues with flushes, such if a delete operation
+ needs to load related objects and instead ``None`` was returned.
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ """
+
+ return loadopt.set_relationship_strategy(attr, {"lazy": "noload"})
+
+
+@noload._add_unbound_fn
+def noload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.noload, keys, False, {})
+
+
+@loader_option()
+def raiseload(loadopt, attr, sql_only=False):
+ """Indicate that the given attribute should raise an error if accessed.
+
+ A relationship attribute configured with :func:`_orm.raiseload` will
+ raise an :exc:`~sqlalchemy.exc.InvalidRequestError` upon access. The
+ typical way this is useful is when an application is attempting to ensure
+ that all relationship attributes that are accessed in a particular context
+ would have been already loaded via eager loading. Instead of having
+ to read through SQL logs to ensure lazy loads aren't occurring, this
+ strategy will cause them to raise immediately.
+
+ :func:`_orm.raiseload` applies to :func:`_orm.relationship`
+ attributes only.
+ In order to apply raise-on-SQL behavior to a column-based attribute,
+ use the :paramref:`.orm.defer.raiseload` parameter on the :func:`.defer`
+ loader option.
+
+ :param sql_only: if True, raise only if the lazy load would emit SQL, but
+ not if it is only checking the identity map, or determining that the
+ related value should just be None due to missing keys. When False, the
+ strategy will raise for all varieties of relationship loading.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`loading_toplevel`
+
+ :ref:`prevent_lazy_with_raiseload`
+
+ :ref:`deferred_raiseload`
+
+ """
+
+ return loadopt.set_relationship_strategy(
+ attr, {"lazy": "raise_on_sql" if sql_only else "raise"}
+ )
+
+
+@raiseload._add_unbound_fn
+def raiseload(*keys, **kw):
+ return _UnboundLoad._from_keys(_UnboundLoad.raiseload, keys, False, kw)
+
+
+@loader_option()
+def defaultload(loadopt, attr):
+ """Indicate an attribute should load using its default loader style.
+
+ This method is used to link to other loader options further into
+ a chain of attributes without altering the loader style of the links
+ along the chain. For example, to set joined eager loading for an
+ element of an element::
+
+ session.query(MyClass).options(
+ defaultload(MyClass.someattribute).
+ joinedload(MyOtherClass.someotherattribute)
+ )
+
+ :func:`.defaultload` is also useful for setting column-level options
+ on a related class, namely that of :func:`.defer` and :func:`.undefer`::
+
+ session.query(MyClass).options(
+ defaultload(MyClass.someattribute).
+ defer("some_column").
+ undefer("some_other_column")
+ )
+
+ .. seealso::
+
+ :meth:`_orm.Load.options` - allows for complex hierarchical
+ loader option structures with less verbosity than with individual
+ :func:`.defaultload` directives.
+
+ :ref:`relationship_loader_options`
+
+ :ref:`deferred_loading_w_multiple`
+
+ """
+ return loadopt.set_relationship_strategy(attr, None)
+
+
+@defaultload._add_unbound_fn
+def defaultload(*keys):
+ return _UnboundLoad._from_keys(_UnboundLoad.defaultload, keys, False, {})
+
+
+@loader_option()
+def defer(loadopt, key, raiseload=False):
+ r"""Indicate that the given column-oriented attribute should be deferred,
+ e.g. not loaded until accessed.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ e.g.::
+
+ from sqlalchemy.orm import defer
+
+ session.query(MyClass).options(
+ defer("attribute_one"),
+ defer("attribute_two"))
+
+ session.query(MyClass).options(
+ defer(MyClass.attribute_one),
+ defer(MyClass.attribute_two))
+
+ To specify a deferred load of an attribute on a related class,
+ the path can be specified one token at a time, specifying the loading
+ style for each link along the chain. To leave the loading style
+ for a link unchanged, use :func:`_orm.defaultload`::
+
+ session.query(MyClass).options(defaultload("someattr").defer("some_column"))
+
+ A :class:`_orm.Load` object that is present on a certain path can have
+ :meth:`_orm.Load.defer` called multiple times,
+ each will operate on the same
+ parent entity::
+
+
+ session.query(MyClass).options(
+ defaultload("someattr").
+ defer("some_column").
+ defer("some_other_column").
+ defer("another_column")
+ )
+
+ :param key: Attribute to be deferred.
+
+ :param raiseload: raise :class:`.InvalidRequestError` if the column
+ value is to be loaded from emitting SQL. Used to prevent unwanted
+ SQL from being emitted.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`deferred_raiseload`
+
+ :param \*addl_attrs: This option supports the old 0.8 style
+ of specifying a path as a series of attributes, which is now superseded
+ by the method-chained style.
+
+ .. deprecated:: 0.9 The \*addl_attrs on :func:`_orm.defer` is
+ deprecated and will be removed in a future release. Please
+ use method chaining in conjunction with defaultload() to
+ indicate a path.
+
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ :func:`_orm.undefer`
+
+ """
+ strategy = {"deferred": True, "instrument": True}
+ if raiseload:
+ strategy["raiseload"] = True
+ return loadopt.set_column_strategy((key,), strategy)
+
+
+@defer._add_unbound_fn
+def defer(key, *addl_attrs, **kw):
+ if addl_attrs:
+ util.warn_deprecated(
+ "The *addl_attrs on orm.defer is deprecated. Please use "
+ "method chaining in conjunction with defaultload() to "
+ "indicate a path.",
+ version="1.3",
+ )
+ return _UnboundLoad._from_keys(
+ _UnboundLoad.defer, (key,) + addl_attrs, False, kw
+ )
+
+
+@loader_option()
+def undefer(loadopt, key):
+ r"""Indicate that the given column-oriented attribute should be undeferred,
+ e.g. specified within the SELECT statement of the entity as a whole.
+
+ The column being undeferred is typically set up on the mapping as a
+ :func:`.deferred` attribute.
+
+ This function is part of the :class:`_orm.Load` interface and supports
+ both method-chained and standalone operation.
+
+ Examples::
+
+ # undefer two columns
+ session.query(MyClass).options(undefer("col1"), undefer("col2"))
+
+ # undefer all columns specific to a single class using Load + *
+ session.query(MyClass, MyOtherClass).options(
+ Load(MyClass).undefer("*"))
+
+ # undefer a column on a related object
+ session.query(MyClass).options(
+ defaultload(MyClass.items).undefer('text'))
+
+ :param key: Attribute to be undeferred.
+
+ :param \*addl_attrs: This option supports the old 0.8 style
+ of specifying a path as a series of attributes, which is now superseded
+ by the method-chained style.
+
+ .. deprecated:: 0.9 The \*addl_attrs on :func:`_orm.undefer` is
+ deprecated and will be removed in a future release. Please
+ use method chaining in conjunction with defaultload() to
+ indicate a path.
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ :func:`_orm.defer`
+
+ :func:`_orm.undefer_group`
+
+ """
+ return loadopt.set_column_strategy(
+ (key,), {"deferred": False, "instrument": True}
+ )
+
+
+@undefer._add_unbound_fn
+def undefer(key, *addl_attrs):
+ if addl_attrs:
+ util.warn_deprecated(
+ "The *addl_attrs on orm.undefer is deprecated. Please use "
+ "method chaining in conjunction with defaultload() to "
+ "indicate a path.",
+ version="1.3",
+ )
+ return _UnboundLoad._from_keys(
+ _UnboundLoad.undefer, (key,) + addl_attrs, False, {}
+ )
+
+
+@loader_option()
+def undefer_group(loadopt, name):
+ """Indicate that columns within the given deferred group name should be
+ undeferred.
+
+ The columns being undeferred are set up on the mapping as
+ :func:`.deferred` attributes and include a "group" name.
+
+ E.g::
+
+ session.query(MyClass).options(undefer_group("large_attrs"))
+
+ To undefer a group of attributes on a related entity, the path can be
+ spelled out using relationship loader options, such as
+ :func:`_orm.defaultload`::
+
+ session.query(MyClass).options(
+ defaultload("someattr").undefer_group("large_attrs"))
+
+ .. versionchanged:: 0.9.0 :func:`_orm.undefer_group` is now specific to a
+ particular entity load path.
+
+ .. seealso::
+
+ :ref:`deferred`
+
+ :func:`_orm.defer`
+
+ :func:`_orm.undefer`
+
+ """
+ return loadopt.set_column_strategy(
+ "*", None, {"undefer_group_%s" % name: True}, opts_only=True
+ )
+
+
+@undefer_group._add_unbound_fn
+def undefer_group(name):
+ return _UnboundLoad().undefer_group(name)
+
+
+@loader_option()
+def with_expression(loadopt, key, expression):
+ r"""Apply an ad-hoc SQL expression to a "deferred expression" attribute.
+
+ This option is used in conjunction with the :func:`_orm.query_expression`
+ mapper-level construct that indicates an attribute which should be the
+ target of an ad-hoc SQL expression.
+
+ E.g.::
+
+
+ sess.query(SomeClass).options(
+ with_expression(SomeClass.x_y_expr, SomeClass.x + SomeClass.y)
+ )
+
+ .. versionadded:: 1.2
+
+ :param key: Attribute to be undeferred.
+
+ :param expr: SQL expression to be applied to the attribute.
+
+ .. note:: the target attribute is populated only if the target object
+ is **not currently loaded** in the current :class:`_orm.Session`
+ unless the :meth:`_query.Query.populate_existing` method is used.
+ Please refer to :ref:`mapper_querytime_expression` for complete
+ usage details.
+
+ .. seealso::
+
+ :ref:`mapper_querytime_expression`
+
+ """
+
+ expression = coercions.expect(
+ roles.LabeledColumnExprRole, _orm_full_deannotate(expression)
+ )
+
+ return loadopt.set_column_strategy(
+ (key,), {"query_expression": True}, opts={"expression": expression}
+ )
+
+
+@with_expression._add_unbound_fn
+def with_expression(key, expression):
+ return _UnboundLoad._from_keys(
+ _UnboundLoad.with_expression, (key,), False, {"expression": expression}
+ )
+
+
+@loader_option()
+def selectin_polymorphic(loadopt, classes):
+ """Indicate an eager load should take place for all attributes
+ specific to a subclass.
+
+ This uses an additional SELECT with IN against all matched primary
+ key values, and is the per-query analogue to the ``"selectin"``
+ setting on the :paramref:`.mapper.polymorphic_load` parameter.
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`polymorphic_selectin`
+
+ """
+ loadopt.set_class_strategy(
+ {"selectinload_polymorphic": True},
+ opts={
+ "entities": tuple(
+ sorted((inspect(cls) for cls in classes), key=id)
+ )
+ },
+ )
+ return loadopt
+
+
+@selectin_polymorphic._add_unbound_fn
+def selectin_polymorphic(base_cls, classes):
+ ul = _UnboundLoad()
+ ul.is_class_strategy = True
+ ul.path = (inspect(base_cls),)
+ ul.selectin_polymorphic(classes)
+ return ul
diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py
new file mode 100644
index 0000000..c041804
--- /dev/null
+++ b/lib/sqlalchemy/orm/sync.py
@@ -0,0 +1,167 @@
+# orm/sync.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""private module containing functions used for copying data
+between instances based on join conditions.
+
+"""
+
+from . import attributes
+from . import exc
+from . import util as orm_util
+from .. import util
+
+
+def populate(
+ source,
+ source_mapper,
+ dest,
+ dest_mapper,
+ synchronize_pairs,
+ uowcommit,
+ flag_cascaded_pks,
+):
+ source_dict = source.dict
+ dest_dict = dest.dict
+
+ for l, r in synchronize_pairs:
+ try:
+ # inline of source_mapper._get_state_attr_by_column
+ prop = source_mapper._columntoproperty[l]
+ value = source.manager[prop.key].impl.get(
+ source, source_dict, attributes.PASSIVE_OFF
+ )
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err)
+
+ try:
+ # inline of dest_mapper._set_state_attr_by_column
+ prop = dest_mapper._columntoproperty[r]
+ dest.manager[prop.key].impl.set(dest, dest_dict, value, None)
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err)
+
+ # technically the "r.primary_key" check isn't
+ # needed here, but we check for this condition to limit
+ # how often this logic is invoked for memory/performance
+ # reasons, since we only need this info for a primary key
+ # destination.
+ if (
+ flag_cascaded_pks
+ and l.primary_key
+ and r.primary_key
+ and r.references(l)
+ ):
+ uowcommit.attributes[("pk_cascaded", dest, r)] = True
+
+
+def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs):
+ # a simplified version of populate() used by bulk insert mode
+ for l, r in synchronize_pairs:
+ try:
+ prop = source_mapper._columntoproperty[l]
+ value = source_dict[prop.key]
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err)
+
+ try:
+ prop = source_mapper._columntoproperty[r]
+ source_dict[prop.key] = value
+ except exc.UnmappedColumnError:
+ _raise_col_to_prop(True, source_mapper, l, source_mapper, r)
+
+
+def clear(dest, dest_mapper, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ if (
+ r.primary_key
+ and dest_mapper._get_state_attr_by_column(dest, dest.dict, r)
+ not in orm_util._none_set
+ ):
+
+ raise AssertionError(
+ "Dependency rule tried to blank-out primary key "
+ "column '%s' on instance '%s'" % (r, orm_util.state_str(dest))
+ )
+ try:
+ dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None)
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(True, None, l, dest_mapper, r, err)
+
+
+def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ try:
+ oldvalue = source_mapper._get_committed_attr_by_column(
+ source.obj(), l
+ )
+ value = source_mapper._get_state_attr_by_column(
+ source, source.dict, l, passive=attributes.PASSIVE_OFF
+ )
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, None, r, err)
+ dest[r.key] = value
+ dest[old_prefix + r.key] = oldvalue
+
+
+def populate_dict(source, source_mapper, dict_, synchronize_pairs):
+ for l, r in synchronize_pairs:
+ try:
+ value = source_mapper._get_state_attr_by_column(
+ source, source.dict, l, passive=attributes.PASSIVE_OFF
+ )
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, None, r, err)
+
+ dict_[r.key] = value
+
+
+def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
+ """return true if the source object has changes from an old to a
+ new value on the given synchronize pairs
+
+ """
+ for l, r in synchronize_pairs:
+ try:
+ prop = source_mapper._columntoproperty[l]
+ except exc.UnmappedColumnError as err:
+ _raise_col_to_prop(False, source_mapper, l, None, r, err)
+ history = uowcommit.get_attribute_history(
+ source, prop.key, attributes.PASSIVE_NO_INITIALIZE
+ )
+ if bool(history.deleted):
+ return True
+ else:
+ return False
+
+
+def _raise_col_to_prop(
+ isdest, source_mapper, source_column, dest_mapper, dest_column, err
+):
+ if isdest:
+ util.raise_(
+ exc.UnmappedColumnError(
+ "Can't execute sync rule for "
+ "destination column '%s'; mapper '%s' does not map "
+ "this column. Try using an explicit `foreign_keys` "
+ "collection which does not include this column (or use "
+ "a viewonly=True relation)." % (dest_column, dest_mapper)
+ ),
+ replace_context=err,
+ )
+ else:
+ util.raise_(
+ exc.UnmappedColumnError(
+ "Can't execute sync rule for "
+ "source column '%s'; mapper '%s' does not map this "
+ "column. Try using an explicit `foreign_keys` "
+ "collection which does not include destination column "
+ "'%s' (or use a viewonly=True relation)."
+ % (source_column, source_mapper, dest_column)
+ ),
+ replace_context=err,
+ )
diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py
new file mode 100644
index 0000000..2257637
--- /dev/null
+++ b/lib/sqlalchemy/orm/unitofwork.py
@@ -0,0 +1,784 @@
+# orm/unitofwork.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The internals for the unit of work system.
+
+The session's flush() process passes objects to a contextual object
+here, which assembles flush tasks based on mappers and their properties,
+organizes them in order of dependency, and executes.
+
+"""
+
+from . import attributes
+from . import exc as orm_exc
+from . import util as orm_util
+from .. import event
+from .. import util
+from ..util import topological
+
+
+def _warn_for_cascade_backrefs(state, prop):
+ util.warn_deprecated_20(
+ '"%s" object is being merged into a Session along the backref '
+ 'cascade path for relationship "%s"; in SQLAlchemy 2.0, this '
+ "reverse cascade will not take place. Set cascade_backrefs to "
+ "False in either the relationship() or backref() function for "
+ "the 2.0 behavior; or to set globally for the whole "
+ "Session, set the future=True flag" % (state.class_.__name__, prop),
+ code="s9r1",
+ )
+
+
+def track_cascade_events(descriptor, prop):
+ """Establish event listeners on object attributes which handle
+ cascade-on-set/append.
+
+ """
+ key = prop.key
+
+ def append(state, item, initiator):
+ # process "save_update" cascade rules for when
+ # an instance is appended to the list of another instance
+
+ if item is None:
+ return
+
+ sess = state.session
+ if sess:
+ if sess._warn_on_events:
+ sess._flush_warning("collection append")
+
+ prop = state.manager.mapper._props[key]
+ item_state = attributes.instance_state(item)
+
+ if (
+ prop._cascade.save_update
+ and (
+ (prop.cascade_backrefs and not sess.future)
+ or key == initiator.key
+ )
+ and not sess._contains_state(item_state)
+ ):
+ if key != initiator.key:
+ _warn_for_cascade_backrefs(item_state, prop)
+ sess._save_or_update_state(item_state)
+ return item
+
+ def remove(state, item, initiator):
+ if item is None:
+ return
+
+ sess = state.session
+
+ prop = state.manager.mapper._props[key]
+
+ if sess and sess._warn_on_events:
+ sess._flush_warning(
+ "collection remove"
+ if prop.uselist
+ else "related attribute delete"
+ )
+
+ if (
+ item is not None
+ and item is not attributes.NEVER_SET
+ and item is not attributes.PASSIVE_NO_RESULT
+ and prop._cascade.delete_orphan
+ ):
+ # expunge pending orphans
+ item_state = attributes.instance_state(item)
+
+ if prop.mapper._is_orphan(item_state):
+ if sess and item_state in sess._new:
+ sess.expunge(item)
+ else:
+ # the related item may or may not itself be in a
+ # Session, however the parent for which we are catching
+ # the event is not in a session, so memoize this on the
+ # item
+ item_state._orphaned_outside_of_session = True
+
+ def set_(state, newvalue, oldvalue, initiator):
+ # process "save_update" cascade rules for when an instance
+ # is attached to another instance
+ if oldvalue is newvalue:
+ return newvalue
+
+ sess = state.session
+ if sess:
+
+ if sess._warn_on_events:
+ sess._flush_warning("related attribute set")
+
+ prop = state.manager.mapper._props[key]
+ if newvalue is not None:
+ newvalue_state = attributes.instance_state(newvalue)
+ if (
+ prop._cascade.save_update
+ and (
+ (prop.cascade_backrefs and not sess.future)
+ or key == initiator.key
+ )
+ and not sess._contains_state(newvalue_state)
+ ):
+ if key != initiator.key:
+ _warn_for_cascade_backrefs(newvalue_state, prop)
+ sess._save_or_update_state(newvalue_state)
+
+ if (
+ oldvalue is not None
+ and oldvalue is not attributes.NEVER_SET
+ and oldvalue is not attributes.PASSIVE_NO_RESULT
+ and prop._cascade.delete_orphan
+ ):
+ # possible to reach here with attributes.NEVER_SET ?
+ oldvalue_state = attributes.instance_state(oldvalue)
+
+ if oldvalue_state in sess._new and prop.mapper._is_orphan(
+ oldvalue_state
+ ):
+ sess.expunge(oldvalue)
+ return newvalue
+
+ event.listen(descriptor, "append_wo_mutation", append, raw=True)
+ event.listen(descriptor, "append", append, raw=True, retval=True)
+ event.listen(descriptor, "remove", remove, raw=True, retval=True)
+ event.listen(descriptor, "set", set_, raw=True, retval=True)
+
+
+class UOWTransaction(object):
+ def __init__(self, session):
+ self.session = session
+
+ # dictionary used by external actors to
+ # store arbitrary state information.
+ self.attributes = {}
+
+ # dictionary of mappers to sets of
+ # DependencyProcessors, which are also
+ # set to be part of the sorted flush actions,
+ # which have that mapper as a parent.
+ self.deps = util.defaultdict(set)
+
+ # dictionary of mappers to sets of InstanceState
+ # items pending for flush which have that mapper
+ # as a parent.
+ self.mappers = util.defaultdict(set)
+
+ # a dictionary of Preprocess objects, which gather
+ # additional states impacted by the flush
+ # and determine if a flush action is needed
+ self.presort_actions = {}
+
+ # dictionary of PostSortRec objects, each
+ # one issues work during the flush within
+ # a certain ordering.
+ self.postsort_actions = {}
+
+ # a set of 2-tuples, each containing two
+ # PostSortRec objects where the second
+ # is dependent on the first being executed
+ # first
+ self.dependencies = set()
+
+ # dictionary of InstanceState-> (isdelete, listonly)
+ # tuples, indicating if this state is to be deleted
+ # or insert/updated, or just refreshed
+ self.states = {}
+
+ # tracks InstanceStates which will be receiving
+ # a "post update" call. Keys are mappers,
+ # values are a set of states and a set of the
+ # columns which should be included in the update.
+ self.post_update_states = util.defaultdict(lambda: (set(), set()))
+
+ @property
+ def has_work(self):
+ return bool(self.states)
+
+ def was_already_deleted(self, state):
+ """Return ``True`` if the given state is expired and was deleted
+ previously.
+ """
+ if state.expired:
+ try:
+ state._load_expired(state, attributes.PASSIVE_OFF)
+ except orm_exc.ObjectDeletedError:
+ self.session._remove_newly_deleted([state])
+ return True
+ return False
+
+ def is_deleted(self, state):
+ """Return ``True`` if the given state is marked as deleted
+ within this uowtransaction."""
+
+ return state in self.states and self.states[state][0]
+
+ def memo(self, key, callable_):
+ if key in self.attributes:
+ return self.attributes[key]
+ else:
+ self.attributes[key] = ret = callable_()
+ return ret
+
+ def remove_state_actions(self, state):
+ """Remove pending actions for a state from the uowtransaction."""
+
+ isdelete = self.states[state][0]
+
+ self.states[state] = (isdelete, True)
+
+ def get_attribute_history(
+ self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE
+ ):
+ """Facade to attributes.get_state_history(), including
+ caching of results."""
+
+ hashkey = ("history", state, key)
+
+ # cache the objects, not the states; the strong reference here
+ # prevents newly loaded objects from being dereferenced during the
+ # flush process
+
+ if hashkey in self.attributes:
+ history, state_history, cached_passive = self.attributes[hashkey]
+ # if the cached lookup was "passive" and now
+ # we want non-passive, do a non-passive lookup and re-cache
+
+ if (
+ not cached_passive & attributes.SQL_OK
+ and passive & attributes.SQL_OK
+ ):
+ impl = state.manager[key].impl
+ history = impl.get_history(
+ state,
+ state.dict,
+ attributes.PASSIVE_OFF
+ | attributes.LOAD_AGAINST_COMMITTED
+ | attributes.NO_RAISE,
+ )
+ if history and impl.uses_objects:
+ state_history = history.as_state()
+ else:
+ state_history = history
+ self.attributes[hashkey] = (history, state_history, passive)
+ else:
+ impl = state.manager[key].impl
+ # TODO: store the history as (state, object) tuples
+ # so we don't have to keep converting here
+ history = impl.get_history(
+ state,
+ state.dict,
+ passive
+ | attributes.LOAD_AGAINST_COMMITTED
+ | attributes.NO_RAISE,
+ )
+ if history and impl.uses_objects:
+ state_history = history.as_state()
+ else:
+ state_history = history
+ self.attributes[hashkey] = (history, state_history, passive)
+
+ return state_history
+
+ def has_dep(self, processor):
+ return (processor, True) in self.presort_actions
+
+ def register_preprocessor(self, processor, fromparent):
+ key = (processor, fromparent)
+ if key not in self.presort_actions:
+ self.presort_actions[key] = Preprocess(processor, fromparent)
+
+ def register_object(
+ self,
+ state,
+ isdelete=False,
+ listonly=False,
+ cancel_delete=False,
+ operation=None,
+ prop=None,
+ ):
+ if not self.session._contains_state(state):
+ # this condition is normal when objects are registered
+ # as part of a relationship cascade operation. it should
+ # not occur for the top-level register from Session.flush().
+ if not state.deleted and operation is not None:
+ util.warn(
+ "Object of type %s not in session, %s operation "
+ "along '%s' will not proceed"
+ % (orm_util.state_class_str(state), operation, prop)
+ )
+ return False
+
+ if state not in self.states:
+ mapper = state.manager.mapper
+
+ if mapper not in self.mappers:
+ self._per_mapper_flush_actions(mapper)
+
+ self.mappers[mapper].add(state)
+ self.states[state] = (isdelete, listonly)
+ else:
+ if not listonly and (isdelete or cancel_delete):
+ self.states[state] = (isdelete, False)
+ return True
+
+ def register_post_update(self, state, post_update_cols):
+ mapper = state.manager.mapper.base_mapper
+ states, cols = self.post_update_states[mapper]
+ states.add(state)
+ cols.update(post_update_cols)
+
+ def _per_mapper_flush_actions(self, mapper):
+ saves = SaveUpdateAll(self, mapper.base_mapper)
+ deletes = DeleteAll(self, mapper.base_mapper)
+ self.dependencies.add((saves, deletes))
+
+ for dep in mapper._dependency_processors:
+ dep.per_property_preprocessors(self)
+
+ for prop in mapper.relationships:
+ if prop.viewonly:
+ continue
+ dep = prop._dependency_processor
+ dep.per_property_preprocessors(self)
+
+ @util.memoized_property
+ def _mapper_for_dep(self):
+ """return a dynamic mapping of (Mapper, DependencyProcessor) to
+ True or False, indicating if the DependencyProcessor operates
+ on objects of that Mapper.
+
+ The result is stored in the dictionary persistently once
+ calculated.
+
+ """
+ return util.PopulateDict(
+ lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop
+ )
+
+ def filter_states_for_dep(self, dep, states):
+ """Filter the given list of InstanceStates to those relevant to the
+ given DependencyProcessor.
+
+ """
+ mapper_for_dep = self._mapper_for_dep
+ return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]]
+
+ def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
+ checktup = (isdelete, listonly)
+ for mapper in mapper.base_mapper.self_and_descendants:
+ for state in self.mappers[mapper]:
+ if self.states[state] == checktup:
+ yield state
+
+ def _generate_actions(self):
+ """Generate the full, unsorted collection of PostSortRecs as
+ well as dependency pairs for this UOWTransaction.
+
+ """
+ # execute presort_actions, until all states
+ # have been processed. a presort_action might
+ # add new states to the uow.
+ while True:
+ ret = False
+ for action in list(self.presort_actions.values()):
+ if action.execute(self):
+ ret = True
+ if not ret:
+ break
+
+ # see if the graph of mapper dependencies has cycles.
+ self.cycles = cycles = topological.find_cycles(
+ self.dependencies, list(self.postsort_actions.values())
+ )
+
+ if cycles:
+ # if yes, break the per-mapper actions into
+ # per-state actions
+ convert = dict(
+ (rec, set(rec.per_state_flush_actions(self))) for rec in cycles
+ )
+
+ # rewrite the existing dependencies to point to
+ # the per-state actions for those per-mapper actions
+ # that were broken up.
+ for edge in list(self.dependencies):
+ if (
+ None in edge
+ or edge[0].disabled
+ or edge[1].disabled
+ or cycles.issuperset(edge)
+ ):
+ self.dependencies.remove(edge)
+ elif edge[0] in cycles:
+ self.dependencies.remove(edge)
+ for dep in convert[edge[0]]:
+ self.dependencies.add((dep, edge[1]))
+ elif edge[1] in cycles:
+ self.dependencies.remove(edge)
+ for dep in convert[edge[1]]:
+ self.dependencies.add((edge[0], dep))
+
+ return set(
+ [a for a in self.postsort_actions.values() if not a.disabled]
+ ).difference(cycles)
+
+ def execute(self):
+ postsort_actions = self._generate_actions()
+
+ postsort_actions = sorted(
+ postsort_actions,
+ key=lambda item: item.sort_key,
+ )
+ # sort = topological.sort(self.dependencies, postsort_actions)
+ # print "--------------"
+ # print "\ndependencies:", self.dependencies
+ # print "\ncycles:", self.cycles
+ # print "\nsort:", list(sort)
+ # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions)
+
+ # execute
+ if self.cycles:
+ for subset in topological.sort_as_subsets(
+ self.dependencies, postsort_actions
+ ):
+ set_ = set(subset)
+ while set_:
+ n = set_.pop()
+ n.execute_aggregate(self, set_)
+ else:
+ for rec in topological.sort(self.dependencies, postsort_actions):
+ rec.execute(self)
+
+ def finalize_flush_changes(self):
+ """Mark processed objects as clean / deleted after a successful
+ flush().
+
+ This method is called within the flush() method after the
+ execute() method has succeeded and the transaction has been committed.
+
+ """
+ if not self.states:
+ return
+
+ states = set(self.states)
+ isdel = set(
+ s for (s, (isdelete, listonly)) in self.states.items() if isdelete
+ )
+ other = states.difference(isdel)
+ if isdel:
+ self.session._remove_newly_deleted(isdel)
+ if other:
+ self.session._register_persistent(other)
+
+
+class IterateMappersMixin(object):
+ def _mappers(self, uow):
+ if self.fromparent:
+ return iter(
+ m
+ for m in self.dependency_processor.parent.self_and_descendants
+ if uow._mapper_for_dep[(m, self.dependency_processor)]
+ )
+ else:
+ return self.dependency_processor.mapper.self_and_descendants
+
+
+class Preprocess(IterateMappersMixin):
+ __slots__ = (
+ "dependency_processor",
+ "fromparent",
+ "processed",
+ "setup_flush_actions",
+ )
+
+ def __init__(self, dependency_processor, fromparent):
+ self.dependency_processor = dependency_processor
+ self.fromparent = fromparent
+ self.processed = set()
+ self.setup_flush_actions = False
+
+ def execute(self, uow):
+ delete_states = set()
+ save_states = set()
+
+ for mapper in self._mappers(uow):
+ for state in uow.mappers[mapper].difference(self.processed):
+ (isdelete, listonly) = uow.states[state]
+ if not listonly:
+ if isdelete:
+ delete_states.add(state)
+ else:
+ save_states.add(state)
+
+ if delete_states:
+ self.dependency_processor.presort_deletes(uow, delete_states)
+ self.processed.update(delete_states)
+ if save_states:
+ self.dependency_processor.presort_saves(uow, save_states)
+ self.processed.update(save_states)
+
+ if delete_states or save_states:
+ if not self.setup_flush_actions and (
+ self.dependency_processor.prop_has_changes(
+ uow, delete_states, True
+ )
+ or self.dependency_processor.prop_has_changes(
+ uow, save_states, False
+ )
+ ):
+ self.dependency_processor.per_property_flush_actions(uow)
+ self.setup_flush_actions = True
+ return True
+ else:
+ return False
+
+
+class PostSortRec(object):
+ __slots__ = ("disabled",)
+
+ def __new__(cls, uow, *args):
+ key = (cls,) + args
+ if key in uow.postsort_actions:
+ return uow.postsort_actions[key]
+ else:
+ uow.postsort_actions[key] = ret = object.__new__(cls)
+ ret.disabled = False
+ return ret
+
+ def execute_aggregate(self, uow, recs):
+ self.execute(uow)
+
+
+class ProcessAll(IterateMappersMixin, PostSortRec):
+ __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key"
+
+ def __init__(self, uow, dependency_processor, isdelete, fromparent):
+ self.dependency_processor = dependency_processor
+ self.sort_key = (
+ "ProcessAll",
+ self.dependency_processor.sort_key,
+ isdelete,
+ )
+ self.isdelete = isdelete
+ self.fromparent = fromparent
+ uow.deps[dependency_processor.parent.base_mapper].add(
+ dependency_processor
+ )
+
+ def execute(self, uow):
+ states = self._elements(uow)
+ if self.isdelete:
+ self.dependency_processor.process_deletes(uow, states)
+ else:
+ self.dependency_processor.process_saves(uow, states)
+
+ def per_state_flush_actions(self, uow):
+ # this is handled by SaveUpdateAll and DeleteAll,
+ # since a ProcessAll should unconditionally be pulled
+ # into per-state if either the parent/child mappers
+ # are part of a cycle
+ return iter([])
+
+ def __repr__(self):
+ return "%s(%s, isdelete=%s)" % (
+ self.__class__.__name__,
+ self.dependency_processor,
+ self.isdelete,
+ )
+
+ def _elements(self, uow):
+ for mapper in self._mappers(uow):
+ for state in uow.mappers[mapper]:
+ (isdelete, listonly) = uow.states[state]
+ if isdelete == self.isdelete and not listonly:
+ yield state
+
+
+class PostUpdateAll(PostSortRec):
+ __slots__ = "mapper", "isdelete", "sort_key"
+
+ def __init__(self, uow, mapper, isdelete):
+ self.mapper = mapper
+ self.isdelete = isdelete
+ self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete)
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute(self, uow):
+ persistence = util.preloaded.orm_persistence
+ states, cols = uow.post_update_states[self.mapper]
+ states = [s for s in states if uow.states[s][0] == self.isdelete]
+
+ persistence.post_update(self.mapper, states, uow, cols)
+
+
+class SaveUpdateAll(PostSortRec):
+ __slots__ = ("mapper", "sort_key")
+
+ def __init__(self, uow, mapper):
+ self.mapper = mapper
+ self.sort_key = ("SaveUpdateAll", mapper._sort_key)
+ assert mapper is mapper.base_mapper
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute(self, uow):
+ util.preloaded.orm_persistence.save_obj(
+ self.mapper,
+ uow.states_for_mapper_hierarchy(self.mapper, False, False),
+ uow,
+ )
+
+ def per_state_flush_actions(self, uow):
+ states = list(
+ uow.states_for_mapper_hierarchy(self.mapper, False, False)
+ )
+ base_mapper = self.mapper.base_mapper
+ delete_all = DeleteAll(uow, base_mapper)
+ for state in states:
+ # keep saves before deletes -
+ # this ensures 'row switch' operations work
+ action = SaveUpdateState(uow, state)
+ uow.dependencies.add((action, delete_all))
+ yield action
+
+ for dep in uow.deps[self.mapper]:
+ states_for_prop = uow.filter_states_for_dep(dep, states)
+ dep.per_state_flush_actions(uow, states_for_prop, False)
+
+ def __repr__(self):
+ return "%s(%s)" % (self.__class__.__name__, self.mapper)
+
+
+class DeleteAll(PostSortRec):
+ __slots__ = ("mapper", "sort_key")
+
+ def __init__(self, uow, mapper):
+ self.mapper = mapper
+ self.sort_key = ("DeleteAll", mapper._sort_key)
+ assert mapper is mapper.base_mapper
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute(self, uow):
+ util.preloaded.orm_persistence.delete_obj(
+ self.mapper,
+ uow.states_for_mapper_hierarchy(self.mapper, True, False),
+ uow,
+ )
+
+ def per_state_flush_actions(self, uow):
+ states = list(
+ uow.states_for_mapper_hierarchy(self.mapper, True, False)
+ )
+ base_mapper = self.mapper.base_mapper
+ save_all = SaveUpdateAll(uow, base_mapper)
+ for state in states:
+ # keep saves before deletes -
+ # this ensures 'row switch' operations work
+ action = DeleteState(uow, state)
+ uow.dependencies.add((save_all, action))
+ yield action
+
+ for dep in uow.deps[self.mapper]:
+ states_for_prop = uow.filter_states_for_dep(dep, states)
+ dep.per_state_flush_actions(uow, states_for_prop, True)
+
+ def __repr__(self):
+ return "%s(%s)" % (self.__class__.__name__, self.mapper)
+
+
+class ProcessState(PostSortRec):
+ __slots__ = "dependency_processor", "isdelete", "state", "sort_key"
+
+ def __init__(self, uow, dependency_processor, isdelete, state):
+ self.dependency_processor = dependency_processor
+ self.sort_key = ("ProcessState", dependency_processor.sort_key)
+ self.isdelete = isdelete
+ self.state = state
+
+ def execute_aggregate(self, uow, recs):
+ cls_ = self.__class__
+ dependency_processor = self.dependency_processor
+ isdelete = self.isdelete
+ our_recs = [
+ r
+ for r in recs
+ if r.__class__ is cls_
+ and r.dependency_processor is dependency_processor
+ and r.isdelete is isdelete
+ ]
+ recs.difference_update(our_recs)
+ states = [self.state] + [r.state for r in our_recs]
+ if isdelete:
+ dependency_processor.process_deletes(uow, states)
+ else:
+ dependency_processor.process_saves(uow, states)
+
+ def __repr__(self):
+ return "%s(%s, %s, delete=%s)" % (
+ self.__class__.__name__,
+ self.dependency_processor,
+ orm_util.state_str(self.state),
+ self.isdelete,
+ )
+
+
+class SaveUpdateState(PostSortRec):
+ __slots__ = "state", "mapper", "sort_key"
+
+ def __init__(self, uow, state):
+ self.state = state
+ self.mapper = state.mapper.base_mapper
+ self.sort_key = ("ProcessState", self.mapper._sort_key)
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute_aggregate(self, uow, recs):
+ persistence = util.preloaded.orm_persistence
+ cls_ = self.__class__
+ mapper = self.mapper
+ our_recs = [
+ r for r in recs if r.__class__ is cls_ and r.mapper is mapper
+ ]
+ recs.difference_update(our_recs)
+ persistence.save_obj(
+ mapper, [self.state] + [r.state for r in our_recs], uow
+ )
+
+ def __repr__(self):
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ orm_util.state_str(self.state),
+ )
+
+
+class DeleteState(PostSortRec):
+ __slots__ = "state", "mapper", "sort_key"
+
+ def __init__(self, uow, state):
+ self.state = state
+ self.mapper = state.mapper.base_mapper
+ self.sort_key = ("DeleteState", self.mapper._sort_key)
+
+ @util.preload_module("sqlalchemy.orm.persistence")
+ def execute_aggregate(self, uow, recs):
+ persistence = util.preloaded.orm_persistence
+ cls_ = self.__class__
+ mapper = self.mapper
+ our_recs = [
+ r for r in recs if r.__class__ is cls_ and r.mapper is mapper
+ ]
+ recs.difference_update(our_recs)
+ states = [self.state] + [r.state for r in our_recs]
+ persistence.delete_obj(
+ mapper, [s for s in states if uow.states[s][0]], uow
+ )
+
+ def __repr__(self):
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ orm_util.state_str(self.state),
+ )
diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py
new file mode 100644
index 0000000..56aa9ff
--- /dev/null
+++ b/lib/sqlalchemy/orm/util.py
@@ -0,0 +1,2149 @@
+# orm/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+import re
+import types
+import weakref
+
+from . import attributes # noqa
+from .base import _class_to_mapper # noqa
+from .base import _never_set # noqa
+from .base import _none_set # noqa
+from .base import attribute_str # noqa
+from .base import class_mapper # noqa
+from .base import InspectionAttr # noqa
+from .base import instance_str # noqa
+from .base import object_mapper # noqa
+from .base import object_state # noqa
+from .base import state_attribute_str # noqa
+from .base import state_class_str # noqa
+from .base import state_str # noqa
+from .interfaces import CriteriaOption
+from .interfaces import MapperProperty # noqa
+from .interfaces import ORMColumnsClauseRole
+from .interfaces import ORMEntityColumnsClauseRole
+from .interfaces import ORMFromClauseRole
+from .interfaces import PropComparator # noqa
+from .path_registry import PathRegistry # noqa
+from .. import event
+from .. import exc as sa_exc
+from .. import inspection
+from .. import sql
+from .. import util
+from ..engine.result import result_tuple
+from ..sql import base as sql_base
+from ..sql import coercions
+from ..sql import expression
+from ..sql import lambdas
+from ..sql import roles
+from ..sql import util as sql_util
+from ..sql import visitors
+from ..sql.annotation import SupportsCloneAnnotations
+from ..sql.base import ColumnCollection
+
+
+all_cascades = frozenset(
+ (
+ "delete",
+ "delete-orphan",
+ "all",
+ "merge",
+ "expunge",
+ "save-update",
+ "refresh-expire",
+ "none",
+ )
+)
+
+
+class CascadeOptions(frozenset):
+ """Keeps track of the options sent to
+ :paramref:`.relationship.cascade`"""
+
+ _add_w_all_cascades = all_cascades.difference(
+ ["all", "none", "delete-orphan"]
+ )
+ _allowed_cascades = all_cascades
+
+ _viewonly_cascades = ["expunge", "all", "none", "refresh-expire"]
+
+ __slots__ = (
+ "save_update",
+ "delete",
+ "refresh_expire",
+ "merge",
+ "expunge",
+ "delete_orphan",
+ )
+
+ def __new__(cls, value_list):
+ if isinstance(value_list, util.string_types) or value_list is None:
+ return cls.from_string(value_list)
+ values = set(value_list)
+ if values.difference(cls._allowed_cascades):
+ raise sa_exc.ArgumentError(
+ "Invalid cascade option(s): %s"
+ % ", ".join(
+ [
+ repr(x)
+ for x in sorted(
+ values.difference(cls._allowed_cascades)
+ )
+ ]
+ )
+ )
+
+ if "all" in values:
+ values.update(cls._add_w_all_cascades)
+ if "none" in values:
+ values.clear()
+ values.discard("all")
+
+ self = frozenset.__new__(CascadeOptions, values)
+ self.save_update = "save-update" in values
+ self.delete = "delete" in values
+ self.refresh_expire = "refresh-expire" in values
+ self.merge = "merge" in values
+ self.expunge = "expunge" in values
+ self.delete_orphan = "delete-orphan" in values
+
+ if self.delete_orphan and not self.delete:
+ util.warn(
+ "The 'delete-orphan' cascade " "option requires 'delete'."
+ )
+ return self
+
+ def __repr__(self):
+ return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)]))
+
+ @classmethod
+ def from_string(cls, arg):
+ values = [c for c in re.split(r"\s*,\s*", arg or "") if c]
+ return cls(values)
+
+
+def _validator_events(desc, key, validator, include_removes, include_backrefs):
+ """Runs a validation method on an attribute value to be set or
+ appended.
+ """
+
+ if not include_backrefs:
+
+ def detect_is_backref(state, initiator):
+ impl = state.manager[key].impl
+ return initiator.impl is not impl
+
+ if include_removes:
+
+ def append(state, value, initiator):
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
+ ):
+ return validator(state.obj(), key, value, False)
+ else:
+ return value
+
+ def bulk_set(state, values, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ obj = state.obj()
+ values[:] = [
+ validator(obj, key, value, False) for value in values
+ ]
+
+ def set_(state, value, oldvalue, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ return validator(state.obj(), key, value, False)
+ else:
+ return value
+
+ def remove(state, value, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ validator(state.obj(), key, value, True)
+
+ else:
+
+ def append(state, value, initiator):
+ if initiator.op is not attributes.OP_BULK_REPLACE and (
+ include_backrefs or not detect_is_backref(state, initiator)
+ ):
+ return validator(state.obj(), key, value)
+ else:
+ return value
+
+ def bulk_set(state, values, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ obj = state.obj()
+ values[:] = [validator(obj, key, value) for value in values]
+
+ def set_(state, value, oldvalue, initiator):
+ if include_backrefs or not detect_is_backref(state, initiator):
+ return validator(state.obj(), key, value)
+ else:
+ return value
+
+ event.listen(desc, "append", append, raw=True, retval=True)
+ event.listen(desc, "bulk_replace", bulk_set, raw=True)
+ event.listen(desc, "set", set_, raw=True, retval=True)
+ if include_removes:
+ event.listen(desc, "remove", remove, raw=True, retval=True)
+
+
+def polymorphic_union(
+ table_map, typecolname, aliasname="p_union", cast_nulls=True
+):
+ """Create a ``UNION`` statement used by a polymorphic mapper.
+
+ See :ref:`concrete_inheritance` for an example of how
+ this is used.
+
+ :param table_map: mapping of polymorphic identities to
+ :class:`_schema.Table` objects.
+ :param typecolname: string name of a "discriminator" column, which will be
+ derived from the query, producing the polymorphic identity for
+ each row. If ``None``, no polymorphic discriminator is generated.
+ :param aliasname: name of the :func:`~sqlalchemy.sql.expression.alias()`
+ construct generated.
+ :param cast_nulls: if True, non-existent columns, which are represented
+ as labeled NULLs, will be passed into CAST. This is a legacy behavior
+ that is problematic on some backends such as Oracle - in which case it
+ can be set to False.
+
+ """
+
+ colnames = util.OrderedSet()
+ colnamemaps = {}
+ types = {}
+ for key in table_map:
+ table = table_map[key]
+
+ table = coercions.expect(
+ roles.StrictFromClauseRole, table, allow_select=True
+ )
+ table_map[key] = table
+
+ m = {}
+ for c in table.c:
+ if c.key == typecolname:
+ raise sa_exc.InvalidRequestError(
+ "Polymorphic union can't use '%s' as the discriminator "
+ "column due to mapped column %r; please apply the "
+ "'typecolname' "
+ "argument; this is available on "
+ "ConcreteBase as '_concrete_discriminator_name'"
+ % (typecolname, c)
+ )
+ colnames.add(c.key)
+ m[c.key] = c
+ types[c.key] = c.type
+ colnamemaps[table] = m
+
+ def col(name, table):
+ try:
+ return colnamemaps[table][name]
+ except KeyError:
+ if cast_nulls:
+ return sql.cast(sql.null(), types[name]).label(name)
+ else:
+ return sql.type_coerce(sql.null(), types[name]).label(name)
+
+ result = []
+ for type_, table in table_map.items():
+ if typecolname is not None:
+ result.append(
+ sql.select(
+ *(
+ [col(name, table) for name in colnames]
+ + [
+ sql.literal_column(
+ sql_util._quote_ddl_expr(type_)
+ ).label(typecolname)
+ ]
+ )
+ ).select_from(table)
+ )
+ else:
+ result.append(
+ sql.select(
+ *[col(name, table) for name in colnames]
+ ).select_from(table)
+ )
+ return sql.union_all(*result).alias(aliasname)
+
+
+def identity_key(*args, **kwargs):
+ r"""Generate "identity key" tuples, as are used as keys in the
+ :attr:`.Session.identity_map` dictionary.
+
+ This function has several call styles:
+
+ * ``identity_key(class, ident, identity_token=token)``
+
+ This form receives a mapped class and a primary key scalar or
+ tuple as an argument.
+
+ E.g.::
+
+ >>> identity_key(MyClass, (1, 2))
+ (<class '__main__.MyClass'>, (1, 2), None)
+
+ :param class: mapped class (must be a positional argument)
+ :param ident: primary key, may be a scalar or tuple argument.
+ :param identity_token: optional identity token
+
+ .. versionadded:: 1.2 added identity_token
+
+
+ * ``identity_key(instance=instance)``
+
+ This form will produce the identity key for a given instance. The
+ instance need not be persistent, only that its primary key attributes
+ are populated (else the key will contain ``None`` for those missing
+ values).
+
+ E.g.::
+
+ >>> instance = MyClass(1, 2)
+ >>> identity_key(instance=instance)
+ (<class '__main__.MyClass'>, (1, 2), None)
+
+ In this form, the given instance is ultimately run though
+ :meth:`_orm.Mapper.identity_key_from_instance`, which will have the
+ effect of performing a database check for the corresponding row
+ if the object is expired.
+
+ :param instance: object instance (must be given as a keyword arg)
+
+ * ``identity_key(class, row=row, identity_token=token)``
+
+ This form is similar to the class/tuple form, except is passed a
+ database result row as a :class:`.Row` object.
+
+ E.g.::
+
+ >>> row = engine.execute(\
+ text("select * from table where a=1 and b=2")\
+ ).first()
+ >>> identity_key(MyClass, row=row)
+ (<class '__main__.MyClass'>, (1, 2), None)
+
+ :param class: mapped class (must be a positional argument)
+ :param row: :class:`.Row` row returned by a :class:`_engine.CursorResult`
+ (must be given as a keyword arg)
+ :param identity_token: optional identity token
+
+ .. versionadded:: 1.2 added identity_token
+
+ """
+ if args:
+ row = None
+ largs = len(args)
+ if largs == 1:
+ class_ = args[0]
+ try:
+ row = kwargs.pop("row")
+ except KeyError:
+ ident = kwargs.pop("ident")
+ elif largs in (2, 3):
+ class_, ident = args
+ else:
+ raise sa_exc.ArgumentError(
+ "expected up to three positional arguments, " "got %s" % largs
+ )
+
+ identity_token = kwargs.pop("identity_token", None)
+ if kwargs:
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs)
+ )
+ mapper = class_mapper(class_)
+ if row is None:
+ return mapper.identity_key_from_primary_key(
+ util.to_list(ident), identity_token=identity_token
+ )
+ else:
+ return mapper.identity_key_from_row(
+ row, identity_token=identity_token
+ )
+ else:
+ instance = kwargs.pop("instance")
+ if kwargs:
+ raise sa_exc.ArgumentError(
+ "unknown keyword arguments: %s" % ", ".join(kwargs.keys)
+ )
+ mapper = object_mapper(instance)
+ return mapper.identity_key_from_instance(instance)
+
+
+class ORMAdapter(sql_util.ColumnAdapter):
+ """ColumnAdapter subclass which excludes adaptation of entities from
+ non-matching mappers.
+
+ """
+
+ def __init__(
+ self,
+ entity,
+ equivalents=None,
+ adapt_required=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
+ info = inspection.inspect(entity)
+
+ self.mapper = info.mapper
+ selectable = info.selectable
+ is_aliased_class = info.is_aliased_class
+ if is_aliased_class:
+ self.aliased_class = entity
+ else:
+ self.aliased_class = None
+
+ sql_util.ColumnAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ adapt_required=adapt_required,
+ allow_label_resolve=allow_label_resolve,
+ anonymize_labels=anonymize_labels,
+ include_fn=self._include_fn,
+ )
+
+ def _include_fn(self, elem):
+ entity = elem._annotations.get("parentmapper", None)
+
+ return not entity or entity.isa(self.mapper) or self.mapper.isa(entity)
+
+
+class AliasedClass(object):
+ r"""Represents an "aliased" form of a mapped class for usage with Query.
+
+ The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias`
+ construct, this object mimics the mapped class using a
+ ``__getattr__`` scheme and maintains a reference to a
+ real :class:`~sqlalchemy.sql.expression.Alias` object.
+
+ A primary purpose of :class:`.AliasedClass` is to serve as an alternate
+ within a SQL statement generated by the ORM, such that an existing
+ mapped entity can be used in multiple contexts. A simple example::
+
+ # find all pairs of users with the same name
+ user_alias = aliased(User)
+ session.query(User, user_alias).\
+ join((user_alias, User.id > user_alias.id)).\
+ filter(User.name == user_alias.name)
+
+ :class:`.AliasedClass` is also capable of mapping an existing mapped
+ class to an entirely new selectable, provided this selectable is column-
+ compatible with the existing mapped selectable, and it can also be
+ configured in a mapping as the target of a :func:`_orm.relationship`.
+ See the links below for examples.
+
+ The :class:`.AliasedClass` object is constructed typically using the
+ :func:`_orm.aliased` function. It also is produced with additional
+ configuration when using the :func:`_orm.with_polymorphic` function.
+
+ The resulting object is an instance of :class:`.AliasedClass`.
+ This object implements an attribute scheme which produces the
+ same attribute and method interface as the original mapped
+ class, allowing :class:`.AliasedClass` to be compatible
+ with any attribute technique which works on the original class,
+ including hybrid attributes (see :ref:`hybrids_toplevel`).
+
+ The :class:`.AliasedClass` can be inspected for its underlying
+ :class:`_orm.Mapper`, aliased selectable, and other information
+ using :func:`_sa.inspect`::
+
+ from sqlalchemy import inspect
+ my_alias = aliased(MyClass)
+ insp = inspect(my_alias)
+
+ The resulting inspection object is an instance of :class:`.AliasedInsp`.
+
+
+ .. seealso::
+
+ :func:`.aliased`
+
+ :func:`.with_polymorphic`
+
+ :ref:`relationship_aliased_class`
+
+ :ref:`relationship_to_window_function`
+
+
+ """
+
+ def __init__(
+ self,
+ mapped_class_or_ac,
+ alias=None,
+ name=None,
+ flat=False,
+ adapt_on_names=False,
+ # TODO: None for default here?
+ with_polymorphic_mappers=(),
+ with_polymorphic_discriminator=None,
+ base_alias=None,
+ use_mapper_path=False,
+ represents_outer_join=False,
+ ):
+ insp = inspection.inspect(mapped_class_or_ac)
+ mapper = insp.mapper
+
+ nest_adapters = False
+
+ if alias is None:
+ if insp.is_aliased_class and insp.selectable._is_subquery:
+ alias = insp.selectable.alias()
+ else:
+ alias = (
+ mapper._with_polymorphic_selectable._anonymous_fromclause(
+ name=name,
+ flat=flat,
+ )
+ )
+ elif insp.is_aliased_class:
+ nest_adapters = True
+
+ self._aliased_insp = AliasedInsp(
+ self,
+ insp,
+ alias,
+ name,
+ with_polymorphic_mappers
+ if with_polymorphic_mappers
+ else mapper.with_polymorphic_mappers,
+ with_polymorphic_discriminator
+ if with_polymorphic_discriminator is not None
+ else mapper.polymorphic_on,
+ base_alias,
+ use_mapper_path,
+ adapt_on_names,
+ represents_outer_join,
+ nest_adapters,
+ )
+
+ self.__name__ = "AliasedClass_%s" % mapper.class_.__name__
+
+ @classmethod
+ def _reconstitute_from_aliased_insp(cls, aliased_insp):
+ obj = cls.__new__(cls)
+ obj.__name__ = "AliasedClass_%s" % aliased_insp.mapper.class_.__name__
+ obj._aliased_insp = aliased_insp
+
+ if aliased_insp._is_with_polymorphic:
+ for sub_aliased_insp in aliased_insp._with_polymorphic_entities:
+ if sub_aliased_insp is not aliased_insp:
+ ent = AliasedClass._reconstitute_from_aliased_insp(
+ sub_aliased_insp
+ )
+ setattr(obj, sub_aliased_insp.class_.__name__, ent)
+
+ return obj
+
+ def __getattr__(self, key):
+ try:
+ _aliased_insp = self.__dict__["_aliased_insp"]
+ except KeyError:
+ raise AttributeError()
+ else:
+ target = _aliased_insp._target
+ # maintain all getattr mechanics
+ attr = getattr(target, key)
+
+ # attribute is a method, that will be invoked against a
+ # "self"; so just return a new method with the same function and
+ # new self
+ if hasattr(attr, "__call__") and hasattr(attr, "__self__"):
+ return types.MethodType(attr.__func__, self)
+
+ # attribute is a descriptor, that will be invoked against a
+ # "self"; so invoke the descriptor against this self
+ if hasattr(attr, "__get__"):
+ attr = attr.__get__(None, self)
+
+ # attributes within the QueryableAttribute system will want this
+ # to be invoked so the object can be adapted
+ if hasattr(attr, "adapt_to_entity"):
+ attr = attr.adapt_to_entity(_aliased_insp)
+ setattr(self, key, attr)
+
+ return attr
+
+ def _get_from_serialized(self, key, mapped_class, aliased_insp):
+ # this method is only used in terms of the
+ # sqlalchemy.ext.serializer extension
+ attr = getattr(mapped_class, key)
+ if hasattr(attr, "__call__") and hasattr(attr, "__self__"):
+ return types.MethodType(attr.__func__, self)
+
+ # attribute is a descriptor, that will be invoked against a
+ # "self"; so invoke the descriptor against this self
+ if hasattr(attr, "__get__"):
+ attr = attr.__get__(None, self)
+
+ # attributes within the QueryableAttribute system will want this
+ # to be invoked so the object can be adapted
+ if hasattr(attr, "adapt_to_entity"):
+ aliased_insp._weak_entity = weakref.ref(self)
+ attr = attr.adapt_to_entity(aliased_insp)
+ setattr(self, key, attr)
+
+ return attr
+
+ def __repr__(self):
+ return "<AliasedClass at 0x%x; %s>" % (
+ id(self),
+ self._aliased_insp._target.__name__,
+ )
+
+ def __str__(self):
+ return str(self._aliased_insp)
+
+
+class AliasedInsp(
+ ORMEntityColumnsClauseRole,
+ ORMFromClauseRole,
+ sql_base.MemoizedHasCacheKey,
+ InspectionAttr,
+):
+ """Provide an inspection interface for an
+ :class:`.AliasedClass` object.
+
+ The :class:`.AliasedInsp` object is returned
+ given an :class:`.AliasedClass` using the
+ :func:`_sa.inspect` function::
+
+ from sqlalchemy import inspect
+ from sqlalchemy.orm import aliased
+
+ my_alias = aliased(MyMappedClass)
+ insp = inspect(my_alias)
+
+ Attributes on :class:`.AliasedInsp`
+ include:
+
+ * ``entity`` - the :class:`.AliasedClass` represented.
+ * ``mapper`` - the :class:`_orm.Mapper` mapping the underlying class.
+ * ``selectable`` - the :class:`_expression.Alias`
+ construct which ultimately
+ represents an aliased :class:`_schema.Table` or
+ :class:`_expression.Select`
+ construct.
+ * ``name`` - the name of the alias. Also is used as the attribute
+ name when returned in a result tuple from :class:`_query.Query`.
+ * ``with_polymorphic_mappers`` - collection of :class:`_orm.Mapper`
+ objects
+ indicating all those mappers expressed in the select construct
+ for the :class:`.AliasedClass`.
+ * ``polymorphic_on`` - an alternate column or SQL expression which
+ will be used as the "discriminator" for a polymorphic load.
+
+ .. seealso::
+
+ :ref:`inspection_toplevel`
+
+ """
+
+ def __init__(
+ self,
+ entity,
+ inspected,
+ selectable,
+ name,
+ with_polymorphic_mappers,
+ polymorphic_on,
+ _base_alias,
+ _use_mapper_path,
+ adapt_on_names,
+ represents_outer_join,
+ nest_adapters,
+ ):
+
+ mapped_class_or_ac = inspected.entity
+ mapper = inspected.mapper
+
+ self._weak_entity = weakref.ref(entity)
+ self.mapper = mapper
+ self.selectable = (
+ self.persist_selectable
+ ) = self.local_table = selectable
+ self.name = name
+ self.polymorphic_on = polymorphic_on
+ self._base_alias = weakref.ref(_base_alias or self)
+ self._use_mapper_path = _use_mapper_path
+ self.represents_outer_join = represents_outer_join
+ self._nest_adapters = nest_adapters
+
+ if with_polymorphic_mappers:
+ self._is_with_polymorphic = True
+ self.with_polymorphic_mappers = with_polymorphic_mappers
+ self._with_polymorphic_entities = []
+ for poly in self.with_polymorphic_mappers:
+ if poly is not mapper:
+ ent = AliasedClass(
+ poly.class_,
+ selectable,
+ base_alias=self,
+ adapt_on_names=adapt_on_names,
+ use_mapper_path=_use_mapper_path,
+ )
+
+ setattr(self.entity, poly.class_.__name__, ent)
+ self._with_polymorphic_entities.append(ent._aliased_insp)
+
+ else:
+ self._is_with_polymorphic = False
+ self.with_polymorphic_mappers = [mapper]
+
+ self._adapter = sql_util.ColumnAdapter(
+ selectable,
+ equivalents=mapper._equivalent_columns,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=True,
+ # make sure the adapter doesn't try to grab other tables that
+ # are not even the thing we are mapping, such as embedded
+ # selectables in subqueries or CTEs. See issue #6060
+ adapt_from_selectables={
+ m.selectable
+ for m in self.with_polymorphic_mappers
+ if not adapt_on_names
+ },
+ )
+
+ if nest_adapters:
+ self._adapter = inspected._adapter.wrap(self._adapter)
+
+ self._adapt_on_names = adapt_on_names
+ self._target = mapped_class_or_ac
+ # self._target = mapper.class_ # mapped_class_or_ac
+
+ @property
+ def entity(self):
+ # to eliminate reference cycles, the AliasedClass is held weakly.
+ # this produces some situations where the AliasedClass gets lost,
+ # particularly when one is created internally and only the AliasedInsp
+ # is passed around.
+ # to work around this case, we just generate a new one when we need
+ # it, as it is a simple class with very little initial state on it.
+ ent = self._weak_entity()
+ if ent is None:
+ ent = AliasedClass._reconstitute_from_aliased_insp(self)
+ self._weak_entity = weakref.ref(ent)
+ return ent
+
+ is_aliased_class = True
+ "always returns True"
+
+ @util.memoized_instancemethod
+ def __clause_element__(self):
+ return self.selectable._annotate(
+ {
+ "parentmapper": self.mapper,
+ "parententity": self,
+ "entity_namespace": self,
+ }
+ )._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ @property
+ def entity_namespace(self):
+ return self.entity
+
+ _cache_key_traversal = [
+ ("name", visitors.ExtendedInternalTraversal.dp_string),
+ ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean),
+ ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement),
+ ]
+
+ @property
+ def class_(self):
+ """Return the mapped class ultimately represented by this
+ :class:`.AliasedInsp`."""
+ return self.mapper.class_
+
+ @property
+ def _path_registry(self):
+ if self._use_mapper_path:
+ return self.mapper._path_registry
+ else:
+ return PathRegistry.per_mapper(self)
+
+ def __getstate__(self):
+ return {
+ "entity": self.entity,
+ "mapper": self.mapper,
+ "alias": self.selectable,
+ "name": self.name,
+ "adapt_on_names": self._adapt_on_names,
+ "with_polymorphic_mappers": self.with_polymorphic_mappers,
+ "with_polymorphic_discriminator": self.polymorphic_on,
+ "base_alias": self._base_alias(),
+ "use_mapper_path": self._use_mapper_path,
+ "represents_outer_join": self.represents_outer_join,
+ "nest_adapters": self._nest_adapters,
+ }
+
+ def __setstate__(self, state):
+ self.__init__(
+ state["entity"],
+ state["mapper"],
+ state["alias"],
+ state["name"],
+ state["with_polymorphic_mappers"],
+ state["with_polymorphic_discriminator"],
+ state["base_alias"],
+ state["use_mapper_path"],
+ state["adapt_on_names"],
+ state["represents_outer_join"],
+ state["nest_adapters"],
+ )
+
+ def _adapt_element(self, elem, key=None):
+ d = {
+ "parententity": self,
+ "parentmapper": self.mapper,
+ }
+ if key:
+ d["proxy_key"] = key
+ return (
+ self._adapter.traverse(elem)
+ ._annotate(d)
+ ._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+ )
+
+ def _entity_for_mapper(self, mapper):
+ self_poly = self.with_polymorphic_mappers
+ if mapper in self_poly:
+ if mapper is self.mapper:
+ return self
+ else:
+ return getattr(
+ self.entity, mapper.class_.__name__
+ )._aliased_insp
+ elif mapper.isa(self.mapper):
+ return self
+ else:
+ assert False, "mapper %s doesn't correspond to %s" % (mapper, self)
+
+ @util.memoized_property
+ def _get_clause(self):
+ onclause, replacemap = self.mapper._get_clause
+ return (
+ self._adapter.traverse(onclause),
+ {
+ self._adapter.traverse(col): param
+ for col, param in replacemap.items()
+ },
+ )
+
+ @util.memoized_property
+ def _memoized_values(self):
+ return {}
+
+ @util.memoized_property
+ def _all_column_expressions(self):
+ if self._is_with_polymorphic:
+ cols_plus_keys = self.mapper._columns_plus_keys(
+ [ent.mapper for ent in self._with_polymorphic_entities]
+ )
+ else:
+ cols_plus_keys = self.mapper._columns_plus_keys()
+
+ cols_plus_keys = [
+ (key, self._adapt_element(col)) for key, col in cols_plus_keys
+ ]
+
+ return ColumnCollection(cols_plus_keys)
+
+ def _memo(self, key, callable_, *args, **kw):
+ if key in self._memoized_values:
+ return self._memoized_values[key]
+ else:
+ self._memoized_values[key] = value = callable_(*args, **kw)
+ return value
+
+ def __repr__(self):
+ if self.with_polymorphic_mappers:
+ with_poly = "(%s)" % ", ".join(
+ mp.class_.__name__ for mp in self.with_polymorphic_mappers
+ )
+ else:
+ with_poly = ""
+ return "<AliasedInsp at 0x%x; %s%s>" % (
+ id(self),
+ self.class_.__name__,
+ with_poly,
+ )
+
+ def __str__(self):
+ if self._is_with_polymorphic:
+ return "with_polymorphic(%s, [%s])" % (
+ self._target.__name__,
+ ", ".join(
+ mp.class_.__name__
+ for mp in self.with_polymorphic_mappers
+ if mp is not self.mapper
+ ),
+ )
+ else:
+ return "aliased(%s)" % (self._target.__name__,)
+
+
+class _WrapUserEntity(object):
+ """A wrapper used within the loader_criteria lambda caller so that
+ we can bypass declared_attr descriptors on unmapped mixins, which
+ normally emit a warning for such use.
+
+ might also be useful for other per-lambda instrumentations should
+ the need arise.
+
+ """
+
+ __slots__ = ("subject",)
+
+ def __init__(self, subject):
+ self.subject = subject
+
+ @util.preload_module("sqlalchemy.orm.decl_api")
+ def __getattribute__(self, name):
+ decl_api = util.preloaded.orm.decl_api
+
+ subject = object.__getattribute__(self, "subject")
+ if name in subject.__dict__ and isinstance(
+ subject.__dict__[name], decl_api.declared_attr
+ ):
+ return subject.__dict__[name].fget(subject)
+ else:
+ return getattr(subject, name)
+
+
+class LoaderCriteriaOption(CriteriaOption):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ :class:`_orm.LoaderCriteriaOption` is invoked using the
+ :func:`_orm.with_loader_criteria` function; see that function for
+ details.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _traverse_internals = [
+ ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
+ ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+ ("where_criteria", visitors.InternalTraversal.dp_clauseelement),
+ ("include_aliases", visitors.InternalTraversal.dp_boolean),
+ ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
+ ]
+
+ def __init__(
+ self,
+ entity_or_base,
+ where_criteria,
+ loader_only=False,
+ include_aliases=False,
+ propagate_to_loaders=True,
+ track_closure_variables=True,
+ ):
+ """Add additional WHERE criteria to the load for all occurrences of
+ a particular entity.
+
+ .. versionadded:: 1.4
+
+ The :func:`_orm.with_loader_criteria` option is intended to add
+ limiting criteria to a particular kind of entity in a query,
+ **globally**, meaning it will apply to the entity as it appears
+ in the SELECT query as well as within any subqueries, join
+ conditions, and relationship loads, including both eager and lazy
+ loaders, without the need for it to be specified in any particular
+ part of the query. The rendering logic uses the same system used by
+ single table inheritance to ensure a certain discriminator is applied
+ to a table.
+
+ E.g., using :term:`2.0-style` queries, we can limit the way the
+ ``User.addresses`` collection is loaded, regardless of the kind
+ of loading used::
+
+ from sqlalchemy.orm import with_loader_criteria
+
+ stmt = select(User).options(
+ selectinload(User.addresses),
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ Above, the "selectinload" for ``User.addresses`` will apply the
+ given filtering criteria to the WHERE clause.
+
+ Another example, where the filtering will be applied to the
+ ON clause of the join, in this example using :term:`1.x style`
+ queries::
+
+ q = session.query(User).outerjoin(User.addresses).options(
+ with_loader_criteria(Address, Address.email_address != 'foo'))
+ )
+
+ The primary purpose of :func:`_orm.with_loader_criteria` is to use
+ it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler
+ to ensure that all occurrences of a particular entity are filtered
+ in a certain way, such as filtering for access control roles. It
+ also can be used to apply criteria to relationship loads. In the
+ example below, we can apply a certain set of rules to all queries
+ emitted by a particular :class:`_orm.Session`::
+
+ session = Session(bind=engine)
+
+ @event.listens_for("do_orm_execute", session)
+ def _add_filtering_criteria(execute_state):
+
+ if (
+ execute_state.is_select
+ and not execute_state.is_column_load
+ and not execute_state.is_relationship_load
+ ):
+ execute_state.statement = execute_state.statement.options(
+ with_loader_criteria(
+ SecurityRole,
+ lambda cls: cls.role.in_(['some_role']),
+ include_aliases=True
+ )
+ )
+
+ In the above example, the :meth:`_orm.SessionEvents.do_orm_execute`
+ event will intercept all queries emitted using the
+ :class:`_orm.Session`. For those queries which are SELECT statements
+ and are not attribute or relationship loads a custom
+ :func:`_orm.with_loader_criteria` option is added to the query. The
+ :func:`_orm.with_loader_criteria` option will be used in the given
+ statement and will also be automatically propagated to all relationship
+ loads that descend from this query.
+
+ The criteria argument given is a ``lambda`` that accepts a ``cls``
+ argument. The given class will expand to include all mapped subclass
+ and need not itself be a mapped class.
+
+ .. tip::
+
+ When using :func:`_orm.with_loader_criteria` option in
+ conjunction with the :func:`_orm.contains_eager` loader option,
+ it's important to note that :func:`_orm.with_loader_criteria` only
+ affects the part of the query that determines what SQL is rendered
+ in terms of the WHERE and FROM clauses. The
+ :func:`_orm.contains_eager` option does not affect the rendering of
+ the SELECT statement outside of the columns clause, so does not have
+ any interaction with the :func:`_orm.with_loader_criteria` option.
+ However, the way things "work" is that :func:`_orm.contains_eager`
+ is meant to be used with a query that is already selecting from the
+ additional entities in some way, where
+ :func:`_orm.with_loader_criteria` can apply it's additional
+ criteria.
+
+ In the example below, assuming a mapping relationship as
+ ``A -> A.bs -> B``, the given :func:`_orm.with_loader_criteria`
+ option will affect the way in which the JOIN is rendered::
+
+ stmt = select(A).join(A.bs).options(
+ contains_eager(A.bs),
+ with_loader_criteria(B, B.flag == 1)
+ )
+
+ Above, the given :func:`_orm.with_loader_criteria` option will
+ affect the ON clause of the JOIN that is specified by
+ ``.join(A.bs)``, so is applied as expected. The
+ :func:`_orm.contains_eager` option has the effect that columns from
+ ``B`` are added to the columns clause::
+
+ SELECT
+ b.id, b.a_id, b.data, b.flag,
+ a.id AS id_1,
+ a.data AS data_1
+ FROM a JOIN b ON a.id = b.a_id AND b.flag = :flag_1
+
+
+ The use of the :func:`_orm.contains_eager` option within the above
+ statement has no effect on the behavior of the
+ :func:`_orm.with_loader_criteria` option. If the
+ :func:`_orm.contains_eager` option were omitted, the SQL would be
+ the same as regards the FROM and WHERE clauses, where
+ :func:`_orm.with_loader_criteria` continues to add its criteria to
+ the ON clause of the JOIN. The addition of
+ :func:`_orm.contains_eager` only affects the columns clause, in that
+ additional columns against ``b`` are added which are then consumed
+ by the ORM to produce ``B`` instances.
+
+ .. warning:: The use of a lambda inside of the call to
+ :func:`_orm.with_loader_criteria` is only invoked **once per unique
+ class**. Custom functions should not be invoked within this lambda.
+ See :ref:`engine_lambda_caching` for an overview of the "lambda SQL"
+ feature, which is for advanced use only.
+
+ :param entity_or_base: a mapped class, or a class that is a super
+ class of a particular set of mapped classes, to which the rule
+ will apply.
+
+ :param where_criteria: a Core SQL expression that applies limiting
+ criteria. This may also be a "lambda:" or Python function that
+ accepts a target class as an argument, when the given class is
+ a base with many different mapped subclasses.
+
+ .. note:: To support pickling, use a module-level Python function to
+ produce the SQL expression instead of a lambda or a fixed SQL
+ expression, which tend to not be picklable.
+
+ :param include_aliases: if True, apply the rule to :func:`_orm.aliased`
+ constructs as well.
+
+ :param propagate_to_loaders: defaults to True, apply to relationship
+ loaders such as lazy loaders. This indicates that the
+ option object itself including SQL expression is carried along with
+ each loaded instance. Set to ``False`` to prevent the object from
+ being assigned to individual instances.
+
+ .. seealso::
+
+ :ref:`examples_session_orm_events` - includes examples of using
+ :func:`_orm.with_loader_criteria`.
+
+ :ref:`do_orm_execute_global_criteria` - basic example on how to
+ combine :func:`_orm.with_loader_criteria` with the
+ :meth:`_orm.SessionEvents.do_orm_execute` event.
+
+ :param track_closure_variables: when False, closure variables inside
+ of a lambda expression will not be used as part of
+ any cache key. This allows more complex expressions to be used
+ inside of a lambda expression but requires that the lambda ensures
+ it returns the identical SQL every time given a particular class.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ entity = inspection.inspect(entity_or_base, False)
+ if entity is None:
+ self.root_entity = entity_or_base
+ self.entity = None
+ else:
+ self.root_entity = None
+ self.entity = entity
+
+ self._where_crit_orig = where_criteria
+ if callable(where_criteria):
+ self.deferred_where_criteria = True
+ self.where_criteria = lambdas.DeferredLambdaElement(
+ where_criteria,
+ roles.WhereHavingRole,
+ lambda_args=(
+ _WrapUserEntity(
+ self.root_entity
+ if self.root_entity is not None
+ else self.entity.entity,
+ ),
+ ),
+ opts=lambdas.LambdaOptions(
+ track_closure_variables=track_closure_variables
+ ),
+ )
+ else:
+ self.deferred_where_criteria = False
+ self.where_criteria = coercions.expect(
+ roles.WhereHavingRole, where_criteria
+ )
+
+ self.include_aliases = include_aliases
+ self.propagate_to_loaders = propagate_to_loaders
+
+ @classmethod
+ def _unreduce(
+ cls, entity, where_criteria, include_aliases, propagate_to_loaders
+ ):
+ return LoaderCriteriaOption(
+ entity,
+ where_criteria,
+ include_aliases=include_aliases,
+ propagate_to_loaders=propagate_to_loaders,
+ )
+
+ def __reduce__(self):
+ return (
+ LoaderCriteriaOption._unreduce,
+ (
+ self.entity.class_ if self.entity else self.root_entity,
+ self._where_crit_orig,
+ self.include_aliases,
+ self.propagate_to_loaders,
+ ),
+ )
+
+ def _all_mappers(self):
+
+ if self.entity:
+ for ent in self.entity.mapper.self_and_descendants:
+ yield ent
+ else:
+ stack = list(self.root_entity.__subclasses__())
+ while stack:
+ subclass = stack.pop(0)
+ ent = inspection.inspect(subclass, raiseerr=False)
+ if ent:
+ for mp in ent.mapper.self_and_descendants:
+ yield mp
+ else:
+ stack.extend(subclass.__subclasses__())
+
+ def _should_include(self, compile_state):
+ if (
+ compile_state.select_statement._annotations.get(
+ "for_loader_criteria", None
+ )
+ is self
+ ):
+ return False
+ return True
+
+ def _resolve_where_criteria(self, ext_info):
+ if self.deferred_where_criteria:
+ crit = self.where_criteria._resolve_with_args(ext_info.entity)
+ else:
+ crit = self.where_criteria
+ return sql_util._deep_annotate(
+ crit, {"for_loader_criteria": self}, detect_subquery_cols=True
+ )
+
+ def process_compile_state_replaced_entities(
+ self, compile_state, mapper_entities
+ ):
+ return self.process_compile_state(compile_state)
+
+ def process_compile_state(self, compile_state):
+ """Apply a modification to a given :class:`.CompileState`."""
+
+ # if options to limit the criteria to immediate query only,
+ # use compile_state.attributes instead
+
+ if compile_state.compile_options._with_polymorphic_adapt_map:
+ util.warn(
+ "The with_loader_criteria() function may not work "
+ "correctly with the legacy Query.with_polymorphic() feature. "
+ "Please migrate code to use the with_polymorphic() standalone "
+ "function before using with_loader_criteria()."
+ )
+ self.get_global_criteria(compile_state.global_attributes)
+
+ def get_global_criteria(self, attributes):
+ for mp in self._all_mappers():
+ load_criteria = attributes.setdefault(
+ ("additional_entity_criteria", mp), []
+ )
+
+ load_criteria.append(self)
+
+
+inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
+inspection._inspects(AliasedInsp)(lambda target: target)
+
+
+def aliased(element, alias=None, name=None, flat=False, adapt_on_names=False):
+ """Produce an alias of the given element, usually an :class:`.AliasedClass`
+ instance.
+
+ E.g.::
+
+ my_alias = aliased(MyClass)
+
+ session.query(MyClass, my_alias).filter(MyClass.id > my_alias.id)
+
+ The :func:`.aliased` function is used to create an ad-hoc mapping of a
+ mapped class to a new selectable. By default, a selectable is generated
+ from the normally mapped selectable (typically a :class:`_schema.Table`
+ ) using the
+ :meth:`_expression.FromClause.alias` method. However, :func:`.aliased`
+ can also be
+ used to link the class to a new :func:`_expression.select` statement.
+ Also, the :func:`.with_polymorphic` function is a variant of
+ :func:`.aliased` that is intended to specify a so-called "polymorphic
+ selectable", that corresponds to the union of several joined-inheritance
+ subclasses at once.
+
+ For convenience, the :func:`.aliased` function also accepts plain
+ :class:`_expression.FromClause` constructs, such as a
+ :class:`_schema.Table` or
+ :func:`_expression.select` construct. In those cases, the
+ :meth:`_expression.FromClause.alias`
+ method is called on the object and the new
+ :class:`_expression.Alias` object returned. The returned
+ :class:`_expression.Alias` is not
+ ORM-mapped in this case.
+
+ .. seealso::
+
+ :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial`
+
+ :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel`
+
+ :param element: element to be aliased. Is normally a mapped class,
+ but for convenience can also be a :class:`_expression.FromClause`
+ element.
+
+ :param alias: Optional selectable unit to map the element to. This is
+ usually used to link the object to a subquery, and should be an aliased
+ select construct as one would produce from the
+ :meth:`_query.Query.subquery` method or
+ the :meth:`_expression.Select.subquery` or
+ :meth:`_expression.Select.alias` methods of the :func:`_expression.select`
+ construct.
+
+ :param name: optional string name to use for the alias, if not specified
+ by the ``alias`` parameter. The name, among other things, forms the
+ attribute name that will be accessible via tuples returned by a
+ :class:`_query.Query` object. Not supported when creating aliases
+ of :class:`_sql.Join` objects.
+
+ :param flat: Boolean, will be passed through to the
+ :meth:`_expression.FromClause.alias` call so that aliases of
+ :class:`_expression.Join` objects will alias the individual tables
+ inside the join, rather than creating a subquery. This is generally
+ supported by all modern databases with regards to right-nested joins
+ and generally produces more efficient queries.
+
+ :param adapt_on_names: if True, more liberal "matching" will be used when
+ mapping the mapped columns of the ORM entity to those of the
+ given selectable - a name-based match will be performed if the
+ given selectable doesn't otherwise have a column that corresponds
+ to one on the entity. The use case for this is when associating
+ an entity with some derived selectable such as one that uses
+ aggregate functions::
+
+ class UnitPrice(Base):
+ __tablename__ = 'unit_price'
+ ...
+ unit_id = Column(Integer)
+ price = Column(Numeric)
+
+ aggregated_unit_price = Session.query(
+ func.sum(UnitPrice.price).label('price')
+ ).group_by(UnitPrice.unit_id).subquery()
+
+ aggregated_unit_price = aliased(UnitPrice,
+ alias=aggregated_unit_price, adapt_on_names=True)
+
+ Above, functions on ``aggregated_unit_price`` which refer to
+ ``.price`` will return the
+ ``func.sum(UnitPrice.price).label('price')`` column, as it is
+ matched on the name "price". Ordinarily, the "price" function
+ wouldn't have any "column correspondence" to the actual
+ ``UnitPrice.price`` column as it is not a proxy of the original.
+
+ """
+ if isinstance(element, expression.FromClause):
+ if adapt_on_names:
+ raise sa_exc.ArgumentError(
+ "adapt_on_names only applies to ORM elements"
+ )
+ if name:
+ return element.alias(name=name, flat=flat)
+ else:
+ return coercions.expect(
+ roles.AnonymizedFromClauseRole, element, flat=flat
+ )
+ else:
+ return AliasedClass(
+ element,
+ alias=alias,
+ flat=flat,
+ name=name,
+ adapt_on_names=adapt_on_names,
+ )
+
+
+def with_polymorphic(
+ base,
+ classes,
+ selectable=False,
+ flat=False,
+ polymorphic_on=None,
+ aliased=False,
+ adapt_on_names=False,
+ innerjoin=False,
+ _use_mapper_path=False,
+ _existing_alias=None,
+):
+ """Produce an :class:`.AliasedClass` construct which specifies
+ columns for descendant mappers of the given base.
+
+ Using this method will ensure that each descendant mapper's
+ tables are included in the FROM clause, and will allow filter()
+ criterion to be used against those tables. The resulting
+ instances will also have those columns already loaded so that
+ no "post fetch" of those columns will be required.
+
+ .. seealso::
+
+ :ref:`with_polymorphic` - full discussion of
+ :func:`_orm.with_polymorphic`.
+
+ :param base: Base class to be aliased.
+
+ :param classes: a single class or mapper, or list of
+ class/mappers, which inherit from the base class.
+ Alternatively, it may also be the string ``'*'``, in which case
+ all descending mapped classes will be added to the FROM clause.
+
+ :param aliased: when True, the selectable will be aliased. For a
+ JOIN, this means the JOIN will be SELECTed from inside of a subquery
+ unless the :paramref:`_orm.with_polymorphic.flat` flag is set to
+ True, which is recommended for simpler use cases.
+
+ :param flat: Boolean, will be passed through to the
+ :meth:`_expression.FromClause.alias` call so that aliases of
+ :class:`_expression.Join` objects will alias the individual tables
+ inside the join, rather than creating a subquery. This is generally
+ supported by all modern databases with regards to right-nested joins
+ and generally produces more efficient queries. Setting this flag is
+ recommended as long as the resulting SQL is functional.
+
+ :param selectable: a table or subquery that will
+ be used in place of the generated FROM clause. This argument is
+ required if any of the desired classes use concrete table
+ inheritance, since SQLAlchemy currently cannot generate UNIONs
+ among tables automatically. If used, the ``selectable`` argument
+ must represent the full set of tables and columns mapped by every
+ mapped class. Otherwise, the unaccounted mapped columns will
+ result in their table being appended directly to the FROM clause
+ which will usually lead to incorrect results.
+
+ When left at its default value of ``False``, the polymorphic
+ selectable assigned to the base mapper is used for selecting rows.
+ However, it may also be passed as ``None``, which will bypass the
+ configured polymorphic selectable and instead construct an ad-hoc
+ selectable for the target classes given; for joined table inheritance
+ this will be a join that includes all target mappers and their
+ subclasses.
+
+ :param polymorphic_on: a column to be used as the "discriminator"
+ column for the given selectable. If not given, the polymorphic_on
+ attribute of the base classes' mapper will be used, if any. This
+ is useful for mappings that don't have polymorphic loading
+ behavior by default.
+
+ :param innerjoin: if True, an INNER JOIN will be used. This should
+ only be specified if querying for one specific subtype only
+
+ :param adapt_on_names: Passes through the
+ :paramref:`_orm.aliased.adapt_on_names`
+ parameter to the aliased object. This may be useful in situations where
+ the given selectable is not directly related to the existing mapped
+ selectable.
+
+ .. versionadded:: 1.4.33
+
+ """
+ primary_mapper = _class_to_mapper(base)
+
+ if selectable not in (None, False) and flat:
+ raise sa_exc.ArgumentError(
+ "the 'flat' and 'selectable' arguments cannot be passed "
+ "simultaneously to with_polymorphic()"
+ )
+
+ if _existing_alias:
+ assert _existing_alias.mapper is primary_mapper
+ classes = util.to_set(classes)
+ new_classes = set(
+ [mp.class_ for mp in _existing_alias.with_polymorphic_mappers]
+ )
+ if classes == new_classes:
+ return _existing_alias
+ else:
+ classes = classes.union(new_classes)
+ mappers, selectable = primary_mapper._with_polymorphic_args(
+ classes, selectable, innerjoin=innerjoin
+ )
+ if aliased or flat:
+ selectable = selectable._anonymous_fromclause(flat=flat)
+ return AliasedClass(
+ base,
+ selectable,
+ adapt_on_names=adapt_on_names,
+ with_polymorphic_mappers=mappers,
+ with_polymorphic_discriminator=polymorphic_on,
+ use_mapper_path=_use_mapper_path,
+ represents_outer_join=not innerjoin,
+ )
+
+
+@inspection._self_inspects
+class Bundle(
+ ORMColumnsClauseRole,
+ SupportsCloneAnnotations,
+ sql_base.MemoizedHasCacheKey,
+ InspectionAttr,
+):
+ """A grouping of SQL expressions that are returned by a :class:`.Query`
+ under one namespace.
+
+ The :class:`.Bundle` essentially allows nesting of the tuple-based
+ results returned by a column-oriented :class:`_query.Query` object.
+ It also
+ is extensible via simple subclassing, where the primary capability
+ to override is that of how the set of expressions should be returned,
+ allowing post-processing as well as custom return types, without
+ involving ORM identity-mapped classes.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`bundles`
+
+
+ """
+
+ single_entity = False
+ """If True, queries for a single Bundle will be returned as a single
+ entity, rather than an element within a keyed tuple."""
+
+ is_clause_element = False
+
+ is_mapper = False
+
+ is_aliased_class = False
+
+ is_bundle = True
+
+ _propagate_attrs = util.immutabledict()
+
+ def __init__(self, name, *exprs, **kw):
+ r"""Construct a new :class:`.Bundle`.
+
+ e.g.::
+
+ bn = Bundle("mybundle", MyClass.x, MyClass.y)
+
+ for row in session.query(bn).filter(
+ bn.c.x == 5).filter(bn.c.y == 4):
+ print(row.mybundle.x, row.mybundle.y)
+
+ :param name: name of the bundle.
+ :param \*exprs: columns or SQL expressions comprising the bundle.
+ :param single_entity=False: if True, rows for this :class:`.Bundle`
+ can be returned as a "single entity" outside of any enclosing tuple
+ in the same manner as a mapped entity.
+
+ """
+ self.name = self._label = name
+ self.exprs = exprs = [
+ coercions.expect(
+ roles.ColumnsClauseRole, expr, apply_propagate_attrs=self
+ )
+ for expr in exprs
+ ]
+
+ self.c = self.columns = ColumnCollection(
+ (getattr(col, "key", col._label), col)
+ for col in [e._annotations.get("bundle", e) for e in exprs]
+ )
+ self.single_entity = kw.pop("single_entity", self.single_entity)
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return (self.__class__, self.name, self.single_entity) + tuple(
+ [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs]
+ )
+
+ @property
+ def mapper(self):
+ return self.exprs[0]._annotations.get("parentmapper", None)
+
+ @property
+ def entity(self):
+ return self.exprs[0]._annotations.get("parententity", None)
+
+ @property
+ def entity_namespace(self):
+ return self.c
+
+ columns = None
+ """A namespace of SQL expressions referred to by this :class:`.Bundle`.
+
+ e.g.::
+
+ bn = Bundle("mybundle", MyClass.x, MyClass.y)
+
+ q = sess.query(bn).filter(bn.c.x == 5)
+
+ Nesting of bundles is also supported::
+
+ b1 = Bundle("b1",
+ Bundle('b2', MyClass.a, MyClass.b),
+ Bundle('b3', MyClass.x, MyClass.y)
+ )
+
+ q = sess.query(b1).filter(
+ b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9)
+
+ .. seealso::
+
+ :attr:`.Bundle.c`
+
+ """
+
+ c = None
+ """An alias for :attr:`.Bundle.columns`."""
+
+ def _clone(self):
+ cloned = self.__class__.__new__(self.__class__)
+ cloned.__dict__.update(self.__dict__)
+ return cloned
+
+ def __clause_element__(self):
+ # ensure existing entity_namespace remains
+ annotations = {"bundle": self, "entity_namespace": self}
+ annotations.update(self._annotations)
+
+ plugin_subject = self.exprs[0]._propagate_attrs.get(
+ "plugin_subject", self.entity
+ )
+ return (
+ expression.ClauseList(
+ _literal_as_text_role=roles.ColumnsClauseRole,
+ group=False,
+ *[e._annotations.get("bundle", e) for e in self.exprs]
+ )
+ ._annotate(annotations)
+ ._set_propagate_attrs(
+ # the Bundle *must* use the orm plugin no matter what. the
+ # subject can be None but it's much better if it's not.
+ {
+ "compile_state_plugin": "orm",
+ "plugin_subject": plugin_subject,
+ }
+ )
+ )
+
+ @property
+ def clauses(self):
+ return self.__clause_element__().clauses
+
+ def label(self, name):
+ """Provide a copy of this :class:`.Bundle` passing a new label."""
+
+ cloned = self._clone()
+ cloned.name = name
+ return cloned
+
+ def create_row_processor(self, query, procs, labels):
+ """Produce the "row processing" function for this :class:`.Bundle`.
+
+ May be overridden by subclasses.
+
+ .. seealso::
+
+ :ref:`bundles` - includes an example of subclassing.
+
+ """
+ keyed_tuple = result_tuple(labels, [() for l in labels])
+
+ def proc(row):
+ return keyed_tuple([proc(row) for proc in procs])
+
+ return proc
+
+
+def _orm_annotate(element, exclude=None):
+ """Deep copy the given ClauseElement, annotating each element with the
+ "_orm_adapt" flag.
+
+ Elements within the exclude collection will be cloned but not annotated.
+
+ """
+ return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
+
+
+def _orm_deannotate(element):
+ """Remove annotations that link a column to a particular mapping.
+
+ Note this doesn't affect "remote" and "foreign" annotations
+ passed by the :func:`_orm.foreign` and :func:`_orm.remote`
+ annotators.
+
+ """
+
+ return sql_util._deep_deannotate(
+ element, values=("_orm_adapt", "parententity")
+ )
+
+
+def _orm_full_deannotate(element):
+ return sql_util._deep_deannotate(element)
+
+
+class _ORMJoin(expression.Join):
+ """Extend Join to support ORM constructs as input."""
+
+ __visit_name__ = expression.Join.__visit_name__
+
+ inherit_cache = True
+
+ def __init__(
+ self,
+ left,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ _left_memo=None,
+ _right_memo=None,
+ _extra_criteria=(),
+ ):
+ left_info = inspection.inspect(left)
+
+ right_info = inspection.inspect(right)
+ adapt_to = right_info.selectable
+
+ # used by joined eager loader
+ self._left_memo = _left_memo
+ self._right_memo = _right_memo
+
+ # legacy, for string attr name ON clause. if that's removed
+ # then the "_joined_from_info" concept can go
+ left_orm_info = getattr(left, "_joined_from_info", left_info)
+ self._joined_from_info = right_info
+ if isinstance(onclause, util.string_types):
+ onclause = getattr(left_orm_info.entity, onclause)
+ # ####
+
+ if isinstance(onclause, attributes.QueryableAttribute):
+ on_selectable = onclause.comparator._source_selectable()
+ prop = onclause.property
+ _extra_criteria += onclause._extra_criteria
+ elif isinstance(onclause, MapperProperty):
+ # used internally by joined eager loader...possibly not ideal
+ prop = onclause
+ on_selectable = prop.parent.selectable
+ else:
+ prop = None
+
+ if prop:
+ left_selectable = left_info.selectable
+
+ if sql_util.clause_is_present(on_selectable, left_selectable):
+ adapt_from = on_selectable
+ else:
+ adapt_from = left_selectable
+
+ (
+ pj,
+ sj,
+ source,
+ dest,
+ secondary,
+ target_adapter,
+ ) = prop._create_joins(
+ source_selectable=adapt_from,
+ dest_selectable=adapt_to,
+ source_polymorphic=True,
+ of_type_entity=right_info,
+ alias_secondary=True,
+ extra_criteria=_extra_criteria,
+ )
+
+ if sj is not None:
+ if isouter:
+ # note this is an inner join from secondary->right
+ right = sql.join(secondary, right, sj)
+ onclause = pj
+ else:
+ left = sql.join(left, secondary, pj, isouter)
+ onclause = sj
+ else:
+ onclause = pj
+
+ self._target_adapter = target_adapter
+
+ augment_onclause = onclause is None and _extra_criteria
+ expression.Join.__init__(self, left, right, onclause, isouter, full)
+
+ if augment_onclause:
+ self.onclause &= sql.and_(*_extra_criteria)
+
+ if (
+ not prop
+ and getattr(right_info, "mapper", None)
+ and right_info.mapper.single
+ ):
+ # if single inheritance target and we are using a manual
+ # or implicit ON clause, augment it the same way we'd augment the
+ # WHERE.
+ single_crit = right_info.mapper._single_table_criterion
+ if single_crit is not None:
+ if right_info.is_aliased_class:
+ single_crit = right_info._adapter.traverse(single_crit)
+ self.onclause = self.onclause & single_crit
+
+ def _splice_into_center(self, other):
+ """Splice a join into the center.
+
+ Given join(a, b) and join(b, c), return join(a, b).join(c)
+
+ """
+ leftmost = other
+ while isinstance(leftmost, sql.Join):
+ leftmost = leftmost.left
+
+ assert self.right is leftmost
+
+ left = _ORMJoin(
+ self.left,
+ other.left,
+ self.onclause,
+ isouter=self.isouter,
+ _left_memo=self._left_memo,
+ _right_memo=other._left_memo,
+ )
+
+ return _ORMJoin(
+ left,
+ other.right,
+ other.onclause,
+ isouter=other.isouter,
+ _right_memo=other._right_memo,
+ )
+
+ def join(
+ self,
+ right,
+ onclause=None,
+ isouter=False,
+ full=False,
+ join_to_left=None,
+ ):
+ return _ORMJoin(self, right, onclause, full=full, isouter=isouter)
+
+ def outerjoin(self, right, onclause=None, full=False, join_to_left=None):
+ return _ORMJoin(self, right, onclause, isouter=True, full=full)
+
+
+def join(
+ left, right, onclause=None, isouter=False, full=False, join_to_left=None
+):
+ r"""Produce an inner join between left and right clauses.
+
+ :func:`_orm.join` is an extension to the core join interface
+ provided by :func:`_expression.join()`, where the
+ left and right selectables may be not only core selectable
+ objects such as :class:`_schema.Table`, but also mapped classes or
+ :class:`.AliasedClass` instances. The "on" clause can
+ be a SQL expression or an ORM mapped attribute
+ referencing a configured :func:`_orm.relationship`.
+
+ .. deprecated:: 1.4 using a string relationship name for the "onclause"
+ is deprecated and will be removed in 2.0; the onclause may be only
+ an ORM-mapped relationship attribute or a SQL expression construct.
+
+ :func:`_orm.join` is not commonly needed in modern usage,
+ as its functionality is encapsulated within that of the
+ :meth:`_sql.Select.join` and :meth:`_query.Query.join`
+ methods. which feature a
+ significant amount of automation beyond :func:`_orm.join`
+ by itself. Explicit use of :func:`_orm.join`
+ with ORM-enabled SELECT statements involves use of the
+ :meth:`_sql.Select.select_from` method, as in::
+
+ from sqlalchemy.orm import join
+ stmt = select(User).\
+ select_from(join(User, Address, User.addresses)).\
+ filter(Address.email_address=='foo@bar.com')
+
+ In modern SQLAlchemy the above join can be written more
+ succinctly as::
+
+ stmt = select(User).\
+ join(User.addresses).\
+ filter(Address.email_address=='foo@bar.com')
+
+ See :ref:`orm_queryguide_joins` for information on modern usage
+ of ORM level joins.
+
+ .. deprecated:: 0.8
+
+ the ``join_to_left`` parameter is deprecated, and will be removed
+ in a future release. The parameter has no effect.
+
+ """
+ return _ORMJoin(left, right, onclause, isouter, full)
+
+
+def outerjoin(left, right, onclause=None, full=False, join_to_left=None):
+ """Produce a left outer join between left and right clauses.
+
+ This is the "outer join" version of the :func:`_orm.join` function,
+ featuring the same behavior except that an OUTER JOIN is generated.
+ See that function's documentation for other usage details.
+
+ """
+ return _ORMJoin(left, right, onclause, True, full)
+
+
+def with_parent(instance, prop, from_entity=None):
+ """Create filtering criterion that relates this query's primary entity
+ to the given related instance, using established
+ :func:`_orm.relationship()`
+ configuration.
+
+ E.g.::
+
+ stmt = select(Address).where(with_parent(some_user, User.addresses))
+
+
+ The SQL rendered is the same as that rendered when a lazy loader
+ would fire off from the given parent on that attribute, meaning
+ that the appropriate state is taken from the parent object in
+ Python without the need to render joins to the parent table
+ in the rendered statement.
+
+ The given property may also make use of :meth:`_orm.PropComparator.of_type`
+ to indicate the left side of the criteria::
+
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+ stmt = select(a1, a2).where(
+ with_parent(u1, User.addresses.of_type(a2))
+ )
+
+ The above use is equivalent to using the
+ :func:`_orm.with_parent.from_entity` argument::
+
+ a1 = aliased(Address)
+ a2 = aliased(Address)
+ stmt = select(a1, a2).where(
+ with_parent(u1, User.addresses, from_entity=a2)
+ )
+
+ :param instance:
+ An instance which has some :func:`_orm.relationship`.
+
+ :param property:
+ String property name, or class-bound attribute, which indicates
+ what relationship from the instance should be used to reconcile the
+ parent/child relationship.
+
+ .. deprecated:: 1.4 Using strings is deprecated and will be removed
+ in SQLAlchemy 2.0. Please use the class-bound attribute directly.
+
+ :param from_entity:
+ Entity in which to consider as the left side. This defaults to the
+ "zero" entity of the :class:`_query.Query` itself.
+
+ .. versionadded:: 1.2
+
+ """
+ if isinstance(prop, util.string_types):
+ util.warn_deprecated_20(
+ "Using strings to indicate relationship names in the ORM "
+ "with_parent() function is deprecated and will be removed "
+ "SQLAlchemy 2.0. Please use the class-bound attribute directly."
+ )
+ mapper = object_mapper(instance)
+ prop = getattr(mapper.class_, prop).property
+ elif isinstance(prop, attributes.QueryableAttribute):
+ if prop._of_type:
+ from_entity = prop._of_type
+ prop = prop.property
+
+ return prop._with_parent(instance, from_entity=from_entity)
+
+
+def has_identity(object_):
+ """Return True if the given object has a database
+ identity.
+
+ This typically corresponds to the object being
+ in either the persistent or detached state.
+
+ .. seealso::
+
+ :func:`.was_deleted`
+
+ """
+ state = attributes.instance_state(object_)
+ return state.has_identity
+
+
+def was_deleted(object_):
+ """Return True if the given object was deleted
+ within a session flush.
+
+ This is regardless of whether or not the object is
+ persistent or detached.
+
+ .. seealso::
+
+ :attr:`.InstanceState.was_deleted`
+
+ """
+
+ state = attributes.instance_state(object_)
+ return state.was_deleted
+
+
+def _entity_corresponds_to(given, entity):
+ """determine if 'given' corresponds to 'entity', in terms
+ of an entity passed to Query that would match the same entity
+ being referred to elsewhere in the query.
+
+ """
+ if entity.is_aliased_class:
+ if given.is_aliased_class:
+ if entity._base_alias() is given._base_alias():
+ return True
+ return False
+ elif given.is_aliased_class:
+ if given._use_mapper_path:
+ return entity in given.with_polymorphic_mappers
+ else:
+ return entity is given
+
+ return entity.common_parent(given)
+
+
+def _entity_corresponds_to_use_path_impl(given, entity):
+ """determine if 'given' corresponds to 'entity', in terms
+ of a path of loader options where a mapped attribute is taken to
+ be a member of a parent entity.
+
+ e.g.::
+
+ someoption(A).someoption(A.b) # -> fn(A, A) -> True
+ someoption(A).someoption(C.d) # -> fn(A, C) -> False
+
+ a1 = aliased(A)
+ someoption(a1).someoption(A.b) # -> fn(a1, A) -> False
+ someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True
+
+ wp = with_polymorphic(A, [A1, A2])
+ someoption(wp).someoption(A1.foo) # -> fn(wp, A1) -> False
+ someoption(wp).someoption(wp.A1.foo) # -> fn(wp, wp.A1) -> True
+
+
+ """
+ if given.is_aliased_class:
+ return (
+ entity.is_aliased_class
+ and not entity._use_mapper_path
+ and (given is entity or given in entity._with_polymorphic_entities)
+ )
+ elif not entity.is_aliased_class:
+ return given.common_parent(entity.mapper)
+ else:
+ return (
+ entity._use_mapper_path
+ and given in entity.with_polymorphic_mappers
+ )
+
+
+def _entity_isa(given, mapper):
+ """determine if 'given' "is a" mapper, in terms of the given
+ would load rows of type 'mapper'.
+
+ """
+ if given.is_aliased_class:
+ return mapper in given.with_polymorphic_mappers or given.mapper.isa(
+ mapper
+ )
+ elif given.with_polymorphic_mappers:
+ return mapper in given.with_polymorphic_mappers
+ else:
+ return given.isa(mapper)
+
+
+def randomize_unitofwork():
+ """Use random-ordering sets within the unit of work in order
+ to detect unit of work sorting issues.
+
+ This is a utility function that can be used to help reproduce
+ inconsistent unit of work sorting issues. For example,
+ if two kinds of objects A and B are being inserted, and
+ B has a foreign key reference to A - the A must be inserted first.
+ However, if there is no relationship between A and B, the unit of work
+ won't know to perform this sorting, and an operation may or may not
+ fail, depending on how the ordering works out. Since Python sets
+ and dictionaries have non-deterministic ordering, such an issue may
+ occur on some runs and not on others, and in practice it tends to
+ have a great dependence on the state of the interpreter. This leads
+ to so-called "heisenbugs" where changing entirely irrelevant aspects
+ of the test program still cause the failure behavior to change.
+
+ By calling ``randomize_unitofwork()`` when a script first runs, the
+ ordering of a key series of sets within the unit of work implementation
+ are randomized, so that the script can be minimized down to the
+ fundamental mapping and operation that's failing, while still reproducing
+ the issue on at least some runs.
+
+ This utility is also available when running the test suite via the
+ ``--reversetop`` flag.
+
+ """
+ from sqlalchemy.orm import unitofwork, session, mapper, dependency
+ from sqlalchemy.util import topological
+ from sqlalchemy.testing.util import RandomSet
+
+ topological.set = (
+ unitofwork.set
+ ) = session.set = mapper.set = dependency.set = RandomSet
+
+
+def _getitem(iterable_query, item, allow_negative):
+ """calculate __getitem__ in terms of an iterable query object
+ that also has a slice() method.
+
+ """
+
+ def _no_negative_indexes():
+ if not allow_negative:
+ raise IndexError(
+ "negative indexes are not accepted by SQL "
+ "index / slice operators"
+ )
+ else:
+ util.warn_deprecated_20(
+ "Support for negative indexes for SQL index / slice operators "
+ "will be "
+ "removed in 2.0; these operators fetch the complete result "
+ "and do not work efficiently."
+ )
+
+ if isinstance(item, slice):
+ start, stop, step = util.decode_slice(item)
+
+ if (
+ isinstance(stop, int)
+ and isinstance(start, int)
+ and stop - start <= 0
+ ):
+ return []
+
+ elif (isinstance(start, int) and start < 0) or (
+ isinstance(stop, int) and stop < 0
+ ):
+ _no_negative_indexes()
+ return list(iterable_query)[item]
+
+ res = iterable_query.slice(start, stop)
+ if step is not None:
+ return list(res)[None : None : item.step]
+ else:
+ return list(res)
+ else:
+ if item == -1:
+ _no_negative_indexes()
+ return list(iterable_query)[-1]
+ else:
+ return list(iterable_query[item : item + 1])[0]
diff --git a/lib/sqlalchemy/pool/__init__.py b/lib/sqlalchemy/pool/__init__.py
new file mode 100644
index 0000000..6a00ef8
--- /dev/null
+++ b/lib/sqlalchemy/pool/__init__.py
@@ -0,0 +1,56 @@
+# sqlalchemy/pool/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""Connection pooling for DB-API connections.
+
+Provides a number of connection pool implementations for a variety of
+usage scenarios and thread behavior requirements imposed by the
+application, DB-API or database itself.
+
+Also provides a DB-API 2.0 connection proxying mechanism allowing
+regular DB-API connect() methods to be transparently managed by a
+SQLAlchemy connection pool.
+"""
+
+from . import events
+from .base import _ConnectionFairy
+from .base import _ConnectionRecord
+from .base import _finalize_fairy
+from .base import Pool
+from .base import reset_commit
+from .base import reset_none
+from .base import reset_rollback
+from .dbapi_proxy import clear_managers
+from .dbapi_proxy import manage
+from .impl import AssertionPool
+from .impl import AsyncAdaptedQueuePool
+from .impl import FallbackAsyncAdaptedQueuePool
+from .impl import NullPool
+from .impl import QueuePool
+from .impl import SingletonThreadPool
+from .impl import StaticPool
+
+
+__all__ = [
+ "Pool",
+ "reset_commit",
+ "reset_none",
+ "reset_rollback",
+ "clear_managers",
+ "manage",
+ "AssertionPool",
+ "NullPool",
+ "QueuePool",
+ "AsyncAdaptedQueuePool",
+ "FallbackAsyncAdaptedQueuePool",
+ "SingletonThreadPool",
+ "StaticPool",
+]
+
+# as these are likely to be used in various test suites, debugging
+# setups, keep them in the sqlalchemy.pool namespace
diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py
new file mode 100644
index 0000000..cde28c2
--- /dev/null
+++ b/lib/sqlalchemy/pool/base.py
@@ -0,0 +1,1121 @@
+# sqlalchemy/pool.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""Base constructs for connection pools.
+
+"""
+
+from collections import deque
+import time
+import weakref
+
+from .. import event
+from .. import exc
+from .. import log
+from .. import util
+
+
+reset_rollback = util.symbol("reset_rollback")
+reset_commit = util.symbol("reset_commit")
+reset_none = util.symbol("reset_none")
+
+
+class _ConnDialect(object):
+ """partial implementation of :class:`.Dialect`
+ which provides DBAPI connection methods.
+
+ When a :class:`_pool.Pool` is combined with an :class:`_engine.Engine`,
+ the :class:`_engine.Engine` replaces this with its own
+ :class:`.Dialect`.
+
+ """
+
+ is_async = False
+
+ def do_rollback(self, dbapi_connection):
+ dbapi_connection.rollback()
+
+ def do_commit(self, dbapi_connection):
+ dbapi_connection.commit()
+
+ def do_close(self, dbapi_connection):
+ dbapi_connection.close()
+
+ def do_ping(self, dbapi_connection):
+ raise NotImplementedError(
+ "The ping feature requires that a dialect is "
+ "passed to the connection pool."
+ )
+
+ def get_driver_connection(self, connection):
+ return connection
+
+
+class _AsyncConnDialect(_ConnDialect):
+ is_async = True
+
+
+class Pool(log.Identified):
+
+ """Abstract base class for connection pools."""
+
+ _dialect = _ConnDialect()
+
+ def __init__(
+ self,
+ creator,
+ recycle=-1,
+ echo=None,
+ logging_name=None,
+ reset_on_return=True,
+ events=None,
+ dialect=None,
+ pre_ping=False,
+ _dispatch=None,
+ ):
+ """
+ Construct a Pool.
+
+ :param creator: a callable function that returns a DB-API
+ connection object. The function will be called with
+ parameters.
+
+ :param recycle: If set to a value other than -1, number of
+ seconds between connection recycling, which means upon
+ checkout, if this timeout is surpassed the connection will be
+ closed and replaced with a newly opened connection. Defaults to -1.
+
+ :param logging_name: String identifier which will be used within
+ the "name" field of logging records generated within the
+ "sqlalchemy.pool" logger. Defaults to a hexstring of the object's
+ id.
+
+ :param echo: if True, the connection pool will log
+ informational output such as when connections are invalidated
+ as well as when connections are recycled to the default log handler,
+ which defaults to ``sys.stdout`` for output.. If set to the string
+ ``"debug"``, the logging will include pool checkouts and checkins.
+
+ The :paramref:`_pool.Pool.echo` parameter can also be set from the
+ :func:`_sa.create_engine` call by using the
+ :paramref:`_sa.create_engine.echo_pool` parameter.
+
+ .. seealso::
+
+ :ref:`dbengine_logging` - further detail on how to configure
+ logging.
+
+ :param reset_on_return: Determine steps to take on
+ connections as they are returned to the pool, which were
+ not otherwise handled by a :class:`_engine.Connection`.
+
+ reset_on_return can have any of these values:
+
+ * ``"rollback"`` - call rollback() on the connection,
+ to release locks and transaction resources.
+ This is the default value. The vast majority
+ of use cases should leave this value set.
+ * ``True`` - same as 'rollback', this is here for
+ backwards compatibility.
+ * ``"commit"`` - call commit() on the connection,
+ to release locks and transaction resources.
+ A commit here may be desirable for databases that
+ cache query plans if a commit is emitted,
+ such as Microsoft SQL Server. However, this
+ value is more dangerous than 'rollback' because
+ any data changes present on the transaction
+ are committed unconditionally.
+ * ``None`` - don't do anything on the connection.
+ This setting is only appropriate if the database / DBAPI
+ works in pure "autocommit" mode at all times, or if the
+ application uses the :class:`_engine.Engine` with consistent
+ connectivity patterns. See the section
+ :ref:`pool_reset_on_return` for more details.
+
+ * ``False`` - same as None, this is here for
+ backwards compatibility.
+
+ .. seealso::
+
+ :ref:`pool_reset_on_return`
+
+ :param events: a list of 2-tuples, each of the form
+ ``(callable, target)`` which will be passed to :func:`.event.listen`
+ upon construction. Provided here so that event listeners
+ can be assigned via :func:`_sa.create_engine` before dialect-level
+ listeners are applied.
+
+ :param dialect: a :class:`.Dialect` that will handle the job
+ of calling rollback(), close(), or commit() on DBAPI connections.
+ If omitted, a built-in "stub" dialect is used. Applications that
+ make use of :func:`_sa.create_engine` should not use this parameter
+ as it is handled by the engine creation strategy.
+
+ .. versionadded:: 1.1 - ``dialect`` is now a public parameter
+ to the :class:`_pool.Pool`.
+
+ :param pre_ping: if True, the pool will emit a "ping" (typically
+ "SELECT 1", but is dialect-specific) on the connection
+ upon checkout, to test if the connection is alive or not. If not,
+ the connection is transparently re-connected and upon success, all
+ other pooled connections established prior to that timestamp are
+ invalidated. Requires that a dialect is passed as well to
+ interpret the disconnection error.
+
+ .. versionadded:: 1.2
+
+ """
+ if logging_name:
+ self.logging_name = self._orig_logging_name = logging_name
+ else:
+ self._orig_logging_name = None
+
+ log.instance_logger(self, echoflag=echo)
+ self._creator = creator
+ self._recycle = recycle
+ self._invalidate_time = 0
+ self._pre_ping = pre_ping
+ self._reset_on_return = util.symbol.parse_user_argument(
+ reset_on_return,
+ {
+ reset_rollback: ["rollback", True],
+ reset_none: ["none", None, False],
+ reset_commit: ["commit"],
+ },
+ "reset_on_return",
+ resolve_symbol_names=False,
+ )
+
+ self.echo = echo
+
+ if _dispatch:
+ self.dispatch._update(_dispatch, only_propagate=False)
+ if dialect:
+ self._dialect = dialect
+ if events:
+ for fn, target in events:
+ event.listen(self, target, fn)
+
+ @util.hybridproperty
+ def _is_asyncio(self):
+ return self._dialect.is_async
+
+ @property
+ def _creator(self):
+ return self.__dict__["_creator"]
+
+ @_creator.setter
+ def _creator(self, creator):
+ self.__dict__["_creator"] = creator
+ self._invoke_creator = self._should_wrap_creator(creator)
+
+ def _should_wrap_creator(self, creator):
+ """Detect if creator accepts a single argument, or is sent
+ as a legacy style no-arg function.
+
+ """
+
+ try:
+ argspec = util.get_callable_argspec(self._creator, no_self=True)
+ except TypeError:
+ return lambda crec: creator()
+
+ defaulted = argspec[3] is not None and len(argspec[3]) or 0
+ positionals = len(argspec[0]) - defaulted
+
+ # look for the exact arg signature that DefaultStrategy
+ # sends us
+ if (argspec[0], argspec[3]) == (["connection_record"], (None,)):
+ return creator
+ # or just a single positional
+ elif positionals == 1:
+ return creator
+ # all other cases, just wrap and assume legacy "creator" callable
+ # thing
+ else:
+ return lambda crec: creator()
+
+ def _close_connection(self, connection):
+ self.logger.debug("Closing connection %r", connection)
+
+ try:
+ self._dialect.do_close(connection)
+ except Exception:
+ self.logger.error(
+ "Exception closing connection %r", connection, exc_info=True
+ )
+
+ def _create_connection(self):
+ """Called by subclasses to create a new ConnectionRecord."""
+
+ return _ConnectionRecord(self)
+
+ def _invalidate(self, connection, exception=None, _checkin=True):
+ """Mark all connections established within the generation
+ of the given connection as invalidated.
+
+ If this pool's last invalidate time is before when the given
+ connection was created, update the timestamp til now. Otherwise,
+ no action is performed.
+
+ Connections with a start time prior to this pool's invalidation
+ time will be recycled upon next checkout.
+ """
+ rec = getattr(connection, "_connection_record", None)
+ if not rec or self._invalidate_time < rec.starttime:
+ self._invalidate_time = time.time()
+ if _checkin and getattr(connection, "is_valid", False):
+ connection.invalidate(exception)
+
+ def recreate(self):
+ """Return a new :class:`_pool.Pool`, of the same class as this one
+ and configured with identical creation arguments.
+
+ This method is used in conjunction with :meth:`dispose`
+ to close out an entire :class:`_pool.Pool` and create a new one in
+ its place.
+
+ """
+
+ raise NotImplementedError()
+
+ def dispose(self):
+ """Dispose of this pool.
+
+ This method leaves the possibility of checked-out connections
+ remaining open, as it only affects connections that are
+ idle in the pool.
+
+ .. seealso::
+
+ :meth:`Pool.recreate`
+
+ """
+
+ raise NotImplementedError()
+
+ def connect(self):
+ """Return a DBAPI connection from the pool.
+
+ The connection is instrumented such that when its
+ ``close()`` method is called, the connection will be returned to
+ the pool.
+
+ """
+ return _ConnectionFairy._checkout(self)
+
+ def _return_conn(self, record):
+ """Given a _ConnectionRecord, return it to the :class:`_pool.Pool`.
+
+ This method is called when an instrumented DBAPI connection
+ has its ``close()`` method called.
+
+ """
+ self._do_return_conn(record)
+
+ def _do_get(self):
+ """Implementation for :meth:`get`, supplied by subclasses."""
+
+ raise NotImplementedError()
+
+ def _do_return_conn(self, conn):
+ """Implementation for :meth:`return_conn`, supplied by subclasses."""
+
+ raise NotImplementedError()
+
+ def status(self):
+ raise NotImplementedError()
+
+
+class _ConnectionRecord(object):
+
+ """Internal object which maintains an individual DBAPI connection
+ referenced by a :class:`_pool.Pool`.
+
+ The :class:`._ConnectionRecord` object always exists for any particular
+ DBAPI connection whether or not that DBAPI connection has been
+ "checked out". This is in contrast to the :class:`._ConnectionFairy`
+ which is only a public facade to the DBAPI connection while it is checked
+ out.
+
+ A :class:`._ConnectionRecord` may exist for a span longer than that
+ of a single DBAPI connection. For example, if the
+ :meth:`._ConnectionRecord.invalidate`
+ method is called, the DBAPI connection associated with this
+ :class:`._ConnectionRecord`
+ will be discarded, but the :class:`._ConnectionRecord` may be used again,
+ in which case a new DBAPI connection is produced when the
+ :class:`_pool.Pool`
+ next uses this record.
+
+ The :class:`._ConnectionRecord` is delivered along with connection
+ pool events, including :meth:`_events.PoolEvents.connect` and
+ :meth:`_events.PoolEvents.checkout`, however :class:`._ConnectionRecord`
+ still
+ remains an internal object whose API and internals may change.
+
+ .. seealso::
+
+ :class:`._ConnectionFairy`
+
+ """
+
+ def __init__(self, pool, connect=True):
+ self.__pool = pool
+ if connect:
+ self.__connect()
+ self.finalize_callback = deque()
+
+ fresh = False
+
+ fairy_ref = None
+
+ starttime = None
+
+ dbapi_connection = None
+ """A reference to the actual DBAPI connection being tracked.
+
+ May be ``None`` if this :class:`._ConnectionRecord` has been marked
+ as invalidated; a new DBAPI connection may replace it if the owning
+ pool calls upon this :class:`._ConnectionRecord` to reconnect.
+
+ For adapted drivers, like the Asyncio implementations, this is a
+ :class:`.AdaptedConnection` that adapts the driver connection
+ to the DBAPI protocol.
+ Use :attr:`._ConnectionRecord.driver_connection` to obtain the
+ connection objected returned by the driver.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ @property
+ def driver_connection(self):
+ """The connection object as returned by the driver after a connect.
+
+ For normal sync drivers that support the DBAPI protocol, this object
+ is the same as the one referenced by
+ :attr:`._ConnectionRecord.dbapi_connection`.
+
+ For adapted drivers, like the Asyncio ones, this is the actual object
+ that was returned by the driver ``connect`` call.
+
+ As :attr:`._ConnectionRecord.dbapi_connection` it may be ``None``
+ if this :class:`._ConnectionRecord` has been marked as invalidated.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ if self.dbapi_connection is None:
+ return None
+ else:
+ return self.__pool._dialect.get_driver_connection(
+ self.dbapi_connection
+ )
+
+ @property
+ def connection(self):
+ """An alias to :attr:`._ConnectionRecord.dbapi_connection`.
+
+ This alias is deprecated, please use the new name.
+
+ .. deprecated:: 1.4.24
+
+ """
+ return self.dbapi_connection
+
+ @connection.setter
+ def connection(self, value):
+ self.dbapi_connection = value
+
+ _soft_invalidate_time = 0
+
+ @util.memoized_property
+ def info(self):
+ """The ``.info`` dictionary associated with the DBAPI connection.
+
+ This dictionary is shared among the :attr:`._ConnectionFairy.info`
+ and :attr:`_engine.Connection.info` accessors.
+
+ .. note::
+
+ The lifespan of this dictionary is linked to the
+ DBAPI connection itself, meaning that it is **discarded** each time
+ the DBAPI connection is closed and/or invalidated. The
+ :attr:`._ConnectionRecord.record_info` dictionary remains
+ persistent throughout the lifespan of the
+ :class:`._ConnectionRecord` container.
+
+ """
+ return {}
+
+ @util.memoized_property
+ def record_info(self):
+ """An "info' dictionary associated with the connection record
+ itself.
+
+ Unlike the :attr:`._ConnectionRecord.info` dictionary, which is linked
+ to the lifespan of the DBAPI connection, this dictionary is linked
+ to the lifespan of the :class:`._ConnectionRecord` container itself
+ and will remain persistent throughout the life of the
+ :class:`._ConnectionRecord`.
+
+ .. versionadded:: 1.1
+
+ """
+ return {}
+
+ @classmethod
+ def checkout(cls, pool):
+ rec = pool._do_get()
+ try:
+ dbapi_connection = rec.get_connection()
+ except Exception as err:
+ with util.safe_reraise():
+ rec._checkin_failed(err, _fairy_was_created=False)
+ echo = pool._should_log_debug()
+ fairy = _ConnectionFairy(dbapi_connection, rec, echo)
+
+ rec.fairy_ref = ref = weakref.ref(
+ fairy,
+ lambda ref: _finalize_fairy
+ and _finalize_fairy(None, rec, pool, ref, echo, True),
+ )
+ _strong_ref_connection_records[ref] = rec
+ if echo:
+ pool.logger.debug(
+ "Connection %r checked out from pool", dbapi_connection
+ )
+ return fairy
+
+ def _checkin_failed(self, err, _fairy_was_created=True):
+ self.invalidate(e=err)
+ self.checkin(
+ _fairy_was_created=_fairy_was_created,
+ )
+
+ def checkin(self, _fairy_was_created=True):
+ if self.fairy_ref is None and _fairy_was_created:
+ # _fairy_was_created is False for the initial get connection phase;
+ # meaning there was no _ConnectionFairy and we must unconditionally
+ # do a checkin.
+ #
+ # otherwise, if fairy_was_created==True, if fairy_ref is None here
+ # that means we were checked in already, so this looks like
+ # a double checkin.
+ util.warn("Double checkin attempted on %s" % self)
+ return
+ self.fairy_ref = None
+ connection = self.dbapi_connection
+ pool = self.__pool
+ while self.finalize_callback:
+ finalizer = self.finalize_callback.pop()
+ finalizer(connection)
+ if pool.dispatch.checkin:
+ pool.dispatch.checkin(connection, self)
+
+ pool._return_conn(self)
+
+ @property
+ def in_use(self):
+ return self.fairy_ref is not None
+
+ @property
+ def last_connect_time(self):
+ return self.starttime
+
+ def close(self):
+ if self.dbapi_connection is not None:
+ self.__close()
+
+ def invalidate(self, e=None, soft=False):
+ """Invalidate the DBAPI connection held by this
+ :class:`._ConnectionRecord`.
+
+ This method is called for all connection invalidations, including
+ when the :meth:`._ConnectionFairy.invalidate` or
+ :meth:`_engine.Connection.invalidate` methods are called,
+ as well as when any
+ so-called "automatic invalidation" condition occurs.
+
+ :param e: an exception object indicating a reason for the
+ invalidation.
+
+ :param soft: if True, the connection isn't closed; instead, this
+ connection will be recycled on next checkout.
+
+ .. versionadded:: 1.0.3
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+ # already invalidated
+ if self.dbapi_connection is None:
+ return
+ if soft:
+ self.__pool.dispatch.soft_invalidate(
+ self.dbapi_connection, self, e
+ )
+ else:
+ self.__pool.dispatch.invalidate(self.dbapi_connection, self, e)
+ if e is not None:
+ self.__pool.logger.info(
+ "%sInvalidate connection %r (reason: %s:%s)",
+ "Soft " if soft else "",
+ self.dbapi_connection,
+ e.__class__.__name__,
+ e,
+ )
+ else:
+ self.__pool.logger.info(
+ "%sInvalidate connection %r",
+ "Soft " if soft else "",
+ self.dbapi_connection,
+ )
+
+ if soft:
+ self._soft_invalidate_time = time.time()
+ else:
+ self.__close()
+ self.dbapi_connection = None
+
+ def get_connection(self):
+ recycle = False
+
+ # NOTE: the various comparisons here are assuming that measurable time
+ # passes between these state changes. however, time.time() is not
+ # guaranteed to have sub-second precision. comparisons of
+ # "invalidation time" to "starttime" should perhaps use >= so that the
+ # state change can take place assuming no measurable time has passed,
+ # however this does not guarantee correct behavior here as if time
+ # continues to not pass, it will try to reconnect repeatedly until
+ # these timestamps diverge, so in that sense using > is safer. Per
+ # https://stackoverflow.com/a/1938096/34549, Windows time.time() may be
+ # within 16 milliseconds accuracy, so unit tests for connection
+ # invalidation need a sleep of at least this long between initial start
+ # time and invalidation for the logic below to work reliably.
+ if self.dbapi_connection is None:
+ self.info.clear()
+ self.__connect()
+ elif (
+ self.__pool._recycle > -1
+ and time.time() - self.starttime > self.__pool._recycle
+ ):
+ self.__pool.logger.info(
+ "Connection %r exceeded timeout; recycling",
+ self.dbapi_connection,
+ )
+ recycle = True
+ elif self.__pool._invalidate_time > self.starttime:
+ self.__pool.logger.info(
+ "Connection %r invalidated due to pool invalidation; "
+ + "recycling",
+ self.dbapi_connection,
+ )
+ recycle = True
+ elif self._soft_invalidate_time > self.starttime:
+ self.__pool.logger.info(
+ "Connection %r invalidated due to local soft invalidation; "
+ + "recycling",
+ self.dbapi_connection,
+ )
+ recycle = True
+
+ if recycle:
+ self.__close()
+ self.info.clear()
+
+ self.__connect()
+ return self.dbapi_connection
+
+ def _is_hard_or_soft_invalidated(self):
+ return (
+ self.dbapi_connection is None
+ or self.__pool._invalidate_time > self.starttime
+ or (self._soft_invalidate_time > self.starttime)
+ )
+
+ def __close(self):
+ self.finalize_callback.clear()
+ if self.__pool.dispatch.close:
+ self.__pool.dispatch.close(self.dbapi_connection, self)
+ self.__pool._close_connection(self.dbapi_connection)
+ self.dbapi_connection = None
+
+ def __connect(self):
+ pool = self.__pool
+
+ # ensure any existing connection is removed, so that if
+ # creator fails, this attribute stays None
+ self.dbapi_connection = None
+ try:
+ self.starttime = time.time()
+ self.dbapi_connection = connection = pool._invoke_creator(self)
+ pool.logger.debug("Created new connection %r", connection)
+ self.fresh = True
+ except Exception as e:
+ with util.safe_reraise():
+ pool.logger.debug("Error on connect(): %s", e)
+ else:
+ # in SQLAlchemy 1.4 the first_connect event is not used by
+ # the engine, so this will usually not be set
+ if pool.dispatch.first_connect:
+ pool.dispatch.first_connect.for_modify(
+ pool.dispatch
+ ).exec_once_unless_exception(self.dbapi_connection, self)
+
+ # init of the dialect now takes place within the connect
+ # event, so ensure a mutex is used on the first run
+ pool.dispatch.connect.for_modify(
+ pool.dispatch
+ )._exec_w_sync_on_first_run(self.dbapi_connection, self)
+
+
+def _finalize_fairy(
+ dbapi_connection,
+ connection_record,
+ pool,
+ ref, # this is None when called directly, not by the gc
+ echo,
+ reset=True,
+ fairy=None,
+):
+ """Cleanup for a :class:`._ConnectionFairy` whether or not it's already
+ been garbage collected.
+
+ When using an async dialect no IO can happen here (without using
+ a dedicated thread), since this is called outside the greenlet
+ context and with an already running loop. In this case function
+ will only log a message and raise a warning.
+ """
+
+ if ref:
+ _strong_ref_connection_records.pop(ref, None)
+ elif fairy:
+ _strong_ref_connection_records.pop(weakref.ref(fairy), None)
+
+ if ref is not None:
+ if connection_record.fairy_ref is not ref:
+ return
+ assert dbapi_connection is None
+ dbapi_connection = connection_record.dbapi_connection
+
+ # null pool is not _is_asyncio but can be used also with async dialects
+ dont_restore_gced = pool._dialect.is_async
+
+ if dont_restore_gced:
+ detach = not connection_record or ref
+ can_manipulate_connection = not ref
+ else:
+ detach = not connection_record
+ can_manipulate_connection = True
+
+ if dbapi_connection is not None:
+ if connection_record and echo:
+ pool.logger.debug(
+ "Connection %r being returned to pool%s",
+ dbapi_connection,
+ ", transaction state was already reset by caller"
+ if not reset
+ else "",
+ )
+
+ try:
+ fairy = fairy or _ConnectionFairy(
+ dbapi_connection,
+ connection_record,
+ echo,
+ )
+ assert fairy.dbapi_connection is dbapi_connection
+ if reset and can_manipulate_connection:
+ fairy._reset(pool)
+
+ if detach:
+ if connection_record:
+ fairy._pool = pool
+ fairy.detach()
+
+ if can_manipulate_connection:
+ if pool.dispatch.close_detached:
+ pool.dispatch.close_detached(dbapi_connection)
+
+ pool._close_connection(dbapi_connection)
+ else:
+ message = (
+ "The garbage collector is trying to clean up "
+ "connection %r. This feature is unsupported on async "
+ "dbapi, since no IO can be performed at this stage to "
+ "reset the connection. Please close out all "
+ "connections when they are no longer used, calling "
+ "``close()`` or using a context manager to "
+ "manage their lifetime."
+ ) % dbapi_connection
+ pool.logger.error(message)
+ util.warn(message)
+
+ except BaseException as e:
+ pool.logger.error(
+ "Exception during reset or similar", exc_info=True
+ )
+ if connection_record:
+ connection_record.invalidate(e=e)
+ if not isinstance(e, Exception):
+ raise
+
+ if connection_record and connection_record.fairy_ref is not None:
+ connection_record.checkin()
+
+
+# a dictionary of the _ConnectionFairy weakrefs to _ConnectionRecord, so that
+# GC under pypy will call ConnectionFairy finalizers. linked directly to the
+# weakref that will empty itself when collected so that it should not create
+# any unmanaged memory references.
+_strong_ref_connection_records = {}
+
+
+class _ConnectionFairy(object):
+
+ """Proxies a DBAPI connection and provides return-on-dereference
+ support.
+
+ This is an internal object used by the :class:`_pool.Pool` implementation
+ to provide context management to a DBAPI connection delivered by
+ that :class:`_pool.Pool`.
+
+ The name "fairy" is inspired by the fact that the
+ :class:`._ConnectionFairy` object's lifespan is transitory, as it lasts
+ only for the length of a specific DBAPI connection being checked out from
+ the pool, and additionally that as a transparent proxy, it is mostly
+ invisible.
+
+ .. seealso::
+
+ :class:`._ConnectionRecord`
+
+ """
+
+ def __init__(self, dbapi_connection, connection_record, echo):
+ self.dbapi_connection = dbapi_connection
+ self._connection_record = connection_record
+ self._echo = echo
+
+ dbapi_connection = None
+ """A reference to the actual DBAPI connection being tracked.
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :attr:`._ConnectionFairy.driver_connection`
+
+ :attr:`._ConnectionRecord.dbapi_connection`
+
+ :ref:`faq_dbapi_connection`
+
+ """
+
+ _connection_record = None
+ """A reference to the :class:`._ConnectionRecord` object associated
+ with the DBAPI connection.
+
+ This is currently an internal accessor which is subject to change.
+
+ """
+
+ @property
+ def driver_connection(self):
+ """The connection object as returned by the driver after a connect.
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :attr:`._ConnectionFairy.dbapi_connection`
+
+ :attr:`._ConnectionRecord.driver_connection`
+
+ :ref:`faq_dbapi_connection`
+
+ """
+ return self._connection_record.driver_connection
+
+ @property
+ def connection(self):
+ """An alias to :attr:`._ConnectionFairy.dbapi_connection`.
+
+ This alias is deprecated, please use the new name.
+
+ .. deprecated:: 1.4.24
+
+ """
+ return self.dbapi_connection
+
+ @connection.setter
+ def connection(self, value):
+ self.dbapi_connection = value
+
+ @classmethod
+ def _checkout(cls, pool, threadconns=None, fairy=None):
+ if not fairy:
+ fairy = _ConnectionRecord.checkout(pool)
+
+ fairy._pool = pool
+ fairy._counter = 0
+
+ if threadconns is not None:
+ threadconns.current = weakref.ref(fairy)
+
+ if fairy.dbapi_connection is None:
+ raise exc.InvalidRequestError("This connection is closed")
+ fairy._counter += 1
+ if (
+ not pool.dispatch.checkout and not pool._pre_ping
+ ) or fairy._counter != 1:
+ return fairy
+
+ # Pool listeners can trigger a reconnection on checkout, as well
+ # as the pre-pinger.
+ # there are three attempts made here, but note that if the database
+ # is not accessible from a connection standpoint, those won't proceed
+ # here.
+ attempts = 2
+ while attempts > 0:
+ connection_is_fresh = fairy._connection_record.fresh
+ fairy._connection_record.fresh = False
+ try:
+ if pool._pre_ping:
+ if not connection_is_fresh:
+ if fairy._echo:
+ pool.logger.debug(
+ "Pool pre-ping on connection %s",
+ fairy.dbapi_connection,
+ )
+ result = pool._dialect.do_ping(fairy.dbapi_connection)
+ if not result:
+ if fairy._echo:
+ pool.logger.debug(
+ "Pool pre-ping on connection %s failed, "
+ "will invalidate pool",
+ fairy.dbapi_connection,
+ )
+ raise exc.InvalidatePoolError()
+ elif fairy._echo:
+ pool.logger.debug(
+ "Connection %s is fresh, skipping pre-ping",
+ fairy.dbapi_connection,
+ )
+
+ pool.dispatch.checkout(
+ fairy.dbapi_connection, fairy._connection_record, fairy
+ )
+ return fairy
+ except exc.DisconnectionError as e:
+ if e.invalidate_pool:
+ pool.logger.info(
+ "Disconnection detected on checkout, "
+ "invalidating all pooled connections prior to "
+ "current timestamp (reason: %r)",
+ e,
+ )
+ fairy._connection_record.invalidate(e)
+ pool._invalidate(fairy, e, _checkin=False)
+ else:
+ pool.logger.info(
+ "Disconnection detected on checkout, "
+ "invalidating individual connection %s (reason: %r)",
+ fairy.dbapi_connection,
+ e,
+ )
+ fairy._connection_record.invalidate(e)
+ try:
+ fairy.dbapi_connection = (
+ fairy._connection_record.get_connection()
+ )
+ except Exception as err:
+ with util.safe_reraise():
+ fairy._connection_record._checkin_failed(
+ err,
+ _fairy_was_created=True,
+ )
+
+ # prevent _ConnectionFairy from being carried
+ # in the stack trace. Do this after the
+ # connection record has been checked in, so that
+ # if the del triggers a finalize fairy, it won't
+ # try to checkin a second time.
+ del fairy
+
+ attempts -= 1
+
+ pool.logger.info("Reconnection attempts exhausted on checkout")
+ fairy.invalidate()
+ raise exc.InvalidRequestError("This connection is closed")
+
+ def _checkout_existing(self):
+ return _ConnectionFairy._checkout(self._pool, fairy=self)
+
+ def _checkin(self, reset=True):
+ _finalize_fairy(
+ self.dbapi_connection,
+ self._connection_record,
+ self._pool,
+ None,
+ self._echo,
+ reset=reset,
+ fairy=self,
+ )
+ self.dbapi_connection = None
+ self._connection_record = None
+
+ _close = _checkin
+
+ def _reset(self, pool):
+ if pool.dispatch.reset:
+ pool.dispatch.reset(self, self._connection_record)
+ if pool._reset_on_return is reset_rollback:
+ if self._echo:
+ pool.logger.debug(
+ "Connection %s rollback-on-return", self.dbapi_connection
+ )
+ pool._dialect.do_rollback(self)
+ elif pool._reset_on_return is reset_commit:
+ if self._echo:
+ pool.logger.debug(
+ "Connection %s commit-on-return",
+ self.dbapi_connection,
+ )
+ pool._dialect.do_commit(self)
+
+ @property
+ def _logger(self):
+ return self._pool.logger
+
+ @property
+ def is_valid(self):
+ """Return True if this :class:`._ConnectionFairy` still refers
+ to an active DBAPI connection."""
+
+ return self.dbapi_connection is not None
+
+ @util.memoized_property
+ def info(self):
+ """Info dictionary associated with the underlying DBAPI connection
+ referred to by this :class:`.ConnectionFairy`, allowing user-defined
+ data to be associated with the connection.
+
+ The data here will follow along with the DBAPI connection including
+ after it is returned to the connection pool and used again
+ in subsequent instances of :class:`._ConnectionFairy`. It is shared
+ with the :attr:`._ConnectionRecord.info` and
+ :attr:`_engine.Connection.info`
+ accessors.
+
+ The dictionary associated with a particular DBAPI connection is
+ discarded when the connection itself is discarded.
+
+ """
+ return self._connection_record.info
+
+ @property
+ def record_info(self):
+ """Info dictionary associated with the :class:`._ConnectionRecord
+ container referred to by this :class:`.ConnectionFairy`.
+
+ Unlike the :attr:`._ConnectionFairy.info` dictionary, the lifespan
+ of this dictionary is persistent across connections that are
+ disconnected and/or invalidated within the lifespan of a
+ :class:`._ConnectionRecord`.
+
+ .. versionadded:: 1.1
+
+ """
+ if self._connection_record:
+ return self._connection_record.record_info
+ else:
+ return None
+
+ def invalidate(self, e=None, soft=False):
+ """Mark this connection as invalidated.
+
+ This method can be called directly, and is also called as a result
+ of the :meth:`_engine.Connection.invalidate` method. When invoked,
+ the DBAPI connection is immediately closed and discarded from
+ further use by the pool. The invalidation mechanism proceeds
+ via the :meth:`._ConnectionRecord.invalidate` internal method.
+
+ :param e: an exception object indicating a reason for the invalidation.
+
+ :param soft: if True, the connection isn't closed; instead, this
+ connection will be recycled on next checkout.
+
+ .. versionadded:: 1.0.3
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+
+ if self.dbapi_connection is None:
+ util.warn("Can't invalidate an already-closed connection.")
+ return
+ if self._connection_record:
+ self._connection_record.invalidate(e=e, soft=soft)
+ if not soft:
+ self.dbapi_connection = None
+ self._checkin()
+
+ def cursor(self, *args, **kwargs):
+ """Return a new DBAPI cursor for the underlying connection.
+
+ This method is a proxy for the ``connection.cursor()`` DBAPI
+ method.
+
+ """
+ return self.dbapi_connection.cursor(*args, **kwargs)
+
+ def __getattr__(self, key):
+ return getattr(self.dbapi_connection, key)
+
+ def detach(self):
+ """Separate this connection from its Pool.
+
+ This means that the connection will no longer be returned to the
+ pool when closed, and will instead be literally closed. The
+ containing ConnectionRecord is separated from the DB-API connection,
+ and will create a new connection when next used.
+
+ Note that any overall connection limiting constraints imposed by a
+ Pool implementation may be violated after a detach, as the detached
+ connection is removed from the pool's knowledge and control.
+ """
+
+ if self._connection_record is not None:
+ rec = self._connection_record
+ rec.fairy_ref = None
+ rec.dbapi_connection = None
+ # TODO: should this be _return_conn?
+ self._pool._do_return_conn(self._connection_record)
+ self.info = self.info.copy()
+ self._connection_record = None
+
+ if self._pool.dispatch.detach:
+ self._pool.dispatch.detach(self.dbapi_connection, rec)
+
+ def close(self):
+ self._counter -= 1
+ if self._counter == 0:
+ self._checkin()
+
+ def _close_no_reset(self):
+ self._counter -= 1
+ if self._counter == 0:
+ self._checkin(reset=False)
diff --git a/lib/sqlalchemy/pool/dbapi_proxy.py b/lib/sqlalchemy/pool/dbapi_proxy.py
new file mode 100644
index 0000000..b0c40f2
--- /dev/null
+++ b/lib/sqlalchemy/pool/dbapi_proxy.py
@@ -0,0 +1,147 @@
+# sqlalchemy/pool/dbapi_proxy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""DBAPI proxy utility.
+
+Provides transparent connection pooling on top of a Python DBAPI.
+
+This is legacy SQLAlchemy functionality that is not typically used
+today.
+
+"""
+
+from .impl import QueuePool
+from .. import util
+from ..util import threading
+
+proxies = {}
+
+
+@util.deprecated(
+ "1.3",
+ "The :func:`.pool.manage` function is deprecated, and will be "
+ "removed in a future release.",
+)
+def manage(module, **params):
+ r"""Return a proxy for a DB-API module that automatically
+ pools connections.
+
+ Given a DB-API 2.0 module and pool management parameters, returns
+ a proxy for the module that will automatically pool connections,
+ creating new connection pools for each distinct set of connection
+ arguments sent to the decorated module's connect() function.
+
+ :param module: a DB-API 2.0 database module
+
+ :param poolclass: the class used by the pool module to provide
+ pooling. Defaults to :class:`.QueuePool`.
+
+ :param \**params: will be passed through to *poolclass*
+
+ """
+ try:
+ return proxies[module]
+ except KeyError:
+ return proxies.setdefault(module, _DBProxy(module, **params))
+
+
+def clear_managers():
+ """Remove all current DB-API 2.0 managers.
+
+ All pools and connections are disposed.
+ """
+
+ for manager in proxies.values():
+ manager.close()
+ proxies.clear()
+
+
+class _DBProxy(object):
+
+ """Layers connection pooling behavior on top of a standard DB-API module.
+
+ Proxies a DB-API 2.0 connect() call to a connection pool keyed to the
+ specific connect parameters. Other functions and attributes are delegated
+ to the underlying DB-API module.
+ """
+
+ def __init__(self, module, poolclass=QueuePool, **kw):
+ """Initializes a new proxy.
+
+ module
+ a DB-API 2.0 module
+
+ poolclass
+ a Pool class, defaulting to QueuePool
+
+ Other parameters are sent to the Pool object's constructor.
+
+ """
+
+ self.module = module
+ self.kw = kw
+ self.poolclass = poolclass
+ self.pools = {}
+ self._create_pool_mutex = threading.Lock()
+
+ def close(self):
+ for key in list(self.pools):
+ del self.pools[key]
+
+ def __del__(self):
+ self.close()
+
+ def __getattr__(self, key):
+ return getattr(self.module, key)
+
+ def get_pool(self, *args, **kw):
+ key = self._serialize(*args, **kw)
+ try:
+ return self.pools[key]
+ except KeyError:
+ with self._create_pool_mutex:
+ if key not in self.pools:
+ kw.pop("sa_pool_key", None)
+ pool = self.poolclass(
+ lambda: self.module.connect(*args, **kw), **self.kw
+ )
+ self.pools[key] = pool
+ return pool
+ else:
+ return self.pools[key]
+
+ def connect(self, *args, **kw):
+ """Activate a connection to the database.
+
+ Connect to the database using this DBProxy's module and the given
+ connect arguments. If the arguments match an existing pool, the
+ connection will be returned from the pool's current thread-local
+ connection instance, or if there is no thread-local connection
+ instance it will be checked out from the set of pooled connections.
+
+ If the pool has no available connections and allows new connections
+ to be created, a new database connection will be made.
+
+ """
+
+ return self.get_pool(*args, **kw).connect()
+
+ def dispose(self, *args, **kw):
+ """Dispose the pool referenced by the given connect arguments."""
+
+ key = self._serialize(*args, **kw)
+ try:
+ del self.pools[key]
+ except KeyError:
+ pass
+
+ def _serialize(self, *args, **kw):
+ if "sa_pool_key" in kw:
+ return kw["sa_pool_key"]
+
+ return tuple(list(args) + [(k, kw[k]) for k in sorted(kw)])
diff --git a/lib/sqlalchemy/pool/events.py b/lib/sqlalchemy/pool/events.py
new file mode 100644
index 0000000..2829a58
--- /dev/null
+++ b/lib/sqlalchemy/pool/events.py
@@ -0,0 +1,284 @@
+# sqlalchemy/pool/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .base import Pool
+from .. import event
+from ..engine.base import Engine
+
+
+class PoolEvents(event.Events):
+ """Available events for :class:`_pool.Pool`.
+
+ The methods here define the name of an event as well
+ as the names of members that are passed to listener
+ functions.
+
+ e.g.::
+
+ from sqlalchemy import event
+
+ def my_on_checkout(dbapi_conn, connection_rec, connection_proxy):
+ "handle an on checkout event"
+
+ event.listen(Pool, 'checkout', my_on_checkout)
+
+ In addition to accepting the :class:`_pool.Pool` class and
+ :class:`_pool.Pool` instances, :class:`_events.PoolEvents` also accepts
+ :class:`_engine.Engine` objects and the :class:`_engine.Engine` class as
+ targets, which will be resolved to the ``.pool`` attribute of the
+ given engine or the :class:`_pool.Pool` class::
+
+ engine = create_engine("postgresql://scott:tiger@localhost/test")
+
+ # will associate with engine.pool
+ event.listen(engine, 'checkout', my_on_checkout)
+
+ """
+
+ _target_class_doc = "SomeEngineOrPool"
+ _dispatch_target = Pool
+
+ @classmethod
+ def _accept_with(cls, target):
+ if isinstance(target, type):
+ if issubclass(target, Engine):
+ return Pool
+ elif issubclass(target, Pool):
+ return target
+ elif isinstance(target, Engine):
+ return target.pool
+ elif isinstance(target, Pool):
+ return target
+ elif hasattr(target, "dispatch") and hasattr(
+ target.dispatch._events, "_no_async_engine_events"
+ ):
+ target.dispatch._events._no_async_engine_events()
+ else:
+ return None
+
+ @classmethod
+ def _listen(cls, event_key, **kw):
+ target = event_key.dispatch_target
+
+ kw.setdefault("asyncio", target._is_asyncio)
+
+ event_key.base_listen(**kw)
+
+ def connect(self, dbapi_connection, connection_record):
+ """Called at the moment a particular DBAPI connection is first
+ created for a given :class:`_pool.Pool`.
+
+ This event allows one to capture the point directly after which
+ the DBAPI module-level ``.connect()`` method has been used in order
+ to produce a new DBAPI connection.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def first_connect(self, dbapi_connection, connection_record):
+ """Called exactly once for the first time a DBAPI connection is
+ checked out from a particular :class:`_pool.Pool`.
+
+ The rationale for :meth:`_events.PoolEvents.first_connect`
+ is to determine
+ information about a particular series of database connections based
+ on the settings used for all connections. Since a particular
+ :class:`_pool.Pool`
+ refers to a single "creator" function (which in terms
+ of a :class:`_engine.Engine`
+ refers to the URL and connection options used),
+ it is typically valid to make observations about a single connection
+ that can be safely assumed to be valid about all subsequent
+ connections, such as the database version, the server and client
+ encoding settings, collation settings, and many others.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def checkout(self, dbapi_connection, connection_record, connection_proxy):
+ """Called when a connection is retrieved from the Pool.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ :param connection_proxy: the :class:`._ConnectionFairy` object which
+ will proxy the public interface of the DBAPI connection for the
+ lifespan of the checkout.
+
+ If you raise a :class:`~sqlalchemy.exc.DisconnectionError`, the current
+ connection will be disposed and a fresh connection retrieved.
+ Processing of all checkout listeners will abort and restart
+ using the new connection.
+
+ .. seealso:: :meth:`_events.ConnectionEvents.engine_connect`
+ - a similar event
+ which occurs upon creation of a new :class:`_engine.Connection`.
+
+ """
+
+ def checkin(self, dbapi_connection, connection_record):
+ """Called when a connection returns to the pool.
+
+ Note that the connection may be closed, and may be None if the
+ connection has been invalidated. ``checkin`` will not be called
+ for detached connections. (They do not return to the pool.)
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def reset(self, dbapi_connection, connection_record):
+ """Called before the "reset" action occurs for a pooled connection.
+
+ This event represents
+ when the ``rollback()`` method is called on the DBAPI connection
+ before it is returned to the pool. The behavior of "reset" can
+ be controlled, including disabled, using the ``reset_on_return``
+ pool argument.
+
+
+ The :meth:`_events.PoolEvents.reset` event is usually followed by the
+ :meth:`_events.PoolEvents.checkin` event is called, except in those
+ cases where the connection is discarded immediately after reset.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ .. seealso::
+
+ :meth:`_events.ConnectionEvents.rollback`
+
+ :meth:`_events.ConnectionEvents.commit`
+
+ """
+
+ def invalidate(self, dbapi_connection, connection_record, exception):
+ """Called when a DBAPI connection is to be "invalidated".
+
+ This event is called any time the :meth:`._ConnectionRecord.invalidate`
+ method is invoked, either from API usage or via "auto-invalidation",
+ without the ``soft`` flag.
+
+ The event occurs before a final attempt to call ``.close()`` on the
+ connection occurs.
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ :param exception: the exception object corresponding to the reason
+ for this invalidation, if any. May be ``None``.
+
+ .. versionadded:: 0.9.2 Added support for connection invalidation
+ listening.
+
+ .. seealso::
+
+ :ref:`pool_connection_invalidation`
+
+ """
+
+ def soft_invalidate(self, dbapi_connection, connection_record, exception):
+ """Called when a DBAPI connection is to be "soft invalidated".
+
+ This event is called any time the :meth:`._ConnectionRecord.invalidate`
+ method is invoked with the ``soft`` flag.
+
+ Soft invalidation refers to when the connection record that tracks
+ this connection will force a reconnect after the current connection
+ is checked in. It does not actively close the dbapi_connection
+ at the point at which it is called.
+
+ .. versionadded:: 1.0.3
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ :param exception: the exception object corresponding to the reason
+ for this invalidation, if any. May be ``None``.
+
+ """
+
+ def close(self, dbapi_connection, connection_record):
+ """Called when a DBAPI connection is closed.
+
+ The event is emitted before the close occurs.
+
+ The close of a connection can fail; typically this is because
+ the connection is already closed. If the close operation fails,
+ the connection is discarded.
+
+ The :meth:`.close` event corresponds to a connection that's still
+ associated with the pool. To intercept close events for detached
+ connections use :meth:`.close_detached`.
+
+ .. versionadded:: 1.1
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def detach(self, dbapi_connection, connection_record):
+ """Called when a DBAPI connection is "detached" from a pool.
+
+ This event is emitted after the detach occurs. The connection
+ is no longer associated with the given connection record.
+
+ .. versionadded:: 1.1
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ :param connection_record: the :class:`._ConnectionRecord` managing the
+ DBAPI connection.
+
+ """
+
+ def close_detached(self, dbapi_connection):
+ """Called when a detached DBAPI connection is closed.
+
+ The event is emitted before the close occurs.
+
+ The close of a connection can fail; typically this is because
+ the connection is already closed. If the close operation fails,
+ the connection is discarded.
+
+ .. versionadded:: 1.1
+
+ :param dbapi_connection: a DBAPI connection.
+ The :attr:`._ConnectionRecord.dbapi_connection` attribute.
+
+ """
diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py
new file mode 100644
index 0000000..91d0290
--- /dev/null
+++ b/lib/sqlalchemy/pool/impl.py
@@ -0,0 +1,514 @@
+# sqlalchemy/pool.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+"""Pool implementation classes.
+
+"""
+
+import traceback
+import weakref
+
+from .base import _AsyncConnDialect
+from .base import _ConnectionFairy
+from .base import _ConnectionRecord
+from .base import Pool
+from .. import exc
+from .. import util
+from ..util import chop_traceback
+from ..util import queue as sqla_queue
+from ..util import threading
+
+
+class QueuePool(Pool):
+
+ """A :class:`_pool.Pool`
+ that imposes a limit on the number of open connections.
+
+ :class:`.QueuePool` is the default pooling implementation used for
+ all :class:`_engine.Engine` objects, unless the SQLite dialect is in use.
+
+ """
+
+ _is_asyncio = False
+ _queue_class = sqla_queue.Queue
+
+ def __init__(
+ self,
+ creator,
+ pool_size=5,
+ max_overflow=10,
+ timeout=30.0,
+ use_lifo=False,
+ **kw
+ ):
+ r"""
+ Construct a QueuePool.
+
+ :param creator: a callable function that returns a DB-API
+ connection object, same as that of :paramref:`_pool.Pool.creator`.
+
+ :param pool_size: The size of the pool to be maintained,
+ defaults to 5. This is the largest number of connections that
+ will be kept persistently in the pool. Note that the pool
+ begins with no connections; once this number of connections
+ is requested, that number of connections will remain.
+ ``pool_size`` can be set to 0 to indicate no size limit; to
+ disable pooling, use a :class:`~sqlalchemy.pool.NullPool`
+ instead.
+
+ :param max_overflow: The maximum overflow size of the
+ pool. When the number of checked-out connections reaches the
+ size set in pool_size, additional connections will be
+ returned up to this limit. When those additional connections
+ are returned to the pool, they are disconnected and
+ discarded. It follows then that the total number of
+ simultaneous connections the pool will allow is pool_size +
+ `max_overflow`, and the total number of "sleeping"
+ connections the pool will allow is pool_size. `max_overflow`
+ can be set to -1 to indicate no overflow limit; no limit
+ will be placed on the total number of concurrent
+ connections. Defaults to 10.
+
+ :param timeout: The number of seconds to wait before giving up
+ on returning a connection. Defaults to 30.0. This can be a float
+ but is subject to the limitations of Python time functions which
+ may not be reliable in the tens of milliseconds.
+
+ :param use_lifo: use LIFO (last-in-first-out) when retrieving
+ connections instead of FIFO (first-in-first-out). Using LIFO, a
+ server-side timeout scheme can reduce the number of connections used
+ during non-peak periods of use. When planning for server-side
+ timeouts, ensure that a recycle or pre-ping strategy is in use to
+ gracefully handle stale connections.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`pool_use_lifo`
+
+ :ref:`pool_disconnects`
+
+ :param \**kw: Other keyword arguments including
+ :paramref:`_pool.Pool.recycle`, :paramref:`_pool.Pool.echo`,
+ :paramref:`_pool.Pool.reset_on_return` and others are passed to the
+ :class:`_pool.Pool` constructor.
+
+ """
+ Pool.__init__(self, creator, **kw)
+ self._pool = self._queue_class(pool_size, use_lifo=use_lifo)
+ self._overflow = 0 - pool_size
+ self._max_overflow = max_overflow
+ self._timeout = timeout
+ self._overflow_lock = threading.Lock()
+
+ def _do_return_conn(self, conn):
+ try:
+ self._pool.put(conn, False)
+ except sqla_queue.Full:
+ try:
+ conn.close()
+ finally:
+ self._dec_overflow()
+
+ def _do_get(self):
+ use_overflow = self._max_overflow > -1
+
+ try:
+ wait = use_overflow and self._overflow >= self._max_overflow
+ return self._pool.get(wait, self._timeout)
+ except sqla_queue.Empty:
+ # don't do things inside of "except Empty", because when we say
+ # we timed out or can't connect and raise, Python 3 tells
+ # people the real error is queue.Empty which it isn't.
+ pass
+ if use_overflow and self._overflow >= self._max_overflow:
+ if not wait:
+ return self._do_get()
+ else:
+ raise exc.TimeoutError(
+ "QueuePool limit of size %d overflow %d reached, "
+ "connection timed out, timeout %0.2f"
+ % (self.size(), self.overflow(), self._timeout),
+ code="3o7r",
+ )
+
+ if self._inc_overflow():
+ try:
+ return self._create_connection()
+ except:
+ with util.safe_reraise():
+ self._dec_overflow()
+ else:
+ return self._do_get()
+
+ def _inc_overflow(self):
+ if self._max_overflow == -1:
+ self._overflow += 1
+ return True
+ with self._overflow_lock:
+ if self._overflow < self._max_overflow:
+ self._overflow += 1
+ return True
+ else:
+ return False
+
+ def _dec_overflow(self):
+ if self._max_overflow == -1:
+ self._overflow -= 1
+ return True
+ with self._overflow_lock:
+ self._overflow -= 1
+ return True
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ self._creator,
+ pool_size=self._pool.maxsize,
+ max_overflow=self._max_overflow,
+ pre_ping=self._pre_ping,
+ use_lifo=self._pool.use_lifo,
+ timeout=self._timeout,
+ recycle=self._recycle,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ reset_on_return=self._reset_on_return,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def dispose(self):
+ while True:
+ try:
+ conn = self._pool.get(False)
+ conn.close()
+ except sqla_queue.Empty:
+ break
+
+ self._overflow = 0 - self.size()
+ self.logger.info("Pool disposed. %s", self.status())
+
+ def status(self):
+ return (
+ "Pool size: %d Connections in pool: %d "
+ "Current Overflow: %d Current Checked out "
+ "connections: %d"
+ % (
+ self.size(),
+ self.checkedin(),
+ self.overflow(),
+ self.checkedout(),
+ )
+ )
+
+ def size(self):
+ return self._pool.maxsize
+
+ def timeout(self):
+ return self._timeout
+
+ def checkedin(self):
+ return self._pool.qsize()
+
+ def overflow(self):
+ return self._overflow
+
+ def checkedout(self):
+ return self._pool.maxsize - self._pool.qsize() + self._overflow
+
+
+class AsyncAdaptedQueuePool(QueuePool):
+ _is_asyncio = True
+ _queue_class = sqla_queue.AsyncAdaptedQueue
+ _dialect = _AsyncConnDialect()
+
+
+class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool):
+ _queue_class = sqla_queue.FallbackAsyncAdaptedQueue
+
+
+class NullPool(Pool):
+
+ """A Pool which does not pool connections.
+
+ Instead it literally opens and closes the underlying DB-API connection
+ per each connection open/close.
+
+ Reconnect-related functions such as ``recycle`` and connection
+ invalidation are not supported by this Pool implementation, since
+ no connections are held persistently.
+
+ """
+
+ def status(self):
+ return "NullPool"
+
+ def _do_return_conn(self, conn):
+ conn.close()
+
+ def _do_get(self):
+ return self._create_connection()
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+
+ return self.__class__(
+ self._creator,
+ recycle=self._recycle,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ reset_on_return=self._reset_on_return,
+ pre_ping=self._pre_ping,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def dispose(self):
+ pass
+
+
+class SingletonThreadPool(Pool):
+
+ """A Pool that maintains one connection per thread.
+
+ Maintains one connection per each thread, never moving a connection to a
+ thread other than the one which it was created in.
+
+ .. warning:: the :class:`.SingletonThreadPool` will call ``.close()``
+ on arbitrary connections that exist beyond the size setting of
+ ``pool_size``, e.g. if more unique **thread identities**
+ than what ``pool_size`` states are used. This cleanup is
+ non-deterministic and not sensitive to whether or not the connections
+ linked to those thread identities are currently in use.
+
+ :class:`.SingletonThreadPool` may be improved in a future release,
+ however in its current status it is generally used only for test
+ scenarios using a SQLite ``:memory:`` database and is not recommended
+ for production use.
+
+
+ Options are the same as those of :class:`_pool.Pool`, as well as:
+
+ :param pool_size: The number of threads in which to maintain connections
+ at once. Defaults to five.
+
+ :class:`.SingletonThreadPool` is used by the SQLite dialect
+ automatically when a memory-based database is used.
+ See :ref:`sqlite_toplevel`.
+
+ """
+
+ _is_asyncio = False
+
+ def __init__(self, creator, pool_size=5, **kw):
+ Pool.__init__(self, creator, **kw)
+ self._conn = threading.local()
+ self._fairy = threading.local()
+ self._all_conns = set()
+ self.size = pool_size
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ self._creator,
+ pool_size=self.size,
+ recycle=self._recycle,
+ echo=self.echo,
+ pre_ping=self._pre_ping,
+ logging_name=self._orig_logging_name,
+ reset_on_return=self._reset_on_return,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def dispose(self):
+ """Dispose of this pool."""
+
+ for conn in self._all_conns:
+ try:
+ conn.close()
+ except Exception:
+ # pysqlite won't even let you close a conn from a thread
+ # that didn't create it
+ pass
+
+ self._all_conns.clear()
+
+ def _cleanup(self):
+ while len(self._all_conns) >= self.size:
+ c = self._all_conns.pop()
+ c.close()
+
+ def status(self):
+ return "SingletonThreadPool id:%d size: %d" % (
+ id(self),
+ len(self._all_conns),
+ )
+
+ def _do_return_conn(self, conn):
+ pass
+
+ def _do_get(self):
+ try:
+ c = self._conn.current()
+ if c:
+ return c
+ except AttributeError:
+ pass
+ c = self._create_connection()
+ self._conn.current = weakref.ref(c)
+ if len(self._all_conns) >= self.size:
+ self._cleanup()
+ self._all_conns.add(c)
+ return c
+
+ def connect(self):
+ # vendored from Pool to include the now removed use_threadlocal
+ # behavior
+ try:
+ rec = self._fairy.current()
+ except AttributeError:
+ pass
+ else:
+ if rec is not None:
+ return rec._checkout_existing()
+
+ return _ConnectionFairy._checkout(self, self._fairy)
+
+ def _return_conn(self, record):
+ try:
+ del self._fairy.current
+ except AttributeError:
+ pass
+ self._do_return_conn(record)
+
+
+class StaticPool(Pool):
+
+ """A Pool of exactly one connection, used for all requests.
+
+ Reconnect-related functions such as ``recycle`` and connection
+ invalidation (which is also used to support auto-reconnect) are only
+ partially supported right now and may not yield good results.
+
+
+ """
+
+ @util.memoized_property
+ def connection(self):
+ return _ConnectionRecord(self)
+
+ def status(self):
+ return "StaticPool"
+
+ def dispose(self):
+ if (
+ "connection" in self.__dict__
+ and self.connection.dbapi_connection is not None
+ ):
+ self.connection.close()
+ del self.__dict__["connection"]
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ creator=self._creator,
+ recycle=self._recycle,
+ reset_on_return=self._reset_on_return,
+ pre_ping=self._pre_ping,
+ echo=self.echo,
+ logging_name=self._orig_logging_name,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def _transfer_from(self, other_static_pool):
+ # used by the test suite to make a new engine / pool without
+ # losing the state of an existing SQLite :memory: connection
+ self._invoke_creator = (
+ lambda crec: other_static_pool.connection.dbapi_connection
+ )
+
+ def _create_connection(self):
+ raise NotImplementedError()
+
+ def _do_return_conn(self, conn):
+ pass
+
+ def _do_get(self):
+ rec = self.connection
+ if rec._is_hard_or_soft_invalidated():
+ del self.__dict__["connection"]
+ rec = self.connection
+
+ return rec
+
+
+class AssertionPool(Pool):
+
+ """A :class:`_pool.Pool` that allows at most one checked out connection at
+ any given time.
+
+ This will raise an exception if more than one connection is checked out
+ at a time. Useful for debugging code that is using more connections
+ than desired.
+
+ """
+
+ def __init__(self, *args, **kw):
+ self._conn = None
+ self._checked_out = False
+ self._store_traceback = kw.pop("store_traceback", True)
+ self._checkout_traceback = None
+ Pool.__init__(self, *args, **kw)
+
+ def status(self):
+ return "AssertionPool"
+
+ def _do_return_conn(self, conn):
+ if not self._checked_out:
+ raise AssertionError("connection is not checked out")
+ self._checked_out = False
+ assert conn is self._conn
+
+ def dispose(self):
+ self._checked_out = False
+ if self._conn:
+ self._conn.close()
+
+ def recreate(self):
+ self.logger.info("Pool recreating")
+ return self.__class__(
+ self._creator,
+ echo=self.echo,
+ pre_ping=self._pre_ping,
+ recycle=self._recycle,
+ reset_on_return=self._reset_on_return,
+ logging_name=self._orig_logging_name,
+ _dispatch=self.dispatch,
+ dialect=self._dialect,
+ )
+
+ def _do_get(self):
+ if self._checked_out:
+ if self._checkout_traceback:
+ suffix = " at:\n%s" % "".join(
+ chop_traceback(self._checkout_traceback)
+ )
+ else:
+ suffix = ""
+ raise AssertionError("connection is already checked out" + suffix)
+
+ if not self._conn:
+ self._conn = self._create_connection()
+
+ self._checked_out = True
+ if self._store_traceback:
+ self._checkout_traceback = traceback.format_stack()
+ return self._conn
diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py
new file mode 100644
index 0000000..e7f388f
--- /dev/null
+++ b/lib/sqlalchemy/processors.py
@@ -0,0 +1,176 @@
+# sqlalchemy/processors.py
+# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""defines generic type conversion functions, as used in bind and result
+processors.
+
+They all share one common characteristic: None is passed through unchanged.
+
+"""
+
+import codecs
+import datetime
+import re
+
+from . import util
+
+
+def str_to_datetime_processor_factory(regexp, type_):
+ rmatch = regexp.match
+ # Even on python2.6 datetime.strptime is both slower than this code
+ # and it does not support microseconds.
+ has_named_groups = bool(regexp.groupindex)
+
+ def process(value):
+ if value is None:
+ return None
+ else:
+ try:
+ m = rmatch(value)
+ except TypeError as err:
+ util.raise_(
+ ValueError(
+ "Couldn't parse %s string '%r' "
+ "- value is not a string." % (type_.__name__, value)
+ ),
+ from_=err,
+ )
+ if m is None:
+ raise ValueError(
+ "Couldn't parse %s string: "
+ "'%s'" % (type_.__name__, value)
+ )
+ if has_named_groups:
+ groups = m.groupdict(0)
+ return type_(
+ **dict(
+ list(
+ zip(
+ iter(groups.keys()),
+ list(map(int, iter(groups.values()))),
+ )
+ )
+ )
+ )
+ else:
+ return type_(*list(map(int, m.groups(0))))
+
+ return process
+
+
+def py_fallback():
+ def to_unicode_processor_factory(encoding, errors=None):
+ decoder = codecs.getdecoder(encoding)
+
+ def process(value):
+ if value is None:
+ return None
+ else:
+ # decoder returns a tuple: (value, len). Simply dropping the
+ # len part is safe: it is done that way in the normal
+ # 'xx'.decode(encoding) code path.
+ return decoder(value, errors)[0]
+
+ return process
+
+ def to_conditional_unicode_processor_factory(encoding, errors=None):
+ decoder = codecs.getdecoder(encoding)
+
+ def process(value):
+ if value is None:
+ return None
+ elif isinstance(value, util.text_type):
+ return value
+ else:
+ # decoder returns a tuple: (value, len). Simply dropping the
+ # len part is safe: it is done that way in the normal
+ # 'xx'.decode(encoding) code path.
+ return decoder(value, errors)[0]
+
+ return process
+
+ def to_decimal_processor_factory(target_class, scale):
+ fstring = "%%.%df" % scale
+
+ def process(value):
+ if value is None:
+ return None
+ else:
+ return target_class(fstring % value)
+
+ return process
+
+ def to_float(value): # noqa
+ if value is None:
+ return None
+ else:
+ return float(value)
+
+ def to_str(value): # noqa
+ if value is None:
+ return None
+ else:
+ return str(value)
+
+ def int_to_boolean(value): # noqa
+ if value is None:
+ return None
+ else:
+ return bool(value)
+
+ DATETIME_RE = re.compile(
+ r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?"
+ )
+ TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?")
+ DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)")
+
+ str_to_datetime = str_to_datetime_processor_factory( # noqa
+ DATETIME_RE, datetime.datetime
+ )
+ str_to_time = str_to_datetime_processor_factory( # noqa
+ TIME_RE, datetime.time
+ ) # noqa
+ str_to_date = str_to_datetime_processor_factory( # noqa
+ DATE_RE, datetime.date
+ ) # noqa
+ return locals()
+
+
+try:
+ from sqlalchemy.cprocessors import DecimalResultProcessor # noqa
+ from sqlalchemy.cprocessors import int_to_boolean # noqa
+ from sqlalchemy.cprocessors import str_to_date # noqa
+ from sqlalchemy.cprocessors import str_to_datetime # noqa
+ from sqlalchemy.cprocessors import str_to_time # noqa
+ from sqlalchemy.cprocessors import to_float # noqa
+ from sqlalchemy.cprocessors import to_str # noqa
+ from sqlalchemy.cprocessors import UnicodeResultProcessor # noqa
+
+ def to_unicode_processor_factory(encoding, errors=None):
+ if errors is not None:
+ return UnicodeResultProcessor(encoding, errors).process
+ else:
+ return UnicodeResultProcessor(encoding).process
+
+ def to_conditional_unicode_processor_factory(encoding, errors=None):
+ if errors is not None:
+ return UnicodeResultProcessor(encoding, errors).conditional_process
+ else:
+ return UnicodeResultProcessor(encoding).conditional_process
+
+ def to_decimal_processor_factory(target_class, scale):
+ # Note that the scale argument is not taken into account for integer
+ # values in the C implementation while it is in the Python one.
+ # For example, the Python implementation might return
+ # Decimal('5.00000') whereas the C implementation will
+ # return Decimal('5'). These are equivalent of course.
+ return DecimalResultProcessor(target_class, "%%.%df" % scale).process
+
+
+except ImportError:
+ globals().update(py_fallback())
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
new file mode 100644
index 0000000..61f82bb
--- /dev/null
+++ b/lib/sqlalchemy/schema.py
@@ -0,0 +1,59 @@
+# schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Compatibility namespace for sqlalchemy.sql.schema and related.
+
+"""
+
+from .sql.base import SchemaVisitor # noqa
+from .sql.ddl import _CreateDropBase # noqa
+from .sql.ddl import _DDLCompiles # noqa
+from .sql.ddl import _DropView # noqa
+from .sql.ddl import AddConstraint # noqa
+from .sql.ddl import CreateColumn # noqa
+from .sql.ddl import CreateIndex # noqa
+from .sql.ddl import CreateSchema # noqa
+from .sql.ddl import CreateSequence # noqa
+from .sql.ddl import CreateTable # noqa
+from .sql.ddl import DDL # noqa
+from .sql.ddl import DDLBase # noqa
+from .sql.ddl import DDLElement # noqa
+from .sql.ddl import DropColumnComment # noqa
+from .sql.ddl import DropConstraint # noqa
+from .sql.ddl import DropIndex # noqa
+from .sql.ddl import DropSchema # noqa
+from .sql.ddl import DropSequence # noqa
+from .sql.ddl import DropTable # noqa
+from .sql.ddl import DropTableComment # noqa
+from .sql.ddl import SetColumnComment # noqa
+from .sql.ddl import SetTableComment # noqa
+from .sql.ddl import sort_tables # noqa
+from .sql.ddl import sort_tables_and_constraints # noqa
+from .sql.naming import conv # noqa
+from .sql.schema import _get_table_key # noqa
+from .sql.schema import BLANK_SCHEMA # noqa
+from .sql.schema import CheckConstraint # noqa
+from .sql.schema import Column # noqa
+from .sql.schema import ColumnCollectionConstraint # noqa
+from .sql.schema import ColumnCollectionMixin # noqa
+from .sql.schema import ColumnDefault # noqa
+from .sql.schema import Computed # noqa
+from .sql.schema import Constraint # noqa
+from .sql.schema import DefaultClause # noqa
+from .sql.schema import DefaultGenerator # noqa
+from .sql.schema import FetchedValue # noqa
+from .sql.schema import ForeignKey # noqa
+from .sql.schema import ForeignKeyConstraint # noqa
+from .sql.schema import Identity # noqa
+from .sql.schema import Index # noqa
+from .sql.schema import MetaData # noqa
+from .sql.schema import PrimaryKeyConstraint # noqa
+from .sql.schema import SchemaItem # noqa
+from .sql.schema import Sequence # noqa
+from .sql.schema import Table # noqa
+from .sql.schema import ThreadLocalMetaData # noqa
+from .sql.schema import UniqueConstraint # noqa
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
new file mode 100644
index 0000000..2677441
--- /dev/null
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -0,0 +1,150 @@
+# sql/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .base import Executable
+from .compiler import COLLECT_CARTESIAN_PRODUCTS
+from .compiler import FROM_LINTING
+from .compiler import NO_LINTING
+from .compiler import WARN_LINTING
+from .expression import Alias
+from .expression import alias
+from .expression import all_
+from .expression import and_
+from .expression import any_
+from .expression import asc
+from .expression import between
+from .expression import bindparam
+from .expression import case
+from .expression import cast
+from .expression import ClauseElement
+from .expression import collate
+from .expression import column
+from .expression import ColumnCollection
+from .expression import ColumnElement
+from .expression import CompoundSelect
+from .expression import cte
+from .expression import Delete
+from .expression import delete
+from .expression import desc
+from .expression import distinct
+from .expression import except_
+from .expression import except_all
+from .expression import exists
+from .expression import extract
+from .expression import false
+from .expression import False_
+from .expression import FromClause
+from .expression import func
+from .expression import funcfilter
+from .expression import Insert
+from .expression import insert
+from .expression import intersect
+from .expression import intersect_all
+from .expression import Join
+from .expression import join
+from .expression import label
+from .expression import LABEL_STYLE_DEFAULT
+from .expression import LABEL_STYLE_DISAMBIGUATE_ONLY
+from .expression import LABEL_STYLE_NONE
+from .expression import LABEL_STYLE_TABLENAME_PLUS_COL
+from .expression import lambda_stmt
+from .expression import LambdaElement
+from .expression import lateral
+from .expression import literal
+from .expression import literal_column
+from .expression import modifier
+from .expression import not_
+from .expression import null
+from .expression import nulls_first
+from .expression import nulls_last
+from .expression import nullsfirst
+from .expression import nullslast
+from .expression import or_
+from .expression import outerjoin
+from .expression import outparam
+from .expression import over
+from .expression import quoted_name
+from .expression import Select
+from .expression import select
+from .expression import Selectable
+from .expression import StatementLambdaElement
+from .expression import Subquery
+from .expression import subquery
+from .expression import table
+from .expression import TableClause
+from .expression import TableSample
+from .expression import tablesample
+from .expression import text
+from .expression import true
+from .expression import True_
+from .expression import tuple_
+from .expression import type_coerce
+from .expression import union
+from .expression import union_all
+from .expression import Update
+from .expression import update
+from .expression import Values
+from .expression import values
+from .expression import within_group
+from .visitors import ClauseVisitor
+
+
+def __go(lcls):
+ global __all__
+ from .. import util as _sa_util
+
+ import inspect as _inspect
+
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
+
+ from .annotation import _prepare_annotations
+ from .annotation import Annotated
+ from .elements import AnnotatedColumnElement
+ from .elements import ClauseList
+ from .selectable import AnnotatedFromClause
+
+ # from .traversals import _preconfigure_traversals
+
+ from . import base
+ from . import coercions
+ from . import elements
+ from . import events
+ from . import lambdas
+ from . import selectable
+ from . import schema
+ from . import sqltypes
+ from . import traversals
+ from . import type_api
+
+ base.coercions = elements.coercions = coercions
+ base.elements = elements
+ base.type_api = type_api
+ coercions.elements = elements
+ coercions.lambdas = lambdas
+ coercions.schema = schema
+ coercions.selectable = selectable
+ coercions.sqltypes = sqltypes
+ coercions.traversals = traversals
+
+ _prepare_annotations(ColumnElement, AnnotatedColumnElement)
+ _prepare_annotations(FromClause, AnnotatedFromClause)
+ _prepare_annotations(ClauseList, Annotated)
+
+ # this is expensive at import time; elements that are used can create
+ # their traversals on demand
+ # _preconfigure_traversals(ClauseElement)
+
+ _sa_util.preloaded.import_prefix("sqlalchemy.sql")
+
+ from . import naming
+
+
+__go(locals())
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
new file mode 100644
index 0000000..5c000ed
--- /dev/null
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -0,0 +1,364 @@
+# sql/annotation.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The :class:`.Annotated` class and related routines; creates hash-equivalent
+copies of SQL constructs which contain context-specific markers and
+associations.
+
+"""
+
+from . import operators
+from .base import HasCacheKey
+from .traversals import anon_map
+from .visitors import InternalTraversal
+from .. import util
+
+EMPTY_ANNOTATIONS = util.immutabledict()
+
+
+class SupportsAnnotations(object):
+ _annotations = EMPTY_ANNOTATIONS
+
+ @util.memoized_property
+ def _annotations_cache_key(self):
+ anon_map_ = anon_map()
+ return (
+ "_annotations",
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map_, [])
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [
+ (key, self._annotations[key])
+ for key in sorted(self._annotations)
+ ]
+ ),
+ )
+
+
+class SupportsCloneAnnotations(SupportsAnnotations):
+
+ _clone_annotations_traverse_internals = [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = new._annotations.union(values)
+ new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
+ return new
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = util.immutabledict(values)
+ new.__dict__.pop("_annotations_cache_key", None)
+ new.__dict__.pop("_generate_cache_key", None)
+ return new
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`_expression.ClauseElement`
+ with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone or self._annotations:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ new = self._clone()
+ new._annotations = util.immutabledict()
+ new.__dict__.pop("_annotations_cache_key", None)
+ return new
+ else:
+ return self
+
+
+class SupportsWrappingAnnotations(SupportsAnnotations):
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`_expression.ClauseElement`
+ with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone:
+ s = self._clone()
+ return s
+ else:
+ return self
+
+
+class Annotated(object):
+ """clones a SupportsAnnotated and applies an 'annotations' dictionary.
+
+ Unlike regular clones, this clone also mimics __hash__() and
+ __cmp__() of the original element so that it takes its place
+ in hashed collections.
+
+ A reference to the original element is maintained, for the important
+ reason of keeping its hash value current. When GC'ed, the
+ hash value may be reused, causing conflicts.
+
+ .. note:: The rationale for Annotated producing a brand new class,
+ rather than placing the functionality directly within ClauseElement,
+ is **performance**. The __hash__() method is absent on plain
+ ClauseElement which leads to significantly reduced function call
+ overhead, as the use of sets and dictionaries against ClauseElement
+ objects is prevalent, but most are not "annotated".
+
+ """
+
+ _is_column_operators = False
+
+ def __new__(cls, *args):
+ if not args:
+ # clone constructor
+ return object.__new__(cls)
+ else:
+ element, values = args
+ # pull appropriate subclass from registry of annotated
+ # classes
+ try:
+ cls = annotated_classes[element.__class__]
+ except KeyError:
+ cls = _new_annotation_type(element.__class__, cls)
+ return object.__new__(cls)
+
+ def __init__(self, element, values):
+ self.__dict__ = element.__dict__.copy()
+ self.__dict__.pop("_annotations_cache_key", None)
+ self.__dict__.pop("_generate_cache_key", None)
+ self.__element = element
+ self._annotations = util.immutabledict(values)
+ self._hash = hash(element)
+
+ def _annotate(self, values):
+ _values = self._annotations.union(values)
+ return self._with_annotations(_values)
+
+ def _with_annotations(self, values):
+ clone = self.__class__.__new__(self.__class__)
+ clone.__dict__ = self.__dict__.copy()
+ clone.__dict__.pop("_annotations_cache_key", None)
+ clone.__dict__.pop("_generate_cache_key", None)
+ clone._annotations = values
+ return clone
+
+ def _deannotate(self, values=None, clone=True):
+ if values is None:
+ return self.__element
+ else:
+ return self._with_annotations(
+ util.immutabledict(
+ {
+ key: value
+ for key, value in self._annotations.items()
+ if key not in values
+ }
+ )
+ )
+
+ def _compiler_dispatch(self, visitor, **kw):
+ return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
+
+ @property
+ def _constructor(self):
+ return self.__element._constructor
+
+ def _clone(self, **kw):
+ clone = self.__element._clone(**kw)
+ if clone is self.__element:
+ # detect immutable, don't change anything
+ return self
+ else:
+ # update the clone with any changes that have occurred
+ # to this object's __dict__.
+ clone.__dict__.update(self.__dict__)
+ return self.__class__(clone, self._annotations)
+
+ def __reduce__(self):
+ return self.__class__, (self.__element, self._annotations)
+
+ def __hash__(self):
+ return self._hash
+
+ def __eq__(self, other):
+ if self._is_column_operators:
+ return self.__element.__class__.__eq__(self, other)
+ else:
+ return hash(other) == hash(self)
+
+ @property
+ def entity_namespace(self):
+ if "entity_namespace" in self._annotations:
+ return self._annotations["entity_namespace"].entity_namespace
+ else:
+ return self.__element.entity_namespace
+
+
+# hard-generate Annotated subclasses. this technique
+# is used instead of on-the-fly types (i.e. type.__new__())
+# so that the resulting objects are pickleable; additionally, other
+# decisions can be made up front about the type of object being annotated
+# just once per class rather than per-instance.
+annotated_classes = {}
+
+
+def _deep_annotate(
+ element, annotations, exclude=None, detect_subquery_cols=False
+):
+ """Deep copy the given ClauseElement, annotating each element
+ with the given annotations dictionary.
+
+ Elements within the exclude collection will be cloned but not annotated.
+
+ """
+
+ # annotated objects hack the __hash__() method so if we want to
+ # uniquely process them we have to use id()
+
+ cloned_ids = {}
+
+ def clone(elem, **kw):
+ kw["detect_subquery_cols"] = detect_subquery_cols
+ id_ = id(elem)
+
+ if id_ in cloned_ids:
+ return cloned_ids[id_]
+
+ if (
+ exclude
+ and hasattr(elem, "proxy_set")
+ and elem.proxy_set.intersection(exclude)
+ ):
+ newelem = elem._clone(clone=clone, **kw)
+ elif annotations != elem._annotations:
+ if detect_subquery_cols and elem._is_immutable:
+ newelem = elem._clone(clone=clone, **kw)._annotate(annotations)
+ else:
+ newelem = elem._annotate(annotations)
+ else:
+ newelem = elem
+ newelem._copy_internals(clone=clone)
+ cloned_ids[id_] = newelem
+ return newelem
+
+ if element is not None:
+ element = clone(element)
+ clone = None # remove gc cycles
+ return element
+
+
+def _deep_deannotate(element, values=None):
+ """Deep copy the given element, removing annotations."""
+
+ cloned = {}
+
+ def clone(elem, **kw):
+ if values:
+ key = id(elem)
+ else:
+ key = elem
+
+ if key not in cloned:
+ newelem = elem._deannotate(values=values, clone=True)
+ newelem._copy_internals(clone=clone)
+ cloned[key] = newelem
+ return newelem
+ else:
+ return cloned[key]
+
+ if element is not None:
+ element = clone(element)
+ clone = None # remove gc cycles
+ return element
+
+
+def _shallow_annotate(element, annotations):
+ """Annotate the given ClauseElement and copy its internals so that
+ internal objects refer to the new annotated object.
+
+ Basically used to apply a "don't traverse" annotation to a
+ selectable, without digging throughout the whole
+ structure wasting time.
+ """
+ element = element._annotate(annotations)
+ element._copy_internals()
+ return element
+
+
+def _new_annotation_type(cls, base_cls):
+ if issubclass(cls, Annotated):
+ return cls
+ elif cls in annotated_classes:
+ return annotated_classes[cls]
+
+ for super_ in cls.__mro__:
+ # check if an Annotated subclass more specific than
+ # the given base_cls is already registered, such
+ # as AnnotatedColumnElement.
+ if super_ in annotated_classes:
+ base_cls = annotated_classes[super_]
+ break
+
+ annotated_classes[cls] = anno_cls = type(
+ "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ )
+ globals()["Annotated%s" % cls.__name__] = anno_cls
+
+ if "_traverse_internals" in cls.__dict__:
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+ elif cls.__dict__.get("inherit_cache", False):
+ anno_cls._traverse_internals = list(cls._traverse_internals) + [
+ ("_annotations", InternalTraversal.dp_annotations_key)
+ ]
+
+ # some classes include this even if they have traverse_internals
+ # e.g. BindParameter, add it if present.
+ if cls.__dict__.get("inherit_cache", False):
+ anno_cls.inherit_cache = True
+
+ anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
+
+ return anno_cls
+
+
+def _prepare_annotations(target_hierarchy, base_cls):
+ for cls in util.walk_subclasses(target_hierarchy):
+ _new_annotation_type(cls, base_cls)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
new file mode 100644
index 0000000..ec685d1
--- /dev/null
+++ b/lib/sqlalchemy/sql/base.py
@@ -0,0 +1,1702 @@
+# sql/base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Foundational utilities common to many sql modules.
+
+"""
+
+
+import itertools
+import operator
+import re
+
+from . import roles
+from . import visitors
+from .traversals import HasCacheKey # noqa
+from .traversals import HasCopyInternals # noqa
+from .traversals import MemoizedHasCacheKey # noqa
+from .visitors import ClauseVisitor
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import exc
+from .. import util
+from ..util import HasMemoized
+from ..util import hybridmethod
+
+
+coercions = None
+elements = None
+type_api = None
+
+PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
+NO_ARG = util.symbol("NO_ARG")
+
+
+class Immutable(object):
+ """mark a ClauseElement as 'immutable' when expressions are cloned."""
+
+ _is_immutable = True
+
+ def unique_params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
+ def params(self, *optionaldict, **kwargs):
+ raise NotImplementedError("Immutable objects do not support copying")
+
+ def _clone(self, **kw):
+ return self
+
+ def _copy_internals(self, **kw):
+ pass
+
+
+class SingletonConstant(Immutable):
+ """Represent SQL constants like NULL, TRUE, FALSE"""
+
+ _is_singleton_constant = True
+
+ def __new__(cls, *arg, **kw):
+ return cls._singleton
+
+ @classmethod
+ def _create_singleton(cls):
+ obj = object.__new__(cls)
+ obj.__init__()
+
+ # for a long time this was an empty frozenset, meaning
+ # a SingletonConstant would never be a "corresponding column" in
+ # a statement. This referred to #6259. However, in #7154 we see
+ # that we do in fact need "correspondence" to work when matching cols
+ # in result sets, so the non-correspondence was moved to a more
+ # specific level when we are actually adapting expressions for SQL
+ # render only.
+ obj.proxy_set = frozenset([obj])
+ cls._singleton = obj
+
+
+def _from_objects(*elements):
+ return itertools.chain.from_iterable(
+ [element._from_objects for element in elements]
+ )
+
+
+def _select_iterables(elements):
+ """expand tables into individual columns in the
+ given list of column expressions.
+
+ """
+ return itertools.chain.from_iterable(
+ [c._select_iterable for c in elements]
+ )
+
+
+def _generative(fn):
+ """non-caching _generative() decorator.
+
+ This is basically the legacy decorator that copies the object and
+ runs a method on the new copy.
+
+ """
+
+ @util.decorator
+ def _generative(fn, self, *args, **kw):
+ """Mark a method as generative."""
+
+ self = self._generate()
+ x = fn(self, *args, **kw)
+ assert x is None, "generative methods must have no return value"
+ return self
+
+ decorated = _generative(fn)
+ decorated.non_generative = fn
+ return decorated
+
+
+def _exclusive_against(*names, **kw):
+ msgs = kw.pop("msgs", {})
+
+ defaults = kw.pop("defaults", {})
+
+ getters = [
+ (name, operator.attrgetter(name), defaults.get(name, None))
+ for name in names
+ ]
+
+ @util.decorator
+ def check(fn, *args, **kw):
+ # make pylance happy by not including "self" in the argument
+ # list
+ self = args[0]
+ args = args[1:]
+ for name, getter, default_ in getters:
+ if getter(self) is not default_:
+ msg = msgs.get(
+ name,
+ "Method %s() has already been invoked on this %s construct"
+ % (fn.__name__, self.__class__),
+ )
+ raise exc.InvalidRequestError(msg)
+ return fn(self, *args, **kw)
+
+ return check
+
+
+def _clone(element, **kw):
+ return element._clone(**kw)
+
+
+def _expand_cloned(elements):
+ """expand the given set of ClauseElements to be the set of all 'cloned'
+ predecessors.
+
+ """
+ return itertools.chain(*[x._cloned_set for x in elements])
+
+
+def _cloned_intersection(a, b):
+ """return the intersection of sets a and b, counting
+ any overlap between 'cloned' predecessors.
+
+ The returned set is in terms of the entities present within 'a'.
+
+ """
+ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+ return set(
+ elem for elem in a if all_overlap.intersection(elem._cloned_set)
+ )
+
+
+def _cloned_difference(a, b):
+ all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
+ return set(
+ elem for elem in a if not all_overlap.intersection(elem._cloned_set)
+ )
+
+
+class _DialectArgView(util.collections_abc.MutableMapping):
+ """A dictionary view of dialect-level arguments in the form
+ <dialectname>_<argument_name>.
+
+ """
+
+ def __init__(self, obj):
+ self.obj = obj
+
+ def _key(self, key):
+ try:
+ dialect, value_key = key.split("_", 1)
+ except ValueError as err:
+ util.raise_(KeyError(key), replace_context=err)
+ else:
+ return dialect, value_key
+
+ def __getitem__(self, key):
+ dialect, value_key = self._key(key)
+
+ try:
+ opt = self.obj.dialect_options[dialect]
+ except exc.NoSuchModuleError as err:
+ util.raise_(KeyError(key), replace_context=err)
+ else:
+ return opt[value_key]
+
+ def __setitem__(self, key, value):
+ try:
+ dialect, value_key = self._key(key)
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Keys must be of the form <dialectname>_<argname>"
+ ),
+ replace_context=err,
+ )
+ else:
+ self.obj.dialect_options[dialect][value_key] = value
+
+ def __delitem__(self, key):
+ dialect, value_key = self._key(key)
+ del self.obj.dialect_options[dialect][value_key]
+
+ def __len__(self):
+ return sum(
+ len(args._non_defaults)
+ for args in self.obj.dialect_options.values()
+ )
+
+ def __iter__(self):
+ return (
+ "%s_%s" % (dialect_name, value_name)
+ for dialect_name in self.obj.dialect_options
+ for value_name in self.obj.dialect_options[
+ dialect_name
+ ]._non_defaults
+ )
+
+
+class _DialectArgDict(util.collections_abc.MutableMapping):
+ """A dictionary view of dialect-level arguments for a specific
+ dialect.
+
+ Maintains a separate collection of user-specified arguments
+ and dialect-specified default arguments.
+
+ """
+
+ def __init__(self):
+ self._non_defaults = {}
+ self._defaults = {}
+
+ def __len__(self):
+ return len(set(self._non_defaults).union(self._defaults))
+
+ def __iter__(self):
+ return iter(set(self._non_defaults).union(self._defaults))
+
+ def __getitem__(self, key):
+ if key in self._non_defaults:
+ return self._non_defaults[key]
+ else:
+ return self._defaults[key]
+
+ def __setitem__(self, key, value):
+ self._non_defaults[key] = value
+
+ def __delitem__(self, key):
+ del self._non_defaults[key]
+
+
+@util.preload_module("sqlalchemy.dialects")
+def _kw_reg_for_dialect(dialect_name):
+ dialect_cls = util.preloaded.dialects.registry.load(dialect_name)
+ if dialect_cls.construct_arguments is None:
+ return None
+ return dict(dialect_cls.construct_arguments)
+
+
+class DialectKWArgs(object):
+ """Establish the ability for a class to have dialect-specific arguments
+ with defaults and constructor validation.
+
+ The :class:`.DialectKWArgs` interacts with the
+ :attr:`.DefaultDialect.construct_arguments` present on a dialect.
+
+ .. seealso::
+
+ :attr:`.DefaultDialect.construct_arguments`
+
+ """
+
+ _dialect_kwargs_traverse_internals = [
+ ("dialect_options", InternalTraversal.dp_dialect_options)
+ ]
+
+ @classmethod
+ def argument_for(cls, dialect_name, argument_name, default):
+ """Add a new kind of dialect-specific keyword argument for this class.
+
+ E.g.::
+
+ Index.argument_for("mydialect", "length", None)
+
+ some_index = Index('a', 'b', mydialect_length=5)
+
+ The :meth:`.DialectKWArgs.argument_for` method is a per-argument
+ way adding extra arguments to the
+ :attr:`.DefaultDialect.construct_arguments` dictionary. This
+ dictionary provides a list of argument names accepted by various
+ schema-level constructs on behalf of a dialect.
+
+ New dialects should typically specify this dictionary all at once as a
+ data member of the dialect class. The use case for ad-hoc addition of
+ argument names is typically for end-user code that is also using
+ a custom compilation scheme which consumes the additional arguments.
+
+ :param dialect_name: name of a dialect. The dialect must be
+ locatable, else a :class:`.NoSuchModuleError` is raised. The
+ dialect must also include an existing
+ :attr:`.DefaultDialect.construct_arguments` collection, indicating
+ that it participates in the keyword-argument validation and default
+ system, else :class:`.ArgumentError` is raised. If the dialect does
+ not include this collection, then any keyword argument can be
+ specified on behalf of this dialect already. All dialects packaged
+ within SQLAlchemy include this collection, however for third party
+ dialects, support may vary.
+
+ :param argument_name: name of the parameter.
+
+ :param default: default value of the parameter.
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
+ if construct_arg_dictionary is None:
+ raise exc.ArgumentError(
+ "Dialect '%s' does have keyword-argument "
+ "validation and defaults enabled configured" % dialect_name
+ )
+ if cls not in construct_arg_dictionary:
+ construct_arg_dictionary[cls] = {}
+ construct_arg_dictionary[cls][argument_name] = default
+
+ @util.memoized_property
+ def dialect_kwargs(self):
+ """A collection of keyword arguments specified as dialect-specific
+ options to this construct.
+
+ The arguments are present here in their original ``<dialect>_<kwarg>``
+ format. Only arguments that were actually passed are included;
+ unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
+ contains all options known by this dialect including defaults.
+
+ The collection is also writable; keys are accepted of the
+ form ``<dialect>_<kwarg>`` where the value will be assembled
+ into the list of options.
+
+ .. versionadded:: 0.9.2
+
+ .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
+ collection is now writable.
+
+ .. seealso::
+
+ :attr:`.DialectKWArgs.dialect_options` - nested dictionary form
+
+ """
+ return _DialectArgView(self)
+
+ @property
+ def kwargs(self):
+ """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
+ return self.dialect_kwargs
+
+ _kw_registry = util.PopulateDict(_kw_reg_for_dialect)
+
+ def _kw_reg_for_dialect_cls(self, dialect_name):
+ construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
+ d = _DialectArgDict()
+
+ if construct_arg_dictionary is None:
+ d._defaults.update({"*": None})
+ else:
+ for cls in reversed(self.__class__.__mro__):
+ if cls in construct_arg_dictionary:
+ d._defaults.update(construct_arg_dictionary[cls])
+ return d
+
+ @util.memoized_property
+ def dialect_options(self):
+ """A collection of keyword arguments specified as dialect-specific
+ options to this construct.
+
+ This is a two-level nested registry, keyed to ``<dialect_name>``
+ and ``<argument_name>``. For example, the ``postgresql_where``
+ argument would be locatable as::
+
+ arg = my_object.dialect_options['postgresql']['where']
+
+ .. versionadded:: 0.9.2
+
+ .. seealso::
+
+ :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
+
+ """
+
+ return util.PopulateDict(
+ util.portable_instancemethod(self._kw_reg_for_dialect_cls)
+ )
+
+ def _validate_dialect_kwargs(self, kwargs):
+ # validate remaining kwargs that they all specify DB prefixes
+
+ if not kwargs:
+ return
+
+ for k in kwargs:
+ m = re.match("^(.+?)_(.+)$", k)
+ if not m:
+ raise TypeError(
+ "Additional arguments should be "
+ "named <dialectname>_<argument>, got '%s'" % k
+ )
+ dialect_name, arg_name = m.group(1, 2)
+
+ try:
+ construct_arg_dictionary = self.dialect_options[dialect_name]
+ except exc.NoSuchModuleError:
+ util.warn(
+ "Can't validate argument %r; can't "
+ "locate any SQLAlchemy dialect named %r"
+ % (k, dialect_name)
+ )
+ self.dialect_options[dialect_name] = d = _DialectArgDict()
+ d._defaults.update({"*": None})
+ d._non_defaults[arg_name] = kwargs[k]
+ else:
+ if (
+ "*" not in construct_arg_dictionary
+ and arg_name not in construct_arg_dictionary
+ ):
+ raise exc.ArgumentError(
+ "Argument %r is not accepted by "
+ "dialect %r on behalf of %r"
+ % (k, dialect_name, self.__class__)
+ )
+ else:
+ construct_arg_dictionary[arg_name] = kwargs[k]
+
+
+class CompileState(object):
+ """Produces additional object state necessary for a statement to be
+ compiled.
+
+ the :class:`.CompileState` class is at the base of classes that assemble
+ state for a particular statement object that is then used by the
+ compiler. This process is essentially an extension of the process that
+ the SQLCompiler.visit_XYZ() method takes, however there is an emphasis
+ on converting raw user intent into more organized structures rather than
+ producing string output. The top-level :class:`.CompileState` for the
+ statement being executed is also accessible when the execution context
+ works with invoking the statement and collecting results.
+
+ The production of :class:`.CompileState` is specific to the compiler, such
+ as within the :meth:`.SQLCompiler.visit_insert`,
+ :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also
+ responsible for associating the :class:`.CompileState` with the
+ :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement,
+ i.e. the outermost SQL statement that's actually being executed.
+ There can be other :class:`.CompileState` objects that are not the
+ toplevel, such as when a SELECT subquery or CTE-nested
+ INSERT/UPDATE/DELETE is generated.
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ("statement",)
+
+ plugins = {}
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ # factory construction.
+
+ if statement._propagate_attrs:
+ plugin_name = statement._propagate_attrs.get(
+ "compile_state_plugin", "default"
+ )
+ klass = cls.plugins.get(
+ (plugin_name, statement._effective_plugin_target), None
+ )
+ if klass is None:
+ klass = cls.plugins[
+ ("default", statement._effective_plugin_target)
+ ]
+
+ else:
+ klass = cls.plugins[
+ ("default", statement._effective_plugin_target)
+ ]
+
+ if klass is cls:
+ return cls(statement, compiler, **kw)
+ else:
+ return klass.create_for_statement(statement, compiler, **kw)
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+ @classmethod
+ def get_plugin_class(cls, statement):
+ plugin_name = statement._propagate_attrs.get(
+ "compile_state_plugin", None
+ )
+
+ if plugin_name:
+ key = (plugin_name, statement._effective_plugin_target)
+ if key in cls.plugins:
+ return cls.plugins[key]
+
+ # there's no case where we call upon get_plugin_class() and want
+ # to get None back, there should always be a default. return that
+ # if there was no plugin-specific class (e.g. Insert with "orm"
+ # plugin)
+ try:
+ return cls.plugins[("default", statement._effective_plugin_target)]
+ except KeyError:
+ return None
+
+ @classmethod
+ def _get_plugin_class_for_plugin(cls, statement, plugin_name):
+ try:
+ return cls.plugins[
+ (plugin_name, statement._effective_plugin_target)
+ ]
+ except KeyError:
+ return None
+
+ @classmethod
+ def plugin_for(cls, plugin_name, visit_name):
+ def decorate(cls_to_decorate):
+ cls.plugins[(plugin_name, visit_name)] = cls_to_decorate
+ return cls_to_decorate
+
+ return decorate
+
+
+class Generative(HasMemoized):
+ """Provide a method-chaining pattern in conjunction with the
+ @_generative decorator."""
+
+ def _generate(self):
+ skip = self._memoized_keys
+ cls = self.__class__
+ s = cls.__new__(cls)
+ if skip:
+ # ensure this iteration remains atomic
+ s.__dict__ = {
+ k: v for k, v in self.__dict__.copy().items() if k not in skip
+ }
+ else:
+ s.__dict__ = self.__dict__.copy()
+ return s
+
+
+class InPlaceGenerative(HasMemoized):
+ """Provide a method-chaining pattern in conjunction with the
+ @_generative decorator that mutates in place."""
+
+ def _generate(self):
+ skip = self._memoized_keys
+ for k in skip:
+ self.__dict__.pop(k, None)
+ return self
+
+
+class HasCompileState(Generative):
+ """A class that has a :class:`.CompileState` associated with it."""
+
+ _compile_state_plugin = None
+
+ _attributes = util.immutabledict()
+
+ _compile_state_factory = CompileState.create_for_statement
+
+
+class _MetaOptions(type):
+ """metaclass for the Options class."""
+
+ def __init__(cls, classname, bases, dict_):
+ cls._cache_attrs = tuple(
+ sorted(
+ d
+ for d in dict_
+ if not d.startswith("__")
+ and d not in ("_cache_key_traversal",)
+ )
+ )
+ type.__init__(cls, classname, bases, dict_)
+
+ def __add__(self, other):
+ o1 = self()
+
+ if set(other).difference(self._cache_attrs):
+ raise TypeError(
+ "dictionary contains attributes not covered by "
+ "Options class %s: %r"
+ % (self, set(other).difference(self._cache_attrs))
+ )
+
+ o1.__dict__.update(other)
+ return o1
+
+
+class Options(util.with_metaclass(_MetaOptions)):
+ """A cacheable option dictionary with defaults."""
+
+ def __init__(self, **kw):
+ self.__dict__.update(kw)
+
+ def __add__(self, other):
+ o1 = self.__class__.__new__(self.__class__)
+ o1.__dict__.update(self.__dict__)
+
+ if set(other).difference(self._cache_attrs):
+ raise TypeError(
+ "dictionary contains attributes not covered by "
+ "Options class %s: %r"
+ % (self, set(other).difference(self._cache_attrs))
+ )
+
+ o1.__dict__.update(other)
+ return o1
+
+ def __eq__(self, other):
+ # TODO: very inefficient. This is used only in test suites
+ # right now.
+ for a, b in util.zip_longest(self._cache_attrs, other._cache_attrs):
+ if getattr(self, a) != getattr(other, b):
+ return False
+ return True
+
+ def __repr__(self):
+ # TODO: fairly inefficient, used only in debugging right now.
+
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ ", ".join(
+ "%s=%r" % (k, self.__dict__[k])
+ for k in self._cache_attrs
+ if k in self.__dict__
+ ),
+ )
+
+ @classmethod
+ def isinstance(cls, klass):
+ return issubclass(cls, klass)
+
+ @hybridmethod
+ def add_to_element(self, name, value):
+ return self + {name: getattr(self, name) + value}
+
+ @hybridmethod
+ def _state_dict(self):
+ return self.__dict__
+
+ _state_dict_const = util.immutabledict()
+
+ @_state_dict.classlevel
+ def _state_dict(cls):
+ return cls._state_dict_const
+
+ @classmethod
+ def safe_merge(cls, other):
+ d = other._state_dict()
+
+ # only support a merge with another object of our class
+ # and which does not have attrs that we don't. otherwise
+ # we risk having state that might not be part of our cache
+ # key strategy
+
+ if (
+ cls is not other.__class__
+ and other._cache_attrs
+ and set(other._cache_attrs).difference(cls._cache_attrs)
+ ):
+ raise TypeError(
+ "other element %r is not empty, is not of type %s, "
+ "and contains attributes not covered here %r"
+ % (
+ other,
+ cls,
+ set(other._cache_attrs).difference(cls._cache_attrs),
+ )
+ )
+ return cls + d
+
+ @classmethod
+ def from_execution_options(
+ cls, key, attrs, exec_options, statement_exec_options
+ ):
+ """process Options argument in terms of execution options.
+
+
+ e.g.::
+
+ (
+ load_options,
+ execution_options,
+ ) = QueryContext.default_load_options.from_execution_options(
+ "_sa_orm_load_options",
+ {
+ "populate_existing",
+ "autoflush",
+ "yield_per"
+ },
+ execution_options,
+ statement._execution_options,
+ )
+
+ get back the Options and refresh "_sa_orm_load_options" in the
+ exec options dict w/ the Options as well
+
+ """
+
+ # common case is that no options we are looking for are
+ # in either dictionary, so cancel for that first
+ check_argnames = attrs.intersection(
+ set(exec_options).union(statement_exec_options)
+ )
+
+ existing_options = exec_options.get(key, cls)
+
+ if check_argnames:
+ result = {}
+ for argname in check_argnames:
+ local = "_" + argname
+ if argname in exec_options:
+ result[local] = exec_options[argname]
+ elif argname in statement_exec_options:
+ result[local] = statement_exec_options[argname]
+
+ new_options = existing_options + result
+ exec_options = util.immutabledict().merge_with(
+ exec_options, {key: new_options}
+ )
+ return new_options, exec_options
+
+ else:
+ return existing_options, exec_options
+
+
+class CacheableOptions(Options, HasCacheKey):
+ @hybridmethod
+ def _gen_cache_key(self, anon_map, bindparams):
+ return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
+
+ @_gen_cache_key.classlevel
+ def _gen_cache_key(cls, anon_map, bindparams):
+ return (cls, ())
+
+ @hybridmethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key_for_object(self)
+
+
+class ExecutableOption(HasCopyInternals):
+ _annotations = util.EMPTY_DICT
+
+ __visit_name__ = "executable_option"
+
+ _is_has_cache_key = False
+
+ def _clone(self, **kw):
+ """Create a shallow copy of this ExecutableOption."""
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = dict(self.__dict__)
+ return c
+
+
+class Executable(roles.StatementRole, Generative):
+ """Mark a :class:`_expression.ClauseElement` as supporting execution.
+
+ :class:`.Executable` is a superclass for all "statement" types
+ of objects, including :func:`select`, :func:`delete`, :func:`update`,
+ :func:`insert`, :func:`text`.
+
+ """
+
+ supports_execution = True
+ _execution_options = util.immutabledict()
+ _bind = None
+ _with_options = ()
+ _with_context_options = ()
+
+ _executable_traverse_internals = [
+ ("_with_options", InternalTraversal.dp_executable_options),
+ (
+ "_with_context_options",
+ ExtendedInternalTraversal.dp_with_context_options,
+ ),
+ ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
+ ]
+
+ is_select = False
+ is_update = False
+ is_insert = False
+ is_text = False
+ is_delete = False
+ is_dml = False
+
+ @property
+ def _effective_plugin_target(self):
+ return self.__visit_name__
+
+ @_generative
+ def options(self, *options):
+ """Apply options to this statement.
+
+ In the general sense, options are any kind of Python object
+ that can be interpreted by the SQL compiler for the statement.
+ These options can be consumed by specific dialects or specific kinds
+ of compilers.
+
+ The most commonly known kind of option are the ORM level options
+ that apply "eager load" and other loading behaviors to an ORM
+ query. However, options can theoretically be used for many other
+ purposes.
+
+ For background on specific kinds of options for specific kinds of
+ statements, refer to the documentation for those option objects.
+
+ .. versionchanged:: 1.4 - added :meth:`.Generative.options` to
+ Core statement objects towards the goal of allowing unified
+ Core / ORM querying capabilities.
+
+ .. seealso::
+
+ :ref:`deferred_options` - refers to options specific to the usage
+ of ORM queries
+
+ :ref:`relationship_loader_options` - refers to options specific
+ to the usage of ORM queries
+
+ """
+ self._with_options += tuple(
+ coercions.expect(roles.ExecutableOptionRole, opt)
+ for opt in options
+ )
+
+ @_generative
+ def _set_compile_options(self, compile_options):
+ """Assign the compile options to a new value.
+
+ :param compile_options: appropriate CacheableOptions structure
+
+ """
+
+ self._compile_options = compile_options
+
+ @_generative
+ def _update_compile_options(self, options):
+ """update the _compile_options with new keys."""
+
+ self._compile_options += options
+
+ @_generative
+ def _add_context_option(self, callable_, cache_args):
+ """Add a context option to this statement.
+
+ These are callable functions that will
+ be given the CompileState object upon compilation.
+
+ A second argument cache_args is required, which will be combined with
+ the ``__code__`` identity of the function itself in order to produce a
+ cache key.
+
+ """
+ self._with_context_options += ((callable_, cache_args),)
+
+ @_generative
+ def execution_options(self, **kw):
+ """Set non-SQL options for the statement which take effect during
+ execution.
+
+ Execution options can be set on a per-statement or
+ per :class:`_engine.Connection` basis. Additionally, the
+ :class:`_engine.Engine` and ORM :class:`~.orm.query.Query`
+ objects provide
+ access to execution options which they in turn configure upon
+ connections.
+
+ The :meth:`execution_options` method is generative. A new
+ instance of this statement is returned that contains the options::
+
+ statement = select(table.c.x, table.c.y)
+ statement = statement.execution_options(autocommit=True)
+
+ Note that only a subset of possible execution options can be applied
+ to a statement - these include "autocommit" and "stream_results",
+ but not "isolation_level" or "compiled_cache".
+ See :meth:`_engine.Connection.execution_options` for a full list of
+ possible options.
+
+ .. seealso::
+
+ :meth:`_engine.Connection.execution_options`
+
+ :meth:`_query.Query.execution_options`
+
+ :meth:`.Executable.get_execution_options`
+
+ """
+ if "isolation_level" in kw:
+ raise exc.ArgumentError(
+ "'isolation_level' execution option may only be specified "
+ "on Connection.execution_options(), or "
+ "per-engine using the isolation_level "
+ "argument to create_engine()."
+ )
+ if "compiled_cache" in kw:
+ raise exc.ArgumentError(
+ "'compiled_cache' execution option may only be specified "
+ "on Connection.execution_options(), not per statement."
+ )
+ self._execution_options = self._execution_options.union(kw)
+
+ def get_execution_options(self):
+ """Get the non-SQL options which will take effect during execution.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :meth:`.Executable.execution_options`
+ """
+ return self._execution_options
+
+ @util.deprecated_20(
+ ":meth:`.Executable.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, *multiparams, **params):
+ """Compile and execute this :class:`.Executable`."""
+ e = self.bind
+ if e is None:
+ label = (
+ getattr(self, "description", None) or self.__class__.__name__
+ )
+ msg = (
+ "This %s is not directly bound to a Connection or Engine. "
+ "Use the .execute() method of a Connection or Engine "
+ "to execute this construct." % label
+ )
+ raise exc.UnboundExecutionError(msg)
+ return e._execute_clauseelement(
+ self, multiparams, params, util.immutabledict()
+ )
+
+ @util.deprecated_20(
+ ":meth:`.Executable.scalar`",
+ alternative="Scalar execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.scalar` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.scalar` method of "
+ ":class:`.Session`.",
+ )
+ def scalar(self, *multiparams, **params):
+ """Compile and execute this :class:`.Executable`, returning the
+ result's scalar representation.
+
+ """
+ return self.execute(*multiparams, **params).scalar()
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Returns the :class:`_engine.Engine` or :class:`_engine.Connection`
+ to
+ which this :class:`.Executable` is bound, or None if none found.
+
+ This is a traversal which checks locally, then
+ checks among the "from" clauses of associated objects
+ until a bound engine or connection is found.
+
+ """
+ if self._bind is not None:
+ return self._bind
+
+ for f in _from_objects(self):
+ if f is self:
+ continue
+ engine = f.bind
+ if engine is not None:
+ return engine
+ else:
+ return None
+
+
+class prefix_anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Considers keys of the form "<ident> <name>" to produce
+ new symbols "<name>_<index>", where "index" is an incrementing integer
+ corresponding to <name>.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __missing__(self, key):
+ (ident, derived) = key.split(" ", 1)
+ anonymous_counter = self.get(derived, 1)
+ self[derived] = anonymous_counter + 1
+ value = derived + "_" + str(anonymous_counter)
+ self[key] = value
+ return value
+
+
+class SchemaEventTarget(object):
+ """Base class for elements that are the targets of :class:`.DDLEvents`
+ events.
+
+ This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
+
+ """
+
+ def _set_parent(self, parent, **kw):
+ """Associate with this SchemaEvent's parent object."""
+
+ def _set_parent_with_dispatch(self, parent, **kw):
+ self.dispatch.before_parent_attach(self, parent)
+ self._set_parent(parent, **kw)
+ self.dispatch.after_parent_attach(self, parent)
+
+
+class SchemaVisitor(ClauseVisitor):
+ """Define the visiting for ``SchemaItem`` objects."""
+
+ __traverse_options__ = {"schema_visitor": True}
+
+
+class ColumnCollection(object):
+ """Collection of :class:`_expression.ColumnElement` instances,
+ typically for
+ :class:`_sql.FromClause` objects.
+
+ The :class:`_sql.ColumnCollection` object is most commonly available
+ as the :attr:`_schema.Table.c` or :attr:`_schema.Table.columns` collection
+ on the :class:`_schema.Table` object, introduced at
+ :ref:`metadata_tables_and_columns`.
+
+ The :class:`_expression.ColumnCollection` has both mapping- and sequence-
+ like behaviors. A :class:`_expression.ColumnCollection` usually stores
+ :class:`_schema.Column` objects, which are then accessible both via mapping
+ style access as well as attribute access style.
+
+ To access :class:`_schema.Column` objects using ordinary attribute-style
+ access, specify the name like any other object attribute, such as below
+ a column named ``employee_name`` is accessed::
+
+ >>> employee_table.c.employee_name
+
+ To access columns that have names with special characters or spaces,
+ index-style access is used, such as below which illustrates a column named
+ ``employee ' payment`` is accessed::
+
+ >>> employee_table.c["employee ' payment"]
+
+ As the :class:`_sql.ColumnCollection` object provides a Python dictionary
+ interface, common dictionary method names like
+ :meth:`_sql.ColumnCollection.keys`, :meth:`_sql.ColumnCollection.values`,
+ and :meth:`_sql.ColumnCollection.items` are available, which means that
+ database columns that are keyed under these names also need to use indexed
+ access::
+
+ >>> employee_table.c["values"]
+
+
+ The name for which a :class:`_schema.Column` would be present is normally
+ that of the :paramref:`_schema.Column.key` parameter. In some contexts,
+ such as a :class:`_sql.Select` object that uses a label style set
+ using the :meth:`_sql.Select.set_label_style` method, a column of a certain
+ key may instead be represented under a particular label name such
+ as ``tablename_columnname``::
+
+ >>> from sqlalchemy import select, column, table
+ >>> from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
+ >>> t = table("t", column("c"))
+ >>> stmt = select(t).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ >>> subq = stmt.subquery()
+ >>> subq.c.t_c
+ <sqlalchemy.sql.elements.ColumnClause at 0x7f59dcf04fa0; t_c>
+
+ :class:`.ColumnCollection` also indexes the columns in order and allows
+ them to be accessible by their integer position::
+
+ >>> cc[0]
+ Column('x', Integer(), table=None)
+ >>> cc[1]
+ Column('y', Integer(), table=None)
+
+ .. versionadded:: 1.4 :class:`_expression.ColumnCollection`
+ allows integer-based
+ index access to the collection.
+
+ Iterating the collection yields the column expressions in order::
+
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('y', Integer(), table=None)]
+
+ The base :class:`_expression.ColumnCollection` object can store
+ duplicates, which can
+ mean either two columns with the same key, in which case the column
+ returned by key access is **arbitrary**::
+
+ >>> x1, x2 = Column('x', Integer), Column('x', Integer)
+ >>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)])
+ >>> list(cc)
+ [Column('x', Integer(), table=None),
+ Column('x', Integer(), table=None)]
+ >>> cc['x'] is x1
+ False
+ >>> cc['x'] is x2
+ True
+
+ Or it can also mean the same column multiple times. These cases are
+ supported as :class:`_expression.ColumnCollection`
+ is used to represent the columns in
+ a SELECT statement which may include duplicates.
+
+ A special subclass :class:`.DedupeColumnCollection` exists which instead
+ maintains SQLAlchemy's older behavior of not allowing duplicates; this
+ collection is used for schema level objects like :class:`_schema.Table`
+ and
+ :class:`.PrimaryKeyConstraint` where this deduping is helpful. The
+ :class:`.DedupeColumnCollection` class also has additional mutation methods
+ as the schema constructs have more use cases that require removal and
+ replacement of columns.
+
+ .. versionchanged:: 1.4 :class:`_expression.ColumnCollection`
+ now stores duplicate
+ column keys as well as the same column in multiple positions. The
+ :class:`.DedupeColumnCollection` class is added to maintain the
+ former behavior in those cases where deduplication as well as
+ additional replace/remove operations are needed.
+
+
+ """
+
+ __slots__ = "_collection", "_index", "_colset"
+
+ def __init__(self, columns=None):
+ object.__setattr__(self, "_colset", set())
+ object.__setattr__(self, "_index", {})
+ object.__setattr__(self, "_collection", [])
+ if columns:
+ self._initial_populate(columns)
+
+ def _initial_populate(self, iter_):
+ self._populate_separate_keys(iter_)
+
+ @property
+ def _all_columns(self):
+ return [col for (k, col) in self._collection]
+
+ def keys(self):
+ """Return a sequence of string key names for all columns in this
+ collection."""
+ return [k for (k, col) in self._collection]
+
+ def values(self):
+ """Return a sequence of :class:`_sql.ColumnClause` or
+ :class:`_schema.Column` objects for all columns in this
+ collection."""
+ return [col for (k, col) in self._collection]
+
+ def items(self):
+ """Return a sequence of (key, column) tuples for all columns in this
+ collection each consisting of a string key name and a
+ :class:`_sql.ColumnClause` or
+ :class:`_schema.Column` object.
+ """
+
+ return list(self._collection)
+
+ def __bool__(self):
+ return bool(self._collection)
+
+ def __len__(self):
+ return len(self._collection)
+
+ def __iter__(self):
+ # turn to a list first to maintain over a course of changes
+ return iter([col for k, col in self._collection])
+
+ def __getitem__(self, key):
+ try:
+ return self._index[key]
+ except KeyError as err:
+ if isinstance(key, util.int_types):
+ util.raise_(IndexError(key), replace_context=err)
+ else:
+ raise
+
+ def __getattr__(self, key):
+ try:
+ return self._index[key]
+ except KeyError as err:
+ util.raise_(AttributeError(key), replace_context=err)
+
+ def __contains__(self, key):
+ if key not in self._index:
+ if not isinstance(key, util.string_types):
+ raise exc.ArgumentError(
+ "__contains__ requires a string argument"
+ )
+ return False
+ else:
+ return True
+
+ def compare(self, other):
+ """Compare this :class:`_expression.ColumnCollection` to another
+ based on the names of the keys"""
+
+ for l, r in util.zip_longest(self, other):
+ if l is not r:
+ return False
+ else:
+ return True
+
+ def __eq__(self, other):
+ return self.compare(other)
+
+ def get(self, key, default=None):
+ """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
+ based on a string key name from this
+ :class:`_expression.ColumnCollection`."""
+
+ if key in self._index:
+ return self._index[key]
+ else:
+ return default
+
+ def __str__(self):
+ return "%s(%s)" % (
+ self.__class__.__name__,
+ ", ".join(str(c) for c in self),
+ )
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError()
+
+ def __delitem__(self, key):
+ raise NotImplementedError()
+
+ def __setattr__(self, key, obj):
+ raise NotImplementedError()
+
+ def clear(self):
+ """Dictionary clear() is not implemented for
+ :class:`_sql.ColumnCollection`."""
+ raise NotImplementedError()
+
+ def remove(self, column):
+ """Dictionary remove() is not implemented for
+ :class:`_sql.ColumnCollection`."""
+ raise NotImplementedError()
+
+ def update(self, iter_):
+ """Dictionary update() is not implemented for
+ :class:`_sql.ColumnCollection`."""
+ raise NotImplementedError()
+
+ __hash__ = None
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
+ self._collection[:] = cols
+ self._colset.update(c for k, c in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ self._index.update({k: col for k, col in reversed(self._collection)})
+
+ def add(self, column, key=None):
+ """Add a column to this :class:`_sql.ColumnCollection`.
+
+ .. note::
+
+ This method is **not normally used by user-facing code**, as the
+ :class:`_sql.ColumnCollection` is usually part of an existing
+ object such as a :class:`_schema.Table`. To add a
+ :class:`_schema.Column` to an existing :class:`_schema.Table`
+ object, use the :meth:`_schema.Table.append_column` method.
+
+ """
+ if key is None:
+ key = column.key
+
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ if key not in self._index:
+ self._index[key] = column
+
+ def __getstate__(self):
+ return {"_collection": self._collection, "_index": self._index}
+
+ def __setstate__(self, state):
+ object.__setattr__(self, "_index", state["_index"])
+ object.__setattr__(self, "_collection", state["_collection"])
+ object.__setattr__(
+ self, "_colset", {col for k, col in self._collection}
+ )
+
+ def contains_column(self, col):
+ """Checks if a column object exists in this collection"""
+ if col not in self._colset:
+ if isinstance(col, util.string_types):
+ raise exc.ArgumentError(
+ "contains_column cannot be used with string arguments. "
+ "Use ``col_name in table.c`` instead."
+ )
+ return False
+ else:
+ return True
+
+ def as_immutable(self):
+ """Return an "immutable" form of this
+ :class:`_sql.ColumnCollection`."""
+
+ return ImmutableColumnCollection(self)
+
+ def corresponding_column(self, column, require_embedded=False):
+ """Given a :class:`_expression.ColumnElement`, return the exported
+ :class:`_expression.ColumnElement` object from this
+ :class:`_expression.ColumnCollection`
+ which corresponds to that original :class:`_expression.ColumnElement`
+ via a common
+ ancestor column.
+
+ :param column: the target :class:`_expression.ColumnElement`
+ to be matched.
+
+ :param require_embedded: only return corresponding columns for
+ the given :class:`_expression.ColumnElement`, if the given
+ :class:`_expression.ColumnElement`
+ is actually present within a sub-element
+ of this :class:`_expression.Selectable`.
+ Normally the column will match if
+ it merely shares a common ancestor with one of the exported
+ columns of this :class:`_expression.Selectable`.
+
+ .. seealso::
+
+ :meth:`_expression.Selectable.corresponding_column`
+ - invokes this method
+ against the collection returned by
+ :attr:`_expression.Selectable.exported_columns`.
+
+ .. versionchanged:: 1.4 the implementation for ``corresponding_column``
+ was moved onto the :class:`_expression.ColumnCollection` itself.
+
+ """
+
+ def embedded(expanded_proxy_set, target_set):
+ for t in target_set.difference(expanded_proxy_set):
+ if not set(_expand_cloned([t])).intersection(
+ expanded_proxy_set
+ ):
+ return False
+ return True
+
+ # don't dig around if the column is locally present
+ if column in self._colset:
+ return column
+ col, intersect = None, None
+ target_set = column.proxy_set
+ cols = [c for (k, c) in self._collection]
+ for c in cols:
+ expanded_proxy_set = set(_expand_cloned(c.proxy_set))
+ i = target_set.intersection(expanded_proxy_set)
+ if i and (
+ not require_embedded
+ or embedded(expanded_proxy_set, target_set)
+ ):
+ if col is None:
+
+ # no corresponding column yet, pick this one.
+
+ col, intersect = c, i
+ elif len(i) > len(intersect):
+
+ # 'c' has a larger field of correspondence than
+ # 'col'. i.e. selectable.c.a1_x->a1.c.x->table.c.x
+ # matches a1.c.x->table.c.x better than
+ # selectable.c.x->table.c.x does.
+
+ col, intersect = c, i
+ elif i == intersect:
+ # they have the same field of correspondence. see
+ # which proxy_set has fewer columns in it, which
+ # indicates a closer relationship with the root
+ # column. Also take into account the "weight"
+ # attribute which CompoundSelect() uses to give
+ # higher precedence to columns based on vertical
+ # position in the compound statement, and discard
+ # columns that have no reference to the target
+ # column (also occurs with CompoundSelect)
+
+ col_distance = util.reduce(
+ operator.add,
+ [
+ sc._annotations.get("weight", 1)
+ for sc in col._uncached_proxy_set()
+ if sc.shares_lineage(column)
+ ],
+ )
+ c_distance = util.reduce(
+ operator.add,
+ [
+ sc._annotations.get("weight", 1)
+ for sc in c._uncached_proxy_set()
+ if sc.shares_lineage(column)
+ ],
+ )
+ if c_distance < col_distance:
+ col, intersect = c, i
+ return col
+
+
+class DedupeColumnCollection(ColumnCollection):
+ """A :class:`_expression.ColumnCollection`
+ that maintains deduplicating behavior.
+
+ This is useful by schema level objects such as :class:`_schema.Table` and
+ :class:`.PrimaryKeyConstraint`. The collection includes more
+ sophisticated mutator methods as well to suit schema objects which
+ require mutable column collections.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def add(self, column, key=None):
+
+ if key is not None and column.key != key:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ key = column.key
+
+ if key is None:
+ raise exc.ArgumentError(
+ "Can't add unnamed column to column collection"
+ )
+
+ if key in self._index:
+
+ existing = self._index[key]
+
+ if existing is column:
+ return
+
+ self.replace(column)
+
+ # pop out memoized proxy_set as this
+ # operation may very well be occurring
+ # in a _make_proxy operation
+ util.memoized_property.reset(column, "proxy_set")
+ else:
+ l = len(self._collection)
+ self._collection.append((key, column))
+ self._colset.add(column)
+ self._index[l] = column
+ self._index[key] = column
+
+ def _populate_separate_keys(self, iter_):
+ """populate from an iterator of (key, column)"""
+ cols = list(iter_)
+
+ replace_col = []
+ for k, col in cols:
+ if col.key != k:
+ raise exc.ArgumentError(
+ "DedupeColumnCollection requires columns be under "
+ "the same key as their .key"
+ )
+ if col.name in self._index and col.key != col.name:
+ replace_col.append(col)
+ elif col.key in self._index:
+ replace_col.append(col)
+ else:
+ self._index[k] = col
+ self._collection.append((k, col))
+ self._colset.update(c for (k, c) in self._collection)
+ self._index.update(
+ (idx, c) for idx, (k, c) in enumerate(self._collection)
+ )
+ for col in replace_col:
+ self.replace(col)
+
+ def extend(self, iter_):
+ self._populate_separate_keys((col.key, col) for col in iter_)
+
+ def remove(self, column):
+ if column not in self._colset:
+ raise ValueError(
+ "Can't remove column %r; column is not in this collection"
+ % column
+ )
+ del self._index[column.key]
+ self._colset.remove(column)
+ self._collection[:] = [
+ (k, c) for (k, c) in self._collection if c is not column
+ ]
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
+ )
+ # delete higher index
+ del self._index[len(self._collection)]
+
+ def replace(self, column):
+ """add the given column to this collection, removing unaliased
+ versions of this column as well as existing columns with the
+ same key.
+
+ e.g.::
+
+ t = Table('sometable', metadata, Column('col1', Integer))
+ t.columns.replace(Column('col1', Integer, key='columnone'))
+
+ will remove the original 'col1' from the collection, and add
+ the new column under the name 'columnname'.
+
+ Used by schema.Column to override columns during table reflection.
+
+ """
+
+ remove_col = set()
+ # remove up to two columns based on matches of name as well as key
+ if column.name in self._index and column.key != column.name:
+ other = self._index[column.name]
+ if other.name == other.key:
+ remove_col.add(other)
+
+ if column.key in self._index:
+ remove_col.add(self._index[column.key])
+
+ new_cols = []
+ replaced = False
+ for k, col in self._collection:
+ if col in remove_col:
+ if not replaced:
+ replaced = True
+ new_cols.append((column.key, column))
+ else:
+ new_cols.append((k, col))
+
+ if remove_col:
+ self._colset.difference_update(remove_col)
+
+ if not replaced:
+ new_cols.append((column.key, column))
+
+ self._colset.add(column)
+ self._collection[:] = new_cols
+
+ self._index.clear()
+ self._index.update(
+ {idx: col for idx, (k, col) in enumerate(self._collection)}
+ )
+ self._index.update(self._collection)
+
+
+class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+ __slots__ = ("_parent",)
+
+ def __init__(self, collection):
+ object.__setattr__(self, "_parent", collection)
+ object.__setattr__(self, "_colset", collection._colset)
+ object.__setattr__(self, "_index", collection._index)
+ object.__setattr__(self, "_collection", collection._collection)
+
+ def __getstate__(self):
+ return {"_parent": self._parent}
+
+ def __setstate__(self, state):
+ parent = state["_parent"]
+ self.__init__(parent)
+
+ add = extend = remove = util.ImmutableContainer._immutable
+
+
+class ColumnSet(util.ordered_column_set):
+ def contains_column(self, col):
+ return col in self
+
+ def extend(self, cols):
+ for col in cols:
+ self.add(col)
+
+ def __add__(self, other):
+ return list(self) + list(other)
+
+ def __eq__(self, other):
+ l = []
+ for c in other:
+ for local in self:
+ if c.shares_lineage(local):
+ l.append(c == local)
+ return elements.and_(*l)
+
+ def __hash__(self):
+ return hash(tuple(x for x in self))
+
+
+def _bind_or_error(schemaitem, msg=None):
+
+ util.warn_deprecated_20(
+ "The ``bind`` argument for schema methods that invoke SQL "
+ "against an engine or connection will be required in SQLAlchemy 2.0."
+ )
+ bind = schemaitem.bind
+ if not bind:
+ name = schemaitem.__class__.__name__
+ label = getattr(
+ schemaitem, "fullname", getattr(schemaitem, "name", None)
+ )
+ if label:
+ item = "%s object %r" % (name, label)
+ else:
+ item = "%s object" % name
+ if msg is None:
+ msg = (
+ "%s is not bound to an Engine or Connection. "
+ "Execution can not proceed without a database to execute "
+ "against." % item
+ )
+ raise exc.UnboundExecutionError(msg)
+ return bind
+
+
+def _entity_namespace(entity):
+ """Return the nearest .entity_namespace for the given entity.
+
+ If not immediately available, does an iterate to find a sub-element
+ that has one, if any.
+
+ """
+ try:
+ return entity.entity_namespace
+ except AttributeError:
+ for elem in visitors.iterate(entity):
+ if hasattr(elem, "entity_namespace"):
+ return elem.entity_namespace
+ else:
+ raise
+
+
+def _entity_namespace_key(entity, key, default=NO_ARG):
+ """Return an entry from an entity_namespace.
+
+
+ Raises :class:`_exc.InvalidRequestError` rather than attribute error
+ on not found.
+
+ """
+
+ try:
+ ns = _entity_namespace(entity)
+ if default is not NO_ARG:
+ return getattr(ns, key, default)
+ else:
+ return getattr(ns, key)
+ except AttributeError as err:
+ util.raise_(
+ exc.InvalidRequestError(
+ 'Entity namespace for "%s" has no property "%s"'
+ % (entity, key)
+ ),
+ replace_context=err,
+ )
diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py
new file mode 100644
index 0000000..8cc73cb
--- /dev/null
+++ b/lib/sqlalchemy/sql/coercions.py
@@ -0,0 +1,1096 @@
+# sql/coercions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import numbers
+import re
+
+from . import operators
+from . import roles
+from . import visitors
+from .base import ExecutableOption
+from .base import Options
+from .traversals import HasCacheKey
+from .visitors import Visitable
+from .. import exc
+from .. import inspection
+from .. import util
+from ..util import collections_abc
+
+
+elements = None
+lambdas = None
+schema = None
+selectable = None
+sqltypes = None
+traversals = None
+
+
+def _is_literal(element):
+ """Return whether or not the element is a "literal" in the context
+ of a SQL expression construct.
+
+ """
+
+ return (
+ not isinstance(
+ element,
+ (Visitable, schema.SchemaEventTarget),
+ )
+ and not hasattr(element, "__clause_element__")
+ )
+
+
+def _deep_is_literal(element):
+ """Return whether or not the element is a "literal" in the context
+ of a SQL expression construct.
+
+ does a deeper more esoteric check than _is_literal. is used
+ for lambda elements that have to distinguish values that would
+ be bound vs. not without any context.
+
+ """
+
+ if isinstance(element, collections_abc.Sequence) and not isinstance(
+ element, str
+ ):
+ for elem in element:
+ if not _deep_is_literal(elem):
+ return False
+ else:
+ return True
+
+ return (
+ not isinstance(
+ element,
+ (
+ Visitable,
+ schema.SchemaEventTarget,
+ HasCacheKey,
+ Options,
+ util.langhelpers._symbol,
+ ),
+ )
+ and not hasattr(element, "__clause_element__")
+ and (
+ not isinstance(element, type)
+ or not issubclass(element, HasCacheKey)
+ )
+ )
+
+
+def _document_text_coercion(paramname, meth_rst, param_rst):
+ return util.add_parameter_text(
+ paramname,
+ (
+ ".. warning:: "
+ "The %s argument to %s can be passed as a Python string argument, "
+ "which will be treated "
+ "as **trusted SQL text** and rendered as given. **DO NOT PASS "
+ "UNTRUSTED INPUT TO THIS PARAMETER**."
+ )
+ % (param_rst, meth_rst),
+ )
+
+
+def _expression_collection_was_a_list(attrname, fnname, args):
+ if args and isinstance(args[0], (list, set, dict)) and len(args) == 1:
+ if isinstance(args[0], list):
+ util.warn_deprecated_20(
+ 'The "%s" argument to %s(), when referring to a sequence '
+ "of items, is now passed as a series of positional "
+ "elements, rather than as a list. " % (attrname, fnname)
+ )
+ return args[0]
+ else:
+ return args
+
+
+def expect(
+ role,
+ element,
+ apply_propagate_attrs=None,
+ argname=None,
+ post_inspect=False,
+ **kw
+):
+ if (
+ role.allows_lambda
+ # note callable() will not invoke a __getattr__() method, whereas
+ # hasattr(obj, "__call__") will. by keeping the callable() check here
+ # we prevent most needless calls to hasattr() and therefore
+ # __getattr__(), which is present on ColumnElement.
+ and callable(element)
+ and hasattr(element, "__code__")
+ ):
+ return lambdas.LambdaElement(
+ element,
+ role,
+ lambdas.LambdaOptions(**kw),
+ apply_propagate_attrs=apply_propagate_attrs,
+ )
+
+ # major case is that we are given a ClauseElement already, skip more
+ # elaborate logic up front if possible
+ impl = _impl_lookup[role]
+
+ original_element = element
+
+ if not isinstance(
+ element,
+ (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue),
+ ):
+ resolved = None
+
+ if impl._resolve_literal_only:
+ resolved = impl._literal_coercion(element, **kw)
+ else:
+
+ original_element = element
+
+ is_clause_element = False
+
+ # this is a special performance optimization for ORM
+ # joins used by JoinTargetImpl that we don't go through the
+ # work of creating __clause_element__() when we only need the
+ # original QueryableAttribute, as the former will do clause
+ # adaption and all that which is just thrown away here.
+ if (
+ impl._skip_clauseelement_for_target_match
+ and isinstance(element, role)
+ and hasattr(element, "__clause_element__")
+ ):
+ is_clause_element = True
+ else:
+ while hasattr(element, "__clause_element__"):
+ is_clause_element = True
+
+ if not getattr(element, "is_clause_element", False):
+ element = element.__clause_element__()
+ else:
+ break
+
+ if not is_clause_element:
+ if impl._use_inspection:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ if post_inspect:
+ insp._post_inspect
+ try:
+ resolved = insp.__clause_element__()
+ except AttributeError:
+ impl._raise_for_expected(original_element, argname)
+
+ if resolved is None:
+ resolved = impl._literal_coercion(
+ element, argname=argname, **kw
+ )
+ else:
+ resolved = element
+ else:
+ resolved = element
+ if (
+ apply_propagate_attrs is not None
+ and not apply_propagate_attrs._propagate_attrs
+ and resolved._propagate_attrs
+ ):
+ apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs
+
+ if impl._role_class in resolved.__class__.__mro__:
+ if impl._post_coercion:
+ resolved = impl._post_coercion(
+ resolved,
+ argname=argname,
+ original_element=original_element,
+ **kw
+ )
+ return resolved
+ else:
+ return impl._implicit_coercions(
+ original_element, resolved, argname=argname, **kw
+ )
+
+
+def expect_as_key(role, element, **kw):
+ kw["as_key"] = True
+ return expect(role, element, **kw)
+
+
+def expect_col_expression_collection(role, expressions):
+ for expr in expressions:
+ strname = None
+ column = None
+
+ resolved = expect(role, expr)
+ if isinstance(resolved, util.string_types):
+ strname = resolved = expr
+ else:
+ cols = []
+ visitors.traverse(resolved, {}, {"column": cols.append})
+ if cols:
+ column = cols[0]
+ add_element = column if column is not None else strname
+ yield resolved, column, strname, add_element
+
+
+class RoleImpl(object):
+ __slots__ = ("_role_class", "name", "_use_inspection")
+
+ def _literal_coercion(self, element, **kw):
+ raise NotImplementedError()
+
+ _post_coercion = None
+ _resolve_literal_only = False
+ _skip_clauseelement_for_target_match = False
+
+ def __init__(self, role_class):
+ self._role_class = role_class
+ self.name = role_class._role_name
+ self._use_inspection = issubclass(role_class, roles.UsesInspection)
+
+ def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ self._raise_for_expected(element, argname, resolved)
+
+ def _raise_for_expected(
+ self,
+ element,
+ argname=None,
+ resolved=None,
+ advice=None,
+ code=None,
+ err=None,
+ ):
+ if resolved is not None and resolved is not element:
+ got = "%r object resolved from %r object" % (resolved, element)
+ else:
+ got = repr(element)
+
+ if argname:
+ msg = "%s expected for argument %r; got %s." % (
+ self.name,
+ argname,
+ got,
+ )
+ else:
+ msg = "%s expected, got %s." % (self.name, got)
+
+ if advice:
+ msg += " " + advice
+
+ util.raise_(exc.ArgumentError(msg, code=code), replace_context=err)
+
+
+class _Deannotate(object):
+ __slots__ = ()
+
+ def _post_coercion(self, resolved, **kw):
+ from .util import _deep_deannotate
+
+ return _deep_deannotate(resolved)
+
+
+class _StringOnly(object):
+ __slots__ = ()
+
+ _resolve_literal_only = True
+
+
+class _ReturnsStringKey(object):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, util.string_types):
+ return original_element
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, **kw):
+ return element
+
+
+class _ColumnCoercions(object):
+ __slots__ = ()
+
+ def _warn_for_scalar_subquery_coercion(self):
+ util.warn(
+ "implicitly coercing SELECT object to scalar subquery; "
+ "please use the .scalar_subquery() method to produce a scalar "
+ "subquery.",
+ )
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if not getattr(resolved, "is_clause_element", False):
+ self._raise_for_expected(original_element, argname, resolved)
+ elif resolved._is_select_statement:
+ self._warn_for_scalar_subquery_coercion()
+ return resolved.scalar_subquery()
+ elif resolved._is_from_clause and isinstance(
+ resolved, selectable.Subquery
+ ):
+ self._warn_for_scalar_subquery_coercion()
+ return resolved.element.scalar_subquery()
+ elif self._role_class.allows_lambda and resolved._is_lambda_element:
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+def _no_text_coercion(
+ element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None
+):
+ util.raise_(
+ exc_cls(
+ "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be "
+ "explicitly declared as text(%(expr)r)"
+ % {
+ "expr": util.ellipses_string(element),
+ "argname": "for argument %s" % (argname,) if argname else "",
+ "extra": "%s " % extra if extra else "",
+ }
+ ),
+ replace_context=err,
+ )
+
+
+class _NoTextCoercion(object):
+ __slots__ = ()
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ if isinstance(element, util.string_types) and issubclass(
+ elements.TextClause, self._role_class
+ ):
+ _no_text_coercion(element, argname)
+ else:
+ self._raise_for_expected(element, argname)
+
+
+class _CoerceLiterals(object):
+ __slots__ = ()
+ _coerce_consts = False
+ _coerce_star = False
+ _coerce_numerics = False
+
+ def _text_coercion(self, element, argname=None):
+ return _no_text_coercion(element, argname)
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ if isinstance(element, util.string_types):
+ if self._coerce_star and element == "*":
+ return elements.ColumnClause("*", is_literal=True)
+ else:
+ return self._text_coercion(element, argname, **kw)
+
+ if self._coerce_consts:
+ if element is None:
+ return elements.Null()
+ elif element is False:
+ return elements.False_()
+ elif element is True:
+ return elements.True_()
+
+ if self._coerce_numerics and isinstance(element, (numbers.Number)):
+ return elements.ColumnClause(str(element), is_literal=True)
+
+ self._raise_for_expected(element, argname)
+
+
+class LiteralValueImpl(RoleImpl):
+ _resolve_literal_only = True
+
+ def _implicit_coercions(
+ self, element, resolved, argname, type_=None, **kw
+ ):
+ if not _is_literal(resolved):
+ self._raise_for_expected(
+ element, resolved=resolved, argname=argname, **kw
+ )
+
+ return elements.BindParameter(None, element, type_=type_, unique=True)
+
+ def _literal_coercion(self, element, argname=None, type_=None, **kw):
+ return element
+
+
+class _SelectIsNotFrom(object):
+ __slots__ = ()
+
+ def _raise_for_expected(self, element, argname=None, resolved=None, **kw):
+ if isinstance(element, roles.SelectStatementRole) or isinstance(
+ resolved, roles.SelectStatementRole
+ ):
+ advice = (
+ "To create a "
+ "FROM clause from a %s object, use the .subquery() method."
+ % (resolved.__class__ if resolved is not None else element,)
+ )
+ code = "89ve"
+ else:
+ advice = code = None
+
+ return super(_SelectIsNotFrom, self)._raise_for_expected(
+ element,
+ argname=argname,
+ resolved=resolved,
+ advice=advice,
+ code=code,
+ **kw
+ )
+
+
+class HasCacheKeyImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, traversals.HasCacheKey):
+ return original_element
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, **kw):
+ return element
+
+
+class ExecutableOptionImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, ExecutableOption):
+ return original_element
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, **kw):
+ return element
+
+
+class ExpressionElementImpl(_ColumnCoercions, RoleImpl):
+ __slots__ = ()
+
+ def _literal_coercion(
+ self, element, name=None, type_=None, argname=None, is_crud=False, **kw
+ ):
+ if (
+ element is None
+ and not is_crud
+ and (type_ is None or not type_.should_evaluate_none)
+ ):
+ # TODO: there's no test coverage now for the
+ # "should_evaluate_none" part of this, as outside of "crud" this
+ # codepath is not normally used except in some special cases
+ return elements.Null()
+ else:
+ try:
+ return elements.BindParameter(
+ name, element, type_, unique=True, _is_crud=is_crud
+ )
+ except exc.ArgumentError as err:
+ self._raise_for_expected(element, err=err)
+
+ def _raise_for_expected(self, element, argname=None, resolved=None, **kw):
+ if isinstance(element, roles.AnonymizedFromClauseRole):
+ advice = (
+ "To create a "
+ "column expression from a FROM clause row "
+ "as a whole, use the .table_valued() method."
+ )
+ else:
+ advice = None
+
+ return super(ExpressionElementImpl, self)._raise_for_expected(
+ element, argname=argname, resolved=resolved, advice=advice, **kw
+ )
+
+
+class BinaryElementImpl(ExpressionElementImpl, RoleImpl):
+
+ __slots__ = ()
+
+ def _literal_coercion(
+ self, element, expr, operator, bindparam_type=None, argname=None, **kw
+ ):
+ try:
+ return expr._bind_param(operator, element, type_=bindparam_type)
+ except exc.ArgumentError as err:
+ self._raise_for_expected(element, err=err)
+
+ def _post_coercion(self, resolved, expr, bindparam_type=None, **kw):
+ if resolved.type._isnull and not expr.type._isnull:
+ resolved = resolved._with_binary_element_type(
+ bindparam_type if bindparam_type is not None else expr.type
+ )
+ return resolved
+
+
+class InElementImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_from_clause:
+ if (
+ isinstance(resolved, selectable.Alias)
+ and resolved.element._is_select_statement
+ ):
+ self._warn_for_implicit_coercion(resolved)
+ return self._post_coercion(resolved.element, **kw)
+ else:
+ self._warn_for_implicit_coercion(resolved)
+ return self._post_coercion(resolved.select(), **kw)
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _warn_for_implicit_coercion(self, elem):
+ util.warn(
+ "Coercing %s object into a select() for use in IN(); "
+ "please pass a select() construct explicitly"
+ % (elem.__class__.__name__)
+ )
+
+ def _literal_coercion(self, element, expr, operator, **kw):
+ if isinstance(element, collections_abc.Iterable) and not isinstance(
+ element, util.string_types
+ ):
+ non_literal_expressions = {}
+ element = list(element)
+ for o in element:
+ if not _is_literal(o):
+ if not isinstance(o, operators.ColumnOperators):
+ self._raise_for_expected(element, **kw)
+ else:
+ non_literal_expressions[o] = o
+ elif o is None:
+ non_literal_expressions[o] = elements.Null()
+
+ if non_literal_expressions:
+ return elements.ClauseList(
+ *[
+ non_literal_expressions[o]
+ if o in non_literal_expressions
+ else expr._bind_param(operator, o)
+ for o in element
+ ]
+ )
+ else:
+ return expr._bind_param(operator, element, expanding=True)
+
+ else:
+ self._raise_for_expected(element, **kw)
+
+ def _post_coercion(self, element, expr, operator, **kw):
+ if element._is_select_statement:
+ # for IN, we are doing scalar_subquery() coercion without
+ # a warning
+ return element.scalar_subquery()
+ elif isinstance(element, elements.ClauseList):
+ assert not len(element.clauses) == 0
+ return element.self_group(against=operator)
+
+ elif isinstance(element, elements.BindParameter):
+ element = element._clone(maintain_key=True)
+ element.expanding = True
+ element.expand_op = operator
+
+ return element
+ else:
+ return element
+
+
+class OnClauseImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, legacy=False, **kw
+ ):
+ if legacy and isinstance(resolved, str):
+ return resolved
+ else:
+ return super(OnClauseImpl, self)._implicit_coercions(
+ original_element,
+ resolved,
+ argname=argname,
+ legacy=legacy,
+ **kw
+ )
+
+ def _text_coercion(self, element, argname=None, legacy=False):
+ if legacy and isinstance(element, str):
+ util.warn_deprecated_20(
+ "Using strings to indicate relationship names in "
+ "Query.join() is deprecated and will be removed in "
+ "SQLAlchemy 2.0. Please use the class-bound attribute "
+ "directly."
+ )
+ return element
+
+ return super(OnClauseImpl, self)._text_coercion(element, argname)
+
+ def _post_coercion(self, resolved, original_element=None, **kw):
+ # this is a hack right now as we want to use coercion on an
+ # ORM InstrumentedAttribute, but we want to return the object
+ # itself if it is one, not its clause element.
+ # ORM context _join and _legacy_join() would need to be improved
+ # to look for annotations in a clause element form.
+ if isinstance(original_element, roles.JoinTargetRole):
+ return original_element
+ return resolved
+
+
+class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ return _no_text_coercion(element, argname)
+
+
+class StatementOptionImpl(_CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ return elements.TextClause(element)
+
+
+class ColumnArgumentImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+
+class ColumnArgumentOrKeyImpl(_ReturnsStringKey, RoleImpl):
+ __slots__ = ()
+
+
+class StrAsPlainColumnImpl(_CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ def _text_coercion(self, element, argname=None):
+ return elements.ColumnClause(element)
+
+
+class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole):
+
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ return elements._textual_label_reference(element)
+
+
+class OrderByImpl(ByOfImpl, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, resolved, **kw):
+ if (
+ isinstance(resolved, self._role_class)
+ and resolved._order_by_label_element is not None
+ ):
+ return elements._label_reference(resolved)
+ else:
+ return resolved
+
+
+class GroupByImpl(ByOfImpl, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(resolved, roles.StrictFromClauseRole):
+ return elements.ClauseList(*resolved.c)
+ else:
+ return resolved
+
+
+class DMLColumnImpl(_ReturnsStringKey, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, as_key=False, **kw):
+ if as_key:
+ return element.key
+ else:
+ return element
+
+
+class ConstExprImpl(RoleImpl):
+ __slots__ = ()
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ if element is None:
+ return elements.Null()
+ elif element is False:
+ return elements.False_()
+ elif element is True:
+ return elements.True_()
+ else:
+ self._raise_for_expected(element, argname)
+
+
+class TruncatedLabelImpl(_StringOnly, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(original_element, util.string_types):
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _literal_coercion(self, element, argname=None, **kw):
+ """coerce the given value to :class:`._truncated_label`.
+
+ Existing :class:`._truncated_label` and
+ :class:`._anonymous_label` objects are passed
+ unchanged.
+ """
+
+ if isinstance(element, elements._truncated_label):
+ return element
+ else:
+ return elements._truncated_label(element)
+
+
+class DDLExpressionImpl(_Deannotate, _CoerceLiterals, RoleImpl):
+
+ __slots__ = ()
+
+ _coerce_consts = True
+
+ def _text_coercion(self, element, argname=None):
+ # see #5754 for why we can't easily deprecate this coercion.
+ # essentially expressions like postgresql_where would have to be
+ # text() as they come back from reflection and we don't want to
+ # have text() elements wired into the inspection dictionaries.
+ return elements.TextClause(element)
+
+
+class DDLConstraintColumnImpl(_Deannotate, _ReturnsStringKey, RoleImpl):
+ __slots__ = ()
+
+
+class DDLReferredColumnImpl(DDLConstraintColumnImpl):
+ __slots__ = ()
+
+
+class LimitOffsetImpl(RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(self, element, resolved, argname=None, **kw):
+ if resolved is None:
+ return None
+ else:
+ self._raise_for_expected(element, argname, resolved)
+
+ def _literal_coercion(self, element, name, type_, **kw):
+ if element is None:
+ return None
+ else:
+ value = util.asint(element)
+ return selectable._OffsetLimitParam(
+ name, value, type_=type_, unique=True
+ )
+
+
+class LabeledColumnExprImpl(ExpressionElementImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if isinstance(resolved, roles.ExpressionElementRole):
+ return resolved.label(None)
+ else:
+ new = super(LabeledColumnExprImpl, self)._implicit_coercions(
+ original_element, resolved, argname=argname, **kw
+ )
+ if isinstance(new, roles.ExpressionElementRole):
+ return new.label(None)
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ _coerce_consts = True
+ _coerce_numerics = True
+ _coerce_star = True
+
+ _guess_straight_column = re.compile(r"^\w\S*$", re.I)
+
+ def _text_coercion(self, element, argname=None):
+ element = str(element)
+
+ guess_is_literal = not self._guess_straight_column.match(element)
+ raise exc.ArgumentError(
+ "Textual column expression %(column)r %(argname)sshould be "
+ "explicitly declared with text(%(column)r), "
+ "or use %(literal_column)s(%(column)r) "
+ "for more specificity"
+ % {
+ "column": util.ellipses_string(element),
+ "argname": "for argument %s" % (argname,) if argname else "",
+ "literal_column": "literal_column"
+ if guess_is_literal
+ else "column",
+ }
+ )
+
+
+class ReturnsRowsImpl(RoleImpl):
+ __slots__ = ()
+
+
+class StatementImpl(_CoerceLiterals, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, resolved, original_element, argname=None, **kw):
+ if resolved is not original_element and not isinstance(
+ original_element, util.string_types
+ ):
+ # use same method as Connection uses; this will later raise
+ # ObjectNotExecutableError
+ try:
+ original_element._execute_on_connection
+ except AttributeError:
+ util.warn_deprecated(
+ "Object %r should not be used directly in a SQL statement "
+ "context, such as passing to methods such as "
+ "session.execute(). This usage will be disallowed in a "
+ "future release. "
+ "Please use Core select() / update() / delete() etc. "
+ "with Session.execute() and other statement execution "
+ "methods." % original_element,
+ "1.4",
+ )
+
+ return resolved
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_lambda_element:
+ return resolved
+ else:
+ return super(StatementImpl, self)._implicit_coercions(
+ original_element, resolved, argname=argname, **kw
+ )
+
+ def _text_coercion(self, element, argname=None):
+ util.warn_deprecated_20(
+ "Using plain strings to indicate SQL statements without using "
+ "the text() construct is "
+ "deprecated and will be removed in version 2.0. Ensure plain "
+ "SQL statements are passed using the text() construct."
+ )
+ return elements.TextClause(element)
+
+
+class SelectStatementImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_text_clause:
+ return resolved.columns()
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class HasCTEImpl(ReturnsRowsImpl):
+ __slots__ = ()
+
+
+class IsCTEImpl(RoleImpl):
+ __slots__ = ()
+
+
+class JoinTargetImpl(RoleImpl):
+ __slots__ = ()
+
+ _skip_clauseelement_for_target_match = True
+
+ def _literal_coercion(self, element, legacy=False, **kw):
+ if isinstance(element, str):
+ return element
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, legacy=False, **kw
+ ):
+ if isinstance(original_element, roles.JoinTargetRole):
+ # note that this codepath no longer occurs as of
+ # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match
+ # were set to False.
+ return original_element
+ elif legacy and isinstance(resolved, str):
+ util.warn_deprecated_20(
+ "Using strings to indicate relationship names in "
+ "Query.join() is deprecated and will be removed in "
+ "SQLAlchemy 2.0. Please use the class-bound attribute "
+ "directly."
+ )
+ return resolved
+ elif legacy and isinstance(resolved, roles.WhereHavingRole):
+ return resolved
+ elif legacy and resolved._is_select_statement:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT "
+ "constructs into FROM clauses is deprecated; please call "
+ ".subquery() on any Core select or ORM Query object in "
+ "order to produce a subquery object.",
+ version="1.4",
+ )
+ # TODO: doing _implicit_subquery here causes tests to fail,
+ # how was this working before? probably that ORM
+ # join logic treated it as a select and subquery would happen
+ # in _ORMJoin->Join
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self,
+ original_element,
+ resolved,
+ argname=None,
+ explicit_subquery=False,
+ allow_select=True,
+ **kw
+ ):
+ if resolved._is_select_statement:
+ if explicit_subquery:
+ return resolved.subquery()
+ elif allow_select:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT "
+ "constructs into FROM clauses is deprecated; please call "
+ ".subquery() on any Core select or ORM Query object in "
+ "order to produce a subquery object.",
+ version="1.4",
+ )
+ return resolved._implicit_subquery
+ elif resolved._is_text_clause:
+ return resolved
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+ def _post_coercion(self, element, deannotate=False, **kw):
+ if deannotate:
+ return element._deannotate()
+ else:
+ return element
+
+
+class StrictFromClauseImpl(FromClauseImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self,
+ original_element,
+ resolved,
+ argname=None,
+ allow_select=False,
+ **kw
+ ):
+ if resolved._is_select_statement and allow_select:
+ util.warn_deprecated(
+ "Implicit coercion of SELECT and textual SELECT constructs "
+ "into FROM clauses is deprecated; please call .subquery() "
+ "on any Core select or ORM Query object in order to produce a "
+ "subquery object.",
+ version="1.4",
+ )
+ return resolved._implicit_subquery
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class AnonymizedFromClauseImpl(StrictFromClauseImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, flat=False, name=None, **kw):
+ assert name is None
+
+ return element._anonymous_fromclause(flat=flat)
+
+
+class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, **kw):
+ if "dml_table" in element._annotations:
+ return element._annotations["dml_table"]
+ else:
+ return element
+
+
+class DMLSelectImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _implicit_coercions(
+ self, original_element, resolved, argname=None, **kw
+ ):
+ if resolved._is_from_clause:
+ if (
+ isinstance(resolved, selectable.Alias)
+ and resolved.element._is_select_statement
+ ):
+ return resolved.element
+ else:
+ return resolved.select()
+ else:
+ self._raise_for_expected(original_element, argname, resolved)
+
+
+class CompoundElementImpl(_NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _raise_for_expected(self, element, argname=None, resolved=None, **kw):
+ if isinstance(element, roles.FromClauseRole):
+ if element._is_subquery:
+ advice = (
+ "Use the plain select() object without "
+ "calling .subquery() or .alias()."
+ )
+ else:
+ advice = (
+ "To SELECT from any FROM clause, use the .select() method."
+ )
+ else:
+ advice = None
+ return super(CompoundElementImpl, self)._raise_for_expected(
+ element, argname=argname, resolved=resolved, advice=advice, **kw
+ )
+
+
+_impl_lookup = {}
+
+
+for name in dir(roles):
+ cls = getattr(roles, name)
+ if name.endswith("Role"):
+ name = name.replace("Role", "Impl")
+ if name in globals():
+ impl = globals()[name](cls)
+ _impl_lookup[cls] = impl
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
new file mode 100644
index 0000000..c9b6ba6
--- /dev/null
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -0,0 +1,5525 @@
+# sql/compiler.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Base SQL and DDL compiler implementations.
+
+Classes provided include:
+
+:class:`.compiler.SQLCompiler` - renders SQL
+strings
+
+:class:`.compiler.DDLCompiler` - renders DDL
+(data definition language) strings
+
+:class:`.compiler.GenericTypeCompiler` - renders
+type specification strings.
+
+To generate user-defined SQL strings, see
+:doc:`/ext/compiler`.
+
+"""
+
+import collections
+import contextlib
+import itertools
+import operator
+import re
+
+from . import base
+from . import coercions
+from . import crud
+from . import elements
+from . import functions
+from . import operators
+from . import schema
+from . import selectable
+from . import sqltypes
+from .base import NO_ARG
+from .base import prefix_anon_map
+from .elements import quoted_name
+from .. import exc
+from .. import util
+
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "authorization",
+ "between",
+ "binary",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "for",
+ "foreign",
+ "freeze",
+ "from",
+ "full",
+ "grant",
+ "group",
+ "having",
+ "ilike",
+ "in",
+ "initially",
+ "inner",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "leading",
+ "left",
+ "like",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "natural",
+ "new",
+ "not",
+ "notnull",
+ "null",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "outer",
+ "overlaps",
+ "placing",
+ "primary",
+ "references",
+ "right",
+ "select",
+ "session_user",
+ "set",
+ "similar",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "verbose",
+ "when",
+ "where",
+ ]
+)
+
+LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
+ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+
+FK_ON_DELETE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_ON_UPDATE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
+BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
+
+BIND_TEMPLATES = {
+ "pyformat": "%%(%(name)s)s",
+ "qmark": "?",
+ "format": "%%s",
+ "numeric": ":[_POSITION]",
+ "named": ":%(name)s",
+}
+
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
+
+OPERATORS = {
+ # binary
+ operators.and_: " AND ",
+ operators.or_: " OR ",
+ operators.add: " + ",
+ operators.mul: " * ",
+ operators.sub: " - ",
+ operators.div: " / ",
+ operators.mod: " % ",
+ operators.truediv: " / ",
+ operators.neg: "-",
+ operators.lt: " < ",
+ operators.le: " <= ",
+ operators.ne: " != ",
+ operators.gt: " > ",
+ operators.ge: " >= ",
+ operators.eq: " = ",
+ operators.is_distinct_from: " IS DISTINCT FROM ",
+ operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
+ operators.concat_op: " || ",
+ operators.match_op: " MATCH ",
+ operators.not_match_op: " NOT MATCH ",
+ operators.in_op: " IN ",
+ operators.not_in_op: " NOT IN ",
+ operators.comma_op: ", ",
+ operators.from_: " FROM ",
+ operators.as_: " AS ",
+ operators.is_: " IS ",
+ operators.is_not: " IS NOT ",
+ operators.collate: " COLLATE ",
+ # unary
+ operators.exists: "EXISTS ",
+ operators.distinct_op: "DISTINCT ",
+ operators.inv: "NOT ",
+ operators.any_op: "ANY ",
+ operators.all_op: "ALL ",
+ # modifiers
+ operators.desc_op: " DESC",
+ operators.asc_op: " ASC",
+ operators.nulls_first_op: " NULLS FIRST",
+ operators.nulls_last_op: " NULLS LAST",
+}
+
+FUNCTIONS = {
+ functions.coalesce: "coalesce",
+ functions.current_date: "CURRENT_DATE",
+ functions.current_time: "CURRENT_TIME",
+ functions.current_timestamp: "CURRENT_TIMESTAMP",
+ functions.current_user: "CURRENT_USER",
+ functions.localtime: "LOCALTIME",
+ functions.localtimestamp: "LOCALTIMESTAMP",
+ functions.random: "random",
+ functions.sysdate: "sysdate",
+ functions.session_user: "SESSION_USER",
+ functions.user: "USER",
+ functions.cube: "CUBE",
+ functions.rollup: "ROLLUP",
+ functions.grouping_sets: "GROUPING SETS",
+}
+
+EXTRACT_MAP = {
+ "month": "month",
+ "day": "day",
+ "year": "year",
+ "second": "second",
+ "hour": "hour",
+ "doy": "doy",
+ "minute": "minute",
+ "quarter": "quarter",
+ "dow": "dow",
+ "week": "week",
+ "epoch": "epoch",
+ "milliseconds": "milliseconds",
+ "microseconds": "microseconds",
+ "timezone_hour": "timezone_hour",
+ "timezone_minute": "timezone_minute",
+}
+
+COMPOUND_KEYWORDS = {
+ selectable.CompoundSelect.UNION: "UNION",
+ selectable.CompoundSelect.UNION_ALL: "UNION ALL",
+ selectable.CompoundSelect.EXCEPT: "EXCEPT",
+ selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
+ selectable.CompoundSelect.INTERSECT: "INTERSECT",
+ selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
+}
+
+
+RM_RENDERED_NAME = 0
+RM_NAME = 1
+RM_OBJECTS = 2
+RM_TYPE = 3
+
+
+ExpandedState = collections.namedtuple(
+ "ExpandedState",
+ [
+ "statement",
+ "additional_parameters",
+ "processors",
+ "positiontup",
+ "parameter_expansion",
+ ],
+)
+
+
+NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
+
+COLLECT_CARTESIAN_PRODUCTS = util.symbol(
+ "COLLECT_CARTESIAN_PRODUCTS",
+ "Collect data on FROMs and cartesian products and gather "
+ "into 'self.from_linter'",
+ canonical=1,
+)
+
+WARN_LINTING = util.symbol(
+ "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
+)
+
+FROM_LINTING = util.symbol(
+ "FROM_LINTING",
+ "Warn for cartesian products; "
+ "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
+ canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
+)
+
+
+class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
+ def lint(self, start=None):
+ froms = self.froms
+ if not froms:
+ return None, None
+
+ edges = set(self.edges)
+ the_rest = set(froms)
+
+ if start is not None:
+ start_with = start
+ the_rest.remove(start_with)
+ else:
+ start_with = the_rest.pop()
+
+ stack = collections.deque([start_with])
+
+ while stack and the_rest:
+ node = stack.popleft()
+ the_rest.discard(node)
+
+ # comparison of nodes in edges here is based on hash equality, as
+ # there are "annotated" elements that match the non-annotated ones.
+ # to remove the need for in-python hash() calls, use native
+ # containment routines (e.g. "node in edge", "edge.index(node)")
+ to_remove = {edge for edge in edges if node in edge}
+
+ # appendleft the node in each edge that is not
+ # the one that matched.
+ stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
+ edges.difference_update(to_remove)
+
+ # FROMS left over? boom
+ if the_rest:
+ return the_rest, start_with
+ else:
+ return None, None
+
+ def warn(self):
+ the_rest, start_with = self.lint()
+
+ # FROMS left over? boom
+ if the_rest:
+
+ froms = the_rest
+ if froms:
+ template = (
+ "SELECT statement has a cartesian product between "
+ "FROM element(s) {froms} and "
+ 'FROM element "{start}". Apply join condition(s) '
+ "between each element to resolve."
+ )
+ froms_str = ", ".join(
+ '"{elem}"'.format(elem=self.froms[from_])
+ for from_ in froms
+ )
+ message = template.format(
+ froms=froms_str, start=self.froms[start_with]
+ )
+
+ util.warn(message)
+
+
+class Compiled(object):
+
+ """Represent a compiled SQL or DDL expression.
+
+ The ``__str__`` method of the ``Compiled`` object should produce
+ the actual text of the statement. ``Compiled`` objects are
+ specific to their underlying database dialect, and also may
+ or may not be specific to the columns referenced within a
+ particular set of bind parameters. In no case should the
+ ``Compiled`` object be dependent on the actual values of those
+ bind parameters, even though it may reference those values as
+ defaults.
+ """
+
+ _cached_metadata = None
+
+ _result_columns = None
+
+ schema_translate_map = None
+
+ execution_options = util.EMPTY_DICT
+ """
+ Execution options propagated from the statement. In some cases,
+ sub-elements of the statement can modify these.
+ """
+
+ _annotations = util.EMPTY_DICT
+
+ compile_state = None
+ """Optional :class:`.CompileState` object that maintains additional
+ state used by the compiler.
+
+ Major executable objects such as :class:`_expression.Insert`,
+ :class:`_expression.Update`, :class:`_expression.Delete`,
+ :class:`_expression.Select` will generate this
+ state when compiled in order to calculate additional information about the
+ object. For the top level object that is to be executed, the state can be
+ stored here where it can also have applicability towards result set
+ processing.
+
+ .. versionadded:: 1.4
+
+ """
+
+ dml_compile_state = None
+ """Optional :class:`.CompileState` assigned at the same point that
+ .isinsert, .isupdate, or .isdelete is assigned.
+
+ This will normally be the same object as .compile_state, with the
+ exception of cases like the :class:`.ORMFromStatementCompileState`
+ object.
+
+ .. versionadded:: 1.4.40
+
+ """
+
+ cache_key = None
+ _gen_time = None
+
+ def __init__(
+ self,
+ dialect,
+ statement,
+ schema_translate_map=None,
+ render_schema_translate=False,
+ compile_kwargs=util.immutabledict(),
+ ):
+ """Construct a new :class:`.Compiled` object.
+
+ :param dialect: :class:`.Dialect` to compile against.
+
+ :param statement: :class:`_expression.ClauseElement` to be compiled.
+
+ :param schema_translate_map: dictionary of schema names to be
+ translated when forming the resultant SQL
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`schema_translating`
+
+ :param compile_kwargs: additional kwargs that will be
+ passed to the initial call to :meth:`.Compiled.process`.
+
+
+ """
+
+ self.dialect = dialect
+ self.preparer = self.dialect.identifier_preparer
+ if schema_translate_map:
+ self.schema_translate_map = schema_translate_map
+ self.preparer = self.preparer._with_schema_translate(
+ schema_translate_map
+ )
+
+ if statement is not None:
+ self.statement = statement
+ self.can_execute = statement.supports_execution
+ self._annotations = statement._annotations
+ if self.can_execute:
+ self.execution_options = statement._execution_options
+ self.string = self.process(self.statement, **compile_kwargs)
+
+ if render_schema_translate:
+ self.string = self.preparer._render_schema_translates(
+ self.string, schema_translate_map
+ )
+ self._gen_time = util.perf_counter()
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self.can_execute:
+ return connection._execute_compiled(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self.statement)
+
+ def visit_unsupported_compilation(self, element, err):
+ util.raise_(
+ exc.UnsupportedCompilationError(self, type(element)),
+ replace_context=err,
+ )
+
+ @property
+ def sql_compiler(self):
+ """Return a Compiled that is capable of processing SQL expressions.
+
+ If this compiler is one, it would likely just return 'self'.
+
+ """
+
+ raise NotImplementedError()
+
+ def process(self, obj, **kwargs):
+ return obj._compiler_dispatch(self, **kwargs)
+
+ def __str__(self):
+ """Return the string text of the generated SQL or DDL."""
+
+ return self.string or ""
+
+ def construct_params(
+ self, params=None, extracted_parameters=None, escape_names=True
+ ):
+ """Return the bind params for this compiled object.
+
+ :param params: a dict of string/object pairs whose values will
+ override bind values compiled in to the
+ statement.
+ """
+
+ raise NotImplementedError()
+
+ @property
+ def params(self):
+ """Return the bind params for this compiled object."""
+ return self.construct_params()
+
+
+class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
+ """Produces DDL specification for TypeEngine objects."""
+
+ ensure_kwarg = r"visit_\w+"
+
+ def __init__(self, dialect):
+ self.dialect = dialect
+
+ def process(self, type_, **kw):
+ return type_._compiler_dispatch(self, **kw)
+
+ def visit_unsupported_compilation(self, element, err, **kw):
+ util.raise_(
+ exc.UnsupportedCompilationError(self, element),
+ replace_context=err,
+ )
+
+
+# this was a Visitable, but to allow accurate detection of
+# column elements this is actually a column element
+class _CompileLabel(elements.ColumnElement):
+
+ """lightweight label object which acts as an expression.Label."""
+
+ __visit_name__ = "label"
+ __slots__ = "element", "name"
+
+ def __init__(self, col, name, alt_names=()):
+ self.element = col
+ self.name = name
+ self._alt_names = (col,) + alt_names
+
+ @property
+ def proxy_set(self):
+ return self.element.proxy_set
+
+ @property
+ def type(self):
+ return self.element.type
+
+ def self_group(self, **kw):
+ return self
+
+
+class SQLCompiler(Compiled):
+ """Default implementation of :class:`.Compiled`.
+
+ Compiles :class:`_expression.ClauseElement` objects into SQL strings.
+
+ """
+
+ extract_map = EXTRACT_MAP
+
+ compound_keywords = COMPOUND_KEYWORDS
+
+ isdelete = isinsert = isupdate = False
+ """class-level defaults which can be set at the instance
+ level to define if this Compiled instance represents
+ INSERT/UPDATE/DELETE
+ """
+
+ isplaintext = False
+
+ returning = None
+ """holds the "returning" collection of columns if
+ the statement is CRUD and defines returning columns
+ either implicitly or explicitly
+ """
+
+ returning_precedes_values = False
+ """set to True classwide to generate RETURNING
+ clauses before the VALUES or WHERE clause (i.e. MSSQL)
+ """
+
+ render_table_with_column_in_update_from = False
+ """set to True classwide to indicate the SET clause
+ in a multi-table UPDATE statement should qualify
+ columns with the table name (i.e. MySQL only)
+ """
+
+ ansi_bind_rules = False
+ """SQL 92 doesn't allow bind parameters to be used
+ in the columns clause of a SELECT, nor does it allow
+ ambiguous expressions like "? = ?". A compiler
+ subclass can set this flag to False if the target
+ driver/DB enforces this
+ """
+
+ _textual_ordered_columns = False
+ """tell the result object that the column names as rendered are important,
+ but they are also "ordered" vs. what is in the compiled object here.
+ """
+
+ _ordered_columns = True
+ """
+ if False, means we can't be sure the list of entries
+ in _result_columns is actually the rendered order. Usually
+ True unless using an unordered TextualSelect.
+ """
+
+ _loose_column_name_matching = False
+ """tell the result object that the SQL statement is textual, wants to match
+ up to Column objects, and may be using the ._tq_label in the SELECT rather
+ than the base name.
+
+ """
+
+ _numeric_binds = False
+ """
+ True if paramstyle is "numeric". This paramstyle is trickier than
+ all the others.
+
+ """
+
+ _render_postcompile = False
+ """
+ whether to render out POSTCOMPILE params during the compile phase.
+
+ """
+
+ insert_single_values_expr = None
+ """When an INSERT is compiled with a single set of parameters inside
+ a VALUES expression, the string is assigned here, where it can be
+ used for insert batching schemes to rewrite the VALUES expression.
+
+ .. versionadded:: 1.3.8
+
+ """
+
+ literal_execute_params = frozenset()
+ """bindparameter objects that are rendered as literal values at statement
+ execution time.
+
+ """
+
+ post_compile_params = frozenset()
+ """bindparameter objects that are rendered as bound parameter placeholders
+ at statement execution time.
+
+ """
+
+ escaped_bind_names = util.EMPTY_DICT
+ """Late escaping of bound parameter names that has to be converted
+ to the original name when looking in the parameter dictionary.
+
+ """
+
+ has_out_parameters = False
+ """if True, there are bindparam() objects that have the isoutparam
+ flag set."""
+
+ insert_prefetch = update_prefetch = ()
+
+ postfetch_lastrowid = False
+ """if True, and this in insert, use cursor.lastrowid to populate
+ result.inserted_primary_key. """
+
+ _cache_key_bind_match = None
+ """a mapping that will relate the BindParameter object we compile
+ to those that are part of the extracted collection of parameters
+ in the cache key, if we were given a cache key.
+
+ """
+
+ positiontup = None
+ """for a compiled construct that uses a positional paramstyle, will be
+ a sequence of strings, indicating the names of bound parameters in order.
+
+ This is used in order to render bound parameters in their correct order,
+ and is combined with the :attr:`_sql.Compiled.params` dictionary to
+ render parameters.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string` - includes a usage example for
+ debugging use cases.
+
+ """
+
+ inline = False
+
+ def __init__(
+ self,
+ dialect,
+ statement,
+ cache_key=None,
+ column_keys=None,
+ for_executemany=False,
+ linting=NO_LINTING,
+ **kwargs
+ ):
+ """Construct a new :class:`.SQLCompiler` object.
+
+ :param dialect: :class:`.Dialect` to be used
+
+ :param statement: :class:`_expression.ClauseElement` to be compiled
+
+ :param column_keys: a list of column names to be compiled into an
+ INSERT or UPDATE statement.
+
+ :param for_executemany: whether INSERT / UPDATE statements should
+ expect that they are to be invoked in an "executemany" style,
+ which may impact how the statement will be expected to return the
+ values of defaults and autoincrement / sequences and similar.
+ Depending on the backend and driver in use, support for retrieving
+ these values may be disabled which means SQL expressions may
+ be rendered inline, RETURNING may not be rendered, etc.
+
+ :param kwargs: additional keyword arguments to be consumed by the
+ superclass.
+
+ """
+ self.column_keys = column_keys
+
+ self.cache_key = cache_key
+
+ if cache_key:
+ self._cache_key_bind_match = ckbm = {
+ b.key: b for b in cache_key[1]
+ }
+ ckbm.update({b: [b] for b in cache_key[1]})
+
+ # compile INSERT/UPDATE defaults/sequences to expect executemany
+ # style execution, which may mean no pre-execute of defaults,
+ # or no RETURNING
+ self.for_executemany = for_executemany
+
+ self.linting = linting
+
+ # a dictionary of bind parameter keys to BindParameter
+ # instances.
+ self.binds = {}
+
+ # a dictionary of BindParameter instances to "compiled" names
+ # that are actually present in the generated SQL
+ self.bind_names = util.column_dict()
+
+ # stack which keeps track of nested SELECT statements
+ self.stack = []
+
+ # relates label names in the final SQL to a tuple of local
+ # column/label name, ColumnElement object (if any) and
+ # TypeEngine. CursorResult uses this for type processing and
+ # column targeting
+ self._result_columns = []
+
+ # true if the paramstyle is positional
+ self.positional = dialect.positional
+ if self.positional:
+ self.positiontup = []
+ self._numeric_binds = dialect.paramstyle == "numeric"
+ self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+
+ self.ctes = None
+
+ self.label_length = (
+ dialect.label_length or dialect.max_identifier_length
+ )
+
+ # a map which tracks "anonymous" identifiers that are created on
+ # the fly here
+ self.anon_map = prefix_anon_map()
+
+ # a map which tracks "truncated" names based on
+ # dialect.label_length or dialect.max_identifier_length
+ self.truncated_names = {}
+
+ Compiled.__init__(self, dialect, statement, **kwargs)
+
+ if self.isinsert or self.isupdate or self.isdelete:
+ if statement._returning:
+ self.returning = statement._returning
+
+ if self.isinsert or self.isupdate:
+ if statement._inline:
+ self.inline = True
+ elif self.for_executemany and (
+ not self.isinsert
+ or (
+ self.dialect.insert_executemany_returning
+ and statement._return_defaults
+ )
+ ):
+ self.inline = True
+
+ if self.positional and self._numeric_binds:
+ self._apply_numbered_params()
+
+ if self._render_postcompile:
+ self._process_parameters_for_postcompile(_populate_self=True)
+
+ @property
+ def current_executable(self):
+ """Return the current 'executable' that is being compiled.
+
+ This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
+ :class:`_sql.Update`, :class:`_sql.Delete`,
+ :class:`_sql.CompoundSelect` object that is being compiled.
+ Specifically it's assigned to the ``self.stack`` list of elements.
+
+ When a statement like the above is being compiled, it normally
+ is also assigned to the ``.statement`` attribute of the
+ :class:`_sql.Compiler` object. However, all SQL constructs are
+ ultimately nestable, and this attribute should never be consulted
+ by a ``visit_`` method, as it is not guaranteed to be assigned
+ nor guaranteed to correspond to the current statement being compiled.
+
+ .. versionadded:: 1.3.21
+
+ For compatibility with previous versions, use the following
+ recipe::
+
+ statement = getattr(self, "current_executable", False)
+ if statement is False:
+ statement = self.stack[-1]["selectable"]
+
+ For versions 1.4 and above, ensure only .current_executable
+ is used; the format of "self.stack" may change.
+
+
+ """
+ try:
+ return self.stack[-1]["selectable"]
+ except IndexError as ie:
+ util.raise_(
+ IndexError("Compiler does not have a stack entry"),
+ replace_context=ie,
+ )
+
+ @property
+ def prefetch(self):
+ return list(self.insert_prefetch + self.update_prefetch)
+
+ @util.memoized_property
+ def _global_attributes(self):
+ return {}
+
+ @util.memoized_instancemethod
+ def _init_cte_state(self):
+ """Initialize collections related to CTEs only if
+ a CTE is located, to save on the overhead of
+ these collections otherwise.
+
+ """
+ # collect CTEs to tack on top of a SELECT
+ # To store the query to print - Dict[cte, text_query]
+ self.ctes = util.OrderedDict()
+
+ # Detect same CTE references - Dict[(level, name), cte]
+ # Level is required for supporting nesting
+ self.ctes_by_level_name = {}
+
+ # To retrieve key/level in ctes_by_level_name -
+ # Dict[cte_reference, (level, cte_name)]
+ self.level_name_by_cte = {}
+
+ self.ctes_recursive = False
+ if self.positional:
+ self.cte_positional = {}
+
+ @contextlib.contextmanager
+ def _nested_result(self):
+ """special API to support the use case of 'nested result sets'"""
+ result_columns, ordered_columns = (
+ self._result_columns,
+ self._ordered_columns,
+ )
+ self._result_columns, self._ordered_columns = [], False
+
+ try:
+ if self.stack:
+ entry = self.stack[-1]
+ entry["need_result_map_for_nested"] = True
+ else:
+ entry = None
+ yield self._result_columns, self._ordered_columns
+ finally:
+ if entry:
+ entry.pop("need_result_map_for_nested")
+ self._result_columns, self._ordered_columns = (
+ result_columns,
+ ordered_columns,
+ )
+
+ def _apply_numbered_params(self):
+ poscount = itertools.count(1)
+ self.string = re.sub(
+ r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string
+ )
+
+ @util.memoized_property
+ def _bind_processors(self):
+
+ return dict(
+ (
+ key,
+ value,
+ )
+ for key, value in (
+ (
+ self.bind_names[bindparam],
+ bindparam.type._cached_bind_processor(self.dialect)
+ if not bindparam.type._is_tuple_type
+ else tuple(
+ elem_type._cached_bind_processor(self.dialect)
+ for elem_type in bindparam.type.types
+ ),
+ )
+ for bindparam in self.bind_names
+ )
+ if value is not None
+ )
+
+ def is_subquery(self):
+ return len(self.stack) > 1
+
+ @property
+ def sql_compiler(self):
+ return self
+
+ def construct_params(
+ self,
+ params=None,
+ _group_number=None,
+ _check=True,
+ extracted_parameters=None,
+ escape_names=True,
+ ):
+ """return a dictionary of bind parameter keys and values"""
+
+ has_escaped_names = escape_names and bool(self.escaped_bind_names)
+
+ if extracted_parameters:
+ # related the bound parameters collected in the original cache key
+ # to those collected in the incoming cache key. They will not have
+ # matching names but they will line up positionally in the same
+ # way. The parameters present in self.bind_names may be clones of
+ # these original cache key params in the case of DML but the .key
+ # will be guaranteed to match.
+ try:
+ orig_extracted = self.cache_key[1]
+ except TypeError as err:
+ util.raise_(
+ exc.CompileError(
+ "This compiled object has no original cache key; "
+ "can't pass extracted_parameters to construct_params"
+ ),
+ replace_context=err,
+ )
+
+ ckbm = self._cache_key_bind_match
+ resolved_extracted = {
+ bind: extracted
+ for b, extracted in zip(orig_extracted, extracted_parameters)
+ for bind in ckbm[b]
+ }
+ else:
+ resolved_extracted = None
+
+ if params:
+ pd = {}
+ for bindparam, name in self.bind_names.items():
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if has_escaped_names
+ else name
+ )
+
+ if bindparam.key in params:
+ pd[escaped_name] = params[bindparam.key]
+ elif name in params:
+ pd[escaped_name] = params[name]
+
+ elif _check and bindparam.required:
+ if _group_number:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r, "
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
+ else:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r"
+ % bindparam.key,
+ code="cd3x",
+ )
+ else:
+ if resolved_extracted:
+ value_param = resolved_extracted.get(
+ bindparam, bindparam
+ )
+ else:
+ value_param = bindparam
+
+ if bindparam.callable:
+ pd[escaped_name] = value_param.effective_value
+ else:
+ pd[escaped_name] = value_param.value
+ return pd
+ else:
+ pd = {}
+ for bindparam, name in self.bind_names.items():
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if has_escaped_names
+ else name
+ )
+
+ if _check and bindparam.required:
+ if _group_number:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r, "
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
+ else:
+ raise exc.InvalidRequestError(
+ "A value is required for bind parameter %r"
+ % bindparam.key,
+ code="cd3x",
+ )
+
+ if resolved_extracted:
+ value_param = resolved_extracted.get(bindparam, bindparam)
+ else:
+ value_param = bindparam
+
+ if bindparam.callable:
+ pd[escaped_name] = value_param.effective_value
+ else:
+ pd[escaped_name] = value_param.value
+ return pd
+
+ @util.memoized_instancemethod
+ def _get_set_input_sizes_lookup(
+ self, include_types=None, exclude_types=None
+ ):
+ if not hasattr(self, "bind_names"):
+ return None
+
+ dialect = self.dialect
+ dbapi = self.dialect.dbapi
+
+ # _unwrapped_dialect_impl() is necessary so that we get the
+ # correct dialect type for a custom TypeDecorator, or a Variant,
+ # which is also a TypeDecorator. Special types like Interval,
+ # that use TypeDecorator but also might be mapped directly
+ # for a dialect impl, also subclass Emulated first which overrides
+ # this behavior in those cases to behave like the default.
+
+ if include_types is None and exclude_types is None:
+
+ def _lookup_type(typ):
+ dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
+ return dbtype
+
+ else:
+
+ def _lookup_type(typ):
+ # note we get dbtype from the possibly TypeDecorator-wrapped
+ # dialect_impl, but the dialect_impl itself that we use for
+ # include/exclude is the unwrapped version.
+
+ dialect_impl = typ._unwrapped_dialect_impl(dialect)
+
+ dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
+
+ if (
+ dbtype is not None
+ and (
+ exclude_types is None
+ or dbtype not in exclude_types
+ and type(dialect_impl) not in exclude_types
+ )
+ and (
+ include_types is None
+ or dbtype in include_types
+ or type(dialect_impl) in include_types
+ )
+ ):
+ return dbtype
+ else:
+ return None
+
+ inputsizes = {}
+ literal_execute_params = self.literal_execute_params
+
+ for bindparam in self.bind_names:
+ if bindparam in literal_execute_params:
+ continue
+
+ if bindparam.type._is_tuple_type:
+ inputsizes[bindparam] = [
+ _lookup_type(typ) for typ in bindparam.type.types
+ ]
+ else:
+ inputsizes[bindparam] = _lookup_type(bindparam.type)
+
+ return inputsizes
+
+ @property
+ def params(self):
+ """Return the bind param dictionary embedded into this
+ compiled object, for those values that are present.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string` - includes a usage example for
+ debugging use cases.
+
+ """
+ return self.construct_params(_check=False)
+
+ def _process_parameters_for_postcompile(
+ self, parameters=None, _populate_self=False
+ ):
+ """handle special post compile parameters.
+
+ These include:
+
+ * "expanding" parameters -typically IN tuples that are rendered
+ on a per-parameter basis for an otherwise fixed SQL statement string.
+
+ * literal_binds compiled with the literal_execute flag. Used for
+ things like SQL Server "TOP N" where the driver does not accommodate
+ N as a bound parameter.
+
+ """
+
+ if parameters is None:
+ parameters = self.construct_params(escape_names=False)
+
+ expanded_parameters = {}
+ if self.positional:
+ positiontup = []
+ else:
+ positiontup = None
+
+ processors = self._bind_processors
+
+ new_processors = {}
+
+ if self.positional and self._numeric_binds:
+ # I'm not familiar with any DBAPI that uses 'numeric'.
+ # strategy would likely be to make use of numbers greater than
+ # the highest number present; then for expanding parameters,
+ # append them to the end of the parameter list. that way
+ # we avoid having to renumber all the existing parameters.
+ raise NotImplementedError(
+ "'post-compile' bind parameters are not supported with "
+ "the 'numeric' paramstyle at this time."
+ )
+
+ replacement_expressions = {}
+ to_update_sets = {}
+
+ # notes:
+ # *unescaped* parameter names in:
+ # self.bind_names, self.binds, self._bind_processors
+ #
+ # *escaped* parameter names in:
+ # construct_params(), replacement_expressions
+
+ for name in (
+ self.positiontup if self.positional else self.bind_names.values()
+ ):
+ escaped_name = (
+ self.escaped_bind_names.get(name, name)
+ if self.escaped_bind_names
+ else name
+ )
+
+ parameter = self.binds[name]
+ if parameter in self.literal_execute_params:
+ if escaped_name not in replacement_expressions:
+ value = parameters.pop(name)
+
+ replacement_expressions[
+ escaped_name
+ ] = self.render_literal_bindparam(
+ parameter, render_literal_value=value
+ )
+ continue
+
+ if parameter in self.post_compile_params:
+ if escaped_name in replacement_expressions:
+ to_update = to_update_sets[escaped_name]
+ else:
+ # we are removing the parameter from parameters
+ # because it is a list value, which is not expected by
+ # TypeEngine objects that would otherwise be asked to
+ # process it. the single name is being replaced with
+ # individual numbered parameters for each value in the
+ # param.
+ #
+ # note we are also inserting *escaped* parameter names
+ # into the given dictionary. default dialect will
+ # use these param names directly as they will not be
+ # in the escaped_bind_names dictionary.
+ values = parameters.pop(name)
+
+ leep = self._literal_execute_expanding_parameter
+ to_update, replacement_expr = leep(
+ escaped_name, parameter, values
+ )
+
+ to_update_sets[escaped_name] = to_update
+ replacement_expressions[escaped_name] = replacement_expr
+
+ if not parameter.literal_execute:
+ parameters.update(to_update)
+ if parameter.type._is_tuple_type:
+ new_processors.update(
+ (
+ "%s_%s_%s" % (name, i, j),
+ processors[name][j - 1],
+ )
+ for i, tuple_element in enumerate(values, 1)
+ for j, value in enumerate(tuple_element, 1)
+ if name in processors
+ and processors[name][j - 1] is not None
+ )
+ else:
+ new_processors.update(
+ (key, processors[name])
+ for key, value in to_update
+ if name in processors
+ )
+ if self.positional:
+ positiontup.extend(name for name, value in to_update)
+ expanded_parameters[name] = [
+ expand_key for expand_key, value in to_update
+ ]
+ elif self.positional:
+ positiontup.append(name)
+
+ def process_expanding(m):
+ key = m.group(1)
+ expr = replacement_expressions[key]
+
+ # if POSTCOMPILE included a bind_expression, render that
+ # around each element
+ if m.group(2):
+ tok = m.group(2).split("~~")
+ be_left, be_right = tok[1], tok[3]
+ expr = ", ".join(
+ "%s%s%s" % (be_left, exp, be_right)
+ for exp in expr.split(", ")
+ )
+ return expr
+
+ statement = re.sub(
+ r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+ process_expanding,
+ self.string,
+ )
+
+ expanded_state = ExpandedState(
+ statement,
+ parameters,
+ new_processors,
+ positiontup,
+ expanded_parameters,
+ )
+
+ if _populate_self:
+ # this is for the "render_postcompile" flag, which is not
+ # otherwise used internally and is for end-user debugging and
+ # special use cases.
+ self.string = expanded_state.statement
+ self._bind_processors.update(expanded_state.processors)
+ self.positiontup = expanded_state.positiontup
+ self.post_compile_params = frozenset()
+ for key in expanded_state.parameter_expansion:
+ bind = self.binds.pop(key)
+ self.bind_names.pop(bind)
+ for value, expanded_key in zip(
+ bind.value, expanded_state.parameter_expansion[key]
+ ):
+ self.binds[expanded_key] = new_param = bind._with_value(
+ value
+ )
+ self.bind_names[new_param] = expanded_key
+
+ return expanded_state
+
+ @util.preload_module("sqlalchemy.engine.cursor")
+ def _create_result_map(self):
+ """utility method used for unit tests only."""
+ cursor = util.preloaded.engine_cursor
+ return cursor.CursorResultMetaData._create_description_match_map(
+ self._result_columns
+ )
+
+ @util.memoized_property
+ def _within_exec_param_key_getter(self):
+ getter = self._key_getters_for_crud_column[2]
+ return getter
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.engine.result")
+ def _inserted_primary_key_from_lastrowid_getter(self):
+ result = util.preloaded.engine_result
+
+ param_key_getter = self._within_exec_param_key_getter
+ table = self.statement.table
+
+ getters = [
+ (operator.methodcaller("get", param_key_getter(col), None), col)
+ for col in table.primary_key
+ ]
+
+ autoinc_col = table._autoincrement_column
+ if autoinc_col is not None:
+ # apply type post processors to the lastrowid
+ proc = autoinc_col.type._cached_result_processor(
+ self.dialect, None
+ )
+ else:
+ proc = None
+
+ row_fn = result.result_tuple([col.key for col in table.primary_key])
+
+ def get(lastrowid, parameters):
+ """given cursor.lastrowid value and the parameters used for INSERT,
+ return a "row" that represents the primary key, either by
+ using the "lastrowid" or by extracting values from the parameters
+ that were sent along with the INSERT.
+
+ """
+ if proc is not None:
+ lastrowid = proc(lastrowid)
+
+ if lastrowid is None:
+ return row_fn(getter(parameters) for getter, col in getters)
+ else:
+ return row_fn(
+ lastrowid if col is autoinc_col else getter(parameters)
+ for getter, col in getters
+ )
+
+ return get
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.engine.result")
+ def _inserted_primary_key_from_returning_getter(self):
+ result = util.preloaded.engine_result
+
+ param_key_getter = self._within_exec_param_key_getter
+ table = self.statement.table
+
+ ret = {col: idx for idx, col in enumerate(self.returning)}
+
+ getters = [
+ (operator.itemgetter(ret[col]), True)
+ if col in ret
+ else (
+ operator.methodcaller("get", param_key_getter(col), None),
+ False,
+ )
+ for col in table.primary_key
+ ]
+
+ row_fn = result.result_tuple([col.key for col in table.primary_key])
+
+ def get(row, parameters):
+ return row_fn(
+ getter(row) if use_row else getter(parameters)
+ for getter, use_row in getters
+ )
+
+ return get
+
+ def default_from(self):
+ """Called when a SELECT statement has no froms, and no FROM clause is
+ to be appended.
+
+ Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
+
+ """
+ return ""
+
+ def visit_grouping(self, grouping, asfrom=False, **kwargs):
+ return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
+
+ def visit_select_statement_grouping(self, grouping, **kwargs):
+ return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
+
+ def visit_label_reference(
+ self, element, within_columns_clause=False, **kwargs
+ ):
+ if self.stack and self.dialect.supports_simple_order_by_label:
+ compile_state = self.stack[-1]["compile_state"]
+
+ (
+ with_cols,
+ only_froms,
+ only_cols,
+ ) = compile_state._label_resolve_dict
+ if within_columns_clause:
+ resolve_dict = only_froms
+ else:
+ resolve_dict = only_cols
+
+ # this can be None in the case that a _label_reference()
+ # were subject to a replacement operation, in which case
+ # the replacement of the Label element may have changed
+ # to something else like a ColumnClause expression.
+ order_by_elem = element.element._order_by_label_element
+
+ if (
+ order_by_elem is not None
+ and order_by_elem.name in resolve_dict
+ and order_by_elem.shares_lineage(
+ resolve_dict[order_by_elem.name]
+ )
+ ):
+ kwargs[
+ "render_label_as_label"
+ ] = element.element._order_by_label_element
+ return self.process(
+ element.element,
+ within_columns_clause=within_columns_clause,
+ **kwargs
+ )
+
+ def visit_textual_label_reference(
+ self, element, within_columns_clause=False, **kwargs
+ ):
+ if not self.stack:
+ # compiling the element outside of the context of a SELECT
+ return self.process(element._text_clause)
+
+ compile_state = self.stack[-1]["compile_state"]
+ with_cols, only_froms, only_cols = compile_state._label_resolve_dict
+ try:
+ if within_columns_clause:
+ col = only_froms[element.element]
+ else:
+ col = with_cols[element.element]
+ except KeyError as err:
+ coercions._no_text_coercion(
+ element.element,
+ extra=(
+ "Can't resolve label reference for ORDER BY / "
+ "GROUP BY / DISTINCT etc."
+ ),
+ exc_cls=exc.CompileError,
+ err=err,
+ )
+ else:
+ kwargs["render_label_as_label"] = col
+ return self.process(
+ col, within_columns_clause=within_columns_clause, **kwargs
+ )
+
+ def visit_label(
+ self,
+ label,
+ add_to_result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False,
+ render_label_as_label=None,
+ result_map_targets=(),
+ **kw
+ ):
+ # only render labels within the columns clause
+ # or ORDER BY clause of a select. dialect-specific compilers
+ # can modify this behavior.
+ render_label_with_as = (
+ within_columns_clause and not within_label_clause
+ )
+ render_label_only = render_label_as_label is label
+
+ if render_label_only or render_label_with_as:
+ if isinstance(label.name, elements._truncated_label):
+ labelname = self._truncated_identifier("colident", label.name)
+ else:
+ labelname = label.name
+
+ if render_label_with_as:
+ if add_to_result_map is not None:
+ add_to_result_map(
+ labelname,
+ label.name,
+ (label, labelname) + label._alt_names + result_map_targets,
+ label.type,
+ )
+ return (
+ label.element._compiler_dispatch(
+ self,
+ within_columns_clause=True,
+ within_label_clause=True,
+ **kw
+ )
+ + OPERATORS[operators.as_]
+ + self.preparer.format_label(label, labelname)
+ )
+ elif render_label_only:
+ return self.preparer.format_label(label, labelname)
+ else:
+ return label.element._compiler_dispatch(
+ self, within_columns_clause=False, **kw
+ )
+
+ def _fallback_column_name(self, column):
+ raise exc.CompileError(
+ "Cannot compile Column object until " "its 'name' is assigned."
+ )
+
+ def visit_lambda_element(self, element, **kw):
+ sql_element = element._resolved
+ return self.process(sql_element, **kw)
+
+ def visit_column(
+ self,
+ column,
+ add_to_result_map=None,
+ include_table=True,
+ result_map_targets=(),
+ **kwargs
+ ):
+ name = orig_name = column.name
+ if name is None:
+ name = self._fallback_column_name(column)
+
+ is_literal = column.is_literal
+ if not is_literal and isinstance(name, elements._truncated_label):
+ name = self._truncated_identifier("colident", name)
+
+ if add_to_result_map is not None:
+ targets = (column, name, column.key) + result_map_targets
+ if column._tq_label:
+ targets += (column._tq_label,)
+
+ add_to_result_map(name, orig_name, targets, column.type)
+
+ if is_literal:
+ # note we are not currently accommodating for
+ # literal_column(quoted_name('ident', True)) here
+ name = self.escape_literal_column(name)
+ else:
+ name = self.preparer.quote(name)
+ table = column.table
+ if table is None or not include_table or not table.named_with_column:
+ return name
+ else:
+ effective_schema = self.preparer.schema_for_object(table)
+
+ if effective_schema:
+ schema_prefix = (
+ self.preparer.quote_schema(effective_schema) + "."
+ )
+ else:
+ schema_prefix = ""
+ tablename = table.name
+ if isinstance(tablename, elements._truncated_label):
+ tablename = self._truncated_identifier("alias", tablename)
+
+ return schema_prefix + self.preparer.quote(tablename) + "." + name
+
+ def visit_collation(self, element, **kw):
+ return self.preparer.format_collation(element.collation)
+
+ def visit_fromclause(self, fromclause, **kwargs):
+ return fromclause.name
+
+ def visit_index(self, index, **kwargs):
+ return index.name
+
+ def visit_typeclause(self, typeclause, **kw):
+ kw["type_expression"] = typeclause
+ kw["identifier_preparer"] = self.preparer
+ return self.dialect.type_compiler.process(typeclause.type, **kw)
+
+ def post_process_text(self, text):
+ if self.preparer._double_percents:
+ text = text.replace("%", "%%")
+ return text
+
+ def escape_literal_column(self, text):
+ if self.preparer._double_percents:
+ text = text.replace("%", "%%")
+ return text
+
+ def visit_textclause(self, textclause, add_to_result_map=None, **kw):
+ def do_bindparam(m):
+ name = m.group(1)
+ if name in textclause._bindparams:
+ return self.process(textclause._bindparams[name], **kw)
+ else:
+ return self.bindparam_string(name, **kw)
+
+ if not self.stack:
+ self.isplaintext = True
+
+ if add_to_result_map:
+ # text() object is present in the columns clause of a
+ # select(). Add a no-name entry to the result map so that
+ # row[text()] produces a result
+ add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
+
+ # un-escape any \:params
+ return BIND_PARAMS_ESC.sub(
+ lambda m: m.group(1),
+ BIND_PARAMS.sub(
+ do_bindparam, self.post_process_text(textclause.text)
+ ),
+ )
+
+ def visit_textual_select(
+ self, taf, compound_index=None, asfrom=False, **kw
+ ):
+
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ new_entry = {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": taf,
+ }
+ self.stack.append(new_entry)
+
+ if taf._independent_ctes:
+ for cte in taf._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
+
+ if populate_result_map:
+ self._ordered_columns = (
+ self._textual_ordered_columns
+ ) = taf.positional
+
+ # enable looser result column matching when the SQL text links to
+ # Column objects by name only
+ self._loose_column_name_matching = not taf.positional and bool(
+ taf.column_args
+ )
+
+ for c in taf.column_args:
+ self.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=self._add_to_result_map,
+ )
+
+ text = self.process(taf.element, **kw)
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+ self.stack.pop(-1)
+
+ return text
+
+ def visit_null(self, expr, **kw):
+ return "NULL"
+
+ def visit_true(self, expr, **kw):
+ if self.dialect.supports_native_boolean:
+ return "true"
+ else:
+ return "1"
+
+ def visit_false(self, expr, **kw):
+ if self.dialect.supports_native_boolean:
+ return "false"
+ else:
+ return "0"
+
+ def _generate_delimited_list(self, elements, separator, **kw):
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in elements)
+ if s
+ )
+
+ def _generate_delimited_and_list(self, clauses, **kw):
+
+ lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
+ operators.and_,
+ elements.True_._singleton,
+ elements.False_._singleton,
+ clauses,
+ )
+ if lcc == 1:
+ return clauses[0]._compiler_dispatch(self, **kw)
+ else:
+ separator = OPERATORS[operators.and_]
+ return separator.join(
+ s
+ for s in (c._compiler_dispatch(self, **kw) for c in clauses)
+ if s
+ )
+
+ def visit_tuple(self, clauselist, **kw):
+ return "(%s)" % self.visit_clauselist(clauselist, **kw)
+
+ def visit_clauselist(self, clauselist, **kw):
+ sep = clauselist.operator
+ if sep is None:
+ sep = " "
+ else:
+ sep = OPERATORS[clauselist.operator]
+
+ return self._generate_delimited_list(clauselist.clauses, sep, **kw)
+
+ def visit_case(self, clause, **kwargs):
+ x = "CASE "
+ if clause.value is not None:
+ x += clause.value._compiler_dispatch(self, **kwargs) + " "
+ for cond, result in clause.whens:
+ x += (
+ "WHEN "
+ + cond._compiler_dispatch(self, **kwargs)
+ + " THEN "
+ + result._compiler_dispatch(self, **kwargs)
+ + " "
+ )
+ if clause.else_ is not None:
+ x += (
+ "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
+ )
+ x += "END"
+ return x
+
+ def visit_type_coerce(self, type_coerce, **kw):
+ return type_coerce.typed_expression._compiler_dispatch(self, **kw)
+
+ def visit_cast(self, cast, **kwargs):
+ return "CAST(%s AS %s)" % (
+ cast.clause._compiler_dispatch(self, **kwargs),
+ cast.typeclause._compiler_dispatch(self, **kwargs),
+ )
+
+ def _format_frame_clause(self, range_, **kw):
+
+ return "%s AND %s" % (
+ "UNBOUNDED PRECEDING"
+ if range_[0] is elements.RANGE_UNBOUNDED
+ else "CURRENT ROW"
+ if range_[0] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[0])), **kw),)
+ if range_[0] < 0
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[0]), **kw),),
+ "UNBOUNDED FOLLOWING"
+ if range_[1] is elements.RANGE_UNBOUNDED
+ else "CURRENT ROW"
+ if range_[1] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[1])), **kw),)
+ if range_[1] < 0
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[1]), **kw),),
+ )
+
+ def visit_over(self, over, **kwargs):
+ if over.range_:
+ range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
+ over.range_, **kwargs
+ )
+ elif over.rows:
+ range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
+ over.rows, **kwargs
+ )
+ else:
+ range_ = None
+
+ return "%s OVER (%s)" % (
+ over.element._compiler_dispatch(self, **kwargs),
+ " ".join(
+ [
+ "%s BY %s"
+ % (word, clause._compiler_dispatch(self, **kwargs))
+ for word, clause in (
+ ("PARTITION", over.partition_by),
+ ("ORDER", over.order_by),
+ )
+ if clause is not None and len(clause)
+ ]
+ + ([range_] if range_ else [])
+ ),
+ )
+
+ def visit_withingroup(self, withingroup, **kwargs):
+ return "%s WITHIN GROUP (ORDER BY %s)" % (
+ withingroup.element._compiler_dispatch(self, **kwargs),
+ withingroup.order_by._compiler_dispatch(self, **kwargs),
+ )
+
+ def visit_funcfilter(self, funcfilter, **kwargs):
+ return "%s FILTER (WHERE %s)" % (
+ funcfilter.func._compiler_dispatch(self, **kwargs),
+ funcfilter.criterion._compiler_dispatch(self, **kwargs),
+ )
+
+ def visit_extract(self, extract, **kwargs):
+ field = self.extract_map.get(extract.field, extract.field)
+ return "EXTRACT(%s FROM %s)" % (
+ field,
+ extract.expr._compiler_dispatch(self, **kwargs),
+ )
+
+ def visit_scalar_function_column(self, element, **kw):
+ compiled_fn = self.visit_function(element.fn, **kw)
+ compiled_col = self.visit_column(element, **kw)
+ return "(%s).%s" % (compiled_fn, compiled_col)
+
+ def visit_function(self, func, add_to_result_map=None, **kwargs):
+ if add_to_result_map is not None:
+ add_to_result_map(func.name, func.name, (), func.type)
+
+ disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
+ if disp:
+ text = disp(func, **kwargs)
+ else:
+ name = FUNCTIONS.get(func._deannotate().__class__, None)
+ if name:
+ if func._has_args:
+ name += "%(expr)s"
+ else:
+ name = func.name
+ name = (
+ self.preparer.quote(name)
+ if self.preparer._requires_quotes_illegal_chars(name)
+ or isinstance(name, elements.quoted_name)
+ else name
+ )
+ name = name + "%(expr)s"
+ text = ".".join(
+ [
+ (
+ self.preparer.quote(tok)
+ if self.preparer._requires_quotes_illegal_chars(tok)
+ or isinstance(name, elements.quoted_name)
+ else tok
+ )
+ for tok in func.packagenames
+ ]
+ + [name]
+ ) % {"expr": self.function_argspec(func, **kwargs)}
+
+ if func._with_ordinality:
+ text += " WITH ORDINALITY"
+ return text
+
+ def visit_next_value_func(self, next_value, **kw):
+ return self.visit_sequence(next_value.sequence)
+
+ def visit_sequence(self, sequence, **kw):
+ raise NotImplementedError(
+ "Dialect '%s' does not support sequence increments."
+ % self.dialect.name
+ )
+
+ def function_argspec(self, func, **kwargs):
+ return func.clause_expr._compiler_dispatch(self, **kwargs)
+
+ def visit_compound_select(
+ self, cs, asfrom=False, compound_index=None, **kwargs
+ ):
+ toplevel = not self.stack
+
+ compile_state = cs._compile_state_factory(cs, self, **kwargs)
+
+ if toplevel and not self.compile_state:
+ self.compile_state = compile_state
+
+ compound_stmt = compile_state.statement
+
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+ need_result_map = toplevel or (
+ not compound_index
+ and entry.get("need_result_map_for_compound", False)
+ )
+
+ # indicates there is already a CompoundSelect in play
+ if compound_index == 0:
+ entry["select_0"] = cs
+
+ self.stack.append(
+ {
+ "correlate_froms": entry["correlate_froms"],
+ "asfrom_froms": entry["asfrom_froms"],
+ "selectable": cs,
+ "compile_state": compile_state,
+ "need_result_map_for_compound": need_result_map,
+ }
+ )
+
+ if compound_stmt._independent_ctes:
+ for cte in compound_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kwargs)
+
+ keyword = self.compound_keywords.get(cs.keyword)
+
+ text = (" " + keyword + " ").join(
+ (
+ c._compiler_dispatch(
+ self, asfrom=asfrom, compound_index=i, **kwargs
+ )
+ for i, c in enumerate(cs.selects)
+ )
+ )
+
+ kwargs["include_table"] = False
+ text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
+ text += self.order_by_clause(cs, **kwargs)
+ if cs._has_row_limiting_clause:
+ text += self._row_limit_clause(cs, **kwargs)
+
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
+
+ self.stack.pop(-1)
+ return text
+
+ def _row_limit_clause(self, cs, **kwargs):
+ if cs._fetch_clause is not None:
+ return self.fetch_clause(cs, **kwargs)
+ else:
+ return self.limit_clause(cs, **kwargs)
+
+ def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
+ attrname = "visit_%s_%s%s" % (
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
+ return getattr(self, attrname, None)
+
+ def visit_unary(
+ self, unary, add_to_result_map=None, result_map_targets=(), **kw
+ ):
+
+ if add_to_result_map is not None:
+ result_map_targets += (unary,)
+ kw["add_to_result_map"] = add_to_result_map
+ kw["result_map_targets"] = result_map_targets
+
+ if unary.operator:
+ if unary.modifier:
+ raise exc.CompileError(
+ "Unary expression does not support operator "
+ "and modifier simultaneously"
+ )
+ disp = self._get_operator_dispatch(
+ unary.operator, "unary", "operator"
+ )
+ if disp:
+ return disp(unary, unary.operator, **kw)
+ else:
+ return self._generate_generic_unary_operator(
+ unary, OPERATORS[unary.operator], **kw
+ )
+ elif unary.modifier:
+ disp = self._get_operator_dispatch(
+ unary.modifier, "unary", "modifier"
+ )
+ if disp:
+ return disp(unary, unary.modifier, **kw)
+ else:
+ return self._generate_generic_unary_modifier(
+ unary, OPERATORS[unary.modifier], **kw
+ )
+ else:
+ raise exc.CompileError(
+ "Unary expression has no operator or modifier"
+ )
+
+ def visit_is_true_unary_operator(self, element, operator, **kw):
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
+ return self.process(element.element, **kw)
+ else:
+ return "%s = 1" % self.process(element.element, **kw)
+
+ def visit_is_false_unary_operator(self, element, operator, **kw):
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
+ return "NOT %s" % self.process(element.element, **kw)
+ else:
+ return "%s = 0" % self.process(element.element, **kw)
+
+ def visit_not_match_op_binary(self, binary, operator, **kw):
+ return "NOT %s" % self.visit_binary(
+ binary, override_operator=operators.match_op
+ )
+
+ def visit_not_in_op_binary(self, binary, operator, **kw):
+ # The brackets are required in the NOT IN operation because the empty
+ # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
+ # The presence of the OR makes the brackets required.
+ return "(%s)" % self._generate_generic_binary(
+ binary, OPERATORS[operator], **kw
+ )
+
+ def visit_empty_set_op_expr(self, type_, expand_op):
+ if expand_op is operators.not_in_op:
+ if len(type_) > 1:
+ return "(%s)) OR (1 = 1" % (
+ ", ".join("NULL" for element in type_)
+ )
+ else:
+ return "NULL) OR (1 = 1"
+ elif expand_op is operators.in_op:
+ if len(type_) > 1:
+ return "(%s)) AND (1 != 1" % (
+ ", ".join("NULL" for element in type_)
+ )
+ else:
+ return "NULL) AND (1 != 1"
+ else:
+ return self.visit_empty_set_expr(type_)
+
+ def visit_empty_set_expr(self, element_types):
+ raise NotImplementedError(
+ "Dialect '%s' does not support empty set expression."
+ % self.dialect.name
+ )
+
+ def _literal_execute_expanding_parameter_literal_binds(
+ self, parameter, values
+ ):
+
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
+ if not values:
+ if typ_dialect_impl._is_tuple_type:
+ replacement_expression = (
+ "VALUES " if self.dialect.tuple_in_values else ""
+ ) + self.visit_empty_set_op_expr(
+ parameter.type.types, parameter.expand_op
+ )
+
+ else:
+ replacement_expression = self.visit_empty_set_op_expr(
+ [parameter.type], parameter.expand_op
+ )
+
+ elif typ_dialect_impl._is_tuple_type or (
+ typ_dialect_impl._isnull
+ and isinstance(values[0], util.collections_abc.Sequence)
+ and not isinstance(
+ values[0], util.string_types + util.binary_types
+ )
+ ):
+
+ replacement_expression = (
+ "VALUES " if self.dialect.tuple_in_values else ""
+ ) + ", ".join(
+ "(%s)"
+ % (
+ ", ".join(
+ self.render_literal_value(value, param_type)
+ for value, param_type in zip(
+ tuple_element, parameter.type.types
+ )
+ )
+ )
+ for i, tuple_element in enumerate(values)
+ )
+ else:
+ replacement_expression = ", ".join(
+ self.render_literal_value(value, parameter.type)
+ for value in values
+ )
+
+ return (), replacement_expression
+
+ def _literal_execute_expanding_parameter(self, name, parameter, values):
+
+ if parameter.literal_execute:
+ return self._literal_execute_expanding_parameter_literal_binds(
+ parameter, values
+ )
+
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+
+ if not values:
+ to_update = []
+ if typ_dialect_impl._is_tuple_type:
+
+ replacement_expression = self.visit_empty_set_op_expr(
+ parameter.type.types, parameter.expand_op
+ )
+ else:
+ replacement_expression = self.visit_empty_set_op_expr(
+ [parameter.type], parameter.expand_op
+ )
+
+ elif typ_dialect_impl._is_tuple_type or (
+ typ_dialect_impl._isnull
+ and isinstance(values[0], util.collections_abc.Sequence)
+ and not isinstance(
+ values[0], util.string_types + util.binary_types
+ )
+ ):
+ assert not typ_dialect_impl._is_array
+ to_update = [
+ ("%s_%s_%s" % (name, i, j), value)
+ for i, tuple_element in enumerate(values, 1)
+ for j, value in enumerate(tuple_element, 1)
+ ]
+ replacement_expression = (
+ "VALUES " if self.dialect.tuple_in_values else ""
+ ) + ", ".join(
+ "(%s)"
+ % (
+ ", ".join(
+ self.bindtemplate
+ % {"name": to_update[i * len(tuple_element) + j][0]}
+ for j, value in enumerate(tuple_element)
+ )
+ )
+ for i, tuple_element in enumerate(values)
+ )
+ else:
+ to_update = [
+ ("%s_%s" % (name, i), value)
+ for i, value in enumerate(values, 1)
+ ]
+ replacement_expression = ", ".join(
+ self.bindtemplate % {"name": key} for key, value in to_update
+ )
+
+ return to_update, replacement_expression
+
+ def visit_binary(
+ self,
+ binary,
+ override_operator=None,
+ eager_grouping=False,
+ from_linter=None,
+ lateral_from_linter=None,
+ **kw
+ ):
+ if from_linter and operators.is_comparison(binary.operator):
+ if lateral_from_linter is not None:
+ enclosing_lateral = kw["enclosing_lateral"]
+ lateral_from_linter.edges.update(
+ itertools.product(
+ binary.left._from_objects + [enclosing_lateral],
+ binary.right._from_objects + [enclosing_lateral],
+ )
+ )
+ else:
+ from_linter.edges.update(
+ itertools.product(
+ binary.left._from_objects, binary.right._from_objects
+ )
+ )
+
+ # don't allow "? = ?" to render
+ if (
+ self.ansi_bind_rules
+ and isinstance(binary.left, elements.BindParameter)
+ and isinstance(binary.right, elements.BindParameter)
+ ):
+ kw["literal_execute"] = True
+
+ operator_ = override_operator or binary.operator
+ disp = self._get_operator_dispatch(operator_, "binary", None)
+ if disp:
+ return disp(binary, operator_, **kw)
+ else:
+ try:
+ opstring = OPERATORS[operator_]
+ except KeyError as err:
+ util.raise_(
+ exc.UnsupportedCompilationError(self, operator_),
+ replace_context=err,
+ )
+ else:
+ return self._generate_generic_binary(
+ binary,
+ opstring,
+ from_linter=from_linter,
+ lateral_from_linter=lateral_from_linter,
+ **kw
+ )
+
+ def visit_function_as_comparison_op_binary(self, element, operator, **kw):
+ return self.process(element.sql_function, **kw)
+
+ def visit_mod_binary(self, binary, operator, **kw):
+ if self.preparer._double_percents:
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
+ else:
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
+
+ def visit_custom_op_binary(self, element, operator, **kw):
+ kw["eager_grouping"] = operator.eager_grouping
+ return self._generate_generic_binary(
+ element,
+ " " + self.escape_literal_column(operator.opstring) + " ",
+ **kw
+ )
+
+ def visit_custom_op_unary_operator(self, element, operator, **kw):
+ return self._generate_generic_unary_operator(
+ element, self.escape_literal_column(operator.opstring) + " ", **kw
+ )
+
+ def visit_custom_op_unary_modifier(self, element, operator, **kw):
+ return self._generate_generic_unary_modifier(
+ element, " " + self.escape_literal_column(operator.opstring), **kw
+ )
+
+ def _generate_generic_binary(
+ self, binary, opstring, eager_grouping=False, **kw
+ ):
+
+ _in_binary = kw.get("_in_binary", False)
+
+ kw["_in_binary"] = True
+ kw["_binary_op"] = binary.operator
+ text = (
+ binary.left._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ + opstring
+ + binary.right._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ )
+
+ if _in_binary and eager_grouping:
+ text = "(%s)" % text
+ return text
+
+ def _generate_generic_unary_operator(self, unary, opstring, **kw):
+ return opstring + unary.element._compiler_dispatch(self, **kw)
+
+ def _generate_generic_unary_modifier(self, unary, opstring, **kw):
+ return unary.element._compiler_dispatch(self, **kw) + opstring
+
+ @util.memoized_property
+ def _like_percent_literal(self):
+ return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
+
+ def visit_contains_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right).concat(percent)
+ return self.visit_like_op_binary(binary, operator, **kw)
+
+ def visit_not_contains_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right).concat(percent)
+ return self.visit_not_like_op_binary(binary, operator, **kw)
+
+ def visit_startswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent._rconcat(binary.right)
+ return self.visit_like_op_binary(binary, operator, **kw)
+
+ def visit_not_startswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent._rconcat(binary.right)
+ return self.visit_not_like_op_binary(binary, operator, **kw)
+
+ def visit_endswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right)
+ return self.visit_like_op_binary(binary, operator, **kw)
+
+ def visit_not_endswith_op_binary(self, binary, operator, **kw):
+ binary = binary._clone()
+ percent = self._like_percent_literal
+ binary.right = percent.concat(binary.right)
+ return self.visit_not_like_op_binary(binary, operator, **kw)
+
+ def visit_like_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+
+ # TODO: use ternary here, not "and"/ "or"
+ return "%s LIKE %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_not_like_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "%s NOT LIKE %s" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "lower(%s) LIKE lower(%s)" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_not_ilike_op_binary(self, binary, operator, **kw):
+ escape = binary.modifiers.get("escape", None)
+ return "lower(%s) NOT LIKE lower(%s)" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
+
+ def visit_between_op_binary(self, binary, operator, **kw):
+ symmetric = binary.modifiers.get("symmetric", False)
+ return self._generate_generic_binary(
+ binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
+ )
+
+ def visit_not_between_op_binary(self, binary, operator, **kw):
+ symmetric = binary.modifiers.get("symmetric", False)
+ return self._generate_generic_binary(
+ binary,
+ " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
+ **kw
+ )
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ raise exc.CompileError(
+ "%s dialect does not support regular expressions"
+ % self.dialect.name
+ )
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ raise exc.CompileError(
+ "%s dialect does not support regular expressions"
+ % self.dialect.name
+ )
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ raise exc.CompileError(
+ "%s dialect does not support regular expression replacements"
+ % self.dialect.name
+ )
+
+ def visit_bindparam(
+ self,
+ bindparam,
+ within_columns_clause=False,
+ literal_binds=False,
+ skip_bind_expression=False,
+ literal_execute=False,
+ render_postcompile=False,
+ **kwargs
+ ):
+ if not skip_bind_expression:
+ impl = bindparam.type.dialect_impl(self.dialect)
+ if impl._has_bind_expression:
+ bind_expression = impl.bind_expression(bindparam)
+ wrapped = self.process(
+ bind_expression,
+ skip_bind_expression=True,
+ within_columns_clause=within_columns_clause,
+ literal_binds=literal_binds,
+ literal_execute=literal_execute,
+ render_postcompile=render_postcompile,
+ **kwargs
+ )
+ if bindparam.expanding:
+ # for postcompile w/ expanding, move the "wrapped" part
+ # of this into the inside
+ m = re.match(
+ r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
+ )
+ wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
+ m.group(2),
+ m.group(1),
+ m.group(3),
+ )
+ return wrapped
+
+ if not literal_binds:
+ literal_execute = (
+ literal_execute
+ or bindparam.literal_execute
+ or (within_columns_clause and self.ansi_bind_rules)
+ )
+ post_compile = literal_execute or bindparam.expanding
+ else:
+ post_compile = False
+
+ if literal_binds:
+ ret = self.render_literal_bindparam(
+ bindparam, within_columns_clause=True, **kwargs
+ )
+ if bindparam.expanding:
+ ret = "(%s)" % ret
+ return ret
+
+ name = self._truncate_bindparam(bindparam)
+
+ if name in self.binds:
+ existing = self.binds[name]
+ if existing is not bindparam:
+ if (
+ (existing.unique or bindparam.unique)
+ and not existing.proxy_set.intersection(
+ bindparam.proxy_set
+ )
+ and not existing._cloned_set.intersection(
+ bindparam._cloned_set
+ )
+ ):
+ raise exc.CompileError(
+ "Bind parameter '%s' conflicts with "
+ "unique bind parameter of the same name" % name
+ )
+ elif existing.expanding != bindparam.expanding:
+ raise exc.CompileError(
+ "Can't reuse bound parameter name '%s' in both "
+ "'expanding' (e.g. within an IN expression) and "
+ "non-expanding contexts. If this parameter is to "
+ "receive a list/array value, set 'expanding=True' on "
+ "it for expressions that aren't IN, otherwise use "
+ "a different parameter name." % (name,)
+ )
+ elif existing._is_crud or bindparam._is_crud:
+ raise exc.CompileError(
+ "bindparam() name '%s' is reserved "
+ "for automatic usage in the VALUES or SET "
+ "clause of this "
+ "insert/update statement. Please use a "
+ "name other than column name when using bindparam() "
+ "with insert() or update() (for example, 'b_%s')."
+ % (bindparam.key, bindparam.key)
+ )
+
+ self.binds[bindparam.key] = self.binds[name] = bindparam
+
+ # if we are given a cache key that we're going to match against,
+ # relate the bindparam here to one that is most likely present
+ # in the "extracted params" portion of the cache key. this is used
+ # to set up a positional mapping that is used to determine the
+ # correct parameters for a subsequent use of this compiled with
+ # a different set of parameter values. here, we accommodate for
+ # parameters that may have been cloned both before and after the cache
+ # key was been generated.
+ ckbm = self._cache_key_bind_match
+ if ckbm:
+ for bp in bindparam._cloned_set:
+ if bp.key in ckbm:
+ cb = ckbm[bp.key]
+ ckbm[cb].append(bindparam)
+
+ if bindparam.isoutparam:
+ self.has_out_parameters = True
+
+ if post_compile:
+ if render_postcompile:
+ self._render_postcompile = True
+
+ if literal_execute:
+ self.literal_execute_params |= {bindparam}
+ else:
+ self.post_compile_params |= {bindparam}
+
+ ret = self.bindparam_string(
+ name,
+ post_compile=post_compile,
+ expanding=bindparam.expanding,
+ **kwargs
+ )
+
+ if bindparam.expanding:
+ ret = "(%s)" % ret
+ return ret
+
+ def render_literal_bindparam(
+ self, bindparam, render_literal_value=NO_ARG, **kw
+ ):
+ if render_literal_value is not NO_ARG:
+ value = render_literal_value
+ else:
+ if bindparam.value is None and bindparam.callable is None:
+ op = kw.get("_binary_op", None)
+ if op and op not in (operators.is_, operators.is_not):
+ util.warn_limited(
+ "Bound parameter '%s' rendering literal NULL in a SQL "
+ "expression; comparisons to NULL should not use "
+ "operators outside of 'is' or 'is not'",
+ (bindparam.key,),
+ )
+ return self.process(sqltypes.NULLTYPE, **kw)
+ value = bindparam.effective_value
+
+ if bindparam.expanding:
+ leep = self._literal_execute_expanding_parameter_literal_binds
+ to_update, replacement_expr = leep(bindparam, value)
+ return replacement_expr
+ else:
+ return self.render_literal_value(value, bindparam.type)
+
+ def render_literal_value(self, value, type_):
+ """Render the value of a bind parameter as a quoted literal.
+
+ This is used for statement sections that do not accept bind parameters
+ on the target driver/database.
+
+ This should be implemented by subclasses using the quoting services
+ of the DBAPI.
+
+ """
+
+ processor = type_._cached_literal_processor(self.dialect)
+ if processor:
+ return processor(value)
+ else:
+ raise NotImplementedError(
+ "Don't know how to literal-quote value %r" % value
+ )
+
+ def _truncate_bindparam(self, bindparam):
+ if bindparam in self.bind_names:
+ return self.bind_names[bindparam]
+
+ bind_name = bindparam.key
+ if isinstance(bind_name, elements._truncated_label):
+ bind_name = self._truncated_identifier("bindparam", bind_name)
+
+ # add to bind_names for translation
+ self.bind_names[bindparam] = bind_name
+
+ return bind_name
+
+ def _truncated_identifier(self, ident_class, name):
+ if (ident_class, name) in self.truncated_names:
+ return self.truncated_names[(ident_class, name)]
+
+ anonname = name.apply_map(self.anon_map)
+
+ if len(anonname) > self.label_length - 6:
+ counter = self.truncated_names.get(ident_class, 1)
+ truncname = (
+ anonname[0 : max(self.label_length - 6, 0)]
+ + "_"
+ + hex(counter)[2:]
+ )
+ self.truncated_names[ident_class] = counter + 1
+ else:
+ truncname = anonname
+ self.truncated_names[(ident_class, name)] = truncname
+ return truncname
+
+ def _anonymize(self, name):
+ return name % self.anon_map
+
+ def bindparam_string(
+ self,
+ name,
+ positional_names=None,
+ post_compile=False,
+ expanding=False,
+ escaped_from=None,
+ **kw
+ ):
+
+ if self.positional:
+ if positional_names is not None:
+ positional_names.append(name)
+ else:
+ self.positiontup.append(name)
+ elif not escaped_from:
+
+ if _BIND_TRANSLATE_RE.search(name):
+ # not quite the translate use case as we want to
+ # also get a quick boolean if we even found
+ # unusual characters in the name
+ new_name = _BIND_TRANSLATE_RE.sub(
+ lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
+ name,
+ )
+ escaped_from = name
+ name = new_name
+
+ if escaped_from:
+ if not self.escaped_bind_names:
+ self.escaped_bind_names = {}
+ self.escaped_bind_names[escaped_from] = name
+ if post_compile:
+ return "__[POSTCOMPILE_%s]" % name
+ else:
+ return self.bindtemplate % {"name": name}
+
+ def visit_cte(
+ self,
+ cte,
+ asfrom=False,
+ ashint=False,
+ fromhints=None,
+ visiting_cte=None,
+ from_linter=None,
+ **kwargs
+ ):
+ self._init_cte_state()
+
+ kwargs["visiting_cte"] = cte
+
+ cte_name = cte.name
+
+ if isinstance(cte_name, elements._truncated_label):
+ cte_name = self._truncated_identifier("alias", cte_name)
+
+ is_new_cte = True
+ embedded_in_current_named_cte = False
+
+ _reference_cte = cte._get_reference_cte()
+
+ if _reference_cte in self.level_name_by_cte:
+ cte_level, _ = self.level_name_by_cte[_reference_cte]
+ assert _ == cte_name
+ else:
+ cte_level = len(self.stack) if cte.nesting else 1
+
+ cte_level_name = (cte_level, cte_name)
+ if cte_level_name in self.ctes_by_level_name:
+ existing_cte = self.ctes_by_level_name[cte_level_name]
+ embedded_in_current_named_cte = visiting_cte is existing_cte
+
+ # we've generated a same-named CTE that we are enclosed in,
+ # or this is the same CTE. just return the name.
+ if cte is existing_cte._restates or cte is existing_cte:
+ is_new_cte = False
+ elif existing_cte is cte._restates:
+ # we've generated a same-named CTE that is
+ # enclosed in us - we take precedence, so
+ # discard the text for the "inner".
+ del self.ctes[existing_cte]
+
+ existing_cte_reference_cte = existing_cte._get_reference_cte()
+
+ # TODO: determine if these assertions are correct. they
+ # pass for current test cases
+ # assert existing_cte_reference_cte is _reference_cte
+ # assert existing_cte_reference_cte is existing_cte
+
+ del self.level_name_by_cte[existing_cte_reference_cte]
+ else:
+ # if the two CTEs are deep-copy identical, consider them
+ # the same, **if** they are clones, that is, they came from
+ # the ORM or other visit method
+ if (
+ cte._is_clone_of is not None
+ or existing_cte._is_clone_of is not None
+ ) and cte.compare(existing_cte):
+ is_new_cte = False
+ else:
+ raise exc.CompileError(
+ "Multiple, unrelated CTEs found with "
+ "the same name: %r" % cte_name
+ )
+
+ if not asfrom and not is_new_cte:
+ return None
+
+ if cte._cte_alias is not None:
+ pre_alias_cte = cte._cte_alias
+ cte_pre_alias_name = cte._cte_alias.name
+ if isinstance(cte_pre_alias_name, elements._truncated_label):
+ cte_pre_alias_name = self._truncated_identifier(
+ "alias", cte_pre_alias_name
+ )
+ else:
+ pre_alias_cte = cte
+ cte_pre_alias_name = None
+
+ if is_new_cte:
+ self.ctes_by_level_name[cte_level_name] = cte
+ self.level_name_by_cte[_reference_cte] = cte_level_name
+
+ if (
+ "autocommit" in cte.element._execution_options
+ and "autocommit" not in self.execution_options
+ ):
+ self.execution_options = self.execution_options.union(
+ {
+ "autocommit": cte.element._execution_options[
+ "autocommit"
+ ]
+ }
+ )
+
+ if pre_alias_cte not in self.ctes:
+ self.visit_cte(pre_alias_cte, **kwargs)
+
+ if not cte_pre_alias_name and cte not in self.ctes:
+ if cte.recursive:
+ self.ctes_recursive = True
+ text = self.preparer.format_alias(cte, cte_name)
+ if cte.recursive:
+ if isinstance(cte.element, selectable.Select):
+ col_source = cte.element
+ elif isinstance(cte.element, selectable.CompoundSelect):
+ col_source = cte.element.selects[0]
+ else:
+ assert False, "cte should only be against SelectBase"
+
+ # TODO: can we get at the .columns_plus_names collection
+ # that is already (or will be?) generated for the SELECT
+ # rather than calling twice?
+ recur_cols = [
+ # TODO: proxy_name is not technically safe,
+ # see test_cte->
+ # test_with_recursive_no_name_currently_buggy. not
+ # clear what should be done with such a case
+ fallback_label_name or proxy_name
+ for (
+ _,
+ proxy_name,
+ fallback_label_name,
+ c,
+ repeated,
+ ) in (col_source._generate_columns_plus_names(True))
+ if not repeated
+ ]
+
+ text += "(%s)" % (
+ ", ".join(
+ self.preparer.format_label_name(
+ ident, anon_map=self.anon_map
+ )
+ for ident in recur_cols
+ )
+ )
+
+ if self.positional:
+ kwargs["positional_names"] = self.cte_positional[cte] = []
+
+ assert kwargs.get("subquery", False) is False
+
+ if not self.stack:
+ # toplevel, this is a stringify of the
+ # cte directly. just compile the inner
+ # the way alias() does.
+ return cte.element._compiler_dispatch(
+ self, asfrom=asfrom, **kwargs
+ )
+ else:
+ prefixes = self._generate_prefixes(
+ cte, cte._prefixes, **kwargs
+ )
+ inner = cte.element._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
+
+ text += " AS %s\n(%s)" % (prefixes, inner)
+
+ if cte._suffixes:
+ text += " " + self._generate_prefixes(
+ cte, cte._suffixes, **kwargs
+ )
+
+ self.ctes[cte] = text
+
+ if asfrom:
+ if from_linter:
+ from_linter.froms[cte] = cte_name
+
+ if not is_new_cte and embedded_in_current_named_cte:
+ return self.preparer.format_alias(cte, cte_name)
+
+ if cte_pre_alias_name:
+ text = self.preparer.format_alias(cte, cte_pre_alias_name)
+ if self.preparer._requires_quotes(cte_name):
+ cte_name = self.preparer.quote(cte_name)
+ text += self.get_render_as_alias_suffix(cte_name)
+ return text
+ else:
+ return self.preparer.format_alias(cte, cte_name)
+
+ def visit_table_valued_alias(self, element, **kw):
+ if element.joins_implicitly:
+ kw["from_linter"] = None
+ if element._is_lateral:
+ return self.visit_lateral(element, **kw)
+ else:
+ return self.visit_alias(element, **kw)
+
+ def visit_table_valued_column(self, element, **kw):
+ return self.visit_column(element, **kw)
+
+ def visit_alias(
+ self,
+ alias,
+ asfrom=False,
+ ashint=False,
+ iscrud=False,
+ fromhints=None,
+ subquery=False,
+ lateral=False,
+ enclosing_alias=None,
+ from_linter=None,
+ **kwargs
+ ):
+
+ if lateral:
+ if "enclosing_lateral" not in kwargs:
+ # if lateral is set and enclosing_lateral is not
+ # present, we assume we are being called directly
+ # from visit_lateral() and we need to set enclosing_lateral.
+ assert alias._is_lateral
+ kwargs["enclosing_lateral"] = alias
+
+ # for lateral objects, we track a second from_linter that is...
+ # lateral! to the level above us.
+ if (
+ from_linter
+ and "lateral_from_linter" not in kwargs
+ and "enclosing_lateral" in kwargs
+ ):
+ kwargs["lateral_from_linter"] = from_linter
+
+ if enclosing_alias is not None and enclosing_alias.element is alias:
+ inner = alias.element._compiler_dispatch(
+ self,
+ asfrom=asfrom,
+ ashint=ashint,
+ iscrud=iscrud,
+ fromhints=fromhints,
+ lateral=lateral,
+ enclosing_alias=alias,
+ **kwargs
+ )
+ if subquery and (asfrom or lateral):
+ inner = "(%s)" % (inner,)
+ return inner
+ else:
+ enclosing_alias = kwargs["enclosing_alias"] = alias
+
+ if asfrom or ashint:
+ if isinstance(alias.name, elements._truncated_label):
+ alias_name = self._truncated_identifier("alias", alias.name)
+ else:
+ alias_name = alias.name
+
+ if ashint:
+ return self.preparer.format_alias(alias, alias_name)
+ elif asfrom:
+ if from_linter:
+ from_linter.froms[alias] = alias_name
+
+ inner = alias.element._compiler_dispatch(
+ self, asfrom=True, lateral=lateral, **kwargs
+ )
+ if subquery:
+ inner = "(%s)" % (inner,)
+
+ ret = inner + self.get_render_as_alias_suffix(
+ self.preparer.format_alias(alias, alias_name)
+ )
+
+ if alias._supports_derived_columns and alias._render_derived:
+ ret += "(%s)" % (
+ ", ".join(
+ "%s%s"
+ % (
+ self.preparer.quote(col.name),
+ " %s"
+ % self.dialect.type_compiler.process(
+ col.type, **kwargs
+ )
+ if alias._render_derived_w_types
+ else "",
+ )
+ for col in alias.c
+ )
+ )
+
+ if fromhints and alias in fromhints:
+ ret = self.format_from_hint_text(
+ ret, alias, fromhints[alias], iscrud
+ )
+
+ return ret
+ else:
+ # note we cancel the "subquery" flag here as well
+ return alias.element._compiler_dispatch(
+ self, lateral=lateral, **kwargs
+ )
+
+ def visit_subquery(self, subquery, **kw):
+ kw["subquery"] = True
+ return self.visit_alias(subquery, **kw)
+
+ def visit_lateral(self, lateral_, **kw):
+ kw["lateral"] = True
+ return "LATERAL %s" % self.visit_alias(lateral_, **kw)
+
+ def visit_tablesample(self, tablesample, asfrom=False, **kw):
+ text = "%s TABLESAMPLE %s" % (
+ self.visit_alias(tablesample, asfrom=True, **kw),
+ tablesample._get_method()._compiler_dispatch(self, **kw),
+ )
+
+ if tablesample.seed is not None:
+ text += " REPEATABLE (%s)" % (
+ tablesample.seed._compiler_dispatch(self, **kw)
+ )
+
+ return text
+
+ def visit_values(self, element, asfrom=False, from_linter=None, **kw):
+ kw.setdefault("literal_binds", element.literal_binds)
+ v = "VALUES %s" % ", ".join(
+ self.process(
+ elements.Tuple(
+ types=element._column_types, *elem
+ ).self_group(),
+ **kw
+ )
+ for chunk in element._data
+ for elem in chunk
+ )
+
+ if isinstance(element.name, elements._truncated_label):
+ name = self._truncated_identifier("values", element.name)
+ else:
+ name = element.name
+
+ if element._is_lateral:
+ lateral = "LATERAL "
+ else:
+ lateral = ""
+
+ if asfrom:
+ if from_linter:
+ from_linter.froms[element] = (
+ name if name is not None else "(unnamed VALUES element)"
+ )
+
+ if name:
+ v = "%s(%s)%s (%s)" % (
+ lateral,
+ v,
+ self.get_render_as_alias_suffix(self.preparer.quote(name)),
+ (
+ ", ".join(
+ c._compiler_dispatch(
+ self, include_table=False, **kw
+ )
+ for c in element.columns
+ )
+ ),
+ )
+ else:
+ v = "%s(%s)" % (lateral, v)
+ return v
+
+ def get_render_as_alias_suffix(self, alias_name_text):
+ return " AS " + alias_name_text
+
+ def _add_to_result_map(self, keyname, name, objects, type_):
+ if keyname is None or keyname == "*":
+ self._ordered_columns = False
+ self._textual_ordered_columns = True
+ if type_._is_tuple_type:
+ raise exc.CompileError(
+ "Most backends don't support SELECTing "
+ "from a tuple() object. If this is an ORM query, "
+ "consider using the Bundle object."
+ )
+ self._result_columns.append((keyname, name, objects, type_))
+
+ def _label_returning_column(self, stmt, column, column_clause_args=None):
+ """Render a column with necessary labels inside of a RETURNING clause.
+
+ This method is provided for individual dialects in place of calling
+ the _label_select_column method directly, so that the two use cases
+ of RETURNING vs. SELECT can be disambiguated going forward.
+
+ .. versionadded:: 1.4.21
+
+ """
+ return self._label_select_column(
+ None,
+ column,
+ True,
+ False,
+ {} if column_clause_args is None else column_clause_args,
+ )
+
+ def _label_select_column(
+ self,
+ select,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=None,
+ proxy_name=None,
+ fallback_label_name=None,
+ within_columns_clause=True,
+ column_is_repeated=False,
+ need_column_expressions=False,
+ ):
+ """produce labeled columns present in a select()."""
+ impl = column.type.dialect_impl(self.dialect)
+
+ if impl._has_column_expression and (
+ need_column_expressions or populate_result_map
+ ):
+ col_expr = impl.column_expression(column)
+ else:
+ col_expr = column
+
+ if populate_result_map:
+ # pass an "add_to_result_map" callable into the compilation
+ # of embedded columns. this collects information about the
+ # column as it will be fetched in the result and is coordinated
+ # with cursor.description when the query is executed.
+ add_to_result_map = self._add_to_result_map
+
+ # if the SELECT statement told us this column is a repeat,
+ # wrap the callable with one that prevents the addition of the
+ # targets
+ if column_is_repeated:
+ _add_to_result_map = add_to_result_map
+
+ def add_to_result_map(keyname, name, objects, type_):
+ _add_to_result_map(keyname, name, (), type_)
+
+ # if we redefined col_expr for type expressions, wrap the
+ # callable with one that adds the original column to the targets
+ elif col_expr is not column:
+ _add_to_result_map = add_to_result_map
+
+ def add_to_result_map(keyname, name, objects, type_):
+ _add_to_result_map(
+ keyname, name, (column,) + objects, type_
+ )
+
+ else:
+ add_to_result_map = None
+
+ # this method is used by some of the dialects for RETURNING,
+ # which has different inputs. _label_returning_column was added
+ # as the better target for this now however for 1.4 we will keep
+ # _label_select_column directly compatible with this use case.
+ # these assertions right now set up the current expected inputs
+ assert within_columns_clause, (
+ "_label_select_column is only relevant within "
+ "the columns clause of a SELECT or RETURNING"
+ )
+
+ if isinstance(column, elements.Label):
+ if col_expr is not column:
+ result_expr = _CompileLabel(
+ col_expr, column.name, alt_names=(column.element,)
+ )
+ else:
+ result_expr = col_expr
+
+ elif name:
+ # here, _columns_plus_names has determined there's an explicit
+ # label name we need to use. this is the default for
+ # tablenames_plus_columnnames as well as when columns are being
+ # deduplicated on name
+
+ assert (
+ proxy_name is not None
+ ), "proxy_name is required if 'name' is passed"
+
+ result_expr = _CompileLabel(
+ col_expr,
+ name,
+ alt_names=(
+ proxy_name,
+ # this is a hack to allow legacy result column lookups
+ # to work as they did before; this goes away in 2.0.
+ # TODO: this only seems to be tested indirectly
+ # via test/orm/test_deprecations.py. should be a
+ # resultset test for this
+ column._tq_label,
+ ),
+ )
+ else:
+ # determine here whether this column should be rendered in
+ # a labelled context or not, as we were given no required label
+ # name from the caller. Here we apply heuristics based on the kind
+ # of SQL expression involved.
+
+ if col_expr is not column:
+ # type-specific expression wrapping the given column,
+ # so we render a label
+ render_with_label = True
+ elif isinstance(column, elements.ColumnClause):
+ # table-bound column, we render its name as a label if we are
+ # inside of a subquery only
+ render_with_label = (
+ asfrom
+ and not column.is_literal
+ and column.table is not None
+ )
+ elif isinstance(column, elements.TextClause):
+ render_with_label = False
+ elif isinstance(column, elements.UnaryExpression):
+ render_with_label = column.wraps_column_expression or asfrom
+ elif (
+ # general class of expressions that don't have a SQL-column
+ # addressible name. includes scalar selects, bind parameters,
+ # SQL functions, others
+ not isinstance(column, elements.NamedColumn)
+ # deeper check that indicates there's no natural "name" to
+ # this element, which accommodates for custom SQL constructs
+ # that might have a ".name" attribute (but aren't SQL
+ # functions) but are not implementing this more recently added
+ # base class. in theory the "NamedColumn" check should be
+ # enough, however here we seek to maintain legacy behaviors
+ # as well.
+ and column._non_anon_label is None
+ ):
+ render_with_label = True
+ else:
+ render_with_label = False
+
+ if render_with_label:
+ if not fallback_label_name:
+ # used by the RETURNING case right now. we generate it
+ # here as 3rd party dialects may be referring to
+ # _label_select_column method directly instead of the
+ # just-added _label_returning_column method
+ assert not column_is_repeated
+ fallback_label_name = column._anon_name_label
+
+ fallback_label_name = (
+ elements._truncated_label(fallback_label_name)
+ if not isinstance(
+ fallback_label_name, elements._truncated_label
+ )
+ else fallback_label_name
+ )
+
+ result_expr = _CompileLabel(
+ col_expr, fallback_label_name, alt_names=(proxy_name,)
+ )
+ else:
+ result_expr = col_expr
+
+ column_clause_args.update(
+ within_columns_clause=within_columns_clause,
+ add_to_result_map=add_to_result_map,
+ )
+ return result_expr._compiler_dispatch(self, **column_clause_args)
+
+ def format_from_hint_text(self, sqltext, table, hint, iscrud):
+ hinttext = self.get_from_hint_text(table, hint)
+ if hinttext:
+ sqltext += " " + hinttext
+ return sqltext
+
+ def get_select_hint_text(self, byfroms):
+ return None
+
+ def get_from_hint_text(self, table, text):
+ return None
+
+ def get_crud_hint_text(self, table, text):
+ return None
+
+ def get_statement_hint_text(self, hint_texts):
+ return " ".join(hint_texts)
+
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
+
+ def _display_froms_for_select(
+ self, select_stmt, asfrom, lateral=False, **kw
+ ):
+ # utility method to help external dialects
+ # get the correct from list for a select.
+ # specifically the oracle dialect needs this feature
+ # right now.
+ toplevel = not self.stack
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ compile_state = select_stmt._compile_state_factory(select_stmt, self)
+
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
+
+ if asfrom and not lateral:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
+ else:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms,
+ )
+ return froms
+
+ translate_select_structure = None
+ """if not ``None``, should be a callable which accepts ``(select_stmt,
+ **kw)`` and returns a select object. this is used for structural changes
+ mostly to accommodate for LIMIT/OFFSET schemes
+
+ """
+
+ def visit_select(
+ self,
+ select_stmt,
+ asfrom=False,
+ insert_into=False,
+ fromhints=None,
+ compound_index=None,
+ select_wraps_for=None,
+ lateral=False,
+ from_linter=None,
+ **kwargs
+ ):
+ assert select_wraps_for is None, (
+ "SQLAlchemy 1.4 requires use of "
+ "the translate_select_structure hook for structural "
+ "translations of SELECT objects"
+ )
+
+ # initial setup of SELECT. the compile_state_factory may now
+ # be creating a totally different SELECT from the one that was
+ # passed in. for ORM use this will convert from an ORM-state
+ # SELECT to a regular "Core" SELECT. other composed operations
+ # such as computation of joins will be performed.
+
+ kwargs["within_columns_clause"] = False
+
+ compile_state = select_stmt._compile_state_factory(
+ select_stmt, self, **kwargs
+ )
+ select_stmt = compile_state.statement
+
+ toplevel = not self.stack
+
+ if toplevel and not self.compile_state:
+ self.compile_state = compile_state
+
+ is_embedded_select = compound_index is not None or insert_into
+
+ # translate step for Oracle, SQL Server which often need to
+ # restructure the SELECT to allow for LIMIT/OFFSET and possibly
+ # other conditions
+ if self.translate_select_structure:
+ new_select_stmt = self.translate_select_structure(
+ select_stmt, asfrom=asfrom, **kwargs
+ )
+
+ # if SELECT was restructured, maintain a link to the originals
+ # and assemble a new compile state
+ if new_select_stmt is not select_stmt:
+ compile_state_wraps_for = compile_state
+ select_wraps_for = select_stmt
+ select_stmt = new_select_stmt
+
+ compile_state = select_stmt._compile_state_factory(
+ select_stmt, self, **kwargs
+ )
+ select_stmt = compile_state.statement
+
+ entry = self._default_stack_entry if toplevel else self.stack[-1]
+
+ populate_result_map = need_column_expressions = (
+ toplevel
+ or entry.get("need_result_map_for_compound", False)
+ or entry.get("need_result_map_for_nested", False)
+ )
+
+ # indicates there is a CompoundSelect in play and we are not the
+ # first select
+ if compound_index:
+ populate_result_map = False
+
+ # this was first proposed as part of #3372; however, it is not
+ # reached in current tests and could possibly be an assertion
+ # instead.
+ if not populate_result_map and "add_to_result_map" in kwargs:
+ del kwargs["add_to_result_map"]
+
+ froms = self._setup_select_stack(
+ select_stmt, compile_state, entry, asfrom, lateral, compound_index
+ )
+
+ column_clause_args = kwargs.copy()
+ column_clause_args.update(
+ {"within_label_clause": False, "within_columns_clause": False}
+ )
+
+ text = "SELECT " # we're off to a good start !
+
+ if select_stmt._hints:
+ hint_text, byfrom = self._setup_select_hints(select_stmt)
+ if hint_text:
+ text += hint_text + " "
+ else:
+ byfrom = None
+
+ if select_stmt._independent_ctes:
+ for cte in select_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kwargs)
+
+ if select_stmt._prefixes:
+ text += self._generate_prefixes(
+ select_stmt, select_stmt._prefixes, **kwargs
+ )
+
+ text += self.get_select_precolumns(select_stmt, **kwargs)
+ # the actual list of columns to print in the SELECT column list.
+ inner_columns = [
+ c
+ for c in [
+ self._label_select_column(
+ select_stmt,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=name,
+ proxy_name=proxy_name,
+ fallback_label_name=fallback_label_name,
+ column_is_repeated=repeated,
+ need_column_expressions=need_column_expressions,
+ )
+ for (
+ name,
+ proxy_name,
+ fallback_label_name,
+ column,
+ repeated,
+ ) in compile_state.columns_plus_names
+ ]
+ if c is not None
+ ]
+
+ if populate_result_map and select_wraps_for is not None:
+ # if this select was generated from translate_select,
+ # rewrite the targeted columns in the result map
+
+ translate = dict(
+ zip(
+ [
+ name
+ for (
+ key,
+ proxy_name,
+ fallback_label_name,
+ name,
+ repeated,
+ ) in compile_state.columns_plus_names
+ ],
+ [
+ name
+ for (
+ key,
+ proxy_name,
+ fallback_label_name,
+ name,
+ repeated,
+ ) in compile_state_wraps_for.columns_plus_names
+ ],
+ )
+ )
+
+ self._result_columns = [
+ (key, name, tuple(translate.get(o, o) for o in obj), type_)
+ for key, name, obj, type_ in self._result_columns
+ ]
+
+ text = self._compose_select_body(
+ text,
+ select_stmt,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
+ )
+
+ if select_stmt._statement_hints:
+ per_dialect = [
+ ht
+ for (dialect_name, ht) in select_stmt._statement_hints
+ if dialect_name in ("*", self.dialect.name)
+ ]
+ if per_dialect:
+ text += " " + self.get_statement_hint_text(per_dialect)
+
+ if self.ctes:
+ # In compound query, CTEs are shared at the compound level
+ if not is_embedded_select:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(nesting_level=nesting_level) + text
+ )
+
+ if select_stmt._suffixes:
+ text += " " + self._generate_prefixes(
+ select_stmt, select_stmt._suffixes, **kwargs
+ )
+
+ self.stack.pop(-1)
+
+ return text
+
+ def _setup_select_hints(self, select):
+ byfrom = dict(
+ [
+ (
+ from_,
+ hinttext
+ % {"name": from_._compiler_dispatch(self, ashint=True)},
+ )
+ for (from_, dialect), hinttext in select._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
+ hint_text = self.get_select_hint_text(byfrom)
+ return hint_text, byfrom
+
+ def _setup_select_stack(
+ self, select, compile_state, entry, asfrom, lateral, compound_index
+ ):
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
+
+ if compound_index == 0:
+ entry["select_0"] = select
+ elif compound_index:
+ select_0 = entry["select_0"]
+ numcols = len(select_0._all_selected_columns)
+
+ if len(compile_state.columns_plus_names) != numcols:
+ raise exc.CompileError(
+ "All selectables passed to "
+ "CompoundSelect must have identical numbers of "
+ "columns; select #%d has %d columns, select "
+ "#%d has %d"
+ % (
+ 1,
+ numcols,
+ compound_index + 1,
+ len(select._all_selected_columns),
+ )
+ )
+
+ if asfrom and not lateral:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms.difference(
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
+ else:
+ froms = compile_state._get_display_froms(
+ explicit_correlate_froms=correlate_froms,
+ implicit_correlate_froms=asfrom_froms,
+ )
+
+ new_correlate_froms = set(selectable._from_objects(*froms))
+ all_correlate_froms = new_correlate_froms.union(correlate_froms)
+
+ new_entry = {
+ "asfrom_froms": new_correlate_froms,
+ "correlate_froms": all_correlate_froms,
+ "selectable": select,
+ "compile_state": compile_state,
+ }
+ self.stack.append(new_entry)
+
+ return froms
+
+ def _compose_select_body(
+ self,
+ text,
+ select,
+ compile_state,
+ inner_columns,
+ froms,
+ byfrom,
+ toplevel,
+ kwargs,
+ ):
+ text += ", ".join(inner_columns)
+
+ if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+ from_linter = FromLinter({}, set())
+ warn_linting = self.linting & WARN_LINTING
+ if toplevel:
+ self.from_linter = from_linter
+ else:
+ from_linter = None
+ warn_linting = False
+
+ if froms:
+ text += " \nFROM "
+
+ if select._hints:
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self,
+ asfrom=True,
+ fromhints=byfrom,
+ from_linter=from_linter,
+ **kwargs
+ )
+ for f in froms
+ ]
+ )
+ else:
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self,
+ asfrom=True,
+ from_linter=from_linter,
+ **kwargs
+ )
+ for f in froms
+ ]
+ )
+ else:
+ text += self.default_from()
+
+ if select._where_criteria:
+ t = self._generate_delimited_and_list(
+ select._where_criteria, from_linter=from_linter, **kwargs
+ )
+ if t:
+ text += " \nWHERE " + t
+
+ if warn_linting:
+ from_linter.warn()
+
+ if select._group_by_clauses:
+ text += self.group_by_clause(select, **kwargs)
+
+ if select._having_criteria:
+ t = self._generate_delimited_and_list(
+ select._having_criteria, **kwargs
+ )
+ if t:
+ text += " \nHAVING " + t
+
+ if select._order_by_clauses:
+ text += self.order_by_clause(select, **kwargs)
+
+ if select._has_row_limiting_clause:
+ text += self._row_limit_clause(select, **kwargs)
+
+ if select._for_update_arg is not None:
+ text += self.for_update_clause(select, **kwargs)
+
+ return text
+
+ def _generate_prefixes(self, stmt, prefixes, **kw):
+ clause = " ".join(
+ prefix._compiler_dispatch(self, **kw)
+ for prefix, dialect_name in prefixes
+ if dialect_name is None or dialect_name == self.dialect.name
+ )
+ if clause:
+ clause += " "
+ return clause
+
+ def _render_cte_clause(
+ self,
+ nesting_level=None,
+ include_following_stack=False,
+ ):
+ """
+ include_following_stack
+ Also render the nesting CTEs on the next stack. Useful for
+ SQL structures like UNION or INSERT that can wrap SELECT
+ statements containing nesting CTEs.
+ """
+ if not self.ctes:
+ return ""
+
+ if nesting_level and nesting_level > 1:
+ ctes = util.OrderedDict()
+ for cte in list(self.ctes.keys()):
+ cte_level, cte_name = self.level_name_by_cte[
+ cte._get_reference_cte()
+ ]
+ is_rendered_level = cte_level == nesting_level or (
+ include_following_stack and cte_level == nesting_level + 1
+ )
+ if not (cte.nesting and is_rendered_level):
+ continue
+
+ ctes[cte] = self.ctes[cte]
+
+ else:
+ ctes = self.ctes
+
+ if not ctes:
+ return ""
+
+ ctes_recursive = any([cte.recursive for cte in ctes])
+
+ if self.positional:
+ self.positiontup = (
+ sum([self.cte_positional[cte] for cte in ctes], [])
+ + self.positiontup
+ )
+ cte_text = self.get_cte_preamble(ctes_recursive) + " "
+ cte_text += ", \n".join([txt for txt in ctes.values()])
+ cte_text += "\n "
+
+ if nesting_level and nesting_level > 1:
+ for cte in list(ctes.keys()):
+ cte_level, cte_name = self.level_name_by_cte[
+ cte._get_reference_cte()
+ ]
+ del self.ctes[cte]
+ del self.ctes_by_level_name[(cte_level, cte_name)]
+ del self.level_name_by_cte[cte._get_reference_cte()]
+
+ return cte_text
+
+ def get_cte_preamble(self, recursive):
+ if recursive:
+ return "WITH RECURSIVE"
+ else:
+ return "WITH"
+
+ def get_select_precolumns(self, select, **kw):
+ """Called when building a ``SELECT`` statement, position is just
+ before column list.
+
+ """
+ if select._distinct_on:
+ util.warn_deprecated(
+ "DISTINCT ON is currently supported only by the PostgreSQL "
+ "dialect. Use of DISTINCT ON for other backends is currently "
+ "silently ignored, however this usage is deprecated, and will "
+ "raise CompileError in a future release for all backends "
+ "that do not support this syntax.",
+ version="1.4",
+ )
+ return "DISTINCT " if select._distinct else ""
+
+ def group_by_clause(self, select, **kw):
+ """allow dialects to customize how GROUP BY is rendered."""
+
+ group_by = self._generate_delimited_list(
+ select._group_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
+ if group_by:
+ return " GROUP BY " + group_by
+ else:
+ return ""
+
+ def order_by_clause(self, select, **kw):
+ """allow dialects to customize how ORDER BY is rendered."""
+
+ order_by = self._generate_delimited_list(
+ select._order_by_clauses, OPERATORS[operators.comma_op], **kw
+ )
+
+ if order_by:
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+ def for_update_clause(self, select, **kw):
+ return " FOR UPDATE"
+
+ def returning_clause(self, stmt, returning_cols):
+ raise exc.CompileError(
+ "RETURNING is not supported by this "
+ "dialect's statement compiler."
+ )
+
+ def limit_clause(self, select, **kw):
+ text = ""
+ if select._limit_clause is not None:
+ text += "\n LIMIT " + self.process(select._limit_clause, **kw)
+ if select._offset_clause is not None:
+ if select._limit_clause is None:
+ text += "\n LIMIT -1"
+ text += " OFFSET " + self.process(select._offset_clause, **kw)
+ return text
+
+ def fetch_clause(self, select, **kw):
+ text = ""
+ if select._offset_clause is not None:
+ text += "\n OFFSET %s ROWS" % self.process(
+ select._offset_clause, **kw
+ )
+ if select._fetch_clause is not None:
+ text += "\n FETCH FIRST %s%s ROWS %s" % (
+ self.process(select._fetch_clause, **kw),
+ " PERCENT" if select._fetch_clause_options["percent"] else "",
+ "WITH TIES"
+ if select._fetch_clause_options["with_ties"]
+ else "ONLY",
+ )
+ return text
+
+ def visit_table(
+ self,
+ table,
+ asfrom=False,
+ iscrud=False,
+ ashint=False,
+ fromhints=None,
+ use_schema=True,
+ from_linter=None,
+ **kwargs
+ ):
+ if from_linter:
+ from_linter.froms[table] = table.fullname
+
+ if asfrom or ashint:
+ effective_schema = self.preparer.schema_for_object(table)
+
+ if use_schema and effective_schema:
+ ret = (
+ self.preparer.quote_schema(effective_schema)
+ + "."
+ + self.preparer.quote(table.name)
+ )
+ else:
+ ret = self.preparer.quote(table.name)
+ if fromhints and table in fromhints:
+ ret = self.format_from_hint_text(
+ ret, table, fromhints[table], iscrud
+ )
+ return ret
+ else:
+ return ""
+
+ def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+ if from_linter:
+ from_linter.edges.update(
+ itertools.product(
+ join.left._from_objects, join.right._from_objects
+ )
+ )
+
+ if join.full:
+ join_type = " FULL OUTER JOIN "
+ elif join.isouter:
+ join_type = " LEFT OUTER JOIN "
+ else:
+ join_type = " JOIN "
+ return (
+ join.left._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ + join_type
+ + join.right._compiler_dispatch(
+ self, asfrom=True, from_linter=from_linter, **kwargs
+ )
+ + " ON "
+ # TODO: likely need asfrom=True here?
+ + join.onclause._compiler_dispatch(
+ self, from_linter=from_linter, **kwargs
+ )
+ )
+
+ def _setup_crud_hints(self, stmt, table_text):
+ dialect_hints = dict(
+ [
+ (table, hint_text)
+ for (table, dialect), hint_text in stmt._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
+ if stmt.table in dialect_hints:
+ table_text = self.format_from_hint_text(
+ table_text, stmt.table, dialect_hints[stmt.table], True
+ )
+ return dialect_hints, table_text
+
+ def visit_insert(self, insert_stmt, **kw):
+
+ compile_state = insert_stmt._compile_state_factory(
+ insert_stmt, self, **kw
+ )
+ insert_stmt = compile_state.statement
+
+ toplevel = not self.stack
+
+ if toplevel:
+ self.isinsert = True
+ if not self.dml_compile_state:
+ self.dml_compile_state = compile_state
+ if not self.compile_state:
+ self.compile_state = compile_state
+
+ self.stack.append(
+ {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": insert_stmt,
+ }
+ )
+
+ crud_params = crud._get_crud_params(
+ self, insert_stmt, compile_state, **kw
+ )
+
+ if (
+ not crud_params
+ and not self.dialect.supports_default_values
+ and not self.dialect.supports_default_metavalue
+ and not self.dialect.supports_empty_insert
+ ):
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." % self.dialect.name
+ )
+
+ if compile_state._has_multi_parameters:
+ if not self.dialect.supports_multivalues_insert:
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support "
+ "in-place multirow inserts." % self.dialect.name
+ )
+ crud_params_single = crud_params[0]
+ else:
+ crud_params_single = crud_params
+
+ preparer = self.preparer
+ supports_default_values = self.dialect.supports_default_values
+
+ text = "INSERT "
+
+ if insert_stmt._prefixes:
+ text += self._generate_prefixes(
+ insert_stmt, insert_stmt._prefixes, **kw
+ )
+
+ text += "INTO "
+ table_text = preparer.format_table(insert_stmt.table)
+
+ if insert_stmt._hints:
+ _, table_text = self._setup_crud_hints(insert_stmt, table_text)
+
+ if insert_stmt._independent_ctes:
+ for cte in insert_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ text += table_text
+
+ if crud_params_single or not supports_default_values:
+ text += " (%s)" % ", ".join(
+ [expr for c, expr, value in crud_params_single]
+ )
+
+ if self.returning or insert_stmt._returning:
+ returning_clause = self.returning_clause(
+ insert_stmt, self.returning or insert_stmt._returning
+ )
+
+ if self.returning_precedes_values:
+ text += " " + returning_clause
+ else:
+ returning_clause = None
+
+ if insert_stmt.select is not None:
+ # placed here by crud.py
+ select_text = self.process(
+ self.stack[-1]["insert_from_select"], insert_into=True, **kw
+ )
+
+ if self.ctes and self.dialect.cte_follows_insert:
+ nesting_level = len(self.stack) if not toplevel else None
+ text += " %s%s" % (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ ),
+ select_text,
+ )
+ else:
+ text += " %s" % select_text
+ elif not crud_params and supports_default_values:
+ text += " DEFAULT VALUES"
+ elif compile_state._has_multi_parameters:
+ text += " VALUES %s" % (
+ ", ".join(
+ "(%s)"
+ % (", ".join(value for c, expr, value in crud_param_set))
+ for crud_param_set in crud_params
+ )
+ )
+ else:
+ insert_single_values_expr = ", ".join(
+ [value for c, expr, value in crud_params]
+ )
+ text += " VALUES (%s)" % insert_single_values_expr
+ if toplevel:
+ self.insert_single_values_expr = insert_single_values_expr
+
+ if insert_stmt._post_values_clause is not None:
+ post_values_clause = self.process(
+ insert_stmt._post_values_clause, **kw
+ )
+ if post_values_clause:
+ text += " " + post_values_clause
+
+ if returning_clause and not self.returning_precedes_values:
+ text += " " + returning_clause
+
+ if self.ctes and not self.dialect.cte_follows_insert:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level, include_following_stack=True
+ )
+ + text
+ )
+
+ self.stack.pop(-1)
+
+ return text
+
+ def update_limit_clause(self, update_stmt):
+ """Provide a hook for MySQL to add LIMIT to the UPDATE"""
+ return None
+
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ """Provide a hook to override the initial table clause
+ in an UPDATE statement.
+
+ MySQL overrides this.
+
+ """
+ kw["asfrom"] = True
+ return from_table._compiler_dispatch(self, iscrud=True, **kw)
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Provide a hook to override the generation of an
+ UPDATE..FROM clause.
+
+ MySQL and MSSQL override this.
+
+ """
+ raise NotImplementedError(
+ "This backend does not support multiple-table "
+ "criteria within UPDATE"
+ )
+
+ def visit_update(self, update_stmt, **kw):
+ compile_state = update_stmt._compile_state_factory(
+ update_stmt, self, **kw
+ )
+ update_stmt = compile_state.statement
+
+ toplevel = not self.stack
+ if toplevel:
+ self.isupdate = True
+ if not self.dml_compile_state:
+ self.dml_compile_state = compile_state
+ if not self.compile_state:
+ self.compile_state = compile_state
+
+ extra_froms = compile_state._extra_froms
+ is_multitable = bool(extra_froms)
+
+ if is_multitable:
+ # main table might be a JOIN
+ main_froms = set(selectable._from_objects(update_stmt.table))
+ render_extra_froms = [
+ f for f in extra_froms if f not in main_froms
+ ]
+ correlate_froms = main_froms.union(extra_froms)
+ else:
+ render_extra_froms = []
+ correlate_froms = {update_stmt.table}
+
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": update_stmt,
+ }
+ )
+
+ text = "UPDATE "
+
+ if update_stmt._prefixes:
+ text += self._generate_prefixes(
+ update_stmt, update_stmt._prefixes, **kw
+ )
+
+ table_text = self.update_tables_clause(
+ update_stmt, update_stmt.table, render_extra_froms, **kw
+ )
+ crud_params = crud._get_crud_params(
+ self, update_stmt, compile_state, **kw
+ )
+
+ if update_stmt._hints:
+ dialect_hints, table_text = self._setup_crud_hints(
+ update_stmt, table_text
+ )
+ else:
+ dialect_hints = None
+
+ if update_stmt._independent_ctes:
+ for cte in update_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ text += table_text
+
+ text += " SET "
+ text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
+
+ if self.returning or update_stmt._returning:
+ if self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ update_stmt, self.returning or update_stmt._returning
+ )
+
+ if extra_froms:
+ extra_from_text = self.update_from_clause(
+ update_stmt,
+ update_stmt.table,
+ render_extra_froms,
+ dialect_hints,
+ **kw
+ )
+ if extra_from_text:
+ text += " " + extra_from_text
+
+ if update_stmt._where_criteria:
+ t = self._generate_delimited_and_list(
+ update_stmt._where_criteria, **kw
+ )
+ if t:
+ text += " WHERE " + t
+
+ limit_clause = self.update_limit_clause(update_stmt)
+ if limit_clause:
+ text += " " + limit_clause
+
+ if (
+ self.returning or update_stmt._returning
+ ) and not self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ update_stmt, self.returning or update_stmt._returning
+ )
+
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+ self.stack.pop(-1)
+
+ return text
+
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ """Provide a hook to override the generation of an
+ DELETE..FROM clause.
+
+ This can be used to implement DELETE..USING for example.
+
+ MySQL and MSSQL override this.
+
+ """
+ raise NotImplementedError(
+ "This backend does not support multiple-table "
+ "criteria within DELETE"
+ )
+
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
+ return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
+
+ def visit_delete(self, delete_stmt, **kw):
+ compile_state = delete_stmt._compile_state_factory(
+ delete_stmt, self, **kw
+ )
+ delete_stmt = compile_state.statement
+
+ toplevel = not self.stack
+ if toplevel:
+ self.isdelete = True
+ if not self.dml_compile_state:
+ self.dml_compile_state = compile_state
+ if not self.compile_state:
+ self.compile_state = compile_state
+
+ extra_froms = compile_state._extra_froms
+
+ correlate_froms = {delete_stmt.table}.union(extra_froms)
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": delete_stmt,
+ }
+ )
+
+ text = "DELETE "
+
+ if delete_stmt._prefixes:
+ text += self._generate_prefixes(
+ delete_stmt, delete_stmt._prefixes, **kw
+ )
+
+ text += "FROM "
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
+
+ if delete_stmt._hints:
+ dialect_hints, table_text = self._setup_crud_hints(
+ delete_stmt, table_text
+ )
+ else:
+ dialect_hints = None
+
+ if delete_stmt._independent_ctes:
+ for cte in delete_stmt._independent_ctes:
+ cte._compiler_dispatch(self, **kw)
+
+ text += table_text
+
+ if delete_stmt._returning:
+ if self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ delete_stmt, delete_stmt._returning
+ )
+
+ if extra_froms:
+ extra_from_text = self.delete_extra_from_clause(
+ delete_stmt,
+ delete_stmt.table,
+ extra_froms,
+ dialect_hints,
+ **kw
+ )
+ if extra_from_text:
+ text += " " + extra_from_text
+
+ if delete_stmt._where_criteria:
+ t = self._generate_delimited_and_list(
+ delete_stmt._where_criteria, **kw
+ )
+ if t:
+ text += " WHERE " + t
+
+ if delete_stmt._returning and not self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ delete_stmt, delete_stmt._returning
+ )
+
+ if self.ctes:
+ nesting_level = len(self.stack) if not toplevel else None
+ text = self._render_cte_clause(nesting_level=nesting_level) + text
+
+ self.stack.pop(-1)
+
+ return text
+
+ def visit_savepoint(self, savepoint_stmt):
+ return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+ def visit_release_savepoint(self, savepoint_stmt):
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
+
+
+class StrSQLCompiler(SQLCompiler):
+ """A :class:`.SQLCompiler` subclass which allows a small selection
+ of non-standard SQL features to render into a string value.
+
+ The :class:`.StrSQLCompiler` is invoked whenever a Core expression
+ element is directly stringified without calling upon the
+ :meth:`_expression.ClauseElement.compile` method.
+ It can render a limited set
+ of non-standard SQL constructs to assist in basic stringification,
+ however for more substantial custom or dialect-specific SQL constructs,
+ it will be necessary to make use of
+ :meth:`_expression.ClauseElement.compile`
+ directly.
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string`
+
+ """
+
+ def _fallback_column_name(self, column):
+ return "<name unknown>"
+
+ @util.preload_module("sqlalchemy.engine.url")
+ def visit_unsupported_compilation(self, element, err, **kw):
+ if element.stringify_dialect != "default":
+ url = util.preloaded.engine_url
+ dialect = url.URL.create(element.stringify_dialect).get_dialect()()
+
+ compiler = dialect.statement_compiler(dialect, None)
+ if not isinstance(compiler, StrSQLCompiler):
+ return compiler.process(element)
+
+ return super(StrSQLCompiler, self).visit_unsupported_compilation(
+ element, err
+ )
+
+ def visit_getitem_binary(self, binary, operator, **kw):
+ return "%s[%s]" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ )
+
+ def visit_json_getitem_op_binary(self, binary, operator, **kw):
+ return self.visit_getitem_binary(binary, operator, **kw)
+
+ def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
+ return self.visit_getitem_binary(binary, operator, **kw)
+
+ def visit_sequence(self, seq, **kw):
+ return "<next sequence value: %s>" % self.preparer.format_sequence(seq)
+
+ def returning_clause(self, stmt, returning_cols):
+ columns = [
+ self._label_select_column(None, c, True, False, {})
+ for c in base._select_iterables(returning_cols)
+ ]
+
+ return "RETURNING " + ", ".join(columns)
+
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ kw["asfrom"] = True
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ kw["asfrom"] = True
+ return ", " + ", ".join(
+ t._compiler_dispatch(self, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
+
+ def visit_empty_set_expr(self, type_):
+ return "SELECT 1 WHERE 1!=1"
+
+ def get_from_hint_text(self, table, text):
+ return "[%s]" % text
+
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " <regexp> ", **kw)
+
+ def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
+ return self._generate_generic_binary(binary, " <not regexp> ", **kw)
+
+ def visit_regexp_replace_op_binary(self, binary, operator, **kw):
+ replacement = binary.modifiers["replacement"]
+ return "<regexp replace>(%s, %s, %s)" % (
+ binary.left._compiler_dispatch(self, **kw),
+ binary.right._compiler_dispatch(self, **kw),
+ replacement._compiler_dispatch(self, **kw),
+ )
+
+
+class DDLCompiler(Compiled):
+ @util.memoized_property
+ def sql_compiler(self):
+ return self.dialect.statement_compiler(
+ self.dialect, None, schema_translate_map=self.schema_translate_map
+ )
+
+ @util.memoized_property
+ def type_compiler(self):
+ return self.dialect.type_compiler
+
+ def construct_params(
+ self, params=None, extracted_parameters=None, escape_names=True
+ ):
+ return None
+
+ def visit_ddl(self, ddl, **kwargs):
+ # table events can substitute table and schema name
+ context = ddl.context
+ if isinstance(ddl.target, schema.Table):
+ context = context.copy()
+
+ preparer = self.preparer
+ path = preparer.format_table_seq(ddl.target)
+ if len(path) == 1:
+ table, sch = path[0], ""
+ else:
+ table, sch = path[-1], path[0]
+
+ context.setdefault("table", table)
+ context.setdefault("schema", sch)
+ context.setdefault("fullname", preparer.format_table(ddl.target))
+
+ return self.sql_compiler.post_process_text(ddl.statement % context)
+
+ def visit_create_schema(self, create, **kw):
+ schema = self.preparer.format_schema(create.element)
+ return "CREATE SCHEMA " + schema
+
+ def visit_drop_schema(self, drop, **kw):
+ schema = self.preparer.format_schema(drop.element)
+ text = "DROP SCHEMA " + schema
+ if drop.cascade:
+ text += " CASCADE"
+ return text
+
+ def visit_create_table(self, create, **kw):
+ table = create.element
+ preparer = self.preparer
+
+ text = "\nCREATE "
+ if table._prefixes:
+ text += " ".join(table._prefixes) + " "
+
+ text += "TABLE "
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += preparer.format_table(table) + " "
+
+ create_table_suffix = self.create_table_suffix(table)
+ if create_table_suffix:
+ text += create_table_suffix + " "
+
+ text += "("
+
+ separator = "\n"
+
+ # if only one primary key, specify it along with the column
+ first_pk = False
+ for create_column in create.columns:
+ column = create_column.element
+ try:
+ processed = self.process(
+ create_column, first_pk=column.primary_key and not first_pk
+ )
+ if processed is not None:
+ text += separator
+ separator = ", \n"
+ text += "\t" + processed
+ if column.primary_key:
+ first_pk = True
+ except exc.CompileError as ce:
+ util.raise_(
+ exc.CompileError(
+ util.u("(in table '%s', column '%s'): %s")
+ % (table.description, column.name, ce.args[0])
+ ),
+ from_=ce,
+ )
+
+ const = self.create_table_constraints(
+ table,
+ _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
+ )
+ if const:
+ text += separator + "\t" + const
+
+ text += "\n)%s\n\n" % self.post_create_table(table)
+ return text
+
+ def visit_create_column(self, create, first_pk=False, **kw):
+ column = create.element
+
+ if column.system:
+ return None
+
+ text = self.get_column_specification(column, first_pk=first_pk)
+ const = " ".join(
+ self.process(constraint) for constraint in column.constraints
+ )
+ if const:
+ text += " " + const
+
+ return text
+
+ def create_table_constraints(
+ self, table, _include_foreign_key_constraints=None, **kw
+ ):
+
+ # On some DB order is significant: visit PK first, then the
+ # other constraints (engine.ReflectionTest.testbasic failed on FB2)
+ constraints = []
+ if table.primary_key:
+ constraints.append(table.primary_key)
+
+ all_fkcs = table.foreign_key_constraints
+ if _include_foreign_key_constraints is not None:
+ omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
+ else:
+ omit_fkcs = set()
+
+ constraints.extend(
+ [
+ c
+ for c in table._sorted_constraints
+ if c is not table.primary_key and c not in omit_fkcs
+ ]
+ )
+
+ return ", \n\t".join(
+ p
+ for p in (
+ self.process(constraint)
+ for constraint in constraints
+ if (
+ constraint._create_rule is None
+ or constraint._create_rule(self)
+ )
+ and (
+ not self.dialect.supports_alter
+ or not getattr(constraint, "use_alter", False)
+ )
+ )
+ if p is not None
+ )
+
+ def visit_drop_table(self, drop, **kw):
+ text = "\nDROP TABLE "
+ if drop.if_exists:
+ text += "IF EXISTS "
+ return text + self.preparer.format_table(drop.element)
+
+ def visit_drop_view(self, drop, **kw):
+ return "\nDROP VIEW " + self.preparer.format_table(drop.element)
+
+ def _verify_index_table(self, index):
+ if index.table is None:
+ raise exc.CompileError(
+ "Index '%s' is not associated " "with any table." % index.name
+ )
+
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True, **kw
+ ):
+ index = create.element
+ self._verify_index_table(index)
+ preparer = self.preparer
+ text = "CREATE "
+ if index.unique:
+ text += "UNIQUE "
+ if index.name is None:
+ raise exc.CompileError(
+ "CREATE INDEX requires that the index have a name"
+ )
+
+ text += "INDEX "
+ if create.if_not_exists:
+ text += "IF NOT EXISTS "
+
+ text += "%s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(
+ index.table, use_schema=include_table_schema
+ ),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
+ return text
+
+ def visit_drop_index(self, drop, **kw):
+ index = drop.element
+
+ if index.name is None:
+ raise exc.CompileError(
+ "DROP INDEX requires that the index have a name"
+ )
+ text = "\nDROP INDEX "
+ if drop.if_exists:
+ text += "IF EXISTS "
+
+ return text + self._prepared_index_name(index, include_schema=True)
+
+ def _prepared_index_name(self, index, include_schema=False):
+ if index.table is not None:
+ effective_schema = self.preparer.schema_for_object(index.table)
+ else:
+ effective_schema = None
+ if include_schema and effective_schema:
+ schema_name = self.preparer.quote_schema(effective_schema)
+ else:
+ schema_name = None
+
+ index_name = self.preparer.format_index(index)
+
+ if schema_name:
+ index_name = schema_name + "." + index_name
+ return index_name
+
+ def visit_add_constraint(self, create, **kw):
+ return "ALTER TABLE %s ADD %s" % (
+ self.preparer.format_table(create.element.table),
+ self.process(create.element),
+ )
+
+ def visit_set_table_comment(self, create, **kw):
+ return "COMMENT ON TABLE %s IS %s" % (
+ self.preparer.format_table(create.element),
+ self.sql_compiler.render_literal_value(
+ create.element.comment, sqltypes.String()
+ ),
+ )
+
+ def visit_drop_table_comment(self, drop, **kw):
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
+
+ def visit_set_column_comment(self, create, **kw):
+ return "COMMENT ON COLUMN %s IS %s" % (
+ self.preparer.format_column(
+ create.element, use_table=True, use_schema=True
+ ),
+ self.sql_compiler.render_literal_value(
+ create.element.comment, sqltypes.String()
+ ),
+ )
+
+ def visit_drop_column_comment(self, drop, **kw):
+ return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
+ drop.element, use_table=True
+ )
+
+ def get_identity_options(self, identity_options):
+ text = []
+ if identity_options.increment is not None:
+ text.append("INCREMENT BY %d" % identity_options.increment)
+ if identity_options.start is not None:
+ text.append("START WITH %d" % identity_options.start)
+ if identity_options.minvalue is not None:
+ text.append("MINVALUE %d" % identity_options.minvalue)
+ if identity_options.maxvalue is not None:
+ text.append("MAXVALUE %d" % identity_options.maxvalue)
+ if identity_options.nominvalue is not None:
+ text.append("NO MINVALUE")
+ if identity_options.nomaxvalue is not None:
+ text.append("NO MAXVALUE")
+ if identity_options.cache is not None:
+ text.append("CACHE %d" % identity_options.cache)
+ if identity_options.order is not None:
+ text.append("ORDER" if identity_options.order else "NO ORDER")
+ if identity_options.cycle is not None:
+ text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
+ return " ".join(text)
+
+ def visit_create_sequence(self, create, prefix=None, **kw):
+ text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
+ if prefix:
+ text += prefix
+ if create.element.start is None:
+ create.element.start = self.dialect.default_sequence_base
+ options = self.get_identity_options(create.element)
+ if options:
+ text += " " + options
+ return text
+
+ def visit_drop_sequence(self, drop, **kw):
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
+
+ def visit_drop_constraint(self, drop, **kw):
+ constraint = drop.element
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ else:
+ formatted_name = None
+
+ if formatted_name is None:
+ raise exc.CompileError(
+ "Can't emit DROP CONSTRAINT for constraint %r; "
+ "it has no name" % drop.element
+ )
+ return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
+ self.preparer.format_table(drop.element.table),
+ formatted_name,
+ drop.cascade and " CASCADE" or "",
+ )
+
+ def get_column_specification(self, column, **kwargs):
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ if column.computed is not None:
+ colspec += " " + self.process(column.computed)
+
+ if (
+ column.identity is not None
+ and self.dialect.supports_identity_columns
+ ):
+ colspec += " " + self.process(column.identity)
+
+ if not column.nullable and (
+ not column.identity or not self.dialect.supports_identity_columns
+ ):
+ colspec += " NOT NULL"
+ return colspec
+
+ def create_table_suffix(self, table):
+ return ""
+
+ def post_create_table(self, table):
+ return ""
+
+ def get_column_default_string(self, column):
+ if isinstance(column.server_default, schema.DefaultClause):
+ if isinstance(column.server_default.arg, util.string_types):
+ return self.sql_compiler.render_literal_value(
+ column.server_default.arg, sqltypes.STRINGTYPE
+ )
+ else:
+ return self.sql_compiler.process(
+ column.server_default.arg, literal_binds=True
+ )
+ else:
+ return None
+
+ def visit_table_or_column_check_constraint(self, constraint, **kw):
+ if constraint.is_column_level:
+ return self.visit_column_check_constraint(constraint)
+ else:
+ return self.visit_check_constraint(constraint)
+
+ def visit_check_constraint(self, constraint, **kw):
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_column_check_constraint(self, constraint, **kw):
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_primary_key_constraint(self, constraint, **kw):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "PRIMARY KEY "
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name)
+ for c in (
+ constraint.columns_autoinc_first
+ if constraint._implicit_generated
+ else constraint.columns
+ )
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def visit_foreign_key_constraint(self, constraint, **kw):
+ preparer = self.preparer
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ remote_table = list(constraint.elements)[0].column.table
+ text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+ ", ".join(
+ preparer.quote(f.parent.name) for f in constraint.elements
+ ),
+ self.define_constraint_remote_table(
+ constraint, remote_table, preparer
+ ),
+ ", ".join(
+ preparer.quote(f.column.name) for f in constraint.elements
+ ),
+ )
+ text += self.define_constraint_match(constraint)
+ text += self.define_constraint_cascades(constraint)
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def define_constraint_remote_table(self, constraint, table, preparer):
+ """Format the remote table clause of a CREATE CONSTRAINT clause."""
+
+ return preparer.format_table(table)
+
+ def visit_unique_constraint(self, constraint, **kw):
+ if len(constraint) == 0:
+ return ""
+ text = ""
+ if constraint.name is not None:
+ formatted_name = self.preparer.format_constraint(constraint)
+ if formatted_name is not None:
+ text += "CONSTRAINT %s " % formatted_name
+ text += "UNIQUE (%s)" % (
+ ", ".join(self.preparer.quote(c.name) for c in constraint)
+ )
+ text += self.define_constraint_deferrability(constraint)
+ return text
+
+ def define_constraint_cascades(self, constraint):
+ text = ""
+ if constraint.ondelete is not None:
+ text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
+ constraint.ondelete, FK_ON_DELETE
+ )
+ if constraint.onupdate is not None:
+ text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
+ constraint.onupdate, FK_ON_UPDATE
+ )
+ return text
+
+ def define_constraint_deferrability(self, constraint):
+ text = ""
+ if constraint.deferrable is not None:
+ if constraint.deferrable:
+ text += " DEFERRABLE"
+ else:
+ text += " NOT DEFERRABLE"
+ if constraint.initially is not None:
+ text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
+ constraint.initially, FK_INITIALLY
+ )
+ return text
+
+ def define_constraint_match(self, constraint):
+ text = ""
+ if constraint.match is not None:
+ text += " MATCH %s" % constraint.match
+ return text
+
+ def visit_computed_column(self, generated, **kw):
+ text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
+ generated.sqltext, include_table=False, literal_binds=True
+ )
+ if generated.persisted is True:
+ text += " STORED"
+ elif generated.persisted is False:
+ text += " VIRTUAL"
+ return text
+
+ def visit_identity_column(self, identity, **kw):
+ text = "GENERATED %s AS IDENTITY" % (
+ "ALWAYS" if identity.always else "BY DEFAULT",
+ )
+ options = self.get_identity_options(identity)
+ if options:
+ text += " (%s)" % options
+ return text
+
+
+class GenericTypeCompiler(TypeCompiler):
+ def visit_FLOAT(self, type_, **kw):
+ return "FLOAT"
+
+ def visit_REAL(self, type_, **kw):
+ return "REAL"
+
+ def visit_NUMERIC(self, type_, **kw):
+ if type_.precision is None:
+ return "NUMERIC"
+ elif type_.scale is None:
+ return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
+ else:
+ return "NUMERIC(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
+
+ def visit_DECIMAL(self, type_, **kw):
+ if type_.precision is None:
+ return "DECIMAL"
+ elif type_.scale is None:
+ return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
+ else:
+ return "DECIMAL(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
+
+ def visit_INTEGER(self, type_, **kw):
+ return "INTEGER"
+
+ def visit_SMALLINT(self, type_, **kw):
+ return "SMALLINT"
+
+ def visit_BIGINT(self, type_, **kw):
+ return "BIGINT"
+
+ def visit_TIMESTAMP(self, type_, **kw):
+ return "TIMESTAMP"
+
+ def visit_DATETIME(self, type_, **kw):
+ return "DATETIME"
+
+ def visit_DATE(self, type_, **kw):
+ return "DATE"
+
+ def visit_TIME(self, type_, **kw):
+ return "TIME"
+
+ def visit_CLOB(self, type_, **kw):
+ return "CLOB"
+
+ def visit_NCLOB(self, type_, **kw):
+ return "NCLOB"
+
+ def _render_string_type(self, type_, name):
+
+ text = name
+ if type_.length:
+ text += "(%d)" % type_.length
+ if type_.collation:
+ text += ' COLLATE "%s"' % type_.collation
+ return text
+
+ def visit_CHAR(self, type_, **kw):
+ return self._render_string_type(type_, "CHAR")
+
+ def visit_NCHAR(self, type_, **kw):
+ return self._render_string_type(type_, "NCHAR")
+
+ def visit_VARCHAR(self, type_, **kw):
+ return self._render_string_type(type_, "VARCHAR")
+
+ def visit_NVARCHAR(self, type_, **kw):
+ return self._render_string_type(type_, "NVARCHAR")
+
+ def visit_TEXT(self, type_, **kw):
+ return self._render_string_type(type_, "TEXT")
+
+ def visit_BLOB(self, type_, **kw):
+ return "BLOB"
+
+ def visit_BINARY(self, type_, **kw):
+ return "BINARY" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_VARBINARY(self, type_, **kw):
+ return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
+
+ def visit_BOOLEAN(self, type_, **kw):
+ return "BOOLEAN"
+
+ def visit_large_binary(self, type_, **kw):
+ return self.visit_BLOB(type_, **kw)
+
+ def visit_boolean(self, type_, **kw):
+ return self.visit_BOOLEAN(type_, **kw)
+
+ def visit_time(self, type_, **kw):
+ return self.visit_TIME(type_, **kw)
+
+ def visit_datetime(self, type_, **kw):
+ return self.visit_DATETIME(type_, **kw)
+
+ def visit_date(self, type_, **kw):
+ return self.visit_DATE(type_, **kw)
+
+ def visit_big_integer(self, type_, **kw):
+ return self.visit_BIGINT(type_, **kw)
+
+ def visit_small_integer(self, type_, **kw):
+ return self.visit_SMALLINT(type_, **kw)
+
+ def visit_integer(self, type_, **kw):
+ return self.visit_INTEGER(type_, **kw)
+
+ def visit_real(self, type_, **kw):
+ return self.visit_REAL(type_, **kw)
+
+ def visit_float(self, type_, **kw):
+ return self.visit_FLOAT(type_, **kw)
+
+ def visit_numeric(self, type_, **kw):
+ return self.visit_NUMERIC(type_, **kw)
+
+ def visit_string(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
+
+ def visit_unicode(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
+
+ def visit_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
+
+ def visit_unicode_text(self, type_, **kw):
+ return self.visit_TEXT(type_, **kw)
+
+ def visit_enum(self, type_, **kw):
+ return self.visit_VARCHAR(type_, **kw)
+
+ def visit_null(self, type_, **kw):
+ raise exc.CompileError(
+ "Can't generate DDL for %r; "
+ "did you forget to specify a "
+ "type on this Column?" % type_
+ )
+
+ def visit_type_decorator(self, type_, **kw):
+ return self.process(type_.type_engine(self.dialect), **kw)
+
+ def visit_user_defined(self, type_, **kw):
+ return type_.get_col_spec(**kw)
+
+
+class StrSQLTypeCompiler(GenericTypeCompiler):
+ def process(self, type_, **kw):
+ try:
+ _compiler_dispatch = type_._compiler_dispatch
+ except AttributeError:
+ return self._visit_unknown(type_, **kw)
+ else:
+ return _compiler_dispatch(self, **kw)
+
+ def __getattr__(self, key):
+ if key.startswith("visit_"):
+ return self._visit_unknown
+ else:
+ raise AttributeError(key)
+
+ def _visit_unknown(self, type_, **kw):
+ if type_.__class__.__name__ == type_.__class__.__name__.upper():
+ return type_.__class__.__name__
+ else:
+ return repr(type_)
+
+ def visit_null(self, type_, **kw):
+ return "NULL"
+
+ def visit_user_defined(self, type_, **kw):
+ try:
+ get_col_spec = type_.get_col_spec
+ except AttributeError:
+ return repr(type_)
+ else:
+ return get_col_spec(**kw)
+
+
+class IdentifierPreparer(object):
+
+ """Handle quoting and case-folding of identifiers based on options."""
+
+ reserved_words = RESERVED_WORDS
+
+ legal_characters = LEGAL_CHARACTERS
+
+ illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
+
+ schema_for_object = operator.attrgetter("schema")
+ """Return the .schema attribute for an object.
+
+ For the default IdentifierPreparer, the schema for an object is always
+ the value of the ".schema" attribute. if the preparer is replaced
+ with one that has a non-empty schema_translate_map, the value of the
+ ".schema" attribute is rendered a symbol that will be converted to a
+ real schema name from the mapping post-compile.
+
+ """
+
+ def __init__(
+ self,
+ dialect,
+ initial_quote='"',
+ final_quote=None,
+ escape_quote='"',
+ quote_case_sensitive_collations=True,
+ omit_schema=False,
+ ):
+ """Construct a new ``IdentifierPreparer`` object.
+
+ initial_quote
+ Character that begins a delimited identifier.
+
+ final_quote
+ Character that ends a delimited identifier. Defaults to
+ `initial_quote`.
+
+ omit_schema
+ Prevent prepending schema name. Useful for databases that do
+ not support schemae.
+ """
+
+ self.dialect = dialect
+ self.initial_quote = initial_quote
+ self.final_quote = final_quote or self.initial_quote
+ self.escape_quote = escape_quote
+ self.escape_to_quote = self.escape_quote * 2
+ self.omit_schema = omit_schema
+ self.quote_case_sensitive_collations = quote_case_sensitive_collations
+ self._strings = {}
+ self._double_percents = self.dialect.paramstyle in (
+ "format",
+ "pyformat",
+ )
+
+ def _with_schema_translate(self, schema_translate_map):
+ prep = self.__class__.__new__(self.__class__)
+ prep.__dict__.update(self.__dict__)
+
+ def symbol_getter(obj):
+ name = obj.schema
+ if name in schema_translate_map and obj._use_schema_map:
+ if name is not None and ("[" in name or "]" in name):
+ raise exc.CompileError(
+ "Square bracket characters ([]) not supported "
+ "in schema translate name '%s'" % name
+ )
+ return quoted_name(
+ "__[SCHEMA_%s]" % (name or "_none"), quote=False
+ )
+ else:
+ return obj.schema
+
+ prep.schema_for_object = symbol_getter
+ return prep
+
+ def _render_schema_translates(self, statement, schema_translate_map):
+ d = schema_translate_map
+ if None in d:
+ d["_none"] = d[None]
+
+ def replace(m):
+ name = m.group(2)
+ effective_schema = d[name]
+ if not effective_schema:
+ effective_schema = self.dialect.default_schema_name
+ if not effective_schema:
+ # TODO: no coverage here
+ raise exc.CompileError(
+ "Dialect has no default schema name; can't "
+ "use None as dynamic schema target."
+ )
+ return self.quote_schema(effective_schema)
+
+ return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
+
+ def _escape_identifier(self, value):
+ """Escape an identifier.
+
+ Subclasses should override this to provide database-dependent
+ escaping behavior.
+ """
+
+ value = value.replace(self.escape_quote, self.escape_to_quote)
+ if self._double_percents:
+ value = value.replace("%", "%%")
+ return value
+
+ def _unescape_identifier(self, value):
+ """Canonicalize an escaped identifier.
+
+ Subclasses should override this to provide database-dependent
+ unescaping behavior that reverses _escape_identifier.
+ """
+
+ return value.replace(self.escape_to_quote, self.escape_quote)
+
+ def validate_sql_phrase(self, element, reg):
+ """keyword sequence filter.
+
+ a filter for elements that are intended to represent keyword sequences,
+ such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
+ should be present.
+
+ .. versionadded:: 1.3
+
+ """
+
+ if element is not None and not reg.match(element):
+ raise exc.CompileError(
+ "Unexpected SQL phrase: %r (matching against %r)"
+ % (element, reg.pattern)
+ )
+ return element
+
+ def quote_identifier(self, value):
+ """Quote an identifier.
+
+ Subclasses should override this to provide database-dependent
+ quoting behavior.
+ """
+
+ return (
+ self.initial_quote
+ + self._escape_identifier(value)
+ + self.final_quote
+ )
+
+ def _requires_quotes(self, value):
+ """Return True if the given identifier requires quoting."""
+ lc_value = value.lower()
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ or (lc_value != value)
+ )
+
+ def _requires_quotes_illegal_chars(self, value):
+ """Return True if the given identifier requires quoting, but
+ not taking case convention into account."""
+ return not self.legal_characters.match(util.text_type(value))
+
+ def quote_schema(self, schema, force=None):
+ """Conditionally quote a schema name.
+
+
+ The name is quoted if it is a reserved word, contains quote-necessary
+ characters, or is an instance of :class:`.quoted_name` which includes
+ ``quote`` set to ``True``.
+
+ Subclasses can override this to provide database-dependent
+ quoting behavior for schema names.
+
+ :param schema: string schema name
+ :param force: unused
+
+ .. deprecated:: 0.9
+
+ The :paramref:`.IdentifierPreparer.quote_schema.force`
+ parameter is deprecated and will be removed in a future
+ release. This flag has no effect on the behavior of the
+ :meth:`.IdentifierPreparer.quote` method; please refer to
+ :class:`.quoted_name`.
+
+ """
+ if force is not None:
+ # not using the util.deprecated_params() decorator in this
+ # case because of the additional function call overhead on this
+ # very performance-critical spot.
+ util.warn_deprecated(
+ "The IdentifierPreparer.quote_schema.force parameter is "
+ "deprecated and will be removed in a future release. This "
+ "flag has no effect on the behavior of the "
+ "IdentifierPreparer.quote method; please refer to "
+ "quoted_name().",
+ # deprecated 0.9. warning from 1.3
+ version="0.9",
+ )
+
+ return self.quote(schema)
+
+ def quote(self, ident, force=None):
+ """Conditionally quote an identifier.
+
+ The identifier is quoted if it is a reserved word, contains
+ quote-necessary characters, or is an instance of
+ :class:`.quoted_name` which includes ``quote`` set to ``True``.
+
+ Subclasses can override this to provide database-dependent
+ quoting behavior for identifier names.
+
+ :param ident: string identifier
+ :param force: unused
+
+ .. deprecated:: 0.9
+
+ The :paramref:`.IdentifierPreparer.quote.force`
+ parameter is deprecated and will be removed in a future
+ release. This flag has no effect on the behavior of the
+ :meth:`.IdentifierPreparer.quote` method; please refer to
+ :class:`.quoted_name`.
+
+ """
+ if force is not None:
+ # not using the util.deprecated_params() decorator in this
+ # case because of the additional function call overhead on this
+ # very performance-critical spot.
+ util.warn_deprecated(
+ "The IdentifierPreparer.quote.force parameter is "
+ "deprecated and will be removed in a future release. This "
+ "flag has no effect on the behavior of the "
+ "IdentifierPreparer.quote method; please refer to "
+ "quoted_name().",
+ # deprecated 0.9. warning from 1.3
+ version="0.9",
+ )
+
+ force = getattr(ident, "quote", None)
+
+ if force is None:
+ if ident in self._strings:
+ return self._strings[ident]
+ else:
+ if self._requires_quotes(ident):
+ self._strings[ident] = self.quote_identifier(ident)
+ else:
+ self._strings[ident] = ident
+ return self._strings[ident]
+ elif force:
+ return self.quote_identifier(ident)
+ else:
+ return ident
+
+ def format_collation(self, collation_name):
+ if self.quote_case_sensitive_collations:
+ return self.quote(collation_name)
+ else:
+ return collation_name
+
+ def format_sequence(self, sequence, use_schema=True):
+ name = self.quote(sequence.name)
+
+ effective_schema = self.schema_for_object(sequence)
+
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
+ name = self.quote_schema(effective_schema) + "." + name
+ return name
+
+ def format_label(self, label, name=None):
+ return self.quote(name or label.name)
+
+ def format_alias(self, alias, name=None):
+ return self.quote(name or alias.name)
+
+ def format_savepoint(self, savepoint, name=None):
+ # Running the savepoint name through quoting is unnecessary
+ # for all known dialects. This is here to support potential
+ # third party use cases
+ ident = name or savepoint.ident
+ if self._requires_quotes(ident):
+ ident = self.quote_identifier(ident)
+ return ident
+
+ @util.preload_module("sqlalchemy.sql.naming")
+ def format_constraint(self, constraint, _alembic_quote=True):
+ naming = util.preloaded.sql_naming
+
+ if constraint.name is elements._NONE_NAME:
+ name = naming._constraint_name_for_table(
+ constraint, constraint.table
+ )
+
+ if name is None:
+ return None
+ else:
+ name = constraint.name
+
+ if constraint.__visit_name__ == "index":
+ return self.truncate_and_render_index_name(
+ name, _alembic_quote=_alembic_quote
+ )
+ else:
+ return self.truncate_and_render_constraint_name(
+ name, _alembic_quote=_alembic_quote
+ )
+
+ def truncate_and_render_index_name(self, name, _alembic_quote=True):
+ # calculate these at format time so that ad-hoc changes
+ # to dialect.max_identifier_length etc. can be reflected
+ # as IdentifierPreparer is long lived
+ max_ = (
+ self.dialect.max_index_name_length
+ or self.dialect.max_identifier_length
+ )
+ return self._truncate_and_render_maxlen_name(
+ name, max_, _alembic_quote
+ )
+
+ def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
+ # calculate these at format time so that ad-hoc changes
+ # to dialect.max_identifier_length etc. can be reflected
+ # as IdentifierPreparer is long lived
+ max_ = (
+ self.dialect.max_constraint_name_length
+ or self.dialect.max_identifier_length
+ )
+ return self._truncate_and_render_maxlen_name(
+ name, max_, _alembic_quote
+ )
+
+ def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
+ if isinstance(name, elements._truncated_label):
+ if len(name) > max_:
+ name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
+ else:
+ self.dialect.validate_identifier(name)
+
+ if not _alembic_quote:
+ return name
+ else:
+ return self.quote(name)
+
+ def format_index(self, index):
+ return self.format_constraint(index)
+
+ def format_table(self, table, use_schema=True, name=None):
+ """Prepare a quoted table and schema name."""
+
+ if name is None:
+ name = table.name
+
+ result = self.quote(name)
+
+ effective_schema = self.schema_for_object(table)
+
+ if not self.omit_schema and use_schema and effective_schema:
+ result = self.quote_schema(effective_schema) + "." + result
+ return result
+
+ def format_schema(self, name):
+ """Prepare a quoted schema name."""
+
+ return self.quote(name)
+
+ def format_label_name(
+ self,
+ name,
+ anon_map=None,
+ ):
+ """Prepare a quoted column name."""
+
+ if anon_map is not None and isinstance(
+ name, elements._truncated_label
+ ):
+ name = name.apply_map(anon_map)
+
+ return self.quote(name)
+
+ def format_column(
+ self,
+ column,
+ use_table=False,
+ name=None,
+ table_name=None,
+ use_schema=False,
+ anon_map=None,
+ ):
+ """Prepare a quoted column name."""
+
+ if name is None:
+ name = column.name
+
+ if anon_map is not None and isinstance(
+ name, elements._truncated_label
+ ):
+ name = name.apply_map(anon_map)
+
+ if not getattr(column, "is_literal", False):
+ if use_table:
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + self.quote(name)
+ )
+ else:
+ return self.quote(name)
+ else:
+ # literal textual elements get stuck into ColumnClause a lot,
+ # which shouldn't get quoted
+
+ if use_table:
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + name
+ )
+ else:
+ return name
+
+ def format_table_seq(self, table, use_schema=True):
+ """Format table name and schema as a tuple."""
+
+ # Dialects with more levels in their fully qualified references
+ # ('database', 'owner', etc.) could override this and return
+ # a longer sequence.
+
+ effective_schema = self.schema_for_object(table)
+
+ if not self.omit_schema and use_schema and effective_schema:
+ return (
+ self.quote_schema(effective_schema),
+ self.format_table(table, use_schema=False),
+ )
+ else:
+ return (self.format_table(table, use_schema=False),)
+
+ @util.memoized_property
+ def _r_identifiers(self):
+ initial, final, escaped_final = [
+ re.escape(s)
+ for s in (
+ self.initial_quote,
+ self.final_quote,
+ self._escape_identifier(self.final_quote),
+ )
+ ]
+ r = re.compile(
+ r"(?:"
+ r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
+ r"|([^\.]+))(?=\.|$))+"
+ % {"initial": initial, "final": final, "escaped": escaped_final}
+ )
+ return r
+
+ def unformat_identifiers(self, identifiers):
+ """Unpack 'schema.table.column'-like strings into components."""
+
+ r = self._r_identifiers
+ return [
+ self._unescape_identifier(i)
+ for i in [a or b for a, b in r.findall(identifiers)]
+ ]
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
new file mode 100644
index 0000000..920c8b3
--- /dev/null
+++ b/lib/sqlalchemy/sql/crud.py
@@ -0,0 +1,1091 @@
+# sql/crud.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Functions used by compiler.py to determine the parameters rendered
+within INSERT and UPDATE statements.
+
+"""
+import functools
+import operator
+
+from . import coercions
+from . import dml
+from . import elements
+from . import roles
+from .selectable import Select
+from .. import exc
+from .. import util
+
+REQUIRED = util.symbol(
+ "REQUIRED",
+ """
+Placeholder for the value within a :class:`.BindParameter`
+which is required to be present when the statement is passed
+to :meth:`_engine.Connection.execute`.
+
+This symbol is typically used when a :func:`_expression.insert`
+or :func:`_expression.update` statement is compiled without parameter
+values present.
+
+""",
+)
+
+
+def _get_crud_params(compiler, stmt, compile_state, **kw):
+ """create a set of tuples representing column/string pairs for use
+ in an INSERT or UPDATE statement.
+
+ Also generates the Compiled object's postfetch, prefetch, and
+ returning column collections, used for default handling and ultimately
+ populating the CursorResult's prefetch_cols() and postfetch_cols()
+ collections.
+
+ """
+
+ compiler.postfetch = []
+ compiler.insert_prefetch = []
+ compiler.update_prefetch = []
+ compiler.returning = []
+
+ # getters - these are normally just column.key,
+ # but in the case of mysql multi-table update, the rules for
+ # .key must conditionally take tablename into account
+ (
+ _column_as_key,
+ _getattr_col_key,
+ _col_bind_name,
+ ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+
+ compiler._key_getters_for_crud_column = getters
+
+ # no parameters in the statement, no parameters in the
+ # compiled params - return binds for all columns
+ if compiler.column_keys is None and compile_state._no_parameters:
+ return [
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_bind_param(compiler, c, None, required=True),
+ )
+ for c in stmt.table.columns
+ ]
+
+ if compile_state._has_multi_parameters:
+ spd = compile_state._multi_parameters[0]
+ stmt_parameter_tuples = list(spd.items())
+ elif compile_state._ordered_values:
+ spd = compile_state._dict_parameters
+ stmt_parameter_tuples = compile_state._ordered_values
+ elif compile_state._dict_parameters:
+ spd = compile_state._dict_parameters
+ stmt_parameter_tuples = list(spd.items())
+ else:
+ stmt_parameter_tuples = spd = None
+
+ # if we have statement parameters - set defaults in the
+ # compiled params
+ if compiler.column_keys is None:
+ parameters = {}
+ elif stmt_parameter_tuples:
+ parameters = dict(
+ (_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if key not in spd
+ )
+ else:
+ parameters = dict(
+ (_column_as_key(key), REQUIRED) for key in compiler.column_keys
+ )
+
+ # create a list of column assignment clauses as tuples
+ values = []
+
+ if stmt_parameter_tuples is not None:
+ _get_stmt_parameter_tuples_params(
+ compiler,
+ compile_state,
+ parameters,
+ stmt_parameter_tuples,
+ _column_as_key,
+ values,
+ kw,
+ )
+
+ check_columns = {}
+
+ # special logic that only occurs for multi-table UPDATE
+ # statements
+ if compile_state.isupdate and compile_state.is_multitable:
+ _get_update_multitable_params(
+ compiler,
+ stmt,
+ compile_state,
+ stmt_parameter_tuples,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+ )
+
+ if compile_state.isinsert and stmt._select_names:
+ _scan_insert_from_select_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
+ else:
+ _scan_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
+
+ if parameters and stmt_parameter_tuples:
+ check = (
+ set(parameters)
+ .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples)
+ .difference(check_columns)
+ )
+ if check:
+ raise exc.CompileError(
+ "Unconsumed column names: %s"
+ % (", ".join("%s" % (c,) for c in check))
+ )
+
+ if compile_state._has_multi_parameters:
+ values = _extend_values_for_multiparams(
+ compiler,
+ stmt,
+ compile_state,
+ values,
+ _column_as_key,
+ kw,
+ )
+ elif (
+ not values
+ and compiler.for_executemany
+ and compiler.dialect.supports_default_metavalue
+ ):
+ # convert an "INSERT DEFAULT VALUES"
+ # into INSERT (firstcol) VALUES (DEFAULT) which can be turned
+ # into an in-place multi values. This supports
+ # insert_executemany_returning mode :)
+ values = [
+ (
+ stmt.table.columns[0],
+ compiler.preparer.format_column(stmt.table.columns[0]),
+ "DEFAULT",
+ )
+ ]
+
+ return values
+
+
+def _create_bind_param(
+ compiler, col, value, process=True, required=False, name=None, **kw
+):
+ if name is None:
+ name = col.key
+ bindparam = elements.BindParameter(
+ name, value, type_=col.type, required=required
+ )
+ bindparam._is_crud = True
+ if process:
+ bindparam = bindparam._compiler_dispatch(compiler, **kw)
+ return bindparam
+
+
+def _handle_values_anonymous_param(compiler, col, value, name, **kw):
+ # the insert() and update() constructs as of 1.4 will now produce anonymous
+ # bindparam() objects in the values() collections up front when given plain
+ # literal values. This is so that cache key behaviors, which need to
+ # produce bound parameters in deterministic order without invoking any
+ # compilation here, can be applied to these constructs when they include
+ # values() (but not yet multi-values, which are not included in caching
+ # right now).
+ #
+ # in order to produce the desired "crud" style name for these parameters,
+ # which will also be targetable in engine/default.py through the usual
+ # conventions, apply our desired name to these unique parameters by
+ # populating the compiler truncated names cache with the desired name,
+ # rather than having
+ # compiler.visit_bindparam()->compiler._truncated_identifier make up a
+ # name. Saves on call counts also.
+
+ # for INSERT/UPDATE that's a CTE, we don't need names to match to
+ # external parameters and these would also conflict in the case where
+ # multiple insert/update are combined together using CTEs
+ is_cte = "visiting_cte" in kw
+
+ if (
+ not is_cte
+ and value.unique
+ and isinstance(value.key, elements._truncated_label)
+ ):
+ compiler.truncated_names[("bindparam", value.key)] = name
+
+ if value.type._isnull:
+ # either unique parameter, or other bound parameters that were
+ # passed in directly
+ # set type to that of the column unconditionally
+ value = value._with_binary_element_type(col.type)
+
+ return value._compiler_dispatch(compiler, **kw)
+
+
+def _key_getters_for_crud_column(compiler, stmt, compile_state):
+ if compile_state.isupdate and compile_state._extra_froms:
+ # when extra tables are present, refer to the columns
+ # in those extra tables as table-qualified, including in
+ # dictionaries and when rendering bind param names.
+ # the "main" table of the statement remains unqualified,
+ # allowing the most compatibility with a non-multi-table
+ # statement.
+ _et = set(compile_state._extra_froms)
+
+ c_key_role = functools.partial(
+ coercions.expect_as_key, roles.DMLColumnRole
+ )
+
+ def _column_as_key(key):
+ str_key = c_key_role(key)
+ if hasattr(key, "table") and key.table in _et:
+ return (key.table.name, str_key)
+ else:
+ return str_key
+
+ def _getattr_col_key(col):
+ if col.table in _et:
+ return (col.table.name, col.key)
+ else:
+ return col.key
+
+ def _col_bind_name(col):
+ if col.table in _et:
+ return "%s_%s" % (col.table.name, col.key)
+ else:
+ return col.key
+
+ else:
+ _column_as_key = functools.partial(
+ coercions.expect_as_key, roles.DMLColumnRole
+ )
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+ return _column_as_key, _getattr_col_key, _col_bind_name
+
+
+def _scan_insert_from_select_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
+
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
+
+ assert compiler.stack[-1]["selectable"] is stmt
+
+ compiler.stack[-1]["insert_from_select"] = stmt.select
+
+ add_select_cols = []
+ if stmt.include_insert_from_select_defaults:
+ col_set = set(cols)
+ for col in stmt.table.columns:
+ if col not in col_set and col.default:
+ cols.append(col)
+
+ for c in cols:
+ col_key = _getattr_col_key(c)
+ if col_key in parameters and col_key not in check_columns:
+ parameters.pop(col_key)
+ values.append((c, compiler.preparer.format_column(c), None))
+ else:
+ _append_param_insert_select_hasdefault(
+ compiler, stmt, c, add_select_cols, kw
+ )
+
+ if add_select_cols:
+ values.extend(add_select_cols)
+ ins_from_select = compiler.stack[-1]["insert_from_select"]
+ if not isinstance(ins_from_select, Select):
+ raise exc.CompileError(
+ "Can't extend statement for INSERT..FROM SELECT to include "
+ "additional default-holding column(s) "
+ "%s. Convert the selectable to a subquery() first, or pass "
+ "include_defaults=False to Insert.from_select() to skip these "
+ "columns."
+ % (", ".join(repr(key) for _, key, _ in add_select_cols),)
+ )
+ ins_from_select = ins_from_select._generate()
+ # copy raw_columns
+ ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [
+ expr for col, col_expr, expr in add_select_cols
+ ]
+ compiler.stack[-1]["insert_from_select"] = ins_from_select
+
+
+def _scan_cols(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+ (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ ) = _get_returning_modifiers(compiler, stmt, compile_state)
+
+ if compile_state._parameter_ordering:
+ parameter_ordering = [
+ _column_as_key(key) for key in compile_state._parameter_ordering
+ ]
+ ordered_keys = set(parameter_ordering)
+ cols = [
+ stmt.table.c[key]
+ for key in parameter_ordering
+ if isinstance(key, util.string_types) and key in stmt.table.c
+ ] + [c for c in stmt.table.c if c.key not in ordered_keys]
+
+ else:
+ cols = stmt.table.columns
+
+ for c in cols:
+ # scan through every column in the target table
+
+ col_key = _getattr_col_key(c)
+
+ if col_key in parameters and col_key not in check_columns:
+ # parameter is present for the column. use that.
+
+ _append_param_parameter(
+ compiler,
+ stmt,
+ compile_state,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
+
+ elif compile_state.isinsert:
+ # no parameter is present and it's an insert.
+
+ if c.primary_key and need_pks:
+ # it's a primary key column, it will need to be generated by a
+ # default generator of some kind, and the statement expects
+ # inserted_primary_key to be available.
+
+ if implicit_returning:
+ # we can use RETURNING, find out how to invoke this
+ # column and get the value where RETURNING is an option.
+ # we can inline server-side functions in this case.
+
+ _append_param_insert_pk_returning(
+ compiler, stmt, c, values, kw
+ )
+ else:
+ # otherwise, find out how to invoke this column
+ # and get its value where RETURNING is not an option.
+ # if we have to invoke a server-side function, we need
+ # to pre-execute it. or if this is a straight
+ # autoincrement column and the dialect supports it
+ # we can use cursor.lastrowid.
+
+ _append_param_insert_pk_no_returning(
+ compiler, stmt, c, values, kw
+ )
+
+ elif c.default is not None:
+ # column has a default, but it's not a pk column, or it is but
+ # we don't need to get the pk back.
+ _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
+
+ elif c.server_default is not None:
+ # column has a DDL-level default, and is either not a pk
+ # column or we don't need the pk.
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif (
+ c.primary_key
+ and c is not stmt.table._autoincrement_column
+ and not c.nullable
+ ):
+ _warn_pk_with_no_anticipated_value(c)
+
+ elif compile_state.isupdate:
+ # no parameter is present and it's an insert.
+
+ _append_param_update(
+ compiler,
+ compile_state,
+ stmt,
+ c,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
+
+
+def _append_param_parameter(
+ compiler,
+ stmt,
+ compile_state,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+):
+ value = parameters.pop(col_key)
+
+ col_value = compiler.preparer.format_column(
+ c, use_table=compile_state.include_table_with_column_exprs
+ )
+
+ if coercions._is_literal(value):
+ value = _create_bind_param(
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c)
+ if not compile_state._has_multi_parameters
+ else "%s_m0" % _col_bind_name(c),
+ **kw
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler,
+ c,
+ value,
+ name=_col_bind_name(c)
+ if not compile_state._has_multi_parameters
+ else "%s_m0" % _col_bind_name(c),
+ **kw
+ )
+ else:
+ # value is a SQL expression
+ value = compiler.process(value.self_group(), **kw)
+
+ if compile_state.isupdate:
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ else:
+ compiler.postfetch.append(c)
+ else:
+ if c.primary_key:
+
+ if implicit_returning:
+ compiler.returning.append(c)
+ elif compiler.dialect.postfetch_lastrowid:
+ compiler.postfetch_lastrowid = True
+
+ elif implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+
+ else:
+ # postfetch specifically means, "we can SELECT the row we just
+ # inserted by primary key to get back the server generated
+ # defaults". so by definition this can't be used to get the
+ # primary key value back, because we need to have it ahead of
+ # time.
+
+ compiler.postfetch.append(c)
+
+ values.append((c, col_value, value))
+
+
+def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
+ """Create a primary key expression in the INSERT statement where
+ we want to populate result.inserted_primary_key and RETURNING
+ is available.
+
+ """
+ if c.default is not None:
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ ):
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default, **kw),
+ )
+ )
+ compiler.returning.append(c)
+ elif c.default.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default.arg.self_group(), **kw),
+ )
+ )
+ compiler.returning.append(c)
+ else:
+ # client side default. OK we can't use RETURNING, need to
+ # do a "prefetch", which in fact fetches the default value
+ # on the Python side
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+ elif c is stmt.table._autoincrement_column or c.server_default is not None:
+ compiler.returning.append(c)
+ elif not c.nullable:
+ # no .default, no .server_default, not autoincrement, we have
+ # no indication this primary key column will have any value
+ _warn_pk_with_no_anticipated_value(c)
+
+
+def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw):
+ """Create a primary key expression in the INSERT statement where
+ we want to populate result.inserted_primary_key and we cannot use
+ RETURNING.
+
+ Depending on the kind of default here we may create a bound parameter
+ in the INSERT statement and pre-execute a default generation function,
+ or we may use cursor.lastrowid if supported by the dialect.
+
+
+ """
+
+ if (
+ # column has a Python-side default
+ c.default is not None
+ and (
+ # and it either is not a sequence, or it is and we support
+ # sequences and want to invoke it
+ not c.default.is_sequence
+ or (
+ compiler.dialect.supports_sequences
+ and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ )
+ )
+ )
+ ) or (
+ # column is the "autoincrement column"
+ c is stmt.table._autoincrement_column
+ and (
+ # dialect can't use cursor.lastrowid
+ not compiler.dialect.postfetch_lastrowid
+ and (
+ # column has a Sequence and we support those
+ (
+ c.default is not None
+ and c.default.is_sequence
+ and compiler.dialect.supports_sequences
+ )
+ or
+ # column has no default on it, but dialect can run the
+ # "autoincrement" mechanism explicitly, e.g. PostgreSQL
+ # SERIAL we know the sequence name
+ (
+ c.default is None
+ and compiler.dialect.preexecute_autoincrement_sequences
+ )
+ )
+ )
+ ):
+ # do a pre-execute of the default
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+ elif (
+ c.default is None
+ and c.server_default is None
+ and not c.nullable
+ and c is not stmt.table._autoincrement_column
+ ):
+ # no .default, no .server_default, not autoincrement, we have
+ # no indication this primary key column will have any value
+ _warn_pk_with_no_anticipated_value(c)
+ elif compiler.dialect.postfetch_lastrowid:
+ # finally, where it seems like there will be a generated primary key
+ # value and we haven't set up any other way to fetch it, and the
+ # dialect supports cursor.lastrowid, switch on the lastrowid flag so
+ # that the DefaultExecutionContext calls upon cursor.lastrowid
+ compiler.postfetch_lastrowid = True
+
+
+def _append_param_insert_hasdefault(
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default, **kw),
+ )
+ )
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ compiler.postfetch.append(c)
+ elif c.default.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ compiler.process(c.default.arg.self_group(), **kw),
+ )
+ )
+
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ elif not c.primary_key:
+ # don't add primary key column to postfetch
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+
+
+def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
+
+ if c.default.is_sequence:
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
+ values.append(
+ (c, compiler.preparer.format_column(c), c.default.next_value())
+ )
+ elif c.default.is_clause_element:
+ values.append(
+ (c, compiler.preparer.format_column(c), c.default.arg.self_group())
+ )
+ else:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_insert_prefetch_bind_param(
+ compiler, c, process=False, **kw
+ ),
+ )
+ )
+
+
+def _append_param_update(
+ compiler, compile_state, stmt, c, implicit_return_defaults, values, kw
+):
+
+ include_table = compile_state.include_table_with_column_exprs
+ if c.onupdate is not None and not c.onupdate.is_sequence:
+ if c.onupdate.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(
+ c,
+ use_table=include_table,
+ ),
+ compiler.process(c.onupdate.arg.self_group(), **kw),
+ )
+ )
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (
+ c,
+ compiler.preparer.format_column(
+ c,
+ use_table=include_table,
+ ),
+ _create_update_prefetch_bind_param(compiler, c, **kw),
+ )
+ )
+ elif c.server_onupdate is not None:
+ if implicit_return_defaults and c in implicit_return_defaults:
+ compiler.returning.append(c)
+ else:
+ compiler.postfetch.append(c)
+ elif (
+ implicit_return_defaults
+ and (stmt._return_defaults_columns or not stmt._return_defaults)
+ and c in implicit_return_defaults
+ ):
+ compiler.returning.append(c)
+
+
+def _create_insert_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
+ compiler.insert_prefetch.append(c)
+ return param
+
+
+def _create_update_prefetch_bind_param(
+ compiler, c, process=True, name=None, **kw
+):
+ param = _create_bind_param(
+ compiler, c, None, process=process, name=name, **kw
+ )
+ compiler.update_prefetch.append(c)
+ return param
+
+
+class _multiparam_column(elements.ColumnElement):
+ _is_multiparam_column = True
+
+ def __init__(self, original, index):
+ self.index = index
+ self.key = "%s_m%d" % (original.key, index + 1)
+ self.original = original
+ self.default = original.default
+ self.type = original.type
+
+ def compare(self, other, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, other, **kw):
+ raise NotImplementedError()
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, _multiparam_column)
+ and other.key == self.key
+ and other.original == self.original
+ )
+
+
+def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
+ if not c.default:
+ raise exc.CompileError(
+ "INSERT value for column %s is explicitly rendered as a bound"
+ "parameter in the VALUES clause; "
+ "a Python-side value or SQL expression is required" % c
+ )
+ elif c.default.is_clause_element:
+ return compiler.process(c.default.arg.self_group(), **kw)
+ elif c.default.is_sequence:
+ # these conditions would have been established
+ # by append_param_insert_(?:hasdefault|pk_returning|pk_no_returning)
+ # in order for us to be here, so these don't need to be
+ # checked
+ # assert compiler.dialect.supports_sequences and (
+ # not c.default.optional
+ # or not compiler.dialect.sequences_optional
+ # )
+ return compiler.process(c.default, **kw)
+ else:
+ col = _multiparam_column(c, index)
+ if isinstance(stmt, dml.Insert):
+ return _create_insert_prefetch_bind_param(compiler, col, **kw)
+ else:
+ return _create_update_prefetch_bind_param(compiler, col, **kw)
+
+
+def _get_update_multitable_params(
+ compiler,
+ stmt,
+ compile_state,
+ stmt_parameter_tuples,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+):
+ normalized_params = dict(
+ (coercions.expect(roles.DMLColumnRole, c), param)
+ for c, param in stmt_parameter_tuples
+ )
+
+ include_table = compile_state.include_table_with_column_exprs
+
+ affected_tables = set()
+ for t in compile_state._extra_froms:
+ for c in t.c:
+ if c in normalized_params:
+ affected_tables.add(t)
+ check_columns[_getattr_col_key(c)] = c
+ value = normalized_params[c]
+
+ col_value = compiler.process(c, include_table=include_table)
+ if coercions._is_literal(value):
+ value = _create_bind_param(
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c),
+ **kw # TODO: no test coverage for literal binds here
+ )
+ elif value._is_bind_parameter:
+ value = _handle_values_anonymous_param(
+ compiler, c, value, name=_col_bind_name(c), **kw
+ )
+ else:
+ compiler.postfetch.append(c)
+ value = compiler.process(value.self_group(), **kw)
+ values.append((c, col_value, value))
+ # determine tables which are actually to be updated - process onupdate
+ # and server_onupdate for these
+ for t in affected_tables:
+ for c in t.c:
+ if c in normalized_params:
+ continue
+ elif c.onupdate is not None and not c.onupdate.is_sequence:
+ if c.onupdate.is_clause_element:
+ values.append(
+ (
+ c,
+ compiler.process(c, include_table=include_table),
+ compiler.process(
+ c.onupdate.arg.self_group(), **kw
+ ),
+ )
+ )
+ compiler.postfetch.append(c)
+ else:
+ values.append(
+ (
+ c,
+ compiler.process(c, include_table=include_table),
+ _create_update_prefetch_bind_param(
+ compiler, c, name=_col_bind_name(c), **kw
+ ),
+ )
+ )
+ elif c.server_onupdate is not None:
+ compiler.postfetch.append(c)
+
+
+def _extend_values_for_multiparams(
+ compiler,
+ stmt,
+ compile_state,
+ values,
+ _column_as_key,
+ kw,
+):
+ values_0 = values
+ values = [values]
+
+ for i, row in enumerate(compile_state._multi_parameters[1:]):
+ extension = []
+
+ row = {_column_as_key(key): v for key, v in row.items()}
+
+ for (col, col_expr, param) in values_0:
+ if col.key in row:
+ key = col.key
+
+ if coercions._is_literal(row[key]):
+ new_param = _create_bind_param(
+ compiler,
+ col,
+ row[key],
+ name="%s_m%d" % (col.key, i + 1),
+ **kw
+ )
+ else:
+ new_param = compiler.process(row[key].self_group(), **kw)
+ else:
+ new_param = _process_multiparam_default_bind(
+ compiler, stmt, col, i, kw
+ )
+
+ extension.append((col, col_expr, new_param))
+
+ values.append(extension)
+
+ return values
+
+
+def _get_stmt_parameter_tuples_params(
+ compiler,
+ compile_state,
+ parameters,
+ stmt_parameter_tuples,
+ _column_as_key,
+ values,
+ kw,
+):
+
+ for k, v in stmt_parameter_tuples:
+ colkey = _column_as_key(k)
+ if colkey is not None:
+ parameters.setdefault(colkey, v)
+ else:
+ # a non-Column expression on the left side;
+ # add it to values() in an "as-is" state,
+ # coercing right side to bound param
+
+ # note one of the main use cases for this is array slice
+ # updates on PostgreSQL, as the left side is also an expression.
+
+ col_expr = compiler.process(
+ k, include_table=compile_state.include_table_with_column_exprs
+ )
+
+ if coercions._is_literal(v):
+ v = compiler.process(
+ elements.BindParameter(None, v, type_=k.type), **kw
+ )
+ else:
+ if v._is_bind_parameter and v.type._isnull:
+ # either unique parameter, or other bound parameters that
+ # were passed in directly
+ # set type to that of the column unconditionally
+ v = v._with_binary_element_type(k.type)
+
+ v = compiler.process(v.self_group(), **kw)
+
+ values.append((k, col_expr, v))
+
+
+def _get_returning_modifiers(compiler, stmt, compile_state):
+
+ need_pks = (
+ compile_state.isinsert
+ and not stmt._inline
+ and (
+ not compiler.for_executemany
+ or (
+ compiler.dialect.insert_executemany_returning
+ and stmt._return_defaults
+ )
+ )
+ and not stmt._returning
+ and not compile_state._has_multi_parameters
+ )
+
+ implicit_returning = (
+ need_pks
+ and compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ )
+
+ if compile_state.isinsert:
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
+ elif compile_state.isupdate:
+ implicit_return_defaults = (
+ compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ and stmt._return_defaults
+ )
+ else:
+ # this line is unused, currently we are always
+ # isinsert or isupdate
+ implicit_return_defaults = False # pragma: no cover
+
+ if implicit_return_defaults:
+ if not stmt._return_defaults_columns:
+ implicit_return_defaults = set(stmt.table.c)
+ else:
+ implicit_return_defaults = set(stmt._return_defaults_columns)
+
+ postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
+
+ return (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ )
+
+
+def _warn_pk_with_no_anticipated_value(c):
+ msg = (
+ "Column '%s.%s' is marked as a member of the "
+ "primary key for table '%s', "
+ "but has no Python-side or server-side default generator indicated, "
+ "nor does it indicate 'autoincrement=True' or 'nullable=True', "
+ "and no explicit value is passed. "
+ "Primary key columns typically may not store NULL."
+ % (c.table.fullname, c.name, c.table.fullname)
+ )
+ if len(c.table.primary_key) > 1:
+ msg += (
+ " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
+ "indicated explicitly for composite (e.g. multicolumn) primary "
+ "keys if AUTO_INCREMENT/SERIAL/IDENTITY "
+ "behavior is expected for one of the columns in the primary key. "
+ "CREATE TABLE statements are impacted by this change as well on "
+ "most backends."
+ )
+ util.warn(msg)
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
new file mode 100644
index 0000000..e608052
--- /dev/null
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -0,0 +1,1341 @@
+# sql/ddl.py
+# Copyright (C) 2009-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+Provides the hierarchy of DDL-defining schema items as well as routines
+to invoke them for a create/drop call.
+
+"""
+
+from . import roles
+from .base import _bind_or_error
+from .base import _generative
+from .base import Executable
+from .base import SchemaVisitor
+from .elements import ClauseElement
+from .. import exc
+from .. import util
+from ..util import topological
+
+
+class _DDLCompiles(ClauseElement):
+ _hierarchy_supports_caching = False
+ """disable cache warnings for all _DDLCompiles subclasses. """
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a
+ Dialect."""
+
+ return dialect.ddl_compiler(dialect, self, **kw)
+
+ def _compile_w_cache(self, *arg, **kw):
+ raise NotImplementedError()
+
+
+class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
+ """Base class for DDL expression constructs.
+
+ This class is the base for the general purpose :class:`.DDL` class,
+ as well as the various create/drop clause constructs such as
+ :class:`.CreateTable`, :class:`.DropTable`, :class:`.AddConstraint`,
+ etc.
+
+ :class:`.DDLElement` integrates closely with SQLAlchemy events,
+ introduced in :ref:`event_toplevel`. An instance of one is
+ itself an event receiving callable::
+
+ event.listen(
+ users,
+ 'after_create',
+ AddConstraint(constraint).execute_if(dialect='postgresql')
+ )
+
+ .. seealso::
+
+ :class:`.DDL`
+
+ :class:`.DDLEvents`
+
+ :ref:`event_toplevel`
+
+ :ref:`schema_ddl_sequences`
+
+ """
+
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
+
+ target = None
+ on = None
+ dialect = None
+ callable_ = None
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ return connection._execute_ddl(
+ self, multiparams, params, execution_options
+ )
+
+ @util.deprecated_20(
+ ":meth:`.DDLElement.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, bind=None, target=None):
+ """Execute this DDL immediately.
+
+ Executes the DDL statement in isolation using the supplied
+ :class:`.Connectable` or
+ :class:`.Connectable` assigned to the ``.bind``
+ property, if not supplied. If the DDL has a conditional ``on``
+ criteria, it will be invoked with None as the event.
+
+ :param bind:
+ Optional, an ``Engine`` or ``Connection``. If not supplied, a valid
+ :class:`.Connectable` must be present in the
+ ``.bind`` property.
+
+ :param target:
+ Optional, defaults to None. The target :class:`_schema.SchemaItem`
+ for the execute call. This is equivalent to passing the
+ :class:`_schema.SchemaItem` to the :meth:`.DDLElement.against`
+ method and then invoking :meth:`_schema.DDLElement.execute`
+ upon the resulting :class:`_schema.DDLElement` object. See
+ :meth:`.DDLElement.against` for further detail.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+
+ if self._should_execute(target, bind):
+ return bind.execute(self.against(target))
+ else:
+ bind.engine.logger.info("DDL execution skipped, criteria not met.")
+
+ @_generative
+ def against(self, target):
+ """Return a copy of this :class:`_schema.DDLElement` which will include
+ the given target.
+
+ This essentially applies the given item to the ``.target`` attribute
+ of the returned :class:`_schema.DDLElement` object. This target
+ is then usable by event handlers and compilation routines in order to
+ provide services such as tokenization of a DDL string in terms of a
+ particular :class:`_schema.Table`.
+
+ When a :class:`_schema.DDLElement` object is established as an event
+ handler for the :meth:`_events.DDLEvents.before_create` or
+ :meth:`_events.DDLEvents.after_create` events, and the event
+ then occurs for a given target such as a :class:`_schema.Constraint`
+ or :class:`_schema.Table`, that target is established with a copy
+ of the :class:`_schema.DDLElement` object using this method, which
+ then proceeds to the :meth:`_schema.DDLElement.execute` method
+ in order to invoke the actual DDL instruction.
+
+ :param target: a :class:`_schema.SchemaItem` that will be the subject
+ of a DDL operation.
+
+ :return: a copy of this :class:`_schema.DDLElement` with the
+ ``.target`` attribute assigned to the given
+ :class:`_schema.SchemaItem`.
+
+ .. seealso::
+
+ :class:`_schema.DDL` - uses tokenization against the "target" when
+ processing the DDL string.
+
+ """
+
+ self.target = target
+
+ @_generative
+ def execute_if(self, dialect=None, callable_=None, state=None):
+ r"""Return a callable that will execute this
+ :class:`_ddl.DDLElement` conditionally within an event handler.
+
+ Used to provide a wrapper for event listening::
+
+ event.listen(
+ metadata,
+ 'before_create',
+ DDL("my_ddl").execute_if(dialect='postgresql')
+ )
+
+ :param dialect: May be a string or tuple of strings.
+ If a string, it will be compared to the name of the
+ executing database dialect::
+
+ DDL('something').execute_if(dialect='postgresql')
+
+ If a tuple, specifies multiple dialect names::
+
+ DDL('something').execute_if(dialect=('postgresql', 'mysql'))
+
+ :param callable\_: A callable, which will be invoked with
+ four positional arguments as well as optional keyword
+ arguments:
+
+ :ddl:
+ This DDL element.
+
+ :target:
+ The :class:`_schema.Table` or :class:`_schema.MetaData`
+ object which is the
+ target of this event. May be None if the DDL is executed
+ explicitly.
+
+ :bind:
+ The :class:`_engine.Connection` being used for DDL execution
+
+ :tables:
+ Optional keyword argument - a list of Table objects which are to
+ be created/ dropped within a MetaData.create_all() or drop_all()
+ method call.
+
+ :state:
+ Optional keyword argument - will be the ``state`` argument
+ passed to this function.
+
+ :checkfirst:
+ Keyword argument, will be True if the 'checkfirst' flag was
+ set during the call to ``create()``, ``create_all()``,
+ ``drop()``, ``drop_all()``.
+
+ If the callable returns a True value, the DDL statement will be
+ executed.
+
+ :param state: any value which will be passed to the callable\_
+ as the ``state`` keyword argument.
+
+ .. seealso::
+
+ :class:`.DDLEvents`
+
+ :ref:`event_toplevel`
+
+ """
+ self.dialect = dialect
+ self.callable_ = callable_
+ self.state = state
+
+ def _should_execute(self, target, bind, **kw):
+ if isinstance(self.dialect, util.string_types):
+ if self.dialect != bind.engine.name:
+ return False
+ elif isinstance(self.dialect, (tuple, list, set)):
+ if bind.engine.name not in self.dialect:
+ return False
+ if self.callable_ is not None and not self.callable_(
+ self, target, bind, state=self.state, **kw
+ ):
+ return False
+
+ return True
+
+ def __call__(self, target, bind, **kw):
+ """Execute the DDL as a ddl_listener."""
+
+ if self._should_execute(target, bind, **kw):
+ return bind.execute(self.against(target))
+
+ def bind(self):
+ if self._bind:
+ return self._bind
+
+ def _set_bind(self, bind):
+ self._bind = bind
+
+ bind = property(bind, _set_bind)
+
+ def _generate(self):
+ s = self.__class__.__new__(self.__class__)
+ s.__dict__ = self.__dict__.copy()
+ return s
+
+
+class DDL(DDLElement):
+ """A literal DDL statement.
+
+ Specifies literal SQL DDL to be executed by the database. DDL objects
+ function as DDL event listeners, and can be subscribed to those events
+ listed in :class:`.DDLEvents`, using either :class:`_schema.Table` or
+ :class:`_schema.MetaData` objects as targets.
+ Basic templating support allows
+ a single DDL instance to handle repetitive tasks for multiple tables.
+
+ Examples::
+
+ from sqlalchemy import event, DDL
+
+ tbl = Table('users', metadata, Column('uid', Integer))
+ event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger'))
+
+ spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE')
+ event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb'))
+
+ drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
+ connection.execute(drop_spow)
+
+ When operating on Table events, the following ``statement``
+ string substitutions are available::
+
+ %(table)s - the Table name, with any required quoting applied
+ %(schema)s - the schema name, with any required quoting applied
+ %(fullname)s - the Table name including schema, quoted if needed
+
+ The DDL's "context", if any, will be combined with the standard
+ substitutions noted above. Keys present in the context will override
+ the standard substitutions.
+
+ """
+
+ __visit_name__ = "ddl"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DDL.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, statement, context=None, bind=None):
+ """Create a DDL statement.
+
+ :param statement:
+ A string or unicode string to be executed. Statements will be
+ processed with Python's string formatting operator using
+ a fixed set of string substitutions, as well as additional
+ substitutions provided by the optional :paramref:`.DDL.context`
+ parameter.
+
+ A literal '%' in a statement must be escaped as '%%'.
+
+ SQL bind parameters are not available in DDL statements.
+
+ :param context:
+ Optional dictionary, defaults to None. These values will be
+ available for use in string substitutions on the DDL statement.
+
+ :param bind:
+ Optional. A :class:`.Connectable`, used by
+ default when ``execute()`` is invoked without a bind argument.
+
+
+ .. seealso::
+
+ :class:`.DDLEvents`
+
+ :ref:`event_toplevel`
+
+ """
+
+ if not isinstance(statement, util.string_types):
+ raise exc.ArgumentError(
+ "Expected a string or unicode SQL statement, got '%r'"
+ % statement
+ )
+
+ self.statement = statement
+ self.context = context or {}
+
+ self._bind = bind
+
+ def __repr__(self):
+ return "<%s@%s; %s>" % (
+ type(self).__name__,
+ id(self),
+ ", ".join(
+ [repr(self.statement)]
+ + [
+ "%s=%r" % (key, getattr(self, key))
+ for key in ("on", "context")
+ if getattr(self, key)
+ ]
+ ),
+ )
+
+
+class _CreateDropBase(DDLElement):
+ """Base class for DDL constructs that represent CREATE and DROP or
+ equivalents.
+
+ The common theme of _CreateDropBase is a single
+ ``element`` attribute which refers to the element
+ to be created or dropped.
+
+ """
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DDLElement.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ element,
+ bind=None,
+ if_exists=False,
+ if_not_exists=False,
+ _legacy_bind=None,
+ ):
+ self.element = element
+ if bind:
+ self.bind = bind
+ elif _legacy_bind:
+ self.bind = _legacy_bind
+ self.if_exists = if_exists
+ self.if_not_exists = if_not_exists
+
+ @property
+ def stringify_dialect(self):
+ return self.element.create_drop_stringify_dialect
+
+ def _create_rule_disable(self, compiler):
+ """Allow disable of _create_rule using a callable.
+
+ Pass to _create_rule using
+ util.portable_instancemethod(self._create_rule_disable)
+ to retain serializability.
+
+ """
+ return False
+
+
+class CreateSchema(_CreateDropBase):
+ """Represent a CREATE SCHEMA statement.
+
+ The argument here is the string name of the schema.
+
+ """
+
+ __visit_name__ = "create_schema"
+
+ def __init__(self, name, quote=None, **kw):
+ """Create a new :class:`.CreateSchema` construct."""
+
+ self.quote = quote
+ super(CreateSchema, self).__init__(name, **kw)
+
+
+class DropSchema(_CreateDropBase):
+ """Represent a DROP SCHEMA statement.
+
+ The argument here is the string name of the schema.
+
+ """
+
+ __visit_name__ = "drop_schema"
+
+ def __init__(self, name, quote=None, cascade=False, **kw):
+ """Create a new :class:`.DropSchema` construct."""
+
+ self.quote = quote
+ self.cascade = cascade
+ super(DropSchema, self).__init__(name, **kw)
+
+
+class CreateTable(_CreateDropBase):
+ """Represent a CREATE TABLE statement."""
+
+ __visit_name__ = "create_table"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.CreateTable.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ element,
+ bind=None,
+ include_foreign_key_constraints=None,
+ if_not_exists=False,
+ ):
+ """Create a :class:`.CreateTable` construct.
+
+ :param element: a :class:`_schema.Table` that's the subject
+ of the CREATE
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param include_foreign_key_constraints: optional sequence of
+ :class:`_schema.ForeignKeyConstraint` objects that will be included
+ inline within the CREATE construct; if omitted, all foreign key
+ constraints that do not specify use_alter=True are included.
+
+ .. versionadded:: 1.0.0
+
+ :param if_not_exists: if True, an IF NOT EXISTS operator will be
+ applied to the construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(CreateTable, self).__init__(
+ element, _legacy_bind=bind, if_not_exists=if_not_exists
+ )
+ self.columns = [CreateColumn(column) for column in element.columns]
+ self.include_foreign_key_constraints = include_foreign_key_constraints
+
+
+class _DropView(_CreateDropBase):
+ """Semi-public 'DROP VIEW' construct.
+
+ Used by the test suite for dialect-agnostic drops of views.
+ This object will eventually be part of a public "view" API.
+
+ """
+
+ __visit_name__ = "drop_view"
+
+
+class CreateColumn(_DDLCompiles):
+ """Represent a :class:`_schema.Column`
+ as rendered in a CREATE TABLE statement,
+ via the :class:`.CreateTable` construct.
+
+ This is provided to support custom column DDL within the generation
+ of CREATE TABLE statements, by using the
+ compiler extension documented in :ref:`sqlalchemy.ext.compiler_toplevel`
+ to extend :class:`.CreateColumn`.
+
+ Typical integration is to examine the incoming :class:`_schema.Column`
+ object, and to redirect compilation if a particular flag or condition
+ is found::
+
+ from sqlalchemy import schema
+ from sqlalchemy.ext.compiler import compiles
+
+ @compiles(schema.CreateColumn)
+ def compile(element, compiler, **kw):
+ column = element.element
+
+ if "special" not in column.info:
+ return compiler.visit_create_column(element, **kw)
+
+ text = "%s SPECIAL DIRECTIVE %s" % (
+ column.name,
+ compiler.type_compiler.process(column.type)
+ )
+ default = compiler.get_column_default_string(column)
+ if default is not None:
+ text += " DEFAULT " + default
+
+ if not column.nullable:
+ text += " NOT NULL"
+
+ if column.constraints:
+ text += " ".join(
+ compiler.process(const)
+ for const in column.constraints)
+ return text
+
+ The above construct can be applied to a :class:`_schema.Table`
+ as follows::
+
+ from sqlalchemy import Table, Metadata, Column, Integer, String
+ from sqlalchemy import schema
+
+ metadata = MetaData()
+
+ table = Table('mytable', MetaData(),
+ Column('x', Integer, info={"special":True}, primary_key=True),
+ Column('y', String(50)),
+ Column('z', String(20), info={"special":True})
+ )
+
+ metadata.create_all(conn)
+
+ Above, the directives we've added to the :attr:`_schema.Column.info`
+ collection
+ will be detected by our custom compilation scheme::
+
+ CREATE TABLE mytable (
+ x SPECIAL DIRECTIVE INTEGER NOT NULL,
+ y VARCHAR(50),
+ z SPECIAL DIRECTIVE VARCHAR(20),
+ PRIMARY KEY (x)
+ )
+
+ The :class:`.CreateColumn` construct can also be used to skip certain
+ columns when producing a ``CREATE TABLE``. This is accomplished by
+ creating a compilation rule that conditionally returns ``None``.
+ This is essentially how to produce the same effect as using the
+ ``system=True`` argument on :class:`_schema.Column`, which marks a column
+ as an implicitly-present "system" column.
+
+ For example, suppose we wish to produce a :class:`_schema.Table`
+ which skips
+ rendering of the PostgreSQL ``xmin`` column against the PostgreSQL
+ backend, but on other backends does render it, in anticipation of a
+ triggered rule. A conditional compilation rule could skip this name only
+ on PostgreSQL::
+
+ from sqlalchemy.schema import CreateColumn
+
+ @compiles(CreateColumn, "postgresql")
+ def skip_xmin(element, compiler, **kw):
+ if element.element.name == 'xmin':
+ return None
+ else:
+ return compiler.visit_create_column(element, **kw)
+
+
+ my_table = Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('xmin', Integer)
+ )
+
+ Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE``
+ which only includes the ``id`` column in the string; the ``xmin`` column
+ will be omitted, but only against the PostgreSQL backend.
+
+ """
+
+ __visit_name__ = "create_column"
+
+ def __init__(self, element):
+ self.element = element
+
+
+class DropTable(_CreateDropBase):
+ """Represent a DROP TABLE statement."""
+
+ __visit_name__ = "drop_table"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DropTable.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, element, bind=None, if_exists=False):
+ """Create a :class:`.DropTable` construct.
+
+ :param element: a :class:`_schema.Table` that's the subject
+ of the DROP.
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param if_exists: if True, an IF EXISTS operator will be applied to the
+ construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(DropTable, self).__init__(
+ element, _legacy_bind=bind, if_exists=if_exists
+ )
+
+
+class CreateSequence(_CreateDropBase):
+ """Represent a CREATE SEQUENCE statement."""
+
+ __visit_name__ = "create_sequence"
+
+
+class DropSequence(_CreateDropBase):
+ """Represent a DROP SEQUENCE statement."""
+
+ __visit_name__ = "drop_sequence"
+
+
+class CreateIndex(_CreateDropBase):
+ """Represent a CREATE INDEX statement."""
+
+ __visit_name__ = "create_index"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.CreateIndex.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, element, bind=None, if_not_exists=False):
+ """Create a :class:`.Createindex` construct.
+
+ :param element: a :class:`_schema.Index` that's the subject
+ of the CREATE.
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param if_not_exists: if True, an IF NOT EXISTS operator will be
+ applied to the construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(CreateIndex, self).__init__(
+ element, _legacy_bind=bind, if_not_exists=if_not_exists
+ )
+
+
+class DropIndex(_CreateDropBase):
+ """Represent a DROP INDEX statement."""
+
+ __visit_name__ = "drop_index"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_ddl.DropIndex.bind` argument is "
+ "deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, element, bind=None, if_exists=False):
+ """Create a :class:`.DropIndex` construct.
+
+ :param element: a :class:`_schema.Index` that's the subject
+ of the DROP.
+ :param on: See the description for 'on' in :class:`.DDL`.
+ :param bind: See the description for 'bind' in :class:`.DDL`.
+ :param if_exists: if True, an IF EXISTS operator will be applied to the
+ construct.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ super(DropIndex, self).__init__(
+ element, _legacy_bind=bind, if_exists=if_exists
+ )
+
+
+class AddConstraint(_CreateDropBase):
+ """Represent an ALTER TABLE ADD CONSTRAINT statement."""
+
+ __visit_name__ = "add_constraint"
+
+ def __init__(self, element, *args, **kw):
+ super(AddConstraint, self).__init__(element, *args, **kw)
+ element._create_rule = util.portable_instancemethod(
+ self._create_rule_disable
+ )
+
+
+class DropConstraint(_CreateDropBase):
+ """Represent an ALTER TABLE DROP CONSTRAINT statement."""
+
+ __visit_name__ = "drop_constraint"
+
+ def __init__(self, element, cascade=False, **kw):
+ self.cascade = cascade
+ super(DropConstraint, self).__init__(element, **kw)
+ element._create_rule = util.portable_instancemethod(
+ self._create_rule_disable
+ )
+
+
+class SetTableComment(_CreateDropBase):
+ """Represent a COMMENT ON TABLE IS statement."""
+
+ __visit_name__ = "set_table_comment"
+
+
+class DropTableComment(_CreateDropBase):
+ """Represent a COMMENT ON TABLE '' statement.
+
+ Note this varies a lot across database backends.
+
+ """
+
+ __visit_name__ = "drop_table_comment"
+
+
+class SetColumnComment(_CreateDropBase):
+ """Represent a COMMENT ON COLUMN IS statement."""
+
+ __visit_name__ = "set_column_comment"
+
+
+class DropColumnComment(_CreateDropBase):
+ """Represent a COMMENT ON COLUMN IS NULL statement."""
+
+ __visit_name__ = "drop_column_comment"
+
+
+class DDLBase(SchemaVisitor):
+ def __init__(self, connection):
+ self.connection = connection
+
+
+class SchemaGenerator(DDLBase):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
+ super(SchemaGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+ self.memo = {}
+
+ def _can_create_table(self, table):
+ self.dialect.validate_identifier(table.name)
+ effective_schema = self.connection.schema_for_object(table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or not self.dialect.has_table(
+ self.connection, table.name, schema=effective_schema
+ )
+
+ def _can_create_index(self, index):
+ effective_schema = self.connection.schema_for_object(index.table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or not self.dialect.has_index(
+ self.connection,
+ index.table.name,
+ index.name,
+ schema=effective_schema,
+ )
+
+ def _can_create_sequence(self, sequence):
+ effective_schema = self.connection.schema_for_object(sequence)
+
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or not self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
+ )
+ )
+ )
+
+ def visit_metadata(self, metadata):
+ if self.tables is not None:
+ tables = self.tables
+ else:
+ tables = list(metadata.tables.values())
+
+ collection = sort_tables_and_constraints(
+ [t for t in tables if self._can_create_table(t)]
+ )
+
+ seq_coll = [
+ s
+ for s in metadata._sequences.values()
+ if s.column is None and self._can_create_sequence(s)
+ ]
+
+ event_collection = [t for (t, fks) in collection if t is not None]
+ metadata.dispatch.before_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ for seq in seq_coll:
+ self.traverse_single(seq, create_ok=True)
+
+ for table, fkcs in collection:
+ if table is not None:
+ self.traverse_single(
+ table,
+ create_ok=True,
+ include_foreign_key_constraints=fkcs,
+ _is_metadata_operation=True,
+ )
+ else:
+ for fkc in fkcs:
+ self.traverse_single(fkc)
+
+ metadata.dispatch.after_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ def visit_table(
+ self,
+ table,
+ create_ok=False,
+ include_foreign_key_constraints=None,
+ _is_metadata_operation=False,
+ ):
+ if not create_ok and not self._can_create_table(table):
+ return
+
+ table.dispatch.before_create(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ for column in table.columns:
+ if column.default is not None:
+ self.traverse_single(column.default)
+
+ if not self.dialect.supports_alter:
+ # e.g., don't omit any foreign key constraints
+ include_foreign_key_constraints = None
+
+ self.connection.execute(
+ # fmt: off
+ CreateTable(
+ table,
+ include_foreign_key_constraints= # noqa
+ include_foreign_key_constraints, # noqa
+ )
+ # fmt: on
+ )
+
+ if hasattr(table, "indexes"):
+ for index in table.indexes:
+ self.traverse_single(index, create_ok=True)
+
+ if self.dialect.supports_comments and not self.dialect.inline_comments:
+ if table.comment is not None:
+ self.connection.execute(SetTableComment(table))
+
+ for column in table.columns:
+ if column.comment is not None:
+ self.connection.execute(SetColumnComment(column))
+
+ table.dispatch.after_create(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ def visit_foreign_key_constraint(self, constraint):
+ if not self.dialect.supports_alter:
+ return
+ self.connection.execute(AddConstraint(constraint))
+
+ def visit_sequence(self, sequence, create_ok=False):
+ if not create_ok and not self._can_create_sequence(sequence):
+ return
+ self.connection.execute(CreateSequence(sequence))
+
+ def visit_index(self, index, create_ok=False):
+ if not create_ok and not self._can_create_index(index):
+ return
+ self.connection.execute(CreateIndex(index))
+
+
+class SchemaDropper(DDLBase):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
+ super(SchemaDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+ self.tables = tables
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+ self.memo = {}
+
+ def visit_metadata(self, metadata):
+ if self.tables is not None:
+ tables = self.tables
+ else:
+ tables = list(metadata.tables.values())
+
+ try:
+ unsorted_tables = [t for t in tables if self._can_drop_table(t)]
+ collection = list(
+ reversed(
+ sort_tables_and_constraints(
+ unsorted_tables,
+ filter_fn=lambda constraint: False
+ if not self.dialect.supports_alter
+ or constraint.name is None
+ else None,
+ )
+ )
+ )
+ except exc.CircularDependencyError as err2:
+ if not self.dialect.supports_alter:
+ util.warn(
+ "Can't sort tables for DROP; an "
+ "unresolvable foreign key "
+ "dependency exists between tables: %s; and backend does "
+ "not support ALTER. To restore at least a partial sort, "
+ "apply use_alter=True to ForeignKey and "
+ "ForeignKeyConstraint "
+ "objects involved in the cycle to mark these as known "
+ "cycles that will be ignored."
+ % (", ".join(sorted([t.fullname for t in err2.cycles])))
+ )
+ collection = [(t, ()) for t in unsorted_tables]
+ else:
+ util.raise_(
+ exc.CircularDependencyError(
+ err2.args[0],
+ err2.cycles,
+ err2.edges,
+ msg="Can't sort tables for DROP; an "
+ "unresolvable foreign key "
+ "dependency exists between tables: %s. Please ensure "
+ "that the ForeignKey and ForeignKeyConstraint objects "
+ "involved in the cycle have "
+ "names so that they can be dropped using "
+ "DROP CONSTRAINT."
+ % (
+ ", ".join(
+ sorted([t.fullname for t in err2.cycles])
+ )
+ ),
+ ),
+ from_=err2,
+ )
+
+ seq_coll = [
+ s
+ for s in metadata._sequences.values()
+ if self._can_drop_sequence(s)
+ ]
+
+ event_collection = [t for (t, fks) in collection if t is not None]
+
+ metadata.dispatch.before_drop(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ for table, fkcs in collection:
+ if table is not None:
+ self.traverse_single(
+ table,
+ drop_ok=True,
+ _is_metadata_operation=True,
+ _ignore_sequences=seq_coll,
+ )
+ else:
+ for fkc in fkcs:
+ self.traverse_single(fkc)
+
+ for seq in seq_coll:
+ self.traverse_single(seq, drop_ok=seq.column is None)
+
+ metadata.dispatch.after_drop(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
+
+ def _can_drop_table(self, table):
+ self.dialect.validate_identifier(table.name)
+ effective_schema = self.connection.schema_for_object(table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or self.dialect.has_table(
+ self.connection, table.name, schema=effective_schema
+ )
+
+ def _can_drop_index(self, index):
+ effective_schema = self.connection.schema_for_object(index.table)
+ if effective_schema:
+ self.dialect.validate_identifier(effective_schema)
+ return not self.checkfirst or self.dialect.has_index(
+ self.connection,
+ index.table.name,
+ index.name,
+ schema=effective_schema,
+ )
+
+ def _can_drop_sequence(self, sequence):
+ effective_schema = self.connection.schema_for_object(sequence)
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
+ )
+ )
+ )
+
+ def visit_index(self, index, drop_ok=False):
+ if not drop_ok and not self._can_drop_index(index):
+ return
+
+ self.connection.execute(DropIndex(index))
+
+ def visit_table(
+ self,
+ table,
+ drop_ok=False,
+ _is_metadata_operation=False,
+ _ignore_sequences=(),
+ ):
+ if not drop_ok and not self._can_drop_table(table):
+ return
+
+ table.dispatch.before_drop(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ self.connection.execute(DropTable(table))
+
+ # traverse client side defaults which may refer to server-side
+ # sequences. noting that some of these client side defaults may also be
+ # set up as server side defaults (see https://docs.sqlalchemy.org/en/
+ # latest/core/defaults.html#associating-a-sequence-as-the-server-side-
+ # default), so have to be dropped after the table is dropped.
+ for column in table.columns:
+ if (
+ column.default is not None
+ and column.default not in _ignore_sequences
+ ):
+ self.traverse_single(column.default)
+
+ table.dispatch.after_drop(
+ table,
+ self.connection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ _is_metadata_operation=_is_metadata_operation,
+ )
+
+ def visit_foreign_key_constraint(self, constraint):
+ if not self.dialect.supports_alter:
+ return
+ self.connection.execute(DropConstraint(constraint))
+
+ def visit_sequence(self, sequence, drop_ok=False):
+
+ if not drop_ok and not self._can_drop_sequence(sequence):
+ return
+ self.connection.execute(DropSequence(sequence))
+
+
+def sort_tables(
+ tables,
+ skip_fn=None,
+ extra_dependencies=None,
+):
+ """Sort a collection of :class:`_schema.Table` objects based on
+ dependency.
+
+ This is a dependency-ordered sort which will emit :class:`_schema.Table`
+ objects such that they will follow their dependent :class:`_schema.Table`
+ objects.
+ Tables are dependent on another based on the presence of
+ :class:`_schema.ForeignKeyConstraint`
+ objects as well as explicit dependencies
+ added by :meth:`_schema.Table.add_is_dependent_on`.
+
+ .. warning::
+
+ The :func:`._schema.sort_tables` function cannot by itself
+ accommodate automatic resolution of dependency cycles between
+ tables, which are usually caused by mutually dependent foreign key
+ constraints. When these cycles are detected, the foreign keys
+ of these tables are omitted from consideration in the sort.
+ A warning is emitted when this condition occurs, which will be an
+ exception raise in a future release. Tables which are not part
+ of the cycle will still be returned in dependency order.
+
+ To resolve these cycles, the
+ :paramref:`_schema.ForeignKeyConstraint.use_alter` parameter may be
+ applied to those constraints which create a cycle. Alternatively,
+ the :func:`_schema.sort_tables_and_constraints` function will
+ automatically return foreign key constraints in a separate
+ collection when cycles are detected so that they may be applied
+ to a schema separately.
+
+ .. versionchanged:: 1.3.17 - a warning is emitted when
+ :func:`_schema.sort_tables` cannot perform a proper sort due to
+ cyclical dependencies. This will be an exception in a future
+ release. Additionally, the sort will continue to return
+ other tables not involved in the cycle in dependency order
+ which was not the case previously.
+
+ :param tables: a sequence of :class:`_schema.Table` objects.
+
+ :param skip_fn: optional callable which will be passed a
+ :class:`_schema.ForeignKey` object; if it returns True, this
+ constraint will not be considered as a dependency. Note this is
+ **different** from the same parameter in
+ :func:`.sort_tables_and_constraints`, which is
+ instead passed the owning :class:`_schema.ForeignKeyConstraint` object.
+
+ :param extra_dependencies: a sequence of 2-tuples of tables which will
+ also be considered as dependent on each other.
+
+ .. seealso::
+
+ :func:`.sort_tables_and_constraints`
+
+ :attr:`_schema.MetaData.sorted_tables` - uses this function to sort
+
+
+ """
+
+ if skip_fn is not None:
+
+ def _skip_fn(fkc):
+ for fk in fkc.elements:
+ if skip_fn(fk):
+ return True
+ else:
+ return None
+
+ else:
+ _skip_fn = None
+
+ return [
+ t
+ for (t, fkcs) in sort_tables_and_constraints(
+ tables,
+ filter_fn=_skip_fn,
+ extra_dependencies=extra_dependencies,
+ _warn_for_cycles=True,
+ )
+ if t is not None
+ ]
+
+
+def sort_tables_and_constraints(
+ tables, filter_fn=None, extra_dependencies=None, _warn_for_cycles=False
+):
+ """Sort a collection of :class:`_schema.Table` /
+ :class:`_schema.ForeignKeyConstraint`
+ objects.
+
+ This is a dependency-ordered sort which will emit tuples of
+ ``(Table, [ForeignKeyConstraint, ...])`` such that each
+ :class:`_schema.Table` follows its dependent :class:`_schema.Table`
+ objects.
+ Remaining :class:`_schema.ForeignKeyConstraint`
+ objects that are separate due to
+ dependency rules not satisfied by the sort are emitted afterwards
+ as ``(None, [ForeignKeyConstraint ...])``.
+
+ Tables are dependent on another based on the presence of
+ :class:`_schema.ForeignKeyConstraint` objects, explicit dependencies
+ added by :meth:`_schema.Table.add_is_dependent_on`,
+ as well as dependencies
+ stated here using the :paramref:`~.sort_tables_and_constraints.skip_fn`
+ and/or :paramref:`~.sort_tables_and_constraints.extra_dependencies`
+ parameters.
+
+ :param tables: a sequence of :class:`_schema.Table` objects.
+
+ :param filter_fn: optional callable which will be passed a
+ :class:`_schema.ForeignKeyConstraint` object,
+ and returns a value based on
+ whether this constraint should definitely be included or excluded as
+ an inline constraint, or neither. If it returns False, the constraint
+ will definitely be included as a dependency that cannot be subject
+ to ALTER; if True, it will **only** be included as an ALTER result at
+ the end. Returning None means the constraint is included in the
+ table-based result unless it is detected as part of a dependency cycle.
+
+ :param extra_dependencies: a sequence of 2-tuples of tables which will
+ also be considered as dependent on each other.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :func:`.sort_tables`
+
+
+ """
+
+ fixed_dependencies = set()
+ mutable_dependencies = set()
+
+ if extra_dependencies is not None:
+ fixed_dependencies.update(extra_dependencies)
+
+ remaining_fkcs = set()
+ for table in tables:
+ for fkc in table.foreign_key_constraints:
+ if fkc.use_alter is True:
+ remaining_fkcs.add(fkc)
+ continue
+
+ if filter_fn:
+ filtered = filter_fn(fkc)
+
+ if filtered is True:
+ remaining_fkcs.add(fkc)
+ continue
+
+ dependent_on = fkc.referred_table
+ if dependent_on is not table:
+ mutable_dependencies.add((dependent_on, table))
+
+ fixed_dependencies.update(
+ (parent, table) for parent in table._extra_dependencies
+ )
+
+ try:
+ candidate_sort = list(
+ topological.sort(
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ )
+ )
+ except exc.CircularDependencyError as err:
+ if _warn_for_cycles:
+ util.warn(
+ "Cannot correctly sort tables; there are unresolvable cycles "
+ 'between tables "%s", which is usually caused by mutually '
+ "dependent foreign key constraints. Foreign key constraints "
+ "involving these tables will not be considered; this warning "
+ "may raise an error in a future release."
+ % (", ".join(sorted(t.fullname for t in err.cycles)),)
+ )
+ for edge in err.edges:
+ if edge in mutable_dependencies:
+ table = edge[1]
+ if table not in err.cycles:
+ continue
+ can_remove = [
+ fkc
+ for fkc in table.foreign_key_constraints
+ if filter_fn is None or filter_fn(fkc) is not False
+ ]
+ remaining_fkcs.update(can_remove)
+ for fkc in can_remove:
+ dependent_on = fkc.referred_table
+ if dependent_on is not table:
+ mutable_dependencies.discard((dependent_on, table))
+ candidate_sort = list(
+ topological.sort(
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ )
+ )
+
+ return [
+ (table, table.foreign_key_constraints.difference(remaining_fkcs))
+ for table in candidate_sort
+ ] + [(None, list(remaining_fkcs))]
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
new file mode 100644
index 0000000..70586c6
--- /dev/null
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -0,0 +1,360 @@
+# sql/default_comparator.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Default implementation of SQL comparison operations.
+"""
+
+
+from . import coercions
+from . import operators
+from . import roles
+from . import type_api
+from .elements import and_
+from .elements import BinaryExpression
+from .elements import ClauseList
+from .elements import collate
+from .elements import CollectionAggregate
+from .elements import False_
+from .elements import Null
+from .elements import or_
+from .elements import True_
+from .elements import UnaryExpression
+from .. import exc
+from .. import util
+
+
+def _boolean_compare(
+ expr,
+ op,
+ obj,
+ negate=None,
+ reverse=False,
+ _python_is_types=(util.NoneType, bool),
+ _any_all_expr=False,
+ result_type=None,
+ **kwargs
+):
+
+ if result_type is None:
+ result_type = type_api.BOOLEANTYPE
+
+ if isinstance(obj, _python_is_types + (Null, True_, False_)):
+ # allow x ==/!= True/False to be treated as a literal.
+ # this comes out to "== / != true/false" or "1/0" if those
+ # constants aren't supported and works on all platforms
+ if op in (operators.eq, operators.ne) and isinstance(
+ obj, (bool, True_, False_)
+ ):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
+ elif op in (
+ operators.is_distinct_from,
+ operators.is_not_distinct_from,
+ ):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
+ elif _any_all_expr:
+ obj = coercions.expect(
+ roles.ConstExprRole, element=obj, operator=op, expr=expr
+ )
+ else:
+ # all other None uses IS, IS NOT
+ if op in (operators.eq, operators.is_):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ operators.is_,
+ negate=operators.is_not,
+ type_=result_type,
+ )
+ elif op in (operators.ne, operators.is_not):
+ return BinaryExpression(
+ expr,
+ coercions.expect(roles.ConstExprRole, obj),
+ operators.is_not,
+ negate=operators.is_,
+ type_=result_type,
+ )
+ else:
+ raise exc.ArgumentError(
+ "Only '=', '!=', 'is_()', 'is_not()', "
+ "'is_distinct_from()', 'is_not_distinct_from()' "
+ "operators can be used with None/True/False"
+ )
+ else:
+ obj = coercions.expect(
+ roles.BinaryElementRole, element=obj, operator=op, expr=expr
+ )
+
+ if reverse:
+ return BinaryExpression(
+ obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
+ else:
+ return BinaryExpression(
+ expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
+
+
+def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
+ if result_type is None:
+ if op.return_type:
+ result_type = op.return_type
+ elif op.is_comparison:
+ result_type = type_api.BOOLEANTYPE
+
+ return _binary_operate(
+ expr, op, obj, reverse=reverse, result_type=result_type, **kw
+ )
+
+
+def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw):
+ obj = coercions.expect(
+ roles.BinaryElementRole, obj, expr=expr, operator=op
+ )
+
+ if reverse:
+ left, right = obj, expr
+ else:
+ left, right = expr, obj
+
+ if result_type is None:
+ op, result_type = left.comparator._adapt_expression(
+ op, right.comparator
+ )
+
+ return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
+
+
+def _conjunction_operate(expr, op, other, **kw):
+ if op is operators.and_:
+ return and_(expr, other)
+ elif op is operators.or_:
+ return or_(expr, other)
+ else:
+ raise NotImplementedError()
+
+
+def _scalar(expr, op, fn, **kw):
+ return fn(expr)
+
+
+def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
+ seq_or_selectable = coercions.expect(
+ roles.InElementRole, seq_or_selectable, expr=expr, operator=op
+ )
+ if "in_ops" in seq_or_selectable._annotations:
+ op, negate_op = seq_or_selectable._annotations["in_ops"]
+
+ return _boolean_compare(
+ expr, op, seq_or_selectable, negate=negate_op, **kw
+ )
+
+
+def _getitem_impl(expr, op, other, **kw):
+ if (
+ isinstance(expr.type, type_api.INDEXABLE)
+ or isinstance(expr.type, type_api.TypeDecorator)
+ and isinstance(expr.type.impl, type_api.INDEXABLE)
+ ):
+ other = coercions.expect(
+ roles.BinaryElementRole, other, expr=expr, operator=op
+ )
+ return _binary_operate(expr, op, other, **kw)
+ else:
+ _unsupported_impl(expr, op, other, **kw)
+
+
+def _unsupported_impl(expr, op, *arg, **kw):
+ raise NotImplementedError(
+ "Operator '%s' is not supported on " "this expression" % op.__name__
+ )
+
+
+def _inv_impl(expr, op, **kw):
+ """See :meth:`.ColumnOperators.__inv__`."""
+
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
+ if hasattr(expr, "negation_clause"):
+ return expr.negation_clause
+ else:
+ return expr._negate()
+
+
+def _neg_impl(expr, op, **kw):
+ """See :meth:`.ColumnOperators.__neg__`."""
+ return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
+
+
+def _match_impl(expr, op, other, **kw):
+ """See :meth:`.ColumnOperators.match`."""
+
+ return _boolean_compare(
+ expr,
+ operators.match_op,
+ coercions.expect(
+ roles.BinaryElementRole,
+ other,
+ expr=expr,
+ operator=operators.match_op,
+ ),
+ result_type=type_api.MATCHTYPE,
+ negate=operators.not_match_op
+ if op is operators.match_op
+ else operators.match_op,
+ **kw
+ )
+
+
+def _distinct_impl(expr, op, **kw):
+ """See :meth:`.ColumnOperators.distinct`."""
+ return UnaryExpression(
+ expr, operator=operators.distinct_op, type_=expr.type
+ )
+
+
+def _between_impl(expr, op, cleft, cright, **kw):
+ """See :meth:`.ColumnOperators.between`."""
+ return BinaryExpression(
+ expr,
+ ClauseList(
+ coercions.expect(
+ roles.BinaryElementRole,
+ cleft,
+ expr=expr,
+ operator=operators.and_,
+ ),
+ coercions.expect(
+ roles.BinaryElementRole,
+ cright,
+ expr=expr,
+ operator=operators.and_,
+ ),
+ operator=operators.and_,
+ group=False,
+ group_contents=False,
+ ),
+ op,
+ negate=operators.not_between_op
+ if op is operators.between_op
+ else operators.between_op,
+ modifiers=kw,
+ )
+
+
+def _collate_impl(expr, op, other, **kw):
+ return collate(expr, other)
+
+
+def _regexp_match_impl(expr, op, pattern, flags, **kw):
+ if flags is not None:
+ flags = coercions.expect(
+ roles.BinaryElementRole,
+ flags,
+ expr=expr,
+ operator=operators.regexp_replace_op,
+ )
+ return _boolean_compare(
+ expr,
+ op,
+ pattern,
+ flags=flags,
+ negate=operators.not_regexp_match_op
+ if op is operators.regexp_match_op
+ else operators.regexp_match_op,
+ **kw
+ )
+
+
+def _regexp_replace_impl(expr, op, pattern, replacement, flags, **kw):
+ replacement = coercions.expect(
+ roles.BinaryElementRole,
+ replacement,
+ expr=expr,
+ operator=operators.regexp_replace_op,
+ )
+ if flags is not None:
+ flags = coercions.expect(
+ roles.BinaryElementRole,
+ flags,
+ expr=expr,
+ operator=operators.regexp_replace_op,
+ )
+ return _binary_operate(
+ expr, op, pattern, replacement=replacement, flags=flags, **kw
+ )
+
+
+# a mapping of operators with the method they use, along with
+# their negated operator for comparison operators
+operator_lookup = {
+ "and_": (_conjunction_operate,),
+ "or_": (_conjunction_operate,),
+ "inv": (_inv_impl,),
+ "add": (_binary_operate,),
+ "mul": (_binary_operate,),
+ "sub": (_binary_operate,),
+ "div": (_binary_operate,),
+ "mod": (_binary_operate,),
+ "truediv": (_binary_operate,),
+ "custom_op": (_custom_op_operate,),
+ "json_path_getitem_op": (_binary_operate,),
+ "json_getitem_op": (_binary_operate,),
+ "concat_op": (_binary_operate,),
+ "any_op": (_scalar, CollectionAggregate._create_any),
+ "all_op": (_scalar, CollectionAggregate._create_all),
+ "lt": (_boolean_compare, operators.ge),
+ "le": (_boolean_compare, operators.gt),
+ "ne": (_boolean_compare, operators.eq),
+ "gt": (_boolean_compare, operators.le),
+ "ge": (_boolean_compare, operators.lt),
+ "eq": (_boolean_compare, operators.ne),
+ "is_distinct_from": (_boolean_compare, operators.is_not_distinct_from),
+ "is_not_distinct_from": (_boolean_compare, operators.is_distinct_from),
+ "like_op": (_boolean_compare, operators.not_like_op),
+ "ilike_op": (_boolean_compare, operators.not_ilike_op),
+ "not_like_op": (_boolean_compare, operators.like_op),
+ "not_ilike_op": (_boolean_compare, operators.ilike_op),
+ "contains_op": (_boolean_compare, operators.not_contains_op),
+ "startswith_op": (_boolean_compare, operators.not_startswith_op),
+ "endswith_op": (_boolean_compare, operators.not_endswith_op),
+ "desc_op": (_scalar, UnaryExpression._create_desc),
+ "asc_op": (_scalar, UnaryExpression._create_asc),
+ "nulls_first_op": (_scalar, UnaryExpression._create_nulls_first),
+ "nulls_last_op": (_scalar, UnaryExpression._create_nulls_last),
+ "in_op": (_in_impl, operators.not_in_op),
+ "not_in_op": (_in_impl, operators.in_op),
+ "is_": (_boolean_compare, operators.is_),
+ "is_not": (_boolean_compare, operators.is_not),
+ "collate": (_collate_impl,),
+ "match_op": (_match_impl,),
+ "not_match_op": (_match_impl,),
+ "distinct_op": (_distinct_impl,),
+ "between_op": (_between_impl,),
+ "not_between_op": (_between_impl,),
+ "neg": (_neg_impl,),
+ "getitem": (_getitem_impl,),
+ "lshift": (_unsupported_impl,),
+ "rshift": (_unsupported_impl,),
+ "contains": (_unsupported_impl,),
+ "regexp_match_op": (_regexp_match_impl,),
+ "not_regexp_match_op": (_regexp_match_impl,),
+ "regexp_replace_op": (_regexp_replace_impl,),
+}
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
new file mode 100644
index 0000000..07a4d7b
--- /dev/null
+++ b/lib/sqlalchemy/sql/dml.py
@@ -0,0 +1,1514 @@
+# sql/dml.py
+# Copyright (C) 2009-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""
+Provide :class:`_expression.Insert`, :class:`_expression.Update` and
+:class:`_expression.Delete`.
+
+"""
+from sqlalchemy.types import NullType
+from . import coercions
+from . import roles
+from . import util as sql_util
+from .base import _entity_namespace_key
+from .base import _exclusive_against
+from .base import _from_objects
+from .base import _generative
+from .base import ColumnCollection
+from .base import CompileState
+from .base import DialectKWArgs
+from .base import Executable
+from .base import HasCompileState
+from .elements import BooleanClauseList
+from .elements import ClauseElement
+from .elements import Null
+from .selectable import HasCTE
+from .selectable import HasPrefixes
+from .selectable import ReturnsRows
+from .visitors import InternalTraversal
+from .. import exc
+from .. import util
+from ..util import collections_abc
+
+
+class DMLState(CompileState):
+ _no_parameters = True
+ _dict_parameters = None
+ _multi_parameters = None
+ _ordered_values = None
+ _parameter_ordering = None
+ _has_multi_parameters = False
+ isupdate = False
+ isdelete = False
+ isinsert = False
+
+ def __init__(self, statement, compiler, **kw):
+ raise NotImplementedError()
+
+ @classmethod
+ def get_entity_description(cls, statement):
+ return {"name": statement.table.name, "table": statement.table}
+
+ @classmethod
+ def get_returning_column_descriptions(cls, statement):
+ return [
+ {
+ "name": c.key,
+ "type": c.type,
+ "expr": c,
+ }
+ for c in statement._all_selected_columns
+ ]
+
+ @property
+ def dml_table(self):
+ return self.statement.table
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator):
+ return [
+ (
+ coercions.expect(roles.DMLColumnRole, k),
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=NullType(),
+ is_crud=True,
+ ),
+ )
+ for k, v in kv_iterator
+ ]
+
+ def _make_extra_froms(self, statement):
+ froms = []
+
+ all_tables = list(sql_util.tables_from_leftmost(statement.table))
+ seen = {all_tables[0]}
+
+ for crit in statement._where_criteria:
+ for item in _from_objects(crit):
+ if not seen.intersection(item._cloned_set):
+ froms.append(item)
+ seen.update(item._cloned_set)
+
+ froms.extend(all_tables[1:])
+ return froms
+
+ def _process_multi_values(self, statement):
+ if not statement._supports_multi_parameters:
+ raise exc.InvalidRequestError(
+ "%s construct does not support "
+ "multiple parameter sets." % statement.__visit_name__.upper()
+ )
+
+ for parameters in statement._multi_values:
+ multi_parameters = [
+ {
+ c.key: value
+ for c, value in zip(statement.table.c, parameter_set)
+ }
+ if isinstance(parameter_set, collections_abc.Sequence)
+ else parameter_set
+ for parameter_set in parameters
+ ]
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._has_multi_parameters = True
+ self._multi_parameters = multi_parameters
+ self._dict_parameters = self._multi_parameters[0]
+ elif not self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ self._multi_parameters.extend(multi_parameters)
+
+ def _process_values(self, statement):
+ if self._no_parameters:
+ self._has_multi_parameters = False
+ self._dict_parameters = statement._values
+ self._no_parameters = False
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+
+ def _process_ordered_values(self, statement):
+ parameters = statement._ordered_values
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = dict(parameters)
+ self._ordered_values = parameters
+ self._parameter_ordering = [key for key, value in parameters]
+ elif self._has_multi_parameters:
+ self._cant_mix_formats_error()
+ else:
+ raise exc.InvalidRequestError(
+ "Can only invoke ordered_values() once, and not mixed "
+ "with any other values() call"
+ )
+
+ def _process_select_values(self, statement):
+ parameters = {
+ coercions.expect(roles.DMLColumnRole, name, as_key=True): Null()
+ for name in statement._select_names
+ }
+
+ if self._no_parameters:
+ self._no_parameters = False
+ self._dict_parameters = parameters
+ else:
+ # this condition normally not reachable as the Insert
+ # does not allow this construction to occur
+ assert False, "This statement already has parameters"
+
+ def _cant_mix_formats_error(self):
+ raise exc.InvalidRequestError(
+ "Can't mix single and multiple VALUES "
+ "formats in one INSERT statement; one style appends to a "
+ "list while the other replaces values, so the intent is "
+ "ambiguous."
+ )
+
+
+@CompileState.plugin_for("default", "insert")
+class InsertDMLState(DMLState):
+ isinsert = True
+
+ include_table_with_column_exprs = False
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+ self.isinsert = True
+ if statement._select_names:
+ self._process_select_values(statement)
+ if statement._values is not None:
+ self._process_values(statement)
+ if statement._multi_values:
+ self._process_multi_values(statement)
+
+ @util.memoized_property
+ def _insert_col_keys(self):
+ # this is also done in crud.py -> _key_getters_for_crud_column
+ return [
+ coercions.expect_as_key(roles.DMLColumnRole, col)
+ for col in self._dict_parameters
+ ]
+
+
+@CompileState.plugin_for("default", "update")
+class UpdateDMLState(DMLState):
+ isupdate = True
+
+ include_table_with_column_exprs = False
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+ self.isupdate = True
+ self._preserve_parameter_order = statement._preserve_parameter_order
+ if statement._ordered_values is not None:
+ self._process_ordered_values(statement)
+ elif statement._values is not None:
+ self._process_values(statement)
+ elif statement._multi_values:
+ self._process_multi_values(statement)
+ self._extra_froms = ef = self._make_extra_froms(statement)
+ self.is_multitable = mt = ef and self._dict_parameters
+ self.include_table_with_column_exprs = (
+ mt and compiler.render_table_with_column_in_update_from
+ )
+
+
+@CompileState.plugin_for("default", "delete")
+class DeleteDMLState(DMLState):
+ isdelete = True
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+
+ self.isdelete = True
+ self._extra_froms = self._make_extra_froms(statement)
+
+
+class UpdateBase(
+ roles.DMLRole,
+ HasCTE,
+ HasCompileState,
+ DialectKWArgs,
+ HasPrefixes,
+ ReturnsRows,
+ Executable,
+ ClauseElement,
+):
+ """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
+
+ __visit_name__ = "update_base"
+
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
+ _hints = util.immutabledict()
+ named_with_column = False
+
+ _return_defaults = False
+ _return_defaults_columns = None
+ _returning = ()
+
+ is_dml = True
+
+ @classmethod
+ def _constructor_20_deprecations(cls, fn_name, clsname, names):
+
+ param_to_method_lookup = dict(
+ whereclause=(
+ "The :paramref:`%(func)s.whereclause` parameter "
+ "will be removed "
+ "in SQLAlchemy 2.0. Please refer to the "
+ ":meth:`%(classname)s.where` method."
+ ),
+ values=(
+ "The :paramref:`%(func)s.values` parameter will be removed "
+ "in SQLAlchemy 2.0. Please refer to the "
+ ":meth:`%(classname)s.values` method."
+ ),
+ bind=(
+ "The :paramref:`%(func)s.bind` parameter will be removed in "
+ "SQLAlchemy 2.0. Please use explicit connection execution."
+ ),
+ inline=(
+ "The :paramref:`%(func)s.inline` parameter will be "
+ "removed in "
+ "SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.inline` method."
+ ),
+ prefixes=(
+ "The :paramref:`%(func)s.prefixes parameter will be "
+ "removed in "
+ "SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.prefix_with` "
+ "method."
+ ),
+ return_defaults=(
+ "The :paramref:`%(func)s.return_defaults` parameter will be "
+ "removed in SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.return_defaults` method."
+ ),
+ returning=(
+ "The :paramref:`%(func)s.returning` parameter will be "
+ "removed in SQLAlchemy 2.0. Please use the "
+ ":meth:`%(classname)s.returning`` method."
+ ),
+ preserve_parameter_order=(
+ "The :paramref:`%(func)s.preserve_parameter_order` parameter "
+ "will be removed in SQLAlchemy 2.0. Use the "
+ ":meth:`%(classname)s.ordered_values` method with a list "
+ "of tuples. "
+ ),
+ )
+
+ return util.deprecated_params(
+ **{
+ name: (
+ "2.0",
+ param_to_method_lookup[name]
+ % {
+ "func": "_expression.%s" % fn_name,
+ "classname": "_expression.%s" % clsname,
+ },
+ )
+ for name in names
+ }
+ )
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ fromclause._columns._populate_separate_keys(
+ col._make_proxy(fromclause) for col in self._returning
+ )
+
+ def params(self, *arg, **kw):
+ """Set the parameters for the statement.
+
+ This method raises ``NotImplementedError`` on the base class,
+ and is overridden by :class:`.ValuesBase` to provide the
+ SET/VALUES clause of UPDATE and INSERT.
+
+ """
+ raise NotImplementedError(
+ "params() is not supported for INSERT/UPDATE/DELETE statements."
+ " To set the values for an INSERT or UPDATE statement, use"
+ " stmt.values(**parameters)."
+ )
+
+ @_generative
+ def with_dialect_options(self, **opt):
+ """Add dialect options to this INSERT/UPDATE/DELETE object.
+
+ e.g.::
+
+ upd = table.update().dialect_options(mysql_limit=10)
+
+ .. versionadded: 1.4 - this method supersedes the dialect options
+ associated with the constructor.
+
+
+ """
+ self._validate_dialect_kwargs(opt)
+
+ def _validate_dialect_kwargs_deprecated(self, dialect_kw):
+ util.warn_deprecated_20(
+ "Passing dialect keyword arguments directly to the "
+ "%s constructor is deprecated and will be removed in SQLAlchemy "
+ "2.0. Please use the ``with_dialect_options()`` method."
+ % (self.__class__.__name__)
+ )
+ self._validate_dialect_kwargs(dialect_kw)
+
+ def bind(self):
+ """Return a 'bind' linked to this :class:`.UpdateBase`
+ or a :class:`_schema.Table` associated with it.
+
+ """
+ return self._bind or self.table.bind
+
+ def _set_bind(self, bind):
+ self._bind = bind
+
+ bind = property(bind, _set_bind)
+
+ @_generative
+ def returning(self, *cols):
+ r"""Add a :term:`RETURNING` or equivalent clause to this statement.
+
+ e.g.:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = (
+ ... table.update()
+ ... .where(table.c.data == "value")
+ ... .values(status="X")
+ ... .returning(table.c.server_flag, table.c.updated_timestamp)
+ ... )
+ >>> print(stmt)
+ UPDATE some_table SET status=:status
+ WHERE some_table.data = :data_1
+ RETURNING some_table.server_flag, some_table.updated_timestamp
+
+ The method may be invoked multiple times to add new entries to the
+ list of expressions to be returned.
+
+ .. versionadded:: 1.4.0b2 The method may be invoked multiple times to
+ add new entries to the list of expressions to be returned.
+
+ The given collection of column expressions should be derived from the
+ table that is the target of the INSERT, UPDATE, or DELETE. While
+ :class:`_schema.Column` objects are typical, the elements can also be
+ expressions:
+
+ .. sourcecode:: pycon+sql
+
+ >>> stmt = table.insert().returning(
+ ... (table.c.first_name + " " + table.c.last_name).label("fullname")
+ ... )
+ >>> print(stmt)
+ INSERT INTO some_table (first_name, last_name)
+ VALUES (:first_name, :last_name)
+ RETURNING some_table.first_name || :first_name_1 || some_table.last_name AS fullname
+
+ Upon compilation, a RETURNING clause, or database equivalent,
+ will be rendered within the statement. For INSERT and UPDATE,
+ the values are the newly inserted/updated values. For DELETE,
+ the values are those of the rows which were deleted.
+
+ Upon execution, the values of the columns to be returned are made
+ available via the result set and can be iterated using
+ :meth:`_engine.CursorResult.fetchone` and similar.
+ For DBAPIs which do not
+ natively support returning values (i.e. cx_oracle), SQLAlchemy will
+ approximate this behavior at the result level so that a reasonable
+ amount of behavioral neutrality is provided.
+
+ Note that not all databases/DBAPIs
+ support RETURNING. For those backends with no support,
+ an exception is raised upon compilation and/or execution.
+ For those who do support it, the functionality across backends
+ varies greatly, including restrictions on executemany()
+ and other statements which return multiple rows. Please
+ read the documentation notes for the database in use in
+ order to determine the availability of RETURNING.
+
+ .. seealso::
+
+ :meth:`.ValuesBase.return_defaults` - an alternative method tailored
+ towards efficient fetching of server-side defaults and triggers
+ for single-row INSERTs or UPDATEs.
+
+ :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial`
+
+ """ # noqa: E501
+ if self._return_defaults:
+ raise exc.InvalidRequestError(
+ "return_defaults() is already configured on this statement"
+ )
+ self._returning += tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
+
+ @property
+ def _all_selected_columns(self):
+ return self._returning
+
+ @property
+ def exported_columns(self):
+ """Return the RETURNING columns as a column collection for this
+ statement.
+
+ .. versionadded:: 1.4
+
+ """
+ # TODO: no coverage here
+ return ColumnCollection(
+ (c.key, c) for c in self._all_selected_columns
+ ).as_immutable()
+
+ @_generative
+ def with_hint(self, text, selectable=None, dialect_name="*"):
+ """Add a table hint for a single table to this
+ INSERT/UPDATE/DELETE statement.
+
+ .. note::
+
+ :meth:`.UpdateBase.with_hint` currently applies only to
+ Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use
+ :meth:`.UpdateBase.prefix_with`.
+
+ The text of the hint is rendered in the appropriate
+ location for the database backend in use, relative
+ to the :class:`_schema.Table` that is the subject of this
+ statement, or optionally to that of the given
+ :class:`_schema.Table` passed as the ``selectable`` argument.
+
+ The ``dialect_name`` option will limit the rendering of a particular
+ hint to a particular backend. Such as, to add a hint
+ that only takes effect for SQL Server::
+
+ mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
+
+ :param text: Text of the hint.
+ :param selectable: optional :class:`_schema.Table` that specifies
+ an element of the FROM clause within an UPDATE or DELETE
+ to be the subject of the hint - applies only to certain backends.
+ :param dialect_name: defaults to ``*``, if specified as the name
+ of a particular dialect, will apply these hints only when
+ that dialect is in use.
+ """
+ if selectable is None:
+ selectable = self.table
+
+ self._hints = self._hints.union({(selectable, dialect_name): text})
+
+ @property
+ def entity_description(self):
+ """Return a :term:`plugin-enabled` description of the table and/or
+ entity which this DML construct is operating against.
+
+ This attribute is generally useful when using the ORM, as an
+ extended structure which includes information about mapped
+ entities is returned. The section :ref:`queryguide_inspection`
+ contains more background.
+
+ For a Core statement, the structure returned by this accessor
+ is derived from the :attr:`.UpdateBase.table` attribute, and
+ refers to the :class:`.Table` being inserted, updated, or deleted::
+
+ >>> stmt = insert(user_table)
+ >>> stmt.entity_description
+ {
+ "name": "user_table",
+ "table": Table("user_table", ...)
+ }
+
+ .. versionadded:: 1.4.33
+
+ .. seealso::
+
+ :attr:`.UpdateBase.returning_column_descriptions`
+
+ :attr:`.Select.column_descriptions` - entity information for
+ a :func:`.select` construct
+
+ :ref:`queryguide_inspection` - ORM background
+
+ """
+ meth = DMLState.get_plugin_class(self).get_entity_description
+ return meth(self)
+
+ @property
+ def returning_column_descriptions(self):
+ """Return a :term:`plugin-enabled` description of the columns
+ which this DML construct is RETURNING against, in other words
+ the expressions established as part of :meth:`.UpdateBase.returning`.
+
+ This attribute is generally useful when using the ORM, as an
+ extended structure which includes information about mapped
+ entities is returned. The section :ref:`queryguide_inspection`
+ contains more background.
+
+ For a Core statement, the structure returned by this accessor is
+ derived from the same objects that are returned by the
+ :attr:`.UpdateBase.exported_columns` accessor::
+
+ >>> stmt = insert(user_table).returning(user_table.c.id, user_table.c.name)
+ >>> stmt.entity_description
+ [
+ {
+ "name": "id",
+ "type": Integer,
+ "expr": Column("id", Integer(), table=<user>, ...)
+ },
+ {
+ "name": "name",
+ "type": String(),
+ "expr": Column("name", String(), table=<user>, ...)
+ },
+ ]
+
+ .. versionadded:: 1.4.33
+
+ .. seealso::
+
+ :attr:`.UpdateBase.entity_description`
+
+ :attr:`.Select.column_descriptions` - entity information for
+ a :func:`.select` construct
+
+ :ref:`queryguide_inspection` - ORM background
+
+ """ # noqa: E501
+ meth = DMLState.get_plugin_class(
+ self
+ ).get_returning_column_descriptions
+ return meth(self)
+
+
+class ValuesBase(UpdateBase):
+ """Supplies support for :meth:`.ValuesBase.values` to
+ INSERT and UPDATE constructs."""
+
+ __visit_name__ = "values_base"
+
+ _supports_multi_parameters = False
+ _preserve_parameter_order = False
+ select = None
+ _post_values_clause = None
+
+ _values = None
+ _multi_values = ()
+ _ordered_values = None
+ _select_names = None
+
+ _returning = ()
+
+ def __init__(self, table, values, prefixes):
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
+ if values is not None:
+ self.values.non_generative(self, values)
+ if prefixes:
+ self._setup_prefixes(prefixes)
+
+ @_generative
+ @_exclusive_against(
+ "_select_names",
+ "_ordered_values",
+ msgs={
+ "_select_names": "This construct already inserts from a SELECT",
+ "_ordered_values": "This statement already has ordered "
+ "values present",
+ },
+ )
+ def values(self, *args, **kwargs):
+ r"""Specify a fixed VALUES clause for an INSERT statement, or the SET
+ clause for an UPDATE.
+
+ Note that the :class:`_expression.Insert` and
+ :class:`_expression.Update`
+ constructs support
+ per-execution time formatting of the VALUES and/or SET clauses,
+ based on the arguments passed to :meth:`_engine.Connection.execute`.
+ However, the :meth:`.ValuesBase.values` method can be used to "fix" a
+ particular set of parameters into the statement.
+
+ Multiple calls to :meth:`.ValuesBase.values` will produce a new
+ construct, each one with the parameter list modified to include
+ the new parameters sent. In the typical case of a single
+ dictionary of parameters, the newly passed keys will replace
+ the same keys in the previous construct. In the case of a list-based
+ "multiple values" construct, each new list of values is extended
+ onto the existing list of values.
+
+ :param \**kwargs: key value pairs representing the string key
+ of a :class:`_schema.Column`
+ mapped to the value to be rendered into the
+ VALUES or SET clause::
+
+ users.insert().values(name="some name")
+
+ users.update().where(users.c.id==5).values(name="some name")
+
+ :param \*args: As an alternative to passing key/value parameters,
+ a dictionary, tuple, or list of dictionaries or tuples can be passed
+ as a single positional argument in order to form the VALUES or
+ SET clause of the statement. The forms that are accepted vary
+ based on whether this is an :class:`_expression.Insert` or an
+ :class:`_expression.Update` construct.
+
+ For either an :class:`_expression.Insert` or
+ :class:`_expression.Update`
+ construct, a single dictionary can be passed, which works the same as
+ that of the kwargs form::
+
+ users.insert().values({"name": "some name"})
+
+ users.update().values({"name": "some new name"})
+
+ Also for either form but more typically for the
+ :class:`_expression.Insert` construct, a tuple that contains an
+ entry for every column in the table is also accepted::
+
+ users.insert().values((5, "some name"))
+
+ The :class:`_expression.Insert` construct also supports being
+ passed a list of dictionaries or full-table-tuples, which on the
+ server will render the less common SQL syntax of "multiple values" -
+ this syntax is supported on backends such as SQLite, PostgreSQL,
+ MySQL, but not necessarily others::
+
+ users.insert().values([
+ {"name": "some name"},
+ {"name": "some other name"},
+ {"name": "yet another name"},
+ ])
+
+ The above form would render a multiple VALUES statement similar to::
+
+ INSERT INTO users (name) VALUES
+ (:name_1),
+ (:name_2),
+ (:name_3)
+
+ It is essential to note that **passing multiple values is
+ NOT the same as using traditional executemany() form**. The above
+ syntax is a **special** syntax not typically used. To emit an
+ INSERT statement against multiple rows, the normal method is
+ to pass a multiple values list to the
+ :meth:`_engine.Connection.execute`
+ method, which is supported by all database backends and is generally
+ more efficient for a very large number of parameters.
+
+ .. seealso::
+
+ :ref:`tutorial_multiple_parameters` - an introduction to
+ the traditional Core method of multiple parameter set
+ invocation for INSERTs and other statements.
+
+ .. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES
+ clause, even a list of length one,
+ implies that the :paramref:`_expression.Insert.inline`
+ flag is set to
+ True, indicating that the statement will not attempt to fetch
+ the "last inserted primary key" or other defaults. The
+ statement deals with an arbitrary number of rows, so the
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ accessor does not
+ apply.
+
+ .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports
+ columns with Python side default values and callables in the
+ same way as that of an "executemany" style of invocation; the
+ callable is invoked for each row. See :ref:`bug_3288`
+ for other details.
+
+ The UPDATE construct also supports rendering the SET parameters
+ in a specific order. For this feature refer to the
+ :meth:`_expression.Update.ordered_values` method.
+
+ .. seealso::
+
+ :meth:`_expression.Update.ordered_values`
+
+
+ """
+ if args:
+ # positional case. this is currently expensive. we don't
+ # yet have positional-only args so we have to check the length.
+ # then we need to check multiparams vs. single dictionary.
+ # since the parameter format is needed in order to determine
+ # a cache key, we need to determine this up front.
+ arg = args[0]
+
+ if kwargs:
+ raise exc.ArgumentError(
+ "Can't pass positional and kwargs to values() "
+ "simultaneously"
+ )
+ elif len(args) > 1:
+ raise exc.ArgumentError(
+ "Only a single dictionary/tuple or list of "
+ "dictionaries/tuples is accepted positionally."
+ )
+
+ elif not self._preserve_parameter_order and isinstance(
+ arg, collections_abc.Sequence
+ ):
+
+ if arg and isinstance(arg[0], (list, dict, tuple)):
+ self._multi_values += (arg,)
+ return
+
+ # tuple values
+ arg = {c.key: value for c, value in zip(self.table.c, arg)}
+ elif self._preserve_parameter_order and not isinstance(
+ arg, collections_abc.Sequence
+ ):
+ raise ValueError(
+ "When preserve_parameter_order is True, "
+ "values() only accepts a list of 2-tuples"
+ )
+
+ else:
+ # kwarg path. this is the most common path for non-multi-params
+ # so this is fairly quick.
+ arg = kwargs
+ if args:
+ raise exc.ArgumentError(
+ "Only a single dictionary/tuple or list of "
+ "dictionaries/tuples is accepted positionally."
+ )
+
+ # for top level values(), convert literals to anonymous bound
+ # parameters at statement construction time, so that these values can
+ # participate in the cache key process like any other ClauseElement.
+ # crud.py now intercepts bound parameters with unique=True from here
+ # and ensures they get the "crud"-style name when rendered.
+
+ kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
+
+ if self._preserve_parameter_order:
+ self._ordered_values = kv_generator(self, arg)
+ else:
+ arg = {k: v for k, v in kv_generator(self, arg.items())}
+ if self._values:
+ self._values = self._values.union(arg)
+ else:
+ self._values = util.immutabledict(arg)
+
+ @_generative
+ @_exclusive_against(
+ "_returning",
+ msgs={
+ "_returning": "RETURNING is already configured on this statement"
+ },
+ defaults={"_returning": _returning},
+ )
+ def return_defaults(self, *cols):
+ """Make use of a :term:`RETURNING` clause for the purpose
+ of fetching server-side expressions and defaults.
+
+ E.g.::
+
+ stmt = table.insert().values(data='newdata').return_defaults()
+
+ result = connection.execute(stmt)
+
+ server_created_at = result.returned_defaults['created_at']
+
+ When used against a backend that supports RETURNING, all column
+ values generated by SQL expression or server-side-default will be
+ added to any existing RETURNING clause, provided that
+ :meth:`.UpdateBase.returning` is not used simultaneously. The column
+ values will then be available on the result using the
+ :attr:`_engine.CursorResult.returned_defaults` accessor as
+ a dictionary,
+ referring to values keyed to the :class:`_schema.Column`
+ object as well as
+ its ``.key``.
+
+ This method differs from :meth:`.UpdateBase.returning` in these ways:
+
+ 1. :meth:`.ValuesBase.return_defaults` is only intended for use with an
+ INSERT or an UPDATE statement that matches exactly one row per
+ parameter set. While the RETURNING construct in the general sense
+ supports multiple rows for a multi-row UPDATE or DELETE statement,
+ or for special cases of INSERT that return multiple rows (e.g.
+ INSERT from SELECT, multi-valued VALUES clause),
+ :meth:`.ValuesBase.return_defaults` is intended only for an
+ "ORM-style" single-row INSERT/UPDATE statement. The row
+ returned by the statement is also consumed implicitly when
+ :meth:`.ValuesBase.return_defaults` is used. By contrast,
+ :meth:`.UpdateBase.returning` leaves the RETURNING result-set intact
+ with a collection of any number of rows.
+
+ 2. It is compatible with the existing logic to fetch auto-generated
+ primary key values, also known as "implicit returning". Backends
+ that support RETURNING will automatically make use of RETURNING in
+ order to fetch the value of newly generated primary keys; while the
+ :meth:`.UpdateBase.returning` method circumvents this behavior,
+ :meth:`.ValuesBase.return_defaults` leaves it intact.
+
+ 3. It can be called against any backend. Backends that don't support
+ RETURNING will skip the usage of the feature, rather than raising
+ an exception. The return value of
+ :attr:`_engine.CursorResult.returned_defaults` will be ``None``
+
+ 4. An INSERT statement invoked with executemany() is supported if the
+ backend database driver supports the
+ ``insert_executemany_returning`` feature, currently this includes
+ PostgreSQL with psycopg2. When executemany is used, the
+ :attr:`_engine.CursorResult.returned_defaults_rows` and
+ :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
+ will return the inserted defaults and primary keys.
+
+ .. versionadded:: 1.4
+
+ :meth:`.ValuesBase.return_defaults` is used by the ORM to provide
+ an efficient implementation for the ``eager_defaults`` feature of
+ :func:`.mapper`.
+
+ :param cols: optional list of column key names or
+ :class:`_schema.Column`
+ objects. If omitted, all column expressions evaluated on the server
+ are added to the returning list.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :meth:`.UpdateBase.returning`
+
+ :attr:`_engine.CursorResult.returned_defaults`
+
+ :attr:`_engine.CursorResult.returned_defaults_rows`
+
+ :attr:`_engine.CursorResult.inserted_primary_key`
+
+ :attr:`_engine.CursorResult.inserted_primary_key_rows`
+
+ """
+ self._return_defaults = True
+ self._return_defaults_columns = cols
+
+
+class Insert(ValuesBase):
+ """Represent an INSERT construct.
+
+ The :class:`_expression.Insert` object is created using the
+ :func:`_expression.insert()` function.
+
+ """
+
+ __visit_name__ = "insert"
+
+ _supports_multi_parameters = True
+
+ select = None
+ include_insert_from_select_defaults = False
+
+ is_insert = True
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_select_names", InternalTraversal.dp_string_list),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_multi_values", InternalTraversal.dp_dml_multi_values),
+ ("select", InternalTraversal.dp_clauseelement),
+ ("_post_values_clause", InternalTraversal.dp_clauseelement),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ("_return_defaults", InternalTraversal.dp_boolean),
+ (
+ "_return_defaults_columns",
+ InternalTraversal.dp_clauseelement_list,
+ ),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
+ + HasCTE._has_ctes_traverse_internals
+ )
+
+ @ValuesBase._constructor_20_deprecations(
+ "insert",
+ "Insert",
+ [
+ "values",
+ "inline",
+ "bind",
+ "prefixes",
+ "returning",
+ "return_defaults",
+ ],
+ )
+ def __init__(
+ self,
+ table,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ **dialect_kw
+ ):
+ """Construct an :class:`_expression.Insert` object.
+
+ E.g.::
+
+ from sqlalchemy import insert
+
+ stmt = (
+ insert(user_table).
+ values(name='username', fullname='Full Username')
+ )
+
+ Similar functionality is available via the
+ :meth:`_expression.TableClause.insert` method on
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial`
+
+
+ :param table: :class:`_expression.TableClause`
+ which is the subject of the
+ insert.
+
+ :param values: collection of values to be inserted; see
+ :meth:`_expression.Insert.values`
+ for a description of allowed formats here.
+ Can be omitted entirely; a :class:`_expression.Insert` construct
+ will also dynamically render the VALUES clause at execution time
+ based on the parameters passed to :meth:`_engine.Connection.execute`.
+
+ :param inline: if True, no attempt will be made to retrieve the
+ SQL-generated default values to be provided within the statement;
+ in particular,
+ this allows SQL expressions to be rendered 'inline' within the
+ statement without the need to pre-execute them beforehand; for
+ backends that support "returning", this turns off the "implicit
+ returning" feature for the statement.
+
+ If both :paramref:`_expression.Insert.values` and compile-time bind
+ parameters are present, the compile-time bind parameters override the
+ information specified within :paramref:`_expression.Insert.values` on a
+ per-key basis.
+
+ The keys within :paramref:`_expression.Insert.values` can be either
+ :class:`~sqlalchemy.schema.Column` objects or their string
+ identifiers. Each key may reference one of:
+
+ * a literal data value (i.e. string, number, etc.);
+ * a Column object;
+ * a SELECT statement.
+
+ If a ``SELECT`` statement is specified which references this
+ ``INSERT`` statement's table, the statement will be correlated
+ against the ``INSERT`` statement.
+
+ .. seealso::
+
+ :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial`
+
+ """
+ super(Insert, self).__init__(table, values, prefixes)
+ self._bind = bind
+ self._inline = inline
+ if returning:
+ self._returning = returning
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
+
+ if return_defaults:
+ self._return_defaults = True
+ if not isinstance(return_defaults, bool):
+ self._return_defaults_columns = return_defaults
+
+ @_generative
+ def inline(self):
+ """Make this :class:`_expression.Insert` construct "inline" .
+
+ When set, no attempt will be made to retrieve the
+ SQL-generated default values to be provided within the statement;
+ in particular,
+ this allows SQL expressions to be rendered 'inline' within the
+ statement without the need to pre-execute them beforehand; for
+ backends that support "returning", this turns off the "implicit
+ returning" feature for the statement.
+
+
+ .. versionchanged:: 1.4 the :paramref:`_expression.Insert.inline`
+ parameter
+ is now superseded by the :meth:`_expression.Insert.inline` method.
+
+ """
+ self._inline = True
+
+ @_generative
+ def from_select(self, names, select, include_defaults=True):
+ """Return a new :class:`_expression.Insert` construct which represents
+ an ``INSERT...FROM SELECT`` statement.
+
+ e.g.::
+
+ sel = select(table1.c.a, table1.c.b).where(table1.c.c > 5)
+ ins = table2.insert().from_select(['a', 'b'], sel)
+
+ :param names: a sequence of string column names or
+ :class:`_schema.Column`
+ objects representing the target columns.
+ :param select: a :func:`_expression.select` construct,
+ :class:`_expression.FromClause`
+ or other construct which resolves into a
+ :class:`_expression.FromClause`,
+ such as an ORM :class:`_query.Query` object, etc. The order of
+ columns returned from this FROM clause should correspond to the
+ order of columns sent as the ``names`` parameter; while this
+ is not checked before passing along to the database, the database
+ would normally raise an exception if these column lists don't
+ correspond.
+ :param include_defaults: if True, non-server default values and
+ SQL expressions as specified on :class:`_schema.Column` objects
+ (as documented in :ref:`metadata_defaults_toplevel`) not
+ otherwise specified in the list of names will be rendered
+ into the INSERT and SELECT statements, so that these values are also
+ included in the data to be inserted.
+
+ .. note:: A Python-side default that uses a Python callable function
+ will only be invoked **once** for the whole statement, and **not
+ per row**.
+
+ .. versionadded:: 1.0.0 - :meth:`_expression.Insert.from_select`
+ now renders
+ Python-side and SQL expression column defaults into the
+ SELECT statement for columns otherwise not included in the
+ list of column names.
+
+ .. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
+ implies that the :paramref:`_expression.insert.inline`
+ flag is set to
+ True, indicating that the statement will not attempt to fetch
+ the "last inserted primary key" or other defaults. The statement
+ deals with an arbitrary number of rows, so the
+ :attr:`_engine.CursorResult.inserted_primary_key`
+ accessor does not apply.
+
+ """
+
+ if self._values:
+ raise exc.InvalidRequestError(
+ "This construct already inserts value expressions"
+ )
+
+ self._select_names = names
+ self._inline = True
+ self.include_insert_from_select_defaults = include_defaults
+ self.select = coercions.expect(roles.DMLSelectRole, select)
+
+
+class DMLWhereBase(object):
+ _where_criteria = ()
+
+ @_generative
+ def where(self, *whereclause):
+ """Return a new construct with the given expression(s) added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+ Both :meth:`_dml.Update.where` and :meth:`_dml.Delete.where`
+ support multiple-table forms, including database-specific
+ ``UPDATE...FROM`` as well as ``DELETE..USING``. For backends that
+ don't have multiple-table support, a backend agnostic approach
+ to using multiple tables is to make use of correlated subqueries.
+ See the linked tutorial sections below for examples.
+
+ .. seealso::
+
+ :ref:`tutorial_correlated_updates`
+
+ :ref:`tutorial_update_from`
+
+ :ref:`tutorial_multi_table_deletes`
+
+ """
+
+ for criterion in whereclause:
+ where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ self._where_criteria += (where_criteria,)
+
+ def filter(self, *criteria):
+ """A synonym for the :meth:`_dml.DMLWhereBase.where` method.
+
+ .. versionadded:: 1.4
+
+ """
+
+ return self.where(*criteria)
+
+ def _filter_by_zero(self):
+ return self.table
+
+ def filter_by(self, **kwargs):
+ r"""apply the given filtering criterion as a WHERE clause
+ to this select.
+
+ """
+ from_entity = self._filter_by_zero()
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
+ @property
+ def whereclause(self):
+ """Return the completed WHERE clause for this :class:`.DMLWhereBase`
+ statement.
+
+ This assembles the current collection of WHERE criteria
+ into a single :class:`_expression.BooleanClauseList` construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
+
+class Update(DMLWhereBase, ValuesBase):
+ """Represent an Update construct.
+
+ The :class:`_expression.Update` object is created using the
+ :func:`_expression.update()` function.
+
+ """
+
+ __visit_name__ = "update"
+
+ is_update = True
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_inline", InternalTraversal.dp_boolean),
+ ("_ordered_values", InternalTraversal.dp_dml_ordered_values),
+ ("_values", InternalTraversal.dp_dml_values),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ("_return_defaults", InternalTraversal.dp_boolean),
+ (
+ "_return_defaults_columns",
+ InternalTraversal.dp_clauseelement_list,
+ ),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
+ + HasCTE._has_ctes_traverse_internals
+ )
+
+ @ValuesBase._constructor_20_deprecations(
+ "update",
+ "Update",
+ [
+ "whereclause",
+ "values",
+ "inline",
+ "bind",
+ "prefixes",
+ "returning",
+ "return_defaults",
+ "preserve_parameter_order",
+ ],
+ )
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ preserve_parameter_order=False,
+ **dialect_kw
+ ):
+ r"""Construct an :class:`_expression.Update` object.
+
+ E.g.::
+
+ from sqlalchemy import update
+
+ stmt = (
+ update(user_table).
+ where(user_table.c.id == 5).
+ values(name='user #5')
+ )
+
+ Similar functionality is available via the
+ :meth:`_expression.TableClause.update` method on
+ :class:`_schema.Table`.
+
+ :param table: A :class:`_schema.Table`
+ object representing the database
+ table to be updated.
+
+ :param whereclause: Optional SQL expression describing the ``WHERE``
+ condition of the ``UPDATE`` statement; is equivalent to using the
+ more modern :meth:`~Update.where()` method to specify the ``WHERE``
+ clause.
+
+ :param values:
+ Optional dictionary which specifies the ``SET`` conditions of the
+ ``UPDATE``. If left as ``None``, the ``SET``
+ conditions are determined from those parameters passed to the
+ statement during the execution and/or compilation of the
+ statement. When compiled standalone without any parameters,
+ the ``SET`` clause generates for all columns.
+
+ Modern applications may prefer to use the generative
+ :meth:`_expression.Update.values` method to set the values of the
+ UPDATE statement.
+
+ :param inline:
+ if True, SQL defaults present on :class:`_schema.Column` objects via
+ the ``default`` keyword will be compiled 'inline' into the statement
+ and not pre-executed. This means that their values will not
+ be available in the dictionary returned from
+ :meth:`_engine.CursorResult.last_updated_params`.
+
+ :param preserve_parameter_order: if True, the update statement is
+ expected to receive parameters **only** via the
+ :meth:`_expression.Update.values` method,
+ and they must be passed as a Python
+ ``list`` of 2-tuples. The rendered UPDATE statement will emit the SET
+ clause for each referenced column maintaining this order.
+
+ .. versionadded:: 1.0.10
+
+ .. seealso::
+
+ :ref:`updates_order_parameters` - illustrates the
+ :meth:`_expression.Update.ordered_values` method.
+
+ If both ``values`` and compile-time bind parameters are present, the
+ compile-time bind parameters override the information specified
+ within ``values`` on a per-key basis.
+
+ The keys within ``values`` can be either :class:`_schema.Column`
+ objects or their string identifiers (specifically the "key" of the
+ :class:`_schema.Column`, normally but not necessarily equivalent to
+ its "name"). Normally, the
+ :class:`_schema.Column` objects used here are expected to be
+ part of the target :class:`_schema.Table` that is the table
+ to be updated. However when using MySQL, a multiple-table
+ UPDATE statement can refer to columns from any of
+ the tables referred to in the WHERE clause.
+
+ The values referred to in ``values`` are typically:
+
+ * a literal data value (i.e. string, number, etc.)
+ * a SQL expression, such as a related :class:`_schema.Column`,
+ a scalar-returning :func:`_expression.select` construct,
+ etc.
+
+ When combining :func:`_expression.select` constructs within the
+ values clause of an :func:`_expression.update`
+ construct, the subquery represented
+ by the :func:`_expression.select` should be *correlated* to the
+ parent table, that is, providing criterion which links the table inside
+ the subquery to the outer table being updated::
+
+ users.update().values(
+ name=select(addresses.c.email_address).\
+ where(addresses.c.user_id==users.c.id).\
+ scalar_subquery()
+ )
+
+ .. seealso::
+
+ :ref:`inserts_and_updates` - SQL Expression
+ Language Tutorial
+
+
+ """
+ self._preserve_parameter_order = preserve_parameter_order
+ super(Update, self).__init__(table, values, prefixes)
+ self._bind = bind
+ if returning:
+ self._returning = returning
+ if whereclause is not None:
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
+ )
+ self._inline = inline
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
+ self._return_defaults = return_defaults
+
+ @_generative
+ def ordered_values(self, *args):
+ """Specify the VALUES clause of this UPDATE statement with an explicit
+ parameter ordering that will be maintained in the SET clause of the
+ resulting UPDATE statement.
+
+ E.g.::
+
+ stmt = table.update().ordered_values(
+ ("name", "ed"), ("ident": "foo")
+ )
+
+ .. seealso::
+
+ :ref:`tutorial_parameter_ordered_updates` - full example of the
+ :meth:`_expression.Update.ordered_values` method.
+
+ .. versionchanged:: 1.4 The :meth:`_expression.Update.ordered_values`
+ method
+ supersedes the
+ :paramref:`_expression.update.preserve_parameter_order`
+ parameter, which will be removed in SQLAlchemy 2.0.
+
+ """
+ if self._values:
+ raise exc.ArgumentError(
+ "This statement already has values present"
+ )
+ elif self._ordered_values:
+ raise exc.ArgumentError(
+ "This statement already has ordered values present"
+ )
+
+ kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
+ self._ordered_values = kv_generator(self, args)
+
+ @_generative
+ def inline(self):
+ """Make this :class:`_expression.Update` construct "inline" .
+
+ When set, SQL defaults present on :class:`_schema.Column`
+ objects via the
+ ``default`` keyword will be compiled 'inline' into the statement and
+ not pre-executed. This means that their values will not be available
+ in the dictionary returned from
+ :meth:`_engine.CursorResult.last_updated_params`.
+
+ .. versionchanged:: 1.4 the :paramref:`_expression.update.inline`
+ parameter
+ is now superseded by the :meth:`_expression.Update.inline` method.
+
+ """
+ self._inline = True
+
+
+class Delete(DMLWhereBase, UpdateBase):
+ """Represent a DELETE construct.
+
+ The :class:`_expression.Delete` object is created using the
+ :func:`_expression.delete()` function.
+
+ """
+
+ __visit_name__ = "delete"
+
+ is_delete = True
+
+ _traverse_internals = (
+ [
+ ("table", InternalTraversal.dp_clauseelement),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_list),
+ ("_returning", InternalTraversal.dp_clauseelement_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + DialectKWArgs._dialect_kwargs_traverse_internals
+ + Executable._executable_traverse_internals
+ + HasCTE._has_ctes_traverse_internals
+ )
+
+ @ValuesBase._constructor_20_deprecations(
+ "delete",
+ "Delete",
+ ["whereclause", "values", "bind", "prefixes", "returning"],
+ )
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ bind=None,
+ returning=None,
+ prefixes=None,
+ **dialect_kw
+ ):
+ r"""Construct :class:`_expression.Delete` object.
+
+ E.g.::
+
+ from sqlalchemy import delete
+
+ stmt = (
+ delete(user_table).
+ where(user_table.c.id == 5)
+ )
+
+ Similar functionality is available via the
+ :meth:`_expression.TableClause.delete` method on
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :ref:`inserts_and_updates` - in the
+ :ref:`1.x tutorial <sqlexpression_toplevel>`
+
+ :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial`
+
+
+ :param table: The table to delete rows from.
+
+ :param whereclause: Optional SQL expression describing the ``WHERE``
+ condition of the ``DELETE`` statement; is equivalent to using the
+ more modern :meth:`~Delete.where()` method to specify the ``WHERE``
+ clause.
+
+ .. seealso::
+
+ :ref:`deletes` - SQL Expression Tutorial
+
+ """
+ self._bind = bind
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
+ if returning:
+ self._returning = returning
+
+ if prefixes:
+ self._setup_prefixes(prefixes)
+
+ if whereclause is not None:
+ self._where_criteria += (
+ coercions.expect(roles.WhereHavingRole, whereclause),
+ )
+
+ if dialect_kw:
+ self._validate_dialect_kwargs_deprecated(dialect_kw)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
new file mode 100644
index 0000000..268c0d6
--- /dev/null
+++ b/lib/sqlalchemy/sql/elements.py
@@ -0,0 +1,5415 @@
+# sql/elements.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Core SQL expression elements, including :class:`_expression.ClauseElement`,
+:class:`_expression.ColumnElement`, and derived classes.
+
+"""
+
+from __future__ import unicode_literals
+
+import itertools
+import operator
+import re
+
+from . import coercions
+from . import operators
+from . import roles
+from . import traversals
+from . import type_api
+from .annotation import Annotated
+from .annotation import SupportsWrappingAnnotations
+from .base import _clone
+from .base import _generative
+from .base import Executable
+from .base import HasMemoized
+from .base import Immutable
+from .base import NO_ARG
+from .base import PARSE_AUTOCOMMIT
+from .base import SingletonConstant
+from .coercions import _document_text_coercion
+from .traversals import HasCopyInternals
+from .traversals import MemoizedHasCacheKey
+from .traversals import NO_CACHE
+from .visitors import cloned_traverse
+from .visitors import InternalTraversal
+from .visitors import traverse
+from .visitors import Traversible
+from .. import exc
+from .. import inspection
+from .. import util
+
+
+def collate(expression, collation):
+ """Return the clause ``expression COLLATE collation``.
+
+ e.g.::
+
+ collate(mycolumn, 'utf8_bin')
+
+ produces::
+
+ mycolumn COLLATE utf8_bin
+
+ The collation expression is also quoted if it is a case sensitive
+ identifier, e.g. contains uppercase characters.
+
+ .. versionchanged:: 1.2 quoting is automatically applied to COLLATE
+ expressions if they are case sensitive.
+
+ """
+
+ expr = coercions.expect(roles.ExpressionElementRole, expression)
+ return BinaryExpression(
+ expr, CollationClause(collation), operators.collate, type_=expr.type
+ )
+
+
+def between(expr, lower_bound, upper_bound, symmetric=False):
+ """Produce a ``BETWEEN`` predicate clause.
+
+ E.g.::
+
+ from sqlalchemy import between
+ stmt = select(users_table).where(between(users_table.c.id, 5, 7))
+
+ Would produce SQL resembling::
+
+ SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2
+
+ The :func:`.between` function is a standalone version of the
+ :meth:`_expression.ColumnElement.between` method available on all
+ SQL expressions, as in::
+
+ stmt = select(users_table).where(users_table.c.id.between(5, 7))
+
+ All arguments passed to :func:`.between`, including the left side
+ column expression, are coerced from Python scalar values if a
+ the value is not a :class:`_expression.ColumnElement` subclass.
+ For example,
+ three fixed values can be compared as in::
+
+ print(between(5, 3, 7))
+
+ Which would produce::
+
+ :param_1 BETWEEN :param_2 AND :param_3
+
+ :param expr: a column expression, typically a
+ :class:`_expression.ColumnElement`
+ instance or alternatively a Python scalar expression to be coerced
+ into a column expression, serving as the left side of the ``BETWEEN``
+ expression.
+
+ :param lower_bound: a column or Python scalar expression serving as the
+ lower bound of the right side of the ``BETWEEN`` expression.
+
+ :param upper_bound: a column or Python scalar expression serving as the
+ upper bound of the right side of the ``BETWEEN`` expression.
+
+ :param symmetric: if True, will render " BETWEEN SYMMETRIC ". Note
+ that not all databases support this syntax.
+
+ .. versionadded:: 0.9.5
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.between`
+
+ """
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+ return expr.between(lower_bound, upper_bound, symmetric=symmetric)
+
+
+def literal(value, type_=None):
+ r"""Return a literal clause, bound to a bind parameter.
+
+ Literal clauses are created automatically when non-
+ :class:`_expression.ClauseElement` objects (such as strings, ints, dates,
+ etc.) are
+ used in a comparison operation with a :class:`_expression.ColumnElement`
+ subclass,
+ such as a :class:`~sqlalchemy.schema.Column` object. Use this function
+ to force the generation of a literal clause, which will be created as a
+ :class:`BindParameter` with a bound value.
+
+ :param value: the value to be bound. Can be any Python object supported by
+ the underlying DB-API, or is translatable via the given type argument.
+
+ :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which
+ will provide bind-parameter translation for this literal.
+
+ """
+ return coercions.expect(roles.LiteralValueRole, value, type_=type_)
+
+
+def outparam(key, type_=None):
+ r"""Create an 'OUT' parameter for usage in functions (stored procedures),
+ for databases which support them.
+
+ The ``outparam`` can be used like a regular function parameter.
+ The "output" value will be available from the
+ :class:`~sqlalchemy.engine.CursorResult` object via its ``out_parameters``
+ attribute, which returns a dictionary containing the values.
+
+ """
+ return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
+
+
+def not_(clause):
+ """Return a negation of the given clause, i.e. ``NOT(clause)``.
+
+ The ``~`` operator is also overloaded on all
+ :class:`_expression.ColumnElement` subclasses to produce the
+ same result.
+
+ """
+ return operators.inv(coercions.expect(roles.ExpressionElementRole, clause))
+
+
+@inspection._self_inspects
+class ClauseElement(
+ roles.SQLRole,
+ SupportsWrappingAnnotations,
+ MemoizedHasCacheKey,
+ HasCopyInternals,
+ Traversible,
+):
+ """Base class for elements of a programmatically constructed SQL
+ expression.
+
+ """
+
+ __visit_name__ = "clause"
+
+ _propagate_attrs = util.immutabledict()
+ """like annotations, however these propagate outwards liberally
+ as SQL constructs are built, and are set up at construction time.
+
+ """
+
+ supports_execution = False
+
+ stringify_dialect = "default"
+
+ _from_objects = []
+ bind = None
+ description = None
+ _is_clone_of = None
+
+ is_clause_element = True
+ is_selectable = False
+
+ _is_textual = False
+ _is_from_clause = False
+ _is_returns_rows = False
+ _is_text_clause = False
+ _is_from_container = False
+ _is_select_container = False
+ _is_select_statement = False
+ _is_bind_parameter = False
+ _is_clause_list = False
+ _is_lambda_element = False
+ _is_singleton_constant = False
+ _is_immutable = False
+ _is_star = False
+
+ _order_by_label_element = None
+
+ _cache_key_traversal = None
+
+ def _set_propagate_attrs(self, values):
+ # usually, self._propagate_attrs is empty here. one case where it's
+ # not is a subquery against ORM select, that is then pulled as a
+ # property of an aliased class. should all be good
+
+ # assert not self._propagate_attrs
+
+ self._propagate_attrs = util.immutabledict(values)
+ return self
+
+ def _clone(self, **kw):
+ """Create a shallow copy of this ClauseElement.
+
+ This method may be used by a generative API. Its also used as
+ part of the "deep" copy afforded by a traversal that combines
+ the _copy_internals() method.
+
+ """
+ skip = self._memoized_keys
+ c = self.__class__.__new__(self.__class__)
+
+ if skip:
+ # ensure this iteration remains atomic
+ c.__dict__ = {
+ k: v for k, v in self.__dict__.copy().items() if k not in skip
+ }
+ else:
+ c.__dict__ = self.__dict__.copy()
+
+ # this is a marker that helps to "equate" clauses to each other
+ # when a Select returns its list of FROM clauses. the cloning
+ # process leaves around a lot of remnants of the previous clause
+ # typically in the form of column expressions still attached to the
+ # old table.
+ cc = self._is_clone_of
+ c._is_clone_of = cc if cc is not None else self
+ return c
+
+ def _negate_in_binary(self, negated_op, original_op):
+ """a hook to allow the right side of a binary expression to respond
+ to a negation of the binary expression.
+
+ Used for the special case of expanding bind parameter with IN.
+
+ """
+ return self
+
+ def _with_binary_element_type(self, type_):
+ """in the context of binary expression, convert the type of this
+ object to the one given.
+
+ applies only to :class:`_expression.ColumnElement` classes.
+
+ """
+ return self
+
+ @property
+ def _constructor(self):
+ """return the 'constructor' for this ClauseElement.
+
+ This is for the purposes for creating a new object of
+ this type. Usually, its just the element's __class__.
+ However, the "Annotated" version of the object overrides
+ to return the class of its proxied element.
+
+ """
+ return self.__class__
+
+ @HasMemoized.memoized_attribute
+ def _cloned_set(self):
+ """Return the set consisting all cloned ancestors of this
+ ClauseElement.
+
+ Includes this ClauseElement. This accessor tends to be used for
+ FromClause objects to identify 'equivalent' FROM clauses, regardless
+ of transformative operations.
+
+ """
+ s = util.column_set()
+ f = self
+
+ # note this creates a cycle, asserted in test_memusage. however,
+ # turning this into a plain @property adds tends of thousands of method
+ # calls to Core / ORM performance tests, so the small overhead
+ # introduced by the relatively small amount of short term cycles
+ # produced here is preferable
+ while f is not None:
+ s.add(f)
+ f = f._is_clone_of
+ return s
+
+ @property
+ def entity_namespace(self):
+ raise AttributeError(
+ "This SQL expression has no entity namespace "
+ "with which to filter from."
+ )
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d.pop("_is_clone_of", None)
+ d.pop("_generate_cache_key", None)
+ return d
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options, _force=False
+ ):
+ if _force or self.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
+
+ def unique_params(self, *optionaldict, **kwargs):
+ """Return a copy with :func:`_expression.bindparam` elements
+ replaced.
+
+ Same functionality as :meth:`_expression.ClauseElement.params`,
+ except adds `unique=True`
+ to affected bind parameters so that multiple statements can be
+ used.
+
+ """
+ return self._replace_params(True, optionaldict, kwargs)
+
+ def params(self, *optionaldict, **kwargs):
+ """Return a copy with :func:`_expression.bindparam` elements
+ replaced.
+
+ Returns a copy of this ClauseElement with
+ :func:`_expression.bindparam`
+ elements replaced with values taken from the given dictionary::
+
+ >>> clause = column('x') + bindparam('foo')
+ >>> print(clause.compile().params)
+ {'foo':None}
+ >>> print(clause.params({'foo':7}).compile().params)
+ {'foo':7}
+
+ """
+ return self._replace_params(False, optionaldict, kwargs)
+
+ def _replace_params(self, unique, optionaldict, kwargs):
+
+ if len(optionaldict) == 1:
+ kwargs.update(optionaldict[0])
+ elif len(optionaldict) > 1:
+ raise exc.ArgumentError(
+ "params() takes zero or one positional dictionary argument"
+ )
+
+ def visit_bindparam(bind):
+ if bind.key in kwargs:
+ bind.value = kwargs[bind.key]
+ bind.required = False
+ if unique:
+ bind._convert_to_unique()
+
+ return cloned_traverse(
+ self,
+ {"maintain_key": True, "detect_subquery_cols": True},
+ {"bindparam": visit_bindparam},
+ )
+
+ def compare(self, other, **kw):
+ r"""Compare this :class:`_expression.ClauseElement` to
+ the given :class:`_expression.ClauseElement`.
+
+ Subclasses should override the default behavior, which is a
+ straight identity comparison.
+
+ \**kw are arguments consumed by subclass ``compare()`` methods and
+ may be used to modify the criteria for comparison
+ (see :class:`_expression.ColumnElement`).
+
+ """
+ return traversals.compare(self, other, **kw)
+
+ def self_group(self, against=None):
+ """Apply a 'grouping' to this :class:`_expression.ClauseElement`.
+
+ This method is overridden by subclasses to return a "grouping"
+ construct, i.e. parenthesis. In particular it's used by "binary"
+ expressions to provide a grouping around themselves when placed into a
+ larger expression, as well as by :func:`_expression.select`
+ constructs when placed into the FROM clause of another
+ :func:`_expression.select`. (Note that subqueries should be
+ normally created using the :meth:`_expression.Select.alias` method,
+ as many
+ platforms require nested SELECT statements to be named).
+
+ As expressions are composed together, the application of
+ :meth:`self_group` is automatic - end-user code should never
+ need to use this method directly. Note that SQLAlchemy's
+ clause constructs take operator precedence into account -
+ so parenthesis might not be needed, for example, in
+ an expression like ``x OR (y AND z)`` - AND takes precedence
+ over OR.
+
+ The base :meth:`self_group` method of
+ :class:`_expression.ClauseElement`
+ just returns self.
+ """
+ return self
+
+ def _ungroup(self):
+ """Return this :class:`_expression.ClauseElement`
+ without any groupings.
+ """
+
+ return self
+
+ @util.preload_module("sqlalchemy.engine.default")
+ @util.preload_module("sqlalchemy.engine.url")
+ def compile(self, bind=None, dialect=None, **kw):
+ """Compile this SQL expression.
+
+ The return value is a :class:`~.Compiled` object.
+ Calling ``str()`` or ``unicode()`` on the returned value will yield a
+ string representation of the result. The
+ :class:`~.Compiled` object also can return a
+ dictionary of bind parameter names and values
+ using the ``params`` accessor.
+
+ :param bind: An ``Engine`` or ``Connection`` from which a
+ ``Compiled`` will be acquired. This argument takes precedence over
+ this :class:`_expression.ClauseElement`'s bound engine, if any.
+
+ :param column_keys: Used for INSERT and UPDATE statements, a list of
+ column names which should be present in the VALUES clause of the
+ compiled statement. If ``None``, all columns from the target table
+ object are rendered.
+
+ :param dialect: A ``Dialect`` instance from which a ``Compiled``
+ will be acquired. This argument takes precedence over the `bind`
+ argument as well as this :class:`_expression.ClauseElement`
+ 's bound engine,
+ if any.
+
+ :param compile_kwargs: optional dictionary of additional parameters
+ that will be passed through to the compiler within all "visit"
+ methods. This allows any custom flag to be passed through to
+ a custom compilation construct, for example. It is also used
+ for the case of passing the ``literal_binds`` flag through::
+
+ from sqlalchemy.sql import table, column, select
+
+ t = table('t', column('x'))
+
+ s = select(t).where(t.c.x == 5)
+
+ print(s.compile(compile_kwargs={"literal_binds": True}))
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`faq_sql_expression_string`
+
+ """
+
+ if not dialect:
+ if bind:
+ dialect = bind.dialect
+ elif self.bind:
+ dialect = self.bind.dialect
+ else:
+ if self.stringify_dialect == "default":
+ default = util.preloaded.engine_default
+ dialect = default.StrCompileDialect()
+ else:
+ url = util.preloaded.engine_url
+ dialect = url.URL.create(
+ self.stringify_dialect
+ ).get_dialect()()
+
+ return self._compiler(dialect, **kw)
+
+ def _compile_w_cache(
+ self,
+ dialect,
+ compiled_cache=None,
+ column_keys=None,
+ for_executemany=False,
+ schema_translate_map=None,
+ **kw
+ ):
+ if compiled_cache is not None and dialect._supports_statement_cache:
+ elem_cache_key = self._generate_cache_key()
+ else:
+ elem_cache_key = None
+
+ if elem_cache_key:
+ cache_key, extracted_params = elem_cache_key
+ key = (
+ dialect,
+ cache_key,
+ tuple(column_keys),
+ bool(schema_translate_map),
+ for_executemany,
+ )
+ compiled_sql = compiled_cache.get(key)
+
+ if compiled_sql is None:
+ cache_hit = dialect.CACHE_MISS
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ for_executemany=for_executemany,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+ compiled_cache[key] = compiled_sql
+ else:
+ cache_hit = dialect.CACHE_HIT
+ else:
+ extracted_params = None
+ compiled_sql = self._compiler(
+ dialect,
+ cache_key=elem_cache_key,
+ column_keys=column_keys,
+ for_executemany=for_executemany,
+ schema_translate_map=schema_translate_map,
+ **kw
+ )
+
+ if not dialect._supports_statement_cache:
+ cache_hit = dialect.NO_DIALECT_SUPPORT
+ elif compiled_cache is None:
+ cache_hit = dialect.CACHING_DISABLED
+ else:
+ cache_hit = dialect.NO_CACHE_KEY
+
+ return compiled_sql, extracted_params, cache_hit
+
+ def _compiler(self, dialect, **kw):
+ """Return a compiler appropriate for this ClauseElement, given a
+ Dialect."""
+
+ return dialect.statement_compiler(dialect, self, **kw)
+
+ def __str__(self):
+ if util.py3k:
+ return str(self.compile())
+ else:
+ return unicode(self.compile()).encode( # noqa
+ "ascii", "backslashreplace"
+ ) # noqa
+
+ def __invert__(self):
+ # undocumented element currently used by the ORM for
+ # relationship.contains()
+ if hasattr(self, "negation_clause"):
+ return self.negation_clause
+ else:
+ return self._negate()
+
+ def _negate(self):
+ return UnaryExpression(
+ self.self_group(against=operators.inv), operator=operators.inv
+ )
+
+ def __bool__(self):
+ raise TypeError("Boolean value of this clause is not defined")
+
+ __nonzero__ = __bool__
+
+ def __repr__(self):
+ friendly = self.description
+ if friendly is None:
+ return object.__repr__(self)
+ else:
+ return "<%s.%s at 0x%x; %s>" % (
+ self.__module__,
+ self.__class__.__name__,
+ id(self),
+ friendly,
+ )
+
+
+class ColumnElement(
+ roles.ColumnArgumentOrKeyRole,
+ roles.StatementOptionRole,
+ roles.WhereHavingRole,
+ roles.BinaryElementRole,
+ roles.OrderByRole,
+ roles.ColumnsClauseRole,
+ roles.LimitOffsetRole,
+ roles.DMLColumnRole,
+ roles.DDLConstraintColumnRole,
+ roles.DDLExpressionRole,
+ operators.ColumnOperators,
+ ClauseElement,
+):
+ """Represent a column-oriented SQL expression suitable for usage in the
+ "columns" clause, WHERE clause etc. of a statement.
+
+ While the most familiar kind of :class:`_expression.ColumnElement` is the
+ :class:`_schema.Column` object, :class:`_expression.ColumnElement`
+ serves as the basis
+ for any unit that may be present in a SQL expression, including
+ the expressions themselves, SQL functions, bound parameters,
+ literal expressions, keywords such as ``NULL``, etc.
+ :class:`_expression.ColumnElement`
+ is the ultimate base class for all such elements.
+
+ A wide variety of SQLAlchemy Core functions work at the SQL expression
+ level, and are intended to accept instances of
+ :class:`_expression.ColumnElement` as
+ arguments. These functions will typically document that they accept a
+ "SQL expression" as an argument. What this means in terms of SQLAlchemy
+ usually refers to an input which is either already in the form of a
+ :class:`_expression.ColumnElement` object,
+ or a value which can be **coerced** into
+ one. The coercion rules followed by most, but not all, SQLAlchemy Core
+ functions with regards to SQL expressions are as follows:
+
+ * a literal Python value, such as a string, integer or floating
+ point value, boolean, datetime, ``Decimal`` object, or virtually
+ any other Python object, will be coerced into a "literal bound
+ value". This generally means that a :func:`.bindparam` will be
+ produced featuring the given value embedded into the construct; the
+ resulting :class:`.BindParameter` object is an instance of
+ :class:`_expression.ColumnElement`.
+ The Python value will ultimately be sent
+ to the DBAPI at execution time as a parameterized argument to the
+ ``execute()`` or ``executemany()`` methods, after SQLAlchemy
+ type-specific converters (e.g. those provided by any associated
+ :class:`.TypeEngine` objects) are applied to the value.
+
+ * any special object value, typically ORM-level constructs, which
+ feature an accessor called ``__clause_element__()``. The Core
+ expression system looks for this method when an object of otherwise
+ unknown type is passed to a function that is looking to coerce the
+ argument into a :class:`_expression.ColumnElement` and sometimes a
+ :class:`_expression.SelectBase` expression.
+ It is used within the ORM to
+ convert from ORM-specific objects like mapped classes and
+ mapped attributes into Core expression objects.
+
+ * The Python ``None`` value is typically interpreted as ``NULL``,
+ which in SQLAlchemy Core produces an instance of :func:`.null`.
+
+ A :class:`_expression.ColumnElement` provides the ability to generate new
+ :class:`_expression.ColumnElement`
+ objects using Python expressions. This means that Python operators
+ such as ``==``, ``!=`` and ``<`` are overloaded to mimic SQL operations,
+ and allow the instantiation of further :class:`_expression.ColumnElement`
+ instances
+ which are composed from other, more fundamental
+ :class:`_expression.ColumnElement`
+ objects. For example, two :class:`.ColumnClause` objects can be added
+ together with the addition operator ``+`` to produce
+ a :class:`.BinaryExpression`.
+ Both :class:`.ColumnClause` and :class:`.BinaryExpression` are subclasses
+ of :class:`_expression.ColumnElement`::
+
+ >>> from sqlalchemy.sql import column
+ >>> column('a') + column('b')
+ <sqlalchemy.sql.expression.BinaryExpression object at 0x101029dd0>
+ >>> print(column('a') + column('b'))
+ a + b
+
+ .. seealso::
+
+ :class:`_schema.Column`
+
+ :func:`_expression.column`
+
+ """
+
+ __visit_name__ = "column_element"
+ primary_key = False
+ foreign_keys = []
+ _proxies = ()
+
+ _tq_label = None
+ """The named label that can be used to target
+ this column in a result set in a "table qualified" context.
+
+ This label is almost always the label used when
+ rendering <expr> AS <label> in a SELECT statement when using
+ the LABEL_STYLE_TABLENAME_PLUS_COL label style, which is what the legacy
+ ORM ``Query`` object uses as well.
+
+ For a regular Column bound to a Table, this is typically the label
+ <tablename>_<columnname>. For other constructs, different rules
+ may apply, such as anonymized labels and others.
+
+ .. versionchanged:: 1.4.21 renamed from ``._label``
+
+ """
+
+ key = None
+ """The 'key' that in some circumstances refers to this object in a
+ Python namespace.
+
+ This typically refers to the "key" of the column as present in the
+ ``.c`` collection of a selectable, e.g. ``sometable.c["somekey"]`` would
+ return a :class:`_schema.Column` with a ``.key`` of "somekey".
+
+ """
+
+ @HasMemoized.memoized_attribute
+ def _tq_key_label(self):
+ """A label-based version of 'key' that in some circumstances refers
+ to this object in a Python namespace.
+
+
+ _tq_key_label comes into play when a select() statement is constructed
+ with apply_labels(); in this case, all Column objects in the ``.c``
+ collection are rendered as <tablename>_<columnname> in SQL; this is
+ essentially the value of ._label. But to locate those columns in the
+ ``.c`` collection, the name is along the lines of <tablename>_<key>;
+ that's the typical value of .key_label.
+
+ .. versionchanged:: 1.4.21 renamed from ``._key_label``
+
+ """
+ return self._proxy_key
+
+ @property
+ def _key_label(self):
+ """legacy; renamed to _tq_key_label"""
+ return self._tq_key_label
+
+ @property
+ def _label(self):
+ """legacy; renamed to _tq_label"""
+ return self._tq_label
+
+ @property
+ def _non_anon_label(self):
+ """the 'name' that naturally applies this element when rendered in
+ SQL.
+
+ Concretely, this is the "name" of a column or a label in a
+ SELECT statement; ``<columnname>`` and ``<labelname>`` below::
+
+ SELECT <columnmame> FROM table
+
+ SELECT column AS <labelname> FROM table
+
+ Above, the two names noted will be what's present in the DBAPI
+ ``cursor.description`` as the names.
+
+ If this attribute returns ``None``, it means that the SQL element as
+ written does not have a 100% fully predictable "name" that would appear
+ in the ``cursor.description``. Examples include SQL functions, CAST
+ functions, etc. While such things do return names in
+ ``cursor.description``, they are only predictable on a
+ database-specific basis; e.g. an expression like ``MAX(table.col)`` may
+ appear as the string ``max`` on one database (like PostgreSQL) or may
+ appear as the whole expression ``max(table.col)`` on SQLite.
+
+ The default implementation looks for a ``.name`` attribute on the
+ object, as has been the precedent established in SQLAlchemy for many
+ years. An exception is made on the ``FunctionElement`` subclass
+ so that the return value is always ``None``.
+
+ .. versionadded:: 1.4.21
+
+
+
+ """
+ return getattr(self, "name", None)
+
+ _render_label_in_columns_clause = True
+ """A flag used by select._columns_plus_names that helps to determine
+ we are actually going to render in terms of "SELECT <col> AS <label>".
+ This flag can be returned as False for some Column objects that want
+ to be rendered as simple "SELECT <col>"; typically columns that don't have
+ any parent table and are named the same as what the label would be
+ in any case.
+
+ """
+
+ _allow_label_resolve = True
+ """A flag that can be flipped to prevent a column from being resolvable
+ by string label name.
+
+ The joined eager loader strategy in the ORM uses this, for example.
+
+ """
+
+ _is_implicitly_boolean = False
+
+ _alt_names = ()
+
+ def self_group(self, against=None):
+ if (
+ against in (operators.and_, operators.or_, operators._asbool)
+ and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity
+ ):
+ return AsBoolean(self, operators.is_true, operators.is_false)
+ elif against in (operators.any_op, operators.all_op):
+ return Grouping(self)
+ else:
+ return self
+
+ def _negate(self):
+ if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+ return AsBoolean(self, operators.is_false, operators.is_true)
+ else:
+ return super(ColumnElement, self)._negate()
+
+ @util.memoized_property
+ def type(self):
+ return type_api.NULLTYPE
+
+ @HasMemoized.memoized_attribute
+ def comparator(self):
+ try:
+ comparator_factory = self.type.comparator_factory
+ except AttributeError as err:
+ util.raise_(
+ TypeError(
+ "Object %r associated with '.type' attribute "
+ "is not a TypeEngine class or object" % self.type
+ ),
+ replace_context=err,
+ )
+ else:
+ return comparator_factory(self)
+
+ def __getattr__(self, key):
+ try:
+ return getattr(self.comparator, key)
+ except AttributeError as err:
+ util.raise_(
+ AttributeError(
+ "Neither %r object nor %r object has an attribute %r"
+ % (
+ type(self).__name__,
+ type(self.comparator).__name__,
+ key,
+ )
+ ),
+ replace_context=err,
+ )
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.comparator, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.comparator, **kwargs)
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ expanding=expanding,
+ )
+
+ @property
+ def expression(self):
+ """Return a column expression.
+
+ Part of the inspection interface; returns self.
+
+ """
+ return self
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ @util.memoized_property
+ def base_columns(self):
+ return util.column_set(c for c in self.proxy_set if not c._proxies)
+
+ @util.memoized_property
+ def proxy_set(self):
+ s = util.column_set([self])
+ for c in self._proxies:
+ s.update(c.proxy_set)
+ return s
+
+ def _uncached_proxy_set(self):
+ """An 'uncached' version of proxy set.
+
+ This is so that we can read annotations from the list of columns
+ without breaking the caching of the above proxy_set.
+
+ """
+ s = util.column_set([self])
+ for c in self._proxies:
+ s.update(c._uncached_proxy_set())
+ return s
+
+ def shares_lineage(self, othercolumn):
+ """Return True if the given :class:`_expression.ColumnElement`
+ has a common ancestor to this :class:`_expression.ColumnElement`."""
+
+ return bool(self.proxy_set.intersection(othercolumn.proxy_set))
+
+ def _compare_name_for_result(self, other):
+ """Return True if the given column element compares to this one
+ when targeting within a result row."""
+
+ return (
+ hasattr(other, "name")
+ and hasattr(self, "name")
+ and other.name == self.name
+ )
+
+ @HasMemoized.memoized_attribute
+ def _proxy_key(self):
+ if self._annotations and "proxy_key" in self._annotations:
+ return self._annotations["proxy_key"]
+
+ name = self.key
+ if not name:
+ # there's a bit of a seeming contradiction which is that the
+ # "_non_anon_label" of a column can in fact be an
+ # "_anonymous_label"; this is when it's on a column that is
+ # proxying for an anonymous expression in a subquery.
+ name = self._non_anon_label
+
+ if isinstance(name, _anonymous_label):
+ return None
+ else:
+ return name
+
+ @HasMemoized.memoized_attribute
+ def _expression_label(self):
+ """a suggested label to use in the case that the column has no name,
+ which should be used if possible as the explicit 'AS <label>'
+ where this expression would normally have an anon label.
+
+ this is essentially mostly what _proxy_key does except it returns
+ None if the column has a normal name that can be used.
+
+ """
+
+ if getattr(self, "name", None) is not None:
+ return None
+ elif self._annotations and "proxy_key" in self._annotations:
+ return self._annotations["proxy_key"]
+ else:
+ return None
+
+ def _make_proxy(
+ self, selectable, name=None, key=None, name_is_truncatable=False, **kw
+ ):
+ """Create a new :class:`_expression.ColumnElement` representing this
+ :class:`_expression.ColumnElement` as it appears in the select list of
+ a descending selectable.
+
+ """
+ if name is None:
+ name = self._anon_name_label
+ if key is None:
+ key = self._proxy_key
+ else:
+ key = name
+
+ co = ColumnClause(
+ coercions.expect(roles.TruncatedLabelRole, name)
+ if name_is_truncatable
+ else name,
+ type_=getattr(self, "type", None),
+ _selectable=selectable,
+ )
+
+ co._propagate_attrs = selectable._propagate_attrs
+ co._proxies = [self]
+ if selectable._is_clone_of is not None:
+ co._is_clone_of = selectable._is_clone_of.columns.get(key)
+ return key, co
+
+ def cast(self, type_):
+ """Produce a type cast, i.e. ``CAST(<expression> AS <type>)``.
+
+ This is a shortcut to the :func:`_expression.cast` function.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`_expression.cast`
+
+ :func:`_expression.type_coerce`
+
+ .. versionadded:: 1.0.7
+
+ """
+ return Cast(self, type_)
+
+ def label(self, name):
+ """Produce a column label, i.e. ``<columnname> AS <name>``.
+
+ This is a shortcut to the :func:`_expression.label` function.
+
+ If 'name' is ``None``, an anonymous label name will be generated.
+
+ """
+ return Label(name, self, self.type)
+
+ def _anon_label(self, seed, add_hash=None):
+ while self._is_clone_of is not None:
+ self = self._is_clone_of
+
+ # as of 1.4 anonymous label for ColumnElement uses hash(), not id(),
+ # as the identifier, because a column and its annotated version are
+ # the same thing in a SQL statement
+ hash_value = hash(self)
+
+ if add_hash:
+ # this path is used for disambiguating anon labels that would
+ # otherwise be the same name for the same element repeated.
+ # an additional numeric value is factored in for each label.
+
+ # shift hash(self) (which is id(self), typically 8 byte integer)
+ # 16 bits leftward. fill extra add_hash on right
+ assert add_hash < (2 << 15)
+ assert seed
+ hash_value = (hash_value << 16) | add_hash
+
+ # extra underscore is added for labels with extra hash
+ # values, to isolate the "deduped anon" namespace from the
+ # regular namespace. eliminates chance of these
+ # manufactured hash values overlapping with regular ones for some
+ # undefined python interpreter
+ seed = seed + "_"
+
+ if isinstance(seed, _anonymous_label):
+ return _anonymous_label.safe_construct(
+ hash_value, "", enclosing_label=seed
+ )
+
+ return _anonymous_label.safe_construct(hash_value, seed or "anon")
+
+ @util.memoized_property
+ def _anon_name_label(self):
+ """Provides a constant 'anonymous label' for this ColumnElement.
+
+ This is a label() expression which will be named at compile time.
+ The same label() is returned each time ``anon_label`` is called so
+ that expressions can reference ``anon_label`` multiple times,
+ producing the same label name at compile time.
+
+ The compiler uses this function automatically at compile time
+ for expressions that are known to be 'unnamed' like binary
+ expressions and function calls.
+
+ .. versionchanged:: 1.4.9 - this attribute was not intended to be
+ public and is renamed to _anon_name_label. anon_name exists
+ for backwards compat
+
+ """
+ name = getattr(self, "name", None)
+ return self._anon_label(name)
+
+ @util.memoized_property
+ def _anon_key_label(self):
+ """Provides a constant 'anonymous key label' for this ColumnElement.
+
+ Compare to ``anon_label``, except that the "key" of the column,
+ if available, is used to generate the label.
+
+ This is used when a deduplicating key is placed into the columns
+ collection of a selectable.
+
+ .. versionchanged:: 1.4.9 - this attribute was not intended to be
+ public and is renamed to _anon_key_label. anon_key_label exists
+ for backwards compat
+
+ """
+ return self._anon_label(self._proxy_key)
+
+ @property
+ @util.deprecated(
+ "1.4",
+ "The :attr:`_expression.ColumnElement.anon_label` attribute is now "
+ "private, and the public accessor is deprecated.",
+ )
+ def anon_label(self):
+ return self._anon_name_label
+
+ @property
+ @util.deprecated(
+ "1.4",
+ "The :attr:`_expression.ColumnElement.anon_key_label` attribute is "
+ "now private, and the public accessor is deprecated.",
+ )
+ def anon_key_label(self):
+ return self._anon_key_label
+
+ def _dedupe_anon_label_idx(self, idx):
+ """label to apply to a column that is anon labeled, but repeated
+ in the SELECT, so that we have to make an "extra anon" label that
+ disambiguates it from the previous appearance.
+
+ these labels come out like "foo_bar_id__1" and have double underscores
+ in them.
+
+ """
+ label = getattr(self, "name", None)
+
+ # current convention is that if the element doesn't have a
+ # ".name" (usually because it is not NamedColumn), we try to
+ # use a "table qualified" form for the "dedupe anon" label,
+ # based on the notion that a label like
+ # "CAST(casttest.v1 AS DECIMAL) AS casttest_v1__1" looks better than
+ # "CAST(casttest.v1 AS DECIMAL) AS anon__1"
+
+ if label is None:
+ return self._dedupe_anon_tq_label_idx(idx)
+ else:
+ return self._anon_label(label, add_hash=idx)
+
+ @util.memoized_property
+ def _anon_tq_label(self):
+ return self._anon_label(getattr(self, "_tq_label", None))
+
+ @util.memoized_property
+ def _anon_tq_key_label(self):
+ return self._anon_label(getattr(self, "_tq_key_label", None))
+
+ def _dedupe_anon_tq_label_idx(self, idx):
+ label = getattr(self, "_tq_label", None) or "anon"
+
+ return self._anon_label(label, add_hash=idx)
+
+
+class WrapsColumnExpression(object):
+ """Mixin that defines a :class:`_expression.ColumnElement`
+ as a wrapper with special
+ labeling behavior for an expression that already has a name.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`change_4449`
+
+
+ """
+
+ @property
+ def wrapped_column_expression(self):
+ raise NotImplementedError()
+
+ @property
+ def _tq_label(self):
+ wce = self.wrapped_column_expression
+ if hasattr(wce, "_tq_label"):
+ return wce._tq_label
+ else:
+ return None
+
+ _label = _tq_label
+
+ @property
+ def _non_anon_label(self):
+ return None
+
+ @property
+ def _anon_name_label(self):
+ wce = self.wrapped_column_expression
+
+ # this logic tries to get the WrappedColumnExpression to render
+ # with "<expr> AS <name>", where "<name>" is the natural name
+ # within the expression itself. e.g. "CAST(table.foo) AS foo".
+ if not wce._is_text_clause:
+ nal = wce._non_anon_label
+ if nal:
+ return nal
+ elif hasattr(wce, "_anon_name_label"):
+ return wce._anon_name_label
+ return super(WrapsColumnExpression, self)._anon_name_label
+
+ def _dedupe_anon_label_idx(self, idx):
+ wce = self.wrapped_column_expression
+ nal = wce._non_anon_label
+ if nal:
+ return self._anon_label(nal + "_")
+ else:
+ return self._dedupe_anon_tq_label_idx(idx)
+
+ @property
+ def _proxy_key(self):
+ wce = self.wrapped_column_expression
+
+ if not wce._is_text_clause:
+ return wce._proxy_key
+ return super(WrapsColumnExpression, self)._proxy_key
+
+
+class BindParameter(roles.InElementRole, ColumnElement):
+ r"""Represent a "bound expression".
+
+ :class:`.BindParameter` is invoked explicitly using the
+ :func:`.bindparam` function, as in::
+
+ from sqlalchemy import bindparam
+
+ stmt = select(users_table).\
+ where(users_table.c.name == bindparam('username'))
+
+ Detailed discussion of how :class:`.BindParameter` is used is
+ at :func:`.bindparam`.
+
+ .. seealso::
+
+ :func:`.bindparam`
+
+ """
+
+ __visit_name__ = "bindparam"
+
+ _traverse_internals = [
+ ("key", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("callable", InternalTraversal.dp_plain_dict),
+ ("value", InternalTraversal.dp_plain_obj),
+ ("literal_execute", InternalTraversal.dp_boolean),
+ ]
+
+ _is_crud = False
+ _is_bind_parameter = True
+ _key_is_anon = False
+
+ # bindparam implements its own _gen_cache_key() method however
+ # we check subclasses for this flag, else no cache key is generated
+ inherit_cache = True
+
+ def __init__(
+ self,
+ key,
+ value=NO_ARG,
+ type_=None,
+ unique=False,
+ required=NO_ARG,
+ quote=None,
+ callable_=None,
+ expanding=False,
+ isoutparam=False,
+ literal_execute=False,
+ _compared_to_operator=None,
+ _compared_to_type=None,
+ _is_crud=False,
+ ):
+ r"""Produce a "bound expression".
+
+ The return value is an instance of :class:`.BindParameter`; this
+ is a :class:`_expression.ColumnElement`
+ subclass which represents a so-called
+ "placeholder" value in a SQL expression, the value of which is
+ supplied at the point at which the statement in executed against a
+ database connection.
+
+ In SQLAlchemy, the :func:`.bindparam` construct has
+ the ability to carry along the actual value that will be ultimately
+ used at expression time. In this way, it serves not just as
+ a "placeholder" for eventual population, but also as a means of
+ representing so-called "unsafe" values which should not be rendered
+ directly in a SQL statement, but rather should be passed along
+ to the :term:`DBAPI` as values which need to be correctly escaped
+ and potentially handled for type-safety.
+
+ When using :func:`.bindparam` explicitly, the use case is typically
+ one of traditional deferment of parameters; the :func:`.bindparam`
+ construct accepts a name which can then be referred to at execution
+ time::
+
+ from sqlalchemy import bindparam
+
+ stmt = select(users_table).\
+ where(users_table.c.name == bindparam('username'))
+
+ The above statement, when rendered, will produce SQL similar to::
+
+ SELECT id, name FROM user WHERE name = :username
+
+ In order to populate the value of ``:username`` above, the value
+ would typically be applied at execution time to a method
+ like :meth:`_engine.Connection.execute`::
+
+ result = connection.execute(stmt, username='wendy')
+
+ Explicit use of :func:`.bindparam` is also common when producing
+ UPDATE or DELETE statements that are to be invoked multiple times,
+ where the WHERE criterion of the statement is to change on each
+ invocation, such as::
+
+ stmt = (users_table.update().
+ where(user_table.c.name == bindparam('username')).
+ values(fullname=bindparam('fullname'))
+ )
+
+ connection.execute(
+ stmt, [{"username": "wendy", "fullname": "Wendy Smith"},
+ {"username": "jack", "fullname": "Jack Jones"},
+ ]
+ )
+
+ SQLAlchemy's Core expression system makes wide use of
+ :func:`.bindparam` in an implicit sense. It is typical that Python
+ literal values passed to virtually all SQL expression functions are
+ coerced into fixed :func:`.bindparam` constructs. For example, given
+ a comparison operation such as::
+
+ expr = users_table.c.name == 'Wendy'
+
+ The above expression will produce a :class:`.BinaryExpression`
+ construct, where the left side is the :class:`_schema.Column` object
+ representing the ``name`` column, and the right side is a
+ :class:`.BindParameter` representing the literal value::
+
+ print(repr(expr.right))
+ BindParameter('%(4327771088 name)s', 'Wendy', type_=String())
+
+ The expression above will render SQL such as::
+
+ user.name = :name_1
+
+ Where the ``:name_1`` parameter name is an anonymous name. The
+ actual string ``Wendy`` is not in the rendered string, but is carried
+ along where it is later used within statement execution. If we
+ invoke a statement like the following::
+
+ stmt = select(users_table).where(users_table.c.name == 'Wendy')
+ result = connection.execute(stmt)
+
+ We would see SQL logging output as::
+
+ SELECT "user".id, "user".name
+ FROM "user"
+ WHERE "user".name = %(name_1)s
+ {'name_1': 'Wendy'}
+
+ Above, we see that ``Wendy`` is passed as a parameter to the database,
+ while the placeholder ``:name_1`` is rendered in the appropriate form
+ for the target database, in this case the PostgreSQL database.
+
+ Similarly, :func:`.bindparam` is invoked automatically when working
+ with :term:`CRUD` statements as far as the "VALUES" portion is
+ concerned. The :func:`_expression.insert` construct produces an
+ ``INSERT`` expression which will, at statement execution time, generate
+ bound placeholders based on the arguments passed, as in::
+
+ stmt = users_table.insert()
+ result = connection.execute(stmt, name='Wendy')
+
+ The above will produce SQL output as::
+
+ INSERT INTO "user" (name) VALUES (%(name)s)
+ {'name': 'Wendy'}
+
+ The :class:`_expression.Insert` construct, at
+ compilation/execution time, rendered a single :func:`.bindparam`
+ mirroring the column name ``name`` as a result of the single ``name``
+ parameter we passed to the :meth:`_engine.Connection.execute` method.
+
+ :param key:
+ the key (e.g. the name) for this bind param.
+ Will be used in the generated
+ SQL statement for dialects that use named parameters. This
+ value may be modified when part of a compilation operation,
+ if other :class:`BindParameter` objects exist with the same
+ key, or if its length is too long and truncation is
+ required.
+
+ :param value:
+ Initial value for this bind param. Will be used at statement
+ execution time as the value for this parameter passed to the
+ DBAPI, if no other value is indicated to the statement execution
+ method for this particular parameter name. Defaults to ``None``.
+
+ :param callable\_:
+ A callable function that takes the place of "value". The function
+ will be called at statement execution time to determine the
+ ultimate value. Used for scenarios where the actual bind
+ value cannot be determined at the point at which the clause
+ construct is created, but embedded bind values are still desirable.
+
+ :param type\_:
+ A :class:`.TypeEngine` class or instance representing an optional
+ datatype for this :func:`.bindparam`. If not passed, a type
+ may be determined automatically for the bind, based on the given
+ value; for example, trivial Python types such as ``str``,
+ ``int``, ``bool``
+ may result in the :class:`.String`, :class:`.Integer` or
+ :class:`.Boolean` types being automatically selected.
+
+ The type of a :func:`.bindparam` is significant especially in that
+ the type will apply pre-processing to the value before it is
+ passed to the database. For example, a :func:`.bindparam` which
+ refers to a datetime value, and is specified as holding the
+ :class:`.DateTime` type, may apply conversion needed to the
+ value (such as stringification on SQLite) before passing the value
+ to the database.
+
+ :param unique:
+ if True, the key name of this :class:`.BindParameter` will be
+ modified if another :class:`.BindParameter` of the same name
+ already has been located within the containing
+ expression. This flag is used generally by the internals
+ when producing so-called "anonymous" bound expressions, it
+ isn't generally applicable to explicitly-named :func:`.bindparam`
+ constructs.
+
+ :param required:
+ If ``True``, a value is required at execution time. If not passed,
+ it defaults to ``True`` if neither :paramref:`.bindparam.value`
+ or :paramref:`.bindparam.callable` were passed. If either of these
+ parameters are present, then :paramref:`.bindparam.required`
+ defaults to ``False``.
+
+ :param quote:
+ True if this parameter name requires quoting and is not
+ currently known as a SQLAlchemy reserved word; this currently
+ only applies to the Oracle backend, where bound names must
+ sometimes be quoted.
+
+ :param isoutparam:
+ if True, the parameter should be treated like a stored procedure
+ "OUT" parameter. This applies to backends such as Oracle which
+ support OUT parameters.
+
+ :param expanding:
+ if True, this parameter will be treated as an "expanding" parameter
+ at execution time; the parameter value is expected to be a sequence,
+ rather than a scalar value, and the string SQL statement will
+ be transformed on a per-execution basis to accommodate the sequence
+ with a variable number of parameter slots passed to the DBAPI.
+ This is to allow statement caching to be used in conjunction with
+ an IN clause.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.in_`
+
+ :ref:`baked_in` - with baked queries
+
+ .. note:: The "expanding" feature does not support "executemany"-
+ style parameter sets.
+
+ .. versionadded:: 1.2
+
+ .. versionchanged:: 1.3 the "expanding" bound parameter feature now
+ supports empty lists.
+
+ :param literal_execute:
+ if True, the bound parameter will be rendered in the compile phase
+ with a special "POSTCOMPILE" token, and the SQLAlchemy compiler will
+ render the final value of the parameter into the SQL statement at
+ statement execution time, omitting the value from the parameter
+ dictionary / list passed to DBAPI ``cursor.execute()``. This
+ produces a similar effect as that of using the ``literal_binds``,
+ compilation flag, however takes place as the statement is sent to
+ the DBAPI ``cursor.execute()`` method, rather than when the statement
+ is compiled. The primary use of this
+ capability is for rendering LIMIT / OFFSET clauses for database
+ drivers that can't accommodate for bound parameters in these
+ contexts, while allowing SQL constructs to be cacheable at the
+ compilation level.
+
+ .. versionadded:: 1.4 Added "post compile" bound parameters
+
+ .. seealso::
+
+ :ref:`change_4808`.
+
+ .. seealso::
+
+ :ref:`tutorial_sending_parameters` - in the
+ :ref:`unified_tutorial`
+
+ """
+ if required is NO_ARG:
+ required = value is NO_ARG and callable_ is None
+ if value is NO_ARG:
+ value = None
+
+ if quote is not None:
+ key = quoted_name(key, quote)
+
+ if unique:
+ self.key = _anonymous_label.safe_construct(
+ id(self),
+ key
+ if key is not None and not isinstance(key, _anonymous_label)
+ else "param",
+ sanitize_key=True,
+ )
+ self._key_is_anon = True
+ elif key:
+ self.key = key
+ else:
+ self.key = _anonymous_label.safe_construct(id(self), "param")
+ self._key_is_anon = True
+
+ # identifying key that won't change across
+ # clones, used to identify the bind's logical
+ # identity
+ self._identifying_key = self.key
+
+ # key that was passed in the first place, used to
+ # generate new keys
+ self._orig_key = key or "param"
+
+ self.unique = unique
+ self.value = value
+ self.callable = callable_
+ self.isoutparam = isoutparam
+ self.required = required
+
+ # indicate an "expanding" parameter; the compiler sets this
+ # automatically in the compiler _render_in_expr_w_bindparam method
+ # for an IN expression
+ self.expanding = expanding
+
+ # this is another hint to help w/ expanding and is typically
+ # set in the compiler _render_in_expr_w_bindparam method for an
+ # IN expression
+ self.expand_op = None
+
+ self.literal_execute = literal_execute
+ if _is_crud:
+ self._is_crud = True
+
+ if type_ is None:
+ if expanding and value:
+ check_value = value[0]
+ else:
+ check_value = value
+ if _compared_to_type is not None:
+ self.type = _compared_to_type.coerce_compared_value(
+ _compared_to_operator, check_value
+ )
+ else:
+ self.type = type_api._resolve_value_to_type(check_value)
+ elif isinstance(type_, type):
+ self.type = type_()
+ elif type_._is_tuple_type and value:
+ if expanding:
+ check_value = value[0]
+ else:
+ check_value = value
+ self.type = type_._resolve_values_to_types(check_value)
+ else:
+ self.type = type_
+
+ def _with_value(self, value, maintain_key=False, required=NO_ARG):
+ """Return a copy of this :class:`.BindParameter` with the given value
+ set.
+ """
+ cloned = self._clone(maintain_key=maintain_key)
+ cloned.value = value
+ cloned.callable = None
+ cloned.required = required if required is not NO_ARG else self.required
+ if cloned.type is type_api.NULLTYPE:
+ cloned.type = type_api._resolve_value_to_type(value)
+ return cloned
+
+ @property
+ def effective_value(self):
+ """Return the value of this bound parameter,
+ taking into account if the ``callable`` parameter
+ was set.
+
+ The ``callable`` value will be evaluated
+ and returned if present, else ``value``.
+
+ """
+ if self.callable:
+ return self.callable()
+ else:
+ return self.value
+
+ def render_literal_execute(self):
+ """Produce a copy of this bound parameter that will enable the
+ :paramref:`_sql.BindParameter.literal_execute` flag.
+
+ The :paramref:`_sql.BindParameter.literal_execute` flag will
+ have the effect of the parameter rendered in the compiled SQL
+ string using ``[POSTCOMPILE]`` form, which is a special form that
+ is converted to be a rendering of the literal value of the parameter
+ at SQL execution time. The rationale is to support caching
+ of SQL statement strings that can embed per-statement literal values,
+ such as LIMIT and OFFSET parameters, in the final SQL string that
+ is passed to the DBAPI. Dialects in particular may want to use
+ this method within custom compilation schemes.
+
+ .. versionadded:: 1.4.5
+
+ .. seealso::
+
+ :ref:`engine_thirdparty_caching`
+
+ """
+ return self.__class__(
+ self.key,
+ self.value,
+ type_=self.type,
+ literal_execute=True,
+ )
+
+ def _negate_in_binary(self, negated_op, original_op):
+ if self.expand_op is original_op:
+ bind = self._clone()
+ bind.expand_op = negated_op
+ return bind
+ else:
+ return self
+
+ def _with_binary_element_type(self, type_):
+ c = ClauseElement._clone(self)
+ c.type = type_
+ return c
+
+ def _clone(self, maintain_key=False, **kw):
+ c = ClauseElement._clone(self, **kw)
+ # ensure all the BindParameter objects stay in cloned set.
+ # in #7823, we changed "clone" so that a clone only keeps a reference
+ # to the "original" element, since for column correspondence, that's
+ # all we need. However, for BindParam, _cloned_set is used by
+ # the "cache key bind match" lookup, which means if any of those
+ # interim BindParameter objects became part of a cache key in the
+ # cache, we need it. So here, make sure all clones keep carrying
+ # forward.
+ c._cloned_set.update(self._cloned_set)
+ if not maintain_key and self.unique:
+ c.key = _anonymous_label.safe_construct(
+ id(c), c._orig_key or "param", sanitize_key=True
+ )
+ return c
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False)
+
+ if not _gen_cache_ok:
+ if anon_map is not None:
+ anon_map[NO_CACHE] = True
+ return None
+
+ idself = id(self)
+ if idself in anon_map:
+ return (anon_map[idself], self.__class__)
+ else:
+ # inline of
+ # id_ = anon_map[idself]
+ anon_map[idself] = id_ = str(anon_map.index)
+ anon_map.index += 1
+
+ if bindparams is not None:
+ bindparams.append(self)
+
+ return (
+ id_,
+ self.__class__,
+ self.type._static_cache_key,
+ self.key % anon_map if self._key_is_anon else self.key,
+ self.literal_execute,
+ )
+
+ def _convert_to_unique(self):
+ if not self.unique:
+ self.unique = True
+ self.key = _anonymous_label.safe_construct(
+ id(self), self._orig_key or "param", sanitize_key=True
+ )
+
+ def __getstate__(self):
+ """execute a deferred value for serialization purposes."""
+
+ d = self.__dict__.copy()
+ v = self.value
+ if self.callable:
+ v = self.callable()
+ d["callable"] = None
+ d["value"] = v
+ return d
+
+ def __setstate__(self, state):
+ if state.get("unique", False):
+ state["key"] = _anonymous_label.safe_construct(
+ id(self), state.get("_orig_key", "param"), sanitize_key=True
+ )
+ self.__dict__.update(state)
+
+ def __repr__(self):
+ return "%s(%r, %r, type_=%r)" % (
+ self.__class__.__name__,
+ self.key,
+ self.value,
+ self.type,
+ )
+
+
+class TypeClause(ClauseElement):
+ """Handle a type keyword in a SQL statement.
+
+ Used by the ``Case`` statement.
+
+ """
+
+ __visit_name__ = "typeclause"
+
+ _traverse_internals = [("type", InternalTraversal.dp_type)]
+
+ def __init__(self, type_):
+ self.type = type_
+
+
+class TextClause(
+ roles.DDLConstraintColumnRole,
+ roles.DDLExpressionRole,
+ roles.StatementOptionRole,
+ roles.WhereHavingRole,
+ roles.OrderByRole,
+ roles.FromClauseRole,
+ roles.SelectStatementRole,
+ roles.BinaryElementRole,
+ roles.InElementRole,
+ Executable,
+ ClauseElement,
+):
+ """Represent a literal SQL text fragment.
+
+ E.g.::
+
+ from sqlalchemy import text
+
+ t = text("SELECT * FROM users")
+ result = connection.execute(t)
+
+
+ The :class:`_expression.TextClause` construct is produced using the
+ :func:`_expression.text`
+ function; see that function for full documentation.
+
+ .. seealso::
+
+ :func:`_expression.text`
+
+ """
+
+ __visit_name__ = "textclause"
+
+ _traverse_internals = [
+ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
+ ("text", InternalTraversal.dp_string),
+ ]
+
+ _is_text_clause = True
+
+ _is_textual = True
+
+ _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE)
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": PARSE_AUTOCOMMIT}
+ )
+ _is_implicitly_boolean = False
+
+ _render_label_in_columns_clause = False
+
+ _hide_froms = ()
+
+ def __and__(self, other):
+ # support use in select.where(), query.filter()
+ return and_(self, other)
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ # help in those cases where text() is
+ # interpreted in a column expression situation
+ key = _label = None
+
+ _allow_label_resolve = False
+
+ @property
+ def _is_star(self):
+ return self.text == "*"
+
+ def __init__(self, text, bind=None):
+ self._bind = bind
+ self._bindparams = {}
+
+ def repl(m):
+ self._bindparams[m.group(1)] = BindParameter(m.group(1))
+ return ":%s" % m.group(1)
+
+ # scan the string and search for bind parameter names, add them
+ # to the list of bindparams
+ self.text = self._bind_params_regex.sub(repl, text)
+
+ @classmethod
+ @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`")
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_sql.text.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def _create_text(cls, text, bind=None):
+ r"""Construct a new :class:`_expression.TextClause` clause,
+ representing
+ a textual SQL string directly.
+
+ E.g.::
+
+ from sqlalchemy import text
+
+ t = text("SELECT * FROM users")
+ result = connection.execute(t)
+
+ The advantages :func:`_expression.text`
+ provides over a plain string are
+ backend-neutral support for bind parameters, per-statement
+ execution options, as well as
+ bind parameter and result-column typing behavior, allowing
+ SQLAlchemy type constructs to play a role when executing
+ a statement that is specified literally. The construct can also
+ be provided with a ``.c`` collection of column elements, allowing
+ it to be embedded in other SQL expression constructs as a subquery.
+
+ Bind parameters are specified by name, using the format ``:name``.
+ E.g.::
+
+ t = text("SELECT * FROM users WHERE id=:user_id")
+ result = connection.execute(t, user_id=12)
+
+ For SQL statements where a colon is required verbatim, as within
+ an inline string, use a backslash to escape::
+
+ t = text("SELECT * FROM users WHERE name='\:username'")
+
+ The :class:`_expression.TextClause`
+ construct includes methods which can
+ provide information about the bound parameters as well as the column
+ values which would be returned from the textual statement, assuming
+ it's an executable SELECT type of statement. The
+ :meth:`_expression.TextClause.bindparams`
+ method is used to provide bound
+ parameter detail, and :meth:`_expression.TextClause.columns`
+ method allows
+ specification of return columns including names and types::
+
+ t = text("SELECT * FROM users WHERE id=:user_id").\
+ bindparams(user_id=7).\
+ columns(id=Integer, name=String)
+
+ for id, name in connection.execute(t):
+ print(id, name)
+
+ The :func:`_expression.text` construct is used in cases when
+ a literal string SQL fragment is specified as part of a larger query,
+ such as for the WHERE clause of a SELECT statement::
+
+ s = select(users.c.id, users.c.name).where(text("id=:user_id"))
+ result = connection.execute(s, user_id=12)
+
+ :func:`_expression.text` is also used for the construction
+ of a full, standalone statement using plain text.
+ As such, SQLAlchemy refers
+ to it as an :class:`.Executable` object, and it supports
+ the :meth:`Executable.execution_options` method. For example,
+ a :func:`_expression.text`
+ construct that should be subject to "autocommit"
+ can be set explicitly so using the
+ :paramref:`.Connection.execution_options.autocommit` option::
+
+ t = text("EXEC my_procedural_thing()").\
+ execution_options(autocommit=True)
+
+ .. deprecated:: 1.4 The "autocommit" execution option is deprecated
+ and will be removed in SQLAlchemy 2.0. See
+ :ref:`migration_20_autocommit` for discussion.
+
+ :param text:
+ the text of the SQL statement to be created. Use ``:<param>``
+ to specify bind parameters; they will be compiled to their
+ engine-specific format.
+
+ :param bind:
+ an optional connection or engine to be used for this text query.
+
+ .. seealso::
+
+ :ref:`tutorial_select_arbitrary_text`
+
+
+ """
+ return TextClause(text, bind=bind)
+
+ @_generative
+ def bindparams(self, *binds, **names_to_values):
+ """Establish the values and/or types of bound parameters within
+ this :class:`_expression.TextClause` construct.
+
+ Given a text construct such as::
+
+ from sqlalchemy import text
+ stmt = text("SELECT id, name FROM user WHERE name=:name "
+ "AND timestamp=:timestamp")
+
+ the :meth:`_expression.TextClause.bindparams`
+ method can be used to establish
+ the initial value of ``:name`` and ``:timestamp``,
+ using simple keyword arguments::
+
+ stmt = stmt.bindparams(name='jack',
+ timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5))
+
+ Where above, new :class:`.BindParameter` objects
+ will be generated with the names ``name`` and ``timestamp``, and
+ values of ``jack`` and ``datetime.datetime(2012, 10, 8, 15, 12, 5)``,
+ respectively. The types will be
+ inferred from the values given, in this case :class:`.String` and
+ :class:`.DateTime`.
+
+ When specific typing behavior is needed, the positional ``*binds``
+ argument can be used in which to specify :func:`.bindparam` constructs
+ directly. These constructs must include at least the ``key``
+ argument, then an optional value and type::
+
+ from sqlalchemy import bindparam
+ stmt = stmt.bindparams(
+ bindparam('name', value='jack', type_=String),
+ bindparam('timestamp', type_=DateTime)
+ )
+
+ Above, we specified the type of :class:`.DateTime` for the
+ ``timestamp`` bind, and the type of :class:`.String` for the ``name``
+ bind. In the case of ``name`` we also set the default value of
+ ``"jack"``.
+
+ Additional bound parameters can be supplied at statement execution
+ time, e.g.::
+
+ result = connection.execute(stmt,
+ timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5))
+
+ The :meth:`_expression.TextClause.bindparams`
+ method can be called repeatedly,
+ where it will re-use existing :class:`.BindParameter` objects to add
+ new information. For example, we can call
+ :meth:`_expression.TextClause.bindparams`
+ first with typing information, and a
+ second time with value information, and it will be combined::
+
+ stmt = text("SELECT id, name FROM user WHERE name=:name "
+ "AND timestamp=:timestamp")
+ stmt = stmt.bindparams(
+ bindparam('name', type_=String),
+ bindparam('timestamp', type_=DateTime)
+ )
+ stmt = stmt.bindparams(
+ name='jack',
+ timestamp=datetime.datetime(2012, 10, 8, 15, 12, 5)
+ )
+
+ The :meth:`_expression.TextClause.bindparams`
+ method also supports the concept of
+ **unique** bound parameters. These are parameters that are
+ "uniquified" on name at statement compilation time, so that multiple
+ :func:`_expression.text`
+ constructs may be combined together without the names
+ conflicting. To use this feature, specify the
+ :paramref:`.BindParameter.unique` flag on each :func:`.bindparam`
+ object::
+
+ stmt1 = text("select id from table where name=:name").bindparams(
+ bindparam("name", value='name1', unique=True)
+ )
+ stmt2 = text("select id from table where name=:name").bindparams(
+ bindparam("name", value='name2', unique=True)
+ )
+
+ union = union_all(
+ stmt1.columns(column("id")),
+ stmt2.columns(column("id"))
+ )
+
+ The above statement will render as::
+
+ select id from table where name=:name_1
+ UNION ALL select id from table where name=:name_2
+
+ .. versionadded:: 1.3.11 Added support for the
+ :paramref:`.BindParameter.unique` flag to work with
+ :func:`_expression.text`
+ constructs.
+
+ """
+ self._bindparams = new_params = self._bindparams.copy()
+
+ for bind in binds:
+ try:
+ # the regex used for text() currently will not match
+ # a unique/anonymous key in any case, so use the _orig_key
+ # so that a text() construct can support unique parameters
+ existing = new_params[bind._orig_key]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "This text() construct doesn't define a "
+ "bound parameter named %r" % bind._orig_key
+ ),
+ replace_context=err,
+ )
+ else:
+ new_params[existing._orig_key] = bind
+
+ for key, value in names_to_values.items():
+ try:
+ existing = new_params[key]
+ except KeyError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "This text() construct doesn't define a "
+ "bound parameter named %r" % key
+ ),
+ replace_context=err,
+ )
+ else:
+ new_params[key] = existing._with_value(value, required=False)
+
+ @util.preload_module("sqlalchemy.sql.selectable")
+ def columns(self, *cols, **types):
+ r"""Turn this :class:`_expression.TextClause` object into a
+ :class:`_expression.TextualSelect`
+ object that serves the same role as a SELECT
+ statement.
+
+ The :class:`_expression.TextualSelect` is part of the
+ :class:`_expression.SelectBase`
+ hierarchy and can be embedded into another statement by using the
+ :meth:`_expression.TextualSelect.subquery` method to produce a
+ :class:`.Subquery`
+ object, which can then be SELECTed from.
+
+ This function essentially bridges the gap between an entirely
+ textual SELECT statement and the SQL expression language concept
+ of a "selectable"::
+
+ from sqlalchemy.sql import column, text
+
+ stmt = text("SELECT id, name FROM some_table")
+ stmt = stmt.columns(column('id'), column('name')).subquery('st')
+
+ stmt = select(mytable).\
+ select_from(
+ mytable.join(stmt, mytable.c.name == stmt.c.name)
+ ).where(stmt.c.id > 5)
+
+ Above, we pass a series of :func:`_expression.column` elements to the
+ :meth:`_expression.TextClause.columns` method positionally. These
+ :func:`_expression.column`
+ elements now become first class elements upon the
+ :attr:`_expression.TextualSelect.selected_columns` column collection,
+ which then
+ become part of the :attr:`.Subquery.c` collection after
+ :meth:`_expression.TextualSelect.subquery` is invoked.
+
+ The column expressions we pass to
+ :meth:`_expression.TextClause.columns` may
+ also be typed; when we do so, these :class:`.TypeEngine` objects become
+ the effective return type of the column, so that SQLAlchemy's
+ result-set-processing systems may be used on the return values.
+ This is often needed for types such as date or boolean types, as well
+ as for unicode processing on some dialect configurations::
+
+ stmt = text("SELECT id, name, timestamp FROM some_table")
+ stmt = stmt.columns(
+ column('id', Integer),
+ column('name', Unicode),
+ column('timestamp', DateTime)
+ )
+
+ for id, name, timestamp in connection.execute(stmt):
+ print(id, name, timestamp)
+
+ As a shortcut to the above syntax, keyword arguments referring to
+ types alone may be used, if only type conversion is needed::
+
+ stmt = text("SELECT id, name, timestamp FROM some_table")
+ stmt = stmt.columns(
+ id=Integer,
+ name=Unicode,
+ timestamp=DateTime
+ )
+
+ for id, name, timestamp in connection.execute(stmt):
+ print(id, name, timestamp)
+
+ The positional form of :meth:`_expression.TextClause.columns`
+ also provides the
+ unique feature of **positional column targeting**, which is
+ particularly useful when using the ORM with complex textual queries. If
+ we specify the columns from our model to
+ :meth:`_expression.TextClause.columns`,
+ the result set will match to those columns positionally, meaning the
+ name or origin of the column in the textual SQL doesn't matter::
+
+ stmt = text("SELECT users.id, addresses.id, users.id, "
+ "users.name, addresses.email_address AS email "
+ "FROM users JOIN addresses ON users.id=addresses.user_id "
+ "WHERE users.id = 1").columns(
+ User.id,
+ Address.id,
+ Address.user_id,
+ User.name,
+ Address.email_address
+ )
+
+ query = session.query(User).from_statement(stmt).options(
+ contains_eager(User.addresses))
+
+ .. versionadded:: 1.1 the :meth:`_expression.TextClause.columns`
+ method now
+ offers positional column targeting in the result set when
+ the column expressions are passed purely positionally.
+
+ The :meth:`_expression.TextClause.columns` method provides a direct
+ route to calling :meth:`_expression.FromClause.subquery` as well as
+ :meth:`_expression.SelectBase.cte`
+ against a textual SELECT statement::
+
+ stmt = stmt.columns(id=Integer, name=String).cte('st')
+
+ stmt = select(sometable).where(sometable.c.id == stmt.c.id)
+
+ :param \*cols: A series of :class:`_expression.ColumnElement` objects,
+ typically
+ :class:`_schema.Column` objects from a :class:`_schema.Table`
+ or ORM level
+ column-mapped attributes, representing a set of columns that this
+ textual string will SELECT from.
+
+ :param \**types: A mapping of string names to :class:`.TypeEngine`
+ type objects indicating the datatypes to use for names that are
+ SELECTed from the textual string. Prefer to use the ``*cols``
+ argument as it also indicates positional ordering.
+
+ """
+ selectable = util.preloaded.sql_selectable
+ positional_input_cols = [
+ ColumnClause(col.key, types.pop(col.key))
+ if col.key in types
+ else col
+ for col in cols
+ ]
+ keyed_input_cols = [
+ ColumnClause(key, type_) for key, type_ in types.items()
+ ]
+
+ return selectable.TextualSelect(
+ self,
+ positional_input_cols + keyed_input_cols,
+ positional=bool(positional_input_cols) and not keyed_input_cols,
+ )
+
+ @property
+ def type(self):
+ return type_api.NULLTYPE
+
+ @property
+ def comparator(self):
+ return self.type.comparator_factory(self)
+
+ def self_group(self, against=None):
+ if against is operators.in_op:
+ return Grouping(self)
+ else:
+ return self
+
+
+class Null(SingletonConstant, roles.ConstExprRole, ColumnElement):
+ """Represent the NULL keyword in a SQL statement.
+
+ :class:`.Null` is accessed as a constant via the
+ :func:`.null` function.
+
+ """
+
+ __visit_name__ = "null"
+
+ _traverse_internals = []
+
+ @util.memoized_property
+ def type(self):
+ return type_api.NULLTYPE
+
+ @classmethod
+ def _instance(cls):
+ """Return a constant :class:`.Null` construct."""
+
+ return Null()
+
+
+Null._create_singleton()
+
+
+class False_(SingletonConstant, roles.ConstExprRole, ColumnElement):
+ """Represent the ``false`` keyword, or equivalent, in a SQL statement.
+
+ :class:`.False_` is accessed as a constant via the
+ :func:`.false` function.
+
+ """
+
+ __visit_name__ = "false"
+ _traverse_internals = []
+
+ @util.memoized_property
+ def type(self):
+ return type_api.BOOLEANTYPE
+
+ def _negate(self):
+ return True_()
+
+ @classmethod
+ def _instance(cls):
+ """Return a :class:`.False_` construct.
+
+ E.g.::
+
+ >>> from sqlalchemy import false
+ >>> print(select(t.c.x).where(false()))
+ SELECT x FROM t WHERE false
+
+ A backend which does not support true/false constants will render as
+ an expression against 1 or 0::
+
+ >>> print(select(t.c.x).where(false()))
+ SELECT x FROM t WHERE 0 = 1
+
+ The :func:`.true` and :func:`.false` constants also feature
+ "short circuit" operation within an :func:`.and_` or :func:`.or_`
+ conjunction::
+
+ >>> print(select(t.c.x).where(or_(t.c.x > 5, true())))
+ SELECT x FROM t WHERE true
+
+ >>> print(select(t.c.x).where(and_(t.c.x > 5, false())))
+ SELECT x FROM t WHERE false
+
+ .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature
+ better integrated behavior within conjunctions and on dialects
+ that don't support true/false constants.
+
+ .. seealso::
+
+ :func:`.true`
+
+ """
+
+ return False_()
+
+
+False_._create_singleton()
+
+
+class True_(SingletonConstant, roles.ConstExprRole, ColumnElement):
+ """Represent the ``true`` keyword, or equivalent, in a SQL statement.
+
+ :class:`.True_` is accessed as a constant via the
+ :func:`.true` function.
+
+ """
+
+ __visit_name__ = "true"
+
+ _traverse_internals = []
+
+ @util.memoized_property
+ def type(self):
+ return type_api.BOOLEANTYPE
+
+ def _negate(self):
+ return False_()
+
+ @classmethod
+ def _ifnone(cls, other):
+ if other is None:
+ return cls._instance()
+ else:
+ return other
+
+ @classmethod
+ def _instance(cls):
+ """Return a constant :class:`.True_` construct.
+
+ E.g.::
+
+ >>> from sqlalchemy import true
+ >>> print(select(t.c.x).where(true()))
+ SELECT x FROM t WHERE true
+
+ A backend which does not support true/false constants will render as
+ an expression against 1 or 0::
+
+ >>> print(select(t.c.x).where(true()))
+ SELECT x FROM t WHERE 1 = 1
+
+ The :func:`.true` and :func:`.false` constants also feature
+ "short circuit" operation within an :func:`.and_` or :func:`.or_`
+ conjunction::
+
+ >>> print(select(t.c.x).where(or_(t.c.x > 5, true())))
+ SELECT x FROM t WHERE true
+
+ >>> print(select(t.c.x).where(and_(t.c.x > 5, false())))
+ SELECT x FROM t WHERE false
+
+ .. versionchanged:: 0.9 :func:`.true` and :func:`.false` feature
+ better integrated behavior within conjunctions and on dialects
+ that don't support true/false constants.
+
+ .. seealso::
+
+ :func:`.false`
+
+ """
+
+ return True_()
+
+
+True_._create_singleton()
+
+
+class ClauseList(
+ roles.InElementRole,
+ roles.OrderByRole,
+ roles.ColumnsClauseRole,
+ roles.DMLColumnRole,
+ ClauseElement,
+):
+ """Describe a list of clauses, separated by an operator.
+
+ By default, is comma-separated, such as a column listing.
+
+ """
+
+ __visit_name__ = "clauselist"
+
+ _is_clause_list = True
+
+ _traverse_internals = [
+ ("clauses", InternalTraversal.dp_clauseelement_list),
+ ("operator", InternalTraversal.dp_operator),
+ ]
+
+ def __init__(self, *clauses, **kwargs):
+ self.operator = kwargs.pop("operator", operators.comma_op)
+ self.group = kwargs.pop("group", True)
+ self.group_contents = kwargs.pop("group_contents", True)
+ if kwargs.pop("_flatten_sub_clauses", False):
+ clauses = util.flatten_iterator(clauses)
+ self._text_converter_role = text_converter_role = kwargs.pop(
+ "_literal_as_text_role", roles.WhereHavingRole
+ )
+ if self.group_contents:
+ self.clauses = [
+ coercions.expect(
+ text_converter_role, clause, apply_propagate_attrs=self
+ ).self_group(against=self.operator)
+ for clause in clauses
+ ]
+ else:
+ self.clauses = [
+ coercions.expect(
+ text_converter_role, clause, apply_propagate_attrs=self
+ )
+ for clause in clauses
+ ]
+ self._is_implicitly_boolean = operators.is_boolean(self.operator)
+
+ @classmethod
+ def _construct_raw(cls, operator, clauses=None):
+ self = cls.__new__(cls)
+ self.clauses = clauses if clauses else []
+ self.group = True
+ self.operator = operator
+ self.group_contents = True
+ self._is_implicitly_boolean = False
+ return self
+
+ def __iter__(self):
+ return iter(self.clauses)
+
+ def __len__(self):
+ return len(self.clauses)
+
+ @property
+ def _select_iterable(self):
+ return itertools.chain.from_iterable(
+ [elem._select_iterable for elem in self.clauses]
+ )
+
+ def append(self, clause):
+ if self.group_contents:
+ self.clauses.append(
+ coercions.expect(self._text_converter_role, clause).self_group(
+ against=self.operator
+ )
+ )
+ else:
+ self.clauses.append(
+ coercions.expect(self._text_converter_role, clause)
+ )
+
+ @property
+ def _from_objects(self):
+ return list(itertools.chain(*[c._from_objects for c in self.clauses]))
+
+ def self_group(self, against=None):
+ if self.group and operators.is_precedent(self.operator, against):
+ return Grouping(self)
+ else:
+ return self
+
+
+class BooleanClauseList(ClauseList, ColumnElement):
+ __visit_name__ = "clauselist"
+ inherit_cache = True
+
+ def __init__(self, *arg, **kw):
+ raise NotImplementedError(
+ "BooleanClauseList has a private constructor"
+ )
+
+ @classmethod
+ def _process_clauses_for_boolean(
+ cls, operator, continue_on, skip_on, clauses
+ ):
+ has_continue_on = None
+
+ convert_clauses = []
+
+ against = operators._asbool
+ lcc = 0
+
+ for clause in clauses:
+ if clause is continue_on:
+ # instance of continue_on, like and_(x, y, True, z), store it
+ # if we didn't find one already, we will use it if there
+ # are no other expressions here.
+ has_continue_on = clause
+ elif clause is skip_on:
+ # instance of skip_on, e.g. and_(x, y, False, z), cancels
+ # the rest out
+ convert_clauses = [clause]
+ lcc = 1
+ break
+ else:
+ if not lcc:
+ lcc = 1
+ else:
+ against = operator
+ # technically this would be len(convert_clauses) + 1
+ # however this only needs to indicate "greater than one"
+ lcc = 2
+ convert_clauses.append(clause)
+
+ if not convert_clauses and has_continue_on is not None:
+ convert_clauses = [has_continue_on]
+ lcc = 1
+
+ return lcc, [c.self_group(against=against) for c in convert_clauses]
+
+ @classmethod
+ def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
+ lcc, convert_clauses = cls._process_clauses_for_boolean(
+ operator,
+ continue_on,
+ skip_on,
+ [
+ coercions.expect(roles.WhereHavingRole, clause)
+ for clause in util.coerce_generator_arg(clauses)
+ ],
+ )
+
+ if lcc > 1:
+ # multiple elements. Return regular BooleanClauseList
+ # which will link elements against the operator.
+ return cls._construct_raw(operator, convert_clauses)
+ elif lcc == 1:
+ # just one element. return it as a single boolean element,
+ # not a list and discard the operator.
+ return convert_clauses[0]
+ else:
+ # no elements period. deprecated use case. return an empty
+ # ClauseList construct that generates nothing unless it has
+ # elements added to it.
+ util.warn_deprecated(
+ "Invoking %(name)s() without arguments is deprecated, and "
+ "will be disallowed in a future release. For an empty "
+ "%(name)s() construct, use %(name)s(%(continue_on)s, *args)."
+ % {
+ "name": operator.__name__,
+ "continue_on": "True"
+ if continue_on is True_._singleton
+ else "False",
+ },
+ version="1.4",
+ )
+ return cls._construct_raw(operator)
+
+ @classmethod
+ def _construct_for_whereclause(cls, clauses):
+ operator, continue_on, skip_on = (
+ operators.and_,
+ True_._singleton,
+ False_._singleton,
+ )
+
+ lcc, convert_clauses = cls._process_clauses_for_boolean(
+ operator,
+ continue_on,
+ skip_on,
+ clauses, # these are assumed to be coerced already
+ )
+
+ if lcc > 1:
+ # multiple elements. Return regular BooleanClauseList
+ # which will link elements against the operator.
+ return cls._construct_raw(operator, convert_clauses)
+ elif lcc == 1:
+ # just one element. return it as a single boolean element,
+ # not a list and discard the operator.
+ return convert_clauses[0]
+ else:
+ return None
+
+ @classmethod
+ def _construct_raw(cls, operator, clauses=None):
+ self = cls.__new__(cls)
+ self.clauses = clauses if clauses else []
+ self.group = True
+ self.operator = operator
+ self.group_contents = True
+ self.type = type_api.BOOLEANTYPE
+ self._is_implicitly_boolean = True
+ return self
+
+ @classmethod
+ def and_(cls, *clauses):
+ r"""Produce a conjunction of expressions joined by ``AND``.
+
+ E.g.::
+
+ from sqlalchemy import and_
+
+ stmt = select(users_table).where(
+ and_(
+ users_table.c.name == 'wendy',
+ users_table.c.enrolled == True
+ )
+ )
+
+ The :func:`.and_` conjunction is also available using the
+ Python ``&`` operator (though note that compound expressions
+ need to be parenthesized in order to function with Python
+ operator precedence behavior)::
+
+ stmt = select(users_table).where(
+ (users_table.c.name == 'wendy') &
+ (users_table.c.enrolled == True)
+ )
+
+ The :func:`.and_` operation is also implicit in some cases;
+ the :meth:`_expression.Select.where`
+ method for example can be invoked multiple
+ times against a statement, which will have the effect of each
+ clause being combined using :func:`.and_`::
+
+ stmt = select(users_table).\
+ where(users_table.c.name == 'wendy').\
+ where(users_table.c.enrolled == True)
+
+ The :func:`.and_` construct must be given at least one positional
+ argument in order to be valid; a :func:`.and_` construct with no
+ arguments is ambiguous. To produce an "empty" or dynamically
+ generated :func:`.and_` expression, from a given list of expressions,
+ a "default" element of ``True`` should be specified::
+
+ criteria = and_(True, *expressions)
+
+ The above expression will compile to SQL as the expression ``true``
+ or ``1 = 1``, depending on backend, if no other expressions are
+ present. If expressions are present, then the ``True`` value is
+ ignored as it does not affect the outcome of an AND expression that
+ has other elements.
+
+ .. deprecated:: 1.4 The :func:`.and_` element now requires that at
+ least one argument is passed; creating the :func:`.and_` construct
+ with no arguments is deprecated, and will emit a deprecation warning
+ while continuing to produce a blank SQL string.
+
+ .. seealso::
+
+ :func:`.or_`
+
+ """
+ return cls._construct(
+ operators.and_, True_._singleton, False_._singleton, *clauses
+ )
+
+ @classmethod
+ def or_(cls, *clauses):
+ """Produce a conjunction of expressions joined by ``OR``.
+
+ E.g.::
+
+ from sqlalchemy import or_
+
+ stmt = select(users_table).where(
+ or_(
+ users_table.c.name == 'wendy',
+ users_table.c.name == 'jack'
+ )
+ )
+
+ The :func:`.or_` conjunction is also available using the
+ Python ``|`` operator (though note that compound expressions
+ need to be parenthesized in order to function with Python
+ operator precedence behavior)::
+
+ stmt = select(users_table).where(
+ (users_table.c.name == 'wendy') |
+ (users_table.c.name == 'jack')
+ )
+
+ The :func:`.or_` construct must be given at least one positional
+ argument in order to be valid; a :func:`.or_` construct with no
+ arguments is ambiguous. To produce an "empty" or dynamically
+ generated :func:`.or_` expression, from a given list of expressions,
+ a "default" element of ``False`` should be specified::
+
+ or_criteria = or_(False, *expressions)
+
+ The above expression will compile to SQL as the expression ``false``
+ or ``0 = 1``, depending on backend, if no other expressions are
+ present. If expressions are present, then the ``False`` value is
+ ignored as it does not affect the outcome of an OR expression which
+ has other elements.
+
+ .. deprecated:: 1.4 The :func:`.or_` element now requires that at
+ least one argument is passed; creating the :func:`.or_` construct
+ with no arguments is deprecated, and will emit a deprecation warning
+ while continuing to produce a blank SQL string.
+
+ .. seealso::
+
+ :func:`.and_`
+
+ """
+ return cls._construct(
+ operators.or_, False_._singleton, True_._singleton, *clauses
+ )
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ def self_group(self, against=None):
+ if not self.clauses:
+ return self
+ else:
+ return super(BooleanClauseList, self).self_group(against=against)
+
+ def _negate(self):
+ return ClauseList._negate(self)
+
+
+and_ = BooleanClauseList.and_
+or_ = BooleanClauseList.or_
+
+
+class Tuple(ClauseList, ColumnElement):
+ """Represent a SQL tuple."""
+
+ __visit_name__ = "tuple"
+
+ _traverse_internals = ClauseList._traverse_internals + []
+
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def __init__(self, *clauses, **kw):
+ """Return a :class:`.Tuple`.
+
+ Main usage is to produce a composite IN construct using
+ :meth:`.ColumnOperators.in_` ::
+
+ from sqlalchemy import tuple_
+
+ tuple_(table.c.col1, table.c.col2).in_(
+ [(1, 2), (5, 12), (10, 19)]
+ )
+
+ .. versionchanged:: 1.3.6 Added support for SQLite IN tuples.
+
+ .. warning::
+
+ The composite IN construct is not supported by all backends, and is
+ currently known to work on PostgreSQL, MySQL, and SQLite.
+ Unsupported backends will raise a subclass of
+ :class:`~sqlalchemy.exc.DBAPIError` when such an expression is
+ invoked.
+
+ """
+ sqltypes = util.preloaded.sql_sqltypes
+
+ types = kw.pop("types", None)
+ if types is None:
+ clauses = [
+ coercions.expect(roles.ExpressionElementRole, c)
+ for c in clauses
+ ]
+ else:
+ if len(types) != len(clauses):
+ raise exc.ArgumentError(
+ "Wrong number of elements for %d-tuple: %r "
+ % (len(types), clauses)
+ )
+ clauses = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ type_=typ if not typ._isnull else None,
+ )
+ for typ, c in zip(types, clauses)
+ ]
+
+ self.type = sqltypes.TupleType(*[arg.type for arg in clauses])
+ super(Tuple, self).__init__(*clauses, **kw)
+
+ @property
+ def _select_iterable(self):
+ return (self,)
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ if expanding:
+ return BindParameter(
+ None,
+ value=obj,
+ _compared_to_operator=operator,
+ unique=True,
+ expanding=True,
+ type_=self.type,
+ )
+ else:
+ return Tuple(
+ *[
+ BindParameter(
+ None,
+ o,
+ _compared_to_operator=operator,
+ _compared_to_type=compared_to_type,
+ unique=True,
+ type_=type_,
+ )
+ for o, compared_to_type in zip(obj, self.type.types)
+ ]
+ )
+
+ def self_group(self, against=None):
+ # Tuple is parenthesized by definition.
+ return self
+
+
+class Case(ColumnElement):
+ """Represent a ``CASE`` expression.
+
+ :class:`.Case` is produced using the :func:`.case` factory function,
+ as in::
+
+ from sqlalchemy import case
+
+ stmt = select(users_table).\
+ where(
+ case(
+ (users_table.c.name == 'wendy', 'W'),
+ (users_table.c.name == 'jack', 'J'),
+ else_='E'
+ )
+ )
+
+ Details on :class:`.Case` usage is at :func:`.case`.
+
+ .. seealso::
+
+ :func:`.case`
+
+ """
+
+ __visit_name__ = "case"
+
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
+ # TODO: for Py2k removal, this will be:
+ # def __init__(self, *whens, value=None, else_=None):
+
+ def __init__(self, *whens, **kw):
+ r"""Produce a ``CASE`` expression.
+
+ The ``CASE`` construct in SQL is a conditional object that
+ acts somewhat analogously to an "if/then" construct in other
+ languages. It returns an instance of :class:`.Case`.
+
+ :func:`.case` in its usual form is passed a series of "when"
+ constructs, that is, a list of conditions and results as tuples::
+
+ from sqlalchemy import case
+
+ stmt = select(users_table).\
+ where(
+ case(
+ (users_table.c.name == 'wendy', 'W'),
+ (users_table.c.name == 'jack', 'J'),
+ else_='E'
+ )
+ )
+
+ The above statement will produce SQL resembling::
+
+ SELECT id, name FROM user
+ WHERE CASE
+ WHEN (name = :name_1) THEN :param_1
+ WHEN (name = :name_2) THEN :param_2
+ ELSE :param_3
+ END
+
+ When simple equality expressions of several values against a single
+ parent column are needed, :func:`.case` also has a "shorthand" format
+ used via the
+ :paramref:`.case.value` parameter, which is passed a column
+ expression to be compared. In this form, the :paramref:`.case.whens`
+ parameter is passed as a dictionary containing expressions to be
+ compared against keyed to result expressions. The statement below is
+ equivalent to the preceding statement::
+
+ stmt = select(users_table).\
+ where(
+ case(
+ {"wendy": "W", "jack": "J"},
+ value=users_table.c.name,
+ else_='E'
+ )
+ )
+
+ The values which are accepted as result values in
+ :paramref:`.case.whens` as well as with :paramref:`.case.else_` are
+ coerced from Python literals into :func:`.bindparam` constructs.
+ SQL expressions, e.g. :class:`_expression.ColumnElement` constructs,
+ are accepted
+ as well. To coerce a literal string expression into a constant
+ expression rendered inline, use the :func:`_expression.literal_column`
+ construct,
+ as in::
+
+ from sqlalchemy import case, literal_column
+
+ case(
+ (
+ orderline.c.qty > 100,
+ literal_column("'greaterthan100'")
+ ),
+ (
+ orderline.c.qty > 10,
+ literal_column("'greaterthan10'")
+ ),
+ else_=literal_column("'lessthan10'")
+ )
+
+ The above will render the given constants without using bound
+ parameters for the result values (but still for the comparison
+ values), as in::
+
+ CASE
+ WHEN (orderline.qty > :qty_1) THEN 'greaterthan100'
+ WHEN (orderline.qty > :qty_2) THEN 'greaterthan10'
+ ELSE 'lessthan10'
+ END
+
+ :param \*whens: The criteria to be compared against,
+ :paramref:`.case.whens` accepts two different forms, based on
+ whether or not :paramref:`.case.value` is used.
+
+ .. versionchanged:: 1.4 the :func:`_sql.case`
+ function now accepts the series of WHEN conditions positionally;
+ passing the expressions within a list is deprecated.
+
+ In the first form, it accepts a list of 2-tuples; each 2-tuple
+ consists of ``(<sql expression>, <value>)``, where the SQL
+ expression is a boolean expression and "value" is a resulting value,
+ e.g.::
+
+ case(
+ (users_table.c.name == 'wendy', 'W'),
+ (users_table.c.name == 'jack', 'J')
+ )
+
+ In the second form, it accepts a Python dictionary of comparison
+ values mapped to a resulting value; this form requires
+ :paramref:`.case.value` to be present, and values will be compared
+ using the ``==`` operator, e.g.::
+
+ case(
+ {"wendy": "W", "jack": "J"},
+ value=users_table.c.name
+ )
+
+ :param value: An optional SQL expression which will be used as a
+ fixed "comparison point" for candidate values within a dictionary
+ passed to :paramref:`.case.whens`.
+
+ :param else\_: An optional SQL expression which will be the evaluated
+ result of the ``CASE`` construct if all expressions within
+ :paramref:`.case.whens` evaluate to false. When omitted, most
+ databases will produce a result of NULL if none of the "when"
+ expressions evaluate to true.
+
+
+ """
+
+ if "whens" in kw:
+ util.warn_deprecated_20(
+ 'The "whens" argument to case() is now passed using '
+ "positional style only, not as a keyword argument."
+ )
+ whens = (kw.pop("whens"),)
+
+ whens = coercions._expression_collection_was_a_list(
+ "whens", "case", whens
+ )
+
+ try:
+ whens = util.dictlike_iteritems(whens)
+ except TypeError:
+ pass
+
+ value = kw.pop("value", None)
+
+ whenlist = [
+ (
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ apply_propagate_attrs=self,
+ ).self_group(),
+ coercions.expect(roles.ExpressionElementRole, r),
+ )
+ for (c, r) in whens
+ ]
+
+ if whenlist:
+ type_ = list(whenlist[-1])[-1].type
+ else:
+ type_ = None
+
+ if value is None:
+ self.value = None
+ else:
+ self.value = coercions.expect(roles.ExpressionElementRole, value)
+
+ self.type = type_
+ self.whens = whenlist
+
+ else_ = kw.pop("else_", None)
+ if else_ is not None:
+ self.else_ = coercions.expect(roles.ExpressionElementRole, else_)
+ else:
+ self.else_ = None
+
+ if kw:
+ raise TypeError("unknown arguments: %s" % (", ".join(sorted(kw))))
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(*[x._from_objects for x in self.get_children()])
+ )
+
+
+def literal_column(text, type_=None):
+ r"""Produce a :class:`.ColumnClause` object that has the
+ :paramref:`_expression.column.is_literal` flag set to True.
+
+ :func:`_expression.literal_column` is similar to
+ :func:`_expression.column`, except that
+ it is more often used as a "standalone" column expression that renders
+ exactly as stated; while :func:`_expression.column`
+ stores a string name that
+ will be assumed to be part of a table and may be quoted as such,
+ :func:`_expression.literal_column` can be that,
+ or any other arbitrary column-oriented
+ expression.
+
+ :param text: the text of the expression; can be any SQL expression.
+ Quoting rules will not be applied. To specify a column-name expression
+ which should be subject to quoting rules, use the :func:`column`
+ function.
+
+ :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine`
+ object which will
+ provide result-set translation and additional expression semantics for
+ this column. If left as ``None`` the type will be :class:`.NullType`.
+
+ .. seealso::
+
+ :func:`_expression.column`
+
+ :func:`_expression.text`
+
+ :ref:`sqlexpression_literal_column`
+
+ """
+ return ColumnClause(text, type_=type_, is_literal=True)
+
+
+class Cast(WrapsColumnExpression, ColumnElement):
+ """Represent a ``CAST`` expression.
+
+ :class:`.Cast` is produced using the :func:`.cast` factory function,
+ as in::
+
+ from sqlalchemy import cast, Numeric
+
+ stmt = select(cast(product_table.c.unit_price, Numeric(10, 4)))
+
+ Details on :class:`.Cast` usage is at :func:`.cast`.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`.cast`
+
+ :func:`.type_coerce` - an alternative to CAST that coerces the type
+ on the Python side only, which is often sufficient to generate the
+ correct SQL and data coercion.
+
+ """
+
+ __visit_name__ = "cast"
+
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("typeclause", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, expression, type_):
+ r"""Produce a ``CAST`` expression.
+
+ :func:`.cast` returns an instance of :class:`.Cast`.
+
+ E.g.::
+
+ from sqlalchemy import cast, Numeric
+
+ stmt = select(cast(product_table.c.unit_price, Numeric(10, 4)))
+
+ The above statement will produce SQL resembling::
+
+ SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product
+
+ The :func:`.cast` function performs two distinct functions when
+ used. The first is that it renders the ``CAST`` expression within
+ the resulting SQL string. The second is that it associates the given
+ type (e.g. :class:`.TypeEngine` class or instance) with the column
+ expression on the Python side, which means the expression will take
+ on the expression operator behavior associated with that type,
+ as well as the bound-value handling and result-row-handling behavior
+ of the type.
+
+ .. versionchanged:: 0.9.0 :func:`.cast` now applies the given type
+ to the expression such that it takes effect on the bound-value,
+ e.g. the Python-to-database direction, in addition to the
+ result handling, e.g. database-to-Python, direction.
+
+ An alternative to :func:`.cast` is the :func:`.type_coerce` function.
+ This function performs the second task of associating an expression
+ with a specific type, but does not render the ``CAST`` expression
+ in SQL.
+
+ :param expression: A SQL expression, such as a
+ :class:`_expression.ColumnElement`
+ expression or a Python string which will be coerced into a bound
+ literal value.
+
+ :param type\_: A :class:`.TypeEngine` class or instance indicating
+ the type to which the ``CAST`` should apply.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`.type_coerce` - an alternative to CAST that coerces the type
+ on the Python side only, which is often sufficient to generate the
+ correct SQL and data coercion.
+
+
+ """
+ self.type = type_api.to_instance(type_)
+ self.clause = coercions.expect(
+ roles.ExpressionElementRole,
+ expression,
+ type_=self.type,
+ apply_propagate_attrs=self,
+ )
+ self.typeclause = TypeClause(self.type)
+
+ @property
+ def _from_objects(self):
+ return self.clause._from_objects
+
+ @property
+ def wrapped_column_expression(self):
+ return self.clause
+
+
+class TypeCoerce(WrapsColumnExpression, ColumnElement):
+ """Represent a Python-side type-coercion wrapper.
+
+ :class:`.TypeCoerce` supplies the :func:`_expression.type_coerce`
+ function; see that function for usage details.
+
+ .. versionchanged:: 1.1 The :func:`.type_coerce` function now produces
+ a persistent :class:`.TypeCoerce` wrapper object rather than
+ translating the given object in place.
+
+ .. seealso::
+
+ :func:`_expression.type_coerce`
+
+ :func:`.cast`
+
+ """
+
+ __visit_name__ = "type_coerce"
+
+ _traverse_internals = [
+ ("clause", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ def __init__(self, expression, type_):
+ r"""Associate a SQL expression with a particular type, without rendering
+ ``CAST``.
+
+ E.g.::
+
+ from sqlalchemy import type_coerce
+
+ stmt = select(type_coerce(log_table.date_string, StringDateTime()))
+
+ The above construct will produce a :class:`.TypeCoerce` object, which
+ does not modify the rendering in any way on the SQL side, with the
+ possible exception of a generated label if used in a columns clause
+ context::
+
+ SELECT date_string AS date_string FROM log
+
+ When result rows are fetched, the ``StringDateTime`` type processor
+ will be applied to result rows on behalf of the ``date_string`` column.
+
+ .. note:: the :func:`.type_coerce` construct does not render any
+ SQL syntax of its own, including that it does not imply
+ parenthesization. Please use :meth:`.TypeCoerce.self_group`
+ if explicit parenthesization is required.
+
+ In order to provide a named label for the expression, use
+ :meth:`_expression.ColumnElement.label`::
+
+ stmt = select(
+ type_coerce(log_table.date_string, StringDateTime()).label('date')
+ )
+
+
+ A type that features bound-value handling will also have that behavior
+ take effect when literal values or :func:`.bindparam` constructs are
+ passed to :func:`.type_coerce` as targets.
+ For example, if a type implements the
+ :meth:`.TypeEngine.bind_expression`
+ method or :meth:`.TypeEngine.bind_processor` method or equivalent,
+ these functions will take effect at statement compilation/execution
+ time when a literal value is passed, as in::
+
+ # bound-value handling of MyStringType will be applied to the
+ # literal value "some string"
+ stmt = select(type_coerce("some string", MyStringType))
+
+ When using :func:`.type_coerce` with composed expressions, note that
+ **parenthesis are not applied**. If :func:`.type_coerce` is being
+ used in an operator context where the parenthesis normally present from
+ CAST are necessary, use the :meth:`.TypeCoerce.self_group` method::
+
+ >>> some_integer = column("someint", Integer)
+ >>> some_string = column("somestr", String)
+ >>> expr = type_coerce(some_integer + 5, String) + some_string
+ >>> print(expr)
+ someint + :someint_1 || somestr
+ >>> expr = type_coerce(some_integer + 5, String).self_group() + some_string
+ >>> print(expr)
+ (someint + :someint_1) || somestr
+
+ :param expression: A SQL expression, such as a
+ :class:`_expression.ColumnElement`
+ expression or a Python string which will be coerced into a bound
+ literal value.
+
+ :param type\_: A :class:`.TypeEngine` class or instance indicating
+ the type to which the expression is coerced.
+
+ .. seealso::
+
+ :ref:`tutorial_casts`
+
+ :func:`.cast`
+
+ """ # noqa
+ self.type = type_api.to_instance(type_)
+ self.clause = coercions.expect(
+ roles.ExpressionElementRole,
+ expression,
+ type_=self.type,
+ apply_propagate_attrs=self,
+ )
+
+ @property
+ def _from_objects(self):
+ return self.clause._from_objects
+
+ @HasMemoized.memoized_attribute
+ def typed_expression(self):
+ if isinstance(self.clause, BindParameter):
+ bp = self.clause._clone()
+ bp.type = self.type
+ return bp
+ else:
+ return self.clause
+
+ @property
+ def wrapped_column_expression(self):
+ return self.clause
+
+ def self_group(self, against=None):
+ grouped = self.clause.self_group(against=against)
+ if grouped is not self.clause:
+ return TypeCoerce(grouped, self.type)
+ else:
+ return self
+
+
+class Extract(ColumnElement):
+ """Represent a SQL EXTRACT clause, ``extract(field FROM expr)``."""
+
+ __visit_name__ = "extract"
+
+ _traverse_internals = [
+ ("expr", InternalTraversal.dp_clauseelement),
+ ("field", InternalTraversal.dp_string),
+ ]
+
+ def __init__(self, field, expr, **kwargs):
+ """Return a :class:`.Extract` construct.
+
+ This is typically available as :func:`.extract`
+ as well as ``func.extract`` from the
+ :data:`.func` namespace.
+
+ :param field: The field to extract.
+
+ :param expr: A column or Python scalar expression serving as the
+ right side of the ``EXTRACT`` expression.
+
+ E.g.::
+
+ from sqlalchemy import extract
+ from sqlalchemy import table, column
+
+ logged_table = table("user",
+ column("id"),
+ column("date_created"),
+ )
+
+ stmt = select(logged_table.c.id).where(
+ extract("YEAR", logged_table.c.date_created) == 2021
+ )
+
+ In the above example, the statement is used to select ids from the
+ database where the ``YEAR`` component matches a specific value.
+
+ Similarly, one can also select an extracted component::
+
+ stmt = select(
+ extract("YEAR", logged_table.c.date_created)
+ ).where(logged_table.c.id == 1)
+
+ The implementation of ``EXTRACT`` may vary across database backends.
+ Users are reminded to consult their database documentation.
+ """
+ self.type = type_api.INTEGERTYPE
+ self.field = field
+ self.expr = coercions.expect(roles.ExpressionElementRole, expr)
+
+ @property
+ def _from_objects(self):
+ return self.expr._from_objects
+
+
+class _label_reference(ColumnElement):
+ """Wrap a column expression as it appears in a 'reference' context.
+
+ This expression is any that includes an _order_by_label_element,
+ which is a Label, or a DESC / ASC construct wrapping a Label.
+
+ The production of _label_reference() should occur when an expression
+ is added to this context; this includes the ORDER BY or GROUP BY of a
+ SELECT statement, as well as a few other places, such as the ORDER BY
+ within an OVER clause.
+
+ """
+
+ __visit_name__ = "label_reference"
+
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
+ def __init__(self, element):
+ self.element = element
+
+ @property
+ def _from_objects(self):
+ return ()
+
+
+class _textual_label_reference(ColumnElement):
+ __visit_name__ = "textual_label_reference"
+
+ _traverse_internals = [("element", InternalTraversal.dp_string)]
+
+ def __init__(self, element):
+ self.element = element
+
+ @util.memoized_property
+ def _text_clause(self):
+ return TextClause._create_text(self.element)
+
+
+class UnaryExpression(ColumnElement):
+ """Define a 'unary' expression.
+
+ A unary expression has a single column expression
+ and an operator. The operator can be placed on the left
+ (where it is called the 'operator') or right (where it is called the
+ 'modifier') of the column expression.
+
+ :class:`.UnaryExpression` is the basis for several unary operators
+ including those used by :func:`.desc`, :func:`.asc`, :func:`.distinct`,
+ :func:`.nulls_first` and :func:`.nulls_last`.
+
+ """
+
+ __visit_name__ = "unary"
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("modifier", InternalTraversal.dp_operator),
+ ]
+
+ def __init__(
+ self,
+ element,
+ operator=None,
+ modifier=None,
+ type_=None,
+ wraps_column_expression=False,
+ ):
+ self.operator = operator
+ self.modifier = modifier
+ self._propagate_attrs = element._propagate_attrs
+ self.element = element.self_group(
+ against=self.operator or self.modifier
+ )
+ self.type = type_api.to_instance(type_)
+ self.wraps_column_expression = wraps_column_expression
+
+ @classmethod
+ def _create_nulls_first(cls, column):
+ """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression.
+
+ :func:`.nulls_first` is intended to modify the expression produced
+ by :func:`.asc` or :func:`.desc`, and indicates how NULL values
+ should be handled when they are encountered during ordering::
+
+
+ from sqlalchemy import desc, nulls_first
+
+ stmt = select(users_table).order_by(
+ nulls_first(desc(users_table.c.name)))
+
+ The SQL expression from the above would resemble::
+
+ SELECT id, name FROM user ORDER BY name DESC NULLS FIRST
+
+ Like :func:`.asc` and :func:`.desc`, :func:`.nulls_first` is typically
+ invoked from the column expression itself using
+ :meth:`_expression.ColumnElement.nulls_first`,
+ rather than as its standalone
+ function version, as in::
+
+ stmt = select(users_table).order_by(
+ users_table.c.name.desc().nulls_first())
+
+ .. versionchanged:: 1.4 :func:`.nulls_first` is renamed from
+ :func:`.nullsfirst` in previous releases.
+ The previous name remains available for backwards compatibility.
+
+ .. seealso::
+
+ :func:`.asc`
+
+ :func:`.desc`
+
+ :func:`.nulls_last`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.nulls_first_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_nulls_last(cls, column):
+ """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression.
+
+ :func:`.nulls_last` is intended to modify the expression produced
+ by :func:`.asc` or :func:`.desc`, and indicates how NULL values
+ should be handled when they are encountered during ordering::
+
+
+ from sqlalchemy import desc, nulls_last
+
+ stmt = select(users_table).order_by(
+ nulls_last(desc(users_table.c.name)))
+
+ The SQL expression from the above would resemble::
+
+ SELECT id, name FROM user ORDER BY name DESC NULLS LAST
+
+ Like :func:`.asc` and :func:`.desc`, :func:`.nulls_last` is typically
+ invoked from the column expression itself using
+ :meth:`_expression.ColumnElement.nulls_last`,
+ rather than as its standalone
+ function version, as in::
+
+ stmt = select(users_table).order_by(
+ users_table.c.name.desc().nulls_last())
+
+ .. versionchanged:: 1.4 :func:`.nulls_last` is renamed from
+ :func:`.nullslast` in previous releases.
+ The previous name remains available for backwards compatibility.
+
+ .. seealso::
+
+ :func:`.asc`
+
+ :func:`.desc`
+
+ :func:`.nulls_first`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.nulls_last_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_desc(cls, column):
+ """Produce a descending ``ORDER BY`` clause element.
+
+ e.g.::
+
+ from sqlalchemy import desc
+
+ stmt = select(users_table).order_by(desc(users_table.c.name))
+
+ will produce SQL as::
+
+ SELECT id, name FROM user ORDER BY name DESC
+
+ The :func:`.desc` function is a standalone version of the
+ :meth:`_expression.ColumnElement.desc`
+ method available on all SQL expressions,
+ e.g.::
+
+
+ stmt = select(users_table).order_by(users_table.c.name.desc())
+
+ :param column: A :class:`_expression.ColumnElement` (e.g.
+ scalar SQL expression)
+ with which to apply the :func:`.desc` operation.
+
+ .. seealso::
+
+ :func:`.asc`
+
+ :func:`.nulls_first`
+
+ :func:`.nulls_last`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.desc_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_asc(cls, column):
+ """Produce an ascending ``ORDER BY`` clause element.
+
+ e.g.::
+
+ from sqlalchemy import asc
+ stmt = select(users_table).order_by(asc(users_table.c.name))
+
+ will produce SQL as::
+
+ SELECT id, name FROM user ORDER BY name ASC
+
+ The :func:`.asc` function is a standalone version of the
+ :meth:`_expression.ColumnElement.asc`
+ method available on all SQL expressions,
+ e.g.::
+
+
+ stmt = select(users_table).order_by(users_table.c.name.asc())
+
+ :param column: A :class:`_expression.ColumnElement` (e.g.
+ scalar SQL expression)
+ with which to apply the :func:`.asc` operation.
+
+ .. seealso::
+
+ :func:`.desc`
+
+ :func:`.nulls_first`
+
+ :func:`.nulls_last`
+
+ :meth:`_expression.Select.order_by`
+
+ """
+ return UnaryExpression(
+ coercions.expect(roles.ByOfRole, column),
+ modifier=operators.asc_op,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_distinct(cls, expr):
+ """Produce an column-expression-level unary ``DISTINCT`` clause.
+
+ This applies the ``DISTINCT`` keyword to an individual column
+ expression, and is typically contained within an aggregate function,
+ as in::
+
+ from sqlalchemy import distinct, func
+ stmt = select(func.count(distinct(users_table.c.name)))
+
+ The above would produce an expression resembling::
+
+ SELECT COUNT(DISTINCT name) FROM user
+
+ The :func:`.distinct` function is also available as a column-level
+ method, e.g. :meth:`_expression.ColumnElement.distinct`, as in::
+
+ stmt = select(func.count(users_table.c.name.distinct()))
+
+ The :func:`.distinct` operator is different from the
+ :meth:`_expression.Select.distinct` method of
+ :class:`_expression.Select`,
+ which produces a ``SELECT`` statement
+ with ``DISTINCT`` applied to the result set as a whole,
+ e.g. a ``SELECT DISTINCT`` expression. See that method for further
+ information.
+
+ .. seealso::
+
+ :meth:`_expression.ColumnElement.distinct`
+
+ :meth:`_expression.Select.distinct`
+
+ :data:`.func`
+
+ """
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+ return UnaryExpression(
+ expr,
+ operator=operators.distinct_op,
+ type_=expr.type,
+ wraps_column_expression=False,
+ )
+
+ @property
+ def _order_by_label_element(self):
+ if self.modifier in (operators.desc_op, operators.asc_op):
+ return self.element._order_by_label_element
+ else:
+ return None
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def _negate(self):
+ if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
+ return UnaryExpression(
+ self.self_group(against=operators.inv),
+ operator=operators.inv,
+ type_=type_api.BOOLEANTYPE,
+ wraps_column_expression=self.wraps_column_expression,
+ )
+ else:
+ return ClauseElement._negate(self)
+
+ def self_group(self, against=None):
+ if self.operator and operators.is_precedent(self.operator, against):
+ return Grouping(self)
+ else:
+ return self
+
+
+class CollectionAggregate(UnaryExpression):
+ """Forms the basis for right-hand collection operator modifiers
+ ANY and ALL.
+
+ The ANY and ALL keywords are available in different ways on different
+ backends. On PostgreSQL, they only work for an ARRAY type. On
+ MySQL, they only work for subqueries.
+
+ """
+
+ inherit_cache = True
+
+ @classmethod
+ def _create_any(cls, expr):
+ """Produce an ANY expression.
+
+ For dialects such as that of PostgreSQL, this operator applies
+ to usage of the :class:`_types.ARRAY` datatype, for that of
+ MySQL, it may apply to a subquery. e.g.::
+
+ # renders on PostgreSQL:
+ # '5 = ANY (somearray)'
+ expr = 5 == any_(mytable.c.somearray)
+
+ # renders on MySQL:
+ # '5 = ANY (SELECT value FROM table)'
+ expr = 5 == any_(select(table.c.value))
+
+ Comparison to NULL may work using ``None`` or :func:`_sql.null`::
+
+ None == any_(mytable.c.somearray)
+
+ The any_() / all_() operators also feature a special "operand flipping"
+ behavior such that if any_() / all_() are used on the left side of a
+ comparison using a standalone operator such as ``==``, ``!=``, etc.
+ (not including operator methods such as
+ :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped::
+
+ # would render '5 = ANY (column)`
+ any_(mytable.c.column) == 5
+
+ Or with ``None``, which note will not perform
+ the usual step of rendering "IS" as is normally the case for NULL::
+
+ # would render 'NULL = ANY(somearray)'
+ any_(mytable.c.somearray) == None
+
+ .. versionchanged:: 1.4.26 repaired the use of any_() / all_()
+ comparing to NULL on the right side to be flipped to the left.
+
+ The column-level :meth:`_sql.ColumnElement.any_` method (not to be
+ confused with :class:`_types.ARRAY` level
+ :meth:`_types.ARRAY.Comparator.any`) is shorthand for
+ ``any_(col)``::
+
+ 5 = mytable.c.somearray.any_()
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.any_`
+
+ :func:`_expression.all_`
+
+ """
+
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+
+ expr = expr.self_group()
+ return CollectionAggregate(
+ expr,
+ operator=operators.any_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
+
+ @classmethod
+ def _create_all(cls, expr):
+ """Produce an ALL expression.
+
+ For dialects such as that of PostgreSQL, this operator applies
+ to usage of the :class:`_types.ARRAY` datatype, for that of
+ MySQL, it may apply to a subquery. e.g.::
+
+ # renders on PostgreSQL:
+ # '5 = ALL (somearray)'
+ expr = 5 == all_(mytable.c.somearray)
+
+ # renders on MySQL:
+ # '5 = ALL (SELECT value FROM table)'
+ expr = 5 == all_(select(table.c.value))
+
+ Comparison to NULL may work using ``None``::
+
+ None == all_(mytable.c.somearray)
+
+ The any_() / all_() operators also feature a special "operand flipping"
+ behavior such that if any_() / all_() are used on the left side of a
+ comparison using a standalone operator such as ``==``, ``!=``, etc.
+ (not including operator methods such as
+ :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped::
+
+ # would render '5 = ALL (column)`
+ all_(mytable.c.column) == 5
+
+ Or with ``None``, which note will not perform
+ the usual step of rendering "IS" as is normally the case for NULL::
+
+ # would render 'NULL = ALL(somearray)'
+ all_(mytable.c.somearray) == None
+
+ .. versionchanged:: 1.4.26 repaired the use of any_() / all_()
+ comparing to NULL on the right side to be flipped to the left.
+
+ The column-level :meth:`_sql.ColumnElement.all_` method (not to be
+ confused with :class:`_types.ARRAY` level
+ :meth:`_types.ARRAY.Comparator.all`) is shorthand for
+ ``all_(col)``::
+
+ 5 == mytable.c.somearray.all_()
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.all_`
+
+ :func:`_expression.any_`
+
+ """
+ expr = coercions.expect(roles.ExpressionElementRole, expr)
+ expr = expr.self_group()
+ return CollectionAggregate(
+ expr,
+ operator=operators.all_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
+
+ # operate and reverse_operate are hardwired to
+ # dispatch onto the type comparator directly, so that we can
+ # ensure "reversed" behavior.
+ def operate(self, op, *other, **kwargs):
+ if not operators.is_comparison(op):
+ raise exc.ArgumentError(
+ "Only comparison operators may be used with ANY/ALL"
+ )
+ kwargs["reverse"] = kwargs["_any_all_expr"] = True
+ return self.comparator.operate(operators.mirror(op), *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ # comparison operators should never call reverse_operate
+ assert not operators.is_comparison(op)
+ raise exc.ArgumentError(
+ "Only comparison operators may be used with ANY/ALL"
+ )
+
+
+class AsBoolean(WrapsColumnExpression, UnaryExpression):
+ inherit_cache = True
+
+ def __init__(self, element, operator, negate):
+ self.element = element
+ self.type = type_api.BOOLEANTYPE
+ self.operator = operator
+ self.negate = negate
+ self.modifier = None
+ self.wraps_column_expression = True
+ self._is_implicitly_boolean = element._is_implicitly_boolean
+
+ @property
+ def wrapped_column_expression(self):
+ return self.element
+
+ def self_group(self, against=None):
+ return self
+
+ def _negate(self):
+ if isinstance(self.element, (True_, False_)):
+ return self.element._negate()
+ else:
+ return AsBoolean(self.element, self.negate, self.operator)
+
+
+class BinaryExpression(ColumnElement):
+ """Represent an expression that is ``LEFT <operator> RIGHT``.
+
+ A :class:`.BinaryExpression` is generated automatically
+ whenever two column expressions are used in a Python binary expression::
+
+ >>> from sqlalchemy.sql import column
+ >>> column('a') + column('b')
+ <sqlalchemy.sql.expression.BinaryExpression object at 0x101029dd0>
+ >>> print(column('a') + column('b'))
+ a + b
+
+ """
+
+ __visit_name__ = "binary"
+
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("operator", InternalTraversal.dp_operator),
+ ("negate", InternalTraversal.dp_operator),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ (
+ "type",
+ InternalTraversal.dp_type,
+ ), # affects JSON CAST operators
+ ]
+
+ _is_implicitly_boolean = True
+ """Indicates that any database will know this is a boolean expression
+ even if the database does not have an explicit boolean datatype.
+
+ """
+
+ def __init__(
+ self, left, right, operator, type_=None, negate=None, modifiers=None
+ ):
+ # allow compatibility with libraries that
+ # refer to BinaryExpression directly and pass strings
+ if isinstance(operator, util.string_types):
+ operator = operators.custom_op(operator)
+ self._orig = (left.__hash__(), right.__hash__())
+ self._propagate_attrs = left._propagate_attrs or right._propagate_attrs
+ self.left = left.self_group(against=operator)
+ self.right = right.self_group(against=operator)
+ self.operator = operator
+ self.type = type_api.to_instance(type_)
+ self.negate = negate
+ self._is_implicitly_boolean = operators.is_boolean(operator)
+
+ if modifiers is None:
+ self.modifiers = {}
+ else:
+ self.modifiers = modifiers
+
+ def __bool__(self):
+ if self.operator in (operator.eq, operator.ne):
+ return self.operator(*self._orig)
+ else:
+ raise TypeError("Boolean value of this clause is not defined")
+
+ __nonzero__ = __bool__
+
+ @property
+ def is_comparison(self):
+ return operators.is_comparison(self.operator)
+
+ @property
+ def _from_objects(self):
+ return self.left._from_objects + self.right._from_objects
+
+ def self_group(self, against=None):
+
+ if operators.is_precedent(self.operator, against):
+ return Grouping(self)
+ else:
+ return self
+
+ def _negate(self):
+ if self.negate is not None:
+ return BinaryExpression(
+ self.left,
+ self.right._negate_in_binary(self.negate, self.operator),
+ self.negate,
+ negate=self.operator,
+ type_=self.type,
+ modifiers=self.modifiers,
+ )
+ else:
+ return super(BinaryExpression, self)._negate()
+
+
+class Slice(ColumnElement):
+ """Represent SQL for a Python array-slice object.
+
+ This is not a specific SQL construct at this level, but
+ may be interpreted by specific dialects, e.g. PostgreSQL.
+
+ """
+
+ __visit_name__ = "slice"
+
+ _traverse_internals = [
+ ("start", InternalTraversal.dp_clauseelement),
+ ("stop", InternalTraversal.dp_clauseelement),
+ ("step", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, start, stop, step, _name=None):
+ self.start = coercions.expect(
+ roles.ExpressionElementRole,
+ start,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.stop = coercions.expect(
+ roles.ExpressionElementRole,
+ stop,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.step = coercions.expect(
+ roles.ExpressionElementRole,
+ step,
+ name=_name,
+ type_=type_api.INTEGERTYPE,
+ )
+ self.type = type_api.NULLTYPE
+
+ def self_group(self, against=None):
+ assert against is operator.getitem
+ return self
+
+
+class IndexExpression(BinaryExpression):
+ """Represent the class of expressions that are like an "index"
+ operation."""
+
+ inherit_cache = True
+
+
+class GroupedElement(ClauseElement):
+ """Represent any parenthesized expression"""
+
+ __visit_name__ = "grouping"
+
+ def self_group(self, against=None):
+ return self
+
+ def _ungroup(self):
+ return self.element._ungroup()
+
+
+class Grouping(GroupedElement, ColumnElement):
+ """Represent a grouping within a column expression"""
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ def __init__(self, element):
+ self.element = element
+ self.type = getattr(element, "type", type_api.NULLTYPE)
+
+ def _with_binary_element_type(self, type_):
+ return self.__class__(self.element._with_binary_element_type(type_))
+
+ @util.memoized_property
+ def _is_implicitly_boolean(self):
+ return self.element._is_implicitly_boolean
+
+ @property
+ def _tq_label(self):
+ return (
+ getattr(self.element, "_tq_label", None) or self._anon_name_label
+ )
+
+ @property
+ def _proxies(self):
+ if isinstance(self.element, ColumnElement):
+ return [self.element]
+ else:
+ return []
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def __getattr__(self, attr):
+ return getattr(self.element, attr)
+
+ def __getstate__(self):
+ return {"element": self.element, "type": self.type}
+
+ def __setstate__(self, state):
+ self.element = state["element"]
+ self.type = state["type"]
+
+
+RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
+RANGE_CURRENT = util.symbol("RANGE_CURRENT")
+
+
+class Over(ColumnElement):
+ """Represent an OVER clause.
+
+ This is a special operator against a so-called
+ "window" function, as well as any aggregate function,
+ which produces results relative to the result set
+ itself. Most modern SQL backends now support window functions.
+
+ """
+
+ __visit_name__ = "over"
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ("partition_by", InternalTraversal.dp_clauseelement),
+ ("range_", InternalTraversal.dp_plain_obj),
+ ("rows", InternalTraversal.dp_plain_obj),
+ ]
+
+ order_by = None
+ partition_by = None
+
+ element = None
+ """The underlying expression object to which this :class:`.Over`
+ object refers towards."""
+
+ def __init__(
+ self, element, partition_by=None, order_by=None, range_=None, rows=None
+ ):
+ r"""Produce an :class:`.Over` object against a function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ :func:`_expression.over` is usually called using
+ the :meth:`.FunctionElement.over` method, e.g.::
+
+ func.row_number().over(order_by=mytable.c.some_column)
+
+ Would produce::
+
+ ROW_NUMBER() OVER(ORDER BY some_column)
+
+ Ranges are also possible using the :paramref:`.expression.over.range_`
+ and :paramref:`.expression.over.rows` parameters. These
+ mutually-exclusive parameters each accept a 2-tuple, which contains
+ a combination of integers and None::
+
+ func.row_number().over(
+ order_by=my_table.c.some_column, range_=(None, 0))
+
+ The above would produce::
+
+ ROW_NUMBER() OVER(ORDER BY some_column
+ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
+
+ A value of ``None`` indicates "unbounded", a
+ value of zero indicates "current row", and negative / positive
+ integers indicate "preceding" and "following":
+
+ * RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING::
+
+ func.row_number().over(order_by='x', range_=(-5, 10))
+
+ * ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW::
+
+ func.row_number().over(order_by='x', rows=(None, 0))
+
+ * RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING::
+
+ func.row_number().over(order_by='x', range_=(-2, None))
+
+ * RANGE BETWEEN 1 FOLLOWING AND 3 FOLLOWING::
+
+ func.row_number().over(order_by='x', range_=(1, 3))
+
+ .. versionadded:: 1.1 support for RANGE / ROWS within a window
+
+
+ :param element: a :class:`.FunctionElement`, :class:`.WithinGroup`,
+ or other compatible construct.
+ :param partition_by: a column element or string, or a list
+ of such, that will be used as the PARTITION BY clause
+ of the OVER construct.
+ :param order_by: a column element or string, or a list
+ of such, that will be used as the ORDER BY clause
+ of the OVER construct.
+ :param range\_: optional range clause for the window. This is a
+ tuple value which can contain integer values or ``None``,
+ and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause.
+
+ .. versionadded:: 1.1
+
+ :param rows: optional rows clause for the window. This is a tuple
+ value which can contain integer values or None, and will render
+ a ROWS BETWEEN PRECEDING / FOLLOWING clause.
+
+ .. versionadded:: 1.1
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.over` method.
+
+ .. seealso::
+
+ :ref:`tutorial_window_functions` - in the :ref:`unified_tutorial`
+
+ :data:`.expression.func`
+
+ :func:`_expression.within_group`
+
+ """
+ self.element = element
+ if order_by is not None:
+ self.order_by = ClauseList(
+ *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole
+ )
+ if partition_by is not None:
+ self.partition_by = ClauseList(
+ *util.to_list(partition_by),
+ _literal_as_text_role=roles.ByOfRole
+ )
+
+ if range_:
+ self.range_ = self._interpret_range(range_)
+ if rows:
+ raise exc.ArgumentError(
+ "'range_' and 'rows' are mutually exclusive"
+ )
+ else:
+ self.rows = None
+ elif rows:
+ self.rows = self._interpret_range(rows)
+ self.range_ = None
+ else:
+ self.rows = self.range_ = None
+
+ def __reduce__(self):
+ return self.__class__, (
+ self.element,
+ self.partition_by,
+ self.order_by,
+ self.range_,
+ self.rows,
+ )
+
+ def _interpret_range(self, range_):
+ if not isinstance(range_, tuple) or len(range_) != 2:
+ raise exc.ArgumentError("2-tuple expected for range/rows")
+
+ if range_[0] is None:
+ lower = RANGE_UNBOUNDED
+ else:
+ try:
+ lower = int(range_[0])
+ except ValueError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Integer or None expected for range value"
+ ),
+ replace_context=err,
+ )
+ else:
+ if lower == 0:
+ lower = RANGE_CURRENT
+
+ if range_[1] is None:
+ upper = RANGE_UNBOUNDED
+ else:
+ try:
+ upper = int(range_[1])
+ except ValueError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "Integer or None expected for range value"
+ ),
+ replace_context=err,
+ )
+ else:
+ if upper == 0:
+ upper = RANGE_CURRENT
+
+ return lower, upper
+
+ @util.memoized_property
+ def type(self):
+ return self.element.type
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
+ )
+ )
+
+
+class WithinGroup(ColumnElement):
+ """Represent a WITHIN GROUP (ORDER BY) clause.
+
+ This is a special operator against so-called
+ "ordered set aggregate" and "hypothetical
+ set aggregate" functions, including ``percentile_cont()``,
+ ``rank()``, ``dense_rank()``, etc.
+
+ It's supported only by certain database backends, such as PostgreSQL,
+ Oracle and MS SQL Server.
+
+ The :class:`.WithinGroup` construct extracts its type from the
+ method :meth:`.FunctionElement.within_group_type`. If this returns
+ ``None``, the function's ``.type`` is used.
+
+ """
+
+ __visit_name__ = "withingroup"
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("order_by", InternalTraversal.dp_clauseelement),
+ ]
+
+ order_by = None
+
+ def __init__(self, element, *order_by):
+ r"""Produce a :class:`.WithinGroup` object against a function.
+
+ Used against so-called "ordered set aggregate" and "hypothetical
+ set aggregate" functions, including :class:`.percentile_cont`,
+ :class:`.rank`, :class:`.dense_rank`, etc.
+
+ :func:`_expression.within_group` is usually called using
+ the :meth:`.FunctionElement.within_group` method, e.g.::
+
+ from sqlalchemy import within_group
+ stmt = select(
+ department.c.id,
+ func.percentile_cont(0.5).within_group(
+ department.c.salary.desc()
+ )
+ )
+
+ The above statement would produce SQL similar to
+ ``SELECT department.id, percentile_cont(0.5)
+ WITHIN GROUP (ORDER BY department.salary DESC)``.
+
+ :param element: a :class:`.FunctionElement` construct, typically
+ generated by :data:`~.expression.func`.
+ :param \*order_by: one or more column elements that will be used
+ as the ORDER BY clause of the WITHIN GROUP construct.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` - in the
+ :ref:`unified_tutorial`
+
+ :data:`.expression.func`
+
+ :func:`_expression.over`
+
+ """
+ self.element = element
+ if order_by is not None:
+ self.order_by = ClauseList(
+ *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole
+ )
+
+ def __reduce__(self):
+ return self.__class__, (self.element,) + tuple(self.order_by)
+
+ def over(self, partition_by=None, order_by=None, range_=None, rows=None):
+ """Produce an OVER clause against this :class:`.WithinGroup`
+ construct.
+
+ This function has the same signature as that of
+ :meth:`.FunctionElement.over`.
+
+ """
+ return Over(
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
+
+ @util.memoized_property
+ def type(self):
+ wgt = self.element.within_group_type(self)
+ if wgt is not None:
+ return wgt
+ else:
+ return self.element.type
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.order_by)
+ if c is not None
+ ]
+ )
+ )
+
+
+class FunctionFilter(ColumnElement):
+ """Represent a function FILTER clause.
+
+ This is a special operator against aggregate and window functions,
+ which controls which rows are passed to it.
+ It's supported only by certain database backends.
+
+ Invocation of :class:`.FunctionFilter` is via
+ :meth:`.FunctionElement.filter`::
+
+ func.count(1).filter(True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`.FunctionElement.filter`
+
+ """
+
+ __visit_name__ = "funcfilter"
+
+ _traverse_internals = [
+ ("func", InternalTraversal.dp_clauseelement),
+ ("criterion", InternalTraversal.dp_clauseelement),
+ ]
+
+ criterion = None
+
+ def __init__(self, func, *criterion):
+ """Produce a :class:`.FunctionFilter` object against a function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ E.g.::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), MyClass.name == 'some name')
+
+ Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.filter` method.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` - in the
+ :ref:`unified_tutorial`
+
+ :meth:`.FunctionElement.filter`
+
+ """
+ self.func = func
+ self.filter(*criterion)
+
+ def filter(self, *criterion):
+ """Produce an additional FILTER against the function.
+
+ This method adds additional criteria to the initial criteria
+ set up by :meth:`.FunctionElement.filter`.
+
+ Multiple criteria are joined together at SQL render time
+ via ``AND``.
+
+
+ """
+
+ for criterion in list(criterion):
+ criterion = coercions.expect(roles.WhereHavingRole, criterion)
+
+ if self.criterion is not None:
+ self.criterion = self.criterion & criterion
+ else:
+ self.criterion = criterion
+
+ return self
+
+ def over(self, partition_by=None, order_by=None, range_=None, rows=None):
+ """Produce an OVER clause against this filtered function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ The expression::
+
+ func.rank().filter(MyClass.y > 5).over(order_by='x')
+
+ is shorthand for::
+
+ from sqlalchemy import over, funcfilter
+ over(funcfilter(func.rank(), MyClass.y > 5), order_by='x')
+
+ See :func:`_expression.over` for a full description.
+
+ """
+ return Over(
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
+
+ def self_group(self, against=None):
+ if operators.is_precedent(operators.filter_op, against):
+ return Grouping(self)
+ else:
+ return self
+
+ @util.memoized_property
+ def type(self):
+ return self.func.type
+
+ @property
+ def _from_objects(self):
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.func, self.criterion)
+ if c is not None
+ ]
+ )
+ )
+
+
+class Label(roles.LabeledColumnExprRole, ColumnElement):
+ """Represents a column label (AS).
+
+ Represent a label, as typically applied to any column-level
+ element using the ``AS`` sql keyword.
+
+ """
+
+ __visit_name__ = "label"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("_type", InternalTraversal.dp_type),
+ ("_element", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, name, element, type_=None):
+ """Return a :class:`Label` object for the
+ given :class:`_expression.ColumnElement`.
+
+ A label changes the name of an element in the columns clause of a
+ ``SELECT`` statement, typically via the ``AS`` SQL keyword.
+
+ This functionality is more conveniently available via the
+ :meth:`_expression.ColumnElement.label` method on
+ :class:`_expression.ColumnElement`.
+
+ :param name: label name
+
+ :param obj: a :class:`_expression.ColumnElement`.
+
+ """
+
+ orig_element = element
+ element = coercions.expect(
+ roles.ExpressionElementRole,
+ element,
+ apply_propagate_attrs=self,
+ )
+ while isinstance(element, Label):
+ # TODO: this is only covered in test_text.py, but nothing
+ # fails if it's removed. determine rationale
+ element = element.element
+
+ if name:
+ self.name = name
+ else:
+ self.name = _anonymous_label.safe_construct(
+ id(self), getattr(element, "name", "anon")
+ )
+ if isinstance(orig_element, Label):
+ # TODO: no coverage for this block, again would be in
+ # test_text.py where the resolve_label concept is important
+ self._resolve_label = orig_element._label
+
+ self.key = self._tq_label = self._tq_key_label = self.name
+ self._element = element
+ self._type = type_
+ self._proxies = [element]
+
+ def __reduce__(self):
+ return self.__class__, (self.name, self._element, self._type)
+
+ @util.memoized_property
+ def _is_implicitly_boolean(self):
+ return self.element._is_implicitly_boolean
+
+ @HasMemoized.memoized_attribute
+ def _allow_label_resolve(self):
+ return self.element._allow_label_resolve
+
+ @property
+ def _order_by_label_element(self):
+ return self
+
+ @util.memoized_property
+ def type(self):
+ return type_api.to_instance(
+ self._type or getattr(self._element, "type", None)
+ )
+
+ @HasMemoized.memoized_attribute
+ def element(self):
+ return self._element.self_group(against=operators.as_)
+
+ def self_group(self, against=None):
+ return self._apply_to_inner(self._element.self_group, against=against)
+
+ def _negate(self):
+ return self._apply_to_inner(self._element._negate)
+
+ def _apply_to_inner(self, fn, *arg, **kw):
+ sub_element = fn(*arg, **kw)
+ if sub_element is not self._element:
+ return Label(self.name, sub_element, type_=self._type)
+ else:
+ return self
+
+ @property
+ def primary_key(self):
+ return self.element.primary_key
+
+ @property
+ def foreign_keys(self):
+ return self.element.foreign_keys
+
+ def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
+ self._reset_memoizations()
+ self._element = clone(self._element, **kw)
+ if anonymize_labels:
+ self.name = _anonymous_label.safe_construct(
+ id(self), getattr(self.element, "name", "anon")
+ )
+ self.key = self._tq_label = self._tq_key_label = self.name
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def _make_proxy(self, selectable, name=None, **kw):
+ name = self.name if not name else name
+
+ key, e = self.element._make_proxy(
+ selectable,
+ name=name,
+ disallow_is_literal=True,
+ name_is_truncatable=isinstance(name, _truncated_label),
+ )
+
+ # there was a note here to remove this assertion, which was here
+ # to determine if we later could support a use case where
+ # the key and name of a label are separate. But I don't know what
+ # that case was. For now, this is an unexpected case that occurs
+ # when a label name conflicts with other columns and select()
+ # is attempting to disambiguate an explicit label, which is not what
+ # the user would want. See issue #6090.
+ if key != self.name:
+ raise exc.InvalidRequestError(
+ "Label name %s is being renamed to an anonymous label due "
+ "to disambiguation "
+ "which is not supported right now. Please use unique names "
+ "for explicit labels." % (self.name)
+ )
+
+ e._propagate_attrs = selectable._propagate_attrs
+ e._proxies.append(self)
+ if self._type is not None:
+ e.type = self._type
+
+ return self.key, e
+
+
+class NamedColumn(ColumnElement):
+ is_literal = False
+ table = None
+
+ def _compare_name_for_result(self, other):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_label") and self._label == other._label
+ )
+
+ @util.memoized_property
+ def description(self):
+ if util.py3k:
+ return self.name
+ else:
+ return self.name.encode("ascii", "backslashreplace")
+
+ @HasMemoized.memoized_attribute
+ def _tq_key_label(self):
+ """table qualified label based on column key.
+
+ for table-bound columns this is <tablename>_<column key/proxy key>;
+
+ all other expressions it resolves to key/proxy key.
+
+ """
+ proxy_key = self._proxy_key
+ if proxy_key and proxy_key != self.name:
+ return self._gen_tq_label(proxy_key)
+ else:
+ return self._tq_label
+
+ @HasMemoized.memoized_attribute
+ def _tq_label(self):
+ """table qualified label based on column name.
+
+ for table-bound columns this is <tablename>_<columnname>; all other
+ expressions it resolves to .name.
+
+ """
+ return self._gen_tq_label(self.name)
+
+ @HasMemoized.memoized_attribute
+ def _render_label_in_columns_clause(self):
+ return True
+
+ @HasMemoized.memoized_attribute
+ def _non_anon_label(self):
+ return self.name
+
+ def _gen_tq_label(self, name, dedupe_on_key=True):
+ return name
+
+ def _bind_param(self, operator, obj, type_=None, expanding=False):
+ return BindParameter(
+ self.key,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ expanding=expanding,
+ )
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ name_is_truncatable=False,
+ disallow_is_literal=False,
+ **kw
+ ):
+ c = ColumnClause(
+ coercions.expect(roles.TruncatedLabelRole, name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
+ type_=self.type,
+ _selectable=selectable,
+ is_literal=False,
+ )
+ c._propagate_attrs = selectable._propagate_attrs
+ if name is None:
+ c.key = self.key
+ c._proxies = [self]
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
+ return c.key, c
+
+
+class ColumnClause(
+ roles.DDLReferredColumnRole,
+ roles.LabeledColumnExprRole,
+ roles.StrAsPlainColumnRole,
+ Immutable,
+ NamedColumn,
+):
+ """Represents a column expression from any textual string.
+
+ The :class:`.ColumnClause`, a lightweight analogue to the
+ :class:`_schema.Column` class, is typically invoked using the
+ :func:`_expression.column` function, as in::
+
+ from sqlalchemy import column
+
+ id, name = column("id"), column("name")
+ stmt = select(id, name).select_from("user")
+
+ The above statement would produce SQL like::
+
+ SELECT id, name FROM user
+
+ :class:`.ColumnClause` is the immediate superclass of the schema-specific
+ :class:`_schema.Column` object. While the :class:`_schema.Column`
+ class has all the
+ same capabilities as :class:`.ColumnClause`, the :class:`.ColumnClause`
+ class is usable by itself in those cases where behavioral requirements
+ are limited to simple SQL expression generation. The object has none of
+ the associations with schema-level metadata or with execution-time
+ behavior that :class:`_schema.Column` does,
+ so in that sense is a "lightweight"
+ version of :class:`_schema.Column`.
+
+ Full details on :class:`.ColumnClause` usage is at
+ :func:`_expression.column`.
+
+ .. seealso::
+
+ :func:`_expression.column`
+
+ :class:`_schema.Column`
+
+ """
+
+ table = None
+ is_literal = False
+
+ __visit_name__ = "column"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("table", InternalTraversal.dp_clauseelement),
+ ("is_literal", InternalTraversal.dp_boolean),
+ ]
+
+ onupdate = default = server_default = server_onupdate = None
+
+ _is_multiparam_column = False
+
+ @property
+ def _is_star(self):
+ return self.is_literal and self.name == "*"
+
+ def __init__(self, text, type_=None, is_literal=False, _selectable=None):
+ """Produce a :class:`.ColumnClause` object.
+
+ The :class:`.ColumnClause` is a lightweight analogue to the
+ :class:`_schema.Column` class. The :func:`_expression.column`
+ function can
+ be invoked with just a name alone, as in::
+
+ from sqlalchemy import column
+
+ id, name = column("id"), column("name")
+ stmt = select(id, name).select_from("user")
+
+ The above statement would produce SQL like::
+
+ SELECT id, name FROM user
+
+ Once constructed, :func:`_expression.column`
+ may be used like any other SQL
+ expression element such as within :func:`_expression.select`
+ constructs::
+
+ from sqlalchemy.sql import column
+
+ id, name = column("id"), column("name")
+ stmt = select(id, name).select_from("user")
+
+ The text handled by :func:`_expression.column`
+ is assumed to be handled
+ like the name of a database column; if the string contains mixed case,
+ special characters, or matches a known reserved word on the target
+ backend, the column expression will render using the quoting
+ behavior determined by the backend. To produce a textual SQL
+ expression that is rendered exactly without any quoting,
+ use :func:`_expression.literal_column` instead,
+ or pass ``True`` as the
+ value of :paramref:`_expression.column.is_literal`. Additionally,
+ full SQL
+ statements are best handled using the :func:`_expression.text`
+ construct.
+
+ :func:`_expression.column` can be used in a table-like
+ fashion by combining it with the :func:`.table` function
+ (which is the lightweight analogue to :class:`_schema.Table`
+ ) to produce
+ a working table construct with minimal boilerplate::
+
+ from sqlalchemy import table, column, select
+
+ user = table("user",
+ column("id"),
+ column("name"),
+ column("description"),
+ )
+
+ stmt = select(user.c.description).where(user.c.name == 'wendy')
+
+ A :func:`_expression.column` / :func:`.table`
+ construct like that illustrated
+ above can be created in an
+ ad-hoc fashion and is not associated with any
+ :class:`_schema.MetaData`, DDL, or events, unlike its
+ :class:`_schema.Table` counterpart.
+
+ .. versionchanged:: 1.0.0 :func:`_expression.column` can now
+ be imported from the plain ``sqlalchemy`` namespace like any
+ other SQL element.
+
+ :param text: the text of the element.
+
+ :param type: :class:`_types.TypeEngine` object which can associate
+ this :class:`.ColumnClause` with a type.
+
+ :param is_literal: if True, the :class:`.ColumnClause` is assumed to
+ be an exact expression that will be delivered to the output with no
+ quoting rules applied regardless of case sensitive settings. the
+ :func:`_expression.literal_column()` function essentially invokes
+ :func:`_expression.column` while passing ``is_literal=True``.
+
+ .. seealso::
+
+ :class:`_schema.Column`
+
+ :func:`_expression.literal_column`
+
+ :func:`.table`
+
+ :func:`_expression.text`
+
+ :ref:`tutorial_select_arbitrary_text`
+
+ """
+ self.key = self.name = text
+ self.table = _selectable
+ self.type = type_api.to_instance(type_)
+ self.is_literal = is_literal
+
+ def get_children(self, column_tables=False, **kw):
+ # override base get_children() to not return the Table
+ # or selectable that is parent to this column. Traversals
+ # expect the columns of tables and subqueries to be leaf nodes.
+ return []
+
+ @property
+ def entity_namespace(self):
+ if self.table is not None:
+ return self.table.entity_namespace
+ else:
+ return super(ColumnClause, self).entity_namespace
+
+ def _clone(self, detect_subquery_cols=False, **kw):
+ if (
+ detect_subquery_cols
+ and self.table is not None
+ and self.table._is_subquery
+ ):
+ clone = kw.pop("clone")
+ table = clone(self.table, **kw)
+ new = table.c.corresponding_column(self)
+ return new
+
+ return super(ColumnClause, self)._clone(**kw)
+
+ @HasMemoized.memoized_attribute
+ def _from_objects(self):
+ t = self.table
+ if t is not None:
+ return [t]
+ else:
+ return []
+
+ @HasMemoized.memoized_attribute
+ def _render_label_in_columns_clause(self):
+ return self.table is not None
+
+ @property
+ def _ddl_label(self):
+ return self._gen_tq_label(self.name, dedupe_on_key=False)
+
+ def _compare_name_for_result(self, other):
+ if (
+ self.is_literal
+ or self.table is None
+ or self.table._is_textual
+ or not hasattr(other, "proxy_set")
+ or (
+ isinstance(other, ColumnClause)
+ and (
+ other.is_literal
+ or other.table is None
+ or other.table._is_textual
+ )
+ )
+ ):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_tq_label")
+ and self._tq_label == other._tq_label
+ )
+ else:
+ return other.proxy_set.intersection(self.proxy_set)
+
+ def _gen_tq_label(self, name, dedupe_on_key=True):
+ """generate table-qualified label
+
+ for a table-bound column this is <tablename>_<columnname>.
+
+ used primarily for LABEL_STYLE_TABLENAME_PLUS_COL
+ as well as the .columns collection on a Join object.
+
+ """
+ t = self.table
+ if self.is_literal:
+ return None
+ elif t is not None and t.named_with_column:
+ if getattr(t, "schema", None):
+ label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
+ else:
+ label = t.name + "_" + name
+
+ # propagate name quoting rules for labels.
+ if getattr(name, "quote", None) is not None:
+ if isinstance(label, quoted_name):
+ label.quote = name.quote
+ else:
+ label = quoted_name(label, name.quote)
+ elif getattr(t.name, "quote", None) is not None:
+ # can't get this situation to occur, so let's
+ # assert false on it for now
+ assert not isinstance(label, quoted_name)
+ label = quoted_name(label, t.name.quote)
+
+ if dedupe_on_key:
+ # ensure the label name doesn't conflict with that of an
+ # existing column. note that this implies that any Column
+ # must **not** set up its _label before its parent table has
+ # all of its other Column objects set up. There are several
+ # tables in the test suite which will fail otherwise; example:
+ # table "owner" has columns "name" and "owner_name". Therefore
+ # column owner.name cannot use the label "owner_name", it has
+ # to be "owner_name_1".
+ if label in t.c:
+ _label = label
+ counter = 1
+ while _label in t.c:
+ _label = label + "_" + str(counter)
+ counter += 1
+ label = _label
+
+ return coercions.expect(roles.TruncatedLabelRole, label)
+
+ else:
+ return name
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ name_is_truncatable=False,
+ disallow_is_literal=False,
+ **kw
+ ):
+ # the "is_literal" flag normally should never be propagated; a proxied
+ # column is always a SQL identifier and never the actual expression
+ # being evaluated. however, there is a case where the "is_literal" flag
+ # might be used to allow the given identifier to have a fixed quoting
+ # pattern already, so maintain the flag for the proxy unless a
+ # :class:`.Label` object is creating the proxy. See [ticket:4730].
+ is_literal = (
+ not disallow_is_literal
+ and self.is_literal
+ and (
+ # note this does not accommodate for quoted_name differences
+ # right now
+ name is None
+ or name == self.name
+ )
+ )
+ c = self._constructor(
+ coercions.expect(roles.TruncatedLabelRole, name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
+ type_=self.type,
+ _selectable=selectable,
+ is_literal=is_literal,
+ )
+ c._propagate_attrs = selectable._propagate_attrs
+ if name is None:
+ c.key = self.key
+ c._proxies = [self]
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
+ return c.key, c
+
+
+class TableValuedColumn(NamedColumn):
+ __visit_name__ = "table_valued_column"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("scalar_alias", InternalTraversal.dp_clauseelement),
+ ]
+
+ def __init__(self, scalar_alias, type_):
+ self.scalar_alias = scalar_alias
+ self.key = self.name = scalar_alias.name
+ self.type = type_
+
+ def _copy_internals(self, clone=_clone, **kw):
+ self.scalar_alias = clone(self.scalar_alias, **kw)
+ self.key = self.name = self.scalar_alias.name
+
+ @property
+ def _from_objects(self):
+ return [self.scalar_alias]
+
+
+class CollationClause(ColumnElement):
+ __visit_name__ = "collation"
+
+ _traverse_internals = [("collation", InternalTraversal.dp_string)]
+
+ def __init__(self, collation):
+ self.collation = collation
+
+
+class _IdentifiedClause(Executable, ClauseElement):
+
+ __visit_name__ = "identified"
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": False}
+ )
+
+ def __init__(self, ident):
+ self.ident = ident
+
+
+class SavepointClause(_IdentifiedClause):
+ __visit_name__ = "savepoint"
+ inherit_cache = False
+
+
+class RollbackToSavepointClause(_IdentifiedClause):
+ __visit_name__ = "rollback_to_savepoint"
+ inherit_cache = False
+
+
+class ReleaseSavepointClause(_IdentifiedClause):
+ __visit_name__ = "release_savepoint"
+ inherit_cache = False
+
+
+class quoted_name(util.MemoizedSlots, util.text_type):
+ """Represent a SQL identifier combined with quoting preferences.
+
+ :class:`.quoted_name` is a Python unicode/str subclass which
+ represents a particular identifier name along with a
+ ``quote`` flag. This ``quote`` flag, when set to
+ ``True`` or ``False``, overrides automatic quoting behavior
+ for this identifier in order to either unconditionally quote
+ or to not quote the name. If left at its default of ``None``,
+ quoting behavior is applied to the identifier on a per-backend basis
+ based on an examination of the token itself.
+
+ A :class:`.quoted_name` object with ``quote=True`` is also
+ prevented from being modified in the case of a so-called
+ "name normalize" option. Certain database backends, such as
+ Oracle, Firebird, and DB2 "normalize" case-insensitive names
+ as uppercase. The SQLAlchemy dialects for these backends
+ convert from SQLAlchemy's lower-case-means-insensitive convention
+ to the upper-case-means-insensitive conventions of those backends.
+ The ``quote=True`` flag here will prevent this conversion from occurring
+ to support an identifier that's quoted as all lower case against
+ such a backend.
+
+ The :class:`.quoted_name` object is normally created automatically
+ when specifying the name for key schema constructs such as
+ :class:`_schema.Table`, :class:`_schema.Column`, and others.
+ The class can also be
+ passed explicitly as the name to any function that receives a name which
+ can be quoted. Such as to use the :meth:`_engine.Engine.has_table`
+ method with
+ an unconditionally quoted name::
+
+ from sqlalchemy import create_engine
+ from sqlalchemy import inspect
+ from sqlalchemy.sql import quoted_name
+
+ engine = create_engine("oracle+cx_oracle://some_dsn")
+ print(inspect(engine).has_table(quoted_name("some_table", True)))
+
+ The above logic will run the "has table" logic against the Oracle backend,
+ passing the name exactly as ``"some_table"`` without converting to
+ upper case.
+
+ .. versionadded:: 0.9.0
+
+ .. versionchanged:: 1.2 The :class:`.quoted_name` construct is now
+ importable from ``sqlalchemy.sql``, in addition to the previous
+ location of ``sqlalchemy.sql.elements``.
+
+ """
+
+ __slots__ = "quote", "lower", "upper"
+
+ def __new__(cls, value, quote):
+ if value is None:
+ return None
+ # experimental - don't bother with quoted_name
+ # if quote flag is None. doesn't seem to make any dent
+ # in performance however
+ # elif not sprcls and quote is None:
+ # return value
+ elif isinstance(value, cls) and (
+ quote is None or value.quote == quote
+ ):
+ return value
+ self = super(quoted_name, cls).__new__(cls, value)
+
+ self.quote = quote
+ return self
+
+ def __reduce__(self):
+ return quoted_name, (util.text_type(self), self.quote)
+
+ def _memoized_method_lower(self):
+ if self.quote:
+ return self
+ else:
+ return util.text_type(self).lower()
+
+ def _memoized_method_upper(self):
+ if self.quote:
+ return self
+ else:
+ return util.text_type(self).upper()
+
+ def __repr__(self):
+ if util.py2k:
+ backslashed = self.encode("ascii", "backslashreplace")
+ if not util.py2k:
+ backslashed = backslashed.decode("ascii")
+ return "'%s'" % backslashed
+ else:
+ return str.__repr__(self)
+
+
+def _find_columns(clause):
+ """locate Column objects within the given expression."""
+
+ cols = util.column_set()
+ traverse(clause, {}, {"column": cols.add})
+ return cols
+
+
+def _type_from_args(args):
+ for a in args:
+ if not a.type._isnull:
+ return a.type
+ else:
+ return type_api.NULLTYPE
+
+
+def _corresponding_column_or_error(fromclause, column, require_embedded=False):
+ c = fromclause.corresponding_column(
+ column, require_embedded=require_embedded
+ )
+ if c is None:
+ raise exc.InvalidRequestError(
+ "Given column '%s', attached to table '%s', "
+ "failed to locate a corresponding column from table '%s'"
+ % (column, getattr(column, "table", None), fromclause.description)
+ )
+ return c
+
+
+class AnnotatedColumnElement(Annotated):
+ def __init__(self, element, values):
+ Annotated.__init__(self, element, values)
+ for attr in (
+ "comparator",
+ "_proxy_key",
+ "_tq_key_label",
+ "_tq_label",
+ "_non_anon_label",
+ ):
+ self.__dict__.pop(attr, None)
+ for attr in ("name", "key", "table"):
+ if self.__dict__.get(attr, False) is None:
+ self.__dict__.pop(attr)
+
+ def _with_annotations(self, values):
+ clone = super(AnnotatedColumnElement, self)._with_annotations(values)
+ clone.__dict__.pop("comparator", None)
+ return clone
+
+ @util.memoized_property
+ def name(self):
+ """pull 'name' from parent, if not present"""
+ return self._Annotated__element.name
+
+ @util.memoized_property
+ def table(self):
+ """pull 'table' from parent, if not present"""
+ return self._Annotated__element.table
+
+ @util.memoized_property
+ def key(self):
+ """pull 'key' from parent, if not present"""
+ return self._Annotated__element.key
+
+ @util.memoized_property
+ def info(self):
+ return self._Annotated__element.info
+
+ @util.memoized_property
+ def _anon_name_label(self):
+ return self._Annotated__element._anon_name_label
+
+
+class _truncated_label(quoted_name):
+ """A unicode subclass used to identify symbolic "
+ "names that may require truncation."""
+
+ __slots__ = ()
+
+ def __new__(cls, value, quote=None):
+ quote = getattr(value, "quote", quote)
+ # return super(_truncated_label, cls).__new__(cls, value, quote, True)
+ return super(_truncated_label, cls).__new__(cls, value, quote)
+
+ def __reduce__(self):
+ return self.__class__, (util.text_type(self), self.quote)
+
+ def apply_map(self, map_):
+ return self
+
+
+class conv(_truncated_label):
+ """Mark a string indicating that a name has already been converted
+ by a naming convention.
+
+ This is a string subclass that indicates a name that should not be
+ subject to any further naming conventions.
+
+ E.g. when we create a :class:`.Constraint` using a naming convention
+ as follows::
+
+ m = MetaData(naming_convention={
+ "ck": "ck_%(table_name)s_%(constraint_name)s"
+ })
+ t = Table('t', m, Column('x', Integer),
+ CheckConstraint('x > 5', name='x5'))
+
+ The name of the above constraint will be rendered as ``"ck_t_x5"``.
+ That is, the existing name ``x5`` is used in the naming convention as the
+ ``constraint_name`` token.
+
+ In some situations, such as in migration scripts, we may be rendering
+ the above :class:`.CheckConstraint` with a name that's already been
+ converted. In order to make sure the name isn't double-modified, the
+ new name is applied using the :func:`_schema.conv` marker. We can
+ use this explicitly as follows::
+
+
+ m = MetaData(naming_convention={
+ "ck": "ck_%(table_name)s_%(constraint_name)s"
+ })
+ t = Table('t', m, Column('x', Integer),
+ CheckConstraint('x > 5', name=conv('ck_t_x5')))
+
+ Where above, the :func:`_schema.conv` marker indicates that the constraint
+ name here is final, and the name will render as ``"ck_t_x5"`` and not
+ ``"ck_t_ck_t_x5"``
+
+ .. versionadded:: 0.9.4
+
+ .. seealso::
+
+ :ref:`constraint_naming_conventions`
+
+ """
+
+ __slots__ = ()
+
+
+_NONE_NAME = util.symbol("NONE_NAME")
+"""indicate a 'deferred' name that was ultimately the value None."""
+
+# for backwards compatibility in case
+# someone is re-implementing the
+# _truncated_identifier() sequence in a custom
+# compiler
+_generated_label = _truncated_label
+
+
+class _anonymous_label(_truncated_label):
+ """A unicode subclass used to identify anonymously
+ generated names."""
+
+ __slots__ = ()
+
+ @classmethod
+ def safe_construct(
+ cls, seed, body, enclosing_label=None, sanitize_key=False
+ ):
+
+ if sanitize_key:
+ body = re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+
+ label = "%%(%d %s)s" % (seed, body.replace("%", "%%"))
+ if enclosing_label:
+ label = "%s%s" % (enclosing_label, label)
+
+ return _anonymous_label(label)
+
+ def __add__(self, other):
+ if "%" in other and not isinstance(other, _anonymous_label):
+ other = util.text_type(other).replace("%", "%%")
+ else:
+ other = util.text_type(other)
+
+ return _anonymous_label(
+ quoted_name(
+ util.text_type.__add__(self, other),
+ self.quote,
+ )
+ )
+
+ def __radd__(self, other):
+ if "%" in other and not isinstance(other, _anonymous_label):
+ other = util.text_type(other).replace("%", "%%")
+ else:
+ other = util.text_type(other)
+
+ return _anonymous_label(
+ quoted_name(
+ util.text_type.__add__(other, self),
+ self.quote,
+ )
+ )
+
+ def apply_map(self, map_):
+ if self.quote is not None:
+ # preserve quoting only if necessary
+ return quoted_name(self % map_, self.quote)
+ else:
+ # else skip the constructor call
+ return self % map_
diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py
new file mode 100644
index 0000000..c425789
--- /dev/null
+++ b/lib/sqlalchemy/sql/events.py
@@ -0,0 +1,331 @@
+# sqlalchemy/sql/events.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .base import SchemaEventTarget
+from .. import event
+
+
+class DDLEvents(event.Events):
+ """
+ Define event listeners for schema objects,
+ that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget`
+ subclasses, including :class:`_schema.MetaData`, :class:`_schema.Table`,
+ :class:`_schema.Column`.
+
+ :class:`_schema.MetaData` and :class:`_schema.Table` support events
+ specifically regarding when CREATE and DROP
+ DDL is emitted to the database.
+
+ Attachment events are also provided to customize
+ behavior whenever a child schema element is associated
+ with a parent, such as, when a :class:`_schema.Column` is associated
+ with its :class:`_schema.Table`, when a
+ :class:`_schema.ForeignKeyConstraint`
+ is associated with a :class:`_schema.Table`, etc.
+
+ Example using the ``after_create`` event::
+
+ from sqlalchemy import event
+ from sqlalchemy import Table, Column, Metadata, Integer
+
+ m = MetaData()
+ some_table = Table('some_table', m, Column('data', Integer))
+
+ def after_create(target, connection, **kw):
+ connection.execute(text(
+ "ALTER TABLE %s SET name=foo_%s" % (target.name, target.name)
+ ))
+
+ event.listen(some_table, "after_create", after_create)
+
+ DDL events integrate closely with the
+ :class:`.DDL` class and the :class:`.DDLElement` hierarchy
+ of DDL clause constructs, which are themselves appropriate
+ as listener callables::
+
+ from sqlalchemy import DDL
+ event.listen(
+ some_table,
+ "after_create",
+ DDL("ALTER TABLE %(table)s SET name=foo_%(table)s")
+ )
+
+ The methods here define the name of an event as well
+ as the names of members that are passed to listener
+ functions.
+
+ For all :class:`.DDLEvent` events, the ``propagate=True`` keyword argument
+ will ensure that a given event handler is propagated to copies of the
+ object, which are made when using the :meth:`_schema.Table.to_metadata`
+ method::
+
+ from sqlalchemy import DDL
+ event.listen(
+ some_table,
+ "after_create",
+ DDL("ALTER TABLE %(table)s SET name=foo_%(table)s"),
+ propagate=True
+ )
+
+ new_table = some_table.to_metadata(new_metadata)
+
+ The above :class:`.DDL` object will also be associated with the
+ :class:`_schema.Table` object represented by ``new_table``.
+
+ .. seealso::
+
+ :ref:`event_toplevel`
+
+ :class:`.DDLElement`
+
+ :class:`.DDL`
+
+ :ref:`schema_ddl_sequences`
+
+ """
+
+ _target_class_doc = "SomeSchemaClassOrObject"
+ _dispatch_target = SchemaEventTarget
+
+ def before_create(self, target, connection, **kw):
+ r"""Called before CREATE statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ CREATE statement or statements will be emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ :func:`.event.listen` accepts the ``insert=True``
+ modifier for this event; when True, the listener function will
+ be prepended to the internal list of events upon discovery, and execute
+ before registered listener functions that do not pass this argument.
+
+ """
+
+ def after_create(self, target, connection, **kw):
+ r"""Called after CREATE statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ CREATE statement or statements have been emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def before_drop(self, target, connection, **kw):
+ r"""Called before DROP statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ DROP statement or statements will be emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def after_drop(self, target, connection, **kw):
+ r"""Called after DROP statements are emitted.
+
+ :param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
+ object which is the target of the event.
+ :param connection: the :class:`_engine.Connection` where the
+ DROP statement or statements have been emitted.
+ :param \**kw: additional keyword arguments relevant
+ to the event. The contents of this dictionary
+ may vary across releases, and include the
+ list of tables being generated for a metadata-level
+ event, the checkfirst flag, and other
+ elements used by internal events.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def before_parent_attach(self, target, parent):
+ """Called before a :class:`.SchemaItem` is associated with
+ a parent :class:`.SchemaItem`.
+
+ :param target: the target object
+ :param parent: the parent to which the target is being attached.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def after_parent_attach(self, target, parent):
+ """Called after a :class:`.SchemaItem` is associated with
+ a parent :class:`.SchemaItem`.
+
+ :param target: the target object
+ :param parent: the parent to which the target is being attached.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ """
+
+ def _sa_event_column_added_to_pk_constraint(self, const, col):
+ """internal event hook used for primary key naming convention
+ updates.
+
+ """
+
+ def column_reflect(self, inspector, table, column_info):
+ """Called for each unit of 'column info' retrieved when
+ a :class:`_schema.Table` is being reflected.
+
+ This event is most easily used by applying it to a specific
+ :class:`_schema.MetaData` instance, where it will take effect for
+ all :class:`_schema.Table` objects within that
+ :class:`_schema.MetaData` that undergo reflection::
+
+ metadata = MetaData()
+
+ @event.listens_for(metadata, 'column_reflect')
+ def receive_column_reflect(inspector, table, column_info):
+ # receives for all Table objects that are reflected
+ # under this MetaData
+
+
+ # will use the above event hook
+ my_table = Table("my_table", metadata, autoload_with=some_engine)
+
+
+ .. versionadded:: 1.4.0b2 The :meth:`_events.DDLEvents.column_reflect`
+ hook may now be applied to a :class:`_schema.MetaData` object as
+ well as the :class:`_schema.MetaData` class itself where it will
+ take place for all :class:`_schema.Table` objects associated with
+ the targeted :class:`_schema.MetaData`.
+
+ It may also be applied to the :class:`_schema.Table` class across
+ the board::
+
+ from sqlalchemy import Table
+
+ @event.listens_for(Table, 'column_reflect')
+ def receive_column_reflect(inspector, table, column_info):
+ # receives for all Table objects that are reflected
+
+ It can also be applied to a specific :class:`_schema.Table` at the
+ point that one is being reflected using the
+ :paramref:`_schema.Table.listeners` parameter::
+
+ t1 = Table(
+ "my_table",
+ autoload_with=some_engine,
+ listeners=[
+ ('column_reflect', receive_column_reflect)
+ ]
+ )
+
+ A future release will allow it to be associated with a specific
+ :class:`_schema.MetaData` object as well.
+
+ The dictionary of column information as returned by the
+ dialect is passed, and can be modified. The dictionary
+ is that returned in each element of the list returned
+ by :meth:`.reflection.Inspector.get_columns`:
+
+ * ``name`` - the column's name, is applied to the
+ :paramref:`_schema.Column.name` parameter
+
+ * ``type`` - the type of this column, which should be an instance
+ of :class:`~sqlalchemy.types.TypeEngine`, is applied to the
+ :paramref:`_schema.Column.type` parameter
+
+ * ``nullable`` - boolean flag if the column is NULL or NOT NULL,
+ is applied to the :paramref:`_schema.Column.nullable` parameter
+
+ * ``default`` - the column's server default value. This is
+ normally specified as a plain string SQL expression, however the
+ event can pass a :class:`.FetchedValue`, :class:`.DefaultClause`,
+ or :func:`_expression.text` object as well. Is applied to the
+ :paramref:`_schema.Column.server_default` parameter
+
+ The event is called before any action is taken against
+ this dictionary, and the contents can be modified; the following
+ additional keys may be added to the dictionary to further modify
+ how the :class:`_schema.Column` is constructed:
+
+
+ * ``key`` - the string key that will be used to access this
+ :class:`_schema.Column` in the ``.c`` collection; will be applied
+ to the :paramref:`_schema.Column.key` parameter. Is also used
+ for ORM mapping. See the section
+ :ref:`mapper_automated_reflection_schemes` for an example.
+
+ * ``quote`` - force or un-force quoting on the column name;
+ is applied to the :paramref:`_schema.Column.quote` parameter.
+
+ * ``info`` - a dictionary of arbitrary data to follow along with
+ the :class:`_schema.Column`, is applied to the
+ :paramref:`_schema.Column.info` parameter.
+
+ :func:`.event.listen` also accepts the ``propagate=True``
+ modifier for this event; when True, the listener function will
+ be established for any copies made of the target object,
+ i.e. those copies that are generated when
+ :meth:`_schema.Table.to_metadata` is used.
+
+ .. seealso::
+
+ :ref:`mapper_automated_reflection_schemes` -
+ in the ORM mapping documentation
+
+ :ref:`automap_intercepting_columns` -
+ in the :ref:`automap_toplevel` documentation
+
+ :ref:`metadata_reflection_dbagnostic_types` - in
+ the :ref:`metadata_reflection_toplevel` documentation
+
+ """
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
new file mode 100644
index 0000000..b4aa14e
--- /dev/null
+++ b/lib/sqlalchemy/sql/expression.py
@@ -0,0 +1,278 @@
+# sql/expression.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines the public namespace for SQL expression constructs.
+
+Prior to version 0.9, this module contained all of "elements", "dml",
+"default_comparator" and "selectable". The module was broken up
+and most "factory" functions were moved to be grouped with their associated
+class.
+
+"""
+
+__all__ = [
+ "Alias",
+ "AliasedReturnsRows",
+ "any_",
+ "all_",
+ "CacheKey",
+ "ClauseElement",
+ "ColumnCollection",
+ "ColumnElement",
+ "CompoundSelect",
+ "Delete",
+ "FromClause",
+ "Insert",
+ "Join",
+ "Lateral",
+ "LambdaElement",
+ "StatementLambdaElement",
+ "Select",
+ "Selectable",
+ "TableClause",
+ "TableValuedAlias",
+ "Update",
+ "Values",
+ "alias",
+ "and_",
+ "asc",
+ "between",
+ "bindparam",
+ "case",
+ "cast",
+ "column",
+ "custom_op",
+ "cte",
+ "delete",
+ "desc",
+ "distinct",
+ "except_",
+ "except_all",
+ "exists",
+ "extract",
+ "func",
+ "modifier",
+ "collate",
+ "insert",
+ "intersect",
+ "intersect_all",
+ "join",
+ "label",
+ "lateral",
+ "lambda_stmt",
+ "literal",
+ "literal_column",
+ "not_",
+ "null",
+ "nulls_first",
+ "nulls_last",
+ "or_",
+ "outparam",
+ "outerjoin",
+ "over",
+ "select",
+ "table",
+ "text",
+ "tuple_",
+ "type_coerce",
+ "quoted_name",
+ "union",
+ "union_all",
+ "update",
+ "quoted_name",
+ "within_group",
+ "Subquery",
+ "TableSample",
+ "tablesample",
+ "values",
+]
+
+
+from .base import _from_objects
+from .base import _select_iterables
+from .base import ColumnCollection
+from .base import Executable
+from .base import PARSE_AUTOCOMMIT
+from .dml import Delete
+from .dml import Insert
+from .dml import Update
+from .dml import UpdateBase
+from .dml import ValuesBase
+from .elements import _truncated_label
+from .elements import between
+from .elements import BinaryExpression
+from .elements import BindParameter
+from .elements import BooleanClauseList
+from .elements import Case
+from .elements import Cast
+from .elements import ClauseElement
+from .elements import ClauseList
+from .elements import collate
+from .elements import CollectionAggregate
+from .elements import ColumnClause
+from .elements import ColumnElement
+from .elements import Extract
+from .elements import False_
+from .elements import FunctionFilter
+from .elements import Grouping
+from .elements import Label
+from .elements import literal
+from .elements import literal_column
+from .elements import not_
+from .elements import Null
+from .elements import outparam
+from .elements import Over
+from .elements import quoted_name
+from .elements import ReleaseSavepointClause
+from .elements import RollbackToSavepointClause
+from .elements import SavepointClause
+from .elements import TextClause
+from .elements import True_
+from .elements import Tuple
+from .elements import TypeClause
+from .elements import TypeCoerce
+from .elements import UnaryExpression
+from .elements import WithinGroup
+from .functions import func
+from .functions import Function
+from .functions import FunctionElement
+from .functions import modifier
+from .lambdas import lambda_stmt
+from .lambdas import LambdaElement
+from .lambdas import StatementLambdaElement
+from .operators import ColumnOperators
+from .operators import custom_op
+from .operators import Operators
+from .selectable import Alias
+from .selectable import AliasedReturnsRows
+from .selectable import CompoundSelect
+from .selectable import CTE
+from .selectable import Exists
+from .selectable import FromClause
+from .selectable import FromGrouping
+from .selectable import GenerativeSelect
+from .selectable import HasCTE
+from .selectable import HasPrefixes
+from .selectable import HasSuffixes
+from .selectable import Join
+from .selectable import LABEL_STYLE_DEFAULT
+from .selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
+from .selectable import LABEL_STYLE_NONE
+from .selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from .selectable import Lateral
+from .selectable import ReturnsRows
+from .selectable import ScalarSelect
+from .selectable import Select
+from .selectable import Selectable
+from .selectable import SelectBase
+from .selectable import Subquery
+from .selectable import subquery
+from .selectable import TableClause
+from .selectable import TableSample
+from .selectable import TableValuedAlias
+from .selectable import TextAsFrom
+from .selectable import TextualSelect
+from .selectable import Values
+from .traversals import CacheKey
+from .visitors import Visitable
+from ..util.langhelpers import public_factory
+
+# factory functions - these pull class-bound constructors and classmethods
+# from SQL elements and selectables into public functions. This allows
+# the functions to be available in the sqlalchemy.sql.* namespace and
+# to be auto-cross-documenting from the function to the class itself.
+
+all_ = public_factory(CollectionAggregate._create_all, ".sql.expression.all_")
+any_ = public_factory(CollectionAggregate._create_any, ".sql.expression.any_")
+and_ = public_factory(BooleanClauseList.and_, ".sql.expression.and_")
+alias = public_factory(Alias._factory, ".sql.expression.alias")
+tablesample = public_factory(
+ TableSample._factory, ".sql.expression.tablesample"
+)
+lateral = public_factory(Lateral._factory, ".sql.expression.lateral")
+or_ = public_factory(BooleanClauseList.or_, ".sql.expression.or_")
+bindparam = public_factory(BindParameter, ".sql.expression.bindparam")
+select = public_factory(Select._create, ".sql.expression.select")
+text = public_factory(TextClause._create_text, ".sql.expression.text")
+table = public_factory(TableClause, ".sql.expression.table")
+column = public_factory(ColumnClause, ".sql.expression.column")
+over = public_factory(Over, ".sql.expression.over")
+within_group = public_factory(WithinGroup, ".sql.expression.within_group")
+label = public_factory(Label, ".sql.expression.label")
+case = public_factory(Case, ".sql.expression.case")
+cast = public_factory(Cast, ".sql.expression.cast")
+cte = public_factory(CTE._factory, ".sql.expression.cte")
+values = public_factory(Values, ".sql.expression.values")
+extract = public_factory(Extract, ".sql.expression.extract")
+tuple_ = public_factory(Tuple, ".sql.expression.tuple_")
+except_ = public_factory(
+ CompoundSelect._create_except, ".sql.expression.except_"
+)
+except_all = public_factory(
+ CompoundSelect._create_except_all, ".sql.expression.except_all"
+)
+intersect = public_factory(
+ CompoundSelect._create_intersect, ".sql.expression.intersect"
+)
+intersect_all = public_factory(
+ CompoundSelect._create_intersect_all, ".sql.expression.intersect_all"
+)
+union = public_factory(CompoundSelect._create_union, ".sql.expression.union")
+union_all = public_factory(
+ CompoundSelect._create_union_all, ".sql.expression.union_all"
+)
+exists = public_factory(Exists, ".sql.expression.exists")
+nulls_first = public_factory(
+ UnaryExpression._create_nulls_first, ".sql.expression.nulls_first"
+)
+nullsfirst = nulls_first # deprecated 1.4; see #5435
+nulls_last = public_factory(
+ UnaryExpression._create_nulls_last, ".sql.expression.nulls_last"
+)
+nullslast = nulls_last # deprecated 1.4; see #5435
+asc = public_factory(UnaryExpression._create_asc, ".sql.expression.asc")
+desc = public_factory(UnaryExpression._create_desc, ".sql.expression.desc")
+distinct = public_factory(
+ UnaryExpression._create_distinct, ".sql.expression.distinct"
+)
+type_coerce = public_factory(TypeCoerce, ".sql.expression.type_coerce")
+true = public_factory(True_._instance, ".sql.expression.true")
+false = public_factory(False_._instance, ".sql.expression.false")
+null = public_factory(Null._instance, ".sql.expression.null")
+join = public_factory(Join._create_join, ".sql.expression.join")
+outerjoin = public_factory(Join._create_outerjoin, ".sql.expression.outerjoin")
+insert = public_factory(Insert, ".sql.expression.insert")
+update = public_factory(Update, ".sql.expression.update")
+delete = public_factory(Delete, ".sql.expression.delete")
+funcfilter = public_factory(FunctionFilter, ".sql.expression.funcfilter")
+
+
+# internal functions still being called from tests and the ORM,
+# these might be better off in some other namespace
+
+
+# old names for compatibility
+_Executable = Executable
+_BindParamClause = BindParameter
+_Label = Label
+_SelectBase = SelectBase
+_BinaryExpression = BinaryExpression
+_Cast = Cast
+_Null = Null
+_False = False_
+_True = True_
+_TextClause = TextClause
+_UnaryExpression = UnaryExpression
+_Case = Case
+_Tuple = Tuple
+_Over = Over
+_TypeClause = TypeClause
+_Extract = Extract
+_Exists = Exists
+_Grouping = Grouping
+_FromGrouping = FromGrouping
+_ScalarSelect = ScalarSelect
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
new file mode 100644
index 0000000..29f4122
--- /dev/null
+++ b/lib/sqlalchemy/sql/functions.py
@@ -0,0 +1,1575 @@
+# sql/functions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQL function API, factories, and built-in functions.
+
+"""
+from . import annotation
+from . import coercions
+from . import operators
+from . import roles
+from . import schema
+from . import sqltypes
+from . import util as sqlutil
+from .base import _entity_namespace
+from .base import ColumnCollection
+from .base import Executable
+from .base import Generative
+from .base import HasMemoized
+from .elements import _type_from_args
+from .elements import BinaryExpression
+from .elements import BindParameter
+from .elements import Cast
+from .elements import ClauseList
+from .elements import ColumnElement
+from .elements import Extract
+from .elements import FunctionFilter
+from .elements import Grouping
+from .elements import literal_column
+from .elements import NamedColumn
+from .elements import Over
+from .elements import WithinGroup
+from .selectable import FromClause
+from .selectable import Select
+from .selectable import TableValuedAlias
+from .visitors import InternalTraversal
+from .visitors import TraversibleType
+from .. import util
+
+
+_registry = util.defaultdict(dict)
+
+
+def register_function(identifier, fn, package="_default"):
+ """Associate a callable with a particular func. name.
+
+ This is normally called by _GenericMeta, but is also
+ available by itself so that a non-Function construct
+ can be associated with the :data:`.func` accessor (i.e.
+ CAST, EXTRACT).
+
+ """
+ reg = _registry[package]
+
+ identifier = util.text_type(identifier).lower()
+
+ # Check if a function with the same identifier is registered.
+ if identifier in reg:
+ util.warn(
+ "The GenericFunction '{}' is already registered and "
+ "is going to be overridden.".format(identifier)
+ )
+ reg[identifier] = fn
+
+
+class FunctionElement(Executable, ColumnElement, FromClause, Generative):
+ """Base for SQL function-oriented constructs.
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ :class:`.Function` - named SQL function.
+
+ :data:`.func` - namespace which produces registered or ad-hoc
+ :class:`.Function` instances.
+
+ :class:`.GenericFunction` - allows creation of registered function
+ types.
+
+ """
+
+ _traverse_internals = [
+ ("clause_expr", InternalTraversal.dp_clauseelement),
+ ("_with_ordinality", InternalTraversal.dp_boolean),
+ ("_table_value_type", InternalTraversal.dp_has_cache_key),
+ ]
+
+ packagenames = ()
+
+ _has_args = False
+ _with_ordinality = False
+ _table_value_type = None
+
+ def __init__(self, *clauses, **kwargs):
+ r"""Construct a :class:`.FunctionElement`.
+
+ :param \*clauses: list of column expressions that form the arguments
+ of the SQL function call.
+
+ :param \**kwargs: additional kwargs are typically consumed by
+ subclasses.
+
+ .. seealso::
+
+ :data:`.func`
+
+ :class:`.Function`
+
+ """
+ args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=getattr(self, "name", None),
+ apply_propagate_attrs=self,
+ )
+ for c in clauses
+ ]
+ self._has_args = self._has_args or bool(args)
+ self.clause_expr = ClauseList(
+ operator=operators.comma_op, group_contents=True, *args
+ ).self_group()
+
+ _non_anon_label = None
+
+ @property
+ def _proxy_key(self):
+ return super(FunctionElement, self)._proxy_key or getattr(
+ self, "name", None
+ )
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ return connection._execute_function(
+ self, multiparams, params, execution_options
+ )
+
+ def scalar_table_valued(self, name, type_=None):
+ """Return a column expression that's against this
+ :class:`_functions.FunctionElement` as a scalar
+ table-valued expression.
+
+ The returned expression is similar to that returned by a single column
+ accessed off of a :meth:`_functions.FunctionElement.table_valued`
+ construct, except no FROM clause is generated; the function is rendered
+ in the similar way as a scalar subquery.
+
+ E.g.::
+
+ >>> from sqlalchemy import func, select
+ >>> fn = func.jsonb_each("{'k', 'v'}").scalar_table_valued("key")
+ >>> print(select(fn))
+ SELECT (jsonb_each(:jsonb_each_1)).key
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :meth:`_functions.FunctionElement.table_valued`
+
+ :meth:`_functions.FunctionElement.alias`
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+ """ # noqa: E501
+
+ return ScalarFunctionColumn(self, name, type_)
+
+ def table_valued(self, *expr, **kw):
+ r"""Return a :class:`_sql.TableValuedAlias` representation of this
+ :class:`_functions.FunctionElement` with table-valued expressions added.
+
+ e.g.::
+
+ >>> fn = (
+ ... func.generate_series(1, 5).
+ ... table_valued("value", "start", "stop", "step")
+ ... )
+
+ >>> print(select(fn))
+ SELECT anon_1.value, anon_1.start, anon_1.stop, anon_1.step
+ FROM generate_series(:generate_series_1, :generate_series_2) AS anon_1
+
+ >>> print(select(fn.c.value, fn.c.stop).where(fn.c.value > 2))
+ SELECT anon_1.value, anon_1.stop
+ FROM generate_series(:generate_series_1, :generate_series_2) AS anon_1
+ WHERE anon_1.value > :value_1
+
+ A WITH ORDINALITY expression may be generated by passing the keyword
+ argument "with_ordinality"::
+
+ >>> fn = func.generate_series(4, 1, -1).table_valued("gen", with_ordinality="ordinality")
+ >>> print(select(fn))
+ SELECT anon_1.gen, anon_1.ordinality
+ FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) WITH ORDINALITY AS anon_1
+
+ :param \*expr: A series of string column names that will be added to the
+ ``.c`` collection of the resulting :class:`_sql.TableValuedAlias`
+ construct as columns. :func:`_sql.column` objects with or without
+ datatypes may also be used.
+
+ :param name: optional name to assign to the alias name that's generated.
+ If omitted, a unique anonymizing name is used.
+
+ :param with_ordinality: string name that when present results in the
+ ``WITH ORDINALITY`` clause being added to the alias, and the given
+ string name will be added as a column to the .c collection
+ of the resulting :class:`_sql.TableValuedAlias`.
+
+ :param joins_implicitly: when True, the table valued function may be
+ used in the FROM clause without any explicit JOIN to other tables
+ in the SQL query, and no "cartesian product" warning will be generated.
+ May be useful for SQL functions such as ``func.json_each()``.
+
+ .. versionadded:: 1.4.33
+
+ .. versionadded:: 1.4.0b2
+
+
+ .. seealso::
+
+ :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial`
+
+ :ref:`postgresql_table_valued` - in the :ref:`postgresql_toplevel` documentation
+
+ :meth:`_functions.FunctionElement.scalar_table_valued` - variant of
+ :meth:`_functions.FunctionElement.table_valued` which delivers the
+ complete table valued expression as a scalar column expression
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+ :meth:`_sql.TableValuedAlias.render_derived` - renders the alias
+ using a derived column clause, e.g. ``AS name(col1, col2, ...)``
+
+ """ # noqa: 501
+
+ new_func = self._generate()
+
+ with_ordinality = kw.pop("with_ordinality", None)
+ joins_implicitly = kw.pop("joins_implicitly", None)
+ name = kw.pop("name", None)
+
+ if with_ordinality:
+ expr += (with_ordinality,)
+ new_func._with_ordinality = True
+
+ new_func.type = new_func._table_value_type = sqltypes.TableValueType(
+ *expr
+ )
+
+ return new_func.alias(name=name, joins_implicitly=joins_implicitly)
+
+ def column_valued(self, name=None):
+ """Return this :class:`_functions.FunctionElement` as a column expression that
+ selects from itself as a FROM clause.
+
+ E.g.::
+
+ >>> from sqlalchemy import select, func
+ >>> gs = func.generate_series(1, 5, -1).column_valued()
+ >>> print(select(gs))
+ SELECT anon_1
+ FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) AS anon_1
+
+ This is shorthand for::
+
+ gs = func.generate_series(1, 5, -1).alias().column
+
+
+ .. seealso::
+
+ :ref:`tutorial_functions_column_valued` - in the :ref:`unified_tutorial`
+
+ :ref:`postgresql_column_valued` - in the :ref:`postgresql_toplevel` documentation
+
+ :meth:`_functions.FunctionElement.table_valued`
+
+ """ # noqa: 501
+
+ return self.alias(name=name).column
+
+ @property
+ def columns(self):
+ r"""The set of columns exported by this :class:`.FunctionElement`.
+
+ This is a placeholder collection that allows the function to be
+ placed in the FROM clause of a statement::
+
+ >>> from sqlalchemy import column, select, func
+ >>> stmt = select(column('x'), column('y')).select_from(func.myfunction())
+ >>> print(stmt)
+ SELECT x, y FROM myfunction()
+
+ The above form is a legacy feature that is now superseded by the
+ fully capable :meth:`_functions.FunctionElement.table_valued`
+ method; see that method for details.
+
+ .. seealso::
+
+ :meth:`_functions.FunctionElement.table_valued` - generates table-valued
+ SQL function expressions.
+
+ """ # noqa: E501
+
+ return ColumnCollection(
+ columns=[(col.key, col) for col in self._all_selected_columns]
+ )
+
+ @property
+ def _all_selected_columns(self):
+ if self.type._is_table_value:
+ cols = self.type._elements
+ else:
+ cols = [self.label(None)]
+
+ return cols
+
+ @property
+ def exported_columns(self):
+ return self.columns
+
+ @HasMemoized.memoized_attribute
+ def clauses(self):
+ """Return the underlying :class:`.ClauseList` which contains
+ the arguments for this :class:`.FunctionElement`.
+
+ """
+ return self.clause_expr.element
+
+ def over(self, partition_by=None, order_by=None, rows=None, range_=None):
+ """Produce an OVER clause against this function.
+
+ Used against aggregate or so-called "window" functions,
+ for database backends that support window functions.
+
+ The expression::
+
+ func.row_number().over(order_by='x')
+
+ is shorthand for::
+
+ from sqlalchemy import over
+ over(func.row_number(), order_by='x')
+
+ See :func:`_expression.over` for a full description.
+
+ .. seealso::
+
+ :func:`_expression.over`
+
+ :ref:`tutorial_window_functions` - in the :ref:`unified_tutorial`
+
+ """
+ return Over(
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ rows=rows,
+ range_=range_,
+ )
+
+ def within_group(self, *order_by):
+ """Produce a WITHIN GROUP (ORDER BY expr) clause against this function.
+
+ Used against so-called "ordered set aggregate" and "hypothetical
+ set aggregate" functions, including :class:`.percentile_cont`,
+ :class:`.rank`, :class:`.dense_rank`, etc.
+
+ See :func:`_expression.within_group` for a full description.
+
+ .. versionadded:: 1.1
+
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` -
+ in the :ref:`unified_tutorial`
+
+
+ """
+ return WithinGroup(self, *order_by)
+
+ def filter(self, *criterion):
+ """Produce a FILTER clause against this function.
+
+ Used against aggregate and window functions,
+ for database backends that support the "FILTER" clause.
+
+ The expression::
+
+ func.count(1).filter(True)
+
+ is shorthand for::
+
+ from sqlalchemy import funcfilter
+ funcfilter(func.count(1), True)
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :ref:`tutorial_functions_within_group` -
+ in the :ref:`unified_tutorial`
+
+ :class:`.FunctionFilter`
+
+ :func:`.funcfilter`
+
+
+ """
+ if not criterion:
+ return self
+ return FunctionFilter(self, *criterion)
+
+ def as_comparison(self, left_index, right_index):
+ """Interpret this expression as a boolean comparison between two
+ values.
+
+ This method is used for an ORM use case described at
+ :ref:`relationship_custom_operator_sql_function`.
+
+ A hypothetical SQL function "is_equal()" which compares to values
+ for equality would be written in the Core expression language as::
+
+ expr = func.is_equal("a", "b")
+
+ If "is_equal()" above is comparing "a" and "b" for equality, the
+ :meth:`.FunctionElement.as_comparison` method would be invoked as::
+
+ expr = func.is_equal("a", "b").as_comparison(1, 2)
+
+ Where above, the integer value "1" refers to the first argument of the
+ "is_equal()" function and the integer value "2" refers to the second.
+
+ This would create a :class:`.BinaryExpression` that is equivalent to::
+
+ BinaryExpression("a", "b", operator=op.eq)
+
+ However, at the SQL level it would still render as
+ "is_equal('a', 'b')".
+
+ The ORM, when it loads a related object or collection, needs to be able
+ to manipulate the "left" and "right" sides of the ON clause of a JOIN
+ expression. The purpose of this method is to provide a SQL function
+ construct that can also supply this information to the ORM, when used
+ with the :paramref:`_orm.relationship.primaryjoin` parameter. The
+ return value is a containment object called :class:`.FunctionAsBinary`.
+
+ An ORM example is as follows::
+
+ class Venue(Base):
+ __tablename__ = 'venue'
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ descendants = relationship(
+ "Venue",
+ primaryjoin=func.instr(
+ remote(foreign(name)), name + "/"
+ ).as_comparison(1, 2) == 1,
+ viewonly=True,
+ order_by=name
+ )
+
+ Above, the "Venue" class can load descendant "Venue" objects by
+ determining if the name of the parent Venue is contained within the
+ start of the hypothetical descendant value's name, e.g. "parent1" would
+ match up to "parent1/child1", but not to "parent2/child1".
+
+ Possible use cases include the "materialized path" example given above,
+ as well as making use of special SQL functions such as geometric
+ functions to create join conditions.
+
+ :param left_index: the integer 1-based index of the function argument
+ that serves as the "left" side of the expression.
+ :param right_index: the integer 1-based index of the function argument
+ that serves as the "right" side of the expression.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`relationship_custom_operator_sql_function` -
+ example use within the ORM
+
+ """
+ return FunctionAsBinary(self, left_index, right_index)
+
+ @property
+ def _from_objects(self):
+ return self.clauses._from_objects
+
+ def within_group_type(self, within_group):
+ """For types that define their return type as based on the criteria
+ within a WITHIN GROUP (ORDER BY) expression, called by the
+ :class:`.WithinGroup` construct.
+
+ Returns None by default, in which case the function's normal ``.type``
+ is used.
+
+ """
+
+ return None
+
+ def alias(self, name=None, joins_implicitly=False):
+ r"""Produce a :class:`_expression.Alias` construct against this
+ :class:`.FunctionElement`.
+
+ .. tip::
+
+ The :meth:`_functions.FunctionElement.alias` method is part of the
+ mechanism by which "table valued" SQL functions are created.
+ However, most use cases are covered by higher level methods on
+ :class:`_functions.FunctionElement` including
+ :meth:`_functions.FunctionElement.table_valued`, and
+ :meth:`_functions.FunctionElement.column_valued`.
+
+ This construct wraps the function in a named alias which
+ is suitable for the FROM clause, in the style accepted for example
+ by PostgreSQL. A column expression is also provided using the
+ special ``.column`` attribute, which may
+ be used to refer to the output of the function as a scalar value
+ in the columns or where clause, for a backend such as PostgreSQL.
+
+ For a full table-valued expression, use the
+ :meth:`_function.FunctionElement.table_valued` method first to
+ establish named columns.
+
+ e.g.::
+
+ >>> from sqlalchemy import func, select, column
+ >>> data_view = func.unnest([1, 2, 3]).alias("data_view")
+ >>> print(select(data_view.column))
+ SELECT data_view
+ FROM unnest(:unnest_1) AS data_view
+
+ The :meth:`_functions.FunctionElement.column_valued` method provides
+ a shortcut for the above pattern::
+
+ >>> data_view = func.unnest([1, 2, 3]).column_valued("data_view")
+ >>> print(select(data_view))
+ SELECT data_view
+ FROM unnest(:unnest_1) AS data_view
+
+ .. versionadded:: 1.4.0b2 Added the ``.column`` accessor
+
+ :param name: alias name, will be rendered as ``AS <name>`` in the
+ FROM clause
+
+ :param joins_implicitly: when True, the table valued function may be
+ used in the FROM clause without any explicit JOIN to other tables
+ in the SQL query, and no "cartesian product" warning will be
+ generated. May be useful for SQL functions such as
+ ``func.json_each()``.
+
+ .. versionadded:: 1.4.33
+
+ .. seealso::
+
+ :ref:`tutorial_functions_table_valued` -
+ in the :ref:`unified_tutorial`
+
+ :meth:`_functions.FunctionElement.table_valued`
+
+ :meth:`_functions.FunctionElement.scalar_table_valued`
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+
+ """
+
+ return TableValuedAlias._construct(
+ self,
+ name,
+ table_value_type=self.type,
+ joins_implicitly=joins_implicitly,
+ )
+
+ def select(self):
+ """Produce a :func:`_expression.select` construct
+ against this :class:`.FunctionElement`.
+
+ This is shorthand for::
+
+ s = select(function_element)
+
+ """
+ s = Select._create_select(self)
+ if self._execution_options:
+ s = s.execution_options(**self._execution_options)
+ return s
+
+ @util.deprecated_20(
+ ":meth:`.FunctionElement.scalar`",
+ alternative="Scalar execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.scalar` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.scalar` method of "
+ ":class:`.Session`.",
+ )
+ def scalar(self):
+ """Execute this :class:`.FunctionElement` against an embedded
+ 'bind' and return a scalar value.
+
+ This first calls :meth:`~.FunctionElement.select` to
+ produce a SELECT construct.
+
+ Note that :class:`.FunctionElement` can be passed to
+ the :meth:`.Connectable.scalar` method of :class:`_engine.Connection`
+ or :class:`_engine.Engine`.
+
+ """
+ return self.select().execute().scalar()
+
+ @util.deprecated_20(
+ ":meth:`.FunctionElement.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self):
+ """Execute this :class:`.FunctionElement` against an embedded
+ 'bind'.
+
+ This first calls :meth:`~.FunctionElement.select` to
+ produce a SELECT construct.
+
+ Note that :class:`.FunctionElement` can be passed to
+ the :meth:`.Connectable.execute` method of :class:`_engine.Connection`
+ or :class:`_engine.Engine`.
+
+ """
+ return self.select().execute()
+
+ def _bind_param(self, operator, obj, type_=None, **kw):
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ unique=True,
+ type_=type_,
+ **kw
+ )
+
+ def self_group(self, against=None):
+ # for the moment, we are parenthesizing all array-returning
+ # expressions against getitem. This may need to be made
+ # more portable if in the future we support other DBs
+ # besides postgresql.
+ if against is operators.getitem and isinstance(
+ self.type, sqltypes.ARRAY
+ ):
+ return Grouping(self)
+ else:
+ return super(FunctionElement, self).self_group(against=against)
+
+ @property
+ def entity_namespace(self):
+ """overrides FromClause.entity_namespace as functions are generally
+ column expressions and not FromClauses.
+
+ """
+ # ideally functions would not be fromclauses but we failed to make
+ # this adjustment in 1.4
+ return _entity_namespace(self.clause_expr)
+
+
+class FunctionAsBinary(BinaryExpression):
+ _traverse_internals = [
+ ("sql_function", InternalTraversal.dp_clauseelement),
+ ("left_index", InternalTraversal.dp_plain_obj),
+ ("right_index", InternalTraversal.dp_plain_obj),
+ ("modifiers", InternalTraversal.dp_plain_dict),
+ ]
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ return ColumnElement._gen_cache_key(self, anon_map, bindparams)
+
+ def __init__(self, fn, left_index, right_index):
+ self.sql_function = fn
+ self.left_index = left_index
+ self.right_index = right_index
+
+ self.operator = operators.function_as_comparison_op
+ self.type = sqltypes.BOOLEANTYPE
+ self.negate = None
+ self._is_implicitly_boolean = True
+ self.modifiers = {}
+
+ @property
+ def left(self):
+ return self.sql_function.clauses.clauses[self.left_index - 1]
+
+ @left.setter
+ def left(self, value):
+ self.sql_function.clauses.clauses[self.left_index - 1] = value
+
+ @property
+ def right(self):
+ return self.sql_function.clauses.clauses[self.right_index - 1]
+
+ @right.setter
+ def right(self, value):
+ self.sql_function.clauses.clauses[self.right_index - 1] = value
+
+
+class ScalarFunctionColumn(NamedColumn):
+ __visit_name__ = "scalar_function_column"
+
+ _traverse_internals = [
+ ("name", InternalTraversal.dp_anon_name),
+ ("type", InternalTraversal.dp_type),
+ ("fn", InternalTraversal.dp_clauseelement),
+ ]
+
+ is_literal = False
+ table = None
+
+ def __init__(self, fn, name, type_=None):
+ self.fn = fn
+ self.name = name
+ self.type = sqltypes.to_instance(type_)
+
+
+class _FunctionGenerator(object):
+ """Generate SQL function expressions.
+
+ :data:`.func` is a special object instance which generates SQL
+ functions based on name-based attributes, e.g.::
+
+ >>> print(func.count(1))
+ count(:param_1)
+
+ The returned object is an instance of :class:`.Function`, and is a
+ column-oriented SQL element like any other, and is used in that way::
+
+ >>> print(select(func.count(table.c.id)))
+ SELECT count(sometable.id) FROM sometable
+
+ Any name can be given to :data:`.func`. If the function name is unknown to
+ SQLAlchemy, it will be rendered exactly as is. For common SQL functions
+ which SQLAlchemy is aware of, the name may be interpreted as a *generic
+ function* which will be compiled appropriately to the target database::
+
+ >>> print(func.current_timestamp())
+ CURRENT_TIMESTAMP
+
+ To call functions which are present in dot-separated packages,
+ specify them in the same manner::
+
+ >>> print(func.stats.yield_curve(5, 10))
+ stats.yield_curve(:yield_curve_1, :yield_curve_2)
+
+ SQLAlchemy can be made aware of the return type of functions to enable
+ type-specific lexical and result-based behavior. For example, to ensure
+ that a string-based function returns a Unicode value and is similarly
+ treated as a string in expressions, specify
+ :class:`~sqlalchemy.types.Unicode` as the type:
+
+ >>> print(func.my_string(u'hi', type_=Unicode) + ' ' +
+ ... func.my_string(u'there', type_=Unicode))
+ my_string(:my_string_1) || :my_string_2 || my_string(:my_string_3)
+
+ The object returned by a :data:`.func` call is usually an instance of
+ :class:`.Function`.
+ This object meets the "column" interface, including comparison and labeling
+ functions. The object can also be passed the :meth:`~.Connectable.execute`
+ method of a :class:`_engine.Connection` or :class:`_engine.Engine`,
+ where it will be
+ wrapped inside of a SELECT statement first::
+
+ print(connection.execute(func.current_timestamp()).scalar())
+
+ In a few exception cases, the :data:`.func` accessor
+ will redirect a name to a built-in expression such as :func:`.cast`
+ or :func:`.extract`, as these names have well-known meaning
+ but are not exactly the same as "functions" from a SQLAlchemy
+ perspective.
+
+ Functions which are interpreted as "generic" functions know how to
+ calculate their return type automatically. For a listing of known generic
+ functions, see :ref:`generic_functions`.
+
+ .. note::
+
+ The :data:`.func` construct has only limited support for calling
+ standalone "stored procedures", especially those with special
+ parameterization concerns.
+
+ See the section :ref:`stored_procedures` for details on how to use
+ the DBAPI-level ``callproc()`` method for fully traditional stored
+ procedures.
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ :class:`.Function`
+
+ """
+
+ def __init__(self, **opts):
+ self.__names = []
+ self.opts = opts
+
+ def __getattr__(self, name):
+ # passthru __ attributes; fixes pydoc
+ if name.startswith("__"):
+ try:
+ return self.__dict__[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ elif name.endswith("_"):
+ name = name[0:-1]
+ f = _FunctionGenerator(**self.opts)
+ f.__names = list(self.__names) + [name]
+ return f
+
+ def __call__(self, *c, **kwargs):
+ o = self.opts.copy()
+ o.update(kwargs)
+
+ tokens = len(self.__names)
+
+ if tokens == 2:
+ package, fname = self.__names
+ elif tokens == 1:
+ package, fname = "_default", self.__names[0]
+ else:
+ package = None
+
+ if package is not None:
+ func = _registry[package].get(fname.lower())
+ if func is not None:
+ return func(*c, **o)
+
+ return Function(
+ self.__names[-1], packagenames=tuple(self.__names[0:-1]), *c, **o
+ )
+
+
+func = _FunctionGenerator()
+func.__doc__ = _FunctionGenerator.__doc__
+
+modifier = _FunctionGenerator(group=False)
+
+
+class Function(FunctionElement):
+ r"""Describe a named SQL function.
+
+ The :class:`.Function` object is typically generated from the
+ :data:`.func` generation object.
+
+
+ :param \*clauses: list of column expressions that form the arguments
+ of the SQL function call.
+
+ :param type\_: optional :class:`.TypeEngine` datatype object that will be
+ used as the return value of the column expression generated by this
+ function call.
+
+ :param packagenames: a string which indicates package prefix names
+ to be prepended to the function name when the SQL is generated.
+ The :data:`.func` generator creates these when it is called using
+ dotted format, e.g.::
+
+ func.mypackage.some_function(col1, col2)
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ :data:`.func` - namespace which produces registered or ad-hoc
+ :class:`.Function` instances.
+
+ :class:`.GenericFunction` - allows creation of registered function
+ types.
+
+ """
+
+ __visit_name__ = "function"
+
+ _traverse_internals = FunctionElement._traverse_internals + [
+ ("packagenames", InternalTraversal.dp_plain_obj),
+ ("name", InternalTraversal.dp_string),
+ ("type", InternalTraversal.dp_type),
+ ]
+
+ type = sqltypes.NULLTYPE
+ """A :class:`_types.TypeEngine` object which refers to the SQL return
+ type represented by this SQL function.
+
+ This datatype may be configured when generating a
+ :class:`_functions.Function` object by passing the
+ :paramref:`_functions.Function.type_` parameter, e.g.::
+
+ >>> select(func.lower("some VALUE", type_=String))
+
+ The small number of built-in classes of :class:`_functions.Function` come
+ with a built-in datatype that's appropriate to the class of function and
+ its arguments. For functions that aren't known, the type defaults to the
+ "null type".
+
+ """
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_sql.text.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(self, name, *clauses, **kw):
+ """Construct a :class:`.Function`.
+
+ The :data:`.func` construct is normally used to construct
+ new :class:`.Function` instances.
+
+ """
+ self.packagenames = kw.pop("packagenames", None) or ()
+ self.name = name
+
+ self._bind = self._get_bind(kw)
+ self.type = sqltypes.to_instance(kw.get("type_", None))
+
+ FunctionElement.__init__(self, *clauses, **kw)
+
+ def _get_bind(self, kw):
+ if "bind" in kw:
+ util.warn_deprecated_20(
+ "The Function.bind argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ )
+ return kw["bind"]
+
+ def _bind_param(self, operator, obj, type_=None, **kw):
+ return BindParameter(
+ self.name,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ **kw
+ )
+
+
+class _GenericMeta(TraversibleType):
+ def __init__(cls, clsname, bases, clsdict):
+ if annotation.Annotated not in cls.__mro__:
+ cls.name = name = clsdict.get("name", clsname)
+ cls.identifier = identifier = clsdict.get("identifier", name)
+ package = clsdict.pop("package", "_default")
+ # legacy
+ if "__return_type__" in clsdict:
+ cls.type = clsdict["__return_type__"]
+
+ # Check _register attribute status
+ cls._register = getattr(cls, "_register", True)
+
+ # Register the function if required
+ if cls._register:
+ register_function(identifier, cls, package)
+ else:
+ # Set _register to True to register child classes by default
+ cls._register = True
+
+ super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
+
+
+class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
+ """Define a 'generic' function.
+
+ A generic function is a pre-established :class:`.Function`
+ class that is instantiated automatically when called
+ by name from the :data:`.func` attribute. Note that
+ calling any name from :data:`.func` has the effect that
+ a new :class:`.Function` instance is created automatically,
+ given that name. The primary use case for defining
+ a :class:`.GenericFunction` class is so that a function
+ of a particular name may be given a fixed return type.
+ It can also include custom argument parsing schemes as well
+ as additional methods.
+
+ Subclasses of :class:`.GenericFunction` are automatically
+ registered under the name of the class. For
+ example, a user-defined function ``as_utc()`` would
+ be available immediately::
+
+ from sqlalchemy.sql.functions import GenericFunction
+ from sqlalchemy.types import DateTime
+
+ class as_utc(GenericFunction):
+ type = DateTime
+ inherit_cache = True
+
+ print(select(func.as_utc()))
+
+ User-defined generic functions can be organized into
+ packages by specifying the "package" attribute when defining
+ :class:`.GenericFunction`. Third party libraries
+ containing many functions may want to use this in order
+ to avoid name conflicts with other systems. For example,
+ if our ``as_utc()`` function were part of a package
+ "time"::
+
+ class as_utc(GenericFunction):
+ type = DateTime
+ package = "time"
+ inherit_cache = True
+
+ The above function would be available from :data:`.func`
+ using the package name ``time``::
+
+ print(select(func.time.as_utc()))
+
+ A final option is to allow the function to be accessed
+ from one name in :data:`.func` but to render as a different name.
+ The ``identifier`` attribute will override the name used to
+ access the function as loaded from :data:`.func`, but will retain
+ the usage of ``name`` as the rendered name::
+
+ class GeoBuffer(GenericFunction):
+ type = Geometry
+ package = "geo"
+ name = "ST_Buffer"
+ identifier = "buffer"
+ inherit_cache = True
+
+ The above function will render as follows::
+
+ >>> print(func.geo.buffer())
+ ST_Buffer()
+
+ The name will be rendered as is, however without quoting unless the name
+ contains special characters that require quoting. To force quoting
+ on or off for the name, use the :class:`.sqlalchemy.sql.quoted_name`
+ construct::
+
+ from sqlalchemy.sql import quoted_name
+
+ class GeoBuffer(GenericFunction):
+ type = Geometry
+ package = "geo"
+ name = quoted_name("ST_Buffer", True)
+ identifier = "buffer"
+ inherit_cache = True
+
+ The above function will render as::
+
+ >>> print(func.geo.buffer())
+ "ST_Buffer"()
+
+ .. versionadded:: 1.3.13 The :class:`.quoted_name` construct is now
+ recognized for quoting when used with the "name" attribute of the
+ object, so that quoting can be forced on or off for the function
+ name.
+
+
+ """
+
+ coerce_arguments = True
+ _register = False
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ parsed_args = kwargs.pop("_parsed_args", None)
+ if parsed_args is None:
+ parsed_args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=self.name,
+ apply_propagate_attrs=self,
+ )
+ for c in args
+ ]
+ self._has_args = self._has_args or bool(parsed_args)
+ self.packagenames = ()
+ self._bind = self._get_bind(kwargs)
+ self.clause_expr = ClauseList(
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ ).self_group()
+ self.type = sqltypes.to_instance(
+ kwargs.pop("type_", None) or getattr(self, "type", None)
+ )
+
+
+register_function("cast", Cast)
+register_function("extract", Extract)
+
+
+class next_value(GenericFunction):
+ """Represent the 'next value', given a :class:`.Sequence`
+ as its single argument.
+
+ Compiles into the appropriate function on each backend,
+ or will raise NotImplementedError if used on a backend
+ that does not provide support for sequences.
+
+ """
+
+ type = sqltypes.Integer()
+ name = "next_value"
+
+ _traverse_internals = [
+ ("sequence", InternalTraversal.dp_named_ddl_element)
+ ]
+
+ def __init__(self, seq, **kw):
+ assert isinstance(
+ seq, schema.Sequence
+ ), "next_value() accepts a Sequence object as input."
+ self._bind = self._get_bind(kw)
+ self.sequence = seq
+ self.type = sqltypes.to_instance(
+ seq.data_type or getattr(self, "type", None)
+ )
+
+ def compare(self, other, **kw):
+ return (
+ isinstance(other, next_value)
+ and self.sequence.name == other.sequence.name
+ )
+
+ @property
+ def _from_objects(self):
+ return []
+
+
+class AnsiFunction(GenericFunction):
+ """Define a function in "ansi" format, which doesn't render parenthesis."""
+
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ GenericFunction.__init__(self, *args, **kwargs)
+
+
+class ReturnTypeFromArgs(GenericFunction):
+ """Define a function whose return type is the same as its arguments."""
+
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=self.name,
+ apply_propagate_attrs=self,
+ )
+ for c in args
+ ]
+ kwargs.setdefault("type_", _type_from_args(args))
+ kwargs["_parsed_args"] = args
+ super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
+
+
+class coalesce(ReturnTypeFromArgs):
+ _has_args = True
+ inherit_cache = True
+
+
+class max(ReturnTypeFromArgs): # noqa: A001
+ """The SQL MAX() aggregate function."""
+
+ inherit_cache = True
+
+
+class min(ReturnTypeFromArgs): # noqa: A001
+ """The SQL MIN() aggregate function."""
+
+ inherit_cache = True
+
+
+class sum(ReturnTypeFromArgs): # noqa: A001
+ """The SQL SUM() aggregate function."""
+
+ inherit_cache = True
+
+
+class now(GenericFunction):
+ """The SQL now() datetime function.
+
+ SQLAlchemy dialects will usually render this particular function
+ in a backend-specific way, such as rendering it as ``CURRENT_TIMESTAMP``.
+
+ """
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class concat(GenericFunction):
+ """The SQL CONCAT() function, which concatenates strings.
+
+ E.g.::
+
+ >>> print(select(func.concat('a', 'b')))
+ SELECT concat(:concat_2, :concat_3) AS concat_1
+
+ String concatenation in SQLAlchemy is more commonly available using the
+ Python ``+`` operator with string datatypes, which will render a
+ backend-specific concatenation operator, such as ::
+
+ >>> print(select(literal("a") + "b"))
+ SELECT :param_1 || :param_2 AS anon_1
+
+
+ """
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class char_length(GenericFunction):
+ """The CHAR_LENGTH() SQL function."""
+
+ type = sqltypes.Integer
+ inherit_cache = True
+
+ def __init__(self, arg, **kwargs):
+ GenericFunction.__init__(self, arg, **kwargs)
+
+
+class random(GenericFunction):
+ """The RANDOM() SQL function."""
+
+ _has_args = True
+ inherit_cache = True
+
+
+class count(GenericFunction):
+ r"""The ANSI COUNT aggregate function. With no arguments,
+ emits COUNT \*.
+
+ E.g.::
+
+ from sqlalchemy import func
+ from sqlalchemy import select
+ from sqlalchemy import table, column
+
+ my_table = table('some_table', column('id'))
+
+ stmt = select(func.count()).select_from(my_table)
+
+ Executing ``stmt`` would emit::
+
+ SELECT count(*) AS count_1
+ FROM some_table
+
+
+ """
+ type = sqltypes.Integer
+ inherit_cache = True
+
+ def __init__(self, expression=None, **kwargs):
+ if expression is None:
+ expression = literal_column("*")
+ super(count, self).__init__(expression, **kwargs)
+
+
+class current_date(AnsiFunction):
+ """The CURRENT_DATE() SQL function."""
+
+ type = sqltypes.Date
+ inherit_cache = True
+
+
+class current_time(AnsiFunction):
+ """The CURRENT_TIME() SQL function."""
+
+ type = sqltypes.Time
+ inherit_cache = True
+
+
+class current_timestamp(AnsiFunction):
+ """The CURRENT_TIMESTAMP() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class current_user(AnsiFunction):
+ """The CURRENT_USER() SQL function."""
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class localtime(AnsiFunction):
+ """The localtime() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class localtimestamp(AnsiFunction):
+ """The localtimestamp() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class session_user(AnsiFunction):
+ """The SESSION_USER() SQL function."""
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class sysdate(AnsiFunction):
+ """The SYSDATE() SQL function."""
+
+ type = sqltypes.DateTime
+ inherit_cache = True
+
+
+class user(AnsiFunction):
+ """The USER() SQL function."""
+
+ type = sqltypes.String
+ inherit_cache = True
+
+
+class array_agg(GenericFunction):
+ """Support for the ARRAY_AGG function.
+
+ The ``func.array_agg(expr)`` construct returns an expression of
+ type :class:`_types.ARRAY`.
+
+ e.g.::
+
+ stmt = select(func.array_agg(table.c.values)[2:5])
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_postgresql.array_agg` - PostgreSQL-specific version that
+ returns :class:`_postgresql.ARRAY`, which has PG-specific operators
+ added.
+
+ """
+
+ type = sqltypes.ARRAY
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ args = [
+ coercions.expect(
+ roles.ExpressionElementRole, c, apply_propagate_attrs=self
+ )
+ for c in args
+ ]
+
+ default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
+ if "type_" not in kwargs:
+
+ type_from_args = _type_from_args(args)
+ if isinstance(type_from_args, sqltypes.ARRAY):
+ kwargs["type_"] = type_from_args
+ else:
+ kwargs["type_"] = default_array_type(type_from_args)
+ kwargs["_parsed_args"] = args
+ super(array_agg, self).__init__(*args, **kwargs)
+
+
+class OrderedSetAgg(GenericFunction):
+ """Define a function where the return type is based on the sort
+ expression type as defined by the expression passed to the
+ :meth:`.FunctionElement.within_group` method."""
+
+ array_for_multi_clause = False
+ inherit_cache = True
+
+ def within_group_type(self, within_group):
+ func_clauses = self.clause_expr.element
+ order_by = sqlutil.unwrap_order_by(within_group.order_by)
+ if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
+ return sqltypes.ARRAY(order_by[0].type)
+ else:
+ return order_by[0].type
+
+
+class mode(OrderedSetAgg):
+ """Implement the ``mode`` ordered-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is the same as the sort expression.
+
+ .. versionadded:: 1.1
+
+ """
+
+ inherit_cache = True
+
+
+class percentile_cont(OrderedSetAgg):
+ """Implement the ``percentile_cont`` ordered-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is the same as the sort expression,
+ or if the arguments are an array, an :class:`_types.ARRAY` of the sort
+ expression's type.
+
+ .. versionadded:: 1.1
+
+ """
+
+ array_for_multi_clause = True
+ inherit_cache = True
+
+
+class percentile_disc(OrderedSetAgg):
+ """Implement the ``percentile_disc`` ordered-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is the same as the sort expression,
+ or if the arguments are an array, an :class:`_types.ARRAY` of the sort
+ expression's type.
+
+ .. versionadded:: 1.1
+
+ """
+
+ array_for_multi_clause = True
+ inherit_cache = True
+
+
+class rank(GenericFunction):
+ """Implement the ``rank`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Integer`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Integer()
+ inherit_cache = True
+
+
+class dense_rank(GenericFunction):
+ """Implement the ``dense_rank`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Integer`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Integer()
+ inherit_cache = True
+
+
+class percent_rank(GenericFunction):
+ """Implement the ``percent_rank`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Numeric`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Numeric()
+ inherit_cache = True
+
+
+class cume_dist(GenericFunction):
+ """Implement the ``cume_dist`` hypothetical-set aggregate function.
+
+ This function must be used with the :meth:`.FunctionElement.within_group`
+ modifier to supply a sort expression to operate upon.
+
+ The return type of this function is :class:`.Numeric`.
+
+ .. versionadded:: 1.1
+
+ """
+
+ type = sqltypes.Numeric()
+ inherit_cache = True
+
+
+class cube(GenericFunction):
+ r"""Implement the ``CUBE`` grouping operation.
+
+ This function is used as part of the GROUP BY of a statement,
+ e.g. :meth:`_expression.Select.group_by`::
+
+ stmt = select(
+ func.sum(table.c.value), table.c.col_1, table.c.col_2
+ ).group_by(func.cube(table.c.col_1, table.c.col_2))
+
+ .. versionadded:: 1.2
+
+ """
+ _has_args = True
+ inherit_cache = True
+
+
+class rollup(GenericFunction):
+ r"""Implement the ``ROLLUP`` grouping operation.
+
+ This function is used as part of the GROUP BY of a statement,
+ e.g. :meth:`_expression.Select.group_by`::
+
+ stmt = select(
+ func.sum(table.c.value), table.c.col_1, table.c.col_2
+ ).group_by(func.rollup(table.c.col_1, table.c.col_2))
+
+ .. versionadded:: 1.2
+
+ """
+ _has_args = True
+ inherit_cache = True
+
+
+class grouping_sets(GenericFunction):
+ r"""Implement the ``GROUPING SETS`` grouping operation.
+
+ This function is used as part of the GROUP BY of a statement,
+ e.g. :meth:`_expression.Select.group_by`::
+
+ stmt = select(
+ func.sum(table.c.value), table.c.col_1, table.c.col_2
+ ).group_by(func.grouping_sets(table.c.col_1, table.c.col_2))
+
+ In order to group by multiple sets, use the :func:`.tuple_` construct::
+
+ from sqlalchemy import tuple_
+
+ stmt = select(
+ func.sum(table.c.value),
+ table.c.col_1, table.c.col_2,
+ table.c.col_3
+ ).group_by(
+ func.grouping_sets(
+ tuple_(table.c.col_1, table.c.col_2),
+ tuple_(table.c.value, table.c.col_3),
+ )
+ )
+
+
+ .. versionadded:: 1.2
+
+ """
+ _has_args = True
+ inherit_cache = True
diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py
new file mode 100644
index 0000000..584efe4
--- /dev/null
+++ b/lib/sqlalchemy/sql/lambdas.py
@@ -0,0 +1,1314 @@
+# sql/lambdas.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import inspect
+import itertools
+import operator
+import sys
+import threading
+import types
+import weakref
+
+from . import coercions
+from . import elements
+from . import roles
+from . import schema
+from . import traversals
+from . import type_api
+from . import visitors
+from .base import _clone
+from .base import Options
+from .operators import ColumnOperators
+from .. import exc
+from .. import inspection
+from .. import util
+from ..util import collections_abc
+from ..util import compat
+
+_closure_per_cache_key = util.LRUCache(1000)
+
+
+class LambdaOptions(Options):
+ enable_tracking = True
+ track_closure_variables = True
+ track_on = None
+ global_track_bound_values = True
+ track_bound_values = True
+ lambda_cache = None
+
+
+def lambda_stmt(
+ lmb,
+ enable_tracking=True,
+ track_closure_variables=True,
+ track_on=None,
+ global_track_bound_values=True,
+ track_bound_values=True,
+ lambda_cache=None,
+):
+ """Produce a SQL statement that is cached as a lambda.
+
+ The Python code object within the lambda is scanned for both Python
+ literals that will become bound parameters as well as closure variables
+ that refer to Core or ORM constructs that may vary. The lambda itself
+ will be invoked only once per particular set of constructs detected.
+
+ E.g.::
+
+ from sqlalchemy import lambda_stmt
+
+ stmt = lambda_stmt(lambda: table.select())
+ stmt += lambda s: s.where(table.c.id == 5)
+
+ result = connection.execute(stmt)
+
+ The object returned is an instance of :class:`_sql.StatementLambdaElement`.
+
+ .. versionadded:: 1.4
+
+ :param lmb: a Python function, typically a lambda, which takes no arguments
+ and returns a SQL expression construct
+ :param enable_tracking: when False, all scanning of the given lambda for
+ changes in closure variables or bound parameters is disabled. Use for
+ a lambda that produces the identical results in all cases with no
+ parameterization.
+ :param track_closure_variables: when False, changes in closure variables
+ within the lambda will not be scanned. Use for a lambda where the
+ state of its closure variables will never change the SQL structure
+ returned by the lambda.
+ :param track_bound_values: when False, bound parameter tracking will
+ be disabled for the given lambda. Use for a lambda that either does
+ not produce any bound values, or where the initial bound values never
+ change.
+ :param global_track_bound_values: when False, bound parameter tracking
+ will be disabled for the entire statement including additional links
+ added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
+ :param lambda_cache: a dictionary or other mapping-like object where
+ information about the lambda's Python code as well as the tracked closure
+ variables in the lambda itself will be stored. Defaults
+ to a global LRU cache. This cache is independent of the "compiled_cache"
+ used by the :class:`_engine.Connection` object.
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+
+ """
+
+ return StatementLambdaElement(
+ lmb,
+ roles.StatementRole,
+ LambdaOptions(
+ enable_tracking=enable_tracking,
+ track_on=track_on,
+ track_closure_variables=track_closure_variables,
+ global_track_bound_values=global_track_bound_values,
+ track_bound_values=track_bound_values,
+ lambda_cache=lambda_cache,
+ ),
+ )
+
+
+class LambdaElement(elements.ClauseElement):
+ """A SQL construct where the state is stored as an un-invoked lambda.
+
+ The :class:`_sql.LambdaElement` is produced transparently whenever
+ passing lambda expressions into SQL constructs, such as::
+
+ stmt = select(table).where(lambda: table.c.col == parameter)
+
+ The :class:`_sql.LambdaElement` is the base of the
+ :class:`_sql.StatementLambdaElement` which represents a full statement
+ within a lambda.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ _transforms = ()
+
+ parent_lambda = None
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
+
+ def __init__(
+ self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
+ ):
+ self.fn = fn
+ self.role = role
+ self.tracker_key = (fn.__code__,)
+ self.opts = opts
+
+ if apply_propagate_attrs is None and (role is roles.StatementRole):
+ apply_propagate_attrs = self
+
+ rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
+
+ if apply_propagate_attrs is not None:
+ propagate_attrs = rec.propagate_attrs
+ if propagate_attrs:
+ apply_propagate_attrs._propagate_attrs = propagate_attrs
+
+ def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
+ lambda_cache = opts.lambda_cache
+ if lambda_cache is None:
+ lambda_cache = _closure_per_cache_key
+
+ tracker_key = self.tracker_key
+
+ fn = self.fn
+ closure = fn.__closure__
+ tracker = AnalyzedCode.get(
+ fn,
+ self,
+ opts,
+ )
+
+ self._resolved_bindparams = bindparams = []
+
+ if self.parent_lambda is not None:
+ parent_closure_cache_key = self.parent_lambda.closure_cache_key
+ else:
+ parent_closure_cache_key = ()
+
+ if parent_closure_cache_key is not traversals.NO_CACHE:
+ anon_map = traversals.anon_map()
+ cache_key = tuple(
+ [
+ getter(closure, opts, anon_map, bindparams)
+ for getter in tracker.closure_trackers
+ ]
+ )
+
+ if traversals.NO_CACHE not in anon_map:
+ cache_key = parent_closure_cache_key + cache_key
+
+ self.closure_cache_key = cache_key
+
+ try:
+ rec = lambda_cache[tracker_key + cache_key]
+ except KeyError:
+ rec = None
+ else:
+ cache_key = traversals.NO_CACHE
+ rec = None
+
+ else:
+ cache_key = traversals.NO_CACHE
+ rec = None
+
+ self.closure_cache_key = cache_key
+
+ if rec is None:
+ if cache_key is not traversals.NO_CACHE:
+
+ with AnalyzedCode._generation_mutex:
+ key = tracker_key + cache_key
+ if key not in lambda_cache:
+ rec = AnalyzedFunction(
+ tracker, self, apply_propagate_attrs, fn
+ )
+ rec.closure_bindparams = bindparams
+ lambda_cache[key] = rec
+ else:
+ rec = lambda_cache[key]
+ else:
+ rec = NonAnalyzedFunction(self._invoke_user_fn(fn))
+
+ else:
+ bindparams[:] = [
+ orig_bind._with_value(new_bind.value, maintain_key=True)
+ for orig_bind, new_bind in zip(
+ rec.closure_bindparams, bindparams
+ )
+ ]
+
+ self._rec = rec
+
+ if cache_key is not traversals.NO_CACHE:
+ if self.parent_lambda is not None:
+ bindparams[:0] = self.parent_lambda._resolved_bindparams
+
+ lambda_element = self
+ while lambda_element is not None:
+ rec = lambda_element._rec
+ if rec.bindparam_trackers:
+ tracker_instrumented_fn = rec.tracker_instrumented_fn
+ for tracker in rec.bindparam_trackers:
+ tracker(
+ lambda_element.fn,
+ tracker_instrumented_fn,
+ bindparams,
+ )
+ lambda_element = lambda_element.parent_lambda
+
+ return rec
+
+ def __getattr__(self, key):
+ return getattr(self._rec.expected_expr, key)
+
+ @property
+ def _is_sequence(self):
+ return self._rec.is_sequence
+
+ @property
+ def _select_iterable(self):
+ if self._is_sequence:
+ return itertools.chain.from_iterable(
+ [element._select_iterable for element in self._resolved]
+ )
+
+ else:
+ return self._resolved._select_iterable
+
+ @property
+ def _from_objects(self):
+ if self._is_sequence:
+ return itertools.chain.from_iterable(
+ [element._from_objects for element in self._resolved]
+ )
+
+ else:
+ return self._resolved._from_objects
+
+ def _param_dict(self):
+ return {b.key: b.value for b in self._resolved_bindparams}
+
+ def _setup_binds_for_tracked_expr(self, expr):
+ bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
+
+ def replace(thing):
+ if isinstance(thing, elements.BindParameter):
+
+ if thing.key in bindparam_lookup:
+ bind = bindparam_lookup[thing.key]
+ if thing.expanding:
+ bind.expanding = True
+ bind.expand_op = thing.expand_op
+ bind.type = thing.type
+ return bind
+
+ if self._rec.is_sequence:
+ expr = [
+ visitors.replacement_traverse(sub_expr, {}, replace)
+ for sub_expr in expr
+ ]
+ elif getattr(expr, "is_clause_element", False):
+ expr = visitors.replacement_traverse(expr, {}, replace)
+
+ return expr
+
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ # TODO: this needs A LOT of tests
+ self._resolved = clone(
+ self._resolved,
+ deferred_copy_internals=deferred_copy_internals,
+ **kw
+ )
+
+ @util.memoized_property
+ def _resolved(self):
+ expr = self._rec.expected_expr
+
+ if self._resolved_bindparams:
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ return expr
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self.closure_cache_key is traversals.NO_CACHE:
+ anon_map[traversals.NO_CACHE] = True
+ return None
+
+ cache_key = (
+ self.fn.__code__,
+ self.__class__,
+ ) + self.closure_cache_key
+
+ parent = self.parent_lambda
+ while parent is not None:
+ cache_key = (
+ (parent.fn.__code__,) + parent.closure_cache_key + cache_key
+ )
+
+ parent = parent.parent_lambda
+
+ if self._resolved_bindparams:
+ bindparams.extend(self._resolved_bindparams)
+ return cache_key
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn()
+
+
+class DeferredLambdaElement(LambdaElement):
+ """A LambdaElement where the lambda accepts arguments and is
+ invoked within the compile phase with special context.
+
+ This lambda doesn't normally produce its real SQL expression outside of the
+ compile phase. It is passed a fixed set of initial arguments
+ so that it can generate a sample expression.
+
+ """
+
+ def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
+ self.lambda_args = lambda_args
+ super(DeferredLambdaElement, self).__init__(fn, role, opts)
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(*self.lambda_args)
+
+ def _resolve_with_args(self, *lambda_args):
+ tracker_fn = self._rec.tracker_instrumented_fn
+ expr = tracker_fn(*lambda_args)
+
+ expr = coercions.expect(self.role, expr)
+
+ expr = self._setup_binds_for_tracked_expr(expr)
+
+ # this validation is getting very close, but not quite, to achieving
+ # #5767. The problem is if the base lambda uses an unnamed column
+ # as is very common with mixins, the parameter name is different
+ # and it produces a false positive; that is, for the documented case
+ # that is exactly what people will be doing, it doesn't work, so
+ # I'm not really sure how to handle this right now.
+ # expected_binds = [
+ # b._orig_key
+ # for b in self._rec.expr._generate_cache_key()[1]
+ # if b.required
+ # ]
+ # got_binds = [
+ # b._orig_key for b in expr._generate_cache_key()[1] if b.required
+ # ]
+ # if expected_binds != got_binds:
+ # raise exc.InvalidRequestError(
+ # "Lambda callable at %s produced a different set of bound "
+ # "parameters than its original run: %s"
+ # % (self.fn.__code__, ", ".join(got_binds))
+ # )
+
+ # TODO: TEST TEST TEST, this is very out there
+ for deferred_copy_internals in self._transforms:
+ expr = deferred_copy_internals(expr)
+
+ return expr
+
+ def _copy_internals(
+ self, clone=_clone, deferred_copy_internals=None, **kw
+ ):
+ super(DeferredLambdaElement, self)._copy_internals(
+ clone=clone,
+ deferred_copy_internals=deferred_copy_internals, # **kw
+ opts=kw,
+ )
+
+ # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
+ # our expression yet. so hold onto the replacement
+ if deferred_copy_internals:
+ self._transforms += (deferred_copy_internals,)
+
+
+class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
+ """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
+
+ The :class:`_sql.StatementLambdaElement` is constructed using the
+ :func:`_sql.lambda_stmt` function::
+
+
+ from sqlalchemy import lambda_stmt
+
+ stmt = lambda_stmt(lambda: select(table))
+
+ Once constructed, additional criteria can be built onto the statement
+ by adding subsequent lambdas, which accept the existing statement
+ object as a single parameter::
+
+ stmt += lambda s: s.where(table.c.col == parameter)
+
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`engine_lambda_caching`
+
+ """
+
+ def __add__(self, other):
+ return self.add_criteria(other)
+
+ def add_criteria(
+ self,
+ other,
+ enable_tracking=True,
+ track_on=None,
+ track_closure_variables=True,
+ track_bound_values=True,
+ ):
+ """Add new criteria to this :class:`_sql.StatementLambdaElement`.
+
+ E.g.::
+
+ >>> def my_stmt(parameter):
+ ... stmt = lambda_stmt(
+ ... lambda: select(table.c.x, table.c.y),
+ ... )
+ ... stmt = stmt.add_criteria(
+ ... lambda: table.c.x > parameter
+ ... )
+ ... return stmt
+
+ The :meth:`_sql.StatementLambdaElement.add_criteria` method is
+ equivalent to using the Python addition operator to add a new
+ lambda, except that additional arguments may be added including
+ ``track_closure_values`` and ``track_on``::
+
+ >>> def my_stmt(self, foo):
+ ... stmt = lambda_stmt(
+ ... lambda: select(func.max(foo.x, foo.y)),
+ ... track_closure_variables=False
+ ... )
+ ... stmt = stmt.add_criteria(
+ ... lambda: self.where_criteria,
+ ... track_on=[self]
+ ... )
+ ... return stmt
+
+ See :func:`_sql.lambda_stmt` for a description of the parameters
+ accepted.
+
+ """
+
+ opts = self.opts + dict(
+ enable_tracking=enable_tracking,
+ track_closure_variables=track_closure_variables,
+ global_track_bound_values=self.opts.global_track_bound_values,
+ track_on=track_on,
+ track_bound_values=track_bound_values,
+ )
+
+ return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._rec.expected_expr.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
+
+ @property
+ def _with_options(self):
+ return self._rec.expected_expr._with_options
+
+ @property
+ def _effective_plugin_target(self):
+ return self._rec.expected_expr._effective_plugin_target
+
+ @property
+ def _execution_options(self):
+ return self._rec.expected_expr._execution_options
+
+ def spoil(self):
+ """Return a new :class:`.StatementLambdaElement` that will run
+ all lambdas unconditionally each time.
+
+ """
+ return NullLambdaStatement(self.fn())
+
+
+class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
+ """Provides the :class:`.StatementLambdaElement` API but does not
+ cache or analyze lambdas.
+
+ the lambdas are instead invoked immediately.
+
+ The intended use is to isolate issues that may arise when using
+ lambda statements.
+
+ """
+
+ __visit_name__ = "lambda_element"
+
+ _is_lambda_element = True
+
+ _traverse_internals = [
+ ("_resolved", visitors.InternalTraversal.dp_clauseelement)
+ ]
+
+ def __init__(self, statement):
+ self._resolved = statement
+ self._propagate_attrs = statement._propagate_attrs
+
+ def __getattr__(self, key):
+ return getattr(self._resolved, key)
+
+ def __add__(self, other):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def add_criteria(self, other, **kw):
+ statement = other(self._resolved)
+
+ return NullLambdaStatement(statement)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ if self._resolved.supports_execution:
+ return connection._execute_clauseelement(
+ self, multiparams, params, execution_options
+ )
+ else:
+ raise exc.ObjectNotExecutableError(self)
+
+
+class LinkedLambdaElement(StatementLambdaElement):
+ """Represent subsequent links of a :class:`.StatementLambdaElement`."""
+
+ role = None
+
+ def __init__(self, fn, parent_lambda, opts):
+ self.opts = opts
+ self.fn = fn
+ self.parent_lambda = parent_lambda
+
+ self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
+ self._retrieve_tracker_rec(fn, self, opts)
+ self._propagate_attrs = parent_lambda._propagate_attrs
+
+ def _invoke_user_fn(self, fn, *arg):
+ return fn(self.parent_lambda._resolved)
+
+
+class AnalyzedCode(object):
+ __slots__ = (
+ "track_closure_variables",
+ "track_bound_values",
+ "bindparam_trackers",
+ "closure_trackers",
+ "build_py_wrappers",
+ )
+ _fns = weakref.WeakKeyDictionary()
+
+ _generation_mutex = threading.RLock()
+
+ @classmethod
+ def get(cls, fn, lambda_element, lambda_kw, **kw):
+ try:
+ # TODO: validate kw haven't changed?
+ return cls._fns[fn.__code__]
+ except KeyError:
+ pass
+
+ with cls._generation_mutex:
+ # check for other thread already created object
+ if fn.__code__ in cls._fns:
+ return cls._fns[fn.__code__]
+
+ cls._fns[fn.__code__] = analyzed = AnalyzedCode(
+ fn, lambda_element, lambda_kw, **kw
+ )
+ return analyzed
+
+ def __init__(self, fn, lambda_element, opts):
+ if inspect.ismethod(fn):
+ raise exc.ArgumentError(
+ "Method %s may not be passed as a SQL expression" % fn
+ )
+ closure = fn.__closure__
+
+ self.track_bound_values = (
+ opts.track_bound_values and opts.global_track_bound_values
+ )
+ enable_tracking = opts.enable_tracking
+ track_on = opts.track_on
+ track_closure_variables = opts.track_closure_variables
+
+ self.track_closure_variables = track_closure_variables and not track_on
+
+ # a list of callables generated from _bound_parameter_getter_*
+ # functions. Each of these uses a PyWrapper object to retrieve
+ # a parameter value
+ self.bindparam_trackers = []
+
+ # a list of callables generated from _cache_key_getter_* functions
+ # these callables work to generate a cache key for the lambda
+ # based on what's inside its closure variables.
+ self.closure_trackers = []
+
+ self.build_py_wrappers = []
+
+ if enable_tracking:
+ if track_on:
+ self._init_track_on(track_on)
+
+ self._init_globals(fn)
+
+ if closure:
+ self._init_closure(fn)
+
+ self._setup_additional_closure_trackers(fn, lambda_element, opts)
+
+ def _init_track_on(self, track_on):
+ self.closure_trackers.extend(
+ self._cache_key_getter_track_on(idx, elem)
+ for idx, elem in enumerate(track_on)
+ )
+
+ def _init_globals(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ bindparam_trackers = self.bindparam_trackers
+ track_bound_values = self.track_bound_values
+
+ for name in fn.__code__.co_names:
+ if name not in fn.__globals__:
+ continue
+
+ _bound_value = self._roll_down_to_literal(fn.__globals__[name])
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((name, None))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_globals(name)
+ )
+
+ def _init_closure(self, fn):
+ build_py_wrappers = self.build_py_wrappers
+ closure = fn.__closure__
+
+ track_bound_values = self.track_bound_values
+ track_closure_variables = self.track_closure_variables
+ bindparam_trackers = self.bindparam_trackers
+ closure_trackers = self.closure_trackers
+
+ for closure_index, (fv, cell) in enumerate(
+ zip(fn.__code__.co_freevars, closure)
+ ):
+ _bound_value = self._roll_down_to_literal(cell.cell_contents)
+
+ if coercions._deep_is_literal(_bound_value):
+ build_py_wrappers.append((fv, closure_index))
+ if track_bound_values:
+ bindparam_trackers.append(
+ self._bound_parameter_getter_func_closure(
+ fv, closure_index
+ )
+ )
+ else:
+ # for normal cell contents, add them to a list that
+ # we can compare later when we get new lambdas. if
+ # any identities have changed, then we will
+ # recalculate the whole lambda and run it again.
+
+ if track_closure_variables:
+ closure_trackers.append(
+ self._cache_key_getter_closure_variable(
+ fn, fv, closure_index, cell.cell_contents
+ )
+ )
+
+ def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
+ # an additional step is to actually run the function, then
+ # go through the PyWrapper objects that were set up to catch a bound
+ # parameter. then if they *didn't* make a param, oh they're another
+ # object in the closure we have to track for our cache key. so
+ # create trackers to catch those.
+
+ analyzed_function = AnalyzedFunction(
+ self,
+ lambda_element,
+ None,
+ fn,
+ )
+
+ closure_trackers = self.closure_trackers
+
+ for pywrapper in analyzed_function.closure_pywrappers:
+ if not pywrapper._sa__has_param:
+ closure_trackers.append(
+ self._cache_key_getter_tracked_literal(fn, pywrapper)
+ )
+
+ @classmethod
+ def _roll_down_to_literal(cls, element):
+ is_clause_element = hasattr(element, "__clause_element__")
+
+ if is_clause_element:
+ while not isinstance(
+ element, (elements.ClauseElement, schema.SchemaItem, type)
+ ):
+ try:
+ element = element.__clause_element__()
+ except AttributeError:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ try:
+ return insp.__clause_element__()
+ except AttributeError:
+ return insp
+
+ # TODO: should we coerce consts None/True/False here?
+ return element
+ else:
+ return element
+
+ def _bound_parameter_getter_func_globals(self, name):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__globals__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__globals__[name]
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__globals__[name], result
+ )
+
+ return extract_parameter_value
+
+ def _bound_parameter_getter_func_closure(self, name, closure_index):
+ """Return a getter that will extend a list of bound parameters
+ with new entries from the ``__closure__`` collection of a particular
+ lambda.
+
+ """
+
+ def extract_parameter_value(
+ current_fn, tracker_instrumented_fn, result
+ ):
+ wrapper = tracker_instrumented_fn.__closure__[
+ closure_index
+ ].cell_contents
+ object.__getattribute__(wrapper, "_extract_bound_parameters")(
+ current_fn.__closure__[closure_index].cell_contents, result
+ )
+
+ return extract_parameter_value
+
+ def _cache_key_getter_track_on(self, idx, elem):
+ """Return a getter that will extend a cache key with new entries
+ from the "track_on" parameter passed to a :class:`.LambdaElement`.
+
+ """
+
+ if isinstance(elem, tuple):
+ # tuple must contain hascachekey elements
+ def get(closure, opts, anon_map, bindparams):
+ return tuple(
+ tup_elem._gen_cache_key(anon_map, bindparams)
+ for tup_elem in opts.track_on[idx]
+ )
+
+ elif isinstance(elem, traversals.HasCacheKey):
+
+ def get(closure, opts, anon_map, bindparams):
+ return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
+
+ else:
+
+ def get(closure, opts, anon_map, bindparams):
+ return opts.track_on[idx]
+
+ return get
+
+ def _cache_key_getter_closure_variable(
+ self,
+ fn,
+ variable_name,
+ idx,
+ cell_contents,
+ use_clause_element=False,
+ use_inspect=False,
+ ):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ """
+
+ if isinstance(cell_contents, traversals.HasCacheKey):
+
+ def get(closure, opts, anon_map, bindparams):
+
+ obj = closure[idx].cell_contents
+ if use_inspect:
+ obj = inspection.inspect(obj)
+ elif use_clause_element:
+ while hasattr(obj, "__clause_element__"):
+ if not getattr(obj, "is_clause_element", False):
+ obj = obj.__clause_element__()
+
+ return obj._gen_cache_key(anon_map, bindparams)
+
+ elif isinstance(cell_contents, types.FunctionType):
+
+ def get(closure, opts, anon_map, bindparams):
+ return closure[idx].cell_contents.__code__
+
+ elif isinstance(cell_contents, collections_abc.Sequence):
+
+ def get(closure, opts, anon_map, bindparams):
+ contents = closure[idx].cell_contents
+
+ try:
+ return tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in contents
+ )
+ except AttributeError as ae:
+ self._raise_for_uncacheable_closure_variable(
+ variable_name, fn, from_=ae
+ )
+
+ else:
+ # if the object is a mapped class or aliased class, or some
+ # other object in the ORM realm of things like that, imitate
+ # the logic used in coercions.expect() to roll it down to the
+ # SQL element
+ element = cell_contents
+ is_clause_element = False
+ while hasattr(element, "__clause_element__"):
+ is_clause_element = True
+ if not getattr(element, "is_clause_element", False):
+ element = element.__clause_element__()
+ else:
+ break
+
+ if not is_clause_element:
+ insp = inspection.inspect(element, raiseerr=False)
+ if insp is not None:
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, idx, insp, use_inspect=True
+ )
+ else:
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, idx, element, use_clause_element=True
+ )
+
+ self._raise_for_uncacheable_closure_variable(variable_name, fn)
+
+ return get
+
+ def _raise_for_uncacheable_closure_variable(
+ self, variable_name, fn, from_=None
+ ):
+ util.raise_(
+ exc.InvalidRequestError(
+ "Closure variable named '%s' inside of lambda callable %s "
+ "does not refer to a cacheable SQL element, and also does not "
+ "appear to be serving as a SQL literal bound value based on "
+ "the default "
+ "SQL expression returned by the function. This variable "
+ "needs to remain outside the scope of a SQL-generating lambda "
+ "so that a proper cache key may be generated from the "
+ "lambda's state. Evaluate this variable outside of the "
+ "lambda, set track_on=[<elements>] to explicitly select "
+ "closure elements to track, or set "
+ "track_closure_variables=False to exclude "
+ "closure variables from being part of the cache key."
+ % (variable_name, fn.__code__),
+ ),
+ from_=from_,
+ )
+
+ def _cache_key_getter_tracked_literal(self, fn, pytracker):
+ """Return a getter that will extend a cache key with new entries
+ from the ``__closure__`` collection of a particular lambda.
+
+ this getter differs from _cache_key_getter_closure_variable
+ in that these are detected after the function is run, and PyWrapper
+ objects have recorded that a particular literal value is in fact
+ not being interpreted as a bound parameter.
+
+ """
+
+ elem = pytracker._sa__to_evaluate
+ closure_index = pytracker._sa__closure_index
+ variable_name = pytracker._sa__name
+
+ return self._cache_key_getter_closure_variable(
+ fn, variable_name, closure_index, elem
+ )
+
+
+class NonAnalyzedFunction(object):
+ __slots__ = ("expr",)
+
+ closure_bindparams = None
+ bindparam_trackers = None
+
+ def __init__(self, expr):
+ self.expr = expr
+
+ @property
+ def expected_expr(self):
+ return self.expr
+
+
+class AnalyzedFunction(object):
+ __slots__ = (
+ "analyzed_code",
+ "fn",
+ "closure_pywrappers",
+ "tracker_instrumented_fn",
+ "expr",
+ "bindparam_trackers",
+ "expected_expr",
+ "is_sequence",
+ "propagate_attrs",
+ "closure_bindparams",
+ )
+
+ def __init__(
+ self,
+ analyzed_code,
+ lambda_element,
+ apply_propagate_attrs,
+ fn,
+ ):
+ self.analyzed_code = analyzed_code
+ self.fn = fn
+
+ self.bindparam_trackers = analyzed_code.bindparam_trackers
+
+ self._instrument_and_run_function(lambda_element)
+
+ self._coerce_expression(lambda_element, apply_propagate_attrs)
+
+ def _instrument_and_run_function(self, lambda_element):
+ analyzed_code = self.analyzed_code
+
+ fn = self.fn
+ self.closure_pywrappers = closure_pywrappers = []
+
+ build_py_wrappers = analyzed_code.build_py_wrappers
+
+ if not build_py_wrappers:
+ self.tracker_instrumented_fn = tracker_instrumented_fn = fn
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+ else:
+ track_closure_variables = analyzed_code.track_closure_variables
+ closure = fn.__closure__
+
+ # will form the __closure__ of the function when we rebuild it
+ if closure:
+ new_closure = {
+ fv: cell.cell_contents
+ for fv, cell in zip(fn.__code__.co_freevars, closure)
+ }
+ else:
+ new_closure = {}
+
+ # will form the __globals__ of the function when we rebuild it
+ new_globals = fn.__globals__.copy()
+
+ for name, closure_index in build_py_wrappers:
+ if closure_index is not None:
+ value = closure[closure_index].cell_contents
+ new_closure[name] = bind = PyWrapper(
+ fn,
+ name,
+ value,
+ closure_index=closure_index,
+ track_bound_values=(
+ self.analyzed_code.track_bound_values
+ ),
+ )
+ if track_closure_variables:
+ closure_pywrappers.append(bind)
+ else:
+ value = fn.__globals__[name]
+ new_globals[name] = bind = PyWrapper(fn, name, value)
+
+ # rewrite the original fn. things that look like they will
+ # become bound parameters are wrapped in a PyWrapper.
+ self.tracker_instrumented_fn = (
+ tracker_instrumented_fn
+ ) = self._rewrite_code_obj(
+ fn,
+ [new_closure[name] for name in fn.__code__.co_freevars],
+ new_globals,
+ )
+
+ # now invoke the function. This will give us a new SQL
+ # expression, but all the places that there would be a bound
+ # parameter, the PyWrapper in its place will give us a bind
+ # with a predictable name we can match up later.
+
+ # additionally, each PyWrapper will log that it did in fact
+ # create a parameter, otherwise, it's some kind of Python
+ # object in the closure and we want to track that, to make
+ # sure it doesn't change to something else, or if it does,
+ # that we create a different tracked function with that
+ # variable.
+ self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
+
+ def _coerce_expression(self, lambda_element, apply_propagate_attrs):
+ """Run the tracker-generated expression through coercion rules.
+
+ After the user-defined lambda has been invoked to produce a statement
+ for re-use, run it through coercion rules to both check that it's the
+ correct type of object and also to coerce it to its useful form.
+
+ """
+
+ parent_lambda = lambda_element.parent_lambda
+ expr = self.expr
+
+ if parent_lambda is None:
+ if isinstance(expr, collections_abc.Sequence):
+ self.expected_expr = [
+ coercions.expect(
+ lambda_element.role,
+ sub_expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ )
+ for sub_expr in expr
+ ]
+ self.is_sequence = True
+ else:
+ self.expected_expr = coercions.expect(
+ lambda_element.role,
+ expr,
+ apply_propagate_attrs=apply_propagate_attrs,
+ )
+ self.is_sequence = False
+ else:
+ self.expected_expr = expr
+ self.is_sequence = False
+
+ if apply_propagate_attrs is not None:
+ self.propagate_attrs = apply_propagate_attrs._propagate_attrs
+ else:
+ self.propagate_attrs = util.EMPTY_DICT
+
+ def _rewrite_code_obj(self, f, cell_values, globals_):
+ """Return a copy of f, with a new closure and new globals
+
+ yes it works in pypy :P
+
+ """
+
+ argrange = range(len(cell_values))
+
+ code = "def make_cells():\n"
+ if cell_values:
+ code += " (%s) = (%s)\n" % (
+ ", ".join("i%d" % i for i in argrange),
+ ", ".join("o%d" % i for i in argrange),
+ )
+ code += " def closure():\n"
+ code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
+ code += " return closure.__closure__"
+ vars_ = {"o%d" % i: cell_values[i] for i in argrange}
+ compat.exec_(code, vars_, vars_)
+ closure = vars_["make_cells"]()
+
+ func = type(f)(
+ f.__code__, globals_, f.__name__, f.__defaults__, closure
+ )
+ if sys.version_info >= (3,):
+ func.__annotations__ = f.__annotations__
+ func.__kwdefaults__ = f.__kwdefaults__
+ func.__doc__ = f.__doc__
+ func.__module__ = f.__module__
+
+ return func
+
+
+class PyWrapper(ColumnOperators):
+ """A wrapper object that is injected into the ``__globals__`` and
+ ``__closure__`` of a Python function.
+
+ When the function is instrumented with :class:`.PyWrapper` objects, it is
+ then invoked just once in order to set up the wrappers. We look through
+ all the :class:`.PyWrapper` objects we made to find the ones that generated
+ a :class:`.BindParameter` object, e.g. the expression system interpreted
+ something as a literal. Those positions in the globals/closure are then
+ ones that we will look at, each time a new lambda comes in that refers to
+ the same ``__code__`` object. In this way, we keep a single version of
+ the SQL expression that this lambda produced, without calling upon the
+ Python function that created it more than once, unless its other closure
+ variables have changed. The expression is then transformed to have the
+ new bound values embedded into it.
+
+ """
+
+ def __init__(
+ self,
+ fn,
+ name,
+ to_evaluate,
+ closure_index=None,
+ getter=None,
+ track_bound_values=True,
+ ):
+ self.fn = fn
+ self._name = name
+ self._to_evaluate = to_evaluate
+ self._param = None
+ self._has_param = False
+ self._bind_paths = {}
+ self._getter = getter
+ self._closure_index = closure_index
+ self.track_bound_values = track_bound_values
+
+ def __call__(self, *arg, **kw):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ value = elem(*arg, **kw)
+ if (
+ self._sa_track_bound_values
+ and coercions._deep_is_literal(value)
+ and not isinstance(
+ # TODO: coverage where an ORM option or similar is here
+ value,
+ traversals.HasCacheKey,
+ )
+ ):
+ name = object.__getattribute__(self, "_name")
+ raise exc.InvalidRequestError(
+ "Can't invoke Python callable %s() inside of lambda "
+ "expression argument at %s; lambda SQL constructs should "
+ "not invoke functions from closure variables to produce "
+ "literal values since the "
+ "lambda SQL system normally extracts bound values without "
+ "actually "
+ "invoking the lambda or any functions within it. Call the "
+ "function outside of the "
+ "lambda and assign to a local variable that is used in the "
+ "lambda as a closure variable, or set "
+ "track_bound_values=False if the return value of this "
+ "function is used in some other way other than a SQL bound "
+ "value." % (name, self._sa_fn.__code__)
+ )
+ else:
+ return value
+
+ def operate(self, op, *other, **kwargs):
+ elem = object.__getattribute__(self, "__clause_element__")()
+ return op(elem, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ elem = object.__getattribute__(self, "__clause_element__")()
+ return op(other, elem, **kwargs)
+
+ def _extract_bound_parameters(self, starting_point, result_list):
+ param = object.__getattribute__(self, "_param")
+ if param is not None:
+ param = param._with_value(starting_point, maintain_key=True)
+ result_list.append(param)
+ for pywrapper in object.__getattribute__(self, "_bind_paths").values():
+ getter = object.__getattribute__(pywrapper, "_getter")
+ element = getter(starting_point)
+ pywrapper._sa__extract_bound_parameters(element, result_list)
+
+ def __clause_element__(self):
+ param = object.__getattribute__(self, "_param")
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ if param is None:
+ name = object.__getattribute__(self, "_name")
+ self._param = param = elements.BindParameter(
+ name, required=False, unique=True
+ )
+ self._has_param = True
+ param.type = type_api._resolve_value_to_type(to_evaluate)
+ return param._with_value(to_evaluate, maintain_key=True)
+
+ def __bool__(self):
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ return bool(to_evaluate)
+
+ def __nonzero__(self):
+ to_evaluate = object.__getattribute__(self, "_to_evaluate")
+ return bool(to_evaluate)
+
+ def __getattribute__(self, key):
+ if key.startswith("_sa_"):
+ return object.__getattribute__(self, key[4:])
+ elif key in (
+ "__clause_element__",
+ "operate",
+ "reverse_operate",
+ "__class__",
+ "__dict__",
+ ):
+ return object.__getattribute__(self, key)
+
+ if key.startswith("__"):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return getattr(elem, key)
+ else:
+ return self._sa__add_getter(key, operator.attrgetter)
+
+ def __iter__(self):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ return iter(elem)
+
+ def __getitem__(self, key):
+ elem = object.__getattribute__(self, "_to_evaluate")
+ if not hasattr(elem, "__getitem__"):
+ raise AttributeError("__getitem__")
+
+ if isinstance(key, PyWrapper):
+ # TODO: coverage
+ raise exc.InvalidRequestError(
+ "Dictionary keys / list indexes inside of a cached "
+ "lambda must be Python literals only"
+ )
+ return self._sa__add_getter(key, operator.itemgetter)
+
+ def _add_getter(self, key, getter_fn):
+
+ bind_paths = object.__getattribute__(self, "_bind_paths")
+
+ bind_path_key = (key, getter_fn)
+ if bind_path_key in bind_paths:
+ return bind_paths[bind_path_key]
+
+ getter = getter_fn(key)
+ elem = object.__getattribute__(self, "_to_evaluate")
+ value = getter(elem)
+
+ rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
+
+ if coercions._deep_is_literal(rolled_down_value):
+ wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
+ bind_paths[bind_path_key] = wrapper
+ return wrapper
+ else:
+ return value
+
+
+@inspection._inspects(LambdaElement)
+def insp(lmb):
+ return inspection.inspect(lmb._resolved)
diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py
new file mode 100644
index 0000000..b7ad221
--- /dev/null
+++ b/lib/sqlalchemy/sql/naming.py
@@ -0,0 +1,210 @@
+# sqlalchemy/naming.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Establish constraint and index naming conventions.
+
+
+"""
+
+import re
+
+from . import events # noqa
+from .elements import _NONE_NAME
+from .elements import conv
+from .schema import CheckConstraint
+from .schema import Column
+from .schema import Constraint
+from .schema import ForeignKeyConstraint
+from .schema import Index
+from .schema import PrimaryKeyConstraint
+from .schema import Table
+from .schema import UniqueConstraint
+from .. import event
+from .. import exc
+
+
+class ConventionDict(object):
+ def __init__(self, const, table, convention):
+ self.const = const
+ self._is_fk = isinstance(const, ForeignKeyConstraint)
+ self.table = table
+ self.convention = convention
+ self._const_name = const.name
+
+ def _key_table_name(self):
+ return self.table.name
+
+ def _column_X(self, idx, attrname):
+ if self._is_fk:
+ try:
+ fk = self.const.elements[idx]
+ except IndexError:
+ return ""
+ else:
+ return getattr(fk.parent, attrname)
+ else:
+ cols = list(self.const.columns)
+ try:
+ col = cols[idx]
+ except IndexError:
+ return ""
+ else:
+ return getattr(col, attrname)
+
+ def _key_constraint_name(self):
+ if self._const_name in (None, _NONE_NAME):
+ raise exc.InvalidRequestError(
+ "Naming convention including "
+ "%(constraint_name)s token requires that "
+ "constraint is explicitly named."
+ )
+ if not isinstance(self._const_name, conv):
+ self.const.name = None
+ return self._const_name
+
+ def _key_column_X_key(self, idx):
+ # note this method was missing before
+ # [ticket:3989], meaning tokens like ``%(column_0_key)s`` weren't
+ # working even though documented.
+ return self._column_X(idx, "key")
+
+ def _key_column_X_name(self, idx):
+ return self._column_X(idx, "name")
+
+ def _key_column_X_label(self, idx):
+ return self._column_X(idx, "_ddl_label")
+
+ def _key_referred_table_name(self):
+ fk = self.const.elements[0]
+ refs = fk.target_fullname.split(".")
+ if len(refs) == 3:
+ refschema, reftable, refcol = refs
+ else:
+ reftable, refcol = refs
+ return reftable
+
+ def _key_referred_column_X_name(self, idx):
+ fk = self.const.elements[idx]
+ # note that before [ticket:3989], this method was returning
+ # the specification for the :class:`.ForeignKey` itself, which normally
+ # would be using the ``.key`` of the column, not the name.
+ return fk.column.name
+
+ def __getitem__(self, key):
+ if key in self.convention:
+ return self.convention[key](self.const, self.table)
+ elif hasattr(self, "_key_%s" % key):
+ return getattr(self, "_key_%s" % key)()
+ else:
+ col_template = re.match(r".*_?column_(\d+)(_?N)?_.+", key)
+ if col_template:
+ idx = col_template.group(1)
+ multiples = col_template.group(2)
+
+ if multiples:
+ if self._is_fk:
+ elems = self.const.elements
+ else:
+ elems = list(self.const.columns)
+ tokens = []
+ for idx, elem in enumerate(elems):
+ attr = "_key_" + key.replace("0" + multiples, "X")
+ try:
+ tokens.append(getattr(self, attr)(idx))
+ except AttributeError:
+ raise KeyError(key)
+ sep = "_" if multiples.startswith("_") else ""
+ return sep.join(tokens)
+ else:
+ attr = "_key_" + key.replace(idx, "X")
+ idx = int(idx)
+ if hasattr(self, attr):
+ return getattr(self, attr)(idx)
+ raise KeyError(key)
+
+
+_prefix_dict = {
+ Index: "ix",
+ PrimaryKeyConstraint: "pk",
+ CheckConstraint: "ck",
+ UniqueConstraint: "uq",
+ ForeignKeyConstraint: "fk",
+}
+
+
+def _get_convention(dict_, key):
+
+ for super_ in key.__mro__:
+ if super_ in _prefix_dict and _prefix_dict[super_] in dict_:
+ return dict_[_prefix_dict[super_]]
+ elif super_ in dict_:
+ return dict_[super_]
+ else:
+ return None
+
+
+def _constraint_name_for_table(const, table):
+ metadata = table.metadata
+ convention = _get_convention(metadata.naming_convention, type(const))
+
+ if isinstance(const.name, conv):
+ return const.name
+ elif (
+ convention is not None
+ and not isinstance(const.name, conv)
+ and (
+ const.name is None
+ or "constraint_name" in convention
+ or const.name is _NONE_NAME
+ )
+ ):
+ return conv(
+ convention
+ % ConventionDict(const, table, metadata.naming_convention)
+ )
+ elif convention is _NONE_NAME:
+ return None
+
+
+@event.listens_for(
+ PrimaryKeyConstraint, "_sa_event_column_added_to_pk_constraint"
+)
+def _column_added_to_pk_constraint(pk_constraint, col):
+ if pk_constraint._implicit_generated:
+ # only operate upon the "implicit" pk constraint for now,
+ # as we have to force the name to None to reset it. the
+ # "implicit" constraint will only have a naming convention name
+ # if at all.
+ table = pk_constraint.table
+ pk_constraint.name = None
+ newname = _constraint_name_for_table(pk_constraint, table)
+ if newname:
+ pk_constraint.name = newname
+
+
+@event.listens_for(Constraint, "after_parent_attach")
+@event.listens_for(Index, "after_parent_attach")
+def _constraint_name(const, table):
+ if isinstance(table, Column):
+ # this path occurs for a CheckConstraint linked to a Column
+
+ # for column-attached constraint, set another event
+ # to link the column attached to the table as this constraint
+ # associated with the table.
+ event.listen(
+ table,
+ "after_parent_attach",
+ lambda col, table: _constraint_name(const, table),
+ )
+
+ elif isinstance(table, Table):
+ if isinstance(const.name, conv) or const.name is _NONE_NAME:
+ return
+
+ newname = _constraint_name_for_table(const, table)
+ if newname:
+ const.name = newname
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
new file mode 100644
index 0000000..1da5032
--- /dev/null
+++ b/lib/sqlalchemy/sql/operators.py
@@ -0,0 +1,1688 @@
+# sql/operators.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Defines operators used in SQL expressions."""
+
+from operator import add
+from operator import and_
+from operator import contains
+from operator import eq
+from operator import ge
+from operator import getitem
+from operator import gt
+from operator import inv
+from operator import le
+from operator import lshift
+from operator import lt
+from operator import mod
+from operator import mul
+from operator import ne
+from operator import neg
+from operator import or_
+from operator import rshift
+from operator import sub
+from operator import truediv
+
+from .. import util
+
+
+if util.py2k:
+ from operator import div
+else:
+ div = truediv
+
+
+class Operators(object):
+ """Base of comparison and logical operators.
+
+ Implements base methods
+ :meth:`~sqlalchemy.sql.operators.Operators.operate` and
+ :meth:`~sqlalchemy.sql.operators.Operators.reverse_operate`, as well as
+ :meth:`~sqlalchemy.sql.operators.Operators.__and__`,
+ :meth:`~sqlalchemy.sql.operators.Operators.__or__`,
+ :meth:`~sqlalchemy.sql.operators.Operators.__invert__`.
+
+ Usually is used via its most common subclass
+ :class:`.ColumnOperators`.
+
+ """
+
+ __slots__ = ()
+
+ def __and__(self, other):
+ """Implement the ``&`` operator.
+
+ When used with SQL expressions, results in an
+ AND operation, equivalent to
+ :func:`_expression.and_`, that is::
+
+ a & b
+
+ is equivalent to::
+
+ from sqlalchemy import and_
+ and_(a, b)
+
+ Care should be taken when using ``&`` regarding
+ operator precedence; the ``&`` operator has the highest precedence.
+ The operands should be enclosed in parenthesis if they contain
+ further sub expressions::
+
+ (a == 2) & (b == 4)
+
+ """
+ return self.operate(and_, other)
+
+ def __or__(self, other):
+ """Implement the ``|`` operator.
+
+ When used with SQL expressions, results in an
+ OR operation, equivalent to
+ :func:`_expression.or_`, that is::
+
+ a | b
+
+ is equivalent to::
+
+ from sqlalchemy import or_
+ or_(a, b)
+
+ Care should be taken when using ``|`` regarding
+ operator precedence; the ``|`` operator has the highest precedence.
+ The operands should be enclosed in parenthesis if they contain
+ further sub expressions::
+
+ (a == 2) | (b == 4)
+
+ """
+ return self.operate(or_, other)
+
+ def __invert__(self):
+ """Implement the ``~`` operator.
+
+ When used with SQL expressions, results in a
+ NOT operation, equivalent to
+ :func:`_expression.not_`, that is::
+
+ ~a
+
+ is equivalent to::
+
+ from sqlalchemy import not_
+ not_(a)
+
+ """
+ return self.operate(inv)
+
+ def op(
+ self, opstring, precedence=0, is_comparison=False, return_type=None
+ ):
+ """Produce a generic operator function.
+
+ e.g.::
+
+ somecolumn.op("*")(5)
+
+ produces::
+
+ somecolumn * 5
+
+ This function can also be used to make bitwise operators explicit. For
+ example::
+
+ somecolumn.op('&')(0xff)
+
+ is a bitwise AND of the value in ``somecolumn``.
+
+ :param operator: a string which will be output as the infix operator
+ between this element and the expression passed to the
+ generated function.
+
+ :param precedence: precedence to apply to the operator, when
+ parenthesizing expressions. A lower number will cause the expression
+ to be parenthesized when applied against another operator with
+ higher precedence. The default value of ``0`` is lower than all
+ operators except for the comma (``,``) and ``AS`` operators.
+ A value of 100 will be higher or equal to all operators, and -100
+ will be lower than or equal to all operators.
+
+ :param is_comparison: legacy; if True, the operator will be considered
+ as a "comparison" operator, that is which evaluates to a boolean
+ true/false value, like ``==``, ``>``, etc. This flag is provided
+ so that ORM relationships can establish that the operator is a
+ comparison operator when used in a custom join condition.
+
+ Using the ``is_comparison`` parameter is superseded by using the
+ :meth:`.Operators.bool_op` method instead; this more succinct
+ operator sets this parameter automatically. In SQLAlchemy 2.0 it
+ will also provide for improved typing support.
+
+ :param return_type: a :class:`.TypeEngine` class or object that will
+ force the return type of an expression produced by this operator
+ to be of that type. By default, operators that specify
+ :paramref:`.Operators.op.is_comparison` will resolve to
+ :class:`.Boolean`, and those that do not will be of the same
+ type as the left-hand operand.
+
+ .. seealso::
+
+ :meth:`.Operators.bool_op`
+
+ :ref:`types_operators`
+
+ :ref:`relationship_custom_operator`
+
+ """
+ operator = custom_op(opstring, precedence, is_comparison, return_type)
+
+ def against(other):
+ return operator(self, other)
+
+ return against
+
+ def bool_op(self, opstring, precedence=0):
+ """Return a custom boolean operator.
+
+ This method is shorthand for calling
+ :meth:`.Operators.op` and passing the
+ :paramref:`.Operators.op.is_comparison`
+ flag with True. A key advantage to using :meth:`.Operators.bool_op`
+ is that when using column constructs, the "boolean" nature of the
+ returned expression will be present for :pep:`484` purposes.
+
+ .. seealso::
+
+ :meth:`.Operators.op`
+
+ """
+ return self.op(opstring, precedence=precedence, is_comparison=True)
+
+ def operate(self, op, *other, **kwargs):
+ r"""Operate on an argument.
+
+ This is the lowest level of operation, raises
+ :class:`NotImplementedError` by default.
+
+ Overriding this on a subclass can allow common
+ behavior to be applied to all operations.
+ For example, overriding :class:`.ColumnOperators`
+ to apply ``func.lower()`` to the left and right
+ side::
+
+ class MyComparator(ColumnOperators):
+ def operate(self, op, other, **kwargs):
+ return op(func.lower(self), func.lower(other), **kwargs)
+
+ :param op: Operator callable.
+ :param \*other: the 'other' side of the operation. Will
+ be a single scalar for most operations.
+ :param \**kwargs: modifiers. These may be passed by special
+ operators such as :meth:`ColumnOperators.contains`.
+
+
+ """
+ raise NotImplementedError(str(op))
+
+ def reverse_operate(self, op, other, **kwargs):
+ """Reverse operate on an argument.
+
+ Usage is the same as :meth:`operate`.
+
+ """
+ raise NotImplementedError(str(op))
+
+
+class custom_op(object):
+ """Represent a 'custom' operator.
+
+ :class:`.custom_op` is normally instantiated when the
+ :meth:`.Operators.op` or :meth:`.Operators.bool_op` methods
+ are used to create a custom operator callable. The class can also be
+ used directly when programmatically constructing expressions. E.g.
+ to represent the "factorial" operation::
+
+ from sqlalchemy.sql import UnaryExpression
+ from sqlalchemy.sql import operators
+ from sqlalchemy import Numeric
+
+ unary = UnaryExpression(table.c.somecolumn,
+ modifier=operators.custom_op("!"),
+ type_=Numeric)
+
+
+ .. seealso::
+
+ :meth:`.Operators.op`
+
+ :meth:`.Operators.bool_op`
+
+ """
+
+ __name__ = "custom_op"
+
+ def __init__(
+ self,
+ opstring,
+ precedence=0,
+ is_comparison=False,
+ return_type=None,
+ natural_self_precedent=False,
+ eager_grouping=False,
+ ):
+ self.opstring = opstring
+ self.precedence = precedence
+ self.is_comparison = is_comparison
+ self.natural_self_precedent = natural_self_precedent
+ self.eager_grouping = eager_grouping
+ self.return_type = (
+ return_type._to_instance(return_type) if return_type else None
+ )
+
+ def __eq__(self, other):
+ return isinstance(other, custom_op) and other.opstring == self.opstring
+
+ def __hash__(self):
+ return id(self)
+
+ def __call__(self, left, right, **kw):
+ return left.operate(self, right, **kw)
+
+
+class ColumnOperators(Operators):
+ """Defines boolean, comparison, and other operators for
+ :class:`_expression.ColumnElement` expressions.
+
+ By default, all methods call down to
+ :meth:`.operate` or :meth:`.reverse_operate`,
+ passing in the appropriate operator function from the
+ Python builtin ``operator`` module or
+ a SQLAlchemy-specific operator function from
+ :mod:`sqlalchemy.expression.operators`. For example
+ the ``__eq__`` function::
+
+ def __eq__(self, other):
+ return self.operate(operators.eq, other)
+
+ Where ``operators.eq`` is essentially::
+
+ def eq(a, b):
+ return a == b
+
+ The core column expression unit :class:`_expression.ColumnElement`
+ overrides :meth:`.Operators.operate` and others
+ to return further :class:`_expression.ColumnElement` constructs,
+ so that the ``==`` operation above is replaced by a clause
+ construct.
+
+ .. seealso::
+
+ :ref:`types_operators`
+
+ :attr:`.TypeEngine.comparator_factory`
+
+ :class:`.ColumnOperators`
+
+ :class:`.PropComparator`
+
+ """
+
+ __slots__ = ()
+
+ timetuple = None
+ """Hack, allows datetime objects to be compared on the LHS."""
+
+ def __lt__(self, other):
+ """Implement the ``<`` operator.
+
+ In a column context, produces the clause ``a < b``.
+
+ """
+ return self.operate(lt, other)
+
+ def __le__(self, other):
+ """Implement the ``<=`` operator.
+
+ In a column context, produces the clause ``a <= b``.
+
+ """
+ return self.operate(le, other)
+
+ __hash__ = Operators.__hash__
+
+ def __eq__(self, other):
+ """Implement the ``==`` operator.
+
+ In a column context, produces the clause ``a = b``.
+ If the target is ``None``, produces ``a IS NULL``.
+
+ """
+ return self.operate(eq, other)
+
+ def __ne__(self, other):
+ """Implement the ``!=`` operator.
+
+ In a column context, produces the clause ``a != b``.
+ If the target is ``None``, produces ``a IS NOT NULL``.
+
+ """
+ return self.operate(ne, other)
+
+ def is_distinct_from(self, other):
+ """Implement the ``IS DISTINCT FROM`` operator.
+
+ Renders "a IS DISTINCT FROM b" on most platforms;
+ on some such as SQLite may render "a IS NOT b".
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(is_distinct_from, other)
+
+ def is_not_distinct_from(self, other):
+ """Implement the ``IS NOT DISTINCT FROM`` operator.
+
+ Renders "a IS NOT DISTINCT FROM b" on most platforms;
+ on some such as SQLite may render "a IS b".
+
+ .. versionchanged:: 1.4 The ``is_not_distinct_from()`` operator is
+ renamed from ``isnot_distinct_from()`` in previous releases.
+ The previous name remains available for backwards compatibility.
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(is_not_distinct_from, other)
+
+ # deprecated 1.4; see #5435
+ isnot_distinct_from = is_not_distinct_from
+
+ def __gt__(self, other):
+ """Implement the ``>`` operator.
+
+ In a column context, produces the clause ``a > b``.
+
+ """
+ return self.operate(gt, other)
+
+ def __ge__(self, other):
+ """Implement the ``>=`` operator.
+
+ In a column context, produces the clause ``a >= b``.
+
+ """
+ return self.operate(ge, other)
+
+ def __neg__(self):
+ """Implement the ``-`` operator.
+
+ In a column context, produces the clause ``-a``.
+
+ """
+ return self.operate(neg)
+
+ def __contains__(self, other):
+ return self.operate(contains, other)
+
+ def __getitem__(self, index):
+ """Implement the [] operator.
+
+ This can be used by some database-specific types
+ such as PostgreSQL ARRAY and HSTORE.
+
+ """
+ return self.operate(getitem, index)
+
+ def __lshift__(self, other):
+ """implement the << operator.
+
+ Not used by SQLAlchemy core, this is provided
+ for custom operator systems which want to use
+ << as an extension point.
+ """
+ return self.operate(lshift, other)
+
+ def __rshift__(self, other):
+ """implement the >> operator.
+
+ Not used by SQLAlchemy core, this is provided
+ for custom operator systems which want to use
+ >> as an extension point.
+ """
+ return self.operate(rshift, other)
+
+ def concat(self, other):
+ """Implement the 'concat' operator.
+
+ In a column context, produces the clause ``a || b``,
+ or uses the ``concat()`` operator on MySQL.
+
+ """
+ return self.operate(concat_op, other)
+
+ def _rconcat(self, other):
+ """Implement an 'rconcat' operator.
+
+ this is for internal use at the moment
+
+ .. versionadded:: 1.4.40
+
+ """
+ return self.reverse_operate(concat_op, other)
+
+ def like(self, other, escape=None):
+ r"""Implement the ``like`` operator.
+
+ In a column context, produces the expression::
+
+ a LIKE other
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.like("%foobar%"))
+
+ :param other: expression to be compared
+ :param escape: optional escape character, renders the ``ESCAPE``
+ keyword, e.g.::
+
+ somecolumn.like("foo/%bar", escape="/")
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.ilike`
+
+ """
+ return self.operate(like_op, other, escape=escape)
+
+ def ilike(self, other, escape=None):
+ r"""Implement the ``ilike`` operator, e.g. case insensitive LIKE.
+
+ In a column context, produces an expression either of the form::
+
+ lower(a) LIKE lower(other)
+
+ Or on backends that support the ILIKE operator::
+
+ a ILIKE other
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.ilike("%foobar%"))
+
+ :param other: expression to be compared
+ :param escape: optional escape character, renders the ``ESCAPE``
+ keyword, e.g.::
+
+ somecolumn.ilike("foo/%bar", escape="/")
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(ilike_op, other, escape=escape)
+
+ def in_(self, other):
+ """Implement the ``in`` operator.
+
+ In a column context, produces the clause ``column IN <other>``.
+
+ The given parameter ``other`` may be:
+
+ * A list of literal values, e.g.::
+
+ stmt.where(column.in_([1, 2, 3]))
+
+ In this calling form, the list of items is converted to a set of
+ bound parameters the same length as the list given::
+
+ WHERE COL IN (?, ?, ?)
+
+ * A list of tuples may be provided if the comparison is against a
+ :func:`.tuple_` containing multiple expressions::
+
+ from sqlalchemy import tuple_
+ stmt.where(tuple_(col1, col2).in_([(1, 10), (2, 20), (3, 30)]))
+
+ * An empty list, e.g.::
+
+ stmt.where(column.in_([]))
+
+ In this calling form, the expression renders an "empty set"
+ expression. These expressions are tailored to individual backends
+ and are generally trying to get an empty SELECT statement as a
+ subquery. Such as on SQLite, the expression is::
+
+ WHERE col IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
+
+ .. versionchanged:: 1.4 empty IN expressions now use an
+ execution-time generated SELECT subquery in all cases.
+
+ * A bound parameter, e.g. :func:`.bindparam`, may be used if it
+ includes the :paramref:`.bindparam.expanding` flag::
+
+ stmt.where(column.in_(bindparam('value', expanding=True)))
+
+ In this calling form, the expression renders a special non-SQL
+ placeholder expression that looks like::
+
+ WHERE COL IN ([EXPANDING_value])
+
+ This placeholder expression is intercepted at statement execution
+ time to be converted into the variable number of bound parameter
+ form illustrated earlier. If the statement were executed as::
+
+ connection.execute(stmt, {"value": [1, 2, 3]})
+
+ The database would be passed a bound parameter for each value::
+
+ WHERE COL IN (?, ?, ?)
+
+ .. versionadded:: 1.2 added "expanding" bound parameters
+
+ If an empty list is passed, a special "empty list" expression,
+ which is specific to the database in use, is rendered. On
+ SQLite this would be::
+
+ WHERE COL IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
+
+ .. versionadded:: 1.3 "expanding" bound parameters now support
+ empty lists
+
+ * a :func:`_expression.select` construct, which is usually a
+ correlated scalar select::
+
+ stmt.where(
+ column.in_(
+ select(othertable.c.y).
+ where(table.c.x == othertable.c.x)
+ )
+ )
+
+ In this calling form, :meth:`.ColumnOperators.in_` renders as given::
+
+ WHERE COL IN (SELECT othertable.y
+ FROM othertable WHERE othertable.x = table.x)
+
+ :param other: a list of literals, a :func:`_expression.select`
+ construct, or a :func:`.bindparam` construct that includes the
+ :paramref:`.bindparam.expanding` flag set to True.
+
+ """
+ return self.operate(in_op, other)
+
+ def not_in(self, other):
+ """implement the ``NOT IN`` operator.
+
+ This is equivalent to using negation with
+ :meth:`.ColumnOperators.in_`, i.e. ``~x.in_(y)``.
+
+ In the case that ``other`` is an empty sequence, the compiler
+ produces an "empty not in" expression. This defaults to the
+ expression "1 = 1" to produce true in all cases. The
+ :paramref:`_sa.create_engine.empty_in_strategy` may be used to
+ alter this behavior.
+
+ .. versionchanged:: 1.4 The ``not_in()`` operator is renamed from
+ ``notin_()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. versionchanged:: 1.2 The :meth:`.ColumnOperators.in_` and
+ :meth:`.ColumnOperators.not_in` operators
+ now produce a "static" expression for an empty IN sequence
+ by default.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.in_`
+
+ """
+ return self.operate(not_in_op, other)
+
+ # deprecated 1.4; see #5429
+ notin_ = not_in
+
+ def not_like(self, other, escape=None):
+ """implement the ``NOT LIKE`` operator.
+
+ This is equivalent to using negation with
+ :meth:`.ColumnOperators.like`, i.e. ``~x.like(y)``.
+
+ .. versionchanged:: 1.4 The ``not_like()`` operator is renamed from
+ ``notlike()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(notlike_op, other, escape=escape)
+
+ # deprecated 1.4; see #5435
+ notlike = not_like
+
+ def not_ilike(self, other, escape=None):
+ """implement the ``NOT ILIKE`` operator.
+
+ This is equivalent to using negation with
+ :meth:`.ColumnOperators.ilike`, i.e. ``~x.ilike(y)``.
+
+ .. versionchanged:: 1.4 The ``not_ilike()`` operator is renamed from
+ ``notilike()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.ilike`
+
+ """
+ return self.operate(notilike_op, other, escape=escape)
+
+ # deprecated 1.4; see #5435
+ notilike = not_ilike
+
+ def is_(self, other):
+ """Implement the ``IS`` operator.
+
+ Normally, ``IS`` is generated automatically when comparing to a
+ value of ``None``, which resolves to ``NULL``. However, explicit
+ usage of ``IS`` may be desirable if comparing to boolean values
+ on certain platforms.
+
+ .. seealso:: :meth:`.ColumnOperators.is_not`
+
+ """
+ return self.operate(is_, other)
+
+ def is_not(self, other):
+ """Implement the ``IS NOT`` operator.
+
+ Normally, ``IS NOT`` is generated automatically when comparing to a
+ value of ``None``, which resolves to ``NULL``. However, explicit
+ usage of ``IS NOT`` may be desirable if comparing to boolean values
+ on certain platforms.
+
+ .. versionchanged:: 1.4 The ``is_not()`` operator is renamed from
+ ``isnot()`` in previous releases. The previous name remains
+ available for backwards compatibility.
+
+ .. seealso:: :meth:`.ColumnOperators.is_`
+
+ """
+ return self.operate(is_not, other)
+
+ # deprecated 1.4; see #5429
+ isnot = is_not
+
+ def startswith(self, other, **kwargs):
+ r"""Implement the ``startswith`` operator.
+
+ Produces a LIKE expression that tests against a match for the start
+ of a string value::
+
+ column LIKE <other> || '%'
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.startswith("foobar"))
+
+ Since the operator uses ``LIKE``, wildcard characters
+ ``"%"`` and ``"_"`` that are present inside the <other> expression
+ will behave like wildcards as well. For literal string
+ values, the :paramref:`.ColumnOperators.startswith.autoescape` flag
+ may be set to ``True`` to apply escaping to occurrences of these
+ characters within the string value so that they match as themselves
+ and not as wildcard characters. Alternatively, the
+ :paramref:`.ColumnOperators.startswith.escape` parameter will establish
+ a given character as an escape character which can be of use when
+ the target expression is not a literal string.
+
+ :param other: expression to be compared. This is usually a plain
+ string value, but can also be an arbitrary SQL expression. LIKE
+ wildcard characters ``%`` and ``_`` are not escaped by default unless
+ the :paramref:`.ColumnOperators.startswith.autoescape` flag is
+ set to True.
+
+ :param autoescape: boolean; when True, establishes an escape character
+ within the LIKE expression, then applies it to all occurrences of
+ ``"%"``, ``"_"`` and the escape character itself within the
+ comparison value, which is assumed to be a literal string and not a
+ SQL expression.
+
+ An expression such as::
+
+ somecolumn.startswith("foo%bar", autoescape=True)
+
+ Will render as::
+
+ somecolumn LIKE :param || '%' ESCAPE '/'
+
+ With the value of ``:param`` as ``"foo/%bar"``.
+
+ :param escape: a character which when given will render with the
+ ``ESCAPE`` keyword to establish that character as the escape
+ character. This character can then be placed preceding occurrences
+ of ``%`` and ``_`` to allow them to act as themselves and not
+ wildcard characters.
+
+ An expression such as::
+
+ somecolumn.startswith("foo/%bar", escape="^")
+
+ Will render as::
+
+ somecolumn LIKE :param || '%' ESCAPE '^'
+
+ The parameter may also be combined with
+ :paramref:`.ColumnOperators.startswith.autoescape`::
+
+ somecolumn.startswith("foo%bar^bat", escape="^", autoescape=True)
+
+ Where above, the given literal parameter will be converted to
+ ``"foo^%bar^^bat"`` before being passed to the database.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.endswith`
+
+ :meth:`.ColumnOperators.contains`
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(startswith_op, other, **kwargs)
+
+ def endswith(self, other, **kwargs):
+ r"""Implement the 'endswith' operator.
+
+ Produces a LIKE expression that tests against a match for the end
+ of a string value::
+
+ column LIKE '%' || <other>
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.endswith("foobar"))
+
+ Since the operator uses ``LIKE``, wildcard characters
+ ``"%"`` and ``"_"`` that are present inside the <other> expression
+ will behave like wildcards as well. For literal string
+ values, the :paramref:`.ColumnOperators.endswith.autoescape` flag
+ may be set to ``True`` to apply escaping to occurrences of these
+ characters within the string value so that they match as themselves
+ and not as wildcard characters. Alternatively, the
+ :paramref:`.ColumnOperators.endswith.escape` parameter will establish
+ a given character as an escape character which can be of use when
+ the target expression is not a literal string.
+
+ :param other: expression to be compared. This is usually a plain
+ string value, but can also be an arbitrary SQL expression. LIKE
+ wildcard characters ``%`` and ``_`` are not escaped by default unless
+ the :paramref:`.ColumnOperators.endswith.autoescape` flag is
+ set to True.
+
+ :param autoescape: boolean; when True, establishes an escape character
+ within the LIKE expression, then applies it to all occurrences of
+ ``"%"``, ``"_"`` and the escape character itself within the
+ comparison value, which is assumed to be a literal string and not a
+ SQL expression.
+
+ An expression such as::
+
+ somecolumn.endswith("foo%bar", autoescape=True)
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param ESCAPE '/'
+
+ With the value of ``:param`` as ``"foo/%bar"``.
+
+ :param escape: a character which when given will render with the
+ ``ESCAPE`` keyword to establish that character as the escape
+ character. This character can then be placed preceding occurrences
+ of ``%`` and ``_`` to allow them to act as themselves and not
+ wildcard characters.
+
+ An expression such as::
+
+ somecolumn.endswith("foo/%bar", escape="^")
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param ESCAPE '^'
+
+ The parameter may also be combined with
+ :paramref:`.ColumnOperators.endswith.autoescape`::
+
+ somecolumn.endswith("foo%bar^bat", escape="^", autoescape=True)
+
+ Where above, the given literal parameter will be converted to
+ ``"foo^%bar^^bat"`` before being passed to the database.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.startswith`
+
+ :meth:`.ColumnOperators.contains`
+
+ :meth:`.ColumnOperators.like`
+
+ """
+ return self.operate(endswith_op, other, **kwargs)
+
+ def contains(self, other, **kwargs):
+ r"""Implement the 'contains' operator.
+
+ Produces a LIKE expression that tests against a match for the middle
+ of a string value::
+
+ column LIKE '%' || <other> || '%'
+
+ E.g.::
+
+ stmt = select(sometable).\
+ where(sometable.c.column.contains("foobar"))
+
+ Since the operator uses ``LIKE``, wildcard characters
+ ``"%"`` and ``"_"`` that are present inside the <other> expression
+ will behave like wildcards as well. For literal string
+ values, the :paramref:`.ColumnOperators.contains.autoescape` flag
+ may be set to ``True`` to apply escaping to occurrences of these
+ characters within the string value so that they match as themselves
+ and not as wildcard characters. Alternatively, the
+ :paramref:`.ColumnOperators.contains.escape` parameter will establish
+ a given character as an escape character which can be of use when
+ the target expression is not a literal string.
+
+ :param other: expression to be compared. This is usually a plain
+ string value, but can also be an arbitrary SQL expression. LIKE
+ wildcard characters ``%`` and ``_`` are not escaped by default unless
+ the :paramref:`.ColumnOperators.contains.autoescape` flag is
+ set to True.
+
+ :param autoescape: boolean; when True, establishes an escape character
+ within the LIKE expression, then applies it to all occurrences of
+ ``"%"``, ``"_"`` and the escape character itself within the
+ comparison value, which is assumed to be a literal string and not a
+ SQL expression.
+
+ An expression such as::
+
+ somecolumn.contains("foo%bar", autoescape=True)
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param || '%' ESCAPE '/'
+
+ With the value of ``:param`` as ``"foo/%bar"``.
+
+ :param escape: a character which when given will render with the
+ ``ESCAPE`` keyword to establish that character as the escape
+ character. This character can then be placed preceding occurrences
+ of ``%`` and ``_`` to allow them to act as themselves and not
+ wildcard characters.
+
+ An expression such as::
+
+ somecolumn.contains("foo/%bar", escape="^")
+
+ Will render as::
+
+ somecolumn LIKE '%' || :param || '%' ESCAPE '^'
+
+ The parameter may also be combined with
+ :paramref:`.ColumnOperators.contains.autoescape`::
+
+ somecolumn.contains("foo%bar^bat", escape="^", autoescape=True)
+
+ Where above, the given literal parameter will be converted to
+ ``"foo^%bar^^bat"`` before being passed to the database.
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.startswith`
+
+ :meth:`.ColumnOperators.endswith`
+
+ :meth:`.ColumnOperators.like`
+
+
+ """
+ return self.operate(contains_op, other, **kwargs)
+
+ def match(self, other, **kwargs):
+ """Implements a database-specific 'match' operator.
+
+ :meth:`_sql.ColumnOperators.match` attempts to resolve to
+ a MATCH-like function or operator provided by the backend.
+ Examples include:
+
+ * PostgreSQL - renders ``x @@ to_tsquery(y)``
+ * MySQL - renders ``MATCH (x) AGAINST (y IN BOOLEAN MODE)``
+
+ .. seealso::
+
+ :class:`_mysql.match` - MySQL specific construct with
+ additional features.
+
+ * Oracle - renders ``CONTAINS(x, y)``
+ * other backends may provide special implementations.
+ * Backends without any special implementation will emit
+ the operator as "MATCH". This is compatible with SQLite, for
+ example.
+
+ """
+ return self.operate(match_op, other, **kwargs)
+
+ def regexp_match(self, pattern, flags=None):
+ """Implements a database-specific 'regexp match' operator.
+
+ E.g.::
+
+ stmt = select(table.c.some_column).where(
+ table.c.some_column.regexp_match('^(b|c)')
+ )
+
+ :meth:`_sql.ColumnOperators.regexp_match` attempts to resolve to
+ a REGEXP-like function or operator provided by the backend, however
+ the specific regular expression syntax and flags available are
+ **not backend agnostic**.
+
+ Examples include:
+
+ * PostgreSQL - renders ``x ~ y`` or ``x !~ y`` when negated.
+ * Oracle - renders ``REGEXP_LIKE(x, y)``
+ * SQLite - uses SQLite's ``REGEXP`` placeholder operator and calls into
+ the Python ``re.match()`` builtin.
+ * other backends may provide special implementations.
+ * Backends without any special implementation will emit
+ the operator as "REGEXP" or "NOT REGEXP". This is compatible with
+ SQLite and MySQL, for example.
+
+ Regular expression support is currently implemented for Oracle,
+ PostgreSQL, MySQL and MariaDB. Partial support is available for
+ SQLite. Support among third-party dialects may vary.
+
+ :param pattern: The regular expression pattern string or column
+ clause.
+ :param flags: Any regular expression string flags to apply. Flags
+ tend to be backend specific. It can be a string or a column clause.
+ Some backends, like PostgreSQL and MariaDB, may alternatively
+ specify the flags as part of the pattern.
+ When using the ignore case flag 'i' in PostgreSQL, the ignore case
+ regexp match operator ``~*`` or ``!~*`` will be used.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.regexp_replace`
+
+
+ """
+ return self.operate(regexp_match_op, pattern, flags=flags)
+
+ def regexp_replace(self, pattern, replacement, flags=None):
+ """Implements a database-specific 'regexp replace' operator.
+
+ E.g.::
+
+ stmt = select(
+ table.c.some_column.regexp_replace(
+ 'b(..)',
+ 'X\1Y',
+ flags='g'
+ )
+ )
+
+ :meth:`_sql.ColumnOperators.regexp_replace` attempts to resolve to
+ a REGEXP_REPLACE-like function provided by the backend, that
+ usually emit the function ``REGEXP_REPLACE()``. However,
+ the specific regular expression syntax and flags available are
+ **not backend agnostic**.
+
+ Regular expression replacement support is currently implemented for
+ Oracle, PostgreSQL, MySQL 8 or greater and MariaDB. Support among
+ third-party dialects may vary.
+
+ :param pattern: The regular expression pattern string or column
+ clause.
+ :param pattern: The replacement string or column clause.
+ :param flags: Any regular expression string flags to apply. Flags
+ tend to be backend specific. It can be a string or a column clause.
+ Some backends, like PostgreSQL and MariaDB, may alternatively
+ specify the flags as part of the pattern.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :meth:`_sql.ColumnOperators.regexp_match`
+
+ """
+ return self.operate(
+ regexp_replace_op, pattern, replacement=replacement, flags=flags
+ )
+
+ def desc(self):
+ """Produce a :func:`_expression.desc` clause against the
+ parent object."""
+ return self.operate(desc_op)
+
+ def asc(self):
+ """Produce a :func:`_expression.asc` clause against the
+ parent object."""
+ return self.operate(asc_op)
+
+ def nulls_first(self):
+ """Produce a :func:`_expression.nulls_first` clause against the
+ parent object.
+
+ .. versionchanged:: 1.4 The ``nulls_first()`` operator is
+ renamed from ``nullsfirst()`` in previous releases.
+ The previous name remains available for backwards compatibility.
+ """
+ return self.operate(nulls_first_op)
+
+ # deprecated 1.4; see #5435
+ nullsfirst = nulls_first
+
+ def nulls_last(self):
+ """Produce a :func:`_expression.nulls_last` clause against the
+ parent object.
+
+ .. versionchanged:: 1.4 The ``nulls_last()`` operator is
+ renamed from ``nullslast()`` in previous releases.
+ The previous name remains available for backwards compatibility.
+ """
+ return self.operate(nulls_last_op)
+
+ # deprecated 1.4; see #5429
+ nullslast = nulls_last
+
+ def collate(self, collation):
+ """Produce a :func:`_expression.collate` clause against
+ the parent object, given the collation string.
+
+ .. seealso::
+
+ :func:`_expression.collate`
+
+ """
+ return self.operate(collate, collation)
+
+ def __radd__(self, other):
+ """Implement the ``+`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__add__`.
+
+ """
+ return self.reverse_operate(add, other)
+
+ def __rsub__(self, other):
+ """Implement the ``-`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__sub__`.
+
+ """
+ return self.reverse_operate(sub, other)
+
+ def __rmul__(self, other):
+ """Implement the ``*`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__mul__`.
+
+ """
+ return self.reverse_operate(mul, other)
+
+ def __rdiv__(self, other):
+ """Implement the ``/`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__div__`.
+
+ """
+ return self.reverse_operate(div, other)
+
+ def __rmod__(self, other):
+ """Implement the ``%`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__mod__`.
+
+ """
+ return self.reverse_operate(mod, other)
+
+ def between(self, cleft, cright, symmetric=False):
+ """Produce a :func:`_expression.between` clause against
+ the parent object, given the lower and upper range.
+
+ """
+ return self.operate(between_op, cleft, cright, symmetric=symmetric)
+
+ def distinct(self):
+ """Produce a :func:`_expression.distinct` clause against the
+ parent object.
+
+ """
+ return self.operate(distinct_op)
+
+ def any_(self):
+ """Produce an :func:`_expression.any_` clause against the
+ parent object.
+
+ See the documentation for :func:`_sql.any_` for examples.
+
+ .. note:: be sure to not confuse the newer
+ :meth:`_sql.ColumnOperators.any_` method with its older
+ :class:`_types.ARRAY`-specific counterpart, the
+ :meth:`_types.ARRAY.Comparator.any` method, which a different
+ calling syntax and usage pattern.
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(any_op)
+
+ def all_(self):
+ """Produce an :func:`_expression.all_` clause against the
+ parent object.
+
+ See the documentation for :func:`_sql.all_` for examples.
+
+ .. note:: be sure to not confuse the newer
+ :meth:`_sql.ColumnOperators.all_` method with its older
+ :class:`_types.ARRAY`-specific counterpart, the
+ :meth:`_types.ARRAY.Comparator.all` method, which a different
+ calling syntax and usage pattern.
+
+
+ .. versionadded:: 1.1
+
+ """
+ return self.operate(all_op)
+
+ def __add__(self, other):
+ """Implement the ``+`` operator.
+
+ In a column context, produces the clause ``a + b``
+ if the parent object has non-string affinity.
+ If the parent object has a string affinity,
+ produces the concatenation operator, ``a || b`` -
+ see :meth:`.ColumnOperators.concat`.
+
+ """
+ return self.operate(add, other)
+
+ def __sub__(self, other):
+ """Implement the ``-`` operator.
+
+ In a column context, produces the clause ``a - b``.
+
+ """
+ return self.operate(sub, other)
+
+ def __mul__(self, other):
+ """Implement the ``*`` operator.
+
+ In a column context, produces the clause ``a * b``.
+
+ """
+ return self.operate(mul, other)
+
+ def __div__(self, other):
+ """Implement the ``/`` operator.
+
+ In a column context, produces the clause ``a / b``.
+
+ """
+ return self.operate(div, other)
+
+ def __mod__(self, other):
+ """Implement the ``%`` operator.
+
+ In a column context, produces the clause ``a % b``.
+
+ """
+ return self.operate(mod, other)
+
+ def __truediv__(self, other):
+ """Implement the ``//`` operator.
+
+ In a column context, produces the clause ``a / b``.
+
+ """
+ return self.operate(truediv, other)
+
+ def __rtruediv__(self, other):
+ """Implement the ``//`` operator in reverse.
+
+ See :meth:`.ColumnOperators.__truediv__`.
+
+ """
+ return self.reverse_operate(truediv, other)
+
+
+_commutative = {eq, ne, add, mul}
+_comparison = {eq, ne, lt, gt, ge, le}
+
+
+def commutative_op(fn):
+ _commutative.add(fn)
+ return fn
+
+
+def comparison_op(fn):
+ _comparison.add(fn)
+ return fn
+
+
+def from_():
+ raise NotImplementedError()
+
+
+@comparison_op
+def function_as_comparison_op():
+ raise NotImplementedError()
+
+
+def as_():
+ raise NotImplementedError()
+
+
+def exists():
+ raise NotImplementedError()
+
+
+def is_true(a):
+ raise NotImplementedError()
+
+
+# 1.4 deprecated; see #5435
+istrue = is_true
+
+
+def is_false(a):
+ raise NotImplementedError()
+
+
+# 1.4 deprecated; see #5435
+isfalse = is_false
+
+
+@comparison_op
+def is_distinct_from(a, b):
+ return a.is_distinct_from(b)
+
+
+@comparison_op
+def is_not_distinct_from(a, b):
+ return a.is_not_distinct_from(b)
+
+
+# deprecated 1.4; see #5435
+isnot_distinct_from = is_not_distinct_from
+
+
+@comparison_op
+def is_(a, b):
+ return a.is_(b)
+
+
+@comparison_op
+def is_not(a, b):
+ return a.is_not(b)
+
+
+# 1.4 deprecated; see #5429
+isnot = is_not
+
+
+def collate(a, b):
+ return a.collate(b)
+
+
+def op(a, opstring, b):
+ return a.op(opstring)(b)
+
+
+@comparison_op
+def like_op(a, b, escape=None):
+ return a.like(b, escape=escape)
+
+
+@comparison_op
+def not_like_op(a, b, escape=None):
+ return a.notlike(b, escape=escape)
+
+
+# 1.4 deprecated; see #5435
+notlike_op = not_like_op
+
+
+@comparison_op
+def ilike_op(a, b, escape=None):
+ return a.ilike(b, escape=escape)
+
+
+@comparison_op
+def not_ilike_op(a, b, escape=None):
+ return a.not_ilike(b, escape=escape)
+
+
+# 1.4 deprecated; see #5435
+notilike_op = not_ilike_op
+
+
+@comparison_op
+def between_op(a, b, c, symmetric=False):
+ return a.between(b, c, symmetric=symmetric)
+
+
+@comparison_op
+def not_between_op(a, b, c, symmetric=False):
+ return ~a.between(b, c, symmetric=symmetric)
+
+
+# 1.4 deprecated; see #5435
+notbetween_op = not_between_op
+
+
+@comparison_op
+def in_op(a, b):
+ return a.in_(b)
+
+
+@comparison_op
+def not_in_op(a, b):
+ return a.not_in(b)
+
+
+# 1.4 deprecated; see #5429
+notin_op = not_in_op
+
+
+def distinct_op(a):
+ return a.distinct()
+
+
+def any_op(a):
+ return a.any_()
+
+
+def all_op(a):
+ return a.all_()
+
+
+def _escaped_like_impl(fn, other, escape, autoescape):
+ if autoescape:
+ if autoescape is not True:
+ util.warn(
+ "The autoescape parameter is now a simple boolean True/False"
+ )
+ if escape is None:
+ escape = "/"
+
+ if not isinstance(other, util.compat.string_types):
+ raise TypeError("String value expected when autoescape=True")
+
+ if escape not in ("%", "_"):
+ other = other.replace(escape, escape + escape)
+
+ other = other.replace("%", escape + "%").replace("_", escape + "_")
+
+ return fn(other, escape=escape)
+
+
+@comparison_op
+def startswith_op(a, b, escape=None, autoescape=False):
+ return _escaped_like_impl(a.startswith, b, escape, autoescape)
+
+
+@comparison_op
+def not_startswith_op(a, b, escape=None, autoescape=False):
+ return ~_escaped_like_impl(a.startswith, b, escape, autoescape)
+
+
+# 1.4 deprecated; see #5435
+notstartswith_op = not_startswith_op
+
+
+@comparison_op
+def endswith_op(a, b, escape=None, autoescape=False):
+ return _escaped_like_impl(a.endswith, b, escape, autoescape)
+
+
+@comparison_op
+def not_endswith_op(a, b, escape=None, autoescape=False):
+ return ~_escaped_like_impl(a.endswith, b, escape, autoescape)
+
+
+# 1.4 deprecated; see #5435
+notendswith_op = not_endswith_op
+
+
+@comparison_op
+def contains_op(a, b, escape=None, autoescape=False):
+ return _escaped_like_impl(a.contains, b, escape, autoescape)
+
+
+@comparison_op
+def not_contains_op(a, b, escape=None, autoescape=False):
+ return ~_escaped_like_impl(a.contains, b, escape, autoescape)
+
+
+# 1.4 deprecated; see #5435
+notcontains_op = not_contains_op
+
+
+@comparison_op
+def match_op(a, b, **kw):
+ return a.match(b, **kw)
+
+
+@comparison_op
+def regexp_match_op(a, b, flags=None):
+ return a.regexp_match(b, flags=flags)
+
+
+@comparison_op
+def not_regexp_match_op(a, b, flags=None):
+ return ~a.regexp_match(b, flags=flags)
+
+
+def regexp_replace_op(a, b, replacement, flags=None):
+ return a.regexp_replace(b, replacement=replacement, flags=flags)
+
+
+@comparison_op
+def not_match_op(a, b, **kw):
+ return ~a.match(b, **kw)
+
+
+# 1.4 deprecated; see #5429
+notmatch_op = not_match_op
+
+
+def comma_op(a, b):
+ raise NotImplementedError()
+
+
+def filter_op(a, b):
+ raise NotImplementedError()
+
+
+def concat_op(a, b):
+ try:
+ concat = a.concat
+ except AttributeError:
+ return b._rconcat(a)
+ else:
+ return concat(b)
+
+
+def desc_op(a):
+ return a.desc()
+
+
+def asc_op(a):
+ return a.asc()
+
+
+def nulls_first_op(a):
+ return a.nulls_first()
+
+
+# 1.4 deprecated; see #5435
+nullsfirst_op = nulls_first_op
+
+
+def nulls_last_op(a):
+ return a.nulls_last()
+
+
+# 1.4 deprecated; see #5435
+nullslast_op = nulls_last_op
+
+
+def json_getitem_op(a, b):
+ raise NotImplementedError()
+
+
+def json_path_getitem_op(a, b):
+ raise NotImplementedError()
+
+
+def is_comparison(op):
+ return op in _comparison or isinstance(op, custom_op) and op.is_comparison
+
+
+def is_commutative(op):
+ return op in _commutative
+
+
+def is_ordering_modifier(op):
+ return op in (asc_op, desc_op, nulls_first_op, nulls_last_op)
+
+
+def is_natural_self_precedent(op):
+ return (
+ op in _natural_self_precedent
+ or isinstance(op, custom_op)
+ and op.natural_self_precedent
+ )
+
+
+_booleans = (inv, is_true, is_false, and_, or_)
+
+
+def is_boolean(op):
+ return is_comparison(op) or op in _booleans
+
+
+_mirror = {gt: lt, ge: le, lt: gt, le: ge}
+
+
+def mirror(op):
+ """rotate a comparison operator 180 degrees.
+
+ Note this is not the same as negation.
+
+ """
+ return _mirror.get(op, op)
+
+
+_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
+
+
+def is_associative(op):
+ return op in _associative
+
+
+_natural_self_precedent = _associative.union(
+ [getitem, json_getitem_op, json_path_getitem_op]
+)
+"""Operators where if we have (a op b) op c, we don't want to
+parenthesize (a op b).
+
+"""
+
+
+_asbool = util.symbol("_asbool", canonical=-10)
+_smallest = util.symbol("_smallest", canonical=-100)
+_largest = util.symbol("_largest", canonical=100)
+
+_PRECEDENCE = {
+ from_: 15,
+ function_as_comparison_op: 15,
+ any_op: 15,
+ all_op: 15,
+ getitem: 15,
+ json_getitem_op: 15,
+ json_path_getitem_op: 15,
+ mul: 8,
+ truediv: 8,
+ div: 8,
+ mod: 8,
+ neg: 8,
+ add: 7,
+ sub: 7,
+ concat_op: 6,
+ filter_op: 6,
+ match_op: 5,
+ not_match_op: 5,
+ regexp_match_op: 5,
+ not_regexp_match_op: 5,
+ regexp_replace_op: 5,
+ ilike_op: 5,
+ not_ilike_op: 5,
+ like_op: 5,
+ not_like_op: 5,
+ in_op: 5,
+ not_in_op: 5,
+ is_: 5,
+ is_not: 5,
+ eq: 5,
+ ne: 5,
+ is_distinct_from: 5,
+ is_not_distinct_from: 5,
+ gt: 5,
+ lt: 5,
+ ge: 5,
+ le: 5,
+ between_op: 5,
+ not_between_op: 5,
+ distinct_op: 5,
+ inv: 5,
+ is_true: 5,
+ is_false: 5,
+ and_: 3,
+ or_: 2,
+ comma_op: -1,
+ desc_op: 3,
+ asc_op: 3,
+ collate: 4,
+ as_: -1,
+ exists: 0,
+ _asbool: -10,
+ _smallest: _smallest,
+ _largest: _largest,
+}
+
+
+def is_precedent(operator, against):
+ if operator is against and is_natural_self_precedent(operator):
+ return False
+ else:
+ return _PRECEDENCE.get(
+ operator, getattr(operator, "precedence", _smallest)
+ ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest))
diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py
new file mode 100644
index 0000000..9e146f7
--- /dev/null
+++ b/lib/sqlalchemy/sql/roles.py
@@ -0,0 +1,239 @@
+# sql/roles.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .. import util
+
+
+class SQLRole(object):
+ """Define a "role" within a SQL statement structure.
+
+ Classes within SQL Core participate within SQLRole hierarchies in order
+ to more accurately indicate where they may be used within SQL statements
+ of all types.
+
+ .. versionadded:: 1.4
+
+ """
+
+ allows_lambda = False
+ uses_inspection = False
+
+
+class UsesInspection(object):
+ _post_inspect = None
+ uses_inspection = True
+
+
+class AllowsLambdaRole(object):
+ allows_lambda = True
+
+
+class HasCacheKeyRole(SQLRole):
+ _role_name = "Cacheable Core or ORM object"
+
+
+class ExecutableOptionRole(SQLRole):
+ __slots__ = ()
+ _role_name = "ExecutionOption Core or ORM object"
+
+
+class LiteralValueRole(SQLRole):
+ _role_name = "Literal Python value"
+
+
+class ColumnArgumentRole(SQLRole):
+ _role_name = "Column expression"
+
+
+class ColumnArgumentOrKeyRole(ColumnArgumentRole):
+ _role_name = "Column expression or string key"
+
+
+class StrAsPlainColumnRole(ColumnArgumentRole):
+ _role_name = "Column expression or string key"
+
+
+class ColumnListRole(SQLRole):
+ """Elements suitable for forming comma separated lists of expressions."""
+
+
+class TruncatedLabelRole(SQLRole):
+ _role_name = "String SQL identifier"
+
+
+class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
+ _role_name = "Column expression or FROM clause"
+
+ @property
+ def _select_iterable(self):
+ raise NotImplementedError()
+
+
+class LimitOffsetRole(SQLRole):
+ _role_name = "LIMIT / OFFSET expression"
+
+
+class ByOfRole(ColumnListRole):
+ _role_name = "GROUP BY / OF / etc. expression"
+
+
+class GroupByRole(AllowsLambdaRole, UsesInspection, ByOfRole):
+ # note there's a special case right now where you can pass a whole
+ # ORM entity to group_by() and it splits out. we may not want to keep
+ # this around
+
+ _role_name = "GROUP BY expression"
+
+
+class OrderByRole(AllowsLambdaRole, ByOfRole):
+ _role_name = "ORDER BY expression"
+
+
+class StructuralRole(SQLRole):
+ pass
+
+
+class StatementOptionRole(StructuralRole):
+ _role_name = "statement sub-expression element"
+
+
+class OnClauseRole(AllowsLambdaRole, StructuralRole):
+ _role_name = "SQL expression for ON clause"
+
+
+class WhereHavingRole(OnClauseRole):
+ _role_name = "SQL expression for WHERE/HAVING role"
+
+
+class ExpressionElementRole(SQLRole):
+ _role_name = "SQL expression element"
+
+
+class ConstExprRole(ExpressionElementRole):
+ _role_name = "Constant True/False/None expression"
+
+
+class LabeledColumnExprRole(ExpressionElementRole):
+ pass
+
+
+class BinaryElementRole(ExpressionElementRole):
+ _role_name = "SQL expression element or literal value"
+
+
+class InElementRole(SQLRole):
+ _role_name = (
+ "IN expression list, SELECT construct, or bound parameter object"
+ )
+
+
+class JoinTargetRole(AllowsLambdaRole, UsesInspection, StructuralRole):
+ _role_name = (
+ "Join target, typically a FROM expression, or ORM "
+ "relationship attribute"
+ )
+
+
+class FromClauseRole(ColumnsClauseRole, JoinTargetRole):
+ _role_name = "FROM expression, such as a Table or alias() object"
+
+ _is_subquery = False
+
+ @property
+ def _hide_froms(self):
+ raise NotImplementedError()
+
+
+class StrictFromClauseRole(FromClauseRole):
+ # does not allow text() or select() objects
+
+ @property
+ def description(self):
+ raise NotImplementedError()
+
+
+class AnonymizedFromClauseRole(StrictFromClauseRole):
+ # calls .alias() as a post processor
+
+ def _anonymous_fromclause(self, name=None, flat=False):
+ raise NotImplementedError()
+
+
+class ReturnsRowsRole(SQLRole):
+ _role_name = (
+ "Row returning expression such as a SELECT, a FROM clause, or an "
+ "INSERT/UPDATE/DELETE with RETURNING"
+ )
+
+
+class StatementRole(SQLRole):
+ _role_name = "Executable SQL or text() construct"
+
+ _propagate_attrs = util.immutabledict()
+
+
+class SelectStatementRole(StatementRole, ReturnsRowsRole):
+ _role_name = "SELECT construct or equivalent text() construct"
+
+ def subquery(self):
+ raise NotImplementedError(
+ "All SelectStatementRole objects should implement a "
+ ".subquery() method."
+ )
+
+
+class HasCTERole(ReturnsRowsRole):
+ pass
+
+
+class IsCTERole(SQLRole):
+ _role_name = "CTE object"
+
+
+class CompoundElementRole(AllowsLambdaRole, SQLRole):
+ """SELECT statements inside a CompoundSelect, e.g. UNION, EXTRACT, etc."""
+
+ _role_name = (
+ "SELECT construct for inclusion in a UNION or other set construct"
+ )
+
+
+# TODO: are we using this?
+class DMLRole(StatementRole):
+ pass
+
+
+class DMLTableRole(FromClauseRole):
+ _role_name = "subject table for an INSERT, UPDATE or DELETE"
+
+
+class DMLColumnRole(SQLRole):
+ _role_name = "SET/VALUES column expression or string key"
+
+
+class DMLSelectRole(SQLRole):
+ """A SELECT statement embedded in DML, typically INSERT from SELECT"""
+
+ _role_name = "SELECT statement or equivalent textual object"
+
+
+class DDLRole(StatementRole):
+ pass
+
+
+class DDLExpressionRole(StructuralRole):
+ _role_name = "SQL expression element for DDL constraint"
+
+
+class DDLConstraintColumnRole(SQLRole):
+ _role_name = "String column name or column expression for DDL constraint"
+
+
+class DDLReferredColumnRole(DDLConstraintColumnRole):
+ _role_name = (
+ "String column name or Column object for DDL foreign key constraint"
+ )
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
new file mode 100644
index 0000000..dde665c
--- /dev/null
+++ b/lib/sqlalchemy/sql/schema.py
@@ -0,0 +1,5268 @@
+# sql/schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The schema module provides the building blocks for database metadata.
+
+Each element within this module describes a database entity which can be
+created and dropped, or is otherwise part of such an entity. Examples include
+tables, columns, sequences, and indexes.
+
+All entities are subclasses of :class:`~sqlalchemy.schema.SchemaItem`, and as
+defined in this module they are intended to be agnostic of any vendor-specific
+constructs.
+
+A collection of entities are grouped into a unit called
+:class:`~sqlalchemy.schema.MetaData`. MetaData serves as a logical grouping of
+schema elements, and can also be associated with an actual database connection
+such that operations involving the contained elements can contact the database
+as needed.
+
+Two of the elements here also build upon their "syntactic" counterparts, which
+are defined in :class:`~sqlalchemy.sql.expression.`, specifically
+:class:`~sqlalchemy.schema.Table` and :class:`~sqlalchemy.schema.Column`.
+Since these objects are part of the SQL expression language, they are usable
+as components in SQL expressions.
+
+"""
+from __future__ import absolute_import
+
+import collections
+
+import sqlalchemy
+from . import coercions
+from . import ddl
+from . import roles
+from . import type_api
+from . import visitors
+from .base import _bind_or_error
+from .base import DedupeColumnCollection
+from .base import DialectKWArgs
+from .base import Executable
+from .base import SchemaEventTarget
+from .coercions import _document_text_coercion
+from .elements import ClauseElement
+from .elements import ColumnClause
+from .elements import ColumnElement
+from .elements import quoted_name
+from .elements import TextClause
+from .selectable import TableClause
+from .type_api import to_instance
+from .visitors import InternalTraversal
+from .. import event
+from .. import exc
+from .. import inspection
+from .. import util
+
+
+RETAIN_SCHEMA = util.symbol(
+ "retain_schema"
+ """Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence`
+ or in some cases a :class:`_schema.ForeignKey` object, in situations
+ where the object is being copied for a :meth:`.Table.to_metadata`
+ operation, should retain the schema name that it already has.
+
+ """
+)
+
+BLANK_SCHEMA = util.symbol(
+ "blank_schema",
+ """Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence`
+ or in some cases a :class:`_schema.ForeignKey` object
+ should have 'None' for its schema, even if the parent
+ :class:`_schema.MetaData` has specified a schema.
+
+ .. versionadded:: 1.0.14
+
+ """,
+)
+
+NULL_UNSPECIFIED = util.symbol(
+ "NULL_UNSPECIFIED",
+ """Symbol indicating the "nullable" keyword was not passed to a Column.
+
+ Normally we would expect None to be acceptable for this but some backends
+ such as that of SQL Server place special signficance on a "nullability"
+ value of None.
+
+ """,
+)
+
+
+def _get_table_key(name, schema):
+ if schema is None:
+ return name
+ else:
+ return schema + "." + name
+
+
+# this should really be in sql/util.py but we'd have to
+# break an import cycle
+def _copy_expression(expression, source_table, target_table):
+ if source_table is None or target_table is None:
+ return expression
+
+ def replace(col):
+ if (
+ isinstance(col, Column)
+ and col.table is source_table
+ and col.key in source_table.c
+ ):
+ return target_table.c[col.key]
+ else:
+ return None
+
+ return visitors.replacement_traverse(expression, {}, replace)
+
+
+@inspection._self_inspects
+class SchemaItem(SchemaEventTarget, visitors.Visitable):
+ """Base class for items that define a database schema."""
+
+ __visit_name__ = "schema_item"
+
+ create_drop_stringify_dialect = "default"
+
+ def _init_items(self, *args, **kw):
+ """Initialize the list of child items for this SchemaItem."""
+ for item in args:
+ if item is not None:
+ try:
+ spwd = item._set_parent_with_dispatch
+ except AttributeError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "'SchemaItem' object, such as a 'Column' or a "
+ "'Constraint' expected, got %r" % item
+ ),
+ replace_context=err,
+ )
+ else:
+ spwd(self, **kw)
+
+ def __repr__(self):
+ return util.generic_repr(self, omit_kwarg=["info"])
+
+ @util.memoized_property
+ def info(self):
+ """Info dictionary associated with the object, allowing user-defined
+ data to be associated with this :class:`.SchemaItem`.
+
+ The dictionary is automatically generated when first accessed.
+ It can also be specified in the constructor of some objects,
+ such as :class:`_schema.Table` and :class:`_schema.Column`.
+
+ """
+ return {}
+
+ def _schema_item_copy(self, schema_item):
+ if "info" in self.__dict__:
+ schema_item.info = self.info.copy()
+ schema_item.dispatch._update(self.dispatch)
+ return schema_item
+
+ _use_schema_map = True
+
+
+class Table(DialectKWArgs, SchemaItem, TableClause):
+ r"""Represent a table in a database.
+
+ e.g.::
+
+ mytable = Table(
+ "mytable", metadata,
+ Column('mytable_id', Integer, primary_key=True),
+ Column('value', String(50))
+ )
+
+ The :class:`_schema.Table`
+ object constructs a unique instance of itself based
+ on its name and optional schema name within the given
+ :class:`_schema.MetaData` object. Calling the :class:`_schema.Table`
+ constructor with the same name and same :class:`_schema.MetaData` argument
+ a second time will return the *same* :class:`_schema.Table`
+ object - in this way
+ the :class:`_schema.Table` constructor acts as a registry function.
+
+ .. seealso::
+
+ :ref:`metadata_describing` - Introduction to database metadata
+
+ Constructor arguments are as follows:
+
+ :param name: The name of this table as represented in the database.
+
+ The table name, along with the value of the ``schema`` parameter,
+ forms a key which uniquely identifies this :class:`_schema.Table`
+ within
+ the owning :class:`_schema.MetaData` collection.
+ Additional calls to :class:`_schema.Table` with the same name,
+ metadata,
+ and schema name will return the same :class:`_schema.Table` object.
+
+ Names which contain no upper case characters
+ will be treated as case insensitive names, and will not be quoted
+ unless they are a reserved word or contain special characters.
+ A name with any number of upper case characters is considered
+ to be case sensitive, and will be sent as quoted.
+
+ To enable unconditional quoting for the table name, specify the flag
+ ``quote=True`` to the constructor, or use the :class:`.quoted_name`
+ construct to specify the name.
+
+ :param metadata: a :class:`_schema.MetaData`
+ object which will contain this
+ table. The metadata is used as a point of association of this table
+ with other tables which are referenced via foreign key. It also
+ may be used to associate this table with a particular
+ :class:`.Connectable`.
+
+ :param \*args: Additional positional arguments are used primarily
+ to add the list of :class:`_schema.Column`
+ objects contained within this
+ table. Similar to the style of a CREATE TABLE statement, other
+ :class:`.SchemaItem` constructs may be added here, including
+ :class:`.PrimaryKeyConstraint`, and
+ :class:`_schema.ForeignKeyConstraint`.
+
+ :param autoload: Defaults to ``False``, unless
+ :paramref:`_schema.Table.autoload_with`
+ is set in which case it defaults to ``True``;
+ :class:`_schema.Column` objects
+ for this table should be reflected from the database, possibly
+ augmenting objects that were explicitly specified.
+ :class:`_schema.Column` and other objects explicitly set on the
+ table will replace corresponding reflected objects.
+
+ .. deprecated:: 1.4
+
+ The autoload parameter is deprecated and will be removed in
+ version 2.0. Please use the
+ :paramref:`_schema.Table.autoload_with` parameter, passing an
+ engine or connection.
+
+ .. seealso::
+
+ :ref:`metadata_reflection_toplevel`
+
+ :param autoload_replace: Defaults to ``True``; when using
+ :paramref:`_schema.Table.autoload`
+ in conjunction with :paramref:`_schema.Table.extend_existing`,
+ indicates
+ that :class:`_schema.Column` objects present in the already-existing
+ :class:`_schema.Table`
+ object should be replaced with columns of the same
+ name retrieved from the autoload process. When ``False``, columns
+ already present under existing names will be omitted from the
+ reflection process.
+
+ Note that this setting does not impact :class:`_schema.Column` objects
+ specified programmatically within the call to :class:`_schema.Table`
+ that
+ also is autoloading; those :class:`_schema.Column` objects will always
+ replace existing columns of the same name when
+ :paramref:`_schema.Table.extend_existing` is ``True``.
+
+ .. seealso::
+
+ :paramref:`_schema.Table.autoload`
+
+ :paramref:`_schema.Table.extend_existing`
+
+ :param autoload_with: An :class:`_engine.Engine` or
+ :class:`_engine.Connection` object,
+ or a :class:`_reflection.Inspector` object as returned by
+ :func:`_sa.inspect`
+ against one, with which this :class:`_schema.Table`
+ object will be reflected.
+ When set to a non-None value, the autoload process will take place
+ for this table against the given engine or connection.
+
+ :param extend_existing: When ``True``, indicates that if this
+ :class:`_schema.Table` is already present in the given
+ :class:`_schema.MetaData`,
+ apply further arguments within the constructor to the existing
+ :class:`_schema.Table`.
+
+ If :paramref:`_schema.Table.extend_existing` or
+ :paramref:`_schema.Table.keep_existing` are not set,
+ and the given name
+ of the new :class:`_schema.Table` refers to a :class:`_schema.Table`
+ that is
+ already present in the target :class:`_schema.MetaData` collection,
+ and
+ this :class:`_schema.Table`
+ specifies additional columns or other constructs
+ or flags that modify the table's state, an
+ error is raised. The purpose of these two mutually-exclusive flags
+ is to specify what action should be taken when a
+ :class:`_schema.Table`
+ is specified that matches an existing :class:`_schema.Table`,
+ yet specifies
+ additional constructs.
+
+ :paramref:`_schema.Table.extend_existing`
+ will also work in conjunction
+ with :paramref:`_schema.Table.autoload` to run a new reflection
+ operation against the database, even if a :class:`_schema.Table`
+ of the same name is already present in the target
+ :class:`_schema.MetaData`; newly reflected :class:`_schema.Column`
+ objects
+ and other options will be added into the state of the
+ :class:`_schema.Table`, potentially overwriting existing columns
+ and options of the same name.
+
+ As is always the case with :paramref:`_schema.Table.autoload`,
+ :class:`_schema.Column` objects can be specified in the same
+ :class:`_schema.Table`
+ constructor, which will take precedence. Below, the existing
+ table ``mytable`` will be augmented with :class:`_schema.Column`
+ objects
+ both reflected from the database, as well as the given
+ :class:`_schema.Column`
+ named "y"::
+
+ Table("mytable", metadata,
+ Column('y', Integer),
+ extend_existing=True,
+ autoload_with=engine
+ )
+
+ .. seealso::
+
+ :paramref:`_schema.Table.autoload`
+
+ :paramref:`_schema.Table.autoload_replace`
+
+ :paramref:`_schema.Table.keep_existing`
+
+
+ :param implicit_returning: True by default - indicates that
+ RETURNING can be used by default to fetch newly inserted primary key
+ values, for backends which support this. Note that
+ :func:`_sa.create_engine` also provides an ``implicit_returning``
+ flag.
+
+ :param include_columns: A list of strings indicating a subset of
+ columns to be loaded via the ``autoload`` operation; table columns who
+ aren't present in this list will not be represented on the resulting
+ ``Table`` object. Defaults to ``None`` which indicates all columns
+ should be reflected.
+
+ :param resolve_fks: Whether or not to reflect :class:`_schema.Table`
+ objects
+ related to this one via :class:`_schema.ForeignKey` objects, when
+ :paramref:`_schema.Table.autoload` or
+ :paramref:`_schema.Table.autoload_with` is
+ specified. Defaults to True. Set to False to disable reflection of
+ related tables as :class:`_schema.ForeignKey`
+ objects are encountered; may be
+ used either to save on SQL calls or to avoid issues with related tables
+ that can't be accessed. Note that if a related table is already present
+ in the :class:`_schema.MetaData` collection, or becomes present later,
+ a
+ :class:`_schema.ForeignKey` object associated with this
+ :class:`_schema.Table` will
+ resolve to that table normally.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :paramref:`.MetaData.reflect.resolve_fks`
+
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ :param keep_existing: When ``True``, indicates that if this Table
+ is already present in the given :class:`_schema.MetaData`, ignore
+ further arguments within the constructor to the existing
+ :class:`_schema.Table`, and return the :class:`_schema.Table`
+ object as
+ originally created. This is to allow a function that wishes
+ to define a new :class:`_schema.Table` on first call, but on
+ subsequent calls will return the same :class:`_schema.Table`,
+ without any of the declarations (particularly constraints)
+ being applied a second time.
+
+ If :paramref:`_schema.Table.extend_existing` or
+ :paramref:`_schema.Table.keep_existing` are not set,
+ and the given name
+ of the new :class:`_schema.Table` refers to a :class:`_schema.Table`
+ that is
+ already present in the target :class:`_schema.MetaData` collection,
+ and
+ this :class:`_schema.Table`
+ specifies additional columns or other constructs
+ or flags that modify the table's state, an
+ error is raised. The purpose of these two mutually-exclusive flags
+ is to specify what action should be taken when a
+ :class:`_schema.Table`
+ is specified that matches an existing :class:`_schema.Table`,
+ yet specifies
+ additional constructs.
+
+ .. seealso::
+
+ :paramref:`_schema.Table.extend_existing`
+
+ :param listeners: A list of tuples of the form ``(<eventname>, <fn>)``
+ which will be passed to :func:`.event.listen` upon construction.
+ This alternate hook to :func:`.event.listen` allows the establishment
+ of a listener function specific to this :class:`_schema.Table` before
+ the "autoload" process begins. Historically this has been intended
+ for use with the :meth:`.DDLEvents.column_reflect` event, however
+ note that this event hook may now be associated with the
+ :class:`_schema.MetaData` object directly::
+
+ def listen_for_reflect(table, column_info):
+ "handle the column reflection event"
+ # ...
+
+ t = Table(
+ 'sometable',
+ autoload_with=engine,
+ listeners=[
+ ('column_reflect', listen_for_reflect)
+ ])
+
+ .. seealso::
+
+ :meth:`_events.DDLEvents.column_reflect`
+
+ :param must_exist: When ``True``, indicates that this Table must already
+ be present in the given :class:`_schema.MetaData` collection, else
+ an exception is raised.
+
+ :param prefixes:
+ A list of strings to insert after CREATE in the CREATE TABLE
+ statement. They will be separated by spaces.
+
+ :param quote: Force quoting of this table's name on or off, corresponding
+ to ``True`` or ``False``. When left at its default of ``None``,
+ the column identifier will be quoted according to whether the name is
+ case sensitive (identifiers with at least one upper case character are
+ treated as case sensitive), or if it's a reserved word. This flag
+ is only needed to force quoting of a reserved word which is not known
+ by the SQLAlchemy dialect.
+
+ .. note:: setting this flag to ``False`` will not provide
+ case-insensitive behavior for table reflection; table reflection
+ will always search for a mixed-case name in a case sensitive
+ fashion. Case insensitive names are specified in SQLAlchemy only
+ by stating the name with all lower case characters.
+
+ :param quote_schema: same as 'quote' but applies to the schema identifier.
+
+ :param schema: The schema name for this table, which is required if
+ the table resides in a schema other than the default selected schema
+ for the engine's database connection. Defaults to ``None``.
+
+ If the owning :class:`_schema.MetaData` of this :class:`_schema.Table`
+ specifies its
+ own :paramref:`_schema.MetaData.schema` parameter,
+ then that schema name will
+ be applied to this :class:`_schema.Table`
+ if the schema parameter here is set
+ to ``None``. To set a blank schema name on a :class:`_schema.Table`
+ that
+ would otherwise use the schema set on the owning
+ :class:`_schema.MetaData`,
+ specify the special symbol :attr:`.BLANK_SCHEMA`.
+
+ .. versionadded:: 1.0.14 Added the :attr:`.BLANK_SCHEMA` symbol to
+ allow a :class:`_schema.Table`
+ to have a blank schema name even when the
+ parent :class:`_schema.MetaData` specifies
+ :paramref:`_schema.MetaData.schema`.
+
+ The quoting rules for the schema name are the same as those for the
+ ``name`` parameter, in that quoting is applied for reserved words or
+ case-sensitive names; to enable unconditional quoting for the schema
+ name, specify the flag ``quote_schema=True`` to the constructor, or use
+ the :class:`.quoted_name` construct to specify the name.
+
+ :param comment: Optional string that will render an SQL comment on table
+ creation.
+
+ .. versionadded:: 1.2 Added the :paramref:`_schema.Table.comment`
+ parameter
+ to :class:`_schema.Table`.
+
+ :param \**kw: Additional keyword arguments not mentioned above are
+ dialect specific, and passed in the form ``<dialectname>_<argname>``.
+ See the documentation regarding an individual dialect at
+ :ref:`dialect_toplevel` for detail on documented arguments.
+
+ """
+
+ __visit_name__ = "table"
+
+ constraints = None
+ """A collection of all :class:`_schema.Constraint` objects associated with
+ this :class:`_schema.Table`.
+
+ Includes :class:`_schema.PrimaryKeyConstraint`,
+ :class:`_schema.ForeignKeyConstraint`, :class:`_schema.UniqueConstraint`,
+ :class:`_schema.CheckConstraint`. A separate collection
+ :attr:`_schema.Table.foreign_key_constraints` refers to the collection
+ of all :class:`_schema.ForeignKeyConstraint` objects, and the
+ :attr:`_schema.Table.primary_key` attribute refers to the single
+ :class:`_schema.PrimaryKeyConstraint` associated with the
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.constraints`
+
+ :attr:`_schema.Table.primary_key`
+
+ :attr:`_schema.Table.foreign_key_constraints`
+
+ :attr:`_schema.Table.indexes`
+
+ :class:`_reflection.Inspector`
+
+
+ """
+
+ indexes = None
+ """A collection of all :class:`_schema.Index` objects associated with this
+ :class:`_schema.Table`.
+
+ .. seealso::
+
+ :meth:`_reflection.Inspector.get_indexes`
+
+ """
+
+ _traverse_internals = TableClause._traverse_internals + [
+ ("schema", InternalTraversal.dp_string)
+ ]
+
+ def _gen_cache_key(self, anon_map, bindparams):
+ if self._annotations:
+ return (self,) + self._annotations_cache_key
+ else:
+ return (self,)
+
+ @util.deprecated_params(
+ mustexist=(
+ "1.4",
+ "Deprecated alias of :paramref:`_schema.Table.must_exist`",
+ ),
+ autoload=(
+ "2.0",
+ "The autoload parameter is deprecated and will be removed in "
+ "version 2.0. Please use the "
+ "autoload_with parameter, passing an engine or connection.",
+ ),
+ )
+ def __new__(cls, *args, **kw):
+ if not args and not kw:
+ # python3k pickle seems to call this
+ return object.__new__(cls)
+
+ try:
+ name, metadata, args = args[0], args[1], args[2:]
+ except IndexError:
+ raise TypeError(
+ "Table() takes at least two positional-only "
+ "arguments 'name' and 'metadata'"
+ )
+
+ schema = kw.get("schema", None)
+ if schema is None:
+ schema = metadata.schema
+ elif schema is BLANK_SCHEMA:
+ schema = None
+ keep_existing = kw.get("keep_existing", False)
+ extend_existing = kw.get("extend_existing", False)
+
+ if keep_existing and extend_existing:
+ msg = "keep_existing and extend_existing are mutually exclusive."
+ raise exc.ArgumentError(msg)
+
+ must_exist = kw.pop("must_exist", kw.pop("mustexist", False))
+ key = _get_table_key(name, schema)
+ if key in metadata.tables:
+ if not keep_existing and not extend_existing and bool(args):
+ raise exc.InvalidRequestError(
+ "Table '%s' is already defined for this MetaData "
+ "instance. Specify 'extend_existing=True' "
+ "to redefine "
+ "options and columns on an "
+ "existing Table object." % key
+ )
+ table = metadata.tables[key]
+ if extend_existing:
+ table._init_existing(*args, **kw)
+ return table
+ else:
+ if must_exist:
+ raise exc.InvalidRequestError("Table '%s' not defined" % (key))
+ table = object.__new__(cls)
+ table.dispatch.before_parent_attach(table, metadata)
+ metadata._add_table(name, schema, table)
+ try:
+ table._init(name, metadata, *args, **kw)
+ table.dispatch.after_parent_attach(table, metadata)
+ return table
+ except Exception:
+ with util.safe_reraise():
+ metadata._remove_table(name, schema)
+
+ def __init__(self, *args, **kw):
+ """Constructor for :class:`_schema.Table`.
+
+ This method is a no-op. See the top-level
+ documentation for :class:`_schema.Table`
+ for constructor arguments.
+
+ """
+ # __init__ is overridden to prevent __new__ from
+ # calling the superclass constructor.
+
+ def _init(self, name, metadata, *args, **kwargs):
+ super(Table, self).__init__(
+ quoted_name(name, kwargs.pop("quote", None))
+ )
+ self.metadata = metadata
+
+ self.schema = kwargs.pop("schema", None)
+ if self.schema is None:
+ self.schema = metadata.schema
+ elif self.schema is BLANK_SCHEMA:
+ self.schema = None
+ else:
+ quote_schema = kwargs.pop("quote_schema", None)
+ self.schema = quoted_name(self.schema, quote_schema)
+
+ self.indexes = set()
+ self.constraints = set()
+ PrimaryKeyConstraint(
+ _implicit_generated=True
+ )._set_parent_with_dispatch(self)
+ self.foreign_keys = set()
+ self._extra_dependencies = set()
+ if self.schema is not None:
+ self.fullname = "%s.%s" % (self.schema, self.name)
+ else:
+ self.fullname = self.name
+
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
+ # this argument is only used with _init_existing()
+ kwargs.pop("autoload_replace", True)
+ keep_existing = kwargs.pop("keep_existing", False)
+ extend_existing = kwargs.pop("extend_existing", False)
+ _extend_on = kwargs.pop("_extend_on", None)
+
+ resolve_fks = kwargs.pop("resolve_fks", True)
+ include_columns = kwargs.pop("include_columns", None)
+
+ self.implicit_returning = kwargs.pop("implicit_returning", True)
+
+ self.comment = kwargs.pop("comment", None)
+
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+ if "listeners" in kwargs:
+ listeners = kwargs.pop("listeners")
+ for evt, fn in listeners:
+ event.listen(self, evt, fn)
+
+ self._prefixes = kwargs.pop("prefixes", None) or []
+
+ self._extra_kwargs(**kwargs)
+
+ # load column definitions from the database if 'autoload' is defined
+ # we do it after the table is in the singleton dictionary to support
+ # circular foreign keys
+ if autoload:
+ self._autoload(
+ metadata,
+ autoload_with,
+ include_columns,
+ _extend_on=_extend_on,
+ resolve_fks=resolve_fks,
+ )
+
+ # initialize all the column, etc. objects. done after reflection to
+ # allow user-overrides
+
+ self._init_items(
+ *args,
+ allow_replacements=extend_existing or keep_existing or autoload
+ )
+
+ def _autoload(
+ self,
+ metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns=(),
+ resolve_fks=True,
+ _extend_on=None,
+ ):
+ if autoload_with is None:
+ autoload_with = _bind_or_error(
+ metadata,
+ msg="No engine is bound to this Table's MetaData. "
+ "Pass an engine to the Table via "
+ "autoload_with=<someengine_or_connection>",
+ )
+
+ insp = inspection.inspect(autoload_with)
+ with insp._inspection_context() as conn_insp:
+ conn_insp.reflect_table(
+ self,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on=_extend_on,
+ )
+
+ @property
+ def _sorted_constraints(self):
+ """Return the set of constraints as a list, sorted by creation
+ order.
+
+ """
+ return sorted(self.constraints, key=lambda c: c._creation_order)
+
+ @property
+ def foreign_key_constraints(self):
+ """:class:`_schema.ForeignKeyConstraint` objects referred to by this
+ :class:`_schema.Table`.
+
+ This list is produced from the collection of
+ :class:`_schema.ForeignKey`
+ objects currently associated.
+
+
+ .. seealso::
+
+ :attr:`_schema.Table.constraints`
+
+ :attr:`_schema.Table.foreign_keys`
+
+ :attr:`_schema.Table.indexes`
+
+ """
+ return set(fkc.constraint for fkc in self.foreign_keys)
+
+ def _init_existing(self, *args, **kwargs):
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
+ autoload_replace = kwargs.pop("autoload_replace", True)
+ schema = kwargs.pop("schema", None)
+ _extend_on = kwargs.pop("_extend_on", None)
+ # these arguments are only used with _init()
+ kwargs.pop("extend_existing", False)
+ kwargs.pop("keep_existing", False)
+
+ if schema and schema != self.schema:
+ raise exc.ArgumentError(
+ "Can't change schema of existing table from '%s' to '%s'",
+ (self.schema, schema),
+ )
+
+ include_columns = kwargs.pop("include_columns", None)
+ if include_columns is not None:
+ for c in self.c:
+ if c.name not in include_columns:
+ self._columns.remove(c)
+
+ resolve_fks = kwargs.pop("resolve_fks", True)
+
+ for key in ("quote", "quote_schema"):
+ if key in kwargs:
+ raise exc.ArgumentError(
+ "Can't redefine 'quote' or 'quote_schema' arguments"
+ )
+
+ # update `self` with these kwargs, if provided
+ self.comment = kwargs.pop("comment", self.comment)
+ self.implicit_returning = kwargs.pop(
+ "implicit_returning", self.implicit_returning
+ )
+ self.info = kwargs.pop("info", self.info)
+
+ if autoload:
+ if not autoload_replace:
+ # don't replace columns already present.
+ # we'd like to do this for constraints also however we don't
+ # have simple de-duping for unnamed constraints.
+ exclude_columns = [c.name for c in self.c]
+ else:
+ exclude_columns = ()
+ self._autoload(
+ self.metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns,
+ resolve_fks,
+ _extend_on=_extend_on,
+ )
+
+ self._extra_kwargs(**kwargs)
+ self._init_items(*args)
+
+ def _extra_kwargs(self, **kwargs):
+ self._validate_dialect_kwargs(kwargs)
+
+ def _init_collections(self):
+ pass
+
+ def _reset_exported(self):
+ pass
+
+ @property
+ def _autoincrement_column(self):
+ return self.primary_key._autoincrement_column
+
+ @property
+ def key(self):
+ """Return the 'key' for this :class:`_schema.Table`.
+
+ This value is used as the dictionary key within the
+ :attr:`_schema.MetaData.tables` collection. It is typically the same
+ as that of :attr:`_schema.Table.name` for a table with no
+ :attr:`_schema.Table.schema`
+ set; otherwise it is typically of the form
+ ``schemaname.tablename``.
+
+ """
+ return _get_table_key(self.name, self.schema)
+
+ def __repr__(self):
+ return "Table(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.metadata)]
+ + [repr(x) for x in self.columns]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]]
+ )
+
+ def __str__(self):
+ return _get_table_key(self.description, self.schema)
+
+ @property
+ def bind(self):
+ """Return the connectable associated with this Table."""
+
+ return self.metadata and self.metadata.bind or None
+
+ def add_is_dependent_on(self, table):
+ """Add a 'dependency' for this Table.
+
+ This is another Table object which must be created
+ first before this one can, or dropped after this one.
+
+ Usually, dependencies between tables are determined via
+ ForeignKey objects. However, for other situations that
+ create dependencies outside of foreign keys (rules, inheriting),
+ this method can manually establish such a link.
+
+ """
+ self._extra_dependencies.add(table)
+
+ def append_column(self, column, replace_existing=False):
+ """Append a :class:`_schema.Column` to this :class:`_schema.Table`.
+
+ The "key" of the newly added :class:`_schema.Column`, i.e. the
+ value of its ``.key`` attribute, will then be available
+ in the ``.c`` collection of this :class:`_schema.Table`, and the
+ column definition will be included in any CREATE TABLE, SELECT,
+ UPDATE, etc. statements generated from this :class:`_schema.Table`
+ construct.
+
+ Note that this does **not** change the definition of the table
+ as it exists within any underlying database, assuming that
+ table has already been created in the database. Relational
+ databases support the addition of columns to existing tables
+ using the SQL ALTER command, which would need to be
+ emitted for an already-existing table that doesn't contain
+ the newly added column.
+
+ :param replace_existing: When ``True``, allows replacing existing
+ columns. When ``False``, the default, an warning will be raised
+ if a column with the same ``.key`` already exists. A future
+ version of sqlalchemy will instead rise a warning.
+
+ .. versionadded:: 1.4.0
+ """
+
+ column._set_parent_with_dispatch(
+ self, allow_replacements=replace_existing
+ )
+
+ def append_constraint(self, constraint):
+ """Append a :class:`_schema.Constraint` to this
+ :class:`_schema.Table`.
+
+ This has the effect of the constraint being included in any
+ future CREATE TABLE statement, assuming specific DDL creation
+ events have not been associated with the given
+ :class:`_schema.Constraint` object.
+
+ Note that this does **not** produce the constraint within the
+ relational database automatically, for a table that already exists
+ in the database. To add a constraint to an
+ existing relational database table, the SQL ALTER command must
+ be used. SQLAlchemy also provides the
+ :class:`.AddConstraint` construct which can produce this SQL when
+ invoked as an executable clause.
+
+ """
+
+ constraint._set_parent_with_dispatch(self)
+
+ def _set_parent(self, metadata, **kw):
+ metadata._add_table(self.name, self.schema, self)
+ self.metadata = metadata
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Table.exists` method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_reflection.Inspector.has_table`.",
+ )
+ def exists(self, bind=None):
+ """Return True if this table exists."""
+
+ if bind is None:
+ bind = _bind_or_error(self)
+
+ insp = inspection.inspect(bind)
+ return insp.has_table(self.name, schema=self.schema)
+
+ def create(self, bind=None, checkfirst=False):
+ """Issue a ``CREATE`` statement for this
+ :class:`_schema.Table`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.create_all`.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=False):
+ """Issue a ``DROP`` statement for this
+ :class:`_schema.Table`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.drop_all`.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ @util.deprecated(
+ "1.4",
+ ":meth:`_schema.Table.tometadata` is renamed to "
+ ":meth:`_schema.Table.to_metadata`",
+ )
+ def tometadata(
+ self,
+ metadata,
+ schema=RETAIN_SCHEMA,
+ referred_schema_fn=None,
+ name=None,
+ ):
+ """Return a copy of this :class:`_schema.Table`
+ associated with a different
+ :class:`_schema.MetaData`.
+
+ See :meth:`_schema.Table.to_metadata` for a full description.
+
+ """
+ return self.to_metadata(
+ metadata,
+ schema=schema,
+ referred_schema_fn=referred_schema_fn,
+ name=name,
+ )
+
+ def to_metadata(
+ self,
+ metadata,
+ schema=RETAIN_SCHEMA,
+ referred_schema_fn=None,
+ name=None,
+ ):
+ """Return a copy of this :class:`_schema.Table` associated with a
+ different :class:`_schema.MetaData`.
+
+ E.g.::
+
+ m1 = MetaData()
+
+ user = Table('user', m1, Column('id', Integer, primary_key=True))
+
+ m2 = MetaData()
+ user_copy = user.to_metadata(m2)
+
+ .. versionchanged:: 1.4 The :meth:`_schema.Table.to_metadata` function
+ was renamed from :meth:`_schema.Table.tometadata`.
+
+
+ :param metadata: Target :class:`_schema.MetaData` object,
+ into which the
+ new :class:`_schema.Table` object will be created.
+
+ :param schema: optional string name indicating the target schema.
+ Defaults to the special symbol :attr:`.RETAIN_SCHEMA` which indicates
+ that no change to the schema name should be made in the new
+ :class:`_schema.Table`. If set to a string name, the new
+ :class:`_schema.Table`
+ will have this new name as the ``.schema``. If set to ``None``, the
+ schema will be set to that of the schema set on the target
+ :class:`_schema.MetaData`, which is typically ``None`` as well,
+ unless
+ set explicitly::
+
+ m2 = MetaData(schema='newschema')
+
+ # user_copy_one will have "newschema" as the schema name
+ user_copy_one = user.to_metadata(m2, schema=None)
+
+ m3 = MetaData() # schema defaults to None
+
+ # user_copy_two will have None as the schema name
+ user_copy_two = user.to_metadata(m3, schema=None)
+
+ :param referred_schema_fn: optional callable which can be supplied
+ in order to provide for the schema name that should be assigned
+ to the referenced table of a :class:`_schema.ForeignKeyConstraint`.
+ The callable accepts this parent :class:`_schema.Table`, the
+ target schema that we are changing to, the
+ :class:`_schema.ForeignKeyConstraint` object, and the existing
+ "target schema" of that constraint. The function should return the
+ string schema name that should be applied. To reset the schema
+ to "none", return the symbol :data:`.BLANK_SCHEMA`. To effect no
+ change, return ``None`` or :data:`.RETAIN_SCHEMA`.
+
+ .. versionchanged:: 1.4.33 The ``referred_schema_fn`` function
+ may return the :data:`.BLANK_SCHEMA` or :data:`.RETAIN_SCHEMA`
+ symbols.
+
+ E.g.::
+
+ def referred_schema_fn(table, to_schema,
+ constraint, referred_schema):
+ if referred_schema == 'base_tables':
+ return referred_schema
+ else:
+ return to_schema
+
+ new_table = table.to_metadata(m2, schema="alt_schema",
+ referred_schema_fn=referred_schema_fn)
+
+ .. versionadded:: 0.9.2
+
+ :param name: optional string name indicating the target table name.
+ If not specified or None, the table name is retained. This allows
+ a :class:`_schema.Table` to be copied to the same
+ :class:`_schema.MetaData` target
+ with a new name.
+
+ .. versionadded:: 1.0.0
+
+ """
+ if name is None:
+ name = self.name
+ if schema is RETAIN_SCHEMA:
+ schema = self.schema
+ elif schema is None:
+ schema = metadata.schema
+ key = _get_table_key(name, schema)
+ if key in metadata.tables:
+ util.warn(
+ "Table '%s' already exists within the given "
+ "MetaData - not copying." % self.description
+ )
+ return metadata.tables[key]
+
+ args = []
+ for c in self.columns:
+ args.append(c._copy(schema=schema))
+ table = Table(
+ name,
+ metadata,
+ schema=schema,
+ comment=self.comment,
+ *args,
+ **self.kwargs
+ )
+ for c in self.constraints:
+ if isinstance(c, ForeignKeyConstraint):
+ referred_schema = c._referred_schema
+ if referred_schema_fn:
+ fk_constraint_schema = referred_schema_fn(
+ self, schema, c, referred_schema
+ )
+ else:
+ fk_constraint_schema = (
+ schema if referred_schema == self.schema else None
+ )
+ table.append_constraint(
+ c._copy(schema=fk_constraint_schema, target_table=table)
+ )
+ elif not c._type_bound:
+ # skip unique constraints that would be generated
+ # by the 'unique' flag on Column
+ if c._column_flag:
+ continue
+
+ table.append_constraint(
+ c._copy(schema=schema, target_table=table)
+ )
+ for index in self.indexes:
+ # skip indexes that would be generated
+ # by the 'index' flag on Column
+ if index._column_flag:
+ continue
+ Index(
+ index.name,
+ unique=index.unique,
+ *[
+ _copy_expression(expr, self, table)
+ for expr in index.expressions
+ ],
+ _table=table,
+ **index.kwargs
+ )
+ return self._schema_item_copy(table)
+
+
+class Column(DialectKWArgs, SchemaItem, ColumnClause):
+ """Represents a column in a database table."""
+
+ __visit_name__ = "column"
+
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ r"""
+ Construct a new ``Column`` object.
+
+ :param name: The name of this column as represented in the database.
+ This argument may be the first positional argument, or specified
+ via keyword.
+
+ Names which contain no upper case characters
+ will be treated as case insensitive names, and will not be quoted
+ unless they are a reserved word. Names with any number of upper
+ case characters will be quoted and sent exactly. Note that this
+ behavior applies even for databases which standardize upper
+ case names as case insensitive such as Oracle.
+
+ The name field may be omitted at construction time and applied
+ later, at any time before the Column is associated with a
+ :class:`_schema.Table`. This is to support convenient
+ usage within the :mod:`~sqlalchemy.ext.declarative` extension.
+
+ :param type\_: The column's type, indicated using an instance which
+ subclasses :class:`~sqlalchemy.types.TypeEngine`. If no arguments
+ are required for the type, the class of the type can be sent
+ as well, e.g.::
+
+ # use a type with arguments
+ Column('data', String(50))
+
+ # use no arguments
+ Column('level', Integer)
+
+ The ``type`` argument may be the second positional argument
+ or specified by keyword.
+
+ If the ``type`` is ``None`` or is omitted, it will first default to
+ the special type :class:`.NullType`. If and when this
+ :class:`_schema.Column` is made to refer to another column using
+ :class:`_schema.ForeignKey` and/or
+ :class:`_schema.ForeignKeyConstraint`, the type
+ of the remote-referenced column will be copied to this column as
+ well, at the moment that the foreign key is resolved against that
+ remote :class:`_schema.Column` object.
+
+ .. versionchanged:: 0.9.0
+ Support for propagation of type to a :class:`_schema.Column`
+ from its
+ :class:`_schema.ForeignKey` object has been improved and should be
+ more reliable and timely.
+
+ :param \*args: Additional positional arguments include various
+ :class:`.SchemaItem` derived constructs which will be applied
+ as options to the column. These include instances of
+ :class:`.Constraint`, :class:`_schema.ForeignKey`,
+ :class:`.ColumnDefault`, :class:`.Sequence`, :class:`.Computed`
+ :class:`.Identity`. In some cases an
+ equivalent keyword argument is available such as ``server_default``,
+ ``default`` and ``unique``.
+
+ :param autoincrement: Set up "auto increment" semantics for an
+ **integer primary key column with no foreign key dependencies**
+ (see later in this docstring for a more specific definition).
+ This may influence the :term:`DDL` that will be emitted for
+ this column during a table create, as well as how the column
+ will be considered when INSERT statements are compiled and
+ executed.
+
+ The default value is the string ``"auto"``,
+ which indicates that a single-column (i.e. non-composite) primary key
+ that is of an INTEGER type with no other client-side or server-side
+ default constructs indicated should receive auto increment semantics
+ automatically. Other values include ``True`` (force this column to
+ have auto-increment semantics for a :term:`composite primary key` as
+ well), ``False`` (this column should never have auto-increment
+ semantics), and the string ``"ignore_fk"`` (special-case for foreign
+ key columns, see below).
+
+ The term "auto increment semantics" refers both to the kind of DDL
+ that will be emitted for the column within a CREATE TABLE statement,
+ when methods such as :meth:`.MetaData.create_all` and
+ :meth:`.Table.create` are invoked, as well as how the column will be
+ considered when an INSERT statement is compiled and emitted to the
+ database:
+
+ * **DDL rendering** (i.e. :meth:`.MetaData.create_all`,
+ :meth:`.Table.create`): When used on a :class:`.Column` that has
+ no other
+ default-generating construct associated with it (such as a
+ :class:`.Sequence` or :class:`.Identity` construct), the parameter
+ will imply that database-specific keywords such as PostgreSQL
+ ``SERIAL``, MySQL ``AUTO_INCREMENT``, or ``IDENTITY`` on SQL Server
+ should also be rendered. Not every database backend has an
+ "implied" default generator available; for example the Oracle
+ backend always needs an explicit construct such as
+ :class:`.Identity` to be included with a :class:`.Column` in order
+ for the DDL rendered to include auto-generating constructs to also
+ be produced in the database.
+
+ * **INSERT semantics** (i.e. when a :func:`_sql.insert` construct is
+ compiled into a SQL string and is then executed on a database using
+ :meth:`_engine.Connection.execute` or equivalent): A single-row
+ INSERT statement will be known to produce a new integer primary key
+ value automatically for this column, which will be accessible
+ after the statement is invoked via the
+ :attr:`.CursorResult.inserted_primary_key` attribute upon the
+ :class:`_result.Result` object. This also applies towards use of the
+ ORM when ORM-mapped objects are persisted to the database,
+ indicating that a new integer primary key will be available to
+ become part of the :term:`identity key` for that object. This
+ behavior takes place regardless of what DDL constructs are
+ associated with the :class:`_schema.Column` and is independent
+ of the "DDL Rendering" behavior discussed in the previous note
+ above.
+
+ The parameter may be set to ``True`` to indicate that a column which
+ is part of a composite (i.e. multi-column) primary key should
+ have autoincrement semantics, though note that only one column
+ within a primary key may have this setting. It can also
+ be set to ``True`` to indicate autoincrement semantics on a
+ column that has a client-side or server-side default configured,
+ however note that not all dialects can accommodate all styles
+ of default as an "autoincrement". It can also be
+ set to ``False`` on a single-column primary key that has a
+ datatype of INTEGER in order to disable auto increment semantics
+ for that column.
+
+ .. versionchanged:: 1.1 The autoincrement flag now defaults to
+ ``"auto"`` which indicates autoincrement semantics by default
+ for single-column integer primary keys only; for composite
+ (multi-column) primary keys, autoincrement is never implicitly
+ enabled; as always, ``autoincrement=True`` will allow for
+ at most one of those columns to be an "autoincrement" column.
+ ``autoincrement=True`` may also be set on a
+ :class:`_schema.Column`
+ that has an explicit client-side or server-side default,
+ subject to limitations of the backend database and dialect.
+
+ The setting *only* has an effect for columns which are:
+
+ * Integer derived (i.e. INT, SMALLINT, BIGINT).
+
+ * Part of the primary key
+
+ * Not referring to another column via :class:`_schema.ForeignKey`,
+ unless
+ the value is specified as ``'ignore_fk'``::
+
+ # turn on autoincrement for this column despite
+ # the ForeignKey()
+ Column('id', ForeignKey('other.id'),
+ primary_key=True, autoincrement='ignore_fk')
+
+ It is typically not desirable to have "autoincrement" enabled on a
+ column that refers to another via foreign key, as such a column is
+ required to refer to a value that originates from elsewhere.
+
+ The setting has these effects on columns that meet the
+ above criteria:
+
+ * DDL issued for the column, if the column does not already include
+ a default generating construct supported by the backend such as
+ :class:`.Identity`, will include database-specific
+ keywords intended to signify this column as an
+ "autoincrement" column for specific backends. Behavior for
+ primary SQLAlchemy dialects includes:
+
+ * AUTO INCREMENT on MySQL and MariaDB
+ * SERIAL on PostgreSQL
+ * IDENTITY on MS-SQL - this occurs even without the
+ :class:`.Identity` construct as the
+ :paramref:`.Column.autoincrement` parameter pre-dates this
+ construct.
+ * SQLite - SQLite integer primary key columns are implicitly
+ "auto incrementing" and no additional keywords are rendered;
+ to render the special SQLite keyword ``AUTOINCREMENT``
+ is not included as this is unnecessary and not recommended
+ by the database vendor. See the section
+ :ref:`sqlite_autoincrement` for more background.
+ * Oracle - The Oracle dialect has no default "autoincrement"
+ feature available at this time, instead the :class:`.Identity`
+ construct is recommended to achieve this (the :class:`.Sequence`
+ construct may also be used).
+ * Third-party dialects - consult those dialects' documentation
+ for details on their specific behaviors.
+
+ * When a single-row :func:`_sql.insert` construct is compiled and
+ executed, which does not set the :meth:`_sql.Insert.inline`
+ modifier, newly generated primary key values for this column
+ will be automatically retrieved upon statement execution
+ using a method specific to the database driver in use:
+
+ * MySQL, SQLite - calling upon ``cursor.lastrowid()``
+ (see
+ `https://www.python.org/dev/peps/pep-0249/#lastrowid
+ <https://www.python.org/dev/peps/pep-0249/#lastrowid>`_)
+ * PostgreSQL, SQL Server, Oracle - use RETURNING or an equivalent
+ construct when rendering an INSERT statement, and then retrieving
+ the newly generated primary key values after execution
+ * PostgreSQL, Oracle for :class:`_schema.Table` objects that
+ set :paramref:`_schema.Table.implicit_returning` to False -
+ for a :class:`.Sequence` only, the :class:`.Sequence` is invoked
+ explicitly before the INSERT statement takes place so that the
+ newly generated primary key value is available to the client
+ * SQL Server for :class:`_schema.Table` objects that
+ set :paramref:`_schema.Table.implicit_returning` to False -
+ the ``SELECT scope_identity()`` construct is used after the
+ INSERT statement is invoked to retrieve the newly generated
+ primary key value.
+ * Third-party dialects - consult those dialects' documentation
+ for details on their specific behaviors.
+
+ * For multiple-row :func:`_sql.insert` constructs invoked with
+ a list of parameters (i.e. "executemany" semantics), primary-key
+ retrieving behaviors are generally disabled, however there may
+ be special APIs that may be used to retrieve lists of new
+ primary key values for an "executemany", such as the psycopg2
+ "fast insertmany" feature. Such features are very new and
+ may not yet be well covered in documentation.
+
+ :param default: A scalar, Python callable, or
+ :class:`_expression.ColumnElement` expression representing the
+ *default value* for this column, which will be invoked upon insert
+ if this column is otherwise not specified in the VALUES clause of
+ the insert. This is a shortcut to using :class:`.ColumnDefault` as
+ a positional argument; see that class for full detail on the
+ structure of the argument.
+
+ Contrast this argument to
+ :paramref:`_schema.Column.server_default`
+ which creates a default generator on the database side.
+
+ .. seealso::
+
+ :ref:`metadata_defaults_toplevel`
+
+ :param doc: optional String that can be used by the ORM or similar
+ to document attributes on the Python side. This attribute does
+ **not** render SQL comments; use the
+ :paramref:`_schema.Column.comment`
+ parameter for this purpose.
+
+ :param key: An optional string identifier which will identify this
+ ``Column`` object on the :class:`_schema.Table`.
+ When a key is provided,
+ this is the only identifier referencing the ``Column`` within the
+ application, including ORM attribute mapping; the ``name`` field
+ is used only when rendering SQL.
+
+ :param index: When ``True``, indicates that a :class:`_schema.Index`
+ construct will be automatically generated for this
+ :class:`_schema.Column`, which will result in a "CREATE INDEX"
+ statement being emitted for the :class:`_schema.Table` when the DDL
+ create operation is invoked.
+
+ Using this flag is equivalent to making use of the
+ :class:`_schema.Index` construct explicitly at the level of the
+ :class:`_schema.Table` construct itself::
+
+ Table(
+ "some_table",
+ metadata,
+ Column("x", Integer),
+ Index("ix_some_table_x", "x")
+ )
+
+ To add the :paramref:`_schema.Index.unique` flag to the
+ :class:`_schema.Index`, set both the
+ :paramref:`_schema.Column.unique` and
+ :paramref:`_schema.Column.index` flags to True simultaneously,
+ which will have the effect of rendering the "CREATE UNIQUE INDEX"
+ DDL instruction instead of "CREATE INDEX".
+
+ The name of the index is generated using the
+ :ref:`default naming convention <constraint_default_naming_convention>`
+ which for the :class:`_schema.Index` construct is of the form
+ ``ix_<tablename>_<columnname>``.
+
+ As this flag is intended only as a convenience for the common case
+ of adding a single-column, default configured index to a table
+ definition, explicit use of the :class:`_schema.Index` construct
+ should be preferred for most use cases, including composite indexes
+ that encompass more than one column, indexes with SQL expressions
+ or ordering, backend-specific index configuration options, and
+ indexes that use a specific name.
+
+ .. note:: the :attr:`_schema.Column.index` attribute on
+ :class:`_schema.Column`
+ **does not indicate** if this column is indexed or not, only
+ if this flag was explicitly set here. To view indexes on
+ a column, view the :attr:`_schema.Table.indexes` collection
+ or use :meth:`_reflection.Inspector.get_indexes`.
+
+ .. seealso::
+
+ :ref:`schema_indexes`
+
+ :ref:`constraint_naming_conventions`
+
+ :paramref:`_schema.Column.unique`
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ :param nullable: When set to ``False``, will cause the "NOT NULL"
+ phrase to be added when generating DDL for the column. When
+ ``True``, will normally generate nothing (in SQL this defaults to
+ "NULL"), except in some very specific backend-specific edge cases
+ where "NULL" may render explicitly.
+ Defaults to ``True`` unless :paramref:`_schema.Column.primary_key`
+ is also ``True`` or the column specifies a :class:`_sql.Identity`,
+ in which case it defaults to ``False``.
+ This parameter is only used when issuing CREATE TABLE statements.
+
+ .. note::
+
+ When the column specifies a :class:`_sql.Identity` this
+ parameter is in general ignored by the DDL compiler. The
+ PostgreSQL database allows nullable identity column by
+ setting this parameter to ``True`` explicitly.
+
+ :param onupdate: A scalar, Python callable, or
+ :class:`~sqlalchemy.sql.expression.ClauseElement` representing a
+ default value to be applied to the column within UPDATE
+ statements, which will be invoked upon update if this column is not
+ present in the SET clause of the update. This is a shortcut to
+ using :class:`.ColumnDefault` as a positional argument with
+ ``for_update=True``.
+
+ .. seealso::
+
+ :ref:`metadata_defaults` - complete discussion of onupdate
+
+ :param primary_key: If ``True``, marks this column as a primary key
+ column. Multiple columns can have this flag set to specify
+ composite primary keys. As an alternative, the primary key of a
+ :class:`_schema.Table` can be specified via an explicit
+ :class:`.PrimaryKeyConstraint` object.
+
+ :param server_default: A :class:`.FetchedValue` instance, str, Unicode
+ or :func:`~sqlalchemy.sql.expression.text` construct representing
+ the DDL DEFAULT value for the column.
+
+ String types will be emitted as-is, surrounded by single quotes::
+
+ Column('x', Text, server_default="val")
+
+ x TEXT DEFAULT 'val'
+
+ A :func:`~sqlalchemy.sql.expression.text` expression will be
+ rendered as-is, without quotes::
+
+ Column('y', DateTime, server_default=text('NOW()'))
+
+ y DATETIME DEFAULT NOW()
+
+ Strings and text() will be converted into a
+ :class:`.DefaultClause` object upon initialization.
+
+ This parameter can also accept complex combinations of contextually
+ valid SQLAlchemy expressions or constructs::
+
+ from sqlalchemy import create_engine
+ from sqlalchemy import Table, Column, MetaData, ARRAY, Text
+ from sqlalchemy.dialects.postgresql import array
+
+ engine = create_engine(
+ 'postgresql://scott:tiger@localhost/mydatabase'
+ )
+ metadata_obj = MetaData()
+ tbl = Table(
+ "foo",
+ metadata_obj,
+ Column("bar",
+ ARRAY(Text),
+ server_default=array(["biz", "bang", "bash"])
+ )
+ )
+ metadata_obj.create_all(engine)
+
+ The above results in a table created with the following SQL::
+
+ CREATE TABLE foo (
+ bar TEXT[] DEFAULT ARRAY['biz', 'bang', 'bash']
+ )
+
+ Use :class:`.FetchedValue` to indicate that an already-existing
+ column will generate a default value on the database side which
+ will be available to SQLAlchemy for post-fetch after inserts. This
+ construct does not specify any DDL and the implementation is left
+ to the database, such as via a trigger.
+
+ .. seealso::
+
+ :ref:`server_defaults` - complete discussion of server side
+ defaults
+
+ :param server_onupdate: A :class:`.FetchedValue` instance
+ representing a database-side default generation function,
+ such as a trigger. This
+ indicates to SQLAlchemy that a newly generated value will be
+ available after updates. This construct does not actually
+ implement any kind of generation function within the database,
+ which instead must be specified separately.
+
+
+ .. warning:: This directive **does not** currently produce MySQL's
+ "ON UPDATE CURRENT_TIMESTAMP()" clause. See
+ :ref:`mysql_timestamp_onupdate` for background on how to
+ produce this clause.
+
+ .. seealso::
+
+ :ref:`triggered_columns`
+
+ :param quote: Force quoting of this column's name on or off,
+ corresponding to ``True`` or ``False``. When left at its default
+ of ``None``, the column identifier will be quoted according to
+ whether the name is case sensitive (identifiers with at least one
+ upper case character are treated as case sensitive), or if it's a
+ reserved word. This flag is only needed to force quoting of a
+ reserved word which is not known by the SQLAlchemy dialect.
+
+ :param unique: When ``True``, and the :paramref:`_schema.Column.index`
+ parameter is left at its default value of ``False``,
+ indicates that a :class:`_schema.UniqueConstraint`
+ construct will be automatically generated for this
+ :class:`_schema.Column`,
+ which will result in a "UNIQUE CONSTRAINT" clause referring
+ to this column being included
+ in the ``CREATE TABLE`` statement emitted, when the DDL create
+ operation for the :class:`_schema.Table` object is invoked.
+
+ When this flag is ``True`` while the
+ :paramref:`_schema.Column.index` parameter is simultaneously
+ set to ``True``, the effect instead is that a
+ :class:`_schema.Index` construct which includes the
+ :paramref:`_schema.Index.unique` parameter set to ``True``
+ is generated. See the documentation for
+ :paramref:`_schema.Column.index` for additional detail.
+
+ Using this flag is equivalent to making use of the
+ :class:`_schema.UniqueConstraint` construct explicitly at the
+ level of the :class:`_schema.Table` construct itself::
+
+ Table(
+ "some_table",
+ metadata,
+ Column("x", Integer),
+ UniqueConstraint("x")
+ )
+
+ The :paramref:`_schema.UniqueConstraint.name` parameter
+ of the unique constraint object is left at its default value
+ of ``None``; in the absence of a :ref:`naming convention <constraint_naming_conventions>`
+ for the enclosing :class:`_schema.MetaData`, the UNIQUE CONSTRAINT
+ construct will be emitted as unnamed, which typically invokes
+ a database-specific naming convention to take place.
+
+ As this flag is intended only as a convenience for the common case
+ of adding a single-column, default configured unique constraint to a table
+ definition, explicit use of the :class:`_schema.UniqueConstraint` construct
+ should be preferred for most use cases, including composite constraints
+ that encompass more than one column, backend-specific index configuration options, and
+ constraints that use a specific name.
+
+ .. note:: the :attr:`_schema.Column.unique` attribute on
+ :class:`_schema.Column`
+ **does not indicate** if this column has a unique constraint or
+ not, only if this flag was explicitly set here. To view
+ indexes and unique constraints that may involve this column,
+ view the
+ :attr:`_schema.Table.indexes` and/or
+ :attr:`_schema.Table.constraints` collections or use
+ :meth:`_reflection.Inspector.get_indexes` and/or
+ :meth:`_reflection.Inspector.get_unique_constraints`
+
+ .. seealso::
+
+ :ref:`schema_unique_constraint`
+
+ :ref:`constraint_naming_conventions`
+
+ :paramref:`_schema.Column.index`
+
+ :param system: When ``True``, indicates this is a "system" column,
+ that is a column which is automatically made available by the
+ database, and should not be included in the columns list for a
+ ``CREATE TABLE`` statement.
+
+ For more elaborate scenarios where columns should be
+ conditionally rendered differently on different backends,
+ consider custom compilation rules for :class:`.CreateColumn`.
+
+ :param comment: Optional string that will render an SQL comment on
+ table creation.
+
+ .. versionadded:: 1.2 Added the
+ :paramref:`_schema.Column.comment`
+ parameter to :class:`_schema.Column`.
+
+
+ """ # noqa: E501, RST201, RST202
+
+ name = kwargs.pop("name", None)
+ type_ = kwargs.pop("type_", None)
+ args = list(args)
+ if args:
+ if isinstance(args[0], util.string_types):
+ if name is not None:
+ raise exc.ArgumentError(
+ "May not pass name positionally and as a keyword."
+ )
+ name = args.pop(0)
+ if args:
+ coltype = args[0]
+
+ if hasattr(coltype, "_sqla_type"):
+ if type_ is not None:
+ raise exc.ArgumentError(
+ "May not pass type_ positionally and as a keyword."
+ )
+ type_ = args.pop(0)
+
+ if name is not None:
+ name = quoted_name(name, kwargs.pop("quote", None))
+ elif "quote" in kwargs:
+ raise exc.ArgumentError(
+ "Explicit 'name' is required when " "sending 'quote' argument"
+ )
+
+ super(Column, self).__init__(name, type_)
+ self.key = kwargs.pop("key", name)
+ self.primary_key = primary_key = kwargs.pop("primary_key", False)
+
+ self._user_defined_nullable = udn = kwargs.pop(
+ "nullable", NULL_UNSPECIFIED
+ )
+
+ if udn is not NULL_UNSPECIFIED:
+ self.nullable = udn
+ else:
+ self.nullable = not primary_key
+
+ self.default = kwargs.pop("default", None)
+ self.server_default = kwargs.pop("server_default", None)
+ self.server_onupdate = kwargs.pop("server_onupdate", None)
+
+ # these default to None because .index and .unique is *not*
+ # an informational flag about Column - there can still be an
+ # Index or UniqueConstraint referring to this Column.
+ self.index = kwargs.pop("index", None)
+ self.unique = kwargs.pop("unique", None)
+
+ self.system = kwargs.pop("system", False)
+ self.doc = kwargs.pop("doc", None)
+ self.onupdate = kwargs.pop("onupdate", None)
+ self.autoincrement = kwargs.pop("autoincrement", "auto")
+ self.constraints = set()
+ self.foreign_keys = set()
+ self.comment = kwargs.pop("comment", None)
+ self.computed = None
+ self.identity = None
+
+ # check if this Column is proxying another column
+ if "_proxies" in kwargs:
+ self._proxies = kwargs.pop("_proxies")
+ # otherwise, add DDL-related events
+ elif isinstance(self.type, SchemaEventTarget):
+ self.type._set_parent_with_dispatch(self)
+
+ if self.default is not None:
+ if isinstance(self.default, (ColumnDefault, Sequence)):
+ args.append(self.default)
+ else:
+ if getattr(self.type, "_warn_on_bytestring", False):
+ if isinstance(self.default, util.binary_type):
+ util.warn(
+ "Unicode column '%s' has non-unicode "
+ "default value %r specified."
+ % (self.key, self.default)
+ )
+ args.append(ColumnDefault(self.default))
+
+ if self.server_default is not None:
+ if isinstance(self.server_default, FetchedValue):
+ args.append(self.server_default._as_for_update(False))
+ else:
+ args.append(DefaultClause(self.server_default))
+
+ if self.onupdate is not None:
+ if isinstance(self.onupdate, (ColumnDefault, Sequence)):
+ args.append(self.onupdate)
+ else:
+ args.append(ColumnDefault(self.onupdate, for_update=True))
+
+ if self.server_onupdate is not None:
+ if isinstance(self.server_onupdate, FetchedValue):
+ args.append(self.server_onupdate._as_for_update(True))
+ else:
+ args.append(
+ DefaultClause(self.server_onupdate, for_update=True)
+ )
+ self._init_items(*args)
+
+ util.set_creation_order(self)
+
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+
+ self._extra_kwargs(**kwargs)
+
+ foreign_keys = None
+ """A collection of all :class:`_schema.ForeignKey` marker objects
+ associated with this :class:`_schema.Column`.
+
+ Each object is a member of a :class:`_schema.Table`-wide
+ :class:`_schema.ForeignKeyConstraint`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.foreign_keys`
+
+ """
+
+ index = None
+ """The value of the :paramref:`_schema.Column.index` parameter.
+
+ Does not indicate if this :class:`_schema.Column` is actually indexed
+ or not; use :attr:`_schema.Table.indexes`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.indexes`
+ """
+
+ unique = None
+ """The value of the :paramref:`_schema.Column.unique` parameter.
+
+ Does not indicate if this :class:`_schema.Column` is actually subject to
+ a unique constraint or not; use :attr:`_schema.Table.indexes` and
+ :attr:`_schema.Table.constraints`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.indexes`
+
+ :attr:`_schema.Table.constraints`.
+
+ """
+
+ def _extra_kwargs(self, **kwargs):
+ self._validate_dialect_kwargs(kwargs)
+
+ def __str__(self):
+ if self.name is None:
+ return "(no name)"
+ elif self.table is not None:
+ if self.table.named_with_column:
+ return self.table.description + "." + self.description
+ else:
+ return self.description
+ else:
+ return self.description
+
+ def references(self, column):
+ """Return True if this Column references the given column via foreign
+ key."""
+
+ for fk in self.foreign_keys:
+ if fk.column.proxy_set.intersection(column.proxy_set):
+ return True
+ else:
+ return False
+
+ def append_foreign_key(self, fk):
+ fk._set_parent_with_dispatch(self)
+
+ def __repr__(self):
+ kwarg = []
+ if self.key != self.name:
+ kwarg.append("key")
+ if self.primary_key:
+ kwarg.append("primary_key")
+ if not self.nullable:
+ kwarg.append("nullable")
+ if self.onupdate:
+ kwarg.append("onupdate")
+ if self.default:
+ kwarg.append("default")
+ if self.server_default:
+ kwarg.append("server_default")
+ if self.comment:
+ kwarg.append("comment")
+ return "Column(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.type)]
+ + [repr(x) for x in self.foreign_keys if x is not None]
+ + [repr(x) for x in self.constraints]
+ + [
+ (
+ self.table is not None
+ and "table=<%s>" % self.table.description
+ or "table=None"
+ )
+ ]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
+ )
+
+ def _set_parent(self, table, allow_replacements=True):
+ if not self.name:
+ raise exc.ArgumentError(
+ "Column must be constructed with a non-blank name or "
+ "assign a non-blank .name before adding to a Table."
+ )
+
+ self._reset_memoizations()
+
+ if self.key is None:
+ self.key = self.name
+
+ existing = getattr(self, "table", None)
+ if existing is not None and existing is not table:
+ raise exc.ArgumentError(
+ "Column object '%s' already assigned to Table '%s'"
+ % (self.key, existing.description)
+ )
+
+ if self.key in table._columns:
+ col = table._columns.get(self.key)
+ if col is not self:
+ if not allow_replacements:
+ util.warn_deprecated(
+ "A column with name '%s' is already present "
+ "in table '%s'. Please use method "
+ ":meth:`_schema.Table.append_column` with the "
+ "parameter ``replace_existing=True`` to replace an "
+ "existing column." % (self.key, table.name),
+ "1.4",
+ )
+ for fk in col.foreign_keys:
+ table.foreign_keys.remove(fk)
+ if fk.constraint in table.constraints:
+ # this might have been removed
+ # already, if it's a composite constraint
+ # and more than one col being replaced
+ table.constraints.remove(fk.constraint)
+
+ table._columns.replace(self)
+
+ self.table = table
+
+ if self.primary_key:
+ table.primary_key._replace(self)
+ elif self.key in table.primary_key:
+ raise exc.ArgumentError(
+ "Trying to redefine primary-key column '%s' as a "
+ "non-primary-key column on table '%s'"
+ % (self.key, table.fullname)
+ )
+
+ if self.index:
+ if isinstance(self.index, util.string_types):
+ raise exc.ArgumentError(
+ "The 'index' keyword argument on Column is boolean only. "
+ "To create indexes with a specific name, create an "
+ "explicit Index object external to the Table."
+ )
+ table.append_constraint(
+ Index(
+ None, self.key, unique=bool(self.unique), _column_flag=True
+ )
+ )
+
+ elif self.unique:
+ if isinstance(self.unique, util.string_types):
+ raise exc.ArgumentError(
+ "The 'unique' keyword argument on Column is boolean "
+ "only. To create unique constraints or indexes with a "
+ "specific name, append an explicit UniqueConstraint to "
+ "the Table's list of elements, or create an explicit "
+ "Index object external to the Table."
+ )
+ table.append_constraint(
+ UniqueConstraint(self.key, _column_flag=True)
+ )
+
+ self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table))
+
+ if self.identity and (
+ isinstance(self.default, Sequence)
+ or isinstance(self.onupdate, Sequence)
+ ):
+ raise exc.ArgumentError(
+ "An column cannot specify both Identity and Sequence."
+ )
+
+ def _setup_on_memoized_fks(self, fn):
+ fk_keys = [
+ ((self.table.key, self.key), False),
+ ((self.table.key, self.name), True),
+ ]
+ for fk_key, link_to_name in fk_keys:
+ if fk_key in self.table.metadata._fk_memos:
+ for fk in self.table.metadata._fk_memos[fk_key]:
+ if fk.link_to_name is link_to_name:
+ fn(fk)
+
+ def _on_table_attach(self, fn):
+ if self.table is not None:
+ fn(self, self.table)
+ else:
+ event.listen(self, "after_parent_attach", fn)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Column.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, **kw):
+ return self._copy(**kw)
+
+ def _copy(self, **kw):
+ """Create a copy of this ``Column``, uninitialized.
+
+ This is used in :meth:`_schema.Table.to_metadata`.
+
+ """
+
+ # Constraint objects plus non-constraint-bound ForeignKey objects
+ args = [
+ c._copy(**kw) for c in self.constraints if not c._type_bound
+ ] + [c._copy(**kw) for c in self.foreign_keys if not c.constraint]
+
+ # ticket #5276
+ column_kwargs = {}
+ for dialect_name in self.dialect_options:
+ dialect_options = self.dialect_options[dialect_name]._non_defaults
+ for (
+ dialect_option_key,
+ dialect_option_value,
+ ) in dialect_options.items():
+ column_kwargs[
+ dialect_name + "_" + dialect_option_key
+ ] = dialect_option_value
+
+ server_default = self.server_default
+ server_onupdate = self.server_onupdate
+ if isinstance(server_default, (Computed, Identity)):
+ server_default = server_onupdate = None
+ args.append(self.server_default._copy(**kw))
+
+ type_ = self.type
+ if isinstance(type_, SchemaEventTarget):
+ type_ = type_.copy(**kw)
+
+ if self._user_defined_nullable is not NULL_UNSPECIFIED:
+ column_kwargs["nullable"] = self._user_defined_nullable
+
+ c = self._constructor(
+ name=self.name,
+ type_=type_,
+ key=self.key,
+ primary_key=self.primary_key,
+ unique=self.unique,
+ system=self.system,
+ # quote=self.quote, # disabled 2013-08-27 (commit 031ef080)
+ index=self.index,
+ autoincrement=self.autoincrement,
+ default=self.default,
+ server_default=server_default,
+ onupdate=self.onupdate,
+ server_onupdate=server_onupdate,
+ doc=self.doc,
+ comment=self.comment,
+ *args,
+ **column_kwargs
+ )
+ return self._schema_item_copy(c)
+
+ def _make_proxy(
+ self, selectable, name=None, key=None, name_is_truncatable=False, **kw
+ ):
+ """Create a *proxy* for this column.
+
+ This is a copy of this ``Column`` referenced by a different parent
+ (such as an alias or select statement). The column should
+ be used only in select scenarios, as its full DDL/default
+ information is not transferred.
+
+ """
+
+ fk = [
+ ForeignKey(
+ col if col is not None else f._colspec,
+ _unresolvable=col is None,
+ _constraint=f.constraint,
+ )
+ for f, col in [
+ (fk, fk._resolve_column(raiseerr=False))
+ for fk in self.foreign_keys
+ ]
+ ]
+
+ if name is None and self.name is None:
+ raise exc.InvalidRequestError(
+ "Cannot initialize a sub-selectable"
+ " with this Column object until its 'name' has "
+ "been assigned."
+ )
+ try:
+ c = self._constructor(
+ coercions.expect(
+ roles.TruncatedLabelRole, name if name else self.name
+ )
+ if name_is_truncatable
+ else (name or self.name),
+ self.type,
+ # this may actually be ._proxy_key when the key is incoming
+ key=key if key else name if name else self.key,
+ primary_key=self.primary_key,
+ nullable=self.nullable,
+ _proxies=[self],
+ *fk
+ )
+ except TypeError as err:
+ util.raise_(
+ TypeError(
+ "Could not create a copy of this %r object. "
+ "Ensure the class includes a _constructor() "
+ "attribute or method which accepts the "
+ "standard Column constructor arguments, or "
+ "references the Column class itself." % self.__class__
+ ),
+ from_=err,
+ )
+
+ c.table = selectable
+ c._propagate_attrs = selectable._propagate_attrs
+ if selectable._is_clone_of is not None:
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
+ if self.primary_key:
+ selectable.primary_key.add(c)
+ if fk:
+ selectable.foreign_keys.update(fk)
+ return c.key, c
+
+
+class ForeignKey(DialectKWArgs, SchemaItem):
+ """Defines a dependency between two columns.
+
+ ``ForeignKey`` is specified as an argument to a :class:`_schema.Column`
+ object,
+ e.g.::
+
+ t = Table("remote_table", metadata,
+ Column("remote_id", ForeignKey("main_table.id"))
+ )
+
+ Note that ``ForeignKey`` is only a marker object that defines
+ a dependency between two columns. The actual constraint
+ is in all cases represented by the :class:`_schema.ForeignKeyConstraint`
+ object. This object will be generated automatically when
+ a ``ForeignKey`` is associated with a :class:`_schema.Column` which
+ in turn is associated with a :class:`_schema.Table`. Conversely,
+ when :class:`_schema.ForeignKeyConstraint` is applied to a
+ :class:`_schema.Table`,
+ ``ForeignKey`` markers are automatically generated to be
+ present on each associated :class:`_schema.Column`, which are also
+ associated with the constraint object.
+
+ Note that you cannot define a "composite" foreign key constraint,
+ that is a constraint between a grouping of multiple parent/child
+ columns, using ``ForeignKey`` objects. To define this grouping,
+ the :class:`_schema.ForeignKeyConstraint` object must be used, and applied
+ to the :class:`_schema.Table`. The associated ``ForeignKey`` objects
+ are created automatically.
+
+ The ``ForeignKey`` objects associated with an individual
+ :class:`_schema.Column`
+ object are available in the `foreign_keys` collection
+ of that column.
+
+ Further examples of foreign key configuration are in
+ :ref:`metadata_foreignkeys`.
+
+ """
+
+ __visit_name__ = "foreign_key"
+
+ def __init__(
+ self,
+ column,
+ _constraint=None,
+ use_alter=False,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ link_to_name=False,
+ match=None,
+ info=None,
+ _unresolvable=False,
+ **dialect_kw
+ ):
+ r"""
+ Construct a column-level FOREIGN KEY.
+
+ The :class:`_schema.ForeignKey` object when constructed generates a
+ :class:`_schema.ForeignKeyConstraint`
+ which is associated with the parent
+ :class:`_schema.Table` object's collection of constraints.
+
+ :param column: A single target column for the key relationship. A
+ :class:`_schema.Column` object or a column name as a string:
+ ``tablename.columnkey`` or ``schema.tablename.columnkey``.
+ ``columnkey`` is the ``key`` which has been assigned to the column
+ (defaults to the column name itself), unless ``link_to_name`` is
+ ``True`` in which case the rendered name of the column is used.
+
+ :param name: Optional string. An in-database name for the key if
+ `constraint` is not provided.
+
+ :param onupdate: Optional string. If set, emit ON UPDATE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param ondelete: Optional string. If set, emit ON DELETE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT
+ DEFERRABLE when issuing DDL for this constraint.
+
+ :param initially: Optional string. If set, emit INITIALLY <value> when
+ issuing DDL for this constraint.
+
+ :param link_to_name: if True, the string name given in ``column`` is
+ the rendered name of the referenced column, not its locally
+ assigned ``key``.
+
+ :param use_alter: passed to the underlying
+ :class:`_schema.ForeignKeyConstraint`
+ to indicate the constraint should
+ be generated/dropped externally from the CREATE TABLE/ DROP TABLE
+ statement. See :paramref:`_schema.ForeignKeyConstraint.use_alter`
+ for further description.
+
+ .. seealso::
+
+ :paramref:`_schema.ForeignKeyConstraint.use_alter`
+
+ :ref:`use_alter`
+
+ :param match: Optional string. If set, emit MATCH <value> when issuing
+ DDL for this constraint. Typical values include SIMPLE, PARTIAL
+ and FULL.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**dialect_kw: Additional keyword arguments are dialect
+ specific, and passed in the form ``<dialectname>_<argname>``. The
+ arguments are ultimately handled by a corresponding
+ :class:`_schema.ForeignKeyConstraint`.
+ See the documentation regarding
+ an individual dialect at :ref:`dialect_toplevel` for detail on
+ documented arguments.
+
+ .. versionadded:: 0.9.2
+
+ """
+
+ self._colspec = coercions.expect(roles.DDLReferredColumnRole, column)
+ self._unresolvable = _unresolvable
+
+ if isinstance(self._colspec, util.string_types):
+ self._table_column = None
+ else:
+ self._table_column = self._colspec
+
+ if not isinstance(
+ self._table_column.table, (util.NoneType, TableClause)
+ ):
+ raise exc.ArgumentError(
+ "ForeignKey received Column not bound "
+ "to a Table, got: %r" % self._table_column.table
+ )
+
+ # the linked ForeignKeyConstraint.
+ # ForeignKey will create this when parent Column
+ # is attached to a Table, *or* ForeignKeyConstraint
+ # object passes itself in when creating ForeignKey
+ # markers.
+ self.constraint = _constraint
+ self.parent = None
+ self.use_alter = use_alter
+ self.name = name
+ self.onupdate = onupdate
+ self.ondelete = ondelete
+ self.deferrable = deferrable
+ self.initially = initially
+ self.link_to_name = link_to_name
+ self.match = match
+ if info:
+ self.info = info
+ self._unvalidated_dialect_kw = dialect_kw
+
+ def __repr__(self):
+ return "ForeignKey(%r)" % self._get_colspec()
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.ForeignKey.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, schema=None, **kw):
+ return self._copy(schema=schema, **kw)
+
+ def _copy(self, schema=None, **kw):
+ """Produce a copy of this :class:`_schema.ForeignKey` object.
+
+ The new :class:`_schema.ForeignKey` will not be bound
+ to any :class:`_schema.Column`.
+
+ This method is usually used by the internal
+ copy procedures of :class:`_schema.Column`, :class:`_schema.Table`,
+ and :class:`_schema.MetaData`.
+
+ :param schema: The returned :class:`_schema.ForeignKey` will
+ reference the original table and column name, qualified
+ by the given string schema name.
+
+ """
+
+ fk = ForeignKey(
+ self._get_colspec(schema=schema),
+ use_alter=self.use_alter,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ link_to_name=self.link_to_name,
+ match=self.match,
+ **self._unvalidated_dialect_kw
+ )
+ return self._schema_item_copy(fk)
+
+ def _get_colspec(self, schema=None, table_name=None):
+ """Return a string based 'column specification' for this
+ :class:`_schema.ForeignKey`.
+
+ This is usually the equivalent of the string-based "tablename.colname"
+ argument first passed to the object's constructor.
+
+ """
+ if schema not in (None, RETAIN_SCHEMA):
+ _schema, tname, colname = self._column_tokens
+ if table_name is not None:
+ tname = table_name
+ if schema is BLANK_SCHEMA:
+ return "%s.%s" % (tname, colname)
+ else:
+ return "%s.%s.%s" % (schema, tname, colname)
+ elif table_name:
+ schema, tname, colname = self._column_tokens
+ if schema:
+ return "%s.%s.%s" % (schema, table_name, colname)
+ else:
+ return "%s.%s" % (table_name, colname)
+ elif self._table_column is not None:
+ return "%s.%s" % (
+ self._table_column.table.fullname,
+ self._table_column.key,
+ )
+ else:
+ return self._colspec
+
+ @property
+ def _referred_schema(self):
+ return self._column_tokens[0]
+
+ def _table_key(self):
+ if self._table_column is not None:
+ if self._table_column.table is None:
+ return None
+ else:
+ return self._table_column.table.key
+ else:
+ schema, tname, colname = self._column_tokens
+ return _get_table_key(tname, schema)
+
+ target_fullname = property(_get_colspec)
+
+ def references(self, table):
+ """Return True if the given :class:`_schema.Table`
+ is referenced by this
+ :class:`_schema.ForeignKey`."""
+
+ return table.corresponding_column(self.column) is not None
+
+ def get_referent(self, table):
+ """Return the :class:`_schema.Column` in the given
+ :class:`_schema.Table`
+ referenced by this :class:`_schema.ForeignKey`.
+
+ Returns None if this :class:`_schema.ForeignKey`
+ does not reference the given
+ :class:`_schema.Table`.
+
+ """
+
+ return table.corresponding_column(self.column)
+
+ @util.memoized_property
+ def _column_tokens(self):
+ """parse a string-based _colspec into its component parts."""
+
+ m = self._get_colspec().split(".")
+ if m is None:
+ raise exc.ArgumentError(
+ "Invalid foreign key column specification: %s" % self._colspec
+ )
+ if len(m) == 1:
+ tname = m.pop()
+ colname = None
+ else:
+ colname = m.pop()
+ tname = m.pop()
+
+ # A FK between column 'bar' and table 'foo' can be
+ # specified as 'foo', 'foo.bar', 'dbo.foo.bar',
+ # 'otherdb.dbo.foo.bar'. Once we have the column name and
+ # the table name, treat everything else as the schema
+ # name. Some databases (e.g. Sybase) support
+ # inter-database foreign keys. See tickets#1341 and --
+ # indirectly related -- Ticket #594. This assumes that '.'
+ # will never appear *within* any component of the FK.
+
+ if len(m) > 0:
+ schema = ".".join(m)
+ else:
+ schema = None
+ return schema, tname, colname
+
+ def _resolve_col_tokens(self):
+ if self.parent is None:
+ raise exc.InvalidRequestError(
+ "this ForeignKey object does not yet have a "
+ "parent Column associated with it."
+ )
+
+ elif self.parent.table is None:
+ raise exc.InvalidRequestError(
+ "this ForeignKey's parent column is not yet associated "
+ "with a Table."
+ )
+
+ parenttable = self.parent.table
+
+ if self._unresolvable:
+ schema, tname, colname = self._column_tokens
+ tablekey = _get_table_key(tname, schema)
+ return parenttable, tablekey, colname
+
+ # assertion
+ # basically Column._make_proxy() sends the actual
+ # target Column to the ForeignKey object, so the
+ # string resolution here is never called.
+ for c in self.parent.base_columns:
+ if isinstance(c, Column):
+ assert c.table is parenttable
+ break
+ else:
+ assert False
+ ######################
+
+ schema, tname, colname = self._column_tokens
+
+ if schema is None and parenttable.metadata.schema is not None:
+ schema = parenttable.metadata.schema
+
+ tablekey = _get_table_key(tname, schema)
+ return parenttable, tablekey, colname
+
+ def _link_to_col_by_colstring(self, parenttable, table, colname):
+
+ _column = None
+ if colname is None:
+ # colname is None in the case that ForeignKey argument
+ # was specified as table name only, in which case we
+ # match the column name to the same column on the
+ # parent.
+ # this use case wasn't working in later 1.x series
+ # as it had no test coverage; fixed in 2.0
+ parent = self.parent
+ assert parent is not None
+ key = parent.key
+ _column = table.c.get(key, None)
+ elif self.link_to_name:
+ key = colname
+ for c in table.c:
+ if c.name == colname:
+ _column = c
+ else:
+ key = colname
+ _column = table.c.get(colname, None)
+
+ if _column is None:
+ raise exc.NoReferencedColumnError(
+ "Could not initialize target column "
+ "for ForeignKey '%s' on table '%s': "
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, table.name, key),
+ table.name,
+ key,
+ )
+
+ return _column
+
+ def _set_target_column(self, column):
+ assert self.parent is not None
+
+ # propagate TypeEngine to parent if it didn't have one
+ if self.parent.type._isnull:
+ self.parent.type = column.type
+
+ # super-edgy case, if other FKs point to our column,
+ # they'd get the type propagated out also.
+
+ def set_type(fk):
+ if fk.parent.type._isnull:
+ fk.parent.type = column.type
+
+ self.parent._setup_on_memoized_fks(set_type)
+
+ self.column = column
+
+ @util.memoized_property
+ def column(self):
+ """Return the target :class:`_schema.Column` referenced by this
+ :class:`_schema.ForeignKey`.
+
+ If no target column has been established, an exception
+ is raised.
+
+ .. versionchanged:: 0.9.0
+ Foreign key target column resolution now occurs as soon as both
+ the ForeignKey object and the remote Column to which it refers
+ are both associated with the same MetaData object.
+
+ """
+
+ return self._resolve_column()
+
+ def _resolve_column(self, raiseerr=True):
+
+ if isinstance(self._colspec, util.string_types):
+
+ parenttable, tablekey, colname = self._resolve_col_tokens()
+
+ if self._unresolvable or tablekey not in parenttable.metadata:
+ if not raiseerr:
+ return None
+ raise exc.NoReferencedTableError(
+ "Foreign key associated with column '%s' could not find "
+ "table '%s' with which to generate a "
+ "foreign key to target column '%s'"
+ % (self.parent, tablekey, colname),
+ tablekey,
+ )
+ elif parenttable.key not in parenttable.metadata:
+ if not raiseerr:
+ return None
+ raise exc.InvalidRequestError(
+ "Table %s is no longer associated with its "
+ "parent MetaData" % parenttable
+ )
+ else:
+ table = parenttable.metadata.tables[tablekey]
+ return self._link_to_col_by_colstring(
+ parenttable, table, colname
+ )
+
+ elif hasattr(self._colspec, "__clause_element__"):
+ _column = self._colspec.__clause_element__()
+ return _column
+ else:
+ _column = self._colspec
+ return _column
+
+ def _set_parent(self, column, **kw):
+ if self.parent is not None and self.parent is not column:
+ raise exc.InvalidRequestError(
+ "This ForeignKey already has a parent !"
+ )
+ self.parent = column
+ self.parent.foreign_keys.add(self)
+ self.parent._on_table_attach(self._set_table)
+
+ def _set_remote_table(self, table):
+ parenttable, tablekey, colname = self._resolve_col_tokens()
+ self._link_to_col_by_colstring(parenttable, table, colname)
+
+ _column = self._link_to_col_by_colstring(parenttable, table, colname)
+ self._set_target_column(_column)
+ assert self.constraint is not None
+
+ self.constraint._validate_dest_table(table)
+
+ def _remove_from_metadata(self, metadata):
+ parenttable, table_key, colname = self._resolve_col_tokens()
+ fk_key = (table_key, colname)
+
+ if self in metadata._fk_memos[fk_key]:
+ # TODO: no test coverage for self not in memos
+ metadata._fk_memos[fk_key].remove(self)
+
+ def _set_table(self, column, table):
+ # standalone ForeignKey - create ForeignKeyConstraint
+ # on the hosting Table when attached to the Table.
+ assert isinstance(table, Table)
+ if self.constraint is None:
+ self.constraint = ForeignKeyConstraint(
+ [],
+ [],
+ use_alter=self.use_alter,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ match=self.match,
+ **self._unvalidated_dialect_kw
+ )
+ self.constraint._append_element(column, self)
+ self.constraint._set_parent_with_dispatch(table)
+ table.foreign_keys.add(self)
+ # set up remote ".column" attribute, or a note to pick it
+ # up when the other Table/Column shows up
+ if isinstance(self._colspec, util.string_types):
+ parenttable, table_key, colname = self._resolve_col_tokens()
+ fk_key = (table_key, colname)
+ if table_key in parenttable.metadata.tables:
+ table = parenttable.metadata.tables[table_key]
+ try:
+ _column = self._link_to_col_by_colstring(
+ parenttable, table, colname
+ )
+ except exc.NoReferencedColumnError:
+ # this is OK, we'll try later
+ pass
+ else:
+ self._set_target_column(_column)
+ parenttable.metadata._fk_memos[fk_key].append(self)
+ elif hasattr(self._colspec, "__clause_element__"):
+ _column = self._colspec.__clause_element__()
+ self._set_target_column(_column)
+ else:
+ _column = self._colspec
+ self._set_target_column(_column)
+
+
+class DefaultGenerator(Executable, SchemaItem):
+ """Base class for column *default* values."""
+
+ __visit_name__ = "default_generator"
+
+ is_sequence = False
+ is_server_default = False
+ column = None
+
+ def __init__(self, for_update=False):
+ self.for_update = for_update
+
+ def _set_parent(self, column, **kw):
+ self.column = column
+ if self.for_update:
+ self.column.onupdate = self
+ else:
+ self.column.default = self
+
+ @util.deprecated_20(
+ ":meth:`.DefaultGenerator.execute`",
+ alternative="All statement execution in SQLAlchemy 2.0 is performed "
+ "by the :meth:`_engine.Connection.execute` method of "
+ ":class:`_engine.Connection`, "
+ "or in the ORM by the :meth:`.Session.execute` method of "
+ ":class:`.Session`.",
+ )
+ def execute(self, bind=None):
+ if bind is None:
+ bind = _bind_or_error(self)
+ return bind._execute_default(self, (), util.EMPTY_DICT)
+
+ def _execute_on_connection(
+ self, connection, multiparams, params, execution_options
+ ):
+ return connection._execute_default(
+ self, multiparams, params, execution_options
+ )
+
+ @property
+ def bind(self):
+ """Return the connectable associated with this default."""
+ if getattr(self, "column", None) is not None:
+ return self.column.table.bind
+ else:
+ return None
+
+
+class ColumnDefault(DefaultGenerator):
+ """A plain default value on a column.
+
+ This could correspond to a constant, a callable function,
+ or a SQL clause.
+
+ :class:`.ColumnDefault` is generated automatically
+ whenever the ``default``, ``onupdate`` arguments of
+ :class:`_schema.Column` are used. A :class:`.ColumnDefault`
+ can be passed positionally as well.
+
+ For example, the following::
+
+ Column('foo', Integer, default=50)
+
+ Is equivalent to::
+
+ Column('foo', Integer, ColumnDefault(50))
+
+
+ """
+
+ def __init__(self, arg, **kwargs):
+ """Construct a new :class:`.ColumnDefault`.
+
+
+ :param arg: argument representing the default value.
+ May be one of the following:
+
+ * a plain non-callable Python value, such as a
+ string, integer, boolean, or other simple type.
+ The default value will be used as is each time.
+ * a SQL expression, that is one which derives from
+ :class:`_expression.ColumnElement`. The SQL expression will
+ be rendered into the INSERT or UPDATE statement,
+ or in the case of a primary key column when
+ RETURNING is not used may be
+ pre-executed before an INSERT within a SELECT.
+ * A Python callable. The function will be invoked for each
+ new row subject to an INSERT or UPDATE.
+ The callable must accept exactly
+ zero or one positional arguments. The one-argument form
+ will receive an instance of the :class:`.ExecutionContext`,
+ which provides contextual information as to the current
+ :class:`_engine.Connection` in use as well as the current
+ statement and parameters.
+
+ """
+ super(ColumnDefault, self).__init__(**kwargs)
+ if isinstance(arg, FetchedValue):
+ raise exc.ArgumentError(
+ "ColumnDefault may not be a server-side default type."
+ )
+ if callable(arg):
+ arg = self._maybe_wrap_callable(arg)
+ self.arg = arg
+
+ @util.memoized_property
+ def is_callable(self):
+ return callable(self.arg)
+
+ @util.memoized_property
+ def is_clause_element(self):
+ return isinstance(self.arg, ClauseElement)
+
+ @util.memoized_property
+ def is_scalar(self):
+ return (
+ not self.is_callable
+ and not self.is_clause_element
+ and not self.is_sequence
+ )
+
+ @util.memoized_property
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def _arg_is_typed(self):
+ sqltypes = util.preloaded.sql_sqltypes
+
+ if self.is_clause_element:
+ return not isinstance(self.arg.type, sqltypes.NullType)
+ else:
+ return False
+
+ def _maybe_wrap_callable(self, fn):
+ """Wrap callables that don't accept a context.
+
+ This is to allow easy compatibility with default callables
+ that aren't specific to accepting of a context.
+
+ """
+ try:
+ argspec = util.get_callable_argspec(fn, no_self=True)
+ except TypeError:
+ return util.wrap_callable(lambda ctx: fn(), fn)
+
+ defaulted = argspec[3] is not None and len(argspec[3]) or 0
+ positionals = len(argspec[0]) - defaulted
+
+ if positionals == 0:
+ return util.wrap_callable(lambda ctx: fn(), fn)
+
+ elif positionals == 1:
+ return fn
+ else:
+ raise exc.ArgumentError(
+ "ColumnDefault Python function takes zero or one "
+ "positional arguments"
+ )
+
+ def __repr__(self):
+ return "ColumnDefault(%r)" % (self.arg,)
+
+
+class IdentityOptions(object):
+ """Defines options for a named database sequence or an identity column.
+
+ .. versionadded:: 1.3.18
+
+ .. seealso::
+
+ :class:`.Sequence`
+
+ """
+
+ def __init__(
+ self,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ cache=None,
+ order=None,
+ ):
+ """Construct a :class:`.IdentityOptions` object.
+
+ See the :class:`.Sequence` documentation for a complete description
+ of the parameters.
+
+ :param start: the starting index of the sequence.
+ :param increment: the increment value of the sequence.
+ :param minvalue: the minimum value of the sequence.
+ :param maxvalue: the maximum value of the sequence.
+ :param nominvalue: no minimum value of the sequence.
+ :param nomaxvalue: no maximum value of the sequence.
+ :param cycle: allows the sequence to wrap around when the maxvalue
+ or minvalue has been reached.
+ :param cache: optional integer value; number of future values in the
+ sequence which are calculated in advance.
+ :param order: optional boolean value; if ``True``, renders the
+ ORDER keyword.
+
+ """
+ self.start = start
+ self.increment = increment
+ self.minvalue = minvalue
+ self.maxvalue = maxvalue
+ self.nominvalue = nominvalue
+ self.nomaxvalue = nomaxvalue
+ self.cycle = cycle
+ self.cache = cache
+ self.order = order
+
+
+class Sequence(IdentityOptions, DefaultGenerator):
+ """Represents a named database sequence.
+
+ The :class:`.Sequence` object represents the name and configurational
+ parameters of a database sequence. It also represents
+ a construct that can be "executed" by a SQLAlchemy :class:`_engine.Engine`
+ or :class:`_engine.Connection`,
+ rendering the appropriate "next value" function
+ for the target database and returning a result.
+
+ The :class:`.Sequence` is typically associated with a primary key column::
+
+ some_table = Table(
+ 'some_table', metadata,
+ Column('id', Integer, Sequence('some_table_seq'),
+ primary_key=True)
+ )
+
+ When CREATE TABLE is emitted for the above :class:`_schema.Table`, if the
+ target platform supports sequences, a CREATE SEQUENCE statement will
+ be emitted as well. For platforms that don't support sequences,
+ the :class:`.Sequence` construct is ignored.
+
+ .. seealso::
+
+ :ref:`defaults_sequences`
+
+ :class:`.CreateSequence`
+
+ :class:`.DropSequence`
+
+ """
+
+ __visit_name__ = "sequence"
+
+ is_sequence = True
+
+ def __init__(
+ self,
+ name,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ schema=None,
+ cache=None,
+ order=None,
+ data_type=None,
+ optional=False,
+ quote=None,
+ metadata=None,
+ quote_schema=None,
+ for_update=False,
+ ):
+ """Construct a :class:`.Sequence` object.
+
+ :param name: the name of the sequence.
+
+ :param start: the starting index of the sequence. This value is
+ used when the CREATE SEQUENCE command is emitted to the database
+ as the value of the "START WITH" clause. If ``None``, the
+ clause is omitted, which on most platforms indicates a starting
+ value of 1.
+ :param increment: the increment value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "INCREMENT BY" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates an
+ increment of 1.
+ :param minvalue: the minimum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "MINVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ minvalue of 1 and -2^63-1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param maxvalue: the maximum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "MAXVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ maxvalue of 2^63-1 and -1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param nominvalue: no minimum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "NO MINVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ minvalue of 1 and -2^63-1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param nomaxvalue: no maximum value of the sequence. This
+ value is used when the CREATE SEQUENCE command is emitted to
+ the database as the value of the "NO MAXVALUE" clause. If ``None``,
+ the clause is omitted, which on most platforms indicates a
+ maxvalue of 2^63-1 and -1 for ascending and descending sequences,
+ respectively.
+
+ .. versionadded:: 1.0.7
+
+ :param cycle: allows the sequence to wrap around when the maxvalue
+ or minvalue has been reached by an ascending or descending sequence
+ respectively. This value is used when the CREATE SEQUENCE command
+ is emitted to the database as the "CYCLE" clause. If the limit is
+ reached, the next number generated will be the minvalue or maxvalue,
+ respectively. If cycle=False (the default) any calls to nextval
+ after the sequence has reached its maximum value will return an
+ error.
+
+ .. versionadded:: 1.0.7
+
+ :param schema: optional schema name for the sequence, if located
+ in a schema other than the default. The rules for selecting the
+ schema name when a :class:`_schema.MetaData`
+ is also present are the same
+ as that of :paramref:`_schema.Table.schema`.
+
+ :param cache: optional integer value; number of future values in the
+ sequence which are calculated in advance. Renders the CACHE keyword
+ understood by Oracle and PostgreSQL.
+
+ .. versionadded:: 1.1.12
+
+ :param order: optional boolean value; if ``True``, renders the
+ ORDER keyword, understood by Oracle, indicating the sequence is
+ definitively ordered. May be necessary to provide deterministic
+ ordering using Oracle RAC.
+
+ .. versionadded:: 1.1.12
+
+ :param data_type: The type to be returned by the sequence, for
+ dialects that allow us to choose between INTEGER, BIGINT, etc.
+ (e.g., mssql).
+
+ .. versionadded:: 1.4.0
+
+ :param optional: boolean value, when ``True``, indicates that this
+ :class:`.Sequence` object only needs to be explicitly generated
+ on backends that don't provide another way to generate primary
+ key identifiers. Currently, it essentially means, "don't create
+ this sequence on the PostgreSQL backend, where the SERIAL keyword
+ creates a sequence for us automatically".
+ :param quote: boolean value, when ``True`` or ``False``, explicitly
+ forces quoting of the :paramref:`_schema.Sequence.name` on or off.
+ When left at its default of ``None``, normal quoting rules based
+ on casing and reserved words take place.
+ :param quote_schema: Set the quoting preferences for the ``schema``
+ name.
+
+ :param metadata: optional :class:`_schema.MetaData` object which this
+ :class:`.Sequence` will be associated with. A :class:`.Sequence`
+ that is associated with a :class:`_schema.MetaData`
+ gains the following
+ capabilities:
+
+ * The :class:`.Sequence` will inherit the
+ :paramref:`_schema.MetaData.schema`
+ parameter specified to the target :class:`_schema.MetaData`, which
+ affects the production of CREATE / DROP DDL, if any.
+
+ * The :meth:`.Sequence.create` and :meth:`.Sequence.drop` methods
+ automatically use the engine bound to the :class:`_schema.MetaData`
+ object, if any.
+
+ * The :meth:`_schema.MetaData.create_all` and
+ :meth:`_schema.MetaData.drop_all`
+ methods will emit CREATE / DROP for this :class:`.Sequence`,
+ even if the :class:`.Sequence` is not associated with any
+ :class:`_schema.Table` / :class:`_schema.Column`
+ that's a member of this
+ :class:`_schema.MetaData`.
+
+ The above behaviors can only occur if the :class:`.Sequence` is
+ explicitly associated with the :class:`_schema.MetaData`
+ via this parameter.
+
+ .. seealso::
+
+ :ref:`sequence_metadata` - full discussion of the
+ :paramref:`.Sequence.metadata` parameter.
+
+ :param for_update: Indicates this :class:`.Sequence`, when associated
+ with a :class:`_schema.Column`,
+ should be invoked for UPDATE statements
+ on that column's table, rather than for INSERT statements, when
+ no value is otherwise present for that column in the statement.
+
+ """
+ DefaultGenerator.__init__(self, for_update=for_update)
+ IdentityOptions.__init__(
+ self,
+ start=start,
+ increment=increment,
+ minvalue=minvalue,
+ maxvalue=maxvalue,
+ nominvalue=nominvalue,
+ nomaxvalue=nomaxvalue,
+ cycle=cycle,
+ cache=cache,
+ order=order,
+ )
+ self.name = quoted_name(name, quote)
+ self.optional = optional
+ if schema is BLANK_SCHEMA:
+ self.schema = schema = None
+ elif metadata is not None and schema is None and metadata.schema:
+ self.schema = schema = metadata.schema
+ else:
+ self.schema = quoted_name(schema, quote_schema)
+ self.metadata = metadata
+ self._key = _get_table_key(name, schema)
+ if metadata:
+ self._set_metadata(metadata)
+ if data_type is not None:
+ self.data_type = to_instance(data_type)
+ else:
+ self.data_type = None
+
+ @util.memoized_property
+ def is_callable(self):
+ return False
+
+ @util.memoized_property
+ def is_clause_element(self):
+ return False
+
+ @util.preload_module("sqlalchemy.sql.functions")
+ def next_value(self):
+ """Return a :class:`.next_value` function element
+ which will render the appropriate increment function
+ for this :class:`.Sequence` within any SQL expression.
+
+ """
+ if self.bind:
+ return util.preloaded.sql_functions.func.next_value(
+ self, bind=self.bind
+ )
+ else:
+ return util.preloaded.sql_functions.func.next_value(self)
+
+ def _set_parent(self, column, **kw):
+ super(Sequence, self)._set_parent(column)
+ column._on_table_attach(self._set_table)
+
+ def _set_table(self, column, table):
+ self._set_metadata(table.metadata)
+
+ def _set_metadata(self, metadata):
+ self.metadata = metadata
+ self.metadata._sequences[self._key] = self
+
+ @property
+ def bind(self):
+ if self.metadata:
+ return self.metadata.bind
+ else:
+ return None
+
+ def create(self, bind=None, checkfirst=True):
+ """Creates this sequence in the database.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=True):
+ """Drops this sequence from the database.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ """
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ def _not_a_column_expr(self):
+ raise exc.InvalidRequestError(
+ "This %s cannot be used directly "
+ "as a column expression. Use func.next_value(sequence) "
+ "to produce a 'next value' function that's usable "
+ "as a column element." % self.__class__.__name__
+ )
+
+
+@inspection._self_inspects
+class FetchedValue(SchemaEventTarget):
+ """A marker for a transparent database-side default.
+
+ Use :class:`.FetchedValue` when the database is configured
+ to provide some automatic default for a column.
+
+ E.g.::
+
+ Column('foo', Integer, FetchedValue())
+
+ Would indicate that some trigger or default generator
+ will create a new value for the ``foo`` column during an
+ INSERT.
+
+ .. seealso::
+
+ :ref:`triggered_columns`
+
+ """
+
+ is_server_default = True
+ reflected = False
+ has_argument = False
+ is_clause_element = False
+
+ def __init__(self, for_update=False):
+ self.for_update = for_update
+
+ def _as_for_update(self, for_update):
+ if for_update == self.for_update:
+ return self
+ else:
+ return self._clone(for_update)
+
+ def _clone(self, for_update):
+ n = self.__class__.__new__(self.__class__)
+ n.__dict__.update(self.__dict__)
+ n.__dict__.pop("column", None)
+ n.for_update = for_update
+ return n
+
+ def _set_parent(self, column, **kw):
+ self.column = column
+ if self.for_update:
+ self.column.server_onupdate = self
+ else:
+ self.column.server_default = self
+
+ def __repr__(self):
+ return util.generic_repr(self)
+
+
+class DefaultClause(FetchedValue):
+ """A DDL-specified DEFAULT column value.
+
+ :class:`.DefaultClause` is a :class:`.FetchedValue`
+ that also generates a "DEFAULT" clause when
+ "CREATE TABLE" is emitted.
+
+ :class:`.DefaultClause` is generated automatically
+ whenever the ``server_default``, ``server_onupdate`` arguments of
+ :class:`_schema.Column` are used. A :class:`.DefaultClause`
+ can be passed positionally as well.
+
+ For example, the following::
+
+ Column('foo', Integer, server_default="50")
+
+ Is equivalent to::
+
+ Column('foo', Integer, DefaultClause("50"))
+
+ """
+
+ has_argument = True
+
+ def __init__(self, arg, for_update=False, _reflected=False):
+ util.assert_arg_type(
+ arg, (util.string_types[0], ClauseElement, TextClause), "arg"
+ )
+ super(DefaultClause, self).__init__(for_update)
+ self.arg = arg
+ self.reflected = _reflected
+
+ def __repr__(self):
+ return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
+
+
+class Constraint(DialectKWArgs, SchemaItem):
+ """A table-level SQL constraint.
+
+ :class:`_schema.Constraint` serves as the base class for the series of
+ constraint objects that can be associated with :class:`_schema.Table`
+ objects, including :class:`_schema.PrimaryKeyConstraint`,
+ :class:`_schema.ForeignKeyConstraint`
+ :class:`_schema.UniqueConstraint`, and
+ :class:`_schema.CheckConstraint`.
+
+ """
+
+ __visit_name__ = "constraint"
+
+ def __init__(
+ self,
+ name=None,
+ deferrable=None,
+ initially=None,
+ _create_rule=None,
+ info=None,
+ _type_bound=False,
+ **dialect_kw
+ ):
+ r"""Create a SQL constraint.
+
+ :param name:
+ Optional, the in-database name of this ``Constraint``.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**dialect_kw: Additional keyword arguments are dialect
+ specific, and passed in the form ``<dialectname>_<argname>``. See
+ the documentation regarding an individual dialect at
+ :ref:`dialect_toplevel` for detail on documented arguments.
+
+ :param _create_rule:
+ used internally by some datatypes that also create constraints.
+
+ :param _type_bound:
+ used internally to indicate that this constraint is associated with
+ a specific datatype.
+
+ """
+
+ self.name = name
+ self.deferrable = deferrable
+ self.initially = initially
+ if info:
+ self.info = info
+ self._create_rule = _create_rule
+ self._type_bound = _type_bound
+ util.set_creation_order(self)
+ self._validate_dialect_kwargs(dialect_kw)
+
+ @property
+ def table(self):
+ try:
+ if isinstance(self.parent, Table):
+ return self.parent
+ except AttributeError:
+ pass
+ raise exc.InvalidRequestError(
+ "This constraint is not bound to a table. Did you "
+ "mean to call table.append_constraint(constraint) ?"
+ )
+
+ def _set_parent(self, parent, **kw):
+ self.parent = parent
+ parent.constraints.add(self)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Constraint.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, **kw):
+ return self._copy(**kw)
+
+ def _copy(self, **kw):
+ raise NotImplementedError()
+
+
+class ColumnCollectionMixin(object):
+
+ columns = None
+ """A :class:`_expression.ColumnCollection` of :class:`_schema.Column`
+ objects.
+
+ This collection represents the columns which are referred to by
+ this object.
+
+ """
+
+ _allow_multiple_tables = False
+
+ def __init__(self, *columns, **kw):
+ _autoattach = kw.pop("_autoattach", True)
+ self._column_flag = kw.pop("_column_flag", False)
+ self.columns = DedupeColumnCollection()
+
+ processed_expressions = kw.pop("_gather_expressions", None)
+ if processed_expressions is not None:
+ self._pending_colargs = []
+ for (
+ expr,
+ column,
+ strname,
+ add_element,
+ ) in coercions.expect_col_expression_collection(
+ roles.DDLConstraintColumnRole, columns
+ ):
+ self._pending_colargs.append(add_element)
+ processed_expressions.append(expr)
+ else:
+ self._pending_colargs = [
+ coercions.expect(roles.DDLConstraintColumnRole, column)
+ for column in columns
+ ]
+
+ if _autoattach and self._pending_colargs:
+ self._check_attach()
+
+ def _check_attach(self, evt=False):
+ col_objs = [c for c in self._pending_colargs if isinstance(c, Column)]
+
+ cols_w_table = [c for c in col_objs if isinstance(c.table, Table)]
+
+ cols_wo_table = set(col_objs).difference(cols_w_table)
+ if cols_wo_table:
+ # feature #3341 - place event listeners for Column objects
+ # such that when all those cols are attached, we autoattach.
+ assert not evt, "Should not reach here on event call"
+
+ # issue #3411 - don't do the per-column auto-attach if some of the
+ # columns are specified as strings.
+ has_string_cols = set(
+ c for c in self._pending_colargs if c is not None
+ ).difference(col_objs)
+ if not has_string_cols:
+
+ def _col_attached(column, table):
+ # this isinstance() corresponds with the
+ # isinstance() above; only want to count Table-bound
+ # columns
+ if isinstance(table, Table):
+ cols_wo_table.discard(column)
+ if not cols_wo_table:
+ self._check_attach(evt=True)
+
+ self._cols_wo_table = cols_wo_table
+ for col in cols_wo_table:
+ col._on_table_attach(_col_attached)
+ return
+
+ columns = cols_w_table
+
+ tables = {c.table for c in columns}
+ if len(tables) == 1:
+ self._set_parent_with_dispatch(tables.pop())
+ elif len(tables) > 1 and not self._allow_multiple_tables:
+ table = columns[0].table
+ others = [c for c in columns[1:] if c.table is not table]
+ if others:
+ raise exc.ArgumentError(
+ "Column(s) %s are not part of table '%s'."
+ % (
+ ", ".join("'%s'" % c for c in others),
+ table.description,
+ )
+ )
+
+ def _col_expressions(self, table):
+ return [
+ table.c[col] if isinstance(col, util.string_types) else col
+ for col in self._pending_colargs
+ ]
+
+ def _set_parent(self, table, **kw):
+ for col in self._col_expressions(table):
+ if col is not None:
+ self.columns.add(col)
+
+
+class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
+ """A constraint that proxies a ColumnCollection."""
+
+ def __init__(self, *columns, **kw):
+ r"""
+ :param \*columns:
+ A sequence of column names or Column objects.
+
+ :param name:
+ Optional, the in-database name of this constraint.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param \**kw: other keyword arguments including dialect-specific
+ arguments are propagated to the :class:`.Constraint` superclass.
+
+ """
+ _autoattach = kw.pop("_autoattach", True)
+ _column_flag = kw.pop("_column_flag", False)
+ Constraint.__init__(self, **kw)
+ ColumnCollectionMixin.__init__(
+ self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
+ )
+
+ columns = None
+ """A :class:`_expression.ColumnCollection` representing the set of columns
+ for this constraint.
+
+ """
+
+ def _set_parent(self, table, **kw):
+ Constraint._set_parent(self, table)
+ ColumnCollectionMixin._set_parent(self, table)
+
+ def __contains__(self, x):
+ return x in self.columns
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.ColumnCollectionConstraint.copy` method "
+ "is deprecated and will be removed in a future release.",
+ )
+ def copy(self, target_table=None, **kw):
+ return self._copy(target_table=target_table, **kw)
+
+ def _copy(self, target_table=None, **kw):
+ # ticket #5276
+ constraint_kwargs = {}
+ for dialect_name in self.dialect_options:
+ dialect_options = self.dialect_options[dialect_name]._non_defaults
+ for (
+ dialect_option_key,
+ dialect_option_value,
+ ) in dialect_options.items():
+ constraint_kwargs[
+ dialect_name + "_" + dialect_option_key
+ ] = dialect_option_value
+
+ c = self.__class__(
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ *[
+ _copy_expression(expr, self.parent, target_table)
+ for expr in self.columns
+ ],
+ **constraint_kwargs
+ )
+ return self._schema_item_copy(c)
+
+ def contains_column(self, col):
+ """Return True if this constraint contains the given column.
+
+ Note that this object also contains an attribute ``.columns``
+ which is a :class:`_expression.ColumnCollection` of
+ :class:`_schema.Column` objects.
+
+ """
+
+ return self.columns.contains_column(col)
+
+ def __iter__(self):
+ return iter(self.columns)
+
+ def __len__(self):
+ return len(self.columns)
+
+
+class CheckConstraint(ColumnCollectionConstraint):
+ """A table- or column-level CHECK constraint.
+
+ Can be included in the definition of a Table or Column.
+ """
+
+ _allow_multiple_tables = True
+
+ __visit_name__ = "table_or_column_check_constraint"
+
+ @_document_text_coercion(
+ "sqltext",
+ ":class:`.CheckConstraint`",
+ ":paramref:`.CheckConstraint.sqltext`",
+ )
+ def __init__(
+ self,
+ sqltext,
+ name=None,
+ deferrable=None,
+ initially=None,
+ table=None,
+ info=None,
+ _create_rule=None,
+ _autoattach=True,
+ _type_bound=False,
+ **kw
+ ):
+ r"""Construct a CHECK constraint.
+
+ :param sqltext:
+ A string containing the constraint definition, which will be used
+ verbatim, or a SQL expression construct. If given as a string,
+ the object is converted to a :func:`_expression.text` object.
+ If the textual
+ string includes a colon character, escape this using a backslash::
+
+ CheckConstraint(r"foo ~ E'a(?\:b|c)d")
+
+ :param name:
+ Optional, the in-database name of the constraint.
+
+ :param deferrable:
+ Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
+ issuing DDL for this constraint.
+
+ :param initially:
+ Optional string. If set, emit INITIALLY <value> when issuing DDL
+ for this constraint.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+ self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
+ columns = []
+ visitors.traverse(self.sqltext, {}, {"column": columns.append})
+
+ super(CheckConstraint, self).__init__(
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ _create_rule=_create_rule,
+ info=info,
+ _type_bound=_type_bound,
+ _autoattach=_autoattach,
+ *columns,
+ **kw
+ )
+ if table is not None:
+ self._set_parent_with_dispatch(table)
+
+ @property
+ def is_column_level(self):
+ return not isinstance(self.parent, Table)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.CheckConstraint.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, target_table=None, **kw):
+ return self._copy(target_table=target_table, **kw)
+
+ def _copy(self, target_table=None, **kw):
+ if target_table is not None:
+ # note that target_table is None for the copy process of
+ # a column-bound CheckConstraint, so this path is not reached
+ # in that case.
+ sqltext = _copy_expression(self.sqltext, self.table, target_table)
+ else:
+ sqltext = self.sqltext
+ c = CheckConstraint(
+ sqltext,
+ name=self.name,
+ initially=self.initially,
+ deferrable=self.deferrable,
+ _create_rule=self._create_rule,
+ table=target_table,
+ _autoattach=False,
+ _type_bound=self._type_bound,
+ )
+ return self._schema_item_copy(c)
+
+
+class ForeignKeyConstraint(ColumnCollectionConstraint):
+ """A table-level FOREIGN KEY constraint.
+
+ Defines a single column or composite FOREIGN KEY ... REFERENCES
+ constraint. For a no-frills, single column foreign key, adding a
+ :class:`_schema.ForeignKey` to the definition of a :class:`_schema.Column`
+ is a
+ shorthand equivalent for an unnamed, single column
+ :class:`_schema.ForeignKeyConstraint`.
+
+ Examples of foreign key configuration are in :ref:`metadata_foreignkeys`.
+
+ """
+
+ __visit_name__ = "foreign_key_constraint"
+
+ def __init__(
+ self,
+ columns,
+ refcolumns,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ use_alter=False,
+ link_to_name=False,
+ match=None,
+ table=None,
+ info=None,
+ **dialect_kw
+ ):
+ r"""Construct a composite-capable FOREIGN KEY.
+
+ :param columns: A sequence of local column names. The named columns
+ must be defined and present in the parent Table. The names should
+ match the ``key`` given to each column (defaults to the name) unless
+ ``link_to_name`` is True.
+
+ :param refcolumns: A sequence of foreign column names or Column
+ objects. The columns must all be located within the same Table.
+
+ :param name: Optional, the in-database name of the key.
+
+ :param onupdate: Optional string. If set, emit ON UPDATE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param ondelete: Optional string. If set, emit ON DELETE <value> when
+ issuing DDL for this constraint. Typical values include CASCADE,
+ DELETE and RESTRICT.
+
+ :param deferrable: Optional bool. If set, emit DEFERRABLE or NOT
+ DEFERRABLE when issuing DDL for this constraint.
+
+ :param initially: Optional string. If set, emit INITIALLY <value> when
+ issuing DDL for this constraint.
+
+ :param link_to_name: if True, the string name given in ``column`` is
+ the rendered name of the referenced column, not its locally assigned
+ ``key``.
+
+ :param use_alter: If True, do not emit the DDL for this constraint as
+ part of the CREATE TABLE definition. Instead, generate it via an
+ ALTER TABLE statement issued after the full collection of tables
+ have been created, and drop it via an ALTER TABLE statement before
+ the full collection of tables are dropped.
+
+ The use of :paramref:`_schema.ForeignKeyConstraint.use_alter` is
+ particularly geared towards the case where two or more tables
+ are established within a mutually-dependent foreign key constraint
+ relationship; however, the :meth:`_schema.MetaData.create_all` and
+ :meth:`_schema.MetaData.drop_all`
+ methods will perform this resolution
+ automatically, so the flag is normally not needed.
+
+ .. versionchanged:: 1.0.0 Automatic resolution of foreign key
+ cycles has been added, removing the need to use the
+ :paramref:`_schema.ForeignKeyConstraint.use_alter` in typical use
+ cases.
+
+ .. seealso::
+
+ :ref:`use_alter`
+
+ :param match: Optional string. If set, emit MATCH <value> when issuing
+ DDL for this constraint. Typical values include SIMPLE, PARTIAL
+ and FULL.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**dialect_kw: Additional keyword arguments are dialect
+ specific, and passed in the form ``<dialectname>_<argname>``. See
+ the documentation regarding an individual dialect at
+ :ref:`dialect_toplevel` for detail on documented arguments.
+
+ .. versionadded:: 0.9.2
+
+ """
+
+ Constraint.__init__(
+ self,
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ info=info,
+ **dialect_kw
+ )
+ self.onupdate = onupdate
+ self.ondelete = ondelete
+ self.link_to_name = link_to_name
+ self.use_alter = use_alter
+ self.match = match
+
+ if len(set(columns)) != len(refcolumns):
+ if len(set(columns)) != len(columns):
+ # e.g. FOREIGN KEY (a, a) REFERENCES r (b, c)
+ raise exc.ArgumentError(
+ "ForeignKeyConstraint with duplicate source column "
+ "references are not supported."
+ )
+ else:
+ # e.g. FOREIGN KEY (a) REFERENCES r (b, c)
+ # paraphrasing
+ # https://www.postgresql.org/docs/current/static/ddl-constraints.html
+ raise exc.ArgumentError(
+ "ForeignKeyConstraint number "
+ "of constrained columns must match the number of "
+ "referenced columns."
+ )
+
+ # standalone ForeignKeyConstraint - create
+ # associated ForeignKey objects which will be applied to hosted
+ # Column objects (in col.foreign_keys), either now or when attached
+ # to the Table for string-specified names
+ self.elements = [
+ ForeignKey(
+ refcol,
+ _constraint=self,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ use_alter=self.use_alter,
+ link_to_name=self.link_to_name,
+ match=self.match,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ **self.dialect_kwargs
+ )
+ for refcol in refcolumns
+ ]
+
+ ColumnCollectionMixin.__init__(self, *columns)
+ if table is not None:
+ if hasattr(self, "parent"):
+ assert table is self.parent
+ self._set_parent_with_dispatch(table)
+
+ def _append_element(self, column, fk):
+ self.columns.add(column)
+ self.elements.append(fk)
+
+ columns = None
+ """A :class:`_expression.ColumnCollection` representing the set of columns
+ for this constraint.
+
+ """
+
+ elements = None
+ """A sequence of :class:`_schema.ForeignKey` objects.
+
+ Each :class:`_schema.ForeignKey`
+ represents a single referring column/referred
+ column pair.
+
+ This collection is intended to be read-only.
+
+ """
+
+ @property
+ def _elements(self):
+ # legacy - provide a dictionary view of (column_key, fk)
+ return util.OrderedDict(zip(self.column_keys, self.elements))
+
+ @property
+ def _referred_schema(self):
+ for elem in self.elements:
+ return elem._referred_schema
+ else:
+ return None
+
+ @property
+ def referred_table(self):
+ """The :class:`_schema.Table` object to which this
+ :class:`_schema.ForeignKeyConstraint` references.
+
+ This is a dynamically calculated attribute which may not be available
+ if the constraint and/or parent table is not yet associated with
+ a metadata collection that contains the referred table.
+
+ .. versionadded:: 1.0.0
+
+ """
+ return self.elements[0].column.table
+
+ def _validate_dest_table(self, table):
+ table_keys = set([elem._table_key() for elem in self.elements])
+ if None not in table_keys and len(table_keys) > 1:
+ elem0, elem1 = sorted(table_keys)[0:2]
+ raise exc.ArgumentError(
+ "ForeignKeyConstraint on %s(%s) refers to "
+ "multiple remote tables: %s and %s"
+ % (table.fullname, self._col_description, elem0, elem1)
+ )
+
+ @property
+ def column_keys(self):
+ """Return a list of string keys representing the local
+ columns in this :class:`_schema.ForeignKeyConstraint`.
+
+ This list is either the original string arguments sent
+ to the constructor of the :class:`_schema.ForeignKeyConstraint`,
+ or if the constraint has been initialized with :class:`_schema.Column`
+ objects, is the string ``.key`` of each element.
+
+ .. versionadded:: 1.0.0
+
+ """
+ if hasattr(self, "parent"):
+ return self.columns.keys()
+ else:
+ return [
+ col.key if isinstance(col, ColumnElement) else str(col)
+ for col in self._pending_colargs
+ ]
+
+ @property
+ def _col_description(self):
+ return ", ".join(self.column_keys)
+
+ def _set_parent(self, table, **kw):
+ Constraint._set_parent(self, table)
+
+ try:
+ ColumnCollectionConstraint._set_parent(self, table)
+ except KeyError as ke:
+ util.raise_(
+ exc.ArgumentError(
+ "Can't create ForeignKeyConstraint "
+ "on table '%s': no column "
+ "named '%s' is present." % (table.description, ke.args[0])
+ ),
+ from_=ke,
+ )
+
+ for col, fk in zip(self.columns, self.elements):
+ if not hasattr(fk, "parent") or fk.parent is not col:
+ fk._set_parent_with_dispatch(col)
+
+ self._validate_dest_table(table)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.ForeignKeyConstraint.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, schema=None, target_table=None, **kw):
+ return self._copy(schema=schema, target_table=target_table, **kw)
+
+ def _copy(self, schema=None, target_table=None, **kw):
+ fkc = ForeignKeyConstraint(
+ [x.parent.key for x in self.elements],
+ [
+ x._get_colspec(
+ schema=schema,
+ table_name=target_table.name
+ if target_table is not None
+ and x._table_key() == x.parent.table.key
+ else None,
+ )
+ for x in self.elements
+ ],
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ use_alter=self.use_alter,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ link_to_name=self.link_to_name,
+ match=self.match,
+ )
+ for self_fk, other_fk in zip(self.elements, fkc.elements):
+ self_fk._schema_item_copy(other_fk)
+ return self._schema_item_copy(fkc)
+
+
+class PrimaryKeyConstraint(ColumnCollectionConstraint):
+ """A table-level PRIMARY KEY constraint.
+
+ The :class:`.PrimaryKeyConstraint` object is present automatically
+ on any :class:`_schema.Table` object; it is assigned a set of
+ :class:`_schema.Column` objects corresponding to those marked with
+ the :paramref:`_schema.Column.primary_key` flag::
+
+ >>> my_table = Table('mytable', metadata,
+ ... Column('id', Integer, primary_key=True),
+ ... Column('version_id', Integer, primary_key=True),
+ ... Column('data', String(50))
+ ... )
+ >>> my_table.primary_key
+ PrimaryKeyConstraint(
+ Column('id', Integer(), table=<mytable>,
+ primary_key=True, nullable=False),
+ Column('version_id', Integer(), table=<mytable>,
+ primary_key=True, nullable=False)
+ )
+
+ The primary key of a :class:`_schema.Table` can also be specified by using
+ a :class:`.PrimaryKeyConstraint` object explicitly; in this mode of usage,
+ the "name" of the constraint can also be specified, as well as other
+ options which may be recognized by dialects::
+
+ my_table = Table('mytable', metadata,
+ Column('id', Integer),
+ Column('version_id', Integer),
+ Column('data', String(50)),
+ PrimaryKeyConstraint('id', 'version_id',
+ name='mytable_pk')
+ )
+
+ The two styles of column-specification should generally not be mixed.
+ An warning is emitted if the columns present in the
+ :class:`.PrimaryKeyConstraint`
+ don't match the columns that were marked as ``primary_key=True``, if both
+ are present; in this case, the columns are taken strictly from the
+ :class:`.PrimaryKeyConstraint` declaration, and those columns otherwise
+ marked as ``primary_key=True`` are ignored. This behavior is intended to
+ be backwards compatible with previous behavior.
+
+ .. versionchanged:: 0.9.2 Using a mixture of columns within a
+ :class:`.PrimaryKeyConstraint` in addition to columns marked as
+ ``primary_key=True`` now emits a warning if the lists don't match.
+ The ultimate behavior of ignoring those columns marked with the flag
+ only is currently maintained for backwards compatibility; this warning
+ may raise an exception in a future release.
+
+ For the use case where specific options are to be specified on the
+ :class:`.PrimaryKeyConstraint`, but the usual style of using
+ ``primary_key=True`` flags is still desirable, an empty
+ :class:`.PrimaryKeyConstraint` may be specified, which will take on the
+ primary key column collection from the :class:`_schema.Table` based on the
+ flags::
+
+ my_table = Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('version_id', Integer, primary_key=True),
+ Column('data', String(50)),
+ PrimaryKeyConstraint(name='mytable_pk',
+ mssql_clustered=True)
+ )
+
+ .. versionadded:: 0.9.2 an empty :class:`.PrimaryKeyConstraint` may now
+ be specified for the purposes of establishing keyword arguments with
+ the constraint, independently of the specification of "primary key"
+ columns within the :class:`_schema.Table` itself; columns marked as
+ ``primary_key=True`` will be gathered into the empty constraint's
+ column collection.
+
+ """
+
+ __visit_name__ = "primary_key_constraint"
+
+ def __init__(self, *columns, **kw):
+ self._implicit_generated = kw.pop("_implicit_generated", False)
+ super(PrimaryKeyConstraint, self).__init__(*columns, **kw)
+
+ def _set_parent(self, table, **kw):
+ super(PrimaryKeyConstraint, self)._set_parent(table)
+
+ if table.primary_key is not self:
+ table.constraints.discard(table.primary_key)
+ table.primary_key = self
+ table.constraints.add(self)
+
+ table_pks = [c for c in table.c if c.primary_key]
+ if self.columns and table_pks and set(table_pks) != set(self.columns):
+ util.warn(
+ "Table '%s' specifies columns %s as primary_key=True, "
+ "not matching locally specified columns %s; setting the "
+ "current primary key columns to %s. This warning "
+ "may become an exception in a future release"
+ % (
+ table.name,
+ ", ".join("'%s'" % c.name for c in table_pks),
+ ", ".join("'%s'" % c.name for c in self.columns),
+ ", ".join("'%s'" % c.name for c in self.columns),
+ )
+ )
+ table_pks[:] = []
+
+ for c in self.columns:
+ c.primary_key = True
+ if c._user_defined_nullable is NULL_UNSPECIFIED:
+ c.nullable = False
+ if table_pks:
+ self.columns.extend(table_pks)
+
+ def _reload(self, columns):
+ """repopulate this :class:`.PrimaryKeyConstraint` given
+ a set of columns.
+
+ Existing columns in the table that are marked as primary_key=True
+ are maintained.
+
+ Also fires a new event.
+
+ This is basically like putting a whole new
+ :class:`.PrimaryKeyConstraint` object on the parent
+ :class:`_schema.Table` object without actually replacing the object.
+
+ The ordering of the given list of columns is also maintained; these
+ columns will be appended to the list of columns after any which
+ are already present.
+
+ """
+ # set the primary key flag on new columns.
+ # note any existing PK cols on the table also have their
+ # flag still set.
+ for col in columns:
+ col.primary_key = True
+
+ self.columns.extend(columns)
+
+ PrimaryKeyConstraint._autoincrement_column._reset(self)
+ self._set_parent_with_dispatch(self.table)
+
+ def _replace(self, col):
+ PrimaryKeyConstraint._autoincrement_column._reset(self)
+ self.columns.replace(col)
+
+ self.dispatch._sa_event_column_added_to_pk_constraint(self, col)
+
+ @property
+ def columns_autoinc_first(self):
+ autoinc = self._autoincrement_column
+
+ if autoinc is not None:
+ return [autoinc] + [c for c in self.columns if c is not autoinc]
+ else:
+ return list(self.columns)
+
+ @util.memoized_property
+ def _autoincrement_column(self):
+ def _validate_autoinc(col, autoinc_true):
+ if col.type._type_affinity is None or not issubclass(
+ col.type._type_affinity,
+ (
+ type_api.INTEGERTYPE._type_affinity,
+ type_api.NUMERICTYPE._type_affinity,
+ ),
+ ):
+ if autoinc_true:
+ raise exc.ArgumentError(
+ "Column type %s on column '%s' is not "
+ "compatible with autoincrement=True" % (col.type, col)
+ )
+ else:
+ return False
+ elif (
+ not isinstance(col.default, (type(None), Sequence))
+ and not autoinc_true
+ ):
+ return False
+ elif (
+ col.server_default is not None
+ and not isinstance(col.server_default, Identity)
+ and not autoinc_true
+ ):
+ return False
+ elif col.foreign_keys and col.autoincrement not in (
+ True,
+ "ignore_fk",
+ ):
+ return False
+ return True
+
+ if len(self.columns) == 1:
+ col = list(self.columns)[0]
+
+ if col.autoincrement is True:
+ _validate_autoinc(col, True)
+ return col
+ elif (
+ col.autoincrement
+ in (
+ "auto",
+ "ignore_fk",
+ )
+ and _validate_autoinc(col, False)
+ ):
+ return col
+
+ else:
+ autoinc = None
+ for col in self.columns:
+ if col.autoincrement is True:
+ _validate_autoinc(col, True)
+ if autoinc is not None:
+ raise exc.ArgumentError(
+ "Only one Column may be marked "
+ "autoincrement=True, found both %s and %s."
+ % (col.name, autoinc.name)
+ )
+ else:
+ autoinc = col
+
+ return autoinc
+
+
+class UniqueConstraint(ColumnCollectionConstraint):
+ """A table-level UNIQUE constraint.
+
+ Defines a single column or composite UNIQUE constraint. For a no-frills,
+ single column constraint, adding ``unique=True`` to the ``Column``
+ definition is a shorthand equivalent for an unnamed, single column
+ UniqueConstraint.
+ """
+
+ __visit_name__ = "unique_constraint"
+
+
+class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
+ """A table-level INDEX.
+
+ Defines a composite (one or more column) INDEX.
+
+ E.g.::
+
+ sometable = Table("sometable", metadata,
+ Column("name", String(50)),
+ Column("address", String(100))
+ )
+
+ Index("some_index", sometable.c.name)
+
+ For a no-frills, single column index, adding
+ :class:`_schema.Column` also supports ``index=True``::
+
+ sometable = Table("sometable", metadata,
+ Column("name", String(50), index=True)
+ )
+
+ For a composite index, multiple columns can be specified::
+
+ Index("some_index", sometable.c.name, sometable.c.address)
+
+ Functional indexes are supported as well, typically by using the
+ :data:`.func` construct in conjunction with table-bound
+ :class:`_schema.Column` objects::
+
+ Index("some_index", func.lower(sometable.c.name))
+
+ An :class:`.Index` can also be manually associated with a
+ :class:`_schema.Table`,
+ either through inline declaration or using
+ :meth:`_schema.Table.append_constraint`. When this approach is used,
+ the names
+ of the indexed columns can be specified as strings::
+
+ Table("sometable", metadata,
+ Column("name", String(50)),
+ Column("address", String(100)),
+ Index("some_index", "name", "address")
+ )
+
+ To support functional or expression-based indexes in this form, the
+ :func:`_expression.text` construct may be used::
+
+ from sqlalchemy import text
+
+ Table("sometable", metadata,
+ Column("name", String(50)),
+ Column("address", String(100)),
+ Index("some_index", text("lower(name)"))
+ )
+
+ .. versionadded:: 0.9.5 the :func:`_expression.text`
+ construct may be used to
+ specify :class:`.Index` expressions, provided the :class:`.Index`
+ is explicitly associated with the :class:`_schema.Table`.
+
+
+ .. seealso::
+
+ :ref:`schema_indexes` - General information on :class:`.Index`.
+
+ :ref:`postgresql_indexes` - PostgreSQL-specific options available for
+ the :class:`.Index` construct.
+
+ :ref:`mysql_indexes` - MySQL-specific options available for the
+ :class:`.Index` construct.
+
+ :ref:`mssql_indexes` - MSSQL-specific options available for the
+ :class:`.Index` construct.
+
+ """
+
+ __visit_name__ = "index"
+
+ def __init__(self, name, *expressions, **kw):
+ r"""Construct an index object.
+
+ :param name:
+ The name of the index
+
+ :param \*expressions:
+ Column expressions to include in the index. The expressions
+ are normally instances of :class:`_schema.Column`, but may also
+ be arbitrary SQL expressions which ultimately refer to a
+ :class:`_schema.Column`.
+
+ :param unique=False:
+ Keyword only argument; if True, create a unique index.
+
+ :param quote=None:
+ Keyword only argument; whether to apply quoting to the name of
+ the index. Works in the same manner as that of
+ :paramref:`_schema.Column.quote`.
+
+ :param info=None: Optional data dictionary which will be populated
+ into the :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param \**kw: Additional keyword arguments not mentioned above are
+ dialect specific, and passed in the form
+ ``<dialectname>_<argname>``. See the documentation regarding an
+ individual dialect at :ref:`dialect_toplevel` for detail on
+ documented arguments.
+
+ """
+ self.table = table = None
+
+ self.name = quoted_name(name, kw.pop("quote", None))
+ self.unique = kw.pop("unique", False)
+ _column_flag = kw.pop("_column_flag", False)
+ if "info" in kw:
+ self.info = kw.pop("info")
+
+ # TODO: consider "table" argument being public, but for
+ # the purpose of the fix here, it starts as private.
+ if "_table" in kw:
+ table = kw.pop("_table")
+
+ self._validate_dialect_kwargs(kw)
+
+ self.expressions = []
+ # will call _set_parent() if table-bound column
+ # objects are present
+ ColumnCollectionMixin.__init__(
+ self,
+ *expressions,
+ _column_flag=_column_flag,
+ _gather_expressions=self.expressions
+ )
+
+ if table is not None:
+ self._set_parent(table)
+
+ def _set_parent(self, table, **kw):
+ ColumnCollectionMixin._set_parent(self, table)
+
+ if self.table is not None and table is not self.table:
+ raise exc.ArgumentError(
+ "Index '%s' is against table '%s', and "
+ "cannot be associated with table '%s'."
+ % (self.name, self.table.description, table.description)
+ )
+ self.table = table
+ table.indexes.add(self)
+
+ expressions = self.expressions
+ col_expressions = self._col_expressions(table)
+ assert len(expressions) == len(col_expressions)
+ self.expressions = [
+ expr if isinstance(expr, ClauseElement) else colexpr
+ for expr, colexpr in zip(expressions, col_expressions)
+ ]
+
+ @property
+ def bind(self):
+ """Return the connectable associated with this Index."""
+
+ return self.table.bind
+
+ def create(self, bind=None, checkfirst=False):
+ """Issue a ``CREATE`` statement for this
+ :class:`.Index`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.create_all`.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
+ return self
+
+ def drop(self, bind=None, checkfirst=False):
+ """Issue a ``DROP`` statement for this
+ :class:`.Index`, using the given :class:`.Connectable`
+ for connectivity.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :meth:`_schema.MetaData.drop_all`.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ def __repr__(self):
+ return "Index(%s)" % (
+ ", ".join(
+ [repr(self.name)]
+ + [repr(e) for e in self.expressions]
+ + (self.unique and ["unique=True"] or [])
+ )
+ )
+
+
+DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"})
+
+
+class MetaData(SchemaItem):
+ """A collection of :class:`_schema.Table`
+ objects and their associated schema
+ constructs.
+
+ Holds a collection of :class:`_schema.Table` objects as well as
+ an optional binding to an :class:`_engine.Engine` or
+ :class:`_engine.Connection`. If bound, the :class:`_schema.Table` objects
+ in the collection and their columns may participate in implicit SQL
+ execution.
+
+ The :class:`_schema.Table` objects themselves are stored in the
+ :attr:`_schema.MetaData.tables` dictionary.
+
+ :class:`_schema.MetaData` is a thread-safe object for read operations.
+ Construction of new tables within a single :class:`_schema.MetaData`
+ object,
+ either explicitly or via reflection, may not be completely thread-safe.
+
+ .. seealso::
+
+ :ref:`metadata_describing` - Introduction to database metadata
+
+ """
+
+ __visit_name__ = "metadata"
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_schema.MetaData.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ bind=None,
+ schema=None,
+ quote_schema=None,
+ naming_convention=None,
+ info=None,
+ ):
+ """Create a new MetaData object.
+
+ :param bind:
+ An Engine or Connection to bind to. May also be a string or URL
+ instance, these are passed to :func:`_sa.create_engine` and
+ this :class:`_schema.MetaData` will
+ be bound to the resulting engine.
+
+ :param schema:
+ The default schema to use for the :class:`_schema.Table`,
+ :class:`.Sequence`, and potentially other objects associated with
+ this :class:`_schema.MetaData`. Defaults to ``None``.
+
+ .. seealso::
+
+ :ref:`schema_metadata_schema_name` - details on how the
+ :paramref:`_schema.MetaData.schema` parameter is used.
+
+ :paramref:`_schema.Table.schema`
+
+ :paramref:`.Sequence.schema`
+
+ :param quote_schema:
+ Sets the ``quote_schema`` flag for those :class:`_schema.Table`,
+ :class:`.Sequence`, and other objects which make usage of the
+ local ``schema`` name.
+
+ :param info: Optional data dictionary which will be populated into the
+ :attr:`.SchemaItem.info` attribute of this object.
+
+ .. versionadded:: 1.0.0
+
+ :param naming_convention: a dictionary referring to values which
+ will establish default naming conventions for :class:`.Constraint`
+ and :class:`.Index` objects, for those objects which are not given
+ a name explicitly.
+
+ The keys of this dictionary may be:
+
+ * a constraint or Index class, e.g. the :class:`.UniqueConstraint`,
+ :class:`_schema.ForeignKeyConstraint` class, the :class:`.Index`
+ class
+
+ * a string mnemonic for one of the known constraint classes;
+ ``"fk"``, ``"pk"``, ``"ix"``, ``"ck"``, ``"uq"`` for foreign key,
+ primary key, index, check, and unique constraint, respectively.
+
+ * the string name of a user-defined "token" that can be used
+ to define new naming tokens.
+
+ The values associated with each "constraint class" or "constraint
+ mnemonic" key are string naming templates, such as
+ ``"uq_%(table_name)s_%(column_0_name)s"``,
+ which describe how the name should be composed. The values
+ associated with user-defined "token" keys should be callables of the
+ form ``fn(constraint, table)``, which accepts the constraint/index
+ object and :class:`_schema.Table` as arguments, returning a string
+ result.
+
+ The built-in names are as follows, some of which may only be
+ available for certain types of constraint:
+
+ * ``%(table_name)s`` - the name of the :class:`_schema.Table`
+ object
+ associated with the constraint.
+
+ * ``%(referred_table_name)s`` - the name of the
+ :class:`_schema.Table`
+ object associated with the referencing target of a
+ :class:`_schema.ForeignKeyConstraint`.
+
+ * ``%(column_0_name)s`` - the name of the :class:`_schema.Column`
+ at
+ index position "0" within the constraint.
+
+ * ``%(column_0N_name)s`` - the name of all :class:`_schema.Column`
+ objects in order within the constraint, joined without a
+ separator.
+
+ * ``%(column_0_N_name)s`` - the name of all
+ :class:`_schema.Column`
+ objects in order within the constraint, joined with an
+ underscore as a separator.
+
+ * ``%(column_0_label)s``, ``%(column_0N_label)s``,
+ ``%(column_0_N_label)s`` - the label of either the zeroth
+ :class:`_schema.Column` or all :class:`.Columns`, separated with
+ or without an underscore
+
+ * ``%(column_0_key)s``, ``%(column_0N_key)s``,
+ ``%(column_0_N_key)s`` - the key of either the zeroth
+ :class:`_schema.Column` or all :class:`.Columns`, separated with
+ or without an underscore
+
+ * ``%(referred_column_0_name)s``, ``%(referred_column_0N_name)s``
+ ``%(referred_column_0_N_name)s``, ``%(referred_column_0_key)s``,
+ ``%(referred_column_0N_key)s``, ... column tokens which
+ render the names/keys/labels of columns that are referenced
+ by a :class:`_schema.ForeignKeyConstraint`.
+
+ * ``%(constraint_name)s`` - a special key that refers to the
+ existing name given to the constraint. When this key is
+ present, the :class:`.Constraint` object's existing name will be
+ replaced with one that is composed from template string that
+ uses this token. When this token is present, it is required that
+ the :class:`.Constraint` is given an explicit name ahead of time.
+
+ * user-defined: any additional token may be implemented by passing
+ it along with a ``fn(constraint, table)`` callable to the
+ naming_convention dictionary.
+
+ .. versionadded:: 1.3.0 - added new ``%(column_0N_name)s``,
+ ``%(column_0_N_name)s``, and related tokens that produce
+ concatenations of names, keys, or labels for all columns referred
+ to by a given constraint.
+
+ .. seealso::
+
+ :ref:`constraint_naming_conventions` - for detailed usage
+ examples.
+
+ """
+ self.tables = util.FacadeDict()
+ self.schema = quoted_name(schema, quote_schema)
+ self.naming_convention = (
+ naming_convention
+ if naming_convention
+ else DEFAULT_NAMING_CONVENTION
+ )
+ if info:
+ self.info = info
+ self._schemas = set()
+ self._sequences = {}
+ self._fk_memos = collections.defaultdict(list)
+
+ self.bind = bind
+
+ tables = None
+ """A dictionary of :class:`_schema.Table`
+ objects keyed to their name or "table key".
+
+ The exact key is that determined by the :attr:`_schema.Table.key`
+ attribute;
+ for a table with no :attr:`_schema.Table.schema` attribute,
+ this is the same
+ as :attr:`_schema.Table.name`. For a table with a schema,
+ it is typically of the
+ form ``schemaname.tablename``.
+
+ .. seealso::
+
+ :attr:`_schema.MetaData.sorted_tables`
+
+ """
+
+ def __repr__(self):
+ if self.bind:
+ return "MetaData(bind=%r)" % self.bind
+ else:
+ return "MetaData()"
+
+ def __contains__(self, table_or_key):
+ if not isinstance(table_or_key, util.string_types):
+ table_or_key = table_or_key.key
+ return table_or_key in self.tables
+
+ def _add_table(self, name, schema, table):
+ key = _get_table_key(name, schema)
+ self.tables._insert_item(key, table)
+ if schema:
+ self._schemas.add(schema)
+
+ def _remove_table(self, name, schema):
+ key = _get_table_key(name, schema)
+ removed = dict.pop(self.tables, key, None)
+ if removed is not None:
+ for fk in removed.foreign_keys:
+ fk._remove_from_metadata(self)
+ if self._schemas:
+ self._schemas = set(
+ [
+ t.schema
+ for t in self.tables.values()
+ if t.schema is not None
+ ]
+ )
+
+ def __getstate__(self):
+ return {
+ "tables": self.tables,
+ "schema": self.schema,
+ "schemas": self._schemas,
+ "sequences": self._sequences,
+ "fk_memos": self._fk_memos,
+ "naming_convention": self.naming_convention,
+ }
+
+ def __setstate__(self, state):
+ self.tables = state["tables"]
+ self.schema = state["schema"]
+ self.naming_convention = state["naming_convention"]
+ self._bind = None
+ self._sequences = state["sequences"]
+ self._schemas = state["schemas"]
+ self._fk_memos = state["fk_memos"]
+
+ def is_bound(self):
+ """True if this MetaData is bound to an Engine or Connection."""
+
+ return self._bind is not None
+
+ def bind(self):
+ """An :class:`_engine.Engine` or :class:`_engine.Connection`
+ to which this
+ :class:`_schema.MetaData` is bound.
+
+ Typically, a :class:`_engine.Engine` is assigned to this attribute
+ so that "implicit execution" may be used, or alternatively
+ as a means of providing engine binding information to an
+ ORM :class:`.Session` object::
+
+ engine = create_engine("someurl://")
+ metadata.bind = engine
+
+ .. deprecated :: 1.4
+
+ The metadata.bind attribute, as part of the deprecated system
+ of "implicit execution", is itself deprecated and will be
+ removed in SQLAlchemy 2.0.
+
+ .. seealso::
+
+ :ref:`dbengine_implicit` - background on "bound metadata"
+
+ """
+ return self._bind
+
+ @util.preload_module("sqlalchemy.engine.url")
+ def _bind_to(self, bind):
+ """Bind this MetaData to an Engine, Connection, string or URL."""
+ url = util.preloaded.engine_url
+ if isinstance(bind, util.string_types + (url.URL,)):
+ self._bind = sqlalchemy.create_engine(bind)
+ else:
+ self._bind = bind
+
+ bind = property(bind, _bind_to)
+
+ def clear(self):
+ """Clear all Table objects from this MetaData."""
+
+ dict.clear(self.tables)
+ self._schemas.clear()
+ self._fk_memos.clear()
+
+ def remove(self, table):
+ """Remove the given Table object from this MetaData."""
+
+ self._remove_table(table.name, table.schema)
+
+ @property
+ def sorted_tables(self):
+ """Returns a list of :class:`_schema.Table` objects sorted in order of
+ foreign key dependency.
+
+ The sorting will place :class:`_schema.Table`
+ objects that have dependencies
+ first, before the dependencies themselves, representing the
+ order in which they can be created. To get the order in which
+ the tables would be dropped, use the ``reversed()`` Python built-in.
+
+ .. warning::
+
+ The :attr:`.MetaData.sorted_tables` attribute cannot by itself
+ accommodate automatic resolution of dependency cycles between
+ tables, which are usually caused by mutually dependent foreign key
+ constraints. When these cycles are detected, the foreign keys
+ of these tables are omitted from consideration in the sort.
+ A warning is emitted when this condition occurs, which will be an
+ exception raise in a future release. Tables which are not part
+ of the cycle will still be returned in dependency order.
+
+ To resolve these cycles, the
+ :paramref:`_schema.ForeignKeyConstraint.use_alter` parameter may be
+ applied to those constraints which create a cycle. Alternatively,
+ the :func:`_schema.sort_tables_and_constraints` function will
+ automatically return foreign key constraints in a separate
+ collection when cycles are detected so that they may be applied
+ to a schema separately.
+
+ .. versionchanged:: 1.3.17 - a warning is emitted when
+ :attr:`.MetaData.sorted_tables` cannot perform a proper sort
+ due to cyclical dependencies. This will be an exception in a
+ future release. Additionally, the sort will continue to return
+ other tables not involved in the cycle in dependency order which
+ was not the case previously.
+
+ .. seealso::
+
+ :func:`_schema.sort_tables`
+
+ :func:`_schema.sort_tables_and_constraints`
+
+ :attr:`_schema.MetaData.tables`
+
+ :meth:`_reflection.Inspector.get_table_names`
+
+ :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names`
+
+
+ """
+ return ddl.sort_tables(
+ sorted(self.tables.values(), key=lambda t: t.key)
+ )
+
+ def reflect(
+ self,
+ bind=None,
+ schema=None,
+ views=False,
+ only=None,
+ extend_existing=False,
+ autoload_replace=True,
+ resolve_fks=True,
+ **dialect_kwargs
+ ):
+ r"""Load all available table definitions from the database.
+
+ Automatically creates ``Table`` entries in this ``MetaData`` for any
+ table available in the database but not yet present in the
+ ``MetaData``. May be called multiple times to pick up tables recently
+ added to the database, however no special action is taken if a table
+ in this ``MetaData`` no longer exists in the database.
+
+ :param bind:
+ A :class:`.Connectable` used to access the database; if None, uses
+ the existing bind on this ``MetaData``, if any.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ :param schema:
+ Optional, query and reflect tables from an alternate schema.
+ If None, the schema associated with this :class:`_schema.MetaData`
+ is used, if any.
+
+ :param views:
+ If True, also reflect views.
+
+ :param only:
+ Optional. Load only a sub-set of available named tables. May be
+ specified as a sequence of names or a callable.
+
+ If a sequence of names is provided, only those tables will be
+ reflected. An error is raised if a table is requested but not
+ available. Named tables already present in this ``MetaData`` are
+ ignored.
+
+ If a callable is provided, it will be used as a boolean predicate to
+ filter the list of potential table names. The callable is called
+ with a table name and this ``MetaData`` instance as positional
+ arguments and should return a true value for any table to reflect.
+
+ :param extend_existing: Passed along to each :class:`_schema.Table` as
+ :paramref:`_schema.Table.extend_existing`.
+
+ .. versionadded:: 0.9.1
+
+ :param autoload_replace: Passed along to each :class:`_schema.Table`
+ as
+ :paramref:`_schema.Table.autoload_replace`.
+
+ .. versionadded:: 0.9.1
+
+ :param resolve_fks: if True, reflect :class:`_schema.Table`
+ objects linked
+ to :class:`_schema.ForeignKey` objects located in each
+ :class:`_schema.Table`.
+ For :meth:`_schema.MetaData.reflect`,
+ this has the effect of reflecting
+ related tables that might otherwise not be in the list of tables
+ being reflected, for example if the referenced table is in a
+ different schema or is omitted via the
+ :paramref:`.MetaData.reflect.only` parameter. When False,
+ :class:`_schema.ForeignKey` objects are not followed to the
+ :class:`_schema.Table`
+ in which they link, however if the related table is also part of the
+ list of tables that would be reflected in any case, the
+ :class:`_schema.ForeignKey` object will still resolve to its related
+ :class:`_schema.Table` after the :meth:`_schema.MetaData.reflect`
+ operation is
+ complete. Defaults to True.
+
+ .. versionadded:: 1.3.0
+
+ .. seealso::
+
+ :paramref:`_schema.Table.resolve_fks`
+
+ :param \**dialect_kwargs: Additional keyword arguments not mentioned
+ above are dialect specific, and passed in the form
+ ``<dialectname>_<argname>``. See the documentation regarding an
+ individual dialect at :ref:`dialect_toplevel` for detail on
+ documented arguments.
+
+ .. versionadded:: 0.9.2 - Added
+ :paramref:`.MetaData.reflect.**dialect_kwargs` to support
+ dialect-level reflection options for all :class:`_schema.Table`
+ objects reflected.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+
+ with inspection.inspect(bind)._inspection_context() as insp:
+ reflect_opts = {
+ "autoload_with": insp,
+ "extend_existing": extend_existing,
+ "autoload_replace": autoload_replace,
+ "resolve_fks": resolve_fks,
+ "_extend_on": set(),
+ }
+
+ reflect_opts.update(dialect_kwargs)
+
+ if schema is None:
+ schema = self.schema
+
+ if schema is not None:
+ reflect_opts["schema"] = schema
+
+ available = util.OrderedSet(insp.get_table_names(schema))
+ if views:
+ available.update(insp.get_view_names(schema))
+
+ if schema is not None:
+ available_w_schema = util.OrderedSet(
+ ["%s.%s" % (schema, name) for name in available]
+ )
+ else:
+ available_w_schema = available
+
+ current = set(self.tables)
+
+ if only is None:
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if extend_existing or schname not in current
+ ]
+ elif callable(only):
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if (extend_existing or schname not in current)
+ and only(name, self)
+ ]
+ else:
+ missing = [name for name in only if name not in available]
+ if missing:
+ s = schema and (" schema '%s'" % schema) or ""
+ raise exc.InvalidRequestError(
+ "Could not reflect: requested table(s) not available "
+ "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing))
+ )
+ load = [
+ name
+ for name in only
+ if extend_existing or name not in current
+ ]
+
+ for name in load:
+ try:
+ Table(name, self, **reflect_opts)
+ except exc.UnreflectableTableError as uerr:
+ util.warn("Skipping table %s: %s" % (name, uerr))
+
+ def create_all(self, bind=None, tables=None, checkfirst=True):
+ """Create all tables stored in this metadata.
+
+ Conditional by default, will not attempt to recreate tables already
+ present in the target database.
+
+ :param bind:
+ A :class:`.Connectable` used to access the
+ database; if None, uses the existing bind on this ``MetaData``, if
+ any.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ :param tables:
+ Optional list of ``Table`` objects, which is a subset of the total
+ tables in the ``MetaData`` (others are ignored).
+
+ :param checkfirst:
+ Defaults to True, don't issue CREATEs for tables already present
+ in the target database.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(
+ ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables
+ )
+
+ def drop_all(self, bind=None, tables=None, checkfirst=True):
+ """Drop all tables stored in this metadata.
+
+ Conditional by default, will not attempt to drop tables not present in
+ the target database.
+
+ :param bind:
+ A :class:`.Connectable` used to access the
+ database; if None, uses the existing bind on this ``MetaData``, if
+ any.
+
+ .. note:: the "bind" argument will be required in
+ SQLAlchemy 2.0.
+
+ :param tables:
+ Optional list of ``Table`` objects, which is a subset of the
+ total tables in the ``MetaData`` (others are ignored).
+
+ :param checkfirst:
+ Defaults to True, only issue DROPs for tables confirmed to be
+ present in the target database.
+
+ """
+ if bind is None:
+ bind = _bind_or_error(self)
+ bind._run_ddl_visitor(
+ ddl.SchemaDropper, self, checkfirst=checkfirst, tables=tables
+ )
+
+
+@util.deprecated_cls(
+ "1.4",
+ ":class:`.ThreadLocalMetaData` is deprecated and will be removed "
+ "in a future release.",
+ constructor="__init__",
+)
+class ThreadLocalMetaData(MetaData):
+ """A MetaData variant that presents a different ``bind`` in every thread.
+
+ Makes the ``bind`` property of the MetaData a thread-local value, allowing
+ this collection of tables to be bound to different ``Engine``
+ implementations or connections in each thread.
+
+ The ThreadLocalMetaData starts off bound to None in each thread. Binds
+ must be made explicitly by assigning to the ``bind`` property or using
+ ``connect()``. You can also re-bind dynamically multiple times per
+ thread, just like a regular ``MetaData``.
+
+ """
+
+ __visit_name__ = "metadata"
+
+ def __init__(self):
+ """Construct a ThreadLocalMetaData."""
+
+ self.context = util.threading.local()
+ self.__engines = {}
+ super(ThreadLocalMetaData, self).__init__()
+
+ def bind(self):
+ """The bound Engine or Connection for this thread.
+
+ This property may be assigned an Engine or Connection, or assigned a
+ string or URL to automatically create a basic Engine for this bind
+ with ``create_engine()``."""
+
+ return getattr(self.context, "_engine", None)
+
+ @util.preload_module("sqlalchemy.engine.url")
+ def _bind_to(self, bind):
+ """Bind to a Connectable in the caller's thread."""
+ url = util.preloaded.engine_url
+ if isinstance(bind, util.string_types + (url.URL,)):
+ try:
+ self.context._engine = self.__engines[bind]
+ except KeyError:
+ e = sqlalchemy.create_engine(bind)
+ self.__engines[bind] = e
+ self.context._engine = e
+ else:
+ # TODO: this is squirrely. we shouldn't have to hold onto engines
+ # in a case like this
+ if bind not in self.__engines:
+ self.__engines[bind] = bind
+ self.context._engine = bind
+
+ bind = property(bind, _bind_to)
+
+ def is_bound(self):
+ """True if there is a bind for this thread."""
+ return (
+ hasattr(self.context, "_engine")
+ and self.context._engine is not None
+ )
+
+ def dispose(self):
+ """Dispose all bound engines, in all thread contexts."""
+
+ for e in self.__engines.values():
+ if hasattr(e, "dispose"):
+ e.dispose()
+
+
+class Computed(FetchedValue, SchemaItem):
+ """Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax.
+
+ The :class:`.Computed` construct is an inline construct added to the
+ argument list of a :class:`_schema.Column` object::
+
+ from sqlalchemy import Computed
+
+ Table('square', metadata_obj,
+ Column('side', Float, nullable=False),
+ Column('area', Float, Computed('side * side'))
+ )
+
+ See the linked documentation below for complete details.
+
+ .. versionadded:: 1.3.11
+
+ .. seealso::
+
+ :ref:`computed_ddl`
+
+ """
+
+ __visit_name__ = "computed_column"
+
+ @_document_text_coercion(
+ "sqltext", ":class:`.Computed`", ":paramref:`.Computed.sqltext`"
+ )
+ def __init__(self, sqltext, persisted=None):
+ """Construct a GENERATED ALWAYS AS DDL construct to accompany a
+ :class:`_schema.Column`.
+
+ :param sqltext:
+ A string containing the column generation expression, which will be
+ used verbatim, or a SQL expression construct, such as a
+ :func:`_expression.text`
+ object. If given as a string, the object is converted to a
+ :func:`_expression.text` object.
+
+ :param persisted:
+ Optional, controls how this column should be persisted by the
+ database. Possible values are:
+
+ * ``None``, the default, it will use the default persistence
+ defined by the database.
+ * ``True``, will render ``GENERATED ALWAYS AS ... STORED``, or the
+ equivalent for the target database if supported.
+ * ``False``, will render ``GENERATED ALWAYS AS ... VIRTUAL``, or
+ the equivalent for the target database if supported.
+
+ Specifying ``True`` or ``False`` may raise an error when the DDL
+ is emitted to the target database if the database does not support
+ that persistence option. Leaving this parameter at its default
+ of ``None`` is guaranteed to succeed for all databases that support
+ ``GENERATED ALWAYS AS``.
+
+ """
+ self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
+ self.persisted = persisted
+ self.column = None
+
+ def _set_parent(self, parent, **kw):
+ if not isinstance(
+ parent.server_default, (type(None), Computed)
+ ) or not isinstance(parent.server_onupdate, (type(None), Computed)):
+ raise exc.ArgumentError(
+ "A generated column cannot specify a server_default or a "
+ "server_onupdate argument"
+ )
+ self.column = parent
+ parent.computed = self
+ self.column.server_onupdate = self
+ self.column.server_default = self
+
+ def _as_for_update(self, for_update):
+ return self
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Computed.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, target_table=None, **kw):
+ return self._copy(target_table, **kw)
+
+ def _copy(self, target_table=None, **kw):
+ sqltext = _copy_expression(
+ self.sqltext,
+ self.column.table if self.column is not None else None,
+ target_table,
+ )
+ g = Computed(sqltext, persisted=self.persisted)
+
+ return self._schema_item_copy(g)
+
+
+class Identity(IdentityOptions, FetchedValue, SchemaItem):
+ """Defines an identity column, i.e. "GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY" syntax.
+
+ The :class:`.Identity` construct is an inline construct added to the
+ argument list of a :class:`_schema.Column` object::
+
+ from sqlalchemy import Identity
+
+ Table('foo', metadata_obj,
+ Column('id', Integer, Identity())
+ Column('description', Text),
+ )
+
+ See the linked documentation below for complete details.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :ref:`identity_ddl`
+
+ """
+
+ __visit_name__ = "identity_column"
+
+ def __init__(
+ self,
+ always=False,
+ on_null=None,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ cache=None,
+ order=None,
+ ):
+ """Construct a GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY DDL
+ construct to accompany a :class:`_schema.Column`.
+
+ See the :class:`.Sequence` documentation for a complete description
+ of most parameters.
+
+ .. note::
+ MSSQL supports this construct as the preferred alternative to
+ generate an IDENTITY on a column, but it uses non standard
+ syntax that only support :paramref:`_schema.Identity.start`
+ and :paramref:`_schema.Identity.increment`.
+ All other parameters are ignored.
+
+ :param always:
+ A boolean, that indicates the type of identity column.
+ If ``False`` is specified, the default, then the user-specified
+ value takes precedence.
+ If ``True`` is specified, a user-specified value is not accepted (
+ on some backends, like PostgreSQL, OVERRIDING SYSTEM VALUE, or
+ similar, may be specified in an INSERT to override the sequence
+ value).
+ Some backends also have a default value for this parameter,
+ ``None`` can be used to omit rendering this part in the DDL. It
+ will be treated as ``False`` if a backend does not have a default
+ value.
+
+ :param on_null:
+ Set to ``True`` to specify ON NULL in conjunction with a
+ ``always=False`` identity column. This option is only supported on
+ some backends, like Oracle.
+
+ :param start: the starting index of the sequence.
+ :param increment: the increment value of the sequence.
+ :param minvalue: the minimum value of the sequence.
+ :param maxvalue: the maximum value of the sequence.
+ :param nominvalue: no minimum value of the sequence.
+ :param nomaxvalue: no maximum value of the sequence.
+ :param cycle: allows the sequence to wrap around when the maxvalue
+ or minvalue has been reached.
+ :param cache: optional integer value; number of future values in the
+ sequence which are calculated in advance.
+ :param order: optional boolean value; if true, renders the
+ ORDER keyword.
+
+ """
+ IdentityOptions.__init__(
+ self,
+ start=start,
+ increment=increment,
+ minvalue=minvalue,
+ maxvalue=maxvalue,
+ nominvalue=nominvalue,
+ nomaxvalue=nomaxvalue,
+ cycle=cycle,
+ cache=cache,
+ order=order,
+ )
+ self.always = always
+ self.on_null = on_null
+ self.column = None
+
+ def _set_parent(self, parent, **kw):
+ if not isinstance(
+ parent.server_default, (type(None), Identity)
+ ) or not isinstance(parent.server_onupdate, type(None)):
+ raise exc.ArgumentError(
+ "A column with an Identity object cannot specify a "
+ "server_default or a server_onupdate argument"
+ )
+ if parent.autoincrement is False:
+ raise exc.ArgumentError(
+ "A column with an Identity object cannot specify "
+ "autoincrement=False"
+ )
+ self.column = parent
+
+ parent.identity = self
+ if parent._user_defined_nullable is NULL_UNSPECIFIED:
+ parent.nullable = False
+
+ parent.server_default = self
+
+ def _as_for_update(self, for_update):
+ return self
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_schema.Identity.copy` method is deprecated "
+ "and will be removed in a future release.",
+ )
+ def copy(self, **kw):
+ return self._copy(**kw)
+
+ def _copy(self, **kw):
+ i = Identity(
+ always=self.always,
+ on_null=self.on_null,
+ start=self.start,
+ increment=self.increment,
+ minvalue=self.minvalue,
+ maxvalue=self.maxvalue,
+ nominvalue=self.nominvalue,
+ nomaxvalue=self.nomaxvalue,
+ cycle=self.cycle,
+ cache=self.cache,
+ order=self.order,
+ )
+
+ return self._schema_item_copy(i)
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
new file mode 100644
index 0000000..8379e1c
--- /dev/null
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -0,0 +1,6946 @@
+# sql/selectable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""The :class:`_expression.FromClause` class of SQL expression elements,
+representing
+SQL tables and derived rowsets.
+
+"""
+
+import collections
+import itertools
+from operator import attrgetter
+
+from . import coercions
+from . import operators
+from . import roles
+from . import traversals
+from . import type_api
+from . import visitors
+from .annotation import Annotated
+from .annotation import SupportsCloneAnnotations
+from .base import _clone
+from .base import _cloned_difference
+from .base import _cloned_intersection
+from .base import _entity_namespace_key
+from .base import _expand_cloned
+from .base import _from_objects
+from .base import _generative
+from .base import _select_iterables
+from .base import CacheableOptions
+from .base import ColumnCollection
+from .base import ColumnSet
+from .base import CompileState
+from .base import DedupeColumnCollection
+from .base import Executable
+from .base import Generative
+from .base import HasCompileState
+from .base import HasMemoized
+from .base import Immutable
+from .base import prefix_anon_map
+from .coercions import _document_text_coercion
+from .elements import _anonymous_label
+from .elements import and_
+from .elements import BindParameter
+from .elements import BooleanClauseList
+from .elements import ClauseElement
+from .elements import ClauseList
+from .elements import ColumnClause
+from .elements import GroupedElement
+from .elements import Grouping
+from .elements import literal_column
+from .elements import TableValuedColumn
+from .elements import UnaryExpression
+from .visitors import InternalTraversal
+from .. import exc
+from .. import util
+from ..inspection import inspect
+
+
+class _OffsetLimitParam(BindParameter):
+ inherit_cache = True
+
+ @property
+ def _limit_offset_value(self):
+ return self.effective_value
+
+
+@util.deprecated(
+ "1.4",
+ "The standalone :func:`.subquery` function is deprecated "
+ "and will be removed in a future release. Use select().subquery().",
+)
+def subquery(alias, *args, **kwargs):
+ r"""Return an :class:`.Subquery` object derived
+ from a :class:`_expression.Select`.
+
+ :param alias: the alias name for the subquery
+
+ :param \*args, \**kwargs: all other arguments are passed through to the
+ :func:`_expression.select` function.
+
+ """
+ return Select.create_legacy_select(*args, **kwargs).subquery(alias)
+
+
+class ReturnsRows(roles.ReturnsRowsRole, ClauseElement):
+ """The base-most class for Core constructs that have some concept of
+ columns that can represent rows.
+
+ While the SELECT statement and TABLE are the primary things we think
+ of in this category, DML like INSERT, UPDATE and DELETE can also specify
+ RETURNING which means they can be used in CTEs and other forms, and
+ PostgreSQL has functions that return rows also.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _is_returns_rows = True
+
+ # sub-elements of returns_rows
+ _is_from_clause = False
+ _is_select_statement = False
+ _is_lateral = False
+
+ @property
+ def selectable(self):
+ return self
+
+ @property
+ def _all_selected_columns(self):
+ """A sequence of column expression objects that represents the
+ "selected" columns of this :class:`_expression.ReturnsRows`.
+
+ This is typically equivalent to .exported_columns except it is
+ delivered in the form of a straight sequence and not keyed
+ :class:`_expression.ColumnCollection`.
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def exported_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ that represents the "exported"
+ columns of this :class:`_expression.ReturnsRows`.
+
+ The "exported" columns represent the collection of
+ :class:`_expression.ColumnElement`
+ expressions that are rendered by this SQL
+ construct. There are primary varieties which are the
+ "FROM clause columns" of a FROM clause, such as a table, join,
+ or subquery, the "SELECTed columns", which are the columns in
+ the "columns clause" of a SELECT statement, and the RETURNING
+ columns in a DML statement..
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_expression.FromClause.exported_columns`
+
+ :attr:`_expression.SelectBase.exported_columns`
+ """
+
+ raise NotImplementedError()
+
+
+class Selectable(ReturnsRows):
+ """Mark a class as being selectable."""
+
+ __visit_name__ = "selectable"
+
+ is_selectable = True
+
+ def _refresh_for_new_column(self, column):
+ raise NotImplementedError()
+
+ def lateral(self, name=None):
+ """Return a LATERAL alias of this :class:`_expression.Selectable`.
+
+ The return value is the :class:`_expression.Lateral` construct also
+ provided by the top-level :func:`_expression.lateral` function.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+ """
+ return Lateral._construct(self, name)
+
+ @util.deprecated(
+ "1.4",
+ message="The :meth:`.Selectable.replace_selectable` method is "
+ "deprecated, and will be removed in a future release. Similar "
+ "functionality is available via the sqlalchemy.sql.visitors module.",
+ )
+ @util.preload_module("sqlalchemy.sql.util")
+ def replace_selectable(self, old, alias):
+ """Replace all occurrences of :class:`_expression.FromClause`
+ 'old' with the given :class:`_expression.Alias`
+ object, returning a copy of this :class:`_expression.FromClause`.
+
+ """
+ return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self)
+
+ def corresponding_column(self, column, require_embedded=False):
+ """Given a :class:`_expression.ColumnElement`, return the exported
+ :class:`_expression.ColumnElement` object from the
+ :attr:`_expression.Selectable.exported_columns`
+ collection of this :class:`_expression.Selectable`
+ which corresponds to that
+ original :class:`_expression.ColumnElement` via a common ancestor
+ column.
+
+ :param column: the target :class:`_expression.ColumnElement`
+ to be matched.
+
+ :param require_embedded: only return corresponding columns for
+ the given :class:`_expression.ColumnElement`, if the given
+ :class:`_expression.ColumnElement`
+ is actually present within a sub-element
+ of this :class:`_expression.Selectable`.
+ Normally the column will match if
+ it merely shares a common ancestor with one of the exported
+ columns of this :class:`_expression.Selectable`.
+
+ .. seealso::
+
+ :attr:`_expression.Selectable.exported_columns` - the
+ :class:`_expression.ColumnCollection`
+ that is used for the operation.
+
+ :meth:`_expression.ColumnCollection.corresponding_column`
+ - implementation
+ method.
+
+ """
+
+ return self.exported_columns.corresponding_column(
+ column, require_embedded
+ )
+
+
+class HasPrefixes(object):
+ _prefixes = ()
+
+ _has_prefixes_traverse_internals = [
+ ("_prefixes", InternalTraversal.dp_prefix_sequence)
+ ]
+
+ @_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`_expression.HasPrefixes.prefix_with`",
+ ":paramref:`.HasPrefixes.prefix_with.*expr`",
+ )
+ def prefix_with(self, *expr, **kw):
+ r"""Add one or more expressions following the statement keyword, i.e.
+ SELECT, INSERT, UPDATE, or DELETE. Generative.
+
+ This is used to support backend-specific prefix keywords such as those
+ provided by MySQL.
+
+ E.g.::
+
+ stmt = table.insert().prefix_with("LOW_PRIORITY", dialect="mysql")
+
+ # MySQL 5.7 optimizer hints
+ stmt = select(table).prefix_with(
+ "/*+ BKA(t1) */", dialect="mysql")
+
+ Multiple prefixes can be specified by multiple calls
+ to :meth:`_expression.HasPrefixes.prefix_with`.
+
+ :param \*expr: textual or :class:`_expression.ClauseElement`
+ construct which
+ will be rendered following the INSERT, UPDATE, or DELETE
+ keyword.
+ :param \**kw: A single keyword 'dialect' is accepted. This is an
+ optional string dialect name which will
+ limit rendering of this prefix to only that dialect.
+
+ """
+ dialect = kw.pop("dialect", None)
+ if kw:
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
+ self._setup_prefixes(expr, dialect)
+
+ def _setup_prefixes(self, prefixes, dialect=None):
+ self._prefixes = self._prefixes + tuple(
+ [
+ (coercions.expect(roles.StatementOptionRole, p), dialect)
+ for p in prefixes
+ ]
+ )
+
+
+class HasSuffixes(object):
+ _suffixes = ()
+
+ _has_suffixes_traverse_internals = [
+ ("_suffixes", InternalTraversal.dp_prefix_sequence)
+ ]
+
+ @_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`_expression.HasSuffixes.suffix_with`",
+ ":paramref:`.HasSuffixes.suffix_with.*expr`",
+ )
+ def suffix_with(self, *expr, **kw):
+ r"""Add one or more expressions following the statement as a whole.
+
+ This is used to support backend-specific suffix keywords on
+ certain constructs.
+
+ E.g.::
+
+ stmt = select(col1, col2).cte().suffix_with(
+ "cycle empno set y_cycle to 1 default 0", dialect="oracle")
+
+ Multiple suffixes can be specified by multiple calls
+ to :meth:`_expression.HasSuffixes.suffix_with`.
+
+ :param \*expr: textual or :class:`_expression.ClauseElement`
+ construct which
+ will be rendered following the target clause.
+ :param \**kw: A single keyword 'dialect' is accepted. This is an
+ optional string dialect name which will
+ limit rendering of this suffix to only that dialect.
+
+ """
+ dialect = kw.pop("dialect", None)
+ if kw:
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
+ self._setup_suffixes(expr, dialect)
+
+ def _setup_suffixes(self, suffixes, dialect=None):
+ self._suffixes = self._suffixes + tuple(
+ [
+ (coercions.expect(roles.StatementOptionRole, p), dialect)
+ for p in suffixes
+ ]
+ )
+
+
+class HasHints(object):
+ _hints = util.immutabledict()
+ _statement_hints = ()
+
+ _has_hints_traverse_internals = [
+ ("_statement_hints", InternalTraversal.dp_statement_hint_list),
+ ("_hints", InternalTraversal.dp_table_hint_list),
+ ]
+
+ def with_statement_hint(self, text, dialect_name="*"):
+ """Add a statement hint to this :class:`_expression.Select` or
+ other selectable object.
+
+ This method is similar to :meth:`_expression.Select.with_hint`
+ except that
+ it does not require an individual table, and instead applies to the
+ statement as a whole.
+
+ Hints here are specific to the backend database and may include
+ directives such as isolation levels, file directives, fetch directives,
+ etc.
+
+ .. versionadded:: 1.0.0
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_hint`
+
+ :meth:`_expression.Select.prefix_with` - generic SELECT prefixing
+ which also can suit some database-specific HINT syntaxes such as
+ MySQL optimizer hints
+
+ """
+ return self.with_hint(None, text, dialect_name)
+
+ @_generative
+ def with_hint(self, selectable, text, dialect_name="*"):
+ r"""Add an indexing or other executional context hint for the given
+ selectable to this :class:`_expression.Select` or other selectable
+ object.
+
+ The text of the hint is rendered in the appropriate
+ location for the database backend in use, relative
+ to the given :class:`_schema.Table` or :class:`_expression.Alias`
+ passed as the
+ ``selectable`` argument. The dialect implementation
+ typically uses Python string substitution syntax
+ with the token ``%(name)s`` to render the name of
+ the table or alias. E.g. when using Oracle, the
+ following::
+
+ select(mytable).\
+ with_hint(mytable, "index(%(name)s ix_mytable)")
+
+ Would render SQL as::
+
+ select /*+ index(mytable ix_mytable) */ ... from mytable
+
+ The ``dialect_name`` option will limit the rendering of a particular
+ hint to a particular backend. Such as, to add hints for both Oracle
+ and Sybase simultaneously::
+
+ select(mytable).\
+ with_hint(mytable, "index(%(name)s ix_mytable)", 'oracle').\
+ with_hint(mytable, "WITH INDEX ix_mytable", 'sybase')
+
+ .. seealso::
+
+ :meth:`_expression.Select.with_statement_hint`
+
+ """
+ if selectable is None:
+ self._statement_hints += ((dialect_name, text),)
+ else:
+ self._hints = self._hints.union(
+ {
+ (
+ coercions.expect(roles.FromClauseRole, selectable),
+ dialect_name,
+ ): text
+ }
+ )
+
+
+class FromClause(roles.AnonymizedFromClauseRole, Selectable):
+ """Represent an element that can be used within the ``FROM``
+ clause of a ``SELECT`` statement.
+
+ The most common forms of :class:`_expression.FromClause` are the
+ :class:`_schema.Table` and the :func:`_expression.select` constructs. Key
+ features common to all :class:`_expression.FromClause` objects include:
+
+ * a :attr:`.c` collection, which provides per-name access to a collection
+ of :class:`_expression.ColumnElement` objects.
+ * a :attr:`.primary_key` attribute, which is a collection of all those
+ :class:`_expression.ColumnElement`
+ objects that indicate the ``primary_key`` flag.
+ * Methods to generate various derivations of a "from" clause, including
+ :meth:`_expression.FromClause.alias`,
+ :meth:`_expression.FromClause.join`,
+ :meth:`_expression.FromClause.select`.
+
+
+ """
+
+ __visit_name__ = "fromclause"
+ named_with_column = False
+ _hide_froms = []
+
+ schema = None
+ """Define the 'schema' attribute for this :class:`_expression.FromClause`.
+
+ This is typically ``None`` for most objects except that of
+ :class:`_schema.Table`, where it is taken as the value of the
+ :paramref:`_schema.Table.schema` argument.
+
+ """
+
+ is_selectable = True
+ _is_from_clause = True
+ _is_join = False
+
+ _use_schema_map = False
+
+ @util.deprecated_params(
+ whereclause=(
+ "2.0",
+ "The :paramref:`_sql.FromClause.select().whereclause` parameter "
+ "is deprecated and will be removed in version 2.0. "
+ "Please make use of "
+ "the :meth:`.Select.where` "
+ "method to add WHERE criteria to the SELECT statement.",
+ ),
+ kwargs=(
+ "2.0",
+ "The :meth:`_sql.FromClause.select` method will no longer accept "
+ "keyword arguments in version 2.0. Please use generative methods "
+ "from the "
+ ":class:`_sql.Select` construct in order to apply additional "
+ "modifications.",
+ ),
+ )
+ def select(self, whereclause=None, **kwargs):
+ r"""Return a SELECT of this :class:`_expression.FromClause`.
+
+
+ e.g.::
+
+ stmt = some_table.select().where(some_table.c.id == 5)
+
+ :param whereclause: a WHERE clause, equivalent to calling the
+ :meth:`_sql.Select.where` method.
+
+ :param \**kwargs: additional keyword arguments are passed to the
+ legacy constructor for :class:`_sql.Select` described at
+ :meth:`_sql.Select.create_legacy_select`.
+
+ .. seealso::
+
+ :func:`_expression.select` - general purpose
+ method which allows for arbitrary column lists.
+
+ """
+ if whereclause is not None:
+ kwargs["whereclause"] = whereclause
+ return Select._create_select_from_fromclause(self, [self], **kwargs)
+
+ def join(self, right, onclause=None, isouter=False, full=False):
+ """Return a :class:`_expression.Join` from this
+ :class:`_expression.FromClause`
+ to another :class:`FromClause`.
+
+ E.g.::
+
+ from sqlalchemy import join
+
+ j = user_table.join(address_table,
+ user_table.c.id == address_table.c.user_id)
+ stmt = select(user_table).select_from(j)
+
+ would emit SQL along the lines of::
+
+ SELECT user.id, user.name FROM user
+ JOIN address ON user.id = address.user_id
+
+ :param right: the right side of the join; this is any
+ :class:`_expression.FromClause` object such as a
+ :class:`_schema.Table` object, and
+ may also be a selectable-compatible object such as an ORM-mapped
+ class.
+
+ :param onclause: a SQL expression representing the ON clause of the
+ join. If left at ``None``, :meth:`_expression.FromClause.join`
+ will attempt to
+ join the two tables based on a foreign key relationship.
+
+ :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN.
+
+ :param full: if True, render a FULL OUTER JOIN, instead of LEFT OUTER
+ JOIN. Implies :paramref:`.FromClause.join.isouter`.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_expression.join` - standalone function
+
+ :class:`_expression.Join` - the type of object produced
+
+ """
+
+ return Join(self, right, onclause, isouter, full)
+
+ def outerjoin(self, right, onclause=None, full=False):
+ """Return a :class:`_expression.Join` from this
+ :class:`_expression.FromClause`
+ to another :class:`FromClause`, with the "isouter" flag set to
+ True.
+
+ E.g.::
+
+ from sqlalchemy import outerjoin
+
+ j = user_table.outerjoin(address_table,
+ user_table.c.id == address_table.c.user_id)
+
+ The above is equivalent to::
+
+ j = user_table.join(
+ address_table,
+ user_table.c.id == address_table.c.user_id,
+ isouter=True)
+
+ :param right: the right side of the join; this is any
+ :class:`_expression.FromClause` object such as a
+ :class:`_schema.Table` object, and
+ may also be a selectable-compatible object such as an ORM-mapped
+ class.
+
+ :param onclause: a SQL expression representing the ON clause of the
+ join. If left at ``None``, :meth:`_expression.FromClause.join`
+ will attempt to
+ join the two tables based on a foreign key relationship.
+
+ :param full: if True, render a FULL OUTER JOIN, instead of
+ LEFT OUTER JOIN.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :meth:`_expression.FromClause.join`
+
+ :class:`_expression.Join`
+
+ """
+
+ return Join(self, right, onclause, True, full)
+
+ def alias(self, name=None, flat=False):
+ """Return an alias of this :class:`_expression.FromClause`.
+
+ E.g.::
+
+ a2 = some_table.alias('a2')
+
+ The above code creates an :class:`_expression.Alias`
+ object which can be used
+ as a FROM clause in any SELECT statement.
+
+ .. seealso::
+
+ :ref:`tutorial_using_aliases`
+
+ :func:`_expression.alias`
+
+ """
+
+ return Alias._construct(self, name)
+
+ @util.preload_module("sqlalchemy.sql.sqltypes")
+ def table_valued(self):
+ """Return a :class:`_sql.TableValuedColumn` object for this
+ :class:`_expression.FromClause`.
+
+ A :class:`_sql.TableValuedColumn` is a :class:`_sql.ColumnElement` that
+ represents a complete row in a table. Support for this construct is
+ backend dependent, and is supported in various forms by backends
+ such as PostgreSQL, Oracle and SQL Server.
+
+ E.g.::
+
+ >>> from sqlalchemy import select, column, func, table
+ >>> a = table("a", column("id"), column("x"), column("y"))
+ >>> stmt = select(func.row_to_json(a.table_valued()))
+ >>> print(stmt)
+ SELECT row_to_json(a) AS row_to_json_1
+ FROM a
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`tutorial_functions` - in the :ref:`unified_tutorial`
+
+ """
+ return TableValuedColumn(self, type_api.TABLEVALUE)
+
+ def tablesample(self, sampling, name=None, seed=None):
+ """Return a TABLESAMPLE alias of this :class:`_expression.FromClause`.
+
+ The return value is the :class:`_expression.TableSample`
+ construct also
+ provided by the top-level :func:`_expression.tablesample` function.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_expression.tablesample` - usage guidelines and parameters
+
+ """
+ return TableSample._construct(self, sampling, name, seed)
+
+ def is_derived_from(self, fromclause):
+ """Return ``True`` if this :class:`_expression.FromClause` is
+ 'derived' from the given ``FromClause``.
+
+ An example would be an Alias of a Table is derived from that Table.
+
+ """
+ # this is essentially an "identity" check in the base class.
+ # Other constructs override this to traverse through
+ # contained elements.
+ return fromclause in self._cloned_set
+
+ def _is_lexical_equivalent(self, other):
+ """Return ``True`` if this :class:`_expression.FromClause` and
+ the other represent the same lexical identity.
+
+ This tests if either one is a copy of the other, or
+ if they are the same via annotation identity.
+
+ """
+ return self._cloned_set.intersection(other._cloned_set)
+
+ @property
+ def description(self):
+ """A brief description of this :class:`_expression.FromClause`.
+
+ Used primarily for error message formatting.
+
+ """
+ return getattr(self, "name", self.__class__.__name__ + " object")
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ fromclause._columns._populate_separate_keys(
+ col._make_proxy(fromclause) for col in self.c
+ )
+
+ @property
+ def exported_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ that represents the "exported"
+ columns of this :class:`_expression.Selectable`.
+
+ The "exported" columns for a :class:`_expression.FromClause`
+ object are synonymous
+ with the :attr:`_expression.FromClause.columns` collection.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_expression.Selectable.exported_columns`
+
+ :attr:`_expression.SelectBase.exported_columns`
+
+
+ """
+ return self.columns
+
+ @util.memoized_property
+ def columns(self):
+ """A named-based collection of :class:`_expression.ColumnElement`
+ objects maintained by this :class:`_expression.FromClause`.
+
+ The :attr:`.columns`, or :attr:`.c` collection, is the gateway
+ to the construction of SQL expressions using table-bound or
+ other selectable-bound columns::
+
+ select(mytable).where(mytable.c.somecolumn == 5)
+
+ :return: a :class:`.ColumnCollection` object.
+
+ """
+
+ if "_columns" not in self.__dict__:
+ self._init_collections()
+ self._populate_column_collection()
+ return self._columns.as_immutable()
+
+ @property
+ def entity_namespace(self):
+ """Return a namespace used for name-based access in SQL expressions.
+
+ This is the namespace that is used to resolve "filter_by()" type
+ expressions, such as::
+
+ stmt.filter_by(address='some address')
+
+ It defaults to the ``.c`` collection, however internally it can
+ be overridden using the "entity_namespace" annotation to deliver
+ alternative results.
+
+ """
+ return self.columns
+
+ @util.memoized_property
+ def primary_key(self):
+ """Return the iterable collection of :class:`_schema.Column` objects
+ which comprise the primary key of this :class:`_selectable.FromClause`.
+
+ For a :class:`_schema.Table` object, this collection is represented
+ by the :class:`_schema.PrimaryKeyConstraint` which itself is an
+ iterable collection of :class:`_schema.Column` objects.
+
+ """
+ self._init_collections()
+ self._populate_column_collection()
+ return self.primary_key
+
+ @util.memoized_property
+ def foreign_keys(self):
+ """Return the collection of :class:`_schema.ForeignKey` marker objects
+ which this FromClause references.
+
+ Each :class:`_schema.ForeignKey` is a member of a
+ :class:`_schema.Table`-wide
+ :class:`_schema.ForeignKeyConstraint`.
+
+ .. seealso::
+
+ :attr:`_schema.Table.foreign_key_constraints`
+
+ """
+ self._init_collections()
+ self._populate_column_collection()
+ return self.foreign_keys
+
+ def _reset_column_collection(self):
+ """Reset the attributes linked to the ``FromClause.c`` attribute.
+
+ This collection is separate from all the other memoized things
+ as it has shown to be sensitive to being cleared out in situations
+ where enclosing code, typically in a replacement traversal scenario,
+ has already established strong relationships
+ with the exported columns.
+
+ The collection is cleared for the case where a table is having a
+ column added to it as well as within a Join during copy internals.
+
+ """
+
+ for key in ["_columns", "columns", "primary_key", "foreign_keys"]:
+ self.__dict__.pop(key, None)
+
+ c = property(
+ attrgetter("columns"),
+ doc="""
+ A named-based collection of :class:`_expression.ColumnElement`
+ objects maintained by this :class:`_expression.FromClause`.
+
+ The :attr:`_sql.FromClause.c` attribute is an alias for the
+ :attr:`_sql.FromClause.columns` attribute.
+
+ :return: a :class:`.ColumnCollection`
+
+ """,
+ )
+ _select_iterable = property(attrgetter("columns"))
+
+ def _init_collections(self):
+ assert "_columns" not in self.__dict__
+ assert "primary_key" not in self.__dict__
+ assert "foreign_keys" not in self.__dict__
+
+ self._columns = ColumnCollection()
+ self.primary_key = ColumnSet()
+ self.foreign_keys = set()
+
+ @property
+ def _cols_populated(self):
+ return "_columns" in self.__dict__
+
+ def _populate_column_collection(self):
+ """Called on subclasses to establish the .c collection.
+
+ Each implementation has a different way of establishing
+ this collection.
+
+ """
+
+ def _refresh_for_new_column(self, column):
+ """Given a column added to the .c collection of an underlying
+ selectable, produce the local version of that column, assuming this
+ selectable ultimately should proxy this column.
+
+ this is used to "ping" a derived selectable to add a new column
+ to its .c. collection when a Column has been added to one of the
+ Table objects it ultimately derives from.
+
+ If the given selectable hasn't populated its .c. collection yet,
+ it should at least pass on the message to the contained selectables,
+ but it will return None.
+
+ This method is currently used by Declarative to allow Table
+ columns to be added to a partially constructed inheritance
+ mapping that may have already produced joins. The method
+ isn't public right now, as the full span of implications
+ and/or caveats aren't yet clear.
+
+ It's also possible that this functionality could be invoked by
+ default via an event, which would require that
+ selectables maintain a weak referencing collection of all
+ derivations.
+
+ """
+ self._reset_column_collection()
+
+ def _anonymous_fromclause(self, name=None, flat=False):
+ return self.alias(name=name)
+
+
+LABEL_STYLE_NONE = util.symbol(
+ "LABEL_STYLE_NONE",
+ """Label style indicating no automatic labeling should be applied to the
+ columns clause of a SELECT statement.
+
+ Below, the columns named ``columna`` are both rendered as is, meaning that
+ the name ``columna`` can only refer to the first occurrence of this name
+ within a result set, as well as if the statement were used as a subquery::
+
+ >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_NONE
+ >>> table1 = table("table1", column("columna"), column("columnb"))
+ >>> table2 = table("table2", column("columna"), column("columnc"))
+ >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_NONE))
+ SELECT table1.columna, table1.columnb, table2.columna, table2.columnc
+ FROM table1 JOIN table2 ON true
+
+ Used with the :meth:`_sql.Select.set_label_style` method.
+
+ .. versionadded:: 1.4
+
+""", # noqa: E501
+)
+
+LABEL_STYLE_TABLENAME_PLUS_COL = util.symbol(
+ "LABEL_STYLE_TABLENAME_PLUS_COL",
+ """Label style indicating all columns should be labeled as
+ ``<tablename>_<columnname>`` when generating the columns clause of a SELECT
+ statement, to disambiguate same-named columns referenced from different
+ tables, aliases, or subqueries.
+
+ Below, all column names are given a label so that the two same-named
+ columns ``columna`` are disambiguated as ``table1_columna`` and
+ ``table2_columna``::
+
+ >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_TABLENAME_PLUS_COL
+ >>> table1 = table("table1", column("columna"), column("columnb"))
+ >>> table2 = table("table2", column("columna"), column("columnc"))
+ >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL))
+ SELECT table1.columna AS table1_columna, table1.columnb AS table1_columnb, table2.columna AS table2_columna, table2.columnc AS table2_columnc
+ FROM table1 JOIN table2 ON true
+
+ Used with the :meth:`_sql.GenerativeSelect.set_label_style` method.
+ Equivalent to the legacy method ``Select.apply_labels()``;
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL` is SQLAlchemy's legacy
+ auto-labeling style. :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY` provides a
+ less intrusive approach to disambiguation of same-named column expressions.
+
+
+ .. versionadded:: 1.4
+
+""", # noqa: E501
+)
+
+
+LABEL_STYLE_DISAMBIGUATE_ONLY = util.symbol(
+ "LABEL_STYLE_DISAMBIGUATE_ONLY",
+ """Label style indicating that columns with a name that conflicts with
+ an existing name should be labeled with a semi-anonymizing label
+ when generating the columns clause of a SELECT statement.
+
+ Below, most column names are left unaffected, except for the second
+ occurrence of the name ``columna``, which is labeled using the
+ label ``columna_1`` to disambiguate it from that of ``tablea.columna``::
+
+ >>> from sqlalchemy import table, column, select, true, LABEL_STYLE_DISAMBIGUATE_ONLY
+ >>> table1 = table("table1", column("columna"), column("columnb"))
+ >>> table2 = table("table2", column("columna"), column("columnc"))
+ >>> print(select(table1, table2).join(table2, true()).set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY))
+ SELECT table1.columna, table1.columnb, table2.columna AS columna_1, table2.columnc
+ FROM table1 JOIN table2 ON true
+
+ Used with the :meth:`_sql.GenerativeSelect.set_label_style` method,
+ :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY` is the default labeling style
+ for all SELECT statements outside of :term:`1.x style` ORM queries.
+
+ .. versionadded:: 1.4
+
+""", # noqa: E501,
+)
+
+
+LABEL_STYLE_DEFAULT = LABEL_STYLE_DISAMBIGUATE_ONLY
+"""The default label style, refers to
+:data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`.
+
+.. versionadded:: 1.4
+
+"""
+
+
+class Join(roles.DMLTableRole, FromClause):
+ """Represent a ``JOIN`` construct between two
+ :class:`_expression.FromClause`
+ elements.
+
+ The public constructor function for :class:`_expression.Join`
+ is the module-level
+ :func:`_expression.join()` function, as well as the
+ :meth:`_expression.FromClause.join` method
+ of any :class:`_expression.FromClause` (e.g. such as
+ :class:`_schema.Table`).
+
+ .. seealso::
+
+ :func:`_expression.join`
+
+ :meth:`_expression.FromClause.join`
+
+ """
+
+ __visit_name__ = "join"
+
+ _traverse_internals = [
+ ("left", InternalTraversal.dp_clauseelement),
+ ("right", InternalTraversal.dp_clauseelement),
+ ("onclause", InternalTraversal.dp_clauseelement),
+ ("isouter", InternalTraversal.dp_boolean),
+ ("full", InternalTraversal.dp_boolean),
+ ]
+
+ _is_join = True
+
+ def __init__(self, left, right, onclause=None, isouter=False, full=False):
+ """Construct a new :class:`_expression.Join`.
+
+ The usual entrypoint here is the :func:`_expression.join`
+ function or the :meth:`_expression.FromClause.join` method of any
+ :class:`_expression.FromClause` object.
+
+ """
+ self.left = coercions.expect(
+ roles.FromClauseRole, left, deannotate=True
+ )
+ self.right = coercions.expect(
+ roles.FromClauseRole, right, deannotate=True
+ ).self_group()
+
+ if onclause is None:
+ self.onclause = self._match_primaries(self.left, self.right)
+ else:
+ # note: taken from If91f61527236fd4d7ae3cad1f24c38be921c90ba
+ # not merged yet
+ self.onclause = coercions.expect(
+ roles.OnClauseRole, onclause
+ ).self_group(against=operators._asbool)
+
+ self.isouter = isouter
+ self.full = full
+
+ @classmethod
+ def _create_outerjoin(cls, left, right, onclause=None, full=False):
+ """Return an ``OUTER JOIN`` clause element.
+
+ The returned object is an instance of :class:`_expression.Join`.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.outerjoin` method on any
+ :class:`_expression.FromClause`.
+
+ :param left: The left side of the join.
+
+ :param right: The right side of the join.
+
+ :param onclause: Optional criterion for the ``ON`` clause, is
+ derived from foreign key relationships established between
+ left and right otherwise.
+
+ To chain joins together, use the :meth:`_expression.FromClause.join`
+ or
+ :meth:`_expression.FromClause.outerjoin` methods on the resulting
+ :class:`_expression.Join` object.
+
+ """
+ return cls(left, right, onclause, isouter=True, full=full)
+
+ @classmethod
+ def _create_join(
+ cls, left, right, onclause=None, isouter=False, full=False
+ ):
+ """Produce a :class:`_expression.Join` object, given two
+ :class:`_expression.FromClause`
+ expressions.
+
+ E.g.::
+
+ j = join(user_table, address_table,
+ user_table.c.id == address_table.c.user_id)
+ stmt = select(user_table).select_from(j)
+
+ would emit SQL along the lines of::
+
+ SELECT user.id, user.name FROM user
+ JOIN address ON user.id = address.user_id
+
+ Similar functionality is available given any
+ :class:`_expression.FromClause` object (e.g. such as a
+ :class:`_schema.Table`) using
+ the :meth:`_expression.FromClause.join` method.
+
+ :param left: The left side of the join.
+
+ :param right: the right side of the join; this is any
+ :class:`_expression.FromClause` object such as a
+ :class:`_schema.Table` object, and
+ may also be a selectable-compatible object such as an ORM-mapped
+ class.
+
+ :param onclause: a SQL expression representing the ON clause of the
+ join. If left at ``None``, :meth:`_expression.FromClause.join`
+ will attempt to
+ join the two tables based on a foreign key relationship.
+
+ :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN.
+
+ :param full: if True, render a FULL OUTER JOIN, instead of JOIN.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :meth:`_expression.FromClause.join` - method form,
+ based on a given left side.
+
+ :class:`_expression.Join` - the type of object produced.
+
+ """
+
+ return cls(left, right, onclause, isouter, full)
+
+ @property
+ def description(self):
+ return "Join object on %s(%d) and %s(%d)" % (
+ self.left.description,
+ id(self.left),
+ self.right.description,
+ id(self.right),
+ )
+
+ def is_derived_from(self, fromclause):
+ return (
+ # use hash() to ensure direct comparison to annotated works
+ # as well
+ hash(fromclause) == hash(self)
+ or self.left.is_derived_from(fromclause)
+ or self.right.is_derived_from(fromclause)
+ )
+
+ def self_group(self, against=None):
+ return FromGrouping(self)
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _populate_column_collection(self):
+ sqlutil = util.preloaded.sql_util
+ columns = [c for c in self.left.columns] + [
+ c for c in self.right.columns
+ ]
+
+ self.primary_key.extend(
+ sqlutil.reduce_columns(
+ (c for c in columns if c.primary_key), self.onclause
+ )
+ )
+ self._columns._populate_separate_keys(
+ (col._tq_key_label, col) for col in columns
+ )
+ self.foreign_keys.update(
+ itertools.chain(*[col.foreign_keys for col in columns])
+ )
+
+ def _copy_internals(self, clone=_clone, **kw):
+ # see Select._copy_internals() for similar concept
+
+ # here we pre-clone "left" and "right" so that we can
+ # determine the new FROM clauses
+ all_the_froms = set(
+ itertools.chain(
+ _from_objects(self.left),
+ _from_objects(self.right),
+ )
+ )
+
+ # run the clone on those. these will be placed in the
+ # cache used by the clone function
+ new_froms = {f: clone(f, **kw) for f in all_the_froms}
+
+ # set up a special replace function that will replace for
+ # ColumnClause with parent table referring to those
+ # replaced FromClause objects
+ def replace(obj, **kw):
+ if isinstance(obj, ColumnClause) and obj.table in new_froms:
+ newelem = new_froms[obj.table].corresponding_column(obj)
+ return newelem
+
+ kw["replace"] = replace
+
+ # run normal _copy_internals. the clones for
+ # left and right will come from the clone function's
+ # cache
+ super(Join, self)._copy_internals(clone=clone, **kw)
+
+ self._reset_memoizations()
+
+ def _refresh_for_new_column(self, column):
+ super(Join, self)._refresh_for_new_column(column)
+ self.left._refresh_for_new_column(column)
+ self.right._refresh_for_new_column(column)
+
+ def _match_primaries(self, left, right):
+ if isinstance(left, Join):
+ left_right = left.right
+ else:
+ left_right = None
+ return self._join_condition(left, right, a_subset=left_right)
+
+ @classmethod
+ def _join_condition(
+ cls, a, b, a_subset=None, consider_as_foreign_keys=None
+ ):
+ """Create a join condition between two tables or selectables.
+
+ e.g.::
+
+ join_condition(tablea, tableb)
+
+ would produce an expression along the lines of::
+
+ tablea.c.id==tableb.c.tablea_id
+
+ The join is determined based on the foreign key relationships
+ between the two selectables. If there are multiple ways
+ to join, or no way to join, an error is raised.
+
+ :param a_subset: An optional expression that is a sub-component
+ of ``a``. An attempt will be made to join to just this sub-component
+ first before looking at the full ``a`` construct, and if found
+ will be successful even if there are other ways to join to ``a``.
+ This allows the "right side" of a join to be passed thereby
+ providing a "natural join".
+
+ """
+ constraints = cls._joincond_scan_left_right(
+ a, a_subset, b, consider_as_foreign_keys
+ )
+
+ if len(constraints) > 1:
+ cls._joincond_trim_constraints(
+ a, b, constraints, consider_as_foreign_keys
+ )
+
+ if len(constraints) == 0:
+ if isinstance(b, FromGrouping):
+ hint = (
+ " Perhaps you meant to convert the right side to a "
+ "subquery using alias()?"
+ )
+ else:
+ hint = ""
+ raise exc.NoForeignKeysError(
+ "Can't find any foreign key relationships "
+ "between '%s' and '%s'.%s"
+ % (a.description, b.description, hint)
+ )
+
+ crit = [(x == y) for x, y in list(constraints.values())[0]]
+ if len(crit) == 1:
+ return crit[0]
+ else:
+ return and_(*crit)
+
+ @classmethod
+ def _can_join(cls, left, right, consider_as_foreign_keys=None):
+ if isinstance(left, Join):
+ left_right = left.right
+ else:
+ left_right = None
+
+ constraints = cls._joincond_scan_left_right(
+ a=left,
+ b=right,
+ a_subset=left_right,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
+
+ return bool(constraints)
+
+ @classmethod
+ @util.preload_module("sqlalchemy.sql.util")
+ def _joincond_scan_left_right(
+ cls, a, a_subset, b, consider_as_foreign_keys
+ ):
+ sql_util = util.preloaded.sql_util
+
+ a = coercions.expect(roles.FromClauseRole, a)
+ b = coercions.expect(roles.FromClauseRole, b)
+
+ constraints = collections.defaultdict(list)
+
+ for left in (a_subset, a):
+ if left is None:
+ continue
+ for fk in sorted(
+ b.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
+ continue
+ try:
+ col = fk.get_referent(left)
+ except exc.NoReferenceError as nrte:
+ table_names = {t.name for t in sql_util.find_tables(left)}
+ if nrte.table_name in table_names:
+ raise
+ else:
+ continue
+
+ if col is not None:
+ constraints[fk.constraint].append((col, fk.parent))
+ if left is not b:
+ for fk in sorted(
+ left.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
+ continue
+ try:
+ col = fk.get_referent(b)
+ except exc.NoReferenceError as nrte:
+ table_names = {t.name for t in sql_util.find_tables(b)}
+ if nrte.table_name in table_names:
+ raise
+ else:
+ continue
+
+ if col is not None:
+ constraints[fk.constraint].append((col, fk.parent))
+ if constraints:
+ break
+ return constraints
+
+ @classmethod
+ def _joincond_trim_constraints(
+ cls, a, b, constraints, consider_as_foreign_keys
+ ):
+ # more than one constraint matched. narrow down the list
+ # to include just those FKCs that match exactly to
+ # "consider_as_foreign_keys".
+ if consider_as_foreign_keys:
+ for const in list(constraints):
+ if set(f.parent for f in const.elements) != set(
+ consider_as_foreign_keys
+ ):
+ del constraints[const]
+
+ # if still multiple constraints, but
+ # they all refer to the exact same end result, use it.
+ if len(constraints) > 1:
+ dedupe = set(tuple(crit) for crit in constraints.values())
+ if len(dedupe) == 1:
+ key = list(constraints)[0]
+ constraints = {key: constraints[key]}
+
+ if len(constraints) != 1:
+ raise exc.AmbiguousForeignKeysError(
+ "Can't determine join between '%s' and '%s'; "
+ "tables have more than one foreign key "
+ "constraint relationship between them. "
+ "Please specify the 'onclause' of this "
+ "join explicitly." % (a.description, b.description)
+ )
+
+ @util.deprecated_params(
+ whereclause=(
+ "2.0",
+ "The :paramref:`_sql.Join.select().whereclause` parameter "
+ "is deprecated and will be removed in version 2.0. "
+ "Please make use of "
+ "the :meth:`.Select.where` "
+ "method to add WHERE criteria to the SELECT statement.",
+ ),
+ kwargs=(
+ "2.0",
+ "The :meth:`_sql.Join.select` method will no longer accept "
+ "keyword arguments in version 2.0. Please use generative "
+ "methods from the "
+ ":class:`_sql.Select` construct in order to apply additional "
+ "modifications.",
+ ),
+ )
+ def select(self, whereclause=None, **kwargs):
+ r"""Create a :class:`_expression.Select` from this
+ :class:`_expression.Join`.
+
+ E.g.::
+
+ stmt = table_a.join(table_b, table_a.c.id == table_b.c.a_id)
+
+ stmt = stmt.select()
+
+ The above will produce a SQL string resembling::
+
+ SELECT table_a.id, table_a.col, table_b.id, table_b.a_id
+ FROM table_a JOIN table_b ON table_a.id = table_b.a_id
+
+ :param whereclause: WHERE criteria, same as calling
+ :meth:`_sql.Select.where` on the resulting statement
+
+ :param \**kwargs: additional keyword arguments are passed to the
+ legacy constructor for :class:`_sql.Select` described at
+ :meth:`_sql.Select.create_legacy_select`.
+
+ """
+ collist = [self.left, self.right]
+
+ if whereclause is not None:
+ kwargs["whereclause"] = whereclause
+ return Select._create_select_from_fromclause(
+ self, collist, **kwargs
+ ).select_from(self)
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Return the bound engine associated with either the left or right
+ side of this :class:`_sql.Join`.
+
+ """
+
+ return self.left.bind or self.right.bind
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _anonymous_fromclause(self, name=None, flat=False):
+ sqlutil = util.preloaded.sql_util
+ if flat:
+ if name is not None:
+ raise exc.ArgumentError("Can't send name argument with flat")
+ left_a, right_a = (
+ self.left._anonymous_fromclause(flat=True),
+ self.right._anonymous_fromclause(flat=True),
+ )
+ adapter = sqlutil.ClauseAdapter(left_a).chain(
+ sqlutil.ClauseAdapter(right_a)
+ )
+
+ return left_a.join(
+ right_a,
+ adapter.traverse(self.onclause),
+ isouter=self.isouter,
+ full=self.full,
+ )
+ else:
+ return (
+ self.select()
+ .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+ .correlate(None)
+ .alias(name)
+ )
+
+ @util.deprecated_20(
+ ":meth:`_sql.Join.alias`",
+ alternative="Create a select + subquery, or alias the "
+ "individual tables inside the join, instead.",
+ )
+ def alias(self, name=None, flat=False):
+ r"""Return an alias of this :class:`_expression.Join`.
+
+ The default behavior here is to first produce a SELECT
+ construct from this :class:`_expression.Join`, then to produce an
+ :class:`_expression.Alias` from that. So given a join of the form::
+
+ j = table_a.join(table_b, table_a.c.id == table_b.c.a_id)
+
+ The JOIN by itself would look like::
+
+ table_a JOIN table_b ON table_a.id = table_b.a_id
+
+ Whereas the alias of the above, ``j.alias()``, would in a
+ SELECT context look like::
+
+ (SELECT table_a.id AS table_a_id, table_b.id AS table_b_id,
+ table_b.a_id AS table_b_a_id
+ FROM table_a
+ JOIN table_b ON table_a.id = table_b.a_id) AS anon_1
+
+ The equivalent long-hand form, given a :class:`_expression.Join`
+ object ``j``, is::
+
+ from sqlalchemy import select, alias
+ j = alias(
+ select(j.left, j.right).\
+ select_from(j).\
+ set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).\
+ correlate(False),
+ name=name
+ )
+
+ The selectable produced by :meth:`_expression.Join.alias`
+ features the same
+ columns as that of the two individual selectables presented under
+ a single name - the individual columns are "auto-labeled", meaning
+ the ``.c.`` collection of the resulting :class:`_expression.Alias`
+ represents
+ the names of the individual columns using a
+ ``<tablename>_<columname>`` scheme::
+
+ j.c.table_a_id
+ j.c.table_b_a_id
+
+ :meth:`_expression.Join.alias` also features an alternate
+ option for aliasing joins which produces no enclosing SELECT and
+ does not normally apply labels to the column names. The
+ ``flat=True`` option will call :meth:`_expression.FromClause.alias`
+ against the left and right sides individually.
+ Using this option, no new ``SELECT`` is produced;
+ we instead, from a construct as below::
+
+ j = table_a.join(table_b, table_a.c.id == table_b.c.a_id)
+ j = j.alias(flat=True)
+
+ we get a result like this::
+
+ table_a AS table_a_1 JOIN table_b AS table_b_1 ON
+ table_a_1.id = table_b_1.a_id
+
+ The ``flat=True`` argument is also propagated to the contained
+ selectables, so that a composite join such as::
+
+ j = table_a.join(
+ table_b.join(table_c,
+ table_b.c.id == table_c.c.b_id),
+ table_b.c.a_id == table_a.c.id
+ ).alias(flat=True)
+
+ Will produce an expression like::
+
+ table_a AS table_a_1 JOIN (
+ table_b AS table_b_1 JOIN table_c AS table_c_1
+ ON table_b_1.id = table_c_1.b_id
+ ) ON table_a_1.id = table_b_1.a_id
+
+ The standalone :func:`_expression.alias` function as well as the
+ base :meth:`_expression.FromClause.alias`
+ method also support the ``flat=True``
+ argument as a no-op, so that the argument can be passed to the
+ ``alias()`` method of any selectable.
+
+ :param name: name given to the alias.
+
+ :param flat: if True, produce an alias of the left and right
+ sides of this :class:`_expression.Join` and return the join of those
+ two selectables. This produces join expression that does not
+ include an enclosing SELECT.
+
+ .. seealso::
+
+ :ref:`core_tutorial_aliases`
+
+ :func:`_expression.alias`
+
+ """
+ return self._anonymous_fromclause(flat=flat, name=name)
+
+ @property
+ def _hide_froms(self):
+ return itertools.chain(
+ *[_from_objects(x.left, x.right) for x in self._cloned_set]
+ )
+
+ @property
+ def _from_objects(self):
+ return [self] + self.left._from_objects + self.right._from_objects
+
+
+class NoInit(object):
+ def __init__(self, *arg, **kw):
+ raise NotImplementedError(
+ "The %s class is not intended to be constructed "
+ "directly. Please use the %s() standalone "
+ "function or the %s() method available from appropriate "
+ "selectable objects."
+ % (
+ self.__class__.__name__,
+ self.__class__.__name__.lower(),
+ self.__class__.__name__.lower(),
+ )
+ )
+
+
+# FromClause ->
+# AliasedReturnsRows
+# -> Alias only for FromClause
+# -> Subquery only for SelectBase
+# -> CTE only for HasCTE -> SelectBase, DML
+# -> Lateral -> FromClause, but we accept SelectBase
+# w/ non-deprecated coercion
+# -> TableSample -> only for FromClause
+class AliasedReturnsRows(NoInit, FromClause):
+ """Base class of aliases against tables, subqueries, and other
+ selectables."""
+
+ _is_from_container = True
+ named_with_column = True
+
+ _supports_derived_columns = False
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("name", InternalTraversal.dp_anon_name),
+ ]
+
+ @classmethod
+ def _construct(cls, *arg, **kw):
+ obj = cls.__new__(cls)
+ obj._init(*arg, **kw)
+ return obj
+
+ @classmethod
+ def _factory(cls, returnsrows, name=None):
+ """Base factory method. Subclasses need to provide this."""
+ raise NotImplementedError()
+
+ def _init(self, selectable, name=None):
+ self.element = coercions.expect(
+ roles.ReturnsRowsRole, selectable, apply_propagate_attrs=self
+ )
+ self.element = selectable
+ self._orig_name = name
+ if name is None:
+ if (
+ isinstance(selectable, FromClause)
+ and selectable.named_with_column
+ ):
+ name = getattr(selectable, "name", None)
+ if isinstance(name, _anonymous_label):
+ name = None
+ name = _anonymous_label.safe_construct(id(self), name or "anon")
+ self.name = name
+
+ def _refresh_for_new_column(self, column):
+ super(AliasedReturnsRows, self)._refresh_for_new_column(column)
+ self.element._refresh_for_new_column(column)
+
+ @property
+ def description(self):
+ name = self.name
+ if isinstance(name, _anonymous_label):
+ name = "anon_1"
+
+ if util.py3k:
+ return name
+ else:
+ return name.encode("ascii", "backslashreplace")
+
+ @property
+ def original(self):
+ """Legacy for dialects that are referring to Alias.original."""
+ return self.element
+
+ def is_derived_from(self, fromclause):
+ if fromclause in self._cloned_set:
+ return True
+ return self.element.is_derived_from(fromclause)
+
+ def _populate_column_collection(self):
+ self.element._generate_fromclause_column_proxies(self)
+
+ def _copy_internals(self, clone=_clone, **kw):
+ existing_element = self.element
+
+ super(AliasedReturnsRows, self)._copy_internals(clone=clone, **kw)
+
+ # the element clone is usually against a Table that returns the
+ # same object. don't reset exported .c. collections and other
+ # memoized details if it was not changed. this saves a lot on
+ # performance.
+ if existing_element is not self.element:
+ self._reset_column_collection()
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+ @property
+ def bind(self):
+ return self.element.bind
+
+
+class Alias(roles.DMLTableRole, AliasedReturnsRows):
+ """Represents an table or selectable alias (AS).
+
+ Represents an alias, as typically applied to any table or
+ sub-select within a SQL statement using the ``AS`` keyword (or
+ without the keyword on certain databases such as Oracle).
+
+ This object is constructed from the :func:`_expression.alias` module
+ level function as well as the :meth:`_expression.FromClause.alias`
+ method available
+ on all :class:`_expression.FromClause` subclasses.
+
+ .. seealso::
+
+ :meth:`_expression.FromClause.alias`
+
+ """
+
+ __visit_name__ = "alias"
+
+ inherit_cache = True
+
+ @classmethod
+ def _factory(cls, selectable, name=None, flat=False):
+ """Return an :class:`_expression.Alias` object.
+
+ An :class:`_expression.Alias` represents any
+ :class:`_expression.FromClause`
+ with an alternate name assigned within SQL, typically using the ``AS``
+ clause when generated, e.g. ``SELECT * FROM table AS aliasname``.
+
+ Similar functionality is available via the
+ :meth:`_expression.FromClause.alias`
+ method available on all :class:`_expression.FromClause` subclasses.
+ In terms of
+ a SELECT object as generated from the :func:`_expression.select`
+ function, the :meth:`_expression.SelectBase.alias` method returns an
+ :class:`_expression.Alias` or similar object which represents a named,
+ parenthesized subquery.
+
+ When an :class:`_expression.Alias` is created from a
+ :class:`_schema.Table` object,
+ this has the effect of the table being rendered
+ as ``tablename AS aliasname`` in a SELECT statement.
+
+ For :func:`_expression.select` objects, the effect is that of
+ creating a named subquery, i.e. ``(select ...) AS aliasname``.
+
+ The ``name`` parameter is optional, and provides the name
+ to use in the rendered SQL. If blank, an "anonymous" name
+ will be deterministically generated at compile time.
+ Deterministic means the name is guaranteed to be unique against
+ other constructs used in the same statement, and will also be the
+ same name for each successive compilation of the same statement
+ object.
+
+ :param selectable: any :class:`_expression.FromClause` subclass,
+ such as a table, select statement, etc.
+
+ :param name: string name to be assigned as the alias.
+ If ``None``, a name will be deterministically generated
+ at compile time.
+
+ :param flat: Will be passed through to if the given selectable
+ is an instance of :class:`_expression.Join` - see
+ :meth:`_expression.Join.alias`
+ for details.
+
+ """
+ return coercions.expect(
+ roles.FromClauseRole, selectable, allow_select=True
+ ).alias(name=name, flat=flat)
+
+
+class TableValuedAlias(Alias):
+ """An alias against a "table valued" SQL function.
+
+ This construct provides for a SQL function that returns columns
+ to be used in the FROM clause of a SELECT statement. The
+ object is generated using the :meth:`_functions.FunctionElement.table_valued`
+ method, e.g.::
+
+ >>> from sqlalchemy import select, func
+ >>> fn = func.json_array_elements_text('["one", "two", "three"]').table_valued("value")
+ >>> print(select(fn.c.value))
+ SELECT anon_1.value
+ FROM json_array_elements_text(:json_array_elements_text_1) AS anon_1
+
+ .. versionadded:: 1.4.0b2
+
+ .. seealso::
+
+ :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial`
+
+ """ # noqa: E501
+
+ __visit_name__ = "table_valued_alias"
+
+ _supports_derived_columns = True
+ _render_derived = False
+ _render_derived_w_types = False
+ joins_implicitly = False
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("name", InternalTraversal.dp_anon_name),
+ ("_tableval_type", InternalTraversal.dp_type),
+ ("_render_derived", InternalTraversal.dp_boolean),
+ ("_render_derived_w_types", InternalTraversal.dp_boolean),
+ ]
+
+ def _init(
+ self,
+ selectable,
+ name=None,
+ table_value_type=None,
+ joins_implicitly=False,
+ ):
+ super(TableValuedAlias, self)._init(selectable, name=name)
+
+ self.joins_implicitly = joins_implicitly
+ self._tableval_type = (
+ type_api.TABLEVALUE
+ if table_value_type is None
+ else table_value_type
+ )
+
+ @HasMemoized.memoized_attribute
+ def column(self):
+ """Return a column expression representing this
+ :class:`_sql.TableValuedAlias`.
+
+ This accessor is used to implement the
+ :meth:`_functions.FunctionElement.column_valued` method. See that
+ method for further details.
+
+ E.g.::
+
+ >>> print(select(func.some_func().table_valued("value").column))
+ SELECT anon_1 FROM some_func() AS anon_1
+
+ .. seealso::
+
+ :meth:`_functions.FunctionElement.column_valued`
+
+ """
+
+ return TableValuedColumn(self, self._tableval_type)
+
+ def alias(self, name=None):
+ """Return a new alias of this :class:`_sql.TableValuedAlias`.
+
+ This creates a distinct FROM object that will be distinguished
+ from the original one when used in a SQL statement.
+
+ """
+
+ tva = TableValuedAlias._construct(
+ self,
+ name=name,
+ table_value_type=self._tableval_type,
+ joins_implicitly=self.joins_implicitly,
+ )
+
+ if self._render_derived:
+ tva._render_derived = True
+ tva._render_derived_w_types = self._render_derived_w_types
+
+ return tva
+
+ def lateral(self, name=None):
+ """Return a new :class:`_sql.TableValuedAlias` with the lateral flag
+ set, so that it renders as LATERAL.
+
+ .. seealso::
+
+ :func:`_expression.lateral`
+
+ """
+ tva = self.alias(name=name)
+ tva._is_lateral = True
+ return tva
+
+ def render_derived(self, name=None, with_types=False):
+ """Apply "render derived" to this :class:`_sql.TableValuedAlias`.
+
+ This has the effect of the individual column names listed out
+ after the alias name in the "AS" sequence, e.g.::
+
+ >>> print(
+ ... select(
+ ... func.unnest(array(["one", "two", "three"])).
+ table_valued("x", with_ordinality="o").render_derived()
+ ... )
+ ... )
+ SELECT anon_1.x, anon_1.o
+ FROM unnest(ARRAY[%(param_1)s, %(param_2)s, %(param_3)s]) WITH ORDINALITY AS anon_1(x, o)
+
+ The ``with_types`` keyword will render column types inline within
+ the alias expression (this syntax currently applies to the
+ PostgreSQL database)::
+
+ >>> print(
+ ... select(
+ ... func.json_to_recordset(
+ ... '[{"a":1,"b":"foo"},{"a":"2","c":"bar"}]'
+ ... )
+ ... .table_valued(column("a", Integer), column("b", String))
+ ... .render_derived(with_types=True)
+ ... )
+ ... )
+ SELECT anon_1.a, anon_1.b FROM json_to_recordset(:json_to_recordset_1)
+ AS anon_1(a INTEGER, b VARCHAR)
+
+ :param name: optional string name that will be applied to the alias
+ generated. If left as None, a unique anonymizing name will be used.
+
+ :param with_types: if True, the derived columns will include the
+ datatype specification with each column. This is a special syntax
+ currently known to be required by PostgreSQL for some SQL functions.
+
+ """ # noqa: E501
+
+ # note: don't use the @_generative system here, keep a reference
+ # to the original object. otherwise you can have re-use of the
+ # python id() of the original which can cause name conflicts if
+ # a new anon-name grabs the same identifier as the local anon-name
+ # (just saw it happen on CI)
+
+ # construct against original to prevent memory growth
+ # for repeated generations
+ new_alias = TableValuedAlias._construct(
+ self.element,
+ name=name,
+ table_value_type=self._tableval_type,
+ joins_implicitly=self.joins_implicitly,
+ )
+ new_alias._render_derived = True
+ new_alias._render_derived_w_types = with_types
+ return new_alias
+
+
+class Lateral(AliasedReturnsRows):
+ """Represent a LATERAL subquery.
+
+ This object is constructed from the :func:`_expression.lateral` module
+ level function as well as the :meth:`_expression.FromClause.lateral`
+ method available
+ on all :class:`_expression.FromClause` subclasses.
+
+ While LATERAL is part of the SQL standard, currently only more recent
+ PostgreSQL versions provide support for this keyword.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+ """
+
+ __visit_name__ = "lateral"
+ _is_lateral = True
+
+ inherit_cache = True
+
+ @classmethod
+ def _factory(cls, selectable, name=None):
+ """Return a :class:`_expression.Lateral` object.
+
+ :class:`_expression.Lateral` is an :class:`_expression.Alias`
+ subclass that represents
+ a subquery with the LATERAL keyword applied to it.
+
+ The special behavior of a LATERAL subquery is that it appears in the
+ FROM clause of an enclosing SELECT, but may correlate to other
+ FROM clauses of that SELECT. It is a special case of subquery
+ only supported by a small number of backends, currently more recent
+ PostgreSQL versions.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+
+ """
+ return coercions.expect(
+ roles.FromClauseRole, selectable, explicit_subquery=True
+ ).lateral(name=name)
+
+
+class TableSample(AliasedReturnsRows):
+ """Represent a TABLESAMPLE clause.
+
+ This object is constructed from the :func:`_expression.tablesample` module
+ level function as well as the :meth:`_expression.FromClause.tablesample`
+ method
+ available on all :class:`_expression.FromClause` subclasses.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :func:`_expression.tablesample`
+
+ """
+
+ __visit_name__ = "tablesample"
+
+ _traverse_internals = AliasedReturnsRows._traverse_internals + [
+ ("sampling", InternalTraversal.dp_clauseelement),
+ ("seed", InternalTraversal.dp_clauseelement),
+ ]
+
+ @classmethod
+ def _factory(cls, selectable, sampling, name=None, seed=None):
+ """Return a :class:`_expression.TableSample` object.
+
+ :class:`_expression.TableSample` is an :class:`_expression.Alias`
+ subclass that represents
+ a table with the TABLESAMPLE clause applied to it.
+ :func:`_expression.tablesample`
+ is also available from the :class:`_expression.FromClause`
+ class via the
+ :meth:`_expression.FromClause.tablesample` method.
+
+ The TABLESAMPLE clause allows selecting a randomly selected approximate
+ percentage of rows from a table. It supports multiple sampling methods,
+ most commonly BERNOULLI and SYSTEM.
+
+ e.g.::
+
+ from sqlalchemy import func
+
+ selectable = people.tablesample(
+ func.bernoulli(1),
+ name='alias',
+ seed=func.random())
+ stmt = select(selectable.c.people_id)
+
+ Assuming ``people`` with a column ``people_id``, the above
+ statement would render as::
+
+ SELECT alias.people_id FROM
+ people AS alias TABLESAMPLE bernoulli(:bernoulli_1)
+ REPEATABLE (random())
+
+ .. versionadded:: 1.1
+
+ :param sampling: a ``float`` percentage between 0 and 100 or
+ :class:`_functions.Function`.
+
+ :param name: optional alias name
+
+ :param seed: any real-valued SQL expression. When specified, the
+ REPEATABLE sub-clause is also rendered.
+
+ """
+ return coercions.expect(roles.FromClauseRole, selectable).tablesample(
+ sampling, name=name, seed=seed
+ )
+
+ @util.preload_module("sqlalchemy.sql.functions")
+ def _init(self, selectable, sampling, name=None, seed=None):
+ functions = util.preloaded.sql_functions
+ if not isinstance(sampling, functions.Function):
+ sampling = functions.func.system(sampling)
+
+ self.sampling = sampling
+ self.seed = seed
+ super(TableSample, self)._init(selectable, name=name)
+
+ def _get_method(self):
+ return self.sampling
+
+
+class CTE(
+ roles.DMLTableRole,
+ roles.IsCTERole,
+ Generative,
+ HasPrefixes,
+ HasSuffixes,
+ AliasedReturnsRows,
+):
+ """Represent a Common Table Expression.
+
+ The :class:`_expression.CTE` object is obtained using the
+ :meth:`_sql.SelectBase.cte` method from any SELECT statement. A less often
+ available syntax also allows use of the :meth:`_sql.HasCTE.cte` method
+ present on :term:`DML` constructs such as :class:`_sql.Insert`,
+ :class:`_sql.Update` and
+ :class:`_sql.Delete`. See the :meth:`_sql.HasCTE.cte` method for
+ usage details on CTEs.
+
+ .. seealso::
+
+ :ref:`tutorial_subqueries_ctes` - in the 2.0 tutorial
+
+ :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+ """
+
+ __visit_name__ = "cte"
+
+ _traverse_internals = (
+ AliasedReturnsRows._traverse_internals
+ + [
+ ("_cte_alias", InternalTraversal.dp_clauseelement),
+ ("_restates", InternalTraversal.dp_clauseelement),
+ ("recursive", InternalTraversal.dp_boolean),
+ ("nesting", InternalTraversal.dp_boolean),
+ ]
+ + HasPrefixes._has_prefixes_traverse_internals
+ + HasSuffixes._has_suffixes_traverse_internals
+ )
+
+ @classmethod
+ def _factory(cls, selectable, name=None, recursive=False):
+ r"""Return a new :class:`_expression.CTE`,
+ or Common Table Expression instance.
+
+ Please see :meth:`_expression.HasCTE.cte` for detail on CTE usage.
+
+ """
+ return coercions.expect(roles.HasCTERole, selectable).cte(
+ name=name, recursive=recursive
+ )
+
+ def _init(
+ self,
+ selectable,
+ name=None,
+ recursive=False,
+ nesting=False,
+ _cte_alias=None,
+ _restates=None,
+ _prefixes=None,
+ _suffixes=None,
+ ):
+ self.recursive = recursive
+ self.nesting = nesting
+ self._cte_alias = _cte_alias
+ # Keep recursivity reference with union/union_all
+ self._restates = _restates
+ if _prefixes:
+ self._prefixes = _prefixes
+ if _suffixes:
+ self._suffixes = _suffixes
+ super(CTE, self)._init(selectable, name=name)
+
+ def _populate_column_collection(self):
+ if self._cte_alias is not None:
+ self._cte_alias._generate_fromclause_column_proxies(self)
+ else:
+ self.element._generate_fromclause_column_proxies(self)
+
+ def alias(self, name=None, flat=False):
+ """Return an :class:`_expression.Alias` of this
+ :class:`_expression.CTE`.
+
+ This method is a CTE-specific specialization of the
+ :meth:`_expression.FromClause.alias` method.
+
+ .. seealso::
+
+ :ref:`tutorial_using_aliases`
+
+ :func:`_expression.alias`
+
+ """
+ return CTE._construct(
+ self.element,
+ name=name,
+ recursive=self.recursive,
+ nesting=self.nesting,
+ _cte_alias=self,
+ _prefixes=self._prefixes,
+ _suffixes=self._suffixes,
+ )
+
+ def union(self, *other):
+ r"""Return a new :class:`_expression.CTE` with a SQL ``UNION``
+ of the original CTE against the given selectables provided
+ as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28 multiple elements are now accepted.
+
+ .. seealso::
+
+ :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+ """
+ return CTE._construct(
+ self.element.union(*other),
+ name=self.name,
+ recursive=self.recursive,
+ nesting=self.nesting,
+ _restates=self,
+ _prefixes=self._prefixes,
+ _suffixes=self._suffixes,
+ )
+
+ def union_all(self, *other):
+ r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL``
+ of the original CTE against the given selectables provided
+ as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28 multiple elements are now accepted.
+
+ .. seealso::
+
+ :meth:`_sql.HasCTE.cte` - examples of calling styles
+
+ """
+ return CTE._construct(
+ self.element.union_all(*other),
+ name=self.name,
+ recursive=self.recursive,
+ nesting=self.nesting,
+ _restates=self,
+ _prefixes=self._prefixes,
+ _suffixes=self._suffixes,
+ )
+
+ def _get_reference_cte(self):
+ """
+ A recursive CTE is updated to attach the recursive part.
+ Updated CTEs should still refer to the original CTE.
+ This function returns this reference identifier.
+ """
+ return self._restates if self._restates is not None else self
+
+
+class HasCTE(roles.HasCTERole):
+ """Mixin that declares a class to include CTE support.
+
+ .. versionadded:: 1.1
+
+ """
+
+ _has_ctes_traverse_internals = [
+ ("_independent_ctes", InternalTraversal.dp_clauseelement_list),
+ ]
+
+ _independent_ctes = ()
+
+ @_generative
+ def add_cte(self, cte):
+ """Add a :class:`_sql.CTE` to this statement object that will be
+ independently rendered even if not referenced in the statement
+ otherwise.
+
+ This feature is useful for the use case of embedding a DML statement
+ such as an INSERT or UPDATE as a CTE inline with a primary statement
+ that may draw from its results indirectly; while PostgreSQL is known
+ to support this usage, it may not be supported by other backends.
+
+ E.g.::
+
+ from sqlalchemy import table, column, select
+ t = table('t', column('c1'), column('c2'))
+
+ ins = t.insert().values({"c1": "x", "c2": "y"}).cte()
+
+ stmt = select(t).add_cte(ins)
+
+ Would render::
+
+ WITH anon_1 AS
+ (INSERT INTO t (c1, c2) VALUES (:param_1, :param_2))
+ SELECT t.c1, t.c2
+ FROM t
+
+ Above, the "anon_1" CTE is not referred towards in the SELECT
+ statement, however still accomplishes the task of running an INSERT
+ statement.
+
+ Similarly in a DML-related context, using the PostgreSQL
+ :class:`_postgresql.Insert` construct to generate an "upsert"::
+
+ from sqlalchemy import table, column
+ from sqlalchemy.dialects.postgresql import insert
+
+ t = table("t", column("c1"), column("c2"))
+
+ delete_statement_cte = (
+ t.delete().where(t.c.c1 < 1).cte("deletions")
+ )
+
+ insert_stmt = insert(t).values({"c1": 1, "c2": 2})
+ update_statement = insert_stmt.on_conflict_do_update(
+ index_elements=[t.c.c1],
+ set_={
+ "c1": insert_stmt.excluded.c1,
+ "c2": insert_stmt.excluded.c2,
+ },
+ ).add_cte(delete_statement_cte)
+
+ print(update_statement)
+
+ The above statement renders as::
+
+ WITH deletions AS
+ (DELETE FROM t WHERE t.c1 < %(c1_1)s)
+ INSERT INTO t (c1, c2) VALUES (%(c1)s, %(c2)s)
+ ON CONFLICT (c1) DO UPDATE SET c1 = excluded.c1, c2 = excluded.c2
+
+ .. versionadded:: 1.4.21
+
+ """
+ cte = coercions.expect(roles.IsCTERole, cte)
+ self._independent_ctes += (cte,)
+
+ def cte(self, name=None, recursive=False, nesting=False):
+ r"""Return a new :class:`_expression.CTE`,
+ or Common Table Expression instance.
+
+ Common table expressions are a SQL standard whereby SELECT
+ statements can draw upon secondary statements specified along
+ with the primary statement, using a clause called "WITH".
+ Special semantics regarding UNION can also be employed to
+ allow "recursive" queries, where a SELECT statement can draw
+ upon the set of rows that have previously been selected.
+
+ CTEs can also be applied to DML constructs UPDATE, INSERT
+ and DELETE on some databases, both as a source of CTE rows
+ when combined with RETURNING, as well as a consumer of
+ CTE rows.
+
+ .. versionchanged:: 1.1 Added support for UPDATE/INSERT/DELETE as
+ CTE, CTEs added to UPDATE/INSERT/DELETE.
+
+ SQLAlchemy detects :class:`_expression.CTE` objects, which are treated
+ similarly to :class:`_expression.Alias` objects, as special elements
+ to be delivered to the FROM clause of the statement as well
+ as to a WITH clause at the top of the statement.
+
+ For special prefixes such as PostgreSQL "MATERIALIZED" and
+ "NOT MATERIALIZED", the :meth:`_expression.CTE.prefix_with`
+ method may be
+ used to establish these.
+
+ .. versionchanged:: 1.3.13 Added support for prefixes.
+ In particular - MATERIALIZED and NOT MATERIALIZED.
+
+ :param name: name given to the common table expression. Like
+ :meth:`_expression.FromClause.alias`, the name can be left as
+ ``None`` in which case an anonymous symbol will be used at query
+ compile time.
+ :param recursive: if ``True``, will render ``WITH RECURSIVE``.
+ A recursive common table expression is intended to be used in
+ conjunction with UNION ALL in order to derive rows
+ from those already selected.
+ :param nesting: if ``True``, will render the CTE locally to the
+ actual statement.
+
+ .. versionadded:: 1.4.24
+
+ The following examples include two from PostgreSQL's documentation at
+ https://www.postgresql.org/docs/current/static/queries-with.html,
+ as well as additional examples.
+
+ Example 1, non recursive::
+
+ from sqlalchemy import (Table, Column, String, Integer,
+ MetaData, select, func)
+
+ metadata = MetaData()
+
+ orders = Table('orders', metadata,
+ Column('region', String),
+ Column('amount', Integer),
+ Column('product', String),
+ Column('quantity', Integer)
+ )
+
+ regional_sales = select(
+ orders.c.region,
+ func.sum(orders.c.amount).label('total_sales')
+ ).group_by(orders.c.region).cte("regional_sales")
+
+
+ top_regions = select(regional_sales.c.region).\
+ where(
+ regional_sales.c.total_sales >
+ select(
+ func.sum(regional_sales.c.total_sales) / 10
+ )
+ ).cte("top_regions")
+
+ statement = select(
+ orders.c.region,
+ orders.c.product,
+ func.sum(orders.c.quantity).label("product_units"),
+ func.sum(orders.c.amount).label("product_sales")
+ ).where(orders.c.region.in_(
+ select(top_regions.c.region)
+ )).group_by(orders.c.region, orders.c.product)
+
+ result = conn.execute(statement).fetchall()
+
+ Example 2, WITH RECURSIVE::
+
+ from sqlalchemy import (Table, Column, String, Integer,
+ MetaData, select, func)
+
+ metadata = MetaData()
+
+ parts = Table('parts', metadata,
+ Column('part', String),
+ Column('sub_part', String),
+ Column('quantity', Integer),
+ )
+
+ included_parts = select(\
+ parts.c.sub_part, parts.c.part, parts.c.quantity\
+ ).\
+ where(parts.c.part=='our part').\
+ cte(recursive=True)
+
+
+ incl_alias = included_parts.alias()
+ parts_alias = parts.alias()
+ included_parts = included_parts.union_all(
+ select(
+ parts_alias.c.sub_part,
+ parts_alias.c.part,
+ parts_alias.c.quantity
+ ).\
+ where(parts_alias.c.part==incl_alias.c.sub_part)
+ )
+
+ statement = select(
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).
+ label('total_quantity')
+ ).\
+ group_by(included_parts.c.sub_part)
+
+ result = conn.execute(statement).fetchall()
+
+ Example 3, an upsert using UPDATE and INSERT with CTEs::
+
+ from datetime import date
+ from sqlalchemy import (MetaData, Table, Column, Integer,
+ Date, select, literal, and_, exists)
+
+ metadata = MetaData()
+
+ visitors = Table('visitors', metadata,
+ Column('product_id', Integer, primary_key=True),
+ Column('date', Date, primary_key=True),
+ Column('count', Integer),
+ )
+
+ # add 5 visitors for the product_id == 1
+ product_id = 1
+ day = date.today()
+ count = 5
+
+ update_cte = (
+ visitors.update()
+ .where(and_(visitors.c.product_id == product_id,
+ visitors.c.date == day))
+ .values(count=visitors.c.count + count)
+ .returning(literal(1))
+ .cte('update_cte')
+ )
+
+ upsert = visitors.insert().from_select(
+ [visitors.c.product_id, visitors.c.date, visitors.c.count],
+ select(literal(product_id), literal(day), literal(count))
+ .where(~exists(update_cte.select()))
+ )
+
+ connection.execute(upsert)
+
+ Example 4, Nesting CTE (SQLAlchemy 1.4.24 and above)::
+
+ value_a = select(
+ literal("root").label("n")
+ ).cte("value_a")
+
+ # A nested CTE with the same name as the root one
+ value_a_nested = select(
+ literal("nesting").label("n")
+ ).cte("value_a", nesting=True)
+
+ # Nesting CTEs takes ascendency locally
+ # over the CTEs at a higher level
+ value_b = select(value_a_nested.c.n).cte("value_b")
+
+ value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b"))
+
+ The above query will render the second CTE nested inside the first,
+ shown with inline parameters below as::
+
+ WITH
+ value_a AS
+ (SELECT 'root' AS n),
+ value_b AS
+ (WITH value_a AS
+ (SELECT 'nesting' AS n)
+ SELECT value_a.n AS n FROM value_a)
+ SELECT value_a.n AS a, value_b.n AS b
+ FROM value_a, value_b
+
+ Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above)::
+
+ edge = Table(
+ "edge",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("left", Integer),
+ Column("right", Integer),
+ )
+
+ root_node = select(literal(1).label("node")).cte(
+ "nodes", recursive=True
+ )
+
+ left_edge = select(edge.c.left).join(
+ root_node, edge.c.right == root_node.c.node
+ )
+ right_edge = select(edge.c.right).join(
+ root_node, edge.c.left == root_node.c.node
+ )
+
+ subgraph_cte = root_node.union(left_edge, right_edge)
+
+ subgraph = select(subgraph_cte)
+
+ The above query will render 2 UNIONs inside the recursive CTE::
+
+ WITH RECURSIVE nodes(node) AS (
+ SELECT 1 AS node
+ UNION
+ SELECT edge."left" AS "left"
+ FROM edge JOIN nodes ON edge."right" = nodes.node
+ UNION
+ SELECT edge."right" AS "right"
+ FROM edge JOIN nodes ON edge."left" = nodes.node
+ )
+ SELECT nodes.node FROM nodes
+
+ .. seealso::
+
+ :meth:`_orm.Query.cte` - ORM version of
+ :meth:`_expression.HasCTE.cte`.
+
+ """
+ return CTE._construct(
+ self, name=name, recursive=recursive, nesting=nesting
+ )
+
+
+class Subquery(AliasedReturnsRows):
+ """Represent a subquery of a SELECT.
+
+ A :class:`.Subquery` is created by invoking the
+ :meth:`_expression.SelectBase.subquery` method, or for convenience the
+ :meth:`_expression.SelectBase.alias` method, on any
+ :class:`_expression.SelectBase` subclass
+ which includes :class:`_expression.Select`,
+ :class:`_expression.CompoundSelect`, and
+ :class:`_expression.TextualSelect`. As rendered in a FROM clause,
+ it represents the
+ body of the SELECT statement inside of parenthesis, followed by the usual
+ "AS <somename>" that defines all "alias" objects.
+
+ The :class:`.Subquery` object is very similar to the
+ :class:`_expression.Alias`
+ object and can be used in an equivalent way. The difference between
+ :class:`_expression.Alias` and :class:`.Subquery` is that
+ :class:`_expression.Alias` always
+ contains a :class:`_expression.FromClause` object whereas
+ :class:`.Subquery`
+ always contains a :class:`_expression.SelectBase` object.
+
+ .. versionadded:: 1.4 The :class:`.Subquery` class was added which now
+ serves the purpose of providing an aliased version of a SELECT
+ statement.
+
+ """
+
+ __visit_name__ = "subquery"
+
+ _is_subquery = True
+
+ inherit_cache = True
+
+ @classmethod
+ def _factory(cls, selectable, name=None):
+ """Return a :class:`.Subquery` object."""
+ return coercions.expect(
+ roles.SelectStatementRole, selectable
+ ).subquery(name=name)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`.Subquery.as_scalar` method, which was previously "
+ "``Alias.as_scalar()`` prior to version 1.4, is deprecated and "
+ "will be removed in a future release; Please use the "
+ ":meth:`_expression.Select.scalar_subquery` method of the "
+ ":func:`_expression.select` "
+ "construct before constructing a subquery object, or with the ORM "
+ "use the :meth:`_query.Query.scalar_subquery` method.",
+ )
+ def as_scalar(self):
+ return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery()
+
+ def _execute_on_connection(
+ self,
+ connection,
+ multiparams,
+ params,
+ execution_options,
+ ):
+ util.warn_deprecated(
+ "Executing a subquery object is deprecated and will raise "
+ "ObjectNotExecutableError in an upcoming release. Please "
+ "execute the underlying select() statement directly.",
+ "1.4",
+ )
+ return self.element._execute_on_connection(
+ connection, multiparams, params, execution_options, _force=True
+ )
+
+
+class FromGrouping(GroupedElement, FromClause):
+ """Represent a grouping of a FROM clause"""
+
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
+ def __init__(self, element):
+ self.element = coercions.expect(roles.FromClauseRole, element)
+
+ def _init_collections(self):
+ pass
+
+ @property
+ def columns(self):
+ return self.element.columns
+
+ @property
+ def primary_key(self):
+ return self.element.primary_key
+
+ @property
+ def foreign_keys(self):
+ return self.element.foreign_keys
+
+ def is_derived_from(self, element):
+ return self.element.is_derived_from(element)
+
+ def alias(self, **kw):
+ return FromGrouping(self.element.alias(**kw))
+
+ def _anonymous_fromclause(self, **kw):
+ return FromGrouping(self.element._anonymous_fromclause(**kw))
+
+ @property
+ def _hide_froms(self):
+ return self.element._hide_froms
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+ def __getstate__(self):
+ return {"element": self.element}
+
+ def __setstate__(self, state):
+ self.element = state["element"]
+
+
+class TableClause(roles.DMLTableRole, Immutable, FromClause):
+ """Represents a minimal "table" construct.
+
+ This is a lightweight table object that has only a name, a
+ collection of columns, which are typically produced
+ by the :func:`_expression.column` function, and a schema::
+
+ from sqlalchemy import table, column
+
+ user = table("user",
+ column("id"),
+ column("name"),
+ column("description"),
+ )
+
+ The :class:`_expression.TableClause` construct serves as the base for
+ the more commonly used :class:`_schema.Table` object, providing
+ the usual set of :class:`_expression.FromClause` services including
+ the ``.c.`` collection and statement generation methods.
+
+ It does **not** provide all the additional schema-level services
+ of :class:`_schema.Table`, including constraints, references to other
+ tables, or support for :class:`_schema.MetaData`-level services.
+ It's useful
+ on its own as an ad-hoc construct used to generate quick SQL
+ statements when a more fully fledged :class:`_schema.Table`
+ is not on hand.
+
+ """
+
+ __visit_name__ = "table"
+
+ _traverse_internals = [
+ (
+ "columns",
+ InternalTraversal.dp_fromclause_canonical_column_collection,
+ ),
+ ("name", InternalTraversal.dp_string),
+ ]
+
+ named_with_column = True
+
+ implicit_returning = False
+ """:class:`_expression.TableClause`
+ doesn't support having a primary key or column
+ -level defaults, so implicit returning doesn't apply."""
+
+ _autoincrement_column = None
+ """No PK or default support so no autoincrement column."""
+
+ def __init__(self, name, *columns, **kw):
+ """Produce a new :class:`_expression.TableClause`.
+
+ The object returned is an instance of
+ :class:`_expression.TableClause`, which
+ represents the "syntactical" portion of the schema-level
+ :class:`_schema.Table` object.
+ It may be used to construct lightweight table constructs.
+
+ .. versionchanged:: 1.0.0 :func:`_expression.table` can now
+ be imported from the plain ``sqlalchemy`` namespace like any
+ other SQL element.
+
+
+ :param name: Name of the table.
+
+ :param columns: A collection of :func:`_expression.column` constructs.
+
+ :param schema: The schema name for this table.
+
+ .. versionadded:: 1.3.18 :func:`_expression.table` can now
+ accept a ``schema`` argument.
+ """
+
+ super(TableClause, self).__init__()
+ self.name = name
+ self._columns = DedupeColumnCollection()
+ self.primary_key = ColumnSet()
+ self.foreign_keys = set()
+ for c in columns:
+ self.append_column(c)
+
+ schema = kw.pop("schema", None)
+ if schema is not None:
+ self.schema = schema
+ if self.schema is not None:
+ self.fullname = "%s.%s" % (self.schema, self.name)
+ else:
+ self.fullname = self.name
+ if kw:
+ raise exc.ArgumentError("Unsupported argument(s): %s" % list(kw))
+
+ def __str__(self):
+ if self.schema is not None:
+ return self.schema + "." + self.name
+ else:
+ return self.name
+
+ def _refresh_for_new_column(self, column):
+ pass
+
+ def _init_collections(self):
+ pass
+
+ @util.memoized_property
+ def description(self):
+ if util.py3k:
+ return self.name
+ else:
+ return self.name.encode("ascii", "backslashreplace")
+
+ def append_column(self, c, **kw):
+ existing = c.table
+ if existing is not None and existing is not self:
+ raise exc.ArgumentError(
+ "column object '%s' already assigned to table '%s'"
+ % (c.key, existing)
+ )
+
+ self._columns.add(c)
+ c.table = self
+
+ @util.preload_module("sqlalchemy.sql.dml")
+ def insert(self, values=None, inline=False, **kwargs):
+ """Generate an :func:`_expression.insert` construct against this
+ :class:`_expression.TableClause`.
+
+ E.g.::
+
+ table.insert().values(name='foo')
+
+ See :func:`_expression.insert` for argument and usage information.
+
+ """
+ return util.preloaded.sql_dml.Insert(
+ self, values=values, inline=inline, **kwargs
+ )
+
+ @util.preload_module("sqlalchemy.sql.dml")
+ def update(self, whereclause=None, values=None, inline=False, **kwargs):
+ """Generate an :func:`_expression.update` construct against this
+ :class:`_expression.TableClause`.
+
+ E.g.::
+
+ table.update().where(table.c.id==7).values(name='foo')
+
+ See :func:`_expression.update` for argument and usage information.
+
+ """
+ return util.preloaded.sql_dml.Update(
+ self,
+ whereclause=whereclause,
+ values=values,
+ inline=inline,
+ **kwargs
+ )
+
+ @util.preload_module("sqlalchemy.sql.dml")
+ def delete(self, whereclause=None, **kwargs):
+ """Generate a :func:`_expression.delete` construct against this
+ :class:`_expression.TableClause`.
+
+ E.g.::
+
+ table.delete().where(table.c.id==7)
+
+ See :func:`_expression.delete` for argument and usage information.
+
+ """
+ return util.preloaded.sql_dml.Delete(self, whereclause, **kwargs)
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+
+class ForUpdateArg(ClauseElement):
+ _traverse_internals = [
+ ("of", InternalTraversal.dp_clauseelement_list),
+ ("nowait", InternalTraversal.dp_boolean),
+ ("read", InternalTraversal.dp_boolean),
+ ("skip_locked", InternalTraversal.dp_boolean),
+ ]
+
+ @classmethod
+ def _from_argument(cls, with_for_update):
+ if isinstance(with_for_update, ForUpdateArg):
+ return with_for_update
+ elif with_for_update in (None, False):
+ return None
+ elif with_for_update is True:
+ return ForUpdateArg()
+ else:
+ return ForUpdateArg(**with_for_update)
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, ForUpdateArg)
+ and other.nowait == self.nowait
+ and other.read == self.read
+ and other.skip_locked == self.skip_locked
+ and other.key_share == self.key_share
+ and other.of is self.of
+ )
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __hash__(self):
+ return id(self)
+
+ def __init__(
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
+ """Represents arguments specified to
+ :meth:`_expression.Select.for_update`.
+
+ """
+
+ self.nowait = nowait
+ self.read = read
+ self.skip_locked = skip_locked
+ self.key_share = key_share
+ if of is not None:
+ self.of = [
+ coercions.expect(roles.ColumnsClauseRole, elem)
+ for elem in util.to_list(of)
+ ]
+ else:
+ self.of = None
+
+
+class Values(Generative, FromClause):
+ """Represent a ``VALUES`` construct that can be used as a FROM element
+ in a statement.
+
+ The :class:`_expression.Values` object is created from the
+ :func:`_expression.values` function.
+
+ .. versionadded:: 1.4
+
+ """
+
+ named_with_column = True
+ __visit_name__ = "values"
+
+ _data = ()
+
+ _traverse_internals = [
+ ("_column_args", InternalTraversal.dp_clauseelement_list),
+ ("_data", InternalTraversal.dp_dml_multi_values),
+ ("name", InternalTraversal.dp_string),
+ ("literal_binds", InternalTraversal.dp_boolean),
+ ]
+
+ def __init__(self, *columns, **kw):
+ r"""Construct a :class:`_expression.Values` construct.
+
+ The column expressions and the actual data for
+ :class:`_expression.Values` are given in two separate steps. The
+ constructor receives the column expressions typically as
+ :func:`_expression.column` constructs,
+ and the data is then passed via the
+ :meth:`_expression.Values.data` method as a list,
+ which can be called multiple
+ times to add more data, e.g.::
+
+ from sqlalchemy import column
+ from sqlalchemy import values
+
+ value_expr = values(
+ column('id', Integer),
+ column('name', String),
+ name="my_values"
+ ).data(
+ [(1, 'name1'), (2, 'name2'), (3, 'name3')]
+ )
+
+ :param \*columns: column expressions, typically composed using
+ :func:`_expression.column` objects.
+
+ :param name: the name for this VALUES construct. If omitted, the
+ VALUES construct will be unnamed in a SQL expression. Different
+ backends may have different requirements here.
+
+ :param literal_binds: Defaults to False. Whether or not to render
+ the data values inline in the SQL output, rather than using bound
+ parameters.
+
+ """
+
+ super(Values, self).__init__()
+ self._column_args = columns
+ self.name = kw.pop("name", None)
+ self.literal_binds = kw.pop("literal_binds", False)
+ self.named_with_column = self.name is not None
+
+ @property
+ def _column_types(self):
+ return [col.type for col in self._column_args]
+
+ @_generative
+ def alias(self, name, **kw):
+ """Return a new :class:`_expression.Values`
+ construct that is a copy of this
+ one with the given name.
+
+ This method is a VALUES-specific specialization of the
+ :meth:`_expression.FromClause.alias` method.
+
+ .. seealso::
+
+ :ref:`tutorial_using_aliases`
+
+ :func:`_expression.alias`
+
+ """
+ self.name = name
+ self.named_with_column = self.name is not None
+
+ @_generative
+ def lateral(self, name=None):
+ """Return a new :class:`_expression.Values` with the lateral flag set,
+ so that
+ it renders as LATERAL.
+
+ .. seealso::
+
+ :func:`_expression.lateral`
+
+ """
+ self._is_lateral = True
+ if name is not None:
+ self.name = name
+
+ @_generative
+ def data(self, values):
+ """Return a new :class:`_expression.Values` construct,
+ adding the given data
+ to the data list.
+
+ E.g.::
+
+ my_values = my_values.data([(1, 'value 1'), (2, 'value2')])
+
+ :param values: a sequence (i.e. list) of tuples that map to the
+ column expressions given in the :class:`_expression.Values`
+ constructor.
+
+ """
+
+ self._data += (values,)
+
+ def _populate_column_collection(self):
+ for c in self._column_args:
+ self._columns.add(c)
+ c.table = self
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+
+class SelectBase(
+ roles.SelectStatementRole,
+ roles.DMLSelectRole,
+ roles.CompoundElementRole,
+ roles.InElementRole,
+ HasCTE,
+ Executable,
+ SupportsCloneAnnotations,
+ Selectable,
+):
+ """Base class for SELECT statements.
+
+
+ This includes :class:`_expression.Select`,
+ :class:`_expression.CompoundSelect` and
+ :class:`_expression.TextualSelect`.
+
+
+ """
+
+ _is_select_statement = True
+ is_select = True
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ raise NotImplementedError()
+
+ def _refresh_for_new_column(self, column):
+ self._reset_memoizations()
+
+ @property
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set.
+
+ This collection differs from the :attr:`_expression.FromClause.columns`
+ collection of a :class:`_expression.FromClause` in that the columns
+ within this collection cannot be directly nested inside another SELECT
+ statement; a subquery must be applied first which provides for the
+ necessary parenthesization required by SQL.
+
+ .. note::
+
+ The :attr:`_sql.SelectBase.selected_columns` collection does not
+ include expressions established in the columns clause using the
+ :func:`_sql.text` construct; these are silently omitted from the
+ collection. To use plain textual column expressions inside of a
+ :class:`_sql.Select` construct, use the :func:`_sql.literal_column`
+ construct.
+
+ .. seealso::
+
+ :attr:`_sql.Select.selected_columns`
+
+ .. versionadded:: 1.4
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def _all_selected_columns(self):
+ """A sequence of expressions that correspond to what is rendered
+ in the columns clause, including :class:`_sql.TextClause`
+ constructs.
+
+ .. versionadded:: 1.4.12
+
+ .. seealso::
+
+ :attr:`_sql.SelectBase.exported_columns`
+
+ """
+ raise NotImplementedError()
+
+ @property
+ def exported_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ that represents the "exported"
+ columns of this :class:`_expression.Selectable`, not including
+ :class:`_sql.TextClause` constructs.
+
+ The "exported" columns for a :class:`_expression.SelectBase`
+ object are synonymous
+ with the :attr:`_expression.SelectBase.selected_columns` collection.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_expression.Select.exported_columns`
+
+ :attr:`_expression.Selectable.exported_columns`
+
+ :attr:`_expression.FromClause.exported_columns`
+
+
+ """
+ return self.selected_columns
+
+ @property
+ @util.deprecated(
+ "1.4",
+ "The :attr:`_expression.SelectBase.c` and "
+ ":attr:`_expression.SelectBase.columns` attributes "
+ "are deprecated and will be removed in a future release; these "
+ "attributes implicitly create a subquery that should be explicit. "
+ "Please call :meth:`_expression.SelectBase.subquery` "
+ "first in order to create "
+ "a subquery, which then contains this attribute. To access the "
+ "columns that this SELECT object SELECTs "
+ "from, use the :attr:`_expression.SelectBase.selected_columns` "
+ "attribute.",
+ )
+ def c(self):
+ return self._implicit_subquery.columns
+
+ @property
+ def columns(self):
+ return self.c
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.SelectBase.select` method is deprecated "
+ "and will be removed in a future release; this method implicitly "
+ "creates a subquery that should be explicit. "
+ "Please call :meth:`_expression.SelectBase.subquery` "
+ "first in order to create "
+ "a subquery, which then can be selected.",
+ )
+ def select(self, *arg, **kw):
+ return self._implicit_subquery.select(*arg, **kw)
+
+ @HasMemoized.memoized_attribute
+ def _implicit_subquery(self):
+ return self.subquery()
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.SelectBase.as_scalar` "
+ "method is deprecated and will be "
+ "removed in a future release. Please refer to "
+ ":meth:`_expression.SelectBase.scalar_subquery`.",
+ )
+ def as_scalar(self):
+ return self.scalar_subquery()
+
+ def exists(self):
+ """Return an :class:`_sql.Exists` representation of this selectable,
+ which can be used as a column expression.
+
+ The returned object is an instance of :class:`_sql.Exists`.
+
+ .. seealso::
+
+ :func:`_sql.exists`
+
+ :ref:`tutorial_exists` - in the :term:`2.0 style` tutorial.
+
+ .. versionadded:: 1.4
+
+ """
+ return Exists(self)
+
+ def scalar_subquery(self):
+ """Return a 'scalar' representation of this selectable, which can be
+ used as a column expression.
+
+ The returned object is an instance of :class:`_sql.ScalarSelect`.
+
+ Typically, a select statement which has only one column in its columns
+ clause is eligible to be used as a scalar expression. The scalar
+ subquery can then be used in the WHERE clause or columns clause of
+ an enclosing SELECT.
+
+ Note that the scalar subquery differentiates from the FROM-level
+ subquery that can be produced using the
+ :meth:`_expression.SelectBase.subquery`
+ method.
+
+ .. versionchanged: 1.4 - the ``.as_scalar()`` method was renamed to
+ :meth:`_expression.SelectBase.scalar_subquery`.
+
+ .. seealso::
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+ """
+ if self._label_style is not LABEL_STYLE_NONE:
+ self = self.set_label_style(LABEL_STYLE_NONE)
+
+ return ScalarSelect(self)
+
+ def label(self, name):
+ """Return a 'scalar' representation of this selectable, embedded as a
+ subquery with a label.
+
+ .. seealso::
+
+ :meth:`_expression.SelectBase.as_scalar`.
+
+ """
+ return self.scalar_subquery().label(name)
+
+ def lateral(self, name=None):
+ """Return a LATERAL alias of this :class:`_expression.Selectable`.
+
+ The return value is the :class:`_expression.Lateral` construct also
+ provided by the top-level :func:`_expression.lateral` function.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`tutorial_lateral_correlation` - overview of usage.
+
+ """
+ return Lateral._factory(self, name)
+
+ @property
+ def _from_objects(self):
+ return [self]
+
+ def subquery(self, name=None):
+ """Return a subquery of this :class:`_expression.SelectBase`.
+
+ A subquery is from a SQL perspective a parenthesized, named
+ construct that can be placed in the FROM clause of another
+ SELECT statement.
+
+ Given a SELECT statement such as::
+
+ stmt = select(table.c.id, table.c.name)
+
+ The above statement might look like::
+
+ SELECT table.id, table.name FROM table
+
+ The subquery form by itself renders the same way, however when
+ embedded into the FROM clause of another SELECT statement, it becomes
+ a named sub-element::
+
+ subq = stmt.subquery()
+ new_stmt = select(subq)
+
+ The above renders as::
+
+ SELECT anon_1.id, anon_1.name
+ FROM (SELECT table.id, table.name FROM table) AS anon_1
+
+ Historically, :meth:`_expression.SelectBase.subquery`
+ is equivalent to calling
+ the :meth:`_expression.FromClause.alias`
+ method on a FROM object; however,
+ as a :class:`_expression.SelectBase`
+ object is not directly FROM object,
+ the :meth:`_expression.SelectBase.subquery`
+ method provides clearer semantics.
+
+ .. versionadded:: 1.4
+
+ """
+
+ return Subquery._construct(self._ensure_disambiguated_names(), name)
+
+ def _ensure_disambiguated_names(self):
+ """Ensure that the names generated by this selectbase will be
+ disambiguated in some way, if possible.
+
+ """
+
+ raise NotImplementedError()
+
+ def alias(self, name=None, flat=False):
+ """Return a named subquery against this
+ :class:`_expression.SelectBase`.
+
+ For a :class:`_expression.SelectBase` (as opposed to a
+ :class:`_expression.FromClause`),
+ this returns a :class:`.Subquery` object which behaves mostly the
+ same as the :class:`_expression.Alias` object that is used with a
+ :class:`_expression.FromClause`.
+
+ .. versionchanged:: 1.4 The :meth:`_expression.SelectBase.alias`
+ method is now
+ a synonym for the :meth:`_expression.SelectBase.subquery` method.
+
+ """
+ return self.subquery(name=name)
+
+
+class SelectStatementGrouping(GroupedElement, SelectBase):
+ """Represent a grouping of a :class:`_expression.SelectBase`.
+
+ This differs from :class:`.Subquery` in that we are still
+ an "inner" SELECT statement, this is strictly for grouping inside of
+ compound selects.
+
+ """
+
+ __visit_name__ = "select_statement_grouping"
+ _traverse_internals = [("element", InternalTraversal.dp_clauseelement)]
+
+ _is_select_container = True
+
+ def __init__(self, element):
+ self.element = coercions.expect(roles.SelectStatementRole, element)
+
+ def _ensure_disambiguated_names(self):
+ new_element = self.element._ensure_disambiguated_names()
+ if new_element is not self.element:
+ return SelectStatementGrouping(new_element)
+ else:
+ return self
+
+ def get_label_style(self):
+ return self._label_style
+
+ def set_label_style(self, label_style):
+ return SelectStatementGrouping(
+ self.element.set_label_style(label_style)
+ )
+
+ @property
+ def _label_style(self):
+ return self.element._label_style
+
+ @property
+ def select_statement(self):
+ return self.element
+
+ def self_group(self, against=None):
+ return self
+
+ def _generate_columns_plus_names(self, anon_for_dupe_key):
+ return self.element._generate_columns_plus_names(anon_for_dupe_key)
+
+ def _generate_fromclause_column_proxies(self, subquery):
+ self.element._generate_fromclause_column_proxies(subquery)
+
+ def _generate_proxy_for_new_column(self, column, subquery):
+ return self.element._generate_proxy_for_new_column(subquery)
+
+ @property
+ def _all_selected_columns(self):
+ return self.element._all_selected_columns
+
+ @property
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ the embedded SELECT statement returns in its result set, not including
+ :class:`_sql.TextClause` constructs.
+
+ .. versionadded:: 1.4
+
+ .. seealso::
+
+ :attr:`_sql.Select.selected_columns`
+
+ """
+ return self.element.selected_columns
+
+ @property
+ def _from_objects(self):
+ return self.element._from_objects
+
+
+class DeprecatedSelectBaseGenerations(object):
+ """A collection of methods available on :class:`_sql.Select` and
+ :class:`_sql.CompoundSelect`, these are all **deprecated** methods as they
+ modify the object in-place.
+
+ """
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.GenerativeSelect.append_order_by` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative method "
+ ":meth:`_expression.GenerativeSelect.order_by`.",
+ )
+ def append_order_by(self, *clauses):
+ """Append the given ORDER BY criterion applied to this selectable.
+
+ The criterion will be appended to any pre-existing ORDER BY criterion.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.GenerativeSelect.order_by` method is preferred,
+ as it
+ provides standard :term:`method chaining`.
+
+ .. seealso::
+
+ :meth:`_expression.GenerativeSelect.order_by`
+
+ """
+ self.order_by.non_generative(self, *clauses)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.GenerativeSelect.append_group_by` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative method "
+ ":meth:`_expression.GenerativeSelect.group_by`.",
+ )
+ def append_group_by(self, *clauses):
+ """Append the given GROUP BY criterion applied to this selectable.
+
+ The criterion will be appended to any pre-existing GROUP BY criterion.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.GenerativeSelect.group_by` method is preferred,
+ as it
+ provides standard :term:`method chaining`.
+
+
+ """
+ self.group_by.non_generative(self, *clauses)
+
+
+class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
+ """Base class for SELECT statements where additional elements can be
+ added.
+
+ This serves as the base for :class:`_expression.Select` and
+ :class:`_expression.CompoundSelect`
+ where elements such as ORDER BY, GROUP BY can be added and column
+ rendering can be controlled. Compare to
+ :class:`_expression.TextualSelect`, which,
+ while it subclasses :class:`_expression.SelectBase`
+ and is also a SELECT construct,
+ represents a fixed textual string which cannot be altered at this level,
+ only wrapped as a subquery.
+
+ """
+
+ _order_by_clauses = ()
+ _group_by_clauses = ()
+ _limit_clause = None
+ _offset_clause = None
+ _fetch_clause = None
+ _fetch_clause_options = None
+ _for_update_arg = None
+
+ @util.deprecated_params(
+ bind=(
+ "2.0",
+ "The :paramref:`_sql.select.bind` argument is deprecated and "
+ "will be removed in SQLAlchemy 2.0.",
+ ),
+ )
+ def __init__(
+ self,
+ _label_style=LABEL_STYLE_DEFAULT,
+ use_labels=False,
+ limit=None,
+ offset=None,
+ order_by=None,
+ group_by=None,
+ bind=None,
+ ):
+ if use_labels:
+ if util.SQLALCHEMY_WARN_20:
+ util.warn_deprecated_20(
+ "The use_labels=True keyword argument to GenerativeSelect "
+ "is deprecated and will be removed in version 2.0. Please "
+ "use "
+ "select.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
+ "if you need to replicate this legacy behavior.",
+ stacklevel=4,
+ )
+ _label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+
+ self._label_style = _label_style
+
+ if limit is not None:
+ self.limit.non_generative(self, limit)
+ if offset is not None:
+ self.offset.non_generative(self, offset)
+
+ if order_by is not None:
+ self.order_by.non_generative(self, *util.to_list(order_by))
+ if group_by is not None:
+ self.group_by.non_generative(self, *util.to_list(group_by))
+
+ self._bind = bind
+
+ @_generative
+ def with_for_update(
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
+ """Specify a ``FOR UPDATE`` clause for this
+ :class:`_expression.GenerativeSelect`.
+
+ E.g.::
+
+ stmt = select(table).with_for_update(nowait=True)
+
+ On a database like PostgreSQL or Oracle, the above would render a
+ statement like::
+
+ SELECT table.a, table.b FROM table FOR UPDATE NOWAIT
+
+ on other backends, the ``nowait`` option is ignored and instead
+ would produce::
+
+ SELECT table.a, table.b FROM table FOR UPDATE
+
+ When called with no arguments, the statement will render with
+ the suffix ``FOR UPDATE``. Additional arguments can then be
+ provided which allow for common database-specific
+ variants.
+
+ :param nowait: boolean; will render ``FOR UPDATE NOWAIT`` on Oracle
+ and PostgreSQL dialects.
+
+ :param read: boolean; will render ``LOCK IN SHARE MODE`` on MySQL,
+ ``FOR SHARE`` on PostgreSQL. On PostgreSQL, when combined with
+ ``nowait``, will render ``FOR SHARE NOWAIT``.
+
+ :param of: SQL expression or list of SQL expression elements
+ (typically :class:`_schema.Column`
+ objects or a compatible expression) which
+ will render into a ``FOR UPDATE OF`` clause; supported by PostgreSQL
+ and Oracle. May render as a table or as a column depending on
+ backend.
+
+ :param skip_locked: boolean, will render ``FOR UPDATE SKIP LOCKED``
+ on Oracle and PostgreSQL dialects or ``FOR SHARE SKIP LOCKED`` if
+ ``read=True`` is also specified.
+
+ :param key_share: boolean, will render ``FOR NO KEY UPDATE``,
+ or if combined with ``read=True`` will render ``FOR KEY SHARE``,
+ on the PostgreSQL dialect.
+
+ """
+ self._for_update_arg = ForUpdateArg(
+ nowait=nowait,
+ read=read,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
+
+ def get_label_style(self):
+ """
+ Retrieve the current label style.
+
+ .. versionadded:: 1.4
+
+ """
+ return self._label_style
+
+ def set_label_style(self, style):
+ """Return a new selectable with the specified label style.
+
+ There are three "label styles" available,
+ :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`,
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`, and
+ :data:`_sql.LABEL_STYLE_NONE`. The default style is
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`.
+
+ In modern SQLAlchemy, there is not generally a need to change the
+ labeling style, as per-expression labels are more effectively used by
+ making use of the :meth:`_sql.ColumnElement.label` method. In past
+ versions, :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL` was used to
+ disambiguate same-named columns from different tables, aliases, or
+ subqueries; the newer :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY` now
+ applies labels only to names that conflict with an existing name so
+ that the impact of this labeling is minimal.
+
+ The rationale for disambiguation is mostly so that all column
+ expressions are available from a given :attr:`_sql.FromClause.c`
+ collection when a subquery is created.
+
+ .. versionadded:: 1.4 - the
+ :meth:`_sql.GenerativeSelect.set_label_style` method replaces the
+ previous combination of ``.apply_labels()``, ``.with_labels()`` and
+ ``use_labels=True`` methods and/or parameters.
+
+ .. seealso::
+
+ :data:`_sql.LABEL_STYLE_DISAMBIGUATE_ONLY`
+
+ :data:`_sql.LABEL_STYLE_TABLENAME_PLUS_COL`
+
+ :data:`_sql.LABEL_STYLE_NONE`
+
+ :data:`_sql.LABEL_STYLE_DEFAULT`
+
+ """
+ if self._label_style is not style:
+ self = self._generate()
+ self._label_style = style
+ return self
+
+ @util.deprecated_20(
+ ":meth:`_sql.GenerativeSelect.apply_labels`",
+ alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) "
+ "instead.",
+ )
+ def apply_labels(self):
+ return self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
+
+ @property
+ def _group_by_clause(self):
+ """ClauseList access to group_by_clauses for legacy dialects"""
+ return ClauseList._construct_raw(
+ operators.comma_op, self._group_by_clauses
+ )
+
+ @property
+ def _order_by_clause(self):
+ """ClauseList access to order_by_clauses for legacy dialects"""
+ return ClauseList._construct_raw(
+ operators.comma_op, self._order_by_clauses
+ )
+
+ def _offset_or_limit_clause(self, element, name=None, type_=None):
+ """Convert the given value to an "offset or limit" clause.
+
+ This handles incoming integers and converts to an expression; if
+ an expression is already given, it is passed through.
+
+ """
+ return coercions.expect(
+ roles.LimitOffsetRole, element, name=name, type_=type_
+ )
+
+ def _offset_or_limit_clause_asint(self, clause, attrname):
+ """Convert the "offset or limit" clause of a select construct to an
+ integer.
+
+ This is only possible if the value is stored as a simple bound
+ parameter. Otherwise, a compilation error is raised.
+
+ """
+ if clause is None:
+ return None
+ try:
+ value = clause._limit_offset_value
+ except AttributeError as err:
+ util.raise_(
+ exc.CompileError(
+ "This SELECT structure does not use a simple "
+ "integer value for %s" % attrname
+ ),
+ replace_context=err,
+ )
+ else:
+ return util.asint(value)
+
+ @property
+ def _limit(self):
+ """Get an integer value for the limit. This should only be used
+ by code that cannot support a limit as a BindParameter or
+ other custom clause as it will throw an exception if the limit
+ isn't currently set to an integer.
+
+ """
+ return self._offset_or_limit_clause_asint(self._limit_clause, "limit")
+
+ def _simple_int_clause(self, clause):
+ """True if the clause is a simple integer, False
+ if it is not present or is a SQL expression.
+ """
+ return isinstance(clause, _OffsetLimitParam)
+
+ @property
+ def _offset(self):
+ """Get an integer value for the offset. This should only be used
+ by code that cannot support an offset as a BindParameter or
+ other custom clause as it will throw an exception if the
+ offset isn't currently set to an integer.
+
+ """
+ return self._offset_or_limit_clause_asint(
+ self._offset_clause, "offset"
+ )
+
+ @property
+ def _has_row_limiting_clause(self):
+ return (
+ self._limit_clause is not None
+ or self._offset_clause is not None
+ or self._fetch_clause is not None
+ )
+
+ @_generative
+ def limit(self, limit):
+ """Return a new selectable with the given LIMIT criterion
+ applied.
+
+ This is a numerical value which usually renders as a ``LIMIT``
+ expression in the resulting select. Backends that don't
+ support ``LIMIT`` will attempt to provide similar
+ functionality.
+
+ .. note::
+
+ The :meth:`_sql.GenerativeSelect.limit` method will replace
+ any clause applied with :meth:`_sql.GenerativeSelect.fetch`.
+
+ .. versionchanged:: 1.0.0 - :meth:`_expression.Select.limit` can now
+ accept arbitrary SQL expressions as well as integer values.
+
+ :param limit: an integer LIMIT parameter, or a SQL expression
+ that provides an integer result. Pass ``None`` to reset it.
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.fetch`
+
+ :meth:`_sql.GenerativeSelect.offset`
+
+ """
+
+ self._fetch_clause = self._fetch_clause_options = None
+ self._limit_clause = self._offset_or_limit_clause(limit)
+
+ @_generative
+ def fetch(self, count, with_ties=False, percent=False):
+ """Return a new selectable with the given FETCH FIRST criterion
+ applied.
+
+ This is a numeric value which usually renders as
+ ``FETCH {FIRST | NEXT} [ count ] {ROW | ROWS} {ONLY | WITH TIES}``
+ expression in the resulting select. This functionality is
+ is currently implemented for Oracle, PostgreSQL, MSSQL.
+
+ Use :meth:`_sql.GenerativeSelect.offset` to specify the offset.
+
+ .. note::
+
+ The :meth:`_sql.GenerativeSelect.fetch` method will replace
+ any clause applied with :meth:`_sql.GenerativeSelect.limit`.
+
+ .. versionadded:: 1.4
+
+ :param count: an integer COUNT parameter, or a SQL expression
+ that provides an integer result. When ``percent=True`` this will
+ represent the percentage of rows to return, not the absolute value.
+ Pass ``None`` to reset it.
+
+ :param with_ties: When ``True``, the WITH TIES option is used
+ to return any additional rows that tie for the last place in the
+ result set according to the ``ORDER BY`` clause. The
+ ``ORDER BY`` may be mandatory in this case. Defaults to ``False``
+
+ :param percent: When ``True``, ``count`` represents the percentage
+ of the total number of selected rows to return. Defaults to ``False``
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.limit`
+
+ :meth:`_sql.GenerativeSelect.offset`
+
+ """
+
+ self._limit_clause = None
+ if count is None:
+ self._fetch_clause = self._fetch_clause_options = None
+ else:
+ self._fetch_clause = self._offset_or_limit_clause(count)
+ self._fetch_clause_options = {
+ "with_ties": with_ties,
+ "percent": percent,
+ }
+
+ @_generative
+ def offset(self, offset):
+ """Return a new selectable with the given OFFSET criterion
+ applied.
+
+
+ This is a numeric value which usually renders as an ``OFFSET``
+ expression in the resulting select. Backends that don't
+ support ``OFFSET`` will attempt to provide similar
+ functionality.
+
+
+ .. versionchanged:: 1.0.0 - :meth:`_expression.Select.offset` can now
+ accept arbitrary SQL expressions as well as integer values.
+
+ :param offset: an integer OFFSET parameter, or a SQL expression
+ that provides an integer result. Pass ``None`` to reset it.
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.limit`
+
+ :meth:`_sql.GenerativeSelect.fetch`
+
+ """
+
+ self._offset_clause = self._offset_or_limit_clause(offset)
+
+ @_generative
+ @util.preload_module("sqlalchemy.sql.util")
+ def slice(self, start, stop):
+ """Apply LIMIT / OFFSET to this statement based on a slice.
+
+ The start and stop indices behave like the argument to Python's
+ built-in :func:`range` function. This method provides an
+ alternative to using ``LIMIT``/``OFFSET`` to get a slice of the
+ query.
+
+ For example, ::
+
+ stmt = select(User).order_by(User).id.slice(1, 3)
+
+ renders as
+
+ .. sourcecode:: sql
+
+ SELECT users.id AS users_id,
+ users.name AS users_name
+ FROM users ORDER BY users.id
+ LIMIT ? OFFSET ?
+ (2, 1)
+
+ .. note::
+
+ The :meth:`_sql.GenerativeSelect.slice` method will replace
+ any clause applied with :meth:`_sql.GenerativeSelect.fetch`.
+
+ .. versionadded:: 1.4 Added the :meth:`_sql.GenerativeSelect.slice`
+ method generalized from the ORM.
+
+ .. seealso::
+
+ :meth:`_sql.GenerativeSelect.limit`
+
+ :meth:`_sql.GenerativeSelect.offset`
+
+ :meth:`_sql.GenerativeSelect.fetch`
+
+ """
+ sql_util = util.preloaded.sql_util
+ self._fetch_clause = self._fetch_clause_options = None
+ self._limit_clause, self._offset_clause = sql_util._make_slice(
+ self._limit_clause, self._offset_clause, start, stop
+ )
+
+ @_generative
+ def order_by(self, *clauses):
+ r"""Return a new selectable with the given list of ORDER BY
+ criteria applied.
+
+ e.g.::
+
+ stmt = select(table).order_by(table.c.id, table.c.name)
+
+ All existing ORDER BY criteria may be cancelled by passing
+ ``None`` by itself. New ORDER BY criteria may then be added by
+ invoking :meth:`_sql.Select.order_by` again, e.g.::
+
+ # will erase all ORDER BY and ORDER BY new_col alone
+ stmt = stmt.order_by(None).order_by(new_col)
+
+ :param \*clauses: a series of :class:`_expression.ColumnElement`
+ constructs
+ which will be used to generate an ORDER BY clause.
+
+ .. seealso::
+
+ :ref:`tutorial_order_by` - in the :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and clauses[0] is None:
+ self._order_by_clauses = ()
+ else:
+ self._order_by_clauses += tuple(
+ coercions.expect(roles.OrderByRole, clause)
+ for clause in clauses
+ )
+
+ @_generative
+ def group_by(self, *clauses):
+ r"""Return a new selectable with the given list of GROUP BY
+ criterion applied.
+
+ All existing GROUP BY settings can be suppressed by passing ``None``.
+
+ e.g.::
+
+ stmt = select(table.c.name, func.max(table.c.stat)).\
+ group_by(table.c.name)
+
+ :param \*clauses: a series of :class:`_expression.ColumnElement`
+ constructs
+ which will be used to generate an GROUP BY clause.
+
+ .. seealso::
+
+ :ref:`tutorial_group_by_w_aggregates` - in the
+ :ref:`unified_tutorial`
+
+ :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial`
+
+ """
+
+ if len(clauses) == 1 and clauses[0] is None:
+ self._group_by_clauses = ()
+ else:
+ self._group_by_clauses += tuple(
+ coercions.expect(roles.GroupByRole, clause)
+ for clause in clauses
+ )
+
+
+@CompileState.plugin_for("default", "compound_select")
+class CompoundSelectState(CompileState):
+ @util.memoized_property
+ def _label_resolve_dict(self):
+ # TODO: this is hacky and slow
+ hacky_subquery = self.statement.subquery()
+ hacky_subquery.named_with_column = False
+ d = dict((c.key, c) for c in hacky_subquery.c)
+ return d, d, d
+
+
+class CompoundSelect(HasCompileState, GenerativeSelect):
+ """Forms the basis of ``UNION``, ``UNION ALL``, and other
+ SELECT-based set operations.
+
+
+ .. seealso::
+
+ :func:`_expression.union`
+
+ :func:`_expression.union_all`
+
+ :func:`_expression.intersect`
+
+ :func:`_expression.intersect_all`
+
+ :func:`_expression.except`
+
+ :func:`_expression.except_all`
+
+ """
+
+ __visit_name__ = "compound_select"
+
+ _traverse_internals = [
+ ("selects", InternalTraversal.dp_clauseelement_list),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
+ ("_order_by_clauses", InternalTraversal.dp_clauseelement_list),
+ ("_group_by_clauses", InternalTraversal.dp_clauseelement_list),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("keyword", InternalTraversal.dp_string),
+ ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
+
+ UNION = util.symbol("UNION")
+ UNION_ALL = util.symbol("UNION ALL")
+ EXCEPT = util.symbol("EXCEPT")
+ EXCEPT_ALL = util.symbol("EXCEPT ALL")
+ INTERSECT = util.symbol("INTERSECT")
+ INTERSECT_ALL = util.symbol("INTERSECT ALL")
+
+ _is_from_container = True
+
+ def __init__(self, keyword, *selects, **kwargs):
+ self._auto_correlate = kwargs.pop("correlate", False)
+ self.keyword = keyword
+ self.selects = [
+ coercions.expect(roles.CompoundElementRole, s).self_group(
+ against=self
+ )
+ for s in selects
+ ]
+
+ if kwargs and util.SQLALCHEMY_WARN_20:
+ util.warn_deprecated_20(
+ "Set functions such as union(), union_all(), extract(), etc. "
+ "in SQLAlchemy 2.0 will accept a "
+ "series of SELECT statements only. "
+ "Please use generative methods such as order_by() for "
+ "additional modifications to this CompoundSelect.",
+ stacklevel=4,
+ )
+
+ GenerativeSelect.__init__(self, **kwargs)
+
+ @classmethod
+ def _create_union(cls, *selects, **kwargs):
+ r"""Return a ``UNION`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ A similar :func:`union()` method is available on all
+ :class:`_expression.FromClause` subclasses.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.UNION, *selects, **kwargs)
+
+ @classmethod
+ def _create_union_all(cls, *selects, **kwargs):
+ r"""Return a ``UNION ALL`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ A similar :func:`union_all()` method is available on all
+ :class:`_expression.FromClause` subclasses.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.UNION_ALL, *selects, **kwargs)
+
+ @classmethod
+ def _create_except(cls, *selects, **kwargs):
+ r"""Return an ``EXCEPT`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.EXCEPT, *selects, **kwargs)
+
+ @classmethod
+ def _create_except_all(cls, *selects, **kwargs):
+ r"""Return an ``EXCEPT ALL`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.EXCEPT_ALL, *selects, **kwargs)
+
+ @classmethod
+ def _create_intersect(cls, *selects, **kwargs):
+ r"""Return an ``INTERSECT`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.INTERSECT, *selects, **kwargs)
+
+ @classmethod
+ def _create_intersect_all(cls, *selects, **kwargs):
+ r"""Return an ``INTERSECT ALL`` of multiple selectables.
+
+ The returned object is an instance of
+ :class:`_expression.CompoundSelect`.
+
+ :param \*selects:
+ a list of :class:`_expression.Select` instances.
+
+ :param \**kwargs:
+ available keyword arguments are the same as those of
+ :func:`select`.
+
+ """
+ return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
+
+ def _scalar_type(self):
+ return self.selects[0]._scalar_type()
+
+ def self_group(self, against=None):
+ return SelectStatementGrouping(self)
+
+ def is_derived_from(self, fromclause):
+ for s in self.selects:
+ if s.is_derived_from(fromclause):
+ return True
+ return False
+
+ def _set_label_style(self, style):
+ if self._label_style is not style:
+ self = self._generate()
+ select_0 = self.selects[0]._set_label_style(style)
+ self.selects = [select_0] + self.selects[1:]
+
+ return self
+
+ def _ensure_disambiguated_names(self):
+ new_select = self.selects[0]._ensure_disambiguated_names()
+ if new_select is not self.selects[0]:
+ self = self._generate()
+ self.selects = [new_select] + self.selects[1:]
+
+ return self
+
+ def _generate_fromclause_column_proxies(self, subquery):
+
+ # this is a slightly hacky thing - the union exports a
+ # column that resembles just that of the *first* selectable.
+ # to get at a "composite" column, particularly foreign keys,
+ # you have to dig through the proxies collection which we
+ # generate below. We may want to improve upon this, such as
+ # perhaps _make_proxy can accept a list of other columns
+ # that are "shared" - schema.column can then copy all the
+ # ForeignKeys in. this would allow the union() to have all
+ # those fks too.
+ select_0 = self.selects[0]
+
+ if self._label_style is not LABEL_STYLE_DEFAULT:
+ select_0 = select_0.set_label_style(self._label_style)
+ select_0._generate_fromclause_column_proxies(subquery)
+
+ # hand-construct the "_proxies" collection to include all
+ # derived columns place a 'weight' annotation corresponding
+ # to how low in the list of select()s the column occurs, so
+ # that the corresponding_column() operation can resolve
+ # conflicts
+
+ for subq_col, select_cols in zip(
+ subquery.c._all_columns,
+ zip(*[s.selected_columns for s in self.selects]),
+ ):
+ subq_col._proxies = [
+ c._annotate({"weight": i + 1})
+ for (i, c) in enumerate(select_cols)
+ ]
+
+ def _refresh_for_new_column(self, column):
+ super(CompoundSelect, self)._refresh_for_new_column(column)
+ for select in self.selects:
+ select._refresh_for_new_column(column)
+
+ @property
+ def _all_selected_columns(self):
+ return self.selects[0]._all_selected_columns
+
+ @property
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set,
+ not including :class:`_sql.TextClause` constructs.
+
+ For a :class:`_expression.CompoundSelect`, the
+ :attr:`_expression.CompoundSelect.selected_columns`
+ attribute returns the selected
+ columns of the first SELECT statement contained within the series of
+ statements within the set operation.
+
+ .. seealso::
+
+ :attr:`_sql.Select.selected_columns`
+
+ .. versionadded:: 1.4
+
+ """
+ return self.selects[0].selected_columns
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Returns the :class:`_engine.Engine` or :class:`_engine.Connection`
+ to which this :class:`.Executable` is bound, or None if none found.
+
+ """
+ if self._bind:
+ return self._bind
+ for s in self.selects:
+ e = s.bind
+ if e:
+ return e
+ else:
+ return None
+
+ @bind.setter
+ def bind(self, bind):
+ self._bind = bind
+
+
+class DeprecatedSelectGenerations(object):
+ """A collection of methods available on :class:`_sql.Select`, these
+ are all **deprecated** methods as they modify the :class:`_sql.Select`
+ object in -place.
+
+ """
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_correlation` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.correlate`.",
+ )
+ def append_correlation(self, fromclause):
+ """Append the given correlation expression to this select()
+ construct.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.correlate` method is preferred,
+ as it provides
+ standard :term:`method chaining`.
+
+ """
+
+ self.correlate.non_generative(self, fromclause)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_column` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.add_columns`.",
+ )
+ def append_column(self, column):
+ """Append the given column expression to the columns clause of this
+ select() construct.
+
+ E.g.::
+
+ my_select.append_column(some_table.c.new_column)
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.add_columns` method is preferred,
+ as it provides standard
+ :term:`method chaining`.
+
+ """
+ self.add_columns.non_generative(self, column)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_prefix` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.prefix_with`.",
+ )
+ def append_prefix(self, clause):
+ """Append the given columns clause prefix expression to this select()
+ construct.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.prefix_with` method is preferred,
+ as it provides
+ standard :term:`method chaining`.
+
+ """
+ self.prefix_with.non_generative(self, clause)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_whereclause` "
+ "method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.where`.",
+ )
+ def append_whereclause(self, whereclause):
+ """Append the given expression to this select() construct's WHERE
+ criterion.
+
+ The expression will be joined to existing WHERE criterion via AND.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.where` method is preferred,
+ as it provides standard
+ :term:`method chaining`.
+
+ """
+ self.where.non_generative(self, whereclause)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_having` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.having`.",
+ )
+ def append_having(self, having):
+ """Append the given expression to this select() construct's HAVING
+ criterion.
+
+ The expression will be joined to existing HAVING criterion via AND.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.having` method is preferred,
+ as it provides standard
+ :term:`method chaining`.
+
+ """
+
+ self.having.non_generative(self, having)
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.append_from` method is deprecated "
+ "and will be removed in a future release. Use the generative "
+ "method :meth:`_expression.Select.select_from`.",
+ )
+ def append_from(self, fromclause):
+ """Append the given :class:`_expression.FromClause` expression
+ to this select() construct's FROM clause.
+
+ This is an **in-place** mutation method; the
+ :meth:`_expression.Select.select_from` method is preferred,
+ as it provides
+ standard :term:`method chaining`.
+
+ """
+ self.select_from.non_generative(self, fromclause)
+
+
+@CompileState.plugin_for("default", "select")
+class SelectState(util.MemoizedSlots, CompileState):
+ __slots__ = (
+ "from_clauses",
+ "froms",
+ "columns_plus_names",
+ "_label_resolve_dict",
+ )
+
+ class default_select_compile_options(CacheableOptions):
+ _cache_key_traversal = []
+
+ def __init__(self, statement, compiler, **kw):
+ self.statement = statement
+ self.from_clauses = statement._from_obj
+
+ for memoized_entities in statement._memoized_select_entities:
+ self._setup_joins(
+ memoized_entities._setup_joins, memoized_entities._raw_columns
+ )
+
+ if statement._setup_joins:
+ self._setup_joins(statement._setup_joins, statement._raw_columns)
+
+ self.froms = self._get_froms(statement)
+
+ self.columns_plus_names = statement._generate_columns_plus_names(True)
+
+ @classmethod
+ def _plugin_not_implemented(cls):
+ raise NotImplementedError(
+ "The default SELECT construct without plugins does not "
+ "implement this method."
+ )
+
+ @classmethod
+ def get_column_descriptions(cls, statement):
+ return [
+ {
+ "name": name,
+ "type": element.type,
+ "expr": element,
+ }
+ for _, name, _, element, _ in (
+ statement._generate_columns_plus_names(False)
+ )
+ ]
+
+ @classmethod
+ def from_statement(cls, statement, from_statement):
+ cls._plugin_not_implemented()
+
+ @classmethod
+ def get_columns_clause_froms(cls, statement):
+ return cls._normalize_froms(
+ itertools.chain.from_iterable(
+ element._from_objects for element in statement._raw_columns
+ )
+ )
+
+ @classmethod
+ def _column_naming_convention(cls, label_style):
+
+ table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ dedupe = label_style is not LABEL_STYLE_NONE
+
+ pa = prefix_anon_map()
+ names = set()
+
+ def go(c, col_name=None):
+ if c._is_text_clause:
+ return None
+
+ elif not dedupe:
+ name = c._proxy_key
+ if name is None:
+ name = "_no_label"
+ return name
+
+ name = c._tq_key_label if table_qualified else c._proxy_key
+
+ if name is None:
+ name = "_no_label"
+ if name in names:
+ return c._anon_label(name) % pa
+ else:
+ names.add(name)
+ return name
+
+ elif name in names:
+ return (
+ c._anon_tq_key_label % pa
+ if table_qualified
+ else c._anon_key_label % pa
+ )
+ else:
+ names.add(name)
+ return name
+
+ return go
+
+ def _get_froms(self, statement):
+ return self._normalize_froms(
+ itertools.chain(
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._raw_columns
+ ]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ self.from_clauses,
+ ),
+ check_statement=statement,
+ )
+
+ @classmethod
+ def _normalize_froms(cls, iterable_of_froms, check_statement=None):
+ """given an iterable of things to select FROM, reduce them to what
+ would actually render in the FROM clause of a SELECT.
+
+ This does the job of checking for JOINs, tables, etc. that are in fact
+ overlapping due to cloning, adaption, present in overlapping joins,
+ etc.
+
+ """
+ seen = set()
+ froms = []
+
+ for item in iterable_of_froms:
+ if item._is_subquery and item.element is check_statement:
+ raise exc.InvalidRequestError(
+ "select() construct refers to itself as a FROM"
+ )
+
+ if not seen.intersection(item._cloned_set):
+ froms.append(item)
+ seen.update(item._cloned_set)
+
+ if froms:
+ toremove = set(
+ itertools.chain.from_iterable(
+ [_expand_cloned(f._hide_froms) for f in froms]
+ )
+ )
+ if toremove:
+ # filter out to FROM clauses not in the list,
+ # using a list to maintain ordering
+ froms = [f for f in froms if f not in toremove]
+
+ return froms
+
+ def _get_display_froms(
+ self, explicit_correlate_froms=None, implicit_correlate_froms=None
+ ):
+ """Return the full list of 'from' clauses to be displayed.
+
+ Takes into account a set of existing froms which may be
+ rendered in the FROM clause of enclosing selects; this Select
+ may want to leave those absent if it is automatically
+ correlating.
+
+ """
+
+ froms = self.froms
+
+ if self.statement._correlate:
+ to_correlate = self.statement._correlate
+ if to_correlate:
+ froms = [
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(
+ _cloned_intersection(
+ froms, explicit_correlate_froms or ()
+ ),
+ to_correlate,
+ )
+ ]
+
+ if self.statement._correlate_except is not None:
+
+ froms = [
+ f
+ for f in froms
+ if f
+ not in _cloned_difference(
+ _cloned_intersection(
+ froms, explicit_correlate_froms or ()
+ ),
+ self.statement._correlate_except,
+ )
+ ]
+
+ if (
+ self.statement._auto_correlate
+ and implicit_correlate_froms
+ and len(froms) > 1
+ ):
+
+ froms = [
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(froms, implicit_correlate_froms)
+ ]
+
+ if not len(froms):
+ raise exc.InvalidRequestError(
+ "Select statement '%r"
+ "' returned no FROM clauses "
+ "due to auto-correlation; "
+ "specify correlate(<tables>) "
+ "to control correlation "
+ "manually." % self.statement
+ )
+
+ return froms
+
+ def _memoized_attr__label_resolve_dict(self):
+ with_cols = dict(
+ (c._tq_label or c.key, c)
+ for c in self.statement._all_selected_columns
+ if c._allow_label_resolve
+ )
+ only_froms = dict(
+ (c.key, c)
+ for c in _select_iterables(self.froms)
+ if c._allow_label_resolve
+ )
+ only_cols = with_cols.copy()
+ for key, value in only_froms.items():
+ with_cols.setdefault(key, value)
+
+ return with_cols, only_froms, only_cols
+
+ @classmethod
+ def determine_last_joined_entity(cls, stmt):
+ if stmt._setup_joins:
+ return stmt._setup_joins[-1][0]
+ else:
+ return None
+
+ @classmethod
+ def all_selected_columns(cls, statement):
+ return [c for c in _select_iterables(statement._raw_columns)]
+
+ def _setup_joins(self, args, raw_columns):
+ for (right, onclause, left, flags) in args:
+ isouter = flags["isouter"]
+ full = flags["full"]
+
+ if left is None:
+ (
+ left,
+ replace_from_obj_index,
+ ) = self._join_determine_implicit_left_side(
+ raw_columns, left, right, onclause
+ )
+ else:
+ (replace_from_obj_index) = self._join_place_explicit_left_side(
+ left
+ )
+
+ if replace_from_obj_index is not None:
+ # splice into an existing element in the
+ # self._from_obj list
+ left_clause = self.from_clauses[replace_from_obj_index]
+
+ self.from_clauses = (
+ self.from_clauses[:replace_from_obj_index]
+ + (
+ Join(
+ left_clause,
+ right,
+ onclause,
+ isouter=isouter,
+ full=full,
+ ),
+ )
+ + self.from_clauses[replace_from_obj_index + 1 :]
+ )
+ else:
+
+ self.from_clauses = self.from_clauses + (
+ Join(left, right, onclause, isouter=isouter, full=full),
+ )
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_determine_implicit_left_side(
+ self, raw_columns, left, right, onclause
+ ):
+ """When join conditions don't express the left side explicitly,
+ determine if an existing FROM or entity in this query
+ can serve as the left hand side.
+
+ """
+
+ sql_util = util.preloaded.sql_util
+
+ replace_from_obj_index = None
+
+ from_clauses = self.from_clauses
+
+ if from_clauses:
+
+ indexes = sql_util.find_left_clause_to_join_from(
+ from_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ replace_from_obj_index = indexes[0]
+ left = from_clauses[replace_from_obj_index]
+ else:
+ potential = {}
+ statement = self.statement
+
+ for from_clause in itertools.chain(
+ itertools.chain.from_iterable(
+ [element._from_objects for element in raw_columns]
+ ),
+ itertools.chain.from_iterable(
+ [
+ element._from_objects
+ for element in statement._where_criteria
+ ]
+ ),
+ ):
+
+ potential[from_clause] = ()
+
+ all_clauses = list(potential.keys())
+ indexes = sql_util.find_left_clause_to_join_from(
+ all_clauses, right, onclause
+ )
+
+ if len(indexes) == 1:
+ left = all_clauses[indexes[0]]
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't determine which FROM clause to join "
+ "from, there are multiple FROMS which can "
+ "join to this entity. Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already to "
+ "help resolve the ambiguity."
+ )
+ elif not indexes:
+ raise exc.InvalidRequestError(
+ "Don't know how to join to %r. "
+ "Please use the .select_from() "
+ "method to establish an explicit left side, as well as "
+ "providing an explicit ON clause if not present already to "
+ "help resolve the ambiguity." % (right,)
+ )
+ return left, replace_from_obj_index
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def _join_place_explicit_left_side(self, left):
+ replace_from_obj_index = None
+
+ sql_util = util.preloaded.sql_util
+
+ from_clauses = list(self.statement._iterate_from_elements())
+
+ if from_clauses:
+ indexes = sql_util.find_left_clause_that_matches_given(
+ self.from_clauses, left
+ )
+ else:
+ indexes = []
+
+ if len(indexes) > 1:
+ raise exc.InvalidRequestError(
+ "Can't identify which entity in which to assign the "
+ "left side of this join. Please use a more specific "
+ "ON clause."
+ )
+
+ # have an index, means the left side is already present in
+ # an existing FROM in the self._from_obj tuple
+ if indexes:
+ replace_from_obj_index = indexes[0]
+
+ # no index, means we need to add a new element to the
+ # self._from_obj tuple
+
+ return replace_from_obj_index
+
+
+class _SelectFromElements(object):
+ def _iterate_from_elements(self):
+ # note this does not include elements
+ # in _setup_joins or _legacy_setup_joins
+
+ seen = set()
+ for element in self._raw_columns:
+ for fr in element._from_objects:
+ if fr in seen:
+ continue
+ seen.add(fr)
+ yield fr
+ for element in self._where_criteria:
+ for fr in element._from_objects:
+ if fr in seen:
+ continue
+ seen.add(fr)
+ yield fr
+ for element in self._from_obj:
+ if element in seen:
+ continue
+ seen.add(element)
+ yield element
+
+
+class _MemoizedSelectEntities(
+ traversals.HasCacheKey, traversals.HasCopyInternals, visitors.Traversible
+):
+ __visit_name__ = "memoized_select_entities"
+
+ _traverse_internals = [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_with_options", InternalTraversal.dp_executable_options),
+ ]
+
+ _annotations = util.EMPTY_DICT
+
+ def _clone(self, **kw):
+ c = self.__class__.__new__(self.__class__)
+ c.__dict__ = {k: v for k, v in self.__dict__.items()}
+
+ c._is_clone_of = self.__dict__.get("_is_clone_of", self)
+ return c
+
+ @classmethod
+ def _generate_for_statement(cls, select_stmt):
+ if (
+ select_stmt._setup_joins
+ or select_stmt._legacy_setup_joins
+ or select_stmt._with_options
+ ):
+ self = _MemoizedSelectEntities()
+ self._raw_columns = select_stmt._raw_columns
+ self._setup_joins = select_stmt._setup_joins
+ self._legacy_setup_joins = select_stmt._legacy_setup_joins
+ self._with_options = select_stmt._with_options
+
+ select_stmt._memoized_select_entities += (self,)
+ select_stmt._raw_columns = (
+ select_stmt._setup_joins
+ ) = (
+ select_stmt._legacy_setup_joins
+ ) = select_stmt._with_options = ()
+
+
+class Select(
+ HasPrefixes,
+ HasSuffixes,
+ HasHints,
+ HasCompileState,
+ DeprecatedSelectGenerations,
+ _SelectFromElements,
+ GenerativeSelect,
+):
+ """Represents a ``SELECT`` statement.
+
+ The :class:`_sql.Select` object is normally constructed using the
+ :func:`_sql.select` function. See that function for details.
+
+ .. seealso::
+
+ :func:`_sql.select`
+
+ :ref:`tutorial_selecting_data` - in the 2.0 tutorial
+
+ """
+
+ __visit_name__ = "select"
+
+ _setup_joins = ()
+ _legacy_setup_joins = ()
+ _memoized_select_entities = ()
+
+ _distinct = False
+ _distinct_on = ()
+ _correlate = ()
+ _correlate_except = None
+ _where_criteria = ()
+ _having_criteria = ()
+ _from_obj = ()
+ _auto_correlate = True
+
+ _compile_options = SelectState.default_select_compile_options
+
+ _traverse_internals = (
+ [
+ ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+ (
+ "_memoized_select_entities",
+ InternalTraversal.dp_memoized_select_entities,
+ ),
+ ("_from_obj", InternalTraversal.dp_clauseelement_list),
+ ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_having_criteria", InternalTraversal.dp_clauseelement_tuple),
+ ("_order_by_clauses", InternalTraversal.dp_clauseelement_tuple),
+ ("_group_by_clauses", InternalTraversal.dp_clauseelement_tuple),
+ ("_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_legacy_setup_joins", InternalTraversal.dp_setup_join_tuple),
+ ("_correlate", InternalTraversal.dp_clauseelement_tuple),
+ ("_correlate_except", InternalTraversal.dp_clauseelement_tuple),
+ ("_limit_clause", InternalTraversal.dp_clauseelement),
+ ("_offset_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause", InternalTraversal.dp_clauseelement),
+ ("_fetch_clause_options", InternalTraversal.dp_plain_dict),
+ ("_for_update_arg", InternalTraversal.dp_clauseelement),
+ ("_distinct", InternalTraversal.dp_boolean),
+ ("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
+ ("_label_style", InternalTraversal.dp_plain_obj),
+ ]
+ + HasCTE._has_ctes_traverse_internals
+ + HasPrefixes._has_prefixes_traverse_internals
+ + HasSuffixes._has_suffixes_traverse_internals
+ + HasHints._has_hints_traverse_internals
+ + SupportsCloneAnnotations._clone_annotations_traverse_internals
+ + Executable._executable_traverse_internals
+ )
+
+ _cache_key_traversal = _traverse_internals + [
+ ("_compile_options", InternalTraversal.dp_has_cache_key)
+ ]
+
+ @classmethod
+ def _create_select_from_fromclause(cls, target, entities, *arg, **kw):
+ if arg or kw:
+ return Select.create_legacy_select(entities, *arg, **kw)
+ else:
+ return Select._create_select(*entities)
+
+ @classmethod
+ @util.deprecated(
+ "2.0",
+ "The legacy calling style of :func:`_sql.select` is deprecated and "
+ "will be removed in SQLAlchemy 2.0. Please use the new calling "
+ "style described at :func:`_sql.select`.",
+ )
+ def create_legacy_select(
+ cls,
+ columns=None,
+ whereclause=None,
+ from_obj=None,
+ distinct=False,
+ having=None,
+ correlate=True,
+ prefixes=None,
+ suffixes=None,
+ **kwargs
+ ):
+ """Construct a new :class:`_expression.Select` using the 1.x style API.
+
+ This method is called implicitly when the :func:`_expression.select`
+ construct is used and the first argument is a Python list or other
+ plain sequence object, which is taken to refer to the columns
+ collection.
+
+ .. versionchanged:: 1.4 Added the :meth:`.Select.create_legacy_select`
+ constructor which documents the calling style in use when the
+ :func:`.select` construct is invoked using 1.x-style arguments.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.select` method on any
+ :class:`_expression.FromClause`.
+
+ All arguments which accept :class:`_expression.ClauseElement` arguments
+ also accept string arguments, which will be converted as appropriate
+ into either :func:`_expression.text()` or
+ :func:`_expression.literal_column()` constructs.
+
+ .. seealso::
+
+ :ref:`tutorial_selecting_data` - in the :ref:`unified_tutorial`
+
+ :param columns:
+ A list of :class:`_expression.ColumnElement` or
+ :class:`_expression.FromClause`
+ objects which will form the columns clause of the resulting
+ statement. For those objects that are instances of
+ :class:`_expression.FromClause` (typically :class:`_schema.Table`
+ or :class:`_expression.Alias`
+ objects), the :attr:`_expression.FromClause.c`
+ collection is extracted
+ to form a collection of :class:`_expression.ColumnElement` objects.
+
+ This parameter will also accept :class:`_expression.TextClause`
+ constructs as
+ given, as well as ORM-mapped classes.
+
+ .. note::
+
+ The :paramref:`_expression.select.columns`
+ parameter is not available
+ in the method form of :func:`_expression.select`, e.g.
+ :meth:`_expression.FromClause.select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.column`
+
+ :meth:`_expression.Select.with_only_columns`
+
+ :param whereclause:
+ A :class:`_expression.ClauseElement`
+ expression which will be used to form the
+ ``WHERE`` clause. It is typically preferable to add WHERE
+ criterion to an existing :class:`_expression.Select`
+ using method chaining
+ with :meth:`_expression.Select.where`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.where`
+
+ :param from_obj:
+ A list of :class:`_expression.ClauseElement`
+ objects which will be added to the
+ ``FROM`` clause of the resulting statement. This is equivalent
+ to calling :meth:`_expression.Select.select_from`
+ using method chaining on
+ an existing :class:`_expression.Select` object.
+
+ .. seealso::
+
+ :meth:`_expression.Select.select_from`
+ - full description of explicit
+ FROM clause specification.
+
+ :param bind=None:
+ an :class:`_engine.Engine` or :class:`_engine.Connection` instance
+ to which the
+ resulting :class:`_expression.Select` object will be bound. The
+ :class:`_expression.Select`
+ object will otherwise automatically bind to
+ whatever :class:`~.base.Connectable` instances can be located within
+ its contained :class:`_expression.ClauseElement` members.
+
+ :param correlate=True:
+ indicates that this :class:`_expression.Select`
+ object should have its
+ contained :class:`_expression.FromClause`
+ elements "correlated" to an enclosing
+ :class:`_expression.Select` object.
+ It is typically preferable to specify
+ correlations on an existing :class:`_expression.Select`
+ construct using
+ :meth:`_expression.Select.correlate`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.correlate`
+ - full description of correlation.
+
+ :param distinct=False:
+ when ``True``, applies a ``DISTINCT`` qualifier to the columns
+ clause of the resulting statement.
+
+ The boolean argument may also be a column expression or list
+ of column expressions - this is a special calling form which
+ is understood by the PostgreSQL dialect to render the
+ ``DISTINCT ON (<columns>)`` syntax.
+
+ ``distinct`` is also available on an existing
+ :class:`_expression.Select`
+ object via the :meth:`_expression.Select.distinct` method.
+
+ .. seealso::
+
+ :meth:`_expression.Select.distinct`
+
+ :param group_by:
+ a list of :class:`_expression.ClauseElement`
+ objects which will comprise the
+ ``GROUP BY`` clause of the resulting select. This parameter
+ is typically specified more naturally using the
+ :meth:`_expression.Select.group_by` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.group_by`
+
+ :param having:
+ a :class:`_expression.ClauseElement`
+ that will comprise the ``HAVING`` clause
+ of the resulting select when ``GROUP BY`` is used. This parameter
+ is typically specified more naturally using the
+ :meth:`_expression.Select.having` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.having`
+
+ :param limit=None:
+ a numerical value which usually renders as a ``LIMIT``
+ expression in the resulting select. Backends that don't
+ support ``LIMIT`` will attempt to provide similar
+ functionality. This parameter is typically specified more
+ naturally using the :meth:`_expression.Select.limit`
+ method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.limit`
+
+ :param offset=None:
+ a numeric value which usually renders as an ``OFFSET``
+ expression in the resulting select. Backends that don't
+ support ``OFFSET`` will attempt to provide similar
+ functionality. This parameter is typically specified more naturally
+ using the :meth:`_expression.Select.offset` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.offset`
+
+ :param order_by:
+ a scalar or list of :class:`_expression.ClauseElement`
+ objects which will
+ comprise the ``ORDER BY`` clause of the resulting select.
+ This parameter is typically specified more naturally using the
+ :meth:`_expression.Select.order_by` method on an existing
+ :class:`_expression.Select`.
+
+ .. seealso::
+
+ :meth:`_expression.Select.order_by`
+
+ :param use_labels=False:
+ when ``True``, the statement will be generated using labels
+ for each column in the columns clause, which qualify each
+ column with its parent table's (or aliases) name so that name
+ conflicts between columns in different tables don't occur.
+ The format of the label is ``<tablename>_<column>``. The "c"
+ collection of a :class:`_expression.Subquery` created
+ against this :class:`_expression.Select`
+ object, as well as the :attr:`_expression.Select.selected_columns`
+ collection of the :class:`_expression.Select` itself, will use these
+ names for targeting column members.
+
+ This parameter can also be specified on an existing
+ :class:`_expression.Select` object using the
+ :meth:`_expression.Select.set_label_style`
+ method.
+
+ .. seealso::
+
+ :meth:`_expression.Select.set_label_style`
+
+ """
+ self = cls.__new__(cls)
+
+ self._auto_correlate = correlate
+
+ if distinct is not False:
+ if distinct is True:
+ self.distinct.non_generative(self)
+ else:
+ self.distinct.non_generative(self, *util.to_list(distinct))
+
+ if from_obj is not None:
+ self.select_from.non_generative(self, *util.to_list(from_obj))
+
+ try:
+ cols_present = bool(columns)
+ except TypeError as err:
+ util.raise_(
+ exc.ArgumentError(
+ "select() construct created in legacy mode, i.e. with "
+ "keyword arguments, must provide the columns argument as "
+ "a Python list or other iterable.",
+ code="c9ae",
+ ),
+ from_=err,
+ )
+
+ if cols_present:
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole, c, apply_propagate_attrs=self
+ )
+ for c in columns
+ ]
+ else:
+ self._raw_columns = []
+
+ if whereclause is not None:
+ self.where.non_generative(self, whereclause)
+
+ if having is not None:
+ self.having.non_generative(self, having)
+
+ if prefixes:
+ self._setup_prefixes(prefixes)
+
+ if suffixes:
+ self._setup_suffixes(suffixes)
+
+ GenerativeSelect.__init__(self, **kwargs)
+ return self
+
+ @classmethod
+ def _create_future_select(cls, *entities):
+ r"""Construct a new :class:`_expression.Select` using the 2.
+ x style API.
+
+ .. versionadded:: 1.4 - The :func:`_sql.select` function now accepts
+ column arguments positionally. The top-level :func:`_sql.select`
+ function will automatically use the 1.x or 2.x style API based on
+ the incoming arguments; using :func:`_future.select` from the
+ ``sqlalchemy.future`` module will enforce that only the 2.x style
+ constructor is used.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.select` method on any
+ :class:`_expression.FromClause`.
+
+ .. seealso::
+
+ :ref:`coretutorial_selecting` - Core Tutorial description of
+ :func:`_expression.select`.
+
+ :param \*entities:
+ Entities to SELECT from. For Core usage, this is typically a series
+ of :class:`_expression.ColumnElement` and / or
+ :class:`_expression.FromClause`
+ objects which will form the columns clause of the resulting
+ statement. For those objects that are instances of
+ :class:`_expression.FromClause` (typically :class:`_schema.Table`
+ or :class:`_expression.Alias`
+ objects), the :attr:`_expression.FromClause.c`
+ collection is extracted
+ to form a collection of :class:`_expression.ColumnElement` objects.
+
+ This parameter will also accept :class:`_expression.TextClause`
+ constructs as
+ given, as well as ORM-mapped classes.
+
+ """
+
+ self = cls.__new__(cls)
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
+ )
+ for ent in entities
+ ]
+
+ GenerativeSelect.__init__(self)
+
+ return self
+
+ _create_select = _create_future_select
+
+ @classmethod
+ def _create_raw_select(cls, **kw):
+ """Create a :class:`.Select` using raw ``__new__`` with no coercions.
+
+ Used internally to build up :class:`.Select` constructs with
+ pre-established state.
+
+ """
+
+ stmt = Select.__new__(Select)
+ stmt.__dict__.update(kw)
+ return stmt
+
+ @classmethod
+ def _create(cls, *args, **kw):
+ r"""Create a :class:`.Select` using either the 1.x or 2.0 constructor
+ style.
+
+ For the legacy calling style, see :meth:`.Select.create_legacy_select`.
+ If the first argument passed is a Python sequence or if keyword
+ arguments are present, this style is used.
+
+ .. versionadded:: 2.0 - the :func:`_future.select` construct is
+ the same construct as the one returned by
+ :func:`_expression.select`, except that the function only
+ accepts the "columns clause" entities up front; the rest of the
+ state of the SELECT should be built up using generative methods.
+
+ Similar functionality is also available via the
+ :meth:`_expression.FromClause.select` method on any
+ :class:`_expression.FromClause`.
+
+ .. seealso::
+
+ :ref:`coretutorial_selecting` - Core Tutorial description of
+ :func:`_expression.select`.
+
+ :param \*entities:
+ Entities to SELECT from. For Core usage, this is typically a series
+ of :class:`_expression.ColumnElement` and / or
+ :class:`_expression.FromClause`
+ objects which will form the columns clause of the resulting
+ statement. For those objects that are instances of
+ :class:`_expression.FromClause` (typically :class:`_schema.Table`
+ or :class:`_expression.Alias`
+ objects), the :attr:`_expression.FromClause.c`
+ collection is extracted
+ to form a collection of :class:`_expression.ColumnElement` objects.
+
+ This parameter will also accept :class:`_expression.TextClause`
+ constructs as given, as well as ORM-mapped classes.
+
+ """
+ if (
+ args
+ and (
+ isinstance(args[0], list)
+ or (
+ hasattr(args[0], "__iter__")
+ and not isinstance(
+ args[0], util.string_types + (ClauseElement,)
+ )
+ and inspect(args[0], raiseerr=False) is None
+ and not hasattr(args[0], "__clause_element__")
+ )
+ )
+ ) or kw:
+ return cls.create_legacy_select(*args, **kw)
+ else:
+ return cls._create_future_select(*args)
+
+ def __init__(self):
+ raise NotImplementedError()
+
+ def _scalar_type(self):
+ elem = self._raw_columns[0]
+ cols = list(elem._select_iterable)
+ return cols[0].type
+
+ def filter(self, *criteria):
+ """A synonym for the :meth:`_future.Select.where` method."""
+
+ return self.where(*criteria)
+
+ def _filter_by_zero(self):
+ if self._setup_joins:
+ meth = SelectState.get_plugin_class(
+ self
+ ).determine_last_joined_entity
+ _last_joined_entity = meth(self)
+ if _last_joined_entity is not None:
+ return _last_joined_entity
+
+ if self._from_obj:
+ return self._from_obj[0]
+
+ return self._raw_columns[0]
+
+ def filter_by(self, **kwargs):
+ r"""apply the given filtering criterion as a WHERE clause
+ to this select.
+
+ """
+ from_entity = self._filter_by_zero()
+
+ clauses = [
+ _entity_namespace_key(from_entity, key) == value
+ for key, value in kwargs.items()
+ ]
+ return self.filter(*clauses)
+
+ @property
+ def column_descriptions(self):
+ """Return a :term:`plugin-enabled` 'column descriptions' structure
+ referring to the columns which are SELECTed by this statement.
+
+ This attribute is generally useful when using the ORM, as an
+ extended structure which includes information about mapped
+ entities is returned. The section :ref:`queryguide_inspection`
+ contains more background.
+
+ For a Core-only statement, the structure returned by this accessor
+ is derived from the same objects that are returned by the
+ :attr:`.Select.selected_columns` accessor, formatted as a list of
+ dictionaries which contain the keys ``name``, ``type`` and ``expr``,
+ which indicate the column expressions to be selected::
+
+ >>> stmt = select(user_table)
+ >>> stmt.column_descriptions
+ [
+ {
+ 'name': 'id',
+ 'type': Integer(),
+ 'expr': Column('id', Integer(), ...)},
+ {
+ 'name': 'name',
+ 'type': String(length=30),
+ 'expr': Column('name', String(length=30), ...)}
+ ]
+
+ .. versionchanged:: 1.4.33 The :attr:`.Select.column_descriptions`
+ attribute returns a structure for a Core-only set of entities,
+ not just ORM-only entities.
+
+ .. seealso::
+
+ :attr:`.UpdateBase.entity_description` - entity information for
+ an :func:`.insert`, :func:`.update`, or :func:`.delete`
+
+ :ref:`queryguide_inspection` - ORM background
+
+ """
+ meth = SelectState.get_plugin_class(self).get_column_descriptions
+ return meth(self)
+
+ def from_statement(self, statement):
+ """Apply the columns which this :class:`.Select` would select
+ onto another statement.
+
+ This operation is :term:`plugin-specific` and will raise a not
+ supported exception if this :class:`_sql.Select` does not select from
+ plugin-enabled entities.
+
+
+ The statement is typically either a :func:`_expression.text` or
+ :func:`_expression.select` construct, and should return the set of
+ columns appropriate to the entities represented by this
+ :class:`.Select`.
+
+ .. seealso::
+
+ :ref:`orm_queryguide_selecting_text` - usage examples in the
+ ORM Querying Guide
+
+ """
+ meth = SelectState.get_plugin_class(self).from_statement
+ return meth(self, statement)
+
+ @_generative
+ def join(self, target, onclause=None, isouter=False, full=False):
+ r"""Create a SQL JOIN against this :class:`_expression.Select`
+ object's criterion
+ and apply generatively, returning the newly resulting
+ :class:`_expression.Select`.
+
+ E.g.::
+
+ stmt = select(user_table).join(address_table, user_table.c.id == address_table.c.user_id)
+
+ The above statement generates SQL similar to::
+
+ SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id
+
+ .. versionchanged:: 1.4 :meth:`_expression.Select.join` now creates
+ a :class:`_sql.Join` object between a :class:`_sql.FromClause`
+ source that is within the FROM clause of the existing SELECT,
+ and a given target :class:`_sql.FromClause`, and then adds
+ this :class:`_sql.Join` to the FROM clause of the newly generated
+ SELECT statement. This is completely reworked from the behavior
+ in 1.3, which would instead create a subquery of the entire
+ :class:`_expression.Select` and then join that subquery to the
+ target.
+
+ This is a **backwards incompatible change** as the previous behavior
+ was mostly useless, producing an unnamed subquery rejected by
+ most databases in any case. The new behavior is modeled after
+ that of the very successful :meth:`_orm.Query.join` method in the
+ ORM, in order to support the functionality of :class:`_orm.Query`
+ being available by using a :class:`_sql.Select` object with an
+ :class:`_orm.Session`.
+
+ See the notes for this change at :ref:`change_select_join`.
+
+
+ :param target: target table to join towards
+
+ :param onclause: ON clause of the join. If omitted, an ON clause
+ is generated automatically based on the :class:`_schema.ForeignKey`
+ linkages between the two tables, if one can be unambiguously
+ determined, otherwise an error is raised.
+
+ :param isouter: if True, generate LEFT OUTER join. Same as
+ :meth:`_expression.Select.outerjoin`.
+
+ :param full: if True, generate FULL OUTER join.
+
+ .. seealso::
+
+ :ref:`tutorial_select_join` - in the :doc:`/tutorial/index`
+
+ :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel`
+
+ :meth:`_expression.Select.join_from`
+
+ :meth:`_expression.Select.outerjoin`
+
+ """ # noqa: E501
+ target = coercions.expect(
+ roles.JoinTargetRole, target, apply_propagate_attrs=self
+ )
+ if onclause is not None:
+ onclause = coercions.expect(roles.OnClauseRole, onclause)
+ self._setup_joins += (
+ (target, onclause, None, {"isouter": isouter, "full": full}),
+ )
+
+ def outerjoin_from(self, from_, target, onclause=None, full=False):
+ r"""Create a SQL LEFT OUTER JOIN against this
+ :class:`_expression.Select` object's criterion and apply generatively,
+ returning the newly resulting :class:`_expression.Select`.
+
+ Usage is the same as that of :meth:`_selectable.Select.join_from`.
+
+ """
+ return self.join_from(
+ from_, target, onclause=onclause, isouter=True, full=full
+ )
+
+ @_generative
+ def join_from(
+ self, from_, target, onclause=None, isouter=False, full=False
+ ):
+ r"""Create a SQL JOIN against this :class:`_expression.Select`
+ object's criterion
+ and apply generatively, returning the newly resulting
+ :class:`_expression.Select`.
+
+ E.g.::
+
+ stmt = select(user_table, address_table).join_from(
+ user_table, address_table, user_table.c.id == address_table.c.user_id
+ )
+
+ The above statement generates SQL similar to::
+
+ SELECT user.id, user.name, address.id, address.email, address.user_id
+ FROM user JOIN address ON user.id = address.user_id
+
+ .. versionadded:: 1.4
+
+ :param from\_: the left side of the join, will be rendered in the
+ FROM clause and is roughly equivalent to using the
+ :meth:`.Select.select_from` method.
+
+ :param target: target table to join towards
+
+ :param onclause: ON clause of the join.
+
+ :param isouter: if True, generate LEFT OUTER join. Same as
+ :meth:`_expression.Select.outerjoin`.
+
+ :param full: if True, generate FULL OUTER join.
+
+ .. seealso::
+
+ :ref:`tutorial_select_join` - in the :doc:`/tutorial/index`
+
+ :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel`
+
+ :meth:`_expression.Select.join`
+
+ """ # noqa: E501
+
+ # note the order of parsing from vs. target is important here, as we
+ # are also deriving the source of the plugin (i.e. the subject mapper
+ # in an ORM query) which should favor the "from_" over the "target"
+
+ from_ = coercions.expect(
+ roles.FromClauseRole, from_, apply_propagate_attrs=self
+ )
+ target = coercions.expect(
+ roles.JoinTargetRole, target, apply_propagate_attrs=self
+ )
+ if onclause is not None:
+ onclause = coercions.expect(roles.OnClauseRole, onclause)
+
+ self._setup_joins += (
+ (target, onclause, from_, {"isouter": isouter, "full": full}),
+ )
+
+ def outerjoin(self, target, onclause=None, full=False):
+ """Create a left outer join.
+
+ Parameters are the same as that of :meth:`_expression.Select.join`.
+
+ .. versionchanged:: 1.4 :meth:`_expression.Select.outerjoin` now
+ creates a :class:`_sql.Join` object between a
+ :class:`_sql.FromClause` source that is within the FROM clause of
+ the existing SELECT, and a given target :class:`_sql.FromClause`,
+ and then adds this :class:`_sql.Join` to the FROM clause of the
+ newly generated SELECT statement. This is completely reworked
+ from the behavior in 1.3, which would instead create a subquery of
+ the entire
+ :class:`_expression.Select` and then join that subquery to the
+ target.
+
+ This is a **backwards incompatible change** as the previous behavior
+ was mostly useless, producing an unnamed subquery rejected by
+ most databases in any case. The new behavior is modeled after
+ that of the very successful :meth:`_orm.Query.join` method in the
+ ORM, in order to support the functionality of :class:`_orm.Query`
+ being available by using a :class:`_sql.Select` object with an
+ :class:`_orm.Session`.
+
+ See the notes for this change at :ref:`change_select_join`.
+
+ .. seealso::
+
+ :ref:`tutorial_select_join` - in the :doc:`/tutorial/index`
+
+ :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel`
+
+ :meth:`_expression.Select.join`
+
+ """
+ return self.join(target, onclause=onclause, isouter=True, full=full)
+
+ def get_final_froms(self):
+ """Compute the final displayed list of :class:`_expression.FromClause`
+ elements.
+
+ This method will run through the full computation required to
+ determine what FROM elements will be displayed in the resulting
+ SELECT statement, including shadowing individual tables with
+ JOIN objects, as well as full computation for ORM use cases including
+ eager loading clauses.
+
+ For ORM use, this accessor returns the **post compilation**
+ list of FROM objects; this collection will include elements such as
+ eagerly loaded tables and joins. The objects will **not** be
+ ORM enabled and not work as a replacement for the
+ :meth:`_sql.Select.select_froms` collection; additionally, the
+ method is not well performing for an ORM enabled statement as it
+ will incur the full ORM construction process.
+
+ To retrieve the FROM list that's implied by the "columns" collection
+ passed to the :class:`_sql.Select` originally, use the
+ :attr:`_sql.Select.columns_clause_froms` accessor.
+
+ To select from an alternative set of columns while maintaining the
+ FROM list, use the :meth:`_sql.Select.with_only_columns` method and
+ pass the
+ :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+ parameter.
+
+ .. versionadded:: 1.4.23 - the :meth:`_sql.Select.get_final_froms`
+ method replaces the previous :attr:`_sql.Select.froms` accessor,
+ which is deprecated.
+
+ .. seealso::
+
+ :attr:`_sql.Select.columns_clause_froms`
+
+ """
+ return self._compile_state_factory(self, None)._get_display_froms()
+
+ @property
+ @util.deprecated(
+ "1.4.23",
+ "The :attr:`_expression.Select.froms` attribute is moved to "
+ "the :meth:`_expression.Select.get_final_froms` method.",
+ )
+ def froms(self):
+ """Return the displayed list of :class:`_expression.FromClause`
+ elements.
+
+
+ """
+ return self.get_final_froms()
+
+ @property
+ def columns_clause_froms(self):
+ """Return the set of :class:`_expression.FromClause` objects implied
+ by the columns clause of this SELECT statement.
+
+ .. versionadded:: 1.4.23
+
+ .. seealso::
+
+ :attr:`_sql.Select.froms` - "final" FROM list taking the full
+ statement into account
+
+ :meth:`_sql.Select.with_only_columns` - makes use of this
+ collection to set up a new FROM list
+
+ """
+
+ return SelectState.get_plugin_class(self).get_columns_clause_froms(
+ self
+ )
+
+ @property
+ def inner_columns(self):
+ """An iterator of all :class:`_expression.ColumnElement`
+ expressions which would
+ be rendered into the columns clause of the resulting SELECT statement.
+
+ This method is legacy as of 1.4 and is superseded by the
+ :attr:`_expression.Select.exported_columns` collection.
+
+ """
+
+ return iter(self._all_selected_columns)
+
+ def is_derived_from(self, fromclause):
+ if self in fromclause._cloned_set:
+ return True
+
+ for f in self._iterate_from_elements():
+ if f.is_derived_from(fromclause):
+ return True
+ return False
+
+ def _copy_internals(self, clone=_clone, **kw):
+ # Select() object has been cloned and probably adapted by the
+ # given clone function. Apply the cloning function to internal
+ # objects
+
+ # 1. keep a dictionary of the froms we've cloned, and what
+ # they've become. This allows us to ensure the same cloned from
+ # is used when other items such as columns are "cloned"
+
+ all_the_froms = set(
+ itertools.chain(
+ _from_objects(*self._raw_columns),
+ _from_objects(*self._where_criteria),
+ _from_objects(*[elem[0] for elem in self._setup_joins]),
+ )
+ )
+
+ # do a clone for the froms we've gathered. what is important here
+ # is if any of the things we are selecting from, like tables,
+ # were converted into Join objects. if so, these need to be
+ # added to _from_obj explicitly, because otherwise they won't be
+ # part of the new state, as they don't associate themselves with
+ # their columns.
+ new_froms = {f: clone(f, **kw) for f in all_the_froms}
+
+ # 2. copy FROM collections, adding in joins that we've created.
+ existing_from_obj = [clone(f, **kw) for f in self._from_obj]
+ add_froms = (
+ set(f for f in new_froms.values() if isinstance(f, Join))
+ .difference(all_the_froms)
+ .difference(existing_from_obj)
+ )
+
+ self._from_obj = tuple(existing_from_obj) + tuple(add_froms)
+
+ # 3. clone everything else, making sure we use columns
+ # corresponding to the froms we just made.
+ def replace(obj, **kw):
+ if isinstance(obj, ColumnClause) and obj.table in new_froms:
+ newelem = new_froms[obj.table].corresponding_column(obj)
+ return newelem
+
+ kw["replace"] = replace
+
+ # copy everything else. for table-ish things like correlate,
+ # correlate_except, setup_joins, these clone normally. For
+ # column-expression oriented things like raw_columns, where_criteria,
+ # order by, we get this from the new froms.
+ super(Select, self)._copy_internals(
+ clone=clone, omit_attrs=("_from_obj",), **kw
+ )
+
+ self._reset_memoizations()
+
+ def get_children(self, **kwargs):
+ return itertools.chain(
+ super(Select, self).get_children(
+ omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
+ ),
+ self._iterate_from_elements(),
+ )
+
+ @_generative
+ def add_columns(self, *columns):
+ """Return a new :func:`_expression.select` construct with
+ the given column expressions added to its columns clause.
+
+ E.g.::
+
+ my_select = my_select.add_columns(table.c.new_column)
+
+ See the documentation for
+ :meth:`_expression.Select.with_only_columns`
+ for guidelines on adding /replacing the columns of a
+ :class:`_expression.Select` object.
+
+ """
+ self._reset_memoizations()
+
+ self._raw_columns = self._raw_columns + [
+ coercions.expect(
+ roles.ColumnsClauseRole, column, apply_propagate_attrs=self
+ )
+ for column in columns
+ ]
+
+ def _set_entities(self, entities):
+ self._raw_columns = [
+ coercions.expect(
+ roles.ColumnsClauseRole, ent, apply_propagate_attrs=self
+ )
+ for ent in util.to_list(entities)
+ ]
+
+ @util.deprecated(
+ "1.4",
+ "The :meth:`_expression.Select.column` method is deprecated and will "
+ "be removed in a future release. Please use "
+ ":meth:`_expression.Select.add_columns`",
+ )
+ def column(self, column):
+ """Return a new :func:`_expression.select` construct with
+ the given column expression added to its columns clause.
+
+ E.g.::
+
+ my_select = my_select.column(table.c.new_column)
+
+ See the documentation for
+ :meth:`_expression.Select.with_only_columns`
+ for guidelines on adding /replacing the columns of a
+ :class:`_expression.Select` object.
+
+ """
+ return self.add_columns(column)
+
+ @util.preload_module("sqlalchemy.sql.util")
+ def reduce_columns(self, only_synonyms=True):
+ """Return a new :func:`_expression.select` construct with redundantly
+ named, equivalently-valued columns removed from the columns clause.
+
+ "Redundant" here means two columns where one refers to the
+ other either based on foreign key, or via a simple equality
+ comparison in the WHERE clause of the statement. The primary purpose
+ of this method is to automatically construct a select statement
+ with all uniquely-named columns, without the need to use
+ table-qualified labels as
+ :meth:`_expression.Select.set_label_style`
+ does.
+
+ When columns are omitted based on foreign key, the referred-to
+ column is the one that's kept. When columns are omitted based on
+ WHERE equivalence, the first column in the columns clause is the
+ one that's kept.
+
+ :param only_synonyms: when True, limit the removal of columns
+ to those which have the same name as the equivalent. Otherwise,
+ all columns that are equivalent to another are removed.
+
+ """
+ return self.with_only_columns(
+ *util.preloaded.sql_util.reduce_columns(
+ self._all_selected_columns,
+ only_synonyms=only_synonyms,
+ *(self._where_criteria + self._from_obj)
+ )
+ )
+
+ @_generative
+ def with_only_columns(self, *columns, **kw):
+ r"""Return a new :func:`_expression.select` construct with its columns
+ clause replaced with the given columns.
+
+ By default, this method is exactly equivalent to as if the original
+ :func:`_expression.select` had been called with the given columns
+ clause. E.g. a statement::
+
+ s = select(table1.c.a, table1.c.b)
+ s = s.with_only_columns(table1.c.b)
+
+ should be exactly equivalent to::
+
+ s = select(table1.c.b)
+
+ In this mode of operation, :meth:`_sql.Select.with_only_columns`
+ will also dynamically alter the FROM clause of the
+ statement if it is not explicitly stated.
+ To maintain the existing set of FROMs including those implied by the
+ current columns clause, add the
+ :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+ parameter::
+
+ s = select(table1.c.a, table2.c.b)
+ s = s.with_only_columns(table1.c.a, maintain_column_froms=True)
+
+ The above parameter performs a transfer of the effective FROMs
+ in the columns collection to the :meth:`_sql.Select.select_from`
+ method, as though the following were invoked::
+
+ s = select(table1.c.a, table2.c.b)
+ s = s.select_from(table1, table2).with_only_columns(table1.c.a)
+
+ The :paramref:`_sql.Select.with_only_columns.maintain_column_froms`
+ parameter makes use of the :attr:`_sql.Select.columns_clause_froms`
+ collection and performs an operation equivalent to the following::
+
+ s = select(table1.c.a, table2.c.b)
+ s = s.select_from(*s.columns_clause_froms).with_only_columns(table1.c.a)
+
+ :param \*columns: column expressions to be used.
+
+ .. versionchanged:: 1.4 the :meth:`_sql.Select.with_only_columns`
+ method accepts the list of column expressions positionally;
+ passing the expressions as a list is deprecated.
+
+ :param maintain_column_froms: boolean parameter that will ensure the
+ FROM list implied from the current columns clause will be transferred
+ to the :meth:`_sql.Select.select_from` method first.
+
+ .. versionadded:: 1.4.23
+
+ """ # noqa: E501
+
+ # memoizations should be cleared here as of
+ # I95c560ffcbfa30b26644999412fb6a385125f663 , asserting this
+ # is the case for now.
+ self._assert_no_memoizations()
+
+ maintain_column_froms = kw.pop("maintain_column_froms", False)
+ if kw:
+ raise TypeError("unknown parameters: %s" % (", ".join(kw),))
+
+ if maintain_column_froms:
+ self.select_from.non_generative(self, *self.columns_clause_froms)
+
+ # then memoize the FROMs etc.
+ _MemoizedSelectEntities._generate_for_statement(self)
+
+ self._raw_columns = [
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in coercions._expression_collection_was_a_list(
+ "columns", "Select.with_only_columns", columns
+ )
+ ]
+
+ @property
+ def whereclause(self):
+ """Return the completed WHERE clause for this
+ :class:`_expression.Select` statement.
+
+ This assembles the current collection of WHERE criteria
+ into a single :class:`_expression.BooleanClauseList` construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
+ _whereclause = whereclause
+
+ @_generative
+ def where(self, *whereclause):
+ """Return a new :func:`_expression.select` construct with
+ the given expression added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+ """
+
+ assert isinstance(self._where_criteria, tuple)
+
+ for criterion in whereclause:
+ where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
+ self._where_criteria += (where_criteria,)
+
+ @_generative
+ def having(self, having):
+ """Return a new :func:`_expression.select` construct with
+ the given expression added to
+ its HAVING clause, joined to the existing clause via AND, if any.
+
+ """
+ self._having_criteria += (
+ coercions.expect(roles.WhereHavingRole, having),
+ )
+
+ @_generative
+ def distinct(self, *expr):
+ r"""Return a new :func:`_expression.select` construct which
+ will apply DISTINCT to its columns clause.
+
+ :param \*expr: optional column expressions. When present,
+ the PostgreSQL dialect will render a ``DISTINCT ON (<expressions>>)``
+ construct.
+
+ .. deprecated:: 1.4 Using \*expr in other dialects is deprecated
+ and will raise :class:`_exc.CompileError` in a future version.
+
+ """
+ if expr:
+ self._distinct = True
+ self._distinct_on = self._distinct_on + tuple(
+ coercions.expect(roles.ByOfRole, e) for e in expr
+ )
+ else:
+ self._distinct = True
+
+ @_generative
+ def select_from(self, *froms):
+ r"""Return a new :func:`_expression.select` construct with the
+ given FROM expression(s)
+ merged into its list of FROM objects.
+
+ E.g.::
+
+ table1 = table('t1', column('a'))
+ table2 = table('t2', column('b'))
+ s = select(table1.c.a).\
+ select_from(
+ table1.join(table2, table1.c.a==table2.c.b)
+ )
+
+ The "from" list is a unique set on the identity of each element,
+ so adding an already present :class:`_schema.Table`
+ or other selectable
+ will have no effect. Passing a :class:`_expression.Join` that refers
+ to an already present :class:`_schema.Table`
+ or other selectable will have
+ the effect of concealing the presence of that selectable as
+ an individual element in the rendered FROM list, instead
+ rendering it into a JOIN clause.
+
+ While the typical purpose of :meth:`_expression.Select.select_from`
+ is to
+ replace the default, derived FROM clause with a join, it can
+ also be called with individual table elements, multiple times
+ if desired, in the case that the FROM clause cannot be fully
+ derived from the columns clause::
+
+ select(func.count('*')).select_from(table1)
+
+ """
+
+ self._from_obj += tuple(
+ coercions.expect(
+ roles.FromClauseRole, fromclause, apply_propagate_attrs=self
+ )
+ for fromclause in froms
+ )
+
+ @_generative
+ def correlate(self, *fromclauses):
+ r"""Return a new :class:`_expression.Select`
+ which will correlate the given FROM
+ clauses to that of an enclosing :class:`_expression.Select`.
+
+ Calling this method turns off the :class:`_expression.Select` object's
+ default behavior of "auto-correlation". Normally, FROM elements
+ which appear in a :class:`_expression.Select`
+ that encloses this one via
+ its :term:`WHERE clause`, ORDER BY, HAVING or
+ :term:`columns clause` will be omitted from this
+ :class:`_expression.Select`
+ object's :term:`FROM clause`.
+ Setting an explicit correlation collection using the
+ :meth:`_expression.Select.correlate`
+ method provides a fixed list of FROM objects
+ that can potentially take place in this process.
+
+ When :meth:`_expression.Select.correlate`
+ is used to apply specific FROM clauses
+ for correlation, the FROM elements become candidates for
+ correlation regardless of how deeply nested this
+ :class:`_expression.Select`
+ object is, relative to an enclosing :class:`_expression.Select`
+ which refers to
+ the same FROM object. This is in contrast to the behavior of
+ "auto-correlation" which only correlates to an immediate enclosing
+ :class:`_expression.Select`.
+ Multi-level correlation ensures that the link
+ between enclosed and enclosing :class:`_expression.Select`
+ is always via
+ at least one WHERE/ORDER BY/HAVING/columns clause in order for
+ correlation to take place.
+
+ If ``None`` is passed, the :class:`_expression.Select`
+ object will correlate
+ none of its FROM entries, and all will render unconditionally
+ in the local FROM clause.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate collection.
+
+ .. seealso::
+
+ :meth:`_expression.Select.correlate_except`
+
+ :ref:`tutorial_scalar_subquery`
+
+ """
+
+ self._auto_correlate = False
+ if fromclauses and fromclauses[0] in {None, False}:
+ self._correlate = ()
+ else:
+ self._correlate = self._correlate + tuple(
+ coercions.expect(roles.FromClauseRole, f) for f in fromclauses
+ )
+
+ @_generative
+ def correlate_except(self, *fromclauses):
+ r"""Return a new :class:`_expression.Select`
+ which will omit the given FROM
+ clauses from the auto-correlation process.
+
+ Calling :meth:`_expression.Select.correlate_except` turns off the
+ :class:`_expression.Select` object's default behavior of
+ "auto-correlation" for the given FROM elements. An element
+ specified here will unconditionally appear in the FROM list, while
+ all other FROM elements remain subject to normal auto-correlation
+ behaviors.
+
+ If ``None`` is passed, the :class:`_expression.Select`
+ object will correlate
+ all of its FROM entries.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate-exception collection.
+
+ .. seealso::
+
+ :meth:`_expression.Select.correlate`
+
+ :ref:`tutorial_scalar_subquery`
+
+ """
+
+ self._auto_correlate = False
+ if fromclauses and fromclauses[0] in {None, False}:
+ self._correlate_except = ()
+ else:
+ self._correlate_except = (self._correlate_except or ()) + tuple(
+ coercions.expect(roles.FromClauseRole, f) for f in fromclauses
+ )
+
+ @HasMemoized.memoized_attribute
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set,
+ not including :class:`_sql.TextClause` constructs.
+
+ This collection differs from the :attr:`_expression.FromClause.columns`
+ collection of a :class:`_expression.FromClause` in that the columns
+ within this collection cannot be directly nested inside another SELECT
+ statement; a subquery must be applied first which provides for the
+ necessary parenthesization required by SQL.
+
+ For a :func:`_expression.select` construct, the collection here is
+ exactly what would be rendered inside the "SELECT" statement, and the
+ :class:`_expression.ColumnElement` objects are directly present as they
+ were given, e.g.::
+
+ col1 = column('q', Integer)
+ col2 = column('p', Integer)
+ stmt = select(col1, col2)
+
+ Above, ``stmt.selected_columns`` would be a collection that contains
+ the ``col1`` and ``col2`` objects directly. For a statement that is
+ against a :class:`_schema.Table` or other
+ :class:`_expression.FromClause`, the collection will use the
+ :class:`_expression.ColumnElement` objects that are in the
+ :attr:`_expression.FromClause.c` collection of the from element.
+
+ .. note::
+
+ The :attr:`_sql.Select.selected_columns` collection does not
+ include expressions established in the columns clause using the
+ :func:`_sql.text` construct; these are silently omitted from the
+ collection. To use plain textual column expressions inside of a
+ :class:`_sql.Select` construct, use the :func:`_sql.literal_column`
+ construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ # compare to SelectState._generate_columns_plus_names, which
+ # generates the actual names used in the SELECT string. that
+ # method is more complex because it also renders columns that are
+ # fully ambiguous, e.g. same column more than once.
+ conv = SelectState._column_naming_convention(self._label_style)
+
+ return ColumnCollection(
+ [
+ (conv(c), c)
+ for c in self._all_selected_columns
+ if not c._is_text_clause
+ ]
+ ).as_immutable()
+
+ @HasMemoized.memoized_attribute
+ def _all_selected_columns(self):
+ meth = SelectState.get_plugin_class(self).all_selected_columns
+ return list(meth(self))
+
+ def _ensure_disambiguated_names(self):
+ if self._label_style is LABEL_STYLE_NONE:
+ self = self.set_label_style(LABEL_STYLE_DISAMBIGUATE_ONLY)
+ return self
+
+ def _generate_columns_plus_names(self, anon_for_dupe_key):
+ """Generate column names as rendered in a SELECT statement by
+ the compiler.
+
+ This is distinct from the _column_naming_convention generator that's
+ intended for population of .c collections and similar, which has
+ different rules. the collection returned here calls upon the
+ _column_naming_convention as well.
+
+ """
+ cols = self._all_selected_columns
+
+ key_naming_convention = SelectState._column_naming_convention(
+ self._label_style
+ )
+
+ names = {}
+
+ result = []
+ result_append = result.append
+
+ table_qualified = self._label_style is LABEL_STYLE_TABLENAME_PLUS_COL
+ label_style_none = self._label_style is LABEL_STYLE_NONE
+
+ # a counter used for "dedupe" labels, which have double underscores
+ # in them and are never referred by name; they only act
+ # as positional placeholders. they need only be unique within
+ # the single columns clause they're rendered within (required by
+ # some dbs such as mysql). So their anon identity is tracked against
+ # a fixed counter rather than hash() identity.
+ dedupe_hash = 1
+
+ for c in cols:
+ repeated = False
+
+ if not c._render_label_in_columns_clause:
+ effective_name = (
+ required_label_name
+ ) = fallback_label_name = None
+ elif label_style_none:
+ effective_name = required_label_name = None
+ fallback_label_name = c._non_anon_label or c._anon_name_label
+ else:
+ if table_qualified:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = c._tq_label
+ else:
+ effective_name = fallback_label_name = c._non_anon_label
+ required_label_name = None
+
+ if effective_name is None:
+ # it seems like this could be _proxy_key and we would
+ # not need _expression_label but it isn't
+ # giving us a clue when to use anon_label instead
+ expr_label = c._expression_label
+ if expr_label is None:
+ repeated = c._anon_name_label in names
+ names[c._anon_name_label] = c
+ effective_name = required_label_name = None
+
+ if repeated:
+ # here, "required_label_name" is sent as
+ # "None" and "fallback_label_name" is sent.
+ if table_qualified:
+ fallback_label_name = (
+ c._dedupe_anon_tq_label_idx(dedupe_hash)
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._dedupe_anon_label_idx(
+ dedupe_hash
+ )
+ dedupe_hash += 1
+ else:
+ fallback_label_name = c._anon_name_label
+ else:
+ required_label_name = (
+ effective_name
+ ) = fallback_label_name = expr_label
+
+ if effective_name is not None:
+ if effective_name in names:
+ # when looking to see if names[name] is the same column as
+ # c, use hash(), so that an annotated version of the column
+ # is seen as the same as the non-annotated
+ if hash(names[effective_name]) != hash(c):
+
+ # different column under the same name. apply
+ # disambiguating label
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_tq_label
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._anon_name_label
+
+ if anon_for_dupe_key and required_label_name in names:
+ # here, c._anon_tq_label is definitely unique to
+ # that column identity (or annotated version), so
+ # this should always be true.
+ # this is also an infrequent codepath because
+ # you need two levels of duplication to be here
+ assert hash(names[required_label_name]) == hash(c)
+
+ # the column under the disambiguating label is
+ # already present. apply the "dedupe" label to
+ # subsequent occurrences of the column so that the
+ # original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[required_label_name] = c
+ elif anon_for_dupe_key:
+ # same column under the same name. apply the "dedupe"
+ # label so that the original stays non-ambiguous
+ if table_qualified:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_tq_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ else:
+ required_label_name = (
+ fallback_label_name
+ ) = c._dedupe_anon_label_idx(dedupe_hash)
+ dedupe_hash += 1
+ repeated = True
+ else:
+ names[effective_name] = c
+
+ result_append(
+ (
+ # string label name, if non-None, must be rendered as a
+ # label, i.e. "AS <name>"
+ required_label_name,
+ # proxy_key that is to be part of the result map for this
+ # col. this is also the key in a fromclause.c or
+ # select.selected_columns collection
+ key_naming_convention(c),
+ # name that can be used to render an "AS <name>" when
+ # we have to render a label even though
+ # required_label_name was not given
+ fallback_label_name,
+ # the ColumnElement itself
+ c,
+ # True if this is a duplicate of a previous column
+ # in the list of columns
+ repeated,
+ )
+ )
+
+ return result
+
+ def _generate_fromclause_column_proxies(self, subquery):
+ """Generate column proxies to place in the exported ``.c``
+ collection of a subquery."""
+
+ prox = [
+ c._make_proxy(
+ subquery,
+ key=proxy_key,
+ name=required_label_name,
+ name_is_truncatable=True,
+ )
+ for (
+ required_label_name,
+ proxy_key,
+ fallback_label_name,
+ c,
+ repeated,
+ ) in (self._generate_columns_plus_names(False))
+ if not c._is_text_clause
+ ]
+
+ subquery._columns._populate_separate_keys(prox)
+
+ def _needs_parens_for_grouping(self):
+ return self._has_row_limiting_clause or bool(
+ self._order_by_clause.clauses
+ )
+
+ def self_group(self, against=None):
+ """Return a 'grouping' construct as per the
+ :class:`_expression.ClauseElement` specification.
+
+ This produces an element that can be embedded in an expression. Note
+ that this method is called automatically as needed when constructing
+ expressions and should not require explicit use.
+
+ """
+ if (
+ isinstance(against, CompoundSelect)
+ and not self._needs_parens_for_grouping()
+ ):
+ return self
+ else:
+ return SelectStatementGrouping(self)
+
+ def union(self, *other, **kwargs):
+ r"""Return a SQL ``UNION`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_union(self, *other, **kwargs)
+
+ def union_all(self, *other, **kwargs):
+ r"""Return a SQL ``UNION ALL`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_union_all(self, *other, **kwargs)
+
+ def except_(self, *other, **kwargs):
+ r"""Return a SQL ``EXCEPT`` of this select() construct against
+ the given selectable provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_except(self, *other, **kwargs)
+
+ def except_all(self, *other, **kwargs):
+ r"""Return a SQL ``EXCEPT ALL`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_except_all(self, *other, **kwargs)
+
+ def intersect(self, *other, **kwargs):
+ r"""Return a SQL ``INTERSECT`` of this select() construct against
+ the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_intersect(self, *other, **kwargs)
+
+ def intersect_all(self, *other, **kwargs):
+ r"""Return a SQL ``INTERSECT ALL`` of this select() construct
+ against the given selectables provided as positional arguments.
+
+ :param \*other: one or more elements with which to create a
+ UNION.
+
+ .. versionchanged:: 1.4.28
+
+ multiple elements are now accepted.
+
+ :param \**kwargs: keyword arguments are forwarded to the constructor
+ for the newly created :class:`_sql.CompoundSelect` object.
+
+ """
+ return CompoundSelect._create_intersect_all(self, *other, **kwargs)
+
+ @property
+ @util.deprecated_20(
+ ":attr:`.Executable.bind`",
+ alternative="Bound metadata is being removed as of SQLAlchemy 2.0.",
+ enable_warnings=False,
+ )
+ def bind(self):
+ """Returns the :class:`_engine.Engine` or :class:`_engine.Connection`
+ to which this :class:`.Executable` is bound, or None if none found.
+
+ """
+ if self._bind:
+ return self._bind
+
+ for item in self._iterate_from_elements():
+ if item._is_subquery and item.element is self:
+ raise exc.InvalidRequestError(
+ "select() construct refers to itself as a FROM"
+ )
+
+ e = item.bind
+ if e:
+ self._bind = e
+ return e
+ else:
+ break
+
+ for c in self._raw_columns:
+ e = c.bind
+ if e:
+ self._bind = e
+ return e
+
+ @bind.setter
+ def bind(self, bind):
+ self._bind = bind
+
+
+class ScalarSelect(roles.InElementRole, Generative, Grouping):
+ """Represent a scalar subquery.
+
+
+ A :class:`_sql.ScalarSelect` is created by invoking the
+ :meth:`_sql.SelectBase.scalar_subquery` method. The object
+ then participates in other SQL expressions as a SQL column expression
+ within the :class:`_sql.ColumnElement` hierarchy.
+
+ .. seealso::
+
+ :meth:`_sql.SelectBase.scalar_subquery`
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+ """
+
+ _from_objects = []
+ _is_from_container = True
+ _is_implicitly_boolean = False
+ inherit_cache = True
+
+ def __init__(self, element):
+ self.element = element
+ self.type = element._scalar_type()
+
+ @property
+ def columns(self):
+ raise exc.InvalidRequestError(
+ "Scalar Select expression has no "
+ "columns; use this object directly "
+ "within a column-level expression."
+ )
+
+ c = columns
+
+ @_generative
+ def where(self, crit):
+ """Apply a WHERE clause to the SELECT statement referred to
+ by this :class:`_expression.ScalarSelect`.
+
+ """
+ self.element = self.element.where(crit)
+
+ def self_group(self, **kwargs):
+ return self
+
+ @_generative
+ def correlate(self, *fromclauses):
+ r"""Return a new :class:`_expression.ScalarSelect`
+ which will correlate the given FROM
+ clauses to that of an enclosing :class:`_expression.Select`.
+
+ This method is mirrored from the :meth:`_sql.Select.correlate` method
+ of the underlying :class:`_sql.Select`. The method applies the
+ :meth:_sql.Select.correlate` method, then returns a new
+ :class:`_sql.ScalarSelect` against that statement.
+
+ .. versionadded:: 1.4 Previously, the
+ :meth:`_sql.ScalarSelect.correlate`
+ method was only available from :class:`_sql.Select`.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate collection.
+
+ .. seealso::
+
+ :meth:`_expression.ScalarSelect.correlate_except`
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+
+ """
+ self.element = self.element.correlate(*fromclauses)
+
+ @_generative
+ def correlate_except(self, *fromclauses):
+ r"""Return a new :class:`_expression.ScalarSelect`
+ which will omit the given FROM
+ clauses from the auto-correlation process.
+
+ This method is mirrored from the
+ :meth:`_sql.Select.correlate_except` method of the underlying
+ :class:`_sql.Select`. The method applies the
+ :meth:_sql.Select.correlate_except` method, then returns a new
+ :class:`_sql.ScalarSelect` against that statement.
+
+ .. versionadded:: 1.4 Previously, the
+ :meth:`_sql.ScalarSelect.correlate_except`
+ method was only available from :class:`_sql.Select`.
+
+ :param \*fromclauses: a list of one or more
+ :class:`_expression.FromClause`
+ constructs, or other compatible constructs (i.e. ORM-mapped
+ classes) to become part of the correlate-exception collection.
+
+ .. seealso::
+
+ :meth:`_expression.ScalarSelect.correlate`
+
+ :ref:`tutorial_scalar_subquery` - in the 2.0 tutorial
+
+
+ """
+
+ self.element = self.element.correlate_except(*fromclauses)
+
+
+class Exists(UnaryExpression):
+ """Represent an ``EXISTS`` clause.
+
+ See :func:`_sql.exists` for a description of usage.
+
+ An ``EXISTS`` clause can also be constructed from a :func:`_sql.select`
+ instance by calling :meth:`_sql.SelectBase.exists`.
+
+ """
+
+ _from_objects = []
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ """Construct a new :class:`_expression.Exists` construct.
+
+ The :func:`_sql.exists` can be invoked by itself to produce an
+ :class:`_sql.Exists` construct, which will accept simple WHERE
+ criteria::
+
+ exists_criteria = exists().where(table1.c.col1 == table2.c.col2)
+
+ However, for greater flexibility in constructing the SELECT, an
+ existing :class:`_sql.Select` construct may be converted to an
+ :class:`_sql.Exists`, most conveniently by making use of the
+ :meth:`_sql.SelectBase.exists` method::
+
+ exists_criteria = (
+ select(table2.c.col2).
+ where(table1.c.col1 == table2.c.col2).
+ exists()
+ )
+
+ The EXISTS criteria is then used inside of an enclosing SELECT::
+
+ stmt = select(table1.c.col1).where(exists_criteria)
+
+ The above statement will then be of the form::
+
+ SELECT col1 FROM table1 WHERE EXISTS
+ (SELECT table2.col2 FROM table2 WHERE table2.col2 = table1.col1)
+
+ .. seealso::
+
+ :ref:`tutorial_exists` - in the :term:`2.0 style` tutorial.
+
+ :meth:`_sql.SelectBase.exists` - method to transform a ``SELECT`` to an
+ ``EXISTS`` clause.
+
+ """ # noqa: E501
+ if args and isinstance(args[0], (SelectBase, ScalarSelect)):
+ s = args[0]
+ else:
+ if not args:
+ args = (literal_column("*"),)
+ s = Select._create(*args, **kwargs).scalar_subquery()
+
+ UnaryExpression.__init__(
+ self,
+ s,
+ operator=operators.exists,
+ type_=type_api.BOOLEANTYPE,
+ wraps_column_expression=True,
+ )
+
+ def _regroup(self, fn):
+ element = self.element._ungroup()
+ element = fn(element)
+ return element.self_group(against=operators.exists)
+
+ @util.deprecated_params(
+ whereclause=(
+ "2.0",
+ "The :paramref:`_sql.Exists.select().whereclause` parameter "
+ "is deprecated and will be removed in version 2.0. "
+ "Please make use "
+ "of the :meth:`.Select.where` "
+ "method to add WHERE criteria to the SELECT statement.",
+ ),
+ kwargs=(
+ "2.0",
+ "The :meth:`_sql.Exists.select` method will no longer accept "
+ "keyword arguments in version 2.0. "
+ "Please use generative methods from the "
+ ":class:`_sql.Select` construct in order to apply additional "
+ "modifications.",
+ ),
+ )
+ def select(self, whereclause=None, **kwargs):
+ r"""Return a SELECT of this :class:`_expression.Exists`.
+
+ e.g.::
+
+ stmt = exists(some_table.c.id).where(some_table.c.id == 5).select()
+
+ This will produce a statement resembling::
+
+ SELECT EXISTS (SELECT id FROM some_table WHERE some_table = :param) AS anon_1
+
+ :param whereclause: a WHERE clause, equivalent to calling the
+ :meth:`_sql.Select.where` method.
+
+ :param **kwargs: additional keyword arguments are passed to the
+ legacy constructor for :class:`_sql.Select` described at
+ :meth:`_sql.Select.create_legacy_select`.
+
+ .. seealso::
+
+ :func:`_expression.select` - general purpose
+ method which allows for arbitrary column lists.
+
+ """ # noqa
+
+ if whereclause is not None:
+ kwargs["whereclause"] = whereclause
+ return Select._create_select_from_fromclause(self, [self], **kwargs)
+
+ def correlate(self, *fromclause):
+ """Apply correlation to the subquery noted by this
+ :class:`_sql.Exists`.
+
+ .. seealso::
+
+ :meth:`_sql.ScalarSelect.correlate`
+
+ """
+ e = self._clone()
+ e.element = self._regroup(
+ lambda element: element.correlate(*fromclause)
+ )
+ return e
+
+ def correlate_except(self, *fromclause):
+ """Apply correlation to the subquery noted by this
+ :class:`_sql.Exists`.
+
+ .. seealso::
+
+ :meth:`_sql.ScalarSelect.correlate_except`
+
+ """
+
+ e = self._clone()
+ e.element = self._regroup(
+ lambda element: element.correlate_except(*fromclause)
+ )
+ return e
+
+ def select_from(self, *froms):
+ """Return a new :class:`_expression.Exists` construct,
+ applying the given
+ expression to the :meth:`_expression.Select.select_from`
+ method of the select
+ statement contained.
+
+ .. note:: it is typically preferable to build a :class:`_sql.Select`
+ statement first, including the desired WHERE clause, then use the
+ :meth:`_sql.SelectBase.exists` method to produce an
+ :class:`_sql.Exists` object at once.
+
+ """
+ e = self._clone()
+ e.element = self._regroup(lambda element: element.select_from(*froms))
+ return e
+
+ def where(self, *clause):
+ """Return a new :func:`_expression.exists` construct with the
+ given expression added to
+ its WHERE clause, joined to the existing clause via AND, if any.
+
+
+ .. note:: it is typically preferable to build a :class:`_sql.Select`
+ statement first, including the desired WHERE clause, then use the
+ :meth:`_sql.SelectBase.exists` method to produce an
+ :class:`_sql.Exists` object at once.
+
+ """
+ e = self._clone()
+ e.element = self._regroup(lambda element: element.where(*clause))
+ return e
+
+
+class TextualSelect(SelectBase):
+ """Wrap a :class:`_expression.TextClause` construct within a
+ :class:`_expression.SelectBase`
+ interface.
+
+ This allows the :class:`_expression.TextClause` object to gain a
+ ``.c`` collection
+ and other FROM-like capabilities such as
+ :meth:`_expression.FromClause.alias`,
+ :meth:`_expression.SelectBase.cte`, etc.
+
+ The :class:`_expression.TextualSelect` construct is produced via the
+ :meth:`_expression.TextClause.columns`
+ method - see that method for details.
+
+ .. versionchanged:: 1.4 the :class:`_expression.TextualSelect`
+ class was renamed
+ from ``TextAsFrom``, to more correctly suit its role as a
+ SELECT-oriented object and not a FROM clause.
+
+ .. seealso::
+
+ :func:`_expression.text`
+
+ :meth:`_expression.TextClause.columns` - primary creation interface.
+
+ """
+
+ __visit_name__ = "textual_select"
+
+ _label_style = LABEL_STYLE_NONE
+
+ _traverse_internals = [
+ ("element", InternalTraversal.dp_clauseelement),
+ ("column_args", InternalTraversal.dp_clauseelement_list),
+ ] + SupportsCloneAnnotations._clone_annotations_traverse_internals
+
+ _is_textual = True
+
+ is_text = True
+ is_select = True
+
+ def __init__(self, text, columns, positional=False):
+ self.element = text
+ # convert for ORM attributes->columns, etc
+ self.column_args = [
+ coercions.expect(roles.ColumnsClauseRole, c) for c in columns
+ ]
+ self.positional = positional
+
+ @HasMemoized.memoized_attribute
+ def selected_columns(self):
+ """A :class:`_expression.ColumnCollection`
+ representing the columns that
+ this SELECT statement or similar construct returns in its result set,
+ not including :class:`_sql.TextClause` constructs.
+
+ This collection differs from the :attr:`_expression.FromClause.columns`
+ collection of a :class:`_expression.FromClause` in that the columns
+ within this collection cannot be directly nested inside another SELECT
+ statement; a subquery must be applied first which provides for the
+ necessary parenthesization required by SQL.
+
+ For a :class:`_expression.TextualSelect` construct, the collection
+ contains the :class:`_expression.ColumnElement` objects that were
+ passed to the constructor, typically via the
+ :meth:`_expression.TextClause.columns` method.
+
+
+ .. versionadded:: 1.4
+
+ """
+ return ColumnCollection(
+ (c.key, c) for c in self.column_args
+ ).as_immutable()
+
+ @property
+ def _all_selected_columns(self):
+ return self.column_args
+
+ def _set_label_style(self, style):
+ return self
+
+ def _ensure_disambiguated_names(self):
+ return self
+
+ @property
+ def _bind(self):
+ return self.element._bind
+
+ @_generative
+ def bindparams(self, *binds, **bind_as_values):
+ self.element = self.element.bindparams(*binds, **bind_as_values)
+
+ def _generate_fromclause_column_proxies(self, fromclause):
+ fromclause._columns._populate_separate_keys(
+ c._make_proxy(fromclause) for c in self.column_args
+ )
+
+ def _scalar_type(self):
+ return self.column_args[0].type
+
+
+TextAsFrom = TextualSelect
+"""Backwards compatibility with the previous name"""
+
+
+class AnnotatedFromClause(Annotated):
+ def __init__(self, element, values):
+ # force FromClause to generate their internal
+ # collections into __dict__
+ element.c
+ Annotated.__init__(self, element, values)
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
new file mode 100644
index 0000000..322bfec
--- /dev/null
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -0,0 +1,3351 @@
+# sql/sqltypes.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""SQL specific types.
+
+"""
+
+import codecs
+import datetime as dt
+import decimal
+import json
+
+from . import coercions
+from . import elements
+from . import operators
+from . import roles
+from . import type_api
+from .base import _bind_or_error
+from .base import NO_ARG
+from .base import SchemaEventTarget
+from .elements import _NONE_NAME
+from .elements import quoted_name
+from .elements import Slice
+from .elements import TypeCoerce as type_coerce # noqa
+from .traversals import HasCacheKey
+from .traversals import InternalTraversal
+from .type_api import Emulated
+from .type_api import NativeForEmulated # noqa
+from .type_api import to_instance
+from .type_api import TypeDecorator
+from .type_api import TypeEngine
+from .type_api import Variant
+from .. import event
+from .. import exc
+from .. import inspection
+from .. import processors
+from .. import util
+from ..util import compat
+from ..util import langhelpers
+from ..util import OrderedDict
+from ..util import pickle
+
+
+class _LookupExpressionAdapter(object):
+
+ """Mixin expression adaptations based on lookup tables.
+
+ These rules are currently used by the numeric, integer and date types
+ which have detailed cross-expression coercion rules.
+
+ """
+
+ @property
+ def _expression_adaptations(self):
+ raise NotImplementedError()
+
+ class Comparator(TypeEngine.Comparator):
+ _blank_dict = util.immutabledict()
+
+ def _adapt_expression(self, op, other_comparator):
+ othertype = other_comparator.type._type_affinity
+ lookup = self.type._expression_adaptations.get(
+ op, self._blank_dict
+ ).get(othertype, self.type)
+ if lookup is othertype:
+ return (op, other_comparator.type)
+ elif lookup is self.type._type_affinity:
+ return (op, self.type)
+ else:
+ return (op, to_instance(lookup))
+
+ comparator_factory = Comparator
+
+
+class Concatenable(object):
+
+ """A mixin that marks a type as supporting 'concatenation',
+ typically strings."""
+
+ class Comparator(TypeEngine.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ if op is operators.add and isinstance(
+ other_comparator,
+ (Concatenable.Comparator, NullType.Comparator),
+ ):
+ return operators.concat_op, self.expr.type
+ else:
+ return super(Concatenable.Comparator, self)._adapt_expression(
+ op, other_comparator
+ )
+
+ comparator_factory = Comparator
+
+
+class Indexable(object):
+ """A mixin that marks a type as supporting indexing operations,
+ such as array or JSON structures.
+
+
+ .. versionadded:: 1.1.0
+
+
+ """
+
+ class Comparator(TypeEngine.Comparator):
+ def _setup_getitem(self, index):
+ raise NotImplementedError()
+
+ def __getitem__(self, index):
+ (
+ adjusted_op,
+ adjusted_right_expr,
+ result_type,
+ ) = self._setup_getitem(index)
+ return self.operate(
+ adjusted_op, adjusted_right_expr, result_type=result_type
+ )
+
+ comparator_factory = Comparator
+
+
+class String(Concatenable, TypeEngine):
+
+ """The base for all string and character types.
+
+ In SQL, corresponds to VARCHAR. Can also take Python unicode objects
+ and encode to the database's encoding in bind params (and the reverse for
+ result sets.)
+
+ The `length` field is usually required when the `String` type is
+ used within a CREATE TABLE statement, as VARCHAR requires a length
+ on most databases.
+
+ """
+
+ __visit_name__ = "string"
+
+ RETURNS_UNICODE = util.symbol(
+ "RETURNS_UNICODE",
+ """Indicates that the DBAPI returns Python Unicode for VARCHAR,
+ NVARCHAR, and other character-based datatypes in all cases.
+
+ This is the default value for
+ :attr:`.DefaultDialect.returns_unicode_strings` under Python 3.
+
+ .. versionadded:: 1.4
+
+ """,
+ )
+
+ RETURNS_BYTES = util.symbol(
+ "RETURNS_BYTES",
+ """Indicates that the DBAPI returns byte objects under Python 3
+ or non-Unicode string objects under Python 2 for VARCHAR, NVARCHAR,
+ and other character-based datatypes in all cases.
+
+ This may be applied to the
+ :attr:`.DefaultDialect.returns_unicode_strings` attribute.
+
+ .. versionadded:: 1.4
+
+ """,
+ )
+
+ RETURNS_CONDITIONAL = util.symbol(
+ "RETURNS_CONDITIONAL",
+ """Indicates that the DBAPI may return Unicode or bytestrings for
+ VARCHAR, NVARCHAR, and other character-based datatypes, and that
+ SQLAlchemy's default String datatype will need to test on a per-row
+ basis for Unicode or bytes.
+
+ This may be applied to the
+ :attr:`.DefaultDialect.returns_unicode_strings` attribute.
+
+ .. versionadded:: 1.4
+
+ """,
+ )
+
+ RETURNS_UNKNOWN = util.symbol(
+ "RETURNS_UNKNOWN",
+ """Indicates that the dialect should test on first connect what the
+ string-returning behavior of character-based datatypes is.
+
+ This is the default value for DefaultDialect.unicode_returns under
+ Python 2.
+
+ This may be applied to the
+ :attr:`.DefaultDialect.returns_unicode_strings` attribute under
+ Python 2 only. The value is disallowed under Python 3.
+
+ .. versionadded:: 1.4
+
+ .. deprecated:: 1.4 This value will be removed in SQLAlchemy 2.0.
+
+ """,
+ )
+
+ @util.deprecated_params(
+ convert_unicode=(
+ "1.3",
+ "The :paramref:`.String.convert_unicode` parameter is deprecated "
+ "and will be removed in a future release. All modern DBAPIs "
+ "now support Python Unicode directly and this parameter is "
+ "unnecessary.",
+ ),
+ unicode_error=(
+ "1.3",
+ "The :paramref:`.String.unicode_errors` parameter is deprecated "
+ "and will be removed in a future release. This parameter is "
+ "unnecessary for modern Python DBAPIs and degrades performance "
+ "significantly.",
+ ),
+ )
+ def __init__(
+ self,
+ length=None,
+ collation=None,
+ convert_unicode=False,
+ unicode_error=None,
+ _warn_on_bytestring=False,
+ _expect_unicode=False,
+ ):
+ """
+ Create a string-holding type.
+
+ :param length: optional, a length for the column for use in
+ DDL and CAST expressions. May be safely omitted if no ``CREATE
+ TABLE`` will be issued. Certain databases may require a
+ ``length`` for use in DDL, and will raise an exception when
+ the ``CREATE TABLE`` DDL is issued if a ``VARCHAR``
+ with no length is included. Whether the value is
+ interpreted as bytes or characters is database specific.
+
+ :param collation: Optional, a column-level collation for
+ use in DDL and CAST expressions. Renders using the
+ COLLATE keyword supported by SQLite, MySQL, and PostgreSQL.
+ E.g.::
+
+ >>> from sqlalchemy import cast, select, String
+ >>> print(select(cast('some string', String(collation='utf8'))))
+ SELECT CAST(:param_1 AS VARCHAR COLLATE utf8) AS anon_1
+
+ :param convert_unicode: When set to ``True``, the
+ :class:`.String` type will assume that
+ input is to be passed as Python Unicode objects under Python 2,
+ and results returned as Python Unicode objects.
+ In the rare circumstance that the DBAPI does not support
+ Python unicode under Python 2, SQLAlchemy will use its own
+ encoder/decoder functionality on strings, referring to the
+ value of the :paramref:`_sa.create_engine.encoding` parameter
+ parameter passed to :func:`_sa.create_engine` as the encoding.
+
+ For the extremely rare case that Python Unicode
+ is to be encoded/decoded by SQLAlchemy on a backend
+ that *does* natively support Python Unicode,
+ the string value ``"force"`` can be passed here which will
+ cause SQLAlchemy's encode/decode services to be
+ used unconditionally.
+
+ .. note::
+
+ SQLAlchemy's unicode-conversion flags and features only apply
+ to Python 2; in Python 3, all string objects are Unicode objects.
+ For this reason, as well as the fact that virtually all modern
+ DBAPIs now support Unicode natively even under Python 2,
+ the :paramref:`.String.convert_unicode` flag is inherently a
+ legacy feature.
+
+ .. note::
+
+ In the vast majority of cases, the :class:`.Unicode` or
+ :class:`.UnicodeText` datatypes should be used for a
+ :class:`_schema.Column` that expects to store non-ascii data.
+ These
+ datatypes will ensure that the correct types are used on the
+ database side as well as set up the correct Unicode behaviors
+ under Python 2.
+
+ .. seealso::
+
+ :paramref:`_sa.create_engine.convert_unicode` -
+ :class:`_engine.Engine`-wide parameter
+
+ :param unicode_error: Optional, a method to use to handle Unicode
+ conversion errors. Behaves like the ``errors`` keyword argument to
+ the standard library's ``string.decode()`` functions, requires
+ that :paramref:`.String.convert_unicode` is set to
+ ``"force"``
+
+ """
+ if unicode_error is not None and convert_unicode != "force":
+ raise exc.ArgumentError(
+ "convert_unicode must be 'force' " "when unicode_error is set."
+ )
+
+ self.length = length
+ self.collation = collation
+ self._expect_unicode = convert_unicode or _expect_unicode
+ self._expect_unicode_error = unicode_error
+
+ self._warn_on_bytestring = _warn_on_bytestring
+
+ def literal_processor(self, dialect):
+ def process(value):
+ value = value.replace("'", "''")
+
+ if dialect.identifier_preparer._double_percents:
+ value = value.replace("%", "%%")
+
+ return "'%s'" % value
+
+ return process
+
+ def bind_processor(self, dialect):
+ if self._expect_unicode or dialect.convert_unicode:
+ if (
+ dialect.supports_unicode_binds
+ and self._expect_unicode != "force"
+ ):
+ if self._warn_on_bytestring:
+
+ def process(value):
+ if isinstance(value, util.binary_type):
+ util.warn_limited(
+ "Unicode type received non-unicode "
+ "bind param value %r.",
+ (util.ellipses_string(value),),
+ )
+ return value
+
+ return process
+ else:
+ return None
+ else:
+ encoder = codecs.getencoder(dialect.encoding)
+ warn_on_bytestring = self._warn_on_bytestring
+
+ def process(value):
+ if isinstance(value, util.text_type):
+ return encoder(value, self._expect_unicode_error)[0]
+ elif warn_on_bytestring and value is not None:
+ util.warn_limited(
+ "Unicode type received non-unicode bind "
+ "param value %r.",
+ (util.ellipses_string(value),),
+ )
+ return value
+
+ return process
+ else:
+ return None
+
+ def result_processor(self, dialect, coltype):
+ wants_unicode = self._expect_unicode or dialect.convert_unicode
+ needs_convert = wants_unicode and (
+ dialect.returns_unicode_strings is not String.RETURNS_UNICODE
+ or self._expect_unicode in ("force", "force_nocheck")
+ )
+ needs_isinstance = (
+ needs_convert
+ and dialect.returns_unicode_strings
+ in (
+ String.RETURNS_CONDITIONAL,
+ String.RETURNS_UNICODE,
+ )
+ and self._expect_unicode != "force_nocheck"
+ )
+ if needs_convert:
+ if needs_isinstance:
+ return processors.to_conditional_unicode_processor_factory(
+ dialect.encoding, self._expect_unicode_error
+ )
+ else:
+ return processors.to_unicode_processor_factory(
+ dialect.encoding, self._expect_unicode_error
+ )
+ else:
+ return None
+
+ @property
+ def python_type(self):
+ if self._expect_unicode:
+ return util.text_type
+ else:
+ return str
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.STRING
+
+ @classmethod
+ def _warn_deprecated_unicode(cls):
+ util.warn_deprecated(
+ "The convert_unicode on Engine and String as well as the "
+ "unicode_error flag on String are deprecated. All modern "
+ "DBAPIs now support Python Unicode natively under Python 2, and "
+ "under Python 3 all strings are inherently Unicode. These flags "
+ "will be removed in a future release.",
+ version="1.3",
+ )
+
+
+class Text(String):
+
+ """A variably sized string type.
+
+ In SQL, usually corresponds to CLOB or TEXT. Can also take Python
+ unicode objects and encode to the database's encoding in bind
+ params (and the reverse for result sets.) In general, TEXT objects
+ do not have a length; while some databases will accept a length
+ argument here, it will be rejected by others.
+
+ """
+
+ __visit_name__ = "text"
+
+
+class Unicode(String):
+
+ """A variable length Unicode string type.
+
+ The :class:`.Unicode` type is a :class:`.String` subclass that assumes
+ input and output strings that may contain non-ASCII characters, and for
+ some backends implies an underlying column type that is explicitly
+ supporting of non-ASCII data, such as ``NVARCHAR`` on Oracle and SQL
+ Server. This will impact the output of ``CREATE TABLE`` statements and
+ ``CAST`` functions at the dialect level, and also in some cases will
+ indicate different behavior in the DBAPI itself in how it handles bound
+ parameters.
+
+ The character encoding used by the :class:`.Unicode` type that is used to
+ transmit and receive data to the database is usually determined by the
+ DBAPI itself. All modern DBAPIs accommodate non-ASCII strings but may have
+ different methods of managing database encodings; if necessary, this
+ encoding should be configured as detailed in the notes for the target DBAPI
+ in the :ref:`dialect_toplevel` section.
+
+ In modern SQLAlchemy, use of the :class:`.Unicode` datatype does not
+ typically imply any encoding/decoding behavior within SQLAlchemy itself.
+ Historically, when DBAPIs did not support Python ``unicode`` objects under
+ Python 2, SQLAlchemy handled unicode encoding/decoding services itself
+ which would be controlled by the flag :paramref:`.String.convert_unicode`;
+ this flag is deprecated as it is no longer needed for Python 3.
+
+ When using Python 2, data that is passed to columns that use the
+ :class:`.Unicode` datatype must be of type ``unicode``, and not ``str``
+ which in Python 2 is equivalent to ``bytes``. In Python 3, all data
+ passed to columns that use the :class:`.Unicode` datatype should be
+ of type ``str``. See the flag :paramref:`.String.convert_unicode` for
+ more discussion of unicode encode/decode behavior under Python 2.
+
+ .. warning:: Some database backends, particularly SQL Server with pyodbc,
+ are known to have undesirable behaviors regarding data that is noted
+ as being of ``NVARCHAR`` type as opposed to ``VARCHAR``, including
+ datatype mismatch errors and non-use of indexes. See the section
+ on :meth:`.DialectEvents.do_setinputsizes` for background on working
+ around unicode character issues for backends like SQL Server with
+ pyodbc as well as cx_Oracle.
+
+ .. seealso::
+
+ :class:`.UnicodeText` - unlengthed textual counterpart
+ to :class:`.Unicode`.
+
+ :paramref:`.String.convert_unicode`
+
+ :meth:`.DialectEvents.do_setinputsizes`
+
+
+ """
+
+ __visit_name__ = "unicode"
+
+ def __init__(self, length=None, **kwargs):
+ """
+ Create a :class:`.Unicode` object.
+
+ Parameters are the same as that of :class:`.String`,
+ with the exception that ``convert_unicode``
+ defaults to ``True``.
+
+ """
+ kwargs.setdefault("_expect_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
+ super(Unicode, self).__init__(length=length, **kwargs)
+
+
+class UnicodeText(Text):
+
+ """An unbounded-length Unicode string type.
+
+ See :class:`.Unicode` for details on the unicode
+ behavior of this object.
+
+ Like :class:`.Unicode`, usage the :class:`.UnicodeText` type implies a
+ unicode-capable type being used on the backend, such as
+ ``NCLOB``, ``NTEXT``.
+
+ """
+
+ __visit_name__ = "unicode_text"
+
+ def __init__(self, length=None, **kwargs):
+ """
+ Create a Unicode-converting Text type.
+
+ Parameters are the same as that of :class:`_expression.TextClause`,
+ with the exception that ``convert_unicode``
+ defaults to ``True``.
+
+ """
+ kwargs.setdefault("_expect_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
+ super(UnicodeText, self).__init__(length=length, **kwargs)
+
+ def _warn_deprecated_unicode(self):
+ pass
+
+
+class Integer(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``int`` integers."""
+
+ __visit_name__ = "integer"
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ @property
+ def python_type(self):
+ return int
+
+ def literal_processor(self, dialect):
+ def process(value):
+ return str(int(value))
+
+ return process
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # TODO: need a dictionary object that will
+ # handle operators generically here, this is incomplete
+ return {
+ operators.add: {
+ Date: Date,
+ Integer: self.__class__,
+ Numeric: Numeric,
+ },
+ operators.mul: {
+ Interval: Interval,
+ Integer: self.__class__,
+ Numeric: Numeric,
+ },
+ operators.div: {Integer: self.__class__, Numeric: Numeric},
+ operators.truediv: {Integer: self.__class__, Numeric: Numeric},
+ operators.sub: {Integer: self.__class__, Numeric: Numeric},
+ }
+
+
+class SmallInteger(Integer):
+
+ """A type for smaller ``int`` integers.
+
+ Typically generates a ``SMALLINT`` in DDL, and otherwise acts like
+ a normal :class:`.Integer` on the Python side.
+
+ """
+
+ __visit_name__ = "small_integer"
+
+
+class BigInteger(Integer):
+
+ """A type for bigger ``int`` integers.
+
+ Typically generates a ``BIGINT`` in DDL, and otherwise acts like
+ a normal :class:`.Integer` on the Python side.
+
+ """
+
+ __visit_name__ = "big_integer"
+
+
+class Numeric(_LookupExpressionAdapter, TypeEngine):
+
+ """Base for non-integer numeric types, such as
+ ``NUMERIC``, ``FLOAT``, ``DECIMAL``, and other variants.
+
+ The :class:`.Numeric` datatype when used directly will render DDL
+ corresponding to precision numerics if available, such as
+ ``NUMERIC(precision, scale)``. The :class:`.Float` subclass will
+ attempt to render a floating-point datatype such as ``FLOAT(precision)``.
+
+ :class:`.Numeric` returns Python ``decimal.Decimal`` objects by default,
+ based on the default value of ``True`` for the
+ :paramref:`.Numeric.asdecimal` parameter. If this parameter is set to
+ False, returned values are coerced to Python ``float`` objects.
+
+ The :class:`.Float` subtype, being more specific to floating point,
+ defaults the :paramref:`.Float.asdecimal` flag to False so that the
+ default Python datatype is ``float``.
+
+ .. note::
+
+ When using a :class:`.Numeric` datatype against a database type that
+ returns Python floating point values to the driver, the accuracy of the
+ decimal conversion indicated by :paramref:`.Numeric.asdecimal` may be
+ limited. The behavior of specific numeric/floating point datatypes
+ is a product of the SQL datatype in use, the Python :term:`DBAPI`
+ in use, as well as strategies that may be present within
+ the SQLAlchemy dialect in use. Users requiring specific precision/
+ scale are encouraged to experiment with the available datatypes
+ in order to determine the best results.
+
+ """
+
+ __visit_name__ = "numeric"
+
+ _default_decimal_return_scale = 10
+
+ def __init__(
+ self,
+ precision=None,
+ scale=None,
+ decimal_return_scale=None,
+ asdecimal=True,
+ ):
+ """
+ Construct a Numeric.
+
+ :param precision: the numeric precision for use in DDL ``CREATE
+ TABLE``.
+
+ :param scale: the numeric scale for use in DDL ``CREATE TABLE``.
+
+ :param asdecimal: default True. Return whether or not
+ values should be sent as Python Decimal objects, or
+ as floats. Different DBAPIs send one or the other based on
+ datatypes - the Numeric type will ensure that return values
+ are one or the other across DBAPIs consistently.
+
+ :param decimal_return_scale: Default scale to use when converting
+ from floats to Python decimals. Floating point values will typically
+ be much longer due to decimal inaccuracy, and most floating point
+ database types don't have a notion of "scale", so by default the
+ float type looks for the first ten decimal places when converting.
+ Specifying this value will override that length. Types which
+ do include an explicit ".scale" value, such as the base
+ :class:`.Numeric` as well as the MySQL float types, will use the
+ value of ".scale" as the default for decimal_return_scale, if not
+ otherwise specified.
+
+ When using the ``Numeric`` type, care should be taken to ensure
+ that the asdecimal setting is appropriate for the DBAPI in use -
+ when Numeric applies a conversion from Decimal->float or float->
+ Decimal, this conversion incurs an additional performance overhead
+ for all result columns received.
+
+ DBAPIs that return Decimal natively (e.g. psycopg2) will have
+ better accuracy and higher performance with a setting of ``True``,
+ as the native translation to Decimal reduces the amount of floating-
+ point issues at play, and the Numeric type itself doesn't need
+ to apply any further conversions. However, another DBAPI which
+ returns floats natively *will* incur an additional conversion
+ overhead, and is still subject to floating point data loss - in
+ which case ``asdecimal=False`` will at least remove the extra
+ conversion overhead.
+
+ """
+ self.precision = precision
+ self.scale = scale
+ self.decimal_return_scale = decimal_return_scale
+ self.asdecimal = asdecimal
+
+ @property
+ def _effective_decimal_return_scale(self):
+ if self.decimal_return_scale is not None:
+ return self.decimal_return_scale
+ elif getattr(self, "scale", None) is not None:
+ return self.scale
+ else:
+ return self._default_decimal_return_scale
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def literal_processor(self, dialect):
+ def process(value):
+ return str(value)
+
+ return process
+
+ @property
+ def python_type(self):
+ if self.asdecimal:
+ return decimal.Decimal
+ else:
+ return float
+
+ def bind_processor(self, dialect):
+ if dialect.supports_native_decimal:
+ return None
+ else:
+ return processors.to_float
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ if dialect.supports_native_decimal:
+ # we're a "numeric", DBAPI will give us Decimal directly
+ return None
+ else:
+ util.warn(
+ "Dialect %s+%s does *not* support Decimal "
+ "objects natively, and SQLAlchemy must "
+ "convert from floating point - rounding "
+ "errors and other issues may occur. Please "
+ "consider storing Decimal numbers as strings "
+ "or integers on this platform for lossless "
+ "storage." % (dialect.name, dialect.driver)
+ )
+
+ # we're a "numeric", DBAPI returns floats, convert.
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal,
+ self.scale
+ if self.scale is not None
+ else self._default_decimal_return_scale,
+ )
+ else:
+ if dialect.supports_native_decimal:
+ return processors.to_float
+ else:
+ return None
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ return {
+ operators.mul: {
+ Interval: Interval,
+ Numeric: self.__class__,
+ Integer: self.__class__,
+ },
+ operators.div: {Numeric: self.__class__, Integer: self.__class__},
+ operators.truediv: {
+ Numeric: self.__class__,
+ Integer: self.__class__,
+ },
+ operators.add: {Numeric: self.__class__, Integer: self.__class__},
+ operators.sub: {Numeric: self.__class__, Integer: self.__class__},
+ }
+
+
+class Float(Numeric):
+
+ """Type representing floating point types, such as ``FLOAT`` or ``REAL``.
+
+ This type returns Python ``float`` objects by default, unless the
+ :paramref:`.Float.asdecimal` flag is set to True, in which case they
+ are coerced to ``decimal.Decimal`` objects.
+
+
+ """
+
+ __visit_name__ = "float"
+
+ scale = None
+
+ def __init__(
+ self, precision=None, asdecimal=False, decimal_return_scale=None
+ ):
+ r"""
+ Construct a Float.
+
+ :param precision: the numeric precision for use in DDL ``CREATE
+ TABLE``.
+
+ :param asdecimal: the same flag as that of :class:`.Numeric`, but
+ defaults to ``False``. Note that setting this flag to ``True``
+ results in floating point conversion.
+
+ :param decimal_return_scale: Default scale to use when converting
+ from floats to Python decimals. Floating point values will typically
+ be much longer due to decimal inaccuracy, and most floating point
+ database types don't have a notion of "scale", so by default the
+ float type looks for the first ten decimal places when converting.
+ Specifying this value will override that length. Note that the
+ MySQL float types, which do include "scale", will use "scale"
+ as the default for decimal_return_scale, if not otherwise specified.
+
+ .. versionadded:: 0.9.0
+
+ """
+ self.precision = precision
+ self.asdecimal = asdecimal
+ self.decimal_return_scale = decimal_return_scale
+
+ def result_processor(self, dialect, coltype):
+ if self.asdecimal:
+ return processors.to_decimal_processor_factory(
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
+ elif dialect.supports_native_decimal:
+ return processors.to_float
+ else:
+ return None
+
+
+class DateTime(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``datetime.datetime()`` objects.
+
+ Date and time types return objects from the Python ``datetime``
+ module. Most DBAPIs have built in support for the datetime
+ module, with the noted exception of SQLite. In the case of
+ SQLite, date and time types are stored as strings which are then
+ converted back to datetime objects when rows are returned.
+
+ For the time representation within the datetime type, some
+ backends include additional options, such as timezone support and
+ fractional seconds support. For fractional seconds, use the
+ dialect-specific datatype, such as :class:`.mysql.TIME`. For
+ timezone support, use at least the :class:`_types.TIMESTAMP` datatype,
+ if not the dialect-specific datatype object.
+
+ """
+
+ __visit_name__ = "datetime"
+
+ def __init__(self, timezone=False):
+ """Construct a new :class:`.DateTime`.
+
+ :param timezone: boolean. Indicates that the datetime type should
+ enable timezone support, if available on the
+ **base date/time-holding type only**. It is recommended
+ to make use of the :class:`_types.TIMESTAMP` datatype directly when
+ using this flag, as some databases include separate generic
+ date/time-holding types distinct from the timezone-capable
+ TIMESTAMP datatype, such as Oracle.
+
+
+ """
+ self.timezone = timezone
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATETIME
+
+ def _resolve_for_literal(self, value):
+ with_timezone = value.tzinfo is not None
+ if with_timezone and not self.timezone:
+ return DATETIME_TIMEZONE
+ else:
+ return self
+
+ @property
+ def python_type(self):
+ return dt.datetime
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {Interval: self.__class__},
+ operators.sub: {Interval: self.__class__, DateTime: Interval},
+ }
+
+
+class Date(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``datetime.date()`` objects."""
+
+ __visit_name__ = "date"
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATETIME
+
+ @property
+ def python_type(self):
+ return dt.date
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {
+ Integer: self.__class__,
+ Interval: DateTime,
+ Time: DateTime,
+ },
+ operators.sub: {
+ # date - integer = date
+ Integer: self.__class__,
+ # date - date = integer.
+ Date: Integer,
+ Interval: DateTime,
+ # date - datetime = interval,
+ # this one is not in the PG docs
+ # but works
+ DateTime: Interval,
+ },
+ }
+
+
+class Time(_LookupExpressionAdapter, TypeEngine):
+
+ """A type for ``datetime.time()`` objects."""
+
+ __visit_name__ = "time"
+
+ def __init__(self, timezone=False):
+ self.timezone = timezone
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.DATETIME
+
+ @property
+ def python_type(self):
+ return dt.time
+
+ def _resolve_for_literal(self, value):
+ with_timezone = value.tzinfo is not None
+ if with_timezone and not self.timezone:
+ return TIME_TIMEZONE
+ else:
+ return self
+
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {Date: DateTime, Interval: self.__class__},
+ operators.sub: {Time: Interval, Interval: self.__class__},
+ }
+
+
+class _Binary(TypeEngine):
+
+ """Define base behavior for binary types."""
+
+ def __init__(self, length=None):
+ self.length = length
+
+ def literal_processor(self, dialect):
+ def process(value):
+ value = value.decode(dialect.encoding).replace("'", "''")
+ return "'%s'" % value
+
+ return process
+
+ @property
+ def python_type(self):
+ return util.binary_type
+
+ # Python 3 - sqlite3 doesn't need the `Binary` conversion
+ # here, though pg8000 does to indicate "bytea"
+ def bind_processor(self, dialect):
+ if dialect.dbapi is None:
+ return None
+
+ DBAPIBinary = dialect.dbapi.Binary
+
+ def process(value):
+ if value is not None:
+ return DBAPIBinary(value)
+ else:
+ return None
+
+ return process
+
+ # Python 3 has native bytes() type
+ # both sqlite3 and pg8000 seem to return it,
+ # psycopg2 as of 2.5 returns 'memoryview'
+ if util.py2k:
+
+ def result_processor(self, dialect, coltype):
+ return processors.to_str
+
+ else:
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is not None:
+ value = bytes(value)
+ return value
+
+ return process
+
+ def coerce_compared_value(self, op, value):
+ """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+
+ if isinstance(value, util.string_types):
+ return self
+ else:
+ return super(_Binary, self).coerce_compared_value(op, value)
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.BINARY
+
+
+class LargeBinary(_Binary):
+
+ """A type for large binary byte data.
+
+ The :class:`.LargeBinary` type corresponds to a large and/or unlengthed
+ binary type for the target platform, such as BLOB on MySQL and BYTEA for
+ PostgreSQL. It also handles the necessary conversions for the DBAPI.
+
+ """
+
+ __visit_name__ = "large_binary"
+
+ def __init__(self, length=None):
+ """
+ Construct a LargeBinary type.
+
+ :param length: optional, a length for the column for use in
+ DDL statements, for those binary types that accept a length,
+ such as the MySQL BLOB type.
+
+ """
+ _Binary.__init__(self, length=length)
+
+
+class SchemaType(SchemaEventTarget):
+
+ """Mark a type as possibly requiring schema-level DDL for usage.
+
+ Supports types that must be explicitly created/dropped (i.e. PG ENUM type)
+ as well as types that are complimented by table or schema level
+ constraints, triggers, and other rules.
+
+ :class:`.SchemaType` classes can also be targets for the
+ :meth:`.DDLEvents.before_parent_attach` and
+ :meth:`.DDLEvents.after_parent_attach` events, where the events fire off
+ surrounding the association of the type object with a parent
+ :class:`_schema.Column`.
+
+ .. seealso::
+
+ :class:`.Enum`
+
+ :class:`.Boolean`
+
+
+ """
+
+ _use_schema_map = True
+
+ def __init__(
+ self,
+ name=None,
+ schema=None,
+ metadata=None,
+ inherit_schema=False,
+ quote=None,
+ _create_events=True,
+ ):
+ if name is not None:
+ self.name = quoted_name(name, quote)
+ else:
+ self.name = None
+ self.schema = schema
+ self.metadata = metadata
+ self.inherit_schema = inherit_schema
+ self._create_events = _create_events
+
+ if _create_events and self.metadata:
+ event.listen(
+ self.metadata,
+ "before_create",
+ util.portable_instancemethod(self._on_metadata_create),
+ )
+ event.listen(
+ self.metadata,
+ "after_drop",
+ util.portable_instancemethod(self._on_metadata_drop),
+ )
+
+ def _set_parent(self, column, **kw):
+ column._on_table_attach(util.portable_instancemethod(self._set_table))
+
+ def _variant_mapping_for_set_table(self, column):
+ if isinstance(column.type, Variant):
+ variant_mapping = column.type.mapping.copy()
+ variant_mapping["_default"] = column.type.impl
+ else:
+ variant_mapping = None
+ return variant_mapping
+
+ def _set_table(self, column, table):
+ if self.inherit_schema:
+ self.schema = table.schema
+ elif self.metadata and self.schema is None and self.metadata.schema:
+ self.schema = self.metadata.schema
+
+ if not self._create_events:
+ return
+
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
+ event.listen(
+ table,
+ "before_create",
+ util.portable_instancemethod(
+ self._on_table_create, {"variant_mapping": variant_mapping}
+ ),
+ )
+ event.listen(
+ table,
+ "after_drop",
+ util.portable_instancemethod(
+ self._on_table_drop, {"variant_mapping": variant_mapping}
+ ),
+ )
+ if self.metadata is None:
+ # TODO: what's the difference between self.metadata
+ # and table.metadata here ?
+ event.listen(
+ table.metadata,
+ "before_create",
+ util.portable_instancemethod(
+ self._on_metadata_create,
+ {"variant_mapping": variant_mapping},
+ ),
+ )
+ event.listen(
+ table.metadata,
+ "after_drop",
+ util.portable_instancemethod(
+ self._on_metadata_drop,
+ {"variant_mapping": variant_mapping},
+ ),
+ )
+
+ def copy(self, **kw):
+ return self.adapt(self.__class__, _create_events=True)
+
+ def adapt(self, impltype, **kw):
+ schema = kw.pop("schema", self.schema)
+ metadata = kw.pop("metadata", self.metadata)
+ _create_events = kw.pop("_create_events", False)
+ return impltype(
+ name=self.name,
+ schema=schema,
+ inherit_schema=self.inherit_schema,
+ metadata=metadata,
+ _create_events=_create_events,
+ **kw
+ )
+
+ @property
+ def bind(self):
+ return self.metadata and self.metadata.bind or None
+
+ def create(self, bind=None, checkfirst=False):
+ """Issue CREATE DDL for this type, if applicable."""
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t.create(bind=bind, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=False):
+ """Issue DROP DDL for this type, if applicable."""
+
+ if bind is None:
+ bind = _bind_or_error(self)
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t.drop(bind=bind, checkfirst=checkfirst)
+
+ def _on_table_create(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_table_create(target, bind, **kw)
+
+ def _on_table_drop(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_table_drop(target, bind, **kw)
+
+ def _on_metadata_create(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_metadata_create(target, bind, **kw)
+
+ def _on_metadata_drop(self, target, bind, **kw):
+ if not self._is_impl_for_variant(bind.dialect, kw):
+ return
+
+ t = self.dialect_impl(bind.dialect)
+ if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
+ t._on_metadata_drop(target, bind, **kw)
+
+ def _is_impl_for_variant(self, dialect, kw):
+ variant_mapping = kw.pop("variant_mapping", None)
+ if variant_mapping is None:
+ return True
+
+ # since PostgreSQL is the only DB that has ARRAY this can only
+ # be integration tested by PG-specific tests
+ def _we_are_the_impl(typ):
+ return (
+ typ is self or isinstance(typ, ARRAY) and typ.item_type is self
+ )
+
+ if dialect.name in variant_mapping and _we_are_the_impl(
+ variant_mapping[dialect.name]
+ ):
+ return True
+ elif dialect.name not in variant_mapping:
+ return _we_are_the_impl(variant_mapping["_default"])
+
+
+class Enum(Emulated, String, SchemaType):
+ """Generic Enum Type.
+
+ The :class:`.Enum` type provides a set of possible string values
+ which the column is constrained towards.
+
+ The :class:`.Enum` type will make use of the backend's native "ENUM"
+ type if one is available; otherwise, it uses a VARCHAR datatype.
+ An option also exists to automatically produce a CHECK constraint
+ when the VARCHAR (so called "non-native") variant is produced;
+ see the :paramref:`.Enum.create_constraint` flag.
+
+ The :class:`.Enum` type also provides in-Python validation of string
+ values during both read and write operations. When reading a value
+ from the database in a result set, the string value is always checked
+ against the list of possible values and a ``LookupError`` is raised
+ if no match is found. When passing a value to the database as a
+ plain string within a SQL statement, if the
+ :paramref:`.Enum.validate_strings` parameter is
+ set to True, a ``LookupError`` is raised for any string value that's
+ not located in the given list of possible values; note that this
+ impacts usage of LIKE expressions with enumerated values (an unusual
+ use case).
+
+ .. versionchanged:: 1.1 the :class:`.Enum` type now provides in-Python
+ validation of input values as well as on data being returned by
+ the database.
+
+ The source of enumerated values may be a list of string values, or
+ alternatively a PEP-435-compliant enumerated class. For the purposes
+ of the :class:`.Enum` datatype, this class need only provide a
+ ``__members__`` method.
+
+ When using an enumerated class, the enumerated objects are used
+ both for input and output, rather than strings as is the case with
+ a plain-string enumerated type::
+
+ import enum
+ from sqlalchemy import Enum
+
+ class MyEnum(enum.Enum):
+ one = 1
+ two = 2
+ three = 3
+
+ t = Table(
+ 'data', MetaData(),
+ Column('value', Enum(MyEnum))
+ )
+
+ connection.execute(t.insert(), {"value": MyEnum.two})
+ assert connection.scalar(t.select()) is MyEnum.two
+
+ Above, the string names of each element, e.g. "one", "two", "three",
+ are persisted to the database; the values of the Python Enum, here
+ indicated as integers, are **not** used; the value of each enum can
+ therefore be any kind of Python object whether or not it is persistable.
+
+ In order to persist the values and not the names, the
+ :paramref:`.Enum.values_callable` parameter may be used. The value of
+ this parameter is a user-supplied callable, which is intended to be used
+ with a PEP-435-compliant enumerated class and returns a list of string
+ values to be persisted. For a simple enumeration that uses string values,
+ a callable such as ``lambda x: [e.value for e in x]`` is sufficient.
+
+ .. versionadded:: 1.1 - support for PEP-435-style enumerated
+ classes.
+
+
+ .. seealso::
+
+ :class:`_postgresql.ENUM` - PostgreSQL-specific type,
+ which has additional functionality.
+
+ :class:`.mysql.ENUM` - MySQL-specific type
+
+ """
+
+ __visit_name__ = "enum"
+
+ @util.deprecated_params(
+ convert_unicode=(
+ "1.3",
+ "The :paramref:`.Enum.convert_unicode` parameter is deprecated "
+ "and will be removed in a future release. All modern DBAPIs "
+ "now support Python Unicode directly and this parameter is "
+ "unnecessary.",
+ )
+ )
+ def __init__(self, *enums, **kw):
+ r"""Construct an enum.
+
+ Keyword arguments which don't apply to a specific backend are ignored
+ by that backend.
+
+ :param \*enums: either exactly one PEP-435 compliant enumerated type
+ or one or more string labels.
+
+ .. versionadded:: 1.1 a PEP-435 style enumerated class may be
+ passed.
+
+ :param convert_unicode: Enable unicode-aware bind parameter and
+ result-set processing for this Enum's data under Python 2 only.
+ Under Python 2, this is set automatically based on the presence of
+ unicode label strings. This flag will be removed in SQLAlchemy 2.0.
+
+ :param create_constraint: defaults to False. When creating a
+ non-native enumerated type, also build a CHECK constraint on the
+ database against the valid values.
+
+ .. note:: it is strongly recommended that the CHECK constraint
+ have an explicit name in order to support schema-management
+ concerns. This can be established either by setting the
+ :paramref:`.Enum.name` parameter or by setting up an
+ appropriate naming convention; see
+ :ref:`constraint_naming_conventions` for background.
+
+ .. versionchanged:: 1.4 - this flag now defaults to False, meaning
+ no CHECK constraint is generated for a non-native enumerated
+ type.
+
+ :param metadata: Associate this type directly with a ``MetaData``
+ object. For types that exist on the target database as an
+ independent schema construct (PostgreSQL), this type will be
+ created and dropped within ``create_all()`` and ``drop_all()``
+ operations. If the type is not associated with any ``MetaData``
+ object, it will associate itself with each ``Table`` in which it is
+ used, and will be created when any of those individual tables are
+ created, after a check is performed for its existence. The type is
+ only dropped when ``drop_all()`` is called for that ``Table``
+ object's metadata, however.
+
+ The value of the :paramref:`_schema.MetaData.schema` parameter of
+ the :class:`_schema.MetaData` object, if set, will be used as the
+ default value of the :paramref:`_types.Enum.schema` on this object
+ if an explicit value is not otherwise supplied.
+
+ .. versionchanged:: 1.4.12 :class:`_types.Enum` inherits the
+ :paramref:`_schema.MetaData.schema` parameter of the
+ :class:`_schema.MetaData` object if present, when passed using
+ the :paramref:`_types.Enum.metadata` parameter.
+
+ :param name: The name of this type. This is required for PostgreSQL
+ and any future supported database which requires an explicitly
+ named type, or an explicitly named constraint in order to generate
+ the type and/or a table that uses it. If a PEP-435 enumerated
+ class was used, its name (converted to lower case) is used by
+ default.
+
+ :param native_enum: Use the database's native ENUM type when
+ available. Defaults to True. When False, uses VARCHAR + check
+ constraint for all backends. When False, the VARCHAR length can be
+ controlled with :paramref:`.Enum.length`; currently "length" is
+ ignored if native_enum=True.
+
+ :param length: Allows specifying a custom length for the VARCHAR
+ when :paramref:`.Enum.native_enum` is False. By default it uses the
+ length of the longest value.
+
+ .. versionadded:: 1.3.16
+
+ :param schema: Schema name of this type. For types that exist on the
+ target database as an independent schema construct (PostgreSQL),
+ this parameter specifies the named schema in which the type is
+ present.
+
+ If not present, the schema name will be taken from the
+ :class:`_schema.MetaData` collection if passed as
+ :paramref:`_types.Enum.metadata`, for a :class:`_schema.MetaData`
+ that includes the :paramref:`_schema.MetaData.schema` parameter.
+
+ .. versionchanged:: 1.4.12 :class:`_types.Enum` inherits the
+ :paramref:`_schema.MetaData.schema` parameter of the
+ :class:`_schema.MetaData` object if present, when passed using
+ the :paramref:`_types.Enum.metadata` parameter.
+
+ Otherwise, if the :paramref:`_types.Enum.inherit_schema` flag is set
+ to ``True``, the schema will be inherited from the associated
+ :class:`_schema.Table` object if any; when
+ :paramref:`_types.Enum.inherit_schema` is at its default of
+ ``False``, the owning table's schema is **not** used.
+
+
+ :param quote: Set explicit quoting preferences for the type's name.
+
+ :param inherit_schema: When ``True``, the "schema" from the owning
+ :class:`_schema.Table`
+ will be copied to the "schema" attribute of this
+ :class:`.Enum`, replacing whatever value was passed for the
+ ``schema`` attribute. This also takes effect when using the
+ :meth:`_schema.Table.to_metadata` operation.
+
+ :param validate_strings: when True, string values that are being
+ passed to the database in a SQL statement will be checked
+ for validity against the list of enumerated values. Unrecognized
+ values will result in a ``LookupError`` being raised.
+
+ .. versionadded:: 1.1.0b2
+
+ :param values_callable: A callable which will be passed the PEP-435
+ compliant enumerated type, which should then return a list of string
+ values to be persisted. This allows for alternate usages such as
+ using the string value of an enum to be persisted to the database
+ instead of its name.
+
+ .. versionadded:: 1.2.3
+
+ :param sort_key_function: a Python callable which may be used as the
+ "key" argument in the Python ``sorted()`` built-in. The SQLAlchemy
+ ORM requires that primary key columns which are mapped must
+ be sortable in some way. When using an unsortable enumeration
+ object such as a Python 3 ``Enum`` object, this parameter may be
+ used to set a default sort key function for the objects. By
+ default, the database value of the enumeration is used as the
+ sorting function.
+
+ .. versionadded:: 1.3.8
+
+ :param omit_aliases: A boolean that when true will remove aliases from
+ pep 435 enums. For backward compatibility it defaults to ``False``.
+ A deprecation warning is raised if the enum has aliases and this
+ flag was not set.
+
+ .. versionadded:: 1.4.5
+
+ .. deprecated:: 1.4 The default will be changed to ``True`` in
+ SQLAlchemy 2.0.
+
+ """
+ self._enum_init(enums, kw)
+
+ @property
+ def _enums_argument(self):
+ if self.enum_class is not None:
+ return [self.enum_class]
+ else:
+ return self.enums
+
+ def _enum_init(self, enums, kw):
+ """internal init for :class:`.Enum` and subclasses.
+
+ friendly init helper used by subclasses to remove
+ all the Enum-specific keyword arguments from kw. Allows all
+ other arguments in kw to pass through.
+
+ """
+ self.native_enum = kw.pop("native_enum", True)
+ self.create_constraint = kw.pop("create_constraint", False)
+ self.values_callable = kw.pop("values_callable", None)
+ self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
+ length_arg = kw.pop("length", NO_ARG)
+ self._omit_aliases = kw.pop("omit_aliases", NO_ARG)
+ _disable_warnings = kw.pop("_disable_warnings", False)
+ values, objects = self._parse_into_values(enums, kw)
+ self._setup_for_values(values, objects, kw)
+
+ convert_unicode = kw.pop("convert_unicode", None)
+ self.validate_strings = kw.pop("validate_strings", False)
+
+ if convert_unicode is None:
+ for e in self.enums:
+ # this is all py2k logic that can go away for py3k only,
+ # "expect unicode" will always be implicitly true
+ if isinstance(e, util.text_type):
+ _expect_unicode = True
+ break
+ else:
+ _expect_unicode = False
+ else:
+ _expect_unicode = convert_unicode
+
+ if self.enums:
+ self._default_length = length = max(len(x) for x in self.enums)
+ else:
+ self._default_length = length = 0
+
+ if length_arg is not NO_ARG:
+ if self.native_enum:
+ if not _disable_warnings:
+ util.warn(
+ "Enum 'length' argument is currently ignored unless "
+ "native_enum is specified as False, including for DDL "
+ "that renders VARCHAR in any case. This may change "
+ "in a future release."
+ )
+ else:
+ if not _disable_warnings and length_arg < length:
+ raise ValueError(
+ "When provided, length must be larger or equal"
+ " than the length of the longest enum value. %s < %s"
+ % (length_arg, length)
+ )
+ length = length_arg
+
+ self._valid_lookup[None] = self._object_lookup[None] = None
+
+ super(Enum, self).__init__(
+ length=length, _expect_unicode=_expect_unicode
+ )
+
+ if self.enum_class:
+ kw.setdefault("name", self.enum_class.__name__.lower())
+ SchemaType.__init__(
+ self,
+ name=kw.pop("name", None),
+ schema=kw.pop("schema", None),
+ metadata=kw.pop("metadata", None),
+ inherit_schema=kw.pop("inherit_schema", False),
+ quote=kw.pop("quote", None),
+ _create_events=kw.pop("_create_events", True),
+ )
+
+ def _parse_into_values(self, enums, kw):
+ if not enums and "_enums" in kw:
+ enums = kw.pop("_enums")
+
+ if len(enums) == 1 and hasattr(enums[0], "__members__"):
+ self.enum_class = enums[0]
+
+ _members = self.enum_class.__members__
+
+ aliases = [n for n, v in _members.items() if v.name != n]
+ if self._omit_aliases is NO_ARG and aliases:
+ util.warn_deprecated_20(
+ "The provided enum %s contains the aliases %s. The "
+ "``omit_aliases`` will default to ``True`` in SQLAlchemy "
+ "2.0. Specify a value to silence this warning."
+ % (self.enum_class.__name__, aliases)
+ )
+ if self._omit_aliases is True:
+ # remove aliases
+ members = OrderedDict(
+ (n, v) for n, v in _members.items() if v.name == n
+ )
+ else:
+ members = _members
+ if self.values_callable:
+ values = self.values_callable(self.enum_class)
+ else:
+ values = list(members)
+ objects = [members[k] for k in members]
+ return values, objects
+ else:
+ self.enum_class = None
+ return enums, enums
+
+ def _setup_for_values(self, values, objects, kw):
+ self.enums = list(values)
+
+ self._valid_lookup = dict(zip(reversed(objects), reversed(values)))
+
+ self._object_lookup = dict(zip(values, objects))
+
+ self._valid_lookup.update(
+ [
+ (value, self._valid_lookup[self._object_lookup[value]])
+ for value in values
+ ]
+ )
+
+ @property
+ def sort_key_function(self):
+ if self._sort_key_function is NO_ARG:
+ return self._db_value_for_elem
+ else:
+ return self._sort_key_function
+
+ @property
+ def native(self):
+ return self.native_enum
+
+ def _db_value_for_elem(self, elem):
+ try:
+ return self._valid_lookup[elem]
+ except KeyError as err:
+ # for unknown string values, we return as is. While we can
+ # validate these if we wanted, that does not allow for lesser-used
+ # end-user use cases, such as using a LIKE comparison with an enum,
+ # or for an application that wishes to apply string tests to an
+ # ENUM (see [ticket:3725]). While we can decide to differentiate
+ # here between an INSERT statement and a criteria used in a SELECT,
+ # for now we're staying conservative w/ behavioral changes (perhaps
+ # someone has a trigger that handles strings on INSERT)
+ if not self.validate_strings and isinstance(
+ elem, compat.string_types
+ ):
+ return elem
+ else:
+ util.raise_(
+ LookupError(
+ "'%s' is not among the defined enum values. "
+ "Enum name: %s. Possible values: %s"
+ % (
+ elem,
+ self.name,
+ langhelpers.repr_tuple_names(self.enums),
+ )
+ ),
+ replace_context=err,
+ )
+
+ class Comparator(String.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ op, typ = super(Enum.Comparator, self)._adapt_expression(
+ op, other_comparator
+ )
+ if op is operators.concat_op:
+ typ = String(
+ self.type.length, _expect_unicode=self.type._expect_unicode
+ )
+ return op, typ
+
+ comparator_factory = Comparator
+
+ def _object_value_for_elem(self, elem):
+ try:
+ return self._object_lookup[elem]
+ except KeyError as err:
+ util.raise_(
+ LookupError(
+ "'%s' is not among the defined enum values. "
+ "Enum name: %s. Possible values: %s"
+ % (
+ elem,
+ self.name,
+ langhelpers.repr_tuple_names(self.enums),
+ )
+ ),
+ replace_context=err,
+ )
+
+ def __repr__(self):
+ return util.generic_repr(
+ self,
+ additional_kw=[
+ ("native_enum", True),
+ ("create_constraint", False),
+ ("length", self._default_length),
+ ],
+ to_inspect=[Enum, SchemaType],
+ )
+
+ def as_generic(self, allow_nulltype=False):
+ if hasattr(self, "enums"):
+ args = self.enums
+ else:
+ raise NotImplementedError(
+ "TypeEngine.as_generic() heuristic "
+ "is undefined for types that inherit Enum but do not have "
+ "an `enums` attribute."
+ )
+
+ return util.constructor_copy(
+ self, self._generic_type_affinity, *args, _disable_warnings=True
+ )
+
+ def adapt_to_emulated(self, impltype, **kw):
+ kw.setdefault("_expect_unicode", self._expect_unicode)
+ kw.setdefault("validate_strings", self.validate_strings)
+ kw.setdefault("name", self.name)
+ kw["_disable_warnings"] = True
+ kw.setdefault("schema", self.schema)
+ kw.setdefault("inherit_schema", self.inherit_schema)
+ kw.setdefault("metadata", self.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("native_enum", self.native_enum)
+ kw.setdefault("values_callable", self.values_callable)
+ kw.setdefault("create_constraint", self.create_constraint)
+ kw.setdefault("length", self.length)
+ kw.setdefault("omit_aliases", self._omit_aliases)
+ assert "_enums" in kw
+ return impltype(**kw)
+
+ def adapt(self, impltype, **kw):
+ kw["_enums"] = self._enums_argument
+ kw["_disable_warnings"] = True
+ return super(Enum, self).adapt(impltype, **kw)
+
+ def _should_create_constraint(self, compiler, **kw):
+ if not self._is_impl_for_variant(compiler.dialect, kw):
+ return False
+ return (
+ not self.native_enum or not compiler.dialect.supports_native_enum
+ )
+
+ @util.preload_module("sqlalchemy.sql.schema")
+ def _set_table(self, column, table):
+ schema = util.preloaded.sql_schema
+ SchemaType._set_table(self, column, table)
+
+ if not self.create_constraint:
+ return
+
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
+ e = schema.CheckConstraint(
+ type_coerce(column, String()).in_(self.enums),
+ name=_NONE_NAME if self.name is None else self.name,
+ _create_rule=util.portable_instancemethod(
+ self._should_create_constraint,
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
+ )
+ assert e.table is table
+
+ def literal_processor(self, dialect):
+ parent_processor = super(Enum, self).literal_processor(dialect)
+
+ def process(value):
+ value = self._db_value_for_elem(value)
+ if parent_processor:
+ value = parent_processor(value)
+ return value
+
+ return process
+
+ def bind_processor(self, dialect):
+ parent_processor = super(Enum, self).bind_processor(dialect)
+
+ def process(value):
+ value = self._db_value_for_elem(value)
+ if parent_processor:
+ value = parent_processor(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ parent_processor = super(Enum, self).result_processor(dialect, coltype)
+
+ def process(value):
+ if parent_processor:
+ value = parent_processor(value)
+
+ value = self._object_value_for_elem(value)
+ return value
+
+ return process
+
+ def copy(self, **kw):
+ return SchemaType.copy(self, **kw)
+
+ @property
+ def python_type(self):
+ if self.enum_class:
+ return self.enum_class
+ else:
+ return super(Enum, self).python_type
+
+
+class PickleType(TypeDecorator):
+ """Holds Python objects, which are serialized using pickle.
+
+ PickleType builds upon the Binary type to apply Python's
+ ``pickle.dumps()`` to incoming objects, and ``pickle.loads()`` on
+ the way out, allowing any pickleable Python object to be stored as
+ a serialized binary field.
+
+ To allow ORM change events to propagate for elements associated
+ with :class:`.PickleType`, see :ref:`mutable_toplevel`.
+
+ """
+
+ impl = LargeBinary
+ cache_ok = True
+
+ def __init__(
+ self,
+ protocol=pickle.HIGHEST_PROTOCOL,
+ pickler=None,
+ comparator=None,
+ impl=None,
+ ):
+ """
+ Construct a PickleType.
+
+ :param protocol: defaults to ``pickle.HIGHEST_PROTOCOL``.
+
+ :param pickler: defaults to cPickle.pickle or pickle.pickle if
+ cPickle is not available. May be any object with
+ pickle-compatible ``dumps`` and ``loads`` methods.
+
+ :param comparator: a 2-arg callable predicate used
+ to compare values of this type. If left as ``None``,
+ the Python "equals" operator is used to compare values.
+
+ :param impl: A binary-storing :class:`_types.TypeEngine` class or
+ instance to use in place of the default :class:`_types.LargeBinary`.
+ For example the :class: `_mysql.LONGBLOB` class may be more effective
+ when using MySQL.
+
+ .. versionadded:: 1.4.20
+
+ """
+ self.protocol = protocol
+ self.pickler = pickler or pickle
+ self.comparator = comparator
+ super(PickleType, self).__init__()
+
+ if impl:
+ self.impl = to_instance(impl)
+
+ def __reduce__(self):
+ return PickleType, (self.protocol, None, self.comparator)
+
+ def bind_processor(self, dialect):
+ impl_processor = self.impl.bind_processor(dialect)
+ dumps = self.pickler.dumps
+ protocol = self.protocol
+ if impl_processor:
+
+ def process(value):
+ if value is not None:
+ value = dumps(value, protocol)
+ return impl_processor(value)
+
+ else:
+
+ def process(value):
+ if value is not None:
+ value = dumps(value, protocol)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ impl_processor = self.impl.result_processor(dialect, coltype)
+ loads = self.pickler.loads
+ if impl_processor:
+
+ def process(value):
+ value = impl_processor(value)
+ if value is None:
+ return None
+ return loads(value)
+
+ else:
+
+ def process(value):
+ if value is None:
+ return None
+ return loads(value)
+
+ return process
+
+ def compare_values(self, x, y):
+ if self.comparator:
+ return self.comparator(x, y)
+ else:
+ return x == y
+
+
+class Boolean(Emulated, TypeEngine, SchemaType):
+
+ """A bool datatype.
+
+ :class:`.Boolean` typically uses BOOLEAN or SMALLINT on the DDL side,
+ and on the Python side deals in ``True`` or ``False``.
+
+ The :class:`.Boolean` datatype currently has two levels of assertion
+ that the values persisted are simple true/false values. For all
+ backends, only the Python values ``None``, ``True``, ``False``, ``1``
+ or ``0`` are accepted as parameter values. For those backends that
+ don't support a "native boolean" datatype, an option exists to
+ also create a CHECK constraint on the target column
+
+ .. versionchanged:: 1.2 the :class:`.Boolean` datatype now asserts that
+ incoming Python values are already in pure boolean form.
+
+
+ """
+
+ __visit_name__ = "boolean"
+ native = True
+
+ def __init__(
+ self, create_constraint=False, name=None, _create_events=True
+ ):
+ """Construct a Boolean.
+
+ :param create_constraint: defaults to False. If the boolean
+ is generated as an int/smallint, also create a CHECK constraint
+ on the table that ensures 1 or 0 as a value.
+
+ .. note:: it is strongly recommended that the CHECK constraint
+ have an explicit name in order to support schema-management
+ concerns. This can be established either by setting the
+ :paramref:`.Boolean.name` parameter or by setting up an
+ appropriate naming convention; see
+ :ref:`constraint_naming_conventions` for background.
+
+ .. versionchanged:: 1.4 - this flag now defaults to False, meaning
+ no CHECK constraint is generated for a non-native enumerated
+ type.
+
+ :param name: if a CHECK constraint is generated, specify
+ the name of the constraint.
+
+ """
+ self.create_constraint = create_constraint
+ self.name = name
+ self._create_events = _create_events
+
+ def _should_create_constraint(self, compiler, **kw):
+ if not self._is_impl_for_variant(compiler.dialect, kw):
+ return False
+ return (
+ not compiler.dialect.supports_native_boolean
+ and compiler.dialect.non_native_boolean_check_constraint
+ )
+
+ @util.preload_module("sqlalchemy.sql.schema")
+ def _set_table(self, column, table):
+ schema = util.preloaded.sql_schema
+ if not self.create_constraint:
+ return
+
+ variant_mapping = self._variant_mapping_for_set_table(column)
+
+ e = schema.CheckConstraint(
+ type_coerce(column, self).in_([0, 1]),
+ name=_NONE_NAME if self.name is None else self.name,
+ _create_rule=util.portable_instancemethod(
+ self._should_create_constraint,
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
+ )
+ assert e.table is table
+
+ @property
+ def python_type(self):
+ return bool
+
+ _strict_bools = frozenset([None, True, False])
+
+ def _strict_as_bool(self, value):
+ if value not in self._strict_bools:
+ if not isinstance(value, int):
+ raise TypeError("Not a boolean value: %r" % (value,))
+ else:
+ raise ValueError(
+ "Value %r is not None, True, or False" % (value,)
+ )
+ return value
+
+ def literal_processor(self, dialect):
+ compiler = dialect.statement_compiler(dialect, None)
+ true = compiler.visit_true(None)
+ false = compiler.visit_false(None)
+
+ def process(value):
+ return true if self._strict_as_bool(value) else false
+
+ return process
+
+ def bind_processor(self, dialect):
+ _strict_as_bool = self._strict_as_bool
+ if dialect.supports_native_boolean:
+ _coerce = bool
+ else:
+ _coerce = int
+
+ def process(value):
+ value = _strict_as_bool(value)
+ if value is not None:
+ value = _coerce(value)
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ if dialect.supports_native_boolean:
+ return None
+ else:
+ return processors.int_to_boolean
+
+
+class _AbstractInterval(_LookupExpressionAdapter, TypeEngine):
+ @util.memoized_property
+ def _expression_adaptations(self):
+ # Based on
+ # https://www.postgresql.org/docs/current/static/functions-datetime.html.
+
+ return {
+ operators.add: {
+ Date: DateTime,
+ Interval: self.__class__,
+ DateTime: DateTime,
+ Time: Time,
+ },
+ operators.sub: {Interval: self.__class__},
+ operators.mul: {Numeric: self.__class__},
+ operators.truediv: {Numeric: self.__class__},
+ operators.div: {Numeric: self.__class__},
+ }
+
+ @property
+ def _type_affinity(self):
+ return Interval
+
+ def coerce_compared_value(self, op, value):
+ """See :meth:`.TypeEngine.coerce_compared_value` for a description."""
+ return self.impl.coerce_compared_value(op, value)
+
+
+class Interval(Emulated, _AbstractInterval, TypeDecorator):
+
+ """A type for ``datetime.timedelta()`` objects.
+
+ The Interval type deals with ``datetime.timedelta`` objects. In
+ PostgreSQL, the native ``INTERVAL`` type is used; for others, the
+ value is stored as a date which is relative to the "epoch"
+ (Jan. 1, 1970).
+
+ Note that the ``Interval`` type does not currently provide date arithmetic
+ operations on platforms which do not support interval types natively. Such
+ operations usually require transformation of both sides of the expression
+ (such as, conversion of both sides into integer epoch values first) which
+ currently is a manual procedure (such as via
+ :attr:`~sqlalchemy.sql.expression.func`).
+
+ """
+
+ impl = DateTime
+ epoch = dt.datetime.utcfromtimestamp(0)
+ cache_ok = True
+
+ def __init__(self, native=True, second_precision=None, day_precision=None):
+ """Construct an Interval object.
+
+ :param native: when True, use the actual
+ INTERVAL type provided by the database, if
+ supported (currently PostgreSQL, Oracle).
+ Otherwise, represent the interval data as
+ an epoch value regardless.
+
+ :param second_precision: For native interval types
+ which support a "fractional seconds precision" parameter,
+ i.e. Oracle and PostgreSQL
+
+ :param day_precision: for native interval types which
+ support a "day precision" parameter, i.e. Oracle.
+
+ """
+ super(Interval, self).__init__()
+ self.native = native
+ self.second_precision = second_precision
+ self.day_precision = day_precision
+
+ @property
+ def python_type(self):
+ return dt.timedelta
+
+ def adapt_to_emulated(self, impltype, **kw):
+ return _AbstractInterval.adapt(self, impltype, **kw)
+
+ def bind_processor(self, dialect):
+ impl_processor = self.impl.bind_processor(dialect)
+ epoch = self.epoch
+ if impl_processor:
+
+ def process(value):
+ if value is not None:
+ value = epoch + value
+ return impl_processor(value)
+
+ else:
+
+ def process(value):
+ if value is not None:
+ value = epoch + value
+ return value
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ impl_processor = self.impl.result_processor(dialect, coltype)
+ epoch = self.epoch
+ if impl_processor:
+
+ def process(value):
+ value = impl_processor(value)
+ if value is None:
+ return None
+ return value - epoch
+
+ else:
+
+ def process(value):
+ if value is None:
+ return None
+ return value - epoch
+
+ return process
+
+
+class JSON(Indexable, TypeEngine):
+ """Represent a SQL JSON type.
+
+ .. note:: :class:`_types.JSON`
+ is provided as a facade for vendor-specific
+ JSON types. Since it supports JSON SQL operations, it only
+ works on backends that have an actual JSON type, currently:
+
+ * PostgreSQL - see :class:`sqlalchemy.dialects.postgresql.JSON` and
+ :class:`sqlalchemy.dialects.postgresql.JSONB` for backend-specific
+ notes
+
+ * MySQL - see
+ :class:`sqlalchemy.dialects.mysql.JSON` for backend-specific notes
+
+ * SQLite as of version 3.9 - see
+ :class:`sqlalchemy.dialects.sqlite.JSON` for backend-specific notes
+
+ * Microsoft SQL Server 2016 and later - see
+ :class:`sqlalchemy.dialects.mssql.JSON` for backend-specific notes
+
+ :class:`_types.JSON` is part of the Core in support of the growing
+ popularity of native JSON datatypes.
+
+ The :class:`_types.JSON` type stores arbitrary JSON format data, e.g.::
+
+ data_table = Table('data_table', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', JSON)
+ )
+
+ with engine.connect() as conn:
+ conn.execute(
+ data_table.insert(),
+ {"data": {"key1": "value1", "key2": "value2"}}
+ )
+
+ **JSON-Specific Expression Operators**
+
+ The :class:`_types.JSON`
+ datatype provides these additional SQL operations:
+
+ * Keyed index operations::
+
+ data_table.c.data['some key']
+
+ * Integer index operations::
+
+ data_table.c.data[3]
+
+ * Path index operations::
+
+ data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
+
+ * Data casters for specific JSON element types, subsequent to an index
+ or path operation being invoked::
+
+ data_table.c.data["some key"].as_integer()
+
+ .. versionadded:: 1.3.11
+
+ Additional operations may be available from the dialect-specific versions
+ of :class:`_types.JSON`, such as
+ :class:`sqlalchemy.dialects.postgresql.JSON` and
+ :class:`sqlalchemy.dialects.postgresql.JSONB` which both offer additional
+ PostgreSQL-specific operations.
+
+ **Casting JSON Elements to Other Types**
+
+ Index operations, i.e. those invoked by calling upon the expression using
+ the Python bracket operator as in ``some_column['some key']``, return an
+ expression object whose type defaults to :class:`_types.JSON` by default,
+ so that
+ further JSON-oriented instructions may be called upon the result type.
+ However, it is likely more common that an index operation is expected
+ to return a specific scalar element, such as a string or integer. In
+ order to provide access to these elements in a backend-agnostic way,
+ a series of data casters are provided:
+
+ * :meth:`.JSON.Comparator.as_string` - return the element as a string
+
+ * :meth:`.JSON.Comparator.as_boolean` - return the element as a boolean
+
+ * :meth:`.JSON.Comparator.as_float` - return the element as a float
+
+ * :meth:`.JSON.Comparator.as_integer` - return the element as an integer
+
+ These data casters are implemented by supporting dialects in order to
+ assure that comparisons to the above types will work as expected, such as::
+
+ # integer comparison
+ data_table.c.data["some_integer_key"].as_integer() == 5
+
+ # boolean comparison
+ data_table.c.data["some_boolean"].as_boolean() == True
+
+ .. versionadded:: 1.3.11 Added type-specific casters for the basic JSON
+ data element types.
+
+ .. note::
+
+ The data caster functions are new in version 1.3.11, and supersede
+ the previous documented approaches of using CAST; for reference,
+ this looked like::
+
+ from sqlalchemy import cast, type_coerce
+ from sqlalchemy import String, JSON
+ cast(
+ data_table.c.data['some_key'], String
+ ) == type_coerce(55, JSON)
+
+ The above case now works directly as::
+
+ data_table.c.data['some_key'].as_integer() == 5
+
+ For details on the previous comparison approach within the 1.3.x
+ series, see the documentation for SQLAlchemy 1.2 or the included HTML
+ files in the doc/ directory of the version's distribution.
+
+ **Detecting Changes in JSON columns when using the ORM**
+
+ The :class:`_types.JSON` type, when used with the SQLAlchemy ORM, does not
+ detect in-place mutations to the structure. In order to detect these, the
+ :mod:`sqlalchemy.ext.mutable` extension must be used. This extension will
+ allow "in-place" changes to the datastructure to produce events which
+ will be detected by the unit of work. See the example at :class:`.HSTORE`
+ for a simple example involving a dictionary.
+
+ **Support for JSON null vs. SQL NULL**
+
+ When working with NULL values, the :class:`_types.JSON` type recommends the
+ use of two specific constants in order to differentiate between a column
+ that evaluates to SQL NULL, e.g. no value, vs. the JSON-encoded string of
+ ``"null"``. To insert or select against a value that is SQL NULL, use the
+ constant :func:`.null`. This symbol may be passed as a parameter value
+ specifically when using the :class:`_types.JSON` datatype, which contains
+ special logic that interprets this symbol to mean that the column value
+ should be SQL NULL as opposed to JSON ``"null"``::
+
+ from sqlalchemy import null
+ conn.execute(table.insert(), {"json_value": null()})
+
+ To insert or select against a value that is JSON ``"null"``, use the
+ constant :attr:`_types.JSON.NULL`::
+
+ conn.execute(table.insert(), {"json_value": JSON.NULL})
+
+ The :class:`_types.JSON` type supports a flag
+ :paramref:`_types.JSON.none_as_null` which when set to True will result
+ in the Python constant ``None`` evaluating to the value of SQL
+ NULL, and when set to False results in the Python constant
+ ``None`` evaluating to the value of JSON ``"null"``. The Python
+ value ``None`` may be used in conjunction with either
+ :attr:`_types.JSON.NULL` and :func:`.null` in order to indicate NULL
+ values, but care must be taken as to the value of the
+ :paramref:`_types.JSON.none_as_null` in these cases.
+
+ **Customizing the JSON Serializer**
+
+ The JSON serializer and deserializer used by :class:`_types.JSON`
+ defaults to
+ Python's ``json.dumps`` and ``json.loads`` functions; in the case of the
+ psycopg2 dialect, psycopg2 may be using its own custom loader function.
+
+ In order to affect the serializer / deserializer, they are currently
+ configurable at the :func:`_sa.create_engine` level via the
+ :paramref:`_sa.create_engine.json_serializer` and
+ :paramref:`_sa.create_engine.json_deserializer` parameters. For example,
+ to turn off ``ensure_ascii``::
+
+ engine = create_engine(
+ "sqlite://",
+ json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False))
+
+ .. versionchanged:: 1.3.7
+
+ SQLite dialect's ``json_serializer`` and ``json_deserializer``
+ parameters renamed from ``_json_serializer`` and
+ ``_json_deserializer``.
+
+ .. seealso::
+
+ :class:`sqlalchemy.dialects.postgresql.JSON`
+
+ :class:`sqlalchemy.dialects.postgresql.JSONB`
+
+ :class:`sqlalchemy.dialects.mysql.JSON`
+
+ :class:`sqlalchemy.dialects.sqlite.JSON`
+
+ .. versionadded:: 1.1
+
+
+ """
+
+ __visit_name__ = "JSON"
+
+ hashable = False
+ NULL = util.symbol("JSON_NULL")
+ """Describe the json value of NULL.
+
+ This value is used to force the JSON value of ``"null"`` to be
+ used as the value. A value of Python ``None`` will be recognized
+ either as SQL NULL or JSON ``"null"``, based on the setting
+ of the :paramref:`_types.JSON.none_as_null` flag; the
+ :attr:`_types.JSON.NULL`
+ constant can be used to always resolve to JSON ``"null"`` regardless
+ of this setting. This is in contrast to the :func:`_expression.null`
+ construct,
+ which always resolves to SQL NULL. E.g.::
+
+ from sqlalchemy import null
+ from sqlalchemy.dialects.postgresql import JSON
+
+ # will *always* insert SQL NULL
+ obj1 = MyObject(json_value=null())
+
+ # will *always* insert JSON string "null"
+ obj2 = MyObject(json_value=JSON.NULL)
+
+ session.add_all([obj1, obj2])
+ session.commit()
+
+ In order to set JSON NULL as a default value for a column, the most
+ transparent method is to use :func:`_expression.text`::
+
+ Table(
+ 'my_table', metadata,
+ Column('json_data', JSON, default=text("'null'"))
+ )
+
+ While it is possible to use :attr:`_types.JSON.NULL` in this context, the
+ :attr:`_types.JSON.NULL` value will be returned as the value of the
+ column,
+ which in the context of the ORM or other repurposing of the default
+ value, may not be desirable. Using a SQL expression means the value
+ will be re-fetched from the database within the context of retrieving
+ generated defaults.
+
+
+ """
+
+ def __init__(self, none_as_null=False):
+ """Construct a :class:`_types.JSON` type.
+
+ :param none_as_null=False: if True, persist the value ``None`` as a
+ SQL NULL value, not the JSON encoding of ``null``. Note that when this
+ flag is False, the :func:`.null` construct can still be used to
+ persist a NULL value, which may be passed directly as a parameter
+ value that is specially interpreted by the :class:`_types.JSON` type
+ as SQL NULL::
+
+ from sqlalchemy import null
+ conn.execute(table.insert(), {"data": null()})
+
+ .. note::
+
+ :paramref:`_types.JSON.none_as_null` does **not** apply to the
+ values passed to :paramref:`_schema.Column.default` and
+ :paramref:`_schema.Column.server_default`; a value of ``None``
+ passed for these parameters means "no default present".
+
+ Additionally, when used in SQL comparison expressions, the
+ Python value ``None`` continues to refer to SQL null, and not
+ JSON NULL. The :paramref:`_types.JSON.none_as_null` flag refers
+ explicitly to the **persistence** of the value within an
+ INSERT or UPDATE statement. The :attr:`_types.JSON.NULL`
+ value should be used for SQL expressions that wish to compare to
+ JSON null.
+
+ .. seealso::
+
+ :attr:`.types.JSON.NULL`
+
+ """
+ self.none_as_null = none_as_null
+
+ class JSONElementType(TypeEngine):
+ """Common function for index / path elements in a JSON expression."""
+
+ _integer = Integer()
+ _string = String()
+
+ def string_bind_processor(self, dialect):
+ return self._string._cached_bind_processor(dialect)
+
+ def string_literal_processor(self, dialect):
+ return self._string._cached_literal_processor(dialect)
+
+ def bind_processor(self, dialect):
+ int_processor = self._integer._cached_bind_processor(dialect)
+ string_processor = self.string_bind_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ def literal_processor(self, dialect):
+ int_processor = self._integer._cached_literal_processor(dialect)
+ string_processor = self.string_literal_processor(dialect)
+
+ def process(value):
+ if int_processor and isinstance(value, int):
+ value = int_processor(value)
+ elif string_processor and isinstance(value, util.string_types):
+ value = string_processor(value)
+ return value
+
+ return process
+
+ class JSONIndexType(JSONElementType):
+ """Placeholder for the datatype of a JSON index value.
+
+ This allows execution-time processing of JSON index values
+ for special syntaxes.
+
+ """
+
+ class JSONIntIndexType(JSONIndexType):
+ """Placeholder for the datatype of a JSON index value.
+
+ This allows execution-time processing of JSON index values
+ for special syntaxes.
+
+ """
+
+ class JSONStrIndexType(JSONIndexType):
+ """Placeholder for the datatype of a JSON index value.
+
+ This allows execution-time processing of JSON index values
+ for special syntaxes.
+
+ """
+
+ class JSONPathType(JSONElementType):
+ """Placeholder type for JSON path operations.
+
+ This allows execution-time processing of a path-based
+ index value into a specific SQL syntax.
+
+ """
+
+ class Comparator(Indexable.Comparator, Concatenable.Comparator):
+ """Define comparison operations for :class:`_types.JSON`."""
+
+ def _setup_getitem(self, index):
+ if not isinstance(index, util.string_types) and isinstance(
+ index, compat.collections_abc.Sequence
+ ):
+ index = coercions.expect(
+ roles.BinaryElementRole,
+ index,
+ expr=self.expr,
+ operator=operators.json_path_getitem_op,
+ bindparam_type=JSON.JSONPathType,
+ )
+
+ operator = operators.json_path_getitem_op
+ else:
+ index = coercions.expect(
+ roles.BinaryElementRole,
+ index,
+ expr=self.expr,
+ operator=operators.json_getitem_op,
+ bindparam_type=JSON.JSONIntIndexType
+ if isinstance(index, int)
+ else JSON.JSONStrIndexType,
+ )
+ operator = operators.json_getitem_op
+
+ return operator, index, self.type
+
+ def as_boolean(self):
+ """Cast an indexed value as boolean.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_boolean()
+ ).where(
+ mytable.c.json_column['some_data'].as_boolean() == True
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(Boolean(), "as_boolean")
+
+ def as_string(self):
+ """Cast an indexed value as string.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_string()
+ ).where(
+ mytable.c.json_column['some_data'].as_string() ==
+ 'some string'
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(String(), "as_string")
+
+ def as_integer(self):
+ """Cast an indexed value as integer.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_integer()
+ ).where(
+ mytable.c.json_column['some_data'].as_integer() == 5
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(Integer(), "as_integer")
+
+ def as_float(self):
+ """Cast an indexed value as float.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_float()
+ ).where(
+ mytable.c.json_column['some_data'].as_float() == 29.75
+ )
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self._binary_w_type(Float(), "as_float")
+
+ def as_numeric(self, precision, scale, asdecimal=True):
+ """Cast an indexed value as numeric/decimal.
+
+ e.g.::
+
+ stmt = select(
+ mytable.c.json_column['some_data'].as_numeric(10, 6)
+ ).where(
+ mytable.c.
+ json_column['some_data'].as_numeric(10, 6) == 29.75
+ )
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ return self._binary_w_type(
+ Numeric(precision, scale, asdecimal=asdecimal), "as_numeric"
+ )
+
+ def as_json(self):
+ """Cast an indexed value as JSON.
+
+ e.g.::
+
+ stmt = select(mytable.c.json_column['some_data'].as_json())
+
+ This is typically the default behavior of indexed elements in any
+ case.
+
+ Note that comparison of full JSON structures may not be
+ supported by all backends.
+
+ .. versionadded:: 1.3.11
+
+ """
+ return self.expr
+
+ def _binary_w_type(self, typ, method_name):
+ if not isinstance(
+ self.expr, elements.BinaryExpression
+ ) or self.expr.operator not in (
+ operators.json_getitem_op,
+ operators.json_path_getitem_op,
+ ):
+ raise exc.InvalidRequestError(
+ "The JSON cast operator JSON.%s() only works with a JSON "
+ "index expression e.g. col['q'].%s()"
+ % (method_name, method_name)
+ )
+ expr = self.expr._clone()
+ expr.type = typ
+ return expr
+
+ comparator_factory = Comparator
+
+ @property
+ def python_type(self):
+ return dict
+
+ @property
+ def should_evaluate_none(self):
+ """Alias of :attr:`_types.JSON.none_as_null`"""
+ return not self.none_as_null
+
+ @should_evaluate_none.setter
+ def should_evaluate_none(self, value):
+ self.none_as_null = not value
+
+ @util.memoized_property
+ def _str_impl(self):
+ return String(_expect_unicode=True)
+
+ def bind_processor(self, dialect):
+ string_process = self._str_impl.bind_processor(dialect)
+
+ json_serializer = dialect._json_serializer or json.dumps
+
+ def process(value):
+ if value is self.NULL:
+ value = None
+ elif isinstance(value, elements.Null) or (
+ value is None and self.none_as_null
+ ):
+ return None
+
+ serialized = json_serializer(value)
+ if string_process:
+ serialized = string_process(serialized)
+ return serialized
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ string_process = self._str_impl.result_processor(dialect, coltype)
+ json_deserializer = dialect._json_deserializer or json.loads
+
+ def process(value):
+ if value is None:
+ return None
+ if string_process:
+ value = string_process(value)
+ return json_deserializer(value)
+
+ return process
+
+
+class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
+ """Represent a SQL Array type.
+
+ .. note:: This type serves as the basis for all ARRAY operations.
+ However, currently **only the PostgreSQL backend has support for SQL
+ arrays in SQLAlchemy**. It is recommended to use the PostgreSQL-specific
+ :class:`sqlalchemy.dialects.postgresql.ARRAY` type directly when using
+ ARRAY types with PostgreSQL, as it provides additional operators
+ specific to that backend.
+
+ :class:`_types.ARRAY` is part of the Core in support of various SQL
+ standard functions such as :class:`_functions.array_agg`
+ which explicitly involve
+ arrays; however, with the exception of the PostgreSQL backend and possibly
+ some third-party dialects, no other SQLAlchemy built-in dialect has support
+ for this type.
+
+ An :class:`_types.ARRAY` type is constructed given the "type"
+ of element::
+
+ mytable = Table("mytable", metadata,
+ Column("data", ARRAY(Integer))
+ )
+
+ The above type represents an N-dimensional array,
+ meaning a supporting backend such as PostgreSQL will interpret values
+ with any number of dimensions automatically. To produce an INSERT
+ construct that passes in a 1-dimensional array of integers::
+
+ connection.execute(
+ mytable.insert(),
+ {"data": [1,2,3]}
+ )
+
+ The :class:`_types.ARRAY` type can be constructed given a fixed number
+ of dimensions::
+
+ mytable = Table("mytable", metadata,
+ Column("data", ARRAY(Integer, dimensions=2))
+ )
+
+ Sending a number of dimensions is optional, but recommended if the
+ datatype is to represent arrays of more than one dimension. This number
+ is used:
+
+ * When emitting the type declaration itself to the database, e.g.
+ ``INTEGER[][]``
+
+ * When translating Python values to database values, and vice versa, e.g.
+ an ARRAY of :class:`.Unicode` objects uses this number to efficiently
+ access the string values inside of array structures without resorting
+ to per-row type inspection
+
+ * When used with the Python ``getitem`` accessor, the number of dimensions
+ serves to define the kind of type that the ``[]`` operator should
+ return, e.g. for an ARRAY of INTEGER with two dimensions::
+
+ >>> expr = table.c.column[5] # returns ARRAY(Integer, dimensions=1)
+ >>> expr = expr[6] # returns Integer
+
+ For 1-dimensional arrays, an :class:`_types.ARRAY` instance with no
+ dimension parameter will generally assume single-dimensional behaviors.
+
+ SQL expressions of type :class:`_types.ARRAY` have support for "index" and
+ "slice" behavior. The Python ``[]`` operator works normally here, given
+ integer indexes or slices. Arrays default to 1-based indexing.
+ The operator produces binary expression
+ constructs which will produce the appropriate SQL, both for
+ SELECT statements::
+
+ select(mytable.c.data[5], mytable.c.data[2:7])
+
+ as well as UPDATE statements when the :meth:`_expression.Update.values`
+ method
+ is used::
+
+ mytable.update().values({
+ mytable.c.data[5]: 7,
+ mytable.c.data[2:7]: [1, 2, 3]
+ })
+
+ The :class:`_types.ARRAY` type also provides for the operators
+ :meth:`.types.ARRAY.Comparator.any` and
+ :meth:`.types.ARRAY.Comparator.all`. The PostgreSQL-specific version of
+ :class:`_types.ARRAY` also provides additional operators.
+
+ .. versionadded:: 1.1.0
+
+ .. seealso::
+
+ :class:`sqlalchemy.dialects.postgresql.ARRAY`
+
+ """
+
+ __visit_name__ = "ARRAY"
+
+ _is_array = True
+
+ zero_indexes = False
+ """If True, Python zero-based indexes should be interpreted as one-based
+ on the SQL expression side."""
+
+ class Comparator(Indexable.Comparator, Concatenable.Comparator):
+
+ """Define comparison operations for :class:`_types.ARRAY`.
+
+ More operators are available on the dialect-specific form
+ of this type. See :class:`.postgresql.ARRAY.Comparator`.
+
+ """
+
+ def _setup_getitem(self, index):
+ if isinstance(index, slice):
+ return_type = self.type
+ if self.type.zero_indexes:
+ index = slice(index.start + 1, index.stop + 1, index.step)
+ slice_ = Slice(
+ index.start, index.stop, index.step, _name=self.expr.key
+ )
+ return operators.getitem, slice_, return_type
+ else:
+ if self.type.zero_indexes:
+ index += 1
+ if self.type.dimensions is None or self.type.dimensions == 1:
+ return_type = self.type.item_type
+ else:
+ adapt_kw = {"dimensions": self.type.dimensions - 1}
+ return_type = self.type.adapt(
+ self.type.__class__, **adapt_kw
+ )
+
+ return operators.getitem, index, return_type
+
+ def contains(self, *arg, **kw):
+ raise NotImplementedError(
+ "ARRAY.contains() not implemented for the base "
+ "ARRAY type; please use the dialect-specific ARRAY type"
+ )
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def any(self, other, operator=None):
+ """Return ``other operator ANY (array)`` clause.
+
+ .. note:: This method is an :class:`_types.ARRAY` - specific
+ construct that is now superseded by the :func:`_sql.any_`
+ function, which features a different calling style. The
+ :func:`_sql.any_` function is also mirrored at the method level
+ via the :meth:`_sql.ColumnOperators.any_` method.
+
+ Usage of array-specific :meth:`_types.ARRAY.Comparator.any`
+ is as follows::
+
+ from sqlalchemy.sql import operators
+
+ conn.execute(
+ select(table.c.data).where(
+ table.c.data.any(7, operator=operators.lt)
+ )
+ )
+
+ :param other: expression to be compared
+ :param operator: an operator object from the
+ :mod:`sqlalchemy.sql.operators`
+ package, defaults to :func:`.operators.eq`.
+
+ .. seealso::
+
+ :func:`_expression.any_`
+
+ :meth:`.types.ARRAY.Comparator.all`
+
+ """
+ elements = util.preloaded.sql_elements
+ operator = operator if operator else operators.eq
+
+ arr_type = self.type
+
+ # send plain BinaryExpression so that negate remains at None,
+ # leading to NOT expr for negation.
+ return elements.BinaryExpression(
+ coercions.expect(
+ roles.BinaryElementRole,
+ element=other,
+ operator=operator,
+ expr=self.expr,
+ bindparam_type=arr_type.item_type,
+ ),
+ elements.CollectionAggregate._create_any(self.expr),
+ operator,
+ )
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def all(self, other, operator=None):
+ """Return ``other operator ALL (array)`` clause.
+
+ .. note:: This method is an :class:`_types.ARRAY` - specific
+ construct that is now superseded by the :func:`_sql.any_`
+ function, which features a different calling style. The
+ :func:`_sql.any_` function is also mirrored at the method level
+ via the :meth:`_sql.ColumnOperators.any_` method.
+
+ Usage of array-specific :meth:`_types.ARRAY.Comparator.all`
+ is as follows::
+
+ from sqlalchemy.sql import operators
+
+ conn.execute(
+ select(table.c.data).where(
+ table.c.data.all(7, operator=operators.lt)
+ )
+ )
+
+ :param other: expression to be compared
+ :param operator: an operator object from the
+ :mod:`sqlalchemy.sql.operators`
+ package, defaults to :func:`.operators.eq`.
+
+ .. seealso::
+
+ :func:`_expression.all_`
+
+ :meth:`.types.ARRAY.Comparator.any`
+
+ """
+ elements = util.preloaded.sql_elements
+ operator = operator if operator else operators.eq
+
+ arr_type = self.type
+
+ # send plain BinaryExpression so that negate remains at None,
+ # leading to NOT expr for negation.
+ return elements.BinaryExpression(
+ coercions.expect(
+ roles.BinaryElementRole,
+ element=other,
+ operator=operator,
+ expr=self.expr,
+ bindparam_type=arr_type.item_type,
+ ),
+ elements.CollectionAggregate._create_all(self.expr),
+ operator,
+ )
+
+ comparator_factory = Comparator
+
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
+ """Construct an :class:`_types.ARRAY`.
+
+ E.g.::
+
+ Column('myarray', ARRAY(Integer))
+
+ Arguments are:
+
+ :param item_type: The data type of items of this array. Note that
+ dimensionality is irrelevant here, so multi-dimensional arrays like
+ ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
+ ``ARRAY(ARRAY(Integer))`` or such.
+
+ :param as_tuple=False: Specify whether return results
+ should be converted to tuples from lists. This parameter is
+ not generally needed as a Python list corresponds well
+ to a SQL array.
+
+ :param dimensions: if non-None, the ARRAY will assume a fixed
+ number of dimensions. This impacts how the array is declared
+ on the database, how it goes about interpreting Python and
+ result values, as well as how expression behavior in conjunction
+ with the "getitem" operator works. See the description at
+ :class:`_types.ARRAY` for additional detail.
+
+ :param zero_indexes=False: when True, index values will be converted
+ between Python zero-based and SQL one-based indexes, e.g.
+ a value of one will be added to all index values before passing
+ to the database.
+
+ """
+ if isinstance(item_type, ARRAY):
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
+ if isinstance(item_type, type):
+ item_type = item_type()
+ self.item_type = item_type
+ self.as_tuple = as_tuple
+ self.dimensions = dimensions
+ self.zero_indexes = zero_indexes
+
+ @property
+ def hashable(self):
+ return self.as_tuple
+
+ @property
+ def python_type(self):
+ return list
+
+ def compare_values(self, x, y):
+ return x == y
+
+ def _set_parent(self, column, outer=False, **kw):
+ """Support SchemaEventTarget"""
+
+ if not outer and isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent(column, **kw)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True)
+
+ if isinstance(self.item_type, SchemaEventTarget):
+ self.item_type._set_parent_with_dispatch(parent)
+
+
+class TupleType(TypeEngine):
+ """represent the composite type of a Tuple."""
+
+ _is_tuple_type = True
+
+ def __init__(self, *types):
+ self._fully_typed = NULLTYPE not in types
+ self.types = [
+ item_type() if isinstance(item_type, type) else item_type
+ for item_type in types
+ ]
+
+ def _resolve_values_to_types(self, value):
+ if self._fully_typed:
+ return self
+ else:
+ return TupleType(
+ *[
+ _resolve_value_to_type(elem) if typ is NULLTYPE else typ
+ for typ, elem in zip(self.types, value)
+ ]
+ )
+
+ def result_processor(self, dialect, coltype):
+ raise NotImplementedError(
+ "The tuple type does not support being fetched "
+ "as a column in a result row."
+ )
+
+
+class REAL(Float):
+
+ """The SQL REAL type."""
+
+ __visit_name__ = "REAL"
+
+
+class FLOAT(Float):
+
+ """The SQL FLOAT type."""
+
+ __visit_name__ = "FLOAT"
+
+
+class NUMERIC(Numeric):
+
+ """The SQL NUMERIC type."""
+
+ __visit_name__ = "NUMERIC"
+
+
+class DECIMAL(Numeric):
+
+ """The SQL DECIMAL type."""
+
+ __visit_name__ = "DECIMAL"
+
+
+class INTEGER(Integer):
+
+ """The SQL INT or INTEGER type."""
+
+ __visit_name__ = "INTEGER"
+
+
+INT = INTEGER
+
+
+class SMALLINT(SmallInteger):
+
+ """The SQL SMALLINT type."""
+
+ __visit_name__ = "SMALLINT"
+
+
+class BIGINT(BigInteger):
+
+ """The SQL BIGINT type."""
+
+ __visit_name__ = "BIGINT"
+
+
+class TIMESTAMP(DateTime):
+
+ """The SQL TIMESTAMP type.
+
+ :class:`_types.TIMESTAMP` datatypes have support for timezone
+ storage on some backends, such as PostgreSQL and Oracle. Use the
+ :paramref:`~types.TIMESTAMP.timezone` argument in order to enable
+ "TIMESTAMP WITH TIMEZONE" for these backends.
+
+ """
+
+ __visit_name__ = "TIMESTAMP"
+
+ def __init__(self, timezone=False):
+ """Construct a new :class:`_types.TIMESTAMP`.
+
+ :param timezone: boolean. Indicates that the TIMESTAMP type should
+ enable timezone support, if available on the target database.
+ On a per-dialect basis is similar to "TIMESTAMP WITH TIMEZONE".
+ If the target database does not support timezones, this flag is
+ ignored.
+
+
+ """
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.TIMESTAMP
+
+
+class DATETIME(DateTime):
+
+ """The SQL DATETIME type."""
+
+ __visit_name__ = "DATETIME"
+
+
+class DATE(Date):
+
+ """The SQL DATE type."""
+
+ __visit_name__ = "DATE"
+
+
+class TIME(Time):
+
+ """The SQL TIME type."""
+
+ __visit_name__ = "TIME"
+
+
+class TEXT(Text):
+
+ """The SQL TEXT type."""
+
+ __visit_name__ = "TEXT"
+
+
+class CLOB(Text):
+
+ """The CLOB type.
+
+ This type is found in Oracle and Informix.
+ """
+
+ __visit_name__ = "CLOB"
+
+
+class VARCHAR(String):
+
+ """The SQL VARCHAR type."""
+
+ __visit_name__ = "VARCHAR"
+
+
+class NVARCHAR(Unicode):
+
+ """The SQL NVARCHAR type."""
+
+ __visit_name__ = "NVARCHAR"
+
+
+class CHAR(String):
+
+ """The SQL CHAR type."""
+
+ __visit_name__ = "CHAR"
+
+
+class NCHAR(Unicode):
+
+ """The SQL NCHAR type."""
+
+ __visit_name__ = "NCHAR"
+
+
+class BLOB(LargeBinary):
+
+ """The SQL BLOB type."""
+
+ __visit_name__ = "BLOB"
+
+
+class BINARY(_Binary):
+
+ """The SQL BINARY type."""
+
+ __visit_name__ = "BINARY"
+
+
+class VARBINARY(_Binary):
+
+ """The SQL VARBINARY type."""
+
+ __visit_name__ = "VARBINARY"
+
+
+class BOOLEAN(Boolean):
+
+ """The SQL BOOLEAN type."""
+
+ __visit_name__ = "BOOLEAN"
+
+
+class NullType(TypeEngine):
+
+ """An unknown type.
+
+ :class:`.NullType` is used as a default type for those cases where
+ a type cannot be determined, including:
+
+ * During table reflection, when the type of a column is not recognized
+ by the :class:`.Dialect`
+ * When constructing SQL expressions using plain Python objects of
+ unknown types (e.g. ``somecolumn == my_special_object``)
+ * When a new :class:`_schema.Column` is created,
+ and the given type is passed
+ as ``None`` or is not passed at all.
+
+ The :class:`.NullType` can be used within SQL expression invocation
+ without issue, it just has no behavior either at the expression
+ construction level or at the bind-parameter/result processing level.
+ :class:`.NullType` will result in a :exc:`.CompileError` if the compiler
+ is asked to render the type itself, such as if it is used in a
+ :func:`.cast` operation or within a schema creation operation such as that
+ invoked by :meth:`_schema.MetaData.create_all` or the
+ :class:`.CreateTable`
+ construct.
+
+ """
+
+ __visit_name__ = "null"
+
+ _isnull = True
+
+ def literal_processor(self, dialect):
+ def process(value):
+ raise exc.CompileError(
+ "Don't know how to render literal SQL value: %r" % (value,)
+ )
+
+ return process
+
+ class Comparator(TypeEngine.Comparator):
+ def _adapt_expression(self, op, other_comparator):
+ if isinstance(
+ other_comparator, NullType.Comparator
+ ) or not operators.is_commutative(op):
+ return op, self.expr.type
+ else:
+ return other_comparator._adapt_expression(op, self)
+
+ comparator_factory = Comparator
+
+
+class TableValueType(HasCacheKey, TypeEngine):
+ """Refers to a table value type."""
+
+ _is_table_value = True
+
+ _traverse_internals = [
+ ("_elements", InternalTraversal.dp_clauseelement_list),
+ ]
+
+ def __init__(self, *elements):
+ self._elements = [
+ coercions.expect(roles.StrAsPlainColumnRole, elem)
+ for elem in elements
+ ]
+
+
+class MatchType(Boolean):
+ """Refers to the return type of the MATCH operator.
+
+ As the :meth:`.ColumnOperators.match` is probably the most open-ended
+ operator in generic SQLAlchemy Core, we can't assume the return type
+ at SQL evaluation time, as MySQL returns a floating point, not a boolean,
+ and other backends might do something different. So this type
+ acts as a placeholder, currently subclassing :class:`.Boolean`.
+ The type allows dialects to inject result-processing functionality
+ if needed, and on MySQL will return floating-point values.
+
+ .. versionadded:: 1.0.0
+
+ """
+
+
+NULLTYPE = NullType()
+BOOLEANTYPE = Boolean()
+STRINGTYPE = String()
+INTEGERTYPE = Integer()
+NUMERICTYPE = Numeric()
+MATCHTYPE = MatchType()
+TABLEVALUE = TableValueType()
+DATETIME_TIMEZONE = DateTime(timezone=True)
+TIME_TIMEZONE = Time(timezone=True)
+
+_type_map = {
+ int: Integer(),
+ float: Float(),
+ bool: BOOLEANTYPE,
+ decimal.Decimal: Numeric(),
+ dt.date: Date(),
+ dt.datetime: DateTime(),
+ dt.time: Time(),
+ dt.timedelta: Interval(),
+ util.NoneType: NULLTYPE,
+}
+
+if util.py3k:
+ _type_map[bytes] = LargeBinary() # noqa
+ _type_map[str] = Unicode()
+else:
+ _type_map[unicode] = Unicode() # noqa
+ _type_map[str] = String()
+
+
+_type_map_get = _type_map.get
+
+
+def _resolve_value_to_type(value):
+ _result_type = _type_map_get(type(value), False)
+ if _result_type is False:
+ # use inspect() to detect SQLAlchemy built-in
+ # objects.
+ insp = inspection.inspect(value, False)
+ if (
+ insp is not None
+ and
+ # foil mock.Mock() and other impostors by ensuring
+ # the inspection target itself self-inspects
+ insp.__class__ in inspection._registrars
+ ):
+ raise exc.ArgumentError(
+ "Object %r is not legal as a SQL literal value" % (value,)
+ )
+ return NULLTYPE
+ else:
+ return _result_type._resolve_for_literal(value)
+
+
+# back-assign to type_api
+type_api.BOOLEANTYPE = BOOLEANTYPE
+type_api.STRINGTYPE = STRINGTYPE
+type_api.INTEGERTYPE = INTEGERTYPE
+type_api.NULLTYPE = NULLTYPE
+type_api.NUMERICTYPE = NUMERICTYPE
+type_api.MATCHTYPE = MATCHTYPE
+type_api.INDEXABLE = Indexable
+type_api.TABLEVALUE = TABLEVALUE
+type_api._resolve_value_to_type = _resolve_value_to_type
+TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE
diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py
new file mode 100644
index 0000000..9da61ab
--- /dev/null
+++ b/lib/sqlalchemy/sql/traversals.py
@@ -0,0 +1,1559 @@
+from collections import deque
+from collections import namedtuple
+import itertools
+import operator
+
+from . import operators
+from .visitors import ExtendedInternalTraversal
+from .visitors import InternalTraversal
+from .. import util
+from ..inspection import inspect
+from ..util import collections_abc
+from ..util import HasMemoized
+from ..util import py37
+
+SKIP_TRAVERSE = util.symbol("skip_traverse")
+COMPARE_FAILED = False
+COMPARE_SUCCEEDED = True
+NO_CACHE = util.symbol("no_cache")
+CACHE_IN_PLACE = util.symbol("cache_in_place")
+CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
+STATIC_CACHE_KEY = util.symbol("static_cache_key")
+PROPAGATE_ATTRS = util.symbol("propagate_attrs")
+ANON_NAME = util.symbol("anon_name")
+
+
+def compare(obj1, obj2, **kw):
+ if kw.get("use_proxies", False):
+ strategy = ColIdentityComparatorStrategy()
+ else:
+ strategy = TraversalComparatorStrategy()
+
+ return strategy.compare(obj1, obj2, **kw)
+
+
+def _preconfigure_traversals(target_hierarchy):
+ for cls in util.walk_subclasses(target_hierarchy):
+ if hasattr(cls, "_traverse_internals"):
+ cls._generate_cache_attrs()
+ _copy_internals.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_copy_internals_traversal",
+ )
+ _get_children.generate_dispatch(
+ cls,
+ cls._traverse_internals,
+ "_generated_get_children_traversal",
+ )
+
+
+class HasCacheKey(object):
+ """Mixin for objects which can produce a cache key.
+
+ .. seealso::
+
+ :class:`.CacheKey`
+
+ :ref:`sql_caching`
+
+ """
+
+ _cache_key_traversal = NO_CACHE
+
+ _is_has_cache_key = True
+
+ _hierarchy_supports_caching = True
+ """private attribute which may be set to False to prevent the
+ inherit_cache warning from being emitted for a hierarchy of subclasses.
+
+ Currently applies to the DDLElement hierarchy which does not implement
+ caching.
+
+ """
+
+ inherit_cache = None
+ """Indicate if this :class:`.HasCacheKey` instance should make use of the
+ cache key generation scheme used by its immediate superclass.
+
+ The attribute defaults to ``None``, which indicates that a construct has
+ not yet taken into account whether or not its appropriate for it to
+ participate in caching; this is functionally equivalent to setting the
+ value to ``False``, except that a warning is also emitted.
+
+ This flag can be set to ``True`` on a particular class, if the SQL that
+ corresponds to the object does not change based on attributes which
+ are local to this class, and not its superclass.
+
+ .. seealso::
+
+ :ref:`compilerext_caching` - General guideslines for setting the
+ :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user
+ defined SQL constructs.
+
+ """
+
+ __slots__ = ()
+
+ @classmethod
+ def _generate_cache_attrs(cls):
+ """generate cache key dispatcher for a new class.
+
+ This sets the _generated_cache_key_traversal attribute once called
+ so should only be called once per class.
+
+ """
+ inherit_cache = cls.__dict__.get("inherit_cache", None)
+ inherit = bool(inherit_cache)
+
+ if inherit:
+ _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
+ if _cache_key_traversal is None:
+ try:
+ _cache_key_traversal = cls._traverse_internals
+ except AttributeError:
+ cls._generated_cache_key_traversal = NO_CACHE
+ return NO_CACHE
+
+ # TODO: wouldn't we instead get this from our superclass?
+ # also, our superclass may not have this yet, but in any case,
+ # we'd generate for the superclass that has it. this is a little
+ # more complicated, so for the moment this is a little less
+ # efficient on startup but simpler.
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+ else:
+ _cache_key_traversal = cls.__dict__.get(
+ "_cache_key_traversal", None
+ )
+ if _cache_key_traversal is None:
+ _cache_key_traversal = cls.__dict__.get(
+ "_traverse_internals", None
+ )
+ if _cache_key_traversal is None:
+ cls._generated_cache_key_traversal = NO_CACHE
+ if (
+ inherit_cache is None
+ and cls._hierarchy_supports_caching
+ ):
+ util.warn(
+ "Class %s will not make use of SQL compilation "
+ "caching as it does not set the 'inherit_cache' "
+ "attribute to ``True``. This can have "
+ "significant performance implications including "
+ "some performance degradations in comparison to "
+ "prior SQLAlchemy versions. Set this attribute "
+ "to True if this object can make use of the cache "
+ "key generated by the superclass. Alternatively, "
+ "this attribute may be set to False which will "
+ "disable this warning." % (cls.__name__),
+ code="cprf",
+ )
+ return NO_CACHE
+
+ return _cache_key_traversal_visitor.generate_dispatch(
+ cls, _cache_key_traversal, "_generated_cache_key_traversal"
+ )
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def _gen_cache_key(self, anon_map, bindparams):
+ """return an optional cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the structures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ If a structure cannot produce a useful cache key, the NO_CACHE
+ symbol should be added to the anon_map and the method should
+ return None.
+
+ """
+
+ idself = id(self)
+ cls = self.__class__
+
+ if idself in anon_map:
+ return (anon_map[idself], cls)
+ else:
+ # inline of
+ # id_ = anon_map[idself]
+ anon_map[idself] = id_ = str(anon_map.index)
+ anon_map.index += 1
+
+ try:
+ dispatcher = cls.__dict__["_generated_cache_key_traversal"]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = cls._generate_cache_attrs()
+
+ if dispatcher is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+
+ result = (id_, cls)
+
+ # inline of _cache_key_traversal_visitor.run_generated_dispatch()
+
+ for attrname, obj, meth in dispatcher(
+ self, _cache_key_traversal_visitor
+ ):
+ if obj is not None:
+ # TODO: see if C code can help here as Python lacks an
+ # efficient switch construct
+
+ if meth is STATIC_CACHE_KEY:
+ sck = obj._static_cache_key
+ if sck is NO_CACHE:
+ anon_map[NO_CACHE] = True
+ return None
+ result += (attrname, sck)
+ elif meth is ANON_NAME:
+ elements = util.preloaded.sql_elements
+ if isinstance(obj, elements._anonymous_label):
+ obj = obj.apply_map(anon_map)
+ result += (attrname, obj)
+ elif meth is CALL_GEN_CACHE_KEY:
+ result += (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams),
+ )
+
+ # remaining cache functions are against
+ # Python tuples, dicts, lists, etc. so we can skip
+ # if they are empty
+ elif obj:
+ if meth is CACHE_IN_PLACE:
+ result += (attrname, obj)
+ elif meth is PROPAGATE_ATTRS:
+ result += (
+ attrname,
+ obj["compile_state_plugin"],
+ obj["plugin_subject"]._gen_cache_key(
+ anon_map, bindparams
+ )
+ if obj["plugin_subject"]
+ else None,
+ )
+ elif meth is InternalTraversal.dp_annotations_key:
+ # obj is here is the _annotations dict. however, we
+ # want to use the memoized cache key version of it. for
+ # Columns, this should be long lived. For select()
+ # statements, not so much, but they usually won't have
+ # annotations.
+ result += self._annotations_cache_key
+ elif (
+ meth is InternalTraversal.dp_clauseelement_list
+ or meth is InternalTraversal.dp_clauseelement_tuple
+ or meth
+ is InternalTraversal.dp_memoized_select_entities
+ ):
+ result += (
+ attrname,
+ tuple(
+ [
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in obj
+ ]
+ ),
+ )
+ else:
+ result += meth(
+ attrname, obj, self, anon_map, bindparams
+ )
+ return result
+
+ def _generate_cache_key(self):
+ """return a cache key.
+
+ The cache key is a tuple which can contain any series of
+ objects that are hashable and also identifies
+ this object uniquely within the presence of a larger SQL expression
+ or statement, for the purposes of caching the resulting query.
+
+ The cache key should be based on the SQL compiled structure that would
+ ultimately be produced. That is, two structures that are composed in
+ exactly the same way should produce the same cache key; any difference
+ in the structures that would affect the SQL string or the type handlers
+ should result in a different cache key.
+
+ The cache key returned by this method is an instance of
+ :class:`.CacheKey`, which consists of a tuple representing the
+ cache key, as well as a list of :class:`.BindParameter` objects
+ which are extracted from the expression. While two expressions
+ that produce identical cache key tuples will themselves generate
+ identical SQL strings, the list of :class:`.BindParameter` objects
+ indicates the bound values which may have different values in
+ each one; these bound parameters must be consulted in order to
+ execute the statement with the correct parameters.
+
+ a :class:`_expression.ClauseElement` structure that does not implement
+ a :meth:`._gen_cache_key` method and does not implement a
+ :attr:`.traverse_internals` attribute will not be cacheable; when
+ such an element is embedded into a larger structure, this method
+ will return None, indicating no cache key is available.
+
+ """
+
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = self._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+ @classmethod
+ def _generate_cache_key_for_object(cls, obj):
+ bindparams = []
+
+ _anon_map = anon_map()
+ key = obj._gen_cache_key(_anon_map, bindparams)
+ if NO_CACHE in _anon_map:
+ return None
+ else:
+ return CacheKey(key, bindparams)
+
+
+class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
+ @HasMemoized.memoized_instancemethod
+ def _generate_cache_key(self):
+ return HasCacheKey._generate_cache_key(self)
+
+
+class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
+ """The key used to identify a SQL statement construct in the
+ SQL compilation cache.
+
+ .. seealso::
+
+ :ref:`sql_caching`
+
+ """
+
+ def __hash__(self):
+ """CacheKey itself is not hashable - hash the .key portion"""
+
+ return None
+
+ def to_offline_string(self, statement_cache, statement, parameters):
+ """Generate an "offline string" form of this :class:`.CacheKey`
+
+ The "offline string" is basically the string SQL for the
+ statement plus a repr of the bound parameter values in series.
+ Whereas the :class:`.CacheKey` object is dependent on in-memory
+ identities in order to work as a cache key, the "offline" version
+ is suitable for a cache that will work for other processes as well.
+
+ The given ``statement_cache`` is a dictionary-like object where the
+ string form of the statement itself will be cached. This dictionary
+ should be in a longer lived scope in order to reduce the time spent
+ stringifying statements.
+
+
+ """
+ if self.key not in statement_cache:
+ statement_cache[self.key] = sql_str = str(statement)
+ else:
+ sql_str = statement_cache[self.key]
+
+ if not self.bindparams:
+ param_tuple = tuple(parameters[key] for key in sorted(parameters))
+ else:
+ param_tuple = tuple(
+ parameters.get(bindparam.key, bindparam.value)
+ for bindparam in self.bindparams
+ )
+
+ return repr((sql_str, param_tuple))
+
+ def __eq__(self, other):
+ return self.key == other.key
+
+ @classmethod
+ def _diff_tuples(cls, left, right):
+ ck1 = CacheKey(left, [])
+ ck2 = CacheKey(right, [])
+ return ck1._diff(ck2)
+
+ def _whats_different(self, other):
+
+ k1 = self.key
+ k2 = other.key
+
+ stack = []
+ pickup_index = 0
+ while True:
+ s1, s2 = k1, k2
+ for idx in stack:
+ s1 = s1[idx]
+ s2 = s2[idx]
+
+ for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)):
+ if idx < pickup_index:
+ continue
+ if e1 != e2:
+ if isinstance(e1, tuple) and isinstance(e2, tuple):
+ stack.append(idx)
+ break
+ else:
+ yield "key%s[%d]: %s != %s" % (
+ "".join("[%d]" % id_ for id_ in stack),
+ idx,
+ e1,
+ e2,
+ )
+ else:
+ pickup_index = stack.pop(-1)
+ break
+
+ def _diff(self, other):
+ return ", ".join(self._whats_different(other))
+
+ def __str__(self):
+ stack = [self.key]
+
+ output = []
+ sentinel = object()
+ indent = -1
+ while stack:
+ elem = stack.pop(0)
+ if elem is sentinel:
+ output.append((" " * (indent * 2)) + "),")
+ indent -= 1
+ elif isinstance(elem, tuple):
+ if not elem:
+ output.append((" " * ((indent + 1) * 2)) + "()")
+ else:
+ indent += 1
+ stack = list(elem) + [sentinel] + stack
+ output.append((" " * (indent * 2)) + "(")
+ else:
+ if isinstance(elem, HasCacheKey):
+ repr_ = "<%s object at %s>" % (
+ type(elem).__name__,
+ hex(id(elem)),
+ )
+ else:
+ repr_ = repr(elem)
+ output.append((" " * (indent * 2)) + " " + repr_ + ", ")
+
+ return "CacheKey(key=%s)" % ("\n".join(output),)
+
+ def _generate_param_dict(self):
+ """used for testing"""
+
+ from .compiler import prefix_anon_map
+
+ _anon_map = prefix_anon_map()
+ return {b.key % _anon_map: b.effective_value for b in self.bindparams}
+
+ def _apply_params_to_element(self, original_cache_key, target_element):
+ translate = {
+ k.key: v.value
+ for k, v in zip(original_cache_key.bindparams, self.bindparams)
+ }
+
+ return target_element.params(translate)
+
+
+def _clone(element, **kw):
+ return element._clone()
+
+
+class _CacheKey(ExtendedInternalTraversal):
+ # very common elements are inlined into the main _get_cache_key() method
+ # to produce a dramatic savings in Python function call overhead
+
+ visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
+ visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
+ visit_annotations_key = InternalTraversal.dp_annotations_key
+ visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
+ visit_memoized_select_entities = (
+ InternalTraversal.dp_memoized_select_entities
+ )
+
+ visit_string = (
+ visit_boolean
+ ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
+ visit_statement_hint_list = CACHE_IN_PLACE
+ visit_type = STATIC_CACHE_KEY
+ visit_anon_name = ANON_NAME
+
+ visit_propagate_attrs = PROPAGATE_ATTRS
+
+ def visit_with_context_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return tuple((fn.__code__, c_key) for fn, c_key in obj)
+
+ def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
+
+ def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
+ return tuple(obj)
+
+ def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ obj._gen_cache_key(anon_map, bindparams)
+ if isinstance(obj, HasCacheKey)
+ else obj,
+ )
+
+ def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ if isinstance(elem, HasCacheKey)
+ else elem
+ for elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in tup_elem
+ )
+ for tup_elem in obj
+ ),
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
+ )
+
+ def visit_executable_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple(
+ elem._gen_cache_key(anon_map, bindparams)
+ for elem in obj
+ if elem._is_has_cache_key
+ ),
+ )
+
+ def visit_inspectable_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_list(
+ attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
+ )
+
+ def visit_clauseelement_tuples(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return self.visit_has_cache_key_tuples(
+ attrname, obj, parent, anon_map, bindparams
+ )
+
+ def visit_fromclause_ordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ return (
+ attrname,
+ tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
+ )
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+ cache_keys = [
+ elem._gen_cache_key(anon_map, bindparams) for elem in obj
+ ]
+ return (
+ attrname,
+ tuple(
+ sorted(cache_keys)
+ ), # cache keys all start with (id_, class)
+ )
+
+ def visit_named_ddl_element(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (attrname, obj.name)
+
+ def visit_prefix_sequence(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+
+ return (
+ attrname,
+ tuple(
+ [
+ (clause._gen_cache_key(anon_map, bindparams), strval)
+ for clause, strval in obj
+ ]
+ ),
+ )
+
+ def visit_setup_join_tuple(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ is_legacy = "legacy" in attrname
+
+ return tuple(
+ (
+ target
+ if is_legacy and isinstance(target, str)
+ else target._gen_cache_key(anon_map, bindparams),
+ onclause
+ if is_legacy and isinstance(onclause, str)
+ else onclause._gen_cache_key(anon_map, bindparams)
+ if onclause is not None
+ else None,
+ from_._gen_cache_key(anon_map, bindparams)
+ if from_ is not None
+ else None,
+ tuple([(key, flags[key]) for key in sorted(flags)]),
+ )
+ for (target, onclause, from_, flags) in obj
+ )
+
+ def visit_table_hint_list(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ if not obj:
+ return ()
+
+ return (
+ attrname,
+ tuple(
+ [
+ (
+ clause._gen_cache_key(anon_map, bindparams),
+ dialect_name,
+ text,
+ )
+ for (clause, dialect_name), text in obj.items()
+ ]
+ ),
+ )
+
+ def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
+ return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
+
+ def visit_dialect_options(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ dialect_name,
+ tuple(
+ [
+ (key, obj[dialect_name][key])
+ for key in sorted(obj[dialect_name])
+ ]
+ ),
+ )
+ for dialect_name in sorted(obj)
+ ),
+ )
+
+ def visit_string_clauseelement_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (key, obj[key]._gen_cache_key(anon_map, bindparams))
+ for key in sorted(obj)
+ ),
+ )
+
+ def visit_string_multi_dict(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key,
+ value._gen_cache_key(anon_map, bindparams)
+ if isinstance(value, HasCacheKey)
+ else value,
+ )
+ for key, value in [(key, obj[key]) for key in sorted(obj)]
+ ),
+ )
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # inlining into the internals of ColumnCollection
+ return (
+ attrname,
+ tuple(
+ col._gen_cache_key(anon_map, bindparams)
+ for k, col in obj._collection
+ ),
+ )
+
+ def visit_unknown_structure(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ anon_map[NO_CACHE] = True
+ return ()
+
+ def visit_dml_ordered_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ return (
+ attrname,
+ tuple(
+ (
+ key._gen_cache_key(anon_map, bindparams)
+ if hasattr(key, "__clause_element__")
+ else key,
+ value._gen_cache_key(anon_map, bindparams),
+ )
+ for key, value in obj
+ ),
+ )
+
+ def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ if py37:
+ # in py37 we can assume two dictionaries created in the same
+ # insert ordering will retain that sorting
+ return (
+ attrname,
+ tuple(
+ (
+ k._gen_cache_key(anon_map, bindparams)
+ if hasattr(k, "__clause_element__")
+ else k,
+ obj[k]._gen_cache_key(anon_map, bindparams),
+ )
+ for k in obj
+ ),
+ )
+ else:
+ expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+ if expr_values:
+ # expr values can't be sorted deterministically right now,
+ # so no cache
+ anon_map[NO_CACHE] = True
+ return ()
+
+ str_values = expr_values.symmetric_difference(obj)
+
+ return (
+ attrname,
+ tuple(
+ (k, obj[k]._gen_cache_key(anon_map, bindparams))
+ for k in sorted(str_values)
+ ),
+ )
+
+ def visit_dml_multi_values(
+ self, attrname, obj, parent, anon_map, bindparams
+ ):
+ # multivalues are simply not cacheable right now
+ anon_map[NO_CACHE] = True
+ return ()
+
+
+_cache_key_traversal_visitor = _CacheKey()
+
+
+class HasCopyInternals(object):
+ def _clone(self, **kw):
+ raise NotImplementedError()
+
+ def _copy_internals(self, omit_attrs=(), **kw):
+ """Reassign internal elements to be clones of themselves.
+
+ Called during a copy-and-traverse operation on newly
+ shallow-copied elements to create a deep copy.
+
+ The given clone function should be used, which may be applying
+ additional transformations to the element (i.e. replacement
+ traversal, cloned traversal, annotations).
+
+ """
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return
+
+ for attrname, obj, meth in _copy_internals.run_generated_dispatch(
+ self, traverse_internals, "_generated_copy_internals_traversal"
+ ):
+ if attrname in omit_attrs:
+ continue
+
+ if obj is not None:
+ result = meth(attrname, self, obj, **kw)
+ if result is not None:
+ setattr(self, attrname, result)
+
+
+class _CopyInternals(InternalTraversal):
+ """Generate a _copy_internals internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_clauseelement(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return clone(element, **kw)
+
+ def visit_clauseelement_list(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return [clone(clause, **kw) for clause in element]
+
+ def visit_clauseelement_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
+ def visit_executable_options(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple([clone(clause, **kw) for clause in element])
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return {clone(clause, **kw) for clause in element}
+
+ def visit_clauseelement_tuples(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return [
+ tuple(clone(tup_elem, **kw) for tup_elem in elem)
+ for elem in element
+ ]
+
+ def visit_string_clauseelement_dict(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return dict(
+ (key, clone(value, **kw)) for key, value in element.items()
+ )
+
+ def visit_setup_join_tuple(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return tuple(
+ (
+ clone(target, **kw) if target is not None else None,
+ clone(onclause, **kw) if onclause is not None else None,
+ clone(from_, **kw) if from_ is not None else None,
+ flags,
+ )
+ for (target, onclause, from_, flags) in element
+ )
+
+ def visit_memoized_select_entities(self, attrname, parent, element, **kw):
+ return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
+
+ def visit_dml_ordered_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ # sequence of 2-tuples
+ return [
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key,
+ clone(value, **kw),
+ )
+ for key, value in element
+ ]
+
+ def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
+ return {
+ (
+ clone(key, **kw) if hasattr(key, "__clause_element__") else key
+ ): clone(value, **kw)
+ for key, value in element.items()
+ }
+
+ def visit_dml_multi_values(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ # sequence of sequences, each sequence contains a list/dict/tuple
+
+ def copy(elem):
+ if isinstance(elem, (list, tuple)):
+ return [
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ for value in elem
+ ]
+ elif isinstance(elem, dict):
+ return {
+ (
+ clone(key, **kw)
+ if hasattr(key, "__clause_element__")
+ else key
+ ): (
+ clone(value, **kw)
+ if hasattr(value, "__clause_element__")
+ else value
+ )
+ for key, value in elem.items()
+ }
+ else:
+ # TODO: use abc classes
+ assert False
+
+ return [
+ [copy(sub_element) for sub_element in sequence]
+ for sequence in element
+ ]
+
+ def visit_propagate_attrs(
+ self, attrname, parent, element, clone=_clone, **kw
+ ):
+ return element
+
+
+_copy_internals = _CopyInternals()
+
+
+def _flatten_clauseelement(element):
+ while hasattr(element, "__clause_element__") and not getattr(
+ element, "is_clause_element", False
+ ):
+ element = element.__clause_element__()
+
+ return element
+
+
+class _GetChildren(InternalTraversal):
+ """Generate a _children_traversal internal traversal dispatch for classes
+ with a _traverse_internals collection."""
+
+ def visit_has_cache_key(self, element, **kw):
+ # the GetChildren traversal refers explicitly to ClauseElement
+ # structures. Within these, a plain HasCacheKey is not a
+ # ClauseElement, so don't include these.
+ return ()
+
+ def visit_clauseelement(self, element, **kw):
+ return (element,)
+
+ def visit_clauseelement_list(self, element, **kw):
+ return element
+
+ def visit_clauseelement_tuple(self, element, **kw):
+ return element
+
+ def visit_clauseelement_tuples(self, element, **kw):
+ return itertools.chain.from_iterable(element)
+
+ def visit_fromclause_canonical_column_collection(self, element, **kw):
+ return ()
+
+ def visit_string_clauseelement_dict(self, element, **kw):
+ return element.values()
+
+ def visit_fromclause_ordered_set(self, element, **kw):
+ return element
+
+ def visit_clauseelement_unordered_set(self, element, **kw):
+ return element
+
+ def visit_setup_join_tuple(self, element, **kw):
+ for (target, onclause, from_, flags) in element:
+ if from_ is not None:
+ yield from_
+
+ if not isinstance(target, str):
+ yield _flatten_clauseelement(target)
+
+ if onclause is not None and not isinstance(onclause, str):
+ yield _flatten_clauseelement(onclause)
+
+ def visit_memoized_select_entities(self, element, **kw):
+ return self.visit_clauseelement_tuple(element, **kw)
+
+ def visit_dml_ordered_values(self, element, **kw):
+ for k, v in element:
+ if hasattr(k, "__clause_element__"):
+ yield k
+ yield v
+
+ def visit_dml_values(self, element, **kw):
+ expr_values = {k for k in element if hasattr(k, "__clause_element__")}
+ str_values = expr_values.symmetric_difference(element)
+
+ for k in sorted(str_values):
+ yield element[k]
+ for k in expr_values:
+ yield k
+ yield element[k]
+
+ def visit_dml_multi_values(self, element, **kw):
+ return ()
+
+ def visit_propagate_attrs(self, element, **kw):
+ return ()
+
+
+_get_children = _GetChildren()
+
+
+@util.preload_module("sqlalchemy.sql.elements")
+def _resolve_name_for_compare(element, name, anon_map, **kw):
+ if isinstance(name, util.preloaded.sql_elements._anonymous_label):
+ name = name.apply_map(anon_map)
+
+ return name
+
+
+class anon_map(dict):
+ """A map that creates new keys for missing key access.
+
+ Produces an incrementing sequence given a series of unique keys.
+
+ This is similar to the compiler prefix_anon_map class although simpler.
+
+ Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
+ is otherwise usually used for this type of operation.
+
+ """
+
+ def __init__(self):
+ self.index = 0
+
+ def __missing__(self, key):
+ self[key] = val = str(self.index)
+ self.index += 1
+ return val
+
+
+class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
+ __slots__ = "stack", "cache", "anon_map"
+
+ def __init__(self):
+ self.stack = deque()
+ self.cache = set()
+
+ def _memoized_attr_anon_map(self):
+ return (anon_map(), anon_map())
+
+ def compare(self, obj1, obj2, **kw):
+ stack = self.stack
+ cache = self.cache
+
+ compare_annotations = kw.get("compare_annotations", False)
+
+ stack.append((obj1, obj2))
+
+ while stack:
+ left, right = stack.popleft()
+
+ if left is right:
+ continue
+ elif left is None or right is None:
+ # we know they are different so no match
+ return False
+ elif (left, right) in cache:
+ continue
+ cache.add((left, right))
+
+ visit_name = left.__visit_name__
+ if visit_name != right.__visit_name__:
+ return False
+
+ meth = getattr(self, "compare_%s" % visit_name, None)
+
+ if meth:
+ attributes_compared = meth(left, right, **kw)
+ if attributes_compared is COMPARE_FAILED:
+ return False
+ elif attributes_compared is SKIP_TRAVERSE:
+ continue
+
+ # attributes_compared is returned as a list of attribute
+ # names that were "handled" by the comparison method above.
+ # remaining attribute names in the _traverse_internals
+ # will be compared.
+ else:
+ attributes_compared = ()
+
+ for (
+ (left_attrname, left_visit_sym),
+ (right_attrname, right_visit_sym),
+ ) in util.zip_longest(
+ left._traverse_internals,
+ right._traverse_internals,
+ fillvalue=(None, None),
+ ):
+ if not compare_annotations and (
+ (left_attrname == "_annotations")
+ or (right_attrname == "_annotations")
+ ):
+ continue
+
+ if (
+ left_attrname != right_attrname
+ or left_visit_sym is not right_visit_sym
+ ):
+ return False
+ elif left_attrname in attributes_compared:
+ continue
+
+ dispatch = self.dispatch(left_visit_sym)
+ left_child = operator.attrgetter(left_attrname)(left)
+ right_child = operator.attrgetter(right_attrname)(right)
+ if left_child is None:
+ if right_child is not None:
+ return False
+ else:
+ continue
+
+ comparison = dispatch(
+ left_attrname, left, left_child, right, right_child, **kw
+ )
+ if comparison is COMPARE_FAILED:
+ return False
+
+ return True
+
+ def compare_inner(self, obj1, obj2, **kw):
+ comparator = self.__class__()
+ return comparator.compare(obj1, obj2, **kw)
+
+ def visit_has_cache_key(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
+ def visit_propagate_attrs(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.compare_inner(
+ left.get("plugin_subject", None), right.get("plugin_subject", None)
+ )
+
+ def visit_has_cache_key_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
+ self.anon_map[1], []
+ ):
+ return COMPARE_FAILED
+
+ def visit_executable_options(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ if (
+ l._gen_cache_key(self.anon_map[0], [])
+ if l._is_has_cache_key
+ else l
+ ) != (
+ r._gen_cache_key(self.anon_map[1], [])
+ if r._is_has_cache_key
+ else r
+ ):
+ return COMPARE_FAILED
+
+ def visit_clauseelement(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ self.stack.append((left, right))
+
+ def visit_fromclause_canonical_column_collection(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((lcol, rcol))
+
+ def visit_fromclause_derived_column_collection(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ pass
+
+ def visit_string_clauseelement_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for lstr, rstr in util.zip_longest(
+ sorted(left), sorted(right), fillvalue=None
+ ):
+ if lstr != rstr:
+ return COMPARE_FAILED
+ self.stack.append((left[lstr], right[rstr]))
+
+ def visit_clauseelement_tuples(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
+ if ltup is None or rtup is None:
+ return COMPARE_FAILED
+
+ for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_clauseelement_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_clauseelement_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def _compare_unordered_sequences(self, seq1, seq2, **kw):
+ if seq1 is None:
+ return seq2 is None
+
+ completed = set()
+ for clause in seq1:
+ for other_clause in set(seq2).difference(completed):
+ if self.compare_inner(clause, other_clause, **kw):
+ completed.add(other_clause)
+ break
+ return len(completed) == len(seq1) == len(seq2)
+
+ def visit_clauseelement_unordered_set(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self._compare_unordered_sequences(left, right, **kw)
+
+ def visit_fromclause_ordered_set(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in util.zip_longest(left, right, fillvalue=None):
+ self.stack.append((l, r))
+
+ def visit_string(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_string_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_anon_name(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return _resolve_name_for_compare(
+ left_parent, left, self.anon_map[0], **kw
+ ) == _resolve_name_for_compare(
+ right_parent, right, self.anon_map[1], **kw
+ )
+
+ def visit_boolean(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_operator(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left is right
+
+ def visit_type(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left._compare_type_affinity(right)
+
+ def visit_plain_dict(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_dialect_options(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_annotations_key(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left and right:
+ return (
+ left_parent._annotations_cache_key
+ == right_parent._annotations_cache_key
+ )
+ else:
+ return left == right
+
+ def visit_with_context_options(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
+ (fn.__code__, c_key) for fn, c_key in right
+ )
+
+ def visit_plain_obj(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_named_ddl_element(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left is None:
+ if right is not None:
+ return COMPARE_FAILED
+
+ return left.name == right.name
+
+ def visit_prefix_sequence(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ if l_str != r_str:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((l_clause, r_clause))
+
+ def visit_setup_join_tuple(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ # TODO: look at attrname for "legacy_join" and use different structure
+ for (
+ (l_target, l_onclause, l_from, l_flags),
+ (r_target, r_onclause, r_from, r_flags),
+ ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)):
+ if l_flags != r_flags:
+ return COMPARE_FAILED
+ self.stack.append((l_target, r_target))
+ self.stack.append((l_onclause, r_onclause))
+ self.stack.append((l_from, r_from))
+
+ def visit_memoized_select_entities(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return self.visit_clauseelement_tuple(
+ attrname, left_parent, left, right_parent, right, **kw
+ )
+
+ def visit_table_hint_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
+ right_keys = sorted(
+ right, key=lambda elem: (elem[0].fullname, elem[1])
+ )
+ for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
+ left_keys, right_keys, fillvalue=(None, None)
+ ):
+ if ldialect != rdialect:
+ return COMPARE_FAILED
+ elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((ltable, rtable))
+
+ def visit_statement_hint_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
+ def visit_unknown_structure(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ raise NotImplementedError()
+
+ def visit_dml_ordered_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ # sequence of tuple pairs
+
+ for (lk, lv), (rk, rv) in util.zip_longest(
+ left, right, fillvalue=(None, None)
+ ):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
+ return COMPARE_FAILED
+
+ def _compare_dml_values_or_ce(self, lv, rv, **kw):
+ lvce = hasattr(lv, "__clause_element__")
+ rvce = hasattr(rv, "__clause_element__")
+ if lvce != rvce:
+ return False
+ elif lvce and not self.compare_inner(lv, rv, **kw):
+ return False
+ elif not lvce and lv != rv:
+ return False
+ elif not self.compare_inner(lv, rv, **kw):
+ return False
+
+ return True
+
+ def visit_dml_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ if left is None or right is None or len(left) != len(right):
+ return COMPARE_FAILED
+
+ if isinstance(left, collections_abc.Sequence):
+ for lv, rv in zip(left, right):
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+ elif isinstance(right, collections_abc.Sequence):
+ return COMPARE_FAILED
+ elif py37:
+ # dictionaries guaranteed to support insert ordering in
+ # py37 so that we can compare the keys in order. without
+ # this, we can't compare SQL expression keys because we don't
+ # know which key is which
+ for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
+ return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+ else:
+ for lk in left:
+ lv = left[lk]
+
+ if lk not in right:
+ return COMPARE_FAILED
+ rv = right[lk]
+
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
+
+ def visit_dml_multi_values(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
+ if lseq is None or rseq is None:
+ return COMPARE_FAILED
+
+ for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
+ if (
+ self.visit_dml_values(
+ attrname, left_parent, ld, right_parent, rd, **kw
+ )
+ is COMPARE_FAILED
+ ):
+ return COMPARE_FAILED
+
+ def compare_clauselist(self, left, right, **kw):
+ if left.operator is right.operator:
+ if operators.is_associative(left.operator):
+ if self._compare_unordered_sequences(
+ left.clauses, right.clauses, **kw
+ ):
+ return ["operator", "clauses"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator"]
+ else:
+ return COMPARE_FAILED
+
+ def compare_binary(self, left, right, **kw):
+ if left.operator == right.operator:
+ if operators.is_commutative(left.operator):
+ if (
+ self.compare_inner(left.left, right.left, **kw)
+ and self.compare_inner(left.right, right.right, **kw)
+ ) or (
+ self.compare_inner(left.left, right.right, **kw)
+ and self.compare_inner(left.right, right.left, **kw)
+ ):
+ return ["operator", "negate", "left", "right"]
+ else:
+ return COMPARE_FAILED
+ else:
+ return ["operator", "negate"]
+ else:
+ return COMPARE_FAILED
+
+ def compare_bindparam(self, left, right, **kw):
+ compare_keys = kw.pop("compare_keys", True)
+ compare_values = kw.pop("compare_values", True)
+
+ if compare_values:
+ omit = []
+ else:
+ # this means, "skip these, we already compared"
+ omit = ["callable", "value"]
+
+ if not compare_keys:
+ omit.append("key")
+
+ return omit
+
+
+class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
+ def compare_column_element(
+ self, left, right, use_proxies=True, equivalents=(), **kw
+ ):
+ """Compare ColumnElements using proxies and equivalent collections.
+
+ This is a comparison strategy specific to the ORM.
+ """
+
+ to_compare = (right,)
+ if equivalents and right in equivalents:
+ to_compare = equivalents[right].union(to_compare)
+
+ for oth in to_compare:
+ if use_proxies and left.shares_lineage(oth):
+ return SKIP_TRAVERSE
+ elif hash(left) == hash(right):
+ return SKIP_TRAVERSE
+ else:
+ return COMPARE_FAILED
+
+ def compare_column(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_label(self, left, right, **kw):
+ return self.compare_column_element(left, right, **kw)
+
+ def compare_table(self, left, right, **kw):
+ # tables compare on identity, since it's not really feasible to
+ # compare them column by column with the above rules
+ return SKIP_TRAVERSE if left is right else COMPARE_FAILED
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
new file mode 100644
index 0000000..29dc749
--- /dev/null
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -0,0 +1,1974 @@
+# sql/types_api.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Base types API.
+
+"""
+
+
+from . import operators
+from .base import SchemaEventTarget
+from .traversals import NO_CACHE
+from .visitors import Traversible
+from .visitors import TraversibleType
+from .. import exc
+from .. import util
+
+# these are back-assigned by sqltypes.
+BOOLEANTYPE = None
+INTEGERTYPE = None
+NULLTYPE = None
+NUMERICTYPE = None
+STRINGTYPE = None
+MATCHTYPE = None
+INDEXABLE = None
+TABLEVALUE = None
+_resolve_value_to_type = None
+
+
+class TypeEngine(Traversible):
+ """The ultimate base class for all SQL datatypes.
+
+ Common subclasses of :class:`.TypeEngine` include
+ :class:`.String`, :class:`.Integer`, and :class:`.Boolean`.
+
+ For an overview of the SQLAlchemy typing system, see
+ :ref:`types_toplevel`.
+
+ .. seealso::
+
+ :ref:`types_toplevel`
+
+ """
+
+ _sqla_type = True
+ _isnull = False
+ _is_tuple_type = False
+ _is_table_value = False
+ _is_array = False
+ _is_type_decorator = False
+
+ class Comparator(operators.ColumnOperators):
+ """Base class for custom comparison operations defined at the
+ type level. See :attr:`.TypeEngine.comparator_factory`.
+
+
+ """
+
+ __slots__ = "expr", "type"
+
+ default_comparator = None
+
+ def __clause_element__(self):
+ return self.expr
+
+ def __init__(self, expr):
+ self.expr = expr
+ self.type = expr.type
+
+ @util.preload_module("sqlalchemy.sql.default_comparator")
+ def operate(self, op, *other, **kwargs):
+ default_comparator = util.preloaded.sql_default_comparator
+ o = default_comparator.operator_lookup[op.__name__]
+ return o[0](self.expr, op, *(other + o[1:]), **kwargs)
+
+ @util.preload_module("sqlalchemy.sql.default_comparator")
+ def reverse_operate(self, op, other, **kwargs):
+ default_comparator = util.preloaded.sql_default_comparator
+ o = default_comparator.operator_lookup[op.__name__]
+ return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs)
+
+ def _adapt_expression(self, op, other_comparator):
+ """evaluate the return type of <self> <op> <othertype>,
+ and apply any adaptations to the given operator.
+
+ This method determines the type of a resulting binary expression
+ given two source types and an operator. For example, two
+ :class:`_schema.Column` objects, both of the type
+ :class:`.Integer`, will
+ produce a :class:`.BinaryExpression` that also has the type
+ :class:`.Integer` when compared via the addition (``+``) operator.
+ However, using the addition operator with an :class:`.Integer`
+ and a :class:`.Date` object will produce a :class:`.Date`, assuming
+ "days delta" behavior by the database (in reality, most databases
+ other than PostgreSQL don't accept this particular operation).
+
+ The method returns a tuple of the form <operator>, <type>.
+ The resulting operator and type will be those applied to the
+ resulting :class:`.BinaryExpression` as the final operator and the
+ right-hand side of the expression.
+
+ Note that only a subset of operators make usage of
+ :meth:`._adapt_expression`,
+ including math operators and user-defined operators, but not
+ boolean comparison or special SQL keywords like MATCH or BETWEEN.
+
+ """
+
+ return op, self.type
+
+ def __reduce__(self):
+ return _reconstitute_comparator, (self.expr,)
+
+ hashable = True
+ """Flag, if False, means values from this type aren't hashable.
+
+ Used by the ORM when uniquing result lists.
+
+ """
+
+ comparator_factory = Comparator
+ """A :class:`.TypeEngine.Comparator` class which will apply
+ to operations performed by owning :class:`_expression.ColumnElement`
+ objects.
+
+ The :attr:`.comparator_factory` attribute is a hook consulted by
+ the core expression system when column and SQL expression operations
+ are performed. When a :class:`.TypeEngine.Comparator` class is
+ associated with this attribute, it allows custom re-definition of
+ all existing operators, as well as definition of new operators.
+ Existing operators include those provided by Python operator overloading
+ such as :meth:`.operators.ColumnOperators.__add__` and
+ :meth:`.operators.ColumnOperators.__eq__`,
+ those provided as standard
+ attributes of :class:`.operators.ColumnOperators` such as
+ :meth:`.operators.ColumnOperators.like`
+ and :meth:`.operators.ColumnOperators.in_`.
+
+ Rudimentary usage of this hook is allowed through simple subclassing
+ of existing types, or alternatively by using :class:`.TypeDecorator`.
+ See the documentation section :ref:`types_operators` for examples.
+
+ """
+
+ sort_key_function = None
+ """A sorting function that can be passed as the key to sorted.
+
+ The default value of ``None`` indicates that the values stored by
+ this type are self-sorting.
+
+ .. versionadded:: 1.3.8
+
+ """
+
+ should_evaluate_none = False
+ """If True, the Python constant ``None`` is considered to be handled
+ explicitly by this type.
+
+ The ORM uses this flag to indicate that a positive value of ``None``
+ is passed to the column in an INSERT statement, rather than omitting
+ the column from the INSERT statement which has the effect of firing
+ off column-level defaults. It also allows types which have special
+ behavior for Python None, such as a JSON type, to indicate that
+ they'd like to handle the None value explicitly.
+
+ To set this flag on an existing type, use the
+ :meth:`.TypeEngine.evaluates_none` method.
+
+ .. seealso::
+
+ :meth:`.TypeEngine.evaluates_none`
+
+ .. versionadded:: 1.1
+
+
+ """
+
+ def evaluates_none(self):
+ """Return a copy of this type which has the
+ :attr:`.should_evaluate_none` flag set to True.
+
+ E.g.::
+
+ Table(
+ 'some_table', metadata,
+ Column(
+ String(50).evaluates_none(),
+ nullable=True,
+ server_default='no value')
+ )
+
+ The ORM uses this flag to indicate that a positive value of ``None``
+ is passed to the column in an INSERT statement, rather than omitting
+ the column from the INSERT statement which has the effect of firing
+ off column-level defaults. It also allows for types which have
+ special behavior associated with the Python None value to indicate
+ that the value doesn't necessarily translate into SQL NULL; a
+ prime example of this is a JSON type which may wish to persist the
+ JSON value ``'null'``.
+
+ In all cases, the actual NULL SQL value can be always be
+ persisted in any column by using
+ the :obj:`_expression.null` SQL construct in an INSERT statement
+ or associated with an ORM-mapped attribute.
+
+ .. note::
+
+ The "evaluates none" flag does **not** apply to a value
+ of ``None`` passed to :paramref:`_schema.Column.default` or
+ :paramref:`_schema.Column.server_default`; in these cases,
+ ``None``
+ still means "no default".
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :ref:`session_forcing_null` - in the ORM documentation
+
+ :paramref:`.postgresql.JSON.none_as_null` - PostgreSQL JSON
+ interaction with this flag.
+
+ :attr:`.TypeEngine.should_evaluate_none` - class-level flag
+
+ """
+ typ = self.copy()
+ typ.should_evaluate_none = True
+ return typ
+
+ def copy(self, **kw):
+ return self.adapt(self.__class__)
+
+ def compare_against_backend(self, dialect, conn_type):
+ """Compare this type against the given backend type.
+
+ This function is currently not implemented for SQLAlchemy
+ types, and for all built in types will return ``None``. However,
+ it can be implemented by a user-defined type
+ where it can be consumed by schema comparison tools such as
+ Alembic autogenerate.
+
+ A future release of SQLAlchemy will potentially implement this method
+ for builtin types as well.
+
+ The function should return True if this type is equivalent to the
+ given type; the type is typically reflected from the database
+ so should be database specific. The dialect in use is also
+ passed. It can also return False to assert that the type is
+ not equivalent.
+
+ :param dialect: a :class:`.Dialect` that is involved in the comparison.
+
+ :param conn_type: the type object reflected from the backend.
+
+ .. versionadded:: 1.0.3
+
+ """
+ return None
+
+ def copy_value(self, value):
+ return value
+
+ def literal_processor(self, dialect):
+ """Return a conversion function for processing literal values that are
+ to be rendered directly without using binds.
+
+ This function is used when the compiler makes use of the
+ "literal_binds" flag, typically used in DDL generation as well
+ as in certain scenarios where backends don't accept bound parameters.
+
+ Returns a callable which will receive a literal Python value
+ as the sole positional argument and will return a string representation
+ to be rendered in a SQL statement.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.literal_processor`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.literal_processor`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.process_literal_param`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+
+ """
+ return None
+
+ def bind_processor(self, dialect):
+ """Return a conversion function for processing bind values.
+
+ Returns a callable which will receive a bind parameter value
+ as the sole positional argument and will return a value to
+ send to the DB-API.
+
+ If processing is not necessary, the method should return ``None``.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.bind_processor`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.bind_processor`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.process_bind_param`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+
+ :param dialect: Dialect instance in use.
+
+ """
+ return None
+
+ def result_processor(self, dialect, coltype):
+ """Return a conversion function for processing result row values.
+
+ Returns a callable which will receive a result row column
+ value as the sole positional argument and will return a value
+ to return to the user.
+
+ If processing is not necessary, the method should return ``None``.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.result_processor`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.result_processor`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.process_result_value`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ :param dialect: Dialect instance in use.
+
+ :param coltype: DBAPI coltype argument received in cursor.description.
+
+ """
+ return None
+
+ def column_expression(self, colexpr):
+ """Given a SELECT column expression, return a wrapping SQL expression.
+
+ This is typically a SQL function that wraps a column expression
+ as rendered in the columns clause of a SELECT statement.
+ It is used for special data types that require
+ columns to be wrapped in some special database function in order
+ to coerce the value before being sent back to the application.
+ It is the SQL analogue of the :meth:`.TypeEngine.result_processor`
+ method.
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** called
+ against specific values.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.column_expression`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.column_expression`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.column_expression`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+
+ .. seealso::
+
+ :ref:`types_sql_value_processing`
+
+ """
+
+ return None
+
+ @util.memoized_property
+ def _has_column_expression(self):
+ """memoized boolean, check if column_expression is implemented.
+
+ Allows the method to be skipped for the vast majority of expression
+ types that don't use this feature.
+
+ """
+
+ return (
+ self.__class__.column_expression.__code__
+ is not TypeEngine.column_expression.__code__
+ )
+
+ def bind_expression(self, bindvalue):
+ """Given a bind value (i.e. a :class:`.BindParameter` instance),
+ return a SQL expression in its place.
+
+ This is typically a SQL function that wraps the existing bound
+ parameter within the statement. It is used for special data types
+ that require literals being wrapped in some special database function
+ in order to coerce an application-level value into a database-specific
+ format. It is the SQL analogue of the
+ :meth:`.TypeEngine.bind_processor` method.
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** called
+ against specific values.
+
+ Note that this method, when implemented, should always return
+ the exact same structure, without any conditional logic, as it
+ may be used in an executemany() call against an arbitrary number
+ of bound parameter sets.
+
+ .. note::
+
+ This method is only called relative to a **dialect specific type
+ object**, which is often **private to a dialect in use** and is not
+ the same type object as the public facing one, which means it's not
+ feasible to subclass a :class:`.types.TypeEngine` class in order to
+ provide an alternate :meth:`_types.TypeEngine.bind_expression`
+ method, unless subclassing the :class:`_types.UserDefinedType`
+ class explicitly.
+
+ To provide alternate behavior for
+ :meth:`_types.TypeEngine.bind_expression`, implement a
+ :class:`_types.TypeDecorator` class and provide an implementation
+ of :meth:`_types.TypeDecorator.bind_expression`.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ .. seealso::
+
+ :ref:`types_sql_value_processing`
+
+ """
+ return None
+
+ @util.memoized_property
+ def _has_bind_expression(self):
+ """memoized boolean, check if bind_expression is implemented.
+
+ Allows the method to be skipped for the vast majority of expression
+ types that don't use this feature.
+
+ """
+
+ return util.method_is_overridden(self, TypeEngine.bind_expression)
+
+ @staticmethod
+ def _to_instance(cls_or_self):
+ return to_instance(cls_or_self)
+
+ def compare_values(self, x, y):
+ """Compare two values for equality."""
+
+ return x == y
+
+ def get_dbapi_type(self, dbapi):
+ """Return the corresponding type object from the underlying DB-API, if
+ any.
+
+ This can be useful for calling ``setinputsizes()``, for example.
+
+ """
+ return None
+
+ @property
+ def python_type(self):
+ """Return the Python type object expected to be returned
+ by instances of this type, if known.
+
+ Basically, for those types which enforce a return type,
+ or are known across the board to do such for all common
+ DBAPIs (like ``int`` for example), will return that type.
+
+ If a return type is not defined, raises
+ ``NotImplementedError``.
+
+ Note that any type also accommodates NULL in SQL which
+ means you can also get back ``None`` from any type
+ in practice.
+
+ """
+ raise NotImplementedError()
+
+ def with_variant(self, type_, dialect_name):
+ r"""Produce a new type object that will utilize the given
+ type when applied to the dialect of the given name.
+
+ e.g.::
+
+ from sqlalchemy.types import String
+ from sqlalchemy.dialects import mysql
+
+ s = String()
+
+ s = s.with_variant(mysql.VARCHAR(collation='foo'), 'mysql')
+
+ The construction of :meth:`.TypeEngine.with_variant` is always
+ from the "fallback" type to that which is dialect specific.
+ The returned type is an instance of :class:`.Variant`, which
+ itself provides a :meth:`.Variant.with_variant`
+ that can be called repeatedly.
+
+ :param type\_: a :class:`.TypeEngine` that will be selected
+ as a variant from the originating type, when a dialect
+ of the given name is in use.
+ :param dialect_name: base name of the dialect which uses
+ this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
+
+ """
+ return Variant(self, {dialect_name: to_instance(type_)})
+
+ def _resolve_for_literal(self, value):
+ """adjust this type given a literal Python value that will be
+ stored in a bound parameter.
+
+ Used exclusively by _resolve_value_to_type().
+
+ .. versionadded:: 1.4.30 or 2.0
+
+ """
+ return self
+
+ @util.memoized_property
+ def _type_affinity(self):
+ """Return a rudimental 'affinity' value expressing the general class
+ of type."""
+
+ typ = None
+ for t in self.__class__.__mro__:
+ if t in (TypeEngine, UserDefinedType):
+ return typ
+ elif issubclass(t, (TypeEngine, UserDefinedType)):
+ typ = t
+ else:
+ return self.__class__
+
+ @util.memoized_property
+ def _generic_type_affinity(self):
+ best_camelcase = None
+ best_uppercase = None
+
+ if not isinstance(self, (TypeEngine, UserDefinedType)):
+ return self.__class__
+
+ for t in self.__class__.__mro__:
+ if (
+ t.__module__
+ in (
+ "sqlalchemy.sql.sqltypes",
+ "sqlalchemy.sql.type_api",
+ )
+ and issubclass(t, TypeEngine)
+ and t is not TypeEngine
+ and t.__name__[0] != "_"
+ ):
+ if t.__name__.isupper() and not best_uppercase:
+ best_uppercase = t
+ elif not t.__name__.isupper() and not best_camelcase:
+ best_camelcase = t
+
+ return best_camelcase or best_uppercase or NULLTYPE.__class__
+
+ def as_generic(self, allow_nulltype=False):
+ """
+ Return an instance of the generic type corresponding to this type
+ using heuristic rule. The method may be overridden if this
+ heuristic rule is not sufficient.
+
+ >>> from sqlalchemy.dialects.mysql import INTEGER
+ >>> INTEGER(display_width=4).as_generic()
+ Integer()
+
+ >>> from sqlalchemy.dialects.mysql import NVARCHAR
+ >>> NVARCHAR(length=100).as_generic()
+ Unicode(length=100)
+
+ .. versionadded:: 1.4.0b2
+
+
+ .. seealso::
+
+ :ref:`metadata_reflection_dbagnostic_types` - describes the
+ use of :meth:`_types.TypeEngine.as_generic` in conjunction with
+ the :meth:`_sql.DDLEvents.column_reflect` event, which is its
+ intended use.
+
+ """
+ if (
+ not allow_nulltype
+ and self._generic_type_affinity == NULLTYPE.__class__
+ ):
+ raise NotImplementedError(
+ "Default TypeEngine.as_generic() "
+ "heuristic method was unsuccessful for {}. A custom "
+ "as_generic() method must be implemented for this "
+ "type class.".format(
+ self.__class__.__module__ + "." + self.__class__.__name__
+ )
+ )
+
+ return util.constructor_copy(self, self._generic_type_affinity)
+
+ def dialect_impl(self, dialect):
+ """Return a dialect-specific implementation for this
+ :class:`.TypeEngine`.
+
+ """
+ try:
+ return dialect._type_memos[self]["impl"]
+ except KeyError:
+ pass
+ return self._dialect_info(dialect)["impl"]
+
+ def _unwrapped_dialect_impl(self, dialect):
+ """Return the 'unwrapped' dialect impl for this type.
+
+ For a type that applies wrapping logic (e.g. TypeDecorator), give
+ us the real, actual dialect-level type that is used.
+
+ This is used by TypeDecorator itself as well at least one case where
+ dialects need to check that a particular specific dialect-level
+ type is in use, within the :meth:`.DefaultDialect.set_input_sizes`
+ method.
+
+ """
+ return self.dialect_impl(dialect)
+
+ def _cached_literal_processor(self, dialect):
+ """Return a dialect-specific literal processor for this type."""
+ try:
+ return dialect._type_memos[self]["literal"]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into literal_processor() function
+ # raises
+ d = self._dialect_info(dialect)
+ d["literal"] = lp = d["impl"].literal_processor(dialect)
+ return lp
+
+ def _cached_bind_processor(self, dialect):
+ """Return a dialect-specific bind processor for this type."""
+
+ try:
+ return dialect._type_memos[self]["bind"]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into bind_processor() function
+ # raises
+ d = self._dialect_info(dialect)
+ d["bind"] = bp = d["impl"].bind_processor(dialect)
+ return bp
+
+ def _cached_result_processor(self, dialect, coltype):
+ """Return a dialect-specific result processor for this type."""
+
+ try:
+ return dialect._type_memos[self][coltype]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into result_processor() function
+ # raises
+ d = self._dialect_info(dialect)
+ # key assumption: DBAPI type codes are
+ # constants. Else this dictionary would
+ # grow unbounded.
+ d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
+ return rp
+
+ def _cached_custom_processor(self, dialect, key, fn):
+ try:
+ return dialect._type_memos[self][key]
+ except KeyError:
+ pass
+ # avoid KeyError context coming into fn() function
+ # raises
+ d = self._dialect_info(dialect)
+ impl = d["impl"]
+ d[key] = result = fn(impl)
+ return result
+
+ def _dialect_info(self, dialect):
+ """Return a dialect-specific registry which
+ caches a dialect-specific implementation, bind processing
+ function, and one or more result processing functions."""
+
+ if self in dialect._type_memos:
+ return dialect._type_memos[self]
+ else:
+ impl = self._gen_dialect_impl(dialect)
+ if impl is self:
+ impl = self.adapt(type(self))
+ # this can't be self, else we create a cycle
+ assert impl is not self
+ dialect._type_memos[self] = d = {"impl": impl}
+ return d
+
+ def _gen_dialect_impl(self, dialect):
+ return dialect.type_descriptor(self)
+
+ @util.memoized_property
+ def _static_cache_key(self):
+ names = util.get_cls_kwargs(self.__class__)
+ return (self.__class__,) + tuple(
+ (
+ k,
+ self.__dict__[k]._static_cache_key
+ if isinstance(self.__dict__[k], TypeEngine)
+ else self.__dict__[k],
+ )
+ for k in names
+ if k in self.__dict__ and not k.startswith("_")
+ )
+
+ def adapt(self, cls, **kw):
+ """Produce an "adapted" form of this type, given an "impl" class
+ to work with.
+
+ This method is used internally to associate generic
+ types with "implementation" types that are specific to a particular
+ dialect.
+ """
+ return util.constructor_copy(self, cls, **kw)
+
+ def coerce_compared_value(self, op, value):
+ """Suggest a type for a 'coerced' Python value in an expression.
+
+ Given an operator and value, gives the type a chance
+ to return a type which the value should be coerced into.
+
+ The default behavior here is conservative; if the right-hand
+ side is already coerced into a SQL type based on its
+ Python type, it is usually left alone.
+
+ End-user functionality extension here should generally be via
+ :class:`.TypeDecorator`, which provides more liberal behavior in that
+ it defaults to coercing the other side of the expression into this
+ type, thus applying special Python conversions above and beyond those
+ needed by the DBAPI to both ides. It also provides the public method
+ :meth:`.TypeDecorator.coerce_compared_value` which is intended for
+ end-user customization of this behavior.
+
+ """
+ _coerced_type = _resolve_value_to_type(value)
+ if (
+ _coerced_type is NULLTYPE
+ or _coerced_type._type_affinity is self._type_affinity
+ ):
+ return self
+ else:
+ return _coerced_type
+
+ def _compare_type_affinity(self, other):
+ return self._type_affinity is other._type_affinity
+
+ def compile(self, dialect=None):
+ """Produce a string-compiled form of this :class:`.TypeEngine`.
+
+ When called with no arguments, uses a "default" dialect
+ to produce a string result.
+
+ :param dialect: a :class:`.Dialect` instance.
+
+ """
+ # arg, return value is inconsistent with
+ # ClauseElement.compile()....this is a mistake.
+
+ if not dialect:
+ dialect = self._default_dialect()
+
+ return dialect.type_compiler.process(self)
+
+ @util.preload_module("sqlalchemy.engine.default")
+ def _default_dialect(self):
+ default = util.preloaded.engine_default
+ return default.StrCompileDialect()
+
+ def __str__(self):
+ if util.py2k:
+ return unicode(self.compile()).encode( # noqa
+ "ascii", "backslashreplace"
+ ) # noqa
+ else:
+ return str(self.compile())
+
+ def __repr__(self):
+ return util.generic_repr(self)
+
+
+class VisitableCheckKWArg(util.EnsureKWArgType, TraversibleType):
+ pass
+
+
+class ExternalType(object):
+ """mixin that defines attributes and behaviors specific to third-party
+ datatypes.
+
+ "Third party" refers to datatypes that are defined outside the scope
+ of SQLAlchemy within either end-user application code or within
+ external extensions to SQLAlchemy.
+
+ Subclasses currently include :class:`.TypeDecorator` and
+ :class:`.UserDefinedType`.
+
+ .. versionadded:: 1.4.28
+
+ """
+
+ cache_ok = None
+ """Indicate if statements using this :class:`.ExternalType` are "safe to
+ cache".
+
+ The default value ``None`` will emit a warning and then not allow caching
+ of a statement which includes this type. Set to ``False`` to disable
+ statements using this type from being cached at all without a warning.
+ When set to ``True``, the object's class and selected elements from its
+ state will be used as part of the cache key. For example, using a
+ :class:`.TypeDecorator`::
+
+ class MyType(TypeDecorator):
+ impl = String
+
+ cache_ok = True
+
+ def __init__(self, choices):
+ self.choices = tuple(choices)
+ self.internal_only = True
+
+ The cache key for the above type would be equivalent to::
+
+ >>> MyType(["a", "b", "c"])._static_cache_key
+ (<class '__main__.MyType'>, ('choices', ('a', 'b', 'c')))
+
+ The caching scheme will extract attributes from the type that correspond
+ to the names of parameters in the ``__init__()`` method. Above, the
+ "choices" attribute becomes part of the cache key but "internal_only"
+ does not, because there is no parameter named "internal_only".
+
+ The requirements for cacheable elements is that they are hashable
+ and also that they indicate the same SQL rendered for expressions using
+ this type every time for a given cache value.
+
+ To accommodate for datatypes that refer to unhashable structures such
+ as dictionaries, sets and lists, these objects can be made "cacheable"
+ by assigning hashable structures to the attributes whose names
+ correspond with the names of the arguments. For example, a datatype
+ which accepts a dictionary of lookup values may publish this as a sorted
+ series of tuples. Given a previously un-cacheable type as::
+
+ class LookupType(UserDefinedType):
+ '''a custom type that accepts a dictionary as a parameter.
+
+ this is the non-cacheable version, as "self.lookup" is not
+ hashable.
+
+ '''
+
+ def __init__(self, lookup):
+ self.lookup = lookup
+
+ def get_col_spec(self, **kw):
+ return "VARCHAR(255)"
+
+ def bind_processor(self, dialect):
+ # ... works with "self.lookup" ...
+
+ Where "lookup" is a dictionary. The type will not be able to generate
+ a cache key::
+
+ >>> type_ = LookupType({"a": 10, "b": 20})
+ >>> type_._static_cache_key
+ <stdin>:1: SAWarning: UserDefinedType LookupType({'a': 10, 'b': 20}) will not
+ produce a cache key because the ``cache_ok`` flag is not set to True.
+ Set this flag to True if this type object's state is safe to use
+ in a cache key, or False to disable this warning.
+ symbol('no_cache')
+
+ If we **did** set up such a cache key, it wouldn't be usable. We would
+ get a tuple structure that contains a dictionary inside of it, which
+ cannot itself be used as a key in a "cache dictionary" such as SQLAlchemy's
+ statement cache, since Python dictionaries aren't hashable::
+
+ >>> # set cache_ok = True
+ >>> type_.cache_ok = True
+
+ >>> # this is the cache key it would generate
+ >>> key = type_._static_cache_key
+ >>> key
+ (<class '__main__.LookupType'>, ('lookup', {'a': 10, 'b': 20}))
+
+ >>> # however this key is not hashable, will fail when used with
+ >>> # SQLAlchemy statement cache
+ >>> some_cache = {key: "some sql value"}
+ Traceback (most recent call last): File "<stdin>", line 1,
+ in <module> TypeError: unhashable type: 'dict'
+
+ The type may be made cacheable by assigning a sorted tuple of tuples
+ to the ".lookup" attribute::
+
+ class LookupType(UserDefinedType):
+ '''a custom type that accepts a dictionary as a parameter.
+
+ The dictionary is stored both as itself in a private variable,
+ and published in a public variable as a sorted tuple of tuples,
+ which is hashable and will also return the same value for any
+ two equivalent dictionaries. Note it assumes the keys and
+ values of the dictionary are themselves hashable.
+
+ '''
+
+ cache_ok = True
+
+ def __init__(self, lookup):
+ self._lookup = lookup
+
+ # assume keys/values of "lookup" are hashable; otherwise
+ # they would also need to be converted in some way here
+ self.lookup = tuple(
+ (key, lookup[key]) for key in sorted(lookup)
+ )
+
+ def get_col_spec(self, **kw):
+ return "VARCHAR(255)"
+
+ def bind_processor(self, dialect):
+ # ... works with "self._lookup" ...
+
+ Where above, the cache key for ``LookupType({"a": 10, "b": 20})`` will be::
+
+ >>> LookupType({"a": 10, "b": 20})._static_cache_key
+ (<class '__main__.LookupType'>, ('lookup', (('a', 10), ('b', 20))))
+
+ .. versionadded:: 1.4.14 - added the ``cache_ok`` flag to allow
+ some configurability of caching for :class:`.TypeDecorator` classes.
+
+ .. versionadded:: 1.4.28 - added the :class:`.ExternalType` mixin which
+ generalizes the ``cache_ok`` flag to both the :class:`.TypeDecorator`
+ and :class:`.UserDefinedType` classes.
+
+ .. seealso::
+
+ :ref:`sql_caching`
+
+ """ # noqa: E501
+
+ @property
+ def _static_cache_key(self):
+ cache_ok = self.__class__.__dict__.get("cache_ok", None)
+
+ if cache_ok is None:
+ subtype_idx = self.__class__.__mro__.index(ExternalType)
+ subtype = self.__class__.__mro__[max(subtype_idx - 1, 0)]
+
+ util.warn(
+ "%s %r will not produce a cache key because "
+ "the ``cache_ok`` attribute is not set to True. This can "
+ "have significant performance implications including some "
+ "performance degradations in comparison to prior SQLAlchemy "
+ "versions. Set this attribute to True if this type object's "
+ "state is safe to use in a cache key, or False to "
+ "disable this warning." % (subtype.__name__, self),
+ code="cprf",
+ )
+ elif cache_ok is True:
+ return super(ExternalType, self)._static_cache_key
+
+ return NO_CACHE
+
+
+class UserDefinedType(
+ util.with_metaclass(VisitableCheckKWArg, ExternalType, TypeEngine)
+):
+ """Base for user defined types.
+
+ This should be the base of new types. Note that
+ for most cases, :class:`.TypeDecorator` is probably
+ more appropriate::
+
+ import sqlalchemy.types as types
+
+ class MyType(types.UserDefinedType):
+ cache_ok = True
+
+ def __init__(self, precision = 8):
+ self.precision = precision
+
+ def get_col_spec(self, **kw):
+ return "MYTYPE(%s)" % self.precision
+
+ def bind_processor(self, dialect):
+ def process(value):
+ return value
+ return process
+
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ return value
+ return process
+
+ Once the type is made, it's immediately usable::
+
+ table = Table('foo', metadata_obj,
+ Column('id', Integer, primary_key=True),
+ Column('data', MyType(16))
+ )
+
+ The ``get_col_spec()`` method will in most cases receive a keyword
+ argument ``type_expression`` which refers to the owning expression
+ of the type as being compiled, such as a :class:`_schema.Column` or
+ :func:`.cast` construct. This keyword is only sent if the method
+ accepts keyword arguments (e.g. ``**kw``) in its argument signature;
+ introspection is used to check for this in order to support legacy
+ forms of this function.
+
+ .. versionadded:: 1.0.0 the owning expression is passed to
+ the ``get_col_spec()`` method via the keyword argument
+ ``type_expression``, if it receives ``**kw`` in its signature.
+
+ The :attr:`.UserDefinedType.cache_ok` class-level flag indicates if this
+ custom :class:`.UserDefinedType` is safe to be used as part of a cache key.
+ This flag defaults to ``None`` which will initially generate a warning
+ when the SQL compiler attempts to generate a cache key for a statement
+ that uses this type. If the :class:`.UserDefinedType` is not guaranteed
+ to produce the same bind/result behavior and SQL generation
+ every time, this flag should be set to ``False``; otherwise if the
+ class produces the same behavior each time, it may be set to ``True``.
+ See :attr:`.UserDefinedType.cache_ok` for further notes on how this works.
+
+ .. versionadded:: 1.4.28 Generalized the :attr:`.ExternalType.cache_ok`
+ flag so that it is available for both :class:`.TypeDecorator` as well
+ as :class:`.UserDefinedType`.
+
+ """
+
+ __visit_name__ = "user_defined"
+
+ ensure_kwarg = "get_col_spec"
+
+ def coerce_compared_value(self, op, value):
+ """Suggest a type for a 'coerced' Python value in an expression.
+
+ Default behavior for :class:`.UserDefinedType` is the
+ same as that of :class:`.TypeDecorator`; by default it returns
+ ``self``, assuming the compared value should be coerced into
+ the same type as this one. See
+ :meth:`.TypeDecorator.coerce_compared_value` for more detail.
+
+ """
+
+ return self
+
+
+class Emulated(object):
+ """Mixin for base types that emulate the behavior of a DB-native type.
+
+ An :class:`.Emulated` type will use an available database type
+ in conjunction with Python-side routines and/or database constraints
+ in order to approximate the behavior of a database type that is provided
+ natively by some backends. When a native-providing backend is in
+ use, the native version of the type is used. This native version
+ should include the :class:`.NativeForEmulated` mixin to allow it to be
+ distinguished from :class:`.Emulated`.
+
+ Current examples of :class:`.Emulated` are: :class:`.Interval`,
+ :class:`.Enum`, :class:`.Boolean`.
+
+ .. versionadded:: 1.2.0b3
+
+ """
+
+ def adapt_to_emulated(self, impltype, **kw):
+ """Given an impl class, adapt this type to the impl assuming
+ "emulated".
+
+ The impl should also be an "emulated" version of this type,
+ most likely the same class as this type itself.
+
+ e.g.: sqltypes.Enum adapts to the Enum class.
+
+ """
+ return super(Emulated, self).adapt(impltype, **kw)
+
+ def adapt(self, impltype, **kw):
+ if hasattr(impltype, "adapt_emulated_to_native"):
+ if self.native:
+ # native support requested, dialect gave us a native
+ # implementor, pass control over to it
+ return impltype.adapt_emulated_to_native(self, **kw)
+ else:
+ # non-native support, let the native implementor
+ # decide also, at the moment this is just to help debugging
+ # as only the default logic is implemented.
+ return impltype.adapt_native_to_emulated(self, **kw)
+ else:
+ if issubclass(impltype, self.__class__):
+ return self.adapt_to_emulated(impltype, **kw)
+ else:
+ return super(Emulated, self).adapt(impltype, **kw)
+
+
+class NativeForEmulated(object):
+ """Indicates DB-native types supported by an :class:`.Emulated` type.
+
+ .. versionadded:: 1.2.0b3
+
+ """
+
+ @classmethod
+ def adapt_native_to_emulated(cls, impl, **kw):
+ """Given an impl, adapt this type's class to the impl assuming
+ "emulated".
+
+
+ """
+ impltype = impl.__class__
+ return impl.adapt(impltype, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Given an impl, adapt this type's class to the impl assuming
+ "native".
+
+ The impl will be an :class:`.Emulated` class but not a
+ :class:`.NativeForEmulated`.
+
+ e.g.: postgresql.ENUM produces a type given an Enum instance.
+
+ """
+ return cls(**kw)
+
+
+class TypeDecorator(ExternalType, SchemaEventTarget, TypeEngine):
+ """Allows the creation of types which add additional functionality
+ to an existing type.
+
+ This method is preferred to direct subclassing of SQLAlchemy's
+ built-in types as it ensures that all required functionality of
+ the underlying type is kept in place.
+
+ Typical usage::
+
+ import sqlalchemy.types as types
+
+ class MyType(types.TypeDecorator):
+ '''Prefixes Unicode values with "PREFIX:" on the way in and
+ strips it off on the way out.
+ '''
+
+ impl = types.Unicode
+
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect):
+ return "PREFIX:" + value
+
+ def process_result_value(self, value, dialect):
+ return value[7:]
+
+ def copy(self, **kw):
+ return MyType(self.impl.length)
+
+ The class-level ``impl`` attribute is required, and can reference any
+ :class:`.TypeEngine` class. Alternatively, the :meth:`load_dialect_impl`
+ method can be used to provide different type classes based on the dialect
+ given; in this case, the ``impl`` variable can reference
+ ``TypeEngine`` as a placeholder.
+
+ The :attr:`.TypeDecorator.cache_ok` class-level flag indicates if this
+ custom :class:`.TypeDecorator` is safe to be used as part of a cache key.
+ This flag defaults to ``None`` which will initially generate a warning
+ when the SQL compiler attempts to generate a cache key for a statement
+ that uses this type. If the :class:`.TypeDecorator` is not guaranteed
+ to produce the same bind/result behavior and SQL generation
+ every time, this flag should be set to ``False``; otherwise if the
+ class produces the same behavior each time, it may be set to ``True``.
+ See :attr:`.TypeDecorator.cache_ok` for further notes on how this works.
+
+ Types that receive a Python type that isn't similar to the ultimate type
+ used may want to define the :meth:`TypeDecorator.coerce_compared_value`
+ method. This is used to give the expression system a hint when coercing
+ Python objects into bind parameters within expressions. Consider this
+ expression::
+
+ mytable.c.somecol + datetime.date(2009, 5, 15)
+
+ Above, if "somecol" is an ``Integer`` variant, it makes sense that
+ we're doing date arithmetic, where above is usually interpreted
+ by databases as adding a number of days to the given date.
+ The expression system does the right thing by not attempting to
+ coerce the "date()" value into an integer-oriented bind parameter.
+
+ However, in the case of ``TypeDecorator``, we are usually changing an
+ incoming Python type to something new - ``TypeDecorator`` by default will
+ "coerce" the non-typed side to be the same type as itself. Such as below,
+ we define an "epoch" type that stores a date value as an integer::
+
+ class MyEpochType(types.TypeDecorator):
+ impl = types.Integer
+
+ epoch = datetime.date(1970, 1, 1)
+
+ def process_bind_param(self, value, dialect):
+ return (value - self.epoch).days
+
+ def process_result_value(self, value, dialect):
+ return self.epoch + timedelta(days=value)
+
+ Our expression of ``somecol + date`` with the above type will coerce the
+ "date" on the right side to also be treated as ``MyEpochType``.
+
+ This behavior can be overridden via the
+ :meth:`~TypeDecorator.coerce_compared_value` method, which returns a type
+ that should be used for the value of the expression. Below we set it such
+ that an integer value will be treated as an ``Integer``, and any other
+ value is assumed to be a date and will be treated as a ``MyEpochType``::
+
+ def coerce_compared_value(self, op, value):
+ if isinstance(value, int):
+ return Integer()
+ else:
+ return self
+
+ .. warning::
+
+ Note that the **behavior of coerce_compared_value is not inherited
+ by default from that of the base type**.
+ If the :class:`.TypeDecorator` is augmenting a
+ type that requires special logic for certain types of operators,
+ this method **must** be overridden. A key example is when decorating
+ the :class:`_postgresql.JSON` and :class:`_postgresql.JSONB` types;
+ the default rules of :meth:`.TypeEngine.coerce_compared_value` should
+ be used in order to deal with operators like index operations::
+
+ from sqlalchemy import JSON
+ from sqlalchemy import TypeDecorator
+
+ class MyJsonType(TypeDecorator):
+ impl = JSON
+
+ cache_ok = True
+
+ def coerce_compared_value(self, op, value):
+ return self.impl.coerce_compared_value(op, value)
+
+ Without the above step, index operations such as ``mycol['foo']``
+ will cause the index value ``'foo'`` to be JSON encoded.
+
+ Similarly, when working with the :class:`.ARRAY` datatype, the
+ type coercion for index operations (e.g. ``mycol[5]``) is also
+ handled by :meth:`.TypeDecorator.coerce_compared_value`, where
+ again a simple override is sufficient unless special rules are needed
+ for particular operators::
+
+ from sqlalchemy import ARRAY
+ from sqlalchemy import TypeDecorator
+
+ class MyArrayType(TypeDecorator):
+ impl = ARRAY
+
+ cache_ok = True
+
+ def coerce_compared_value(self, op, value):
+ return self.impl.coerce_compared_value(op, value)
+
+
+ """
+
+ __visit_name__ = "type_decorator"
+
+ _is_type_decorator = True
+
+ def __init__(self, *args, **kwargs):
+ """Construct a :class:`.TypeDecorator`.
+
+ Arguments sent here are passed to the constructor
+ of the class assigned to the ``impl`` class level attribute,
+ assuming the ``impl`` is a callable, and the resulting
+ object is assigned to the ``self.impl`` instance attribute
+ (thus overriding the class attribute of the same name).
+
+ If the class level ``impl`` is not a callable (the unusual case),
+ it will be assigned to the same instance attribute 'as-is',
+ ignoring those arguments passed to the constructor.
+
+ Subclasses can override this to customize the generation
+ of ``self.impl`` entirely.
+
+ """
+
+ if not hasattr(self.__class__, "impl"):
+ raise AssertionError(
+ "TypeDecorator implementations "
+ "require a class-level variable "
+ "'impl' which refers to the class of "
+ "type being decorated"
+ )
+ self.impl = to_instance(self.__class__.impl, *args, **kwargs)
+
+ coerce_to_is_types = (util.NoneType,)
+ """Specify those Python types which should be coerced at the expression
+ level to "IS <constant>" when compared using ``==`` (and same for
+ ``IS NOT`` in conjunction with ``!=``).
+
+ For most SQLAlchemy types, this includes ``NoneType``, as well as
+ ``bool``.
+
+ :class:`.TypeDecorator` modifies this list to only include ``NoneType``,
+ as typedecorator implementations that deal with boolean types are common.
+
+ Custom :class:`.TypeDecorator` classes can override this attribute to
+ return an empty tuple, in which case no values will be coerced to
+ constants.
+
+ """
+
+ class Comparator(TypeEngine.Comparator):
+ """A :class:`.TypeEngine.Comparator` that is specific to
+ :class:`.TypeDecorator`.
+
+ User-defined :class:`.TypeDecorator` classes should not typically
+ need to modify this.
+
+
+ """
+
+ __slots__ = ()
+
+ def operate(self, op, *other, **kwargs):
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
+ return super(TypeDecorator.Comparator, self).operate(
+ op, *other, **kwargs
+ )
+
+ def reverse_operate(self, op, other, **kwargs):
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
+ return super(TypeDecorator.Comparator, self).reverse_operate(
+ op, other, **kwargs
+ )
+
+ @property
+ def comparator_factory(self):
+ if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:
+ return self.impl.comparator_factory
+ else:
+ return type(
+ "TDComparator",
+ (TypeDecorator.Comparator, self.impl.comparator_factory),
+ {},
+ )
+
+ def _gen_dialect_impl(self, dialect):
+ """
+ #todo
+ """
+ adapted = dialect.type_descriptor(self)
+ if adapted is not self:
+ return adapted
+
+ # otherwise adapt the impl type, link
+ # to a copy of this TypeDecorator and return
+ # that.
+ typedesc = self.load_dialect_impl(dialect).dialect_impl(dialect)
+ tt = self.copy()
+ if not isinstance(tt, self.__class__):
+ raise AssertionError(
+ "Type object %s does not properly "
+ "implement the copy() method, it must "
+ "return an object of type %s" % (self, self.__class__)
+ )
+ tt.impl = typedesc
+ return tt
+
+ @property
+ def _type_affinity(self):
+ """
+ #todo
+ """
+ return self.impl._type_affinity
+
+ def _set_parent(self, column, outer=False, **kw):
+ """Support SchemaEventTarget"""
+
+ super(TypeDecorator, self)._set_parent(column)
+
+ if not outer and isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent(column, outer=False, **kw)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ super(TypeDecorator, self)._set_parent_with_dispatch(
+ parent, outer=True
+ )
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent_with_dispatch(parent)
+
+ def type_engine(self, dialect):
+ """Return a dialect-specific :class:`.TypeEngine` instance
+ for this :class:`.TypeDecorator`.
+
+ In most cases this returns a dialect-adapted form of
+ the :class:`.TypeEngine` type represented by ``self.impl``.
+ Makes usage of :meth:`dialect_impl`.
+ Behavior can be customized here by overriding
+ :meth:`load_dialect_impl`.
+
+ """
+ adapted = dialect.type_descriptor(self)
+ if not isinstance(adapted, type(self)):
+ return adapted
+ else:
+ return self.load_dialect_impl(dialect)
+
+ def load_dialect_impl(self, dialect):
+ """Return a :class:`.TypeEngine` object corresponding to a dialect.
+
+ This is an end-user override hook that can be used to provide
+ differing types depending on the given dialect. It is used
+ by the :class:`.TypeDecorator` implementation of :meth:`type_engine`
+ to help determine what type should ultimately be returned
+ for a given :class:`.TypeDecorator`.
+
+ By default returns ``self.impl``.
+
+ """
+ return self.impl
+
+ def _unwrapped_dialect_impl(self, dialect):
+ """Return the 'unwrapped' dialect impl for this type.
+
+ This is used by the :meth:`.DefaultDialect.set_input_sizes`
+ method.
+
+ """
+ # some dialects have a lookup for a TypeDecorator subclass directly.
+ # postgresql.INTERVAL being the main example
+ typ = self.dialect_impl(dialect)
+
+ # if we are still a type decorator, load the per-dialect switch
+ # (such as what Variant uses), then get the dialect impl for that.
+ if isinstance(typ, self.__class__):
+ return typ.load_dialect_impl(dialect).dialect_impl(dialect)
+ else:
+ return typ
+
+ def __getattr__(self, key):
+ """Proxy all other undefined accessors to the underlying
+ implementation."""
+ return getattr(self.impl, key)
+
+ def process_literal_param(self, value, dialect):
+ """Receive a literal parameter value to be rendered inline within
+ a statement.
+
+ .. note::
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. Unlike other SQL
+ compilation methods, it is passed a specific Python value to be
+ rendered as a string. However it should not be confused with the
+ :meth:`_types.TypeDecorator.process_bind_param` method, which is
+ the more typical method that processes the actual value passed to a
+ particular parameter at statement execution time.
+
+ Custom subclasses of :class:`_types.TypeDecorator` should override
+ this method to provide custom behaviors for incoming data values
+ that are in the special case of being rendered as literals.
+
+ The returned string will be rendered into the output string.
+
+ """
+ raise NotImplementedError()
+
+ def process_bind_param(self, value, dialect):
+ """Receive a bound parameter value to be converted.
+
+ Custom subclasses of :class:`_types.TypeDecorator` should override
+ this method to provide custom behaviors for incoming data values.
+ This method is called at **statement execution time** and is passed
+ the literal Python data value which is to be associated with a bound
+ parameter in the statement.
+
+ The operation could be anything desired to perform custom
+ behavior, such as transforming or serializing data.
+ This could also be used as a hook for validating logic.
+
+ :param value: Data to operate upon, of any type expected by
+ this method in the subclass. Can be ``None``.
+ :param dialect: the :class:`.Dialect` in use.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ :meth:`_types.TypeDecorator.process_result_value`
+
+ """
+
+ raise NotImplementedError()
+
+ def process_result_value(self, value, dialect):
+ """Receive a result-row column value to be converted.
+
+ Custom subclasses of :class:`_types.TypeDecorator` should override
+ this method to provide custom behaviors for data values
+ being received in result rows coming from the database.
+ This method is called at **result fetching time** and is passed
+ the literal Python data value that's extracted from a database result
+ row.
+
+ The operation could be anything desired to perform custom
+ behavior, such as transforming or deserializing data.
+
+ :param value: Data to operate upon, of any type expected by
+ this method in the subclass. Can be ``None``.
+ :param dialect: the :class:`.Dialect` in use.
+
+ .. seealso::
+
+ :ref:`types_typedecorator`
+
+ :meth:`_types.TypeDecorator.process_bind_param`
+
+
+ """
+
+ raise NotImplementedError()
+
+ @util.memoized_property
+ def _has_bind_processor(self):
+ """memoized boolean, check if process_bind_param is implemented.
+
+ Allows the base process_bind_param to raise
+ NotImplementedError without needing to test an expensive
+ exception throw.
+
+ """
+
+ return util.method_is_overridden(
+ self, TypeDecorator.process_bind_param
+ )
+
+ @util.memoized_property
+ def _has_literal_processor(self):
+ """memoized boolean, check if process_literal_param is implemented."""
+
+ return util.method_is_overridden(
+ self, TypeDecorator.process_literal_param
+ )
+
+ def literal_processor(self, dialect):
+ """Provide a literal processing function for the given
+ :class:`.Dialect`.
+
+ This is the method that fulfills the :class:`.TypeEngine`
+ contract for literal value conversion which normally occurs via
+ the :meth:`_types.TypeEngine.literal_processor` method.
+
+ .. note::
+
+ User-defined subclasses of :class:`_types.TypeDecorator` should
+ **not** implement this method, and should instead implement
+ :meth:`_types.TypeDecorator.process_literal_param` so that the
+ "inner" processing provided by the implementing type is maintained.
+
+ """
+ if self._has_literal_processor:
+ process_param = self.process_literal_param
+ elif self._has_bind_processor:
+ # the bind processor should normally be OK
+ # for TypeDecorator since it isn't doing DB-level
+ # handling, the handling here won't be different for bound vs.
+ # literals.
+ process_param = self.process_bind_param
+ else:
+ process_param = None
+
+ if process_param:
+ impl_processor = self.impl.literal_processor(dialect)
+ if impl_processor:
+
+ def process(value):
+ return impl_processor(process_param(value, dialect))
+
+ else:
+
+ def process(value):
+ return process_param(value, dialect)
+
+ return process
+ else:
+ return self.impl.literal_processor(dialect)
+
+ def bind_processor(self, dialect):
+ """Provide a bound value processing function for the
+ given :class:`.Dialect`.
+
+ This is the method that fulfills the :class:`.TypeEngine`
+ contract for bound value conversion which normally occurs via
+ the :meth:`_types.TypeEngine.bind_processor` method.
+
+ .. note::
+
+ User-defined subclasses of :class:`_types.TypeDecorator` should
+ **not** implement this method, and should instead implement
+ :meth:`_types.TypeDecorator.process_bind_param` so that the "inner"
+ processing provided by the implementing type is maintained.
+
+ :param dialect: Dialect instance in use.
+
+ """
+ if self._has_bind_processor:
+ process_param = self.process_bind_param
+ impl_processor = self.impl.bind_processor(dialect)
+ if impl_processor:
+
+ def process(value):
+ return impl_processor(process_param(value, dialect))
+
+ else:
+
+ def process(value):
+ return process_param(value, dialect)
+
+ return process
+ else:
+ return self.impl.bind_processor(dialect)
+
+ @util.memoized_property
+ def _has_result_processor(self):
+ """memoized boolean, check if process_result_value is implemented.
+
+ Allows the base process_result_value to raise
+ NotImplementedError without needing to test an expensive
+ exception throw.
+
+ """
+
+ return util.method_is_overridden(
+ self, TypeDecorator.process_result_value
+ )
+
+ def result_processor(self, dialect, coltype):
+ """Provide a result value processing function for the given
+ :class:`.Dialect`.
+
+ This is the method that fulfills the :class:`.TypeEngine`
+ contract for bound value conversion which normally occurs via
+ the :meth:`_types.TypeEngine.result_processor` method.
+
+ .. note::
+
+ User-defined subclasses of :class:`_types.TypeDecorator` should
+ **not** implement this method, and should instead implement
+ :meth:`_types.TypeDecorator.process_result_value` so that the
+ "inner" processing provided by the implementing type is maintained.
+
+ :param dialect: Dialect instance in use.
+ :param coltype: A SQLAlchemy data type
+
+ """
+ if self._has_result_processor:
+ process_value = self.process_result_value
+ impl_processor = self.impl.result_processor(dialect, coltype)
+ if impl_processor:
+
+ def process(value):
+ return process_value(impl_processor(value), dialect)
+
+ else:
+
+ def process(value):
+ return process_value(value, dialect)
+
+ return process
+ else:
+ return self.impl.result_processor(dialect, coltype)
+
+ @util.memoized_property
+ def _has_bind_expression(self):
+
+ return (
+ util.method_is_overridden(self, TypeDecorator.bind_expression)
+ or self.impl._has_bind_expression
+ )
+
+ def bind_expression(self, bindparam):
+ """Given a bind value (i.e. a :class:`.BindParameter` instance),
+ return a SQL expression which will typically wrap the given parameter.
+
+ .. note::
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** necessarily
+ called against specific values, and should not be confused with the
+ :meth:`_types.TypeDecorator.process_bind_param` method, which is
+ the more typical method that processes the actual value passed to a
+ particular parameter at statement execution time.
+
+ Subclasses of :class:`_types.TypeDecorator` can override this method
+ to provide custom bind expression behavior for the type. This
+ implementation will **replace** that of the underlying implementation
+ type.
+
+ """
+ return self.impl.bind_expression(bindparam)
+
+ @util.memoized_property
+ def _has_column_expression(self):
+ """memoized boolean, check if column_expression is implemented.
+
+ Allows the method to be skipped for the vast majority of expression
+ types that don't use this feature.
+
+ """
+
+ return (
+ util.method_is_overridden(self, TypeDecorator.column_expression)
+ or self.impl._has_column_expression
+ )
+
+ def column_expression(self, column):
+ """Given a SELECT column expression, return a wrapping SQL expression.
+
+ .. note::
+
+ This method is called during the **SQL compilation** phase of a
+ statement, when rendering a SQL string. It is **not** called
+ against specific values, and should not be confused with the
+ :meth:`_types.TypeDecorator.process_result_value` method, which is
+ the more typical method that processes the actual value returned
+ in a result row subsequent to statement execution time.
+
+ Subclasses of :class:`_types.TypeDecorator` can override this method
+ to provide custom column expresion behavior for the type. This
+ implementation will **replace** that of the underlying implementation
+ type.
+
+ See the description of :meth:`_types.TypeEngine.column_expression`
+ for a complete description of the method's use.
+
+ """
+
+ return self.impl.column_expression(column)
+
+ def coerce_compared_value(self, op, value):
+ """Suggest a type for a 'coerced' Python value in an expression.
+
+ By default, returns self. This method is called by
+ the expression system when an object using this type is
+ on the left or right side of an expression against a plain Python
+ object which does not yet have a SQLAlchemy type assigned::
+
+ expr = table.c.somecolumn + 35
+
+ Where above, if ``somecolumn`` uses this type, this method will
+ be called with the value ``operator.add``
+ and ``35``. The return value is whatever SQLAlchemy type should
+ be used for ``35`` for this particular operation.
+
+ """
+ return self
+
+ def copy(self, **kw):
+ """Produce a copy of this :class:`.TypeDecorator` instance.
+
+ This is a shallow copy and is provided to fulfill part of
+ the :class:`.TypeEngine` contract. It usually does not
+ need to be overridden unless the user-defined :class:`.TypeDecorator`
+ has local state that should be deep-copied.
+
+ """
+
+ instance = self.__class__.__new__(self.__class__)
+ instance.__dict__.update(self.__dict__)
+ return instance
+
+ def get_dbapi_type(self, dbapi):
+ """Return the DBAPI type object represented by this
+ :class:`.TypeDecorator`.
+
+ By default this calls upon :meth:`.TypeEngine.get_dbapi_type` of the
+ underlying "impl".
+ """
+ return self.impl.get_dbapi_type(dbapi)
+
+ def compare_values(self, x, y):
+ """Given two values, compare them for equality.
+
+ By default this calls upon :meth:`.TypeEngine.compare_values`
+ of the underlying "impl", which in turn usually
+ uses the Python equals operator ``==``.
+
+ This function is used by the ORM to compare
+ an original-loaded value with an intercepted
+ "changed" value, to determine if a net change
+ has occurred.
+
+ """
+ return self.impl.compare_values(x, y)
+
+ @property
+ def sort_key_function(self):
+ return self.impl.sort_key_function
+
+ def __repr__(self):
+ return util.generic_repr(self, to_inspect=self.impl)
+
+
+class Variant(TypeDecorator):
+ """A wrapping type that selects among a variety of
+ implementations based on dialect in use.
+
+ The :class:`.Variant` type is typically constructed
+ using the :meth:`.TypeEngine.with_variant` method.
+
+ .. seealso:: :meth:`.TypeEngine.with_variant` for an example of use.
+
+ """
+
+ cache_ok = True
+
+ def __init__(self, base, mapping):
+ """Construct a new :class:`.Variant`.
+
+ :param base: the base 'fallback' type
+ :param mapping: dictionary of string dialect names to
+ :class:`.TypeEngine` instances.
+
+ """
+ self.impl = base
+ self.mapping = mapping
+
+ @util.memoized_property
+ def _static_cache_key(self):
+ # TODO: needs tests in test/sql/test_compare.py
+ return (self.__class__,) + (
+ self.impl._static_cache_key,
+ tuple(
+ (key, self.mapping[key]._static_cache_key)
+ for key in sorted(self.mapping)
+ ),
+ )
+
+ def coerce_compared_value(self, operator, value):
+ result = self.impl.coerce_compared_value(operator, value)
+ if result is self.impl:
+ return self
+ else:
+ return result
+
+ def load_dialect_impl(self, dialect):
+ if dialect.name in self.mapping:
+ return self.mapping[dialect.name]
+ else:
+ return self.impl
+
+ def _set_parent(self, column, outer=False, **kw):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent(column, **kw)
+ for impl in self.mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent(column, **kw)
+
+ def _set_parent_with_dispatch(self, parent):
+ """Support SchemaEventTarget"""
+
+ if isinstance(self.impl, SchemaEventTarget):
+ self.impl._set_parent_with_dispatch(parent)
+ for impl in self.mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent_with_dispatch(parent)
+
+ def with_variant(self, type_, dialect_name):
+ r"""Return a new :class:`.Variant` which adds the given
+ type + dialect name to the mapping, in addition to the
+ mapping present in this :class:`.Variant`.
+
+ :param type\_: a :class:`.TypeEngine` that will be selected
+ as a variant from the originating type, when a dialect
+ of the given name is in use.
+ :param dialect_name: base name of the dialect which uses
+ this type. (i.e. ``'postgresql'``, ``'mysql'``, etc.)
+
+ """
+
+ if dialect_name in self.mapping:
+ raise exc.ArgumentError(
+ "Dialect '%s' is already present in "
+ "the mapping for this Variant" % dialect_name
+ )
+ mapping = self.mapping.copy()
+ mapping[dialect_name] = type_
+ return Variant(self.impl, mapping)
+
+ @property
+ def comparator_factory(self):
+ """express comparison behavior in terms of the base type"""
+ return self.impl.comparator_factory
+
+
+def _reconstitute_comparator(expression):
+ return expression.comparator
+
+
+def to_instance(typeobj, *arg, **kw):
+ if typeobj is None:
+ return NULLTYPE
+
+ if callable(typeobj):
+ return typeobj(*arg, **kw)
+ else:
+ return typeobj
+
+
+def adapt_type(typeobj, colspecs):
+ if isinstance(typeobj, type):
+ typeobj = typeobj()
+ for t in typeobj.__class__.__mro__[0:-1]:
+ try:
+ impltype = colspecs[t]
+ break
+ except KeyError:
+ pass
+ else:
+ # couldn't adapt - so just return the type itself
+ # (it may be a user-defined type)
+ return typeobj
+ # if we adapted the given generic type to a database-specific type,
+ # but it turns out the originally given "generic" type
+ # is actually a subclass of our resulting type, then we were already
+ # given a more specific type than that required; so use that.
+ if issubclass(typeobj.__class__, impltype):
+ return typeobj
+ return typeobj.adapt(impltype)
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
new file mode 100644
index 0000000..019b29e
--- /dev/null
+++ b/lib/sqlalchemy/sql/util.py
@@ -0,0 +1,1120 @@
+# sql/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""High level utilities which build upon other modules here.
+
+"""
+
+from collections import deque
+from itertools import chain
+
+from . import coercions
+from . import operators
+from . import roles
+from . import visitors
+from .annotation import _deep_annotate # noqa
+from .annotation import _deep_deannotate # noqa
+from .annotation import _shallow_annotate # noqa
+from .base import _expand_cloned
+from .base import _from_objects
+from .base import ColumnSet
+from .ddl import sort_tables # noqa
+from .elements import _find_columns # noqa
+from .elements import _label_reference
+from .elements import _textual_label_reference
+from .elements import BindParameter
+from .elements import ColumnClause
+from .elements import ColumnElement
+from .elements import Grouping
+from .elements import Label
+from .elements import Null
+from .elements import UnaryExpression
+from .schema import Column
+from .selectable import Alias
+from .selectable import FromClause
+from .selectable import FromGrouping
+from .selectable import Join
+from .selectable import ScalarSelect
+from .selectable import SelectBase
+from .selectable import TableClause
+from .traversals import HasCacheKey # noqa
+from .. import exc
+from .. import util
+
+
+join_condition = util.langhelpers.public_factory(
+ Join._join_condition, ".sql.util.join_condition"
+)
+
+
+def find_join_source(clauses, join_to):
+ """Given a list of FROM clauses and a selectable,
+ return the first index and element from the list of
+ clauses which can be joined against the selectable. returns
+ None, None if no match is found.
+
+ e.g.::
+
+ clause1 = table1.join(table2)
+ clause2 = table4.join(table5)
+
+ join_to = table2.join(table3)
+
+ find_join_source([clause1, clause2], join_to) == clause1
+
+ """
+
+ selectables = list(_from_objects(join_to))
+ idx = []
+ for i, f in enumerate(clauses):
+ for s in selectables:
+ if f.is_derived_from(s):
+ idx.append(i)
+ return idx
+
+
+def find_left_clause_that_matches_given(clauses, join_from):
+ """Given a list of FROM clauses and a selectable,
+ return the indexes from the list of
+ clauses which is derived from the selectable.
+
+ """
+
+ selectables = list(_from_objects(join_from))
+ liberal_idx = []
+ for i, f in enumerate(clauses):
+ for s in selectables:
+ # basic check, if f is derived from s.
+ # this can be joins containing a table, or an aliased table
+ # or select statement matching to a table. This check
+ # will match a table to a selectable that is adapted from
+ # that table. With Query, this suits the case where a join
+ # is being made to an adapted entity
+ if f.is_derived_from(s):
+ liberal_idx.append(i)
+ break
+
+ # in an extremely small set of use cases, a join is being made where
+ # there are multiple FROM clauses where our target table is represented
+ # in more than one, such as embedded or similar. in this case, do
+ # another pass where we try to get a more exact match where we aren't
+ # looking at adaption relationships.
+ if len(liberal_idx) > 1:
+ conservative_idx = []
+ for idx in liberal_idx:
+ f = clauses[idx]
+ for s in selectables:
+ if set(surface_selectables(f)).intersection(
+ surface_selectables(s)
+ ):
+ conservative_idx.append(idx)
+ break
+ if conservative_idx:
+ return conservative_idx
+
+ return liberal_idx
+
+
+def find_left_clause_to_join_from(clauses, join_to, onclause):
+ """Given a list of FROM clauses, a selectable,
+ and optional ON clause, return a list of integer indexes from the
+ clauses list indicating the clauses that can be joined from.
+
+ The presence of an "onclause" indicates that at least one clause can
+ definitely be joined from; if the list of clauses is of length one
+ and the onclause is given, returns that index. If the list of clauses
+ is more than length one, and the onclause is given, attempts to locate
+ which clauses contain the same columns.
+
+ """
+ idx = []
+ selectables = set(_from_objects(join_to))
+
+ # if we are given more than one target clause to join
+ # from, use the onclause to provide a more specific answer.
+ # otherwise, don't try to limit, after all, "ON TRUE" is a valid
+ # on clause
+ if len(clauses) > 1 and onclause is not None:
+ resolve_ambiguity = True
+ cols_in_onclause = _find_columns(onclause)
+ else:
+ resolve_ambiguity = False
+ cols_in_onclause = None
+
+ for i, f in enumerate(clauses):
+ for s in selectables.difference([f]):
+ if resolve_ambiguity:
+ if set(f.c).union(s.c).issuperset(cols_in_onclause):
+ idx.append(i)
+ break
+ elif onclause is not None or Join._can_join(f, s):
+ idx.append(i)
+ break
+
+ if len(idx) > 1:
+ # this is the same "hide froms" logic from
+ # Selectable._get_display_froms
+ toremove = set(
+ chain(*[_expand_cloned(f._hide_froms) for f in clauses])
+ )
+ idx = [i for i in idx if clauses[i] not in toremove]
+
+ # onclause was given and none of them resolved, so assume
+ # all indexes can match
+ if not idx and onclause is not None:
+ return range(len(clauses))
+ else:
+ return idx
+
+
+def visit_binary_product(fn, expr):
+ """Produce a traversal of the given expression, delivering
+ column comparisons to the given function.
+
+ The function is of the form::
+
+ def my_fn(binary, left, right)
+
+ For each binary expression located which has a
+ comparison operator, the product of "left" and
+ "right" will be delivered to that function,
+ in terms of that binary.
+
+ Hence an expression like::
+
+ and_(
+ (a + b) == q + func.sum(e + f),
+ j == r
+ )
+
+ would have the traversal::
+
+ a <eq> q
+ a <eq> e
+ a <eq> f
+ b <eq> q
+ b <eq> e
+ b <eq> f
+ j <eq> r
+
+ That is, every combination of "left" and
+ "right" that doesn't further contain
+ a binary comparison is passed as pairs.
+
+ """
+ stack = []
+
+ def visit(element):
+ if isinstance(element, ScalarSelect):
+ # we don't want to dig into correlated subqueries,
+ # those are just column elements by themselves
+ yield element
+ elif element.__visit_name__ == "binary" and operators.is_comparison(
+ element.operator
+ ):
+ stack.insert(0, element)
+ for l in visit(element.left):
+ for r in visit(element.right):
+ fn(stack[0], l, r)
+ stack.pop(0)
+ for elem in element.get_children():
+ visit(elem)
+ else:
+ if isinstance(element, ColumnClause):
+ yield element
+ for elem in element.get_children():
+ for e in visit(elem):
+ yield e
+
+ list(visit(expr))
+ visit = None # remove gc cycles
+
+
+def find_tables(
+ clause,
+ check_columns=False,
+ include_aliases=False,
+ include_joins=False,
+ include_selects=False,
+ include_crud=False,
+):
+ """locate Table objects within the given expression."""
+
+ tables = []
+ _visitors = {}
+
+ if include_selects:
+ _visitors["select"] = _visitors["compound_select"] = tables.append
+
+ if include_joins:
+ _visitors["join"] = tables.append
+
+ if include_aliases:
+ _visitors["alias"] = _visitors["subquery"] = _visitors[
+ "tablesample"
+ ] = _visitors["lateral"] = tables.append
+
+ if include_crud:
+ _visitors["insert"] = _visitors["update"] = _visitors[
+ "delete"
+ ] = lambda ent: tables.append(ent.table)
+
+ if check_columns:
+
+ def visit_column(column):
+ tables.append(column.table)
+
+ _visitors["column"] = visit_column
+
+ _visitors["table"] = tables.append
+
+ visitors.traverse(clause, {}, _visitors)
+ return tables
+
+
+def unwrap_order_by(clause):
+ """Break up an 'order by' expression into individual column-expressions,
+ without DESC/ASC/NULLS FIRST/NULLS LAST"""
+
+ cols = util.column_set()
+ result = []
+ stack = deque([clause])
+
+ # examples
+ # column -> ASC/DESC == column
+ # column -> ASC/DESC -> label == column
+ # column -> label -> ASC/DESC -> label == column
+ # scalar_select -> label -> ASC/DESC == scalar_select -> label
+
+ while stack:
+ t = stack.popleft()
+ if isinstance(t, ColumnElement) and (
+ not isinstance(t, UnaryExpression)
+ or not operators.is_ordering_modifier(t.modifier)
+ ):
+ if isinstance(t, Label) and not isinstance(
+ t.element, ScalarSelect
+ ):
+ t = t.element
+
+ if isinstance(t, Grouping):
+ t = t.element
+
+ stack.append(t)
+ continue
+ elif isinstance(t, _label_reference):
+ t = t.element
+
+ stack.append(t)
+ continue
+ if isinstance(t, (_textual_label_reference)):
+ continue
+ if t not in cols:
+ cols.add(t)
+ result.append(t)
+
+ else:
+ for c in t.get_children():
+ stack.append(c)
+ return result
+
+
+def unwrap_label_reference(element):
+ def replace(elem):
+ if isinstance(elem, (_label_reference, _textual_label_reference)):
+ return elem.element
+
+ return visitors.replacement_traverse(element, {}, replace)
+
+
+def expand_column_list_from_order_by(collist, order_by):
+ """Given the columns clause and ORDER BY of a selectable,
+ return a list of column expressions that can be added to the collist
+ corresponding to the ORDER BY, without repeating those already
+ in the collist.
+
+ """
+ cols_already_present = set(
+ [
+ col.element if col._order_by_label_element is not None else col
+ for col in collist
+ ]
+ )
+
+ to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
+
+ return [col for col in to_look_for if col not in cols_already_present]
+
+
+def clause_is_present(clause, search):
+ """Given a target clause and a second to search within, return True
+ if the target is plainly present in the search without any
+ subqueries or aliases involved.
+
+ Basically descends through Joins.
+
+ """
+
+ for elem in surface_selectables(search):
+ if clause == elem: # use == here so that Annotated's compare
+ return True
+ else:
+ return False
+
+
+def tables_from_leftmost(clause):
+ if isinstance(clause, Join):
+ for t in tables_from_leftmost(clause.left):
+ yield t
+ for t in tables_from_leftmost(clause.right):
+ yield t
+ elif isinstance(clause, FromGrouping):
+ for t in tables_from_leftmost(clause.element):
+ yield t
+ else:
+ yield clause
+
+
+def surface_selectables(clause):
+ stack = [clause]
+ while stack:
+ elem = stack.pop()
+ yield elem
+ if isinstance(elem, Join):
+ stack.extend((elem.left, elem.right))
+ elif isinstance(elem, FromGrouping):
+ stack.append(elem.element)
+
+
+def surface_selectables_only(clause):
+ stack = [clause]
+ while stack:
+ elem = stack.pop()
+ if isinstance(elem, (TableClause, Alias)):
+ yield elem
+ if isinstance(elem, Join):
+ stack.extend((elem.left, elem.right))
+ elif isinstance(elem, FromGrouping):
+ stack.append(elem.element)
+ elif isinstance(elem, ColumnClause):
+ if elem.table is not None:
+ stack.append(elem.table)
+ else:
+ yield elem
+ elif elem is not None:
+ yield elem
+
+
+def extract_first_column_annotation(column, annotation_name):
+ filter_ = (FromGrouping, SelectBase)
+
+ stack = deque([column])
+ while stack:
+ elem = stack.popleft()
+ if annotation_name in elem._annotations:
+ return elem._annotations[annotation_name]
+ for sub in elem.get_children():
+ if isinstance(sub, filter_):
+ continue
+ stack.append(sub)
+ return None
+
+
+def selectables_overlap(left, right):
+ """Return True if left/right have some overlapping selectable"""
+
+ return bool(
+ set(surface_selectables(left)).intersection(surface_selectables(right))
+ )
+
+
+def bind_values(clause):
+ """Return an ordered list of "bound" values in the given clause.
+
+ E.g.::
+
+ >>> expr = and_(
+ ... table.c.foo==5, table.c.foo==7
+ ... )
+ >>> bind_values(expr)
+ [5, 7]
+ """
+
+ v = []
+
+ def visit_bindparam(bind):
+ v.append(bind.effective_value)
+
+ visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
+ return v
+
+
+def _quote_ddl_expr(element):
+ if isinstance(element, util.string_types):
+ element = element.replace("'", "''")
+ return "'%s'" % element
+ else:
+ return repr(element)
+
+
+class _repr_base(object):
+ _LIST = 0
+ _TUPLE = 1
+ _DICT = 2
+
+ __slots__ = ("max_chars",)
+
+ def trunc(self, value):
+ rep = repr(value)
+ lenrep = len(rep)
+ if lenrep > self.max_chars:
+ segment_length = self.max_chars // 2
+ rep = (
+ rep[0:segment_length]
+ + (
+ " ... (%d characters truncated) ... "
+ % (lenrep - self.max_chars)
+ )
+ + rep[-segment_length:]
+ )
+ return rep
+
+
+class _repr_row(_repr_base):
+ """Provide a string view of a row."""
+
+ __slots__ = ("row",)
+
+ def __init__(self, row, max_chars=300):
+ self.row = row
+ self.max_chars = max_chars
+
+ def __repr__(self):
+ trunc = self.trunc
+ return "(%s%s)" % (
+ ", ".join(trunc(value) for value in self.row),
+ "," if len(self.row) == 1 else "",
+ )
+
+
+class _repr_params(_repr_base):
+ """Provide a string view of bound parameters.
+
+ Truncates display to a given number of 'multi' parameter sets,
+ as well as long values to a given number of characters.
+
+ """
+
+ __slots__ = "params", "batches", "ismulti"
+
+ def __init__(self, params, batches, max_chars=300, ismulti=None):
+ self.params = params
+ self.ismulti = ismulti
+ self.batches = batches
+ self.max_chars = max_chars
+
+ def __repr__(self):
+ if self.ismulti is None:
+ return self.trunc(self.params)
+
+ if isinstance(self.params, list):
+ typ = self._LIST
+
+ elif isinstance(self.params, tuple):
+ typ = self._TUPLE
+ elif isinstance(self.params, dict):
+ typ = self._DICT
+ else:
+ return self.trunc(self.params)
+
+ if self.ismulti and len(self.params) > self.batches:
+ msg = " ... displaying %i of %i total bound parameter sets ... "
+ return " ".join(
+ (
+ self._repr_multi(self.params[: self.batches - 2], typ)[
+ 0:-1
+ ],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(self.params[-2:], typ)[1:],
+ )
+ )
+ elif self.ismulti:
+ return self._repr_multi(self.params, typ)
+ else:
+ return self._repr_params(self.params, typ)
+
+ def _repr_multi(self, multi_params, typ):
+ if multi_params:
+ if isinstance(multi_params[0], list):
+ elem_type = self._LIST
+ elif isinstance(multi_params[0], tuple):
+ elem_type = self._TUPLE
+ elif isinstance(multi_params[0], dict):
+ elem_type = self._DICT
+ else:
+ assert False, "Unknown parameter type %s" % (
+ type(multi_params[0])
+ )
+
+ elements = ", ".join(
+ self._repr_params(params, elem_type) for params in multi_params
+ )
+ else:
+ elements = ""
+
+ if typ == self._LIST:
+ return "[%s]" % elements
+ else:
+ return "(%s)" % elements
+
+ def _repr_params(self, params, typ):
+ trunc = self.trunc
+ if typ is self._DICT:
+ return "{%s}" % (
+ ", ".join(
+ "%r: %s" % (key, trunc(value))
+ for key, value in params.items()
+ )
+ )
+ elif typ is self._TUPLE:
+ return "(%s%s)" % (
+ ", ".join(trunc(value) for value in params),
+ "," if len(params) == 1 else "",
+ )
+ else:
+ return "[%s]" % (", ".join(trunc(value) for value in params))
+
+
+def adapt_criterion_to_null(crit, nulls):
+ """given criterion containing bind params, convert selected elements
+ to IS NULL.
+
+ """
+
+ def visit_binary(binary):
+ if (
+ isinstance(binary.left, BindParameter)
+ and binary.left._identifying_key in nulls
+ ):
+ # reverse order if the NULL is on the left side
+ binary.left = binary.right
+ binary.right = Null()
+ binary.operator = operators.is_
+ binary.negate = operators.is_not
+ elif (
+ isinstance(binary.right, BindParameter)
+ and binary.right._identifying_key in nulls
+ ):
+ binary.right = Null()
+ binary.operator = operators.is_
+ binary.negate = operators.is_not
+
+ return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
+
+
+def splice_joins(left, right, stop_on=None):
+ if left is None:
+ return right
+
+ stack = [(right, None)]
+
+ adapter = ClauseAdapter(left)
+ ret = None
+ while stack:
+ (right, prevright) = stack.pop()
+ if isinstance(right, Join) and right is not stop_on:
+ right = right._clone()
+ right.onclause = adapter.traverse(right.onclause)
+ stack.append((right.left, right))
+ else:
+ right = adapter.traverse(right)
+ if prevright is not None:
+ prevright.left = right
+ if ret is None:
+ ret = right
+
+ return ret
+
+
+def reduce_columns(columns, *clauses, **kw):
+ r"""given a list of columns, return a 'reduced' set based on natural
+ equivalents.
+
+ the set is reduced to the smallest list of columns which have no natural
+ equivalent present in the list. A "natural equivalent" means that two
+ columns will ultimately represent the same value because they are related
+ by a foreign key.
+
+ \*clauses is an optional list of join clauses which will be traversed
+ to further identify columns that are "equivalent".
+
+ \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
+ whose tables are not yet configured, or columns that aren't yet present.
+
+ This function is primarily used to determine the most minimal "primary
+ key" from a selectable, by reducing the set of primary key columns present
+ in the selectable to just those that are not repeated.
+
+ """
+ ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
+ only_synonyms = kw.pop("only_synonyms", False)
+
+ columns = util.ordered_column_set(columns)
+
+ omit = util.column_set()
+ for col in columns:
+ for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
+ for c in columns:
+ if c is col:
+ continue
+ try:
+ fk_col = fk.column
+ except exc.NoReferencedColumnError:
+ # TODO: add specific coverage here
+ # to test/sql/test_selectable ReduceTest
+ if ignore_nonexistent_tables:
+ continue
+ else:
+ raise
+ except exc.NoReferencedTableError:
+ # TODO: add specific coverage here
+ # to test/sql/test_selectable ReduceTest
+ if ignore_nonexistent_tables:
+ continue
+ else:
+ raise
+ if fk_col.shares_lineage(c) and (
+ not only_synonyms or c.name == col.name
+ ):
+ omit.add(col)
+ break
+
+ if clauses:
+
+ def visit_binary(binary):
+ if binary.operator == operators.eq:
+ cols = util.column_set(
+ chain(*[c.proxy_set for c in columns.difference(omit)])
+ )
+ if binary.left in cols and binary.right in cols:
+ for c in reversed(columns):
+ if c.shares_lineage(binary.right) and (
+ not only_synonyms or c.name == binary.left.name
+ ):
+ omit.add(c)
+ break
+
+ for clause in clauses:
+ if clause is not None:
+ visitors.traverse(clause, {}, {"binary": visit_binary})
+
+ return ColumnSet(columns.difference(omit))
+
+
+def criterion_as_pairs(
+ expression,
+ consider_as_foreign_keys=None,
+ consider_as_referenced_keys=None,
+ any_operator=False,
+):
+ """traverse an expression and locate binary criterion pairs."""
+
+ if consider_as_foreign_keys and consider_as_referenced_keys:
+ raise exc.ArgumentError(
+ "Can only specify one of "
+ "'consider_as_foreign_keys' or "
+ "'consider_as_referenced_keys'"
+ )
+
+ def col_is(a, b):
+ # return a is b
+ return a.compare(b)
+
+ def visit_binary(binary):
+ if not any_operator and binary.operator is not operators.eq:
+ return
+ if not isinstance(binary.left, ColumnElement) or not isinstance(
+ binary.right, ColumnElement
+ ):
+ return
+
+ if consider_as_foreign_keys:
+ if binary.left in consider_as_foreign_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_foreign_keys
+ ):
+ pairs.append((binary.right, binary.left))
+ elif binary.right in consider_as_foreign_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_foreign_keys
+ ):
+ pairs.append((binary.left, binary.right))
+ elif consider_as_referenced_keys:
+ if binary.left in consider_as_referenced_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_referenced_keys
+ ):
+ pairs.append((binary.left, binary.right))
+ elif binary.right in consider_as_referenced_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_referenced_keys
+ ):
+ pairs.append((binary.right, binary.left))
+ else:
+ if isinstance(binary.left, Column) and isinstance(
+ binary.right, Column
+ ):
+ if binary.left.references(binary.right):
+ pairs.append((binary.right, binary.left))
+ elif binary.right.references(binary.left):
+ pairs.append((binary.left, binary.right))
+
+ pairs = []
+ visitors.traverse(expression, {}, {"binary": visit_binary})
+ return pairs
+
+
+class ClauseAdapter(visitors.ReplacingExternalTraversal):
+ """Clones and modifies clauses based on column correspondence.
+
+ E.g.::
+
+ table1 = Table('sometable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+ table2 = Table('someothertable', metadata,
+ Column('col1', Integer),
+ Column('col2', Integer)
+ )
+
+ condition = table1.c.col1 == table2.c.col1
+
+ make an alias of table1::
+
+ s = table1.alias('foo')
+
+ calling ``ClauseAdapter(s).traverse(condition)`` converts
+ condition to read::
+
+ s.c.col1 == table2.c.col1
+
+ """
+
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ anonymize_labels=False,
+ adapt_from_selectables=None,
+ ):
+ self.__traverse_options__ = {
+ "stop_on": [selectable],
+ "anonymize_labels": anonymize_labels,
+ }
+ self.selectable = selectable
+ self.include_fn = include_fn
+ self.exclude_fn = exclude_fn
+ self.equivalents = util.column_dict(equivalents or {})
+ self.adapt_on_names = adapt_on_names
+ self.adapt_from_selectables = adapt_from_selectables
+
+ def _corresponding_column(
+ self, col, require_embedded, _seen=util.EMPTY_SET
+ ):
+
+ newcol = self.selectable.corresponding_column(
+ col, require_embedded=require_embedded
+ )
+ if newcol is None and col in self.equivalents and col not in _seen:
+ for equiv in self.equivalents[col]:
+ newcol = self._corresponding_column(
+ equiv,
+ require_embedded=require_embedded,
+ _seen=_seen.union([col]),
+ )
+ if newcol is not None:
+ return newcol
+ if self.adapt_on_names and newcol is None:
+ newcol = self.selectable.exported_columns.get(col.name)
+ return newcol
+
+ @util.preload_module("sqlalchemy.sql.functions")
+ def replace(self, col, _include_singleton_constants=False):
+ functions = util.preloaded.sql_functions
+
+ if isinstance(col, FromClause) and not isinstance(
+ col, functions.FunctionElement
+ ):
+
+ if self.selectable.is_derived_from(col):
+ if self.adapt_from_selectables:
+ for adp in self.adapt_from_selectables:
+ if adp.is_derived_from(col):
+ break
+ else:
+ return None
+ return self.selectable
+ elif isinstance(col, Alias) and isinstance(
+ col.element, TableClause
+ ):
+ # we are a SELECT statement and not derived from an alias of a
+ # table (which nonetheless may be a table our SELECT derives
+ # from), so return the alias to prevent further traversal
+ # or
+ # we are an alias of a table and we are not derived from an
+ # alias of a table (which nonetheless may be the same table
+ # as ours) so, same thing
+ return col
+ else:
+ # other cases where we are a selectable and the element
+ # is another join or selectable that contains a table which our
+ # selectable derives from, that we want to process
+ return None
+
+ elif not isinstance(col, ColumnElement):
+ return None
+ elif not _include_singleton_constants and col._is_singleton_constant:
+ # dont swap out NULL, TRUE, FALSE for a label name
+ # in a SQL statement that's being rewritten,
+ # leave them as the constant. This is first noted in #6259,
+ # however the logic to check this moved here as of #7154 so that
+ # it is made specific to SQL rewriting and not all column
+ # correspondence
+ return None
+
+ if "adapt_column" in col._annotations:
+ col = col._annotations["adapt_column"]
+
+ if self.adapt_from_selectables and col not in self.equivalents:
+ for adp in self.adapt_from_selectables:
+ if adp.c.corresponding_column(col, False) is not None:
+ break
+ else:
+ return None
+
+ if self.include_fn and not self.include_fn(col):
+ return None
+ elif self.exclude_fn and self.exclude_fn(col):
+ return None
+ else:
+ return self._corresponding_column(col, True)
+
+
+class ColumnAdapter(ClauseAdapter):
+ """Extends ClauseAdapter with extra utility functions.
+
+ Key aspects of ColumnAdapter include:
+
+ * Expressions that are adapted are stored in a persistent
+ .columns collection; so that an expression E adapted into
+ an expression E1, will return the same object E1 when adapted
+ a second time. This is important in particular for things like
+ Label objects that are anonymized, so that the ColumnAdapter can
+ be used to present a consistent "adapted" view of things.
+
+ * Exclusion of items from the persistent collection based on
+ include/exclude rules, but also independent of hash identity.
+ This because "annotated" items all have the same hash identity as their
+ parent.
+
+ * "wrapping" capability is added, so that the replacement of an expression
+ E can proceed through a series of adapters. This differs from the
+ visitor's "chaining" feature in that the resulting object is passed
+ through all replacing functions unconditionally, rather than stopping
+ at the first one that returns non-None.
+
+ * An adapt_required option, used by eager loading to indicate that
+ We don't trust a result row column that is not translated.
+ This is to prevent a column from being interpreted as that
+ of the child row in a self-referential scenario, see
+ inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
+
+ """
+
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ adapt_required=False,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ adapt_from_selectables=None,
+ ):
+ ClauseAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ include_fn=include_fn,
+ exclude_fn=exclude_fn,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=anonymize_labels,
+ adapt_from_selectables=adapt_from_selectables,
+ )
+
+ self.columns = util.WeakPopulateDict(self._locate_col)
+ if self.include_fn or self.exclude_fn:
+ self.columns = self._IncludeExcludeMapping(self, self.columns)
+ self.adapt_required = adapt_required
+ self.allow_label_resolve = allow_label_resolve
+ self._wrap = None
+
+ class _IncludeExcludeMapping(object):
+ def __init__(self, parent, columns):
+ self.parent = parent
+ self.columns = columns
+
+ def __getitem__(self, key):
+ if (
+ self.parent.include_fn and not self.parent.include_fn(key)
+ ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
+ if self.parent._wrap:
+ return self.parent._wrap.columns[key]
+ else:
+ return key
+ return self.columns[key]
+
+ def wrap(self, adapter):
+ ac = self.__class__.__new__(self.__class__)
+ ac.__dict__.update(self.__dict__)
+ ac._wrap = adapter
+ ac.columns = util.WeakPopulateDict(ac._locate_col)
+ if ac.include_fn or ac.exclude_fn:
+ ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
+
+ return ac
+
+ def traverse(self, obj):
+ return self.columns[obj]
+
+ adapt_clause = traverse
+ adapt_list = ClauseAdapter.copy_and_process
+
+ def adapt_check_present(self, col):
+ newcol = self.columns[col]
+
+ if newcol is col and self._corresponding_column(col, True) is None:
+ return None
+
+ return newcol
+
+ def _locate_col(self, col):
+ # both replace and traverse() are overly complicated for what
+ # we are doing here and we would do better to have an inlined
+ # version that doesn't build up as much overhead. the issue is that
+ # sometimes the lookup does in fact have to adapt the insides of
+ # say a labeled scalar subquery. However, if the object is an
+ # Immutable, i.e. Column objects, we can skip the "clone" /
+ # "copy internals" part since those will be no-ops in any case.
+ # additionally we want to catch singleton objects null/true/false
+ # and make sure they are adapted as well here.
+
+ if col._is_immutable:
+ for vis in self.visitor_iterator:
+ c = vis.replace(col, _include_singleton_constants=True)
+ if c is not None:
+ break
+ else:
+ c = col
+ else:
+ c = ClauseAdapter.traverse(self, col)
+
+ if self._wrap:
+ c2 = self._wrap._locate_col(c)
+ if c2 is not None:
+ c = c2
+
+ if self.adapt_required and c is col:
+ return None
+
+ c._allow_label_resolve = self.allow_label_resolve
+
+ return c
+
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ del d["columns"]
+ return d
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self.columns = util.WeakPopulateDict(self._locate_col)
+
+
+def _offset_or_limit_clause(element, name=None, type_=None):
+ """Convert the given value to an "offset or limit" clause.
+
+ This handles incoming integers and converts to an expression; if
+ an expression is already given, it is passed through.
+
+ """
+ return coercions.expect(
+ roles.LimitOffsetRole, element, name=name, type_=type_
+ )
+
+
+def _offset_or_limit_clause_asint_if_possible(clause):
+ """Return the offset or limit clause as a simple integer if possible,
+ else return the clause.
+
+ """
+ if clause is None:
+ return None
+ if hasattr(clause, "_limit_offset_value"):
+ value = clause._limit_offset_value
+ return util.asint(value)
+ else:
+ return clause
+
+
+def _make_slice(limit_clause, offset_clause, start, stop):
+ """Compute LIMIT/OFFSET in terms of slice start/end"""
+
+ # for calculated limit/offset, try to do the addition of
+ # values to offset in Python, however if a SQL clause is present
+ # then the addition has to be on the SQL side.
+ if start is not None and stop is not None:
+ offset_clause = _offset_or_limit_clause_asint_if_possible(
+ offset_clause
+ )
+ if offset_clause is None:
+ offset_clause = 0
+
+ if start != 0:
+ offset_clause = offset_clause + start
+
+ if offset_clause == 0:
+ offset_clause = None
+ else:
+ offset_clause = _offset_or_limit_clause(offset_clause)
+
+ limit_clause = _offset_or_limit_clause(stop - start)
+
+ elif start is None and stop is not None:
+ limit_clause = _offset_or_limit_clause(stop)
+ elif start is not None and stop is None:
+ offset_clause = _offset_or_limit_clause_asint_if_possible(
+ offset_clause
+ )
+ if offset_clause is None:
+ offset_clause = 0
+
+ if start != 0:
+ offset_clause = offset_clause + start
+
+ if offset_clause == 0:
+ offset_clause = None
+ else:
+ offset_clause = _offset_or_limit_clause(offset_clause)
+
+ return limit_clause, offset_clause
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
new file mode 100644
index 0000000..f72d83a
--- /dev/null
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -0,0 +1,852 @@
+# sql/visitors.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Visitor/traversal interface and library functions.
+
+SQLAlchemy schema and expression constructs rely on a Python-centric
+version of the classic "visitor" pattern as the primary way in which
+they apply functionality. The most common use of this pattern
+is statement compilation, where individual expression classes match
+up to rendering methods that produce a string result. Beyond this,
+the visitor system is also used to inspect expressions for various
+information and patterns, as well as for the purposes of applying
+transformations to expressions.
+
+Examples of how the visit system is used can be seen in the source code
+of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler``
+modules. Some background on clause adaption is also at
+https://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
+
+"""
+
+from collections import deque
+import itertools
+import operator
+
+from .. import exc
+from .. import util
+from ..util import langhelpers
+from ..util import symbol
+
+__all__ = [
+ "iterate",
+ "traverse_using",
+ "traverse",
+ "cloned_traverse",
+ "replacement_traverse",
+ "Traversible",
+ "TraversibleType",
+ "ExternalTraversal",
+ "InternalTraversal",
+]
+
+
+def _generate_compiler_dispatch(cls):
+ """Generate a _compiler_dispatch() external traversal on classes with a
+ __visit_name__ attribute.
+
+ """
+ visit_name = cls.__visit_name__
+
+ if "_compiler_dispatch" in cls.__dict__:
+ # class has a fixed _compiler_dispatch() method.
+ # copy it to "original" so that we can get it back if
+ # sqlalchemy.ext.compiles overrides it.
+ cls._original_compiler_dispatch = cls._compiler_dispatch
+ return
+
+ if not isinstance(visit_name, util.compat.string_types):
+ raise exc.InvalidRequestError(
+ "__visit_name__ on class %s must be a string at the class level"
+ % cls.__name__
+ )
+
+ name = "visit_%s" % visit_name
+ getter = operator.attrgetter(name)
+
+ def _compiler_dispatch(self, visitor, **kw):
+ """Look for an attribute named "visit_<visit_name>" on the
+ visitor, and call it with the same kw params.
+
+ """
+ try:
+ meth = getter(visitor)
+ except AttributeError as err:
+ return visitor.visit_unsupported_compilation(self, err, **kw)
+
+ else:
+ return meth(self, **kw)
+
+ cls._compiler_dispatch = (
+ cls._original_compiler_dispatch
+ ) = _compiler_dispatch
+
+
+class TraversibleType(type):
+ """Metaclass which assigns dispatch attributes to various kinds of
+ "visitable" classes.
+
+ Attributes include:
+
+ * The ``_compiler_dispatch`` method, corresponding to ``__visit_name__``.
+ This is called "external traversal" because the caller of each visit()
+ method is responsible for sub-traversing the inner elements of each
+ object. This is appropriate for string compilers and other traversals
+ that need to call upon the inner elements in a specific pattern.
+
+ * internal traversal collections ``_children_traversal``,
+ ``_cache_key_traversal``, ``_copy_internals_traversal``, generated from
+ an optional ``_traverse_internals`` collection of symbols which comes
+ from the :class:`.InternalTraversal` list of symbols. This is called
+ "internal traversal" MARKMARK
+
+ """
+
+ def __init__(cls, clsname, bases, clsdict):
+ if clsname != "Traversible":
+ if "__visit_name__" in clsdict:
+ _generate_compiler_dispatch(cls)
+
+ super(TraversibleType, cls).__init__(clsname, bases, clsdict)
+
+
+class Traversible(util.with_metaclass(TraversibleType)):
+ """Base class for visitable objects, applies the
+ :class:`.visitors.TraversibleType` metaclass.
+
+ """
+
+ def __class_getitem__(cls, key):
+ # allow generic classes in py3.9+
+ return cls
+
+ @util.preload_module("sqlalchemy.sql.traversals")
+ def get_children(self, omit_attrs=(), **kw):
+ r"""Return immediate child :class:`.visitors.Traversible`
+ elements of this :class:`.visitors.Traversible`.
+
+ This is used for visit traversal.
+
+ \**kw may contain flags that change the collection that is
+ returned, for example to return a subset of items in order to
+ cut down on larger traversals, or to return child items from a
+ different context (such as schema-level collections instead of
+ clause-level).
+
+ """
+
+ traversals = util.preloaded.sql_traversals
+
+ try:
+ traverse_internals = self._traverse_internals
+ except AttributeError:
+ # user-defined classes may not have a _traverse_internals
+ return []
+
+ dispatch = traversals._get_children.run_generated_dispatch
+ return itertools.chain.from_iterable(
+ meth(obj, **kw)
+ for attrname, obj, meth in dispatch(
+ self, traverse_internals, "_generated_get_children_traversal"
+ )
+ if attrname not in omit_attrs and obj is not None
+ )
+
+
+class _InternalTraversalType(type):
+ def __init__(cls, clsname, bases, clsdict):
+ if cls.__name__ in ("InternalTraversal", "ExtendedInternalTraversal"):
+ lookup = {}
+ for key, sym in clsdict.items():
+ if key.startswith("dp_"):
+ visit_key = key.replace("dp_", "visit_")
+ sym_name = sym.name
+ assert sym_name not in lookup, sym_name
+ lookup[sym] = lookup[sym_name] = visit_key
+ if hasattr(cls, "_dispatch_lookup"):
+ lookup.update(cls._dispatch_lookup)
+ cls._dispatch_lookup = lookup
+
+ super(_InternalTraversalType, cls).__init__(clsname, bases, clsdict)
+
+
+def _generate_dispatcher(visitor, internal_dispatch, method_name):
+ names = []
+ for attrname, visit_sym in internal_dispatch:
+ meth = visitor.dispatch(visit_sym)
+ if meth:
+ visit_name = ExtendedInternalTraversal._dispatch_lookup[visit_sym]
+ names.append((attrname, visit_name))
+
+ code = (
+ (" return [\n")
+ + (
+ ", \n".join(
+ " (%r, self.%s, visitor.%s)"
+ % (attrname, attrname, visit_name)
+ for attrname, visit_name in names
+ )
+ )
+ + ("\n ]\n")
+ )
+ meth_text = ("def %s(self, visitor):\n" % method_name) + code + "\n"
+ # print(meth_text)
+ return langhelpers._exec_code_in_env(meth_text, {}, method_name)
+
+
+class InternalTraversal(util.with_metaclass(_InternalTraversalType, object)):
+ r"""Defines visitor symbols used for internal traversal.
+
+ The :class:`.InternalTraversal` class is used in two ways. One is that
+ it can serve as the superclass for an object that implements the
+ various visit methods of the class. The other is that the symbols
+ themselves of :class:`.InternalTraversal` are used within
+ the ``_traverse_internals`` collection. Such as, the :class:`.Case`
+ object defines ``_traverse_internals`` as ::
+
+ _traverse_internals = [
+ ("value", InternalTraversal.dp_clauseelement),
+ ("whens", InternalTraversal.dp_clauseelement_tuples),
+ ("else_", InternalTraversal.dp_clauseelement),
+ ]
+
+ Above, the :class:`.Case` class indicates its internal state as the
+ attributes named ``value``, ``whens``, and ``else_``. They each
+ link to an :class:`.InternalTraversal` method which indicates the type
+ of datastructure referred towards.
+
+ Using the ``_traverse_internals`` structure, objects of type
+ :class:`.InternalTraversible` will have the following methods automatically
+ implemented:
+
+ * :meth:`.Traversible.get_children`
+
+ * :meth:`.Traversible._copy_internals`
+
+ * :meth:`.Traversible._gen_cache_key`
+
+ Subclasses can also implement these methods directly, particularly for the
+ :meth:`.Traversible._copy_internals` method, when special steps
+ are needed.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def dispatch(self, visit_symbol):
+ """Given a method from :class:`.InternalTraversal`, return the
+ corresponding method on a subclass.
+
+ """
+ name = self._dispatch_lookup[visit_symbol]
+ return getattr(self, name, None)
+
+ def run_generated_dispatch(
+ self, target, internal_dispatch, generate_dispatcher_name
+ ):
+ try:
+ dispatcher = target.__class__.__dict__[generate_dispatcher_name]
+ except KeyError:
+ # most of the dispatchers are generated up front
+ # in sqlalchemy/sql/__init__.py ->
+ # traversals.py-> _preconfigure_traversals().
+ # this block will generate any remaining dispatchers.
+ dispatcher = self.generate_dispatch(
+ target.__class__, internal_dispatch, generate_dispatcher_name
+ )
+ return dispatcher(target, self)
+
+ def generate_dispatch(
+ self, target_cls, internal_dispatch, generate_dispatcher_name
+ ):
+ dispatcher = _generate_dispatcher(
+ self, internal_dispatch, generate_dispatcher_name
+ )
+ # assert isinstance(target_cls, type)
+ setattr(target_cls, generate_dispatcher_name, dispatcher)
+ return dispatcher
+
+ dp_has_cache_key = symbol("HC")
+ """Visit a :class:`.HasCacheKey` object."""
+
+ dp_has_cache_key_list = symbol("HL")
+ """Visit a list of :class:`.HasCacheKey` objects."""
+
+ dp_clauseelement = symbol("CE")
+ """Visit a :class:`_expression.ClauseElement` object."""
+
+ dp_fromclause_canonical_column_collection = symbol("FC")
+ """Visit a :class:`_expression.FromClause` object in the context of the
+ ``columns`` attribute.
+
+ The column collection is "canonical", meaning it is the originally
+ defined location of the :class:`.ColumnClause` objects. Right now
+ this means that the object being visited is a
+ :class:`_expression.TableClause`
+ or :class:`_schema.Table` object only.
+
+ """
+
+ dp_clauseelement_tuples = symbol("CTS")
+ """Visit a list of tuples which contain :class:`_expression.ClauseElement`
+ objects.
+
+ """
+
+ dp_clauseelement_list = symbol("CL")
+ """Visit a list of :class:`_expression.ClauseElement` objects.
+
+ """
+
+ dp_clauseelement_tuple = symbol("CT")
+ """Visit a tuple of :class:`_expression.ClauseElement` objects.
+
+ """
+
+ dp_executable_options = symbol("EO")
+
+ dp_with_context_options = symbol("WC")
+
+ dp_fromclause_ordered_set = symbol("CO")
+ """Visit an ordered set of :class:`_expression.FromClause` objects. """
+
+ dp_string = symbol("S")
+ """Visit a plain string value.
+
+ Examples include table and column names, bound parameter keys, special
+ keywords such as "UNION", "UNION ALL".
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_string_list = symbol("SL")
+ """Visit a list of strings."""
+
+ dp_anon_name = symbol("AN")
+ """Visit a potentially "anonymized" string value.
+
+ The string value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_boolean = symbol("B")
+ """Visit a boolean value.
+
+ The boolean value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_operator = symbol("O")
+ """Visit an operator.
+
+ The operator is a function from the :mod:`sqlalchemy.sql.operators`
+ module.
+
+ The operator value is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_type = symbol("T")
+ """Visit a :class:`.TypeEngine` object
+
+ The type object is considered to be significant for cache key
+ generation.
+
+ """
+
+ dp_plain_dict = symbol("PD")
+ """Visit a dictionary with string keys.
+
+ The keys of the dictionary should be strings, the values should
+ be immutable and hashable. The dictionary is considered to be
+ significant for cache key generation.
+
+ """
+
+ dp_dialect_options = symbol("DO")
+ """Visit a dialect options structure."""
+
+ dp_string_clauseelement_dict = symbol("CD")
+ """Visit a dictionary of string keys to :class:`_expression.ClauseElement`
+ objects.
+
+ """
+
+ dp_string_multi_dict = symbol("MD")
+ """Visit a dictionary of string keys to values which may either be
+ plain immutable/hashable or :class:`.HasCacheKey` objects.
+
+ """
+
+ dp_annotations_key = symbol("AK")
+ """Visit the _annotations_cache_key element.
+
+ This is a dictionary of additional information about a ClauseElement
+ that modifies its role. It should be included when comparing or caching
+ objects, however generating this key is relatively expensive. Visitors
+ should check the "_annotations" dict for non-None first before creating
+ this key.
+
+ """
+
+ dp_plain_obj = symbol("PO")
+ """Visit a plain python object.
+
+ The value should be immutable and hashable, such as an integer.
+ The value is considered to be significant for cache key generation.
+
+ """
+
+ dp_named_ddl_element = symbol("DD")
+ """Visit a simple named DDL element.
+
+ The current object used by this method is the :class:`.Sequence`.
+
+ The object is only considered to be important for cache key generation
+ as far as its name, but not any other aspects of it.
+
+ """
+
+ dp_prefix_sequence = symbol("PS")
+ """Visit the sequence represented by :class:`_expression.HasPrefixes`
+ or :class:`_expression.HasSuffixes`.
+
+ """
+
+ dp_table_hint_list = symbol("TH")
+ """Visit the ``_hints`` collection of a :class:`_expression.Select`
+ object.
+
+ """
+
+ dp_setup_join_tuple = symbol("SJ")
+
+ dp_memoized_select_entities = symbol("ME")
+
+ dp_statement_hint_list = symbol("SH")
+ """Visit the ``_statement_hints`` collection of a
+ :class:`_expression.Select`
+ object.
+
+ """
+
+ dp_unknown_structure = symbol("UK")
+ """Visit an unknown structure.
+
+ """
+
+ dp_dml_ordered_values = symbol("DML_OV")
+ """Visit the values() ordered tuple list of an
+ :class:`_expression.Update` object."""
+
+ dp_dml_values = symbol("DML_V")
+ """Visit the values() dictionary of a :class:`.ValuesBase`
+ (e.g. Insert or Update) object.
+
+ """
+
+ dp_dml_multi_values = symbol("DML_MV")
+ """Visit the values() multi-valued list of dictionaries of an
+ :class:`_expression.Insert` object.
+
+ """
+
+ dp_propagate_attrs = symbol("PA")
+ """Visit the propagate attrs dict. This hardcodes to the particular
+ elements we care about right now."""
+
+
+class ExtendedInternalTraversal(InternalTraversal):
+ """Defines additional symbols that are useful in caching applications.
+
+ Traversals for :class:`_expression.ClauseElement` objects only need to use
+ those symbols present in :class:`.InternalTraversal`. However, for
+ additional caching use cases within the ORM, symbols dealing with the
+ :class:`.HasCacheKey` class are added here.
+
+ """
+
+ dp_ignore = symbol("IG")
+ """Specify an object that should be ignored entirely.
+
+ This currently applies function call argument caching where some
+ arguments should not be considered to be part of a cache key.
+
+ """
+
+ dp_inspectable = symbol("IS")
+ """Visit an inspectable object where the return value is a
+ :class:`.HasCacheKey` object."""
+
+ dp_multi = symbol("M")
+ """Visit an object that may be a :class:`.HasCacheKey` or may be a
+ plain hashable object."""
+
+ dp_multi_list = symbol("MT")
+ """Visit a tuple containing elements that may be :class:`.HasCacheKey` or
+ may be a plain hashable object."""
+
+ dp_has_cache_key_tuples = symbol("HT")
+ """Visit a list of tuples which contain :class:`.HasCacheKey`
+ objects.
+
+ """
+
+ dp_inspectable_list = symbol("IL")
+ """Visit a list of inspectable objects which upon inspection are
+ HasCacheKey objects."""
+
+
+class ExternalTraversal(object):
+ """Base class for visitor objects which can traverse externally using
+ the :func:`.visitors.traverse` function.
+
+ Direct usage of the :func:`.visitors.traverse` function is usually
+ preferred.
+
+ """
+
+ __traverse_options__ = {}
+
+ def traverse_single(self, obj, **kw):
+ for v in self.visitor_iterator:
+ meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
+ if meth:
+ return meth(obj, **kw)
+
+ def iterate(self, obj):
+ """Traverse the given expression structure, returning an iterator
+ of all elements.
+
+ """
+ return iterate(obj, self.__traverse_options__)
+
+ def traverse(self, obj):
+ """Traverse and visit the given expression structure."""
+
+ return traverse(obj, self.__traverse_options__, self._visitor_dict)
+
+ @util.memoized_property
+ def _visitor_dict(self):
+ visitors = {}
+
+ for name in dir(self):
+ if name.startswith("visit_"):
+ visitors[name[6:]] = getattr(self, name)
+ return visitors
+
+ @property
+ def visitor_iterator(self):
+ """Iterate through this visitor and each 'chained' visitor."""
+
+ v = self
+ while v:
+ yield v
+ v = getattr(v, "_next", None)
+
+ def chain(self, visitor):
+ """'Chain' an additional ClauseVisitor onto this ClauseVisitor.
+
+ The chained visitor will receive all visit events after this one.
+
+ """
+ tail = list(self.visitor_iterator)[-1]
+ tail._next = visitor
+ return self
+
+
+class CloningExternalTraversal(ExternalTraversal):
+ """Base class for visitor objects which can traverse using
+ the :func:`.visitors.cloned_traverse` function.
+
+ Direct usage of the :func:`.visitors.cloned_traverse` function is usually
+ preferred.
+
+
+ """
+
+ def copy_and_process(self, list_):
+ """Apply cloned traversal to the given list of elements, and return
+ the new list.
+
+ """
+ return [self.traverse(x) for x in list_]
+
+ def traverse(self, obj):
+ """Traverse and visit the given expression structure."""
+
+ return cloned_traverse(
+ obj, self.__traverse_options__, self._visitor_dict
+ )
+
+
+class ReplacingExternalTraversal(CloningExternalTraversal):
+ """Base class for visitor objects which can traverse using
+ the :func:`.visitors.replacement_traverse` function.
+
+ Direct usage of the :func:`.visitors.replacement_traverse` function is
+ usually preferred.
+
+ """
+
+ def replace(self, elem):
+ """Receive pre-copied elements during a cloning traversal.
+
+ If the method returns a new element, the element is used
+ instead of creating a simple copy of the element. Traversal
+ will halt on the newly returned element if it is re-encountered.
+ """
+ return None
+
+ def traverse(self, obj):
+ """Traverse and visit the given expression structure."""
+
+ def replace(elem):
+ for v in self.visitor_iterator:
+ e = v.replace(elem)
+ if e is not None:
+ return e
+
+ return replacement_traverse(obj, self.__traverse_options__, replace)
+
+
+# backwards compatibility
+Visitable = Traversible
+VisitableType = TraversibleType
+ClauseVisitor = ExternalTraversal
+CloningVisitor = CloningExternalTraversal
+ReplacingCloningVisitor = ReplacingExternalTraversal
+
+
+def iterate(obj, opts=util.immutabledict()):
+ r"""Traverse the given expression structure, returning an iterator.
+
+ Traversal is configured to be breadth-first.
+
+ The central API feature used by the :func:`.visitors.iterate`
+ function is the
+ :meth:`_expression.ClauseElement.get_children` method of
+ :class:`_expression.ClauseElement` objects. This method should return all
+ the :class:`_expression.ClauseElement` objects which are associated with a
+ particular :class:`_expression.ClauseElement` object. For example, a
+ :class:`.Case` structure will refer to a series of
+ :class:`_expression.ColumnElement` objects within its "whens" and "else\_"
+ member variables.
+
+ :param obj: :class:`_expression.ClauseElement` structure to be traversed
+
+ :param opts: dictionary of iteration options. This dictionary is usually
+ empty in modern usage.
+
+ """
+ yield obj
+ children = obj.get_children(**opts)
+
+ if not children:
+ return
+
+ stack = deque([children])
+ while stack:
+ t_iterator = stack.popleft()
+ for t in t_iterator:
+ yield t
+ stack.append(t.get_children(**opts))
+
+
+def traverse_using(iterator, obj, visitors):
+ """Visit the given expression structure using the given iterator of
+ objects.
+
+ :func:`.visitors.traverse_using` is usually called internally as the result
+ of the :func:`.visitors.traverse` function.
+
+ :param iterator: an iterable or sequence which will yield
+ :class:`_expression.ClauseElement`
+ structures; the iterator is assumed to be the
+ product of the :func:`.visitors.iterate` function.
+
+ :param obj: the :class:`_expression.ClauseElement`
+ that was used as the target of the
+ :func:`.iterate` function.
+
+ :param visitors: dictionary of visit functions. See :func:`.traverse`
+ for details on this dictionary.
+
+ .. seealso::
+
+ :func:`.traverse`
+
+
+ """
+ for target in iterator:
+ meth = visitors.get(target.__visit_name__, None)
+ if meth:
+ meth(target)
+ return obj
+
+
+def traverse(obj, opts, visitors):
+ """Traverse and visit the given expression structure using the default
+ iterator.
+
+ e.g.::
+
+ from sqlalchemy.sql import visitors
+
+ stmt = select(some_table).where(some_table.c.foo == 'bar')
+
+ def visit_bindparam(bind_param):
+ print("found bound value: %s" % bind_param.value)
+
+ visitors.traverse(stmt, {}, {"bindparam": visit_bindparam})
+
+ The iteration of objects uses the :func:`.visitors.iterate` function,
+ which does a breadth-first traversal using a stack.
+
+ :param obj: :class:`_expression.ClauseElement` structure to be traversed
+
+ :param opts: dictionary of iteration options. This dictionary is usually
+ empty in modern usage.
+
+ :param visitors: dictionary of visit functions. The dictionary should
+ have strings as keys, each of which would correspond to the
+ ``__visit_name__`` of a particular kind of SQL expression object, and
+ callable functions as values, each of which represents a visitor function
+ for that kind of object.
+
+ """
+ return traverse_using(iterate(obj, opts), obj, visitors)
+
+
+def cloned_traverse(obj, opts, visitors):
+ """Clone the given expression structure, allowing modifications by
+ visitors.
+
+ Traversal usage is the same as that of :func:`.visitors.traverse`.
+ The visitor functions present in the ``visitors`` dictionary may also
+ modify the internals of the given structure as the traversal proceeds.
+
+ The central API feature used by the :func:`.visitors.cloned_traverse`
+ and :func:`.visitors.replacement_traverse` functions, in addition to the
+ :meth:`_expression.ClauseElement.get_children`
+ function that is used to achieve
+ the iteration, is the :meth:`_expression.ClauseElement._copy_internals`
+ method.
+ For a :class:`_expression.ClauseElement`
+ structure to support cloning and replacement
+ traversals correctly, it needs to be able to pass a cloning function into
+ its internal members in order to make copies of them.
+
+ .. seealso::
+
+ :func:`.visitors.traverse`
+
+ :func:`.visitors.replacement_traverse`
+
+ """
+
+ cloned = {}
+ stop_on = set(opts.get("stop_on", []))
+
+ def deferred_copy_internals(obj):
+ return cloned_traverse(obj, opts, visitors)
+
+ def clone(elem, **kw):
+ if elem in stop_on:
+ return elem
+ else:
+ if id(elem) not in cloned:
+
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[id(elem)] = newelem
+ return newelem
+
+ cloned[id(elem)] = newelem = elem._clone(clone=clone, **kw)
+ newelem._copy_internals(clone=clone, **kw)
+ meth = visitors.get(newelem.__visit_name__, None)
+ if meth:
+ meth(newelem)
+ return cloned[id(elem)]
+
+ if obj is not None:
+ obj = clone(
+ obj, deferred_copy_internals=deferred_copy_internals, **opts
+ )
+ clone = None # remove gc cycles
+ return obj
+
+
+def replacement_traverse(obj, opts, replace):
+ """Clone the given expression structure, allowing element
+ replacement by a given replacement function.
+
+ This function is very similar to the :func:`.visitors.cloned_traverse`
+ function, except instead of being passed a dictionary of visitors, all
+ elements are unconditionally passed into the given replace function.
+ The replace function then has the option to return an entirely new object
+ which will replace the one given. If it returns ``None``, then the object
+ is kept in place.
+
+ The difference in usage between :func:`.visitors.cloned_traverse` and
+ :func:`.visitors.replacement_traverse` is that in the former case, an
+ already-cloned object is passed to the visitor function, and the visitor
+ function can then manipulate the internal state of the object.
+ In the case of the latter, the visitor function should only return an
+ entirely different object, or do nothing.
+
+ The use case for :func:`.visitors.replacement_traverse` is that of
+ replacing a FROM clause inside of a SQL structure with a different one,
+ as is a common use case within the ORM.
+
+ """
+
+ cloned = {}
+ stop_on = {id(x) for x in opts.get("stop_on", [])}
+
+ def deferred_copy_internals(obj):
+ return replacement_traverse(obj, opts, replace)
+
+ def clone(elem, **kw):
+ if (
+ id(elem) in stop_on
+ or "no_replacement_traverse" in elem._annotations
+ ):
+ return elem
+ else:
+ newelem = replace(elem)
+ if newelem is not None:
+ stop_on.add(id(newelem))
+ return newelem
+ else:
+ # base "already seen" on id(), not hash, so that we don't
+ # replace an Annotated element with its non-annotated one, and
+ # vice versa
+ id_elem = id(elem)
+ if id_elem not in cloned:
+ if "replace" in kw:
+ newelem = kw["replace"](elem)
+ if newelem is not None:
+ cloned[id_elem] = newelem
+ return newelem
+
+ cloned[id_elem] = newelem = elem._clone(**kw)
+ newelem._copy_internals(clone=clone, **kw)
+ return cloned[id_elem]
+
+ if obj is not None:
+ obj = clone(
+ obj, deferred_copy_internals=deferred_copy_internals, **opts
+ )
+ clone = None # remove gc cycles
+ return obj
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
new file mode 100644
index 0000000..80d344f
--- /dev/null
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -0,0 +1,86 @@
+# testing/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from . import config
+from . import mock
+from .assertions import assert_raises
+from .assertions import assert_raises_context_ok
+from .assertions import assert_raises_message
+from .assertions import assert_raises_message_context_ok
+from .assertions import assert_warns
+from .assertions import assert_warns_message
+from .assertions import AssertsCompiledSQL
+from .assertions import AssertsExecutionResults
+from .assertions import ComparesTables
+from .assertions import emits_warning
+from .assertions import emits_warning_on
+from .assertions import eq_
+from .assertions import eq_ignore_whitespace
+from .assertions import eq_regex
+from .assertions import expect_deprecated
+from .assertions import expect_deprecated_20
+from .assertions import expect_raises
+from .assertions import expect_raises_message
+from .assertions import expect_warnings
+from .assertions import in_
+from .assertions import is_
+from .assertions import is_false
+from .assertions import is_instance_of
+from .assertions import is_none
+from .assertions import is_not
+from .assertions import is_not_
+from .assertions import is_not_none
+from .assertions import is_true
+from .assertions import le_
+from .assertions import ne_
+from .assertions import not_in
+from .assertions import not_in_
+from .assertions import startswith_
+from .assertions import uses_deprecated
+from .config import async_test
+from .config import combinations
+from .config import combinations_list
+from .config import db
+from .config import fixture
+from .config import requirements as requires
+from .exclusions import _is_excluded
+from .exclusions import _server_version
+from .exclusions import against as _against
+from .exclusions import db_spec
+from .exclusions import exclude
+from .exclusions import fails
+from .exclusions import fails_if
+from .exclusions import fails_on
+from .exclusions import fails_on_everything_except
+from .exclusions import future
+from .exclusions import only_if
+from .exclusions import only_on
+from .exclusions import skip
+from .exclusions import skip_if
+from .schema import eq_clause_element
+from .schema import eq_type_affinity
+from .util import adict
+from .util import fail
+from .util import flag_combinations
+from .util import force_drop_names
+from .util import lambda_combinations
+from .util import metadata_fixture
+from .util import provide_metadata
+from .util import resolve_lambda
+from .util import rowset
+from .util import run_as_contextmanager
+from .util import teardown_events
+from .warnings import assert_warnings
+from .warnings import warn_test_suite
+
+
+def against(*queries):
+ return _against(config._current, *queries)
+
+
+crashes = skip
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
new file mode 100644
index 0000000..9a3c06b
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -0,0 +1,845 @@
+# testing/assertions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import contextlib
+import re
+import sys
+import warnings
+
+from . import assertsql
+from . import config
+from . import engines
+from . import mock
+from .exclusions import db_spec
+from .util import fail
+from .. import exc as sa_exc
+from .. import schema
+from .. import sql
+from .. import types as sqltypes
+from .. import util
+from ..engine import default
+from ..engine import url
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import compat
+from ..util import decorator
+
+
+def expect_warnings(*messages, **kw):
+ """Context manager which expects one or more warnings.
+
+ With no arguments, squelches all SAWarning and RemovedIn20Warning emitted via
+ sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
+ pass string expressions that will match selected warnings via regex;
+ all non-matching warnings are sent through.
+
+ The expect version **asserts** that the warnings were in fact seen.
+
+ Note that the test suite sets SAWarning warnings to raise exceptions.
+
+ """ # noqa
+ return _expect_warnings(
+ (sa_exc.RemovedIn20Warning, sa_exc.SAWarning), messages, **kw
+ )
+
+
+@contextlib.contextmanager
+def expect_warnings_on(db, *messages, **kw):
+ """Context manager which expects one or more warnings on specific
+ dialects.
+
+ The expect version **asserts** that the warnings were in fact seen.
+
+ """
+ spec = db_spec(db)
+
+ if isinstance(db, util.string_types) and not spec(config._current):
+ yield
+ else:
+ with expect_warnings(*messages, **kw):
+ yield
+
+
+def emits_warning(*messages):
+ """Decorator form of expect_warnings().
+
+ Note that emits_warning does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_warnings(assert_=False, *messages):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+def expect_deprecated(*messages, **kw):
+ return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
+
+
+def expect_deprecated_20(*messages, **kw):
+ return _expect_warnings(sa_exc.Base20DeprecationWarning, messages, **kw)
+
+
+def emits_warning_on(db, *messages):
+ """Mark a test as emitting a warning on a specific dialect.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+
+ Note that emits_warning_on does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_warnings_on(db, assert_=False, *messages):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+def uses_deprecated(*messages):
+ """Mark a test as immune from fatal deprecation warnings.
+
+ With no arguments, squelches all SADeprecationWarning failures.
+ Or pass one or more strings; these will be matched to the root
+ of the warning description by warnings.filterwarnings().
+
+ As a special case, you may pass a function name prefixed with //
+ and it will be re-written as needed to match the standard warning
+ verbiage emitted by the sqlalchemy.util.deprecated decorator.
+
+ Note that uses_deprecated does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_deprecated(*messages, assert_=False):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+_FILTERS = None
+_SEEN = None
+_EXC_CLS = None
+
+
+@contextlib.contextmanager
+def _expect_warnings(
+ exc_cls,
+ messages,
+ regex=True,
+ search_msg=False,
+ assert_=True,
+ py2konly=False,
+ raise_on_any_unexpected=False,
+ squelch_other_warnings=False,
+):
+
+ global _FILTERS, _SEEN, _EXC_CLS
+
+ if regex or search_msg:
+ filters = [re.compile(msg, re.I | re.S) for msg in messages]
+ else:
+ filters = list(messages)
+
+ if _FILTERS is not None:
+ # nested call; update _FILTERS and _SEEN, return. outer
+ # block will assert our messages
+ assert _SEEN is not None
+ assert _EXC_CLS is not None
+ _FILTERS.extend(filters)
+ _SEEN.update(filters)
+ _EXC_CLS += (exc_cls,)
+ yield
+ else:
+ seen = _SEEN = set(filters)
+ _FILTERS = filters
+ _EXC_CLS = (exc_cls,)
+
+ if raise_on_any_unexpected:
+
+ def real_warn(msg, *arg, **kw):
+ raise AssertionError("Got unexpected warning: %r" % msg)
+
+ else:
+ real_warn = warnings.warn
+
+ def our_warn(msg, *arg, **kw):
+
+ if isinstance(msg, _EXC_CLS):
+ exception = type(msg)
+ msg = str(msg)
+ elif arg:
+ exception = arg[0]
+ else:
+ exception = None
+
+ if not exception or not issubclass(exception, _EXC_CLS):
+ if not squelch_other_warnings:
+ return real_warn(msg, *arg, **kw)
+ else:
+ return
+
+ if not filters and not raise_on_any_unexpected:
+ return
+
+ for filter_ in filters:
+ if (
+ (search_msg and filter_.search(msg))
+ or (regex and filter_.match(msg))
+ or (not regex and filter_ == msg)
+ ):
+ seen.discard(filter_)
+ break
+ else:
+ if not squelch_other_warnings:
+ real_warn(msg, *arg, **kw)
+
+ with mock.patch("warnings.warn", our_warn), mock.patch(
+ "sqlalchemy.util.SQLALCHEMY_WARN_20", True
+ ), mock.patch(
+ "sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20", True
+ ), mock.patch(
+ "sqlalchemy.engine.row.LegacyRow._default_key_style", 2
+ ):
+ try:
+ yield
+ finally:
+ _SEEN = _FILTERS = _EXC_CLS = None
+
+ if assert_ and (not py2konly or not compat.py3k):
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
+
+
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+ _assert_no_stray_pool_connections()
+
+
+def _assert_no_stray_pool_connections():
+ engines.testing_reaper.assert_all_closed()
+
+
+def eq_regex(a, b, msg=None):
+ assert re.match(b, a), msg or "%r !~ %r" % (a, b)
+
+
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+
+def le_(a, b, msg=None):
+ """Assert a <= b, with repr messaging on failure."""
+ assert a <= b, msg or "%r != %r" % (a, b)
+
+
+def is_instance_of(a, b, msg=None):
+ assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
+
+
+def is_none(a, msg=None):
+ is_(a, None, msg=msg)
+
+
+def is_not_none(a, msg=None):
+ is_not(a, None, msg=msg)
+
+
+def is_true(a, msg=None):
+ is_(bool(a), True, msg=msg)
+
+
+def is_false(a, msg=None):
+ is_(bool(a), False, msg=msg)
+
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+
+def is_not(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+
+# deprecated. See #5429
+is_not_ = is_not
+
+
+def in_(a, b, msg=None):
+ """Assert a in b, with repr messaging on failure."""
+ assert a in b, msg or "%r not in %r" % (a, b)
+
+
+def not_in(a, b, msg=None):
+ """Assert a in not b, with repr messaging on failure."""
+ assert a not in b, msg or "%r is in %r" % (a, b)
+
+
+# deprecated. See #5429
+not_in_ = not_in
+
+
+def startswith_(a, fragment, msg=None):
+ """Assert a.startswith(fragment), with repr messaging on failure."""
+ assert a.startswith(fragment), msg or "%r does not start with %r" % (
+ a,
+ fragment,
+ )
+
+
+def eq_ignore_whitespace(a, b, msg=None):
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
+
+ assert a == b, msg or "%r != %r" % (a, b)
+
+
+def _assert_proper_exception_context(exception):
+ """assert that any exception we're catching does not have a __context__
+ without a __cause__, and that __suppress_context__ is never set.
+
+ Python 3 will report nested as exceptions as "during the handling of
+ error X, error Y occurred". That's not what we want to do. we want
+ these exceptions in a cause chain.
+
+ """
+
+ if not util.py3k:
+ return
+
+ if (
+ exception.__context__ is not exception.__cause__
+ and not exception.__suppress_context__
+ ):
+ assert False, (
+ "Exception %r was correctly raised but did not set a cause, "
+ "within context %r as its cause."
+ % (exception, exception.__context__)
+ )
+
+
+def assert_raises(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ return _assert_raises(
+ except_cls, callable_, args, kwargs, msg=msg, check_context=True
+ )
+
+
+def assert_warns(except_cls, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+
+ """
+ with _expect_warnings(except_cls, [".*"], squelch_other_warnings=True):
+ return callable_(*args, **kwargs)
+
+
+def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+ Also uses regex.search() to match the given message to the error string
+ rather than regex.match().
+
+ """
+ with _expect_warnings(
+ except_cls,
+ [msg],
+ search_msg=True,
+ regex=False,
+ squelch_other_warnings=True,
+ ):
+ return callable_(*args, **kwargs)
+
+
+def assert_raises_message_context_ok(
+ except_cls, msg, callable_, *args, **kwargs
+):
+ return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+ except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
+
+ with _expect_raises(except_cls, msg, check_context) as ec:
+ callable_(*args, **kwargs)
+ return ec.error
+
+
+class _ErrorContainer(object):
+ error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+ if (
+ isinstance(except_cls, type)
+ and issubclass(except_cls, Warning)
+ or isinstance(except_cls, Warning)
+ ):
+ raise TypeError(
+ "Use expect_warnings for warnings, not "
+ "expect_raises / assert_raises"
+ )
+ ec = _ErrorContainer()
+ if check_context:
+ are_we_already_in_a_traceback = sys.exc_info()[0]
+ try:
+ yield ec
+ success = False
+ except except_cls as err:
+ ec.error = err
+ success = True
+ if msg is not None:
+ assert re.search(
+ msg, util.text_type(err), re.UNICODE
+ ), "%r !~ %s" % (msg, err)
+ if check_context and not are_we_already_in_a_traceback:
+ _assert_proper_exception_context(err)
+ print(util.text_type(err).encode("utf-8"))
+
+ # it's generally a good idea to not carry traceback objects outside
+ # of the except: block, but in this case especially we seem to have
+ # hit some bug in either python 3.10.0b2 or greenlet or both which
+ # this seems to fix:
+ # https://github.com/python-greenlet/greenlet/issues/242
+ del ec
+
+ # assert outside the block so it works for AssertionError too !
+ assert success, "Callable did not raise an exception"
+
+
+def expect_raises(except_cls, check_context=True):
+ return _expect_raises(except_cls, check_context=check_context)
+
+
+def expect_raises_message(except_cls, msg, check_context=True):
+ return _expect_raises(except_cls, msg=msg, check_context=check_context)
+
+
+class AssertsCompiledSQL(object):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ for_executemany=False,
+ check_literal_execute=None,
+ check_post_param=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ supports_default_values=True,
+ supports_default_metavalue=True,
+ literal_binds=False,
+ render_postcompile=False,
+ schema_translate_map=None,
+ render_schema_translate=False,
+ default_schema_name=None,
+ from_linting=False,
+ ):
+ if use_default_dialect:
+ dialect = default.DefaultDialect()
+ dialect.supports_default_values = supports_default_values
+ dialect.supports_default_metavalue = supports_default_metavalue
+ elif allow_dialect_select:
+ dialect = None
+ else:
+ if dialect is None:
+ dialect = getattr(self, "__dialect__", None)
+
+ if dialect is None:
+ dialect = config.db.dialect
+ elif dialect == "default":
+ dialect = default.DefaultDialect()
+ dialect.supports_default_values = supports_default_values
+ dialect.supports_default_metavalue = supports_default_metavalue
+ elif dialect == "default_enhanced":
+ dialect = default.StrCompileDialect()
+ elif isinstance(dialect, util.string_types):
+ dialect = url.URL.create(dialect).get_dialect()()
+
+ if default_schema_name:
+ dialect.default_schema_name = default_schema_name
+
+ kw = {}
+ compile_kwargs = {}
+
+ if schema_translate_map:
+ kw["schema_translate_map"] = schema_translate_map
+
+ if params is not None:
+ kw["column_keys"] = list(params)
+
+ if literal_binds:
+ compile_kwargs["literal_binds"] = True
+
+ if render_postcompile:
+ compile_kwargs["render_postcompile"] = True
+
+ if for_executemany:
+ kw["for_executemany"] = True
+
+ if render_schema_translate:
+ kw["render_schema_translate"] = True
+
+ if from_linting or getattr(self, "assert_from_linting", False):
+ kw["linting"] = sql.FROM_LINTING
+
+ from sqlalchemy import orm
+
+ if isinstance(clause, orm.Query):
+ stmt = clause._statement_20()
+ stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+ clause = stmt
+
+ if compile_kwargs:
+ kw["compile_kwargs"] = compile_kwargs
+
+ class DontAccess(object):
+ def __getattribute__(self, key):
+ raise NotImplementedError(
+ "compiler accessed .statement; use "
+ "compiler.current_executable"
+ )
+
+ class CheckCompilerAccess(object):
+ def __init__(self, test_statement):
+ self.test_statement = test_statement
+ self._annotations = {}
+ self.supports_execution = getattr(
+ test_statement, "supports_execution", False
+ )
+
+ if self.supports_execution:
+ self._execution_options = test_statement._execution_options
+
+ if hasattr(test_statement, "_returning"):
+ self._returning = test_statement._returning
+ if hasattr(test_statement, "_inline"):
+ self._inline = test_statement._inline
+ if hasattr(test_statement, "_return_defaults"):
+ self._return_defaults = test_statement._return_defaults
+
+ def _default_dialect(self):
+ return self.test_statement._default_dialect()
+
+ def compile(self, dialect, **kw):
+ return self.test_statement.compile.__func__(
+ self, dialect=dialect, **kw
+ )
+
+ def _compiler(self, dialect, **kw):
+ return self.test_statement._compiler.__func__(
+ self, dialect, **kw
+ )
+
+ def _compiler_dispatch(self, compiler, **kwargs):
+ if hasattr(compiler, "statement"):
+ with mock.patch.object(
+ compiler, "statement", DontAccess()
+ ):
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+ else:
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+
+ # no construct can assume it's the "top level" construct in all cases
+ # as anything can be nested. ensure constructs don't assume they
+ # are the "self.statement" element
+ c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
+
+ if isinstance(clause, sqltypes.TypeEngine):
+ cache_key_no_warnings = clause._static_cache_key
+ if cache_key_no_warnings:
+ hash(cache_key_no_warnings)
+ else:
+ cache_key_no_warnings = clause._generate_cache_key()
+ if cache_key_no_warnings:
+ hash(cache_key_no_warnings[0])
+
+ param_str = repr(getattr(c, "params", {}))
+ if util.py3k:
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
+ print(
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
+ else:
+ print(
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
+
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
+
+ eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+ if checkparams is not None:
+ eq_(c.construct_params(params), checkparams)
+ if checkpositional is not None:
+ p = c.construct_params(params)
+ eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+ if check_prefetch is not None:
+ eq_(c.prefetch, check_prefetch)
+ if check_literal_execute is not None:
+ eq_(
+ {
+ c.bind_names[b]: b.effective_value
+ for b in c.literal_execute_params
+ },
+ check_literal_execute,
+ )
+ if check_post_param is not None:
+ eq_(
+ {
+ c.bind_names[b]: b.effective_value
+ for b in c.post_compile_params
+ },
+ check_post_param,
+ )
+
+
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ eq_(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
+
+ if strict_types:
+ msg = "Type '%s' doesn't correspond to type '%s'"
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
+ else:
+ self.assert_types_base(reflected_c, c)
+
+ if isinstance(c.type, sqltypes.String):
+ eq_(c.type.length, reflected_c.type.length)
+
+ eq_(
+ {f.column.name for f in c.foreign_keys},
+ {f.column.name for f in reflected_c.foreign_keys},
+ )
+ if c.server_default:
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
+
+ def assert_types_base(self, c1, c2):
+ assert c1.type._compare_type_affinity(
+ c2.type
+ ), "On column %r, type '%s' doesn't correspond to type '%s'" % (
+ c1.name,
+ c1.type,
+ c2.type,
+ )
+
+
+class AssertsExecutionResults(object):
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print(repr(result))
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list_):
+ self.assert_(
+ len(result) == len(list_),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
+ for i in range(0, len(list_)):
+ self.assert_row(class_, result[i], list_[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
+ for key, value in desc.items():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
+
+ def assert_unordered_result(self, result, cls, *expected):
+ """As assert_result, but the order of objects is not considered.
+
+ The algorithm is very expensive but not a big deal for the small
+ numbers of rows that the test suite manipulates.
+ """
+
+ class immutabledict(dict):
+ def __hash__(self):
+ return id(self)
+
+ found = util.IdentitySet(result)
+ expected = {immutabledict(e) for e in expected}
+
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
+
+ if len(found) != len(expected):
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
+
+ NOVALUE = object()
+
+ def _compare_item(obj, spec):
+ for key, value in spec.items():
+ if isinstance(value, tuple):
+ try:
+ self.assert_unordered_result(
+ getattr(obj, key), value[0], *value[1]
+ )
+ except AssertionError:
+ return False
+ else:
+ if getattr(obj, key, NOVALUE) != value:
+ return False
+ return True
+
+ for expected_item in expected:
+ for found_item in found:
+ if _compare_item(found_item, expected_item):
+ found.remove(found_item)
+ break
+ else:
+ fail(
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
+ return True
+
+ def sql_execution_asserter(self, db=None):
+ if db is None:
+ from . import db as db
+
+ return assertsql.assert_engine(db)
+
+ def assert_sql_execution(self, db, callable_, *rules):
+ with self.sql_execution_asserter(db) as asserter:
+ result = callable_()
+ asserter.assert_(*rules)
+ return result
+
+ def assert_sql(self, db, callable_, rules):
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
+ else:
+ newrule = assertsql.CompiledSQL(*rule)
+ newrules.append(newrule)
+
+ return self.assert_sql_execution(db, callable_, *newrules)
+
+ def assert_sql_count(self, db, callable_, count):
+ self.assert_sql_execution(
+ db, callable_, assertsql.CountStatements(count)
+ )
+
+ def assert_multiple_sql_count(self, dbs, callable_, counts):
+ recs = [
+ (self.sql_execution_asserter(db), db, count)
+ for (db, count) in zip(dbs, counts)
+ ]
+ asserters = []
+ for ctx, db, count in recs:
+ asserters.append(ctx.__enter__())
+ try:
+ return callable_()
+ finally:
+ for asserter, (ctx, db, count) in zip(asserters, recs):
+ ctx.__exit__(None, None, None)
+ asserter.assert_(assertsql.CountStatements(count))
+
+ @contextlib.contextmanager
+ def assert_execution(self, db, *rules):
+ with self.sql_execution_asserter(db) as asserter:
+ yield
+ asserter.assert_(*rules)
+
+ def assert_statement_count(self, db, count):
+ return self.assert_execution(db, assertsql.CountStatements(count))
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
new file mode 100644
index 0000000..565b3ed
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -0,0 +1,457 @@
+# testing/assertsql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import collections
+import contextlib
+import re
+
+from .. import event
+from .. import util
+from ..engine import url
+from ..engine.default import DefaultDialect
+from ..engine.util import _distill_cursor_params
+from ..schema import _DDLCompiles
+
+
+class AssertRule(object):
+
+ is_consumed = False
+ errormessage = None
+ consume_statement = True
+
+ def process_statement(self, execute_observed):
+ pass
+
+ def no_more_statements(self):
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
+
+
+class SQLMatchRule(AssertRule):
+ pass
+
+
+class CursorSQL(SQLMatchRule):
+ def __init__(self, statement, params=None, consume_statement=True):
+ self.statement = statement
+ self.params = params
+ self.consume_statement = consume_statement
+
+ def process_statement(self, execute_observed):
+ stmt = execute_observed.statements[0]
+ if self.statement != stmt.statement or (
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
+ )
+ )
+ else:
+ execute_observed.statements.pop(0)
+ self.is_consumed = True
+ if not execute_observed.statements:
+ self.consume_statement = True
+
+
+class CompiledSQL(SQLMatchRule):
+ def __init__(self, statement, params=None, dialect="default"):
+ self.statement = statement
+ self.params = params
+ self.dialect = dialect
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r"[\n\t]", "", self.statement)
+ return received_statement == stmt
+
+ def _compile_dialect(self, execute_observed):
+ if self.dialect == "default":
+ dialect = DefaultDialect()
+ # this is currently what tests are expecting
+ # dialect.supports_default_values = True
+ dialect.supports_default_metavalue = True
+ return dialect
+ else:
+ # ugh
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
+ else:
+ params = {}
+ return url.URL.create(self.dialect).get_dialect()(**params)
+
+ def _received_statement(self, execute_observed):
+ """reconstruct the statement and params in terms
+ of a target dialect, which for CompiledSQL is just DefaultDialect."""
+
+ context = execute_observed.context
+ compare_dialect = self._compile_dialect(execute_observed)
+
+ # received_statement runs a full compile(). we should not need to
+ # consider extracted_parameters; if we do this indicates some state
+ # is being sent from a previous cached query, which some misbehaviors
+ # in the ORM can cause, see #6881
+ cache_key = None # execute_observed.context.compiled.cache_key
+ extracted_parameters = (
+ None # execute_observed.context.extracted_parameters
+ )
+
+ if "schema_translate_map" in context.execution_options:
+ map_ = context.execution_options["schema_translate_map"]
+ else:
+ map_ = None
+
+ if isinstance(execute_observed.clauseelement, _DDLCompiles):
+
+ compiled = execute_observed.clauseelement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=map_,
+ )
+ else:
+ compiled = execute_observed.clauseelement.compile(
+ cache_key=cache_key,
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ for_executemany=context.compiled.for_executemany,
+ schema_translate_map=map_,
+ )
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
+ parameters = execute_observed.parameters
+
+ if not parameters:
+ _received_parameters = [
+ compiled.construct_params(
+ extracted_parameters=extracted_parameters
+ )
+ ]
+ else:
+ _received_parameters = [
+ compiled.construct_params(
+ m, extracted_parameters=extracted_parameters
+ )
+ for m in parameters
+ ]
+
+ return _received_statement, _received_parameters
+
+ def process_statement(self, execute_observed):
+ context = execute_observed.context
+
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
+ params = self._all_params(context)
+
+ equivalent = self._compare_sql(execute_observed, _received_statement)
+
+ if equivalent:
+ if params is not None:
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while all_params and all_received:
+ param = dict(all_params.pop(0))
+
+ for idx, received in enumerate(list(all_received)):
+ # do a positive compare only
+ for param_key in param:
+ # a key in param did not match current
+ # 'received'
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
+ break
+ else:
+ # all keys in param matched 'received';
+ # onto next param
+ del all_received[idx]
+ break
+ else:
+ # param did not match any entry
+ # in all_received
+ equivalent = False
+ break
+ if all_params or all_received:
+ equivalent = False
+
+ if equivalent:
+ self.is_consumed = True
+ self.errormessage = None
+ else:
+ self.errormessage = self._failure_message(params) % {
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
+ }
+
+ def _all_params(self, context):
+ if self.params:
+ if callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+ return params
+ else:
+ return None
+
+ def _failure_message(self, expected_params):
+ return (
+ "Testing for compiled statement\n%r partial params %s, "
+ "received\n%%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (
+ self.statement.replace("%", "%%"),
+ repr(expected_params).replace("%", "%%"),
+ )
+ )
+
+
+class RegexSQL(CompiledSQL):
+ def __init__(self, regex, params=None, dialect="default"):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+ self.dialect = dialect
+
+ def _failure_message(self, expected_params):
+ return (
+ "Testing for compiled statement ~%r partial params %s, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (
+ self.orig_regex.replace("%", "%%"),
+ repr(expected_params).replace("%", "%%"),
+ )
+ )
+
+ def _compare_sql(self, execute_observed, received_statement):
+ return bool(self.regex.match(received_statement))
+
+
+class DialectSQL(CompiledSQL):
+ def _compile_dialect(self, execute_observed):
+ return execute_observed.context.dialect
+
+ def _compare_no_space(self, real_stmt, received_stmt):
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
+ return received_stmt == stmt
+
+ def _received_statement(self, execute_observed):
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
+
+ # TODO: why do we need this part?
+ for real_stmt in execute_observed.statements:
+ if self._compare_no_space(real_stmt.statement, received_stmt):
+ break
+ else:
+ raise AssertionError(
+ "Can't locate compiled statement %r in list of "
+ "statements actually invoked" % received_stmt
+ )
+
+ return received_stmt, execute_observed.context.compiled_parameters
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r"[\n\t]", "", self.statement)
+ # convert our comparison statement to have the
+ # paramstyle of the received
+ paramstyle = execute_observed.context.dialect.paramstyle
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
+ else:
+ # positional params
+ repl = None
+ if paramstyle == "qmark":
+ repl = "?"
+ elif paramstyle == "format":
+ repl = r"%s"
+ elif paramstyle == "numeric":
+ repl = None
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
+
+ return received_statement == stmt
+
+
+class CountStatements(AssertRule):
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_statement(self, execute_observed):
+ self._statement_count += 1
+
+ def no_more_statements(self):
+ if self.count != self._statement_count:
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
+
+
+class AllOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_statement(self, execute_observed):
+ for rule in list(self.rules):
+ rule.errormessage = None
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.discard(rule)
+ if not self.rules:
+ self.is_consumed = True
+ break
+ elif not rule.errormessage:
+ # rule is not done yet
+ self.errormessage = None
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
+
+
+class EachOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = list(rules)
+
+ def process_statement(self, execute_observed):
+ while self.rules:
+ rule = self.rules[0]
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.pop(0)
+ elif rule.errormessage:
+ self.errormessage = rule.errormessage
+ if rule.consume_statement:
+ break
+
+ if not self.rules:
+ self.is_consumed = True
+
+ def no_more_statements(self):
+ if self.rules and not self.rules[0].is_consumed:
+ self.rules[0].no_more_statements()
+ elif self.rules:
+ super(EachOf, self).no_more_statements()
+
+
+class Conditional(EachOf):
+ def __init__(self, condition, rules, else_rules):
+ if condition:
+ super(Conditional, self).__init__(*rules)
+ else:
+ super(Conditional, self).__init__(*else_rules)
+
+
+class Or(AllOf):
+ def process_statement(self, execute_observed):
+ for rule in self.rules:
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.is_consumed = True
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
+
+
+class SQLExecuteObserved(object):
+ def __init__(self, context, clauseelement, multiparams, params):
+ self.context = context
+ self.clauseelement = clauseelement
+ self.parameters = _distill_cursor_params(
+ context.connection, tuple(multiparams), params
+ )
+ self.statements = []
+
+ def __repr__(self):
+ return str(self.statements)
+
+
+class SQLCursorExecuteObserved(
+ collections.namedtuple(
+ "SQLCursorExecuteObserved",
+ ["statement", "parameters", "context", "executemany"],
+ )
+):
+ pass
+
+
+class SQLAsserter(object):
+ def __init__(self):
+ self.accumulated = []
+
+ def _close(self):
+ self._final = self.accumulated
+ del self.accumulated
+
+ def assert_(self, *rules):
+ rule = EachOf(*rules)
+
+ observed = list(self._final)
+ while observed:
+ statement = observed.pop(0)
+ rule.process_statement(statement)
+ if rule.is_consumed:
+ break
+ elif rule.errormessage:
+ assert False, rule.errormessage
+ if observed:
+ assert False, "Additional SQL statements remain:\n%s" % observed
+ elif not rule.is_consumed:
+ rule.no_more_statements()
+
+
+@contextlib.contextmanager
+def assert_engine(engine):
+ asserter = SQLAsserter()
+
+ orig = []
+
+ @event.listens_for(engine, "before_execute")
+ def connection_execute(
+ conn, clauseelement, multiparams, params, execution_options
+ ):
+ # grab the original statement + params before any cursor
+ # execution
+ orig[:] = clauseelement, multiparams, params
+
+ @event.listens_for(engine, "after_cursor_execute")
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ if not context:
+ return
+ # then grab real cursor statements and associate them all
+ # around a single context
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
+ obs = asserter.accumulated[-1]
+ else:
+ obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
+ asserter.accumulated.append(obs)
+ obs.statements.append(
+ SQLCursorExecuteObserved(
+ statement, parameters, context, executemany
+ )
+ )
+
+ try:
+ yield asserter
+ finally:
+ event.remove(engine, "after_cursor_execute", cursor_execute)
+ event.remove(engine, "before_execute", connection_execute)
+ asserter._close()
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
new file mode 100644
index 0000000..2189060
--- /dev/null
+++ b/lib/sqlalchemy/testing/asyncio.py
@@ -0,0 +1,128 @@
+# testing/asyncio.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+# functions and wrappers to run tests, fixtures, provisioning and
+# setup/teardown in an asyncio event loop, conditionally based on the
+# current DB driver being used for a test.
+
+# note that SQLAlchemy's asyncio integration also supports a method
+# of running individual asyncio functions inside of separate event loops
+# using "async_fallback" mode; however running whole functions in the event
+# loop is a more accurate test for how SQLAlchemy's asyncio features
+# would run in the real world.
+
+
+from functools import wraps
+import inspect
+
+from . import config
+from ..util.concurrency import _util_async_run
+from ..util.concurrency import _util_async_run_coroutine_function
+
+# may be set to False if the
+# --disable-asyncio flag is passed to the test runner.
+ENABLE_ASYNCIO = True
+
+
+def _run_coroutine_function(fn, *args, **kwargs):
+ return _util_async_run_coroutine_function(fn, *args, **kwargs)
+
+
+def _assume_async(fn, *args, **kwargs):
+ """Run a function in an asyncio loop unconditionally.
+
+ This function is used for provisioning features like
+ testing a database connection for server info.
+
+ Note that for blocking IO database drivers, this means they block the
+ event loop.
+
+ """
+
+ if not ENABLE_ASYNCIO:
+ return fn(*args, **kwargs)
+
+ return _util_async_run(fn, *args, **kwargs)
+
+
+def _maybe_async_provisioning(fn, *args, **kwargs):
+ """Run a function in an asyncio loop if any current drivers might need it.
+
+ This function is used for provisioning features that take
+ place outside of a specific database driver being selected, so if the
+ current driver that happens to be used for the provisioning operation
+ is an async driver, it will run in asyncio and not fail.
+
+ Note that for blocking IO database drivers, this means they block the
+ event loop.
+
+ """
+ if not ENABLE_ASYNCIO:
+ return fn(*args, **kwargs)
+
+ if config.any_async:
+ return _util_async_run(fn, *args, **kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+
+def _maybe_async(fn, *args, **kwargs):
+ """Run a function in an asyncio loop if the current selected driver is
+ async.
+
+ This function is used for test setup/teardown and tests themselves
+ where the current DB driver is known.
+
+
+ """
+ if not ENABLE_ASYNCIO:
+
+ return fn(*args, **kwargs)
+
+ is_async = config._current.is_async
+
+ if is_async:
+ return _util_async_run(fn, *args, **kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+
+def _maybe_async_wrapper(fn):
+ """Apply the _maybe_async function to an existing function and return
+ as a wrapped callable, supporting generator functions as well.
+
+ This is currently used for pytest fixtures that support generator use.
+
+ """
+
+ if inspect.isgeneratorfunction(fn):
+ _stop = object()
+
+ def call_next(gen):
+ try:
+ return next(gen)
+ # can't raise StopIteration in an awaitable.
+ except StopIteration:
+ return _stop
+
+ @wraps(fn)
+ def wrap_fixture(*args, **kwargs):
+ gen = fn(*args, **kwargs)
+ while True:
+ value = _maybe_async(call_next, gen)
+ if value is _stop:
+ break
+ yield value
+
+ else:
+
+ @wraps(fn)
+ def wrap_fixture(*args, **kwargs):
+ return _maybe_async(fn, *args, **kwargs)
+
+ return wrap_fixture
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
new file mode 100644
index 0000000..fc13a16
--- /dev/null
+++ b/lib/sqlalchemy/testing/config.py
@@ -0,0 +1,209 @@
+# testing/config.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import collections
+
+from .. import util
+
+requirements = None
+db = None
+db_url = None
+db_opts = None
+file_config = None
+test_schema = None
+test_schema_2 = None
+any_async = False
+_current = None
+ident = "main"
+
+_fixture_functions = None # installed by plugin_base
+
+
+def combinations(*comb, **kw):
+ r"""Deliver multiple versions of a test based on positional combinations.
+
+ This is a facade over pytest.mark.parametrize.
+
+
+ :param \*comb: argument combinations. These are tuples that will be passed
+ positionally to the decorated function.
+
+ :param argnames: optional list of argument names. These are the names
+ of the arguments in the test function that correspond to the entries
+ in each argument tuple. pytest.mark.parametrize requires this, however
+ the combinations function will derive it automatically if not present
+ by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
+ first argument is "self" which is discarded.
+
+ :param id\_: optional id template. This is a string template that
+ describes how the "id" for each parameter set should be defined, if any.
+ The number of characters in the template should match the number of
+ entries in each argument tuple. Each character describes how the
+ corresponding entry in the argument tuple should be handled, as far as
+ whether or not it is included in the arguments passed to the function, as
+ well as if it is included in the tokens used to create the id of the
+ parameter set.
+
+ If omitted, the argument combinations are passed to parametrize as is. If
+ passed, each argument combination is turned into a pytest.param() object,
+ mapping the elements of the argument tuple to produce an id based on a
+ character value in the same position within the string template using the
+ following scheme::
+
+ i - the given argument is a string that is part of the id only, don't
+ pass it as an argument
+
+ n - the given argument should be passed and it should be added to the
+ id by calling the .__name__ attribute
+
+ r - the given argument should be passed and it should be added to the
+ id by calling repr()
+
+ s - the given argument should be passed and it should be added to the
+ id by calling str()
+
+ a - (argument) the given argument should be passed and it should not
+ be used to generated the id
+
+ e.g.::
+
+ @testing.combinations(
+ (operator.eq, "eq"),
+ (operator.ne, "ne"),
+ (operator.gt, "gt"),
+ (operator.lt, "lt"),
+ id_="na"
+ )
+ def test_operator(self, opfunc, name):
+ pass
+
+ The above combination will call ``.__name__`` on the first member of
+ each tuple and use that as the "id" to pytest.param().
+
+
+ """
+ return _fixture_functions.combinations(*comb, **kw)
+
+
+def combinations_list(arg_iterable, **kw):
+ "As combination, but takes a single iterable"
+ return combinations(*arg_iterable, **kw)
+
+
+def fixture(*arg, **kw):
+ return _fixture_functions.fixture(*arg, **kw)
+
+
+def get_current_test_name():
+ return _fixture_functions.get_current_test_name()
+
+
+def mark_base_test_class():
+ return _fixture_functions.mark_base_test_class()
+
+
+class Config(object):
+ def __init__(self, db, db_opts, options, file_config):
+ self._set_name(db)
+ self.db = db
+ self.db_opts = db_opts
+ self.options = options
+ self.file_config = file_config
+ self.test_schema = "test_schema"
+ self.test_schema_2 = "test_schema_2"
+
+ self.is_async = db.dialect.is_async and not util.asbool(
+ db.url.query.get("async_fallback", False)
+ )
+
+ _stack = collections.deque()
+ _configs = set()
+
+ def _set_name(self, db):
+ if db.dialect.server_version_info:
+ svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
+ self.name = "%s+%s_[%s]" % (db.name, db.driver, svi)
+ else:
+ self.name = "%s+%s" % (db.name, db.driver)
+
+ @classmethod
+ def register(cls, db, db_opts, options, file_config):
+ """add a config as one of the global configs.
+
+ If there are no configs set up yet, this config also
+ gets set as the "_current".
+ """
+ global any_async
+
+ cfg = Config(db, db_opts, options, file_config)
+
+ # if any backends include an async driver, then ensure
+ # all setup/teardown and tests are wrapped in the maybe_async()
+ # decorator that will set up a greenlet context for async drivers.
+ any_async = any_async or cfg.is_async
+
+ cls._configs.add(cfg)
+ return cfg
+
+ @classmethod
+ def set_as_current(cls, config, namespace):
+ global db, _current, db_url, test_schema, test_schema_2, db_opts
+ _current = config
+ db_url = config.db.url
+ db_opts = config.db_opts
+ test_schema = config.test_schema
+ test_schema_2 = config.test_schema_2
+ namespace.db = db = config.db
+
+ @classmethod
+ def push_engine(cls, db, namespace):
+ assert _current, "Can't push without a default Config set up"
+ cls.push(
+ Config(
+ db, _current.db_opts, _current.options, _current.file_config
+ ),
+ namespace,
+ )
+
+ @classmethod
+ def push(cls, config, namespace):
+ cls._stack.append(_current)
+ cls.set_as_current(config, namespace)
+
+ @classmethod
+ def pop(cls, namespace):
+ if cls._stack:
+ # a failed test w/ -x option can call reset() ahead of time
+ _current = cls._stack[-1]
+ del cls._stack[-1]
+ cls.set_as_current(_current, namespace)
+
+ @classmethod
+ def reset(cls, namespace):
+ if cls._stack:
+ cls.set_as_current(cls._stack[0], namespace)
+ cls._stack.clear()
+
+ @classmethod
+ def all_configs(cls):
+ return cls._configs
+
+ @classmethod
+ def all_dbs(cls):
+ for cfg in cls.all_configs():
+ yield cfg.db
+
+ def skip_test(self, msg):
+ skip_test(msg)
+
+
+def skip_test(msg):
+ raise _fixture_functions.skip_test_exception(msg)
+
+
+def async_test(fn):
+ return _fixture_functions.async_test(fn)
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
new file mode 100644
index 0000000..b8be6b9
--- /dev/null
+++ b/lib/sqlalchemy/testing/engines.py
@@ -0,0 +1,465 @@
+# testing/engines.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import collections
+import re
+import warnings
+import weakref
+
+from . import config
+from .util import decorator
+from .util import gc_collect
+from .. import event
+from .. import pool
+from ..util import await_only
+
+
+class ConnectionKiller(object):
+ def __init__(self):
+ self.proxy_refs = weakref.WeakKeyDictionary()
+ self.testing_engines = collections.defaultdict(set)
+ self.dbapi_connections = set()
+
+ def add_pool(self, pool):
+ event.listen(pool, "checkout", self._add_conn)
+ event.listen(pool, "checkin", self._remove_conn)
+ event.listen(pool, "close", self._remove_conn)
+ event.listen(pool, "close_detached", self._remove_conn)
+ # note we are keeping "invalidated" here, as those are still
+ # opened connections we would like to roll back
+
+ def _add_conn(self, dbapi_con, con_record, con_proxy):
+ self.dbapi_connections.add(dbapi_con)
+ self.proxy_refs[con_proxy] = True
+
+ def _remove_conn(self, dbapi_conn, *arg):
+ self.dbapi_connections.discard(dbapi_conn)
+
+ def add_engine(self, engine, scope):
+ self.add_pool(engine.pool)
+
+ assert scope in ("class", "global", "function", "fixture")
+ self.testing_engines[scope].add(engine)
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except Exception as e:
+ warnings.warn(
+ "testing_reaper couldn't rollback/close connection: %s" % e
+ )
+
+ def rollback_all(self):
+ for rec in list(self.proxy_refs):
+ if rec is not None and rec.is_valid:
+ self._safe(rec.rollback)
+
+ def checkin_all(self):
+ # run pool.checkin() for all ConnectionFairy instances we have
+ # tracked.
+
+ for rec in list(self.proxy_refs):
+ if rec is not None and rec.is_valid:
+ self.dbapi_connections.discard(rec.dbapi_connection)
+ self._safe(rec._checkin)
+
+ # for fairy refs that were GCed and could not close the connection,
+ # such as asyncio, roll back those remaining connections
+ for con in self.dbapi_connections:
+ self._safe(con.rollback)
+ self.dbapi_connections.clear()
+
+ def close_all(self):
+ self.checkin_all()
+
+ def prepare_for_drop_tables(self, connection):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ from . import provision
+
+ provision.prepare_for_drop_tables(connection.engine.url, connection)
+
+ def _drop_testing_engines(self, scope):
+ eng = self.testing_engines[scope]
+ for rec in list(eng):
+ for proxy_ref in list(self.proxy_refs):
+ if proxy_ref is not None and proxy_ref.is_valid:
+ if (
+ proxy_ref._pool is not None
+ and proxy_ref._pool is rec.pool
+ ):
+ self._safe(proxy_ref._checkin)
+ if hasattr(rec, "sync_engine"):
+ await_only(rec.dispose())
+ else:
+ rec.dispose()
+ eng.clear()
+
+ def after_test(self):
+ self._drop_testing_engines("function")
+
+ def after_test_outside_fixtures(self, test):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ if test.__class__.__leave_connections_for_teardown__:
+ return
+
+ self.checkin_all()
+
+ # on PostgreSQL, this will test for any "idle in transaction"
+ # connections. useful to identify tests with unusual patterns
+ # that can't be cleaned up correctly.
+ from . import provision
+
+ with config.db.connect() as conn:
+ provision.prepare_for_drop_tables(conn.engine.url, conn)
+
+ def stop_test_class_inside_fixtures(self):
+ self.checkin_all()
+ self._drop_testing_engines("function")
+ self._drop_testing_engines("class")
+
+ def stop_test_class_outside_fixtures(self):
+ # ensure no refs to checked out connections at all.
+
+ if pool.base._strong_ref_connection_records:
+ gc_collect()
+
+ if pool.base._strong_ref_connection_records:
+ ln = len(pool.base._strong_ref_connection_records)
+ pool.base._strong_ref_connection_records.clear()
+ assert (
+ False
+ ), "%d connection recs not cleared after test suite" % (ln)
+
+ def final_cleanup(self):
+ self.checkin_all()
+ for scope in self.testing_engines:
+ self._drop_testing_engines(scope)
+
+ def assert_all_closed(self):
+ for rec in self.proxy_refs:
+ if rec.is_valid:
+ assert False
+
+
+testing_reaper = ConnectionKiller()
+
+
+@decorator
+def assert_conns_closed(fn, *args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.assert_all_closed()
+
+
+@decorator
+def rollback_open_connections(fn, *args, **kw):
+ """Decorator that rolls back all open connections after fn execution."""
+
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.rollback_all()
+
+
+@decorator
+def close_first(fn, *args, **kw):
+ """Decorator that closes all connections before fn execution."""
+
+ testing_reaper.checkin_all()
+ fn(*args, **kw)
+
+
+@decorator
+def close_open_connections(fn, *args, **kw):
+ """Decorator that closes all connections after fn execution."""
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.checkin_all()
+
+
+def all_dialects(exclude=None):
+ import sqlalchemy.dialects as d
+
+ for name in d.__all__:
+ # TEMPORARY
+ if exclude and name in exclude:
+ continue
+ mod = getattr(d, name, None)
+ if not mod:
+ mod = getattr(
+ __import__("sqlalchemy.dialects.%s" % name).dialects, name
+ )
+ yield mod.dialect()
+
+
+class ReconnectFixture(object):
+ def __init__(self, dbapi):
+ self.dbapi = dbapi
+ self.connections = []
+ self.is_stopped = False
+
+ def __getattr__(self, key):
+ return getattr(self.dbapi, key)
+
+ def connect(self, *args, **kwargs):
+
+ conn = self.dbapi.connect(*args, **kwargs)
+ if self.is_stopped:
+ self._safe(conn.close)
+ curs = conn.cursor() # should fail on Oracle etc.
+ # should fail for everything that didn't fail
+ # above, connection is closed
+ curs.execute("select 1")
+ assert False, "simulated connect failure didn't work"
+ else:
+ self.connections.append(conn)
+ return conn
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except Exception as e:
+ warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
+
+ def shutdown(self, stop=False):
+ # TODO: this doesn't cover all cases
+ # as nicely as we'd like, namely MySQLdb.
+ # would need to implement R. Brewer's
+ # proxy server idea to get better
+ # coverage.
+ self.is_stopped = stop
+ for c in list(self.connections):
+ self._safe(c.close)
+ self.connections = []
+
+ def restart(self):
+ self.is_stopped = False
+
+
+def reconnecting_engine(url=None, options=None):
+ url = url or config.db.url
+ dbapi = config.db.dialect.dbapi
+ if not options:
+ options = {}
+ options["module"] = ReconnectFixture(dbapi)
+ engine = testing_engine(url, options)
+ _dispose = engine.dispose
+
+ def dispose():
+ engine.dialect.dbapi.shutdown()
+ engine.dialect.dbapi.is_stopped = False
+ _dispose()
+
+ engine.test_shutdown = engine.dialect.dbapi.shutdown
+ engine.test_restart = engine.dialect.dbapi.restart
+ engine.dispose = dispose
+ return engine
+
+
+def testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ _sqlite_savepoint=False,
+):
+ """Produce an engine configured by --options with optional overrides."""
+
+ if asyncio:
+ assert not _sqlite_savepoint
+ from sqlalchemy.ext.asyncio import (
+ create_async_engine as create_engine,
+ )
+ elif future or (
+ config.db and config.db._is_future and future is not False
+ ):
+ from sqlalchemy.future import create_engine
+ else:
+ from sqlalchemy import create_engine
+ from sqlalchemy.engine.url import make_url
+
+ if not options:
+ use_reaper = True
+ scope = "function"
+ sqlite_savepoint = False
+ else:
+ use_reaper = options.pop("use_reaper", True)
+ scope = options.pop("scope", "function")
+ sqlite_savepoint = options.pop("sqlite_savepoint", False)
+
+ url = url or config.db.url
+
+ url = make_url(url)
+ if options is None:
+ if config.db is None or url.drivername == config.db.url.drivername:
+ options = config.db_opts
+ else:
+ options = {}
+ elif config.db is not None and url.drivername == config.db.url.drivername:
+ default_opt = config.db_opts.copy()
+ default_opt.update(options)
+
+ engine = create_engine(url, **options)
+
+ if sqlite_savepoint and engine.name == "sqlite":
+ # apply SQLite savepoint workaround
+ @event.listens_for(engine, "connect")
+ def do_connect(dbapi_connection, connection_record):
+ dbapi_connection.isolation_level = None
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ conn.exec_driver_sql("BEGIN")
+
+ if transfer_staticpool:
+ from sqlalchemy.pool import StaticPool
+
+ if config.db is not None and isinstance(config.db.pool, StaticPool):
+ use_reaper = False
+ engine.pool._transfer_from(config.db.pool)
+
+ if scope == "global":
+ if asyncio:
+ engine.sync_engine._has_events = True
+ else:
+ engine._has_events = (
+ True # enable event blocks, helps with profiling
+ )
+
+ if isinstance(engine.pool, pool.QueuePool):
+ engine.pool._timeout = 0
+ engine.pool._max_overflow = 0
+ if use_reaper:
+ testing_reaper.add_engine(engine, scope)
+
+ return engine
+
+
+def mock_engine(dialect_name=None):
+ """Provides a mocking engine based on the current testing.db.
+
+ This is normally used to test DDL generation flow as emitted
+ by an Engine.
+
+ It should not be used in other cases, as assert_compile() and
+ assert_sql_execution() are much better choices with fewer
+ moving parts.
+
+ """
+
+ from sqlalchemy import create_mock_engine
+
+ if not dialect_name:
+ dialect_name = config.db.name
+
+ buffer = []
+
+ def executor(sql, *a, **kw):
+ buffer.append(sql)
+
+ def assert_sql(stmts):
+ recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
+ assert recv == stmts, recv
+
+ def print_sql():
+ d = engine.dialect
+ return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
+
+ engine = create_mock_engine(dialect_name + "://", executor)
+ assert not hasattr(engine, "mock")
+ engine.mock = buffer
+ engine.assert_sql = assert_sql
+ engine.print_sql = print_sql
+ return engine
+
+
+class DBAPIProxyCursor(object):
+ """Proxy a DBAPI cursor.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level cursor operations.
+
+ """
+
+ def __init__(self, engine, conn, *args, **kwargs):
+ self.engine = engine
+ self.connection = conn
+ self.cursor = conn.cursor(*args, **kwargs)
+
+ def execute(self, stmt, parameters=None, **kw):
+ if parameters:
+ return self.cursor.execute(stmt, parameters, **kw)
+ else:
+ return self.cursor.execute(stmt, **kw)
+
+ def executemany(self, stmt, params, **kw):
+ return self.cursor.executemany(stmt, params, **kw)
+
+ def __iter__(self):
+ return iter(self.cursor)
+
+ def __getattr__(self, key):
+ return getattr(self.cursor, key)
+
+
+class DBAPIProxyConnection(object):
+ """Proxy a DBAPI connection.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level connection operations.
+
+ """
+
+ def __init__(self, engine, cursor_cls):
+ self.conn = engine.pool._creator()
+ self.engine = engine
+ self.cursor_cls = cursor_cls
+
+ def cursor(self, *args, **kwargs):
+ return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
+
+ def close(self):
+ self.conn.close()
+
+ def __getattr__(self, key):
+ return getattr(self.conn, key)
+
+
+def proxying_engine(
+ conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
+):
+ """Produce an engine that provides proxy hooks for
+ common methods.
+
+ """
+
+ def mock_conn():
+ return conn_cls(config.db, cursor_cls)
+
+ def _wrap_do_on_connect(do_on_connect):
+ def go(dbapi_conn):
+ return do_on_connect(dbapi_conn.conn)
+
+ return go
+
+ return testing_engine(
+ options={
+ "creator": mock_conn,
+ "_wrap_do_on_connect": _wrap_do_on_connect,
+ }
+ )
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
new file mode 100644
index 0000000..8ea65d6
--- /dev/null
+++ b/lib/sqlalchemy/testing/entities.py
@@ -0,0 +1,111 @@
+# testing/entities.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import sqlalchemy as sa
+from .. import exc as sa_exc
+from ..util import compat
+
+_repr_stack = set()
+
+
+class BasicEntity(object):
+ def __init__(self, **kw):
+ for key, value in kw.items():
+ setattr(self, key, value)
+
+ def __repr__(self):
+ if id(self) in _repr_stack:
+ return object.__repr__(self)
+ _repr_stack.add(id(self))
+ try:
+ return "%s(%s)" % (
+ (self.__class__.__name__),
+ ", ".join(
+ [
+ "%s=%r" % (key, getattr(self, key))
+ for key in sorted(self.__dict__.keys())
+ if not key.startswith("_")
+ ]
+ ),
+ )
+ finally:
+ _repr_stack.remove(id(self))
+
+
+_recursion_stack = set()
+
+
+class ComparableMixin(object):
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __eq__(self, other):
+ """'Deep, sparse compare.
+
+ Deeply compare two entities, following the non-None attributes of the
+ non-persisted object, if possible.
+
+ """
+ if other is self:
+ return True
+ elif not self.__class__ == other.__class__:
+ return False
+
+ if id(self) in _recursion_stack:
+ return True
+ _recursion_stack.add(id(self))
+
+ try:
+ # pick the entity that's not SA persisted as the source
+ try:
+ self_key = sa.orm.attributes.instance_state(self).key
+ except sa.orm.exc.NO_STATE:
+ self_key = None
+
+ if other is None:
+ a = self
+ b = other
+ elif self_key is not None:
+ a = other
+ b = self
+ else:
+ a = self
+ b = other
+
+ for attr in list(a.__dict__):
+ if attr.startswith("_"):
+ continue
+ value = getattr(a, attr)
+
+ try:
+ # handle lazy loader errors
+ battr = getattr(b, attr)
+ except (AttributeError, sa_exc.UnboundExecutionError):
+ return False
+
+ if hasattr(value, "__iter__") and not isinstance(
+ value, compat.string_types
+ ):
+ if hasattr(value, "__getitem__") and not hasattr(
+ value, "keys"
+ ):
+ if list(value) != list(battr):
+ return False
+ else:
+ if set(value) != set(battr):
+ return False
+ else:
+ if value is not None and value != battr:
+ return False
+ return True
+ finally:
+ _recursion_stack.remove(id(self))
+
+
+class ComparableEntity(ComparableMixin, BasicEntity):
+ def __hash__(self):
+ return hash(self.__class__)
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
new file mode 100644
index 0000000..521a4aa
--- /dev/null
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -0,0 +1,465 @@
+# testing/exclusions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+import contextlib
+import operator
+import re
+import sys
+
+from . import config
+from .. import util
+from ..util import decorator
+from ..util.compat import inspect_getfullargspec
+
+
+def skip_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.skips.add(pred)
+ return rule
+
+
+def fails_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.fails.add(pred)
+ return rule
+
+
+class compound(object):
+ def __init__(self):
+ self.fails = set()
+ self.skips = set()
+ self.tags = set()
+
+ def __add__(self, other):
+ return self.add(other)
+
+ def as_skips(self):
+ rule = compound()
+ rule.skips.update(self.skips)
+ rule.skips.update(self.fails)
+ rule.tags.update(self.tags)
+ return rule
+
+ def add(self, *others):
+ copy = compound()
+ copy.fails.update(self.fails)
+ copy.skips.update(self.skips)
+ copy.tags.update(self.tags)
+ for other in others:
+ copy.fails.update(other.fails)
+ copy.skips.update(other.skips)
+ copy.tags.update(other.tags)
+ return copy
+
+ def not_(self):
+ copy = compound()
+ copy.fails.update(NotPredicate(fail) for fail in self.fails)
+ copy.skips.update(NotPredicate(skip) for skip in self.skips)
+ copy.tags.update(self.tags)
+ return copy
+
+ @property
+ def enabled(self):
+ return self.enabled_for_config(config._current)
+
+ def enabled_for_config(self, config):
+ for predicate in self.skips.union(self.fails):
+ if predicate(config):
+ return False
+ else:
+ return True
+
+ def matching_config_reasons(self, config):
+ return [
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
+ if predicate(config)
+ ]
+
+ def include_test(self, include_tags, exclude_tags):
+ return bool(
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
+ )
+
+ def _extend(self, other):
+ self.skips.update(other.skips)
+ self.fails.update(other.fails)
+ self.tags.update(other.tags)
+
+ def __call__(self, fn):
+ if hasattr(fn, "_sa_exclusion_extend"):
+ fn._sa_exclusion_extend._extend(self)
+ return fn
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ return self._do(config._current, fn, *args, **kw)
+
+ decorated = decorate(fn)
+ decorated._sa_exclusion_extend = self
+ return decorated
+
+ @contextlib.contextmanager
+ def fail_if(self):
+ all_fails = compound()
+ all_fails.fails.update(self.skips.union(self.fails))
+
+ try:
+ yield
+ except Exception as ex:
+ all_fails._expect_failure(config._current, ex)
+ else:
+ all_fails._expect_success(config._current)
+
+ def _do(self, cfg, fn, *args, **kw):
+ for skip in self.skips:
+ if skip(cfg):
+ msg = "'%s' : %s" % (
+ config.get_current_test_name(),
+ skip._as_string(cfg),
+ )
+ config.skip_test(msg)
+
+ try:
+ return_value = fn(*args, **kw)
+ except Exception as ex:
+ self._expect_failure(cfg, ex, name=fn.__name__)
+ else:
+ self._expect_success(cfg, name=fn.__name__)
+ return return_value
+
+ def _expect_failure(self, config, ex, name="block"):
+ for fail in self.fails:
+ if fail(config):
+ if util.py2k:
+ str_ex = unicode(ex).encode( # noqa: F821
+ "utf-8", errors="ignore"
+ )
+ else:
+ str_ex = str(ex)
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str_ex)
+ )
+ )
+ break
+ else:
+ util.raise_(ex, with_traceback=sys.exc_info()[2])
+
+ def _expect_success(self, config, name="block"):
+ if not self.fails:
+ return
+
+ for fail in self.fails:
+ if fail(config):
+ raise AssertionError(
+ "Unexpected success for '%s' (%s)"
+ % (
+ name,
+ " and ".join(
+ fail._as_string(config) for fail in self.fails
+ ),
+ )
+ )
+
+
+def requires_tag(tagname):
+ return tags([tagname])
+
+
+def tags(tagnames):
+ comp = compound()
+ comp.tags.update(tagnames)
+ return comp
+
+
+def only_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return skip_if(NotPredicate(predicate), reason)
+
+
+def succeeds_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return fails_if(NotPredicate(predicate), reason)
+
+
+class Predicate(object):
+ @classmethod
+ def as_predicate(cls, predicate, description=None):
+ if isinstance(predicate, compound):
+ return cls.as_predicate(predicate.enabled_for_config, description)
+ elif isinstance(predicate, Predicate):
+ if description and predicate.description is None:
+ predicate.description = description
+ return predicate
+ elif isinstance(predicate, (list, set)):
+ return OrPredicate(
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
+ elif isinstance(predicate, tuple):
+ return SpecPredicate(*predicate)
+ elif isinstance(predicate, util.string_types):
+ tokens = re.match(
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
+ if not tokens:
+ raise ValueError(
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
+ db = tokens.group(1)
+ op = tokens.group(2)
+ spec = (
+ tuple(int(d) for d in tokens.group(3).split("."))
+ if tokens.group(3)
+ else None
+ )
+
+ return SpecPredicate(db, op, spec, description=description)
+ elif callable(predicate):
+ return LambdaPredicate(predicate, description)
+ else:
+ assert False, "unknown predicate type: %s" % predicate
+
+ def _format_description(self, config, negate=False):
+ bool_ = self(config)
+ if negate:
+ bool_ = not negate
+ return self.description % {
+ "driver": config.db.url.get_driver_name()
+ if config
+ else "<no driver>",
+ "database": config.db.url.get_backend_name()
+ if config
+ else "<no database>",
+ "doesnt_support": "doesn't support" if bool_ else "does support",
+ "does_support": "does support" if bool_ else "doesn't support",
+ }
+
+ def _as_string(self, config=None, negate=False):
+ raise NotImplementedError()
+
+
+class BooleanPredicate(Predicate):
+ def __init__(self, value, description=None):
+ self.value = value
+ self.description = description or "boolean %s" % value
+
+ def __call__(self, config):
+ return self.value
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config, negate=negate)
+
+
+class SpecPredicate(Predicate):
+ def __init__(self, db, op=None, spec=None, description=None):
+ self.db = db
+ self.op = op
+ self.spec = spec
+ self.description = description
+
+ _ops = {
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
+ }
+
+ def __call__(self, config):
+ if config is None:
+ return False
+
+ engine = config.db
+
+ if "+" in self.db:
+ dialect, driver = self.db.split("+")
+ else:
+ dialect, driver = self.db, None
+
+ if dialect and engine.name != dialect:
+ return False
+ if driver is not None and engine.driver != driver:
+ return False
+
+ if self.op is not None:
+ assert driver is None, "DBAPI version specs not supported yet"
+
+ version = _server_version(engine)
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
+ return oper(version, self.spec)
+ else:
+ return True
+
+ def _as_string(self, config, negate=False):
+ if self.description is not None:
+ return self._format_description(config)
+ elif self.op is None:
+ if negate:
+ return "not %s" % self.db
+ else:
+ return "%s" % self.db
+ else:
+ if negate:
+ return "not %s %s %s" % (self.db, self.op, self.spec)
+ else:
+ return "%s %s %s" % (self.db, self.op, self.spec)
+
+
+class LambdaPredicate(Predicate):
+ def __init__(self, lambda_, description=None, args=None, kw=None):
+ spec = inspect_getfullargspec(lambda_)
+ if not spec[0]:
+ self.lambda_ = lambda db: lambda_()
+ else:
+ self.lambda_ = lambda_
+ self.args = args or ()
+ self.kw = kw or {}
+ if description:
+ self.description = description
+ elif lambda_.__doc__:
+ self.description = lambda_.__doc__
+ else:
+ self.description = "custom function"
+
+ def __call__(self, config):
+ return self.lambda_(config)
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config)
+
+
+class NotPredicate(Predicate):
+ def __init__(self, predicate, description=None):
+ self.predicate = predicate
+ self.description = description
+
+ def __call__(self, config):
+ return not self.predicate(config)
+
+ def _as_string(self, config, negate=False):
+ if self.description:
+ return self._format_description(config, not negate)
+ else:
+ return self.predicate._as_string(config, not negate)
+
+
+class OrPredicate(Predicate):
+ def __init__(self, predicates, description=None):
+ self.predicates = predicates
+ self.description = description
+
+ def __call__(self, config):
+ for pred in self.predicates:
+ if pred(config):
+ return True
+ return False
+
+ def _eval_str(self, config, negate=False):
+ if negate:
+ conjunction = " and "
+ else:
+ conjunction = " or "
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
+
+ def _negation_str(self, config):
+ if self.description is not None:
+ return "Not " + self._format_description(config)
+ else:
+ return self._eval_str(config, negate=True)
+
+ def _as_string(self, config, negate=False):
+ if negate:
+ return self._negation_str(config)
+ else:
+ if self.description is not None:
+ return self._format_description(config)
+ else:
+ return self._eval_str(config)
+
+
+_as_predicate = Predicate.as_predicate
+
+
+def _is_excluded(db, op, spec):
+ return SpecPredicate(db, op, spec)(config._current)
+
+
+def _server_version(engine):
+ """Return a server_version_info tuple."""
+
+ # force metadata to be retrieved
+ conn = engine.connect()
+ version = getattr(engine.dialect, "server_version_info", None)
+ if version is None:
+ version = ()
+ conn.close()
+ return version
+
+
+def db_spec(*dbs):
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
+
+
+def open(): # noqa
+ return skip_if(BooleanPredicate(False, "mark as execute"))
+
+
+def closed():
+ return skip_if(BooleanPredicate(True, "marked as skip"))
+
+
+def fails(reason=None):
+ return fails_if(BooleanPredicate(True, reason or "expected to fail"))
+
+
+@decorator
+def future(fn, *arg):
+ return fails_if(LambdaPredicate(fn), "Future feature")
+
+
+def fails_on(db, reason=None):
+ return fails_if(db, reason)
+
+
+def fails_on_everything_except(*dbs):
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
+
+
+def skip(db, reason=None):
+ return skip_if(db, reason)
+
+
+def only_on(dbs, reason=None):
+ return only_if(
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
+ )
+
+
+def exclude(db, op, spec, reason=None):
+ return skip_if(SpecPredicate(db, op, spec), reason)
+
+
+def against(config, *queries):
+ assert queries, "no queries sent!"
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
new file mode 100644
index 0000000..0a2d63b
--- /dev/null
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -0,0 +1,870 @@
+# testing/fixtures.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import contextlib
+import re
+import sys
+
+import sqlalchemy as sa
+from . import assertions
+from . import config
+from . import schema
+from .entities import BasicEntity
+from .entities import ComparableEntity
+from .entities import ComparableMixin # noqa
+from .util import adict
+from .util import drop_all_tables_from_metadata
+from .. import event
+from .. import util
+from ..orm import declarative_base
+from ..orm import registry
+from ..orm.decl_api import DeclarativeMeta
+from ..schema import sort_tables_and_constraints
+
+
+@config.mark_base_test_class()
+class TestBase(object):
+ # A sequence of requirement names matching testing.requires decorators
+ __requires__ = ()
+
+ # A sequence of dialect names to exclude from the test class.
+ __unsupported_on__ = ()
+
+ # If present, test class is only runnable for the *single* specified
+ # dialect. If you need multiple, use __unsupported_on__ and invert.
+ __only_on__ = None
+
+ # A sequence of no-arg callables. If any are True, the entire testcase is
+ # skipped.
+ __skip_if__ = None
+
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
+ def assert_(self, val, msg=None):
+ assert val, msg
+
+ @config.fixture()
+ def nocache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+ @config.fixture()
+ def connection_no_trans(self):
+ eng = getattr(self, "bind", None) or config.db
+
+ with eng.connect() as conn:
+ yield conn
+
+ @config.fixture()
+ def connection(self):
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
+
+ conn = eng.connect()
+ trans = conn.begin()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
+
+ @config.fixture()
+ def close_result_when_finished(self):
+ to_close = []
+ to_consume = []
+
+ def go(result, consume=False):
+ to_close.append(result)
+ if consume:
+ to_consume.append(result)
+
+ yield go
+ for r in to_consume:
+ try:
+ r.all()
+ except:
+ pass
+ for r in to_close:
+ try:
+ r.close()
+ except:
+ pass
+
+ @config.fixture()
+ def registry(self, metadata):
+ reg = registry(metadata=metadata)
+ yield reg
+ reg.dispose()
+
+ @config.fixture
+ def decl_base(self, registry):
+ return registry.generate_base()
+
+ @config.fixture()
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
+
+ @config.fixture()
+ def future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from . import engines
+
+ def gen_testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url,
+ options=options,
+ future=future,
+ asyncio=asyncio,
+ transfer_staticpool=transfer_staticpool,
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
+
+ @config.fixture()
+ def async_testing_engine(self, testing_engine):
+ def go(**kw):
+ kw["asyncio"] = True
+ return testing_engine(**kw)
+
+ return go
+
+ @config.fixture
+ def fixture_session(self):
+ return fixture_session()
+
+ @config.fixture()
+ def metadata(self, request):
+ """Provide bound MetaData for a single test, dropping afterwards."""
+
+ from ..sql import schema
+
+ metadata = schema.MetaData()
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+
+ @config.fixture(
+ params=[
+ (rollback, second_operation, begin_nested)
+ for rollback in (True, False)
+ for second_operation in ("none", "execute", "begin")
+ for begin_nested in (
+ True,
+ False,
+ )
+ ]
+ )
+ def trans_ctx_manager_fixture(self, request, metadata):
+ rollback, second_operation, begin_nested = request.param
+
+ from sqlalchemy import Table, Column, Integer, func, select
+ from . import eq_
+
+ t = Table("test", metadata, Column("data", Integer))
+ eng = getattr(self, "bind", None) or config.db
+
+ t.create(eng)
+
+ def run_test(subject, trans_on_subject, execute_on_subject):
+ with subject.begin() as trans:
+
+ if begin_nested:
+ if not config.requirements.savepoints.enabled:
+ config.skip_test("savepoints not enabled")
+ if execute_on_subject:
+ nested_trans = subject.begin_nested()
+ else:
+ nested_trans = trans.begin_nested()
+
+ with nested_trans:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ # for nested trans, we always commit/rollback on the
+ # "nested trans" object itself.
+ # only Session(future=False) will affect savepoint
+ # transaction for session.commit/rollback
+
+ if rollback:
+ nested_trans.rollback()
+ else:
+ nested_trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction "
+ "inside context "
+ "manager. Please complete the context "
+ "manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(
+ t.insert(), {"data": 12}
+ )
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ # outside the nested trans block, but still inside the
+ # transaction block, we can run SQL, and it will be
+ # committed
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 14})
+ else:
+ trans.execute(t.insert(), {"data": 14})
+
+ else:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ if trans_on_subject:
+ if rollback:
+ subject.rollback()
+ else:
+ subject.commit()
+ else:
+ if rollback:
+ trans.rollback()
+ else:
+ trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction inside "
+ "context "
+ "manager. Please complete the context manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 12})
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if hasattr(trans, "begin"):
+ trans.begin()
+ else:
+ subject.begin()
+ elif second_operation == "begin_nested":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ expected_committed = 0
+ if begin_nested:
+ # begin_nested variant, we inserted a row after the nested
+ # block
+ expected_committed += 1
+ if not rollback:
+ # not rollback variant, our row inserted in the target
+ # block itself would be committed
+ expected_committed += 1
+
+ if execute_on_subject:
+ eq_(
+ subject.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+ else:
+ with subject.connect() as conn:
+ eq_(
+ conn.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+
+ return run_test
+
+
+_connection_fixture_connection = None
+
+
+@contextlib.contextmanager
+def _push_future_engine(engine):
+
+ from ..future.engine import Engine
+ from sqlalchemy import testing
+
+ facade = Engine._future_facade(engine)
+ config._current.push_engine(facade, testing)
+
+ yield facade
+
+ config._current.pop(testing)
+
+
+class FutureEngineMixin(object):
+ @config.fixture(autouse=True, scope="class")
+ def _push_future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+
+class TablesTest(TestBase):
+
+ # 'once', None
+ run_setup_bind = "once"
+
+ # 'once', 'each', None
+ run_define_tables = "once"
+
+ # 'once', 'each', None
+ run_create_tables = "once"
+
+ # 'once', 'each', None
+ run_inserts = "each"
+
+ # 'each', None
+ run_deletes = "each"
+
+ # 'once', None
+ run_dispose_bind = None
+
+ bind = None
+ _tables_metadata = None
+ tables = None
+ other = None
+ sequences = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ cls._setup_once_tables()
+
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
+ @classmethod
+ def _init_class(cls):
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
+
+ cls.other = adict()
+ cls.tables = adict()
+ cls.sequences = adict()
+
+ cls.bind = cls.setup_bind()
+ cls._tables_metadata = sa.MetaData()
+
+ @classmethod
+ def _setup_once_inserts(cls):
+ if cls.run_inserts == "once":
+ cls._load_fixtures()
+ with cls.bind.begin() as conn:
+ cls.insert_data(conn)
+
+ @classmethod
+ def _setup_once_tables(cls):
+ if cls.run_define_tables == "once":
+ cls.define_tables(cls._tables_metadata)
+ if cls.run_create_tables == "once":
+ cls._tables_metadata.create_all(cls.bind)
+ cls.tables.update(cls._tables_metadata.tables)
+ cls.sequences.update(cls._tables_metadata._sequences)
+
+ def _setup_each_tables(self):
+ if self.run_define_tables == "each":
+ self.define_tables(self._tables_metadata)
+ if self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+ self.tables.update(self._tables_metadata.tables)
+ self.sequences.update(self._tables_metadata._sequences)
+ elif self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+
+ def _setup_each_inserts(self):
+ if self.run_inserts == "each":
+ self._load_fixtures()
+ with self.bind.begin() as conn:
+ self.insert_data(conn)
+
+ def _teardown_each_tables(self):
+ if self.run_define_tables == "each":
+ self.tables.clear()
+ if self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+ self._tables_metadata.clear()
+ elif self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+
+ savepoints = getattr(config.requirements, "savepoints", False)
+ if savepoints:
+ savepoints = savepoints.enabled
+
+ # no need to run deletes if tables are recreated on setup
+ if (
+ self.run_define_tables != "each"
+ and self.run_create_tables != "each"
+ and self.run_deletes == "each"
+ ):
+ with self.bind.begin() as conn:
+ for table in reversed(
+ [
+ t
+ for (t, fks) in sort_tables_and_constraints(
+ self._tables_metadata.tables.values()
+ )
+ if t is not None
+ ]
+ ):
+ try:
+ if savepoints:
+ with conn.begin_nested():
+ conn.execute(table.delete())
+ else:
+ conn.execute(table.delete())
+ except sa.exc.DBAPIError as ex:
+ util.print_(
+ ("Error emptying table %s: %r" % (table, ex)),
+ file=sys.stderr,
+ )
+
+ @classmethod
+ def _teardown_once_metadata_bind(cls):
+ if cls.run_create_tables:
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
+
+ if cls.run_dispose_bind == "once":
+ cls.dispose_bind(cls.bind)
+
+ cls._tables_metadata.bind = None
+
+ if cls.run_setup_bind is not None:
+ cls.bind = None
+
+ @classmethod
+ def setup_bind(cls):
+ return config.db
+
+ @classmethod
+ def dispose_bind(cls, bind):
+ if hasattr(bind, "dispose"):
+ bind.dispose()
+ elif hasattr(bind, "close"):
+ bind.close()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ pass
+
+ @classmethod
+ def fixtures(cls):
+ return {}
+
+ @classmethod
+ def insert_data(cls, connection):
+ pass
+
+ def sql_count_(self, count, fn):
+ self.assert_sql_count(self.bind, fn, count)
+
+ def sql_eq_(self, callable_, statements):
+ self.assert_sql(self.bind, callable_, statements)
+
+ @classmethod
+ def _load_fixtures(cls):
+ """Insert rows as represented by the fixtures() method."""
+ headers, rows = {}, {}
+ for table, data in cls.fixtures().items():
+ if len(data) < 2:
+ continue
+ if isinstance(table, util.string_types):
+ table = cls.tables[table]
+ headers[table] = data[0]
+ rows[table] = data[1:]
+ for table, fks in sort_tables_and_constraints(
+ cls._tables_metadata.tables.values()
+ ):
+ if table is None:
+ continue
+ if table not in headers:
+ continue
+ with cls.bind.begin() as conn:
+ conn.execute(
+ table.insert(),
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
+
+class NoCache(object):
+ @config.fixture(autouse=True, scope="function")
+ def _disable_cache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+
+class RemovesEvents(object):
+ @util.memoized_property
+ def _event_fns(self):
+ return set()
+
+ def event_listen(self, target, name, fn, **kw):
+ self._event_fns.add((target, name, fn))
+ event.listen(target, name, fn, **kw)
+
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
+ for key in self._event_fns:
+ event.remove(*key)
+
+
+_fixture_sessions = set()
+
+
+def fixture_session(**kw):
+ kw.setdefault("autoflush", True)
+ kw.setdefault("expire_on_commit", True)
+
+ bind = kw.pop("bind", config.db)
+
+ sess = sa.orm.Session(bind, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def _close_all_sessions():
+ # will close all still-referenced sessions
+ sa.orm.session.close_all_sessions()
+ _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+ _close_all_sessions()
+ sa.orm.clear_mappers()
+
+
+def after_test():
+ if _fixture_sessions:
+ _close_all_sessions()
+
+
+class ORMTest(TestBase):
+ pass
+
+
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
+ # 'once', 'each', None
+ run_setup_classes = "once"
+
+ # 'once', 'each', None
+ run_setup_mappers = "each"
+
+ classes = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ if cls.classes is None:
+ cls.classes = adict()
+
+ cls._setup_once_tables()
+ cls._setup_once_classes()
+ cls._setup_once_mappers()
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_class()
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_classes()
+ self._setup_each_mappers()
+ self._setup_each_inserts()
+
+ yield
+
+ sa.orm.session.close_all_sessions()
+ self._teardown_each_mappers()
+ self._teardown_each_classes()
+ self._teardown_each_tables()
+
+ @classmethod
+ def _teardown_once_class(cls):
+ cls.classes.clear()
+
+ @classmethod
+ def _setup_once_classes(cls):
+ if cls.run_setup_classes == "once":
+ cls._with_register_classes(cls.setup_classes)
+
+ @classmethod
+ def _setup_once_mappers(cls):
+ if cls.run_setup_mappers == "once":
+ cls.mapper_registry, cls.mapper = cls._generate_registry()
+ cls._with_register_classes(cls.setup_mappers)
+
+ def _setup_each_mappers(self):
+ if self.run_setup_mappers != "once":
+ (
+ self.__class__.mapper_registry,
+ self.__class__.mapper,
+ ) = self._generate_registry()
+
+ if self.run_setup_mappers == "each":
+ self._with_register_classes(self.setup_mappers)
+
+ def _setup_each_classes(self):
+ if self.run_setup_classes == "each":
+ self._with_register_classes(self.setup_classes)
+
+ @classmethod
+ def _generate_registry(cls):
+ decl = registry(metadata=cls._tables_metadata)
+ return decl, decl.map_imperatively
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ """Run a setup method, framing the operation with a Base class
+ that will catch new subclasses to be established within
+ the "classes" registry.
+
+ """
+ cls_registry = cls.classes
+
+ assert cls_registry is not None
+
+ class FindFixture(type):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ type.__init__(cls, classname, bases, dict_)
+
+ class _Base(util.with_metaclass(FindFixture, object)):
+ pass
+
+ class Basic(BasicEntity, _Base):
+ pass
+
+ class Comparable(ComparableEntity, _Base):
+ pass
+
+ cls.Basic = Basic
+ cls.Comparable = Comparable
+ fn()
+
+ def _teardown_each_mappers(self):
+ # some tests create mappers in the test bodies
+ # and will define setup_mappers as None -
+ # clear mappers in any case
+ if self.run_setup_mappers != "once":
+ sa.orm.clear_mappers()
+
+ def _teardown_each_classes(self):
+ if self.run_setup_classes != "once":
+ self.classes.clear()
+
+ @classmethod
+ def setup_classes(cls):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ pass
+
+
+class DeclarativeMappedTest(MappedTest):
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
+
+ @classmethod
+ def _setup_once_tables(cls):
+ pass
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ cls_registry = cls.classes
+
+ class FindFixtureDeclarative(DeclarativeMeta):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ DeclarativeMeta.__init__(cls, classname, bases, dict_)
+
+ class DeclarativeBasic(object):
+ __table_cls__ = schema.Table
+
+ _DeclBase = declarative_base(
+ metadata=cls._tables_metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic,
+ )
+
+ cls.DeclarativeBasic = _DeclBase
+
+ # sets up cls.Basic which is helpful for things like composite
+ # classes
+ super(DeclarativeMappedTest, cls)._with_register_classes(fn)
+
+ if cls._tables_metadata.tables and cls.run_create_tables:
+ cls._tables_metadata.create_all(config.db)
+
+
+class ComputedReflectionFixtureTest(TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("computed_columns", "table_reflection")
+
+ regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
+
+ def normalize(self, text):
+ return self.regexp.sub("", text).lower()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ from .. import Integer
+ from .. import testing
+ from ..schema import Column
+ from ..schema import Computed
+ from ..schema import Table
+
+ Table(
+ "computed_default_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_col", Integer, Computed("normal + 42")),
+ Column("with_default", Integer, server_default="42"),
+ )
+
+ t = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal + 42")),
+ )
+
+ if testing.requires.schemas.enabled:
+ t2 = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal / 42")),
+ schema=config.test_schema,
+ )
+
+ if testing.requires.computed_columns_virtual.enabled:
+ t.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal + 2", persisted=False),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal / 2", persisted=False),
+ )
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ t.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal - 42", persisted=True),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal * 42", persisted=True),
+ )
+ )
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
new file mode 100644
index 0000000..e333c70
--- /dev/null
+++ b/lib/sqlalchemy/testing/mock.py
@@ -0,0 +1,32 @@
+# testing/mock.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Import stub for mock library.
+"""
+from __future__ import absolute_import
+
+from ..util import py3k
+
+
+if py3k:
+ from unittest.mock import MagicMock
+ from unittest.mock import Mock
+ from unittest.mock import call
+ from unittest.mock import patch
+ from unittest.mock import ANY
+else:
+ try:
+ from mock import MagicMock # noqa
+ from mock import Mock # noqa
+ from mock import call # noqa
+ from mock import patch # noqa
+ from mock import ANY # noqa
+ except ImportError:
+ raise ImportError(
+ "SQLAlchemy's test suite requires the "
+ "'mock' library as of 0.8.2."
+ )
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
new file mode 100644
index 0000000..f05960c
--- /dev/null
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -0,0 +1,151 @@
+# testing/pickleable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Classes used in pickling tests, need to be at the module level for
+unpickling.
+"""
+
+from . import fixtures
+from ..schema import Column
+from ..types import String
+
+
+class User(fixtures.ComparableEntity):
+ pass
+
+
+class Order(fixtures.ComparableEntity):
+ pass
+
+
+class Dingaling(fixtures.ComparableEntity):
+ pass
+
+
+class EmailUser(User):
+ pass
+
+
+class Address(fixtures.ComparableEntity):
+ pass
+
+
+# TODO: these are kind of arbitrary....
+class Child1(fixtures.ComparableEntity):
+ pass
+
+
+class Child2(fixtures.ComparableEntity):
+ pass
+
+
+class Parent(fixtures.ComparableEntity):
+ pass
+
+
+class Screen(object):
+ def __init__(self, obj, parent=None):
+ self.obj = obj
+ self.parent = parent
+
+
+class Mixin(object):
+ email_address = Column(String)
+
+
+class AddressWMixin(Mixin, fixtures.ComparableEntity):
+ pass
+
+
+class Foo(object):
+ def __init__(self, moredata, stuff="im stuff"):
+ self.data = "im data"
+ self.stuff = stuff
+ self.moredata = moredata
+
+ __hash__ = object.__hash__
+
+ def __eq__(self, other):
+ return (
+ other.data == self.data
+ and other.stuff == self.stuff
+ and other.moredata == self.moredata
+ )
+
+
+class Bar(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ __hash__ = object.__hash__
+
+ def __eq__(self, other):
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class OldSchool:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __eq__(self, other):
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+
+class OldSchoolWithoutCompare:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+
+class BarWithoutCompare(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class NotComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return NotImplemented
+
+ def __ne__(self, other):
+ return NotImplemented
+
+
+class BrokenComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ raise NotImplementedError
+
+ def __ne__(self, other):
+ raise NotImplementedError
diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/__init__.py
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
new file mode 100644
index 0000000..6721f48
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -0,0 +1,54 @@
+"""
+Bootstrapper for test framework plugins.
+
+The entire rationale for this system is to get the modules in plugin/
+imported without importing all of the supporting library, so that we can
+set up things for testing before coverage starts.
+
+The rationale for all of plugin/ being *in* the supporting library in the
+first place is so that the testing and plugin suite is available to other
+libraries, mainly external SQLAlchemy and Alembic dialects, to make use
+of the same test environment and standard suites available to
+SQLAlchemy/Alembic themselves without the need to ship/install a separate
+package outside of SQLAlchemy.
+
+
+"""
+
+import os
+import sys
+
+
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
+
+
+def load_file_as_module(name):
+ path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
+
+ if sys.version_info >= (3, 5):
+ import importlib.util
+
+ spec = importlib.util.spec_from_file_location(name, path)
+ assert spec is not None
+ assert spec.loader is not None
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ else:
+ import imp
+
+ mod = imp.load_source(name, path)
+
+ return mod
+
+
+if to_bootstrap == "pytest":
+ sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+ sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
+ if sys.version_info < (3, 0):
+ sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
+ "reinvent_fixtures_py2k"
+ )
+ sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
+else:
+ raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
new file mode 100644
index 0000000..d59564e
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -0,0 +1,789 @@
+# plugin/plugin_base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Testing extensions.
+
+this module is designed to work as a testing-framework-agnostic library,
+created so that multiple test frameworks can be supported at once
+(mostly so that we can migrate to new ones). The current target
+is pytest.
+
+"""
+
+from __future__ import absolute_import
+
+import abc
+import logging
+import re
+import sys
+
+# flag which indicates we are in the SQLAlchemy testing suite,
+# and not that of Alembic or a third party dialect.
+bootstrapped_as_sqlalchemy = False
+
+log = logging.getLogger("sqlalchemy.testing.plugin_base")
+
+
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+ import configparser
+
+ ABC = abc.ABC
+else:
+ import ConfigParser as configparser
+ import collections as collections_abc # noqa
+
+ class ABC(object):
+ __metaclass__ = abc.ABCMeta
+
+
+# late imports
+fixtures = None
+engines = None
+exclusions = None
+warnings = None
+profiling = None
+provision = None
+assertions = None
+requirements = None
+config = None
+testing = None
+util = None
+file_config = None
+
+logging = None
+include_tags = set()
+exclude_tags = set()
+options = None
+
+
+def setup_options(make_option):
+ make_option(
+ "--log-info",
+ action="callback",
+ type=str,
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type=str,
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type=str,
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type=str,
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dbdriver",
+ action="append",
+ type=str,
+ dest="dbdriver",
+ help="Additional database drivers to include in tests. "
+ "These are linked to the existing database URLs by the "
+ "provisioning system.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--disable-asyncio",
+ action="store_true",
+ help="disable test / fixtures / provisoning running in asyncio",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__ or __sparse_backend__",
+ )
+ make_option(
+ "--nomemory",
+ action="store_true",
+ dest="nomemory",
+ help="Don't run memory profiling tests",
+ )
+ make_option(
+ "--notimingintensive",
+ action="store_true",
+ dest="notimingintensive",
+ help="Don't run timing intensive tests",
+ )
+ make_option(
+ "--profile-sort",
+ type=str,
+ default="cumulative",
+ dest="profilesort",
+ help="Type of sort for profiling standard output",
+ )
+ make_option(
+ "--profile-dump",
+ type=str,
+ dest="profiledump",
+ help="Filename where a single profile run will be dumped",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type=str,
+ help="name of template database to use for PostgreSQL "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type=str,
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type=str,
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type=str,
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type=str,
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--write-profiles",
+ action="store_true",
+ dest="write_profiles",
+ default=False,
+ help="Write/update failing profiling data.",
+ )
+ make_option(
+ "--force-write-profiles",
+ action="store_true",
+ dest="force_write_profiles",
+ default=False,
+ help="Unconditionally write/update profiling data.",
+ )
+ make_option(
+ "--dump-pyannotate",
+ type=str,
+ dest="dump_pyannotate",
+ help="Run pyannotate and dump json info to given file",
+ )
+ make_option(
+ "--mypy-extra-test-path",
+ type=str,
+ action="append",
+ default=[],
+ dest="mypy_extra_test_paths",
+ help="Additional test directories to add to the mypy tests. "
+ "This is used only when running mypy tests. Multiple OK",
+ )
+
+
+def configure_follower(follower_ident):
+ """Configure required state for a follower.
+
+ This invokes in the parent process and typically includes
+ database creation.
+
+ """
+ from sqlalchemy.testing import provision
+
+ provision.FOLLOWER_IDENT = follower_ident
+
+
+def memoize_important_follower_config(dict_):
+ """Store important configuration we will need to send to a follower.
+
+ This invokes in the parent process after normal config is set up.
+
+ This is necessary as pytest seems to not be using forking, so we
+ start with nothing in memory, *but* it isn't running our argparse
+ callables, so we have to just copy all of that over.
+
+ """
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
+ }
+
+
+def restore_important_follower_config(dict_):
+ """Restore important configuration needed by a follower.
+
+ This invokes in the follower process.
+
+ """
+ global include_tags, exclude_tags
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
+
+
+def read_config():
+ global file_config
+ file_config = configparser.ConfigParser()
+ file_config.read(["setup.cfg", "test.cfg"])
+
+
+def pre_begin(opt):
+ """things to set up early, before coverage might be setup."""
+ global options
+ options = opt
+ for fn in pre_configure:
+ fn(options, file_config)
+
+
+def set_coverage_flag(value):
+ options.has_coverage = value
+
+
+def post_begin():
+ """things to set up later, once we know coverage is running."""
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ # late imports, has to happen after config.
+ global util, fixtures, engines, exclusions, assertions, provision
+ global warnings, profiling, config, testing
+ from sqlalchemy import testing # noqa
+ from sqlalchemy.testing import fixtures, engines, exclusions # noqa
+ from sqlalchemy.testing import assertions, warnings, profiling # noqa
+ from sqlalchemy.testing import config, provision # noqa
+ from sqlalchemy import util # noqa
+
+ warnings.setup_filters()
+
+
+def _log(opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+
+ logging.basicConfig()
+
+ if opt_str.endswith("-info"):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith("-debug"):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+ print("Available --db options (use --dburi to override)")
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
+ sys.exit(0)
+
+
+def _requirements_opt(opt_str, value, parser):
+ _setup_requirements(value)
+
+
+def _exclude_tag(opt_str, value, parser):
+ exclude_tags.add(value.replace("-", "_"))
+
+
+def _include_tag(opt_str, value, parser):
+ include_tags.add(value.replace("-", "_"))
+
+
+pre_configure = []
+post_configure = []
+
+
+def pre(fn):
+ pre_configure.append(fn)
+ return fn
+
+
+def post(fn):
+ post_configure.append(fn)
+ return fn
+
+
+@pre
+def _setup_options(opt, file_config):
+ global options
+ options = opt
+
+
+@pre
+def _set_nomemory(opt, file_config):
+ if opt.nomemory:
+ exclude_tags.add("memory_intensive")
+
+
+@pre
+def _set_notimingintensive(opt, file_config):
+ if opt.notimingintensive:
+ exclude_tags.add("timing_intensive")
+
+
+@pre
+def _monkeypatch_cdecimal(options, file_config):
+ if options.cdecimal:
+ import cdecimal
+
+ sys.modules["decimal"] = cdecimal
+
+
+@post
+def _init_symbols(options, file_config):
+ from sqlalchemy.testing import config
+
+ config._fixture_functions = _fixture_fn_class()
+
+
+@post
+def _set_disable_asyncio(opt, file_config):
+ if opt.disable_asyncio or not py3k:
+ from sqlalchemy.testing import asyncio
+
+ asyncio.ENABLE_ASYNCIO = False
+
+
+@post
+def _engine_uri(options, file_config):
+
+ from sqlalchemy import testing
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing import provision
+
+ if options.dburi:
+ db_urls = list(options.dburi)
+ else:
+ db_urls = []
+
+ extra_drivers = options.dbdriver or []
+
+ if options.db:
+ for db_token in options.db:
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
+ raise RuntimeError(
+ "Unknown URI specifier '%s'. "
+ "Specify --dbs for known uris." % db
+ )
+ else:
+ db_urls.append(file_config.get("db", db))
+
+ if not db_urls:
+ db_urls.append(file_config.get("db", "default"))
+
+ config._current = None
+
+ expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
+
+ for db_url in expanded_urls:
+ log.info("Adding database URL: %s", db_url)
+
+ if options.write_idents and provision.FOLLOWER_IDENT:
+ with open(options.write_idents, "a") as file_:
+ file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
+
+ cfg = provision.setup_config(
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
+ if not config._current:
+ cfg.set_as_current(cfg, testing)
+
+
+@post
+def _requirements(options, file_config):
+
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
+ _setup_requirements(requirement_cls)
+
+
+def _setup_requirements(argument):
+ from sqlalchemy.testing import config
+ from sqlalchemy import testing
+
+ if config.requirements is not None:
+ return
+
+ modname, clsname = argument.split(":")
+
+ # importlib.import_module() only introduced in 2.7, a little
+ # late
+ mod = __import__(modname)
+ for component in modname.split(".")[1:]:
+ mod = getattr(mod, component)
+ req_cls = getattr(mod, clsname)
+
+ config.requirements = testing.requires = req_cls()
+
+ config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
+
+
+@post
+def _prep_testing_database(options, file_config):
+ from sqlalchemy.testing import config
+
+ if options.dropfirst:
+ from sqlalchemy.testing import provision
+
+ for cfg in config.Config.all_configs():
+ provision.drop_all_schema_objects(cfg, cfg.db)
+
+
+@post
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm.util import randomize_unitofwork
+
+ randomize_unitofwork()
+
+
+@post
+def _post_setup_options(opt, file_config):
+ from sqlalchemy.testing import config
+
+ config.options = options
+ config.file_config = file_config
+
+
+@post
+def _setup_profiling(options, file_config):
+ from sqlalchemy.testing import profiling
+
+ profiling._profile_stats = profiling.ProfileStatsFile(
+ file_config.get("sqla_testing", "profile_file"),
+ sort=options.profilesort,
+ dump=options.profiledump,
+ )
+
+
+def want_class(name, cls):
+ if not issubclass(cls, fixtures.TestBase):
+ return False
+ elif name.startswith("_"):
+ return False
+ elif (
+ config.options.backend_only
+ and not getattr(cls, "__backend__", False)
+ and not getattr(cls, "__sparse_backend__", False)
+ and not getattr(cls, "__only_on__", False)
+ ):
+ return False
+ else:
+ return True
+
+
+def want_method(cls, fn):
+ if not fn.__name__.startswith("test_"):
+ return False
+ elif fn.__module__ is None:
+ return False
+ elif include_tags:
+ return (
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
+ ) or (
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
+ )
+ elif exclude_tags and hasattr(cls, "__tags__"):
+ return exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
+ return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
+ else:
+ return True
+
+
+def generate_sub_tests(cls, module):
+ if getattr(cls, "__backend__", False) or getattr(
+ cls, "__sparse_backend__", False
+ ):
+ sparse = getattr(cls, "__sparse_backend__", False)
+ for cfg in _possible_configs_for_cls(cls, sparse=sparse):
+ orig_name = cls.__name__
+
+ # we can have special chars in these names except for the
+ # pytest junit plugin, which is tripped up by the brackets
+ # and periods, so sanitize
+
+ alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub(r"_+$", "", alpha_name)
+ name = "%s_%s" % (cls.__name__, alpha_name)
+ subcls = type(
+ name,
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
+ )
+ setattr(module, name, subcls)
+ yield subcls
+ else:
+ yield cls
+
+
+def start_test_class_outside_fixtures(cls):
+ _do_skips(cls)
+ _setup_engine(cls)
+
+
+def stop_test_class(cls):
+ # close sessions, immediate connections, etc.
+ fixtures.stop_test_class_inside_fixtures(cls)
+
+ # close outstanding connection pool connections, dispose of
+ # additional engines
+ engines.testing_reaper.stop_test_class_inside_fixtures()
+
+
+def stop_test_class_outside_fixtures(cls):
+ engines.testing_reaper.stop_test_class_outside_fixtures()
+ provision.stop_test_class_outside_fixtures(config, config.db, cls)
+ try:
+ if not options.low_connections:
+ assertions.global_cleanup_assertions()
+ finally:
+ _restore_engine()
+
+
+def _restore_engine():
+ if config._current:
+ config._current.reset(testing)
+
+
+def final_process_cleanup():
+ engines.testing_reaper.final_cleanup()
+ assertions.global_cleanup_assertions()
+ _restore_engine()
+
+
+def _setup_engine(cls):
+ if getattr(cls, "__engine_options__", None):
+ opts = dict(cls.__engine_options__)
+ opts["scope"] = "class"
+ eng = engines.testing_engine(options=opts)
+ config._current.push_engine(eng, testing)
+
+
+def before_test(test, test_module_name, test_class, test_name):
+
+ # format looks like:
+ # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
+
+ name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
+
+ id_ = "%s.%s.%s" % (test_module_name, name, test_name)
+
+ profiling._start_current_test(id_)
+
+
+def after_test(test):
+ fixtures.after_test()
+ engines.testing_reaper.after_test()
+
+
+def after_test_fixtures(test):
+ engines.testing_reaper.after_test_outside_fixtures(test)
+
+
+def _possible_configs_for_cls(cls, reasons=None, sparse=False):
+ all_configs = set(config.Config.all_configs())
+
+ if cls.__unsupported_on__:
+ spec = exclusions.db_spec(*cls.__unsupported_on__)
+ for config_obj in list(all_configs):
+ if spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if getattr(cls, "__only_on__", None):
+ spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+ for config_obj in list(all_configs):
+ if not spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if getattr(cls, "__only_on_config__", None):
+ all_configs.intersection_update([cls.__only_on_config__])
+
+ if hasattr(cls, "__requires__"):
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__requires__:
+ check = getattr(requirements, requirement)
+
+ skip_reasons = check.matching_config_reasons(config_obj)
+ if skip_reasons:
+ all_configs.remove(config_obj)
+ if reasons is not None:
+ reasons.extend(skip_reasons)
+ break
+
+ if hasattr(cls, "__prefer_requires__"):
+ non_preferred = set()
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__prefer_requires__:
+ check = getattr(requirements, requirement)
+
+ if not check.enabled_for_config(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ if sparse:
+ # pick only one config from each base dialect
+ # sorted so we get the same backend each time selecting the highest
+ # server version info.
+ per_dialect = {}
+ for cfg in reversed(
+ sorted(
+ all_configs,
+ key=lambda cfg: (
+ cfg.db.name,
+ cfg.db.driver,
+ cfg.db.dialect.server_version_info,
+ ),
+ )
+ ):
+ db = cfg.db.name
+ if db not in per_dialect:
+ per_dialect[db] = cfg
+ return per_dialect.values()
+
+ return all_configs
+
+
+def _do_skips(cls):
+ reasons = []
+ all_configs = _possible_configs_for_cls(cls, reasons)
+
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
+ if c():
+ config.skip_test(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
+ )
+
+ if not all_configs:
+ msg = "'%s' unsupported on any DB implementation %s%s" % (
+ cls.__name__,
+ ", ".join(
+ "'%s(%s)+%s'"
+ % (
+ config_obj.db.name,
+ ".".join(
+ str(dig)
+ for dig in exclusions._server_version(config_obj.db)
+ ),
+ config_obj.db.driver,
+ )
+ for config_obj in config.Config.all_configs()
+ ),
+ ", ".join(reasons),
+ )
+ config.skip_test(msg)
+ elif hasattr(cls, "__prefer_backends__"):
+ non_preferred = set()
+ spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
+ for config_obj in all_configs:
+ if not spec(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ if config._current not in all_configs:
+ _setup_config(all_configs.pop(), cls)
+
+
+def _setup_config(config_obj, ctx):
+ config._current.push(config_obj, testing)
+
+
+class FixtureFunctions(ABC):
+ @abc.abstractmethod
+ def skip_test_exception(self, *arg, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def combinations(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def param_ident(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def fixture(self, *arg, **kw):
+ raise NotImplementedError()
+
+ def get_current_test_name(self):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def mark_base_test_class(self):
+ raise NotImplementedError()
+
+
+_fixture_fn_class = None
+
+
+def set_fixture_functions(fixture_fn_class):
+ global _fixture_fn_class
+ _fixture_fn_class = fixture_fn_class
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
new file mode 100644
index 0000000..5a51582
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -0,0 +1,820 @@
+try:
+ # installed by bootstrap.py
+ import sqla_plugin_base as plugin_base
+except ImportError:
+ # assume we're a package, use traditional import
+ from . import plugin_base
+
+import argparse
+import collections
+from functools import update_wrapper
+import inspect
+import itertools
+import operator
+import os
+import re
+import sys
+import uuid
+
+import pytest
+
+
+py2k = sys.version_info < (3, 0)
+if py2k:
+ try:
+ import sqla_reinvent_fixtures as reinvent_fixtures_py2k
+ except ImportError:
+ from . import reinvent_fixtures_py2k
+
+
+def pytest_addoption(parser):
+ group = parser.getgroup("sqlalchemy")
+
+ def make_option(name, **kw):
+ callback_ = kw.pop("callback", None)
+ if callback_:
+
+ class CallableAction(argparse.Action):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
+ callback_(option_string, values, parser)
+
+ kw["action"] = CallableAction
+
+ zeroarg_callback = kw.pop("zeroarg_callback", None)
+ if zeroarg_callback:
+
+ class CallableAction(argparse.Action):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None, # noqa
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
+ zeroarg_callback(option_string, values, parser)
+
+ kw["action"] = CallableAction
+
+ group.addoption(name, **kw)
+
+ plugin_base.setup_options(make_option)
+ plugin_base.read_config()
+
+
+def pytest_configure(config):
+ if config.pluginmanager.hasplugin("xdist"):
+ config.pluginmanager.register(XDistHooks())
+
+ if hasattr(config, "workerinput"):
+ plugin_base.restore_important_follower_config(config.workerinput)
+ plugin_base.configure_follower(config.workerinput["follower_ident"])
+ else:
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
+ os.remove(config.option.write_idents)
+
+ plugin_base.pre_begin(config.option)
+
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
+
+ plugin_base.set_fixture_functions(PytestFixtureFunctions)
+
+ if config.option.dump_pyannotate:
+ global DUMP_PYANNOTATE
+ DUMP_PYANNOTATE = True
+
+
+DUMP_PYANNOTATE = False
+
+
+@pytest.fixture(autouse=True)
+def collect_types_fixture():
+ if DUMP_PYANNOTATE:
+ from pyannotate_runtime import collect_types
+
+ collect_types.start()
+ yield
+ if DUMP_PYANNOTATE:
+ collect_types.stop()
+
+
+def pytest_sessionstart(session):
+ from sqlalchemy.testing import asyncio
+
+ asyncio._assume_async(plugin_base.post_begin)
+
+
+def pytest_sessionfinish(session):
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
+
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ collect_types.dump_stats(session.config.option.dump_pyannotate)
+
+
+def pytest_collection_finish(session):
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
+
+ def _filter(filename):
+ filename = os.path.normpath(os.path.abspath(filename))
+ if "lib/sqlalchemy" not in os.path.commonpath(
+ [filename, lib_sqlalchemy]
+ ):
+ return None
+ if "testing" in filename:
+ return None
+
+ return filename
+
+ collect_types.init_types_collection(filter_filename=_filter)
+
+
+class XDistHooks(object):
+ def pytest_configure_node(self, node):
+ from sqlalchemy.testing import provision
+ from sqlalchemy.testing import asyncio
+
+ # the master for each node fills workerinput dictionary
+ # which pytest-xdist will transfer to the subprocess
+
+ plugin_base.memoize_important_follower_config(node.workerinput)
+
+ node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
+
+ asyncio._maybe_async_provisioning(
+ provision.create_follower_db, node.workerinput["follower_ident"]
+ )
+
+ def pytest_testnodedown(self, node, error):
+ from sqlalchemy.testing import provision
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async_provisioning(
+ provision.drop_follower_db, node.workerinput["follower_ident"]
+ )
+
+
+def pytest_collection_modifyitems(session, config, items):
+
+ # look for all those classes that specify __backend__ and
+ # expand them out into per-database test cases.
+
+ # this is much easier to do within pytest_pycollect_makeitem, however
+ # pytest is iterating through cls.__dict__ as makeitem is
+ # called which causes a "dictionary changed size" error on py3k.
+ # I'd submit a pullreq for them to turn it into a list first, but
+ # it's to suit the rather odd use case here which is that we are adding
+ # new classes to a module on the fly.
+
+ from sqlalchemy.testing import asyncio
+
+ rebuilt_items = collections.defaultdict(
+ lambda: collections.defaultdict(list)
+ )
+
+ items[:] = [
+ item
+ for item in items
+ if item.getparent(pytest.Class) is not None
+ and not item.getparent(pytest.Class).name.startswith("_")
+ ]
+
+ test_classes = set(item.getparent(pytest.Class) for item in items)
+
+ def collect(element):
+ for inst_or_fn in element.collect():
+ if isinstance(inst_or_fn, pytest.Collector):
+ # no yield from in 2.7
+ for el in collect(inst_or_fn):
+ yield el
+ else:
+ yield inst_or_fn
+
+ def setup_test_classes():
+ for test_class in test_classes:
+ for sub_cls in plugin_base.generate_sub_tests(
+ test_class.cls, test_class.module
+ ):
+ if sub_cls is not test_class.cls:
+ per_cls_dict = rebuilt_items[test_class.cls]
+
+ # support pytest 5.4.0 and above pytest.Class.from_parent
+ ctor = getattr(pytest.Class, "from_parent", pytest.Class)
+ module = test_class.getparent(pytest.Module)
+ for fn in collect(
+ ctor(name=sub_cls.__name__, parent=module)
+ ):
+ per_cls_dict[fn.name].append(fn)
+
+ # class requirements will sometimes need to access the DB to check
+ # capabilities, so need to do this for async
+ asyncio._maybe_async_provisioning(setup_test_classes)
+
+ newitems = []
+ for item in items:
+ cls_ = item.cls
+ if cls_ in rebuilt_items:
+ newitems.extend(rebuilt_items[cls_][item.name])
+ else:
+ newitems.append(item)
+
+ if py2k:
+ for item in newitems:
+ reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
+
+ # seems like the functions attached to a test class aren't sorted already?
+ # is that true and why's that? (when using unittest, they're sorted)
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.getparent(pytest.Module).name,
+ item.getparent(pytest.Class).name,
+ item.name,
+ ),
+ )
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+ if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+ from sqlalchemy.testing import config
+
+ if config.any_async:
+ obj = _apply_maybe_async(obj)
+
+ ctor = getattr(pytest.Class, "from_parent", pytest.Class)
+ return [
+ ctor(name=parametrize_cls.__name__, parent=collector)
+ for parametrize_cls in _parametrize_cls(collector.module, obj)
+ ]
+ elif (
+ inspect.isfunction(obj)
+ and collector.cls is not None
+ and plugin_base.want_method(collector.cls, obj)
+ ):
+ # None means, fall back to default logic, which includes
+ # method-level parametrize
+ return None
+ else:
+ # empty list means skip this item
+ return []
+
+
+def _is_wrapped_coroutine_function(fn):
+ while hasattr(fn, "__wrapped__"):
+ fn = fn.__wrapped__
+
+ return inspect.iscoroutinefunction(fn)
+
+
+def _apply_maybe_async(obj, recurse=True):
+ from sqlalchemy.testing import asyncio
+
+ for name, value in vars(obj).items():
+ if (
+ (callable(value) or isinstance(value, classmethod))
+ and not getattr(value, "_maybe_async_applied", False)
+ and (name.startswith("test_"))
+ and not _is_wrapped_coroutine_function(value)
+ ):
+ is_classmethod = False
+ if isinstance(value, classmethod):
+ value = value.__func__
+ is_classmethod = True
+
+ @_pytest_fn_decorator
+ def make_async(fn, *args, **kwargs):
+ return asyncio._maybe_async(fn, *args, **kwargs)
+
+ do_async = make_async(value)
+ if is_classmethod:
+ do_async = classmethod(do_async)
+ do_async._maybe_async_applied = True
+
+ setattr(obj, name, do_async)
+ if recurse:
+ for cls in obj.mro()[1:]:
+ if cls != object:
+ _apply_maybe_async(cls, False)
+ return obj
+
+
+def _parametrize_cls(module, cls):
+ """implement a class-based version of pytest parametrize."""
+
+ if "_sa_parametrize" not in cls.__dict__:
+ return [cls]
+
+ _sa_parametrize = cls._sa_parametrize
+ classes = []
+ for full_param_set in itertools.product(
+ *[params for argname, params in _sa_parametrize]
+ ):
+ cls_variables = {}
+
+ for argname, param in zip(
+ [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
+ ):
+ if not argname:
+ raise TypeError("need argnames for class-based combinations")
+ argname_split = re.split(r",\s*", argname)
+ for arg, val in zip(argname_split, param.values):
+ cls_variables[arg] = val
+ parametrized_name = "_".join(
+ # token is a string, but in py2k pytest is giving us a unicode,
+ # so call str() on it.
+ str(re.sub(r"\W", "", token))
+ for param in full_param_set
+ for token in param.id.split("-")
+ )
+ name = "%s_%s" % (cls.__name__, parametrized_name)
+ newcls = type.__new__(type, name, (cls,), cls_variables)
+ setattr(module, name, newcls)
+ classes.append(newcls)
+ return classes
+
+
+_current_class = None
+
+
+def pytest_runtest_setup(item):
+ from sqlalchemy.testing import asyncio
+
+ # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
+ # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
+ # for the whole class and has to run things that are across all current
+ # databases, so we run this outside of the pytest fixture system altogether
+ # and ensure asyncio greenlet if any engines are async
+
+ global _current_class
+
+ if isinstance(item, pytest.Function) and _current_class is None:
+ asyncio._maybe_async_provisioning(
+ plugin_base.start_test_class_outside_fixtures,
+ item.cls,
+ )
+ _current_class = item.getparent(pytest.Class)
+
+
+@pytest.hookimpl(hookwrapper=True)
+def pytest_runtest_teardown(item, nextitem):
+ # runs inside of pytest function fixture scope
+ # after test function runs
+ from sqlalchemy.testing import asyncio
+ from sqlalchemy.util import string_types
+
+ asyncio._maybe_async(plugin_base.after_test, item)
+
+ yield
+ # this is now after all the fixture teardown have run, the class can be
+ # finalized. Since pytest v7 this finalizer can no longer be added in
+ # pytest_runtest_setup since the class has not yet been setup at that
+ # time.
+ # See https://github.com/pytest-dev/pytest/issues/9343
+ global _current_class, _current_report
+
+ if _current_class is not None and (
+ # last test or a new class
+ nextitem is None
+ or nextitem.getparent(pytest.Class) is not _current_class
+ ):
+ _current_class = None
+
+ try:
+ asyncio._maybe_async_provisioning(
+ plugin_base.stop_test_class_outside_fixtures, item.cls
+ )
+ except Exception as e:
+ # in case of an exception during teardown attach the original
+ # error to the exception message, otherwise it will get lost
+ if _current_report.failed:
+ if not e.args:
+ e.args = (
+ "__Original test failure__:\n"
+ + _current_report.longreprtext,
+ )
+ elif e.args[-1] and isinstance(e.args[-1], string_types):
+ args = list(e.args)
+ args[-1] += (
+ "\n__Original test failure__:\n"
+ + _current_report.longreprtext
+ )
+ e.args = tuple(args)
+ else:
+ e.args += (
+ "__Original test failure__",
+ _current_report.longreprtext,
+ )
+ raise
+ finally:
+ _current_report = None
+
+
+def pytest_runtest_call(item):
+ # runs inside of pytest function fixture scope
+ # before test function runs
+
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async(
+ plugin_base.before_test,
+ item,
+ item.module.__name__,
+ item.cls,
+ item.name,
+ )
+
+
+_current_report = None
+
+
+def pytest_runtest_logreport(report):
+ global _current_report
+ if report.when == "call":
+ _current_report = report
+
+
+@pytest.fixture(scope="class")
+def setup_class_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ cls = request.cls
+
+ if hasattr(cls, "setup_test_class"):
+ asyncio._maybe_async(cls.setup_test_class)
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_setup(request)
+
+ yield
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_teardown(request)
+
+ if hasattr(cls, "teardown_test_class"):
+ asyncio._maybe_async(cls.teardown_test_class)
+
+ asyncio._maybe_async(plugin_base.stop_test_class, cls)
+
+
+@pytest.fixture(scope="function")
+def setup_test_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ # called for each test
+
+ self = request.instance
+
+ # before this fixture runs:
+
+ # 1. function level "autouse" fixtures under py3k (examples: TablesTest
+ # define tables / data, MappedTest define tables / mappers / data)
+
+ # 2. run homegrown function level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_setup(request)
+
+ # 3. run outer xdist-style setup
+ if hasattr(self, "setup_test"):
+ asyncio._maybe_async(self.setup_test)
+
+ # alembic test suite is using setUp and tearDown
+ # xdist methods; support these in the test suite
+ # for the near term
+ if hasattr(self, "setUp"):
+ asyncio._maybe_async(self.setUp)
+
+ # inside the yield:
+ # 4. function level fixtures defined on test functions themselves,
+ # e.g. "connection", "metadata" run next
+
+ # 5. pytest hook pytest_runtest_call then runs
+
+ # 6. test itself runs
+
+ yield
+
+ # yield finishes:
+
+ # 7. function level fixtures defined on test functions
+ # themselves, e.g. "connection" rolls back the transaction, "metadata"
+ # emits drop all
+
+ # 8. pytest hook pytest_runtest_teardown hook runs, this is associated
+ # with fixtures close all sessions, provisioning.stop_test_class(),
+ # engines.testing_reaper -> ensure all connection pool connections
+ # are returned, engines created by testing_engine that aren't the
+ # config engine are disposed
+
+ asyncio._maybe_async(plugin_base.after_test_fixtures, self)
+
+ # 10. run xdist-style teardown
+ if hasattr(self, "tearDown"):
+ asyncio._maybe_async(self.tearDown)
+
+ if hasattr(self, "teardown_test"):
+ asyncio._maybe_async(self.teardown_test)
+
+ # 11. run homegrown function-level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
+
+ # 12. function level "autouse" fixtures under py3k (examples: TablesTest /
+ # MappedTest delete table data, possibly drop tables and clear mappers
+ # depending on the flags defined by the test class)
+
+
+def getargspec(fn):
+ if sys.version_info.major == 3:
+ return inspect.getfullargspec(fn)
+ else:
+ return inspect.getargspec(fn)
+
+
+def _pytest_fn_decorator(target):
+ """Port of langhelpers.decorator with pytest-specific tricks."""
+
+ from sqlalchemy.util.langhelpers import format_argspec_plus
+ from sqlalchemy.util.compat import inspect_getfullargspec
+
+ def _exec_code_in_env(code, env, fn_name):
+ exec(code, env)
+ return env[fn_name]
+
+ def decorate(fn, add_positional_parameters=()):
+
+ spec = inspect_getfullargspec(fn)
+ if add_positional_parameters:
+ spec.args.extend(add_positional_parameters)
+
+ metadata = dict(
+ __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
+ )
+ metadata.update(format_argspec_plus(spec, grouped=False))
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
+"""
+ % metadata
+ )
+ decorated = _exec_code_in_env(
+ code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
+ )
+ if not add_positional_parameters:
+ decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ decorated.__wrapped__ = fn
+ return update_wrapper(decorated, fn)
+ else:
+ # this is the pytest hacky part. don't do a full update wrapper
+ # because pytest is really being sneaky about finding the args
+ # for the wrapped function
+ decorated.__module__ = fn.__module__
+ decorated.__name__ = fn.__name__
+ if hasattr(fn, "pytestmark"):
+ decorated.pytestmark = fn.pytestmark
+ return decorated
+
+ return decorate
+
+
+class PytestFixtureFunctions(plugin_base.FixtureFunctions):
+ def skip_test_exception(self, *arg, **kw):
+ return pytest.skip.Exception(*arg, **kw)
+
+ def mark_base_test_class(self):
+ return pytest.mark.usefixtures(
+ "setup_class_methods", "setup_test_methods"
+ )
+
+ _combination_id_fns = {
+ "i": lambda obj: obj,
+ "r": repr,
+ "s": str,
+ "n": lambda obj: obj.__name__
+ if hasattr(obj, "__name__")
+ else type(obj).__name__,
+ }
+
+ def combinations(self, *arg_sets, **kw):
+ """Facade for pytest.mark.parametrize.
+
+ Automatically derives argument names from the callable which in our
+ case is always a method on a class with positional arguments.
+
+ ids for parameter sets are derived using an optional template.
+
+ """
+ from sqlalchemy.testing import exclusions
+
+ if sys.version_info.major == 3:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
+ arg_sets = list(arg_sets[0])
+ else:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
+ arg_sets = list(arg_sets[0])
+
+ argnames = kw.pop("argnames", None)
+
+ def _filter_exclusions(args):
+ result = []
+ gathered_exclusions = []
+ for a in args:
+ if isinstance(a, exclusions.compound):
+ gathered_exclusions.append(a)
+ else:
+ result.append(a)
+
+ return result, gathered_exclusions
+
+ id_ = kw.pop("id_", None)
+
+ tobuild_pytest_params = []
+ has_exclusions = False
+ if id_:
+ _combination_id_fns = self._combination_id_fns
+
+ # because itemgetter is not consistent for one argument vs.
+ # multiple, make it multiple in all cases and use a slice
+ # to omit the first argument
+ _arg_getter = operator.itemgetter(
+ 0,
+ *[
+ idx
+ for idx, char in enumerate(id_)
+ if char in ("n", "r", "s", "a")
+ ]
+ )
+ fns = [
+ (operator.itemgetter(idx), _combination_id_fns[char])
+ for idx, char in enumerate(id_)
+ if char in _combination_id_fns
+ ]
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ parameters = _arg_getter(fn_params)[1:]
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (
+ parameters,
+ param_exclusions,
+ "-".join(
+ comb_fn(getter(arg)) for getter, comb_fn in fns
+ ),
+ )
+ )
+
+ else:
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (fn_params, param_exclusions, None)
+ )
+
+ pytest_params = []
+ for parameters, param_exclusions, id_ in tobuild_pytest_params:
+ if has_exclusions:
+ parameters += (param_exclusions,)
+
+ param = pytest.param(*parameters, id=id_)
+ pytest_params.append(param)
+
+ def decorate(fn):
+ if inspect.isclass(fn):
+ if has_exclusions:
+ raise NotImplementedError(
+ "exclusions not supported for class level combinations"
+ )
+ if "_sa_parametrize" not in fn.__dict__:
+ fn._sa_parametrize = []
+ fn._sa_parametrize.append((argnames, pytest_params))
+ return fn
+ else:
+ if argnames is None:
+ _argnames = getargspec(fn).args[1:]
+ else:
+ _argnames = re.split(r", *", argnames)
+
+ if has_exclusions:
+ _argnames += ["_exclusions"]
+
+ @_pytest_fn_decorator
+ def check_exclusions(fn, *args, **kw):
+ _exclusions = args[-1]
+ if _exclusions:
+ exlu = exclusions.compound().add(*_exclusions)
+ fn = exlu(fn)
+ return fn(*args[0:-1], **kw)
+
+ def process_metadata(spec):
+ spec.args.append("_exclusions")
+
+ fn = check_exclusions(
+ fn, add_positional_parameters=("_exclusions",)
+ )
+
+ return pytest.mark.parametrize(_argnames, pytest_params)(fn)
+
+ return decorate
+
+ def param_ident(self, *parameters):
+ ident = parameters[0]
+ return pytest.param(*parameters[1:], id=ident)
+
+ def fixture(self, *arg, **kw):
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing import asyncio
+
+ # wrapping pytest.fixture function. determine if
+ # decorator was called as @fixture or @fixture().
+ if len(arg) > 0 and callable(arg[0]):
+ # was called as @fixture(), we have the function to wrap.
+ fn = arg[0]
+ arg = arg[1:]
+ else:
+ # was called as @fixture, don't have the function yet.
+ fn = None
+
+ # create a pytest.fixture marker. because the fn is not being
+ # passed, this is always a pytest.FixtureFunctionMarker()
+ # object (or whatever pytest is calling it when you read this)
+ # that is waiting for a function.
+ fixture = pytest.fixture(*arg, **kw)
+
+ # now apply wrappers to the function, including fixture itself
+
+ def wrap(fn):
+ if config.any_async:
+ fn = asyncio._maybe_async_wrapper(fn)
+ # other wrappers may be added here
+
+ if py2k and "autouse" in kw:
+ # py2k workaround for too-slow collection of autouse fixtures
+ # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
+ # rationale.
+
+ # comment this condition out in order to disable the
+ # py2k workaround entirely.
+ reinvent_fixtures_py2k.add_fixture(fn, fixture)
+ else:
+ # now apply FixtureFunctionMarker
+ fn = fixture(fn)
+
+ return fn
+
+ if fn:
+ return wrap(fn)
+ else:
+ return wrap
+
+ def get_current_test_name(self):
+ return os.environ.get("PYTEST_CURRENT_TEST")
+
+ def async_test(self, fn):
+ from sqlalchemy.testing import asyncio
+
+ @_pytest_fn_decorator
+ def decorate(fn, *args, **kwargs):
+ asyncio._run_coroutine_function(fn, *args, **kwargs)
+
+ return decorate(fn)
diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
new file mode 100644
index 0000000..36b6841
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
@@ -0,0 +1,112 @@
+"""
+invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
+collection/high memory use in pytest 4.6.11, which is the highest version that
+works in py2k.
+
+by "too-slow" we mean the test suite can't even manage to be collected for a
+single process in less than 70 seconds or so and memory use seems to be very
+high as well. for two or four workers the job just times out after ten
+minutes.
+
+so instead we have invented a very limited form of these fixtures, as our
+current use of "autouse" fixtures are limited to those in fixtures.py.
+
+assumptions for these fixtures:
+
+1. we are only using "function" or "class" scope
+
+2. the functions must be associated with a test class
+
+3. the fixture functions cannot themselves use pytest fixtures
+
+4. the fixture functions must use yield, not return
+
+When py2k support is removed and we can stay on a modern pytest version, this
+can all be removed.
+
+
+"""
+import collections
+
+
+_py2k_fixture_fn_names = collections.defaultdict(set)
+_py2k_class_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+_py2k_function_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+
+_py2k_cls_fixture_stack = []
+_py2k_fn_fixture_stack = []
+
+
+def add_fixture(fn, fixture):
+ assert fixture.scope in ("class", "function")
+ _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
+
+
+def scan_for_fixtures_to_use_for_class(item):
+ test_class = item.parent.parent.obj
+
+ for name in _py2k_fixture_fn_names:
+ for fixture_fn, scope in _py2k_fixture_fn_names[name]:
+ meth = getattr(test_class, name, None)
+ if meth and meth.im_func is fixture_fn:
+ for sup in test_class.__mro__:
+ if name in sup.__dict__:
+ if scope == "class":
+ _py2k_class_fixtures[test_class][sup].add(meth)
+ elif scope == "function":
+ _py2k_function_fixtures[test_class][sup].add(meth)
+ break
+ break
+
+
+def run_class_fixture_setup(request):
+
+ cls = request.cls
+ self = cls.__new__(cls)
+
+ fixtures_for_this_class = _py2k_class_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in cls.__mro__:
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_cls_fixture_stack.append(iter_)
+
+
+def run_class_fixture_teardown(request):
+ while _py2k_cls_fixture_stack:
+ iter_ = _py2k_cls_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
+
+
+def run_fn_fixture_setup(request):
+ cls = request.cls
+ self = request.instance
+
+ fixtures_for_this_class = _py2k_function_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in reversed(cls.__mro__):
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_fn_fixture_stack.append(iter_)
+
+
+def run_fn_fixture_teardown(request):
+ while _py2k_fn_fixture_stack:
+ iter_ = _py2k_fn_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
new file mode 100644
index 0000000..4132630
--- /dev/null
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -0,0 +1,335 @@
+# testing/profiling.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Profiling support for unit and performance tests.
+
+These are special purpose profiling methods which operate
+in a more fine-grained way than nose's profiling plugin.
+
+"""
+
+import collections
+import contextlib
+import os
+import platform
+import pstats
+import re
+import sys
+
+from . import config
+from .util import gc_collect
+from ..util import has_compiled_ext
+
+
+try:
+ import cProfile
+except ImportError:
+ cProfile = None
+
+_profile_stats = None
+"""global ProfileStatsFileInstance.
+
+plugin_base assigns this at the start of all tests.
+
+"""
+
+
+_current_test = None
+"""String id of current test.
+
+plugin_base assigns this at the start of each test using
+_start_current_test.
+
+"""
+
+
+def _start_current_test(id_):
+ global _current_test
+ _current_test = id_
+
+ if _profile_stats.force_write:
+ _profile_stats.reset_count()
+
+
+class ProfileStatsFile(object):
+ """Store per-platform/fn profiling results in a file.
+
+ There was no json module available when this was written, but now
+ the file format which is very deterministically line oriented is kind of
+ handy in any case for diffs and merges.
+
+ """
+
+ def __init__(self, filename, sort="cumulative", dump=None):
+ self.force_write = (
+ config.options is not None and config.options.force_write_profiles
+ )
+ self.write = self.force_write or (
+ config.options is not None and config.options.write_profiles
+ )
+ self.fname = os.path.abspath(filename)
+ self.short_fname = os.path.split(self.fname)[-1]
+ self.data = collections.defaultdict(
+ lambda: collections.defaultdict(dict)
+ )
+ self.dump = dump
+ self.sort = sort
+ self._read()
+ if self.write:
+ # rewrite for the case where features changed,
+ # etc.
+ self._write()
+
+ @property
+ def platform_key(self):
+
+ dbapi_key = config.db.name + "_" + config.db.driver
+
+ if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
+ config.db.url
+ ):
+ dbapi_key += "_file"
+
+ # keep it at 2.7, 3.1, 3.2, etc. for now.
+ py_version = ".".join([str(v) for v in sys.version_info[0:2]])
+
+ platform_tokens = [
+ platform.machine(),
+ platform.system().lower(),
+ platform.python_implementation().lower(),
+ py_version,
+ dbapi_key,
+ ]
+
+ platform_tokens.append(
+ "nativeunicode"
+ if config.db.dialect.convert_unicode
+ else "dbapiunicode"
+ )
+ _has_cext = has_compiled_ext()
+ platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
+ return "_".join(platform_tokens)
+
+ def has_stats(self):
+ test_key = _current_test
+ return (
+ test_key in self.data and self.platform_key in self.data[test_key]
+ )
+
+ def result(self, callcount):
+ test_key = _current_test
+ per_fn = self.data[test_key]
+ per_platform = per_fn[self.platform_key]
+
+ if "counts" not in per_platform:
+ per_platform["counts"] = counts = []
+ else:
+ counts = per_platform["counts"]
+
+ if "current_count" not in per_platform:
+ per_platform["current_count"] = current_count = 0
+ else:
+ current_count = per_platform["current_count"]
+
+ has_count = len(counts) > current_count
+
+ if not has_count:
+ counts.append(callcount)
+ if self.write:
+ self._write()
+ result = None
+ else:
+ result = per_platform["lineno"], counts[current_count]
+ per_platform["current_count"] += 1
+ return result
+
+ def reset_count(self):
+ test_key = _current_test
+ # since self.data is a defaultdict, don't access a key
+ # if we don't know it's there first.
+ if test_key not in self.data:
+ return
+ per_fn = self.data[test_key]
+ if self.platform_key not in per_fn:
+ return
+ per_platform = per_fn[self.platform_key]
+ if "counts" in per_platform:
+ per_platform["counts"][:] = []
+
+ def replace(self, callcount):
+ test_key = _current_test
+ per_fn = self.data[test_key]
+ per_platform = per_fn[self.platform_key]
+ counts = per_platform["counts"]
+ current_count = per_platform["current_count"]
+ if current_count < len(counts):
+ counts[current_count - 1] = callcount
+ else:
+ counts[-1] = callcount
+ if self.write:
+ self._write()
+
+ def _header(self):
+ return (
+ "# %s\n"
+ "# This file is written out on a per-environment basis.\n"
+ "# For each test in aaa_profiling, the corresponding "
+ "function and \n"
+ "# environment is located within this file. "
+ "If it doesn't exist,\n"
+ "# the test is skipped.\n"
+ "# If a callcount does exist, it is compared "
+ "to what we received. \n"
+ "# assertions are raised if the counts do not match.\n"
+ "# \n"
+ "# To add a new callcount test, apply the function_call_count \n"
+ "# decorator and re-run the tests using the --write-profiles \n"
+ "# option - this file will be rewritten including the new count.\n"
+ "# \n"
+ ) % (self.fname)
+
+ def _read(self):
+ try:
+ profile_f = open(self.fname)
+ except IOError:
+ return
+ for lineno, line in enumerate(profile_f):
+ line = line.strip()
+ if not line or line.startswith("#"):
+ continue
+
+ test_key, platform_key, counts = line.split()
+ per_fn = self.data[test_key]
+ per_platform = per_fn[platform_key]
+ c = [int(count) for count in counts.split(",")]
+ per_platform["counts"] = c
+ per_platform["lineno"] = lineno + 1
+ per_platform["current_count"] = 0
+ profile_f.close()
+
+ def _write(self):
+ print(("Writing profile file %s" % self.fname))
+ profile_f = open(self.fname, "w")
+ profile_f.write(self._header())
+ for test_key in sorted(self.data):
+
+ per_fn = self.data[test_key]
+ profile_f.write("\n# TEST: %s\n\n" % test_key)
+ for platform_key in sorted(per_fn):
+ per_platform = per_fn[platform_key]
+ c = ",".join(str(count) for count in per_platform["counts"])
+ profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
+ profile_f.close()
+
+
+def function_call_count(variance=0.05, times=1, warmup=0):
+ """Assert a target for a test case's function call count.
+
+ The main purpose of this assertion is to detect changes in
+ callcounts for various functions - the actual number is not as important.
+ Callcounts are stored in a file keyed to Python version and OS platform
+ information. This file is generated automatically for new tests,
+ and versioned so that unexpected changes in callcounts will be detected.
+
+ """
+
+ # use signature-rewriting decorator function so that pytest fixtures
+ # still work on py27. In Py3, update_wrapper() alone is good enough,
+ # likely due to the introduction of __signature__.
+
+ from sqlalchemy.util import decorator
+ from sqlalchemy.util import deprecations
+ from sqlalchemy.engine import row
+ from sqlalchemy.testing import mock
+
+ @decorator
+ def wrap(fn, *args, **kw):
+
+ with mock.patch.object(
+ deprecations, "SQLALCHEMY_WARN_20", False
+ ), mock.patch.object(
+ row.LegacyRow, "_default_key_style", row.KEY_OBJECTS_NO_WARN
+ ):
+ for warm in range(warmup):
+ fn(*args, **kw)
+
+ timerange = range(times)
+ with count_functions(variance=variance):
+ for time in timerange:
+ rv = fn(*args, **kw)
+ return rv
+
+ return wrap
+
+
+@contextlib.contextmanager
+def count_functions(variance=0.05):
+ if cProfile is None:
+ raise config._skip_test_exception("cProfile is not installed")
+
+ if not _profile_stats.has_stats() and not _profile_stats.write:
+ config.skip_test(
+ "No profiling stats available on this "
+ "platform for this function. Run tests with "
+ "--write-profiles to add statistics to %s for "
+ "this platform." % _profile_stats.short_fname
+ )
+
+ gc_collect()
+
+ pr = cProfile.Profile()
+ pr.enable()
+ # began = time.time()
+ yield
+ # ended = time.time()
+ pr.disable()
+
+ # s = compat.StringIO()
+ stats = pstats.Stats(pr, stream=sys.stdout)
+
+ # timespent = ended - began
+ callcount = stats.total_calls
+
+ expected = _profile_stats.result(callcount)
+
+ if expected is None:
+ expected_count = None
+ else:
+ line_no, expected_count = expected
+
+ print(("Pstats calls: %d Expected %s" % (callcount, expected_count)))
+ stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort))
+ stats.print_stats()
+ if _profile_stats.dump:
+ base, ext = os.path.splitext(_profile_stats.dump)
+ test_name = _current_test.split(".")[-1]
+ dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile")
+ stats.dump_stats(dumpfile)
+ print("Dumped stats to file %s" % dumpfile)
+ # stats.print_callers()
+ if _profile_stats.force_write:
+ _profile_stats.replace(callcount)
+ elif expected_count:
+ deviance = int(callcount * variance)
+ failed = abs(callcount - expected_count) > deviance
+
+ if failed:
+ if _profile_stats.write:
+ _profile_stats.replace(callcount)
+ else:
+ raise AssertionError(
+ "Adjusted function call count %s not within %s%% "
+ "of expected %s, platform %s. Rerun with "
+ "--write-profiles to "
+ "regenerate this callcount."
+ % (
+ callcount,
+ (variance * 100),
+ expected_count,
+ _profile_stats.platform_key,
+ )
+ )
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
new file mode 100644
index 0000000..90c4d93
--- /dev/null
+++ b/lib/sqlalchemy/testing/provision.py
@@ -0,0 +1,416 @@
+import collections
+import logging
+
+from . import config
+from . import engines
+from . import util
+from .. import exc
+from .. import inspect
+from ..engine import url as sa_url
+from ..sql import ddl
+from ..sql import schema
+from ..util import compat
+
+
+log = logging.getLogger(__name__)
+
+FOLLOWER_IDENT = None
+
+
+class register(object):
+ def __init__(self):
+ self.fns = {}
+
+ @classmethod
+ def init(cls, fn):
+ return register().for_db("*")(fn)
+
+ def for_db(self, *dbnames):
+ def decorate(fn):
+ for dbname in dbnames:
+ self.fns[dbname] = fn
+ return self
+
+ return decorate
+
+ def __call__(self, cfg, *arg):
+ if isinstance(cfg, compat.string_types):
+ url = sa_url.make_url(cfg)
+ elif isinstance(cfg, sa_url.URL):
+ url = cfg
+ else:
+ url = cfg.db.url
+ backend = url.get_backend_name()
+ if backend in self.fns:
+ return self.fns[backend](cfg, *arg)
+ else:
+ return self.fns["*"](cfg, *arg)
+
+
+def create_follower_db(follower_ident):
+ for cfg in _configs_for_db_operation():
+ log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
+ create_db(cfg, cfg.db, follower_ident)
+
+
+def setup_config(db_url, options, file_config, follower_ident):
+ # load the dialect, which should also have it set up its provision
+ # hooks
+
+ dialect = sa_url.make_url(db_url).get_dialect()
+ dialect.load_provisioning()
+
+ if follower_ident:
+ db_url = follower_url_from_main(db_url, follower_ident)
+ db_opts = {}
+ update_db_opts(db_url, db_opts)
+ db_opts["scope"] = "global"
+ eng = engines.testing_engine(db_url, db_opts)
+ post_configure_engine(db_url, eng, follower_ident)
+ eng.connect().close()
+
+ cfg = config.Config.register(eng, db_opts, options, file_config)
+
+ # a symbolic name that tests can use if they need to disambiguate
+ # names across databases
+ if follower_ident:
+ config.ident = follower_ident
+
+ if follower_ident:
+ configure_follower(cfg, follower_ident)
+ return cfg
+
+
+def drop_follower_db(follower_ident):
+ for cfg in _configs_for_db_operation():
+ log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
+ drop_db(cfg, cfg.db, follower_ident)
+
+
+def generate_db_urls(db_urls, extra_drivers):
+ """Generate a set of URLs to test given configured URLs plus additional
+ driver names.
+
+ Given::
+
+ --dburi postgresql://db1 \
+ --dburi postgresql://db2 \
+ --dburi postgresql://db2 \
+ --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
+
+ Noting that the default postgresql driver is psycopg2, the output
+ would be::
+
+ postgresql+psycopg2://db1
+ postgresql+asyncpg://db1
+ postgresql+psycopg2://db2
+ postgresql+psycopg2://db3
+
+ That is, for the driver in a --dburi, we want to keep that and use that
+ driver for each URL it's part of . For a driver that is only
+ in --dbdrivers, we want to use it just once for one of the URLs.
+ for a driver that is both coming from --dburi as well as --dbdrivers,
+ we want to keep it in that dburi.
+
+ Driver specific query options can be specified by added them to the
+ driver name. For example, to enable the async fallback option for
+ asyncpg::
+
+ --dburi postgresql://db1 \
+ --dbdriver=asyncpg?async_fallback=true
+
+ """
+ urls = set()
+
+ backend_to_driver_we_already_have = collections.defaultdict(set)
+
+ urls_plus_dialects = [
+ (url_obj, url_obj.get_dialect())
+ for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
+ ]
+
+ for url_obj, dialect in urls_plus_dialects:
+ backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
+
+ backend_to_driver_we_need = {}
+
+ for url_obj, dialect in urls_plus_dialects:
+ backend = dialect.name
+ dialect.load_provisioning()
+
+ if backend not in backend_to_driver_we_need:
+ backend_to_driver_we_need[backend] = extra_per_backend = set(
+ extra_drivers
+ ).difference(backend_to_driver_we_already_have[backend])
+ else:
+ extra_per_backend = backend_to_driver_we_need[backend]
+
+ for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
+ if driver_url in urls:
+ continue
+ urls.add(driver_url)
+ yield driver_url
+
+
+def _generate_driver_urls(url, extra_drivers):
+ main_driver = url.get_driver_name()
+ extra_drivers.discard(main_driver)
+
+ url = generate_driver_url(url, main_driver, "")
+ yield str(url)
+
+ for drv in list(extra_drivers):
+
+ if "?" in drv:
+
+ driver_only, query_str = drv.split("?", 1)
+
+ else:
+ driver_only = drv
+ query_str = None
+
+ new_url = generate_driver_url(url, driver_only, query_str)
+ if new_url:
+ extra_drivers.remove(drv)
+
+ yield str(new_url)
+
+
+@register.init
+def generate_driver_url(url, driver, query_str):
+ backend = url.get_backend_name()
+
+ new_url = url.set(
+ drivername="%s+%s" % (backend, driver),
+ )
+ if query_str:
+ new_url = new_url.update_query_string(query_str)
+
+ try:
+ new_url.get_dialect()
+ except exc.NoSuchModuleError:
+ return None
+ else:
+ return new_url
+
+
+def _configs_for_db_operation():
+ hosts = set()
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+ for cfg in config.Config.all_configs():
+ url = cfg.db.url
+ backend = url.get_backend_name()
+ host_conf = (backend, url.username, url.host, url.database)
+
+ if host_conf not in hosts:
+ yield cfg
+ hosts.add(host_conf)
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+
+@register.init
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ pass
+
+
+@register.init
+def drop_all_schema_objects_post_tables(cfg, eng):
+ pass
+
+
+def drop_all_schema_objects(cfg, eng):
+
+ drop_all_schema_objects_pre_tables(cfg, eng)
+
+ inspector = inspect(eng)
+ try:
+ view_names = inspector.get_view_names()
+ except NotImplementedError:
+ pass
+ else:
+ with eng.begin() as conn:
+ for vname in view_names:
+ conn.execute(
+ ddl._DropView(schema.Table(vname, schema.MetaData()))
+ )
+
+ if config.requirements.schemas.enabled_for_config(cfg):
+ try:
+ view_names = inspector.get_view_names(schema="test_schema")
+ except NotImplementedError:
+ pass
+ else:
+ with eng.begin() as conn:
+ for vname in view_names:
+ conn.execute(
+ ddl._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
+
+ util.drop_all_tables(eng, inspector)
+ if config.requirements.schemas.enabled_for_config(cfg):
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
+
+ drop_all_schema_objects_post_tables(cfg, eng)
+
+ if config.requirements.sequences.enabled_for_config(cfg):
+ with eng.begin() as conn:
+ for seq in inspector.get_sequence_names():
+ conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+ if config.requirements.schemas.enabled_for_config(cfg):
+ for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+ for seq in inspector.get_sequence_names(
+ schema=schema_name
+ ):
+ conn.execute(
+ ddl.DropSequence(
+ schema.Sequence(seq, schema=schema_name)
+ )
+ )
+
+
+@register.init
+def create_db(cfg, eng, ident):
+ """Dynamically create a database for testing.
+
+ Used when a test run will employ multiple processes, e.g., when run
+ via `tox` or `pytest -n4`.
+ """
+ raise NotImplementedError(
+ "no DB creation routine for cfg: %s" % (eng.url,)
+ )
+
+
+@register.init
+def drop_db(cfg, eng, ident):
+ """Drop a database that we dynamically created for testing."""
+ raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
+
+
+@register.init
+def update_db_opts(db_url, db_opts):
+ """Set database options (db_opts) for a test database that we created."""
+ pass
+
+
+@register.init
+def post_configure_engine(url, engine, follower_ident):
+ """Perform extra steps after configuring an engine for testing.
+
+ (For the internal dialects, currently only used by sqlite, oracle)
+ """
+ pass
+
+
+@register.init
+def follower_url_from_main(url, ident):
+ """Create a connection URL for a dynamically-created test database.
+
+ :param url: the connection URL specified when the test run was invoked
+ :param ident: the pytest-xdist "worker identifier" to be used as the
+ database name
+ """
+ url = sa_url.make_url(url)
+ return url.set(database=ident)
+
+
+@register.init
+def configure_follower(cfg, ident):
+ """Create dialect-specific config settings for a follower database."""
+ pass
+
+
+@register.init
+def run_reap_dbs(url, ident):
+ """Remove databases that were created during the test process, after the
+ process has ended.
+
+ This is an optional step that is invoked for certain backends that do not
+ reliably release locks on the database as long as a process is still in
+ use. For the internal dialects, this is currently only necessary for
+ mssql and oracle.
+ """
+ pass
+
+
+def reap_dbs(idents_file):
+ log.info("Reaping databases...")
+
+ urls = collections.defaultdict(set)
+ idents = collections.defaultdict(set)
+ dialects = {}
+
+ with open(idents_file) as file_:
+ for line in file_:
+ line = line.strip()
+ db_name, db_url = line.split(" ")
+ url_obj = sa_url.make_url(db_url)
+ if db_name not in dialects:
+ dialects[db_name] = url_obj.get_dialect()
+ dialects[db_name].load_provisioning()
+ url_key = (url_obj.get_backend_name(), url_obj.host)
+ urls[url_key].add(db_url)
+ idents[url_key].add(db_name)
+
+ for url_key in urls:
+ url = list(urls[url_key])[0]
+ ident = idents[url_key]
+ run_reap_dbs(url, ident)
+
+
+@register.init
+def temp_table_keyword_args(cfg, eng):
+ """Specify keyword arguments for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ kwargs that are passed to the Table method when creating a temporary
+ table for testing, e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+ """
+ raise NotImplementedError(
+ "no temp table keyword args routine for cfg: %s" % (eng.url,)
+ )
+
+
+@register.init
+def prepare_for_drop_tables(config, connection):
+ pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
+ pass
+
+
+@register.init
+def get_temp_table_name(cfg, eng, base_name):
+ """Specify table name for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ name to use when creating a temporary table for testing,
+ e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+
+ Default to just the base name since that's what most dialects will
+ use. The mssql dialect's implementation will need a "#" prepended.
+ """
+ return base_name
+
+
+@register.init
+def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
+ raise NotImplementedError(
+ "backend does not implement a schema name set function: %s"
+ % (cfg.db.url,)
+ )
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
new file mode 100644
index 0000000..857d1fd
--- /dev/null
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -0,0 +1,1518 @@
+# testing/requirements.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+External dialect test suites should subclass SuiteRequirements
+to provide specific inclusion/exclusions.
+
+"""
+
+import platform
+import sys
+
+from . import exclusions
+from . import only_on
+from .. import util
+from ..pool import QueuePool
+
+
+class Requirements(object):
+ pass
+
+
+class SuiteRequirements(Requirements):
+ @property
+ def create_table(self):
+ """target platform can emit basic CreateTable DDL."""
+
+ return exclusions.open()
+
+ @property
+ def drop_table(self):
+ """target platform can emit basic DropTable DDL."""
+
+ return exclusions.open()
+
+ @property
+ def table_ddl_if_exists(self):
+ """target platform supports IF NOT EXISTS / IF EXISTS for tables."""
+
+ return exclusions.closed()
+
+ @property
+ def index_ddl_if_exists(self):
+ """target platform supports IF NOT EXISTS / IF EXISTS for indexes."""
+
+ return exclusions.closed()
+
+ @property
+ def foreign_keys(self):
+ """Target database must support foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def table_value_constructor(self):
+ """Database / dialect supports a query like::
+
+ SELECT * FROM VALUES ( (c1, c2), (c1, c2), ...)
+ AS some_table(col1, col2)
+
+ SQLAlchemy generates this with the :func:`_sql.values` function.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def standard_cursor_sql(self):
+ """Target database passes SQL-92 style statements to cursor.execute()
+ when a statement like select() or insert() is run.
+
+ A very small portion of dialect-level tests will ensure that certain
+ conditions are present in SQL strings, and these tests use very basic
+ SQL that will work on any SQL-like platform in order to assert results.
+
+ It's normally a given for any pep-249 DBAPI that a statement like
+ "SELECT id, name FROM table WHERE some_table.id=5" will work.
+ However, there are dialects that don't actually produce SQL Strings
+ and instead may work with symbolic objects instead, or dialects that
+ aren't working with SQL, so for those this requirement can be marked
+ as excluded.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def on_update_cascade(self):
+ """target database must support ON UPDATE..CASCADE behavior in
+ foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def non_updating_cascade(self):
+ """target database must *not* support ON UPDATE..CASCADE behavior in
+ foreign keys."""
+ return exclusions.closed()
+
+ @property
+ def deferrable_fks(self):
+ return exclusions.closed()
+
+ @property
+ def on_update_or_deferrable_fks(self):
+ # TODO: exclusions should be composable,
+ # somehow only_if([x, y]) isn't working here, negation/conjunctions
+ # getting confused.
+ return exclusions.only_if(
+ lambda: self.on_update_cascade.enabled
+ or self.deferrable_fks.enabled
+ )
+
+ @property
+ def queue_pool(self):
+ """target database is using QueuePool"""
+
+ def go(config):
+ return isinstance(config.db.pool, QueuePool)
+
+ return exclusions.only_if(go)
+
+ @property
+ def self_referential_foreign_keys(self):
+ """Target database must support self-referential foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def foreign_key_ddl(self):
+ """Target database must support the DDL phrases for FOREIGN KEY."""
+
+ return exclusions.open()
+
+ @property
+ def named_constraints(self):
+ """target database must support names for constraints."""
+
+ return exclusions.open()
+
+ @property
+ def implicitly_named_constraints(self):
+ """target database must apply names to unnamed constraints."""
+
+ return exclusions.open()
+
+ @property
+ def subqueries(self):
+ """Target database must support subqueries."""
+
+ return exclusions.open()
+
+ @property
+ def offset(self):
+ """target database can render OFFSET, or an equivalent, in a
+ SELECT.
+ """
+
+ return exclusions.open()
+
+ @property
+ def bound_limit_offset(self):
+ """target database can render LIMIT and/or OFFSET using a bound
+ parameter
+ """
+
+ return exclusions.open()
+
+ @property
+ def sql_expression_limit_offset(self):
+ """target database can render LIMIT and/or OFFSET with a complete
+ SQL expression, such as one that uses the addition operator.
+ parameter
+ """
+
+ return exclusions.open()
+
+ @property
+ def parens_in_union_contained_select_w_limit_offset(self):
+ """Target database must support parenthesized SELECT in UNION
+ when LIMIT/OFFSET is specifically present.
+
+ E.g. (SELECT ...) UNION (SELECT ..)
+
+ This is known to fail on SQLite.
+
+ """
+ return exclusions.open()
+
+ @property
+ def parens_in_union_contained_select_wo_limit_offset(self):
+ """Target database must support parenthesized SELECT in UNION
+ when OFFSET/LIMIT is specifically not present.
+
+ E.g. (SELECT ... LIMIT ..) UNION (SELECT .. OFFSET ..)
+
+ This is known to fail on SQLite. It also fails on Oracle
+ because without LIMIT/OFFSET, there is currently no step that
+ creates an additional subquery.
+
+ """
+ return exclusions.open()
+
+ @property
+ def boolean_col_expressions(self):
+ """Target database must support boolean expressions as columns"""
+
+ return exclusions.closed()
+
+ @property
+ def nullable_booleans(self):
+ """Target database allows boolean columns to store NULL."""
+
+ return exclusions.open()
+
+ @property
+ def nullsordering(self):
+ """Target backends that support nulls ordering."""
+
+ return exclusions.closed()
+
+ @property
+ def standalone_binds(self):
+ """target database/driver supports bound parameters as column
+ expressions without being in the context of a typed column.
+ """
+ return exclusions.closed()
+
+ @property
+ def standalone_null_binds_whereclause(self):
+ """target database/driver supports bound parameters with NULL in the
+ WHERE clause, in situations where it has to be typed.
+
+ """
+ return exclusions.open()
+
+ @property
+ def intersect(self):
+ """Target database must support INTERSECT or equivalent."""
+ return exclusions.closed()
+
+ @property
+ def except_(self):
+ """Target database must support EXCEPT or equivalent (i.e. MINUS)."""
+ return exclusions.closed()
+
+ @property
+ def window_functions(self):
+ """Target database must support window functions."""
+ return exclusions.closed()
+
+ @property
+ def ctes(self):
+ """Target database supports CTEs"""
+
+ return exclusions.closed()
+
+ @property
+ def ctes_with_update_delete(self):
+ """target database supports CTES that ride on top of a normal UPDATE
+ or DELETE statement which refers to the CTE in a correlated subquery.
+
+ """
+
+ return exclusions.closed()
+
+ @property
+ def ctes_on_dml(self):
+ """target database supports CTES which consist of INSERT, UPDATE
+ or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)"""
+
+ return exclusions.closed()
+
+ @property
+ def autoincrement_insert(self):
+ """target platform generates new surrogate integer primary key values
+ when insert() is executed, excluding the pk column."""
+
+ return exclusions.open()
+
+ @property
+ def fetch_rows_post_commit(self):
+ """target platform will allow cursor.fetchone() to proceed after a
+ COMMIT.
+
+ Typically this refers to an INSERT statement with RETURNING which
+ is invoked within "autocommit". If the row can be returned
+ after the autocommit, then this rule can be open.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def group_by_complex_expression(self):
+ """target platform supports SQL expressions in GROUP BY
+
+ e.g.
+
+ SELECT x + y AS somelabel FROM table GROUP BY x + y
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def sane_rowcount(self):
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_sane_rowcount,
+ "driver doesn't support 'sane' rowcount",
+ )
+
+ @property
+ def sane_multi_rowcount(self):
+ return exclusions.fails_if(
+ lambda config: not config.db.dialect.supports_sane_multi_rowcount,
+ "driver %(driver)s %(doesnt_support)s 'sane' multi row count",
+ )
+
+ @property
+ def sane_rowcount_w_returning(self):
+ return exclusions.fails_if(
+ lambda config: not (
+ config.db.dialect.supports_sane_rowcount_returning
+ ),
+ "driver doesn't support 'sane' rowcount when returning is on",
+ )
+
+ @property
+ def empty_inserts(self):
+ """target platform supports INSERT with no values, i.e.
+ INSERT DEFAULT VALUES or equivalent."""
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.supports_empty_insert
+ or config.db.dialect.supports_default_values
+ or config.db.dialect.supports_default_metavalue,
+ "empty inserts not supported",
+ )
+
+ @property
+ def empty_inserts_executemany(self):
+ """target platform supports INSERT with no values, i.e.
+ INSERT DEFAULT VALUES or equivalent, within executemany()"""
+
+ return self.empty_inserts
+
+ @property
+ def insert_from_select(self):
+ """target platform supports INSERT from a SELECT."""
+
+ return exclusions.open()
+
+ @property
+ def full_returning(self):
+ """target platform supports RETURNING completely, including
+ multiple rows returned.
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.full_returning,
+ "%(database)s %(does_support)s 'RETURNING of multiple rows'",
+ )
+
+ @property
+ def insert_executemany_returning(self):
+ """target platform supports RETURNING when INSERT is used with
+ executemany(), e.g. multiple parameter sets, indicating
+ as many rows come back as do parameter sets were passed.
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.insert_executemany_returning,
+ "%(database)s %(does_support)s 'RETURNING of "
+ "multiple rows with INSERT executemany'",
+ )
+
+ @property
+ def returning(self):
+ """target platform supports RETURNING for at least one row.
+
+ .. seealso::
+
+ :attr:`.Requirements.full_returning`
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.implicit_returning,
+ "%(database)s %(does_support)s 'RETURNING of a single row'",
+ )
+
+ @property
+ def tuple_in(self):
+ """Target platform supports the syntax
+ "(x, y) IN ((x1, y1), (x2, y2), ...)"
+ """
+
+ return exclusions.closed()
+
+ @property
+ def tuple_in_w_empty(self):
+ """Target platform tuple IN w/ empty set"""
+ return self.tuple_in
+
+ @property
+ def duplicate_names_in_cursor_description(self):
+ """target platform supports a SELECT statement that has
+ the same name repeated more than once in the columns list."""
+
+ return exclusions.open()
+
+ @property
+ def denormalized_names(self):
+ """Target database must have 'denormalized', i.e.
+ UPPERCASE as case insensitive names."""
+
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.requires_name_normalize,
+ "Backend does not require denormalized names.",
+ )
+
+ @property
+ def multivalues_inserts(self):
+ """target database must support multiple VALUES clauses in an
+ INSERT statement."""
+
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_multivalues_insert,
+ "Backend does not support multirow inserts.",
+ )
+
+ @property
+ def implements_get_lastrowid(self):
+ """target dialect implements the executioncontext.get_lastrowid()
+ method without reliance on RETURNING.
+
+ """
+ return exclusions.open()
+
+ @property
+ def emulated_lastrowid(self):
+ """target dialect retrieves cursor.lastrowid, or fetches
+ from a database-side function after an insert() construct executes,
+ within the get_lastrowid() method.
+
+ Only dialects that "pre-execute", or need RETURNING to get last
+ inserted id, would return closed/fail/skip for this.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def emulated_lastrowid_even_with_sequences(self):
+ """target dialect retrieves cursor.lastrowid or an equivalent
+ after an insert() construct executes, even if the table has a
+ Sequence on it.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def dbapi_lastrowid(self):
+ """target platform includes a 'lastrowid' accessor on the DBAPI
+ cursor object.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def views(self):
+ """Target database must support VIEWs."""
+
+ return exclusions.closed()
+
+ @property
+ def schemas(self):
+ """Target database must support external schemas, and have one
+ named 'test_schema'."""
+
+ return only_on(lambda config: config.db.dialect.supports_schemas)
+
+ @property
+ def cross_schema_fk_reflection(self):
+ """target system must support reflection of inter-schema
+ foreign keys"""
+ return exclusions.closed()
+
+ @property
+ def foreign_key_constraint_name_reflection(self):
+ """Target supports refleciton of FOREIGN KEY constraints and
+ will return the name of the constraint that was used in the
+ "CONSTRAINT <name> FOREIGN KEY" DDL.
+
+ MySQL prior to version 8 and MariaDB prior to version 10.5
+ don't support this.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def implicit_default_schema(self):
+ """target system has a strong concept of 'default' schema that can
+ be referred to implicitly.
+
+ basically, PostgreSQL.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def default_schema_name_switch(self):
+ """target dialect implements provisioning module including
+ set_default_schema_on_connection"""
+
+ return exclusions.closed()
+
+ @property
+ def server_side_cursors(self):
+ """Target dialect must support server side cursors."""
+
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_server_side_cursors],
+ "no server side cursors support",
+ )
+
+ @property
+ def sequences(self):
+ """Target database must support SEQUENCEs."""
+
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_sequences],
+ "no sequence support",
+ )
+
+ @property
+ def no_sequences(self):
+ """the opposite of "sequences", DB does not support sequences at
+ all."""
+
+ return exclusions.NotPredicate(self.sequences)
+
+ @property
+ def sequences_optional(self):
+ """Target database supports sequences, but also optionally
+ as a means of generating new PK values."""
+
+ return exclusions.only_if(
+ [
+ lambda config: config.db.dialect.supports_sequences
+ and config.db.dialect.sequences_optional
+ ],
+ "no sequence support, or sequences not optional",
+ )
+
+ @property
+ def supports_lastrowid(self):
+ """target database / driver supports cursor.lastrowid as a means
+ of retrieving the last inserted primary key value.
+
+ note that if the target DB supports sequences also, this is still
+ assumed to work. This is a new use case brought on by MariaDB 10.3.
+
+ """
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.postfetch_lastrowid]
+ )
+
+ @property
+ def no_lastrowid_support(self):
+ """the opposite of supports_lastrowid"""
+ return exclusions.only_if(
+ [lambda config: not config.db.dialect.postfetch_lastrowid]
+ )
+
+ @property
+ def reflects_pk_names(self):
+ return exclusions.closed()
+
+ @property
+ def table_reflection(self):
+ """target database has general support for table reflection"""
+ return exclusions.open()
+
+ @property
+ def reflect_tables_no_columns(self):
+ """target database supports creation and reflection of tables with no
+ columns, or at least tables that seem to have no columns."""
+
+ return exclusions.closed()
+
+ @property
+ def comment_reflection(self):
+ return exclusions.closed()
+
+ @property
+ def view_column_reflection(self):
+ """target database must support retrieval of the columns in a view,
+ similarly to how a table is inspected.
+
+ This does not include the full CREATE VIEW definition.
+
+ """
+ return self.views
+
+ @property
+ def view_reflection(self):
+ """target database must support inspection of the full CREATE VIEW
+ definition."""
+ return self.views
+
+ @property
+ def schema_reflection(self):
+ return self.schemas
+
+ @property
+ def primary_key_constraint_reflection(self):
+ return exclusions.open()
+
+ @property
+ def foreign_key_constraint_reflection(self):
+ return exclusions.open()
+
+ @property
+ def foreign_key_constraint_option_reflection_ondelete(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_ondelete_restrict(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_ondelete_noaction(self):
+ return exclusions.closed()
+
+ @property
+ def foreign_key_constraint_option_reflection_onupdate(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_onupdate_restrict(self):
+ return exclusions.closed()
+
+ @property
+ def temp_table_reflection(self):
+ return exclusions.open()
+
+ @property
+ def temp_table_reflect_indexes(self):
+ return self.temp_table_reflection
+
+ @property
+ def temp_table_names(self):
+ """target dialect supports listing of temporary table names"""
+ return exclusions.closed()
+
+ @property
+ def temporary_tables(self):
+ """target database supports temporary tables"""
+ return exclusions.open()
+
+ @property
+ def temporary_views(self):
+ """target database supports temporary views"""
+ return exclusions.closed()
+
+ @property
+ def index_reflection(self):
+ return exclusions.open()
+
+ @property
+ def index_reflects_included_columns(self):
+ return exclusions.closed()
+
+ @property
+ def indexes_with_ascdesc(self):
+ """target database supports CREATE INDEX with per-column ASC/DESC."""
+ return exclusions.open()
+
+ @property
+ def indexes_with_expressions(self):
+ """target database supports CREATE INDEX against SQL expressions."""
+ return exclusions.closed()
+
+ @property
+ def unique_constraint_reflection(self):
+ """target dialect supports reflection of unique constraints"""
+ return exclusions.open()
+
+ @property
+ def check_constraint_reflection(self):
+ """target dialect supports reflection of check constraints"""
+ return exclusions.closed()
+
+ @property
+ def duplicate_key_raises_integrity_error(self):
+ """target dialect raises IntegrityError when reporting an INSERT
+ with a primary key violation. (hint: it should)
+
+ """
+ return exclusions.open()
+
+ @property
+ def unbounded_varchar(self):
+ """Target database must support VARCHAR with no length"""
+
+ return exclusions.open()
+
+ @property
+ def unicode_data(self):
+ """Target database/dialect must support Python unicode objects with
+ non-ASCII characters represented, delivered as bound parameters
+ as well as in result rows.
+
+ """
+ return exclusions.open()
+
+ @property
+ def unicode_ddl(self):
+ """Target driver must support some degree of non-ascii symbol
+ names.
+ """
+ return exclusions.closed()
+
+ @property
+ def symbol_names_w_double_quote(self):
+ """Target driver can create tables with a name like 'some " table'"""
+ return exclusions.open()
+
+ @property
+ def datetime_literals(self):
+ """target dialect supports rendering of a date, time, or datetime as a
+ literal string, e.g. via the TypeEngine.literal_processor() method.
+
+ """
+
+ return exclusions.closed()
+
+ @property
+ def datetime(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects."""
+
+ return exclusions.open()
+
+ @property
+ def datetime_timezone(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with tzinfo with DateTime(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def time_timezone(self):
+ """target dialect supports representation of Python
+ datetime.time() with tzinfo with Time(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def datetime_implicit_bound(self):
+ """target dialect when given a datetime object will bind it such
+ that the database server knows the object is a datetime, and not
+ a plain string.
+
+ """
+ return exclusions.open()
+
+ @property
+ def datetime_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with microsecond objects."""
+
+ return exclusions.open()
+
+ @property
+ def timestamp_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with microsecond objects but only
+ if TIMESTAMP is used."""
+ return exclusions.closed()
+
+ @property
+ def timestamp_microseconds_implicit_bound(self):
+ """target dialect when given a datetime object which also includes
+ a microseconds portion when using the TIMESTAMP data type
+ will bind it such that the database server knows
+ the object is a datetime with microseconds, and not a plain string.
+
+ """
+ return self.timestamp_microseconds
+
+ @property
+ def datetime_historic(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects with historic (pre 1970) values."""
+
+ return exclusions.closed()
+
+ @property
+ def date(self):
+ """target dialect supports representation of Python
+ datetime.date() objects."""
+
+ return exclusions.open()
+
+ @property
+ def date_coerces_from_datetime(self):
+ """target dialect accepts a datetime object as the target
+ of a date column."""
+
+ return exclusions.open()
+
+ @property
+ def date_historic(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects with historic (pre 1970) values."""
+
+ return exclusions.closed()
+
+ @property
+ def time(self):
+ """target dialect supports representation of Python
+ datetime.time() objects."""
+
+ return exclusions.open()
+
+ @property
+ def time_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.time() with microsecond objects."""
+
+ return exclusions.open()
+
+ @property
+ def binary_comparisons(self):
+ """target database/driver can allow BLOB/BINARY fields to be compared
+ against a bound parameter value.
+ """
+
+ return exclusions.open()
+
+ @property
+ def binary_literals(self):
+ """target backend supports simple binary literals, e.g. an
+ expression like::
+
+ SELECT CAST('foo' AS BINARY)
+
+ Where ``BINARY`` is the type emitted from :class:`.LargeBinary`,
+ e.g. it could be ``BLOB`` or similar.
+
+ Basically fails on Oracle.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def autocommit(self):
+ """target dialect supports 'AUTOCOMMIT' as an isolation_level"""
+ return exclusions.closed()
+
+ @property
+ def isolation_level(self):
+ """target dialect supports general isolation level settings.
+
+ Note that this requirement, when enabled, also requires that
+ the get_isolation_levels() method be implemented.
+
+ """
+ return exclusions.closed()
+
+ def get_isolation_levels(self, config):
+ """Return a structure of supported isolation levels for the current
+ testing dialect.
+
+ The structure indicates to the testing suite what the expected
+ "default" isolation should be, as well as the other values that
+ are accepted. The dictionary has two keys, "default" and "supported".
+ The "supported" key refers to a list of all supported levels and
+ it should include AUTOCOMMIT if the dialect supports it.
+
+ If the :meth:`.DefaultRequirements.isolation_level` requirement is
+ not open, then this method has no return value.
+
+ E.g.::
+
+ >>> testing.requirements.get_isolation_levels()
+ {
+ "default": "READ_COMMITTED",
+ "supported": [
+ "SERIALIZABLE", "READ UNCOMMITTED",
+ "READ COMMITTED", "REPEATABLE READ",
+ "AUTOCOMMIT"
+ ]
+ }
+ """
+
+ @property
+ def json_type(self):
+ """target platform implements a native JSON type."""
+
+ return exclusions.closed()
+
+ @property
+ def json_array_indexes(self):
+ """target platform supports numeric array indexes
+ within a JSON structure"""
+
+ return self.json_type
+
+ @property
+ def json_index_supplementary_unicode_element(self):
+ return exclusions.open()
+
+ @property
+ def legacy_unconditional_json_extract(self):
+ """Backend has a JSON_EXTRACT or similar function that returns a
+ valid JSON string in all cases.
+
+ Used to test a legacy feature and is not needed.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_general(self):
+ """target backend has general support for moderately high-precision
+ numerics."""
+ return exclusions.open()
+
+ @property
+ def precision_numerics_enotation_small(self):
+ """target backend supports Decimal() objects using E notation
+ to represent very small values."""
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_enotation_large(self):
+ """target backend supports Decimal() objects using E notation
+ to represent very large values."""
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_many_significant_digits(self):
+ """target backend supports values with many digits on both sides,
+ such as 319438950232418390.273596, 87673.594069654243
+
+ """
+ return exclusions.closed()
+
+ @property
+ def cast_precision_numerics_many_significant_digits(self):
+ """same as precision_numerics_many_significant_digits but within the
+ context of a CAST statement (hello MySQL)
+
+ """
+ return self.precision_numerics_many_significant_digits
+
+ @property
+ def implicit_decimal_binds(self):
+ """target backend will return a selected Decimal as a Decimal, not
+ a string.
+
+ e.g.::
+
+ expr = decimal.Decimal("15.7563")
+
+ value = e.scalar(
+ select(literal(expr))
+ )
+
+ assert value == expr
+
+ See :ticket:`4036`
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def nested_aggregates(self):
+ """target database can select an aggregate from a subquery that's
+ also using an aggregate
+
+ """
+ return exclusions.open()
+
+ @property
+ def recursive_fk_cascade(self):
+ """target database must support ON DELETE CASCADE on a self-referential
+ foreign key
+
+ """
+ return exclusions.open()
+
+ @property
+ def precision_numerics_retains_significant_digits(self):
+ """A precision numeric type will return empty significant digits,
+ i.e. a value such as 10.000 will come back in Decimal form with
+ the .000 maintained."""
+
+ return exclusions.closed()
+
+ @property
+ def infinity_floats(self):
+ """The Float type can persist and load float('inf'), float('-inf')."""
+
+ return exclusions.closed()
+
+ @property
+ def precision_generic_float_type(self):
+ """target backend will return native floating point numbers with at
+ least seven decimal places when using the generic Float type.
+
+ """
+ return exclusions.open()
+
+ @property
+ def floats_to_four_decimals(self):
+ """target backend can return a floating-point number with four
+ significant digits (such as 15.7563) accurately
+ (i.e. without FP inaccuracies, such as 15.75629997253418).
+
+ """
+ return exclusions.open()
+
+ @property
+ def fetch_null_from_numeric(self):
+ """target backend doesn't crash when you try to select a NUMERIC
+ value that has a value of NULL.
+
+ Added to support Pyodbc bug #351.
+ """
+
+ return exclusions.open()
+
+ @property
+ def text_type(self):
+ """Target database must support an unbounded Text() "
+ "type such as TEXT or CLOB"""
+
+ return exclusions.open()
+
+ @property
+ def empty_strings_varchar(self):
+ """target database can persist/return an empty string with a
+ varchar.
+
+ """
+ return exclusions.open()
+
+ @property
+ def empty_strings_text(self):
+ """target database can persist/return an empty string with an
+ unbounded text."""
+
+ return exclusions.open()
+
+ @property
+ def expressions_against_unbounded_text(self):
+ """target database supports use of an unbounded textual field in a
+ WHERE clause."""
+
+ return exclusions.open()
+
+ @property
+ def selectone(self):
+ """target driver must support the literal statement 'select 1'"""
+ return exclusions.open()
+
+ @property
+ def savepoints(self):
+ """Target database must support savepoints."""
+
+ return exclusions.closed()
+
+ @property
+ def two_phase_transactions(self):
+ """Target database must support two-phase transactions."""
+
+ return exclusions.closed()
+
+ @property
+ def update_from(self):
+ """Target must support UPDATE..FROM syntax"""
+ return exclusions.closed()
+
+ @property
+ def delete_from(self):
+ """Target must support DELETE FROM..FROM or DELETE..USING syntax"""
+ return exclusions.closed()
+
+ @property
+ def update_where_target_in_subquery(self):
+ """Target must support UPDATE (or DELETE) where the same table is
+ present in a subquery in the WHERE clause.
+
+ This is an ANSI-standard syntax that apparently MySQL can't handle,
+ such as::
+
+ UPDATE documents SET flag=1 WHERE documents.title IN
+ (SELECT max(documents.title) AS title
+ FROM documents GROUP BY documents.user_id
+ )
+
+ """
+ return exclusions.open()
+
+ @property
+ def mod_operator_as_percent_sign(self):
+ """target database must use a plain percent '%' as the 'modulus'
+ operator."""
+ return exclusions.closed()
+
+ @property
+ def percent_schema_names(self):
+ """target backend supports weird identifiers with percent signs
+ in them, e.g. 'some % column'.
+
+ this is a very weird use case but often has problems because of
+ DBAPIs that use python formatting. It's not a critical use
+ case either.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def order_by_col_from_union(self):
+ """target database supports ordering by a column from a SELECT
+ inside of a UNION
+
+ E.g. (SELECT id, ...) UNION (SELECT id, ...) ORDER BY id
+
+ """
+ return exclusions.open()
+
+ @property
+ def order_by_label_with_expression(self):
+ """target backend supports ORDER BY a column label within an
+ expression.
+
+ Basically this::
+
+ select data as foo from test order by foo || 'bar'
+
+ Lots of databases including PostgreSQL don't support this,
+ so this is off by default.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def order_by_collation(self):
+ def check(config):
+ try:
+ self.get_order_by_collation(config)
+ return False
+ except NotImplementedError:
+ return True
+
+ return exclusions.skip_if(check)
+
+ def get_order_by_collation(self, config):
+ raise NotImplementedError()
+
+ @property
+ def unicode_connections(self):
+ """Target driver must support non-ASCII characters being passed at
+ all.
+ """
+ return exclusions.open()
+
+ @property
+ def graceful_disconnects(self):
+ """Target driver must raise a DBAPI-level exception, such as
+ InterfaceError, when the underlying connection has been closed
+ and the execute() method is called.
+ """
+ return exclusions.open()
+
+ @property
+ def independent_connections(self):
+ """
+ Target must support simultaneous, independent database connections.
+ """
+ return exclusions.open()
+
+ @property
+ def skip_mysql_on_windows(self):
+ """Catchall for a large variety of MySQL on Windows failures"""
+ return exclusions.open()
+
+ @property
+ def ad_hoc_engines(self):
+ """Test environment must allow ad-hoc engine/connection creation.
+
+ DBs that scale poorly for many connections, even when closed, i.e.
+ Oracle, may use the "--low-connections" option which flags this
+ requirement as not present.
+
+ """
+ return exclusions.skip_if(
+ lambda config: config.options.low_connections
+ )
+
+ @property
+ def no_windows(self):
+ return exclusions.skip_if(self._running_on_windows())
+
+ def _running_on_windows(self):
+ return exclusions.LambdaPredicate(
+ lambda: platform.system() == "Windows",
+ description="running on Windows",
+ )
+
+ @property
+ def timing_intensive(self):
+ return exclusions.requires_tag("timing_intensive")
+
+ @property
+ def memory_intensive(self):
+ return exclusions.requires_tag("memory_intensive")
+
+ @property
+ def threading_with_mock(self):
+ """Mark tests that use threading and mock at the same time - stability
+ issues have been observed with coverage + python 3.3
+
+ """
+ return exclusions.skip_if(
+ lambda config: util.py3k and config.options.has_coverage,
+ "Stability issues with coverage + py3k",
+ )
+
+ @property
+ def sqlalchemy2_stubs(self):
+ def check(config):
+ try:
+ __import__("sqlalchemy-stubs.ext.mypy")
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(check)
+
+ @property
+ def python2(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info >= (3,),
+ "Python version 2.xx is required.",
+ )
+
+ @property
+ def python3(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3,), "Python version 3.xx is required."
+ )
+
+ @property
+ def pep520(self):
+ return self.python36
+
+ @property
+ def insert_order_dicts(self):
+ return self.python37
+
+ @property
+ def python36(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3, 6),
+ "Python version 3.6 or greater is required.",
+ )
+
+ @property
+ def python37(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3, 7),
+ "Python version 3.7 or greater is required.",
+ )
+
+ @property
+ def dataclasses(self):
+ return self.python37
+
+ @property
+ def python38(self):
+ return exclusions.only_if(
+ lambda: util.py38, "Python 3.8 or above required"
+ )
+
+ @property
+ def cpython(self):
+ return exclusions.only_if(
+ lambda: util.cpython, "cPython interpreter needed"
+ )
+
+ @property
+ def patch_library(self):
+ def check_lib():
+ try:
+ __import__("patch")
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(check_lib, "patch library needed")
+
+ @property
+ def non_broken_pickle(self):
+ from sqlalchemy.util import pickle
+
+ return exclusions.only_if(
+ lambda: util.cpython
+ and pickle.__name__ == "cPickle"
+ or sys.version_info >= (3, 2),
+ "Needs cPickle+cPython or newer Python 3 pickle",
+ )
+
+ @property
+ def predictable_gc(self):
+ """target platform must remove all cycles unconditionally when
+ gc.collect() is called, as well as clean out unreferenced subclasses.
+
+ """
+ return self.cpython
+
+ @property
+ def no_coverage(self):
+ """Test should be skipped if coverage is enabled.
+
+ This is to block tests that exercise libraries that seem to be
+ sensitive to coverage, such as PostgreSQL notice logging.
+
+ """
+ return exclusions.skip_if(
+ lambda config: config.options.has_coverage,
+ "Issues observed when coverage is enabled",
+ )
+
+ def _has_mysql_on_windows(self, config):
+ return False
+
+ def _has_mysql_fully_case_sensitive(self, config):
+ return False
+
+ @property
+ def sqlite(self):
+ return exclusions.skip_if(lambda: not self._has_sqlite())
+
+ @property
+ def cextensions(self):
+ return exclusions.skip_if(
+ lambda: not util.has_compiled_ext(), "C extensions not installed"
+ )
+
+ def _has_sqlite(self):
+ from sqlalchemy import create_engine
+
+ try:
+ create_engine("sqlite://")
+ return True
+ except ImportError:
+ return False
+
+ @property
+ def async_dialect(self):
+ """dialect makes use of await_() to invoke operations on the DBAPI."""
+
+ return exclusions.closed()
+
+ @property
+ def asyncio(self):
+ return self.greenlet
+
+ @property
+ def greenlet(self):
+ def go(config):
+ try:
+ import greenlet # noqa: F401
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(go)
+
+ @property
+ def computed_columns(self):
+ "Supports computed columns"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_stored(self):
+ "Supports computed columns with `persisted=True`"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_virtual(self):
+ "Supports computed columns with `persisted=False`"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_default_persisted(self):
+ """If the default persistence is virtual or stored when `persisted`
+ is omitted"""
+ return exclusions.closed()
+
+ @property
+ def computed_columns_reflect_persisted(self):
+ """If persistence information is returned by the reflection of
+ computed columns"""
+ return exclusions.closed()
+
+ @property
+ def supports_distinct_on(self):
+ """If a backend supports the DISTINCT ON in a select"""
+ return exclusions.closed()
+
+ @property
+ def supports_is_distinct_from(self):
+ """Supports some form of "x IS [NOT] DISTINCT FROM y" construct.
+ Different dialects will implement their own flavour, e.g.,
+ sqlite will emit "x IS NOT y" instead of "x IS DISTINCT FROM y".
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.is_distinct_from`
+
+ """
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_is_distinct_from,
+ "driver doesn't support an IS DISTINCT FROM construct",
+ )
+
+ @property
+ def identity_columns(self):
+ """If a backend supports GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY"""
+ return exclusions.closed()
+
+ @property
+ def identity_columns_standard(self):
+ """If a backend supports GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY with a standard syntax.
+ This is mainly to exclude MSSql.
+ """
+ return exclusions.closed()
+
+ @property
+ def regexp_match(self):
+ """backend supports the regexp_match operator."""
+ return exclusions.closed()
+
+ @property
+ def regexp_replace(self):
+ """backend supports the regexp_replace operator."""
+ return exclusions.closed()
+
+ @property
+ def fetch_first(self):
+ """backend supports the fetch first clause."""
+ return exclusions.closed()
+
+ @property
+ def fetch_percent(self):
+ """backend supports the fetch first clause with percent."""
+ return exclusions.closed()
+
+ @property
+ def fetch_ties(self):
+ """backend supports the fetch first clause with ties."""
+ return exclusions.closed()
+
+ @property
+ def fetch_no_order_by(self):
+ """backend supports the fetch first without order by"""
+ return exclusions.closed()
+
+ @property
+ def fetch_offset_with_options(self):
+ """backend supports the offset when using fetch first with percent
+ or ties. basically this is "not mssql"
+ """
+ return exclusions.closed()
+
+ @property
+ def fetch_expression(self):
+ """backend supports fetch / offset with expression in them, like
+
+ SELECT * FROM some_table
+ OFFSET 1 + 1 ROWS FETCH FIRST 1 + 1 ROWS ONLY
+ """
+ return exclusions.closed()
+
+ @property
+ def autoincrement_without_sequence(self):
+ """If autoincrement=True on a column does not require an explicit
+ sequence. This should be false only for oracle.
+ """
+ return exclusions.open()
+
+ @property
+ def generic_classes(self):
+ "If X[Y] can be implemented with ``__class_getitem__``. py3.7+"
+ return exclusions.only_if(lambda: util.py37)
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
new file mode 100644
index 0000000..bff07a5
--- /dev/null
+++ b/lib/sqlalchemy/testing/schema.py
@@ -0,0 +1,218 @@
+# testing/schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import sys
+
+from . import config
+from . import exclusions
+from .. import event
+from .. import schema
+from .. import types as sqltypes
+from ..util import OrderedDict
+
+
+__all__ = ["Table", "Column"]
+
+table_options = {}
+
+
+def Table(*args, **kw):
+ """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
+
+ kw.update(table_options)
+
+ if exclusions.against(config._current, "mysql"):
+ if (
+ "mysql_engine" not in kw
+ and "mysql_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mysql_engine"] = "InnoDB"
+ else:
+ kw["mysql_engine"] = "MyISAM"
+ elif exclusions.against(config._current, "mariadb"):
+ if (
+ "mariadb_engine" not in kw
+ and "mariadb_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mariadb_engine"] = "InnoDB"
+ else:
+ kw["mariadb_engine"] = "MyISAM"
+
+ # Apply some default cascading rules for self-referential foreign keys.
+ # MySQL InnoDB has some issues around selecting self-refs too.
+ if exclusions.against(config._current, "firebird"):
+ table_name = args[0]
+ unpack = config.db.dialect.identifier_preparer.unformat_identifiers
+
+ # Only going after ForeignKeys in Columns. May need to
+ # expand to ForeignKeyConstraint too.
+ fks = [
+ fk
+ for col in args
+ if isinstance(col, schema.Column)
+ for fk in col.foreign_keys
+ ]
+
+ for fk in fks:
+ # root around in raw spec
+ ref = fk._colspec
+ if isinstance(ref, schema.Column):
+ name = ref.table.name
+ else:
+ # take just the table name: on FB there cannot be
+ # a schema, so the first element is always the
+ # table name, possibly followed by the field name
+ name = unpack(ref)[0]
+ if name == table_name:
+ if fk.ondelete is None:
+ fk.ondelete = "CASCADE"
+ if fk.onupdate is None:
+ fk.onupdate = "CASCADE"
+
+ return schema.Table(*args, **kw)
+
+
+def Column(*args, **kw):
+ """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
+
+ if not config.requirements.foreign_key_ddl.enabled_for_config(config):
+ args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
+
+ col = schema.Column(*args, **kw)
+ if test_opts.get("test_needs_autoincrement", False) and kw.get(
+ "primary_key", False
+ ):
+
+ if col.default is None and col.server_default is None:
+ col.autoincrement = True
+
+ # allow any test suite to pick up on this
+ col.info["test_needs_autoincrement"] = True
+
+ # hardcoded rule for firebird, oracle; this should
+ # be moved out
+ if exclusions.against(config._current, "firebird", "oracle"):
+
+ def add_seq(c, tbl):
+ c._init_items(
+ schema.Sequence(
+ _truncate_name(
+ config.db.dialect, tbl.name + "_" + c.name + "_seq"
+ ),
+ optional=True,
+ )
+ )
+
+ event.listen(col, "after_parent_attach", add_seq, propagate=True)
+ return col
+
+
+class eq_type_affinity(object):
+ """Helper to compare types inside of datastructures based on affinity.
+
+ E.g.::
+
+ eq_(
+ inspect(connection).get_columns("foo"),
+ [
+ {
+ "name": "id",
+ "type": testing.eq_type_affinity(sqltypes.INTEGER),
+ "nullable": False,
+ "default": None,
+ "autoincrement": False,
+ },
+ {
+ "name": "data",
+ "type": testing.eq_type_affinity(sqltypes.NullType),
+ "nullable": True,
+ "default": None,
+ "autoincrement": False,
+ },
+ ],
+ )
+
+ """
+
+ def __init__(self, target):
+ self.target = sqltypes.to_instance(target)
+
+ def __eq__(self, other):
+ return self.target._type_affinity is other._type_affinity
+
+ def __ne__(self, other):
+ return self.target._type_affinity is not other._type_affinity
+
+
+class eq_clause_element(object):
+ """Helper to compare SQL structures based on compare()"""
+
+ def __init__(self, target):
+ self.target = target
+
+ def __eq__(self, other):
+ return self.target.compare(other)
+
+ def __ne__(self, other):
+ return not self.target.compare(other)
+
+
+def _truncate_name(dialect, name):
+ if len(name) > dialect.max_identifier_length:
+ return (
+ name[0 : max(dialect.max_identifier_length - 6, 0)]
+ + "_"
+ + hex(hash(name) % 64)[2:]
+ )
+ else:
+ return name
+
+
+def pep435_enum(name):
+ # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+ __members__ = OrderedDict()
+
+ def __init__(self, name, value, alias=None):
+ self.name = name
+ self.value = value
+ self.__members__[name] = self
+ value_to_member[value] = self
+ setattr(self.__class__, name, self)
+ if alias:
+ self.__members__[alias] = self
+ setattr(self.__class__, alias, self)
+
+ value_to_member = {}
+
+ @classmethod
+ def get(cls, value):
+ return value_to_member[value]
+
+ someenum = type(
+ name,
+ (object,),
+ {"__members__": __members__, "__init__": __init__, "get": get},
+ )
+
+ # getframe() trick for pickling I don't understand courtesy
+ # Python namedtuple()
+ try:
+ module = sys._getframe(1).f_globals.get("__name__", "__main__")
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ someenum.__module__ = module
+
+ return someenum
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
new file mode 100644
index 0000000..30817e1
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -0,0 +1,13 @@
+from .test_cte import * # noqa
+from .test_ddl import * # noqa
+from .test_deprecations import * # noqa
+from .test_dialect import * # noqa
+from .test_insert import * # noqa
+from .test_reflection import * # noqa
+from .test_results import * # noqa
+from .test_rowcount import * # noqa
+from .test_select import * # noqa
+from .test_sequence import * # noqa
+from .test_types import * # noqa
+from .test_unicode_ddl import * # noqa
+from .test_update_delete import * # noqa
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
new file mode 100644
index 0000000..a94ee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -0,0 +1,204 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import ForeignKey
+from ... import Integer
+from ... import select
+from ... import String
+from ... import testing
+
+
+class CTETest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("ctes",)
+
+ run_inserts = "each"
+ run_deletes = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", ForeignKey("some_table.id")),
+ )
+
+ Table(
+ "some_other_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "d1", "parent_id": None},
+ {"id": 2, "data": "d2", "parent_id": 1},
+ {"id": 3, "data": "d3", "parent_id": 1},
+ {"id": 4, "data": "d4", "parent_id": 3},
+ {"id": 5, "data": "d5", "parent_id": 3},
+ ],
+ )
+
+ def test_select_nonrecursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ result = connection.execute(
+ select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
+ )
+ eq_(result.fetchall(), [("d4",)])
+
+ def test_select_recursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte", recursive=True)
+ )
+
+ cte_alias = cte.alias("c1")
+ st1 = some_table.alias()
+ # note that SQL Server requires this to be UNION ALL,
+ # can't be UNION
+ cte = cte.union_all(
+ select(st1).where(st1.c.id == cte_alias.c.parent_id)
+ )
+ result = connection.execute(
+ select(cte.c.data)
+ .where(cte.c.data != "d2")
+ .order_by(cte.c.data.desc())
+ )
+ eq_(
+ result.fetchall(),
+ [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
+ )
+
+ def test_insert_from_select_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(cte)
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.update_from
+ def test_update_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.update()
+ .values(parent_id=5)
+ .where(some_other_table.c.data == cte.c.data)
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None),
+ (2, "d2", 5),
+ (3, "d3", 5),
+ (4, "d4", 5),
+ (5, "d5", 3),
+ ],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.delete_from
+ def test_delete_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ def test_delete_scalar_subq_round_trip(self, connection):
+
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data
+ == select(cte.c.data)
+ .where(cte.c.id == some_other_table.c.id)
+ .scalar_subquery()
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
new file mode 100644
index 0000000..b3fee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -0,0 +1,381 @@
+import random
+
+from . import testing
+from .. import config
+from .. import fixtures
+from .. import util
+from ..assertions import eq_
+from ..assertions import is_false
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Table
+from ... import CheckConstraint
+from ... import Column
+from ... import ForeignKeyConstraint
+from ... import Index
+from ... import inspect
+from ... import Integer
+from ... import schema
+from ... import String
+from ... import UniqueConstraint
+
+
+class TableDDLTest(fixtures.TestBase):
+ __backend__ = True
+
+ def _simple_fixture(self, schema=None):
+ return Table(
+ "test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ schema=schema,
+ )
+
+ def _underscore_fixture(self):
+ return Table(
+ "_test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("_data", String(50)),
+ )
+
+ def _table_index_fixture(self, schema=None):
+ table = self._simple_fixture(schema=schema)
+ idx = Index("test_index", table.c.data)
+ return table, idx
+
+ def _simple_roundtrip(self, table):
+ with config.db.begin() as conn:
+ conn.execute(table.insert().values((1, "some data")))
+ result = conn.execute(table.select())
+ eq_(result.first(), (1, "some data"))
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_create_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.create_table
+ @requirements.schemas
+ @util.provide_metadata
+ def test_create_table_schema(self):
+ table = self._simple_fixture(schema=config.test_schema)
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.drop_table
+ @util.provide_metadata
+ def test_drop_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ table.drop(config.db, checkfirst=False)
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_underscore_names(self):
+ table = self._underscore_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_add_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"),
+ {"text": "a comment"},
+ )
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_drop_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ connection.execute(schema.DropTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"), {"text": None}
+ )
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_create_table_if_not_exists(self, connection):
+ table = self._simple_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ is_true(inspect(connection).has_table("test_table"))
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_create_index_if_not_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+ is_true(inspect(connection).has_table("test_table"))
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_table_if_exists(self, connection):
+ table = self._simple_fixture()
+
+ table.create(connection)
+
+ is_true(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ is_false(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_index_if_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ table.create(connection)
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+
+class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
+ pass
+
+
+class LongNameBlowoutTest(fixtures.TestBase):
+ """test the creation of a variety of DDL structures and ensure
+ label length limits pass on backends
+
+ """
+
+ __backend__ = True
+
+ def fk(self, metadata, connection):
+ convention = {
+ "fk": "foreign_key_%(table_name)s_"
+ "%(column_0_N_name)s_"
+ "%(referred_table_name)s_"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(20))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ cons = ForeignKeyConstraint(
+ ["aid"], ["a_things_with_stuff.id_long_column_name"]
+ )
+ Table(
+ "b_related_things_of_value",
+ metadata,
+ Column(
+ "aid",
+ ),
+ cons,
+ test_needs_fk=True,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+
+ if testing.requires.foreign_key_constraint_name_reflection.enabled:
+ insp = inspect(connection)
+ fks = insp.get_foreign_keys("b_related_things_of_value")
+ reflected_name = fks[0]["name"]
+
+ return actual_name, reflected_name
+ else:
+ return actual_name, None
+
+ def pk(self, metadata, connection):
+ convention = {
+ "pk": "primary_key_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer, primary_key=True),
+ )
+ cons = a.primary_key
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ pk = insp.get_pk_constraint("a_things_with_stuff")
+ reflected_name = pk["name"]
+ return actual_name, reflected_name
+
+ def ix(self, metadata, connection):
+ convention = {
+ "ix": "index_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ )
+ cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ix = insp.get_indexes("a_things_with_stuff")
+ reflected_name = ix[0]["name"]
+ return actual_name, reflected_name
+
+ def uq(self, metadata, connection):
+ convention = {
+ "uq": "unique_constraint_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ uq = insp.get_unique_constraints("a_things_with_stuff")
+ reflected_name = uq[0]["name"]
+ return actual_name, reflected_name
+
+ def ck(self, metadata, connection):
+ convention = {
+ "ck": "check_constraint_%(table_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = CheckConstraint("some_long_column_name > 5")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("some_long_column_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ck = insp.get_check_constraints("a_things_with_stuff")
+ reflected_name = ck[0]["name"]
+ return actual_name, reflected_name
+
+ @testing.combinations(
+ ("fk",),
+ ("pk",),
+ ("ix",),
+ ("ck", testing.requires.check_constraint_reflection.as_skips()),
+ ("uq", testing.requires.unique_constraint_reflection.as_skips()),
+ argnames="type_",
+ )
+ def test_long_convention_name(self, type_, metadata, connection):
+ actual_name, reflected_name = getattr(self, type_)(
+ metadata, connection
+ )
+
+ assert len(actual_name) > 255
+
+ if reflected_name is not None:
+ overlap = actual_name[0 : len(reflected_name)]
+ if len(overlap) < len(actual_name):
+ eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
+ else:
+ eq_(overlap, reflected_name)
+
+
+__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")
diff --git a/lib/sqlalchemy/testing/suite/test_deprecations.py b/lib/sqlalchemy/testing/suite/test_deprecations.py
new file mode 100644
index 0000000..b36162f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_deprecations.py
@@ -0,0 +1,145 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import select
+from ... import testing
+from ... import union
+
+
+class DeprecatedCompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, conn, select, result, params=()):
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ # note we've had to remove one use case entirely, which is this
+ # one. the Select gets its FROMS from the WHERE clause and the
+ # columns clause, but not the ORDER BY, which means the old ".c" system
+ # allowed you to "order_by(s.c.foo)" to get an unnamed column in the
+ # ORDER BY without adding the SELECT into the FROM and breaking the
+ # query. Users will have to adjust for this use case if they were doing
+ # it before.
+ def _dont_test_select_from_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
new file mode 100644
index 0000000..c2c17d0
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -0,0 +1,361 @@
+#! coding: utf-8
+
+from . import testing
+from .. import assert_raises
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import fixtures
+from .. import ne_
+from .. import provide_metadata
+from ..config import requirements
+from ..provision import set_default_schema_on_connection
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import event
+from ... import exc
+from ... import Integer
+from ... import literal_column
+from ... import select
+from ... import String
+from ...util import compat
+
+
+class ExceptionTest(fixtures.TablesTest):
+ """Test basic exception wrapping.
+
+ DBAPIs vary a lot in exception behavior so to actually anticipate
+ specific exceptions from real round trips, we need to be conservative.
+
+ """
+
+ run_deletes = "each"
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ @requirements.duplicate_key_raises_integrity_error
+ def test_integrity_error(self):
+
+ with config.db.connect() as conn:
+
+ trans = conn.begin()
+ conn.execute(
+ self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
+ )
+
+ assert_raises(
+ exc.IntegrityError,
+ conn.execute,
+ self.tables.manual_pk.insert(),
+ {"id": 1, "data": "d1"},
+ )
+
+ trans.rollback()
+
+ def test_exception_with_non_ascii(self):
+ with config.db.connect() as conn:
+ try:
+ # try to create an error message that likely has non-ascii
+ # characters in the DBAPI's message string. unfortunately
+ # there's no way to make this happen with some drivers like
+ # mysqlclient, pymysql. this at least does produce a non-
+ # ascii error message for cx_oracle, psycopg2
+ conn.execute(select(literal_column(u"méil")))
+ assert False
+ except exc.DBAPIError as err:
+ err_str = str(err)
+
+ assert str(err.orig) in str(err)
+
+ # test that we are actually getting string on Py2k, unicode
+ # on Py3k.
+ if compat.py2k:
+ assert isinstance(err_str, str)
+ else:
+ assert isinstance(err_str, str)
+
+
+class IsolationLevelTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("isolation_level",)
+
+ def _get_non_default_isolation_level(self):
+ levels = requirements.get_isolation_levels(config)
+
+ default = levels["default"]
+ supported = levels["supported"]
+
+ s = set(supported).difference(["AUTOCOMMIT", default])
+ if s:
+ return s.pop()
+ else:
+ config.skip_test("no non-default isolation level available")
+
+ def test_default_isolation_level(self):
+ eq_(
+ config.db.dialect.default_isolation_level,
+ requirements.get_isolation_levels(config)["default"],
+ )
+
+ def test_non_default_isolation_level(self):
+ non_default = self._get_non_default_isolation_level()
+
+ with config.db.connect() as conn:
+ existing = conn.get_isolation_level()
+
+ ne_(existing, non_default)
+
+ conn.execution_options(isolation_level=non_default)
+
+ eq_(conn.get_isolation_level(), non_default)
+
+ conn.dialect.reset_isolation_level(conn.connection)
+
+ eq_(conn.get_isolation_level(), existing)
+
+ def test_all_levels(self):
+ levels = requirements.get_isolation_levels(config)
+
+ all_levels = levels["supported"]
+
+ for level in set(all_levels).difference(["AUTOCOMMIT"]):
+ with config.db.connect() as conn:
+ conn.execution_options(isolation_level=level)
+
+ eq_(conn.get_isolation_level(), level)
+
+ trans = conn.begin()
+ trans.rollback()
+
+ eq_(conn.get_isolation_level(), level)
+
+ with config.db.connect() as conn:
+ eq_(
+ conn.get_isolation_level(),
+ levels["default"],
+ )
+
+
+class AutocommitIsolationTest(fixtures.TablesTest):
+
+ run_deletes = "each"
+
+ __requires__ = ("autocommit",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ test_needs_acid=True,
+ )
+
+ def _test_conn_autocommits(self, conn, autocommit):
+ trans = conn.begin()
+ conn.execute(
+ self.tables.some_table.insert(), {"id": 1, "data": "some data"}
+ )
+ trans.rollback()
+
+ eq_(
+ conn.scalar(select(self.tables.some_table.c.id)),
+ 1 if autocommit else None,
+ )
+
+ with conn.begin():
+ conn.execute(self.tables.some_table.delete())
+
+ def test_autocommit_on(self, connection_no_trans):
+ conn = connection_no_trans
+ c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(c2, True)
+
+ c2.dialect.reset_isolation_level(c2.connection)
+
+ self._test_conn_autocommits(conn, False)
+
+ def test_autocommit_off(self, connection_no_trans):
+ conn = connection_no_trans
+ self._test_conn_autocommits(conn, False)
+
+ def test_turn_autocommit_off_via_default_iso_level(
+ self, connection_no_trans
+ ):
+ conn = connection_no_trans
+ conn = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(conn, True)
+
+ conn.execution_options(
+ isolation_level=requirements.get_isolation_levels(config)[
+ "default"
+ ]
+ )
+ self._test_conn_autocommits(conn, False)
+
+
+class EscapingTest(fixtures.TestBase):
+ @provide_metadata
+ def test_percent_sign_round_trip(self):
+ """test that the DBAPI accommodates for escaped / nonescaped
+ percent signs in a way that matches the compiler
+
+ """
+ m = self.metadata
+ t = Table("t", m, Column("data", String(50)))
+ t.create(config.db)
+ with config.db.begin() as conn:
+ conn.execute(t.insert(), dict(data="some % value"))
+ conn.execute(t.insert(), dict(data="some %% other value"))
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some % value'")
+ )
+ ),
+ "some % value",
+ )
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some %% other value'")
+ )
+ ),
+ "some %% other value",
+ )
+
+
+class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("default_schema_name_switch",)
+
+ def test_control_case(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+ with eng.connect():
+ pass
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_wont_work_wo_insert(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect")
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_schema_change_on_connect(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+ def test_schema_change_works_w_transactions(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, *arg):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ trans = conn.begin()
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+ trans.rollback()
+
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+
+class FutureWeCanSetDefaultSchemaWEventsTest(
+ fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
+):
+ pass
+
+
+class DifficultParametersTest(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.combinations(
+ ("boring",),
+ ("per cent",),
+ ("per % cent",),
+ ("%percent",),
+ ("par(ens)",),
+ ("percent%(ens)yah",),
+ ("col:ons",),
+ ("more :: %colons%",),
+ ("/slashes/",),
+ ("more/slashes",),
+ ("q?marks",),
+ ("1param",),
+ ("1col:on",),
+ argnames="name",
+ )
+ def test_round_trip(self, name, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column(name, String(50), nullable=False),
+ )
+
+ # table is created
+ t.create(connection)
+
+ # automatic param generated by insert
+ connection.execute(t.insert().values({"id": 1, name: "some name"}))
+
+ # automatic param generated by criteria, plus selecting the column
+ stmt = select(t.c[name]).where(t.c[name] == "some name")
+
+ eq_(connection.scalar(stmt), "some name")
+
+ # use the name in a param explicitly
+ stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
+
+ row = connection.execute(stmt, {name: "some name"}).first()
+
+ # name works as the key from cursor.description
+ eq_(row._mapping[name], "some name")
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
new file mode 100644
index 0000000..3c22f50
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -0,0 +1,367 @@
+from .. import config
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import select
+from ... import String
+
+
+class LastrowidTest(fixtures.TablesTest):
+ run_deletes = "each"
+
+ __backend__ = True
+
+ __requires__ = "implements_get_lastrowid", "autoincrement_insert"
+
+ __engine_options__ = {"implicit_returning": False}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ def test_autoincrement_on_insert(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+ @requirements.dbapi_lastrowid
+ def test_native_lastrowid_autoinc(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ lastrowid = r.lastrowid
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(lastrowid, pk)
+
+
+class InsertBehaviorTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+ Table(
+ "includes_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ Column("x", Integer, default=5),
+ Column(
+ "y",
+ Integer,
+ default=literal_column("2", type_=Integer) + literal(2),
+ ),
+ )
+
+ @requirements.autoincrement_insert
+ def test_autoclose_on_insert(self):
+ if requirements.returning.enabled:
+ engine = engines.testing_engine(
+ options={"implicit_returning": False}
+ )
+ else:
+ engine = config.db
+
+ with engine.begin() as conn:
+ r = conn.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
+ # an insert where the PK was taken from a row that the dialect
+ # selected, as is the case for mssql/pyodbc, will still report
+ # returns_rows as true because there's a cursor description. in that
+ # case, the row had to have been consumed at least.
+ assert not r.returns_rows or r.fetchone() is None
+
+ @requirements.returning
+ def test_autoclose_on_insert_implicit_returning(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # note we are experimenting with having this be True
+ # as of I8091919d45421e3f53029b8660427f844fee0228 .
+ # implicit returning has fetched the row, but it still is a
+ # "returns rows"
+ assert r.returns_rows
+
+ # and we should be able to fetchone() on it, we just get no row
+ eq_(r.fetchone(), None)
+
+ # and the keys, etc.
+ eq_(r.keys(), ["id"])
+
+ # but the dialect took in the row already. not really sure
+ # what the best behavior is.
+
+ @requirements.empty_inserts
+ def test_empty_insert(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert())
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+ eq_(len(r.all()), 1)
+
+ @requirements.empty_inserts_executemany
+ def test_empty_insert_multiple(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+
+ eq_(len(r.all()), 3)
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+ connection.execute(
+ src_table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+ eq_(result.fetchall(), [("data2",), ("data3",)])
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc_no_rows(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+
+ eq_(result.fetchall(), [])
+
+ @requirements.insert_from_select
+ def test_insert_from_select(self, connection):
+ table = self.tables.manual_pk
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table.c.data).order_by(table.c.data)
+ ).fetchall(),
+ [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
+ )
+
+ @requirements.insert_from_select
+ def test_insert_from_select_with_defaults(self, connection):
+ table = self.tables.includes_defaults
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table).order_by(table.c.data, table.c.id)
+ ).fetchall(),
+ [
+ (1, "data1", 5, 4),
+ (2, "data2", 5, 4),
+ (7, "data2", 5, 4),
+ (3, "data3", 5, 4),
+ (8, "data3", 5, 4),
+ ],
+ )
+
+
+class ReturningTest(fixtures.TablesTest):
+ run_create_tables = "each"
+ __requires__ = "returning", "autoincrement_insert"
+ __backend__ = True
+
+ __engine_options__ = {"implicit_returning": True}
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ @requirements.fetch_rows_post_commit
+ def test_explicit_returning_pk_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_explicit_returning_pk_no_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_autoincrement_on_insert_implicit_returning(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id_implicit_returning(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
new file mode 100644
index 0000000..459a4d8
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -0,0 +1,1738 @@
+import operator
+import re
+
+import sqlalchemy as sa
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import expect_warnings
+from .. import fixtures
+from .. import is_
+from ..provision import get_temp_table_name
+from ..provision import temp_table_keyword_args
+from ..schema import Column
+from ..schema import Table
+from ... import event
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import String
+from ... import testing
+from ... import types as sql_types
+from ...schema import DDL
+from ...schema import Index
+from ...sql.elements import quoted_name
+from ...sql.schema import BLANK_SCHEMA
+from ...testing import is_false
+from ...testing import is_true
+
+
+metadata, users = None, None
+
+
+class HasTableTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "test_table_s",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+
+ def test_has_table(self):
+ with config.db.begin() as conn:
+ is_true(config.db.dialect.has_table(conn, "test_table"))
+ is_false(config.db.dialect.has_table(conn, "test_table_s"))
+ is_false(config.db.dialect.has_table(conn, "nonexistent_table"))
+
+ @testing.requires.schemas
+ def test_has_table_schema(self):
+ with config.db.begin() as conn:
+ is_false(
+ config.db.dialect.has_table(
+ conn, "test_table", schema=config.test_schema
+ )
+ )
+ is_true(
+ config.db.dialect.has_table(
+ conn, "test_table_s", schema=config.test_schema
+ )
+ )
+ is_false(
+ config.db.dialect.has_table(
+ conn, "nonexistent_table", schema=config.test_schema
+ )
+ )
+
+
+class HasIndexTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Index("my_idx", tt.c.data)
+
+ if testing.requires.schemas.enabled:
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+ Index("my_idx_s", tt.c.data)
+
+ def test_has_index(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(conn, "test_table", "my_idx")
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "nonexistent_table", "my_idx"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "nonexistent_idx"
+ )
+
+ @testing.requires.schemas
+ def test_has_index_schema(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "nonexistent_table",
+ "my_idx_s",
+ schema=config.test_schema,
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "test_table",
+ "nonexistent_idx_s",
+ schema=config.test_schema,
+ )
+
+
+class QuotedNameArgumentTest(fixtures.TablesTest):
+ run_create_tables = "once"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "quote ' one",
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name="pk quote ' one"),
+ sa.Index("ix quote ' one", "name"),
+ sa.UniqueConstraint(
+ "data",
+ name="uq quote' one",
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name="fk quote ' one"
+ ),
+ sa.CheckConstraint("name != 'foo'", name="ck quote ' one"),
+ comment=r"""quote ' one comment""",
+ test_needs_fk=True,
+ )
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ Table(
+ 'quote " two',
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name='pk quote " two'),
+ sa.Index('ix quote " two', "name"),
+ sa.UniqueConstraint(
+ "data",
+ name='uq quote" two',
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name='fk quote " two'
+ ),
+ sa.CheckConstraint("name != 'foo'", name='ck quote " two '),
+ comment=r"""quote " two comment""",
+ test_needs_fk=True,
+ )
+
+ Table(
+ "related",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("related", Integer),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.view_column_reflection.enabled:
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ names = [
+ "quote ' one",
+ 'quote " two',
+ ]
+ else:
+ names = [
+ "quote ' one",
+ ]
+ for name in names:
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ ),
+ config.db.dialect.identifier_preparer.quote(name),
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata,
+ "before_drop",
+ DDL(
+ "DROP VIEW %s"
+ % config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ )
+ ),
+ )
+
+ def quote_fixtures(fn):
+ return testing.combinations(
+ ("quote ' one",),
+ ('quote " two', testing.requires.symbol_names_w_double_quote),
+ )(fn)
+
+ @quote_fixtures
+ def test_get_table_options(self, name):
+ insp = inspect(config.db)
+
+ insp.get_table_options(name)
+
+ @quote_fixtures
+ @testing.requires.view_column_reflection
+ def test_get_view_definition(self, name):
+ insp = inspect(config.db)
+ assert insp.get_view_definition("view %s" % name)
+
+ @quote_fixtures
+ def test_get_columns(self, name):
+ insp = inspect(config.db)
+ assert insp.get_columns(name)
+
+ @quote_fixtures
+ def test_get_pk_constraint(self, name):
+ insp = inspect(config.db)
+ assert insp.get_pk_constraint(name)
+
+ @quote_fixtures
+ def test_get_foreign_keys(self, name):
+ insp = inspect(config.db)
+ assert insp.get_foreign_keys(name)
+
+ @quote_fixtures
+ def test_get_indexes(self, name):
+ insp = inspect(config.db)
+ assert insp.get_indexes(name)
+
+ @quote_fixtures
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_unique_constraints(name)
+
+ @quote_fixtures
+ @testing.requires.comment_reflection
+ def test_get_table_comment(self, name):
+ insp = inspect(config.db)
+ assert insp.get_table_comment(name)
+
+ @quote_fixtures
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_check_constraints(name)
+
+
+class ComponentReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+
+ @classmethod
+ def setup_bind(cls):
+ if config.requirements.independent_connections.enabled:
+ from sqlalchemy import pool
+
+ return engines.testing_engine(
+ options=dict(poolclass=pool.StaticPool, scope="class"),
+ )
+ else:
+ return config.db
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.define_reflected_tables(metadata, None)
+ if testing.requires.schemas.enabled:
+ cls.define_reflected_tables(metadata, testing.config.test_schema)
+
+ @classmethod
+ def define_reflected_tables(cls, metadata, schema):
+ if schema:
+ schema_prefix = schema + "."
+ else:
+ schema_prefix = ""
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ Column(
+ "parent_user_id",
+ sa.Integer,
+ sa.ForeignKey(
+ "%susers.user_id" % schema_prefix, name="user_id_fk"
+ ),
+ ),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ else:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ Table(
+ "dingalings",
+ metadata,
+ Column("dingaling_id", sa.Integer, primary_key=True),
+ Column(
+ "address_id",
+ sa.Integer,
+ sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ ),
+ Column("data", sa.String(30)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "email_addresses",
+ metadata,
+ Column("address_id", sa.Integer),
+ Column(
+ "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
+ ),
+ Column("email_address", sa.String(20)),
+ sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "comment_test",
+ metadata,
+ Column("id", sa.Integer, primary_key=True, comment="id comment"),
+ Column("data", sa.String(20), comment="data % comment"),
+ Column(
+ "d2",
+ sa.String(20),
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ schema=schema,
+ comment=r"""the test % ' " \ table comment""",
+ )
+
+ if testing.requires.cross_schema_fk_reflection.enabled:
+ if schema is None:
+ Table(
+ "local_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ Column(
+ "remote_id",
+ ForeignKey(
+ "%s.remote_table_2.id" % testing.config.test_schema
+ ),
+ ),
+ test_needs_fk=True,
+ schema=config.db.dialect.default_schema_name,
+ )
+ else:
+ Table(
+ "remote_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column(
+ "local_id",
+ ForeignKey(
+ "%s.local_table.id"
+ % config.db.dialect.default_schema_name
+ ),
+ ),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "remote_table_2",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ if testing.requires.index_reflection.enabled:
+ cls.define_index(metadata, users)
+
+ if not schema:
+ # test_needs_fk is at the moment to force MySQL InnoDB
+ noncol_idx_test_nopk = Table(
+ "noncol_idx_test_nopk",
+ metadata,
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ noncol_idx_test_pk = Table(
+ "noncol_idx_test_pk",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.indexes_with_ascdesc.enabled:
+ Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
+ Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
+
+ if testing.requires.view_column_reflection.enabled:
+ cls.define_views(metadata, schema)
+ if not schema and testing.requires.temp_table_reflection.enabled:
+ cls.define_temp_tables(metadata)
+
+ @classmethod
+ def define_temp_tables(cls, metadata):
+ kw = temp_table_keyword_args(config, config.db)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ user_tmp = Table(
+ table_name,
+ metadata,
+ Column("id", sa.INT, primary_key=True),
+ Column("name", sa.VARCHAR(50)),
+ Column("foo", sa.INT),
+ # disambiguate temp table unique constraint names. this is
+ # pretty arbitrary for a generic dialect however we are doing
+ # it to suit SQL Server which will produce name conflicts for
+ # unique constraints created against temp tables in different
+ # databases.
+ # https://www.arbinada.com/en/node/1645
+ sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident),
+ sa.Index("user_tmp_ix", "foo"),
+ **kw
+ )
+ if (
+ testing.requires.view_reflection.enabled
+ and testing.requires.temporary_views.enabled
+ ):
+ event.listen(
+ user_tmp,
+ "after_create",
+ DDL(
+ "create temporary view user_tmp_v as "
+ "select * from user_tmp_%s" % config.ident
+ ),
+ )
+ event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
+
+ @classmethod
+ def define_index(cls, metadata, users):
+ Index("users_t_idx", users.c.test1, users.c.test2)
+ Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
+
+ @classmethod
+ def define_views(cls, metadata, schema):
+ for table_name in ("users", "email_addresses"):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + "_v"
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ view_name,
+ fullname,
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
+ )
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names(self):
+ insp = inspect(self.bind)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names_w_translate_map(self, connection):
+ """test #7300"""
+
+ connection = connection.execution_options(
+ schema_translate_map={
+ "foo": "bar",
+ BLANK_SCHEMA: testing.config.test_schema,
+ }
+ )
+ insp = inspect(connection)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_dialect_initialize(self):
+ engine = engines.testing_engine()
+ inspect(engine)
+ assert hasattr(engine.dialect, "default_schema_name")
+
+ @testing.requires.schema_reflection
+ def test_get_default_schema_name(self):
+ insp = inspect(self.bind)
+ eq_(insp.default_schema_name, self.bind.dialect.default_schema_name)
+
+ @testing.requires.foreign_key_constraint_reflection
+ @testing.combinations(
+ (None, True, False, False),
+ (None, True, False, True, testing.requires.schemas),
+ ("foreign_key", True, False, False),
+ (None, False, True, False),
+ (None, False, True, True, testing.requires.schemas),
+ (None, True, True, False),
+ (None, True, True, True, testing.requires.schemas),
+ argnames="order_by,include_plain,include_views,use_schema",
+ )
+ def test_get_table_names(
+ self, connection, order_by, include_plain, include_views, use_schema
+ ):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ _ignore_tables = [
+ "comment_test",
+ "noncol_idx_test_pk",
+ "noncol_idx_test_nopk",
+ "local_table",
+ "remote_table",
+ "remote_table_2",
+ ]
+
+ insp = inspect(connection)
+
+ if include_views:
+ table_names = insp.get_view_names(schema)
+ table_names.sort()
+ answer = ["email_addresses_v", "users_v"]
+ eq_(sorted(table_names), answer)
+
+ if include_plain:
+ if order_by:
+ tables = [
+ rec[0]
+ for rec in insp.get_sorted_table_and_fkc_names(schema)
+ if rec[0]
+ ]
+ else:
+ tables = insp.get_table_names(schema)
+ table_names = [t for t in tables if t not in _ignore_tables]
+
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
+ eq_(table_names, answer)
+ else:
+ answer = ["dingalings", "email_addresses", "users"]
+ eq_(sorted(table_names), answer)
+
+ @testing.requires.temp_table_names
+ def test_get_temp_table_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_table_names()
+ eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident])
+
+ @testing.requires.view_reflection
+ @testing.requires.temp_table_names
+ @testing.requires.temporary_views
+ def test_get_temp_view_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_view_names()
+ eq_(sorted(temp_table_names), ["user_tmp_v"])
+
+ @testing.requires.comment_reflection
+ def test_get_comments(self):
+ self._test_get_comments()
+
+ @testing.requires.comment_reflection
+ @testing.requires.schemas
+ def test_get_comments_with_schema(self):
+ self._test_get_comments(testing.config.test_schema)
+
+ def _test_get_comments(self, schema=None):
+ insp = inspect(self.bind)
+
+ eq_(
+ insp.get_table_comment("comment_test", schema=schema),
+ {"text": r"""the test % ' " \ table comment"""},
+ )
+
+ eq_(insp.get_table_comment("users", schema=schema), {"text": None})
+
+ eq_(
+ [
+ {"name": rec["name"], "comment": rec["comment"]}
+ for rec in insp.get_columns("comment_test", schema=schema)
+ ],
+ [
+ {"comment": "id comment", "name": "id"},
+ {"comment": "data % comment", "name": "data"},
+ {
+ "comment": (
+ r"""Comment types type speedily ' " \ '' Fun!"""
+ ),
+ "name": "d2",
+ },
+ ],
+ )
+
+ @testing.combinations(
+ (False, False),
+ (False, True, testing.requires.schemas),
+ (True, False, testing.requires.view_reflection),
+ (
+ True,
+ True,
+ testing.requires.schemas + testing.requires.view_reflection,
+ ),
+ argnames="use_views,use_schema",
+ )
+ def test_get_columns(self, connection, use_views, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ if use_views:
+ table_names = ["users_v", "email_addresses_v"]
+ else:
+ table_names = ["users", "email_addresses"]
+
+ insp = inspect(connection)
+ for table_name, table in zip(table_names, (users, addresses)):
+ schema_name = schema
+ cols = insp.get_columns(table_name, schema=schema_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ # should be in order
+
+ for i, col in enumerate(table.columns):
+ eq_(col.name, cols[i]["name"])
+ ctype = cols[i]["type"].__class__
+ ctype_def = col.type
+ if isinstance(ctype_def, sa.types.TypeEngine):
+ ctype_def = ctype_def.__class__
+
+ # Oracle returns Date for DateTime.
+
+ if testing.against("oracle") and ctype_def in (
+ sql_types.Date,
+ sql_types.DateTime,
+ ):
+ ctype_def = sql_types.Date
+
+ # assert that the desired type and return type share
+ # a base within one of the generic types.
+
+ self.assert_(
+ len(
+ set(ctype.__mro__)
+ .intersection(ctype_def.__mro__)
+ .intersection(
+ [
+ sql_types.Integer,
+ sql_types.Numeric,
+ sql_types.DateTime,
+ sql_types.Date,
+ sql_types.Time,
+ sql_types.String,
+ sql_types._Binary,
+ ]
+ )
+ )
+ > 0,
+ "%s(%s), %s(%s)"
+ % (col.name, col.type, cols[i]["name"], ctype),
+ )
+
+ if not col.primary_key:
+ assert cols[i]["default"] is None
+
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_columns(self):
+ table_name = get_temp_table_name(
+ config, self.bind, "user_tmp_%s" % config.ident
+ )
+ user_tmp = self.tables[table_name]
+ insp = inspect(self.bind)
+ cols = insp.get_columns(table_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ for i, col in enumerate(user_tmp.columns):
+ eq_(col.name, cols[i]["name"])
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.view_column_reflection
+ @testing.requires.temporary_views
+ def test_get_temp_view_columns(self):
+ insp = inspect(self.bind)
+ cols = insp.get_columns("user_tmp_v")
+ eq_([col["name"] for col in cols], ["id", "name", "foo"])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.primary_key_constraint_reflection
+ def test_get_pk_constraint(self, connection, use_schema):
+ if use_schema:
+ schema = testing.config.test_schema
+ else:
+ schema = None
+
+ users, addresses = self.tables.users, self.tables.email_addresses
+ insp = inspect(connection)
+
+ users_cons = insp.get_pk_constraint(users.name, schema=schema)
+ users_pkeys = users_cons["constrained_columns"]
+ eq_(users_pkeys, ["user_id"])
+
+ addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
+ addr_pkeys = addr_cons["constrained_columns"]
+ eq_(addr_pkeys, ["address_id"])
+
+ with testing.requires.reflects_pk_names.fail_if():
+ eq_(addr_cons["name"], "email_ad_pk")
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.foreign_key_constraint_reflection
+ def test_get_foreign_keys(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ insp = inspect(connection)
+ expected_schema = schema
+ # users
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users_fkeys = insp.get_foreign_keys(users.name, schema=schema)
+ fkey1 = users_fkeys[0]
+
+ with testing.requires.named_constraints.fail_if():
+ eq_(fkey1["name"], "user_id_fk")
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ if testing.requires.self_referential_foreign_keys.enabled:
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
+
+ # addresses
+ addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
+ fkey1 = addr_fkeys[0]
+
+ with testing.requires.implicitly_named_constraints.fail_if():
+ self.assert_(fkey1["name"] is not None)
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ eq_(fkey1["constrained_columns"], ["remote_user_id"])
+
+ @testing.requires.cross_schema_fk_reflection
+ @testing.requires.schemas
+ def test_get_inter_schema_foreign_keys(self):
+ local_table, remote_table, remote_table_2 = self.tables(
+ "%s.local_table" % self.bind.dialect.default_schema_name,
+ "%s.remote_table" % testing.config.test_schema,
+ "%s.remote_table_2" % testing.config.test_schema,
+ )
+
+ insp = inspect(self.bind)
+
+ local_fkeys = insp.get_foreign_keys(local_table.name)
+ eq_(len(local_fkeys), 1)
+
+ fkey1 = local_fkeys[0]
+ eq_(fkey1["referred_schema"], testing.config.test_schema)
+ eq_(fkey1["referred_table"], remote_table_2.name)
+ eq_(fkey1["referred_columns"], ["id"])
+ eq_(fkey1["constrained_columns"], ["remote_id"])
+
+ remote_fkeys = insp.get_foreign_keys(
+ remote_table.name, schema=testing.config.test_schema
+ )
+ eq_(len(remote_fkeys), 1)
+
+ fkey2 = remote_fkeys[0]
+
+ assert fkey2["referred_schema"] in (
+ None,
+ self.bind.dialect.default_schema_name,
+ )
+ eq_(fkey2["referred_table"], local_table.name)
+ eq_(fkey2["referred_columns"], ["id"])
+ eq_(fkey2["constrained_columns"], ["local_id"])
+
+ def _assert_insp_indexes(self, indexes, expected_indexes):
+ index_names = [d["name"] for d in indexes]
+ for e_index in expected_indexes:
+ assert e_index["name"] in index_names
+ index = indexes[index_names.index(e_index["name"])]
+ for key in e_index:
+ eq_(e_index[key], index[key])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_indexes(self, connection, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ # The database may decide to create indexes for foreign keys, etc.
+ # so there may be more indexes than expected.
+ insp = inspect(self.bind)
+ indexes = insp.get_indexes("users", schema=schema)
+ expected_indexes = [
+ {
+ "unique": False,
+ "column_names": ["test1", "test2"],
+ "name": "users_t_idx",
+ },
+ {
+ "unique": False,
+ "column_names": ["user_id", "test2", "test1"],
+ "name": "users_all_idx",
+ },
+ ]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ @testing.combinations(
+ ("noncol_idx_test_nopk", "noncol_idx_nopk"),
+ ("noncol_idx_test_pk", "noncol_idx_pk"),
+ argnames="tname,ixname",
+ )
+ @testing.requires.index_reflection
+ @testing.requires.indexes_with_ascdesc
+ def test_get_noncol_index(self, connection, tname, ixname):
+ insp = inspect(connection)
+ indexes = insp.get_indexes(tname)
+
+ # reflecting an index that has "x DESC" in it as the column.
+ # the DB may or may not give us "x", but make sure we get the index
+ # back, it has a name, it's connected to the table.
+ expected_indexes = [{"unique": False, "name": ixname}]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ t = Table(tname, MetaData(), autoload_with=connection)
+ eq_(len(t.indexes), 1)
+ is_(list(t.indexes)[0].table, t)
+ eq_(list(t.indexes)[0].name, ixname)
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.unique_constraint_reflection
+ def test_get_temp_table_unique_constraints(self):
+ insp = inspect(self.bind)
+ reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident)
+ for refl in reflected:
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ refl.pop("duplicates_index", None)
+ eq_(
+ reflected,
+ [
+ {
+ "column_names": ["name"],
+ "name": "user_tmp_uq_%s" % config.ident,
+ }
+ ],
+ )
+
+ @testing.requires.temp_table_reflect_indexes
+ def test_get_temp_table_indexes(self):
+ insp = inspect(self.bind)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ indexes = insp.get_indexes(table_name)
+ for ind in indexes:
+ ind.pop("dialect_options", None)
+ expected = [
+ {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(
+ [idx for idx in indexes if idx["name"] == "user_tmp_ix"],
+ expected,
+ )
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, metadata, connection, use_schema):
+ # SQLite dialect needs to parse the names of the constraints
+ # separately from what it gets from PRAGMA index_list(), and
+ # then matches them up. so same set of column_names in two
+ # constraints will confuse it. Perhaps we should no longer
+ # bother with index_list() here since we have the whole
+ # CREATE TABLE?
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ uniques = sorted(
+ [
+ {"name": "unique_a", "column_names": ["a"]},
+ {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]},
+ {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]},
+ {"name": "unique_asc_key", "column_names": ["asc", "key"]},
+ {"name": "i.have.dots", "column_names": ["b"]},
+ {"name": "i have spaces", "column_names": ["c"]},
+ ],
+ key=operator.itemgetter("name"),
+ )
+ table = Table(
+ "testtbl",
+ metadata,
+ Column("a", sa.String(20)),
+ Column("b", sa.String(30)),
+ Column("c", sa.Integer),
+ # reserved identifiers
+ Column("asc", sa.String(30)),
+ Column("key", sa.String(30)),
+ schema=schema,
+ )
+ for uc in uniques:
+ table.append_constraint(
+ sa.UniqueConstraint(*uc["column_names"], name=uc["name"])
+ )
+ table.create(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_unique_constraints("testtbl", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ names_that_duplicate_index = set()
+
+ for orig, refl in zip(uniques, reflected):
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ dupe = refl.pop("duplicates_index", None)
+ if dupe:
+ names_that_duplicate_index.add(dupe)
+ eq_(orig, refl)
+
+ reflected_metadata = MetaData()
+ reflected = Table(
+ "testtbl",
+ reflected_metadata,
+ autoload_with=connection,
+ schema=schema,
+ )
+
+ # test "deduplicates for index" logic. MySQL and Oracle
+ # "unique constraints" are actually unique indexes (with possible
+ # exception of a unique that is a dupe of another one in the case
+ # of Oracle). make sure # they aren't duplicated.
+ idx_names = set([idx.name for idx in reflected.indexes])
+ uq_names = set(
+ [
+ uq.name
+ for uq in reflected.constraints
+ if isinstance(uq, sa.UniqueConstraint)
+ ]
+ ).difference(["unique_c_a_b"])
+
+ assert not idx_names.intersection(uq_names)
+ if names_that_duplicate_index:
+ eq_(names_that_duplicate_index, idx_names)
+ eq_(uq_names, set())
+
+ @testing.requires.view_reflection
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_view_definition(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ view_name1 = "users_v"
+ view_name2 = "email_addresses_v"
+ insp = inspect(connection)
+ v1 = insp.get_view_definition(view_name1, schema=schema)
+ self.assert_(v1)
+ v2 = insp.get_view_definition(view_name2, schema=schema)
+ self.assert_(v2)
+
+ # why is this here if it's PG specific ?
+ @testing.combinations(
+ ("users", False),
+ ("users", True, testing.requires.schemas),
+ argnames="table_name,use_schema",
+ )
+ @testing.only_on("postgresql", "PG specific feature")
+ def test_get_table_oid(self, connection, table_name, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ insp = inspect(connection)
+ oid = insp.get_table_oid(table_name, schema)
+ self.assert_(isinstance(oid, int))
+
+ @testing.requires.table_reflection
+ def test_autoincrement_col(self):
+ """test that 'autoincrement' is reflected according to sqla's policy.
+
+ Don't mark this test as unsupported for any backend !
+
+ (technically it fails with MySQL InnoDB since "id" comes before "id2")
+
+ A backend is better off not returning "autoincrement" at all,
+ instead of potentially returning "False" for an auto-incrementing
+ primary key column.
+
+ """
+
+ insp = inspect(self.bind)
+
+ for tname, cname in [
+ ("users", "user_id"),
+ ("email_addresses", "address_id"),
+ ("dingalings", "dingaling_id"),
+ ]:
+ cols = insp.get_columns(tname)
+ id_ = {c["name"]: c for c in cols}[cname]
+ assert id_.get("autoincrement", True)
+
+
+class TableNoColumnsTest(fixtures.TestBase):
+ __requires__ = ("reflect_tables_no_columns",)
+ __backend__ = True
+
+ @testing.fixture
+ def table_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ @testing.fixture
+ def view_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ Table("empty", metadata)
+ event.listen(
+ metadata,
+ "after_create",
+ DDL("CREATE VIEW empty_v AS SELECT * FROM empty"),
+ )
+
+ # for transactional DDL the transaction is rolled back before this
+ # drop statement is invoked
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v")
+ )
+ metadata.create_all(connection)
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_table_no_columns(self, connection, table_no_columns):
+ t2 = Table("empty", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_table_no_columns(self, connection, table_no_columns):
+ eq_(inspect(connection).get_columns("empty"), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_incl_table_no_columns(self, connection, table_no_columns):
+ m = MetaData()
+ m.reflect(connection)
+ assert set(m.tables).intersection(["empty"])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_view_no_columns(self, connection, view_no_columns):
+ t2 = Table("empty_v", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_view_no_columns(self, connection, view_no_columns):
+ eq_(inspect(connection).get_columns("empty_v"), [])
+
+
+class ComponentReflectionTestExtra(fixtures.TestBase):
+
+ __backend__ = True
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, metadata, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ Table(
+ "sa_cc",
+ metadata,
+ Column("a", Integer()),
+ sa.CheckConstraint("a > 1 AND a < 5", name="cc1"),
+ sa.CheckConstraint(
+ "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing"
+ ),
+ schema=schema,
+ )
+
+ metadata.create_all(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_check_constraints("sa_cc", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ # trying to minimize effect of quoting, parenthesis, etc.
+ # may need to add more to this as new dialects get CHECK
+ # constraint reflection support
+ def normalize(sqltext):
+ return " ".join(
+ re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)
+ )
+
+ reflected = [
+ {"name": item["name"], "sqltext": normalize(item["sqltext"])}
+ for item in reflected
+ ]
+ eq_(
+ reflected,
+ [
+ {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"},
+ {"name": "cc1", "sqltext": "a > 1 and a < 5"},
+ ],
+ )
+
+ @testing.requires.indexes_with_expressions
+ def test_reflect_expression_based_indexes(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+
+ Index("t_idx", func.lower(t.c.x), func.lower(t.c.y))
+
+ Index("t_idx_2", t.c.x)
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ expected = [
+ {"name": "t_idx_2", "column_names": ["x"], "unique": False}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ expected[0]["dialect_options"] = {
+ "%s_include" % connection.engine.name: []
+ }
+
+ with expect_warnings(
+ "Skipped unsupported reflection of expression-based index t_idx"
+ ):
+ eq_(
+ insp.get_indexes("t"),
+ expected,
+ )
+
+ @testing.requires.index_reflects_included_columns
+ def test_reflect_covering_index(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+ idx = Index("t_idx", t.c.x)
+ idx.dialect_options[connection.engine.name]["include"] = ["y"]
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ eq_(
+ insp.get_indexes("t"),
+ [
+ {
+ "name": "t_idx",
+ "column_names": ["x"],
+ "include_columns": ["y"],
+ "unique": False,
+ "dialect_options": {
+ "%s_include" % connection.engine.name: ["y"]
+ },
+ }
+ ],
+ )
+
+ t2 = Table("t", MetaData(), autoload_with=connection)
+ eq_(
+ list(t2.indexes)[0].dialect_options[connection.engine.name][
+ "include"
+ ],
+ ["y"],
+ )
+
+ def _type_round_trip(self, connection, metadata, *types):
+ t = Table(
+ "t",
+ metadata,
+ *[Column("t%d" % i, type_) for i, type_ in enumerate(types)]
+ )
+ t.create(connection)
+
+ return [c["type"] for c in inspect(connection).get_columns("t")]
+
+ @testing.requires.table_reflection
+ def test_numeric_reflection(self, connection, metadata):
+ for typ in self._type_round_trip(
+ connection, metadata, sql_types.Numeric(18, 5)
+ ):
+ assert isinstance(typ, sql_types.Numeric)
+ eq_(typ.precision, 18)
+ eq_(typ.scale, 5)
+
+ @testing.requires.table_reflection
+ def test_varchar_reflection(self, connection, metadata):
+ typ = self._type_round_trip(
+ connection, metadata, sql_types.String(52)
+ )[0]
+ assert isinstance(typ, sql_types.String)
+ eq_(typ.length, 52)
+
+ @testing.requires.table_reflection
+ def test_nullable_reflection(self, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("a", Integer, nullable=True),
+ Column("b", Integer, nullable=False),
+ )
+ t.create(connection)
+ eq_(
+ dict(
+ (col["name"], col["nullable"])
+ for col in inspect(connection).get_columns("t")
+ ),
+ {"a": True, "b": False},
+ )
+
+ @testing.combinations(
+ (
+ None,
+ "CASCADE",
+ None,
+ testing.requires.foreign_key_constraint_option_reflection_ondelete,
+ ),
+ (
+ None,
+ None,
+ "SET NULL",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ None,
+ "NO ACTION",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ "NO ACTION",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_noaction,
+ ),
+ (
+ None,
+ None,
+ "RESTRICT",
+ testing.requires.fk_constraint_option_reflection_onupdate_restrict,
+ ),
+ (
+ None,
+ "RESTRICT",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_restrict,
+ ),
+ argnames="expected,ondelete,onupdate",
+ )
+ def test_get_foreign_key_options(
+ self, connection, metadata, expected, ondelete, onupdate
+ ):
+ options = {}
+ if ondelete:
+ options["ondelete"] = ondelete
+ if onupdate:
+ options["onupdate"] = onupdate
+
+ if expected is None:
+ expected = options
+
+ Table(
+ "x",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")),
+ Column("test", String(10)),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "user",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ sa.ForeignKeyConstraint(
+ ["tid"], ["table.id"], name="myfk", **options
+ ),
+ test_needs_fk=True,
+ )
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ # test 'options' is always present for a backend
+ # that can reflect these, since alembic looks for this
+ opts = insp.get_foreign_keys("table")[0]["options"]
+
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), {})
+
+ opts = insp.get_foreign_keys("user")[0]["options"]
+ eq_(opts, expected)
+ # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected)
+
+
+class NormalizedNameTest(fixtures.TablesTest):
+ __requires__ = ("denormalized_names",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ quoted_name("t1", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+ Table(
+ quoted_name("t2", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("t1id", ForeignKey("t1.id")),
+ )
+
+ def test_reflect_lowercase_forced_tables(self):
+
+ m2 = MetaData()
+ t2_ref = Table(
+ quoted_name("t2", quote=True), m2, autoload_with=config.db
+ )
+ t1_ref = m2.tables["t1"]
+ assert t2_ref.c.t1id.references(t1_ref.c.id)
+
+ m3 = MetaData()
+ m3.reflect(
+ config.db, only=lambda name, m: name.lower() in ("t1", "t2")
+ )
+ assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id)
+
+ def test_get_table_names(self):
+ tablenames = [
+ t
+ for t in inspect(config.db).get_table_names()
+ if t.lower() in ("t1", "t2")
+ ]
+
+ eq_(tablenames[0].upper(), tablenames[0].lower())
+ eq_(tablenames[1].upper(), tablenames[1].lower())
+
+
+class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest):
+ def test_computed_col_default_not_set(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ col_data = {c["name"]: c for c in cols}
+ is_true("42" in col_data["with_default"]["default"])
+ is_(col_data["normal"]["default"], None)
+ is_(col_data["computed_col"]["default"], None)
+
+ def test_get_column_returns_computed(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ data = {c["name"]: c for c in cols}
+ for key in ("id", "normal", "with_default"):
+ is_true("computed" not in data[key])
+ compData = data["computed_col"]
+ is_true("computed" in compData)
+ is_true("sqltext" in compData["computed"])
+ eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")
+ eq_(
+ "persisted" in compData["computed"],
+ testing.requires.computed_columns_reflect_persisted.enabled,
+ )
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ eq_(
+ compData["computed"]["persisted"],
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+
+ def check_column(self, data, column, sqltext, persisted):
+ is_true("computed" in data[column])
+ compData = data[column]["computed"]
+ eq_(self.normalize(compData["sqltext"]), sqltext)
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ is_true("persisted" in compData)
+ is_(compData["persisted"], persisted)
+
+ def test_get_column_returns_persisted(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_column_table")
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal+42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal+2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal-42",
+ True,
+ )
+
+ @testing.requires.schemas
+ def test_get_column_returns_persisted_with_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns(
+ "computed_column_table", schema=config.test_schema
+ )
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal/42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal/2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal*42",
+ True,
+ )
+
+
+class IdentityReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("identity_columns", "table_reflection")
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity()),
+ )
+ Table(
+ "t2",
+ metadata,
+ Column(
+ "id2",
+ Integer,
+ Identity(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ ),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity(always=True, start=20)),
+ schema=config.test_schema,
+ )
+
+ def check(self, value, exp, approx):
+ if testing.requires.identity_columns_standard.enabled:
+ common_keys = (
+ "always",
+ "start",
+ "increment",
+ "minvalue",
+ "maxvalue",
+ "cycle",
+ "cache",
+ )
+ for k in list(value):
+ if k not in common_keys:
+ value.pop(k)
+ if approx:
+ eq_(len(value), len(exp))
+ for k in value:
+ if k == "minvalue":
+ is_true(value[k] <= exp[k])
+ elif k in {"maxvalue", "cache"}:
+ is_true(value[k] >= exp[k])
+ else:
+ eq_(value[k], exp[k], k)
+ else:
+ eq_(value, exp)
+ else:
+ eq_(value["start"], exp["start"])
+ eq_(value["increment"], exp["increment"])
+
+ def test_reflect_identity(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1") + insp.get_columns("t2")
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=False,
+ start=1,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+ elif col["name"] == "id2":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ approx=False,
+ )
+
+ @testing.requires.schemas
+ def test_reflect_identity_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1", schema=config.test_schema)
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=20,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+
+
+class CompositeKeyReflectionTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tb1 = Table(
+ "tb1",
+ metadata,
+ Column("id", Integer),
+ Column("attr", Integer),
+ Column("name", sql_types.VARCHAR(20)),
+ sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"),
+ schema=None,
+ test_needs_fk=True,
+ )
+ Table(
+ "tb2",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("pid", Integer),
+ Column("pattr", Integer),
+ Column("pname", sql_types.VARCHAR(20)),
+ sa.ForeignKeyConstraint(
+ ["pname", "pid", "pattr"],
+ [tb1.c.name, tb1.c.id, tb1.c.attr],
+ name="fk_tb1_name_id_attr",
+ ),
+ schema=None,
+ test_needs_fk=True,
+ )
+
+ @testing.requires.primary_key_constraint_reflection
+ def test_pk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ primary_key = insp.get_pk_constraint(self.tables.tb1.name)
+ eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"])
+
+ @testing.requires.foreign_key_constraint_reflection
+ def test_fk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ foreign_keys = insp.get_foreign_keys(self.tables.tb2.name)
+ eq_(len(foreign_keys), 1)
+ fkey1 = foreign_keys[0]
+ eq_(fkey1.get("referred_columns"), ["name", "id", "attr"])
+ eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"])
+
+
+__all__ = (
+ "ComponentReflectionTest",
+ "ComponentReflectionTestExtra",
+ "TableNoColumnsTest",
+ "QuotedNameArgumentTest",
+ "HasTableTest",
+ "HasIndexTest",
+ "NormalizedNameTest",
+ "ComputedReflectionTest",
+ "IdentityReflectionTest",
+ "CompositeKeyReflectionTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
new file mode 100644
index 0000000..c41a550
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -0,0 +1,426 @@
+import datetime
+
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import DateTime
+from ... import func
+from ... import Integer
+from ... import select
+from ... import sql
+from ... import String
+from ... import testing
+from ... import text
+from ... import util
+
+
+class RowFetchTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Table(
+ "has_dates",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("today", DateTime),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ connection.execute(
+ cls.tables.has_dates.insert(),
+ [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
+ )
+
+ def test_via_attr(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row.id, 1)
+ eq_(row.data, "d1")
+
+ def test_via_string(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping["id"], 1)
+ eq_(row._mapping["data"], "d1")
+
+ def test_via_int(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row[0], 1)
+ eq_(row[1], "d1")
+
+ def test_via_col_object(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping[self.tables.plain_pk.c.id], 1)
+ eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
+
+ @requirements.duplicate_names_in_cursor_description
+ def test_row_with_dupe_names(self, connection):
+ result = connection.execute(
+ select(
+ self.tables.plain_pk.c.data,
+ self.tables.plain_pk.c.data.label("data"),
+ ).order_by(self.tables.plain_pk.c.id)
+ )
+ row = result.first()
+ eq_(result.keys(), ["data", "data"])
+ eq_(row, ("d1", "d1"))
+
+ def test_row_w_scalar_select(self, connection):
+ """test that a scalar select as a column is returned as such
+ and that type conversion works OK.
+
+ (this is half a SQLAlchemy Core test and half to catch database
+ backends that may have unusual behavior with scalar selects.)
+
+ """
+ datetable = self.tables.has_dates
+ s = select(datetable.alias("x").c.today).scalar_subquery()
+ s2 = select(datetable.c.id, s.label("somelabel"))
+ row = connection.execute(s2).first()
+
+ eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
+
+
+class PercentSchemaNamesTest(fixtures.TablesTest):
+ """tests using percent signs, spaces in table and column names.
+
+ This didn't work for PostgreSQL / MySQL drivers for a long time
+ but is now supported.
+
+ """
+
+ __requires__ = ("percent_schema_names",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.tables.percent_table = Table(
+ "percent%table",
+ metadata,
+ Column("percent%", Integer),
+ Column("spaces % more spaces", Integer),
+ )
+ cls.tables.lightweight_percent_table = sql.table(
+ "percent%table",
+ sql.column("percent%"),
+ sql.column("spaces % more spaces"),
+ )
+
+ def test_single_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ for params in [
+ {"percent%": 5, "spaces % more spaces": 12},
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ]:
+ connection.execute(percent_table.insert(), params)
+ self._assert_table(connection)
+
+ def test_executemany_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ connection.execute(
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
+ )
+ connection.execute(
+ percent_table.insert(),
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
+ )
+ self._assert_table(connection)
+
+ def _assert_table(self, conn):
+ percent_table = self.tables.percent_table
+ lightweight_percent_table = self.tables.lightweight_percent_table
+
+ for table in (
+ percent_table,
+ percent_table.alias(),
+ lightweight_percent_table,
+ lightweight_percent_table.alias(),
+ ):
+ eq_(
+ list(
+ conn.execute(table.select().order_by(table.c["percent%"]))
+ ),
+ [(5, 12), (7, 11), (9, 10), (11, 9)],
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ table.select()
+ .where(table.c["spaces % more spaces"].in_([9, 10]))
+ .order_by(table.c["percent%"])
+ )
+ ),
+ [(9, 10), (11, 9)],
+ )
+
+ row = conn.execute(
+ table.select().order_by(table.c["percent%"])
+ ).first()
+ eq_(row._mapping["percent%"], 5)
+ eq_(row._mapping["spaces % more spaces"], 12)
+
+ eq_(row._mapping[table.c["percent%"]], 5)
+ eq_(row._mapping[table.c["spaces % more spaces"]], 12)
+
+ conn.execute(
+ percent_table.update().values(
+ {percent_table.c["spaces % more spaces"]: 15}
+ )
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ percent_table.select().order_by(
+ percent_table.c["percent%"]
+ )
+ )
+ ),
+ [(5, 15), (7, 15), (9, 15), (11, 15)],
+ )
+
+
+class ServerSideCursorsTest(
+ fixtures.TestBase, testing.AssertsExecutionResults
+):
+
+ __requires__ = ("server_side_cursors",)
+
+ __backend__ = True
+
+ def _is_server_side(self, cursor):
+ # TODO: this is a huge issue as it prevents these tests from being
+ # usable by third party dialects.
+ if self.engine.dialect.driver == "psycopg2":
+ return bool(cursor.name)
+ elif self.engine.dialect.driver == "pymysql":
+ sscursor = __import__("pymysql.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver in ("aiomysql", "asyncmy"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "mysqldb":
+ sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver == "mariadbconnector":
+ return not cursor.buffered
+ elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "pg8000":
+ return getattr(cursor, "server_side", False)
+ else:
+ return False
+
+ def _fixture(self, server_side_cursors):
+ if server_side_cursors:
+ with testing.expect_deprecated(
+ "The create_engine.server_side_cursors parameter is "
+ "deprecated and will be removed in a future release. "
+ "Please use the Connection.execution_options.stream_results "
+ "parameter."
+ ):
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ else:
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ return self.engine
+
+ @testing.combinations(
+ ("global_string", True, "select 1", True),
+ ("global_text", True, text("select 1"), True),
+ ("global_expr", True, select(1), True),
+ ("global_off_explicit", False, text("select 1"), False),
+ (
+ "stmt_option",
+ False,
+ select(1).execution_options(stream_results=True),
+ True,
+ ),
+ (
+ "stmt_option_disabled",
+ True,
+ select(1).execution_options(stream_results=False),
+ False,
+ ),
+ ("for_update_expr", True, select(1).with_for_update(), True),
+ # TODO: need a real requirement for this, or dont use this test
+ (
+ "for_update_string",
+ True,
+ "SELECT 1 FOR UPDATE",
+ True,
+ testing.skip_if("sqlite"),
+ ),
+ ("text_no_ss", False, text("select 42"), False),
+ (
+ "text_ss_option",
+ False,
+ text("select 42").execution_options(stream_results=True),
+ True,
+ ),
+ id_="iaaa",
+ argnames="engine_ss_arg, statement, cursor_ss_status",
+ )
+ def test_ss_cursor_status(
+ self, engine_ss_arg, statement, cursor_ss_status
+ ):
+ engine = self._fixture(engine_ss_arg)
+ with engine.begin() as conn:
+ if isinstance(statement, util.string_types):
+ result = conn.exec_driver_sql(statement)
+ else:
+ result = conn.execute(statement)
+ eq_(self._is_server_side(result.cursor), cursor_ss_status)
+ result.close()
+
+ def test_conn_option(self):
+ engine = self._fixture(False)
+
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
+
+ def test_stmt_enabled_conn_option_disabled(self):
+ engine = self._fixture(False)
+
+ s = select(1).execution_options(stream_results=True)
+
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
+
+ def test_aliases_and_ss(self):
+ engine = self._fixture(False)
+ s1 = (
+ select(sql.literal_column("1").label("x"))
+ .execution_options(stream_results=True)
+ .subquery()
+ )
+
+ # options don't propagate out when subquery is used as a FROM clause
+ with engine.begin() as conn:
+ result = conn.execute(s1.select())
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ s2 = select(1).select_from(s1)
+ with engine.begin() as conn:
+ result = conn.execute(s2)
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ def test_roundtrip_fetchall(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(test_table.insert(), dict(data="data1"))
+ connection.execute(test_table.insert(), dict(data="data2"))
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ connection.execute(
+ test_table.update()
+ .where(test_table.c.id == 2)
+ .values(data=test_table.c.data + " updated")
+ )
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
+ connection.execute(test_table.delete())
+ eq_(
+ connection.scalar(
+ select(func.count("*")).select_from(test_table)
+ ),
+ 0,
+ )
+
+ def test_roundtrip_fetchmany(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(
+ test_table.insert(),
+ [dict(data="data%d" % i) for i in range(1, 20)],
+ )
+
+ result = connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ )
+
+ eq_(
+ result.fetchmany(5),
+ [(i, "data%d" % i) for i in range(1, 6)],
+ )
+ eq_(
+ result.fetchmany(10),
+ [(i, "data%d" % i) for i in range(6, 16)],
+ )
+ eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py
new file mode 100644
index 0000000..82e831f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_rowcount.py
@@ -0,0 +1,165 @@
+from sqlalchemy import bindparam
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+
+
+class RowCountTest(fixtures.TablesTest):
+ """test rowcount functionality"""
+
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "employees",
+ metadata,
+ Column(
+ "employee_id",
+ Integer,
+ autoincrement=False,
+ primary_key=True,
+ ),
+ Column("name", String(50)),
+ Column("department", String(1)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ cls.data = data = [
+ ("Angela", "A"),
+ ("Andrew", "A"),
+ ("Anand", "A"),
+ ("Bob", "B"),
+ ("Bobette", "B"),
+ ("Buffy", "B"),
+ ("Charlie", "C"),
+ ("Cynthia", "C"),
+ ("Chris", "C"),
+ ]
+
+ employees_table = cls.tables.employees
+ connection.execute(
+ employees_table.insert(),
+ [
+ {"employee_id": i, "name": n, "department": d}
+ for i, (n, d) in enumerate(data)
+ ],
+ )
+
+ def test_basic(self, connection):
+ employees_table = self.tables.employees
+ s = select(
+ employees_table.c.name, employees_table.c.department
+ ).order_by(employees_table.c.employee_id)
+ rows = connection.execute(s).fetchall()
+
+ eq_(rows, self.data)
+
+ def test_update_rowcount1(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows changed
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "Z"},
+ )
+ assert r.rowcount == 3
+
+ def test_update_rowcount2(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 0 rows changed
+ department = employees_table.c.department
+
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "C"},
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_rowcount_w_returning
+ def test_update_rowcount_return_defaults(self, connection):
+ employees_table = self.tables.employees
+
+ department = employees_table.c.department
+ stmt = (
+ employees_table.update()
+ .where(department == "C")
+ .values(name=employees_table.c.department + "Z")
+ .return_defaults()
+ )
+
+ r = connection.execute(stmt)
+ eq_(r.rowcount, 3)
+
+ def test_raw_sql_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.exec_driver_sql(
+ "update employees set department='Z' where department='C'"
+ )
+ eq_(result.rowcount, 3)
+
+ def test_text_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.execute(
+ text("update employees set department='Z' " "where department='C'")
+ )
+ eq_(result.rowcount, 3)
+
+ def test_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows deleted
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.delete().where(department == "C")
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_update_rowcount(self, connection):
+ employees_table = self.tables.employees
+ stmt = (
+ employees_table.update()
+ .where(employees_table.c.name == bindparam("emp_name"))
+ .values(department="C")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ stmt = employees_table.delete().where(
+ employees_table.c.name == bindparam("emp_name")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
new file mode 100644
index 0000000..cb78fff
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -0,0 +1,1783 @@
+import itertools
+
+from .. import AssertsCompiledSQL
+from .. import AssertsExecutionResults
+from .. import config
+from .. import fixtures
+from ..assertions import assert_raises
+from ..assertions import eq_
+from ..assertions import in_
+from ..assertsql import CursorSQL
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import case
+from ... import column
+from ... import Computed
+from ... import exists
+from ... import false
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import null
+from ... import select
+from ... import String
+from ... import table
+from ... import testing
+from ... import text
+from ... import true
+from ... import tuple_
+from ... import TupleType
+from ... import union
+from ... import util
+from ... import values
+from ...exc import DatabaseError
+from ...exc import ProgrammingError
+from ...util import collections_abc
+
+
+class CollateTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "collate data1"},
+ {"id": 2, "data": "collate data2"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ @testing.requires.order_by_collation
+ def test_collate_order_by(self):
+ collation = testing.requires.get_order_by_collation(testing.config)
+
+ self._assert_result(
+ select(self.tables.some_table).order_by(
+ self.tables.some_table.c.data.collate(collation).asc()
+ ),
+ [(1, "collate data1"), (2, "collate data2")],
+ )
+
+
+class OrderByLabelTest(fixtures.TablesTest):
+ """Test the dialect sends appropriate ORDER BY expressions when
+ labels are used.
+
+ This essentially exercises the "supports_simple_order_by_label"
+ setting.
+
+ """
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", String(50)),
+ Column("p", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
+ {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
+ {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ def test_plain(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)])
+
+ def test_composed_int(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)])
+
+ def test_composed_multiple(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ ly = (func.lower(table.c.q) + table.c.p).label("ly")
+ self._assert_result(
+ select(lx, ly).order_by(lx, ly.desc()),
+ [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))],
+ )
+
+ def test_plain_desc(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)])
+
+ def test_composed_int_desc(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)])
+
+ @testing.requires.group_by_complex_expression
+ def test_group_by_composed(self):
+ table = self.tables.some_table
+ expr = (table.c.x + table.c.y).label("lx")
+ stmt = (
+ select(func.count(table.c.id), expr).group_by(expr).order_by(expr)
+ )
+ self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)])
+
+
+class ValuesExpressionTest(fixtures.TestBase):
+ __requires__ = ("table_value_constructor",)
+
+ __backend__ = True
+
+ def test_tuples(self, connection):
+ value_expr = values(
+ column("id", Integer), column("name", String), name="my_values"
+ ).data([(1, "name1"), (2, "name2"), (3, "name3")])
+
+ eq_(
+ connection.execute(select(value_expr)).all(),
+ [(1, "name1"), (2, "name2"), (3, "name3")],
+ )
+
+
+class FetchLimitOffsetTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ {"id": 5, "x": 4, "y": 6},
+ ],
+ )
+
+ def _assert_result(
+ self, connection, select, result, params=(), set_=False
+ ):
+ if set_:
+ query_res = connection.execute(select, params).fetchall()
+ eq_(len(query_res), len(result))
+ eq_(set(query_res), set(result))
+
+ else:
+ eq_(connection.execute(select, params).fetchall(), result)
+
+ def _assert_result_str(self, select, result, params=()):
+ conn = config.db.connect(close_with_result=True)
+ eq_(conn.exec_driver_sql(select, params).fetchall(), result)
+
+ def test_simple_limit(self, connection):
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id)
+ self._assert_result(
+ connection,
+ stmt.limit(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ stmt.limit(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ def test_limit_render_multiple_times(self, connection):
+ table = self.tables.some_table
+ stmt = select(table.c.id).limit(1).scalar_subquery()
+
+ u = union(select(stmt), select(stmt)).subquery().select()
+
+ self._assert_result(
+ connection,
+ u,
+ [
+ (1,),
+ ],
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.offset
+ def test_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.combinations(
+ ([(2, 0), (2, 1), (3, 2)]),
+ ([(2, 1), (2, 0), (3, 2)]),
+ ([(3, 1), (2, 1), (3, 1)]),
+ argnames="cases",
+ )
+ @testing.requires.offset
+ def test_simple_limit_offset(self, connection, cases):
+ table = self.tables.some_table
+ connection = connection.execution_options(compiled_cache={})
+
+ assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)]
+
+ for limit, offset in cases:
+ expected = assert_data[offset : offset + limit]
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(limit).offset(offset),
+ expected,
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2).offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_no_order_by
+ def test_fetch_offset_no_order(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).fetch(10),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.offset
+ def test_simple_offset_zero(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(0),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(1),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.offset
+ def test_limit_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).limit(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.fetch_first
+ def test_fetch_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).fetch(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3)],
+ params={"l": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ params={"l": 3},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 1},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"l": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"l": 3, "o": 2},
+ )
+
+ @testing.requires.fetch_first
+ def test_bound_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"f": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"f": 3, "o": 2},
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .offset(literal_column("1") + literal_column("2")),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("2")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.fetch_first
+ @testing.requires.fetch_expression
+ def test_expr_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_simple_limit_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(2)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(3)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(2),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ def test_simple_fetch_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties_exact_number(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(3, with_ties=True)
+ .offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(20, percent=True),
+ [(1, 1, 2)],
+ )
+
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(40, percent=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x.desc())
+ .fetch(20, percent=True, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(40, percent=True, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+
+class JoinTest(fixtures.TablesTest):
+ __backend__ = True
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("a", metadata, Column("id", Integer, primary_key=True))
+ Table(
+ "b",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("a_id", ForeignKey("a.id"), nullable=False),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.a.insert(),
+ [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}],
+ )
+
+ connection.execute(
+ cls.tables.b.insert(),
+ [
+ {"id": 1, "a_id": 1},
+ {"id": 2, "a_id": 1},
+ {"id": 4, "a_id": 2},
+ {"id": 5, "a_id": 3},
+ ],
+ )
+
+ def test_inner_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+ def test_inner_join_true(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, true()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (a, b, c)
+ for (a,), (b, c) in itertools.product(
+ [(1,), (2,), (3,), (4,), (5,)],
+ [(1, 1), (2, 1), (4, 2), (5, 3)],
+ )
+ ],
+ )
+
+ def test_inner_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(stmt, [])
+
+ def test_outer_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.outerjoin(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (1, None, None),
+ (2, None, None),
+ (3, None, None),
+ (4, None, None),
+ (5, None, None),
+ ],
+ )
+
+ def test_outer_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+
+class CompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_select_from_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_in_unions_from_alias(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ # this necessarily has double parens
+ u1 = union(s1, s2).alias()
+ self._assert_result(
+ u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+
+class PostCompileParamsTest(
+ AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest
+):
+ __backend__ = True
+
+ __requires__ = ("standard_cursor_sql",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def test_compile(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table "
+ "WHERE some_table.x = __[POSTCOMPILE_q]",
+ {},
+ )
+
+ def test_compile_literal_binds(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", 10, literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table WHERE some_table.x = 10",
+ {},
+ literal_binds=True,
+ )
+
+ def test_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=10))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x = 10",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ def test_execute_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x.in_(bindparam("q", expanding=True, literal_execute=True))
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[5, 6, 7]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x IN (5, 6, 7)",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, 10), (12, 18)]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.y) "
+ "IN (%s(5, 10), (12, 18))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.z) "
+ "IN (%s(5, 'z1'), (12, 'z3'))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+
+class ExpandingBoundInTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_multiple_empty_sets_bindparam(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .where(table.c.y.in_(bindparam("p")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": [], "p": []})
+
+ def test_multiple_empty_sets_direct(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.y.in_([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)])
+ go([], [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)])
+ go([], [])
+
+ def test_bound_in_scalar_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
+
+ def test_bound_in_scalar_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3, 4]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ def test_nonempty_in_plus_empty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3]))
+ .where(table.c.id.not_in([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,)])
+
+ def test_empty_in_plus_notempty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.id.not_in([2, 3]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ def test_typed_str_in(self):
+ """test related to #7292.
+
+ as a type is given to the bound param, there is no ambiguity
+ to the type of element.
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", type_=String, expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ def test_untyped_str_in(self):
+ """test related to #7292.
+
+ for untyped expression, we look at the types of elements.
+ Test for Sequence to detect tuple in. but not strings or bytes!
+ as always....
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ [(2, "z2"), (3, "z3"), (4, "z4")]
+ )
+ )
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self):
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(
+ bindparam(
+ "q", type_=TupleType(Integer(), String()), expanding=True
+ )
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ def test_empty_set_against_integer_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_integer_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_integer_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_integer_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_empty_set_against_string_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_string_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_string_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_string_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_null_in_empty_set_is_false_bindparam(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_(bindparam("foo", value=())),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+ def test_null_in_empty_set_is_false_direct(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_([]),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+
+class LikeFunctionsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ run_inserts = "once"
+ run_deletes = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "abcdefg"},
+ {"id": 2, "data": "ab/cdefg"},
+ {"id": 3, "data": "ab%cdefg"},
+ {"id": 4, "data": "ab_cdefg"},
+ {"id": 5, "data": "abcde/fg"},
+ {"id": 6, "data": "abcde%fg"},
+ {"id": 7, "data": "ab#cdefg"},
+ {"id": 8, "data": "ab9cdefg"},
+ {"id": 9, "data": "abcde#fg"},
+ {"id": 10, "data": "abcd9fg"},
+ {"id": 11, "data": None},
+ ],
+ )
+
+ def _test(self, expr, expected):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ rows = {
+ value
+ for value, in conn.execute(select(some_table.c.id).where(expr))
+ }
+
+ eq_(rows, expected)
+
+ def test_startswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
+
+ def test_startswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True), {3})
+
+ def test_startswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.startswith(literal_column("'ab%c'")),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ )
+
+ def test_startswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab##c", escape="#"), {7})
+
+ def test_startswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3})
+ self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7})
+
+ def test_endswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_endswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9}
+ )
+
+ def test_endswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True), {6})
+
+ def test_endswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e##fg", escape="#"), {9})
+
+ def test_endswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6})
+ self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9})
+
+ def test_contains_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_contains_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde", autoescape=True), {3})
+
+ def test_contains_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b##cde", escape="#"), {7})
+
+ def test_contains_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
+ self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
+
+ @testing.requires.regexp_match
+ def test_not_regexp_match(self):
+ col = self.tables.some_table.c.data
+ self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10})
+
+ @testing.requires.regexp_replace
+ def test_regexp_replace(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9}
+ )
+
+ @testing.requires.regexp_match
+ @testing.combinations(
+ ("a.cde", {1, 5, 6, 9}),
+ ("abc", {1, 5, 6, 9, 10}),
+ ("^abc", {1, 5, 6, 9, 10}),
+ ("9cde", {8}),
+ ("^a", set(range(1, 11))),
+ ("(b|c)", set(range(1, 11))),
+ ("^(b|c)", set()),
+ )
+ def test_regexp_match(self, text, expected):
+ col = self.tables.some_table.c.data
+ self._test(col.regexp_match(text), expected)
+
+
+class ComputedColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("computed_columns",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "square",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("side", Integer),
+ Column("area", Integer, Computed("side * side")),
+ Column("perimeter", Integer, Computed("4 * side")),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.square.insert(),
+ [{"id": 1, "side": 10}, {"id": 10, "side": 42}],
+ )
+
+ def test_select_all(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(text("*"))
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)])
+
+ def test_select_columns(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(
+ self.tables.square.c.area, self.tables.square.c.perimeter
+ )
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(100, 40), (1764, 168)])
+
+
+class IdentityColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("identity_columns",)
+ run_inserts = "once"
+ run_deletes = "once"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl_a",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(
+ always=True, start=42, nominvalue=True, nomaxvalue=True
+ ),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+ Table(
+ "tbl_b",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.tbl_a.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"id": 42, "desc": "c"}],
+ )
+
+ def test_select_all(self, connection):
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_a)
+ .order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42, "a"), (43, "b")])
+
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_b)
+ .order_by(self.tables.tbl_b.c.id)
+ ).fetchall()
+ eq_(res, [(-5, "b"), (0, "a"), (42, "c")])
+
+ def test_select_columns(self, connection):
+
+ res = connection.execute(
+ select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42,), (43,)])
+
+ @testing.requires.identity_columns_standard
+ def test_insert_always_error(self, connection):
+ def fn():
+ connection.execute(
+ self.tables.tbl_a.insert(),
+ [{"id": 200, "desc": "a"}],
+ )
+
+ assert_raises((DatabaseError, ProgrammingError), fn)
+
+
+class IdentityAutoincrementTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("autoincrement_without_sequence",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(),
+ primary_key=True,
+ autoincrement=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ def test_autoincrement_with_identity(self, connection):
+ res = connection.execute(self.tables.tbl.insert(), {"desc": "row"})
+ res = connection.execute(self.tables.tbl.select()).first()
+ eq_(res, (1, "row"))
+
+
+class ExistsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "stuff",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.stuff.insert(),
+ [
+ {"id": 1, "data": "some data"},
+ {"id": 2, "data": "some data"},
+ {"id": 3, "data": "some data"},
+ {"id": 4, "data": "some other data"},
+ ],
+ )
+
+ def test_select_exists(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "some data")
+ )
+ ).fetchall(),
+ [(1,)],
+ )
+
+ def test_select_exists_false(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "no data")
+ )
+ ).fetchall(),
+ [],
+ )
+
+
+class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest):
+ __backend__ = True
+
+ @testing.fails_if(testing.requires.supports_distinct_on)
+ def test_distinct_on(self):
+ stm = select("*").distinct(column("q")).select_from(table("foo"))
+ with testing.expect_deprecated(
+ "DISTINCT ON is currently supported only by the PostgreSQL "
+ ):
+ self.assert_compile(stm, "SELECT DISTINCT * FROM foo")
+
+
+class IsOrIsNotDistinctFromTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("supports_is_distinct_from",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "is_distinct_test",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("col_a", Integer, nullable=True),
+ Column("col_b", Integer, nullable=True),
+ )
+
+ @testing.combinations(
+ ("both_int_different", 0, 1, 1),
+ ("both_int_same", 1, 1, 0),
+ ("one_null_first", None, 1, 1),
+ ("one_null_second", 0, None, 1),
+ ("both_null", None, None, 0),
+ id_="iaaa",
+ argnames="col_a_value, col_b_value, expected_row_count_for_is",
+ )
+ def test_is_or_is_not_distinct_from(
+ self, col_a_value, col_b_value, expected_row_count_for_is, connection
+ ):
+ tbl = self.tables.is_distinct_test
+
+ connection.execute(
+ tbl.insert(),
+ [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}],
+ )
+
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is,
+ )
+
+ expected_row_count_for_is_not = (
+ 1 if expected_row_count_for_is == 0 else 0
+ )
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is_not,
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
new file mode 100644
index 0000000..d6747d2
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -0,0 +1,282 @@
+from .. import config
+from .. import fixtures
+from ..assertions import eq_
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import Sequence
+from ... import String
+from ... import testing
+
+
+class SequenceTest(fixtures.TablesTest):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ run_create_tables = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "seq_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_opt_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq", data_type=Integer, optional=True),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_no_returning",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ )
+
+ if testing.requires.schemas.enabled:
+ Table(
+ "seq_no_returning_sch",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema=config.test_schema),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ schema=config.test_schema,
+ )
+
+ def test_insert_roundtrip(self, connection):
+ connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
+ self._assert_round_trip(self.tables.seq_pk, connection)
+
+ def test_insert_lastrowid(self, connection):
+ r = connection.execute(
+ self.tables.seq_pk.insert(), dict(data="some data")
+ )
+ eq_(
+ r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
+ )
+
+ def test_nextval_direct(self, connection):
+ r = connection.execute(self.tables.seq_pk.c.id.default)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+ @requirements.sequences_optional
+ def test_optional_seq(self, connection):
+ r = connection.execute(
+ self.tables.seq_opt_pk.insert(), dict(data="some data")
+ )
+ eq_(r.inserted_primary_key, (1,))
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
+
+ def test_insert_roundtrip_no_implicit_returning(self, connection):
+ connection.execute(
+ self.tables.seq_no_returning.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.seq_no_returning, connection)
+
+ @testing.combinations((True,), (False,), argnames="implicit_returning")
+ @testing.requires.schemas
+ def test_insert_roundtrip_translate(self, connection, implicit_returning):
+
+ seq_no_returning = Table(
+ "seq_no_returning_sch",
+ MetaData(),
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema="alt_schema"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=implicit_returning,
+ schema="alt_schema",
+ )
+
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+ connection.execute(seq_no_returning.insert(), dict(data="some data"))
+ self._assert_round_trip(seq_no_returning, connection)
+
+ @testing.requires.schemas
+ def test_nextval_direct_schema_translate(self, connection):
+ seq = Sequence("noret_sch_id_seq", schema="alt_schema")
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+
+ r = connection.execute(seq)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+
+class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_literal_binds_inline_compile(self, connection):
+ table = Table(
+ "x",
+ MetaData(),
+ Column("y", Integer, Sequence("y_seq")),
+ Column("q", Integer),
+ )
+
+ stmt = table.insert().values(q=5)
+
+ seq_nextval = connection.dialect.statement_compiler(
+ statement=None, dialect=connection.dialect
+ ).visit_sequence(Sequence("y_seq"))
+ self.assert_compile(
+ stmt,
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
+ literal_binds=True,
+ dialect=connection.dialect,
+ )
+
+
+class HasSequenceTest(fixtures.TablesTest):
+ run_deletes = None
+
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Sequence("user_id_seq", metadata=metadata)
+ Sequence(
+ "other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True
+ )
+ if testing.requires.schemas.enabled:
+ Sequence(
+ "user_id_seq", schema=config.test_schema, metadata=metadata
+ )
+ Sequence(
+ "schema_seq", schema=config.test_schema, metadata=metadata
+ )
+ Table(
+ "user_id_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+
+ def test_has_sequence(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_seq"),
+ True,
+ )
+
+ def test_has_sequence_other_object(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_table"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "user_id_seq", schema=config.test_schema
+ ),
+ True,
+ )
+
+ def test_has_sequence_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence("some_sequence"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schemas_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "some_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_default_not_in_remote(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "other_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_remote_not_in_default(self, connection):
+ eq_(
+ inspect(connection).has_sequence("schema_seq"),
+ False,
+ )
+
+ def test_get_sequence_names(self, connection):
+ exp = {"other_seq", "user_id_seq"}
+
+ res = set(inspect(connection).get_sequence_names())
+ is_true(res.intersection(exp) == exp)
+ is_true("schema_seq" not in res)
+
+ @testing.requires.schemas
+ def test_get_sequence_names_no_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema_2
+ ),
+ [],
+ )
+
+ @testing.requires.schemas
+ def test_get_sequence_names_sequences_schema(self, connection):
+ eq_(
+ sorted(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema
+ )
+ ),
+ ["schema_seq", "user_id_seq"],
+ )
+
+
+class HasSequenceTestEmpty(fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_get_sequence_names_no_sequence(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(),
+ [],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
new file mode 100644
index 0000000..b96350e
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -0,0 +1,1508 @@
+# coding: utf-8
+
+import datetime
+import decimal
+import json
+import re
+
+from .. import config
+from .. import engines
+from .. import fixtures
+from .. import mock
+from ..assertions import eq_
+from ..assertions import is_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import and_
+from ... import BigInteger
+from ... import bindparam
+from ... import Boolean
+from ... import case
+from ... import cast
+from ... import Date
+from ... import DateTime
+from ... import Float
+from ... import Integer
+from ... import JSON
+from ... import literal
+from ... import MetaData
+from ... import null
+from ... import Numeric
+from ... import select
+from ... import String
+from ... import testing
+from ... import Text
+from ... import Time
+from ... import TIMESTAMP
+from ... import TypeDecorator
+from ... import Unicode
+from ... import UnicodeText
+from ... import util
+from ...orm import declarative_base
+from ...orm import Session
+from ...sql.sqltypes import LargeBinary
+from ...sql.sqltypes import PickleType
+from ...util import compat
+from ...util import u
+
+
+class _LiteralRoundTripFixture(object):
+ supports_whereclause = True
+
+ @testing.fixture
+ def literal_round_trip(self, metadata, connection):
+ """test literal rendering"""
+
+ # for literal, we test the literal render in an INSERT
+ # into a typed column. we can then SELECT it back as its
+ # official type; ideally we'd be able to use CAST here
+ # but MySQL in particular can't CAST fully
+
+ def run(type_, input_, output, filter_=None):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ for value in input_:
+ ins = (
+ t.insert()
+ .values(x=literal(value, type_))
+ .compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ )
+ connection.execute(ins)
+
+ if self.supports_whereclause:
+ stmt = t.select().where(t.c.x == literal(value))
+ else:
+ stmt = t.select()
+
+ stmt = stmt.compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ for row in connection.execute(stmt):
+ value = row[0]
+ if filter_ is not None:
+ value = filter_(value)
+ assert value in output
+
+ return run
+
+
+class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ __requires__ = ("unicode_data",)
+
+ data = u(
+ "Alors vous imaginez ma 🐍 surprise, au lever du jour, "
+ "quand une drôle de petite 🐍 voix m’a réveillé. Elle "
+ "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »"
+ )
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "unicode_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("unicode_data", cls.datatype),
+ )
+
+ def test_round_trip(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": self.data}
+ )
+
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+
+ eq_(row, (self.data,))
+ assert isinstance(row[0], util.text_type)
+
+ def test_round_trip_executemany(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(),
+ [{"id": i, "unicode_data": self.data} for i in range(1, 4)],
+ )
+
+ rows = connection.execute(
+ select(unicode_table.c.unicode_data)
+ ).fetchall()
+ eq_(rows, [(self.data,) for i in range(1, 4)])
+ for row in rows:
+ assert isinstance(row[0], util.text_type)
+
+ def _test_null_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": None}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (None,))
+
+ def _test_empty_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": u("")}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (u(""),))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(self.datatype, [self.data], [self.data])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+
+class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = ("unicode_data",)
+ __backend__ = True
+
+ datatype = Unicode(255)
+
+ @requirements.empty_strings_varchar
+ def test_empty_strings_varchar(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_varchar(self, connection):
+ self._test_null_strings(connection)
+
+
+class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = "unicode_data", "text_type"
+ __backend__ = True
+
+ datatype = UnicodeText()
+
+ @requirements.empty_strings_text
+ def test_empty_strings_text(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_text(self, connection):
+ self._test_null_strings(connection)
+
+
+class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("binary_literals",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "binary_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("binary_data", LargeBinary),
+ Column("pickle_data", PickleType),
+ )
+
+ def test_binary_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(), {"id": 1, "binary_data": b"this is binary"}
+ )
+ row = connection.execute(select(binary_table.c.binary_data)).first()
+ eq_(row, (b"this is binary",))
+
+ def test_pickle_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(),
+ {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}},
+ )
+ row = connection.execute(select(binary_table.c.pickle_data)).first()
+ eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},))
+
+
+class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("text_type",)
+ __backend__ = True
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "text_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("text_data", Text),
+ )
+
+ def test_text_roundtrip(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(
+ text_table.insert(), {"id": 1, "text_data": "some text"}
+ )
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("some text",))
+
+ @testing.requires.empty_strings_text
+ def test_text_empty_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": ""})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("",))
+
+ def test_text_null_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": None})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, (None,))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Text, ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_percentsigns(self, literal_round_trip):
+ data = r"percent % signs %% percent"
+ literal_round_trip(Text, [data], [data])
+
+
+class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @requirements.unbounded_varchar
+ def test_nolength_string(self):
+ metadata = MetaData()
+ foo = Table("foo", metadata, Column("one", String))
+
+ foo.create(config.db)
+ foo.drop(config.db)
+
+ def test_literal(self, literal_round_trip):
+ # note that in Python 3, this invokes the Unicode
+ # datatype for the literal part because all strings are unicode
+ literal_round_trip(String(40), ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(String(40), [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(String(40), [data], [data])
+
+
+class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ compare = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ class Decorated(TypeDecorator):
+ impl = cls.datatype
+ cache_ok = True
+
+ Table(
+ "date_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("date_data", cls.datatype),
+ Column("decorated_date_data", Decorated),
+ )
+
+ @testing.requires.datetime_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+ def test_round_trip(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_round_trip_decorated(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "decorated_date_data": self.data}
+ )
+
+ row = connection.execute(
+ select(date_table.c.decorated_date_data)
+ ).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_null(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(date_table.insert(), {"id": 1, "date_data": None})
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+ eq_(row, (None,))
+
+ @testing.requires.datetime_literals
+ def test_literal(self, literal_round_trip):
+ compare = self.compare or self.data
+ literal_round_trip(self.datatype, [self.data], [compare])
+
+ @testing.requires.standalone_null_binds_whereclause
+ def test_null_bound_comparison(self):
+ # this test is based on an Oracle issue observed in #4886.
+ # passing NULL for an expression that needs to be interpreted as
+ # a certain type, does the DBAPI have the info it needs to do this.
+ date_table = self.tables.date_table
+ with config.db.begin() as conn:
+ result = conn.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+ id_ = result.inserted_primary_key[0]
+ stmt = select(date_table.c.id).where(
+ case(
+ (
+ bindparam("foo", type_=self.datatype) != None,
+ bindparam("foo", type_=self.datatype),
+ ),
+ else_=date_table.c.date_data,
+ )
+ == date_table.c.date_data
+ )
+
+ row = conn.execute(stmt, {"foo": None}).first()
+ eq_(row[0], id_)
+
+
+class DateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+
+
+class DateTimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_timezone",)
+ __backend__ = True
+ datatype = DateTime(timezone=True)
+ data = datetime.datetime(
+ 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc
+ )
+
+
+class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_microseconds",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+
+class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("timestamp_microseconds",)
+ __backend__ = True
+ datatype = TIMESTAMP
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+ @testing.requires.timestamp_microseconds_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+
+class TimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18)
+
+
+class TimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_timezone",)
+ __backend__ = True
+ datatype = Time(timezone=True)
+ data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc)
+
+
+class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_microseconds",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18, 396)
+
+
+class DateTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(2012, 10, 15)
+
+
+class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = "date", "date_coerces_from_datetime"
+ __backend__ = True
+ datatype = Date
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+ compare = datetime.date(2012, 10, 15)
+
+
+class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_historic",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(1850, 11, 10, 11, 52, 35)
+
+
+class DateHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date_historic",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(1727, 4, 1)
+
+
+class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Integer, [5], [5])
+
+ def test_huge_int(self, integer_round_trip):
+ integer_round_trip(BigInteger, 1376537018368127)
+
+ @testing.fixture
+ def integer_round_trip(self, metadata, connection):
+ def run(datatype, data):
+ int_table = Table(
+ "integer_table",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ primary_key=True,
+ test_needs_autoincrement=True,
+ ),
+ Column("integer_data", datatype),
+ )
+
+ metadata.create_all(config.db)
+
+ connection.execute(
+ int_table.insert(), {"id": 1, "integer_data": data}
+ )
+
+ row = connection.execute(select(int_table.c.integer_data)).first()
+
+ eq_(row, (data,))
+
+ if util.py3k:
+ assert isinstance(row[0], int)
+ else:
+ assert isinstance(row[0], (long, int)) # noqa
+
+ return run
+
+
+class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def string_as_int(self):
+ class StringAsInt(TypeDecorator):
+ impl = String(50)
+ cache_ok = True
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def column_expression(self, col):
+ return cast(col, Integer)
+
+ def bind_expression(self, col):
+ return cast(col, String(50))
+
+ return StringAsInt()
+
+ def test_special_type(self, metadata, connection, string_as_int):
+
+ type_ = string_as_int
+
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ eq_(result, {1, 2, 3})
+
+ result = {
+ row[0] for row in connection.execute(t.select().where(t.c.x == 2))
+ }
+ eq_(result, {2})
+
+
+class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def do_numeric_test(self, metadata, connection):
+ @testing.emits_warning(
+ r".*does \*not\* support Decimal objects natively"
+ )
+ def run(type_, input_, output, filter_=None, check_scale=False):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
+
+ return run
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric_asfloat(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ def test_render_literal_float(self, literal_round_trip):
+ literal_round_trip(
+ Float(4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ @testing.requires.precision_generic_float_type
+ def test_float_custom_scale(self, do_numeric_test):
+ do_numeric_test(
+ Float(None, decimal_return_scale=7, asdecimal=True),
+ [15.7563827, decimal.Decimal("15.7563827")],
+ [decimal.Decimal("15.7563827")],
+ check_scale=True,
+ )
+
+ def test_numeric_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ def test_numeric_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ @testing.requires.infinity_floats
+ def test_infinity_floats(self, do_numeric_test):
+ """test for #977, #7283"""
+
+ do_numeric_test(
+ Float(None),
+ [float("inf")],
+ [float("inf")],
+ )
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_decimal(self, do_numeric_test):
+ do_numeric_test(Numeric(precision=8, scale=4), [None], [None])
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False), [None], [None]
+ )
+
+ @testing.requires.floats_to_four_decimals
+ def test_float_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8, asdecimal=True),
+ [15.7563, decimal.Decimal("15.7563"), None],
+ [decimal.Decimal("15.7563"), None],
+ filter_=lambda n: n is not None and round(n, 4) or None,
+ )
+
+ def test_float_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ def test_float_coerce_round_trip(self, connection):
+ expr = 15.7563
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ # this does not work in MySQL, see #4036, however we choose not
+ # to render CAST unconditionally since this is kind of an edge case.
+
+ @testing.requires.implicit_decimal_binds
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip_w_cast(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(cast(expr, Numeric(10, 4))))
+ eq_(val, expr)
+
+ @testing.requires.precision_numerics_general
+ def test_precision_decimal(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ]
+ )
+
+ do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal(self, do_numeric_test):
+ """test exceedingly small decimals.
+
+ Decimal reports values with E notation when the exponent
+ is greater than 6.
+
+ """
+
+ numbers = set(
+ [
+ decimal.Decimal("1E-2"),
+ decimal.Decimal("1E-3"),
+ decimal.Decimal("1E-4"),
+ decimal.Decimal("1E-5"),
+ decimal.Decimal("1E-6"),
+ decimal.Decimal("1E-7"),
+ decimal.Decimal("1E-8"),
+ decimal.Decimal("0.01000005940696"),
+ decimal.Decimal("0.00000005940696"),
+ decimal.Decimal("0.00000000000696"),
+ decimal.Decimal("0.70000000000696"),
+ decimal.Decimal("696E-12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal_large(self, do_numeric_test):
+ """test exceedingly large decimals."""
+
+ numbers = set(
+ [
+ decimal.Decimal("4E+8"),
+ decimal.Decimal("5748E+15"),
+ decimal.Decimal("1.521E+15"),
+ decimal.Decimal("00000000000000.1E+12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers)
+
+ @testing.requires.precision_numerics_many_significant_digits
+ def test_many_significant_digits(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("31943874831932418390.01"),
+ decimal.Decimal("319438950232418390.273596"),
+ decimal.Decimal("87673.594069654243"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_retains_significant_digits
+ def test_numeric_no_decimal(self, do_numeric_test):
+ numbers = set([decimal.Decimal("1.000")])
+ do_numeric_test(
+ Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
+ )
+
+
+class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "boolean_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("value", Boolean),
+ Column("unconstrained_value", Boolean(create_constraint=False)),
+ )
+
+ def test_render_literal_bool(self, literal_round_trip):
+ literal_round_trip(Boolean(), [True, False], [True, False])
+
+ def test_round_trip(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": True, "unconstrained_value": False},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (True, False))
+ assert isinstance(row[0], bool)
+
+ @testing.requires.nullable_booleans
+ def test_null(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": None, "unconstrained_value": None},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (None, None))
+
+ def test_whereclause(self):
+ # testing "WHERE <column>" renders a compatible expression
+ boolean_table = self.tables.boolean_table
+
+ with config.db.begin() as conn:
+ conn.execute(
+ boolean_table.insert(),
+ [
+ {"id": 1, "value": True, "unconstrained_value": True},
+ {"id": 2, "value": False, "unconstrained_value": False},
+ ],
+ )
+
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(boolean_table.c.value)
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ boolean_table.c.unconstrained_value
+ )
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(~boolean_table.c.value)
+ ),
+ 2,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ ~boolean_table.c.unconstrained_value
+ )
+ ),
+ 2,
+ )
+
+
+class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("json_type",)
+ __backend__ = True
+
+ datatype = JSON
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype, nullable=False),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def test_round_trip_data1(self, connection):
+ self._test_round_trip({"key1": "value1", "key2": "value2"}, connection)
+
+ def _test_round_trip(self, data_element, connection):
+ data_table = self.tables.data_table
+
+ connection.execute(
+ data_table.insert(),
+ {"id": 1, "name": "row1", "data": data_element},
+ )
+
+ row = connection.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+
+ def _index_fixtures(include_comparison):
+
+ if include_comparison:
+ # basically SQL Server and MariaDB can kind of do json
+ # comparison, MySQL, PG and SQLite can't. not worth it.
+ json_elements = []
+ else:
+ json_elements = [
+ ("json", {"foo": "bar"}),
+ ("json", ["one", "two", "three"]),
+ (None, {"foo": "bar"}),
+ (None, ["one", "two", "three"]),
+ ]
+
+ elements = [
+ ("boolean", True),
+ ("boolean", False),
+ ("boolean", None),
+ ("string", "some string"),
+ ("string", None),
+ ("string", util.u("réve illé")),
+ (
+ "string",
+ util.u("réve🐍 illé"),
+ testing.requires.json_index_supplementary_unicode_element,
+ ),
+ ("integer", 15),
+ ("integer", 1),
+ ("integer", 0),
+ ("integer", None),
+ ("float", 28.5),
+ ("float", None),
+ (
+ "float",
+ 1234567.89,
+ ),
+ ("numeric", 1234567.89),
+ # this one "works" because the float value you see here is
+ # lost immediately to floating point stuff
+ ("numeric", 99998969694839.983485848, requirements.python3),
+ ("numeric", 99939.983485848, requirements.python3),
+ ("_decimal", decimal.Decimal("1234567.89")),
+ (
+ "_decimal",
+ decimal.Decimal("99998969694839.983485848"),
+ # fails on SQLite and MySQL (non-mariadb)
+ requirements.cast_precision_numerics_many_significant_digits,
+ ),
+ (
+ "_decimal",
+ decimal.Decimal("99939.983485848"),
+ ),
+ ] + json_elements
+
+ def decorate(fn):
+ fn = testing.combinations(id_="sa", *elements)(fn)
+
+ return fn
+
+ return decorate
+
+ def _json_value_insert(self, connection, datatype, value, data_element):
+ data_table = self.tables.data_table
+ if datatype == "_decimal":
+
+ # Python's builtin json serializer basically doesn't support
+ # Decimal objects without implicit float conversion period.
+ # users can otherwise use simplejson which supports
+ # precision decimals
+
+ # https://bugs.python.org/issue16535
+
+ # inserting as strings to avoid a new fixture around the
+ # dialect which would have idiosyncrasies for different
+ # backends.
+
+ class DecimalEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, decimal.Decimal):
+ return str(o)
+ return super(DecimalEncoder, self).default(o)
+
+ json_data = json.dumps(data_element, cls=DecimalEncoder)
+
+ # take the quotes out. yup, there is *literally* no other
+ # way to get Python's json.dumps() to put all the digits in
+ # the string
+ json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data)
+
+ datatype = "numeric"
+
+ connection.execute(
+ data_table.insert().values(
+ name="row1",
+ # to pass the string directly to every backend, including
+ # PostgreSQL which needs the value to be CAST as JSON
+ # both in the SQL as well as at the prepared statement
+ # level for asyncpg, while at the same time MySQL
+ # doesn't even support CAST for JSON, here we are
+ # sending the string embedded in the SQL without using
+ # a parameter.
+ data=bindparam(None, json_data, literal_execute=True),
+ nulldata=bindparam(None, json_data, literal_execute=True),
+ ),
+ )
+ else:
+ connection.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ p_s = None
+
+ if datatype:
+ if datatype == "numeric":
+ a, b = str(value).split(".")
+ s = len(b)
+ p = len(a) + s
+
+ if isinstance(value, decimal.Decimal):
+ compare_value = value
+ else:
+ compare_value = decimal.Decimal(str(value))
+
+ p_s = (p, s)
+ else:
+ compare_value = value
+ else:
+ compare_value = value
+
+ return datatype, compare_value, p_s
+
+ @_index_fixtures(False)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_access(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ roundtrip = conn.scalar(select(expr))
+ eq_(roundtrip, compare_value)
+ if util.py3k: # skip py2k to avoid comparing unicode to str etc.
+ is_(type(roundtrip), type(compare_value))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_path_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": {"subkey1": value}}
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data[("key1", "subkey1")]
+
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @testing.combinations(
+ (True,),
+ (False,),
+ (None,),
+ (15,),
+ (0,),
+ (-1,),
+ (-1.0,),
+ (15.052,),
+ ("a string",),
+ (util.u("réve illé"),),
+ (util.u("réve🐍 illé"),),
+ )
+ def test_single_element_round_trip(self, element):
+ data_table = self.tables.data_table
+ data_element = element
+ with config.db.begin() as conn:
+ conn.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ row = conn.execute(
+ select(data_table.c.data, data_table.c.nulldata)
+ ).first()
+
+ eq_(row, (data_element, data_element))
+
+ def test_round_trip_custom_json(self):
+ data_table = self.tables.data_table
+ data_element = {"key1": "data1"}
+
+ js = mock.Mock(side_effect=json.dumps)
+ jd = mock.Mock(side_effect=json.loads)
+ engine = engines.testing_engine(
+ options=dict(json_serializer=js, json_deserializer=jd)
+ )
+
+ # support sqlite :memory: database...
+ data_table.create(engine, checkfirst=True)
+ with engine.begin() as conn:
+ conn.execute(
+ data_table.insert(), {"name": "row1", "data": data_element}
+ )
+ row = conn.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+ eq_(js.mock_calls, [mock.call(data_element)])
+ eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ ("omit",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_sql_null(self, connection, insert_type):
+ col = self.tables.data_table.c["nulldata"]
+
+ conn = connection
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "nulldata": None,
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "nulldata": None, "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(
+ name="r1",
+ nulldata=None,
+ data=None,
+ ),
+ {},
+ )
+ elif insert_type == "omit":
+ stmt, params = (
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": None},
+ )
+
+ else:
+ assert False
+
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(col.is_(null()))
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_round_trip_json_null_as_json_null(self, connection):
+ col = self.tables.data_table.c["data"]
+
+ conn = connection
+ conn.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": JSON.NULL},
+ )
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_json_null(self, connection, insert_type):
+ col = self.tables.data_table.c["data"]
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(name="r1", data=None),
+ {},
+ )
+ else:
+ assert False
+
+ conn = connection
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_unicode_round_trip(self):
+ # note we include Unicode supplementary characters as well
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ {
+ "name": "r1",
+ "data": {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ },
+ )
+
+ eq_(
+ conn.scalar(select(self.tables.data_table.c.data)),
+ {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ )
+
+ def test_eval_none_flag_orm(self, connection):
+
+ Base = declarative_base()
+
+ class Data(Base):
+ __table__ = self.tables.data_table
+
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
+
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
+ )
+
+
+class JSONLegacyStringCastIndexTest(
+ _LiteralRoundTripFixture, fixtures.TablesTest
+):
+ """test JSON index access with "cast to string", which we have documented
+ for a long time as how to compare JSON values, but is ultimately not
+ reliable in all cases. The "as_XYZ()" comparators should be used
+ instead.
+
+ """
+
+ __requires__ = ("json_type", "legacy_unconditional_json_extract")
+ __backend__ = True
+
+ datatype = JSON
+
+ data1 = {"key1": "value1", "key2": "value2"}
+
+ data2 = {
+ "Key 'One'": "value1",
+ "key two": "value2",
+ "key three": "value ' three '",
+ }
+
+ data3 = {
+ "key1": [1, 2, 3],
+ "key2": ["one", "two", "three"],
+ "key3": [{"four": "five"}, {"six": "seven"}],
+ }
+
+ data4 = ["one", "two", "three"]
+
+ data5 = {
+ "nested": {
+ "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}],
+ "elem2": {"elem3": {"elem4": "elem5"}},
+ }
+ }
+
+ data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def _criteria_fixture(self):
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ [
+ {"name": "r1", "data": self.data1},
+ {"name": "r2", "data": self.data2},
+ {"name": "r3", "data": self.data3},
+ {"name": "r4", "data": self.data4},
+ {"name": "r5", "data": self.data5},
+ {"name": "r6", "data": self.data6},
+ ],
+ )
+
+ def _test_index_criteria(self, crit, expected, test_literal=True):
+ self._criteria_fixture()
+ with config.db.connect() as conn:
+ stmt = select(self.tables.data_table.c.name).where(crit)
+
+ eq_(conn.scalar(stmt), expected)
+
+ if test_literal:
+ literal_sql = str(
+ stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}
+ )
+ )
+
+ eq_(conn.exec_driver_sql(literal_sql).scalar(), expected)
+
+ def test_string_cast_crit_spaces_in_key(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract field from a non-object", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name.in_(["r1", "r2", "r3"]),
+ cast(col["key two"], String) == '"value2"',
+ ),
+ "r2",
+ )
+
+ @config.requirements.json_array_indexes
+ def test_string_cast_crit_simple_int(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract array element from a non-array", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name == "r4",
+ cast(col[1], String) == '"two"',
+ ),
+ "r4",
+ )
+
+ def test_string_cast_crit_mixed_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("key3", 1, "six")], String) == '"seven"',
+ "r3",
+ )
+
+ def test_string_cast_crit_string_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("nested", "elem2", "elem3", "elem4")], String)
+ == '"elem5"',
+ "r5",
+ )
+
+ def test_string_cast_crit_against_string_basic(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ self._test_index_criteria(
+ and_(
+ name == "r6",
+ cast(col["b"], String) == '"some value"',
+ ),
+ "r6",
+ )
+
+
+__all__ = (
+ "BinaryTest",
+ "UnicodeVarcharTest",
+ "UnicodeTextTest",
+ "JSONTest",
+ "JSONLegacyStringCastIndexTest",
+ "DateTest",
+ "DateTimeTest",
+ "DateTimeTZTest",
+ "TextTest",
+ "NumericTest",
+ "IntegerTest",
+ "CastTypeDecoratorTest",
+ "DateTimeHistoricTest",
+ "DateTimeCoercedToDateTimeTest",
+ "TimeMicrosecondsTest",
+ "TimestampMicrosecondsTest",
+ "TimeTest",
+ "TimeTZTest",
+ "DateTimeMicrosecondsTest",
+ "DateHistoricTest",
+ "StringTest",
+ "BooleanTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
new file mode 100644
index 0000000..a4ae334
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
@@ -0,0 +1,206 @@
+# coding: utf-8
+"""verrrrry basic unicode column name testing"""
+
+from sqlalchemy import desc
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import testing
+from sqlalchemy import util
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+from sqlalchemy.util import u
+from sqlalchemy.util import ue
+
+
+class UnicodeSchemaTest(fixtures.TablesTest):
+ __requires__ = ("unicode_ddl",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ global t1, t2, t3
+
+ t1 = Table(
+ u("unitable1"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True),
+ Column(ue("\u6e2c\u8a66"), Integer),
+ test_needs_fk=True,
+ )
+ t2 = Table(
+ u("Unitéble2"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True, key="a"),
+ Column(
+ ue("\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(u("unitable1.méil")),
+ key="b",
+ ),
+ test_needs_fk=True,
+ )
+
+ # Few DBs support Unicode foreign keys
+ if testing.against("sqlite"):
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(
+ ue("unitable1_\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(ue("unitable1.\u6e2c\u8a66")),
+ ),
+ Column(
+ u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b"))
+ ),
+ Column(
+ ue("\u6e2c\u8a66_self"),
+ Integer,
+ ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")),
+ ),
+ test_needs_fk=True,
+ )
+ else:
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(ue("unitable1_\u6e2c\u8a66"), Integer),
+ Column(u("Unitéble2_b"), Integer),
+ Column(ue("\u6e2c\u8a66_self"), Integer),
+ test_needs_fk=True,
+ )
+
+ def test_insert(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
+ eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
+ eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
+
+ def test_col_targeting(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ row = connection.execute(t1.select()).first()
+ eq_(row._mapping[t1.c[u("méil")]], 1)
+ eq_(row._mapping[t1.c[ue("\u6e2c\u8a66")]], 5)
+
+ row = connection.execute(t2.select()).first()
+ eq_(row._mapping[t2.c[u("a")]], 1)
+ eq_(row._mapping[t2.c[u("b")]], 1)
+
+ row = connection.execute(t3.select()).first()
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_id")]], 1)
+ eq_(row._mapping[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5)
+ eq_(row._mapping[t3.c[u("Unitéble2_b")]], 1)
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_self")]], 1)
+
+ def test_reflect(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 2, ue("\u6e2c\u8a66"): 7})
+ connection.execute(t2.insert(), {u("a"): 2, u("b"): 2})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 2,
+ ue("unitable1_\u6e2c\u8a66"): 7,
+ u("Unitéble2_b"): 2,
+ ue("\u6e2c\u8a66_self"): 2,
+ },
+ )
+
+ meta = MetaData()
+ tt1 = Table(t1.name, meta, autoload_with=connection)
+ tt2 = Table(t2.name, meta, autoload_with=connection)
+ tt3 = Table(t3.name, meta, autoload_with=connection)
+
+ connection.execute(tt1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(tt2.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 1})
+ connection.execute(
+ tt3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(
+ connection.execute(
+ tt1.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 7), (1, 5)],
+ )
+ eq_(
+ connection.execute(
+ tt2.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 2), (1, 1)],
+ )
+ eq_(
+ connection.execute(
+ tt3.select().order_by(desc(ue("\u6e2c\u8a66_id")))
+ ).fetchall(),
+ [(2, 7, 2, 2), (1, 5, 1, 1)],
+ )
+
+ def test_repr(self):
+ meta = MetaData()
+ t = Table(
+ ue("\u6e2c\u8a66"), meta, Column(ue("\u6e2c\u8a66_id"), Integer)
+ )
+
+ if util.py2k:
+ eq_(
+ repr(t),
+ (
+ "Table('\\u6e2c\\u8a66', MetaData(), "
+ "Column('\\u6e2c\\u8a66_id', Integer(), "
+ "table=<\u6e2c\u8a66>), "
+ "schema=None)"
+ ),
+ )
+ else:
+ eq_(
+ repr(t),
+ (
+ "Table('測試', MetaData(), "
+ "Column('測試_id', Integer(), "
+ "table=<測試>), "
+ "schema=None)"
+ ),
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
new file mode 100644
index 0000000..f04a9d5
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -0,0 +1,60 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import String
+
+
+class SimpleUpdateDeleteTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ def test_update(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(
+ t.update().where(t.c.id == 2), dict(data="d2_new")
+ )
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (2, "d2_new"), (3, "d3")],
+ )
+
+ def test_delete(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(t.delete().where(t.c.id == 2))
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (3, "d3")],
+ )
+
+
+__all__ = ("SimpleUpdateDeleteTest",)
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
new file mode 100644
index 0000000..be89bc6
--- /dev/null
+++ b/lib/sqlalchemy/testing/util.py
@@ -0,0 +1,458 @@
+# testing/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import decimal
+import gc
+import random
+import sys
+import types
+
+from . import config
+from . import mock
+from .. import inspect
+from ..engine import Connection
+from ..schema import Column
+from ..schema import DropConstraint
+from ..schema import DropTable
+from ..schema import ForeignKeyConstraint
+from ..schema import MetaData
+from ..schema import Table
+from ..sql import schema
+from ..sql.sqltypes import Integer
+from ..util import decorator
+from ..util import defaultdict
+from ..util import has_refcount_gc
+from ..util import inspect_getfullargspec
+from ..util import py2k
+
+
+if not has_refcount_gc:
+
+ def non_refcount_gc_collect(*args):
+ gc.collect()
+ gc.collect()
+
+ gc_collect = lazy_gc = non_refcount_gc_collect
+else:
+ # assume CPython - straight gc.collect, lazy_gc() is a pass
+ gc_collect = gc.collect
+
+ def lazy_gc():
+ pass
+
+
+def picklers():
+ picklers = set()
+ if py2k:
+ try:
+ import cPickle
+
+ picklers.add(cPickle)
+ except ImportError:
+ pass
+
+ import pickle
+
+ picklers.add(pickle)
+
+ # yes, this thing needs this much testing
+ for pickle_ in picklers:
+ for protocol in range(-2, pickle.HIGHEST_PROTOCOL):
+ yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
+
+
+if py2k:
+
+ def random_choices(population, k=1):
+ pop = list(population)
+ # lame but works :)
+ random.shuffle(pop)
+ return pop[0:k]
+
+
+else:
+
+ def random_choices(population, k=1):
+ return random.choices(population, k=k)
+
+
+def round_decimal(value, prec):
+ if isinstance(value, float):
+ return round(value, prec)
+
+ # can also use shift() here but that is 2.6 only
+ return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
+ decimal.ROUND_FLOOR
+ ) / pow(10, prec)
+
+
+class RandomSet(set):
+ def __iter__(self):
+ l = list(set.__iter__(self))
+ random.shuffle(l)
+ return iter(l)
+
+ def pop(self):
+ index = random.randint(0, len(self) - 1)
+ item = list(set.__iter__(self))[index]
+ self.remove(item)
+ return item
+
+ def union(self, other):
+ return RandomSet(set.union(self, other))
+
+ def difference(self, other):
+ return RandomSet(set.difference(self, other))
+
+ def intersection(self, other):
+ return RandomSet(set.intersection(self, other))
+
+ def copy(self):
+ return RandomSet(self)
+
+
+def conforms_partial_ordering(tuples, sorted_elements):
+ """True if the given sorting conforms to the given partial ordering."""
+
+ deps = defaultdict(set)
+ for parent, child in tuples:
+ deps[parent].add(child)
+ for i, node in enumerate(sorted_elements):
+ for n in sorted_elements[i:]:
+ if node in deps[n]:
+ return False
+ else:
+ return True
+
+
+def all_partial_orderings(tuples, elements):
+ edges = defaultdict(set)
+ for parent, child in tuples:
+ edges[child].add(parent)
+
+ def _all_orderings(elements):
+
+ if len(elements) == 1:
+ yield list(elements)
+ else:
+ for elem in elements:
+ subset = set(elements).difference([elem])
+ if not subset.intersection(edges[elem]):
+ for sub_ordering in _all_orderings(subset):
+ yield [elem] + sub_ordering
+
+ return iter(_all_orderings(elements))
+
+
+def function_named(fn, name):
+ """Return a function with a given __name__.
+
+ Will assign to __name__ and return the original function if possible on
+ the Python implementation, otherwise a new function will be constructed.
+
+ This function should be phased out as much as possible
+ in favor of @decorator. Tests that "generate" many named tests
+ should be modernized.
+
+ """
+ try:
+ fn.__name__ = name
+ except TypeError:
+ fn = types.FunctionType(
+ fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
+ )
+ return fn
+
+
+def run_as_contextmanager(ctx, fn, *arg, **kw):
+ """Run the given function under the given contextmanager,
+ simulating the behavior of 'with' to support older
+ Python versions.
+
+ This is not necessary anymore as we have placed 2.6
+ as minimum Python version, however some tests are still using
+ this structure.
+
+ """
+
+ obj = ctx.__enter__()
+ try:
+ result = fn(obj, *arg, **kw)
+ ctx.__exit__(None, None, None)
+ return result
+ except:
+ exc_info = sys.exc_info()
+ raise_ = ctx.__exit__(*exc_info)
+ if not raise_:
+ raise
+ else:
+ return raise_
+
+
+def rowset(results):
+ """Converts the results of sql execution into a plain set of column tuples.
+
+ Useful for asserting the results of an unordered query.
+ """
+
+ return {tuple(row) for row in results}
+
+
+def fail(msg):
+ assert False, msg
+
+
+@decorator
+def provide_metadata(fn, *args, **kw):
+ """Provide bound MetaData for a single test, dropping afterwards.
+
+ Legacy; use the "metadata" pytest fixture.
+
+ """
+
+ from . import fixtures
+
+ metadata = schema.MetaData()
+ self = args[0]
+ prev_meta = getattr(self, "metadata", None)
+ self.metadata = metadata
+ try:
+ return fn(*args, **kw)
+ finally:
+ # close out some things that get in the way of dropping tables.
+ # when using the "metadata" fixture, there is a set ordering
+ # of things that makes sure things are cleaned up in order, however
+ # the simple "decorator" nature of this legacy function means
+ # we have to hardcode some of that cleanup ahead of time.
+
+ # close ORM sessions
+ fixtures._close_all_sessions()
+
+ # integrate with the "connection" fixture as there are many
+ # tests where it is used along with provide_metadata
+ if fixtures._connection_fixture_connection:
+ # TODO: this warning can be used to find all the places
+ # this is used with connection fixture
+ # warn("mixing legacy provide metadata with connection fixture")
+ drop_all_tables_from_metadata(
+ metadata, fixtures._connection_fixture_connection
+ )
+ # as the provide_metadata fixture is often used with "testing.db",
+ # when we do the drop we have to commit the transaction so that
+ # the DB is actually updated as the CREATE would have been
+ # committed
+ fixtures._connection_fixture_connection.get_transaction().commit()
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+ self.metadata = prev_meta
+
+
+def flag_combinations(*combinations):
+ """A facade around @testing.combinations() oriented towards boolean
+ keyword-based arguments.
+
+ Basically generates a nice looking identifier based on the keywords
+ and also sets up the argument names.
+
+ E.g.::
+
+ @testing.flag_combinations(
+ dict(lazy=False, passive=False),
+ dict(lazy=True, passive=False),
+ dict(lazy=False, passive=True),
+ dict(lazy=False, passive=True, raiseload=True),
+ )
+
+
+ would result in::
+
+ @testing.combinations(
+ ('', False, False, False),
+ ('lazy', True, False, False),
+ ('lazy_passive', True, True, False),
+ ('lazy_passive', True, True, True),
+ id_='iaaa',
+ argnames='lazy,passive,raiseload'
+ )
+
+ """
+
+ keys = set()
+
+ for d in combinations:
+ keys.update(d)
+
+ keys = sorted(keys)
+
+ return config.combinations(
+ *[
+ ("_".join(k for k in keys if d.get(k, False)),)
+ + tuple(d.get(k, False) for k in keys)
+ for d in combinations
+ ],
+ id_="i" + ("a" * len(keys)),
+ argnames=",".join(keys)
+ )
+
+
+def lambda_combinations(lambda_arg_sets, **kw):
+ args = inspect_getfullargspec(lambda_arg_sets)
+
+ arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
+
+ def create_fixture(pos):
+ def fixture(**kw):
+ return lambda_arg_sets(**kw)[pos]
+
+ fixture.__name__ = "fixture_%3.3d" % pos
+ return fixture
+
+ return config.combinations(
+ *[(create_fixture(i),) for i in range(len(arg_sets))], **kw
+ )
+
+
+def resolve_lambda(__fn, **kw):
+ """Given a no-arg lambda and a namespace, return a new lambda that
+ has all the values filled in.
+
+ This is used so that we can have module-level fixtures that
+ refer to instance-level variables using lambdas.
+
+ """
+
+ pos_args = inspect_getfullargspec(__fn)[0]
+ pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
+ glb = dict(__fn.__globals__)
+ glb.update(kw)
+ new_fn = types.FunctionType(__fn.__code__, glb)
+ return new_fn(**pass_pos_args)
+
+
+def metadata_fixture(ddl="function"):
+ """Provide MetaData for a pytest fixture."""
+
+ def decorate(fn):
+ def run_ddl(self):
+
+ metadata = self.metadata = schema.MetaData()
+ try:
+ result = fn(self, metadata)
+ metadata.create_all(config.db)
+ # TODO:
+ # somehow get a per-function dml erase fixture here
+ yield result
+ finally:
+ metadata.drop_all(config.db)
+
+ return config.fixture(scope=ddl)(run_ddl)
+
+ return decorate
+
+
+def force_drop_names(*names):
+ """Force the given table names to be dropped after test complete,
+ isolating for foreign key cycles
+
+ """
+
+ @decorator
+ def go(fn, *args, **kw):
+
+ try:
+ return fn(*args, **kw)
+ finally:
+ drop_all_tables(config.db, inspect(config.db), include_names=names)
+
+ return go
+
+
+class adict(dict):
+ """Dict keys available as attributes. Shadows."""
+
+ def __getattribute__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ return dict.__getattribute__(self, key)
+
+ def __call__(self, *keys):
+ return tuple([self[key] for key in keys])
+
+ get_all = __call__
+
+
+def drop_all_tables_from_metadata(metadata, engine_or_connection):
+ from . import engines
+
+ def go(connection):
+ engines.testing_reaper.prepare_for_drop_tables(connection)
+
+ if not connection.dialect.supports_alter:
+ from . import assertions
+
+ with assertions.expect_warnings(
+ "Can't sort tables", assert_=False
+ ):
+ metadata.drop_all(connection)
+ else:
+ metadata.drop_all(connection)
+
+ if not isinstance(engine_or_connection, Connection):
+ with engine_or_connection.begin() as connection:
+ go(connection)
+ else:
+ go(engine_or_connection)
+
+
+def drop_all_tables(engine, inspector, schema=None, include_names=None):
+
+ if include_names is not None:
+ include_names = set(include_names)
+
+ with engine.begin() as conn:
+ for tname, fkcs in reversed(
+ inspector.get_sorted_table_and_fkc_names(schema=schema)
+ ):
+ if tname:
+ if include_names is not None and tname not in include_names:
+ continue
+ conn.execute(
+ DropTable(Table(tname, MetaData(), schema=schema))
+ )
+ elif fkcs:
+ if not engine.dialect.supports_alter:
+ continue
+ for tname, fkc in fkcs:
+ if (
+ include_names is not None
+ and tname not in include_names
+ ):
+ continue
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
+ )
+ )
+
+
+def teardown_events(event_cls):
+ @decorator
+ def decorate(fn, *arg, **kw):
+ try:
+ return fn(*arg, **kw)
+ finally:
+ event_cls._clear()
+
+ return decorate
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
new file mode 100644
index 0000000..3e78387
--- /dev/null
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -0,0 +1,82 @@
+# testing/warnings.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import warnings
+
+from . import assertions
+from .. import exc as sa_exc
+from ..util.langhelpers import _warnings_warn
+
+
+class SATestSuiteWarning(Warning):
+ """warning for a condition detected during tests that is non-fatal
+
+ Currently outside of SAWarning so that we can work around tools like
+ Alembic doing the wrong thing with warnings.
+
+ """
+
+
+def warn_test_suite(message):
+ _warnings_warn(message, category=SATestSuiteWarning)
+
+
+def setup_filters():
+ """Set global warning behavior for the test suite."""
+
+ # TODO: at this point we can use the normal pytest warnings plugin,
+ # if we decide the test suite can be linked to pytest only
+
+ origin = r"^(?:test|sqlalchemy)\..*"
+
+ warnings.filterwarnings(
+ "ignore", category=sa_exc.SAPendingDeprecationWarning
+ )
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
+
+ warnings.filterwarnings("always", category=SATestSuiteWarning)
+
+ warnings.filterwarnings(
+ "error", category=DeprecationWarning, module=origin
+ )
+
+ # ignore things that are deprecated *as of* 2.0 :)
+ warnings.filterwarnings(
+ "ignore",
+ category=sa_exc.SADeprecationWarning,
+ message=r".*\(deprecated since: 2.0\)$",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ category=sa_exc.SADeprecationWarning,
+ message=r"^The (Sybase|firebird) dialect is deprecated and will be",
+ )
+
+ try:
+ import pytest
+ except ImportError:
+ pass
+ else:
+ warnings.filterwarnings(
+ "once", category=pytest.PytestDeprecationWarning, module=origin
+ )
+
+
+def assert_warnings(fn, warning_msgs, regex=False):
+ """Assert that each of the given warnings are emitted by fn.
+
+ Deprecated. Please use assertions.expect_warnings().
+
+ """
+
+ with assertions._expect_warnings(
+ sa_exc.SAWarning, warning_msgs, regex=regex
+ ):
+ return fn()
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
new file mode 100644
index 0000000..07263c5
--- /dev/null
+++ b/lib/sqlalchemy/types.py
@@ -0,0 +1,119 @@
+# types.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Compatibility namespace for sqlalchemy.sql.types.
+
+"""
+
+__all__ = [
+ "TypeEngine",
+ "TypeDecorator",
+ "UserDefinedType",
+ "ExternalType",
+ "INT",
+ "CHAR",
+ "VARCHAR",
+ "NCHAR",
+ "NVARCHAR",
+ "TEXT",
+ "Text",
+ "FLOAT",
+ "NUMERIC",
+ "REAL",
+ "DECIMAL",
+ "TIMESTAMP",
+ "DATETIME",
+ "CLOB",
+ "BLOB",
+ "BINARY",
+ "VARBINARY",
+ "BOOLEAN",
+ "BIGINT",
+ "SMALLINT",
+ "INTEGER",
+ "DATE",
+ "TIME",
+ "TupleType",
+ "String",
+ "Integer",
+ "SmallInteger",
+ "BigInteger",
+ "Numeric",
+ "Float",
+ "DateTime",
+ "Date",
+ "Time",
+ "LargeBinary",
+ "Boolean",
+ "Unicode",
+ "Concatenable",
+ "UnicodeText",
+ "PickleType",
+ "Interval",
+ "Enum",
+ "Indexable",
+ "ARRAY",
+ "JSON",
+]
+
+from .sql.sqltypes import _Binary
+from .sql.sqltypes import ARRAY
+from .sql.sqltypes import BIGINT
+from .sql.sqltypes import BigInteger
+from .sql.sqltypes import BINARY
+from .sql.sqltypes import BLOB
+from .sql.sqltypes import BOOLEAN
+from .sql.sqltypes import Boolean
+from .sql.sqltypes import CHAR
+from .sql.sqltypes import CLOB
+from .sql.sqltypes import Concatenable
+from .sql.sqltypes import DATE
+from .sql.sqltypes import Date
+from .sql.sqltypes import DATETIME
+from .sql.sqltypes import DateTime
+from .sql.sqltypes import DECIMAL
+from .sql.sqltypes import Enum
+from .sql.sqltypes import FLOAT
+from .sql.sqltypes import Float
+from .sql.sqltypes import Indexable
+from .sql.sqltypes import INT
+from .sql.sqltypes import INTEGER
+from .sql.sqltypes import Integer
+from .sql.sqltypes import Interval
+from .sql.sqltypes import JSON
+from .sql.sqltypes import LargeBinary
+from .sql.sqltypes import MatchType
+from .sql.sqltypes import NCHAR
+from .sql.sqltypes import NULLTYPE
+from .sql.sqltypes import NullType
+from .sql.sqltypes import NUMERIC
+from .sql.sqltypes import Numeric
+from .sql.sqltypes import NVARCHAR
+from .sql.sqltypes import PickleType
+from .sql.sqltypes import REAL
+from .sql.sqltypes import SchemaType
+from .sql.sqltypes import SMALLINT
+from .sql.sqltypes import SmallInteger
+from .sql.sqltypes import String
+from .sql.sqltypes import STRINGTYPE
+from .sql.sqltypes import TEXT
+from .sql.sqltypes import Text
+from .sql.sqltypes import TIME
+from .sql.sqltypes import Time
+from .sql.sqltypes import TIMESTAMP
+from .sql.sqltypes import TupleType
+from .sql.sqltypes import Unicode
+from .sql.sqltypes import UnicodeText
+from .sql.sqltypes import VARBINARY
+from .sql.sqltypes import VARCHAR
+from .sql.type_api import adapt_type
+from .sql.type_api import ExternalType
+from .sql.type_api import to_instance
+from .sql.type_api import TypeDecorator
+from .sql.type_api import TypeEngine
+from .sql.type_api import UserDefinedType
+from .sql.type_api import Variant
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
new file mode 100644
index 0000000..33427e3
--- /dev/null
+++ b/lib/sqlalchemy/util/__init__.py
@@ -0,0 +1,175 @@
+# util/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from collections import defaultdict
+from contextlib import contextmanager
+from functools import partial
+from functools import update_wrapper
+
+from ._collections import coerce_generator_arg
+from ._collections import coerce_to_immutabledict
+from ._collections import collections_abc
+from ._collections import column_dict
+from ._collections import column_set
+from ._collections import EMPTY_DICT
+from ._collections import EMPTY_SET
+from ._collections import FacadeDict
+from ._collections import flatten_iterator
+from ._collections import has_dupes
+from ._collections import has_intersection
+from ._collections import IdentitySet
+from ._collections import ImmutableContainer
+from ._collections import immutabledict
+from ._collections import ImmutableProperties
+from ._collections import LRUCache
+from ._collections import ordered_column_set
+from ._collections import OrderedDict
+from ._collections import OrderedIdentitySet
+from ._collections import OrderedProperties
+from ._collections import OrderedSet
+from ._collections import PopulateDict
+from ._collections import Properties
+from ._collections import ScopedRegistry
+from ._collections import sort_dictionary
+from ._collections import ThreadLocalRegistry
+from ._collections import to_column_set
+from ._collections import to_list
+from ._collections import to_set
+from ._collections import unique_list
+from ._collections import UniqueAppender
+from ._collections import update_copy
+from ._collections import WeakPopulateDict
+from ._collections import WeakSequence
+from ._preloaded import preload_module
+from ._preloaded import preloaded
+from .compat import ABC
+from .compat import arm
+from .compat import b
+from .compat import b64decode
+from .compat import b64encode
+from .compat import binary_type
+from .compat import binary_types
+from .compat import byte_buffer
+from .compat import callable
+from .compat import cmp
+from .compat import cpython
+from .compat import dataclass_fields
+from .compat import decode_backslashreplace
+from .compat import dottedgetter
+from .compat import has_refcount_gc
+from .compat import inspect_getfullargspec
+from .compat import int_types
+from .compat import iterbytes
+from .compat import itertools_filter
+from .compat import itertools_filterfalse
+from .compat import local_dataclass_fields
+from .compat import namedtuple
+from .compat import next
+from .compat import nullcontext
+from .compat import osx
+from .compat import parse_qsl
+from .compat import perf_counter
+from .compat import pickle
+from .compat import print_
+from .compat import py2k
+from .compat import py311
+from .compat import py37
+from .compat import py38
+from .compat import py39
+from .compat import py3k
+from .compat import pypy
+from .compat import quote_plus
+from .compat import raise_
+from .compat import raise_from_cause
+from .compat import reduce
+from .compat import reraise
+from .compat import string_types
+from .compat import StringIO
+from .compat import text_type
+from .compat import threading
+from .compat import timezone
+from .compat import TYPE_CHECKING
+from .compat import u
+from .compat import ue
+from .compat import unquote
+from .compat import unquote_plus
+from .compat import win32
+from .compat import with_metaclass
+from .compat import zip_longest
+from .concurrency import asyncio
+from .concurrency import await_fallback
+from .concurrency import await_only
+from .concurrency import greenlet_spawn
+from .concurrency import is_exit_exception
+from .deprecations import deprecated
+from .deprecations import deprecated_20
+from .deprecations import deprecated_20_cls
+from .deprecations import deprecated_cls
+from .deprecations import deprecated_params
+from .deprecations import inject_docstring_text
+from .deprecations import moved_20
+from .deprecations import SQLALCHEMY_WARN_20
+from .deprecations import warn_deprecated
+from .deprecations import warn_deprecated_20
+from .langhelpers import add_parameter_text
+from .langhelpers import as_interface
+from .langhelpers import asbool
+from .langhelpers import asint
+from .langhelpers import assert_arg_type
+from .langhelpers import attrsetter
+from .langhelpers import bool_or_str
+from .langhelpers import chop_traceback
+from .langhelpers import class_hierarchy
+from .langhelpers import classproperty
+from .langhelpers import clsname_as_plain_name
+from .langhelpers import coerce_kw_type
+from .langhelpers import constructor_copy
+from .langhelpers import constructor_key
+from .langhelpers import counter
+from .langhelpers import create_proxy_methods
+from .langhelpers import decode_slice
+from .langhelpers import decorator
+from .langhelpers import dictlike_iteritems
+from .langhelpers import duck_type_collection
+from .langhelpers import ellipses_string
+from .langhelpers import EnsureKWArgType
+from .langhelpers import format_argspec_init
+from .langhelpers import format_argspec_plus
+from .langhelpers import generic_repr
+from .langhelpers import get_callable_argspec
+from .langhelpers import get_cls_kwargs
+from .langhelpers import get_func_kwargs
+from .langhelpers import getargspec_init
+from .langhelpers import has_compiled_ext
+from .langhelpers import HasMemoized
+from .langhelpers import hybridmethod
+from .langhelpers import hybridproperty
+from .langhelpers import iterate_attributes
+from .langhelpers import map_bits
+from .langhelpers import md5_hex
+from .langhelpers import memoized_instancemethod
+from .langhelpers import memoized_property
+from .langhelpers import MemoizedSlots
+from .langhelpers import method_is_overridden
+from .langhelpers import methods_equivalent
+from .langhelpers import monkeypatch_proxied_specials
+from .langhelpers import NoneType
+from .langhelpers import only_once
+from .langhelpers import PluginLoader
+from .langhelpers import portable_instancemethod
+from .langhelpers import quoted_token_parser
+from .langhelpers import safe_reraise
+from .langhelpers import set_creation_order
+from .langhelpers import string_or_unprintable
+from .langhelpers import symbol
+from .langhelpers import unbound_method_to_callable
+from .langhelpers import walk_subclasses
+from .langhelpers import warn
+from .langhelpers import warn_exception
+from .langhelpers import warn_limited
+from .langhelpers import wrap_callable
diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py
new file mode 100644
index 0000000..8e21830
--- /dev/null
+++ b/lib/sqlalchemy/util/_collections.py
@@ -0,0 +1,1089 @@
+# util/_collections.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Collection classes and helpers."""
+
+from __future__ import absolute_import
+
+import operator
+import types
+import weakref
+
+from .compat import binary_types
+from .compat import collections_abc
+from .compat import itertools_filterfalse
+from .compat import py2k
+from .compat import py37
+from .compat import string_types
+from .compat import threading
+
+
+EMPTY_SET = frozenset()
+
+
+class ImmutableContainer(object):
+ def _immutable(self, *arg, **kw):
+ raise TypeError("%s object is immutable" % self.__class__.__name__)
+
+ __delitem__ = __setitem__ = __setattr__ = _immutable
+
+
+def _immutabledict_py_fallback():
+ class immutabledict(ImmutableContainer, dict):
+
+ clear = (
+ pop
+ ) = popitem = setdefault = update = ImmutableContainer._immutable
+
+ def __new__(cls, *args):
+ new = dict.__new__(cls)
+ dict.__init__(new, *args)
+ return new
+
+ def __init__(self, *args):
+ pass
+
+ def __reduce__(self):
+ return _immutabledict_reconstructor, (dict(self),)
+
+ def union(self, __d=None):
+ if not __d:
+ return self
+
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ dict.update(new, __d)
+ return new
+
+ def _union_w_kw(self, __d=None, **kw):
+ # not sure if C version works correctly w/ this yet
+ if not __d and not kw:
+ return self
+
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ if __d:
+ dict.update(new, __d)
+ dict.update(new, kw)
+ return new
+
+ def merge_with(self, *dicts):
+ new = None
+ for d in dicts:
+ if d:
+ if new is None:
+ new = dict.__new__(self.__class__)
+ dict.__init__(new, self)
+ dict.update(new, d)
+ if new is None:
+ return self
+
+ return new
+
+ def __repr__(self):
+ return "immutabledict(%s)" % dict.__repr__(self)
+
+ return immutabledict
+
+
+try:
+ from sqlalchemy.cimmutabledict import immutabledict
+
+ collections_abc.Mapping.register(immutabledict)
+
+except ImportError:
+ immutabledict = _immutabledict_py_fallback()
+
+ def _immutabledict_reconstructor(*arg):
+ """do the pickle dance"""
+ return immutabledict(*arg)
+
+
+def coerce_to_immutabledict(d):
+ if not d:
+ return EMPTY_DICT
+ elif isinstance(d, immutabledict):
+ return d
+ else:
+ return immutabledict(d)
+
+
+EMPTY_DICT = immutabledict()
+
+
+class FacadeDict(ImmutableContainer, dict):
+ """A dictionary that is not publicly mutable."""
+
+ clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
+
+ def __new__(cls, *args):
+ new = dict.__new__(cls)
+ return new
+
+ def copy(self):
+ raise NotImplementedError(
+ "an immutabledict shouldn't need to be copied. use dict(d) "
+ "if you need a mutable dictionary."
+ )
+
+ def __reduce__(self):
+ return FacadeDict, (dict(self),)
+
+ def _insert_item(self, key, value):
+ """insert an item into the dictionary directly."""
+ dict.__setitem__(self, key, value)
+
+ def __repr__(self):
+ return "FacadeDict(%s)" % dict.__repr__(self)
+
+
+class Properties(object):
+ """Provide a __getattr__/__setattr__ interface over a dict."""
+
+ __slots__ = ("_data",)
+
+ def __init__(self, data):
+ object.__setattr__(self, "_data", data)
+
+ def __len__(self):
+ return len(self._data)
+
+ def __iter__(self):
+ return iter(list(self._data.values()))
+
+ def __dir__(self):
+ return dir(super(Properties, self)) + [
+ str(k) for k in self._data.keys()
+ ]
+
+ def __add__(self, other):
+ return list(self) + list(other)
+
+ def __setitem__(self, key, obj):
+ self._data[key] = obj
+
+ def __getitem__(self, key):
+ return self._data[key]
+
+ def __delitem__(self, key):
+ del self._data[key]
+
+ def __setattr__(self, key, obj):
+ self._data[key] = obj
+
+ def __getstate__(self):
+ return {"_data": self._data}
+
+ def __setstate__(self, state):
+ object.__setattr__(self, "_data", state["_data"])
+
+ def __getattr__(self, key):
+ try:
+ return self._data[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __contains__(self, key):
+ return key in self._data
+
+ def as_immutable(self):
+ """Return an immutable proxy for this :class:`.Properties`."""
+
+ return ImmutableProperties(self._data)
+
+ def update(self, value):
+ self._data.update(value)
+
+ def get(self, key, default=None):
+ if key in self:
+ return self[key]
+ else:
+ return default
+
+ def keys(self):
+ return list(self._data)
+
+ def values(self):
+ return list(self._data.values())
+
+ def items(self):
+ return list(self._data.items())
+
+ def has_key(self, key):
+ return key in self._data
+
+ def clear(self):
+ self._data.clear()
+
+
+class OrderedProperties(Properties):
+ """Provide a __getattr__/__setattr__ interface with an OrderedDict
+ as backing store."""
+
+ __slots__ = ()
+
+ def __init__(self):
+ Properties.__init__(self, OrderedDict())
+
+
+class ImmutableProperties(ImmutableContainer, Properties):
+ """Provide immutable dict/object attribute to an underlying dictionary."""
+
+ __slots__ = ()
+
+
+def _ordered_dictionary_sort(d, key=None):
+ """Sort an OrderedDict in-place."""
+
+ items = [(k, d[k]) for k in sorted(d, key=key)]
+
+ d.clear()
+
+ d.update(items)
+
+
+if py37:
+ OrderedDict = dict
+ sort_dictionary = _ordered_dictionary_sort
+
+else:
+ # prevent sort_dictionary from being used against a plain dictionary
+ # for Python < 3.7
+
+ def sort_dictionary(d, key=None):
+ """Sort an OrderedDict in place."""
+
+ d._ordered_dictionary_sort(key=key)
+
+ class OrderedDict(dict):
+ """Dictionary that maintains insertion order.
+
+ Superseded by Python dict as of Python 3.7
+
+ """
+
+ __slots__ = ("_list",)
+
+ def _ordered_dictionary_sort(self, key=None):
+ _ordered_dictionary_sort(self, key=key)
+
+ def __reduce__(self):
+ return OrderedDict, (self.items(),)
+
+ def __init__(self, ____sequence=None, **kwargs):
+ self._list = []
+ if ____sequence is None:
+ if kwargs:
+ self.update(**kwargs)
+ else:
+ self.update(____sequence, **kwargs)
+
+ def clear(self):
+ self._list = []
+ dict.clear(self)
+
+ def copy(self):
+ return self.__copy__()
+
+ def __copy__(self):
+ return OrderedDict(self)
+
+ def update(self, ____sequence=None, **kwargs):
+ if ____sequence is not None:
+ if hasattr(____sequence, "keys"):
+ for key in ____sequence.keys():
+ self.__setitem__(key, ____sequence[key])
+ else:
+ for key, value in ____sequence:
+ self[key] = value
+ if kwargs:
+ self.update(kwargs)
+
+ def setdefault(self, key, value):
+ if key not in self:
+ self.__setitem__(key, value)
+ return value
+ else:
+ return self.__getitem__(key)
+
+ def __iter__(self):
+ return iter(self._list)
+
+ def keys(self):
+ return list(self)
+
+ def values(self):
+ return [self[key] for key in self._list]
+
+ def items(self):
+ return [(key, self[key]) for key in self._list]
+
+ if py2k:
+
+ def itervalues(self):
+ return iter(self.values())
+
+ def iterkeys(self):
+ return iter(self)
+
+ def iteritems(self):
+ return iter(self.items())
+
+ def __setitem__(self, key, obj):
+ if key not in self:
+ try:
+ self._list.append(key)
+ except AttributeError:
+ # work around Python pickle loads() with
+ # dict subclass (seems to ignore __setstate__?)
+ self._list = [key]
+ dict.__setitem__(self, key, obj)
+
+ def __delitem__(self, key):
+ dict.__delitem__(self, key)
+ self._list.remove(key)
+
+ def pop(self, key, *default):
+ present = key in self
+ value = dict.pop(self, key, *default)
+ if present:
+ self._list.remove(key)
+ return value
+
+ def popitem(self):
+ item = dict.popitem(self)
+ self._list.remove(item[0])
+ return item
+
+
+class OrderedSet(set):
+ def __init__(self, d=None):
+ set.__init__(self)
+ if d is not None:
+ self._list = unique_list(d)
+ set.update(self, self._list)
+ else:
+ self._list = []
+
+ def add(self, element):
+ if element not in self:
+ self._list.append(element)
+ set.add(self, element)
+
+ def remove(self, element):
+ set.remove(self, element)
+ self._list.remove(element)
+
+ def insert(self, pos, element):
+ if element not in self:
+ self._list.insert(pos, element)
+ set.add(self, element)
+
+ def discard(self, element):
+ if element in self:
+ self._list.remove(element)
+ set.remove(self, element)
+
+ def clear(self):
+ set.clear(self)
+ self._list = []
+
+ def __getitem__(self, key):
+ return self._list[key]
+
+ def __iter__(self):
+ return iter(self._list)
+
+ def __add__(self, other):
+ return self.union(other)
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self._list)
+
+ __str__ = __repr__
+
+ def update(self, iterable):
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ set.add(self, e)
+ return self
+
+ __ior__ = update
+
+ def union(self, other):
+ result = self.__class__(self)
+ result.update(other)
+ return result
+
+ __or__ = union
+
+ def intersection(self, other):
+ other = set(other)
+ return self.__class__(a for a in self if a in other)
+
+ __and__ = intersection
+
+ def symmetric_difference(self, other):
+ other = set(other)
+ result = self.__class__(a for a in self if a not in other)
+ result.update(a for a in other if a not in self)
+ return result
+
+ __xor__ = symmetric_difference
+
+ def difference(self, other):
+ other = set(other)
+ return self.__class__(a for a in self if a not in other)
+
+ __sub__ = difference
+
+ def intersection_update(self, other):
+ other = set(other)
+ set.intersection_update(self, other)
+ self._list = [a for a in self._list if a in other]
+ return self
+
+ __iand__ = intersection_update
+
+ def symmetric_difference_update(self, other):
+ set.symmetric_difference_update(self, other)
+ self._list = [a for a in self._list if a in self]
+ self._list += [a for a in other._list if a in self]
+ return self
+
+ __ixor__ = symmetric_difference_update
+
+ def difference_update(self, other):
+ set.difference_update(self, other)
+ self._list = [a for a in self._list if a in self]
+ return self
+
+ __isub__ = difference_update
+
+
+class IdentitySet(object):
+ """A set that considers only object id() for uniqueness.
+
+ This strategy has edge cases for builtin types- it's possible to have
+ two 'foo' strings in one of these sets, for example. Use sparingly.
+
+ """
+
+ def __init__(self, iterable=None):
+ self._members = dict()
+ if iterable:
+ self.update(iterable)
+
+ def add(self, value):
+ self._members[id(value)] = value
+
+ def __contains__(self, value):
+ return id(value) in self._members
+
+ def remove(self, value):
+ del self._members[id(value)]
+
+ def discard(self, value):
+ try:
+ self.remove(value)
+ except KeyError:
+ pass
+
+ def pop(self):
+ try:
+ pair = self._members.popitem()
+ return pair[1]
+ except KeyError:
+ raise KeyError("pop from an empty set")
+
+ def clear(self):
+ self._members.clear()
+
+ def __cmp__(self, other):
+ raise TypeError("cannot compare sets using cmp()")
+
+ def __eq__(self, other):
+ if isinstance(other, IdentitySet):
+ return self._members == other._members
+ else:
+ return False
+
+ def __ne__(self, other):
+ if isinstance(other, IdentitySet):
+ return self._members != other._members
+ else:
+ return True
+
+ def issubset(self, iterable):
+ if isinstance(iterable, self.__class__):
+ other = iterable
+ else:
+ other = self.__class__(iterable)
+
+ if len(self) > len(other):
+ return False
+ for m in itertools_filterfalse(
+ other._members.__contains__, iter(self._members.keys())
+ ):
+ return False
+ return True
+
+ def __le__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.issubset(other)
+
+ def __lt__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return len(self) < len(other) and self.issubset(other)
+
+ def issuperset(self, iterable):
+ if isinstance(iterable, self.__class__):
+ other = iterable
+ else:
+ other = self.__class__(iterable)
+
+ if len(self) < len(other):
+ return False
+
+ for m in itertools_filterfalse(
+ self._members.__contains__, iter(other._members.keys())
+ ):
+ return False
+ return True
+
+ def __ge__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.issuperset(other)
+
+ def __gt__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return len(self) > len(other) and self.issuperset(other)
+
+ def union(self, iterable):
+ result = self.__class__()
+ members = self._members
+ result._members.update(members)
+ result._members.update((id(obj), obj) for obj in iterable)
+ return result
+
+ def __or__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.union(other)
+
+ def update(self, iterable):
+ self._members.update((id(obj), obj) for obj in iterable)
+
+ def __ior__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.update(other)
+ return self
+
+ def difference(self, iterable):
+ result = self.__class__()
+ members = self._members
+ if isinstance(iterable, self.__class__):
+ other = set(iterable._members.keys())
+ else:
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ ((k, v) for k, v in members.items() if k not in other)
+ )
+ return result
+
+ def __sub__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.difference(other)
+
+ def difference_update(self, iterable):
+ self._members = self.difference(iterable)._members
+
+ def __isub__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.difference_update(other)
+ return self
+
+ def intersection(self, iterable):
+ result = self.__class__()
+ members = self._members
+ if isinstance(iterable, self.__class__):
+ other = set(iterable._members.keys())
+ else:
+ other = {id(obj) for obj in iterable}
+ result._members.update(
+ (k, v) for k, v in members.items() if k in other
+ )
+ return result
+
+ def __and__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.intersection(other)
+
+ def intersection_update(self, iterable):
+ self._members = self.intersection(iterable)._members
+
+ def __iand__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.intersection_update(other)
+ return self
+
+ def symmetric_difference(self, iterable):
+ result = self.__class__()
+ members = self._members
+ if isinstance(iterable, self.__class__):
+ other = iterable._members
+ else:
+ other = {id(obj): obj for obj in iterable}
+ result._members.update(
+ ((k, v) for k, v in members.items() if k not in other)
+ )
+ result._members.update(
+ ((k, v) for k, v in other.items() if k not in members)
+ )
+ return result
+
+ def __xor__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ return self.symmetric_difference(other)
+
+ def symmetric_difference_update(self, iterable):
+ self._members = self.symmetric_difference(iterable)._members
+
+ def __ixor__(self, other):
+ if not isinstance(other, IdentitySet):
+ return NotImplemented
+ self.symmetric_difference(other)
+ return self
+
+ def copy(self):
+ return type(self)(iter(self._members.values()))
+
+ __copy__ = copy
+
+ def __len__(self):
+ return len(self._members)
+
+ def __iter__(self):
+ return iter(self._members.values())
+
+ def __hash__(self):
+ raise TypeError("set objects are unhashable")
+
+ def __repr__(self):
+ return "%s(%r)" % (type(self).__name__, list(self._members.values()))
+
+
+class WeakSequence(object):
+ def __init__(self, __elements=()):
+ # adapted from weakref.WeakKeyDictionary, prevent reference
+ # cycles in the collection itself
+ def _remove(item, selfref=weakref.ref(self)):
+ self = selfref()
+ if self is not None:
+ self._storage.remove(item)
+
+ self._remove = _remove
+ self._storage = [
+ weakref.ref(element, _remove) for element in __elements
+ ]
+
+ def append(self, item):
+ self._storage.append(weakref.ref(item, self._remove))
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ return (
+ obj for obj in (ref() for ref in self._storage) if obj is not None
+ )
+
+ def __getitem__(self, index):
+ try:
+ obj = self._storage[index]
+ except KeyError:
+ raise IndexError("Index %s out of range" % index)
+ else:
+ return obj()
+
+
+class OrderedIdentitySet(IdentitySet):
+ def __init__(self, iterable=None):
+ IdentitySet.__init__(self)
+ self._members = OrderedDict()
+ if iterable:
+ for o in iterable:
+ self.add(o)
+
+
+class PopulateDict(dict):
+ """A dict which populates missing values via a creation function.
+
+ Note the creation function takes a key, unlike
+ collections.defaultdict.
+
+ """
+
+ def __init__(self, creator):
+ self.creator = creator
+
+ def __missing__(self, key):
+ self[key] = val = self.creator(key)
+ return val
+
+
+class WeakPopulateDict(dict):
+ """Like PopulateDict, but assumes a self + a method and does not create
+ a reference cycle.
+
+ """
+
+ def __init__(self, creator_method):
+ self.creator = creator_method.__func__
+ weakself = creator_method.__self__
+ self.weakself = weakref.ref(weakself)
+
+ def __missing__(self, key):
+ self[key] = val = self.creator(self.weakself(), key)
+ return val
+
+
+# Define collections that are capable of storing
+# ColumnElement objects as hashable keys/elements.
+# At this point, these are mostly historical, things
+# used to be more complicated.
+column_set = set
+column_dict = dict
+ordered_column_set = OrderedSet
+
+
+_getters = PopulateDict(operator.itemgetter)
+
+_property_getters = PopulateDict(
+ lambda idx: property(operator.itemgetter(idx))
+)
+
+
+def unique_list(seq, hashfunc=None):
+ seen = set()
+ seen_add = seen.add
+ if not hashfunc:
+ return [x for x in seq if x not in seen and not seen_add(x)]
+ else:
+ return [
+ x
+ for x in seq
+ if hashfunc(x) not in seen and not seen_add(hashfunc(x))
+ ]
+
+
+class UniqueAppender(object):
+ """Appends items to a collection ensuring uniqueness.
+
+ Additional appends() of the same object are ignored. Membership is
+ determined by identity (``is a``) not equality (``==``).
+ """
+
+ def __init__(self, data, via=None):
+ self.data = data
+ self._unique = {}
+ if via:
+ self._data_appender = getattr(data, via)
+ elif hasattr(data, "append"):
+ self._data_appender = data.append
+ elif hasattr(data, "add"):
+ self._data_appender = data.add
+
+ def append(self, item):
+ id_ = id(item)
+ if id_ not in self._unique:
+ self._data_appender(item)
+ self._unique[id_] = True
+
+ def __iter__(self):
+ return iter(self.data)
+
+
+def coerce_generator_arg(arg):
+ if len(arg) == 1 and isinstance(arg[0], types.GeneratorType):
+ return list(arg[0])
+ else:
+ return arg
+
+
+def to_list(x, default=None):
+ if x is None:
+ return default
+ if not isinstance(x, collections_abc.Iterable) or isinstance(
+ x, string_types + binary_types
+ ):
+ return [x]
+ elif isinstance(x, list):
+ return x
+ else:
+ return list(x)
+
+
+def has_intersection(set_, iterable):
+ r"""return True if any items of set\_ are present in iterable.
+
+ Goes through special effort to ensure __hash__ is not called
+ on items in iterable that don't support it.
+
+ """
+ # TODO: optimize, write in C, etc.
+ return bool(set_.intersection([i for i in iterable if i.__hash__]))
+
+
+def to_set(x):
+ if x is None:
+ return set()
+ if not isinstance(x, set):
+ return set(to_list(x))
+ else:
+ return x
+
+
+def to_column_set(x):
+ if x is None:
+ return column_set()
+ if not isinstance(x, column_set):
+ return column_set(to_list(x))
+ else:
+ return x
+
+
+def update_copy(d, _new=None, **kw):
+ """Copy the given dict and update with the given values."""
+
+ d = d.copy()
+ if _new:
+ d.update(_new)
+ d.update(**kw)
+ return d
+
+
+def flatten_iterator(x):
+ """Given an iterator of which further sub-elements may also be
+ iterators, flatten the sub-elements into a single iterator.
+
+ """
+ for elem in x:
+ if not isinstance(elem, str) and hasattr(elem, "__iter__"):
+ for y in flatten_iterator(elem):
+ yield y
+ else:
+ yield elem
+
+
+class LRUCache(dict):
+ """Dictionary with 'squishy' removal of least
+ recently used items.
+
+ Note that either get() or [] should be used here, but
+ generally its not safe to do an "in" check first as the dictionary
+ can change subsequent to that call.
+
+ """
+
+ __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
+
+ def __init__(self, capacity=100, threshold=0.5, size_alert=None):
+ self.capacity = capacity
+ self.threshold = threshold
+ self.size_alert = size_alert
+ self._counter = 0
+ self._mutex = threading.Lock()
+
+ def _inc_counter(self):
+ self._counter += 1
+ return self._counter
+
+ def get(self, key, default=None):
+ item = dict.get(self, key, default)
+ if item is not default:
+ item[2] = self._inc_counter()
+ return item[1]
+ else:
+ return default
+
+ def __getitem__(self, key):
+ item = dict.__getitem__(self, key)
+ item[2] = self._inc_counter()
+ return item[1]
+
+ def values(self):
+ return [i[1] for i in dict.values(self)]
+
+ def setdefault(self, key, value):
+ if key in self:
+ return self[key]
+ else:
+ self[key] = value
+ return value
+
+ def __setitem__(self, key, value):
+ item = dict.get(self, key)
+ if item is None:
+ item = [key, value, self._inc_counter()]
+ dict.__setitem__(self, key, item)
+ else:
+ item[1] = value
+ self._manage_size()
+
+ @property
+ def size_threshold(self):
+ return self.capacity + self.capacity * self.threshold
+
+ def _manage_size(self):
+ if not self._mutex.acquire(False):
+ return
+ try:
+ size_alert = bool(self.size_alert)
+ while len(self) > self.capacity + self.capacity * self.threshold:
+ if size_alert:
+ size_alert = False
+ self.size_alert(self)
+ by_counter = sorted(
+ dict.values(self), key=operator.itemgetter(2), reverse=True
+ )
+ for item in by_counter[self.capacity :]:
+ try:
+ del self[item[0]]
+ except KeyError:
+ # deleted elsewhere; skip
+ continue
+ finally:
+ self._mutex.release()
+
+
+class ScopedRegistry(object):
+ """A Registry that can store one or multiple instances of a single
+ class on the basis of a "scope" function.
+
+ The object implements ``__call__`` as the "getter", so by
+ calling ``myregistry()`` the contained object is returned
+ for the current scope.
+
+ :param createfunc:
+ a callable that returns a new object to be placed in the registry
+
+ :param scopefunc:
+ a callable that will return a key to store/retrieve an object.
+ """
+
+ def __init__(self, createfunc, scopefunc):
+ """Construct a new :class:`.ScopedRegistry`.
+
+ :param createfunc: A creation function that will generate
+ a new value for the current scope, if none is present.
+
+ :param scopefunc: A function that returns a hashable
+ token representing the current scope (such as, current
+ thread identifier).
+
+ """
+ self.createfunc = createfunc
+ self.scopefunc = scopefunc
+ self.registry = {}
+
+ def __call__(self):
+ key = self.scopefunc()
+ try:
+ return self.registry[key]
+ except KeyError:
+ return self.registry.setdefault(key, self.createfunc())
+
+ def has(self):
+ """Return True if an object is present in the current scope."""
+
+ return self.scopefunc() in self.registry
+
+ def set(self, obj):
+ """Set the value for the current scope."""
+
+ self.registry[self.scopefunc()] = obj
+
+ def clear(self):
+ """Clear the current scope, if any."""
+
+ try:
+ del self.registry[self.scopefunc()]
+ except KeyError:
+ pass
+
+
+class ThreadLocalRegistry(ScopedRegistry):
+ """A :class:`.ScopedRegistry` that uses a ``threading.local()``
+ variable for storage.
+
+ """
+
+ def __init__(self, createfunc):
+ self.createfunc = createfunc
+ self.registry = threading.local()
+
+ def __call__(self):
+ try:
+ return self.registry.value
+ except AttributeError:
+ val = self.registry.value = self.createfunc()
+ return val
+
+ def has(self):
+ return hasattr(self.registry, "value")
+
+ def set(self, obj):
+ self.registry.value = obj
+
+ def clear(self):
+ try:
+ del self.registry.value
+ except AttributeError:
+ pass
+
+
+def has_dupes(sequence, target):
+ """Given a sequence and search object, return True if there's more
+ than one, False if zero or one of them.
+
+
+ """
+ # compare to .index version below, this version introduces less function
+ # overhead and is usually the same speed. At 15000 items (way bigger than
+ # a relationship-bound collection in memory usually is) it begins to
+ # fall behind the other version only by microseconds.
+ c = 0
+ for item in sequence:
+ if item is target:
+ c += 1
+ if c > 1:
+ return True
+ return False
+
+
+# .index version. the two __contains__ calls as well
+# as .index() and isinstance() slow this down.
+# def has_dupes(sequence, target):
+# if target not in sequence:
+# return False
+# elif not isinstance(sequence, collections_abc.Sequence):
+# return False
+#
+# idx = sequence.index(target)
+# return target in sequence[idx + 1:]
diff --git a/lib/sqlalchemy/util/_compat_py3k.py b/lib/sqlalchemy/util/_compat_py3k.py
new file mode 100644
index 0000000..ce659a4
--- /dev/null
+++ b/lib/sqlalchemy/util/_compat_py3k.py
@@ -0,0 +1,67 @@
+# util/_compat_py3k.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from functools import wraps
+
+# vendored from py3.7
+
+
+class _AsyncGeneratorContextManager:
+ """Helper for @asynccontextmanager."""
+
+ def __init__(self, func, args, kwds):
+ self.gen = func(*args, **kwds)
+ self.func, self.args, self.kwds = func, args, kwds
+ doc = getattr(func, "__doc__", None)
+ if doc is None:
+ doc = type(self).__doc__
+ self.__doc__ = doc
+
+ async def __aenter__(self):
+ try:
+ return await self.gen.__anext__()
+ except StopAsyncIteration:
+ raise RuntimeError("generator didn't yield") from None
+
+ async def __aexit__(self, typ, value, traceback):
+ if typ is None:
+ try:
+ await self.gen.__anext__()
+ except StopAsyncIteration:
+ return
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ if value is None:
+ value = typ()
+ # See _GeneratorContextManager.__exit__ for comments on subtleties
+ # in this implementation
+ try:
+ await self.gen.athrow(typ, value, traceback)
+ raise RuntimeError("generator didn't stop after athrow()")
+ except StopAsyncIteration as exc:
+ return exc is not value
+ except RuntimeError as exc:
+ if exc is value:
+ return False
+ if isinstance(value, (StopIteration, StopAsyncIteration)):
+ if exc.__cause__ is value:
+ return False
+ raise
+ except BaseException as exc:
+ if exc is not value:
+ raise
+
+
+# using the vendored version in all cases at the moment to establish
+# full test coverage
+def asynccontextmanager(func):
+ @wraps(func)
+ def helper(*args, **kwds):
+ return _AsyncGeneratorContextManager(func, args, kwds)
+
+ return helper
diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py
new file mode 100644
index 0000000..0b12834
--- /dev/null
+++ b/lib/sqlalchemy/util/_concurrency_py3k.py
@@ -0,0 +1,194 @@
+# util/_concurrency_py3k.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import asyncio
+import sys
+from typing import Any
+from typing import Callable
+from typing import Coroutine
+
+import greenlet
+
+from . import compat
+from .langhelpers import memoized_property
+from .. import exc
+
+# If greenlet.gr_context is present in current version of greenlet,
+# it will be set with the current context on creation.
+# Refs: https://github.com/python-greenlet/greenlet/pull/198
+_has_gr_context = hasattr(greenlet.getcurrent(), "gr_context")
+
+
+def is_exit_exception(e):
+ # note asyncio.CancelledError is already BaseException
+ # so was an exit exception in any case
+ return not isinstance(e, Exception) or isinstance(
+ e, (asyncio.TimeoutError, asyncio.CancelledError)
+ )
+
+
+# implementation based on snaury gist at
+# https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef
+# Issue for context: https://github.com/python-greenlet/greenlet/issues/173
+
+
+class _AsyncIoGreenlet(greenlet.greenlet):
+ def __init__(self, fn, driver):
+ greenlet.greenlet.__init__(self, fn, driver)
+ self.driver = driver
+ if _has_gr_context:
+ self.gr_context = driver.gr_context
+
+
+def await_only(awaitable: Coroutine) -> Any:
+ """Awaits an async function in a sync method.
+
+ The sync method must be inside a :func:`greenlet_spawn` context.
+ :func:`await_only` calls cannot be nested.
+
+ :param awaitable: The coroutine to call.
+
+ """
+ # this is called in the context greenlet while running fn
+ current = greenlet.getcurrent()
+ if not isinstance(current, _AsyncIoGreenlet):
+ raise exc.MissingGreenlet(
+ "greenlet_spawn has not been called; can't call await_only() "
+ "here. Was IO attempted in an unexpected place?"
+ )
+
+ # returns the control to the driver greenlet passing it
+ # a coroutine to run. Once the awaitable is done, the driver greenlet
+ # switches back to this greenlet with the result of awaitable that is
+ # then returned to the caller (or raised as error)
+ return current.driver.switch(awaitable)
+
+
+def await_fallback(awaitable: Coroutine) -> Any:
+ """Awaits an async function in a sync method.
+
+ The sync method must be inside a :func:`greenlet_spawn` context.
+ :func:`await_fallback` calls cannot be nested.
+
+ :param awaitable: The coroutine to call.
+
+ """
+ # this is called in the context greenlet while running fn
+ current = greenlet.getcurrent()
+ if not isinstance(current, _AsyncIoGreenlet):
+ loop = get_event_loop()
+ if loop.is_running():
+ raise exc.MissingGreenlet(
+ "greenlet_spawn has not been called and asyncio event "
+ "loop is already running; can't call await_fallback() here. "
+ "Was IO attempted in an unexpected place?"
+ )
+ return loop.run_until_complete(awaitable)
+
+ return current.driver.switch(awaitable)
+
+
+async def greenlet_spawn(
+ fn: Callable, *args, _require_await=False, **kwargs
+) -> Any:
+ """Runs a sync function ``fn`` in a new greenlet.
+
+ The sync function can then use :func:`await_only` to wait for async
+ functions.
+
+ :param fn: The sync callable to call.
+ :param \\*args: Positional arguments to pass to the ``fn`` callable.
+ :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable.
+ """
+
+ context = _AsyncIoGreenlet(fn, greenlet.getcurrent())
+ # runs the function synchronously in gl greenlet. If the execution
+ # is interrupted by await_only, context is not dead and result is a
+ # coroutine to wait. If the context is dead the function has
+ # returned, and its result can be returned.
+ switch_occurred = False
+ try:
+ result = context.switch(*args, **kwargs)
+ while not context.dead:
+ switch_occurred = True
+ try:
+ # wait for a coroutine from await_only and then return its
+ # result back to it.
+ value = await result
+ except BaseException:
+ # this allows an exception to be raised within
+ # the moderated greenlet so that it can continue
+ # its expected flow.
+ result = context.throw(*sys.exc_info())
+ else:
+ result = context.switch(value)
+ finally:
+ # clean up to avoid cycle resolution by gc
+ del context.driver
+ if _require_await and not switch_occurred:
+ raise exc.AwaitRequired(
+ "The current operation required an async execution but none was "
+ "detected. This will usually happen when using a non compatible "
+ "DBAPI driver. Please ensure that an async DBAPI is used."
+ )
+ return result
+
+
+class AsyncAdaptedLock:
+ @memoized_property
+ def mutex(self):
+ # there should not be a race here for coroutines creating the
+ # new lock as we are not using await, so therefore no concurrency
+ return asyncio.Lock()
+
+ def __enter__(self):
+ # await is used to acquire the lock only after the first calling
+ # coroutine has created the mutex.
+ await_fallback(self.mutex.acquire())
+ return self
+
+ def __exit__(self, *arg, **kw):
+ self.mutex.release()
+
+
+def _util_async_run_coroutine_function(fn, *args, **kwargs):
+ """for test suite/ util only"""
+
+ loop = get_event_loop()
+ if loop.is_running():
+ raise Exception(
+ "for async run coroutine we expect that no greenlet or event "
+ "loop is running when we start out"
+ )
+ return loop.run_until_complete(fn(*args, **kwargs))
+
+
+def _util_async_run(fn, *args, **kwargs):
+ """for test suite/ util only"""
+
+ loop = get_event_loop()
+ if not loop.is_running():
+ return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs))
+ else:
+ # allow for a wrapped test function to call another
+ assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet)
+ return fn(*args, **kwargs)
+
+
+def get_event_loop():
+ """vendor asyncio.get_event_loop() for python 3.7 and above.
+
+ Python 3.10 deprecates get_event_loop() as a standalone.
+
+ """
+ if compat.py37:
+ try:
+ return asyncio.get_running_loop()
+ except RuntimeError:
+ return asyncio.get_event_loop_policy().get_event_loop()
+ else:
+ return asyncio.get_event_loop()
diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/_preloaded.py
new file mode 100644
index 0000000..1803de4
--- /dev/null
+++ b/lib/sqlalchemy/util/_preloaded.py
@@ -0,0 +1,68 @@
+# util/_preloaded.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""supplies the "preloaded" registry to resolve circular module imports at
+runtime.
+
+"""
+
+import sys
+
+from . import compat
+
+
+class _ModuleRegistry:
+ """Registry of modules to load in a package init file.
+
+ To avoid potential thread safety issues for imports that are deferred
+ in a function, like https://bugs.python.org/issue38884, these modules
+ are added to the system module cache by importing them after the packages
+ has finished initialization.
+
+ A global instance is provided under the name :attr:`.preloaded`. Use
+ the function :func:`.preload_module` to register modules to load and
+ :meth:`.import_prefix` to load all the modules that start with the
+ given path.
+
+ While the modules are loaded in the global module cache, it's advisable
+ to access them using :attr:`.preloaded` to ensure that it was actually
+ registered. Each registered module is added to the instance ``__dict__``
+ in the form `<package>_<module>`, omitting ``sqlalchemy`` from the package
+ name. Example: ``sqlalchemy.sql.util`` becomes ``preloaded.sql_util``.
+ """
+
+ def __init__(self, prefix="sqlalchemy."):
+ self.module_registry = set()
+ self.prefix = prefix
+
+ def preload_module(self, *deps):
+ """Adds the specified modules to the list to load.
+
+ This method can be used both as a normal function and as a decorator.
+ No change is performed to the decorated object.
+ """
+ self.module_registry.update(deps)
+ return lambda fn: fn
+
+ def import_prefix(self, path):
+ """Resolve all the modules in the registry that start with the
+ specified path.
+ """
+ for module in self.module_registry:
+ if self.prefix:
+ key = module.split(self.prefix)[-1].replace(".", "_")
+ else:
+ key = module
+ if (
+ not path or module.startswith(path)
+ ) and key not in self.__dict__:
+ compat.import_(module, globals(), locals())
+ self.__dict__[key] = sys.modules[module]
+
+
+preloaded = _ModuleRegistry()
+preload_module = preloaded.preload_module
diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py
new file mode 100644
index 0000000..21a9491
--- /dev/null
+++ b/lib/sqlalchemy/util/compat.py
@@ -0,0 +1,632 @@
+# util/compat.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Handle Python version/platform incompatibilities."""
+
+import collections
+import contextlib
+import inspect
+import operator
+import platform
+import sys
+
+py311 = sys.version_info >= (3, 11)
+py39 = sys.version_info >= (3, 9)
+py38 = sys.version_info >= (3, 8)
+py37 = sys.version_info >= (3, 7)
+py3k = sys.version_info >= (3, 0)
+py2k = sys.version_info < (3, 0)
+pypy = platform.python_implementation() == "PyPy"
+
+
+cpython = platform.python_implementation() == "CPython"
+win32 = sys.platform.startswith("win")
+osx = sys.platform.startswith("darwin")
+arm = "aarch" in platform.machine().lower()
+
+has_refcount_gc = bool(cpython)
+
+contextmanager = contextlib.contextmanager
+dottedgetter = operator.attrgetter
+namedtuple = collections.namedtuple
+next = next # noqa
+
+FullArgSpec = collections.namedtuple(
+ "FullArgSpec",
+ [
+ "args",
+ "varargs",
+ "varkw",
+ "defaults",
+ "kwonlyargs",
+ "kwonlydefaults",
+ "annotations",
+ ],
+)
+
+
+class nullcontext(object):
+ """Context manager that does no additional processing.
+
+ Vendored from Python 3.7.
+
+ """
+
+ def __init__(self, enter_result=None):
+ self.enter_result = enter_result
+
+ def __enter__(self):
+ return self.enter_result
+
+ def __exit__(self, *excinfo):
+ pass
+
+
+try:
+ import threading
+except ImportError:
+ import dummy_threading as threading # noqa
+
+
+def inspect_getfullargspec(func):
+ """Fully vendored version of getfullargspec from Python 3.3."""
+
+ if inspect.ismethod(func):
+ func = func.__func__
+ if not inspect.isfunction(func):
+ raise TypeError("{!r} is not a Python function".format(func))
+
+ co = func.__code__
+ if not inspect.iscode(co):
+ raise TypeError("{!r} is not a code object".format(co))
+
+ nargs = co.co_argcount
+ names = co.co_varnames
+ nkwargs = co.co_kwonlyargcount if py3k else 0
+ args = list(names[:nargs])
+ kwonlyargs = list(names[nargs : nargs + nkwargs])
+
+ nargs += nkwargs
+ varargs = None
+ if co.co_flags & inspect.CO_VARARGS:
+ varargs = co.co_varnames[nargs]
+ nargs = nargs + 1
+ varkw = None
+ if co.co_flags & inspect.CO_VARKEYWORDS:
+ varkw = co.co_varnames[nargs]
+
+ return FullArgSpec(
+ args,
+ varargs,
+ varkw,
+ func.__defaults__,
+ kwonlyargs,
+ func.__kwdefaults__ if py3k else None,
+ func.__annotations__ if py3k else {},
+ )
+
+
+if py38:
+ from importlib import metadata as importlib_metadata
+else:
+ import importlib_metadata # noqa
+
+
+def importlib_metadata_get(group):
+ ep = importlib_metadata.entry_points()
+ if hasattr(ep, "select"):
+ return ep.select(group=group)
+ else:
+ return ep.get(group, ())
+
+
+if py3k:
+ import base64
+ import builtins
+ import configparser
+ import itertools
+ import pickle
+
+ from functools import reduce
+ from io import BytesIO as byte_buffer
+ from io import StringIO
+ from itertools import zip_longest
+ from time import perf_counter
+ from urllib.parse import (
+ quote_plus,
+ unquote_plus,
+ parse_qsl,
+ quote,
+ unquote,
+ )
+
+ string_types = (str,)
+ binary_types = (bytes,)
+ binary_type = bytes
+ text_type = str
+ int_types = (int,)
+ iterbytes = iter
+ long_type = int
+
+ itertools_filterfalse = itertools.filterfalse
+ itertools_filter = filter
+ itertools_imap = map
+
+ exec_ = getattr(builtins, "exec")
+ import_ = getattr(builtins, "__import__")
+ print_ = getattr(builtins, "print")
+
+ def b(s):
+ return s.encode("latin-1")
+
+ def b64decode(x):
+ return base64.b64decode(x.encode("ascii"))
+
+ def b64encode(x):
+ return base64.b64encode(x).decode("ascii")
+
+ def decode_backslashreplace(text, encoding):
+ return text.decode(encoding, errors="backslashreplace")
+
+ def cmp(a, b):
+ return (a > b) - (a < b)
+
+ def raise_(
+ exception, with_traceback=None, replace_context=None, from_=False
+ ):
+ r"""implement "raise" with cause support.
+
+ :param exception: exception to raise
+ :param with_traceback: will call exception.with_traceback()
+ :param replace_context: an as-yet-unsupported feature. This is
+ an exception object which we are "replacing", e.g., it's our
+ "cause" but we don't want it printed. Basically just what
+ ``__suppress_context__`` does but we don't want to suppress
+ the enclosing context, if any. So for now we make it the
+ cause.
+ :param from\_: the cause. this actually sets the cause and doesn't
+ hope to hide it someday.
+
+ """
+ if with_traceback is not None:
+ exception = exception.with_traceback(with_traceback)
+
+ if from_ is not False:
+ exception.__cause__ = from_
+ elif replace_context is not None:
+ # no good solution here, we would like to have the exception
+ # have only the context of replace_context.__context__ so that the
+ # intermediary exception does not change, but we can't figure
+ # that out.
+ exception.__cause__ = replace_context
+
+ try:
+ raise exception
+ finally:
+ # credit to
+ # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/
+ # as the __traceback__ object creates a cycle
+ del exception, replace_context, from_, with_traceback
+
+ def u(s):
+ return s
+
+ def ue(s):
+ return s
+
+ from typing import TYPE_CHECKING
+
+ # Unused. Kept for backwards compatibility.
+ callable = callable # noqa
+
+ from abc import ABC
+
+ def _qualname(fn):
+ return fn.__qualname__
+
+
+else:
+ import base64
+ import ConfigParser as configparser # noqa
+ import itertools
+
+ from StringIO import StringIO # noqa
+ from cStringIO import StringIO as byte_buffer # noqa
+ from itertools import izip_longest as zip_longest # noqa
+ from time import clock as perf_counter # noqa
+ from urllib import quote # noqa
+ from urllib import quote_plus # noqa
+ from urllib import unquote # noqa
+ from urllib import unquote_plus # noqa
+ from urlparse import parse_qsl # noqa
+
+ from abc import ABCMeta
+
+ class ABC(object):
+ __metaclass__ = ABCMeta
+
+ try:
+ import cPickle as pickle
+ except ImportError:
+ import pickle # noqa
+
+ string_types = (basestring,) # noqa
+ binary_types = (bytes,)
+ binary_type = str
+ text_type = unicode # noqa
+ int_types = int, long # noqa
+ long_type = long # noqa
+
+ callable = callable # noqa
+ cmp = cmp # noqa
+ reduce = reduce # noqa
+
+ b64encode = base64.b64encode
+ b64decode = base64.b64decode
+
+ itertools_filterfalse = itertools.ifilterfalse
+ itertools_filter = itertools.ifilter
+ itertools_imap = itertools.imap
+
+ def b(s):
+ return s
+
+ def exec_(func_text, globals_, lcl=None):
+ if lcl is None:
+ exec("exec func_text in globals_")
+ else:
+ exec("exec func_text in globals_, lcl")
+
+ def iterbytes(buf):
+ return (ord(byte) for byte in buf)
+
+ def import_(*args):
+ if len(args) == 4:
+ args = args[0:3] + ([str(arg) for arg in args[3]],)
+ return __import__(*args)
+
+ def print_(*args, **kwargs):
+ fp = kwargs.pop("file", sys.stdout)
+ if fp is None:
+ return
+ for arg in enumerate(args):
+ if not isinstance(arg, basestring): # noqa
+ arg = str(arg)
+ fp.write(arg)
+
+ def u(s):
+ # this differs from what six does, which doesn't support non-ASCII
+ # strings - we only use u() with
+ # literal source strings, and all our source files with non-ascii
+ # in them (all are tests) are utf-8 encoded.
+ return unicode(s, "utf-8") # noqa
+
+ def ue(s):
+ return unicode(s, "unicode_escape") # noqa
+
+ def decode_backslashreplace(text, encoding):
+ try:
+ return text.decode(encoding)
+ except UnicodeDecodeError:
+ # regular "backslashreplace" for an incompatible encoding raises:
+ # "TypeError: don't know how to handle UnicodeDecodeError in
+ # error callback"
+ return repr(text)[1:-1].decode()
+
+ def safe_bytestring(text):
+ # py2k only
+ if not isinstance(text, string_types):
+ return unicode(text).encode( # noqa: F821
+ "ascii", errors="backslashreplace"
+ )
+ elif isinstance(text, unicode): # noqa: F821
+ return text.encode("ascii", errors="backslashreplace")
+ else:
+ return text
+
+ exec(
+ "def raise_(exception, with_traceback=None, replace_context=None, "
+ "from_=False):\n"
+ " if with_traceback:\n"
+ " raise type(exception), exception, with_traceback\n"
+ " else:\n"
+ " raise exception\n"
+ )
+
+ TYPE_CHECKING = False
+
+ def _qualname(meth):
+ """return __qualname__ equivalent for a method on a class"""
+
+ for cls in meth.im_class.__mro__:
+ if meth.__name__ in cls.__dict__:
+ break
+ else:
+ return meth.__name__
+
+ return "%s.%s" % (cls.__name__, meth.__name__)
+
+
+if py3k:
+
+ def _formatannotation(annotation, base_module=None):
+ """vendored from python 3.7"""
+
+ if getattr(annotation, "__module__", None) == "typing":
+ return repr(annotation).replace("typing.", "")
+ if isinstance(annotation, type):
+ if annotation.__module__ in ("builtins", base_module):
+ return annotation.__qualname__
+ return annotation.__module__ + "." + annotation.__qualname__
+ return repr(annotation)
+
+ def inspect_formatargspec(
+ args,
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=(),
+ kwonlydefaults={},
+ annotations={},
+ formatarg=str,
+ formatvarargs=lambda name: "*" + name,
+ formatvarkw=lambda name: "**" + name,
+ formatvalue=lambda value: "=" + repr(value),
+ formatreturns=lambda text: " -> " + text,
+ formatannotation=_formatannotation,
+ ):
+ """Copy formatargspec from python 3.7 standard library.
+
+ Python 3 has deprecated formatargspec and requested that Signature
+ be used instead, however this requires a full reimplementation
+ of formatargspec() in terms of creating Parameter objects and such.
+ Instead of introducing all the object-creation overhead and having
+ to reinvent from scratch, just copy their compatibility routine.
+
+ Ultimately we would need to rewrite our "decorator" routine completely
+ which is not really worth it right now, until all Python 2.x support
+ is dropped.
+
+ """
+
+ kwonlydefaults = kwonlydefaults or {}
+ annotations = annotations or {}
+
+ def formatargandannotation(arg):
+ result = formatarg(arg)
+ if arg in annotations:
+ result += ": " + formatannotation(annotations[arg])
+ return result
+
+ specs = []
+ if defaults:
+ firstdefault = len(args) - len(defaults)
+ for i, arg in enumerate(args):
+ spec = formatargandannotation(arg)
+ if defaults and i >= firstdefault:
+ spec = spec + formatvalue(defaults[i - firstdefault])
+ specs.append(spec)
+
+ if varargs is not None:
+ specs.append(formatvarargs(formatargandannotation(varargs)))
+ else:
+ if kwonlyargs:
+ specs.append("*")
+
+ if kwonlyargs:
+ for kwonlyarg in kwonlyargs:
+ spec = formatargandannotation(kwonlyarg)
+ if kwonlydefaults and kwonlyarg in kwonlydefaults:
+ spec += formatvalue(kwonlydefaults[kwonlyarg])
+ specs.append(spec)
+
+ if varkw is not None:
+ specs.append(formatvarkw(formatargandannotation(varkw)))
+
+ result = "(" + ", ".join(specs) + ")"
+ if "return" in annotations:
+ result += formatreturns(formatannotation(annotations["return"]))
+ return result
+
+
+else:
+ from inspect import formatargspec as _inspect_formatargspec
+
+ def inspect_formatargspec(*spec, **kw):
+ # convert for a potential FullArgSpec from compat.getfullargspec()
+ return _inspect_formatargspec(*spec[0:4], **kw) # noqa
+
+
+# Fix deprecation of accessing ABCs straight from collections module
+# (which will stop working in 3.8).
+if py3k:
+ import collections.abc as collections_abc
+else:
+ import collections as collections_abc # noqa
+
+
+if py37:
+ import dataclasses
+
+ def dataclass_fields(cls):
+ """Return a sequence of all dataclasses.Field objects associated
+ with a class."""
+
+ if dataclasses.is_dataclass(cls):
+ return dataclasses.fields(cls)
+ else:
+ return []
+
+ def local_dataclass_fields(cls):
+ """Return a sequence of all dataclasses.Field objects associated with
+ a class, excluding those that originate from a superclass."""
+
+ if dataclasses.is_dataclass(cls):
+ super_fields = set()
+ for sup in cls.__bases__:
+ super_fields.update(dataclass_fields(sup))
+ return [
+ f for f in dataclasses.fields(cls) if f not in super_fields
+ ]
+ else:
+ return []
+
+
+else:
+
+ def dataclass_fields(cls):
+ return []
+
+ def local_dataclass_fields(cls):
+ return []
+
+
+def raise_from_cause(exception, exc_info=None):
+ r"""legacy. use raise\_()"""
+
+ if exc_info is None:
+ exc_info = sys.exc_info()
+ exc_type, exc_value, exc_tb = exc_info
+ cause = exc_value if exc_value is not exception else None
+ reraise(type(exception), exception, tb=exc_tb, cause=cause)
+
+
+def reraise(tp, value, tb=None, cause=None):
+ r"""legacy. use raise\_()"""
+
+ raise_(value, with_traceback=tb, from_=cause)
+
+
+def with_metaclass(meta, *bases, **kw):
+ """Create a base class with a metaclass.
+
+ Drops the middle class upon creation.
+
+ Source: https://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/
+
+ """
+
+ class metaclass(meta):
+ __call__ = type.__call__
+ __init__ = type.__init__
+
+ def __new__(cls, name, this_bases, d):
+ if this_bases is None:
+ cls = type.__new__(cls, name, (), d)
+ else:
+ cls = meta(name, bases, d)
+
+ if hasattr(cls, "__init_subclass__") and hasattr(
+ cls.__init_subclass__, "__func__"
+ ):
+ cls.__init_subclass__.__func__(cls, **kw)
+ return cls
+
+ return metaclass("temporary_class", None, {})
+
+
+if py3k:
+ from datetime import timezone
+else:
+ from datetime import datetime
+ from datetime import timedelta
+ from datetime import tzinfo
+
+ class timezone(tzinfo):
+ """Minimal port of python 3 timezone object"""
+
+ __slots__ = "_offset"
+
+ def __init__(self, offset):
+ if not isinstance(offset, timedelta):
+ raise TypeError("offset must be a timedelta")
+ if not self._minoffset <= offset <= self._maxoffset:
+ raise ValueError(
+ "offset must be a timedelta "
+ "strictly between -timedelta(hours=24) and "
+ "timedelta(hours=24)."
+ )
+ self._offset = offset
+
+ def __eq__(self, other):
+ if type(other) != timezone:
+ return False
+ return self._offset == other._offset
+
+ def __hash__(self):
+ return hash(self._offset)
+
+ def __repr__(self):
+ return "sqlalchemy.util.%s(%r)" % (
+ self.__class__.__name__,
+ self._offset,
+ )
+
+ def __str__(self):
+ return self.tzname(None)
+
+ def utcoffset(self, dt):
+ return self._offset
+
+ def tzname(self, dt):
+ return self._name_from_offset(self._offset)
+
+ def dst(self, dt):
+ return None
+
+ def fromutc(self, dt):
+ if isinstance(dt, datetime):
+ if dt.tzinfo is not self:
+ raise ValueError("fromutc: dt.tzinfo " "is not self")
+ return dt + self._offset
+ raise TypeError(
+ "fromutc() argument must be a datetime instance" " or None"
+ )
+
+ @staticmethod
+ def _timedelta_to_microseconds(timedelta):
+ """backport of timedelta._to_microseconds()"""
+ return (
+ timedelta.days * (24 * 3600) + timedelta.seconds
+ ) * 1000000 + timedelta.microseconds
+
+ @staticmethod
+ def _divmod_timedeltas(a, b):
+ """backport of timedelta.__divmod__"""
+
+ q, r = divmod(
+ timezone._timedelta_to_microseconds(a),
+ timezone._timedelta_to_microseconds(b),
+ )
+ return q, timedelta(0, 0, r)
+
+ @staticmethod
+ def _name_from_offset(delta):
+ if not delta:
+ return "UTC"
+ if delta < timedelta(0):
+ sign = "-"
+ delta = -delta
+ else:
+ sign = "+"
+ hours, rest = timezone._divmod_timedeltas(
+ delta, timedelta(hours=1)
+ )
+ minutes, rest = timezone._divmod_timedeltas(
+ rest, timedelta(minutes=1)
+ )
+ result = "UTC%s%02d:%02d" % (sign, hours, minutes)
+ if rest.seconds:
+ result += ":%02d" % (rest.seconds,)
+ if rest.microseconds:
+ result += ".%06d" % (rest.microseconds,)
+ return result
+
+ _maxoffset = timedelta(hours=23, minutes=59)
+ _minoffset = -_maxoffset
+
+ timezone.utc = timezone(timedelta(0))
diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py
new file mode 100644
index 0000000..e900b43
--- /dev/null
+++ b/lib/sqlalchemy/util/concurrency.py
@@ -0,0 +1,73 @@
+# util/concurrency.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from . import compat
+
+have_greenlet = False
+greenlet_error = None
+
+if compat.py3k:
+ try:
+ import greenlet # noqa: F401
+ except ImportError as e:
+ greenlet_error = str(e)
+ else:
+ have_greenlet = True
+ from ._concurrency_py3k import await_only
+ from ._concurrency_py3k import await_fallback
+ from ._concurrency_py3k import greenlet_spawn
+ from ._concurrency_py3k import is_exit_exception
+ from ._concurrency_py3k import AsyncAdaptedLock
+ from ._concurrency_py3k import _util_async_run # noqa: F401
+ from ._concurrency_py3k import (
+ _util_async_run_coroutine_function,
+ ) # noqa: F401, E501
+ from ._concurrency_py3k import asyncio # noqa: F401
+
+ # does not need greennlet, just Python 3
+ from ._compat_py3k import asynccontextmanager # noqa: F401
+
+if not have_greenlet:
+
+ asyncio = None # noqa: F811
+
+ def _not_implemented():
+ # this conditional is to prevent pylance from considering
+ # greenlet_spawn() etc as "no return" and dimming out code below it
+ if have_greenlet:
+ return None
+
+ if not compat.py3k:
+ raise ValueError("Cannot use this function in py2.")
+ else:
+ raise ValueError(
+ "the greenlet library is required to use this function."
+ " %s" % greenlet_error
+ if greenlet_error
+ else ""
+ )
+
+ def is_exit_exception(e): # noqa: F811
+ return not isinstance(e, Exception)
+
+ def await_only(thing): # noqa: F811
+ _not_implemented()
+
+ def await_fallback(thing): # noqa: F811
+ return thing
+
+ def greenlet_spawn(fn, *args, **kw): # noqa: F811
+ _not_implemented()
+
+ def AsyncAdaptedLock(*args, **kw): # noqa: F811
+ _not_implemented()
+
+ def _util_async_run(fn, *arg, **kw): # noqa: F811
+ return fn(*arg, **kw)
+
+ def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa: F811
+ _not_implemented()
diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py
new file mode 100644
index 0000000..b61516d
--- /dev/null
+++ b/lib/sqlalchemy/util/deprecations.py
@@ -0,0 +1,417 @@
+# util/deprecations.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Helpers related to deprecation of functions, methods, classes, other
+functionality."""
+
+import os
+import re
+
+from . import compat
+from .langhelpers import _hash_limit_string
+from .langhelpers import _warnings_warn
+from .langhelpers import decorator
+from .langhelpers import inject_docstring_text
+from .langhelpers import inject_param_text
+from .. import exc
+
+
+SQLALCHEMY_WARN_20 = False
+
+if os.getenv("SQLALCHEMY_WARN_20", "false").lower() in ("true", "yes", "1"):
+ SQLALCHEMY_WARN_20 = True
+
+
+def _warn_with_version(msg, version, type_, stacklevel, code=None):
+ if (
+ issubclass(type_, exc.Base20DeprecationWarning)
+ and not SQLALCHEMY_WARN_20
+ ):
+ return
+
+ warn = type_(msg, code=code)
+ warn.deprecated_since = version
+
+ _warnings_warn(warn, stacklevel=stacklevel + 1)
+
+
+def warn_deprecated(msg, version, stacklevel=3, code=None):
+ _warn_with_version(
+ msg, version, exc.SADeprecationWarning, stacklevel, code=code
+ )
+
+
+def warn_deprecated_limited(msg, args, version, stacklevel=3, code=None):
+ """Issue a deprecation warning with a parameterized string,
+ limiting the number of registrations.
+
+ """
+ if args:
+ msg = _hash_limit_string(msg, 10, args)
+ _warn_with_version(
+ msg, version, exc.SADeprecationWarning, stacklevel, code=code
+ )
+
+
+def warn_deprecated_20(msg, stacklevel=3, code=None):
+
+ _warn_with_version(
+ msg,
+ exc.RemovedIn20Warning.deprecated_since,
+ exc.RemovedIn20Warning,
+ stacklevel,
+ code=code,
+ )
+
+
+def deprecated_cls(version, message, constructor="__init__"):
+ header = ".. deprecated:: %s %s" % (version, (message or ""))
+
+ def decorate(cls):
+ return _decorate_cls_with_warning(
+ cls,
+ constructor,
+ exc.SADeprecationWarning,
+ message % dict(func=constructor),
+ version,
+ header,
+ )
+
+ return decorate
+
+
+def deprecated_20_cls(
+ clsname, alternative=None, constructor="__init__", becomes_legacy=False
+):
+ message = (
+ ".. deprecated:: 1.4 The %s class is considered legacy as of the "
+ "1.x series of SQLAlchemy and %s in 2.0."
+ % (
+ clsname,
+ "will be removed"
+ if not becomes_legacy
+ else "becomes a legacy construct",
+ )
+ )
+
+ if alternative:
+ message += " " + alternative
+
+ if becomes_legacy:
+ warning_cls = exc.LegacyAPIWarning
+ else:
+ warning_cls = exc.RemovedIn20Warning
+
+ def decorate(cls):
+ return _decorate_cls_with_warning(
+ cls,
+ constructor,
+ warning_cls,
+ message,
+ warning_cls.deprecated_since,
+ message,
+ )
+
+ return decorate
+
+
+def deprecated(
+ version,
+ message=None,
+ add_deprecation_to_docstring=True,
+ warning=None,
+ enable_warnings=True,
+):
+ """Decorates a function and issues a deprecation warning on use.
+
+ :param version:
+ Issue version in the warning.
+
+ :param message:
+ If provided, issue message in the warning. A sensible default
+ is used if not provided.
+
+ :param add_deprecation_to_docstring:
+ Default True. If False, the wrapped function's __doc__ is left
+ as-is. If True, the 'message' is prepended to the docs if
+ provided, or sensible default if message is omitted.
+
+ """
+
+ # nothing is deprecated "since" 2.0 at this time. All "removed in 2.0"
+ # should emit the RemovedIn20Warning, but messaging should be expressed
+ # in terms of "deprecated since 1.4".
+
+ if version == "2.0":
+ if warning is None:
+ warning = exc.RemovedIn20Warning
+ version = "1.4"
+ if add_deprecation_to_docstring:
+ header = ".. deprecated:: %s %s" % (
+ version,
+ (message or ""),
+ )
+ else:
+ header = None
+
+ if message is None:
+ message = "Call to deprecated function %(func)s"
+
+ if warning is None:
+ warning = exc.SADeprecationWarning
+
+ if warning is not exc.RemovedIn20Warning:
+ message += " (deprecated since: %s)" % version
+
+ def decorate(fn):
+ return _decorate_with_warning(
+ fn,
+ warning,
+ message % dict(func=fn.__name__),
+ version,
+ header,
+ enable_warnings=enable_warnings,
+ )
+
+ return decorate
+
+
+def moved_20(message, **kw):
+ return deprecated(
+ "2.0", message=message, warning=exc.MovedIn20Warning, **kw
+ )
+
+
+def deprecated_20(api_name, alternative=None, becomes_legacy=False, **kw):
+ type_reg = re.match("^:(attr|func|meth):", api_name)
+ if type_reg:
+ type_ = {"attr": "attribute", "func": "function", "meth": "method"}[
+ type_reg.group(1)
+ ]
+ else:
+ type_ = "construct"
+ message = (
+ "The %s %s is considered legacy as of the "
+ "1.x series of SQLAlchemy and %s in 2.0."
+ % (
+ api_name,
+ type_,
+ "will be removed"
+ if not becomes_legacy
+ else "becomes a legacy construct",
+ )
+ )
+
+ if ":attr:" in api_name:
+ attribute_ok = kw.pop("warn_on_attribute_access", False)
+ if not attribute_ok:
+ assert kw.get("enable_warnings") is False, (
+ "attribute %s will emit a warning on read access. "
+ "If you *really* want this, "
+ "add warn_on_attribute_access=True. Otherwise please add "
+ "enable_warnings=False." % api_name
+ )
+
+ if alternative:
+ message += " " + alternative
+
+ if becomes_legacy:
+ warning_cls = exc.LegacyAPIWarning
+ else:
+ warning_cls = exc.RemovedIn20Warning
+
+ return deprecated("2.0", message=message, warning=warning_cls, **kw)
+
+
+def deprecated_params(**specs):
+ """Decorates a function to warn on use of certain parameters.
+
+ e.g. ::
+
+ @deprecated_params(
+ weak_identity_map=(
+ "0.7",
+ "the :paramref:`.Session.weak_identity_map parameter "
+ "is deprecated."
+ )
+
+ )
+
+ """
+
+ messages = {}
+ versions = {}
+ version_warnings = {}
+
+ for param, (version, message) in specs.items():
+ versions[param] = version
+ messages[param] = _sanitize_restructured_text(message)
+ version_warnings[param] = (
+ exc.RemovedIn20Warning
+ if version == "2.0"
+ else exc.SADeprecationWarning
+ )
+
+ def decorate(fn):
+ spec = compat.inspect_getfullargspec(fn)
+
+ if spec.defaults is not None:
+ defaults = dict(
+ zip(
+ spec.args[(len(spec.args) - len(spec.defaults)) :],
+ spec.defaults,
+ )
+ )
+ check_defaults = set(defaults).intersection(messages)
+ check_kw = set(messages).difference(defaults)
+ else:
+ check_defaults = ()
+ check_kw = set(messages)
+
+ check_any_kw = spec.varkw
+
+ @decorator
+ def warned(fn, *args, **kwargs):
+ for m in check_defaults:
+ if (defaults[m] is None and kwargs[m] is not None) or (
+ defaults[m] is not None and kwargs[m] != defaults[m]
+ ):
+ _warn_with_version(
+ messages[m],
+ versions[m],
+ version_warnings[m],
+ stacklevel=3,
+ )
+
+ if check_any_kw in messages and set(kwargs).difference(
+ check_defaults
+ ):
+
+ _warn_with_version(
+ messages[check_any_kw],
+ versions[check_any_kw],
+ version_warnings[check_any_kw],
+ stacklevel=3,
+ )
+
+ for m in check_kw:
+ if m in kwargs:
+ _warn_with_version(
+ messages[m],
+ versions[m],
+ version_warnings[m],
+ stacklevel=3,
+ )
+ return fn(*args, **kwargs)
+
+ doc = fn.__doc__ is not None and fn.__doc__ or ""
+ if doc:
+ doc = inject_param_text(
+ doc,
+ {
+ param: ".. deprecated:: %s %s"
+ % ("1.4" if version == "2.0" else version, (message or ""))
+ for param, (version, message) in specs.items()
+ },
+ )
+ decorated = warned(fn)
+ decorated.__doc__ = doc
+ return decorated
+
+ return decorate
+
+
+def _sanitize_restructured_text(text):
+ def repl(m):
+ type_, name = m.group(1, 2)
+ if type_ in ("func", "meth"):
+ name += "()"
+ return name
+
+ text = re.sub(r":ref:`(.+) <.*>`", lambda m: '"%s"' % m.group(1), text)
+ return re.sub(r"\:(\w+)\:`~?(?:_\w+)?\.?(.+?)`", repl, text)
+
+
+def _decorate_cls_with_warning(
+ cls, constructor, wtype, message, version, docstring_header=None
+):
+ doc = cls.__doc__ is not None and cls.__doc__ or ""
+ if docstring_header is not None:
+
+ if constructor is not None:
+ docstring_header %= dict(func=constructor)
+
+ if issubclass(wtype, exc.Base20DeprecationWarning):
+ docstring_header += (
+ " (Background on SQLAlchemy 2.0 at: "
+ ":ref:`migration_20_toplevel`)"
+ )
+ doc = inject_docstring_text(doc, docstring_header, 1)
+
+ if type(cls) is type:
+ clsdict = dict(cls.__dict__)
+ clsdict["__doc__"] = doc
+ clsdict.pop("__dict__", None)
+ clsdict.pop("__weakref__", None)
+ cls = type(cls.__name__, cls.__bases__, clsdict)
+ if constructor is not None:
+ constructor_fn = clsdict[constructor]
+
+ else:
+ cls.__doc__ = doc
+ if constructor is not None:
+ constructor_fn = getattr(cls, constructor)
+
+ if constructor is not None:
+ setattr(
+ cls,
+ constructor,
+ _decorate_with_warning(
+ constructor_fn, wtype, message, version, None
+ ),
+ )
+ return cls
+
+
+def _decorate_with_warning(
+ func, wtype, message, version, docstring_header=None, enable_warnings=True
+):
+ """Wrap a function with a warnings.warn and augmented docstring."""
+
+ message = _sanitize_restructured_text(message)
+
+ if issubclass(wtype, exc.Base20DeprecationWarning):
+ doc_only = (
+ " (Background on SQLAlchemy 2.0 at: "
+ ":ref:`migration_20_toplevel`)"
+ )
+ else:
+ doc_only = ""
+
+ @decorator
+ def warned(fn, *args, **kwargs):
+ skip_warning = not enable_warnings or kwargs.pop(
+ "_sa_skip_warning", False
+ )
+ if not skip_warning:
+ _warn_with_version(message, version, wtype, stacklevel=3)
+ return fn(*args, **kwargs)
+
+ doc = func.__doc__ is not None and func.__doc__ or ""
+ if docstring_header is not None:
+ docstring_header %= dict(func=func.__name__)
+
+ docstring_header += doc_only
+
+ doc = inject_docstring_text(doc, docstring_header, 1)
+
+ decorated = warned(func)
+ decorated.__doc__ = doc
+ decorated._sa_warn = lambda: _warn_with_version(
+ message, version, wtype, stacklevel=3
+ )
+ return decorated
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
new file mode 100644
index 0000000..c3636f0
--- /dev/null
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -0,0 +1,1945 @@
+# util/langhelpers.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Routines to help with the creation, loading and introspection of
+modules, classes, hierarchies, attributes, functions, and methods.
+
+"""
+
+import collections
+from functools import update_wrapper
+import hashlib
+import inspect
+import itertools
+import operator
+import re
+import sys
+import textwrap
+import types
+import warnings
+
+from . import _collections
+from . import compat
+from .. import exc
+
+
+def md5_hex(x):
+ if compat.py3k:
+ x = x.encode("utf-8")
+ m = hashlib.md5()
+ m.update(x)
+ return m.hexdigest()
+
+
+class safe_reraise(object):
+ """Reraise an exception after invoking some
+ handler code.
+
+ Stores the existing exception info before
+ invoking so that it is maintained across a potential
+ coroutine context switch.
+
+ e.g.::
+
+ try:
+ sess.commit()
+ except:
+ with safe_reraise():
+ sess.rollback()
+
+ """
+
+ __slots__ = ("warn_only", "_exc_info")
+
+ def __init__(self, warn_only=False):
+ self.warn_only = warn_only
+
+ def __enter__(self):
+ self._exc_info = sys.exc_info()
+
+ def __exit__(self, type_, value, traceback):
+ # see #2703 for notes
+ if type_ is None:
+ exc_type, exc_value, exc_tb = self._exc_info
+ self._exc_info = None # remove potential circular references
+ if not self.warn_only:
+ compat.raise_(
+ exc_value,
+ with_traceback=exc_tb,
+ )
+ else:
+ if not compat.py3k and self._exc_info and self._exc_info[1]:
+ # emulate Py3K's behavior of telling us when an exception
+ # occurs in an exception handler.
+ warn(
+ "An exception has occurred during handling of a "
+ "previous exception. The previous exception "
+ "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1])
+ )
+ self._exc_info = None # remove potential circular references
+ compat.raise_(value, with_traceback=traceback)
+
+
+def walk_subclasses(cls):
+ seen = set()
+
+ stack = [cls]
+ while stack:
+ cls = stack.pop()
+ if cls in seen:
+ continue
+ else:
+ seen.add(cls)
+ stack.extend(cls.__subclasses__())
+ yield cls
+
+
+def string_or_unprintable(element):
+ if isinstance(element, compat.string_types):
+ return element
+ else:
+ try:
+ return str(element)
+ except Exception:
+ return "unprintable element %r" % element
+
+
+def clsname_as_plain_name(cls):
+ return " ".join(
+ n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__)
+ )
+
+
+def method_is_overridden(instance_or_cls, against_method):
+ """Return True if the two class methods don't match."""
+
+ if not isinstance(instance_or_cls, type):
+ current_cls = instance_or_cls.__class__
+ else:
+ current_cls = instance_or_cls
+
+ method_name = against_method.__name__
+
+ current_method = getattr(current_cls, method_name)
+
+ return current_method != against_method
+
+
+def decode_slice(slc):
+ """decode a slice object as sent to __getitem__.
+
+ takes into account the 2.5 __index__() method, basically.
+
+ """
+ ret = []
+ for x in slc.start, slc.stop, slc.step:
+ if hasattr(x, "__index__"):
+ x = x.__index__()
+ ret.append(x)
+ return tuple(ret)
+
+
+def _unique_symbols(used, *bases):
+ used = set(used)
+ for base in bases:
+ pool = itertools.chain(
+ (base,),
+ compat.itertools_imap(lambda i: base + str(i), range(1000)),
+ )
+ for sym in pool:
+ if sym not in used:
+ used.add(sym)
+ yield sym
+ break
+ else:
+ raise NameError("exhausted namespace for symbol base %s" % base)
+
+
+def map_bits(fn, n):
+ """Call the given function given each nonzero bit from n."""
+
+ while n:
+ b = n & (~n + 1)
+ yield fn(b)
+ n ^= b
+
+
+def decorator(target):
+ """A signature-matching decorator factory."""
+
+ def decorate(fn):
+ if not inspect.isfunction(fn) and not inspect.ismethod(fn):
+ raise Exception("not a decoratable function")
+
+ spec = compat.inspect_getfullargspec(fn)
+ env = {}
+
+ spec = _update_argspec_defaults_into_env(spec, env)
+
+ names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
+ targ_name, fn_name = _unique_symbols(names, "target", "fn")
+
+ metadata = dict(target=targ_name, fn=fn_name)
+ metadata.update(format_argspec_plus(spec, grouped=False))
+ metadata["name"] = fn.__name__
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return %(target)s(%(fn)s, %(apply_kw)s)
+"""
+ % metadata
+ )
+ env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
+
+ decorated = _exec_code_in_env(code, env, fn.__name__)
+ decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ decorated.__wrapped__ = fn
+ return update_wrapper(decorated, fn)
+
+ return update_wrapper(decorate, target)
+
+
+def _update_argspec_defaults_into_env(spec, env):
+ """given a FullArgSpec, convert defaults to be symbol names in an env."""
+
+ if spec.defaults:
+ new_defaults = []
+ i = 0
+ for arg in spec.defaults:
+ if type(arg).__module__ not in ("builtins", "__builtin__"):
+ name = "x%d" % i
+ env[name] = arg
+ new_defaults.append(name)
+ i += 1
+ else:
+ new_defaults.append(arg)
+ elem = list(spec)
+ elem[3] = tuple(new_defaults)
+ return compat.FullArgSpec(*elem)
+ else:
+ return spec
+
+
+def _exec_code_in_env(code, env, fn_name):
+ exec(code, env)
+ return env[fn_name]
+
+
+def public_factory(target, location, class_location=None):
+ """Produce a wrapping function for the given cls or classmethod.
+
+ Rationale here is so that the __init__ method of the
+ class can serve as documentation for the function.
+
+ """
+
+ if isinstance(target, type):
+ fn = target.__init__
+ callable_ = target
+ doc = (
+ "Construct a new :class:`%s` object. \n\n"
+ "This constructor is mirrored as a public API function; "
+ "see :func:`sqlalchemy%s` "
+ "for a full usage and argument description."
+ % (
+ class_location if class_location else ".%s" % target.__name__,
+ location,
+ )
+ )
+ else:
+ fn = callable_ = target
+ doc = (
+ "This function is mirrored; see :func:`sqlalchemy%s` "
+ "for a description of arguments." % location
+ )
+
+ location_name = location.split(".")[-1]
+ spec = compat.inspect_getfullargspec(fn)
+ del spec[0][0]
+ metadata = format_argspec_plus(spec, grouped=False)
+ metadata["name"] = location_name
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return cls(%(apply_kw)s)
+"""
+ % metadata
+ )
+ env = {
+ "cls": callable_,
+ "symbol": symbol,
+ "__name__": callable_.__module__,
+ }
+ exec(code, env)
+ decorated = env[location_name]
+
+ if hasattr(fn, "_linked_to"):
+ linked_to, linked_to_location = fn._linked_to
+ linked_to_doc = linked_to.__doc__
+ if class_location is None:
+ class_location = "%s.%s" % (target.__module__, target.__name__)
+
+ linked_to_doc = inject_docstring_text(
+ linked_to_doc,
+ ".. container:: inherited_member\n\n "
+ "This documentation is inherited from :func:`sqlalchemy%s`; "
+ "this constructor, :func:`sqlalchemy%s`, "
+ "creates a :class:`sqlalchemy%s` object. See that class for "
+ "additional details describing this subclass."
+ % (linked_to_location, location, class_location),
+ 1,
+ )
+ decorated.__doc__ = linked_to_doc
+ else:
+ decorated.__doc__ = fn.__doc__
+
+ decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0]
+ if decorated.__module__ not in sys.modules:
+ raise ImportError(
+ "public_factory location %s is not in sys.modules"
+ % (decorated.__module__,)
+ )
+
+ if compat.py2k or hasattr(fn, "__func__"):
+ fn.__func__.__doc__ = doc
+ if not hasattr(fn.__func__, "_linked_to"):
+ fn.__func__._linked_to = (decorated, location)
+ else:
+ fn.__doc__ = doc
+ if not hasattr(fn, "_linked_to"):
+ fn._linked_to = (decorated, location)
+
+ return decorated
+
+
+class PluginLoader(object):
+ def __init__(self, group, auto_fn=None):
+ self.group = group
+ self.impls = {}
+ self.auto_fn = auto_fn
+
+ def clear(self):
+ self.impls.clear()
+
+ def load(self, name):
+ if name in self.impls:
+ return self.impls[name]()
+
+ if self.auto_fn:
+ loader = self.auto_fn(name)
+ if loader:
+ self.impls[name] = loader
+ return loader()
+
+ for impl in compat.importlib_metadata_get(self.group):
+ if impl.name == name:
+ self.impls[name] = impl.load
+ return impl.load()
+
+ raise exc.NoSuchModuleError(
+ "Can't load plugin: %s:%s" % (self.group, name)
+ )
+
+ def register(self, name, modulepath, objname):
+ def load():
+ mod = compat.import_(modulepath)
+ for token in modulepath.split(".")[1:]:
+ mod = getattr(mod, token)
+ return getattr(mod, objname)
+
+ self.impls[name] = load
+
+
+def _inspect_func_args(fn):
+ try:
+ co_varkeywords = inspect.CO_VARKEYWORDS
+ except AttributeError:
+ # https://docs.python.org/3/library/inspect.html
+ # The flags are specific to CPython, and may not be defined in other
+ # Python implementations. Furthermore, the flags are an implementation
+ # detail, and can be removed or deprecated in future Python releases.
+ spec = compat.inspect_getfullargspec(fn)
+ return spec[0], bool(spec[2])
+ else:
+ # use fn.__code__ plus flags to reduce method call overhead
+ co = fn.__code__
+ nargs = co.co_argcount
+ return (
+ list(co.co_varnames[:nargs]),
+ bool(co.co_flags & co_varkeywords),
+ )
+
+
+def get_cls_kwargs(cls, _set=None):
+ r"""Return the full set of inherited kwargs for the given `cls`.
+
+ Probes a class's __init__ method, collecting all named arguments. If the
+ __init__ defines a \**kwargs catch-all, then the constructor is presumed
+ to pass along unrecognized keywords to its base classes, and the
+ collection process is repeated recursively on each of the bases.
+
+ Uses a subset of inspect.getfullargspec() to cut down on method overhead,
+ as this is used within the Core typing system to create copies of type
+ objects which is a performance-sensitive operation.
+
+ No anonymous tuple arguments please !
+
+ """
+ toplevel = _set is None
+ if toplevel:
+ _set = set()
+
+ ctr = cls.__dict__.get("__init__", False)
+
+ has_init = (
+ ctr
+ and isinstance(ctr, types.FunctionType)
+ and isinstance(ctr.__code__, types.CodeType)
+ )
+
+ if has_init:
+ names, has_kw = _inspect_func_args(ctr)
+ _set.update(names)
+
+ if not has_kw and not toplevel:
+ return None
+
+ if not has_init or has_kw:
+ for c in cls.__bases__:
+ if get_cls_kwargs(c, _set) is None:
+ break
+
+ _set.discard("self")
+ return _set
+
+
+def get_func_kwargs(func):
+ """Return the set of legal kwargs for the given `func`.
+
+ Uses getargspec so is safe to call for methods, functions,
+ etc.
+
+ """
+
+ return compat.inspect_getfullargspec(func)[0]
+
+
+def get_callable_argspec(fn, no_self=False, _is_init=False):
+ """Return the argument signature for any callable.
+
+ All pure-Python callables are accepted, including
+ functions, methods, classes, objects with __call__;
+ builtins and other edge cases like functools.partial() objects
+ raise a TypeError.
+
+ """
+ if inspect.isbuiltin(fn):
+ raise TypeError("Can't inspect builtin: %s" % fn)
+ elif inspect.isfunction(fn):
+ if _is_init and no_self:
+ spec = compat.inspect_getfullargspec(fn)
+ return compat.FullArgSpec(
+ spec.args[1:],
+ spec.varargs,
+ spec.varkw,
+ spec.defaults,
+ spec.kwonlyargs,
+ spec.kwonlydefaults,
+ spec.annotations,
+ )
+ else:
+ return compat.inspect_getfullargspec(fn)
+ elif inspect.ismethod(fn):
+ if no_self and (_is_init or fn.__self__):
+ spec = compat.inspect_getfullargspec(fn.__func__)
+ return compat.FullArgSpec(
+ spec.args[1:],
+ spec.varargs,
+ spec.varkw,
+ spec.defaults,
+ spec.kwonlyargs,
+ spec.kwonlydefaults,
+ spec.annotations,
+ )
+ else:
+ return compat.inspect_getfullargspec(fn.__func__)
+ elif inspect.isclass(fn):
+ return get_callable_argspec(
+ fn.__init__, no_self=no_self, _is_init=True
+ )
+ elif hasattr(fn, "__func__"):
+ return compat.inspect_getfullargspec(fn.__func__)
+ elif hasattr(fn, "__call__"):
+ if inspect.ismethod(fn.__call__):
+ return get_callable_argspec(fn.__call__, no_self=no_self)
+ else:
+ raise TypeError("Can't inspect callable: %s" % fn)
+ else:
+ raise TypeError("Can't inspect callable: %s" % fn)
+
+
+def format_argspec_plus(fn, grouped=True):
+ """Returns a dictionary of formatted, introspected function arguments.
+
+ A enhanced variant of inspect.formatargspec to support code generation.
+
+ fn
+ An inspectable callable or tuple of inspect getargspec() results.
+ grouped
+ Defaults to True; include (parens, around, argument) lists
+
+ Returns:
+
+ args
+ Full inspect.formatargspec for fn
+ self_arg
+ The name of the first positional argument, varargs[0], or None
+ if the function defines no positional arguments.
+ apply_pos
+ args, re-written in calling rather than receiving syntax. Arguments are
+ passed positionally.
+ apply_kw
+ Like apply_pos, except keyword-ish args are passed as keywords.
+ apply_pos_proxied
+ Like apply_pos but omits the self/cls argument
+
+ Example::
+
+ >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
+ {'args': '(self, a, b, c=3, **d)',
+ 'self_arg': 'self',
+ 'apply_kw': '(self, a, b, c=c, **d)',
+ 'apply_pos': '(self, a, b, c, **d)'}
+
+ """
+ if compat.callable(fn):
+ spec = compat.inspect_getfullargspec(fn)
+ else:
+ spec = fn
+
+ args = compat.inspect_formatargspec(*spec)
+
+ apply_pos = compat.inspect_formatargspec(
+ spec[0], spec[1], spec[2], None, spec[4]
+ )
+
+ if spec[0]:
+ self_arg = spec[0][0]
+
+ apply_pos_proxied = compat.inspect_formatargspec(
+ spec[0][1:], spec[1], spec[2], None, spec[4]
+ )
+
+ elif spec[1]:
+ # I'm not sure what this is
+ self_arg = "%s[0]" % spec[1]
+
+ apply_pos_proxied = apply_pos
+ else:
+ self_arg = None
+ apply_pos_proxied = apply_pos
+
+ num_defaults = 0
+ if spec[3]:
+ num_defaults += len(spec[3])
+ if spec[4]:
+ num_defaults += len(spec[4])
+ name_args = spec[0] + spec[4]
+
+ if num_defaults:
+ defaulted_vals = name_args[0 - num_defaults :]
+ else:
+ defaulted_vals = ()
+
+ apply_kw = compat.inspect_formatargspec(
+ name_args,
+ spec[1],
+ spec[2],
+ defaulted_vals,
+ formatvalue=lambda x: "=" + x,
+ )
+
+ if spec[0]:
+ apply_kw_proxied = compat.inspect_formatargspec(
+ name_args[1:],
+ spec[1],
+ spec[2],
+ defaulted_vals,
+ formatvalue=lambda x: "=" + x,
+ )
+ else:
+ apply_kw_proxied = apply_kw
+
+ if grouped:
+ return dict(
+ args=args,
+ self_arg=self_arg,
+ apply_pos=apply_pos,
+ apply_kw=apply_kw,
+ apply_pos_proxied=apply_pos_proxied,
+ apply_kw_proxied=apply_kw_proxied,
+ )
+ else:
+ return dict(
+ args=args[1:-1],
+ self_arg=self_arg,
+ apply_pos=apply_pos[1:-1],
+ apply_kw=apply_kw[1:-1],
+ apply_pos_proxied=apply_pos_proxied[1:-1],
+ apply_kw_proxied=apply_kw_proxied[1:-1],
+ )
+
+
+def format_argspec_init(method, grouped=True):
+ """format_argspec_plus with considerations for typical __init__ methods
+
+ Wraps format_argspec_plus with error handling strategies for typical
+ __init__ cases::
+
+ object.__init__ -> (self)
+ other unreflectable (usually C) -> (self, *args, **kwargs)
+
+ """
+ if method is object.__init__:
+ args = "(self)" if grouped else "self"
+ proxied = "()" if grouped else ""
+ else:
+ try:
+ return format_argspec_plus(method, grouped=grouped)
+ except TypeError:
+ args = (
+ "(self, *args, **kwargs)"
+ if grouped
+ else "self, *args, **kwargs"
+ )
+ proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
+ return dict(
+ self_arg="self",
+ args=args,
+ apply_pos=args,
+ apply_kw=args,
+ apply_pos_proxied=proxied,
+ apply_kw_proxied=proxied,
+ )
+
+
+def create_proxy_methods(
+ target_cls,
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ classmethods=(),
+ methods=(),
+ attributes=(),
+):
+ """A class decorator that will copy attributes to a proxy class.
+
+ The class to be instrumented must define a single accessor "_proxied".
+
+ """
+
+ def decorate(cls):
+ def instrument(name, clslevel=False):
+ fn = getattr(target_cls, name)
+ spec = compat.inspect_getfullargspec(fn)
+ env = {"__name__": fn.__module__}
+
+ spec = _update_argspec_defaults_into_env(spec, env)
+ caller_argspec = format_argspec_plus(spec, grouped=False)
+
+ metadata = {
+ "name": fn.__name__,
+ "apply_pos_proxied": caller_argspec["apply_pos_proxied"],
+ "apply_kw_proxied": caller_argspec["apply_kw_proxied"],
+ "args": caller_argspec["args"],
+ "self_arg": caller_argspec["self_arg"],
+ }
+
+ if clslevel:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return target_cls.%(name)s(%(apply_kw_proxied)s)"
+ % metadata
+ )
+ env["target_cls"] = target_cls
+ else:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return %(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)" # noqa: E501
+ % metadata
+ )
+
+ proxy_fn = _exec_code_in_env(code, env, fn.__name__)
+ proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ proxy_fn.__doc__ = inject_docstring_text(
+ fn.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (target_cls_sphinx_name, proxy_cls_sphinx_name),
+ 1,
+ )
+
+ if clslevel:
+ proxy_fn = classmethod(proxy_fn)
+
+ return proxy_fn
+
+ def makeprop(name):
+ attr = target_cls.__dict__.get(name, None)
+
+ if attr is not None:
+ doc = inject_docstring_text(
+ attr.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ ),
+ 1,
+ )
+ else:
+ doc = None
+
+ code = (
+ "def set_(self, attr):\n"
+ " self._proxied.%(name)s = attr\n"
+ "def get(self):\n"
+ " return self._proxied.%(name)s\n"
+ "get.__doc__ = doc\n"
+ "getset = property(get, set_)"
+ ) % {"name": name}
+
+ getset = _exec_code_in_env(code, {"doc": doc}, "getset")
+
+ return getset
+
+ for meth in methods:
+ if hasattr(cls, meth):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, meth)
+ )
+ setattr(cls, meth, instrument(meth))
+
+ for prop in attributes:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, makeprop(prop))
+
+ for prop in classmethods:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, instrument(prop, clslevel=True))
+
+ return cls
+
+ return decorate
+
+
+def getargspec_init(method):
+ """inspect.getargspec with considerations for typical __init__ methods
+
+ Wraps inspect.getargspec with error handling for typical __init__ cases::
+
+ object.__init__ -> (self)
+ other unreflectable (usually C) -> (self, *args, **kwargs)
+
+ """
+ try:
+ return compat.inspect_getfullargspec(method)
+ except TypeError:
+ if method is object.__init__:
+ return (["self"], None, None, None)
+ else:
+ return (["self"], "args", "kwargs", None)
+
+
+def unbound_method_to_callable(func_or_cls):
+ """Adjust the incoming callable such that a 'self' argument is not
+ required.
+
+ """
+
+ if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__:
+ return func_or_cls.__func__
+ else:
+ return func_or_cls
+
+
+def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()):
+ """Produce a __repr__() based on direct association of the __init__()
+ specification vs. same-named attributes present.
+
+ """
+ if to_inspect is None:
+ to_inspect = [obj]
+ else:
+ to_inspect = _collections.to_list(to_inspect)
+
+ missing = object()
+
+ pos_args = []
+ kw_args = _collections.OrderedDict()
+ vargs = None
+ for i, insp in enumerate(to_inspect):
+ try:
+ spec = compat.inspect_getfullargspec(insp.__init__)
+ except TypeError:
+ continue
+ else:
+ default_len = spec.defaults and len(spec.defaults) or 0
+ if i == 0:
+ if spec.varargs:
+ vargs = spec.varargs
+ if default_len:
+ pos_args.extend(spec.args[1:-default_len])
+ else:
+ pos_args.extend(spec.args[1:])
+ else:
+ kw_args.update(
+ [(arg, missing) for arg in spec.args[1:-default_len]]
+ )
+
+ if default_len:
+ kw_args.update(
+ [
+ (arg, default)
+ for arg, default in zip(
+ spec.args[-default_len:], spec.defaults
+ )
+ ]
+ )
+ output = []
+
+ output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
+
+ if vargs is not None and hasattr(obj, vargs):
+ output.extend([repr(val) for val in getattr(obj, vargs)])
+
+ for arg, defval in kw_args.items():
+ if arg in omit_kwarg:
+ continue
+ try:
+ val = getattr(obj, arg, missing)
+ if val is not missing and val != defval:
+ output.append("%s=%r" % (arg, val))
+ except Exception:
+ pass
+
+ if additional_kw:
+ for arg, defval in additional_kw:
+ try:
+ val = getattr(obj, arg, missing)
+ if val is not missing and val != defval:
+ output.append("%s=%r" % (arg, val))
+ except Exception:
+ pass
+
+ return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
+
+
+class portable_instancemethod(object):
+ """Turn an instancemethod into a (parent, name) pair
+ to produce a serializable callable.
+
+ """
+
+ __slots__ = "target", "name", "kwargs", "__weakref__"
+
+ def __getstate__(self):
+ return {
+ "target": self.target,
+ "name": self.name,
+ "kwargs": self.kwargs,
+ }
+
+ def __setstate__(self, state):
+ self.target = state["target"]
+ self.name = state["name"]
+ self.kwargs = state.get("kwargs", ())
+
+ def __init__(self, meth, kwargs=()):
+ self.target = meth.__self__
+ self.name = meth.__name__
+ self.kwargs = kwargs
+
+ def __call__(self, *arg, **kw):
+ kw.update(self.kwargs)
+ return getattr(self.target, self.name)(*arg, **kw)
+
+
+def class_hierarchy(cls):
+ """Return an unordered sequence of all classes related to cls.
+
+ Traverses diamond hierarchies.
+
+ Fibs slightly: subclasses of builtin types are not returned. Thus
+ class_hierarchy(class A(object)) returns (A, object), not A plus every
+ class systemwide that derives from object.
+
+ Old-style classes are discarded and hierarchies rooted on them
+ will not be descended.
+
+ """
+ if compat.py2k:
+ if isinstance(cls, types.ClassType):
+ return list()
+
+ hier = {cls}
+ process = list(cls.__mro__)
+ while process:
+ c = process.pop()
+ if compat.py2k:
+ if isinstance(c, types.ClassType):
+ continue
+ bases = (
+ _
+ for _ in c.__bases__
+ if _ not in hier and not isinstance(_, types.ClassType)
+ )
+ else:
+ bases = (_ for _ in c.__bases__ if _ not in hier)
+
+ for b in bases:
+ process.append(b)
+ hier.add(b)
+
+ if compat.py3k:
+ if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
+ continue
+ else:
+ if c.__module__ == "__builtin__" or not hasattr(
+ c, "__subclasses__"
+ ):
+ continue
+
+ for s in [_ for _ in c.__subclasses__() if _ not in hier]:
+ process.append(s)
+ hier.add(s)
+ return list(hier)
+
+
+def iterate_attributes(cls):
+ """iterate all the keys and attributes associated
+ with a class, without using getattr().
+
+ Does not use getattr() so that class-sensitive
+ descriptors (i.e. property.__get__()) are not called.
+
+ """
+ keys = dir(cls)
+ for key in keys:
+ for c in cls.__mro__:
+ if key in c.__dict__:
+ yield (key, c.__dict__[key])
+ break
+
+
+def monkeypatch_proxied_specials(
+ into_cls,
+ from_cls,
+ skip=None,
+ only=None,
+ name="self.proxy",
+ from_instance=None,
+):
+ """Automates delegation of __specials__ for a proxying type."""
+
+ if only:
+ dunders = only
+ else:
+ if skip is None:
+ skip = (
+ "__slots__",
+ "__del__",
+ "__getattribute__",
+ "__metaclass__",
+ "__getstate__",
+ "__setstate__",
+ )
+ dunders = [
+ m
+ for m in dir(from_cls)
+ if (
+ m.startswith("__")
+ and m.endswith("__")
+ and not hasattr(into_cls, m)
+ and m not in skip
+ )
+ ]
+
+ for method in dunders:
+ try:
+ fn = getattr(from_cls, method)
+ if not hasattr(fn, "__call__"):
+ continue
+ fn = getattr(fn, "__func__", fn)
+ except AttributeError:
+ continue
+ try:
+ spec = compat.inspect_getfullargspec(fn)
+ fn_args = compat.inspect_formatargspec(spec[0])
+ d_args = compat.inspect_formatargspec(spec[0][1:])
+ except TypeError:
+ fn_args = "(self, *args, **kw)"
+ d_args = "(*args, **kw)"
+
+ py = (
+ "def %(method)s%(fn_args)s: "
+ "return %(name)s.%(method)s%(d_args)s" % locals()
+ )
+
+ env = from_instance is not None and {name: from_instance} or {}
+ compat.exec_(py, env)
+ try:
+ env[method].__defaults__ = fn.__defaults__
+ except AttributeError:
+ pass
+ setattr(into_cls, method, env[method])
+
+
+def methods_equivalent(meth1, meth2):
+ """Return True if the two methods are the same implementation."""
+
+ return getattr(meth1, "__func__", meth1) is getattr(
+ meth2, "__func__", meth2
+ )
+
+
+def as_interface(obj, cls=None, methods=None, required=None):
+ """Ensure basic interface compliance for an instance or dict of callables.
+
+ Checks that ``obj`` implements public methods of ``cls`` or has members
+ listed in ``methods``. If ``required`` is not supplied, implementing at
+ least one interface method is sufficient. Methods present on ``obj`` that
+ are not in the interface are ignored.
+
+ If ``obj`` is a dict and ``dict`` does not meet the interface
+ requirements, the keys of the dictionary are inspected. Keys present in
+ ``obj`` that are not in the interface will raise TypeErrors.
+
+ Raises TypeError if ``obj`` does not meet the interface criteria.
+
+ In all passing cases, an object with callable members is returned. In the
+ simple case, ``obj`` is returned as-is; if dict processing kicks in then
+ an anonymous class is returned.
+
+ obj
+ A type, instance, or dictionary of callables.
+ cls
+ Optional, a type. All public methods of cls are considered the
+ interface. An ``obj`` instance of cls will always pass, ignoring
+ ``required``..
+ methods
+ Optional, a sequence of method names to consider as the interface.
+ required
+ Optional, a sequence of mandatory implementations. If omitted, an
+ ``obj`` that provides at least one interface method is considered
+ sufficient. As a convenience, required may be a type, in which case
+ all public methods of the type are required.
+
+ """
+ if not cls and not methods:
+ raise TypeError("a class or collection of method names are required")
+
+ if isinstance(cls, type) and isinstance(obj, cls):
+ return obj
+
+ interface = set(methods or [m for m in dir(cls) if not m.startswith("_")])
+ implemented = set(dir(obj))
+
+ complies = operator.ge
+ if isinstance(required, type):
+ required = interface
+ elif not required:
+ required = set()
+ complies = operator.gt
+ else:
+ required = set(required)
+
+ if complies(implemented.intersection(interface), required):
+ return obj
+
+ # No dict duck typing here.
+ if not isinstance(obj, dict):
+ qualifier = complies is operator.gt and "any of" or "all of"
+ raise TypeError(
+ "%r does not implement %s: %s"
+ % (obj, qualifier, ", ".join(interface))
+ )
+
+ class AnonymousInterface(object):
+ """A callable-holding shell."""
+
+ if cls:
+ AnonymousInterface.__name__ = "Anonymous" + cls.__name__
+ found = set()
+
+ for method, impl in dictlike_iteritems(obj):
+ if method not in interface:
+ raise TypeError("%r: unknown in this interface" % method)
+ if not compat.callable(impl):
+ raise TypeError("%r=%r is not callable" % (method, impl))
+ setattr(AnonymousInterface, method, staticmethod(impl))
+ found.add(method)
+
+ if complies(found, required):
+ return AnonymousInterface
+
+ raise TypeError(
+ "dictionary does not contain required keys %s"
+ % ", ".join(required - found)
+ )
+
+
+class memoized_property(object):
+ """A read-only @property that is only evaluated once."""
+
+ def __init__(self, fget, doc=None):
+ self.fget = fget
+ self.__doc__ = doc or fget.__doc__
+ self.__name__ = fget.__name__
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ obj.__dict__[self.__name__] = result = self.fget(obj)
+ return result
+
+ def _reset(self, obj):
+ memoized_property.reset(obj, self.__name__)
+
+ @classmethod
+ def reset(cls, obj, name):
+ obj.__dict__.pop(name, None)
+
+
+def memoized_instancemethod(fn):
+ """Decorate a method memoize its return value.
+
+ Best applied to no-arg methods: memoization is not sensitive to
+ argument values, and will always return the same value even when
+ called with different arguments.
+
+ """
+
+ def oneshot(self, *args, **kw):
+ result = fn(self, *args, **kw)
+
+ def memo(*a, **kw):
+ return result
+
+ memo.__name__ = fn.__name__
+ memo.__doc__ = fn.__doc__
+ self.__dict__[fn.__name__] = memo
+ return result
+
+ return update_wrapper(oneshot, fn)
+
+
+class HasMemoized(object):
+ """A class that maintains the names of memoized elements in a
+ collection for easy cache clearing, generative, etc.
+
+ """
+
+ __slots__ = ()
+
+ _memoized_keys = frozenset()
+
+ def _reset_memoizations(self):
+ for elem in self._memoized_keys:
+ self.__dict__.pop(elem, None)
+
+ def _assert_no_memoizations(self):
+ for elem in self._memoized_keys:
+ assert elem not in self.__dict__
+
+ def _set_memoized_attribute(self, key, value):
+ self.__dict__[key] = value
+ self._memoized_keys |= {key}
+
+ class memoized_attribute(object):
+ """A read-only @property that is only evaluated once.
+
+ :meta private:
+
+ """
+
+ def __init__(self, fget, doc=None):
+ self.fget = fget
+ self.__doc__ = doc or fget.__doc__
+ self.__name__ = fget.__name__
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ obj.__dict__[self.__name__] = result = self.fget(obj)
+ obj._memoized_keys |= {self.__name__}
+ return result
+
+ @classmethod
+ def memoized_instancemethod(cls, fn):
+ """Decorate a method memoize its return value."""
+
+ def oneshot(self, *args, **kw):
+ result = fn(self, *args, **kw)
+
+ def memo(*a, **kw):
+ return result
+
+ memo.__name__ = fn.__name__
+ memo.__doc__ = fn.__doc__
+ self.__dict__[fn.__name__] = memo
+ self._memoized_keys |= {fn.__name__}
+ return result
+
+ return update_wrapper(oneshot, fn)
+
+
+class MemoizedSlots(object):
+ """Apply memoized items to an object using a __getattr__ scheme.
+
+ This allows the functionality of memoized_property and
+ memoized_instancemethod to be available to a class using __slots__.
+
+ """
+
+ __slots__ = ()
+
+ def _fallback_getattr(self, key):
+ raise AttributeError(key)
+
+ def __getattr__(self, key):
+ if key.startswith("_memoized"):
+ raise AttributeError(key)
+ elif hasattr(self, "_memoized_attr_%s" % key):
+ value = getattr(self, "_memoized_attr_%s" % key)()
+ setattr(self, key, value)
+ return value
+ elif hasattr(self, "_memoized_method_%s" % key):
+ fn = getattr(self, "_memoized_method_%s" % key)
+
+ def oneshot(*args, **kw):
+ result = fn(*args, **kw)
+
+ def memo(*a, **kw):
+ return result
+
+ memo.__name__ = fn.__name__
+ memo.__doc__ = fn.__doc__
+ setattr(self, key, memo)
+ return result
+
+ oneshot.__doc__ = fn.__doc__
+ return oneshot
+ else:
+ return self._fallback_getattr(key)
+
+
+# from paste.deploy.converters
+def asbool(obj):
+ if isinstance(obj, compat.string_types):
+ obj = obj.strip().lower()
+ if obj in ["true", "yes", "on", "y", "t", "1"]:
+ return True
+ elif obj in ["false", "no", "off", "n", "f", "0"]:
+ return False
+ else:
+ raise ValueError("String is not true/false: %r" % obj)
+ return bool(obj)
+
+
+def bool_or_str(*text):
+ """Return a callable that will evaluate a string as
+ boolean, or one of a set of "alternate" string values.
+
+ """
+
+ def bool_or_value(obj):
+ if obj in text:
+ return obj
+ else:
+ return asbool(obj)
+
+ return bool_or_value
+
+
+def asint(value):
+ """Coerce to integer."""
+
+ if value is None:
+ return value
+ return int(value)
+
+
+def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None):
+ r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if
+ necessary. If 'flexi_bool' is True, the string '0' is considered false
+ when coercing to boolean.
+ """
+
+ if dest is None:
+ dest = kw
+
+ if (
+ key in kw
+ and (not isinstance(type_, type) or not isinstance(kw[key], type_))
+ and kw[key] is not None
+ ):
+ if type_ is bool and flexi_bool:
+ dest[key] = asbool(kw[key])
+ else:
+ dest[key] = type_(kw[key])
+
+
+def constructor_key(obj, cls):
+ """Produce a tuple structure that is cacheable using the __dict__ of
+ obj to retrieve values
+
+ """
+ names = get_cls_kwargs(cls)
+ return (cls,) + tuple(
+ (k, obj.__dict__[k]) for k in names if k in obj.__dict__
+ )
+
+
+def constructor_copy(obj, cls, *args, **kw):
+ """Instantiate cls using the __dict__ of obj as constructor arguments.
+
+ Uses inspect to match the named arguments of ``cls``.
+
+ """
+
+ names = get_cls_kwargs(cls)
+ kw.update(
+ (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__
+ )
+ return cls(*args, **kw)
+
+
+def counter():
+ """Return a threadsafe counter function."""
+
+ lock = compat.threading.Lock()
+ counter = itertools.count(1)
+
+ # avoid the 2to3 "next" transformation...
+ def _next():
+ with lock:
+ return next(counter)
+
+ return _next
+
+
+def duck_type_collection(specimen, default=None):
+ """Given an instance or class, guess if it is or is acting as one of
+ the basic collection types: list, set and dict. If the __emulates__
+ property is present, return that preferentially.
+ """
+
+ if hasattr(specimen, "__emulates__"):
+ # canonicalize set vs sets.Set to a standard: the builtin set
+ if specimen.__emulates__ is not None and issubclass(
+ specimen.__emulates__, set
+ ):
+ return set
+ else:
+ return specimen.__emulates__
+
+ isa = isinstance(specimen, type) and issubclass or isinstance
+ if isa(specimen, list):
+ return list
+ elif isa(specimen, set):
+ return set
+ elif isa(specimen, dict):
+ return dict
+
+ if hasattr(specimen, "append"):
+ return list
+ elif hasattr(specimen, "add"):
+ return set
+ elif hasattr(specimen, "set"):
+ return dict
+ else:
+ return default
+
+
+def assert_arg_type(arg, argtype, name):
+ if isinstance(arg, argtype):
+ return arg
+ else:
+ if isinstance(argtype, tuple):
+ raise exc.ArgumentError(
+ "Argument '%s' is expected to be one of type %s, got '%s'"
+ % (name, " or ".join("'%s'" % a for a in argtype), type(arg))
+ )
+ else:
+ raise exc.ArgumentError(
+ "Argument '%s' is expected to be of type '%s', got '%s'"
+ % (name, argtype, type(arg))
+ )
+
+
+def dictlike_iteritems(dictlike):
+ """Return a (key, value) iterator for almost any dict-like object."""
+
+ if compat.py3k:
+ if hasattr(dictlike, "items"):
+ return list(dictlike.items())
+ else:
+ if hasattr(dictlike, "iteritems"):
+ return dictlike.iteritems()
+ elif hasattr(dictlike, "items"):
+ return iter(dictlike.items())
+
+ getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None))
+ if getter is None:
+ raise TypeError("Object '%r' is not dict-like" % dictlike)
+
+ if hasattr(dictlike, "iterkeys"):
+
+ def iterator():
+ for key in dictlike.iterkeys():
+ yield key, getter(key)
+
+ return iterator()
+ elif hasattr(dictlike, "keys"):
+ return iter((key, getter(key)) for key in dictlike.keys())
+ else:
+ raise TypeError("Object '%r' is not dict-like" % dictlike)
+
+
+class classproperty(property):
+ """A decorator that behaves like @property except that operates
+ on classes rather than instances.
+
+ The decorator is currently special when using the declarative
+ module, but note that the
+ :class:`~.sqlalchemy.ext.declarative.declared_attr`
+ decorator should be used for this purpose with declarative.
+
+ """
+
+ def __init__(self, fget, *arg, **kw):
+ super(classproperty, self).__init__(fget, *arg, **kw)
+ self.__doc__ = fget.__doc__
+
+ def __get__(desc, self, cls):
+ return desc.fget(cls)
+
+
+class hybridproperty(object):
+ def __init__(self, func):
+ self.func = func
+ self.clslevel = func
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ clsval = self.clslevel(owner)
+ return clsval
+ else:
+ return self.func(instance)
+
+ def classlevel(self, func):
+ self.clslevel = func
+ return self
+
+
+class hybridmethod(object):
+ """Decorate a function as cls- or instance- level."""
+
+ def __init__(self, func):
+ self.func = self.__func__ = func
+ self.clslevel = func
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.clslevel.__get__(owner, owner.__class__)
+ else:
+ return self.func.__get__(instance, owner)
+
+ def classlevel(self, func):
+ self.clslevel = func
+ return self
+
+
+class _symbol(int):
+ def __new__(self, name, doc=None, canonical=None):
+ """Construct a new named symbol."""
+ assert isinstance(name, compat.string_types)
+ if canonical is None:
+ canonical = hash(name)
+ v = int.__new__(_symbol, canonical)
+ v.name = name
+ if doc:
+ v.__doc__ = doc
+ return v
+
+ def __reduce__(self):
+ return symbol, (self.name, "x", int(self))
+
+ def __str__(self):
+ return repr(self)
+
+ def __repr__(self):
+ return "symbol(%r)" % self.name
+
+
+_symbol.__name__ = "symbol"
+
+
+class symbol(object):
+ """A constant symbol.
+
+ >>> symbol('foo') is symbol('foo')
+ True
+ >>> symbol('foo')
+ <symbol 'foo>
+
+ A slight refinement of the MAGICCOOKIE=object() pattern. The primary
+ advantage of symbol() is its repr(). They are also singletons.
+
+ Repeated calls of symbol('name') will all return the same instance.
+
+ The optional ``doc`` argument assigns to ``__doc__``. This
+ is strictly so that Sphinx autoattr picks up the docstring we want
+ (it doesn't appear to pick up the in-module docstring if the datamember
+ is in a different module - autoattribute also blows up completely).
+ If Sphinx fixes/improves this then we would no longer need
+ ``doc`` here.
+
+ """
+
+ symbols = {}
+ _lock = compat.threading.Lock()
+
+ def __new__(cls, name, doc=None, canonical=None):
+ with cls._lock:
+ sym = cls.symbols.get(name)
+ if sym is None:
+ cls.symbols[name] = sym = _symbol(name, doc, canonical)
+ return sym
+
+ @classmethod
+ def parse_user_argument(
+ cls, arg, choices, name, resolve_symbol_names=False
+ ):
+ """Given a user parameter, parse the parameter into a chosen symbol.
+
+ The user argument can be a string name that matches the name of a
+ symbol, or the symbol object itself, or any number of alternate choices
+ such as True/False/ None etc.
+
+ :param arg: the user argument.
+ :param choices: dictionary of symbol object to list of possible
+ entries.
+ :param name: name of the argument. Used in an :class:`.ArgumentError`
+ that is raised if the parameter doesn't match any available argument.
+ :param resolve_symbol_names: include the name of each symbol as a valid
+ entry.
+
+ """
+ # note using hash lookup is tricky here because symbol's `__hash__`
+ # is its int value which we don't want included in the lookup
+ # explicitly, so we iterate and compare each.
+ for sym, choice in choices.items():
+ if arg is sym:
+ return sym
+ elif resolve_symbol_names and arg == sym.name:
+ return sym
+ elif arg in choice:
+ return sym
+
+ if arg is None:
+ return None
+
+ raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg))
+
+
+_creation_order = 1
+
+
+def set_creation_order(instance):
+ """Assign a '_creation_order' sequence to the given instance.
+
+ This allows multiple instances to be sorted in order of creation
+ (typically within a single thread; the counter is not particularly
+ threadsafe).
+
+ """
+ global _creation_order
+ instance._creation_order = _creation_order
+ _creation_order += 1
+
+
+def warn_exception(func, *args, **kwargs):
+ """executes the given function, catches all exceptions and converts to
+ a warning.
+
+ """
+ try:
+ return func(*args, **kwargs)
+ except Exception:
+ warn("%s('%s') ignored" % sys.exc_info()[0:2])
+
+
+def ellipses_string(value, len_=25):
+ try:
+ if len(value) > len_:
+ return "%s..." % value[0:len_]
+ else:
+ return value
+ except TypeError:
+ return value
+
+
+class _hash_limit_string(compat.text_type):
+ """A string subclass that can only be hashed on a maximum amount
+ of unique values.
+
+ This is used for warnings so that we can send out parameterized warnings
+ without the __warningregistry__ of the module, or the non-overridable
+ "once" registry within warnings.py, overloading memory,
+
+
+ """
+
+ def __new__(cls, value, num, args):
+ interpolated = (value % args) + (
+ " (this warning may be suppressed after %d occurrences)" % num
+ )
+ self = super(_hash_limit_string, cls).__new__(cls, interpolated)
+ self._hash = hash("%s_%d" % (value, hash(interpolated) % num))
+ return self
+
+ def __hash__(self):
+ return self._hash
+
+ def __eq__(self, other):
+ return hash(self) == hash(other)
+
+
+def warn(msg, code=None):
+ """Issue a warning.
+
+ If msg is a string, :class:`.exc.SAWarning` is used as
+ the category.
+
+ """
+ if code:
+ _warnings_warn(exc.SAWarning(msg, code=code))
+ else:
+ _warnings_warn(msg, exc.SAWarning)
+
+
+def warn_limited(msg, args):
+ """Issue a warning with a parameterized string, limiting the number
+ of registrations.
+
+ """
+ if args:
+ msg = _hash_limit_string(msg, 10, args)
+ _warnings_warn(msg, exc.SAWarning)
+
+
+def _warnings_warn(message, category=None, stacklevel=2):
+
+ # adjust the given stacklevel to be outside of SQLAlchemy
+ try:
+ frame = sys._getframe(stacklevel)
+ except ValueError:
+ # being called from less than 3 (or given) stacklevels, weird,
+ # but don't crash
+ stacklevel = 0
+ except:
+ # _getframe() doesn't work, weird interpreter issue, weird,
+ # ok, but don't crash
+ stacklevel = 0
+ else:
+ # using __name__ here requires that we have __name__ in the
+ # __globals__ of the decorated string functions we make also.
+ # we generate this using {"__name__": fn.__module__}
+ while frame is not None and re.match(
+ r"^(?:sqlalchemy\.|alembic\.)", frame.f_globals.get("__name__", "")
+ ):
+ frame = frame.f_back
+ stacklevel += 1
+
+ if category is not None:
+ warnings.warn(message, category, stacklevel=stacklevel + 1)
+ else:
+ warnings.warn(message, stacklevel=stacklevel + 1)
+
+
+def only_once(fn, retry_on_exception):
+ """Decorate the given function to be a no-op after it is called exactly
+ once."""
+
+ once = [fn]
+
+ def go(*arg, **kw):
+ # strong reference fn so that it isn't garbage collected,
+ # which interferes with the event system's expectations
+ strong_fn = fn # noqa
+ if once:
+ once_fn = once.pop()
+ try:
+ return once_fn(*arg, **kw)
+ except:
+ if retry_on_exception:
+ once.insert(0, once_fn)
+ raise
+
+ return go
+
+
+_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py")
+_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)")
+
+
+def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE):
+ """Chop extraneous lines off beginning and end of a traceback.
+
+ :param tb:
+ a list of traceback lines as returned by ``traceback.format_stack()``
+
+ :param exclude_prefix:
+ a regular expression object matching lines to skip at beginning of
+ ``tb``
+
+ :param exclude_suffix:
+ a regular expression object matching lines to skip at end of ``tb``
+ """
+ start = 0
+ end = len(tb) - 1
+ while start <= end and exclude_prefix.search(tb[start]):
+ start += 1
+ while start <= end and exclude_suffix.search(tb[end]):
+ end -= 1
+ return tb[start : end + 1]
+
+
+NoneType = type(None)
+
+
+def attrsetter(attrname):
+ code = "def set(obj, value):" " obj.%s = value" % attrname
+ env = locals().copy()
+ exec(code, env)
+ return env["set"]
+
+
+class EnsureKWArgType(type):
+ r"""Apply translation of functions to accept \**kw arguments if they
+ don't already.
+
+ """
+
+ def __init__(cls, clsname, bases, clsdict):
+ fn_reg = cls.ensure_kwarg
+ if fn_reg:
+ for key in clsdict:
+ m = re.match(fn_reg, key)
+ if m:
+ fn = clsdict[key]
+ spec = compat.inspect_getfullargspec(fn)
+ if not spec.varkw:
+ clsdict[key] = wrapped = cls._wrap_w_kw(fn)
+ setattr(cls, key, wrapped)
+ super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict)
+
+ def _wrap_w_kw(self, fn):
+ def wrap(*arg, **kw):
+ return fn(*arg)
+
+ return update_wrapper(wrap, fn)
+
+
+def wrap_callable(wrapper, fn):
+ """Augment functools.update_wrapper() to work with objects with
+ a ``__call__()`` method.
+
+ :param fn:
+ object with __call__ method
+
+ """
+ if hasattr(fn, "__name__"):
+ return update_wrapper(wrapper, fn)
+ else:
+ _f = wrapper
+ _f.__name__ = fn.__class__.__name__
+ if hasattr(fn, "__module__"):
+ _f.__module__ = fn.__module__
+
+ if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__:
+ _f.__doc__ = fn.__call__.__doc__
+ elif fn.__doc__:
+ _f.__doc__ = fn.__doc__
+
+ return _f
+
+
+def quoted_token_parser(value):
+ """Parse a dotted identifier with accommodation for quoted names.
+
+ Includes support for SQL-style double quotes as a literal character.
+
+ E.g.::
+
+ >>> quoted_token_parser("name")
+ ["name"]
+ >>> quoted_token_parser("schema.name")
+ ["schema", "name"]
+ >>> quoted_token_parser('"Schema"."Name"')
+ ['Schema', 'Name']
+ >>> quoted_token_parser('"Schema"."Name""Foo"')
+ ['Schema', 'Name""Foo']
+
+ """
+
+ if '"' not in value:
+ return value.split(".")
+
+ # 0 = outside of quotes
+ # 1 = inside of quotes
+ state = 0
+ result = [[]]
+ idx = 0
+ lv = len(value)
+ while idx < lv:
+ char = value[idx]
+ if char == '"':
+ if state == 1 and idx < lv - 1 and value[idx + 1] == '"':
+ result[-1].append('"')
+ idx += 1
+ else:
+ state ^= 1
+ elif char == "." and state == 0:
+ result.append([])
+ else:
+ result[-1].append(char)
+ idx += 1
+
+ return ["".join(token) for token in result]
+
+
+def add_parameter_text(params, text):
+ params = _collections.to_list(params)
+
+ def decorate(fn):
+ doc = fn.__doc__ is not None and fn.__doc__ or ""
+ if doc:
+ doc = inject_param_text(doc, {param: text for param in params})
+ fn.__doc__ = doc
+ return fn
+
+ return decorate
+
+
+def _dedent_docstring(text):
+ split_text = text.split("\n", 1)
+ if len(split_text) == 1:
+ return text
+ else:
+ firstline, remaining = split_text
+ if not firstline.startswith(" "):
+ return firstline + "\n" + textwrap.dedent(remaining)
+ else:
+ return textwrap.dedent(text)
+
+
+def inject_docstring_text(doctext, injecttext, pos):
+ doctext = _dedent_docstring(doctext or "")
+ lines = doctext.split("\n")
+ if len(lines) == 1:
+ lines.append("")
+ injectlines = textwrap.dedent(injecttext).split("\n")
+ if injectlines[0]:
+ injectlines.insert(0, "")
+
+ blanks = [num for num, line in enumerate(lines) if not line.strip()]
+ blanks.insert(0, 0)
+
+ inject_pos = blanks[min(pos, len(blanks) - 1)]
+
+ lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
+ return "\n".join(lines)
+
+
+_param_reg = re.compile(r"(\s+):param (.+?):")
+
+
+def inject_param_text(doctext, inject_params):
+ doclines = collections.deque(doctext.splitlines())
+ lines = []
+
+ # TODO: this is not working for params like ":param case_sensitive=True:"
+
+ to_inject = None
+ while doclines:
+ line = doclines.popleft()
+
+ m = _param_reg.match(line)
+
+ if to_inject is None:
+ if m:
+ param = m.group(2).lstrip("*")
+ if param in inject_params:
+ # default indent to that of :param: plus one
+ indent = " " * len(m.group(1)) + " "
+
+ # but if the next line has text, use that line's
+ # indentation
+ if doclines:
+ m2 = re.match(r"(\s+)\S", doclines[0])
+ if m2:
+ indent = " " * len(m2.group(1))
+
+ to_inject = indent + inject_params[param]
+ elif m:
+ lines.extend(["\n", to_inject, "\n"])
+ to_inject = None
+ elif not line.rstrip():
+ lines.extend([line, to_inject, "\n"])
+ to_inject = None
+ elif line.endswith("::"):
+ # TODO: this still wont cover if the code example itself has blank
+ # lines in it, need to detect those via indentation.
+ lines.extend([line, doclines.popleft()])
+ continue
+ lines.append(line)
+
+ return "\n".join(lines)
+
+
+def repr_tuple_names(names):
+ """Trims a list of strings from the middle and return a string of up to
+ four elements. Strings greater than 11 characters will be truncated"""
+ if len(names) == 0:
+ return None
+ flag = len(names) <= 4
+ names = names[0:4] if flag else names[0:3] + names[-1:]
+ res = ["%s.." % name[:11] if len(name) > 11 else name for name in names]
+ if flag:
+ return ", ".join(res)
+ else:
+ return "%s, ..., %s" % (", ".join(res[0:3]), res[-1])
+
+
+def has_compiled_ext():
+ try:
+ from sqlalchemy import cimmutabledict # noqa: F401
+ from sqlalchemy import cprocessors # noqa: F401
+ from sqlalchemy import cresultproxy # noqa: F401
+
+ return True
+ except ImportError:
+ return False
diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py
new file mode 100644
index 0000000..67c5219
--- /dev/null
+++ b/lib/sqlalchemy/util/queue.py
@@ -0,0 +1,291 @@
+# util/queue.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""An adaptation of Py2.3/2.4's Queue module which supports reentrant
+behavior, using RLock instead of Lock for its mutex object. The
+Queue object is used exclusively by the sqlalchemy.pool.QueuePool
+class.
+
+This is to support the connection pool's usage of weakref callbacks to return
+connections to the underlying Queue, which can in extremely
+rare cases be invoked within the ``get()`` method of the Queue itself,
+producing a ``put()`` inside the ``get()`` and therefore a reentrant
+condition.
+
+"""
+
+from collections import deque
+from time import time as _time
+
+from . import compat
+from .compat import threading
+from .concurrency import asyncio
+from .concurrency import await_fallback
+from .concurrency import await_only
+from .langhelpers import memoized_property
+
+
+__all__ = ["Empty", "Full", "Queue"]
+
+
+class Empty(Exception):
+ "Exception raised by Queue.get(block=0)/get_nowait()."
+
+ pass
+
+
+class Full(Exception):
+ "Exception raised by Queue.put(block=0)/put_nowait()."
+
+ pass
+
+
+class Queue:
+ def __init__(self, maxsize=0, use_lifo=False):
+ """Initialize a queue object with a given maximum size.
+
+ If `maxsize` is <= 0, the queue size is infinite.
+
+ If `use_lifo` is True, this Queue acts like a Stack (LIFO).
+ """
+
+ self._init(maxsize)
+ # mutex must be held whenever the queue is mutating. All methods
+ # that acquire mutex must release it before returning. mutex
+ # is shared between the two conditions, so acquiring and
+ # releasing the conditions also acquires and releases mutex.
+ self.mutex = threading.RLock()
+ # Notify not_empty whenever an item is added to the queue; a
+ # thread waiting to get is notified then.
+ self.not_empty = threading.Condition(self.mutex)
+ # Notify not_full whenever an item is removed from the queue;
+ # a thread waiting to put is notified then.
+ self.not_full = threading.Condition(self.mutex)
+ # If this queue uses LIFO or FIFO
+ self.use_lifo = use_lifo
+
+ def qsize(self):
+ """Return the approximate size of the queue (not reliable!)."""
+
+ with self.mutex:
+ return self._qsize()
+
+ def empty(self):
+ """Return True if the queue is empty, False otherwise (not
+ reliable!)."""
+
+ with self.mutex:
+ return self._empty()
+
+ def full(self):
+ """Return True if the queue is full, False otherwise (not
+ reliable!)."""
+
+ with self.mutex:
+ return self._full()
+
+ def put(self, item, block=True, timeout=None):
+ """Put an item into the queue.
+
+ If optional args `block` is True and `timeout` is None (the
+ default), block if necessary until a free slot is
+ available. If `timeout` is a positive number, it blocks at
+ most `timeout` seconds and raises the ``Full`` exception if no
+ free slot was available within that time. Otherwise (`block`
+ is false), put an item on the queue if a free slot is
+ immediately available, else raise the ``Full`` exception
+ (`timeout` is ignored in that case).
+ """
+
+ with self.not_full:
+ if not block:
+ if self._full():
+ raise Full
+ elif timeout is None:
+ while self._full():
+ self.not_full.wait()
+ else:
+ if timeout < 0:
+ raise ValueError("'timeout' must be a positive number")
+ endtime = _time() + timeout
+ while self._full():
+ remaining = endtime - _time()
+ if remaining <= 0.0:
+ raise Full
+ self.not_full.wait(remaining)
+ self._put(item)
+ self.not_empty.notify()
+
+ def put_nowait(self, item):
+ """Put an item into the queue without blocking.
+
+ Only enqueue the item if a free slot is immediately available.
+ Otherwise raise the ``Full`` exception.
+ """
+ return self.put(item, False)
+
+ def get(self, block=True, timeout=None):
+ """Remove and return an item from the queue.
+
+ If optional args `block` is True and `timeout` is None (the
+ default), block if necessary until an item is available. If
+ `timeout` is a positive number, it blocks at most `timeout`
+ seconds and raises the ``Empty`` exception if no item was
+ available within that time. Otherwise (`block` is false),
+ return an item if one is immediately available, else raise the
+ ``Empty`` exception (`timeout` is ignored in that case).
+
+ """
+ with self.not_empty:
+ if not block:
+ if self._empty():
+ raise Empty
+ elif timeout is None:
+ while self._empty():
+ self.not_empty.wait()
+ else:
+ if timeout < 0:
+ raise ValueError("'timeout' must be a positive number")
+ endtime = _time() + timeout
+ while self._empty():
+ remaining = endtime - _time()
+ if remaining <= 0.0:
+ raise Empty
+ self.not_empty.wait(remaining)
+ item = self._get()
+ self.not_full.notify()
+ return item
+
+ def get_nowait(self):
+ """Remove and return an item from the queue without blocking.
+
+ Only get an item if one is immediately available. Otherwise
+ raise the ``Empty`` exception.
+ """
+
+ return self.get(False)
+
+ # Override these methods to implement other queue organizations
+ # (e.g. stack or priority queue).
+ # These will only be called with appropriate locks held
+
+ # Initialize the queue representation
+ def _init(self, maxsize):
+ self.maxsize = maxsize
+ self.queue = deque()
+
+ def _qsize(self):
+ return len(self.queue)
+
+ # Check whether the queue is empty
+ def _empty(self):
+ return not self.queue
+
+ # Check whether the queue is full
+ def _full(self):
+ return self.maxsize > 0 and len(self.queue) == self.maxsize
+
+ # Put a new item in the queue
+ def _put(self, item):
+ self.queue.append(item)
+
+ # Get an item from the queue
+ def _get(self):
+ if self.use_lifo:
+ # LIFO
+ return self.queue.pop()
+ else:
+ # FIFO
+ return self.queue.popleft()
+
+
+class AsyncAdaptedQueue:
+ await_ = staticmethod(await_only)
+
+ def __init__(self, maxsize=0, use_lifo=False):
+ self.use_lifo = use_lifo
+ self.maxsize = maxsize
+
+ def empty(self):
+ return self._queue.empty()
+
+ def full(self):
+ return self._queue.full()
+
+ def qsize(self):
+ return self._queue.qsize()
+
+ @memoized_property
+ def _queue(self):
+ # Delay creation of the queue until it is first used, to avoid
+ # binding it to a possibly wrong event loop.
+ # By delaying the creation of the pool we accommodate the common
+ # usage pattern of instantiating the engine at module level, where a
+ # different event loop is in present compared to when the application
+ # is actually run.
+
+ if self.use_lifo:
+ queue = asyncio.LifoQueue(maxsize=self.maxsize)
+ else:
+ queue = asyncio.Queue(maxsize=self.maxsize)
+ return queue
+
+ def put_nowait(self, item):
+ try:
+ return self._queue.put_nowait(item)
+ except asyncio.QueueFull as err:
+ compat.raise_(
+ Full(),
+ replace_context=err,
+ )
+
+ def put(self, item, block=True, timeout=None):
+ if not block:
+ return self.put_nowait(item)
+
+ try:
+ if timeout is not None:
+ return self.await_(
+ asyncio.wait_for(self._queue.put(item), timeout)
+ )
+ else:
+ return self.await_(self._queue.put(item))
+ except (asyncio.QueueFull, asyncio.TimeoutError) as err:
+ compat.raise_(
+ Full(),
+ replace_context=err,
+ )
+
+ def get_nowait(self):
+ try:
+ return self._queue.get_nowait()
+ except asyncio.QueueEmpty as err:
+ compat.raise_(
+ Empty(),
+ replace_context=err,
+ )
+
+ def get(self, block=True, timeout=None):
+ if not block:
+ return self.get_nowait()
+
+ try:
+ if timeout is not None:
+ return self.await_(
+ asyncio.wait_for(self._queue.get(), timeout)
+ )
+ else:
+ return self.await_(self._queue.get())
+ except (asyncio.QueueEmpty, asyncio.TimeoutError) as err:
+ compat.raise_(
+ Empty(),
+ replace_context=err,
+ )
+
+
+class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue):
+ await_ = staticmethod(await_fallback)
diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py
new file mode 100644
index 0000000..bbc819f
--- /dev/null
+++ b/lib/sqlalchemy/util/topological.py
@@ -0,0 +1,100 @@
+# util/topological.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Topological sorting algorithms."""
+
+from .. import util
+from ..exc import CircularDependencyError
+
+__all__ = ["sort", "sort_as_subsets", "find_cycles"]
+
+
+def sort_as_subsets(tuples, allitems):
+
+ edges = util.defaultdict(set)
+ for parent, child in tuples:
+ edges[child].add(parent)
+
+ todo = list(allitems)
+ todo_set = set(allitems)
+
+ while todo_set:
+ output = []
+ for node in todo:
+ if todo_set.isdisjoint(edges[node]):
+ output.append(node)
+
+ if not output:
+ raise CircularDependencyError(
+ "Circular dependency detected.",
+ find_cycles(tuples, allitems),
+ _gen_edges(edges),
+ )
+
+ todo_set.difference_update(output)
+ todo = [t for t in todo if t in todo_set]
+ yield output
+
+
+def sort(tuples, allitems, deterministic_order=True):
+ """sort the given list of items by dependency.
+
+ 'tuples' is a list of tuples representing a partial ordering.
+
+ deterministic_order is no longer used, the order is now always
+ deterministic given the order of "allitems". the flag is there
+ for backwards compatibility with Alembic.
+
+ """
+
+ for set_ in sort_as_subsets(tuples, allitems):
+ for s in set_:
+ yield s
+
+
+def find_cycles(tuples, allitems):
+ # adapted from:
+ # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
+
+ edges = util.defaultdict(set)
+ for parent, child in tuples:
+ edges[parent].add(child)
+ nodes_to_test = set(edges)
+
+ output = set()
+
+ # we'd like to find all nodes that are
+ # involved in cycles, so we do the full
+ # pass through the whole thing for each
+ # node in the original list.
+
+ # we can go just through parent edge nodes.
+ # if a node is only a child and never a parent,
+ # by definition it can't be part of a cycle. same
+ # if it's not in the edges at all.
+ for node in nodes_to_test:
+ stack = [node]
+ todo = nodes_to_test.difference(stack)
+ while stack:
+ top = stack[-1]
+ for node in edges[top]:
+ if node in stack:
+ cyc = stack[stack.index(node) :]
+ todo.difference_update(cyc)
+ output.update(cyc)
+
+ if node in todo:
+ stack.append(node)
+ todo.remove(node)
+ break
+ else:
+ node = stack.pop()
+ return output
+
+
+def _gen_edges(edges):
+ return set([(right, left) for left in edges for right in edges[left]])
diff --git a/lib/sunhpc/__init__.py b/lib/sunhpc/__init__.py
new file mode 100644
index 0000000..e0eb8f4
--- /dev/null
+++ b/lib/sunhpc/__init__.py
@@ -0,0 +1,5 @@
+version_major = "7"
+version_minor = "0"
+version_micro = "0"
+version = version_major + '.' + version_minor
+release = "Kamaitachi"
diff --git a/lib/sunhpc/commands/__init__.py b/lib/sunhpc/commands/__init__.py
new file mode 100644
index 0000000..50df3ae
--- /dev/null
+++ b/lib/sunhpc/commands/__init__.py
@@ -0,0 +1,1631 @@
+#coding:utf-8
+import re
+import os
+import sys
+import pwd
+import xml
+import time
+import fcntl
+import psutil
+import sunhpc
+import shutil
+import socket
+import syslog
+import struct
+import sqlite3
+import termios
+import argparse
+import textwrap
+import subprocess
+import prettytable
+import configparser
+import sunhpc.invoke
+import sqlalchemy.engine.result
+
+from struct import pack
+from collections import Counter
+from xml.sax import saxutils
+from xml.sax import handler
+from xml.sax import make_parser
+from xml.sax._exceptions import SAXParseException
+from prettytable import DEFAULT,MSWORD_FRIENDLY,PLAIN_COLUMNS,RANDOM
+
+LEFT_PADDING = 8
+RIGHT_PADDING = 8
+DEFAULT_HELP_WIDTH = 8
+
+def get_help_width():
+ try:
+ data = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, '1234')
+ columns = int(struct.unpack('hh', data)[1])
+ except (IOError, ValueError) as e:
+ print ("terminal size detection failed, using default width.")
+ return DEFAULT_HELP_WIDTH
+
+ columns = columns - RIGHT_PADDING
+ if columns > 0:
+ width = columns
+ else:
+ width = DEFAULT_HELP_WIDTH
+
+ return width
+
+class DatabaseConnection(object):
+
+ def __init__(self, db):
+ self.database = db
+
+ def search(self, command):
+ return self.database.search(command)
+
+ def execute(self, command):
+ return self.database.execute(command)
+
+ def fetchone(self):
+ return self.database.fetchone()
+
+ def fetchall(self):
+ return self.database.fetchall()
+
+ def getSession(self):
+ """helper function to get the session"""
+ return self.database.getSession()
+
+ def checkHostnameValidity(self, hostname):
+ return self.database.checkHostnameValidity(hostname)
+
+ def getHostname(self, hostname=None):
+ return self.database.getHostname(hostname)
+
+ def getHostIp(self, hostname=None):
+ return self.database.getHostIp(hostname)
+
+ def getHostAttr(self, host, key):
+ return self.getHostAttrs(host).get(key)
+
+ def getHostAttrs(self, host):
+ hostname = self.getHostname(host)
+ return self.database.getHostAttrs(hostname)
+
+ def getFrontendName(self):
+ return self.database.getFrontendName()
+
+ def setNewHostAttr(self, node, attr, value):
+ return self.database.setNewHostAttr(node, attr, value)
+
+ def commit(self):
+ return self.database.commit()
+
+ @property
+ def getDBFile(self):
+ return self.database.getDBFile()
+
+ @property
+ def getEngine(self):
+ return self.database.getEngine()
+
+class HostArgumentProcessor(object):
+
+ def getHostnames(self, names=None, managed_only=0):
+
+ list = []
+ if not names:
+ query = 'select name from nodes'
+ self.db.execute(query)
+ for host, in self.db.fetchall():
+ list.append(host)
+
+ if managed_only:
+ managed_list = []
+ for hostname in list:
+ if self.db.getHostAttr(hostname,
+ 'managed') == 'true':
+ managed_list.append(hostname)
+ return managed_list
+ return list
+
+ groups = {}
+ self.db.execute('select min(rack), max(rack) from nodes')
+ min,max = self.db.fetchone()
+ for i in range(min, max+1): # racks
+ self.db.execute("""select n.name from nodes n where n.rack=%d""" % i)
+ l = []
+ for node, in self.db.fetchall():
+ l.append(node)
+ groups['rack%d' % i] = l
+
+ dict = {}
+ for name in names:
+ if name.find('select') == 0: # SQL select
+ self.db.execute(name)
+ for host, in self.db.fetchall():
+ dict[host] = 1
+ elif name.find('%') >= 0: # SQL % pattern
+ self.db.execute("""select name from nodes where
+ name like '%s'""" % name)
+ for h, in self.db.fetchall():
+ dict[h] = 1
+ elif name in groups: # group name
+ for host in groups[name]:
+ dict[host] = 1
+ else: # host name
+ dict[self.db.getHostname(name)] = 1
+
+ list = sorted(dict.keys())
+ return list
+
+class NetworkFunc(object):
+
+ def netmask_to_cidr(self, netmask):
+ return sum([bin(int(i)).count('1') for i in str(netmask).split('.')])
+
+ def cidr_to_netmask(self, cidr):
+ return socket.inet_ntoa(pack('>I', 0xffffffff ^ (1 << 32 - int(cidr)) - 1))
+
+ def getNetwork(self, addr, mask):
+ address = sunhpc.core.ip.IPAddr(addr)
+ netmask = sunhpc.core.ip.IPAddr(mask)
+ return sunhpc.core.ip.IPAddr(address & netmask)
+
+class FirewallConnection(object):
+
+ def __init__(self, cmd):
+ self.cmd = cmd
+
+ def addports(self, ports=[], proto=[]):
+ """prots=['80', '10-20'], proto=['tcp', 'udp']"""
+
+ if not self.isRunning('firewalld', active='is-active'):
+ return False
+
+ if not ports: return False
+ pr = ' --permanent >/dev/null 1>&2'
+ fw = '/usr/bin/firewall-cmd --add-port='
+
+ pplist = []
+ if proto:
+ for c in proto:
+ for p in ports:
+ pplist.append('%s/%s' % (p, c))
+ else:
+ for p in ports:
+ pplist.append('%s/tcp' % p)
+
+ cmds = []
+ for pp in pplist:
+ if not self.query('--query-port=', pp):
+ cmds.append(fw + pp + pr)
+
+ for cmd in cmds:
+ ret = subprocess.run(cmd, shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, encoding="utf-8", timeout=3)
+
+ self.reload()
+
+ for pp in pplist:
+ if self.query('--query-port=', pp):
+ self.cmd.msg('Adding port \t%s successfull.' % pp)
+ else:
+ self.cmd.msg('Adding port \t%s Failed.' % pp, 'w')
+
+ def addservice(self, name):
+ """firewall-cmd --add-service=tftp --permanent"""
+ if not name: return
+ pr = ' --permanent >/dev/null 1>&2'
+ fw = '/usr/bin/firewall-cmd --add-service='
+
+ srvlist = []
+ if isinstance(name, type([])):
+ srvlist.extend(name)
+ else:
+ srvlist.append(name)
+
+ cmds = []
+ for s in srvlist:
+ if not self.query('--query-service=', s):
+ cmds.append(fw + s + pr)
+
+ for cmd in cmds:
+ ret = subprocess.run(cmd, shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, encoding="utf-8", timeout=3)
+
+ self.reload()
+ for s in srvlist:
+ if self.query('--query-service=', s):
+ self.cmd.msg('Adding service \t%s successfull.' % s)
+ else:
+ self.cmd.msg('Adding service \t%s Failed.' % s, 'w')
+
+ def reload(self):
+ self.cmd.shcmd('/usr/bin/firewall-cmd --reload')
+
+ def isRunning(self, name, active='status'):
+ cmd = '/usr/bin/systemctl %s %s >/dev/null 1>&2' % (active, name)
+ ret = subprocess.run(cmd, shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, encoding="utf-8", timeout=3)
+
+ # shell success code 0
+ if ret.returncode:
+ return False
+ return True
+
+ def query(self, active, name):
+ cmd = '/usr/bin/firewall-cmd %s%s >/dev/null 1>&2' % (active, name)
+ ret = subprocess.run(cmd, shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, encoding="utf-8", timeout=3)
+
+ # shell success code 0
+ if ret.returncode:
+ return False
+ return True
+
+class KickstartProcess(object):
+
+ def __init__(self, command):
+ self.command = command
+ self.ks = {}
+ self.ks['main'] = []
+ self.ks['pre'] = []
+ self.ks['post'] = []
+ self.ks['packages'] = []
+ self.ks['addons'] = []
+ self.ks['anaconda'] = []
+
+ self.services_list = []
+ self.services_list = []
+
+ self.ksprelog = '/mnt/sysimage/root/ks-sunhpc-pre.log'
+ self.kspostlog = '/root/ks-sunhpc-post.log'
+
+ self.isFrontend = False
+
+ def getKickstart(self, phases):
+ return self.ks[phases]
+
+ def addMain(self, txt):
+ if isinstance(txt, type([])):
+ self.ks['main'].append(txt)
+ else:
+ self.ks['main'].append([txt])
+
+ def addPre(self, txt, i='', arg='', log=None):
+ """arg: --nochroot"""
+
+ if i in ['sh', 'bash', 'python']:
+ interpreter = '%%pre --interpreter=/usr/bin/%s' % i
+ else:
+ interpreter = '%pre'
+
+ if log:
+ log = '--log=%s' % log
+ else:
+ log = '--log=%s' % self.ksprelog
+
+ list = []
+ if self.isFrontend:
+ list.extend(self.getContent(txt))
+ else:
+ list.append(' '.join([interpreter, arg, log]))
+ list.extend(self.getContent(txt))
+ list.append('%end\n')
+ self.ks['pre'].append(list)
+
+ def addPost(self, txt, i='', arg='', log=None):
+ """arg: --nochroot"""
+
+ if i in ['sh', 'bash', 'python']:
+ interpreter = '%%post --interpreter=/usr/bin/%s' % i
+ else:
+ interpreter = '%post'
+
+ if log:
+ log = '--log=%s' % log
+ else:
+ log = '--log=%s' % self.kspostlog
+
+ list = []
+ if self.isFrontend:
+ list.extend(self.getContent(txt))
+ else:
+ list.append(' '.join([interpreter, arg, log]))
+ list.extend(self.getContent(txt))
+ list.append('%end\n')
+ self.ks['post'].append(list)
+
+ def addPackages(self, txt, arg=''):
+
+ if self.isFrontend:
+ self.ks['packages'].append(self.getContent(txt))
+ else:
+ self.ks['packages'].append(['%%packages %s' % arg])
+ self.ks['packages'].append(self.getContent(txt))
+ self.ks['packages'].append(['%end\n'])
+
+ def addAddons(self, txt, arg=''):
+
+ if not self.isFrontend:
+ self.ks['addons'].append(self.getContent(txt))
+ self.ks['addons'].append(['%end\n'])
+
+ def addAnaconda(self, txt, arg=''):
+
+ if not self.isFrontend:
+ self.ks['anaconda'].append(['%%anaconda %s' % arg])
+ self.ks['anaconda'].append(self.getContent(txt))
+ self.ks['anaconda'].append(['%end\n'])
+
+ def getContent(self, txt):
+ content = []
+ if isinstance(txt, type([])) or \
+ isinstance(txt, type(())):
+ content.extend(txt)
+ else:
+ content.append(txt)
+ return content
+
+ def makefile(self, text='', name='', mode='',
+ owner='', perms='', expr='', quot=''):
+ """
+ text: list content
+ name: file path
+ mode: append or pipe key
+ owner: root.apache
+ perms: 755
+ src: relation mode args. cat src to text.
+ expr: exec command > or >> filename
+ """
+
+ if isinstance(text, type([])) or \
+ isinstance(text, type(())):
+ fileText = ''.join(text)
+ else:
+ fileText = textwrap.dedent(text)
+
+ fileName = name
+ fileMode = mode
+ fileOwner = owner
+ filePerms = perms
+ fileCommand = expr
+ fileQuoting = quot
+
+ s = ''
+ if fileName:
+ # split path
+ paths, fname = os.path.split(fileName)
+ if paths:
+ mkdirs = "[ ! -e %s ] && mkdir -p %s\n" % (paths, paths)
+ else:
+ mkdirs = ""
+
+ if fileMode == 'append':
+ gt = '>>'
+ else:
+ gt = '>'
+
+ if fileCommand:
+ s += "%s %s %s\n" % (fileCommand, gt, fileName)
+ else:
+ if not fileText:
+ s += 'touch %s\n' % fileName
+ else:
+ if fileQuoting == 'expand':
+ eof = "EOF"
+ else:
+ eof = "'EOF'"
+
+ s += mkdirs
+ s += "cat %s %s << %s" % (gt, fileName, eof)
+ if fileText[0] != '\n':
+ s += '\n'
+ s += fileText
+ if fileText[-1] != '\n':
+ s += '\n'
+ s += "EOF\n"
+
+ if filePerms:
+ s += 'chmod %s %s\n' % (filePerms, fileName)
+
+ if fileOwner:
+ s += 'chown %s %s\n' % (fileOwner, fileName)
+ return s
+
+class ModulesArgument(object):
+ """
+ Subclass of ArgumentParser that also examines boot arguments.
+ formatter_class=argparse.RawTextHelpFormatter,
+ formatter_class=lambda prog: argparse.ArgumentDefaultsHelpFormatter(
+ prog, max_help_position=LEFT_PADDING, width=get_help_width()),
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.parser = None
+ self.subparser = None
+
+ self.autoGenerate(*args, **kwargs)
+
+ def autoGenerate(self, *args, **kwargs):
+ self.parser = argparse.ArgumentParser(
+ epilog='Sunhpc online help: <https://www.sunhpc.com/software/coreutils/>',
+ formatter_class=lambda prog: argparse.ArgumentDefaultsHelpFormatter(
+ prog, max_help_position=LEFT_PADDING, width=get_help_width()),
+ *args, **kwargs
+ )
+
+ # version
+ self.parser.add_argument('-v', '--version', action='version', version='Sunhpc ' + sunhpc.version)
+
+ def handler(self):
+ return self.parser
+
+ def add_mutux(self):
+ """
+ 互斥参数
+ mp = self.ap.add_mutux()
+ mp.add_argument('--aaa', action='store_true')
+ mp.add_argument('--bbb', action='store_true')
+ mp.add_argument('--ccc', action='store_true')
+ 同时使用aaa,bbb,ccc只能使用其中一个,否则会报错.
+ """
+ return self.parser.add_mutually_exclusive_group()
+
+ def add_group(self, gname):
+ """
+ 自定义参数组
+ op = self.ap.add_group('Configure')
+ op.add_argument('-c', dest='config', metavar='config',
+ default='/opt/sunhpc/etc/sunhpc.conf',
+ help='use config file'
+ )
+ """
+ return self.parser.add_argument_group('%s' % gname)
+
+ def add_sub(self, command, helps=''):
+ """
+ 嵌套解析器
+ sub = self.ap.add_sub('create', 'Create a directory')
+ sub.add_argument('--dirname', action='store', help='New directory to create')
+
+ sub = self.ap.add_sub('delete', 'Remove a directory')
+ sub.add_argument('--dirname', action='store', help='The directory to remove')
+ """
+ if not self.subparser:
+ self.subparser = self.parser.add_subparsers()
+ cmd = self.subparser.add_parser('%s' % command, help=helps)
+ return cmd
+
+
+class Command(object):
+ """Base class for all Sunhpc commands the general command line form
+ is as follows:
+
+ sunhpc ACTION COMPONENT OBJECT [ <ARGNAME ARGS> ... ]
+
+ ACTION(s):
+ add
+ create
+ list
+ load
+ sync
+ """
+ MustBeRoot = 1
+
+ def __init__(self, database):
+
+ self.db = DatabaseConnection(database)
+ self.fw = FirewallConnection(self)
+ self.ks = KickstartProcess(self)
+ self.ap = ModulesArgument()
+ self.newdb = self.db.database
+
+ self.os = os.uname()[0].lower()
+ self.arch = os.uname()[4]
+ self.net = NetworkFunc()
+ self.text = ''
+ self.output = []
+ self.fmtOutput = []
+
+ self._args = None
+ self._params = None
+
+ self.major = sunhpc.version_major
+ self.minor = sunhpc.version_minor
+ self.micro = sunhpc.version_micro
+
+ self.prefix = os.environ.get('SUNHPC_HOME') if os.environ.get('SUNHPC_HOME') else '/opt/sunhpc'
+ self._debug = True if os.environ.get('SUNHPC_DEBUG') else False
+
+ self.modules = sunhpc.core.utils.index_modules()
+ self.modules_count = Counter()
+ self.modules_count.update([module.split('.')[0] for module in self.modules])
+
+ @property
+ def debug(self):
+ return self._debug
+ def setdebug(self, status=None):
+ self._debug = True if status else False
+
+ def getNetcard(self):
+ netcard_info = {}
+ info = psutil.net_if_addrs()
+ for k,v in info.items():
+ for item in v:
+ if item[0] == 2 and not item[1] == '127.0.0.1':
+ netcard_info[k] = item.address
+
+ return netcard_info
+
+ def writeConfig(self, filename, content):
+ text = ''
+ if type(content) in [type([]), type(())]:
+ for c in content:
+ text += c
+ else:
+ text += content
+
+ with open(filename, 'w') as f:
+ f.write(text)
+
+ def create_user(self, username, passwd='', uid=None, gid=None, home=None, nohome=False, shell='/bin/bash'):
+
+ with open('/etc/passwd', 'r') as fd:
+ for line in fd:
+ matchusername = re.search(r'%s' % username, line, re.I)
+ if not matchusername:
+ # useradd -M -s /sbin/nologin apache
+ cmd = 'useradd '
+ if nohome: cmd += '-M '
+ if uid: cmd += '-u %d ' % uid
+ if gid:
+ self.shcmd('groupadd -g %d %s' % (gid, username))
+ cmd += '-g %d ' % gid
+ if home: cmd += '-d %s ' % home
+ if shell: cmd += '-s %s ' % shell
+ if username:
+ cmd += '%s ' % username
+ if passwd:
+ cmd += '%s' % passwd
+
+ os.system(cmd)
+ self.msg('"%s" is create successfully.' % username)
+ else:
+ self.msg('"%s" is already exists.' % username, 'w')
+
+ def msg(self, msg, level='', r=None, q=False):
+ """ .e,g.. self.msg(messages, 'e') """
+
+ if q: return
+ r = str(r)
+
+ if level in ['e', 'err', 'error']:
+ if not r:
+ print("\033[91m[errs] %s\033[0m" % (msg))
+ else:
+ r = "\033[92m%s\033[0m\033[31m" % r
+ print("\033[31m[errs] %s \033[0m" % msg.replace('*', r, 1))
+
+ elif level in ['w', 'war', 'warn', 'warning']:
+ if not r:
+ print("\033[95m[warn] %s\033[0m" % (msg))
+ else:
+ r = "\033[92m%s\033[0m\033[95m" % r
+ print("\033[95m[warn] %s \033[0m" % msg.replace('*', r, 1))
+
+ elif level in ['i', 'inf', 'info', 'infomation']:
+ if not r:
+ print("\033[92m[info] %s\033[0m" % (msg))
+ else:
+ r = "\033[92m%s\033[0m\033[93m" % r
+ print("\033[93m[info] %s \033[0m" % msg.replace('*', r, 1))
+
+ elif level in ['o', 'ok']:
+ if not r:
+ print("\033[92m[+]\033[0m \033[96m%s\033[0m" % (msg))
+ else:
+ r = "\033[96m%s\033[0m\033[92m" % r
+ print("\033[92m[+] %s \033[0m" % msg.replace('*', r, 1))
+
+ elif level in ['a', 'abrt', 'abort']:
+ if not r:
+ #print("\033[31m[abrt] %s\033[0m" % (msg))
+ msg = "\033[31m[abrt] %s\033[0m" % (msg)
+ else:
+ r = "\033[93m%s\033[0m\033[31m" % r
+ msg = "\033[31m[abrt] %s \033[0m" % msg.replace('*', r, 1)
+ self.abort(msg)
+ else:
+ #print("\033[92m[+]\033[0m \033[96m%s\033[0m" % (msg))
+ print("\033[92m[+]\033[0m %s" % (msg))
+
+ def abort(self, msg):
+ syslog.syslog(syslog.LOG_ERR, msg.split('\n')[0])
+ raise sunhpc.core.utils.CommandError(msg)
+
+ def matchText(self, name, txt):
+ """name is file path, txt: match string"""
+ # 如果文件不存在,直接返回真,可以直接写入这个文件.
+ if not os.path.exists(name):
+ return 1
+
+ # 如果找到 txt 在文件中返回 False,不在附加到文件.
+ cmd = 'grep -q "%s" %s' % (txt, name)
+ ret = subprocess.run(cmd, shell=True)
+ if ret.returncode:
+ return 1
+ return 0
+
+ def isRootUser(self):
+ return True if os.geteuid() == 0 else False
+
+ def isApacheUser(self):
+ try:
+ if os.geteuid() == pwd.getpwnam('apache')[3]:
+ return 1
+ except:
+ pass
+ return 0
+
+ def str2bool(self, s):
+ if s and s.upper() in [ 'ON', 'YES', 'Y', 'TRUE', '1' ]:
+ return 1
+ return 0
+
+ def bool2str(self, b):
+ if b:
+ return 'yes'
+ return 'no'
+
+ def sbin(self, name, path="sbin"):
+ """Return sunhpc bin/sbin command path"""
+ return os.path.join(self.prefix, path, name)
+
+ def tranpath(self, filename):
+
+ if '/' in filename:
+ # 转换成需要的格式文件名.
+ fullpath = os.path.join(os.sep, filename)
+ dotPath = '..'.join(fullpath.split('.'))
+ dotPath = '.'.join(dotPath.split('/'))[1:]
+ else:
+ # 反转成正常路径文件名.
+ dotFullPath = '/'.join(filename.split('.'))
+ dotPath = '.'.join(dotFullPath.split('//'))
+ dotPath = os.path.join(os.sep, dotPath)
+ return dotPath
+
+ def replace(self, files, srckey, dstkey):
+ """替换指定文件中的关键字"""
+ if not os.path.exists(files):
+ self.abort('The %s no exists.' % files)
+
+ line = []
+ with open(files, 'r') as f:
+ line.extend(f.readlines())
+
+ newline = []
+ for i in line:
+ if srckey in i:
+ rep = i.replace(srckey, dstkey)
+ newline.append(rep)
+ else:
+ newline.append(i)
+
+ return newline
+
+ def shrun(self, cmd):
+ return subprocess.run(cmd, shell=True, check=True)
+
+ def shcmd(self, cmd, ret='str', flag=None, code=False, env=None, cwd=None, preexec_fn=None):
+ info = {}
+ cmd = ' '.join(cmd.split())
+ p = subprocess.Popen(cmd, shell=True, bufsize = 0,
+ cwd=cwd, env=env, preexec_fn=preexec_fn,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+
+ if ret == 'str':
+ info['o'] = p.stdout.read().decode('UTF-8')
+ info['e'] = p.stderr.read().decode('UTF-8')
+ if ret == 'list':
+ info['o'] = [ l.decode('UTF-8') for l in p.stdout.readlines() ]
+ info['e'] = [ l.decode('UTF-8') for l in p.stderr.readlines() ]
+ if code:
+ #info['code'] = p.poll()
+
+ # 使用p.wait() 容易造成死锁. ulimit -a 查看pipe size 缓存大小
+ # 如果pipe size 太小,输出超出了界限,p.wait会一直等待.
+ # info['c'] = p.wait()
+
+ # 所以我们使用p.communicate() 函数替代, 此函数是将输出直接放入内存
+ # 所以这样基本不会出问题了.
+ # info['o'], info['e'] = p.communicate() # now wait
+ p.communicate() # now wait
+
+ # None, 表示没有执行结束
+ # -N, 表示进程被N号信号终止.
+ info['c'] = p.returncode # get exe return value
+ return info
+
+ def copyfiles(self, files, dstdirs):
+
+ if not files:
+ self.msg('supply files list is empty.', 'e')
+ return
+
+ if not os.path.exists(dstdirs):
+ os.makedirs(dstdirs)
+
+ sys.stdout.write("\33[?25l")
+ for i in range(len(files)):
+ self.shcmd("/usr/bin/cp --parents %s %s" % (files[i], dstdirs))
+ self.precent(i+1, len(files), files[i])
+ sys.stdout.write("\33[?25h")
+ print ('', end='\n')
+
+ def precent(self, step, total, name=''):
+
+ rate = step / total
+ percent = int(rate * 100)
+ txt = "\033[92m[+]\033[0m\033[95m Running "
+ txt += "%02d%% [%s/%s] finished.\033[0m\r" % (percent, step, total)
+ print (txt, end="")
+
+ def systemctl(self, name, active=None):
+
+ cmd = '/usr/bin/systemctl is-active %s' % name
+ ret = subprocess.run(cmd, shell=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, encoding="utf-8", timeout=3)
+
+ # shell success code 0
+ if ret.returncode:
+ return False
+ return True
+
+ def system(self, cmd, std=1):
+
+ sys.stdout.write("\33[?25l")
+ if std:
+ subprocess.call(cmd, shell=True)
+ else:
+ p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ close_fds=True)
+
+ w, r ,e = (p.stdin, p.stdout, p.stderr)
+ currLength = 0
+ prevLength = 0
+ spinChars = '-\|/'
+ spinIndex = 0
+ while 1:
+ line = e.readline().decode()
+ if not line:
+ break
+ if len(line) > 79:
+ data = line[0:78]
+ else:
+ data = line[:-1]
+
+ currLength = len(data)
+ pad = ''
+ # 清理输出数据后面的字符无法消除
+ for i in range(0, prevLength - currLength):
+ pad = pad + ' '
+ spin = spinChars[spinIndex % len(spinChars)]
+ spinIndex = spinIndex + 1
+ print (spin + data + pad + '\r', end='')
+ prevLength = currLength
+ sys.stdout.flush()
+ r.close()
+ w.close()
+ e.close()
+
+ pad = ''
+ for i in range(0,78):
+ pad = pad + ' '
+ print ('\r%s\r' % pad,)
+ sys.stdout.write("\33[?25h")
+
+ def loadModules(self, name, args=''):
+ """Parser Kickstart modules sunhpc.modules.*"""
+
+ module_list = []
+ for mod in self.modules:
+ try:
+ mod_name = mod.split('.')[0]
+ except KeyError:
+ mod_name = None
+
+ if name != mod_name: continue
+
+ module = 'sunhpc.modules.%s' % mod
+ module = sunhpc.core.utils.import_centrals(module)(self)
+ module_list.append(module)
+
+ if not module_list: return
+
+ text = []
+ for mod in module_list:
+ modname = mod.__module__.split('.')
+ if modname[2] == 'kickstart':
+ text.append('# Include kickstart module %s' % modname[-1])
+ else:
+ text.append('# Include kickstart module %s' % modname[-1])
+
+ txt = mod.run(args)
+ if not txt: continue
+
+ if isinstance(txt, type([])):
+ text.extend(txt)
+ elif isinstance(txt, type(())):
+ for e in txt:
+ text.append(e)
+ else:
+ text.append(txt)
+
+ return text
+
+ def loadPlugins(self):
+
+ plugins = []
+ #
+ # 每个命令下面都可以添加plugin文件,所以这里的模块
+ # 路径需要根据每个命令,自行获取.
+ # 例如: /commands/sync/users 命令等.
+ dirs = eval('%s.__path__[0]' % self.__module__)
+ for fe in os.listdir(dirs):
+ # 所有命令目录下面的扩展模块名称必须以,
+ # "plugin_00_*.py" 开头方能够被识别.
+ if fe.split('_')[0] != 'plugin':
+ continue
+ if os.path.splitext(fe)[1] != '.py':
+ continue
+
+ # 获取命令目录下面的文件模块的名称然后进行连接.
+ # e.g., sunhpc.commands.sync.users.plugin_fixusers
+ module = '%s.%s' % (self.__module__, os.path.splitext(fe)[0])
+
+ # 导入模块
+ __import__(module)
+ module = eval(module)
+ try:
+ # 调用模块中Plugin类.
+ o = getattr(module, 'Plugin')(self)
+ except AttributeError:
+ continue
+
+ plugins.append(o)
+
+ return plugins
+
+ def runPlugins(self, args='', plugins=None):
+ if not plugins:
+ plugins = self.loadPlugins()
+
+ if not plugins:
+ return
+
+ for plugin in plugins:
+ syslog.syslog(syslog.LOG_INFO, 'Run %s' % plugin)
+ plugin.run(args)
+
+ def clearText(self):
+ self.text = ''
+ def addText(self, s):
+ if s: self.text += s
+ def getText(self):
+ return self.text
+
+ def dictOutput(self, strDict, sep='\t'):
+ maxlen = max(map(len, strDict.keys()))
+ for k in strDict:
+ strings = str(strDict[k])
+ if strings.strip().lower() in ['ok', 'on', 'yes']:
+ value = '|\033[1;32m %s \033[0m' % strings.upper()
+ elif strings.strip().lower() in ['fail', 'off', 'no', 'error', 'failed']:
+ value = '|\033[1;31m %s \033[0m' % strings.upper()
+ else:
+ value = '| %s' % strings
+ print ('\t\033[1;95m%s\033[0m %s \033[1;96m%s\033[0m' % (k.ljust(maxlen), sep, value))
+
+ def beginFmtOutput(self, header=[]):
+ self.fmtOutput = prettytable.PrettyTable(header)
+ def addFmtOutput(self, list_string):
+ """use prettytable fotmatting output buffer."""
+
+ # 如果不是列表,转换成列表进行添加.
+ if not isinstance(list_string, type([])):
+ list_string = [ list_string ]
+ if isinstance(list_string, type(())):
+ list_string = list(list_string)
+
+ # 如果是[[]]列表,使用rows一次添加多行.否则使用单行添加row.
+ if list_string and isinstance(list_string[0], type([])):
+ self.fmtOutput.add_rows(list_string)
+ else:
+ self.fmtOutput.add_row(list_string)
+ def endFmtOutput(self, name='', sort='', rever=False, style='def', align='l'):
+ """
+ :types name: string
+ :param name: 需要匹配对其的名称.
+
+ :types sort: string
+ :param sort: 需要匹配排序的名称.
+
+ :types rever: bool
+ :param rever: 反向排序.
+
+ :types style: string
+ :param style: 设置边框样式.
+
+ :types align: string
+ :param align: 对其方式, l左对齐,r右对齐,c居中.
+ """
+ style_id = {'def':DEFAULT, 'ms':MSWORD_FRIENDLY, 'plain': PLAIN_COLUMNS, 'random': RANDOM}
+ if style in style_id:
+ self.fmtOutput.set_style(style_id[style])
+
+ if align in ['l', 'r', 'c']:
+ names = name.split(',')
+ if names[0] != '':
+ for n in names:
+ self.fmtOutput.align[n] = align
+ else:
+ self.fmtOutput.align = align
+
+ if sort != '':
+ print (self.fmtOutput.get_string(sortby=sort, reversesort=rever))
+ else:
+ print (self.fmtOutput)
+
+ def beginOutput(self):
+ self.output = []
+
+ def addOutput(self, owner, vals):
+
+ plist = [ '%s:' % owner ]
+ if isinstance(vals, type([])):
+ plist.extend(vals)
+ elif isinstance(vals, type(())) or \
+ isinstance(vals, sqlalchemy.engine.result.Row):
+ for e in vals:
+ plist.append(e)
+ else:
+ plist.append(vals)
+ self.output.append(plist)
+
+ def endOutput(self, header=[], padChar='-', trimOwner=1,linesep='\n'):
+
+ if not self.output: return
+
+ showHeader = True
+ if 'output-header' in self._params:
+ showHeader = self.str2bool(self._params['output-header'])
+
+ self.outputCols = []
+ if 'output-col' in self._params:
+ showCols = self._params['output-col'].split(',')
+ for i in header:
+ if i.lower() in showCols:
+ self.outputCols.append(True)
+ else:
+ self.outputCols.append(False)
+
+ if trimOwner:
+ owner = ''
+ self.startOfLine = 1
+ for line in self.output:
+ if not owner:
+ owner = line[0]
+ if not owner == line[0]:
+ self.startOfLine = 0
+ else:
+ self.startOfLine = 0
+
+ if header and showHeader:
+ plist = []
+ for field in header:
+ plist.append(field.upper())
+ output = [ plist ]
+ output.extend(self.output)
+ else:
+ output = self.output
+
+ colwidth = []
+ for line in output:
+ for i in range(0, len(line)):
+ if len(colwidth) <= i:
+ colwidth.append(0)
+ if type(line[i]) != type(str):
+ if line[i] == None:
+ itemlen = 0
+ else:
+ itemlen = len(repr(line[i]))
+ else:
+ itemlen = len(line[i])
+
+ if itemlen > colwidth[i]:
+ colwidth[i] = itemlen
+
+ o = ''
+ for line in output:
+ plist = []
+ for i in range(self.startOfLine, len(line)):
+ if line[i] == None:
+ s = ''
+ else:
+ s = str(line[i])
+ if padChar != '':
+ if s:
+ o = s.ljust(colwidth[i])
+ else:
+ o = ''.ljust(colwidth[i], padChar)
+ else:
+ o = s
+ plist.append(o)
+
+ self.addText('%s%s' % (self.outputRow(plist),linesep))
+
+ def outputRow(self, plist):
+ if self.outputCols:
+ l = []
+ for i in range(0, len(plist)):
+ if self.outputCols[i + self.startOfLine]:
+ l.append(plist[i])
+ return ' '.join(l)
+ else:
+ return ' '.join(plist)
+
+ def usage(self):
+ if self.__doc__:
+ handler = DocStringHandler()
+ parser = make_parser()
+ parser.setContentHandler(handler)
+ try:
+ parser.feed('<docstring>%s</docstring>' % self.__doc__)
+ except:
+ return '-- invalid doc string --'
+ return handler.getUsageText()
+ else:
+ return '-- missing doc string --'
+
+ def help(self, command, flags={}):
+
+ if not self.__doc__: return
+
+ if self.MustBeRoot:
+ users = [ 'root', 'apache' ]
+ else:
+ users = []
+
+ if 'format' in flags:
+ formats = flags['format'].lower()
+ else:
+ formats = 'plain'
+
+ if formats == 'raw':
+ i = 1
+ for line in self.__doc__.split('\n'):
+ self.addText('%d:%s\n' % (i, line))
+ i += 1
+ else:
+ handler = DocStringHandler(command, users)
+ parser = make_parser()
+ parser.setContentHandler(handler)
+ parser.feed('<docstring>%s</docstring>' % self.__doc__)
+ if formats == 'docbook':
+ self.addText(handler.getDocbookText())
+ elif formats == 'parsed':
+ self.addText(handler.getParsedText())
+ elif formats == 'sphinx':
+ self.addText(handler.getSphinxText())
+ else:
+ self.addText(handler.getPlainText())
+
+ def fillPositionalArgs(self, names, params=None, args=None):
+
+ if not type(names) in [ type([]), type(()) ]:
+ names = [ names ]
+
+ if not params:
+ params = self._params
+ if not args:
+ args = self._args
+
+ plist = []
+ for name in names:
+ if name in params:
+ plist.append(params[name])
+ else:
+ plist.append(None)
+
+ variate = []
+ trimmed = args
+ plist.reverse()
+ for e in plist:
+ if not e and len(trimmed):
+ variate.append(trimmed[-1])
+ trimmed = trimmed[:-1]
+ else:
+ variate.append(e)
+ variate.reverse()
+
+ rlist = []
+ rlist.append(trimmed)
+ rlist.extend(variate)
+ return rlist
+
+ def fillParams(self, names, params=None):
+
+ if not type(names) in [ type([]), type(()) ]:
+ names = [ names ]
+
+ pdlist = []
+ for e in names:
+ if type(e) in [ type([]), type(()) ] and len(e) == 2:
+ tuples = ( e[0], e[1] )
+ else:
+ tuples = ( e[0], None )
+ pdlist.append(tuples)
+
+ if not params:
+ params = self._params
+
+ plist = []
+ for (key, default) in pdlist:
+ if key in params:
+ plist.append(params[key])
+ else:
+ plist.append(default)
+ return plist
+
+ def command(self, command, args=[]):
+
+ modpath = 'sunhpc.commands.%s' % command
+ __import__(modpath)
+ mod = eval(modpath)
+
+ try:
+ o = getattr(mod, 'Command')(self.db)
+ n = ' '.join(command.split('.'))
+ except AttributeError:
+ return ''
+
+ o.runWrapper(n, args)
+ return o.getText()
+
+ def runWrapper(self, name, args):
+
+ username = pwd.getpwuid(os.geteuid())[0]
+ if args:
+ command = '%s %s' % (name, ' '.join(args))
+ else:
+ command = name
+
+ syslog.syslog(syslog.LOG_INFO,
+ 'user %s called "%s"' % (username, command))
+
+ pdict = {}
+ plist = []
+ nparams = 0
+ flagpattern=re.compile("^[a-zA-z0-9\-_+]+=")
+ for arg in args:
+ tokens = arg.split()
+ if tokens[0] == 'select':
+ plist.append(arg)
+ elif flagpattern.match(arg):
+ (key, val) = arg.split('=', 1)
+ pdict[key] = val
+ if nparams == 0:
+ pdict['@SUNHPCPARAM0'] = arg
+ nparams += 1
+ else:
+ plist.append(arg)
+
+ if plist and plist[0] == 'help':
+ self.help(name, pdict)
+ else:
+ if self.MustBeRoot and not \
+ (self.isRootUser() or self.isApacheUser()):
+ self.abort('command "%s" requires root' % name)
+ else:
+ self._args = plist
+ self._params = pdict
+ try:
+ self.run(self._params, self._args)
+ except sunhpc.core.utils.HostnotfoundException as e:
+ if self.debug:
+ traceback.print_exc()
+ self.abort(str(e))
+ except sqlite3.OperationalError as e:
+ if self.debug:
+ traceback.print_exc()
+ self.abort("Dabase error: " + str(e))
+
+ def getFunction(self, name):
+ import inspect
+ path = inspect.getfile(name)
+ self.msg(path)
+
+ def makeEULA(self, dst):
+ s = 'SunHPC Linux %s EULA\n\n' % sunhpc.version_major
+ s += 'This version was created using SunHPC %s\n\n' % sunhpc.version
+ s += 'The Distribution is released as CentOS. Individual packages in the\n'
+ s += 'distribution come with their own licences.\n'
+ with open(dst, 'w') as f:
+ f.write(s)
+
+ def makeBuildTag(self, dst):
+ timefmt = time.strftime('%Y%m%d-%H%M', time.localtime())
+ with open(dst, 'w') as f:
+ f.write('%s\n' % timefmt)
+
+ def makeDiscInfo(self, dst):
+ with open(dst, 'w') as f:
+ f.write('%s\n' % time.time())
+ f.write('%s\n' % sunhpc.version)
+ f.write('%s\n' % self.arch)
+
+ def makeTreeInfo(self, dst):
+ config = configparser.ConfigParser()
+ config.add_section('general')
+ config.set('general', 'name', 'Sunhpc-%s' % sunhpc.version_major)
+ config.set('general', 'family', 'Sunhpc')
+ config.set('general', 'timestamp', '%.2f' % time.time())
+ config.set('general', 'variant', '')
+ config.set('general', 'version', sunhpc.version_major)
+ config.set('general', 'packagedir', '')
+ config.set('general', 'arch', self.arch)
+
+ config.add_section('stage2')
+ config.set('stage2', 'mainimage', 'images/install.img')
+
+ config.add_section('images-%s' % self.arch)
+
+ vmlinuz = 'images/pxeboot/vmlinuz-%s-%s' % (sunhpc.version, self.arch)
+ initrds = 'images/pxeboot/initrd-%s-%s' % (sunhpc.version, self.arch)
+ updates = 'images/pxeboot/updates.img'
+ config.set('images-%s'% self.arch, 'kernel', vmlinuz)
+ config.set('images-%s'% self.arch, 'initrd', initrds)
+ #config.set('images-%s'% self.arch, 'upgrade', updates)
+ config.set('images-%s'% self.arch, 'boot.iso', 'images/boot.iso')
+
+ with open(dst, 'w') as f:
+ config.write(f)
+
+ def buildstamp(self, dst_stamp):
+ uuidTime = time.strftime('%Y%m%d%H%M%S', time.localtime())
+ with open(dst_stamp, 'w') as f:
+ f.write('[Main]\n')
+ f.write('Product=Sunhpc\n')
+ f.write('Version=%s\n' % sunhpc.version_major)
+ f.write('BugURL=your distribution provided bug reporting tool\n')
+ f.write('IsFinal=True\n')
+ f.write('UUID=%s.%s\n' % (uuidTime, self.arch))
+ f.write('[Compose]\n')
+ f.write('Lorax=19.7.19-1\n')
+
+ def run(self, flags, args):
+ pass
+
+class DocStringHandler(
+ handler.ContentHandler, handler.ErrorHandler,
+ handler.EntityResolver, handler.DTDHandler ):
+
+ def __init__(self, name='', users=[]):
+ handler.ContentHandler.__init__(self)
+ self.text = ''
+ self.name = name
+ self.users = users
+ self.section= {}
+ self.section['arg'] = []
+ self.section['param'] = []
+ self.section['example'] = []
+ self.section['related'] = []
+ self.section['description'] = ''
+ self.parser = make_parser()
+ self.parser.setContentHandler(self)
+
+ def getDocbookText(self):
+ s = ''
+ s += '<section id="sunhpc-%s" xreflabel="%s">\n' % \
+ (string.join(self.name.split(' '), '-'), self.name)
+ s += '<title>%s</title>\n' % self.name
+ s += '<cmdsynopsis>\n'
+ s += '\t<command>sunhpc %s</command>\n' % self.name
+ for ((name, type, opt, rep), txt) in self.section['arg']:
+ if opt:
+ choice = 'opt'
+ else:
+ choice = 'req'
+ if rep:
+ repeat = 'repeat'
+ else:
+ repeat = 'norepeat'
+ s += '\t<arg rep="%s" choice="%s">%s</arg>\n' % \
+ (repeat, choice, name)
+ for ((name, type, opt, rep), txt) in self.section['param']:
+ if opt:
+ choice = 'opt'
+ else:
+ choice = 'req'
+ if rep:
+ repeat = 'repeat'
+ else:
+ repeat = 'norepeat'
+ s += '\t<arg rep="%s" choice="%s">' % (repeat, choice)
+ s += '%s=<replaceable>%s</replaceable>' % (name, type)
+ s += '</arg>\n'
+ s += '</cmdsynopsis>\n'
+ s += '<para>\n'
+ s += saxutils.escape(self.section['description'])
+ s += '\n</para>\n'
+ if self.section['arg']:
+ s += '<variablelist><title>arguments</title>\n'
+ for ((name, type, opt, rep), txt) in \
+ self.section['arg']:
+ s += '\t<varlistentry>\n'
+ if opt:
+ term = '<optional>%s</optional>' % name
+ else:
+ term = name
+ s += '\t<term>%s</term>\n' % term
+ s += '\t<listitem>\n'
+ s += '\t<para>\n'
+ s += saxutils.escape(txt)
+ s += '\n\t</para>\n'
+ s += '\t</listitem>\n'
+ s += '\t</varlistentry>\n'
+ s += '</variablelist>\n'
+ if self.section['param']:
+ s += '<variablelist><title>parameters</title>\n'
+ for ((name, type, opt, rep), txt) in \
+ self.section['param']:
+ s += '\t<varlistentry>\n'
+ if opt:
+ optStart = '<optional>'
+ optEnd = '</optional>'
+ else:
+ optStart = ''
+ optEnd = ''
+ key = '%s=' % name
+ val = '<replaceable>%s</replaceable>' % type
+ s += '\t<term>%s%s%s%s</term>\n' % \
+ (optStart, key, val, optEnd)
+ s += '\t<listitem>\n'
+ s += '\t<para>\n'
+ s += saxutils.escape(txt)
+ s += '\n\t</para>\n'
+ s += '\t</listitem>\n'
+ s += '\t</varlistentry>\n'
+ s += '</variablelist>\n'
+ if self.section['example']:
+ s += '<variablelist><title>examples</title>\n'
+ for (cmd, txt) in self.section['example']:
+ s += '\t<varlistentry>\n'
+ s += '\t<term>\n'
+ if 'root' in self.users:
+ s += '# '
+ else:
+ s += '$ '
+ s += 'sunhpc %s' % cmd
+ s += '\n\t</term>\n'
+ s += '\t<listitem>\n'
+ s += '\t<para>\n'
+ s += saxutils.escape(txt)
+ s += '\n\t</para>\n'
+ s += '\t</listitem>\n'
+ s += '\t</varlistentry>\n'
+ s += '</variablelist>\n'
+ if self.section['related']:
+ s += '<variablelist><title>related commands</title>\n'
+ for related in self.section['related']:
+ s += '\t<varlistentry>\n'
+ s += '\t<term>'
+ s += '<xref linkend="sunhpc-%s">' % \
+ string.join(related.split(' '), '-')
+ s += '</term>\n'
+ s += '\t<listitem>\n'
+ s += '\t<para>\n'
+ s += '\n\t</para>\n'
+ s += '\t</listitem>\n'
+ s += '\t</varlistentry>\n'
+ s += '</variablelist>\n'
+ s += '</section>'
+ return s
+
+ def getUsageText(self):
+ s = ''
+ for ((name, type, opt, rep), txt) in self.section['arg']:
+ if opt:
+ s += '[%s]' % name
+ else:
+ s += '{%s}' % name
+ if rep:
+ s += '...'
+ s += ' '
+ for ((name, type, opt, rep), txt) in self.section['param']:
+ if opt:
+ s += '[%s=%s]' % (name, type)
+ else:
+ s += '{%s=%s}' % (name, type)
+ if rep:
+ s += '...'
+ s += ' '
+ if s and s[-1] == ' ':
+ return s[:-1]
+ else:
+ return s
+
+ def getSphinxText(self):
+ if 'root' in self.users:
+ prompt = '#'
+ else:
+ prompt = '$'
+
+ s = ':orphan:\n\n'
+ s += '%s\n' % self.name
+ s += '%s\n\n' % ("-" * len(self.name))
+ s += '.. role:: defn\n\n'
+ utxt = self.getUsageText()
+ if len(utxt):
+ s += ':defn:`sunhpc %s` *%s*\n' % (self.name, utxt)
+ else:
+ s += ':defn:`sunhpc %s` %s\n' % (self.name, utxt)
+ s += '\n\n**Description:**\n'
+ s += self.section['description'].replace('\t',' ')
+ if self.section['arg']:
+ s += '\n**Arguments:**\n\n'
+ for ((name, type, opt, rep), txt) in \
+ self.section['arg']:
+ if opt:
+ s += '*[%s]*' % name
+ else:
+ s += '*{%s}*' % name
+ txt = txt.replace('*', '\*')
+ s += '\n%s\n' % txt.replace('\t', ' ')
+ if self.section['param']:
+ s += '\n**Parameters:**\n\n'
+ for ((name, type, opt, rep), txt) in \
+ self.section['param']:
+ if opt:
+ s += '*[%s=%s]*' % (name, type)
+ else:
+ s += '*{%s=%s}*' % (name, type)
+ txt = txt.replace('*', '\*')
+ s += '\n%s\n' % txt.replace('\t', ' ')
+ if self.section['example']:
+ s += '\n**Examples:**\n'
+ for (cmd, txt) in self.section['example']:
+ txt = txt.replace('*', '\*')
+ s += '%s::\n\n' % txt.replace('\t',' ')
+ s += ' %s sunhpc %s\n' % (prompt, cmd)
+ if self.section['related']:
+ s += '\n**Related Commands:**\n\n'
+ for related in self.section['related']:
+ s += ' * :ref:`sunhpc-%s`\n' % related.replace(' ','-')
+
+ word = self.name.split()[0]
+ s += '\n:ref:`%s commands <%s-ref>`\n' % (word, word)
+
+ return s
+
+ def getPlainText(self):
+ if 'root' in self.users:
+ prompt = '\033[91m#\033[0m'
+ else:
+ prompt = '\033[96m$\033[0m'
+ s = ''
+ s += 'sunhpc %s %s' % (self.name, self.getUsageText())
+ s += '\n\n\033[95m Description: \033[0m\n'
+ s += '\033[96m%s\033[0m' % self.section['description']
+ if self.section['arg']:
+ s += '\n\033[95m Arguments: \033[0m\n\n'
+ for ((name, type, opt, rep), txt) in \
+ self.section['arg']:
+ if opt:
+ s += '\t\033[92m[%s]\033[0m' % name
+ else:
+ s += '\t\033[92m{%s}\033[0m' % name
+ s += '\033[96m%s\033[0m' % txt
+ if self.section['param']:
+ s += '\n\033[95m Parameters: \033[0m\n\n'
+ for ((name, type, opt, rep), txt) in \
+ self.section['param']:
+ if opt:
+ s += '\t\033[92m[%s=%s]\033[0m' % (name, type)
+ else:
+ s += '\t\033[92m{%s=%s}\033[0m' % (name, type)
+ #s += '\n%s\n' % txt
+ s += '\033[96m%s\033[0m' % txt
+ if self.section['example']:
+ s += '\n\033[95m Examples: \033[0m\n\n'
+ for (cmd, txt) in self.section['example']:
+ s += '\t%s sunhpc %s' % (prompt, cmd)
+ #s += '%s\n' % txt
+ s += '\033[96m%s\033[0m' % txt
+ if self.section['related']:
+ s += '\n\033[95m Related Commands: \033[0m\n'
+ for related in self.section['related']:
+ s += '\tsunhpc %s\n' % related
+ return s
+
+ def getParsedText(self):
+ return '%s' % self.section
+
+ def startElement(self, name, attrs):
+ if not self.section['description']:
+ self.section['description'] = self.text
+ self.key = None
+ self.text = ''
+ if name in [ 'arg', 'param' ]:
+ try:
+ type = attrs.get('type')
+ except:
+ type = 'string'
+ try:
+ optional = int(attrs.get('optional'))
+ except:
+ if name == 'arg':
+ optional = 0
+ if name == 'param':
+ optional = 1
+ try:
+ repeat = int(attrs.get('repeat'))
+ except:
+ repeat = 0
+ name = attrs.get('name')
+ self.key = (name, type, optional, repeat)
+ elif name == 'example':
+ self.key = attrs.get('cmd')
+
+ def endElement(self, name):
+ if name == 'docstring':
+ self.section['param'].sort()
+ self.section['related'].sort()
+ elif name in [ 'arg', 'param', 'example' ]:
+ self.section[name].append((self.key, self.text))
+ else:
+ if name in self.section:
+ self.section[name].append(self.text)
+
+ def characters(self, s):
+ self.text += s
+
+class RollArgumentProcessor(object):
+ """An Interface class to add the ability to process roll arguments."""
+
+ def getRollNames(self, args, params):
+ if 'version' in params:
+ version = params['version']
+ else:
+ version = '%' # SQL wildcard
+
+ plist = []
+ if not args: args = [ '%' ] # find all roll names
+ for arg in args:
+ rows = self.db.search("""select distinct name,version
+ from rolls where name like '%s' and
+ version like '%s'""" % (arg, version))
+ if rows == 0 and arg == '%': # empty table is OK
+ continue
+ if rows < 1:
+ self.abort('unknown roll name "%s"' % arg)
+ for (name, ver) in self.db.fetchall():
+ plist.append((name, ver))
+
+ return plist
+
+class Plugin(object):
+ """Base class for all Sunhpc command plug-ins."""
+ def __init__(self, command):
+ self.cmd = command
+ self.db = command.db
+
+ def run(self, args):
+ """All derived classes should override this method. This
+ is the entry point into the Plugin object."""
+ pass
diff --git a/lib/sunhpc/commands/add/__init__.py b/lib/sunhpc/commands/add/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/add/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/add/host/__init__.py b/lib/sunhpc/commands/add/host/__init__.py
new file mode 100644
index 0000000..7da1164
--- /dev/null
+++ b/lib/sunhpc/commands/add/host/__init__.py
@@ -0,0 +1,124 @@
+#coding:utf-8
+import os
+import sys
+import string
+import sunhpc
+from sunhpc.db.mappings.base import *
+class command(sunhpc.commands.HostArgumentProcessor, sunhpc.commands.add.command):
+ pass
+
+class Command(command):
+ """
+ Add an new host to the cluster.
+
+ <arg type='string' name='host'>
+ A single host name. If the hostname is of the standard form of
+ basename-rack-rank the default values for the membership, rack,
+ and rank parameters are taken from the hostname.
+ </arg>
+
+ <param type='int' name='cpus'>
+ Number of CPUs (cores) in the given host. If not provided the
+ default of 1 CPU is inserted into the database.
+ </param>
+
+ <param type='string' name='membership'>
+ Appliance membership name. If not provided and the host name is of
+ the standard form the membership is taken from the basename of
+ the host.
+ </param>
+
+ <param type='int' name='rack'>
+ The number of the rack where the machine is located. The convention
+ in Rocks is to start numbering at 0. If not provided and the host
+ name is of the standard form the rack number is taken from the host
+ name.
+ </param>
+
+ <param type='int' name='rank'>
+ The position of the machine in the rack. The convention in Rocks
+ is to number from the bottom of the rack to the top starting at 0.
+ If not provided and the host name is of the standard form the rank
+ number is taken from the host name.
+ </param>
+
+ <param type='string' name='os'>
+ The operating system name. The default is: linux.
+ </param>
+
+ <example cmd='add host compute-0-1'>
+ Adds the host "compute-0-0" to the database with 1 CPU, a membership
+ name of "compute", a rack number of 0, and rank of 1.
+ </example>
+
+ <example cmd='add host frontend rack=0 rank=0 membership=Frontend'>
+ Adds the host "frontend" to the database with 1 CPU, a membership name
+ of "Frontend", a rack number of 0, and rank of 1.
+ </example>
+
+ <related>add host interface</related>
+
+ """
+
+ def run(self, params, args):
+
+ if len(args) != 1:
+ self.abort('must supply one host')
+ host = args[0]
+
+ self.newdb.checkHostnameValidity(host)
+ self.newdb.commit()
+
+ s = self.newdb.getSession()
+ try:
+ basename, rack, rank = host.split('-')
+ rack = int(rack)
+ rank = int(rank)
+ except:
+ rack = None
+ rank = None
+
+ (numCPUs, rack, rank, osname) = self.fillParams([
+ ('cpus', 1),
+ ('rack', rack),
+ ('rank', rank),
+ ('os', None) ])
+
+ if rack == None:
+ self.abort('rack not specified')
+ if rank == None:
+ self.abort('rank not specified')
+
+ if osname is None:
+ osname = 'linux'
+
+ n = Node(name=host, cpus=int(numCPUs), rack=int(rack), rank=int(rank), os=osname)
+ s.add(n)
+
+ self.newdb.commit()
+ next_server = self.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+
+ self.newdb.setNewHostAttr(host, 'os', osname)
+ self.newdb.setNewHostAttr(host, 'kickstartable', 'yes')
+ self.newdb.setNewHostAttr(host, 'dhcp_filename', 'pxelinux.0')
+ self.newdb.setNewHostAttr(host, 'dhcp_nextserver', next_server)
+ self.newdb.setNewHostAttr(host, 'managed', 'false')
+
+RollName = "base"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/add/host/interface/__init__.py b/lib/sunhpc/commands/add/host/interface/__init__.py
new file mode 100644
index 0000000..8150814
--- /dev/null
+++ b/lib/sunhpc/commands/add/host/interface/__init__.py
@@ -0,0 +1,109 @@
+#coding:utf-8
+
+import os
+import sys
+import stat
+import time
+import string
+import sunhpc
+class Command(sunhpc.commands.add.host.command):
+ """
+ Adds an interface to a host and sets the associated values
+
+ <arg type='string' name='host'>
+ Host name of machine
+ </arg>
+
+ <arg type='string' name='iface'>
+ The interface name on the host (e.g., 'eth0', 'eth1')
+ </arg>
+
+ <param type='string' name='iface'>
+ Can be used in place of the iface argument.
+ </param>
+
+ <param type='string' name='ip'>
+ The IP address to assign to the interface (e.g., '192.168.1.254')
+ </param>
+
+ <param type='string' name='subnet'>
+ The name of the subnet to assign to the interface (e.g., 'private')
+ </param>
+
+ <param type='string' name='name'>
+ The name to assign to the interface
+ </param>
+
+ <param type='string' name='mac'>
+ The MAC address of the interface (e.g., '00:11:22:33:44:55')
+ </param>
+
+ <example cmd='add host interface compute-0-0 eth1 ip=192.168.1.2 subnet=private name=fast-0-0'>
+ </example>
+
+ <example cmd='add host interface compute-0-0 iface=eth1 ip=192.168.1.2 subnet=private name=fast-0-0'>
+ same as above
+ </example>
+
+ <related>set host interface iface</related>
+ <related>set host interface ip</related>
+ <related>set host interface mac</related>
+ <related>set host interface name</related>
+ <related>set host interface subnet</related>
+ """
+
+ def run(self, params, args):
+
+ (args, iface) = self.fillPositionalArgs(('iface',))
+
+ hosts = self.getHostnames(args)
+
+ if not iface:
+ self.abort('missing iface')
+
+ if len(hosts) != 1:
+ self.abort('must supply one host')
+ host = hosts[0]
+
+ #
+ # determine if this is an interface name or a MAC address
+ #
+
+ isMac = 0
+ m = iface.split(':')
+ if len(m) >= 6:
+ isMac = 1
+
+ rows = self.db.search("""select * from networks,nodes where
+ nodes.name='%s' and
+ (networks.device='%s' or networks.mac='%s') and
+ networks.node=nodes.id""" % (host, iface, iface))
+ if rows:
+ self.abort('interface "%s" exists' % iface)
+
+ if isMac:
+ self.db.execute("""insert into networks(node,mac)
+ values ((select id from nodes where name='%s'),
+ '%s')""" % (host, iface))
+ else:
+ self.db.execute("""insert into networks(node,device)
+ values ((select id from nodes where name='%s'),
+ '%s')""" % (host, iface))
+
+ for key in ['ip', 'mac', 'name', 'subnet']:
+ if key in params:
+ self.command('set.host.interface.%s' % key,
+ (host, iface, params[key]))
+RollName = "base"
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/add/host/security/__init__.py b/lib/sunhpc/commands/add/host/security/__init__.py
new file mode 100644
index 0000000..79d2a5e
--- /dev/null
+++ b/lib/sunhpc/commands/add/host/security/__init__.py
@@ -0,0 +1,99 @@
+#coding:utf-8
+
+import os
+import sys
+import socket
+import sunhpc
+class Command(sunhpc.commands.add.host.command):
+ """
+ Add Host security keys to Database
+
+ <arg type='string' name='host'>
+ Host name of machine
+ </arg>
+
+ <param type='Bool' name='force'>
+ Force overwrite secure attribute to database, default: false
+ </param>
+
+ <example cmd='add host security compute-0-0'>
+ Adds one host security keys to database.
+ </example>
+
+ <example cmd='add host security compute-0-0 force=1'>
+ force overwrite security keys to database.
+ </example>
+ """
+ def nodeup(self, host):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(0.2)
+ try:
+ sock.connect((host, 22))
+ buf = sock.recv(64)
+ sock.send(b'SSH-2.0-nodeup\r\n')
+ buf = sock.recv(1024)
+ sock.close()
+ except:
+ return 0
+ return 1
+
+ def run(self, params, args):
+ (force, enc) = self.fillParams([('force', 'no'), ('enc', 'sha')])
+
+ if not args:
+ self.msg('must supply an host', 'a')
+
+ force = self.str2bool(force)
+ frontend = self.db.getFrontendName()
+ hostname = self.getHostnames(args)[0]
+
+ if hostname == frontend:
+ self.msg('not add frontend keys to database', 'a')
+
+ content = []
+ output = self.shcmd('/usr/bin/ssh-keyscan %s' % hostname)['o']
+ nodeid = self.newdb.getNodeId(hostname)
+ line = output.split('\n')
+ for i in line:
+ r = i.split()
+ if not r or len(r) != 3:
+ continue
+
+ attr, value = r[1], ' '.join(r[1:])
+ content.append((attr, enc, value, nodeid))
+
+ if not content:
+ self.msg('Unable to retrieve any content %s' % hostname, 'a')
+
+ for a, e, v, n in content:
+ rows = self.db.search('select * from secnodes where attr="%s" and node="%s"' % (a, n))
+ if rows and not force:
+ self.msg('Attribute %s already exists.' % a, 'a')
+
+ if force and rows:
+ cmd = '''update secnodes set attr="%s",
+ enc="%s", value="%s", node="%s" where
+ attr="%s" and node="%s" ''' % (a, e, v, n, a, n)
+ else:
+ cmd = 'insert into secnodes values("%s", "%s", "%s", "%s") ' % (a, e, v, n)
+
+ self.db.execute(cmd)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/add/security/__init__.py b/lib/sunhpc/commands/add/security/__init__.py
new file mode 100644
index 0000000..c016763
--- /dev/null
+++ b/lib/sunhpc/commands/add/security/__init__.py
@@ -0,0 +1,68 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+class Command(sunhpc.commands.add.command):
+ """
+ Add all secure attribute to the database.
+
+ <param type='Bool' name='force'>
+ Force overwrite secure attribute to database, default: false
+ </param>
+
+ <example cmd='add security'>
+ adds all secure attr to database
+ </example>
+
+ <example cmd='add security force=1'>
+ force overwrite secure attr to database
+ </example>
+ """
+
+ def run(self, params, args):
+
+ (force, enc) = self.fillParams([('force', 'no'), ('enc', 'sha')])
+
+ attr_list = []
+ force = self.str2bool(force)
+
+ # add /etc/ssh/*.pub to attr_list
+ sshd_dirs = '/etc/safe-security'
+ for i in os.listdir(sshd_dirs):
+ try:
+ if i.split('.')[-1] != 'pub':
+ continue
+ except KeyError:
+ pass
+
+ with open(os.path.join(sshd_dirs, i), 'r') as fe:
+ value = fe.read()
+
+ attr_list.append((i, value))
+
+ for n, v in attr_list:
+ rows = self.db.search('select * from secglobals where attr="%s"' % n)
+ if rows and not force:
+ self.msg('Attribute %s already exists.' % n, 'a')
+
+ if force and rows:
+ cmd = 'update secglobals set attr="%s", value="%s", enc="%s" where attr="%s" ' % (n, v, enc, n)
+ else:
+ cmd = 'insert into secglobals values("%s", "%s", "%s") ' % (n, v, enc)
+
+ self.db.execute(cmd)
+
+RollName = 'base'
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/build/__init__.py b/lib/sunhpc/commands/build/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/build/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/build/initializes/__init__.py b/lib/sunhpc/commands/build/initializes/__init__.py
new file mode 100644
index 0000000..bd0a17a
--- /dev/null
+++ b/lib/sunhpc/commands/build/initializes/__init__.py
@@ -0,0 +1,626 @@
+#
+#coding:utf-8
+#
+#Author : QCSun
+#Email : qcsun@sunhpc.com
+#Times : 2023-04-14 05:21:02
+#WebSite : https://www.sunhpc.com
+
+import os
+import sys
+import time
+import dbus
+import socket
+import sunhpc
+import random
+import readline
+import argparse
+from sunhpc.db.mappings.base import *
+
+class command(sunhpc.commands.build.command):
+ pass
+
+class Command(command):
+ """
+ Initialzed the sunhpc database.
+ """
+ def run(self, params, args):
+
+ '''
+ pars = argparse.ArgumentParser(description="Initialzed the sunhpc database")
+ pars.add_argument('-c', dest='config', metavar='config', default='/opt/sunhpc', help='use config file')
+ argv = pars.parse_args(args=args)
+ '''
+
+ ap = self.ap.handler()
+ op = self.ap.add_group('Configure')
+ op.add_argument(
+ '-c', '--config', metavar='config', action='store',
+ default='/opt/sunhpc/etc/sunhpc.conf',
+ help='Use config file'
+ )
+ op.add_argument(
+ '-f', '--force', action='store_true',
+ default=False,
+ help='Force reinitialzed sunhpc database file.'
+ )
+
+ argv = ap.parse_args(args)
+
+ #self.setdebug(True)
+
+ lock = os.path.join(self.newdb._dbPath, self.newdb._dbLock)
+ self.init_tables(lock)
+
+ # get cluster config file
+ results = self.init_config(argv)
+
+ # Return sql initialze data query list.
+ query = self.init_database(results)
+
+ # Add data to database.
+ for i in query:
+ if self.debug:
+ print (i)
+ self.db.execute(i)
+
+ #[self.db.execute(d) for d in query]
+
+ #print (session)
+
+ def init_tables(self, lock):
+ """create the database file"""
+ print ('Starting initialzed the sunhpc database ...')
+ if os.path.exists(lock):
+ with open(lock, 'r') as f:
+ data = f.readline()
+
+ self.abort('Error - In %s date. The sunhpc database has been initialized to be complete.' % data)
+ else:
+ # Use sqlalchemy initialze the database.
+ #Base.metadata.create_all(self.newdb.engine)
+
+ # Use sqlite initialze the database
+ for i in self.baseTables():
+ if self.debug:
+ print (i)
+ self.db.execute(i)
+
+ current_time = int(time.time())
+ localtime = time.localtime(current_time)
+ today = time.strftime('%Y-%m-%d %H:%M:%S', localtime)
+ with open(lock, 'w') as f:
+ f.write('%s\n' % today)
+
+ def askResponse(self, key, conf, ask, default=None):
+ # key: publichostname conf: config file
+ # ask: 提示信息, default:默认值.
+
+ if key in conf:
+ return conf[key]
+
+ tmp = None
+ if isinstance(default, type([])) or isinstance(default, type(())):
+ while tmp not in default:
+ tmp = input('Please input \033[1;93m%s\033[0m default[ "\033[1;31m%s\033[0m" ]: ' % (ask, ','.join(default)))
+ else:
+ while not tmp:
+ tmp = input('Please input \033[1;93m%s\033[0m default[ "\033[1;31m%s\033[0m" ]: ' % (ask, default)) or default
+
+ return tmp
+
+ def init_config(self, args):
+
+ confs = {}
+ results = {}
+
+ # 读取配置文件中信息.
+ if os.path.exists(args.config):
+ with open(args.config, 'r') as f:
+ for i in f.readlines():
+ r = i.split()
+ if i.startswith('#') or len(r) < 2:
+ continue
+
+ key, val = r[0], ' '.join(r[1:])
+ confs[key.lower()] = val
+
+ results['country'] = self.askResponse('country', confs, 'Country code', 'CN')
+ results['state'] = self.askResponse('state', confs, 'State name', 'LiaoNing')
+ results['city'] = self.askResponse('city', confs, 'Locality name', 'DaLian')
+ results['url'] = 'https://www.sunhpc.com'
+ results['name'] = 'Sunhpc-cluster'
+ results['contact'] = 'info@sunhpc.com'
+ results['worksdir'] = '/export'
+ results['distrodir'] = '/export/sunhpc'
+ results['partition'] = 'default'
+
+ # safe security
+ results['safeport'] = '372'
+ results['safedirs'] = 'safe.d'
+ results['safesecurity'] = 'safe-security'
+
+ try:
+ # Hostname = socket.getfqdn(socket.gethostname())
+ devices = self.getNetcard()
+ interface = list(devices.keys())
+
+ results['publichostname'] = self.askResponse('publichostname', confs, 'PublicHostname', 'cluster.sunhpc.com')
+ results['publicinterface'] = self.askResponse('publicinterface', confs, 'PublicInterface', interface)
+
+ wan_dev = results['publicinterface']
+ results['publicaddress'] = self.askResponse('publicaddress', confs, 'PublicAddress %s' % wan_dev, devices[wan_dev])
+
+ try:
+ gwaddr = devices[wan_dev].split('.')[:-1]
+ gwaddr.append('1')
+ gwaddr = "%s" % '.'.join(gwaddr)
+ except (KeyError, IndexError):
+ gwaddr = devices[wan_dev]
+
+ results['publicgateway'] = self.askResponse(
+ 'publicgateway', confs, 'PublicGateway (%s:%s)' % (wan_dev, gwaddr), gwaddr)
+
+ results['publicnetmask'] = self.askResponse(
+ 'publicnetmask', confs, 'PublicNetmask (%s:%s)' % (wan_dev, devices[wan_dev]), '255.255.255.0')
+
+ results['publicdnsserver'] = self.askResponse('publicdnsserver', confs, 'PublicDNSServer', '223.5.5.5')
+
+ wan_mac = os.popen("nmcli device show %s|grep -i general.hwaddr|awk '{print $2}'" % wan_dev).readline().strip()
+ results['publicmacaddress'] = self.askResponse('publicmacaddr', confs, 'PublicMACAddress', wan_mac)
+
+ wan_cidr = self.net.netmask_to_cidr(results['publicnetmask'])
+ results['publiccidr'] = self.askResponse('publiccidr', confs, 'PublicCIDR', wan_cidr)
+
+ wan_network = self.net.getNetwork(results['publicaddress'], results['publicnetmask'])
+ results['publicnetwork'] = self.askResponse('publicnetwork', confs, 'PublicNetwork', wan_network)
+
+ results['publicdomain'] = 'sunhpc.com'
+ results['publicmtu'] = '1500'
+ results['publicntphost'] = 'pool.ntp.org'
+
+
+ # private network configure
+ results['privatehostname'] = self.askResponse('privatehostname', confs, 'PrivateHostname', 'cluster')
+
+ interface.remove(results['publicinterface'])
+ if len(interface) == 1:
+ interface = interface[0]
+ results['privateinterface'] = self.askResponse('privateinterface', confs, 'PrivateInterface', interface)
+
+ lan_dev = results['privateinterface']
+ results['privateaddress'] = self.askResponse(
+ 'privateaddress', confs, 'PrivateAddress %s' % lan_dev, devices[lan_dev])
+
+ results['privatenetmask'] = self.askResponse(
+ 'privatenetmask', confs, 'PrivateNetmask (%s:%s)' % (lan_dev, devices[lan_dev]), '255.255.255.0')
+
+ lan_mac = os.popen("nmcli device show %s|grep -i general.hwaddr|awk '{print $2}'" % lan_dev).readline().strip()
+ results['privatemacaddress'] = self.askResponse(
+ 'privateMacAddress', confs, 'PrivateMacAddress (%s:%s)' % (lan_dev, devices[lan_dev]), lan_mac)
+
+ lan_cidr = self.net.netmask_to_cidr(results['privatenetmask'])
+ results['privatecidr'] = self.askResponse(
+ 'privatecidr', confs, 'PrivateCIDR (%s:%s)' % (lan_dev, results['privatenetmask']), lan_cidr)
+
+ lan_network = self.net.getNetwork(results['privateaddress'], results['privatenetmask'])
+ results['privatenetwork'] = self.askResponse('privatenetwork', confs,
+ 'PrivateNetwork (%s:%s/%s)' % (lan_dev, results['privateaddress'], results['privatenetmask']), lan_network)
+
+ results['privatedomain'] = self.askResponse('privatednsdomain', confs, 'PrivateDomain', 'local')
+ results['privatemtu'] = self.askResponse('privatemtu', confs, 'PrivateMTU', '1500')
+
+ domain = '%s.%s' % (results['privatehostname'], results['privatedomain'])
+ results['privatentphost'] = self.askResponse('privatentphost', confs, 'PrivateNTPHost', domain)
+
+ except KeyError as e:
+ self.msg("The %s config file error. Field -> *" % conf, 'a', e.args)
+
+ results['timezone'] = self.askResponse('timezone', confs, 'Time zone', 'Asia/Shanghai')
+
+ results['bootargs'] = 'net.ifnames=0 biosdevname=0 ksdevice=bootif'
+ results['basedir'] = 'install'
+ results['ganglia'] = '224.0.0.3'
+ results['nextserver'] = results['privateaddress']
+ results['plugin_port'] = '12345'
+ results['pxefilename'] = 'pxelinux.0'
+ results['pxelinuxdir'] = '/tftpboot/pxelinux'
+ results['distribution'] = 'sunhpc-dist'
+
+ self.dictOutput(results)
+
+ tmp = ''
+ commands = ['s', 'save', 'q', 'quit', 'e', 'exit']
+ while tmp not in commands:
+ print ('--------------------------------------------------')
+ print ('\ts/save : Save vars to %s, database and quit' % args.config)
+ print ('\tq/quit : Quit the and not save')
+ print ('\tModify key=value e.g,. plugin_port=2222')
+ tmp = input('Please input command: ').lower()
+ value = tmp.split('=', 1)
+ if len(value) == 2 and value[0] in results:
+ results[value[0]] = value[1]
+ self.dictOutput(results)
+
+ if tmp in ['s', 'save']:
+ with open(args.config, 'w') as f:
+ f.write('#\n# sunhpc cluster config file.\n#\n')
+ for k in results:
+ value = "%s %s\n" % (k.lower(), results[k])
+ f.write(value)
+ self.msg('Has been written to the configuration file %s' % args.config)
+
+ return results
+
+ def init_database(self, results):
+ queuelist = []
+
+ category = [
+ "insert into categories values (1, 'global', 'Global Default')",
+ "insert into categories values (2, 'os', 'OS System name')",
+ "insert into categories values (3, 'appliance', 'Logical appliances')",
+ "insert into categories values (4, 'rack', 'Machine room racks')",
+ "insert into categories values (5, 'host', 'Physical and virtual')"
+ ]
+ queuelist.extend(category)
+
+ catindex = [
+ "insert into catindex values (1, 'global', 1)",
+ "insert into catindex values (2, 'linux', 2)",
+ "insert into catindex values (3, 'other', 2)",
+ "insert into catindex values (4, 'frontend', 3)",
+ "insert into catindex values (5, 'compute', 3)",
+ "insert into catindex values (6, 'nas', 3)",
+ "insert into catindex values (7, 'network', 3)",
+ "insert into catindex values (8, 'power', 3)",
+ "insert into catindex values (9, 'devel-server', 3)",
+ "insert into catindex values (10, 'login', 3)"
+ ]
+ queuelist.extend(catindex)
+
+ bootver = '%s-%s' % (sunhpc.version, self.arch)
+ vmlinuz = 'vmlinuz-%s' % bootver
+ initrds = 'initrd-%s' % bootver
+ defArgs = "net.ifnames=0 biosdevname=0"
+ insArgs = "%s inst.ks.sendmac ksdevice=bootif" % (defArgs)
+ resArgs = "%s rescue" % defArgs
+ lesArgs = "%s vnc" % defArgs
+ bootaction = [
+ "insert into bootactions values (1, 'install', '%s', '%s', '%s')" % (vmlinuz, initrds, insArgs),
+ "insert into bootactions values (2, 'os', 'localboot 0', NULL, NULL)",
+ "insert into bootactions values (3, 'memtest', 'kernel memtest', NULL, NULL)",
+ "insert into bootactions values (4, 'install headless', '%s', '%s', '%s')" % (vmlinuz, initrds, lesArgs),
+ "insert into bootactions values (5, 'rescue', '%s', '%s', '%s')" % (vmlinuz, initrds, resArgs),
+ "insert into bootactions values (6, 'pxeflash', 'kernel memdisk bigraw', 'pxeflash.img', 'keeppxe')"]
+
+ queuelist.extend(bootaction)
+
+ cmds = "dmidecode -t memory|sed s/[[:space:]]//g|grep '^Size'|grep -E 'MB|GB'"
+ Names = os.popen("cat /sys/class/dmi/id/product_name").readline().strip()
+ Vender = os.popen("cat /sys/class/dmi/id/sys_vendor").readline().strip()
+ Serial = os.popen("cat /sys/class/dmi/id/product_serial").readline().strip()
+ CPUs = os.popen("dmidecode -t processor|grep Version|cut -d ':' -f2|wc -l").readline().strip()
+ Core = os.popen("cat /proc/cpuinfo |grep processor|wc -l").readline().strip()
+ Model = os.popen("dmidecode -t processor|grep Version|head -n1|cut -d ':' -f2").readline().strip()
+ MemNumb = os.popen("%s|wc -l" % cmds).readline().strip()
+ MemSize = os.popen("%s|head -n1|cut -d ':' -f2" % cmds).readline().strip()
+
+ part = sunhpc.core.partition.Partition()
+ part.discoveredDisks()
+ disks = part.getDisks()
+ nodeDisks = part.getNodePartInfo(disks)
+ disklist = sorted(nodeDisks)
+ for disk in disklist:
+ for d in nodeDisks[disk]:
+ dev, sec, size, partid, fstype, bootflags, partflags, mountpoint = d
+ if dev and sec and size:
+ queuelist.append('''
+ insert into partitions(
+ node, device, mountpoint, sectorstart, partitionsize, fstype, partitionflags, formatflags) values
+ (1, "%s", "%s", "%s" ,"%s", "%s", "%s", "%s")''' % (
+ dev, mountpoint, sec, size, fstype, bootflags, partflags))
+
+ for attr in self.attributions(results):
+ k, v, s, n = attr
+ queuelist.append('insert into attributes values (NULL, "%s","%s", NULL, "%s", %d)' % (k, v, s, n))
+
+ # add frontend node
+ lhostname = results['privatehostname']
+ nodes = [ "insert into nodes values (1, '%s', %d, 0, 0, '%s', '%s', '', 'install')" \
+ % (lhostname, int(Core), self.arch, self.os)]
+ queuelist.extend(nodes)
+
+ # add frontend network
+ network = [ "insert into networks values (1, 1, '%s', '%s', '%s', '%s', '2')" %
+ (results['publicmacaddress'], results['publicaddress'], lhostname, results['publicinterface']),
+ "insert into networks values (2, 1, '%s', '%s', '%s', '%s', '1')" %
+ (results['privatemacaddress'], results['privateaddress'], lhostname, results['privateinterface'])]
+ queuelist.extend(network)
+
+ # add subnet data
+ subnet = [ "insert into subnets values (1, 'private', '%s', '%s', '%s', '%s', '1')" %
+ (results['privatedomain'], results['privatenetwork'],
+ results['privatenetmask'], results['privatemtu']),
+ "insert into subnets values (2, 'public', '%s', '%s', '%s', '%s', '0')" %
+ (results['publicdomain'], results['publicnetwork'],
+ results['publicnetmask'], results['publicmtu'])]
+ queuelist.extend(subnet)
+ return queuelist
+
+ def attributions(self, results):
+
+ attrs = [
+ ('Info_CertificateCountry', results['country'], 1, 1),
+ ('Info_CertificateState', results['state'], 1, 1),
+ ('Info_CertificateLocality', results['city'], 1, 1),
+ ('Info_CertificateOrganization','DLHP', 1, 1),
+ ('Info_ClusterUrl', results['url'], 1, 1),
+ ('Info_ClusterName', results['name'], 1, 1),
+ ('Info_ClusterContact', results['contact'], 1, 1),
+ ('Kickstart_WorksDir', results['worksdir'], 1, 1),
+ ('Kickstart_DistroDir', results['distrodir'], 1, 1),
+ ('Kickstart_Partition', results['partition'], 1, 1),
+
+ ('Kickstart_PublicHostname', results['publichostname'], 1, 1),
+ ('Kickstart_PublicInterface', results['publicinterface'], 1, 1),
+ ('Kickstart_PublicAddress', results['publicaddress'], 1, 1),
+ ('Kickstart_PublicMacAddr', results['publicmacaddress'], 1, 1),
+ ('Kickstart_PublicNetmask', results['publicnetmask'], 1, 1),
+ ('Kickstart_PublicNetmaskCIDR', results['publiccidr'], 1, 1),
+ ('Kickstart_PublicGateway', results['publicgateway'], 1, 1),
+ ('Kickstart_PublicNetwork', results['publicnetwork'], 1, 1),
+ ('Kickstart_PublicNTPHost', results['publicntphost'], 1, 1),
+ ('Kickstart_PublicDNSServer', results['publicdnsserver'], 1, 1),
+ ('Kickstart_PublicDomain', results['publicdomain'], 1, 1),
+ ('Kickstart_PublicMTU', results['publicmtu'], 1, 1),
+
+ ('Kickstart_PrivateHostname', results['privatehostname'], 1, 1),
+ ('Kickstart_PrivateInterface', results['privateinterface'], 1, 1),
+ ('Kickstart_PrivateAddress', results['privateaddress'], 1, 1),
+ ('Kickstart_PrivateMacAddr', results['privatemacaddress'], 1, 1),
+ ('Kickstart_PrivateNetmask', results['privatenetmask'], 1, 1),
+ ('Kickstart_PrivateNetmaskCIDR',results['privatecidr'], 1, 1),
+ ('Kickstart_PrivateGateway', results['privateaddress'], 1, 1),
+ ('Kickstart_PrivateNetwork', results['privatenetwork'], 1, 1),
+ ('Kickstart_PrivateNTPHost', results['privatentphost'], 1, 1),
+ ('Kickstart_PrivateDNSServer', results['privateaddress'], 1, 1),
+ ('Kickstart_PrivateDomain', results['privatedomain'], 1, 1),
+ ('Kickstart_PrivateMTU', results['privatemtu'], 1, 1),
+
+ ('Kickstart_Plugin_Port', results['plugin_port'], 1, 1),
+ ('Kickstart_Plugin_Keys', self.daemon_pass(), 1, 1),
+
+ ('Kickstart_Timezone', results['timezone'], 1, 1),
+ ('Kickstart_Bootargs', results['bootargs'], 1, 1),
+ ('distribution', results['distribution'], 1, 1),
+ ('Kickstart_BaseDir', results['basedir'], 1, 1),
+
+ ('safeport', results['safeport'], 1, 1),
+ ('safedirs', results['safedirs'], 1, 1),
+ ('safesecurity', results['safesecurity'], 1, 1),
+
+ ('sunhpc_version', sunhpc.version, 1, 1),
+ ('sunhpc_major', self.major, 1, 1),
+ ('sunhpc_minor', self.minor, 1, 1),
+ ('sunhpc_micro', self.micro, 1, 1),
+ ('sunhpc_release', sunhpc.release, 1, 1),
+ ('ganglia_address', results['ganglia'], 1, 1),
+
+ ('dhcp_filename', results['pxefilename'], 1, 1),
+ ('dhcp_nextserver', results['nextserver'], 1, 1),
+ ('pxelinuxdir', results['pxelinuxdir'], 1, 1),
+
+ ('kickstartable', 'yes', 3, 4),
+ ('kickstartable', 'yes', 3, 5),
+ ('kickstartable', 'yes', 3, 6),
+ ('kickstartable', 'no', 3, 7),
+ ('kickstartable', 'no', 3, 8),
+ ('kickstartable', 'yes', 3, 9),
+ ('kickstartable', 'yes', 3, 10),
+
+ ('managed', 'true', 1, 1),
+ ('os', 'linux', 1, 1)
+ ]
+ return attrs
+
+ def baseTables(self):
+ """Base data tables"""
+
+ drop = 'DROP TABLE IF EXISTS tablename'
+ tables = ['Nodes', 'Networks', 'Subnets', 'PublicKeys', 'SecNodes',
+ 'Attributes', 'Partitions', 'Categories', 'Catindex',
+ 'Rolls', 'Bootactions', 'distributions',
+ 'SecGlobals'
+ ]
+ datalist = []
+
+ Categories = '''
+ CREATE TABLE Categories (
+ ID integer primary key autoincrement,
+ Name varchar(64) NOT NULL unique default '0',
+ Description varchar(512) default null
+ )'''
+ datalist.append(Categories)
+
+ Catindex = '''
+ CREATE TABLE Catindex (
+ ID integer primary key autoincrement,
+ Name varchar(64) not null default '0',
+ Category integer not null,
+ Foreign key(Category) references categories(ID) on delete cascade on update restrict
+ )'''
+ datalist.append(Catindex)
+
+ Nodes = '''
+ CREATE TABLE Nodes (
+ ID integer primary key autoincrement,
+ Name varchar(128) default NULL,
+ CPUs integer(11) NOT NULL default '1',
+ Rack integer(11) default NULL,
+ Rank integer(11) default NULL,
+ Arch varchar(32) default NULL,
+ OS varchar(64) NOT NULL default 'linux',
+ Alias varchar(64) default '',
+ Status varchar(32) default 'os'
+ )'''
+ datalist.append(Nodes)
+
+ Networks = '''
+ CREATE TABLE Networks (
+ ID integer NOT NULL primary key autoincrement,
+ Node integer(11) default NULL,
+ MAC varchar(64) default NULL,
+ IP varchar(32) default NULL,
+ Name varchar(128) default NULL,
+ Device varchar(32) default NULL,
+ Subnet integer(11) default NULL,
+ Foreign key(subnet) references subnets(id) on delete cascade on update restrict,
+ Foreign key(node) references nodes(id) on delete cascade on update restrict
+ )'''
+ datalist.append(Networks)
+
+ Subnets = '''
+ CREATE TABLE Subnets (
+ ID integer NOT NULL primary key autoincrement,
+ name varchar(32) UNIQUE NOT NULL,
+ dnszone varchar(64) UNIQUE NOT NULL ,
+ subnet varchar(32) NOT NULL,
+ netmask varchar(32) NOT NULL,
+ mtu integer(11) default '1500',
+ servedns boolean default false
+ )'''
+ datalist.append(Subnets)
+
+ PublicKeys = '''
+ CREATE TABLE PublicKeys (
+ ID integer NOT NULL primary key autoincrement,
+ Public_Key varchar(4096) default NULL,
+ Description varchar(4096) default NULL,
+ Node integer(11) NOT NULL default '0',
+ Foreign key(Node) references nodes(id) on delete cascade on update restrict
+ )'''
+ datalist.append(PublicKeys)
+
+ SecNodes = '''
+ CREATE TABLE SecNodes (
+ Attr varchar(128) default NULL,
+ Enc varchar(64) default NULL,
+ Value text,
+ Node integer(11) NOT NULL default '0',
+ PRIMARY KEY (Node, Attr)
+ )'''
+ datalist.append(SecNodes)
+
+ Attributes = '''
+ CREATE TABLE Attributes (
+ ID integer NOT NULL primary key autoincrement,
+ Attr varchar(128) NOT NULL,
+ Value text,
+ Shadow text,
+ Category int(11) NOT NULL,
+ Catindex int(11) NOT NULL,
+ Foreign key(Catindex) references catindex(id) on delete cascade on update restrict
+ )'''
+ datalist.append(Attributes)
+
+ Partitions = '''
+ CREATE TABLE Partitions (
+ ID integer NOT NULL primary key autoincrement,
+ Node integer(11) NOT NULL default '0',
+ Device varchar(128) NOT NULL default '',
+ Mountpoint varchar(128) NOT NULL default '',
+ SectorStart varchar(128) NOT NULL default '',
+ PartitionSize varchar(128) NOT NULL default '',
+ PartitionID varchar(128) NOT NULL default '',
+ FsType varchar(128) NOT NULL default '',
+ PartitionFlags varchar(128) NOT NULL default '',
+ FormatFlags varchar(128) NOT NULL default ''
+ )'''
+ datalist.append(Partitions)
+
+ Firewalls = '''
+ CREATE TABLE Firewalls (
+ ID integer NOT NULL primary key autoincrement,
+ Rulename varchar(128) NOT NULL,
+ Service varchar(256),
+ Protocol varchar(256),
+ Ports varchar(256),
+ Action varchar(256),
+ Comment varchar(256),
+ Node integer(11) NOT NULL default '0'
+ )'''
+ datalist.append(Firewalls)
+
+ Rolls = '''
+ CREATE TABLE Rolls (
+ ID integer NOT NULL primary key autoincrement,
+ Name varchar(128) NOT NULL default '',
+ Version varchar(32) NOT NULL default '',
+ Arch varchar(32) NOT NULL default '',
+ OS varchar(64) NOT NULL default 'linux',
+ Enabled varchar(11) NOT NULL default 'yes'
+ )'''
+ datalist.append(Rolls)
+
+ Bootactions = '''
+ CREATE TABLE Bootactions (
+ ID integer NOT NULL primary key autoincrement,
+ Action varchar(1024) default NULL,
+ Kernel varchar(1024) default NULL,
+ Ramdisk varchar(1024) default NULL,
+ Args varchar(1024) default NULL
+ )'''
+ datalist.append(Bootactions)
+
+ Distributions = '''
+ CREATE TABLE Distributions (
+ ID integer NOT NULL primary key autoincrement,
+ Name varchar(32) NOT NULL default '',
+ OS varchar(32) default '',
+ Release varchar(32) default ''
+ )'''
+ datalist.append(Distributions)
+
+ SecGlobals = '''
+ CREATE TABLE SecGlobals (
+ Attr varchar(128) default NULL,
+ Value text,
+ Enc varchar(128) default NULL,
+ PRIMARY KEY (Attr)
+ )'''
+ datalist.append(SecGlobals)
+
+ return datalist
+
+ def daemon_pass(self):
+ char = 'abcdefghijklmnopqrstuvwxyz'
+ char += '!@#$%+=-_^&*.?'
+ char += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890'
+ return ''.join(random.sample(char,16))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/check/__init__.py b/lib/sunhpc/commands/check/__init__.py
new file mode 100644
index 0000000..8add27f
--- /dev/null
+++ b/lib/sunhpc/commands/check/__init__.py
@@ -0,0 +1,6 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+RollName = "base"
diff --git a/lib/sunhpc/commands/check/services/__init__.py b/lib/sunhpc/commands/check/services/__init__.py
new file mode 100644
index 0000000..b4c42d9
--- /dev/null
+++ b/lib/sunhpc/commands/check/services/__init__.py
@@ -0,0 +1,52 @@
+#coding:utf-8
+import shlex
+import sunhpc
+import socket
+import subprocess
+class Command(sunhpc.commands.Command):
+ """
+ Checks that all Sunhpc required services are up and running, If a
+ required services is not running it will reported.
+
+ <example cmd='check host services'>
+ Check that all required services are up and running.
+ </example>
+ """
+ def checkService(self, cmdname, error_msg):
+ """check if the given cmdname returns 0 once executed.
+ If cmdname fails it returns runninvg"""
+ devnull = open('/dev/null', 'w')
+ process = subprocess.Popen(shlex.split(cmdname), stdout=devnull, stderr=devnull)
+ retcode = process.wait()
+ devnull.close()
+ if retcode != 0:
+ self.abort(error_msg)
+
+ def run(self, params, args):
+
+ if len(args) != 0:
+ self.msg("check services does not accept any argument", 'a')
+
+ if socket.gethostname().split('.')[0] != self.newdb.getFrontendName():
+ self.msg('this command should run only on the frontend', 'a')
+
+ # dhcpd
+ cmd = "service dhcpd status"
+ error_msg = "dhcpd is not running.\n" + \
+ "Restart it with 'systemctl start dhcpd'"
+ self.checkService(cmd, error_msg)
+
+ # xinetd
+ cmd = "curl 'tftp://localhost/pxelinux.0' -o /dev/null"
+ error_msg = "unable to download pxelinux with tftp.\n" + \
+ "Verify that xinetd is running with 'systemctl start tftp'"
+ self.checkService(cmd, error_msg)
+
+ # httpd wget kickstart
+ cmd = "bash -c \"curl -k http://localhost/install/sbin/kickstart.cgi\""
+ error_msg = "unable to download kickstart.\n" + \
+ "Verify httpd is running with 'systemctl start httpd'"
+ self.checkService(cmd, error_msg)
+ return True
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/__init__.py b/lib/sunhpc/commands/create/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/create/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/distro/__init__.py b/lib/sunhpc/commands/create/distro/__init__.py
new file mode 100644
index 0000000..0c3569e
--- /dev/null
+++ b/lib/sunhpc/commands/create/distro/__init__.py
@@ -0,0 +1,162 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import datetime
+import tempfile
+import subprocess
+import sunhpc.commands
+class Command(sunhpc.commands.create.command):
+ """
+ Create a Sunhpc distribution. Use this distribution to install Sunhpc nodes.
+
+ <param type='str' name='arch'>
+ Default:local machine - The architecture of the distribution.
+ </param>
+
+ <param type='str' name='version'>
+ The version of the distribution.The default is: value of the machine
+ </param>
+
+ <param type='str' name='rolls'>
+ Specify an or more roll name. The default is: All Rolls
+ </param>
+
+ <param type='str' name='root'>
+ The path prefix location of the rolls. The default is: /works/sunhpc/install
+ </param>
+
+ <param type='str' name='dist'>
+ The directory name of the distribution. The default is: "sunhpc-dist"
+ </param>
+
+ <param type='bool' name='md5'>
+ Calculate MD5SUM of all packages. The default is: 'yes'
+ </param>
+
+ <example cmd='create distro'>
+ Create a distribution in the current directory.
+ </example>
+ """
+ def getRolls(self):
+ rolls = []
+ self.db.execute('select name, version, arch, enabled from rolls where OS="linux" ')
+ for n, v, a, e in self.db.fetchall():
+ if e == 'yes':
+ rolls.append([n, v, a, e])
+
+ return rolls
+
+ def addRolls(self, rolls):
+ for roll_name, roll_vers, roll_arch, roll_enb in rolls:
+ rows = self.db.search("""select * from rolls
+ where name="%s"
+ and version="%s"
+ and arch="%s"
+ and os="%s" """ % (
+ roll_name, roll_vers, roll_arch, self.os))
+ if not rows:
+ addRoll = """insert into rolls
+ (name, version, arch, enabled, os)
+ values("%s", "%s", "%s", "%s", "%s")
+ """ % (roll_name, roll_vers,
+ roll_arch, roll_enb, self.os)
+ self.db.execute(addRoll)
+
+ def commandDist(self, dist, rolls):
+ builder = sunhpc.core.build.DistributionBuilder(dist)
+ builder.setRolls(rolls)
+ builder.setQuiet(self.quiet)
+ builder.setSiteProfiles(1)
+ builder.setCalcMD5(self.md5)
+ builder.setCommand(self)
+ builder.build()
+ return builder
+
+ def run(self, param, args):
+
+ '''
+ lockfile = '/var/lock/sunhpc-dist'
+ if os.path.exists(lockfile):
+ self.msg("%s exists already.Waiting or remove the lockfile." % lockfile, 'a')
+ os.system('touch %s' % lockfile)
+ '''
+
+ (arch, version, withrolls, root,
+ calcmd5, quiet, dist) = self.fillParams([
+ ('arch', self.arch),
+ ('version', sunhpc.version),
+ ('rolls', None),
+ ('root', '/export/sunhpc/install'),
+ ('md5', 'yes'),
+ ('quiet', 'no'),
+ ('dist', 'sunhpc-dist')
+ ])
+
+ rolls = []
+ if withrolls == None:
+ rolls = self.getRolls()
+ else:
+ for i in withrolls.split():
+ rolls.append(i.split(',') + [ 'yes' ])
+
+ self.md5 = self.str2bool(calcmd5)
+ self.quiet = self.str2bool(quiet)
+
+ mirror = sunhpc.core.dist.Mirror()
+ mirror.setHost('rolls')
+ mirror.setPath(root)
+ mirror.setRoot(root)
+ mirror.setArch(arch)
+
+ mirrors = []
+ mirrors.append(mirror)
+
+ distro = sunhpc.core.dist.Distribution(mirrors, version)
+ distro.setRoot(os.getcwd())
+
+ old_umask = os.umask(0o022)
+ try:
+ #
+ # build the new distro in a temporary directory.
+ #
+ tempdist = tempfile.mkdtemp(dir="")
+ distro.setDist(tempdist)
+
+ distro.setLocal('/usr/src/redhat')
+ distro.setContrib(os.path.join(mirror.getRootPath(), 'contrib', version))
+ builder = self.commandDist(distro, rolls)
+ #
+ # make sure everyone can traverse the rolls directories.
+ #
+ mirrors = distro.getMirrors()
+ fullmirror = mirrors[0].getRollsPath()
+ # modify all dirs mode 755
+ os.system('find %s -type d ' % (fullmirror) + '-exec chmod -R 0755 {} \;')
+
+ if self.arch != arch and os.path.exists(dist):
+ shutil.move(os.path.join(tempdist, arch), os.path.join(dist, arch))
+ shutil.rmtree(tempdist)
+ else:
+ #
+ # now move the previous distro into a temporary directory
+ #
+ prevdist = tempfile.mkdtemp(dir="")
+ try:
+ shutil.move(dist, prevdist)
+ except:
+ pass
+
+ shutil.move(tempdist, dist)
+ os.system('chmod 755 %s' % dist)
+
+ try:
+ shutil.rmtree(prevdist)
+ except:
+ pass
+
+ #os.unlink(lockfile)
+ finally:
+ os.umask(old_umask)
+ self.addRolls(rolls)
diff --git a/lib/sunhpc/commands/create/pxelinux/__init__.py b/lib/sunhpc/commands/create/pxelinux/__init__.py
new file mode 100644
index 0000000..e0304be
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.create.command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/pxelinux/client/__init__.py b/lib/sunhpc/commands/create/pxelinux/client/__init__.py
new file mode 100644
index 0000000..28bc228
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/client/__init__.py
@@ -0,0 +1,85 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import sunhpc
+import subprocess
+class Command(sunhpc.commands.create.pxelinux.command):
+ """
+ Build the sunhpc client rpms file for the cluster
+
+ <param type='Path' name='basedir'>
+ supply an unpack base dirs. default: current directorys.
+ </param>
+
+ <param type='Path' name='outdir'>
+ specify an rpm out directory. default: `pwd`/pxeboot
+ </param>
+
+ <param type='String' name='version'>
+ specify an gen rpm version. default: use sunhpc version
+ </param>
+
+ <example cmd='create pxelinux client'>
+ build an sunhpc client rpms file
+ </example>
+ """
+ def run(self, param, args):
+
+ cwd = os.getcwd()
+ (self.src_base, self.dst_base, version) = self.fillParams([
+ ('basedir', os.path.join(cwd, 'source')),
+ ('outdir', os.path.join(cwd, 'pxeboot')),
+ ('version', sunhpc.version)])
+
+ self.kickstart = os.path.join(self.prefix, 'share', 'isobuild', 'anaconda-sunhpc')
+ self.copyBoot()
+
+ os.chdir(cwd)
+
+ def copyBoot(self):
+ """copy kickstart file to current source directory."""
+
+ # create current/source/images dirs
+ ksdir = os.path.join(self.src_base, 'kickstart')
+ if os.path.exists(ksdir):
+ shutil.rmtree(ksdir)
+ os.makedirs(ksdir)
+
+ kslist = os.listdir(self.kickstart)
+ for k in kslist:
+ tmpdir = os.path.join(self.kickstart, k)
+ os.chdir(tmpdir)
+
+ if os.path.exists(self.src_base):
+ shutil.rmtree(self.src_base)
+
+ os.makedirs(self.src_base)
+ ret = self.shcmd('find . -print|cpio -mpud %s' % self.src_base)
+
+ fullname = '%s-%s-%s.noarch.rpm' % (k, sunhpc.version, sunhpc.version_micro)
+ dst_rpms = self.makeRpms(k, fullname)
+
+ if os.path.exists(dst_rpms):
+ self.msg('\tGenerate complete - %s.' % k)
+ else:
+ self.msg('\tFailed: %s.' % ret, 'a')
+
+ def makeRpms(self, name, fullname):
+ # 制作roll-*-kickstart.rpm文件.
+
+ dst_fullname = os.path.join(self.dst_base, fullname)
+ if os.path.exists(dst_fullname):
+ os.remove(dst_fullname)
+
+ argsk = []
+ argsk.append('arch=noarch')
+ argsk.append('name=%s' % name)
+ argsk.append('source=%s' % self.src_base)
+ argsk.append('prefix=/opt/sunhpc')
+ argsk.append('outdir=%s' % self.dst_base)
+ self.command('create.rpm', argsk)
+ return dst_fullname
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/pxelinux/efiboot/__init__.py b/lib/sunhpc/commands/create/pxelinux/efiboot/__init__.py
new file mode 100644
index 0000000..e6e0260
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/efiboot/__init__.py
@@ -0,0 +1,113 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import sunhpc
+import subprocess
+class Command(sunhpc.commands.create.pxelinux.command):
+ """
+ Build the efiboot file for the cluster
+
+ <param type='Path' name='basedir'>
+ supply an unpack initrd base dirs. default: current directorys.
+ </param>
+
+ <param type='Path' name='outdir'>
+ specify an sunhpc-anaconda-initrd-*.rpm out directory. default: `pwd`/pxeboot
+ </param>
+
+ <param type='Path' name='cdrom'>
+ specify an cdrom directory. default: /mnt/cdrom
+ </param>
+
+ <param type='string' name='version'>
+ initrd name contains the version. default: sunhpc version
+ </param>
+
+ <example cmd='create pxelinux efiboot cdrom=/mnt/cdrom'>
+ build an boot cdrom file
+ </example>
+ """
+ def run(self, param, args):
+
+ cwd = os.getcwd()
+ (self.src_base, self.dst_base, self.cdrom,
+ version) = self.fillParams([
+ ('basedir', os.path.join(cwd, 'source')),
+ ('outdir', os.path.join(cwd, 'pxeboot')),
+ ('cdrom', '/mnt/cdrom'),
+ ('version', sunhpc.version)])
+
+ self.checkCDROM()
+ self.copyBoot()
+ self.makeRpms()
+
+ os.chdir(cwd)
+
+ def copyBoot(self):
+ """copy /mnt/cdrom efiboot to current pxeboot directory."""
+
+ if os.path.exists(self.src_base):
+ shutil.rmtree(self.src_base)
+
+ # create current/source/images dirs
+ os.makedirs(os.path.join(self.src_base, 'images'))
+
+ install = os.path.join(self.src_base, 'images', 'efiboot.img')
+ efiboot = self.shcmd('find %s -type f -name efiboot.img -print' % self.cdrom)['o'].strip()
+
+ # copy /mnt/cdrom/LiveOS/efiboot.img to current/source/images/efiboot.img
+ shutil.copyfile(efiboot, install)
+ if os.path.exists(install):
+ self.msg('\tCopying complete efiboot.img')
+
+ def makeRpms(self):
+ # 制作sunhpc-anaconda-cdboot-*.rpm文件.
+
+ rpm_fullname = 'sunhpc-anaconda-efiboot-%s-%s.noarch.rpm' % (sunhpc.version, sunhpc.version_micro)
+ dst_fullname = os.path.join(self.dst_base, rpm_fullname)
+ if os.path.exists(dst_fullname):
+ os.remove(dst_fullname)
+
+ argsk = []
+ argsk.append('arch=noarch')
+ argsk.append('name=sunhpc-anaconda-efiboot')
+ argsk.append('source=%s' % self.src_base)
+ argsk.append('prefix=/export/sunhpc/install/source')
+ argsk.append('outdir=%s' % self.dst_base)
+ self.command('create.rpm', argsk)
+
+ def checkCDROM(self):
+ """check cdrom"""
+
+ if not len(os.listdir(self.cdrom)):
+ self.msg('Must mount an valid linux cd to %s.' % self.cdrom, 'a')
+
+ efiboot = subprocess.run('find %s -type f -name efiboot.img -print >/dev/null' % self.cdrom, shell=True)
+ if efiboot.returncode:
+ self.msg('Not found the efiboot file. %s' % efiboot, 'a')
+
+ # 读取版本信息
+ # timestamp, 7.6, x86_64
+ info = ['unkonw', 'unkown', 'unkown']
+ with open(os.path.join(self.cdrom, '.discinfo'), 'r') as f:
+ line = f.readlines()
+ if len(line) >= 3:
+ info[0] = line[0].strip()
+ info[1] = line[1].strip()
+ info[2] = line[2].strip()
+
+ name = 'unkown'
+ with open(os.path.join(self.cdrom, '.treeinfo'), 'r') as f:
+ for l in f.readlines():
+ if l.find('family') == 0:
+ try:
+ name = l.split('=')[1].strip()
+ except KeyError:
+ pass
+
+ if name not in ['CentOS', 'Sunhpc']:
+ self.msg('The %s is not supported.' % name, 'a')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/pxelinux/kickstart/__init__.py b/lib/sunhpc/commands/create/pxelinux/kickstart/__init__.py
new file mode 100644
index 0000000..bac56a4
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/kickstart/__init__.py
@@ -0,0 +1,85 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import sunhpc
+import subprocess
+class Command(sunhpc.commands.create.pxelinux.command):
+ """
+ Build the kickstart rpms file for the cluster
+
+ <param type='Path' name='basedir'>
+ supply an unpack base dirs. default: current directorys.
+ </param>
+
+ <param type='Path' name='outdir'>
+ specify an roll-*-kickstart.rpm out directory. default: `pwd`/pxeboot
+ </param>
+
+ <param type='String' name='version'>
+ specify an gen rpm version. default: use sunhpc version
+ </param>
+
+ <example cmd='create pxelinux kickstart'>
+ build an kickstart rpms file
+ </example>
+ """
+ def run(self, param, args):
+
+ cwd = os.getcwd()
+ (self.src_base, self.dst_base, version) = self.fillParams([
+ ('basedir', os.path.join(cwd, 'source')),
+ ('outdir', os.path.join(cwd, 'pxeboot')),
+ ('version', sunhpc.version)])
+
+ self.kickstart = os.path.join(self.prefix, 'share', 'isobuild', 'kickstart')
+ self.copyBoot()
+
+ os.chdir(cwd)
+
+ def copyBoot(self):
+ """copy kickstart file to current source directory."""
+
+ # create current/source/images dirs
+ ksdir = os.path.join(self.src_base, 'kickstart')
+ if os.path.exists(ksdir):
+ shutil.rmtree(ksdir)
+ os.makedirs(ksdir)
+
+ kslist = os.listdir(self.kickstart)
+ for k in kslist:
+ tmpdir = os.path.join(self.kickstart, k)
+ os.chdir(tmpdir)
+
+ if os.path.exists(self.src_base):
+ shutil.rmtree(self.src_base)
+
+ os.makedirs(self.src_base)
+ ret = self.shcmd('find . -print|cpio -mpud %s' % self.src_base)
+
+ fullname = '%s-%s-%s.noarch.rpm' % (k, sunhpc.version, sunhpc.version_micro)
+ dst_rpms = self.makeRpms(k, fullname)
+
+ if os.path.exists(dst_rpms):
+ self.msg('\tGenerate complete %s' % k)
+ else:
+ self.msg('Failed: %s' % ret, 'a')
+
+ def makeRpms(self, name, fullname):
+ # 制作roll-*-kickstart.rpm文件.
+
+ dst_fullname = os.path.join(self.dst_base, fullname)
+ if os.path.exists(dst_fullname):
+ os.remove(dst_fullname)
+
+ argsk = []
+ argsk.append('arch=noarch')
+ argsk.append('name=%s' % name)
+ argsk.append('source=%s' % self.src_base)
+ argsk.append('prefix=/export/sunhpc/install/source')
+ argsk.append('outdir=%s' % self.dst_base)
+ self.command('create.rpm', argsk)
+ return dst_fullname
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/pxelinux/product/__init__.py b/lib/sunhpc/commands/create/pxelinux/product/__init__.py
new file mode 100644
index 0000000..bd6e182
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/product/__init__.py
@@ -0,0 +1,80 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import sunhpc
+import subprocess
+class Command(sunhpc.commands.create.pxelinux.command):
+ """
+ Build the product file for the cluster
+
+ <param type='Path' name='basedir'>
+ supply an unpack base dirs. default: current directorys.
+ </param>
+
+ <param type='Path' name='outdir'>
+ specify an sunhpc-anaconda-product-*.rpm out directory. default: `pwd`/pxeboot
+ </param>
+
+ <param type='String' name='version'>
+ specify an gen rpm version. default: use sunhpc version
+ </param>
+
+ <example cmd='create pxelinux product'>
+ build an product images file
+ </example>
+ """
+ def run(self, param, args):
+
+ cwd = os.getcwd()
+ (self.src_base, self.dst_base, version) = self.fillParams([
+ ('basedir', os.path.join(cwd, 'source')),
+ ('outdir', os.path.join(cwd, 'pxeboot')),
+ ('version', sunhpc.version)])
+
+ self.product = os.path.join(self.prefix, 'share', 'isobuild', 'anaconda-product')
+
+ self.copyBoot()
+ self.makeRpms()
+
+ os.chdir(cwd)
+
+ def copyBoot(self):
+ """copy product file to current source directory."""
+
+ # create current/source/images dirs
+ images = os.path.join(self.src_base, 'images')
+ if os.path.exists(images):
+ shutil.rmtree(images)
+ os.makedirs(images)
+
+ os.chdir(self.product)
+ dst_product = os.path.join(self.src_base, 'images', 'product.img')
+ if os.path.exists(dst_product):
+ os.remove(dst_product)
+
+ ret = self.shcmd('find . |cpio -c -o|xz -9 --format=xz > %s' % dst_product, code=True)
+
+ if os.path.exists(dst_product):
+ self.msg('\tGenerate complete product.img')
+ else:
+ self.msg('Failed: %s.' % ret, 'a')
+
+ def makeRpms(self):
+ # 制作sunhpc-anaconda-cdboot-*.rpm文件.
+
+ rpm_fullname = 'sunhpc-anaconda-product-%s-%s.noarch.rpm' % (sunhpc.version, sunhpc.version_micro)
+ dst_fullname = os.path.join(self.dst_base, rpm_fullname)
+ if os.path.exists(dst_fullname):
+ os.remove(dst_fullname)
+
+ argsk = []
+ argsk.append('arch=noarch')
+ argsk.append('name=sunhpc-anaconda-product')
+ argsk.append('source=%s' % self.src_base)
+ argsk.append('prefix=/export/sunhpc/install/source')
+ argsk.append('outdir=%s' % self.dst_base)
+ self.command('create.rpm', argsk)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/pxelinux/squashfs/__init__.py b/lib/sunhpc/commands/create/pxelinux/squashfs/__init__.py
new file mode 100644
index 0000000..1917dcc
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/squashfs/__init__.py
@@ -0,0 +1,113 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import sunhpc
+import subprocess
+class Command(sunhpc.commands.create.pxelinux.command):
+ """
+ Build the squashfs file for the cluster
+
+ <param type='Path' name='basedir'>
+ supply an unpack initrd base dirs. default: current directorys.
+ </param>
+
+ <param type='Path' name='outdir'>
+ specify an sunhpc-anaconda-initrd-*.rpm out directory. default: `pwd`/pxeboot
+ </param>
+
+ <param type='Path' name='cdrom'>
+ specify an cdrom directory. default: /mnt/cdrom
+ </param>
+
+ <param type='string' name='version'>
+ initrd name contains the version. default: sunhpc version
+ </param>
+
+ <example cmd='create pxelinux install cdrom=/mnt/cdrom'>
+ build an boot cdrom file
+ </example>
+ """
+ def run(self, param, args):
+
+ cwd = os.getcwd()
+ (self.src_base, self.dst_base, self.cdrom,
+ version) = self.fillParams([
+ ('basedir', os.path.join(cwd, 'source')),
+ ('outdir', os.path.join(cwd, 'pxeboot')),
+ ('cdrom', '/mnt/cdrom'),
+ ('version', sunhpc.version)])
+
+ self.checkCDROM()
+ self.copyBoot()
+ self.makeRpms()
+
+ os.chdir(cwd)
+
+ def copyBoot(self):
+ """copy /mnt/cdrom squashfs to current pxeboot directory."""
+
+ if os.path.exists(self.src_base):
+ shutil.rmtree(self.src_base)
+
+ # create current/source/images dirs
+ os.makedirs(os.path.join(self.src_base, 'images'))
+
+ install = os.path.join(self.src_base, 'images', 'install.img')
+ squashfs = self.shcmd('find %s -type f -name squashfs.img -print' % self.cdrom)['o'].strip()
+
+ # copy /mnt/cdrom/LiveOS/squashfs.img to current/source/images/squashfs.img
+ shutil.copyfile(squashfs, install)
+ if os.path.exists(install):
+ self.msg('\tCopying complete install.img')
+
+ def makeRpms(self):
+ # 制作sunhpc-anaconda-cdboot-*.rpm文件.
+
+ rpm_fullname = 'sunhpc-anaconda-install-%s-%s.noarch.rpm' % (sunhpc.version, sunhpc.version_micro)
+ dst_fullname = os.path.join(self.dst_base, rpm_fullname)
+ if os.path.exists(dst_fullname):
+ os.remove(dst_fullname)
+
+ argsk = []
+ argsk.append('arch=noarch')
+ argsk.append('name=sunhpc-anaconda-install')
+ argsk.append('source=%s' % self.src_base)
+ argsk.append('prefix=/export/sunhpc/install/source')
+ argsk.append('outdir=%s' % self.dst_base)
+ self.command('create.rpm', argsk)
+
+ def checkCDROM(self):
+ """check cdrom"""
+
+ if not len(os.listdir(self.cdrom)):
+ self.msg('Must mount an valid linux cd to %s.' % self.cdrom, 'a')
+
+ squashfs = subprocess.run('find %s -type f -name squashfs.img -print >/dev/null' % self.cdrom, shell=True)
+ if squashfs.returncode:
+ self.msg('Not found the squashfs file. %s' % squashfs, 'a')
+
+ # 读取版本信息
+ # timestamp, 7.6, x86_64
+ info = ['unkonw', 'unkown', 'unkown']
+ with open(os.path.join(self.cdrom, '.discinfo'), 'r') as f:
+ line = f.readlines()
+ if len(line) >= 3:
+ info[0] = line[0].strip()
+ info[1] = line[1].strip()
+ info[2] = line[2].strip()
+
+ name = 'unkown'
+ with open(os.path.join(self.cdrom, '.treeinfo'), 'r') as f:
+ for l in f.readlines():
+ if l.find('family') == 0:
+ try:
+ name = l.split('=')[1].strip()
+ except KeyError:
+ pass
+
+ if name not in ['CentOS', 'Sunhpc']:
+ self.msg('The %s is not supported.' % name, 'a')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/pxelinux/updates/__init__.py b/lib/sunhpc/commands/create/pxelinux/updates/__init__.py
new file mode 100644
index 0000000..bf99022
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/updates/__init__.py
@@ -0,0 +1,80 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import sunhpc
+import subprocess
+class Command(sunhpc.commands.create.pxelinux.command):
+ """
+ Build the updates images file for the cluster
+
+ <param type='Path' name='basedir'>
+ supply an unpack base dirs. default: current directorys.
+ </param>
+
+ <param type='Path' name='outdir'>
+ specify an sunhpc-anaconda-updates-*.rpm out directory. default: `pwd`/pxeboot
+ </param>
+
+ <param type='String' name='version'>
+ specify an gen rpm version. default: sunhpc version
+ </param>
+
+ <example cmd='create pxelinux updates'>
+ build an updates images file
+ </example>
+ """
+ def run(self, param, args):
+
+ cwd = os.getcwd()
+ (self.src_base, self.dst_base, version) = self.fillParams([
+ ('basedir', os.path.join(cwd, 'source')),
+ ('outdir', os.path.join(cwd, 'pxeboot')),
+ ('version', sunhpc.version)])
+
+ self.updates = os.path.join(self.prefix, 'share', 'isobuild', 'anaconda-updates')
+
+ self.copyBoot()
+ self.makeRpms()
+
+ os.chdir(cwd)
+
+ def copyBoot(self):
+ """copy updates file to current source directory."""
+
+ # create current/source/images dirs
+ images = os.path.join(self.src_base, 'images')
+ if os.path.exists(images):
+ shutil.rmtree(images)
+ os.makedirs(images)
+
+ os.chdir(self.updates)
+ dst_updates = os.path.join(self.src_base, 'images', 'updates.img')
+ if os.path.exists(dst_updates):
+ os.remove(dst_updates)
+
+ ret = self.shcmd('find . |cpio -c -o|xz -9 --format=xz > %s' % dst_updates, code=True)
+
+ if os.path.exists(dst_updates):
+ self.msg('\tGenerate complete updates.img')
+ else:
+ self.msg('Failed: %s.' % ret, 'a')
+
+ def makeRpms(self):
+ # 制作sunhpc-anaconda-cdboot-*.rpm文件.
+
+ rpm_fullname = 'sunhpc-anaconda-updates-%s-%s.noarch.rpm' % (sunhpc.version, sunhpc.version_micro)
+ dst_fullname = os.path.join(self.dst_base, rpm_fullname)
+ if os.path.exists(dst_fullname):
+ os.remove(dst_fullname)
+
+ argsk = []
+ argsk.append('arch=noarch')
+ argsk.append('name=sunhpc-anaconda-updates')
+ argsk.append('source=%s' % self.src_base)
+ argsk.append('prefix=/export/sunhpc/install/source')
+ argsk.append('outdir=%s' % self.dst_base)
+ self.command('create.rpm', argsk)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/pxelinux/vminitrd/__init__.py b/lib/sunhpc/commands/create/pxelinux/vminitrd/__init__.py
new file mode 100644
index 0000000..3562eef
--- /dev/null
+++ b/lib/sunhpc/commands/create/pxelinux/vminitrd/__init__.py
@@ -0,0 +1,124 @@
+#coding:utf-8
+import os
+import sys
+import time
+import shutil
+import sunhpc
+import subprocess
+class Command(sunhpc.commands.create.pxelinux.command):
+ """
+ Build the initrd and vmlinuz file for the cluster
+
+ <param type='Path' name='basedir'>
+ supply an unpack base dirs. default: current directorys.
+ </param>
+
+ <param type='Path' name='outdir'>
+ specify an sunhpc-anaconda-initrds-*.rpm out directory. default: `pwd`/pxeboot
+ </param>
+
+ <param type='Path' name='cdrom'>
+ specify an cdrom directory. default: /mnt/cdrom
+ </param>
+
+ <param type='string' name='version'>
+ initrd name contains the version. default: sunhpc version
+ </param>
+
+ <example cmd='create pxelinux vminitrd cdrom=/mnt/cdrom'>
+ build an boot vmlinux and initrd file
+ </example>
+ """
+ def run(self, param, args):
+
+ cwd = os.getcwd()
+ (self.src_base, self.dst_base, self.cdrom,
+ version) = self.fillParams([
+ ('basedir', os.path.join(cwd, 'source')),
+ ('outdir', os.path.join(cwd, 'pxeboot')),
+ ('cdrom', '/mnt/cdrom'),
+ ('version', sunhpc.version)])
+
+ self.printVersion()
+ self.copyBoot()
+ self.makeRpms()
+
+ os.chdir(cwd)
+
+ def copyBoot(self):
+ """copy /mnt/cdrom initrd vmlinuz to current pxeboot directory."""
+
+ # create current/source/images dirs
+ images = os.path.join(self.src_base, 'images')
+ if os.path.exists(images):
+ shutil.rmtree(images)
+ os.makedirs(images)
+
+ initrds = 'initrd-%s-%s' % (sunhpc.version, self.arch)
+ vmlinuz = 'vmlinuz-%s-%s' % (sunhpc.version, self.arch)
+
+ dst_initrds = os.path.join(self.src_base, 'images', initrds)
+ dst_vmlinuz = os.path.join(self.src_base, 'images', vmlinuz)
+
+ src_initrds = self.shcmd('find %s -type f -name initrd.img -print|head -n1' % self.cdrom)['o'].strip()
+ src_vmlinuz = self.shcmd('find %s -type f -name vmlinuz -print|head -n1' % self.cdrom)['o'].strip()
+
+ # copy initrd.img to current/source/images/initrd-7.0-x86_64
+ # copy vmlinuz to current/source/images/vmlinuz-7.0-x86_64
+ if os.path.exists(dst_initrds):
+ os.remove(dst_initrds)
+ shutil.copyfile(src_initrds, dst_initrds)
+ if os.path.exists(dst_initrds):
+ self.msg('\tCopying complete %s' % initrds)
+ else:
+ self.msg('Failed: %s' % dst_initrds, 'a')
+
+ if os.path.exists(dst_vmlinuz):
+ os.remove(dst_vmlinuz)
+ shutil.copyfile(src_vmlinuz, dst_vmlinuz)
+ if os.path.exists(dst_vmlinuz):
+ self.msg('\tCopying complete %s' % vmlinuz)
+ else:
+ self.msg('Failed for %s.' % dst_vmlinuz, 'a')
+
+ def makeRpms(self):
+ # 制作sunhpc-anaconda-cdboot-*.rpm文件.
+
+ rpm_fullname = 'sunhpc-anaconda-initrds-%s-%s.noarch.rpm' % (sunhpc.version, sunhpc.version_micro)
+ dst_fullname = os.path.join(self.dst_base, rpm_fullname)
+ if os.path.exists(dst_fullname):
+ os.remove(dst_fullname)
+
+ argsk = []
+ argsk.append('arch=noarch')
+ argsk.append('name=sunhpc-anaconda-initrds')
+ argsk.append('source=%s' % self.src_base)
+ argsk.append('prefix=/export/sunhpc/install/source')
+ argsk.append('outdir=%s' % self.dst_base)
+ self.command('create.rpm', argsk)
+
+ def printVersion(self):
+
+ # 读取版本信息
+ # timestamp, 7.6, x86_64
+ info = ['unkonw', 'unkown', 'unkown']
+ with open(os.path.join(self.cdrom, '.discinfo'), 'r') as f:
+ line = f.readlines()
+ if len(line) >= 3:
+ info[0] = line[0].strip()
+ info[1] = line[1].strip()
+ info[2] = line[2].strip()
+
+ name = 'unkown'
+ with open(os.path.join(self.cdrom, '.treeinfo'), 'r') as f:
+ for l in f.readlines():
+ if l.find('family') == 0:
+ try:
+ name = l.split('=')[1].strip()
+ except KeyError:
+ pass
+
+ if name not in ['CentOS', 'Sunhpc']:
+ self.msg('The %s is not supported.' % name, 'a')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/repos/__init__.py b/lib/sunhpc/commands/create/repos/__init__.py
new file mode 100644
index 0000000..fc6bd6a
--- /dev/null
+++ b/lib/sunhpc/commands/create/repos/__init__.py
@@ -0,0 +1,105 @@
+#coding:utf-8
+import os
+import sys
+import shutil
+import subprocess
+import sunhpc.commands
+class Command(sunhpc.commands.create.command):
+ """
+ Create a Sunhpc local yum repos.
+
+ <param type='Bool' name='file'>
+ Enable the file repos.
+ </param>
+
+ <param type='Bool' name='web'>
+ Enable the http web repos.
+ </param>
+
+ <param type='Bool' name='gpk'>
+ Enable the gpk.
+ </param>
+
+ <example cmd='create repos'>
+ Create a sunhpc local yum repos.
+ </example>
+
+ <example cmd='create repos file=1'>
+ Enable yum repo for file.
+ </example>
+
+ <example cmd='create repos web=1'>
+ Enable yum web repo for http.
+ </example>
+ """
+ def run(self, param, args):
+
+ (version, filerepo, webrepo, gpk, quiet) = self.fillParams([
+ ('version', sunhpc.version),
+ ('file', 'no'),
+ ('web', 'no'),
+ ('gpk', 'no'),
+ ('quiet', 'no')
+ ])
+
+ q = self.str2bool(quiet)
+ fil = self.str2bool(filerepo)
+ web = self.str2bool(webrepo)
+ gpk = self.str2bool(gpk)
+
+ if fil:
+ web = False
+
+ if web:
+ fil = False
+
+ if not fil and not web:
+ self.msg('must supply an file or web.', 'a')
+
+
+ addr = self.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ base = self.db.getHostAttr('localhost', 'Kickstart_BaseDir')
+ dist = self.db.getHostAttr('localhost', 'distribution')
+
+ distro_base = self.command('report.distro').strip()
+ distro_dirs = os.path.join(distro_base, dist, self.arch)
+
+ sunhpc_repo = '/etc/yum.repos.d/sunhpc-local.repo'
+ with open(sunhpc_repo, 'w') as f:
+ f.write('#\n# Generated by "sunhpc create repos" command\n#\n')
+
+ if web:
+ f.write('[Sunhpc-%s]\n' % sunhpc.version)
+ f.write('name = Sunhpc %s\n' % sunhpc.version)
+ f.write('baseurl = http://%s/%s/%s/%s\n' % (addr, base, dist, self.arch))
+ f.write('enabled = 1\n')
+
+ if fil:
+ f.write('[Sunhpc-File-%s]\n' % sunhpc.version)
+ f.write('name = Sunhpc file %s\n' % sunhpc.version)
+ f.write('baseurl = file://%s\n' % distro_dirs)
+ f.write('enabled = 1\n')
+
+ if gpk:
+ f.write('gpgcheck = 1\n')
+ else:
+ f.write('gpgcheck = 0\n')
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/create/roll/__init__.py b/lib/sunhpc/commands/create/roll/__init__.py
new file mode 100644
index 0000000..b5d9aca
--- /dev/null
+++ b/lib/sunhpc/commands/create/roll/__init__.py
@@ -0,0 +1,440 @@
+#coding:utf-8
+import os
+import re
+import sys
+import time
+import string
+import socket
+import shutil
+import sunhpc
+import pexpect
+import tempfile
+import subprocess
+
+class Builder:
+ def __init__(self, cmd):
+ self.cmd = cmd
+ self.config = None
+ self.tempdir = os.getcwd()
+
+ def mktemp(self):
+ return tempfile.mktemp(dir=self.tempdir)
+
+ def mkisofs(self, isoName, rollName, diskName, rollDir):
+ self.cmd.msg('Building ISO image for %s ...' % diskName)
+
+ if self.config.isBootable():
+ extraflags = self.config.getISOFlags()
+ else:
+ extraflags = ''
+
+ volname = '%s %s' % (rollName, diskName)
+ if len(volname) > 32:
+ volname = volname[0:32]
+
+ cwd = os.getcwd()
+ cmd = 'mkisofs -V "%s" %s -input-charset utf-8 -max-iso9660-filenames -r -T -f -o %s .' % \
+ (volname, extraflags, os.path.join(cwd, isoName))
+
+ os.chdir(rollDir)
+ self.cmd.system(cmd, std=0)
+ os.chdir(cwd)
+
+ def discinfo(self, dir, name, arch, id=1):
+ # .discinfo 格式
+ # 第一行是时间戳
+ # 第二行是碟片名称
+ # 第三行是架构信息
+ # 第四行是碟片序号 1,2,3 连续三张碟片.
+ disc = os.path.join(dir, '.discinfo')
+ if os.path.isfile(disc):
+ os.unlink(disc)
+
+ with open(disc, 'w') as f:
+ f.write('%f\n' % time.time())
+ f.write('%s\n' % name)
+ f.write('%s\n' % arch)
+ f.write('%d\n' % id)
+
+class RollBuilder(Builder, sunhpc.core.dist.Arch):
+
+ def __init__(self, cmd, xmlfile, sign, kernel):
+ Builder.__init__(self, cmd)
+ sunhpc.core.dist.Arch.__init__(self)
+
+ self.config = sunhpc.core.files.RollInfoFile(xmlfile)
+ self.setArch(self.config.getRollArch())
+
+ self.cmd = cmd
+ self.sign = sign # 是否需要签名.
+ self.kernel = kernel # 提供的Linux iso挂载的根目录.
+
+ def mkisofs(self, isoName, rollName, diskName):
+ Builder.mkisofs(self, isoName, rollName, diskName, diskName)
+
+ def getRPMS(self, path):
+
+ dict = {}
+ tree = sunhpc.core.files.Tree(os.path.join(os.getcwd(), path))
+ for dir in tree.getDirs():
+ for file in tree.getFiles(dir):
+ try:
+ file.getPackageName()
+ except AttributeError:
+ continue # skip all non-rpm files
+
+ # Skip RPMS for other architecures
+ if file.getPackageArch() not in self.getCPUs():
+ continue
+
+ # Resolve package versions
+ name = file.getUniqueName()
+ if not name in dict or file >= dict[name]:
+ dict[name] = file
+
+ list = []
+ for e in dict.keys():
+ list.append(dict[e])
+ return list
+
+ def signRPM(self, rpm):
+
+ cmd = "rpm -q --qf '%%{BUILDHOST}' -p %s" % rpm.getFullName()
+ buildhost = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE).stdout.readline()
+ hostname = socket.gethostname()
+
+ if buildhost == hostname:
+ print (rpm.getFullName())
+ print ('Need to resign for the rpm files')
+
+ def treeinfo(self, name):
+
+ with open(os.path.join(name, '.treeinfo'), 'w') as f:
+ ti = '[general]\n'
+ ti += 'name = Sunhpc-%s\n' % sunhpc.version_major
+ ti += 'family = SunhpcOS\n'
+ ti += 'timestamp = %.2f\n' % time.time()
+ ti += 'variant = \n'
+ ti += 'version = %s\n' % sunhpc.version_major
+ ti += 'packagedir =\n'
+ ti += 'arch = x86_64\n\n'
+
+ ti += '[stage2]\n'
+ ti += 'mainimage = images/install.img\n\n'
+
+ ti += '[images-x86_64]\n'
+ ti += 'kernel = images/vmlinuz-%s-%s\n' % (sunhpc.version, self.cmd.arch)
+ ti += 'initrd = images/initrd-%s-%s\n' % (sunhpc.version, self.cmd.arch)
+ ti += 'boot.iso = images/boot.iso\n\n'
+ f.write(ti)
+
+ def makeBootOS(self, disk, root):
+ self.cmd.msg('\tCreate roll comps repodata ......')
+ cwd = os.getcwd()
+ #
+ # base: roll/kernel/version/x86_64/RedHat/base
+ #
+ base = os.path.join(cwd, root, 'RedHat', 'base')
+ os.makedirs(base)
+
+ #
+ # 将光盘中的repodata/中的*-*-*comps.xml文件放入base目录中.
+ # comps = os.path.join(cwd, disk, 'comps.xml')
+ #
+ comps = os.path.join(base, 'comps.xml')
+ if not os.path.exists(comps):
+ compfile = ''
+ repodirs = os.path.join(self.kernel, 'repodata')
+ for c in os.listdir(repodirs):
+ if c.endswith('.xml') and 'comps' in c.split('.')[0]:
+ compfile = os.path.join(repodirs, c)
+ break
+
+ if not os.path.exists(compfile):
+ self.cmd.msg('The *-comps.xml file was not found in %s' % self.kernel, 'a')
+
+ shutil.copy(compfile, comps)
+
+ # 通过.treeinfo 中的upgrade特性创建自定义镜像文件.
+ product = os.path.join(cwd, disk, 'updates')
+ if not os.path.exists(product):
+ os.makedirs(product)
+ self.treeinfo(disk)
+
+ images = os.path.join(cwd, disk, 'images')
+ if not os.path.exists(images):
+ shutil.copytree(os.path.join(cwd, 'isolinux'), os.path.join(disk, 'images'))
+
+ isolinux = os.path.join(cwd, disk, 'isolinux')
+ if not os.path.exists(isolinux):
+ isodirs = os.path.join(self.kernel, 'isolinux')
+ shutil.copytree(isodirs, isolinux)
+
+ if not os.path.exists(comps):
+ self.cmd.msg('\n\tCould not find a comps.xml file. - %s' % comps, 'e')
+ sys.exit(-1)
+
+ self.createrepo(disk, comps)
+
+ def createrepo(self, disk, comps):
+ self.cmd.msg('Creating %s repository......' % disk)
+ cwd = os.getcwd()
+ base = os.path.join(cwd, disk)
+ os.chdir(base)
+
+ createrepo = '/usr/share/createrepo/genpkgmetadata.py'
+ if not os.path.exists(createrepo):
+ self.cmd.msg('First yum install createrepo', 'a')
+
+ # 使用works参数,需要有足够的tmpdir空间.
+ tmpdir = os.getenv("TMPDIR")
+ os.putenv("TMPDIR",".")
+ cmd = "%s --groupfile %s --workers 8 --quiet ." % (createrepo, comps)
+ subprocess.call(cmd, shell=True)
+ if tmpdir is not None:
+ os.putenv("TMPDIR",tmpdir)
+ else:
+ os.unsetenv("TMPDIR")
+ os.chdir(cwd)
+
+ def getExternalRPMS(self):
+ # 通过解析kickstart xml脚本提取出来的所有rpms.
+ rpms = ['gcc']
+ '''
+ selected = []
+ for rpm in os.popen('%s/sbin/runpy2 %s' % (
+ os.environ.get('SUNHPC_HOME'), ' '.join(rpms))):
+ selected.extend(eval(rpm))
+ '''
+ sunhpc = []
+ nonsun = []
+ return ([], [])
+
+ def getReqiredRPMS(self):
+ self.cmd.msg('Please wait while collecting rpm packages...', 'e')
+
+ rpms = ['minimal', '@base', '@core', '@development']
+ required = []
+ for rpm in os.popen('%s/sbin/runpy2 %s' % (
+ os.environ.get('SUNHPC_HOME'), ' '.join(rpms))):
+ required.extend(eval(rpm))
+
+ if os.path.exists('RPMS'):
+ shutil.rmtree('RPMS')
+
+ os.makedirs('RPMS')
+
+ self.cmd.msg('Copying iso boot required rpm files...', 'e')
+ packDirs = os.path.join(self.kernel, 'Packages')
+ if not os.path.exists(packDirs):
+ self.cmd.msg('Cannot find %s dirs,may not be booting!' % packDirs, 'w')
+
+ packages = {}
+ packTree = sunhpc.core.files.Tree(self.kernel)
+ for dir in packTree.getDirs():
+ if dir != 'Packages': continue
+
+ for file in packTree.getFiles(dir):
+ packages[file.getBaseName()] = file
+
+ #os.symlink(file.getFullName(), os.path.join(path, file.getName()))
+ for p in required:
+ if p in packages:
+ os.symlink(
+ packages[p].getFullName(),
+ os.path.join('RPMS', packages[p].getName())
+ )
+ else:
+ self.cmd.msg('The %s package could not be found.' % p)
+
+ def run(self):
+
+ rpmlist = []
+ if self.config.hasRPMS():
+ # 收集当前RPMS目录里面的rpm文件.
+ rpmlist.extend(self.getRPMS('RPMS'))
+ rpmlist.extend(self.getRPMS('PXES'))
+
+ if self.config.hasSRPMS():
+ # 收集当前SRPMS目录里面的rpm文件.
+ rpmlist.extend(self.getRPMS('SRPMS'))
+
+ # 出现错误: Header V3 RSA/SHA256 Signature, key ID f4a80eb5: NOKEY
+ # 需要导入gpg-key
+ # rpm --import /etc/pki/rpm-gpg/RPM-GPG-KEY-CentOS-7
+ if rpmlist:
+ cmd = "rpm -q --qf '%%{BUILDHOST}' -p %s" % rpmlist[0].getFullName()
+ fix = "rpm --import /etc/pki/rpm-gpg/RPM-*"
+ ret = self.cmd.shcmd(cmd, code=True)
+ if ret['c'] != 0:
+ self.cmd.shcmd(fix, code=True)
+
+ # 进行验签,如果有我们自己制作的rpm文件,需要进行重签
+ # 进行了签名以后,在yum的配置时候就可以开启验签了.
+ if self.sign:
+ self.cmd.msg('Starting the rpm sign...')
+ for rpm in rpmlist:
+ self.signRPM(rpm)
+
+ required = []
+ if self.config.hasRolls():
+ self.cmd.msg('Starting making the rolls iso...')
+ (required, optional) = self.getExternalRPMS()
+ for file in rpmlist:
+ required.append(file)
+
+ # required is total packages.
+ self.cmd.msg('\tRequired Packages %s' % len(required))
+ self.cmd.msg('\tOptional Packages %s' % len(optional))
+ for file in required:
+ rpmlist.append(file)
+ rpmlist.extend(optional)
+
+ # 执行完成getExternalRPMS(),required不应该为0
+ if len(required) == 0:
+ self.cmd.msg('This disk is optional (extra rpms)', 'w')
+
+ # ISO Roll基础目录.
+ name = 'disk'
+ root = os.path.join(name, self.config.getRollName(),
+ self.config.getRollVersion(), self.config.getRollArch())
+
+ os.makedirs(root)
+ os.makedirs(os.path.join(root, 'RedHat', 'RPMS'))
+ os.makedirs(os.path.join(root, 'SRPMS'))
+
+ # Symlink in all the RPMS and SRPMS
+ for file in rpmlist:
+ try:
+ arch = file.getPackageArch()
+ except:
+ continue
+
+ if arch == 'src':
+ file.symlink(os.path.join(root, 'SRPMS', file.getName()))
+ else:
+ file.symlink(os.path.join(root, 'RedHat', 'RPMS', file.getName()))
+
+ if file in required:
+ del required[required.index(file)]
+
+ # Copy the Roll XML file onto all the disks
+ shutil.copy(self.config.getFullName(), root)
+
+ # Create the .discinfo file
+ self.discinfo(name, self.config.getRollName(), self.config.getRollArch())
+
+ isoname = '%s-%s-%s.%s.%s.iso' % (
+ self.config.getRollName(),
+ self.config.getRollVersion(),
+ self.config.getRollRelease(),
+ self.config.getRollArch(),
+ name)
+
+ if self.config.isBootable() == 1:
+ self.cmd.msg('Configuring %s kickstart bootable ...' % self.config.getRollName())
+ self.makeBootOS(name, root)
+ #self.makeBootGrub(self.config.getRollName(), name)
+
+ self.mkisofs(isoname, self.config.getRollName(), name)
+ self.cmd.msg("%s create finished." % isoname)
+
+ def makeBootGrub(self, rollname, name):
+ # CentOS7 以后的grub引导菜单里面的Label需要转义空格.
+ config = os.path.join('kernel', 'isolinux', 'isolinux.cfg')
+ srckey = 'inst.stage2=hd:LABEL=CentOS\\x207\\x20x86_64'
+ dstkey = 'inst.stage2=hd:LABEL=%s\\x20%s ' % (rollname, name)
+ dstkey += 'net.ifnames=0 biosdevname=0'
+
+ text = self.cmd.replace(config, srckey, dstkey)
+ with open(os.path.join(name,
+ 'isolinux', 'isolinux.cfg'), 'w') as f:
+ for i in text:
+ f.write(i)
+
+ def makeBoot(self):
+ if not self.config.isBootable():
+ return
+
+ file_list = ['initrd.img', 'vmlinuz', 'squashfs.img', 'efiboot.img']
+ for f in file_list:
+ require = self.cmd.shcmd('find %s -type f -name %s' % (self.kernel, f))['o'].strip()
+ if not require:
+ self.cmd.msg('.e.g sunhpc create roll roll-*.xml boot=/mnt/cdrom', 'e')
+ self.cmd.msg('must provide an source linux cd path build to boot. %s does not exist.' % self.kernel, 'a')
+
+ self.cmd.msg('Configuring %s pxelinux bootable ...' % self.config.getRollName())
+ if not os.path.exists('isolinux'):
+ os.makedirs('isolinux')
+
+ pxelist = ['product', 'updates', 'efiboot', 'vminitrd', 'squashfs', 'kickstart', 'client']
+ for i in pxelist:
+ self.cmd.command('create.pxelinux.%s' % i)
+ src = 'source/images'
+ if not os.path.exists(src):
+ continue
+ for s in os.listdir(src):
+ shutil.copyfile(os.path.join(src, s), os.path.join('isolinux', s))
+
+class Command(sunhpc.commands.create.command):
+ """
+ 创建新的Roll.
+ <arg type='file' name="*.xml">
+ Either a list of Roll ISO files or the name of a single Roll XML
+ description file. If a list of Roll ISO files to be merge together
+ into a single Roll. Otherwise the single argument is assumed to
+ be the name of the XML file generated by the top level Makefile in
+ the Roll's source.
+ </arg>
+
+ <param type="bool" name="sign">
+ Whether to sign the rpms
+ </param>
+
+ <param type="Path" name="boot">
+ Provides an ISO mount directory for the original system distribution.
+ </param>
+
+ <example cmd='create roll roll-base.xml'>
+ Create an new roll iso file
+ </example>
+
+ <example cmd='create roll roll-CentOS.xml boot=/mnt/cdrom'>
+ Create an new boot roll iso file.
+ </example>
+
+ <related>create xml name=CentOS version=7.9.2009 boot=1'</related>
+ """
+
+ def run(self, param, args):
+
+ (kernel, sign) = self.fillParams([('boot', ''), ('sign', 'no')])
+
+ sign = self.str2bool(sign)
+ if len(args) != 1:
+ self.msg('must supply roll xml config file. or use command create it.', 'e')
+ self.msg('\t# sunhpc create xml name=CentOS version=7.9.2009 boot=1', 'a')
+
+ base, ext = os.path.splitext(args[0])
+ if ext != '.xml':
+ self.msg('missing xml file - %s' % args[0], 'a')
+
+ cwd = os.getcwd()
+ builder = RollBuilder(self, args[0], sign, kernel)
+
+ # create boot iso symlink
+ rpms = os.path.join(kernel, 'Packages')
+ if not os.path.islink('RPMS'):
+ os.symlink(rpms, 'RPMS')
+
+ builder.makeBoot()
+ if os.path.exists('pxeboot') and not os.path.islink('PXES'):
+ os.symlink('pxeboot', 'PXES')
+
+ builder.run()
+ self.msg('Use following command init pxe:', 'i')
+ self.msg(' sunhpc pxelinux build xxx.x86_64.disk.iso', 'i')
+
+
+
+
diff --git a/lib/sunhpc/commands/create/rpm/__init__.py b/lib/sunhpc/commands/create/rpm/__init__.py
new file mode 100644
index 0000000..d858c79
--- /dev/null
+++ b/lib/sunhpc/commands/create/rpm/__init__.py
@@ -0,0 +1,476 @@
+#coding:utf-8
+import os
+import sys
+import pwd
+import grp
+import time
+import sunhpc
+import shutil
+import tarfile
+import tempfile
+import datetime
+import subprocess
+
+class RPMBUILD(object):
+
+ def __init__(self, cmd, name, root, source, prefix,
+ version, release, arch, pyver, spec, tarfile, outdir, quiet):
+
+ self.cmd = cmd
+ self.name = name
+ self.root = root
+ self.arch = arch
+ self.quiet = quiet
+ self.srcdir = source
+ self.outdir = outdir
+ self.prefix = prefix
+ self.version = version
+ self.release = release
+ self.pyver = pyver
+ self.specfile = spec
+ self.tarfile = tarfile
+
+ self.build = os.path.join(self.root, 'BUILD')
+ self.rpms = os.path.join(self.root, 'RPMS')
+ self.srpms = os.path.join(self.root, 'SRPMS')
+ self.spec = os.path.join(self.root, 'SPECS')
+ self.source = os.path.join(self.root, 'SOURCES')
+ self.buildroot = os.path.join(self.root, 'BUILDROOT')
+
+ def makedirs(self):
+ rpmdirs = ['BUILD', 'RPMS', 'SOURCES', 'SPECS', 'SRPMS', 'BUILDROOT']
+ [ os.makedirs(os.path.join(self.root, x))
+ for x in rpmdirs if not os.path.exists(x)]
+
+ def makeMacros(self):
+ macros = os.path.join(self.root, '.rpmmacros')
+
+ '''
+ %_signature gpg
+ %_gpg_path ~/.gnupg
+ %_gpg_name xiubuzhe@sina.com
+ %_gpgbin /usr/bin/gpg2
+ %_gpg_digest_algo sha512%_topdir %{getenv:HOME}/rpmbuild
+ '''
+ with open(macros, 'w') as f:
+ f.write('%%_topdir %s\n' % self.root)
+ f.write('%debug_package %{nil}\n')
+
+ def createHeader(self):
+ # %_sourcedir SOURCE
+ # %_specdir SPECS
+ # $RPM_SOURCE_DIR ="/root/rpmbuild/SOURCES"
+ # $RPM_BUILD_DIR ="/root/rpmbuild/BUILD"
+ # $RPM_BUILD_ROOT ="/root/rpmbuild/BUILDROOT/test-0.1-1.x86_64"
+
+ header = '%%define _prefix %s\n' % self.prefix
+ # 软件包的内容概要
+ header += 'Summary: %s\n' % self.name
+ # 软件包的名称
+ header += 'Name: %s\n' % self.name
+ # 软件包的版本
+ header += 'Version: %s\n' % self.version
+ # 软件包的实际版本号
+ header += 'Release: %s\n' % self.release
+ # 软件包的架构
+ header += 'BuildArch: %s\n' % self.arch
+ # 软件的授权方式和地址.
+ header += 'License: %s\n' % self.name
+ header += 'URL: https://www.sunhpc.com\n'
+ # 发行商或者打包组织的信息
+ header += 'Vendor: %s\n' % self.name
+ # 软件分组
+ header += 'Group: System Environment/Base\n'
+ # 源代码包,可以带多个source0,source1等源,后面可以用%{source},%{source1}引用
+ # Source0: http://nginx.org/download/%{name}-%{version}.tar.gz
+ header += 'Source: %s.tar.gz\n' % self.name
+ header += 'Buildroot: %s\n' % self.buildroot
+ header += 'Prefix: %{_prefix}\n'
+ header += 'Packager: Sunhpc cluster for %s\n' % self.name
+ header += 'AutoReqProv: no\n'
+
+ # 需要依赖的软件包
+ # BuildRequires: zlib-devel
+
+ # 指明本软件一些特定的功能,以便与其他rpm识别
+ # Provides: webserver
+
+ # 判断语句
+ # if 0%{?aaaa} 使用0是先假设没有aaa变量.
+ # 定义了aaa变量执行此处
+ # else
+ # 没有定义aaa变量执行此处
+
+ return header
+
+ def createDesc(self):
+ text = '\n%description\n'
+ text += 'The %s spec file make by Sunhpc OS\n' % self.name
+ return text
+
+ def createPrep(self):
+ text = '\n%prep\n'
+ text += '%setup -q -n %{name}\n'
+ return text
+
+ def createBuild(self):
+ text = '\n%build\n'
+ return text
+
+ def createInstall(self):
+ """制作Spec Install部分"""
+ # 如果制作python,需要将/usr/bin/python 替换成python3版本命令.
+ tree = sunhpc.core.files.RPMbuild(self.srcdir)
+ root = tree.getRootMode()
+ name = tree.getRootName()
+
+ dirslist, filelist, usefiles = [], [], []
+ for path in tree.getDirs():
+
+ Dirsmode = tree.getPathMode(os.path.join(self.srcdir, path))
+ PathUid = pwd.getpwuid(Dirsmode[1]).pw_name
+ PathGid = grp.getgrgid(Dirsmode[2]).gr_name
+ pathmode = 'install -d -m 0%s -o %s -g %s \"$RPM_BUILD_ROOT%%{_prefix}/%s\"\n' % \
+ (Dirsmode[0], PathUid, PathGid, path)
+ if path: dirslist.append(pathmode)
+
+ for f in tree.getFiles(path):
+
+ Filename = f.getName()
+ Filemode = f.getFileMode()
+ FileUid = pwd.getpwuid(Filemode[1]).pw_name
+ FileGid = grp.getgrgid(Filemode[2]).gr_name
+
+ fullpath = os.path.join(path, Filename)
+ #fullname = '%%attr(0%s, %s, %s) %%{_prefix}/%s' % (Filemode[0], FileUid, FileGid, fullpath)
+ fullname = 'install -m 0%s -o %s -g %s \"%s\" \"$RPM_BUILD_ROOT%%{_prefix}/%s\"\n' % \
+ (Filemode[0], FileUid, FileGid, fullpath, fullpath)
+
+ if fullname not in filelist: filelist.append(fullname)
+
+ if fullpath not in usefiles: usefiles.append("/%s" % fullpath)
+
+ # 提供给%files使用.
+ self.usefiles = usefiles
+ self.usepaths = dirslist
+
+ # 配置根目录用户和权限.
+ RootUid = pwd.getpwuid(root[1]).pw_name
+ RootGid = grp.getgrgid(root[2]).gr_name
+ rootdirs = 'install -d -m 0%s -o %s -g %s \"$RPM_BUILD_ROOT%%{_prefix}\"\n' % (root[0], RootUid, RootGid)
+
+ text = ["\n%install\n"]
+ text.append(rootdirs)
+ text.extend(dirslist)
+ text.extend(filelist)
+
+ # 执行buildroot时查看目录结构.
+ #text.append('\ntree %s\n' % self.buildroot)
+
+ return ''.join(text)
+ def createClean(self):
+ # 清理临时文件
+ text = '\n[ "$RPM_BUILD_ROOT" != "/" ] && rm -rf "$RPM_BUILD_ROOT"\n'
+ text += 'rm -rf $RPM_BUILD_DIR/%{name}-%{version}\n'
+ return text
+
+ def createPre(self):
+ # rpm安装前执行的脚本.
+ text = "\n%pre\n"
+ return text
+
+ def createPost(self):
+ # rpm安装后执行的脚本
+ text = "\n%post\n"
+ return text
+
+ def createPreun(self):
+ # rpm卸载前执行的脚本
+ text = "\n%pre\n"
+ return text
+ def createPostun(self):
+ # rpm卸载后执行的脚本
+ text = "\n%postun\n"
+ return text
+
+ def createFiles(self):
+ """包含所以编译阶段中所产生的文件"""
+
+ text = "\n%files\n"
+ # text += "%{_prefix}\n"
+ # 不能直接包含这个根目录,软件如果不涉及到项系统目录写文件则没有问题
+ # 但软件中的文件如果要放入到系统的目录,那么在安装时候
+ # 就会产生文件冲突,无法安装.
+ # 原因是当软件要卸载的时候,会将这里包含的所有文件进行删除
+ # 所以如果你包含了系统目录,那么在卸载软件的时候必然会将系统目录一并删除
+ # 那绝对是致命的问题.所以我们以每个文件作为包含对象.
+
+ # 还需要注意一个问题,个人python程序,在rpmbuild的时候压缩包内只有源码.py文件
+ # 但在执行rpmbuild的时候在buildroot下会产生pyc,pyo文件,所以在打包文件时候
+ # 一般会报错需要包含pyc,pyo文件在files里面.一直无解.
+ # 正常如果会产生pyc,pyo文件应该时在第一次编译时产生呀(BUILD)目录里面产生
+ # 然后在install的时候将BUILD里面的所有文件安装到BUILDROOT里面,
+ # 测试过BUILD目录一直不产生pyc,pyo,但是在BUILDROOT里面却产生.
+ # 解析1: 源码中可能需要有make动作产生相关pyc,pyo文件
+ # 解析2: SPEC配置文件没有build阶段,这个阶段需要结合软件自身的configure和make.
+ for f in self.usefiles:
+ (dname, fname) = os.path.split(f)
+ filename = '%s' % os.path.join(self.prefix, dname[1:], fname)
+ text += '%s\n' % filename
+
+ '''
+ name, ext = os.path.splitext(filename)
+ if ext == '.py' and self.pyver == '2':
+ text += '%so\n' % filename
+ text += '%sc\n' % filename
+
+ (pydir, init) = os.path.split(f)
+ if ext == '.py' and self.pyver == '3':
+ text += '\n%%dir %s/__pycache__\n' % pydir
+ text += '\n%s/__pycache__/*\n' % pydir
+ #text += '\n%s/__pycache__/__init__.cpython-36.opt-1.pyc\n' % pydir
+ #text += '\n%s/__pycache__/__init__.cpython-36.pyc\n' % pydir
+ '''
+ return text
+
+ def excludeFiles(self):
+ # 列出不想打包到rpm中的文件,指定的文件不存在则会报错.
+ text = "\n%exclude\n"
+ return text
+
+ def createLogs(self):
+ tims = time.strftime("%a %b %d %Y", time.localtime())
+ text = "\n%changelog\n"
+ text += "* %s kelvin <kelvin@sunhpc.com>\n" % tims
+ text += "- Specfile auto-generated by sunhpcOS\n"
+ return text
+
+ def makeTargz(self, output_name):
+ """打包目录成tar.gz格式
+ :param output_name : 压缩文件名
+ :param source_path : 需要打包的目录
+ :return: bool
+ """
+ if not self.quiet:
+ self.cmd.msg('\tCompress %s.tar.gz packages, wait....' % self.name)
+ try:
+ with tarfile.open(output_name, "w:gz") as tar:
+ tar.add(self.srcdir, arcname=self.name)
+ except Exception as e:
+ self.cmd.abort('The %s.tar.gz compress pacakges failed. %s' % (self.name, repr(e)))
+
+ def run(self):
+ # 创建rpmbuild基础目录.
+ self.makedirs()
+
+ # 创建Rpmbuild 基础配置文件.
+ self.makeMacros()
+
+ # 配置SPEC文件内容
+ contents = self.createHeader()
+ contents += self.createDesc()
+ contents += self.createPrep()
+ # contents += self.createBuild()
+ contents += self.createInstall()
+ # contents += self.createClean()
+ # contents += self.createPre() #rpm安装前执行的脚本
+ contents += self.createPost() #rpm安装后执行的脚本
+ # contents += self.createPreun() #rpm卸载前执行的脚本
+ # contents += self.createPostun() #rpm卸载后执行的脚本
+ contents += self.createFiles()
+ # contents += self.excludeFiles() #列出不想打包到rpm中的文件,指定文件不存在,则会报错.
+ contents += self.createLogs()
+
+ # 写入Spec文件到SPEC目录.
+ spec_file = os.path.join(self.spec, '%s.spec' % self.name)
+
+ if self.specfile and self.tarfile:
+ with open(self.specfile, 'r') as f:
+ contents = f.read()
+
+ with open(spec_file, 'w') as f:
+ f.write(contents)
+
+ # 将源文件夹压缩成tar.gz文件放入到SOURCE目录中.
+ tar_file = os.path.join(self.source, '%s.tar.gz' % self.name)
+
+ # 如果压缩包存在,则先删除.
+ if os.path.isfile(tar_file):
+ os.rmove(tar_file)
+
+ if self.tarfile:
+ dirname, srcname = os.path.split(self.tarfile)
+ dstname = os.path.join(self.source, srcname)
+ shutil.copyfile(self.tarfile, dstname)
+ else:
+ # 将路径内容进行打包.
+ self.makeTargz(tar_file)
+
+ # 设置Home变量
+ HOME = os.environ['HOME']
+ os.environ['HOME'] = self.root
+ cwd = os.getcwd()
+
+ # 开始执行命令进行建立rpm包
+ os.chdir(self.root)
+ # -bp 只作准备 (解压与打补丁)
+ # -bc 准备并编译
+ # -bi 编译并安装
+ # -bl 检验文件是否齐全
+ # -ba 编译后做成*.rpm和src.rpm
+ # -bb 编译后做成*.rpm
+ # -bs 只做成*.src.rpm
+ rpmbuild = 'rpmbuild -bb SPECS/%s.spec' % self.name
+
+ # rpmbuild command
+ #self.cmd.execWithCommand(rpmbuild, quiet=not(self.cmd.debug))
+ retval = subprocess.run(rpmbuild + ' > /tmp/make-sunhpc-rpm.log 2>&1', shell=True)
+ if retval.returncode:
+ self.cmd.msg('check /tmp/make-sunhpc-rpm.log file.','w')
+
+ # 切换正常目录.
+ os.environ['HOME'] = HOME
+ os.chdir(cwd)
+
+ rpm_name = "%s-%s-%s.%s.rpm" % (self.name, self.version, self.release, self.arch)
+ rpm_file = os.path.join(self.rpms, self.arch, rpm_name)
+ if not os.path.exists(rpm_file):
+ self.cmd.msg(' %s Generate failed.' % rpm_name, 'a')
+
+ new_rpm_dir = os.path.join(cwd, 'newRPMS')
+ if self.outdir:
+ new_rpm_dir = self.outdir
+
+ if not os.path.exists(new_rpm_dir):
+ os.makedirs(new_rpm_dir)
+
+ shutil.copyfile(rpm_file, os.path.join(new_rpm_dir, rpm_name))
+
+ new_rpm = os.path.join(new_rpm_dir, rpm_name)
+ if os.path.exists(new_rpm):
+ if not self.quiet:
+ self.cmd.msg('The new "%s" in %s directory.' % (self.name, new_rpm_dir))
+
+class Command(sunhpc.commands.create.command):
+ """
+ Create an rpm packages
+ <param type="str" name="name">
+ Must specify the rpm name.
+ </param>
+
+ <param type="str" name="source">
+ Specify the rpm source path.
+ </param>
+
+ <param type="str" name="prefix">
+ Specify the rpm install path.
+ </param>
+
+ <param type="str" name="arch">
+ Specifies the rpm arch. The default is local machine arch
+ </param>
+
+ <param type="str" name="version">
+ Specifies the rpm version.
+ </param>
+
+ <param type="str" name="release">
+ Specifies the rpm release.
+ </param>
+
+ <param type="file" name="spec">
+ Specifies the spec file for rpmbuild.
+ </param>
+
+ <param type="path" name="outdir">
+ Specifies the rpm file out dir.
+ </param>
+
+ <example cmd='create rpm'>
+ Create an rpm package.
+ </example>
+ """
+ def run(self, params, args):
+
+ (name, source, prefix, outdir, pyver,
+ spec, arch, tarfile, quiet, version,
+ release) = self.fillParams([
+ ('name', ),
+ ('source', None),
+ ('prefix', None),
+ ('outdir', None),
+ ('pyver', '2'),
+ ('spec', None),
+ ('arch', self.arch), # 指定arch=noarch 编译noarch包.
+ ('tarfile', None),
+ ('quiet', 'no'),
+ ('version', sunhpc.version),
+ ('release', sunhpc.version_micro)])
+
+ quiet = self.str2bool(quiet)
+
+ if spec and os.path.isfile(spec):
+ if not tarfile:
+ self.msg('spec and tarfile params must be supply', 'a')
+
+ if not source:
+ self.msg('must supply source package paths.', 'a')
+
+ if not os.path.isdir(source):
+ self.msg('The %s dirs is not exists or not dirs.' % source, 'a')
+
+ if not prefix:
+ self.msg('must supply rpm install(prefix) paths.', 'a')
+
+ # 如果没有提供软件名称,则目录名作为软件名称.
+ if not name:
+ name = os.path.basename(os.path.abspath(source))
+
+ # RPM最终的安装目录,需要提供全路径.
+ # sunhpc create rpm source=gaussian-09-a01/ prefix=/works/apps/soft/gaussian-09-a01 version=9.1
+ prefix = os.path.join(os.sep, prefix)
+
+ # 创建临时编译目录.
+ cwd = os.getcwd()
+ root = tempfile.mktemp(dir=cwd)
+ build = RPMBUILD(self, name, root, source, prefix, version, release, arch, pyver, spec, tarfile, outdir, quiet)
+
+ src_brp = '/usr/lib/rpm/brp-python-bytecompile'
+ dst_brp = '/usr/lib/rpm/bak-brp-python-bytecompile'
+ try:
+ # 如果Python保持也可以关闭python的检测.但是这样不明智,
+ # 即使制作出来了RPM包,也不一定好用.
+ # 注释掉 /usr/lib/rpm/redhat/macros 文件中的
+ # /usr/lib/rpm/brp-python-bytecompile 这行.
+ os.rename(src_brp, dst_brp)
+ with open(src_brp, 'w') as f:
+ f.write('#!/bin/bash\n')
+ f.write('exit 0\n')
+ os.chmod(src_brp, 0o755)
+
+ # 如果未指定pyver版本, 则默认使用系统python版本制作rpm包
+ # 如果制作Sunhpc相关的RPM包,必须使用python-3的版本.
+
+ dst = '/usr/bin/python'
+ org = '/usr/bin/python2'
+ if int(pyver) == 3:
+ src = '/opt/sunpy3/bin/python3'
+ else:
+ src = '/usr/bin/python2'
+
+ # 删除python
+ os.unlink(dst)
+ # 建立软连接
+ os.symlink(src, dst)
+
+ build.run()
+
+ shutil.rmtree(root)
+ except Exception as e:
+ self.abort('%s' % repr(e))
+
+ finally:
+ # 恢复python版本.
+ os.unlink(dst)
+ os.symlink(org, dst)
+ os.rename(dst_brp, src_brp)
diff --git a/lib/sunhpc/commands/create/security/__init__.py b/lib/sunhpc/commands/create/security/__init__.py
new file mode 100644
index 0000000..f5a0e47
--- /dev/null
+++ b/lib/sunhpc/commands/create/security/__init__.py
@@ -0,0 +1,221 @@
+#coding:utf-8
+import os
+import sys
+import stat
+import time
+import base64
+import shutil
+import sunhpc
+from sunhpc.core.utils import SafeError
+class command(sunhpc.commands.create.command):
+
+ def makeEncrypt(self, filename, path=None, safedir="/etc/safe.d", quiet=0):
+ """
+ filename: 需要加密的真实文件路径. /etc/safe-security/ssh_host_rsa_key
+ path : 最终实际安装的文件路径. /etc/ssh/ssh_host_rsa_key
+ """
+
+ if not os.access(filename, os.R_OK):
+ self.msg("I cannot find or see '%s' file" % filename, 'a')
+
+ if not os.access(safedir, os.W_OK):
+ self.msg("I do not have permission to write to '%s'" % safedir, 'a')
+
+ # 获取这个文件的类型权限等相关信息.
+ s = os.stat(filename)
+
+ # 获取文件类型和权限 十进制数值.
+ mode = s[stat.ST_MODE]
+
+ # 将文件类型和权限 转换成八进制数值
+ moct = oct(mode)
+
+ # 获取UID和GID
+ uid = s[stat.ST_UID]
+ gid = s[stat.ST_GID]
+
+ # 获取文件大小
+ fsize = s[stat.ST_SIZE]
+
+ # 获取文件最后访问时间
+ atime = s[stat.ST_ATIME]
+
+ # 获取文件创建时间
+ ctime = s[stat.ST_CTIME]
+
+ # 获取文件最后修改时间
+ mtime = s[stat.ST_MTIME]
+
+ # 将uid和gid组合.
+ owner = "%d.%d" % (uid, gid)
+
+ # 通过时间戳转换成友好时间格式
+ # timeArray = time.localtime(atime)
+ # timeFmt = time.strftime("%Y-%m-%d %H:%M:%S", timeArray)
+ # 2020-11-26 20:36:28
+
+ # 加密类函数.
+ secu = sunhpc.core.security.Security(2048)
+
+ # 获取文件的绝对路径. e.g., /etc/passwd
+ # fullpath: /etc/passwd
+ fullpath = os.path.abspath(filename)
+ if path:
+ fullpath = path
+
+ # 获取security Plugin插件 /opt/sunhpc/var/plugins/security
+ plugin_path = os.path.join(self.prefix, 'var', 'plugins', 'security')
+ sys.path.append(plugin_path)
+ mod_file = None
+ for plugin_file in os.listdir(plugin_path):
+ if not plugin_file.endswith('.py'):
+ continue
+ mod_name = plugin_file.split('.py')[0]
+ # Import the plugin
+ mod = __import__(mod_name)
+ plugin = mod.Plugin()
+ if plugin.get_filename() == fullpath:
+ mod_file = plugin_file
+ break
+ else:
+ plugin = None
+
+ # 通过get_filename 函数 获取是处理哪一个文件.将此plugin代码添加到密文中.
+ if mod_file == None:
+ pycode = None
+ else:
+ f = open(os.path.join(plugin_path, mod_file), 'r')
+ pycode = f.read().encode('UTF-8')
+ f.close()
+
+ # e.g., dirs:/etc name:passwd
+ dirs, name = os.path.split(filename)
+ if path:
+ dirs, name = os.path.split(filename)
+
+ # 开始制作xml配置文件.
+ header = "<?xml version='1.0' standalone='yes'?>\n"
+ header = "<SafeService>\n"
+ header += "<name>%s</name>\n" % fullpath
+ header += "<mode>%s</mode>\n" % moct
+ header += "<owner>%s</owner>\n" % owner
+ header += "<directory>%s</directory>\n" % dirs
+
+ plaintext = header
+
+ # 如果是一个文件.
+ if stat.S_ISREG(mode):
+
+ plaintext += "<content>\n<![CDATA[\n"
+
+ with open(filename, 'r') as f:
+ # 将内容转换bytes格式. 否则无法base64.
+ content = f.read().encode('UTF-8')
+
+ # 将bytes 转换成 str
+ plaintext += str(base64.b64encode(content), encoding="utf-8")
+ plaintext += "]]>\n</content>\n"
+
+ elif stat.S_ISDIR(mode):
+ pass
+
+ else:
+ raise SafeError("I can only publish a regular file or a directory.")
+
+ # 将pycode添加到配置文件中.当使用safeGet命令时调用.
+ if pycode is not None:
+ plaintext += "<pycode>\n<![CDATA[\n"
+ plaintext += str(base64.b64encode(pycode), encoding="utf-8")
+ plaintext += "]]>\n</pycode>\n"
+ plaintext += '</SafeService>\n'
+
+ # 将plaintext格式的内容进行加密.
+ msg = secu.encrypt(plaintext)
+
+ # safe格式文件名 e.g., - /etc/passwd > etc.passwd
+ safeFilename = self.tranpath(filename)
+ if path:
+ safeFilename = self.tranpath(path)
+
+ # 将加密的内容写入到safe目录,e.g., /etc/safe.d/passwd
+ fname = os.path.join(safedir, safeFilename)
+ with open(fname, 'w') as f:
+ f.write(msg)
+
+ if not quiet:
+ print ("SafeFile Wrote: %s/%s" % (safedir, safeFilename))
+
+
+class Command(command):
+ """
+ Create sunhpc cluster security keys ...
+
+ <arg type='string' name='keyname'>
+ supply an name e,g.. rsa, dsa, ecd, and so on.
+ </arg>
+
+ <param type='string' name='keyname'>
+ supply an name e,g.. rsa, dsa, ecd, and so on.
+ </param>
+
+ <param type='Bool' name='force'>
+ Force create the keys.
+ </param>
+
+ <example cmd='create security rsa'>
+ Create an security keypair to /etc/safe-security
+ </example>
+ """
+
+ def run(self, param, args):
+
+ (keyname, force) = self.fillParams([
+ ('keyname', None),
+ ('force', 'no')
+ ])
+
+ force = self.str2bool(force)
+ safedir = '/etc/safe-security'
+ if not keyname:
+ if not args:
+ self.msg('must supply an keyname.', 'a')
+
+ keyname = args[0]
+
+ if keyname not in ['rsa', 'dsa', 'ecdsa']:
+ self.msg('supply keyname is not support (rsa, dsa, ecdsa).', 'a')
+
+ if force and os.path.exists(safedir):
+ shutil.rmtree(safedir)
+
+ if not os.path.exists(safedir):
+ os.makedirs(safedir)
+
+ security = sunhpc.core.security.Security()
+
+ if keyname in ['r', 'rsa']:
+ keyspair = security.makeRsaKeyPair()
+ self.makeKeys(keyspair, safedir)
+
+ if keyname in ['d', 'dsa']:
+ pass
+
+ if keyname in ['e', 'ec', 'ecd', 'ecdsa']:
+ pass
+
+ def makeKeys(self, keyspair, safedir):
+ plist = ['master.key', 'master.pub', 'shared.key', 'shared.pub']
+ for pp in plist:
+ filename = os.path.join(safedir, pp)
+ if os.path.exists(filename):
+ continue
+
+ with open(filename, 'wb') as f:
+ f.write(keyspair[pp] + b'\n')
+
+ if pp.split('.')[-1] == 'key':
+ os.system('chmod 0400 %s' % filename)
+ else:
+ os.system('chmod 0444 %s' % filename)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/security/sshd/__init__.py b/lib/sunhpc/commands/create/security/sshd/__init__.py
new file mode 100644
index 0000000..9573a18
--- /dev/null
+++ b/lib/sunhpc/commands/create/security/sshd/__init__.py
@@ -0,0 +1,60 @@
+#coding:utf-8
+
+import os
+import sys
+import stat
+import time
+import base64
+import sunhpc
+from sunhpc.core.utils import SafeError
+class Command(sunhpc.commands.create.security.command):
+ """
+ Generate the root rsa and sshd rsa, dsa, ecdsa ...
+
+ <params type='Bool' name='force'>
+ Force overwrite old files.
+ </params>
+
+ <params type='Bool' name='quiet'>
+ quiet mode.
+ </params>
+
+ <example cmd='create security sshd'>
+ Generate the crypted.
+ </example>
+ """
+ def run(self, parms, args):
+
+ (force, quiet) = self.fillParams([('force', 'no'), ('quiet', 'yes')])
+
+ force = self.str2bool(force)
+ quiet = self.str2bool(quiet)
+
+ safedirs = '/etc/safe.d'
+ src_dirs = '/etc/safe-security'
+ filelist = ['ssh_host_rsa_key', 'ssh_host_ecdsa_key', 'ssh_host_ed25519_key']
+
+ valid_list = []
+ for i in filelist:
+ filename = os.path.join(src_dirs, i)
+ if force and os.path.exists(filename):
+ os.remove(filename)
+ os.system("ssh-keygen -q -t rsa -f %s -C '' -N ''" % filename)
+
+ if not os.path.exists(filename):
+ os.system("ssh-keygen -q -t rsa -f %s -C '' -N ''" % filename)
+
+ os.system("chmod 0640 %s" % filename)
+ os.system("chmod 0644 %s.pub" % filename)
+ os.system("chown root:ssh_keys %s" % filename)
+
+
+ key = (filename, "/etc/ssh/%s" % i)
+ pub = ("%s.pub" % filename, "/etc/ssh/%s.pub" % i)
+ valid_list.extend([key, pub])
+
+
+ for s, d in valid_list:
+ self.makeEncrypt(s, d, safedirs, quiet)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/security/users/__init__.py b/lib/sunhpc/commands/create/security/users/__init__.py
new file mode 100644
index 0000000..e5cddfa
--- /dev/null
+++ b/lib/sunhpc/commands/create/security/users/__init__.py
@@ -0,0 +1,42 @@
+#coding:utf-8
+
+import os
+import sys
+import stat
+import time
+import base64
+import sunhpc
+from sunhpc.core.utils import SafeError
+class Command(sunhpc.commands.create.security.command):
+ """
+ Update all user-related files (e.g., /etc/passwd, /etc/shadow, etc.)
+ on all known hosts. Also, restart autofs on all known hosts.
+
+ <arg type='string' name='safedir'>
+ Provide a path to encrypt, default: /etc/safe.d
+ </arg>
+
+ <params type='string' name='safedir'>
+ Provide a encrypt file output path, default: /etc/safe.d
+ </params>
+
+ <example cmd='create security users'>
+ Encrypt sunhpc os all base data.
+ </example>
+ """
+ def run(self, parms, args):
+ (self.safedir, salt, quiet) = self.fillParams([
+ ('safedir', '/etc/safe.d'),
+ ('salt', None),
+ ('quiet', 'yes')])
+
+ quiet = self.str2bool(quiet)
+ if not os.path.exists(self.safedir):
+ os.makedirs(self.safedir)
+
+ userdirs = ['/etc/passwd', '/etc/shadow', '/etc/group']
+ services = ['/etc/auto.master', '/etc/auto.home', '/etc/auto.share']
+ for i in userdirs + services:
+ self.makeEncrypt(i, quiet=quiet)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/create/xml/__init__.py b/lib/sunhpc/commands/create/xml/__init__.py
new file mode 100644
index 0000000..d8a4c05
--- /dev/null
+++ b/lib/sunhpc/commands/create/xml/__init__.py
@@ -0,0 +1,116 @@
+#coding:utf-8
+import os
+import sys
+import time
+import sunhpc
+import datetime
+class Command(sunhpc.commands.create.command):
+ """
+ 创建新的xml相关配置文件.
+ <param type="str" name="name">
+ Must specify the ISO/Roll name.
+ </param>
+
+ <param type="bool" name="boot">
+ Def:0 Specifies whether the ISO supports bootin.
+ </param>
+
+ <param type="str" name="flag">
+ Specifies boot extension flag.
+ </param>
+
+ <param type="bool" name="bin">
+ Def:1 Specifies the file bin format in the ISO package.
+ </param>
+
+ <param type="bool" name="src">
+ Def:0 Specifies whether the ISO contains source files.
+ </param>
+
+ <param type="bool" name="rolls">
+ Def:1 Specifies the ISO for sunhpc rolls.
+ </param>
+
+ <param type="str" name="version">
+ Specifies the ISO/Roll version.
+ </param>
+
+ <param type="str" name="release">
+ Specifies the ISO/Roll release.
+ </param>
+
+ <param type="str" name="arch">
+ Def:x86_64 Specifies the ISO/Roll arch.
+ </param>
+
+ <param type="str" name="os">
+ Def:linux Specifies the ISO/Roll os.
+ </param>
+
+ <example cmd='create xml name=CentOS version=7.9.2009 boot=1'>
+ Create Roll xml configuration.
+ </example>
+ """
+ def run(self, param, args):
+
+ (name, boot, flag, b, src, rolls, ver, release,
+ arch, self.out, plat) = self.fillParams([
+ ('name', ),
+ ('boot', '0'),
+ ('flag', ''),
+ ('bin', '1'),
+ ('src', '0'),
+ ('rolls', '1'),
+ ('version', sunhpc.version),
+ ('release', sunhpc.version_micro),
+ ('arch', self.arch),
+ ('path', os.getcwd()),
+ ('os', self.os)])
+
+ boot = self.str2bool(boot)
+ isbin = self.str2bool(b)
+ issrc = self.str2bool(src)
+ rolls = self.str2bool(rolls)
+
+ if not name:
+ self.abort('must supply roll name')
+
+ if not os.path.isdir(self.out):
+ self.abort('The %s must is directory' % self.out)
+
+ if not os.path.exists(self.out):
+ os.makedirs(self.out)
+
+ self.makeRollXML(name, arch, plat, flag, isbin, issrc, rolls, boot, ver, release)
+
+ def makeRollXML(self, name, arch, plat, flag, isbin, issrc, rolls, boot, ver, release):
+
+ # ver1 = '-b isolinux/isolinux.bin -c isolinux/boot.cat -no-emul-boot -boot-load-size 4 -boot-info-table'
+ # ver2 = '-b isolinux/isolinux.bin -c isolinux/boot.cat -no-emul-boot -boot-load-size 4 -boot-info-table '
+ # '-eltorito-alt-boot -e images/efiboot.img -no-emul-boot -J -T '
+ if boot:
+ flags = '-b isolinux/isolinux.bin -c isolinux/boot.cat -no-emul-boot '
+ flags += '-boot-load-size 4 -boot-info-table -eltorito-alt-boot '
+ flags += '-e images/efiboot.img %s -no-emul-boot -J -T ' % flag
+ else:
+ flags = ''
+
+ filename = os.path.join(self.out, "roll-%s.xml" % name)
+ file = open(filename, 'w')
+ file.write('<roll name="%s" interface="%s">\n' % (name, sunhpc.version))
+
+ rolltime = time.strftime('%X')
+ rolldate = time.strftime('%b %d %Y')
+ rollzone = time.strftime('%Z')
+ file.write('\t<timestamp time="%s" date="%s" tz="%s"/>\n' %
+ (rolltime, rolldate, rollzone))
+
+ file.write('\t<info version="%s" release="%s" arch="%s" os="%s"/>\n' %
+ (ver, release, arch, plat))
+
+ # size 是制作iso最大大小,如果超过size,制作时就会自动分割.
+ file.write('\t<iso maxsize="0" bootable="%s" mkisofs="%s"/>\n' % (boot, flags))
+ file.write('\t<rpm rolls="%s" bin="%s" src="%s"/>\n' % (rolls, isbin, issrc))
+ file.write('\t<deps name="%s" host="sunhpc-cluster.local"/>\n' % name)
+ file.write('</roll>\n')
+ file.close()
diff --git a/lib/sunhpc/commands/database/__init__.py b/lib/sunhpc/commands/database/__init__.py
new file mode 100644
index 0000000..8add27f
--- /dev/null
+++ b/lib/sunhpc/commands/database/__init__.py
@@ -0,0 +1,6 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+RollName = "base"
diff --git a/lib/sunhpc/commands/database/init/__init__.py b/lib/sunhpc/commands/database/init/__init__.py
new file mode 100644
index 0000000..dba50f3
--- /dev/null
+++ b/lib/sunhpc/commands/database/init/__init__.py
@@ -0,0 +1,487 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import random
+import sqlalchemy
+import sunhpc.invoke
+class Command(sunhpc.commands.database.command):
+ """
+ Create sunhpc database tables.
+ <arg type='path' name='cfg'>
+ supply an config file. default cfg in /opt/sunhpc/etc/sunhpc.cfg
+ </arg>
+ """
+
+ def run(self, params, args):
+
+ cfg = os.path.join(self.prefix, 'etc', 'sunhpc.conf')
+ if args:
+ cfg = args[0]
+
+ if not os.path.exists(cfg):
+ self.msg('The * configuration file does not exists.must provide a.', 'a', cfg)
+
+ self.readConfig(cfg)
+ self.insertdb()
+
+ def insertdb(self):
+ try:
+
+ [ self.db.execute(t) for t in self.baseTables()]
+
+ except sqlalchemy.exc.OperationalError as e:
+ raise
+ # e.args, e.code, e.orig, e.statement
+ if e.code == 'e3q8':
+ self.msg(e.args, 'e')
+ self.msg('Database * already exists.please delete it first!!!',
+ 'a', self.db.database.datafile)
+ else:
+ print (e)
+
+ self.msg('The %s database is create finish.' % self.db.database.datafile)
+
+ # add data to DB
+ [self.db.execute(d) for d in self.tableData()]
+ self.msg('The %s database has been added.' % self.db.database.datafile)
+
+ def readConfig(self, conf):
+
+ confs = {}
+ with open(conf, 'r') as f:
+ for i in f.readlines():
+ r = i.split()
+ if i.startswith('#') or len(r) < 2:
+ continue
+ key, val = r[0], ' '.join(r[1:])
+ confs[key.lower()] = val
+
+ self.country = confs['country'] if 'country' in confs else 'CN'
+ self.state = confs['state'] if 'state' in confs else 'LiaoNing'
+ self.city = confs['locality'] if 'locality' in confs else 'DaLian'
+ self.URL = confs['url'] if 'url' in confs else 'https://www.sunhpc.com'
+ self.cname = confs['name'] if 'name' in confs else 'Sunhpc-cluster'
+ self.contact = confs['contact'] if 'contact' in confs else 'info@sunhpc.com'
+ self.worksdir = confs['worksdir'] if 'worksdir' in confs else '/export'
+ self.distrodir = confs['distrodir'] if 'distrodir' in confs else '/export/sunhpc'
+ self.partition = confs['partition'] if 'partition' in confs else 'default'
+
+ # safe security
+ self.safeport = confs['safeport'] if 'safeport' in confs else '372'
+ self.safedirs = confs['safedirs'] if 'safedirs' in confs else 'safe.d'
+ self.safesecurity = confs['safecurity'] if 'safecurity' in confs else 'safe-security'
+
+ try:
+
+ self.wan_hostname = confs['publichostname']
+ self.wan_device = confs['publicinterface']
+ self.wan_address = confs['publicaddress']
+ self.wan_gateway = confs['publicgateway']
+ self.wan_dns = confs['publicdnsserver']
+ self.wan_netmask = confs['publicnetmask']
+
+ wan_mac = os.popen("nmcli device show %s|grep -i general.hwaddr|awk '{print $2}'" %
+ self.wan_device).readline().strip()
+ self.wan_mac = confs['publicmacaddr'] if 'publicmacaddr' in confs else wan_mac
+
+ wan_cidr = self.net.netmask_to_cidr(self.wan_netmask)
+ self.wan_cidr = confs['publiccidr'] if 'publiccidr' in confs else wan_cidr
+
+ wan_network = self.net.getNetwork(self.wan_address, self.wan_netmask)
+ self.wan_network = confs['publicnetwork'] if 'publicnetwork' in confs else wan_network
+
+ self.wan_ntphost = confs['publicntphost'] if 'publicntphost' in confs else 'pool.ntp.org'
+ self.wan_domain = confs['publicdnsdomain'] if 'publicdnsdomain' in confs else 'hpc.org'
+ self.wan_mtu = confs['publicmtu'] if 'publicmtu' in confs else '1500'
+
+ # private network configure
+ self.lan_hostname = confs['privatehostname'] if 'privatehostname' in confs else 'cluster'
+ self.lan_device = confs['privateinterface']
+ self.lan_address = confs['privateaddress']
+ self.lan_netmask = confs['privatenetmask']
+
+ lan_mac = os.popen("nmcli device show %s|grep -i general.hwaddr|awk '{print $2}'" %
+ self.lan_device).readline().strip()
+ self.lan_mac = confs['privatemacaddr'] if 'privatemacaddr' in confs else lan_mac
+
+ lan_cidr = self.net.netmask_to_cidr(self.lan_netmask)
+ self.lan_cidr = confs['privatecidr'] if 'privatecidr' in confs else lan_cidr
+
+ lan_network = self.net.getNetwork(self.lan_address, self.lan_netmask)
+ self.lan_network = confs['privatenetwork'] if 'privatenetwork' in confs else lan_network
+
+
+ self.lan_gateway = confs['privategateway'] if 'privategateway' in confs else self.lan_address
+ self.lan_ntphost = confs['privatentphost'] if 'privatentphost' in confs else self.lan_address
+ self.lan_dns = confs['privatednsserver'] if 'privatednsserver' in confs else self.lan_address
+ self.lan_domain = confs['privatednsdomain'] if 'privatednsdomain' in confs else 'local'
+ self.lan_mtu = confs['privatemtu'] if 'privatemtu' in confs else '1500'
+
+ except KeyError as e:
+ self.msg("The %s config file error. Field -> *" % conf, 'a', e.args)
+
+ self.daemonPort = confs['plugin_port'] if 'plugin_port' in confs else '12345'
+ self.time_zone = confs['timezone'] if 'timezone' in confs else 'Asia/Shanghai'
+ self.def_args = confs['bootargs'] if 'bootargs' in confs else 'net.ifnames=0 biosdevname=0'
+ self.sunhpcdist = confs['distribution'] if 'distribution' in confs else 'sunhpc-dist'
+ self.ksbasedir = confs['basedir'] if 'basedir' in confs else 'install'
+ self.pxefilename = confs['dhcp_filename'] if 'dhcp_filename' in confs else 'pxelinux.0'
+ self.nextserver = confs['dhcp_nextserver'] if 'dhcp_nextserver' in confs else self.lan_address
+ self.gangliaaddr = confs['ganglia'] if 'ganglia' in confs else '224.0.0.3'
+ self.pxelinuxdir = confs['pxelinuxdir'] if 'pxelinuxdir' in confs else '/tftpboot/pxelinux'
+
+
+ def attributions(self):
+ attrs = [
+ ('Info_CertificateCountry', self.country, '', 1),
+ ('Info_CertificateState', self.state, '', 1),
+ ('Info_CertificateLocality', self.city, '', 1),
+ ('Info_CertificateOrganization','DLHP', '', 1),
+ ('Info_ClusterUrl', self.URL, '', 1),
+ ('Info_ClusterName', self.cname, '', 1),
+ ('Info_ClusterContact', self.contact, '', 1),
+ ('Kickstart_WorksDir', self.worksdir, '', 1),
+ ('Kickstart_DistroDir', self.distrodir, '', 1),
+ ('Kickstart_Partition', self.partition, '', 1),
+ ('Kickstart_PublicHostname', self.wan_hostname, '', 1),
+ ('Kickstart_PublicInterface', self.wan_device, '', 1),
+ ('Kickstart_PublicAddress', self.wan_address, '', 1),
+ ('Kickstart_PublicMacAddr', self.wan_mac, '', 1),
+ ('Kickstart_PublicNetmask', self.wan_netmask, '', 1),
+ ('Kickstart_PublicNetmaskCIDR', self.wan_cidr, '', 1),
+ ('Kickstart_PublicGateway', self.wan_gateway, '', 1),
+ ('Kickstart_PublicNetwork', self.wan_network, '', 1),
+ ('Kickstart_PublicNTPHost', self.wan_ntphost, '', 1),
+ ('Kickstart_PublicDNSServer', self.wan_dns, '', 1),
+ ('Kickstart_PublicDNSDomain', self.wan_domain, '', 1),
+ ('Kickstart_PublicMTU', self.wan_mtu, '', 1),
+
+ ('Kickstart_PrivateHostname', self.lan_hostname, '', 1),
+ ('Kickstart_PrivateInterface', self.lan_device, '', 1),
+ ('Kickstart_PrivateAddress', self.lan_address, '', 1),
+ ('Kickstart_PrivateMacAddr', self.lan_mac, '', 1),
+ ('Kickstart_PrivateNetmask', self.lan_netmask, '', 1),
+ ('Kickstart_PrivateNetmaskCIDR',self.lan_cidr, '', 1),
+ ('Kickstart_PrivateGateway', self.lan_address, '', 1),
+ ('Kickstart_PrivateNetwork', self.lan_network, '', 1),
+ ('Kickstart_PrivateNTPHost', self.lan_ntphost, '', 1),
+ ('Kickstart_PrivateDNSServer', self.lan_dns, '', 1),
+ ('Kickstart_PrivateDNSDomain', self.lan_domain, '', 1),
+ ('Kickstart_PrivateMTU', self.lan_mtu, '', 1),
+
+ ('Kickstart_Plugin_Port', self.daemonPort, '', 1),
+ ('Kickstart_Plugin_Keys', self.daemon_pass(), '', 1),
+
+ ('Kickstart_Timezone', self.time_zone, '', 1),
+ ('Kickstart_Bootargs', self.def_args, '', 1),
+ ('distribution', self.sunhpcdist, '', 1),
+ ('Kickstart_BaseDir', self.ksbasedir, '', 1),
+
+ ('safeport', self.safeport, '', 1),
+ ('safedirs', self.safedirs, '', 1),
+ ('safesecurity', self.safesecurity, '', 1),
+
+ ('sunhpc_version', sunhpc.version, '', 1),
+ ('sunhpc_major', self.major, '', 1),
+ ('sunhpc_minor', self.minor, '', 1),
+ ('sunhpc_micro', self.micro, '', 1),
+ ('sunhpc_release', sunhpc.release, '', 1),
+ ('ganglia_address', self.gangliaaddr, '', 1),
+
+ ('dhcp_filename', self.pxefilename, '', 1),
+ ('dhcp_nextserver', self.nextserver, '', 1),
+ ('pxelinuxdir', self.pxelinuxdir, '', 1),
+
+ ('kickstartable', 'yes', '', 1),
+ ('managed', 'true', '', 1),
+ ('os', 'linux', '', 1)
+ ]
+ return attrs
+
+ def tableData(self):
+ queuelist = []
+ bootver = '%s-%s' % (sunhpc.version, self.arch)
+ vmlinuz = 'vmlinuz-%s' % bootver
+ initrds = 'initrd-%s' % bootver
+ defArgs = "ramdisk_size=150000 net.ifnames=0 biosdevname=0"
+ insArgs = "%s inst.ks.sendmac" % (defArgs)
+ resArgs = "%s rescue" % defArgs
+ lesArgs = "%s vnc" % defArgs
+ bootaction = [
+ "insert into bootactions values (1, 'install', '%s', '%s', '%s')" % (vmlinuz, initrds, insArgs),
+ "insert into bootactions values (2, 'os', 'localboot 0', NULL, NULL)",
+ "insert into bootactions values (3, 'memtest', 'kernel memtest', NULL, NULL)",
+ "insert into bootactions values (4, 'install headless', '%s', '%s', '%s')" % (vmlinuz, initrds, lesArgs),
+ "insert into bootactions values (5, 'rescue', '%s', '%s', '%s')" % (vmlinuz, initrds, resArgs),
+ "insert into bootactions values (6, 'pxeflash', 'kernel memdisk bigraw', 'pxeflash.img', 'keeppxe')"]
+
+ queuelist.extend(bootaction)
+
+ distributions = [
+ "insert into distributions values (1, 'sunhpc-dist', '%s', '%s')" % (self.os, sunhpc.release)
+ ]
+ queuelist.extend(distributions)
+
+ cmds = "dmidecode -t memory|sed s/[[:space:]]//g|grep '^Size'|grep -E 'MB|GB'"
+ Names = os.popen("cat /sys/class/dmi/id/product_name").readline().strip()
+ Vender = os.popen("cat /sys/class/dmi/id/sys_vendor").readline().strip()
+ Serial = os.popen("cat /sys/class/dmi/id/product_serial").readline().strip()
+ CPUs = os.popen("dmidecode -t processor|grep Version|cut -d ':' -f2|wc -l").readline().strip()
+ Core = os.popen("cat /proc/cpuinfo |grep processor|wc -l").readline().strip()
+ Model = os.popen("dmidecode -t processor|grep Version|head -n1|cut -d ':' -f2").readline().strip()
+ MemNumb = os.popen("%s|wc -l" % cmds).readline().strip()
+ MemSize = os.popen("%s|head -n1|cut -d ':' -f2" % cmds).readline().strip()
+
+ Machines = [
+ "insert into machines values (1, 1, '%s', '%s', '%s', '%s', '%s', '%s', '%s', '%s')" %
+ (Names, Vender, Serial, CPUs, Core, Model, MemNumb, MemSize)
+ ]
+ queuelist.extend(Machines)
+
+ part = sunhpc.core.partition.Partition()
+ part.discoveredDisks()
+ disks = part.getDisks()
+ nodeDisks = part.getNodePartInfo(disks)
+ disklist = sorted(nodeDisks)
+ for disk in disklist:
+ for d in nodeDisks[disk]:
+ dev, sec, size, partid, fstype, bootflags, partflags, mountpoint = d
+ if dev and sec and size:
+ queuelist.append('''
+ insert into partitions(node, device, mountpoint, sectorstart,
+ partitionsize, fstype, partitionflags, formatflags) values
+ (1, "%s", "%s", "%s" ,"%s", "%s", "%s", "%s")''' % (
+ dev, mountpoint, sec, size, fstype, bootflags, partflags))
+
+ for attr in self.attributions():
+ k, v, s, n = attr
+ queuelist.append('insert into attributes values (NULL, "%s","%s", "%s", %d)' % (k, v, s, n))
+
+ # add frontend node
+ nodes = [ "insert into nodes values (1, '%s', %d, 0, 0, '%s', '%s', '','', 'os')" \
+ % (self.lan_hostname, int(Core), self.arch, self.os)]
+ queuelist.extend(nodes)
+
+ # add frontend network
+ network = [ "insert into networks values (1, 1, '%s', '%s', '%s', '%s', '2')" %
+ (self.wan_mac, self.wan_address, self.lan_hostname, self.wan_device),
+ "insert into networks values (2, 1, '%s', '%s', '%s', '%s', '1')" %
+ (self.lan_mac, self.lan_address, self.lan_hostname, self.lan_device)]
+ queuelist.extend(network)
+
+ # add subnet data
+ subnet = [ "insert into subnets values (1, 'private', '%s', '%s', '%s', '%s', '1')" %
+ (self.lan_domain, self.lan_network, self.lan_netmask, self.lan_mtu),
+ "insert into subnets values (2, 'public', '%s', '%s', '%s', '%s', '0')" %
+ (self.wan_domain, self.wan_network, self.wan_netmask, self.wan_mtu)]
+ queuelist.extend(subnet)
+
+ # add globalRouters
+ globalrouter = [ "insert into globalroutes values ('0.0.0.0', '0.0.0.0', '%s', NULL)" %
+ self.lan_address,
+ "insert into globalroutes values ('%s', '255.255.255.255', '%s', NULL)" %
+ (self.wan_address, self.lan_address)]
+ queuelist.extend(globalrouter)
+
+ return queuelist
+
+ def baseTables(self):
+ """Base data tables"""
+
+ drop = 'DROP TABLE IF EXISTS tablename'
+ tables = ['Nodes', 'Networks', 'Subnets', 'GlobalRouters',
+ 'PublicKeys', 'SecNodes', 'Attributes', 'Partitions',
+ 'Firewalls', 'Rolls', 'Bootactions', 'distributions',
+ 'SecGlobals'
+ ]
+ datalist = []
+
+ Nodes = '''
+ CREATE TABLE Nodes (
+ ID integer primary key autoincrement,
+ Name varchar(128) default NULL,
+ CPUs integer(11) NOT NULL default '1',
+ Rack integer(11) default NULL,
+ Rank integer(11) default NULL,
+ Arch varchar(32) default NULL,
+ OS varchar(64) NOT NULL default 'linux',
+ Alias varchar(64) default '',
+ Flags varchar(256) default '',
+ Status varchar(32) default 'os'
+ )'''
+ datalist.append(Nodes)
+
+ Networks = '''
+ CREATE TABLE Networks (
+ ID integer NOT NULL primary key autoincrement,
+ Node integer(11) default NULL,
+ MAC varchar(64) default NULL,
+ IP varchar(32) default NULL,
+ Name varchar(128) default NULL,
+ Device varchar(32) default NULL,
+ Subnet integer(11) default NULL,
+ Foreign key(subnet) references subnets(id) on delete cascade on update restrict,
+ Foreign key(node) references nodes(id) on delete cascade on update restrict
+ )'''
+ datalist.append(Networks)
+
+ Subnets = '''
+ CREATE TABLE Subnets (
+ ID integer NOT NULL primary key autoincrement,
+ name varchar(32) UNIQUE NOT NULL,
+ dnszone varchar(64) UNIQUE NOT NULL ,
+ subnet varchar(32) NOT NULL,
+ netmask varchar(32) NOT NULL,
+ mtu integer(11) default '1500',
+ servedns boolean default false
+ )'''
+ datalist.append(Subnets)
+
+ GlobalRouters = '''
+ CREATE TABLE GlobalRoutes (
+ Network varchar(32) NOT NULL default '',
+ Netmask varchar(32) NOT NULL default '',
+ Gateway varchar(32) NOT NULL default '',
+ Subnet integer(11) default NULL,
+ Primary key(Network, Netmask),
+ Foreign key(subnet) references subnets(id) on delete cascade on update restrict
+ )'''
+ datalist.append(GlobalRouters)
+
+ PublicKeys = '''
+ CREATE TABLE PublicKeys (
+ ID integer NOT NULL primary key autoincrement,
+ Public_Key varchar(4096) default NULL,
+ Description varchar(4096) default NULL,
+ Node integer(11) NOT NULL default '0',
+ Foreign key(Node) references nodes(id) on delete cascade on update restrict
+ )'''
+ datalist.append(PublicKeys)
+
+ SecNodes = '''
+ CREATE TABLE SecNodes (
+ Attr varchar(128) default NULL,
+ Enc varchar(64) default NULL,
+ Value text,
+ Node integer(11) NOT NULL default '0',
+ PRIMARY KEY (Node, Attr)
+ )'''
+ datalist.append(SecNodes)
+
+ Attributes = '''
+ CREATE TABLE Attributes (
+ ID integer NOT NULL primary key autoincrement,
+ Attr varchar(128) NOT NULL,
+ Value text,
+ Shadow text,
+ Node integer(11) NOT NULL,
+ Foreign key(Node) references nodes(id) on delete cascade on update restrict
+ )'''
+ datalist.append(Attributes)
+
+ Partitions = '''
+ CREATE TABLE Partitions (
+ ID integer NOT NULL primary key autoincrement,
+ Node integer(11) NOT NULL default '0',
+ Device varchar(128) NOT NULL default '',
+ Mountpoint varchar(128) NOT NULL default '',
+ SectorStart varchar(128) NOT NULL default '',
+ PartitionSize varchar(128) NOT NULL default '',
+ PartitionID varchar(128) NOT NULL default '',
+ FsType varchar(128) NOT NULL default '',
+ PartitionFlags varchar(128) NOT NULL default '',
+ FormatFlags varchar(128) NOT NULL default ''
+ )'''
+ datalist.append(Partitions)
+
+ Firewalls = '''
+ CREATE TABLE Firewalls (
+ ID integer NOT NULL primary key autoincrement,
+ Rulename varchar(128) NOT NULL,
+ Service varchar(256),
+ Protocol varchar(256),
+ Ports varchar(256),
+ Action varchar(256),
+ Comment varchar(256),
+ Node integer(11) NOT NULL default '0'
+ )'''
+ datalist.append(Firewalls)
+
+ Rolls = '''
+ CREATE TABLE Rolls (
+ ID integer NOT NULL primary key autoincrement,
+ Name varchar(128) NOT NULL default '',
+ Version varchar(32) NOT NULL default '',
+ Arch varchar(32) NOT NULL default '',
+ OS varchar(64) NOT NULL default 'linux',
+ Enabled varchar(11) NOT NULL default 'yes'
+ )'''
+ datalist.append(Rolls)
+
+ Bootactions = '''
+ CREATE TABLE Bootactions (
+ ID integer NOT NULL primary key autoincrement,
+ Action varchar(1024) default NULL,
+ Kernel varchar(1024) default NULL,
+ Ramdisk varchar(1024) default NULL,
+ Args varchar(1024) default NULL
+ )'''
+ datalist.append(Bootactions)
+
+ Distributions = '''
+ CREATE TABLE Distributions (
+ ID integer NOT NULL primary key autoincrement,
+ Name varchar(32) NOT NULL default '',
+ OS varchar(32) default '',
+ Release varchar(32) default ''
+ )'''
+ datalist.append(Distributions)
+
+ SecGlobals = '''
+ CREATE TABLE SecGlobals (
+ Attr varchar(128) default NULL,
+ Value text,
+ Enc varchar(128) default NULL,
+ PRIMARY KEY (Attr)
+ )'''
+ datalist.append(SecGlobals)
+
+ Machines = '''
+ CREATE TABLE Machines (
+ ID integer primary key autoincrement,
+ Node integer(128) default '0',
+ Name varchar(128) default NULL,
+ Vender varchar(128) default NULL,
+ Serial varchar(128) default NULL,
+ CPUs integer(128) default '1',
+ Cores integer(128) default '1',
+ Model varchar(128) default '',
+ MemNumber integer(128) default '1',
+ MemSize varchar(64) default NULL,
+ foreign key(Node) references nodes(id) on delete cascade on update restrict
+ )'''
+ datalist.append(Machines)
+
+ return datalist
+
+ def daemon_pass(self):
+ char = 'abcdefghijklmnopqrstuvwxyz'
+ char += '!@#$%+=-_^&*.?'
+ char += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890'
+ return ''.join(random.sample(char,16))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/list/__init__.py b/lib/sunhpc/commands/list/__init__.py
new file mode 100644
index 0000000..b4ae391
--- /dev/null
+++ b/lib/sunhpc/commands/list/__init__.py
@@ -0,0 +1,7 @@
+#condig:utf-8
+
+import sunhpc.commands
+
+class command(sunhpc.commands.Command):
+ MustBeRoot = 0
+RollName = "base"
diff --git a/lib/sunhpc/commands/list/help/__init__.py b/lib/sunhpc/commands/list/help/__init__.py
new file mode 100644
index 0000000..a8e230a
--- /dev/null
+++ b/lib/sunhpc/commands/list/help/__init__.py
@@ -0,0 +1,71 @@
+#coding:utf-8
+import os
+import sys
+import sunhpc.invoke
+import sunhpc.commands
+class Command(sunhpc.commands.list.command):
+ """
+ The Help Command print the usage of all the registered
+ Commands.
+
+ <param optional='1' type='string' name='subdir'>
+ Relative of Python commands for listing help. This is used internally
+ only.
+ </param>
+
+ <example cmd='list help'>
+ List help for all commands
+ </example>
+
+ <example cmd='list help subdir=list/host'>
+ List help for all commands under list/host
+ </example>
+ """
+
+ def run(self, params, args):
+
+ (subdir, cols) = self.fillParams([('subdir', ), ('cols', 80) ], params)
+ if subdir:
+ filepath = os.path.join(sunhpc.commands.__path__[0], subdir)
+ modpath = 'sunhpc.commands.%s' % '.'.join(subdir.split(os.sep))
+ else:
+ filepath = sunhpc.commands.__path__[0]
+ modpath = 'sunhpc.commands'
+
+ tree = sunhpc.core.files.Tree(filepath)
+ dirs = tree.getDirs()
+ dirs.sort()
+
+ if 'COLUMNS' in os.environ:
+ cols = int(os.environ['COLUMNS'])
+
+ for dir in dirs:
+ if not dir: continue
+
+ module = '%s.%s' % (modpath, '.'.join(dir.split(os.sep)))
+ __import__(module)
+ module = eval(module)
+
+ try:
+ o = getattr(module, 'Command')(None)
+ except AttributeError:
+ continue
+
+ if o.MustBeRoot and not self.isRootUser():
+ continue
+
+ cmd = ' '.join(dir.split(os.sep))
+ l = len(cmd) + 1
+ s = ''
+ for arg in o.usage().split():
+ if l + len(arg) < cols or cols == 0:
+ s += '%s ' % arg
+ l += len(arg) + 1 # space
+ else:
+ #s += '\n\t %s ' % arg
+ s += '\n%s %s ' % (" "*len(cmd), arg)
+ l = len(arg) + 9 # tab + space
+
+ self.addText('%s %s\n' % (cmd, s))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/list/host/__init__.py b/lib/sunhpc/commands/list/host/__init__.py
new file mode 100644
index 0000000..c050e80
--- /dev/null
+++ b/lib/sunhpc/commands/list/host/__init__.py
@@ -0,0 +1,38 @@
+#coding:utf-8
+
+import sunhpc.commands
+from sunhpc.db.mappings.base import *
+class command(sunhpc.commands.HostArgumentProcessor,
+ sunhpc.commands.list.command):
+ pass
+
+class Command(command):
+ """
+ List the CPU count and physical position info for
+ a list of hosts.
+
+ <arg optional='1' type='string' name='host' repeat='1'>
+ Zero, one or more host names. If no host names are supplied, info about
+ all the known hosts is listed.
+ </arg>
+
+ <example cmd='list host compute-0-0'>
+ List info for compute-0-0.
+ </example>
+
+ <example cmd='list host'>
+ List info for all known hosts.
+ </example>
+ """
+
+ def run(self, params, args):
+ self.beginOutput()
+
+ for host in self.newdb.getNodesfromNames(args):
+ self.addOutput(host.name, [host.cpus, host.rack,
+ host.rank, host.os, host.arch, host.alias, host.status])
+
+ self.endOutput(header=['host', 'cpus', 'rack',
+ 'rank', 'os', 'arch', 'alias', 'status'])
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/list/host/interface/__init__.py b/lib/sunhpc/commands/list/host/interface/__init__.py
new file mode 100644
index 0000000..9a0763e
--- /dev/null
+++ b/lib/sunhpc/commands/list/host/interface/__init__.py
@@ -0,0 +1,41 @@
+#coding:utf-8
+
+import sunhpc.commands
+class Command(sunhpc.commands.list.host.command):
+ """
+ Lists the interface definitions for hosts. For each host supplied on
+ the command line, this command prints the hostname and interface
+ definitions for that host.
+
+ <arg optional='1' type='string' name='host' repeat='1'>
+ Zero, one or more host names. If no host names are supplied, info about
+ all the known hosts is listed.
+ </arg>
+
+ <example cmd='list host interface compute-0-0'>
+ List network interface info for compute-0-0.
+ </example>
+
+ <example cmd='list host interface'>
+ List network interface info for all known hosts.
+ </example>
+ """
+
+ def run(self, params, args):
+
+ self.beginOutput()
+ for host in self.getHostnames(args):
+ self.db.execute(""" SELECT s.name, n.Device, n.Mac,
+ n.ip, s.netmask, n.Name
+ FROM networks n LEFT JOIN subnets s ON n.subnet=s.id
+ INNER JOIN nodes h ON n.node = h.id AND h.name='%s'""" %host)
+ for row in self.db.fetchall():
+ #
+ # if device name matches vlan* then clear
+ # fields for printing
+ #
+ self.addOutput(host, row)
+
+ self.endOutput(header=['host', 'subnet', 'iface', 'mac', 'ip', 'netmask', 'name'])
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/list/license/__init__.py b/lib/sunhpc/commands/list/license/__init__.py
new file mode 100644
index 0000000..b7893e7
--- /dev/null
+++ b/lib/sunhpc/commands/list/license/__init__.py
@@ -0,0 +1,20 @@
+
+import os
+import sunhpc
+class Command(sunhpc.commands.list.command):
+ """
+ List the Sunhpc copyright.
+
+ <example cmd='list license'>
+ List the Sunhpc copyright.
+ </example>
+ """
+
+ def run(self, params, args):
+ license = os.path.join(eval(self.__module__).__path__[0],
+ 'license.txt')
+ with open(license, 'r') as f:
+ for line in f.readlines():
+ self.addText(line)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/list/license/license.txt b/lib/sunhpc/commands/list/license/license.txt
new file mode 100644
index 0000000..a748443
--- /dev/null
+++ b/lib/sunhpc/commands/list/license/license.txt
@@ -0,0 +1,6 @@
+ SUNHPC(r)
+ www.sunhpc.com
+ version 1.0.0 (Kamaitachi)
+
+Copyright (c) 2000 - 2020 The software of the Dalian Hengpu Electronic
+Technology Co., Ltd. all rights reserved.
diff --git a/lib/sunhpc/commands/pxelinux/__init__.py b/lib/sunhpc/commands/pxelinux/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/pxelinux/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/pxelinux/build/__init__.py b/lib/sunhpc/commands/pxelinux/build/__init__.py
new file mode 100644
index 0000000..98befa4
--- /dev/null
+++ b/lib/sunhpc/commands/pxelinux/build/__init__.py
@@ -0,0 +1,228 @@
+#coding:utf-8
+
+import os
+import sys
+import shutil
+import sunhpc
+import subprocess
+class RollHandler(object):
+
+ def __init__(self, cmd, mnt, clean):
+
+ self.cmd = cmd
+ self.clean = clean
+ self.mntdir = mnt
+ self.rinfo = None
+
+ def is_mounted(self):
+ cmd = 'mount |grep %s' % self.mntdir
+ if subprocess.call(cmd, shell=True):
+ return 0
+ return 1
+
+ def mount_iso(self, iso):
+ subprocess.run('mount -r %s %s' % (iso, self.mntdir), shell=True, check=True)
+
+ def umount_iso(self):
+ subprocess.run('umount %s' % self.mntdir, shell=True)
+
+ def read_iso(self):
+ cmd = 'find %s -type f -name roll-\*.xml' % self.mntdir
+ ret = self.cmd.shcmd(cmd, code=True)
+ try:
+ roll = sunhpc.core.files.RollInfoFile(ret['o'].strip())
+ self.rinfo = roll
+ except FileNotFoundError:
+ pass
+
+ def foreign_roll(self):
+ self.cmd.msg('This ISO file cannot be recognized.This ISO was not produced by sunhpc created.', 'w')
+
+ def copy_iso(self):
+ self.read_iso()
+ if not self.rinfo:
+ self.foreign_roll()
+
+ self.copy_roll()
+
+ def copy_roll(self):
+ # self.rinfo是经过file.RollInfoFile处理过的,返回相应的类对象.
+ roll_name = self.rinfo.getRollName()
+ roll_vers = self.rinfo.getRollVersion()
+ roll_arch = self.rinfo.getRollArch()
+ roll_os = self.rinfo.getRollOS()
+ roll_boot = self.rinfo.isBootable()
+
+ # 获取rolls存放位置
+ cmd = '/opt/sunhpc/bin/sunhpc report distro'
+ for line in os.popen(cmd).readlines():
+ distro = line[:-1]
+
+ # /export/sunhpc/install/rolls
+ rolls_dir = '%s/rolls' % (distro)
+
+ # /export/sunhpc/install/rolls/kernel......
+ roll_dir = os.path.join(rolls_dir, roll_name)
+
+ if self.clean:
+ # /export/sunhpc/install/rolls/kernel/version/x86_64
+ specific_roll_dir = os.path.join(roll_dir, roll_vers, roll_arch)
+ if os.path.exists(specific_roll_dir):
+ txt = '[+] Cleaning %s version %s for %s from the Rolls directory.' % \
+ (roll_name, roll_vers, roll_arch)
+ self.cmd.msg(txt)
+ self.clean_dir(specific_roll_dir)
+ os.makedirs(specific_roll_dir)
+
+ cwd = os.getcwd()
+ os.chdir(os.path.join(self.mntdir, roll_name))
+
+ # 计算一下大小
+ cmd = 'du -sh %s' % self.mntdir
+ ret = self.cmd.shcmd(cmd, code=True)
+ size = 'unkown'
+ if ret['c'] == 0:
+ size = ret['o'].split()[0]
+
+ # 开始拷贝文件
+ self.cmd.msg('Copying %s(%s) to %s, Please waiting....\r' % (roll_name, size, roll_dir))
+
+ # find . ! -name TRANS.TBL ! -path './kernel*' -print 排除目录和对应文件
+ # cpio -d自动创建目录, -u自动覆盖旧文件 -p目录拷贝模式 -m保留文件原始mtime
+ # subprocess.run('find . ! -name TRANS.TBL -print|cpio -mpud %s' % roll_dir, shell=True)
+
+ rpmlist = self.cmd.shcmd('find . ! -name TRANS.TBL -print')['o'].split('\n')
+ self.cmd.copyfiles(rpmlist, roll_dir)
+
+ # 拷贝iso根目录下其他引导所需要的文件.
+ if roll_boot:
+ os.chdir(self.mntdir)
+ boot_dir = os.path.join(rolls_dir, 'boot')
+ bootlist = self.cmd.shcmd('find . ! -name TRANS.TBL ! -path "./%s*" -print' % roll_name)['o'].split('\n')
+ self.cmd.copyfiles(bootlist, boot_dir)
+
+ os.chdir(os.path.join(self.mntdir, roll_name))
+
+ # 修改所有目录权限,确保所有人能够访问,例如Apache等.
+ self.cmd.shcmd('find %s -type d -exec chmod a+rx {} \;' % roll_dir, code=True)
+
+ # 插入roll信息到数据库,如果存在则放弃插入.
+ rows = self.cmd.db.search('select * from rolls where' \
+ ' name="%s" and version="%s" and arch="%s" and os="%s"' \
+ % (roll_name, roll_vers, roll_arch, roll_os))
+
+ if not rows:
+ db_cmd = 'insert into rolls (name, version, arch, enabled, os)'\
+ ' values("%s", "%s", "%s", "%s", "%s")' \
+ % (roll_name, roll_vers, roll_arch, 'yes', roll_os)
+
+ self.cmd.db.execute(db_cmd)
+
+ os.chdir(cwd)
+ # 如果是引导roll,需要使用软件Kylins create distro重新部署.
+ # 重新创建images,stage2,和相关连接相关任务.
+ self.cmd.msg('Copy the %s finished.' % roll_name)
+
+ def clean_dirs(self, dirs):
+ for root, dirs, files in os.walk(dirs, topdown=False):
+ for name in files:
+ os.remove(os.path.join(root, name))
+ for name in dirs:
+ os.rmdir(os.path.join(root, name))
+ os.removedirs(dirs)
+
+class command(sunhpc.commands.pxelinux.command):
+ pass
+
+class Command(command):
+ """
+ build pxe install in the cluster
+ <arg type='File' name='iso'>
+ supply an sunhpc iso images file.
+ </arg>
+
+ <param type='File' name='iso'>
+ supply an sunhpc iso images file.
+ </param>
+
+ <param type='Path' name='mnt'>
+ Specify a temporary iso images mount directory.default:/mnt/cdrom
+ </param>
+
+ <param type='Bool' name='pxesrv'>
+ Pxelinux services are also installed , httpd, dhcpd, tftp. autofs,default:yes
+ </param>
+
+ <param type='bool' name='clean'>
+ If set, then remove all files from any existing rolls of the same name.
+ </param>
+
+ <example cmd='pxelinux build CentOS-7.9.2009-0.x86_64.disk.iso'>
+ Added the sunhpc iso to the local Roll directory.
+ </example>
+
+ <related>sunhpc create roll roll-CentOS.xml boot=/mnt/cdrom</related>
+
+ """
+
+ def run(self, params, args):
+
+ (clean, mnt, pxe, quiet) = self.fillParams([
+ ('clean', 'n'), ('mnt', '/mnt/cdrom'),
+ ('pxesrv', 'yes'),
+ ('quiet', 'no')])
+
+ if os.path.exists(self.db.getDBFile):
+ shutil.chown(self.db.getDBFile, 'root', 'apache')
+
+ cwd = os.getcwd()
+ q = self.str2bool(quiet)
+ pxesrv = self.str2bool(pxe)
+ clean = self.str2bool(clean)
+ if len(args) == 0:
+ self.abort('must supply linux iso images file.')
+
+ if not os.path.exists(mnt):
+ os.makedirs(mnt)
+
+ iso_list = []
+ for i in args:
+ i = os.path.join(os.getcwd(),i)
+ if os.path.exists(i) and i.endswith('.iso'):
+ iso_list.append(i)
+ else:
+ self.msg("Cannot find %s or %s is not and ISO image" % (i, i), 'a')
+
+ roll_handler = RollHandler(self, mnt, clean)
+ if roll_handler.is_mounted():
+ self.msg("The %s has been mounted,umount is firsrt." % mnt, 'a')
+
+ for i in iso_list:
+ roll_handler.mount_iso(i)
+ roll_handler.copy_iso()
+ roll_handler.umount_iso()
+
+ if pxesrv:
+ self.configSRV(q)
+ else:
+ self.msg('Rebuild the sunhpc distribution -> /export/sunhpc/install')
+ self.msg('Use the command: sunhpc create distro')
+
+ # modify the database file owner root.apache
+
+ os.chdir(cwd)
+
+ def configSRV(self, q):
+
+ distbase = self.command('report.distro',[]).strip()
+ os.chdir(distbase)
+ # /export/sunhpc/install
+ self.command('create.distro')
+
+ srvlist = ['httpd', 'tftpd', 'dhcpd', 'autofs']
+ for s in srvlist:
+ self.msg('Building the pxe services - %s' % s, q=q)
+ self.command('pxelinux.build.%s' % s, ['quiet=%s' % q])
+
+RollName = "base"
+
diff --git a/lib/sunhpc/commands/pxelinux/build/autofs/__init__.py b/lib/sunhpc/commands/pxelinux/build/autofs/__init__.py
new file mode 100644
index 0000000..0f35d99
--- /dev/null
+++ b/lib/sunhpc/commands/pxelinux/build/autofs/__init__.py
@@ -0,0 +1,124 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import shutil
+import textwrap
+import subprocess
+class Command(sunhpc.commands.pxelinux.build.command):
+ """
+ Build the autofs service for the sunhpc cluster.
+
+ <param type='Bool' name='Quiet'>
+ Whether to output detailed information, default: no
+ </param>
+
+ <example cmd='pxelinux build autofs'>
+ In local build the autofs service.
+ </example>
+
+ <example cmd='pxelinux build autofs'>
+ In local build the autofs service.
+ </example>
+ """
+ def run(self, params, args):
+
+ (quiet, ) = self.fillParams([('quiet', 'no')])
+
+ cwd = os.getcwd()
+ q = self.str2bool(quiet)
+ self.msg('Starting build the autofs service...', q=q)
+
+ isInstalled = True
+ autofs_files = ['/etc/auto.master', '/etc/auto.home', '/etc/auto.share']
+ for autofs in autofs_files:
+ if not os.path.exists(autofs):
+ isInstalled = False
+
+
+ # The httpd is not installed.
+ if not isInstalled:
+ self.shcmd('yum install -y autofs rpcbind nfs-utils --enablerepo=Sunhpc-%s' % sunhpc.version)
+
+ self.fw.addports(['111', '2049', '30001-30004'], ['tcp', 'udp'])
+
+ address = self.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ network = self.db.getHostAttr('localhost', 'Kickstart_PrivateNetwork')
+ netmask = self.db.getHostAttr('localhost', 'Kickstart_PrivateNetmask')
+
+ export_file = '/etc/exports'
+ with open(export_file, 'w') as f:
+ f.write(textwrap.dedent(f"""\
+ /export {address}(rw,async,no_root_squash) {network}/{netmask}(rw,async)
+ """))
+
+ sysconfig = '/etc/sysconfig/nfs'
+ if self.matchText(sysconfig, 'RQUOTAD_PORT=30001'):
+ with open(sysconfig, 'a') as f:
+ f.write(textwrap.dedent(f"""\
+ RQUOTAD_PORT=30001
+ LOCKD_TCPPORT=30002
+ LOCKD_UDPPORT=30002
+ MOUNTD_PORT=30003
+ STATD_PORT=30004
+ """))
+
+ lockd = '/etc/modprobe.d/lockd.conf'
+ if self.matchText(lockd, 'nlm_tcpport=30002'):
+
+ with open(lockd, 'a') as f:
+ f.write(textwrap.dedent(f"""\
+ options lockd nlm_tcpport=30002
+ options lockd nlm_udpport=30002
+ """))
+
+ autofs_config = "/etc/autofs.conf"
+ os.system("sed -i 's/mount_nfs_default_protocol = 4/mount_nfs_default_protocol = 3/g' %s" % autofs_config)
+
+ apps = '/export/apps'
+ if not os.path.exists(apps):
+ os.makedirs(apps)
+
+ home = '/export/home'
+ if not os.path.exists(home):
+ os.makedirs(home)
+
+ self.command('sync.users')
+
+ self.shcmd('systemctl daemon-reload')
+ self.shcmd('systemctl stop autofs')
+ self.shcmd('systemctl start autofs')
+ self.shcmd('systemctl enable autofs')
+ self.shcmd('systemctl stop nfs')
+ self.shcmd('systemctl start nfs')
+ self.shcmd('systemctl enable nfs')
+
+RollName = "base"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/pxelinux/build/dhcpd/__init__.py b/lib/sunhpc/commands/pxelinux/build/dhcpd/__init__.py
new file mode 100644
index 0000000..4a4d0c0
--- /dev/null
+++ b/lib/sunhpc/commands/pxelinux/build/dhcpd/__init__.py
@@ -0,0 +1,69 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import shutil
+import textwrap
+import subprocess
+class Command(sunhpc.commands.pxelinux.build.command):
+ """
+ Build the dhcp service for the sunhpc cluster.
+
+ <param type='Bool' name='Quiet'>
+ Whether to output detailed information, default: no
+ </param>
+
+ <example cmd='pxelinux build dhcpd'>
+ In local build the dhcp service.
+ </example>
+
+ <example cmd='pxelinux build dhcpd'>
+ Build the dhcp service.
+ </example>
+ """
+ def run(self, params, args):
+
+ (quiet, ) = self.fillParams([('quiet', 'no')])
+
+ q = self.str2bool(quiet)
+ self.msg('Starting build the dhcpd service...', q=q)
+
+ # install dhcp server
+ self.installDhcpd(q)
+
+ def installDhcpd(self, q):
+ config = '/etc/dhcp/dhcpd.conf'
+
+ # The dhcpd is not installed.
+ if not os.path.exists(config):
+ self.shcmd('yum install -y dhcp dhcp-common dhcp-libs --enablerepo=Sunhpc-%s' % sunhpc.version)
+
+ self.fw.addports(['67'], ['udp'])
+
+ os.system('sunhpc report host dhcpd > %s' % config)
+
+ dhcpd_srv_conf = '/usr/lib/systemd/system/dhcpd.service'
+ with open(dhcpd_srv_conf, 'w') as f:
+ f.write(textwrap.dedent("""\
+ [Unit]
+ Description=DHCPv4 Server Daemon
+ Documentation=man:dhcpd(8) man:dhcpd.conf(5)
+ Wants=network-online.target
+ After=network-online.target
+ After=time-sync.target
+
+ [Service]
+ Type=notify
+ ExecStart=/usr/sbin/dhcpd -f -cf /etc/dhcp/dhcpd.conf -user dhcpd -group dhcpd --no-pid %s
+
+ [Install]
+ WantedBy=multi-user.target
+ """ % devices))
+
+ self.shcmd('systemctl daemon-reload')
+ self.shcmd('systemctl stop dhcpd')
+ self.shcmd('systemctl start dhcpd')
+ self.shcmd('systemctl enable dhcpd')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/pxelinux/build/httpd/__init__.py b/lib/sunhpc/commands/pxelinux/build/httpd/__init__.py
new file mode 100644
index 0000000..8a076b5
--- /dev/null
+++ b/lib/sunhpc/commands/pxelinux/build/httpd/__init__.py
@@ -0,0 +1,195 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import shutil
+import textwrap
+import subprocess
+class Command(sunhpc.commands.pxelinux.build.command):
+ """
+ Build the http service for the sunhpc cluster.
+
+ <param type='Path' name='PxeDir'>
+ supply an pxelinux base directory.default:/httpboot/pxelinux
+ </param>
+
+ <param type='Bool' name='Quiet'>
+ Whether to output detailed information, default: no
+ </param>
+
+ <example cmd='pxelinux build httpd'>
+ In local build the http service.
+ </example>
+
+ <example cmd='pxelinux build httpd pxedir=/httpboot/pxelinux'>
+ In local /httpboot/pxelinux directory build the http service.
+ </example>
+ """
+ def run(self, params, args):
+
+ (quiet, ) = self.fillParams([('quiet', 'no')])
+ cwd = os.getcwd()
+ q = self.str2bool(quiet)
+ self.msg('Starting build the httpd service...', q=q)
+
+ # sunhpc-dist
+ distname = self.db.getHostAttr('localhost', 'distribution')
+ # /export/sunhpc/install
+ distbase = self.command('report.distro',[]).strip()
+ # /export/sunhpc/install/sunhpc-dist/x86_64/build
+ install = os.path.join(distbase, distname, self.arch, 'build')
+ # /export/sunhpc/install/sunhpc-dist/x86_64/build/sbin
+ src_sbin = os.path.join(distname, self.arch, 'build', 'sbin')
+ # /export/sunhpc/install/sbin
+ dst_sbin = os.path.join(distbase, 'sbin')
+
+ os.chdir(distbase)
+ if not os.path.exists(os.path.join(src_sbin, 'kickstart.cgi')):
+ self.msg('Install the cgi-module-kickstart rpm ...', q=q)
+ self.command('create.distro')
+
+ if not os.path.exists(dst_sbin):
+ os.symlink(src_sbin, dst_sbin)
+
+ # chmod kickstart perms 755
+ kickstart_cgi = os.path.join(src_sbin, 'kickstart.cgi')
+ os.chmod(kickstart_cgi, 0o0755)
+
+ # enable yum file repos.
+ self.command('create.repos', ['file=1'])
+
+ self.installHttpd(q)
+
+ # enable yum http repos.
+ self.command('create.repos', ['web=1'])
+ os.chdir(cwd)
+
+ def installHttpd(self, q):
+ html = '/var/www/html'
+ conf = '/etc/httpd/conf/httpd.conf'
+ hostname = self.db.getHostAttr('localhost', 'Kickstart_PrivateHostname')
+ basedirs = self.db.getHostAttr('localhost', 'Kickstart_BaseDir')
+ distdirs = self.command('report.distro')
+
+ # The httpd is not installed.
+ if not os.path.exists(conf):
+ self.shcmd('yum clean all; yum makecache')
+ self.shcmd('yum install -y httpd httpd-tools httpd-devel')
+
+ self.fw.addports(['80', '443'], ['tcp'])
+
+ sunhpc_httpd_conf = '/etc/httpd/conf.d/sunhpc.conf'
+ with open(sunhpc_httpd_conf, 'w') as f:
+ f.write(textwrap.dedent(f"""\
+ #
+ # Sunhpc specific apache configuration.
+ #
+ <IfModule mod_mime.c>
+ AddHandler cgi-script .cgi
+ </IfModule>
+
+ UseCanonicalName Off
+ ServerName {hostname}
+
+ <Directory "/var/www/html">
+ Options FollowSymLinks Indexes ExecCGI
+ AllowOverride None
+ Order allow,deny
+ Allow from all
+ </Directory>
+ """))
+ os.chmod(sunhpc_httpd_conf, 0o0644)
+
+ central_httpd_conf = '/etc/httpd/conf.d/central.conf'
+ with open(central_httpd_conf, 'w') as f:
+ f.write(textwrap.dedent(f"""\
+ #
+ # Export Root Access Config
+ #
+ <Directory /var/www/html/{basedirs}>
+ Options FollowSymLinks Indexes ExecCGI
+ AllowOverride None
+ Allow from all
+ </Directory>
+
+ # HTTPS access for serving kickstart files
+ <Directory /var/www/html/{basedirs}/sbin>
+ Options ExecCGI
+ AllowOverride None
+ Order allow,deny
+ Allow from all
+ </Directory>
+
+ # allow all access to the rolls RPMS
+ <Directory /var/www/html/{basedirs}/rolls>
+ Allow from all
+ </Directory>
+ """))
+ os.chmod(central_httpd_conf, 0o0644)
+
+ address = self.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ network = '.'.join(address.split('.')[:3])
+ safe_httpd_conf = '/etc/httpd/conf.d/safe.conf'
+ with open(safe_httpd_conf, 'w') as f:
+ f.write(textwrap.dedent(f"""\
+ #
+ # sunhpc pxelinux build httpd
+ #
+ Listen 372
+ <VirtualHost {address}:372>
+ Alias /safe.d/ "/etc/safe.d/"
+ Alias /safe.d "/etc/safe.d"
+
+ <Directory /etc/safe.d>
+
+ Options Indexes MultiViews
+ IndexOptions FancyIndexing NameWidth=*
+ RemoveHandler .var
+
+ AllowOverride None
+ Require ip {network}
+
+ </Directory>
+ </VirtualHost>
+ """))
+ os.chmod(safe_httpd_conf, 0o0644)
+
+ # sylink the /export/sunhpc/install to /var/www/html
+ html = os.path.join(html, basedirs)
+ if not os.path.islink(html):
+ os.symlink(distdirs.strip(), html.strip())
+
+ self.shcmd('systemctl daemon-reload')
+ self.shcmd('systemctl stop httpd')
+ self.shcmd('systemctl start httpd')
+ self.shcmd('systemctl enable httpd')
+
+RollName = "base"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/pxelinux/build/nodes/__init__.py b/lib/sunhpc/commands/pxelinux/build/nodes/__init__.py
new file mode 100644
index 0000000..8dcbe2b
--- /dev/null
+++ b/lib/sunhpc/commands/pxelinux/build/nodes/__init__.py
@@ -0,0 +1,71 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import shutil
+import textwrap
+import subprocess
+class Command(sunhpc.commands.pxelinux.build.command):
+ """
+ Build the nodes configure for the sunhpc cluster.
+
+ <arg type='string' name='hosts'>
+ supply an host and execute the nodes configure.
+ </arg>
+
+ <param type='Bool' name='Quiet'>
+ Whether to output detailed information, default: no
+ </param>
+
+ <example cmd='pxelinux build nodes cluster'>
+ Example Modify the configuration of an "cluster" node.
+ </example>
+ """
+ def run(self, params, args):
+
+ (quiet, ) = self.fillParams([('quiet', 'no')])
+ q = self.str2bool(quiet)
+
+ if not len(args):
+ self.msg('must supply one host.', 'a')
+
+ host = args[0]
+ control = self.db.getHostAttr('localhost', 'Kickstart_PrivateHostname')
+ if host == control:
+ # build the cluster node
+ self.msg('Starting build the %s node configure...' % host)
+ self.loadModules('control')
+ else:
+ # build the compute node
+ self.msg('Starting build the %s node configure...' % host)
+ self.loadModules('compute')
+
+RollName = "base"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/pxelinux/build/tftpd/__init__.py b/lib/sunhpc/commands/pxelinux/build/tftpd/__init__.py
new file mode 100644
index 0000000..82a7b4c
--- /dev/null
+++ b/lib/sunhpc/commands/pxelinux/build/tftpd/__init__.py
@@ -0,0 +1,173 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import shutil
+import textwrap
+import subprocess
+class Command(sunhpc.commands.pxelinux.build.command):
+ """
+ Build the tftp service for the sunhpc cluster.
+
+ <param type='Path' name='PxeDir'>
+ supply an pxelinux base directory.default:/tftpboot/pxelinux
+ </param>
+
+ <param type='Bool' name='Quiet'>
+ Whether to output detailed information, default: no
+ </param>
+
+ <example cmd='pxelinux build tftpd'>
+ In local build the tftp service.
+ </example>
+
+ <example cmd='pxelinux build tftpd pxedir=/tftpboot/pxelinux'>
+ In local /tftpboot/pxelinux directory build the tftp service.
+ </example>
+ """
+ def run(self, params, args):
+
+ (self.basedir, quiet) = self.fillParams([
+ ('pxedir', '/tftpboot/pxelinux'),
+ ('quiet', 'no')])
+
+ q = self.str2bool(quiet)
+ self.msg('Starting build the tftpd service...', q=q)
+
+ if not os.path.exists(self.basedir):
+ self.msg('Create the tftpd dirctory ...', q=q)
+ os.makedirs(self.basedir)
+
+ self.pxecfg = os.path.join(self.basedir, 'pxelinux.cfg')
+ if not os.path.exists(self.pxecfg):
+ self.msg('Create the pxelinux.cfg dirctory ...', q=q)
+ os.makedirs(self.pxecfg)
+
+ shared = os.path.join(self.prefix, 'share/isobuild/anaconda-shared')
+
+ os.chmod(self.pxecfg, 0o755)
+ self.create_user('apache', uid=48, gid=48, home='/var/apache', nohome=1, shell='/sbin/nologin')
+ os.system('chown -R root.apache %s' % self.pxecfg)
+
+ distbase = self.command('report.distro',[]).strip()
+ distname = self.db.getHostAttr('localhost', 'distribution')
+ distro = os.path.join(distbase, distname, self.arch)
+ images = os.path.join(distro, 'images')
+
+ pxefile = os.listdir(images) + os.listdir(shared)
+ for i in pxefile:
+ if i.startswith('initrd'):
+ self.msg('Copying to %s - %s' % (self.basedir, i), q=q)
+ self.copyit(os.path.join(images, i), os.path.join(self.basedir, i), 0o755, '0.0')
+
+ if i.startswith('vmlinuz'):
+ self.msg('Copying to %s - %s' % (self.basedir, i), q=q)
+ self.copyit(os.path.join(images, i), os.path.join(self.basedir, i), 0o755, '0.0')
+
+ if i.startswith('efiboot.img'):
+ self.msg('Copying to %s - %s' % (self.basedir, i), q=q)
+ self.copyit(os.path.join(images, i), os.path.join(self.basedir, i), 0o644, '0.0')
+
+ if i.startswith('pxelinux.0'):
+ self.msg('Copying to %s - %s' % (self.basedir, i), q=q)
+ self.copyit(os.path.join(shared, i), os.path.join(self.basedir, i), 0o644, '0.0')
+
+ if i.startswith('gpxelinux.0'):
+ self.msg('Copying to %s - %s' % (self.basedir, i), q=q)
+ self.copyit(os.path.join(shared, i), os.path.join(self.basedir, i), 0o644, '0.0')
+
+ if i.startswith('memtest'):
+ self.msg('Copying to %s - %s' % (self.basedir, i), q=q)
+ self.copyit(os.path.join(shared, i), os.path.join(self.basedir, i), 0o755, '0.0')
+
+ # write the pxelinux default file
+ self.writeDefault(q)
+
+ # install tftp server
+ self.installTftpd(q)
+
+ def writeDefault(self, q):
+ default = os.path.join(self.pxecfg, 'default')
+
+ self.msg('Configure pxelinux default boot file to %s ' % (default), q=q)
+ self.db.execute('select kernel, ramdisk, args from bootactions where id=1')
+ vmlinuz, initrd, argsks = self.db.fetchone()
+
+ with open(default, 'w') as f:
+ f.write('#\n# Generated by sunhpc pxelinux build tftpd\n#\n')
+ f.write('default sunhpc\n')
+ f.write('prompt 0\n')
+ f.write('label sunhpc\n')
+ f.write(' kernel %s\n' % vmlinuz)
+ f.write(' append initrd=%s %s \n' % (initrd, argsks))
+ os.chmod(default, 0o0664)
+ os.system('chown root.apache %s' % default)
+
+ def installTftpd(self, q):
+ conf = '/etc/xinetd.d/tftp'
+ hostname = self.db.getHostAttr('localhost', 'Kickstart_PrivateHostname')
+ basedirs = self.db.getHostAttr('localhost', 'Kickstart_BaseDir')
+ distdirs = self.command('report.distro')
+
+ # The httpd is not installed.
+ if not os.path.exists(conf):
+ self.shcmd('yum install -y tftp tftp-server syslinux --enablerepo=Sunhpc-%s' % sunhpc.version)
+
+ self.fw.addports(['69'], ['udp'])
+ self.fw.addservice('tftp')
+
+ sunhpc_tftpd_conf = '/etc/xinetd.d/tftp'
+ with open(sunhpc_tftpd_conf, 'w') as f:
+ f.write(textwrap.dedent("""\
+ # default: off
+ # protocol. The tftp protocol is often used to boot diskless
+ # workstations, download configuration files to network-aware printers,
+ # and to start the installation process for some operating systems.
+ service tftp
+ {
+ socket_type = dgram
+ protocol = udp
+ wait = yes
+ user = root
+ server = /usr/sbin/in.tftpd
+ server_args = -s %s -B 1468
+ disable = no
+ per_source = 11
+ cps = 100 2
+ flags = IPv4
+ }
+ """ % self.basedir))
+
+ tftpd_service_conf = '/usr/lib/systemd/system/tftp.service'
+ with open(tftpd_service_conf, 'w') as f:
+ f.write(textwrap.dedent(f"""\
+ [Unit]
+ Description=Tftp Server
+ Requires=tftp.socket
+ Documentation=man:in.tftpd
+
+ [Service]
+ ExecStart=/usr/sbin/in.tftpd -s {self.basedir}
+ StandardInput=socket
+
+ [Install]
+ Also=tftp.socket
+ """))
+
+ self.shcmd('systemctl daemon-reload')
+ self.shcmd('systemctl start tftp')
+ self.shcmd('systemctl enable tftp.socket')
+
+ def copyit(self, src, dst, perms=0o644, owner=None):
+ if not os.path.exists(src):
+ self.msg('The %s not exists.' % src, 'a')
+
+ if os.path.exists(dst): os.remove(dst)
+
+ shutil.copyfile(src, dst)
+ os.chmod(dst, perms)
+ if owner:
+ os.system('chown %s %s' % (owner, dst))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/repair/__init__.py b/lib/sunhpc/commands/repair/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/repair/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/repair/permission/__init__.py b/lib/sunhpc/commands/repair/permission/__init__.py
new file mode 100644
index 0000000..f8d53f2
--- /dev/null
+++ b/lib/sunhpc/commands/repair/permission/__init__.py
@@ -0,0 +1,54 @@
+#coding:utf-8
+import os
+import sys
+import pwd
+import shutil
+import sunhpc
+class Command(sunhpc.commands.repair.command):
+ """
+ Repair the Sunhpc distribution permission.
+
+ <example cmd='repair permission'>
+ Repair the Sunhpc distribution permission.
+ </example>
+ """
+ def run(self, params, args):
+
+ #
+ # repair database permission.
+ #
+ dbpath = os.path.join(self.prefix, 'data')
+ dbfile = os.path.join(self.prefix, 'data', 'sunhpcdb')
+ self.chown('apache', 'apache', [dbpath, dbfile])
+ self.chmod('700', [dbpath, dbfile])
+
+ #
+ # Repair cgi file permission
+ #
+ distro = self.command('report.distro').strip()
+ ksdirs = os.path.join(distro, 'sbin')
+ kslist = list(map(lambda x: os.path.join(ksdirs, x), os.listdir(ksdirs)))
+ self.chmod('755', kslist)
+
+ #
+ # Repair pxelinux.cfg
+ #
+ pxedir = self.db.getHostAttr('localhost', 'pxelinuxdir')
+ pxecfg = os.path.join(pxedir, 'pxelinux.cfg')
+ pxes = list(map(lambda x: os.path.join(pxecfg, x), os.listdir(pxecfg)))
+ result = pxes + [pxecfg]
+ self.chown('root', 'apache', result)
+ self.chmod('664', pxes)
+ self.chmod('755', [pxecfg])
+
+
+ def chown(self, user, group, path=[]):
+ for p in path:
+ shutil.chown(p, user, group)
+
+ def chmod(self, perms, path=[]):
+ """ perms is oct number, e,g. 0o644"""
+ for p in path:
+ os.system('/usr/bin/chmod %s %s' % (perms, p))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/repair/users/__init__.py b/lib/sunhpc/commands/repair/users/__init__.py
new file mode 100644
index 0000000..08fd583
--- /dev/null
+++ b/lib/sunhpc/commands/repair/users/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.repair.command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/repair/users/authorized/__init__.py b/lib/sunhpc/commands/repair/users/authorized/__init__.py
new file mode 100644
index 0000000..d1c794d
--- /dev/null
+++ b/lib/sunhpc/commands/repair/users/authorized/__init__.py
@@ -0,0 +1,82 @@
+#coding:utf-8
+import os
+import sys
+import sunhpc
+class command(sunhpc.commands.repair.users.command):
+ pass
+class Command(command):
+ """
+ Repair the system users.
+ <arg type='string' name='user'>
+ Provide an user accecc name.
+ </arg>
+
+ <param type='Bool' name='all'>
+ fix all users authorized_keys
+ </param>
+
+ <example cmd='repair users'>
+ Repair user account information.
+ </example>
+ """
+ def run(self, params, args):
+
+ (alls, ) = self.fillParams([('all', 'no')])
+ alls = self.str2bool(alls)
+
+ if not args and not alls:
+ self.msg('must supply an user or all=True', 'a')
+
+ userlist = []
+
+ if alls:
+ userlist.extend(self.getAllUsers())
+
+ if args:
+ userlist.extend(args)
+
+ # 去重.
+ userlist = sorted(list(set(userlist)))
+
+ home = '/export/home'
+ for user in userlist:
+ userhome = os.path.join(home, user)
+ sshdhome = os.path.join(userhome, '.ssh')
+ if not os.path.exists(sshdhome):
+ os.makedirs(sshdhome)
+
+ ssh_auth = os.path.join(sshdhome, 'authorized_keys')
+ ssh_keys = os.path.join(sshdhome, 'id_rsa')
+ ssh_pubs = os.path.join(sshdhome, 'id_rsa.pub')
+
+ # create the id_rsa keypair
+ if os.path.exists(ssh_keys):
+ os.remove(ssh_keys)
+ os.remove(ssh_pubs)
+ cmd = '/usr/bin/ssh-keygen -q -t rsa -P "" -f %s' % ssh_keys
+ os.system(cmd)
+
+ # generate the authorized_keys
+ if os.path.exists(ssh_pubs):
+ cmd = '/usr/bin/cat %s > %s' % (ssh_pubs, ssh_auth)
+ os.system(cmd)
+
+ # chown and chmod ssh file
+ os.system('chown -R %s:%s %s' % (user, user, userhome))
+
+ os.chmod(userhome, 0o700)
+ os.chmod(sshdhome, 0o700)
+ os.chmod(ssh_auth, 0o644)
+ os.chmod(ssh_keys, 0o600)
+ os.chmod(ssh_pubs, 0o644)
+
+ def getAllUsers(self):
+ users = []
+ autofs = '/etc/auto.home'
+ with open(autofs, 'r') as f:
+ for line in f:
+ users.append(line.split()[0])
+
+ return users
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/__init__.py b/lib/sunhpc/commands/report/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/report/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/distro/__init__.py b/lib/sunhpc/commands/report/distro/__init__.py
new file mode 100644
index 0000000..4d05369
--- /dev/null
+++ b/lib/sunhpc/commands/report/distro/__init__.py
@@ -0,0 +1,19 @@
+#coding:utf-8
+
+import sunhpc
+class Command(sunhpc.commands.report.command):
+ """
+ Output the path prefix for the location of the Rocks distribution.
+
+ <example cmd='report distro'>
+ Output the current path prefix to the distribution.
+ </example>
+ """
+ def run(self, params, args):
+ distrodir = '/export/sunhpc/install'
+
+ self.beginOutput()
+ self.addOutput('', distrodir)
+ self.endOutput(padChar='')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/host/__init__.py b/lib/sunhpc/commands/report/host/__init__.py
new file mode 100644
index 0000000..7570d64
--- /dev/null
+++ b/lib/sunhpc/commands/report/host/__init__.py
@@ -0,0 +1,148 @@
+import os
+import sunhpc
+class command(sunhpc.commands.HostArgumentProcessor,
+ sunhpc.commands.report.command):
+ pass
+
+class Command(command):
+ """
+ Report the host to IP address mapping in the form suitable for
+ /etc/hosts.
+
+ <example cmd='report host'>
+ Outputs data for /etc/hosts.
+ </example>
+ """
+ def hostlocal(self, hostsFile):
+
+ if os.path.isfile(hostsFile):
+ print ('# import from %s' % hostsFile)
+ with open(hostsFile, 'r') as f:
+ for line in f.readlines():
+ print (line[:-1])
+
+ def extranics(self):
+ self.db.execute("""select networks.IP, networks.Name from
+ networks,subnets where subnets.name != "private" and
+ networks.subnet = subnets.id and
+ networks.ip is not NULL order by networks.IP""")
+
+ nodes=[]
+ for row in self.db.fetchall():
+ node = sunhpc.core.utils.Struct()
+ node.address = row[0]
+ node.name = [row[1],]
+ nodes.append(node)
+
+ for node in nodes:
+ if node.name[0] is not None:
+ print ('%s\t%s' % (node.address, ' '.join(node.name)))
+
+ def hostlines(self, subnet, netmask):
+
+ ip = sunhpc.core.ip.IPGenerator(subnet, netmask)
+ domain = self.db.getHostAttr('localhost', 'Kickstart_PrivateDNSDomain')
+ self.db.execute("""select n.id, n.rack, n.rank from nodes n order by n.id""")
+
+ nodes=[]
+ for row in self.db.fetchall():
+ node = sunhpc.core.utils.Struct()
+ node.id = row[0]
+ node.rack = row[1]
+ node.rank = row[2]
+ node.warning = None
+
+ self.db.execute("""select networks.name, networks.ip
+ from networks, subnets where
+ networks.node = %d and
+ subnets.name = "private" and
+ networks.subnet = subnets.id and
+ (networks.device not like 'vlan%%' or
+ networks.device is NULL)""" %
+ (node.id))
+
+ row = self.db.fetchone()
+ if row == None:
+ continue
+
+ nodes.append(node)
+ node.name = [row[0],]
+ node.address = row[1]
+
+ if not node.address:
+ node.address = ip.dec()
+
+ name = 'compute-%d-%d' % (node.rack, node.rank)
+
+ # If there is no name in the database, use the
+ # generated one.
+ if not node.name[0]:
+ node.name = [name,]
+
+ if node.name[0] != name:
+ node.warning = 'originally %s' % name
+
+ # Append names from the Aliases table.
+ for node in nodes:
+ self.db.execute('select nodes.alias from nodes '
+ 'where nodes.id = %d' % (node.id))
+ for alias, in self.db.fetchall():
+ node.name.append(alias)
+
+ # Format the data
+ for node in nodes:
+ fqdn = "%s.%s" % (node.name[0], domain)
+ entry = '%s\t%s %s' % (node.address, fqdn,
+ ' '.join(node.name))
+ if node.warning:
+ entry = entry + ' # ' + node.warning
+ print (entry)
+
+ def run(self, param, args):
+ self.beginOutput()
+ self.addOutput('localhost', '# Added by sunhpc report host #')
+ self.addOutput('localhost', '# DO NOT MODIFY #')
+ self.addOutput('localhost', '# Add any modifications to #')
+ self.addOutput('localhost', '# /etc/hosts.local file #\n')
+ self.addOutput('localhost','127.0.0.1\tlocalhost.localdomain\tlocalhost\n')
+
+ self.db.execute('select dnszone from subnets where name="private"')
+ localzone, = self.db.fetchone()
+
+ cmd = 'select nt.ip, n.alias, s.dnszone, ' +\
+ 'coalesce(nt.name,n.name) ' +\
+ 'from networks nt, subnets s, nodes n ' +\
+ 'where nt.subnet=s.id and nt.ip!="NULL" ' +\
+ 'and nt.node=n.id order by nt.subnet, nt.name'
+ self.db.execute(cmd)
+
+ # <ip> <name>.<private_zone> <name> <alias> <alias>.<private_zone>
+ # <ip> <name>.<zonename> <alias>.<zonename>
+ for (ip, alias, zone, record) in self.db.fetchall():
+ # Add <ip> <name>.<zonename>
+ h = '%s\t%s.%s' % (ip, record, zone)
+
+ # If it's the private zone
+ if zone == localzone:
+ # Add the <name> entry
+ h = h + '\t' + record
+ # and the <alias> entry
+ if alias:
+ h = h + '\t' + alias
+ # Finally add the <alias>.<zonename> entry
+ if alias:
+ h = h + '\t' + '%s.%s' % (alias, zone)
+ self.addOutput('localhost', h)
+
+ # Finally, add the hosts.local file to the list
+ hostlocal = '/etc/hosts.local'
+ if os.path.exists(hostlocal):
+ f = open(hostlocal,'r')
+ self.addOutput('localhost','\n# Imported from /etc/hosts.local\n')
+ h = f.read()
+ self.addOutput('localhost',h)
+ f.close()
+
+ self.endOutput(padChar='')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/host/attr/__init__.py b/lib/sunhpc/commands/report/host/attr/__init__.py
new file mode 100644
index 0000000..0aa2c0a
--- /dev/null
+++ b/lib/sunhpc/commands/report/host/attr/__init__.py
@@ -0,0 +1,72 @@
+#coding:utf-8
+
+import sys
+import socket
+import sunhpc
+class Command(sunhpc.commands.HostArgumentProcessor,
+ sunhpc.commands.report.command):
+ """
+ Report the set of attributes for hosts.
+
+ <arg optional='1' type='string' name='host'>
+ Host name of machine
+ </arg>
+
+ <param optional='1' type='string' name='attr'>
+ Output just the value of a particular attribute
+ </param>
+
+ <param optional='1' type='bool' name='pydict'>
+ Output as a python-formatted dictionary. Defaults to false.
+ Only valid if attr parameter is not specified.
+ </param>
+
+ <example cmd='report host attr compute-0-0'>
+ Report the attributes for compute-0-0.
+ </example>
+
+ <example cmd='report host attr compute-0-0 pydict=true'>
+ Report the attributes for compute-0-0 as a python dictionary suitable
+ for input to sunhpc report script.
+ </example>
+
+ <example cmd='report host attr compute-0-0 attr=Kickstart_Lang'>
+ Output value of the attribute called Kickstart_Lang for node
+ compute-0-0.
+ </example>
+
+ <related>report script</related>
+ """
+
+ def run(self, params, args):
+
+ (attr, pydict) = self.fillParams([('attr', ),('pydict','false')])
+ pyformat=self.str2bool(pydict)
+
+ self.beginOutput()
+
+ for host in self.getHostnames(args):
+
+ attrs = self.newdb.getHostAttrs(host)
+ if attr:
+ try:
+ self.addOutput(host, attrs[attr])
+ except KeyError:
+ raise sunhpc.core.utils.CommandError('Attribute %s does not exist' % attr)
+ elif pyformat:
+ # i don't understand why but when you pass it to the
+ # shell in this way no need to escape
+ #attrs = sunhpc.core.utils.escapeStringForShell(str(attrs))
+ self.addOutput(host, attrs)
+ else:
+ fmt="%s:%s"
+ for key in sorted(attrs.keys()):
+ self.addOutput(host,
+ fmt % (key, attrs[key]))
+
+ if pyformat:
+ self.endOutput(padChar='',linesep='')
+ else:
+ self.endOutput(padChar='')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/host/dhcpd/__init__.py b/lib/sunhpc/commands/report/host/dhcpd/__init__.py
new file mode 100644
index 0000000..5e7d38c
--- /dev/null
+++ b/lib/sunhpc/commands/report/host/dhcpd/__init__.py
@@ -0,0 +1,171 @@
+#coding:utf-8
+import os
+import sunhpc
+class Command(sunhpc.commands.HostArgumentProcessor,
+ sunhpc.commands.report.command):
+ """
+ Output the DHCP server configuration infomation
+ <example cmd='report host dhcpd'>
+ Output dhcpd data for host.
+ </example>
+
+ <arg type='str' name='host'>
+ Specify an node name.
+ </arg>
+ """
+ def makeAttrDictionary(self):
+ self.db.execute("""
+ select n.id, n.name, att.attr, att.value from nodes n, attributes att
+ where att.node = n.id order by n.id""")
+ self.attrdict = {}
+ for row in self.db.fetchall():
+ key = "%d-%s" % (row[0], row[2])
+ self.attrdict[key] = row[3]
+
+ def printHost(self, name, hostname, mac, ip, filename, nextserver):
+ self.addOutput('', '\t\thost %s {' % name)
+ if mac:
+ self.addOutput('', '\t\t\thardware ethernet %s;' % mac)
+
+ self.addOutput('', '\t\t\toption host-name "%s";' % hostname)
+ self.addOutput('', '\t\t\tfixed-address %s;' % ip)
+
+ if filename:
+ self.addOutput('','\t\t\tfilename "%s";' % filename)
+ if nextserver:
+ self.addOutput('','\t\t\tnext-server %s;' % nextserver)
+ self.addOutput('', '\t\t}')
+
+ def printOptions(self, prefix):
+
+ self.addOutput('', '%soption routers %s;' %
+ (prefix, self.db.getHostAttr('localhost', 'Kickstart_PrivateGateway')))
+
+ self.addOutput('', '%soption subnet-mask %s;' %
+ (prefix, self.db.getHostAttr('localhost', 'Kickstart_PrivateNetmask')))
+
+ self.addOutput('', '%soption domain-name "%s";' %
+ (prefix, self.db.getHostAttr('localhost', 'Kickstart_PrivateDNSDomain')))
+
+ self.addOutput('', '%soption domain-name-servers %s;' %
+ (prefix, self.db.getHostAttr('localhost', 'Kickstart_PrivateDNSServer')))
+
+ self.addOutput('', '%soption interface-mtu %s;' %
+ (prefix, self.db.getHostAttr('localhost', 'Kickstart_PrivateMTU')))
+
+ def writeDHCPConf(self, host):
+
+ ip = self.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ dm = self.db.getHostAttr('localhost', 'Kickstart_PrivateDNSDomain')
+ nw = self.db.getHostAttr('localhost', 'Kickstart_PrivateNetwork')
+ nm = self.db.getHostAttr('localhost', 'Kickstart_PrivateNetmask')
+ dl = self.db.getHostAttr('localhost', 'Kickstart_DefaultLeaseTime')
+ ml = self.db.getHostAttr('localhost', 'Kickstart_MaxLeaseTime')
+
+ if not dl: dl = '1200'
+ if not ml: ml = '1200'
+
+ #self.addOutput('', '<file name="/etc/dhcp/dhcpd.conf">')
+ self.addOutput('', 'ddns-update-style none;')
+ self.addOutput('', 'subnet %s netmask %s {' % (nw, nm))
+ self.addOutput('', '\tdefault-lease-time %s;' % dl)
+ self.addOutput('', '\tmax-lease-time %s;' % ml)
+
+ self.printOptions('\t')
+ self.addOutput('', '\tgroup "%s" {' % dm)
+
+ currnode = 0
+ self.makeAttrDictionary()
+ self.db.execute("""
+ SELECT n.id,n.name,n.rack,n.rank,net.device,net.mac,net.ip,sub.name
+ FROM nodes n INNER JOIN networks net
+ ON net.node=n.id, subnets sub WHERE net.subnet=sub.id
+ AND sub.name="private"
+ UNION
+ SELECT n.id,n.name,n.rack,n.rank,net.device,net.mac,
+ net.ip, NULL FROM nodes n INNER JOIN networks net
+ ON net.node=n.id where net.subnet IS NULL
+ ORDER BY 1,7 DESC""")
+
+ #
+ # add exclude mac address from /etc/dhcpd/ehost.conf get.
+ # ,e.g. hiwifi 00:00:00:00:00
+ exthost = []
+ extfile = '/etc/dhcp/ehost.conf'
+ if os.path.exists(extfile):
+ with open(extfile, 'r') as f:
+ for i in f.readlines():
+ if i.startswith('#'):
+ continue
+ tmp = i.split()
+ tmp.insert(0, 0)
+ tmp.insert(2, 0)
+ tmp.insert(3, 'ext0')
+ tmp.insert(4, 0)
+ tmp.insert(6, '255.255.255.255')
+ tmp.insert(7, 'private')
+ exthost.append(tmp)
+
+ for row in self.db.fetchall() + exthost:
+ node = sunhpc.core.utils.Struct()
+ node.id = row[0]
+ node.name = row[1]
+ node.rack = row[2]
+ netdevice = row[3]
+ node.rank = row[4]
+ node.mac = row[5]
+ node.ip = row[6]
+ netname = row[7]
+ hostname = node.name
+
+ if currnode != node.id:
+ currnode = node.id
+ unassignedidx = 0
+
+ kickstartable = self.attrdict.get('%d-kickstartable' % node.id)
+ if kickstartable:
+ kickstartable = self.str2bool(kickstartable)
+ else:
+ kickstartable = False
+
+ if not kickstartable:
+ nextserver = None
+ filename = None
+
+ if kickstartable:
+ filename = self.attrdict.get('%d-dhcp_filename' % node.id)
+ nextserver = self.attrdict.get('%d-dhcp_nextserver' % node.id)
+
+ if netname == "private":
+ privateIP = node.ip
+
+ if netname is None and node.ip is None:
+ node.ip = privateIP
+
+ if node.name is None or node.mac is None or node.ip is None or len(node.mac) > 20:
+ continue
+
+ try:
+ node.name = node.name + '-' + netdevice.replace(':', '_')
+ except:
+ pass
+ self.printHost(node.name, hostname, node.mac, node.ip, filename, nextserver)
+
+ self.addOutput('', '\t}')
+ self.addOutput('', '}')
+ #self.addOutput('', '</file>')
+
+ def run(self, params, args):
+
+ if len(args) > 1:
+ self.msg('Cannot supply more than one host name', 'a')
+
+ if len(args) == 0:
+ args = [ os.uname()[1] ]
+
+ hosts = self.getHostnames(args)
+ self.beginOutput()
+ self.writeDHCPConf(hosts)
+ self.endOutput(padChar='')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/kickstart/__init__.py b/lib/sunhpc/commands/report/kickstart/__init__.py
new file mode 100644
index 0000000..ded2a58
--- /dev/null
+++ b/lib/sunhpc/commands/report/kickstart/__init__.py
@@ -0,0 +1,57 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import textwrap
+class Command(sunhpc.commands.report.command):
+ """
+ Output the kickstart scripts.
+
+ <arg type='string' name='nodes'>
+ Output compute node kickstart scripts.
+ </arg>
+
+ <example cmd='report kickstart'>
+ Output the kickstart scripts.
+ </example>
+
+ <example cmd='report kickstart nodes'>
+ Output the compute node kickstart scripts.
+ </example>
+ """
+ def run(self, params, args):
+
+ if args and args[0] == 'nodes':
+
+ try:
+ hostname = args[1]
+ except KeyError:
+ hostname = 'temporary-nodes'
+
+ include_info = self.loadModules('kickstart', hostname)
+ print ('\n'.join(include_info))
+
+ phases = ['main', 'pre', 'post', 'packages', 'addons', 'anaconda']
+ for phase in phases:
+ kickstart = self.ks.getKickstart(phase)
+ for kslist in kickstart:
+ for ks in kslist:
+ print (textwrap.dedent(ks))
+ else:
+ # Output control scripts
+ self.ks.isFrontend = True
+ include_info = self.loadModules('control')
+ print ('\n'.join(include_info))
+
+ phases = ['main', 'pre', 'post', 'packages']
+ for phase in phases:
+ kickstart = self.ks.getKickstart(phase)
+ for kslist in kickstart:
+ for ks in kslist:
+ print (textwrap.dedent(ks))
+
+
+
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/knownhosts/__init__.py b/lib/sunhpc/commands/report/knownhosts/__init__.py
new file mode 100644
index 0000000..f2c62ae
--- /dev/null
+++ b/lib/sunhpc/commands/report/knownhosts/__init__.py
@@ -0,0 +1,102 @@
+#coding:utf-8
+import os
+import sys
+import sunhpc
+import subprocess
+class command(sunhpc.commands.HostArgumentProcessor, sunhpc.commands.report.command):
+ pass
+
+class Command(command):
+ """
+ Report the host to known hosts (public keys) for
+ /etc/ssh/ssh_known_hosts
+
+ <example cmd='report knownhosts'>
+ Outputs lists of public IPs to be used for /etc/ssh/ssh_known_hosts
+ </example>
+ """
+
+ def run(self, param, args):
+ self.beginOutput()
+ self.addOutput('localhost', '# Added by sunhpc report knownhosts #')
+ self.addOutput('localhost', '# DO NOT MODIFY #')
+ self.addOutput('localhost', '# If you need to add entries use #')
+ self.addOutput('localhost', '# /etc/ssh/ssh_known_hosts.local #')
+
+ # grab per-node public keys
+ allhosts = {}
+ cmd = """SELECT n.name, s.value FROM nodes n INNER JOIN
+ secnodes s ON s.node = n.id """
+ self.db.execute(cmd)
+
+ for (host,pubkey) in self.db.fetchall():
+ ipaddr = self.newdb.getHostIp(host)
+ #data = "%s,%s %s" % (host, ipaddr, pubkey.rstrip('\n'))
+ allhosts[host] = [ipaddr, pubkey.rstrip('\n')]
+
+ cmd = """SELECT n.name,s.dnszone,net.name,s.name FROM nodes n INNER JOIN
+ networks net ON net.node = n.id INNER JOIN
+ subnets s on net.subnet = s.id; """
+ self.db.execute(cmd)
+ for (host,zone,ifname,subnet) in self.db.fetchall():
+ if host in allhosts and zone is not None:
+ if ifname is not None:
+ hostname = ifname
+ else:
+ hostname = host
+ self.addOutput('localhost',
+ '%s.%s,%s %s' % (hostname, zone, allhosts[host][0], allhosts[host][1]))
+ if subnet == 'private':
+ self.addOutput('localhost',
+ '%s,%s %s' % (hostname, allhosts[host][0], allhosts[host][1]))
+
+ cmd = """SELECT s.value FROM secglobals s
+ WHERE s.attr = 'ssh_host_rsa_key.pub'"""
+ row = self.db.search(cmd)
+ if row > 0:
+ pubkey, = self.db.fetchone()
+ else:
+ pubkey = None
+ if pubkey is not None:
+ pubkey = pubkey.rstrip('\n')
+
+ cmd = """SELECT dnszone FROM subnets where dnszone
+ IS NOT NULL AND subnets.name != 'public';"""
+ self.db.execute(cmd)
+ for zone, in self.db.fetchall():
+ if pubkey is not None:
+ self.addOutput('localhost',
+ '*.%s %s' % (zone,pubkey))
+
+ cmd = """SELECT n.name,s.dnszone,net.name FROM nodes n INNER JOIN
+ networks net ON net.node = n.id INNER JOIN
+ subnets s on net.subnet = s.id AND s.name = 'private'; """
+ self.db.execute(cmd)
+
+ if pubkey is not None:
+ for (host,zone,ifname) in self.db.fetchall():
+ if host not in allhosts:
+ if ifname is not None:
+ hostname = ifname
+ else:
+ hostname = host
+ ipaddr = self.newdb.getHostIp(hostname)
+ self.addOutput('localhost',
+ '%s,%s %s' % (hostname, ipaddr, pubkey))
+
+ hostlocal = '/etc/ssh/ssh_known_hosts.local'
+ try:
+ f = open(hostlocal,'r')
+ self.addOutput('localhost','#\n# Imported from %s\n#' % hostlocal)
+ h = f.read()
+ self.addOutput('localhost',h)
+ f.close()
+ except :
+ pass
+
+ self.endOutput(padChar='')
+
+RollName = "base"
+
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/report/nextip/__init__.py b/lib/sunhpc/commands/report/nextip/__init__.py
new file mode 100644
index 0000000..601c8b8
--- /dev/null
+++ b/lib/sunhpc/commands/report/nextip/__init__.py
@@ -0,0 +1,73 @@
+#coding:utf-8
+
+import sunhpc
+import sqlalchemy
+from sunhpc.db.mappings.base import *
+class command(sunhpc.commands.HostArgumentProcessor,
+ sunhpc.commands.report.command):
+ pass
+
+class Command(command):
+ """
+ Report the next available IP address on the given subnet
+
+ <arg type='string' name='subnet'>
+ The subnet name that we should use
+ </arg>
+
+ <param type='int' name='increment'>
+ The increment that should be used to find the next available IP
+ Default to -1
+ </param>
+
+ <param type='string' name='baseip'>
+ The starting IP address we should use to generate IPs
+ </param>
+
+ <example cmd='report nextip private'>
+ Output the next available IP on the private subnet
+ </example>
+ """
+
+ def run(self, param, args):
+
+ (increment, baseip) = self.fillParams( [
+ ('increment', '-1'),
+ ('baseip', '')
+ ])
+
+ if len(args) != 1:
+ self.abort('must supply the subnet')
+
+ increment = int(increment)
+
+ subnet = args[0]
+
+ try:
+ subnet_db = self.newdb.getSession().query(sunhpc.db.mappings.base.Subnet)\
+ .options(sqlalchemy.orm.joinedload('networks'))\
+ .filter(Subnet.name == subnet).one()
+ except sqlalchemy.orm.exc.NoResultFound:
+ self.abort('subnet %s is not valid' % subnet)
+
+ mask_ip = sunhpc.core.ip.IPAddr(subnet_db.netmask)
+ network_ip = sunhpc.core.ip.IPAddr(subnet_db.subnet)
+ bcast_ip = sunhpc.core.ip.IPAddr(network_ip | sunhpc.core.ip.IPAddr(~mask_ip))
+ bcast = "%s" % (bcast_ip)
+ used_ip = [ net.ip for net in subnet_db.networks]
+
+ ip = sunhpc.core.ip.IPGenerator(bcast, subnet_db.netmask)
+ if baseip:
+ ip.addr = sunhpc.core.ip.IPAddr(baseip)
+
+ while 1:
+ nextip = ip.next(increment)
+ if str(nextip) not in used_ip:
+ # we found it
+ break
+
+ self.beginOutput()
+ self.addOutput('localhost', nextip)
+ self.endOutput(padChar='')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/run/__init__.py b/lib/sunhpc/commands/run/__init__.py
new file mode 100644
index 0000000..0a924f2
--- /dev/null
+++ b/lib/sunhpc/commands/run/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = 'base'
diff --git a/lib/sunhpc/commands/run/host/__init__.py b/lib/sunhpc/commands/run/host/__init__.py
new file mode 100644
index 0000000..b1cebf0
--- /dev/null
+++ b/lib/sunhpc/commands/run/host/__init__.py
@@ -0,0 +1,280 @@
+import os
+import sys
+import time
+import socket
+import threading
+import subprocess
+import sunhpc.commands
+class Parallel(threading.Thread):
+ def __init__(self, cmdclass, cmd, host, hostif, stats, collate):
+ threading.Thread.__init__(self)
+ self.cmd = cmd
+ self.host = host
+ self.hostif = hostif
+ self.stats = stats
+ self.collate = collate
+ self.cmdclass = cmdclass
+
+ def run(self):
+ starttime = time.time()
+ self.p = subprocess.Popen(self.cmd,
+ stdin = subprocess.PIPE, stdout = subprocess.PIPE,
+ stderr = subprocess.STDOUT)
+
+ for line in self.p.stdout.readlines():
+ if self.collate:
+ self.cmdclass.addOutput(self.host, line[:-1])
+ else:
+ print (line[:-1].decode('utf-8'))
+
+ if self.stats:
+ msg = 'command on host %s took %f seconds' % \
+ (self.host, time.time() - starttime)
+
+ if self.collate:
+ self.cmdclass.addOutput(self.host, msg)
+ else:
+ print (msg.decode('utf-8'))
+
+ def kill(self):
+ os.kill(self.p.pid, 9)
+
+class command(sunhpc.commands.HostArgumentProcessor,
+ sunhpc.commands.run.command):
+ MustBeRoot = 0
+
+class Command(command):
+ """
+ Run a command for each specified host.
+
+ <arg optional='1' type='string' name='host' repeat='1'>
+ Zero, one or more host names. If no host names are supplied, the command
+ is run on all 'managed' hosts. By default, all compute nodes are
+ 'managed' nodes. To determine if a host is managed, execute:
+ 'sunhpc list host attr hostname | grep managed'. If you see output like:
+ 'compute-0-0: managed true', then the host is managed.
+ </arg>
+
+ <arg type='string' name='command'>
+ The command to run on the list of hosts.
+ </arg>
+
+ <param type='boolean' name='managed'>
+ Run the command only on 'managed' hosts, that is, hosts that generally
+ have an ssh login. Default is 'yes'.
+ </param>
+
+ <param type='boolean' name='x11'>
+ If 'no', disable X11 forwarding when connecting to hosts.
+ Default is 'yes'.
+ </param>
+
+ <param type='string' name='timeout'>
+ Sets the maximum length of time (in seconds) that the command is
+ allowed to run.
+ Default is '30'.
+ </param>
+
+ <param type='string' name='delay'>
+ Sets the time (in seconds) to delay between each executed command
+ on multiple hosts. For example, if the command is run on two
+ hosts and if the delay is 10, then the command will be executed on host
+ 1, then 10 seconds later, the command will be executed on host 2.
+ Default is '0' (no delay).
+ </param>
+
+ <param type='string' name='stats'>
+ Display performance statistics if this parameter is set to 'yes'.
+ Default is 'no'.
+ </param>
+
+ <param type='string' name='collate'>
+ Prepend the hostname to every output line if this parameter is set to
+ 'yes'.
+ Default is 'no'.
+ </param>
+
+ <param type='string' name='num-threads'>
+ The number of threads to start in parallel. If num-threads is 0, then
+ try to run the command in parallel on all hosts. Default is '128'.
+ </param>
+
+ <param type='string' name='command'>
+ Can be used in place of the 'command' argument.
+ </param>
+
+ <example cmd='run host compute-0-0 command="hostname"'>
+ Run the command 'hostname' on compute-0-0.
+ </example>
+
+ <example cmd='run host compute "ls /tmp"'>
+ Run the command 'ls /tmp/' on all compute nodes.
+ </example>
+ """
+
+ def nodeup(self, hostif):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(2.0)
+ try:
+ sock.connect((hostif, 22))
+ buf = sock.recv(64)
+ sock.send('SSH-2.0-nodeup\r\n')
+ buf = sock.recv(1024)
+ sock.close()
+ except:
+ return 0
+ return 1
+
+ def run(self, params, args):
+ (args, command) = self.fillPositionalArgs(('command', ))
+
+ if not command:
+ self.abort('must supply a command')
+
+ (managed, x11, t, d, s, c, n) = \
+ self.fillParams([
+ ('managed', 'y'),
+ ('x11', 'y'),
+ ('timeout', '30'),
+ ('delay', '0'),
+ ('stats', 'n'),
+ ('collate', 'n'),
+ ('num-threads', '8')
+ ])
+
+ try:
+ timeout = int(t)
+ except:
+ self.abort('"timeout" must be an integer')
+
+ if timeout < 0:
+ self.abort('"timeout" must be a postive integer')
+
+ try:
+ numthreads = int(n)
+ except:
+ self.abort('"num-threads" must be an integer')
+
+ try:
+ delay = float(d)
+ except:
+ self.abort('"delay" must be a floating point number')
+
+ hosts = self.getHostnames(args, self.str2bool(managed))
+ # This is the same as doing -x using ssh. Might be useful
+ # for the common case, but required for the Viz Roll.
+ if not self.str2bool(x11):
+ try:
+ del os.environ['DISPLAY']
+ except KeyError:
+ pass
+
+ collate = self.str2bool(c)
+ stats = self.str2bool(s)
+ if collate:
+ self.beginOutput()
+
+ if numthreads <= 0:
+ numthreads = len(hosts)
+
+ threads = []
+ i = 0
+ work = len(hosts)
+ while work:
+ localhost = socket.gethostname().split('.')[0]
+ while i < numthreads and i < len(hosts):
+ host = hosts[i]
+ # Is this host me?
+ runlocal = (localhost == host.split('.')[0])
+ i += 1
+
+ try:
+ hnet=self.db.getHostAttr(host,'primary_net')
+ query="""select net.ip from networks net, nodes n, subnets s where
+ net.node=n.id and net.subnet=s.id and
+ n.name='%s' and s.name='%s' """ % (host, hnet)
+ self.db.execute(query)
+ hostif,=self.db.fetchone()
+ except:
+ hostif=host
+ #
+ # first test if the node is up and responding
+ # to ssh
+ #
+ if not runlocal and not self.nodeup(hostif):
+ if collate:
+ self.addOutput(host, 'down')
+ else:
+ print ('%s: down' % host)
+
+ numthreads += 1
+ work -= 1
+ continue
+ #
+ # fire off the command
+ #
+ if runlocal:
+ cmd = ('bash', '-c', command)
+ else:
+ cmd = ('ssh', hostif, command)
+
+ p = Parallel(self, cmd, host, hostif, stats, collate)
+ p.start()
+ threads.append(p)
+
+ if delay > 0:
+ time.sleep(delay)
+ #
+ # collect completed threads
+ #
+ try:
+ totaltime = time.time()
+ while timeout == 0 or \
+ (time.time() - totaltime) < timeout:
+
+ active = threading.enumerate()
+
+ t = threads
+ for thread in t:
+ if thread not in active:
+ thread.join(0.1)
+ threads.remove(thread)
+ numthreads += 1
+ work -= 1
+
+ if len(active) == 1:
+ break
+ #
+ # don't burn a CPU while waiting for the
+ # threads to complete
+ #
+ time.sleep(0.5)
+
+ except KeyboardInterrupt:
+ #
+ # try to collect all the active threads
+ #
+ active = threading.enumerate()
+
+ t = threads
+ for thread in t:
+ if thread not in active:
+ thread.join(0.1)
+ threads.remove(thread)
+ #
+ # no more work to do if the user hits
+ # control-c
+ #
+ work = 0
+ #
+ # kill all still active threads
+ #
+ active = threading.enumerate()
+ if len(active) >= 2:
+ for i in range(1, len(active)):
+ active[i].kill()
+
+ if collate:
+ self.endOutput(padChar='')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/__init__.py b/lib/sunhpc/commands/set/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/set/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/__init__.py b/lib/sunhpc/commands/set/host/__init__.py
new file mode 100644
index 0000000..b5225ca
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/__init__.py
@@ -0,0 +1,8 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.HostArgumentProcessor,
+ sunhpc.commands.set.command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/boot/__init__.py b/lib/sunhpc/commands/set/host/boot/__init__.py
new file mode 100644
index 0000000..d5fe7a6
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/boot/__init__.py
@@ -0,0 +1,52 @@
+#coding:utf-8
+
+import sunhpc.commands
+class Command(sunhpc.commands.set.host.command):
+ """
+ Set a bootaction for a host. A hosts action can be set to 'install'
+ or to 'os' (also, 'run' is a synonym for 'os').
+
+ <arg type='string' name='host' repeat='1'>
+ One or more host names.
+ </arg>
+
+ <param type='string' name='action'>
+ The label name for the bootaction. This must be one of: 'os',
+ 'install', or 'run'.
+
+ If no action is supplied, then only the configuration file for the
+ list of hosts will be rewritten.
+ </param>
+
+ <example cmd='set host boot compute-0-0 action=os'>
+ On the next boot, compute-0-0 will boot the profile based on its
+ "status". To see the node's "status", execute:
+ "sunhpc list host compute-0-0" and examine the value in the "status" column.
+ </example>
+ """
+
+ def updateBoot(self, host, action):
+
+ self.db.execute("select id from nodes where name = '%s'" % host)
+ bootid, = self.db.fetchone()
+ self.db.execute("update nodes set status = '%s' where id = %s" % (action, bootid))
+
+ def run(self, params, args):
+
+ (action,) = self.fillParams([('action', )])
+
+ if not len(args):
+ self.abort('must supply host')
+
+ if action not in [ 'os', 'install', None ]:
+ self.abort('invalid action. action must be "os" or "install"')
+
+ for host in self.getHostnames(args):
+ if action:
+ self.updateBoot(host, action)
+ #
+ # run the plugins
+ #
+ self.runPlugins(host)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/boot/plugin_00_ip2hex.py b/lib/sunhpc/commands/set/host/boot/plugin_00_ip2hex.py
new file mode 100644
index 0000000..b68aaba
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/boot/plugin_00_ip2hex.py
@@ -0,0 +1,170 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc.commands
+class Plugin(sunhpc.commands.Plugin):
+
+ def getFilename(self, nodeid):
+ nrows = self.db.search("""select networks.ip from
+ networks,subnets where
+ networks.node = %s and subnets.name = "private" and
+ networks.subnet = subnets.id
+ """ % (nodeid))
+
+ if nrows < 1:
+ return None
+
+ ipaddr, = self.db.fetchone()
+
+ filename = '/tftpboot/pxelinux/pxelinux.cfg/'
+ for i in ipaddr.split('.'):
+ hexstr = '%02x' % (int(i))
+ filename += '%s' % hexstr.upper()
+ return filename
+
+ def writeDefaultPxebootCfg(self):
+ nrows = self.db.search("select kernel, ramdisk, args from bootactions where action='install'")
+ if nrows == 1:
+ kernel, ramdisk, args = self.db.fetchone()
+
+ filename = '/tftpboot/pxelinux/pxelinux.cfg/default'
+ file = open(filename, 'w')
+ file.write('default sunhpc\n')
+ file.write('prompt 0\n')
+ file.write('label sunhpc\n')
+
+ if len(kernel) > 6 and kernel[0:7] == 'vmlinuz':
+ file.write('\tkernel %s\n' % (kernel))
+ if len(ramdisk) > 0:
+ if len(args) > 0:
+ args += ' initrd=%s' % ramdisk
+ else:
+ args = 'initrd=%s' % ramdisk
+ if len(args) > 0:
+ file.write('\tappend %s\n' % (args))
+
+ file.close()
+
+ #
+ # make sure apache can update the file
+ #
+ os.system('chown root.apache %s' % (filename))
+ os.system('chmod 664 %s' % (filename))
+
+
+ def writePxebootCfg(self, node, nodeid):
+ #
+ # there is a case where the host name may be in the nodes table
+ # but not in the boot table. in this case, remove the current
+ # configuration file (if it exists) and return
+ #
+ filename = self.getFilename(nodeid)
+
+ nrows = self.db.search("select status from nodes where id = %s" % nodeid)
+ if nrows < 1:
+ if filename != None and os.path.exists(filename):
+ os.unlink(filename)
+ return
+ else:
+ action, = self.db.fetchone()
+
+ #
+ # get the bootaction from the 'installaction' or
+ # 'runaction' column
+ #
+ if action in [ 'os', 'install' ]:
+ nrows = self.db.search("select status from nodes where name = '%s'" % node)
+ else:
+ print ('action "%s" for host "%s" is invalid' % (action, node))
+ sys.exit(-1)
+
+ if nrows == 1:
+ bootaction, = self.db.fetchone()
+ else:
+ print ('failed to get bootaction')
+ sys.exit(-1)
+
+ nrows = self.db.search("""select kernel, ramdisk, args from
+ bootactions where action = '%s' """% bootaction)
+
+ if nrows == 1:
+ kernel, ramdisk, args = self.db.fetchone()
+ else:
+ print ('bootaction "%s" for host "%s" is invalid' % (action, node))
+ sys.exit(-1)
+
+ if args and args.find('ksdevice=') != -1:
+ self.db.execute("""select net.ip
+ from networks net, subnets s, nodes n
+ where n.name='%s' and net.node=n.id and
+ s.id=net.subnet and s.name='private' and
+ net.ip is not NULL""" % node)
+ ip, = self.db.fetchone()
+ args += ' ip=%s ' % ip
+ attrs = self.db.getHostAttrs('localhost')
+ # hostname=%s
+ args += 'gateway=%s netmask=%s dns=%s nextserver=%s'%(\
+ attrs['Kickstart_PrivateGateway'],
+ attrs['Kickstart_PrivateNetmask'],
+ attrs['Kickstart_PrivateDNSServer'],
+ attrs['Kickstart_PrivateAddress'])
+
+ if args and args.find('inst.ks=') == -1:
+ address = self.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ args += ' inst.ks=http://%s/install/sbin/kickstart.cgi' % address
+
+ if filename != None:
+ file = open(filename, 'w')
+ file.write('default sunhpc\n')
+ file.write('prompt 0\n')
+ file.write('label sunhpc\n')
+
+ if kernel:
+ if kernel[0:7] == 'vmlinuz':
+ file.write('\tkernel %s\n' % (kernel))
+ else:
+ file.write('\t%s\n' % (kernel))
+
+ if ramdisk and len(ramdisk) > 0:
+ if len(args) > 0:
+ args += ' initrd=%s' % ramdisk
+ else:
+ args = 'initrd=%s' % ramdisk
+
+ if args and len(args) > 0:
+ file.write('\tappend %s\n' % (args))
+
+ # If using ksdevice=bootif we need to
+ # pass the PXE information to loader.
+
+ if args and args.find('bootif') != -1:
+ file.write('\tipappend 2\n')
+
+ file.close()
+
+ #
+ # make sure apache can update the file
+ #
+ os.system('chown root.apache %s' % (filename))
+ os.system('chmod 664 %s' % (filename))
+
+ def run(self, host):
+ nrows = self.db.search("select id from nodes where name = '%s'" % host)
+ if nrows > 0:
+ nodeid, = self.db.fetchone()
+ else:
+ print ('could not find host "%s" in the database' % host)
+ sys.exit(-1)
+
+ frontend_host = self.db.getHostAttr('localhost', 'Kickstart_PrivateHostname')
+
+ if host == frontend_host:
+ self.writeDefaultPxebootCfg()
+ else:
+ self.writePxebootCfg(host, nodeid)
+
+ def __repr__(self):
+ return '01'
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/cpus/__init__.py b/lib/sunhpc/commands/set/host/cpus/__init__.py
new file mode 100644
index 0000000..853a48b
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/cpus/__init__.py
@@ -0,0 +1,45 @@
+#coding:utf-8
+
+import sunhpc.commands
+class Command(sunhpc.commands.set.host.command):
+ """
+ Set the number of CPUs for a list of hosts.
+
+ <arg type='string' name='host' repeat='1'>
+ One or more host names.
+ </arg>
+
+ <arg type='string' name='cpus'>
+ The number of CPUs to assign to each host.
+ </arg>
+
+ <param optional='1' type='string' name='cpus'>
+ Can be used in place of the cpus argument.
+ </param>
+
+ <example cmd='set host cpus compute-0-0 2'>
+ Sets the CPU value to 2 for compute-0-0.
+ </example>
+
+ <example cmd='set host cpus compute-0-0 compute-0-1 4'>
+ Sets the CPU value to 4 for compute-0-0 and compute-0-1.
+ </example>
+
+ <example cmd='set host cpus compute-0-0 compute-0-1 cpus=4'>
+ Same as above.
+ </example>
+ """
+
+ def run(self, params, args):
+ (args, cpus) = self.fillPositionalArgs(('cpus',))
+
+ if not len(args):
+ self.abort('must supply host')
+ if not cpus:
+ self.abort('must supply cpus')
+
+ for host in self.getHostnames(args):
+ self.db.execute("""update nodes set cpus=%d where
+ name='%s'""" % (int(cpus), host))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/interface/__init__.py b/lib/sunhpc/commands/set/host/interface/__init__.py
new file mode 100644
index 0000000..e11ea5c
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/interface/__init__.py
@@ -0,0 +1,3 @@
+#coding:utf-8
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/interface/iface/__init__.py b/lib/sunhpc/commands/set/host/interface/iface/__init__.py
new file mode 100644
index 0000000..26a2a20
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/interface/iface/__init__.py
@@ -0,0 +1,62 @@
+#coding:utf-8
+
+import sunhpc
+class Command(sunhpc.commands.set.host.command):
+ """
+ Sets the logical interface of a mac address for particular hosts.
+
+ <arg type='string' name='host' repeat='1'>
+ One or more named hosts.
+ </arg>
+
+ <arg type='string' name='mac'>
+ MAC address of the interface whose logical interface will be reassigned
+ </arg>
+
+ <arg type='string' name='iface'>
+ Logical interface.
+ </arg>
+
+ <param type='string' name='mac'>
+ Can be used in place of the mac argument.
+ </param>
+
+ <param type='string' name='iface'>
+ Can be used in place of the iface argument.
+ </param>
+
+
+ <example cmd='set host interface iface compute-0-0 00:0e:0c:a7:5d:ff eth1'>
+ Sets the logical interface of MAC address 00:0e:0c:a7:5d:ff to be eth1
+ </example>
+
+ <example cmd='set host interface iface compute-0-0 iface=eth1 mac=00:0e:0c:a7:5d:ff'>
+ Same as above.
+ </example>
+
+ <!-- cross refs do not exist yet
+ <related>set host interface iface</related>
+ <related>set host interface ip</related>
+ <related>set host interface module</related>
+ -->
+ <related>add host</related>
+ """
+
+ def run(self, params, args):
+
+ (args, mac, iface) = self.fillPositionalArgs(('mac', 'iface'))
+
+ if not len(args):
+ self.abort('must supply host')
+ if not mac:
+ self.abort('must supply mac')
+ if not iface:
+ self.abort('must supply iface')
+
+ for host in self.getHostnames(args):
+ self.db.execute("""update networks set device='%s'
+ where nodes.name='%s' and networks.node=nodes.id and
+ networks.mac='%s'""" %
+ (iface, host, mac))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/interface/ip/__init__.py b/lib/sunhpc/commands/set/host/interface/ip/__init__.py
new file mode 100644
index 0000000..d4b7566
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/interface/ip/__init__.py
@@ -0,0 +1,87 @@
+#coding:utf-8
+
+import sunhpc
+import socket
+class Command(sunhpc.commands.set.host.command):
+ """
+ Sets the IP address for the named interface for one host.
+
+ <arg type='string' name='host'>
+ Host name.
+ </arg>
+
+ <arg type='string' name='iface'>
+ Interface that should be updated. This may be a logical interface or
+ the mac address of the interface.
+ </arg>
+
+ <arg type='string' name='ip'>
+ The IP address of the interface. Usually of the form nnn.nnn.nnn.nnn
+ where n is a decimal digit. This format is not enforced. Use IP=NULL
+ to clear.
+ </arg>
+
+ <param type='string' name='iface'>
+ Can be used in place of the iface argument.
+ </param>
+
+ <param type='string' name='ip'>
+ Can be used in place of the ip argument.
+ </param>
+
+
+ <example cmd='set host interface ip compute-0-0 eth1 192.168.0.10'>
+ Sets the IP Address for the eth1 device on host compute-0-0.
+ </example>
+
+ <example cmd='set host interface ip compute-0-0 iface=eth1 ip=192.168.0.10'>
+ Same as above.
+ </example>
+
+ <related>set host interface iface</related>
+ <related>set host interface ip</related>
+ <related>set host interface module</related>
+ <related>add host</related>
+ """
+
+ def run(self, params, args):
+
+ (args, iface, ip) = self.fillPositionalArgs(('iface', 'ip'))
+
+ hosts = self.getHostnames(args)
+
+ if len(hosts) != 1:
+ self.abort('must supply one host')
+ if not iface:
+ self.abort('must supply iface')
+ if not ip:
+ self.abort('must supply ip')
+
+ ip = ip.upper() # null -> NULL
+ if ip != 'NULL':
+ #let's check if this is a valid IPv4 address
+ #for IPv6 use socket.AF_INET6
+ try:
+ socket.inet_pton(socket.AF_INET, ip)
+ except:
+ self.abort("The ip address %s is invalid" % ip )
+
+ for host in hosts:
+ self.db.execute("""update networks set ip=NULLIF('%s','NULL')
+ where node=(select id from nodes where nodes.name="%s") and
+ (networks.device='%s' or networks.mac='%s')""" %
+ (ip, host, iface, iface))
+
+RollName = "base"
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/set/host/interface/mac/__init__.py b/lib/sunhpc/commands/set/host/interface/mac/__init__.py
new file mode 100644
index 0000000..a119030
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/interface/mac/__init__.py
@@ -0,0 +1,72 @@
+#coding:utf-8
+
+import sunhpc
+class Command(sunhpc.commands.set.host.command):
+ """
+ Sets the mac address for named interface on host.
+
+ <arg type='string' name='host'>
+ Host name.
+ </arg>
+
+ <arg type='string' name='iface'>
+ Interface that should be updated. This may be a logical interface or
+ the mac address of the interface.
+ </arg>
+
+ <arg type='string' name='mac'>
+ The mac address of the interface. Usually of the form dd:dd:dd:dd:dd:dd
+ where d is a hex digit. This format is not enforced. Use mac=NULL to
+ clear the mac address.
+ </arg>
+
+ <param type='string' name='iface'>
+ Can be used in place of the iface argument.
+ </param>
+
+ <param type='string' name='mac'>
+ Can be used in place of the mac argument.
+ </param>
+
+
+ <example cmd='set host interface mac compute-0-0 eth1 00:0e:0c:a7:5d:ff'>
+ Sets the MAC Address for the eth1 device on host compute-0-0.
+ </example>
+
+ <example cmd='set host interface mac compute-0-0 iface=eth1 mac=00:0e:0c:a7:5d:ff'>
+ Same as above.
+ </example>
+
+ <example cmd='set host interface mac compute-0-0 iface=eth1 mac=NULL'>
+ clears the mac address from the database
+ </example>
+
+ <!-- cross refs do not exist yet
+ <related>set host interface iface</related>
+ <related>set host interface ip</related>
+ <related>set host interface module</related>
+ -->
+ <related>add host</related>
+ """
+
+ def run(self, params, args):
+
+ (args, iface, mac) = self.fillPositionalArgs(('iface', 'mac'))
+
+ hosts = self.getHostnames(args)
+
+ if len(hosts) != 1:
+ self.abort('must supply one host')
+ if not iface:
+ self.abort('must supply iface')
+ if not mac:
+ self.abort('must supply mac')
+
+ for host in hosts:
+ self.db.execute("""update networks set mac=NULLIF('%s','NULL')
+ where
+ node=(select id from nodes where nodes.name="%s") and
+ (networks.device='%s' or networks.mac='%s')""" %
+ (mac, host, iface, iface))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/interface/name/__init__.py b/lib/sunhpc/commands/set/host/interface/name/__init__.py
new file mode 100644
index 0000000..2fddcda
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/interface/name/__init__.py
@@ -0,0 +1,83 @@
+#coding:utf-8
+
+import sunhpc
+class Command(sunhpc.commands.set.host.command):
+ """
+ Sets the logical name of a network interface on a particular host.
+
+ <arg type='string' name='host'>
+ Host name.
+ </arg>
+
+ <arg type='string' name='iface'>
+ Interface that should be updated. This may be a logical interface or
+ the MAC address of the interface.
+ </arg>
+
+ <arg type='string' name='name'>
+ Name of this interface (e.g. newname). This is only the
+ name associated with a certain interface. FQDNs are disallowed.
+ To set the domain or zone for an interface, use the
+ "sunhpc add network" command, and then associate the interface
+ with the network
+ </arg>
+
+ <param type='string' name='iface'>
+ Can be used in place of the iface argument.
+ </param>
+
+ <param type='string' name='name'>
+ Can be used in place of the name argument.
+ </param>
+
+
+ <example cmd='set host interface name compute-0-0 eth1 cluster-0-0'>
+ Sets the name for the eth1 device on host compute-0-0 to
+ cluster-0-0.zonename. The zone is decided by the subnet that the
+ interface is attached to.
+ </example>
+
+ <example cmd='set host interface name compute-0-0 iface=eth1 name=c0-0'>
+ Same as above.
+ </example>
+
+ <!-- cross refs do not exist yet
+ <related>set host interface iface</related>
+ <related>set host interface ip</related>
+ <related>set host interface module</related>
+ -->
+ <related>add host</related>
+ <related>add network</related>
+ """
+
+ def run(self, params, args):
+
+ (args, iface, name) = self.fillPositionalArgs(('iface','name'))
+
+ hosts = self.getHostnames(args)
+
+ if len(hosts) != 1:
+ self.abort('must supply one host')
+
+ # One host only
+ host = hosts[0]
+
+ if not iface:
+ self.abort('must supply iface')
+ if not name:
+ self.abort('must supply name')
+
+ if len(name.split('.')) > 1:
+ self.abort('cannot be fqdn\n' +\
+ 'Please use subnets table to set domain name')
+
+ if name.upper() == "NULL":
+ name = host
+
+ self.db.execute("""update networks set networks.name='%s'
+ where
+ node=(select id from nodes where nodes.name="%s") and
+ (networks.device='%s' or networks.mac='%s')""" %
+ (name, host, iface, iface))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/set/host/interface/subnet/__init__.py b/lib/sunhpc/commands/set/host/interface/subnet/__init__.py
new file mode 100644
index 0000000..dcbd3fc
--- /dev/null
+++ b/lib/sunhpc/commands/set/host/interface/subnet/__init__.py
@@ -0,0 +1,80 @@
+#coding:utf-8
+
+import sunhpc
+class Command(sunhpc.commands.set.host.command):
+ """
+ Sets the subnet for named interface on one of more hosts.
+
+ <arg type='string' name='host' repeat='1'>
+ One or more named hosts.
+ </arg>
+
+ <arg type='string' name='iface'>
+ Interface that should be updated. This may be a logical interface or
+ the MAC address of the interface.
+ </arg>
+
+ <arg type='string' name='subnet'>
+ The subnet address of the interface. This is a named subnet and must be
+ listable by the command 'sunhpc list network'.
+ </arg>
+
+ <param type='string' name='iface'>
+ Can be used in place of the iface argument.
+ </param>
+
+ <param type='string' name='subnet'>
+ Can be used in place of the subnet argument.
+ </param>
+
+
+ <example cmd='set host interface subnet compute-0-0 eth1 public'>
+ Sets eth1 to be on the public subnet.
+ </example>
+
+ <example cmd='set host interface mac compute-0-0 iface=eth1 subnet=public'>
+ Same as above.
+ </example>
+
+ <!-- cross refs do not exist yet
+ <related>set host interface iface</related>
+ <related>set host interface ip</related>
+ <related>set host interface module</related>
+ -->
+ <related>add host</related>
+ """
+
+ def run(self, params, args):
+
+ (args, iface, subnet) = self.fillPositionalArgs(
+ ('iface', 'subnet'))
+
+ if not len(args):
+ self.abort('must supply host')
+ if not iface:
+ self.abort('must supply iface')
+ if not subnet:
+ self.abort('must supply subnet')
+
+ for host in self.getHostnames(args):
+
+ self.db.execute("""select net.name from
+ networks net, nodes n where
+ n.name='%s' and
+ net.node=n.id and
+ (net.device='%s' or net.mac='%s')""" %
+ (host, iface, iface))
+
+ name, = self.db.fetchone()
+ if not name: name = host
+
+ # Updates the subnet id and the name. The name
+ # is updated even if it did not change (see above)
+ self.db.execute("""update networks set
+ subnet=(select id from subnets s where s.name='%s'),
+ name='%s' where
+ node=(select id from nodes where nodes.name="%s") and
+ (device='%s' or mac='%s')""" %
+ (subnet, name, host, iface, iface))
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/soft/__init__.py b/lib/sunhpc/commands/soft/__init__.py
new file mode 100644
index 0000000..8fc7ab7
--- /dev/null
+++ b/lib/sunhpc/commands/soft/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/soft/autodock/__init__.py b/lib/sunhpc/commands/soft/autodock/__init__.py
new file mode 100644
index 0000000..8161f15
--- /dev/null
+++ b/lib/sunhpc/commands/soft/autodock/__init__.py
@@ -0,0 +1,120 @@
+#
+#coding:utf-8
+#
+#Author : QCSun
+#Email : qcsun@sunhpc.com
+#Times : 2023-04-14 05:21:02
+#WebSite : https://www.sunhpc.com
+
+import os
+import sys
+import sunhpc
+import shutil
+
+class command(sunhpc.commands.soft.command):
+ pass
+
+class Command(command):
+ """
+ Build the Gaussian software.
+
+ <arg type="string" name="version">
+ Specifies the software version. e.g,, version=03/09/16
+ </arg>
+
+ <param type="path" name="prefix">
+ Specifies the software install path.
+ </param>
+
+ <param type="path" name="envs">
+ Specifies the software env path.
+ </param>
+
+ <param type="path" name="source">
+ Specifies the software source path. e.g,, /mnt/usb
+ </param>
+
+ <example cmd='soft gaussian prefix=/share/apps/soft version=16'>
+ install the gaussian software.
+ </example>
+ """
+ def run(self, params, args):
+
+ (prefix, version, source, envs) = self.fillParams([
+ ('prefix', '/share/apps/soft'),
+ ('version', None),
+ ('source', '/mnt/usb'),
+ ('envs', '/share/apps/envs'),
+ ])
+
+ if not version:
+ self.msg('must supply an "Gaussian version" e.g,, version=03/09/16', 'a')
+
+ try:
+ os.makedirs(prefix)
+ self.msg("The %s directory does not exist,and it will be created." % prefix, 'w')
+ except FileExistsError:
+ pass
+
+ if not os.path.exists(envs):
+ os.makedirs(envs)
+
+ softname = 'Gaussian%s' % version
+
+ softdirs = os.path.join(source, 'hpcsoft')
+ softlist = os.listdir(softdirs)
+
+ if 'Gaussian' not in softlist:
+ self.msg('The "%s" software was not found in the %s directory.' % (softname, softdirs), 'a')
+
+ gspathname = os.path.join(softdirs, 'Gaussian', softname, 'Gaussian16-a03.tbz')
+
+
+ gaussian = os.path.join(prefix, 'g%s' % version)
+ if os.path.exists(gaussian):
+ self.msg('The %s already exists,to reinstall it, remove it first.' % gaussian)
+ else:
+ self.msg('Start installing the %s software to the %s directory...' % (softname, prefix))
+ os.system('tar -xf %s -C %s' % (gspathname, prefix))
+
+ gsenv = os.path.join(envs, 'g16-env.sh')
+ with open(gsenv, 'w') as f:
+ f.write('#!/bin/sh\n')
+ f.write('#\n# %s env config\n#\n\n' % gaussian)
+ f.write('export g%sroot=%s\n' % (version, prefix))
+ f.write('source $g%sroot/g%s/bsd/g%s.profile\n\n' % (version, version, version))
+ f.write('export GAUSS_SCDIR=~/gstmp\n')
+
+ # create shared user and group.
+ self.msg('Create a shared group to run the %s software.' % softname)
+ self.msg(' 1, groupadd -g 888 public ')
+ self.msg(' 2, usermod -G public dell ')
+ self.msg(' 3, chown -R root:public %s/g%s' % (prefix, version))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/soft/gaussian/__init__.py b/lib/sunhpc/commands/soft/gaussian/__init__.py
new file mode 100644
index 0000000..1be8725
--- /dev/null
+++ b/lib/sunhpc/commands/soft/gaussian/__init__.py
@@ -0,0 +1,152 @@
+#
+#coding:utf-8
+#
+#Author : QCSun
+#Email : qcsun@sunhpc.com
+#Times : 2023-04-14 05:21:02
+#WebSite : https://www.sunhpc.com
+
+import os
+import sys
+import sunhpc
+import shutil
+
+class command(sunhpc.commands.soft.command):
+ pass
+
+class Command(command):
+ """
+ Build the Gaussian software.
+
+ <arg type="string" name="version">
+ Specifies the software version. e.g,, version=03/09/16
+ </arg>
+
+ <param type="string" name="ext">
+ Only the Gaussian09 version will use it, e.g,, ext=a01/d01/e01.
+ Default: E01
+ </param>
+
+ <param type="path" name="prefix">
+ Specifies the software install path.
+ Default: /share/apps/soft/gaussian
+ </param>
+
+ <param type="path" name="envs">
+ Specifies the software env path.
+ Default: /share/apps/envs
+ </param>
+
+ <param type="path" name="source">
+ Specifies the software source path. e.g,, /mnt/usb
+ Default: /mnt/usb
+ </param>
+
+ <example cmd='soft gaussian prefix=/share/apps/soft version=16'>
+ install the gaussian software.
+ </example>
+ """
+ def run(self, params, args):
+
+ (prefix, version, source, envs, ext) = self.fillParams([
+ ('prefix', '/share/apps/soft/gaussian'),
+ ('version', None),
+ ('source', '/mnt/usb'),
+ ('envs', '/share/apps/envs'),
+ ('ext', 'e01'),
+ ])
+
+ if len(args):
+ version = args[0]
+
+ if not version:
+ self.msg('must supply an "Gaussian version" e.g,, version=03/09/16', 'a')
+
+ try:
+ os.makedirs(prefix)
+ self.msg("The %s directory does not exist,and it will be created." % prefix, 'w')
+ except FileExistsError:
+ pass
+
+ if not os.path.exists(envs):
+ os.makedirs(envs)
+
+ if version == '03':
+ if ext in ['d01']:
+ basename = 'g03-d01'
+ else:
+ basename = 'g03-std'
+
+ elif version == '09':
+ basename = 'g09-%s' % ext.lower()
+
+ elif version == '16':
+ basename = 'g16-a03'
+
+ else:
+ self.msg('version error, must is 03/09/16', 'a')
+
+ filename = '%s.tar.bz2' % basename
+
+
+ pathname = 'Gaussian%s' % version
+ softname = os.path.join(source, 'hpcsoft/Gaussian', pathname, filename)
+
+ if not os.path.exists(softname):
+ self.msg('The "%s" not found.' % softname, 'a')
+
+ gaussian = os.path.join(prefix, basename)
+
+
+ softlist = os.listdir(prefix)
+ print ('------softlist-----', softlist)
+ print ('------basename-----', basename)
+ if basename in softlist:
+ self.msg('The %s already exists,to reinstall it, remove it first.' % gaussian)
+ else:
+ self.msg('Start installing the %s software to the %s directory...' % (basename, prefix))
+ os.system('tar -xf %s -C %s' % (softname, prefix))
+
+ gsenv = os.path.join(envs, '%s-env.sh' % basename)
+ with open(gsenv, 'w') as f:
+ f.write('#!/bin/sh\n')
+ f.write('#\n# %s env config\n#\n\n' % gaussian)
+ f.write('export g%sroot=%s\n' % (version, prefix))
+ f.write('source $g%sroot/g%s/bsd/g%s.profile\n\n' % (version, version, version))
+ f.write('export GAUSS_SCDIR=~/gstmp\n')
+
+ # create shared user and group.
+ self.msg('Create a shared group to run the %s software.' % softname)
+ self.msg(' 1, groupadd -g 888 public ')
+ self.msg(' 2, usermod -G public dell ')
+ self.msg(' 3, chown -R root:public %s/%s' % (prefix, basename))
+ self.msg('')
+ self.msg(' source %s' % gsenv)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/commands/sync/__init__.py b/lib/sunhpc/commands/sync/__init__.py
new file mode 100644
index 0000000..0a924f2
--- /dev/null
+++ b/lib/sunhpc/commands/sync/__init__.py
@@ -0,0 +1,7 @@
+#coding:utf-8
+
+import sunhpc
+class command(sunhpc.commands.Command):
+ pass
+
+RollName = 'base'
diff --git a/lib/sunhpc/commands/sync/config/__init__.py b/lib/sunhpc/commands/sync/config/__init__.py
new file mode 100644
index 0000000..029d9a9
--- /dev/null
+++ b/lib/sunhpc/commands/sync/config/__init__.py
@@ -0,0 +1,29 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+class Command(sunhpc.commands.sync.command):
+ """
+ For each system configuration file controlled by Sunhpc, first
+ rebuild the configuration file by extracting data from the
+ database, then restart the relevant services.
+
+ <example cmd='sync config'>
+ Rebuild all configuration files and restart relevant services.
+ </example>
+ """
+ def run(self, params, args):
+ #
+ # don't call insert-ethers if insert-ethers is already
+ # running. this can occur when one replaces a node
+ # (insert-ethers calls 'sunhpc remove host' which calls
+ # sunhpc sync config).
+ #
+ if not os.path.exists('/var/lock/insert-ethers'):
+ #
+ # run the plugins
+ #
+ self.runPlugins()
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/sync/config/plugin_00_safe.py b/lib/sunhpc/commands/sync/config/plugin_00_safe.py
new file mode 100644
index 0000000..20fd5a6
--- /dev/null
+++ b/lib/sunhpc/commands/sync/config/plugin_00_safe.py
@@ -0,0 +1,18 @@
+#coding:utf-8
+
+import sunhpc
+class Plugin(sunhpc.commands.Plugin):
+ """
+ Configure the sunhpc cluster security.
+ """
+ def run(self, args):
+
+ with open('/opt/sunhpc/etc/safeputrc', 'w') as fe:
+ fe.write('<?xml version="1.0" standalone="yes"?>\n')
+ fe.write('<safeput>\n')
+ fe.write('\t<PrivateNetwork id="%s" mask="%s"/>\n' % (
+ self.db.getHostAttr('localhost', 'Kickstart_PrivateAddress'),
+ self.db.getHostAttr('localhost', 'Kickstart_PrivateNetmaskCIDR')))
+ fe.write('</safeput>\n')
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/sync/config/plugin_05_sshd.py b/lib/sunhpc/commands/sync/config/plugin_05_sshd.py
new file mode 100644
index 0000000..6eaf33f
--- /dev/null
+++ b/lib/sunhpc/commands/sync/config/plugin_05_sshd.py
@@ -0,0 +1,11 @@
+#coding:utf-8
+
+import sunhpc
+import subprocess
+class Plugin(sunhpc.commands.Plugin):
+
+ def run(self, args):
+ cmd = "report knownhosts > /etc/ssh/ssh_known_hosts"
+ subprocess.call("/opt/sunhpc/bin/sunhpc %s 2>/dev/null" % cmd, shell=True)
+
+RollName = "base"
diff --git a/lib/sunhpc/commands/sync/users/__init__.py b/lib/sunhpc/commands/sync/users/__init__.py
new file mode 100644
index 0000000..b9c50e9
--- /dev/null
+++ b/lib/sunhpc/commands/sync/users/__init__.py
@@ -0,0 +1,21 @@
+#coding:utf-8
+
+import sunhpc
+class Command(sunhpc.commands.sync.command):
+ """
+ Update all user-related files (e.g., /etc/passwd, /etc/shadow, etc.)
+ on all known hosts. Also, restart autofs on all known hosts.
+
+ <example cmd='sync users'>
+ Send all user info to all known hosts.
+ </example>
+ """
+ def run(self, params, args):
+
+ #
+ # fix /etc/passwd
+ #
+ self.runPlugins()
+
+ # Encrypt file to /etc/safe.d directory.
+ self.command('create.security.users')
diff --git a/lib/sunhpc/commands/sync/users/plugin_00_fixmaster.py b/lib/sunhpc/commands/sync/users/plugin_00_fixmaster.py
new file mode 100644
index 0000000..e8bf082
--- /dev/null
+++ b/lib/sunhpc/commands/sync/users/plugin_00_fixmaster.py
@@ -0,0 +1,19 @@
+#coding:utf-8
+
+import os
+import sunhpc
+class Plugin(sunhpc.commands.Plugin):
+ """Relocates home directories to location on file server and fixes autofs.share"""
+
+ def run(self, args):
+ """修复auto.share文件"""
+
+ # 默认auto.master数据
+ share_mount = '/share /etc/auto.share --timeout=1200'
+ homes_mount = '/home /etc/auto.home --timeout=1200'
+ auto_master = '/etc/auto.master'
+
+ if self.cmd.matchText(auto_master, share_mount):
+ with open(auto_master, 'w') as f:
+ f.write('%s\n' % share_mount)
+ f.write('%s\n' % homes_mount)
diff --git a/lib/sunhpc/commands/sync/users/plugin_05_share.py b/lib/sunhpc/commands/sync/users/plugin_05_share.py
new file mode 100644
index 0000000..94e6d94
--- /dev/null
+++ b/lib/sunhpc/commands/sync/users/plugin_05_share.py
@@ -0,0 +1,20 @@
+#coding:utf-8
+
+import os
+import sunhpc
+class Plugin(sunhpc.commands.Plugin):
+ """Relocates home directories to location on file server and fixes autofs.share"""
+
+ def run(self, args):
+ """修复auto.share文件"""
+
+ # 默认auto.share数据
+ hostname = '%s.%s' % (self.db.getFrontendName(),
+ self.db.getHostAttr('localhost', 'Kickstart_PrivateDNSDomain'))
+
+ shared = '/etc/auto.share'
+ content = 'apps %s:/export/&' % (hostname)
+
+ if self.cmd.matchText(shared, content):
+ with open(shared, 'w') as f:
+ f.write('%s\n' % content)
diff --git a/lib/sunhpc/commands/sync/users/plugin_10_fixusers.py b/lib/sunhpc/commands/sync/users/plugin_10_fixusers.py
new file mode 100644
index 0000000..ae083d9
--- /dev/null
+++ b/lib/sunhpc/commands/sync/users/plugin_10_fixusers.py
@@ -0,0 +1,121 @@
+#coding:utf-8
+
+import os
+import sunhpc
+class Plugin(sunhpc.commands.Plugin):
+ """Relocates home directories to location on file server and fixes autofs.share"""
+
+ def run(self, args):
+ """修复auto.home文件"""
+
+ auto_users, pwd_users, new_users = [], [], []
+
+ # 先读取autofs已经挂载的用户名称
+ auto_home = '/etc/auto.home'
+ if os.path.exists(auto_home):
+ with open(auto_home, 'r') as f:
+ for li in f.readlines():
+ auto_users.append(li.split()[0])
+
+ # 去重排序
+ auto_users = sorted(list(set(auto_users)))
+
+ # 获取工作目录 /works/home
+ default_dir = '/export/home/'
+
+ # fix /etc/default/useradd command
+ userhome = 'HOME=%s' % default_dir[:-1]
+ useradd = '/etc/default/useradd'
+ if self.cmd.matchText(useradd, userhome):
+ data = []
+ with open(useradd, 'r') as fe:
+ for line in fe:
+ if line.startswith('HOME='):
+ data.append(userhome)
+ continue
+
+ data.append(line.strip())
+
+ with open(useradd, 'w') as f:
+ f.write('\n'.join(data))
+
+ # 读取/etc/passwd文件新用户.
+ fe = open('/etc/passwd', 'r')
+ for line in fe.readlines():
+
+ l = line[:-1].split(':')
+ if len(l) < 6: continue
+
+ username = l[0]
+ homedirs = l[5]
+
+ # 提取工作目录等于'/export/home'
+ if homedirs[:len(default_dir)] == default_dir:
+ pwd_users.append(username)
+
+ # 过滤掉uid < 1000用户并且用户不在auto_users列表中.
+ # auto_users列表中用户不需要使用usermod命令再次切换
+ if self.handler_uid(line) and username not in auto_users:
+ auto_users.append(username)
+
+ fe.close()
+
+ # cluster.local
+ # 获取控制节点的域名称
+ hostname = self.db.getHostAttr('localhost', 'Info_HomeDirSrv')
+ if not hostname:
+ hostname = '%s.%s' % (self.db.getFrontendName(),
+ self.db.getHostAttr('localhost', 'Kickstart_PrivateDNSDomain'))
+
+ # 如果有自定义参数,则添加.
+ options = self.db.getHostAttr('localhost', 'Info_HomeDirOptions')
+ if options:
+ options = '\t-' + options
+ else:
+ options = ""
+
+ # 清空/etc/auto.home内容.
+ #open(auto_home, 'w').close()
+
+ # 修正用户主目录.
+ for user in pwd_users:
+ cmd = '/usr/sbin/usermod -d %s %s' % (os.path.join('/home', user), user)
+ for line in os.popen(cmd).readlines():
+ self.cmd.addText(line)
+
+ # auto gen authorized_keys
+ rootssh = os.path.join(default_dir, user, '.ssh')
+ if not os.path.exists(rootssh):
+ os.makedirs(rootssh)
+
+ rootrsa = os.path.join(rootssh, 'id_rsa')
+ if not os.path.exists(rootrsa):
+ self.cmd.command('repair.users.authorized', [user])
+
+ content = []
+ # 合并autofs用户和passwd新用户添加到new_users中.
+ new_users.extend(auto_users)
+ new_users.extend(pwd_users)
+ new_users = sorted(list(set(new_users)))
+ for user in new_users:
+ # 更新这个 auto.home 文件.
+ # /export/home/dell
+ new_user_dir = os.path.join(default_dir, user)
+
+ # dell cluster.local:/export/home/dell
+ autofs_entry = '%s%s\t%s:%s\n' % (user, options, hostname, new_user_dir)
+ content.append(autofs_entry)
+
+ with open(auto_home, 'w') as f:
+ f.write(''.join(content))
+
+ def handler_uid(self, x):
+ l = x.split(':')
+ if int(l[2]) < 1000:
+ return False
+ if l[0] in self.avoid_uname():
+ return False
+ return True
+
+ def avoid_uname(self):
+ return ['nobody', 'nobody4', 'noaccess', 'nfsnobody']
diff --git a/lib/sunhpc/core/build.py b/lib/sunhpc/core/build.py
new file mode 100644
index 0000000..13514cc
--- /dev/null
+++ b/lib/sunhpc/core/build.py
@@ -0,0 +1,585 @@
+#coding:utf-8
+import os
+import re
+import sys
+import xml
+import time
+import sunhpc
+import shutil
+import socket
+import tempfile
+import subprocess
+class BuildError(Exception):
+ pass
+
+class Builder:
+
+ def __init__(self):
+ self.verbose = 0
+ self.debug = 0
+
+ def build(self):
+ pass
+
+ def setVerbose(self, level=1):
+ self.verbose = level
+
+ def setDebug(self, level=1):
+ self.debug = level
+
+class MirrorBuilder(Builder):
+
+ def __init__(self, m):
+ Builder.__init__(self)
+ self.mirrors = m
+
+ def build(self):
+ for m in self.mirrors:
+ dirs = []
+ if m.getRemoteReleasePath():
+ dirs.append(m.getRemoteReleasePath())
+ for dir in dirs:
+ self.buildMirror(m.getHost(), dir)
+
+
+ def buildMirror(self, host, path):
+ # Try FTP first, failover to HTTP
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ try:
+ sock.connect((host, 21))
+ sock.close()
+ cmd = 'wget -m -nv ftp://%s//%s/' % (host, path)
+ except socket.error:
+ cmd = 'wget -m -nv -np http://%s//%s/' % (host, path)
+ sock = None
+
+ if self.verbose or self.debug:
+ print (cmd)
+ if not self.debug:
+ subprocess.call(cmd, shell=True)
+
+class DistributionBuilder(Builder):
+
+ def __init__(self, dist, links=1):
+ Builder.__init__(self)
+ self.cmd = None
+ self.dist = dist
+ self.quiet = None
+ self.useLinks = links
+ self.compsPath = None
+ self.useRolls = {}
+ self.allRolls = 1
+ self.onlyRolls = 0
+ self.version = '1.0'
+ self.calcmd5 = 1
+ self.withSiteProfiles = 0
+ for mirror in self.dist.getMirrors():
+ if not mirror.isBuilt():
+ mirror.build()
+
+ if not self.dist.isBuilt():
+ self.dist.build()
+
+ def setCommand(self, command):
+ self.cmd = command
+ def setQuiet(self, quiet):
+ self.quiet = quiet
+
+ def setRolls(self, list, only=0):
+ if list:
+ for e in list:
+ self.useRolls[e[0]] = (e[1], e[2])
+ self.allRolls = 0
+ else:
+ self.useRolls = {}
+ self.allRolls = 1
+ self.onlyRolls = only
+
+ def setVersion(self, ver):
+ self.version = ver
+
+ def setSiteProfiles(self, bool):
+ self.withSiteProfiles = bool
+
+ def setCalcMD5(self, bool):
+ self.calcmd5 = bool
+
+ def clean(self):
+ self.dist.getTree('release').apply(self.cleaner)
+
+ def useRoll(self, key, ver, arch):
+ "Returns true if we should include this roll"
+ if arch == self.dist.arch:
+ if self.allRolls:
+ return 1
+ if key in self.useRolls:
+ version, enabled = self.useRolls[key]
+ if enabled and version == ver:
+ return 1
+ return 0
+
+ def getRollBaseFiles(self):
+ # 这里收集rolls中文件
+ # roll 名字,架构,版本,都必须和数据库匹配才可以收集到信息.
+ files = []
+ for m in self.dist.getMirrors():
+ # m: dist-Mirror class
+ for key, value in m.getRolls().items():
+ # key-value : base [('x86_64', '7.0.0')]
+ # key-value : kernel [('x86_64', '7.0.0')]
+ for arch, ver in value:
+ if self.useRoll(key, ver, arch):
+ if not self.quiet:
+ self.cmd.msg(' including roll (%s,%s) "%s" ...' % (ver, arch, key))
+ files.extend(m.getRollBaseFiles(key, ver, arch))
+ return files
+
+ def getRollRPMS(self):
+ rpms = []
+ for m in self.dist.getMirrors():
+ for key, value in m.getRolls().items():
+ for arch, ver in value:
+ if self.useRoll(key, ver, arch):
+ if not self.quiet:
+ self.cmd.msg(' including roll (%s,%s) "%s" ...' % (ver, arch, key))
+ rpms.extend(m.getRollRPMS(key, ver, arch))
+ return rpms
+
+ def getRollSRPMS(self):
+ rpms = []
+ for m in self.dist.getMirrors():
+ for key, value in m.getRolls().items():
+ for arch, ver in value:
+ if self.useRoll(key,ver,arch):
+ if not self.quiet:
+ self.cmd.msg(' including roll (%s,%s) "%s" ...' % (ver, arch, key))
+ rpms.extend(m.getRollSRPMS(key, ver, arch))
+ return rpms
+
+ def buildRPMSList(self):
+ rpms = self.getRollRPMS()
+ for mirror in self.dist.getMirrors():
+ rpms.extend(mirror.getRPMS())
+ if not self.onlyRolls:
+ rpms.extend(self.dist.getContribRPMS())
+ rpms.extend(self.dist.getLocalRPMS())
+ if not os.path.isdir(self.dist.getForceRPMSPath()):
+ os.makedirs(self.dist.getForceRPMSPath())
+ else:
+ rpms.extend(self.dist.getForceRPMS())
+ return rpms
+
+ def buildSRPMSList(self):
+ rpms = self.getRollSRPMS()
+ for mirror in self.dist.getMirrors():
+ rpms.extend(mirror.getSRPMS())
+ rpms.extend(self.dist.getContribSRPMS())
+ rpms.extend(self.dist.getLocalSRPMS())
+ return rpms
+
+ def buildRollLinks(self):
+ """Links all rolls from our mirrors into rocks-dist/rolls/"""
+
+ print("Building Roll Links")
+ rollLocation = self.dist.getRollsPath()
+ subprocess.call('mkdir -p %s' % rollLocation, shell=True)
+
+ rolls = []
+ for mirror in self.dist.getMirrors():
+ rolldir = mirror.getRollsPath()
+ if not os.path.exists(rolldir):
+ continue
+ for d in os.listdir(rolldir):
+ rollpath = os.path.join(rolldir,d)
+ if os.path.isdir(rollpath):
+ rolls.append(rollpath)
+
+ here = os.getcwd()
+ os.chdir(rollLocation)
+ for r in rolls:
+ subprocess.call('ln -sf %s .' % (r), shell=True)
+ os.chdir(here)
+
+ def buildBase(self):
+ if not self.quiet:
+ baselen = len(self.resolveVersions(self.getRollBaseFiles()))
+ self.cmd.msg('Resolving versions (base - %s)' % baselen)
+
+ self.dist.setBaseFiles(self.resolveVersions(self.getRollBaseFiles()))
+
+ def touchCriticalFiles(self, m, key, ver, arch):
+ criticalfiles = [ 'anaconda', 'anaconda-runtime', 'kudzu', 'kudzu-devel' ]
+ for rpm in m.getRollRPMS(key,ver,arch):
+ try:
+ if rpm.getPackageName() in criticalfiles:
+ rpm.timestamp = int(time.time())
+ except:
+ pass
+
+ def includeCriticalRPMS(self):
+ if not self.quiet:
+ self.cmd.msg(' including "critical" RPMS')
+ for m in self.dist.getMirrors():
+ for key, value in m.getRolls().items():
+ if key != 'base':
+ continue
+ for arch, ver in value:
+ if self.useRoll(key, ver, arch):
+ self.touchCriticalFiles(m,key,ver,arch)
+
+ def buildRPMS(self):
+ if not self.quiet:
+ rpms = len(self.resolveVersions(self.buildRPMSList()))
+ self.cmd.msg('Resolving versions (RPMs - %s)' % rpms)
+ self.dist.setRPMS(self.resolveVersions(self.buildRPMSList()))
+
+ def buildSRPMS(self):
+ if not self.quiet:
+ srpms = len(self.resolveVersions(self.buildSRPMSList()))
+ self.cmd.msg('Resolving versions (SRPMs - %s)' % srpms)
+ self.dist.setSRPMS(self.resolveVersions(self.buildSRPMSList()))
+
+ def insertNetstage(self):
+ if not self.quiet:
+ self.cmd.msg('Rebuiling the images...')
+
+ # cmd = 'rm -f %s/RedHat/base/install.img' % (self.dist.getReleasePath())
+ # subprocess.call(cmd, shell=True)
+
+ imgFiles = ['install', 'initrds', 'efiboot', 'product', 'updates']
+ for i in imgFiles:
+ if not self.quiet:
+ self.cmd.msg(' Applying %s images...' % i)
+ try:
+ self.applyRPM('sunhpc-anaconda-%s' % i, self.dist.getReleasePath())
+ except:
+ self.cmd.msg("Couldn't find the package sunhpc-anaconda-%s" % i, 'w')
+
+ def build(self):
+ #Start SUNHPC Build
+ self.cmd.msg('Staring build the sunhpc distro release ...')
+
+ self.clean()
+ self.dist.syncMirror()
+
+ # 开始收集相关文件.
+ self.buildBase()
+ self.includeCriticalRPMS()
+ self.buildRPMS()
+ self.buildSRPMS()
+
+ if not self.quiet:
+ self.cmd.msg('Start creating files ...')
+ if self.useLinks:
+ if not self.quiet:
+ self.cmd.msg(' (symbolic links - fast)')
+ else:
+ if not self.quiet:
+ self.cmd.msg(' (use deep copy - slow)')
+
+ self.dist.getReleaseTree().apply(self.builder)
+ self.dist.getReleaseTree().apply(self.normalizer)
+
+ # install.img and updates.img
+ self.insertNetstage()
+
+ # install roll-*-kickstart
+ self.buildKickstart()
+
+ # write treeinfo
+ self.buildTreeInfo()
+
+ if not self.quiet:
+ self.cmd.msg('Calling Yum genpkgmetadata.py')
+ self.createrepo()
+
+ if not self.quiet:
+ self.cmd.msg('Rebuilding Product Image including md5 sums')
+ #self.buildProductImg()
+
+ if not self.quiet:
+ self.cmd.msg('Creating Directory Listing')
+ self.makeDirListing()
+
+ if not self.quiet:
+ self.cmd.msg('The sunhpc cluster system distribution build is complete.')
+
+ def buildTreeInfo(self):
+ if not self.quiet:
+ self.cmd.msg(' installing profiles "treeinfo" ...')
+ treeinfo = os.path.join(self.dist.getReleasePath(), '.treeinfo')
+ self.cmd.makeTreeInfo(treeinfo)
+
+ def buildKickstart(self):
+ if not self.quiet:
+ self.cmd.msg('Installing XML Kickstart profiles')
+
+ build = self.dist.getBuildPath()
+ for rpm in self.dist.getRPMS():
+ tok = rpm.getBaseName().split('-')
+ if tok[0] != 'roll':
+ continue
+ try:
+ k = tok.index('kickstart')
+ rollname = '-'.join(tok[1:k])
+ except ValueError:
+ continue
+
+ if not self.quiet:
+ self.cmd.msg(' installing profiles from "%s" ...' % rollname)
+ self.applyRPM(rpm.getBaseName(), build)
+
+ # install the cgi-module-kickstart
+ #self.applyRPM('cgi-module-kickstart', build)
+
+ # Copy local profiles into the distribution.
+ if self.withSiteProfiles:
+ if not self.quiet:
+ self.cmd.msg(' installing profiles from "site" ...')
+ tree = self.dist.getSiteProfilesTree()
+ for dir in tree.getDirs():
+ for file in tree.getFiles(dir):
+ path = os.path.join(build, dir)
+ if not os.path.isdir(path):
+ os.makedirs(path)
+ shutil.copy(file.getFullName(),
+ os.path.join(path, file.getName()))
+ # make sure apache can read site XML
+ file.chmod(0o664)
+
+ def applyRPM(self, name, root, flags=''):
+ rpm = None
+ try:
+ rpm = self.dist.getRPM(name)
+ except sunhpc.core.dist.DistRPMList as e:
+ for r in e.list:
+ if r.getPackageArch() == self.dist.getArch():
+ rpm = r
+ break
+
+ if not rpm:
+ raise ValueError("could not find %s" % name)
+
+ dbdir = os.path.join(root, 'var', 'lib', 'rpm')
+ if not os.path.isdir(dbdir):
+ os.makedirs(dbdir)
+
+ reloc = subprocess.call("rpm -q --queryformat '%{prefixes}\n' -p " +
+ rpm.getFullName() + "| grep none > /dev/null", shell=True)
+
+ cmd = 'rpm -i --ignoresize --nomd5 --force --nodeps --ignorearch '
+ cmd += '--dbpath %s ' % dbdir
+ if reloc:
+ cmd = cmd + '--prefix %s %s %s' % (root, flags, rpm.getFullName())
+ else:
+ cmd = cmd + '--badreloc --relocate /=%s %s %s' % (root, flags, rpm.getFullName())
+ retval = subprocess.call(cmd + ' > /dev/null 2>&1', shell=True)
+
+ shutil.rmtree(os.path.join(root, 'var'))
+ if retval == 256:
+ raise BuildError("could not apply RPM %s" % (name))
+ return retval
+
+ def buildProductImg(self):
+ product = '../../images/product.img'
+ productfilesdir = os.path.join(self.dist.getBuildPath(), 'include')
+ if not os.path.exists(productfilesdir):
+ #
+ # there are no 'product' files, so there's nothing to do.
+ # let's just return
+ #
+ return
+
+ if not self.quiet:
+ self.cmd.msg(' Create product images...')
+ cwd = os.getcwd()
+ #
+ # make an MD5 checksum for all files in the distribution
+ # the 'sed' command strips off the leading "./" from the pathnames
+ # don't include the build, SRPMS and force directories
+ #
+ os.chdir(self.dist.getReleasePath())
+ if self.calcmd5:
+ cmd = '/usr/bin/md5sum `find -L . -type f | sed "s/^\.\///" | '
+ cmd += 'egrep -v "^build|^SRPMS|^force" | egrep -v "rpm$"` '
+ cmd += '> %s/packages.md5' % (productfilesdir)
+ else:
+ cmd = 'touch %s/packages.md5' % (productfilesdir)
+
+ subprocess.call(cmd, shell=True)
+ #
+ # create the product.img file
+ #
+ os.chdir(productfilesdir)
+
+ if not os.path.exists('../../images'):
+ os.makedirs('../../images')
+
+ subprocess.call('rm -f %s' % (product), shell=True)
+ cmd = '/sbin/mksquashfs packages.md5 applets pyanaconda '
+ cmd += '%s ' % (product)
+ cmd += '-keep-as-directory > /dev/null 2>&1'
+ subprocess.call(cmd,shell=True)
+
+ if os.path.exists(product):
+ #
+ # on a server installation (e.g., frontend), mksquashfs
+ # fails, but it is not important that product.img is built
+ # during the installation. product.img was already downloaded
+ # off the CD, so it will not be needed for the remainder of
+ # the server installation.
+ #
+ os.chmod(product, 0o664)
+ os.chdir(cwd)
+ return
+
+ def createrepo(self):
+ if not self.quiet:
+ self.cmd.msg('Creating repository ......')
+
+ cwd = os.getcwd()
+ releasedir = self.dist.getReleasePath()
+ os.chdir(releasedir)
+ #
+ # first check in the install environment (/tmp/updates), then
+ # look in the 'normal' place (on a running frontend).
+ #
+ createrepo = '/tmp/updates/usr/share/createrepo/genpkgmetadata.py'
+ if not os.path.exists(createrepo):
+ createrepo = '/usr/share/createrepo/genpkgmetadata.py'
+ if not self.quiet:
+ self.cmd.msg(' Using the genpkgmetadata create ...')
+ groupfile = "%s/RedHat/base/comps.xml" % releasedir
+ if os.path.exists(groupfile):
+ gf = "--groupfile %s/RedHat/base/comps.xml " % (releasedir)
+ else:
+ self.cmd.msg("Couldn't find the groupfile %s" % groupfile, 'w')
+ self.cmd.msg("\tIf you are bootstrapping, this is not a problem", 'w')
+ gf = " "
+
+ tmpdir = os.getenv("TMPDIR")
+
+ # worker.py (Called by genpkgmetadata) needs tmp space
+ os.putenv("TMPDIR",".")
+ subprocess.call('%s ' % (createrepo) + gf + ' --workers 8 ' + '--quiet .', shell=True)
+
+ if tmpdir is not None:
+ os.putenv("TMPDIR",tmpdir)
+ else:
+ os.unsetenv("TMPDIR")
+ os.chdir(cwd)
+ return
+
+ def makeDirListing(self):
+ #
+ # 只有在/works/sunhpc/install 这个目录
+ # 执行sunhpc create distro 的时候才会执行此处函数.
+ #
+ path = os.path.join(self.dist.getRootPath(), 'rolls')
+ if os.path.exists(path):
+ filename = os.path.join(path, 'index.cgi')
+
+ file = open(filename, 'w')
+ file.write('%s' % (directory_listing_cgi))
+ file.close()
+
+ os.chmod(path, 0o755)
+ os.chmod(filename, 0o755)
+ return
+
+ def cleaner(self, path, file, root):
+ if not root:
+ root = self.dist.getReleasePath()
+ dir = os.path.join(root, path)
+ if dir not in [ self.dist.getForceRPMSPath() ]:
+ os.unlink(os.path.join(dir, file.getName()))
+
+ def builder(self, path, file, root):
+ # path : 'RedHat/RPMS'
+ if not root:
+ # root : /root/test/tmpxxxxx/x86_64
+ root = self.dist.getReleasePath()
+ dir = os.path.join(root, path)
+ fullname = os.path.join(dir, file.getName())
+
+ if file.getFullName() == fullname:
+ return
+
+ if not os.path.isdir(dir):
+ os.makedirs(dir)
+
+ if self.useLinks:
+ file.symlink(fullname, self.dist.getRootPath())
+ else:
+ if os.path.islink(file.getFullName()):
+ os.symlink(os.readlink(file.getFullName()), fullname)
+ else:
+ shutil.copy(file.getFullName(), fullname)
+ os.utime(fullname, (file.getTimestamp(), file.getTimestamp()))
+
+ def normalizer(self, path, file, root):
+ if not root:
+ root = self.dist.getReleasePath()
+ dir = os.path.join(root, path)
+ fullname = os.path.join(dir, file.getName())
+ if file.getFullName() != fullname:
+ file.setFile(fullname)
+
+ def resolveVersions(self, files):
+ dict = {}
+ for e in files:
+ name = e.getUniqueName() # name w/ arch string appended
+ if name not in dict or e >= dict[name]:
+ dict[name] = e
+
+ list = []
+ for e in dict.keys():
+ list.append(dict[e])
+ return list
+
+ def setComps(self, path):
+ self.compsPath = path
+
+directory_listing_cgi = """#!/opt/sunpy3/bin/python3
+import os
+try:
+ dir = os.environ['DOCUMENT_ROOT'] + os.environ['REQUEST_URI']
+except:
+ dir = '.'
+
+if os.path.isfile(dir):
+ dir, filename = os.path.split(dir)
+
+out = ''
+out += '<html>'
+out += '<body>'
+out += '<table>'
+
+#
+# test environ value
+#for i in os.environ:
+# out += '\n%s --- %s\n' % (i, os.environ[i])
+
+listing = os.listdir(dir)
+listing.sort(key=str.lower)
+for file in listing:
+ if file not in [ 'index.cgi' ]:
+ out += '<tr><td>\\n'
+
+ if os.path.isdir(os.path.join(dir, file)):
+ out += '<a href="%s/">%s/</a>\\n' % (file, file)
+ else:
+ out += '<a href="%s">%s</a>\\n' % (file, file)
+
+ out += '</td></tr>'
+ out += '\\n'
+
+out += '</table>'
+out += '</body>'
+out += '</html>'
+
+print ('Content-type: text/html')
+print ('Content-length: %d' % (len(out)))
+print ('')
+print (out)
+"""
diff --git a/lib/sunhpc/core/dist.py b/lib/sunhpc/core/dist.py
new file mode 100644
index 0000000..647c2c4
--- /dev/null
+++ b/lib/sunhpc/core/dist.py
@@ -0,0 +1,391 @@
+#coding:utf-8
+import os
+import sunhpc
+import xml.sax
+class DistError(Exception):
+ pass
+
+class DistRPMList(DistError):
+ def __init__(self, list):
+ Exception.__init__(self, list)
+ self.list = list
+
+class Arch:
+ """Base class that understands Linux architecture strings and nothing
+ else. All distributions needs this information as do other code
+ that handles rpms"""
+
+ def __init__(self):
+ self.arch = ''
+ self.distArch = ''
+ self.cpus = []
+ self.i86cpus = [ 'athlon', 'i686', 'i586', 'i486', 'i386' ]
+
+ def getCPUs(self):
+ return self.cpus
+
+ def getArch(self):
+ return self.arch
+
+ def getDistArch(self):
+ return self.distArch
+
+ def setArch(self, arch, distArch=None):
+ """The two architectures are to handle trends like
+ the AMD64 dist arch, where the true arch is x86_64.
+ NOTE: This trend does not exist with RHEL."""
+
+ self.arch = arch
+ if arch in self.i86cpus:
+ self.cpus = self.i86cpus
+ self.arch = 'i386'
+ elif arch == 'x86_64':
+ self.cpus = [ arch ]
+ self.cpus.extend([ 'ia32e' ])
+ self.cpus.extend(self.i86cpus)
+ else:
+ self.cpus = [ arch ]
+
+ self.cpus.extend([ 'src', 'noarch' ])
+
+ if distArch:
+ self.distArch = distArch
+ else:
+ self.distArch = arch
+
+class Base(Arch):
+ """Understands how to navigate the sometimes arcane
+ RedHat linux distribution directory paths. Used to build
+ and manipulate custom RedHat-compatible distributions."""
+
+ def __init__(self):
+ Arch.__init__(self)
+ self.root = ''
+ self.distdir = ''
+ self.trees = {}
+
+ def isBuilt(self):
+ if self.trees != {}:
+ return 1
+ else:
+ return 0
+
+ def build(self):
+ self.trees['release'] = sunhpc.core.files.Tree(self.getReleasePath())
+
+ def setRoot(self, s):
+ self.root = s
+
+ def setDist(self, d):
+ self.distdir = d
+
+ def getDist(self):
+ return self.distdir
+
+ def getRootPath(self):
+ return self.root
+
+ def getHomePath(self):
+ return os.path.join(self.root, self.distdir)
+
+ def getReleasePath(self):
+ return os.path.join(self.getHomePath(), self.getDistArch())
+
+ def getWANReleasePath(self, client='all'):
+ return os.path.join(self.getHomePath(), client,
+ self.getDistArch())
+
+ def getRPMSPath(self):
+ return os.path.join(self.getReleasePath(), 'RedHat', 'RPMS')
+
+ def getSRPMSPath(self):
+ return os.path.join(self.getReleasePath(), 'SRPMS')
+
+ def getBasePath(self):
+ return os.path.join(self.getReleasePath(), 'RedHat', 'base')
+
+ def getRollCentralPath(self):
+ return str(os.path.join(self.getHomePath(), 'rolls'))
+
+ def getBaseFile(self, name):
+ for file in self.getFiles('release',
+ os.path.join('RedHat', 'base')):
+ if file.getName() == name:
+ return file
+ return None
+
+ def getTreeNames(self):
+ return self.trees.keys()
+
+ def getTree(self, name):
+ if name in self.trees.keys():
+ return self.trees[name]
+ else:
+ return None
+
+ def setFiles(self, name, path, list):
+ self.trees[name].setFiles(path, list)
+
+ def getFiles(self, name, path):
+ try:
+ value = self.trees[name]
+ except KeyError:
+ return []
+ list = []
+ if type(value) == type([]):
+ for tree in value:
+ list.extend(tree.getFiles(path))
+ return list
+ else:
+ return value.getFiles(path)
+
+ def setBaseFiles(self, list):
+ self.setFiles('release', os.path.join('RedHat', 'base'), list)
+
+ def setRPMS(self, list):
+ self.setFiles('release', os.path.join('RedHat', 'RPMS'), list)
+
+ def setSRPMS(self, list):
+ self.setFiles('release', 'SRPMS', list)
+
+ def getPackage(self, name, list):
+ matches = []
+ for file in list:
+ if file.getBaseName() == name:
+ matches.append(file)
+
+ if not matches:
+ return None
+ elif len(matches) == 1:
+ return matches[0]
+ else:
+ raise DistRPMList(matches)
+
+ def getRPM(self, name):
+ return self.getPackage(name, self.getRPMS())
+
+ def getSRPM(self, name):
+ return self.getPackage(name, self.getSRPMS())
+
+ def getRPMS(self):
+ return self.getFiles('release', os.path.join('RedHat', 'RPMS'))
+
+ def getSRPMS(self):
+ return self.getFiles('release', os.path.join('SRPMS'))
+
+ def getReleaseTree(self):
+ return self.getTree('release')
+
+ def dumpDirNames(self):
+ for key in self.trees.keys():
+ value = self.trees[key]
+ if type(value) == types.ListType:
+ for e in value:
+ e.dumpDirNames()
+ else:
+ value.dumpDirNames()
+
+ def dump(self):
+ for key in self.trees.keys():
+ value = self.trees[key]
+ if type(value) == types.ListType:
+ for e in value:
+ e.dump()
+ else:
+ value.dump()
+
+class Mirror(Base):
+
+ def __init__(self, mirror=None):
+ Base.__init__(self)
+ if mirror:
+ self.setHost(mirror.host)
+ self.setPath(mirror.dir)
+ self.setRoot(mirror.root)
+ self.setArch(mirror.arch, mirror.distArch)
+ else:
+ self.host = ''
+ self.dir = ''
+ self.getRelease = 1
+
+ def __str__(self):
+ s = "SunHPC Mirror Distribution\n"
+ s += "Host: %s\n" % self.getHost()
+ s += "Path: %s\n" % self.getPath()
+ return s
+
+ def __cmp__(self, other):
+ if not other:
+ return -1
+ elif other.getHost() == self.getHost() and \
+ other.getPath() == self.getPath():
+ return 0
+ else:
+ return -1
+
+ def build(self):
+ Base.build(self)
+ self.trees['rolls'] = sunhpc.core.files.Tree(self.getRollsPath())
+
+ def getRootPath(self):
+ return self.root
+
+ def setHost(self, s):
+ self.host = s
+
+ def setPath(self, s):
+ self.dir = s
+
+ def getHost(self):
+ return self.host
+
+ def getPath(self):
+ return self.dir
+
+ def getHomePath(self):
+ return os.path.join(self.root, self.host, self.dir)
+
+ def getRemoteReleasePath(self):
+ return os.path.join(self.dir, self.getDistArch())
+
+ def getRollsPath(self):
+ return os.path.join(self.getRootPath(), 'rolls')
+
+ def getRollRPMS(self, roll, version, arch):
+ path = os.path.join(roll, version, arch, 'RedHat', 'RPMS')
+ return self.getFiles('rolls', path)
+
+ def getRollBaseFiles(self, roll, version, arch):
+ path = os.path.join(roll, version, arch, 'RedHat', 'base')
+ return self.getFiles('rolls', path)
+
+ def getRollSRPMS(self, roll, version, arch):
+ path = os.path.join(roll, version, arch, 'SRPMS')
+ return self.getFiles('rolls', path)
+
+ def getRolls(self):
+ rolls = {}
+ rollsPath = self.getRollsPath()
+ if not os.path.exists(rollsPath):
+ return rolls
+ for r in os.listdir(rollsPath):
+ rolls[r] = []
+ rdir = os.path.join(self.getRollsPath(), r)
+ if not os.path.isdir(rdir):
+ continue
+ for v in os.listdir(rdir):
+ vdir = os.path.join(rdir, v)
+ if not os.path.isdir(vdir):
+ continue
+ for a in os.listdir(vdir):
+ adir = os.path.join(vdir, a)
+ if not os.path.isdir(adir):
+ continue
+ rolls[r].append((a, v))
+ return rolls
+
+class Distribution(Base):
+
+ def __init__(self, m, v):
+ Base.__init__(self)
+ self.contrib = ''
+ self.local = ''
+ self.mirrors = m
+ self.root = self.mirrors[0].root
+ self.arch = self.mirrors[0].arch
+ self.distArch = self.mirrors[0].distArch
+ self.cpus = self.mirrors[0].cpus
+ self.version = v
+
+ def build(self):
+ Base.build(self)
+
+ self.trees['contrib'] = sunhpc.core.files.Tree(self.contrib)
+ self.trees['force'] = sunhpc.core.files.Tree(self.getForceRPMSPath())
+ self.trees['site-profiles'] = sunhpc.core.files.Tree(self.getSiteProfilesPath())
+ self.trees['local_srpms'] = []
+ self.trees['local'] = []
+ self.trees['cdrom'] = []
+ self.trees['rolls'] = []
+ for e in self.getSiteRPMSPath():
+ self.trees['local'].append(sunhpc.core.files.Tree(e))
+ for e in self.getSiteSRPMSPath():
+ self.trees['local_srpms'].append(sunhpc.core.files.Tree(e))
+ for f in self.trees['force'].getFiles(''):
+ f.setImortal()
+
+ def setContrib(self, s):
+ self.contrib = s
+
+ def setLocal(self, s):
+ self.local = s
+
+ def getSunhpcRelease(self):
+ return self.version
+
+ def getBuildPath(self):
+ return os.path.join(self.getReleasePath(), 'build')
+
+ def getSiteRPMSPath(self):
+ l = []
+ if self.local:
+ for cpu in self.cpus:
+ l.append(os.path.join(self.local, 'RPMS', cpu))
+ if 'RPMHOME' in os.environ:
+ for cpu in self.cpus:
+ l.append(os.path.join(os.environ['RPMHOME'],
+ 'RPMS', cpu))
+ return l
+
+ def getSiteSRPMSPath(self):
+ l = []
+ if self.local:
+ l.append(os.path.join(self.local, 'SRPMS'))
+ if 'RPMHOME' in os.environ:
+ l.append(os.path.join(os.environ['RPMHOME'], 'SRPMS'))
+ return l
+
+ def getForceRPMSPath(self):
+ return os.path.join(self.getReleasePath(), 'force', 'RPMS')
+
+ def getRollsPath(self):
+ return os.path.join(self.getReleasePath(), 'rolls')
+
+ def getContribRPMSPath(self):
+ return os.path.join(self.contrib, self.arch, 'RPMS')
+
+ def getContribSRPMSPath(self):
+ return os.path.join(self.contrib, self.arch, 'SRPMS')
+
+ def getSiteProfilesPath(self):
+ return os.path.join(self.getRootPath(), 'site-profiles',
+ self.getSunhpcRelease())
+
+ def getMirrors(self):
+ return self.mirrors
+
+ def getContribRPMS(self):
+ return self.getFiles('contrib', os.path.join(self.arch, 'RPMS'))
+
+ def getContribSRPMS(self):
+ return self.getFiles('contrib', os.path.join(self.arch, 'SRPMS'))
+
+ def getLocalRPMS(self):
+ return self.getFiles('local', '')
+
+ def getLocalSRPMS(self):
+ return self.getFiles('local_srpms', '')
+
+ def getForceRPMS(self):
+ return self.getFiles('force', '')
+
+ def getSiteProfilesTree(self):
+ return self.getTree('site-profiles')
+
+ def syncMirror(self):
+ for mirror in self.mirrors:
+ tree = mirror.getTree('release')
+ for key in tree.getDirs():
+ self.getTree('release').\
+ setFiles(key, tree.getFiles(key))
+
diff --git a/lib/sunhpc/core/files.py b/lib/sunhpc/core/files.py
new file mode 100644
index 0000000..886a059
--- /dev/null
+++ b/lib/sunhpc/core/files.py
@@ -0,0 +1,445 @@
+#coding:utf-8
+import os
+import re
+import sys
+import stat
+import time
+import shutil
+import string
+import xml.sax
+import datetime
+
+class File:
+
+ def __init__(self, file, timestamp=None, size=None):
+ self.setFile(file, timestamp, size)
+ self.imortal = 0
+
+ def __cmp__(self, file):
+ if self.getBaseName() != file.getBaseName() or \
+ self.timestamp == file.timestamp:
+ rc = 0
+ elif self.timestamp > file.timestamp:
+ rc = 1
+ else:
+ rc = -1
+
+ if rc and self.imortal + file.imortal == 1:
+ if self.imortal:
+ rc = 1
+ else:
+ rc = -1
+
+ return rc
+
+ def setFile(self, file, timestamp=None, size=None):
+ self.pathname = os.path.dirname(file)
+ self.filename = os.path.basename(file)
+
+ if None not in (timestamp, size):
+ self.timestamp = timestamp
+ self.size = size
+ elif not os.path.islink(file):
+ self.timestamp = os.path.getmtime(file)
+ self.size = os.path.getsize(file)
+ else:
+ orig = os.readlink(file)
+ if os.path.isfile(orig):
+ self.timestamp = os.path.getmtime(orig)
+ self.size = os.path.getsize(file)
+ else:
+ self.timestamp = 0
+ self.size = 0
+
+ def explode(self):
+
+ file = self.getFullName()
+ if os.path.islink(file):
+ orig = os.readlink(file)
+ if os.path.isfile(orig):
+ os.unlink(file)
+ shutil.copy2(orig, file)
+
+ tm = os.path.getmtime(orig)
+ os.utime(file, (tm, tm))
+
+ def setImortal(self):
+ self.imortal = 1
+
+ def getTimestamp(self):
+ return self.timestamp
+
+ def getSize(self):
+ return float(self.size) / (1024*1024)
+
+ def getUniqueName(self):
+ return self.filename
+
+ def getBaseName(self):
+ return self.filename
+
+ def getName(self):
+ return self.filename
+
+ def getShortName(self):
+ return os.path.splitext(self.filename)[0]
+
+ def getPath(self):
+ return self.pathname
+
+ def getFullName(self):
+ return str(os.path.join(self.pathname, self.filename))
+
+ def getFileMode(self):
+ return self.getMode(self.getFullName())
+
+ def getMode(self, path):
+ dirs = os.stat(path)
+ mode = oct(dirs.st_mode)[-3:] # 目录权限0o750 -3:750
+ return (mode, dirs.st_uid, dirs.st_gid)
+
+ def symlink(self, target, base=''):
+ if os.path.isfile(target) or os.path.islink(target):
+ os.unlink(target)
+ os.symlink(self.getFullName(), target)
+
+ def chmod(self, mode):
+ # python2 : 0664
+ # python3 : 0o664
+ if os.path.exists(self.getFullName()):
+ os.chmod(self.getFullName(), mode)
+
+ def dump(self):
+ print ('%s(%s)' % (self.filename, self.pathname))
+
+class Tree:
+
+ def __init__(self, root):
+ self.root = root
+ self.tree = {}
+ self.build('')
+
+ def getRoot(self):
+ return self.root
+
+ def getDirs(self):
+ return list(self.tree.keys())
+
+ def getSize(self):
+ len = 0
+ for key in self.tree.keys():
+ for file in self.tree[key]:
+ len = len + file.getSize()
+ return float(len)
+
+ def getFiles(self, path=''):
+
+ try:
+ list = self.tree[path]
+ except KeyError:
+ list = []
+ return list
+
+ def setFiles(self, path, files):
+ self.tree[path] = files
+
+ def build(self, dir):
+ path = os.path.join(self.root, dir)
+ if not os.path.isdir(path):
+ return
+
+ try:
+ files = os.listdir(path)
+ # 移除python缓存文件夹
+ if '__pycache__' in files: files.remove('__pycache__')
+ except:
+ files = []
+
+ v = []
+ for f in files:
+ filepath = os.path.join(path, f)
+ if os.path.isdir(filepath) and not \
+ os.path.islink(filepath):
+ self.build(os.path.join(dir, f))
+ else:
+ if re.match('.*\.rpm$', f) != None:
+ v.append(RPMFile(filepath))
+ elif re.match('roll-.*\.iso$', f) != None:
+ v.append(RollFile(filepath))
+ else:
+ v.append(File(filepath))
+ self.tree[dir] = v
+
+ def apply(self, func, root=None):
+ for key in self.tree.keys():
+ for e in self.tree[key]:
+ func(key, e, root)
+
+
+class RPMBaseFile(File):
+
+ def __init__(self, file, timestamp=None, size=None, ext=1):
+ File.__init__(self, file, timestamp, size)
+ self.list = []
+
+ s = self.filename # name-ver-rpmver.arch.rpm
+ for x in range(0, ext):
+ i = s.rfind(".")
+ s = self.filename[:i]
+
+ i = s.rfind(".")
+ self.list.append(s[i+1:]) # get architecture string
+ s = self.filename[:i]
+
+ i = s.rfind("-") # get RPM version string
+ self.release = s[i+1:]
+ self.list.append(self.versionList(s[i+1:]))
+ s = self.filename[:i]
+
+ i = s.rfind("-") # get software version string
+ self.version = s[i+1:]
+ self.list.append(self.versionList(s[i+1:]))
+ self.list.append(self.filename[:i]) # get package name
+ self.list.reverse() # we built the list backwards
+
+ def versionList(self, s):
+ list = []
+ for e in re.split('\.+|_+', s):
+ l = []
+ num = ''
+ alpha = ''
+ for c in e:
+ if c in string.digits:
+ num = num + c
+ if alpha:
+ l.append(alpha)
+ alpha = ''
+ else:
+ alpha = alpha + c
+ if num:
+ l.append(num)
+ num = ''
+ if alpha:
+ l.append(alpha)
+ if num:
+ l.append(num)
+ list.append(l)
+ return list
+
+ def getBaseName(self):
+ return self.list[0]
+
+ def getUniqueName(self):
+ return '%s-%s' % (self.list[0], self.list[3])
+
+class RPMFile(RPMBaseFile):
+
+ def __init__(self, file, timestamp=None, size=None):
+ RPMBaseFile.__init__(self, file, timestamp, size)
+
+ def __cmp__(self, file):
+ if self.getPackageArch() != file.getPackageArch():
+ rc = 0
+ else:
+ if abs(int(self.timestamp) - int(file.timestamp)) < 120 :
+ # print "CMP %s:%s" % (self.getFullName(), file.getFullName())
+ f1=os.popen("rpm -qp --qf '%%{BUILDTIME}' %s" % self.getFullName())
+ self.timestamp=float(f1.readline())
+ f1.close()
+ f2=os.popen("rpm -qp --qf '%%{BUILDTIME}' %s" % file.getFullName())
+ file.timestamp=float(f2.readline())
+ f2.close()
+
+ rc = File.__cmp__(self, file)
+ return rc
+
+ def getPackageName(self):
+ return self.getBaseName()
+
+ def getPackageVersion(self):
+ return self.list[1]
+
+ def getPackageRelease(self):
+ return self.list[2]
+
+ def getPackageVersionString(self):
+ return self.version
+
+ def getPackageReleaseString(self):
+ return self.release
+
+ def getPackageArch(self):
+ return self.list[3]
+
+class RollInfoFile(File,
+ xml.sax.handler.ContentHandler, xml.sax.handler.DTDHandler,
+ xml.sax.handler.EntityResolver, xml.sax.handler.ErrorHandler):
+
+ def __init__(self, file):
+ File.__init__(self, file)
+
+ self.attrs = {}
+ parser = xml.sax.make_parser()
+ parser.setContentHandler(self)
+ fin = open(file, 'r')
+ parser.parse(fin)
+ fin.close()
+
+ def startElement(self, name, attrs):
+ self.attrs[str(name)] = {}
+ for (attrName, attrVal) in attrs.items():
+ self.attrs[str(name)][str(attrName)] = str(attrVal)
+
+ def getXML(self):
+
+ xml = []
+ xml.append('<roll name="%s" interface="%s">' %
+ (self.getRollName(), self.getRollInterface()))
+ for tag in self.attrs.keys():
+ if tag == 'roll':
+ continue
+ attrs = ''
+ for key,val in self.attrs[tag].items():
+ attrs += ' %s="%s"' % (key, val)
+ xml.append('\t<%s%s/>' % (tag, attrs))
+ xml.append('</roll>')
+
+ return string.join(xml, '\n')
+
+ def getRollName(self):
+ return self.attrs['roll']['name']
+
+ def getRollInterface(self):
+ return self.attrs['roll']['interface']
+
+ def getRollVersion(self):
+ return self.attrs['info']['version']
+
+ def getRollRelease(self):
+ return self.attrs['info']['release']
+
+ def setRollOS(self, os):
+ self.attrs['info']['os'] = os
+
+ def getRollOS(self):
+ try:
+ return self.attrs['info']['os']
+ except KeyError:
+ return 'linux'
+
+ def setRollArch(self, arch):
+ self.attrs['info']['arch'] = arch
+
+ def getRollArch(self):
+ return self.attrs['info']['arch']
+
+ def getISOMaxSize(self):
+ return float(self.attrs['iso']['maxsize'])
+
+ def setISOMaxSize(self, size):
+ self.attrs['iso']['maxsize'] = size
+
+ def getISOFlags(self):
+ return self.attrs['iso']['mkisofs']
+
+ def getRollRolls(self):
+ return self.attrs['rpm']['rolls']
+
+ def isBootable(self):
+ return int(self.attrs['iso']['bootable'])
+
+ def hasRolls(self):
+ if self.attrs['rpm']['rolls'] != '0':
+ return 1
+ else:
+ return 0
+
+ def hasRPMS(self):
+ return int(self.attrs['rpm']['bin'])
+
+ def hasSRPMS(self):
+ return int(self.attrs['rpm']['src'])
+
+ def getDepsName(self):
+ return self.attrs['deps']['name']
+
+ def getDepsHost(self):
+ return self.attrs['deps']['host']
+
+class RPMbuild:
+ def __init__(self, root):
+ self.root = root
+ self.tree = {}
+ self.fulldirs = {}
+ self.build('')
+
+ def getRoot(self):
+ return self.root
+
+ def getRootName(self):
+ return self.root.split('/')[-1]
+
+ def getRootMode(self):
+ return self.getMode(self.root)
+
+ def getPathMode(self, path):
+ return self.getMode(path)
+
+ def getMode(self, path):
+ dirs = os.stat(path)
+ mode = oct(dirs.st_mode)[-3:] # 目录权限0o750 -3:750
+ return (mode, dirs.st_uid, dirs.st_gid)
+
+ def getDirs(self):
+ return list(self.tree.keys())
+
+ def getDirMode(self, path):
+ return self.fulldirs[path]
+
+ def getFiles(self, path=''):
+ try:
+ list = self.tree[path]
+ except KeyError:
+ list = []
+ return list
+
+ def build(self, dir):
+
+ path = os.path.join(self.root, dir)
+ if not os.path.isdir(path):
+ return
+
+ try:
+ files = os.listdir(path)
+ # 移除python缓存文件夹
+ #if '__pycache__' in files: files.remove('__pycache__')
+ except:
+ raise
+ files = []
+
+ v = []
+ for f in files:
+ filepath = os.path.join(path, f)
+ if os.path.isdir(filepath) and not os.path.islink(filepath):
+ self.build(os.path.join(dir, f))
+ else:
+ v.append(File(filepath))
+
+ dirs = os.stat(path)
+ mode = oct(dirs[stat.ST_MODE])[-3:] # 目录权限0o750 -3:750
+ uid = dirs.st_uid #目录的uid
+ gid = dirs.st_gid #目录的gid
+
+ # 格式: (dirname, owner, uid, gid)
+ self.fulldirs[dir] = (mode, uid, gid)
+ self.tree[dir] = v
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/core/firewalld.py b/lib/sunhpc/core/firewalld.py
new file mode 100644
index 0000000..6e32593
--- /dev/null
+++ b/lib/sunhpc/core/firewalld.py
@@ -0,0 +1,6 @@
+#coding:utf-8
+
+from firewall.client import FirewallClient
+
+fw = FirewallClient()
+print (fw)
diff --git a/lib/sunhpc/core/ip.py b/lib/sunhpc/core/ip.py
new file mode 100644
index 0000000..906e3b1
--- /dev/null
+++ b/lib/sunhpc/core/ip.py
@@ -0,0 +1,116 @@
+#coding:utf-8
+import sys
+from sunhpc.core.utils import CommandError
+class IPAddr:
+ def __init__(self, addr):
+ # str: 255.255.0.0 -> 0.0.255.255
+ if type(addr) == str:
+ self.list = list(map(int, addr.split('.')))
+ self.list.reverse()
+ else:
+ self.list = []
+ # addr=255 &后: 255 | addr=256 &后: addr+1,target步进+1
+ self.list.append((addr & 0x000000ff))
+ # addr=255 &后: target=0 | addr=256 &后: target=1 addr+255, target步进+1
+ self.list.append((addr & 0x0000ff00) >> 8)
+ # Same as above
+ self.list.append((addr & 0x00ff0000) >> 16)
+ # Same as above
+ self.list.append(((addr & 0xff000000) >> 24) & 0x000000ff)
+
+ def address(self):
+ return ((self.list[3] << 24) +
+ (self.list[2] << 16) +
+ (self.list[1] << 8) + self.list[0])
+
+ def __call__(self):
+ return self.address()
+
+ def __getitem__(self, i):
+ return self.list[i]
+
+ def __setitem__(self, i, v):
+ self.list[i] = v
+
+ def __add__(self, n):
+ return IPAddr(self.address() + n)
+
+ def __sub__(self, n):
+ return IPAddr(self.address() - n)
+
+ def __invert__(self):
+ return ~self.address()
+
+ def __or__(self, o):
+ return self.address() | o.address()
+
+ def __and__(self, o):
+ return self.address() & o.address()
+
+ def __xor__(self, o):
+ return self.address() ^ o.address()
+
+ def __repr__(self):
+ # 只使用此类则原路返回地址.
+ return '%d.%d.%d.%d' % ( self.list[3], self.list[2],
+ self.list[1], self.list[0] )
+
+class IPGenerator:
+
+ def __init__(self, network, netmask=None):
+
+ self.network = IPAddr(network)
+ # If no netmask was provided infer it from the address
+ #
+ # 0* - class A
+ # 10* - class B
+ # 110* - class C
+
+ if not netmask:
+ if self.network() & 0x80 == 0x00:
+ self.netmask = IPAddr('255.0.0.0')
+ elif self.network() & 0xc0 == 0x80:
+ self.netmask = IPAddr('255.255.0.0')
+ elif self.network() & 0xe0 == 0xc0:
+ self.netmask = IPAddr('255.255.255.0')
+ else:
+ print ('not a unicast address %s' % self.network)
+ sys.exit(-1)
+ else:
+ self.netmask = IPAddr(netmask)
+
+ # Set the initial address to the top of the address range.
+ self.addr = IPAddr(self.network | IPAddr(~self.netmask))
+
+ def curr(self):
+ if (self.addr & IPAddr(~self.netmask)) == ~self.netmask:
+ raise CommandError('At top of address range')
+
+ if (self.addr & IPAddr(~self.netmask)) == 0x00:
+ raise CommandError('At bottom of address range')
+
+ return self.addr
+
+ def dec(self):
+ return self.next()
+
+ def get_network(self):
+ return "%s" % IPAddr(self.addr & self.netmask)
+
+ def next(self, n=-1):
+ addr = self.addr + n
+
+ if (addr & IPAddr(~self.netmask)) == ~self.netmask:
+ raise CommandError('At top of address range')
+
+ if (addr & IPAddr(~self.netmask)) == 0x00:
+ raise CommandError('At bottom of address range')
+
+ self.addr = addr
+ return self.addr
+
+if __name__ == '__main__':
+ a = IPGenerator('10.1.1.0', '255.255.255.128')
+ print (a.curr())
+ a.next(-126)
+ print (a.curr())
diff --git a/lib/sunhpc/core/partition.py b/lib/sunhpc/core/partition.py
new file mode 100644
index 0000000..efd272d
--- /dev/null
+++ b/lib/sunhpc/core/partition.py
@@ -0,0 +1,902 @@
+#coding:utf-8
+import os
+import re
+import sys
+import grp
+import stat
+import time
+import syslog
+import tempfile
+
+syslog.openlog('SUNHPC')
+def msg(message):
+ syslog.syslog('DOP: %s' % message)
+
+class Partition:
+
+ raidinfo = ''
+ mountpoints = []
+ saved_fstab = []
+
+ def __init__(self):
+ if os.path.exists('/mnt/runtime/usr/sbin/parted'):
+ self.parted = '/mnt/runtime/usr/sbin/parted'
+ else:
+ self.parted = '/sbin/parted'
+
+ if os.path.exists('/mnt/runtime/usr/sbin/e2label'):
+ #self.e2label = '/mnt/runtime/usr/sbin/e2label'
+ self.xfs_admin = '/mnt/runtime/usr/sbin/xfs_admin'
+ else:
+ self.xfs_admin = '/sbin/xfs_admin'
+
+ if os.path.exists('/mnt/runtime/usr/sbin/mdadm'):
+ self.mdadm = '/mnt/runtime/usr/sbin/mdadm'
+ else:
+ self.mdadm = '/sbin/mdadm'
+
+ if not os.path.exists('/tmp/discovered.disks'):
+ self.discoveredDisks()
+
+ def discoveredDisks(self):
+ raids = []
+ disks = []
+ if not disks:
+ disks = Device().get_disks
+
+ if raids:
+ mdstat = open('/proc/mdstat')
+ for line in mdstat:
+ if line.startswith('md'):
+ devicesName = line.split()[4:]
+ for devName in devicesName:
+ devName = devName.strip('0123456789[]')
+ if devName not in disks:
+ disks.append(devName)
+ mdstat.close()
+
+ # Print disks
+ discoveredDisk = open('/tmp/discovered.disks', 'w')
+ discoveredDisk.write("disks: " + ' '.join(disks) + '\n')
+ discoveredDisk.write("raids: " + ' '.join(raids) + '\n')
+ discoveredDisk.close()
+
+ def getDisks(self):
+ disks = []
+ fe = open('/tmp/discovered.disks', 'r')
+ for line in fe.readlines():
+ l = line.split()
+ if len(l) > 0 and l[0] == 'disks:':
+ for d in l[1:]:
+ #
+ # only include disks that have their
+ # 'media present' -- this is one way
+ # to filter out Dell Virtual Floppy
+ # devices
+ #
+ disks.append(d)
+ fe.close()
+ return disks
+
+ def getRaids(self):
+ raids = []
+ file = open('/tmp/discovered.disks', 'r')
+ for line in file.readlines():
+ l = line.split()
+ if len(l) > 0 and l[0] == 'raids:':
+ raids = l[1:]
+ file.close()
+ try:
+ file = open('/dev/md/md-device-map', 'r')
+ for line in file.readlines():
+ l = line.split()
+ if len(l) > 0 and not (l[0] in raids):
+ raids.append(l[0])
+ file.close()
+ except:
+ #no raid
+ pass
+ return raids
+
+ def gptDrive(self, devname):
+ #
+ # if this is a drive with a GPT format, then return '1'
+ #
+ retval = 0
+
+ cmd = '%s /dev/%s print -s 2> /dev/null' % \
+ (self.parted, devname)
+
+ label = 'Partition Table:'
+ for line in os.popen(cmd).readlines():
+ if len(line) > len(label) and \
+ line[0:len(label)] == label:
+
+ l = line.split()
+ if len(l) > 2 and l[2] == 'gpt':
+ retval = 1
+ break
+
+ return retval
+
+
+ def getDevice(self, strs):
+ device = ''
+
+ a = strs.split('/dev/')
+ if len(a) > 1:
+ device = a[1]
+
+ return device.strip()
+
+
+ def getSectorStart(self, str):
+ sectorstart = ''
+
+ a = str.split('=')
+ if len(a) > 1 and a[0].strip() == 'start':
+ sectorstart = a[1]
+ else:
+ sectorstart = a[0]
+
+ return sectorstart.strip()
+
+
+ def getPartitionSize(self, str):
+ partitionsize = ''
+
+ a = str.split('=')
+ if len(a) > 1 and a[0].strip() == 'size':
+ partitionsize = a[1]
+ else:
+ partitionsize = a[0]
+
+ return partitionsize.strip()
+
+
+ def getPartId(self, str):
+ partid = ''
+
+ a = str.split('=')
+ if len(a) > 1 and a[0].strip() == 'Id':
+ partid = a[1]
+ else:
+ partid = a[0]
+
+ return partid.strip()
+
+
+ def getFsType(self, mntpoint):
+ return self.findFsTypeInFstab(mntpoint)
+
+
+ def getBootFlags(self, str):
+ return str.strip()
+
+
+ def getMountPoint(self, devicename):
+ mntpoint = ''
+
+ cmd = 'blkid -o export /dev/%s | grep UUID ' % devicename
+ cmd += ' 2> /dev/null'
+ uuid = os.popen(cmd).readlines()
+ if len(uuid) > 0:
+ mntpoint = self.findMntInFstab(uuid[0][:-1])
+
+ if mntpoint == '':
+ mntpoint = self.findMntInFstab('/dev/' + devicename)
+
+ if mntpoint == '':
+ #
+ # see if the device is part of a raidset
+ #
+ mntpoint = self.getRaidName(devicename)
+
+ if mntpoint == '':
+ cmd = '%s /dev/%s 2> /dev/null' % \
+ (self.xfs_admin, devicename)
+ label = os.popen(cmd).readlines()
+
+ label = ''.join(label)
+ id = 'LABEL=%s' % (label[:-1])
+
+ mntpoint = self.findMntInFstab(id)
+
+ return mntpoint
+
+
+ def getRaidName(self, partition_device):
+ raidname = ''
+
+ for info in self.raidinfo:
+ if len(info) > 3:
+ (device, partitions, raidlevel,
+ num_partitions) = info
+
+ if partition_device in partitions:
+ raidname = 'raid.%s' % partition_device
+ break
+
+ return raidname
+
+
+ def findMntInFstab(self, identifier):
+ for line in self.saved_fstab:
+ l = line.split()
+ if len(l) > 0:
+ if l[0] == identifier:
+ return l[1]
+
+ return ''
+
+ def findFsTypeInFstab(self, mntpoint):
+ for line in self.saved_fstab:
+ l = line.split()
+ if len(l) > 2:
+ if l[1] == mntpoint:
+ return l[2]
+
+ return ''
+
+
+ def formatPartedNodePartInfo(self, devname, info):
+ #
+ # this function parses partition info from 'parted'
+ #
+ partinfo = []
+ isDisk = 0
+
+ for line in info:
+ l = line[:-1].split()
+
+ if len(l) > 2 and re.match('[0-9]+', l[0]):
+ if devname[0:2] == 'md':
+ device = devname
+ elif len(devname) > 4 and \
+ devname[0:5] == 'cciss':
+ #
+ # special case for HP smart array
+ # controllers
+ #
+ device = devname + 'p' + l[0]
+ else:
+ device = devname + l[0]
+ isDisk = 1
+ else:
+ if len(l) > 1 and l[0] == 'Disk':
+ isDisk = 1
+ continue
+
+ sectorstart = l[1]
+ partitionsize = l[3]
+ partid = ''
+
+ if devname[0:2] == 'md' and len(l) > 4:
+ #
+ # special case for software RAID. there is
+ # no 'Type' or 'Flags' fields, so the
+ # 'File system' field is 5th field
+ #
+ fstype = l[4]
+ bootflags = ''
+ else:
+ bfs = None
+ if len(l) > 5 and not self.gptDrive(devname):
+ #
+ # there is a case for RAID 0 that the
+ # second partition of the drive does not
+ # get the file system label
+ # (e.g., ext4), so the 'boot flags'
+ # get misidentified as a
+ # file system type
+ #
+ if 'raid' in l[5] or 'boot' in l[5]:
+ bfs = l[5:]
+ fstype = ''
+
+ else:
+ fstype = l[5]
+ else:
+ fstype = ''
+
+ if len(l) > 4 and self.gptDrive(devname):
+ # gpt partition there is not
+ # Type column so fstype is in 4
+ fstype = l[4]
+ if len(l) > 5:
+ bfs = [l[5]]
+
+
+ if not bfs and len(l) > 6:
+ bfs = l[6:]
+
+ if bfs:
+ bf = []
+ for b in bfs:
+ bf.append(b.rstrip(','))
+ bootflags = ' '.join(bf)
+ else:
+ bootflags = ''
+
+ if 'linux-swap' in fstype:
+ mntpoint = 'swap'
+ else:
+ mntpoint = self.getMountPoint(device)
+
+ # print 'formatPartedNodePartInfo:l: ', l
+
+ partinfo.append('%s,%s,%s,%s,%s,%s,%s,%s\n' %
+ (device, sectorstart, partitionsize,
+ partid, fstype, bootflags, '',
+ mntpoint))
+
+ # print 'formatPartedNodePartInfo:partinfo: ', partinfo
+
+ if partinfo == [] and isDisk:
+ #
+ # this disk has no partitions, create a
+ # dummy null entry for it
+ #
+ partinfo = [ '%s,,,,,,,\n' % (devname) ]
+
+ return partinfo
+
+
+ def parsePartInfo(self, info):
+ n = info.split(',')
+
+ if len(n) != 8:
+ return ('', '', '', '', '', '', '', '')
+
+ device = n[0].strip()
+ sectorstart = n[1].strip()
+ partitionsize = n[2].strip()
+ partid = n[3].strip()
+ fstype = n[4].strip()
+ bootflags = n[5].strip()
+ partflags = n[6].strip()
+ mntpoint = n[7].strip()
+
+ return (device, sectorstart, partitionsize, partid,
+ fstype, bootflags, partflags, mntpoint)
+
+
+ def getDiskInfo(self, disk):
+ syslog.syslog('getDiskInfo: disk:%s' % (disk))
+
+ cmd = '%s /dev/%s ' % (self.parted, disk)
+ cmd += 'print -s 2> /dev/null'
+ diskinfo = os.popen(cmd).readlines()
+
+ syslog.syslog('getNodePartInfo: diskinfo:%s' % (diskinfo))
+
+ return diskinfo
+
+
+ def getRaidLevel(self, device):
+ level = None
+
+ cmd = '%s --query --detail ' % (self.mdadm)
+ cmd += '/dev/%s' % (device)
+ for line in os.popen(cmd).readlines():
+ l = line.split()
+ if len(l) > 3 and l[0] == 'Raid' and l[1] == 'Level':
+ if l[3][0:4] == 'raid':
+ level = l[3][4:]
+ break
+
+ return level
+
+
+ def getRaidParts(self, device):
+ parts = []
+
+ foundparts = 0
+ cmd = '%s --query --detail ' % (self.mdadm)
+ cmd += '/dev/%s' % (device)
+ for line in os.popen(cmd).readlines():
+ l = line.split()
+ if len(l) > 4 and l[3] == 'RaidDevice':
+ foundparts = 1
+ continue
+
+ if foundparts == 0:
+ continue
+
+ if len(l) == 0:
+ continue
+
+ part = l[-1].split('/')
+ parts.append('raid.%s' % part[-1])
+
+ return ' '.join(parts)
+
+
+ def getNodePartInfo(self, disks):
+ arch = os.uname()[4]
+
+ partinfo = []
+ nodedisks = {}
+
+ for line in self.getFstab(disks):
+ self.saved_fstab.append(line)
+
+ for devname in disks:
+ diskinfo = self.getDiskInfo(devname)
+ partinfo += self.formatPartedNodePartInfo(devname,
+ diskinfo)
+
+ syslog.syslog('getNodePartInfo: partinfo:%s' % (partinfo))
+
+ for node in partinfo:
+ n = self.parsePartInfo(node)
+
+ (nodedevice, nodesectorstart, nodepartitionsize,
+ nodepartid, nodefstype, nodebootflags,
+ nodepartflags, nodemntpoint) = n
+
+ if (len(nodedevice) > 2) and (nodedevice[0:2] == 'md'):
+ nodepartflags = '--level=%s' % \
+ self.getRaidLevel(nodedevice)
+
+ nodebootflags = self.getRaidParts(nodedevice)
+
+ n = (nodedevice, nodesectorstart,
+ nodepartitionsize,
+ nodepartid, nodefstype,
+ nodebootflags,
+ nodepartflags, nodemntpoint)
+
+ elif nodebootflags != '':
+ if 'raid' in nodebootflags.split():
+ nodemntpoint = 'raid.%s' % (nodedevice)
+
+ n = (nodedevice, nodesectorstart,
+ nodepartitionsize,
+ nodepartid, nodefstype,
+ nodebootflags,
+ nodepartflags, nodemntpoint)
+
+ if nodedevice != '':
+ key = ''
+ for disk in disks:
+ if len(disk) <= len(nodedevice) and \
+ disk == nodedevice[0:len(disk)]:
+
+ key = disk
+ break
+
+ if key != '':
+ if key not in nodedisks:
+ nodedisks[key] = [n]
+ else:
+ nodedisks[key].append(n)
+
+ syslog.syslog('getNodePartInfo:nodedisks:%s' % (nodedisks))
+
+ return nodedisks
+
+
+ def listDiskPartitions(self, disk):
+ list = []
+ inHeader = 1
+
+ if disk[0:2] == 'md':
+ return [ (disk, 'dummy') ]
+
+ for part in self.getDiskInfo(disk):
+ l = part.split()
+
+ #
+ # skip the 'parted' header
+ #
+ if len(l) > 1 and l[0] == 'Number':
+ inHeader = 0
+ continue
+
+ if inHeader:
+ continue
+
+ partnumber = 0
+
+ #
+ # look for a part number
+ #
+ if len(l) > 2 and re.match('[0-9]+', l[0]):
+ partnumber = int(l[0])
+
+ if partnumber > 0:
+ if len(disk) > 4 and disk[0:5] == 'cciss':
+ #
+ # special case for HP smart array
+ # controllers
+ #
+ disk = disk + 'p'
+
+ if len(l) > 5:
+ fstype = l[5]
+ else:
+ fstype = ''
+ if len(l) > 4 and self.gptDrive( disk ):
+ # this is a gpt partition
+ fstype = l[4]
+
+
+ list.append(('%s%d' % (disk, partnumber),
+ fstype))
+
+ return list
+
+
+ def defaultDataDisk(self, disk):
+ basename = '/state/partition'
+ parts = []
+
+ i = 1
+ while 1:
+ nextname = '%s%d' % (basename, i)
+ if nextname not in self.mountpoints:
+ break
+ i = i + 1
+
+ p = 'part '
+ p += '%s --size=1 ' % (nextname)
+ p += '--fstype=ext4 --grow --ondisk=%s ' % (disk)
+ self.mountpoints.append(nextname)
+ parts.append(p)
+
+ return parts
+
+
+ def SunhpcGetPartsize(self, mountpoint):
+ size = 0
+
+ if mountpoint == 'root':
+ size = 60000
+ elif mountpoint == 'var':
+ size = 10000
+ elif mountpoint == 'swap':
+ size = 1000
+
+ return size
+
+
+ def defaultRootDisk(self, disk):
+ arch = os.uname()[4]
+ parts = []
+
+ if arch == 'ia64':
+ p = 'part /boot/efi --size=1000 --fstype=vfat '
+ p += '--ondisk=%s\n' % (disk)
+
+ p = 'part '
+ p += '/ --size=%d ' % (self.SunhpcGetPartsize('root'))
+ p += '--fstype=ext4 --ondisk=%s ' % (disk)
+ self.mountpoints.append('/')
+ parts.append(p)
+
+ p = 'part '
+ p += '/var --size=%d ' % (self.SunhpcGetPartsize('var'))
+ p += '--fstype=ext4 --ondisk=%s ' % (disk)
+ self.mountpoints.append('/var')
+ parts.append(p)
+
+ p = 'part '
+ p += 'swap --size=%d ' % (self.SunhpcGetPartsize('swap'))
+ p += '--fstype=swap --ondisk=%s ' % (disk)
+ self.mountpoints.append('swap')
+ parts.append(p)
+
+ parts += self.defaultDataDisk(disk)
+ return parts
+
+ def getFstab(self, disks):
+ if os.path.exists('/upgrade/etc/fstab'):
+ file = open('/upgrade/etc/fstab')
+ lines = file.readlines()
+ file.close()
+ return lines
+
+ #
+ # if we are here, let's go look at all the disks for /etc/fstab
+ #
+ mountpoint = tempfile.mktemp()
+ os.makedirs(mountpoint)
+ fstab = mountpoint + '/etc/fstab'
+
+ lines = []
+ for disk in disks:
+ for (partition, fstype) in \
+ self.listDiskPartitions(disk):
+
+ if not fstype or 'linux-swap' in fstype:
+ continue
+
+ os.system('mount /dev/%s %s' \
+ % (partition, mountpoint) + \
+ ' > /dev/null 2>&1')
+
+ if os.path.exists(fstab):
+ file = open(fstab)
+ lines = file.readlines()
+ file.close()
+
+ os.system('umount %s 2> /dev/null' %
+ (mountpoint))
+
+ if len(lines) > 0:
+ break
+
+ if len(lines) > 0:
+ break
+
+ try:
+ os.removedirs(mountpoint)
+ except:
+ pass
+
+ return lines
+
+ def isSunhpcDisk(self, partinfo, touchit=0):
+ retval = 0
+ mountpoint = tempfile.mktemp()
+ os.makedirs(mountpoint)
+ for part in partinfo:
+ (dev,start,size,id,fstype,bootflags,partflags,mnt) = part
+
+ if not fstype or fstype == 'linux-swap':
+ continue
+
+ devname = '/dev/%s' % (dev)
+ os.system('mount %s %s' % (devname, mountpoint))
+
+ try:
+ filename = mountpoint + '/.sunhpc-release'
+ if touchit == 1:
+ os.system('touch %s' % filename)
+
+ if os.path.exists(filename):
+ retval = 1
+ except:
+ pass
+
+ os.system('umount %s' % (mountpoint) + ' > /dev/null 2>&1')
+ if retval == 1:
+ break
+ try:
+ os.removedirs(mountpoint)
+ except:
+ pass
+ return retval
+
+ def addPartitions(self, nodepartinfo, format):
+ arch = os.uname()[4]
+ parts = []
+ #
+ # for each partition on a drive, build a partition
+ # specification for anaconda
+ #
+ for node in nodepartinfo:
+ if len(node) == 1: continue
+
+ (nodedevice, nodesectorstart, nodepartitionsize,
+ nodepartid, nodefstype, nodebootflags,
+ nodepartflags, nodemntpoint) = node
+
+ if arch == 'ia64':
+ if nodefstype == 'fat32':
+ nodefstype = 'vfat'
+ elif nodefstype == 'linux-swap':
+ nodefstype = 'swap'
+
+ if nodemntpoint == '':
+ continue
+ #
+ # only add raid partitions if they have a mountpoint
+ # defined by their respective 'md' device.
+ #
+ # anaconda will crash if there is not a valid
+ # mountpoint for the md device
+ #
+ if nodepartid == 'fd':
+ if not self.getRaidMountPoint(nodedevice):
+ continue
+ args = [ nodemntpoint ]
+ if len(nodemntpoint) > 3 and \
+ nodemntpoint[0:4] == 'raid':
+ #
+ # never format a software raid partition and
+ # always set its size to 1
+ #
+ args.append('--noformat')
+ args += [ '--size', '1']
+ elif (nodemntpoint != '/' and nodemntpoint != '/var') \
+ and not format:
+ args.append('--noformat')
+ else:
+ if nodefstype == '':
+ args += [ '--fstype', self.fstype ]
+ else:
+ args += [ '--fstype', nodefstype ]
+
+ israid = 0
+ if len(nodedevice) > 2 and nodedevice[0:2] == 'md':
+ israid = 1
+ args += [ "--device=%s" % (nodedevice) ]
+ args += [ "--useexisting" ]
+
+ if nodepartflags != '':
+ args += [ nodepartflags ]
+ else:
+ args += [ '--onpart', nodedevice ]
+
+ if israid:
+ parts.append('raid %s' % (args.join()))
+ else:
+ parts.append('part %s' % (args.join()))
+
+ self.mountpoints.append(nodemntpoint)
+ return parts
+
+ def compareDiskInfo(self, dbpartinfo, nodepartinfo):
+ if len(dbpartinfo) != len(nodepartinfo):
+ return 0
+
+ for db in dbpartinfo:
+ if len(db) == 1:
+ continue
+
+ (dbdevice, dbsectorstart, dbpartsize, dbpartid,
+ dbfstype, dbbootflags, dbpartflags,
+ dbmntpoint) = db
+
+ found = 0
+ for node in nodepartinfo:
+ if len(node) == 1:
+ continue
+
+ (nodedevice, nodesectorstart, nodepartsize,
+ nodepartid, nodefstype, nodebootflags,
+ nodepartflags, nodemntpoint) = node
+
+ # print 'compareDiskInfo:node: ', node
+ # print 'compareDiskInfo:db: ', db
+ if dbsectorstart == nodesectorstart and \
+ dbpartsize == nodepartsize and \
+ dbpartid == nodepartid and \
+ dbfstype == nodefstype and \
+ dbbootflags == nodebootflags and \
+ dbpartflags == nodepartflags and \
+ dbmntpoint == nodemntpoint:
+ found = 1
+ break
+ if not found: return 0
+ return 1
+
+class Device:
+ def __init__(self, subsys='/sys/class/block'):
+ self.subsys = subsys
+ self.block = []
+ self.disks = {}
+ self.parts = {}
+ self.build()
+
+ def build(self):
+ if not os.path.isdir(self.subsys):
+ return self.devices
+
+ for sub in os.listdir(self.subsys):
+ self.block.append(Block(sub, self.subsys))
+
+ for disk in self.block:
+
+ if disk.device_is_partition:
+ self.parts[disk.info.get('DEVNAME')] = disk
+
+ if not disk.device_is_disk:
+ continue
+
+ if disk.device_is_loop:
+ continue
+
+ if disk.device_is_dm:
+ continue
+
+ self.disks[disk.info.get('DEVNAME')] = disk
+
+ @property
+ def get_parts(self):
+ return list(self.parts.keys())
+ @property
+ def get_disks(self):
+ return list(self.disks.keys())
+
+ @property
+ def get_disk_object(self, dev=None):
+ if dev in self.disks:
+ return self.disks.get(dev)
+ @property
+ def get_part_object(self, dev=None):
+ if dev in self.parts:
+ return self.parts.get(dev)
+
+class Block:
+ def __init__(self, devname, block):
+ self.devname = devname
+ self.block = block
+ self.devinfo = {}
+ self.build()
+
+ def build(self):
+ self.device_info()
+
+ @property
+ def info(self):
+ return self.devinfo
+
+ @property
+ def device_path(self):
+ return os.path.realpath(os.path.join(self.block, self.devname))
+
+ def device_info(self):
+ """read /sys/class/block/sda/uevent info"""
+ # DEVNAME=sda
+ # DEVTYPE=disk
+ uevent = os.path.join(self.device_path, 'uevent')
+ if not os.path.isfile(uevent):
+ return self.devinfo
+
+ with open(uevent) as f:
+ for l in f.readlines():
+ sp = l.split('=', 1)
+ if not len(sp): continue
+ self.devinfo[sp[0]] = sp[1].strip()
+
+ @property
+ def device_is_cdrom(self):
+ try:
+ dev = os.path.join(os.sep, 'dev', self.info.get('DEVNAME'))
+ gid = os.stat(dev)[stat.ST_GID]
+ t = grp.getgrgid(gid).gr_name
+ except:
+ t = None
+
+ if t and t == 'cdrom':
+ return True
+
+ if os.path.exists('/opt/sunhpc/bin/lsscsi'):
+ lsscsi = '/opt/sunhpc/bin/lsscsi'
+ else:
+ lsscsi = '/sbin/lsscsi'
+
+ if not os.path.exists(lsscsi):
+ return False
+
+ line = os.popen('%s |grep "%s"' % (lsscsi, self.info.get('DEVNAME'))).readline().split()
+ if len(line) > 4:
+ if line[1] == 'cd/dvd' or line[3] == 'DVD-ROM':
+ return True
+ return False
+
+ @property
+ def device_is_disk(self):
+ if self.device_is_cdrom:
+ return False
+ has_range = os.path.exists(os.path.join(self.device_path, 'range'))
+ return self.info.get("DEVTYPE") == "disk" or has_range
+
+ @property
+ def device_is_partition(self):
+ has_start = os.path.exists(os.path.join(self.device_path, 'start'))
+ return self.info.get("DEVTYPE") == "partition" or has_start
+
+ @property
+ def device_is_loop(self):
+ return self.info.get('DEVNAME').startswith('loop')
+
+ @property
+ def device_is_dm(self):
+ return self.info.get('DEVNAME').startswith('dm')
+
+ def __repr__(self):
+ return 'Device("%s")' % self.device_path
diff --git a/lib/sunhpc/core/printer.py b/lib/sunhpc/core/printer.py
new file mode 100644
index 0000000..da27c7c
--- /dev/null
+++ b/lib/sunhpc/core/printer.py
@@ -0,0 +1,222 @@
+from __future__ import print_function
+from __future__ import absolute_import
+
+import sys
+import threading
+import collections
+from weakref import WeakKeyDictionary
+
+try:
+ import queue
+except ImportError: # Python 3.x
+ import Queue as queue
+
+
+printer_queue = queue.Queue()
+thread_output_stream = WeakKeyDictionary()
+
+PrintResource = collections.namedtuple("PrintResource", ["content", "sep", "end", "file", "thread"])
+
+
+class PrinterThread(threading.Thread):
+ def __init__(self):
+ super(PrinterThread, self).__init__()
+ self.daemon = True
+
+ def run(self):
+ while True:
+ content, sep, end, file_, thread = printer_queue.get()
+ print(*content, sep=sep, end=end, file=file_)
+ printer_queue.task_done()
+
+
+def __cprint(*args, **kwargs):
+ """ Color print()
+
+ Signature like Python 3 print() function
+ print([object, ...][, sep=' '][, end='\n'][, file=sys.stdout])
+ """
+ if not kwargs.pop("verbose", True):
+ return
+
+ sep = kwargs.get("sep", " ")
+ end = kwargs.get("end", "\n")
+ thread = threading.current_thread()
+ try:
+ file_ = thread_output_stream.get(thread, ())[-1]
+ except IndexError:
+ file_ = kwargs.get("file", sys.stdout)
+
+ printer_queue.put(PrintResource(content=args, sep=sep, end=end, file=file_, thread=thread))
+
+
+def print_error(*args, **kwargs) -> None:
+ """ Print error message prefixing it with [-]
+
+ """
+
+ __cprint("\033[91m[-]\033[0m", *args, **kwargs)
+
+
+def print_status(*args, **kwargs) -> None:
+ """ Print status message prefixing it with [-]
+
+ """
+
+ __cprint("\033[93m[*]\033[0m", *args, **kwargs)
+
+
+def print_success(*args, **kwargs) -> None:
+ """ Print success message prefixing it with [-]
+
+ """
+
+ __cprint("\033[92m[+]\033[0m", *args, **kwargs)
+
+
+def print_info(*args, **kwargs) -> None:
+ """ Print info message prefixing it with [-]
+
+ """
+
+ __cprint(*args, **kwargs)
+
+
+def print_table(headers, *args, **kwargs) -> None:
+ """ Print table.
+
+ example:
+
+ Name Current setting Description
+ ---- --------------- -----------
+ option_name value description
+ foo bar baz
+ foo bar baz
+
+ :param headers: Headers names ex.('Name, 'Current setting', 'Description')
+ :param args: table values, each element representing one line ex. ('option_name', 'value', 'description), ...
+ :param kwargs: 'extra_fill' space between columns, 'header_separator' character to separate headers from content
+ :return:
+ """
+ extra_fill = kwargs.get("extra_fill", 5)
+ header_separator = kwargs.get("header_separator", "-")
+
+ if not all(map(lambda x: len(x) == len(headers), args)):
+ print_error("Headers and table rows tuples should be the same length.")
+ return
+
+ def custom_len(x):
+ try:
+ return len(x)
+ except TypeError:
+ return 0
+
+ fill = []
+ headers_line = ' '
+ headers_separator_line = ' '
+ for idx, header in enumerate(headers):
+ column = [custom_len(arg[idx]) for arg in args]
+ column.append(len(header))
+
+ current_line_fill = max(column) + extra_fill
+ fill.append(current_line_fill)
+ #
+ # headers_line : Name , Current settings .....
+ # headers_sepa : ---- , ---------------- .....
+ #
+ headers_line = "".join((headers_line, "{header:<{fill}}".format(header=header, fill=current_line_fill)))
+ headers_separator_line = "".join((
+ headers_separator_line,
+ "{:<{}}".format(header_separator * len(header), current_line_fill)
+ ))
+
+ print_info()
+ print_info(headers_line)
+ print_info(headers_separator_line)
+ #
+ # args : 所有参数
+ # arg : 每行参数(Tuple) -> ('http_use', 'true', 'Check HTTP[s] service: true/false')
+ #
+ for arg in args:
+ content_line = " "
+ for idx, element in enumerate(arg):
+ content_line = "".join((
+ content_line,
+ "{:<{}}".format(element, fill[idx])
+ ))
+ print_info(content_line)
+
+ print_info()
+
+
+def pprint_dict_in_order(dictionary, order=None) -> None:
+ """ Pretty dict print.
+
+ Pretty printing dictionary in specific order. (as in 'show info' command)
+ Keys not mentioned in *order* parameter will be printed in random order.
+
+ ex. pprint_dict_in_order({'name': John, 'sex': 'male', "hobby": ["rugby", "golf"]}, ('sex', 'name'))
+
+ Sex:
+ male
+
+ Name:
+ John
+
+ Hobby:
+ - rugby
+ - golf
+
+ """
+ order = order or ()
+
+ def prettyprint(title, body):
+ print_info("\n{}:".format(title.capitalize()))
+ if not isinstance(body, str):
+ for value_element in body:
+ print_info("- ", value_element)
+ else:
+ print_info(body)
+
+ keys = list(dictionary.keys())
+ for element in order:
+ try:
+ key = keys.pop(keys.index(element))
+ value = dictionary[key]
+ except (KeyError, ValueError):
+ pass
+ else:
+ prettyprint(element, value)
+
+ for rest_keys in keys:
+ prettyprint(rest_keys, dictionary[rest_keys])
+
+
+def color_blue(string: str) -> str:
+ """ Returns string colored with blue
+
+ :param str string:
+ :return str:
+ """
+
+ return "\033[94m{}\033[0m".format(string)
+
+
+def color_green(string: str) -> str:
+ """ Returns string colored with green
+
+ :param str string:
+ :return str:
+ """
+
+ return "\033[92m{}\033[0m".format(string)
+
+
+def color_red(string: str) -> str:
+ """ Returns string colored with red
+
+ :param str string:
+ :return str:
+ """
+
+ return "\033[91m{}\033[0m".format(string)
diff --git a/lib/sunhpc/core/security.py b/lib/sunhpc/core/security.py
new file mode 100644
index 0000000..033d4cf
--- /dev/null
+++ b/lib/sunhpc/core/security.py
@@ -0,0 +1,200 @@
+#coding:utf-8
+import os
+import re
+import stat
+import sunhpc
+import base64
+import xml.dom.minidom
+from Crypto import Random
+from Crypto.Hash import SHA
+from Crypto.PublicKey import RSA
+from Crypto.Cipher import PKCS1_v1_5
+from Crypto.Signature import PKCS1_v1_5 as Sig_pk
+from sunhpc.core.utils import SafeError
+
+class Security(object):
+
+ def __init__(self, keybit=2048):
+
+ self.master_key = '/etc/safe-security/master.key'
+ self.master_pub = '/etc/safe-security/master.pub'
+ self.shared_key = '/etc/safe-security/shared.key'
+ self.shared_pub = '/etc/safe-security/shared.pub'
+ self.rsa_keybit = keybit
+ if keybit not in [ 1024, 2048, 4096 ]:
+ raise SafeError('Supply the rsa bit too big or non-standard.')
+
+ self.mkey = None
+ self.mpub = None
+ self.skey = None
+ self.spub = None
+
+ self.master = None
+ self.conn = None
+ self.masters = []
+
+ # A regex for our header search.
+ pattern = "\n*(?P<comment>.*?)\$110id\$"
+ self.header_pattern = re.compile(pattern)
+
+ pattern = "<a href=.+>(?P<filename>.+)</a> +(?P<date>\d+.*) +(?P<size>\d+.*)"
+ # Make the pattern matching engine case-insensitive.
+ self.dir_pattern = re.compile(pattern, re.I)
+
+ def addHeader(self, msg):
+ result = '-----BEGIN SECURITY MESSAGE-----\n'
+ result += msg
+ result += '-----END SECURITY MESSAGE-----\n'
+ return result
+
+ def makeRsaKeyPair(self):
+ """生成新的Master 和 Shared RSA密钥对"""
+ dict_rsa = {}
+ Master_rsa = RSA.generate(self.rsa_keybit, Random.new().read)
+ dict_rsa['master.key'] = Master_rsa.exportKey()
+ dict_rsa['master.pub'] = Master_rsa.publickey().exportKey()
+
+ Shared_rsa = RSA.generate(self.rsa_keybit, Random.new().read)
+ dict_rsa['shared.key'] = Shared_rsa.exportKey()
+ dict_rsa['shared.pub'] = Shared_rsa.publickey().exportKey()
+ return dict_rsa
+
+ def readEncKeyPair(self):
+ """使用Master私钥签名,Shared公钥加密,Frontend节点使用"""
+ if not os.path.exists(self.master_key):
+ raise SafeError("Master key is not exists.")
+ if not os.path.exists(self.shared_pub):
+ raise SafeError("Shared pub is not exists.")
+
+ # 使用master 私钥对RSA进行签名.
+ with open(self.master_key, 'r') as f: self.mkey = f.read()
+ # 使用shared 公钥对内容加密.
+ with open(self.shared_pub, 'r') as f: self.spub = f.read()
+
+ def readDecKeyPair(self):
+ """使用Master私钥签名,Shared公钥加密,Frontend节点使用"""
+ if not os.path.exists(self.shared_key):
+ raise SafeError("Shared key is not exists.")
+ if not os.path.exists(self.master_pub):
+ raise SafeError("Master pub is not exists.")
+
+ # 使用Shared 私钥对内容解密.
+ with open(self.shared_key, 'r') as f: self.skey = f.read()
+ # 使用Master 公钥对RSA 验签.
+ with open(self.master_pub, 'r') as f: self.mpub = f.read()
+
+ def encrypt(self, msg, type110 = 1):
+ """使用Master私钥签名,Shared公钥加密,Frontend节点使用"""
+ if not self.spub: self.readEncKeyPair()
+
+ # 转换成str格式.
+ if isinstance(msg, bytes):
+ msg = msg.decode()
+
+ # 使用Shared 公钥加密.
+ shared_pub = RSA.importKey(self.spub)
+ pk = PKCS1_v1_5.new(shared_pub)
+ encrypt_text = []
+ for i in range(0, len(msg), 100):
+ cont = msg[i:i+100]
+ encrypt_text.append(pk.encrypt(cont.encode()))
+
+ cipher_text = b''.join(encrypt_text)
+ basefmt = base64.b64encode(cipher_text).decode()
+
+ result = ''
+ if type110:
+ result += self.sign(basefmt)
+
+ # 友好方式显示密文.
+ for i in range(0, len(basefmt), 100):
+ result += basefmt[i:i+100] + '\n'
+
+ result = self.addHeader(result)
+ return result
+
+ def sign(self, msg):
+ """使用master 私钥进行签名"""
+ if not self.mkey: self.readEncKeyPair()
+ master_key = RSA.importKey(self.mkey)
+
+ # 解码 base64格式
+ msg = base64.b64decode(msg)
+
+ # 将内容进行Hash
+ data = SHA.new(msg)
+ # 读取master 私钥
+ sig_pk = Sig_pk.new(master_key)
+ # 使用master 私钥进行签名
+ sign = sig_pk.sign(data)
+ # 转换成base64位格式
+ result = base64.b64encode(sign)
+ data = result.decode()
+
+ new_fmt = ''
+ # 友好格式输出.
+ for i in range(0, len(data), 100):
+ new_fmt += data[i:i+100] + '\n'
+
+ data = new_fmt + '\n'
+ return data
+
+ def verify(self, msg, key):
+ """使用master 公钥进行验签"""
+ if not self.mpub: self.readDecKeyPair()
+ master_public = RSA.importKey(self.mpub)
+
+ # 将密文进行Hash读取.
+ sha_text = SHA.new(msg)
+
+ # 读取master 公钥进行验签.
+ signer = Sig_pk.new(master_public)
+
+ # 使用签名密文进行验签. 通过则返回真,否则假.
+ result = signer.verify(sha_text, key)
+ return result
+
+ def decrypt(self, msg, type110=1):
+ """使用Master公钥验签,Shared私钥解密,Compute节点使用"""
+
+ if not self.skey: self.readDecKeyPair()
+
+ # 读取Shared 公钥进行解密.
+ shared_private = RSA.importKey(self.skey)
+
+ # 去头部信息.
+ msg_text = self.removeHeader(msg)
+
+ if type110:
+ # 将签名和内容密钥进行分割.
+ msg_text = msg_text.split('\n\n')
+ # 签名密文
+ sig_cip = base64.b64decode(''.join(msg_text[0].split('\n')))
+ # 内容密文
+ rsa_cip = base64.b64decode(''.join(msg_text[1].split('\n')))
+ if not self.verify(rsa_cip, sig_cip):
+ raise SafeError("Signature does not verify.")
+ else:
+ # 内容密文
+ rsa_cip = base64.b64decode(''.join(msg_text.split('\n')))
+
+ cipher = PKCS1_v1_5.new(shared_private)
+ decrypt_text = []
+ # RSA密码生成的位置不同,这里的256需要改变.
+ # RSA 1024Bit->128, 2048Bit->256, 4096->512
+ step = int(self.rsa_keybit / 8)
+ for i in range(0, len(rsa_cip), step):
+ cont = rsa_cip[i:i+step]
+ decrypt_text.append(cipher.decrypt(cont,1))
+
+ decrypt_text = b''.join(decrypt_text)
+ return decrypt_text.decode()
+
+if __name__ == "__main__":
+
+ s = 'Aaab' * 20
+ a = Security()
+ # 加密 默认是进行签名的.
+ e = a.encrypt(s)
+ # 解密
+ d = a.decrypt(c)
diff --git a/lib/sunhpc/core/sql.py b/lib/sunhpc/core/sql.py
new file mode 100644
index 0000000..ce317bc
--- /dev/null
+++ b/lib/sunhpc/core/sql.py
@@ -0,0 +1,110 @@
+#coding:utf-8
+import os
+import sys
+import sunhpc
+import socket
+import sunhpc.commands
+import sunhpc.db.helper
+
+class Application(object):
+ def __init__(self, argv=None):
+
+ self.db = None
+ self.newdb = None
+
+ def connect(self):
+ self.newdb = sunhpc.db.helper.DatabaseHelper()
+ self.newdb.connect()
+
+ self.db = sunhpc.commands.DatabaseConnection(self.newdb)
+ return 1
+
+ def search(self, command):
+ return self.db.search(command)
+
+ def execute(self, command):
+ return self.db.execute(command)
+
+ def fetchone(self):
+ return self.db.fetchone()
+
+ def fetchall(self):
+ return self.db.fetchall()
+
+ def getHostAttr(self, host, attr):
+ return self.getHostAttrs(host).get(attr)
+
+ def getHostAttrs(self, host):
+ return self.db.getHostAttrs(host)
+
+ def close(self):
+ self.db.database.close()
+
+ def commit(self):
+ self.db.database.commit()
+
+ def getNodeId(self, host):
+
+ try:
+ return int(host)
+ except Exception:
+ pass
+
+ self.execute('select id from nodes where name="%s"' % host)
+ try:
+ nodeid, = self.fetchone()
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.ip="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.mac="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ return nodeid
+
+ def __repr__(self):
+ return 'SQLApplication'
+
+class InsertEthersPlugin:
+ """Base class for any module that wants to be notified when a node
+ is added or deleted in the cluster"""
+
+ def __init__(self, app):
+ self.app = app
+ self.screen = app.insertor.screen
+
+ def update(self):
+ "Regenerate your config files and reload them."
+ pass
+
+ def done(self):
+ """Called just before insert-ethers quits and nodes
+ have been added or removed."""
+ pass
+
+ def added(self, nodename):
+ """This node has been added to the cluster."""
+ pass
+
+ def removed(self, nodename):
+ "This node has been removed from the cluster"
+ pass
+
+ def changed(self, old, new):
+ "Not currently used"
+ pass
diff --git a/lib/sunhpc/core/utils.py b/lib/sunhpc/core/utils.py
new file mode 100644
index 0000000..97ec7c3
--- /dev/null
+++ b/lib/sunhpc/core/utils.py
@@ -0,0 +1,278 @@
+#coding:utf-8
+
+import os
+import sys
+import xml.sax
+import importlib
+import subprocess
+from xml.sax import handler
+import sunhpc.modules as sunhpc_modules
+
+MODULES_DIR = sunhpc_modules.__path__[0]
+
+# Sunhpc exception hierarchy
+class SunhpcException(Exception):
+ """Base class for Sunhpc exceptions."""
+ pass
+
+class HostnotfoundException(SunhpcException):
+ """This exception is used when the given host does not exist"""
+ pass
+
+class ParameterNotValid(SunhpcException):
+ """This exception is used when the user input parameters are
+ not valid"""
+ pass
+
+class CommandError(SunhpcException):
+ """This exception is thrown by the sunhpc command line when
+ something goes awry"""
+ pass
+
+class KickstartError(SunhpcException):
+ pass
+
+class KickstartGraphError(KickstartError):
+ pass
+
+class KickstartNodeError(KickstartError):
+ pass
+
+class SafeError(KickstartError):
+ pass
+
+class InsertError(Exception):
+ pass
+
+class InsertDone(Exception):
+ pass
+
+class DumpError(Exception):
+ pass
+
+class Struct:
+ pass
+
+
+def escapeAttr(value):
+ """escape attribute values with XML escaping"""
+ if value is None:
+ return value
+ return xml.sax.saxutils.escape(value, { "\"": "&quot;",
+ "%": "&#x0025;",
+ "'": "\\'"})
+
+def unescapeAttr(value):
+ """unescape attribute values with XML escaping """
+ if value is None:
+ return value
+ return xml.sax.saxutils.unescape(value, {"&quot;": "\"",
+ "&#x0025;": "%"})
+
+def escapeStringForShell(string):
+ """escape the given string so that it can be used in a shell script
+ inside a double quote string"""
+ return string.replace("\"", "\\\"")
+
+def str2bool(s):
+ """Converts an on/off, yes/no, true/false string to 1/0."""
+ if s and s.upper() in [ 'ON', 'YES', 'Y', 'TRUE', '1', 'ENABLED', 'ENABLE']:
+ return True
+ else:
+ return False
+
+def bool2str(b):
+ """Converts an 1/0 to a yes/no"""
+ if b:
+ return 'yes'
+ else:
+ return 'no'
+
+def list2str(list):
+ s = ''
+ for e in list:
+ s = s + e
+ return s
+
+
+def listcmp(l1, l2):
+ return map(lambda a,b: a==b, l1, l2)
+
+def listdup(e, n):
+ l = []
+ for i in range(0, n):
+ l.append(e)
+ return l
+
+
+def list_isprefix(l1, l2):
+ l = listcmp(l1, l2)
+ for i in range(0, len(l1)):
+ if not l[i]:
+ return 0
+ return 1
+
+def getNativeArch():
+ """Returns the canotical arch as reported by the operating system"""
+
+ arch = os.uname()[4]
+ if arch in [ 'i386', 'i486', 'i586', 'i686']:
+ arch = 'i386'
+ return arch
+
+def mkdir(newdir):
+ """Works the way a good mkdir should :)
+ - already exists, silently complete
+ - regular file in the way, raise an exception
+ - parent directory(ies) does not exist, make them as well
+ From Trent Mick's post to ASPN."""
+ if os.path.isdir(newdir):
+ pass
+ elif os.path.isfile(newdir):
+ raise OSError("a file with the same name as the desired " \
+ "dir, '%s', already exists." % newdir)
+ else:
+ head, tail = os.path.split(newdir)
+ if head and not os.path.isdir(head):
+ mkdir(head)
+ if tail:
+ os.mkdir(newdir)
+
+class ParseXML(handler.ContentHandler,
+ handler.DTDHandler,
+ handler.EntityResolver,
+ handler.ErrorHandler):
+ """A helper class to for XML parsers. Uses our
+ startElement_name style."""
+
+ def __init__(self, app=None):
+ handler.ContentHandler.__init__(self)
+ self.app = app
+ self.text = ''
+
+ def startElement(self, name, attrs):
+ """The Mason Katz school of parsers. Make small functions
+ instead of monolithic case statements. Easier to override and
+ to add new tag handlers."""
+ try:
+ f = getattr(self, "startElement_%s" % name)
+ f(name, attrs)
+ except AttributeError:
+ return
+
+ def endElement(self, name):
+ try:
+ f = getattr(self, "endElement_%s" % name)
+ f(name)
+ except AttributeError:
+ return
+
+ def characters(self, s):
+ self.text += s
+
+def system(cmd, type='standard'):
+ if type == 'spinner':
+ return startSpinner(cmd)
+ else:
+ return subprocess.call(cmd, shell=True)
+
+def startSpinner(cmd):
+ """This used to just be a system() but now we
+ control the child output to keep the status
+ on one line using stupid CR terminal tricks.
+ We even add a way cool spinny thing in
+ column zero just to be l33t!
+
+ Does not show standard error output."""
+
+ p = subprocess.Popen(cmd, shell=True,
+ stdin=subprocess.PIPE, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, close_fds=True)
+ w, r ,e = (p.stdin, p.stdout, p.stderr)
+ currLength = 0
+ prevLength = 0
+ spinChars = '-\|/'
+ spinIndex = 0
+ while 1:
+ line = e.readline()
+ if not line:
+ break
+ if len(line) > 79:
+ data = line[0:78]
+ else:
+ data = line[:-1]
+ currLength = len(data)
+ pad = ''
+ for i in range(0, prevLength - currLength):
+ pad = pad + ' '
+ spin = spinChars[spinIndex % len(spinChars)]
+ spinIndex = spinIndex + 1
+ print (spin + data + pad + '\r', end='')
+ prevLength = currLength
+ sys.stdout.flush()
+ r.close()
+ w.close()
+ e.close()
+
+ # Cleanup screen when done
+ pad = ''
+ for i in range(0,78):
+ pad = pad + ' '
+ print ('\r%s\r' % pad, end='')
+
+def index_modules(modules_directory: str = MODULES_DIR) -> list:
+ """ Returns list of all exploits modules
+
+ :param str modules_directory: path to modules directory
+ :return list: list of found modules
+ """
+ modules = []
+ for root, dirs, files in os.walk(modules_directory):
+ files.sort()
+ _, package, root = root.rpartition("sunhpc/modules/".replace("/", os.sep))
+ root = root.replace(os.sep, ".")
+ files = filter(lambda x: not x.startswith("__") and x.endswith(".py"), files)
+ modules.extend(map(lambda x: ".".join((root, os.path.splitext(x)[0])), files))
+
+ return modules
+
+def import_centrals(module: str):
+ """ Imports centrals module
+
+ :param str path: absolute path to centrals e.g. routersploit.modules.exploits.asus_auth_bypass
+ :return: centrals module or error
+ """
+ try:
+ # 判断模块中的类名称.
+ module = importlib.import_module(module)
+ if hasattr(module, "Modules"):
+ return getattr(module, "Modules")
+ elif hasattr(module, "Plugins"):
+ return getattr(module, "Plugins")
+ else:
+ raise ImportError("No module named '{}'".format(module))
+
+ except (ImportError, AttributeError, KeyError) as err:
+ pass
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/db/__init__.py b/lib/sunhpc/db/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sunhpc/db/__init__.py
diff --git a/lib/sunhpc/db/alchemy-bak/database.py b/lib/sunhpc/db/alchemy-bak/database.py
new file mode 100644
index 0000000..8ce81d5
--- /dev/null
+++ b/lib/sunhpc/db/alchemy-bak/database.py
@@ -0,0 +1,210 @@
+#coding:utf-8
+
+import os
+import pwd
+import sys
+import string
+import getopt
+import types
+import subprocess
+import threading
+
+from sqlalchemy import create_engine
+import sqlalchemy
+import sqlalchemy.exc
+
+import sunhpc
+from sunhpc.db.mappings.base import *
+
+threadlocal = threading.local()
+
+
+class Database(object):
+ """
+ This class should proxy all the connection to the database.
+
+ There are two main internal objects inside this class which come from
+ sqlalchemy:
+
+ - session: this is used by the ORM layer
+ - connection: this is used by the execute statement, so every time
+ you use pure sql
+
+ These two objects have two separate DB connections, which means
+ that DB status can be different when queried through them.
+
+ Usage Example::
+
+ db = Database()
+ db.setVerbose()
+ db.connect()
+ """
+
+ def __init__(self):
+ self._dbPath = "/opt/sunhpc/data"
+ self._dbFile = "sunhpcdb"
+ self.verbose = False
+ self.results = False
+ self.conn = None
+ self.engine = None
+ self.datafile = os.path.join(self._dbPath, self._dbFile)
+
+ def setVerbose(self, verbose):
+ """
+ If the verbose is true all the sql will be printed to
+ stdout. This function must be called before the connect
+
+ :type verbose: bool
+ :param verbose: if verbose should be set to True
+
+ """
+ self.verbose = verbose
+
+ def getDBFile(self):
+ return self.datafile
+
+ def connect(self):
+ """
+ It start the connection to the DB and create all the internal
+ data structure
+ """
+
+ if 'SUNHPCDEBUG' in os.environ:
+ self.setVerbose(True)
+
+ url = 'sqlite:///' + self._dbPath + '/' + self._dbFile
+
+ if self.verbose:
+ import logging
+ logging.basicConfig()
+ logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
+
+ if self.verbose:
+ print ("Database connection URL: ", url)
+
+ self.engine = create_engine(url)
+ self.conn = self.engine.connect()
+
+ def reconnect(self):
+ self.engine.dispose()
+ self.conn = self.engine.connect()
+
+ def getSession(self):
+ session = getattr(threadlocal, "session", None)
+
+ if session:
+ return session
+ elif self.engine:
+ Session = sqlalchemy.orm.sessionmaker(bind=self.engine)
+ session = Session()
+ setattr(threadlocal, "session", session)
+ return session
+ else:
+ return None
+
+ def closeSession(self):
+ session = getattr(threadlocal, "session", None)
+ if session:
+ session.close()
+ setattr(threadlocal, "session", None)
+ return
+ return
+
+ def commit(self):
+ """
+ Commit the current session if it existsi. *It does not touch the
+ connection.*
+ """
+ session = self.getSession()
+ if session:
+ session.commit()
+ else:
+ pass
+
+ def search(self, command):
+ self.execute(command)
+ rows = self.fetchall()
+ self.execute(command)
+ return len(rows)
+
+
+ def execute(self, command):
+ if self.conn:
+ if '%' in command:
+ command = command.replace('%', '%%')
+ try:
+ self.results = self.conn.execute(command)
+ except sqlalchemy.exc.OperationalError as e:
+ self.renewConnection()
+ self.results = self.conn.execute(command)
+
+ return self.results.rowcount
+ else:
+ return None
+
+
+ def fetchone(self):
+ """
+ Fetch one row from the results of the previous query
+
+ :rtype: tuple
+ :return: a tuple containing the values of the fetched row.
+ It really returns a :class:`sqlalchemy.engine.result.RowProxy`
+ but it can be treated as a tuple
+ """
+ if self.results:
+ return self.results.fetchone()
+ return ()
+
+ def fetchall(self):
+ """
+ Fetch all rows from the results of the previous query
+
+ :rtype: list
+ :return: a list of tuples containing the values of the fetched rows
+ """
+ if self.results:
+ return self.results.fetchall()
+ return ()
+
+ def close(self):
+ """
+ It closes the connection only. You also need to close the
+ session, if you want to release all the DB resources
+ :meth:`closeSession`
+ """
+ if self.results:
+ self.results.close()
+ self.results = None
+ if self.conn:
+ self.conn.close()
+
+ def renewConnection(self):
+ """
+ It renews the connection, if inactive for few hours mysql
+ closes down the connection, so you might need to renew it.
+ """
+ self.close()
+ self.conn = self.engine.connect()
+
+
+if __name__ == "__main__":
+ d = Database()
+ d.connect()
+ conn = d.getSession()
+ print (conn)
+
+ qry = conn.query(Node)
+ print (qry.all())
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/db/alchemy-bak/helper.py b/lib/sunhpc/db/alchemy-bak/helper.py
new file mode 100644
index 0000000..8f6d110
--- /dev/null
+++ b/lib/sunhpc/db/alchemy-bak/helper.py
@@ -0,0 +1,369 @@
+#coding:utf-8
+
+import socket
+import sunhpc
+import sunhpc.core.utils
+import sunhpc.db.database
+from sqlalchemy import or_, and_
+from sunhpc.db.mappings.base import *
+
+class DatabaseHelper(sunhpc.db.database.Database):
+ """
+ This class extend the Database class with a set of helper methods
+ which deal with the new ORM interface (aka the objects from the
+ sunhpc.db.mappings classes).
+
+ These methods should replace the old methods in :mod:`sunhpc.commands`
+ only methods relative to the command line should remain in the that
+ module, all DB functionality should be migrated.
+ """
+
+ def __init__(self):
+ super(DatabaseHelper, self).__init__()
+
+ self._appliances_list = None
+ self._attribute = None
+ self._frontend = None
+ self._cacheAttrs = {}
+
+ def getFrontendName(self):
+ if self._frontend:
+ return self._frontend
+
+ self._frontend = self.getHostAttr('localhost', 'Kickstart_PrivateHostname')
+ return self._frontend
+
+ def getNodesfromNames(self, names=None, managed_only=0):
+
+ list = []
+ if not names:
+
+ list = self.getSession().query(Node)
+ list = list.all()
+
+ if managed_only:
+ managed_list = []
+ for hostname in list:
+ if self.getHostAttr(hostname,
+ 'managed') == 'true':
+ managed_list.append(hostname)
+ return managed_list
+ return list
+
+ clause = sqlalchemy.sql.expression.false()
+ query = self.getSession().query(Node)
+
+ for name in names:
+ if name.find('select ') == 0: # SQL select
+ self.execute(name)
+ nodes = [i for i, in self.fetchall()]
+ clause = or_(clause, Node.name.in_(nodes))
+ elif name.find('%') >= 0: # SQL % pattern
+ clause = or_(clause, Node.name.like(name))
+
+ else:
+ clause = or_(clause, Node.name == self.getHostname(name))
+
+ # now we register the query on the table Node and append all our clauses on OR
+ query = query.filter(clause)
+ return query.all()
+
+ def getHostIp(self, hostname):
+ hostname = self.getHostname(hostname)
+
+ rows = self.search('''select n.ip from networks n, subnets s
+ WHERE
+ n.subnet = s.id and n.name = "%s"
+ and s.name = "private"
+ ''' % hostname)
+ if rows:
+ return self.fetchone()[0]
+
+ def getHostname(self, hostname=None):
+ """涉及到数据库, 数据库中域名, /etc/hosts 和 DNS解析."""
+
+ arghostname = hostname
+ if hostname:
+ rows = self.search("select * from nodes where name='%s'" % hostname)
+ if rows:
+ return hostname
+
+ '''
+ try:
+ int(hostname)
+ rows = self.search("select name from nodes where id='%d'" % hostname)
+ if rows:
+ return self.fetchone()[0]
+ except ValueError:
+ return
+ '''
+
+ if not hostname:
+ hostname = socket.gethostname().split('.')[0]
+ try:
+ addr = socket.gethostbyname(hostname)
+ if not addr == hostname:
+ (name, aliases, addrs) = socket.gethostbyaddr(addr)
+ if hostname != name and hostname not in aliases:
+ raise NameError
+ except:
+ addr = None
+ if hostname == 'localhost':
+ addr = '127.0.0.1'
+
+ if not addr and self.conn:
+ self.execute("select name from nodes where name='%s'" % hostname)
+ if self.fetchone():
+ return hostname
+
+ # let's check if the name is nodeID
+ row = self.search("select name from nodes where nodes.id = '%s'" % hostname)
+ if row:
+ (hostname, ) = self.fetchone()
+ return hostname
+
+ # let's check if the name is an alias
+ row = self.search("select name from nodes where nodes.alias = '%s'" % hostname)
+ if row:
+ (hostname, ) = self.fetchone()
+ return hostname
+
+ # let's check if this is a mac address
+ self.execute("""select nodes.name from networks, nodes where
+ nodes.id = networks.node and
+ networks.mac = '%s' """ % hostname)
+ try:
+ hostname, = self.fetchone()
+ return hostname
+ except:
+ pass
+
+ # let's check if this is a FQDN
+ n = hostname.split('.')
+ if len(n) > 1:
+ name = n[0]
+ domain = '.'.join(n[1:])
+ query = """select n.name from nodes n, networks nt, subnets s where
+ nt.subnet=s.id and nt.node=n.id and
+ s.dnszone='{}' and (nt.name='{}' or n.name='{}')
+ """.format(domain, name, name)
+ self.execute(query)
+ try:
+ hostname, = self.fetchone()
+ return hostname
+ except:
+ pass
+
+ #
+ # 以下是hostname不在数据库中才会走到这里.
+ # 例如, hostname: aaa, resolv: search lan, /etc/hosts: x.x.x.x aaa
+ # 最终, hostname: aaa
+ # 如果, hostname: aaa.lan , 则会触发 raise cannot resolve host aaa.lan
+ #
+ #
+ # 例如, hostname: aaa, resolv: search lan, /etc/hosts: x.x.x.x aaa.lan
+ # 最终, hostname: aaa.lan
+ #
+ # 如果, hostname 不在/etc/hosts中,直接触发 raise
+ #
+
+ try:
+ fin = open('/etc/resolv.conf', 'r')
+ except:
+ fin = None
+ if fin:
+ # ,e.g.. search local hpc.org
+ # domains = ['local', 'hpc.org']
+ domains = []
+ for line in fin.readlines():
+ tokens = line[:-1].split()
+ if len(tokens) > 0 and tokens[0] == 'search':
+ domains = tokens[1:]
+ for domain in domains:
+ try:
+ name = '%s.%s' % (hostname, domain)
+ addr = socket.gethostbyname(name)
+ hostname = name
+ break
+ except:
+ pass
+ # 如果 hostname 能够在/etc/hosts中匹配.(必须完全匹配包括域名)
+ # 那么 addr 就可以正常得到 /etc/hosts对应的IP地址
+ # 这个时候 hostname 应该是带域名 例如: aaa.lan
+ # 然后使用 aaa.lan 继续使用 getHostname函数进行执行.
+ if addr:
+ return self.getHostname(hostname)
+
+ fin.close()
+
+ raise (sunhpc.core.utils.HostnotfoundException(\
+ 'cannot resolve host "%s"' % hostname))
+
+ if addr == '127.0.0.1': # allow localhost to be valid
+ if arghostname == None:
+ return 'localhost'
+ else:
+ return self.getHostname()
+
+ if self.conn:
+ rows = self.search('select nodes.name from '
+ 'networks, nodes where '
+ 'nodes.id=networks.node and ip="%s"' % addr)
+ if not rows:
+ rows = self.search('select nodes.name '
+ 'from networks, nodes where '
+ 'nodes.id=networks.node and '
+ 'networks.name="%s"' % hostname)
+ if not rows:
+ raise (sunhpc.core.utils.HostnotfoundException(\
+ 'host "%s" is not in cluster' % hostname))
+ hostname, = self.fetchone()
+
+ return hostname
+
+ def getHostAttrs(self, hostname):
+
+ session = self.getSession()
+ if isinstance(hostname, str):
+ hostname = self.getHostname(hostname)
+ node = session.query(Node).filter(Node.name == hostname).one()
+
+ elif isinstance(hostname, Node):
+ node = hostname
+ hostname = node.name
+ else:
+ assert False, "hostname must be either a string with a hostname or a Node."
+
+ attrs = {}
+ attrs['hostname'] = hostname
+ attrs['rack'] = str(node.rack)
+ attrs['rank'] = str(node.rank)
+
+ query = """select attr,value from attributes, nodes where
+ nodes.id = attributes.node and nodes.name='%s'""" % hostname
+ for (attr, value) in self.conn.execute(query):
+ attrs[attr] = value
+
+ return attrs
+
+ def getHostAttr(self, hostname, attr):
+ return self.getHostAttrs(hostname).get(attr)
+
+ def checkHostnameValidity(self, hostname):
+ """
+ check that the given host name is valid
+ it checks that the hostname:
+ - it does not contain any .
+ - it is not already used
+ - it is not in the form of rack<number>
+ - it is not an alias
+ - it is not a mac address
+ """
+ if '.' in hostname:
+ raise (sunhpc.core.utils.CommandError('Hostname %s can not contains any dot.'
+ % hostname))
+
+ msg = ''
+ if hostname.startswith('rack'):
+ number = hostname.split('rack')[1]
+ try:
+ int(number)
+ msg = ('Hostname %s can not be in the from ' \
+ + 'of rack<number>.\n') % hostname
+ msg += 'select a different hostname.'
+ except ValueError:
+ pass
+ if msg:
+ raise (sunhpc.core.utils.CommandError(msg))
+
+ try:
+ host = self.getHostname(hostname)
+ if host:
+ msg = 'Node "%s" already exists.\n' % hostname
+ msg += 'Select a different hostname, cabinet '
+ msg += 'and/or rank value.'
+ except (sunhpc.core.utils.HostnotfoundException, NameError):
+ # good, Host does not exists.
+ return
+ raise sunhpc.core.utils.CommandError(msg)
+
+ def getNodeId(self, host):
+ """Lookup hostname in nodes table. Host may be a name
+ or an IP address. Returns None if not found."""
+ try:
+ return int(host)
+ except Exception:
+ pass
+
+ self.execute('select id from nodes where name="%s"' % host)
+ try:
+ nodeid, = self.fetchone()
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.ip="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.mac="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ return nodeid
+
+ def setNewHostAttr(self, hostname, attr, value):
+
+ session = self.getSession()
+ hostid = self.getNodeId(hostname)
+ try:
+ old_attr = Attribute.loadOne(session, attr=attr, node=hostid)
+ except sqlalchemy.orm.exc.NoResultFound:
+ new_attr = Attribute(attr=attr, value=value, node=hostid)
+ session.add(new_attr)
+ session.commit()
+ return
+
+ old_value = old_attr.value
+ old_attr.value = value
+
+ # 如果数据库中已经存在了数据,则将原有数据加上 _old重新存储.
+ if not attr.endswith(attr_postfix):
+ self.setCategoryAttr(hostname, attr + attr_postfix, old_value)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/db/alchemy-bak/mappings/__init__.py b/lib/sunhpc/db/alchemy-bak/mappings/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sunhpc/db/alchemy-bak/mappings/__init__.py
diff --git a/lib/sunhpc/db/alchemy-bak/mappings/base.py b/lib/sunhpc/db/alchemy-bak/mappings/base.py
new file mode 100644
index 0000000..0191b06
--- /dev/null
+++ b/lib/sunhpc/db/alchemy-bak/mappings/base.py
@@ -0,0 +1,219 @@
+#coding:utf-8
+import sqlalchemy.ext.declarative
+import sqlalchemy.orm
+from sqlalchemy import *
+
+Base = sqlalchemy.ext.declarative.declarative_base()
+
+class SunhpcBase(object):
+ """Additional base class of Sunhpc ORM hierarchy which includes some
+ helper methods for all classes"""
+
+ @property
+ def session(self):
+ """Singelton which return the session of the object"""
+ return sqlalchemy.orm.session.object_session(self)
+
+
+ def delete(self):
+ """instance method to autodelete the instance which calls it
+
+ so you can use
+ node.delete()"""
+ self.session.delete(self)
+
+ @classmethod
+ def loadOne(cls, session, **kwargs):
+ """ """
+ return cls.load(session, **kwargs).one()
+
+ @classmethod
+ def load(cls, session, **kwargs):
+ """
+ this method allow us to run query on all the mapping objects
+ simply using
+
+ e.g.::
+
+ node = Nodes.load(session, Name='compute-0-0', Cpus=2)
+ nic = Network.load(session, Name='compute-0-0', Interface='eth0')
+
+ taken from:
+ http://petrushev.wordpress.com/2010/06/22/sqlalchemy-base-model/
+ """
+ q = session.query(cls)
+ filters = [getattr(cls, field_name)==kwargs[field_name] \
+ for field_name in kwargs]
+ return q.filter(and_(*filters))
+
+class Node(SunhpcBase, Base):
+ __tablename__ = 'nodes'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ name = Column('Name', String(128))
+ cpus = Column('CPUs', Integer, nullable=False, default=1)
+ rack = Column('Rack', Integer)
+ rank = Column('Rank', Integer)
+ arch = Column('Arch', String(32))
+ os = Column('OS', Enum(u'linux', u'sunos'), nullable=False, default=u'linux')
+ alias = Column('Alias', String(128), default='')
+ flags = Column('Flags', String(128), default='')
+ status = Column('Status', String(128), default='os')
+
+ networks = sqlalchemy.orm.relationship("Network", backref="nodes")
+ public_keys = sqlalchemy.orm.relationship("PublicKey", backref="nodes")
+
+ def __repr__(self):
+ return "<Node(name='%s')>" % (self.name)
+
+class Network(SunhpcBase, Base):
+ __tablename__ = 'networks'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ node = Column('Node', Integer, ForeignKey('nodes.ID'))
+ mac = Column('MAC', String(64))
+ ip = Column('IP', String(32))
+ name = Column('Name', String(128))
+ device = Column('Device', String(32))
+ subnet = Column('Subnet', Integer, ForeignKey('subnets.ID'))
+
+class Subnet(SunhpcBase, Base):
+ __tablename__ = 'subnets'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ name = Column('name', String(32), nullable=False, unique=True)
+ dnszone = Column('dnszone', String(64), nullable=False, unique=True)
+ subnet = Column('subnet', String(32), nullable=False)
+ netmask = Column('netmask', String(32), nullable=False)
+ mtu = Column('mtu', Integer, default=1500)
+ servedns = Column('servedns', Boolean, default=False)
+
+ networks = sqlalchemy.orm.relationship("Network", backref="subnets")
+
+class GlobalRoute(SunhpcBase, Base):
+ __tablename__ = 'globalroutes'
+ __table_args__ = {}
+
+ #column definitions
+ gateway = Column('Gateway', String(32), nullable=False)
+ netmask = Column('Netmask', String(32), primary_key=True, nullable=False)
+ network = Column('Network', String(32), primary_key=True, nullable=False)
+ subnet = Column('Subnet', Integer, ForeignKey('subnets.ID'))
+
+class PublicKey(SunhpcBase, Base):
+ __tablename__ = 'publickeys'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ node = Column('Node', Integer, ForeignKey('nodes.ID'), nullable=False)
+ public_key = Column('Public_Key', String(4096))
+ description = Column('Description', String(4096))
+
+class SecNode(SunhpcBase, Base):
+ __tablename__ = 'secnodes'
+ __table_args__ = {}
+
+ #column definitions
+ attr = Column('Attr', String(128), primary_key=True, nullable=False)
+ enc = Column('Enc', String(64))
+ node = Column('Node', Integer, primary_key=True, nullable=False)
+ value = Column('Value', TEXT())
+
+class Attribute(SunhpcBase, Base):
+ __tablename__ = 'attributes'
+ __table_args__ = {}
+
+ #column definitions
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ attr = Column('Attr', String(128), nullable=False)
+ value = Column('Value', TEXT())
+ shadow = Column('Shadow', TEXT())
+ node = Column('Node', Integer, ForeignKey('nodes.ID'), nullable=False)
+
+class Firewall(SunhpcBase, Base):
+ __tablename__ = 'firewalls'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ rulename = Column('Rulename', String(128), nullable=False)
+ service = Column('Service', String(256))
+ protocol = Column('Protocol', String(256))
+ ports = Column('Flags', String(256))
+ action = Column('Action', String(256))
+ comment = Column('Comment', String(256))
+ node = Column('Node', Integer, ForeignKey('nodes.ID'), nullable=False)
+
+class Roll(SunhpcBase, Base):
+ __tablename__ = 'rolls'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ name = Column('Name', String(128), nullable=False)
+ version = Column('Version', String(32), nullable=False)
+ arch = Column('Arch', String(32), nullable=False)
+ os = Column('OS', nullable=False, default=u'linux')
+ enabled = Column('Enabled', Enum(u'yes', u'no'), nullable=False, default=u'yes')
+
+class Bootaction(SunhpcBase, Base):
+ __tablename__ = 'bootactions'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ action = Column('Action', String(256))
+ kernel = Column('Kernel', String(256))
+ ramdisk = Column('Ramdisk', String(256))
+ args = Column('Args', String(1024))
+
+class Distribution(SunhpcBase, Base):
+ __tablename__ = 'distributions'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ name = Column('Name', String(32), nullable=False, default='')
+ os = Column('OS', String(32), nullable=False, default='')
+ Release = Column('Release', String(32), nullable=False, default='')
+
+class SecGlobal(SunhpcBase, Base):
+ __tablename__ = 'secglobals'
+ __table_args__ = {}
+
+ attr = Column('Attr', String(128), primary_key=True, nullable=False)
+ enc = Column('Enc', String(64))
+ value = Column('Value', TEXT())
+
+
+class Partition(SunhpcBase, Base):
+ __tablename__ = 'partitions'
+ __table_args__ = {}
+
+ device = Column('Device', String(128), nullable=False)
+ formatFlags = Column('FormatFlags', String(128), nullable=False)
+ fsType = Column('FsType', String(128), nullable=False)
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ mountpoint = Column('Mountpoint', String(128), nullable=False)
+ node = Column('Node', Integer, nullable=False)
+ partitionFlags = Column('PartitionFlags', String(128), nullable=False)
+ partitionID = Column('PartitionID', String(128), nullable=False)
+ partitionSize = Column('PartitionSize', String(128), nullable=False)
+ sectorStart = Column('SectorStart', String(128), nullable=False)
+
+class Machine(SunhpcBase, Base):
+ __tablename__ = 'machines'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False)
+ node = Column('Node', Integer, nullable=False)
+ name = Column('Name', String(128), nullable=False)
+ vender = Column('Vender', String(128), nullable=False)
+ serial = Column('Serial', String(128), nullable=False)
+ cpus = Column('CPUs', Integer, nullable=False, default=1)
+ core = Column('Cores', Integer, nullable=False)
+ model = Column('Model', String(256), nullable=False)
+ memnumber = Column('MemNumber', Integer, nullable=False, default=1)
+ memsize = Column('MemSize', String(128), nullable=False)
+
+
+
diff --git a/lib/sunhpc/db/database.py b/lib/sunhpc/db/database.py
new file mode 100644
index 0000000..9146059
--- /dev/null
+++ b/lib/sunhpc/db/database.py
@@ -0,0 +1,204 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import threading
+import subprocess
+import sqlalchemy
+import sqlalchemy.exc
+from sqlalchemy import create_engine
+from sunhpc.db.mappings.base import *
+
+threadlocal = threading.local()
+
+class Database(object):
+ """
+ This class should proxy all the connection to the database.
+
+ There are two main internal objects inside this class which come from
+ sqlalchemy:
+
+ - session: this is used by the ORM layer
+ - connection: this is used by the execute statement, so every time
+ you use pure sql
+
+ These two objects have two separate DB connections, which means
+ that DB status can be different when queried through them.
+
+ Usage Example::
+
+ db = Database()
+ db.setVerbose()
+ db.connect()
+ """
+
+ def __init__(self):
+ self.conn = None
+ self.engine = None
+ self.verbose = False
+ self.results = False
+ self._dbPath = "/opt/sunhpc/data"
+ self._dbFile = "sunhpc.db"
+ self._dbLock = ".database"
+ self._datafile = os.path.join(self._dbPath, self._dbFile)
+
+ def setVerbose(self, verbose):
+ """
+ If the verbose is true all the sql will be printed to
+ stdout. This function must be called before the connect
+
+ :type verbose: bool
+ :param verbose: if verbose should be set to True
+
+ """
+ self.verbose = verbose
+
+ def getDBFile(self):
+ return self._datafile
+
+ def connect(self):
+ """
+ It start the connection to the DB and create all the internal
+ data structure
+ """
+
+ if 'SUNHPCDEBUG' in os.environ:
+ self.setVerbose(True)
+
+ url = 'sqlite:///' + self._dbPath + '/' + self._dbFile
+
+ if self.verbose:
+ import logging
+ logging.basicConfig()
+ logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
+
+ if self.verbose:
+ print ("Database connection URL: ", url)
+
+ self.engine = create_engine(url)
+ self.conn = self.engine.connect()
+
+ def reconnect(self):
+ self.engine.dispose()
+ self.conn = self.engine.connect()
+
+ def getSession(self):
+ session = getattr(threadlocal, "session", None)
+
+ if session:
+ return session
+ elif self.engine:
+ Session = sqlalchemy.orm.sessionmaker(bind=self.engine)
+ session = Session()
+ setattr(threadlocal, "session", session)
+ return session
+ else:
+ return None
+
+ def closeSession(self):
+ session = getattr(threadlocal, "session", None)
+ if session:
+ session.close()
+ setattr(threadlocal, "session", None)
+ return
+ return
+
+ def commit(self):
+ """
+ Commit the current session if it existsi. *It does not touch the
+ connection.*
+ """
+ session = self.getSession()
+ if session:
+ session.commit()
+ else:
+ pass
+
+ def search(self, command):
+ self.execute(command)
+ rows = self.fetchall()
+ self.execute(command)
+ return len(rows)
+
+
+ def execute(self, command):
+ if self.conn:
+ if '%' in command:
+ command = command.replace('%', '%%')
+ try:
+ self.results = self.conn.execute(command)
+ except sqlalchemy.exc.OperationalError as e:
+ self.renewConnection()
+ self.results = self.conn.execute(command)
+
+ return self.results.rowcount
+ else:
+ return None
+
+
+ def fetchone(self):
+ """
+ Fetch one row from the results of the previous query
+
+ :rtype: tuple
+ :return: a tuple containing the values of the fetched row.
+ It really returns a :class:`sqlalchemy.engine.result.RowProxy`
+ but it can be treated as a tuple
+ """
+ if self.results:
+ return self.results.fetchone()
+ return ()
+
+ def fetchall(self):
+ """
+ Fetch all rows from the results of the previous query
+
+ :rtype: list
+ :return: a list of tuples containing the values of the fetched rows
+ """
+ if self.results:
+ return self.results.fetchall()
+ return ()
+
+ def close(self):
+ """
+ It closes the connection only. You also need to close the
+ session, if you want to release all the DB resources
+ :meth:`closeSession`
+ """
+ if self.results:
+ self.results.close()
+ self.results = None
+ if self.conn:
+ self.conn.close()
+
+ def renewConnection(self):
+ """
+ It renews the connection, if inactive for few hours mysql
+ closes down the connection, so you might need to renew it.
+ """
+ self.close()
+ self.conn = self.engine.connect()
+
+
+if __name__ == "__main__":
+ d = Database()
+ d.connect()
+ conn = d.getSession()
+ print (conn)
+
+ qry = conn.query(Node)
+ print (qry.all())
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/db/helper.py b/lib/sunhpc/db/helper.py
new file mode 100644
index 0000000..8f6d110
--- /dev/null
+++ b/lib/sunhpc/db/helper.py
@@ -0,0 +1,369 @@
+#coding:utf-8
+
+import socket
+import sunhpc
+import sunhpc.core.utils
+import sunhpc.db.database
+from sqlalchemy import or_, and_
+from sunhpc.db.mappings.base import *
+
+class DatabaseHelper(sunhpc.db.database.Database):
+ """
+ This class extend the Database class with a set of helper methods
+ which deal with the new ORM interface (aka the objects from the
+ sunhpc.db.mappings classes).
+
+ These methods should replace the old methods in :mod:`sunhpc.commands`
+ only methods relative to the command line should remain in the that
+ module, all DB functionality should be migrated.
+ """
+
+ def __init__(self):
+ super(DatabaseHelper, self).__init__()
+
+ self._appliances_list = None
+ self._attribute = None
+ self._frontend = None
+ self._cacheAttrs = {}
+
+ def getFrontendName(self):
+ if self._frontend:
+ return self._frontend
+
+ self._frontend = self.getHostAttr('localhost', 'Kickstart_PrivateHostname')
+ return self._frontend
+
+ def getNodesfromNames(self, names=None, managed_only=0):
+
+ list = []
+ if not names:
+
+ list = self.getSession().query(Node)
+ list = list.all()
+
+ if managed_only:
+ managed_list = []
+ for hostname in list:
+ if self.getHostAttr(hostname,
+ 'managed') == 'true':
+ managed_list.append(hostname)
+ return managed_list
+ return list
+
+ clause = sqlalchemy.sql.expression.false()
+ query = self.getSession().query(Node)
+
+ for name in names:
+ if name.find('select ') == 0: # SQL select
+ self.execute(name)
+ nodes = [i for i, in self.fetchall()]
+ clause = or_(clause, Node.name.in_(nodes))
+ elif name.find('%') >= 0: # SQL % pattern
+ clause = or_(clause, Node.name.like(name))
+
+ else:
+ clause = or_(clause, Node.name == self.getHostname(name))
+
+ # now we register the query on the table Node and append all our clauses on OR
+ query = query.filter(clause)
+ return query.all()
+
+ def getHostIp(self, hostname):
+ hostname = self.getHostname(hostname)
+
+ rows = self.search('''select n.ip from networks n, subnets s
+ WHERE
+ n.subnet = s.id and n.name = "%s"
+ and s.name = "private"
+ ''' % hostname)
+ if rows:
+ return self.fetchone()[0]
+
+ def getHostname(self, hostname=None):
+ """涉及到数据库, 数据库中域名, /etc/hosts 和 DNS解析."""
+
+ arghostname = hostname
+ if hostname:
+ rows = self.search("select * from nodes where name='%s'" % hostname)
+ if rows:
+ return hostname
+
+ '''
+ try:
+ int(hostname)
+ rows = self.search("select name from nodes where id='%d'" % hostname)
+ if rows:
+ return self.fetchone()[0]
+ except ValueError:
+ return
+ '''
+
+ if not hostname:
+ hostname = socket.gethostname().split('.')[0]
+ try:
+ addr = socket.gethostbyname(hostname)
+ if not addr == hostname:
+ (name, aliases, addrs) = socket.gethostbyaddr(addr)
+ if hostname != name and hostname not in aliases:
+ raise NameError
+ except:
+ addr = None
+ if hostname == 'localhost':
+ addr = '127.0.0.1'
+
+ if not addr and self.conn:
+ self.execute("select name from nodes where name='%s'" % hostname)
+ if self.fetchone():
+ return hostname
+
+ # let's check if the name is nodeID
+ row = self.search("select name from nodes where nodes.id = '%s'" % hostname)
+ if row:
+ (hostname, ) = self.fetchone()
+ return hostname
+
+ # let's check if the name is an alias
+ row = self.search("select name from nodes where nodes.alias = '%s'" % hostname)
+ if row:
+ (hostname, ) = self.fetchone()
+ return hostname
+
+ # let's check if this is a mac address
+ self.execute("""select nodes.name from networks, nodes where
+ nodes.id = networks.node and
+ networks.mac = '%s' """ % hostname)
+ try:
+ hostname, = self.fetchone()
+ return hostname
+ except:
+ pass
+
+ # let's check if this is a FQDN
+ n = hostname.split('.')
+ if len(n) > 1:
+ name = n[0]
+ domain = '.'.join(n[1:])
+ query = """select n.name from nodes n, networks nt, subnets s where
+ nt.subnet=s.id and nt.node=n.id and
+ s.dnszone='{}' and (nt.name='{}' or n.name='{}')
+ """.format(domain, name, name)
+ self.execute(query)
+ try:
+ hostname, = self.fetchone()
+ return hostname
+ except:
+ pass
+
+ #
+ # 以下是hostname不在数据库中才会走到这里.
+ # 例如, hostname: aaa, resolv: search lan, /etc/hosts: x.x.x.x aaa
+ # 最终, hostname: aaa
+ # 如果, hostname: aaa.lan , 则会触发 raise cannot resolve host aaa.lan
+ #
+ #
+ # 例如, hostname: aaa, resolv: search lan, /etc/hosts: x.x.x.x aaa.lan
+ # 最终, hostname: aaa.lan
+ #
+ # 如果, hostname 不在/etc/hosts中,直接触发 raise
+ #
+
+ try:
+ fin = open('/etc/resolv.conf', 'r')
+ except:
+ fin = None
+ if fin:
+ # ,e.g.. search local hpc.org
+ # domains = ['local', 'hpc.org']
+ domains = []
+ for line in fin.readlines():
+ tokens = line[:-1].split()
+ if len(tokens) > 0 and tokens[0] == 'search':
+ domains = tokens[1:]
+ for domain in domains:
+ try:
+ name = '%s.%s' % (hostname, domain)
+ addr = socket.gethostbyname(name)
+ hostname = name
+ break
+ except:
+ pass
+ # 如果 hostname 能够在/etc/hosts中匹配.(必须完全匹配包括域名)
+ # 那么 addr 就可以正常得到 /etc/hosts对应的IP地址
+ # 这个时候 hostname 应该是带域名 例如: aaa.lan
+ # 然后使用 aaa.lan 继续使用 getHostname函数进行执行.
+ if addr:
+ return self.getHostname(hostname)
+
+ fin.close()
+
+ raise (sunhpc.core.utils.HostnotfoundException(\
+ 'cannot resolve host "%s"' % hostname))
+
+ if addr == '127.0.0.1': # allow localhost to be valid
+ if arghostname == None:
+ return 'localhost'
+ else:
+ return self.getHostname()
+
+ if self.conn:
+ rows = self.search('select nodes.name from '
+ 'networks, nodes where '
+ 'nodes.id=networks.node and ip="%s"' % addr)
+ if not rows:
+ rows = self.search('select nodes.name '
+ 'from networks, nodes where '
+ 'nodes.id=networks.node and '
+ 'networks.name="%s"' % hostname)
+ if not rows:
+ raise (sunhpc.core.utils.HostnotfoundException(\
+ 'host "%s" is not in cluster' % hostname))
+ hostname, = self.fetchone()
+
+ return hostname
+
+ def getHostAttrs(self, hostname):
+
+ session = self.getSession()
+ if isinstance(hostname, str):
+ hostname = self.getHostname(hostname)
+ node = session.query(Node).filter(Node.name == hostname).one()
+
+ elif isinstance(hostname, Node):
+ node = hostname
+ hostname = node.name
+ else:
+ assert False, "hostname must be either a string with a hostname or a Node."
+
+ attrs = {}
+ attrs['hostname'] = hostname
+ attrs['rack'] = str(node.rack)
+ attrs['rank'] = str(node.rank)
+
+ query = """select attr,value from attributes, nodes where
+ nodes.id = attributes.node and nodes.name='%s'""" % hostname
+ for (attr, value) in self.conn.execute(query):
+ attrs[attr] = value
+
+ return attrs
+
+ def getHostAttr(self, hostname, attr):
+ return self.getHostAttrs(hostname).get(attr)
+
+ def checkHostnameValidity(self, hostname):
+ """
+ check that the given host name is valid
+ it checks that the hostname:
+ - it does not contain any .
+ - it is not already used
+ - it is not in the form of rack<number>
+ - it is not an alias
+ - it is not a mac address
+ """
+ if '.' in hostname:
+ raise (sunhpc.core.utils.CommandError('Hostname %s can not contains any dot.'
+ % hostname))
+
+ msg = ''
+ if hostname.startswith('rack'):
+ number = hostname.split('rack')[1]
+ try:
+ int(number)
+ msg = ('Hostname %s can not be in the from ' \
+ + 'of rack<number>.\n') % hostname
+ msg += 'select a different hostname.'
+ except ValueError:
+ pass
+ if msg:
+ raise (sunhpc.core.utils.CommandError(msg))
+
+ try:
+ host = self.getHostname(hostname)
+ if host:
+ msg = 'Node "%s" already exists.\n' % hostname
+ msg += 'Select a different hostname, cabinet '
+ msg += 'and/or rank value.'
+ except (sunhpc.core.utils.HostnotfoundException, NameError):
+ # good, Host does not exists.
+ return
+ raise sunhpc.core.utils.CommandError(msg)
+
+ def getNodeId(self, host):
+ """Lookup hostname in nodes table. Host may be a name
+ or an IP address. Returns None if not found."""
+ try:
+ return int(host)
+ except Exception:
+ pass
+
+ self.execute('select id from nodes where name="%s"' % host)
+ try:
+ nodeid, = self.fetchone()
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.ip="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.mac="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ return nodeid
+
+ def setNewHostAttr(self, hostname, attr, value):
+
+ session = self.getSession()
+ hostid = self.getNodeId(hostname)
+ try:
+ old_attr = Attribute.loadOne(session, attr=attr, node=hostid)
+ except sqlalchemy.orm.exc.NoResultFound:
+ new_attr = Attribute(attr=attr, value=value, node=hostid)
+ session.add(new_attr)
+ session.commit()
+ return
+
+ old_value = old_attr.value
+ old_attr.value = value
+
+ # 如果数据库中已经存在了数据,则将原有数据加上 _old重新存储.
+ if not attr.endswith(attr_postfix):
+ self.setCategoryAttr(hostname, attr + attr_postfix, old_value)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/db/mappings/__init__.py b/lib/sunhpc/db/mappings/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sunhpc/db/mappings/__init__.py
diff --git a/lib/sunhpc/db/mappings/base.py b/lib/sunhpc/db/mappings/base.py
new file mode 100644
index 0000000..e397d68
--- /dev/null
+++ b/lib/sunhpc/db/mappings/base.py
@@ -0,0 +1,177 @@
+#coding:utf-8
+from sqlalchemy import *
+import sqlalchemy.orm
+import sqlalchemy.ext.declarative
+
+Base = sqlalchemy.ext.declarative.declarative_base()
+
+class SunhpcBase(object):
+ """Additional base class of Sunhpc ORM hierarchy which includes some
+ helper methods for all classes"""
+
+ @property
+ def session(self):
+ """Singelton which return the session of the object"""
+ return sqlalchemy.orm.session.object_session(self)
+
+
+ def delete(self):
+ """instance method to autodelete the instance which calls it
+
+ so you can use
+ node.delete()"""
+ self.session.delete(self)
+
+ @classmethod
+ def loadOne(cls, session, **kwargs):
+ """ """
+ return cls.load(session, **kwargs).one()
+
+ @classmethod
+ def load(cls, session, **kwargs):
+ """
+ this method allow us to run query on all the mapping objects
+ simply using
+
+ e.g.::
+
+ node = Nodes.load(session, Name='compute-0-0', Cpus=2)
+ nic = Network.load(session, Name='compute-0-0', Interface='eth0')
+
+ taken from:
+ http://petrushev.wordpress.com/2010/06/22/sqlalchemy-base-model/
+ """
+ q = session.query(cls)
+ filters = [getattr(cls, field_name)==kwargs[field_name] \
+ for field_name in kwargs]
+ return q.filter(and_(*filters))
+
+class Attribute(SunhpcBase, Base):
+ __tablename__ = 'attributes'
+ __table_args__ = {}
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ attr = Column('Attr', String(128), nullable=False)
+ value = Column('Value', TEXT())
+ shadow = Column('Shadow', TEXT())
+ categoryID = Column('Category', Integer, ForeignKey('categories.ID'), nullable=False)
+ catindexID = Column('Catindex', Integer, ForeignKey('catindex.ID'), nullable=False)
+
+class Bootaction(SunhpcBase, Base):
+ __tablename__ = 'bootactions'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ action = Column('Action', String(256))
+ kernel = Column('Kernel', String(256))
+ ramdisk = Column('Ramdisk', String(256))
+ args = Column('Args', String(1024))
+
+class Category(SunhpcBase, Base):
+ __tablename__ = 'categories'
+ __table_args__ = {}
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ name = Column('Name', String(64), nullable=False)
+ description = Column('Description', String(512))
+ attributes = sqlalchemy.orm.relationship('Attribute', backref='category')
+ catindexes = sqlalchemy.orm.relationship('Catindex', backref='category')
+
+class Catindex(SunhpcBase, Base):
+ __tablename__ = 'catindex'
+ __table_args__ = {}
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ name = Column('Name', String(64), nullable=False)
+ categoryID = Column('Category', Integer, ForeignKey('categories.ID'), nullable=False)
+ attributes = sqlalchemy.orm.relationship('Attribute', backref='catindex')
+
+class Node(SunhpcBase, Base):
+ __tablename__ = 'nodes'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ name = Column('Name', String(128))
+ cpus = Column('CPUs', Integer, nullable=False, default=1)
+ rack = Column('Rack', Integer)
+ rank = Column('Rank', Integer)
+ arch = Column('Arch', String(32))
+ os = Column('OS', Enum(u'linux', u'sunos'), nullable=False, default=u'linux')
+ alias = Column('Alias', String(128), default='')
+ status = Column('Status', String(128), default='install')
+
+ networks_ID = sqlalchemy.orm.relationship("Network", backref="nodes")
+ pub_keys_ID = sqlalchemy.orm.relationship("PublicKey", backref="nodes")
+ def __repr__(self):
+ return "<Node(name='%s')>" % (self.name)
+
+class Network(SunhpcBase, Base):
+ __tablename__ = 'networks'
+ __table_args__ = {}
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ node = Column('Node', Integer, ForeignKey('nodes.ID'))
+ mac = Column('MAC', String(64))
+ ip = Column('IP', String(32))
+ name = Column('Name', String(128))
+ device = Column('Device', String(32))
+ subnet_ID = Column('Subnet', Integer, ForeignKey('subnets.ID'))
+
+class Subnet(SunhpcBase, Base):
+ __tablename__ = 'subnets'
+ __table_args__ = {}
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ name = Column('name', String(32), nullable=False, unique=True)
+ dnszone = Column('dnszone', String(64), nullable=False, unique=True)
+ subnet = Column('subnet', String(32), nullable=False)
+ netmask = Column('netmask', String(32), nullable=False)
+ mtu = Column('mtu', Integer, default=1500)
+ servedns = Column('servedns', Boolean, default=False)
+ networks = sqlalchemy.orm.relationship('Network', backref='subnet')
+
+class PublicKey(SunhpcBase, Base):
+ __tablename__ = 'publickeys'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ node = Column('Node', Integer, ForeignKey('nodes.ID'), nullable=False)
+ public_key = Column('Public_Key', String(4096))
+ description = Column('Description', String(4096))
+
+class Roll(SunhpcBase, Base):
+ __tablename__ = 'rolls'
+ __table_args__ = {}
+
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ name = Column('Name', String(128), nullable=False)
+ version = Column('Version', String(32), nullable=False)
+ arch = Column('Arch', String(32), nullable=False)
+ os = Column('OS', Enum('linux', 'sunos'), nullable=False, default=u'linux')
+ enabled = Column('Enabled', Enum(u'yes', u'no'), nullable=False, default=u'yes')
+
+class SecNode(SunhpcBase, Base):
+ __tablename__ = 'secnodes'
+ __table_args__ = {}
+ attr = Column('Attr', String(128), primary_key=True, nullable=False)
+ enc = Column('Enc', String(64))
+ node = Column('Node', Integer, primary_key=True, nullable=False)
+ value = Column('Value', TEXT())
+
+class SecGlobal(SunhpcBase, Base):
+ __tablename__ = 'secglobals'
+ __table_args__ = {}
+ attr = Column('Attr', String(128), primary_key=True, nullable=False)
+ enc = Column('Enc', String(64))
+ value = Column('Value', TEXT())
+
+class Partition(SunhpcBase, Base):
+ __tablename__ = 'partitions'
+ __table_args__ = {}
+ device = Column('Device', String(128), nullable=False)
+ formatFlags = Column('FormatFlags', String(128), nullable=False)
+ fsType = Column('FsType', String(128), nullable=False)
+ ID = Column('ID', Integer, primary_key=True, nullable=False, autoincrement=True)
+ mountpoint = Column('Mountpoint', String(128), nullable=False)
+ node = Column('Node', Integer, nullable=False)
+ partitionFlags = Column('PartitionFlags', String(128), nullable=False)
+ partitionID = Column('PartitionID', String(128), nullable=False)
+ partitionSize = Column('PartitionSize', String(128), nullable=False)
+ sectorStart = Column('SectorStart', String(128), nullable=False)
+
+
diff --git a/lib/sunhpc/db/sqlite-bak/database.py b/lib/sunhpc/db/sqlite-bak/database.py
new file mode 100644
index 0000000..3885ff5
--- /dev/null
+++ b/lib/sunhpc/db/sqlite-bak/database.py
@@ -0,0 +1,165 @@
+#coding:utf-8
+
+import os
+import sys
+import sunhpc
+import sqlite3
+import threading
+import subprocess
+
+threadlocal = threading.local()
+
+class Database(object):
+ """
+ This class should proxy all the connection to the database.
+
+ These two objects have two separate DB connections, which means
+ that DB status can be different when queried through them.
+
+ Usage Example::
+ db = Database()
+ db.connect()
+ """
+
+ def __init__(self):
+ self.conn = None
+ self.curs = None
+ self.results = None
+ self._dbPath = "/opt/sunhpc/data"
+ self._dbFile = "sunhpc.db"
+ self._datafile = os.path.join(self._dbPath, self._dbFile)
+
+ def connect(self):
+ """
+ It start the connection to the DB and create all the internal
+ data structure
+ """
+ self.conn = sqlite3.connect(self._datafile, isolation_level=None)
+ self.curs = self.conn.cursor()
+
+ def getSession(self):
+ session = getattr(threadlocal, "session", None)
+
+ if session:
+ return session
+ elif self.conn:
+ session = self.curs
+ setattr(threadlocal, "session", session)
+ return session
+ else:
+ return None
+
+ def closeSession(self):
+ session = getattr(threadlocal, "session", None)
+ if session:
+ session.close()
+ setattr(threadlocal, "session", None)
+
+ def commit(self):
+ """
+ Commit the current session if it existsi. *It does not touch the
+ connection.*
+ """
+ session = self.getSession()
+ if session:
+ self.conn.commit()
+ else:
+ pass
+
+ def fetchone(self):
+ """
+ Fetch one row from the results of the previous query
+
+ :rtype: tuple
+ :return: a tuple containing the values of the fetched row.
+ """
+ if self.results:
+ return self.results.fetchone()
+ return ()
+
+ def fetchall(self):
+ """
+ Fetch all rows from the results of the previous query
+
+ :rtype: list
+ :return: a list of tuples containing the values of the fetched rows
+ """
+ if self.results:
+ return self.results.fetchall()
+ return ()
+
+ def execute(self, command):
+
+ if self.conn:
+ if '%' in command:
+ command = command.replace('%', '%%')
+ try:
+ self.results = self.curs.execute(command)
+ except sqlite3.OperationalError as e:
+ print ('SQliteErr: ', repr(e))
+ sys.exit(-1)
+ #
+ # lastrowid 返回插入行的id
+ # rowcount 使用insert, update, delete语句,一般返回1,
+ # 其他操作, 例如select,或者创建表等操作返回-1
+ # 如果需要给execute方法传递参数,则有几个?占位符,就必须对应几个参数,
+ # 例如:
+ # cursor.execute('select * from user where name=? and pwd=?', ('abc', 'password'))
+ #
+ # print ('lastrowid: ', self.results.lastrowid)
+ # print ('rowcount : ', self.results.rowcount)
+ # lastrowid: 0
+ # rowcount : -1
+ # bool : (-1, 1, 2, )-> True, (0,None)-> False
+ return self.results.rowcount
+ else:
+ return None
+
+ def close(self):
+ if self.results:
+ self.results.close()
+ self.results = None
+ if self.conn:
+ self.conn.close()
+
+ def closeSession(self):
+ session = getattr(threadlocal, 'session', None)
+ if session:
+ session.close()
+ setattr(threadlocal, 'session', None)
+
+ def test(self):
+ qry = 'create table userb (id varchar(20) primary key, name varchar(20))'
+ self.execute(qry)
+
+ # insert
+ #qry = 'insert into user (id, name) values ("5", "Michael")'
+ #row = self.execute(qry)
+ #print ('row: ', row)
+
+ '''
+ qry = 'select * from usera'
+ self.execute(qry)
+ data = self.fetchall()
+ print ('data: ', data)
+ '''
+
+
+if __name__ == "__main__":
+ d = Database()
+ d.connect()
+
+ d.test()
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/db/sqlite-bak/helper.py b/lib/sunhpc/db/sqlite-bak/helper.py
new file mode 100644
index 0000000..4e9e36c
--- /dev/null
+++ b/lib/sunhpc/db/sqlite-bak/helper.py
@@ -0,0 +1,367 @@
+#coding:utf-8
+
+import socket
+import sunhpc
+import sunhpc.core.utils
+import sunhpc.db.database
+
+class DatabaseHelper(sunhpc.db.database.Database):
+ """
+ This class extend the Database class with a set of helper methods
+ which deal with the new ORM interface (aka the objects from the
+ sunhpc.db.mappings classes).
+
+ These methods should replace the old methods in :mod:`sunhpc.commands`
+ only methods relative to the command line should remain in the that
+ module, all DB functionality should be migrated.
+ """
+
+ def __init__(self):
+ super(DatabaseHelper, self).__init__()
+
+ self._appliances_list = None
+ self._attribute = None
+ self._frontend = None
+ self._cacheAttrs = {}
+
+ def getFrontendName(self):
+ if self._frontend:
+ return self._frontend
+
+ self._frontend = self.getHostAttr('localhost', 'Kickstart_PrivateHostname')
+ return self._frontend
+
+ def getNodesfromNames(self, names=None, managed_only=0):
+
+ list = []
+ if not names:
+
+ list = self.getSession().query(Node)
+ list = list.all()
+
+ if managed_only:
+ managed_list = []
+ for hostname in list:
+ if self.getHostAttr(hostname,
+ 'managed') == 'true':
+ managed_list.append(hostname)
+ return managed_list
+ return list
+
+ clause = sqlalchemy.sql.expression.false()
+ query = self.getSession().query(Node)
+
+ for name in names:
+ if name.find('select ') == 0: # SQL select
+ self.execute(name)
+ nodes = [i for i, in self.fetchall()]
+ clause = or_(clause, Node.name.in_(nodes))
+ elif name.find('%') >= 0: # SQL % pattern
+ clause = or_(clause, Node.name.like(name))
+
+ else:
+ clause = or_(clause, Node.name == self.getHostname(name))
+
+ # now we register the query on the table Node and append all our clauses on OR
+ query = query.filter(clause)
+ return query.all()
+
+ def getHostIp(self, hostname):
+ hostname = self.getHostname(hostname)
+
+ rows = self.search('''select n.ip from networks n, subnets s
+ WHERE
+ n.subnet = s.id and n.name = "%s"
+ and s.name = "private"
+ ''' % hostname)
+ if rows:
+ return self.fetchone()[0]
+
+ def getHostname(self, hostname=None):
+ """涉及到数据库, 数据库中域名, /etc/hosts 和 DNS解析."""
+
+ arghostname = hostname
+ if hostname:
+ rows = self.search("select * from nodes where name='%s'" % hostname)
+ if rows:
+ return hostname
+
+ '''
+ try:
+ int(hostname)
+ rows = self.search("select name from nodes where id='%d'" % hostname)
+ if rows:
+ return self.fetchone()[0]
+ except ValueError:
+ return
+ '''
+
+ if not hostname:
+ hostname = socket.gethostname().split('.')[0]
+ try:
+ addr = socket.gethostbyname(hostname)
+ if not addr == hostname:
+ (name, aliases, addrs) = socket.gethostbyaddr(addr)
+ if hostname != name and hostname not in aliases:
+ raise NameError
+ except:
+ addr = None
+ if hostname == 'localhost':
+ addr = '127.0.0.1'
+
+ if not addr and self.conn:
+ self.execute("select name from nodes where name='%s'" % hostname)
+ if self.fetchone():
+ return hostname
+
+ # let's check if the name is nodeID
+ row = self.search("select name from nodes where nodes.id = '%s'" % hostname)
+ if row:
+ (hostname, ) = self.fetchone()
+ return hostname
+
+ # let's check if the name is an alias
+ row = self.search("select name from nodes where nodes.alias = '%s'" % hostname)
+ if row:
+ (hostname, ) = self.fetchone()
+ return hostname
+
+ # let's check if this is a mac address
+ self.execute("""select nodes.name from networks, nodes where
+ nodes.id = networks.node and
+ networks.mac = '%s' """ % hostname)
+ try:
+ hostname, = self.fetchone()
+ return hostname
+ except:
+ pass
+
+ # let's check if this is a FQDN
+ n = hostname.split('.')
+ if len(n) > 1:
+ name = n[0]
+ domain = '.'.join(n[1:])
+ query = """select n.name from nodes n, networks nt, subnets s where
+ nt.subnet=s.id and nt.node=n.id and
+ s.dnszone='{}' and (nt.name='{}' or n.name='{}')
+ """.format(domain, name, name)
+ self.execute(query)
+ try:
+ hostname, = self.fetchone()
+ return hostname
+ except:
+ pass
+
+ #
+ # 以下是hostname不在数据库中才会走到这里.
+ # 例如, hostname: aaa, resolv: search lan, /etc/hosts: x.x.x.x aaa
+ # 最终, hostname: aaa
+ # 如果, hostname: aaa.lan , 则会触发 raise cannot resolve host aaa.lan
+ #
+ #
+ # 例如, hostname: aaa, resolv: search lan, /etc/hosts: x.x.x.x aaa.lan
+ # 最终, hostname: aaa.lan
+ #
+ # 如果, hostname 不在/etc/hosts中,直接触发 raise
+ #
+
+ try:
+ fin = open('/etc/resolv.conf', 'r')
+ except:
+ fin = None
+ if fin:
+ # ,e.g.. search local hpc.org
+ # domains = ['local', 'hpc.org']
+ domains = []
+ for line in fin.readlines():
+ tokens = line[:-1].split()
+ if len(tokens) > 0 and tokens[0] == 'search':
+ domains = tokens[1:]
+ for domain in domains:
+ try:
+ name = '%s.%s' % (hostname, domain)
+ addr = socket.gethostbyname(name)
+ hostname = name
+ break
+ except:
+ pass
+ # 如果 hostname 能够在/etc/hosts中匹配.(必须完全匹配包括域名)
+ # 那么 addr 就可以正常得到 /etc/hosts对应的IP地址
+ # 这个时候 hostname 应该是带域名 例如: aaa.lan
+ # 然后使用 aaa.lan 继续使用 getHostname函数进行执行.
+ if addr:
+ return self.getHostname(hostname)
+
+ fin.close()
+
+ raise (sunhpc.core.utils.HostnotfoundException(\
+ 'cannot resolve host "%s"' % hostname))
+
+ if addr == '127.0.0.1': # allow localhost to be valid
+ if arghostname == None:
+ return 'localhost'
+ else:
+ return self.getHostname()
+
+ if self.conn:
+ rows = self.search('select nodes.name from '
+ 'networks, nodes where '
+ 'nodes.id=networks.node and ip="%s"' % addr)
+ if not rows:
+ rows = self.search('select nodes.name '
+ 'from networks, nodes where '
+ 'nodes.id=networks.node and '
+ 'networks.name="%s"' % hostname)
+ if not rows:
+ raise (sunhpc.core.utils.HostnotfoundException(\
+ 'host "%s" is not in cluster' % hostname))
+ hostname, = self.fetchone()
+
+ return hostname
+
+ def getHostAttrs(self, hostname):
+
+ session = self.getSession()
+ if isinstance(hostname, str):
+ hostname = self.getHostname(hostname)
+ node = session.query(Node).filter(Node.name == hostname).one()
+
+ elif isinstance(hostname, Node):
+ node = hostname
+ hostname = node.name
+ else:
+ assert False, "hostname must be either a string with a hostname or a Node."
+
+ attrs = {}
+ attrs['hostname'] = hostname
+ attrs['rack'] = str(node.rack)
+ attrs['rank'] = str(node.rank)
+
+ query = """select attr,value from attributes, nodes where
+ nodes.id = attributes.node and nodes.name='%s'""" % hostname
+ for (attr, value) in self.conn.execute(query):
+ attrs[attr] = value
+
+ return attrs
+
+ def getHostAttr(self, hostname, attr):
+ return self.getHostAttrs(hostname).get(attr)
+
+ def checkHostnameValidity(self, hostname):
+ """
+ check that the given host name is valid
+ it checks that the hostname:
+ - it does not contain any .
+ - it is not already used
+ - it is not in the form of rack<number>
+ - it is not an alias
+ - it is not a mac address
+ """
+ if '.' in hostname:
+ raise (sunhpc.core.utils.CommandError('Hostname %s can not contains any dot.'
+ % hostname))
+
+ msg = ''
+ if hostname.startswith('rack'):
+ number = hostname.split('rack')[1]
+ try:
+ int(number)
+ msg = ('Hostname %s can not be in the from ' \
+ + 'of rack<number>.\n') % hostname
+ msg += 'select a different hostname.'
+ except ValueError:
+ pass
+ if msg:
+ raise (sunhpc.core.utils.CommandError(msg))
+
+ try:
+ host = self.getHostname(hostname)
+ if host:
+ msg = 'Node "%s" already exists.\n' % hostname
+ msg += 'Select a different hostname, cabinet '
+ msg += 'and/or rank value.'
+ except (sunhpc.core.utils.HostnotfoundException, NameError):
+ # good, Host does not exists.
+ return
+ raise sunhpc.core.utils.CommandError(msg)
+
+ def getNodeId(self, host):
+ """Lookup hostname in nodes table. Host may be a name
+ or an IP address. Returns None if not found."""
+ try:
+ return int(host)
+ except Exception:
+ pass
+
+ self.execute('select id from nodes where name="%s"' % host)
+ try:
+ nodeid, = self.fetchone()
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.ip="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ self.execute('select max(n.id) from networks net, \
+ nodes n where net.mac="%s" and net.node=n.id' % host)
+ try:
+ nodeid, = self.fetchone()
+ if nodeid:
+ return nodeid
+ except TypeError:
+ nodeid = None
+
+ return nodeid
+
+ def setNewHostAttr(self, hostname, attr, value):
+
+ session = self.getSession()
+ hostid = self.getNodeId(hostname)
+ try:
+ old_attr = Attribute.loadOne(session, attr=attr, node=hostid)
+ except sqlalchemy.orm.exc.NoResultFound:
+ new_attr = Attribute(attr=attr, value=value, node=hostid)
+ session.add(new_attr)
+ session.commit()
+ return
+
+ old_value = old_attr.value
+ old_attr.value = value
+
+ # 如果数据库中已经存在了数据,则将原有数据加上 _old重新存储.
+ if not attr.endswith(attr_postfix):
+ self.setCategoryAttr(hostname, attr + attr_postfix, old_value)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/invoke.py b/lib/sunhpc/invoke.py
new file mode 100644
index 0000000..44bdb63
--- /dev/null
+++ b/lib/sunhpc/invoke.py
@@ -0,0 +1,10 @@
+#coding:utf-8
+import sunhpc
+import sunhpc.core
+import sunhpc.core.ip
+import sunhpc.core.sql
+import sunhpc.core.dist
+import sunhpc.core.files
+import sunhpc.core.build
+import sunhpc.core.security
+import sunhpc.core.partition
diff --git a/lib/sunhpc/modules/__init__.py b/lib/sunhpc/modules/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sunhpc/modules/__init__.py
diff --git a/lib/sunhpc/modules/compute/00_selinux.py b/lib/sunhpc/modules/compute/00_selinux.py
new file mode 100644
index 0000000..13f4558
--- /dev/null
+++ b/lib/sunhpc/modules/compute/00_selinux.py
@@ -0,0 +1,11 @@
+#coding:utf-8
+
+class Modules(object):
+ def __init__(self, command):
+ self.cmd = command
+
+ def run(self, args):
+ print ('this is compute selinux modules....')
+
+ def __repr__(self):
+ return "selinux"
diff --git a/lib/sunhpc/modules/compute/__init__.py b/lib/sunhpc/modules/compute/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sunhpc/modules/compute/__init__.py
diff --git a/lib/sunhpc/modules/control/00-base.py b/lib/sunhpc/modules/control/00-base.py
new file mode 100644
index 0000000..a539411
--- /dev/null
+++ b/lib/sunhpc/modules/control/00-base.py
@@ -0,0 +1,191 @@
+#coding:utf-8
+import time
+import sunhpc
+class Modules(object):
+ """
+ Configure base for the sunhpc cluster
+ """
+ def __init__(self, command):
+ self.cmd = command
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ content = []
+ content.append(self.__help__)
+
+ # hostname setting
+ hostname = self.cmd.db.getHostAttr('localhost', 'Kickstart_PublicHostname')
+ hostname_info = "hostnamectl --static set-hostname %s" % hostname
+ content.append(hostname_info)
+
+ # selinux config
+ selinux = '/etc/selinux/config'
+ selinux_info = """
+ #
+ # Generated by sunhpc selinux module.
+ #
+ # This file controls the state of SELinux on the system.
+ # SELINUX= can take one of these three values:
+ # enforcing - SELinux security policy is enforced.
+ # permissive - SELinux prints warnings instead of enforcing.
+ # disabled - No SELinux policy is loaded.
+ SELINUX=disabled
+ # SELINUXTYPE= can take one of three values:
+ # targeted - Targeted processes are protected,
+ # minimum - Modification of targeted policy. Only selected processes are protected.
+ # mls - Multi Level Security protection.
+ SELINUXTYPE=targeted
+ """
+ content.append(self.cmd.ks.makefile(selinux_info, selinux))
+
+ # hosts
+ hostfile = '/etc/hosts'
+ hostinfo = self.cmd.command('report.host').strip()
+ content.append(self.cmd.ks.makefile(hostinfo, hostfile))
+
+
+ # resolv
+ resolv = '/etc/resolv.conf'
+ search_lan = self.cmd.db.getHostAttr('localhost', 'Kickstart_PrivateDNSDomain')
+ search_wan = self.cmd.db.getHostAttr('localhost', 'Kickstart_PublicDNSDomain')
+ public_dns = self.cmd.db.getHostAttr('localhost', 'Kickstart_PublicDNSServer')
+ resolv_info = """
+ search %s %s
+ nameserver %s
+ nameserver 223.5.5.5
+ """ % (search_lan, search_wan, public_dns)
+ content.append(self.cmd.ks.makefile(resolv_info, resolv))
+
+ # motd
+ tnow = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())
+ motd = '/etc/motd'
+ motd_info = """
+ SunhpcOS %s (%s)
+ Profile built %s
+
+ _____ _
+ / ____| | |
+ ======= | (___ _ _ _ __ | |__ _ __ ___
+ \ // \___ \| | | | '_ \| '_ \| '_ \ / __|
+ \ // ___ ) | |_| | | | | | | | |_) | (__
+ \// |_____/ \__,_|_| |_|_| |_| .__/ \___|
+ | |
+ Powered by HengPU Technology |_| By kylins
+
+ PlatName : HengPu High Performance Computing Platform
+ WeChat : xiubuzhe
+ Version : %s (%s)
+ Email : info@sunhpc.com
+ Phone : +86 18640977650
+ Homepage : https://www.sunhpc.com - @kylins
+ -----------------------------------------------------
+
+
+ """ % (sunhpc.version, sunhpc.release, tnow, sunhpc.version, sunhpc.release)
+ content.append(self.cmd.ks.makefile(motd_info, motd))
+
+ # sunhpc-release
+ rels = '/etc/sunhpc-release'
+ rels_info = "SunhpcOS %s" % sunhpc.version
+ content.append(self.cmd.ks.makefile(rels_info, rels))
+
+ # sys-release
+ sysrels = '/etc/system-release'
+ sysrels_info = "SunhpcOS Linux release %s (%s)" % (sunhpc.version, sunhpc.release)
+ content.append(self.cmd.ks.makefile(sysrels_info, sysrels))
+
+ # os-release
+ osrels = '/etc/os-release'
+ osrels_info = """
+ NAME="SunHPC Linux"
+ VERSION="%s (%s)"
+ ID="sunhpc"
+ ID_LIKE="rhel fedora"
+ VERSION_ID="%s"
+ PRETTY_NAME="SunhpcOS Linux %s (%s)"
+ ANSI_COLOR="0;31"
+ CPE_NAME="cpe:/o:sunhpc:sunhpc:%s"
+ HOME_URL="https://www.sunhpc.com/"
+ BUG_REPORT_URL="https://bugs.sunhpc.com/"
+
+ CENTOS_MANTISBT_PROJECT="SunhpcOS-%s"
+ CENTOS_MANTISBT_PROJECT_VERSION="%s"
+ REDHAT_SUPPORT_PRODUCT="sunhpc"
+ REDHAT_SUPPORT_PRODUCT_VERSION="%s"
+ """ % (self.cmd.major, sunhpc.release,
+ self.cmd.major, self.cmd.major, sunhpc.release,
+ self.cmd.major, self.cmd.major, self.cmd.major,
+ self.cmd.major)
+ content.append(self.cmd.ks.makefile(osrels_info, osrels))
+
+ # issue
+ issue = '/etc/issue'
+ issue_net = '/etc/issue.net'
+ issue_info = """
+ \S
+ Kernel \r on an \m
+ Current times \d \t
+
+ _____ _
+ / ____| | |
+ ======= | (___ _ _ _ __ | |__ _ __ ___
+ \\ // \\___ \\| | | | '_ \\| '_ \\| '_ \\ / __|
+ \\ // ___ ) | |_| | | | | | | | |_) | (__
+ \\// |_____/ \\__,_|_| |_|_| |_| .__/ \\___|
+ | |
+ Powered by HengPU Technology |_| By kylins
+
+ PlatName : HengPu High Performance Computing Platform
+ WeChat : xiubuzhe
+ Version : %s (%s)
+ Email : info@sunhpc.com
+ Phone : +86 18640977650
+ Homepage : https://www.sunhpc.com - @kylins
+ -----------------------------------------------------
+
+ """ % (sunhpc.version, sunhpc.release)
+ content.append(self.cmd.ks.makefile(issue_info, issue))
+ content.append(self.cmd.ks.makefile(issue_info, issue_net))
+
+ self.cmd.ks.addMain(content)
+ def __repr__(self):
+ return "control-base"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/modules/control/05-securtiy.py b/lib/sunhpc/modules/control/05-securtiy.py
new file mode 100644
index 0000000..ab71cc4
--- /dev/null
+++ b/lib/sunhpc/modules/control/05-securtiy.py
@@ -0,0 +1,98 @@
+#coding:utf-8
+import time
+import sunhpc
+class Modules(object):
+ """
+ Configure base for the sunhpc cluster
+ """
+ def __init__(self, command):
+ self.cmd = command
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ content = []
+ content.append(self.__help__)
+
+ # create rsa keys
+ rsa_command = "sunhpc create security keyname=rsa"
+ content.append(rsa_command)
+
+ ssh_config = '/etc/ssh/ssh_config'
+ ssh_config_info = """
+ Host *
+ CheckHostIP no
+ ForwardX11 yes
+ ForwardAgent yes
+ StrictHostKeyChecking no
+ UsePrivilegedPort no
+ Protocol 2,1
+ HostbasedAuthentication yes
+ EnableSSHKeySign yes
+ """
+ content.append(self.cmd.ks.makefile(ssh_config_info, ssh_config))
+
+
+ sshd_config = '/etc/ssh/sshd_config'
+ sshd_config_info = """
+ Match User root
+ HostbasedAuthentication no
+ """
+ content.append(self.cmd.ks.makefile(sshd_config_info, sshd_config, mode="append"))
+
+ root_info = """
+ ROOTRSAKEY=/root/.ssh/id_rsa
+ if [ ! -f $ROOTRSAKEY ]; then
+ /usr/bin/ssh-keygen -q -t rsa -P '' -f $ROOTRSAKEY
+ fi
+
+ mkdir -p /etc/ssh/authorized_keys
+ AUTHKEYS = /etc/ssh/authorized_keys/id_rsa.pub
+ if [ ! -f $AUTHKEYS ]; then
+ /usr/bin/cat /root/.ssh/id_rsa.pub > $AUTHKEYS
+ fi
+ """
+ content.append(root_info)
+
+
+ self.cmd.ks.addMain(content)
+ def __repr__(self):
+ return "control-base"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/modules/control/50-packages.py b/lib/sunhpc/modules/control/50-packages.py
new file mode 100644
index 0000000..28bbf5c
--- /dev/null
+++ b/lib/sunhpc/modules/control/50-packages.py
@@ -0,0 +1,94 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Configure nodes packages.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ must_package = ['libffi-devel']
+
+
+ content = []
+ content.append(self.__help__)
+ #@^gnome-desktop-environment
+ parts = """
+ @^developer-workstation-environment
+ @additional-devel
+ @base
+ @compat-libraries
+ @core
+ @debugging
+ @desktop-debugging
+ @development
+ @dial-up
+ @directory-client
+ @fonts
+ @gnome-apps
+ @gnome-desktop
+ @guest-agents
+ @guest-desktop-agents
+ @hardware-monitoring
+ @identity-management-server
+ @infiniband
+ @input-methods
+ @internet-applications
+ @internet-browser
+ @java-platform
+ @large-systems
+ @multimedia
+ @network-file-system-client
+ @performance
+ @perl-runtime
+ @perl-web
+ @php
+ @platform-devel
+ @print-client
+ @python-web
+ @ruby-runtime
+ @system-admin-tools
+ @virtualization-client
+ @virtualization-hypervisor
+ @virtualization-tools
+ @web-server
+ @x11
+ """
+
+ mini = """
+ @^minimal
+ wget
+ net-tools
+ libffi-devel
+ """
+ content.append(mini)
+ self.cmd.ks.addPackages(content)
+
+ def __repr__(self):
+ return "kickstart-packages"
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/modules/control/__init__.py b/lib/sunhpc/modules/control/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sunhpc/modules/control/__init__.py
diff --git a/lib/sunhpc/modules/kickstart/00-base.py b/lib/sunhpc/modules/kickstart/00-base.py
new file mode 100644
index 0000000..e07100e
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/00-base.py
@@ -0,0 +1,79 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ base config module.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ addr = self.cmd.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ base = self.cmd.db.getHostAttr('localhost', 'Kickstart_BaseDir')
+ dist = self.cmd.db.getHostAttr('localhost', 'distribution')
+ http = "http://%s/%s/%s/%s" % (addr, base, dist, self.cmd.arch)
+ zone = self.cmd.db.getHostAttr('localhost', 'Kickstart_Timezone')
+
+ main = ['#', '# sunhpc cluster system for kickstart', '#']
+ main.append(self.__help__)
+
+ # System authorization information
+ main.append('auth --enableshadow --passalgo=sha512')
+
+ # Use URL installation
+ main.append('install')
+ main.append('url --url %s' % http)
+
+ # Use graphical install
+ main.append('graphical')
+
+ # # Keyboard layouts
+ main.append("keyboard --vckeymap=us --xlayouts='us'")
+
+ # System language
+ main.append('lang en_US.UTF-8')
+
+ main.append('eula --agreed')
+
+ # Run the Setup Agent on first boot
+ main.append('firstboot --disable')
+
+ # System timezone
+ main.append('timezone %s --isUtc --nontp' % zone)
+
+ # X Window System configuration information
+ # 请不要在不安装X windows系统中使用下面命令.
+ # --startxonboot - 在安装的系统中使用图形界面登录
+ # main.append('xconfig --startxonboot')
+
+ # System services
+ main.append("services --disabled='chronyd'")
+
+ # Root password
+ passwd = '$6$6oVIU.5y$mAFN7begXFk5A1g0ODIQauIz6na5ja3AM0'
+ passwd += 'QebWc8dLVHmLguPv65nVQbpYR3.w1h6HdfbUkFfhJkv/KNO2Sj3/'
+ main.append('rootpw --iscrypted %s' % passwd)
+
+ # Reboot install finished
+ main.append('reboot')
+
+ # Disk partitioning information
+ main.append('%include /tmp/partition-info')
+
+ self.cmd.ks.addMain(main)
+
+ def __repr__(self):
+ return "kickstart-base"
diff --git a/lib/sunhpc/modules/kickstart/05-partition.py b/lib/sunhpc/modules/kickstart/05-partition.py
new file mode 100644
index 0000000..42ac552
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/05-partition.py
@@ -0,0 +1,42 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Client auto partition module.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ content = []
+ content.append(self.__help__)
+
+ parts = """
+ /bin/echo "sunhpc force-default" > /tmp/user_partition_info
+
+ if [ ! -f /tmp/do_partition.py ]
+ then
+ /bin/cat /applets/partition.py > /tmp/do_partition.py
+ fi
+
+ /bin/chmod 755 /tmp/do_partition.py
+ /tmp/do_partition.py
+ """
+ content.append(parts)
+ self.cmd.ks.addPre(content)
+
+ def __repr__(self):
+ return "kickstart-partition"
diff --git a/lib/sunhpc/modules/kickstart/10-hostauth.py b/lib/sunhpc/modules/kickstart/10-hostauth.py
new file mode 100644
index 0000000..4ffdbc4
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/10-hostauth.py
@@ -0,0 +1,62 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Configure sshd, authorized, rsa and so on.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ text = ''
+ # get root's public keys
+ public = '/etc/ssh/authorized_keys/id_rsa.pub'
+ try:
+ with open(public, 'r') as f:
+ for key in f.readlines():
+ if len(key) > 0:
+ text = key[:-1]
+ except:
+ pass
+
+ content = []
+ content.append(self.__help__)
+
+ parts = """
+ [ ! -e /root/.ssh ] && mkdir -p /root/.ssh
+ chmod -R 700 /root/.ssh
+
+ cat > /root/.ssh/authorized_keys << 'EOF'
+ %s
+ EOF
+ chmod 600 /root/.ssh/authorized_keys
+
+ ifconfig virbr0 down
+ brctl delbr virbr0
+ systemctl disable libvirtd.service
+
+ sed -i 's/#UseDNS yes/UseDNS no/' /etc/ssh/sshd_config
+
+ systemctl set-default multi-user.target
+
+ systemctl enable autofs
+ """ % (text)
+
+ content.append(parts)
+ self.cmd.ks.addPost(content)
+
+ def __repr__(self):
+ return "kickstart-secret"
diff --git a/lib/sunhpc/modules/kickstart/12-security.py b/lib/sunhpc/modules/kickstart/12-security.py
new file mode 100644
index 0000000..42ae555
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/12-security.py
@@ -0,0 +1,147 @@
+#coding:utf-8
+import time
+import sunhpc
+class Modules(object):
+ """
+ Configure sunhpc cluster security.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ content = [self.__help__]
+
+ # master.pub and shared.key
+ try:
+ nodeid = self.cmd.newdb.getNodeId(args)
+ except:
+ nodeid = None
+
+ if nodeid:
+ mkeys = '/etc/safe-security/master.pub'
+ skeys = '/etc/safe-security/shared.key'
+ with open(mkeys, 'r') as f: mtext = f.read()
+ with open(skeys, 'r') as f: stext = f.read()
+
+ content.append(
+ self.cmd.ks.makefile(
+ mtext, mkeys, perms="444", owner="root:root"))
+ content.append(
+ self.cmd.ks.makefile(
+ stext, skeys, perms="400", owner="root:root"))
+
+ # ssh config not autogen ssh keys
+ # check /usr/sbin/sshd-keygen
+ sshd_file = "/etc/sysconfig/sshd"
+ sshconfig = """
+ # Configuration file for the sshd service.
+
+ # The server keys are automatically generated if they are missing.
+ # To change the automatic creation uncomment and change the appropriate
+ # line. Accepted key types are: DSA RSA ECDSA ED25519.
+ # The default is "RSA ECDSA ED25519"
+
+ AUTOCREATE_SERVER_KEYS="NO"
+ # AUTOCREATE_SERVER_KEYS="RSA ECDSA ED25519"
+
+ # Do not change this option unless you have hardware random
+ # generator and you REALLY know what you are doing
+
+ SSH_USE_STRONG_RNG=0
+ # SSH_USE_STRONG_RNG=1
+ """
+ content.append(self.cmd.ks.makefile(sshconfig, sshd_file, perms="640", owner="root.root"))
+
+ # /etc/safe.conf
+ address = self.cmd.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ safeport = self.cmd.db.getHostAttr('localhost', 'safeport')
+ safedirs = self.cmd.db.getHostAttr('localhost', 'safedirs')
+ safe_file = '/etc/safe.conf'
+ safe_conf = """
+ <!-- Safe security configuration -->
+ <config>
+ <url>http://%s</url>
+ <port>%s</port>
+ <urldir>%s</urldir>
+ </config>
+ """ % (address, safeport, safedirs)
+ content.append(self.cmd.ks.makefile(safe_conf, safe_file, perms="644", owner="root.root"))
+
+ # safeGet
+ safe_info = """
+ if [ -f /opt/sunhpc/sbin/safeGet ];then
+ /opt/sunhpc/sbin/safeGet --all
+ else
+ echo "Error safeGet command not found." >> /tmp/sunhpc.log
+ fi
+
+ /usr/bin/chown root:ssh_keys /etc/ssh/ssh_host_ecdsa_key
+ /usr/bin/chown root:ssh_keys /etc/ssh/ssh_host_ed25519_key
+ /usr/bin/chown root:ssh_keys /etc/ssh/ssh_host_rsa_key
+ """
+ content.append(safe_info)
+
+ ssh_config = '/etc/ssh/ssh_config'
+ ssh_info = """
+ Host *
+ CheckHostIP no
+ ForwardX11 yes
+ ForwardAgent yes
+ StrictHostKeyChecking no
+ UsePrivilegedPort no
+ Protocol 2,1
+ HostbasedAuthentication yes
+ EnableSSHKeySign yes
+ """
+ content.append(self.cmd.ks.makefile(ssh_info, ssh_config, perms="644", owner="root.root"))
+
+ sshd_info = """
+ echo "MaxStartups 1280" >> /etc/ssh/sshd_config
+ echo "Match User root" >> /etc/ssh/sshd_config
+ echo " HostbasedAuthentication no" >> /etc/ssh/sshd_config
+ """
+ content.append(sshd_info)
+
+ # memory sysctl
+ mem_info = """
+ SHMSIZE=`gawk '/MemTotal:/ { printf("(%s/4) * (3 * 1024)\\n", $2); }' /proc/meminfo | bc`
+ if [ $SHMSIZE ]
+ then
+ echo "kernel.shmmax = " $SHMSIZE >> /etc/sysctl.conf
+ fi
+
+ """
+ content.append(mem_info)
+
+ self.cmd.ks.addPost(content)
+ def __repr__(self):
+ return "kickstart-security"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/modules/kickstart/15-network.py b/lib/sunhpc/modules/kickstart/15-network.py
new file mode 100644
index 0000000..22e3cca
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/15-network.py
@@ -0,0 +1,40 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Configure node networks.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ # network --bootproto=static
+ # --device=eth0
+ # --gateway=192.168.199.1
+ # --ip=192.168.199.155
+ # --nameserver=192.168.199.1
+ # --netmask=255.255.255.0
+ # --ipv6=auto --activate
+ # --hostname=cluster
+
+ network = []
+ network.append(self.__help__)
+ #network.append('network --bootproto=static --hostname=%s --activate' % args)
+ network.append('network --hostname=%s' % args)
+ self.cmd.ks.addMain(network)
+
+ def __repr__(self):
+ return "kickstart-network"
diff --git a/lib/sunhpc/modules/kickstart/20-scripts.py b/lib/sunhpc/modules/kickstart/20-scripts.py
new file mode 100644
index 0000000..ab86dac
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/20-scripts.py
@@ -0,0 +1,161 @@
+#coding:utf-8
+import time
+import sunhpc
+class Modules(object):
+ """
+ Configure yum repos, /etc/hosts.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ #
+ # self.cmd.ks.makefile
+ # args:
+ # text, type:list
+ # name, type:file path
+ # mode, owner, perms, expr, quot
+ #
+
+ contents = [self.__help__]
+ repofile = "/etc/yum.repos.d/sunhpc-local.repo"
+ with open(repofile, 'r') as fe:
+ repotext = fe.readlines()
+ contents.append(self.cmd.ks.makefile(repotext, repofile))
+
+ # hosts
+ hostfile = '/etc/hosts'
+ hostinfo = self.cmd.command('report.host').strip()
+ contents.append(self.cmd.ks.makefile(hostinfo, hostfile))
+
+
+ # resolv
+ resolv = '/etc/resolv.conf'
+ landns = "nameserver %s" % self.cmd.db.getHostAttr(
+ 'localhost', 'Kickstart_PrivateDNSServer')
+ contents.append(self.cmd.ks.makefile(landns, resolv))
+
+
+ # selinux
+ selinux = '/etc/selinux/config'
+ selinux_info = """
+ # This file controls the state of SELinux on the system.
+ # SELINUX= can take one of these three values:
+ # enforcing - SELinux security policy is enforced.
+ # permissive - SELinux prints warnings instead of enforcing.
+ # disabled - No SELinux policy is loaded.
+ SELINUX=disabled
+ # SELINUXTYPE= can take one of three two values:
+ # targeted - Targeted processes are protected,
+ # minimum - Modification of targeted policy. Only selected processes are protected.
+ # mls - Multi Level Security protection.
+ SELINUXTYPE=targeted
+ """
+ contents.append(self.cmd.ks.makefile(selinux_info, selinux))
+
+ # motd
+ tnow = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())
+ motd = '/etc/motd'
+ motd_info = """
+ SunhpcOS %s (%s)
+ Profile built %s
+
+ """ % (sunhpc.version, sunhpc.release, tnow)
+ contents.append(self.cmd.ks.makefile(motd_info, motd))
+
+ # sunhpc-release
+ rels = '/etc/sunhpc-release'
+ rels_info = "SunhpcOS %s" % sunhpc.version
+ contents.append(self.cmd.ks.makefile(rels_info, rels))
+
+ # sys-release
+ sysrels = '/etc/system-release'
+ sysrels_info = "SunhpcOS Linux release %s (%s)" % (sunhpc.version, sunhpc.release)
+ contents.append(self.cmd.ks.makefile(sysrels_info, sysrels))
+
+ # os-release
+ osrels = '/etc/os-release'
+ osrels_info = """
+ NAME="SunHPC Linux"
+ VERSION="%s (%s)"
+ ID="sunhpc"
+ ID_LIKE="rhel fedora"
+ VERSION_ID="%s"
+ PRETTY_NAME="SunhpcOS Linux %s (%s)"
+ ANSI_COLOR="0;31"
+ CPE_NAME="cpe:/o:sunhpc:sunhpc:%s"
+ HOME_URL="https://www.sunhpc.com/"
+ BUG_REPORT_URL="https://bugs.sunhpc.com/"
+
+ CENTOS_MANTISBT_PROJECT="SunhpcOS-%s"
+ CENTOS_MANTISBT_PROJECT_VERSION="%s"
+ REDHAT_SUPPORT_PRODUCT="sunhpc"
+ REDHAT_SUPPORT_PRODUCT_VERSION="%s"
+ """ % (self.cmd.major, sunhpc.release,
+ self.cmd.major, self.cmd.major, sunhpc.release,
+ self.cmd.major, self.cmd.major, self.cmd.major,
+ self.cmd.major)
+ contents.append(self.cmd.ks.makefile(osrels_info, osrels))
+
+ # issue
+ issue = '/etc/issue'
+ issue_net = '/etc/issue.net'
+ issue_info = """
+ \S
+ Kernel \\r on an \m
+ Current times \d \\t
+
+ _____ _
+ / ____| | |
+ ======= | (___ _ _ _ __ | |__ _ __ ___
+ \\\ // \\\___ \\\| | | | '_ \\\| '_ \\\| '_ \\\ / __|
+ \\\ // ___ ) | |_| | | | | | | | |_) | (__
+ \\\// |_____/ \\\__,_|_| |_|_| |_| .__/ \\\___|
+ | |
+ Powered by HengPU Technology |_| By kylins
+
+ PlatName : HengPu High Performance Computing Platform
+ WeChat : xiubuzhe
+ Version : %s (%s)
+ Email : info@sunhpc.com
+ Phone : +86 18640977650
+ Homepage : https://www.sunhpc.com - @kylins
+ -----------------------------------------------------
+
+ """ % (sunhpc.version, sunhpc.release)
+ contents.append(self.cmd.ks.makefile(issue_info, issue))
+ contents.append(self.cmd.ks.makefile(issue_info, issue_net))
+
+
+ self.cmd.ks.addPost(contents)
+ def __repr__(self):
+ return "kickstart-auxiliary"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/modules/kickstart/30-services.py b/lib/sunhpc/modules/kickstart/30-services.py
new file mode 100644
index 0000000..7d56ca2
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/30-services.py
@@ -0,0 +1,59 @@
+#coding:utf-8
+import time
+import sunhpc
+class Modules(object):
+ """
+ Configure sunhpc compute node services.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ #
+ # self.cmd.ks.makefile
+ # args:
+ # text, type:list
+ # name, type:file path
+ # mode, owner, perms, expr, quot
+ #
+
+ content = [self.__help__]
+
+ # Autofs
+ autofs_info = """
+ sed -i 's/mount_nfs_default_protocol = 4/mount_nfs_default_protocol = 3/g' /etc/autofs.conf
+ """
+ content.append(autofs_info)
+
+ self.cmd.ks.addPost(content)
+ def __repr__(self):
+ return "kickstart-services"
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/modules/kickstart/50-packages.py b/lib/sunhpc/modules/kickstart/50-packages.py
new file mode 100644
index 0000000..020e0bc
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/50-packages.py
@@ -0,0 +1,98 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Configure nodes packages.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ content = []
+ content.append(self.__help__)
+ #@^gnome-desktop-environment
+ parts = """
+ @^developer-workstation-environment
+ @additional-devel
+ @base
+ @compat-libraries
+ @core
+ @debugging
+ @desktop-debugging
+ @development
+ @dial-up
+ @directory-client
+ @fonts
+ @gnome-apps
+ @gnome-desktop
+ @guest-agents
+ @guest-desktop-agents
+ @hardware-monitoring
+ @identity-management-server
+ @infiniband
+ @input-methods
+ @internet-applications
+ @internet-browser
+ @java-platform
+ @large-systems
+ @multimedia
+ @network-file-system-client
+ @performance
+ @perl-runtime
+ @perl-web
+ @php
+ @platform-devel
+ @print-client
+ @python-web
+ @ruby-runtime
+ @system-admin-tools
+ @virtualization-client
+ @virtualization-hypervisor
+ @virtualization-tools
+ @web-server
+ @x11
+ """
+
+ mini = """
+ @^minimal
+ wget
+ net-tools
+ libffi-devel
+ libffi-devel*.i686
+ nfs*
+ autofs
+ vim
+ sunhpc-client
+ sunhpc-python
+ """
+ content.append(mini)
+ #content.append(parts)
+ self.cmd.ks.addPackages(content, arg="--multilib --ignoremissing")
+
+ def __repr__(self):
+ return "kickstart-packages"
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/lib/sunhpc/modules/kickstart/60-addons.py b/lib/sunhpc/modules/kickstart/60-addons.py
new file mode 100644
index 0000000..9663b6b
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/60-addons.py
@@ -0,0 +1,30 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Configure node kdump.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ content = []
+ content.append(self.__help__)
+ content.append("%addon com_redhat_kdump --disable --reserve-mb='auto'")
+ self.cmd.ks.addAddons(content)
+
+ def __repr__(self):
+ return "kickstart-addons"
diff --git a/lib/sunhpc/modules/kickstart/62-anaconda.py b/lib/sunhpc/modules/kickstart/62-anaconda.py
new file mode 100644
index 0000000..bd32ce3
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/62-anaconda.py
@@ -0,0 +1,32 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Configure node system pwpolicy.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ anaconda = []
+ anaconda.append(self.__help__)
+ anaconda.append("pwpolicy root --minlen=6 --minquality=1 --notstrict --nochanges --notempty")
+ anaconda.append("pwpolicy user --minlen=6 --minquality=1 --notstrict --nochanges --emptyok")
+ anaconda.append("pwpolicy luks --minlen=6 --minquality=1 --notstrict --nochanges --notempty")
+ self.cmd.ks.addAnaconda(anaconda)
+
+ def __repr__(self):
+ return "kickstart-anaconda"
diff --git a/lib/sunhpc/modules/kickstart/64-pxeboot.py b/lib/sunhpc/modules/kickstart/64-pxeboot.py
new file mode 100644
index 0000000..7f74d7d
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/64-pxeboot.py
@@ -0,0 +1,39 @@
+#coding:utf-8
+import textwrap
+class Modules(object):
+ """
+ Configure node install pxeboot.
+ """
+ def __init__(self, command):
+ self.cmd = command
+
+ @property
+ def __help__(self):
+ info = """
+ #==============================================================
+ # %s
+ #
+ # module_path: %s
+ # module_name: %s
+ #==============================================================
+ """ % (self.__doc__.strip(), self.__module__.strip(), self.__repr__())
+ return info
+
+ def run(self, args):
+
+ content = []
+ content.append(self.__help__)
+ addr = self.cmd.db.getHostAttr('localhost', 'Kickstart_PrivateAddress')
+ base = self.cmd.db.getHostAttr('localhost', 'Kickstart_BaseDir')
+ http = "http://%s/%s/%s/%s" % (addr, base, 'sbin', 'setPxeboot.cgi')
+
+ setPxeboot = "wget --no-check-certificate -O /dev/null %s" % http
+ content.append(setPxeboot)
+
+ makegrub = "grub2-mkconfig > /boot/grub2/grub.cfg"
+ content.append(makegrub)
+
+ self.cmd.ks.addPost(content)
+
+ def __repr__(self):
+ return "kickstart-pxeboot"
diff --git a/lib/sunhpc/modules/kickstart/__init__.py b/lib/sunhpc/modules/kickstart/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sunhpc/modules/kickstart/__init__.py
diff --git a/sbin/build-sunhpc b/sbin/build-sunhpc
new file mode 100755
index 0000000..ebe0777
--- /dev/null
+++ b/sbin/build-sunhpc
@@ -0,0 +1,14 @@
+#!/bin/bash
+# first : initialize the sunhpc database
+# second : add cluster data to database.
+WHERE=$1
+HERE=`pwd`
+SUNHPC=/opt/sunhpc/bin/sunhpc
+
+# initialize database
+/opt/sunhpc/sbin/init-sunhpcDB
+
+# add fronetend node to database
+MYNAME=`hostname -s`
+$SUNHPC add host $MYNAME rack=0 rank=1 member=server 2> /tmp/sunhpc-initdb.log
+$SUNHPC add catindex $MYNAME category=host 2>> /tmp/sunhpc-initdb.log
diff --git a/sbin/calcrollmd5 b/sbin/calcrollmd5
new file mode 100755
index 0000000..d6bcfe1
--- /dev/null
+++ b/sbin/calcrollmd5
@@ -0,0 +1,49 @@
+#!/usr/bin/python
+#coding:utf-8
+
+import os
+import sys
+import hashlib
+
+def GetFileMD5(filename):
+ myhash = hashlib.md5()
+ with open(filename, 'rb') as f:
+ while True:
+ b = f.read(8096)
+ if not b:
+ break
+ myhash.update(b)
+ return myhash.hexdigest()
+
+def running(path):
+ rpmlist = []
+ for root, dirs,files in os.walk(path):
+ if root.endswith("RedHat"):
+ (roll, version, arch) = root.split(os.path.sep)[-4:-1]
+ rollname = "%s/%s/%s" % (roll, version, arch)
+ rollvers = "%s-%s-%s" % (roll, version, arch)
+
+ baseDirs = os.path.join(rollname, 'RedHat', 'RPMS')
+ for f in os.listdir(baseDirs):
+ rpmname = os.path.join(baseDirs, f)
+ md5sum = GetFileMD5(rpmname)
+ result = "%s %s %s" % (md5sum, rollvers, rpmname)
+ rpmlist.append(result)
+
+ return rpmlist
+
+if __name__ == "__main__":
+ args = sys.argv
+ if len(args) < 2:
+ print " - Must to be supply an path."
+ sys.exit(0)
+
+ path = args[1]
+ if os.path.isdir(path) and path != '/':
+ path = os.path.abspath(path)
+ rpmlist = running(path)
+
+ rfile = os.path.join(path, 'rollslist')
+ with open(rfile, 'w') as f:
+ f.write('\n'.join(rpmlist))
+ f.write('\n')
diff --git a/sbin/gen_root_pw b/sbin/gen_root_pw
new file mode 100755
index 0000000..ec8c378
--- /dev/null
+++ b/sbin/gen_root_pw
@@ -0,0 +1,10 @@
+#!/usr/bin/env python3
+
+import crypt
+import string
+import random
+def get_pw():
+ pw = random.random()
+ return crypt.crypt(str(pw))
+if __name__ == '__main__':
+ print (get_pw())
diff --git a/sbin/insert-ethers b/sbin/insert-ethers
new file mode 100755
index 0000000..b15aad1
--- /dev/null
+++ b/sbin/insert-ethers
@@ -0,0 +1,742 @@
+#!/opt/sunpy3/bin/python3
+#coding:utf-8
+import os
+import sys
+import time
+import snack
+import getopt
+import syslog
+import sunhpc
+import signal
+import sqlite3
+import logging
+import sunhpc.invoke
+from sunhpc.core.utils import InsertError
+from sunhpc.core.utils import InsertDone
+from sunhpc.core.utils import DumpError
+
+logging.basicConfig(filename="/tmp/sunhpc.log", level=logging.INFO)
+log = logging.getLogger("Insertnodes ")
+
+try:
+ from rhpl.translate import _, N_
+ import rhpl.translate as translate
+ translate.textdomain('insert-ethers')
+except:
+ from gettext import gettext as _
+
+class ServiceController(object):
+ """Handler system services functions"""
+
+ def __init__(self):
+ self.services = {}
+ self.ignoreList = []
+
+ self.plugins = []
+ self.plugindir = os.path.abspath(
+ '/opt/sunhpc/var/plugins')
+
+ def igore(self, service):
+ if service not in self.ignoreList:
+ self.ignoreList.append(service)
+
+ def isIgnored(self, service):
+ return service in self.ignoreList
+
+ def restart(self, service):
+ for name in self.services[service]:
+ if service not in self.ignoreList:
+ eval('self.restart_%s()' % name)
+
+ def loadPlugins(self, app):
+ # load plug in /opt/sunhpc/var/plugins dirctorys
+ if not os.path.exists(self.plugindir):
+ return
+
+ # 将plugin目录添加到模块自动导入.
+ if self.plugindir not in sys.path:
+ sys.path.append(self.plugindir)
+
+ info = _("insert-ethers loading plugins: ")
+
+ # 只载入insertnodes相关的模块.
+ modlist = os.listdir(self.plugindir + '/insertnodes')
+ modlist.sort()
+ for f in modlist:
+ modname, ext = os.path.splitext(f)
+ if modname == '__init__' or \
+ modname == '__pycache__' or ext != '.py':
+ continue
+
+ info += "%s " % modname
+ mods = __import__('insertnodes.%s' % modname)
+ mod = getattr(mods, modname)
+ try:
+ # 导入这个模块中的Plugin类.
+ plugin_class = getattr(mod, 'Plugin')
+ # 将app类以参数形式传入这个类使用.
+ # p是这个Plugin类,包含其方法和函数.
+ p = plugin_class(app)
+ self.plugins.append(p)
+ except:
+ info += _("(invalid, skipping) ")
+
+ # 将模块导入信息输出到系统日志.
+ log.info('Load KS Plugins: %s' % info)
+ syslog.syslog(info)
+
+ def logError(self, o=''):
+ "Logs the last execption to syslog"
+ oops = "%s threw exception '%s'" % (o, sys.exc_info())
+ syslog.syslog(oops)
+
+ def added(self, nodename):
+ """Tell all plugins this node has been added"""
+ for p in self.plugins:
+ try:
+ p.added(nodename)
+ except:
+ self.logError(p)
+
+ def removed(self, nodename):
+ """Tell all plugins this node has been removed"""
+ for p in self.plugins:
+ try:
+ p.removed(nodename)
+ except:
+ self.logError(p)
+
+ def done(self):
+ """Tell all plugins we are finished"""
+ for p in self.plugins:
+ try:
+ p.done()
+ except:
+ self.logError(p)
+
+ def update(self):
+ """Tell all plugins we to reload"""
+ for p in self.plugins:
+ try:
+ p.update()
+ except:
+ self.logError(p)
+
+class GUI(object):
+ """Use the snack gui class"""
+
+ def __init__(self):
+ self.screen = None
+
+ def startGUI(self):
+ self.screen = snack.SnackScreen()
+
+ def endGUI(self):
+ self.screen.finish()
+
+ def errorGUI(self, message, l1=_("Quit"), l2=None):
+ return self.modalGUI(str(message), _("Error"), l1, l2)
+
+ def warningGUI(self, message, l1=_("OK"), l2=None):
+ return self.modalGUI(str(message), _("Warning"), l1, l2)
+
+ def infoGUI(self, message, l1=_("OK"), l2=None):
+ return self.modalGUI(str(message), _("Information"), l1, l2)
+
+ def modalGUI(self, message, title, l1, l2):
+ form = snack.GridForm(self.screen, title, 2, 2)
+
+ textbox = snack.TextboxReflowed(40, message)
+ form.add(textbox, 0, 0)
+ if not l2:
+ b1 = snack.Button(l1)
+ form.add(b1, 0, 1)
+ else:
+ b1 = snack.Button(l1)
+ b2 = snack.Button(l2)
+ form.add(b1, 0, 1)
+ form.add(b2, 1, 1)
+
+ if form.runOnce() == b1:
+ return 0
+ else:
+ return 1
+
+
+class InsertEthers(GUI):
+
+ def __init__(self, app):
+ super(InsertEthers, self).__init__()
+
+ self.sql = app
+ self.cmd = None
+ self.controller = ServiceController()
+ self.cabinet = 0
+ self.rank = -1
+ self.replace = ''
+ self.maxNew = -1
+ self.remove = 0
+ self.membership = None
+ self.basename = None
+ self.restart_srvs = 0
+ self.inserted = []
+ self.kickstarted = {}
+ self.excludeMacList = []
+ self.dist_lockFile = '/var/lock/sunhpc-dist'
+ self.osname = 'linux'
+
+
+ self.doRestart = 1
+ # 排除的mac地址
+ self.subnet = 'private' # Internal Network
+ self.hostname = None
+ self.kickstartable = True
+
+ def setMembershipName(self, membership_name):
+ self.membership = membership_name
+
+ def setRemove(self, host):
+ self.replace = host
+ self.remove = 1
+
+ def startGUI(self):
+
+ GUI.startGUI(self)
+ self.form = snack.GridForm(self.screen, _("Install the system using pxelinux"), 1, 1)
+
+ self.textbox = snack.Textbox(50, 4, "", scroll=1)
+ self.form.add(self.textbox, 0, 0)
+
+ self.screen.drawRootText(0, 0, _("SunHPC(%s) -- version %s") %
+ (self.sql.usage_name,
+ self.sql.usage_version))
+
+ self.screen.drawRootText(0, 1, _("Opened kickstart access to %s network") %
+ self.sql.getPrivateNet())
+
+ self.screen.pushHelpLine(' ')
+
+ def statusGUI(self):
+ """Updates the list of nodes in 'Inserted Appliances' windows"""
+ macs_n_names = ''
+ ks = ''
+ for (mac, name) in self.inserted:
+ if name not in self.kickstarted:
+ ks = ''
+ elif self.kickstarted[name] == 0:
+ ks = '( )'
+ elif self.kickstarted[name] == 200:
+ ks = '(*)'
+ else: # An error
+ ks = '(%s)' % self.kickstarted[name]
+ macs_n_names += '%s\t%s\t%s\n' % (mac, name, ks)
+
+ self.textbox.setText(_(macs_n_names))
+
+ self.form.draw()
+ self.screen.refresh()
+
+ def waitGUI(self):
+
+ not_done = ''
+ hosts = list(self.kickstarted.keys())
+ hosts.sort()
+ for name in hosts:
+ status = self.kickstarted[name]
+ if status != 200:
+ ks = '( )'
+ if status:
+ ks = '(%s)' % status
+ not_done += '%s \t %s\n' % (name, ks)
+
+ form = snack.GridForm(self.screen,
+ _("Not kickstarted, please wait..."), 1, 1)
+ textbox = snack.Textbox(35, 4, not_done, scroll=1)
+ form.add(textbox, 0,0)
+
+ form.draw()
+ self.screen.refresh()
+ time.sleep(1)
+ self.screen.popWindow()
+
+ def membershipGUI(self):
+ self.kickstartable = True
+ self.basename = 'compute'
+ self.setMembershipName(self.basename)
+
+ def initializeRank(self):
+ query = 'select rank,max(rank) from nodes where rack = %d group by rack' % (self.cabinet)
+
+ if self.sql.search(query) > 0:
+ (rank, max_rank) = self.sql.fetchone()
+ self.rank = max_rank + 1
+ else:
+ self.rank = 0
+
+ def getnextIP(self, subnet):
+
+ args = [ subnet ]
+ if self.sql.ipIncrement != -1:
+ args.append('increment=%d' % self.sql.ipIncrement)
+
+ text = self.cmd.command('report.nextip', args)
+ if len(text) == 0:
+ raise Exception("Unable to get next IP address")
+
+ return text.strip()
+
+ def addit(self, mac, nodename, ip):
+
+ self.cmd.command('add.host', [nodename, 'os=' + self.osname,
+ 'rack=' + str(self.cabinet), 'rank=' + str(self.rank)])
+
+ self.cmd.command('add.host.interface', [nodename, 'eth0',
+ 'ip=' + ip, 'mac=' + mac, 'subnet=' + self.subnet])
+
+ self.sql.commit()
+
+ self.controller.added(nodename)
+ self.restart_srvs = 1
+
+ self.sql.commit()
+
+ list = [(mac, nodename)]
+ list.extend(self.inserted)
+ self.inserted = list
+ self.kickstarted[nodename] = 0
+
+ def discover(self, mac, dev):
+ """如果存在数据库中返回真"""
+ retval = False
+ query = 'select mac from networks where mac="%s"' % (mac)
+ if not self.sql.search(query):
+ nodename = self.getNodename()
+ log.info('GetNodename: %s' % nodename)
+
+ ipaddr = self.getnextIP(self.subnet)
+ self.addit(mac, nodename, ipaddr)
+ log.info('Addit Host: %s/%s/%s' % (nodename, ipaddr, mac))
+ self.printDiscovered(mac)
+
+ retval = True
+ return retval
+
+ def printDiscovered(self, mac):
+
+ form = snack.GridForm(self.screen,
+ _("Discovered New Appliance"), 1, 1)
+
+ new_app = _("Discovered a new appliance with MAC (%s)") % (mac)
+ textbox = snack.Textbox(len(new_app), 1, new_app)
+ form.add(textbox, 0, 0)
+
+ form.draw()
+ self.screen.refresh()
+ time.sleep(2)
+ self.screen.popWindow()
+
+ def getNodename(self):
+ if self.hostname is not None:
+ return self.hostname
+ else:
+ return '%s-%d-%d' % (self.basename, self.cabinet, self.rank)
+
+ def listenDHCP(self, line):
+
+ tokens = line.split()[:-1]
+ if len(tokens) > 9 and tokens[4] == 'dhcpd:' and \
+ (tokens[5] in ['DHCPDISCOVER', 'BOOTREQUEST']):
+
+ Dev = tokens[9].replace(':','').strip()
+ Mac = tokens[7].strip()
+
+ # 在DHCPDISCOVER from macaddr via eth0,这里面的eth0
+ # 是指主节点开启了dhcpd的网卡名称,也是private网卡名称.
+ # 但这并非是计算节点的网卡名称.
+ self.sql.execute("""select networks.device from
+ networks, subnets, nodes where
+ subnets.name='%s' and nodes.name='%s' and
+ networks.subnet=subnets.id and networks.node=nodes.id""" % (
+ self.subnet, self.sql.newdb.getFrontendName()))
+
+ # 如果有需要排除的Mac地址则在这里配置.
+ if Mac in self.excludeMacList: return
+
+ # 如果不匹配主节点DHCP服务的网卡名称,
+ subnet_dev = self.sql.fetchone()[0]
+ if Dev != subnet_dev: return
+
+ # 如果已经完成添加的mac地址,放弃这次请求.
+ if not self.discover(Mac, Dev): return
+
+ log.info('Discover New MAC: %s' % Mac)
+ self.statusGUI()
+
+ if self.maxNew > 0:
+ self.maxNew -= 1
+ if self.maxNew == 0:
+ raise InsertDone(_("Suggest Done"))
+
+ # 自动增加主机名称的Rank号.
+ self.rank = self.rank + 1
+
+ elif len(tokens) > 6 and tokens[4] == 'last' and \
+ tokens[5] == 'message' and tokens[6] == 'repeated':
+
+ shortname = os.uname()[1].split('.')[0]
+ if tokens[3] == shortname:
+ os.system('/usr/bin/systemctl restart syslog >/dev/null 2>&1')
+
+ def monitoring(self):
+ # 监控日志
+ mslog = open('/var/log/messages', 'r')
+ mslog.seek(0, 2)
+
+ kslog = open('/var/log/httpd/access_log', 'r')
+ kslog.seek(0, 2)
+
+ self.screen.pushHelpLine(
+ _(" Press <F8> to quit, press <F9> to force quit"))
+ self.form.addHotKey('F8')
+ self.form.addHotKey('F9')
+ self.form.setTimer(1000)
+
+ self.statusGUI()
+
+ result = self.form.run()
+ suggest_done = 0
+ done = 0
+ log.info('Monitoring Log: OK')
+ while not done:
+
+ # 监控系统日志中的dhcpd信息.
+ syslog_line = mslog.readline()
+ if syslog_line and not suggest_done:
+ try:
+ self.listenDHCP(syslog_line)
+ except InsertDone:
+ suggest_done = 1
+
+ except (sunhpc.core.utils.CommandError, InsertError) as msg:
+ self.warningGUI(msg)
+ continue
+
+ # 监控日志中的pxelinux信息.
+ access_line = kslog.readline()
+ if access_line:
+ try:
+ self.listenKS(access_line)
+ except InsertError as msg:
+ self.warningGUI(msg)
+ continue
+ #
+ result = self.form.run()
+ done = self.checkDone(result, suggest_done)
+
+ log.info('Restarting services status: %s' % self.restart_srvs)
+ if self.restart_srvs:
+ log.info('Start restart services ...')
+ form = snack.GridForm(self.screen, _("Restarting Services"), 1, 1)
+ message = _("Restarting Services...")
+ textbox = snack.Textbox(len(message), 1, message)
+ form.add(textbox, 0, 0)
+ form.draw()
+
+ self.screen.refresh()
+ self.controller.done()
+ self.screen.popWindow()
+
+ mslog.close()
+ self.endGUI()
+
+ def listenKS(self, line):
+ """Look in log line for a kickstart request."""
+
+ # Track accesses both with and without local certs.
+ interesting = line.count('install/sbin/kickstart.cgi') \
+ or line.count('install/sbin/public/kickstart.cgi') \
+ or line.count('install/sbin/public/jumpstart.cgi')
+ if not interesting:
+ return
+
+ fields = line.split()
+ try:
+ status = int(fields[8])
+ log.info('Kickstart Code: %s' % status)
+ except:
+ raise InsertError(_("Apache log file not well formed!"))
+
+ nodeid = int(self.sql.getNodeId(fields[0]))
+ self.sql.execute('select name from nodes where id=%d' % nodeid)
+ try:
+ name, = self.sql.fetchone()
+ except:
+ if status == 200:
+ raise InsertError( _("Unknown node %s got a kickstart file!") % fields[0])
+ return
+
+ log.info('Kickstart NodeID %s->%s' % (name, nodeid))
+ if name not in self.kickstarted:
+ return
+
+ log.info('Change KS Status %s->%s' % (name, status))
+ self.kickstarted[name] = status
+ self.statusGUI()
+
+ def checkDone(self, result, suggest_done):
+
+ if result == 'TIMER' and not suggest_done:
+ return 0
+
+ if result == 'F9': return 1
+
+ if not self.kickstartable: return 1
+
+ ok = 1
+ for status in self.kickstarted.values():
+ if status != 200:
+ ok = 0
+ break
+
+ if not ok:
+ if result == 'F8':
+ self.waitGUI()
+ else:
+ if suggest_done or result == 'F8':
+ return 1
+ return 0
+
+ def distDone(self):
+ if os.path.exists(self.dist_lockFile):
+ self.warningGUI(_("Sunhpc distribution is not ready\n\n")
+ + _("Please wait for 'sunhpc create distro' to complete\n"))
+ return 0
+ return 1
+
+ def run(self):
+
+ self.cmd = sunhpc.commands.Command(self.sql.newdb)
+ try:
+ self.cmd.command('check.services', [])
+ log.info('Check services: OK')
+ except sunhpc.core.utils.CommandError as err:
+ sys.stderr.write('error - ' + str(err) + '\n')
+ return
+
+ # 开始启动界面
+ self.startGUI()
+ log.info('Start Daemon GUI: OK')
+
+ # make sure 'sunhpc create distro' is build finished.
+ if not self.distDone():
+ self.endGUI()
+ return
+
+ self.controller.loadPlugins(self.sql)
+ try:
+ if self.remove:
+ self.endGUI()
+ self.controller.done()
+ print ('Removed node %s' % self.replace)
+ return
+
+ # 初始化Member界面信息
+ self.membershipGUI()
+ # 初始化Rank信息
+ self.initializeRank()
+
+ if self.hostname:
+ # 检查给与的主机名是否有效.
+ self.checkHostNameValidity(self.hostname)
+
+ except (sunhpc.core.utils.CommandError, InsertError) as msg:
+ self.errorGUI(msg)
+ self.endGUI()
+ sys.stderr.write(_("%s\n") % str(msg))
+ return
+
+ log.info('Start Monitoring ...')
+ self.monitoring()
+
+class App(sunhpc.core.sql.Application):
+
+ def __init__(self, argv=None):
+ sunhpc.core.sql.Application.__init__(self, argv)
+
+ if not argv:
+ argv = sys.argv
+
+ self.args = []
+ self.caller_args = argv[1:]
+ self.usage_name = 'Kamaitachi'
+ self.usage_version = '1.0.0'
+ self.usage_command = os.path.basename(argv[0])
+
+ self.getopt = sunhpc.core.utils.Struct()
+ # 短参数
+ self.getopt.s = []
+ # 长参数
+ self.getopt.l = [ ('help', 'display the command help infomation'),
+ ('version', 'Display the sunhpc version')
+ ]
+
+ try:
+ # unset our locale
+ del os.environ['LANG']
+ except KeyError:
+ pass
+
+ self.dist = None
+ self.doUpdate = 0
+ self.lockFile = '/var/lock/insert-ethers'
+ self.insertor = InsertEthers(self)
+ self.controller = ServiceController()
+ self.ipIncrement = -1
+ self.doPublicMode = 0
+
+ self.getopt.l.extend([
+ ('remove=', 'remove an hostname')
+ ])
+
+ def getArgs(self):
+ return self.args
+
+ def setArgs(self, list):
+ self.args = list
+
+ def getPrivateNet(self):
+ net = self.getHostAttr('localhost', 'Kickstart_PrivateNetwork')
+ mask = self.getHostAttr('localhost', 'Kickstart_PrivateNetmask')
+ return "%s/%s" % (net, mask)
+
+ def parseArgs(self, rcbase=None):
+ """解析参数"""
+
+ args = self.getArgs()
+
+ # 设置参数
+ self.setArgs(self.caller_args)
+
+ # 开始解析参数
+ self.parseCommandLine()
+
+ def parseCommandLine(self):
+
+ # 使用 getopt 类中的 parse函数解析命令行
+
+ # 解析短参数形式
+ short = ''
+ for e in self.getopt.s:
+ if type(e) == type(()):
+ # 取参数左值
+ short = short + e[0]
+ else:
+ short = short + e
+
+ # 解析长参数形式
+ long = []
+ for e in self.getopt.l:
+ if type(e) == type(()):
+ # 取参数左值
+ long.append(e[0])
+ else:
+ long.append(e)
+
+ try:
+ opts, args = getopt.getopt(self.args, short, long)
+ except getopt.GetoptError as msg:
+ sys.stderr.write('error - %s\n' % msg)
+ self.usage()
+ sys.exit(1)
+
+ for c in opts:
+ self.parseArg(c)
+
+ def parseArg(self, c):
+
+ if c[0] in ('-h', '--help'):
+ self.usage()
+ sys.exit(0)
+
+ elif c[0] in ('version', '--version'):
+ print (self.getClusterVersion())
+ elif c[0] == '--remove':
+ self.insertor.setRemove(c[1])
+ return 0
+
+ def getClusterVersion(self):
+ return "SunhpcOS (%s) for version - %s" % (
+ self.cmd.release, self.cmd.version)
+
+ def usage(self):
+ argDict = {}
+ for e in self.getopt.s:
+ if type(e) == type(()):
+ argDict['-%s' % e[0]] = e[1]
+ else:
+ argDict['-%s' % e] = ''
+
+ for l in self.getopt.l:
+ if type(l) == type(()):
+ argDict['--%s' % l[0]] = l[1]
+ else:
+ argDict['--%s' % l] = ''
+
+ if not argDict: return
+ maxlen = max(map(len, argDict.keys()))
+ print ('\nUsage: %s [options] command infomations' % self.usage_command)
+ for k in argDict:
+ keys = k.ljust(maxlen)
+ vals = argDict[k]
+ print (' %s\t%s' % (keys, vals))
+ print ('If you have any questions, please contact info@sunhpc.com')
+
+ def run(self):
+ self.connect()
+
+ if os.path.isfile(self.lockFile):
+ self.cmd.abort('lock file %s exists.' % self.lockFile)
+ else:
+ os.system('touch %s' % self.lockFile)
+
+ if self.doPublicMode:
+ self.insertor.runPublicOnly()
+ else:
+ self.insertor.run()
+
+ self.cleanup()
+
+ def cleanup(self):
+ try:
+ os.unlink(self.lockFile)
+ except:
+ pass
+
+if __name__ == "__main__":
+
+ try:
+ (width, heigh) = shutil.get_terminal_size()
+ except:
+ width = 80
+ os.environ['COLUMNS'] = str(width)
+ log.info('starting insert-node ...')
+
+ app = App(sys.argv)
+ app.parseArgs()
+ try:
+ app.run()
+ except Exception as msg:
+ app.cleanup()
+ if app.insertor and app.insertor.screen:
+ app.insertor.endGUI()
+ sys.stderr.write('error - ' + str(msg) + '\n')
+ import traceback
+ traceback.print_exc()
+ sys.exit(1)
+
+ finally:
+ if os.path.exists(app.lockFile):
+ os.unlink(app.lockFile)
+
+
+
diff --git a/sbin/kgen b/sbin/kgen
new file mode 100755
index 0000000..fa01d54
--- /dev/null
+++ b/sbin/kgen
@@ -0,0 +1,157 @@
+#!/opt/sunpy3/bin/python3
+#coding:utf-8
+import os,sys
+import getopt
+import sunhpc.invoke
+from xml.sax._exceptions import SAXParseException
+class App(sunhpc.core.database.ApplicationSQL):
+
+ def __init__(self, argv=None):
+ sunhpc.core.database.ApplicationSQL.__init__(self)
+
+ if not argv:
+ argv = sys.argv
+
+ self.args = []
+ self.caller_args = argv[1:]
+ self.usage_name = 'Kickstart Generator'
+ self.usage_version = '1.0'
+ self.usage_command = os.path.basename(argv[0])
+ self.sections = []
+
+ self.os = os.uname()[0].lower()
+ self.arch = os.uname()[4]
+ osGenerator = getattr(sunhpc.core.xmlgen, 'Generator_%s' % self.os)
+ self.generator = osGenerator()
+ self.generator.setArch(self.arch)
+ self.generator.setOS(self.os)
+
+ self.getopt = sunhpc.core.utils.Struct()
+ self.getopt.s = [('h', 'help infomation'), ('a', 'architecture')]
+ self.getopt.l = [('arch=', 'architecture'),
+ ('section=', 'name'),
+ ('postonly', 'show post'),
+ ]
+ def usage(self):
+ argDict = {}
+ for e in self.getopt.s:
+ if type(e) == type(()):
+ argDict['-%s' % e[0]] = e[1]
+ else:
+ argDict['-%s' % e] = ''
+
+ for l in self.getopt.l:
+ if type(l) == type(()):
+ argDict['--%s' % l[0]] = l[1]
+ else:
+ argDict['--%s' % l] = ''
+
+ if not argDict: return
+ maxlen = max(map(len, argDict.keys()))
+ print ('\nUsage: %s [options] command infomations' % self.usage_command)
+ for k in argDict:
+ keys = k.ljust(maxlen)
+ vals = argDict[k]
+ print (' %s\t%s' % (keys, vals))
+ print ('If you have any questions, please contact info@sunhpc.com')
+
+ def parseArg(self, c):
+ if c[0] in ('-h', '--help'):
+ self.usage()
+ sys.exit(-1)
+ elif c[0] in ('-a', '--arch'):
+ self.generator.setArch(c[1])
+ elif c[0] == '--section':
+ self.sections += c[1].split()
+ elif c[0] == '--postonly':
+ self.sections.append('post')
+ else:
+ return 0
+ return 1
+
+ def parseArgs(self):
+ self.parseCommandLine()
+
+ def parseCommandLine(self):
+ # 解析短参数形式
+ short = ''
+ for e in self.getopt.s:
+ if type(e) == type(()):
+ # 取参数左值
+ short = short + e[0]
+ else:
+ short = short + e
+
+ # 解析长参数形式
+ long = []
+ for e in self.getopt.l:
+ if type(e) == type(()):
+ # 取参数左值
+ long.append(e[0])
+ else:
+ long.append(e)
+ try:
+ opts, args = getopt.getopt(self.caller_args, short, long)
+ except getopt.GetoptError as msg:
+ sys.stderr.write('error - %s\n' % msg)
+ self.usage()
+ sys.exit(1)
+
+ for c in opts:
+ self.parseArg(c)
+
+ def run(self):
+
+ if self.args:
+ fe = open(self.args[0], 'r')
+ else:
+ fe = sys.stdin
+
+ self.generator.parse(fe.read())
+ print ('#')
+ print ('# %s version %s' % (self.usage_name, self.usage_version))
+ print ('#')
+
+ sections = self.sections
+ if not sections:
+ sections = ['order', 'debug', 'main', 'packages', 'pre', 'post']
+
+ plist = []
+ for s in sections:
+ plist += self.generator.generate(s)
+
+ for line in plist:
+ print (line.rstrip())
+
+if __name__ == "__main__":
+ app = App(sys.argv)
+ app.parseArgs()
+ try:
+ app.run()
+ except sunhpc.core.exceptions.KickstartError as msg:
+ sys.stderr.write("kgen error - %s\n" % msg)
+ sys.exit(-1)
+
+ except SAXParseException as msg:
+ sys.stderr.write("kgen XML parse exception: %s\n" % msg)
+ sys.exit(-1)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/sbin/mksquashfs b/sbin/mksquashfs
new file mode 100755
index 0000000..109ea23
--- /dev/null
+++ b/sbin/mksquashfs
Binary files differ
diff --git a/sbin/mom_gencfg b/sbin/mom_gencfg
new file mode 100755
index 0000000..f676b59
--- /dev/null
+++ b/sbin/mom_gencfg
@@ -0,0 +1,559 @@
+#!/usr/bin/perl
+# *****************************************************************************
+#
+# Copyright 2011 Zuse Institute Berlin
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+#
+# Please send comments to kallies@zib.de
+#
+# *****************************************************************************
+# Purpose: - called from /etc/init.d/pbs_mom during start actions.
+# - creates /var/spool/torque/mom_priv/mom.layout
+# - creates/modifies /dev/cpuset/torque
+# Prereq: - hwloc >= 1.1, http://www.open-mpi.org/projects/hwloc/
+# - Sys::Hwloc >= 0.09, http://search.cpan.org/~bka/
+# Install: Install this script on each UV rack
+# /opt/torque/Scripts/mom_gencfg root:root -rwxr-xr-x
+# Config: Set MOM_GENCFG=/opt/torque/Scripts/mom_gencfg
+# in /etc/init.d/pbs_mom for UV, execute $MOM_GENCFG before
+# starting the pbs_mom daemon.
+# MOM_GENCFG can be overridden in /etc/sysconfig/pbs_mom.
+# *****************************************************************************
+# $Id: mom_gencfg,v 1.1.2.1 2011/01/17 10:12:46 acountin Exp $
+# *****************************************************************************
+
+#
+# *** Instructions for use ***
+#
+# 1. Install hwloc - see contrib/hwloc_install.sh. This should already be done since
+# TORQUE needs hwloc for its cpuset implementation starting in 4.0
+# 2. Install Sys::Hwloc from CPAN
+# 3. Set $PBS_HOME to the proper value if not already set
+# 4. Update the variables in the section 'Config Definitions' Especially update firstNodeId
+# and nodesPerBoard if desired.
+# firstNodeId should be set above 0 if you have a root cpuset that you wish to exclude
+# nodesPerBoard is the number of numa nodes per board. Each node is defined in the
+# directory /sys/devices/system/node, in a subdirectory node<node index>
+# 5. Backup your current file, just in case a variable is set incorrectly or neglected
+# 6. Run this script and enjoy the layout file
+#
+#
+
+
+use strict;
+
+use lib qw(
+ /usr/lib/perl5
+ /usr/lib/perl5/site_perl
+ );
+
+use Sys::Hostname;
+use File::Basename;
+use Getopt::Long qw(:config no_ignore_case);
+use autouse 'Pod::Usage' => qw(pod2usage);
+use Sys::Hwloc 0.09;
+
+my $progName = basename($0);
+my $hostName = hostname();
+
+$SIG{__DIE__} = \&xDie;
+
+# ==============================================================================
+# Setup needed before init
+# ==============================================================================
+
+BEGIN: {
+ die "This script needs at least hwloc-1.1\n" unless HWLOC_XSAPI_VERSION() >= 0x00010100;
+}
+
+# ==============================================================================
+# Config definitions
+# ==============================================================================
+
+my $hostNames = undef; # hostname pattern to be run on, undef to skip test
+my $cpusetFsName = '/dev/cpuset'; # the name of the cpuset file system
+my $cpusetBaseName = '/torque'; # the name of the parent cpuset of a job's cpuset
+my $mkdirCmd = '/bin/mkdir'; # the path to the mkdir command
+my $catCmd = '/bin/cat'; # the path to the cat command
+my $echoCmd = '/bin/echo'; # the path to the echo command
+my $momCfgDir = 'mom_priv'; # the directory where MOM configs are stored
+my $momLayoutFile = 'mom.layout'; # the name of the MOM layout file
+my $firstNodeId = 0; # ID of 1st NUMA node to be used by Torque (start with 0)
+my $lastNodeId = undef; # ID of last NUMA node to be used (undef means last available)
+my $nodesPerBoard = 1; # number of NUMA nodes per nodeboard
+my %cpusetConf = (
+ cpus => undef, # undef means auto-generate
+ mems => undef, # undef means auto-generate
+ cpu_exclusive => 1, #
+ mem_exclusive => 1, #
+ );
+my %options = (
+ -doLayout => 1, # generate mom.layout
+ -withCpus => 1, # include cpus in mom.layout
+ -withMems => 1, # include mems in mom.layout
+ -doCpuset => 1, # generate/modify /torque cpuset
+ -withSmt => 1, # include logical processors running on the same core
+ -verbose => undef, # be verbose to STDERR
+ -dryRun => undef, # no actions, just tell what would be done
+ );
+
+# ==============================================================================
+# Command line options
+# ==============================================================================
+
+GetOptions(
+ "layout!" => \$options{-doLayout},
+ "cpus!" => \$options{-withCpus},
+ "mems!" => \$options{-withMems},
+ "smt!" => \$options{-withSmt},
+ "cpuset!" => \$options{-doCpuset},
+ "dry-run!" => \$options{-dryRun},
+ "verbose!" => \$options{-verbose},
+ "help|?" => sub { usage(0) },
+ "man" => sub { manPage() },
+ ) or usage(2);
+
+if($options{-dryRun}) {
+ $options{-verbose} = 1 unless defined $options{-verbose};
+ xDebug(">>> DryRunDryRunDryRunDryRunDryRun <<<");
+}
+
+# ==============================================================================
+# Quick exit if not wanted on this host, or if no work to do
+# ==============================================================================
+
+#if(defined $hostNames) {
+# unless($hostName =~ /$hostNames/) {
+# xDebug("--- Don't run on $hostName ---");
+# exit 0;
+# }
+#}
+
+exit 0 unless ($options{-doLayout} || $options{-doCpuset});
+
+# ==============================================================================
+# See if PBS_HOME is set, and if $PBS_HOME/mom_priv exists.
+# If not, we are probably not called correctly, thus die.
+# See if cpusets are configured. If not, die.
+# ==============================================================================
+
+die "\$PBS_HOME not set\n" unless (exists $ENV{PBS_HOME} && $ENV{PBS_HOME});
+die "PBS_HOME=$ENV{PBS_HOME} does not exist\n" unless -d $ENV{PBS_HOME};
+$momCfgDir = "$ENV{PBS_HOME}/${momCfgDir}";
+die "MOM config dir $momCfgDir does not exist\n" unless -d $momCfgDir;
+$momLayoutFile = "${momCfgDir}/${momLayoutFile}";
+die "this system does not support cpusets\n" unless -d $cpusetFsName;
+
+# ==============================================================================
+# Figure out system topology, collect wanted node objects
+# ==============================================================================
+
+my $topology = Sys::Hwloc::Topology->init;
+die "Failed to init topology\n" unless defined $topology;
+$topology->set_flags(HWLOC_TOPOLOGY_FLAG_WHOLE_SYSTEM);
+die("Failed to load topology\n") if $topology->load;
+
+# ==============================================================================
+# Collect nodesets of wanted NUMA nodes per nodeBoard
+# ==============================================================================
+
+my @nodeBoards = ();
+my $nodeObj = undef;
+my $nNodes = 0;
+while($nodeObj = $topology->get_next_obj_by_type(HWLOC_OBJ_NODE, $nodeObj)) {
+ my $nodeId = $nodeObj->logical_index;
+ next if $nodeId < $firstNodeId;
+ last if (defined $lastNodeId && $nodeId > $lastNodeId);
+ if($nNodes) {
+ $nodeBoards[$#nodeBoards]->{nodeset}->or($nodeObj->nodeset);
+ } else {
+ push @nodeBoards, {
+ cpuset => Sys::Hwloc::Bitmap->new,
+ nodeset => $nodeObj->nodeset->dup,
+ };
+ }
+ $nNodes++;
+ $nNodes = 0 if $nNodes >= $nodesPerBoard;
+}
+
+# ==============================================================================
+# Assemble cpusets per nodeBoard
+# ==============================================================================
+
+foreach my $nodeBoard (@nodeBoards) {
+ $topology->cpuset_from_nodeset_strict($nodeBoard->{cpuset}, $nodeBoard->{nodeset});
+ next if $options{-withSmt};
+ my $core = undef;
+ while($core = $topology->get_next_obj_inside_cpuset_by_type($nodeBoard->{cpuset}, HWLOC_OBJ_CORE, $core)) {
+ my $j = 1;
+ while (my $pu = $topology->get_obj_inside_cpuset_by_type($core->cpuset, HWLOC_OBJ_PU, $j++)) {
+ $nodeBoard->{cpuset}->andnot($pu->cpuset);
+ }
+ }
+}
+
+# ==============================================================================
+# Generate mom.layout
+# ==============================================================================
+
+if($options{-doLayout}) {
+
+ xDebug("--- Generating $momLayoutFile ---");
+ if(! $options{-dryRun}) {
+ open(FILE, "> $momLayoutFile") or die "failed to open $momLayoutFile: $!\n";
+ }
+ foreach my $nodeBoard (@nodeBoards) {
+ my $line = sprintf("nodes=%s", $nodeBoard->{nodeset}->sprintf_list);
+ $line .= sprintf(" cpus=%s", $nodeBoard->{cpuset}->sprintf_list) if $options{-withCpus};
+ $line .= sprintf(" mems=%s", $nodeBoard->{nodeset}->sprintf_list) if $options{-withMems};
+ xDebug(" $line");
+ print FILE "$line\n" unless $options{-dryRun};
+ }
+ close(FILE) unless $options{-dryRun};
+
+}
+
+# ==============================================================================
+# Create/modify torque cpuset
+# ==============================================================================
+
+if($options{-doCpuset}) {
+
+ # Create it if it is not there
+ my $cpusetPath = "${cpusetFsName}${cpusetBaseName}";
+ if(! -d $cpusetPath) {
+ xDebug("--- Creating $cpusetPath ---");
+ my $rc = execCmd($mkdirCmd,1,$cpusetPath);
+ die "Failed to create $cpusetPath\n" unless defined $rc;
+ }
+
+ # Read content
+ xDebug("--- Reading $cpusetPath ---");
+ my $cpusetData = readCpuset($cpusetPath);
+ die "Failed to read $cpusetPath\n" unless defined $cpusetData;
+
+ # Assemble changes
+ my %cpusetMod = ();
+ foreach my $key (keys %cpusetConf) {
+ next unless exists $cpusetData->{$key};
+ my $val = $cpusetConf{$key};
+ CASE: {
+ $key eq 'cpus' && do {
+ if(! defined $val) {
+ my $cpuset = Sys::Hwloc::Bitmap->new;
+ foreach my $nodeBoard (@nodeBoards) {
+ $cpuset->or($nodeBoard->{cpuset});
+ }
+ $val = $cpuset->sprintf_list;
+ $cpuset->free;
+ }
+ last CASE;
+ };
+ $key eq 'mems' && do {
+ if(! defined $val) {
+ my $nodeset = Sys::Hwloc::Bitmap->new;
+ foreach my $nodeBoard (@nodeBoards) {
+ $nodeset->or($nodeBoard->{nodeset});
+ }
+ $val = $nodeset->sprintf_list;
+ $nodeset->free;
+ }
+ last CASE;
+ };
+ }
+ next unless defined $val;
+ if(
+ (! defined $cpusetData->{$key}) ||
+ (defined $cpusetData->{$key} && $cpusetData->{$key} ne $val)
+ ) {
+ $cpusetMod{$key} = $val;
+ }
+ }
+
+ # Write changes, if any. Don't abort on error, but warn if changes not done
+ if(%cpusetMod) {
+ xDebug("--- Modifying $cpusetPath ---");
+ if($options{-dryRun}) {
+ while(my ($key, $val) = each %cpusetMod) {
+ xDebug(sprintf(" = cpuset %s: %-25s %s", $cpusetPath, $key, $val));
+ }
+ } else {
+ while(my ($key, $val) = each %cpusetMod) {
+ my $out = execCmd($echoCmd, 0, "$val > ${cpusetPath}/$key");
+ }
+ if($options{-verbose}) {
+ $cpusetData = readCpuset($cpusetPath);
+ die "Failed to read $cpusetPath\n" unless defined $cpusetData;
+ while(my ($key, $val) = each %cpusetMod) {
+ xDebug(sprintf(" %s cpuset %s: %-25s %s", $val eq $cpusetData->{$key} ? '=' : '-', $cpusetPath, $key, $val));
+ }
+ }
+ }
+ }
+}
+
+# ==============================================================================
+# All done
+# ==============================================================================
+
+$topology->destroy;
+
+exit 0;
+
+# #############################################################################
+
+# ==============================================================================
+# Read cpuset data into a hash, return 0 on error, 1 on success
+# ==============================================================================
+
+sub readCpuset {
+ my $cpusetPath = shift;
+ my $cpusetData = {};
+
+ # Check if cpuset exists
+ unless(-d $cpusetPath) {
+ xDebug("ERROR: Cpuset $cpusetPath does not exist.");
+ return undef;
+ }
+
+ # Read content of cpuset
+ foreach my $key (qw(
+ cpu_exclusive
+ cpus
+ mem_exclusive
+ mem_hardwall
+ memory_migrate
+ memory_pressure
+ memory_spread_page
+ memory_spread_slab
+ mems
+ notify_on_release
+ sched_load_balance
+ sched_relax_domain_level
+ )) {
+ my $f = "${cpusetPath}/$key";
+ next unless -e $f;
+ my $rc = execCmd($catCmd,0,$f);
+ return undef unless defined $rc; # Command failed
+ my $val = undef;
+ if(@{$rc}) {
+ CASE: {
+ $key eq 'tasks' && do { $val = join(",", @{$rc}); last CASE };
+ $val = $rc->[0];
+ }
+ }
+ xDebug(sprintf(" cpuset %s: %-25s %s", $cpusetPath, $key, defined $val ? $val : "NO DATA"));
+ $cpusetData->{$key} = $val;
+ }
+
+ return $cpusetData;
+
+}
+
+# ==============================================================================
+# Execute a command with args.
+# Returns arrayref with chomped output on success.
+# On command failure, print error msg and return undef.
+# ==============================================================================
+
+sub execCmd {
+ my $cmdBase = shift;
+ my $verbose = shift;
+ my @cmdArgs = @_;
+
+ if(! $cmdBase) {
+ xDebug("ERROR execCmd: need \$cmdBase.");
+ return undef;
+ }
+
+ # --
+ # Check if cmdBase is executable
+ # --
+
+ if(! -x $cmdBase) {
+ xDebug("ERROR: File \"$cmdBase\" does not exist or is not executable.");
+ return undef;
+ }
+
+ # --
+ # Execute
+ # --
+
+ my $cmd = $cmdBase;
+ $cmd .= (" " . join(" ", @cmdArgs)) if @cmdArgs;
+ xDebug(" About to execute \"$cmd\"") if $verbose;
+ open(CMD, "$cmd 2>&1 |") or do {
+ xDebug("ERROR: Failed to execute \"$cmd\": $!");
+ return undef;
+ };
+
+ my @cmdOut = (<CMD>);
+ chomp @cmdOut;
+
+ close(CMD);
+ my $rc = $? >> 8;
+ if($rc) {
+ xDebug("ERROR: Command \"$cmd\" returned rc = $rc");
+ if(@cmdOut) {
+ xDebug(join("\n", map { " $_" } grep { /\S/ } $#cmdOut < 3 ? @cmdOut : (@cmdOut[0..2], "...")));
+ }
+ return undef;
+ }
+
+ # --
+ # Return output
+ # --
+
+ return \@cmdOut;
+
+}
+
+# ==============================================================================
+# Usage message
+# ==============================================================================
+
+sub usage {
+ my $code = shift || 0;
+ pod2usage(
+ -verbose => 0,
+ -exitval => "NOEXIT",
+ );
+ exit $code;
+}
+
+# ==============================================================================
+# Man page
+# ==============================================================================
+
+sub manPage {
+ if ($< == 0) { # Cannot invoke perldoc as root
+ my $id = eval { getpwnam("nobody") };
+ $id = eval { getpwnam("nouser") } unless defined $id;
+ $id = -2 unless defined $id;
+ $< = $id;
+ }
+ $> = $<; # Disengage setuid
+ $ENV{PATH} = "/bin:/usr/bin"; # Untaint PATH
+ delete @ENV{ 'IFS', 'CDPATH', 'ENV', 'BASH_ENV' };
+ if ($0 =~ /^([-\/\w\.]+)$/) {
+ $0 = $1; # Untaint $0
+ } else {
+ die "Illegal characters were found in \$0 ($0)\n";
+ }
+ pod2usage(
+ -verbose => 2,
+ -exitval => 0,
+ );
+}
+
+# ==============================================================================
+# Verbose printing
+# ==============================================================================
+
+sub xDebug {
+ return unless $options{-verbose};
+ my $msg = join("", @_);
+ if($msg) {
+ foreach(split("\n", $msg)) {
+ print STDERR "$progName - $_\n"
+ }
+ } else {
+ print STDERR "$progName - something to debug\n";
+ }
+}
+
+sub xDie {
+ die "$progName - ", @_;
+}
+
+__END__
+
+=head1 NAME
+
+mom_gencfg - Create mom.layout and /dev/cpuset/torque, designed to be called from /etc/init.d/pbs_mom
+
+=head1 SYNOPSIS
+
+mom_gencfg --help|-?|--man
+
+mom_gencfg -(no)layout -(no)cpus -(no)mems -(no)cpuset -(no)smt -(no)dry-run -(no)verbose
+
+=head1 DESCRIPTION
+
+This script creates /var/spool/torque/mom_priv/mom.layout and creates/modifies /dev/cpuset/torque
+for a pbs_mom that is compiled with --enable-numa-support.
+
+The basic configuration like number and offset of NUMA node IDs per nodeboard,
+cpuset settings, and defaults of command line options is hardcoded in the script.
+
+The script checks if I<PBS_HOME> is set in the environment. Usually this should point to
+/var/spool/torque.
+
+=head1 OPTIONS
+
+=over 4
+
+=item B<-(no)layout>
+
+Create the mom.layout file or not.
+
+=item B<-(no)cpus>
+
+mom.layout contains cpu IDs per nodeboard or not.
+
+=item B<-(no)mems>
+
+mom.layout contains memory node IDs per nodeboard or not.
+
+=item B<-(no)cpuset>
+
+Create/modify /dev/cpuset/torque or not.
+
+=item B<-(no)smt>
+
+The I<cpus> entry in mom.layout and in /dev/cpuset/torque contain additional
+logical processors running on the same core or not.
+
+=item B<-(no)dry-run>
+
+If B<-dry-run> is given, show what would have been done. Switches B<-verbose> on, unless B<-noverbose> was given.
+
+=item B<-(no)verbose>
+
+Verbose printing to STDERR.
+
+=item B<-man>
+
+Prints this man page.
+
+=item B<-help|-?>
+
+Prints synopsis.
+
+=back
+
+=head1 AUTHOR
+
+Bernd Kallies, E<lt>kallies@zib.deE<gt>
+
+=head1 COPYRIGHT AND LICENSE
+
+Copyright (C) 2011 Zuse Institute Berlin
+
+This library is free software; you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation.
+
+=cut
diff --git a/sbin/restart-anaconda b/sbin/restart-anaconda
new file mode 100755
index 0000000..5e654ef
--- /dev/null
+++ b/sbin/restart-anaconda
@@ -0,0 +1,34 @@
+#! /bin/bash
+#
+# restart-anaconda: Debugging tool to restart stage2 Anaconda.
+#
+# Copyright (C) 2010
+# Red Hat, Inc. All rights reserved.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+#
+
+rm -rf /tmp/updates || echo "Error removing /tmp/updates. Updates won't be re-downloaded." >&2
+
+if [[ -f /var/run/iscsid.pid ]]; then
+ # iscsid must die else it will cause us troubles on the next run
+ # log out of all nodes
+ /sbin/iscsiadm -m node --logoutall=all
+fi
+
+# This will kill all programs in the anaconda group and restart the
+# service.
+systemctl stop anaconda.service
+anaconda-cleanup
+systemctl start --no-block anaconda.service
diff --git a/sbin/suncli b/sbin/suncli
new file mode 100755
index 0000000..221932c
--- /dev/null
+++ b/sbin/suncli
@@ -0,0 +1,38 @@
+#!/opt/sunpy3/bin/python3
+import os
+import sys
+import sunhpc
+import sunhpc.invoke
+import logging.handlers
+
+if sys.version_info.major < 3:
+ print("Sunhpc cluster supports only Python3. Rerun application in Python3 environment.")
+ exit(0)
+
+from sunhpc.console import SunhpcConsole
+
+sunhpc_home = os.environ.get('SUNHPCHOME')
+if sunhpc_home:
+ log_file = os.path.join(sunhpc_home, 'logs', 'runSunhpc.log')
+else:
+ log_file = os.path.join('/opt/sunhpc', 'logs', 'runSunhpc.log')
+
+log_handler = logging.handlers.RotatingFileHandler(filename=log_file, maxBytes=500000)
+log_formatter = logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
+log_handler.setFormatter(log_formatter)
+LOGGER = logging.getLogger()
+LOGGER.setLevel(logging.DEBUG)
+LOGGER.addHandler(log_handler)
+
+def sunhpcApplication(argv):
+ hpc = SunhpcConsole()
+ if len(argv[1:]):
+ hpc.nonInteractive(argv)
+ else:
+ hpc.start()
+
+if __name__ == "__main__":
+ try:
+ sunhpcApplication(sys.argv)
+ except (KeyboardInterrupt, SystemExit):
+ pass
diff --git a/sbin/sunyums b/sbin/sunyums
new file mode 100755
index 0000000..fc4eab8
--- /dev/null
+++ b/sbin/sunyums
@@ -0,0 +1,178 @@
+#!/usr/bin/python
+#coding:utf-8
+import os, sys
+import yum, pickle
+import tempfile
+usages = \
+"""
+Usage: sunyums [OPTION]... [FILE]...
+ Output and match all dependent installation packages
+
+Example:
+ sunyums packname1 packname2
+ sunyums packname1 packname2 --config=file --comps=comps.xml
+ sunyums packname1 packname2 --mandatory=1 --default=1 --options=0
+
+Options:
+ --config=file.conf supply an yum config file, default: optional
+ --comps=comps.xml supply an parsed comps.xml default: optional
+ --mandatory=True include mandatory packages default: True
+ --default=True include mandatory packages default: True
+ --options=False include mandatory packages default: False
+"""
+class Application(object):
+
+ def __init__(self, args):
+ self.args = args[1:]
+ self.yums = yum.YumBase()
+ self.comps = None
+ self.config = None
+ self.groups = []
+ self.mandatory = True
+ self.default = True
+ self.options = False
+
+ self.basePacks = []
+ self.origPacks = []
+ self.packages = []
+
+ def str2bool(self, s):
+ """Converts an on/off, yes/no, true/false string to 1/0."""
+ if s and s.upper() in [ 'ON', 'YES', 'Y', 'TRUE', '1', 'ENABLED', 'ENABLE']:
+ return True
+ else:
+ return False
+
+ def usages(self):
+ print usages
+ sys.exit(0)
+
+ def parseArgs(self):
+
+ if not self.args:
+ self.usages()
+
+ for arg in self.args:
+ if arg in [ '-h', '--help']:
+ self.usages()
+
+ elif arg.startswith('--comps='):
+ self.comps = arg.split('=')[1]
+
+ elif arg.startswith('--config='):
+ self.config = arg.split('=')[1]
+
+ elif arg.startswith('--mandatory='):
+ self.mandatory = self.str2bool(arg.split('=')[1])
+
+ elif arg.startswith('--default='):
+ self.default = self.str2bool(arg.split('=')[1])
+
+ elif arg.startswith('--options='):
+ self.options = self.str2bool(arg.split('=')[1])
+
+ else:
+ self.groups.append(arg)
+
+ def depends(self):
+ pkgs = []
+ avail = self.yums.pkgSack.returnNewestByNameArch()
+ for p in avail:
+ if p.name in self.basePacks:
+ pkgs.append(p)
+
+ done = 0
+ while not done:
+ done = 1
+ results = self.yums.findDeps(pkgs)
+ for pkg in results.keys():
+ for req in results[pkg].keys():
+ reqlist = results[pkg][req]
+ for r in reqlist:
+ if r.name not in self.basePacks:
+ self.basePacks.append(r.name)
+ pkgs.append(r)
+ done = 0
+
+ def allgroups(self):
+ for grp in self.yums.comps.groups:
+ self.packages.extend(grp.packages)
+
+ def handerPackages(self, name):
+ if not self.packages:
+ self.allgroups()
+
+ if name in self.packages and \
+ name not in self.basePacks:
+ self.basePacks.append(name)
+
+ if name not in self.origPacks:
+ self.origPacks.append(name)
+
+ def handerGroups(self, name):
+ groups = []
+ if not self.yums.comps.has_group(name):
+ return
+
+ valid_groups = self.yums.comps.return_group(name.encode('utf-8'))
+ if self.mandatory:
+ groups.extend(valid_groups.mandatory_packages.keys())
+ if self.default:
+ groups.extend(valid_groups.default_packages.keys())
+ if self.options:
+ groups.extend(valid_groups.options_packages.keys())
+
+ for package in groups:
+ self.handerPackages(package)
+
+ def handerEnviron(self, name):
+ groups = []
+ if not self.yums.comps.has_environment(name):
+ return
+
+ valid_environ = self.yums.comps.return_environment(name)
+ for grp in valid_environ.groups:
+ self.handerGroups(grp)
+
+ def run(self):
+
+ if self.comps and os.path.exists(self.comps):
+ self.yums.comps.add(self.comps)
+
+ if self.config and os.path.exists(self.config):
+ self.yums.doConfigSetup(fn=self.config, init_plugins=False)
+
+ self.yums.conf.cache = 0
+ for rpm in self.groups:
+ if rpm[0] == '@':
+ self.handerGroups(rpm[1:])
+
+ elif rpm[0] == '^':
+ self.handerEnviron(rpm[1:])
+
+ else:
+ self.handerPackages(rpm)
+
+ self.depends()
+
+ for o in self.origPacks:
+ if o not in self.basePacks:
+ print '#%s' % o
+
+ for p in self.basePacks:
+ print p
+
+if __name__ == "__main__":
+ app = Application(sys.argv)
+ app.parseArgs()
+ app.allgroups()
+ app.run()
+
+
+
+
+
+
+
+
+
diff --git a/sbin/unsquashfs b/sbin/unsquashfs
new file mode 100755
index 0000000..36c3773
--- /dev/null
+++ b/sbin/unsquashfs
Binary files differ